betterdb-agent-memory 0.1.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- betterdb_agent_memory/__init__.py +88 -0
- betterdb_agent_memory/_num.py +51 -0
- betterdb_agent_memory/agent_memory.py +79 -0
- betterdb_agent_memory/build_memory_index.py +60 -0
- betterdb_agent_memory/build_memory_record.py +68 -0
- betterdb_agent_memory/build_recall_query.py +65 -0
- betterdb_agent_memory/composite_score.py +35 -0
- betterdb_agent_memory/discovery.py +164 -0
- betterdb_agent_memory/memory_store.py +964 -0
- betterdb_agent_memory/parse_memory_item.py +34 -0
- betterdb_agent_memory/select_evictions.py +54 -0
- betterdb_agent_memory/telemetry.py +164 -0
- betterdb_agent_memory/types.py +132 -0
- betterdb_agent_memory-0.1.0.dist-info/METADATA +135 -0
- betterdb_agent_memory-0.1.0.dist-info/RECORD +16 -0
- betterdb_agent_memory-0.1.0.dist-info/WHEEL +4 -0
|
@@ -0,0 +1,88 @@
|
|
|
1
|
+
"""betterdb-agent-memory: long-term vector memory tier for AI agents on Valkey.
|
|
2
|
+
|
|
3
|
+
Re-exports everything from ``betterdb-agent-cache`` (the short-term cache tiers)
|
|
4
|
+
alongside the memory tier so the facade is a single import.
|
|
5
|
+
"""
|
|
6
|
+
|
|
7
|
+
from __future__ import annotations
|
|
8
|
+
|
|
9
|
+
import betterdb_agent_cache as _agent_cache
|
|
10
|
+
from betterdb_agent_cache import * # noqa: F401,F403
|
|
11
|
+
|
|
12
|
+
from .agent_memory import AgentMemory, AgentMemoryOptions
|
|
13
|
+
from .build_recall_query import MATCH_ALL_MEMORY_QUERY
|
|
14
|
+
from .composite_score import (
|
|
15
|
+
composite_score,
|
|
16
|
+
recency_decay,
|
|
17
|
+
similarity_from_distance,
|
|
18
|
+
)
|
|
19
|
+
from .discovery import MEMORY_CACHE_TYPE, MEMORY_CAPABILITIES, MemoryDiscovery
|
|
20
|
+
from .memory_store import MemoryStore
|
|
21
|
+
from .telemetry import (
|
|
22
|
+
DEFAULT_METRICS_PREFIX,
|
|
23
|
+
DEFAULT_TRACER_NAME,
|
|
24
|
+
MemoryMetrics,
|
|
25
|
+
MemoryTelemetry,
|
|
26
|
+
MemoryTelemetryOptions,
|
|
27
|
+
create_memory_telemetry,
|
|
28
|
+
)
|
|
29
|
+
from .types import (
|
|
30
|
+
AgentMemoryConfig,
|
|
31
|
+
AgentMemoryRecallConfig,
|
|
32
|
+
ConsolidateResult,
|
|
33
|
+
EmbedFn,
|
|
34
|
+
MemoryConfigRefreshConfig,
|
|
35
|
+
MemoryConfigSnapshot,
|
|
36
|
+
MemoryDiscoveryConfig,
|
|
37
|
+
MemoryHit,
|
|
38
|
+
MemoryItem,
|
|
39
|
+
MemoryListOptions,
|
|
40
|
+
MemoryListResult,
|
|
41
|
+
MemoryScope,
|
|
42
|
+
MemoryStats,
|
|
43
|
+
MemoryStoreClient,
|
|
44
|
+
RecallWeights,
|
|
45
|
+
SummarizeFn,
|
|
46
|
+
)
|
|
47
|
+
|
|
48
|
+
__all__ = [
|
|
49
|
+
# Memory tier
|
|
50
|
+
"AgentMemory",
|
|
51
|
+
"AgentMemoryOptions",
|
|
52
|
+
"AgentMemoryConfig",
|
|
53
|
+
"AgentMemoryRecallConfig",
|
|
54
|
+
"MemoryStore",
|
|
55
|
+
"MemoryDiscovery",
|
|
56
|
+
"MEMORY_CACHE_TYPE",
|
|
57
|
+
"MEMORY_CAPABILITIES",
|
|
58
|
+
# Telemetry
|
|
59
|
+
"create_memory_telemetry",
|
|
60
|
+
"DEFAULT_METRICS_PREFIX",
|
|
61
|
+
"DEFAULT_TRACER_NAME",
|
|
62
|
+
"MemoryTelemetry",
|
|
63
|
+
"MemoryTelemetryOptions",
|
|
64
|
+
"MemoryMetrics",
|
|
65
|
+
# Scoring
|
|
66
|
+
"composite_score",
|
|
67
|
+
"similarity_from_distance",
|
|
68
|
+
"recency_decay",
|
|
69
|
+
# Types
|
|
70
|
+
"EmbedFn",
|
|
71
|
+
"MemoryStoreClient",
|
|
72
|
+
"MemoryScope",
|
|
73
|
+
"MemoryItem",
|
|
74
|
+
"MemoryHit",
|
|
75
|
+
"MemoryListOptions",
|
|
76
|
+
"MemoryListResult",
|
|
77
|
+
"MemoryStats",
|
|
78
|
+
"ConsolidateResult",
|
|
79
|
+
"SummarizeFn",
|
|
80
|
+
"RecallWeights",
|
|
81
|
+
"MemoryConfigSnapshot",
|
|
82
|
+
"MemoryDiscoveryConfig",
|
|
83
|
+
"MemoryConfigRefreshConfig",
|
|
84
|
+
"MATCH_ALL_MEMORY_QUERY",
|
|
85
|
+
]
|
|
86
|
+
|
|
87
|
+
# Surface everything agent-cache exports so consumers need only one import.
|
|
88
|
+
__all__ += list(getattr(_agent_cache, "__all__", []))
|
|
@@ -0,0 +1,51 @@
|
|
|
1
|
+
from __future__ import annotations
|
|
2
|
+
|
|
3
|
+
import math
|
|
4
|
+
import re
|
|
5
|
+
|
|
6
|
+
_FLOAT_RE = re.compile(r"^[+-]?(\d+\.?\d*|\.\d+)([eE][+-]?\d+)?")
|
|
7
|
+
_INT_RE = re.compile(r"^[+-]?\d+")
|
|
8
|
+
|
|
9
|
+
|
|
10
|
+
def _coerce_str(value: object) -> str:
|
|
11
|
+
"""Coerce a valkey reply value to text the way JS ``String()`` would.
|
|
12
|
+
|
|
13
|
+
valkey-py can hand back ``bytes`` for raw replies, so decode those instead of
|
|
14
|
+
stringifying them (``str(b'1')`` is ``"b'1'"``, which would parse as NaN).
|
|
15
|
+
"""
|
|
16
|
+
if isinstance(value, bytes):
|
|
17
|
+
return value.decode()
|
|
18
|
+
return str(value)
|
|
19
|
+
|
|
20
|
+
|
|
21
|
+
def js_number(value: object) -> float:
|
|
22
|
+
"""Mimic JavaScript ``Number(value)`` for the string inputs we see.
|
|
23
|
+
|
|
24
|
+
Empty/whitespace-only strings become ``0`` (as in JS); unparseable strings
|
|
25
|
+
become ``NaN``. ``None`` becomes ``NaN``.
|
|
26
|
+
"""
|
|
27
|
+
if value is None:
|
|
28
|
+
return math.nan
|
|
29
|
+
text = _coerce_str(value).strip()
|
|
30
|
+
if text == "":
|
|
31
|
+
return 0.0
|
|
32
|
+
try:
|
|
33
|
+
return float(text)
|
|
34
|
+
except ValueError:
|
|
35
|
+
return math.nan
|
|
36
|
+
|
|
37
|
+
|
|
38
|
+
def parse_float(value: object) -> float:
|
|
39
|
+
"""Mimic JavaScript ``parseFloat``: parse the leading numeric portion, else NaN."""
|
|
40
|
+
if value is None:
|
|
41
|
+
return math.nan
|
|
42
|
+
match = _FLOAT_RE.match(_coerce_str(value).strip())
|
|
43
|
+
return float(match.group()) if match else math.nan
|
|
44
|
+
|
|
45
|
+
|
|
46
|
+
def parse_int(value: object) -> float:
|
|
47
|
+
"""Mimic JavaScript ``parseInt(value, 10)``: parse the leading integer, else NaN."""
|
|
48
|
+
if value is None:
|
|
49
|
+
return math.nan
|
|
50
|
+
match = _INT_RE.match(_coerce_str(value).strip())
|
|
51
|
+
return int(match.group()) if match else math.nan
|
|
@@ -0,0 +1,79 @@
|
|
|
1
|
+
from __future__ import annotations
|
|
2
|
+
|
|
3
|
+
import asyncio
|
|
4
|
+
from dataclasses import dataclass, field
|
|
5
|
+
|
|
6
|
+
from betterdb_agent_cache import AgentCache, AgentCacheOptions
|
|
7
|
+
|
|
8
|
+
from .memory_store import MemoryStore
|
|
9
|
+
from .telemetry import MemoryTelemetryOptions
|
|
10
|
+
from .types import AgentMemoryConfig, EmbedFn, _empty_memory_config
|
|
11
|
+
|
|
12
|
+
DEFAULT_NAME = "betterdb_ac"
|
|
13
|
+
|
|
14
|
+
|
|
15
|
+
@dataclass(kw_only=True)
|
|
16
|
+
class AgentMemoryOptions(AgentCacheOptions):
|
|
17
|
+
"""Options for the batteries-included :class:`AgentMemory` facade.
|
|
18
|
+
|
|
19
|
+
Extends :class:`AgentCacheOptions` (the three short-term cache tiers) with
|
|
20
|
+
the long-term memory tier: an ``embed_fn`` to vectorize content plus an
|
|
21
|
+
optional ``memory`` sub-config.
|
|
22
|
+
"""
|
|
23
|
+
|
|
24
|
+
embed_fn: EmbedFn
|
|
25
|
+
memory: AgentMemoryConfig = field(default_factory=_empty_memory_config)
|
|
26
|
+
|
|
27
|
+
|
|
28
|
+
class AgentMemory:
|
|
29
|
+
"""Agent cache (llm/tool/session) plus a long-term :class:`MemoryStore` tier."""
|
|
30
|
+
|
|
31
|
+
def __init__(self, options: AgentMemoryOptions) -> None:
|
|
32
|
+
if not callable(getattr(options, "embed_fn", None)):
|
|
33
|
+
raise ValueError("AgentMemory requires an embed_fn to back the memory tier")
|
|
34
|
+
|
|
35
|
+
# The name lives on the shared options object and defaults identically in
|
|
36
|
+
# both tiers, so the cache and memory key prefixes can never drift apart.
|
|
37
|
+
name = options.name
|
|
38
|
+
self._cache = AgentCache(options)
|
|
39
|
+
self.llm = self._cache.llm
|
|
40
|
+
self.tool = self._cache.tool
|
|
41
|
+
self.session = self._cache.session
|
|
42
|
+
|
|
43
|
+
memory = options.memory
|
|
44
|
+
registry = options.telemetry.registry
|
|
45
|
+
self.memory = MemoryStore(
|
|
46
|
+
client=options.client,
|
|
47
|
+
name=name,
|
|
48
|
+
embed_fn=options.embed_fn,
|
|
49
|
+
default_threshold=memory.default_threshold,
|
|
50
|
+
weights=memory.recall.weights if memory.recall is not None else None,
|
|
51
|
+
half_life_seconds=(
|
|
52
|
+
memory.recall.half_life_seconds if memory.recall is not None else None
|
|
53
|
+
),
|
|
54
|
+
max_items_per_scope=memory.max_items_per_scope,
|
|
55
|
+
# The facade is the batteries-included product: discover the memory
|
|
56
|
+
# tier alongside the cache tiers by default, unless explicitly disabled.
|
|
57
|
+
discovery=memory.discovery,
|
|
58
|
+
config_refresh=memory.config_refresh,
|
|
59
|
+
telemetry=MemoryTelemetryOptions(registry=registry) if registry else None,
|
|
60
|
+
)
|
|
61
|
+
|
|
62
|
+
async def initialize(self) -> None:
|
|
63
|
+
# Create the memory index before discovery so a freshly constructed facade
|
|
64
|
+
# is immediately usable for remember/recall without the caller hand-rolling
|
|
65
|
+
# the FT index. A create failure surfaces — the tier is unusable without it.
|
|
66
|
+
await self.memory.ensure_index()
|
|
67
|
+
# Surface a discovery name-collision from either tier instead of swallowing it.
|
|
68
|
+
await asyncio.gather(
|
|
69
|
+
self._cache.ensure_discovery_ready(),
|
|
70
|
+
self.memory.ensure_discovery_ready(),
|
|
71
|
+
self.memory.ensure_config_refresh_started(),
|
|
72
|
+
)
|
|
73
|
+
|
|
74
|
+
async def close(self) -> None:
|
|
75
|
+
# Tear down both tiers even if one fails, so timers and heartbeats can't leak.
|
|
76
|
+
try:
|
|
77
|
+
await self.memory.close()
|
|
78
|
+
finally:
|
|
79
|
+
await self._cache.shutdown()
|
|
@@ -0,0 +1,60 @@
|
|
|
1
|
+
from __future__ import annotations
|
|
2
|
+
|
|
3
|
+
from .build_recall_query import VECTOR_FIELD
|
|
4
|
+
|
|
5
|
+
MEMORY_INDEX_ALGORITHM = "HNSW"
|
|
6
|
+
|
|
7
|
+
|
|
8
|
+
def memory_index_name(name: str) -> str:
|
|
9
|
+
return f"{name}:mem:idx"
|
|
10
|
+
|
|
11
|
+
|
|
12
|
+
def memory_key_prefix(name: str) -> str:
|
|
13
|
+
return f"{name}:mem:"
|
|
14
|
+
|
|
15
|
+
|
|
16
|
+
def build_memory_index_args(name: str, dims: int) -> list[str]:
|
|
17
|
+
if not isinstance(dims, int) or isinstance(dims, bool) or dims <= 0:
|
|
18
|
+
raise ValueError(f"memory index dimension must be a positive integer, got: {dims}")
|
|
19
|
+
return [
|
|
20
|
+
memory_index_name(name),
|
|
21
|
+
"ON",
|
|
22
|
+
"HASH",
|
|
23
|
+
"PREFIX",
|
|
24
|
+
"1",
|
|
25
|
+
memory_key_prefix(name),
|
|
26
|
+
"SCHEMA",
|
|
27
|
+
VECTOR_FIELD,
|
|
28
|
+
"VECTOR",
|
|
29
|
+
MEMORY_INDEX_ALGORITHM,
|
|
30
|
+
"6",
|
|
31
|
+
"TYPE",
|
|
32
|
+
"FLOAT32",
|
|
33
|
+
"DIM",
|
|
34
|
+
str(dims),
|
|
35
|
+
"DISTANCE_METRIC",
|
|
36
|
+
"COSINE",
|
|
37
|
+
"threadId",
|
|
38
|
+
"TAG",
|
|
39
|
+
"agentId",
|
|
40
|
+
"TAG",
|
|
41
|
+
"namespace",
|
|
42
|
+
"TAG",
|
|
43
|
+
"tags",
|
|
44
|
+
"TAG",
|
|
45
|
+
"SEPARATOR",
|
|
46
|
+
",",
|
|
47
|
+
"source",
|
|
48
|
+
"TAG",
|
|
49
|
+
"importance",
|
|
50
|
+
"NUMERIC",
|
|
51
|
+
"created_at",
|
|
52
|
+
"NUMERIC",
|
|
53
|
+
"SORTABLE",
|
|
54
|
+
"last_accessed_at",
|
|
55
|
+
"NUMERIC",
|
|
56
|
+
"access_count",
|
|
57
|
+
"NUMERIC",
|
|
58
|
+
"content",
|
|
59
|
+
"TEXT",
|
|
60
|
+
]
|
|
@@ -0,0 +1,68 @@
|
|
|
1
|
+
from __future__ import annotations
|
|
2
|
+
|
|
3
|
+
import math
|
|
4
|
+
from dataclasses import dataclass
|
|
5
|
+
|
|
6
|
+
from betterdb_valkey_search_kit import encode_float32
|
|
7
|
+
|
|
8
|
+
DEFAULT_IMPORTANCE = 0.5
|
|
9
|
+
|
|
10
|
+
|
|
11
|
+
@dataclass
|
|
12
|
+
class MemoryWrite:
|
|
13
|
+
key: str
|
|
14
|
+
fields: list[str | bytes]
|
|
15
|
+
|
|
16
|
+
|
|
17
|
+
def build_memory_record(
|
|
18
|
+
name: str,
|
|
19
|
+
id: str,
|
|
20
|
+
content: str,
|
|
21
|
+
vector: list[float],
|
|
22
|
+
*,
|
|
23
|
+
importance: float | None = None,
|
|
24
|
+
tags: list[str] | None = None,
|
|
25
|
+
source: str | None = None,
|
|
26
|
+
thread_id: str | None = None,
|
|
27
|
+
agent_id: str | None = None,
|
|
28
|
+
namespace: str | None = None,
|
|
29
|
+
now: int,
|
|
30
|
+
) -> MemoryWrite:
|
|
31
|
+
imp = importance if importance is not None else DEFAULT_IMPORTANCE
|
|
32
|
+
if not isinstance(imp, (int, float)) or not math.isfinite(imp) or imp < 0 or imp > 1:
|
|
33
|
+
raise ValueError(f"importance must be a finite number in [0, 1], got: {importance}")
|
|
34
|
+
|
|
35
|
+
fields: list[str | bytes] = [
|
|
36
|
+
"content",
|
|
37
|
+
content,
|
|
38
|
+
"vector",
|
|
39
|
+
encode_float32(vector),
|
|
40
|
+
"importance",
|
|
41
|
+
str(imp),
|
|
42
|
+
"created_at",
|
|
43
|
+
str(now),
|
|
44
|
+
"last_accessed_at",
|
|
45
|
+
str(now),
|
|
46
|
+
"access_count",
|
|
47
|
+
"0",
|
|
48
|
+
]
|
|
49
|
+
|
|
50
|
+
tag_list = tags if tags is not None else []
|
|
51
|
+
for tag in tag_list:
|
|
52
|
+
if "," in tag:
|
|
53
|
+
raise ValueError(
|
|
54
|
+
f"Tag '{tag}' must not contain a comma; tags are stored comma-separated"
|
|
55
|
+
)
|
|
56
|
+
if len(tag_list) > 0:
|
|
57
|
+
fields.extend(["tags", ",".join(tag_list)])
|
|
58
|
+
|
|
59
|
+
if thread_id is not None:
|
|
60
|
+
fields.extend(["threadId", thread_id])
|
|
61
|
+
if agent_id is not None:
|
|
62
|
+
fields.extend(["agentId", agent_id])
|
|
63
|
+
if namespace is not None:
|
|
64
|
+
fields.extend(["namespace", namespace])
|
|
65
|
+
if source is not None:
|
|
66
|
+
fields.extend(["source", source])
|
|
67
|
+
|
|
68
|
+
return MemoryWrite(key=f"{name}:mem:{id}", fields=fields)
|
|
@@ -0,0 +1,65 @@
|
|
|
1
|
+
from __future__ import annotations
|
|
2
|
+
|
|
3
|
+
from dataclasses import dataclass
|
|
4
|
+
|
|
5
|
+
from betterdb_valkey_search_kit import escape_tag
|
|
6
|
+
|
|
7
|
+
from .types import MemoryScope
|
|
8
|
+
|
|
9
|
+
SCORE_FIELD = "__score"
|
|
10
|
+
VECTOR_FIELD = "vector"
|
|
11
|
+
|
|
12
|
+
# valkey-search rejects a bare '*' on VECTOR-schema indexes with
|
|
13
|
+
# "Invalid query string syntax". Every memory record has created_at
|
|
14
|
+
# (set at write time), so this numeric range matches all documents
|
|
15
|
+
# and is accepted by the vector index schema.
|
|
16
|
+
MATCH_ALL_MEMORY_QUERY = "@created_at:[-inf +inf]"
|
|
17
|
+
|
|
18
|
+
|
|
19
|
+
def _scope_clauses(scope: MemoryScope, tags: list[str]) -> list[str]:
|
|
20
|
+
clauses: list[str] = []
|
|
21
|
+
if scope.thread_id is not None:
|
|
22
|
+
clauses.append(f"@threadId:{{{escape_tag(scope.thread_id)}}}")
|
|
23
|
+
if scope.agent_id is not None:
|
|
24
|
+
clauses.append(f"@agentId:{{{escape_tag(scope.agent_id)}}}")
|
|
25
|
+
if scope.namespace is not None:
|
|
26
|
+
clauses.append(f"@namespace:{{{escape_tag(scope.namespace)}}}")
|
|
27
|
+
for tag in tags:
|
|
28
|
+
clauses.append(f"@tags:{{{escape_tag(tag)}}}")
|
|
29
|
+
return clauses
|
|
30
|
+
|
|
31
|
+
|
|
32
|
+
def _join_clauses(clauses: list[str]) -> str:
|
|
33
|
+
if len(clauses) == 0:
|
|
34
|
+
return MATCH_ALL_MEMORY_QUERY
|
|
35
|
+
return f"({' '.join(clauses)})"
|
|
36
|
+
|
|
37
|
+
|
|
38
|
+
def build_scope_filter(scope: MemoryScope, tags: list[str]) -> str:
|
|
39
|
+
return _join_clauses(_scope_clauses(scope, tags))
|
|
40
|
+
|
|
41
|
+
|
|
42
|
+
@dataclass
|
|
43
|
+
class ConsolidateFilterOptions:
|
|
44
|
+
max_created_at: int | None = None
|
|
45
|
+
max_importance: float | None = None
|
|
46
|
+
exclude_source: str | None = None
|
|
47
|
+
|
|
48
|
+
|
|
49
|
+
def build_consolidate_filter(
|
|
50
|
+
scope: MemoryScope,
|
|
51
|
+
tags: list[str],
|
|
52
|
+
options: ConsolidateFilterOptions,
|
|
53
|
+
) -> str:
|
|
54
|
+
clauses = _scope_clauses(scope, tags)
|
|
55
|
+
if options.max_created_at is not None:
|
|
56
|
+
clauses.append(f"@created_at:[-inf {options.max_created_at}]")
|
|
57
|
+
if options.max_importance is not None:
|
|
58
|
+
clauses.append(f"@importance:[-inf {options.max_importance}]")
|
|
59
|
+
if options.exclude_source is not None:
|
|
60
|
+
clauses.append(f"-@source:{{{escape_tag(options.exclude_source)}}}")
|
|
61
|
+
return _join_clauses(clauses)
|
|
62
|
+
|
|
63
|
+
|
|
64
|
+
def build_recall_query(k: int, scope: MemoryScope, tags: list[str]) -> str:
|
|
65
|
+
return f"{build_scope_filter(scope, tags)}=>[KNN {k} @{VECTOR_FIELD} $vec AS {SCORE_FIELD}]"
|
|
@@ -0,0 +1,35 @@
|
|
|
1
|
+
from __future__ import annotations
|
|
2
|
+
|
|
3
|
+
import math
|
|
4
|
+
|
|
5
|
+
from .types import RecallWeights
|
|
6
|
+
|
|
7
|
+
|
|
8
|
+
def recency_decay(age_seconds: float, half_life_seconds: float) -> float:
|
|
9
|
+
"""True half-life decay: 1 at age 0, 0.5 at one half_life_seconds, approaching 0 beyond."""
|
|
10
|
+
return math.exp((-math.log(2) * age_seconds) / half_life_seconds)
|
|
11
|
+
|
|
12
|
+
|
|
13
|
+
def composite_score(
|
|
14
|
+
*,
|
|
15
|
+
similarity: float,
|
|
16
|
+
age_seconds: float,
|
|
17
|
+
importance: float,
|
|
18
|
+
weights: RecallWeights,
|
|
19
|
+
half_life_seconds: float,
|
|
20
|
+
) -> float:
|
|
21
|
+
"""Weighted blend of semantic similarity, recency, and importance.
|
|
22
|
+
|
|
23
|
+
Recency is a true half-life decay: 0.5 at one half_life_seconds.
|
|
24
|
+
"""
|
|
25
|
+
recency = recency_decay(age_seconds, half_life_seconds)
|
|
26
|
+
return (
|
|
27
|
+
weights.similarity * similarity
|
|
28
|
+
+ weights.recency * recency
|
|
29
|
+
+ weights.importance * importance
|
|
30
|
+
)
|
|
31
|
+
|
|
32
|
+
|
|
33
|
+
def similarity_from_distance(distance: float) -> float:
|
|
34
|
+
"""Map cosine distance (0..2, lower = closer) to a 0..1 similarity score."""
|
|
35
|
+
return 1 - distance / 2
|
|
@@ -0,0 +1,164 @@
|
|
|
1
|
+
from __future__ import annotations
|
|
2
|
+
|
|
3
|
+
import asyncio
|
|
4
|
+
import json
|
|
5
|
+
import os
|
|
6
|
+
import socket
|
|
7
|
+
import warnings
|
|
8
|
+
from datetime import datetime, timezone
|
|
9
|
+
from typing import Any, Callable
|
|
10
|
+
|
|
11
|
+
from betterdb_agent_cache.discovery import (
|
|
12
|
+
DEFAULT_HEARTBEAT_INTERVAL_S,
|
|
13
|
+
HEARTBEAT_KEY_PREFIX,
|
|
14
|
+
HEARTBEAT_TTL_SECONDS,
|
|
15
|
+
PROTOCOL_KEY,
|
|
16
|
+
PROTOCOL_VERSION,
|
|
17
|
+
REGISTRY_KEY,
|
|
18
|
+
)
|
|
19
|
+
|
|
20
|
+
from .types import MemoryStoreClient
|
|
21
|
+
|
|
22
|
+
MEMORY_CACHE_TYPE = "agent_memory"
|
|
23
|
+
MEMORY_CAPABILITIES = ["recall", "consolidate", "reinforce"]
|
|
24
|
+
|
|
25
|
+
|
|
26
|
+
class MemoryDiscovery:
|
|
27
|
+
def __init__(
|
|
28
|
+
self,
|
|
29
|
+
*,
|
|
30
|
+
client: MemoryStoreClient,
|
|
31
|
+
name: str,
|
|
32
|
+
version: str,
|
|
33
|
+
stats_key: str,
|
|
34
|
+
heartbeat_interval_s: float | None = None,
|
|
35
|
+
on_write_failed: Callable[[], None] | None = None,
|
|
36
|
+
) -> None:
|
|
37
|
+
self._client = client
|
|
38
|
+
self._name = name
|
|
39
|
+
self._version = version
|
|
40
|
+
self._stats_key = stats_key
|
|
41
|
+
self._heartbeat_interval_s = (
|
|
42
|
+
heartbeat_interval_s
|
|
43
|
+
if heartbeat_interval_s is not None
|
|
44
|
+
else DEFAULT_HEARTBEAT_INTERVAL_S
|
|
45
|
+
)
|
|
46
|
+
# Namespace the marker under `{name}:mem` so a memory store and an
|
|
47
|
+
# agent-cache sharing the same name register distinct registry fields
|
|
48
|
+
# and heartbeat keys instead of clobbering each other.
|
|
49
|
+
self._marker_field = f"{name}:mem"
|
|
50
|
+
self._heartbeat_key = f"{HEARTBEAT_KEY_PREFIX}{self._marker_field}"
|
|
51
|
+
self._started_at = datetime.now(timezone.utc).isoformat()
|
|
52
|
+
self._on_write_failed: Callable[[], None] = on_write_failed or (lambda: None)
|
|
53
|
+
self._heartbeat_task: asyncio.Task[None] | None = None
|
|
54
|
+
|
|
55
|
+
def build_marker(self) -> dict[str, Any]:
|
|
56
|
+
return {
|
|
57
|
+
"type": MEMORY_CACHE_TYPE,
|
|
58
|
+
"prefix": self._name,
|
|
59
|
+
"version": self._version,
|
|
60
|
+
"protocol_version": PROTOCOL_VERSION,
|
|
61
|
+
"capabilities": list(MEMORY_CAPABILITIES),
|
|
62
|
+
"stats_key": self._stats_key,
|
|
63
|
+
"started_at": self._started_at,
|
|
64
|
+
"pid": os.getpid(),
|
|
65
|
+
"hostname": socket.gethostname(),
|
|
66
|
+
}
|
|
67
|
+
|
|
68
|
+
async def register(self) -> None:
|
|
69
|
+
# HGET-then-HSET is not atomic (TOCTOU); acceptable for best-effort
|
|
70
|
+
# discovery — a racing writer just means last-writer-wins on the marker.
|
|
71
|
+
existing = await self._safe_hget()
|
|
72
|
+
if existing is not None:
|
|
73
|
+
self._check_collision(existing)
|
|
74
|
+
await self._write_marker()
|
|
75
|
+
await self._safe_call(
|
|
76
|
+
lambda: self._client.execute_command("SET", PROTOCOL_KEY, str(PROTOCOL_VERSION), "NX")
|
|
77
|
+
)
|
|
78
|
+
await self._write_heartbeat()
|
|
79
|
+
self._start_heartbeat()
|
|
80
|
+
|
|
81
|
+
async def stop(self, *, delete_heartbeat: bool) -> None:
|
|
82
|
+
if self._heartbeat_task is not None:
|
|
83
|
+
self._heartbeat_task.cancel()
|
|
84
|
+
try:
|
|
85
|
+
await self._heartbeat_task
|
|
86
|
+
except (asyncio.CancelledError, Exception):
|
|
87
|
+
pass
|
|
88
|
+
self._heartbeat_task = None
|
|
89
|
+
if not delete_heartbeat:
|
|
90
|
+
return
|
|
91
|
+
try:
|
|
92
|
+
await self._client.execute_command("DEL", self._heartbeat_key)
|
|
93
|
+
except Exception:
|
|
94
|
+
self._on_write_failed()
|
|
95
|
+
|
|
96
|
+
async def tick_heartbeat(self) -> None:
|
|
97
|
+
await self._write_heartbeat()
|
|
98
|
+
await self._write_marker()
|
|
99
|
+
# PROTOCOL_KEY is set once in register(); the NX SET is a guaranteed
|
|
100
|
+
# no-op on every subsequent tick, so it's not re-issued here.
|
|
101
|
+
|
|
102
|
+
def _start_heartbeat(self) -> None:
|
|
103
|
+
async def _loop() -> None:
|
|
104
|
+
try:
|
|
105
|
+
while True:
|
|
106
|
+
await asyncio.sleep(self._heartbeat_interval_s)
|
|
107
|
+
await self.tick_heartbeat()
|
|
108
|
+
except asyncio.CancelledError:
|
|
109
|
+
pass
|
|
110
|
+
|
|
111
|
+
self._heartbeat_task = asyncio.create_task(_loop())
|
|
112
|
+
|
|
113
|
+
async def _write_heartbeat(self) -> None:
|
|
114
|
+
now = datetime.now(timezone.utc).isoformat()
|
|
115
|
+
try:
|
|
116
|
+
await self._client.execute_command(
|
|
117
|
+
"SET", self._heartbeat_key, now, "EX", str(HEARTBEAT_TTL_SECONDS)
|
|
118
|
+
)
|
|
119
|
+
except Exception:
|
|
120
|
+
self._on_write_failed()
|
|
121
|
+
|
|
122
|
+
async def _write_marker(self) -> None:
|
|
123
|
+
try:
|
|
124
|
+
payload = json.dumps(self.build_marker())
|
|
125
|
+
except Exception:
|
|
126
|
+
self._on_write_failed()
|
|
127
|
+
return
|
|
128
|
+
await self._safe_call(
|
|
129
|
+
lambda: self._client.execute_command("HSET", REGISTRY_KEY, self._marker_field, payload)
|
|
130
|
+
)
|
|
131
|
+
|
|
132
|
+
async def _safe_hget(self) -> str | None:
|
|
133
|
+
try:
|
|
134
|
+
value = await self._client.execute_command("HGET", REGISTRY_KEY, self._marker_field)
|
|
135
|
+
if value is None:
|
|
136
|
+
return None
|
|
137
|
+
return value.decode() if isinstance(value, bytes) else str(value)
|
|
138
|
+
except Exception:
|
|
139
|
+
self._on_write_failed()
|
|
140
|
+
return None
|
|
141
|
+
|
|
142
|
+
async def _safe_call(self, fn: Callable[[], Any]) -> None:
|
|
143
|
+
try:
|
|
144
|
+
await fn()
|
|
145
|
+
except Exception:
|
|
146
|
+
self._on_write_failed()
|
|
147
|
+
|
|
148
|
+
def _check_collision(self, existing_json: str) -> None:
|
|
149
|
+
try:
|
|
150
|
+
parsed = json.loads(existing_json)
|
|
151
|
+
except Exception:
|
|
152
|
+
return
|
|
153
|
+
existing_type = parsed.get("type") if isinstance(parsed, dict) else None
|
|
154
|
+
if existing_type and existing_type != MEMORY_CACHE_TYPE:
|
|
155
|
+
# The memory marker lives under `{name}:mem`, distinct from
|
|
156
|
+
# agent-cache's `{name}`, so the two tiers never collide here.
|
|
157
|
+
# Surface it with a visible warning rather than raising into a
|
|
158
|
+
# swallowed registration; registration then proceeds
|
|
159
|
+
# last-writer-wins, matching agent-cache.
|
|
160
|
+
warnings.warn(
|
|
161
|
+
f"agent-memory discovery: field '{self._marker_field}' already holds a "
|
|
162
|
+
f"'{existing_type}' marker; overwriting",
|
|
163
|
+
stacklevel=2,
|
|
164
|
+
)
|