langgraph-api 0.0.8__py3-none-any.whl → 0.0.9__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of langgraph-api might be problematic. Click here for more details.

langgraph_storage/ops.py CHANGED
@@ -11,15 +11,17 @@ from collections.abc import AsyncIterator, Sequence
11
11
  from contextlib import asynccontextmanager
12
12
  from copy import deepcopy
13
13
  from datetime import UTC, datetime, timedelta
14
- from typing import Any, Literal
14
+ from typing import Any, Literal, cast
15
15
  from uuid import UUID, uuid4
16
16
 
17
17
  import structlog
18
18
  from langgraph.pregel.debug import CheckpointPayload
19
19
  from langgraph.pregel.types import StateSnapshot
20
+ from langgraph_sdk import Auth
20
21
  from starlette.exceptions import HTTPException
21
22
 
22
23
  from langgraph_api.asyncio import SimpleTaskGroup, ValueEvent, create_task
24
+ from langgraph_api.auth.custom import handle_event
23
25
  from langgraph_api.errors import UserInterrupt, UserRollback
24
26
  from langgraph_api.graph import get_graph
25
27
  from langgraph_api.schema import (
@@ -41,7 +43,7 @@ from langgraph_api.schema import (
41
43
  ThreadUpdateResponse,
42
44
  )
43
45
  from langgraph_api.serde import Fragment
44
- from langgraph_api.utils import fetchone
46
+ from langgraph_api.utils import fetchone, get_auth_ctx
45
47
  from langgraph_storage.checkpoint import Checkpointer
46
48
  from langgraph_storage.database import InMemConnectionProto, connect
47
49
  from langgraph_storage.queue import Message, get_stream_manager
@@ -57,13 +59,51 @@ def _ensure_uuid(id_: str | uuid.UUID | None) -> uuid.UUID:
57
59
  return id_
58
60
 
59
61
 
62
+ class WrappedHTTPException(Exception):
63
+ def __init__(self, http_exception: HTTPException):
64
+ self.http_exception = http_exception
65
+
66
+
60
67
  # Right now the whole API types as UUID but frequently passes a str
61
68
  # We ensure UUIDs for eveerything EXCEPT the checkpoint storage/writes,
62
69
  # which we leave as strings. This is because I'm too lazy to subclass fully
63
70
  # and we use non-UUID examples in the OSS version
64
71
 
65
72
 
66
- class Assistants:
73
+ class Authenticated:
74
+ resource: Literal["threads", "crons", "assistants"]
75
+
76
+ @classmethod
77
+ def _context(
78
+ cls,
79
+ ctx: Auth.types.BaseAuthContext | None,
80
+ action: Literal["create", "read", "update", "delete", "create_run"],
81
+ ) -> Auth.types.AuthContext | None:
82
+ if not ctx:
83
+ return
84
+ return Auth.types.AuthContext(
85
+ user=ctx.user,
86
+ scopes=ctx.scopes,
87
+ resource=cls.resource,
88
+ action=action,
89
+ )
90
+
91
+ @classmethod
92
+ async def handle_event(
93
+ cls,
94
+ ctx: Auth.types.BaseAuthContext | None,
95
+ action: Literal["create", "read", "update", "delete", "search"],
96
+ value: Any,
97
+ ) -> Auth.types.FilterType | None:
98
+ ctx = ctx or get_auth_ctx()
99
+ if not ctx:
100
+ return
101
+ return await handle_event(cls._context(ctx, action), value)
102
+
103
+
104
+ class Assistants(Authenticated):
105
+ resource = "assistants"
106
+
67
107
  @staticmethod
68
108
  async def search(
69
109
  conn: InMemConnectionProto,
@@ -72,7 +112,17 @@ class Assistants:
72
112
  metadata: MetadataInput,
73
113
  limit: int,
74
114
  offset: int,
115
+ ctx: Auth.types.BaseAuthContext | None = None,
75
116
  ) -> AsyncIterator[Assistant]:
117
+ metadata = metadata if metadata is not None else {}
118
+ filters = await Assistants.handle_event(
119
+ ctx,
120
+ "search",
121
+ Auth.types.AssistantsSearch(
122
+ graph_id=graph_id, metadata=metadata, limit=limit, offset=offset
123
+ ),
124
+ )
125
+
76
126
  async def filter_and_yield() -> AsyncIterator[Assistant]:
77
127
  assistants = conn.store["assistants"]
78
128
  filtered_assistants = [
@@ -82,6 +132,7 @@ class Assistants:
82
132
  and (
83
133
  not metadata or is_jsonb_contained(assistant["metadata"], metadata)
84
134
  )
135
+ and (not filters or _check_filter_match(assistant["metadata"], filters))
85
136
  ]
86
137
  filtered_assistants.sort(key=lambda x: x["created_at"], reverse=True)
87
138
  for assistant in filtered_assistants[offset : offset + limit]:
@@ -91,14 +142,23 @@ class Assistants:
91
142
 
92
143
  @staticmethod
93
144
  async def get(
94
- conn: InMemConnectionProto, assistant_id: UUID
145
+ conn: InMemConnectionProto,
146
+ assistant_id: UUID,
147
+ ctx: Auth.types.BaseAuthContext | None = None,
95
148
  ) -> AsyncIterator[Assistant]:
96
149
  """Get an assistant by ID."""
97
150
  assistant_id = _ensure_uuid(assistant_id)
151
+ filters = await Assistants.handle_event(
152
+ ctx,
153
+ "read",
154
+ Auth.types.AssistantsRead(assistant_id=assistant_id),
155
+ )
98
156
 
99
157
  async def _yield_result():
100
158
  for assistant in conn.store["assistants"]:
101
- if assistant["assistant_id"] == assistant_id:
159
+ if assistant["assistant_id"] == assistant_id and (
160
+ not filters or _check_filter_match(assistant["metadata"], filters)
161
+ ):
102
162
  yield assistant
103
163
 
104
164
  return _yield_result()
@@ -113,14 +173,33 @@ class Assistants:
113
173
  metadata: MetadataInput,
114
174
  if_exists: OnConflictBehavior,
115
175
  name: str,
176
+ ctx: Auth.types.BaseAuthContext | None = None,
116
177
  ) -> AsyncIterator[Assistant]:
