botrun-flow-lang 5.12.263__py3-none-any.whl → 5.12.264__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (87) hide show
  1. botrun_flow_lang/api/auth_api.py +39 -39
  2. botrun_flow_lang/api/auth_utils.py +183 -183
  3. botrun_flow_lang/api/botrun_back_api.py +65 -65
  4. botrun_flow_lang/api/flow_api.py +3 -3
  5. botrun_flow_lang/api/hatch_api.py +508 -508
  6. botrun_flow_lang/api/langgraph_api.py +811 -811
  7. botrun_flow_lang/api/line_bot_api.py +1484 -1484
  8. botrun_flow_lang/api/model_api.py +300 -300
  9. botrun_flow_lang/api/rate_limit_api.py +32 -32
  10. botrun_flow_lang/api/routes.py +79 -79
  11. botrun_flow_lang/api/search_api.py +53 -53
  12. botrun_flow_lang/api/storage_api.py +395 -395
  13. botrun_flow_lang/api/subsidy_api.py +290 -290
  14. botrun_flow_lang/api/subsidy_api_system_prompt.txt +109 -109
  15. botrun_flow_lang/api/user_setting_api.py +70 -70
  16. botrun_flow_lang/api/version_api.py +31 -31
  17. botrun_flow_lang/api/youtube_api.py +26 -26
  18. botrun_flow_lang/constants.py +13 -13
  19. botrun_flow_lang/langgraph_agents/agents/agent_runner.py +178 -178
  20. botrun_flow_lang/langgraph_agents/agents/agent_tools/step_planner.py +77 -77
  21. botrun_flow_lang/langgraph_agents/agents/checkpointer/firestore_checkpointer.py +666 -666
  22. botrun_flow_lang/langgraph_agents/agents/gov_researcher/GOV_RESEARCHER_PRD.md +192 -192
  23. botrun_flow_lang/langgraph_agents/agents/gov_researcher/gemini_subsidy_graph.py +460 -460
  24. botrun_flow_lang/langgraph_agents/agents/gov_researcher/gov_researcher_2_graph.py +1002 -1002
  25. botrun_flow_lang/langgraph_agents/agents/gov_researcher/gov_researcher_graph.py +822 -822
  26. botrun_flow_lang/langgraph_agents/agents/langgraph_react_agent.py +723 -723
  27. botrun_flow_lang/langgraph_agents/agents/search_agent_graph.py +864 -864
  28. botrun_flow_lang/langgraph_agents/agents/tools/__init__.py +4 -4
  29. botrun_flow_lang/langgraph_agents/agents/tools/gemini_code_execution.py +376 -376
  30. botrun_flow_lang/langgraph_agents/agents/util/gemini_grounding.py +66 -66
  31. botrun_flow_lang/langgraph_agents/agents/util/html_util.py +316 -316
  32. botrun_flow_lang/langgraph_agents/agents/util/img_util.py +294 -294
  33. botrun_flow_lang/langgraph_agents/agents/util/local_files.py +419 -419
  34. botrun_flow_lang/langgraph_agents/agents/util/mermaid_util.py +86 -86
  35. botrun_flow_lang/langgraph_agents/agents/util/model_utils.py +143 -143
  36. botrun_flow_lang/langgraph_agents/agents/util/pdf_analyzer.py +486 -486
  37. botrun_flow_lang/langgraph_agents/agents/util/pdf_cache.py +250 -250
  38. botrun_flow_lang/langgraph_agents/agents/util/pdf_processor.py +204 -204
  39. botrun_flow_lang/langgraph_agents/agents/util/perplexity_search.py +464 -464
  40. botrun_flow_lang/langgraph_agents/agents/util/plotly_util.py +59 -59
  41. botrun_flow_lang/langgraph_agents/agents/util/tavily_search.py +199 -199
  42. botrun_flow_lang/langgraph_agents/agents/util/youtube_util.py +90 -90
  43. botrun_flow_lang/langgraph_agents/cache/langgraph_botrun_cache.py +197 -197
  44. botrun_flow_lang/llm_agent/llm_agent.py +19 -19
  45. botrun_flow_lang/llm_agent/llm_agent_util.py +83 -83
  46. botrun_flow_lang/log/.gitignore +2 -2
  47. botrun_flow_lang/main.py +61 -61
  48. botrun_flow_lang/main_fast.py +51 -51
  49. botrun_flow_lang/mcp_server/__init__.py +10 -10
  50. botrun_flow_lang/mcp_server/default_mcp.py +744 -744
  51. botrun_flow_lang/models/nodes/utils.py +205 -205
  52. botrun_flow_lang/models/token_usage.py +34 -34
  53. botrun_flow_lang/requirements.txt +21 -21
  54. botrun_flow_lang/services/base/firestore_base.py +30 -30
  55. botrun_flow_lang/services/hatch/hatch_factory.py +11 -11
  56. botrun_flow_lang/services/hatch/hatch_fs_store.py +419 -419
  57. botrun_flow_lang/services/storage/storage_cs_store.py +206 -206
  58. botrun_flow_lang/services/storage/storage_factory.py +12 -12
  59. botrun_flow_lang/services/storage/storage_store.py +65 -65
  60. botrun_flow_lang/services/user_setting/user_setting_factory.py +9 -9
  61. botrun_flow_lang/services/user_setting/user_setting_fs_store.py +66 -66
  62. botrun_flow_lang/static/docs/tools/index.html +926 -926
  63. botrun_flow_lang/tests/api_functional_tests.py +1525 -1525
  64. botrun_flow_lang/tests/api_stress_test.py +357 -357
  65. botrun_flow_lang/tests/shared_hatch_tests.py +333 -333
  66. botrun_flow_lang/tests/test_botrun_app.py +46 -46
  67. botrun_flow_lang/tests/test_html_util.py +31 -31
  68. botrun_flow_lang/tests/test_img_analyzer.py +190 -190
  69. botrun_flow_lang/tests/test_img_util.py +39 -39
  70. botrun_flow_lang/tests/test_local_files.py +114 -114
  71. botrun_flow_lang/tests/test_mermaid_util.py +103 -103
  72. botrun_flow_lang/tests/test_pdf_analyzer.py +104 -104
  73. botrun_flow_lang/tests/test_plotly_util.py +151 -151
  74. botrun_flow_lang/tests/test_run_workflow_engine.py +65 -65
  75. botrun_flow_lang/tools/generate_docs.py +133 -133
  76. botrun_flow_lang/tools/templates/tools.html +153 -153
  77. botrun_flow_lang/utils/__init__.py +7 -7
  78. botrun_flow_lang/utils/botrun_logger.py +344 -344
  79. botrun_flow_lang/utils/clients/rate_limit_client.py +209 -209
  80. botrun_flow_lang/utils/clients/token_verify_client.py +153 -153
  81. botrun_flow_lang/utils/google_drive_utils.py +654 -654
  82. botrun_flow_lang/utils/langchain_utils.py +324 -324
  83. botrun_flow_lang/utils/yaml_utils.py +9 -9
  84. {botrun_flow_lang-5.12.263.dist-info → botrun_flow_lang-5.12.264.dist-info}/METADATA +1 -1
  85. botrun_flow_lang-5.12.264.dist-info/RECORD +102 -0
  86. botrun_flow_lang-5.12.263.dist-info/RECORD +0 -102
  87. {botrun_flow_lang-5.12.263.dist-info → botrun_flow_lang-5.12.264.dist-info}/WHEEL +0 -0
