langgraph-api 0.4.1__py3-none-any.whl → 0.7.3__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (135) hide show
  1. langgraph_api/__init__.py +1 -1
  2. langgraph_api/api/__init__.py +111 -51
  3. langgraph_api/api/a2a.py +1610 -0
  4. langgraph_api/api/assistants.py +212 -89
  5. langgraph_api/api/mcp.py +3 -3
  6. langgraph_api/api/meta.py +52 -28
  7. langgraph_api/api/openapi.py +27 -17
  8. langgraph_api/api/profile.py +108 -0
  9. langgraph_api/api/runs.py +342 -195
  10. langgraph_api/api/store.py +19 -2
  11. langgraph_api/api/threads.py +209 -27
  12. langgraph_api/asgi_transport.py +14 -9
  13. langgraph_api/asyncio.py +14 -4
  14. langgraph_api/auth/custom.py +52 -37
  15. langgraph_api/auth/langsmith/backend.py +4 -3
  16. langgraph_api/auth/langsmith/client.py +13 -8
  17. langgraph_api/cli.py +230 -133
  18. langgraph_api/command.py +5 -3
  19. langgraph_api/config/__init__.py +532 -0
  20. langgraph_api/config/_parse.py +58 -0
  21. langgraph_api/config/schemas.py +431 -0
  22. langgraph_api/cron_scheduler.py +17 -1
  23. langgraph_api/encryption/__init__.py +15 -0
  24. langgraph_api/encryption/aes_json.py +158 -0
  25. langgraph_api/encryption/context.py +35 -0
  26. langgraph_api/encryption/custom.py +280 -0
  27. langgraph_api/encryption/middleware.py +632 -0
  28. langgraph_api/encryption/shared.py +63 -0
  29. langgraph_api/errors.py +12 -1
  30. langgraph_api/executor_entrypoint.py +11 -6
  31. langgraph_api/feature_flags.py +29 -0
  32. langgraph_api/graph.py +176 -76
  33. langgraph_api/grpc/client.py +313 -0
  34. langgraph_api/grpc/config_conversion.py +231 -0
  35. langgraph_api/grpc/generated/__init__.py +29 -0
  36. langgraph_api/grpc/generated/checkpointer_pb2.py +63 -0
  37. langgraph_api/grpc/generated/checkpointer_pb2.pyi +99 -0
  38. langgraph_api/grpc/generated/checkpointer_pb2_grpc.py +329 -0
  39. langgraph_api/grpc/generated/core_api_pb2.py +216 -0
  40. langgraph_api/grpc/generated/core_api_pb2.pyi +905 -0
  41. langgraph_api/grpc/generated/core_api_pb2_grpc.py +1621 -0
  42. langgraph_api/grpc/generated/engine_common_pb2.py +219 -0
  43. langgraph_api/grpc/generated/engine_common_pb2.pyi +722 -0
  44. langgraph_api/grpc/generated/engine_common_pb2_grpc.py +24 -0
  45. langgraph_api/grpc/generated/enum_cancel_run_action_pb2.py +37 -0
  46. langgraph_api/grpc/generated/enum_cancel_run_action_pb2.pyi +12 -0
  47. langgraph_api/grpc/generated/enum_cancel_run_action_pb2_grpc.py +24 -0
  48. langgraph_api/grpc/generated/enum_control_signal_pb2.py +37 -0
  49. langgraph_api/grpc/generated/enum_control_signal_pb2.pyi +16 -0
  50. langgraph_api/grpc/generated/enum_control_signal_pb2_grpc.py +24 -0
  51. langgraph_api/grpc/generated/enum_durability_pb2.py +37 -0
  52. langgraph_api/grpc/generated/enum_durability_pb2.pyi +16 -0
  53. langgraph_api/grpc/generated/enum_durability_pb2_grpc.py +24 -0
  54. langgraph_api/grpc/generated/enum_multitask_strategy_pb2.py +37 -0
  55. langgraph_api/grpc/generated/enum_multitask_strategy_pb2.pyi +16 -0
  56. langgraph_api/grpc/generated/enum_multitask_strategy_pb2_grpc.py +24 -0
  57. langgraph_api/grpc/generated/enum_run_status_pb2.py +37 -0
  58. langgraph_api/grpc/generated/enum_run_status_pb2.pyi +22 -0
  59. langgraph_api/grpc/generated/enum_run_status_pb2_grpc.py +24 -0
  60. langgraph_api/grpc/generated/enum_stream_mode_pb2.py +37 -0
  61. langgraph_api/grpc/generated/enum_stream_mode_pb2.pyi +28 -0
  62. langgraph_api/grpc/generated/enum_stream_mode_pb2_grpc.py +24 -0
  63. langgraph_api/grpc/generated/enum_thread_status_pb2.py +37 -0
  64. langgraph_api/grpc/generated/enum_thread_status_pb2.pyi +16 -0
  65. langgraph_api/grpc/generated/enum_thread_status_pb2_grpc.py +24 -0
  66. langgraph_api/grpc/generated/enum_thread_stream_mode_pb2.py +37 -0
  67. langgraph_api/grpc/generated/enum_thread_stream_mode_pb2.pyi +16 -0
  68. langgraph_api/grpc/generated/enum_thread_stream_mode_pb2_grpc.py +24 -0
  69. langgraph_api/grpc/generated/errors_pb2.py +39 -0
  70. langgraph_api/grpc/generated/errors_pb2.pyi +21 -0
  71. langgraph_api/grpc/generated/errors_pb2_grpc.py +24 -0
  72. langgraph_api/grpc/ops/__init__.py +370 -0
  73. langgraph_api/grpc/ops/assistants.py +424 -0
  74. langgraph_api/grpc/ops/runs.py +792 -0
  75. langgraph_api/grpc/ops/threads.py +1013 -0
  76. langgraph_api/http.py +16 -5
  77. langgraph_api/http_metrics.py +15 -35
  78. langgraph_api/http_metrics_utils.py +38 -0
  79. langgraph_api/js/build.mts +1 -1
  80. langgraph_api/js/client.http.mts +13 -7
  81. langgraph_api/js/client.mts +2 -5
  82. langgraph_api/js/package.json +29 -28
  83. langgraph_api/js/remote.py +56 -30
  84. langgraph_api/js/src/graph.mts +20 -0
  85. langgraph_api/js/sse.py +2 -2
  86. langgraph_api/js/ui.py +1 -1
  87. langgraph_api/js/yarn.lock +1204 -1006
  88. langgraph_api/logging.py +29 -2
  89. langgraph_api/metadata.py +99 -28
  90. langgraph_api/middleware/http_logger.py +7 -2
  91. langgraph_api/middleware/private_network.py +7 -7
  92. langgraph_api/models/run.py +54 -93
  93. langgraph_api/otel_context.py +205 -0
  94. langgraph_api/patch.py +5 -3
  95. langgraph_api/queue_entrypoint.py +154 -65
  96. langgraph_api/route.py +47 -5
  97. langgraph_api/schema.py +88 -10
  98. langgraph_api/self_hosted_logs.py +124 -0
  99. langgraph_api/self_hosted_metrics.py +450 -0
  100. langgraph_api/serde.py +79 -37
  101. langgraph_api/server.py +138 -60
  102. langgraph_api/state.py +4 -3
  103. langgraph_api/store.py +25 -16
  104. langgraph_api/stream.py +80 -29
  105. langgraph_api/thread_ttl.py +31 -13
  106. langgraph_api/timing/__init__.py +25 -0
  107. langgraph_api/timing/profiler.py +200 -0
  108. langgraph_api/timing/timer.py +318 -0
  109. langgraph_api/utils/__init__.py +53 -8
  110. langgraph_api/utils/cache.py +47 -10
  111. langgraph_api/utils/config.py +2 -1
  112. langgraph_api/utils/errors.py +77 -0
  113. langgraph_api/utils/future.py +10 -6
  114. langgraph_api/utils/headers.py +76 -2
  115. langgraph_api/utils/retriable_client.py +74 -0
  116. langgraph_api/utils/stream_codec.py +315 -0
  117. langgraph_api/utils/uuids.py +29 -62
  118. langgraph_api/validation.py +9 -0
  119. langgraph_api/webhook.py +120 -6
  120. langgraph_api/worker.py +55 -24
  121. {langgraph_api-0.4.1.dist-info → langgraph_api-0.7.3.dist-info}/METADATA +16 -8
  122. langgraph_api-0.7.3.dist-info/RECORD +168 -0
  123. {langgraph_api-0.4.1.dist-info → langgraph_api-0.7.3.dist-info}/WHEEL +1 -1
  124. langgraph_runtime/__init__.py +1 -0
  125. langgraph_runtime/routes.py +11 -0
  126. logging.json +1 -3
  127. openapi.json +839 -478
  128. langgraph_api/config.py +0 -387
  129. langgraph_api/js/isolate-0x130008000-46649-46649-v8.log +0 -4430
  130. langgraph_api/js/isolate-0x138008000-44681-44681-v8.log +0 -4430
  131. langgraph_api/js/package-lock.json +0 -3308
  132. langgraph_api-0.4.1.dist-info/RECORD +0 -107
  133. /langgraph_api/{utils.py → grpc/__init__.py} +0 -0
  134. {langgraph_api-0.4.1.dist-info → langgraph_api-0.7.3.dist-info}/entry_points.txt +0 -0
  135. {langgraph_api-0.4.1.dist-info → langgraph_api-0.7.3.dist-info}/licenses/LICENSE +0 -0