117
178
  """Insert an assistant."""
118
179
  assistant_id = _ensure_uuid(assistant_id)
180
+ metadata = metadata if metadata is not None else {}
181
+ filters = await Assistants.handle_event(
182
+ ctx,
183
+ "create",
184
+ Auth.types.AssistantsCreate(
185
+ assistant_id=assistant_id,
186
+ graph_id=graph_id,
187
+ config=config,
188
+ metadata=metadata,
189
+ name=name,
190
+ ),
191
+ )
119
192
  existing_assistant = next(
120
193
  (a for a in conn.store["assistants"] if a["assistant_id"] == assistant_id),
121
194
  None,
122
195
  )
123
196
  if existing_assistant:
197
+ if filters and not _check_filter_match(
198
+ existing_assistant["metadata"], filters
199
+ ):
200
+ raise HTTPException(
201
+ status_code=409, detail=f"Assistant {assistant_id} already exists"
202
+ )
124
203
  if if_exists == "raise":
125
204
  raise HTTPException(
126
205
  status_code=409, detail=f"Assistant {assistant_id} already exists"
@@ -168,6 +247,7 @@ class Assistants:
168
247
  graph_id: str | None = None,
169
248
  metadata: MetadataInput | None = None,
170
249
  name: str | None = None,
250
+ ctx: Auth.types.BaseAuthContext | None = None,
171
251
  ) -> AsyncIterator[Assistant]:
172
252
  """Update an assistant.
173
253
 
@@ -182,6 +262,18 @@ class Assistants:
182
262
  return the updated assistant model.
183
263
  """
184
264
  assistant_id = _ensure_uuid(assistant_id)
