langgraph-api 0.4.1__py3-none-any.whl → 0.7.3__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (135) hide show
  1. langgraph_api/__init__.py +1 -1
  2. langgraph_api/api/__init__.py +111 -51
  3. langgraph_api/api/a2a.py +1610 -0
  4. langgraph_api/api/assistants.py +212 -89
  5. langgraph_api/api/mcp.py +3 -3
  6. langgraph_api/api/meta.py +52 -28
  7. langgraph_api/api/openapi.py +27 -17
  8. langgraph_api/api/profile.py +108 -0
  9. langgraph_api/api/runs.py +342 -195
  10. langgraph_api/api/store.py +19 -2
  11. langgraph_api/api/threads.py +209 -27
  12. langgraph_api/asgi_transport.py +14 -9
  13. langgraph_api/asyncio.py +14 -4
  14. langgraph_api/auth/custom.py +52 -37
  15. langgraph_api/auth/langsmith/backend.py +4 -3
  16. langgraph_api/auth/langsmith/client.py +13 -8
  17. langgraph_api/cli.py +230 -133
  18. langgraph_api/command.py +5 -3
  19. langgraph_api/config/__init__.py +532 -0
  20. langgraph_api/config/_parse.py +58 -0
  21. langgraph_api/config/schemas.py +431 -0
  22. langgraph_api/cron_scheduler.py +17 -1
  23. langgraph_api/encryption/__init__.py +15 -0
  24. langgraph_api/encryption/aes_json.py +158 -0
  25. langgraph_api/encryption/context.py +35 -0
  26. langgraph_api/encryption/custom.py +280 -0
  27. langgraph_api/encryption/middleware.py +632 -0
  28. langgraph_api/encryption/shared.py +63 -0
  29. langgraph_api/errors.py +12 -1
  30. langgraph_api/executor_entrypoint.py +11 -6
  31. langgraph_api/feature_flags.py +29 -0
  32. langgraph_api/graph.py +176 -76
  33. langgraph_api/grpc/client.py +313 -0
  34. langgraph_api/grpc/config_conversion.py +231 -0
  35. langgraph_api/grpc/generated/__init__.py +29 -0
  36. langgraph_api/grpc/generated/checkpointer_pb2.py +63 -0
  37. langgraph_api/grpc/generated/checkpointer_pb2.pyi +99 -0
  38. langgraph_api/grpc/generated/checkpointer_pb2_grpc.py +329 -0
  39. langgraph_api/grpc/generated/core_api_pb2.py +216 -0
  40. langgraph_api/grpc/generated/core_api_pb2.pyi +905 -0
  41. langgraph_api/grpc/generated/core_api_pb2_grpc.py +1621 -0
  42. langgraph_api/grpc/generated/engine_common_pb2.py +219 -0
  43. langgraph_api/grpc/generated/engine_common_pb2.pyi +722 -0
  44. langgraph_api/grpc/generated/engine_common_pb2_grpc.py +24 -0
  45. langgraph_api/grpc/generated/enum_cancel_run_action_pb2.py +37 -0
  46. langgraph_api/grpc/generated/enum_cancel_run_action_pb2.pyi +12 -0
  47. langgraph_api/grpc/generated/enum_cancel_run_action_pb2_grpc.py +24 -0
  48. langgraph_api/grpc/generated/enum_control_signal_pb2.py +37 -0
  49. langgraph_api/grpc/generated/enum_control_signal_pb2.pyi +16 -0
  50. langgraph_api/grpc/generated/enum_control_signal_pb2_grpc.py +24 -0
  51. langgraph_api/grpc/generated/enum_durability_pb2.py +37 -0
  52. langgraph_api/grpc/generated/enum_durability_pb2.pyi +16 -0
  53. langgraph_api/grpc/generated/enum_durability_pb2_grpc.py +24 -0
  54. langgraph_api/grpc/generated/enum_multitask_strategy_pb2.py +37 -0
  55. langgraph_api/grpc/generated/enum_multitask_strategy_pb2.pyi +16 -0
  56. langgraph_api/grpc/generated/enum_multitask_strategy_pb2_grpc.py +24 -0
  57. langgraph_api/grpc/generated/enum_run_status_pb2.py +37 -0
  58. langgraph_api/grpc/generated/enum_run_status_pb2.pyi +22 -0
  59. langgraph_api/grpc/generated/enum_run_status_pb2_grpc.py +24 -0
  60. langgraph_api/grpc/generated/enum_stream_mode_pb2.py +37 -0
  61. langgraph_api/grpc/generated/enum_stream_mode_pb2.pyi +28 -0
  62. langgraph_api/grpc/generated/enum_stream_mode_pb2_grpc.py +24 -0
  63. langgraph_api/grpc/generated/enum_thread_status_pb2.py +37 -0
  64. langgraph_api/grpc/generated/enum_thread_status_pb2.pyi +16 -0
  65. langgraph_api/grpc/generated/enum_thread_status_pb2_grpc.py +24 -0
  66. langgraph_api/grpc/generated/enum_thread_stream_mode_pb2.py +37 -0
  67. langgraph_api/grpc/generated/enum_thread_stream_mode_pb2.pyi +16 -0
  68. langgraph_api/grpc/generated/enum_thread_stream_mode_pb2_grpc.py +24 -0
  69. langgraph_api/grpc/generated/errors_pb2.py +39 -0
  70. langgraph_api/grpc/generated/errors_pb2.pyi +21 -0
  71. langgraph_api/grpc/generated/errors_pb2_grpc.py +24 -0
  72. langgraph_api/grpc/ops/__init__.py +370 -0
  73. langgraph_api/grpc/ops/assistants.py +424 -0
  74. langgraph_api/grpc/ops/runs.py +792 -0
  75. langgraph_api/grpc/ops/threads.py +1013 -0
  76. langgraph_api/http.py +16 -5
  77. langgraph_api/http_metrics.py +15 -35
  78. langgraph_api/http_metrics_utils.py +38 -0
  79. langgraph_api/js/build.mts +1 -1
  80. langgraph_api/js/client.http.mts +13 -7
  81. langgraph_api/js/client.mts +2 -5
  82. langgraph_api/js/package.json +29 -28
  83. langgraph_api/js/remote.py +56 -30
  84. langgraph_api/js/src/graph.mts +20 -0
  85. langgraph_api/js/sse.py +2 -2
  86. langgraph_api/js/ui.py +1 -1
  87. langgraph_api/js/yarn.lock +1204 -1006
  88. langgraph_api/logging.py +29 -2
  89. langgraph_api/metadata.py +99 -28
  90. langgraph_api/middleware/http_logger.py +7 -2
  91. langgraph_api/middleware/private_network.py +7 -7
  92. langgraph_api/models/run.py +54 -93
  93. langgraph_api/otel_context.py +205 -0
  94. langgraph_api/patch.py +5 -3
  95. langgraph_api/queue_entrypoint.py +154 -65
  96. langgraph_api/route.py +47 -5
  97. langgraph_api/schema.py +88 -10
  98. langgraph_api/self_hosted_logs.py +124 -0
  99. langgraph_api/self_hosted_metrics.py +450 -0
  100. langgraph_api/serde.py +79 -37
  101. langgraph_api/server.py +138 -60
  102. langgraph_api/state.py +4 -3
  103. langgraph_api/store.py +25 -16
  104. langgraph_api/stream.py +80 -29
  105. langgraph_api/thread_ttl.py +31 -13
  106. langgraph_api/timing/__init__.py +25 -0
  107. langgraph_api/timing/profiler.py +200 -0
  108. langgraph_api/timing/timer.py +318 -0
  109. langgraph_api/utils/__init__.py +53 -8
  110. langgraph_api/utils/cache.py +47 -10
  111. langgraph_api/utils/config.py +2 -1
  112. langgraph_api/utils/errors.py +77 -0
  113. langgraph_api/utils/future.py +10 -6
  114. langgraph_api/utils/headers.py +76 -2
  115. langgraph_api/utils/retriable_client.py +74 -0
  116. langgraph_api/utils/stream_codec.py +315 -0
  117. langgraph_api/utils/uuids.py +29 -62
  118. langgraph_api/validation.py +9 -0
  119. langgraph_api/webhook.py +120 -6
  120. langgraph_api/worker.py +55 -24
  121. {langgraph_api-0.4.1.dist-info → langgraph_api-0.7.3.dist-info}/METADATA +16 -8
  122. langgraph_api-0.7.3.dist-info/RECORD +168 -0
  123. {langgraph_api-0.4.1.dist-info → langgraph_api-0.7.3.dist-info}/WHEEL +1 -1
  124. langgraph_runtime/__init__.py +1 -0
  125. langgraph_runtime/routes.py +11 -0
  126. logging.json +1 -3
  127. openapi.json +839 -478
  128. langgraph_api/config.py +0 -387
  129. langgraph_api/js/isolate-0x130008000-46649-46649-v8.log +0 -4430
  130. langgraph_api/js/isolate-0x138008000-44681-44681-v8.log +0 -4430
  131. langgraph_api/js/package-lock.json +0 -3308
  132. langgraph_api-0.4.1.dist-info/RECORD +0 -107
  133. /langgraph_api/{utils.py → grpc/__init__.py} +0 -0
  134. {langgraph_api-0.4.1.dist-info → langgraph_api-0.7.3.dist-info}/entry_points.txt +0 -0
  135. {langgraph_api-0.4.1.dist-info → langgraph_api-0.7.3.dist-info}/licenses/LICENSE +0 -0
