langgraph-api 0.5.4__py3-none-any.whl → 0.7.3__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (122) hide show
  1. langgraph_api/__init__.py +1 -1
  2. langgraph_api/api/__init__.py +93 -27
  3. langgraph_api/api/a2a.py +36 -32
  4. langgraph_api/api/assistants.py +114 -26
  5. langgraph_api/api/mcp.py +3 -3
  6. langgraph_api/api/meta.py +15 -2
  7. langgraph_api/api/openapi.py +27 -17
  8. langgraph_api/api/profile.py +108 -0
  9. langgraph_api/api/runs.py +114 -57
  10. langgraph_api/api/store.py +19 -2
  11. langgraph_api/api/threads.py +133 -10
  12. langgraph_api/asgi_transport.py +14 -9
  13. langgraph_api/auth/custom.py +23 -13
  14. langgraph_api/cli.py +86 -41
  15. langgraph_api/command.py +2 -2
  16. langgraph_api/config/__init__.py +532 -0
  17. langgraph_api/config/_parse.py +58 -0
  18. langgraph_api/config/schemas.py +431 -0
  19. langgraph_api/cron_scheduler.py +17 -1
  20. langgraph_api/encryption/__init__.py +15 -0
  21. langgraph_api/encryption/aes_json.py +158 -0
  22. langgraph_api/encryption/context.py +35 -0
  23. langgraph_api/encryption/custom.py +280 -0
  24. langgraph_api/encryption/middleware.py +632 -0
  25. langgraph_api/encryption/shared.py +63 -0
  26. langgraph_api/errors.py +12 -1
  27. langgraph_api/executor_entrypoint.py +11 -6
  28. langgraph_api/feature_flags.py +19 -0
  29. langgraph_api/graph.py +163 -64
  30. langgraph_api/{grpc_ops → grpc}/client.py +142 -12
  31. langgraph_api/{grpc_ops → grpc}/config_conversion.py +16 -10
  32. langgraph_api/grpc/generated/__init__.py +29 -0
  33. langgraph_api/grpc/generated/checkpointer_pb2.py +63 -0
  34. langgraph_api/grpc/generated/checkpointer_pb2.pyi +99 -0
  35. langgraph_api/grpc/generated/checkpointer_pb2_grpc.py +329 -0
  36. langgraph_api/grpc/generated/core_api_pb2.py +216 -0
  37. langgraph_api/{grpc_ops → grpc}/generated/core_api_pb2.pyi +292 -372
  38. langgraph_api/{grpc_ops → grpc}/generated/core_api_pb2_grpc.py +252 -31
  39. langgraph_api/grpc/generated/engine_common_pb2.py +219 -0
  40. langgraph_api/{grpc_ops → grpc}/generated/engine_common_pb2.pyi +178 -104
  41. langgraph_api/grpc/generated/enum_cancel_run_action_pb2.py +37 -0
  42. langgraph_api/grpc/generated/enum_cancel_run_action_pb2.pyi +12 -0
  43. langgraph_api/grpc/generated/enum_cancel_run_action_pb2_grpc.py +24 -0
  44. langgraph_api/grpc/generated/enum_control_signal_pb2.py +37 -0
  45. langgraph_api/grpc/generated/enum_control_signal_pb2.pyi +16 -0
  46. langgraph_api/grpc/generated/enum_control_signal_pb2_grpc.py +24 -0
  47. langgraph_api/grpc/generated/enum_durability_pb2.py +37 -0
  48. langgraph_api/grpc/generated/enum_durability_pb2.pyi +16 -0
  49. langgraph_api/grpc/generated/enum_durability_pb2_grpc.py +24 -0
  50. langgraph_api/grpc/generated/enum_multitask_strategy_pb2.py +37 -0
  51. langgraph_api/grpc/generated/enum_multitask_strategy_pb2.pyi +16 -0
  52. langgraph_api/grpc/generated/enum_multitask_strategy_pb2_grpc.py +24 -0
  53. langgraph_api/grpc/generated/enum_run_status_pb2.py +37 -0
  54. langgraph_api/grpc/generated/enum_run_status_pb2.pyi +22 -0
  55. langgraph_api/grpc/generated/enum_run_status_pb2_grpc.py +24 -0
  56. langgraph_api/grpc/generated/enum_stream_mode_pb2.py +37 -0
  57. langgraph_api/grpc/generated/enum_stream_mode_pb2.pyi +28 -0
  58. langgraph_api/grpc/generated/enum_stream_mode_pb2_grpc.py +24 -0
  59. langgraph_api/grpc/generated/enum_thread_status_pb2.py +37 -0
  60. langgraph_api/grpc/generated/enum_thread_status_pb2.pyi +16 -0
  61. langgraph_api/grpc/generated/enum_thread_status_pb2_grpc.py +24 -0
  62. langgraph_api/grpc/generated/enum_thread_stream_mode_pb2.py +37 -0
  63. langgraph_api/grpc/generated/enum_thread_stream_mode_pb2.pyi +16 -0
  64. langgraph_api/grpc/generated/enum_thread_stream_mode_pb2_grpc.py +24 -0
  65. langgraph_api/grpc/generated/errors_pb2.py +39 -0
  66. langgraph_api/grpc/generated/errors_pb2.pyi +21 -0
  67. langgraph_api/grpc/generated/errors_pb2_grpc.py +24 -0
  68. langgraph_api/grpc/ops/__init__.py +370 -0
  69. langgraph_api/grpc/ops/assistants.py +424 -0
  70. langgraph_api/grpc/ops/runs.py +792 -0
  71. langgraph_api/grpc/ops/threads.py +1013 -0
  72. langgraph_api/http.py +16 -5
  73. langgraph_api/js/client.mts +1 -4
  74. langgraph_api/js/package.json +28 -27
  75. langgraph_api/js/remote.py +39 -17
  76. langgraph_api/js/sse.py +2 -2
  77. langgraph_api/js/ui.py +1 -1
  78. langgraph_api/js/yarn.lock +1139 -869
  79. langgraph_api/metadata.py +29 -3
  80. langgraph_api/middleware/http_logger.py +1 -1
  81. langgraph_api/middleware/private_network.py +7 -7
  82. langgraph_api/models/run.py +44 -26
  83. langgraph_api/otel_context.py +205 -0
  84. langgraph_api/patch.py +2 -2
  85. langgraph_api/queue_entrypoint.py +34 -35
  86. langgraph_api/route.py +33 -1
  87. langgraph_api/schema.py +84 -9
  88. langgraph_api/self_hosted_logs.py +2 -2
  89. langgraph_api/self_hosted_metrics.py +73 -3
  90. langgraph_api/serde.py +16 -4
  91. langgraph_api/server.py +33 -31
  92. langgraph_api/state.py +3 -2
  93. langgraph_api/store.py +25 -16
  94. langgraph_api/stream.py +20 -16
  95. langgraph_api/thread_ttl.py +28 -13
  96. langgraph_api/timing/__init__.py +25 -0
  97. langgraph_api/timing/profiler.py +200 -0
  98. langgraph_api/timing/timer.py +318 -0
  99. langgraph_api/utils/__init__.py +53 -8
  100. langgraph_api/utils/config.py +2 -1
  101. langgraph_api/utils/future.py +10 -6
  102. langgraph_api/utils/uuids.py +29 -62
  103. langgraph_api/validation.py +6 -0
  104. langgraph_api/webhook.py +120 -6
  105. langgraph_api/worker.py +54 -24
  106. {langgraph_api-0.5.4.dist-info → langgraph_api-0.7.3.dist-info}/METADATA +8 -6
  107. langgraph_api-0.7.3.dist-info/RECORD +168 -0
  108. {langgraph_api-0.5.4.dist-info → langgraph_api-0.7.3.dist-info}/WHEEL +1 -1
  109. langgraph_runtime/__init__.py +1 -0
  110. langgraph_runtime/routes.py +11 -0
  111. logging.json +1 -3
  112. openapi.json +635 -537
  113. langgraph_api/config.py +0 -523
  114. langgraph_api/grpc_ops/generated/__init__.py +0 -5
  115. langgraph_api/grpc_ops/generated/core_api_pb2.py +0 -275
  116. langgraph_api/grpc_ops/generated/engine_common_pb2.py +0 -194
  117. langgraph_api/grpc_ops/ops.py +0 -1045
  118. langgraph_api-0.5.4.dist-info/RECORD +0 -121
  119. /langgraph_api/{grpc_ops → grpc}/__init__.py +0 -0
  120. /langgraph_api/{grpc_ops → grpc}/generated/engine_common_pb2_grpc.py +0 -0
  121. {langgraph_api-0.5.4.dist-info → langgraph_api-0.7.3.dist-info}/entry_points.txt +0 -0
  122. {langgraph_api-0.5.4.dist-info → langgraph_api-0.7.3.dist-info}/licenses/LICENSE +0 -0
