langgraph-api 0.0.8__py3-none-any.whl → 0.0.10__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of langgraph-api might be problematic. Click here for more details.

langgraph_storage/ops.py CHANGED
@@ -11,15 +11,17 @@ from collections.abc import AsyncIterator, Sequence
11
11
  from contextlib import asynccontextmanager
12
12
  from copy import deepcopy
13
13
  from datetime import UTC, datetime, timedelta
14
- from typing import Any, Literal
14
+ from typing import Any, Literal, cast
15
15
  from uuid import UUID, uuid4
16
16
 
17
17
  import structlog
18
18
  from langgraph.pregel.debug import CheckpointPayload
19
19
  from langgraph.pregel.types import StateSnapshot
20
+ from langgraph_sdk import Auth
20
21
  from starlette.exceptions import HTTPException
21
22
 
22
23
  from langgraph_api.asyncio import SimpleTaskGroup, ValueEvent, create_task
24
+ from langgraph_api.auth.custom import handle_event
23
25
  from langgraph_api.errors import UserInterrupt, UserRollback
24
26
  from langgraph_api.graph import get_graph
25
27
  from langgraph_api.schema import (
@@ -41,7 +43,7 @@ from langgraph_api.schema import (
41
43
  ThreadUpdateResponse,
42
44
  )
43
45
  from langgraph_api.serde import Fragment
44
- from langgraph_api.utils import fetchone
46
+ from langgraph_api.utils import fetchone, get_auth_ctx
45
47
  from langgraph_storage.checkpoint import Checkpointer
46
48
  from langgraph_storage.database import InMemConnectionProto, connect
47
49
  from langgraph_storage.queue import Message, get_stream_manager
@@ -57,13 +59,51 @@ def _ensure_uuid(id_: str | uuid.UUID | None) -> uuid.UUID:
57
59
  return id_
58
60
 
59
61
 
62
+ class WrappedHTTPException(Exception):
63
+ def __init__(self, http_exception: HTTPException):
64
+ self.http_exception = http_exception
65
+
66
+
60
67
  # Right now the whole API types as UUID but frequently passes a str
61
68
  # We ensure UUIDs for eveerything EXCEPT the checkpoint storage/writes,
62
69
  # which we leave as strings. This is because I'm too lazy to subclass fully
63
70
  # and we use non-UUID examples in the OSS version
64
71
 
65
72
 
66
- class Assistants:
73
+ class Authenticated:
74
+ resource: Literal["threads", "crons", "assistants"]
75
+
76
+ @classmethod
77
+ def _context(
78
+ cls,
79
+ ctx: Auth.types.BaseAuthContext | None,
80
+ action: Literal["create", "read", "update", "delete", "create_run"],
81
+ ) -> Auth.types.AuthContext | None:
82
+ if not ctx:
83
+ return
84
+ return Auth.types.AuthContext(
85
+ user=ctx.user,
86
+ scopes=ctx.scopes,
87
+ resource=cls.resource,
88
+ action=action,
89
+ )
90
+
91
+ @classmethod
92
+ async def handle_event(
93
+ cls,
94
+ ctx: Auth.types.BaseAuthContext | None,
95
+ action: Literal["create", "read", "update", "delete", "search"],
96
+ value: Any,
97
+ ) -> Auth.types.FilterType | None:
98
+ ctx = ctx or get_auth_ctx()
99
+ if not ctx:
100
+ return
101
+ return await handle_event(cls._context(ctx, action), value)
102
+
103
+
104
+ class Assistants(Authenticated):
105
+ resource = "assistants"
106
+
67
107
  @staticmethod
68
108
  async def search(
69
109
  conn: InMemConnectionProto,
@@ -72,7 +112,17 @@ class Assistants:
72
112
  metadata: MetadataInput,
73
113
  limit: int,
74
114
  offset: int,
115
+ ctx: Auth.types.BaseAuthContext | None = None,
75
116
  ) -> AsyncIterator[Assistant]:
117
+ metadata = metadata if metadata is not None else {}
118
+ filters = await Assistants.handle_event(
119
+ ctx,
120
+ "search",
121
+ Auth.types.AssistantsSearch(
122
+ graph_id=graph_id, metadata=metadata, limit=limit, offset=offset
123
+ ),
124
+ )
125
+
76
126
  async def filter_and_yield() -> AsyncIterator[Assistant]:
77
127
  assistants = conn.store["assistants"]
78
128
  filtered_assistants = [
@@ -82,6 +132,7 @@ class Assistants:
82
132
  and (
83
133
  not metadata or is_jsonb_contained(assistant["metadata"], metadata)
84
134
  )
135
+ and (not filters or _check_filter_match(assistant["metadata"], filters))
85
136
  ]
86
137
  filtered_assistants.sort(key=lambda x: x["created_at"], reverse=True)
87
138
  for assistant in filtered_assistants[offset : offset + limit]:
@@ -91,14 +142,23 @@ class Assistants:
91
142
 
92
143
  @staticmethod
93
144
  async def get(
94
- conn: InMemConnectionProto, assistant_id: UUID
145
+ conn: InMemConnectionProto,
146
+ assistant_id: UUID,
147
+ ctx: Auth.types.BaseAuthContext | None = None,
95
148
  ) -> AsyncIterator[Assistant]:
96
149
  """Get an assistant by ID."""
97
150
  assistant_id = _ensure_uuid(assistant_id)
151
+ filters = await Assistants.handle_event(
152
+ ctx,
153
+ "read",
154
+ Auth.types.AssistantsRead(assistant_id=assistant_id),
155
+ )
98
156
 
99
157
  async def _yield_result():
100
158
  for assistant in conn.store["assistants"]:
101
- if assistant["assistant_id"] == assistant_id:
159
+ if assistant["assistant_id"] == assistant_id and (
160
+ not filters or _check_filter_match(assistant["metadata"], filters)
161
+ ):
102
162
  yield assistant
103
163
 
104
164
  return _yield_result()
@@ -113,14 +173,33 @@ class Assistants:
113
173
  metadata: MetadataInput,
114
174
  if_exists: OnConflictBehavior,
115
175
  name: str,
176
+ ctx: Auth.types.BaseAuthContext | None = None,
116
177
  ) -> AsyncIterator[Assistant]:
