langgraph-api 0.2.44__py3-none-any.whl → 0.2.46__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of langgraph-api might be problematic. Click here for more details.

langgraph_api/__init__.py CHANGED
@@ -1 +1 @@
1
- __version__ = "0.2.44"
1
+ __version__ = "0.2.46"
@@ -189,40 +189,40 @@ async def get_assistant_graph(
189
189
  async with connect() as conn:
190
190
  assistant_ = await Assistants.get(conn, assistant_id)
191
191
  assistant = await fetchone(assistant_)
192
- config = await ajson_loads(assistant["config"])
193
- async with get_graph(
194
- assistant["graph_id"],
195
- config,
196
- checkpointer=Checkpointer(conn),
197
- store=(await api_store.get_store()),
198
- ) as graph:
199
- xray: bool | int = False
200
- xray_query = request.query_params.get("xray")
201
- if xray_query:
202
- if xray_query in ("true", "True"):
203
- xray = True
204
- elif xray_query in ("false", "False"):
205
- xray = False
206
- else:
207
- try:
208
- xray = int(xray_query)
209
- except ValueError:
210
- raise HTTPException(422, detail="Invalid xray value") from None
211
-
212
- if xray <= 0:
213
- raise HTTPException(422, detail="Invalid xray value") from None
214
-
215
- if isinstance(graph, BaseRemotePregel):
216
- drawable_graph = await graph.fetch_graph(xray=xray)
217
- return ApiResponse(drawable_graph.to_json())
192
+ config = await ajson_loads(assistant["config"])
193
+ async with get_graph(
194
+ assistant["graph_id"],
195
+ config,
196
+ checkpointer=Checkpointer(),
197
+ store=(await api_store.get_store()),
198
+ ) as graph:
199
+ xray: bool | int = False
200
+ xray_query = request.query_params.get("xray")
201
+ if xray_query:
202
+ if xray_query in ("true", "True"):
203
+ xray = True
204
+ elif xray_query in ("false", "False"):
205
+ xray = False
206
+ else:
207
+ try:
208
+ xray = int(xray_query)
209
+ except ValueError:
210
+ raise HTTPException(422, detail="Invalid xray value") from None
211
+
212
+ if xray <= 0:
213
+ raise HTTPException(422, detail="Invalid xray value") from None
214
+
215
+ if isinstance(graph, BaseRemotePregel):
216
+ drawable_graph = await graph.fetch_graph(xray=xray)
217
+ return ApiResponse(drawable_graph.to_json())
218
218
 
219
- try:
220
- drawable_graph = await graph.aget_graph(xray=xray)
221
- return ApiResponse(drawable_graph.to_json())
222
- except NotImplementedError:
223
- raise HTTPException(
224
- 422, detail="The graph does not support visualization"
225
- ) from None
219
+ try:
220
+ drawable_graph = await graph.aget_graph(xray=xray)
221
+ return ApiResponse(drawable_graph.to_json())
222
+ except NotImplementedError:
223
+ raise HTTPException(
224
+ 422, detail="The graph does not support visualization"
225
+ ) from None
226
226
 
227
227
 
228
228
  @retry_db
@@ -239,7 +239,7 @@ async def get_assistant_subgraphs(
239
239
  async with get_graph(
240
240
  assistant["graph_id"],
241
241
  config,
242
- checkpointer=Checkpointer(conn),
242
+ checkpointer=Checkpointer(),
243
243
  store=(await api_store.get_store()),
244
244
  ) as graph:
245
245
  namespace = request.path_params.get("namespace")
@@ -285,7 +285,7 @@ async def get_assistant_schemas(
285
285
  async with get_graph(
286
286
  assistant["graph_id"],
287
287
  config,
288
- checkpointer=Checkpointer(conn),
288
+ checkpointer=Checkpointer(),
289
289
  store=(await api_store.get_store()),
290
290
  ) as graph:
291
291
  if isinstance(graph, BaseRemotePregel):
langgraph_api/api/runs.py CHANGED
@@ -222,6 +222,9 @@ async def wait_run(request: ApiRequest):
222
222
  stream = asyncio.create_task(consume())
223
223
  while True:
224
224
  try:
225
+ if stream.done():
226
+ # raise stream exception if any
227
+ stream.result()
225
228
  yield await asyncio.wait_for(last_chunk.wait(), timeout=5)
226
229
  break
227
230
  except TimeoutError:
@@ -270,7 +273,10 @@ async def wait_run_stateless(request: ApiRequest):
270
273
  vchunk: bytes | None = None
271
274
  async with aclosing(
272
275
  Runs.Stream.join(
273
- run["run_id"], thread_id=run["thread_id"], stream_mode=await sub
276
+ run["run_id"],
277
+ thread_id=run["thread_id"],
278
+ stream_mode=await sub,
279
+ ignore_404=True,
274
280
  )
275
281
  ) as stream:
276
282
  async for mode, chunk, _ in stream:
@@ -290,6 +296,9 @@ async def wait_run_stateless(request: ApiRequest):
290
296
  stream = asyncio.create_task(consume())
291
297
  while True:
292
298
  try:
299
+ if stream.done():
300
+ # raise stream exception if any
301
+ stream.result()
293
302
  yield await asyncio.wait_for(last_chunk.wait(), timeout=5)
294
303
  break
295
304
  except TimeoutError:
@@ -48,7 +48,6 @@ from langgraph_api.js.sse import SSEDecoder, aiter_lines_raw
48
48
  from langgraph_api.route import ApiResponse
49
49
  from langgraph_api.schema import Config
50
50
  from langgraph_api.serde import json_dumpb
51
- from langgraph_api.utils import AsyncConnectionProto
52
51
 
53
52
  logger = structlog.stdlib.get_logger(__name__)
54
53
 
@@ -452,27 +451,27 @@ async def run_js_http_process(paths_str: str, http_config: dict, watch: bool = F
452
451
  attempt += 1
453
452
 
454
453
 
455
- def _get_passthrough_checkpointer(conn: AsyncConnectionProto):
456
- from langgraph_runtime.checkpoint import Checkpointer
454
+ class PassthroughSerialiser(SerializerProtocol):
455
+ def dumps(self, obj: Any) -> bytes:
456
+ return json_dumpb(obj)
457
457
 
458
- class PassthroughSerialiser(SerializerProtocol):
459
- def dumps(self, obj: Any) -> bytes:
460
- return json_dumpb(obj)
458
+ def dumps_typed(self, obj: Any) -> tuple[str, bytes]:
459
+ return "json", json_dumpb(obj)
461
460
 
462
- def dumps_typed(self, obj: Any) -> tuple[str, bytes]:
463
- return "json", json_dumpb(obj)
461
+ def loads(self, data: bytes) -> Any:
462
+ return orjson.loads(data)
464
463
 
465
- def loads(self, data: bytes) -> Any:
466
- return orjson.loads(data)
464
+ def loads_typed(self, data: tuple[str, bytes]) -> Any:
465
+ type, payload = data
466
+ if type != "json":
467
+ raise ValueError(f"Unsupported type {type}")
468
+ return orjson.loads(payload)
467
469
 
468
- def loads_typed(self, data: tuple[str, bytes]) -> Any:
469
- type, payload = data
470
- if type != "json":
471
- raise ValueError(f"Unsupported type {type}")
472
- return orjson.loads(payload)
473
470
 
474
- checkpointer = Checkpointer(conn)
471
+ def _get_passthrough_checkpointer():
472
+ from langgraph_runtime.checkpoint import Checkpointer
475
473
 
474
+ checkpointer = Checkpointer()
476
475
  # This checkpointer does not attempt to revive LC-objects.
477
476
  # Instead, it will pass through the JSON values as-is.
478
477
  checkpointer.serde = PassthroughSerialiser()
@@ -487,53 +486,46 @@ async def _get_passthrough_store():
487
486
  # Setup a HTTP server on top of CHECKPOINTER_SOCKET unix socket
488
487
  # used by `client.mts` to communicate with the Python checkpointer
489
488
  async def run_remote_checkpointer():
490
- from langgraph_runtime.database import connect
491
-
492
489
  async def checkpointer_list(payload: dict):
493
490
  """Search checkpoints"""
494
491
 
495
492
  result = []
496
- async with connect() as conn:
497
- checkpointer = _get_passthrough_checkpointer(conn)
498
- async for item in checkpointer.alist(
499
- config=payload.get("config"),
500
- limit=int(payload.get("limit") or 10),
501
- before=payload.get("before"),
502
- filter=payload.get("filter"),
503
- ):
504
- result.append(item)
493
+ checkpointer = _get_passthrough_checkpointer()
494
+ async for item in checkpointer.alist(
495
+ config=payload.get("config"),
496
+ limit=int(payload.get("limit") or 10),
497
+ before=payload.get("before"),
498
+ filter=payload.get("filter"),
499
+ ):
500
+ result.append(item)
505
501
 
506
502
  return result
507
503
 
508
504
  async def checkpointer_put(payload: dict):
509
505
  """Put the new checkpoint metadata"""
510
506
 
511
- async with connect() as conn:
512
- checkpointer = _get_passthrough_checkpointer(conn)
513
- return await checkpointer.aput(
514
- payload["config"],
515
- payload["checkpoint"],
516
- payload["metadata"],
517
- payload.get("new_versions", {}),
518
- )
507
+ checkpointer = _get_passthrough_checkpointer()
508
+ return await checkpointer.aput(
509
+ payload["config"],
510
+ payload["checkpoint"],
511
+ payload["metadata"],
512
+ payload.get("new_versions", {}),
513
+ )
519
514
 
520
515
  async def checkpointer_get_tuple(payload: dict):
521
516
  """Get actual checkpoint values (reads)"""
522
-
523
- async with connect() as conn:
524
- checkpointer = _get_passthrough_checkpointer(conn)
525
- return await checkpointer.aget_tuple(config=payload["config"])
517
+ checkpointer = _get_passthrough_checkpointer()
518
+ return await checkpointer.aget_tuple(config=payload["config"])
526
519
 
527
520
  async def checkpointer_put_writes(payload: dict):
528
521
  """Put actual checkpoint values (writes)"""
529
522
 
530
- async with connect() as conn:
531
- checkpointer = _get_passthrough_checkpointer(conn)
532
- return await checkpointer.aput_writes(
533
- payload["config"],
534
- payload["writes"],
535
- payload["taskId"],
536
- )
523
+ checkpointer = _get_passthrough_checkpointer()
524
+ return await checkpointer.aput_writes(
525
+ payload["config"],
526
+ payload["writes"],
527
+ payload["taskId"],
528
+ )
537
529
 
538
530
  async def store_batch(payload: dict):
539
531
  """Batch operations on the store"""
@@ -687,7 +679,11 @@ async def run_remote_checkpointer():
687
679
  payload = orjson.loads(await request.body())
688
680
  return ApiResponse(await cb(payload))
689
681
  except ValueError as exc:
682
+ await logger.error(exc)
690
683
  return ApiResponse({"error": str(exc)}, status_code=400)
684
+ except Exception as exc:
685
+ await logger.error(exc)
686
+ return ApiResponse({"error": str(exc)}, status_code=500)
691
687
 
692
688
  return wrapped
693
689
 
@@ -800,15 +796,18 @@ async def js_healthcheck():
800
796
  transport=httpx.AsyncHTTPTransport(verify=SSL),
801
797
  ) as checkpointer_client,
802
798
  ):
799
+ graph_passed = False
803
800
  try:
804
801
  res = await graph_client.get("/ok")
805
802
  res.raise_for_status()
803
+ graph_passed = True
806
804
  res = await checkpointer_client.get("/ok")
807
805
  res.raise_for_status()
808
806
  return True
809
807
  except httpx.HTTPError as exc:
810
808
  logger.warning(
811
809
  "JS healthcheck failed. Either the JS server is not running or the event loop is blocked by a CPU-intensive task.",
810
+ graph_passed=graph_passed,
812
811
  error=exc,
813
812
  )
814
813
  raise HTTPException(
langgraph_api/logging.py CHANGED
@@ -57,6 +57,16 @@ class AddPrefixedEnvVars:
57
57
  return event_dict
58
58
 
59
59
 
60
+ class AddApiVersion:
61
+ def __call__(
62
+ self, logger: logging.Logger, method_name: str, event_dict: EventDict
63
+ ) -> EventDict:
64
+ from langgraph_api import __version__
65
+
66
+ event_dict["langgraph_api_version"] = __version__
67
+ return event_dict
68
+
69
+
60
70
  class AddLoggingContext:
61
71
  def __call__(
62
72
  self, logger: logging.Logger, method_name: str, event_dict: EventDict
@@ -90,11 +100,14 @@ shared_processors = [
90
100
  structlog.stdlib.PositionalArgumentsFormatter(),
91
101
  structlog.stdlib.ExtraAdder(),
92
102
  AddPrefixedEnvVars("LANGSMITH_LANGGRAPH_"), # injected by docker build
103
+ AddApiVersion(),
93
104
  structlog.processors.TimeStamper(fmt="iso", utc=True),
94
105
  structlog.processors.StackInfoRenderer(),
95
- structlog.processors.dict_tracebacks
96
- if LOG_JSON
97
- else structlog.processors.format_exc_info,
106
+ (
107
+ structlog.processors.dict_tracebacks
108
+ if LOG_JSON
109
+ else structlog.processors.format_exc_info
110
+ ),
98
111
  structlog.processors.UnicodeDecoder(),
99
112
  AddLoggingContext(),
100
113
  ]
@@ -29,6 +29,13 @@ from langgraph_api.utils import AsyncConnectionProto, get_auth_ctx
29
29
  from langgraph_runtime.ops import Runs, logger
30
30
 
31
31
 
32
+ class LangSmithTracer(TypedDict, total=False):
33
+ """Configuration for LangSmith tracing."""
34
+
35
+ example_id: str | None
36
+ project_name: str | None
37
+
38
+
32
39
  class RunCreateDict(TypedDict):
33
40
  """Payload for creating a run."""
34
41
 
@@ -87,6 +94,8 @@ class RunCreateDict(TypedDict):
87
94
  """Start the run after this many seconds. Defaults to 0."""
88
95
  if_not_exists: IfNotExists
89
96
  """Create the thread if it doesn't exist. If False, reply with 404."""
97
+ langsmith_tracer: LangSmithTracer | None
98
+ """Configuration for additional tracing with LangSmith."""
90
99
 
91
100
 
92
101
  def ensure_ids(
@@ -295,6 +304,9 @@ async def create_valid_run(
295
304
  user_id = None
296
305
  if not configurable.get("langgraph_request_id"):
297
306
  configurable["langgraph_request_id"] = request_id
307
+ if ls_tracing := payload.get("langsmith_tracer"):
308
+ configurable["__langsmith_project__"] = ls_tracing.get("project_name")
309
+ configurable["__langsmith_example_id__"] = ls_tracing.get("example_id")
298
310
  if request_start_time:
299
311
  configurable["__request_start_time_ms__"] = request_start_time
300
312
  after_seconds = payload.get("after_seconds", 0)
langgraph_api/stream.py CHANGED
@@ -1,5 +1,5 @@
1
1
  from collections.abc import AsyncIterator, Callable
2
- from contextlib import AsyncExitStack, aclosing
2
+ from contextlib import AsyncExitStack, aclosing, asynccontextmanager
3
3
  from functools import lru_cache
4
4
  from typing import Any, cast
5
5
 
@@ -20,6 +20,7 @@ from langgraph.errors import (
20
20
  InvalidUpdateError,
21
21
  )
22
22
  from langgraph.pregel.debug import CheckpointPayload, TaskResultPayload
23
+ from langsmith.utils import get_tracer_project
23
24
  from pydantic import ValidationError
24
25
  from pydantic.v1 import ValidationError as ValidationErrorLegacy
25
26
 
@@ -31,7 +32,6 @@ from langgraph_api.js.base import BaseRemotePregel
31
32
  from langgraph_api.metadata import HOST, PLAN, USER_API_URL, incr_nodes
32
33
  from langgraph_api.schema import Run, StreamMode
33
34
  from langgraph_api.serde import json_dumpb
34
- from langgraph_api.utils import AsyncConnectionProto
35
35
  from langgraph_runtime.checkpoint import Checkpointer
36
36
  from langgraph_runtime.ops import Runs
37
37
 
@@ -71,9 +71,13 @@ def _preprocess_debug_checkpoint(payload: CheckpointPayload | None) -> dict[str,
71
71
  return payload
72
72
 
73
73
 
74
+ @asynccontextmanager
75
+ async def async_tracing_context(*args, **kwargs):
76
+ with langsmith.tracing_context(*args, **kwargs):
77
+ yield
78
+
79
+
74
80
  async def astream_state(
75
- stack: AsyncExitStack,
76
- conn: AsyncConnectionProto,
77
81
  run: Run,
78
82
  attempt: int,
79
83
  done: ValueEvent,
@@ -83,7 +87,6 @@ async def astream_state(
83
87
  ) -> AnyStream:
84
88
  """Stream messages from the runnable."""
85
89
  run_id = str(run["run_id"])
86
- await stack.enter_async_context(conn.pipeline())
87
90
  # extract args from run
88
91
  kwargs = run["kwargs"].copy()
89
92
  kwargs.pop("webhook", None)
@@ -91,12 +94,13 @@ async def astream_state(
91
94
  subgraphs = kwargs.get("subgraphs", False)
92
95
  temporary = kwargs.pop("temporary", False)
93
96
  config = kwargs.pop("config")
97
+ stack = AsyncExitStack()
94
98
  graph = await stack.enter_async_context(
95
99
  get_graph(
96
100
  config["configurable"]["graph_id"],
97
101
  config,
98
102
  store=(await api_store.get_store()),
99
- checkpointer=None if temporary else Checkpointer(conn),
103
+ checkpointer=None if temporary else Checkpointer(),
100
104
  )
101
105
  )
102
106
  input = kwargs.pop("input")
@@ -131,6 +135,25 @@ async def astream_state(
131
135
  use_astream_events = "events" in stream_mode or isinstance(graph, BaseRemotePregel)
132
136
  # yield metadata chunk
133
137
  yield "metadata", {"run_id": run_id, "attempt": attempt}
138
+
139
+ # is a langsmith tracing project is specified, additionally pass that in to tracing context
140
+ if ls_project := config["configurable"].get("__langsmith_project__"):
141
+ updates = None
142
+ if example_id := config["configurable"].get("__langsmith_example_id__"):
143
+ updates = {"reference_example_id": example_id}
144
+
145
+ await stack.enter_async_context(
146
+ async_tracing_context(
147
+ replicas=[
148
+ (
149
+ ls_project,
150
+ updates,
151
+ ),
152
+ (get_tracer_project(), None),
153
+ ]
154
+ )
155
+ )
156
+
134
157
  # stream run
135
158
  if use_astream_events:
136
159
  async with (
@@ -269,7 +292,7 @@ async def astream_state(
269
292
  yield mode, chunk
270
293
  # --- end shared logic with astream_events ---
271
294
  if is_remote_pregel:
272
- # increament the remote runs
295
+ # increment the remote runs
273
296
  try:
274
297
  nodes_executed = await graph.fetch_nodes_executed()
275
298
  incr_nodes(None, incr=nodes_executed)