@@ -5,7 +5,13 @@ from starlette.responses import Response
5
5
  from starlette.routing import BaseRoute
6
6
 
7
7
  from langgraph_api.auth.custom import handle_event as _handle_event
8
+ from langgraph_api.encryption.middleware import (
9
+ decrypt_response,
10
+ decrypt_responses,
11
+ encrypt_request,
12
+ )
8
13
  from langgraph_api.route import ApiRequest, ApiResponse, ApiRoute
14
+ from langgraph_api.schema import STORE_ENCRYPTION_FIELDS
9
15
  from langgraph_api.store import get_store
10
16
  from langgraph_api.utils import get_auth_ctx
11
17
  from langgraph_api.validation import (
@@ -48,6 +54,7 @@ async def handle_event(
48
54
  async def put_item(request: ApiRequest):
49
55
  """Store or update an item."""
50
56
  payload = await request.json(StorePutRequest)
57
+ payload = await encrypt_request(payload, "store", STORE_ENCRYPTION_FIELDS)
51
58
  namespace = tuple(payload["namespace"]) if payload.get("namespace") else ()
52
59
  if err := _validate_namespace(namespace):
53
60
  return err
@@ -78,7 +85,11 @@ async def get_item(request: ApiRequest):
78
85
  }
79
86
  await handle_event("get", handler_payload)
80
87
  result = await (await get_store()).aget(namespace, key)
81
- return ApiResponse(result.dict() if result is not None else None)
88
+ if result is None:
89
+ return ApiResponse(None)
90
+ return ApiResponse(
91
+ await decrypt_response(result.dict(), "store", STORE_ENCRYPTION_FIELDS)
92
+ )
82
93
 
83
94
 
84
95
  @retry_db
@@ -125,7 +136,13 @@ async def search_items(request: ApiRequest):
125
136
  offset=handler_payload["offset"],
126
137
  query=handler_payload["query"],
127
138
  )
