langgraph-api 0.4.40__py3-none-any.whl → 0.5.6__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of langgraph-api might be problematic. Click here for more details.
- langgraph_api/__init__.py +1 -1
- langgraph_api/api/assistants.py +65 -61
- langgraph_api/api/meta.py +6 -0
- langgraph_api/api/threads.py +11 -7
- langgraph_api/auth/custom.py +29 -24
- langgraph_api/cli.py +2 -49
- langgraph_api/config.py +131 -16
- langgraph_api/graph.py +1 -1
- langgraph_api/grpc/client.py +183 -0
- langgraph_api/grpc/config_conversion.py +225 -0
- langgraph_api/grpc/generated/core_api_pb2.py +275 -0
- langgraph_api/{grpc_ops → grpc}/generated/core_api_pb2.pyi +35 -40
- langgraph_api/grpc/generated/engine_common_pb2.py +190 -0
- langgraph_api/grpc/generated/engine_common_pb2.pyi +634 -0
- langgraph_api/grpc/generated/engine_common_pb2_grpc.py +24 -0
- langgraph_api/grpc/ops.py +1045 -0
- langgraph_api/js/build.mts +1 -1
- langgraph_api/js/client.http.mts +1 -1
- langgraph_api/js/client.mts +1 -1
- langgraph_api/js/package.json +12 -12
- langgraph_api/js/src/graph.mts +20 -0
- langgraph_api/js/yarn.lock +176 -234
- langgraph_api/metadata.py +29 -21
- langgraph_api/queue_entrypoint.py +2 -2
- langgraph_api/route.py +14 -4
- langgraph_api/schema.py +2 -2
- langgraph_api/self_hosted_metrics.py +48 -2
- langgraph_api/serde.py +58 -14
- langgraph_api/server.py +16 -2
- langgraph_api/worker.py +1 -1
- {langgraph_api-0.4.40.dist-info → langgraph_api-0.5.6.dist-info}/METADATA +6 -6
- {langgraph_api-0.4.40.dist-info → langgraph_api-0.5.6.dist-info}/RECORD +38 -34
- langgraph_api/grpc_ops/client.py +0 -80
- langgraph_api/grpc_ops/generated/core_api_pb2.py +0 -274
- langgraph_api/grpc_ops/ops.py +0 -610
- /langgraph_api/{grpc_ops → grpc}/__init__.py +0 -0
- /langgraph_api/{grpc_ops → grpc}/generated/__init__.py +0 -0
- /langgraph_api/{grpc_ops → grpc}/generated/core_api_pb2_grpc.py +0 -0
- {langgraph_api-0.4.40.dist-info → langgraph_api-0.5.6.dist-info}/WHEEL +0 -0
- {langgraph_api-0.4.40.dist-info → langgraph_api-0.5.6.dist-info}/entry_points.txt +0 -0
- {langgraph_api-0.4.40.dist-info → langgraph_api-0.5.6.dist-info}/licenses/LICENSE +0 -0
|
@@ -0,0 +1,1045 @@
|
|
|
1
|
+
"""gRPC-based operations for LangGraph API."""
|
|
2
|
+
|
|
3
|
+
from __future__ import annotations
|
|
4
|
+
|
|
5
|
+
import asyncio
|
|
6
|
+
import functools
|
|
7
|
+
from collections.abc import AsyncIterator
|
|
8
|
+
from datetime import UTC
|
|
9
|
+
from http import HTTPStatus
|
|
10
|
+
from typing import Any, overload
|
|
11
|
+
from uuid import UUID
|
|
12
|
+
|
|
13
|
+
import orjson
|
|
14
|
+
import structlog
|
|
15
|
+
from google.protobuf.json_format import MessageToDict
|
|
16
|
+
from google.protobuf.struct_pb2 import Struct # type: ignore[import]
|
|
17
|
+
from grpc import StatusCode
|
|
18
|
+
from grpc.aio import AioRpcError
|
|
19
|
+
from langgraph_sdk.schema import Config
|
|
20
|
+
from starlette.exceptions import HTTPException
|
|
21
|
+
|
|
22
|
+
from langgraph_api.grpc import config_conversion
|
|
23
|
+
from langgraph_api.schema import (
|
|
24
|
+
Assistant,
|
|
25
|
+
AssistantSelectField,
|
|
26
|
+
Context,
|
|
27
|
+
MetadataInput,
|
|
28
|
+
OnConflictBehavior,
|
|
29
|
+
Thread,
|
|
30
|
+
ThreadSelectField,
|
|
31
|
+
ThreadStatus,
|
|
32
|
+
)
|
|
33
|
+
from langgraph_api.serde import json_dumpb, json_loads
|
|
34
|
+
|
|
35
|
+
from .client import get_shared_client
|
|
36
|
+
from .generated import core_api_pb2 as pb
|
|
37
|
+
|
|
38
|
+
GRPC_STATUS_TO_HTTP_STATUS = {
|
|
39
|
+
StatusCode.NOT_FOUND: HTTPStatus.NOT_FOUND,
|
|
40
|
+
StatusCode.ALREADY_EXISTS: HTTPStatus.CONFLICT,
|
|
41
|
+
StatusCode.INVALID_ARGUMENT: HTTPStatus.UNPROCESSABLE_ENTITY,
|
|
42
|
+
}
|
|
43
|
+
|
|
44
|
+
logger = structlog.stdlib.get_logger(__name__)
|
|
45
|
+
|
|
46
|
+
|
|
47
|
+
def map_if_exists(if_exists: str) -> pb.OnConflictBehavior:
|
|
48
|
+
if if_exists == "do_nothing":
|
|
49
|
+
return pb.OnConflictBehavior.DO_NOTHING
|
|
50
|
+
return pb.OnConflictBehavior.RAISE
|
|
51
|
+
|
|
52
|
+
|
|
53
|
+
@overload
|
|
54
|
+
def consolidate_config_and_context(
|
|
55
|
+
config: Config | None, context: None
|
|
56
|
+
) -> tuple[Config, None]: ...
|
|
57
|
+
|
|
58
|
+
|
|
59
|
+
@overload
|
|
60
|
+
def consolidate_config_and_context(
|
|
61
|
+
config: Config | None, context: Context
|
|
62
|
+
) -> tuple[Config, Context]: ...
|
|
63
|
+
|
|
64
|
+
|
|
65
|
+
def consolidate_config_and_context(
|
|
66
|
+
config: Config | None, context: Context | None
|
|
67
|
+
) -> tuple[Config, Context | None]:
|
|
68
|
+
"""Return a new (config, context) with consistent configurable/context.
|
|
69
|
+
|
|
70
|
+
Does not mutate the passed-in objects. If both configurable and context
|
|
71
|
+
are provided, raises 400. If only one is provided, mirrors it to the other.
|
|
72
|
+
"""
|
|
73
|
+
cfg: Config = Config(config or {})
|
|
74
|
+
ctx: Context | None = dict(context) if context is not None else None
|
|
75
|
+
configurable = cfg.get("configurable")
|
|
76
|
+
|
|
77
|
+
if configurable and ctx:
|
|
78
|
+
raise HTTPException(
|
|
79
|
+
status_code=400,
|
|
80
|
+
detail="Cannot specify both configurable and context. Prefer setting context alone."
|
|
81
|
+
" Context was introduced in LangGraph 0.6.0 and "
|
|
82
|
+
"is the long term planned replacement for configurable.",
|
|
83
|
+
)
|
|
84
|
+
|
|
85
|
+
if configurable:
|
|
86
|
+
ctx = configurable
|
|
87
|
+
elif ctx is not None:
|
|
88
|
+
cfg["configurable"] = ctx
|
|
89
|
+
|
|
90
|
+
return cfg, ctx
|
|
91
|
+
|
|
92
|
+
|
|
93
|
+
def dict_to_struct(data: dict[str, Any]) -> Struct:
|
|
94
|
+
"""Convert a dictionary to a protobuf Struct."""
|
|
95
|
+
struct = Struct()
|
|
96
|
+
if data:
|
|
97
|
+
struct.update(data)
|
|
98
|
+
return struct
|
|
99
|
+
|
|
100
|
+
|
|
101
|
+
def struct_to_dict(struct: Struct) -> dict[str, Any]:
|
|
102
|
+
"""Convert a protobuf Struct to a dictionary."""
|
|
103
|
+
return MessageToDict(struct) if struct else {}
|
|
104
|
+
|
|
105
|
+
|
|
106
|
+
def proto_to_assistant(proto_assistant: pb.Assistant) -> Assistant:
|
|
107
|
+
"""Convert protobuf Assistant to dictionary format."""
|
|
108
|
+
# Preserve None for optional scalar fields by checking presence via HasField
|
|
109
|
+
description = (
|
|
110
|
+
proto_assistant.description if proto_assistant.HasField("description") else None
|
|
111
|
+
)
|
|
112
|
+
return {
|
|
113
|
+
"assistant_id": proto_assistant.assistant_id,
|
|
114
|
+
"graph_id": proto_assistant.graph_id,
|
|
115
|
+
"version": proto_assistant.version,
|
|
116
|
+
"created_at": proto_assistant.created_at.ToDatetime(tzinfo=UTC),
|
|
117
|
+
"updated_at": proto_assistant.updated_at.ToDatetime(tzinfo=UTC),
|
|
118
|
+
"config": config_conversion.config_from_proto(proto_assistant.config),
|
|
119
|
+
"context": struct_to_dict(proto_assistant.context),
|
|
120
|
+
"metadata": struct_to_dict(proto_assistant.metadata),
|
|
121
|
+
"name": proto_assistant.name,
|
|
122
|
+
"description": description,
|
|
123
|
+
}
|
|
124
|
+
|
|
125
|
+
|
|
126
|
+
THREAD_STATUS_TO_PB = {
|
|
127
|
+
"idle": pb.ThreadStatus.THREAD_STATUS_IDLE,
|
|
128
|
+
"busy": pb.ThreadStatus.THREAD_STATUS_BUSY,
|
|
129
|
+
"interrupted": pb.ThreadStatus.THREAD_STATUS_INTERRUPTED,
|
|
130
|
+
"error": pb.ThreadStatus.THREAD_STATUS_ERROR,
|
|
131
|
+
}
|
|
132
|
+
|
|
133
|
+
THREAD_STATUS_FROM_PB = {
|
|
134
|
+
pb.ThreadStatus.THREAD_STATUS_IDLE: "idle",
|
|
135
|
+
pb.ThreadStatus.THREAD_STATUS_BUSY: "busy",
|
|
136
|
+
pb.ThreadStatus.THREAD_STATUS_INTERRUPTED: "interrupted",
|
|
137
|
+
pb.ThreadStatus.THREAD_STATUS_ERROR: "error",
|
|
138
|
+
}
|
|
139
|
+
|
|
140
|
+
THREAD_SORT_BY_MAP = {
|
|
141
|
+
"thread_id": pb.ThreadsSortBy.THREADS_SORT_BY_THREAD_ID,
|
|
142
|
+
"created_at": pb.ThreadsSortBy.THREADS_SORT_BY_CREATED_AT,
|
|
143
|
+
"updated_at": pb.ThreadsSortBy.THREADS_SORT_BY_UPDATED_AT,
|
|
144
|
+
"status": pb.ThreadsSortBy.THREADS_SORT_BY_STATUS,
|
|
145
|
+
}
|
|
146
|
+
|
|
147
|
+
THREAD_TTL_STRATEGY_MAP = {"delete": pb.ThreadTTLStrategy.THREAD_TTL_STRATEGY_DELETE}
|
|
148
|
+
|
|
149
|
+
|
|
150
|
+
def _map_thread_status(status: ThreadStatus | None) -> pb.ThreadStatus | None:
|
|
151
|
+
if status is None:
|
|
152
|
+
return None
|
|
153
|
+
return THREAD_STATUS_TO_PB.get(status)
|
|
154
|
+
|
|
155
|
+
|
|
156
|
+
def _map_threads_sort_by(sort_by: str | None) -> pb.ThreadsSortBy:
|
|
157
|
+
if not sort_by:
|
|
158
|
+
return pb.ThreadsSortBy.THREADS_SORT_BY_CREATED_AT
|
|
159
|
+
return THREAD_SORT_BY_MAP.get(
|
|
160
|
+
sort_by.lower(), pb.ThreadsSortBy.THREADS_SORT_BY_CREATED_AT
|
|
161
|
+
)
|
|
162
|
+
|
|
163
|
+
|
|
164
|
+
def _map_thread_ttl(ttl: dict[str, Any] | None) -> pb.ThreadTTLConfig | None:
|
|
165
|
+
if not ttl:
|
|
166
|
+
return None
|
|
167
|
+
|
|
168
|
+
config = pb.ThreadTTLConfig()
|
|
169
|
+
strategy = ttl.get("strategy")
|
|
170
|
+
if strategy:
|
|
171
|
+
mapped_strategy = THREAD_TTL_STRATEGY_MAP.get(str(strategy).lower())
|
|
172
|
+
if mapped_strategy is None:
|
|
173
|
+
raise HTTPException(
|
|
174
|
+
status_code=HTTPStatus.UNPROCESSABLE_ENTITY,
|
|
175
|
+
detail=f"Invalid thread TTL strategy: {strategy}. Expected one of ['delete']",
|
|
176
|
+
)
|
|
177
|
+
config.strategy = mapped_strategy
|
|
178
|
+
|
|
179
|
+
ttl_value = ttl.get("ttl", ttl.get("default_ttl"))
|
|
180
|
+
if ttl_value is not None:
|
|
181
|
+
config.default_ttl = float(ttl_value)
|
|
182
|
+
|
|
183
|
+
sweep_interval = ttl.get("sweep_interval_minutes")
|
|
184
|
+
if sweep_interval is not None:
|
|
185
|
+
config.sweep_interval_minutes = int(sweep_interval)
|
|
186
|
+
|
|
187
|
+
return config
|
|
188
|
+
|
|
189
|
+
|
|
190
|
+
def fragment_to_value(fragment: pb.Fragment | None) -> Any:
|
|
191
|
+
if fragment is None or not fragment.value:
|
|
192
|
+
return {}
|
|
193
|
+
try:
|
|
194
|
+
return json_loads(fragment.value)
|
|
195
|
+
except orjson.JSONDecodeError:
|
|
196
|
+
logger.warning("Failed to decode fragment", fragment=fragment.value)
|
|
197
|
+
return {}
|
|
198
|
+
|
|
199
|
+
|
|
200
|
+
def _proto_interrupts_to_dict(
|
|
201
|
+
interrupts_map: dict[str, pb.Interrupts],
|
|
202
|
+
) -> dict[str, list[dict[str, Any]]]:
|
|
203
|
+
out: dict[str, list[dict[str, Any]]] = {}
|
|
204
|
+
for key, interrupts in interrupts_map.items():
|
|
205
|
+
entries: list[dict[str, Any]] = []
|
|
206
|
+
for interrupt in interrupts.interrupts:
|
|
207
|
+
entry: dict[str, Any] = {
|
|
208
|
+
"id": interrupt.id or None,
|
|
209
|
+
"value": json_loads(interrupt.value),
|
|
210
|
+
}
|
|
211
|
+
if interrupt.when:
|
|
212
|
+
entry["when"] = interrupt.when
|
|
213
|
+
if interrupt.resumable:
|
|
214
|
+
entry["resumable"] = interrupt.resumable
|
|
215
|
+
if interrupt.ns:
|
|
216
|
+
entry["ns"] = list(interrupt.ns)
|
|
217
|
+
entries.append(entry)
|
|
218
|
+
out[key] = entries
|
|
219
|
+
return out
|
|
220
|
+
|
|
221
|
+
|
|
222
|
+
def proto_to_thread(proto_thread: pb.Thread) -> Thread:
|
|
223
|
+
"""Convert protobuf Thread to API dictionary format."""
|
|
224
|
+
thread_id = (
|
|
225
|
+
UUID(proto_thread.thread_id.value)
|
|
226
|
+
if proto_thread.HasField("thread_id")
|
|
227
|
+
else None
|
|
228
|
+
)
|
|
229
|
+
if thread_id is None:
|
|
230
|
+
raise HTTPException(
|
|
231
|
+
status_code=HTTPStatus.INTERNAL_SERVER_ERROR,
|
|
232
|
+
detail="Thread response missing thread_id",
|
|
233
|
+
)
|
|
234
|
+
created_at = (
|
|
235
|
+
proto_thread.created_at.ToDatetime(tzinfo=UTC)
|
|
236
|
+
if proto_thread.HasField("created_at")
|
|
237
|
+
else None
|
|
238
|
+
)
|
|
239
|
+
updated_at = (
|
|
240
|
+
proto_thread.updated_at.ToDatetime(tzinfo=UTC)
|
|
241
|
+
if proto_thread.HasField("updated_at")
|
|
242
|
+
else None
|
|
243
|
+
)
|
|
244
|
+
status = THREAD_STATUS_FROM_PB.get(proto_thread.status, "idle")
|
|
245
|
+
|
|
246
|
+
return {
|
|
247
|
+
"thread_id": thread_id,
|
|
248
|
+
"created_at": created_at,
|
|
249
|
+
"updated_at": updated_at,
|
|
250
|
+
"metadata": fragment_to_value(proto_thread.metadata),
|
|
251
|
+
"config": fragment_to_value(proto_thread.config),
|
|
252
|
+
"error": fragment_to_value(proto_thread.error),
|
|
253
|
+
"status": status, # type: ignore[typeddict-item]
|
|
254
|
+
"values": fragment_to_value(proto_thread.values),
|
|
255
|
+
"interrupts": _proto_interrupts_to_dict(dict(proto_thread.interrupts)),
|
|
256
|
+
}
|
|
257
|
+
|
|
258
|
+
|
|
259
|
+
def exception_to_struct(exception: BaseException | None) -> Struct | None:
|
|
260
|
+
if exception is None:
|
|
261
|
+
return None
|
|
262
|
+
try:
|
|
263
|
+
payload = orjson.loads(json_dumpb(exception))
|
|
264
|
+
except orjson.JSONDecodeError:
|
|
265
|
+
payload = {"error": type(exception).__name__, "message": str(exception)}
|
|
266
|
+
return dict_to_struct(payload)
|
|
267
|
+
|
|
268
|
+
|
|
269
|
+
def _filter_thread_fields(
|
|
270
|
+
thread: Thread, select: list[ThreadSelectField] | None
|
|
271
|
+
) -> dict[str, Any]:
|
|
272
|
+
if not select:
|
|
273
|
+
return dict(thread)
|
|
274
|
+
return {field: thread[field] for field in select if field in thread}
|
|
275
|
+
|
|
276
|
+
|
|
277
|
+
def _normalize_uuid(value: UUID | str) -> str:
|
|
278
|
+
return str(value) if isinstance(value, UUID) else str(UUID(str(value)))
|
|
279
|
+
|
|
280
|
+
|
|
281
|
+
def _map_sort_by(sort_by: str | None) -> pb.AssistantsSortBy:
|
|
282
|
+
"""Map string sort_by to protobuf enum."""
|
|
283
|
+
if not sort_by:
|
|
284
|
+
return pb.AssistantsSortBy.CREATED_AT
|
|
285
|
+
|
|
286
|
+
sort_by_lower = sort_by.lower()
|
|
287
|
+
mapping = {
|
|
288
|
+
"assistant_id": pb.AssistantsSortBy.ASSISTANT_ID,
|
|
289
|
+
"graph_id": pb.AssistantsSortBy.GRAPH_ID,
|
|
290
|
+
"name": pb.AssistantsSortBy.NAME,
|
|
291
|
+
"created_at": pb.AssistantsSortBy.CREATED_AT,
|
|
292
|
+
"updated_at": pb.AssistantsSortBy.UPDATED_AT,
|
|
293
|
+
}
|
|
294
|
+
return mapping.get(sort_by_lower, pb.AssistantsSortBy.CREATED_AT)
|
|
295
|
+
|
|
296
|
+
|
|
297
|
+
def _map_sort_order(sort_order: str | None) -> pb.SortOrder:
|
|
298
|
+
"""Map string sort_order to protobuf enum."""
|
|
299
|
+
if sort_order and sort_order.upper() == "ASC":
|
|
300
|
+
return pb.SortOrder.ASC
|
|
301
|
+
return pb.SortOrder.DESC
|
|
302
|
+
|
|
303
|
+
|
|
304
|
+
def _handle_grpc_error(error: AioRpcError) -> None:
|
|
305
|
+
"""Handle gRPC errors and convert to appropriate exceptions."""
|
|
306
|
+
raise HTTPException(
|
|
307
|
+
status_code=GRPC_STATUS_TO_HTTP_STATUS.get(
|
|
308
|
+
error.code(), HTTPStatus.INTERNAL_SERVER_ERROR
|
|
309
|
+
),
|
|
310
|
+
detail=str(error.details()),
|
|
311
|
+
)
|
|
312
|
+
|
|
313
|
+
|
|
314
|
+
class Authenticated:
|
|
315
|
+
"""Base class for authenticated operations (matches storage_postgres interface)."""
|
|
316
|
+
|
|
317
|
+
resource: str = "assistants"
|
|
318
|
+
|
|
319
|
+
@classmethod
|
|
320
|
+
async def handle_event(
|
|
321
|
+
cls,
|
|
322
|
+
ctx: Any, # Auth context
|
|
323
|
+
action: str,
|
|
324
|
+
value: Any,
|
|
325
|
+
) -> dict[str, Any] | None:
|
|
326
|
+
"""Handle authentication event - stub implementation for now."""
|
|
327
|
+
# TODO: Implement proper auth handling that converts auth context
|
|
328
|
+
# to gRPC AuthFilter format when needed
|
|
329
|
+
return None
|
|
330
|
+
|
|
331
|
+
|
|
332
|
+
def grpc_error_guard(cls):
|
|
333
|
+
"""Class decorator to wrap async methods and handle gRPC errors uniformly."""
|
|
334
|
+
for name, attr in list(cls.__dict__.items()):
|
|
335
|
+
func = None
|
|
336
|
+
wrapper_type = None
|
|
337
|
+
if isinstance(attr, staticmethod):
|
|
338
|
+
func = attr.__func__
|
|
339
|
+
wrapper_type = staticmethod
|
|
340
|
+
elif isinstance(attr, classmethod):
|
|
341
|
+
func = attr.__func__
|
|
342
|
+
wrapper_type = classmethod
|
|
343
|
+
elif callable(attr):
|
|
344
|
+
func = attr
|
|
345
|
+
|
|
346
|
+
if func and asyncio.iscoroutinefunction(func):
|
|
347
|
+
|
|
348
|
+
def make_wrapper(f):
|
|
349
|
+
@functools.wraps(f)
|
|
350
|
+
async def wrapped(*args, **kwargs):
|
|
351
|
+
try:
|
|
352
|
+
return await f(*args, **kwargs)
|
|
353
|
+
except AioRpcError as e:
|
|
354
|
+
_handle_grpc_error(e)
|
|
355
|
+
|
|
356
|
+
return wrapped # noqa: B023
|
|
357
|
+
|
|
358
|
+
wrapped = make_wrapper(func)
|
|
359
|
+
if wrapper_type is staticmethod:
|
|
360
|
+
setattr(cls, name, staticmethod(wrapped))
|
|
361
|
+
elif wrapper_type is classmethod:
|
|
362
|
+
setattr(cls, name, classmethod(wrapped))
|
|
363
|
+
else:
|
|
364
|
+
setattr(cls, name, wrapped)
|
|
365
|
+
return cls
|
|
366
|
+
|
|
367
|
+
|
|
368
|
+
@grpc_error_guard
|
|
369
|
+
class Assistants(Authenticated):
|
|
370
|
+
"""gRPC-based assistants operations."""
|
|
371
|
+
|
|
372
|
+
resource = "assistants"
|
|
373
|
+
|
|
374
|
+
@staticmethod
|
|
375
|
+
async def search(
|
|
376
|
+
conn, # Not used in gRPC implementation
|
|
377
|
+
*,
|
|
378
|
+
graph_id: str | None,
|
|
379
|
+
metadata: MetadataInput,
|
|
380
|
+
limit: int,
|
|
381
|
+
offset: int,
|
|
382
|
+
sort_by: str | None = None,
|
|
383
|
+
sort_order: str | None = None,
|
|
384
|
+
select: list[AssistantSelectField] | None = None,
|
|
385
|
+
ctx: Any = None,
|
|
386
|
+
) -> tuple[AsyncIterator[Assistant], int | None]: # type: ignore[return-value]
|
|
387
|
+
"""Search assistants via gRPC."""
|
|
388
|
+
# Handle auth filters
|
|
389
|
+
auth_filters = await Assistants.handle_event(
|
|
390
|
+
ctx,
|
|
391
|
+
"search",
|
|
392
|
+
{
|
|
393
|
+
"graph_id": graph_id,
|
|
394
|
+
"metadata": metadata,
|
|
395
|
+
"limit": limit,
|
|
396
|
+
"offset": offset,
|
|
397
|
+
},
|
|
398
|
+
)
|
|
399
|
+
|
|
400
|
+
# Build the gRPC request
|
|
401
|
+
request = pb.SearchAssistantsRequest(
|
|
402
|
+
filters=auth_filters,
|
|
403
|
+
graph_id=graph_id,
|
|
404
|
+
metadata=dict_to_struct(metadata or {}),
|
|
405
|
+
limit=limit,
|
|
406
|
+
offset=offset,
|
|
407
|
+
sort_by=_map_sort_by(sort_by),
|
|
408
|
+
sort_order=_map_sort_order(sort_order),
|
|
409
|
+
select=select,
|
|
410
|
+
)
|
|
411
|
+
|
|
412
|
+
client = await get_shared_client()
|
|
413
|
+
response = await client.assistants.Search(request)
|
|
414
|
+
|
|
415
|
+
# Convert response to expected format
|
|
416
|
+
assistants = [
|
|
417
|
+
proto_to_assistant(assistant) for assistant in response.assistants
|
|
418
|
+
]
|
|
419
|
+
|
|
420
|
+
# Determine if there are more results
|
|
421
|
+
# Note: gRPC doesn't return cursor info, so we estimate based on result count
|
|
422
|
+
cursor = offset + limit if len(assistants) == limit else None
|
|
423
|
+
|
|
424
|
+
async def generate_results():
|
|
425
|
+
for assistant in assistants:
|
|
426
|
+
yield {
|
|
427
|
+
k: v for k, v in assistant.items() if select is None or k in select
|
|
428
|
+
}
|
|
429
|
+
|
|
430
|
+
return generate_results(), cursor
|
|
431
|
+
|
|
432
|
+
@staticmethod
|
|
433
|
+
async def get(
|
|
434
|
+
conn, # Not used in gRPC implementation
|
|
435
|
+
assistant_id: UUID | str,
|
|
436
|
+
ctx: Any = None,
|
|
437
|
+
) -> AsyncIterator[Assistant]: # type: ignore[return-value]
|
|
438
|
+
"""Get assistant by ID via gRPC."""
|
|
439
|
+
# Handle auth filters
|
|
440
|
+
auth_filters = await Assistants.handle_event(
|
|
441
|
+
ctx, "read", {"assistant_id": str(assistant_id)}
|
|
442
|
+
)
|
|
443
|
+
|
|
444
|
+
# Build the gRPC request
|
|
445
|
+
request = pb.GetAssistantRequest(
|
|
446
|
+
assistant_id=str(assistant_id),
|
|
447
|
+
filters=auth_filters or {},
|
|
448
|
+
)
|
|
449
|
+
|
|
450
|
+
client = await get_shared_client()
|
|
451
|
+
response = await client.assistants.Get(request)
|
|
452
|
+
|
|
453
|
+
# Convert and yield the result
|
|
454
|
+
assistant = proto_to_assistant(response)
|
|
455
|
+
|
|
456
|
+
async def generate_result():
|
|
457
|
+
yield assistant
|
|
458
|
+
|
|
459
|
+
return generate_result()
|
|
460
|
+
|
|
461
|
+
@staticmethod
|
|
462
|
+
async def put(
|
|
463
|
+
conn, # Not used in gRPC implementation
|
|
464
|
+
assistant_id: UUID | str,
|
|
465
|
+
*,
|
|
466
|
+
graph_id: str,
|
|
467
|
+
config: Config,
|
|
468
|
+
context: Context,
|
|
469
|
+
metadata: MetadataInput,
|
|
470
|
+
if_exists: OnConflictBehavior,
|
|
471
|
+
name: str,
|
|
472
|
+
description: str | None = None,
|
|
473
|
+
ctx: Any = None,
|
|
474
|
+
) -> AsyncIterator[Assistant]: # type: ignore[return-value]
|
|
475
|
+
"""Create/update assistant via gRPC."""
|
|
476
|
+
# Handle auth filters
|
|
477
|
+
auth_filters = await Assistants.handle_event(
|
|
478
|
+
ctx,
|
|
479
|
+
"create",
|
|
480
|
+
{
|
|
481
|
+
"assistant_id": str(assistant_id),
|
|
482
|
+
"graph_id": graph_id,
|
|
483
|
+
"config": config,
|
|
484
|
+
"context": context,
|
|
485
|
+
"metadata": metadata,
|
|
486
|
+
"name": name,
|
|
487
|
+
"description": description,
|
|
488
|
+
},
|
|
489
|
+
)
|
|
490
|
+
|
|
491
|
+
config, context = consolidate_config_and_context(config, context)
|
|
492
|
+
|
|
493
|
+
on_conflict = map_if_exists(if_exists)
|
|
494
|
+
|
|
495
|
+
# Build the gRPC request
|
|
496
|
+
request = pb.CreateAssistantRequest(
|
|
497
|
+
assistant_id=str(assistant_id),
|
|
498
|
+
graph_id=graph_id,
|
|
499
|
+
filters=auth_filters or {},
|
|
500
|
+
if_exists=on_conflict,
|
|
501
|
+
config=config_conversion.config_to_proto(config),
|
|
502
|
+
context=dict_to_struct(context or {}),
|
|
503
|
+
name=name,
|
|
504
|
+
description=description,
|
|
505
|
+
metadata=dict_to_struct(metadata or {}),
|
|
506
|
+
)
|
|
507
|
+
|
|
508
|
+
client = await get_shared_client()
|
|
509
|
+
response = await client.assistants.Create(request)
|
|
510
|
+
|
|
511
|
+
# Convert and yield the result
|
|
512
|
+
assistant = proto_to_assistant(response)
|
|
513
|
+
|
|
514
|
+
async def generate_result():
|
|
515
|
+
yield assistant
|
|
516
|
+
|
|
517
|
+
return generate_result()
|
|
518
|
+
|
|
519
|
+
@staticmethod
|
|
520
|
+
async def patch(
|
|
521
|
+
conn, # Not used in gRPC implementation
|
|
522
|
+
assistant_id: UUID | str,
|
|
523
|
+
*,
|
|
524
|
+
config: Config | None = None,
|
|
525
|
+
context: Context | None = None,
|
|
526
|
+
graph_id: str | None = None,
|
|
527
|
+
metadata: MetadataInput | None = None,
|
|
528
|
+
name: str | None = None,
|
|
529
|
+
description: str | None = None,
|
|
530
|
+
ctx: Any = None,
|
|
531
|
+
) -> AsyncIterator[Assistant]: # type: ignore[return-value]
|
|
532
|
+
"""Update assistant via gRPC."""
|
|
533
|
+
metadata = metadata if metadata is not None else {}
|
|
534
|
+
config = config if config is not None else Config()
|
|
535
|
+
# Handle auth filters
|
|
536
|
+
auth_filters = await Assistants.handle_event(
|
|
537
|
+
ctx,
|
|
538
|
+
"update",
|
|
539
|
+
{
|
|
540
|
+
"assistant_id": str(assistant_id),
|
|
541
|
+
"graph_id": graph_id,
|
|
542
|
+
"config": config,
|
|
543
|
+
"context": context,
|
|
544
|
+
"metadata": metadata,
|
|
545
|
+
"name": name,
|
|
546
|
+
"description": description,
|
|
547
|
+
},
|
|
548
|
+
)
|
|
549
|
+
|
|
550
|
+
config, context = consolidate_config_and_context(config, context)
|
|
551
|
+
|
|
552
|
+
# Build the gRPC request
|
|
553
|
+
request = pb.PatchAssistantRequest(
|
|
554
|
+
assistant_id=str(assistant_id),
|
|
555
|
+
filters=auth_filters or {},
|
|
556
|
+
graph_id=graph_id,
|
|
557
|
+
name=name,
|
|
558
|
+
description=description,
|
|
559
|
+
metadata=dict_to_struct(metadata or {}),
|
|
560
|
+
)
|
|
561
|
+
|
|
562
|
+
# Add optional config if provided
|
|
563
|
+
if config:
|
|
564
|
+
request.config.CopyFrom(config_conversion.config_to_proto(config))
|
|
565
|
+
|
|
566
|
+
# Add optional context if provided
|
|
567
|
+
if context:
|
|
568
|
+
request.context.CopyFrom(dict_to_struct(context))
|
|
569
|
+
|
|
570
|
+
client = await get_shared_client()
|
|
571
|
+
response = await client.assistants.Patch(request)
|
|
572
|
+
|
|
573
|
+
# Convert and yield the result
|
|
574
|
+
assistant = proto_to_assistant(response)
|
|
575
|
+
|
|
576
|
+
async def generate_result():
|
|
577
|
+
yield assistant
|
|
578
|
+
|
|
579
|
+
return generate_result()
|
|
580
|
+
|
|
581
|
+
@staticmethod
|
|
582
|
+
async def delete(
|
|
583
|
+
conn, # Not used in gRPC implementation
|
|
584
|
+
assistant_id: UUID | str,
|
|
585
|
+
ctx: Any = None,
|
|
586
|
+
) -> AsyncIterator[UUID]: # type: ignore[return-value]
|
|
587
|
+
"""Delete assistant via gRPC."""
|
|
588
|
+
# Handle auth filters
|
|
589
|
+
auth_filters = await Assistants.handle_event(
|
|
590
|
+
ctx, "delete", {"assistant_id": str(assistant_id)}
|
|
591
|
+
)
|
|
592
|
+
|
|
593
|
+
# Build the gRPC request
|
|
594
|
+
request = pb.DeleteAssistantRequest(
|
|
595
|
+
assistant_id=str(assistant_id),
|
|
596
|
+
filters=auth_filters or {},
|
|
597
|
+
)
|
|
598
|
+
|
|
599
|
+
client = await get_shared_client()
|
|
600
|
+
await client.assistants.Delete(request)
|
|
601
|
+
|
|
602
|
+
# Return the deleted ID
|
|
603
|
+
async def generate_result():
|
|
604
|
+
yield UUID(str(assistant_id))
|
|
605
|
+
|
|
606
|
+
return generate_result()
|
|
607
|
+
|
|
608
|
+
@staticmethod
|
|
609
|
+
async def set_latest(
|
|
610
|
+
conn, # Not used in gRPC implementation
|
|
611
|
+
assistant_id: UUID | str,
|
|
612
|
+
version: int,
|
|
613
|
+
ctx: Any = None,
|
|
614
|
+
) -> AsyncIterator[Assistant]: # type: ignore[return-value]
|
|
615
|
+
"""Set latest version of assistant via gRPC."""
|
|
616
|
+
# Handle auth filters
|
|
617
|
+
auth_filters = await Assistants.handle_event(
|
|
618
|
+
ctx,
|
|
619
|
+
"update",
|
|
620
|
+
{
|
|
621
|
+
"assistant_id": str(assistant_id),
|
|
622
|
+
"version": version,
|
|
623
|
+
},
|
|
624
|
+
)
|
|
625
|
+
|
|
626
|
+
# Build the gRPC request
|
|
627
|
+
request = pb.SetLatestAssistantRequest(
|
|
628
|
+
assistant_id=str(assistant_id),
|
|
629
|
+
version=version,
|
|
630
|
+
filters=auth_filters or {},
|
|
631
|
+
)
|
|
632
|
+
|
|
633
|
+
client = await get_shared_client()
|
|
634
|
+
response = await client.assistants.SetLatest(request)
|
|
635
|
+
|
|
636
|
+
# Convert and yield the result
|
|
637
|
+
assistant = proto_to_assistant(response)
|
|
638
|
+
|
|
639
|
+
async def generate_result():
|
|
640
|
+
yield assistant
|
|
641
|
+
|
|
642
|
+
return generate_result()
|
|
643
|
+
|
|
644
|
+
@staticmethod
|
|
645
|
+
async def get_versions(
|
|
646
|
+
conn, # Not used in gRPC implementation
|
|
647
|
+
assistant_id: UUID | str,
|
|
648
|
+
metadata: MetadataInput,
|
|
649
|
+
limit: int,
|
|
650
|
+
offset: int,
|
|
651
|
+
ctx: Any = None,
|
|
652
|
+
) -> AsyncIterator[Assistant]: # type: ignore[return-value]
|
|
653
|
+
"""Get all versions of assistant via gRPC."""
|
|
654
|
+
# Handle auth filters
|
|
655
|
+
auth_filters = await Assistants.handle_event(
|
|
656
|
+
ctx,
|
|
657
|
+
"search",
|
|
658
|
+
{"assistant_id": str(assistant_id), "metadata": metadata},
|
|
659
|
+
)
|
|
660
|
+
|
|
661
|
+
# Build the gRPC request
|
|
662
|
+
request = pb.GetAssistantVersionsRequest(
|
|
663
|
+
assistant_id=str(assistant_id),
|
|
664
|
+
filters=auth_filters or {},
|
|
665
|
+
metadata=dict_to_struct(metadata or {}),
|
|
666
|
+
limit=limit,
|
|
667
|
+
offset=offset,
|
|
668
|
+
)
|
|
669
|
+
|
|
670
|
+
client = await get_shared_client()
|
|
671
|
+
response = await client.assistants.GetVersions(request)
|
|
672
|
+
|
|
673
|
+
# Convert and yield the results
|
|
674
|
+
async def generate_results():
|
|
675
|
+
for version in response.versions:
|
|
676
|
+
# Preserve None for optional scalar fields by checking presence
|
|
677
|
+
version_description = (
|
|
678
|
+
version.description if version.HasField("description") else None
|
|
679
|
+
)
|
|
680
|
+
yield {
|
|
681
|
+
"assistant_id": version.assistant_id,
|
|
682
|
+
"graph_id": version.graph_id,
|
|
683
|
+
"version": version.version,
|
|
684
|
+
"created_at": version.created_at.ToDatetime(tzinfo=UTC),
|
|
685
|
+
"config": config_conversion.config_from_proto(version.config),
|
|
686
|
+
"context": struct_to_dict(version.context),
|
|
687
|
+
"metadata": struct_to_dict(version.metadata),
|
|
688
|
+
"name": version.name,
|
|
689
|
+
"description": version_description,
|
|
690
|
+
}
|
|
691
|
+
|
|
692
|
+
return generate_results()
|
|
693
|
+
|
|
694
|
+
@staticmethod
|
|
695
|
+
async def count(
|
|
696
|
+
conn, # Not used in gRPC implementation
|
|
697
|
+
*,
|
|
698
|
+
graph_id: str | None = None,
|
|
699
|
+
metadata: MetadataInput = None,
|
|
700
|
+
ctx: Any = None,
|
|
701
|
+
) -> int: # type: ignore[return-value]
|
|
702
|
+
"""Count assistants via gRPC."""
|
|
703
|
+
# Handle auth filters
|
|
704
|
+
auth_filters = await Assistants.handle_event(
|
|
705
|
+
ctx, "search", {"graph_id": graph_id, "metadata": metadata}
|
|
706
|
+
)
|
|
707
|
+
|
|
708
|
+
# Build the gRPC request
|
|
709
|
+
request = pb.CountAssistantsRequest(
|
|
710
|
+
filters=auth_filters or {},
|
|
711
|
+
graph_id=graph_id,
|
|
712
|
+
metadata=dict_to_struct(metadata or {}),
|
|
713
|
+
)
|
|
714
|
+
|
|
715
|
+
client = await get_shared_client()
|
|
716
|
+
response = await client.assistants.Count(request)
|
|
717
|
+
|
|
718
|
+
return int(response.count)
|
|
719
|
+
|
|
720
|
+
|
|
721
|
+
def _json_contains(container: Any, subset: dict[str, Any]) -> bool:
|
|
722
|
+
if not subset:
|
|
723
|
+
return True
|
|
724
|
+
if not isinstance(container, dict):
|
|
725
|
+
return False
|
|
726
|
+
for key, value in subset.items():
|
|
727
|
+
if key not in container:
|
|
728
|
+
return False
|
|
729
|
+
candidate = container[key]
|
|
730
|
+
if isinstance(value, dict):
|
|
731
|
+
if not _json_contains(candidate, value):
|
|
732
|
+
return False
|
|
733
|
+
else:
|
|
734
|
+
if candidate != value:
|
|
735
|
+
return False
|
|
736
|
+
return True
|
|
737
|
+
|
|
738
|
+
|
|
739
|
+
@grpc_error_guard
|
|
740
|
+
class Threads(Authenticated):
|
|
741
|
+
"""gRPC-based threads operations."""
|
|
742
|
+
|
|
743
|
+
resource = "threads"
|
|
744
|
+
|
|
745
|
+
@staticmethod
|
|
746
|
+
async def search(
|
|
747
|
+
conn, # Not used in gRPC implementation
|
|
748
|
+
*,
|
|
749
|
+
ids: list[str] | list[UUID] | None = None,
|
|
750
|
+
metadata: MetadataInput,
|
|
751
|
+
values: MetadataInput,
|
|
752
|
+
status: ThreadStatus | None,
|
|
753
|
+
limit: int,
|
|
754
|
+
offset: int,
|
|
755
|
+
sort_by: str | None = None,
|
|
756
|
+
sort_order: str | None = None,
|
|
757
|
+
select: list[ThreadSelectField] | None = None,
|
|
758
|
+
ctx: Any = None,
|
|
759
|
+
) -> tuple[AsyncIterator[Thread], int | None]: # type: ignore[return-value]
|
|
760
|
+
metadata = metadata or {}
|
|
761
|
+
values = values or {}
|
|
762
|
+
|
|
763
|
+
auth_filters = await Threads.handle_event(
|
|
764
|
+
ctx,
|
|
765
|
+
"search",
|
|
766
|
+
{
|
|
767
|
+
"metadata": metadata,
|
|
768
|
+
"values": values,
|
|
769
|
+
"status": status,
|
|
770
|
+
"limit": limit,
|
|
771
|
+
"offset": offset,
|
|
772
|
+
},
|
|
773
|
+
)
|
|
774
|
+
|
|
775
|
+
if ids:
|
|
776
|
+
normalized_ids = [_normalize_uuid(thread_id) for thread_id in ids]
|
|
777
|
+
threads: list[Thread] = []
|
|
778
|
+
client = await get_shared_client()
|
|
779
|
+
for thread_id in normalized_ids:
|
|
780
|
+
request = pb.GetThreadRequest(
|
|
781
|
+
thread_id=pb.UUID(value=_normalize_uuid(thread_id)),
|
|
782
|
+
filters=auth_filters or {},
|
|
783
|
+
)
|
|
784
|
+
response = await client.threads.Get(request)
|
|
785
|
+
thread = proto_to_thread(response)
|
|
786
|
+
|
|
787
|
+
if status and thread["status"] != status:
|
|
788
|
+
continue
|
|
789
|
+
if metadata and not _json_contains(thread["metadata"], metadata):
|
|
790
|
+
continue
|
|
791
|
+
if values and not _json_contains(thread.get("values") or {}, values):
|
|
792
|
+
continue
|
|
793
|
+
threads.append(thread)
|
|
794
|
+
|
|
795
|
+
total = len(threads)
|
|
796
|
+
paginated = threads[offset : offset + limit]
|
|
797
|
+
cursor = offset + limit if total > offset + limit else None
|
|
798
|
+
|
|
799
|
+
async def generate_results():
|
|
800
|
+
for thread in paginated:
|
|
801
|
+
yield _filter_thread_fields(thread, select)
|
|
802
|
+
|
|
803
|
+
return generate_results(), cursor
|
|
804
|
+
|
|
805
|
+
request_kwargs: dict[str, Any] = {
|
|
806
|
+
"filters": auth_filters or {},
|
|
807
|
+
"metadata": dict_to_struct(metadata),
|
|
808
|
+
"values": dict_to_struct(values),
|
|
809
|
+
"limit": limit,
|
|
810
|
+
"offset": offset,
|
|
811
|
+
"sort_by": _map_threads_sort_by(sort_by),
|
|
812
|
+
"sort_order": _map_sort_order(sort_order),
|
|
813
|
+
}
|
|
814
|
+
|
|
815
|
+
if status:
|
|
816
|
+
mapped_status = _map_thread_status(status)
|
|
817
|
+
if mapped_status is None:
|
|
818
|
+
raise HTTPException(
|
|
819
|
+
status_code=HTTPStatus.UNPROCESSABLE_ENTITY,
|
|
820
|
+
detail=f"Invalid thread status: {status}",
|
|
821
|
+
)
|
|
822
|
+
request_kwargs["status"] = mapped_status
|
|
823
|
+
|
|
824
|
+
if select:
|
|
825
|
+
request_kwargs["select"] = select
|
|
826
|
+
|
|
827
|
+
client = await get_shared_client()
|
|
828
|
+
response = await client.threads.Search(
|
|
829
|
+
pb.SearchThreadsRequest(**request_kwargs)
|
|
830
|
+
)
|
|
831
|
+
|
|
832
|
+
threads = [proto_to_thread(thread) for thread in response.threads]
|
|
833
|
+
cursor = offset + limit if len(threads) == limit else None
|
|
834
|
+
|
|
835
|
+
async def generate_results():
|
|
836
|
+
for thread in threads:
|
|
837
|
+
yield _filter_thread_fields(thread, select)
|
|
838
|
+
|
|
839
|
+
return generate_results(), cursor
|
|
840
|
+
|
|
841
|
+
@staticmethod
|
|
842
|
+
async def count(
|
|
843
|
+
conn, # Not used
|
|
844
|
+
*,
|
|
845
|
+
metadata: MetadataInput,
|
|
846
|
+
values: MetadataInput,
|
|
847
|
+
status: ThreadStatus | None,
|
|
848
|
+
ctx: Any = None,
|
|
849
|
+
) -> int: # type: ignore[override]
|
|
850
|
+
metadata = metadata or {}
|
|
851
|
+
values = values or {}
|
|
852
|
+
|
|
853
|
+
auth_filters = await Threads.handle_event(
|
|
854
|
+
ctx,
|
|
855
|
+
"search",
|
|
856
|
+
{
|
|
857
|
+
"metadata": metadata,
|
|
858
|
+
"values": values,
|
|
859
|
+
"status": status,
|
|
860
|
+
},
|
|
861
|
+
)
|
|
862
|
+
|
|
863
|
+
request_kwargs: dict[str, Any] = {
|
|
864
|
+
"filters": auth_filters or {},
|
|
865
|
+
"metadata": dict_to_struct(metadata),
|
|
866
|
+
"values": dict_to_struct(values),
|
|
867
|
+
}
|
|
868
|
+
if status:
|
|
869
|
+
mapped_status = _map_thread_status(status)
|
|
870
|
+
if mapped_status is None:
|
|
871
|
+
raise HTTPException(
|
|
872
|
+
status_code=HTTPStatus.UNPROCESSABLE_ENTITY,
|
|
873
|
+
detail=f"Invalid thread status: {status}",
|
|
874
|
+
)
|
|
875
|
+
request_kwargs["status"] = mapped_status
|
|
876
|
+
|
|
877
|
+
client = await get_shared_client()
|
|
878
|
+
response = await client.threads.Count(pb.CountThreadsRequest(**request_kwargs))
|
|
879
|
+
|
|
880
|
+
return int(response.count)
|
|
881
|
+
|
|
882
|
+
@staticmethod
|
|
883
|
+
async def get(
|
|
884
|
+
conn, # Not used
|
|
885
|
+
thread_id: UUID | str,
|
|
886
|
+
ctx: Any = None,
|
|
887
|
+
) -> AsyncIterator[Thread]: # type: ignore[return-value]
|
|
888
|
+
auth_filters = await Threads.handle_event(
|
|
889
|
+
ctx, "read", {"thread_id": str(thread_id)}
|
|
890
|
+
)
|
|
891
|
+
|
|
892
|
+
request = pb.GetThreadRequest(
|
|
893
|
+
thread_id=pb.UUID(value=_normalize_uuid(thread_id)),
|
|
894
|
+
filters=auth_filters or {},
|
|
895
|
+
)
|
|
896
|
+
client = await get_shared_client()
|
|
897
|
+
response = await client.threads.Get(request)
|
|
898
|
+
|
|
899
|
+
thread = proto_to_thread(response)
|
|
900
|
+
|
|
901
|
+
async def generate_result():
|
|
902
|
+
yield thread
|
|
903
|
+
|
|
904
|
+
return generate_result()
|
|
905
|
+
|
|
906
|
+
@staticmethod
|
|
907
|
+
async def put(
|
|
908
|
+
conn, # Not used
|
|
909
|
+
thread_id: UUID | str,
|
|
910
|
+
*,
|
|
911
|
+
metadata: MetadataInput,
|
|
912
|
+
if_exists: OnConflictBehavior,
|
|
913
|
+
ttl: dict[str, Any] | None = None,
|
|
914
|
+
ctx: Any = None,
|
|
915
|
+
) -> AsyncIterator[Thread]: # type: ignore[return-value]
|
|
916
|
+
metadata = metadata or {}
|
|
917
|
+
|
|
918
|
+
auth_filters = await Threads.handle_event(
|
|
919
|
+
ctx,
|
|
920
|
+
"create",
|
|
921
|
+
{
|
|
922
|
+
"thread_id": str(thread_id),
|
|
923
|
+
"metadata": metadata,
|
|
924
|
+
"if_exists": if_exists,
|
|
925
|
+
},
|
|
926
|
+
)
|
|
927
|
+
|
|
928
|
+
request = pb.CreateThreadRequest(
|
|
929
|
+
thread_id=pb.UUID(value=_normalize_uuid(thread_id)),
|
|
930
|
+
filters=auth_filters or {},
|
|
931
|
+
if_exists=map_if_exists(if_exists),
|
|
932
|
+
metadata=dict_to_struct(metadata),
|
|
933
|
+
)
|
|
934
|
+
ttl_config = _map_thread_ttl(ttl)
|
|
935
|
+
if ttl_config is not None:
|
|
936
|
+
request.ttl.CopyFrom(ttl_config)
|
|
937
|
+
|
|
938
|
+
client = await get_shared_client()
|
|
939
|
+
response = await client.threads.Create(request)
|
|
940
|
+
thread = proto_to_thread(response)
|
|
941
|
+
|
|
942
|
+
async def generate_result():
|
|
943
|
+
yield thread
|
|
944
|
+
|
|
945
|
+
return generate_result()
|
|
946
|
+
|
|
947
|
+
@staticmethod
|
|
948
|
+
async def patch(
|
|
949
|
+
conn, # Not used
|
|
950
|
+
thread_id: UUID | str,
|
|
951
|
+
*,
|
|
952
|
+
metadata: MetadataInput,
|
|
953
|
+
ttl: dict[str, Any] | None = None,
|
|
954
|
+
ctx: Any = None,
|
|
955
|
+
) -> AsyncIterator[Thread]: # type: ignore[return-value]
|
|
956
|
+
metadata = metadata or {}
|
|
957
|
+
|
|
958
|
+
auth_filters = await Threads.handle_event(
|
|
959
|
+
ctx,
|
|
960
|
+
"update",
|
|
961
|
+
{
|
|
962
|
+
"thread_id": str(thread_id),
|
|
963
|
+
"metadata": metadata,
|
|
964
|
+
},
|
|
965
|
+
)
|
|
966
|
+
|
|
967
|
+
request = pb.PatchThreadRequest(
|
|
968
|
+
thread_id=pb.UUID(value=_normalize_uuid(thread_id)),
|
|
969
|
+
filters=auth_filters or {},
|
|
970
|
+
)
|
|
971
|
+
|
|
972
|
+
if metadata:
|
|
973
|
+
request.metadata.CopyFrom(dict_to_struct(metadata))
|
|
974
|
+
|
|
975
|
+
ttl_config = _map_thread_ttl(ttl)
|
|
976
|
+
if ttl_config is not None:
|
|
977
|
+
request.ttl.CopyFrom(ttl_config)
|
|
978
|
+
|
|
979
|
+
client = await get_shared_client()
|
|
980
|
+
response = await client.threads.Patch(request)
|
|
981
|
+
|
|
982
|
+
thread = proto_to_thread(response)
|
|
983
|
+
|
|
984
|
+
async def generate_result():
|
|
985
|
+
yield thread
|
|
986
|
+
|
|
987
|
+
return generate_result()
|
|
988
|
+
|
|
989
|
+
@staticmethod
|
|
990
|
+
async def delete(
|
|
991
|
+
conn, # Not used
|
|
992
|
+
thread_id: UUID | str,
|
|
993
|
+
ctx: Any = None,
|
|
994
|
+
) -> AsyncIterator[UUID]: # type: ignore[return-value]
|
|
995
|
+
auth_filters = await Threads.handle_event(
|
|
996
|
+
ctx,
|
|
997
|
+
"delete",
|
|
998
|
+
{
|
|
999
|
+
"thread_id": str(thread_id),
|
|
1000
|
+
},
|
|
1001
|
+
)
|
|
1002
|
+
|
|
1003
|
+
request = pb.DeleteThreadRequest(
|
|
1004
|
+
thread_id=pb.UUID(value=_normalize_uuid(thread_id)),
|
|
1005
|
+
filters=auth_filters or {},
|
|
1006
|
+
)
|
|
1007
|
+
|
|
1008
|
+
client = await get_shared_client()
|
|
1009
|
+
response = await client.threads.Delete(request)
|
|
1010
|
+
|
|
1011
|
+
deleted_id = UUID(response.value)
|
|
1012
|
+
|
|
1013
|
+
async def generate_result():
|
|
1014
|
+
yield deleted_id
|
|
1015
|
+
|
|
1016
|
+
return generate_result()
|
|
1017
|
+
|
|
1018
|
+
@staticmethod
|
|
1019
|
+
async def copy(
|
|
1020
|
+
conn, # Not used
|
|
1021
|
+
thread_id: UUID | str,
|
|
1022
|
+
ctx: Any = None,
|
|
1023
|
+
) -> AsyncIterator[Thread]: # type: ignore[return-value]
|
|
1024
|
+
auth_filters = await Threads.handle_event(
|
|
1025
|
+
ctx,
|
|
1026
|
+
"read",
|
|
1027
|
+
{
|
|
1028
|
+
"thread_id": str(thread_id),
|
|
1029
|
+
},
|
|
1030
|
+
)
|
|
1031
|
+
|
|
1032
|
+
request = pb.CopyThreadRequest(
|
|
1033
|
+
thread_id=pb.UUID(value=_normalize_uuid(thread_id)),
|
|
1034
|
+
filters=auth_filters or {},
|
|
1035
|
+
)
|
|
1036
|
+
|
|
1037
|
+
client = await get_shared_client()
|
|
1038
|
+
response = await client.threads.Copy(request)
|
|
1039
|
+
|
|
1040
|
+
thread = proto_to_thread(response)
|
|
1041
|
+
|
|
1042
|
+
async def generate_result():
|
|
1043
|
+
yield thread
|
|
1044
|
+
|
|
1045
|
+
return generate_result()
|