langgraph-api 0.4.40__py3-none-any.whl → 0.5.6__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of langgraph-api might be problematic. Click here for more details.

Files changed (41) hide show
  1. langgraph_api/__init__.py +1 -1
  2. langgraph_api/api/assistants.py +65 -61
  3. langgraph_api/api/meta.py +6 -0
  4. langgraph_api/api/threads.py +11 -7
  5. langgraph_api/auth/custom.py +29 -24
  6. langgraph_api/cli.py +2 -49
  7. langgraph_api/config.py +131 -16
  8. langgraph_api/graph.py +1 -1
  9. langgraph_api/grpc/client.py +183 -0
  10. langgraph_api/grpc/config_conversion.py +225 -0
  11. langgraph_api/grpc/generated/core_api_pb2.py +275 -0
  12. langgraph_api/{grpc_ops → grpc}/generated/core_api_pb2.pyi +35 -40
  13. langgraph_api/grpc/generated/engine_common_pb2.py +190 -0
  14. langgraph_api/grpc/generated/engine_common_pb2.pyi +634 -0
  15. langgraph_api/grpc/generated/engine_common_pb2_grpc.py +24 -0
  16. langgraph_api/grpc/ops.py +1045 -0
  17. langgraph_api/js/build.mts +1 -1
  18. langgraph_api/js/client.http.mts +1 -1
  19. langgraph_api/js/client.mts +1 -1
  20. langgraph_api/js/package.json +12 -12
  21. langgraph_api/js/src/graph.mts +20 -0
  22. langgraph_api/js/yarn.lock +176 -234
  23. langgraph_api/metadata.py +29 -21
  24. langgraph_api/queue_entrypoint.py +2 -2
  25. langgraph_api/route.py +14 -4
  26. langgraph_api/schema.py +2 -2
  27. langgraph_api/self_hosted_metrics.py +48 -2
  28. langgraph_api/serde.py +58 -14
  29. langgraph_api/server.py +16 -2
  30. langgraph_api/worker.py +1 -1
  31. {langgraph_api-0.4.40.dist-info → langgraph_api-0.5.6.dist-info}/METADATA +6 -6
  32. {langgraph_api-0.4.40.dist-info → langgraph_api-0.5.6.dist-info}/RECORD +38 -34
  33. langgraph_api/grpc_ops/client.py +0 -80
  34. langgraph_api/grpc_ops/generated/core_api_pb2.py +0 -274
  35. langgraph_api/grpc_ops/ops.py +0 -610
  36. /langgraph_api/{grpc_ops → grpc}/__init__.py +0 -0
  37. /langgraph_api/{grpc_ops → grpc}/generated/__init__.py +0 -0
  38. /langgraph_api/{grpc_ops → grpc}/generated/core_api_pb2_grpc.py +0 -0
  39. {langgraph_api-0.4.40.dist-info → langgraph_api-0.5.6.dist-info}/WHEEL +0 -0
  40. {langgraph_api-0.4.40.dist-info → langgraph_api-0.5.6.dist-info}/entry_points.txt +0 -0
  41. {langgraph_api-0.4.40.dist-info → langgraph_api-0.5.6.dist-info}/licenses/LICENSE +0 -0