128
- return ApiResponse({"items": [item.dict() for item in items]})
139
+ return ApiResponse(
140
+ {
141
+ "items": await decrypt_responses(
142
+ [item.dict() for item in items], "store", STORE_ENCRYPTION_FIELDS
143
+ )
144
+ }
145
+ )
129
146
 
130
147
 
131
148
  @retry_db
@@ -1,11 +1,23 @@
1
+ from typing import get_args
1
2
  from uuid import uuid4
2
3
 
3
4
  from starlette.exceptions import HTTPException
4
5
  from starlette.responses import Response
5
6
  from starlette.routing import BaseRoute
6
7
 
8
+ from langgraph_api.encryption.middleware import (
9
+ decrypt_response,
10
+ decrypt_responses,
11
+ encrypt_request,
12
+ )
13
+ from langgraph_api.feature_flags import FF_USE_CORE_API
14
+ from langgraph_api.grpc.ops import Threads as GrpcThreads
7
15
  from langgraph_api.route import ApiRequest, ApiResponse, ApiRoute
8
- from langgraph_api.schema import THREAD_FIELDS
16
+ from langgraph_api.schema import (
17
+ THREAD_ENCRYPTION_FIELDS,
18
+ THREAD_FIELDS,
19
+ ThreadStreamMode,
20
+ )
9
21
  from langgraph_api.sse import EventSourceResponse
10
22
  from langgraph_api.state import state_snapshot_to_thread_state