langgraph_api/api/mcp.py CHANGED
@@ -193,13 +193,13 @@ async def handle_post_request(request: ApiRequest) -> Response:
193
193
  # Careful ID checks as the integer 0 is a valid ID
194
194
  if id_ is not None and method:
195
195
  # JSON-RPC request
196
- return await handle_jsonrpc_request(request, cast(JsonRpcRequest, message))
196
+ return await handle_jsonrpc_request(request, cast("JsonRpcRequest", message))
197
197
  elif id_ is not None:
198
198
  # JSON-RPC response
199
- return handle_jsonrpc_response(cast(JsonRpcResponse, message))
199
+ return handle_jsonrpc_response(cast("JsonRpcResponse", message))
200
200
  elif method:
201
201
  # JSON-RPC notification
202
- return handle_jsonrpc_notification(cast(JsonRpcNotification, message))
202
+ return handle_jsonrpc_notification(cast("JsonRpcNotification", message))
203
203
  else:
204
204
  # Invalid message format
205
205
  return create_error_response(
langgraph_api/api/meta.py CHANGED
@@ -3,6 +3,8 @@ import structlog
3
3
  from starlette.responses import JSONResponse, PlainTextResponse
4
4
 
5
5
  from langgraph_api import __version__, config, metadata
6
+ from langgraph_api.feature_flags import FF_USE_CORE_API
7
+ from langgraph_api.grpc.ops import Runs as GrpcRuns
6
8
  from langgraph_api.http_metrics import HTTP_METRICS_COLLECTOR
7
9
  from langgraph_api.route import ApiRequest
8
10
  from langgraph_license.validation import plus_features_enabled
@@ -10,6 +12,8 @@ from langgraph_runtime.database import connect, pool_stats
10
12
  from langgraph_runtime.metrics import get_metrics
11
13
  from langgraph_runtime.ops import Runs
12
14
 
15
+ CrudRuns = GrpcRuns if FF_USE_CORE_API else Runs
16
+
13
17
  METRICS_FORMATS = {"prometheus", "json"}
14
18
 
15
19
  logger = structlog.stdlib.get_logger(__name__)
@@ -66,7 +70,7 @@ async def meta_metrics(request: ApiRequest):
66
70
  async with connect() as conn:
67
71
  resp = {
68
72
  **pg_redis_stats,
69
- "queue": await Runs.stats(conn),
73
+ "queue": await CrudRuns.stats(conn),
70
74
  **http_metrics,
71
75
  }
72
76
  if config.N_JOBS_PER_WORKER > 0:
@@ -76,7 +80,7 @@ async def meta_metrics(request: ApiRequest):
76
80
  metrics = []
77
81
  try:
78
82
  async with connect() as conn:
79
- queue_stats = await Runs.stats(conn)
83
+ queue_stats = await CrudRuns.stats(conn)
80
84
 
81
85
  metrics.extend(
82
86
  [
@@ -86,6 +90,15 @@ async def meta_metrics(request: ApiRequest):
86
90
  "# HELP lg_api_num_running_runs The number of runs currently running.",
87
91
  "# TYPE lg_api_num_running_runs gauge",
88
92
  f'lg_api_num_running_runs{{project_id="{metadata.PROJECT_ID}", revision_id="{metadata.HOST_REVISION_ID}"}} {queue_stats["n_running"]}',
93
+ "# HELP lg_api_pending_runs_wait_time_max The maximum time a run has been pending, in seconds.",
94
+ "# TYPE lg_api_pending_runs_wait_time_max gauge",
95
+ f'lg_api_pending_runs_wait_time_max{{project_id="{metadata.PROJECT_ID}", revision_id="{metadata.HOST_REVISION_ID}"}} {queue_stats.get("pending_runs_wait_time_max_secs") or 0}',
96
+ "# HELP lg_api_pending_runs_wait_time_med The median pending wait time across runs, in seconds.",
97
+ "# TYPE lg_api_pending_runs_wait_time_med gauge",
98
+ f'lg_api_pending_runs_wait_time_med{{project_id="{metadata.PROJECT_ID}", revision_id="{metadata.HOST_REVISION_ID}"}} {queue_stats.get("pending_runs_wait_time_med_secs") or 0}',
99
+ "# HELP lg_api_pending_unblocked_runs_wait_time_max The maximum time a run has been pending excluding runs blocked by another run on the same thread, in seconds.",
100
+ "# TYPE lg_api_pending_unblocked_runs_wait_time_max gauge",
101
+ f'lg_api_pending_unblocked_runs_wait_time_max{{project_id="{metadata.PROJECT_ID}", revision_id="{metadata.HOST_REVISION_ID}"}} {queue_stats.get("pending_unblocked_runs_wait_time_max_secs") or 0}',
89
102
  ]
90
103
  )
91
104
  except Exception as e:
@@ -5,12 +5,7 @@ from functools import lru_cache
5
5
 
6
6
  import orjson
7
7
 
8
- from langgraph_api.config import (
9
- HTTP_CONFIG,
10
- LANGGRAPH_AUTH,
11
- LANGGRAPH_AUTH_TYPE,
12
- MOUNT_PREFIX,
13
- )
8
+ from langgraph_api import config
14
9
  from langgraph_api.graph import GRAPHS
15
10
  from langgraph_api.validation import openapi
16
11
 
@@ -39,17 +34,20 @@ def get_openapi_spec() -> bytes:
39
34
  graph_ids
40
35
  )