117
178
  """Insert an assistant."""
118
179
  assistant_id = _ensure_uuid(assistant_id)
180
+ metadata = metadata if metadata is not None else {}
181
+ filters = await Assistants.handle_event(
182
+ ctx,
183
+ "create",
184
+ Auth.types.AssistantsCreate(
185
+ assistant_id=assistant_id,
186
+ graph_id=graph_id,
187
+ config=config,
188
+ metadata=metadata,
189
+ name=name,
190
+ ),
191
+ )
119
192
  existing_assistant = next(
120
193
  (a for a in conn.store["assistants"] if a["assistant_id"] == assistant_id),
121
194
  None,
122
195
  )
123
196
  if existing_assistant:
197
+ if filters and not _check_filter_match(
198
+ existing_assistant["metadata"], filters
199
+ ):
200
+ raise HTTPException(
201
+ status_code=409, detail=f"Assistant {assistant_id} already exists"
202
+ )
124
203
  if if_exists == "raise":
125
204
  raise HTTPException(
126
205
  status_code=409, detail=f"Assistant {assistant_id} already exists"
@@ -168,6 +247,7 @@ class Assistants:
168
247
  graph_id: str | None = None,
169
248
  metadata: MetadataInput | None = None,
170
249
  name: str | None = None,
250
+ ctx: Auth.types.BaseAuthContext | None = None,
171
251
  ) -> AsyncIterator[Assistant]:
172
252
  """Update an assistant.
173
253
 
@@ -182,6 +262,18 @@ class Assistants:
182
262
  return the updated assistant model.
183
263
  """
184
264
  assistant_id = _ensure_uuid(assistant_id)