265
+ metadata = metadata if metadata is not None else {}
266
+ filters = await Assistants.handle_event(
267
+ ctx,
268
+ "update",
269
+ Auth.types.AssistantsUpdate(
270
+ assistant_id=assistant_id,
271
+ graph_id=graph_id,
272
+ config=config,
273
+ metadata=metadata,
274
+ name=name,
275
+ ),
276
+ )
185
277
  assistant = next(
186
278
  (a for a in conn.store["assistants"] if a["assistant_id"] == assistant_id),
187
279
  None,
@@ -190,6 +282,10 @@ class Assistants:
190
282
  raise HTTPException(
191
283
  status_code=404, detail=f"Assistant {assistant_id} not found"
192
284
  )
285
+ elif filters and not _check_filter_match(assistant["metadata"], filters):
286
+ raise HTTPException(
287
+ status_code=404, detail=f"Assistant {assistant_id} not found"
288
+ )
193
289
 
194
290
  now = datetime.now(UTC)
195
291
  new_version = (
@@ -204,6 +300,11 @@ class Assistants:
204
300
  )
205
301
 
206
302
  # Update assistant_versions table
303
+ if metadata:
304
+ metadata = {
305
+ **assistant["metadata"],
306
+ **metadata,
307
+ }
207
308
  new_version_entry = {
208
309
  "assistant_id": assistant_id,
209
310
  "version": new_version,
@@ -233,10 +334,33 @@ class Assistants:
233
334
 
234
335
  @staticmethod
235
336
  async def delete(
236
- conn: InMemConnectionProto, assistant_id: UUID
337
+ conn: InMemConnectionProto,
338
+ assistant_id: UUID,
339
+ ctx: Auth.types.BaseAuthContext | None = None,
237
340
  ) -> AsyncIterator[UUID]:
238
341
  """Delete an assistant by ID."""
239
342
  assistant_id = _ensure_uuid(assistant_id)
343
+ filters = await Assistants.handle_event(
344
+ ctx,
345
+ "delete",
346
+ Auth.types.AssistantsDelete(
347
+ assistant_id=assistant_id,
348
+ ),
349
+ )
350
+ assistant = next(
351
+ (a for a in conn.store["assistants"] if a["assistant_id"] == assistant_id),
352
+ None,
353
+ )
354
+
355
+ if not assistant:
356
+ raise HTTPException(
357
+ status_code=404, detail=f"Assistant with ID {assistant_id} not found"
358
+ )
359
+ elif filters and not _check_filter_match(assistant["metadata"], filters):
360
+ raise HTTPException(
361
+ status_code=404, detail=f"Assistant with ID {assistant_id} not found"
362
+ )
363
+
240
364
  conn.store["assistants"] = [
241
365
  a for a in conn.store["assistants"] if a["assistant_id"] != assistant_id
242
366
  ]
@@ -249,9 +373,10 @@ class Assistants:
249
373
  retained = []
250
374
  for run in conn.store["runs"]:
251
375
  if run["assistant_id"] == assistant_id:
252
- res = await Runs.delete(conn, run["run_id"], thread_id=run["thread_id"])
376
+ res = await Runs.delete(
377
+ conn, run["run_id"], thread_id=run["thread_id"], ctx=ctx
378
+ )
253
379
  await anext(res)
254
-
255
380
  else:
256
381
  retained.append(run)
257
382
 
@@ -262,10 +387,21 @@ class Assistants:
262
387
 
263
388
  @staticmethod
264
389
  async def set_latest(
265
- conn: InMemConnectionProto, assistant_id: UUID, version: int
390
+ conn: InMemConnectionProto,
391
+ assistant_id: UUID,
392
+ version: int,
393
+ ctx: Auth.types.BaseAuthContext | None = None,
266
394
  ) -> AsyncIterator[Assistant]:
267
395
  """Change the version of an assistant."""
268
396
  assistant_id = _ensure_uuid(assistant_id)
397
+ filters = await Assistants.handle_event(
398
+ ctx,
399
+ "update",
400
+ Auth.types.AssistantsUpdate(
401
+ assistant_id=assistant_id,
402
+ version=version,
403
+ ),
404
+ )
269
405
  assistant = next(
270
406
  (a for a in conn.store["assistants"] if a["assistant_id"] == assistant_id),
271
407
  None,
@@ -274,6 +410,10 @@ class Assistants:
274
410
  raise HTTPException(
275
411
  status_code=404, detail=f"Assistant {assistant_id} not found"
276
412
  )
413
+ elif filters and not _check_filter_match(assistant["metadata"], filters):
414
+ raise HTTPException(
415
+ status_code=404, detail=f"Assistant {assistant_id} not found"
416
+ )
277
417
 
278
418
  version_data = next(
279
419
  (
@@ -310,14 +450,21 @@ class Assistants:
310
450
  metadata: MetadataInput,
311
451
  limit: int,
312
452
  offset: int,
453
+ ctx: Auth.types.BaseAuthContext | None = None,
313
454
  ) -> AsyncIterator[Assistant]:
314
455
  """Get all versions of an assistant."""
315
456
  assistant_id = _ensure_uuid(assistant_id)
457
+ filters = await Assistants.handle_event(
458
+ ctx,
459
+ "read",
460
+ Auth.types.AssistantsRead(assistant_id=assistant_id),
461
+ )
316
462
  versions = [
317
463
  v
318
464
  for v in conn.store["assistant_versions"]
319
465
  if v["assistant_id"] == assistant_id
320
466
  and (not metadata or is_jsonb_contained(v["metadata"], metadata))
467
+ and (not filters or _check_filter_match(v["metadata"], filters))
321
468
  ]
322
469
  versions.sort(key=lambda x: x["version"], reverse=True)
323
470
 
@@ -379,7 +526,9 @@ def _replace_thread_id(data, new_thread_id, thread_id):
379
526
  return d
380
527
 
381
528
 
382
- class Threads:
529
+ class Threads(Authenticated):
530
+ resource = "threads"
531
+
383
532
  @staticmethod
384
533
  async def search(
385
534
  conn: InMemConnectionProto,
@@ -389,29 +538,43 @@ class Threads:
389
538
  status: ThreadStatus | None,
390
539
  limit: int,
391
540
  offset: int,
541
+ ctx: Auth.types.BaseAuthContext | None = None,
392
542
  ) -> AsyncIterator[Thread]:
393
543
  threads = conn.store["threads"]
394
544
  filtered_threads: list[Thread] = []
545
+ metadata = metadata if metadata is not None else {}
546
+ values = values if values is not None else {}
547
+ filters = await Threads.handle_event(
548
+ ctx,
549
+ "search",
550
+ Auth.types.ThreadsSearch(
551
+ metadata=metadata,
552
+ values=values,
553
+ status=status,
554
+ limit=limit,
555
+ offset=offset,
556
+ ),
557
+ )
395
558
 
396
559
  # Apply filters
397
560
  for thread in threads:
398
- matches = True
561
+ if filters and not _check_filter_match(thread["metadata"], filters):
562
+ continue
399
563
 
400
564
  if metadata and not is_jsonb_contained(thread["metadata"], metadata):
401
- matches = False
565
+ continue
402
566
 
403
567
  if (
404
568
  values
405
569
  and "values" in thread
406
570
  and not is_jsonb_contained(thread["values"], values)
407
571
  ):
408
- matches = False
572
+ continue
409
573
 
410
574
  if status and thread.get("status") != status:
411
- matches = False
575
+ continue
412
576
 
413
- if matches:
414
- filtered_threads.append(thread)
577
+ filtered_threads.append(thread)
415
578
 
416
579
  # Sort by created_at in descending order
417
580
  sorted_threads = sorted(
@@ -428,8 +591,11 @@ class Threads:
428
591
  return thread_iterator()
429
592
 
430
593
  @staticmethod
431
- async def get(conn: InMemConnectionProto, thread_id: UUID) -> AsyncIterator[Thread]:
432
- """Get a thread by ID."""
594
+ async def _get_with_filters(
595
+ conn: InMemConnectionProto,
596
+ thread_id: UUID,
597
+ filters: Auth.types.FilterType | None,
598
+ ) -> Thread | None:
433
599
  thread_id = _ensure_uuid(thread_id)
434
600
  matching_thread = next(
435
601
  (
@@ -439,6 +605,37 @@ class Threads:
439
605
  ),
440
606
  None,
441
607
  )
608
+ if not matching_thread or (
609
+ filters and not _check_filter_match(matching_thread["metadata"], filters)
610
+ ):
611
+ return
612
+
613
+ return matching_thread
614
+
615
+ @staticmethod
616
+ async def _get(
617
+ conn: InMemConnectionProto,
618
+ thread_id: UUID,
619
+ ctx: Auth.types.BaseAuthContext | None = None,
620
+ ) -> Thread | None:
621
+ """Get a thread by ID."""
622
+ thread_id = _ensure_uuid(thread_id)
623
+ filters = await Threads.handle_event(
624
+ ctx,
625
+ "read",
626
+ Auth.types.ThreadsRead(thread_id=thread_id),
627
+ )
628
+ return await Threads._get_with_filters(conn, thread_id, filters)
629
+
630
+ @staticmethod
631
+ async def get(
632
+ conn: InMemConnectionProto,
633
+ thread_id: UUID,
634
+ ctx: Auth.types.BaseAuthContext | None = None,
635
+ ) -> AsyncIterator[Thread]:
636
+ """Get a thread by ID."""
637
+ matching_thread = await Threads._get(conn, thread_id, ctx)
638
+
442
639
  if not matching_thread:
443
640
  raise HTTPException(
444
641
  status_code=404, detail=f"Thread with ID {thread_id} not found"
@@ -457,6 +654,7 @@ class Threads:
457
654
  *,
458
655
  metadata: MetadataInput,
459
656
  if_exists: OnConflictBehavior,
657
+ ctx: Auth.types.BaseAuthContext | None = None,
460
658
  ) -> AsyncIterator[Thread]:
461
659
  """Insert or update a thread."""
462
660
  thread_id = _ensure_uuid(thread_id)
@@ -467,8 +665,22 @@ class Threads:
467
665
  existing_thread = next(
468
666
  (t for t in conn.store["threads"] if t["thread_id"] == thread_id), None
469
667
  )
668
+ filters = await Threads.handle_event(
669
+ ctx,
670
+ "create",
671
+ Auth.types.ThreadsCreate(
672
+ thread_id=thread_id, metadata=metadata, if_exists=if_exists
673
+ ),
674
+ )
470
675
 
471
676
  if existing_thread:
677
+ if filters and not _check_filter_match(
678
+ existing_thread["metadata"], filters
679
+ ):
680
+ # Should we use a different status code here?
681
+ raise HTTPException(
682
+ status_code=409, detail=f"Thread with ID {thread_id} already exists"
683
+ )
472
684
  if if_exists == "raise":
473
685
  raise HTTPException(
474
686
  status_code=409, detail=f"Thread with ID {thread_id} already exists"
@@ -479,7 +691,6 @@ class Threads:
479
691
  yield existing_thread
480
692
 
481
693
  return _yield_existing()
482
-
483
694
  # Create new thread
484
695
  new_thread: Thread = {
485
696
  "thread_id": thread_id,
@@ -501,7 +712,11 @@ class Threads:
501
712
 
502
713
  @staticmethod
503
714
  async def patch(
504
- conn: InMemConnectionProto, thread_id: UUID, *, metadata: MetadataValue
715
+ conn: InMemConnectionProto,
716
+ thread_id: UUID,
717
+ *,
718
+ metadata: MetadataValue,
719
+ ctx: Auth.types.BaseAuthContext | None = None,
505
720
  ) -> AsyncIterator[Thread]:
506
721
  """Update a thread."""
507
722
  thread_list = conn.store["threads"]
@@ -514,15 +729,23 @@ class Threads:
514
729
  break
515
730
 
516
731
  if thread_idx is not None:
517
- thread = copy.deepcopy(thread_list[thread_idx])
518
- thread["metadata"] = {**thread["metadata"], **metadata}
519
- thread["updated_at"] = datetime.now(UTC)
520
- thread_list[thread_idx] = thread
732
+ filters = await Threads.handle_event(
733
+ ctx,
734
+ "update",
735
+ Auth.types.ThreadsUpdate(thread_id=thread_id, metadata=metadata),
736
+ )
737
+ if not filters or _check_filter_match(
738
+ thread_list[thread_idx]["metadata"], filters
739
+ ):
740
+ thread = copy.deepcopy(thread_list[thread_idx])
741
+ thread["metadata"] = {**thread["metadata"], **metadata}
742
+ thread["updated_at"] = datetime.now(UTC)
743
+ thread_list[thread_idx] = thread
521
744
 
522
- async def thread_iterator() -> AsyncIterator[Thread]:
523
- yield thread
745
+ async def thread_iterator() -> AsyncIterator[Thread]:
746
+ yield thread
524
747
 
525
- return thread_iterator()
748
+ return thread_iterator()
526
749
 
527
750
  async def empty_iterator() -> AsyncIterator[Thread]:
528
751
  if False: # This ensures the iterator is empty
@@ -536,6 +759,7 @@ class Threads:
536
759
  thread_id: UUID,
537
760
  checkpoint: CheckpointPayload | None,
538
761
  exception: BaseException | None,
762
+ # This does not accept the auth context since it's only used internally
539
763
  ) -> None:
540
764
  """Set the status of a thread."""
541
765
  thread_id = _ensure_uuid(thread_id)
@@ -597,19 +821,34 @@ class Threads:
597
821
 
598
822
  @staticmethod
599
823
  async def delete(
600
- conn: InMemConnectionProto, thread_id: UUID
824
+ conn: InMemConnectionProto,
825
+ thread_id: UUID,
826
+ ctx: Auth.types.BaseAuthContext | None = None,
601
827
  ) -> AsyncIterator[UUID]:
602
828
  """Delete a thread by ID and cascade delete all associated runs."""
603
829
  thread_list = conn.store["threads"]
604
830
  thread_idx = None
605
831
  thread_id = _ensure_uuid(thread_id)
606
- conn.locks.pop(thread_id, None)
607
832
 
608
833
  # Find the thread to delete
609
834
  for idx, thread in enumerate(thread_list):
610
835
  if thread["thread_id"] == thread_id:
611
836
  thread_idx = idx
612
837
  break
838
+ filters = await Threads.handle_event(
839
+ ctx,
840
+ "delete",
841
+ Auth.types.ThreadsDelete(thread_id=thread_id),
842
+ )
843
+ if (filters and not _check_filter_match(thread["metadata"], filters)) or (
844
+ thread_idx is None
845
+ ):
846
+ raise HTTPException(
847
+ status_code=404, detail=f"Thread with ID {thread_id} not found"
848
+ )
849
+
850
+ # Delete the thread
851
+ conn.locks.pop(thread_id, None)
613
852
  # Cascade delete all runs associated with this thread
614
853
  conn.store["runs"] = [
615
854
  run for run in conn.store["runs"] if run["thread_id"] != thread_id
@@ -635,12 +874,20 @@ class Threads:
635
874
 
636
875
  @staticmethod
637
876
  async def copy(
638
- conn: InMemConnectionProto, thread_id: UUID
877
+ conn: InMemConnectionProto,
878
+ thread_id: UUID,
879
+ ctx: Auth.types.BaseAuthContext | None = None,
639
880
  ) -> AsyncIterator[Thread]:
640
881
  """Create a copy of an existing thread."""
641
882
  thread_id = _ensure_uuid(thread_id)
642
883
  new_thread_id = uuid4()
643
-
884
+ filters = await Threads.handle_event(
885
+ ctx,
886
+ "read",
887
+ Auth.types.ThreadsRead(
888
+ thread_id=new_thread_id,
889
+ ),
890
+ )
644
891
  async with conn.pipeline():
645
892
  # Find the original thread in our store
646
893
  original_thread = next(
@@ -648,7 +895,11 @@ class Threads:
648
895
  )
649
896
 
650
897
  if not original_thread:
651
- return
898
+ return _empty_generator()
899
+ if filters and not _check_filter_match(
900
+ original_thread["metadata"], filters
901
+ ):
902
+ return _empty_generator()
652
903
 
653
904
  # Create new thread with copied metadata
654
905
  new_thread: Thread = {
@@ -690,15 +941,22 @@ class Threads:
690
941
 
691
942
  return row_generator()
692
943
 
693
- class State:
944
+ class State(Authenticated):
945
+ # We will treat this like a runs resource for now.
946
+ resource = "threads"
947
+
694
948
  @staticmethod
695
949
  async def get(
696
- conn: InMemConnectionProto, config: Config, subgraphs: bool = False
950
+ conn: InMemConnectionProto,
951
+ config: Config,
952
+ subgraphs: bool = False,
953
+ ctx: Auth.types.BaseAuthContext | None = None,
697
954
  ) -> StateSnapshot:
698
955
  """Get state for a thread."""
699
956
  checkpointer = Checkpointer(conn)
700
957
  thread_id = _ensure_uuid(config["configurable"]["thread_id"])
701
- thread_iter = await Threads.get(conn, thread_id)
958
+ # Auth will be applied here so no need to use filters downstream
959
+ thread_iter = await Threads.get(conn, thread_id, ctx=ctx)
702
960
  thread = await anext(thread_iter)
703
961
  checkpoint = await checkpointer.aget(config)
704
962
 
@@ -747,12 +1005,19 @@ class Threads:
747
1005
  config: Config,
748
1006
  values: Sequence[dict] | dict[str, Any] | None,
749
1007
  as_node: str | None = None,
1008
+ ctx: Auth.types.BaseAuthContext | None = None,
750
1009
  ) -> ThreadUpdateResponse:
751
1010
  """Add state to a thread."""
1011
+ thread_id = _ensure_uuid(config["configurable"]["thread_id"])
1012
+ filters = await Threads.handle_event(
1013
+ ctx,
1014
+ "update",
1015
+ Auth.types.ThreadsUpdate(thread_id=thread_id),
1016
+ )
752
1017
 
753
1018
  checkpointer = Checkpointer(conn)
754
- thread_id = _ensure_uuid(config["configurable"]["thread_id"])
755
- thread_iter = await Threads.get(conn, thread_id)
1019
+
1020
+ thread_iter = await Threads.get(conn, thread_id, ctx=ctx)
756
1021
  thread = await fetchone(
757
1022
  thread_iter, not_found_detail=f"Thread {thread_id} not found."
758
1023
  )
@@ -760,6 +1025,8 @@ class Threads:
760
1025
 
761
1026
  if not thread:
762
1027
  raise HTTPException(status_code=404, detail="Thread not found")
1028
+ if not _check_filter_match(thread["metadata"], filters):
1029
+ raise HTTPException(status_code=403, detail="Forbidden")
763
1030
 
764
1031
  metadata = thread["metadata"]
765
1032
  thread_config = thread["config"]
@@ -781,7 +1048,7 @@ class Threads:
781
1048
  )
782
1049
 
783
1050
  # Get current state
784
- state = await Threads.State.get(conn, config, subgraphs=False)
1051
+ state = await Threads.State.get(conn, config, subgraphs=False, ctx=ctx)
785
1052
  # Update thread values
786
1053
  for thread in conn.store["threads"]:
787
1054
  if thread["thread_id"] == thread_id:
@@ -805,22 +1072,26 @@ class Threads:
805
1072
  limit: int = 10,
806
1073
  before: str | Checkpoint | None = None,
807
1074
  metadata: MetadataInput = None,
1075
+ ctx: Auth.types.BaseAuthContext | None = None,
808
1076
  ) -> list[StateSnapshot]:
809
1077
  """Get the history of a thread."""
810
1078
 
811
1079
  thread_id = _ensure_uuid(config["configurable"]["thread_id"])
812
1080
  thread = None
813
-
814
- for t in conn.store["threads"]:
815
- if t["thread_id"] == thread_id:
816
- thread = t
817
- break
818
-
819
- if not thread:
820
- return []
1081
+ filters = await Threads.handle_event(
1082
+ ctx,
1083
+ "read",
1084
+ Auth.types.ThreadsRead(thread_id=thread_id),
1085
+ )
1086
+ thread = await fetchone(
1087
+ await Threads.get(conn, config["configurable"]["thread_id"], ctx=ctx)
1088
+ )
821
1089
 
822
1090
  # Parse thread metadata and config
823
1091
  thread_metadata = thread["metadata"]
1092
+ if not _check_filter_match(thread_metadata, filters):
1093
+ return []
1094
+
824
1095
  thread_config = thread["config"]
825
1096
  # If graph_id exists, get state history
826
1097
  if graph_id := thread_metadata.get("graph_id"):
@@ -847,7 +1118,9 @@ class Threads:
847
1118
  return []
848
1119
 
849
1120
 
850
- class Runs:
1121
+ class Runs(Authenticated):
1122
+ resource = "threads"
1123
+
851
1124
  @staticmethod
852
1125
  async def stats(conn: InMemConnectionProto) -> QueueStats:
853
1126
  """Get stats about the queue."""
@@ -1001,6 +1274,7 @@ class Runs:
1001
1274
  multitask_strategy: MultitaskStrategy = "reject",
1002
1275
  if_not_exists: IfNotExists = "reject",
1003
1276
  after_seconds: int = 0,
1277
+ ctx: Auth.types.BaseAuthContext | None = None,
1004
1278
  ) -> AsyncIterator[Run]:
1005
1279
  """Create a run."""
1006
1280
  assistant_id = _ensure_uuid(assistant_id)
@@ -1009,22 +1283,38 @@ class Runs:
1009
1283
  None,
1010
1284
  )
1011
1285
 
1012
- async def empty_generator():
1013
- if False:
1014
- yield
1015
-
1016
1286
  if not assistant:
1017
- return empty_generator()
1287
+ return _empty_generator()
1018
1288
 
1019
1289
  thread_id = _ensure_uuid(thread_id) if thread_id else None
1020
1290
  run_id = _ensure_uuid(run_id) if run_id else None
1021
- metadata = metadata or {}
1291
+ metadata = metadata if metadata is not None else {}
1022
1292
  config = kwargs.get("config", {})
1023
1293
 
1024
1294
  # Handle thread creation/update
1025
1295
  existing_thread = next(
1026
1296
  (t for t in conn.store["threads"] if t["thread_id"] == thread_id), None
1027
1297
  )
1298
+ filters = await Runs.handle_event(
1299
+ ctx,
1300
+ "create",
1301
+ Auth.types.RunsCreate(
1302
+ thread_id=thread_id,
1303
+ assistant_id=assistant_id,
1304
+ run_id=run_id,
1305
+ status=status,
1306
+ metadata=metadata,
1307
+ prevent_insert_if_inflight=prevent_insert_if_inflight,
1308
+ multitask_strategy=multitask_strategy,
1309
+ if_not_exists=if_not_exists,
1310
+ after_seconds=after_seconds,
1311
+ kwargs=kwargs,
1312
+ ),
1313
+ )
1314
+ if existing_thread and filters:
1315
+ # Reject if the user doesn't own the thread
1316
+ if not _check_filter_match(existing_thread["metadata"], filters):
1317
+ return _empty_generator()
1028
1318
 
1029
1319
  if not existing_thread and (thread_id is None or if_not_exists == "create"):
1030
1320
  # Create new thread
@@ -1069,7 +1359,7 @@ class Runs:
1069
1359
  )
1070
1360
  existing_thread["updated_at"] = datetime.now(UTC)
1071
1361
  else:
1072
- return empty_generator()
1362
+ return _empty_generator()
1073
1363
 
1074
1364
  # Check for inflight runs if needed
1075
1365
  inflight_runs = [
@@ -1089,9 +1379,11 @@ class Runs:
1089
1379
  # Create new run
1090
1380
  configurable = Runs._merge_jsonb(
1091
1381
  Runs._get_configurable(assistant["config"]),
1092
- Runs._get_configurable(existing_thread["config"])
1093
- if existing_thread
1094
- else {},
1382
+ (
1383
+ Runs._get_configurable(existing_thread["config"])
1384
+ if existing_thread
1385
+ else {}
1386
+ ),
1095
1387
  Runs._get_configurable(config),
1096
1388
  {
1097
1389
  "run_id": str(run_id),
@@ -1149,11 +1441,20 @@ class Runs:
1149
1441
 
1150
1442
  @staticmethod
1151
1443
  async def get(
1152
- conn: InMemConnectionProto, run_id: UUID, *, thread_id: UUID
1444
+ conn: InMemConnectionProto,
1445
+ run_id: UUID,
1446
+ *,
1447
+ thread_id: UUID,
1448
+ ctx: Auth.types.BaseAuthContext | None = None,
1153
1449
  ) -> AsyncIterator[Run]:
1154
1450
  """Get a run by ID."""
1155
1451
 
1156
1452
  run_id, thread_id = _ensure_uuid(run_id), _ensure_uuid(thread_id)
1453
+ filters = await Runs.handle_event(
1454
+ ctx,
1455
+ "read",
1456
+ Auth.types.ThreadsRead(thread_id=thread_id),
1457
+ )
1157
1458
 
1158
1459
  async def _yield_result():
1159
1460
  matching_run = None
@@ -1162,16 +1463,36 @@ class Runs:
1162
1463
  matching_run = run
1163
1464
  break
1164
1465
  if matching_run:
1466
+ if filters:
1467
+ thread = await Threads._get_with_filters(
1468
+ conn, matching_run["thread_id"], filters
1469
+ )
1470
+ if not thread:
1471
+ return
1165
1472
  yield matching_run
1166
1473
 
1167
1474
  return _yield_result()
1168
1475
 
1169
1476
  @staticmethod
1170
1477
  async def delete(
1171
- conn: InMemConnectionProto, run_id: UUID, *, thread_id: UUID
1478
+ conn: InMemConnectionProto,
1479
+ run_id: UUID,
1480
+ *,
1481
+ thread_id: UUID,
1482
+ ctx: Auth.types.BaseAuthContext | None = None,
1172
1483
  ) -> AsyncIterator[UUID]:
1173
1484
  """Delete a run by ID."""
1174
1485
  run_id, thread_id = _ensure_uuid(run_id), _ensure_uuid(thread_id)
1486
+ filters = await Runs.handle_event(
1487
+ ctx,
1488
+ "delete",
1489
+ Auth.types.ThreadsDelete(run_id=run_id, thread_id=thread_id),
1490
+ )
1491
+
1492
+ if filters:
1493
+ thread = await Threads._get_with_filters(conn, thread_id, filters)
1494
+ if not thread:
1495
+ return _empty_generator()
1175
1496
  _delete_checkpoints_for_thread(thread_id, conn, run_id=run_id)
1176
1497
  found = False
1177
1498
  for i, run in enumerate(conn.store["runs"]):
@@ -1192,16 +1513,22 @@ class Runs:
1192
1513
  run_id: UUID,
1193
1514
  *,
1194
1515
  thread_id: UUID,
1516
+ ctx: Auth.types.BaseAuthContext | None = None,
1195
1517
  ) -> Fragment:
1196
1518
  """Wait for a run to complete. If already done, return immediately.
1197
1519
 
1198
1520
  Returns:
1199
1521
  the final state of the run.
1200
1522
  """
1523
+ async with connect() as conn:
1524
+ # Validate ownership
1525
+ thread_iter = await Threads.get(conn, thread_id, ctx=ctx)
1526
+ await fetchone(thread_iter)
1201
1527
  last_chunk: bytes | None = None
1202
1528
  # wait for the run to complete
1529
+ # Rely on this join's auth
1203
1530
  async for mode, chunk in Runs.Stream.join(
1204
- run_id, thread_id=thread_id, stream_mode="values"
1531
+ run_id, thread_id=thread_id, stream_mode="values", ctx=ctx
1205
1532
  ):
1206
1533
  if mode == b"values":
1207
1534
  last_chunk = chunk
@@ -1212,7 +1539,7 @@ class Runs:
1212
1539
  else:
1213
1540
  # otherwise, the run had already finished, so fetch the state from thread
1214
1541
  async with connect() as conn:
1215
- thread_iter = await Threads.get(conn, thread_id)
1542
+ thread_iter = await Threads.get(conn, thread_id, ctx=ctx)
1216
1543
  thread = await fetchone(thread_iter)
1217
1544
  return thread["values"]
1218
1545
 
@@ -1223,8 +1550,10 @@ class Runs:
1223
1550
  *,
1224
1551
  action: Literal["interrupt", "rollback"] = "interrupt",
1225
1552
  thread_id: UUID,
1553
+ ctx: Auth.types.BaseAuthContext | None = None,
1226
1554
  ) -> None:
1227
1555
  """Cancel a run."""
1556
+ # Authwise, this invokes the runs.update handler
1228
1557
  # Cancellation tries to take two actions, to cover runs in different states:
1229
1558
  # - For any run, send a cancellation message through the stream manager
1230
1559
  # - For queued runs not yet picked up by a worker, update their status if interrupt,
@@ -1233,6 +1562,15 @@ class Runs:
1233
1562
  # - For runs in any other state, we raise a 404
1234
1563
  run_ids = [_ensure_uuid(run_id) for run_id in run_ids]
1235
1564
  thread_id = _ensure_uuid(thread_id)
1565
+ filters = await Runs.handle_event(
1566
+ ctx,
1567
+ "update",
1568
+ Auth.types.ThreadsUpdate(
1569
+ thread_id=thread_id,
1570
+ action=action,
1571
+ metadata={"run_ids": run_ids},
1572
+ ),
1573
+ )
1236
1574
 
1237
1575
  stream_manager = get_stream_manager()
1238
1576
  found_runs = []
@@ -1247,6 +1585,10 @@ class Runs:
1247
1585
  None,
1248
1586
  )
1249
1587
  if run:
1588
+ if filters:
1589
+ thread = await Threads._get_with_filters(conn, thread_id, filters)
1590
+ if not thread:
1591
+ continue
1250
1592
  found_runs.append(run)
1251
1593
  # Send cancellation message through stream manager
1252
1594
  control_message = Message(
@@ -1296,16 +1638,26 @@ class Runs:
1296
1638
  offset: int = 0,
1297
1639
  metadata: MetadataInput,
1298
1640
  status: RunStatus | None = None,
1641
+ ctx: Auth.types.BaseAuthContext | None = None,
1299
1642
  ) -> AsyncIterator[Run]:
1300
1643
  """List all runs by thread."""
1301
1644
  runs = conn.store["runs"]
1302
- metadata = metadata or {}
1645
+ metadata = metadata if metadata is not None else {}
1303
1646
  thread_id = _ensure_uuid(thread_id)
1647
+ filters = await Runs.handle_event(
1648
+ ctx,
1649
+ "search",
1650
+ Auth.types.ThreadsSearch(thread_id=thread_id, metadata=metadata),
1651
+ )
1304
1652
  filtered_runs = [
1305
1653
  run
1306
1654
  for run in runs
1307
1655
  if run["thread_id"] == thread_id
1308
1656
  and is_jsonb_contained(run["metadata"], metadata)
1657
+ and (
1658
+ not filters
1659
+ or (await Threads._get_with_filters(conn, thread_id, filters))
1660
+ )
1309
1661
  and (status is None or run["status"] == status)
1310
1662
  ]
1311
1663
  sorted_runs = sorted(filtered_runs, key=lambda x: x["created_at"], reverse=True)
@@ -1358,6 +1710,7 @@ class Runs:
1358
1710
  ignore_404: bool = False,
1359
1711
  cancel_on_disconnect: bool = False,
1360
1712
  stream_mode: "StreamMode | asyncio.Queue | None" = None,
1713
+ ctx: Auth.types.BaseAuthContext | None = None,
1361
1714
  ) -> AsyncIterator[tuple[bytes, bytes]]:
1362
1715
  """Stream the run output."""
1363
1716
  log = logger.isEnabledFor(logging.DEBUG)
@@ -1369,6 +1722,21 @@ class Runs:
1369
1722
 
1370
1723
  try:
1371
1724
  async with connect() as conn:
1725
+ filters = await Runs.handle_event(
1726
+ ctx,
1727
+ "read",
1728
+ Auth.types.ThreadsRead(thread_id=thread_id),
1729
+ )
1730
+ if filters:
1731
+ thread = await Threads._get_with_filters(
1732
+ cast(InMemConnectionProto, conn), thread_id, filters
1733
+ )
1734
+ if not thread:
1735
+ raise WrappedHTTPException(
1736
+ HTTPException(
1737
+ status_code=404, detail="Thread not found"
1738
+ )
1739
+ )
1372
1740
  channel_prefix = f"run:{run_id}:stream:"
1373
1741
  len_prefix = len(channel_prefix.encode())
1374
1742
 
@@ -1393,7 +1761,9 @@ class Runs:
1393
1761
  )
