langgraph-api 0.5.4__py3-none-any.whl → 0.7.3__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (122) hide show
  1. langgraph_api/__init__.py +1 -1
  2. langgraph_api/api/__init__.py +93 -27
  3. langgraph_api/api/a2a.py +36 -32
  4. langgraph_api/api/assistants.py +114 -26
  5. langgraph_api/api/mcp.py +3 -3
  6. langgraph_api/api/meta.py +15 -2
  7. langgraph_api/api/openapi.py +27 -17
  8. langgraph_api/api/profile.py +108 -0
  9. langgraph_api/api/runs.py +114 -57
  10. langgraph_api/api/store.py +19 -2
  11. langgraph_api/api/threads.py +133 -10
  12. langgraph_api/asgi_transport.py +14 -9
  13. langgraph_api/auth/custom.py +23 -13
  14. langgraph_api/cli.py +86 -41
  15. langgraph_api/command.py +2 -2
  16. langgraph_api/config/__init__.py +532 -0
  17. langgraph_api/config/_parse.py +58 -0
  18. langgraph_api/config/schemas.py +431 -0
  19. langgraph_api/cron_scheduler.py +17 -1
  20. langgraph_api/encryption/__init__.py +15 -0
  21. langgraph_api/encryption/aes_json.py +158 -0
  22. langgraph_api/encryption/context.py +35 -0
  23. langgraph_api/encryption/custom.py +280 -0
  24. langgraph_api/encryption/middleware.py +632 -0
  25. langgraph_api/encryption/shared.py +63 -0
  26. langgraph_api/errors.py +12 -1
  27. langgraph_api/executor_entrypoint.py +11 -6
  28. langgraph_api/feature_flags.py +19 -0
  29. langgraph_api/graph.py +163 -64
  30. langgraph_api/{grpc_ops → grpc}/client.py +142 -12
  31. langgraph_api/{grpc_ops → grpc}/config_conversion.py +16 -10
  32. langgraph_api/grpc/generated/__init__.py +29 -0
  33. langgraph_api/grpc/generated/checkpointer_pb2.py +63 -0
  34. langgraph_api/grpc/generated/checkpointer_pb2.pyi +99 -0
  35. langgraph_api/grpc/generated/checkpointer_pb2_grpc.py +329 -0
  36. langgraph_api/grpc/generated/core_api_pb2.py +216 -0
  37. langgraph_api/{grpc_ops → grpc}/generated/core_api_pb2.pyi +292 -372
  38. langgraph_api/{grpc_ops → grpc}/generated/core_api_pb2_grpc.py +252 -31
  39. langgraph_api/grpc/generated/engine_common_pb2.py +219 -0
  40. langgraph_api/{grpc_ops → grpc}/generated/engine_common_pb2.pyi +178 -104
  41. langgraph_api/grpc/generated/enum_cancel_run_action_pb2.py +37 -0
  42. langgraph_api/grpc/generated/enum_cancel_run_action_pb2.pyi +12 -0
  43. langgraph_api/grpc/generated/enum_cancel_run_action_pb2_grpc.py +24 -0
  44. langgraph_api/grpc/generated/enum_control_signal_pb2.py +37 -0
  45. langgraph_api/grpc/generated/enum_control_signal_pb2.pyi +16 -0
  46. langgraph_api/grpc/generated/enum_control_signal_pb2_grpc.py +24 -0
  47. langgraph_api/grpc/generated/enum_durability_pb2.py +37 -0
  48. langgraph_api/grpc/generated/enum_durability_pb2.pyi +16 -0
  49. langgraph_api/grpc/generated/enum_durability_pb2_grpc.py +24 -0
  50. langgraph_api/grpc/generated/enum_multitask_strategy_pb2.py +37 -0
  51. langgraph_api/grpc/generated/enum_multitask_strategy_pb2.pyi +16 -0
  52. langgraph_api/grpc/generated/enum_multitask_strategy_pb2_grpc.py +24 -0
  53. langgraph_api/grpc/generated/enum_run_status_pb2.py +37 -0
  54. langgraph_api/grpc/generated/enum_run_status_pb2.pyi +22 -0
  55. langgraph_api/grpc/generated/enum_run_status_pb2_grpc.py +24 -0
  56. langgraph_api/grpc/generated/enum_stream_mode_pb2.py +37 -0
  57. langgraph_api/grpc/generated/enum_stream_mode_pb2.pyi +28 -0
  58. langgraph_api/grpc/generated/enum_stream_mode_pb2_grpc.py +24 -0
  59. langgraph_api/grpc/generated/enum_thread_status_pb2.py +37 -0
  60. langgraph_api/grpc/generated/enum_thread_status_pb2.pyi +16 -0
  61. langgraph_api/grpc/generated/enum_thread_status_pb2_grpc.py +24 -0
  62. langgraph_api/grpc/generated/enum_thread_stream_mode_pb2.py +37 -0
  63. langgraph_api/grpc/generated/enum_thread_stream_mode_pb2.pyi +16 -0
  64. langgraph_api/grpc/generated/enum_thread_stream_mode_pb2_grpc.py +24 -0
  65. langgraph_api/grpc/generated/errors_pb2.py +39 -0
  66. langgraph_api/grpc/generated/errors_pb2.pyi +21 -0
  67. langgraph_api/grpc/generated/errors_pb2_grpc.py +24 -0
  68. langgraph_api/grpc/ops/__init__.py +370 -0
  69. langgraph_api/grpc/ops/assistants.py +424 -0
  70. langgraph_api/grpc/ops/runs.py +792 -0
  71. langgraph_api/grpc/ops/threads.py +1013 -0
  72. langgraph_api/http.py +16 -5
  73. langgraph_api/js/client.mts +1 -4
  74. langgraph_api/js/package.json +28 -27
  75. langgraph_api/js/remote.py +39 -17
  76. langgraph_api/js/sse.py +2 -2
  77. langgraph_api/js/ui.py +1 -1
  78. langgraph_api/js/yarn.lock +1139 -869
  79. langgraph_api/metadata.py +29 -3
  80. langgraph_api/middleware/http_logger.py +1 -1
  81. langgraph_api/middleware/private_network.py +7 -7
  82. langgraph_api/models/run.py +44 -26
  83. langgraph_api/otel_context.py +205 -0
  84. langgraph_api/patch.py +2 -2
  85. langgraph_api/queue_entrypoint.py +34 -35
  86. langgraph_api/route.py +33 -1
  87. langgraph_api/schema.py +84 -9
  88. langgraph_api/self_hosted_logs.py +2 -2
  89. langgraph_api/self_hosted_metrics.py +73 -3
  90. langgraph_api/serde.py +16 -4
  91. langgraph_api/server.py +33 -31
  92. langgraph_api/state.py +3 -2
  93. langgraph_api/store.py +25 -16
  94. langgraph_api/stream.py +20 -16
  95. langgraph_api/thread_ttl.py +28 -13
  96. langgraph_api/timing/__init__.py +25 -0
  97. langgraph_api/timing/profiler.py +200 -0
  98. langgraph_api/timing/timer.py +318 -0
  99. langgraph_api/utils/__init__.py +53 -8
  100. langgraph_api/utils/config.py +2 -1
  101. langgraph_api/utils/future.py +10 -6
  102. langgraph_api/utils/uuids.py +29 -62
  103. langgraph_api/validation.py +6 -0
  104. langgraph_api/webhook.py +120 -6
  105. langgraph_api/worker.py +54 -24
  106. {langgraph_api-0.5.4.dist-info → langgraph_api-0.7.3.dist-info}/METADATA +8 -6
  107. langgraph_api-0.7.3.dist-info/RECORD +168 -0
  108. {langgraph_api-0.5.4.dist-info → langgraph_api-0.7.3.dist-info}/WHEEL +1 -1
  109. langgraph_runtime/__init__.py +1 -0
  110. langgraph_runtime/routes.py +11 -0
  111. logging.json +1 -3
  112. openapi.json +635 -537
  113. langgraph_api/config.py +0 -523
  114. langgraph_api/grpc_ops/generated/__init__.py +0 -5
  115. langgraph_api/grpc_ops/generated/core_api_pb2.py +0 -275
  116. langgraph_api/grpc_ops/generated/engine_common_pb2.py +0 -194
  117. langgraph_api/grpc_ops/ops.py +0 -1045
  118. langgraph_api-0.5.4.dist-info/RECORD +0 -121
  119. /langgraph_api/{grpc_ops → grpc}/__init__.py +0 -0
  120. /langgraph_api/{grpc_ops → grpc}/generated/engine_common_pb2_grpc.py +0 -0
  121. {langgraph_api-0.5.4.dist-info → langgraph_api-0.7.3.dist-info}/entry_points.txt +0 -0
  122. {langgraph_api-0.5.4.dist-info → langgraph_api-0.7.3.dist-info}/licenses/LICENSE +0 -0