langgraph_api/api/runs.py CHANGED
@@ -1,17 +1,33 @@
1
1
  import asyncio
2
- from collections.abc import AsyncIterator
2
+ from collections.abc import AsyncIterator, Awaitable, Callable
3
3
  from typing import Literal, cast
4
- from uuid import uuid4
4
+ from uuid import UUID, uuid4
5
5
 
6
6
  import orjson
7
+ import structlog
7
8
  from starlette.exceptions import HTTPException
8
9
  from starlette.responses import Response, StreamingResponse
9
10
 
10
11
  from langgraph_api import config
11
- from langgraph_api.asyncio import ValueEvent, aclosing
12
+ from langgraph_api.asyncio import ValueEvent
13
+ from langgraph_api.encryption.middleware import (
14
+ decrypt_response,
15
+ decrypt_responses,
16
+ encrypt_request,
17
+ )
18
+ from langgraph_api.feature_flags import FF_USE_CORE_API
19
+ from langgraph_api.graph import _validate_assistant_id
20
+ from langgraph_api.grpc.ops import Runs as GrpcRuns
12
21
  from langgraph_api.models.run import create_valid_run
13
22
  from langgraph_api.route import ApiRequest, ApiResponse, ApiRoute
14
- from langgraph_api.schema import CRON_FIELDS, RUN_FIELDS
23
+ from langgraph_api.schema import (
24
+ CRON_ENCRYPTION_FIELDS,
25
+ CRON_FIELDS,
26
+ CRON_PAYLOAD_ENCRYPTION_SUBFIELDS,
27
+ RUN_ENCRYPTION_FIELDS,
28
+ RUN_FIELDS,
29
+ )
30
+ from langgraph_api.serde import json_dumpb, json_loads
15
31
  from langgraph_api.sse import EventSourceResponse