1394
1762
  except TimeoutError:
1395
1763
  # Check if the run is still pending
1396
- run_iter = await Runs.get(conn, run_id, thread_id=thread_id)
1764
+ run_iter = await Runs.get(
1765
+ conn, run_id, thread_id=thread_id, ctx=ctx
1766
+ )
1397
1767
  run = await anext(run_iter, None)
1398
1768
 
1399
1769
  if ignore_404 and run is None:
@@ -1408,6 +1778,8 @@ class Runs:
1408
1778
  break
1409
1779
  elif run["status"] != "pending":
1410
1780
  break
1781
+ except WrappedHTTPException as e:
1782
+ raise e.http_exception from None
1411
1783
  except:
1412
1784
  if cancel_on_disconnect:
1413
1785
  create_task(cancel_run(thread_id, run_id))
@@ -1475,22 +1847,32 @@ class Crons:
1475
1847
  schedule: str,
1476
1848
  cron_id: UUID | None = None,
1477
1849
  thread_id: UUID | None = None,
1478
- user_id: str | None = None,
1479
1850
  end_time: datetime | None = None,
1851
+ ctx: Auth.types.BaseAuthContext | None = None,
1480
1852
  ) -> AsyncIterator[Cron]:
1481
1853
  raise NotImplementedError