@@ -0,0 +1,792 @@
1
+ """gRPC-based runs operations."""
2
+
3
+ from __future__ import annotations
4
+
5
+ from contextlib import asynccontextmanager
6
+ from datetime import UTC
7
+ from http import HTTPStatus
8
+ from typing import TYPE_CHECKING, Any, Literal
9
+ from uuid import UUID
10
+
11
+ import structlog
12
+ from google.protobuf.empty_pb2 import Empty # type: ignore[import]
13
+ from grpc import StatusCode
14
+ from grpc.aio import AioRpcError
15
+ from langgraph_sdk import Auth
16
+ from starlette.exceptions import HTTPException
17
+
18
+ from langgraph_api.asyncio import SimpleTaskGroup, ValueEvent
19
+ from langgraph_api.errors import UserInterrupt, UserRollback
20
+ from langgraph_api.grpc.client import get_shared_client
21
+ from langgraph_api.grpc.config_conversion import config_from_proto
22
+ from langgraph_api.grpc.generated import (
23
+ core_api_pb2 as pb,
24
+ )
25
+ from langgraph_api.grpc.generated import (
26
+ enum_cancel_run_action_pb2 as enum_cancel_run_action,
27
+ )
28
+ from langgraph_api.grpc.generated import (
29
+ enum_control_signal_pb2 as enum_control_signal,
30
+ )
31
+ from langgraph_api.grpc.generated import (
32
+ enum_multitask_strategy_pb2 as enum_multitask_strategy,
33
+ )
34
+ from langgraph_api.grpc.generated import (
35
+ enum_run_status_pb2 as enum_run_status,
36
+ )
37
+ from langgraph_api.grpc.generated import (
38
+ enum_stream_mode_pb2 as enum_stream_mode,
39
+ )
40
+ from langgraph_api.grpc.ops import (
41
+ Authenticated,
42
+ grpc_error_guard,
43
+ )
44
+ from langgraph_api.serde import json_dumpb, json_dumpb_optional, json_loads_optional
45
+
46
+ if TYPE_CHECKING:
47
+ from collections.abc import AsyncIterator, Sequence
48
+
49
+ from langgraph_api.schema import (
50
+ IfNotExists,
51
+ MetadataInput,
52
+ MultitaskStrategy,
53
+ QueueStats,
54
+ Run,
55
+ RunSelectField,
56
+ RunStatus,
57
+ )
58
+
59
+
60
+ RUN_STATUS_TO_PB = {
61
+ "pending": enum_run_status.pending,
62
+ "running": enum_run_status.running,
63
+ "error": enum_run_status.error,
64
+ "success": enum_run_status.success,
65
+ "timeout": enum_run_status.timeout,
66
+ "interrupted": enum_run_status.interrupted,
67
+ # This is a pseudo-status that is not exposed to the user
68
+ # but is used internally to indicate a rollback.
69
+ # We should never return it from the API, as it should never be persisted.
70
+ "rollback": enum_run_status.rollback,
71
+ }
72
+
73
+ RUN_STATUS_FROM_PB = {v: k for k, v in RUN_STATUS_TO_PB.items()}
74
+
75
+ CANCEL_STATUS_TO_PB = {
76
+ "pending": pb.CancelRunStatus.CANCEL_RUN_STATUS_PENDING,
77
+ "running": pb.CancelRunStatus.CANCEL_RUN_STATUS_RUNNING,
78
+ "all": pb.CancelRunStatus.CANCEL_RUN_STATUS_ALL,
79
+ }
80
+
81
+
82
+ def _map_run_status(status: RunStatus | None) -> enum_run_status.RunStatus | None:
83
+ """Map string status to protobuf enum."""
84
+ return None if status is None else RUN_STATUS_TO_PB.get(status)
85
+
86
+
87
+ MULTITASK_STRATEGY_TO_PB = {
88
+ "reject": enum_multitask_strategy.reject,
89
+ "interrupt": enum_multitask_strategy.interrupt,
90
+ "rollback": enum_multitask_strategy.rollback,
91
+ "enqueue": enum_multitask_strategy.enqueue,
92
+ }
93
+
94
+ MULTITASK_STRATEGY_FROM_PB = {v: k for k, v in MULTITASK_STRATEGY_TO_PB.items()}
95
+
96
+ STREAM_MODE_TO_PB = {
97
+ "unknown": enum_stream_mode.unknown,
98
+ "values": enum_stream_mode.values,
99
+ "updates": enum_stream_mode.updates,
100
+ "checkpoints": enum_stream_mode.checkpoints,
101
+ "tasks": enum_stream_mode.tasks,
102
+ "debug": enum_stream_mode.debug,
103
+ "messages": enum_stream_mode.messages,
104
+ "custom": enum_stream_mode.custom,
105
+ "events": enum_stream_mode.events,
106
+ "messages-tuple": enum_stream_mode.messages_tuple,
107
+ }
108
+
109
+ STREAM_MODE_FROM_PB = {
110
+ **{v: k for k, v in STREAM_MODE_TO_PB.items()},
111
+ # This isn't actually a valid stream mode (it's just a placeholder
112
+ # in the protobuf definition), so if we receive it from the gRPC
113
+ # server for some reason, we should suppress it to avoid exposing it
114
+ # to the user.
115
+ enum_stream_mode.unknown: None,
116
+ }
117
+
118
+
119
+ logger = structlog.stdlib.get_logger(__name__)
120
+
121
+
122
+ class GrpcRetryableException(Exception):
123
+ """Exception indicating a gRPC error that should trigger a run retry."""
124
+
125
+ pass
126
+
127
+
128
+ GRPC_RETRIABLE_STATUS_CODES = (
129
+ StatusCode.UNAVAILABLE,
130
+ StatusCode.DEADLINE_EXCEEDED,
131
+ )
132
+
133
+
134
+ def _map_multitask_strategy(
135
+ strategy: MultitaskStrategy | None,
136
+ ) -> enum_multitask_strategy.MultitaskStrategy | None:
137
+ """Map string multitask strategy to protobuf enum."""
138
+ return None if strategy is None else MULTITASK_STRATEGY_TO_PB.get(strategy)
139
+
140
+
141
+ def _map_if_not_exists(
142
+ if_not_exists: IfNotExists | None,
143
+ ) -> pb.CreateRunBehavior | None:
144
+ """Map if_not_exists string to protobuf enum."""
145
+ if if_not_exists is None:
146
+ return None
147
+ return (
148
+ pb.CreateRunBehavior.CREATE_THREAD_IF_THREAD_NOT_EXISTS
149
+ if if_not_exists == "create"
150
+ else pb.CreateRunBehavior.REJECT_RUN_IF_THREAD_NOT_EXISTS
151
+ )
152
+
153
+
154
+ def proto_to_run(proto_run: pb.Run) -> Run:
155
+ """Convert protobuf Run to dictionary format."""
156
+ return {
157
+ "run_id": UUID(proto_run.run_id.value)
158
+ if proto_run.HasField("run_id")
159
+ else None,
160
+ "thread_id": UUID(proto_run.thread_id.value)
161
+ if proto_run.HasField("thread_id")
162
+ else None,
163
+ "assistant_id": UUID(proto_run.assistant_id.value)
164
+ if proto_run.HasField("assistant_id")
165
+ else None,
166
+ "created_at": proto_run.created_at.ToDatetime(tzinfo=UTC)
167
+ if proto_run.HasField("created_at")
168
+ else None,
169
+ "updated_at": proto_run.updated_at.ToDatetime(tzinfo=UTC)
170
+ if proto_run.HasField("updated_at")
171
+ else None,
172
+ "status": RUN_STATUS_FROM_PB.get(proto_run.status, "pending"),
173
+ "metadata": json_loads_optional(proto_run.metadata.value)
174
+ if proto_run.HasField("metadata")
175
+ else {},
176
+ "kwargs": _proto_kwargs_to_dict(proto_run.kwargs)
177
+ if proto_run.HasField("kwargs")
178
+ else {},
179
+ "multitask_strategy": MULTITASK_STRATEGY_FROM_PB.get(
180
+ proto_run.multitask_strategy
181
+ ),
182
+ }
183
+
184
+
185
+ def _proto_kwargs_to_dict(kwargs: pb.RunKwargs) -> dict:
186
+ """Convert protobuf RunKwargs to dictionary format."""
187
+ result: dict = {
188
+ "input": json_loads_optional(kwargs.input_json)
189
+ if kwargs.HasField("input_json")
190
+ else None,
191
+ "config": dict(config_from_proto(kwargs.config))
192
+ if kwargs.HasField("config")
193
+ else None,
194
+ "context": json_loads_optional(kwargs.context_json)
195
+ if kwargs.HasField("context_json")
196
+ else None,
197
+ "command": json_loads_optional(kwargs.command_json)
198
+ if kwargs.HasField("command_json")
199
+ else None,
200
+ "stream_mode": [STREAM_MODE_FROM_PB.get(kwargs.stream_mode)]
201
+ if kwargs.stream_mode
202
+ else None,
203
+ "interrupt_before": list(kwargs.interrupt_before.node_names.names)
204
+ if kwargs.HasField("interrupt_before")
205
+ else None,
206
+ "interrupt_after": list(kwargs.interrupt_after.node_names.names)
207
+ if kwargs.HasField("interrupt_after")
208
+ else None,
209
+ "webhook": kwargs.webhook if kwargs.HasField("webhook") else None,
210
+ "feedback_keys": list(kwargs.feedback_keys) if kwargs.feedback_keys else None,
211
+ "temporary": kwargs.temporary if kwargs.HasField("temporary") else False,
212
+ "subgraphs": kwargs.subgraphs if kwargs.HasField("subgraphs") else False,
213
+ "resumable": kwargs.resumable if kwargs.HasField("resumable") else False,
214
+ "checkpoint_during": kwargs.checkpoint_during
215
+ if kwargs.HasField("checkpoint_during")
216
+ else True,
217
+ "durability": kwargs.durability if kwargs.HasField("durability") else None,
218
+ }
219
+ return result
220
+
221
+
222
+ def _filter_run_fields(run: Run, select: list[RunSelectField] | None) -> dict[str, Any]:
223
+ """Filter run fields based on select list.
224
+
225
+ Returns the original run if no fields are provided."""
226
+ if not select:
227
+ return run
228
+ return {field: run[field] for field in select if field in run}
229
+
230
+
231
+ @grpc_error_guard
232
+ class Runs(Authenticated):
233
+ """gRPC-based runs operations."""
234
+
235
+ # Auth for runs is applied at the thread level.
236
+ resource = "threads"
237
+
238
+ @staticmethod
239
+ async def search(
240
+ conn, # Not used in gRPC implementation
241
+ thread_id: UUID,
242
+ *,
243
+ limit: int = 10,
244
+ offset: int = 0,
245
+ status: RunStatus | None = None,
246
+ select: list[RunSelectField] | None = None,
247
+ ctx: Any = None,
248
+ ) -> AsyncIterator[Run]: # type: ignore[return-value]
249
+ """List all runs by thread."""
250
+ auth_filters = await Runs.handle_event(
251
+ ctx,
252
+ "search",
253
+ Auth.types.ThreadsSearch(thread_id=thread_id, metadata={}),
254
+ )
255
+
256
+ request_kwargs: dict[str, Any] = {
257
+ "filters": auth_filters,
258
+ "thread_id": pb.UUID(value=str(thread_id)),
259
+ "limit": limit,
260
+ "offset": offset,
261
+ }
262
+
263
+ mapped_status = _map_run_status(status)
264
+ if mapped_status is not None:
265
+ request_kwargs["status"] = mapped_status
266
+
267
+ if select:
268
+ request_kwargs["select"] = select
269
+
270
+ client = await get_shared_client()
271
+ response = await client.runs.Search(pb.SearchRunsRequest(**request_kwargs))
272
+
273
+ runs = [proto_to_run(run) for run in response.runs]
274
+
275
+ async def generate_results():
276
+ for run in runs:
277
+ yield _filter_run_fields(run, select)
278
+
279
+ return generate_results()
280
+
281
+ @staticmethod
282
+ async def count(
283
+ *,
284
+ thread_id: UUID | str,
285
+ statuses: list[str] | None = None,
286
+ ) -> int:
287
+ """Count runs matching criteria.
288
+
289
+ This is an internal method with no auth - used for checking
290
+ if a thread has pending/running runs.
291
+
292
+ Args:
293
+ thread_id: Thread ID to count runs for
294
+ statuses: Optional list of statuses to filter by (e.g., ["pending", "running"])
295
+
296
+ Returns:
297
+ Count of matching runs
298
+ """
299
+ request = pb.CountRunsRequest(
300
+ thread_id=pb.UUID(value=str(thread_id)),
301
+ statuses=statuses or [],
302
+ )
303
+
304
+ client = await get_shared_client()
305
+ response = await client.runs.Count(request)
306
+
307
+ return int(response.count)
308
+
309
+ @staticmethod
310
+ async def get(
311
+ conn, # Not used in gRPC implementation
312
+ run_id: UUID,
313
+ *,
314
+ thread_id: UUID,
315
+ ctx: Any = None,
316
+ ) -> AsyncIterator[Run]: # type: ignore[return-value]
317
+ """Get a run by ID."""
318
+ auth_filters = await Runs.handle_event(
319
+ ctx,
320
+ "read",
321
+ Auth.types.ThreadsRead(run_id=run_id, thread_id=thread_id),
322
+ )
323
+
324
+ request = pb.GetRunRequest(
325
+ run_id=pb.UUID(value=str(run_id)),
326
+ thread_id=pb.UUID(value=str(thread_id)),
327
+ filters=auth_filters,
328
+ )
329
+
330
+ client = await get_shared_client()
331
+ response = await client.runs.Get(request)
332
+
333
+ run = proto_to_run(response)
334
+
335
+ async def generate_result():
336
+ yield run
337
+
338
+ return generate_result()
339
+
340
+ @staticmethod
341
+ async def delete(
342
+ conn, # Not used in gRPC implementation
343
+ run_id: UUID,
344
+ *,
345
+ thread_id: UUID,
346
+ ctx: Any = None,
347
+ ) -> AsyncIterator[UUID]: # type: ignore[return-value]
348
+ """Delete a run by ID."""
349
+ auth_filters = await Runs.handle_event(
350
+ ctx,
351
+ "delete",
352
+ Auth.types.ThreadsDelete(run_id=run_id, thread_id=thread_id),
353
+ )
354
+
355
+ request = pb.DeleteRunRequest(
356
+ run_id=pb.UUID(value=str(run_id)),
357
+ thread_id=pb.UUID(value=str(thread_id)),
358
+ filters=auth_filters,
359
+ )
360
+
361
+ client = await get_shared_client()
362
+ response = await client.runs.Delete(request)
363
+
364
+ deleted_id = UUID(response.value)
365
+
366
+ async def generate_result():
367
+ yield deleted_id
368
+
369
+ return generate_result()
370
+
371
+ @staticmethod
372
+ async def put(
373
+ conn, # Not used in gRPC implementation
374
+ assistant_id: UUID,
375
+ kwargs: dict,
376
+ *,
377
+ thread_id: UUID | None = None,
378
+ user_id: str | None = None,
379
+ run_id: UUID | None = None,
380
+ status: RunStatus | None = "pending",
381
+ metadata: MetadataInput,
382
+ prevent_insert_if_inflight: bool,
383
+ multitask_strategy: MultitaskStrategy = "reject",
384
+ if_not_exists: IfNotExists = "reject",
385
+ after_seconds: int = 0,
386
+ ctx: Any = None,
387
+ ) -> AsyncIterator[Run]: # type: ignore[return-value]
388
+ """Create a run."""
389
+ metadata = metadata or {}
390
+ kwargs = kwargs or {}
391
+ temporary = kwargs.get("temporary", False)
392
+
393
+ auth_filters = await Runs.handle_event(
394
+ ctx,
395
+ "create_run",
396
+ Auth.types.RunsCreate(
397
+ thread_id=None if temporary else thread_id,
398
+ assistant_id=assistant_id,
399
+ run_id=run_id,
400
+ status=status,
401
+ metadata=metadata,
402
+ prevent_insert_if_inflight=prevent_insert_if_inflight,
403
+ multitask_strategy=multitask_strategy,
404
+ if_not_exists=if_not_exists,
405
+ after_seconds=after_seconds,
406
+ kwargs=kwargs,
407
+ ),
408
+ )
409
+
410
+ kwargs_json_bytes = json_dumpb(kwargs)
411
+ request_kwargs: dict[str, Any] = {
412
+ "assistant_id": pb.UUID(value=str(assistant_id)),
413
+ "kwargs_json": kwargs_json_bytes,
414
+ "thread_filters": auth_filters,
415
+ }
416
+
417
+ if thread_id is not None:
418
+ request_kwargs["thread_id"] = pb.UUID(value=str(thread_id))
419
+ if user_id is not None:
420
+ request_kwargs["user_id"] = user_id
421
+ if run_id is not None:
422
+ request_kwargs["run_id"] = pb.UUID(value=str(run_id))
423
+
424
+ mapped_status = _map_run_status(status)
425
+ if mapped_status is not None:
426
+ request_kwargs["status"] = mapped_status
427
+ if metadata:
428
+ request_kwargs["metadata_json"] = json_dumpb_optional(metadata)
429
+ if prevent_insert_if_inflight:
430
+ request_kwargs["prevent_insert_if_inflight"] = prevent_insert_if_inflight
431
+
432
+ mapped_strategy = _map_multitask_strategy(multitask_strategy)
433
+ if mapped_strategy is not None:
434
+ request_kwargs["multitask_strategy"] = mapped_strategy
435
+
436
+ mapped_if_not_exists = _map_if_not_exists(if_not_exists)
437
+ if mapped_if_not_exists is not None:
438
+ request_kwargs["if_not_exists"] = mapped_if_not_exists
439
+
440
+ if after_seconds > 0:
441
+ request_kwargs["after_seconds"] = int(after_seconds)
442
+
443
+ client = await get_shared_client()
444
+ response = await client.runs.Create(pb.CreateRunRequest(**request_kwargs))
445
+
446
+ async def generate_result():
447
+ for run in response.runs:
448
+ yield proto_to_run(run)
449
+
450
+ return generate_result()
451
+
452
+ @staticmethod
453
+ async def cancel(
454
+ conn, # Not used in gRPC implementation
455
+ run_ids: Sequence[UUID] | None = None,
456
+ *,
457
+ action: Literal["interrupt", "rollback"] = "interrupt",
458
+ thread_id: UUID | None = None,
459
+ status: Literal["pending", "running", "all"] | None = None,
460
+ ctx: Any = None,
461
+ ) -> None:
462
+ """Cancel runs.
463
+
464
+ Must provide either:
465
+ 1) thread_id + run_ids, or
466
+ 2) a status (pending, running, all).
467
+ """
468
+ if status is not None:
469
+ if thread_id is not None or run_ids is not None:
470
+ raise HTTPException(
471
+ status_code=HTTPStatus.UNPROCESSABLE_ENTITY,
472
+ detail="Cannot specify 'thread_id' or 'run_ids' when using 'status'",
473
+ )
474
+ elif thread_id is None or run_ids is None:
475
+ raise HTTPException(
476
+ status_code=HTTPStatus.UNPROCESSABLE_ENTITY,
477
+ detail="Please provide a thread_id and run_ids, or a status to cancel",
478
+ )
479
+
480
+ auth_filters = await Runs.handle_event(
481
+ ctx,
482
+ "update",
483
+ Auth.types.ThreadsUpdate(
484
+ thread_id=thread_id, # type: ignore
485
+ action=action,
486
+ metadata={"run_ids": run_ids, "status": status},
487
+ ),
488
+ )
489
+
490
+ action_enum = (
491
+ enum_cancel_run_action.rollback
492
+ if action == "rollback"
493
+ else enum_cancel_run_action.interrupt
494
+ )
495
+
496
+ request_kwargs: dict[str, Any] = {
497
+ "filters": auth_filters,
498
+ "action": action_enum,
499
+ }
500
+
501
+ if status is not None:
502
+ request_kwargs["status"] = pb.CancelStatusTarget(
503
+ status=CANCEL_STATUS_TO_PB[status]
504
+ )
505
+ else:
506
+ request_kwargs["run_ids"] = pb.CancelRunIdsTarget(
507
+ thread_id=pb.UUID(value=str(thread_id)),
508
+ run_ids=[pb.UUID(value=str(rid)) for rid in run_ids], # type: ignore
509
+ )
510
+
511
+ client = await get_shared_client()
512
+ await client.runs.Cancel(pb.CancelRunRequest(**request_kwargs))
513
+
514
+ @staticmethod
515
+ async def stats(conn) -> QueueStats: # type: ignore[return-value]
516
+ """Get queue statistics (not exposed via API, no auth)."""
517
+ client = await get_shared_client()
518
+ response = await client.runs.Stats(Empty())
519
+
520
+ return {
521
+ "n_pending": response.n_pending,
522
+ "n_running": response.n_running,
523
+ "pending_runs_wait_time_max_secs": (
524
+ response.pending_runs_wait_time_max_secs
525
+ if response.HasField("pending_runs_wait_time_max_secs")
526
+ else None
527
+ ),
528
+ "pending_runs_wait_time_med_secs": (
529
+ response.pending_runs_wait_time_med_secs
530
+ if response.HasField("pending_runs_wait_time_med_secs")
531
+ else None
532
+ ),
533
+ "pending_unblocked_runs_wait_time_max_secs": (
534
+ response.pending_unblocked_runs_wait_time_max_secs
535
+ if response.HasField("pending_unblocked_runs_wait_time_max_secs")
536
+ else None
537
+ ),
538
+ }
539
+
540
+ @staticmethod
541
+ async def set_status(
542
+ conn, # Not used in gRPC implementation
543
+ run_id: UUID,
544
+ status: RunStatus,
545
+ ) -> None:
546
+ """Set the status of a run (not exposed via API, no auth)."""
547
+ mapped_status = _map_run_status(status)
548
+ if mapped_status is None:
549
+ return
550
+
551
+ request = pb.SetRunStatusRequest(
552
+ run_id=pb.UUID(value=str(run_id)),
553
+ status=mapped_status,
554
+ )
555
+
556
+ client = await get_shared_client()
557
+ await client.runs.SetStatus(request)
558
+
559
+ @staticmethod
560
+ async def sweep() -> list[UUID]:
561
+ """Sweep runs that have been in running state for too long (not exposed via API, no auth)."""
562
+ client = await get_shared_client()
563
+ response = await client.runs.Sweep(Empty())
564
+
565
+ return [UUID(uuid_pb.value) for uuid_pb in response.run_ids]
566
+
567
+ @staticmethod
568
+ async def _mark_done(run_id: UUID, thread_id: UUID, resumable: bool) -> None:
569
+ """Mark a run as done by signaling control and publishing to stream.
570
+
571
+ Internal method used by workers. Not exposed via API, no auth.
572
+ """
573
+ request = pb.MarkRunDoneRequest(
574
+ run_id=pb.UUID(value=str(run_id)),
575
+ thread_id=pb.UUID(value=str(thread_id)),
576
+ resumable=resumable,
577
+ )
578
+
579
+ client = await get_shared_client()
580
+ await client.runs.MarkDone(request)
581
+
582
+ @staticmethod
583
+ async def next(wait: bool, limit: int = 1) -> AsyncIterator[tuple[Run, int]]: # type: ignore[return-value]
584
+ """Get the next run from the queue, and the attempt number.
585
+
586
+ 1 is the first attempt, 2 is the first retry, etc.
587
+
588
+ Not exposed via API, no auth.
589
+ """
590
+ request = pb.NextRunRequest(wait=wait, limit=limit)
591
+
592
+ client = await get_shared_client()
593
+ response = await client.runs.Next(request)
594
+
595
+ async def generate_results():
596
+ for run_with_attempt in response.runs:
597
+ run = proto_to_run(run_with_attempt.run)
598
+ yield run, run_with_attempt.attempt
599
+
600
+ return generate_results()
601
+
602
+ class Stream(Authenticated):
603
+ """Stream operations for runs."""
604
+
605
+ resource = "threads"
606
+
607
+ @staticmethod
608
+ async def subscribe(
609
+ run_id: UUID,
610
+ thread_id: UUID | None = None,
611
+ ):
612
+ """Subscribe to the run stream, returning a stream handler.
613
+
614
+ The stream handler must be passed to `join` to receive messages.
615
+ """
616
+ # TODO: Implement gRPC streaming subscription
617
+ raise NotImplementedError("Stream.subscribe not yet implemented for gRPC")
618
+
619
+ @staticmethod
620
+ async def join(
621
+ run_id: UUID,
622
+ *,
623
+ stream_channel,
624
+ thread_id: UUID,
625
+ ignore_404: bool = False,
626
+ cancel_on_disconnect: bool = False,
627
+ stream_mode=None,
628
+ last_event_id: str | None = None,
629
+ ctx: Any = None,
630
+ ):
631
+ """Stream the run output."""
632
+ # TODO: Implement gRPC streaming join
633
+ raise NotImplementedError("Stream.join not yet implemented for gRPC")
634
+
635
+ @staticmethod
636
+ async def check_run_stream_auth(
637
+ run_id: UUID,
638
+ thread_id: UUID,
639
+ ctx: Any = None,
640
+ ) -> None:
641
+ """Check auth for streaming a run."""
642
+ # TODO: Implement auth check for gRPC streaming
643
+ raise NotImplementedError(
644
+ "Stream.check_run_stream_auth not yet implemented for gRPC"
645
+ )
646
+
647
+ @staticmethod
648
+ async def publish(
649
+ run_id: UUID | str,
650
+ event: str,
651
+ message: bytes,
652
+ *,
653
+ thread_id: UUID | str | None = None,
654
+ resumable: bool = False,
655
+ ) -> None:
656
+ """Publish a message to the run stream."""
657
+ # TODO: Implement gRPC stream publishing
658
+ raise NotImplementedError("Stream.publish not yet implemented for gRPC")
659
+
660
+ @staticmethod
661
+ def enter(
662
+ run_id: UUID,
663
+ thread_id: UUID,
664
+ loop: Any, # unused, for API compatibility
665
+ resumable: bool,
666
+ ):
667
+ """Enter a run context manager for execution.
668
+
669
+ Opens a streaming Enter RPC that:
670
+ 1. Starts heartbeat and Redis cancellation listening on the server
671
+ 2. Streams back control signals (interrupt/rollback) when they occur
672
+ 3. Returns a ValueEvent that will be set on interrupt/rollback
673
+
674
+ Args:
675
+ run_id: The run ID
676
+ thread_id: The thread ID
677
+ loop: The event loop (unused in gRPC implementation)
678
+ resumable: Whether the run is resumable
679
+
680
+ Yields:
681
+ ValueEvent that will be set with UserInterrupt() or UserRollback() if cancelled
682
+ """
683
+
684
+ @asynccontextmanager
685
+ async def _enter_impl():
686
+ done = ValueEvent()
687
+
688
+ # Open streaming Enter RPC
689
+ client = await get_shared_client()
690
+ enter_request = pb.EnterRunRequest(
691
+ run_id=pb.UUID(value=str(run_id)),
692
+ thread_id=pb.UUID(value=str(thread_id)),
693
+ resumable=resumable,
694
+ )
695
+ enter_stream = client.runs.Enter(enter_request)
696
+
697
+ async with SimpleTaskGroup(cancel=True, taskgroup_name="Runs.enter") as tg:
698
+ # Background task to listen for control signals from the stream
699
+ async def listen_for_signals():
700
+ try:
701
+ async for event in enter_stream:
702
+ if event.action == enum_control_signal.interrupt:
703
+ logger.info(
704
+ "Received interrupt signal from gRPC stream",
705
+ run_id=run_id,
706
+ thread_id=thread_id,
707
+ )
708
+ done.set(UserInterrupt())
709
+ break
710
+ elif event.action == enum_control_signal.rollback:
711
+ logger.info(
712
+ "Received rollback signal from gRPC stream",
713
+ run_id=run_id,
714
+ thread_id=thread_id,
715
+ )
716
+ done.set(UserRollback())
717
+ break
718
+ except Exception as exc:
719
+ logger.exception(
720
+ "listen_for_signals failed",
721
+ run_id=run_id,
722
+ thread_id=thread_id,
723
+ exc_info=exc,
724
+ )
725
+ done.set(
726
+ GrpcRetryableException(
727
+ f"listen_for_signals failed. \nError: {exc!r}"
728
+ )
729
+ )
730
+ raise
731
+
732
+ tg.create_task(listen_for_signals())
733
+
734
+ try:
735
+ yield done
736
+ # Signal done via gRPC (stops heartbeat and cleanup on server)
737
+ await Runs.mark_done(
738
+ run_id=run_id, thread_id=thread_id, resumable=resumable
739
+ )
740
+ except GrpcRetryableException:
741
+ logger.info(
742
+ "Retriable exception, will not signal done",
743
+ run_id=run_id,
744
+ thread_id=thread_id,
745
+ )
746
+ except AioRpcError as e:
747
+ if e.code() in GRPC_RETRIABLE_STATUS_CODES:
748
+ logger.info(
749
+ "Retriable gRPC error, will not signal done",
750
+ run_id=run_id,
751
+ thread_id=thread_id,
752
+ grpc_code=e.code().name,
753
+ )
754
+ else:
755
+ logger.exception(
756
+ "Non-retriable gRPC error when marking run as done",
757
+ run_id=run_id,
758
+ thread_id=thread_id,
759
+ grpc_code=e.code().name,
760
+ )
761
+ raise
762
+ except BaseException:
763
+ logger.exception(
764
+ "Non-retriable exception when marking run as done",
765
+ run_id=run_id,
766
+ thread_id=thread_id,
767
+ )
768
+ raise
769
+
770
+ return _enter_impl()
771
+
772
+ @staticmethod
773
+ @grpc_error_guard
774
+ async def mark_done(
775
+ run_id: UUID,
776
+ thread_id: UUID,
777
+ resumable: bool,
778
+ ) -> None:
779
+ """Mark a run as done.
780
+
781
+ Args:
782
+ run_id: The run ID
783
+ thread_id: The thread ID
784
+ resumable: Whether streaming is resumable
785
+ """
786
+ client = await get_shared_client()
787
+ request = pb.MarkRunDoneRequest(
788
+ run_id=pb.UUID(value=str(run_id)),
789
+ thread_id=pb.UUID(value=str(thread_id)),
790
+ resumable=resumable,
791
+ )
792
+ await client.runs.MarkDone(request)