265
+ metadata = metadata if metadata is not None else {}
266
+ filters = await Assistants.handle_event(
267
+ ctx,
268
+ "update",
269
+ Auth.types.AssistantsUpdate(
270
+ assistant_id=assistant_id,
271
+ graph_id=graph_id,
272
+ config=config,
273
+ metadata=metadata,
274
+ name=name,
275
+ ),
276
+ )
185
277
  assistant = next(
186
278
  (a for a in conn.store["assistants"] if a["assistant_id"] == assistant_id),
187
279
  None,
@@ -190,6 +282,10 @@ class Assistants:
190
282
  raise HTTPException(
191
283
  status_code=404, detail=f"Assistant {assistant_id} not found"
192
284
  )
285
+ elif filters and not _check_filter_match(assistant["metadata"], filters):
286
+ raise HTTPException(
287
+ status_code=404, detail=f"Assistant {assistant_id} not found"
288
+ )
193
289
 
194
290
  now = datetime.now(UTC)
195
291
  new_version = (
@@ -204,6 +300,11 @@ class Assistants:
204
300
  )
205
301
 
206
302
  # Update assistant_versions table
303
+ if metadata:
304
+ metadata = {
305
+ **assistant["metadata"],
306
+ **metadata,
307
+ }
207
308
  new_version_entry = {
208
309
  "assistant_id": assistant_id,
209
310
  "version": new_version,
@@ -233,10 +334,33 @@ class Assistants:
233
334
 
234
335
  @staticmethod
235
336
  async def delete(
236
- conn: InMemConnectionProto, assistant_id: UUID
337
+ conn: InMemConnectionProto,
338
+ assistant_id: UUID,
339
+ ctx: Auth.types.BaseAuthContext | None = None,
237
340
  ) -> AsyncIterator[UUID]:
238
341
  """Delete an assistant by ID."""
239
342
  assistant_id = _ensure_uuid(assistant_id)
343
+ filters = await Assistants.handle_event(
344
+ ctx,
345
+ "delete",
346
+ Auth.types.AssistantsDelete(
347
+ assistant_id=assistant_id,
348
+ ),
349
+ )
350
+ assistant = next(
351
+ (a for a in conn.store["assistants"] if a["assistant_id"] == assistant_id),
352
+ None,
353
+ )
354
+
355
+ if not assistant:
356
+ raise HTTPException(
357
+ status_code=404, detail=f"Assistant with ID {assistant_id} not found"
358
+ )
359
+ elif filters and not _check_filter_match(assistant["metadata"], filters):
360
+ raise HTTPException(
361
+ status_code=404, detail=f"Assistant with ID {assistant_id} not found"
362
+ )
363
+
240
364
  conn.store["assistants"] = [
241
365
  a for a in conn.store["assistants"] if a["assistant_id"] != assistant_id
242
366
  ]
@@ -249,9 +373,10 @@ class Assistants:
249
373
  retained = []
250
374
  for run in conn.store["runs"]:
251
375
  if run["assistant_id"] == assistant_id:
252
- res = await Runs.delete(conn, run["run_id"], thread_id=run["thread_id"])
376
+ res = await Runs.delete(
377
+ conn, run["run_id"], thread_id=run["thread_id"], ctx=ctx
378
+ )
253
379
  await anext(res)
254
-
255
380
  else:
256
381
  retained.append(run)
257
382
 
@@ -262,10 +387,21 @@ class Assistants:
262
387
 
263
388
  @staticmethod
264
389
  async def set_latest(
265
- conn: InMemConnectionProto, assistant_id: UUID, version: int
390
+ conn: InMemConnectionProto,
391
+ assistant_id: UUID,
392
+ version: int,
393
+ ctx: Auth.types.BaseAuthContext | None = None,
266
394
  ) -> AsyncIterator[Assistant]:
267
395
  """Change the version of an assistant."""
268
396
  assistant_id = _ensure_uuid(assistant_id)
397
+ filters = await Assistants.handle_event(
398
+ ctx,
399
+ "update",
400
+ Auth.types.AssistantsUpdate(
401
+ assistant_id=assistant_id,
402
+ version=version,
403
+ ),
404
+ )
269
405
  assistant = next(
270
406
  (a for a in conn.store["assistants"] if a["assistant_id"] == assistant_id),
271
407
  None,
@@ -274,6 +410,10 @@ class Assistants:
274
410
  raise HTTPException(
275
411
  status_code=404, detail=f"Assistant {assistant_id} not found"
276
412
  )
413
+ elif filters and not _check_filter_match(assistant["metadata"], filters):
414
+ raise HTTPException(
415
+ status_code=404, detail=f"Assistant {assistant_id} not found"
416
+ )
277
417
 
278
418
  version_data = next(
279
419
  (
@@ -310,14 +450,21 @@ class Assistants:
310
450
  metadata: MetadataInput,
311
451
  limit: int,
312
452
  offset: int,
453
+ ctx: Auth.types.BaseAuthContext | None = None,
313
454
  ) -> AsyncIterator[Assistant]:
314
455
  """Get all versions of an assistant."""
315
456
  assistant_id = _ensure_uuid(assistant_id)
457
+ filters = await Assistants.handle_event(
458
+ ctx,
459
+ "read",
460
+ Auth.types.AssistantsRead(assistant_id=assistant_id),
461
+ )
316
462
  versions = [
317
463
  v
318
464
  for v in conn.store["assistant_versions"]
319
465
  if v["assistant_id"] == assistant_id
320
466
  and (not metadata or is_jsonb_contained(v["metadata"], metadata))
467
+ and (not filters or _check_filter_match(v["metadata"], filters))
321
468
  ]
322
469
  versions.sort(key=lambda x: x["version"], reverse=True)
323
470
 
@@ -379,7 +526,9 @@ def _replace_thread_id(data, new_thread_id, thread_id):
379
526
  return d
380
527
 
381
528
 
382
- class Threads:
529
+ class Threads(Authenticated):
530
+ resource = "threads"
531
+
383
532
  @staticmethod
384
533
  async def search(
385
534
  conn: InMemConnectionProto,
@@ -389,29 +538,43 @@ class Threads:
389
538
  status: ThreadStatus | None,
390
539
  limit: int,
391
540
  offset: int,
541
+ ctx: Auth.types.BaseAuthContext | None = None,
392
542
  ) -> AsyncIterator[Thread]:
393
543
  threads = conn.store["threads"]
394
544
  filtered_threads: list[Thread] = []
545
+ metadata = metadata if metadata is not None else {}
546
+ values = values if values is not None else {}
547
+ filters = await Threads.handle_event(
548
+ ctx,
549
+ "search",
550
+ Auth.types.ThreadsSearch(
551
+ metadata=metadata,
552
+ values=values,
553
+ status=status,
554
+ limit=limit,
555
+ offset=offset,
556
+ ),
557
+ )
395
558
 
396
559
  # Apply filters
397
560
  for thread in threads:
398
- matches = True
561
+ if filters and not _check_filter_match(thread["metadata"], filters):
562
+ continue
399
563
 
400
564
  if metadata and not is_jsonb_contained(thread["metadata"], metadata):
401
- matches = False
565
+ continue
402
566
 
403
567
  if (
404
568
  values
405
569
  and "values" in thread
406
570
  and not is_jsonb_contained(thread["values"], values)
407
571
  ):
408
- matches = False
572
+ continue
409
573
 
410
574
  if status and thread.get("status") != status:
411
- matches = False
575
+ continue
412
576
 
413
- if matches:
414
- filtered_threads.append(thread)
577
+ filtered_threads.append(thread)
415
578
 
416
579
  # Sort by created_at in descending order
417
580
  sorted_threads = sorted(
@@ -428,8 +591,11 @@ class Threads:
428
591
  return thread_iterator()
429
592
 
430
593
  @staticmethod
431
- async def get(conn: InMemConnectionProto, thread_id: UUID) -> AsyncIterator[Thread]:
432
- """Get a thread by ID."""
594
+ async def _get_with_filters(
595
+ conn: InMemConnectionProto,
596
+ thread_id: UUID,
597
+ filters: Auth.types.FilterType | None,
598
+ ) -> Thread | None:
433
599
  thread_id = _ensure_uuid(thread_id)
434
600
  matching_thread = next(
435
601
  (
@@ -439,6 +605,37 @@ class Threads:
439
605
  ),
440
606
  None,
441
607
  )
608
+ if not matching_thread or (
609
+ filters and not _check_filter_match(matching_thread["metadata"], filters)
610
+ ):
611
+ return
612
+
613
+ return matching_thread
614
+
615
+ @staticmethod
616
+ async def _get(
617
+ conn: InMemConnectionProto,
618
+ thread_id: UUID,
619
+ ctx: Auth.types.BaseAuthContext | None = None,
620
+ ) -> Thread | None:
621
+ """Get a thread by ID."""
622
+ thread_id = _ensure_uuid(thread_id)
623
+ filters = await Threads.handle_event(
624
+ ctx,
625
+ "read",
626
+ Auth.types.ThreadsRead(thread_id=thread_id),
627
+ )
628
+ return await Threads._get_with_filters(conn, thread_id, filters)
629
+
630
+ @staticmethod
631
+ async def get(
632
+ conn: InMemConnectionProto,
633
+ thread_id: UUID,
634
+ ctx: Auth.types.BaseAuthContext | None = None,
635
+ ) -> AsyncIterator[Thread]:
636
+ """Get a thread by ID."""
637
+ matching_thread = await Threads._get(conn, thread_id, ctx)
638
+
442
639
  if not matching_thread:
443
640
  raise HTTPException(
444
641
  status_code=404, detail=f"Thread with ID {thread_id} not found"
@@ -457,6 +654,7 @@ class Threads:
457
654
  *,
458
655
  metadata: MetadataInput,
459
656
  if_exists: OnConflictBehavior,
657
+ ctx: Auth.types.BaseAuthContext | None = None,
460
658
  ) -> AsyncIterator[Thread]:
461
659
  """Insert or update a thread."""
462
660
  thread_id = _ensure_uuid(thread_id)
@@ -467,8 +665,22 @@ class Threads:
467
665
  existing_thread = next(
468
666
  (t for t in conn.store["threads"] if t["thread_id"] == thread_id), None
469
667
  )
668
+ filters = await Threads.handle_event(
669
+ ctx,
670
+ "create",
671
+ Auth.types.ThreadsCreate(
672
+ thread_id=thread_id, metadata=metadata, if_exists=if_exists
673
+ ),
674
+ )
470
675
 
471
676
  if existing_thread:
677
+ if filters and not _check_filter_match(
678
+ existing_thread["metadata"], filters
679
+ ):
680
+ # Should we use a different status code here?
681
+ raise HTTPException(
682
+ status_code=409, detail=f"Thread with ID {thread_id} already exists"
683
+ )
472
684
  if if_exists == "raise":
473
685
  raise HTTPException(
474
686
  status_code=409, detail=f"Thread with ID {thread_id} already exists"
@@ -479,7 +691,6 @@ class Threads:
479
691
  yield existing_thread
480
692
 
481
693
  return _yield_existing()
482
-
483
694
  # Create new thread
484
695
  new_thread: Thread = {
485
696
  "thread_id": thread_id,
@@ -501,7 +712,11 @@ class Threads:
501
712
 
502
713
  @staticmethod
503
714
  async def patch(
504
- conn: InMemConnectionProto, thread_id: UUID, *, metadata: MetadataValue
715
+ conn: InMemConnectionProto,
716
+ thread_id: UUID,
717
+ *,
718
+ metadata: MetadataValue,
719
+ ctx: Auth.types.BaseAuthContext | None = None,
505
720
  ) -> AsyncIterator[Thread]:
506
721
  """Update a thread."""
507
722
  thread_list = conn.store["threads"]
@@ -514,15 +729,23 @@ class Threads:
514
729
  break
515
730
 
516
731
  if thread_idx is not None:
517
- thread = copy.deepcopy(thread_list[thread_idx])
518
- thread["metadata"] = {**thread["metadata"], **metadata}
519
- thread["updated_at"] = datetime.now(UTC)
520
- thread_list[thread_idx] = thread
732
+ filters = await Threads.handle_event(
733
+ ctx,
734
+ "update",
735
+ Auth.types.ThreadsUpdate(thread_id=thread_id, metadata=metadata),
736
+ )
737
+ if not filters or _check_filter_match(
738
+ thread_list[thread_idx]["metadata"], filters
739
+ ):
740
+ thread = copy.deepcopy(thread_list[thread_idx])
741
+ thread["metadata"] = {**thread["metadata"], **metadata}
742
+ thread["updated_at"] = datetime.now(UTC)
743
+ thread_list[thread_idx] = thread
521
744
 
522
- async def thread_iterator() -> AsyncIterator[Thread]:
523
- yield thread
745
+ async def thread_iterator() -> AsyncIterator[Thread]:
746
+ yield thread
524
747
 
525
- return thread_iterator()
748
+ return thread_iterator()
526
749
 
527
750
  async def empty_iterator() -> AsyncIterator[Thread]:
528
751
  if False: # This ensures the iterator is empty
@@ -536,6 +759,7 @@ class Threads:
536
759
  thread_id: UUID,
537
760
  checkpoint: CheckpointPayload | None,
538
761
  exception: BaseException | None,
762
+ # This does not accept the auth context since it's only used internally
539
763
  ) -> None:
540
764
  """Set the status of a thread."""
541
765
  thread_id = _ensure_uuid(thread_id)
@@ -597,19 +821,33 @@ class Threads:
597
821
 
598
822
  @staticmethod
599
823
  async def delete(
600
- conn: InMemConnectionProto, thread_id: UUID
824
+ conn: InMemConnectionProto,
825
+ thread_id: UUID,
826
+ ctx: Auth.types.BaseAuthContext | None = None,
601
827
  ) -> AsyncIterator[UUID]:
602
828
  """Delete a thread by ID and cascade delete all associated runs."""
603
829
  thread_list = conn.store["threads"]
604
830
  thread_idx = None
605
831
  thread_id = _ensure_uuid(thread_id)
606
- conn.locks.pop(thread_id, None)
607
832
 
608
833
  # Find the thread to delete
609
834
  for idx, thread in enumerate(thread_list):
610
835
  if thread["thread_id"] == thread_id:
611
836
  thread_idx = idx
612
837
  break
838
+ filters = await Threads.handle_event(
839
+ ctx,
840
+ "delete",
841
+ Auth.types.ThreadsDelete(thread_id=thread_id),
842
+ )
843
+ if (filters and not _check_filter_match(thread["metadata"], filters)) or (
844
+ thread_idx is None
845
+ ):
846
+ raise HTTPException(
847
+ status_code=404, detail=f"Thread with ID {thread_id} not found"
848
+ )
849
+ # Delete the thread
850
+ conn.locks.pop(thread_id, None)
613
851
  # Cascade delete all runs associated with this thread
614
852
  conn.store["runs"] = [
615
853
  run for run in conn.store["runs"] if run["thread_id"] != thread_id
@@ -635,12 +873,20 @@ class Threads:
635
873
 
636
874
  @staticmethod
637
875
  async def copy(
638
- conn: InMemConnectionProto, thread_id: UUID
876
+ conn: InMemConnectionProto,
877
+ thread_id: UUID,
878
+ ctx: Auth.types.BaseAuthContext | None = None,
639
879
  ) -> AsyncIterator[Thread]:
640
880
  """Create a copy of an existing thread."""
641
881
  thread_id = _ensure_uuid(thread_id)
642
882
  new_thread_id = uuid4()
643
-
883
+ filters = await Threads.handle_event(
884
+ ctx,
885
+ "read",
886
+ Auth.types.ThreadsRead(
887
+ thread_id=new_thread_id,
888
+ ),
889
+ )
644
890
  async with conn.pipeline():
645
891
  # Find the original thread in our store
646
892
  original_thread = next(
@@ -648,7 +894,11 @@ class Threads:
648
894
  )
649
895
 
650
896
  if not original_thread:
651
- return
897
+ return _empty_generator()
898
+ if filters and not _check_filter_match(
899
+ original_thread["metadata"], filters
900
+ ):
901
+ return _empty_generator()
652
902
 
653
903
  # Create new thread with copied metadata
654
904
  new_thread: Thread = {
@@ -690,15 +940,22 @@ class Threads:
690
940
 
691
941
  return row_generator()
692
942
 
693
- class State:
943
+ class State(Authenticated):
944
+ # We will treat this like a runs resource for now.
945
+ resource = "threads"
946
+
694
947
  @staticmethod
695
948
  async def get(
696
- conn: InMemConnectionProto, config: Config, subgraphs: bool = False
949
+ conn: InMemConnectionProto,
950
+ config: Config,
951
+ subgraphs: bool = False,
952
+ ctx: Auth.types.BaseAuthContext | None = None,
697
953
  ) -> StateSnapshot:
698
954
  """Get state for a thread."""
699
955
  checkpointer = Checkpointer(conn)
700
956
  thread_id = _ensure_uuid(config["configurable"]["thread_id"])
701
- thread_iter = await Threads.get(conn, thread_id)
957
+ # Auth will be applied here so no need to use filters downstream
958
+ thread_iter = await Threads.get(conn, thread_id, ctx=ctx)
702
959
  thread = await anext(thread_iter)
703
960
  checkpoint = await checkpointer.aget(config)
704
961
 
@@ -747,12 +1004,19 @@ class Threads:
747
1004
  config: Config,
748
1005
  values: Sequence[dict] | dict[str, Any] | None,
749
1006
  as_node: str | None = None,
1007
+ ctx: Auth.types.BaseAuthContext | None = None,
750
1008
  ) -> ThreadUpdateResponse:
751
1009
  """Add state to a thread."""
1010
+ thread_id = _ensure_uuid(config["configurable"]["thread_id"])
1011
+ filters = await Threads.handle_event(
1012
+ ctx,
1013
+ "update",
1014
+ Auth.types.ThreadsUpdate(thread_id=thread_id),
1015
+ )
752
1016
 
753
1017
  checkpointer = Checkpointer(conn)
754
- thread_id = _ensure_uuid(config["configurable"]["thread_id"])
755
- thread_iter = await Threads.get(conn, thread_id)
1018
+
1019
+ thread_iter = await Threads.get(conn, thread_id, ctx=ctx)
756
1020
  thread = await fetchone(
757
1021
  thread_iter, not_found_detail=f"Thread {thread_id} not found."
758
1022
  )
@@ -760,6 +1024,8 @@ class Threads:
760
1024
 
761
1025
  if not thread:
762
1026
  raise HTTPException(status_code=404, detail="Thread not found")
1027
+ if not _check_filter_match(thread["metadata"], filters):
1028
+ raise HTTPException(status_code=403, detail="Forbidden")
763
1029
 
764
1030
  metadata = thread["metadata"]
765
1031
  thread_config = thread["config"]
@@ -781,7 +1047,7 @@ class Threads:
781
1047
  )
782
1048
 
783
1049
  # Get current state
784
- state = await Threads.State.get(conn, config, subgraphs=False)
1050
+ state = await Threads.State.get(conn, config, subgraphs=False, ctx=ctx)
785
1051
  # Update thread values
786
1052
  for thread in conn.store["threads"]:
787
1053
  if thread["thread_id"] == thread_id:
@@ -805,22 +1071,26 @@ class Threads:
805
1071
  limit: int = 10,
806
1072
  before: str | Checkpoint | None = None,
807
1073
  metadata: MetadataInput = None,
1074
+ ctx: Auth.types.BaseAuthContext | None = None,
808
1075
  ) -> list[StateSnapshot]:
809
1076
  """Get the history of a thread."""
810
1077
 
811
1078
  thread_id = _ensure_uuid(config["configurable"]["thread_id"])
812
1079
  thread = None
813
-
814
- for t in conn.store["threads"]:
815
- if t["thread_id"] == thread_id:
816
- thread = t
817
- break
818
-
819
- if not thread:
820
- return []
1080
+ filters = await Threads.handle_event(
1081
+ ctx,
1082
+ "read",
1083
+ Auth.types.ThreadsRead(thread_id=thread_id),
1084
+ )
1085
+ thread = await fetchone(
1086
+ await Threads.get(conn, config["configurable"]["thread_id"], ctx=ctx)
1087
+ )
821
1088
 
822
1089
  # Parse thread metadata and config
823
1090
  thread_metadata = thread["metadata"]
1091
+ if not _check_filter_match(thread_metadata, filters):
1092
+ return []
1093
+
824
1094
  thread_config = thread["config"]
825
1095
  # If graph_id exists, get state history
826
1096
  if graph_id := thread_metadata.get("graph_id"):
@@ -847,7 +1117,9 @@ class Threads:
847
1117
  return []
848
1118
 
849
1119
 
850
- class Runs:
1120
+ class Runs(Authenticated):
1121
+ resource = "threads"
1122
+
851
1123
  @staticmethod
852
1124
  async def stats(conn: InMemConnectionProto) -> QueueStats:
853
1125
  """Get stats about the queue."""
@@ -1001,6 +1273,7 @@ class Runs:
1001
1273
  multitask_strategy: MultitaskStrategy = "reject",
1002
1274
  if_not_exists: IfNotExists = "reject",
1003
1275
  after_seconds: int = 0,
1276
+ ctx: Auth.types.BaseAuthContext | None = None,
1004
1277
  ) -> AsyncIterator[Run]:
1005
1278
  """Create a run."""
1006
1279
  assistant_id = _ensure_uuid(assistant_id)
@@ -1009,22 +1282,38 @@ class Runs:
1009
1282
  None,
1010
1283
  )
1011
1284
 
1012
- async def empty_generator():
1013
- if False:
1014
- yield
1015
-
1016
1285
  if not assistant:
1017
- return empty_generator()
1286
+ return _empty_generator()
1018
1287
 
1019
1288
  thread_id = _ensure_uuid(thread_id) if thread_id else None
1020
1289
  run_id = _ensure_uuid(run_id) if run_id else None
1021
- metadata = metadata or {}
1290
+ metadata = metadata if metadata is not None else {}
1022
1291
  config = kwargs.get("config", {})
1023
1292
 
1024
1293
  # Handle thread creation/update
1025
1294
  existing_thread = next(
1026
1295
  (t for t in conn.store["threads"] if t["thread_id"] == thread_id), None
1027
1296
  )
1297
+ filters = await Runs.handle_event(
1298
+ ctx,
1299
+ "create",
1300
+ Auth.types.RunsCreate(
1301
+ thread_id=thread_id,
1302
+ assistant_id=assistant_id,
1303
+ run_id=run_id,
1304
+ status=status,
1305
+ metadata=metadata,
1306
+ prevent_insert_if_inflight=prevent_insert_if_inflight,
1307
+ multitask_strategy=multitask_strategy,
1308
+ if_not_exists=if_not_exists,
1309
+ after_seconds=after_seconds,
1310
+ kwargs=kwargs,
1311
+ ),
1312
+ )
1313
+ if existing_thread and filters:
1314
+ # Reject if the user doesn't own the thread
1315
+ if not _check_filter_match(existing_thread["metadata"], filters):
1316
+ return _empty_generator()
1028
1317
 
1029
1318
  if not existing_thread and (thread_id is None or if_not_exists == "create"):
1030
1319
  # Create new thread
@@ -1069,7 +1358,7 @@ class Runs:
1069
1358
  )
1070
1359
  existing_thread["updated_at"] = datetime.now(UTC)
1071
1360
  else:
1072
- return empty_generator()
1361
+ return _empty_generator()
1073
1362
 
1074
1363
  # Check for inflight runs if needed
1075
1364
  inflight_runs = [
@@ -1089,9 +1378,11 @@ class Runs:
1089
1378
  # Create new run
1090
1379
  configurable = Runs._merge_jsonb(
1091
1380
  Runs._get_configurable(assistant["config"]),
1092
- Runs._get_configurable(existing_thread["config"])
1093
- if existing_thread
1094
- else {},
1381
+ (
1382
+ Runs._get_configurable(existing_thread["config"])
1383
+ if existing_thread
1384
+ else {}
1385
+ ),
1095
1386
  Runs._get_configurable(config),
1096
1387
  {
1097
1388
  "run_id": str(run_id),
@@ -1149,11 +1440,20 @@ class Runs:
1149
1440
 
1150
1441
  @staticmethod
1151
1442
  async def get(
1152
- conn: InMemConnectionProto, run_id: UUID, *, thread_id: UUID
1443
+ conn: InMemConnectionProto,
1444
+ run_id: UUID,
1445
+ *,
1446
+ thread_id: UUID,
1447
+ ctx: Auth.types.BaseAuthContext | None = None,
1153
1448
  ) -> AsyncIterator[Run]:
1154
1449
  """Get a run by ID."""
1155
1450
 
1156
1451
  run_id, thread_id = _ensure_uuid(run_id), _ensure_uuid(thread_id)
1452
+ filters = await Runs.handle_event(
1453
+ ctx,
1454
+ "read",
1455
+ Auth.types.ThreadsRead(thread_id=thread_id),
1456
+ )
1157
1457
 
1158
1458
  async def _yield_result():
1159
1459
  matching_run = None
@@ -1162,16 +1462,36 @@ class Runs:
1162
1462
  matching_run = run
1163
1463
  break
1164
1464
  if matching_run:
1465
+ if filters:
1466
+ thread = await Threads._get_with_filters(
1467
+ conn, matching_run["thread_id"], filters
1468
+ )
1469
+ if not thread:
1470
+ return
1165
1471
  yield matching_run
1166
1472
 
1167
1473
  return _yield_result()
1168
1474
 
1169
1475
  @staticmethod
1170
1476
  async def delete(
1171
- conn: InMemConnectionProto, run_id: UUID, *, thread_id: UUID
1477
+ conn: InMemConnectionProto,
1478
+ run_id: UUID,
1479
+ *,
1480
+ thread_id: UUID,
1481
+ ctx: Auth.types.BaseAuthContext | None = None,
1172
1482
  ) -> AsyncIterator[UUID]:
1173
1483
  """Delete a run by ID."""
1174
1484
  run_id, thread_id = _ensure_uuid(run_id), _ensure_uuid(thread_id)
1485
+ filters = await Runs.handle_event(
1486
+ ctx,
1487
+ "delete",
1488
+ Auth.types.ThreadsDelete(run_id=run_id, thread_id=thread_id),
1489
+ )
1490
+
1491
+ if filters:
1492
+ thread = await Threads._get_with_filters(conn, thread_id, filters)
1493
+ if not thread:
1494
+ return _empty_generator()
1175
1495
  _delete_checkpoints_for_thread(thread_id, conn, run_id=run_id)
1176
1496
  found = False
1177
1497
  for i, run in enumerate(conn.store["runs"]):
@@ -1192,16 +1512,22 @@ class Runs:
1192
1512
  run_id: UUID,
1193
1513
  *,
1194
1514
  thread_id: UUID,
1515
+ ctx: Auth.types.BaseAuthContext | None = None,
1195
1516
  ) -> Fragment:
1196
1517
  """Wait for a run to complete. If already done, return immediately.
1197
1518
 
1198
1519
  Returns:
1199
1520
  the final state of the run.
1200
1521
  """
1522
+ async with connect() as conn:
1523
+ # Validate ownership
1524
+ thread_iter = await Threads.get(conn, thread_id, ctx=ctx)
1525
+ await fetchone(thread_iter)
1201
1526
  last_chunk: bytes | None = None
1202
1527
  # wait for the run to complete
1528
+ # Rely on this join's auth
1203
1529
  async for mode, chunk in Runs.Stream.join(
1204
- run_id, thread_id=thread_id, stream_mode="values"
1530
+ run_id, thread_id=thread_id, stream_mode="values", ctx=ctx
1205
1531
  ):
1206
1532
  if mode == b"values":
1207
1533
  last_chunk = chunk
@@ -1212,7 +1538,7 @@ class Runs:
1212
1538
  else:
1213
1539
  # otherwise, the run had already finished, so fetch the state from thread
1214
1540
  async with connect() as conn:
1215
- thread_iter = await Threads.get(conn, thread_id)
1541
+ thread_iter = await Threads.get(conn, thread_id, ctx=ctx)
1216
1542
  thread = await fetchone(thread_iter)
1217
1543
  return thread["values"]
1218
1544
 
@@ -1223,8 +1549,10 @@ class Runs:
1223
1549
  *,
1224
1550
  action: Literal["interrupt", "rollback"] = "interrupt",
1225
1551
  thread_id: UUID,
1552
+ ctx: Auth.types.BaseAuthContext | None = None,
1226
1553
  ) -> None:
1227
1554
  """Cancel a run."""
1555
+ # Authwise, this invokes the runs.update handler
1228
1556
  # Cancellation tries to take two actions, to cover runs in different states:
1229
1557
  # - For any run, send a cancellation message through the stream manager
1230
1558
  # - For queued runs not yet picked up by a worker, update their status if interrupt,
@@ -1233,6 +1561,15 @@ class Runs:
1233
1561
  # - For runs in any other state, we raise a 404
1234
1562
  run_ids = [_ensure_uuid(run_id) for run_id in run_ids]
1235
1563
  thread_id = _ensure_uuid(thread_id)
1564
+ filters = await Runs.handle_event(
1565
+ ctx,
1566
+ "update",
1567
+ Auth.types.ThreadsUpdate(
1568
+ thread_id=thread_id,
1569
+ action=action,
1570
+ metadata={"run_ids": run_ids},
1571
+ ),
1572
+ )
1236
1573
 
1237
1574
  stream_manager = get_stream_manager()
1238
1575
  found_runs = []
@@ -1247,6 +1584,10 @@ class Runs:
1247
1584
  None,
1248
1585
  )
1249
1586
  if run:
1587
+ if filters:
1588
+ thread = await Threads._get_with_filters(conn, thread_id, filters)
1589
+ if not thread:
1590
+ continue
1250
1591
  found_runs.append(run)
1251
1592
  # Send cancellation message through stream manager
1252
1593
  control_message = Message(
@@ -1296,16 +1637,26 @@ class Runs:
1296
1637
  offset: int = 0,
1297
1638
  metadata: MetadataInput,
1298
1639
  status: RunStatus | None = None,
1640
+ ctx: Auth.types.BaseAuthContext | None = None,
1299
1641
  ) -> AsyncIterator[Run]:
1300
1642
  """List all runs by thread."""
1301
1643
  runs = conn.store["runs"]
1302
- metadata = metadata or {}
1644
+ metadata = metadata if metadata is not None else {}
1303
1645
  thread_id = _ensure_uuid(thread_id)
1646
+ filters = await Runs.handle_event(
1647
+ ctx,
1648
+ "search",
1649
+ Auth.types.ThreadsSearch(thread_id=thread_id, metadata=metadata),
1650
+ )
1304
1651
  filtered_runs = [
1305
1652
  run
1306
1653
  for run in runs
1307
1654
  if run["thread_id"] == thread_id
1308
1655
  and is_jsonb_contained(run["metadata"], metadata)
1656
+ and (
1657
+ not filters
1658
+ or (await Threads._get_with_filters(conn, thread_id, filters))
1659
+ )
1309
1660
  and (status is None or run["status"] == status)
1310
1661
  ]
1311
1662
  sorted_runs = sorted(filtered_runs, key=lambda x: x["created_at"], reverse=True)
@@ -1358,6 +1709,7 @@ class Runs:
1358
1709
  ignore_404: bool = False,
1359
1710
  cancel_on_disconnect: bool = False,
1360
1711
  stream_mode: "StreamMode | asyncio.Queue | None" = None,
1712
+ ctx: Auth.types.BaseAuthContext | None = None,
1361
1713
  ) -> AsyncIterator[tuple[bytes, bytes]]:
1362
1714
  """Stream the run output."""
1363
1715
  log = logger.isEnabledFor(logging.DEBUG)
@@ -1369,6 +1721,21 @@ class Runs:
1369
1721
 
1370
1722
  try:
1371
1723
  async with connect() as conn:
1724
+ filters = await Runs.handle_event(
1725
+ ctx,
1726
+ "read",
1727
+ Auth.types.ThreadsRead(thread_id=thread_id),
1728
+ )
1729
+ if filters:
1730
+ thread = await Threads._get_with_filters(
1731
+ cast(InMemConnectionProto, conn), thread_id, filters
1732
+ )
1733
+ if not thread:
1734
+ raise WrappedHTTPException(
1735
+ HTTPException(
1736
+ status_code=404, detail="Thread not found"
1737
+ )
1738
+ )
1372
1739
  channel_prefix = f"run:{run_id}:stream:"
1373
1740
  len_prefix = len(channel_prefix.encode())
1374
1741
 
@@ -1393,7 +1760,9 @@ class Runs:
1393
1760
  )