41
36
  # patch the auth schemes
42
- if LANGGRAPH_AUTH_TYPE == "langsmith":
37
+ if config.LANGGRAPH_AUTH_TYPE == "langsmith":
43
38
  openapi["security"] = [
44
39
  {"x-api-key": []},
45
40
  ]
46
41
  openapi["components"]["securitySchemes"] = {
47
42
  "x-api-key": {"type": "apiKey", "in": "header", "name": "x-api-key"}
48
43
  }
49
- if LANGGRAPH_AUTH:
44
+ if config.LANGGRAPH_AUTH:
50
45
  # Allow user to specify OpenAPI security configuration
51
- if isinstance(LANGGRAPH_AUTH, dict) and "openapi" in LANGGRAPH_AUTH:
52
- openapi_config = LANGGRAPH_AUTH["openapi"]
46
+ if (
47
+ isinstance(config.LANGGRAPH_AUTH, dict)
48
+ and "openapi" in config.LANGGRAPH_AUTH
49
+ ):
50
+ openapi_config = config.LANGGRAPH_AUTH["openapi"]
53
51
  if isinstance(openapi_config, dict):
54
52
  # Add security schemes
55
53
  if "securitySchemes" in openapi_config:
@@ -82,8 +80,13 @@ def get_openapi_spec() -> bytes:
82
80
  )
