langgraph-api 0.3.4__py3-none-any.whl → 0.4.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
langgraph_api/__init__.py CHANGED
@@ -1 +1 @@
1
- __version__ = "0.3.4"
1
+ __version__ = "0.4.1"
langgraph_api/api/runs.py CHANGED
@@ -1,6 +1,7 @@
1
1
  import asyncio
2
2
  from collections.abc import AsyncIterator
3
3
  from typing import Literal, cast
4
+ from uuid import uuid4
4
5
 
5
6
  import orjson
6
7
  from starlette.exceptions import HTTPException
@@ -100,7 +101,7 @@ async def stream_run(
100
101
  payload = await request.json(RunCreateStateful)
101
102
  on_disconnect = payload.get("on_disconnect", "continue")
102
103
  run_id = uuid7()
103
- sub = asyncio.create_task(Runs.Stream.subscribe(run_id))
104
+ sub = asyncio.create_task(Runs.Stream.subscribe(run_id, thread_id))
104
105
 
105
106
  try:
106
107
  async with connect() as conn:
@@ -138,19 +139,21 @@ async def stream_run_stateless(
138
139
  ):
139
140
  """Create a stateless run."""
140
141
  payload = await request.json(RunCreateStateless)
142
+ payload["if_not_exists"] = "create"
141
143
  on_disconnect = payload.get("on_disconnect", "continue")
142
144
  run_id = uuid7()
143
- sub = asyncio.create_task(Runs.Stream.subscribe(run_id))
144
-
145
+ thread_id = uuid4()
146
+ sub = asyncio.create_task(Runs.Stream.subscribe(run_id, thread_id))
145
147
  try:
146
148
  async with connect() as conn:
147
149
  run = await create_valid_run(
148
150
  conn,
149
- None,
151
+ str(thread_id),
150
152
  payload,
151
153
  request.headers,
152
154
  run_id=run_id,
153
155
  request_start_time=request.scope.get("request_start_time_ms"),
156
+ temporary=True,
154
157
  )
155
158
  except Exception:
156
159
  if not sub.cancelled():
@@ -181,7 +184,7 @@ async def wait_run(request: ApiRequest):
181
184
  payload = await request.json(RunCreateStateful)
182
185
  on_disconnect = payload.get("on_disconnect", "continue")
183
186
  run_id = uuid7()
184
- sub = asyncio.create_task(Runs.Stream.subscribe(run_id))
187
+ sub = asyncio.create_task(Runs.Stream.subscribe(run_id, thread_id))
185
188
 
186
189
  try:
187
190
  async with connect() as conn:
@@ -263,26 +266,28 @@ async def wait_run(request: ApiRequest):
263
266
  async def wait_run_stateless(request: ApiRequest):
264
267
  """Create a stateless run, wait for the output."""
265
268
  payload = await request.json(RunCreateStateless)
269
+ payload["if_not_exists"] = "create"
266
270
  on_disconnect = payload.get("on_disconnect", "continue")
267
271
  run_id = uuid7()
268
- sub = asyncio.create_task(Runs.Stream.subscribe(run_id))
272
+ thread_id = uuid4()
273
+ sub = asyncio.create_task(Runs.Stream.subscribe(run_id, thread_id))
269
274
 
270
275
  try:
271
276
  async with connect() as conn:
272
277
  run = await create_valid_run(
273
278
  conn,
274
- None,
279
+ str(thread_id),
275
280
  payload,
276
281
  request.headers,
277
282
  run_id=run_id,
278
283
  request_start_time=request.scope.get("request_start_time_ms"),
284
+ temporary=True,
279
285
  )
280
286
  except Exception:
281
287
  if not sub.cancelled():
282
288
  handle = await sub
283
289
  await handle.__aexit__(None, None, None)
284
290
  raise
285
-
286
291
  last_chunk = ValueEvent()
287
292
 
288
293
  async def consume():
@@ -6,11 +6,13 @@ from starlette.routing import BaseRoute
6
6
 
7
7
  from langgraph_api.route import ApiRequest, ApiResponse, ApiRoute
8
8
  from langgraph_api.schema import THREAD_FIELDS
9
+ from langgraph_api.sse import EventSourceResponse
9
10
  from langgraph_api.state import state_snapshot_to_thread_state
10
11
  from langgraph_api.utils import (
11
12
  fetchone,
12
13
  get_pagination_headers,
13
14
  validate_select_columns,
15
+ validate_stream_id,
14
16
  validate_uuid,
15
17
  )
16
18
  from langgraph_api.validation import (
@@ -282,6 +284,23 @@ async def copy_thread(request: ApiRequest):
282
284
  return ApiResponse(await fetchone(iter, not_found_code=409))
283
285
 
284
286
 
287
+ @retry_db
288
+ async def join_thread_stream(request: ApiRequest):
289
+ """Join a thread stream."""
290
+ thread_id = request.path_params["thread_id"]
291
+ validate_uuid(thread_id, "Invalid thread ID: must be a UUID")
292
+ last_event_id = request.headers.get("last-event-id") or None
293
+ validate_stream_id(
294
+ last_event_id, "Invalid last-event-id: must be a valid Redis stream ID"
295
+ )
296
+ return EventSourceResponse(
297
+ Threads.Stream.join(
298
+ thread_id,
299
+ last_event_id=last_event_id,
300
+ ),
301
+ )
302
+
303
+
285
304
  threads_routes: list[BaseRoute] = [
286
305
  ApiRoute("/threads", endpoint=create_thread, methods=["POST"]),
287
306
  ApiRoute("/threads/search", endpoint=search_threads, methods=["POST"]),
@@ -312,4 +331,9 @@ threads_routes: list[BaseRoute] = [
312
331
  endpoint=get_thread_state_at_checkpoint_post,
313
332
  methods=["POST"],
314
333
  ),
334
+ ApiRoute(
335
+ "/threads/{thread_id}/stream",
336
+ endpoint=join_thread_stream,
337
+ methods=["GET"],
338
+ ),
315
339
  ]