langgraph-api 0.4.1__py3-none-any.whl → 0.7.3__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (135) hide show
  1. langgraph_api/__init__.py +1 -1
  2. langgraph_api/api/__init__.py +111 -51
  3. langgraph_api/api/a2a.py +1610 -0
  4. langgraph_api/api/assistants.py +212 -89
  5. langgraph_api/api/mcp.py +3 -3
  6. langgraph_api/api/meta.py +52 -28
  7. langgraph_api/api/openapi.py +27 -17
  8. langgraph_api/api/profile.py +108 -0
  9. langgraph_api/api/runs.py +342 -195
  10. langgraph_api/api/store.py +19 -2
  11. langgraph_api/api/threads.py +209 -27
  12. langgraph_api/asgi_transport.py +14 -9
  13. langgraph_api/asyncio.py +14 -4
  14. langgraph_api/auth/custom.py +52 -37
  15. langgraph_api/auth/langsmith/backend.py +4 -3
  16. langgraph_api/auth/langsmith/client.py +13 -8
  17. langgraph_api/cli.py +230 -133
  18. langgraph_api/command.py +5 -3
  19. langgraph_api/config/__init__.py +532 -0
  20. langgraph_api/config/_parse.py +58 -0
  21. langgraph_api/config/schemas.py +431 -0
  22. langgraph_api/cron_scheduler.py +17 -1
  23. langgraph_api/encryption/__init__.py +15 -0
  24. langgraph_api/encryption/aes_json.py +158 -0
  25. langgraph_api/encryption/context.py +35 -0
  26. langgraph_api/encryption/custom.py +280 -0
  27. langgraph_api/encryption/middleware.py +632 -0
  28. langgraph_api/encryption/shared.py +63 -0
  29. langgraph_api/errors.py +12 -1
  30. langgraph_api/executor_entrypoint.py +11 -6
  31. langgraph_api/feature_flags.py +29 -0
  32. langgraph_api/graph.py +176 -76
  33. langgraph_api/grpc/client.py +313 -0
  34. langgraph_api/grpc/config_conversion.py +231 -0
  35. langgraph_api/grpc/generated/__init__.py +29 -0
  36. langgraph_api/grpc/generated/checkpointer_pb2.py +63 -0
  37. langgraph_api/grpc/generated/checkpointer_pb2.pyi +99 -0
  38. langgraph_api/grpc/generated/checkpointer_pb2_grpc.py +329 -0
  39. langgraph_api/grpc/generated/core_api_pb2.py +216 -0
  40. langgraph_api/grpc/generated/core_api_pb2.pyi +905 -0
  41. langgraph_api/grpc/generated/core_api_pb2_grpc.py +1621 -0
  42. langgraph_api/grpc/generated/engine_common_pb2.py +219 -0
  43. langgraph_api/grpc/generated/engine_common_pb2.pyi +722 -0
  44. langgraph_api/grpc/generated/engine_common_pb2_grpc.py +24 -0
  45. langgraph_api/grpc/generated/enum_cancel_run_action_pb2.py +37 -0
  46. langgraph_api/grpc/generated/enum_cancel_run_action_pb2.pyi +12 -0
  47. langgraph_api/grpc/generated/enum_cancel_run_action_pb2_grpc.py +24 -0
  48. langgraph_api/grpc/generated/enum_control_signal_pb2.py +37 -0
  49. langgraph_api/grpc/generated/enum_control_signal_pb2.pyi +16 -0
  50. langgraph_api/grpc/generated/enum_control_signal_pb2_grpc.py +24 -0
  51. langgraph_api/grpc/generated/enum_durability_pb2.py +37 -0
  52. langgraph_api/grpc/generated/enum_durability_pb2.pyi +16 -0
  53. langgraph_api/grpc/generated/enum_durability_pb2_grpc.py +24 -0
  54. langgraph_api/grpc/generated/enum_multitask_strategy_pb2.py +37 -0
  55. langgraph_api/grpc/generated/enum_multitask_strategy_pb2.pyi +16 -0
  56. langgraph_api/grpc/generated/enum_multitask_strategy_pb2_grpc.py +24 -0
  57. langgraph_api/grpc/generated/enum_run_status_pb2.py +37 -0
  58. langgraph_api/grpc/generated/enum_run_status_pb2.pyi +22 -0
  59. langgraph_api/grpc/generated/enum_run_status_pb2_grpc.py +24 -0
  60. langgraph_api/grpc/generated/enum_stream_mode_pb2.py +37 -0
  61. langgraph_api/grpc/generated/enum_stream_mode_pb2.pyi +28 -0
  62. langgraph_api/grpc/generated/enum_stream_mode_pb2_grpc.py +24 -0
  63. langgraph_api/grpc/generated/enum_thread_status_pb2.py +37 -0
  64. langgraph_api/grpc/generated/enum_thread_status_pb2.pyi +16 -0
  65. langgraph_api/grpc/generated/enum_thread_status_pb2_grpc.py +24 -0
  66. langgraph_api/grpc/generated/enum_thread_stream_mode_pb2.py +37 -0
  67. langgraph_api/grpc/generated/enum_thread_stream_mode_pb2.pyi +16 -0
  68. langgraph_api/grpc/generated/enum_thread_stream_mode_pb2_grpc.py +24 -0
  69. langgraph_api/grpc/generated/errors_pb2.py +39 -0
  70. langgraph_api/grpc/generated/errors_pb2.pyi +21 -0
  71. langgraph_api/grpc/generated/errors_pb2_grpc.py +24 -0
  72. langgraph_api/grpc/ops/__init__.py +370 -0
  73. langgraph_api/grpc/ops/assistants.py +424 -0
  74. langgraph_api/grpc/ops/runs.py +792 -0
  75. langgraph_api/grpc/ops/threads.py +1013 -0
  76. langgraph_api/http.py +16 -5
  77. langgraph_api/http_metrics.py +15 -35
  78. langgraph_api/http_metrics_utils.py +38 -0
  79. langgraph_api/js/build.mts +1 -1
  80. langgraph_api/js/client.http.mts +13 -7
  81. langgraph_api/js/client.mts +2 -5
  82. langgraph_api/js/package.json +29 -28
  83. langgraph_api/js/remote.py +56 -30
  84. langgraph_api/js/src/graph.mts +20 -0
  85. langgraph_api/js/sse.py +2 -2
  86. langgraph_api/js/ui.py +1 -1
  87. langgraph_api/js/yarn.lock +1204 -1006
  88. langgraph_api/logging.py +29 -2
  89. langgraph_api/metadata.py +99 -28
  90. langgraph_api/middleware/http_logger.py +7 -2
  91. langgraph_api/middleware/private_network.py +7 -7
  92. langgraph_api/models/run.py +54 -93
  93. langgraph_api/otel_context.py +205 -0
  94. langgraph_api/patch.py +5 -3
  95. langgraph_api/queue_entrypoint.py +154 -65
  96. langgraph_api/route.py +47 -5
  97. langgraph_api/schema.py +88 -10
  98. langgraph_api/self_hosted_logs.py +124 -0
  99. langgraph_api/self_hosted_metrics.py +450 -0
  100. langgraph_api/serde.py +79 -37
  101. langgraph_api/server.py +138 -60
  102. langgraph_api/state.py +4 -3
  103. langgraph_api/store.py +25 -16
  104. langgraph_api/stream.py +80 -29
  105. langgraph_api/thread_ttl.py +31 -13
  106. langgraph_api/timing/__init__.py +25 -0
  107. langgraph_api/timing/profiler.py +200 -0
  108. langgraph_api/timing/timer.py +318 -0
  109. langgraph_api/utils/__init__.py +53 -8
  110. langgraph_api/utils/cache.py +47 -10
  111. langgraph_api/utils/config.py +2 -1
  112. langgraph_api/utils/errors.py +77 -0
  113. langgraph_api/utils/future.py +10 -6
  114. langgraph_api/utils/headers.py +76 -2
  115. langgraph_api/utils/retriable_client.py +74 -0
  116. langgraph_api/utils/stream_codec.py +315 -0
  117. langgraph_api/utils/uuids.py +29 -62
  118. langgraph_api/validation.py +9 -0
  119. langgraph_api/webhook.py +120 -6
  120. langgraph_api/worker.py +55 -24
  121. {langgraph_api-0.4.1.dist-info → langgraph_api-0.7.3.dist-info}/METADATA +16 -8
  122. langgraph_api-0.7.3.dist-info/RECORD +168 -0
  123. {langgraph_api-0.4.1.dist-info → langgraph_api-0.7.3.dist-info}/WHEEL +1 -1
  124. langgraph_runtime/__init__.py +1 -0
  125. langgraph_runtime/routes.py +11 -0
  126. logging.json +1 -3
  127. openapi.json +839 -478
  128. langgraph_api/config.py +0 -387
  129. langgraph_api/js/isolate-0x130008000-46649-46649-v8.log +0 -4430
  130. langgraph_api/js/isolate-0x138008000-44681-44681-v8.log +0 -4430
  131. langgraph_api/js/package-lock.json +0 -3308
  132. langgraph_api-0.4.1.dist-info/RECORD +0 -107
  133. /langgraph_api/{utils.py → grpc/__init__.py} +0 -0
  134. {langgraph_api-0.4.1.dist-info → langgraph_api-0.7.3.dist-info}/entry_points.txt +0 -0
  135. {langgraph_api-0.4.1.dist-info → langgraph_api-0.7.3.dist-info}/licenses/LICENSE +0 -0