11
23
  from langgraph_api.utils import (
@@ -15,10 +27,12 @@ from langgraph_api.utils import (
15
27
  validate_stream_id,
16
28
  validate_uuid,
17
29
  )
30
+ from langgraph_api.utils.headers import get_configurable_headers
18
31
  from langgraph_api.validation import (
19
32
  ThreadCountRequest,
20
33
  ThreadCreate,
21
34
  ThreadPatch,
35
+ ThreadPruneRequest,
22
36
  ThreadSearchRequest,
23
37
  ThreadStateCheckpointRequest,
24
38
  ThreadStateSearch,
@@ -28,6 +42,8 @@ from langgraph_runtime.database import connect
28
42
  from langgraph_runtime.ops import Threads
29
43
  from langgraph_runtime.retry import retry_db
30
44
 
45
+ CrudThreads = GrpcThreads if FF_USE_CORE_API else Threads
46
+
31
47
 
32
48
  @retry_db
33
49
  async def create_thread(
@@ -37,28 +53,56 @@ async def create_thread(
37
53
  payload = await request.json(ThreadCreate)
38
54
  if thread_id := payload.get("thread_id"):
39
55
  validate_uuid(thread_id, "Invalid thread ID: must be a UUID")
56
+
57
+ # Validate keep_latest TTL requires core API
58
+ ttl = payload.get("ttl")
59
+ if ttl and ttl.get("strategy") == "keep_latest" and not FF_USE_CORE_API:
60
+ raise HTTPException(
61
+ status_code=422,
62
+ detail="keep_latest TTL strategy requires FF_USE_CORE_API=true",
63
+ )
64
+
65
+ # Encrypt metadata before storing
66
+ encrypted_payload = await encrypt_request(
67
+ payload,
68
+ "thread",
69
+ ["metadata"],
70
+ )
71
+
40
72
  async with connect() as conn:
41
73
  thread_id = thread_id or str(uuid4())
42
- iter = await Threads.put(
74
+ iter = await CrudThreads.put(
43
75
  conn,
44
76
  thread_id,
45
- metadata=payload.get("metadata"),
77
+ metadata=encrypted_payload.get("metadata") or {},
46
78
  if_exists=payload.get("if_exists") or "raise",
47
79
  ttl=payload.get("ttl"),
48
80
  )
49
-
81
+ config = {
82
+ "configurable": {
83
+ **get_configurable_headers(request.headers),
84
+ "thread_id": thread_id,
85
+ }
86
+ }
50
87
  if supersteps := payload.get("supersteps"):
51
88
  try:
52
89
  await Threads.State.bulk(
53
90
  conn,
54
- config={"configurable": {"thread_id": thread_id}},
91
+ config=config,
55
92
  supersteps=supersteps,
56
93
  )
57
94
  except HTTPException as e:
58
95
  detail = f"Thread {thread_id} was created, but there were problems updating the state: {e.detail}"
59
96
  raise HTTPException(status_code=201, detail=detail) from e
60
97
 
61
- return ApiResponse(await fetchone(iter, not_found_code=409))
98
+ # Decrypt thread fields in response
99
+ thread = await fetchone(iter, not_found_code=409)
100
+ thread = await decrypt_response(
101
+ thread,
102
+ "thread",
103
+ THREAD_ENCRYPTION_FIELDS,
104
+ )
105
+ return ApiResponse(thread)
62
106
 
63
107
 
64
108
  @retry_db
@@ -70,12 +114,14 @@ async def search_threads(
70
114
  select = validate_select_columns(payload.get("select") or None, THREAD_FIELDS)
71
115
  limit = int(payload.get("limit") or 10)
72
116
  offset = int(payload.get("offset") or 0)
117
+
73
118
  async with connect() as conn:
74
- threads_iter, next_offset = await Threads.search(
119
+ threads_iter, next_offset = await CrudThreads.search(
75
120
  conn,
76
121
  status=payload.get("status"),
77
122
  values=payload.get("values"),
78
123
  metadata=payload.get("metadata"),
124
+ ids=payload.get("ids"),
79
125
  limit=limit,
80
126
  offset=offset,
81
127
  sort_by=payload.get("sort_by"),
@@ -85,7 +131,15 @@ async def search_threads(
85
131
  threads, response_headers = await get_pagination_headers(
86
132
  threads_iter, next_offset, offset
87
133
  )
88
- return ApiResponse(threads, headers=response_headers)
134
+
135
+ # Decrypt metadata, values, interrupts, and error in all returned threads
136
+ decrypted_threads = await decrypt_responses(
137
+ threads,
138
+ "thread",
139
+ THREAD_ENCRYPTION_FIELDS,
140
+ )
141
+
142
+ return ApiResponse(decrypted_threads, headers=response_headers)
89
143
 
90
144
 
91
145
  @retry_db
@@ -95,7 +149,7 @@ async def count_threads(
95
149
  """Count threads."""
96
150
  payload = await request.json(ThreadCountRequest)
97
151
  async with connect() as conn:
98
- count = await Threads.count(
152
+ count = await CrudThreads.count(
99
153
  conn,
100
154
  status=payload.get("status"),
101
155
  values=payload.get("values"),
@@ -113,10 +167,14 @@ async def get_thread_state(
113
167
  validate_uuid(thread_id, "Invalid thread ID: must be a UUID")
114
168
  subgraphs = request.query_params.get("subgraphs") in ("true", "True")
115
169
  async with connect() as conn:
170
+ config = {
171
+ "configurable": {
172
+ **get_configurable_headers(request.headers),
173
+ "thread_id": thread_id,
174
+ }
175
+ }
116
176
  state = state_snapshot_to_thread_state(
117
- await Threads.State.get(
118
- conn, {"configurable": {"thread_id": thread_id}}, subgraphs=subgraphs
119
- )
177
+ await Threads.State.get(conn, config=config, subgraphs=subgraphs)
120
178
  )
121
179
  return ApiResponse(state)
122
180
 
@@ -130,15 +188,17 @@ async def get_thread_state_at_checkpoint(
130
188
  validate_uuid(thread_id, "Invalid thread ID: must be a UUID")
131
189
  checkpoint_id = request.path_params["checkpoint_id"]
132
190
  async with connect() as conn:
191
+ config = {
192
+ "configurable": {
193
+ **get_configurable_headers(request.headers),
194
+ "thread_id": thread_id,
195
+ "checkpoint_id": checkpoint_id,
196
+ }
197
+ }
133
198
  state = state_snapshot_to_thread_state(
134
199
  await Threads.State.get(
135
200
  conn,
136
- {
137
- "configurable": {
138
- "thread_id": thread_id,
139
- "checkpoint_id": checkpoint_id,
140
- }
141
- },
201
+ config=config,
142
202
  subgraphs=request.query_params.get("subgraphs") in ("true", "True"),
143
203
  )
144
204
  )
@@ -154,10 +214,17 @@ async def get_thread_state_at_checkpoint_post(
154
214
  validate_uuid(thread_id, "Invalid thread ID: must be a UUID")
155
215
  payload = await request.json(ThreadStateCheckpointRequest)
156
216
  async with connect() as conn:
217
+ config = {
218
+ "configurable": {
219
+ **payload["checkpoint"],
220
+ **get_configurable_headers(request.headers),
221
+ "thread_id": thread_id,
222
+ }
223
+ }
157
224
  state = state_snapshot_to_thread_state(
158
225
  await Threads.State.get(
159
226
  conn,
160
- {"configurable": {"thread_id": thread_id, **payload["checkpoint"]}},
227
+ config=config,
161
228
  subgraphs=payload.get("subgraphs", False),
162
229
  )
163
230
  )
@@ -182,6 +249,7 @@ async def update_thread_state(
182
249
  config["configurable"]["user_id"] = user_id
183
250
  except AssertionError:
184
251
  pass
252
+ config["configurable"].update(get_configurable_headers(request.headers))
185
253
  async with connect() as conn:
186
254
  inserted = await Threads.State.post(
187
255
  conn,
@@ -205,7 +273,13 @@ async def get_thread_history(
205
273
  except ValueError:
206
274
  raise HTTPException(status_code=422, detail=f"Invalid limit {limit_}") from None
207
275
  before = request.query_params.get("before")
208
- config = {"configurable": {"thread_id": thread_id, "checkpoint_ns": ""}}
276
+ config = {
277
+ "configurable": {
278
+ "thread_id": thread_id,
279
+ "checkpoint_ns": "",
280
+ **get_configurable_headers(request.headers),
281
+ }
282
+ }
209
283
  async with connect() as conn:
210
284
  states = [
211
285
  state_snapshot_to_thread_state(c)
@@ -226,6 +300,7 @@ async def get_thread_history_post(
226
300
  payload = await request.json(ThreadStateSearch)
227
301
  config = {"configurable": {"thread_id": thread_id, "checkpoint_ns": ""}}
228
302
  config["configurable"].update(payload.get("checkpoint", {}))
303
+ config["configurable"].update(get_configurable_headers(request.headers))
229
304
  async with connect() as conn:
230
305
  states = [
231
306
  state_snapshot_to_thread_state(c)
@@ -247,9 +322,23 @@ async def get_thread(
247
322
  """Get a thread by ID."""
248
323
  thread_id = request.path_params["thread_id"]
249
324
  validate_uuid(thread_id, "Invalid thread ID: must be a UUID")
325
+
326
+ # Parse include parameter for optional fields (e.g., ttl)
327
+ include_param = request.query_params.get("include", "")
328
+ include_fields = [f.strip() for f in include_param.split(",") if f.strip()]
329
+ include_ttl = "ttl" in include_fields
330
+
250
331
  async with connect() as conn:
251
- thread = await Threads.get(conn, thread_id)
252
- return ApiResponse(await fetchone(thread))
332
+ thread = await CrudThreads.get(conn, thread_id, include_ttl=include_ttl)
333
+
334
+ # Decrypt metadata, values, interrupts, and error in response
335
+ thread_data = await fetchone(thread)
336
+ thread_data = await decrypt_response(
337
+ thread_data,
338
+ "thread",
339
+ THREAD_ENCRYPTION_FIELDS,
340
+ )
341
+ return ApiResponse(thread_data)
253
342
 
254
343
 
255
344
  @retry_db
@@ -260,9 +349,37 @@ async def patch_thread(
260
349
  thread_id = request.path_params["thread_id"]
261
350
  validate_uuid(thread_id, "Invalid thread ID: must be a UUID")
262
351
  payload = await request.json(ThreadPatch)
352
+
353
+ # Validate keep_latest TTL requires core API
354
+ ttl = payload.get("ttl")
355
+ if ttl and ttl.get("strategy") == "keep_latest" and not FF_USE_CORE_API:
356
+ raise HTTPException(
357
+ status_code=422,
358
+ detail="keep_latest TTL strategy requires FF_USE_CORE_API=true",
359
+ )
360
+
361
+ # Encrypt metadata before storing
362
+ encrypted_payload = await encrypt_request(
363
+ payload,
364
+ "thread",
365
+ ["metadata"],
366
+ )
367
+
263
368
  async with connect() as conn:
264
- thread = await Threads.patch(conn, thread_id, metadata=payload["metadata"])
265
- return ApiResponse(await fetchone(thread))
369
+ thread = await CrudThreads.patch(
370
+ conn,
371
+ thread_id,
372
+ metadata=encrypted_payload.get("metadata") or {},
373
+ ttl=payload.get("ttl"),
374
+ )
375
+ thread_data = await fetchone(thread)
376
+ # Decrypt metadata, values, interrupts, and error in response
377
+ thread_data = await decrypt_response(
378
+ thread_data,
379
+ "thread",
380
+ THREAD_ENCRYPTION_FIELDS,
381
+ )
382
+ return ApiResponse(thread_data)
266
383
 
267
384
 
268
385
  @retry_db
@@ -271,17 +388,60 @@ async def delete_thread(request: ApiRequest):
271
388
  thread_id = request.path_params["thread_id"]
272
389
  validate_uuid(thread_id, "Invalid thread ID: must be a UUID")
273
390
  async with connect() as conn:
274
- tid = await Threads.delete(conn, thread_id)
391
+ tid = await CrudThreads.delete(conn, thread_id)
275
392
  await fetchone(tid)
276
393
  return Response(status_code=204)
277
394
 
278
395
 
396
+ @retry_db
397
+ async def prune_threads(request: ApiRequest):
398
+ """Prune threads by ID."""
399
+ payload = await request.json(ThreadPruneRequest)
400
+ thread_ids = payload.get("thread_ids", [])
401
+ strategy = payload.get("strategy", "delete")
402
+
403
+ # Validate each thread_id is a valid UUID
404
+ for tid in thread_ids:
405
+ validate_uuid(tid, "Invalid thread ID: must be a UUID")
406
+
407
+ # Validate strategy
408
+ if strategy not in ("delete", "keep_latest"):
409
+ raise HTTPException(
410
+ status_code=422,
411
+ detail=f"Invalid strategy: {strategy}. Expected 'delete' or 'keep_latest'.",
412
+ )
413
+
414
+ # Empty list is a no-op, return early
415
+ if not thread_ids:
416
+ return ApiResponse({"pruned_count": 0})
417
+
418
+ if not FF_USE_CORE_API:
419
+ raise HTTPException(
420
+ status_code=422,
421
+ detail="Thread prune requires FF_USE_CORE_API=true",
422
+ )
423
+
424
+ pruned_count = await CrudThreads.prune(
425
+ thread_ids=thread_ids,
426
+ strategy=strategy,
427
+ )
428
+
429
+ return ApiResponse({"pruned_count": pruned_count})
430
+
431
+
279
432
  @retry_db
280
433
  async def copy_thread(request: ApiRequest):
281
434
  thread_id = request.path_params["thread_id"]
282
435
  async with connect() as conn:
283
- iter = await Threads.copy(conn, thread_id)
284
- return ApiResponse(await fetchone(iter, not_found_code=409))
436
+ iter = await CrudThreads.copy(conn, thread_id)
437
+ thread_data = await fetchone(iter, not_found_code=409)
438
+ # Decrypt metadata, values, interrupts, and error in response
439
+ thread_data = await decrypt_response(
440
+ thread_data,
441
+ "thread",
442
+ THREAD_ENCRYPTION_FIELDS,
443
+ )
444
+ return ApiResponse(thread_data)
285
445
 
286
446
 
287
447
  @retry_db
@@ -293,10 +453,31 @@ async def join_thread_stream(request: ApiRequest):
293
453
  validate_stream_id(
294
454
  last_event_id, "Invalid last-event-id: must be a valid Redis stream ID"
295
455
  )
456
+
457
+ # Parse stream_modes parameter - can be single string or comma-separated list
458
+ stream_modes_param = request.query_params.get("stream_modes")
459
+ if stream_modes_param:
460
+ if "," in stream_modes_param:
461
+ # Handle comma-separated list
462
+ stream_modes = [mode.strip() for mode in stream_modes_param.split(",")]
463
+ else:
464
+ # Handle single value
465
+ stream_modes = [stream_modes_param]
466
+ # Validate each mode
467
+ for mode in stream_modes:
468
+ if mode not in get_args(ThreadStreamMode):
469
+ raise HTTPException(
470
+ status_code=422, detail=f"Invalid stream mode: {mode}"
471
+ )
472
+ else:
473
+ # Default to run_modes
474
+ stream_modes = ["run_modes"]
475
+
296
476
  return EventSourceResponse(
297
477
  Threads.Stream.join(
298
478
  thread_id,
299
479
  last_event_id=last_event_id,
480
+ stream_modes=stream_modes,
300
481
  ),
301
482
  )
302
483
 
@@ -305,6 +486,7 @@ threads_routes: list[BaseRoute] = [
305
486
  ApiRoute("/threads", endpoint=create_thread, methods=["POST"]),
306
487
  ApiRoute("/threads/search", endpoint=search_threads, methods=["POST"]),
307
488
  ApiRoute("/threads/count", endpoint=count_threads, methods=["POST"]),
489
+ ApiRoute("/threads/prune", endpoint=prune_threads, methods=["POST"]),
308
490
  ApiRoute("/threads/{thread_id}", endpoint=get_thread, methods=["GET"]),
309
491
  ApiRoute("/threads/{thread_id}", endpoint=patch_thread, methods=["PATCH"]),
310
492
  ApiRoute("/threads/{thread_id}", endpoint=delete_thread, methods=["DELETE"]),
@@ -25,7 +25,7 @@ def is_running_trio() -> bool:
25
25
  # sniffio is a dependency of trio.
26
26
 
27
27
  # See https://github.com/python-trio/trio/issues/2802
28
- import sniffio
28
+ import sniffio # type: ignore[unresolved-import]
29
29
 
30
30
  if sniffio.current_async_library() == "trio":
31
31
  return True
@@ -84,7 +84,8 @@ class ASGITransport(ASGITransportBase):
84
84
  ) -> Response:
85
85
  from langgraph_api.asyncio import call_soon_in_main_loop
86
86
 
87
- assert isinstance(request.stream, AsyncByteStream)
87
+ if not isinstance(request.stream, AsyncByteStream):
88
+ raise ValueError("Request stream must be an AsyncByteStream")
88
89
 
89
90
  # ASGI scope.
90
91
  scope = {
@@ -133,14 +134,15 @@ class ASGITransport(ASGITransportBase):
133
134
  nonlocal status_code, response_headers, response_started
134
135
 
135
136
  if message["type"] == "http.response.start":
136
- assert not response_started
137
-
137
+ if response_started:
138
+ raise RuntimeError("Response already started")
138
139
  status_code = message["status"]
139
140
  response_headers = message.get("headers", [])
140
141
  response_started = True
141
142
 
142
143
  elif message["type"] == "http.response.body":
143
- assert not response_complete.is_set()
144
+ if response_complete.is_set():
145
+ raise RuntimeError("Response already complete")
144
146
  body = message.get("body", b"")
145
147
  more_body = message.get("more_body", False)
146
148
 
@@ -152,7 +154,7 @@ class ASGITransport(ASGITransportBase):
152
154
 
153
155
  try:
154
156
  await call_soon_in_main_loop(self.app(scope, receive, send))
155
- except Exception: # noqa: PIE-786
157
+ except Exception:
156
158
  if self.raise_app_exceptions:
157
159
  raise
158
160
 
@@ -162,9 +164,12 @@ class ASGITransport(ASGITransportBase):
162
164
  if response_headers is None:
163
165
  response_headers = {}
164
166
 
165
- assert response_complete.is_set()
166
- assert status_code is not None
167
- assert response_headers is not None
167
+ if not response_complete.is_set():
168
+ raise RuntimeError("Response not complete")
169
+ if status_code is None:
170
+ raise RuntimeError("Status code not set")
171
+ if response_headers is None:
172
+ raise RuntimeError("Response headers not set")
168
173
 
169
174
  stream = ASGIResponseStream(body_parts)
170
175
 
langgraph_api/asyncio.py CHANGED
@@ -115,11 +115,16 @@ def create_task(
115
115
 
116
116
 
117
117
  def run_coroutine_threadsafe(
118
- coro: Coroutine[Any, Any, T], ignore_exceptions: tuple[type[Exception], ...] = ()
118
+ coro: Coroutine[Any, Any, T],
119
+ ignore_exceptions: tuple[type[Exception], ...] = (),
120
+ *,
121
+ loop: asyncio.AbstractEventLoop | None = None,
119
122
  ) -> concurrent.futures.Future[T] | concurrent.futures.Future[None]:
120
- if _MAIN_LOOP is None:
123
+ if loop is None:
124
+ loop = _MAIN_LOOP
125
+ if loop is None:
121
126
  raise RuntimeError("No event loop set")
122
- future = asyncio.run_coroutine_threadsafe(coro, _MAIN_LOOP)
127
+ future = asyncio.run_coroutine_threadsafe(coro, loop)
123
128
  future.add_done_callback(partial(_create_task_done_callback, ignore_exceptions))
124
129
  return future
125
130
 
@@ -158,12 +163,15 @@ class SimpleTaskGroup(AbstractAsyncContextManager["SimpleTaskGroup"]):
158
163
  self,
159
164
  *coros: Coroutine[Any, Any, T],
160
165
  cancel: bool = False,
166
+ cancel_event: asyncio.Event | None = None,
161
167
  wait: bool = True,
162
168
  taskset: set[asyncio.Task] | None = None,
163
169
  taskgroup_name: str | None = None,
164
170
  ) -> None:
165
- self.tasks = taskset if taskset is not None else set()
171
+ # Copy the taskset to avoid modifying the original set unintentionally (like in lifespan)
172
+ self.tasks = taskset.copy() if taskset is not None else set()
166
173
  self.cancel = cancel
174
+ self.cancel_event = cancel_event
167
175
  self.wait = wait
168
176
  if taskset:
169
177
  for task in tuple(taskset):
@@ -180,6 +188,8 @@ class SimpleTaskGroup(AbstractAsyncContextManager["SimpleTaskGroup"]):
180
188
  try:
181
189
  if (exc := task.exception()) and not isinstance(exc, ignore_exceptions):
182
190
  logger.exception("asyncio.task failed in task group", exc_info=exc)
191
+ if self.cancel_event:
192
+ self.cancel_event.set()
183
193
  except asyncio.CancelledError:
184
194
  pass
185
195