83
81
 
84
82
  # Remove webhook parameters if webhooks are disabled
85
- if HTTP_CONFIG and HTTP_CONFIG.get("disable_webhooks"):
86
- webhook_schemas = ["CronCreate", "RunCreateStateful", "RunCreateStateless"]
83
+ if config.WEBHOOKS_ENABLED:
84
+ webhook_schemas = [
85
+ "CronCreate",
86
+ "ThreadCronCreate",
87
+ "RunCreateStateful",
88
+ "RunCreateStateless",
89
+ ]
87
90
  for schema_name in webhook_schemas:
88
91
  if schema_name in openapi["components"]["schemas"]:
89
92
  schema = openapi["components"]["schemas"][schema_name]
@@ -96,14 +99,21 @@ def get_openapi_spec() -> bytes:
96
99
  final = openapi
97
100
  if CUSTOM_OPENAPI_SPEC:
98
101
  final = merge_openapi_specs(openapi, CUSTOM_OPENAPI_SPEC)
99
- if MOUNT_PREFIX:
100
- final["servers"] = [{"url": MOUNT_PREFIX}]
101
-
102
- MCP_ENABLED = HTTP_CONFIG is None or not HTTP_CONFIG.get("disable_mcp")
102
+ if config.MOUNT_PREFIX:
103
+ final["servers"] = [{"url": config.MOUNT_PREFIX}]
103
104
 
104
- if not MCP_ENABLED:
105
+ if not config.MCP_ENABLED:
105
106
  # Remove the MCP paths from the OpenAPI spec
106
107
  final["paths"].pop("/mcp/", None)
108
+ # Remove the MCP tag definition
109
+ final["tags"] = [t for t in final.get("tags", []) if t.get("name") != "MCP"]
110
+
111
+ if not config.A2A_ENABLED:
112
+ # Remove the A2A paths from the OpenAPI spec
113
+ final["paths"].pop("/a2a/{assistant_id}", None)
114
+ final["paths"].pop("/.well-known/agent-card.json", None)
115
+ # Remove the A2A tag definition
116
+ final["tags"] = [t for t in final.get("tags", []) if t.get("name") != "A2A"]
107
117
 
108
118
  return orjson.dumps(final)
109
119
 
@@ -0,0 +1,108 @@
1
+ import asyncio
2
+ import contextlib
3
+ import os
4
+ import shutil
5
+ import subprocess
6
+ import tempfile
7
+ import time
8
+ from typing import Literal
9
+
10
+ import structlog
11
+ from starlette.responses import JSONResponse, Response
12
+ from starlette.routing import BaseRoute
13
+
14
+ from langgraph_api import config
15
+ from langgraph_api.route import ApiRequest, ApiRoute
16
+
17
+ logger = structlog.stdlib.get_logger(__name__)
18
+
19
+
20
+ def _clamp_duration(seconds: int) -> int:
21
+ seconds = max(1, seconds)
22
+ return min(seconds, max(1, config.FF_PYSPY_PROFILING_MAX_DURATION_SECS))
23
+
24
+
25
+ async def _profile_with_pyspy(seconds: int, fmt: Literal["svg"]) -> Response:
26
+ """Run py-spy against the current process for N seconds and return SVG."""
27
+ pyspy = shutil.which("py-spy")
28
+ if not pyspy:
29
+ return JSONResponse({"error": "py-spy not found on PATH"}, status_code=501)
30
+
31
+ # py-spy writes to a file; use a temp file then return its contents.
32
+ fd, path = tempfile.mkstemp(suffix=".svg")
33
+ os.close(fd)
34
+ try:
35
+ pid = os.getpid()
36
+ # Example:
37
+ # py-spy record -p <pid> -d <seconds> --format flamegraph -o out.svg
38
+ cmd = [
39
+ pyspy,
40
+ "record",
41
+ "-p",
42
+ str(pid),
43
+ "-d",
44
+ str(seconds),
45
+ "--format",
46
+ "flamegraph",
47
+ "-o",
48
+ path,
49
+ ]
50
+ proc = await asyncio.create_subprocess_exec(
51
+ *cmd, stdout=subprocess.PIPE, stderr=subprocess.PIPE
52
+ )
53
+ try:
54
+ _, stderr = await asyncio.wait_for(proc.communicate(), timeout=seconds + 15)
55
+ except TimeoutError:
56
+ with contextlib.suppress(ProcessLookupError):
57
+ proc.kill()
58
+ return JSONResponse(
59
+ {
60
+ "error": "py-spy timed out",
61
+ "hint": "Check ptrace permissions or reduce duration",
62
+ },
63
+ status_code=504,
64
+ )
65
+ if proc.returncode != 0:
66
+ # Common failures: missing ptrace capability in containers.
67
+ msg = stderr.decode("utf-8", errors="ignore") if stderr else "py-spy failed"
68
+ await logger.awarning("py-spy failed", returncode=proc.returncode, msg=msg)
69
+ return JSONResponse(
70
+ {
71
+ "error": "py-spy failed",
72
+ "detail": msg,
73
+ "hint": "Ensure the container has CAP_SYS_PTRACE / seccomp=unconfined",
74
+ },
75
+ status_code=500,
76
+ )
77
+
78
+ with open(path, "rb") as f:
79
+ content = f.read()
80
+ ts = int(time.time())
81
+ return Response(
82
+ content,
83
+ media_type="image/svg+xml",
84
+ headers={
85
+ "Content-Disposition": f"inline; filename=pyspy-{ts}.svg",
86
+ "Cache-Control": "no-store",
87
+ },
88
+ )
89
+ finally:
90
+ with contextlib.suppress(FileNotFoundError):
91
+ os.remove(path)
92
+
93
+
94
+ async def profile(request: ApiRequest):
95
+ if not config.FF_PYSPY_PROFILING_ENABLED:
96
+ return JSONResponse({"error": "Profiling disabled"}, status_code=403)
97
+
98
+ params = request.query_params
99
+ try:
100
+ seconds = _clamp_duration(int(params.get("seconds", "15")))
101
+ except ValueError:
102
+ return JSONResponse({"error": "Invalid seconds"}, status_code=400)
103
+ return await _profile_with_pyspy(seconds, "svg")
104
+
105
+
106
+ profile_routes: list[BaseRoute] = [
107
+ ApiRoute("/profile", profile, methods=["GET"]),
108
+ ]
langgraph_api/api/runs.py CHANGED
@@ -10,9 +10,24 @@ from starlette.responses import Response, StreamingResponse
10
10
 