@@ -0,0 +1,450 @@
1
+ import os
2
+
3
+ import structlog
4
+ from opentelemetry import metrics
5
+ from opentelemetry.exporter.otlp.proto.http.metric_exporter import (
6
+ OTLPMetricExporter,
7
+ )
8
+ from opentelemetry.metrics import CallbackOptions, Observation
9
+ from opentelemetry.sdk.metrics import MeterProvider
10
+ from opentelemetry.sdk.metrics.export import PeriodicExportingMetricReader
11
+ from opentelemetry.sdk.resources import SERVICE_NAME, Resource
12
+
13
+ from langgraph_api import asyncio as lg_asyncio
14
+ from langgraph_api import config, metadata
15
+ from langgraph_api.feature_flags import FF_USE_CORE_API
16
+ from langgraph_api.grpc.ops import Runs as GrpcRuns
17
+ from langgraph_api.http_metrics_utils import HTTP_LATENCY_BUCKETS
18
+ from langgraph_runtime.database import connect, pool_stats
19
+ from langgraph_runtime.metrics import get_metrics
20
+ from langgraph_runtime.ops import Runs
21
+
22
+ CrudRuns = GrpcRuns if FF_USE_CORE_API else Runs
23
+
24
+ logger = structlog.stdlib.get_logger(__name__)
25
+
26
+ _meter_provider = None
27
+ _customer_attributes = {}
28
+
29
+ _http_request_counter = None
30
+ _http_latency_histogram = None
31
+
32
+
33
+ def initialize_self_hosted_metrics():
34
+ global \
35
+ _meter_provider, \
36
+ _http_request_counter, \
37
+ _http_latency_histogram, \
38
+ _customer_attributes
39
+
40
+ if not config.LANGGRAPH_METRICS_ENABLED:
41
+ return
42
+
43
+ if not config.LANGGRAPH_METRICS_ENDPOINT:
44
+ raise RuntimeError(
45
+ "LANGGRAPH_METRICS_ENABLED is true but no LANGGRAPH_METRICS_ENDPOINT is configured"
46
+ )
47
+
48
+ # for now, this is only enabled for fully self-hosted customers
49
+ # we will need to update the otel collector auth model to support hybrid customers
50
+ if not config.LANGGRAPH_CLOUD_LICENSE_KEY:
51
+ logger.warning(
52
+ "Self-hosted metrics require a license key, and do not work with hybrid deployments yet."
53
+ )
54
+ return
55
+
56
+ try:
57
+ exporter = OTLPMetricExporter(
58
+ endpoint=config.LANGGRAPH_METRICS_ENDPOINT,
59
+ headers={"X-Langchain-License-Key": config.LANGGRAPH_CLOUD_LICENSE_KEY},
60
+ )
61
+
62
+ # this will periodically export metrics to our beacon lgp otel collector in a separate thread
63
+ metric_reader = PeriodicExportingMetricReader(
64
+ exporter=exporter,
65
+ export_interval_millis=config.LANGGRAPH_METRICS_EXPORT_INTERVAL_MS,
66
+ )
67
+
68
+ resource_attributes = {
69
+ SERVICE_NAME: config.SELF_HOSTED_OBSERVABILITY_SERVICE_NAME,
70
+ }
71
+
72
+ resource = Resource.create(resource_attributes)
73
+
74
+ if config.LANGGRAPH_CLOUD_LICENSE_KEY:
75
+ try:
76
+ from langgraph_license.validation import (
77
+ CUSTOMER_ID, # type: ignore[unresolved-import]
78
+ CUSTOMER_NAME, # type: ignore[unresolved-import]
79
+ )
80
+
81
+ if CUSTOMER_ID:
82
+ _customer_attributes["customer_id"] = CUSTOMER_ID
83
+ if CUSTOMER_NAME:
84
+ _customer_attributes["customer_name"] = CUSTOMER_NAME
85
+ except ImportError:
86
+ pass
87
+ except Exception as e:
88
+ logger.warning("Failed to get customer info from license", exc_info=e)
89
+
90
+ # resolves to pod name in k8s, or container id in docker
91
+ instance_id = os.environ.get("HOSTNAME")
92
+ if instance_id:
93
+ _customer_attributes["instance_id"] = instance_id
94
+
95
+ _meter_provider = MeterProvider(
96
+ metric_readers=[metric_reader], resource=resource
97
+ )
98
+ metrics.set_meter_provider(_meter_provider)
99
+
100
+ meter = metrics.get_meter("langgraph_api.self_hosted")
101
+
102
+ meter.create_observable_gauge(
103
+ name="lg_api_num_pending_runs",
104
+ description="The number of runs currently pending",
105
+ unit="1",
106
+ callbacks=[_get_pending_runs_callback],
107
+ )
108
+
109
+ meter.create_observable_gauge(
110
+ name="lg_api_num_running_runs",
111
+ description="The number of runs currently running",
112
+ unit="1",
113
+ callbacks=[_get_running_runs_callback],
114
+ )
115
+
116
+ meter.create_observable_gauge(
117
+ name="lg_api_pending_runs_wait_time_max",
118
+ description="The maximum time a run has been pending, in seconds",
119
+ unit="s",
120
+ callbacks=[_get_pending_runs_wait_time_max_callback],
121
+ )
122
+
123
+ meter.create_observable_gauge(
124
+ name="lg_api_pending_runs_wait_time_med",
125
+ description="The median pending wait time across runs, in seconds",
126
+ unit="s",
127
+ callbacks=[_get_pending_runs_wait_time_med_callback],
128
+ )
129
+
130
+ meter.create_observable_gauge(
131
+ name="lg_api_pending_unblocked_runs_wait_time_max",
132
+ description="The maximum time a run has been pending excluding runs blocked by another run on the same thread, in seconds",
133
+ unit="s",
134
+ callbacks=[_get_pending_unblocked_runs_wait_time_max_callback],
135
+ )
136
+
137
+ if config.N_JOBS_PER_WORKER > 0:
138
+ meter.create_observable_gauge(
139
+ name="lg_api_workers_max",
140
+ description="The maximum number of workers available",
141
+ unit="1",
142
+ callbacks=[_get_workers_max_callback],
143
+ )
144
+
145
+ meter.create_observable_gauge(
146
+ name="lg_api_workers_active",
147
+ description="The number of currently active workers",
148
+ unit="1",
149
+ callbacks=[_get_workers_active_callback],
150
+ )
151
+
152
+ meter.create_observable_gauge(
153
+ name="lg_api_workers_available",
154
+ description="The number of available (idle) workers",
155
+ unit="1",
156
+ callbacks=[_get_workers_available_callback],
157
+ )
158
+
159
+ if not config.IS_QUEUE_ENTRYPOINT and not config.IS_EXECUTOR_ENTRYPOINT:
160
+ _http_request_counter = meter.create_counter(
161
+ name="lg_api_http_requests_total",
162
+ description="Total number of HTTP requests",
163
+ unit="1",
164
+ )
165
+
166
+ _http_latency_histogram = meter.create_histogram(
167
+ name="lg_api_http_requests_latency_seconds",
168
+ description="HTTP request latency in seconds",
169
+ unit="s",
170
+ explicit_bucket_boundaries_advisory=[
171
+ b for b in HTTP_LATENCY_BUCKETS if b != float("inf")
172
+ ],
173
+ )
174
+
175
+ meter.create_observable_gauge(
176
+ name="lg_api_pg_pool_max",
177
+ description="The maximum size of the postgres connection pool",
178
+ unit="1",
179
+ callbacks=[_get_pg_pool_max_callback],
180
+ )
181
+
182
+ meter.create_observable_gauge(
183
+ name="lg_api_pg_pool_size",
184
+ description="Number of connections currently managed by the postgres connection pool",
185
+ unit="1",
186
+ callbacks=[_get_pg_pool_size_callback],
187
+ )
188
+
189
+ meter.create_observable_gauge(
190
+ name="lg_api_pg_pool_available",
191
+ description="Number of connections currently idle in the postgres connection pool",
192
+ unit="1",
193
+ callbacks=[_get_pg_pool_available_callback],
194
+ )
195
+
196
+ meter.create_observable_gauge(
197
+ name="lg_api_redis_pool_max",
198
+ description="The maximum size of the redis connection pool",
199
+ unit="1",
200
+ callbacks=[_get_redis_pool_max_callback],
201
+ )
202
+
203
+ meter.create_observable_gauge(
204
+ name="lg_api_redis_pool_size",
205
+ description="Number of connections currently in use in the redis connection pool",
206
+ unit="1",
207
+ callbacks=[_get_redis_pool_size_callback],
208
+ )
209
+
210
+ meter.create_observable_gauge(
211
+ name="lg_api_redis_pool_available",
212
+ description="Number of connections currently idle in the redis connection pool",
213
+ unit="1",
214
+ callbacks=[_get_redis_pool_available_callback],
215
+ )
216
+
217
+ logger.info(
218
+ "Self-hosted metrics initialized successfully",
219
+ endpoint=config.LANGGRAPH_METRICS_ENDPOINT,
220
+ export_interval_ms=config.LANGGRAPH_METRICS_EXPORT_INTERVAL_MS,
221
+ )
222
+
223
+ except Exception as e:
224
+ logger.exception("Failed to initialize self-hosted metrics", exc_info=e)
225
+
226
+
227
+ def shutdown_self_hosted_metrics():
228
+ global _meter_provider
229
+
230
+ if _meter_provider:
231
+ try:
232
+ logger.info("Shutting down self-hosted metrics")
233
+ _meter_provider.shutdown(timeout_millis=5000)
234
+ _meter_provider = None
235
+ except Exception as e:
236
+ logger.exception("Failed to shutdown self-hosted metrics", exc_info=e)
237
+
238
+
239
+ def record_http_request(
240
+ method: str, route_path: str, status: int, latency_seconds: float
241
+ ):
242
+ if not _meter_provider or not _http_request_counter or not _http_latency_histogram:
243
+ return
244
+
245
+ attributes = {"method": method, "path": route_path, "status": str(status)}
246
+ if _customer_attributes:
247
+ attributes.update(_customer_attributes)
248
+
249
+ _http_request_counter.add(1, attributes)
250
+ _http_latency_histogram.record(latency_seconds, attributes)
251
+
252
+
253
+ def _get_queue_stats():
254
+ async def _fetch_queue_stats():
255
+ try:
256
+ async with connect() as conn:
257
+ return await CrudRuns.stats(conn)
258
+ except Exception as e:
259
+ logger.warning("Failed to get queue stats from database", exc_info=e)
260
+ return {
261
+ "n_pending": 0,
262
+ "n_running": 0,
263
+ "pending_runs_wait_time_max_secs": 0,
264
+ "pending_runs_wait_time_med_secs": 0,
265
+ "pending_unblocked_runs_wait_time_max_secs": 0,
266
+ }
267
+
268
+ try:
269
+ future = lg_asyncio.run_coroutine_threadsafe(_fetch_queue_stats())
270
+ return future.result(timeout=5)
271
+ except Exception as e:
272
+ logger.warning("Failed to get queue stats", exc_info=e)
273
+ return {
274
+ "n_pending": 0,
275
+ "n_running": 0,
276
+ "pending_runs_wait_time_max_secs": 0,
277
+ "pending_runs_wait_time_med_secs": 0,
278
+ "pending_unblocked_runs_wait_time_max_secs": 0,
279
+ }
280
+
281
+
282
+ def _get_pool_stats():
283
+ # _get_pool() inside the pool_stats fn will not work correctly if called from the daemon thread created by PeriodicExportingMetricReader,
284
+ # so we submit this as a coro to run in the main event loop
285
+ async def _fetch_pool_stats():
286
+ try:
287
+ return pool_stats(
288
+ metadata.PROJECT_ID, metadata.HOST_REVISION_ID, format="json"
289
+ )
290
+ except Exception as e:
291
+ logger.warning("Failed to get pool stats", exc_info=e)
292
+ return {"postgres": {}, "redis": {}}
293
+
294
+ try:
295
+ future = lg_asyncio.run_coroutine_threadsafe(_fetch_pool_stats())
296
+ return future.result(timeout=5)
297
+ except Exception as e:
298
+ logger.warning("Failed to get pool stats", exc_info=e)
299
+ return {"postgres": {}, "redis": {}}
300
+
301
+
302
+ def _get_pending_runs_callback(options: CallbackOptions):
303
+ try:
304
+ stats = _get_queue_stats()
305
+ return [Observation(stats.get("n_pending", 0), attributes=_customer_attributes)]
306
+ except Exception as e:
307
+ logger.warning("Failed to get pending runs", exc_info=e)
308
+ return [Observation(0, attributes=_customer_attributes)]
309
+
310
+
311
+ def _get_running_runs_callback(options: CallbackOptions):
312
+ try:
313
+ stats = _get_queue_stats()
314
+ return [Observation(stats.get("n_running", 0), attributes=_customer_attributes)]
315
+ except Exception as e:
316
+ logger.warning("Failed to get running runs", exc_info=e)
317
+ return [Observation(0, attributes=_customer_attributes)]
318
+
319
+
320
+ def _get_pending_runs_wait_time_max_callback(options: CallbackOptions):
321
+ try:
322
+ stats = _get_queue_stats()
323
+ value = stats.get("pending_runs_wait_time_max_secs")
324
+ value = 0 if value is None else value
325
+ return [Observation(value, attributes=_customer_attributes)]
326
+ except Exception as e:
327
+ logger.warning("Failed to get max pending wait time", exc_info=e)
328
+ return [Observation(0, attributes=_customer_attributes)]
329
+
330
+
331
+ def _get_pending_runs_wait_time_med_callback(options: CallbackOptions):
332
+ try:
333
+ stats = _get_queue_stats()
334
+ value = stats.get("pending_runs_wait_time_med_secs")
335
+ value = 0 if value is None else value
336
+ return [Observation(value, attributes=_customer_attributes)]
337
+ except Exception as e:
338
+ logger.warning("Failed to get median pending wait time", exc_info=e)
339
+ return [Observation(0, attributes=_customer_attributes)]
340
+
341
+
342
+ def _get_pending_unblocked_runs_wait_time_max_callback(options: CallbackOptions):
343
+ try:
344
+ stats = _get_queue_stats()
345
+ value = stats.get("pending_unblocked_runs_wait_time_max_secs")
346
+ value = 0 if value is None else value
347
+ return [Observation(value, attributes=_customer_attributes)]
348
+ except Exception as e:
349
+ logger.warning("Failed to get max unblocked pending wait time", exc_info=e)
350
+ return [Observation(0, attributes=_customer_attributes)]
351
+
352
+
353
+ def _get_workers_max_callback(options: CallbackOptions):
354
+ try:
355
+ metrics_data = get_metrics()
356
+ worker_metrics = metrics_data.get("workers", {})
357
+ return [
358
+ Observation(worker_metrics.get("max", 0), attributes=_customer_attributes)
359
+ ]
360
+ except Exception as e:
361
+ logger.warning("Failed to get max workers", exc_info=e)
362
+ return [Observation(0, attributes=_customer_attributes)]
363
+
364
+
365
+ def _get_workers_active_callback(options: CallbackOptions):
366
+ try:
367
+ metrics_data = get_metrics()
368
+ worker_metrics = metrics_data.get("workers", {})
369
+ return [
370
+ Observation(
371
+ worker_metrics.get("active", 0), attributes=_customer_attributes
372
+ )
373
+ ]
374
+ except Exception as e:
375
+ logger.warning("Failed to get active workers", exc_info=e)
376
+ return [Observation(0, attributes=_customer_attributes)]
377
+
378
+
379
+ def _get_workers_available_callback(options: CallbackOptions):
380
+ try:
381
+ metrics_data = get_metrics()
382
+ worker_metrics = metrics_data.get("workers", {})
383
+ return [
384
+ Observation(
385
+ worker_metrics.get("available", 0), attributes=_customer_attributes
386
+ )
387
+ ]
388
+ except Exception as e:
389
+ logger.warning("Failed to get available workers", exc_info=e)
390
+ return [Observation(0, attributes=_customer_attributes)]
391
+
392
+
393
+ def _get_pg_pool_max_callback(options: CallbackOptions):
394
+ try:
395
+ stats = _get_pool_stats()
396
+ pg_max = stats.get("postgres", {}).get("pool_max", 0)
397
+ return [Observation(pg_max, attributes=_customer_attributes)]
398
+ except Exception as e:
399
+ logger.warning("Failed to get PG pool max", exc_info=e)
400
+ return [Observation(0, attributes=_customer_attributes)]
401
+
402
+
403
+ def _get_pg_pool_size_callback(options: CallbackOptions):
404
+ try:
405
+ stats = _get_pool_stats()
406
+ pg_size = stats.get("postgres", {}).get("pool_size", 0)
407
+ return [Observation(pg_size, attributes=_customer_attributes)]
408
+ except Exception as e:
409
+ logger.warning("Failed to get PG pool size", exc_info=e)
410
+ return [Observation(0, attributes=_customer_attributes)]
411
+
412
+
413
+ def _get_pg_pool_available_callback(options: CallbackOptions):
414
+ try:
415
+ stats = _get_pool_stats()
416
+ pg_available = stats.get("postgres", {}).get("pool_available", 0)
417
+ return [Observation(pg_available, attributes=_customer_attributes)]
418
+ except Exception as e:
419
+ logger.warning("Failed to get PG pool available", exc_info=e)
420
+ return [Observation(0, attributes=_customer_attributes)]
421
+
422
+
423
+ def _get_redis_pool_max_callback(options: CallbackOptions):
424
+ try:
425
+ stats = _get_pool_stats()
426
+ redis_max = stats.get("redis", {}).get("max_connections", 0)
427
+ return [Observation(redis_max, attributes=_customer_attributes)]
428
+ except Exception as e:
429
+ logger.warning("Failed to get Redis pool max", exc_info=e)
430
+ return [Observation(0, attributes=_customer_attributes)]
431
+
432
+
433
+ def _get_redis_pool_size_callback(options: CallbackOptions):
434
+ try:
435
+ stats = _get_pool_stats()
436
+ redis_size = stats.get("redis", {}).get("in_use_connections", 0)
437
+ return [Observation(redis_size, attributes=_customer_attributes)]
438
+ except Exception as e:
439
+ logger.warning("Failed to get Redis pool size", exc_info=e)
440
+ return [Observation(0, attributes=_customer_attributes)]
441
+
442
+
443
+ def _get_redis_pool_available_callback(options: CallbackOptions):
444
+ try:
445
+ stats = _get_pool_stats()
446
+ redis_available = stats.get("redis", {}).get("idle_connections", 0)
447
+ return [Observation(redis_available, attributes=_customer_attributes)]
448
+ except Exception as e:
449
+ logger.warning("Failed to get Redis pool available", exc_info=e)
450
+ return [Observation(0, attributes=_customer_attributes)]
langgraph_api/serde.py CHANGED
@@ -1,10 +1,9 @@
1
1
  import asyncio