@@ -0,0 +1,1045 @@
1
+ """gRPC-based operations for LangGraph API."""
2
+
3
+ from __future__ import annotations
4
+
5
+ import asyncio
6
+ import functools
7
+ from collections.abc import AsyncIterator
8
+ from datetime import UTC
9
+ from http import HTTPStatus
10
+ from typing import Any, overload
11
+ from uuid import UUID
12
+
13
+ import orjson
14
+ import structlog
15
+ from google.protobuf.json_format import MessageToDict
16
+ from google.protobuf.struct_pb2 import Struct # type: ignore[import]
17
+ from grpc import StatusCode
18
+ from grpc.aio import AioRpcError
19
+ from langgraph_sdk.schema import Config
20
+ from starlette.exceptions import HTTPException
21
+
22
+ from langgraph_api.grpc import config_conversion
23
+ from langgraph_api.schema import (
24
+ Assistant,
25
+ AssistantSelectField,
26
+ Context,
27
+ MetadataInput,
28
+ OnConflictBehavior,
29
+ Thread,
30
+ ThreadSelectField,
31
+ ThreadStatus,
32
+ )
33
+ from langgraph_api.serde import json_dumpb, json_loads
34
+
35
+ from .client import get_shared_client
36
+ from .generated import core_api_pb2 as pb
37
+
38
+ GRPC_STATUS_TO_HTTP_STATUS = {
39
+ StatusCode.NOT_FOUND: HTTPStatus.NOT_FOUND,
40
+ StatusCode.ALREADY_EXISTS: HTTPStatus.CONFLICT,
41
+ StatusCode.INVALID_ARGUMENT: HTTPStatus.UNPROCESSABLE_ENTITY,
42
+ }
43
+
44
+ logger = structlog.stdlib.get_logger(__name__)
45
+
46
+
47
+ def map_if_exists(if_exists: str) -> pb.OnConflictBehavior:
48
+ if if_exists == "do_nothing":
49
+ return pb.OnConflictBehavior.DO_NOTHING
50
+ return pb.OnConflictBehavior.RAISE
51
+
52
+
53
+ @overload
54
+ def consolidate_config_and_context(
55
+ config: Config | None, context: None
56
+ ) -> tuple[Config, None]: ...
57
+
58
+
59
+ @overload
60
+ def consolidate_config_and_context(
61
+ config: Config | None, context: Context
62
+ ) -> tuple[Config, Context]: ...
63
+
64
+
65
+ def consolidate_config_and_context(
66
+ config: Config | None, context: Context | None
67
+ ) -> tuple[Config, Context | None]:
68
+ """Return a new (config, context) with consistent configurable/context.
69
+
70
+ Does not mutate the passed-in objects. If both configurable and context
71
+ are provided, raises 400. If only one is provided, mirrors it to the other.
72
+ """
73
+ cfg: Config = Config(config or {})
74
+ ctx: Context | None = dict(context) if context is not None else None
75
+ configurable = cfg.get("configurable")
76
+
77
+ if configurable and ctx:
78
+ raise HTTPException(
79
+ status_code=400,
80
+ detail="Cannot specify both configurable and context. Prefer setting context alone."
81
+ " Context was introduced in LangGraph 0.6.0 and "
82
+ "is the long term planned replacement for configurable.",
83
+ )
84
+
85
+ if configurable:
86
+ ctx = configurable
87
+ elif ctx is not None:
88
+ cfg["configurable"] = ctx
89
+
90
+ return cfg, ctx
91
+
92
+
93
+ def dict_to_struct(data: dict[str, Any]) -> Struct:
94
+ """Convert a dictionary to a protobuf Struct."""
95
+ struct = Struct()
96
+ if data:
97
+ struct.update(data)
98
+ return struct
99
+
100
+
101
+ def struct_to_dict(struct: Struct) -> dict[str, Any]:
102
+ """Convert a protobuf Struct to a dictionary."""
103
+ return MessageToDict(struct) if struct else {}
104
+
105
+
106
+ def proto_to_assistant(proto_assistant: pb.Assistant) -> Assistant:
107
+ """Convert protobuf Assistant to dictionary format."""
108
+ # Preserve None for optional scalar fields by checking presence via HasField
109
+ description = (
110
+ proto_assistant.description if proto_assistant.HasField("description") else None
111
+ )
112
+ return {
113
+ "assistant_id": proto_assistant.assistant_id,
114
+ "graph_id": proto_assistant.graph_id,
115
+ "version": proto_assistant.version,
116
+ "created_at": proto_assistant.created_at.ToDatetime(tzinfo=UTC),
117
+ "updated_at": proto_assistant.updated_at.ToDatetime(tzinfo=UTC),
118
+ "config": config_conversion.config_from_proto(proto_assistant.config),
119
+ "context": struct_to_dict(proto_assistant.context),
120
+ "metadata": struct_to_dict(proto_assistant.metadata),
121
+ "name": proto_assistant.name,
122
+ "description": description,
123
+ }
124
+
125
+
126
+ THREAD_STATUS_TO_PB = {
127
+ "idle": pb.ThreadStatus.THREAD_STATUS_IDLE,
128
+ "busy": pb.ThreadStatus.THREAD_STATUS_BUSY,
129
+ "interrupted": pb.ThreadStatus.THREAD_STATUS_INTERRUPTED,
130
+ "error": pb.ThreadStatus.THREAD_STATUS_ERROR,
131
+ }
132
+
133
+ THREAD_STATUS_FROM_PB = {
134
+ pb.ThreadStatus.THREAD_STATUS_IDLE: "idle",
135
+ pb.ThreadStatus.THREAD_STATUS_BUSY: "busy",
136
+ pb.ThreadStatus.THREAD_STATUS_INTERRUPTED: "interrupted",
137
+ pb.ThreadStatus.THREAD_STATUS_ERROR: "error",
138
+ }
139
+
140
+ THREAD_SORT_BY_MAP = {
141
+ "thread_id": pb.ThreadsSortBy.THREADS_SORT_BY_THREAD_ID,
142
+ "created_at": pb.ThreadsSortBy.THREADS_SORT_BY_CREATED_AT,
143
+ "updated_at": pb.ThreadsSortBy.THREADS_SORT_BY_UPDATED_AT,
144
+ "status": pb.ThreadsSortBy.THREADS_SORT_BY_STATUS,
145
+ }
146
+
147
+ THREAD_TTL_STRATEGY_MAP = {"delete": pb.ThreadTTLStrategy.THREAD_TTL_STRATEGY_DELETE}
148
+
149
+
150
+ def _map_thread_status(status: ThreadStatus | None) -> pb.ThreadStatus | None:
151
+ if status is None:
152
+ return None
153
+ return THREAD_STATUS_TO_PB.get(status)
154
+
155
+
156
+ def _map_threads_sort_by(sort_by: str | None) -> pb.ThreadsSortBy:
157
+ if not sort_by:
158
+ return pb.ThreadsSortBy.THREADS_SORT_BY_CREATED_AT
159
+ return THREAD_SORT_BY_MAP.get(
160
+ sort_by.lower(), pb.ThreadsSortBy.THREADS_SORT_BY_CREATED_AT
161
+ )
162
+
163
+
164
+ def _map_thread_ttl(ttl: dict[str, Any] | None) -> pb.ThreadTTLConfig | None:
165
+ if not ttl:
166
+ return None
167
+
168
+ config = pb.ThreadTTLConfig()
169
+ strategy = ttl.get("strategy")
170
+ if strategy:
171
+ mapped_strategy = THREAD_TTL_STRATEGY_MAP.get(str(strategy).lower())
172
+ if mapped_strategy is None:
173
+ raise HTTPException(
174
+ status_code=HTTPStatus.UNPROCESSABLE_ENTITY,
175
+ detail=f"Invalid thread TTL strategy: {strategy}. Expected one of ['delete']",
176
+ )
177
+ config.strategy = mapped_strategy
178
+
179
+ ttl_value = ttl.get("ttl", ttl.get("default_ttl"))
180
+ if ttl_value is not None:
181
+ config.default_ttl = float(ttl_value)
182
+
183
+ sweep_interval = ttl.get("sweep_interval_minutes")
184
+ if sweep_interval is not None:
185
+ config.sweep_interval_minutes = int(sweep_interval)
186
+
187
+ return config
188
+
189
+
190
+ def fragment_to_value(fragment: pb.Fragment | None) -> Any:
191
+ if fragment is None or not fragment.value:
192
+ return {}
193
+ try:
194
+ return json_loads(fragment.value)
195
+ except orjson.JSONDecodeError:
196
+ logger.warning("Failed to decode fragment", fragment=fragment.value)
197
+ return {}
198
+
199
+
200
+ def _proto_interrupts_to_dict(
201
+ interrupts_map: dict[str, pb.Interrupts],
202
+ ) -> dict[str, list[dict[str, Any]]]:
203
+ out: dict[str, list[dict[str, Any]]] = {}
204
+ for key, interrupts in interrupts_map.items():
205
+ entries: list[dict[str, Any]] = []
206
+ for interrupt in interrupts.interrupts:
207
+ entry: dict[str, Any] = {
208
+ "id": interrupt.id or None,
209
+ "value": json_loads(interrupt.value),
210
+ }
211
+ if interrupt.when:
212
+ entry["when"] = interrupt.when
213
+ if interrupt.resumable:
214
+ entry["resumable"] = interrupt.resumable
215
+ if interrupt.ns:
216
+ entry["ns"] = list(interrupt.ns)
217
+ entries.append(entry)
218
+ out[key] = entries
219
+ return out
220
+
221
+
222
+ def proto_to_thread(proto_thread: pb.Thread) -> Thread:
223
+ """Convert protobuf Thread to API dictionary format."""
224
+ thread_id = (
225
+ UUID(proto_thread.thread_id.value)
226
+ if proto_thread.HasField("thread_id")
227
+ else None
228
+ )
229
+ if thread_id is None:
230
+ raise HTTPException(
231
+ status_code=HTTPStatus.INTERNAL_SERVER_ERROR,
232
+ detail="Thread response missing thread_id",
233
+ )
234
+ created_at = (
235
+ proto_thread.created_at.ToDatetime(tzinfo=UTC)
236
+ if proto_thread.HasField("created_at")
237
+ else None
238
+ )
239
+ updated_at = (
240
+ proto_thread.updated_at.ToDatetime(tzinfo=UTC)
241
+ if proto_thread.HasField("updated_at")
242
+ else None
243
+ )
244
+ status = THREAD_STATUS_FROM_PB.get(proto_thread.status, "idle")
245
+
246
+ return {
247
+ "thread_id": thread_id,
248
+ "created_at": created_at,
249
+ "updated_at": updated_at,
250
+ "metadata": fragment_to_value(proto_thread.metadata),
251
+ "config": fragment_to_value(proto_thread.config),
252
+ "error": fragment_to_value(proto_thread.error),
253
+ "status": status, # type: ignore[typeddict-item]
254
+ "values": fragment_to_value(proto_thread.values),
255
+ "interrupts": _proto_interrupts_to_dict(dict(proto_thread.interrupts)),
256
+ }
257
+
258
+
259
+ def exception_to_struct(exception: BaseException | None) -> Struct | None:
260
+ if exception is None:
261
+ return None
262
+ try:
263
+ payload = orjson.loads(json_dumpb(exception))
264
+ except orjson.JSONDecodeError:
265
+ payload = {"error": type(exception).__name__, "message": str(exception)}
266
+ return dict_to_struct(payload)
267
+
268
+
269
+ def _filter_thread_fields(
270
+ thread: Thread, select: list[ThreadSelectField] | None
271
+ ) -> dict[str, Any]:
272
+ if not select:
273
+ return dict(thread)
274
+ return {field: thread[field] for field in select if field in thread}
275
+
276
+
277
+ def _normalize_uuid(value: UUID | str) -> str:
278
+ return str(value) if isinstance(value, UUID) else str(UUID(str(value)))
279
+
280
+
281
+ def _map_sort_by(sort_by: str | None) -> pb.AssistantsSortBy:
282
+ """Map string sort_by to protobuf enum."""
283
+ if not sort_by:
284
+ return pb.AssistantsSortBy.CREATED_AT
285
+
286
+ sort_by_lower = sort_by.lower()
287
+ mapping = {
288
+ "assistant_id": pb.AssistantsSortBy.ASSISTANT_ID,
289
+ "graph_id": pb.AssistantsSortBy.GRAPH_ID,
290
+ "name": pb.AssistantsSortBy.NAME,
291
+ "created_at": pb.AssistantsSortBy.CREATED_AT,
292
+ "updated_at": pb.AssistantsSortBy.UPDATED_AT,
293
+ }
294
+ return mapping.get(sort_by_lower, pb.AssistantsSortBy.CREATED_AT)
295
+
296
+
297
+ def _map_sort_order(sort_order: str | None) -> pb.SortOrder:
298
+ """Map string sort_order to protobuf enum."""
299
+ if sort_order and sort_order.upper() == "ASC":
300
+ return pb.SortOrder.ASC
301
+ return pb.SortOrder.DESC
302
+
303
+
304
+ def _handle_grpc_error(error: AioRpcError) -> None:
305
+ """Handle gRPC errors and convert to appropriate exceptions."""
306
+ raise HTTPException(
307
+ status_code=GRPC_STATUS_TO_HTTP_STATUS.get(
308
+ error.code(), HTTPStatus.INTERNAL_SERVER_ERROR
309
+ ),
310
+ detail=str(error.details()),
311
+ )
312
+
313
+
314
+ class Authenticated:
315
+ """Base class for authenticated operations (matches storage_postgres interface)."""
316
+
317
+ resource: str = "assistants"
318
+
319
+ @classmethod
320
+ async def handle_event(
321
+ cls,
322
+ ctx: Any, # Auth context
323
+ action: str,
324
+ value: Any,
325
+ ) -> dict[str, Any] | None:
326
+ """Handle authentication event - stub implementation for now."""
327
+ # TODO: Implement proper auth handling that converts auth context
328
+ # to gRPC AuthFilter format when needed
329
+ return None
330
+
331
+
332
+ def grpc_error_guard(cls):
333
+ """Class decorator to wrap async methods and handle gRPC errors uniformly."""
334
+ for name, attr in list(cls.__dict__.items()):
335
+ func = None
336
+ wrapper_type = None
337
+ if isinstance(attr, staticmethod):
338
+ func = attr.__func__
339
+ wrapper_type = staticmethod
340
+ elif isinstance(attr, classmethod):
341
+ func = attr.__func__
342
+ wrapper_type = classmethod
343
+ elif callable(attr):
344
+ func = attr
345
+
346
+ if func and asyncio.iscoroutinefunction(func):
347
+
348
+ def make_wrapper(f):
349
+ @functools.wraps(f)
350
+ async def wrapped(*args, **kwargs):
351
+ try:
352
+ return await f(*args, **kwargs)
353
+ except AioRpcError as e:
354
+ _handle_grpc_error(e)
355
+
356
+ return wrapped # noqa: B023
357
+
358
+ wrapped = make_wrapper(func)
359
+ if wrapper_type is staticmethod:
360
+ setattr(cls, name, staticmethod(wrapped))
361
+ elif wrapper_type is classmethod:
362
+ setattr(cls, name, classmethod(wrapped))
363
+ else:
364
+ setattr(cls, name, wrapped)
365
+ return cls
366
+
367
+
368
+ @grpc_error_guard
369
+ class Assistants(Authenticated):
370
+ """gRPC-based assistants operations."""
371
+
372
+ resource = "assistants"
373
+
374
+ @staticmethod
375
+ async def search(
376
+ conn, # Not used in gRPC implementation
377
+ *,
378
+ graph_id: str | None,
379
+ metadata: MetadataInput,
380
+ limit: int,
381
+ offset: int,
382
+ sort_by: str | None = None,
383
+ sort_order: str | None = None,
384
+ select: list[AssistantSelectField] | None = None,
385
+ ctx: Any = None,
386
+ ) -> tuple[AsyncIterator[Assistant], int | None]: # type: ignore[return-value]
387
+ """Search assistants via gRPC."""
388
+ # Handle auth filters
389
+ auth_filters = await Assistants.handle_event(
390
+ ctx,
391
+ "search",
392
+ {
393
+ "graph_id": graph_id,
394
+ "metadata": metadata,
395
+ "limit": limit,
396
+ "offset": offset,
397
+ },
398
+ )
399
+
400
+ # Build the gRPC request
401
+ request = pb.SearchAssistantsRequest(
402
+ filters=auth_filters,
403
+ graph_id=graph_id,
404
+ metadata=dict_to_struct(metadata or {}),
405
+ limit=limit,
406
+ offset=offset,
407
+ sort_by=_map_sort_by(sort_by),
408
+ sort_order=_map_sort_order(sort_order),
409
+ select=select,
410
+ )
411
+
412
+ client = await get_shared_client()
413
+ response = await client.assistants.Search(request)
414
+
415
+ # Convert response to expected format
416
+ assistants = [
417
+ proto_to_assistant(assistant) for assistant in response.assistants
418
+ ]
419
+
420
+ # Determine if there are more results
421
+ # Note: gRPC doesn't return cursor info, so we estimate based on result count
422
+ cursor = offset + limit if len(assistants) == limit else None
423
+
424
+ async def generate_results():
425
+ for assistant in assistants:
426
+ yield {
427
+ k: v for k, v in assistant.items() if select is None or k in select
428
+ }
429
+
430
+ return generate_results(), cursor
431
+
432
+ @staticmethod
433
+ async def get(
434
+ conn, # Not used in gRPC implementation
435
+ assistant_id: UUID | str,
436
+ ctx: Any = None,
437
+ ) -> AsyncIterator[Assistant]: # type: ignore[return-value]
438
+ """Get assistant by ID via gRPC."""
439
+ # Handle auth filters
440
+ auth_filters = await Assistants.handle_event(
441
+ ctx, "read", {"assistant_id": str(assistant_id)}
442
+ )
443
+
444
+ # Build the gRPC request
445
+ request = pb.GetAssistantRequest(
446
+ assistant_id=str(assistant_id),
447
+ filters=auth_filters or {},
448
+ )
449
+
450
+ client = await get_shared_client()
451
+ response = await client.assistants.Get(request)
452
+
453
+ # Convert and yield the result
454
+ assistant = proto_to_assistant(response)
455
+
456
+ async def generate_result():
457
+ yield assistant
458
+
459
+ return generate_result()
460
+
461
+ @staticmethod
462
+ async def put(
463
+ conn, # Not used in gRPC implementation
464
+ assistant_id: UUID | str,
465
+ *,
466
+ graph_id: str,
467
+ config: Config,
468
+ context: Context,
469
+ metadata: MetadataInput,
470
+ if_exists: OnConflictBehavior,
471
+ name: str,
472
+ description: str | None = None,
473
+ ctx: Any = None,
474
+ ) -> AsyncIterator[Assistant]: # type: ignore[return-value]
475
+ """Create/update assistant via gRPC."""
476
+ # Handle auth filters
477
+ auth_filters = await Assistants.handle_event(
478
+ ctx,
479
+ "create",
480
+ {
481
+ "assistant_id": str(assistant_id),
482
+ "graph_id": graph_id,
483
+ "config": config,
484
+ "context": context,
485
+ "metadata": metadata,
486
+ "name": name,
487
+ "description": description,
488
+ },
489
+ )
490
+
491
+ config, context = consolidate_config_and_context(config, context)
492
+
493
+ on_conflict = map_if_exists(if_exists)
494
+
495
+ # Build the gRPC request
496
+ request = pb.CreateAssistantRequest(
497
+ assistant_id=str(assistant_id),
498
+ graph_id=graph_id,
499
+ filters=auth_filters or {},
500
+ if_exists=on_conflict,
501
+ config=config_conversion.config_to_proto(config),
502
+ context=dict_to_struct(context or {}),
503
+ name=name,
504
+ description=description,
505
+ metadata=dict_to_struct(metadata or {}),
506
+ )
507
+
508
+ client = await get_shared_client()
509
+ response = await client.assistants.Create(request)
510
+
511
+ # Convert and yield the result
512
+ assistant = proto_to_assistant(response)
513
+
514
+ async def generate_result():
515
+ yield assistant
516
+
517
+ return generate_result()
518
+
519
+ @staticmethod
520
+ async def patch(
521
+ conn, # Not used in gRPC implementation
522
+ assistant_id: UUID | str,
523
+ *,
524
+ config: Config | None = None,
525
+ context: Context | None = None,
526
+ graph_id: str | None = None,
527
+ metadata: MetadataInput | None = None,
528
+ name: str | None = None,
529
+ description: str | None = None,
530
+ ctx: Any = None,
531
+ ) -> AsyncIterator[Assistant]: # type: ignore[return-value]
532
+ """Update assistant via gRPC."""
533
+ metadata = metadata if metadata is not None else {}
534
+ config = config if config is not None else Config()
535
+ # Handle auth filters
536
+ auth_filters = await Assistants.handle_event(
537
+ ctx,
538
+ "update",
539
+ {
540
+ "assistant_id": str(assistant_id),
541
+ "graph_id": graph_id,
542
+ "config": config,
543
+ "context": context,
544
+ "metadata": metadata,
545
+ "name": name,
546
+ "description": description,
547
+ },
548
+ )
549
+
550
+ config, context = consolidate_config_and_context(config, context)
551
+
552
+ # Build the gRPC request
553
+ request = pb.PatchAssistantRequest(
554
+ assistant_id=str(assistant_id),
555
+ filters=auth_filters or {},
556
+ graph_id=graph_id,
557
+ name=name,
558
+ description=description,
559
+ metadata=dict_to_struct(metadata or {}),
560
+ )
561
+
562
+ # Add optional config if provided
563
+ if config:
564
+ request.config.CopyFrom(config_conversion.config_to_proto(config))
565
+
566
+ # Add optional context if provided
567
+ if context:
568
+ request.context.CopyFrom(dict_to_struct(context))
569
+
570
+ client = await get_shared_client()
571
+ response = await client.assistants.Patch(request)
572
+
573
+ # Convert and yield the result
574
+ assistant = proto_to_assistant(response)
575
+
576
+ async def generate_result():
577
+ yield assistant
578
+
579
+ return generate_result()
580
+
581
+ @staticmethod
582
+ async def delete(
583
+ conn, # Not used in gRPC implementation
584
+ assistant_id: UUID | str,
585
+ ctx: Any = None,
586
+ ) -> AsyncIterator[UUID]: # type: ignore[return-value]
587
+ """Delete assistant via gRPC."""
588
+ # Handle auth filters
589
+ auth_filters = await Assistants.handle_event(
590
+ ctx, "delete", {"assistant_id": str(assistant_id)}
591
+ )
592
+
593
+ # Build the gRPC request
594
+ request = pb.DeleteAssistantRequest(
595
+ assistant_id=str(assistant_id),
596
+ filters=auth_filters or {},
597
+ )
598
+
599
+ client = await get_shared_client()
600
+ await client.assistants.Delete(request)
601
+
602
+ # Return the deleted ID
603
+ async def generate_result():
604
+ yield UUID(str(assistant_id))
605
+
606
+ return generate_result()
607
+
608
+ @staticmethod
609
+ async def set_latest(
610
+ conn, # Not used in gRPC implementation
611
+ assistant_id: UUID | str,
612
+ version: int,
613
+ ctx: Any = None,
614
+ ) -> AsyncIterator[Assistant]: # type: ignore[return-value]
615
+ """Set latest version of assistant via gRPC."""
616
+ # Handle auth filters
617
+ auth_filters = await Assistants.handle_event(
618
+ ctx,
619
+ "update",
620
+ {
621
+ "assistant_id": str(assistant_id),
622
+ "version": version,
623
+ },
624
+ )
625
+
626
+ # Build the gRPC request
627
+ request = pb.SetLatestAssistantRequest(
628
+ assistant_id=str(assistant_id),
629
+ version=version,
630
+ filters=auth_filters or {},
631
+ )
632
+
633
+ client = await get_shared_client()
634
+ response = await client.assistants.SetLatest(request)
635
+
636
+ # Convert and yield the result
637
+ assistant = proto_to_assistant(response)
638
+
639
+ async def generate_result():
640
+ yield assistant
641
+
642
+ return generate_result()
643
+
644
+ @staticmethod
645
+ async def get_versions(
646
+ conn, # Not used in gRPC implementation
647
+ assistant_id: UUID | str,
648
+ metadata: MetadataInput,
649
+ limit: int,
650
+ offset: int,
651
+ ctx: Any = None,
652
+ ) -> AsyncIterator[Assistant]: # type: ignore[return-value]
653
+ """Get all versions of assistant via gRPC."""
654
+ # Handle auth filters
655
+ auth_filters = await Assistants.handle_event(
656
+ ctx,
657
+ "search",
658
+ {"assistant_id": str(assistant_id), "metadata": metadata},
659
+ )
660
+
661
+ # Build the gRPC request
662
+ request = pb.GetAssistantVersionsRequest(
663
+ assistant_id=str(assistant_id),
664
+ filters=auth_filters or {},
665
+ metadata=dict_to_struct(metadata or {}),
666
+ limit=limit,
667
+ offset=offset,
668
+ )
669
+
670
+ client = await get_shared_client()
671
+ response = await client.assistants.GetVersions(request)
672
+
673
+ # Convert and yield the results
674
+ async def generate_results():
675
+ for version in response.versions:
676
+ # Preserve None for optional scalar fields by checking presence
677
+ version_description = (
678
+ version.description if version.HasField("description") else None
679
+ )
680
+ yield {
681
+ "assistant_id": version.assistant_id,
682
+ "graph_id": version.graph_id,
683
+ "version": version.version,
684
+ "created_at": version.created_at.ToDatetime(tzinfo=UTC),
685
+ "config": config_conversion.config_from_proto(version.config),
686
+ "context": struct_to_dict(version.context),
687
+ "metadata": struct_to_dict(version.metadata),
688
+ "name": version.name,
689
+ "description": version_description,
690
+ }
691
+
692
+ return generate_results()
693
+
694
+ @staticmethod
695
+ async def count(
696
+ conn, # Not used in gRPC implementation
697
+ *,
698
+ graph_id: str | None = None,
699
+ metadata: MetadataInput = None,
700
+ ctx: Any = None,
701
+ ) -> int: # type: ignore[return-value]
702
+ """Count assistants via gRPC."""
703
+ # Handle auth filters
704
+ auth_filters = await Assistants.handle_event(
705
+ ctx, "search", {"graph_id": graph_id, "metadata": metadata}
706
+ )
707
+
708
+ # Build the gRPC request
709
+ request = pb.CountAssistantsRequest(
710
+ filters=auth_filters or {},
711
+ graph_id=graph_id,
712
+ metadata=dict_to_struct(metadata or {}),
713
+ )
714
+
715
+ client = await get_shared_client()
716
+ response = await client.assistants.Count(request)
717
+
718
+ return int(response.count)
719
+
720
+
721
+ def _json_contains(container: Any, subset: dict[str, Any]) -> bool:
722
+ if not subset:
723
+ return True
724
+ if not isinstance(container, dict):
725
+ return False
726
+ for key, value in subset.items():
727
+ if key not in container:
728
+ return False
729
+ candidate = container[key]
730
+ if isinstance(value, dict):
731
+ if not _json_contains(candidate, value):
732
+ return False
733
+ else:
734
+ if candidate != value:
735
+ return False
736
+ return True
737
+
738
+
739
+ @grpc_error_guard
740
+ class Threads(Authenticated):
741
+ """gRPC-based threads operations."""
742
+
743
+ resource = "threads"
744
+
745
+ @staticmethod
746
+ async def search(
747
+ conn, # Not used in gRPC implementation
748
+ *,
749
+ ids: list[str] | list[UUID] | None = None,
750
+ metadata: MetadataInput,
751
+ values: MetadataInput,
752
+ status: ThreadStatus | None,
753
+ limit: int,
754
+ offset: int,
755
+ sort_by: str | None = None,
756
+ sort_order: str | None = None,
757
+ select: list[ThreadSelectField] | None = None,
758
+ ctx: Any = None,
759
+ ) -> tuple[AsyncIterator[Thread], int | None]: # type: ignore[return-value]
760
+ metadata = metadata or {}
761
+ values = values or {}
762
+
763
+ auth_filters = await Threads.handle_event(
764
+ ctx,
765
+ "search",
766
+ {
767
+ "metadata": metadata,
768
+ "values": values,
769
+ "status": status,
770
+ "limit": limit,
771
+ "offset": offset,
772
+ },
773
+ )
774
+
775
+ if ids:
776
+ normalized_ids = [_normalize_uuid(thread_id) for thread_id in ids]
777
+ threads: list[Thread] = []
778
+ client = await get_shared_client()
779
+ for thread_id in normalized_ids:
780
+ request = pb.GetThreadRequest(
781
+ thread_id=pb.UUID(value=_normalize_uuid(thread_id)),
782
+ filters=auth_filters or {},
783
+ )
784
+ response = await client.threads.Get(request)
785
+ thread = proto_to_thread(response)
786
+
787
+ if status and thread["status"] != status:
788
+ continue
789
+ if metadata and not _json_contains(thread["metadata"], metadata):
790
+ continue
791
+ if values and not _json_contains(thread.get("values") or {}, values):
792
+ continue
793
+ threads.append(thread)
794
+
795
+ total = len(threads)
796
+ paginated = threads[offset : offset + limit]
797
+ cursor = offset + limit if total > offset + limit else None
798
+
799
+ async def generate_results():
800
+ for thread in paginated:
801
+ yield _filter_thread_fields(thread, select)
802
+
803
+ return generate_results(), cursor
804
+
805
+ request_kwargs: dict[str, Any] = {
806
+ "filters": auth_filters or {},
807
+ "metadata": dict_to_struct(metadata),
808
+ "values": dict_to_struct(values),
809
+ "limit": limit,
810
+ "offset": offset,
811
+ "sort_by": _map_threads_sort_by(sort_by),
812
+ "sort_order": _map_sort_order(sort_order),
813
+ }
814
+
815
+ if status:
816
+ mapped_status = _map_thread_status(status)
817
+ if mapped_status is None:
818
+ raise HTTPException(
819
+ status_code=HTTPStatus.UNPROCESSABLE_ENTITY,
820
+ detail=f"Invalid thread status: {status}",
821
+ )
822
+ request_kwargs["status"] = mapped_status
823
+
824
+ if select:
825
+ request_kwargs["select"] = select
826
+
827
+ client = await get_shared_client()
828
+ response = await client.threads.Search(
829
+ pb.SearchThreadsRequest(**request_kwargs)
830
+ )
831
+
832
+ threads = [proto_to_thread(thread) for thread in response.threads]
833
+ cursor = offset + limit if len(threads) == limit else None
834
+
835
+ async def generate_results():
836
+ for thread in threads:
837
+ yield _filter_thread_fields(thread, select)
838
+
839
+ return generate_results(), cursor
840
+
841
+ @staticmethod
842
+ async def count(
843
+ conn, # Not used
844
+ *,
845
+ metadata: MetadataInput,
846
+ values: MetadataInput,
847
+ status: ThreadStatus | None,
848
+ ctx: Any = None,
849
+ ) -> int: # type: ignore[override]
850
+ metadata = metadata or {}
851
+ values = values or {}
852
+
853
+ auth_filters = await Threads.handle_event(
854
+ ctx,
855
+ "search",
856
+ {
857
+ "metadata": metadata,
858
+ "values": values,
859
+ "status": status,
860
+ },
861
+ )
862
+
863
+ request_kwargs: dict[str, Any] = {
864
+ "filters": auth_filters or {},
865
+ "metadata": dict_to_struct(metadata),
866
+ "values": dict_to_struct(values),
867
+ }
868
+ if status:
869
+ mapped_status = _map_thread_status(status)
870
+ if mapped_status is None:
871
+ raise HTTPException(
872
+ status_code=HTTPStatus.UNPROCESSABLE_ENTITY,
873
+ detail=f"Invalid thread status: {status}",
874
+ )
875
+ request_kwargs["status"] = mapped_status
876
+
877
+ client = await get_shared_client()
878
+ response = await client.threads.Count(pb.CountThreadsRequest(**request_kwargs))
879
+
880
+ return int(response.count)
881
+
882
+ @staticmethod
883
+ async def get(
884
+ conn, # Not used
885
+ thread_id: UUID | str,
886
+ ctx: Any = None,
887
+ ) -> AsyncIterator[Thread]: # type: ignore[return-value]
888
+ auth_filters = await Threads.handle_event(
889
+ ctx, "read", {"thread_id": str(thread_id)}
890
+ )
891
+
892
+ request = pb.GetThreadRequest(
893
+ thread_id=pb.UUID(value=_normalize_uuid(thread_id)),
894
+ filters=auth_filters or {},
895
+ )
896
+ client = await get_shared_client()
897
+ response = await client.threads.Get(request)
898
+
899
+ thread = proto_to_thread(response)
900
+
901
+ async def generate_result():
902
+ yield thread
903
+
904
+ return generate_result()
905
+
906
+ @staticmethod
907
+ async def put(
908
+ conn, # Not used
909
+ thread_id: UUID | str,
910
+ *,
911
+ metadata: MetadataInput,
912
+ if_exists: OnConflictBehavior,
913
+ ttl: dict[str, Any] | None = None,
914
+ ctx: Any = None,
915
+ ) -> AsyncIterator[Thread]: # type: ignore[return-value]
916
+ metadata = metadata or {}
917
+
918
+ auth_filters = await Threads.handle_event(
919
+ ctx,
920
+ "create",
921
+ {
922
+ "thread_id": str(thread_id),
923
+ "metadata": metadata,
924
+ "if_exists": if_exists,
925
+ },
926
+ )
927
+
928
+ request = pb.CreateThreadRequest(
929
+ thread_id=pb.UUID(value=_normalize_uuid(thread_id)),
930
+ filters=auth_filters or {},
931
+ if_exists=map_if_exists(if_exists),
932
+ metadata=dict_to_struct(metadata),
933
+ )
934
+ ttl_config = _map_thread_ttl(ttl)
935
+ if ttl_config is not None:
936
+ request.ttl.CopyFrom(ttl_config)
937
+
938
+ client = await get_shared_client()
939
+ response = await client.threads.Create(request)
940
+ thread = proto_to_thread(response)
941
+
942
+ async def generate_result():
943
+ yield thread
944
+
945
+ return generate_result()
946
+
947
+ @staticmethod
948
+ async def patch(
949
+ conn, # Not used
950
+ thread_id: UUID | str,
951
+ *,
952
+ metadata: MetadataInput,
953
+ ttl: dict[str, Any] | None = None,
954
+ ctx: Any = None,
955
+ ) -> AsyncIterator[Thread]: # type: ignore[return-value]
956
+ metadata = metadata or {}
957
+
958
+ auth_filters = await Threads.handle_event(
959
+ ctx,
960
+ "update",
961
+ {
962
+ "thread_id": str(thread_id),
963
+ "metadata": metadata,
964
+ },
965
+ )
966
+
967
+ request = pb.PatchThreadRequest(
968
+ thread_id=pb.UUID(value=_normalize_uuid(thread_id)),
969
+ filters=auth_filters or {},
970
+ )
971
+
972
+ if metadata:
973
+ request.metadata.CopyFrom(dict_to_struct(metadata))
974
+
975
+ ttl_config = _map_thread_ttl(ttl)
976
+ if ttl_config is not None:
977
+ request.ttl.CopyFrom(ttl_config)
978
+
979
+ client = await get_shared_client()
980
+ response = await client.threads.Patch(request)
981
+
982
+ thread = proto_to_thread(response)
983
+
984
+ async def generate_result():
985
+ yield thread
986
+
987
+ return generate_result()
988
+
989
+ @staticmethod
990
+ async def delete(
991
+ conn, # Not used
992
+ thread_id: UUID | str,
993
+ ctx: Any = None,
994
+ ) -> AsyncIterator[UUID]: # type: ignore[return-value]
995
+ auth_filters = await Threads.handle_event(
996
+ ctx,
997
+ "delete",
998
+ {
999
+ "thread_id": str(thread_id),
1000
+ },
1001
+ )
1002
+
1003
+ request = pb.DeleteThreadRequest(
1004
+ thread_id=pb.UUID(value=_normalize_uuid(thread_id)),
1005
+ filters=auth_filters or {},
1006
+ )
1007
+
1008
+ client = await get_shared_client()
1009
+ response = await client.threads.Delete(request)
1010
+
1011
+ deleted_id = UUID(response.value)
1012
+
1013
+ async def generate_result():
1014
+ yield deleted_id
1015
+
1016
+ return generate_result()
1017
+
1018
+ @staticmethod
1019
+ async def copy(
1020
+ conn, # Not used
1021
+ thread_id: UUID | str,
1022
+ ctx: Any = None,
1023
+ ) -> AsyncIterator[Thread]: # type: ignore[return-value]
1024
+ auth_filters = await Threads.handle_event(
1025
+ ctx,
1026
+ "read",
1027
+ {
1028
+ "thread_id": str(thread_id),
1029
+ },
1030
+ )
1031
+
1032
+ request = pb.CopyThreadRequest(
1033
+ thread_id=pb.UUID(value=_normalize_uuid(thread_id)),
1034
+ filters=auth_filters or {},
1035
+ )
1036
+
1037
+ client = await get_shared_client()
1038
+ response = await client.threads.Copy(request)
1039
+
1040
+ thread = proto_to_thread(response)
1041
+
1042
+ async def generate_result():
1043
+ yield thread
1044
+
1045
+ return generate_result()