langgraph-api 0.4.20__py3-none-any.whl → 0.4.22__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of langgraph-api might be problematic. Click here for more details.

langgraph_api/__init__.py CHANGED
@@ -1 +1 @@
1
- __version__ = "0.4.20"
1
+ __version__ = "0.4.22"
langgraph_api/api/a2a.py CHANGED
@@ -7,7 +7,7 @@ A2A Protocol specification:
7
7
  https://a2a-protocol.org/dev/specification/
8
8
 
9
9
  The implementation currently supports JSON-RPC 2.0 transport only.
10
- Streaming (SSE) and push notifications are not implemented.
10
+ Push notifications are not implemented.
11
11
  """
12
12
 
13
13
  import functools
@@ -16,18 +16,19 @@ from datetime import UTC, datetime
16
16
  from typing import Any, Literal, NotRequired, cast
17
17
 
18
18
  import orjson
19
+ import structlog
19
20
  from langgraph_sdk.client import LangGraphClient, get_client
20
21
  from starlette.datastructures import Headers
21
22
  from starlette.responses import JSONResponse, Response
22
- from structlog import getLogger
23
23
  from typing_extensions import TypedDict
24
24
 
25
25
  from langgraph_api import __version__
26
26
  from langgraph_api.metadata import USER_API_URL
27
27
  from langgraph_api.route import ApiRequest, ApiRoute
28
+ from langgraph_api.sse import EventSourceResponse
28
29
  from langgraph_api.utils.cache import LRUCache
29
30
 
30
- logger = getLogger(__name__)
31
+ logger = structlog.stdlib.get_logger(__name__)
31
32
 
32
33
  # Cache for assistant schemas (assistant_id -> schemas dict)
33
34
  _assistant_schemas_cache = LRUCache[dict[str, Any]](max_size=1000, ttl=60)
@@ -286,6 +287,101 @@ def _extract_a2a_response(result: dict[str, Any]) -> str:
286
287
  return str(last_message)
287
288
 
288
289
 
290
+ def _lc_stream_items_to_a2a_message(
291
+ items: list[dict[str, Any]],
292
+ *,
293
+ task_id: str,
294
+ context_id: str,
295
+ role: Literal["agent", "user"] = "agent",
296
+ ) -> dict[str, Any]:
297
+ """Convert LangChain stream "messages/*" items into a valid A2A Message.
298
+
299
+ This takes the list found in a messages/* StreamPart's data field and
300
+ constructs a single A2A Message object, concatenating textual content and
301
+ preserving select structured metadata into a DataPart.
302
+
303
+ Args:
304
+ items: List of LangChain message dicts from stream (e.g., with keys like
305
+ "content", "type", "response_metadata", "tool_calls", etc.)
306
+ task_id: The A2A task ID this message belongs to
307
+ context_id: The A2A context ID (thread) for grouping
308
+ role: A2A role; defaults to "agent" for streamed assistant output
309
+
310
+ Returns:
311
+ A2A Message dict with required fields and minimally valid parts.
312
+ """
313
+ # Aggregate any text content across items
314
+ text_parts: list[str] = []
315
+ # Collect a small amount of structured data for debugging/traceability
316
+ extra_data: dict[str, Any] = {}
317
+
318
+ def _sse_safe_text(s: str) -> str:
319
+ return s.replace("\u2028", "\\u2028").replace("\u2029", "\\u2029")
320
+
321
+ for it in items:
322
+ if not isinstance(it, dict):
323
+ continue
324
+ content = it.get("content")
325
+ if isinstance(content, str) and content:
326
+ text_parts.append(_sse_safe_text(content))
327
+
328
+ # Preserve a couple of useful fields if present
329
+ # Keep this small to avoid bloating the message payload
330
+ rm = it.get("response_metadata")
331
+ if isinstance(rm, dict) and rm:
332
+ extra_data.setdefault("response_metadata", rm)
333
+ tc = it.get("tool_calls")
334
+ if isinstance(tc, list) and tc:
335
+ extra_data.setdefault("tool_calls", tc)
336
+
337
+ parts: list[dict[str, Any]] = []
338
+ if text_parts:
339
+ parts.append({"kind": "text", "text": "".join(text_parts)})
340
+ if extra_data:
341
+ parts.append({"kind": "data", "data": extra_data})
342
+
343
+ # Ensure we always produce a minimally valid A2A Message
344
+ if not parts:
345
+ parts = [{"kind": "text", "text": ""}]
346
+
347
+ return {
348
+ "role": role,
349
+ "parts": parts,
350
+ "messageId": str(uuid.uuid4()),
351
+ "taskId": task_id,
352
+ "contextId": context_id,
353
+ "kind": "message",
354
+ }
355
+
356
+
357
+ def _lc_items_to_status_update_event(
358
+ items: list[dict[str, Any]],
359
+ *,
360
+ task_id: str,
361
+ context_id: str,
362
+ state: str = "working",
363
+ ) -> dict[str, Any]:
364
+ """Build a TaskStatusUpdateEvent embedding a converted A2A Message.
365
+
366
+ This avoids emitting standalone Message results (which some clients reject)
367
+ and keeps message content within the status update per spec.
368
+ """
369
+ message = _lc_stream_items_to_a2a_message(
370
+ items, task_id=task_id, context_id=context_id, role="agent"
371
+ )
372
+ return {
373
+ "taskId": task_id,
374
+ "contextId": context_id,
375
+ "kind": "status-update",
376
+ "status": {
377
+ "state": state,
378
+ "message": message,
379
+ "timestamp": datetime.now(UTC).isoformat(),
380
+ },
381
+ "final": False,
382
+ }
383
+
384
+
289
385
  def _map_runs_create_error_to_rpc(
290
386
  exception: Exception, assistant_id: str, thread_id: str | None = None
291
387
  ) -> dict[str, Any]:
@@ -519,9 +615,6 @@ async def handle_post_request(request: ApiRequest, assistant_id: str) -> Respons
519
615
  except orjson.JSONDecodeError:
520
616
  return create_error_response("Invalid JSON payload", 400)
521
617
 
522
- if not is_valid_accept_header(request):
523
- return create_error_response("Accept header must include application/json", 400)
524
-
525
618
  if not isinstance(message, dict):
526
619
  return create_error_response("Invalid message format", 400)
527
620
 
@@ -534,6 +627,18 @@ async def handle_post_request(request: ApiRequest, assistant_id: str) -> Respons
534
627
  id_ = message.get("id")
535
628
  method = message.get("method")
536
629
 
630
+ accept_header = request.headers.get("Accept") or ""
631
+ if method == "message/stream":
632
+ if "text/event-stream" not in accept_header:
633
+ return create_error_response(
634
+ "Accept header must include text/event-stream for streaming", 400
635
+ )
636
+ else:
637
+ if "application/json" not in accept_header:
638
+ return create_error_response(
639
+ "Accept header must include application/json", 400
640
+ )
641
+
537
642
  if id_ is not None and method:
538
643
  # JSON-RPC request
539
644
  return await handle_jsonrpc_request(
@@ -553,19 +658,6 @@ async def handle_post_request(request: ApiRequest, assistant_id: str) -> Respons
553
658
  )
554
659
 
555
660
 
556
- def is_valid_accept_header(request: ApiRequest) -> bool:
557
- """Check if Accept header contains supported content types.
558
-
559
- Args:
560
- request: The incoming request
561
-
562
- Returns:
563
- True if header contains application/json
564
- """
565
- accept_header = request.headers.get("Accept", "")
566
- return "application/json" in accept_header
567
-
568
-
569
661
  def create_error_response(message: str, status_code: int) -> Response:
570
662
  """Create a JSON error response.
571
663
 
@@ -603,9 +695,10 @@ async def handle_jsonrpc_request(
603
695
  """
604
696
  method = message["method"]
605
697
  params = message.get("params", {})
606
-
607
698
  # Route to appropriate A2A method handler
608
- if method == "message/send":
699
+ if method == "message/stream":
700
+ return await handle_message_stream(request, params, assistant_id, message["id"])
701
+ elif method == "message/send":
609
702
  result_or_error = await handle_message_send(request, params, assistant_id)
610
703
  elif method == "tasks/get":
611
704
  result_or_error = await handle_tasks_get(request, params)
@@ -949,7 +1042,9 @@ async def generate_agent_card(request: ApiRequest, assistant_id: str) -> dict[st
949
1042
  required = input_schema.get("required", [])
950
1043
 
951
1044
  assistant_name = assistant["name"]
952
- assistant_description = assistant.get("description", f"{assistant_name} assistant")
1045
+ assistant_description = (
1046
+ assistant.get("description") or f"{assistant_name} assistant"
1047
+ )
953
1048
 
954
1049
  # For now, each assistant has one main skill - itself
955
1050
  skills = [
@@ -978,10 +1073,11 @@ async def generate_agent_card(request: ApiRequest, assistant_id: str) -> dict[st
978
1073
  scheme = request.url.scheme
979
1074
  host = request.url.hostname or "localhost"
980
1075
  port = request.url.port
1076
+ path = request.url.path.removesuffix("/.well-known/agent-card.json")
981
1077
  if port and (
982
1078
  (scheme == "http" and port != 80) or (scheme == "https" and port != 443)
983
1079
  ):
984
- base_url = f"{scheme}://{host}:{port}"
1080
+ base_url = f"{scheme}://{host}:{port}{path}"
985
1081
  else:
986
1082
  base_url = f"{scheme}://{host}"
987
1083
 
@@ -992,7 +1088,7 @@ async def generate_agent_card(request: ApiRequest, assistant_id: str) -> dict[st
992
1088
  "url": f"{base_url}/a2a/{assistant_id}",
993
1089
  "preferredTransport": "JSONRPC",
994
1090
  "capabilities": {
995
- "streaming": False, # Not implemented yet
1091
+ "streaming": True,
996
1092
  "pushNotifications": False, # Not implemented yet
997
1093
  "stateTransitionHistory": False,
998
1094
  },
@@ -1062,6 +1158,281 @@ async def handle_agent_card_endpoint(request: ApiRequest) -> Response:
1062
1158
  )
1063
1159
 
1064
1160
 
1161
+ # ============================================================================
1162
+ # Message Streaming
1163
+ # ============================================================================
1164
+
1165
+
1166
+ async def handle_message_stream(
1167
+ request: ApiRequest,
1168
+ params: dict[str, Any],
1169
+ assistant_id: str,
1170
+ rpc_id: str | int,
1171
+ ) -> Response:
1172
+ """Handle message/stream requests and stream JSON-RPC responses via SSE.
1173
+
1174
+ Each SSE "data" is a JSON-RPC 2.0 response object. We emit:
1175
+ - An initial TaskStatusUpdateEvent with state "submitted".
1176
+ - Optionally a TaskStatusUpdateEvent with state "working" on first update.
1177
+ - A final Task result when the run completes.
1178
+ - A JSON-RPC error if anything fails.
1179
+ """
1180
+ client = _client()
1181
+
1182
+ async def stream_body():
1183
+ try:
1184
+ message = params.get("message")
1185
+ if not message:
1186
+ yield (
1187
+ b"message",
1188
+ {
1189
+ "jsonrpc": "2.0",
1190
+ "id": rpc_id,
1191
+ "error": {
1192
+ "code": ERROR_CODE_INVALID_PARAMS,
1193
+ "message": "Missing 'message' in params",
1194
+ },
1195
+ },
1196
+ )
1197
+ return
1198
+
1199
+ parts = message.get("parts", [])
1200
+ if not parts:
1201
+ yield (
1202
+ b"message",
1203
+ {
1204
+ "jsonrpc": "2.0",
1205
+ "id": rpc_id,
1206
+ "error": {
1207
+ "code": ERROR_CODE_INVALID_PARAMS,
1208
+ "message": "Message must contain at least one part",
1209
+ },
1210
+ },
1211
+ )
1212
+ return
1213
+
1214
+ try:
1215
+ assistant = await _get_assistant(client, assistant_id, request.headers)
1216
+ await _validate_supports_messages(
1217
+ client, assistant, request.headers, parts
1218
+ )
1219
+ except ValueError as e:
1220
+ yield (
1221
+ b"message",
1222
+ {
1223
+ "jsonrpc": "2.0",
1224
+ "id": rpc_id,
1225
+ "error": {
1226
+ "code": ERROR_CODE_INVALID_PARAMS,
1227
+ "message": str(e),
1228
+ },
1229
+ },
1230
+ )
1231
+ return
1232
+
1233
+ # Process A2A message parts into LangChain messages format
1234
+ try:
1235
+ message_role = message.get("role", "user")
1236
+ input_content = _process_a2a_message_parts(parts, message_role)
1237
+ except ValueError as e:
1238
+ yield (
1239
+ b"message",
1240
+ {
1241
+ "jsonrpc": "2.0",
1242
+ "id": rpc_id,
1243
+ "error": {
1244
+ "code": ERROR_CODE_CONTENT_TYPE_NOT_SUPPORTED,
1245
+ "message": str(e),
1246
+ },
1247
+ },
1248
+ )
1249
+ return
1250
+
1251
+ run = await client.runs.create(
1252
+ thread_id=message.get("contextId"),
1253
+ assistant_id=assistant_id,
1254
+ stream_mode=["messages", "values"],
1255
+ if_not_exists="create",
1256
+ input=input_content,
1257
+ headers=request.headers,
1258
+ )
1259
+ context_id = run["thread_id"]
1260
+ # Emit initial Task object to establish task context
1261
+ initial_task = {
1262
+ "id": run["run_id"],
1263
+ "contextId": context_id,
1264
+ "history": [
1265
+ {
1266
+ **message,
1267
+ "taskId": run["run_id"],
1268
+ "contextId": context_id,
1269
+ "kind": "message",
1270
+ }
1271
+ ],
1272
+ "kind": "task",
1273
+ "status": {
1274
+ "state": "submitted",
1275
+ "timestamp": datetime.now(UTC).isoformat(),
1276
+ },
1277
+ }
1278
+ yield (b"message", {"jsonrpc": "2.0", "id": rpc_id, "result": initial_task})
1279
+ task_id = run["run_id"]
1280
+ stream = client.runs.join_stream(
1281
+ run_id=task_id,
1282
+ thread_id=context_id,
1283
+ headers=request.headers,
1284
+ )
1285
+ result = None
1286
+ err = None
1287
+ notified_is_working = False
1288
+ async for chunk in stream:
1289
+ try:
1290
+ if chunk.event == "metadata":
1291
+ data = chunk.data or {}
1292
+ if data.get("status") == "run_done":
1293
+ final_message = None
1294
+ if isinstance(result, dict):
1295
+ try:
1296
+ final_text = _extract_a2a_response(result)
1297
+ final_message = {
1298
+ "role": "agent",
1299
+ "parts": [{"kind": "text", "text": final_text}],
1300
+ "messageId": str(uuid.uuid4()),
1301
+ "taskId": task_id,
1302
+ "contextId": context_id,
1303
+ "kind": "message",
1304
+ }
1305
+ except Exception:
1306
+ await logger.aexception(
1307
+ "Failed to extract final message from result",
1308
+ result=result,
1309
+ )
1310
+ if final_message is None:
1311
+ final_message = {
1312
+ "role": "agent",
1313
+ "parts": [{"kind": "text", "text": str(result)}],
1314
+ "messageId": str(uuid.uuid4()),
1315
+ "taskId": task_id,
1316
+ "contextId": context_id,
1317
+ "kind": "message",
1318
+ }
1319
+ completed = {
1320
+ "taskId": task_id,
1321
+ "contextId": context_id,
1322
+ "kind": "status-update",
1323
+ "status": {
1324
+ "state": "completed",
1325
+ "message": final_message,
1326
+ "timestamp": datetime.now(UTC).isoformat(),
1327
+ },
1328
+ "final": True,
1329
+ }
1330
+ yield (
1331
+ b"message",
1332
+ {"jsonrpc": "2.0", "id": rpc_id, "result": completed},
1333
+ )
1334
+ return
1335
+ if data.get("run_id") and not notified_is_working:
1336
+ notified_is_working = True
1337
+ yield (
1338
+ b"message",
1339
+ {
1340
+ "jsonrpc": "2.0",
1341
+ "id": rpc_id,
1342
+ "result": {
1343
+ "taskId": task_id,
1344
+ "contextId": context_id,
1345
+ "kind": "status-update",
1346
+ "status": {"state": "working"},
1347
+ "final": False,
1348
+ },
1349
+ },
1350
+ )
1351
+ elif chunk.event == "error":
1352
+ err = chunk.data
1353
+ elif chunk.event == "values":
1354
+ err = None # Error was retriable
1355
+ result = chunk.data
1356
+ elif chunk.event.startswith("messages"):
1357
+ err = None # Error was retriable
1358
+ items = chunk.data or []
1359
+ if isinstance(items, list) and items:
1360
+ update = _lc_items_to_status_update_event(
1361
+ items,
1362
+ task_id=task_id,
1363
+ context_id=context_id,
1364
+ state="working",
1365
+ )
1366
+ yield (
1367
+ b"message",
1368
+ {"jsonrpc": "2.0", "id": rpc_id, "result": update},
1369
+ )
1370
+ else:
1371
+ await logger.awarning(
1372
+ "Ignoring unknown event type: " + chunk.event
1373
+ )
1374
+
1375
+ except Exception as e:
1376
+ await logger.aexception("Failed to process message stream")
1377
+ err = {"error": type(e).__name__, "message": str(e)}
1378
+ continue
1379
+
1380
+ # If we exit unexpectedly, send a final status based on error presence
1381
+ final_message = None
1382
+ if isinstance(err, dict) and ("__error__" in err or "error" in err):
1383
+ msg = (
1384
+ err.get("__error__", {}).get("error")
1385
+ if isinstance(err.get("__error__"), dict)
1386
+ else err.get("message")
1387
+ )
1388
+ await logger.aerror("Failed to process message stream", err=err)
1389
+ final_message = {
1390
+ "role": "agent",
1391
+ "parts": [{"kind": "text", "text": str(msg or "")}],
1392
+ "messageId": str(uuid.uuid4()),
1393
+ "taskId": task_id,
1394
+ "contextId": context_id,
1395
+ "kind": "message",
1396
+ }
1397
+ fallback = {
1398
+ "taskId": task_id,
1399
+ "contextId": context_id,
1400
+ "kind": "status-update",
1401
+ "status": {
1402
+ "state": "failed" if err else "completed",
1403
+ **({"message": final_message} if final_message else {}),
1404
+ "timestamp": datetime.now(UTC).isoformat(),
1405
+ },
1406
+ "final": True,
1407
+ }
1408
+ yield (b"message", {"jsonrpc": "2.0", "id": rpc_id, "result": fallback})
1409
+ except Exception as e:
1410
+ await logger.aerror(
1411
+ f"Error in message/stream for assistant {assistant_id}: {str(e)}",
1412
+ exc_info=True,
1413
+ )
1414
+ yield (
1415
+ b"message",
1416
+ {
1417
+ "jsonrpc": "2.0",
1418
+ "id": rpc_id,
1419
+ "error": {
1420
+ "code": ERROR_CODE_INTERNAL_ERROR,
1421
+ "message": f"Internal server error: {str(e)}",
1422
+ },
1423
+ },
1424
+ )
1425
+
1426
+ async def consume_():
1427
+ async for chunk in stream_body():
1428
+ await logger.adebug("A2A.stream_body: Yielding chunk", chunk=chunk)
1429
+ yield chunk
1430
+
1431
+ return EventSourceResponse(
1432
+ consume_(), headers={"Content-Type": "text/event-stream"}
1433
+ )
1434
+
1435
+
1065
1436
  # ============================================================================
1066
1437
  # Route Definitions
1067
1438
  # ============================================================================
@@ -17,7 +17,7 @@ from langgraph_api.graph import get_assistant_id, get_graph
17
17
  from langgraph_api.js.base import BaseRemotePregel
18
18
  from langgraph_api.route import ApiRequest, ApiResponse, ApiRoute
19
19
  from langgraph_api.schema import ASSISTANT_FIELDS
20
- from langgraph_api.serde import ajson_loads
20
+ from langgraph_api.serde import json_loads
21
21
  from langgraph_api.utils import (
22
22
  fetchone,
23
23
  get_pagination_headers,
@@ -240,7 +240,7 @@ async def get_assistant_graph(
240
240
  async with connect() as conn:
241
241
  assistant_ = await Assistants.get(conn, assistant_id)
242
242
  assistant = await fetchone(assistant_)
243
- config = await ajson_loads(assistant["config"])
243
+ config = json_loads(assistant["config"])
244
244
  configurable = config.setdefault("configurable", {})
245
245
  configurable.update(get_configurable_headers(request.headers))
246
246
 
@@ -297,7 +297,7 @@ async def get_assistant_subgraphs(
297
297
  async with connect() as conn:
298
298
  assistant_ = await Assistants.get(conn, assistant_id)
299
299
  assistant = await fetchone(assistant_)
300
- config = await ajson_loads(assistant["config"])
300
+ config = json_loads(assistant["config"])
301
301
  configurable = config.setdefault("configurable", {})
302
302
  configurable.update(get_configurable_headers(request.headers))
303
303
  async with get_graph(
@@ -345,7 +345,7 @@ async def get_assistant_schemas(
345
345
  assistant_ = await Assistants.get(conn, assistant_id)
346
346
  # TODO Implementa cache so we can de-dent and release this connection.
347
347
  assistant = await fetchone(assistant_)
348
- config = await ajson_loads(assistant["config"])
348
+ config = json_loads(assistant["config"])
349
349
  configurable = config.setdefault("configurable", {})
350
350
  configurable.update(get_configurable_headers(request.headers))
351
351
  async with get_graph(