@@ -1,666 +1,666 @@
1
- from typing import (
2
- Any,
3
- Dict,
4
- List,
5
- Optional,
6
- Tuple,
7
- AsyncIterator,
8
- Iterator,
9
- cast,
10
- AsyncGenerator,
11
- )
12
- import logging
13
- from datetime import datetime
14
- import os
15
- import asyncio
16
- from dotenv import load_dotenv
17
-
18
- from google.cloud import firestore
19
- from google.cloud.firestore_v1.base_query import FieldFilter
20
-
21
- from google.cloud.exceptions import GoogleCloudError
22
- from google.oauth2 import service_account
23
-
24
- from langgraph.checkpoint.base import (
25
- BaseCheckpointSaver,
26
- Checkpoint,
27
- CheckpointMetadata,
28
- CheckpointTuple,
29
- PendingWrite, # Note: PendingWrite is actually Tuple[str, Any, Any]
30
- get_checkpoint_id,
31
- WRITES_IDX_MAP,
32
- ChannelVersions,
33
- )
34
- from langgraph.checkpoint.serde.base import SerializerProtocol
35
- from langgraph.checkpoint.serde.jsonplus import JsonPlusSerializer
36
- from langgraph.pregel.types import StateSnapshot
37
- from langchain_core.runnables import RunnableConfig
38
-
39
- from botrun_flow_lang.constants import CHECKPOINTER_STORE_NAME
40
- from botrun_flow_lang.services.base.firestore_base import FirestoreBase
41
- import time
42
-
43
- load_dotenv()
44
-
45
- # Set up logger
46
- logger = logging.getLogger("AsyncFirestoreCheckpointer")
47
- # 從環境變數取得日誌級別,默認為 WARNING(不顯示 INFO 級別日誌)
48
- log_level = os.getenv("FIRESTORE_CHECKPOINTER_LOG_LEVEL", "WARNING").upper()
49
- log_level_map = {
50
- "DEBUG": logging.DEBUG,
51
- "INFO": logging.INFO,
52
- "WARNING": logging.WARNING,
53
- "ERROR": logging.ERROR,
54
- "CRITICAL": logging.CRITICAL,
55
- }
56
- logger.setLevel(log_level_map.get(log_level, logging.WARNING))
57
- # Create console handler if it doesn't exist
58
- if not logger.handlers:
59
- ch = logging.StreamHandler()
60
- ch.setLevel(log_level_map.get(log_level, logging.WARNING))
61
- formatter = logging.Formatter(
62
- "%(asctime)s - %(name)s - %(levelname)s - %(message)s"
63
- )
64
- ch.setFormatter(formatter)
65
- logger.addHandler(ch)
66
-
67
- # Constants for field names
68
- FIELD_THREAD_ID = "thread_id"
69
- FIELD_CHECKPOINT_NS = "checkpoint_ns"
70
- FIELD_CHECKPOINT_ID = "checkpoint_id"
71
- FIELD_PARENT_CHECKPOINT_ID = "parent_checkpoint_id"
72
- FIELD_TASK_ID = "task_id"
73
- FIELD_IDX = "idx"
74
- FIELD_TIMESTAMP = "timestamp"
75
- FIELD_TYPE = "type"
76
- FIELD_DATA = "data"
77
- FIELD_METADATA = "metadata"
78
- FIELD_NEW_VERSIONS = "new_versions"
79
- FIELD_CHANNEL = "channel"
80
- FIELD_VALUE = "value"
81
- FIELD_CREATED_AT = "created_at"
82
-
83
-
84
- class AsyncFirestoreCheckpointer(BaseCheckpointSaver):
85
- """Async Firestore-based checkpoint saver implementation.
86
-
87
- This implementation uses Firestore's collections and sub-collections to efficiently
88
- store and retrieve checkpoints and their associated writes.
89
-
90
- For each environment, it creates:
91
- - A root collection for all checkpoints
92
- - A sub-collection for each checkpoint's writes
93
-
94
- This design provides:
95
- - Efficient querying by thread_id, namespace, and checkpoint_id
96
- - Hierarchical structure that matches the data relationships
97
- - Improved query performance with proper indexing
98
- """
99
-
100
- db: firestore.AsyncClient
101
- checkpoints_collection: firestore.AsyncCollectionReference
102
-
103
- def __init__(
104
- self,
105
- env_name: str,
106
- serializer: Optional[SerializerProtocol] = None,
107
- collection_name: Optional[str] = None,
108
- ):
109
- """Initialize the AsyncFirestoreCheckpointer.
110
-
111
- Args:
112
- env_name: Environment name to be used as prefix for collection.
113
- serializer: Optional serializer to use for converting values to storable format.
114
- collection_name: Optional custom collection name. If not provided,
115
- it will use {env_name}-{CHECKPOINTER_STORE_NAME}.
116
- """
117
- super().__init__()
118
- logger.info(f"Initializing AsyncFirestoreCheckpointer with env_name={env_name}")
119
- self.serde = serializer or JsonPlusSerializer()
120
- self._collection_name = (
121
- collection_name or f"{env_name}-{CHECKPOINTER_STORE_NAME}"
122
- )
123
- logger.info(f"Using collection: {self._collection_name}")
124
-
125
- try:
126
- # Initialize async Firestore client
127
- google_service_account_key_path = os.getenv(
128
- "GOOGLE_APPLICATION_CREDENTIALS_FOR_FASTAPI",
129
- "/app/keys/scoop-386004-d22d99a7afd9.json",
130
- )
131
- credentials = service_account.Credentials.from_service_account_file(
132
- google_service_account_key_path,
133
- scopes=["https://www.googleapis.com/auth/datastore"],
134
- )
135
-
136
- project_id = os.getenv("GOOGLE_CLOUD_PROJECT")
137
- if project_id:
138
- self.db = firestore.AsyncClient(
139
- project=project_id, credentials=credentials
140
- )
141
- else:
142
- self.db = firestore.AsyncClient(credentials=credentials)
143
-
144
- self.checkpoints_collection = self.db.collection(self._collection_name)
145
- logger.info("Async Firestore client initialized successfully")
146
- except Exception as e:
147
- logger.error(f"Error initializing Firestore client: {e}", exc_info=True)
148
- raise
149
-
150
- async def close(self):
151
- """Close the Firestore client connection."""
152
- if hasattr(self, "db") and self.db:
153
- await self.db.close()
154
- logger.info("Firestore client connection closed")
155
-
156
- async def __aenter__(self):
157
- """Context manager entry."""
158
- return self
159
-
160
- async def __aexit__(self, exc_type, exc_val, exc_tb):
161
- """Context manager exit with cleanup."""
162
- await self.close()
163
-
164
- def _get_checkpoint_doc_id(
165
- self, thread_id: str, checkpoint_ns: str, checkpoint_id: str
166
- ) -> str:
167
- """Generate a document ID for a checkpoint.
168
-
169
- For maximum Firestore efficiency, we use a compound ID that naturally clusters
170
- related data together for efficient retrieval.
171
- """
172
- return f"{thread_id}:{checkpoint_ns}:{checkpoint_id}"
173
-
174
- def _get_writes_subcollection(
175
- self, checkpoint_doc_ref: firestore.AsyncDocumentReference
176
- ) -> firestore.AsyncCollectionReference:
177
- """Get the subcollection reference for checkpoint writes."""
178
- return checkpoint_doc_ref.collection("writes")
179
-
180
- def _parse_checkpoint_doc_id(self, doc_id: str) -> Dict[str, str]:
181
- """Parse a checkpoint document ID into its components."""
182
- parts = doc_id.split(":")
183
- if len(parts) != 3:
184
- raise ValueError(f"Invalid checkpoint document ID format: {doc_id}")
185
-
186
- return {
187
- FIELD_THREAD_ID: parts[0],
188
- FIELD_CHECKPOINT_NS: parts[1],
189
- FIELD_CHECKPOINT_ID: parts[2],
190
- }
191
-
192
- async def aput(
193
- self,
194
- config: RunnableConfig,
195
- checkpoint: Checkpoint,
196
- metadata: CheckpointMetadata,
197
- new_versions: ChannelVersions,
198
- ) -> RunnableConfig:
199
- """Save a checkpoint to Firestore asynchronously.
200
-
201
- This method saves a checkpoint to Firestore as a document with fields for
202
- efficient querying.
203
-
204
- Args:
205
- config: The config to associate with the checkpoint.
206
- checkpoint: The checkpoint to save.
207
- metadata: Additional metadata to save with the checkpoint.
208
- new_versions: New channel versions as of this write.
209
-
210
- Returns:
211
- RunnableConfig: Updated configuration after storing the checkpoint.
212
- """
213
- thread_id = config["configurable"]["thread_id"]
214
- checkpoint_ns = config["configurable"].get("checkpoint_ns", "")
215
- checkpoint_id = checkpoint["id"]
216
- parent_checkpoint_id = config["configurable"].get("checkpoint_id", "")
217
-
218
- # Generate document ID for efficient querying
219
- doc_id = self._get_checkpoint_doc_id(thread_id, checkpoint_ns, checkpoint_id)
220
-
221
- # Serialize the data
222
- type_, serialized_checkpoint = self.serde.dumps_typed(checkpoint)
223
- serialized_metadata = self.serde.dumps(metadata)
224
-
225
- # Prepare the document data
226
- data = {
227
- FIELD_THREAD_ID: thread_id,
228
- FIELD_CHECKPOINT_NS: checkpoint_ns,
229
- FIELD_CHECKPOINT_ID: checkpoint_id,
230
- FIELD_PARENT_CHECKPOINT_ID: parent_checkpoint_id,
231
- FIELD_TYPE: type_,
232
- FIELD_DATA: serialized_checkpoint,
233
- FIELD_METADATA: serialized_metadata,
234
- FIELD_TIMESTAMP: firestore.SERVER_TIMESTAMP, # Use server timestamp for consistency
235
- FIELD_CREATED_AT: datetime.utcnow().isoformat(), # Backup client-side timestamp
236
- }
237
-
238
- if new_versions:
239
- data[FIELD_NEW_VERSIONS] = self.serde.dumps(new_versions)
240
-
241
- try:
242
- await self.checkpoints_collection.document(doc_id).set(data)
243
- logger.info(f"Successfully stored checkpoint with ID: {doc_id}")
244
- except Exception as e:
245
- logger.error(f"Error storing checkpoint: {e}", exc_info=True)
246
- raise
247
-
248
- return {
249
- "configurable": {
250
- "thread_id": thread_id,
251
- "checkpoint_ns": checkpoint_ns,
252
- "checkpoint_id": checkpoint_id,
253
- }
254
- }
255
-
256
- async def aput_writes(
257
- self,
258
- config: RunnableConfig,
259
- writes: List[Tuple[str, Any]],
260
- task_id: str,
261
- ) -> None:
262
- """Store intermediate writes linked to a checkpoint asynchronously.
263
-
264
- This method saves intermediate writes associated with a checkpoint in a subcollection.
265
-
266
- Args:
267
- config: Configuration of the related checkpoint.
268
- writes: List of writes to store, each as (channel, value) pair.
269
- task_id: Identifier for the task creating the writes.
270
- """
271
- thread_id = config["configurable"]["thread_id"]
272
- checkpoint_ns = config["configurable"].get("checkpoint_ns", "")
273
- checkpoint_id = config["configurable"]["checkpoint_id"]
274
-
275
- # Get the checkpoint document reference
276
- checkpoint_doc_id = self._get_checkpoint_doc_id(
277
- thread_id, checkpoint_ns, checkpoint_id
278
- )
279
- checkpoint_doc_ref = self.checkpoints_collection.document(checkpoint_doc_id)
280
-
281
- # Get the writes subcollection
282
- writes_collection = self._get_writes_subcollection(checkpoint_doc_ref)
283
-
284
- try:
285
- # Optimize write operations with batching
286
- batch = self.db.batch()
287
- batch_size = 0
288
- max_batch_size = 450 # Slightly below Firestore limit for safety
289
- batch_futures = [] # For tracking concurrent batch commits
290
-
291
- for idx, (channel, value) in enumerate(writes):
292
- # Determine the write ID
293
- write_idx = WRITES_IDX_MAP.get(channel, idx)
294
- write_id = f"{task_id}:{write_idx}"
295
-
296
- # Serialize the value
297
- type_, serialized_value = self.serde.dumps_typed(value)
298
-
299
- # Prepare the write data
300
- data = {
301
- FIELD_TASK_ID: task_id,
302
- FIELD_IDX: write_idx,
303
- FIELD_CHANNEL: channel,
304
- FIELD_TYPE: type_,
305
- FIELD_VALUE: serialized_value,
306
- FIELD_TIMESTAMP: firestore.SERVER_TIMESTAMP,
307
- FIELD_CREATED_AT: datetime.utcnow().isoformat(),
308
- }
309
-
310
- write_doc_ref = writes_collection.document(write_id)
311
-
312
- # Determine if we should set or create-if-not-exists
313
- if channel in WRITES_IDX_MAP:
314
- # For indexed channels, always set (similar to HSET behavior)
315
- batch.set(write_doc_ref, data)
316
- else:
317
- # For non-indexed channels, we need a transaction to check existence
318
- # We'll check existence manually for now
319
- doc = await write_doc_ref.get()
320
- if not doc.exists:
321
- batch.set(write_doc_ref, data)
322
-
323
- batch_size += 1
324
-
325
- # If batch is getting full, submit it and start a new one
326
- if batch_size >= max_batch_size:
327
- batch_futures.append(batch.commit())
328
- batch = self.db.batch()
329
- batch_size = 0
330
-
331
- # Commit any remaining writes in the batch
332
- if batch_size > 0:
333
- batch_futures.append(batch.commit())
334
-
335
- # Wait for all batch operations to complete
336
- if batch_futures:
337
- await asyncio.gather(*batch_futures)
338
-
339
- logger.info(
340
- f"Successfully stored {len(writes)} writes for checkpoint: {checkpoint_id}"
341
- )
342
- except Exception as e:
343
- logger.error(f"Error storing writes: {e}", exc_info=True)
344
- raise
345
-
346
- async def aget_tuple(
347
- self,
348
- config: RunnableConfig,
349
- ) -> Optional[CheckpointTuple]:
350
- """Get a checkpoint tuple from Firestore asynchronously.
351
-
352
- This method retrieves a checkpoint and its associated writes from Firestore.
353
-
354
- Args:
355
- config: The config to use for retrieving the checkpoint.
356
-
357
- Returns:
358
- Optional[CheckpointTuple]: The retrieved checkpoint tuple, or None if not found.
359
- """
360
- thread_id = config["configurable"]["thread_id"]
361
- checkpoint_id = get_checkpoint_id(config)
362
- checkpoint_ns = config["configurable"].get("checkpoint_ns", "")
363
-
364
- try:
365
- # If checkpoint_id is provided, get that specific checkpoint
366
- if checkpoint_id:
367
- doc_id = self._get_checkpoint_doc_id(
368
- thread_id, checkpoint_ns, checkpoint_id
369
- )
370
- doc = await self.checkpoints_collection.document(doc_id).get()
371
-
372
- if not doc.exists:
373
- return None
374
- else:
375
- # Otherwise, find the latest checkpoint
376
- query = (
377
- self.checkpoints_collection.where(
378
- filter=FieldFilter(FIELD_THREAD_ID, "==", thread_id)
379
- )
380
- .where(filter=FieldFilter(FIELD_CHECKPOINT_NS, "==", checkpoint_ns))
381
- .order_by(FIELD_TIMESTAMP, direction=firestore.Query.DESCENDING)
382
- .limit(1)
383
- )
384
-
385
- docs = await query.get()
386
- if not docs:
387
- return None
388
-
389
- doc = docs[0]
390
- # Extract the checkpoint_id for loading writes
391
- checkpoint_id = doc.get(FIELD_CHECKPOINT_ID)
392
-
393
- data = doc.to_dict()
394
-
395
- # Parse the document data
396
- type_ = data.get(FIELD_TYPE)
397
- serialized_checkpoint = data.get(FIELD_DATA)
398
- serialized_metadata = data.get(FIELD_METADATA)
399
-
400
- if not type_ or not serialized_checkpoint or not serialized_metadata:
401
- logger.error(f"Invalid checkpoint data for ID: {doc.id}")
402
- return None
403
-
404
- # 重新組合類型和序列化數據,以符合 loads_typed 的期望
405
- checkpoint = self.serde.loads_typed((type_, serialized_checkpoint))
406
- metadata = self.serde.loads(serialized_metadata)
407
-
408
- # Load pending writes from the subcollection
409
- pending_writes = await self._aload_pending_writes(doc.reference)
410
-
411
- return CheckpointTuple(
412
- config=config,
413
- checkpoint=checkpoint,
414
- metadata=metadata,
415
- pending_writes=pending_writes if pending_writes else None,
416
- )
417
- except Exception as e:
418
- logger.error(f"Error retrieving checkpoint tuple: {e}", exc_info=True)
419
- raise
420
-
421
- async def alist(
422
- self,
423
- config: Optional[RunnableConfig],
424
- *,
425
- filter: Optional[dict[str, Any]] = None,
426
- before: Optional[RunnableConfig] = None,
427
- limit: Optional[int] = None,
428
- ) -> AsyncGenerator[CheckpointTuple, None]:
429
- """List checkpoints from Firestore asynchronously.
430
-
431
- This method retrieves a list of checkpoint tuples from Firestore based
432
- on the provided config.
433
-
434
- Args:
435
- config: Base configuration for filtering checkpoints.
436
- filter: Additional filtering criteria for metadata.
437
- before: If provided, only checkpoints before the specified checkpoint ID are returned.
438
- limit: Maximum number of checkpoints to return.
439
-
440
- Yields:
441
- AsyncGenerator[CheckpointTuple, None]: An async generator of matching checkpoint tuples.
442
- """
443
- if not config:
444
- logger.error("Config is required for listing checkpoints")
445
- return
446
-
447
- thread_id = config["configurable"]["thread_id"]
448
- checkpoint_ns = config["configurable"].get("checkpoint_ns", "")
449
-
450
- try:
451
- t1 = time.time()
452
- # Build the query
453
- query = (
454
- self.checkpoints_collection.where(
455
- filter=FieldFilter(FIELD_THREAD_ID, "==", thread_id)
456
- )
457
- .where(filter=FieldFilter(FIELD_CHECKPOINT_NS, "==", checkpoint_ns))
458
- .order_by(FIELD_TIMESTAMP, direction=firestore.Query.DESCENDING)
459
- )
460
-
461
- # Apply additional filters
462
- if before is not None:
463
- before_id = get_checkpoint_id(before)
464
- # We need to find the timestamp of the 'before' checkpoint to filter correctly
465
- before_doc_id = self._get_checkpoint_doc_id(
466
- thread_id, checkpoint_ns, before_id
467
- )
468
- before_doc = await self.checkpoints_collection.document(
469
- before_doc_id
470
- ).get()
471
-
472
- if before_doc.exists:
473
- before_timestamp = before_doc.get(FIELD_TIMESTAMP)
474
- if before_timestamp:
475
- query = query.where(FIELD_TIMESTAMP, "<", before_timestamp)
476
-
477
- # Apply limit if provided
478
- if limit is not None:
479
- query = query.limit(limit)
480
-
481
- # Execute the query
482
- docs = await query.get()
483
-
484
- # Process each document
485
- for doc in docs:
486
- data = doc.to_dict()
487
-
488
- if not data or FIELD_DATA not in data or FIELD_METADATA not in data:
489
- continue
490
-
491
- # Extract basic information
492
- thread_id = data.get(FIELD_THREAD_ID)
493
- checkpoint_ns = data.get(FIELD_CHECKPOINT_NS)
494
- checkpoint_id = data.get(FIELD_CHECKPOINT_ID)
495
-
496
- # Build config for this checkpoint
497
- checkpoint_config = {
498
- "configurable": {
499
- "thread_id": thread_id,
500
- "checkpoint_ns": checkpoint_ns,
501
- "checkpoint_id": checkpoint_id,
502
- }
503
- }
504
-
505
- # Parse checkpoint data
506
- type_ = data.get(FIELD_TYPE)
507
- serialized_checkpoint = data.get(FIELD_DATA)
508
- serialized_metadata = data.get(FIELD_METADATA)
509
-
510
- if not type_ or not serialized_checkpoint:
511
- continue
512
-
513
- # 重新組合類型和序列化數據,以符合 loads_typed 的期望
514
- checkpoint = self.serde.loads_typed((type_, serialized_checkpoint))
515
- metadata = (
516
- self.serde.loads(serialized_metadata)
517
- if serialized_metadata
518
- else None
519
- )
520
-
521
- # Load pending writes
522
- pending_writes = await self._aload_pending_writes(doc.reference)
523
-
524
- yield CheckpointTuple(
525
- config=checkpoint_config,
526
- checkpoint=checkpoint,
527
- metadata=metadata,
528
- pending_writes=pending_writes if pending_writes else None,
529
- )
530
- except Exception as e:
531
- logger.error(f"Error listing checkpoints: {e}", exc_info=True)
532
- raise
533
- t2 = time.time()
534
- print(f"[AsyncFirestoreCheckpointer:alist] Elapsed {t2 - t1:.3f}s")
535
-
536
- async def _aload_pending_writes(
537
- self, checkpoint_doc_ref: firestore.AsyncDocumentReference
538
- ) -> List[Tuple[str, Any, None]]:
539
- """Load pending writes for a checkpoint from its subcollection.
540
-
541
- Returns a flat list of PendingWrite tuples (channel, value, None) similar to Redis implementation.
542
- """
543
- try:
544
- # Get the writes subcollection
545
- writes_collection = self._get_writes_subcollection(checkpoint_doc_ref)
546
-
547
- # Query all writes documents in the subcollection
548
- docs = await writes_collection.get()
549
-
550
- # Process the documents to extract writes
551
- result = []
552
-
553
- for doc in docs:
554
- data = doc.to_dict()
555
-
556
- if not data:
557
- continue
558
-
559
- task_id = data.get(FIELD_TASK_ID)
560
- channel = data.get(FIELD_CHANNEL)
561
- type_ = data.get(FIELD_TYPE)
562
- serialized_value = data.get(FIELD_VALUE)
563
-
564
- if not task_id or not channel or not type_ or not serialized_value:
565
- continue
566
-
567
- # 重新組合類型和序列化數據,以符合 loads_typed 的期望
568
- value = self.serde.loads_typed((type_, serialized_value))
569
-
570
- # Create a proper tuple according to PendingWrite definition (channel, value, None)
571
- # Following the Redis implementation pattern
572
- result.append((channel, value, None))
573
-
574
- return result
575
- except Exception as e:
576
- logger.error(f"Error loading pending writes: {e}", exc_info=True)
577
- return []
578
-
579
- async def adelete_thread(self, thread_id: str) -> None:
580
- """Delete all checkpoints and writes for a specific thread asynchronously.
581
-
582
- This method removes all data associated with a thread, including:
583
- - All checkpoint documents that match the thread_id
584
- - All writes subcollections under those checkpoints
585
-
586
- Args:
587
- thread_id: The thread ID for which to delete all checkpoints and writes.
588
- """
589
- try:
590
- logger.info(f"Starting deletion of all data for thread: {thread_id}")
591
-
592
- # Query all checkpoint documents for this thread_id
593
- # We need to delete across all checkpoint namespaces
594
- query = self.checkpoints_collection.where(
595
- filter=FieldFilter(FIELD_THREAD_ID, "==", thread_id)
596
- )
597
-
598
- # Get all matching checkpoint documents
599
- docs = await query.get()
600
-
601
- if not docs:
602
- logger.info(f"No checkpoints found for thread: {thread_id}")
603
- return
604
-
605
- deleted_checkpoints = 0
606
- deleted_writes = 0
607
- total_operations = 0
608
- batch_count = 0
609
-
610
- # Use smaller batches to avoid "Transaction too big" error
611
- batch = self.db.batch()
612
- batch_size = 0
613
- max_batch_size = 200 # Conservative batch size
614
-
615
- async def commit_current_batch():
616
- """Commit the current batch if it has operations"""
617
- nonlocal batch, batch_size, total_operations, batch_count
618
- if batch_size > 0:
619
- await batch.commit()
620
- total_operations += batch_size
621
- batch_count += 1
622
- logger.info(
623
- f"Thread {thread_id}: Committed batch {batch_count} "
624
- f"({batch_size} operations, total: {total_operations})"
625
- )
626
- batch = self.db.batch()
627
- batch_size = 0
628
-
629
- for doc in docs:
630
- # Delete writes subcollection first
631
- writes_collection = self._get_writes_subcollection(doc.reference)
632
-
633
- # Get all writes documents in the subcollection
634
- writes_docs = await writes_collection.get()
635
-
636
- # Add writes deletion to batch
637
- for write_doc in writes_docs:
638
- batch.delete(write_doc.reference)
639
- batch_size += 1
640
- deleted_writes += 1
641
-
642
- # Commit batch when it reaches max size
643
- if batch_size >= max_batch_size:
644
- await commit_current_batch()
645
-
646
- # Add checkpoint document deletion to batch
647
- batch.delete(doc.reference)
648
- batch_size += 1
649
- deleted_checkpoints += 1
650
-
651
- # Commit batch when it reaches max size
652
- if batch_size >= max_batch_size:
653
- await commit_current_batch()
654
-
655
- # Commit any remaining operations in the final batch
656
- await commit_current_batch()
657
-
658
- logger.info(
659
- f"Successfully deleted thread {thread_id}: "
660
- f"{deleted_checkpoints} checkpoints, {deleted_writes} writes "
661
- f"(total: {total_operations} operations in {batch_count} batches)"
662
- )
663
-
664
- except Exception as e:
665
- logger.error(f"Error deleting thread {thread_id}: {e}", exc_info=True)
666
- raise
1
+ from typing import (
2
+ Any,
3
+ Dict,
4
+ List,
5
+ Optional,
6
+ Tuple,
7
+ AsyncIterator,
8
+ Iterator,
9
+ cast,
10
+ AsyncGenerator,
11
+ )
12
+ import logging
13
+ from datetime import datetime
14
+ import os
15
+ import asyncio
16
+ from dotenv import load_dotenv
17
+
18
+ from google.cloud import firestore
19
+ from google.cloud.firestore_v1.base_query import FieldFilter
20
+
21
+ from google.cloud.exceptions import GoogleCloudError
22
+ from google.oauth2 import service_account
23
+
24
+ from langgraph.checkpoint.base import (
25
+ BaseCheckpointSaver,
26
+ Checkpoint,
27
+ CheckpointMetadata,
28
+ CheckpointTuple,
29
+ PendingWrite, # Note: PendingWrite is actually Tuple[str, Any, Any]
30
+ get_checkpoint_id,
31
+ WRITES_IDX_MAP,
32
+ ChannelVersions,
33
+ )
34
+ from langgraph.checkpoint.serde.base import SerializerProtocol
35
+ from langgraph.checkpoint.serde.jsonplus import JsonPlusSerializer
36
+ from langgraph.pregel.types import StateSnapshot
37
+ from langchain_core.runnables import RunnableConfig
38
+
39
+ from botrun_flow_lang.constants import CHECKPOINTER_STORE_NAME
40
+ from botrun_flow_lang.services.base.firestore_base import FirestoreBase
41
+ import time
42
+
43
+ load_dotenv()
44
+
45
+ # Set up logger
46
+ logger = logging.getLogger("AsyncFirestoreCheckpointer")
47
+ # 從環境變數取得日誌級別,默認為 WARNING(不顯示 INFO 級別日誌)
48
+ log_level = os.getenv("FIRESTORE_CHECKPOINTER_LOG_LEVEL", "WARNING").upper()
49
+ log_level_map = {
50
+ "DEBUG": logging.DEBUG,
51
+ "INFO": logging.INFO,
52
+ "WARNING": logging.WARNING,
53
+ "ERROR": logging.ERROR,
54
+ "CRITICAL": logging.CRITICAL,
55
+ }
56
+ logger.setLevel(log_level_map.get(log_level, logging.WARNING))
57
+ # Create console handler if it doesn't exist
58
+ if not logger.handlers:
59
+ ch = logging.StreamHandler()
60
+ ch.setLevel(log_level_map.get(log_level, logging.WARNING))
61
+ formatter = logging.Formatter(
62
+ "%(asctime)s - %(name)s - %(levelname)s - %(message)s"
63
+ )
64
+ ch.setFormatter(formatter)
65
+ logger.addHandler(ch)
66
+
67
+ # Constants for field names
68
+ FIELD_THREAD_ID = "thread_id"
69
+ FIELD_CHECKPOINT_NS = "checkpoint_ns"
70
+ FIELD_CHECKPOINT_ID = "checkpoint_id"
71
+ FIELD_PARENT_CHECKPOINT_ID = "parent_checkpoint_id"
72
+ FIELD_TASK_ID = "task_id"
73
+ FIELD_IDX = "idx"
74
+ FIELD_TIMESTAMP = "timestamp"
75
+ FIELD_TYPE = "type"
76
+ FIELD_DATA = "data"
77
+ FIELD_METADATA = "metadata"
78
+ FIELD_NEW_VERSIONS = "new_versions"
79
+ FIELD_CHANNEL = "channel"
80
+ FIELD_VALUE = "value"
81
+ FIELD_CREATED_AT = "created_at"
82
+
83
+
84
+ class AsyncFirestoreCheckpointer(BaseCheckpointSaver):
85
+ """Async Firestore-based checkpoint saver implementation.
86
+
87
+ This implementation uses Firestore's collections and sub-collections to efficiently
88
+ store and retrieve checkpoints and their associated writes.
89
+
90
+ For each environment, it creates:
91
+ - A root collection for all checkpoints
92
+ - A sub-collection for each checkpoint's writes
93
+
94
+ This design provides:
95
+ - Efficient querying by thread_id, namespace, and checkpoint_id
96
+ - Hierarchical structure that matches the data relationships
97
+ - Improved query performance with proper indexing
98
+ """
99
+
100
+ db: firestore.AsyncClient
101
+ checkpoints_collection: firestore.AsyncCollectionReference
102
+
103
+ def __init__(
104
+ self,
105
+ env_name: str,
106
+ serializer: Optional[SerializerProtocol] = None,
107
+ collection_name: Optional[str] = None,
108
+ ):
109
+ """Initialize the AsyncFirestoreCheckpointer.
110
+
111
+ Args:
112
+ env_name: Environment name to be used as prefix for collection.
113
+ serializer: Optional serializer to use for converting values to storable format.
114
+ collection_name: Optional custom collection name. If not provided,
115
+ it will use {env_name}-{CHECKPOINTER_STORE_NAME}.
116
+ """
117
+ super().__init__()
118
+ logger.info(f"Initializing AsyncFirestoreCheckpointer with env_name={env_name}")
119
+ self.serde = serializer or JsonPlusSerializer()
120
+ self._collection_name = (
121
+ collection_name or f"{env_name}-{CHECKPOINTER_STORE_NAME}"
122
+ )
123
+ logger.info(f"Using collection: {self._collection_name}")
124
+
125
+ try:
126
+ # Initialize async Firestore client
127
+ google_service_account_key_path = os.getenv(
128
+ "GOOGLE_APPLICATION_CREDENTIALS_FOR_FASTAPI",
129
+ "/app/keys/scoop-386004-d22d99a7afd9.json",
130
+ )
131
+ credentials = service_account.Credentials.from_service_account_file(
132
+ google_service_account_key_path,
133
+ scopes=["https://www.googleapis.com/auth/datastore"],
134
+ )
135
+
136
+ project_id = os.getenv("GOOGLE_CLOUD_PROJECT")
137
+ if project_id:
138
+ self.db = firestore.AsyncClient(
139
+ project=project_id, credentials=credentials
140
+ )
141
+ else:
142
+ self.db = firestore.AsyncClient(credentials=credentials)
143
+
144
+ self.checkpoints_collection = self.db.collection(self._collection_name)
145
+ logger.info("Async Firestore client initialized successfully")
146
+ except Exception as e:
147
+ logger.error(f"Error initializing Firestore client: {e}", exc_info=True)
148
+ raise
149
+
150
+ async def close(self):
151
+ """Close the Firestore client connection."""
152
+ if hasattr(self, "db") and self.db:
153
+ await self.db.close()
154
+ logger.info("Firestore client connection closed")
155
+
156
+ async def __aenter__(self):
157
+ """Context manager entry."""
158
+ return self
159
+
160
+ async def __aexit__(self, exc_type, exc_val, exc_tb):
161
+ """Context manager exit with cleanup."""
162
+ await self.close()
163
+
164
+ def _get_checkpoint_doc_id(
165
+ self, thread_id: str, checkpoint_ns: str, checkpoint_id: str
166
+ ) -> str:
167
+ """Generate a document ID for a checkpoint.
168
+
169
+ For maximum Firestore efficiency, we use a compound ID that naturally clusters
170
+ related data together for efficient retrieval.
171
+ """
172
+ return f"{thread_id}:{checkpoint_ns}:{checkpoint_id}"
173
+
174
+ def _get_writes_subcollection(
175
+ self, checkpoint_doc_ref: firestore.AsyncDocumentReference
176
+ ) -> firestore.AsyncCollectionReference:
177
+ """Get the subcollection reference for checkpoint writes."""
178
+ return checkpoint_doc_ref.collection("writes")
179
+
180
+ def _parse_checkpoint_doc_id(self, doc_id: str) -> Dict[str, str]:
181
+ """Parse a checkpoint document ID into its components."""
182
+ parts = doc_id.split(":")
183
+ if len(parts) != 3:
184
+ raise ValueError(f"Invalid checkpoint document ID format: {doc_id}")
185
+
186
+ return {
187
+ FIELD_THREAD_ID: parts[0],
188
+ FIELD_CHECKPOINT_NS: parts[1],
189
+ FIELD_CHECKPOINT_ID: parts[2],
190
+ }
191
+
192
+ async def aput(
193
+ self,
194
+ config: RunnableConfig,
195
+ checkpoint: Checkpoint,
196
+ metadata: CheckpointMetadata,
197
+ new_versions: ChannelVersions,
198
+ ) -> RunnableConfig:
199
+ """Save a checkpoint to Firestore asynchronously.
200
+
201
+ This method saves a checkpoint to Firestore as a document with fields for
202
+ efficient querying.
203
+
204
+ Args:
205
+ config: The config to associate with the checkpoint.
206
+ checkpoint: The checkpoint to save.
207
+ metadata: Additional metadata to save with the checkpoint.
208
+ new_versions: New channel versions as of this write.
209
+
210
+ Returns:
211
+ RunnableConfig: Updated configuration after storing the checkpoint.
212
+ """
213
+ thread_id = config["configurable"]["thread_id"]
214
+ checkpoint_ns = config["configurable"].get("checkpoint_ns", "")
215
+ checkpoint_id = checkpoint["id"]
216
+ parent_checkpoint_id = config["configurable"].get("checkpoint_id", "")
217
+
218
+ # Generate document ID for efficient querying
219
+ doc_id = self._get_checkpoint_doc_id(thread_id, checkpoint_ns, checkpoint_id)
220
+
221
+ # Serialize the data
222
+ type_, serialized_checkpoint = self.serde.dumps_typed(checkpoint)
223
+ serialized_metadata = self.serde.dumps(metadata)
224
+
225
+ # Prepare the document data
226
+ data = {
227
+ FIELD_THREAD_ID: thread_id,
228
+ FIELD_CHECKPOINT_NS: checkpoint_ns,
229
+ FIELD_CHECKPOINT_ID: checkpoint_id,
230
+ FIELD_PARENT_CHECKPOINT_ID: parent_checkpoint_id,
231
+ FIELD_TYPE: type_,
232
+ FIELD_DATA: serialized_checkpoint,
233
+ FIELD_METADATA: serialized_metadata,
234
+ FIELD_TIMESTAMP: firestore.SERVER_TIMESTAMP, # Use server timestamp for consistency
235
+ FIELD_CREATED_AT: datetime.utcnow().isoformat(), # Backup client-side timestamp
236
+ }
237
+
238
+ if new_versions:
239
+ data[FIELD_NEW_VERSIONS] = self.serde.dumps(new_versions)
240
+
241
+ try:
242
+ await self.checkpoints_collection.document(doc_id).set(data)
243
+ logger.info(f"Successfully stored checkpoint with ID: {doc_id}")
244
+ except Exception as e:
245
+ logger.error(f"Error storing checkpoint: {e}", exc_info=True)
246
+ raise
247
+
248
+ return {
249
+ "configurable": {
250
+ "thread_id": thread_id,
251
+ "checkpoint_ns": checkpoint_ns,
252
+ "checkpoint_id": checkpoint_id,
253
+ }
254
+ }
255
+
256
+ async def aput_writes(
257
+ self,
258
+ config: RunnableConfig,
259
+ writes: List[Tuple[str, Any]],
260
+ task_id: str,
261
+ ) -> None:
262
+ """Store intermediate writes linked to a checkpoint asynchronously.
263
+
264
+ This method saves intermediate writes associated with a checkpoint in a subcollection.
265
+
266
+ Args:
267
+ config: Configuration of the related checkpoint.
268
+ writes: List of writes to store, each as (channel, value) pair.
269
+ task_id: Identifier for the task creating the writes.
270
+ """
271
+ thread_id = config["configurable"]["thread_id"]
272
+ checkpoint_ns = config["configurable"].get("checkpoint_ns", "")
273
+ checkpoint_id = config["configurable"]["checkpoint_id"]
274
+
275
+ # Get the checkpoint document reference
276
+ checkpoint_doc_id = self._get_checkpoint_doc_id(
277
+ thread_id, checkpoint_ns, checkpoint_id
278
+ )
279
+ checkpoint_doc_ref = self.checkpoints_collection.document(checkpoint_doc_id)
280
+
281
+ # Get the writes subcollection
282
+ writes_collection = self._get_writes_subcollection(checkpoint_doc_ref)
283
+
284
+ try:
285
+ # Optimize write operations with batching
286
+ batch = self.db.batch()
287
+ batch_size = 0
288
+ max_batch_size = 450 # Slightly below Firestore limit for safety
289
+ batch_futures = [] # For tracking concurrent batch commits
290
+
291
+ for idx, (channel, value) in enumerate(writes):
292
+ # Determine the write ID
293
+ write_idx = WRITES_IDX_MAP.get(channel, idx)
294
+ write_id = f"{task_id}:{write_idx}"
295
+
296
+ # Serialize the value
297
+ type_, serialized_value = self.serde.dumps_typed(value)
298
+
299
+ # Prepare the write data
300
+ data = {
301
+ FIELD_TASK_ID: task_id,
302
+ FIELD_IDX: write_idx,
303
+ FIELD_CHANNEL: channel,
304
+ FIELD_TYPE: type_,
305
+ FIELD_VALUE: serialized_value,
306
+ FIELD_TIMESTAMP: firestore.SERVER_TIMESTAMP,
307
+ FIELD_CREATED_AT: datetime.utcnow().isoformat(),
308
+ }
309
+
310
+ write_doc_ref = writes_collection.document(write_id)
311
+
312
+ # Determine if we should set or create-if-not-exists
313
+ if channel in WRITES_IDX_MAP:
314
+ # For indexed channels, always set (similar to HSET behavior)
315
+ batch.set(write_doc_ref, data)
316
+ else:
317
+ # For non-indexed channels, we need a transaction to check existence
318
+ # We'll check existence manually for now
319
+ doc = await write_doc_ref.get()
320
+ if not doc.exists:
321
+ batch.set(write_doc_ref, data)
322
+
323
+ batch_size += 1
324
+
325
+ # If batch is getting full, submit it and start a new one
326
+ if batch_size >= max_batch_size:
327
+ batch_futures.append(batch.commit())
328
+ batch = self.db.batch()
329
+ batch_size = 0
330
+
331
+ # Commit any remaining writes in the batch
332
+ if batch_size > 0:
333
+ batch_futures.append(batch.commit())
334
+
335
+ # Wait for all batch operations to complete
336
+ if batch_futures:
337
+ await asyncio.gather(*batch_futures)
338
+
339
+ logger.info(
340
+ f"Successfully stored {len(writes)} writes for checkpoint: {checkpoint_id}"
341
+ )
342
+ except Exception as e:
343
+ logger.error(f"Error storing writes: {e}", exc_info=True)
344
+ raise
345
+
346
+ async def aget_tuple(
347
+ self,
348
+ config: RunnableConfig,
349
+ ) -> Optional[CheckpointTuple]:
350
+ """Get a checkpoint tuple from Firestore asynchronously.
351
+
352
+ This method retrieves a checkpoint and its associated writes from Firestore.
353
+
354
+ Args:
355
+ config: The config to use for retrieving the checkpoint.
356
+
357
+ Returns:
358
+ Optional[CheckpointTuple]: The retrieved checkpoint tuple, or None if not found.
359
+ """
360
+ thread_id = config["configurable"]["thread_id"]
361
+ checkpoint_id = get_checkpoint_id(config)
362
+ checkpoint_ns = config["configurable"].get("checkpoint_ns", "")
363
+
364
+ try:
365
+ # If checkpoint_id is provided, get that specific checkpoint
366
+ if checkpoint_id:
367
+ doc_id = self._get_checkpoint_doc_id(
368
+ thread_id, checkpoint_ns, checkpoint_id
369
+ )
370
+ doc = await self.checkpoints_collection.document(doc_id).get()
371
+
372
+ if not doc.exists:
373
+ return None
374
+ else:
375
+ # Otherwise, find the latest checkpoint
376
+ query = (
377
+ self.checkpoints_collection.where(
378
+ filter=FieldFilter(FIELD_THREAD_ID, "==", thread_id)
379
+ )
380
+ .where(filter=FieldFilter(FIELD_CHECKPOINT_NS, "==", checkpoint_ns))
381
+ .order_by(FIELD_TIMESTAMP, direction=firestore.Query.DESCENDING)
382
+ .limit(1)
383
+ )
384
+
385
+ docs = await query.get()
386
+ if not docs:
387
+ return None
388
+
389
+ doc = docs[0]
390
+ # Extract the checkpoint_id for loading writes
391
+ checkpoint_id = doc.get(FIELD_CHECKPOINT_ID)
392
+
393
+ data = doc.to_dict()
394
+
395
+ # Parse the document data
396
+ type_ = data.get(FIELD_TYPE)
397
+ serialized_checkpoint = data.get(FIELD_DATA)
398
+ serialized_metadata = data.get(FIELD_METADATA)
399
+
400
+ if not type_ or not serialized_checkpoint or not serialized_metadata:
401
+ logger.error(f"Invalid checkpoint data for ID: {doc.id}")
402
+ return None
403
+
404
+ # 重新組合類型和序列化數據,以符合 loads_typed 的期望
405
+ checkpoint = self.serde.loads_typed((type_, serialized_checkpoint))
406
+ metadata = self.serde.loads(serialized_metadata)
407
+
408
+ # Load pending writes from the subcollection
409
+ pending_writes = await self._aload_pending_writes(doc.reference)
410
+
411
+ return CheckpointTuple(
412
+ config=config,
413
+ checkpoint=checkpoint,
414
+ metadata=metadata,
415
+ pending_writes=pending_writes if pending_writes else None,
416
+ )
417
+ except Exception as e:
418
+ logger.error(f"Error retrieving checkpoint tuple: {e}", exc_info=True)
419
+ raise
420
+
421
+ async def alist(
422
+ self,
423
+ config: Optional[RunnableConfig],
424
+ *,
425
+ filter: Optional[dict[str, Any]] = None,
426
+ before: Optional[RunnableConfig] = None,
427
+ limit: Optional[int] = None,
428
+ ) -> AsyncGenerator[CheckpointTuple, None]:
429
+ """List checkpoints from Firestore asynchronously.
430
+
431
+ This method retrieves a list of checkpoint tuples from Firestore based
432
+ on the provided config.
433
+
434
+ Args:
435
+ config: Base configuration for filtering checkpoints.
436
+ filter: Additional filtering criteria for metadata.
437
+ before: If provided, only checkpoints before the specified checkpoint ID are returned.
438
+ limit: Maximum number of checkpoints to return.
439
+
440
+ Yields:
441
+ AsyncGenerator[CheckpointTuple, None]: An async generator of matching checkpoint tuples.
442
+ """
443
+ if not config:
444
+ logger.error("Config is required for listing checkpoints")
445
+ return
446
+
447
+ thread_id = config["configurable"]["thread_id"]
448
+ checkpoint_ns = config["configurable"].get("checkpoint_ns", "")
449
+
450
+ try:
451
+ t1 = time.time()
452
+ # Build the query
453
+ query = (
454
+ self.checkpoints_collection.where(
455
+ filter=FieldFilter(FIELD_THREAD_ID, "==", thread_id)
456
+ )
457
+ .where(filter=FieldFilter(FIELD_CHECKPOINT_NS, "==", checkpoint_ns))
458
+ .order_by(FIELD_TIMESTAMP, direction=firestore.Query.DESCENDING)
459
+ )
460
+
461
+ # Apply additional filters
462
+ if before is not None:
463
+ before_id = get_checkpoint_id(before)
464
+ # We need to find the timestamp of the 'before' checkpoint to filter correctly
465
+ before_doc_id = self._get_checkpoint_doc_id(
466
+ thread_id, checkpoint_ns, before_id
467
+ )
468
+ before_doc = await self.checkpoints_collection.document(
469
+ before_doc_id
470
+ ).get()
471
+
472
+ if before_doc.exists:
473
+ before_timestamp = before_doc.get(FIELD_TIMESTAMP)
474
+ if before_timestamp:
475
+ query = query.where(FIELD_TIMESTAMP, "<", before_timestamp)
476
+
477
+ # Apply limit if provided
478
+ if limit is not None:
479
+ query = query.limit(limit)
480
+
481
+ # Execute the query
482
+ docs = await query.get()
483
+
484
+ # Process each document
485
+ for doc in docs:
486
+ data = doc.to_dict()
487
+
488
+ if not data or FIELD_DATA not in data or FIELD_METADATA not in data:
489
+ continue
490
+
491
+ # Extract basic information
492
+ thread_id = data.get(FIELD_THREAD_ID)
493
+ checkpoint_ns = data.get(FIELD_CHECKPOINT_NS)
494
+ checkpoint_id = data.get(FIELD_CHECKPOINT_ID)
495
+
496
+ # Build config for this checkpoint
497
+ checkpoint_config = {
498
+ "configurable": {
499
+ "thread_id": thread_id,
500
+ "checkpoint_ns": checkpoint_ns,
501
+ "checkpoint_id": checkpoint_id,
502
+ }
503
+ }
504
+
505
+ # Parse checkpoint data
506
+ type_ = data.get(FIELD_TYPE)
507
+ serialized_checkpoint = data.get(FIELD_DATA)
508
+ serialized_metadata = data.get(FIELD_METADATA)
509
+
510
+ if not type_ or not serialized_checkpoint:
511
+ continue
512
+
513
+ # 重新組合類型和序列化數據,以符合 loads_typed 的期望
514
+ checkpoint = self.serde.loads_typed((type_, serialized_checkpoint))
515
+ metadata = (
516
+ self.serde.loads(serialized_metadata)
517
+ if serialized_metadata
518
+ else None
519
+ )
520
+
521
+ # Load pending writes
522
+ pending_writes = await self._aload_pending_writes(doc.reference)
523
+
524
+ yield CheckpointTuple(
525
+ config=checkpoint_config,
526
+ checkpoint=checkpoint,
527
+ metadata=metadata,
528
+ pending_writes=pending_writes if pending_writes else None,
529
+ )
530
+ except Exception as e:
531
+ logger.error(f"Error listing checkpoints: {e}", exc_info=True)
532
+ raise
533
+ t2 = time.time()
534
+ print(f"[AsyncFirestoreCheckpointer:alist] Elapsed {t2 - t1:.3f}s")
535
+
536
+ async def _aload_pending_writes(
537
+ self, checkpoint_doc_ref: firestore.AsyncDocumentReference
538
+ ) -> List[Tuple[str, Any, None]]:
539
+ """Load pending writes for a checkpoint from its subcollection.
540
+
541
+ Returns a flat list of PendingWrite tuples (channel, value, None) similar to Redis implementation.
542
+ """
543
+ try:
544
+ # Get the writes subcollection
545
+ writes_collection = self._get_writes_subcollection(checkpoint_doc_ref)
546
+
547
+ # Query all writes documents in the subcollection
548
+ docs = await writes_collection.get()
549
+
550
+ # Process the documents to extract writes
551
+ result = []
552
+
553
+ for doc in docs:
554
+ data = doc.to_dict()
555
+
556
+ if not data:
557
+ continue
558
+
559
+ task_id = data.get(FIELD_TASK_ID)
560
+ channel = data.get(FIELD_CHANNEL)
561
+ type_ = data.get(FIELD_TYPE)
562
+ serialized_value = data.get(FIELD_VALUE)
563
+
564
+ if not task_id or not channel or not type_ or not serialized_value:
565
+ continue
566
+
567
+ # 重新組合類型和序列化數據,以符合 loads_typed 的期望
568
+ value = self.serde.loads_typed((type_, serialized_value))
569
+
570
+ # Create a proper tuple according to PendingWrite definition (channel, value, None)
571
+ # Following the Redis implementation pattern
572
+ result.append((channel, value, None))
573
+
574
+ return result
575
+ except Exception as e:
576
+ logger.error(f"Error loading pending writes: {e}", exc_info=True)
577
+ return []
578
+
579
+ async def adelete_thread(self, thread_id: str) -> None:
580
+ """Delete all checkpoints and writes for a specific thread asynchronously.
581
+
582
+ This method removes all data associated with a thread, including:
583
+ - All checkpoint documents that match the thread_id
584
+ - All writes subcollections under those checkpoints
585
+
586
+ Args:
587
+ thread_id: The thread ID for which to delete all checkpoints and writes.
588
+ """
589
+ try:
590
+ logger.info(f"Starting deletion of all data for thread: {thread_id}")
591
+
592
+ # Query all checkpoint documents for this thread_id
593
+ # We need to delete across all checkpoint namespaces
594
+ query = self.checkpoints_collection.where(
595
+ filter=FieldFilter(FIELD_THREAD_ID, "==", thread_id)
596
+ )
597
+
598
+ # Get all matching checkpoint documents
599
+ docs = await query.get()
600
+
601
+ if not docs:
602
+ logger.info(f"No checkpoints found for thread: {thread_id}")
603
+ return
604
+
605
+ deleted_checkpoints = 0
606
+ deleted_writes = 0
607
+ total_operations = 0
608
+ batch_count = 0
609
+
610
+ # Use smaller batches to avoid "Transaction too big" error
611
+ batch = self.db.batch()
612
+ batch_size = 0
613
+ max_batch_size = 200 # Conservative batch size
614
+
615
+ async def commit_current_batch():
616
+ """Commit the current batch if it has operations"""
617
+ nonlocal batch, batch_size, total_operations, batch_count
618
+ if batch_size > 0:
619
+ await batch.commit()
620
+ total_operations += batch_size
621
+ batch_count += 1
622
+ logger.info(
623
+ f"Thread {thread_id}: Committed batch {batch_count} "
624
+ f"({batch_size} operations, total: {total_operations})"
625
+ )
626
+ batch = self.db.batch()
627
+ batch_size = 0
628
+
629
+ for doc in docs:
630
+ # Delete writes subcollection first
631
+ writes_collection = self._get_writes_subcollection(doc.reference)
632
+
633
+ # Get all writes documents in the subcollection
634
+ writes_docs = await writes_collection.get()
635
+
636
+ # Add writes deletion to batch
637
+ for write_doc in writes_docs:
638
+ batch.delete(write_doc.reference)
639
+ batch_size += 1
640
+ deleted_writes += 1
641
+
642
+ # Commit batch when it reaches max size
643
+ if batch_size >= max_batch_size:
644
+ await commit_current_batch()
645
+
646
+ # Add checkpoint document deletion to batch
647
+ batch.delete(doc.reference)
648
+ batch_size += 1
649
+ deleted_checkpoints += 1
650
+
651
+ # Commit batch when it reaches max size
652
+ if batch_size >= max_batch_size:
653
+ await commit_current_batch()
654
+
655
+ # Commit any remaining operations in the final batch
656
+ await commit_current_batch()
657
+
658
+ logger.info(
659
+ f"Successfully deleted thread {thread_id}: "
660
+ f"{deleted_checkpoints} checkpoints, {deleted_writes} writes "
661
+ f"(total: {total_operations} operations in {batch_count} batches)"
662
+ )
663
+
664
+ except Exception as e:
665
+ logger.error(f"Error deleting thread {thread_id}: {e}", exc_info=True)
666
+ raise