16
32
  from langgraph_api.utils import (
17
33
  fetchone,
@@ -28,18 +44,138 @@ from langgraph_api.validation import (
28
44
  RunCreateStateful,
29
45
  RunCreateStateless,
30
46
  RunsCancel,
47
+ ThreadCronCreate,
31
48
  )
49
+ from langgraph_api.webhook import validate_webhook_url_or_raise
32
50
  from langgraph_license.validation import plus_features_enabled
33
51
  from langgraph_runtime.database import connect
34
- from langgraph_runtime.ops import Crons, Runs, Threads
52
+ from langgraph_runtime.ops import Crons, Runs, StreamHandler, Threads
35
53
  from langgraph_runtime.retry import retry_db
36
54
 
55
+ CrudRuns = GrpcRuns if FF_USE_CORE_API else Runs
56
+
57
+ logger = structlog.stdlib.get_logger(__name__)
58
+
59
+
60
+ _RunResultFallback = Callable[[], Awaitable[bytes]]
61
+
62
+
63
+ def _ensure_crons_enabled() -> None:
64
+ if not (config.FF_CRONS_ENABLED and plus_features_enabled()):
65
+ raise HTTPException(
66
+ status_code=403,
67
+ detail="Crons are currently only available in the cloud version of LangSmith Deployment or with a self-hosting enterprise license. Please visit https://docs.langchain.com/langsmith/deployments to learn more about deployment options, or contact sales@langchain.com for more information",
68
+ )
69
+
70
+
71
+ def _thread_values_fallback(thread_id: UUID) -> _RunResultFallback:
72
+ async def fetch_thread_values() -> bytes:
73
+ async with connect() as conn:
74
+ thread_iter = await Threads.get(conn, thread_id)
75
+ try:
76
+ row = await anext(thread_iter)
77
+ # Decrypt thread fields (values, interrupts, error) if encryption is enabled
78
+ thread = await decrypt_response(
79
+ dict(row),
80
+ "thread",
81
+ ["values", "interrupts", "error"],
82
+ )
83
+ if row["status"] == "error":
84
+ return json_dumpb({"__error__": json_loads(thread["error"])})
85
+ if row["status"] == "interrupted":
86
+ # Get an interrupt for the thread. There is the case where there are multiple interrupts for the same run and we may not show the same
87
+ # interrupt, but we'll always show one. Long term we should show all of them.
88
+ try:
89
+ interrupt_map = json_loads(thread["interrupts"])
90
+ interrupt = [next(iter(interrupt_map.values()))[0]]
91
+ return json_dumpb({"__interrupt__": interrupt})
92
+ except Exception:
93
+ # No interrupt, but status is interrupted from a before/after block. Default back to values.
94
+ pass
95
+ values = json_loads(thread["values"]) if thread["values"] else None
96
+ return json_dumpb(values) if values else b"{}"
97
+ except StopAsyncIteration:
98
+ await logger.awarning(
99
+ f"No checkpoint found for thread {thread_id}",
100
+ thread_id=thread_id,
101
+ )
102
+ return b"{}"
103
+
104
+ return fetch_thread_values
105
+
106
+
107
+ def _run_result_body(
108
+ *,
109
+ run_id: UUID,
110
+ thread_id: UUID,
111
+ sub: StreamHandler,
112
+ cancel_on_disconnect: bool = False,
113
+ ignore_404: bool = False,
114
+ fallback: _RunResultFallback | None = None,
115
+ cancel_message: str | None = None,
116
+ ) -> Callable[[], AsyncIterator[bytes]]:
117
+ last_chunk = ValueEvent()
118
+
119
+ async def consume() -> None:
120
+ vchunk: bytes | None = None
121
+ try:
122
+ async for mode, chunk, _ in Runs.Stream.join(
123
+ run_id,
124
+ stream_channel=sub,
125
+ cancel_on_disconnect=cancel_on_disconnect,
126
+ thread_id=thread_id,
127
+ ignore_404=ignore_404,
128
+ ):
129
+ if mode == b"values" or (
130
+ mode == b"updates" and b"__interrupt__" in chunk
131
+ ):
132
+ vchunk = chunk
133
+ elif mode == b"error":
134
+ vchunk = orjson.dumps({"__error__": orjson.Fragment(chunk)})
135
+ if vchunk is not None:
136
+ last_chunk.set(vchunk)
137
+ elif fallback is not None:
138
+ last_chunk.set(await fallback())
139
+ else:
140
+ last_chunk.set(b"{}")
141
+ finally:
142
+ # Make sure to always clean up the pubsub
143
+ await sub.__aexit__(None, None, None)
144
+
145
+ # keep the connection open by sending whitespace every 5 seconds
146
+ # leading whitespace will be ignored by json parsers
147
+ async def body() -> AsyncIterator[bytes]:
148
+ try:
149
+ stream = asyncio.create_task(consume())
150
+ while True:
151
+ try:
152
+ if stream.done():
153
+ # raise stream exception if any
154
+ stream.result()
155
+ yield await asyncio.wait_for(last_chunk.wait(), timeout=5)
156
+ break
157
+ except TimeoutError:
158
+ yield b"\n"
159
+ except asyncio.CancelledError:
160
+ if cancel_message is not None:
161
+ stream.cancel(cancel_message)
162
+ else:
163
+ stream.cancel()
164
+ await stream
165
+ raise
166
+ finally:
167
+ # Make sure to always clean up the pubsub
168
+ await sub.__aexit__(None, None, None)
169
+
170
+ return body
171
+
37
172
 
38
173
  @retry_db
39
174
  async def create_run(request: ApiRequest):
40
175
  """Create a run."""
41
176
  thread_id = request.path_params["thread_id"]
42
177
  payload = await request.json(RunCreateStateful)
178
+
43
179
  async with connect() as conn:
44
180
  run = await create_valid_run(
45
181
  conn,
@@ -48,6 +184,7 @@ async def create_run(request: ApiRequest):
48
184
  request.headers,
49
185
  request_start_time=request.scope.get("request_start_time_ms"),
50
186
  )
187
+ run = await decrypt_response(run, "run", RUN_ENCRYPTION_FIELDS)
51
188
  return ApiResponse(
52
189
  run,
53
190
  headers={"Content-Location": f"/threads/{thread_id}/runs/{run['run_id']}"},
@@ -58,6 +195,7 @@ async def create_run(request: ApiRequest):
58
195
  async def create_stateless_run(request: ApiRequest):
59
196
  """Create a run."""
60
197
  payload = await request.json(RunCreateStateless)
198
+
61
199
  async with connect() as conn:
62
200
  run = await create_valid_run(
63
201
  conn,
@@ -66,6 +204,7 @@ async def create_stateless_run(request: ApiRequest):
66
204
  request.headers,
67
205
  request_start_time=request.scope.get("request_start_time_ms"),
68
206
  )
207
+ run = await decrypt_response(run, "run", RUN_ENCRYPTION_FIELDS)
69
208
  return ApiResponse(
70
209
  run,
71
210
  headers={"Content-Location": f"/runs/{run['run_id']}"},
@@ -90,6 +229,7 @@ async def create_stateless_run_batch(request: ApiRequest):
90
229
  for payload in batch_payload
91
230
  ]
92
231
  runs = await asyncio.gather(*coros)
232
+ runs = await decrypt_responses(list(runs), "run", RUN_ENCRYPTION_FIELDS)
93
233
  return ApiResponse(runs)
94
234
 
95
235
 
@@ -101,8 +241,8 @@ async def stream_run(
101
241
  payload = await request.json(RunCreateStateful)
102
242
  on_disconnect = payload.get("on_disconnect", "continue")
103
243
  run_id = uuid7()
104
- sub = asyncio.create_task(Runs.Stream.subscribe(run_id, thread_id))
105
244
 
245
+ sub = await Runs.Stream.subscribe(run_id, thread_id)
106
246
  try:
107
247
  async with connect() as conn:
108
248
  run = await create_valid_run(
@@ -114,19 +254,26 @@ async def stream_run(
114
254
  request_start_time=request.scope.get("request_start_time_ms"),
115
255
  )
116
256
  except Exception:
117
- if not sub.cancelled():
118
- handle = await sub
119
- await handle.__aexit__(None, None, None)
257
+ # Clean up the pubsub on errors
258
+ await sub.__aexit__(None, None, None)
120
259
  raise
121
260
 
261
+ async def body():
262
+ try:
263
+ async for event, message, stream_id in Runs.Stream.join(
264
+ run["run_id"],
265
+ thread_id=thread_id,
266
+ cancel_on_disconnect=on_disconnect == "cancel",
267
+ stream_channel=sub,
268
+ last_event_id=None,
269
+ ):
270
+ yield event, message, stream_id
271
+ finally:
272
+ # Make sure to always clean up the pubsub
273
+ await sub.__aexit__(None, None, None)
274
+
122
275
  return EventSourceResponse(
123
- Runs.Stream.join(
124
- run["run_id"],
125
- thread_id=thread_id,
126
- cancel_on_disconnect=on_disconnect == "cancel",
127
- stream_channel=await sub,
128
- last_event_id=None,
129
- ),
276
+ body(),
130
277
  headers={
131
278
  "Location": f"/threads/{thread_id}/runs/{run['run_id']}/stream",
132
279
  "Content-Location": f"/threads/{thread_id}/runs/{run['run_id']}",
@@ -143,7 +290,8 @@ async def stream_run_stateless(
143
290
  on_disconnect = payload.get("on_disconnect", "continue")
144
291
  run_id = uuid7()
145
292
  thread_id = uuid4()
146
- sub = asyncio.create_task(Runs.Stream.subscribe(run_id, thread_id))
293
+
294
+ sub = await Runs.Stream.subscribe(run_id, thread_id)
147
295
  try:
148
296
  async with connect() as conn:
149
297
  run = await create_valid_run(
@@ -156,20 +304,27 @@ async def stream_run_stateless(
156
304
  temporary=True,
157
305
  )
158
306
  except Exception:
159
- if not sub.cancelled():
160
- handle = await sub
161
- await handle.__aexit__(None, None, None)
307
+ # Clean up the pubsub on errors
308
+ await sub.__aexit__(None, None, None)
162
309
  raise
163
310
 
311
+ async def body():
312
+ try:
313
+ async for event, message, stream_id in Runs.Stream.join(
314
+ run["run_id"],
315
+ thread_id=run["thread_id"],
316
+ ignore_404=True,
317
+ cancel_on_disconnect=on_disconnect == "cancel",
318
+ stream_channel=sub,
319
+ last_event_id=None,
320
+ ):
321
+ yield event, message, stream_id
322
+ finally:
323
+ # Make sure to always clean up the pubsub
324
+ await sub.__aexit__(None, None, None)
325
+
164
326
  return EventSourceResponse(
165
- Runs.Stream.join(
166
- run["run_id"],
167
- thread_id=run["thread_id"],
168
- ignore_404=True,
169
- cancel_on_disconnect=on_disconnect == "cancel",
170
- stream_channel=await sub,
171
- last_event_id=None,
172
- ),
327
+ body(),
173
328
  headers={
174
329
  "Location": f"/runs/{run['run_id']}/stream",
175
330
  "Content-Location": f"/runs/{run['run_id']}",
@@ -184,8 +339,7 @@ async def wait_run(request: ApiRequest):
184
339
  payload = await request.json(RunCreateStateful)
185
340
  on_disconnect = payload.get("on_disconnect", "continue")
186
341
  run_id = uuid7()
187
- sub = asyncio.create_task(Runs.Stream.subscribe(run_id, thread_id))
188
-
342
+ sub = await Runs.Stream.subscribe(run_id, thread_id)
189
343
  try:
190
344
  async with connect() as conn:
191
345
  run = await create_valid_run(
@@ -197,60 +351,17 @@ async def wait_run(request: ApiRequest):
197
351
  request_start_time=request.scope.get("request_start_time_ms"),
198
352
  )
199
353
  except Exception:
200
- if not sub.cancelled():
201
- handle = await sub
202
- await handle.__aexit__(None, None, None)
354
+ # Clean up the pubsub on errors
355
+ await sub.__aexit__(None, None, None)
203
356
  raise
204
357
 
205
- last_chunk = ValueEvent()
206
-
207
- async def consume():
208
- vchunk: bytes | None = None
209
- async with aclosing(
210
- Runs.Stream.join(
211
- run["run_id"],
212
- thread_id=run["thread_id"],
213
- stream_channel=await sub,
214
- cancel_on_disconnect=on_disconnect == "cancel",
215
- )
216
- ) as stream:
217
- async for mode, chunk, _ in stream:
218
- if (
219
- mode == b"values"
220
- or mode == b"updates"
221
- and b"__interrupt__" in chunk
222
- ):
223
- vchunk = chunk
224
- elif mode == b"error":
225
- vchunk = orjson.dumps({"__error__": orjson.Fragment(chunk)})
226
- if vchunk is not None:
227
- last_chunk.set(vchunk)
228
- else:
229
- async with connect() as conn:
230
- thread_iter = await Threads.get(conn, thread_id)
231
- try:
232
- thread = await anext(thread_iter)
233
- last_chunk.set(thread["values"])
234
- except StopAsyncIteration:
235
- last_chunk.set(b"{}")
236
-
237
- # keep the connection open by sending whitespace every 5 seconds
238
- # leading whitespace will be ignored by json parsers
239
- async def body() -> AsyncIterator[bytes]:
240
- stream = asyncio.create_task(consume())
241
- while True:
242
- try:
243
- if stream.done():
244
- # raise stream exception if any
245
- stream.result()
246
- yield await asyncio.wait_for(last_chunk.wait(), timeout=5)
247
- break
248
- except TimeoutError:
249
- yield b"\n"
250
- except asyncio.CancelledError:
251
- stream.cancel()
252
- await stream
253
- raise
358
+ body = _run_result_body(
359
+ run_id=run["run_id"],
360
+ thread_id=run["thread_id"],
361
+ sub=sub,
362
+ cancel_on_disconnect=on_disconnect == "cancel",
363
+ fallback=_thread_values_fallback(thread_id),
364
+ )
254
365
 
255
366
  return StreamingResponse(
256
367
  body(),
@@ -270,8 +381,8 @@ async def wait_run_stateless(request: ApiRequest):
270
381
  on_disconnect = payload.get("on_disconnect", "continue")
271
382
  run_id = uuid7()
272
383
  thread_id = uuid4()
273
- sub = asyncio.create_task(Runs.Stream.subscribe(run_id, thread_id))
274
384
 
385
+ sub = await Runs.Stream.subscribe(run_id, thread_id)
275
386
  try:
276
387
  async with connect() as conn:
277
388
  run = await create_valid_run(
@@ -284,55 +395,27 @@ async def wait_run_stateless(request: ApiRequest):
284
395
  temporary=True,
285
396
  )
286
397
  except Exception:
287
- if not sub.cancelled():
288
- handle = await sub
289
- await handle.__aexit__(None, None, None)
398
+ # Clean up the pubsub on errors
399
+ await sub.__aexit__(None, None, None)
290
400
  raise
291
- last_chunk = ValueEvent()
292
-
293
- async def consume():
294
- vchunk: bytes | None = None
295
- async with aclosing(
296
- Runs.Stream.join(
297
- run["run_id"],
298
- thread_id=run["thread_id"],
299
- stream_channel=await sub,
300
- ignore_404=True,
301
- cancel_on_disconnect=on_disconnect == "cancel",
302
- )
303
- ) as stream:
304
- async for mode, chunk, _ in stream:
305
- if (
306
- mode == b"values"
307
- or mode == b"updates"
308
- and b"__interrupt__" in chunk
309
- ):
310
- vchunk = chunk
311
- elif mode == b"error":
312
- vchunk = orjson.dumps({"__error__": orjson.Fragment(chunk)})
313
- if vchunk is not None:
314
- last_chunk.set(vchunk)
315
- else:
316
- # we can't fetch the thread (it was deleted), so just return empty values
317
- last_chunk.set(b"{}")
318
401
 
319
- # keep the connection open by sending whitespace every 5 seconds
320
- # leading whitespace will be ignored by json parsers
321
- async def body() -> AsyncIterator[bytes]:
322
- stream = asyncio.create_task(consume())
323
- while True:
324
- try:
325
- if stream.done():
326
- # raise stream exception if any
327
- stream.result()
328
- yield await asyncio.wait_for(last_chunk.wait(), timeout=5)
329
- break
330
- except TimeoutError:
331
- yield b"\n"
332
- except asyncio.CancelledError:
333
- stream.cancel("Run stream cancelled")
334
- await stream
335
- raise
402
+ async def stateless_fallback() -> bytes:
403
+ await logger.awarning(
404
+ "No checkpoint emitted for stateless run",
405
+ run_id=run["run_id"],
406
+ thread_id=run["thread_id"],
407
+ )
408
+ return b"{}"
409
+
410
+ body = _run_result_body(
411
+ run_id=run["run_id"],
412
+ thread_id=run["thread_id"],
413
+ sub=sub,
414
+ cancel_on_disconnect=on_disconnect == "cancel",
415
+ ignore_404=True,
416
+ fallback=stateless_fallback,
417
+ cancel_message="Run stream cancelled",
418
+ )
336
419
 
337
420
  return StreamingResponse(
338
421
  body(),
@@ -361,7 +444,7 @@ async def list_runs(
361
444
  async with connect() as conn, conn.pipeline():
362
445
  thread, runs = await asyncio.gather(
363
446
  Threads.get(conn, thread_id),
364
- Runs.search(
447
+ CrudRuns.search(
365
448
  conn,
366
449
  thread_id,
367
450
  limit=limit,
@@ -371,7 +454,12 @@ async def list_runs(
371
454
  ),
372
455
  )
373
456
  await fetchone(thread)
374
- return ApiResponse([run async for run in runs])
457
+
458
+ # Collect and decrypt runs
459
+ runs_list = [run async for run in runs]
460
+ runs_list = await decrypt_responses(runs_list, "run", RUN_ENCRYPTION_FIELDS)
461
+
462
+ return ApiResponse(runs_list)
375
463
 
376
464
 
377
465
  @retry_db
@@ -385,14 +473,19 @@ async def get_run(request: ApiRequest):
385
473
  async with connect() as conn, conn.pipeline():
386
474
  thread, run = await asyncio.gather(
387
475
  Threads.get(conn, thread_id),
388
- Runs.get(
476
+ CrudRuns.get(
389
477
  conn,
390
478
  run_id,
391
479
  thread_id=thread_id,
392
480
  ),
393
481
  )
394
482
  await fetchone(thread)
395
- return ApiResponse(await fetchone(run))
483
+ run_dict = await fetchone(run)
484
+
485
+ # Decrypt run metadata and kwargs
486
+ run_dict = await decrypt_response(run_dict, "run", RUN_ENCRYPTION_FIELDS)
487
+
488
+ return ApiResponse(run_dict)
396
489
 
397
490
 
398
491
  @retry_db
@@ -403,11 +496,23 @@ async def join_run(request: ApiRequest):
403
496
  validate_uuid(thread_id, "Invalid thread ID: must be a UUID")
404
497
  validate_uuid(run_id, "Invalid run ID: must be a UUID")
405
498
 
406
- return ApiResponse(
407
- await Runs.join(
408
- run_id,
409
- thread_id=thread_id,
410
- )
499
+ # A touch redundant, but to meet the existing signature of join, we need to throw any 404s before we enter the streaming body
500
+ await Runs.Stream.check_run_stream_auth(run_id, thread_id)
501
+ sub = await Runs.Stream.subscribe(run_id, thread_id)
502
+ body = _run_result_body(
503
+ run_id=run_id,
504
+ thread_id=thread_id,
505
+ sub=sub,
506
+ fallback=_thread_values_fallback(thread_id),
507
+ )
508
+
509
+ return StreamingResponse(
510
+ body(),
511
+ media_type="application/json",
512
+ headers={
513
+ "Location": f"/threads/{thread_id}/runs/{run_id}/join",
514
+ "Content-Location": f"/threads/{thread_id}/runs/{run_id}",
515
+ },
411
516
  )
412
517
 
413
518
 
@@ -422,14 +527,25 @@ async def join_run_stream(request: ApiRequest):
422
527
  validate_uuid(run_id, "Invalid run ID: must be a UUID")
423
528
  stream_mode = request.query_params.get("stream_mode") or []
424
529
  last_event_id = request.headers.get("last-event-id") or None
530
+
531
+ async def body():
532
+ async with await Runs.Stream.subscribe(run_id, thread_id) as sub:
533
+ async for event, message, stream_id in Runs.Stream.join(
534
+ run_id,
535
+ thread_id=thread_id,
536
+ cancel_on_disconnect=cancel_on_disconnect,
537
+ stream_channel=sub,
538
+ stream_mode=stream_mode,
539
+ last_event_id=last_event_id,
540
+ ):
541
+ yield event, message, stream_id
542
+
425
543
  return EventSourceResponse(
426
- Runs.Stream.join(
427
- run_id,
428
- thread_id=thread_id,
429
- cancel_on_disconnect=cancel_on_disconnect,
430
- stream_mode=stream_mode,
431
- last_event_id=last_event_id,
432
- ),
544
+ body(),
545
+ headers={
546
+ "Location": f"/threads/{thread_id}/runs/{run_id}/stream",
547
+ "Content-Location": f"/threads/{thread_id}/runs/{run_id}",
548
+ },
433
549
  )
434
550
 
435
551
 
@@ -446,23 +562,40 @@ async def cancel_run(
446
562
  wait = wait_str.lower() in {"true", "yes", "1"}
447
563
  action_str = request.query_params.get("action", "interrupt")
448
564
  action = cast(
449
- Literal["interrupt", "rollback"],
565
+ "Literal['interrupt', 'rollback']",
450
566
  action_str if action_str in {"interrupt", "rollback"} else "interrupt",
451
567
  )
452
568
 
453
- async with connect() as conn:
454
- await Runs.cancel(
455
- conn,
456
- [run_id],
457
- action=action,
458
- thread_id=thread_id,
459
- )
460
- if wait:
461
- await Runs.join(
462
- run_id,
463
- thread_id=thread_id,
464
- )
465
- return Response(status_code=204 if wait else 202)
569
+ sub = await Runs.Stream.subscribe(run_id, thread_id) if wait else None
570
+ try:
571
+ async with connect() as conn:
572
+ await CrudRuns.cancel(
573
+ conn,
574
+ [run_id],
575
+ action=action,
576
+ thread_id=thread_id,
577
+ )
578
+ except Exception:
579
+ if sub is not None:
580
+ await sub.__aexit__(None, None, None)
581
+ raise
582
+ if not wait:
583
+ return Response(status_code=202)
584
+
585
+ body = _run_result_body(
586
+ run_id=run_id,
587
+ thread_id=thread_id,
588
+ sub=sub,
589
+ )
590
+
591
+ return StreamingResponse(
592
+ body(),
593
+ media_type="application/json",
594
+ headers={
595
+ "Location": f"/threads/{thread_id}/runs/{run_id}/join",
596
+ "Content-Location": f"/threads/{thread_id}/runs/{run_id}",
597
+ },
598
+ )
466
599
 
467
600
 
468
601
  @retry_db
@@ -495,12 +628,12 @@ async def cancel_runs(
495
628
  validate_uuid(rid, "Invalid run ID: must be a UUID")
496
629
  action_str = request.query_params.get("action", "interrupt")
497
630
  action = cast(
498
- Literal["interrupt", "rollback"],
631
+ "Literal['interrupt', 'rollback']",
499
632
  action_str if action_str in ("interrupt", "rollback") else "interrupt",
500
633
  )
501
634
 
502
635
  async with connect() as conn:
503
- await Runs.cancel(
636
+ await CrudRuns.cancel(
504
637
  conn,
505
638
  run_ids,
506
639
  action=action,
@@ -519,7 +652,7 @@ async def delete_run(request: ApiRequest):
519
652
  validate_uuid(run_id, "Invalid run ID: must be a UUID")
520
653
 
521
654
  async with connect() as conn:
522
- rid = await Runs.delete(
655
+ rid = await CrudRuns.delete(
523
656
  conn,
524
657
  run_id,
525
658
  thread_id=thread_id,
@@ -531,40 +664,71 @@ async def delete_run(request: ApiRequest):
531
664
  @retry_db
532
665
  async def create_cron(request: ApiRequest):
533
666
  """Create a cron with new thread."""
667
+ _ensure_crons_enabled()
534
668
  payload = await request.json(CronCreate)
669
+ if webhook := payload.get("webhook"):
670
+ await validate_webhook_url_or_raise(str(webhook))
671
+ _validate_assistant_id(payload.get("assistant_id"))
672
+
673
+ encrypted_payload = await encrypt_request(
674
+ payload,
675
+ "cron",
676
+ CRON_PAYLOAD_ENCRYPTION_SUBFIELDS,
677
+ )
535
678
 
536
679
  async with connect() as conn:
537
680
  cron = await Crons.put(
538
681
  conn,
539
682
  thread_id=None,
683
+ on_run_completed=payload.get("on_run_completed", "delete"),
540
684
  end_time=payload.get("end_time"),
541
685
  schedule=payload.get("schedule"),
542
- payload=payload,
686
+ payload=encrypted_payload,
687
+ metadata=encrypted_payload.get("metadata"),
543
688
  )
544
- return ApiResponse(await fetchone(cron))
689
+ cron_dict = await fetchone(cron)
690
+ cron_dict = await decrypt_response(cron_dict, "cron", CRON_ENCRYPTION_FIELDS)
691
+
692
+ return ApiResponse(cron_dict)
545
693
 
546
694
 
547
695
  @retry_db
548
696
  async def create_thread_cron(request: ApiRequest):
549
697
  """Create a thread specific cron."""
698
+ _ensure_crons_enabled()
550
699
  thread_id = request.path_params["thread_id"]
551
700
  validate_uuid(thread_id, "Invalid thread ID: must be a UUID")
552
- payload = await request.json(CronCreate)
701
+ payload = await request.json(ThreadCronCreate)
702
+ if webhook := payload.get("webhook"):
703
+ await validate_webhook_url_or_raise(str(webhook))
704
+ _validate_assistant_id(payload.get("assistant_id"))
705
+
706
+ encrypted_payload = await encrypt_request(
707
+ payload,
708
+ "cron",
709
+ CRON_PAYLOAD_ENCRYPTION_SUBFIELDS,
710
+ )
553
711
 
554
712
  async with connect() as conn:
555
713
  cron = await Crons.put(
556
714
  conn,
557
715
  thread_id=thread_id,
716
+ on_run_completed=None,
558
717
  end_time=payload.get("end_time"),
559
718
  schedule=payload.get("schedule"),
560
- payload=payload,
719
+ payload=encrypted_payload,
720
+ metadata=encrypted_payload.get("metadata"),
561
721
  )
562
- return ApiResponse(await fetchone(cron))
722
+ cron_dict = await fetchone(cron)
723
+ cron_dict = await decrypt_response(cron_dict, "cron", CRON_ENCRYPTION_FIELDS)
724
+
725
+ return ApiResponse(cron_dict)
563
726
 
564
727
 
565
728
  @retry_db
566
729
  async def delete_cron(request: ApiRequest):
567
730
  """Delete a cron by ID."""
731
+ _ensure_crons_enabled()
568
732
  cron_id = request.path_params["cron_id"]
569
733
  validate_uuid(cron_id, "Invalid cron ID: must be a UUID")
570
734
 
@@ -580,6 +744,7 @@ async def delete_cron(request: ApiRequest):
580
744
  @retry_db
581
745
  async def search_crons(request: ApiRequest):
582
746
  """List all cron jobs for an assistant"""
747
+ _ensure_crons_enabled()
583
748
  payload = await request.json(CronSearch)
584
749
  select = validate_select_columns(payload.get("select") or None, CRON_FIELDS)
585
750
  if assistant_id := payload.get("assistant_id"):
@@ -602,12 +767,16 @@ async def search_crons(request: ApiRequest):
602
767
  crons, response_headers = await get_pagination_headers(
603
768
  crons_iter, next_offset, offset
604
769
  )
770
+
771
+ crons = await decrypt_responses(crons, "cron", CRON_ENCRYPTION_FIELDS)
772
+
605
773
  return ApiResponse(crons, headers=response_headers)
606
774
 
607
775
 
608
776
  @retry_db
609
777
  async def count_crons(request: ApiRequest):
610
778
  """Count cron jobs."""
779
+ _ensure_crons_enabled()
611
780
  payload = await request.json(CronCountRequest)
612
781
  if assistant_id := payload.get("assistant_id"):
613
782
  validate_uuid(assistant_id, "Invalid assistant ID: must be a UUID")
@@ -629,21 +798,9 @@ runs_routes = [
629
798
  ApiRoute("/runs", create_stateless_run, methods=["POST"]),
630
799
  ApiRoute("/runs/batch", create_stateless_run_batch, methods=["POST"]),
631
800
  ApiRoute("/runs/cancel", cancel_runs, methods=["POST"]),
632
- (
633
- ApiRoute("/runs/crons", create_cron, methods=["POST"])
634
- if config.FF_CRONS_ENABLED and plus_features_enabled()
635
- else None
636
- ),
637
- (
638
- ApiRoute("/runs/crons/search", search_crons, methods=["POST"])
639
- if config.FF_CRONS_ENABLED and plus_features_enabled()
640
- else None
641
- ),
642
- (
643
- ApiRoute("/runs/crons/count", count_crons, methods=["POST"])
644
- if config.FF_CRONS_ENABLED and plus_features_enabled()
645
- else None
646
- ),
801
+ ApiRoute("/runs/crons", create_cron, methods=["POST"]),
802
+ ApiRoute("/runs/crons/search", search_crons, methods=["POST"]),
803
+ ApiRoute("/runs/crons/count", count_crons, methods=["POST"]),
647
804
  ApiRoute("/threads/{thread_id}/runs/{run_id}/join", join_run, methods=["GET"]),
648
805
  ApiRoute(
649
806
  "/threads/{thread_id}/runs/{run_id}/stream",
@@ -656,19 +813,9 @@ runs_routes = [
656
813
  ApiRoute("/threads/{thread_id}/runs/stream", stream_run, methods=["POST"]),
657
814
  ApiRoute("/threads/{thread_id}/runs/wait", wait_run, methods=["POST"]),
658
815
  ApiRoute("/threads/{thread_id}/runs", create_run, methods=["POST"]),
659
- (
660
- ApiRoute(
661
- "/threads/{thread_id}/runs/crons", create_thread_cron, methods=["POST"]
662
- )
663
- if config.FF_CRONS_ENABLED and plus_features_enabled()
664
- else None
665
- ),
816
+ ApiRoute("/threads/{thread_id}/runs/crons", create_thread_cron, methods=["POST"]),
666
817
  ApiRoute("/threads/{thread_id}/runs", list_runs, methods=["GET"]),
667
- (
668
- ApiRoute("/runs/crons/{cron_id}", delete_cron, methods=["DELETE"])
669
- if config.FF_CRONS_ENABLED and plus_features_enabled()
670
- else None
671
- ),
818
+ ApiRoute("/runs/crons/{cron_id}", delete_cron, methods=["DELETE"]),
672
819
  ]
673
820
 
674
821
  runs_routes = [route for route in runs_routes if route is not None]