MemoryOS 0.2.0__py3-none-any.whl → 0.2.1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of MemoryOS might be problematic. Click here for more details.
- {memoryos-0.2.0.dist-info → memoryos-0.2.1.dist-info}/METADATA +66 -26
- {memoryos-0.2.0.dist-info → memoryos-0.2.1.dist-info}/RECORD +80 -56
- memoryos-0.2.1.dist-info/entry_points.txt +3 -0
- memos/__init__.py +1 -1
- memos/api/config.py +471 -0
- memos/api/exceptions.py +28 -0
- memos/api/mcp_serve.py +502 -0
- memos/api/product_api.py +35 -0
- memos/api/product_models.py +159 -0
- memos/api/routers/__init__.py +1 -0
- memos/api/routers/product_router.py +358 -0
- memos/chunkers/sentence_chunker.py +8 -2
- memos/cli.py +113 -0
- memos/configs/embedder.py +27 -0
- memos/configs/graph_db.py +83 -2
- memos/configs/llm.py +47 -0
- memos/configs/mem_cube.py +1 -1
- memos/configs/mem_scheduler.py +91 -5
- memos/configs/memory.py +5 -4
- memos/dependency.py +52 -0
- memos/embedders/ark.py +92 -0
- memos/embedders/factory.py +4 -0
- memos/embedders/sentence_transformer.py +8 -2
- memos/embedders/universal_api.py +32 -0
- memos/graph_dbs/base.py +2 -2
- memos/graph_dbs/factory.py +2 -0
- memos/graph_dbs/neo4j.py +331 -122
- memos/graph_dbs/neo4j_community.py +300 -0
- memos/llms/base.py +9 -0
- memos/llms/deepseek.py +54 -0
- memos/llms/factory.py +10 -1
- memos/llms/hf.py +170 -13
- memos/llms/hf_singleton.py +114 -0
- memos/llms/ollama.py +4 -0
- memos/llms/openai.py +67 -1
- memos/llms/qwen.py +63 -0
- memos/llms/vllm.py +153 -0
- memos/mem_cube/general.py +77 -16
- memos/mem_cube/utils.py +102 -0
- memos/mem_os/core.py +131 -41
- memos/mem_os/main.py +93 -11
- memos/mem_os/product.py +1098 -35
- memos/mem_os/utils/default_config.py +352 -0
- memos/mem_os/utils/format_utils.py +1154 -0
- memos/mem_reader/simple_struct.py +5 -5
- memos/mem_scheduler/base_scheduler.py +467 -36
- memos/mem_scheduler/general_scheduler.py +125 -244
- memos/mem_scheduler/modules/base.py +9 -0
- memos/mem_scheduler/modules/dispatcher.py +68 -2
- memos/mem_scheduler/modules/misc.py +39 -0
- memos/mem_scheduler/modules/monitor.py +228 -49
- memos/mem_scheduler/modules/rabbitmq_service.py +317 -0
- memos/mem_scheduler/modules/redis_service.py +32 -22
- memos/mem_scheduler/modules/retriever.py +250 -23
- memos/mem_scheduler/modules/schemas.py +189 -7
- memos/mem_scheduler/mos_for_test_scheduler.py +143 -0
- memos/mem_scheduler/utils.py +51 -2
- memos/mem_user/persistent_user_manager.py +260 -0
- memos/memories/activation/item.py +25 -0
- memos/memories/activation/kv.py +10 -3
- memos/memories/activation/vllmkv.py +219 -0
- memos/memories/factory.py +2 -0
- memos/memories/textual/general.py +7 -5
- memos/memories/textual/tree.py +9 -5
- memos/memories/textual/tree_text_memory/organize/conflict.py +5 -3
- memos/memories/textual/tree_text_memory/organize/manager.py +26 -18
- memos/memories/textual/tree_text_memory/organize/redundancy.py +25 -44
- memos/memories/textual/tree_text_memory/organize/relation_reason_detector.py +11 -13
- memos/memories/textual/tree_text_memory/organize/reorganizer.py +73 -51
- memos/memories/textual/tree_text_memory/retrieve/recall.py +0 -1
- memos/memories/textual/tree_text_memory/retrieve/reranker.py +2 -2
- memos/memories/textual/tree_text_memory/retrieve/searcher.py +6 -5
- memos/parsers/markitdown.py +8 -2
- memos/templates/mem_reader_prompts.py +65 -23
- memos/templates/mem_scheduler_prompts.py +96 -47
- memos/templates/tree_reorganize_prompts.py +85 -30
- memos/vec_dbs/base.py +12 -0
- memos/vec_dbs/qdrant.py +46 -20
- {memoryos-0.2.0.dist-info → memoryos-0.2.1.dist-info}/LICENSE +0 -0
- {memoryos-0.2.0.dist-info → memoryos-0.2.1.dist-info}/WHEEL +0 -0
|
@@ -1,45 +1,241 @@
|
|
|
1
|
-
import
|
|
2
|
-
|
|
1
|
+
from datetime import datetime
|
|
3
2
|
from typing import Any
|
|
4
3
|
|
|
4
|
+
from memos.configs.mem_scheduler import BaseSchedulerConfig
|
|
5
|
+
from memos.llms.base import BaseLLM
|
|
5
6
|
from memos.log import get_logger
|
|
6
7
|
from memos.mem_cube.general import GeneralMemCube
|
|
7
8
|
from memos.mem_scheduler.modules.base import BaseSchedulerModule
|
|
9
|
+
from memos.mem_scheduler.modules.misc import AutoDroppingQueue as Queue
|
|
10
|
+
from memos.mem_scheduler.modules.schemas import (
|
|
11
|
+
DEFAULT_ACTIVATION_MEM_MONITOR_SIZE_LIMIT,
|
|
12
|
+
DEFAULT_WORKING_MEM_MONITOR_SIZE_LIMIT,
|
|
13
|
+
MONITOR_ACTIVATION_MEMORY_TYPE,
|
|
14
|
+
MONITOR_WORKING_MEMORY_TYPE,
|
|
15
|
+
MemCubeID,
|
|
16
|
+
MemoryMonitorManager,
|
|
17
|
+
UserID,
|
|
18
|
+
)
|
|
8
19
|
from memos.mem_scheduler.utils import extract_json_dict
|
|
9
|
-
from memos.memories.textual.tree import TreeTextMemory
|
|
20
|
+
from memos.memories.textual.tree import TextualMemoryItem, TreeTextMemory
|
|
10
21
|
|
|
11
22
|
|
|
12
23
|
logger = get_logger(__name__)
|
|
13
24
|
|
|
14
25
|
|
|
15
26
|
class SchedulerMonitor(BaseSchedulerModule):
|
|
16
|
-
|
|
27
|
+
"""Monitors and manages scheduling operations with LLM integration."""
|
|
28
|
+
|
|
29
|
+
def __init__(self, process_llm: BaseLLM, config: BaseSchedulerConfig):
|
|
17
30
|
super().__init__()
|
|
18
|
-
|
|
19
|
-
|
|
20
|
-
self.
|
|
21
|
-
self.
|
|
22
|
-
|
|
23
|
-
|
|
24
|
-
|
|
25
|
-
self.
|
|
26
|
-
|
|
27
|
-
|
|
28
|
-
|
|
29
|
-
|
|
30
|
-
self.
|
|
31
|
-
|
|
32
|
-
|
|
33
|
-
|
|
34
|
-
|
|
35
|
-
|
|
36
|
-
|
|
37
|
-
|
|
38
|
-
|
|
31
|
+
|
|
32
|
+
# hyper-parameters
|
|
33
|
+
self.config: BaseSchedulerConfig = config
|
|
34
|
+
self.act_mem_update_interval = self.config.get("act_mem_update_interval", 300)
|
|
35
|
+
|
|
36
|
+
# Partial Retention Strategy
|
|
37
|
+
self.partial_retention_number = 2
|
|
38
|
+
self.working_mem_monitor_capacity = DEFAULT_WORKING_MEM_MONITOR_SIZE_LIMIT
|
|
39
|
+
self.activation_mem_monitor_capacity = DEFAULT_ACTIVATION_MEM_MONITOR_SIZE_LIMIT
|
|
40
|
+
|
|
41
|
+
# attributes
|
|
42
|
+
self.query_history = Queue(maxsize=self.config.context_window_size)
|
|
43
|
+
self.intent_history = Queue(maxsize=self.config.context_window_size)
|
|
44
|
+
self.working_memory_monitors: dict[UserID, dict[MemCubeID, MemoryMonitorManager]] = {}
|
|
45
|
+
self.activation_memory_monitors: dict[UserID, dict[MemCubeID, MemoryMonitorManager]] = {}
|
|
46
|
+
|
|
47
|
+
# Lifecycle monitor
|
|
48
|
+
self._last_activation_mem_update_time = datetime.min
|
|
49
|
+
|
|
50
|
+
self._process_llm = process_llm
|
|
51
|
+
|
|
52
|
+
def register_memory_manager_if_not_exists(
|
|
53
|
+
self,
|
|
54
|
+
user_id: str,
|
|
55
|
+
mem_cube_id: str,
|
|
56
|
+
memory_monitors: dict[UserID, dict[MemCubeID, MemoryMonitorManager]],
|
|
57
|
+
max_capacity: int,
|
|
58
|
+
) -> None:
|
|
59
|
+
"""
|
|
60
|
+
Register a new MemoryMonitorManager for the given user and memory cube if it doesn't exist.
|
|
61
|
+
|
|
62
|
+
Checks if a MemoryMonitorManager already exists for the specified user_id and mem_cube_id.
|
|
63
|
+
If not, creates a new MemoryMonitorManager with appropriate capacity settings and registers it.
|
|
64
|
+
|
|
65
|
+
Args:
|
|
66
|
+
user_id: The ID of the user to associate with the memory manager
|
|
67
|
+
mem_cube_id: The ID of the memory cube to monitor
|
|
68
|
+
|
|
69
|
+
Note:
|
|
70
|
+
This function will update the loose_max_working_memory_capacity based on the current
|
|
71
|
+
WorkingMemory size plus partial retention number before creating a new manager.
|
|
72
|
+
"""
|
|
73
|
+
# Check if a MemoryMonitorManager already exists for the current user_id and mem_cube_id
|
|
74
|
+
# If doesn't exist, create and register a new one
|
|
75
|
+
if (user_id not in memory_monitors) or (mem_cube_id not in memory_monitors[user_id]):
|
|
76
|
+
# Initialize MemoryMonitorManager with user ID, memory cube ID, and max capacity
|
|
77
|
+
monitor_manager = MemoryMonitorManager(
|
|
78
|
+
user_id=user_id, mem_cube_id=mem_cube_id, max_capacity=max_capacity
|
|
79
|
+
)
|
|
80
|
+
|
|
81
|
+
# Safely register the new manager in the nested dictionary structure
|
|
82
|
+
memory_monitors.setdefault(user_id, {})[mem_cube_id] = monitor_manager
|
|
83
|
+
logger.info(
|
|
84
|
+
f"Registered new MemoryMonitorManager for user_id={user_id},"
|
|
85
|
+
f" mem_cube_id={mem_cube_id} with max_capacity={max_capacity}"
|
|
86
|
+
)
|
|
39
87
|
else:
|
|
88
|
+
logger.info(
|
|
89
|
+
f"MemoryMonitorManager already exists for user_id={user_id}, "
|
|
90
|
+
f"mem_cube_id={mem_cube_id} in the provided memory_monitors dictionary"
|
|
91
|
+
)
|
|
92
|
+
|
|
93
|
+
def update_memory_monitors(self, user_id: str, mem_cube_id: str, mem_cube: GeneralMemCube):
|
|
94
|
+
text_mem_base: TreeTextMemory = mem_cube.text_mem
|
|
95
|
+
|
|
96
|
+
if not isinstance(text_mem_base, TreeTextMemory):
|
|
40
97
|
logger.error("Not Implemented")
|
|
98
|
+
return
|
|
99
|
+
|
|
100
|
+
self.working_mem_monitor_capacity = min(
|
|
101
|
+
DEFAULT_WORKING_MEM_MONITOR_SIZE_LIMIT,
|
|
102
|
+
(
|
|
103
|
+
text_mem_base.memory_manager.memory_size["WorkingMemory"]
|
|
104
|
+
+ self.partial_retention_number
|
|
105
|
+
),
|
|
106
|
+
)
|
|
107
|
+
|
|
108
|
+
self.update_working_memory_monitors(
|
|
109
|
+
user_id=user_id, mem_cube_id=mem_cube_id, mem_cube=mem_cube
|
|
110
|
+
)
|
|
41
111
|
|
|
42
|
-
|
|
112
|
+
self.update_activation_memory_monitors(
|
|
113
|
+
user_id=user_id, mem_cube_id=mem_cube_id, mem_cube=mem_cube
|
|
114
|
+
)
|
|
115
|
+
|
|
116
|
+
def update_working_memory_monitors(
|
|
117
|
+
self, user_id: str, mem_cube_id: str, mem_cube: GeneralMemCube
|
|
118
|
+
):
|
|
119
|
+
# register monitors
|
|
120
|
+
self.register_memory_manager_if_not_exists(
|
|
121
|
+
user_id=user_id,
|
|
122
|
+
mem_cube_id=mem_cube_id,
|
|
123
|
+
memory_monitors=self.working_memory_monitors,
|
|
124
|
+
max_capacity=self.working_mem_monitor_capacity,
|
|
125
|
+
)
|
|
126
|
+
|
|
127
|
+
# === update working memory monitors ===
|
|
128
|
+
# Retrieve current working memory content
|
|
129
|
+
text_mem_base: TreeTextMemory = mem_cube.text_mem
|
|
130
|
+
working_memory: list[TextualMemoryItem] = text_mem_base.get_working_memory()
|
|
131
|
+
text_working_memory: list[str] = [w_m.memory for w_m in working_memory]
|
|
132
|
+
|
|
133
|
+
self.working_memory_monitors[user_id][mem_cube_id].update_memories(
|
|
134
|
+
text_working_memories=text_working_memory,
|
|
135
|
+
partial_retention_number=self.partial_retention_number,
|
|
136
|
+
)
|
|
137
|
+
|
|
138
|
+
def update_activation_memory_monitors(
|
|
139
|
+
self, user_id: str, mem_cube_id: str, mem_cube: GeneralMemCube
|
|
140
|
+
):
|
|
141
|
+
self.register_memory_manager_if_not_exists(
|
|
142
|
+
user_id=user_id,
|
|
143
|
+
mem_cube_id=mem_cube_id,
|
|
144
|
+
memory_monitors=self.activation_memory_monitors,
|
|
145
|
+
max_capacity=self.activation_mem_monitor_capacity,
|
|
146
|
+
)
|
|
147
|
+
|
|
148
|
+
# === update activation memory monitors ===
|
|
149
|
+
# Sort by importance_score in descending order and take top k
|
|
150
|
+
top_k_memories = sorted(
|
|
151
|
+
self.working_memory_monitors[user_id][mem_cube_id].memories,
|
|
152
|
+
key=lambda m: m.get_score(),
|
|
153
|
+
reverse=True,
|
|
154
|
+
)[: self.activation_mem_monitor_capacity]
|
|
155
|
+
|
|
156
|
+
# Extract just the text from these memories
|
|
157
|
+
text_top_k_memories = [m.memory_text for m in top_k_memories]
|
|
158
|
+
|
|
159
|
+
# Update the activation memory monitors with these important memories
|
|
160
|
+
self.activation_memory_monitors[user_id][mem_cube_id].update_memories(
|
|
161
|
+
text_working_memories=text_top_k_memories,
|
|
162
|
+
partial_retention_number=self.partial_retention_number,
|
|
163
|
+
)
|
|
164
|
+
|
|
165
|
+
def timed_trigger(self, last_time: datetime, interval_seconds: float) -> bool:
|
|
166
|
+
now = datetime.now()
|
|
167
|
+
elapsed = (now - last_time).total_seconds()
|
|
168
|
+
if elapsed >= interval_seconds:
|
|
169
|
+
return True
|
|
170
|
+
logger.debug(f"Time trigger not ready, {elapsed:.1f}s elapsed (needs {interval_seconds}s)")
|
|
171
|
+
return False
|
|
172
|
+
|
|
173
|
+
def get_monitor_memories(
|
|
174
|
+
self,
|
|
175
|
+
user_id: str,
|
|
176
|
+
mem_cube_id: str,
|
|
177
|
+
memory_type: str = MONITOR_WORKING_MEMORY_TYPE,
|
|
178
|
+
top_k: int = 10,
|
|
179
|
+
) -> list[str]:
|
|
180
|
+
"""Retrieves memory items managed by the scheduler, sorted by recording count.
|
|
181
|
+
|
|
182
|
+
Args:
|
|
183
|
+
user_id: Unique identifier of the user
|
|
184
|
+
mem_cube_id: Unique identifier of the memory cube
|
|
185
|
+
memory_type: Type of memory to retrieve (MONITOR_WORKING_MEMORY_TYPE or
|
|
186
|
+
MONITOR_ACTIVATION_MEMORY_TYPE)
|
|
187
|
+
top_k: Maximum number of memory items to return (default: 10)
|
|
188
|
+
|
|
189
|
+
Returns:
|
|
190
|
+
List of memory texts, sorted by recording count in descending order.
|
|
191
|
+
Returns empty list if no MemoryMonitorManager exists for the given parameters.
|
|
192
|
+
"""
|
|
193
|
+
# Select the appropriate monitor dictionary based on memory_type
|
|
194
|
+
if memory_type == MONITOR_WORKING_MEMORY_TYPE:
|
|
195
|
+
monitor_dict = self.working_memory_monitors
|
|
196
|
+
elif memory_type == MONITOR_ACTIVATION_MEMORY_TYPE:
|
|
197
|
+
monitor_dict = self.activation_memory_monitors
|
|
198
|
+
else:
|
|
199
|
+
logger.warning(f"Invalid memory type: {memory_type}")
|
|
200
|
+
return []
|
|
201
|
+
|
|
202
|
+
if user_id not in monitor_dict or mem_cube_id not in monitor_dict[user_id]:
|
|
203
|
+
logger.warning(
|
|
204
|
+
f"MemoryMonitorManager not found for user {user_id}, "
|
|
205
|
+
f"mem_cube {mem_cube_id}, type {memory_type}"
|
|
206
|
+
)
|
|
207
|
+
return []
|
|
208
|
+
|
|
209
|
+
manager = monitor_dict[user_id][mem_cube_id]
|
|
210
|
+
# Sort memories by recording_count in descending order and return top_k items
|
|
211
|
+
sorted_memories = sorted(manager.memories, key=lambda m: m.recording_count, reverse=True)
|
|
212
|
+
sorted_text_memories = [m.memory_text for m in sorted_memories[:top_k]]
|
|
213
|
+
return sorted_text_memories
|
|
214
|
+
|
|
215
|
+
def get_monitors_info(self, user_id: str, mem_cube_id: str) -> dict[str, Any]:
|
|
216
|
+
"""Retrieves monitoring information for a specific memory cube."""
|
|
217
|
+
if (
|
|
218
|
+
user_id not in self.working_memory_monitors
|
|
219
|
+
or mem_cube_id not in self.working_memory_monitors[user_id]
|
|
220
|
+
):
|
|
221
|
+
logger.warning(
|
|
222
|
+
f"MemoryMonitorManager not found for user {user_id}, mem_cube {mem_cube_id}"
|
|
223
|
+
)
|
|
224
|
+
return {}
|
|
225
|
+
|
|
226
|
+
info_dict = {}
|
|
227
|
+
for manager in [
|
|
228
|
+
self.working_memory_monitors[user_id][mem_cube_id],
|
|
229
|
+
self.activation_memory_monitors[user_id][mem_cube_id],
|
|
230
|
+
]:
|
|
231
|
+
info_dict[str(type(manager))] = {
|
|
232
|
+
"user_id": user_id,
|
|
233
|
+
"mem_cube_id": mem_cube_id,
|
|
234
|
+
"memory_count": manager.memory_size,
|
|
235
|
+
"max_capacity": manager.max_capacity,
|
|
236
|
+
"top_memories": self.get_scheduler_working_memories(user_id, mem_cube_id, top_k=1),
|
|
237
|
+
}
|
|
238
|
+
return info_dict
|
|
43
239
|
|
|
44
240
|
def detect_intent(
|
|
45
241
|
self,
|
|
@@ -55,28 +251,11 @@ class SchedulerMonitor(BaseSchedulerModule):
|
|
|
55
251
|
q_list=q_list,
|
|
56
252
|
working_memory_list=text_working_memory,
|
|
57
253
|
)
|
|
58
|
-
response = self.
|
|
59
|
-
response = extract_json_dict(response)
|
|
60
|
-
return response
|
|
61
|
-
|
|
62
|
-
def update_freq(
|
|
63
|
-
self,
|
|
64
|
-
answer: str,
|
|
65
|
-
activation_memory_freq_list: list[dict],
|
|
66
|
-
prompt_name="freq_detecting",
|
|
67
|
-
) -> list[dict]:
|
|
68
|
-
"""
|
|
69
|
-
Use LLM to detect which memories in activation_memory_freq_list appear in the answer,
|
|
70
|
-
increment their count by 1, and return the updated list.
|
|
71
|
-
"""
|
|
72
|
-
prompt = self.build_prompt(
|
|
73
|
-
template_name=prompt_name,
|
|
74
|
-
answer=answer,
|
|
75
|
-
activation_memory_freq_list=activation_memory_freq_list,
|
|
76
|
-
)
|
|
77
|
-
response = self._chat_llm.generate([{"role": "user", "content": prompt}])
|
|
254
|
+
response = self._process_llm.generate([{"role": "user", "content": prompt}])
|
|
78
255
|
try:
|
|
79
|
-
|
|
256
|
+
response = extract_json_dict(response)
|
|
257
|
+
assert ("trigger_retrieval" in response) and ("missing_evidences" in response)
|
|
80
258
|
except Exception:
|
|
81
|
-
|
|
82
|
-
|
|
259
|
+
logger.error(f"Fail to extract json dict from response: {response}")
|
|
260
|
+
response = {"trigger_retrieval": False, "missing_evidences": q_list}
|
|
261
|
+
return response
|
|
@@ -0,0 +1,317 @@
|
|
|
1
|
+
import json
|
|
2
|
+
import ssl
|
|
3
|
+
import threading
|
|
4
|
+
import time
|
|
5
|
+
|
|
6
|
+
from pathlib import Path
|
|
7
|
+
from queue import Queue
|
|
8
|
+
|
|
9
|
+
from memos.configs.mem_scheduler import AuthConfig, RabbitMQConfig
|
|
10
|
+
from memos.dependency import require_python_package
|
|
11
|
+
from memos.log import get_logger
|
|
12
|
+
from memos.mem_scheduler.modules.base import BaseSchedulerModule
|
|
13
|
+
from memos.mem_scheduler.modules.schemas import DIRECT_EXCHANGE_TYPE, FANOUT_EXCHANGE_TYPE
|
|
14
|
+
|
|
15
|
+
|
|
16
|
+
logger = get_logger(__name__)
|
|
17
|
+
|
|
18
|
+
|
|
19
|
+
class RabbitMQSchedulerModule(BaseSchedulerModule):
|
|
20
|
+
@require_python_package(
|
|
21
|
+
import_name="pika",
|
|
22
|
+
install_command="pip install pika",
|
|
23
|
+
install_link="https://pika.readthedocs.io/en/stable/index.html",
|
|
24
|
+
)
|
|
25
|
+
def __init__(self):
|
|
26
|
+
"""
|
|
27
|
+
Initialize RabbitMQ connection settings.
|
|
28
|
+
"""
|
|
29
|
+
super().__init__()
|
|
30
|
+
|
|
31
|
+
# RabbitMQ settings
|
|
32
|
+
self.rabbitmq_config: RabbitMQConfig | None = None
|
|
33
|
+
self.rabbit_queue_name = "memos-scheduler"
|
|
34
|
+
self.rabbitmq_exchange_name = "memos-fanout"
|
|
35
|
+
self.rabbitmq_exchange_type = FANOUT_EXCHANGE_TYPE
|
|
36
|
+
self.rabbitmq_connection = None
|
|
37
|
+
self.rabbitmq_channel = None
|
|
38
|
+
|
|
39
|
+
# fixed params
|
|
40
|
+
self.rabbitmq_message_cache_max_size = 10 # Max 10 messages
|
|
41
|
+
self.rabbitmq_message_cache = Queue(maxsize=self.rabbitmq_message_cache_max_size)
|
|
42
|
+
self.rabbitmq_connection_attempts = 3 # Max retry attempts on connection failure
|
|
43
|
+
self.rabbitmq_retry_delay = 5 # Delay (seconds) between retries
|
|
44
|
+
self.rabbitmq_heartbeat = 60 # Heartbeat interval (seconds) for connectio
|
|
45
|
+
self.rabbitmq_conn_max_waiting_seconds = 30
|
|
46
|
+
self.rabbitmq_conn_sleep_seconds = 1
|
|
47
|
+
|
|
48
|
+
# Thread management
|
|
49
|
+
self._rabbitmq_io_loop_thread = None # For IOLoop execution
|
|
50
|
+
self._rabbitmq_stop_flag = False # Graceful shutdown flag
|
|
51
|
+
self._rabbitmq_lock = threading.Lock() # Ensure thread safety
|
|
52
|
+
|
|
53
|
+
def is_rabbitmq_connected(self) -> bool:
|
|
54
|
+
"""Check if RabbitMQ connection is alive"""
|
|
55
|
+
return (
|
|
56
|
+
self.rabbitmq_connection
|
|
57
|
+
and self.rabbitmq_connection.is_open
|
|
58
|
+
and self.rabbitmq_channel
|
|
59
|
+
and self.rabbitmq_channel.is_open
|
|
60
|
+
)
|
|
61
|
+
|
|
62
|
+
def initialize_rabbitmq(
|
|
63
|
+
self, config: dict | None | RabbitMQConfig = None, config_path: str | Path | None = None
|
|
64
|
+
):
|
|
65
|
+
"""
|
|
66
|
+
Establish connection to RabbitMQ using pika.
|
|
67
|
+
"""
|
|
68
|
+
from pika.adapters.select_connection import SelectConnection
|
|
69
|
+
|
|
70
|
+
if config is None:
|
|
71
|
+
if config_path is None and AuthConfig.default_config_exists():
|
|
72
|
+
auth_config = AuthConfig.from_local_yaml()
|
|
73
|
+
elif Path(config_path).exists():
|
|
74
|
+
auth_config = AuthConfig.from_local_yaml(config_path=config_path)
|
|
75
|
+
else:
|
|
76
|
+
logger.error("Fail to initialize auth_config")
|
|
77
|
+
return
|
|
78
|
+
self.rabbitmq_config = auth_config.rabbitmq
|
|
79
|
+
elif isinstance(config, RabbitMQConfig):
|
|
80
|
+
self.rabbitmq_config = config
|
|
81
|
+
elif isinstance(config, dict):
|
|
82
|
+
self.rabbitmq_config = AuthConfig.from_dict(config).rabbitmq
|
|
83
|
+
else:
|
|
84
|
+
logger.error("Not implemented")
|
|
85
|
+
|
|
86
|
+
# Start connection process
|
|
87
|
+
parameters = self.get_rabbitmq_connection_param()
|
|
88
|
+
self.rabbitmq_connection = SelectConnection(
|
|
89
|
+
parameters,
|
|
90
|
+
on_open_callback=self.on_rabbitmq_connection_open,
|
|
91
|
+
on_open_error_callback=self.on_rabbitmq_connection_error,
|
|
92
|
+
on_close_callback=self.on_rabbitmq_connection_closed,
|
|
93
|
+
)
|
|
94
|
+
|
|
95
|
+
# Start IOLoop in dedicated thread
|
|
96
|
+
self._io_loop_thread = threading.Thread(
|
|
97
|
+
target=self.rabbitmq_connection.ioloop.start, daemon=True
|
|
98
|
+
)
|
|
99
|
+
self._io_loop_thread.start()
|
|
100
|
+
logger.info("RabbitMQ connection process started")
|
|
101
|
+
|
|
102
|
+
def get_rabbitmq_queue_size(self) -> int:
|
|
103
|
+
"""Get the current number of messages in the queue.
|
|
104
|
+
|
|
105
|
+
Returns:
|
|
106
|
+
int: Number of messages in the queue.
|
|
107
|
+
Returns -1 if there's an error or no active connection.
|
|
108
|
+
"""
|
|
109
|
+
if self.rabbitmq_exchange_type != DIRECT_EXCHANGE_TYPE:
|
|
110
|
+
logger.warning("Queue size can only be checked for direct exchanges")
|
|
111
|
+
return None
|
|
112
|
+
|
|
113
|
+
with self._rabbitmq_lock:
|
|
114
|
+
if not self.is_rabbitmq_connected():
|
|
115
|
+
logger.warning("No active connection to check queue size")
|
|
116
|
+
return -1
|
|
117
|
+
|
|
118
|
+
# Declare queue passively (only checks existence, doesn't create)
|
|
119
|
+
# Using passive=True prevents accidental queue creation
|
|
120
|
+
result = self.rabbitmq_channel.queue_declare(
|
|
121
|
+
queue=self.rabbit_queue_name,
|
|
122
|
+
durable=True, # Match the original queue durability setting
|
|
123
|
+
passive=True, # Only check queue existence, don't create
|
|
124
|
+
)
|
|
125
|
+
|
|
126
|
+
if result is None:
|
|
127
|
+
return 0
|
|
128
|
+
# Return the message count from the queue declaration result
|
|
129
|
+
return result.method.message_count
|
|
130
|
+
|
|
131
|
+
def get_rabbitmq_connection_param(self):
|
|
132
|
+
import pika
|
|
133
|
+
|
|
134
|
+
credentials = pika.PlainCredentials(
|
|
135
|
+
username=self.rabbitmq_config.user_name,
|
|
136
|
+
password=self.rabbitmq_config.password,
|
|
137
|
+
erase_on_connect=self.rabbitmq_config.erase_on_connect,
|
|
138
|
+
)
|
|
139
|
+
if self.rabbitmq_config.port == 5671:
|
|
140
|
+
context = ssl.SSLContext(ssl.PROTOCOL_TLS_CLIENT)
|
|
141
|
+
context.check_hostname = False
|
|
142
|
+
context.verify_mode = False
|
|
143
|
+
return pika.ConnectionParameters(
|
|
144
|
+
host=self.rabbitmq_config.host_name,
|
|
145
|
+
port=self.rabbitmq_config.port,
|
|
146
|
+
virtual_host=self.rabbitmq_config.virtual_host,
|
|
147
|
+
credentials=credentials,
|
|
148
|
+
ssl_options=pika.SSLOptions(context),
|
|
149
|
+
connection_attempts=self.rabbitmq_connection_attempts,
|
|
150
|
+
retry_delay=self.rabbitmq_retry_delay,
|
|
151
|
+
heartbeat=self.rabbitmq_heartbeat,
|
|
152
|
+
)
|
|
153
|
+
else:
|
|
154
|
+
return pika.ConnectionParameters(
|
|
155
|
+
host=self.rabbitmq_config.host_name,
|
|
156
|
+
port=self.rabbitmq_config.port,
|
|
157
|
+
virtual_host=self.rabbitmq_config.virtual_host,
|
|
158
|
+
credentials=credentials,
|
|
159
|
+
connection_attempts=self.rabbitmq_connection_attempts,
|
|
160
|
+
retry_delay=self.rabbitmq_retry_delay,
|
|
161
|
+
heartbeat=self.rabbitmq_heartbeat,
|
|
162
|
+
)
|
|
163
|
+
|
|
164
|
+
# Connection lifecycle callbacks
|
|
165
|
+
def on_rabbitmq_connection_open(self, connection):
|
|
166
|
+
"""Called when connection is established."""
|
|
167
|
+
logger.debug("Connection opened")
|
|
168
|
+
connection.channel(on_open_callback=self.on_rabbitmq_channel_open)
|
|
169
|
+
|
|
170
|
+
def on_rabbitmq_connection_error(self, connection, error):
|
|
171
|
+
"""Called if connection fails to open."""
|
|
172
|
+
logger.error(f"Connection failed: {error}")
|
|
173
|
+
self.rabbit_reconnect()
|
|
174
|
+
|
|
175
|
+
def on_rabbitmq_connection_closed(self, connection, reason):
|
|
176
|
+
"""Called when connection closes."""
|
|
177
|
+
logger.warning(f"Connection closed: {reason}")
|
|
178
|
+
if not self._rabbitmq_stop_flag:
|
|
179
|
+
self.rabbit_reconnect()
|
|
180
|
+
|
|
181
|
+
# Channel lifecycle callbacks
|
|
182
|
+
def on_rabbitmq_channel_open(self, channel):
|
|
183
|
+
"""Called when channel is ready."""
|
|
184
|
+
self.rabbitmq_channel = channel
|
|
185
|
+
logger.debug("Channel opened")
|
|
186
|
+
|
|
187
|
+
# Setup exchange and queue
|
|
188
|
+
channel.exchange_declare(
|
|
189
|
+
exchange=self.rabbitmq_exchange_name,
|
|
190
|
+
exchange_type=self.rabbitmq_exchange_type,
|
|
191
|
+
durable=True,
|
|
192
|
+
callback=self.on_rabbitmq_exchange_declared,
|
|
193
|
+
)
|
|
194
|
+
|
|
195
|
+
def on_rabbitmq_exchange_declared(self, frame):
|
|
196
|
+
"""Called when exchange is ready."""
|
|
197
|
+
self.rabbitmq_channel.queue_declare(
|
|
198
|
+
queue=self.rabbit_queue_name, durable=True, callback=self.on_rabbitmq_queue_declared
|
|
199
|
+
)
|
|
200
|
+
|
|
201
|
+
def on_rabbitmq_queue_declared(self, frame):
|
|
202
|
+
"""Called when queue is ready."""
|
|
203
|
+
self.rabbitmq_channel.queue_bind(
|
|
204
|
+
exchange=self.rabbitmq_exchange_name,
|
|
205
|
+
queue=self.rabbit_queue_name,
|
|
206
|
+
routing_key=self.rabbit_queue_name,
|
|
207
|
+
callback=self.on_rabbitmq_bind_ok,
|
|
208
|
+
)
|
|
209
|
+
|
|
210
|
+
def on_rabbitmq_bind_ok(self, frame):
|
|
211
|
+
"""Final setup step when bind is complete."""
|
|
212
|
+
logger.info("RabbitMQ setup completed")
|
|
213
|
+
|
|
214
|
+
def on_rabbitmq_message(self, channel, method, properties, body):
|
|
215
|
+
"""Handle incoming messages. Only for test."""
|
|
216
|
+
try:
|
|
217
|
+
print(f"Received message: {body.decode()}")
|
|
218
|
+
self.rabbitmq_message_cache.put_nowait({"properties": properties, "body": body})
|
|
219
|
+
print(f"message delivery_tag: {method.delivery_tag}")
|
|
220
|
+
channel.basic_ack(delivery_tag=method.delivery_tag)
|
|
221
|
+
except Exception as e:
|
|
222
|
+
logger.error(f"Message handling failed: {e}")
|
|
223
|
+
|
|
224
|
+
def wait_for_connection_ready(self):
|
|
225
|
+
start_time = time.time()
|
|
226
|
+
while not self.is_rabbitmq_connected():
|
|
227
|
+
delta_time = time.time() - start_time
|
|
228
|
+
if delta_time > self.rabbitmq_conn_max_waiting_seconds:
|
|
229
|
+
logger.error("Failed to start consuming: Connection timeout")
|
|
230
|
+
return False
|
|
231
|
+
self.rabbit_reconnect()
|
|
232
|
+
time.sleep(self.rabbitmq_conn_sleep_seconds) # Reduced frequency of checks
|
|
233
|
+
|
|
234
|
+
# Message handling
|
|
235
|
+
def rabbitmq_start_consuming(self):
|
|
236
|
+
"""Start consuming messages asynchronously."""
|
|
237
|
+
self.wait_for_connection_ready()
|
|
238
|
+
|
|
239
|
+
self.rabbitmq_channel.basic_consume(
|
|
240
|
+
queue=self.rabbit_queue_name,
|
|
241
|
+
on_message_callback=self.on_rabbitmq_message,
|
|
242
|
+
auto_ack=False,
|
|
243
|
+
)
|
|
244
|
+
logger.info("Started rabbitmq consuming messages")
|
|
245
|
+
|
|
246
|
+
def rabbitmq_publish_message(self, message: dict):
|
|
247
|
+
"""
|
|
248
|
+
Publish a message to RabbitMQ.
|
|
249
|
+
"""
|
|
250
|
+
import pika
|
|
251
|
+
|
|
252
|
+
with self._rabbitmq_lock:
|
|
253
|
+
if not self.is_rabbitmq_connected():
|
|
254
|
+
logger.error("Cannot publish - no active connection")
|
|
255
|
+
return False
|
|
256
|
+
|
|
257
|
+
try:
|
|
258
|
+
self.rabbitmq_channel.basic_publish(
|
|
259
|
+
exchange=self.rabbitmq_exchange_name,
|
|
260
|
+
routing_key=self.rabbit_queue_name,
|
|
261
|
+
body=json.dumps(message),
|
|
262
|
+
properties=pika.BasicProperties(
|
|
263
|
+
delivery_mode=2, # Persistent
|
|
264
|
+
),
|
|
265
|
+
mandatory=True,
|
|
266
|
+
)
|
|
267
|
+
logger.debug(f"Published message: {message}")
|
|
268
|
+
return True
|
|
269
|
+
except Exception as e:
|
|
270
|
+
logger.error(f"Failed to publish message: {e}")
|
|
271
|
+
self.rabbit_reconnect()
|
|
272
|
+
return False
|
|
273
|
+
|
|
274
|
+
# Connection management
|
|
275
|
+
def rabbit_reconnect(self):
|
|
276
|
+
"""Schedule reconnection attempt."""
|
|
277
|
+
logger.info("Attempting to reconnect...")
|
|
278
|
+
if self.rabbitmq_connection and not self.rabbitmq_connection.is_closed:
|
|
279
|
+
self.rabbitmq_connection.ioloop.stop()
|
|
280
|
+
|
|
281
|
+
# Reset connection state
|
|
282
|
+
self.rabbitmq_channel = None
|
|
283
|
+
self.initialize_rabbitmq()
|
|
284
|
+
|
|
285
|
+
def rabbitmq_close(self):
|
|
286
|
+
"""Gracefully shutdown connection."""
|
|
287
|
+
with self._rabbitmq_lock:
|
|
288
|
+
self._rabbitmq_stop_flag = True
|
|
289
|
+
|
|
290
|
+
# Close channel if open
|
|
291
|
+
if self.rabbitmq_channel and self.rabbitmq_channel.is_open:
|
|
292
|
+
try:
|
|
293
|
+
self.rabbitmq_channel.close()
|
|
294
|
+
except Exception as e:
|
|
295
|
+
logger.warning(f"Error closing channel: {e}")
|
|
296
|
+
|
|
297
|
+
# Close connection if open
|
|
298
|
+
if self.rabbitmq_connection:
|
|
299
|
+
if self.rabbitmq_connection.is_open:
|
|
300
|
+
try:
|
|
301
|
+
self.rabbitmq_connection.close()
|
|
302
|
+
except Exception as e:
|
|
303
|
+
logger.warning(f"Error closing connection: {e}")
|
|
304
|
+
|
|
305
|
+
# Stop IOLoop if running
|
|
306
|
+
try:
|
|
307
|
+
self.rabbitmq_connection.ioloop.stop()
|
|
308
|
+
except Exception as e:
|
|
309
|
+
logger.warning(f"Error stopping IOLoop: {e}")
|
|
310
|
+
|
|
311
|
+
# Wait for IOLoop thread to finish
|
|
312
|
+
if self._io_loop_thread and self._io_loop_thread.is_alive():
|
|
313
|
+
self._io_loop_thread.join(timeout=5)
|
|
314
|
+
if self._io_loop_thread.is_alive():
|
|
315
|
+
logger.warning("IOLoop thread did not terminate cleanly")
|
|
316
|
+
|
|
317
|
+
logger.info("RabbitMQ connection closed")
|