2
- import base64
3
2
  import re
4
3
  import uuid
5
4
  from base64 import b64encode
6
5
  from collections import deque
7
- from collections.abc import Mapping
6
+ from collections.abc import Callable, Mapping
8
7
  from datetime import timedelta, timezone
9
8
  from decimal import Decimal
10
9
  from ipaddress import (
@@ -17,7 +16,7 @@ from ipaddress import (
17
16
  )
18
17
  from pathlib import Path
19
18
  from re import Pattern
20
- from typing import Any, NamedTuple, cast
19
+ from typing import Any, Literal, NamedTuple, cast
21
20
  from zoneinfo import ZoneInfo
22
21
 
23
22
  import cloudpickle
@@ -32,6 +31,10 @@ class Fragment(NamedTuple):
32
31
  buf: bytes
33
32
 
34
33
 
34
+ def fragment_loads(data: bytes) -> Fragment:
35
+ return Fragment(data)
36
+
37
+
35
38
  def decimal_encoder(dec_value: Decimal) -> int | float:
36
39
  """
37
40
  Encodes a Decimal as int of there's no exponent, otherwise float
@@ -51,7 +54,7 @@ def decimal_encoder(dec_value: Decimal) -> int | float:
51
54
  # maps to float('nan') / float('inf') / float('-inf')
52
55
  not dec_value.is_finite()
53
56
  # or regular float
54
- or cast(int, dec_value.as_tuple().exponent) < 0
57
+ or cast("int", dec_value.as_tuple().exponent) < 0
55
58
  ):
56
59
  return float(dec_value)
57
60
  return int(dec_value)
@@ -76,15 +79,15 @@ def default(obj):
76
79
  return obj._asdict()
77
80
  elif isinstance(obj, BaseException):
78
81
  return {"error": type(obj).__name__, "message": str(obj)}
79
- elif isinstance(obj, (set, frozenset, deque)): # noqa: UP038
82
+ elif isinstance(obj, (set, frozenset, deque)):
80
83
  return list(obj)
81
- elif isinstance(obj, (timezone, ZoneInfo)): # noqa: UP038
84
+ elif isinstance(obj, (timezone, ZoneInfo)):
82
85
  return obj.tzname(None)
83
86
  elif isinstance(obj, timedelta):
84
87
  return obj.total_seconds()
85
88
  elif isinstance(obj, Decimal):
86
89
  return decimal_encoder(obj)
87
- elif isinstance( # noqa: UP038
90
+ elif isinstance(
88
91
  obj,
89
92
  (
90
93
  uuid.UUID,
@@ -110,16 +113,24 @@ _option = orjson.OPT_SERIALIZE_NUMPY | orjson.OPT_NON_STR_KEYS
110
113
  _SURROGATE_RE = re.compile(r"[\ud800-\udfff]")
111
114
 
112
115
 
113
- def _strip_surr(s: str) -> str:
114
- return s if _SURROGATE_RE.search(s) is None else _SURROGATE_RE.sub("", s)
116
+ def _replace_surr(s: str) -> str:
117
+ return s if _SURROGATE_RE.search(s) is None else _SURROGATE_RE.sub("?", s)
115
118
 
116
119
 
117
120
  def _sanitise(o: Any) -> Any:
118
121
  if isinstance(o, str):
119
- return _strip_surr(o)
122
+ return _replace_surr(o)
120
123
  if isinstance(o, Mapping):
121
124
  return {_sanitise(k): _sanitise(v) for k, v in o.items()}
122
125
  if isinstance(o, list | tuple | set):
126
+ if (
127
+ isinstance(o, tuple)
128
+ and hasattr(o, "_asdict")
129
+ and callable(o._asdict)
130
+ and hasattr(o, "_fields")
131
+ and isinstance(o._fields, tuple)
132
+ ): # named tuple
133
+ return {f: _sanitise(ov) for f, ov in zip(o._fields, o, strict=True)}
123
134
  ctor = list if isinstance(o, list) else type(o)
124
135
  return ctor(_sanitise(x) for x in o)
125
136
  return o
@@ -146,26 +157,67 @@ def json_loads(content: bytes | Fragment | dict) -> Any:
146
157
  content = content.buf
147
158
  if isinstance(content, dict):
148
159
  return content
149
- return orjson.loads(cast(bytes, content))
160
+ return orjson.loads(content)
161
+
162
+
163
+ def json_dumpb_optional(obj: Any | None) -> bytes | None:
164
+ if obj is None:
165
+ return
166
+ return json_dumpb(obj)
167
+
150
168
 
169
+ def json_loads_optional(content: bytes | None) -> Any | None:
170
+ if content is None:
171
+ return
172
+ return json_loads(content)
151
173
 
174
+
175
+ # Do not use. orjson holds the GIL the entire time it's running anyway.
152
176
  async def ajson_loads(content: bytes | Fragment) -> Any:
153
177
  return await asyncio.to_thread(json_loads, content)
154
178
 
155
179
 
156
180
  class Serializer(JsonPlusSerializer):
181
+ def __init__(
182
+ self,
183
+ __unpack_ext_hook__: Callable[[int, bytes], Any] | None = None,
184
+ pickle_fallback: bool | None = None,
185
+ ):
186
+ from langgraph_api.config import SERDE
187
+
188
+ allowed_json_modules: list[tuple[str, ...]] | Literal[True] | None = None
189
+ if SERDE and "allowed_json_modules" in SERDE:
190
+ allowed_ = SERDE["allowed_json_modules"]
191
+ if allowed_ is True:
192
+ allowed_json_modules = True
193
+ elif allowed_ is None:
194
+ allowed_json_modules = None
195
+ else:
196
+ allowed_json_modules = [tuple(x) for x in allowed_]
197
+ if pickle_fallback is None:
198
+ if SERDE and "pickle_fallback" in SERDE:
199
+ pickle_fallback = SERDE["pickle_fallback"]
200
+ else:
201
+ pickle_fallback = True
202
+
203
+ super().__init__(
204
+ allowed_json_modules=allowed_json_modules,
205
+ __unpack_ext_hook__=__unpack_ext_hook__,
206
+ )
207
+ self.pickle_fallback = pickle_fallback
208
+
157
209
  def dumps_typed(self, obj: Any) -> tuple[str, bytes]:
158
210
  try:
159
211
  return super().dumps_typed(obj)
160
212
  except TypeError:
161
213
  return "pickle", cloudpickle.dumps(obj)
162
214
 
163
- def dumps(self, obj: Any) -> bytes:
164
- # See comment above (in json_dumpb)
165
- return super().dumps(obj).replace(rb"\\u0000", b"").replace(rb"\u0000", b"")
166
-
167
215
  def loads_typed(self, data: tuple[str, bytes]) -> Any:
168
216
  if data[0] == "pickle":
217
+ if not self.pickle_fallback:
218
+ raise ValueError(
219
+ "Pickle fallback is disabled. Cannot deserialize pickled object."
220
+ )
169
221
  try:
170
222
  return cloudpickle.loads(data[1])
171
223
  except Exception as e:
@@ -173,26 +225,16 @@ class Serializer(JsonPlusSerializer):
173
225
  "Failed to unpickle object, replacing w None", exc_info=e
174
226
  )
175
227
  return None
176
- return super().loads_typed(data)
177
-
178
-
179
- mpack_keys = {"method", "value"}
180
- SERIALIZER = Serializer()
181
-
182
-
183
- # TODO: Make more performant (by removing)
184
- async def reserialize_message(message: bytes) -> bytes:
185
- # Stream messages from golang runtime are a byte dict of StreamChunks.
186
- loaded = await ajson_loads(message)
187
- converted = {}
188
- for k, v in loaded.items():
189
- if isinstance(v, dict) and v.keys() == mpack_keys:
190
- if v["method"] == "missing":
191
- converted[k] = v["value"] # oops
192
- else:
193
- converted[k] = SERIALIZER.loads_typed(
194
- (v["method"], base64.b64decode(v["value"]))
228
+ try:
229
+ return super().loads_typed(data)
230
+ except Exception:
231
+ if data[0] == "json":
232
+ logger.exception(
233
+ "Heads up! There was a deserialization error of an item stored using 'json'-type serialization."
234
+ ' For security reasons, starting in langgraph-api version 0.5.0, we no longer serialize objects using the "json" type.'
235
+ " If you would like to retain the ability to deserialize old checkpoints saved in this format, "
236
+ 'please set the "allowed_json_modules" option in your langgraph.json configuration to add the'
237
+ " necessary module and type paths to an allow-list to be deserialized. You can alkso retain the"
238
+ ' ability to insecurely deserialize custom types by setting it to "true".'
195
239
  )
196
- else:
197
- converted[k] = v
198
- return json_dumpb(converted)
240
+ raise