edgeone 1.5.8 → 1.6.0
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- package/README.md +26 -26
- package/edgeone-bin/edgeone.js +3 -3
- package/edgeone-dist/cli.js +86879 -2294
- package/edgeone-dist/libs-pages-agent-toolkit/README.md +8 -0
- package/edgeone-dist/libs-pages-agent-toolkit/pages_agent_toolkit-0.1.40-py3-none-any.whl +0 -0
- package/edgeone-dist/libs-pages-blob-python/README.md +38 -0
- package/edgeone-dist/libs-pages-blob-python/pages_blob_python-0.11.0-py3-none-any.whl +0 -0
- package/edgeone-dist/pages/dev/runner-worker.js +86519 -2075
- package/edgeone-dist/pages/observability-python/__init__.py +32 -0
- package/edgeone-dist/pages/observability-python/_compat.py +69 -0
- package/edgeone-dist/pages/observability-python/apm/__init__.py +13 -0
- package/edgeone-dist/pages/observability-python/apm/config.py +85 -0
- package/edgeone-dist/pages/observability-python/apm/llm_semconv.py +53 -0
- package/edgeone-dist/pages/observability-python/apm/metrics_bridge.py +226 -0
- package/edgeone-dist/pages/observability-python/apm/span_exporter.py +384 -0
- package/edgeone-dist/pages/observability-python/bootstrap.py +158 -0
- package/edgeone-dist/pages/observability-python/build.py +119 -0
- package/edgeone-dist/pages/observability-python/context_patches.py +167 -0
- package/edgeone-dist/pages/observability-python/context_propagator.py +78 -0
- package/edgeone-dist/pages/observability-python/registry.json +95 -0
- package/edgeone-dist/pages/observability-python/registry.py +141 -0
- package/edgeone-dist/pages/observability-python/telemetry.py +214 -0
- package/edgeone-dist/pages/observability-python/tracer.py +165 -0
- package/edgeone-dist/pages/templates/agent-python/__init__.py +11 -0
- package/edgeone-dist/pages/templates/agent-python/adapter.py +908 -0
- package/edgeone-dist/pages/templates/agent-python/context.py +689 -0
- package/edgeone-dist/pages/templates/agent-python/local_blob_store.py +172 -0
- package/edgeone-dist/pages/templates/agent-python/memory.py +2301 -0
- package/edgeone-dist/pages/templates/agent-python/runtime.py +839 -0
- package/edgeone-dist/pages/templates/agent-python/store.py +204 -0
- package/edgeone-dist/studio/ui/assets/agent-obs-Dvi4IpEy.js +4 -0
- package/edgeone-dist/studio/ui/assets/agent-obs-qDJCE0TQ.css +1 -0
- package/edgeone-dist/studio/ui/assets/highlight-ClXAL37H.js +3 -0
- package/edgeone-dist/studio/ui/assets/index-Cz5oQnXW.css +1 -0
- package/edgeone-dist/studio/ui/assets/index-DD3d108t.js +1 -0
- package/edgeone-dist/studio/ui/assets/moment-BYRO94Ou.js +10 -0
- package/edgeone-dist/studio/ui/assets/react-dom-ZzBHVjtL.js +24 -0
- package/edgeone-dist/studio/ui/assets/react-hnpCyKql.js +17 -0
- package/edgeone-dist/studio/ui/assets/tea-CADagUwM.css +1 -0
- package/edgeone-dist/studio/ui/assets/tea-Slf_ajmf.js +334 -0
- package/edgeone-dist/studio/ui/favicon.ico +0 -0
- package/edgeone-dist/studio/ui/index.html +31 -0
- package/libs-pages-agent-toolkit/README.md +8 -0
- package/libs-pages-agent-toolkit/pages_agent_toolkit-0.1.40-py3-none-any.whl +0 -0
- package/libs-pages-blob-python/README.md +38 -0
- package/libs-pages-blob-python/pages_blob_python-0.11.0-py3-none-any.whl +0 -0
- package/package.json +33 -7
|
@@ -0,0 +1,2301 @@
|
|
|
1
|
+
# src/pages/builder/templates/agent-python/memory.py
|
|
2
|
+
"""Conversation memory — ctx.store for message history CRUD.
|
|
3
|
+
|
|
4
|
+
提供基于 conversation_id 的消息追加、读取、清空、列举、删除、更新等操作。
|
|
5
|
+
底层直接使用 raw blob store(pages_blob.Store / LocalFileBlobStore),
|
|
6
|
+
自行管理 JSON 序列化,不复用 BlobBackedStore 的 envelope 格式。
|
|
7
|
+
|
|
8
|
+
典型用法::
|
|
9
|
+
|
|
10
|
+
async def handler(ctx):
|
|
11
|
+
# 追加 user 消息
|
|
12
|
+
msg_id = await ctx.store.append_message(
|
|
13
|
+
ctx.conversation_id, "user", "Hello!"
|
|
14
|
+
)
|
|
15
|
+
# 获取对话历史(默认按时间升序,方便拼 prompt)
|
|
16
|
+
messages = await ctx.store.get_messages(ctx.conversation_id)
|
|
17
|
+
# 转换为 OpenAI 格式
|
|
18
|
+
openai_msgs = ctx.store.to_openai_input(messages)
|
|
19
|
+
|
|
20
|
+
数据布局(与 Node `agent/memory.ts` 完全一致):
|
|
21
|
+
|
|
22
|
+
conversations/{encoded_cid}/meta 会话元数据(保留 messageCount=0 后仍存在)
|
|
23
|
+
conversations/{encoded_cid}/messages/{ts}_{msg_id} 消息正文(按时间戳排序)
|
|
24
|
+
message_index/{msg_id} 消息→key 的二级索引(O(1) 定位)
|
|
25
|
+
conversation_index/{rev_ts}_{encoded_cid} 会话最近活跃时间倒排索引
|
|
26
|
+
user_conversation_index/{encoded_uid}/{rev_ts}_{cid} 用户维度会话倒排索引
|
|
27
|
+
langgraph_checkpoints/{thread_id}/checkpoints/{cid} LangGraph checkpoint
|
|
28
|
+
langgraph_checkpoints/{thread_id}/latest LangGraph latest checkpoint id
|
|
29
|
+
langgraph_checkpoints/{thread_id}/writes/{cid}/{tid} LangGraph pending writes
|
|
30
|
+
"""
|
|
31
|
+
from __future__ import annotations
|
|
32
|
+
|
|
33
|
+
import base64
|
|
34
|
+
import json
|
|
35
|
+
import time
|
|
36
|
+
import uuid
|
|
37
|
+
import urllib.parse
|
|
38
|
+
from dataclasses import dataclass, field
|
|
39
|
+
from typing import Any, List, Optional
|
|
40
|
+
|
|
41
|
+
|
|
42
|
+
# ─── Error Classes ───
|
|
43
|
+
|
|
44
|
+
|
|
45
|
+
class MemoryError(Exception):
|
|
46
|
+
"""Base class for all memory errors."""
|
|
47
|
+
pass
|
|
48
|
+
|
|
49
|
+
|
|
50
|
+
class MemoryValidationError(MemoryError):
|
|
51
|
+
"""Bad input: limit > 100, conversation_id > 256, content > 50MB, non-object metadata."""
|
|
52
|
+
pass
|
|
53
|
+
|
|
54
|
+
|
|
55
|
+
class MemoryNotFoundError(MemoryError):
|
|
56
|
+
"""Conversation/message not found."""
|
|
57
|
+
pass
|
|
58
|
+
|
|
59
|
+
|
|
60
|
+
class MemoryQuotaExceededError(MemoryError):
|
|
61
|
+
"""More than 10000 messages per conversation."""
|
|
62
|
+
pass
|
|
63
|
+
|
|
64
|
+
|
|
65
|
+
class MemoryStorageError(MemoryError):
|
|
66
|
+
"""Blob storage failures."""
|
|
67
|
+
pass
|
|
68
|
+
|
|
69
|
+
|
|
70
|
+
# ─── Data Model ───
|
|
71
|
+
|
|
72
|
+
|
|
73
|
+
@dataclass
|
|
74
|
+
class Message:
|
|
75
|
+
"""A single message in a conversation."""
|
|
76
|
+
message_id: str
|
|
77
|
+
role: str # 'user' | 'assistant' | 'system' | 'tool'
|
|
78
|
+
content: Any # str or list (for multimodal)
|
|
79
|
+
created_at: int # ms timestamp
|
|
80
|
+
metadata: Optional[dict] = None
|
|
81
|
+
updated_at: Optional[int] = None
|
|
82
|
+
|
|
83
|
+
def to_dict(self) -> dict:
|
|
84
|
+
d: dict[str, Any] = {
|
|
85
|
+
"messageId": self.message_id,
|
|
86
|
+
"role": self.role,
|
|
87
|
+
"content": self.content,
|
|
88
|
+
"createdAt": self.created_at,
|
|
89
|
+
}
|
|
90
|
+
if self.metadata is not None:
|
|
91
|
+
d["metadata"] = self.metadata
|
|
92
|
+
if self.updated_at is not None:
|
|
93
|
+
d["updatedAt"] = self.updated_at
|
|
94
|
+
return d
|
|
95
|
+
|
|
96
|
+
@classmethod
|
|
97
|
+
def from_dict(cls, d: dict) -> "Message":
|
|
98
|
+
return cls(
|
|
99
|
+
message_id=d.get("message_id") or d.get("messageId", ""),
|
|
100
|
+
role=d["role"],
|
|
101
|
+
content=d["content"],
|
|
102
|
+
created_at=d.get("created_at") or d.get("createdAt", 0),
|
|
103
|
+
metadata=d.get("metadata"),
|
|
104
|
+
updated_at=d.get("updated_at") or d.get("updatedAt"),
|
|
105
|
+
)
|
|
106
|
+
|
|
107
|
+
|
|
108
|
+
@dataclass
|
|
109
|
+
class ConversationMeta:
|
|
110
|
+
"""Metadata for a conversation."""
|
|
111
|
+
conversation_id: str
|
|
112
|
+
created_at: int # ms timestamp
|
|
113
|
+
last_message_at: int # ms timestamp
|
|
114
|
+
message_count: int
|
|
115
|
+
metadata: Optional[dict] = None
|
|
116
|
+
|
|
117
|
+
def to_dict(self) -> dict:
|
|
118
|
+
d: dict[str, Any] = {
|
|
119
|
+
"conversationId": self.conversation_id,
|
|
120
|
+
"createdAt": self.created_at,
|
|
121
|
+
"lastMessageAt": self.last_message_at,
|
|
122
|
+
"messageCount": self.message_count,
|
|
123
|
+
}
|
|
124
|
+
if self.metadata is not None:
|
|
125
|
+
d["metadata"] = self.metadata
|
|
126
|
+
return d
|
|
127
|
+
|
|
128
|
+
@classmethod
|
|
129
|
+
def from_dict(cls, d: dict) -> "ConversationMeta":
|
|
130
|
+
cid = d.get("conversationId") or d.get("conversation_id") or ""
|
|
131
|
+
if not cid:
|
|
132
|
+
raise KeyError(f"conversationId missing in meta dict, keys={list(d.keys())}")
|
|
133
|
+
return cls(
|
|
134
|
+
conversation_id=cid,
|
|
135
|
+
created_at=d.get("createdAt") or d.get("created_at") or 0,
|
|
136
|
+
last_message_at=d.get("lastMessageAt") or d.get("last_message_at") or 0,
|
|
137
|
+
message_count=d.get("messageCount") if "messageCount" in d else d.get("message_count", 0),
|
|
138
|
+
metadata=d.get("metadata"),
|
|
139
|
+
)
|
|
140
|
+
|
|
141
|
+
|
|
142
|
+
@dataclass
|
|
143
|
+
class ListConversationsResult:
|
|
144
|
+
"""Paginated result for list_conversations."""
|
|
145
|
+
items: List[ConversationMeta] = field(default_factory=list)
|
|
146
|
+
next_cursor: Optional[str] = None
|
|
147
|
+
previous_cursor: Optional[str] = None
|
|
148
|
+
|
|
149
|
+
|
|
150
|
+
# ─── Constants ───
|
|
151
|
+
|
|
152
|
+
_MAX_LIMIT = 100
|
|
153
|
+
_DEFAULT_LIMIT = 20
|
|
154
|
+
_MAX_CONVERSATION_ID_LEN = 256
|
|
155
|
+
_MAX_MESSAGE_ID_LEN = 256
|
|
156
|
+
_MAX_CONTENT_SIZE = 50 * 1024 * 1024 # 50MB
|
|
157
|
+
_MAX_MESSAGES_PER_CONVERSATION = 10000
|
|
158
|
+
_VALID_ROLES = ("user", "assistant", "system", "tool")
|
|
159
|
+
|
|
160
|
+
_MESSAGE_PREFIX = "conversations"
|
|
161
|
+
_MESSAGE_INDEX_PREFIX = "message_index"
|
|
162
|
+
_CONVERSATION_INDEX_PREFIX = "conversation_index"
|
|
163
|
+
_USER_CONVERSATION_INDEX_PREFIX = "user_conversation_index"
|
|
164
|
+
_LANGGRAPH_CHECKPOINT_PREFIX = "langgraph_checkpoints"
|
|
165
|
+
_LANGGRAPH_STORE_PREFIX = "langgraph_store"
|
|
166
|
+
_LANGGRAPH_STORE_KEY_SEPARATOR = "__key__"
|
|
167
|
+
_MAX_SAFE_INTEGER = 9007199254740991 # JS Number.MAX_SAFE_INTEGER
|
|
168
|
+
|
|
169
|
+
# Sentinel for "argument not provided",区别于显式传入的 None。
|
|
170
|
+
# update_message 用它来判断 content / metadata 是否被传,对齐 Node 的
|
|
171
|
+
# `input.content === undefined` 判断。
|
|
172
|
+
_UNSET: Any = object()
|
|
173
|
+
|
|
174
|
+
|
|
175
|
+
def _json_default(obj: Any) -> Any:
|
|
176
|
+
"""Fallback for json.dumps — handles LangGraph checkpoint objects
|
|
177
|
+
containing non-serializable types (Runtime, ChannelProtocol, etc.)."""
|
|
178
|
+
if hasattr(obj, "model_dump"):
|
|
179
|
+
return obj.model_dump()
|
|
180
|
+
# For non-serializable objects, return None to strip them rather than
|
|
181
|
+
# corrupting checkpoint data with repr() strings.
|
|
182
|
+
return None
|
|
183
|
+
|
|
184
|
+
|
|
185
|
+
def _encode_path_segment(value: str) -> str:
|
|
186
|
+
"""Base64url-encode a path segment (no padding). Matches Node encodePathSegment."""
|
|
187
|
+
return base64.urlsafe_b64encode(value.encode("utf-8")).rstrip(b"=").decode("ascii")
|
|
188
|
+
|
|
189
|
+
|
|
190
|
+
# ─── Key Schema Helpers ───
|
|
191
|
+
|
|
192
|
+
|
|
193
|
+
def _encode_cid(conversation_id: str) -> str:
|
|
194
|
+
"""URL-encode conversation_id for use in blob keys."""
|
|
195
|
+
return urllib.parse.quote(conversation_id, safe="")
|
|
196
|
+
|
|
197
|
+
|
|
198
|
+
def _encode_segment(value: str) -> str:
|
|
199
|
+
"""URL-encode an arbitrary path segment (thread_id / checkpoint_id / task_id)."""
|
|
200
|
+
return urllib.parse.quote(value, safe="")
|
|
201
|
+
|
|
202
|
+
|
|
203
|
+
def _message_key(encoded_cid: str, created_at: int, message_id: str) -> str:
|
|
204
|
+
"""conversations/{encoded_cid}/messages/{created_at}_{message_id}"""
|
|
205
|
+
return f"{_MESSAGE_PREFIX}/{encoded_cid}/messages/{created_at}_{message_id}"
|
|
206
|
+
|
|
207
|
+
|
|
208
|
+
def _messages_prefix(encoded_cid: str) -> str:
|
|
209
|
+
"""Prefix for listing all messages in a conversation."""
|
|
210
|
+
return f"{_MESSAGE_PREFIX}/{encoded_cid}/messages/"
|
|
211
|
+
|
|
212
|
+
|
|
213
|
+
def _meta_key(encoded_cid: str) -> str:
|
|
214
|
+
"""conversations/{encoded_cid}/meta"""
|
|
215
|
+
return f"{_MESSAGE_PREFIX}/{encoded_cid}/meta"
|
|
216
|
+
|
|
217
|
+
|
|
218
|
+
def _message_index_key(message_id: str) -> str:
|
|
219
|
+
"""message_index/{message_id}
|
|
220
|
+
|
|
221
|
+
二级索引:通过 message_id 直达消息正文 key。和 Node 行为一致,
|
|
222
|
+
支撑 update_message / delete_message 的 O(1) 定位。
|
|
223
|
+
"""
|
|
224
|
+
return f"{_MESSAGE_INDEX_PREFIX}/{_encode_segment(message_id)}"
|
|
225
|
+
|
|
226
|
+
|
|
227
|
+
def _index_key(last_message_at: int, encoded_cid: str) -> str:
|
|
228
|
+
"""conversation_index/{reverseTimestamp}_{encoded_cid}
|
|
229
|
+
|
|
230
|
+
Use reverse timestamp (MAX_SAFE_INTEGER - ts) padded to 16 digits,
|
|
231
|
+
aligned with Node memory.ts. Lexicographic ascending = most recent first.
|
|
232
|
+
"""
|
|
233
|
+
return f"{_CONVERSATION_INDEX_PREFIX}/{_reverse_timestamp(last_message_at)}_{encoded_cid}"
|
|
234
|
+
|
|
235
|
+
|
|
236
|
+
def _reverse_timestamp(ts: int) -> str:
|
|
237
|
+
"""MAX_SAFE_INTEGER - ts, zero-padded to 16 digits. Matches Node reverseTimestamp()."""
|
|
238
|
+
return str(_MAX_SAFE_INTEGER - ts).zfill(16)
|
|
239
|
+
|
|
240
|
+
|
|
241
|
+
def _user_index_key(encoded_user_id: str, last_message_at: int, encoded_cid: str) -> str:
|
|
242
|
+
"""user_conversation_index/{encoded_user_id}/{reverseTimestamp}_{encoded_cid}"""
|
|
243
|
+
return f"{_USER_CONVERSATION_INDEX_PREFIX}/{encoded_user_id}/{_reverse_timestamp(last_message_at)}_{encoded_cid}"
|
|
244
|
+
|
|
245
|
+
|
|
246
|
+
def _user_index_prefix(encoded_user_id: str) -> str:
|
|
247
|
+
"""Prefix for listing all conversations of a user."""
|
|
248
|
+
return f"{_USER_CONVERSATION_INDEX_PREFIX}/{encoded_user_id}/"
|
|
249
|
+
|
|
250
|
+
|
|
251
|
+
def _index_prefix() -> str:
|
|
252
|
+
return f"{_CONVERSATION_INDEX_PREFIX}/"
|
|
253
|
+
|
|
254
|
+
|
|
255
|
+
def _generate_message_id() -> str:
|
|
256
|
+
"""Generate a unique message ID: msg_{random_hex}"""
|
|
257
|
+
return f"msg_{uuid.uuid4().hex[:16]}"
|
|
258
|
+
|
|
259
|
+
|
|
260
|
+
# ─── Cursor Helpers ───
|
|
261
|
+
|
|
262
|
+
|
|
263
|
+
def _encode_cursor(last_message_at: int, conversation_id: str) -> str:
|
|
264
|
+
"""Encode pagination cursor as base64url JSON. Opaque to callers."""
|
|
265
|
+
payload = json.dumps(
|
|
266
|
+
{"v": 1, "lastMessageAt": last_message_at, "conversationId": conversation_id},
|
|
267
|
+
separators=(",", ":"),
|
|
268
|
+
)
|
|
269
|
+
return base64.urlsafe_b64encode(payload.encode()).decode().rstrip("=")
|
|
270
|
+
|
|
271
|
+
|
|
272
|
+
def _decode_cursor(cursor: str) -> dict:
|
|
273
|
+
"""Decode and validate an opaque pagination cursor.
|
|
274
|
+
|
|
275
|
+
Raises MemoryValidationError on invalid/unsupported cursor.
|
|
276
|
+
"""
|
|
277
|
+
padded = cursor + "=" * (-len(cursor) % 4)
|
|
278
|
+
try:
|
|
279
|
+
raw = base64.urlsafe_b64decode(padded)
|
|
280
|
+
data = json.loads(raw)
|
|
281
|
+
except Exception:
|
|
282
|
+
raise MemoryValidationError("Invalid cursor.")
|
|
283
|
+
if not isinstance(data, dict) or data.get("v") != 1:
|
|
284
|
+
raise MemoryValidationError("Unsupported cursor version.")
|
|
285
|
+
if not isinstance(data.get("lastMessageAt"), int) or not isinstance(data.get("conversationId"), str):
|
|
286
|
+
raise MemoryValidationError("Invalid cursor fields.")
|
|
287
|
+
return data
|
|
288
|
+
|
|
289
|
+
|
|
290
|
+
def _cursor_sort_key(last_message_at: int, conversation_id: str) -> str:
|
|
291
|
+
"""Build the sort key string that a cursor maps to in the index.
|
|
292
|
+
|
|
293
|
+
Format: {revTs:016d}_{encoded_cid} — same format as index key basenames.
|
|
294
|
+
"""
|
|
295
|
+
return f"{_reverse_timestamp(last_message_at)}_{_encode_cid(conversation_id)}"
|
|
296
|
+
|
|
297
|
+
|
|
298
|
+
# ─── ConversationMemory ───
|
|
299
|
+
|
|
300
|
+
|
|
301
|
+
class ConversationMemory:
|
|
302
|
+
"""Conversation message history CRUD, accessed as ctx.store.
|
|
303
|
+
|
|
304
|
+
Operates on a dedicated blob store (separate from per-route ctx.kv).
|
|
305
|
+
All core methods are async.
|
|
306
|
+
"""
|
|
307
|
+
|
|
308
|
+
def __init__(self, blob_store: Any, run_id: str) -> None:
|
|
309
|
+
"""
|
|
310
|
+
Args:
|
|
311
|
+
blob_store: Raw blob store instance (pages_blob.Store or LocalFileBlobStore).
|
|
312
|
+
Must implement: get(key, type=), set(key, value), delete(key),
|
|
313
|
+
list(prefix=) returning object with .blobs list of objects with .key.
|
|
314
|
+
run_id: Current run ID, auto-injected into message metadata.
|
|
315
|
+
"""
|
|
316
|
+
self._blob = blob_store
|
|
317
|
+
self._run_id = run_id
|
|
318
|
+
|
|
319
|
+
# ─── Validation Helpers ───
|
|
320
|
+
|
|
321
|
+
def _validate_conversation_id(self, conversation_id: str) -> None:
|
|
322
|
+
if not conversation_id or not isinstance(conversation_id, str):
|
|
323
|
+
raise MemoryValidationError("conversation_id must be a non-empty string")
|
|
324
|
+
if len(conversation_id) > _MAX_CONVERSATION_ID_LEN:
|
|
325
|
+
raise MemoryValidationError(
|
|
326
|
+
f"conversation_id exceeds max length of {_MAX_CONVERSATION_ID_LEN} characters"
|
|
327
|
+
)
|
|
328
|
+
|
|
329
|
+
def _validate_message_id(self, message_id: str) -> None:
|
|
330
|
+
if not message_id or not isinstance(message_id, str):
|
|
331
|
+
raise MemoryValidationError("message_id must be a non-empty string")
|
|
332
|
+
if len(message_id) > _MAX_MESSAGE_ID_LEN:
|
|
333
|
+
raise MemoryValidationError(
|
|
334
|
+
f"message_id exceeds max length of {_MAX_MESSAGE_ID_LEN} characters"
|
|
335
|
+
)
|
|
336
|
+
|
|
337
|
+
def _validate_user_id(self, user_id: str) -> None:
|
|
338
|
+
if not user_id or not isinstance(user_id, str):
|
|
339
|
+
raise MemoryValidationError("user_id must be a non-empty string")
|
|
340
|
+
if len(user_id) > _MAX_CONVERSATION_ID_LEN:
|
|
341
|
+
raise MemoryValidationError(
|
|
342
|
+
f"user_id exceeds max length of {_MAX_CONVERSATION_ID_LEN} characters"
|
|
343
|
+
)
|
|
344
|
+
|
|
345
|
+
def _validate_limit(self, limit: int) -> None:
|
|
346
|
+
if not isinstance(limit, int) or limit < 1:
|
|
347
|
+
raise MemoryValidationError("limit must be a positive integer")
|
|
348
|
+
if limit > _MAX_LIMIT:
|
|
349
|
+
raise MemoryValidationError(f"limit exceeds maximum of {_MAX_LIMIT}")
|
|
350
|
+
|
|
351
|
+
def _validate_role(self, role: str) -> None:
|
|
352
|
+
if role not in _VALID_ROLES:
|
|
353
|
+
raise MemoryValidationError(
|
|
354
|
+
f"role must be one of {_VALID_ROLES}, got '{role}'"
|
|
355
|
+
)
|
|
356
|
+
|
|
357
|
+
def _validate_content(self, content: Any) -> None:
|
|
358
|
+
if content is None:
|
|
359
|
+
raise MemoryValidationError("content must not be None")
|
|
360
|
+
serialized = json.dumps(content, ensure_ascii=False) if not isinstance(content, str) else content
|
|
361
|
+
if len(serialized.encode("utf-8")) > _MAX_CONTENT_SIZE:
|
|
362
|
+
raise MemoryValidationError(
|
|
363
|
+
f"content exceeds maximum size of {_MAX_CONTENT_SIZE // (1024 * 1024)}MB"
|
|
364
|
+
)
|
|
365
|
+
|
|
366
|
+
def _validate_metadata(self, metadata: Any) -> None:
|
|
367
|
+
if metadata is not None and not isinstance(metadata, dict):
|
|
368
|
+
raise MemoryValidationError("metadata must be a dict or None")
|
|
369
|
+
|
|
370
|
+
@staticmethod
|
|
371
|
+
def _assert_single_cursor(after: Optional[str], before: Optional[str]) -> None:
|
|
372
|
+
if after is not None and before is not None:
|
|
373
|
+
raise MemoryValidationError("after and before are mutually exclusive")
|
|
374
|
+
|
|
375
|
+
# ─── Blob I/O Helpers ───
|
|
376
|
+
|
|
377
|
+
async def _blob_get_json(self, key: str, **kwargs) -> Optional[dict]:
|
|
378
|
+
"""Get and parse JSON from blob store. Returns None if not found."""
|
|
379
|
+
try:
|
|
380
|
+
raw = await self._blob.get(key, type="text", **kwargs)
|
|
381
|
+
if raw is None:
|
|
382
|
+
return None
|
|
383
|
+
return json.loads(raw)
|
|
384
|
+
except (json.JSONDecodeError, TypeError, ValueError):
|
|
385
|
+
return None
|
|
386
|
+
except MemoryStorageError:
|
|
387
|
+
raise
|
|
388
|
+
except Exception as e:
|
|
389
|
+
raise MemoryStorageError(f"Failed to read key '{key}': {e}") from e
|
|
390
|
+
|
|
391
|
+
async def _blob_set_json(self, key: str, data: Any) -> None:
|
|
392
|
+
"""Serialize value to JSON and write to blob store.
|
|
393
|
+
|
|
394
|
+
``data`` 可以是 dict / list / 字符串 / 数字等任意可 JSON 序列化的对象,
|
|
395
|
+
以便支持 ``message_index`` 这种简单结构 / LangGraph latest checkpoint id
|
|
396
|
+
这种字符串 payload。
|
|
397
|
+
"""
|
|
398
|
+
try:
|
|
399
|
+
payload = json.dumps(data, ensure_ascii=False, separators=(",", ":"), default=_json_default)
|
|
400
|
+
await self._blob.set(key, payload)
|
|
401
|
+
except MemoryStorageError:
|
|
402
|
+
raise
|
|
403
|
+
except Exception as e:
|
|
404
|
+
raise MemoryStorageError(f"Failed to write key '{key}': {e}") from e
|
|
405
|
+
|
|
406
|
+
async def _blob_get_text(self, key: str, **kwargs) -> Optional[str]:
|
|
407
|
+
"""Get raw text from blob store. Returns None if not found."""
|
|
408
|
+
try:
|
|
409
|
+
raw = await self._blob.get(key, type="text", **kwargs)
|
|
410
|
+
if raw is None:
|
|
411
|
+
return None
|
|
412
|
+
return raw if isinstance(raw, str) else str(raw)
|
|
413
|
+
except MemoryStorageError:
|
|
414
|
+
raise
|
|
415
|
+
except Exception as e:
|
|
416
|
+
raise MemoryStorageError(f"Failed to read key '{key}': {e}") from e
|
|
417
|
+
|
|
418
|
+
async def _blob_delete(self, key: str) -> None:
|
|
419
|
+
"""Delete a key from blob store (idempotent)."""
|
|
420
|
+
try:
|
|
421
|
+
await self._blob.delete(key)
|
|
422
|
+
except MemoryStorageError:
|
|
423
|
+
raise
|
|
424
|
+
except Exception as e:
|
|
425
|
+
raise MemoryStorageError(f"Failed to delete key '{key}': {e}") from e
|
|
426
|
+
|
|
427
|
+
async def _blob_list_keys(self, prefix: str, **kwargs) -> List[str]:
|
|
428
|
+
"""List all keys with given prefix, sorted lexicographically."""
|
|
429
|
+
try:
|
|
430
|
+
result = await self._blob.list(prefix=prefix, **kwargs)
|
|
431
|
+
return [blob.key for blob in result.blobs]
|
|
432
|
+
except MemoryStorageError:
|
|
433
|
+
raise
|
|
434
|
+
except Exception as e:
|
|
435
|
+
raise MemoryStorageError(f"Failed to list keys with prefix '{prefix}': {e}") from e
|
|
436
|
+
|
|
437
|
+
# ─── Core API Methods ───
|
|
438
|
+
|
|
439
|
+
async def append_message(
|
|
440
|
+
self,
|
|
441
|
+
conversation_id: str,
|
|
442
|
+
role: str,
|
|
443
|
+
content: Any,
|
|
444
|
+
metadata: Optional[dict] = None,
|
|
445
|
+
user_id: Optional[str] = None,
|
|
446
|
+
) -> str:
|
|
447
|
+
"""Append a message to a conversation.
|
|
448
|
+
|
|
449
|
+
Args:
|
|
450
|
+
conversation_id: Conversation identifier (max 256 chars).
|
|
451
|
+
role: One of 'user', 'assistant', 'system', 'tool'.
|
|
452
|
+
content: Message content (str or list for multimodal).
|
|
453
|
+
metadata: Optional metadata dict. ``run_id`` is auto-injected unless
|
|
454
|
+
already present. Pass ``idempotency_key`` in metadata to enable
|
|
455
|
+
idempotent writes.
|
|
456
|
+
user_id: Optional user identifier. When provided, the conversation
|
|
457
|
+
is indexed under this user for later retrieval via
|
|
458
|
+
``list_conversations(user_id=...)``.
|
|
459
|
+
|
|
460
|
+
Returns:
|
|
461
|
+
The generated message_id.
|
|
462
|
+
|
|
463
|
+
Raises:
|
|
464
|
+
MemoryValidationError: Invalid input.
|
|
465
|
+
MemoryQuotaExceededError: Conversation has >= 10000 messages.
|
|
466
|
+
MemoryStorageError: Blob store failure.
|
|
467
|
+
"""
|
|
468
|
+
# Validate
|
|
469
|
+
self._validate_conversation_id(conversation_id)
|
|
470
|
+
self._validate_role(role)
|
|
471
|
+
self._validate_content(content)
|
|
472
|
+
self._validate_metadata(metadata)
|
|
473
|
+
if user_id is not None:
|
|
474
|
+
self._validate_user_id(user_id)
|
|
475
|
+
|
|
476
|
+
encoded_cid = _encode_cid(conversation_id)
|
|
477
|
+
now_ms = int(time.time() * 1000)
|
|
478
|
+
|
|
479
|
+
# Prepare metadata with run_id injection
|
|
480
|
+
if metadata is None:
|
|
481
|
+
metadata = {}
|
|
482
|
+
else:
|
|
483
|
+
metadata = dict(metadata) # shallow copy to avoid mutating caller's dict
|
|
484
|
+
if "run_id" not in metadata:
|
|
485
|
+
metadata["run_id"] = self._run_id
|
|
486
|
+
|
|
487
|
+
# Idempotency check
|
|
488
|
+
idempotency_key = metadata.get("idempotency_key")
|
|
489
|
+
if idempotency_key is not None:
|
|
490
|
+
existing_id = await self._check_idempotency(encoded_cid, idempotency_key)
|
|
491
|
+
if existing_id is not None:
|
|
492
|
+
return existing_id
|
|
493
|
+
|
|
494
|
+
# Load conversation meta (or prepare to create new)
|
|
495
|
+
meta_k = _meta_key(encoded_cid)
|
|
496
|
+
meta_data = await self._blob_get_json(meta_k)
|
|
497
|
+
|
|
498
|
+
if meta_data is not None:
|
|
499
|
+
meta = ConversationMeta.from_dict(meta_data)
|
|
500
|
+
if meta.message_count >= _MAX_MESSAGES_PER_CONVERSATION:
|
|
501
|
+
raise MemoryQuotaExceededError(
|
|
502
|
+
f"Conversation '{conversation_id}' has reached the maximum of "
|
|
503
|
+
f"{_MAX_MESSAGES_PER_CONVERSATION} messages"
|
|
504
|
+
)
|
|
505
|
+
old_last_message_at = meta.last_message_at
|
|
506
|
+
else:
|
|
507
|
+
meta = ConversationMeta(
|
|
508
|
+
conversation_id=conversation_id,
|
|
509
|
+
created_at=now_ms,
|
|
510
|
+
last_message_at=0,
|
|
511
|
+
message_count=0,
|
|
512
|
+
metadata=None,
|
|
513
|
+
)
|
|
514
|
+
old_last_message_at = None # no old index to delete
|
|
515
|
+
|
|
516
|
+
# Generate message
|
|
517
|
+
message_id = _generate_message_id()
|
|
518
|
+
message = Message(
|
|
519
|
+
message_id=message_id,
|
|
520
|
+
role=role,
|
|
521
|
+
content=content,
|
|
522
|
+
created_at=now_ms,
|
|
523
|
+
metadata=metadata if metadata else None,
|
|
524
|
+
)
|
|
525
|
+
|
|
526
|
+
# Write message blob
|
|
527
|
+
msg_key = _message_key(encoded_cid, now_ms, message_id)
|
|
528
|
+
await self._blob_set_json(msg_key, message.to_dict())
|
|
529
|
+
|
|
530
|
+
# Write message_index for O(1) lookup by message_id (mirrors Node memory.ts).
|
|
531
|
+
await self._blob_set_json(
|
|
532
|
+
_message_index_key(message_id),
|
|
533
|
+
{
|
|
534
|
+
"conversationId": conversation_id,
|
|
535
|
+
"key": msg_key,
|
|
536
|
+
"messageId": message_id,
|
|
537
|
+
"createdAt": now_ms,
|
|
538
|
+
},
|
|
539
|
+
)
|
|
540
|
+
|
|
541
|
+
# Update conversation meta
|
|
542
|
+
meta.last_message_at = now_ms
|
|
543
|
+
meta.message_count += 1
|
|
544
|
+
await self._blob_set_json(meta_k, meta.to_dict())
|
|
545
|
+
|
|
546
|
+
# Update conversation index: write new first, then delete old (safety)
|
|
547
|
+
new_index_key = _index_key(now_ms, encoded_cid)
|
|
548
|
+
index_entry = {"conversationId": conversation_id, "lastMessageAt": now_ms}
|
|
549
|
+
await self._blob_set_json(new_index_key, index_entry)
|
|
550
|
+
|
|
551
|
+
if old_last_message_at is not None and old_last_message_at != now_ms:
|
|
552
|
+
old_index_key = _index_key(old_last_message_at, encoded_cid)
|
|
553
|
+
await self._blob_delete(old_index_key)
|
|
554
|
+
|
|
555
|
+
# Update user conversation index (if user_id provided)
|
|
556
|
+
if user_id is not None:
|
|
557
|
+
encoded_user_id = _encode_segment(user_id)
|
|
558
|
+
new_user_key = _user_index_key(encoded_user_id, now_ms, encoded_cid)
|
|
559
|
+
index_entry_user = {"conversationId": conversation_id, "lastMessageAt": now_ms}
|
|
560
|
+
await self._blob_set_json(new_user_key, index_entry_user)
|
|
561
|
+
|
|
562
|
+
# Delete old user index key for this user+conversation (if exists and different)
|
|
563
|
+
if old_last_message_at is not None and old_last_message_at != now_ms:
|
|
564
|
+
old_user_key = _user_index_key(encoded_user_id, old_last_message_at, encoded_cid)
|
|
565
|
+
await self._blob_delete(old_user_key)
|
|
566
|
+
elif old_last_message_at is None:
|
|
567
|
+
# First message in this conversation — scan for stale user index
|
|
568
|
+
# (edge case: conversation was deleted + recreated under same id)
|
|
569
|
+
prefix = _user_index_prefix(encoded_user_id)
|
|
570
|
+
existing_keys = await self._blob_list_keys(prefix)
|
|
571
|
+
suffix = f"_{encoded_cid}"
|
|
572
|
+
for k in existing_keys:
|
|
573
|
+
if k.endswith(suffix) and k != new_user_key:
|
|
574
|
+
await self._blob_delete(k)
|
|
575
|
+
|
|
576
|
+
return message_id
|
|
577
|
+
|
|
578
|
+
async def _check_idempotency(self, encoded_cid: str, idempotency_key: str) -> Optional[str]:
|
|
579
|
+
"""Scan recent messages for matching idempotency_key. Returns message_id if found."""
|
|
580
|
+
prefix = _messages_prefix(encoded_cid)
|
|
581
|
+
keys = await self._blob_list_keys(prefix)
|
|
582
|
+
# Scan at most last 50 messages for bounded cost
|
|
583
|
+
check_count = min(len(keys), 50)
|
|
584
|
+
for key in keys[-check_count:]:
|
|
585
|
+
data = await self._blob_get_json(key)
|
|
586
|
+
if data and isinstance(data.get("metadata"), dict):
|
|
587
|
+
if data["metadata"].get("idempotency_key") == idempotency_key:
|
|
588
|
+
return data.get("messageId") or data.get("message_id", "")
|
|
589
|
+
return None
|
|
590
|
+
|
|
591
|
+
async def get_messages(
|
|
592
|
+
self,
|
|
593
|
+
conversation_id: str,
|
|
594
|
+
limit: int = _DEFAULT_LIMIT,
|
|
595
|
+
order: str = "asc",
|
|
596
|
+
after: Optional[str] = None,
|
|
597
|
+
before: Optional[str] = None,
|
|
598
|
+
) -> List[Message]:
|
|
599
|
+
"""Get messages from a conversation.
|
|
600
|
+
|
|
601
|
+
Args:
|
|
602
|
+
conversation_id: Conversation identifier.
|
|
603
|
+
limit: Max messages to return (default 20, max 100).
|
|
604
|
+
order: 'asc' (oldest first, default) or 'desc' (newest first).
|
|
605
|
+
after: Cursor — message_id to start after (exclusive).
|
|
606
|
+
before: Cursor — message_id to end before (exclusive).
|
|
607
|
+
|
|
608
|
+
Returns:
|
|
609
|
+
List of Message objects in requested order.
|
|
610
|
+
"""
|
|
611
|
+
import asyncio
|
|
612
|
+
|
|
613
|
+
self._validate_conversation_id(conversation_id)
|
|
614
|
+
self._validate_limit(limit)
|
|
615
|
+
if order not in ("asc", "desc"):
|
|
616
|
+
raise MemoryValidationError("order must be 'asc' or 'desc'")
|
|
617
|
+
self._assert_single_cursor(after, before)
|
|
618
|
+
|
|
619
|
+
encoded_cid = _encode_cid(conversation_id)
|
|
620
|
+
prefix = _messages_prefix(encoded_cid)
|
|
621
|
+
keys = await self._blob_list_keys(prefix, consistency="strong")
|
|
622
|
+
# keys are sorted lexicographically by {created_at}_{message_id}
|
|
623
|
+
|
|
624
|
+
# Apply cursor filtering
|
|
625
|
+
if after is not None:
|
|
626
|
+
after_idx = self._find_message_key_index(keys, after)
|
|
627
|
+
if after_idx is not None:
|
|
628
|
+
keys = keys[after_idx + 1:]
|
|
629
|
+
|
|
630
|
+
if before is not None:
|
|
631
|
+
before_idx = self._find_message_key_index(keys, before)
|
|
632
|
+
if before_idx is not None:
|
|
633
|
+
keys = keys[:before_idx]
|
|
634
|
+
|
|
635
|
+
# Apply order and limit
|
|
636
|
+
if order == "desc":
|
|
637
|
+
keys = list(reversed(keys))
|
|
638
|
+
|
|
639
|
+
keys = keys[:limit]
|
|
640
|
+
|
|
641
|
+
# Fetch message bodies concurrently for much lower latency.
|
|
642
|
+
# Instead of N sequential HTTP round-trips, all reads happen in parallel
|
|
643
|
+
# (wall-clock = 1 RTT instead of N × RTT).
|
|
644
|
+
async def _fetch_one(key: str) -> Optional[Message]:
|
|
645
|
+
data = await self._blob_get_json(key, consistency="strong")
|
|
646
|
+
return Message.from_dict(data) if data else None
|
|
647
|
+
|
|
648
|
+
results = await asyncio.gather(*[_fetch_one(k) for k in keys])
|
|
649
|
+
return [m for m in results if m is not None]
|
|
650
|
+
|
|
651
|
+
@staticmethod
|
|
652
|
+
def _find_message_key_index(keys: List[str], message_id: str) -> Optional[int]:
|
|
653
|
+
"""Find index of key containing given message_id."""
|
|
654
|
+
suffix = f"_{message_id}"
|
|
655
|
+
for i, key in enumerate(keys):
|
|
656
|
+
if key.endswith(suffix):
|
|
657
|
+
return i
|
|
658
|
+
return None
|
|
659
|
+
|
|
660
|
+
async def update_message(
|
|
661
|
+
self,
|
|
662
|
+
conversation_id: str,
|
|
663
|
+
message_id: str,
|
|
664
|
+
content: Any = _UNSET,
|
|
665
|
+
metadata: Any = _UNSET,
|
|
666
|
+
) -> Message:
|
|
667
|
+
"""Update an existing message's content / metadata.
|
|
668
|
+
|
|
669
|
+
语义对齐 Node ``memory.updateMessage``:
|
|
670
|
+
- ``content``:传入即整体替换。未传则保持不变。
|
|
671
|
+
- ``metadata``:传入即与原 metadata **浅合并**(同 key 覆盖、新 key 追加);
|
|
672
|
+
传入 ``None`` 视为「未传」(不变)。
|
|
673
|
+
- 自动写入 ``updated_at`` 时间戳。
|
|
674
|
+
|
|
675
|
+
Args:
|
|
676
|
+
conversation_id: 必须和 message 当前所属 conversation 一致;
|
|
677
|
+
对不上抛 ``MemoryNotFoundError``。
|
|
678
|
+
message_id: 待更新消息 id。
|
|
679
|
+
content: 新 content(str / list / dict 等);省略表示保留原值。
|
|
680
|
+
metadata: 用于浅合并的 metadata;省略 / None 表示保留原值。
|
|
681
|
+
|
|
682
|
+
Returns:
|
|
683
|
+
更新后的 ``Message`` 对象。
|
|
684
|
+
|
|
685
|
+
Raises:
|
|
686
|
+
MemoryValidationError: content / metadata / id 不合法或两个字段都没传。
|
|
687
|
+
MemoryNotFoundError: 消息不存在或 conversation 对不上。
|
|
688
|
+
MemoryStorageError: blob 读写失败。
|
|
689
|
+
"""
|
|
690
|
+
self._validate_conversation_id(conversation_id)
|
|
691
|
+
self._validate_message_id(message_id)
|
|
692
|
+
|
|
693
|
+
content_provided = content is not _UNSET
|
|
694
|
+
metadata_provided = metadata is not _UNSET
|
|
695
|
+
|
|
696
|
+
if not content_provided and not metadata_provided:
|
|
697
|
+
raise MemoryValidationError("content or metadata is required for update_message")
|
|
698
|
+
|
|
699
|
+
if content_provided:
|
|
700
|
+
self._validate_content(content)
|
|
701
|
+
if metadata_provided and metadata is not None and not isinstance(metadata, dict):
|
|
702
|
+
raise MemoryValidationError("metadata must be a dict or None")
|
|
703
|
+
|
|
704
|
+
# 通过 message_index 二级索引拿到原始 key
|
|
705
|
+
index_data = await self._blob_get_json(_message_index_key(message_id))
|
|
706
|
+
if not index_data or (index_data.get("conversationId") or index_data.get("conversation_id")) != conversation_id:
|
|
707
|
+
raise MemoryNotFoundError(
|
|
708
|
+
f"Message '{message_id}' not found in conversation '{conversation_id}'"
|
|
709
|
+
)
|
|
710
|
+
|
|
711
|
+
msg_key = index_data.get("key")
|
|
712
|
+
if not msg_key:
|
|
713
|
+
raise MemoryNotFoundError(f"Message '{message_id}' has no stored key")
|
|
714
|
+
|
|
715
|
+
existed = await self._blob_get_json(msg_key)
|
|
716
|
+
if not existed:
|
|
717
|
+
raise MemoryNotFoundError(f"Message '{message_id}' not found")
|
|
718
|
+
|
|
719
|
+
existed_msg = Message.from_dict(existed)
|
|
720
|
+
if content_provided:
|
|
721
|
+
existed_msg.content = content
|
|
722
|
+
if metadata_provided and metadata is not None:
|
|
723
|
+
base = dict(existed_msg.metadata or {})
|
|
724
|
+
base.update(metadata)
|
|
725
|
+
existed_msg.metadata = base if base else None
|
|
726
|
+
existed_msg.updated_at = int(time.time() * 1000)
|
|
727
|
+
|
|
728
|
+
await self._blob_set_json(msg_key, existed_msg.to_dict())
|
|
729
|
+
return existed_msg
|
|
730
|
+
|
|
731
|
+
async def delete_message(self, conversation_id: str, message_id: str) -> None:
|
|
732
|
+
"""Delete a single message by message_id.
|
|
733
|
+
|
|
734
|
+
语义对齐 Node ``memory.deleteMessage``:
|
|
735
|
+
- 通过 message_index 定位 → 删消息本体 → 删 message_index → 重算 meta。
|
|
736
|
+
- 找不到时抛 ``MemoryNotFoundError``。
|
|
737
|
+
- 删完最后一条消息后,meta 仍**保留**(``message_count=0``,``last_message_at``
|
|
738
|
+
回退到 ``created_at``),与 ``clear_messages`` 行为一致。
|
|
739
|
+
|
|
740
|
+
Args:
|
|
741
|
+
conversation_id: 消息所属 conversation。
|
|
742
|
+
message_id: 待删除消息 id。
|
|
743
|
+
|
|
744
|
+
Raises:
|
|
745
|
+
MemoryValidationError: id 不合法。
|
|
746
|
+
MemoryNotFoundError: 消息或 conversation 不存在。
|
|
747
|
+
MemoryStorageError: blob 读写失败。
|
|
748
|
+
"""
|
|
749
|
+
self._validate_conversation_id(conversation_id)
|
|
750
|
+
self._validate_message_id(message_id)
|
|
751
|
+
|
|
752
|
+
index_data = await self._blob_get_json(_message_index_key(message_id))
|
|
753
|
+
if not index_data or (index_data.get("conversationId") or index_data.get("conversation_id")) != conversation_id:
|
|
754
|
+
raise MemoryNotFoundError(
|
|
755
|
+
f"Message '{message_id}' not found in conversation '{conversation_id}'"
|
|
756
|
+
)
|
|
757
|
+
|
|
758
|
+
msg_key = index_data.get("key")
|
|
759
|
+
if not msg_key:
|
|
760
|
+
raise MemoryNotFoundError(f"Message '{message_id}' has no stored key")
|
|
761
|
+
|
|
762
|
+
encoded_cid = _encode_cid(conversation_id)
|
|
763
|
+
meta_k = _meta_key(encoded_cid)
|
|
764
|
+
meta_data = await self._blob_get_json(meta_k)
|
|
765
|
+
if meta_data is None:
|
|
766
|
+
raise MemoryNotFoundError(f"Conversation '{conversation_id}' not found")
|
|
767
|
+
previous_meta = ConversationMeta.from_dict(meta_data)
|
|
768
|
+
|
|
769
|
+
# 删消息本体 + 二级索引
|
|
770
|
+
await self._blob_delete(msg_key)
|
|
771
|
+
await self._blob_delete(_message_index_key(message_id))
|
|
772
|
+
|
|
773
|
+
# 重算 meta:扫剩余消息得到最新 last_message_at;空则回退到 created_at
|
|
774
|
+
next_meta = await self._recalculate_conversation_meta(encoded_cid, previous_meta)
|
|
775
|
+
await self._blob_set_json(meta_k, next_meta.to_dict())
|
|
776
|
+
|
|
777
|
+
# 更新 conversation_index:写新(如果有变动)+ 删旧
|
|
778
|
+
new_index_key = _index_key(next_meta.last_message_at, encoded_cid)
|
|
779
|
+
old_index_key = _index_key(previous_meta.last_message_at, encoded_cid)
|
|
780
|
+
if new_index_key != old_index_key:
|
|
781
|
+
await self._blob_set_json(
|
|
782
|
+
new_index_key,
|
|
783
|
+
{
|
|
784
|
+
"conversationId": conversation_id,
|
|
785
|
+
"lastMessageAt": next_meta.last_message_at,
|
|
786
|
+
},
|
|
787
|
+
)
|
|
788
|
+
await self._blob_delete(old_index_key)
|
|
789
|
+
|
|
790
|
+
async def _recalculate_conversation_meta(
|
|
791
|
+
self,
|
|
792
|
+
encoded_cid: str,
|
|
793
|
+
previous_meta: ConversationMeta,
|
|
794
|
+
) -> ConversationMeta:
|
|
795
|
+
"""重新计算 conversation 的 message_count / last_message_at。
|
|
796
|
+
|
|
797
|
+
删除单条消息后调用 —— 扫剩余 message keys,从 key 中解析 created_at
|
|
798
|
+
(key 格式: conversations/{cid}/messages/{created_at}_{msg_id});
|
|
799
|
+
没消息时 last_message_at 回退到 created_at。
|
|
800
|
+
"""
|
|
801
|
+
prefix = _messages_prefix(encoded_cid)
|
|
802
|
+
keys = await self._blob_list_keys(prefix)
|
|
803
|
+
|
|
804
|
+
latest_created_at = 0
|
|
805
|
+
for key in keys:
|
|
806
|
+
# key 末段是 "{created_at}_{message_id}",直接解析 timestamp
|
|
807
|
+
segment = key.rsplit("/", 1)[-1]
|
|
808
|
+
try:
|
|
809
|
+
ts = int(segment.split("_", 1)[0])
|
|
810
|
+
except (ValueError, IndexError):
|
|
811
|
+
ts = 0
|
|
812
|
+
if ts > latest_created_at:
|
|
813
|
+
latest_created_at = ts
|
|
814
|
+
|
|
815
|
+
return ConversationMeta(
|
|
816
|
+
conversation_id=previous_meta.conversation_id,
|
|
817
|
+
created_at=previous_meta.created_at,
|
|
818
|
+
last_message_at=latest_created_at if keys else previous_meta.created_at,
|
|
819
|
+
message_count=len(keys),
|
|
820
|
+
metadata=previous_meta.metadata,
|
|
821
|
+
)
|
|
822
|
+
|
|
823
|
+
async def clear_messages(self, conversation_id: str) -> None:
|
|
824
|
+
"""Delete all messages in a conversation, **保留 meta**。
|
|
825
|
+
|
|
826
|
+
语义对齐 Node ``memory.clearMessages``:清空消息后 ``get_conversation``
|
|
827
|
+
仍然返回 meta(``message_count=0``、``last_message_at=created_at``)。
|
|
828
|
+
若要彻底删除请用 ``delete_conversation``。
|
|
829
|
+
|
|
830
|
+
Raises:
|
|
831
|
+
MemoryNotFoundError: conversation 不存在。
|
|
832
|
+
"""
|
|
833
|
+
self._validate_conversation_id(conversation_id)
|
|
834
|
+
|
|
835
|
+
encoded_cid = _encode_cid(conversation_id)
|
|
836
|
+
meta_k = _meta_key(encoded_cid)
|
|
837
|
+
meta_data = await self._blob_get_json(meta_k)
|
|
838
|
+
|
|
839
|
+
if meta_data is None:
|
|
840
|
+
raise MemoryNotFoundError(f"Conversation '{conversation_id}' not found")
|
|
841
|
+
|
|
842
|
+
previous_meta = ConversationMeta.from_dict(meta_data)
|
|
843
|
+
|
|
844
|
+
# 删所有消息正文 + 对应二级索引
|
|
845
|
+
prefix = _messages_prefix(encoded_cid)
|
|
846
|
+
keys = await self._blob_list_keys(prefix)
|
|
847
|
+
for key in keys:
|
|
848
|
+
data = await self._blob_get_json(key)
|
|
849
|
+
await self._blob_delete(key)
|
|
850
|
+
if data and isinstance(data.get("messageId") or data.get("message_id"), str):
|
|
851
|
+
await self._blob_delete(_message_index_key(data.get("messageId") or data.get("message_id", "")))
|
|
852
|
+
|
|
853
|
+
# 重置 meta:保留 conversation_id / created_at / 自定义 metadata,
|
|
854
|
+
# 把 message_count 置 0,last_message_at 回退到 created_at。
|
|
855
|
+
next_meta = ConversationMeta(
|
|
856
|
+
conversation_id=previous_meta.conversation_id,
|
|
857
|
+
created_at=previous_meta.created_at,
|
|
858
|
+
last_message_at=previous_meta.created_at,
|
|
859
|
+
message_count=0,
|
|
860
|
+
metadata=previous_meta.metadata,
|
|
861
|
+
)
|
|
862
|
+
await self._blob_set_json(meta_k, next_meta.to_dict())
|
|
863
|
+
|
|
864
|
+
# 更新 conversation_index
|
|
865
|
+
new_index_key = _index_key(next_meta.last_message_at, encoded_cid)
|
|
866
|
+
old_index_key = _index_key(previous_meta.last_message_at, encoded_cid)
|
|
867
|
+
if new_index_key != old_index_key:
|
|
868
|
+
await self._blob_set_json(
|
|
869
|
+
new_index_key,
|
|
870
|
+
{
|
|
871
|
+
"conversationId": conversation_id,
|
|
872
|
+
"lastMessageAt": next_meta.last_message_at,
|
|
873
|
+
},
|
|
874
|
+
)
|
|
875
|
+
await self._blob_delete(old_index_key)
|
|
876
|
+
|
|
877
|
+
async def list_conversations(
|
|
878
|
+
self,
|
|
879
|
+
limit: int = _DEFAULT_LIMIT,
|
|
880
|
+
order: str = "desc",
|
|
881
|
+
after: Optional[str] = None,
|
|
882
|
+
before: Optional[str] = None,
|
|
883
|
+
user_id: Optional[str] = None,
|
|
884
|
+
) -> ListConversationsResult:
|
|
885
|
+
"""List conversations, ordered by last_message_at.
|
|
886
|
+
|
|
887
|
+
Args:
|
|
888
|
+
limit: Max conversations to return (default 20, max 100).
|
|
889
|
+
order: 'desc' (most recent first, default) or 'asc'.
|
|
890
|
+
after: Opaque cursor from a previous ``next_cursor``.
|
|
891
|
+
before: Opaque cursor from a previous ``previous_cursor``.
|
|
892
|
+
user_id: If provided, only list conversations indexed under this user.
|
|
893
|
+
|
|
894
|
+
Returns:
|
|
895
|
+
ListConversationsResult with items, next_cursor, previous_cursor.
|
|
896
|
+
"""
|
|
897
|
+
self._validate_limit(limit)
|
|
898
|
+
if order not in ("asc", "desc"):
|
|
899
|
+
raise MemoryValidationError("order must be 'asc' or 'desc'")
|
|
900
|
+
self._assert_single_cursor(after, before)
|
|
901
|
+
if user_id is not None:
|
|
902
|
+
self._validate_user_id(user_id)
|
|
903
|
+
|
|
904
|
+
# Determine prefix (global vs user-scoped)
|
|
905
|
+
if user_id is not None:
|
|
906
|
+
encoded_user_id = _encode_segment(user_id)
|
|
907
|
+
prefix = _user_index_prefix(encoded_user_id)
|
|
908
|
+
else:
|
|
909
|
+
prefix = _index_prefix()
|
|
910
|
+
|
|
911
|
+
# List all index keys (lexicographic ascending = most recent first with revTs)
|
|
912
|
+
# Use strong consistency to align with Node behavior (avoid CDN stale data)
|
|
913
|
+
keys = await self._blob_list_keys(prefix, consistency="strong")
|
|
914
|
+
|
|
915
|
+
# Default order is desc (most recent first). With revTs, ascending lex order
|
|
916
|
+
# already means most-recent-first, so desc = natural order, asc = reversed.
|
|
917
|
+
if order == "asc":
|
|
918
|
+
keys = list(reversed(keys))
|
|
919
|
+
|
|
920
|
+
# Apply cursor-based boundary filtering
|
|
921
|
+
prefix_len = len(prefix)
|
|
922
|
+
if after is not None:
|
|
923
|
+
cursor_data = _decode_cursor(after)
|
|
924
|
+
cursor_sk = _cursor_sort_key(cursor_data["lastMessageAt"], cursor_data["conversationId"])
|
|
925
|
+
keys = self._apply_after_cursor(keys, prefix_len, cursor_sk, order)
|
|
926
|
+
|
|
927
|
+
if before is not None:
|
|
928
|
+
cursor_data = _decode_cursor(before)
|
|
929
|
+
cursor_sk = _cursor_sort_key(cursor_data["lastMessageAt"], cursor_data["conversationId"])
|
|
930
|
+
keys = self._apply_before_cursor(keys, prefix_len, cursor_sk, order)
|
|
931
|
+
|
|
932
|
+
# Fetch up to limit+1 to determine if there's a next page (handles skipped residuals)
|
|
933
|
+
items: List[ConversationMeta] = []
|
|
934
|
+
has_more = False
|
|
935
|
+
for key in keys:
|
|
936
|
+
if len(items) >= limit:
|
|
937
|
+
has_more = True
|
|
938
|
+
break
|
|
939
|
+
# Parse encoded_cid from key basename: {revTs:016d}_{encoded_cid}
|
|
940
|
+
remainder = key[prefix_len:]
|
|
941
|
+
sep_idx = remainder.find("_")
|
|
942
|
+
if sep_idx < 0:
|
|
943
|
+
continue
|
|
944
|
+
encoded_cid = remainder[sep_idx + 1:]
|
|
945
|
+
meta_data = await self._blob_get_json(_meta_key(encoded_cid), consistency="strong")
|
|
946
|
+
if meta_data:
|
|
947
|
+
try:
|
|
948
|
+
items.append(ConversationMeta.from_dict(meta_data))
|
|
949
|
+
except (KeyError, TypeError):
|
|
950
|
+
continue # skip corrupted/incompatible meta entries
|
|
951
|
+
# else: residual index (conversation deleted), skip silently
|
|
952
|
+
|
|
953
|
+
# Build cursors
|
|
954
|
+
next_cursor: Optional[str] = None
|
|
955
|
+
previous_cursor: Optional[str] = None
|
|
956
|
+
if items:
|
|
957
|
+
if has_more:
|
|
958
|
+
last = items[-1]
|
|
959
|
+
next_cursor = _encode_cursor(last.last_message_at, last.conversation_id)
|
|
960
|
+
if after is not None:
|
|
961
|
+
first = items[0]
|
|
962
|
+
previous_cursor = _encode_cursor(first.last_message_at, first.conversation_id)
|
|
963
|
+
|
|
964
|
+
return ListConversationsResult(
|
|
965
|
+
items=items,
|
|
966
|
+
next_cursor=next_cursor,
|
|
967
|
+
previous_cursor=previous_cursor,
|
|
968
|
+
)
|
|
969
|
+
|
|
970
|
+
@staticmethod
|
|
971
|
+
def _apply_after_cursor(
|
|
972
|
+
keys: List[str], prefix_len: int, cursor_sk: str, order: str
|
|
973
|
+
) -> List[str]:
|
|
974
|
+
"""Skip keys up to and including the cursor position."""
|
|
975
|
+
# In desc order (natural revTs ascending): cursor_sk should be compared
|
|
976
|
+
# with the basename of each key. We skip all items <= cursor for desc,
|
|
977
|
+
# and all items >= cursor for asc.
|
|
978
|
+
result = []
|
|
979
|
+
found = False
|
|
980
|
+
for key in keys:
|
|
981
|
+
basename = key[prefix_len:]
|
|
982
|
+
if not found:
|
|
983
|
+
if order == "desc":
|
|
984
|
+
# desc: natural lex asc. "after" means skip items with
|
|
985
|
+
# sort_key <= cursor_sk (i.e. more recent or equal)
|
|
986
|
+
if basename > cursor_sk:
|
|
987
|
+
found = True
|
|
988
|
+
result.append(key)
|
|
989
|
+
else:
|
|
990
|
+
# asc: reversed lex. "after" means skip items with
|
|
991
|
+
# sort_key >= cursor_sk (i.e. older or equal, since list is reversed)
|
|
992
|
+
if basename < cursor_sk:
|
|
993
|
+
found = True
|
|
994
|
+
result.append(key)
|
|
995
|
+
else:
|
|
996
|
+
result.append(key)
|
|
997
|
+
return result
|
|
998
|
+
|
|
999
|
+
@staticmethod
|
|
1000
|
+
def _apply_before_cursor(
|
|
1001
|
+
keys: List[str], prefix_len: int, cursor_sk: str, order: str
|
|
1002
|
+
) -> List[str]:
|
|
1003
|
+
"""Take keys only before the cursor position."""
|
|
1004
|
+
result = []
|
|
1005
|
+
for key in keys:
|
|
1006
|
+
basename = key[prefix_len:]
|
|
1007
|
+
if order == "desc":
|
|
1008
|
+
# desc: natural lex asc. "before" means stop at sort_key >= cursor_sk
|
|
1009
|
+
if basename >= cursor_sk:
|
|
1010
|
+
break
|
|
1011
|
+
else:
|
|
1012
|
+
# asc: reversed lex. "before" means stop at sort_key <= cursor_sk
|
|
1013
|
+
if basename <= cursor_sk:
|
|
1014
|
+
break
|
|
1015
|
+
result.append(key)
|
|
1016
|
+
return result
|
|
1017
|
+
|
|
1018
|
+
async def delete_conversation(self, conversation_id: str) -> None:
|
|
1019
|
+
"""Delete a conversation and all its messages.
|
|
1020
|
+
|
|
1021
|
+
Raises:
|
|
1022
|
+
MemoryNotFoundError: Conversation doesn't exist.
|
|
1023
|
+
"""
|
|
1024
|
+
self._validate_conversation_id(conversation_id)
|
|
1025
|
+
|
|
1026
|
+
encoded_cid = _encode_cid(conversation_id)
|
|
1027
|
+
meta_k = _meta_key(encoded_cid)
|
|
1028
|
+
meta_data = await self._blob_get_json(meta_k)
|
|
1029
|
+
|
|
1030
|
+
if meta_data is None:
|
|
1031
|
+
raise MemoryNotFoundError(f"Conversation '{conversation_id}' not found")
|
|
1032
|
+
|
|
1033
|
+
meta = ConversationMeta.from_dict(meta_data)
|
|
1034
|
+
|
|
1035
|
+
# Delete all message blobs + secondary indices
|
|
1036
|
+
prefix = _messages_prefix(encoded_cid)
|
|
1037
|
+
keys = await self._blob_list_keys(prefix)
|
|
1038
|
+
for key in keys:
|
|
1039
|
+
data = await self._blob_get_json(key)
|
|
1040
|
+
await self._blob_delete(key)
|
|
1041
|
+
if data and isinstance(data.get("messageId") or data.get("message_id"), str):
|
|
1042
|
+
await self._blob_delete(_message_index_key(data.get("messageId") or data.get("message_id", "")))
|
|
1043
|
+
|
|
1044
|
+
# Delete index entry
|
|
1045
|
+
old_index_key = _index_key(meta.last_message_at, encoded_cid)
|
|
1046
|
+
await self._blob_delete(old_index_key)
|
|
1047
|
+
|
|
1048
|
+
# Delete meta
|
|
1049
|
+
await self._blob_delete(meta_k)
|
|
1050
|
+
|
|
1051
|
+
async def update_conversation(
|
|
1052
|
+
self, conversation_id: str, metadata: dict
|
|
1053
|
+
) -> ConversationMeta:
|
|
1054
|
+
"""Update conversation metadata (only metadata field is writable).
|
|
1055
|
+
|
|
1056
|
+
语义对齐 Node ``memory.updateConversation``:传入的 ``metadata`` 与原
|
|
1057
|
+
metadata 做**浅合并**(同 key 覆盖、新 key 追加),而非整体替换。
|
|
1058
|
+
|
|
1059
|
+
Args:
|
|
1060
|
+
conversation_id: Conversation identifier.
|
|
1061
|
+
metadata: 用于浅合并的 metadata dict。
|
|
1062
|
+
|
|
1063
|
+
Returns:
|
|
1064
|
+
Updated ConversationMeta.
|
|
1065
|
+
|
|
1066
|
+
Raises:
|
|
1067
|
+
MemoryNotFoundError: Conversation doesn't exist.
|
|
1068
|
+
MemoryValidationError: metadata is not a dict.
|
|
1069
|
+
"""
|
|
1070
|
+
self._validate_conversation_id(conversation_id)
|
|
1071
|
+
if not isinstance(metadata, dict):
|
|
1072
|
+
raise MemoryValidationError("metadata must be a dict for update_conversation")
|
|
1073
|
+
|
|
1074
|
+
encoded_cid = _encode_cid(conversation_id)
|
|
1075
|
+
meta_k = _meta_key(encoded_cid)
|
|
1076
|
+
meta_data = await self._blob_get_json(meta_k)
|
|
1077
|
+
|
|
1078
|
+
if meta_data is None:
|
|
1079
|
+
raise MemoryNotFoundError(f"Conversation '{conversation_id}' not found")
|
|
1080
|
+
|
|
1081
|
+
meta = ConversationMeta.from_dict(meta_data)
|
|
1082
|
+
merged = dict(meta.metadata or {})
|
|
1083
|
+
merged.update(metadata)
|
|
1084
|
+
meta.metadata = merged if merged else None
|
|
1085
|
+
await self._blob_set_json(meta_k, meta.to_dict())
|
|
1086
|
+
return meta
|
|
1087
|
+
|
|
1088
|
+
async def get_conversation(self, conversation_id: str) -> ConversationMeta:
|
|
1089
|
+
"""Get conversation metadata.
|
|
1090
|
+
|
|
1091
|
+
Raises:
|
|
1092
|
+
MemoryNotFoundError: Conversation doesn't exist.
|
|
1093
|
+
"""
|
|
1094
|
+
self._validate_conversation_id(conversation_id)
|
|
1095
|
+
|
|
1096
|
+
encoded_cid = _encode_cid(conversation_id)
|
|
1097
|
+
meta_k = _meta_key(encoded_cid)
|
|
1098
|
+
meta_data = await self._blob_get_json(meta_k)
|
|
1099
|
+
|
|
1100
|
+
if meta_data is None:
|
|
1101
|
+
raise MemoryNotFoundError(f"Conversation '{conversation_id}' not found")
|
|
1102
|
+
|
|
1103
|
+
return ConversationMeta.from_dict(meta_data)
|
|
1104
|
+
|
|
1105
|
+
# ─── Framework Helpers ───
|
|
1106
|
+
|
|
1107
|
+
@staticmethod
|
|
1108
|
+
def to_openai_input(messages: List[Message]) -> List[dict]:
|
|
1109
|
+
"""Convert Messages to OpenAI-compatible format.
|
|
1110
|
+
|
|
1111
|
+
Strips platform fields (message_id, created_at, metadata).
|
|
1112
|
+
Returns list of ``{"role": ..., "content": ...}`` dicts.
|
|
1113
|
+
"""
|
|
1114
|
+
return [{"role": m.role, "content": m.content} for m in messages]
|
|
1115
|
+
|
|
1116
|
+
@staticmethod
|
|
1117
|
+
def to_anthropic_messages(messages: List[Message]) -> List[dict]:
|
|
1118
|
+
"""Convert Messages to Anthropic-compatible format.
|
|
1119
|
+
|
|
1120
|
+
Strips platform fields (message_id, created_at, metadata).
|
|
1121
|
+
Returns list of ``{"role": ..., "content": ...}`` dicts.
|
|
1122
|
+
"""
|
|
1123
|
+
return [{"role": m.role, "content": m.content} for m in messages]
|
|
1124
|
+
|
|
1125
|
+
@property
|
|
1126
|
+
def langgraph_checkpointer(self) -> Any:
|
|
1127
|
+
"""Lazy-constructed adapter implementing LangGraph BaseCheckpointSaver.
|
|
1128
|
+
|
|
1129
|
+
Maps ``conversation_id`` → ``thread_id``.
|
|
1130
|
+
|
|
1131
|
+
独立前缀 ``langgraph_checkpoints/`` 存储,与消息历史完全隔离,
|
|
1132
|
+
不影响 ``message_count`` / quota。
|
|
1133
|
+
"""
|
|
1134
|
+
cached = getattr(self, "_langgraph_checkpointer_cached", None)
|
|
1135
|
+
if cached is not None:
|
|
1136
|
+
return cached
|
|
1137
|
+
adapter = _LangGraphCheckpointerAdapter(self)
|
|
1138
|
+
self._langgraph_checkpointer_cached = adapter
|
|
1139
|
+
return adapter
|
|
1140
|
+
|
|
1141
|
+
@property
|
|
1142
|
+
def langgraph_store(self) -> Any:
|
|
1143
|
+
"""LangGraph BaseStore adapter — provides get/put/search/batch/listNamespaces.
|
|
1144
|
+
|
|
1145
|
+
独立前缀 ``langgraph_store/items/`` 存储,与消息历史和 checkpoint 完全隔离。
|
|
1146
|
+
存储布局与 Node ``context.memory.langgraphStore`` 完全一致,支持跨语言互操作。
|
|
1147
|
+
|
|
1148
|
+
Usage::
|
|
1149
|
+
|
|
1150
|
+
store = ctx.store.langgraph_store
|
|
1151
|
+
await store.put(["agent", "memories"], "facts", {"content": "..."})
|
|
1152
|
+
item = await store.get(["agent", "memories"], "facts")
|
|
1153
|
+
items = await store.search(["agent"])
|
|
1154
|
+
"""
|
|
1155
|
+
cached = getattr(self, "_langgraph_store_cached", None)
|
|
1156
|
+
if cached is not None:
|
|
1157
|
+
return cached
|
|
1158
|
+
adapter = _LangGraphStoreAdapter(self)
|
|
1159
|
+
self._langgraph_store_cached = adapter
|
|
1160
|
+
return adapter
|
|
1161
|
+
|
|
1162
|
+
def openai_session(self, session_id: str, *, max_items: int = 100) -> "_EdgeOneMemorySession":
|
|
1163
|
+
"""Create an OpenAI Agents SDK Session backed by ctx.store.
|
|
1164
|
+
|
|
1165
|
+
Usage with OpenAI Agents SDK::
|
|
1166
|
+
|
|
1167
|
+
from agents import Agent, Runner
|
|
1168
|
+
|
|
1169
|
+
async def handler(ctx):
|
|
1170
|
+
agent = Agent(name="Assistant", instructions="Reply concisely.")
|
|
1171
|
+
session = ctx.store.openai_session(ctx.conversation_id)
|
|
1172
|
+
|
|
1173
|
+
result = await Runner.run(agent, user_input, session=session)
|
|
1174
|
+
return {"reply": result.final_output}
|
|
1175
|
+
|
|
1176
|
+
The session implements the Agents SDK Session protocol:
|
|
1177
|
+
- get_items(): reads history from memory
|
|
1178
|
+
- add_items(): persists new items to memory
|
|
1179
|
+
- pop_item(): removes and returns the last item
|
|
1180
|
+
- clear_session(): clears all items
|
|
1181
|
+
|
|
1182
|
+
Args:
|
|
1183
|
+
session_id: Conversation/session identifier (typically ctx.conversation_id).
|
|
1184
|
+
max_items: Maximum items to retrieve per get_items() call (default 100).
|
|
1185
|
+
"""
|
|
1186
|
+
return _EdgeOneMemorySession(self, session_id, max_items=max_items)
|
|
1187
|
+
|
|
1188
|
+
def session(self, session_id: str, *, max_items: int = 100) -> "_EdgeOneMemorySession":
|
|
1189
|
+
"""Alias for :meth:`openai_session` (backward compatibility)."""
|
|
1190
|
+
return self.openai_session(session_id, max_items=max_items)
|
|
1191
|
+
|
|
1192
|
+
def claude_session_store(self) -> "EdgeOneSessionStore":
|
|
1193
|
+
"""Create a Claude Agent SDK SessionStore backed by EdgeOne blob storage.
|
|
1194
|
+
|
|
1195
|
+
Returns a lazily-constructed singleton. The store implements the
|
|
1196
|
+
``SessionStore`` protocol from ``claude_agent_sdk`` (append/load/
|
|
1197
|
+
list_sessions/delete/list_subkeys).
|
|
1198
|
+
|
|
1199
|
+
Usage with Claude Agent SDK::
|
|
1200
|
+
|
|
1201
|
+
from claude_agent_sdk import query, ClaudeAgentOptions
|
|
1202
|
+
|
|
1203
|
+
async def handler(ctx):
|
|
1204
|
+
store = ctx.store.claude_session_store()
|
|
1205
|
+
async for msg in query(
|
|
1206
|
+
prompt="Fix the bug in auth.py",
|
|
1207
|
+
options=ClaudeAgentOptions(session_store=store),
|
|
1208
|
+
):
|
|
1209
|
+
...
|
|
1210
|
+
|
|
1211
|
+
Returns:
|
|
1212
|
+
EdgeOneSessionStore instance (singleton per ConversationMemory).
|
|
1213
|
+
"""
|
|
1214
|
+
cached = getattr(self, "_claude_session_store_cached", None)
|
|
1215
|
+
if cached is not None:
|
|
1216
|
+
return cached
|
|
1217
|
+
store = EdgeOneSessionStore(self._blob)
|
|
1218
|
+
self._claude_session_store_cached = store
|
|
1219
|
+
return store
|
|
1220
|
+
|
|
1221
|
+
|
|
1222
|
+
# ─── LangGraph Store Adapter ───
|
|
1223
|
+
|
|
1224
|
+
try:
|
|
1225
|
+
from langgraph.store.base import BaseStore as _LangGraphBaseStore
|
|
1226
|
+
except ImportError:
|
|
1227
|
+
_LangGraphBaseStore = object # type: ignore[misc,assignment]
|
|
1228
|
+
|
|
1229
|
+
|
|
1230
|
+
class _LangGraphStoreAdapter(_LangGraphBaseStore):
|
|
1231
|
+
"""LangGraph BaseStore 适配器 — 对齐 Node ``createLangGraphStore``。
|
|
1232
|
+
|
|
1233
|
+
存储布局:
|
|
1234
|
+
langgraph_store/items/{base64url(ns1)}/{base64url(ns2)}/{__key__}/{base64url(key)}
|
|
1235
|
+
|
|
1236
|
+
只需实现 batch / abatch,BaseStore 的 get/put/search/delete/list_namespaces
|
|
1237
|
+
都有默认实现(内部调 self.batch)。
|
|
1238
|
+
"""
|
|
1239
|
+
|
|
1240
|
+
def __init__(self, memory: "ConversationMemory") -> None:
|
|
1241
|
+
if _LangGraphBaseStore is not object:
|
|
1242
|
+
super().__init__()
|
|
1243
|
+
self._memory = memory
|
|
1244
|
+
|
|
1245
|
+
# ─── Key helpers ───
|
|
1246
|
+
|
|
1247
|
+
@staticmethod
|
|
1248
|
+
def _namespace_to_path(namespace) -> str:
|
|
1249
|
+
return "/".join(_encode_path_segment(str(s)) for s in namespace)
|
|
1250
|
+
|
|
1251
|
+
@classmethod
|
|
1252
|
+
def _item_key(cls, namespace, key: str) -> str:
|
|
1253
|
+
prefix = f"{_LANGGRAPH_STORE_PREFIX}/items/"
|
|
1254
|
+
ns_path = cls._namespace_to_path(namespace)
|
|
1255
|
+
encoded_key = _encode_path_segment(str(key))
|
|
1256
|
+
if ns_path:
|
|
1257
|
+
return f"{prefix}{ns_path}/{_LANGGRAPH_STORE_KEY_SEPARATOR}/{encoded_key}"
|
|
1258
|
+
return f"{prefix}{_LANGGRAPH_STORE_KEY_SEPARATOR}/{encoded_key}"
|
|
1259
|
+
|
|
1260
|
+
@classmethod
|
|
1261
|
+
def _search_prefix(cls, namespace_prefix) -> str:
|
|
1262
|
+
prefix = f"{_LANGGRAPH_STORE_PREFIX}/items/"
|
|
1263
|
+
ns_path = cls._namespace_to_path(namespace_prefix)
|
|
1264
|
+
return f"{prefix}{ns_path}/" if ns_path else prefix
|
|
1265
|
+
|
|
1266
|
+
# ─── Internal helpers ───
|
|
1267
|
+
|
|
1268
|
+
async def _parse_stored_item(self, blob_key: str) -> Optional[dict]:
|
|
1269
|
+
try:
|
|
1270
|
+
raw = await self._memory._blob.get(blob_key, type="text", consistency="strong")
|
|
1271
|
+
if raw is None:
|
|
1272
|
+
return None
|
|
1273
|
+
item = json.loads(raw)
|
|
1274
|
+
except Exception:
|
|
1275
|
+
return None
|
|
1276
|
+
if (
|
|
1277
|
+
not item
|
|
1278
|
+
or not isinstance(item, dict)
|
|
1279
|
+
or not isinstance(item.get("value"), dict)
|
|
1280
|
+
or not isinstance(item.get("key"), str)
|
|
1281
|
+
or not isinstance(item.get("namespace"), list)
|
|
1282
|
+
or not isinstance(item.get("createdAt"), (int, float))
|
|
1283
|
+
or not isinstance(item.get("updatedAt"), (int, float))
|
|
1284
|
+
):
|
|
1285
|
+
return None
|
|
1286
|
+
return item
|
|
1287
|
+
|
|
1288
|
+
@staticmethod
|
|
1289
|
+
def _to_public_item(item: dict) -> Any:
|
|
1290
|
+
"""Convert internal item to LangGraph Item (if available) or plain dict."""
|
|
1291
|
+
from datetime import datetime, timezone
|
|
1292
|
+
try:
|
|
1293
|
+
from langgraph.store.base import Item
|
|
1294
|
+
return Item(
|
|
1295
|
+
value=item["value"],
|
|
1296
|
+
key=item["key"],
|
|
1297
|
+
namespace=tuple(str(s) for s in item["namespace"]),
|
|
1298
|
+
created_at=datetime.fromtimestamp(item["createdAt"] / 1000, tz=timezone.utc),
|
|
1299
|
+
updated_at=datetime.fromtimestamp(item["updatedAt"] / 1000, tz=timezone.utc),
|
|
1300
|
+
)
|
|
1301
|
+
except ImportError:
|
|
1302
|
+
return {
|
|
1303
|
+
"value": item["value"],
|
|
1304
|
+
"key": item["key"],
|
|
1305
|
+
"namespace": tuple(str(s) for s in item["namespace"]),
|
|
1306
|
+
"created_at": datetime.fromtimestamp(item["createdAt"] / 1000, tz=timezone.utc),
|
|
1307
|
+
"updated_at": datetime.fromtimestamp(item["updatedAt"] / 1000, tz=timezone.utc),
|
|
1308
|
+
}
|
|
1309
|
+
|
|
1310
|
+
@staticmethod
|
|
1311
|
+
def _namespace_starts_with(namespace, prefix) -> bool:
|
|
1312
|
+
ns = list(namespace)
|
|
1313
|
+
pf = list(prefix)
|
|
1314
|
+
if len(pf) > len(ns):
|
|
1315
|
+
return False
|
|
1316
|
+
return all(ns[i] == pf[i] for i in range(len(pf)))
|
|
1317
|
+
|
|
1318
|
+
@staticmethod
|
|
1319
|
+
def _namespace_matches_path(namespace, path, direction: str) -> bool:
|
|
1320
|
+
ns = list(namespace)
|
|
1321
|
+
p = list(path)
|
|
1322
|
+
if len(p) > len(ns):
|
|
1323
|
+
return False
|
|
1324
|
+
offset = len(ns) - len(p) if direction == "suffix" else 0
|
|
1325
|
+
return all(
|
|
1326
|
+
seg == "*" or ns[offset + i] == seg
|
|
1327
|
+
for i, seg in enumerate(p)
|
|
1328
|
+
)
|
|
1329
|
+
|
|
1330
|
+
@staticmethod
|
|
1331
|
+
def _compare_filter_value(item_value: Any, filter_value: Any) -> bool:
|
|
1332
|
+
if isinstance(filter_value, dict):
|
|
1333
|
+
operators = list(filter_value.keys())
|
|
1334
|
+
valid_ops = {"$eq", "$ne", "$gt", "$gte", "$lt", "$lte", "$in", "$nin"}
|
|
1335
|
+
if operators and all(op in valid_ops for op in operators):
|
|
1336
|
+
for op in operators:
|
|
1337
|
+
v = filter_value[op]
|
|
1338
|
+
if op == "$eq" and item_value != v:
|
|
1339
|
+
return False
|
|
1340
|
+
elif op == "$ne" and item_value == v:
|
|
1341
|
+
return False
|
|
1342
|
+
elif op == "$gt" and not (float(item_value or 0) > float(v or 0)):
|
|
1343
|
+
return False
|
|
1344
|
+
elif op == "$gte" and not (float(item_value or 0) >= float(v or 0)):
|
|
1345
|
+
return False
|
|
1346
|
+
elif op == "$lt" and not (float(item_value or 0) < float(v or 0)):
|
|
1347
|
+
return False
|
|
1348
|
+
elif op == "$lte" and not (float(item_value or 0) <= float(v or 0)):
|
|
1349
|
+
return False
|
|
1350
|
+
elif op == "$in" and (not isinstance(v, list) or item_value not in v):
|
|
1351
|
+
return False
|
|
1352
|
+
elif op == "$nin" and isinstance(v, list) and item_value in v:
|
|
1353
|
+
return False
|
|
1354
|
+
return True
|
|
1355
|
+
return item_value == filter_value
|
|
1356
|
+
|
|
1357
|
+
@classmethod
|
|
1358
|
+
def _matches_filter(cls, item: dict, filter_: Optional[dict]) -> bool:
|
|
1359
|
+
if not filter_:
|
|
1360
|
+
return True
|
|
1361
|
+
value = item.get("value", {})
|
|
1362
|
+
return all(
|
|
1363
|
+
cls._compare_filter_value(value.get(k), v)
|
|
1364
|
+
for k, v in filter_.items()
|
|
1365
|
+
)
|
|
1366
|
+
|
|
1367
|
+
# ─── Core: abatch (async) ───
|
|
1368
|
+
|
|
1369
|
+
async def abatch(self, ops) -> list:
|
|
1370
|
+
"""Execute operations asynchronously. This is the core implementation."""
|
|
1371
|
+
results: list = []
|
|
1372
|
+
for op in ops:
|
|
1373
|
+
op_type = type(op).__name__
|
|
1374
|
+
if op_type == "GetOp":
|
|
1375
|
+
ns = list(op.namespace)
|
|
1376
|
+
blob_key = self._item_key(ns, str(op.key))
|
|
1377
|
+
item = await self._parse_stored_item(blob_key)
|
|
1378
|
+
result = self._to_public_item(item) if item else None
|
|
1379
|
+
results.append(result)
|
|
1380
|
+
elif op_type == "PutOp":
|
|
1381
|
+
ns = list(op.namespace)
|
|
1382
|
+
k = str(op.key)
|
|
1383
|
+
blob_key = self._item_key(ns, k)
|
|
1384
|
+
if op.value is None:
|
|
1385
|
+
await self._memory._blob_delete(blob_key)
|
|
1386
|
+
else:
|
|
1387
|
+
previous = await self._parse_stored_item(blob_key)
|
|
1388
|
+
now = int(time.time() * 1000)
|
|
1389
|
+
stored = {
|
|
1390
|
+
"value": op.value,
|
|
1391
|
+
"key": k,
|
|
1392
|
+
"namespace": ns,
|
|
1393
|
+
"createdAt": previous["createdAt"] if previous else now,
|
|
1394
|
+
"updatedAt": now,
|
|
1395
|
+
}
|
|
1396
|
+
await self._memory._blob_set_json(blob_key, stored)
|
|
1397
|
+
results.append(None)
|
|
1398
|
+
elif op_type == "SearchOp":
|
|
1399
|
+
ns_prefix = list(op.namespace_prefix)
|
|
1400
|
+
prefix = self._search_prefix(ns_prefix)
|
|
1401
|
+
listed = await self._memory._blob.list(prefix=prefix, consistency="strong")
|
|
1402
|
+
keys = [blob.key for blob in listed.blobs]
|
|
1403
|
+
items: list[dict] = []
|
|
1404
|
+
for bk in keys:
|
|
1405
|
+
if f"/{_LANGGRAPH_STORE_KEY_SEPARATOR}/" not in bk:
|
|
1406
|
+
continue
|
|
1407
|
+
it = await self._parse_stored_item(bk)
|
|
1408
|
+
if it and self._namespace_starts_with(it["namespace"], ns_prefix):
|
|
1409
|
+
items.append(it)
|
|
1410
|
+
filtered = [it for it in items if self._matches_filter(it, getattr(op, "filter", None))]
|
|
1411
|
+
filtered.sort(key=lambda it: ("/".join(it["namespace"]), it["key"]))
|
|
1412
|
+
offset = getattr(op, "offset", 0) or 0
|
|
1413
|
+
limit = getattr(op, "limit", 10) or 10
|
|
1414
|
+
paged = filtered[offset: offset + limit]
|
|
1415
|
+
try:
|
|
1416
|
+
from langgraph.store.base import SearchItem
|
|
1417
|
+
from datetime import datetime, timezone
|
|
1418
|
+
results.append([
|
|
1419
|
+
SearchItem(
|
|
1420
|
+
value=it["value"],
|
|
1421
|
+
key=it["key"],
|
|
1422
|
+
namespace=tuple(str(s) for s in it["namespace"]),
|
|
1423
|
+
created_at=datetime.fromtimestamp(it["createdAt"] / 1000, tz=timezone.utc),
|
|
1424
|
+
updated_at=datetime.fromtimestamp(it["updatedAt"] / 1000, tz=timezone.utc),
|
|
1425
|
+
)
|
|
1426
|
+
for it in paged
|
|
1427
|
+
])
|
|
1428
|
+
except ImportError:
|
|
1429
|
+
results.append([self._to_public_item(it) for it in paged])
|
|
1430
|
+
elif op_type == "ListNamespacesOp":
|
|
1431
|
+
all_prefix = f"{_LANGGRAPH_STORE_PREFIX}/items/"
|
|
1432
|
+
listed = await self._memory._blob.list(prefix=all_prefix, consistency="strong")
|
|
1433
|
+
keys = [blob.key for blob in listed.blobs]
|
|
1434
|
+
namespaces: set[str] = set()
|
|
1435
|
+
for bk in keys:
|
|
1436
|
+
if f"/{_LANGGRAPH_STORE_KEY_SEPARATOR}/" not in bk:
|
|
1437
|
+
continue
|
|
1438
|
+
it = await self._parse_stored_item(bk)
|
|
1439
|
+
if not it:
|
|
1440
|
+
continue
|
|
1441
|
+
ns = [str(s) for s in it["namespace"]]
|
|
1442
|
+
conditions = getattr(op, "match_conditions", None) or []
|
|
1443
|
+
matches = all(
|
|
1444
|
+
self._namespace_matches_path(
|
|
1445
|
+
ns,
|
|
1446
|
+
[str(s) for s in (cond.path if hasattr(cond, 'path') else [])],
|
|
1447
|
+
getattr(cond, "match_type", "prefix"),
|
|
1448
|
+
)
|
|
1449
|
+
for cond in conditions
|
|
1450
|
+
)
|
|
1451
|
+
if not matches:
|
|
1452
|
+
continue
|
|
1453
|
+
max_depth = getattr(op, "max_depth", None)
|
|
1454
|
+
if max_depth is not None and max_depth > 0:
|
|
1455
|
+
ns = ns[:max_depth]
|
|
1456
|
+
namespaces.add(json.dumps(ns))
|
|
1457
|
+
result_ns = [tuple(json.loads(s)) for s in namespaces]
|
|
1458
|
+
result_ns.sort(key=lambda ns: "/".join(ns))
|
|
1459
|
+
offset = getattr(op, "offset", 0) or 0
|
|
1460
|
+
limit = getattr(op, "limit", 100) or 100
|
|
1461
|
+
results.append(result_ns[offset: offset + limit])
|
|
1462
|
+
else:
|
|
1463
|
+
results.append(None)
|
|
1464
|
+
return results
|
|
1465
|
+
|
|
1466
|
+
# ─── Core: batch (sync) — bridges to abatch ───
|
|
1467
|
+
|
|
1468
|
+
# Shared worker thread + event loop for sync→async bridge.
|
|
1469
|
+
# Avoids creating/destroying ThreadPoolExecutor on every batch call.
|
|
1470
|
+
_worker_loop: Any = None
|
|
1471
|
+
_worker_thread: Any = None
|
|
1472
|
+
|
|
1473
|
+
@classmethod
|
|
1474
|
+
def _get_worker_loop(cls):
|
|
1475
|
+
"""Get or create a shared background event loop for sync batch calls."""
|
|
1476
|
+
import asyncio
|
|
1477
|
+
import threading
|
|
1478
|
+
if cls._worker_loop is None or cls._worker_loop.is_closed():
|
|
1479
|
+
cls._worker_loop = asyncio.new_event_loop()
|
|
1480
|
+
cls._worker_thread = threading.Thread(
|
|
1481
|
+
target=cls._worker_loop.run_forever, daemon=True, name="langgraph-store-worker"
|
|
1482
|
+
)
|
|
1483
|
+
cls._worker_thread.start()
|
|
1484
|
+
return cls._worker_loop
|
|
1485
|
+
|
|
1486
|
+
def batch(self, ops) -> list:
|
|
1487
|
+
"""Execute operations synchronously. Bridges to abatch."""
|
|
1488
|
+
import asyncio
|
|
1489
|
+
import threading
|
|
1490
|
+
|
|
1491
|
+
ops_list = list(ops)
|
|
1492
|
+
|
|
1493
|
+
try:
|
|
1494
|
+
loop = asyncio.get_running_loop()
|
|
1495
|
+
except RuntimeError:
|
|
1496
|
+
loop = None
|
|
1497
|
+
|
|
1498
|
+
if loop and loop.is_running():
|
|
1499
|
+
loop_thread_id = getattr(loop, '_thread_id', None)
|
|
1500
|
+
if loop_thread_id and loop_thread_id != threading.current_thread().ident:
|
|
1501
|
+
# Called from worker thread — safe to submit to main loop
|
|
1502
|
+
future = asyncio.run_coroutine_threadsafe(self.abatch(ops_list), loop)
|
|
1503
|
+
return future.result(timeout=60)
|
|
1504
|
+
else:
|
|
1505
|
+
# Called from main loop thread — use shared background loop
|
|
1506
|
+
worker_loop = self._get_worker_loop()
|
|
1507
|
+
future = asyncio.run_coroutine_threadsafe(self.abatch(ops_list), worker_loop)
|
|
1508
|
+
return future.result(timeout=60)
|
|
1509
|
+
return asyncio.run(self.abatch(ops_list))
|
|
1510
|
+
|
|
1511
|
+
def start(self) -> None:
|
|
1512
|
+
pass
|
|
1513
|
+
|
|
1514
|
+
def stop(self) -> None:
|
|
1515
|
+
pass
|
|
1516
|
+
|
|
1517
|
+
|
|
1518
|
+
# ─── LangGraph Checkpointer Adapter ───
|
|
1519
|
+
|
|
1520
|
+
try:
|
|
1521
|
+
from langgraph.checkpoint.base import BaseCheckpointSaver as _BaseCheckpointSaver
|
|
1522
|
+
except ImportError:
|
|
1523
|
+
_BaseCheckpointSaver = object # type: ignore[misc,assignment]
|
|
1524
|
+
|
|
1525
|
+
try:
|
|
1526
|
+
from langgraph.checkpoint.base import CheckpointTuple as _CheckpointTuple
|
|
1527
|
+
except ImportError:
|
|
1528
|
+
_CheckpointTuple = None # type: ignore[misc,assignment]
|
|
1529
|
+
|
|
1530
|
+
|
|
1531
|
+
def _to_checkpoint_tuple(data: dict) -> Any:
|
|
1532
|
+
"""Convert a stored dict to a CheckpointTuple (if available) or a namespace object."""
|
|
1533
|
+
import pickle
|
|
1534
|
+
|
|
1535
|
+
# Restore channel_values from pickle if stored separately
|
|
1536
|
+
checkpoint = dict(data.get("checkpoint") or {})
|
|
1537
|
+
channel_values_b64 = data.get("channel_values_pickle")
|
|
1538
|
+
if channel_values_b64 and "channel_values" not in checkpoint:
|
|
1539
|
+
try:
|
|
1540
|
+
cv = pickle.loads(base64.b64decode(channel_values_b64))
|
|
1541
|
+
checkpoint["channel_values"] = cv
|
|
1542
|
+
except Exception:
|
|
1543
|
+
checkpoint["channel_values"] = {}
|
|
1544
|
+
|
|
1545
|
+
# Clean config: remove LangGraph internal keys that were serialized as None.
|
|
1546
|
+
# When these keys exist with None value, LangGraph uses None instead of
|
|
1547
|
+
# DEFAULT_RUNTIME, causing 'NoneType' has no attribute 'override' errors.
|
|
1548
|
+
config = data.get("config") or {}
|
|
1549
|
+
configurable = config.get("configurable") or {}
|
|
1550
|
+
_INTERNAL_KEYS_TO_STRIP = (
|
|
1551
|
+
"__pregel_runtime",
|
|
1552
|
+
"__pregel_store",
|
|
1553
|
+
"__pregel_checkpointer",
|
|
1554
|
+
"__pregel_cache",
|
|
1555
|
+
)
|
|
1556
|
+
for key in _INTERNAL_KEYS_TO_STRIP:
|
|
1557
|
+
configurable.pop(key, None)
|
|
1558
|
+
if configurable:
|
|
1559
|
+
config["configurable"] = configurable
|
|
1560
|
+
|
|
1561
|
+
parent_config = data.get("parent_config")
|
|
1562
|
+
if isinstance(parent_config, dict):
|
|
1563
|
+
parent_configurable = (parent_config.get("configurable") or {})
|
|
1564
|
+
for key in _INTERNAL_KEYS_TO_STRIP:
|
|
1565
|
+
parent_configurable.pop(key, None)
|
|
1566
|
+
|
|
1567
|
+
# Clean config: remove LangGraph internal keys that were serialized as None.
|
|
1568
|
+
# When these keys exist with None value, LangGraph uses None instead of
|
|
1569
|
+
# DEFAULT_RUNTIME, causing 'NoneType' has no attribute 'override' errors.
|
|
1570
|
+
config = data.get("config") or {}
|
|
1571
|
+
configurable = config.get("configurable") or {}
|
|
1572
|
+
_INTERNAL_KEYS_TO_STRIP = (
|
|
1573
|
+
"__pregel_runtime",
|
|
1574
|
+
"__pregel_store",
|
|
1575
|
+
"__pregel_checkpointer",
|
|
1576
|
+
"__pregel_cache",
|
|
1577
|
+
)
|
|
1578
|
+
for key in _INTERNAL_KEYS_TO_STRIP:
|
|
1579
|
+
configurable.pop(key, None)
|
|
1580
|
+
if configurable:
|
|
1581
|
+
config["configurable"] = configurable
|
|
1582
|
+
|
|
1583
|
+
parent_config = data.get("parent_config")
|
|
1584
|
+
if isinstance(parent_config, dict):
|
|
1585
|
+
parent_configurable = (parent_config.get("configurable") or {})
|
|
1586
|
+
for key in _INTERNAL_KEYS_TO_STRIP:
|
|
1587
|
+
parent_configurable.pop(key, None)
|
|
1588
|
+
|
|
1589
|
+
if _CheckpointTuple is not None:
|
|
1590
|
+
return _CheckpointTuple(
|
|
1591
|
+
config=config,
|
|
1592
|
+
checkpoint=checkpoint,
|
|
1593
|
+
metadata=data.get("metadata"),
|
|
1594
|
+
parent_config=parent_config,
|
|
1595
|
+
pending_writes=data.get("pending_writes"),
|
|
1596
|
+
)
|
|
1597
|
+
# Fallback: return a simple namespace so .checkpoint / .config attribute access works
|
|
1598
|
+
class _Ns:
|
|
1599
|
+
pass
|
|
1600
|
+
ns = _Ns()
|
|
1601
|
+
ns.config = config
|
|
1602
|
+
ns.checkpoint = checkpoint
|
|
1603
|
+
ns.metadata = data.get("metadata")
|
|
1604
|
+
ns.parent_config = parent_config
|
|
1605
|
+
ns.pending_writes = data.get("pending_writes")
|
|
1606
|
+
return ns
|
|
1607
|
+
|
|
1608
|
+
|
|
1609
|
+
class _LangGraphCheckpointerAdapter(_BaseCheckpointSaver):
|
|
1610
|
+
"""LangGraph checkpoint saver — 独立前缀存储,对齐 Node ``createLangGraphCheckpointer``。
|
|
1611
|
+
|
|
1612
|
+
存储布局:
|
|
1613
|
+
|
|
1614
|
+
langgraph_checkpoints/{thread_id}/checkpoints/{checkpoint_id} 单个 checkpoint
|
|
1615
|
+
langgraph_checkpoints/{thread_id}/latest 最新 checkpoint id
|
|
1616
|
+
langgraph_checkpoints/{thread_id}/writes/{checkpoint_id}/{task_id} pending writes
|
|
1617
|
+
|
|
1618
|
+
所有 thread_id / checkpoint_id / task_id 都做 URL-encode,避免出现 ``/`` 等
|
|
1619
|
+
特殊字符破坏 key 结构。
|
|
1620
|
+
|
|
1621
|
+
LangGraph 自身有 sync 和 async 两套 API,这里同时实现两套:
|
|
1622
|
+
- ``aget_tuple`` / ``aput`` / ``aput_writes`` / ``alist`` / ``adelete_thread``:
|
|
1623
|
+
async 版本(LangGraph BaseCheckpointSaver async)。
|
|
1624
|
+
- ``get_tuple`` / ``put`` / ``put_writes`` / ``list`` / ``delete_thread``:
|
|
1625
|
+
sync 别名(LangGraph 在 sync 路径下会调;底层仍调 async,使用方需在
|
|
1626
|
+
async 上下文里)。
|
|
1627
|
+
|
|
1628
|
+
与 Node 行为差异:Node 一份代码同时挂 camelCase / snake_case 别名;这里
|
|
1629
|
+
Python 全部 snake_case,并保持 a*前缀的 async 名(Python 生态约定)。
|
|
1630
|
+
"""
|
|
1631
|
+
|
|
1632
|
+
def __init__(self, memory: ConversationMemory) -> None:
|
|
1633
|
+
if _BaseCheckpointSaver is not object:
|
|
1634
|
+
super().__init__()
|
|
1635
|
+
self._memory = memory
|
|
1636
|
+
|
|
1637
|
+
# ─── Key helpers ───
|
|
1638
|
+
|
|
1639
|
+
@staticmethod
|
|
1640
|
+
def _thread_base(thread_id: str, checkpoint_ns: str = "") -> str:
|
|
1641
|
+
base = f"{_LANGGRAPH_CHECKPOINT_PREFIX}/{_encode_segment(thread_id)}"
|
|
1642
|
+
if checkpoint_ns:
|
|
1643
|
+
return f"{base}/namespaces/{_encode_segment(checkpoint_ns)}"
|
|
1644
|
+
return base
|
|
1645
|
+
|
|
1646
|
+
@classmethod
|
|
1647
|
+
def _checkpoint_key(cls, thread_id: str, checkpoint_id: str, checkpoint_ns: str = "") -> str:
|
|
1648
|
+
return f"{cls._thread_base(thread_id, checkpoint_ns)}/checkpoints/{_encode_segment(checkpoint_id)}"
|
|
1649
|
+
|
|
1650
|
+
@classmethod
|
|
1651
|
+
def _latest_key(cls, thread_id: str, checkpoint_ns: str = "") -> str:
|
|
1652
|
+
return f"{cls._thread_base(thread_id, checkpoint_ns)}/latest"
|
|
1653
|
+
|
|
1654
|
+
@classmethod
|
|
1655
|
+
def _writes_key(cls, thread_id: str, checkpoint_id: str, task_id: str, checkpoint_ns: str = "") -> str:
|
|
1656
|
+
return f"{cls._thread_base(thread_id, checkpoint_ns)}/writes/{_encode_segment(checkpoint_id)}/{_encode_segment(task_id)}"
|
|
1657
|
+
|
|
1658
|
+
@classmethod
|
|
1659
|
+
def _thread_prefix(cls, thread_id: str, checkpoint_ns: str = "") -> str:
|
|
1660
|
+
return f"{cls._thread_base(thread_id, checkpoint_ns)}/"
|
|
1661
|
+
|
|
1662
|
+
@classmethod
|
|
1663
|
+
def _checkpoints_prefix(cls, thread_id: str, checkpoint_ns: str = "") -> str:
|
|
1664
|
+
return f"{cls._thread_base(thread_id, checkpoint_ns)}/checkpoints/"
|
|
1665
|
+
|
|
1666
|
+
@staticmethod
|
|
1667
|
+
def _resolve_thread_id(config: Any) -> str:
|
|
1668
|
+
configurable = (config or {}).get("configurable") or {}
|
|
1669
|
+
thread_id = configurable.get("thread_id") or configurable.get("threadId")
|
|
1670
|
+
if not thread_id:
|
|
1671
|
+
raise MemoryValidationError(
|
|
1672
|
+
"LangGraph checkpoint config requires configurable.thread_id"
|
|
1673
|
+
)
|
|
1674
|
+
return str(thread_id)
|
|
1675
|
+
|
|
1676
|
+
@staticmethod
|
|
1677
|
+
def _resolve_checkpoint_ns(config: Any) -> str:
|
|
1678
|
+
configurable = (config or {}).get("configurable") or {}
|
|
1679
|
+
return str(configurable.get("checkpoint_ns") or configurable.get("checkpointNs") or "")
|
|
1680
|
+
|
|
1681
|
+
@staticmethod
|
|
1682
|
+
def _resolve_checkpoint_id(config: Any, checkpoint: Any = None) -> str:
|
|
1683
|
+
if isinstance(checkpoint, dict) and checkpoint.get("id"):
|
|
1684
|
+
return str(checkpoint["id"])
|
|
1685
|
+
configurable = (config or {}).get("configurable") or {}
|
|
1686
|
+
cid = configurable.get("checkpoint_id") or configurable.get("checkpointId")
|
|
1687
|
+
if cid:
|
|
1688
|
+
return str(cid)
|
|
1689
|
+
return uuid.uuid4().hex
|
|
1690
|
+
|
|
1691
|
+
# ─── Async API(LangGraph BaseCheckpointSaver async 接口) ───
|
|
1692
|
+
|
|
1693
|
+
async def aget_tuple(self, config: Any) -> Optional[Any]:
|
|
1694
|
+
"""Get a checkpoint tuple by thread_id (+ optional checkpoint_id).
|
|
1695
|
+
|
|
1696
|
+
- 没传 ``checkpoint_id``:返回 thread 的 latest。
|
|
1697
|
+
- 传了 ``checkpoint_id``:返回指定 checkpoint。
|
|
1698
|
+
- 找不到:返回 ``None``。
|
|
1699
|
+
|
|
1700
|
+
返回 CheckpointTuple。pending_writes 从独立存储的 writes 中重建,
|
|
1701
|
+
确保 LangGraph 的 aget_delta_channel_history 能正确遍历 parent chain
|
|
1702
|
+
收集 DeltaChannel 数据。
|
|
1703
|
+
"""
|
|
1704
|
+
thread_id = self._resolve_thread_id(config)
|
|
1705
|
+
checkpoint_ns = self._resolve_checkpoint_ns(config)
|
|
1706
|
+
configurable = (config or {}).get("configurable") or {}
|
|
1707
|
+
checkpoint_id = configurable.get("checkpoint_id") or configurable.get("checkpointId")
|
|
1708
|
+
if not checkpoint_id:
|
|
1709
|
+
checkpoint_id = await self._read_latest_id(thread_id, checkpoint_ns)
|
|
1710
|
+
if not checkpoint_id:
|
|
1711
|
+
return None
|
|
1712
|
+
|
|
1713
|
+
data = await self._memory._blob_get_json(
|
|
1714
|
+
self._checkpoint_key(thread_id, str(checkpoint_id), checkpoint_ns),
|
|
1715
|
+
consistency="strong",
|
|
1716
|
+
)
|
|
1717
|
+
if data is None:
|
|
1718
|
+
return None
|
|
1719
|
+
|
|
1720
|
+
# Prevent infinite parent chain loop: if parent_config points to
|
|
1721
|
+
# the same checkpoint_id, nullify it to terminate the chain.
|
|
1722
|
+
parent_config = data.get("parent_config")
|
|
1723
|
+
if isinstance(parent_config, dict):
|
|
1724
|
+
parent_configurable = parent_config.get("configurable") or {}
|
|
1725
|
+
parent_ckpt_id = parent_configurable.get("checkpoint_id") or parent_configurable.get("checkpointId")
|
|
1726
|
+
if not parent_ckpt_id or parent_ckpt_id == checkpoint_id:
|
|
1727
|
+
data["parent_config"] = None
|
|
1728
|
+
|
|
1729
|
+
# Rebuild pending_writes from stored writes for this checkpoint.
|
|
1730
|
+
# LangGraph's aget_delta_channel_history expects pending_writes as
|
|
1731
|
+
# [(task_id, channel_name, value), ...] tuples.
|
|
1732
|
+
pending_writes = await self._load_pending_writes(thread_id, str(checkpoint_id), checkpoint_ns)
|
|
1733
|
+
data["pending_writes"] = pending_writes
|
|
1734
|
+
|
|
1735
|
+
return _to_checkpoint_tuple(data)
|
|
1736
|
+
|
|
1737
|
+
async def _load_pending_writes(self, thread_id: str, checkpoint_id: str, checkpoint_ns: str = "") -> list:
|
|
1738
|
+
"""Load all writes for this thread and convert to pending_writes format.
|
|
1739
|
+
|
|
1740
|
+
LangGraph's DeltaChannel writes are stored under various checkpoint_ids
|
|
1741
|
+
(the parent checkpoint id at the time of writing). We load ALL writes
|
|
1742
|
+
for the thread so aget_delta_channel_history can find them.
|
|
1743
|
+
Uses concurrent reads for performance.
|
|
1744
|
+
"""
|
|
1745
|
+
import asyncio
|
|
1746
|
+
import pickle
|
|
1747
|
+
# Only load writes for this specific checkpoint_id, not all writes in the thread.
|
|
1748
|
+
# This matches Node behavior where pending_writes are scoped to a checkpoint.
|
|
1749
|
+
prefix = f"{self._thread_base(thread_id, checkpoint_ns)}/writes/{_encode_segment(checkpoint_id)}/"
|
|
1750
|
+
try:
|
|
1751
|
+
listed = await self._memory._blob.list(prefix=prefix, consistency="strong")
|
|
1752
|
+
keys = [blob.key for blob in listed.blobs]
|
|
1753
|
+
except Exception:
|
|
1754
|
+
return []
|
|
1755
|
+
|
|
1756
|
+
if not keys:
|
|
1757
|
+
return []
|
|
1758
|
+
|
|
1759
|
+
# Concurrent read all writes
|
|
1760
|
+
async def _read_write(key: str):
|
|
1761
|
+
try:
|
|
1762
|
+
raw = await self._memory._blob.get(key, type="text", consistency="strong")
|
|
1763
|
+
if raw is None:
|
|
1764
|
+
return None
|
|
1765
|
+
return json.loads(raw)
|
|
1766
|
+
except Exception:
|
|
1767
|
+
return None
|
|
1768
|
+
|
|
1769
|
+
results = await asyncio.gather(*[_read_write(k) for k in keys])
|
|
1770
|
+
|
|
1771
|
+
pending_writes = []
|
|
1772
|
+
for wd in results:
|
|
1773
|
+
if not isinstance(wd, dict):
|
|
1774
|
+
continue
|
|
1775
|
+
task_id = wd.get("task_id", "")
|
|
1776
|
+
|
|
1777
|
+
# New format: writes serialized via pickle+base64
|
|
1778
|
+
writes_pickle = wd.get("writes_pickle")
|
|
1779
|
+
if writes_pickle:
|
|
1780
|
+
try:
|
|
1781
|
+
writes = pickle.loads(base64.b64decode(writes_pickle))
|
|
1782
|
+
for channel_name, channel_value in writes:
|
|
1783
|
+
pending_writes.append((task_id, channel_name, channel_value))
|
|
1784
|
+
continue
|
|
1785
|
+
except Exception:
|
|
1786
|
+
pass
|
|
1787
|
+
|
|
1788
|
+
# Legacy format: writes stored as plain JSON
|
|
1789
|
+
writes_legacy = wd.get("writes")
|
|
1790
|
+
if writes_legacy:
|
|
1791
|
+
for channel_name, channel_value in writes_legacy:
|
|
1792
|
+
pending_writes.append((task_id, channel_name, channel_value))
|
|
1793
|
+
|
|
1794
|
+
return pending_writes
|
|
1795
|
+
|
|
1796
|
+
async def aput(
|
|
1797
|
+
self,
|
|
1798
|
+
config: Any,
|
|
1799
|
+
checkpoint: Any,
|
|
1800
|
+
metadata: Any = None,
|
|
1801
|
+
new_versions: Any = None,
|
|
1802
|
+
) -> dict:
|
|
1803
|
+
"""Persist a checkpoint and update ``latest`` pointer.
|
|
1804
|
+
|
|
1805
|
+
返回更新后的 config(``configurable.checkpoint_id`` 注入实际 id),
|
|
1806
|
+
与 LangGraph 约定一致。
|
|
1807
|
+
"""
|
|
1808
|
+
thread_id = self._resolve_thread_id(config)
|
|
1809
|
+
checkpoint_ns = self._resolve_checkpoint_ns(config)
|
|
1810
|
+
checkpoint_id = self._resolve_checkpoint_id(config, checkpoint)
|
|
1811
|
+
|
|
1812
|
+
next_configurable = dict((config or {}).get("configurable") or {})
|
|
1813
|
+
next_configurable["thread_id"] = thread_id
|
|
1814
|
+
next_configurable["checkpoint_ns"] = checkpoint_ns
|
|
1815
|
+
next_configurable["checkpoint_id"] = checkpoint_id
|
|
1816
|
+
next_config = dict(config or {})
|
|
1817
|
+
next_config["configurable"] = next_configurable
|
|
1818
|
+
|
|
1819
|
+
ckpt_payload = dict(checkpoint or {})
|
|
1820
|
+
ckpt_payload["id"] = checkpoint_id
|
|
1821
|
+
|
|
1822
|
+
# channel_values may contain non-JSON-serializable objects (Runtime, etc.)
|
|
1823
|
+
# Serialize them separately using pickle+base64.
|
|
1824
|
+
import pickle
|
|
1825
|
+
channel_values = ckpt_payload.pop("channel_values", None)
|
|
1826
|
+
channel_values_b64 = None
|
|
1827
|
+
if channel_values is not None:
|
|
1828
|
+
try:
|
|
1829
|
+
channel_values_b64 = base64.b64encode(pickle.dumps(channel_values)).decode("ascii")
|
|
1830
|
+
except Exception:
|
|
1831
|
+
channel_values_b64 = None
|
|
1832
|
+
|
|
1833
|
+
await self._memory._blob_set_json(
|
|
1834
|
+
self._checkpoint_key(thread_id, checkpoint_id, checkpoint_ns),
|
|
1835
|
+
{
|
|
1836
|
+
"config": next_config,
|
|
1837
|
+
"checkpoint": ckpt_payload,
|
|
1838
|
+
"channel_values_pickle": channel_values_b64,
|
|
1839
|
+
"metadata": metadata,
|
|
1840
|
+
"parent_config": config,
|
|
1841
|
+
"pending_writes": [],
|
|
1842
|
+
"new_versions": new_versions,
|
|
1843
|
+
},
|
|
1844
|
+
)
|
|
1845
|
+
await self._memory._blob_set_json(self._latest_key(thread_id, checkpoint_ns), checkpoint_id)
|
|
1846
|
+
return next_config
|
|
1847
|
+
|
|
1848
|
+
async def aput_writes(self, config: Any, writes: Any, task_id: Any) -> None:
|
|
1849
|
+
"""Persist pending writes for a checkpoint task.
|
|
1850
|
+
|
|
1851
|
+
等价于 LangGraph ``BaseCheckpointSaver.aput_writes``。如果 config 没
|
|
1852
|
+
给 checkpoint_id,则落到 latest;都没有就用 ``"pending"`` 兜底。
|
|
1853
|
+
|
|
1854
|
+
writes 中包含 LangChain Message 对象等不可 JSON 安全 round-trip 的类型,
|
|
1855
|
+
用 pickle+base64 序列化以保留完整 Python 类型信息。
|
|
1856
|
+
"""
|
|
1857
|
+
import pickle
|
|
1858
|
+
thread_id = self._resolve_thread_id(config)
|
|
1859
|
+
checkpoint_ns = self._resolve_checkpoint_ns(config)
|
|
1860
|
+
configurable = (config or {}).get("configurable") or {}
|
|
1861
|
+
checkpoint_id = (
|
|
1862
|
+
configurable.get("checkpoint_id")
|
|
1863
|
+
or configurable.get("checkpointId")
|
|
1864
|
+
or await self._read_latest_id(thread_id, checkpoint_ns)
|
|
1865
|
+
or "pending"
|
|
1866
|
+
)
|
|
1867
|
+
task_id_str = str(task_id) if task_id else uuid.uuid4().hex
|
|
1868
|
+
|
|
1869
|
+
# Serialize writes using pickle to preserve Python object types
|
|
1870
|
+
writes_b64 = None
|
|
1871
|
+
try:
|
|
1872
|
+
writes_b64 = base64.b64encode(pickle.dumps(writes)).decode("ascii")
|
|
1873
|
+
except Exception:
|
|
1874
|
+
pass
|
|
1875
|
+
|
|
1876
|
+
await self._memory._blob_set_json(
|
|
1877
|
+
self._writes_key(thread_id, str(checkpoint_id), task_id_str, checkpoint_ns),
|
|
1878
|
+
{
|
|
1879
|
+
"config": config,
|
|
1880
|
+
"writes_pickle": writes_b64,
|
|
1881
|
+
"task_id": task_id_str,
|
|
1882
|
+
},
|
|
1883
|
+
)
|
|
1884
|
+
|
|
1885
|
+
async def alist(self, config: Any, *, limit: Optional[int] = None) -> list:
|
|
1886
|
+
"""List all checkpoints for a thread, newest first."""
|
|
1887
|
+
thread_id = self._resolve_thread_id(config)
|
|
1888
|
+
checkpoint_ns = self._resolve_checkpoint_ns(config)
|
|
1889
|
+
keys = await self._memory._blob_list_keys(self._checkpoints_prefix(thread_id, checkpoint_ns))
|
|
1890
|
+
# keys 是字典序升序,倒序就是按 checkpoint_id 字典序的最新优先;
|
|
1891
|
+
# 真实业务里 checkpoint_id 一般是单调递增 UUID/timestamp,足够近似新→旧。
|
|
1892
|
+
keys = list(reversed(keys))
|
|
1893
|
+
if limit is not None and limit > 0:
|
|
1894
|
+
keys = keys[:limit]
|
|
1895
|
+
tuples = []
|
|
1896
|
+
for key in keys:
|
|
1897
|
+
data = await self._memory._blob_get_json(key)
|
|
1898
|
+
if data is not None:
|
|
1899
|
+
tuples.append(_to_checkpoint_tuple(data))
|
|
1900
|
+
return tuples
|
|
1901
|
+
|
|
1902
|
+
async def adelete_thread(self, thread_id: str) -> None:
|
|
1903
|
+
"""Delete every key under a thread (checkpoints + writes + latest)."""
|
|
1904
|
+
if not thread_id:
|
|
1905
|
+
raise MemoryValidationError("thread_id is required for delete_thread")
|
|
1906
|
+
keys = await self._memory._blob_list_keys(self._thread_prefix(thread_id))
|
|
1907
|
+
for key in keys:
|
|
1908
|
+
await self._memory._blob_delete(key)
|
|
1909
|
+
|
|
1910
|
+
@staticmethod
|
|
1911
|
+
def get_next_version(current: Any, _channel: Any = None) -> int:
|
|
1912
|
+
"""Compute next channel version. 与 Node ``getNextVersion`` 行为一致。
|
|
1913
|
+
|
|
1914
|
+
LangGraph Python 的签名是 ``get_next_version(current, channel)``;
|
|
1915
|
+
我们忽略 channel,纯按 current+1 处理(Node 实现也只看 current)。
|
|
1916
|
+
"""
|
|
1917
|
+
try:
|
|
1918
|
+
value = float(current) if current is not None else 0.0
|
|
1919
|
+
except (TypeError, ValueError):
|
|
1920
|
+
return 1
|
|
1921
|
+
if value != value or value in (float("inf"), float("-inf")): # NaN/inf
|
|
1922
|
+
return 1
|
|
1923
|
+
return int(value) + 1
|
|
1924
|
+
|
|
1925
|
+
# ─── Sync 别名(部分 LangGraph 集成会调 sync 名) ───
|
|
1926
|
+
#
|
|
1927
|
+
# LangGraph 的 sync API 实际是阻塞调用,用在非 async 上下文里。我们这里
|
|
1928
|
+
# 没有真正的 sync 后端(blob_store 全是 async),所以这些 sync 名只能在
|
|
1929
|
+
# 已有 event loop 的场景下被「await」。其作用主要是:
|
|
1930
|
+
# 1. 满足某些工具链按 ``hasattr(checkpointer, "get_tuple")`` 做 duck-typing;
|
|
1931
|
+
# 2. 业务侧可以拿到一个 awaitable,自己 ``await`` 即可。
|
|
1932
|
+
# 真要在纯同步环境里用,需要业务自己包 ``asyncio.run`` 之类。
|
|
1933
|
+
|
|
1934
|
+
def get_tuple(self, config: Any) -> Any:
|
|
1935
|
+
return self._run_sync(self.aget_tuple(config))
|
|
1936
|
+
|
|
1937
|
+
def put(self, config: Any, checkpoint: Any, metadata: Any = None, new_versions: Any = None) -> Any:
|
|
1938
|
+
return self._run_sync(self.aput(config, checkpoint, metadata, new_versions))
|
|
1939
|
+
|
|
1940
|
+
def put_writes(self, config: Any, writes: Any, task_id: Any) -> Any:
|
|
1941
|
+
return self._run_sync(self.aput_writes(config, writes, task_id))
|
|
1942
|
+
|
|
1943
|
+
def list(self, config: Any, *, limit: Optional[int] = None) -> Any:
|
|
1944
|
+
return self._run_sync(self.alist(config, limit=limit))
|
|
1945
|
+
|
|
1946
|
+
def delete_thread(self, thread_id: str) -> Any:
|
|
1947
|
+
return self._run_sync(self.adelete_thread(thread_id))
|
|
1948
|
+
|
|
1949
|
+
@staticmethod
|
|
1950
|
+
def _run_sync(coro) -> Any:
|
|
1951
|
+
"""Run an async coroutine from a sync context."""
|
|
1952
|
+
import asyncio
|
|
1953
|
+
import threading
|
|
1954
|
+
try:
|
|
1955
|
+
loop = asyncio.get_running_loop()
|
|
1956
|
+
except RuntimeError:
|
|
1957
|
+
loop = None
|
|
1958
|
+
|
|
1959
|
+
if loop is None:
|
|
1960
|
+
return asyncio.run(coro)
|
|
1961
|
+
|
|
1962
|
+
loop_thread_id = getattr(loop, '_thread_id', None)
|
|
1963
|
+
if loop_thread_id and loop_thread_id == threading.current_thread().ident:
|
|
1964
|
+
import concurrent.futures
|
|
1965
|
+
with concurrent.futures.ThreadPoolExecutor(1) as pool:
|
|
1966
|
+
return pool.submit(asyncio.run, coro).result(timeout=60)
|
|
1967
|
+
|
|
1968
|
+
future = asyncio.run_coroutine_threadsafe(coro, loop)
|
|
1969
|
+
return future.result(timeout=60)
|
|
1970
|
+
|
|
1971
|
+
# ─── Internal helpers ───
|
|
1972
|
+
|
|
1973
|
+
async def _read_latest_id(self, thread_id: str, checkpoint_ns: str = "") -> Optional[str]:
|
|
1974
|
+
raw = await self._memory._blob_get_text(
|
|
1975
|
+
self._latest_key(thread_id, checkpoint_ns), consistency="strong"
|
|
1976
|
+
)
|
|
1977
|
+
if raw is None:
|
|
1978
|
+
return None
|
|
1979
|
+
# latest 写入时是 json.dumps("xxx"),读出来要解一次
|
|
1980
|
+
try:
|
|
1981
|
+
value = json.loads(raw)
|
|
1982
|
+
except (json.JSONDecodeError, ValueError):
|
|
1983
|
+
value = raw
|
|
1984
|
+
return str(value) if value else None
|
|
1985
|
+
|
|
1986
|
+
|
|
1987
|
+
# ─── OpenAI Agents SDK Session Adapter ───
|
|
1988
|
+
|
|
1989
|
+
|
|
1990
|
+
_SESSION_VALID_ROLES = {"user", "assistant", "system", "tool"}
|
|
1991
|
+
_SESSION_METADATA_MARKER = {"agent_sdk_session": True}
|
|
1992
|
+
|
|
1993
|
+
|
|
1994
|
+
class _EdgeOneMemorySession:
|
|
1995
|
+
"""OpenAI Agents SDK Session protocol adapter backed by ctx.store.
|
|
1996
|
+
|
|
1997
|
+
Runner.run(..., session=session) will automatically:
|
|
1998
|
+
1. Call get_items() to read history and prepend to current input
|
|
1999
|
+
2. Call add_items() to persist this turn's user / assistant / tool items
|
|
2000
|
+
|
|
2001
|
+
This adapter stores each Agents SDK item as a memory message with
|
|
2002
|
+
metadata ``{"agent_sdk_session": True}`` to distinguish from manually
|
|
2003
|
+
written messages. The full item dict is stored as ``content``.
|
|
2004
|
+
"""
|
|
2005
|
+
|
|
2006
|
+
session_settings = None
|
|
2007
|
+
|
|
2008
|
+
def __init__(self, memory: ConversationMemory, session_id: str, *, max_items: int = 100):
|
|
2009
|
+
self.memory = memory
|
|
2010
|
+
self.session_id = session_id
|
|
2011
|
+
self.max_items = max_items
|
|
2012
|
+
|
|
2013
|
+
async def get_items(self, limit: int | None = None) -> list:
|
|
2014
|
+
"""Retrieve conversation history as Agents SDK input items."""
|
|
2015
|
+
effective_limit = min(limit or self.max_items, 100)
|
|
2016
|
+
messages = await self.memory.get_messages(
|
|
2017
|
+
self.session_id,
|
|
2018
|
+
limit=effective_limit,
|
|
2019
|
+
order="asc",
|
|
2020
|
+
)
|
|
2021
|
+
|
|
2022
|
+
items = []
|
|
2023
|
+
for message in messages:
|
|
2024
|
+
item = self._message_to_item(message)
|
|
2025
|
+
if item is not None:
|
|
2026
|
+
items.append(item)
|
|
2027
|
+
return items[-limit:] if limit is not None else items
|
|
2028
|
+
|
|
2029
|
+
async def add_items(self, items: list) -> None:
|
|
2030
|
+
"""Persist new Agents SDK items to memory."""
|
|
2031
|
+
if not items:
|
|
2032
|
+
return
|
|
2033
|
+
|
|
2034
|
+
for item in items:
|
|
2035
|
+
normalized = self._jsonable(item)
|
|
2036
|
+
role = self._role_for_item(normalized)
|
|
2037
|
+
await self.memory.append_message(
|
|
2038
|
+
self.session_id,
|
|
2039
|
+
role,
|
|
2040
|
+
normalized,
|
|
2041
|
+
metadata={
|
|
2042
|
+
**_SESSION_METADATA_MARKER,
|
|
2043
|
+
"item_type": normalized.get("type") if isinstance(normalized, dict) else None,
|
|
2044
|
+
},
|
|
2045
|
+
)
|
|
2046
|
+
|
|
2047
|
+
async def pop_item(self) -> Optional[dict]:
|
|
2048
|
+
"""Remove and return the most recent item from the session."""
|
|
2049
|
+
messages = await self.memory.get_messages(self.session_id, limit=100, order="desc")
|
|
2050
|
+
for message in messages:
|
|
2051
|
+
item = self._message_to_item(message)
|
|
2052
|
+
if item is None:
|
|
2053
|
+
continue
|
|
2054
|
+
await self.memory.delete_message(self.session_id, message.message_id)
|
|
2055
|
+
return item
|
|
2056
|
+
return None
|
|
2057
|
+
|
|
2058
|
+
async def clear_session(self) -> None:
|
|
2059
|
+
"""Clear all items for this session."""
|
|
2060
|
+
await self.memory.clear_messages(self.session_id)
|
|
2061
|
+
|
|
2062
|
+
def _message_to_item(self, message: Any) -> Optional[dict]:
|
|
2063
|
+
content = getattr(message, "content", None)
|
|
2064
|
+
role = getattr(message, "role", None)
|
|
2065
|
+
metadata = getattr(message, "metadata", None) or {}
|
|
2066
|
+
|
|
2067
|
+
# New format: content holds the full Agents SDK input item.
|
|
2068
|
+
if metadata.get("agent_sdk_session") and isinstance(content, dict):
|
|
2069
|
+
return content
|
|
2070
|
+
|
|
2071
|
+
# Legacy compat: plain text messages written by hand.
|
|
2072
|
+
if role in _SESSION_VALID_ROLES and content is not None:
|
|
2073
|
+
return {"role": role, "content": content}
|
|
2074
|
+
|
|
2075
|
+
return None
|
|
2076
|
+
|
|
2077
|
+
@staticmethod
|
|
2078
|
+
def _role_for_item(item: Any) -> str:
|
|
2079
|
+
"""Determine the best role string for storage from an Agents SDK item."""
|
|
2080
|
+
if isinstance(item, dict):
|
|
2081
|
+
role = item.get("role")
|
|
2082
|
+
if role in _SESSION_VALID_ROLES:
|
|
2083
|
+
return role
|
|
2084
|
+
|
|
2085
|
+
item_type = item.get("type")
|
|
2086
|
+
if item_type == "message":
|
|
2087
|
+
msg_role = item.get("role")
|
|
2088
|
+
return msg_role if msg_role in _SESSION_VALID_ROLES else "assistant"
|
|
2089
|
+
if item_type in ("function_call_output", "computer_call_output"):
|
|
2090
|
+
return "tool"
|
|
2091
|
+
if item_type in ("function_call", "computer_call", "reasoning"):
|
|
2092
|
+
return "assistant"
|
|
2093
|
+
|
|
2094
|
+
return "tool"
|
|
2095
|
+
|
|
2096
|
+
@staticmethod
|
|
2097
|
+
def _jsonable(item: Any) -> dict:
|
|
2098
|
+
"""Convert an Agents SDK item to a JSON-serializable dict."""
|
|
2099
|
+
if isinstance(item, dict):
|
|
2100
|
+
return item
|
|
2101
|
+
if hasattr(item, "model_dump"):
|
|
2102
|
+
return item.model_dump(exclude_unset=True)
|
|
2103
|
+
if hasattr(item, "dict"):
|
|
2104
|
+
return item.dict(exclude_unset=True)
|
|
2105
|
+
return {"role": "tool", "content": str(item)}
|
|
2106
|
+
|
|
2107
|
+
|
|
2108
|
+
# ─── Claude Agent SDK SessionStore Adapter ───
|
|
2109
|
+
|
|
2110
|
+
|
|
2111
|
+
_CLAUDE_SESSION_PREFIX = "claude_sessions"
|
|
2112
|
+
|
|
2113
|
+
|
|
2114
|
+
class EdgeOneSessionStore:
|
|
2115
|
+
"""Claude Agent SDK SessionStore backed by EdgeOne blob storage.
|
|
2116
|
+
|
|
2117
|
+
Implements the ``SessionStore`` protocol from ``claude_agent_sdk``:
|
|
2118
|
+
- append(key, entries): persist transcript entries
|
|
2119
|
+
- load(key): retrieve all entries in append order
|
|
2120
|
+
- list_sessions(project_key): enumerate sessions with mtime
|
|
2121
|
+
- delete(key): remove session (cascade) or single subpath
|
|
2122
|
+
- list_subkeys(key): discover subagent subpaths
|
|
2123
|
+
|
|
2124
|
+
Storage layout (multi-part-file, aligned with S3 reference impl)::
|
|
2125
|
+
|
|
2126
|
+
claude_sessions/{project_key}/{session_id}/parts/{ts13}_{rand6}
|
|
2127
|
+
claude_sessions/{project_key}/{session_id}/subpaths/{subpath}/parts/{ts13}_{rand6}
|
|
2128
|
+
|
|
2129
|
+
Each ``append()`` creates one new part file containing JSONL (one entry
|
|
2130
|
+
per line). ``load()`` lists all part files in lexicographic order (which
|
|
2131
|
+
equals chronological order due to the 13-digit timestamp prefix), reads
|
|
2132
|
+
each, and concatenates entries.
|
|
2133
|
+
|
|
2134
|
+
Usage::
|
|
2135
|
+
|
|
2136
|
+
from claude_agent_sdk import query, ClaudeAgentOptions
|
|
2137
|
+
|
|
2138
|
+
async def handler(ctx):
|
|
2139
|
+
store = ctx.store.claude_session_store()
|
|
2140
|
+
async for msg in query(
|
|
2141
|
+
prompt="Fix the bug",
|
|
2142
|
+
options=ClaudeAgentOptions(session_store=store),
|
|
2143
|
+
):
|
|
2144
|
+
...
|
|
2145
|
+
"""
|
|
2146
|
+
|
|
2147
|
+
def __init__(self, blob_store: Any) -> None:
|
|
2148
|
+
"""
|
|
2149
|
+
Args:
|
|
2150
|
+
blob_store: Raw blob store instance. Must implement get(key, type=),
|
|
2151
|
+
set(key, value), delete(key), list(prefix=).
|
|
2152
|
+
"""
|
|
2153
|
+
self._blob = blob_store
|
|
2154
|
+
|
|
2155
|
+
# ─── Key Helpers ───
|
|
2156
|
+
|
|
2157
|
+
@staticmethod
|
|
2158
|
+
def _parts_prefix(key: dict) -> str:
|
|
2159
|
+
"""Build the parts/ prefix for a given SessionKey."""
|
|
2160
|
+
base = (
|
|
2161
|
+
f"{_CLAUDE_SESSION_PREFIX}"
|
|
2162
|
+
f"/{_encode_segment(key['project_key'])}"
|
|
2163
|
+
f"/{_encode_segment(key['session_id'])}"
|
|
2164
|
+
)
|
|
2165
|
+
subpath = key.get("subpath")
|
|
2166
|
+
if subpath:
|
|
2167
|
+
base = f"{base}/subpaths/{_encode_segment(subpath)}"
|
|
2168
|
+
return f"{base}/parts/"
|
|
2169
|
+
|
|
2170
|
+
@staticmethod
|
|
2171
|
+
def _session_base(project_key: str, session_id: str) -> str:
|
|
2172
|
+
"""Build the base prefix for an entire session (for cascade delete)."""
|
|
2173
|
+
return (
|
|
2174
|
+
f"{_CLAUDE_SESSION_PREFIX}"
|
|
2175
|
+
f"/{_encode_segment(project_key)}"
|
|
2176
|
+
f"/{_encode_segment(session_id)}/"
|
|
2177
|
+
)
|
|
2178
|
+
|
|
2179
|
+
@staticmethod
|
|
2180
|
+
def _project_prefix(project_key: str) -> str:
|
|
2181
|
+
"""Build the prefix for all sessions in a project."""
|
|
2182
|
+
return f"{_CLAUDE_SESSION_PREFIX}/{_encode_segment(project_key)}/"
|
|
2183
|
+
|
|
2184
|
+
# ─── Required Methods ───
|
|
2185
|
+
|
|
2186
|
+
async def append(self, key: dict, entries: list) -> None:
|
|
2187
|
+
"""Persist a batch of transcript entries as a new part file.
|
|
2188
|
+
|
|
2189
|
+
Args:
|
|
2190
|
+
key: SessionKey dict with project_key, session_id, optional subpath.
|
|
2191
|
+
entries: List of JSON-serializable SessionStoreEntry objects.
|
|
2192
|
+
"""
|
|
2193
|
+
if not entries:
|
|
2194
|
+
return
|
|
2195
|
+
prefix = self._parts_prefix(key)
|
|
2196
|
+
ts = str(int(time.time() * 1000)).zfill(13)
|
|
2197
|
+
rand = uuid.uuid4().hex[:6]
|
|
2198
|
+
part_key = f"{prefix}{ts}_{rand}"
|
|
2199
|
+
payload = "\n".join(
|
|
2200
|
+
json.dumps(e, ensure_ascii=False, separators=(",", ":"))
|
|
2201
|
+
for e in entries
|
|
2202
|
+
)
|
|
2203
|
+
await self._blob.set(part_key, payload)
|
|
2204
|
+
|
|
2205
|
+
async def load(self, key: dict) -> Optional[list]:
|
|
2206
|
+
"""Load all entries for a session/subpath in append order.
|
|
2207
|
+
|
|
2208
|
+
Returns:
|
|
2209
|
+
List of entries in chronological order, or None if session unknown.
|
|
2210
|
+
"""
|
|
2211
|
+
prefix = self._parts_prefix(key)
|
|
2212
|
+
keys = await self._list_keys(prefix)
|
|
2213
|
+
if not keys:
|
|
2214
|
+
return None
|
|
2215
|
+
keys.sort() # 正序 = 时间顺序 = append 顺序
|
|
2216
|
+
entries: list = []
|
|
2217
|
+
for k in keys:
|
|
2218
|
+
raw = await self._blob.get(k, type="text")
|
|
2219
|
+
if raw:
|
|
2220
|
+
for line in raw.split("\n"):
|
|
2221
|
+
line = line.strip()
|
|
2222
|
+
if line:
|
|
2223
|
+
try:
|
|
2224
|
+
entries.append(json.loads(line))
|
|
2225
|
+
except (json.JSONDecodeError, ValueError):
|
|
2226
|
+
continue
|
|
2227
|
+
return entries if entries else None
|
|
2228
|
+
|
|
2229
|
+
# ─── Optional Methods ───
|
|
2230
|
+
|
|
2231
|
+
async def list_sessions(self, project_key: str) -> list:
|
|
2232
|
+
"""Enumerate all sessions under a project with their last-modified time.
|
|
2233
|
+
|
|
2234
|
+
Returns:
|
|
2235
|
+
List of {"session_id": str, "mtime": int} dicts.
|
|
2236
|
+
"""
|
|
2237
|
+
prefix = self._project_prefix(project_key)
|
|
2238
|
+
keys = await self._list_keys(prefix)
|
|
2239
|
+
sessions: dict = {} # session_id → max mtime
|
|
2240
|
+
for k in keys:
|
|
2241
|
+
remainder = k[len(prefix):]
|
|
2242
|
+
# remainder: {session_id}/parts/{ts}_{rand}
|
|
2243
|
+
# or: {session_id}/subpaths/{sp}/parts/{ts}_{rand}
|
|
2244
|
+
sid_encoded = remainder.split("/")[0]
|
|
2245
|
+
sid = urllib.parse.unquote(sid_encoded)
|
|
2246
|
+
# Extract timestamp from part filename
|
|
2247
|
+
filename = k.rsplit("/", 1)[-1]
|
|
2248
|
+
ts_str = filename.split("_")[0]
|
|
2249
|
+
try:
|
|
2250
|
+
ts = int(ts_str)
|
|
2251
|
+
except (ValueError, IndexError):
|
|
2252
|
+
continue
|
|
2253
|
+
if sid not in sessions or ts > sessions[sid]:
|
|
2254
|
+
sessions[sid] = ts
|
|
2255
|
+
return [{"session_id": sid, "mtime": mtime} for sid, mtime in sessions.items()]
|
|
2256
|
+
|
|
2257
|
+
async def delete(self, key: dict) -> None:
|
|
2258
|
+
"""Delete session data.
|
|
2259
|
+
|
|
2260
|
+
Without subpath: cascade-delete everything under the session.
|
|
2261
|
+
With subpath: delete only that subpath's parts.
|
|
2262
|
+
"""
|
|
2263
|
+
subpath = key.get("subpath")
|
|
2264
|
+
if subpath:
|
|
2265
|
+
prefix = self._parts_prefix(key)
|
|
2266
|
+
else:
|
|
2267
|
+
prefix = self._session_base(key["project_key"], key["session_id"])
|
|
2268
|
+
keys = await self._list_keys(prefix)
|
|
2269
|
+
for k in keys:
|
|
2270
|
+
await self._blob.delete(k)
|
|
2271
|
+
|
|
2272
|
+
async def list_subkeys(self, key: dict) -> list:
|
|
2273
|
+
"""Discover subagent subpaths for a session.
|
|
2274
|
+
|
|
2275
|
+
Returns:
|
|
2276
|
+
List of subpath strings (e.g. ["subagents/agent-abc", ...]).
|
|
2277
|
+
"""
|
|
2278
|
+
base = (
|
|
2279
|
+
f"{_CLAUDE_SESSION_PREFIX}"
|
|
2280
|
+
f"/{_encode_segment(key['project_key'])}"
|
|
2281
|
+
f"/{_encode_segment(key['session_id'])}"
|
|
2282
|
+
f"/subpaths/"
|
|
2283
|
+
)
|
|
2284
|
+
keys = await self._list_keys(base)
|
|
2285
|
+
subpaths: set = set()
|
|
2286
|
+
for k in keys:
|
|
2287
|
+
remainder = k[len(base):]
|
|
2288
|
+
# remainder: {encoded_subpath}/parts/{ts}_{rand}
|
|
2289
|
+
segments = remainder.split("/")
|
|
2290
|
+
if segments:
|
|
2291
|
+
decoded = urllib.parse.unquote(segments[0])
|
|
2292
|
+
if decoded and decoded != "." and ".." not in decoded:
|
|
2293
|
+
subpaths.add(decoded)
|
|
2294
|
+
return sorted(subpaths)
|
|
2295
|
+
|
|
2296
|
+
# ─── Internal Helpers ───
|
|
2297
|
+
|
|
2298
|
+
async def _list_keys(self, prefix: str) -> List[str]:
|
|
2299
|
+
"""List all keys with given prefix."""
|
|
2300
|
+
result = await self._blob.list(prefix=prefix)
|
|
2301
|
+
return [blob.key for blob in result.blobs]
|