1482
1854
 
1483
1855
  @staticmethod
1484
- async def delete(conn: InMemConnectionProto, cron_id: UUID) -> AsyncIterator[UUID]:
1856
+ async def delete(
1857
+ conn: InMemConnectionProto,
1858
+ cron_id: UUID,
1859
+ ctx: Auth.types.BaseAuthContext | None = None,
1860
+ ) -> AsyncIterator[UUID]:
1485
1861
  raise NotImplementedError
1486
1862
 
1487
1863
  @staticmethod
1488
- async def next(conn: InMemConnectionProto) -> AsyncIterator[Cron]:
1864
+ async def next(
1865
+ conn: InMemConnectionProto,
1866
+ ctx: Auth.types.BaseAuthContext | None = None,
1867
+ ) -> AsyncIterator[Cron]:
1489
1868
  raise NotImplementedError
1490
1869
 
1491
1870
  @staticmethod
1492
1871
  async def set_next_run_date(
1493
- conn: InMemConnectionProto, cron_id: UUID, next_run_date: datetime
1872
+ conn: InMemConnectionProto,
1873
+ cron_id: UUID,
1874
+ next_run_date: datetime,
1875
+ ctx: Auth.types.BaseAuthContext | None = None,
1494
1876
  ) -> None:
1495
1877
  raise NotImplementedError