1394
1761
  except TimeoutError:
1395
1762
  # Check if the run is still pending
1396
- run_iter = await Runs.get(conn, run_id, thread_id=thread_id)
1763
+ run_iter = await Runs.get(
1764
+ conn, run_id, thread_id=thread_id, ctx=ctx
1765
+ )
1397
1766
  run = await anext(run_iter, None)
1398
1767
 
1399
1768
  if ignore_404 and run is None:
@@ -1408,6 +1777,8 @@ class Runs:
1408
1777
  break
1409
1778
  elif run["status"] != "pending":
1410
1779
  break
1780
+ except WrappedHTTPException as e:
1781
+ raise e.http_exception from None
1411
1782
  except:
1412
1783
  if cancel_on_disconnect:
1413
1784
  create_task(cancel_run(thread_id, run_id))
@@ -1475,22 +1846,32 @@ class Crons:
1475
1846
  schedule: str,
1476
1847
  cron_id: UUID | None = None,
1477
1848
  thread_id: UUID | None = None,
1478
- user_id: str | None = None,
1479
1849
  end_time: datetime | None = None,
1850
+ ctx: Auth.types.BaseAuthContext | None = None,
1480
1851
  ) -> AsyncIterator[Cron]:
1481
1852
  raise NotImplementedError
