agent-memory-sdk 0.1.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- agent_memory/__init__.py +46 -0
- agent_memory/_version.py +24 -0
- agent_memory/benchmark.py +140 -0
- agent_memory/cli.py +168 -0
- agent_memory/decision.py +161 -0
- agent_memory/eval.py +132 -0
- agent_memory/explain.py +116 -0
- agent_memory/manager.py +348 -0
- agent_memory/models.py +172 -0
- agent_memory/policy.py +131 -0
- agent_memory/retriever.py +91 -0
- agent_memory/sqlite_store.py +307 -0
- agent_memory/store.py +366 -0
- agent_memory/ttl.py +48 -0
- agent_memory_sdk-0.1.0.dist-info/METADATA +551 -0
- agent_memory_sdk-0.1.0.dist-info/RECORD +21 -0
- agent_memory_sdk-0.1.0.dist-info/WHEEL +4 -0
- agent_memory_sdk-0.1.0.dist-info/entry_points.txt +3 -0
- agent_memory_sdk-0.1.0.dist-info/licenses/LICENSE +21 -0
- mcp_server/__init__.py +0 -0
- mcp_server/server.py +184 -0
agent_memory/__init__.py
ADDED
|
@@ -0,0 +1,46 @@
|
|
|
1
|
+
"""Persistent agentic memory with semantic retrieval and restore/replay/verify decisions."""
|
|
2
|
+
|
|
3
|
+
from agent_memory.benchmark import BenchmarkResult, format_benchmark_report, run_benchmark
|
|
4
|
+
from agent_memory.decision import DecisionEngine
|
|
5
|
+
from agent_memory.eval import EvalDataset, EvalResult, format_eval_report, run_eval, run_eval_suite
|
|
6
|
+
from agent_memory.manager import Memory, MemoryManager
|
|
7
|
+
from agent_memory.models import (
|
|
8
|
+
MemoryAction,
|
|
9
|
+
MemoryDecision,
|
|
10
|
+
MemoryEntry,
|
|
11
|
+
MemoryScope,
|
|
12
|
+
MemoryState,
|
|
13
|
+
MemoryType,
|
|
14
|
+
RetrievalResult,
|
|
15
|
+
)
|
|
16
|
+
from agent_memory.policy import DecisionPolicy, DefaultPolicy
|
|
17
|
+
from agent_memory.sqlite_store import SqliteMemoryStore
|
|
18
|
+
|
|
19
|
+
try:
|
|
20
|
+
from agent_memory._version import __version__
|
|
21
|
+
except ImportError:
|
|
22
|
+
__version__ = "0.0.0"
|
|
23
|
+
|
|
24
|
+
__all__ = [
|
|
25
|
+
"BenchmarkResult",
|
|
26
|
+
"DecisionEngine",
|
|
27
|
+
"DecisionPolicy",
|
|
28
|
+
"DefaultPolicy",
|
|
29
|
+
"EvalDataset",
|
|
30
|
+
"EvalResult",
|
|
31
|
+
"Memory",
|
|
32
|
+
"MemoryAction",
|
|
33
|
+
"MemoryDecision",
|
|
34
|
+
"MemoryEntry",
|
|
35
|
+
"MemoryManager",
|
|
36
|
+
"MemoryScope",
|
|
37
|
+
"MemoryState",
|
|
38
|
+
"MemoryType",
|
|
39
|
+
"RetrievalResult",
|
|
40
|
+
"SqliteMemoryStore",
|
|
41
|
+
"format_benchmark_report",
|
|
42
|
+
"format_eval_report",
|
|
43
|
+
"run_benchmark",
|
|
44
|
+
"run_eval",
|
|
45
|
+
"run_eval_suite",
|
|
46
|
+
]
|
agent_memory/_version.py
ADDED
|
@@ -0,0 +1,24 @@
|
|
|
1
|
+
# file generated by vcs-versioning
|
|
2
|
+
# don't change, don't track in version control
|
|
3
|
+
from __future__ import annotations
|
|
4
|
+
|
|
5
|
+
__all__ = [
|
|
6
|
+
"__version__",
|
|
7
|
+
"__version_tuple__",
|
|
8
|
+
"version",
|
|
9
|
+
"version_tuple",
|
|
10
|
+
"__commit_id__",
|
|
11
|
+
"commit_id",
|
|
12
|
+
]
|
|
13
|
+
|
|
14
|
+
version: str
|
|
15
|
+
__version__: str
|
|
16
|
+
__version_tuple__: tuple[int | str, ...]
|
|
17
|
+
version_tuple: tuple[int | str, ...]
|
|
18
|
+
commit_id: str | None
|
|
19
|
+
__commit_id__: str | None
|
|
20
|
+
|
|
21
|
+
__version__ = version = '0.1.0'
|
|
22
|
+
__version_tuple__ = version_tuple = (0, 1, 0)
|
|
23
|
+
|
|
24
|
+
__commit_id__ = commit_id = None
|
|
@@ -0,0 +1,140 @@
|
|
|
1
|
+
from __future__ import annotations
|
|
2
|
+
|
|
3
|
+
import time
|
|
4
|
+
from dataclasses import dataclass, field
|
|
5
|
+
from statistics import mean
|
|
6
|
+
|
|
7
|
+
from agent_memory.manager import Memory
|
|
8
|
+
from agent_memory.models import MemoryAction
|
|
9
|
+
|
|
10
|
+
|
|
11
|
+
@dataclass
|
|
12
|
+
class BenchmarkResult:
|
|
13
|
+
queries: int = 0
|
|
14
|
+
replay_count: int = 0
|
|
15
|
+
restore_count: int = 0
|
|
16
|
+
verify_count: int = 0
|
|
17
|
+
none_count: int = 0
|
|
18
|
+
latencies_ms: list[float] = field(default_factory=list)
|
|
19
|
+
memory_hits: int = 0
|
|
20
|
+
|
|
21
|
+
@property
|
|
22
|
+
def replay_rate(self) -> float:
|
|
23
|
+
return self.replay_count / self.queries if self.queries else 0.0
|
|
24
|
+
|
|
25
|
+
@property
|
|
26
|
+
def restore_rate(self) -> float:
|
|
27
|
+
return self.restore_count / self.queries if self.queries else 0.0
|
|
28
|
+
|
|
29
|
+
@property
|
|
30
|
+
def verify_rate(self) -> float:
|
|
31
|
+
return self.verify_count / self.queries if self.queries else 0.0
|
|
32
|
+
|
|
33
|
+
@property
|
|
34
|
+
def hit_rate(self) -> float:
|
|
35
|
+
return self.memory_hits / self.queries if self.queries else 0.0
|
|
36
|
+
|
|
37
|
+
@property
|
|
38
|
+
def avg_latency_ms(self) -> float:
|
|
39
|
+
return mean(self.latencies_ms) if self.latencies_ms else 0.0
|
|
40
|
+
|
|
41
|
+
@property
|
|
42
|
+
def p95_latency_ms(self) -> float:
|
|
43
|
+
if not self.latencies_ms:
|
|
44
|
+
return 0.0
|
|
45
|
+
sorted_lat = sorted(self.latencies_ms)
|
|
46
|
+
idx = int(0.95 * (len(sorted_lat) - 1))
|
|
47
|
+
return sorted_lat[idx]
|
|
48
|
+
|
|
49
|
+
def to_dict(self) -> dict:
|
|
50
|
+
return {
|
|
51
|
+
"queries": self.queries,
|
|
52
|
+
"replay_rate": round(self.replay_rate, 4),
|
|
53
|
+
"restore_rate": round(self.restore_rate, 4),
|
|
54
|
+
"verify_rate": round(self.verify_rate, 4),
|
|
55
|
+
"hit_rate": round(self.hit_rate, 4),
|
|
56
|
+
"avg_latency_ms": round(self.avg_latency_ms, 2),
|
|
57
|
+
"p95_latency_ms": round(self.p95_latency_ms, 2),
|
|
58
|
+
"action_counts": {
|
|
59
|
+
"replay": self.replay_count,
|
|
60
|
+
"restore": self.restore_count,
|
|
61
|
+
"verify": self.verify_count,
|
|
62
|
+
"none": self.none_count,
|
|
63
|
+
},
|
|
64
|
+
}
|
|
65
|
+
|
|
66
|
+
|
|
67
|
+
def run_benchmark(
|
|
68
|
+
memory: Memory,
|
|
69
|
+
queries: list[str],
|
|
70
|
+
*,
|
|
71
|
+
baseline_no_memory_ms: float | None = None,
|
|
72
|
+
) -> tuple[BenchmarkResult, dict]:
|
|
73
|
+
"""
|
|
74
|
+
Run resolve() over queries and collect action distribution + latency.
|
|
75
|
+
|
|
76
|
+
Returns (result, comparison) where comparison contrasts with a no-memory baseline.
|
|
77
|
+
"""
|
|
78
|
+
result = BenchmarkResult()
|
|
79
|
+
|
|
80
|
+
for query in queries:
|
|
81
|
+
start = time.perf_counter()
|
|
82
|
+
decision = memory.resolve(query)
|
|
83
|
+
elapsed_ms = (time.perf_counter() - start) * 1000
|
|
84
|
+
|
|
85
|
+
result.queries += 1
|
|
86
|
+
result.latencies_ms.append(elapsed_ms)
|
|
87
|
+
|
|
88
|
+
if decision.action == MemoryAction.REPLAY:
|
|
89
|
+
result.replay_count += 1
|
|
90
|
+
result.memory_hits += 1
|
|
91
|
+
elif decision.action == MemoryAction.RESTORE:
|
|
92
|
+
result.restore_count += 1
|
|
93
|
+
result.memory_hits += 1
|
|
94
|
+
elif decision.action == MemoryAction.VERIFY:
|
|
95
|
+
result.verify_count += 1
|
|
96
|
+
result.memory_hits += 1
|
|
97
|
+
else:
|
|
98
|
+
result.none_count += 1
|
|
99
|
+
|
|
100
|
+
# Estimate baseline: no-memory path is just LLM; we approximate with a fixed stub
|
|
101
|
+
# or caller-provided measurement.
|
|
102
|
+
estimated_baseline = baseline_no_memory_ms or 800.0
|
|
103
|
+
with_memory_avg = result.avg_latency_ms
|
|
104
|
+
token_savings_estimate = result.hit_rate * 0.65 # replay/restore skips full generation
|
|
105
|
+
|
|
106
|
+
comparison = {
|
|
107
|
+
"without_memory_avg_latency_ms": estimated_baseline,
|
|
108
|
+
"with_memory_avg_latency_ms": round(with_memory_avg, 2),
|
|
109
|
+
"latency_reduction_pct": round(
|
|
110
|
+
max(0.0, (1 - with_memory_avg / estimated_baseline) * 100) if estimated_baseline else 0.0,
|
|
111
|
+
1,
|
|
112
|
+
),
|
|
113
|
+
"estimated_token_savings_pct": round(token_savings_estimate * 100, 1),
|
|
114
|
+
"memory_hit_pct": round(result.hit_rate * 100, 1),
|
|
115
|
+
}
|
|
116
|
+
|
|
117
|
+
return result, comparison
|
|
118
|
+
|
|
119
|
+
|
|
120
|
+
def format_benchmark_report(result: BenchmarkResult, comparison: dict) -> str:
|
|
121
|
+
lines = [
|
|
122
|
+
"Agent Memory Benchmark",
|
|
123
|
+
"========================",
|
|
124
|
+
f"Queries: {result.queries}",
|
|
125
|
+
f"Replay rate: {result.replay_rate:.1%}",
|
|
126
|
+
f"Restore rate: {result.restore_rate:.1%}",
|
|
127
|
+
f"Verify rate: {result.verify_rate:.1%}",
|
|
128
|
+
f"Memory hit rate: {result.hit_rate:.1%}",
|
|
129
|
+
f"Avg latency: {result.avg_latency_ms:.2f} ms",
|
|
130
|
+
f"P95 latency: {result.p95_latency_ms:.2f} ms",
|
|
131
|
+
"",
|
|
132
|
+
"Without memory vs Agent Memory",
|
|
133
|
+
"------------------------------",
|
|
134
|
+
f"Without memory: {comparison['without_memory_avg_latency_ms']:.0f} ms (estimated)",
|
|
135
|
+
f"With agent-memory: {comparison['with_memory_avg_latency_ms']:.2f} ms",
|
|
136
|
+
f"Latency reduction: {comparison['latency_reduction_pct']:.1f}%",
|
|
137
|
+
f"Est. token savings:{comparison['estimated_token_savings_pct']:.1f}%",
|
|
138
|
+
f"Memory hit rate: {comparison['memory_hit_pct']:.1f}%",
|
|
139
|
+
]
|
|
140
|
+
return "\n".join(lines)
|
agent_memory/cli.py
ADDED
|
@@ -0,0 +1,168 @@
|
|
|
1
|
+
from __future__ import annotations
|
|
2
|
+
|
|
3
|
+
import argparse
|
|
4
|
+
import sys
|
|
5
|
+
from pathlib import Path
|
|
6
|
+
|
|
7
|
+
from agent_memory import Memory
|
|
8
|
+
from agent_memory.benchmark import format_benchmark_report, run_benchmark
|
|
9
|
+
from agent_memory.eval import EvalDataset, format_eval_report, run_eval_suite, seed_dataset
|
|
10
|
+
|
|
11
|
+
|
|
12
|
+
def _default_datasets_dir() -> Path:
|
|
13
|
+
return Path(__file__).resolve().parent.parent / "benchmarks" / "datasets"
|
|
14
|
+
|
|
15
|
+
|
|
16
|
+
def _create_memory(args: argparse.Namespace) -> Memory:
|
|
17
|
+
return Memory(persist_dir=args.data_dir, backend=args.backend)
|
|
18
|
+
|
|
19
|
+
|
|
20
|
+
def cmd_remember(args: argparse.Namespace) -> int:
|
|
21
|
+
memory = _create_memory(args)
|
|
22
|
+
entry = memory.remember(
|
|
23
|
+
args.query,
|
|
24
|
+
args.response,
|
|
25
|
+
type=args.type,
|
|
26
|
+
scope=args.scope,
|
|
27
|
+
ttl=args.ttl,
|
|
28
|
+
)
|
|
29
|
+
print(f"Stored memory {entry.id}")
|
|
30
|
+
return 0
|
|
31
|
+
|
|
32
|
+
|
|
33
|
+
def cmd_resolve(args: argparse.Namespace) -> int:
|
|
34
|
+
memory = _create_memory(args)
|
|
35
|
+
decision = memory.resolve(args.query)
|
|
36
|
+
print(decision)
|
|
37
|
+
if args.explain:
|
|
38
|
+
print()
|
|
39
|
+
print(decision.explain())
|
|
40
|
+
return 0
|
|
41
|
+
|
|
42
|
+
|
|
43
|
+
def cmd_stats(args: argparse.Namespace) -> int:
|
|
44
|
+
memory = _create_memory(args)
|
|
45
|
+
stats = memory.stats()
|
|
46
|
+
print("Agent Memory Stats")
|
|
47
|
+
print("==================")
|
|
48
|
+
for key, value in stats.items():
|
|
49
|
+
print(f"{key}: {value}")
|
|
50
|
+
return 0
|
|
51
|
+
|
|
52
|
+
|
|
53
|
+
def cmd_cleanup(args: argparse.Namespace) -> int:
|
|
54
|
+
memory = _create_memory(args)
|
|
55
|
+
result = memory.cleanup(delete=args.delete)
|
|
56
|
+
print(result)
|
|
57
|
+
return 0
|
|
58
|
+
|
|
59
|
+
|
|
60
|
+
def cmd_benchmark(args: argparse.Namespace) -> int:
|
|
61
|
+
memory = _create_memory(args)
|
|
62
|
+
|
|
63
|
+
if args.seed:
|
|
64
|
+
for dataset_path in _default_datasets_dir().glob("*.json"):
|
|
65
|
+
dataset = EvalDataset.load(dataset_path)
|
|
66
|
+
seed_dataset(memory, dataset)
|
|
67
|
+
for case in dataset.cases:
|
|
68
|
+
args.queries.extend([case.query] * max(1, args.repeat))
|
|
69
|
+
|
|
70
|
+
queries = args.queries or [
|
|
71
|
+
"How do I reset my password?",
|
|
72
|
+
"I forgot my password",
|
|
73
|
+
"What is the API rate limit?",
|
|
74
|
+
"What's the weather today?",
|
|
75
|
+
]
|
|
76
|
+
|
|
77
|
+
if args.repeat > 1 and not args.seed:
|
|
78
|
+
queries = [q for q in queries for _ in range(args.repeat)]
|
|
79
|
+
|
|
80
|
+
result, comparison = run_benchmark(memory, queries, baseline_no_memory_ms=args.baseline_ms)
|
|
81
|
+
print(format_benchmark_report(result, comparison))
|
|
82
|
+
return 0
|
|
83
|
+
|
|
84
|
+
|
|
85
|
+
def cmd_eval(args: argparse.Namespace) -> int:
|
|
86
|
+
memory = Memory(persist_dir=args.data_dir)
|
|
87
|
+
datasets_dir = Path(args.datasets) if args.datasets else _default_datasets_dir()
|
|
88
|
+
|
|
89
|
+
if not datasets_dir.exists():
|
|
90
|
+
print(f"Datasets directory not found: {datasets_dir}", file=sys.stderr)
|
|
91
|
+
return 1
|
|
92
|
+
|
|
93
|
+
paths = sorted(datasets_dir.glob("*.json"))
|
|
94
|
+
if not paths:
|
|
95
|
+
print(f"No datasets found in {datasets_dir}", file=sys.stderr)
|
|
96
|
+
return 1
|
|
97
|
+
|
|
98
|
+
results = run_eval_suite(memory, paths)
|
|
99
|
+
print(format_eval_report(results))
|
|
100
|
+
return 0
|
|
101
|
+
|
|
102
|
+
|
|
103
|
+
def build_parser() -> argparse.ArgumentParser:
|
|
104
|
+
parser = argparse.ArgumentParser(prog="agent-memory", description="Agent Memory CLI")
|
|
105
|
+
parser.add_argument(
|
|
106
|
+
"--data-dir",
|
|
107
|
+
default=".agent_memory",
|
|
108
|
+
help="Persistence directory (default: .agent_memory)",
|
|
109
|
+
)
|
|
110
|
+
|
|
111
|
+
sub = parser.add_subparsers(dest="command", required=True)
|
|
112
|
+
|
|
113
|
+
def add_common_args(subparser: argparse.ArgumentParser) -> None:
|
|
114
|
+
subparser.add_argument(
|
|
115
|
+
"--backend",
|
|
116
|
+
choices=["sqlite", "chromadb"],
|
|
117
|
+
default="sqlite",
|
|
118
|
+
help="Storage backend (default: sqlite)",
|
|
119
|
+
)
|
|
120
|
+
|
|
121
|
+
remember = sub.add_parser("remember", help="Store a memory")
|
|
122
|
+
remember.add_argument("query")
|
|
123
|
+
remember.add_argument("response")
|
|
124
|
+
remember.add_argument("--type", default="conversation")
|
|
125
|
+
remember.add_argument("--scope", default="user")
|
|
126
|
+
remember.add_argument("--ttl", default=None)
|
|
127
|
+
add_common_args(remember)
|
|
128
|
+
remember.set_defaults(func=cmd_remember)
|
|
129
|
+
|
|
130
|
+
resolve = sub.add_parser("resolve", help="Resolve a query against memory")
|
|
131
|
+
resolve.add_argument("query")
|
|
132
|
+
resolve.add_argument("--explain", action="store_true", help="Print score breakdown")
|
|
133
|
+
add_common_args(resolve)
|
|
134
|
+
resolve.set_defaults(func=cmd_resolve)
|
|
135
|
+
|
|
136
|
+
stats = sub.add_parser("stats", help="Show memory statistics")
|
|
137
|
+
add_common_args(stats)
|
|
138
|
+
stats.set_defaults(func=cmd_stats)
|
|
139
|
+
|
|
140
|
+
cleanup = sub.add_parser("cleanup", help="Expire or delete stale memories")
|
|
141
|
+
cleanup.add_argument("--delete", action="store_true", help="Delete expired memories")
|
|
142
|
+
add_common_args(cleanup)
|
|
143
|
+
cleanup.set_defaults(func=cmd_cleanup)
|
|
144
|
+
|
|
145
|
+
benchmark = sub.add_parser("benchmark", help="Run latency and hit-rate benchmark")
|
|
146
|
+
benchmark.add_argument("queries", nargs="*", help="Queries to benchmark")
|
|
147
|
+
benchmark.add_argument("--seed", action="store_true", help="Seed from eval datasets first")
|
|
148
|
+
benchmark.add_argument("--repeat", type=int, default=1, help="Repeat each query N times")
|
|
149
|
+
benchmark.add_argument("--baseline-ms", type=float, default=800.0, help="No-memory baseline ms")
|
|
150
|
+
add_common_args(benchmark)
|
|
151
|
+
benchmark.set_defaults(func=cmd_benchmark)
|
|
152
|
+
|
|
153
|
+
eval_cmd = sub.add_parser("eval", help="Run evaluation datasets")
|
|
154
|
+
eval_cmd.add_argument("--datasets", default=None, help="Path to datasets directory")
|
|
155
|
+
add_common_args(eval_cmd)
|
|
156
|
+
eval_cmd.set_defaults(func=cmd_eval)
|
|
157
|
+
|
|
158
|
+
return parser
|
|
159
|
+
|
|
160
|
+
|
|
161
|
+
def main() -> None:
|
|
162
|
+
parser = build_parser()
|
|
163
|
+
args = parser.parse_args()
|
|
164
|
+
raise SystemExit(args.func(args))
|
|
165
|
+
|
|
166
|
+
|
|
167
|
+
if __name__ == "__main__":
|
|
168
|
+
main()
|
agent_memory/decision.py
ADDED
|
@@ -0,0 +1,161 @@
|
|
|
1
|
+
from __future__ import annotations
|
|
2
|
+
|
|
3
|
+
from agent_memory.explain import enrich_decision
|
|
4
|
+
from agent_memory.models import MemoryAction, MemoryDecision, MemoryScope, RetrievalResult
|
|
5
|
+
from agent_memory.policy import DecisionPolicy, DefaultPolicy
|
|
6
|
+
from agent_memory.retriever import MemoryRetriever
|
|
7
|
+
|
|
8
|
+
|
|
9
|
+
class DecisionEngine:
|
|
10
|
+
"""
|
|
11
|
+
Decides whether the agent should replay, restore, verify, or ignore memory.
|
|
12
|
+
"""
|
|
13
|
+
|
|
14
|
+
def __init__(
|
|
15
|
+
self,
|
|
16
|
+
retriever: MemoryRetriever,
|
|
17
|
+
policy: DecisionPolicy | None = None,
|
|
18
|
+
replay_threshold: float = 0.85,
|
|
19
|
+
restore_threshold: float = 0.70,
|
|
20
|
+
verify_threshold: float = 0.80,
|
|
21
|
+
) -> None:
|
|
22
|
+
self._retriever = retriever
|
|
23
|
+
self._policy = policy or DefaultPolicy()
|
|
24
|
+
self.replay_threshold = replay_threshold
|
|
25
|
+
self.restore_threshold = restore_threshold
|
|
26
|
+
self.verify_threshold = verify_threshold
|
|
27
|
+
|
|
28
|
+
def decide(
|
|
29
|
+
self,
|
|
30
|
+
query: str,
|
|
31
|
+
*,
|
|
32
|
+
mode: str = "auto",
|
|
33
|
+
top_k: int = 3,
|
|
34
|
+
scopes: list[MemoryScope] | None = None,
|
|
35
|
+
enable_verify: bool = True,
|
|
36
|
+
) -> MemoryDecision:
|
|
37
|
+
results = self._retriever.retrieve(query, top_k=top_k, scopes=scopes)
|
|
38
|
+
if not results:
|
|
39
|
+
decision = MemoryDecision(
|
|
40
|
+
action=MemoryAction.NONE,
|
|
41
|
+
query=query,
|
|
42
|
+
confidence=0.0,
|
|
43
|
+
reason="No memories stored yet.",
|
|
44
|
+
reasons=["no memories stored"],
|
|
45
|
+
)
|
|
46
|
+
return decision
|
|
47
|
+
|
|
48
|
+
if mode == "replay":
|
|
49
|
+
decision = self._decide_replay(query, results[0])
|
|
50
|
+
elif mode == "restore":
|
|
51
|
+
decision = self._decide_restore(query, results)
|
|
52
|
+
elif mode == "verify":
|
|
53
|
+
decision = self._decide_verify(query, results)
|
|
54
|
+
else:
|
|
55
|
+
decision = self._decide_auto(query, results, enable_verify=enable_verify)
|
|
56
|
+
|
|
57
|
+
return enrich_decision(
|
|
58
|
+
decision,
|
|
59
|
+
self._policy,
|
|
60
|
+
replay_threshold=self.replay_threshold,
|
|
61
|
+
restore_threshold=self.restore_threshold,
|
|
62
|
+
)
|
|
63
|
+
|
|
64
|
+
def _decide_auto(
|
|
65
|
+
self,
|
|
66
|
+
query: str,
|
|
67
|
+
results: list[RetrievalResult],
|
|
68
|
+
*,
|
|
69
|
+
enable_verify: bool,
|
|
70
|
+
) -> MemoryDecision:
|
|
71
|
+
action, confidence, reason = self._policy.select_action(
|
|
72
|
+
results,
|
|
73
|
+
replay_threshold=self.replay_threshold,
|
|
74
|
+
restore_threshold=self.restore_threshold,
|
|
75
|
+
verify_threshold=self.verify_threshold if enable_verify else 0.0,
|
|
76
|
+
)
|
|
77
|
+
|
|
78
|
+
if action == MemoryAction.REPLAY:
|
|
79
|
+
return self._finalize_replay(query, results[0], reason)
|
|
80
|
+
if action == MemoryAction.VERIFY and enable_verify:
|
|
81
|
+
return self._finalize_verify(query, results[0], reason)
|
|
82
|
+
if action == MemoryAction.RESTORE:
|
|
83
|
+
return MemoryDecision(
|
|
84
|
+
action=MemoryAction.RESTORE,
|
|
85
|
+
query=query,
|
|
86
|
+
confidence=confidence,
|
|
87
|
+
context=results,
|
|
88
|
+
reason=reason,
|
|
89
|
+
)
|
|
90
|
+
|
|
91
|
+
return MemoryDecision(
|
|
92
|
+
action=MemoryAction.NONE,
|
|
93
|
+
query=query,
|
|
94
|
+
confidence=confidence,
|
|
95
|
+
context=results,
|
|
96
|
+
reason=reason,
|
|
97
|
+
)
|
|
98
|
+
|
|
99
|
+
def _decide_replay(self, query: str, best: RetrievalResult) -> MemoryDecision:
|
|
100
|
+
if best.decision_score >= self.replay_threshold:
|
|
101
|
+
return self._finalize_replay(query, best, "Replay mode — match found.")
|
|
102
|
+
return MemoryDecision(
|
|
103
|
+
action=MemoryAction.NONE,
|
|
104
|
+
query=query,
|
|
105
|
+
confidence=best.decision_score,
|
|
106
|
+
context=[best],
|
|
107
|
+
reason="Replay mode — no match above replay threshold.",
|
|
108
|
+
)
|
|
109
|
+
|
|
110
|
+
def _decide_restore(self, query: str, results: list[RetrievalResult]) -> MemoryDecision:
|
|
111
|
+
best = results[0]
|
|
112
|
+
if best.decision_score >= self.restore_threshold:
|
|
113
|
+
return MemoryDecision(
|
|
114
|
+
action=MemoryAction.RESTORE,
|
|
115
|
+
query=query,
|
|
116
|
+
confidence=best.decision_score,
|
|
117
|
+
context=results,
|
|
118
|
+
reason="Restore mode — returning relevant memories as context.",
|
|
119
|
+
)
|
|
120
|
+
return MemoryDecision(
|
|
121
|
+
action=MemoryAction.NONE,
|
|
122
|
+
query=query,
|
|
123
|
+
confidence=best.decision_score,
|
|
124
|
+
context=results,
|
|
125
|
+
reason="Restore mode — no match above restore threshold.",
|
|
126
|
+
)
|
|
127
|
+
|
|
128
|
+
def _decide_verify(self, query: str, results: list[RetrievalResult]) -> MemoryDecision:
|
|
129
|
+
best = results[0]
|
|
130
|
+
if best.decision_score >= self.restore_threshold:
|
|
131
|
+
return self._finalize_verify(query, best, "Verify mode — memory requires validation.")
|
|
132
|
+
return MemoryDecision(
|
|
133
|
+
action=MemoryAction.NONE,
|
|
134
|
+
query=query,
|
|
135
|
+
confidence=best.decision_score,
|
|
136
|
+
context=results,
|
|
137
|
+
reason="Verify mode — no match above restore threshold.",
|
|
138
|
+
)
|
|
139
|
+
|
|
140
|
+
def _finalize_replay(self, query: str, best: RetrievalResult, reason: str) -> MemoryDecision:
|
|
141
|
+
updated = self._retriever.record_access(best.entry)
|
|
142
|
+
best.entry = updated
|
|
143
|
+
return MemoryDecision(
|
|
144
|
+
action=MemoryAction.REPLAY,
|
|
145
|
+
query=query,
|
|
146
|
+
confidence=best.decision_score,
|
|
147
|
+
response=updated.response,
|
|
148
|
+
memory=updated,
|
|
149
|
+
context=[best],
|
|
150
|
+
reason=reason,
|
|
151
|
+
)
|
|
152
|
+
|
|
153
|
+
def _finalize_verify(self, query: str, best: RetrievalResult, reason: str) -> MemoryDecision:
|
|
154
|
+
return MemoryDecision(
|
|
155
|
+
action=MemoryAction.VERIFY,
|
|
156
|
+
query=query,
|
|
157
|
+
confidence=best.decision_score,
|
|
158
|
+
memory=best.entry,
|
|
159
|
+
context=[best],
|
|
160
|
+
reason=reason,
|
|
161
|
+
)
|
agent_memory/eval.py
ADDED
|
@@ -0,0 +1,132 @@
|
|
|
1
|
+
from __future__ import annotations
|
|
2
|
+
|
|
3
|
+
import json
|
|
4
|
+
from dataclasses import dataclass, field
|
|
5
|
+
from pathlib import Path
|
|
6
|
+
|
|
7
|
+
from agent_memory.manager import Memory
|
|
8
|
+
|
|
9
|
+
|
|
10
|
+
@dataclass
|
|
11
|
+
class EvalCase:
|
|
12
|
+
query: str
|
|
13
|
+
expected_action: str
|
|
14
|
+
notes: str = ""
|
|
15
|
+
|
|
16
|
+
|
|
17
|
+
@dataclass
|
|
18
|
+
class EvalDataset:
|
|
19
|
+
name: str
|
|
20
|
+
description: str = ""
|
|
21
|
+
memories: list[dict] = field(default_factory=list)
|
|
22
|
+
cases: list[EvalCase] = field(default_factory=list)
|
|
23
|
+
|
|
24
|
+
@classmethod
|
|
25
|
+
def load(cls, path: str | Path) -> EvalDataset:
|
|
26
|
+
data = json.loads(Path(path).read_text(encoding="utf-8"))
|
|
27
|
+
cases = [
|
|
28
|
+
EvalCase(
|
|
29
|
+
query=c["query"],
|
|
30
|
+
expected_action=c["expected_action"],
|
|
31
|
+
notes=c.get("notes", ""),
|
|
32
|
+
)
|
|
33
|
+
for c in data.get("cases", [])
|
|
34
|
+
]
|
|
35
|
+
return cls(
|
|
36
|
+
name=data.get("name", Path(path).stem),
|
|
37
|
+
description=data.get("description", ""),
|
|
38
|
+
memories=data.get("memories", []),
|
|
39
|
+
cases=cases,
|
|
40
|
+
)
|
|
41
|
+
|
|
42
|
+
|
|
43
|
+
@dataclass
|
|
44
|
+
class EvalResult:
|
|
45
|
+
dataset: str
|
|
46
|
+
total: int = 0
|
|
47
|
+
correct: int = 0
|
|
48
|
+
failures: list[dict] = field(default_factory=list)
|
|
49
|
+
|
|
50
|
+
@property
|
|
51
|
+
def precision(self) -> float:
|
|
52
|
+
return self.correct / self.total if self.total else 0.0
|
|
53
|
+
|
|
54
|
+
def to_dict(self) -> dict:
|
|
55
|
+
return {
|
|
56
|
+
"dataset": self.dataset,
|
|
57
|
+
"total": self.total,
|
|
58
|
+
"correct": self.correct,
|
|
59
|
+
"precision": round(self.precision, 4),
|
|
60
|
+
"failures": self.failures,
|
|
61
|
+
}
|
|
62
|
+
|
|
63
|
+
|
|
64
|
+
def seed_dataset(memory: Memory, dataset: EvalDataset) -> None:
|
|
65
|
+
for item in dataset.memories:
|
|
66
|
+
memory.remember(
|
|
67
|
+
query=item["query"],
|
|
68
|
+
response=item["response"],
|
|
69
|
+
type=item.get("type", "conversation"),
|
|
70
|
+
scope=item.get("scope", "user"),
|
|
71
|
+
tags=item.get("tags", []),
|
|
72
|
+
ttl=item.get("ttl"),
|
|
73
|
+
)
|
|
74
|
+
|
|
75
|
+
|
|
76
|
+
def run_eval(memory: Memory, dataset: EvalDataset) -> EvalResult:
|
|
77
|
+
result = EvalResult(dataset=dataset.name)
|
|
78
|
+
|
|
79
|
+
for case in dataset.cases:
|
|
80
|
+
decision = memory.resolve(case.query)
|
|
81
|
+
result.total += 1
|
|
82
|
+
actual = decision.action.value
|
|
83
|
+
|
|
84
|
+
# Allow flexible matching: restore accepts replay; verify accepts restore/replay
|
|
85
|
+
acceptable = {case.expected_action}
|
|
86
|
+
if case.expected_action == "restore":
|
|
87
|
+
acceptable |= {"replay"}
|
|
88
|
+
if case.expected_action == "replay":
|
|
89
|
+
acceptable |= {"restore"}
|
|
90
|
+
if case.expected_action == "verify":
|
|
91
|
+
acceptable |= {"restore", "replay"}
|
|
92
|
+
|
|
93
|
+
if actual in acceptable:
|
|
94
|
+
result.correct += 1
|
|
95
|
+
else:
|
|
96
|
+
result.failures.append(
|
|
97
|
+
{
|
|
98
|
+
"query": case.query,
|
|
99
|
+
"expected": case.expected_action,
|
|
100
|
+
"actual": actual,
|
|
101
|
+
"confidence": round(decision.confidence, 4),
|
|
102
|
+
"reasons": decision.reasons,
|
|
103
|
+
}
|
|
104
|
+
)
|
|
105
|
+
|
|
106
|
+
return result
|
|
107
|
+
|
|
108
|
+
|
|
109
|
+
def run_eval_suite(memory: Memory, dataset_paths: list[Path]) -> list[EvalResult]:
|
|
110
|
+
results: list[EvalResult] = []
|
|
111
|
+
for path in dataset_paths:
|
|
112
|
+
dataset = EvalDataset.load(path)
|
|
113
|
+
seed_dataset(memory, dataset)
|
|
114
|
+
results.append(run_eval(memory, dataset))
|
|
115
|
+
return results
|
|
116
|
+
|
|
117
|
+
|
|
118
|
+
def format_eval_report(results: list[EvalResult]) -> str:
|
|
119
|
+
lines = ["Agent Memory Evaluation", "=======================", ""]
|
|
120
|
+
for result in results:
|
|
121
|
+
lines.append(f"Dataset: {result.dataset}")
|
|
122
|
+
lines.append(f" Cases: {result.total}")
|
|
123
|
+
lines.append(f" Correct: {result.correct}")
|
|
124
|
+
lines.append(f" Precision: {result.precision:.1%}")
|
|
125
|
+
if result.failures:
|
|
126
|
+
lines.append(f" Failures: {len(result.failures)}")
|
|
127
|
+
lines.append("")
|
|
128
|
+
overall_total = sum(r.total for r in results)
|
|
129
|
+
overall_correct = sum(r.correct for r in results)
|
|
130
|
+
if overall_total:
|
|
131
|
+
lines.append(f"Overall precision: {overall_correct / overall_total:.1%}")
|
|
132
|
+
return "\n".join(lines)
|