1496
1878
 
@@ -1502,13 +1884,16 @@ class Crons:
1502
1884
  thread_id: UUID | None,
1503
1885
  limit: int,
1504
1886
  offset: int,
1887
+ ctx: Auth.types.BaseAuthContext | None = None,
1505
1888
  ) -> AsyncIterator[Cron]:
1506
1889
  raise NotImplementedError
1507
1890
 
1508
1891
 
1509
- async def cancel_run(thread_id: UUID, run_id: UUID) -> None:
1892
+ async def cancel_run(
1893
+ thread_id: UUID, run_id: UUID, ctx: Auth.types.BaseAuthContext | None = None
1894
+ ) -> None:
1510
1895
  async with connect() as conn:
1511
- await Runs.cancel(conn, [run_id], thread_id=thread_id)
1896
+ await Runs.cancel(conn, [run_id], thread_id=thread_id, ctx=ctx)
1512
1897
 
1513
1898
 
1514
1899
  def _delete_checkpoints_for_thread(
@@ -1538,6 +1923,47 @@ def _delete_checkpoints_for_thread(
1538
1923
  )
1539
1924
 
1540
1925
 
1926
+ def _check_filter_match(metadata: dict, filters: Auth.types.FilterType | None) -> bool:
1927
+ """Check if metadata matches the filter conditions.
1928
+
1929
+ Args:
1930
+ metadata: The metadata to check
1931
+ filters: The filter conditions to apply
1932
+
1933
+ Returns:
1934
+ True if the metadata matches all filter conditions, False otherwise
1935
+ """
1936
+ if not filters:
1937
+ return True
1938
+
1939
+ for key, value in filters.items():
1940
+ if isinstance(value, dict):
1941
+ op = next(iter(value))
1942
+ filter_value = value[op]
1943
+
1944
+ if op == "$eq":
1945
+ if key not in metadata or metadata[key] != filter_value:
1946
+ return False
1947
+ elif op == "$contains":
1948
+ if (
1949
+ key not in metadata
1950
+ or not isinstance(metadata[key], list)
1951
+ or filter_value not in metadata[key]
1952
+ ):
1953
+ return False
1954
+ else:
1955
+ # Direct equality
1956
+ if key not in metadata or metadata[key] != value:
1957
+ return False
1958
+
1959
+ return True
1960
+
1961
+
1962
+ async def _empty_generator():
1963
+ if False:
1964
+ yield
1965
+
1966
+
1541
1967
  __all__ = [
1542
1968
  "Assistants",
1543
1969
  "Crons",