1482
1853
 
1483
1854
  @staticmethod
1484
- async def delete(conn: InMemConnectionProto, cron_id: UUID) -> AsyncIterator[UUID]:
1855
+ async def delete(
1856
+ conn: InMemConnectionProto,
1857
+ cron_id: UUID,
1858
+ ctx: Auth.types.BaseAuthContext | None = None,
1859
+ ) -> AsyncIterator[UUID]:
1485
1860
  raise NotImplementedError
1486
1861
 
1487
1862
  @staticmethod
1488
- async def next(conn: InMemConnectionProto) -> AsyncIterator[Cron]:
1863
+ async def next(
1864
+ conn: InMemConnectionProto,
1865
+ ctx: Auth.types.BaseAuthContext | None = None,
1866
+ ) -> AsyncIterator[Cron]:
1489
1867
  raise NotImplementedError
1490
1868
 
1491
1869
  @staticmethod
1492
1870
  async def set_next_run_date(
1493
- conn: InMemConnectionProto, cron_id: UUID, next_run_date: datetime
1871
+ conn: InMemConnectionProto,
1872
+ cron_id: UUID,
1873
+ next_run_date: datetime,
1874
+ ctx: Auth.types.BaseAuthContext | None = None,
1494
1875
  ) -> None:
1495
1876
  raise NotImplementedError
