betterdb-agent-memory 0.1.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -0,0 +1,88 @@
1
+ """betterdb-agent-memory: long-term vector memory tier for AI agents on Valkey.
2
+
3
+ Re-exports everything from ``betterdb-agent-cache`` (the short-term cache tiers)
4
+ alongside the memory tier so the facade is a single import.
5
+ """
6
+
7
+ from __future__ import annotations
8
+
9
+ import betterdb_agent_cache as _agent_cache
10
+ from betterdb_agent_cache import * # noqa: F401,F403
11
+
12
+ from .agent_memory import AgentMemory, AgentMemoryOptions
13
+ from .build_recall_query import MATCH_ALL_MEMORY_QUERY
14
+ from .composite_score import (
15
+ composite_score,
16
+ recency_decay,
17
+ similarity_from_distance,
18
+ )
19
+ from .discovery import MEMORY_CACHE_TYPE, MEMORY_CAPABILITIES, MemoryDiscovery
20
+ from .memory_store import MemoryStore
21
+ from .telemetry import (
22
+ DEFAULT_METRICS_PREFIX,
23
+ DEFAULT_TRACER_NAME,
24
+ MemoryMetrics,
25
+ MemoryTelemetry,
26
+ MemoryTelemetryOptions,
27
+ create_memory_telemetry,
28
+ )
29
+ from .types import (
30
+ AgentMemoryConfig,
31
+ AgentMemoryRecallConfig,
32
+ ConsolidateResult,
33
+ EmbedFn,
34
+ MemoryConfigRefreshConfig,
35
+ MemoryConfigSnapshot,
36
+ MemoryDiscoveryConfig,
37
+ MemoryHit,
38
+ MemoryItem,
39
+ MemoryListOptions,
40
+ MemoryListResult,
41
+ MemoryScope,
42
+ MemoryStats,
43
+ MemoryStoreClient,
44
+ RecallWeights,
45
+ SummarizeFn,
46
+ )
47
+
48
+ __all__ = [
49
+ # Memory tier
50
+ "AgentMemory",
51
+ "AgentMemoryOptions",
52
+ "AgentMemoryConfig",
53
+ "AgentMemoryRecallConfig",
54
+ "MemoryStore",
55
+ "MemoryDiscovery",
56
+ "MEMORY_CACHE_TYPE",
57
+ "MEMORY_CAPABILITIES",
58
+ # Telemetry
59
+ "create_memory_telemetry",
60
+ "DEFAULT_METRICS_PREFIX",
61
+ "DEFAULT_TRACER_NAME",
62
+ "MemoryTelemetry",
63
+ "MemoryTelemetryOptions",
64
+ "MemoryMetrics",
65
+ # Scoring
66
+ "composite_score",
67
+ "similarity_from_distance",
68
+ "recency_decay",
69
+ # Types
70
+ "EmbedFn",
71
+ "MemoryStoreClient",
72
+ "MemoryScope",
73
+ "MemoryItem",
74
+ "MemoryHit",
75
+ "MemoryListOptions",
76
+ "MemoryListResult",
77
+ "MemoryStats",
78
+ "ConsolidateResult",
79
+ "SummarizeFn",
80
+ "RecallWeights",
81
+ "MemoryConfigSnapshot",
82
+ "MemoryDiscoveryConfig",
83
+ "MemoryConfigRefreshConfig",
84
+ "MATCH_ALL_MEMORY_QUERY",
85
+ ]
86
+
87
+ # Surface everything agent-cache exports so consumers need only one import.
88
+ __all__ += list(getattr(_agent_cache, "__all__", []))
@@ -0,0 +1,51 @@
1
+ from __future__ import annotations
2
+
3
+ import math
4
+ import re
5
+
6
+ _FLOAT_RE = re.compile(r"^[+-]?(\d+\.?\d*|\.\d+)([eE][+-]?\d+)?")
7
+ _INT_RE = re.compile(r"^[+-]?\d+")
8
+
9
+
10
+ def _coerce_str(value: object) -> str:
11
+ """Coerce a valkey reply value to text the way JS ``String()`` would.
12
+
13
+ valkey-py can hand back ``bytes`` for raw replies, so decode those instead of
14
+ stringifying them (``str(b'1')`` is ``"b'1'"``, which would parse as NaN).
15
+ """
16
+ if isinstance(value, bytes):
17
+ return value.decode()
18
+ return str(value)
19
+
20
+
21
+ def js_number(value: object) -> float:
22
+ """Mimic JavaScript ``Number(value)`` for the string inputs we see.
23
+
24
+ Empty/whitespace-only strings become ``0`` (as in JS); unparseable strings
25
+ become ``NaN``. ``None`` becomes ``NaN``.
26
+ """
27
+ if value is None:
28
+ return math.nan
29
+ text = _coerce_str(value).strip()
30
+ if text == "":
31
+ return 0.0
32
+ try:
33
+ return float(text)
34
+ except ValueError:
35
+ return math.nan
36
+
37
+
38
+ def parse_float(value: object) -> float:
39
+ """Mimic JavaScript ``parseFloat``: parse the leading numeric portion, else NaN."""
40
+ if value is None:
41
+ return math.nan
42
+ match = _FLOAT_RE.match(_coerce_str(value).strip())
43
+ return float(match.group()) if match else math.nan
44
+
45
+
46
+ def parse_int(value: object) -> float:
47
+ """Mimic JavaScript ``parseInt(value, 10)``: parse the leading integer, else NaN."""
48
+ if value is None:
49
+ return math.nan
50
+ match = _INT_RE.match(_coerce_str(value).strip())
51
+ return int(match.group()) if match else math.nan
@@ -0,0 +1,79 @@
1
+ from __future__ import annotations
2
+
3
+ import asyncio
4
+ from dataclasses import dataclass, field
5
+
6
+ from betterdb_agent_cache import AgentCache, AgentCacheOptions
7
+
8
+ from .memory_store import MemoryStore
9
+ from .telemetry import MemoryTelemetryOptions
10
+ from .types import AgentMemoryConfig, EmbedFn, _empty_memory_config
11
+
12
+ DEFAULT_NAME = "betterdb_ac"
13
+
14
+
15
+ @dataclass(kw_only=True)
16
+ class AgentMemoryOptions(AgentCacheOptions):
17
+ """Options for the batteries-included :class:`AgentMemory` facade.
18
+
19
+ Extends :class:`AgentCacheOptions` (the three short-term cache tiers) with
20
+ the long-term memory tier: an ``embed_fn`` to vectorize content plus an
21
+ optional ``memory`` sub-config.
22
+ """
23
+
24
+ embed_fn: EmbedFn
25
+ memory: AgentMemoryConfig = field(default_factory=_empty_memory_config)
26
+
27
+
28
+ class AgentMemory:
29
+ """Agent cache (llm/tool/session) plus a long-term :class:`MemoryStore` tier."""
30
+
31
+ def __init__(self, options: AgentMemoryOptions) -> None:
32
+ if not callable(getattr(options, "embed_fn", None)):
33
+ raise ValueError("AgentMemory requires an embed_fn to back the memory tier")
34
+
35
+ # The name lives on the shared options object and defaults identically in
36
+ # both tiers, so the cache and memory key prefixes can never drift apart.
37
+ name = options.name
38
+ self._cache = AgentCache(options)
39
+ self.llm = self._cache.llm
40
+ self.tool = self._cache.tool
41
+ self.session = self._cache.session
42
+
43
+ memory = options.memory
44
+ registry = options.telemetry.registry
45
+ self.memory = MemoryStore(
46
+ client=options.client,
47
+ name=name,
48
+ embed_fn=options.embed_fn,
49
+ default_threshold=memory.default_threshold,
50
+ weights=memory.recall.weights if memory.recall is not None else None,
51
+ half_life_seconds=(
52
+ memory.recall.half_life_seconds if memory.recall is not None else None
53
+ ),
54
+ max_items_per_scope=memory.max_items_per_scope,
55
+ # The facade is the batteries-included product: discover the memory
56
+ # tier alongside the cache tiers by default, unless explicitly disabled.
57
+ discovery=memory.discovery,
58
+ config_refresh=memory.config_refresh,
59
+ telemetry=MemoryTelemetryOptions(registry=registry) if registry else None,
60
+ )
61
+
62
+ async def initialize(self) -> None:
63
+ # Create the memory index before discovery so a freshly constructed facade
64
+ # is immediately usable for remember/recall without the caller hand-rolling
65
+ # the FT index. A create failure surfaces — the tier is unusable without it.
66
+ await self.memory.ensure_index()
67
+ # Surface a discovery name-collision from either tier instead of swallowing it.
68
+ await asyncio.gather(
69
+ self._cache.ensure_discovery_ready(),
70
+ self.memory.ensure_discovery_ready(),
71
+ self.memory.ensure_config_refresh_started(),
72
+ )
73
+
74
+ async def close(self) -> None:
75
+ # Tear down both tiers even if one fails, so timers and heartbeats can't leak.
76
+ try:
77
+ await self.memory.close()
78
+ finally:
79
+ await self._cache.shutdown()
@@ -0,0 +1,60 @@
1
+ from __future__ import annotations
2
+
3
+ from .build_recall_query import VECTOR_FIELD
4
+
5
+ MEMORY_INDEX_ALGORITHM = "HNSW"
6
+
7
+
8
+ def memory_index_name(name: str) -> str:
9
+ return f"{name}:mem:idx"
10
+
11
+
12
+ def memory_key_prefix(name: str) -> str:
13
+ return f"{name}:mem:"
14
+
15
+
16
+ def build_memory_index_args(name: str, dims: int) -> list[str]:
17
+ if not isinstance(dims, int) or isinstance(dims, bool) or dims <= 0:
18
+ raise ValueError(f"memory index dimension must be a positive integer, got: {dims}")
19
+ return [
20
+ memory_index_name(name),
21
+ "ON",
22
+ "HASH",
23
+ "PREFIX",
24
+ "1",
25
+ memory_key_prefix(name),
26
+ "SCHEMA",
27
+ VECTOR_FIELD,
28
+ "VECTOR",
29
+ MEMORY_INDEX_ALGORITHM,
30
+ "6",
31
+ "TYPE",
32
+ "FLOAT32",
33
+ "DIM",
34
+ str(dims),
35
+ "DISTANCE_METRIC",
36
+ "COSINE",
37
+ "threadId",
38
+ "TAG",
39
+ "agentId",
40
+ "TAG",
41
+ "namespace",
42
+ "TAG",
43
+ "tags",
44
+ "TAG",
45
+ "SEPARATOR",
46
+ ",",
47
+ "source",
48
+ "TAG",
49
+ "importance",
50
+ "NUMERIC",
51
+ "created_at",
52
+ "NUMERIC",
53
+ "SORTABLE",
54
+ "last_accessed_at",
55
+ "NUMERIC",
56
+ "access_count",
57
+ "NUMERIC",
58
+ "content",
59
+ "TEXT",
60
+ ]
@@ -0,0 +1,68 @@
1
+ from __future__ import annotations
2
+
3
+ import math
4
+ from dataclasses import dataclass
5
+
6
+ from betterdb_valkey_search_kit import encode_float32
7
+
8
+ DEFAULT_IMPORTANCE = 0.5
9
+
10
+
11
+ @dataclass
12
+ class MemoryWrite:
13
+ key: str
14
+ fields: list[str | bytes]
15
+
16
+
17
+ def build_memory_record(
18
+ name: str,
19
+ id: str,
20
+ content: str,
21
+ vector: list[float],
22
+ *,
23
+ importance: float | None = None,
24
+ tags: list[str] | None = None,
25
+ source: str | None = None,
26
+ thread_id: str | None = None,
27
+ agent_id: str | None = None,
28
+ namespace: str | None = None,
29
+ now: int,
30
+ ) -> MemoryWrite:
31
+ imp = importance if importance is not None else DEFAULT_IMPORTANCE
32
+ if not isinstance(imp, (int, float)) or not math.isfinite(imp) or imp < 0 or imp > 1:
33
+ raise ValueError(f"importance must be a finite number in [0, 1], got: {importance}")
34
+
35
+ fields: list[str | bytes] = [
36
+ "content",
37
+ content,
38
+ "vector",
39
+ encode_float32(vector),
40
+ "importance",
41
+ str(imp),
42
+ "created_at",
43
+ str(now),
44
+ "last_accessed_at",
45
+ str(now),
46
+ "access_count",
47
+ "0",
48
+ ]
49
+
50
+ tag_list = tags if tags is not None else []
51
+ for tag in tag_list:
52
+ if "," in tag:
53
+ raise ValueError(
54
+ f"Tag '{tag}' must not contain a comma; tags are stored comma-separated"
55
+ )
56
+ if len(tag_list) > 0:
57
+ fields.extend(["tags", ",".join(tag_list)])
58
+
59
+ if thread_id is not None:
60
+ fields.extend(["threadId", thread_id])
61
+ if agent_id is not None:
62
+ fields.extend(["agentId", agent_id])
63
+ if namespace is not None:
64
+ fields.extend(["namespace", namespace])
65
+ if source is not None:
66
+ fields.extend(["source", source])
67
+
68
+ return MemoryWrite(key=f"{name}:mem:{id}", fields=fields)
@@ -0,0 +1,65 @@
1
+ from __future__ import annotations
2
+
3
+ from dataclasses import dataclass
4
+
5
+ from betterdb_valkey_search_kit import escape_tag
6
+
7
+ from .types import MemoryScope
8
+
9
+ SCORE_FIELD = "__score"
10
+ VECTOR_FIELD = "vector"
11
+
12
+ # valkey-search rejects a bare '*' on VECTOR-schema indexes with
13
+ # "Invalid query string syntax". Every memory record has created_at
14
+ # (set at write time), so this numeric range matches all documents
15
+ # and is accepted by the vector index schema.
16
+ MATCH_ALL_MEMORY_QUERY = "@created_at:[-inf +inf]"
17
+
18
+
19
+ def _scope_clauses(scope: MemoryScope, tags: list[str]) -> list[str]:
20
+ clauses: list[str] = []
21
+ if scope.thread_id is not None:
22
+ clauses.append(f"@threadId:{{{escape_tag(scope.thread_id)}}}")
23
+ if scope.agent_id is not None:
24
+ clauses.append(f"@agentId:{{{escape_tag(scope.agent_id)}}}")
25
+ if scope.namespace is not None:
26
+ clauses.append(f"@namespace:{{{escape_tag(scope.namespace)}}}")
27
+ for tag in tags:
28
+ clauses.append(f"@tags:{{{escape_tag(tag)}}}")
29
+ return clauses
30
+
31
+
32
+ def _join_clauses(clauses: list[str]) -> str:
33
+ if len(clauses) == 0:
34
+ return MATCH_ALL_MEMORY_QUERY
35
+ return f"({' '.join(clauses)})"
36
+
37
+
38
+ def build_scope_filter(scope: MemoryScope, tags: list[str]) -> str:
39
+ return _join_clauses(_scope_clauses(scope, tags))
40
+
41
+
42
+ @dataclass
43
+ class ConsolidateFilterOptions:
44
+ max_created_at: int | None = None
45
+ max_importance: float | None = None
46
+ exclude_source: str | None = None
47
+
48
+
49
+ def build_consolidate_filter(
50
+ scope: MemoryScope,
51
+ tags: list[str],
52
+ options: ConsolidateFilterOptions,
53
+ ) -> str:
54
+ clauses = _scope_clauses(scope, tags)
55
+ if options.max_created_at is not None:
56
+ clauses.append(f"@created_at:[-inf {options.max_created_at}]")
57
+ if options.max_importance is not None:
58
+ clauses.append(f"@importance:[-inf {options.max_importance}]")
59
+ if options.exclude_source is not None:
60
+ clauses.append(f"-@source:{{{escape_tag(options.exclude_source)}}}")
61
+ return _join_clauses(clauses)
62
+
63
+
64
+ def build_recall_query(k: int, scope: MemoryScope, tags: list[str]) -> str:
65
+ return f"{build_scope_filter(scope, tags)}=>[KNN {k} @{VECTOR_FIELD} $vec AS {SCORE_FIELD}]"
@@ -0,0 +1,35 @@
1
+ from __future__ import annotations
2
+
3
+ import math
4
+
5
+ from .types import RecallWeights
6
+
7
+
8
+ def recency_decay(age_seconds: float, half_life_seconds: float) -> float:
9
+ """True half-life decay: 1 at age 0, 0.5 at one half_life_seconds, approaching 0 beyond."""
10
+ return math.exp((-math.log(2) * age_seconds) / half_life_seconds)
11
+
12
+
13
+ def composite_score(
14
+ *,
15
+ similarity: float,
16
+ age_seconds: float,
17
+ importance: float,
18
+ weights: RecallWeights,
19
+ half_life_seconds: float,
20
+ ) -> float:
21
+ """Weighted blend of semantic similarity, recency, and importance.
22
+
23
+ Recency is a true half-life decay: 0.5 at one half_life_seconds.
24
+ """
25
+ recency = recency_decay(age_seconds, half_life_seconds)
26
+ return (
27
+ weights.similarity * similarity
28
+ + weights.recency * recency
29
+ + weights.importance * importance
30
+ )
31
+
32
+
33
+ def similarity_from_distance(distance: float) -> float:
34
+ """Map cosine distance (0..2, lower = closer) to a 0..1 similarity score."""
35
+ return 1 - distance / 2
@@ -0,0 +1,164 @@
1
+ from __future__ import annotations
2
+
3
+ import asyncio
4
+ import json
5
+ import os
6
+ import socket
7
+ import warnings
8
+ from datetime import datetime, timezone
9
+ from typing import Any, Callable
10
+
11
+ from betterdb_agent_cache.discovery import (
12
+ DEFAULT_HEARTBEAT_INTERVAL_S,
13
+ HEARTBEAT_KEY_PREFIX,
14
+ HEARTBEAT_TTL_SECONDS,
15
+ PROTOCOL_KEY,
16
+ PROTOCOL_VERSION,
17
+ REGISTRY_KEY,
18
+ )
19
+
20
+ from .types import MemoryStoreClient
21
+
22
+ MEMORY_CACHE_TYPE = "agent_memory"
23
+ MEMORY_CAPABILITIES = ["recall", "consolidate", "reinforce"]
24
+
25
+
26
+ class MemoryDiscovery:
27
+ def __init__(
28
+ self,
29
+ *,
30
+ client: MemoryStoreClient,
31
+ name: str,
32
+ version: str,
33
+ stats_key: str,
34
+ heartbeat_interval_s: float | None = None,
35
+ on_write_failed: Callable[[], None] | None = None,
36
+ ) -> None:
37
+ self._client = client
38
+ self._name = name
39
+ self._version = version
40
+ self._stats_key = stats_key
41
+ self._heartbeat_interval_s = (
42
+ heartbeat_interval_s
43
+ if heartbeat_interval_s is not None
44
+ else DEFAULT_HEARTBEAT_INTERVAL_S
45
+ )
46
+ # Namespace the marker under `{name}:mem` so a memory store and an
47
+ # agent-cache sharing the same name register distinct registry fields
48
+ # and heartbeat keys instead of clobbering each other.
49
+ self._marker_field = f"{name}:mem"
50
+ self._heartbeat_key = f"{HEARTBEAT_KEY_PREFIX}{self._marker_field}"
51
+ self._started_at = datetime.now(timezone.utc).isoformat()
52
+ self._on_write_failed: Callable[[], None] = on_write_failed or (lambda: None)
53
+ self._heartbeat_task: asyncio.Task[None] | None = None
54
+
55
+ def build_marker(self) -> dict[str, Any]:
56
+ return {
57
+ "type": MEMORY_CACHE_TYPE,
58
+ "prefix": self._name,
59
+ "version": self._version,
60
+ "protocol_version": PROTOCOL_VERSION,
61
+ "capabilities": list(MEMORY_CAPABILITIES),
62
+ "stats_key": self._stats_key,
63
+ "started_at": self._started_at,
64
+ "pid": os.getpid(),
65
+ "hostname": socket.gethostname(),
66
+ }
67
+
68
+ async def register(self) -> None:
69
+ # HGET-then-HSET is not atomic (TOCTOU); acceptable for best-effort
70
+ # discovery — a racing writer just means last-writer-wins on the marker.
71
+ existing = await self._safe_hget()
72
+ if existing is not None:
73
+ self._check_collision(existing)
74
+ await self._write_marker()
75
+ await self._safe_call(
76
+ lambda: self._client.execute_command("SET", PROTOCOL_KEY, str(PROTOCOL_VERSION), "NX")
77
+ )
78
+ await self._write_heartbeat()
79
+ self._start_heartbeat()
80
+
81
+ async def stop(self, *, delete_heartbeat: bool) -> None:
82
+ if self._heartbeat_task is not None:
83
+ self._heartbeat_task.cancel()
84
+ try:
85
+ await self._heartbeat_task
86
+ except (asyncio.CancelledError, Exception):
87
+ pass
88
+ self._heartbeat_task = None
89
+ if not delete_heartbeat:
90
+ return
91
+ try:
92
+ await self._client.execute_command("DEL", self._heartbeat_key)
93
+ except Exception:
94
+ self._on_write_failed()
95
+
96
+ async def tick_heartbeat(self) -> None:
97
+ await self._write_heartbeat()
98
+ await self._write_marker()
99
+ # PROTOCOL_KEY is set once in register(); the NX SET is a guaranteed
100
+ # no-op on every subsequent tick, so it's not re-issued here.
101
+
102
+ def _start_heartbeat(self) -> None:
103
+ async def _loop() -> None:
104
+ try:
105
+ while True:
106
+ await asyncio.sleep(self._heartbeat_interval_s)
107
+ await self.tick_heartbeat()
108
+ except asyncio.CancelledError:
109
+ pass
110
+
111
+ self._heartbeat_task = asyncio.create_task(_loop())
112
+
113
+ async def _write_heartbeat(self) -> None:
114
+ now = datetime.now(timezone.utc).isoformat()
115
+ try:
116
+ await self._client.execute_command(
117
+ "SET", self._heartbeat_key, now, "EX", str(HEARTBEAT_TTL_SECONDS)
118
+ )
119
+ except Exception:
120
+ self._on_write_failed()
121
+
122
+ async def _write_marker(self) -> None:
123
+ try:
124
+ payload = json.dumps(self.build_marker())
125
+ except Exception:
126
+ self._on_write_failed()
127
+ return
128
+ await self._safe_call(
129
+ lambda: self._client.execute_command("HSET", REGISTRY_KEY, self._marker_field, payload)
130
+ )
131
+
132
+ async def _safe_hget(self) -> str | None:
133
+ try:
134
+ value = await self._client.execute_command("HGET", REGISTRY_KEY, self._marker_field)
135
+ if value is None:
136
+ return None
137
+ return value.decode() if isinstance(value, bytes) else str(value)
138
+ except Exception:
139
+ self._on_write_failed()
140
+ return None
141
+
142
+ async def _safe_call(self, fn: Callable[[], Any]) -> None:
143
+ try:
144
+ await fn()
145
+ except Exception:
146
+ self._on_write_failed()
147
+
148
+ def _check_collision(self, existing_json: str) -> None:
149
+ try:
150
+ parsed = json.loads(existing_json)
151
+ except Exception:
152
+ return
153
+ existing_type = parsed.get("type") if isinstance(parsed, dict) else None
154
+ if existing_type and existing_type != MEMORY_CACHE_TYPE:
155
+ # The memory marker lives under `{name}:mem`, distinct from
156
+ # agent-cache's `{name}`, so the two tiers never collide here.
157
+ # Surface it with a visible warning rather than raising into a
158
+ # swallowed registration; registration then proceeds
159
+ # last-writer-wins, matching agent-cache.
160
+ warnings.warn(
161
+ f"agent-memory discovery: field '{self._marker_field}' already holds a "
162
+ f"'{existing_type}' marker; overwriting",
163
+ stacklevel=2,
164
+ )