veri-agents-api 0.1.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
File without changes
File without changes
@@ -0,0 +1,2 @@
1
+ from .router import *
2
+ from .schema import *
@@ -0,0 +1,334 @@
1
+ import asyncio
2
+ import json
3
+ import logging
4
+ from typing import Any, AsyncGenerator, Dict, List, Tuple, Callable
5
+ from uuid import uuid4
6
+
7
+ from fastapi import HTTPException, Request, APIRouter
8
+ from fastapi.responses import StreamingResponse
9
+ from langchain_core.callbacks import AsyncCallbackHandler
10
+ from langchain_core.runnables import RunnableConfig
11
+ from langgraph.graph.graph import CompiledGraph
12
+
13
+ from .schema import (
14
+ ChatMessage,
15
+ StreamInput,
16
+ InvokeInput,
17
+ )
18
+ from veri_agents_api.threads_util import ThreadInfo, ThreadsCheckpointerUtil
19
+ from veri_agents_api.util.awaitable import as_awaitable, MaybeAwaitable
20
+
21
+ log = logging.getLogger(__name__)
22
+
23
+ class TokenQueueStreamingHandler(AsyncCallbackHandler):
24
+ """LangChain callback handler for streaming LLM tokens to an asyncio queue."""
25
+
26
+ def __init__(self, queue: asyncio.Queue):
27
+ self.queue = queue
28
+
29
+ async def on_llm_new_token(self, token: str, **kwargs) -> None:
30
+ if token:
31
+ await self.queue.put(token)
32
+
33
+ def create_thread_router(
34
+ get_graph: Callable[[Request], MaybeAwaitable[CompiledGraph]],
35
+ get_thread_id: Callable[[Request], MaybeAwaitable[str]],
36
+ allow_access_thread: Callable[[str, ThreadInfo | None, Request], MaybeAwaitable[bool]] = lambda thread_id, thread_info, request: True,
37
+ allow_invoke_thread: Callable[[str, ThreadInfo | None, InvokeInput, Request], MaybeAwaitable[bool]] = lambda thread_id, thread_info, invoke_input, request: True,
38
+ invoke_runnable_config: Callable[[str, ThreadInfo | None, InvokeInput, Request], MaybeAwaitable[RunnableConfig | None]] = lambda thread_id, thread_info, invoke_input, request: None,
39
+ on_invoke_thread: Callable[[str, ThreadInfo | None, InvokeInput, Request], MaybeAwaitable[None]] = lambda thread_id, thread_info, invoke_input, request: None,
40
+ # InvokeInputCls: Type[InvokeInput] = InvokeInput,
41
+ **router_kwargs
42
+ ):
43
+ """
44
+ POST /invoke
45
+ POST /stream
46
+ GET /history
47
+ GET /feedback
48
+ POST /feedback
49
+ """
50
+
51
+ router = APIRouter(**router_kwargs)
52
+
53
+ def _parse_input(user_input: InvokeInput, thread_id: str, invoke_recvd_runnable_config: RunnableConfig | None) -> Tuple[Dict[str, Any], str]:
54
+ run_id = uuid4()
55
+ input_message = ChatMessage(type="human", content=user_input.message)
56
+
57
+ runnable_config = invoke_recvd_runnable_config or RunnableConfig()
58
+
59
+ runnable_config["configurable"] = {
60
+ **{
61
+ # used by checkpointer
62
+ "thread_id": thread_id,
63
+
64
+ "_has_threadinfo": True,
65
+
66
+ # "args": user_input.args,
67
+ },
68
+ **(runnable_config.get("configurable", {}))
69
+ }
70
+
71
+ kwargs = dict(
72
+ input={"messages": [input_message.to_langchain()]},
73
+ config=runnable_config
74
+ )
75
+ return kwargs, str(run_id)
76
+
77
+ @router.post("/invoke")
78
+ async def invoke(invoke_input: InvokeInput, request: Request) -> ChatMessage:
79
+ """
80
+ Invoke the agent with user input to retrieve a final response.
81
+
82
+ Use thread_id to persist and continue a multi-turn conversation. run_id kwarg
83
+ is also attached to messages for recording feedback.
84
+ """
85
+
86
+ graph = await as_awaitable(get_graph(request))
87
+ thread_id = await as_awaitable(get_thread_id(request))
88
+
89
+
90
+ thread_info = await ThreadsCheckpointerUtil.get_thread_info(thread_id, graph.checkpointer)
91
+
92
+ if not await as_awaitable(allow_invoke_thread(thread_id, thread_info, invoke_input, request)):
93
+ raise HTTPException(status_code=403, detail="Forbidden")
94
+
95
+ invoke_recvd_runnable_config = await as_awaitable(invoke_runnable_config(thread_id, thread_info, invoke_input, request))
96
+
97
+ kwargs, run_id = _parse_input(invoke_input, thread_id, invoke_recvd_runnable_config)
98
+
99
+ # # store this thread in the database if a new one
100
+ # if user_input.thread_id not in router.state.threads:
101
+ # thread_info = ThreadInfo(
102
+ # thread_id=user_input.thread_id,
103
+ # user=principal,
104
+ # workflow_id=user_input.workflow,
105
+ # name=user_input.message[:50],
106
+ # metadata={"router": user_input.router},
107
+ # )
108
+ # router.state.threads[user_input.thread_id] = thread_info
109
+ # await graph.checkpointer.aput_thread(thread_info)
110
+
111
+ await as_awaitable(on_invoke_thread(thread_id, thread_info, invoke_input, request))
112
+
113
+ # langfuse_handler = CallbackHandler(
114
+ # public_key=router.state.cfg.logging.langfuse.public_key,
115
+ # secret_key=router.state.cfg.logging.langfuse.secret_key,
116
+ # host=router.state.cfg.logging.langfuse.host,
117
+ # # user_id=principal,
118
+ # session_id=user_input.thread_id,
119
+ # trace_name=user_input.message[:50],
120
+ # )
121
+ kwargs["config"]["callbacks"] = [] # was [langfuse_handler]
122
+ # kwargs["config"]["configurable"]["workflow_id"] = user_input.workflow
123
+ try:
124
+ response = await graph.ainvoke(**kwargs)
125
+ output = ChatMessage.from_langchain(response["messages"][-1])
126
+ output.run_id = str(run_id)
127
+ return output
128
+ except Exception as e:
129
+ raise HTTPException(status_code=500, detail=str(e))
130
+
131
+ @router.post("/stream")
132
+ async def stream_agent(stream_input: StreamInput, request: Request):
133
+ """
134
+ Stream the agent's response to a user input, including intermediate messages and tokens.
135
+
136
+ Use thread_id to persist and continue a multi-turn conversation. run_id kwarg
137
+ is also attached to all messages for recording feedback.
138
+ """
139
+
140
+ graph = await as_awaitable(get_graph(request))
141
+ thread_id = await as_awaitable(get_thread_id(request))
142
+
143
+ thread_info = await ThreadsCheckpointerUtil.get_thread_info(thread_id, graph.checkpointer)
144
+
145
+ if not await as_awaitable(allow_invoke_thread(thread_id, thread_info, stream_input, request)):
146
+ raise HTTPException(status_code=403, detail="Forbidden")
147
+
148
+ invoke_recvd_runnable_config = await as_awaitable(invoke_runnable_config(thread_id, thread_info, stream_input, request))
149
+
150
+ async def message_generator() -> AsyncGenerator[str, None]:
151
+ """
152
+ Generate a stream of messages from the agent.
153
+
154
+ This is the workhorse method for the /stream endpoint.
155
+ """
156
+ kwargs, run_id = _parse_input(stream_input, thread_id, invoke_recvd_runnable_config)
157
+
158
+ await as_awaitable(on_invoke_thread(thread_id, thread_info, stream_input, request))
159
+
160
+ # # store this thread in the database if a new one
161
+ # if user_input.thread_id not in router.state.threads:
162
+ # thread_info = ThreadInfo(
163
+ # thread_id=user_input.thread_id,
164
+ # user=principal,
165
+ # workflow_id=user_input.workflow,
166
+ # name=user_input.message[:50],
167
+ # metadata={"router": user_input.router},
168
+ # )
169
+ # router.state.threads[user_input.thread_id] = thread_info
170
+ # await graph.checkpointer.aput_thread(thread_info)
171
+
172
+ # Use an asyncio queue to process both messages and tokens in
173
+ # chronological order, so we can easily yield them to the client.
174
+ output_queue = asyncio.Queue(maxsize=10)
175
+
176
+ # langfuse_handler = CallbackHandler(
177
+ # public_key=router.state.cfg.logging.langfuse.public_key,
178
+ # secret_key=router.state.cfg.logging.langfuse.secret_key,
179
+ # host=router.state.cfg.logging.langfuse.host,
180
+ # user_id=principal,
181
+ # session_id=user_input.thread_id,
182
+ # trace_name=user_input.message[:50],
183
+ # )
184
+ if stream_input.stream_tokens:
185
+ kwargs["config"]["callbacks"] = [
186
+ TokenQueueStreamingHandler(queue=output_queue),
187
+ # langfuse_handler,
188
+ ]
189
+ # kwargs["config"]["configurable"]["workflow_id"] = stream_input.workflow
190
+
191
+ # Pass the agent's stream of messages to the queue in a separate task, so
192
+ # we can yield the messages to the client in the main thread.
193
+ async def run_agent_stream():
194
+ async for s in graph.astream(**kwargs, stream_mode="updates"):
195
+ await output_queue.put(s)
196
+ await output_queue.put(None)
197
+
198
+ stream_task = asyncio.create_task(run_agent_stream())
199
+
200
+ # Process the queue and yield messages over the SSE stream.
201
+ while s := await output_queue.get():
202
+ log.info("Got from queue: %s: %s", type(s), s)
203
+ if isinstance(s, str):
204
+ # str is an LLM token
205
+ yield f"data: {json.dumps({'type': 'token', 'content': s})}\n\n"
206
+ continue
207
+
208
+ # Otherwise, s should be a dict of state updates for each node in the graph.
209
+ # s could have updates for multiple nodes, so check each for messages.
210
+ new_messages = []
211
+ for _, state in s.items():
212
+ new_messages.extend(state["messages"])
213
+ for message in new_messages:
214
+ try:
215
+ chat_message = ChatMessage.from_langchain(message)
216
+ chat_message.run_id = str(run_id)
217
+ except Exception as e:
218
+ yield f"data: {json.dumps({'type': 'error', 'content': f'Error parsing message: {e}'})}\n\n"
219
+ continue
220
+ # LangGraph re-sends the input message, which feels weird, so drop it
221
+ if (
222
+ chat_message.type == "human"
223
+ and chat_message.content == stream_input.message
224
+ ):
225
+ continue
226
+ yield f"data: {json.dumps({'type': 'message', 'content': chat_message.dict()})}\n\n"
227
+
228
+ await stream_task
229
+ yield "data: [DONE]\n\n"
230
+
231
+ return StreamingResponse(
232
+ message_generator(),
233
+ media_type="text/event-stream",
234
+ )
235
+
236
+ @router.get("/history")
237
+ async def get_history(request: Request) -> List[ChatMessage]:
238
+ """
239
+ Get the history of a thread.
240
+ """
241
+
242
+ graph = await as_awaitable(get_graph(request))
243
+ thread_id = await as_awaitable(get_thread_id(request))
244
+
245
+ thread_info = await ThreadsCheckpointerUtil.get_thread_info(thread_id, graph.checkpointer)
246
+
247
+ if not await as_awaitable(allow_access_thread(thread_id, thread_info, request)):
248
+ raise HTTPException(status_code=403, detail="Forbidden")
249
+
250
+ # agent: CompiledGraph = router.state.workflows[workflow].get_graph()
251
+ config = RunnableConfig(configurable={
252
+ # used by checkpointer
253
+ "thread_id": thread_id,
254
+ })
255
+ state = await graph.aget_state(config)
256
+ messages = state.values.get("messages", [])
257
+
258
+ converted_messages: List[ChatMessage] = []
259
+ for message in messages:
260
+ try:
261
+ chat_message = ChatMessage.from_langchain(message)
262
+ converted_messages.append(chat_message)
263
+ except Exception as e:
264
+ log.error(f"Error parsing message: {e}")
265
+ continue
266
+ return converted_messages
267
+
268
+ # @router.get("/feedback")
269
+ # async def get_feedback(request: Request, thread_id: str):
270
+ # """Get all feedback for a thread.
271
+ #
272
+ # Arguments:
273
+ # thread_id: The ID of the thread to get feedback for.
274
+ # """
275
+ # # if thread_id not in router.state.threads:
276
+ # # raise HTTPException(status_code=404, detail=f"Unknown thread: {thread_id}")
277
+ # # assert_viewer_can_assume_identity(
278
+ # # request, principal=router.state.threads[thread_id].user
279
+ # # )
280
+ # feedback = [
281
+ # f.model_dump(mode="json")
282
+ # async for f in graph.checkpointer.alist_feedback(thread_id=thread_id)
283
+ # ]
284
+ # return feedback
285
+ #
286
+ # @router.post("/feedback")
287
+ # async def feedback(feedback: Feedback, request: Request):
288
+ # """
289
+ # Record feedback for a run of the agent.
290
+ #
291
+ # Arguments:
292
+ # feedback: The feedback to record.
293
+ # """
294
+ # if feedback.thread_id not in router.state.threads:
295
+ # raise HTTPException(
296
+ # status_code=404, detail=f"Unknown thread: {feedback.thread_id}"
297
+ # )
298
+ # assert_viewer_can_assume_identity(
299
+ # request, principal=router.state.threads[feedback.thread_id].user
300
+ # )
301
+ #
302
+ # # store in database
303
+ # try:
304
+ # await graph.checkpointer.aput_feedback(feedback)
305
+ # db_status = "success"
306
+ # except Exception as e:
307
+ # log.error(f"Error storing feedback in database: {e}")
308
+ # db_status = "error"
309
+ #
310
+ # ## Also store in Langfuse
311
+ # ## We don't have the run_id, but need it for Langfuse
312
+ # ## The run_id is currently not store in the database.
313
+ # # try:
314
+ # # langfuse = Langfuse(
315
+ # # public_key=router.state.cfg.logging.langfuse.public_key,
316
+ # # secret_key=router.state.cfg.logging.langfuse.secret_key,
317
+ # # host=router.state.cfg.logging.langfuse.host,
318
+ # # )
319
+ # # langfuse.score(
320
+ # # trace_id=feedback.run_id,
321
+ # # name=feedback.key,
322
+ # # value=feedback.score,
323
+ # # comment=feedback.kwargs.get("comment", ""),
324
+ # # )
325
+ # # langfuse_status = "success"
326
+ # # except Exception as e:
327
+ # # log.error(f"Error storing feedback in Langfuse: {e}")
328
+ # # langfuse_status = "error"
329
+ #
330
+ # langfuse_status = "not implemented"
331
+ #
332
+ # return {"db_status": db_status, "langfuse_status": langfuse_status}
333
+
334
+ return router
@@ -0,0 +1,169 @@
1
+ from datetime import datetime
2
+ from typing import Dict, Any, List, Literal, Optional, Union
3
+ from langchain_core.messages import (
4
+ BaseMessage,
5
+ HumanMessage,
6
+ AIMessage,
7
+ ToolMessage,
8
+ ToolCall,
9
+ message_to_dict,
10
+ messages_from_dict,
11
+ )
12
+ from pydantic import BaseModel, Field
13
+
14
+ class InvokeInput(BaseModel):
15
+ """Basic user input for the agent."""
16
+
17
+ message: str = Field(
18
+ description="User input to the agent.",
19
+ examples=["What is the weather in Tokyo?"],
20
+ )
21
+ # args: Dict[str, Any] = Field(
22
+ # description="Arguments to pass to the workflow.",
23
+ # default={},
24
+ # examples=[{"kb": "veritone_support"}],
25
+ # )
26
+ # user: Optional[str] = Field(
27
+ # description="A user identifier to validate the user in knowledge bases and other tools.",
28
+ # default=None,
29
+ # examples=["jjohnson", "ccarlson"],
30
+ # )
31
+ # thread_id: str = Field(
32
+ # description="Thread ID to persist and continue a multi-turn conversation.",
33
+ # examples=["847c6285-8fc9-4560-a83f-4e6285809254"],
34
+ # )
35
+
36
+
37
+ class StreamInput(InvokeInput):
38
+ """User input for streaming the agent's response."""
39
+
40
+ stream_tokens: bool = Field(
41
+ description="Whether to stream LLM tokens to the client.",
42
+ default=True,
43
+ )
44
+
45
+
46
+ class AgentResponse(BaseModel):
47
+ """Response from the agent when called via /invoke."""
48
+
49
+ message: Dict[str, Any] = Field(
50
+ description="Final response from the agent, as a serialized LangChain message.",
51
+ examples=[
52
+ {
53
+ "message": {
54
+ "type": "ai",
55
+ "data": {
56
+ "content": "The weather in Tokyo is 70 degrees.",
57
+ "type": "ai",
58
+ },
59
+ }
60
+ }
61
+ ],
62
+ )
63
+
64
+
65
+ class ChatMessage(BaseModel):
66
+ """Message in a chat."""
67
+
68
+ type: Literal["human", "ai", "tool"] = Field(
69
+ description="Role of the message.",
70
+ examples=["human", "ai", "tool"],
71
+ )
72
+ content: Union[str, list[Union[str, dict]]] = Field(
73
+ description="Content of the message.",
74
+ examples=["Hello, world!"],
75
+ )
76
+ tool_calls: List[ToolCall] = Field(
77
+ description="Tool calls in the message.",
78
+ default=[],
79
+ )
80
+ tool_call_id: str | None = Field(
81
+ description="Tool call that this message is responding to.",
82
+ default=None,
83
+ examples=["call_Jja7J89XsjrOLA5r!MEOW!SL"],
84
+ )
85
+ run_id: str | None = Field(
86
+ description="Run ID of the message.",
87
+ default=None,
88
+ examples=["847c6285-8fc9-4560-a83f-4e6285809254"],
89
+ )
90
+ original: Dict[str, Any] = Field(
91
+ description="Original LangChain message in serialized form.",
92
+ default={},
93
+ )
94
+
95
+ @classmethod
96
+ def from_langchain(cls, message: BaseMessage) -> "ChatMessage":
97
+ """Create a ChatMessage from a LangChain message."""
98
+ original = message_to_dict(message)
99
+ match message:
100
+ case HumanMessage():
101
+ human_message = cls(
102
+ type="human", content=message.content, original=original
103
+ )
104
+ return human_message
105
+ case AIMessage():
106
+ ai_message = cls(type="ai", content=message.content, original=original)
107
+ if message.tool_calls:
108
+ ai_message.tool_calls = message.tool_calls
109
+ return ai_message
110
+ case ToolMessage():
111
+ tool_message = cls(
112
+ type="tool",
113
+ content=message.content,
114
+ tool_call_id=message.tool_call_id,
115
+ original=original,
116
+ )
117
+ return tool_message
118
+ case _:
119
+ raise ValueError(
120
+ f"Unsupported message type: {message.__class__.__name__}"
121
+ )
122
+
123
+ def to_langchain(self) -> BaseMessage:
124
+ """Convert the ChatMessage to a LangChain message."""
125
+ if self.original:
126
+ return messages_from_dict([self.original])[0]
127
+ match self.type:
128
+ case "human":
129
+ return HumanMessage(content=self.content)
130
+ case _:
131
+ raise NotImplementedError(f"Unsupported message type: {self.type}")
132
+
133
+ def pretty_print(self) -> None:
134
+ """Pretty print the ChatMessage."""
135
+ lc_msg = self.to_langchain()
136
+ lc_msg.pretty_print()
137
+
138
+ def get_artifact(self) -> Optional[Dict[str, Any]]:
139
+ """Get the artifact from the message if there is one."""
140
+ if (
141
+ self.original.get("data")
142
+ and self.original["data"].get("artifact")
143
+ and self.original["data"]["artifact"]
144
+ ):
145
+ return self.original["data"]["artifact"]
146
+ return None
147
+
148
+
149
+ class Feedback(BaseModel):
150
+ """Feedback for a run."""
151
+
152
+ message_id: str = Field(
153
+ description="Message ID to record feedback for.",
154
+ examples=["847c6285-8fc9-4560-a83f-4e6285809254"],
155
+ )
156
+ # thread_id: str = Field(
157
+ # description="Thread ID to record feedback for.",
158
+ # examples=["847c6285-8fc9-4560-a83f-4e6285809254"],
159
+ # )
160
+ score: float = Field(
161
+ description="Feedback score.",
162
+ examples=[0.8],
163
+ )
164
+ kwargs: Dict[str, Any] = Field(
165
+ description="Additional feedback kwargs, passed to LangSmith.",
166
+ default={},
167
+ examples=[{"comment": "In-line human feedback"}],
168
+ )
169
+ creation: datetime = Field(default_factory=datetime.now)
@@ -0,0 +1,59 @@
1
+ from typing import Annotated
2
+
3
+ from langchain_aws import ChatBedrock
4
+ from typing_extensions import TypedDict
5
+
6
+ from fastapi import FastAPI
7
+ from langgraph.checkpoint.memory import MemorySaver
8
+ from langgraph.graph import StateGraph
9
+ from langgraph.graph.message import add_messages
10
+ import uvicorn
11
+
12
+ from veri_agents_api.fastapi.thread import create_thread_router
13
+
14
+ if __name__ == "__main__":
15
+ class State(TypedDict):
16
+ messages: Annotated[list, add_messages]
17
+
18
+ graph_builder = StateGraph(State)
19
+
20
+ llm = ChatBedrock(model_id="us.anthropic.claude-3-7-sonnet-20250219-v1:0") # pyright: ignore[reportCallIssue]
21
+
22
+ def chatbot(state: State):
23
+ return {"messages": [llm.invoke(state["messages"])]}
24
+
25
+ # The first argument is the unique node name
26
+ # The second argument is the function or object that will be called whenever
27
+ # the node is used.
28
+ graph_builder.add_node("chatbot", chatbot)
29
+ graph_builder.set_entry_point("chatbot")
30
+ graph_builder.set_finish_point("chatbot")
31
+
32
+
33
+
34
+ # in-memory persistence
35
+ memory = MemorySaver()
36
+ graph = graph_builder.compile(checkpointer=memory)
37
+
38
+ # veri-agents convenience router
39
+ thread_router = create_thread_router(
40
+ # same graph for every request
41
+ get_graph=lambda req: graph,
42
+ # derive thread id from /thread/{thread_id} path param
43
+ get_thread_id=lambda req: req.path_params["thread_id"]
44
+ )
45
+
46
+ # root fastapi app
47
+ app = FastAPI()
48
+ app.include_router(thread_router, prefix="/threads/{thread_id}")
49
+
50
+ uvicorn.run(app, port=5000, log_level="info")
51
+ # you can now access:
52
+ # GET /openapi.json
53
+ # GET /threads
54
+ # GET /thread/{thread_id}/info
55
+ # POST /threads/{thread_id}/invoke
56
+ # POST /threads/{thread_id}/stream
57
+ # GET /threads/{thread_id}/history
58
+ # GET /threads/{thread_id}/feedback
59
+ # POST /threads/{thread_id}/feedback
@@ -0,0 +1,58 @@
1
+ from typing import Annotated
2
+
3
+ from langchain_aws import ChatBedrock
4
+ from typing_extensions import TypedDict
5
+
6
+ from langgraph.checkpoint.memory import MemorySaver
7
+ from langgraph.graph import StateGraph
8
+ from langgraph.graph.message import add_messages
9
+ from fastapi import FastAPI
10
+ import uvicorn
11
+
12
+ from veri_agents_api.fastapi.thread import create_thread_router
13
+
14
+ if __name__ == "__main__":
15
+ class State(TypedDict):
16
+ messages: Annotated[list, add_messages]
17
+
18
+ graph_builder = StateGraph(State)
19
+
20
+ llm = ChatBedrock(model_id="us.anthropic.claude-3-7-sonnet-20250219-v1:0") # pyright: ignore[reportCallIssue]
21
+
22
+ def chatbot(state: State):
23
+ return {"messages": [llm.invoke(state["messages"])]}
24
+
25
+ # The first argument is the unique node name
26
+ # The second argument is the function or object that will be called whenever
27
+ # the node is used.
28
+ graph_builder.add_node("chatbot", chatbot)
29
+ graph_builder.set_entry_point("chatbot")
30
+ graph_builder.set_finish_point("chatbot")
31
+
32
+
33
+
34
+
35
+ # in-memory persistence
36
+ memory = MemorySaver()
37
+ graph = graph_builder.compile(checkpointer=memory)
38
+
39
+ # veri-agents convenience router
40
+ thread_router = create_thread_router(
41
+ # same graph for every request
42
+ get_graph=lambda req: graph,
43
+ # same thread for every request
44
+ get_thread_id=lambda req: "inmem"
45
+ )
46
+
47
+ # root fastapi app
48
+ app = FastAPI()
49
+ app.include_router(thread_router)
50
+
51
+ uvicorn.run(app, port=5000, log_level="info")
52
+ # you can now access:
53
+ # GET /openapi.json
54
+ # POST /invoke
55
+ # POST /stream
56
+ # GET /history
57
+ # GET /feedback
58
+ # POST /feedback
@@ -0,0 +1 @@
1
+ from .router import *
@@ -0,0 +1,75 @@
1
+ import logging
2
+ from typing import Callable, cast, Type, Awaitable
3
+
4
+ from fastapi import HTTPException, Request, APIRouter
5
+ from langgraph.graph.graph import CompiledGraph
6
+
7
+ from veri_agents_api.fastapi.thread import (
8
+ create_thread_router as create_thread_router, InvokeInput
9
+ )
10
+ from veri_agents_api.threads_util import ThreadsCheckpointerUtil, ThreadInfo
11
+ from veri_agents_api.util.awaitable import MaybeAwaitable, as_awaitable
12
+
13
+ log = logging.getLogger(__name__)
14
+
15
+
16
+ def create_threads_router(
17
+ get_graph: Callable[[Request], MaybeAwaitable[CompiledGraph]],
18
+ allow_access_thread: Callable[[str, ThreadInfo | None, Request], MaybeAwaitable[bool]] = lambda thread_id, thread_info, request: True,
19
+ allow_invoke_thread: Callable[[str, ThreadInfo | None, InvokeInput, Request], MaybeAwaitable[bool]] = lambda thread_id, thread_info, invoke_input, request: True,
20
+ on_invoke_thread: Callable[[str, ThreadInfo | None, InvokeInput, Request], MaybeAwaitable[None]] = lambda thread_id, thread_info, invoke_input, request: None,
21
+ # InvokeInputCls: Type[InvokeInput] = InvokeInput,
22
+ **router_kwargs
23
+ ):
24
+ router = APIRouter(prefix="/threads", **router_kwargs)
25
+
26
+ thread_router = create_thread_router(
27
+ # derive thread id from /thread/{thread_id} path param
28
+ get_thread_id=lambda req: req.path_params["thread_id"],
29
+
30
+ # arg passthrough - TODO: make more elegant
31
+ get_graph=get_graph,
32
+ allow_access_thread=allow_access_thread,
33
+ allow_invoke_thread=allow_invoke_thread,
34
+ on_invoke_thread=on_invoke_thread,
35
+
36
+ # InvokeInputCls=InvokeInputCls
37
+ )
38
+
39
+ @router.get("/")
40
+ async def get_threads(request: Request):
41
+ """Get all threads the user has access to."""
42
+
43
+ graph = await as_awaitable(get_graph(request))
44
+
45
+ all_thread_info = await ThreadsCheckpointerUtil.list_threads(graph.checkpointer)
46
+
47
+ accessible_thread_info: list[ThreadInfo] = []
48
+ for thread_info in all_thread_info:
49
+ if allow_access_thread(thread_info.thread_id, thread_info, request):
50
+ accessible_thread_info.append(thread_info)
51
+
52
+ return accessible_thread_info
53
+
54
+ @router.get("/{thread_id}/info")
55
+ async def get_thread_by_id(thread_id: str, request: Request):
56
+ """Get a thread by its ID.
57
+
58
+ Arguments:
59
+ thread_id: The ID of the thread to get.
60
+ """
61
+ graph = await as_awaitable(get_graph(request))
62
+
63
+ thread_info = await ThreadsCheckpointerUtil.get_thread_info(thread_id, graph.checkpointer)
64
+
65
+ if not allow_access_thread(thread_id, thread_info, request):
66
+ raise HTTPException(status_code=403, detail="Forbidden")
67
+
68
+ try:
69
+ return thread_info
70
+ except:
71
+ raise HTTPException(status_code=404, detail="Thread not found")
72
+
73
+ router.include_router(thread_router, prefix="/{thread_id}")
74
+
75
+ return router
@@ -0,0 +1,58 @@
1
+ from typing import Annotated
2
+
3
+ from langchain_aws import ChatBedrock
4
+ from typing_extensions import TypedDict
5
+
6
+ from fastapi import FastAPI
7
+ from langgraph.checkpoint.memory import MemorySaver
8
+ from langgraph.graph import StateGraph
9
+ from langgraph.graph.message import add_messages
10
+ import uvicorn
11
+
12
+ from veri_agents_api.fastapi.threads import create_threads_router
13
+
14
+ if __name__ == "__main__":
15
+ class State(TypedDict):
16
+ messages: Annotated[list, add_messages]
17
+
18
+ graph_builder = StateGraph(State)
19
+
20
+ # noinspection PyArgumentList
21
+ llm = ChatBedrock(model_id="us.anthropic.claude-3-7-sonnet-20250219-v1:0") # pyright: ignore[reportCallIssue]
22
+
23
+ def chatbot(state: State):
24
+ return {"messages": [llm.invoke(state["messages"])]}
25
+
26
+ # The first argument is the unique node name
27
+ # The second argument is the function or object that will be called whenever
28
+ # the node is used.
29
+ graph_builder.add_node("chatbot", chatbot)
30
+ graph_builder.set_entry_point("chatbot")
31
+ graph_builder.set_finish_point("chatbot")
32
+
33
+
34
+
35
+ # in-memory persistence
36
+ memory = MemorySaver()
37
+ graph = graph_builder.compile(checkpointer=memory)
38
+
39
+ # veri-agents convenience router
40
+ threads_router = create_threads_router(
41
+ # same graph for every request
42
+ get_graph=lambda req: graph
43
+ )
44
+
45
+ # root fastapi app
46
+ app = FastAPI()
47
+ app.include_router(threads_router)
48
+
49
+ uvicorn.run(app, port=5000, log_level="info")
50
+ # you can now access:
51
+ # GET /openapi.json
52
+ # GET /threads
53
+ # GET /threads/{thread_id}/info
54
+ # POST /threads/{thread_id}/invoke
55
+ # POST /threads/{thread_id}/stream
56
+ # GET /threads/{thread_id}/history
57
+ # GET /threads/{thread_id}/feedback
58
+ # POST /threads/{thread_id}/feedback
@@ -0,0 +1,2 @@
1
+ from .checkpointer import *
2
+ from .schema import *
@@ -0,0 +1,43 @@
1
+ from typing import cast
2
+
3
+ from .schema import ThreadInfo
4
+ from langgraph.types import Checkpointer
5
+ from langgraph.checkpoint.base import BaseCheckpointSaver
6
+
7
+ class ThreadsCheckpointerUtil:
8
+ @staticmethod
9
+ async def get_thread_info(thread_id: str, checkpointer: Checkpointer) -> ThreadInfo | None:
10
+ if not isinstance(checkpointer, BaseCheckpointSaver):
11
+ raise Exception("checkpointer must be instance of BaseCheckpointSaver")
12
+
13
+ chk_tuple = (await checkpointer.aget_tuple(config={
14
+ "configurable": {"_has_threadinfo": True, "thread_id": thread_id}}))
15
+
16
+ if chk_tuple is None:
17
+ return None
18
+
19
+ thread_metadata = chk_tuple.metadata
20
+ return ThreadInfo(
21
+ thread_id=thread_id,
22
+ )
23
+
24
+ @staticmethod
25
+ async def list_threads(checkpointer: Checkpointer) -> list[ThreadInfo]:
26
+ if not isinstance(checkpointer, BaseCheckpointSaver):
27
+ raise Exception("checkpointer must be instance of BaseCheckpointSaver")
28
+
29
+ init_step_checkpoints = checkpointer.alist(config=None, filter={ 'step': -1 }) # get initial steps only - this ensures we are only getting one thread_id from the checkpoints of a thread
30
+
31
+ all_accessible_thread_ids: list[str] = []
32
+ async for checkpoint in init_step_checkpoints:
33
+ thread_id = cast(str | None, checkpoint.config.get("configurable", {"thread_id": None}).get("thread_id", None))
34
+ if thread_id is not None: # and allow_access_thread(thread_id, request) ?
35
+ all_accessible_thread_ids.append(cast(str, thread_id))
36
+
37
+ all_thread_info: list[ThreadInfo] = []
38
+ for thread_id in all_accessible_thread_ids:
39
+ thread_info = await ThreadsCheckpointerUtil.get_thread_info(thread_id, checkpointer)
40
+ if thread_info is not None:
41
+ all_thread_info.append(thread_info)
42
+
43
+ return all_thread_info
@@ -0,0 +1,13 @@
1
+ from datetime import datetime
2
+
3
+ from pydantic import BaseModel, Field
4
+
5
+ class ThreadInfo(BaseModel):
6
+ """Information about a single thread."""
7
+
8
+ thread_id: str
9
+ # workflow_id: str
10
+ # name: str
11
+ # user: str
12
+ # metadata: dict = Field(default={})
13
+ # creation: datetime = Field(default_factory=datetime.now)
File without changes
@@ -0,0 +1,11 @@
1
+ import asyncio
2
+ from typing import TypeVar, Awaitable, cast
3
+
4
+ T = TypeVar('T')
5
+
6
+ async def as_awaitable(maybe_coroutine: T | Awaitable[T]) -> T:
7
+ if asyncio.iscoroutine(maybe_coroutine):
8
+ return await maybe_coroutine
9
+ return cast(T, maybe_coroutine)
10
+
11
+ type MaybeAwaitable[T] = T | Awaitable[T]
@@ -0,0 +1,17 @@
1
+ Metadata-Version: 2.4
2
+ Name: veri-agents-api
3
+ Version: 0.1.1
4
+ Summary: Add your description here
5
+ Author-email: Markus Toman <mtoman@veritone.com>, Teo Boley <tboley@veritone.com>
6
+ Requires-Python: >=3.12
7
+ Requires-Dist: veri-agents-common[langgraph]==0.1.1
8
+ Provides-Extra: dev
9
+ Requires-Dist: langchain-aws>=0.2.21; extra == 'dev'
10
+ Requires-Dist: uvicorn>=0.34.2; extra == 'dev'
11
+ Requires-Dist: veri-agents-common[fastapi,langfuse]==0.1.1; extra == 'dev'
12
+ Provides-Extra: fastapi
13
+ Requires-Dist: veri-agents-common[fastapi,langfuse]==0.1.1; extra == 'fastapi'
14
+ Provides-Extra: fastapi-dev
15
+ Requires-Dist: langchain-aws>=0.2.21; extra == 'fastapi-dev'
16
+ Requires-Dist: uvicorn>=0.34.2; extra == 'fastapi-dev'
17
+ Requires-Dist: veri-agents-common[fastapi,langfuse]==0.1.1; extra == 'fastapi-dev'
@@ -0,0 +1,18 @@
1
+ veri_agents_api/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
2
+ veri_agents_api/fastapi/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
3
+ veri_agents_api/fastapi/thread/__init__.py,sha256=WFpfiIPAMaAbD2OFDUlIJ-tt1-bXOz2fOAGRYvyCFts,44
4
+ veri_agents_api/fastapi/thread/router.py,sha256=DDEjb2-xGX3Osg9BBP2AwHOkt7JzO-tL58rWKnmsjME,14135
5
+ veri_agents_api/fastapi/thread/schema.py,sha256=2IQlKuzWzLPeynIHVKKKXhCw3pVZe0KjJkQo1CM8B3Y,5519
6
+ veri_agents_api/fastapi/thread/test/multi_thread.py,sha256=xxe6SpfOLbWeTFItNAX6ipHz2zkOdgtE4P_ADCgWNF4,1888
7
+ veri_agents_api/fastapi/thread/test/single_thread.py,sha256=lbSehG3oWo-bC5iJCVWylkV_TtEdemo87PFappb0Et4,1655
8
+ veri_agents_api/fastapi/threads/__init__.py,sha256=GAq8qINP3yNKm8tnrRGiWuH5IuEsPI4Rzl5T-5Ql1bY,22
9
+ veri_agents_api/fastapi/threads/router.py,sha256=wlYYFHQqJvXB2MZUWF_7xvZE8oTJnj6C9GIi-BWPmhI,2887
10
+ veri_agents_api/fastapi/threads/test/__main__.py,sha256=2Cp7dtrKRBsC_DQM7iim8lF7lOgQk6T0s7BNH56b40A,1770
11
+ veri_agents_api/threads_util/__init__.py,sha256=P06UGWyMeigeiE6dDmsS40k28JZs0NmIOSnldsB7bQ4,50
12
+ veri_agents_api/threads_util/checkpointer.py,sha256=4eiR5zfbjpPqPNzQ7_K8tiFWyWXYxwfxWw6k-fwI774,1904
13
+ veri_agents_api/threads_util/schema.py,sha256=f3l4l5MozS4QvxOewwVSRCKScQmw3Xi2IIpCfbqtsAM,323
14
+ veri_agents_api/util/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
15
+ veri_agents_api/util/awaitable.py,sha256=FerfQ0QKK-nUo-jQ9MfxM8g5pbo0P10Tt1mYZZa1zTs,303
16
+ veri_agents_api-0.1.1.dist-info/METADATA,sha256=McgIAMbZ9e_DtLQUbQYo8M-jkKDntJrIT15c_qFDNWU,777
17
+ veri_agents_api-0.1.1.dist-info/WHEEL,sha256=qtCwoSJWgHk21S1Kb4ihdzI2rlJ1ZKaIurTj_ngOhyQ,87
18
+ veri_agents_api-0.1.1.dist-info/RECORD,,
@@ -0,0 +1,4 @@
1
+ Wheel-Version: 1.0
2
+ Generator: hatchling 1.27.0
3
+ Root-Is-Purelib: true
4
+ Tag: py3-none-any