1496
1877
 
@@ -1502,13 +1883,16 @@ class Crons:
1502
1883
  thread_id: UUID | None,
1503
1884
  limit: int,
1504
1885
  offset: int,
1886
+ ctx: Auth.types.BaseAuthContext | None = None,
1505
1887
  ) -> AsyncIterator[Cron]:
1506
1888
  raise NotImplementedError
1507
1889
 
1508
1890
 
1509
- async def cancel_run(thread_id: UUID, run_id: UUID) -> None:
1891
+ async def cancel_run(
1892
+ thread_id: UUID, run_id: UUID, ctx: Auth.types.BaseAuthContext | None = None
1893
+ ) -> None:
1510
1894
  async with connect() as conn:
1511
- await Runs.cancel(conn, [run_id], thread_id=thread_id)
1895
+ await Runs.cancel(conn, [run_id], thread_id=thread_id, ctx=ctx)
1512
1896
 
1513
1897
 
1514
1898
  def _delete_checkpoints_for_thread(
@@ -1538,6 +1922,47 @@ def _delete_checkpoints_for_thread(
1538
1922
  )
1539
1923
 
1540
1924
 
1925
+ def _check_filter_match(metadata: dict, filters: Auth.types.FilterType | None) -> bool:
1926
+ """Check if metadata matches the filter conditions.
1927
+
1928
+ Args:
1929
+ metadata: The metadata to check
1930
+ filters: The filter conditions to apply
1931
+
1932
+ Returns:
1933
+ True if the metadata matches all filter conditions, False otherwise
1934
+ """
1935
+ if not filters:
1936
+ return True
1937
+
1938
+ for key, value in filters.items():
1939
+ if isinstance(value, dict):
1940
+ op = next(iter(value))
1941
+ filter_value = value[op]
1942
+
1943
+ if op == "$eq":
1944
+ if key not in metadata or metadata[key] != filter_value:
1945
+ return False
1946
+ elif op == "$contains":
1947
+ if (
1948
+ key not in metadata
1949
+ or not isinstance(metadata[key], list)
1950
+ or filter_value not in metadata[key]
1951
+ ):
1952
+ return False
1953
+ else:
1954
+ # Direct equality
1955
+ if key not in metadata or metadata[key] != value:
1956
+ return False
1957
+
1958
+ return True
1959
+
1960
+
1961
+ async def _empty_generator():
1962
+ if False:
1963
+ yield
1964
+
1965
+
1541
1966
  __all__ = [
1542
1967
  "Assistants",
1543
1968
  "Crons",