warm-memory 0.2.1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- warm_memory/__init__.py +14 -0
- warm_memory/benchmark.py +219 -0
- warm_memory/buffer.py +171 -0
- warm_memory/decorators.py +70 -0
- warm_memory/langgraph/__init__.py +17 -0
- warm_memory/langgraph/agent.py +137 -0
- warm_memory/langgraph/benchmark.py +325 -0
- warm_memory/langgraph/embeddings.py +94 -0
- warm_memory/langgraph/store.py +335 -0
- warm_memory/scoring.py +61 -0
- warm_memory/workload.py +35 -0
- warm_memory-0.2.1.dist-info/METADATA +306 -0
- warm_memory-0.2.1.dist-info/RECORD +16 -0
- warm_memory-0.2.1.dist-info/WHEEL +5 -0
- warm_memory-0.2.1.dist-info/licenses/LICENSE +21 -0
- warm_memory-0.2.1.dist-info/top_level.txt +1 -0
warm_memory/__init__.py
ADDED
|
@@ -0,0 +1,14 @@
|
|
|
1
|
+
from .buffer import WarmMemoryBuffer
|
|
2
|
+
from .benchmark import BenchmarkConfig, BenchmarkResult, run_benchmark
|
|
3
|
+
from .decorators import remember_interaction
|
|
4
|
+
from .scoring import ImportanceScorer, KeywordImportanceScorer
|
|
5
|
+
|
|
6
|
+
__all__ = [
|
|
7
|
+
"BenchmarkConfig",
|
|
8
|
+
"BenchmarkResult",
|
|
9
|
+
"ImportanceScorer",
|
|
10
|
+
"KeywordImportanceScorer",
|
|
11
|
+
"WarmMemoryBuffer",
|
|
12
|
+
"remember_interaction",
|
|
13
|
+
"run_benchmark",
|
|
14
|
+
]
|
warm_memory/benchmark.py
ADDED
|
@@ -0,0 +1,219 @@
|
|
|
1
|
+
from __future__ import annotations
|
|
2
|
+
|
|
3
|
+
from dataclasses import dataclass
|
|
4
|
+
from pathlib import Path
|
|
5
|
+
from time import perf_counter
|
|
6
|
+
from typing import Any
|
|
7
|
+
|
|
8
|
+
import pandas as pd
|
|
9
|
+
|
|
10
|
+
from .buffer import WarmMemoryBuffer
|
|
11
|
+
from .workload import ScenarioTurn, default_workload
|
|
12
|
+
|
|
13
|
+
|
|
14
|
+
@dataclass(slots=True)
|
|
15
|
+
class BenchmarkConfig:
|
|
16
|
+
capacity: int = 8
|
|
17
|
+
top_k: int = 5
|
|
18
|
+
long_term_limit: int = 8
|
|
19
|
+
warm_hit_threshold: float = 0.34
|
|
20
|
+
long_term_latency_ms: float = 8.0
|
|
21
|
+
llm_base_latency_ms: float = 35.0
|
|
22
|
+
llm_latency_per_token_ms: float = 0.12
|
|
23
|
+
prompt_build_per_item_ms: float = 0.08
|
|
24
|
+
|
|
25
|
+
|
|
26
|
+
@dataclass(slots=True)
|
|
27
|
+
class BenchmarkResult:
|
|
28
|
+
name: str
|
|
29
|
+
turn_log: pd.DataFrame
|
|
30
|
+
summary: dict[str, float]
|
|
31
|
+
|
|
32
|
+
|
|
33
|
+
def _estimate_tokens(items: list[str]) -> int:
|
|
34
|
+
return sum(max(1, len(item.split())) for item in items)
|
|
35
|
+
|
|
36
|
+
|
|
37
|
+
def _summarize_turn_log(name: str, turn_log: pd.DataFrame) -> BenchmarkResult:
|
|
38
|
+
summary = {
|
|
39
|
+
"turns": float(len(turn_log)),
|
|
40
|
+
"warm_hit_rate": float(turn_log["warm_hit"].mean()),
|
|
41
|
+
"fallback_rate": float(turn_log["fallback_used"].mean()),
|
|
42
|
+
"avg_prompt_tokens": float(turn_log["prompt_tokens"].mean()),
|
|
43
|
+
"avg_end_to_end_ms": float(turn_log["end_to_end_ms"].mean()),
|
|
44
|
+
"p95_end_to_end_ms": float(turn_log["end_to_end_ms"].quantile(0.95)),
|
|
45
|
+
"answer_accuracy": float(turn_log["answer_correct"].mean()),
|
|
46
|
+
"memory_precision_at_k": float(turn_log["memory_precision_at_k"].mean()),
|
|
47
|
+
"repeated_tool_calls": float(turn_log["repeated_tool_call"].sum()),
|
|
48
|
+
}
|
|
49
|
+
return BenchmarkResult(name=name, turn_log=turn_log, summary=summary)
|
|
50
|
+
|
|
51
|
+
|
|
52
|
+
def _build_prompt_metrics(config: BenchmarkConfig, retrieved_contents: list[str]) -> tuple[int, float]:
|
|
53
|
+
tokens = _estimate_tokens(retrieved_contents)
|
|
54
|
+
prompt_build_ms = len(retrieved_contents) * config.prompt_build_per_item_ms
|
|
55
|
+
llm_ms = config.llm_base_latency_ms + (tokens * config.llm_latency_per_token_ms)
|
|
56
|
+
return tokens, prompt_build_ms + llm_ms
|
|
57
|
+
|
|
58
|
+
|
|
59
|
+
def _markdown_table(frame: pd.DataFrame) -> str:
|
|
60
|
+
columns = list(frame.columns)
|
|
61
|
+
header = "| " + " | ".join(columns) + " |"
|
|
62
|
+
divider = "| " + " | ".join(["---"] * len(columns)) + " |"
|
|
63
|
+
rows = [
|
|
64
|
+
"| " + " | ".join(str(row[column]) for column in columns) + " |"
|
|
65
|
+
for _, row in frame.iterrows()
|
|
66
|
+
]
|
|
67
|
+
return "\n".join([header, divider, *rows])
|
|
68
|
+
|
|
69
|
+
|
|
70
|
+
def _write_memory(memory: WarmMemoryBuffer, turn: ScenarioTurn, response: str) -> None:
|
|
71
|
+
memory.add("user", turn.query, metadata={"scenario_id": turn.turn_id, "topic": turn.topic})
|
|
72
|
+
memory.add("assistant", response, metadata={"scenario_id": turn.turn_id, "topic": turn.topic})
|
|
73
|
+
|
|
74
|
+
|
|
75
|
+
def _retrieve_recent(memory: WarmMemoryBuffer, limit: int) -> pd.DataFrame:
|
|
76
|
+
recent = memory.recent(limit=limit).copy(deep=True)
|
|
77
|
+
if not recent.empty:
|
|
78
|
+
recent["score"] = 1.0
|
|
79
|
+
return recent
|
|
80
|
+
|
|
81
|
+
|
|
82
|
+
def _retrieve_relevant(memory: WarmMemoryBuffer, query: str, limit: int) -> pd.DataFrame:
|
|
83
|
+
return memory.relevant(query=query, limit=limit)
|
|
84
|
+
|
|
85
|
+
|
|
86
|
+
def _run_strategy(
|
|
87
|
+
name: str,
|
|
88
|
+
config: BenchmarkConfig,
|
|
89
|
+
turns: list[ScenarioTurn],
|
|
90
|
+
) -> BenchmarkResult:
|
|
91
|
+
warm = WarmMemoryBuffer(capacity=config.capacity)
|
|
92
|
+
long_term = WarmMemoryBuffer(capacity=max(len(turns) * 2, config.capacity * 4))
|
|
93
|
+
rows: list[dict[str, Any]] = []
|
|
94
|
+
resolved_topics: set[str] = set()
|
|
95
|
+
|
|
96
|
+
for turn in turns:
|
|
97
|
+
lookup_start = perf_counter()
|
|
98
|
+
if name == "recency":
|
|
99
|
+
warm_candidates = _retrieve_recent(warm, config.top_k)
|
|
100
|
+
else:
|
|
101
|
+
warm_candidates = _retrieve_relevant(warm, turn.query, config.top_k)
|
|
102
|
+
warm_lookup_ms = (perf_counter() - lookup_start) * 1000
|
|
103
|
+
|
|
104
|
+
best_score = float(warm_candidates["score"].max()) if "score" in warm_candidates.columns and not warm_candidates.empty else 0.0
|
|
105
|
+
warm_hit = bool(not warm_candidates.empty and (name == "recency" or best_score >= config.warm_hit_threshold))
|
|
106
|
+
|
|
107
|
+
fallback_used = False
|
|
108
|
+
retrieval_ms = warm_lookup_ms
|
|
109
|
+
retrieved = warm_candidates
|
|
110
|
+
|
|
111
|
+
if name == "fallback" and not warm_hit:
|
|
112
|
+
fallback_used = True
|
|
113
|
+
retrieval_ms += config.long_term_latency_ms
|
|
114
|
+
retrieved = long_term.relevant(turn.query, limit=config.long_term_limit)
|
|
115
|
+
|
|
116
|
+
retrieved_contents = retrieved["content"].tolist() if not retrieved.empty else []
|
|
117
|
+
prompt_tokens, generation_ms = _build_prompt_metrics(config, retrieved_contents)
|
|
118
|
+
end_to_end_ms = retrieval_ms + generation_ms
|
|
119
|
+
|
|
120
|
+
relevant_retrieved = {
|
|
121
|
+
str(meta.get("topic"))
|
|
122
|
+
for meta in retrieved["metadata"].tolist()
|
|
123
|
+
if isinstance(meta, dict) and meta.get("topic")
|
|
124
|
+
}
|
|
125
|
+
expected = set(turn.required_topics)
|
|
126
|
+
memory_precision = len(expected & relevant_retrieved) / max(len(relevant_retrieved), 1) if relevant_retrieved else 0.0
|
|
127
|
+
answer_correct = expected.issubset(relevant_retrieved)
|
|
128
|
+
|
|
129
|
+
repeated_tool_call = int(turn.topic in resolved_topics and not answer_correct)
|
|
130
|
+
if answer_correct:
|
|
131
|
+
resolved_topics.add(turn.topic)
|
|
132
|
+
|
|
133
|
+
response = f"Response for {turn.topic}: {'correct' if answer_correct else 'incomplete'}"
|
|
134
|
+
_write_memory(warm, turn, response)
|
|
135
|
+
_write_memory(long_term, turn, response)
|
|
136
|
+
|
|
137
|
+
if name == "relevance":
|
|
138
|
+
warm.retain_relevant(turn.query, limit=config.capacity)
|
|
139
|
+
|
|
140
|
+
rows.append(
|
|
141
|
+
{
|
|
142
|
+
"turn_id": turn.turn_id,
|
|
143
|
+
"topic": turn.topic,
|
|
144
|
+
"query": turn.query,
|
|
145
|
+
"warm_lookup_ms": warm_lookup_ms,
|
|
146
|
+
"warm_hit": warm_hit,
|
|
147
|
+
"fallback_used": fallback_used,
|
|
148
|
+
"retrieval_ms": retrieval_ms,
|
|
149
|
+
"prompt_tokens": prompt_tokens,
|
|
150
|
+
"end_to_end_ms": end_to_end_ms,
|
|
151
|
+
"answer_correct": answer_correct,
|
|
152
|
+
"memory_precision_at_k": memory_precision,
|
|
153
|
+
"repeated_tool_call": repeated_tool_call,
|
|
154
|
+
}
|
|
155
|
+
)
|
|
156
|
+
|
|
157
|
+
turn_log = pd.DataFrame(rows)
|
|
158
|
+
return _summarize_turn_log(name, turn_log)
|
|
159
|
+
|
|
160
|
+
|
|
161
|
+
def _render_report(config: BenchmarkConfig, results: list[BenchmarkResult]) -> str:
|
|
162
|
+
summary_frame = pd.DataFrame(
|
|
163
|
+
[
|
|
164
|
+
{"strategy": result.name, **result.summary}
|
|
165
|
+
for result in results
|
|
166
|
+
]
|
|
167
|
+
)
|
|
168
|
+
best_latency = summary_frame.sort_values("avg_end_to_end_ms").iloc[0]["strategy"]
|
|
169
|
+
best_accuracy = summary_frame.sort_values("answer_accuracy", ascending=False).iloc[0]["strategy"]
|
|
170
|
+
best_tokens = summary_frame.sort_values("avg_prompt_tokens").iloc[0]["strategy"]
|
|
171
|
+
|
|
172
|
+
lines = [
|
|
173
|
+
"# WarmMemory Benchmark Report",
|
|
174
|
+
"",
|
|
175
|
+
"## Configuration",
|
|
176
|
+
"",
|
|
177
|
+
f"- capacity: {config.capacity}",
|
|
178
|
+
f"- top_k: {config.top_k}",
|
|
179
|
+
f"- long_term_limit: {config.long_term_limit}",
|
|
180
|
+
f"- warm_hit_threshold: {config.warm_hit_threshold}",
|
|
181
|
+
"",
|
|
182
|
+
"## Summary",
|
|
183
|
+
"",
|
|
184
|
+
_markdown_table(summary_frame),
|
|
185
|
+
"",
|
|
186
|
+
"## Readout",
|
|
187
|
+
"",
|
|
188
|
+
f"- Lowest average latency: `{best_latency}`",
|
|
189
|
+
f"- Highest answer accuracy: `{best_accuracy}`",
|
|
190
|
+
f"- Smallest prompt footprint: `{best_tokens}`",
|
|
191
|
+
"",
|
|
192
|
+
"## Interpretation",
|
|
193
|
+
"",
|
|
194
|
+
"- `recency` shows the baseline cost of always trusting the latest interactions.",
|
|
195
|
+
"- `relevance` shows the effect of ranking and retaining the hottest working set.",
|
|
196
|
+
"- `fallback` shows a two-tier memory architecture where long-term retrieval is only used on warm misses.",
|
|
197
|
+
]
|
|
198
|
+
return "\n".join(lines) + "\n"
|
|
199
|
+
|
|
200
|
+
|
|
201
|
+
def run_benchmark(
|
|
202
|
+
*,
|
|
203
|
+
config: BenchmarkConfig | None = None,
|
|
204
|
+
report_path: str | Path | None = None,
|
|
205
|
+
) -> dict[str, BenchmarkResult]:
|
|
206
|
+
active_config = config or BenchmarkConfig()
|
|
207
|
+
turns = default_workload()
|
|
208
|
+
results = {
|
|
209
|
+
name: _run_strategy(name, active_config, turns)
|
|
210
|
+
for name in ("recency", "relevance", "fallback")
|
|
211
|
+
}
|
|
212
|
+
|
|
213
|
+
if report_path is not None:
|
|
214
|
+
report_text = _render_report(active_config, list(results.values()))
|
|
215
|
+
target = Path(report_path)
|
|
216
|
+
target.parent.mkdir(parents=True, exist_ok=True)
|
|
217
|
+
target.write_text(report_text, encoding="utf-8")
|
|
218
|
+
|
|
219
|
+
return results
|
warm_memory/buffer.py
ADDED
|
@@ -0,0 +1,171 @@
|
|
|
1
|
+
from __future__ import annotations
|
|
2
|
+
|
|
3
|
+
from dataclasses import asdict, dataclass, field
|
|
4
|
+
from datetime import datetime, timezone
|
|
5
|
+
from itertools import count
|
|
6
|
+
from typing import Any
|
|
7
|
+
|
|
8
|
+
import pandas as pd
|
|
9
|
+
|
|
10
|
+
from .scoring import ImportanceScorer, KeywordImportanceScorer
|
|
11
|
+
|
|
12
|
+
|
|
13
|
+
@dataclass(slots=True)
|
|
14
|
+
class InteractionRecord:
|
|
15
|
+
interaction_id: int
|
|
16
|
+
timestamp: datetime
|
|
17
|
+
role: str
|
|
18
|
+
content: str
|
|
19
|
+
summary: str = ""
|
|
20
|
+
tags: tuple[str, ...] = field(default_factory=tuple)
|
|
21
|
+
metadata: dict[str, Any] = field(default_factory=dict)
|
|
22
|
+
|
|
23
|
+
|
|
24
|
+
class WarmMemoryBuffer:
|
|
25
|
+
"""
|
|
26
|
+
Pandas-backed in-memory interaction buffer for agent experiments.
|
|
27
|
+
|
|
28
|
+
Two usage modes are supported:
|
|
29
|
+
- `recent(limit)`: classic sliding-window retrieval.
|
|
30
|
+
- `relevant(query, limit)`: query-aware retrieval using a pluggable scorer.
|
|
31
|
+
"""
|
|
32
|
+
|
|
33
|
+
COLUMNS = ["interaction_id", "timestamp", "role", "content", "summary", "tags", "metadata"]
|
|
34
|
+
|
|
35
|
+
def __init__(
|
|
36
|
+
self,
|
|
37
|
+
capacity: int = 32,
|
|
38
|
+
scorer: ImportanceScorer | None = None,
|
|
39
|
+
) -> None:
|
|
40
|
+
if capacity <= 0:
|
|
41
|
+
raise ValueError("capacity must be positive")
|
|
42
|
+
|
|
43
|
+
self.capacity = capacity
|
|
44
|
+
self.scorer = scorer or KeywordImportanceScorer()
|
|
45
|
+
self._id_source = count(1)
|
|
46
|
+
self._frame = pd.DataFrame(columns=self.COLUMNS)
|
|
47
|
+
|
|
48
|
+
@property
|
|
49
|
+
def frame(self) -> pd.DataFrame:
|
|
50
|
+
"""Return a copy so callers cannot mutate the live buffer implicitly."""
|
|
51
|
+
return self._frame.copy(deep=True)
|
|
52
|
+
|
|
53
|
+
def __len__(self) -> int:
|
|
54
|
+
return len(self._frame.index)
|
|
55
|
+
|
|
56
|
+
def add(
|
|
57
|
+
self,
|
|
58
|
+
role: str,
|
|
59
|
+
content: str,
|
|
60
|
+
*,
|
|
61
|
+
summary: str = "",
|
|
62
|
+
tags: list[str] | tuple[str, ...] | None = None,
|
|
63
|
+
metadata: dict[str, Any] | None = None,
|
|
64
|
+
) -> int:
|
|
65
|
+
record = InteractionRecord(
|
|
66
|
+
interaction_id=next(self._id_source),
|
|
67
|
+
timestamp=datetime.now(timezone.utc),
|
|
68
|
+
role=role,
|
|
69
|
+
content=content,
|
|
70
|
+
summary=summary,
|
|
71
|
+
tags=tuple(tags or ()),
|
|
72
|
+
metadata=dict(metadata or {}),
|
|
73
|
+
)
|
|
74
|
+
|
|
75
|
+
row = pd.DataFrame([asdict(record)], columns=self.COLUMNS)
|
|
76
|
+
self._frame = pd.concat([self._frame, row], ignore_index=True)
|
|
77
|
+
self._evict_over_capacity()
|
|
78
|
+
return record.interaction_id
|
|
79
|
+
|
|
80
|
+
def recent(self, limit: int = 5) -> pd.DataFrame:
|
|
81
|
+
if limit <= 0:
|
|
82
|
+
return self._frame.head(0).copy(deep=True)
|
|
83
|
+
|
|
84
|
+
ordered = self._frame.sort_values("interaction_id", ascending=False).head(limit)
|
|
85
|
+
return ordered.sort_values("interaction_id", ascending=True).reset_index(drop=True)
|
|
86
|
+
|
|
87
|
+
def relevant(self, query: str, limit: int = 5) -> pd.DataFrame:
|
|
88
|
+
if limit <= 0 or self._frame.empty:
|
|
89
|
+
return self._frame.head(0).copy(deep=True)
|
|
90
|
+
|
|
91
|
+
scored = self._frame.copy(deep=True)
|
|
92
|
+
scored["score"] = scored.apply(lambda row: self.scorer.score(query, row), axis=1)
|
|
93
|
+
scored = scored.sort_values(["score", "interaction_id"], ascending=[False, False]).head(limit)
|
|
94
|
+
return scored.reset_index(drop=True)
|
|
95
|
+
|
|
96
|
+
def context_window(self, query: str | None = None, limit: int = 5) -> pd.DataFrame:
|
|
97
|
+
"""
|
|
98
|
+
Return either a recent window or a query-aware relevant window.
|
|
99
|
+
|
|
100
|
+
This gives callers a single method for switching between fixed-window and
|
|
101
|
+
importance-aware retention policies.
|
|
102
|
+
"""
|
|
103
|
+
|
|
104
|
+
if query is None or not query.strip():
|
|
105
|
+
return self.recent(limit=limit)
|
|
106
|
+
return self.relevant(query=query, limit=limit)
|
|
107
|
+
|
|
108
|
+
def retain_relevant(self, query: str, limit: int | None = None) -> pd.DataFrame:
|
|
109
|
+
"""
|
|
110
|
+
Compact the live buffer down to the top relevant rows for a query.
|
|
111
|
+
|
|
112
|
+
This is useful when you want a strict "working set" rather than keeping the
|
|
113
|
+
most recent interactions by default.
|
|
114
|
+
"""
|
|
115
|
+
|
|
116
|
+
target_size = self.capacity if limit is None else limit
|
|
117
|
+
if target_size <= 0:
|
|
118
|
+
self.clear()
|
|
119
|
+
return self._frame.copy(deep=True)
|
|
120
|
+
|
|
121
|
+
retained = self.relevant(query=query, limit=target_size).sort_values("interaction_id")
|
|
122
|
+
self._frame = retained[self.COLUMNS].reset_index(drop=True)
|
|
123
|
+
return self.frame
|
|
124
|
+
|
|
125
|
+
def clear(self) -> None:
|
|
126
|
+
self._frame = pd.DataFrame(columns=self.COLUMNS)
|
|
127
|
+
|
|
128
|
+
def find_index_by_metadata(self, field: str, value: Any) -> int | None:
|
|
129
|
+
"""
|
|
130
|
+
Return the integer index of the first row whose `metadata[field]` equals `value`.
|
|
131
|
+
|
|
132
|
+
Used by external mutation paths (e.g., key-based stores layered on top of
|
|
133
|
+
the buffer) that need to update or delete a specific logical entry without
|
|
134
|
+
scanning the dataframe themselves.
|
|
135
|
+
"""
|
|
136
|
+
if self._frame.empty:
|
|
137
|
+
return None
|
|
138
|
+
for idx, metadata in enumerate(self._frame["metadata"].tolist()):
|
|
139
|
+
if isinstance(metadata, dict) and metadata.get(field) == value:
|
|
140
|
+
return idx
|
|
141
|
+
return None
|
|
142
|
+
|
|
143
|
+
def drop_at(self, index: int) -> None:
|
|
144
|
+
"""
|
|
145
|
+
Remove the row at the given positional index.
|
|
146
|
+
|
|
147
|
+
Pairs with `find_index_by_metadata` to give external code a clean mutation
|
|
148
|
+
path without reaching into private dataframe internals.
|
|
149
|
+
"""
|
|
150
|
+
if index < 0 or index >= len(self._frame.index):
|
|
151
|
+
return
|
|
152
|
+
self._frame = self._frame.drop(self._frame.index[index]).reset_index(drop=True)
|
|
153
|
+
|
|
154
|
+
def iter_rows(self) -> list[dict[str, Any]]:
|
|
155
|
+
"""
|
|
156
|
+
Return a list of row dicts in insertion order. Use this for read-only
|
|
157
|
+
scans (e.g., search) instead of poking at the live dataframe.
|
|
158
|
+
"""
|
|
159
|
+
return self._frame.to_dict("records")
|
|
160
|
+
|
|
161
|
+
def _evict_over_capacity(self) -> None:
|
|
162
|
+
overflow = len(self._frame.index) - self.capacity
|
|
163
|
+
if overflow <= 0:
|
|
164
|
+
return
|
|
165
|
+
|
|
166
|
+
self._frame = (
|
|
167
|
+
self._frame.sort_values("interaction_id", ascending=False)
|
|
168
|
+
.head(self.capacity)
|
|
169
|
+
.sort_values("interaction_id", ascending=True)
|
|
170
|
+
.reset_index(drop=True)
|
|
171
|
+
)
|
|
@@ -0,0 +1,70 @@
|
|
|
1
|
+
from __future__ import annotations
|
|
2
|
+
|
|
3
|
+
from functools import wraps
|
|
4
|
+
import inspect
|
|
5
|
+
from typing import Any, Callable
|
|
6
|
+
|
|
7
|
+
from .buffer import WarmMemoryBuffer
|
|
8
|
+
|
|
9
|
+
|
|
10
|
+
def _stringify_payload(payload: Any) -> str:
|
|
11
|
+
if payload is None:
|
|
12
|
+
return ""
|
|
13
|
+
if isinstance(payload, str):
|
|
14
|
+
return payload
|
|
15
|
+
return repr(payload)
|
|
16
|
+
|
|
17
|
+
|
|
18
|
+
def remember_interaction(
|
|
19
|
+
memory: WarmMemoryBuffer,
|
|
20
|
+
*,
|
|
21
|
+
input_role: str = "user",
|
|
22
|
+
output_role: str = "assistant",
|
|
23
|
+
input_extractor: Callable[[tuple[Any, ...], dict[str, Any]], Any] | None = None,
|
|
24
|
+
output_extractor: Callable[[Any], Any] | None = None,
|
|
25
|
+
metadata_factory: Callable[[tuple[Any, ...], dict[str, Any], Any], dict[str, Any]] | None = None,
|
|
26
|
+
) -> Callable[[Callable[..., Any]], Callable[..., Any]]:
|
|
27
|
+
"""
|
|
28
|
+
Decorate an agent-like function and persist input/output rows in warm memory.
|
|
29
|
+
|
|
30
|
+
Defaults:
|
|
31
|
+
- input is derived from the first non-memory positional argument or `prompt` kwarg
|
|
32
|
+
- output is stored as the function return value
|
|
33
|
+
"""
|
|
34
|
+
|
|
35
|
+
def decorator(func: Callable[..., Any]) -> Callable[..., Any]:
|
|
36
|
+
signature = inspect.signature(func)
|
|
37
|
+
|
|
38
|
+
def default_input_extractor(args: tuple[Any, ...], kwargs: dict[str, Any]) -> Any:
|
|
39
|
+
bound = signature.bind_partial(*args, **kwargs)
|
|
40
|
+
if "prompt" in bound.arguments:
|
|
41
|
+
return bound.arguments["prompt"]
|
|
42
|
+
if bound.arguments:
|
|
43
|
+
first_name = next(iter(bound.arguments))
|
|
44
|
+
return bound.arguments[first_name]
|
|
45
|
+
return None
|
|
46
|
+
|
|
47
|
+
@wraps(func)
|
|
48
|
+
def wrapper(*args: Any, **kwargs: Any) -> Any:
|
|
49
|
+
raw_input = (
|
|
50
|
+
input_extractor(args, kwargs)
|
|
51
|
+
if input_extractor is not None
|
|
52
|
+
else default_input_extractor(args, kwargs)
|
|
53
|
+
)
|
|
54
|
+
|
|
55
|
+
result = func(*args, **kwargs)
|
|
56
|
+
|
|
57
|
+
raw_output = output_extractor(result) if output_extractor is not None else result
|
|
58
|
+
metadata = (
|
|
59
|
+
metadata_factory(args, kwargs, result)
|
|
60
|
+
if metadata_factory is not None
|
|
61
|
+
else {"function": func.__name__}
|
|
62
|
+
)
|
|
63
|
+
|
|
64
|
+
memory.add(role=input_role, content=_stringify_payload(raw_input), metadata=metadata)
|
|
65
|
+
memory.add(role=output_role, content=_stringify_payload(raw_output), metadata=metadata)
|
|
66
|
+
return result
|
|
67
|
+
|
|
68
|
+
return wrapper
|
|
69
|
+
|
|
70
|
+
return decorator
|
|
@@ -0,0 +1,17 @@
|
|
|
1
|
+
"""
|
|
2
|
+
LangGraph integration for WarmMemory.
|
|
3
|
+
|
|
4
|
+
Install with the optional extra:
|
|
5
|
+
|
|
6
|
+
pip install WarmMemory[langgraph]
|
|
7
|
+
"""
|
|
8
|
+
|
|
9
|
+
from .agent import build_warm_memory_agent
|
|
10
|
+
from .embeddings import EmbeddingsImportanceScorer
|
|
11
|
+
from .store import WarmStore
|
|
12
|
+
|
|
13
|
+
__all__ = [
|
|
14
|
+
"WarmStore",
|
|
15
|
+
"EmbeddingsImportanceScorer",
|
|
16
|
+
"build_warm_memory_agent",
|
|
17
|
+
]
|
|
@@ -0,0 +1,137 @@
|
|
|
1
|
+
"""
|
|
2
|
+
Reusable LangGraph agent builder that wires WarmStore into the request loop.
|
|
3
|
+
|
|
4
|
+
Pre-call: search the user's namespace for relevant memories and inject them
|
|
5
|
+
into the system message.
|
|
6
|
+
Post-call: write the new (user, assistant) exchange back into the warm store.
|
|
7
|
+
|
|
8
|
+
Works fully synthetic out of the box (FakeListChatModel + KeywordImportanceScorer).
|
|
9
|
+
Drop in a real chat model and/or EmbeddingsImportanceScorer to turn it into a
|
|
10
|
+
production agent.
|
|
11
|
+
"""
|
|
12
|
+
|
|
13
|
+
from __future__ import annotations
|
|
14
|
+
|
|
15
|
+
from typing import Any, Callable, TypedDict
|
|
16
|
+
|
|
17
|
+
from langchain_core.language_models import BaseChatModel
|
|
18
|
+
from langchain_core.messages import AIMessage, HumanMessage, SystemMessage
|
|
19
|
+
from langgraph.graph import END, START, StateGraph
|
|
20
|
+
|
|
21
|
+
from .store import WarmStore
|
|
22
|
+
|
|
23
|
+
|
|
24
|
+
def _flatten_message_content(content: Any) -> str:
|
|
25
|
+
"""
|
|
26
|
+
Normalize an AIMessage.content value to a plain string.
|
|
27
|
+
|
|
28
|
+
Newer LangChain chat models can return `content` as a list of content
|
|
29
|
+
blocks (e.g., `[{"type": "text", "text": "..."}, ...]`) instead of a
|
|
30
|
+
plain string. Storing the list repr in warm memory is useless for
|
|
31
|
+
keyword/embedding scoring, so flatten to text here.
|
|
32
|
+
"""
|
|
33
|
+
if isinstance(content, str):
|
|
34
|
+
return content
|
|
35
|
+
if isinstance(content, list):
|
|
36
|
+
parts: list[str] = []
|
|
37
|
+
for block in content:
|
|
38
|
+
if isinstance(block, str):
|
|
39
|
+
parts.append(block)
|
|
40
|
+
elif isinstance(block, dict):
|
|
41
|
+
if block.get("type") == "text" and isinstance(block.get("text"), str):
|
|
42
|
+
parts.append(block["text"])
|
|
43
|
+
elif isinstance(block.get("text"), str):
|
|
44
|
+
parts.append(block["text"])
|
|
45
|
+
return "\n".join(parts)
|
|
46
|
+
return str(content)
|
|
47
|
+
|
|
48
|
+
|
|
49
|
+
class WarmAgentState(TypedDict, total=False):
|
|
50
|
+
query: str
|
|
51
|
+
namespace: tuple[str, ...]
|
|
52
|
+
recalled: list[dict[str, Any]]
|
|
53
|
+
response: str
|
|
54
|
+
|
|
55
|
+
|
|
56
|
+
_DEFAULT_SYSTEM = (
|
|
57
|
+
"You are a helpful assistant with access to warm memory of prior exchanges "
|
|
58
|
+
"with this user. Use the recalled context if relevant; ignore it if not."
|
|
59
|
+
)
|
|
60
|
+
|
|
61
|
+
|
|
62
|
+
def _format_recalled(recalled: list[dict[str, Any]]) -> str:
|
|
63
|
+
if not recalled:
|
|
64
|
+
return "(no prior context)"
|
|
65
|
+
lines = []
|
|
66
|
+
for entry in recalled:
|
|
67
|
+
key = entry.get("key", "?")
|
|
68
|
+
value = entry.get("value", {})
|
|
69
|
+
lines.append(f"- [{key}] {value}")
|
|
70
|
+
return "\n".join(lines)
|
|
71
|
+
|
|
72
|
+
|
|
73
|
+
def build_warm_memory_agent(
|
|
74
|
+
*,
|
|
75
|
+
model: BaseChatModel,
|
|
76
|
+
store: WarmStore,
|
|
77
|
+
recall_limit: int = 5,
|
|
78
|
+
system_prompt: str = _DEFAULT_SYSTEM,
|
|
79
|
+
namespace_default: tuple[str, ...] = ("default",),
|
|
80
|
+
) -> Callable[[dict[str, Any]], dict[str, Any]]:
|
|
81
|
+
"""
|
|
82
|
+
Build a compiled LangGraph agent that uses `store` as warm memory.
|
|
83
|
+
|
|
84
|
+
Returns a compiled graph. Invoke it with:
|
|
85
|
+
agent.invoke({"query": "...", "namespace": ("alice",)})
|
|
86
|
+
"""
|
|
87
|
+
|
|
88
|
+
def memory_lookup(state: WarmAgentState) -> dict[str, Any]:
|
|
89
|
+
namespace = state.get("namespace") or namespace_default
|
|
90
|
+
query = state.get("query", "")
|
|
91
|
+
if not query:
|
|
92
|
+
return {"recalled": []}
|
|
93
|
+
hits = store.search(namespace, query=query, limit=recall_limit)
|
|
94
|
+
return {
|
|
95
|
+
"recalled": [
|
|
96
|
+
{"key": h.key, "value": h.value, "score": h.score} for h in hits
|
|
97
|
+
],
|
|
98
|
+
"namespace": namespace,
|
|
99
|
+
}
|
|
100
|
+
|
|
101
|
+
def respond(state: WarmAgentState) -> dict[str, Any]:
|
|
102
|
+
recalled = state.get("recalled", []) or []
|
|
103
|
+
memory_block = _format_recalled(recalled)
|
|
104
|
+
messages = [
|
|
105
|
+
SystemMessage(content=f"{system_prompt}\n\nRecalled memory:\n{memory_block}"),
|
|
106
|
+
HumanMessage(content=state.get("query", "")),
|
|
107
|
+
]
|
|
108
|
+
ai_message = model.invoke(messages)
|
|
109
|
+
raw_content = ai_message.content if isinstance(ai_message, AIMessage) else ai_message
|
|
110
|
+
text = _flatten_message_content(raw_content)
|
|
111
|
+
return {"response": text}
|
|
112
|
+
|
|
113
|
+
def memory_write(state: WarmAgentState) -> dict[str, Any]:
|
|
114
|
+
namespace = state.get("namespace") or namespace_default
|
|
115
|
+
query = state.get("query", "")
|
|
116
|
+
response = state.get("response", "")
|
|
117
|
+
if not query and not response:
|
|
118
|
+
return {}
|
|
119
|
+
next_key = store.next_key(namespace, prefix="exchange-")
|
|
120
|
+
store.put(
|
|
121
|
+
namespace,
|
|
122
|
+
next_key,
|
|
123
|
+
{"user": query, "assistant": response},
|
|
124
|
+
)
|
|
125
|
+
return {}
|
|
126
|
+
|
|
127
|
+
graph = StateGraph(WarmAgentState)
|
|
128
|
+
graph.add_node("memory_lookup", memory_lookup)
|
|
129
|
+
graph.add_node("respond", respond)
|
|
130
|
+
graph.add_node("memory_write", memory_write)
|
|
131
|
+
|
|
132
|
+
graph.add_edge(START, "memory_lookup")
|
|
133
|
+
graph.add_edge("memory_lookup", "respond")
|
|
134
|
+
graph.add_edge("respond", "memory_write")
|
|
135
|
+
graph.add_edge("memory_write", END)
|
|
136
|
+
|
|
137
|
+
return graph.compile()
|