11
11
  from langgraph_api import config
12
12
  from langgraph_api.asyncio import ValueEvent
13
+ from langgraph_api.encryption.middleware import (
14
+ decrypt_response,
15
+ decrypt_responses,
16
+ encrypt_request,
17
+ )
18
+ from langgraph_api.feature_flags import FF_USE_CORE_API
19
+ from langgraph_api.graph import _validate_assistant_id
20
+ from langgraph_api.grpc.ops import Runs as GrpcRuns
13
21
  from langgraph_api.models.run import create_valid_run
14
22
  from langgraph_api.route import ApiRequest, ApiResponse, ApiRoute
15
- from langgraph_api.schema import CRON_FIELDS, RUN_FIELDS
23
+ from langgraph_api.schema import (
24
+ CRON_ENCRYPTION_FIELDS,
25
+ CRON_FIELDS,
26
+ CRON_PAYLOAD_ENCRYPTION_SUBFIELDS,
27
+ RUN_ENCRYPTION_FIELDS,
28
+ RUN_FIELDS,
29
+ )
30
+ from langgraph_api.serde import json_dumpb, json_loads
16
31
  from langgraph_api.sse import EventSourceResponse
17
32
  from langgraph_api.utils import (
18
33
  fetchone,
@@ -29,41 +44,56 @@ from langgraph_api.validation import (
29
44
  RunCreateStateful,
30
45
  RunCreateStateless,
31
46
  RunsCancel,
47
+ ThreadCronCreate,
32
48
  )
49
+ from langgraph_api.webhook import validate_webhook_url_or_raise
33
50
  from langgraph_license.validation import plus_features_enabled
34
51
  from langgraph_runtime.database import connect
35
52
  from langgraph_runtime.ops import Crons, Runs, StreamHandler, Threads
36
53
  from langgraph_runtime.retry import retry_db
37
54
 
55
+ CrudRuns = GrpcRuns if FF_USE_CORE_API else Runs
56
+
38
57
  logger = structlog.stdlib.get_logger(__name__)
39
58
 
40
59
 
41
60
  _RunResultFallback = Callable[[], Awaitable[bytes]]
42
61
 
43
62
 
63
+ def _ensure_crons_enabled() -> None:
64
+ if not (config.FF_CRONS_ENABLED and plus_features_enabled()):
65
+ raise HTTPException(
66
+ status_code=403,
67
+ detail="Crons are currently only available in the cloud version of LangSmith Deployment or with a self-hosting enterprise license. Please visit https://docs.langchain.com/langsmith/deployments to learn more about deployment options, or contact sales@langchain.com for more information",
68
+ )
69
+
70
+
44
71
  def _thread_values_fallback(thread_id: UUID) -> _RunResultFallback:
45
72
  async def fetch_thread_values() -> bytes:
46
73
  async with connect() as conn:
47
74
  thread_iter = await Threads.get(conn, thread_id)
48
75
  try:
49
- thread = await anext(thread_iter)
50
- if thread["status"] == "error":
51
- return orjson.dumps({"__error__": orjson.Fragment(thread["error"])})
52
- if thread["status"] == "interrupted":
76
+ row = await anext(thread_iter)
77
+ # Decrypt thread fields (values, interrupts, error) if encryption is enabled
78
+ thread = await decrypt_response(
79
+ dict(row),
80
+ "thread",
81
+ ["values", "interrupts", "error"],
82
+ )
83
+ if row["status"] == "error":
84
+ return json_dumpb({"__error__": json_loads(thread["error"])})
85
+ if row["status"] == "interrupted":
53
86
  # Get an interrupt for the thread. There is the case where there are multiple interrupts for the same run and we may not show the same
54
87
  # interrupt, but we'll always show one. Long term we should show all of them.
55
88
  try:
56
- if isinstance(thread["interrupts"], dict):
57
- # Handle in memory format
58
- interrupt_map = thread["interrupts"]
59
- else:
60
- interrupt_map = orjson.loads(thread["interrupts"].buf)
89
+ interrupt_map = json_loads(thread["interrupts"])
61
90
  interrupt = [next(iter(interrupt_map.values()))[0]]
62
- return orjson.dumps({"__interrupt__": interrupt})
91
+ return json_dumpb({"__interrupt__": interrupt})
63
92
  except Exception:
64
93
  # No interrupt, but status is interrupted from a before/after block. Default back to values.
65
94
  pass
66
- return cast(bytes, thread["values"])
95
+ values = json_loads(thread["values"]) if thread["values"] else None
96
+ return json_dumpb(values) if values else b"{}"
67
97
  except StopAsyncIteration:
68
98
  await logger.awarning(
69
99
  f"No checkpoint found for thread {thread_id}",
@@ -96,10 +126,8 @@ def _run_result_body(
96
126
  thread_id=thread_id,
97
127
  ignore_404=ignore_404,
98
128
  ):
99
- if (
100
- mode == b"values"
101
- or mode == b"updates"
102
- and b"__interrupt__" in chunk
129
+ if mode == b"values" or (
130
+ mode == b"updates" and b"__interrupt__" in chunk
103
131
  ):
104
132
  vchunk = chunk
105
133
  elif mode == b"error":
@@ -147,6 +175,7 @@ async def create_run(request: ApiRequest):
147
175
  """Create a run."""
148
176
  thread_id = request.path_params["thread_id"]
149
177
  payload = await request.json(RunCreateStateful)
178
+
150
179
  async with connect() as conn:
151
180
  run = await create_valid_run(
152
181
  conn,
@@ -155,6 +184,7 @@ async def create_run(request: ApiRequest):
155
184
  request.headers,
156
185
  request_start_time=request.scope.get("request_start_time_ms"),
157
186
  )
187
+ run = await decrypt_response(run, "run", RUN_ENCRYPTION_FIELDS)
158
188
  return ApiResponse(
159
189
  run,
160
190
  headers={"Content-Location": f"/threads/{thread_id}/runs/{run['run_id']}"},
@@ -165,6 +195,7 @@ async def create_run(request: ApiRequest):
165
195
  async def create_stateless_run(request: ApiRequest):
166
196
  """Create a run."""
167
197
  payload = await request.json(RunCreateStateless)
198
+
168
199
  async with connect() as conn:
169
200
  run = await create_valid_run(
170
201
  conn,
@@ -173,6 +204,7 @@ async def create_stateless_run(request: ApiRequest):
173
204
  request.headers,
174
205
  request_start_time=request.scope.get("request_start_time_ms"),
175
206
  )
207
+ run = await decrypt_response(run, "run", RUN_ENCRYPTION_FIELDS)
176
208
  return ApiResponse(
177
209
  run,
178
210
  headers={"Content-Location": f"/runs/{run['run_id']}"},
@@ -197,6 +229,7 @@ async def create_stateless_run_batch(request: ApiRequest):
197
229
  for payload in batch_payload
198
230
  ]
199
231
  runs = await asyncio.gather(*coros)
232
+ runs = await decrypt_responses(list(runs), "run", RUN_ENCRYPTION_FIELDS)
200
233
  return ApiResponse(runs)
201
234
 
202
235
 
@@ -411,7 +444,7 @@ async def list_runs(
411
444
  async with connect() as conn, conn.pipeline():
412
445
  thread, runs = await asyncio.gather(
413
446
  Threads.get(conn, thread_id),
414
- Runs.search(
447
+ CrudRuns.search(
415
448
  conn,
416
449
  thread_id,
417
450
  limit=limit,
@@ -421,7 +454,12 @@ async def list_runs(
421
454
  ),
422
455
  )
423
456
  await fetchone(thread)
424
- return ApiResponse([run async for run in runs])
457
+
458
+ # Collect and decrypt runs
459
+ runs_list = [run async for run in runs]
460
+ runs_list = await decrypt_responses(runs_list, "run", RUN_ENCRYPTION_FIELDS)
461
+
462
+ return ApiResponse(runs_list)
425
463
 
426
464
 
427
465
  @retry_db
@@ -435,14 +473,19 @@ async def get_run(request: ApiRequest):
435
473
  async with connect() as conn, conn.pipeline():
436
474
  thread, run = await asyncio.gather(
437
475
  Threads.get(conn, thread_id),
438
- Runs.get(
476
+ CrudRuns.get(
439
477
  conn,
440
478
  run_id,
441
479
  thread_id=thread_id,
442
480
  ),
443
481
  )
444
482
  await fetchone(thread)
445
- return ApiResponse(await fetchone(run))
483
+ run_dict = await fetchone(run)
484
+
485
+ # Decrypt run metadata and kwargs
486
+ run_dict = await decrypt_response(run_dict, "run", RUN_ENCRYPTION_FIELDS)
487
+
488
+ return ApiResponse(run_dict)
446
489
 
447
490
 
448
491
  @retry_db
@@ -519,14 +562,14 @@ async def cancel_run(
519
562
  wait = wait_str.lower() in {"true", "yes", "1"}
520
563
  action_str = request.query_params.get("action", "interrupt")
521
564
  action = cast(
522
- Literal["interrupt", "rollback"],
565
+ "Literal['interrupt', 'rollback']",
523
566
  action_str if action_str in {"interrupt", "rollback"} else "interrupt",
524
567
  )
525
568
 
526
569
  sub = await Runs.Stream.subscribe(run_id, thread_id) if wait else None
527
570
  try:
528
571
  async with connect() as conn:
529
- await Runs.cancel(
572
+ await CrudRuns.cancel(
530
573
  conn,
531
574
  [run_id],
532
575
  action=action,
@@ -585,12 +628,12 @@ async def cancel_runs(
585
628
  validate_uuid(rid, "Invalid run ID: must be a UUID")
586
629
  action_str = request.query_params.get("action", "interrupt")
587
630
  action = cast(
588
- Literal["interrupt", "rollback"],
631
+ "Literal['interrupt', 'rollback']",
589
632
  action_str if action_str in ("interrupt", "rollback") else "interrupt",
590
633
  )
591
634
 
592
635
  async with connect() as conn:
593
- await Runs.cancel(
636
+ await CrudRuns.cancel(
594
637
  conn,
595
638
  run_ids,
596
639
  action=action,
@@ -609,7 +652,7 @@ async def delete_run(request: ApiRequest):
609
652
  validate_uuid(run_id, "Invalid run ID: must be a UUID")
610
653
 
611
654
  async with connect() as conn:
612
- rid = await Runs.delete(
655
+ rid = await CrudRuns.delete(
613
656
  conn,
614
657
  run_id,
615
658
  thread_id=thread_id,
@@ -621,40 +664,71 @@ async def delete_run(request: ApiRequest):
621
664
  @retry_db
622
665
  async def create_cron(request: ApiRequest):
623
666
  """Create a cron with new thread."""
667
+ _ensure_crons_enabled()
624
668
  payload = await request.json(CronCreate)
669
+ if webhook := payload.get("webhook"):
670
+ await validate_webhook_url_or_raise(str(webhook))
671
+ _validate_assistant_id(payload.get("assistant_id"))
672
+
673
+ encrypted_payload = await encrypt_request(
674
+ payload,
675
+ "cron",
676
+ CRON_PAYLOAD_ENCRYPTION_SUBFIELDS,
677
+ )
625
678
 
626
679
  async with connect() as conn:
627
680
  cron = await Crons.put(
628
681
  conn,
629
682
  thread_id=None,
683
+ on_run_completed=payload.get("on_run_completed", "delete"),
630
684
  end_time=payload.get("end_time"),
631
685
  schedule=payload.get("schedule"),
632
- payload=payload,
686
+ payload=encrypted_payload,
687
+ metadata=encrypted_payload.get("metadata"),
633
688
  )
634
- return ApiResponse(await fetchone(cron))
689
+ cron_dict = await fetchone(cron)
690
+ cron_dict = await decrypt_response(cron_dict, "cron", CRON_ENCRYPTION_FIELDS)
691
+
692
+ return ApiResponse(cron_dict)
635
693
 
636
694
 
637
695
  @retry_db
638
696
  async def create_thread_cron(request: ApiRequest):
639
697
  """Create a thread specific cron."""
698
+ _ensure_crons_enabled()
640
699
  thread_id = request.path_params["thread_id"]
641
700
  validate_uuid(thread_id, "Invalid thread ID: must be a UUID")
642
- payload = await request.json(CronCreate)
701
+ payload = await request.json(ThreadCronCreate)
702
+ if webhook := payload.get("webhook"):
703
+ await validate_webhook_url_or_raise(str(webhook))
704
+ _validate_assistant_id(payload.get("assistant_id"))
705
+
706
+ encrypted_payload = await encrypt_request(
707
+ payload,
708
+ "cron",
709
+ CRON_PAYLOAD_ENCRYPTION_SUBFIELDS,
710
+ )
643
711
 
644
712
  async with connect() as conn:
645
713
  cron = await Crons.put(
646
714
  conn,
647
715
  thread_id=thread_id,
716
+ on_run_completed=None,
648
717
  end_time=payload.get("end_time"),
649
718
  schedule=payload.get("schedule"),
650
- payload=payload,
719
+ payload=encrypted_payload,
720
+ metadata=encrypted_payload.get("metadata"),
651
721
  )
652
- return ApiResponse(await fetchone(cron))
722
+ cron_dict = await fetchone(cron)
723
+ cron_dict = await decrypt_response(cron_dict, "cron", CRON_ENCRYPTION_FIELDS)
724
+
725
+ return ApiResponse(cron_dict)
653
726
 
654
727
 
655
728
  @retry_db
656
729
  async def delete_cron(request: ApiRequest):
657
730
  """Delete a cron by ID."""
731
+ _ensure_crons_enabled()
658
732
  cron_id = request.path_params["cron_id"]
659
733
  validate_uuid(cron_id, "Invalid cron ID: must be a UUID")
660
734
 
@@ -670,6 +744,7 @@ async def delete_cron(request: ApiRequest):
670
744
  @retry_db
671
745
  async def search_crons(request: ApiRequest):
672
746
  """List all cron jobs for an assistant"""
747
+ _ensure_crons_enabled()
673
748
  payload = await request.json(CronSearch)
674
749
  select = validate_select_columns(payload.get("select") or None, CRON_FIELDS)
675
750
  if assistant_id := payload.get("assistant_id"):
@@ -692,12 +767,16 @@ async def search_crons(request: ApiRequest):
692
767
  crons, response_headers = await get_pagination_headers(
693
768
  crons_iter, next_offset, offset
694
769
  )
770
+
771
+ crons = await decrypt_responses(crons, "cron", CRON_ENCRYPTION_FIELDS)
772
+
695
773
  return ApiResponse(crons, headers=response_headers)
696
774
 
697
775
 
698
776
  @retry_db
699
777
  async def count_crons(request: ApiRequest):
700
778
  """Count cron jobs."""
779
+ _ensure_crons_enabled()
701
780
  payload = await request.json(CronCountRequest)
702
781
  if assistant_id := payload.get("assistant_id"):
703
782
  validate_uuid(assistant_id, "Invalid assistant ID: must be a UUID")
@@ -719,21 +798,9 @@ runs_routes = [
719
798
  ApiRoute("/runs", create_stateless_run, methods=["POST"]),
720
799
  ApiRoute("/runs/batch", create_stateless_run_batch, methods=["POST"]),
721
800
  ApiRoute("/runs/cancel", cancel_runs, methods=["POST"]),
722
- (
723
- ApiRoute("/runs/crons", create_cron, methods=["POST"])
724
- if config.FF_CRONS_ENABLED and plus_features_enabled()
725
- else None
726
- ),
727
- (
728
- ApiRoute("/runs/crons/search", search_crons, methods=["POST"])
729
- if config.FF_CRONS_ENABLED and plus_features_enabled()
730
- else None
731
- ),
732
- (
733
- ApiRoute("/runs/crons/count", count_crons, methods=["POST"])
734
- if config.FF_CRONS_ENABLED and plus_features_enabled()
735
- else None
736
- ),
801
+ ApiRoute("/runs/crons", create_cron, methods=["POST"]),
802
+ ApiRoute("/runs/crons/search", search_crons, methods=["POST"]),
803
+ ApiRoute("/runs/crons/count", count_crons, methods=["POST"]),
737
804
  ApiRoute("/threads/{thread_id}/runs/{run_id}/join", join_run, methods=["GET"]),
738
805
  ApiRoute(
739
806
  "/threads/{thread_id}/runs/{run_id}/stream",
@@ -746,19 +813,9 @@ runs_routes = [
746
813
  ApiRoute("/threads/{thread_id}/runs/stream", stream_run, methods=["POST"]),
747
814
  ApiRoute("/threads/{thread_id}/runs/wait", wait_run, methods=["POST"]),
748
815
  ApiRoute("/threads/{thread_id}/runs", create_run, methods=["POST"]),
749
- (
750
- ApiRoute(
751
- "/threads/{thread_id}/runs/crons", create_thread_cron, methods=["POST"]
752
- )
753
- if config.FF_CRONS_ENABLED and plus_features_enabled()
754
- else None
755
- ),
816
+ ApiRoute("/threads/{thread_id}/runs/crons", create_thread_cron, methods=["POST"]),
756
817
  ApiRoute("/threads/{thread_id}/runs", list_runs, methods=["GET"]),
757
- (
758
- ApiRoute("/runs/crons/{cron_id}", delete_cron, methods=["DELETE"])
759
- if config.FF_CRONS_ENABLED and plus_features_enabled()
760
- else None
761
- ),
818
+ ApiRoute("/runs/crons/{cron_id}", delete_cron, methods=["DELETE"]),
762
819
  ]
763
820
 
764
821
  runs_routes = [route for route in runs_routes if route is not None]