botrun-flow-lang 5.12.263__py3-none-any.whl → 5.12.264__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- botrun_flow_lang/api/auth_api.py +39 -39
- botrun_flow_lang/api/auth_utils.py +183 -183
- botrun_flow_lang/api/botrun_back_api.py +65 -65
- botrun_flow_lang/api/flow_api.py +3 -3
- botrun_flow_lang/api/hatch_api.py +508 -508
- botrun_flow_lang/api/langgraph_api.py +811 -811
- botrun_flow_lang/api/line_bot_api.py +1484 -1484
- botrun_flow_lang/api/model_api.py +300 -300
- botrun_flow_lang/api/rate_limit_api.py +32 -32
- botrun_flow_lang/api/routes.py +79 -79
- botrun_flow_lang/api/search_api.py +53 -53
- botrun_flow_lang/api/storage_api.py +395 -395
- botrun_flow_lang/api/subsidy_api.py +290 -290
- botrun_flow_lang/api/subsidy_api_system_prompt.txt +109 -109
- botrun_flow_lang/api/user_setting_api.py +70 -70
- botrun_flow_lang/api/version_api.py +31 -31
- botrun_flow_lang/api/youtube_api.py +26 -26
- botrun_flow_lang/constants.py +13 -13
- botrun_flow_lang/langgraph_agents/agents/agent_runner.py +178 -178
- botrun_flow_lang/langgraph_agents/agents/agent_tools/step_planner.py +77 -77
- botrun_flow_lang/langgraph_agents/agents/checkpointer/firestore_checkpointer.py +666 -666
- botrun_flow_lang/langgraph_agents/agents/gov_researcher/GOV_RESEARCHER_PRD.md +192 -192
- botrun_flow_lang/langgraph_agents/agents/gov_researcher/gemini_subsidy_graph.py +460 -460
- botrun_flow_lang/langgraph_agents/agents/gov_researcher/gov_researcher_2_graph.py +1002 -1002
- botrun_flow_lang/langgraph_agents/agents/gov_researcher/gov_researcher_graph.py +822 -822
- botrun_flow_lang/langgraph_agents/agents/langgraph_react_agent.py +723 -723
- botrun_flow_lang/langgraph_agents/agents/search_agent_graph.py +864 -864
- botrun_flow_lang/langgraph_agents/agents/tools/__init__.py +4 -4
- botrun_flow_lang/langgraph_agents/agents/tools/gemini_code_execution.py +376 -376
- botrun_flow_lang/langgraph_agents/agents/util/gemini_grounding.py +66 -66
- botrun_flow_lang/langgraph_agents/agents/util/html_util.py +316 -316
- botrun_flow_lang/langgraph_agents/agents/util/img_util.py +294 -294
- botrun_flow_lang/langgraph_agents/agents/util/local_files.py +419 -419
- botrun_flow_lang/langgraph_agents/agents/util/mermaid_util.py +86 -86
- botrun_flow_lang/langgraph_agents/agents/util/model_utils.py +143 -143
- botrun_flow_lang/langgraph_agents/agents/util/pdf_analyzer.py +486 -486
- botrun_flow_lang/langgraph_agents/agents/util/pdf_cache.py +250 -250
- botrun_flow_lang/langgraph_agents/agents/util/pdf_processor.py +204 -204
- botrun_flow_lang/langgraph_agents/agents/util/perplexity_search.py +464 -464
- botrun_flow_lang/langgraph_agents/agents/util/plotly_util.py +59 -59
- botrun_flow_lang/langgraph_agents/agents/util/tavily_search.py +199 -199
- botrun_flow_lang/langgraph_agents/agents/util/youtube_util.py +90 -90
- botrun_flow_lang/langgraph_agents/cache/langgraph_botrun_cache.py +197 -197
- botrun_flow_lang/llm_agent/llm_agent.py +19 -19
- botrun_flow_lang/llm_agent/llm_agent_util.py +83 -83
- botrun_flow_lang/log/.gitignore +2 -2
- botrun_flow_lang/main.py +61 -61
- botrun_flow_lang/main_fast.py +51 -51
- botrun_flow_lang/mcp_server/__init__.py +10 -10
- botrun_flow_lang/mcp_server/default_mcp.py +744 -744
- botrun_flow_lang/models/nodes/utils.py +205 -205
- botrun_flow_lang/models/token_usage.py +34 -34
- botrun_flow_lang/requirements.txt +21 -21
- botrun_flow_lang/services/base/firestore_base.py +30 -30
- botrun_flow_lang/services/hatch/hatch_factory.py +11 -11
- botrun_flow_lang/services/hatch/hatch_fs_store.py +419 -419
- botrun_flow_lang/services/storage/storage_cs_store.py +206 -206
- botrun_flow_lang/services/storage/storage_factory.py +12 -12
- botrun_flow_lang/services/storage/storage_store.py +65 -65
- botrun_flow_lang/services/user_setting/user_setting_factory.py +9 -9
- botrun_flow_lang/services/user_setting/user_setting_fs_store.py +66 -66
- botrun_flow_lang/static/docs/tools/index.html +926 -926
- botrun_flow_lang/tests/api_functional_tests.py +1525 -1525
- botrun_flow_lang/tests/api_stress_test.py +357 -357
- botrun_flow_lang/tests/shared_hatch_tests.py +333 -333
- botrun_flow_lang/tests/test_botrun_app.py +46 -46
- botrun_flow_lang/tests/test_html_util.py +31 -31
- botrun_flow_lang/tests/test_img_analyzer.py +190 -190
- botrun_flow_lang/tests/test_img_util.py +39 -39
- botrun_flow_lang/tests/test_local_files.py +114 -114
- botrun_flow_lang/tests/test_mermaid_util.py +103 -103
- botrun_flow_lang/tests/test_pdf_analyzer.py +104 -104
- botrun_flow_lang/tests/test_plotly_util.py +151 -151
- botrun_flow_lang/tests/test_run_workflow_engine.py +65 -65
- botrun_flow_lang/tools/generate_docs.py +133 -133
- botrun_flow_lang/tools/templates/tools.html +153 -153
- botrun_flow_lang/utils/__init__.py +7 -7
- botrun_flow_lang/utils/botrun_logger.py +344 -344
- botrun_flow_lang/utils/clients/rate_limit_client.py +209 -209
- botrun_flow_lang/utils/clients/token_verify_client.py +153 -153
- botrun_flow_lang/utils/google_drive_utils.py +654 -654
- botrun_flow_lang/utils/langchain_utils.py +324 -324
- botrun_flow_lang/utils/yaml_utils.py +9 -9
- {botrun_flow_lang-5.12.263.dist-info → botrun_flow_lang-5.12.264.dist-info}/METADATA +1 -1
- botrun_flow_lang-5.12.264.dist-info/RECORD +102 -0
- botrun_flow_lang-5.12.263.dist-info/RECORD +0 -102
- {botrun_flow_lang-5.12.263.dist-info → botrun_flow_lang-5.12.264.dist-info}/WHEEL +0 -0
|
@@ -1,666 +1,666 @@
|
|
|
1
|
-
from typing import (
|
|
2
|
-
Any,
|
|
3
|
-
Dict,
|
|
4
|
-
List,
|
|
5
|
-
Optional,
|
|
6
|
-
Tuple,
|
|
7
|
-
AsyncIterator,
|
|
8
|
-
Iterator,
|
|
9
|
-
cast,
|
|
10
|
-
AsyncGenerator,
|
|
11
|
-
)
|
|
12
|
-
import logging
|
|
13
|
-
from datetime import datetime
|
|
14
|
-
import os
|
|
15
|
-
import asyncio
|
|
16
|
-
from dotenv import load_dotenv
|
|
17
|
-
|
|
18
|
-
from google.cloud import firestore
|
|
19
|
-
from google.cloud.firestore_v1.base_query import FieldFilter
|
|
20
|
-
|
|
21
|
-
from google.cloud.exceptions import GoogleCloudError
|
|
22
|
-
from google.oauth2 import service_account
|
|
23
|
-
|
|
24
|
-
from langgraph.checkpoint.base import (
|
|
25
|
-
BaseCheckpointSaver,
|
|
26
|
-
Checkpoint,
|
|
27
|
-
CheckpointMetadata,
|
|
28
|
-
CheckpointTuple,
|
|
29
|
-
PendingWrite, # Note: PendingWrite is actually Tuple[str, Any, Any]
|
|
30
|
-
get_checkpoint_id,
|
|
31
|
-
WRITES_IDX_MAP,
|
|
32
|
-
ChannelVersions,
|
|
33
|
-
)
|
|
34
|
-
from langgraph.checkpoint.serde.base import SerializerProtocol
|
|
35
|
-
from langgraph.checkpoint.serde.jsonplus import JsonPlusSerializer
|
|
36
|
-
from langgraph.pregel.types import StateSnapshot
|
|
37
|
-
from langchain_core.runnables import RunnableConfig
|
|
38
|
-
|
|
39
|
-
from botrun_flow_lang.constants import CHECKPOINTER_STORE_NAME
|
|
40
|
-
from botrun_flow_lang.services.base.firestore_base import FirestoreBase
|
|
41
|
-
import time
|
|
42
|
-
|
|
43
|
-
load_dotenv()
|
|
44
|
-
|
|
45
|
-
# Set up logger
|
|
46
|
-
logger = logging.getLogger("AsyncFirestoreCheckpointer")
|
|
47
|
-
# 從環境變數取得日誌級別,默認為 WARNING(不顯示 INFO 級別日誌)
|
|
48
|
-
log_level = os.getenv("FIRESTORE_CHECKPOINTER_LOG_LEVEL", "WARNING").upper()
|
|
49
|
-
log_level_map = {
|
|
50
|
-
"DEBUG": logging.DEBUG,
|
|
51
|
-
"INFO": logging.INFO,
|
|
52
|
-
"WARNING": logging.WARNING,
|
|
53
|
-
"ERROR": logging.ERROR,
|
|
54
|
-
"CRITICAL": logging.CRITICAL,
|
|
55
|
-
}
|
|
56
|
-
logger.setLevel(log_level_map.get(log_level, logging.WARNING))
|
|
57
|
-
# Create console handler if it doesn't exist
|
|
58
|
-
if not logger.handlers:
|
|
59
|
-
ch = logging.StreamHandler()
|
|
60
|
-
ch.setLevel(log_level_map.get(log_level, logging.WARNING))
|
|
61
|
-
formatter = logging.Formatter(
|
|
62
|
-
"%(asctime)s - %(name)s - %(levelname)s - %(message)s"
|
|
63
|
-
)
|
|
64
|
-
ch.setFormatter(formatter)
|
|
65
|
-
logger.addHandler(ch)
|
|
66
|
-
|
|
67
|
-
# Constants for field names
|
|
68
|
-
FIELD_THREAD_ID = "thread_id"
|
|
69
|
-
FIELD_CHECKPOINT_NS = "checkpoint_ns"
|
|
70
|
-
FIELD_CHECKPOINT_ID = "checkpoint_id"
|
|
71
|
-
FIELD_PARENT_CHECKPOINT_ID = "parent_checkpoint_id"
|
|
72
|
-
FIELD_TASK_ID = "task_id"
|
|
73
|
-
FIELD_IDX = "idx"
|
|
74
|
-
FIELD_TIMESTAMP = "timestamp"
|
|
75
|
-
FIELD_TYPE = "type"
|
|
76
|
-
FIELD_DATA = "data"
|
|
77
|
-
FIELD_METADATA = "metadata"
|
|
78
|
-
FIELD_NEW_VERSIONS = "new_versions"
|
|
79
|
-
FIELD_CHANNEL = "channel"
|
|
80
|
-
FIELD_VALUE = "value"
|
|
81
|
-
FIELD_CREATED_AT = "created_at"
|
|
82
|
-
|
|
83
|
-
|
|
84
|
-
class AsyncFirestoreCheckpointer(BaseCheckpointSaver):
|
|
85
|
-
"""Async Firestore-based checkpoint saver implementation.
|
|
86
|
-
|
|
87
|
-
This implementation uses Firestore's collections and sub-collections to efficiently
|
|
88
|
-
store and retrieve checkpoints and their associated writes.
|
|
89
|
-
|
|
90
|
-
For each environment, it creates:
|
|
91
|
-
- A root collection for all checkpoints
|
|
92
|
-
- A sub-collection for each checkpoint's writes
|
|
93
|
-
|
|
94
|
-
This design provides:
|
|
95
|
-
- Efficient querying by thread_id, namespace, and checkpoint_id
|
|
96
|
-
- Hierarchical structure that matches the data relationships
|
|
97
|
-
- Improved query performance with proper indexing
|
|
98
|
-
"""
|
|
99
|
-
|
|
100
|
-
db: firestore.AsyncClient
|
|
101
|
-
checkpoints_collection: firestore.AsyncCollectionReference
|
|
102
|
-
|
|
103
|
-
def __init__(
|
|
104
|
-
self,
|
|
105
|
-
env_name: str,
|
|
106
|
-
serializer: Optional[SerializerProtocol] = None,
|
|
107
|
-
collection_name: Optional[str] = None,
|
|
108
|
-
):
|
|
109
|
-
"""Initialize the AsyncFirestoreCheckpointer.
|
|
110
|
-
|
|
111
|
-
Args:
|
|
112
|
-
env_name: Environment name to be used as prefix for collection.
|
|
113
|
-
serializer: Optional serializer to use for converting values to storable format.
|
|
114
|
-
collection_name: Optional custom collection name. If not provided,
|
|
115
|
-
it will use {env_name}-{CHECKPOINTER_STORE_NAME}.
|
|
116
|
-
"""
|
|
117
|
-
super().__init__()
|
|
118
|
-
logger.info(f"Initializing AsyncFirestoreCheckpointer with env_name={env_name}")
|
|
119
|
-
self.serde = serializer or JsonPlusSerializer()
|
|
120
|
-
self._collection_name = (
|
|
121
|
-
collection_name or f"{env_name}-{CHECKPOINTER_STORE_NAME}"
|
|
122
|
-
)
|
|
123
|
-
logger.info(f"Using collection: {self._collection_name}")
|
|
124
|
-
|
|
125
|
-
try:
|
|
126
|
-
# Initialize async Firestore client
|
|
127
|
-
google_service_account_key_path = os.getenv(
|
|
128
|
-
"GOOGLE_APPLICATION_CREDENTIALS_FOR_FASTAPI",
|
|
129
|
-
"/app/keys/scoop-386004-d22d99a7afd9.json",
|
|
130
|
-
)
|
|
131
|
-
credentials = service_account.Credentials.from_service_account_file(
|
|
132
|
-
google_service_account_key_path,
|
|
133
|
-
scopes=["https://www.googleapis.com/auth/datastore"],
|
|
134
|
-
)
|
|
135
|
-
|
|
136
|
-
project_id = os.getenv("GOOGLE_CLOUD_PROJECT")
|
|
137
|
-
if project_id:
|
|
138
|
-
self.db = firestore.AsyncClient(
|
|
139
|
-
project=project_id, credentials=credentials
|
|
140
|
-
)
|
|
141
|
-
else:
|
|
142
|
-
self.db = firestore.AsyncClient(credentials=credentials)
|
|
143
|
-
|
|
144
|
-
self.checkpoints_collection = self.db.collection(self._collection_name)
|
|
145
|
-
logger.info("Async Firestore client initialized successfully")
|
|
146
|
-
except Exception as e:
|
|
147
|
-
logger.error(f"Error initializing Firestore client: {e}", exc_info=True)
|
|
148
|
-
raise
|
|
149
|
-
|
|
150
|
-
async def close(self):
|
|
151
|
-
"""Close the Firestore client connection."""
|
|
152
|
-
if hasattr(self, "db") and self.db:
|
|
153
|
-
await self.db.close()
|
|
154
|
-
logger.info("Firestore client connection closed")
|
|
155
|
-
|
|
156
|
-
async def __aenter__(self):
|
|
157
|
-
"""Context manager entry."""
|
|
158
|
-
return self
|
|
159
|
-
|
|
160
|
-
async def __aexit__(self, exc_type, exc_val, exc_tb):
|
|
161
|
-
"""Context manager exit with cleanup."""
|
|
162
|
-
await self.close()
|
|
163
|
-
|
|
164
|
-
def _get_checkpoint_doc_id(
|
|
165
|
-
self, thread_id: str, checkpoint_ns: str, checkpoint_id: str
|
|
166
|
-
) -> str:
|
|
167
|
-
"""Generate a document ID for a checkpoint.
|
|
168
|
-
|
|
169
|
-
For maximum Firestore efficiency, we use a compound ID that naturally clusters
|
|
170
|
-
related data together for efficient retrieval.
|
|
171
|
-
"""
|
|
172
|
-
return f"{thread_id}:{checkpoint_ns}:{checkpoint_id}"
|
|
173
|
-
|
|
174
|
-
def _get_writes_subcollection(
|
|
175
|
-
self, checkpoint_doc_ref: firestore.AsyncDocumentReference
|
|
176
|
-
) -> firestore.AsyncCollectionReference:
|
|
177
|
-
"""Get the subcollection reference for checkpoint writes."""
|
|
178
|
-
return checkpoint_doc_ref.collection("writes")
|
|
179
|
-
|
|
180
|
-
def _parse_checkpoint_doc_id(self, doc_id: str) -> Dict[str, str]:
|
|
181
|
-
"""Parse a checkpoint document ID into its components."""
|
|
182
|
-
parts = doc_id.split(":")
|
|
183
|
-
if len(parts) != 3:
|
|
184
|
-
raise ValueError(f"Invalid checkpoint document ID format: {doc_id}")
|
|
185
|
-
|
|
186
|
-
return {
|
|
187
|
-
FIELD_THREAD_ID: parts[0],
|
|
188
|
-
FIELD_CHECKPOINT_NS: parts[1],
|
|
189
|
-
FIELD_CHECKPOINT_ID: parts[2],
|
|
190
|
-
}
|
|
191
|
-
|
|
192
|
-
async def aput(
|
|
193
|
-
self,
|
|
194
|
-
config: RunnableConfig,
|
|
195
|
-
checkpoint: Checkpoint,
|
|
196
|
-
metadata: CheckpointMetadata,
|
|
197
|
-
new_versions: ChannelVersions,
|
|
198
|
-
) -> RunnableConfig:
|
|
199
|
-
"""Save a checkpoint to Firestore asynchronously.
|
|
200
|
-
|
|
201
|
-
This method saves a checkpoint to Firestore as a document with fields for
|
|
202
|
-
efficient querying.
|
|
203
|
-
|
|
204
|
-
Args:
|
|
205
|
-
config: The config to associate with the checkpoint.
|
|
206
|
-
checkpoint: The checkpoint to save.
|
|
207
|
-
metadata: Additional metadata to save with the checkpoint.
|
|
208
|
-
new_versions: New channel versions as of this write.
|
|
209
|
-
|
|
210
|
-
Returns:
|
|
211
|
-
RunnableConfig: Updated configuration after storing the checkpoint.
|
|
212
|
-
"""
|
|
213
|
-
thread_id = config["configurable"]["thread_id"]
|
|
214
|
-
checkpoint_ns = config["configurable"].get("checkpoint_ns", "")
|
|
215
|
-
checkpoint_id = checkpoint["id"]
|
|
216
|
-
parent_checkpoint_id = config["configurable"].get("checkpoint_id", "")
|
|
217
|
-
|
|
218
|
-
# Generate document ID for efficient querying
|
|
219
|
-
doc_id = self._get_checkpoint_doc_id(thread_id, checkpoint_ns, checkpoint_id)
|
|
220
|
-
|
|
221
|
-
# Serialize the data
|
|
222
|
-
type_, serialized_checkpoint = self.serde.dumps_typed(checkpoint)
|
|
223
|
-
serialized_metadata = self.serde.dumps(metadata)
|
|
224
|
-
|
|
225
|
-
# Prepare the document data
|
|
226
|
-
data = {
|
|
227
|
-
FIELD_THREAD_ID: thread_id,
|
|
228
|
-
FIELD_CHECKPOINT_NS: checkpoint_ns,
|
|
229
|
-
FIELD_CHECKPOINT_ID: checkpoint_id,
|
|
230
|
-
FIELD_PARENT_CHECKPOINT_ID: parent_checkpoint_id,
|
|
231
|
-
FIELD_TYPE: type_,
|
|
232
|
-
FIELD_DATA: serialized_checkpoint,
|
|
233
|
-
FIELD_METADATA: serialized_metadata,
|
|
234
|
-
FIELD_TIMESTAMP: firestore.SERVER_TIMESTAMP, # Use server timestamp for consistency
|
|
235
|
-
FIELD_CREATED_AT: datetime.utcnow().isoformat(), # Backup client-side timestamp
|
|
236
|
-
}
|
|
237
|
-
|
|
238
|
-
if new_versions:
|
|
239
|
-
data[FIELD_NEW_VERSIONS] = self.serde.dumps(new_versions)
|
|
240
|
-
|
|
241
|
-
try:
|
|
242
|
-
await self.checkpoints_collection.document(doc_id).set(data)
|
|
243
|
-
logger.info(f"Successfully stored checkpoint with ID: {doc_id}")
|
|
244
|
-
except Exception as e:
|
|
245
|
-
logger.error(f"Error storing checkpoint: {e}", exc_info=True)
|
|
246
|
-
raise
|
|
247
|
-
|
|
248
|
-
return {
|
|
249
|
-
"configurable": {
|
|
250
|
-
"thread_id": thread_id,
|
|
251
|
-
"checkpoint_ns": checkpoint_ns,
|
|
252
|
-
"checkpoint_id": checkpoint_id,
|
|
253
|
-
}
|
|
254
|
-
}
|
|
255
|
-
|
|
256
|
-
async def aput_writes(
|
|
257
|
-
self,
|
|
258
|
-
config: RunnableConfig,
|
|
259
|
-
writes: List[Tuple[str, Any]],
|
|
260
|
-
task_id: str,
|
|
261
|
-
) -> None:
|
|
262
|
-
"""Store intermediate writes linked to a checkpoint asynchronously.
|
|
263
|
-
|
|
264
|
-
This method saves intermediate writes associated with a checkpoint in a subcollection.
|
|
265
|
-
|
|
266
|
-
Args:
|
|
267
|
-
config: Configuration of the related checkpoint.
|
|
268
|
-
writes: List of writes to store, each as (channel, value) pair.
|
|
269
|
-
task_id: Identifier for the task creating the writes.
|
|
270
|
-
"""
|
|
271
|
-
thread_id = config["configurable"]["thread_id"]
|
|
272
|
-
checkpoint_ns = config["configurable"].get("checkpoint_ns", "")
|
|
273
|
-
checkpoint_id = config["configurable"]["checkpoint_id"]
|
|
274
|
-
|
|
275
|
-
# Get the checkpoint document reference
|
|
276
|
-
checkpoint_doc_id = self._get_checkpoint_doc_id(
|
|
277
|
-
thread_id, checkpoint_ns, checkpoint_id
|
|
278
|
-
)
|
|
279
|
-
checkpoint_doc_ref = self.checkpoints_collection.document(checkpoint_doc_id)
|
|
280
|
-
|
|
281
|
-
# Get the writes subcollection
|
|
282
|
-
writes_collection = self._get_writes_subcollection(checkpoint_doc_ref)
|
|
283
|
-
|
|
284
|
-
try:
|
|
285
|
-
# Optimize write operations with batching
|
|
286
|
-
batch = self.db.batch()
|
|
287
|
-
batch_size = 0
|
|
288
|
-
max_batch_size = 450 # Slightly below Firestore limit for safety
|
|
289
|
-
batch_futures = [] # For tracking concurrent batch commits
|
|
290
|
-
|
|
291
|
-
for idx, (channel, value) in enumerate(writes):
|
|
292
|
-
# Determine the write ID
|
|
293
|
-
write_idx = WRITES_IDX_MAP.get(channel, idx)
|
|
294
|
-
write_id = f"{task_id}:{write_idx}"
|
|
295
|
-
|
|
296
|
-
# Serialize the value
|
|
297
|
-
type_, serialized_value = self.serde.dumps_typed(value)
|
|
298
|
-
|
|
299
|
-
# Prepare the write data
|
|
300
|
-
data = {
|
|
301
|
-
FIELD_TASK_ID: task_id,
|
|
302
|
-
FIELD_IDX: write_idx,
|
|
303
|
-
FIELD_CHANNEL: channel,
|
|
304
|
-
FIELD_TYPE: type_,
|
|
305
|
-
FIELD_VALUE: serialized_value,
|
|
306
|
-
FIELD_TIMESTAMP: firestore.SERVER_TIMESTAMP,
|
|
307
|
-
FIELD_CREATED_AT: datetime.utcnow().isoformat(),
|
|
308
|
-
}
|
|
309
|
-
|
|
310
|
-
write_doc_ref = writes_collection.document(write_id)
|
|
311
|
-
|
|
312
|
-
# Determine if we should set or create-if-not-exists
|
|
313
|
-
if channel in WRITES_IDX_MAP:
|
|
314
|
-
# For indexed channels, always set (similar to HSET behavior)
|
|
315
|
-
batch.set(write_doc_ref, data)
|
|
316
|
-
else:
|
|
317
|
-
# For non-indexed channels, we need a transaction to check existence
|
|
318
|
-
# We'll check existence manually for now
|
|
319
|
-
doc = await write_doc_ref.get()
|
|
320
|
-
if not doc.exists:
|
|
321
|
-
batch.set(write_doc_ref, data)
|
|
322
|
-
|
|
323
|
-
batch_size += 1
|
|
324
|
-
|
|
325
|
-
# If batch is getting full, submit it and start a new one
|
|
326
|
-
if batch_size >= max_batch_size:
|
|
327
|
-
batch_futures.append(batch.commit())
|
|
328
|
-
batch = self.db.batch()
|
|
329
|
-
batch_size = 0
|
|
330
|
-
|
|
331
|
-
# Commit any remaining writes in the batch
|
|
332
|
-
if batch_size > 0:
|
|
333
|
-
batch_futures.append(batch.commit())
|
|
334
|
-
|
|
335
|
-
# Wait for all batch operations to complete
|
|
336
|
-
if batch_futures:
|
|
337
|
-
await asyncio.gather(*batch_futures)
|
|
338
|
-
|
|
339
|
-
logger.info(
|
|
340
|
-
f"Successfully stored {len(writes)} writes for checkpoint: {checkpoint_id}"
|
|
341
|
-
)
|
|
342
|
-
except Exception as e:
|
|
343
|
-
logger.error(f"Error storing writes: {e}", exc_info=True)
|
|
344
|
-
raise
|
|
345
|
-
|
|
346
|
-
async def aget_tuple(
|
|
347
|
-
self,
|
|
348
|
-
config: RunnableConfig,
|
|
349
|
-
) -> Optional[CheckpointTuple]:
|
|
350
|
-
"""Get a checkpoint tuple from Firestore asynchronously.
|
|
351
|
-
|
|
352
|
-
This method retrieves a checkpoint and its associated writes from Firestore.
|
|
353
|
-
|
|
354
|
-
Args:
|
|
355
|
-
config: The config to use for retrieving the checkpoint.
|
|
356
|
-
|
|
357
|
-
Returns:
|
|
358
|
-
Optional[CheckpointTuple]: The retrieved checkpoint tuple, or None if not found.
|
|
359
|
-
"""
|
|
360
|
-
thread_id = config["configurable"]["thread_id"]
|
|
361
|
-
checkpoint_id = get_checkpoint_id(config)
|
|
362
|
-
checkpoint_ns = config["configurable"].get("checkpoint_ns", "")
|
|
363
|
-
|
|
364
|
-
try:
|
|
365
|
-
# If checkpoint_id is provided, get that specific checkpoint
|
|
366
|
-
if checkpoint_id:
|
|
367
|
-
doc_id = self._get_checkpoint_doc_id(
|
|
368
|
-
thread_id, checkpoint_ns, checkpoint_id
|
|
369
|
-
)
|
|
370
|
-
doc = await self.checkpoints_collection.document(doc_id).get()
|
|
371
|
-
|
|
372
|
-
if not doc.exists:
|
|
373
|
-
return None
|
|
374
|
-
else:
|
|
375
|
-
# Otherwise, find the latest checkpoint
|
|
376
|
-
query = (
|
|
377
|
-
self.checkpoints_collection.where(
|
|
378
|
-
filter=FieldFilter(FIELD_THREAD_ID, "==", thread_id)
|
|
379
|
-
)
|
|
380
|
-
.where(filter=FieldFilter(FIELD_CHECKPOINT_NS, "==", checkpoint_ns))
|
|
381
|
-
.order_by(FIELD_TIMESTAMP, direction=firestore.Query.DESCENDING)
|
|
382
|
-
.limit(1)
|
|
383
|
-
)
|
|
384
|
-
|
|
385
|
-
docs = await query.get()
|
|
386
|
-
if not docs:
|
|
387
|
-
return None
|
|
388
|
-
|
|
389
|
-
doc = docs[0]
|
|
390
|
-
# Extract the checkpoint_id for loading writes
|
|
391
|
-
checkpoint_id = doc.get(FIELD_CHECKPOINT_ID)
|
|
392
|
-
|
|
393
|
-
data = doc.to_dict()
|
|
394
|
-
|
|
395
|
-
# Parse the document data
|
|
396
|
-
type_ = data.get(FIELD_TYPE)
|
|
397
|
-
serialized_checkpoint = data.get(FIELD_DATA)
|
|
398
|
-
serialized_metadata = data.get(FIELD_METADATA)
|
|
399
|
-
|
|
400
|
-
if not type_ or not serialized_checkpoint or not serialized_metadata:
|
|
401
|
-
logger.error(f"Invalid checkpoint data for ID: {doc.id}")
|
|
402
|
-
return None
|
|
403
|
-
|
|
404
|
-
# 重新組合類型和序列化數據,以符合 loads_typed 的期望
|
|
405
|
-
checkpoint = self.serde.loads_typed((type_, serialized_checkpoint))
|
|
406
|
-
metadata = self.serde.loads(serialized_metadata)
|
|
407
|
-
|
|
408
|
-
# Load pending writes from the subcollection
|
|
409
|
-
pending_writes = await self._aload_pending_writes(doc.reference)
|
|
410
|
-
|
|
411
|
-
return CheckpointTuple(
|
|
412
|
-
config=config,
|
|
413
|
-
checkpoint=checkpoint,
|
|
414
|
-
metadata=metadata,
|
|
415
|
-
pending_writes=pending_writes if pending_writes else None,
|
|
416
|
-
)
|
|
417
|
-
except Exception as e:
|
|
418
|
-
logger.error(f"Error retrieving checkpoint tuple: {e}", exc_info=True)
|
|
419
|
-
raise
|
|
420
|
-
|
|
421
|
-
async def alist(
|
|
422
|
-
self,
|
|
423
|
-
config: Optional[RunnableConfig],
|
|
424
|
-
*,
|
|
425
|
-
filter: Optional[dict[str, Any]] = None,
|
|
426
|
-
before: Optional[RunnableConfig] = None,
|
|
427
|
-
limit: Optional[int] = None,
|
|
428
|
-
) -> AsyncGenerator[CheckpointTuple, None]:
|
|
429
|
-
"""List checkpoints from Firestore asynchronously.
|
|
430
|
-
|
|
431
|
-
This method retrieves a list of checkpoint tuples from Firestore based
|
|
432
|
-
on the provided config.
|
|
433
|
-
|
|
434
|
-
Args:
|
|
435
|
-
config: Base configuration for filtering checkpoints.
|
|
436
|
-
filter: Additional filtering criteria for metadata.
|
|
437
|
-
before: If provided, only checkpoints before the specified checkpoint ID are returned.
|
|
438
|
-
limit: Maximum number of checkpoints to return.
|
|
439
|
-
|
|
440
|
-
Yields:
|
|
441
|
-
AsyncGenerator[CheckpointTuple, None]: An async generator of matching checkpoint tuples.
|
|
442
|
-
"""
|
|
443
|
-
if not config:
|
|
444
|
-
logger.error("Config is required for listing checkpoints")
|
|
445
|
-
return
|
|
446
|
-
|
|
447
|
-
thread_id = config["configurable"]["thread_id"]
|
|
448
|
-
checkpoint_ns = config["configurable"].get("checkpoint_ns", "")
|
|
449
|
-
|
|
450
|
-
try:
|
|
451
|
-
t1 = time.time()
|
|
452
|
-
# Build the query
|
|
453
|
-
query = (
|
|
454
|
-
self.checkpoints_collection.where(
|
|
455
|
-
filter=FieldFilter(FIELD_THREAD_ID, "==", thread_id)
|
|
456
|
-
)
|
|
457
|
-
.where(filter=FieldFilter(FIELD_CHECKPOINT_NS, "==", checkpoint_ns))
|
|
458
|
-
.order_by(FIELD_TIMESTAMP, direction=firestore.Query.DESCENDING)
|
|
459
|
-
)
|
|
460
|
-
|
|
461
|
-
# Apply additional filters
|
|
462
|
-
if before is not None:
|
|
463
|
-
before_id = get_checkpoint_id(before)
|
|
464
|
-
# We need to find the timestamp of the 'before' checkpoint to filter correctly
|
|
465
|
-
before_doc_id = self._get_checkpoint_doc_id(
|
|
466
|
-
thread_id, checkpoint_ns, before_id
|
|
467
|
-
)
|
|
468
|
-
before_doc = await self.checkpoints_collection.document(
|
|
469
|
-
before_doc_id
|
|
470
|
-
).get()
|
|
471
|
-
|
|
472
|
-
if before_doc.exists:
|
|
473
|
-
before_timestamp = before_doc.get(FIELD_TIMESTAMP)
|
|
474
|
-
if before_timestamp:
|
|
475
|
-
query = query.where(FIELD_TIMESTAMP, "<", before_timestamp)
|
|
476
|
-
|
|
477
|
-
# Apply limit if provided
|
|
478
|
-
if limit is not None:
|
|
479
|
-
query = query.limit(limit)
|
|
480
|
-
|
|
481
|
-
# Execute the query
|
|
482
|
-
docs = await query.get()
|
|
483
|
-
|
|
484
|
-
# Process each document
|
|
485
|
-
for doc in docs:
|
|
486
|
-
data = doc.to_dict()
|
|
487
|
-
|
|
488
|
-
if not data or FIELD_DATA not in data or FIELD_METADATA not in data:
|
|
489
|
-
continue
|
|
490
|
-
|
|
491
|
-
# Extract basic information
|
|
492
|
-
thread_id = data.get(FIELD_THREAD_ID)
|
|
493
|
-
checkpoint_ns = data.get(FIELD_CHECKPOINT_NS)
|
|
494
|
-
checkpoint_id = data.get(FIELD_CHECKPOINT_ID)
|
|
495
|
-
|
|
496
|
-
# Build config for this checkpoint
|
|
497
|
-
checkpoint_config = {
|
|
498
|
-
"configurable": {
|
|
499
|
-
"thread_id": thread_id,
|
|
500
|
-
"checkpoint_ns": checkpoint_ns,
|
|
501
|
-
"checkpoint_id": checkpoint_id,
|
|
502
|
-
}
|
|
503
|
-
}
|
|
504
|
-
|
|
505
|
-
# Parse checkpoint data
|
|
506
|
-
type_ = data.get(FIELD_TYPE)
|
|
507
|
-
serialized_checkpoint = data.get(FIELD_DATA)
|
|
508
|
-
serialized_metadata = data.get(FIELD_METADATA)
|
|
509
|
-
|
|
510
|
-
if not type_ or not serialized_checkpoint:
|
|
511
|
-
continue
|
|
512
|
-
|
|
513
|
-
# 重新組合類型和序列化數據,以符合 loads_typed 的期望
|
|
514
|
-
checkpoint = self.serde.loads_typed((type_, serialized_checkpoint))
|
|
515
|
-
metadata = (
|
|
516
|
-
self.serde.loads(serialized_metadata)
|
|
517
|
-
if serialized_metadata
|
|
518
|
-
else None
|
|
519
|
-
)
|
|
520
|
-
|
|
521
|
-
# Load pending writes
|
|
522
|
-
pending_writes = await self._aload_pending_writes(doc.reference)
|
|
523
|
-
|
|
524
|
-
yield CheckpointTuple(
|
|
525
|
-
config=checkpoint_config,
|
|
526
|
-
checkpoint=checkpoint,
|
|
527
|
-
metadata=metadata,
|
|
528
|
-
pending_writes=pending_writes if pending_writes else None,
|
|
529
|
-
)
|
|
530
|
-
except Exception as e:
|
|
531
|
-
logger.error(f"Error listing checkpoints: {e}", exc_info=True)
|
|
532
|
-
raise
|
|
533
|
-
t2 = time.time()
|
|
534
|
-
print(f"[AsyncFirestoreCheckpointer:alist] Elapsed {t2 - t1:.3f}s")
|
|
535
|
-
|
|
536
|
-
async def _aload_pending_writes(
|
|
537
|
-
self, checkpoint_doc_ref: firestore.AsyncDocumentReference
|
|
538
|
-
) -> List[Tuple[str, Any, None]]:
|
|
539
|
-
"""Load pending writes for a checkpoint from its subcollection.
|
|
540
|
-
|
|
541
|
-
Returns a flat list of PendingWrite tuples (channel, value, None) similar to Redis implementation.
|
|
542
|
-
"""
|
|
543
|
-
try:
|
|
544
|
-
# Get the writes subcollection
|
|
545
|
-
writes_collection = self._get_writes_subcollection(checkpoint_doc_ref)
|
|
546
|
-
|
|
547
|
-
# Query all writes documents in the subcollection
|
|
548
|
-
docs = await writes_collection.get()
|
|
549
|
-
|
|
550
|
-
# Process the documents to extract writes
|
|
551
|
-
result = []
|
|
552
|
-
|
|
553
|
-
for doc in docs:
|
|
554
|
-
data = doc.to_dict()
|
|
555
|
-
|
|
556
|
-
if not data:
|
|
557
|
-
continue
|
|
558
|
-
|
|
559
|
-
task_id = data.get(FIELD_TASK_ID)
|
|
560
|
-
channel = data.get(FIELD_CHANNEL)
|
|
561
|
-
type_ = data.get(FIELD_TYPE)
|
|
562
|
-
serialized_value = data.get(FIELD_VALUE)
|
|
563
|
-
|
|
564
|
-
if not task_id or not channel or not type_ or not serialized_value:
|
|
565
|
-
continue
|
|
566
|
-
|
|
567
|
-
# 重新組合類型和序列化數據,以符合 loads_typed 的期望
|
|
568
|
-
value = self.serde.loads_typed((type_, serialized_value))
|
|
569
|
-
|
|
570
|
-
# Create a proper tuple according to PendingWrite definition (channel, value, None)
|
|
571
|
-
# Following the Redis implementation pattern
|
|
572
|
-
result.append((channel, value, None))
|
|
573
|
-
|
|
574
|
-
return result
|
|
575
|
-
except Exception as e:
|
|
576
|
-
logger.error(f"Error loading pending writes: {e}", exc_info=True)
|
|
577
|
-
return []
|
|
578
|
-
|
|
579
|
-
async def adelete_thread(self, thread_id: str) -> None:
|
|
580
|
-
"""Delete all checkpoints and writes for a specific thread asynchronously.
|
|
581
|
-
|
|
582
|
-
This method removes all data associated with a thread, including:
|
|
583
|
-
- All checkpoint documents that match the thread_id
|
|
584
|
-
- All writes subcollections under those checkpoints
|
|
585
|
-
|
|
586
|
-
Args:
|
|
587
|
-
thread_id: The thread ID for which to delete all checkpoints and writes.
|
|
588
|
-
"""
|
|
589
|
-
try:
|
|
590
|
-
logger.info(f"Starting deletion of all data for thread: {thread_id}")
|
|
591
|
-
|
|
592
|
-
# Query all checkpoint documents for this thread_id
|
|
593
|
-
# We need to delete across all checkpoint namespaces
|
|
594
|
-
query = self.checkpoints_collection.where(
|
|
595
|
-
filter=FieldFilter(FIELD_THREAD_ID, "==", thread_id)
|
|
596
|
-
)
|
|
597
|
-
|
|
598
|
-
# Get all matching checkpoint documents
|
|
599
|
-
docs = await query.get()
|
|
600
|
-
|
|
601
|
-
if not docs:
|
|
602
|
-
logger.info(f"No checkpoints found for thread: {thread_id}")
|
|
603
|
-
return
|
|
604
|
-
|
|
605
|
-
deleted_checkpoints = 0
|
|
606
|
-
deleted_writes = 0
|
|
607
|
-
total_operations = 0
|
|
608
|
-
batch_count = 0
|
|
609
|
-
|
|
610
|
-
# Use smaller batches to avoid "Transaction too big" error
|
|
611
|
-
batch = self.db.batch()
|
|
612
|
-
batch_size = 0
|
|
613
|
-
max_batch_size = 200 # Conservative batch size
|
|
614
|
-
|
|
615
|
-
async def commit_current_batch():
|
|
616
|
-
"""Commit the current batch if it has operations"""
|
|
617
|
-
nonlocal batch, batch_size, total_operations, batch_count
|
|
618
|
-
if batch_size > 0:
|
|
619
|
-
await batch.commit()
|
|
620
|
-
total_operations += batch_size
|
|
621
|
-
batch_count += 1
|
|
622
|
-
logger.info(
|
|
623
|
-
f"Thread {thread_id}: Committed batch {batch_count} "
|
|
624
|
-
f"({batch_size} operations, total: {total_operations})"
|
|
625
|
-
)
|
|
626
|
-
batch = self.db.batch()
|
|
627
|
-
batch_size = 0
|
|
628
|
-
|
|
629
|
-
for doc in docs:
|
|
630
|
-
# Delete writes subcollection first
|
|
631
|
-
writes_collection = self._get_writes_subcollection(doc.reference)
|
|
632
|
-
|
|
633
|
-
# Get all writes documents in the subcollection
|
|
634
|
-
writes_docs = await writes_collection.get()
|
|
635
|
-
|
|
636
|
-
# Add writes deletion to batch
|
|
637
|
-
for write_doc in writes_docs:
|
|
638
|
-
batch.delete(write_doc.reference)
|
|
639
|
-
batch_size += 1
|
|
640
|
-
deleted_writes += 1
|
|
641
|
-
|
|
642
|
-
# Commit batch when it reaches max size
|
|
643
|
-
if batch_size >= max_batch_size:
|
|
644
|
-
await commit_current_batch()
|
|
645
|
-
|
|
646
|
-
# Add checkpoint document deletion to batch
|
|
647
|
-
batch.delete(doc.reference)
|
|
648
|
-
batch_size += 1
|
|
649
|
-
deleted_checkpoints += 1
|
|
650
|
-
|
|
651
|
-
# Commit batch when it reaches max size
|
|
652
|
-
if batch_size >= max_batch_size:
|
|
653
|
-
await commit_current_batch()
|
|
654
|
-
|
|
655
|
-
# Commit any remaining operations in the final batch
|
|
656
|
-
await commit_current_batch()
|
|
657
|
-
|
|
658
|
-
logger.info(
|
|
659
|
-
f"Successfully deleted thread {thread_id}: "
|
|
660
|
-
f"{deleted_checkpoints} checkpoints, {deleted_writes} writes "
|
|
661
|
-
f"(total: {total_operations} operations in {batch_count} batches)"
|
|
662
|
-
)
|
|
663
|
-
|
|
664
|
-
except Exception as e:
|
|
665
|
-
logger.error(f"Error deleting thread {thread_id}: {e}", exc_info=True)
|
|
666
|
-
raise
|
|
1
|
+
from typing import (
|
|
2
|
+
Any,
|
|
3
|
+
Dict,
|
|
4
|
+
List,
|
|
5
|
+
Optional,
|
|
6
|
+
Tuple,
|
|
7
|
+
AsyncIterator,
|
|
8
|
+
Iterator,
|
|
9
|
+
cast,
|
|
10
|
+
AsyncGenerator,
|
|
11
|
+
)
|
|
12
|
+
import logging
|
|
13
|
+
from datetime import datetime
|
|
14
|
+
import os
|
|
15
|
+
import asyncio
|
|
16
|
+
from dotenv import load_dotenv
|
|
17
|
+
|
|
18
|
+
from google.cloud import firestore
|
|
19
|
+
from google.cloud.firestore_v1.base_query import FieldFilter
|
|
20
|
+
|
|
21
|
+
from google.cloud.exceptions import GoogleCloudError
|
|
22
|
+
from google.oauth2 import service_account
|
|
23
|
+
|
|
24
|
+
from langgraph.checkpoint.base import (
|
|
25
|
+
BaseCheckpointSaver,
|
|
26
|
+
Checkpoint,
|
|
27
|
+
CheckpointMetadata,
|
|
28
|
+
CheckpointTuple,
|
|
29
|
+
PendingWrite, # Note: PendingWrite is actually Tuple[str, Any, Any]
|
|
30
|
+
get_checkpoint_id,
|
|
31
|
+
WRITES_IDX_MAP,
|
|
32
|
+
ChannelVersions,
|
|
33
|
+
)
|
|
34
|
+
from langgraph.checkpoint.serde.base import SerializerProtocol
|
|
35
|
+
from langgraph.checkpoint.serde.jsonplus import JsonPlusSerializer
|
|
36
|
+
from langgraph.pregel.types import StateSnapshot
|
|
37
|
+
from langchain_core.runnables import RunnableConfig
|
|
38
|
+
|
|
39
|
+
from botrun_flow_lang.constants import CHECKPOINTER_STORE_NAME
|
|
40
|
+
from botrun_flow_lang.services.base.firestore_base import FirestoreBase
|
|
41
|
+
import time
|
|
42
|
+
|
|
43
|
+
load_dotenv()
|
|
44
|
+
|
|
45
|
+
# Set up logger
|
|
46
|
+
logger = logging.getLogger("AsyncFirestoreCheckpointer")
|
|
47
|
+
# 從環境變數取得日誌級別,默認為 WARNING(不顯示 INFO 級別日誌)
|
|
48
|
+
log_level = os.getenv("FIRESTORE_CHECKPOINTER_LOG_LEVEL", "WARNING").upper()
|
|
49
|
+
log_level_map = {
|
|
50
|
+
"DEBUG": logging.DEBUG,
|
|
51
|
+
"INFO": logging.INFO,
|
|
52
|
+
"WARNING": logging.WARNING,
|
|
53
|
+
"ERROR": logging.ERROR,
|
|
54
|
+
"CRITICAL": logging.CRITICAL,
|
|
55
|
+
}
|
|
56
|
+
logger.setLevel(log_level_map.get(log_level, logging.WARNING))
|
|
57
|
+
# Create console handler if it doesn't exist
|
|
58
|
+
if not logger.handlers:
|
|
59
|
+
ch = logging.StreamHandler()
|
|
60
|
+
ch.setLevel(log_level_map.get(log_level, logging.WARNING))
|
|
61
|
+
formatter = logging.Formatter(
|
|
62
|
+
"%(asctime)s - %(name)s - %(levelname)s - %(message)s"
|
|
63
|
+
)
|
|
64
|
+
ch.setFormatter(formatter)
|
|
65
|
+
logger.addHandler(ch)
|
|
66
|
+
|
|
67
|
+
# Constants for field names
|
|
68
|
+
FIELD_THREAD_ID = "thread_id"
|
|
69
|
+
FIELD_CHECKPOINT_NS = "checkpoint_ns"
|
|
70
|
+
FIELD_CHECKPOINT_ID = "checkpoint_id"
|
|
71
|
+
FIELD_PARENT_CHECKPOINT_ID = "parent_checkpoint_id"
|
|
72
|
+
FIELD_TASK_ID = "task_id"
|
|
73
|
+
FIELD_IDX = "idx"
|
|
74
|
+
FIELD_TIMESTAMP = "timestamp"
|
|
75
|
+
FIELD_TYPE = "type"
|
|
76
|
+
FIELD_DATA = "data"
|
|
77
|
+
FIELD_METADATA = "metadata"
|
|
78
|
+
FIELD_NEW_VERSIONS = "new_versions"
|
|
79
|
+
FIELD_CHANNEL = "channel"
|
|
80
|
+
FIELD_VALUE = "value"
|
|
81
|
+
FIELD_CREATED_AT = "created_at"
|
|
82
|
+
|
|
83
|
+
|
|
84
|
+
class AsyncFirestoreCheckpointer(BaseCheckpointSaver):
|
|
85
|
+
"""Async Firestore-based checkpoint saver implementation.
|
|
86
|
+
|
|
87
|
+
This implementation uses Firestore's collections and sub-collections to efficiently
|
|
88
|
+
store and retrieve checkpoints and their associated writes.
|
|
89
|
+
|
|
90
|
+
For each environment, it creates:
|
|
91
|
+
- A root collection for all checkpoints
|
|
92
|
+
- A sub-collection for each checkpoint's writes
|
|
93
|
+
|
|
94
|
+
This design provides:
|
|
95
|
+
- Efficient querying by thread_id, namespace, and checkpoint_id
|
|
96
|
+
- Hierarchical structure that matches the data relationships
|
|
97
|
+
- Improved query performance with proper indexing
|
|
98
|
+
"""
|
|
99
|
+
|
|
100
|
+
db: firestore.AsyncClient
|
|
101
|
+
checkpoints_collection: firestore.AsyncCollectionReference
|
|
102
|
+
|
|
103
|
+
def __init__(
|
|
104
|
+
self,
|
|
105
|
+
env_name: str,
|
|
106
|
+
serializer: Optional[SerializerProtocol] = None,
|
|
107
|
+
collection_name: Optional[str] = None,
|
|
108
|
+
):
|
|
109
|
+
"""Initialize the AsyncFirestoreCheckpointer.
|
|
110
|
+
|
|
111
|
+
Args:
|
|
112
|
+
env_name: Environment name to be used as prefix for collection.
|
|
113
|
+
serializer: Optional serializer to use for converting values to storable format.
|
|
114
|
+
collection_name: Optional custom collection name. If not provided,
|
|
115
|
+
it will use {env_name}-{CHECKPOINTER_STORE_NAME}.
|
|
116
|
+
"""
|
|
117
|
+
super().__init__()
|
|
118
|
+
logger.info(f"Initializing AsyncFirestoreCheckpointer with env_name={env_name}")
|
|
119
|
+
self.serde = serializer or JsonPlusSerializer()
|
|
120
|
+
self._collection_name = (
|
|
121
|
+
collection_name or f"{env_name}-{CHECKPOINTER_STORE_NAME}"
|
|
122
|
+
)
|
|
123
|
+
logger.info(f"Using collection: {self._collection_name}")
|
|
124
|
+
|
|
125
|
+
try:
|
|
126
|
+
# Initialize async Firestore client
|
|
127
|
+
google_service_account_key_path = os.getenv(
|
|
128
|
+
"GOOGLE_APPLICATION_CREDENTIALS_FOR_FASTAPI",
|
|
129
|
+
"/app/keys/scoop-386004-d22d99a7afd9.json",
|
|
130
|
+
)
|
|
131
|
+
credentials = service_account.Credentials.from_service_account_file(
|
|
132
|
+
google_service_account_key_path,
|
|
133
|
+
scopes=["https://www.googleapis.com/auth/datastore"],
|
|
134
|
+
)
|
|
135
|
+
|
|
136
|
+
project_id = os.getenv("GOOGLE_CLOUD_PROJECT")
|
|
137
|
+
if project_id:
|
|
138
|
+
self.db = firestore.AsyncClient(
|
|
139
|
+
project=project_id, credentials=credentials
|
|
140
|
+
)
|
|
141
|
+
else:
|
|
142
|
+
self.db = firestore.AsyncClient(credentials=credentials)
|
|
143
|
+
|
|
144
|
+
self.checkpoints_collection = self.db.collection(self._collection_name)
|
|
145
|
+
logger.info("Async Firestore client initialized successfully")
|
|
146
|
+
except Exception as e:
|
|
147
|
+
logger.error(f"Error initializing Firestore client: {e}", exc_info=True)
|
|
148
|
+
raise
|
|
149
|
+
|
|
150
|
+
async def close(self):
|
|
151
|
+
"""Close the Firestore client connection."""
|
|
152
|
+
if hasattr(self, "db") and self.db:
|
|
153
|
+
await self.db.close()
|
|
154
|
+
logger.info("Firestore client connection closed")
|
|
155
|
+
|
|
156
|
+
async def __aenter__(self):
|
|
157
|
+
"""Context manager entry."""
|
|
158
|
+
return self
|
|
159
|
+
|
|
160
|
+
async def __aexit__(self, exc_type, exc_val, exc_tb):
|
|
161
|
+
"""Context manager exit with cleanup."""
|
|
162
|
+
await self.close()
|
|
163
|
+
|
|
164
|
+
def _get_checkpoint_doc_id(
|
|
165
|
+
self, thread_id: str, checkpoint_ns: str, checkpoint_id: str
|
|
166
|
+
) -> str:
|
|
167
|
+
"""Generate a document ID for a checkpoint.
|
|
168
|
+
|
|
169
|
+
For maximum Firestore efficiency, we use a compound ID that naturally clusters
|
|
170
|
+
related data together for efficient retrieval.
|
|
171
|
+
"""
|
|
172
|
+
return f"{thread_id}:{checkpoint_ns}:{checkpoint_id}"
|
|
173
|
+
|
|
174
|
+
def _get_writes_subcollection(
|
|
175
|
+
self, checkpoint_doc_ref: firestore.AsyncDocumentReference
|
|
176
|
+
) -> firestore.AsyncCollectionReference:
|
|
177
|
+
"""Get the subcollection reference for checkpoint writes."""
|
|
178
|
+
return checkpoint_doc_ref.collection("writes")
|
|
179
|
+
|
|
180
|
+
def _parse_checkpoint_doc_id(self, doc_id: str) -> Dict[str, str]:
|
|
181
|
+
"""Parse a checkpoint document ID into its components."""
|
|
182
|
+
parts = doc_id.split(":")
|
|
183
|
+
if len(parts) != 3:
|
|
184
|
+
raise ValueError(f"Invalid checkpoint document ID format: {doc_id}")
|
|
185
|
+
|
|
186
|
+
return {
|
|
187
|
+
FIELD_THREAD_ID: parts[0],
|
|
188
|
+
FIELD_CHECKPOINT_NS: parts[1],
|
|
189
|
+
FIELD_CHECKPOINT_ID: parts[2],
|
|
190
|
+
}
|
|
191
|
+
|
|
192
|
+
async def aput(
|
|
193
|
+
self,
|
|
194
|
+
config: RunnableConfig,
|
|
195
|
+
checkpoint: Checkpoint,
|
|
196
|
+
metadata: CheckpointMetadata,
|
|
197
|
+
new_versions: ChannelVersions,
|
|
198
|
+
) -> RunnableConfig:
|
|
199
|
+
"""Save a checkpoint to Firestore asynchronously.
|
|
200
|
+
|
|
201
|
+
This method saves a checkpoint to Firestore as a document with fields for
|
|
202
|
+
efficient querying.
|
|
203
|
+
|
|
204
|
+
Args:
|
|
205
|
+
config: The config to associate with the checkpoint.
|
|
206
|
+
checkpoint: The checkpoint to save.
|
|
207
|
+
metadata: Additional metadata to save with the checkpoint.
|
|
208
|
+
new_versions: New channel versions as of this write.
|
|
209
|
+
|
|
210
|
+
Returns:
|
|
211
|
+
RunnableConfig: Updated configuration after storing the checkpoint.
|
|
212
|
+
"""
|
|
213
|
+
thread_id = config["configurable"]["thread_id"]
|
|
214
|
+
checkpoint_ns = config["configurable"].get("checkpoint_ns", "")
|
|
215
|
+
checkpoint_id = checkpoint["id"]
|
|
216
|
+
parent_checkpoint_id = config["configurable"].get("checkpoint_id", "")
|
|
217
|
+
|
|
218
|
+
# Generate document ID for efficient querying
|
|
219
|
+
doc_id = self._get_checkpoint_doc_id(thread_id, checkpoint_ns, checkpoint_id)
|
|
220
|
+
|
|
221
|
+
# Serialize the data
|
|
222
|
+
type_, serialized_checkpoint = self.serde.dumps_typed(checkpoint)
|
|
223
|
+
serialized_metadata = self.serde.dumps(metadata)
|
|
224
|
+
|
|
225
|
+
# Prepare the document data
|
|
226
|
+
data = {
|
|
227
|
+
FIELD_THREAD_ID: thread_id,
|
|
228
|
+
FIELD_CHECKPOINT_NS: checkpoint_ns,
|
|
229
|
+
FIELD_CHECKPOINT_ID: checkpoint_id,
|
|
230
|
+
FIELD_PARENT_CHECKPOINT_ID: parent_checkpoint_id,
|
|
231
|
+
FIELD_TYPE: type_,
|
|
232
|
+
FIELD_DATA: serialized_checkpoint,
|
|
233
|
+
FIELD_METADATA: serialized_metadata,
|
|
234
|
+
FIELD_TIMESTAMP: firestore.SERVER_TIMESTAMP, # Use server timestamp for consistency
|
|
235
|
+
FIELD_CREATED_AT: datetime.utcnow().isoformat(), # Backup client-side timestamp
|
|
236
|
+
}
|
|
237
|
+
|
|
238
|
+
if new_versions:
|
|
239
|
+
data[FIELD_NEW_VERSIONS] = self.serde.dumps(new_versions)
|
|
240
|
+
|
|
241
|
+
try:
|
|
242
|
+
await self.checkpoints_collection.document(doc_id).set(data)
|
|
243
|
+
logger.info(f"Successfully stored checkpoint with ID: {doc_id}")
|
|
244
|
+
except Exception as e:
|
|
245
|
+
logger.error(f"Error storing checkpoint: {e}", exc_info=True)
|
|
246
|
+
raise
|
|
247
|
+
|
|
248
|
+
return {
|
|
249
|
+
"configurable": {
|
|
250
|
+
"thread_id": thread_id,
|
|
251
|
+
"checkpoint_ns": checkpoint_ns,
|
|
252
|
+
"checkpoint_id": checkpoint_id,
|
|
253
|
+
}
|
|
254
|
+
}
|
|
255
|
+
|
|
256
|
+
async def aput_writes(
|
|
257
|
+
self,
|
|
258
|
+
config: RunnableConfig,
|
|
259
|
+
writes: List[Tuple[str, Any]],
|
|
260
|
+
task_id: str,
|
|
261
|
+
) -> None:
|
|
262
|
+
"""Store intermediate writes linked to a checkpoint asynchronously.
|
|
263
|
+
|
|
264
|
+
This method saves intermediate writes associated with a checkpoint in a subcollection.
|
|
265
|
+
|
|
266
|
+
Args:
|
|
267
|
+
config: Configuration of the related checkpoint.
|
|
268
|
+
writes: List of writes to store, each as (channel, value) pair.
|
|
269
|
+
task_id: Identifier for the task creating the writes.
|
|
270
|
+
"""
|
|
271
|
+
thread_id = config["configurable"]["thread_id"]
|
|
272
|
+
checkpoint_ns = config["configurable"].get("checkpoint_ns", "")
|
|
273
|
+
checkpoint_id = config["configurable"]["checkpoint_id"]
|
|
274
|
+
|
|
275
|
+
# Get the checkpoint document reference
|
|
276
|
+
checkpoint_doc_id = self._get_checkpoint_doc_id(
|
|
277
|
+
thread_id, checkpoint_ns, checkpoint_id
|
|
278
|
+
)
|
|
279
|
+
checkpoint_doc_ref = self.checkpoints_collection.document(checkpoint_doc_id)
|
|
280
|
+
|
|
281
|
+
# Get the writes subcollection
|
|
282
|
+
writes_collection = self._get_writes_subcollection(checkpoint_doc_ref)
|
|
283
|
+
|
|
284
|
+
try:
|
|
285
|
+
# Optimize write operations with batching
|
|
286
|
+
batch = self.db.batch()
|
|
287
|
+
batch_size = 0
|
|
288
|
+
max_batch_size = 450 # Slightly below Firestore limit for safety
|
|
289
|
+
batch_futures = [] # For tracking concurrent batch commits
|
|
290
|
+
|
|
291
|
+
for idx, (channel, value) in enumerate(writes):
|
|
292
|
+
# Determine the write ID
|
|
293
|
+
write_idx = WRITES_IDX_MAP.get(channel, idx)
|
|
294
|
+
write_id = f"{task_id}:{write_idx}"
|
|
295
|
+
|
|
296
|
+
# Serialize the value
|
|
297
|
+
type_, serialized_value = self.serde.dumps_typed(value)
|
|
298
|
+
|
|
299
|
+
# Prepare the write data
|
|
300
|
+
data = {
|
|
301
|
+
FIELD_TASK_ID: task_id,
|
|
302
|
+
FIELD_IDX: write_idx,
|
|
303
|
+
FIELD_CHANNEL: channel,
|
|
304
|
+
FIELD_TYPE: type_,
|
|
305
|
+
FIELD_VALUE: serialized_value,
|
|
306
|
+
FIELD_TIMESTAMP: firestore.SERVER_TIMESTAMP,
|
|
307
|
+
FIELD_CREATED_AT: datetime.utcnow().isoformat(),
|
|
308
|
+
}
|
|
309
|
+
|
|
310
|
+
write_doc_ref = writes_collection.document(write_id)
|
|
311
|
+
|
|
312
|
+
# Determine if we should set or create-if-not-exists
|
|
313
|
+
if channel in WRITES_IDX_MAP:
|
|
314
|
+
# For indexed channels, always set (similar to HSET behavior)
|
|
315
|
+
batch.set(write_doc_ref, data)
|
|
316
|
+
else:
|
|
317
|
+
# For non-indexed channels, we need a transaction to check existence
|
|
318
|
+
# We'll check existence manually for now
|
|
319
|
+
doc = await write_doc_ref.get()
|
|
320
|
+
if not doc.exists:
|
|
321
|
+
batch.set(write_doc_ref, data)
|
|
322
|
+
|
|
323
|
+
batch_size += 1
|
|
324
|
+
|
|
325
|
+
# If batch is getting full, submit it and start a new one
|
|
326
|
+
if batch_size >= max_batch_size:
|
|
327
|
+
batch_futures.append(batch.commit())
|
|
328
|
+
batch = self.db.batch()
|
|
329
|
+
batch_size = 0
|
|
330
|
+
|
|
331
|
+
# Commit any remaining writes in the batch
|
|
332
|
+
if batch_size > 0:
|
|
333
|
+
batch_futures.append(batch.commit())
|
|
334
|
+
|
|
335
|
+
# Wait for all batch operations to complete
|
|
336
|
+
if batch_futures:
|
|
337
|
+
await asyncio.gather(*batch_futures)
|
|
338
|
+
|
|
339
|
+
logger.info(
|
|
340
|
+
f"Successfully stored {len(writes)} writes for checkpoint: {checkpoint_id}"
|
|
341
|
+
)
|
|
342
|
+
except Exception as e:
|
|
343
|
+
logger.error(f"Error storing writes: {e}", exc_info=True)
|
|
344
|
+
raise
|
|
345
|
+
|
|
346
|
+
async def aget_tuple(
|
|
347
|
+
self,
|
|
348
|
+
config: RunnableConfig,
|
|
349
|
+
) -> Optional[CheckpointTuple]:
|
|
350
|
+
"""Get a checkpoint tuple from Firestore asynchronously.
|
|
351
|
+
|
|
352
|
+
This method retrieves a checkpoint and its associated writes from Firestore.
|
|
353
|
+
|
|
354
|
+
Args:
|
|
355
|
+
config: The config to use for retrieving the checkpoint.
|
|
356
|
+
|
|
357
|
+
Returns:
|
|
358
|
+
Optional[CheckpointTuple]: The retrieved checkpoint tuple, or None if not found.
|
|
359
|
+
"""
|
|
360
|
+
thread_id = config["configurable"]["thread_id"]
|
|
361
|
+
checkpoint_id = get_checkpoint_id(config)
|
|
362
|
+
checkpoint_ns = config["configurable"].get("checkpoint_ns", "")
|
|
363
|
+
|
|
364
|
+
try:
|
|
365
|
+
# If checkpoint_id is provided, get that specific checkpoint
|
|
366
|
+
if checkpoint_id:
|
|
367
|
+
doc_id = self._get_checkpoint_doc_id(
|
|
368
|
+
thread_id, checkpoint_ns, checkpoint_id
|
|
369
|
+
)
|
|
370
|
+
doc = await self.checkpoints_collection.document(doc_id).get()
|
|
371
|
+
|
|
372
|
+
if not doc.exists:
|
|
373
|
+
return None
|
|
374
|
+
else:
|
|
375
|
+
# Otherwise, find the latest checkpoint
|
|
376
|
+
query = (
|
|
377
|
+
self.checkpoints_collection.where(
|
|
378
|
+
filter=FieldFilter(FIELD_THREAD_ID, "==", thread_id)
|
|
379
|
+
)
|
|
380
|
+
.where(filter=FieldFilter(FIELD_CHECKPOINT_NS, "==", checkpoint_ns))
|
|
381
|
+
.order_by(FIELD_TIMESTAMP, direction=firestore.Query.DESCENDING)
|
|
382
|
+
.limit(1)
|
|
383
|
+
)
|
|
384
|
+
|
|
385
|
+
docs = await query.get()
|
|
386
|
+
if not docs:
|
|
387
|
+
return None
|
|
388
|
+
|
|
389
|
+
doc = docs[0]
|
|
390
|
+
# Extract the checkpoint_id for loading writes
|
|
391
|
+
checkpoint_id = doc.get(FIELD_CHECKPOINT_ID)
|
|
392
|
+
|
|
393
|
+
data = doc.to_dict()
|
|
394
|
+
|
|
395
|
+
# Parse the document data
|
|
396
|
+
type_ = data.get(FIELD_TYPE)
|
|
397
|
+
serialized_checkpoint = data.get(FIELD_DATA)
|
|
398
|
+
serialized_metadata = data.get(FIELD_METADATA)
|
|
399
|
+
|
|
400
|
+
if not type_ or not serialized_checkpoint or not serialized_metadata:
|
|
401
|
+
logger.error(f"Invalid checkpoint data for ID: {doc.id}")
|
|
402
|
+
return None
|
|
403
|
+
|
|
404
|
+
# 重新組合類型和序列化數據,以符合 loads_typed 的期望
|
|
405
|
+
checkpoint = self.serde.loads_typed((type_, serialized_checkpoint))
|
|
406
|
+
metadata = self.serde.loads(serialized_metadata)
|
|
407
|
+
|
|
408
|
+
# Load pending writes from the subcollection
|
|
409
|
+
pending_writes = await self._aload_pending_writes(doc.reference)
|
|
410
|
+
|
|
411
|
+
return CheckpointTuple(
|
|
412
|
+
config=config,
|
|
413
|
+
checkpoint=checkpoint,
|
|
414
|
+
metadata=metadata,
|
|
415
|
+
pending_writes=pending_writes if pending_writes else None,
|
|
416
|
+
)
|
|
417
|
+
except Exception as e:
|
|
418
|
+
logger.error(f"Error retrieving checkpoint tuple: {e}", exc_info=True)
|
|
419
|
+
raise
|
|
420
|
+
|
|
421
|
+
async def alist(
|
|
422
|
+
self,
|
|
423
|
+
config: Optional[RunnableConfig],
|
|
424
|
+
*,
|
|
425
|
+
filter: Optional[dict[str, Any]] = None,
|
|
426
|
+
before: Optional[RunnableConfig] = None,
|
|
427
|
+
limit: Optional[int] = None,
|
|
428
|
+
) -> AsyncGenerator[CheckpointTuple, None]:
|
|
429
|
+
"""List checkpoints from Firestore asynchronously.
|
|
430
|
+
|
|
431
|
+
This method retrieves a list of checkpoint tuples from Firestore based
|
|
432
|
+
on the provided config.
|
|
433
|
+
|
|
434
|
+
Args:
|
|
435
|
+
config: Base configuration for filtering checkpoints.
|
|
436
|
+
filter: Additional filtering criteria for metadata.
|
|
437
|
+
before: If provided, only checkpoints before the specified checkpoint ID are returned.
|
|
438
|
+
limit: Maximum number of checkpoints to return.
|
|
439
|
+
|
|
440
|
+
Yields:
|
|
441
|
+
AsyncGenerator[CheckpointTuple, None]: An async generator of matching checkpoint tuples.
|
|
442
|
+
"""
|
|
443
|
+
if not config:
|
|
444
|
+
logger.error("Config is required for listing checkpoints")
|
|
445
|
+
return
|
|
446
|
+
|
|
447
|
+
thread_id = config["configurable"]["thread_id"]
|
|
448
|
+
checkpoint_ns = config["configurable"].get("checkpoint_ns", "")
|
|
449
|
+
|
|
450
|
+
try:
|
|
451
|
+
t1 = time.time()
|
|
452
|
+
# Build the query
|
|
453
|
+
query = (
|
|
454
|
+
self.checkpoints_collection.where(
|
|
455
|
+
filter=FieldFilter(FIELD_THREAD_ID, "==", thread_id)
|
|
456
|
+
)
|
|
457
|
+
.where(filter=FieldFilter(FIELD_CHECKPOINT_NS, "==", checkpoint_ns))
|
|
458
|
+
.order_by(FIELD_TIMESTAMP, direction=firestore.Query.DESCENDING)
|
|
459
|
+
)
|
|
460
|
+
|
|
461
|
+
# Apply additional filters
|
|
462
|
+
if before is not None:
|
|
463
|
+
before_id = get_checkpoint_id(before)
|
|
464
|
+
# We need to find the timestamp of the 'before' checkpoint to filter correctly
|
|
465
|
+
before_doc_id = self._get_checkpoint_doc_id(
|
|
466
|
+
thread_id, checkpoint_ns, before_id
|
|
467
|
+
)
|
|
468
|
+
before_doc = await self.checkpoints_collection.document(
|
|
469
|
+
before_doc_id
|
|
470
|
+
).get()
|
|
471
|
+
|
|
472
|
+
if before_doc.exists:
|
|
473
|
+
before_timestamp = before_doc.get(FIELD_TIMESTAMP)
|
|
474
|
+
if before_timestamp:
|
|
475
|
+
query = query.where(FIELD_TIMESTAMP, "<", before_timestamp)
|
|
476
|
+
|
|
477
|
+
# Apply limit if provided
|
|
478
|
+
if limit is not None:
|
|
479
|
+
query = query.limit(limit)
|
|
480
|
+
|
|
481
|
+
# Execute the query
|
|
482
|
+
docs = await query.get()
|
|
483
|
+
|
|
484
|
+
# Process each document
|
|
485
|
+
for doc in docs:
|
|
486
|
+
data = doc.to_dict()
|
|
487
|
+
|
|
488
|
+
if not data or FIELD_DATA not in data or FIELD_METADATA not in data:
|
|
489
|
+
continue
|
|
490
|
+
|
|
491
|
+
# Extract basic information
|
|
492
|
+
thread_id = data.get(FIELD_THREAD_ID)
|
|
493
|
+
checkpoint_ns = data.get(FIELD_CHECKPOINT_NS)
|
|
494
|
+
checkpoint_id = data.get(FIELD_CHECKPOINT_ID)
|
|
495
|
+
|
|
496
|
+
# Build config for this checkpoint
|
|
497
|
+
checkpoint_config = {
|
|
498
|
+
"configurable": {
|
|
499
|
+
"thread_id": thread_id,
|
|
500
|
+
"checkpoint_ns": checkpoint_ns,
|
|
501
|
+
"checkpoint_id": checkpoint_id,
|
|
502
|
+
}
|
|
503
|
+
}
|
|
504
|
+
|
|
505
|
+
# Parse checkpoint data
|
|
506
|
+
type_ = data.get(FIELD_TYPE)
|
|
507
|
+
serialized_checkpoint = data.get(FIELD_DATA)
|
|
508
|
+
serialized_metadata = data.get(FIELD_METADATA)
|
|
509
|
+
|
|
510
|
+
if not type_ or not serialized_checkpoint:
|
|
511
|
+
continue
|
|
512
|
+
|
|
513
|
+
# 重新組合類型和序列化數據,以符合 loads_typed 的期望
|
|
514
|
+
checkpoint = self.serde.loads_typed((type_, serialized_checkpoint))
|
|
515
|
+
metadata = (
|
|
516
|
+
self.serde.loads(serialized_metadata)
|
|
517
|
+
if serialized_metadata
|
|
518
|
+
else None
|
|
519
|
+
)
|
|
520
|
+
|
|
521
|
+
# Load pending writes
|
|
522
|
+
pending_writes = await self._aload_pending_writes(doc.reference)
|
|
523
|
+
|
|
524
|
+
yield CheckpointTuple(
|
|
525
|
+
config=checkpoint_config,
|
|
526
|
+
checkpoint=checkpoint,
|
|
527
|
+
metadata=metadata,
|
|
528
|
+
pending_writes=pending_writes if pending_writes else None,
|
|
529
|
+
)
|
|
530
|
+
except Exception as e:
|
|
531
|
+
logger.error(f"Error listing checkpoints: {e}", exc_info=True)
|
|
532
|
+
raise
|
|
533
|
+
t2 = time.time()
|
|
534
|
+
print(f"[AsyncFirestoreCheckpointer:alist] Elapsed {t2 - t1:.3f}s")
|
|
535
|
+
|
|
536
|
+
async def _aload_pending_writes(
|
|
537
|
+
self, checkpoint_doc_ref: firestore.AsyncDocumentReference
|
|
538
|
+
) -> List[Tuple[str, Any, None]]:
|
|
539
|
+
"""Load pending writes for a checkpoint from its subcollection.
|
|
540
|
+
|
|
541
|
+
Returns a flat list of PendingWrite tuples (channel, value, None) similar to Redis implementation.
|
|
542
|
+
"""
|
|
543
|
+
try:
|
|
544
|
+
# Get the writes subcollection
|
|
545
|
+
writes_collection = self._get_writes_subcollection(checkpoint_doc_ref)
|
|
546
|
+
|
|
547
|
+
# Query all writes documents in the subcollection
|
|
548
|
+
docs = await writes_collection.get()
|
|
549
|
+
|
|
550
|
+
# Process the documents to extract writes
|
|
551
|
+
result = []
|
|
552
|
+
|
|
553
|
+
for doc in docs:
|
|
554
|
+
data = doc.to_dict()
|
|
555
|
+
|
|
556
|
+
if not data:
|
|
557
|
+
continue
|
|
558
|
+
|
|
559
|
+
task_id = data.get(FIELD_TASK_ID)
|
|
560
|
+
channel = data.get(FIELD_CHANNEL)
|
|
561
|
+
type_ = data.get(FIELD_TYPE)
|
|
562
|
+
serialized_value = data.get(FIELD_VALUE)
|
|
563
|
+
|
|
564
|
+
if not task_id or not channel or not type_ or not serialized_value:
|
|
565
|
+
continue
|
|
566
|
+
|
|
567
|
+
# 重新組合類型和序列化數據,以符合 loads_typed 的期望
|
|
568
|
+
value = self.serde.loads_typed((type_, serialized_value))
|
|
569
|
+
|
|
570
|
+
# Create a proper tuple according to PendingWrite definition (channel, value, None)
|
|
571
|
+
# Following the Redis implementation pattern
|
|
572
|
+
result.append((channel, value, None))
|
|
573
|
+
|
|
574
|
+
return result
|
|
575
|
+
except Exception as e:
|
|
576
|
+
logger.error(f"Error loading pending writes: {e}", exc_info=True)
|
|
577
|
+
return []
|
|
578
|
+
|
|
579
|
+
async def adelete_thread(self, thread_id: str) -> None:
|
|
580
|
+
"""Delete all checkpoints and writes for a specific thread asynchronously.
|
|
581
|
+
|
|
582
|
+
This method removes all data associated with a thread, including:
|
|
583
|
+
- All checkpoint documents that match the thread_id
|
|
584
|
+
- All writes subcollections under those checkpoints
|
|
585
|
+
|
|
586
|
+
Args:
|
|
587
|
+
thread_id: The thread ID for which to delete all checkpoints and writes.
|
|
588
|
+
"""
|
|
589
|
+
try:
|
|
590
|
+
logger.info(f"Starting deletion of all data for thread: {thread_id}")
|
|
591
|
+
|
|
592
|
+
# Query all checkpoint documents for this thread_id
|
|
593
|
+
# We need to delete across all checkpoint namespaces
|
|
594
|
+
query = self.checkpoints_collection.where(
|
|
595
|
+
filter=FieldFilter(FIELD_THREAD_ID, "==", thread_id)
|
|
596
|
+
)
|
|
597
|
+
|
|
598
|
+
# Get all matching checkpoint documents
|
|
599
|
+
docs = await query.get()
|
|
600
|
+
|
|
601
|
+
if not docs:
|
|
602
|
+
logger.info(f"No checkpoints found for thread: {thread_id}")
|
|
603
|
+
return
|
|
604
|
+
|
|
605
|
+
deleted_checkpoints = 0
|
|
606
|
+
deleted_writes = 0
|
|
607
|
+
total_operations = 0
|
|
608
|
+
batch_count = 0
|
|
609
|
+
|
|
610
|
+
# Use smaller batches to avoid "Transaction too big" error
|
|
611
|
+
batch = self.db.batch()
|
|
612
|
+
batch_size = 0
|
|
613
|
+
max_batch_size = 200 # Conservative batch size
|
|
614
|
+
|
|
615
|
+
async def commit_current_batch():
|
|
616
|
+
"""Commit the current batch if it has operations"""
|
|
617
|
+
nonlocal batch, batch_size, total_operations, batch_count
|
|
618
|
+
if batch_size > 0:
|
|
619
|
+
await batch.commit()
|
|
620
|
+
total_operations += batch_size
|
|
621
|
+
batch_count += 1
|
|
622
|
+
logger.info(
|
|
623
|
+
f"Thread {thread_id}: Committed batch {batch_count} "
|
|
624
|
+
f"({batch_size} operations, total: {total_operations})"
|
|
625
|
+
)
|
|
626
|
+
batch = self.db.batch()
|
|
627
|
+
batch_size = 0
|
|
628
|
+
|
|
629
|
+
for doc in docs:
|
|
630
|
+
# Delete writes subcollection first
|
|
631
|
+
writes_collection = self._get_writes_subcollection(doc.reference)
|
|
632
|
+
|
|
633
|
+
# Get all writes documents in the subcollection
|
|
634
|
+
writes_docs = await writes_collection.get()
|
|
635
|
+
|
|
636
|
+
# Add writes deletion to batch
|
|
637
|
+
for write_doc in writes_docs:
|
|
638
|
+
batch.delete(write_doc.reference)
|
|
639
|
+
batch_size += 1
|
|
640
|
+
deleted_writes += 1
|
|
641
|
+
|
|
642
|
+
# Commit batch when it reaches max size
|
|
643
|
+
if batch_size >= max_batch_size:
|
|
644
|
+
await commit_current_batch()
|
|
645
|
+
|
|
646
|
+
# Add checkpoint document deletion to batch
|
|
647
|
+
batch.delete(doc.reference)
|
|
648
|
+
batch_size += 1
|
|
649
|
+
deleted_checkpoints += 1
|
|
650
|
+
|
|
651
|
+
# Commit batch when it reaches max size
|
|
652
|
+
if batch_size >= max_batch_size:
|
|
653
|
+
await commit_current_batch()
|
|
654
|
+
|
|
655
|
+
# Commit any remaining operations in the final batch
|
|
656
|
+
await commit_current_batch()
|
|
657
|
+
|
|
658
|
+
logger.info(
|
|
659
|
+
f"Successfully deleted thread {thread_id}: "
|
|
660
|
+
f"{deleted_checkpoints} checkpoints, {deleted_writes} writes "
|
|
661
|
+
f"(total: {total_operations} operations in {batch_count} batches)"
|
|
662
|
+
)
|
|
663
|
+
|
|
664
|
+
except Exception as e:
|
|
665
|
+
logger.error(f"Error deleting thread {thread_id}: {e}", exc_info=True)
|
|
666
|
+
raise
|