MemoryOS 1.0.1__py3-none-any.whl → 1.1.1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of MemoryOS might be problematic. Click here for more details.
- {memoryos-1.0.1.dist-info → memoryos-1.1.1.dist-info}/METADATA +7 -2
- {memoryos-1.0.1.dist-info → memoryos-1.1.1.dist-info}/RECORD +79 -65
- {memoryos-1.0.1.dist-info → memoryos-1.1.1.dist-info}/WHEEL +1 -1
- memos/__init__.py +1 -1
- memos/api/client.py +109 -0
- memos/api/config.py +11 -9
- memos/api/context/dependencies.py +15 -55
- memos/api/middleware/request_context.py +9 -40
- memos/api/product_api.py +2 -3
- memos/api/product_models.py +91 -16
- memos/api/routers/product_router.py +23 -16
- memos/api/start_api.py +10 -0
- memos/configs/graph_db.py +4 -0
- memos/configs/mem_scheduler.py +38 -3
- memos/context/context.py +255 -0
- memos/embedders/factory.py +2 -0
- memos/graph_dbs/nebular.py +230 -232
- memos/graph_dbs/neo4j.py +35 -1
- memos/graph_dbs/neo4j_community.py +7 -0
- memos/llms/factory.py +2 -0
- memos/llms/openai.py +74 -2
- memos/log.py +27 -15
- memos/mem_cube/general.py +3 -1
- memos/mem_os/core.py +60 -22
- memos/mem_os/main.py +3 -6
- memos/mem_os/product.py +35 -11
- memos/mem_reader/factory.py +2 -0
- memos/mem_reader/simple_struct.py +127 -74
- memos/mem_scheduler/analyzer/__init__.py +0 -0
- memos/mem_scheduler/analyzer/mos_for_test_scheduler.py +569 -0
- memos/mem_scheduler/analyzer/scheduler_for_eval.py +280 -0
- memos/mem_scheduler/base_scheduler.py +126 -56
- memos/mem_scheduler/general_modules/dispatcher.py +2 -2
- memos/mem_scheduler/general_modules/misc.py +99 -1
- memos/mem_scheduler/general_modules/scheduler_logger.py +17 -11
- memos/mem_scheduler/general_scheduler.py +40 -88
- memos/mem_scheduler/memory_manage_modules/__init__.py +5 -0
- memos/mem_scheduler/memory_manage_modules/memory_filter.py +308 -0
- memos/mem_scheduler/{general_modules → memory_manage_modules}/retriever.py +34 -7
- memos/mem_scheduler/monitors/dispatcher_monitor.py +9 -8
- memos/mem_scheduler/monitors/general_monitor.py +119 -39
- memos/mem_scheduler/optimized_scheduler.py +124 -0
- memos/mem_scheduler/orm_modules/__init__.py +0 -0
- memos/mem_scheduler/orm_modules/base_model.py +635 -0
- memos/mem_scheduler/orm_modules/monitor_models.py +261 -0
- memos/mem_scheduler/scheduler_factory.py +2 -0
- memos/mem_scheduler/schemas/monitor_schemas.py +96 -29
- memos/mem_scheduler/utils/config_utils.py +100 -0
- memos/mem_scheduler/utils/db_utils.py +33 -0
- memos/mem_scheduler/utils/filter_utils.py +1 -1
- memos/mem_scheduler/webservice_modules/__init__.py +0 -0
- memos/memories/activation/kv.py +2 -1
- memos/memories/textual/item.py +95 -16
- memos/memories/textual/naive.py +1 -1
- memos/memories/textual/tree.py +27 -3
- memos/memories/textual/tree_text_memory/organize/handler.py +4 -2
- memos/memories/textual/tree_text_memory/organize/manager.py +28 -14
- memos/memories/textual/tree_text_memory/organize/relation_reason_detector.py +1 -2
- memos/memories/textual/tree_text_memory/organize/reorganizer.py +75 -23
- memos/memories/textual/tree_text_memory/retrieve/bochasearch.py +7 -5
- memos/memories/textual/tree_text_memory/retrieve/internet_retriever.py +6 -2
- memos/memories/textual/tree_text_memory/retrieve/internet_retriever_factory.py +2 -0
- memos/memories/textual/tree_text_memory/retrieve/recall.py +70 -22
- memos/memories/textual/tree_text_memory/retrieve/searcher.py +101 -33
- memos/memories/textual/tree_text_memory/retrieve/xinyusearch.py +5 -4
- memos/memos_tools/singleton.py +174 -0
- memos/memos_tools/thread_safe_dict.py +22 -0
- memos/memos_tools/thread_safe_dict_segment.py +382 -0
- memos/parsers/factory.py +2 -0
- memos/reranker/concat.py +59 -0
- memos/reranker/cosine_local.py +1 -0
- memos/reranker/factory.py +5 -0
- memos/reranker/http_bge.py +225 -12
- memos/templates/mem_scheduler_prompts.py +242 -0
- memos/types.py +4 -1
- memos/api/context/context.py +0 -147
- memos/api/context/context_thread.py +0 -96
- memos/mem_scheduler/mos_for_test_scheduler.py +0 -146
- {memoryos-1.0.1.dist-info → memoryos-1.1.1.dist-info}/entry_points.txt +0 -0
- {memoryos-1.0.1.dist-info → memoryos-1.1.1.dist-info/licenses}/LICENSE +0 -0
- /memos/mem_scheduler/{general_modules → webservice_modules}/rabbitmq_service.py +0 -0
- /memos/mem_scheduler/{general_modules → webservice_modules}/redis_service.py +0 -0
|
@@ -0,0 +1,174 @@
|
|
|
1
|
+
"""
|
|
2
|
+
Singleton decorator module for caching factory instances to avoid excessive memory usage
|
|
3
|
+
from repeated initialization.
|
|
4
|
+
"""
|
|
5
|
+
|
|
6
|
+
import hashlib
|
|
7
|
+
import json
|
|
8
|
+
|
|
9
|
+
from collections.abc import Callable
|
|
10
|
+
from functools import wraps
|
|
11
|
+
from typing import Any, TypeVar
|
|
12
|
+
from weakref import WeakValueDictionary
|
|
13
|
+
|
|
14
|
+
|
|
15
|
+
T = TypeVar("T")
|
|
16
|
+
|
|
17
|
+
|
|
18
|
+
class FactorySingleton:
|
|
19
|
+
"""Factory singleton manager that caches instances based on configuration parameters"""
|
|
20
|
+
|
|
21
|
+
def __init__(self):
|
|
22
|
+
# Use weak reference dictionary for automatic cleanup when instances are no longer referenced
|
|
23
|
+
self._instances: dict[str, WeakValueDictionary] = {}
|
|
24
|
+
|
|
25
|
+
def _generate_cache_key(self, config: Any, *args, **kwargs) -> str:
|
|
26
|
+
"""Generate cache key based on configuration only (ignoring other parameters)"""
|
|
27
|
+
|
|
28
|
+
# Handle configuration objects - only use the config parameter
|
|
29
|
+
if hasattr(config, "model_dump"): # Pydantic model
|
|
30
|
+
config_data = config.model_dump()
|
|
31
|
+
elif hasattr(config, "dict"): # Legacy Pydantic model
|
|
32
|
+
config_data = config.dict()
|
|
33
|
+
elif isinstance(config, dict):
|
|
34
|
+
config_data = config
|
|
35
|
+
else:
|
|
36
|
+
# For other types, try to convert to string
|
|
37
|
+
config_data = str(config)
|
|
38
|
+
|
|
39
|
+
# Filter out time-related fields that shouldn't affect caching
|
|
40
|
+
filtered_config = self._filter_temporal_fields(config_data)
|
|
41
|
+
|
|
42
|
+
# Generate hash key based only on config
|
|
43
|
+
try:
|
|
44
|
+
cache_str = json.dumps(filtered_config, sort_keys=True, ensure_ascii=False, default=str)
|
|
45
|
+
except (TypeError, ValueError):
|
|
46
|
+
# If JSON serialization fails, convert the entire config to string
|
|
47
|
+
cache_str = str(filtered_config)
|
|
48
|
+
|
|
49
|
+
return hashlib.md5(cache_str.encode("utf-8")).hexdigest()
|
|
50
|
+
|
|
51
|
+
def _filter_temporal_fields(self, config_data: Any) -> Any:
|
|
52
|
+
"""Filter out temporal fields that shouldn't affect instance caching"""
|
|
53
|
+
if isinstance(config_data, dict):
|
|
54
|
+
filtered = {}
|
|
55
|
+
for key, value in config_data.items():
|
|
56
|
+
# Skip common temporal field names
|
|
57
|
+
if key.lower() in {
|
|
58
|
+
"created_at",
|
|
59
|
+
"updated_at",
|
|
60
|
+
"timestamp",
|
|
61
|
+
"time",
|
|
62
|
+
"date",
|
|
63
|
+
"created_time",
|
|
64
|
+
"updated_time",
|
|
65
|
+
"last_modified",
|
|
66
|
+
"modified_at",
|
|
67
|
+
"start_time",
|
|
68
|
+
"end_time",
|
|
69
|
+
"execution_time",
|
|
70
|
+
"run_time",
|
|
71
|
+
}:
|
|
72
|
+
continue
|
|
73
|
+
# Recursively filter nested dictionaries
|
|
74
|
+
filtered[key] = self._filter_temporal_fields(value)
|
|
75
|
+
return filtered
|
|
76
|
+
elif isinstance(config_data, list):
|
|
77
|
+
# Recursively filter lists
|
|
78
|
+
return [self._filter_temporal_fields(item) for item in config_data]
|
|
79
|
+
else:
|
|
80
|
+
# For primitive types, return as-is
|
|
81
|
+
return config_data
|
|
82
|
+
|
|
83
|
+
def get_or_create(self, factory_class: type, cache_key: str, creator_func: Callable) -> Any:
|
|
84
|
+
"""Get or create instance"""
|
|
85
|
+
class_name = factory_class.__name__
|
|
86
|
+
|
|
87
|
+
if class_name not in self._instances:
|
|
88
|
+
self._instances[class_name] = WeakValueDictionary()
|
|
89
|
+
|
|
90
|
+
class_cache = self._instances[class_name]
|
|
91
|
+
|
|
92
|
+
if cache_key in class_cache:
|
|
93
|
+
return class_cache[cache_key]
|
|
94
|
+
|
|
95
|
+
# Create new instance
|
|
96
|
+
instance = creator_func()
|
|
97
|
+
class_cache[cache_key] = instance
|
|
98
|
+
return instance
|
|
99
|
+
|
|
100
|
+
def clear_cache(self, factory_class: type | None = None):
|
|
101
|
+
"""Clear cache"""
|
|
102
|
+
if factory_class:
|
|
103
|
+
class_name = factory_class.__name__
|
|
104
|
+
if class_name in self._instances:
|
|
105
|
+
self._instances[class_name].clear()
|
|
106
|
+
else:
|
|
107
|
+
for cache in self._instances.values():
|
|
108
|
+
cache.clear()
|
|
109
|
+
|
|
110
|
+
|
|
111
|
+
# Global singleton manager
|
|
112
|
+
_factory_singleton = FactorySingleton()
|
|
113
|
+
|
|
114
|
+
|
|
115
|
+
def singleton_factory(factory_class: type | str | None = None):
|
|
116
|
+
"""
|
|
117
|
+
Factory singleton decorator
|
|
118
|
+
|
|
119
|
+
Usage:
|
|
120
|
+
@singleton_factory()
|
|
121
|
+
def from_config(cls, config):
|
|
122
|
+
return SomeClass(config)
|
|
123
|
+
|
|
124
|
+
Or specify factory class:
|
|
125
|
+
@singleton_factory(EmbedderFactory)
|
|
126
|
+
def from_config(cls, config):
|
|
127
|
+
return SomeClass(config)
|
|
128
|
+
"""
|
|
129
|
+
|
|
130
|
+
def decorator(func: Callable[..., T]) -> Callable[..., T]:
|
|
131
|
+
@wraps(func)
|
|
132
|
+
def wrapper(*args, **kwargs) -> T:
|
|
133
|
+
# Determine factory class and config parameter
|
|
134
|
+
target_factory_class = factory_class
|
|
135
|
+
config = None
|
|
136
|
+
|
|
137
|
+
# Simple logic: check if first parameter is a class or config
|
|
138
|
+
if args:
|
|
139
|
+
if hasattr(args[0], "__name__") and hasattr(args[0], "__module__"):
|
|
140
|
+
# First parameter is a class (cls), so this is a @classmethod
|
|
141
|
+
if target_factory_class is None:
|
|
142
|
+
target_factory_class = args[0]
|
|
143
|
+
config = args[1] if len(args) > 1 else None
|
|
144
|
+
else:
|
|
145
|
+
# First parameter is config, so this is a @staticmethod
|
|
146
|
+
if target_factory_class is None:
|
|
147
|
+
raise ValueError(
|
|
148
|
+
"Factory class must be explicitly specified for static methods"
|
|
149
|
+
)
|
|
150
|
+
if isinstance(target_factory_class, str):
|
|
151
|
+
# Convert string to a mock class for caching purposes
|
|
152
|
+
class MockFactoryClass:
|
|
153
|
+
__name__ = target_factory_class
|
|
154
|
+
|
|
155
|
+
target_factory_class = MockFactoryClass
|
|
156
|
+
config = args[0]
|
|
157
|
+
|
|
158
|
+
if config is None:
|
|
159
|
+
# If no configuration parameter, call original function directly
|
|
160
|
+
return func(*args, **kwargs)
|
|
161
|
+
|
|
162
|
+
# Generate cache key based only on config
|
|
163
|
+
cache_key = _factory_singleton._generate_cache_key(config)
|
|
164
|
+
|
|
165
|
+
# Function to create instance
|
|
166
|
+
def creator():
|
|
167
|
+
return func(*args, **kwargs)
|
|
168
|
+
|
|
169
|
+
# Get or create instance
|
|
170
|
+
return _factory_singleton.get_or_create(target_factory_class, cache_key, creator)
|
|
171
|
+
|
|
172
|
+
return wrapper
|
|
173
|
+
|
|
174
|
+
return decorator
|
|
@@ -7,10 +7,15 @@ import threading
|
|
|
7
7
|
from collections.abc import ItemsView, Iterator, KeysView, ValuesView
|
|
8
8
|
from typing import Generic, TypeVar
|
|
9
9
|
|
|
10
|
+
from memos.log import get_logger
|
|
11
|
+
from memos.utils import timed
|
|
12
|
+
|
|
10
13
|
|
|
11
14
|
K = TypeVar("K")
|
|
12
15
|
V = TypeVar("V")
|
|
13
16
|
|
|
17
|
+
logger = get_logger(__name__)
|
|
18
|
+
|
|
14
19
|
|
|
15
20
|
class ReadWriteLock:
|
|
16
21
|
"""A simple read-write lock implementation. use for product-server scenario"""
|
|
@@ -19,6 +24,7 @@ class ReadWriteLock:
|
|
|
19
24
|
self._read_ready = threading.Condition(threading.RLock())
|
|
20
25
|
self._readers = 0
|
|
21
26
|
|
|
27
|
+
@timed
|
|
22
28
|
def acquire_read(self):
|
|
23
29
|
"""Acquire a read lock. Multiple readers can hold the lock simultaneously."""
|
|
24
30
|
self._read_ready.acquire()
|
|
@@ -37,6 +43,7 @@ class ReadWriteLock:
|
|
|
37
43
|
finally:
|
|
38
44
|
self._read_ready.release()
|
|
39
45
|
|
|
46
|
+
@timed
|
|
40
47
|
def acquire_write(self):
|
|
41
48
|
"""Acquire a write lock. Only one writer can hold the lock."""
|
|
42
49
|
self._read_ready.acquire()
|
|
@@ -67,6 +74,7 @@ class ThreadSafeDict(Generic[K, V]):
|
|
|
67
74
|
self._dict: dict[K, V] = initial_dict.copy() if initial_dict else {}
|
|
68
75
|
self._lock = ReadWriteLock()
|
|
69
76
|
|
|
77
|
+
@timed
|
|
70
78
|
def __getitem__(self, key: K) -> V:
|
|
71
79
|
"""Get item by key."""
|
|
72
80
|
self._lock.acquire_read()
|
|
@@ -75,6 +83,7 @@ class ThreadSafeDict(Generic[K, V]):
|
|
|
75
83
|
finally:
|
|
76
84
|
self._lock.release_read()
|
|
77
85
|
|
|
86
|
+
@timed
|
|
78
87
|
def __setitem__(self, key: K, value: V) -> None:
|
|
79
88
|
"""Set item by key."""
|
|
80
89
|
self._lock.acquire_write()
|
|
@@ -83,6 +92,7 @@ class ThreadSafeDict(Generic[K, V]):
|
|
|
83
92
|
finally:
|
|
84
93
|
self._lock.release_write()
|
|
85
94
|
|
|
95
|
+
@timed
|
|
86
96
|
def __delitem__(self, key: K) -> None:
|
|
87
97
|
"""Delete item by key."""
|
|
88
98
|
self._lock.acquire_write()
|
|
@@ -91,6 +101,7 @@ class ThreadSafeDict(Generic[K, V]):
|
|
|
91
101
|
finally:
|
|
92
102
|
self._lock.release_write()
|
|
93
103
|
|
|
104
|
+
@timed
|
|
94
105
|
def __contains__(self, key: K) -> bool:
|
|
95
106
|
"""Check if key exists in dictionary."""
|
|
96
107
|
self._lock.acquire_read()
|
|
@@ -99,6 +110,7 @@ class ThreadSafeDict(Generic[K, V]):
|
|
|
99
110
|
finally:
|
|
100
111
|
self._lock.release_read()
|
|
101
112
|
|
|
113
|
+
@timed
|
|
102
114
|
def __len__(self) -> int:
|
|
103
115
|
"""Get length of dictionary."""
|
|
104
116
|
self._lock.acquire_read()
|
|
@@ -115,6 +127,7 @@ class ThreadSafeDict(Generic[K, V]):
|
|
|
115
127
|
finally:
|
|
116
128
|
self._lock.release_read()
|
|
117
129
|
|
|
130
|
+
@timed
|
|
118
131
|
def __iter__(self) -> Iterator[K]:
|
|
119
132
|
"""Iterate over keys. Returns a snapshot to avoid iteration issues."""
|
|
120
133
|
self._lock.acquire_read()
|
|
@@ -124,6 +137,7 @@ class ThreadSafeDict(Generic[K, V]):
|
|
|
124
137
|
finally:
|
|
125
138
|
self._lock.release_read()
|
|
126
139
|
|
|
140
|
+
@timed
|
|
127
141
|
def get(self, key: K, default: V | None = None) -> V:
|
|
128
142
|
"""Get item by key with optional default."""
|
|
129
143
|
self._lock.acquire_read()
|
|
@@ -132,6 +146,7 @@ class ThreadSafeDict(Generic[K, V]):
|
|
|
132
146
|
finally:
|
|
133
147
|
self._lock.release_read()
|
|
134
148
|
|
|
149
|
+
@timed
|
|
135
150
|
def pop(self, key: K, *args) -> V:
|
|
136
151
|
"""Pop item by key."""
|
|
137
152
|
self._lock.acquire_write()
|
|
@@ -140,6 +155,7 @@ class ThreadSafeDict(Generic[K, V]):
|
|
|
140
155
|
finally:
|
|
141
156
|
self._lock.release_write()
|
|
142
157
|
|
|
158
|
+
@timed
|
|
143
159
|
def update(self, *args, **kwargs) -> None:
|
|
144
160
|
"""Update dictionary."""
|
|
145
161
|
self._lock.acquire_write()
|
|
@@ -148,6 +164,7 @@ class ThreadSafeDict(Generic[K, V]):
|
|
|
148
164
|
finally:
|
|
149
165
|
self._lock.release_write()
|
|
150
166
|
|
|
167
|
+
@timed
|
|
151
168
|
def clear(self) -> None:
|
|
152
169
|
"""Clear all items."""
|
|
153
170
|
self._lock.acquire_write()
|
|
@@ -156,6 +173,7 @@ class ThreadSafeDict(Generic[K, V]):
|
|
|
156
173
|
finally:
|
|
157
174
|
self._lock.release_write()
|
|
158
175
|
|
|
176
|
+
@timed
|
|
159
177
|
def keys(self) -> KeysView[K]:
|
|
160
178
|
"""Get dictionary keys view (snapshot)."""
|
|
161
179
|
self._lock.acquire_read()
|
|
@@ -164,6 +182,7 @@ class ThreadSafeDict(Generic[K, V]):
|
|
|
164
182
|
finally:
|
|
165
183
|
self._lock.release_read()
|
|
166
184
|
|
|
185
|
+
@timed
|
|
167
186
|
def values(self) -> ValuesView[V]:
|
|
168
187
|
"""Get dictionary values view (snapshot)."""
|
|
169
188
|
self._lock.acquire_read()
|
|
@@ -172,6 +191,7 @@ class ThreadSafeDict(Generic[K, V]):
|
|
|
172
191
|
finally:
|
|
173
192
|
self._lock.release_read()
|
|
174
193
|
|
|
194
|
+
@timed
|
|
175
195
|
def items(self) -> ItemsView[K, V]:
|
|
176
196
|
"""Get dictionary items view (snapshot)."""
|
|
177
197
|
self._lock.acquire_read()
|
|
@@ -180,6 +200,7 @@ class ThreadSafeDict(Generic[K, V]):
|
|
|
180
200
|
finally:
|
|
181
201
|
self._lock.release_read()
|
|
182
202
|
|
|
203
|
+
@timed
|
|
183
204
|
def copy(self) -> dict[K, V]:
|
|
184
205
|
"""Create a copy of the dictionary."""
|
|
185
206
|
self._lock.acquire_read()
|
|
@@ -188,6 +209,7 @@ class ThreadSafeDict(Generic[K, V]):
|
|
|
188
209
|
finally:
|
|
189
210
|
self._lock.release_read()
|
|
190
211
|
|
|
212
|
+
@timed
|
|
191
213
|
def setdefault(self, key: K, default: V | None = None) -> V:
|
|
192
214
|
"""Set default value for key if not exists."""
|
|
193
215
|
self._lock.acquire_write()
|
|
@@ -0,0 +1,382 @@
|
|
|
1
|
+
import threading
|
|
2
|
+
import time
|
|
3
|
+
|
|
4
|
+
from collections.abc import Iterator
|
|
5
|
+
from contextlib import contextmanager
|
|
6
|
+
from typing import Any, Generic, TypeVar
|
|
7
|
+
|
|
8
|
+
|
|
9
|
+
K = TypeVar("K")
|
|
10
|
+
V = TypeVar("V")
|
|
11
|
+
|
|
12
|
+
|
|
13
|
+
class FastReadWriteLock:
|
|
14
|
+
"""Read-write lock optimized for FastAPI scenarios:
|
|
15
|
+
reader priority with writer starvation prevention"""
|
|
16
|
+
|
|
17
|
+
def __init__(self):
|
|
18
|
+
self._readers = 0
|
|
19
|
+
self._writers = 0
|
|
20
|
+
self._waiting_writers = 0
|
|
21
|
+
self._lock = threading.RLock()
|
|
22
|
+
self._read_ready = threading.Condition(self._lock)
|
|
23
|
+
self._write_ready = threading.Condition(self._lock)
|
|
24
|
+
# Writer starvation detection
|
|
25
|
+
self._last_write_time = 0
|
|
26
|
+
self._write_starvation_threshold = 0.1 # 100ms
|
|
27
|
+
|
|
28
|
+
def acquire_read(self) -> bool:
|
|
29
|
+
"""Fast read lock acquisition"""
|
|
30
|
+
with self._lock:
|
|
31
|
+
# Check if writers are starving
|
|
32
|
+
current_time = time.time()
|
|
33
|
+
write_starving = (
|
|
34
|
+
self._waiting_writers > 0
|
|
35
|
+
and current_time - self._last_write_time > self._write_starvation_threshold
|
|
36
|
+
)
|
|
37
|
+
|
|
38
|
+
# If no writers are active and no starvation, allow readers to continue
|
|
39
|
+
if self._writers == 0 and not write_starving:
|
|
40
|
+
self._readers += 1
|
|
41
|
+
return True
|
|
42
|
+
|
|
43
|
+
# Otherwise wait
|
|
44
|
+
while self._writers > 0 or write_starving:
|
|
45
|
+
self._read_ready.wait()
|
|
46
|
+
current_time = time.time()
|
|
47
|
+
write_starving = (
|
|
48
|
+
self._waiting_writers > 0
|
|
49
|
+
and current_time - self._last_write_time > self._write_starvation_threshold
|
|
50
|
+
)
|
|
51
|
+
|
|
52
|
+
self._readers += 1
|
|
53
|
+
return True
|
|
54
|
+
|
|
55
|
+
def release_read(self):
|
|
56
|
+
"""Release read lock"""
|
|
57
|
+
with self._lock:
|
|
58
|
+
self._readers -= 1
|
|
59
|
+
if self._readers == 0:
|
|
60
|
+
self._write_ready.notify()
|
|
61
|
+
|
|
62
|
+
def acquire_write(self) -> bool:
|
|
63
|
+
"""Write lock acquisition"""
|
|
64
|
+
with self._lock:
|
|
65
|
+
self._waiting_writers += 1
|
|
66
|
+
try:
|
|
67
|
+
while self._readers > 0 or self._writers > 0:
|
|
68
|
+
self._write_ready.wait()
|
|
69
|
+
|
|
70
|
+
self._writers = 1
|
|
71
|
+
self._waiting_writers -= 1
|
|
72
|
+
self._last_write_time = time.time()
|
|
73
|
+
return True
|
|
74
|
+
except:
|
|
75
|
+
self._waiting_writers -= 1
|
|
76
|
+
raise
|
|
77
|
+
|
|
78
|
+
def release_write(self):
|
|
79
|
+
"""Release write lock"""
|
|
80
|
+
with self._lock:
|
|
81
|
+
self._writers = 0
|
|
82
|
+
# Prioritize notifying readers (reader priority strategy)
|
|
83
|
+
self._read_ready.notify_all()
|
|
84
|
+
self._write_ready.notify()
|
|
85
|
+
|
|
86
|
+
|
|
87
|
+
class SegmentedLock:
|
|
88
|
+
"""Segmented lock, segments based on key hash"""
|
|
89
|
+
|
|
90
|
+
def __init__(self, segment_count: int = 64):
|
|
91
|
+
self.segment_count = segment_count
|
|
92
|
+
self.locks = [FastReadWriteLock() for _ in range(segment_count)]
|
|
93
|
+
|
|
94
|
+
def get_lock(self, key: K) -> FastReadWriteLock:
|
|
95
|
+
"""Get the corresponding lock based on key"""
|
|
96
|
+
segment = hash(key) % self.segment_count
|
|
97
|
+
return self.locks[segment]
|
|
98
|
+
|
|
99
|
+
@contextmanager
|
|
100
|
+
def read_lock(self, key: K):
|
|
101
|
+
"""Read lock context manager"""
|
|
102
|
+
lock = self.get_lock(key)
|
|
103
|
+
lock.acquire_read()
|
|
104
|
+
try:
|
|
105
|
+
yield
|
|
106
|
+
finally:
|
|
107
|
+
lock.release_read()
|
|
108
|
+
|
|
109
|
+
@contextmanager
|
|
110
|
+
def write_lock(self, key: K):
|
|
111
|
+
"""Write lock context manager"""
|
|
112
|
+
lock = self.get_lock(key)
|
|
113
|
+
lock.acquire_write()
|
|
114
|
+
try:
|
|
115
|
+
yield
|
|
116
|
+
finally:
|
|
117
|
+
lock.release_write()
|
|
118
|
+
|
|
119
|
+
|
|
120
|
+
class OptimizedThreadSafeDict(Generic[K, V]):
|
|
121
|
+
"""
|
|
122
|
+
Thread-safe dictionary optimized for FastAPI scenarios:
|
|
123
|
+
- Segmented locks to reduce contention
|
|
124
|
+
- Reader priority with writer starvation prevention
|
|
125
|
+
- Support for large object storage
|
|
126
|
+
- Strong consistency guarantee
|
|
127
|
+
"""
|
|
128
|
+
|
|
129
|
+
def __init__(
|
|
130
|
+
self, initial_dict: dict[K, V] | None = None, segment_count: int = 128
|
|
131
|
+
): # More segments for high concurrency
|
|
132
|
+
self._segments: list[dict[K, V]] = [{} for _ in range(segment_count)]
|
|
133
|
+
self._segment_count = segment_count
|
|
134
|
+
self._segmented_lock = SegmentedLock(segment_count)
|
|
135
|
+
|
|
136
|
+
# Initialize data
|
|
137
|
+
if initial_dict:
|
|
138
|
+
for k, v in initial_dict.items():
|
|
139
|
+
segment_idx = self._get_segment(k)
|
|
140
|
+
self._segments[segment_idx][k] = v
|
|
141
|
+
|
|
142
|
+
def _get_segment(self, key: K) -> int:
|
|
143
|
+
"""Calculate the segment corresponding to the key"""
|
|
144
|
+
return hash(key) % self._segment_count
|
|
145
|
+
|
|
146
|
+
def __getitem__(self, key: K) -> V:
|
|
147
|
+
"""Get element"""
|
|
148
|
+
segment_idx = self._get_segment(key)
|
|
149
|
+
with self._segmented_lock.read_lock(key):
|
|
150
|
+
return self._segments[segment_idx][key]
|
|
151
|
+
|
|
152
|
+
def __setitem__(self, key: K, value: V) -> None:
|
|
153
|
+
"""Set element - key optimization point"""
|
|
154
|
+
segment_idx = self._get_segment(key)
|
|
155
|
+
with self._segmented_lock.write_lock(key):
|
|
156
|
+
self._segments[segment_idx][key] = value
|
|
157
|
+
|
|
158
|
+
def __delitem__(self, key: K) -> None:
|
|
159
|
+
"""Delete element"""
|
|
160
|
+
segment_idx = self._get_segment(key)
|
|
161
|
+
with self._segmented_lock.write_lock(key):
|
|
162
|
+
del self._segments[segment_idx][key]
|
|
163
|
+
|
|
164
|
+
def __contains__(self, key: K) -> bool:
|
|
165
|
+
"""Check if key is contained"""
|
|
166
|
+
segment_idx = self._get_segment(key)
|
|
167
|
+
with self._segmented_lock.read_lock(key):
|
|
168
|
+
return key in self._segments[segment_idx]
|
|
169
|
+
|
|
170
|
+
def get(self, key: K, default: V | None = None) -> V | None:
|
|
171
|
+
"""Safely get element"""
|
|
172
|
+
segment_idx = self._get_segment(key)
|
|
173
|
+
with self._segmented_lock.read_lock(key):
|
|
174
|
+
return self._segments[segment_idx].get(key, default)
|
|
175
|
+
|
|
176
|
+
def pop(self, key: K, *args) -> V:
|
|
177
|
+
"""Pop element"""
|
|
178
|
+
segment_idx = self._get_segment(key)
|
|
179
|
+
with self._segmented_lock.write_lock(key):
|
|
180
|
+
return self._segments[segment_idx].pop(key, *args)
|
|
181
|
+
|
|
182
|
+
def setdefault(self, key: K, default: V | None = None) -> V:
|
|
183
|
+
"""Set default value"""
|
|
184
|
+
segment_idx = self._get_segment(key)
|
|
185
|
+
with self._segmented_lock.write_lock(key):
|
|
186
|
+
return self._segments[segment_idx].setdefault(key, default)
|
|
187
|
+
|
|
188
|
+
def update(self, other=None, **kwargs) -> None:
|
|
189
|
+
"""Batch update - optimized batch operation"""
|
|
190
|
+
items = (other.items() if hasattr(other, "items") else other) if other is not None else []
|
|
191
|
+
|
|
192
|
+
# Group update items by segment
|
|
193
|
+
segment_updates: dict[int, list[tuple[K, V]]] = {}
|
|
194
|
+
|
|
195
|
+
for k, v in items:
|
|
196
|
+
segment_idx = self._get_segment(k)
|
|
197
|
+
if segment_idx not in segment_updates:
|
|
198
|
+
segment_updates[segment_idx] = []
|
|
199
|
+
segment_updates[segment_idx].append((k, v))
|
|
200
|
+
|
|
201
|
+
for k, v in kwargs.items():
|
|
202
|
+
segment_idx = self._get_segment(k)
|
|
203
|
+
if segment_idx not in segment_updates:
|
|
204
|
+
segment_updates[segment_idx] = []
|
|
205
|
+
segment_updates[segment_idx].append((k, v))
|
|
206
|
+
|
|
207
|
+
# Update segment by segment to reduce lock holding time
|
|
208
|
+
for segment_idx, updates in segment_updates.items():
|
|
209
|
+
# Use the first key to get the lock (all keys in the same segment map to the same lock)
|
|
210
|
+
first_key = updates[0][0]
|
|
211
|
+
with self._segmented_lock.write_lock(first_key):
|
|
212
|
+
for k, v in updates:
|
|
213
|
+
self._segments[segment_idx][k] = v
|
|
214
|
+
|
|
215
|
+
def clear(self) -> None:
|
|
216
|
+
"""Clear all elements - need to acquire all locks"""
|
|
217
|
+
# Acquire all locks in order to avoid deadlock
|
|
218
|
+
acquired_locks = []
|
|
219
|
+
try:
|
|
220
|
+
for i in range(self._segment_count):
|
|
221
|
+
lock = self._segmented_lock.locks[i]
|
|
222
|
+
lock.acquire_write()
|
|
223
|
+
acquired_locks.append(lock)
|
|
224
|
+
|
|
225
|
+
# Clear all segments
|
|
226
|
+
for segment in self._segments:
|
|
227
|
+
segment.clear()
|
|
228
|
+
|
|
229
|
+
finally:
|
|
230
|
+
# Release locks in reverse order
|
|
231
|
+
for lock in reversed(acquired_locks):
|
|
232
|
+
lock.release_write()
|
|
233
|
+
|
|
234
|
+
def __len__(self) -> int:
|
|
235
|
+
"""Get total length - snapshot read"""
|
|
236
|
+
total = 0
|
|
237
|
+
acquired_locks = []
|
|
238
|
+
try:
|
|
239
|
+
# Acquire all read locks
|
|
240
|
+
for i in range(self._segment_count):
|
|
241
|
+
lock = self._segmented_lock.locks[i]
|
|
242
|
+
lock.acquire_read()
|
|
243
|
+
acquired_locks.append(lock)
|
|
244
|
+
|
|
245
|
+
# Calculate total length
|
|
246
|
+
for segment in self._segments:
|
|
247
|
+
total += len(segment)
|
|
248
|
+
|
|
249
|
+
return total
|
|
250
|
+
|
|
251
|
+
finally:
|
|
252
|
+
# Release all read locks
|
|
253
|
+
for lock in reversed(acquired_locks):
|
|
254
|
+
lock.release_read()
|
|
255
|
+
|
|
256
|
+
def __bool__(self) -> bool:
|
|
257
|
+
"""Check if empty"""
|
|
258
|
+
return len(self) > 0
|
|
259
|
+
|
|
260
|
+
def keys(self) -> list[K]:
|
|
261
|
+
"""Get snapshot of all keys"""
|
|
262
|
+
all_keys = []
|
|
263
|
+
acquired_locks = []
|
|
264
|
+
|
|
265
|
+
try:
|
|
266
|
+
# Acquire all read locks
|
|
267
|
+
for i in range(self._segment_count):
|
|
268
|
+
lock = self._segmented_lock.locks[i]
|
|
269
|
+
lock.acquire_read()
|
|
270
|
+
acquired_locks.append(lock)
|
|
271
|
+
|
|
272
|
+
# Collect all keys
|
|
273
|
+
for segment in self._segments:
|
|
274
|
+
all_keys.extend(segment.keys())
|
|
275
|
+
|
|
276
|
+
return all_keys
|
|
277
|
+
|
|
278
|
+
finally:
|
|
279
|
+
for lock in reversed(acquired_locks):
|
|
280
|
+
lock.release_read()
|
|
281
|
+
|
|
282
|
+
def values(self) -> list[V]:
|
|
283
|
+
"""Get snapshot of all values"""
|
|
284
|
+
all_values = []
|
|
285
|
+
acquired_locks = []
|
|
286
|
+
|
|
287
|
+
try:
|
|
288
|
+
for i in range(self._segment_count):
|
|
289
|
+
lock = self._segmented_lock.locks[i]
|
|
290
|
+
lock.acquire_read()
|
|
291
|
+
acquired_locks.append(lock)
|
|
292
|
+
|
|
293
|
+
for segment in self._segments:
|
|
294
|
+
all_values.extend(segment.values())
|
|
295
|
+
|
|
296
|
+
return all_values
|
|
297
|
+
|
|
298
|
+
finally:
|
|
299
|
+
for lock in reversed(acquired_locks):
|
|
300
|
+
lock.release_read()
|
|
301
|
+
|
|
302
|
+
def items(self) -> list[tuple[K, V]]:
|
|
303
|
+
"""Get snapshot of all items"""
|
|
304
|
+
all_items = []
|
|
305
|
+
acquired_locks = []
|
|
306
|
+
|
|
307
|
+
try:
|
|
308
|
+
for i in range(self._segment_count):
|
|
309
|
+
lock = self._segmented_lock.locks[i]
|
|
310
|
+
lock.acquire_read()
|
|
311
|
+
acquired_locks.append(lock)
|
|
312
|
+
|
|
313
|
+
for segment in self._segments:
|
|
314
|
+
all_items.extend(segment.items())
|
|
315
|
+
|
|
316
|
+
return all_items
|
|
317
|
+
|
|
318
|
+
finally:
|
|
319
|
+
for lock in reversed(acquired_locks):
|
|
320
|
+
lock.release_read()
|
|
321
|
+
|
|
322
|
+
def copy(self) -> dict[K, V]:
|
|
323
|
+
"""Create dictionary copy"""
|
|
324
|
+
result = {}
|
|
325
|
+
acquired_locks = []
|
|
326
|
+
|
|
327
|
+
try:
|
|
328
|
+
for i in range(self._segment_count):
|
|
329
|
+
lock = self._segmented_lock.locks[i]
|
|
330
|
+
lock.acquire_read()
|
|
331
|
+
acquired_locks.append(lock)
|
|
332
|
+
|
|
333
|
+
for segment in self._segments:
|
|
334
|
+
result.update(segment)
|
|
335
|
+
|
|
336
|
+
return result
|
|
337
|
+
|
|
338
|
+
finally:
|
|
339
|
+
for lock in reversed(acquired_locks):
|
|
340
|
+
lock.release_read()
|
|
341
|
+
|
|
342
|
+
def __iter__(self) -> Iterator[K]:
|
|
343
|
+
"""Iterator - returns snapshot"""
|
|
344
|
+
return iter(self.keys())
|
|
345
|
+
|
|
346
|
+
def __repr__(self) -> str:
|
|
347
|
+
"""String representation"""
|
|
348
|
+
return f"OptimizedThreadSafeDict({dict(self.items())})"
|
|
349
|
+
|
|
350
|
+
def stats(self) -> dict[str, Any]:
|
|
351
|
+
"""Get statistics"""
|
|
352
|
+
segment_sizes = []
|
|
353
|
+
total_items = 0
|
|
354
|
+
|
|
355
|
+
acquired_locks = []
|
|
356
|
+
try:
|
|
357
|
+
for i in range(self._segment_count):
|
|
358
|
+
lock = self._segmented_lock.locks[i]
|
|
359
|
+
lock.acquire_read()
|
|
360
|
+
acquired_locks.append(lock)
|
|
361
|
+
|
|
362
|
+
for segment in self._segments:
|
|
363
|
+
size = len(segment)
|
|
364
|
+
segment_sizes.append(size)
|
|
365
|
+
total_items += size
|
|
366
|
+
|
|
367
|
+
avg_size = total_items / self._segment_count if self._segment_count > 0 else 0
|
|
368
|
+
max_size = max(segment_sizes) if segment_sizes else 0
|
|
369
|
+
min_size = min(segment_sizes) if segment_sizes else 0
|
|
370
|
+
|
|
371
|
+
return {
|
|
372
|
+
"total_items": total_items,
|
|
373
|
+
"segment_count": self._segment_count,
|
|
374
|
+
"avg_segment_size": avg_size,
|
|
375
|
+
"max_segment_size": max_size,
|
|
376
|
+
"min_segment_size": min_size,
|
|
377
|
+
"load_balance_ratio": min_size / max_size if max_size > 0 else 1.0,
|
|
378
|
+
}
|
|
379
|
+
|
|
380
|
+
finally:
|
|
381
|
+
for lock in reversed(acquired_locks):
|
|
382
|
+
lock.release_read()
|