langgraph-api 0.4.1__py3-none-any.whl → 0.4.9__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of langgraph-api might be problematic. Click here for more details.
- langgraph_api/__init__.py +1 -1
- langgraph_api/api/__init__.py +4 -0
- langgraph_api/api/a2a.py +1128 -0
- langgraph_api/api/assistants.py +8 -0
- langgraph_api/api/runs.py +126 -146
- langgraph_api/api/threads.py +64 -14
- langgraph_api/asyncio.py +2 -1
- langgraph_api/feature_flags.py +1 -0
- langgraph_api/logging.py +5 -2
- langgraph_api/models/run.py +10 -67
- langgraph_api/schema.py +2 -0
- langgraph_api/stream.py +9 -1
- langgraph_api/utils/headers.py +76 -2
- {langgraph_api-0.4.1.dist-info → langgraph_api-0.4.9.dist-info}/METADATA +2 -2
- {langgraph_api-0.4.1.dist-info → langgraph_api-0.4.9.dist-info}/RECORD +19 -22
- openapi.json +244 -0
- langgraph_api/js/isolate-0x130008000-46649-46649-v8.log +0 -4430
- langgraph_api/js/isolate-0x138008000-44681-44681-v8.log +0 -4430
- langgraph_api/js/package-lock.json +0 -3308
- langgraph_api/utils.py +0 -0
- {langgraph_api-0.4.1.dist-info → langgraph_api-0.4.9.dist-info}/WHEEL +0 -0
- {langgraph_api-0.4.1.dist-info → langgraph_api-0.4.9.dist-info}/entry_points.txt +0 -0
- {langgraph_api-0.4.1.dist-info → langgraph_api-0.4.9.dist-info}/licenses/LICENSE +0 -0
langgraph_api/api/assistants.py
CHANGED
|
@@ -24,6 +24,7 @@ from langgraph_api.utils import (
|
|
|
24
24
|
validate_select_columns,
|
|
25
25
|
validate_uuid,
|
|
26
26
|
)
|
|
27
|
+
from langgraph_api.utils.headers import get_configurable_headers
|
|
27
28
|
from langgraph_api.validation import (
|
|
28
29
|
AssistantCountRequest,
|
|
29
30
|
AssistantCreate,
|
|
@@ -240,6 +241,9 @@ async def get_assistant_graph(
|
|
|
240
241
|
assistant_ = await Assistants.get(conn, assistant_id)
|
|
241
242
|
assistant = await fetchone(assistant_)
|
|
242
243
|
config = await ajson_loads(assistant["config"])
|
|
244
|
+
configurable = config.setdefault("configurable", {})
|
|
245
|
+
configurable.update(get_configurable_headers(request.headers))
|
|
246
|
+
|
|
243
247
|
async with get_graph(
|
|
244
248
|
assistant["graph_id"],
|
|
245
249
|
config,
|
|
@@ -294,6 +298,8 @@ async def get_assistant_subgraphs(
|
|
|
294
298
|
assistant_ = await Assistants.get(conn, assistant_id)
|
|
295
299
|
assistant = await fetchone(assistant_)
|
|
296
300
|
config = await ajson_loads(assistant["config"])
|
|
301
|
+
configurable = config.setdefault("configurable", {})
|
|
302
|
+
configurable.update(get_configurable_headers(request.headers))
|
|
297
303
|
async with get_graph(
|
|
298
304
|
assistant["graph_id"],
|
|
299
305
|
config,
|
|
@@ -340,6 +346,8 @@ async def get_assistant_schemas(
|
|
|
340
346
|
# TODO Implementa cache so we can de-dent and release this connection.
|
|
341
347
|
assistant = await fetchone(assistant_)
|
|
342
348
|
config = await ajson_loads(assistant["config"])
|
|
349
|
+
configurable = config.setdefault("configurable", {})
|
|
350
|
+
configurable.update(get_configurable_headers(request.headers))
|
|
343
351
|
async with get_graph(
|
|
344
352
|
assistant["graph_id"],
|
|
345
353
|
config,
|
langgraph_api/api/runs.py
CHANGED
|
@@ -4,11 +4,12 @@ from typing import Literal, cast
|
|
|
4
4
|
from uuid import uuid4
|
|
5
5
|
|
|
6
6
|
import orjson
|
|
7
|
+
import structlog
|
|
7
8
|
from starlette.exceptions import HTTPException
|
|
8
9
|
from starlette.responses import Response, StreamingResponse
|
|
9
10
|
|
|
10
11
|
from langgraph_api import config
|
|
11
|
-
from langgraph_api.asyncio import ValueEvent
|
|
12
|
+
from langgraph_api.asyncio import ValueEvent
|
|
12
13
|
from langgraph_api.models.run import create_valid_run
|
|
13
14
|
from langgraph_api.route import ApiRequest, ApiResponse, ApiRoute
|
|
14
15
|
from langgraph_api.schema import CRON_FIELDS, RUN_FIELDS
|
|
@@ -34,6 +35,8 @@ from langgraph_runtime.database import connect
|
|
|
34
35
|
from langgraph_runtime.ops import Crons, Runs, Threads
|
|
35
36
|
from langgraph_runtime.retry import retry_db
|
|
36
37
|
|
|
38
|
+
logger = structlog.stdlib.get_logger(__name__)
|
|
39
|
+
|
|
37
40
|
|
|
38
41
|
@retry_db
|
|
39
42
|
async def create_run(request: ApiRequest):
|
|
@@ -101,9 +104,7 @@ async def stream_run(
|
|
|
101
104
|
payload = await request.json(RunCreateStateful)
|
|
102
105
|
on_disconnect = payload.get("on_disconnect", "continue")
|
|
103
106
|
run_id = uuid7()
|
|
104
|
-
|
|
105
|
-
|
|
106
|
-
try:
|
|
107
|
+
async with await Runs.Stream.subscribe(run_id, thread_id) as sub:
|
|
107
108
|
async with connect() as conn:
|
|
108
109
|
run = await create_valid_run(
|
|
109
110
|
conn,
|
|
@@ -113,25 +114,20 @@ async def stream_run(
|
|
|
113
114
|
run_id=run_id,
|
|
114
115
|
request_start_time=request.scope.get("request_start_time_ms"),
|
|
115
116
|
)
|
|
116
|
-
except Exception:
|
|
117
|
-
if not sub.cancelled():
|
|
118
|
-
handle = await sub
|
|
119
|
-
await handle.__aexit__(None, None, None)
|
|
120
|
-
raise
|
|
121
117
|
|
|
122
|
-
|
|
123
|
-
|
|
124
|
-
|
|
125
|
-
|
|
126
|
-
|
|
127
|
-
|
|
128
|
-
|
|
129
|
-
|
|
130
|
-
|
|
131
|
-
|
|
132
|
-
|
|
133
|
-
|
|
134
|
-
|
|
118
|
+
return EventSourceResponse(
|
|
119
|
+
Runs.Stream.join(
|
|
120
|
+
run["run_id"],
|
|
121
|
+
thread_id=thread_id,
|
|
122
|
+
cancel_on_disconnect=on_disconnect == "cancel",
|
|
123
|
+
stream_channel=sub,
|
|
124
|
+
last_event_id=None,
|
|
125
|
+
),
|
|
126
|
+
headers={
|
|
127
|
+
"Location": f"/threads/{thread_id}/runs/{run['run_id']}/stream",
|
|
128
|
+
"Content-Location": f"/threads/{thread_id}/runs/{run['run_id']}",
|
|
129
|
+
},
|
|
130
|
+
)
|
|
135
131
|
|
|
136
132
|
|
|
137
133
|
async def stream_run_stateless(
|
|
@@ -143,8 +139,7 @@ async def stream_run_stateless(
|
|
|
143
139
|
on_disconnect = payload.get("on_disconnect", "continue")
|
|
144
140
|
run_id = uuid7()
|
|
145
141
|
thread_id = uuid4()
|
|
146
|
-
|
|
147
|
-
try:
|
|
142
|
+
async with await Runs.Stream.subscribe(run_id, thread_id) as sub:
|
|
148
143
|
async with connect() as conn:
|
|
149
144
|
run = await create_valid_run(
|
|
150
145
|
conn,
|
|
@@ -155,26 +150,21 @@ async def stream_run_stateless(
|
|
|
155
150
|
request_start_time=request.scope.get("request_start_time_ms"),
|
|
156
151
|
temporary=True,
|
|
157
152
|
)
|
|
158
|
-
except Exception:
|
|
159
|
-
if not sub.cancelled():
|
|
160
|
-
handle = await sub
|
|
161
|
-
await handle.__aexit__(None, None, None)
|
|
162
|
-
raise
|
|
163
153
|
|
|
164
|
-
|
|
165
|
-
|
|
166
|
-
|
|
167
|
-
|
|
168
|
-
|
|
169
|
-
|
|
170
|
-
|
|
171
|
-
|
|
172
|
-
|
|
173
|
-
|
|
174
|
-
|
|
175
|
-
|
|
176
|
-
|
|
177
|
-
|
|
154
|
+
return EventSourceResponse(
|
|
155
|
+
Runs.Stream.join(
|
|
156
|
+
run["run_id"],
|
|
157
|
+
thread_id=run["thread_id"],
|
|
158
|
+
ignore_404=True,
|
|
159
|
+
cancel_on_disconnect=on_disconnect == "cancel",
|
|
160
|
+
stream_channel=sub,
|
|
161
|
+
last_event_id=None,
|
|
162
|
+
),
|
|
163
|
+
headers={
|
|
164
|
+
"Location": f"/runs/{run['run_id']}/stream",
|
|
165
|
+
"Content-Location": f"/runs/{run['run_id']}",
|
|
166
|
+
},
|
|
167
|
+
)
|
|
178
168
|
|
|
179
169
|
|
|
180
170
|
@retry_db
|
|
@@ -184,9 +174,7 @@ async def wait_run(request: ApiRequest):
|
|
|
184
174
|
payload = await request.json(RunCreateStateful)
|
|
185
175
|
on_disconnect = payload.get("on_disconnect", "continue")
|
|
186
176
|
run_id = uuid7()
|
|
187
|
-
|
|
188
|
-
|
|
189
|
-
try:
|
|
177
|
+
async with await Runs.Stream.subscribe(run_id, thread_id) as sub:
|
|
190
178
|
async with connect() as conn:
|
|
191
179
|
run = await create_valid_run(
|
|
192
180
|
conn,
|
|
@@ -196,25 +184,17 @@ async def wait_run(request: ApiRequest):
|
|
|
196
184
|
run_id=run_id,
|
|
197
185
|
request_start_time=request.scope.get("request_start_time_ms"),
|
|
198
186
|
)
|
|
199
|
-
except Exception:
|
|
200
|
-
if not sub.cancelled():
|
|
201
|
-
handle = await sub
|
|
202
|
-
await handle.__aexit__(None, None, None)
|
|
203
|
-
raise
|
|
204
187
|
|
|
205
|
-
|
|
188
|
+
last_chunk = ValueEvent()
|
|
206
189
|
|
|
207
|
-
|
|
208
|
-
|
|
209
|
-
|
|
210
|
-
Runs.Stream.join(
|
|
190
|
+
async def consume():
|
|
191
|
+
vchunk: bytes | None = None
|
|
192
|
+
async for mode, chunk, _ in Runs.Stream.join(
|
|
211
193
|
run["run_id"],
|
|
212
194
|
thread_id=run["thread_id"],
|
|
213
|
-
stream_channel=
|
|
195
|
+
stream_channel=sub,
|
|
214
196
|
cancel_on_disconnect=on_disconnect == "cancel",
|
|
215
|
-
)
|
|
216
|
-
) as stream:
|
|
217
|
-
async for mode, chunk, _ in stream:
|
|
197
|
+
):
|
|
218
198
|
if (
|
|
219
199
|
mode == b"values"
|
|
220
200
|
or mode == b"updates"
|
|
@@ -223,43 +203,47 @@ async def wait_run(request: ApiRequest):
|
|
|
223
203
|
vchunk = chunk
|
|
224
204
|
elif mode == b"error":
|
|
225
205
|
vchunk = orjson.dumps({"__error__": orjson.Fragment(chunk)})
|
|
226
|
-
|
|
227
|
-
|
|
228
|
-
|
|
229
|
-
|
|
230
|
-
|
|
206
|
+
if vchunk is not None:
|
|
207
|
+
last_chunk.set(vchunk)
|
|
208
|
+
else:
|
|
209
|
+
async with connect() as conn:
|
|
210
|
+
thread_iter = await Threads.get(conn, thread_id)
|
|
211
|
+
try:
|
|
212
|
+
thread = await anext(thread_iter)
|
|
213
|
+
last_chunk.set(thread["values"])
|
|
214
|
+
except StopAsyncIteration:
|
|
215
|
+
await logger.awarning(
|
|
216
|
+
f"No checkpoint found for thread {thread_id}",
|
|
217
|
+
thread_id=thread_id,
|
|
218
|
+
)
|
|
219
|
+
last_chunk.set(b"{}")
|
|
220
|
+
|
|
221
|
+
# keep the connection open by sending whitespace every 5 seconds
|
|
222
|
+
# leading whitespace will be ignored by json parsers
|
|
223
|
+
async def body() -> AsyncIterator[bytes]:
|
|
224
|
+
stream = asyncio.create_task(consume())
|
|
225
|
+
while True:
|
|
231
226
|
try:
|
|
232
|
-
|
|
233
|
-
|
|
234
|
-
|
|
235
|
-
last_chunk.
|
|
236
|
-
|
|
237
|
-
|
|
238
|
-
|
|
239
|
-
|
|
240
|
-
|
|
241
|
-
|
|
242
|
-
|
|
243
|
-
|
|
244
|
-
|
|
245
|
-
|
|
246
|
-
|
|
247
|
-
|
|
248
|
-
|
|
249
|
-
|
|
250
|
-
|
|
251
|
-
|
|
252
|
-
await stream
|
|
253
|
-
raise
|
|
254
|
-
|
|
255
|
-
return StreamingResponse(
|
|
256
|
-
body(),
|
|
257
|
-
media_type="application/json",
|
|
258
|
-
headers={
|
|
259
|
-
"Location": f"/threads/{thread_id}/runs/{run['run_id']}/join",
|
|
260
|
-
"Content-Location": f"/threads/{thread_id}/runs/{run['run_id']}",
|
|
261
|
-
},
|
|
262
|
-
)
|
|
227
|
+
if stream.done():
|
|
228
|
+
# raise stream exception if any
|
|
229
|
+
stream.result()
|
|
230
|
+
yield await asyncio.wait_for(last_chunk.wait(), timeout=5)
|
|
231
|
+
break
|
|
232
|
+
except TimeoutError:
|
|
233
|
+
yield b"\n"
|
|
234
|
+
except asyncio.CancelledError:
|
|
235
|
+
stream.cancel()
|
|
236
|
+
await stream
|
|
237
|
+
raise
|
|
238
|
+
|
|
239
|
+
return StreamingResponse(
|
|
240
|
+
body(),
|
|
241
|
+
media_type="application/json",
|
|
242
|
+
headers={
|
|
243
|
+
"Location": f"/threads/{thread_id}/runs/{run['run_id']}/join",
|
|
244
|
+
"Content-Location": f"/threads/{thread_id}/runs/{run['run_id']}",
|
|
245
|
+
},
|
|
246
|
+
)
|
|
263
247
|
|
|
264
248
|
|
|
265
249
|
@retry_db
|
|
@@ -270,9 +254,7 @@ async def wait_run_stateless(request: ApiRequest):
|
|
|
270
254
|
on_disconnect = payload.get("on_disconnect", "continue")
|
|
271
255
|
run_id = uuid7()
|
|
272
256
|
thread_id = uuid4()
|
|
273
|
-
|
|
274
|
-
|
|
275
|
-
try:
|
|
257
|
+
async with await Runs.Stream.subscribe(run_id, thread_id) as sub:
|
|
276
258
|
async with connect() as conn:
|
|
277
259
|
run = await create_valid_run(
|
|
278
260
|
conn,
|
|
@@ -283,25 +265,18 @@ async def wait_run_stateless(request: ApiRequest):
|
|
|
283
265
|
request_start_time=request.scope.get("request_start_time_ms"),
|
|
284
266
|
temporary=True,
|
|
285
267
|
)
|
|
286
|
-
|
|
287
|
-
|
|
288
|
-
|
|
289
|
-
|
|
290
|
-
|
|
291
|
-
|
|
292
|
-
|
|
293
|
-
async def consume():
|
|
294
|
-
vchunk: bytes | None = None
|
|
295
|
-
async with aclosing(
|
|
296
|
-
Runs.Stream.join(
|
|
268
|
+
|
|
269
|
+
last_chunk = ValueEvent()
|
|
270
|
+
|
|
271
|
+
async def consume():
|
|
272
|
+
vchunk: bytes | None = None
|
|
273
|
+
async for mode, chunk, _ in Runs.Stream.join(
|
|
297
274
|
run["run_id"],
|
|
298
275
|
thread_id=run["thread_id"],
|
|
299
|
-
stream_channel=
|
|
276
|
+
stream_channel=sub,
|
|
300
277
|
ignore_404=True,
|
|
301
278
|
cancel_on_disconnect=on_disconnect == "cancel",
|
|
302
|
-
)
|
|
303
|
-
) as stream:
|
|
304
|
-
async for mode, chunk, _ in stream:
|
|
279
|
+
):
|
|
305
280
|
if (
|
|
306
281
|
mode == b"values"
|
|
307
282
|
or mode == b"updates"
|
|
@@ -310,38 +285,43 @@ async def wait_run_stateless(request: ApiRequest):
|
|
|
310
285
|
vchunk = chunk
|
|
311
286
|
elif mode == b"error":
|
|
312
287
|
vchunk = orjson.dumps({"__error__": orjson.Fragment(chunk)})
|
|
313
|
-
|
|
314
|
-
|
|
315
|
-
|
|
316
|
-
|
|
317
|
-
|
|
318
|
-
|
|
319
|
-
|
|
320
|
-
|
|
321
|
-
|
|
322
|
-
|
|
323
|
-
|
|
324
|
-
|
|
325
|
-
|
|
326
|
-
|
|
327
|
-
|
|
328
|
-
|
|
329
|
-
|
|
330
|
-
|
|
331
|
-
|
|
332
|
-
|
|
333
|
-
|
|
334
|
-
|
|
335
|
-
|
|
336
|
-
|
|
337
|
-
|
|
338
|
-
|
|
339
|
-
|
|
340
|
-
|
|
341
|
-
|
|
342
|
-
|
|
343
|
-
|
|
344
|
-
|
|
288
|
+
if vchunk is not None:
|
|
289
|
+
last_chunk.set(vchunk)
|
|
290
|
+
else:
|
|
291
|
+
# we can't fetch the thread (it was deleted), so just return empty values
|
|
292
|
+
await logger.awarning(
|
|
293
|
+
"No checkpoint emitted for stateless run",
|
|
294
|
+
run_id=run["run_id"],
|
|
295
|
+
thread_id=run["thread_id"],
|
|
296
|
+
)
|
|
297
|
+
last_chunk.set(b"{}")
|
|
298
|
+
|
|
299
|
+
# keep the connection open by sending whitespace every 5 seconds
|
|
300
|
+
# leading whitespace will be ignored by json parsers
|
|
301
|
+
async def body() -> AsyncIterator[bytes]:
|
|
302
|
+
stream = asyncio.create_task(consume())
|
|
303
|
+
while True:
|
|
304
|
+
try:
|
|
305
|
+
if stream.done():
|
|
306
|
+
# raise stream exception if any
|
|
307
|
+
stream.result()
|
|
308
|
+
yield await asyncio.wait_for(last_chunk.wait(), timeout=5)
|
|
309
|
+
break
|
|
310
|
+
except TimeoutError:
|
|
311
|
+
yield b"\n"
|
|
312
|
+
except asyncio.CancelledError:
|
|
313
|
+
stream.cancel("Run stream cancelled")
|
|
314
|
+
await stream
|
|
315
|
+
raise
|
|
316
|
+
|
|
317
|
+
return StreamingResponse(
|
|
318
|
+
body(),
|
|
319
|
+
media_type="application/json",
|
|
320
|
+
headers={
|
|
321
|
+
"Location": f"/threads/{run['thread_id']}/runs/{run['run_id']}/join",
|
|
322
|
+
"Content-Location": f"/threads/{run['thread_id']}/runs/{run['run_id']}",
|
|
323
|
+
},
|
|
324
|
+
)
|
|
345
325
|
|
|
346
326
|
|
|
347
327
|
@retry_db
|
langgraph_api/api/threads.py
CHANGED
|
@@ -1,3 +1,4 @@
|
|
|
1
|
+
from typing import get_args
|
|
1
2
|
from uuid import uuid4
|
|
2
3
|
|
|
3
4
|
from starlette.exceptions import HTTPException
|
|
@@ -5,7 +6,7 @@ from starlette.responses import Response
|
|
|
5
6
|
from starlette.routing import BaseRoute
|
|
6
7
|
|
|
7
8
|
from langgraph_api.route import ApiRequest, ApiResponse, ApiRoute
|
|
8
|
-
from langgraph_api.schema import THREAD_FIELDS
|
|
9
|
+
from langgraph_api.schema import THREAD_FIELDS, ThreadStreamMode
|
|
9
10
|
from langgraph_api.sse import EventSourceResponse
|
|
10
11
|
from langgraph_api.state import state_snapshot_to_thread_state
|
|
11
12
|
from langgraph_api.utils import (
|
|
@@ -15,6 +16,7 @@ from langgraph_api.utils import (
|
|
|
15
16
|
validate_stream_id,
|
|
16
17
|
validate_uuid,
|
|
17
18
|
)
|
|
19
|
+
from langgraph_api.utils.headers import get_configurable_headers
|
|
18
20
|
from langgraph_api.validation import (
|
|
19
21
|
ThreadCountRequest,
|
|
20
22
|
ThreadCreate,
|
|
@@ -46,12 +48,17 @@ async def create_thread(
|
|
|
46
48
|
if_exists=payload.get("if_exists") or "raise",
|
|
47
49
|
ttl=payload.get("ttl"),
|
|
48
50
|
)
|
|
49
|
-
|
|
51
|
+
config = {
|
|
52
|
+
"configurable": {
|
|
53
|
+
**get_configurable_headers(request.headers),
|
|
54
|
+
"thread_id": thread_id,
|
|
55
|
+
}
|
|
56
|
+
}
|
|
50
57
|
if supersteps := payload.get("supersteps"):
|
|
51
58
|
try:
|
|
52
59
|
await Threads.State.bulk(
|
|
53
60
|
conn,
|
|
54
|
-
config=
|
|
61
|
+
config=config,
|
|
55
62
|
supersteps=supersteps,
|
|
56
63
|
)
|
|
57
64
|
except HTTPException as e:
|
|
@@ -76,6 +83,7 @@ async def search_threads(
|
|
|
76
83
|
status=payload.get("status"),
|
|
77
84
|
values=payload.get("values"),
|
|
78
85
|
metadata=payload.get("metadata"),
|
|
86
|
+
ids=payload.get("ids"),
|
|
79
87
|
limit=limit,
|
|
80
88
|
offset=offset,
|
|
81
89
|
sort_by=payload.get("sort_by"),
|
|
@@ -113,10 +121,14 @@ async def get_thread_state(
|
|
|
113
121
|
validate_uuid(thread_id, "Invalid thread ID: must be a UUID")
|
|
114
122
|
subgraphs = request.query_params.get("subgraphs") in ("true", "True")
|
|
115
123
|
async with connect() as conn:
|
|
124
|
+
config = {
|
|
125
|
+
"configurable": {
|
|
126
|
+
**get_configurable_headers(request.headers),
|
|
127
|
+
"thread_id": thread_id,
|
|
128
|
+
}
|
|
129
|
+
}
|
|
116
130
|
state = state_snapshot_to_thread_state(
|
|
117
|
-
await Threads.State.get(
|
|
118
|
-
conn, {"configurable": {"thread_id": thread_id}}, subgraphs=subgraphs
|
|
119
|
-
)
|
|
131
|
+
await Threads.State.get(conn, config=config, subgraphs=subgraphs)
|
|
120
132
|
)
|
|
121
133
|
return ApiResponse(state)
|
|
122
134
|
|
|
@@ -130,15 +142,17 @@ async def get_thread_state_at_checkpoint(
|
|
|
130
142
|
validate_uuid(thread_id, "Invalid thread ID: must be a UUID")
|
|
131
143
|
checkpoint_id = request.path_params["checkpoint_id"]
|
|
132
144
|
async with connect() as conn:
|
|
145
|
+
config = {
|
|
146
|
+
"configurable": {
|
|
147
|
+
**get_configurable_headers(request.headers),
|
|
148
|
+
"thread_id": thread_id,
|
|
149
|
+
"checkpoint_id": checkpoint_id,
|
|
150
|
+
}
|
|
151
|
+
}
|
|
133
152
|
state = state_snapshot_to_thread_state(
|
|
134
153
|
await Threads.State.get(
|
|
135
154
|
conn,
|
|
136
|
-
|
|
137
|
-
"configurable": {
|
|
138
|
-
"thread_id": thread_id,
|
|
139
|
-
"checkpoint_id": checkpoint_id,
|
|
140
|
-
}
|
|
141
|
-
},
|
|
155
|
+
config=config,
|
|
142
156
|
subgraphs=request.query_params.get("subgraphs") in ("true", "True"),
|
|
143
157
|
)
|
|
144
158
|
)
|
|
@@ -154,10 +168,17 @@ async def get_thread_state_at_checkpoint_post(
|
|
|
154
168
|
validate_uuid(thread_id, "Invalid thread ID: must be a UUID")
|
|
155
169
|
payload = await request.json(ThreadStateCheckpointRequest)
|
|
156
170
|
async with connect() as conn:
|
|
171
|
+
config = {
|
|
172
|
+
"configurable": {
|
|
173
|
+
**payload["checkpoint"],
|
|
174
|
+
**get_configurable_headers(request.headers),
|
|
175
|
+
"thread_id": thread_id,
|
|
176
|
+
}
|
|
177
|
+
}
|
|
157
178
|
state = state_snapshot_to_thread_state(
|
|
158
179
|
await Threads.State.get(
|
|
159
180
|
conn,
|
|
160
|
-
|
|
181
|
+
config=config,
|
|
161
182
|
subgraphs=payload.get("subgraphs", False),
|
|
162
183
|
)
|
|
163
184
|
)
|
|
@@ -182,6 +203,7 @@ async def update_thread_state(
|
|
|
182
203
|
config["configurable"]["user_id"] = user_id
|
|
183
204
|
except AssertionError:
|
|
184
205
|
pass
|
|
206
|
+
config["configurable"].update(get_configurable_headers(request.headers))
|
|
185
207
|
async with connect() as conn:
|
|
186
208
|
inserted = await Threads.State.post(
|
|
187
209
|
conn,
|
|
@@ -205,7 +227,13 @@ async def get_thread_history(
|
|
|
205
227
|
except ValueError:
|
|
206
228
|
raise HTTPException(status_code=422, detail=f"Invalid limit {limit_}") from None
|
|
207
229
|
before = request.query_params.get("before")
|
|
208
|
-
config = {
|
|
230
|
+
config = {
|
|
231
|
+
"configurable": {
|
|
232
|
+
"thread_id": thread_id,
|
|
233
|
+
"checkpoint_ns": "",
|
|
234
|
+
**get_configurable_headers(request.headers),
|
|
235
|
+
}
|
|
236
|
+
}
|
|
209
237
|
async with connect() as conn:
|
|
210
238
|
states = [
|
|
211
239
|
state_snapshot_to_thread_state(c)
|
|
@@ -226,6 +254,7 @@ async def get_thread_history_post(
|
|
|
226
254
|
payload = await request.json(ThreadStateSearch)
|
|
227
255
|
config = {"configurable": {"thread_id": thread_id, "checkpoint_ns": ""}}
|
|
228
256
|
config["configurable"].update(payload.get("checkpoint", {}))
|
|
257
|
+
config["configurable"].update(get_configurable_headers(request.headers))
|
|
229
258
|
async with connect() as conn:
|
|
230
259
|
states = [
|
|
231
260
|
state_snapshot_to_thread_state(c)
|
|
@@ -293,10 +322,31 @@ async def join_thread_stream(request: ApiRequest):
|
|
|
293
322
|
validate_stream_id(
|
|
294
323
|
last_event_id, "Invalid last-event-id: must be a valid Redis stream ID"
|
|
295
324
|
)
|
|
325
|
+
|
|
326
|
+
# Parse stream_modes parameter - can be single string or comma-separated list
|
|
327
|
+
stream_modes_param = request.query_params.get("stream_modes")
|
|
328
|
+
if stream_modes_param:
|
|
329
|
+
if "," in stream_modes_param:
|
|
330
|
+
# Handle comma-separated list
|
|
331
|
+
stream_modes = [mode.strip() for mode in stream_modes_param.split(",")]
|
|
332
|
+
else:
|
|
333
|
+
# Handle single value
|
|
334
|
+
stream_modes = [stream_modes_param]
|
|
335
|
+
# Validate each mode
|
|
336
|
+
for mode in stream_modes:
|
|
337
|
+
if mode not in get_args(ThreadStreamMode):
|
|
338
|
+
raise HTTPException(
|
|
339
|
+
status_code=422, detail=f"Invalid stream mode: {mode}"
|
|
340
|
+
)
|
|
341
|
+
else:
|
|
342
|
+
# Default to run_modes
|
|
343
|
+
stream_modes = ["run_modes"]
|
|
344
|
+
|
|
296
345
|
return EventSourceResponse(
|
|
297
346
|
Threads.Stream.join(
|
|
298
347
|
thread_id,
|
|
299
348
|
last_event_id=last_event_id,
|
|
349
|
+
stream_modes=stream_modes,
|
|
300
350
|
),
|
|
301
351
|
)
|
|
302
352
|
|
langgraph_api/asyncio.py
CHANGED
|
@@ -162,7 +162,8 @@ class SimpleTaskGroup(AbstractAsyncContextManager["SimpleTaskGroup"]):
|
|
|
162
162
|
taskset: set[asyncio.Task] | None = None,
|
|
163
163
|
taskgroup_name: str | None = None,
|
|
164
164
|
) -> None:
|
|
165
|
-
|
|
165
|
+
# Copy the taskset to avoid modifying the original set unintentionally (like in lifespan)
|
|
166
|
+
self.tasks = taskset.copy() if taskset is not None else set()
|
|
166
167
|
self.cancel = cancel
|
|
167
168
|
self.wait = wait
|
|
168
169
|
if taskset:
|
langgraph_api/feature_flags.py
CHANGED
langgraph_api/logging.py
CHANGED
|
@@ -69,9 +69,12 @@ class AddApiVersion:
|
|
|
69
69
|
def __call__(
|
|
70
70
|
self, logger: logging.Logger, method_name: str, event_dict: EventDict
|
|
71
71
|
) -> EventDict:
|
|
72
|
-
|
|
72
|
+
try:
|
|
73
|
+
from langgraph_api import __version__
|
|
73
74
|
|
|
74
|
-
|
|
75
|
+
event_dict["langgraph_api_version"] = __version__
|
|
76
|
+
except ImportError:
|
|
77
|
+
pass
|
|
75
78
|
return event_dict
|
|
76
79
|
|
|
77
80
|
|