warm-memory 0.2.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -0,0 +1,14 @@
1
+ from .buffer import WarmMemoryBuffer
2
+ from .benchmark import BenchmarkConfig, BenchmarkResult, run_benchmark
3
+ from .decorators import remember_interaction
4
+ from .scoring import ImportanceScorer, KeywordImportanceScorer
5
+
6
+ __all__ = [
7
+ "BenchmarkConfig",
8
+ "BenchmarkResult",
9
+ "ImportanceScorer",
10
+ "KeywordImportanceScorer",
11
+ "WarmMemoryBuffer",
12
+ "remember_interaction",
13
+ "run_benchmark",
14
+ ]
@@ -0,0 +1,219 @@
1
+ from __future__ import annotations
2
+
3
+ from dataclasses import dataclass
4
+ from pathlib import Path
5
+ from time import perf_counter
6
+ from typing import Any
7
+
8
+ import pandas as pd
9
+
10
+ from .buffer import WarmMemoryBuffer
11
+ from .workload import ScenarioTurn, default_workload
12
+
13
+
14
+ @dataclass(slots=True)
15
+ class BenchmarkConfig:
16
+ capacity: int = 8
17
+ top_k: int = 5
18
+ long_term_limit: int = 8
19
+ warm_hit_threshold: float = 0.34
20
+ long_term_latency_ms: float = 8.0
21
+ llm_base_latency_ms: float = 35.0
22
+ llm_latency_per_token_ms: float = 0.12
23
+ prompt_build_per_item_ms: float = 0.08
24
+
25
+
26
+ @dataclass(slots=True)
27
+ class BenchmarkResult:
28
+ name: str
29
+ turn_log: pd.DataFrame
30
+ summary: dict[str, float]
31
+
32
+
33
+ def _estimate_tokens(items: list[str]) -> int:
34
+ return sum(max(1, len(item.split())) for item in items)
35
+
36
+
37
+ def _summarize_turn_log(name: str, turn_log: pd.DataFrame) -> BenchmarkResult:
38
+ summary = {
39
+ "turns": float(len(turn_log)),
40
+ "warm_hit_rate": float(turn_log["warm_hit"].mean()),
41
+ "fallback_rate": float(turn_log["fallback_used"].mean()),
42
+ "avg_prompt_tokens": float(turn_log["prompt_tokens"].mean()),
43
+ "avg_end_to_end_ms": float(turn_log["end_to_end_ms"].mean()),
44
+ "p95_end_to_end_ms": float(turn_log["end_to_end_ms"].quantile(0.95)),
45
+ "answer_accuracy": float(turn_log["answer_correct"].mean()),
46
+ "memory_precision_at_k": float(turn_log["memory_precision_at_k"].mean()),
47
+ "repeated_tool_calls": float(turn_log["repeated_tool_call"].sum()),
48
+ }
49
+ return BenchmarkResult(name=name, turn_log=turn_log, summary=summary)
50
+
51
+
52
+ def _build_prompt_metrics(config: BenchmarkConfig, retrieved_contents: list[str]) -> tuple[int, float]:
53
+ tokens = _estimate_tokens(retrieved_contents)
54
+ prompt_build_ms = len(retrieved_contents) * config.prompt_build_per_item_ms
55
+ llm_ms = config.llm_base_latency_ms + (tokens * config.llm_latency_per_token_ms)
56
+ return tokens, prompt_build_ms + llm_ms
57
+
58
+
59
+ def _markdown_table(frame: pd.DataFrame) -> str:
60
+ columns = list(frame.columns)
61
+ header = "| " + " | ".join(columns) + " |"
62
+ divider = "| " + " | ".join(["---"] * len(columns)) + " |"
63
+ rows = [
64
+ "| " + " | ".join(str(row[column]) for column in columns) + " |"
65
+ for _, row in frame.iterrows()
66
+ ]
67
+ return "\n".join([header, divider, *rows])
68
+
69
+
70
+ def _write_memory(memory: WarmMemoryBuffer, turn: ScenarioTurn, response: str) -> None:
71
+ memory.add("user", turn.query, metadata={"scenario_id": turn.turn_id, "topic": turn.topic})
72
+ memory.add("assistant", response, metadata={"scenario_id": turn.turn_id, "topic": turn.topic})
73
+
74
+
75
+ def _retrieve_recent(memory: WarmMemoryBuffer, limit: int) -> pd.DataFrame:
76
+ recent = memory.recent(limit=limit).copy(deep=True)
77
+ if not recent.empty:
78
+ recent["score"] = 1.0
79
+ return recent
80
+
81
+
82
+ def _retrieve_relevant(memory: WarmMemoryBuffer, query: str, limit: int) -> pd.DataFrame:
83
+ return memory.relevant(query=query, limit=limit)
84
+
85
+
86
+ def _run_strategy(
87
+ name: str,
88
+ config: BenchmarkConfig,
89
+ turns: list[ScenarioTurn],
90
+ ) -> BenchmarkResult:
91
+ warm = WarmMemoryBuffer(capacity=config.capacity)
92
+ long_term = WarmMemoryBuffer(capacity=max(len(turns) * 2, config.capacity * 4))
93
+ rows: list[dict[str, Any]] = []
94
+ resolved_topics: set[str] = set()
95
+
96
+ for turn in turns:
97
+ lookup_start = perf_counter()
98
+ if name == "recency":
99
+ warm_candidates = _retrieve_recent(warm, config.top_k)
100
+ else:
101
+ warm_candidates = _retrieve_relevant(warm, turn.query, config.top_k)
102
+ warm_lookup_ms = (perf_counter() - lookup_start) * 1000
103
+
104
+ best_score = float(warm_candidates["score"].max()) if "score" in warm_candidates.columns and not warm_candidates.empty else 0.0
105
+ warm_hit = bool(not warm_candidates.empty and (name == "recency" or best_score >= config.warm_hit_threshold))
106
+
107
+ fallback_used = False
108
+ retrieval_ms = warm_lookup_ms
109
+ retrieved = warm_candidates
110
+
111
+ if name == "fallback" and not warm_hit:
112
+ fallback_used = True
113
+ retrieval_ms += config.long_term_latency_ms
114
+ retrieved = long_term.relevant(turn.query, limit=config.long_term_limit)
115
+
116
+ retrieved_contents = retrieved["content"].tolist() if not retrieved.empty else []
117
+ prompt_tokens, generation_ms = _build_prompt_metrics(config, retrieved_contents)
118
+ end_to_end_ms = retrieval_ms + generation_ms
119
+
120
+ relevant_retrieved = {
121
+ str(meta.get("topic"))
122
+ for meta in retrieved["metadata"].tolist()
123
+ if isinstance(meta, dict) and meta.get("topic")
124
+ }
125
+ expected = set(turn.required_topics)
126
+ memory_precision = len(expected & relevant_retrieved) / max(len(relevant_retrieved), 1) if relevant_retrieved else 0.0
127
+ answer_correct = expected.issubset(relevant_retrieved)
128
+
129
+ repeated_tool_call = int(turn.topic in resolved_topics and not answer_correct)
130
+ if answer_correct:
131
+ resolved_topics.add(turn.topic)
132
+
133
+ response = f"Response for {turn.topic}: {'correct' if answer_correct else 'incomplete'}"
134
+ _write_memory(warm, turn, response)
135
+ _write_memory(long_term, turn, response)
136
+
137
+ if name == "relevance":
138
+ warm.retain_relevant(turn.query, limit=config.capacity)
139
+
140
+ rows.append(
141
+ {
142
+ "turn_id": turn.turn_id,
143
+ "topic": turn.topic,
144
+ "query": turn.query,
145
+ "warm_lookup_ms": warm_lookup_ms,
146
+ "warm_hit": warm_hit,
147
+ "fallback_used": fallback_used,
148
+ "retrieval_ms": retrieval_ms,
149
+ "prompt_tokens": prompt_tokens,
150
+ "end_to_end_ms": end_to_end_ms,
151
+ "answer_correct": answer_correct,
152
+ "memory_precision_at_k": memory_precision,
153
+ "repeated_tool_call": repeated_tool_call,
154
+ }
155
+ )
156
+
157
+ turn_log = pd.DataFrame(rows)
158
+ return _summarize_turn_log(name, turn_log)
159
+
160
+
161
+ def _render_report(config: BenchmarkConfig, results: list[BenchmarkResult]) -> str:
162
+ summary_frame = pd.DataFrame(
163
+ [
164
+ {"strategy": result.name, **result.summary}
165
+ for result in results
166
+ ]
167
+ )
168
+ best_latency = summary_frame.sort_values("avg_end_to_end_ms").iloc[0]["strategy"]
169
+ best_accuracy = summary_frame.sort_values("answer_accuracy", ascending=False).iloc[0]["strategy"]
170
+ best_tokens = summary_frame.sort_values("avg_prompt_tokens").iloc[0]["strategy"]
171
+
172
+ lines = [
173
+ "# WarmMemory Benchmark Report",
174
+ "",
175
+ "## Configuration",
176
+ "",
177
+ f"- capacity: {config.capacity}",
178
+ f"- top_k: {config.top_k}",
179
+ f"- long_term_limit: {config.long_term_limit}",
180
+ f"- warm_hit_threshold: {config.warm_hit_threshold}",
181
+ "",
182
+ "## Summary",
183
+ "",
184
+ _markdown_table(summary_frame),
185
+ "",
186
+ "## Readout",
187
+ "",
188
+ f"- Lowest average latency: `{best_latency}`",
189
+ f"- Highest answer accuracy: `{best_accuracy}`",
190
+ f"- Smallest prompt footprint: `{best_tokens}`",
191
+ "",
192
+ "## Interpretation",
193
+ "",
194
+ "- `recency` shows the baseline cost of always trusting the latest interactions.",
195
+ "- `relevance` shows the effect of ranking and retaining the hottest working set.",
196
+ "- `fallback` shows a two-tier memory architecture where long-term retrieval is only used on warm misses.",
197
+ ]
198
+ return "\n".join(lines) + "\n"
199
+
200
+
201
+ def run_benchmark(
202
+ *,
203
+ config: BenchmarkConfig | None = None,
204
+ report_path: str | Path | None = None,
205
+ ) -> dict[str, BenchmarkResult]:
206
+ active_config = config or BenchmarkConfig()
207
+ turns = default_workload()
208
+ results = {
209
+ name: _run_strategy(name, active_config, turns)
210
+ for name in ("recency", "relevance", "fallback")
211
+ }
212
+
213
+ if report_path is not None:
214
+ report_text = _render_report(active_config, list(results.values()))
215
+ target = Path(report_path)
216
+ target.parent.mkdir(parents=True, exist_ok=True)
217
+ target.write_text(report_text, encoding="utf-8")
218
+
219
+ return results
warm_memory/buffer.py ADDED
@@ -0,0 +1,171 @@
1
+ from __future__ import annotations
2
+
3
+ from dataclasses import asdict, dataclass, field
4
+ from datetime import datetime, timezone
5
+ from itertools import count
6
+ from typing import Any
7
+
8
+ import pandas as pd
9
+
10
+ from .scoring import ImportanceScorer, KeywordImportanceScorer
11
+
12
+
13
+ @dataclass(slots=True)
14
+ class InteractionRecord:
15
+ interaction_id: int
16
+ timestamp: datetime
17
+ role: str
18
+ content: str
19
+ summary: str = ""
20
+ tags: tuple[str, ...] = field(default_factory=tuple)
21
+ metadata: dict[str, Any] = field(default_factory=dict)
22
+
23
+
24
+ class WarmMemoryBuffer:
25
+ """
26
+ Pandas-backed in-memory interaction buffer for agent experiments.
27
+
28
+ Two usage modes are supported:
29
+ - `recent(limit)`: classic sliding-window retrieval.
30
+ - `relevant(query, limit)`: query-aware retrieval using a pluggable scorer.
31
+ """
32
+
33
+ COLUMNS = ["interaction_id", "timestamp", "role", "content", "summary", "tags", "metadata"]
34
+
35
+ def __init__(
36
+ self,
37
+ capacity: int = 32,
38
+ scorer: ImportanceScorer | None = None,
39
+ ) -> None:
40
+ if capacity <= 0:
41
+ raise ValueError("capacity must be positive")
42
+
43
+ self.capacity = capacity
44
+ self.scorer = scorer or KeywordImportanceScorer()
45
+ self._id_source = count(1)
46
+ self._frame = pd.DataFrame(columns=self.COLUMNS)
47
+
48
+ @property
49
+ def frame(self) -> pd.DataFrame:
50
+ """Return a copy so callers cannot mutate the live buffer implicitly."""
51
+ return self._frame.copy(deep=True)
52
+
53
+ def __len__(self) -> int:
54
+ return len(self._frame.index)
55
+
56
+ def add(
57
+ self,
58
+ role: str,
59
+ content: str,
60
+ *,
61
+ summary: str = "",
62
+ tags: list[str] | tuple[str, ...] | None = None,
63
+ metadata: dict[str, Any] | None = None,
64
+ ) -> int:
65
+ record = InteractionRecord(
66
+ interaction_id=next(self._id_source),
67
+ timestamp=datetime.now(timezone.utc),
68
+ role=role,
69
+ content=content,
70
+ summary=summary,
71
+ tags=tuple(tags or ()),
72
+ metadata=dict(metadata or {}),
73
+ )
74
+
75
+ row = pd.DataFrame([asdict(record)], columns=self.COLUMNS)
76
+ self._frame = pd.concat([self._frame, row], ignore_index=True)
77
+ self._evict_over_capacity()
78
+ return record.interaction_id
79
+
80
+ def recent(self, limit: int = 5) -> pd.DataFrame:
81
+ if limit <= 0:
82
+ return self._frame.head(0).copy(deep=True)
83
+
84
+ ordered = self._frame.sort_values("interaction_id", ascending=False).head(limit)
85
+ return ordered.sort_values("interaction_id", ascending=True).reset_index(drop=True)
86
+
87
+ def relevant(self, query: str, limit: int = 5) -> pd.DataFrame:
88
+ if limit <= 0 or self._frame.empty:
89
+ return self._frame.head(0).copy(deep=True)
90
+
91
+ scored = self._frame.copy(deep=True)
92
+ scored["score"] = scored.apply(lambda row: self.scorer.score(query, row), axis=1)
93
+ scored = scored.sort_values(["score", "interaction_id"], ascending=[False, False]).head(limit)
94
+ return scored.reset_index(drop=True)
95
+
96
+ def context_window(self, query: str | None = None, limit: int = 5) -> pd.DataFrame:
97
+ """
98
+ Return either a recent window or a query-aware relevant window.
99
+
100
+ This gives callers a single method for switching between fixed-window and
101
+ importance-aware retention policies.
102
+ """
103
+
104
+ if query is None or not query.strip():
105
+ return self.recent(limit=limit)
106
+ return self.relevant(query=query, limit=limit)
107
+
108
+ def retain_relevant(self, query: str, limit: int | None = None) -> pd.DataFrame:
109
+ """
110
+ Compact the live buffer down to the top relevant rows for a query.
111
+
112
+ This is useful when you want a strict "working set" rather than keeping the
113
+ most recent interactions by default.
114
+ """
115
+
116
+ target_size = self.capacity if limit is None else limit
117
+ if target_size <= 0:
118
+ self.clear()
119
+ return self._frame.copy(deep=True)
120
+
121
+ retained = self.relevant(query=query, limit=target_size).sort_values("interaction_id")
122
+ self._frame = retained[self.COLUMNS].reset_index(drop=True)
123
+ return self.frame
124
+
125
+ def clear(self) -> None:
126
+ self._frame = pd.DataFrame(columns=self.COLUMNS)
127
+
128
+ def find_index_by_metadata(self, field: str, value: Any) -> int | None:
129
+ """
130
+ Return the integer index of the first row whose `metadata[field]` equals `value`.
131
+
132
+ Used by external mutation paths (e.g., key-based stores layered on top of
133
+ the buffer) that need to update or delete a specific logical entry without
134
+ scanning the dataframe themselves.
135
+ """
136
+ if self._frame.empty:
137
+ return None
138
+ for idx, metadata in enumerate(self._frame["metadata"].tolist()):
139
+ if isinstance(metadata, dict) and metadata.get(field) == value:
140
+ return idx
141
+ return None
142
+
143
+ def drop_at(self, index: int) -> None:
144
+ """
145
+ Remove the row at the given positional index.
146
+
147
+ Pairs with `find_index_by_metadata` to give external code a clean mutation
148
+ path without reaching into private dataframe internals.
149
+ """
150
+ if index < 0 or index >= len(self._frame.index):
151
+ return
152
+ self._frame = self._frame.drop(self._frame.index[index]).reset_index(drop=True)
153
+
154
+ def iter_rows(self) -> list[dict[str, Any]]:
155
+ """
156
+ Return a list of row dicts in insertion order. Use this for read-only
157
+ scans (e.g., search) instead of poking at the live dataframe.
158
+ """
159
+ return self._frame.to_dict("records")
160
+
161
+ def _evict_over_capacity(self) -> None:
162
+ overflow = len(self._frame.index) - self.capacity
163
+ if overflow <= 0:
164
+ return
165
+
166
+ self._frame = (
167
+ self._frame.sort_values("interaction_id", ascending=False)
168
+ .head(self.capacity)
169
+ .sort_values("interaction_id", ascending=True)
170
+ .reset_index(drop=True)
171
+ )
@@ -0,0 +1,70 @@
1
+ from __future__ import annotations
2
+
3
+ from functools import wraps
4
+ import inspect
5
+ from typing import Any, Callable
6
+
7
+ from .buffer import WarmMemoryBuffer
8
+
9
+
10
+ def _stringify_payload(payload: Any) -> str:
11
+ if payload is None:
12
+ return ""
13
+ if isinstance(payload, str):
14
+ return payload
15
+ return repr(payload)
16
+
17
+
18
+ def remember_interaction(
19
+ memory: WarmMemoryBuffer,
20
+ *,
21
+ input_role: str = "user",
22
+ output_role: str = "assistant",
23
+ input_extractor: Callable[[tuple[Any, ...], dict[str, Any]], Any] | None = None,
24
+ output_extractor: Callable[[Any], Any] | None = None,
25
+ metadata_factory: Callable[[tuple[Any, ...], dict[str, Any], Any], dict[str, Any]] | None = None,
26
+ ) -> Callable[[Callable[..., Any]], Callable[..., Any]]:
27
+ """
28
+ Decorate an agent-like function and persist input/output rows in warm memory.
29
+
30
+ Defaults:
31
+ - input is derived from the first non-memory positional argument or `prompt` kwarg
32
+ - output is stored as the function return value
33
+ """
34
+
35
+ def decorator(func: Callable[..., Any]) -> Callable[..., Any]:
36
+ signature = inspect.signature(func)
37
+
38
+ def default_input_extractor(args: tuple[Any, ...], kwargs: dict[str, Any]) -> Any:
39
+ bound = signature.bind_partial(*args, **kwargs)
40
+ if "prompt" in bound.arguments:
41
+ return bound.arguments["prompt"]
42
+ if bound.arguments:
43
+ first_name = next(iter(bound.arguments))
44
+ return bound.arguments[first_name]
45
+ return None
46
+
47
+ @wraps(func)
48
+ def wrapper(*args: Any, **kwargs: Any) -> Any:
49
+ raw_input = (
50
+ input_extractor(args, kwargs)
51
+ if input_extractor is not None
52
+ else default_input_extractor(args, kwargs)
53
+ )
54
+
55
+ result = func(*args, **kwargs)
56
+
57
+ raw_output = output_extractor(result) if output_extractor is not None else result
58
+ metadata = (
59
+ metadata_factory(args, kwargs, result)
60
+ if metadata_factory is not None
61
+ else {"function": func.__name__}
62
+ )
63
+
64
+ memory.add(role=input_role, content=_stringify_payload(raw_input), metadata=metadata)
65
+ memory.add(role=output_role, content=_stringify_payload(raw_output), metadata=metadata)
66
+ return result
67
+
68
+ return wrapper
69
+
70
+ return decorator
@@ -0,0 +1,17 @@
1
+ """
2
+ LangGraph integration for WarmMemory.
3
+
4
+ Install with the optional extra:
5
+
6
+ pip install WarmMemory[langgraph]
7
+ """
8
+
9
+ from .agent import build_warm_memory_agent
10
+ from .embeddings import EmbeddingsImportanceScorer
11
+ from .store import WarmStore
12
+
13
+ __all__ = [
14
+ "WarmStore",
15
+ "EmbeddingsImportanceScorer",
16
+ "build_warm_memory_agent",
17
+ ]
@@ -0,0 +1,137 @@
1
+ """
2
+ Reusable LangGraph agent builder that wires WarmStore into the request loop.
3
+
4
+ Pre-call: search the user's namespace for relevant memories and inject them
5
+ into the system message.
6
+ Post-call: write the new (user, assistant) exchange back into the warm store.
7
+
8
+ Works fully synthetic out of the box (FakeListChatModel + KeywordImportanceScorer).
9
+ Drop in a real chat model and/or EmbeddingsImportanceScorer to turn it into a
10
+ production agent.
11
+ """
12
+
13
+ from __future__ import annotations
14
+
15
+ from typing import Any, Callable, TypedDict
16
+
17
+ from langchain_core.language_models import BaseChatModel
18
+ from langchain_core.messages import AIMessage, HumanMessage, SystemMessage
19
+ from langgraph.graph import END, START, StateGraph
20
+
21
+ from .store import WarmStore
22
+
23
+
24
+ def _flatten_message_content(content: Any) -> str:
25
+ """
26
+ Normalize an AIMessage.content value to a plain string.
27
+
28
+ Newer LangChain chat models can return `content` as a list of content
29
+ blocks (e.g., `[{"type": "text", "text": "..."}, ...]`) instead of a
30
+ plain string. Storing the list repr in warm memory is useless for
31
+ keyword/embedding scoring, so flatten to text here.
32
+ """
33
+ if isinstance(content, str):
34
+ return content
35
+ if isinstance(content, list):
36
+ parts: list[str] = []
37
+ for block in content:
38
+ if isinstance(block, str):
39
+ parts.append(block)
40
+ elif isinstance(block, dict):
41
+ if block.get("type") == "text" and isinstance(block.get("text"), str):
42
+ parts.append(block["text"])
43
+ elif isinstance(block.get("text"), str):
44
+ parts.append(block["text"])
45
+ return "\n".join(parts)
46
+ return str(content)
47
+
48
+
49
+ class WarmAgentState(TypedDict, total=False):
50
+ query: str
51
+ namespace: tuple[str, ...]
52
+ recalled: list[dict[str, Any]]
53
+ response: str
54
+
55
+
56
+ _DEFAULT_SYSTEM = (
57
+ "You are a helpful assistant with access to warm memory of prior exchanges "
58
+ "with this user. Use the recalled context if relevant; ignore it if not."
59
+ )
60
+
61
+
62
+ def _format_recalled(recalled: list[dict[str, Any]]) -> str:
63
+ if not recalled:
64
+ return "(no prior context)"
65
+ lines = []
66
+ for entry in recalled:
67
+ key = entry.get("key", "?")
68
+ value = entry.get("value", {})
69
+ lines.append(f"- [{key}] {value}")
70
+ return "\n".join(lines)
71
+
72
+
73
+ def build_warm_memory_agent(
74
+ *,
75
+ model: BaseChatModel,
76
+ store: WarmStore,
77
+ recall_limit: int = 5,
78
+ system_prompt: str = _DEFAULT_SYSTEM,
79
+ namespace_default: tuple[str, ...] = ("default",),
80
+ ) -> Callable[[dict[str, Any]], dict[str, Any]]:
81
+ """
82
+ Build a compiled LangGraph agent that uses `store` as warm memory.
83
+
84
+ Returns a compiled graph. Invoke it with:
85
+ agent.invoke({"query": "...", "namespace": ("alice",)})
86
+ """
87
+
88
+ def memory_lookup(state: WarmAgentState) -> dict[str, Any]:
89
+ namespace = state.get("namespace") or namespace_default
90
+ query = state.get("query", "")
91
+ if not query:
92
+ return {"recalled": []}
93
+ hits = store.search(namespace, query=query, limit=recall_limit)
94
+ return {
95
+ "recalled": [
96
+ {"key": h.key, "value": h.value, "score": h.score} for h in hits
97
+ ],
98
+ "namespace": namespace,
99
+ }
100
+
101
+ def respond(state: WarmAgentState) -> dict[str, Any]:
102
+ recalled = state.get("recalled", []) or []
103
+ memory_block = _format_recalled(recalled)
104
+ messages = [
105
+ SystemMessage(content=f"{system_prompt}\n\nRecalled memory:\n{memory_block}"),
106
+ HumanMessage(content=state.get("query", "")),
107
+ ]
108
+ ai_message = model.invoke(messages)
109
+ raw_content = ai_message.content if isinstance(ai_message, AIMessage) else ai_message
110
+ text = _flatten_message_content(raw_content)
111
+ return {"response": text}
112
+
113
+ def memory_write(state: WarmAgentState) -> dict[str, Any]:
114
+ namespace = state.get("namespace") or namespace_default
115
+ query = state.get("query", "")
116
+ response = state.get("response", "")
117
+ if not query and not response:
118
+ return {}
119
+ next_key = store.next_key(namespace, prefix="exchange-")
120
+ store.put(
121
+ namespace,
122
+ next_key,
123
+ {"user": query, "assistant": response},
124
+ )
125
+ return {}
126
+
127
+ graph = StateGraph(WarmAgentState)
128
+ graph.add_node("memory_lookup", memory_lookup)
129
+ graph.add_node("respond", respond)
130
+ graph.add_node("memory_write", memory_write)
131
+
132
+ graph.add_edge(START, "memory_lookup")
133
+ graph.add_edge("memory_lookup", "respond")
134
+ graph.add_edge("respond", "memory_write")
135
+ graph.add_edge("memory_write", END)
136
+
137
+ return graph.compile()