veri-agents-api 0.1.1__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -0,0 +1,186 @@
1
+ .turbo
2
+
3
+ # go
4
+ vendor
5
+
6
+ # js
7
+ dist
8
+ out-tsc
9
+ node_modules
10
+
11
+ .idea
12
+ *.iml
13
+ .DS_Store
14
+
15
+
16
+ # Byte-compiled / optimized / DLL files
17
+ __pycache__/
18
+ *.py[cod]
19
+ *$py.class
20
+
21
+ # C extensions
22
+ *.so
23
+
24
+ # Distribution / packaging
25
+ .Python
26
+ build/
27
+ develop-eggs/
28
+ dist/
29
+ downloads/
30
+ eggs/
31
+ .eggs/
32
+ lib/
33
+ lib64/
34
+ parts/
35
+ sdist/
36
+ var/
37
+ wheels/
38
+ share/python-wheels/
39
+ *.egg-info/
40
+ .installed.cfg
41
+ *.egg
42
+ MANIFEST
43
+
44
+ # PyInstaller
45
+ # Usually these files are written by a python script from a template
46
+ # before PyInstaller builds the exe, so as to inject date/other infos into it.
47
+ *.manifest
48
+ *.spec
49
+
50
+ # Installer logs
51
+ pip-log.txt
52
+ pip-delete-this-directory.txt
53
+
54
+ # Unit test / coverage reports
55
+ htmlcov/
56
+ .tox/
57
+ .nox/
58
+ .coverage
59
+ .coverage.*
60
+ .cache
61
+ nosetests.xml
62
+ coverage.xml
63
+ *.cover
64
+ *.py,cover
65
+ .hypothesis/
66
+ .pytest_cache/
67
+ cover/
68
+
69
+ # Translations
70
+ *.mo
71
+ *.pot
72
+
73
+ # Django stuff:
74
+ *.log
75
+ local_settings.py
76
+ db.sqlite3
77
+ db.sqlite3-journal
78
+
79
+ # Flask stuff:
80
+ instance/
81
+ .webassets-cache
82
+
83
+ # Scrapy stuff:
84
+ .scrapy
85
+
86
+ # Sphinx documentation
87
+ docs/_build/
88
+
89
+ # PyBuilder
90
+ .pybuilder/
91
+ target/
92
+
93
+ # Jupyter Notebook
94
+ .ipynb_checkpoints
95
+
96
+ # IPython
97
+ profile_default/
98
+ ipython_config.py
99
+
100
+ # pyenv
101
+ # For a library or package, you might want to ignore these files since the code is
102
+ # intended to run in multiple environments; otherwise, check them in:
103
+ # .python-version
104
+
105
+ # pipenv
106
+ # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.
107
+ # However, in case of collaboration, if having platform-specific dependencies or dependencies
108
+ # having no cross-platform support, pipenv may install dependencies that don't work, or not
109
+ # install all needed dependencies.
110
+ #Pipfile.lock
111
+
112
+ # UV
113
+ # Similar to Pipfile.lock, it is generally recommended to include uv.lock in version control.
114
+ # This is especially recommended for binary packages to ensure reproducibility, and is more
115
+ # commonly ignored for libraries.
116
+ #uv.lock
117
+
118
+ # poetry
119
+ # Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control.
120
+ # This is especially recommended for binary packages to ensure reproducibility, and is more
121
+ # commonly ignored for libraries.
122
+ # https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control
123
+ #poetry.lock
124
+
125
+ # pdm
126
+ # Similar to Pipfile.lock, it is generally recommended to include pdm.lock in version control.
127
+ #pdm.lock
128
+ # pdm stores project-wide configurations in .pdm.toml, but it is recommended to not include it
129
+ # in version control.
130
+ # https://pdm.fming.dev/latest/usage/project/#working-with-version-control
131
+ .pdm.toml
132
+ .pdm-python
133
+ .pdm-build/
134
+
135
+ # PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm
136
+ __pypackages__/
137
+
138
+ # Celery stuff
139
+ celerybeat-schedule
140
+ celerybeat.pid
141
+
142
+ # SageMath parsed files
143
+ *.sage.py
144
+
145
+ # Environments
146
+ .env
147
+ .venv
148
+ env/
149
+ venv/
150
+ ENV/
151
+ env.bak/
152
+ venv.bak/
153
+
154
+ # Spyder project settings
155
+ .spyderproject
156
+ .spyproject
157
+
158
+ # Rope project settings
159
+ .ropeproject
160
+
161
+ # mkdocs documentation
162
+ /site
163
+
164
+ # mypy
165
+ .mypy_cache/
166
+ .dmypy.json
167
+ dmypy.json
168
+
169
+ # Pyre type checker
170
+ .pyre/
171
+
172
+ # pytype static type analyzer
173
+ .pytype/
174
+
175
+ # Cython debug symbols
176
+ cython_debug/
177
+
178
+ # PyCharm
179
+ # JetBrains specific template is maintained in a separate JetBrains.gitignore that can
180
+ # be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore
181
+ # and can be added to the global gitignore or merged into this file. For a more nuclear
182
+ # option (not recommended) you can uncomment the following to ignore the entire idea folder.
183
+ #.idea/
184
+
185
+ # PyPI configuration file
186
+ .pypirc
@@ -0,0 +1,17 @@
1
+ Metadata-Version: 2.4
2
+ Name: veri-agents-api
3
+ Version: 0.1.1
4
+ Summary: Add your description here
5
+ Author-email: Markus Toman <mtoman@veritone.com>, Teo Boley <tboley@veritone.com>
6
+ Requires-Python: >=3.12
7
+ Requires-Dist: veri-agents-common[langgraph]==0.1.1
8
+ Provides-Extra: dev
9
+ Requires-Dist: langchain-aws>=0.2.21; extra == 'dev'
10
+ Requires-Dist: uvicorn>=0.34.2; extra == 'dev'
11
+ Requires-Dist: veri-agents-common[fastapi,langfuse]==0.1.1; extra == 'dev'
12
+ Provides-Extra: fastapi
13
+ Requires-Dist: veri-agents-common[fastapi,langfuse]==0.1.1; extra == 'fastapi'
14
+ Provides-Extra: fastapi-dev
15
+ Requires-Dist: langchain-aws>=0.2.21; extra == 'fastapi-dev'
16
+ Requires-Dist: uvicorn>=0.34.2; extra == 'fastapi-dev'
17
+ Requires-Dist: veri-agents-common[fastapi,langfuse]==0.1.1; extra == 'fastapi-dev'
@@ -0,0 +1,16 @@
1
+ {
2
+ "name": "@veritone/agents-chat-api",
3
+ "version": "0.0.0",
4
+ "description": "",
5
+ "type": "module",
6
+ "scripts": {
7
+ "typecheck": "uv run pyright",
8
+ "format": "uv run ruff format",
9
+ "lint": "echo \"Warn: no lint specified\" && exit 0",
10
+ "test": "echo \"Warn: no test specified\" && exit 0"
11
+ },
12
+ "author": "",
13
+ "dependencies": {
14
+ "@veritone/agents-common": "workspace:"
15
+ }
16
+ }
@@ -0,0 +1,30 @@
1
+ [project]
2
+ name = "veri-agents-api"
3
+ version = "0.1.1"
4
+ description = "Add your description here"
5
+ authors = [
6
+ {name = "Markus Toman", email = "mtoman@veritone.com"},
7
+ {name = "Teo Boley", email = "tboley@veritone.com"},
8
+ ]
9
+ requires-python = ">=3.12"
10
+ dependencies = [
11
+ "veri-agents-common[langgraph]==0.1.1"
12
+ ]
13
+
14
+ [project.optional-dependencies]
15
+ fastapi = [
16
+ "veri-agents-common[fastapi,langfuse]==0.1.1",
17
+ ]
18
+ fastapi-dev = [
19
+ "langchain-aws>=0.2.21",
20
+ "uvicorn>=0.34.2",
21
+ "veri-agents-api[fastapi]==0.1.1",
22
+ ]
23
+ # as optional dep so it can be referenced in workspace pyproject.toml
24
+ dev = [
25
+ "veri-agents-api[fastapi-dev]==0.1.1"
26
+ ]
27
+
28
+ [build-system]
29
+ requires = ["hatchling"]
30
+ build-backend = "hatchling.build"
File without changes
@@ -0,0 +1,2 @@
1
+ from .router import *
2
+ from .schema import *
@@ -0,0 +1,334 @@
1
+ import asyncio
2
+ import json
3
+ import logging
4
+ from typing import Any, AsyncGenerator, Dict, List, Tuple, Callable
5
+ from uuid import uuid4
6
+
7
+ from fastapi import HTTPException, Request, APIRouter
8
+ from fastapi.responses import StreamingResponse
9
+ from langchain_core.callbacks import AsyncCallbackHandler
10
+ from langchain_core.runnables import RunnableConfig
11
+ from langgraph.graph.graph import CompiledGraph
12
+
13
+ from .schema import (
14
+ ChatMessage,
15
+ StreamInput,
16
+ InvokeInput,
17
+ )
18
+ from veri_agents_api.threads_util import ThreadInfo, ThreadsCheckpointerUtil
19
+ from veri_agents_api.util.awaitable import as_awaitable, MaybeAwaitable
20
+
21
+ log = logging.getLogger(__name__)
22
+
23
+ class TokenQueueStreamingHandler(AsyncCallbackHandler):
24
+ """LangChain callback handler for streaming LLM tokens to an asyncio queue."""
25
+
26
+ def __init__(self, queue: asyncio.Queue):
27
+ self.queue = queue
28
+
29
+ async def on_llm_new_token(self, token: str, **kwargs) -> None:
30
+ if token:
31
+ await self.queue.put(token)
32
+
33
+ def create_thread_router(
34
+ get_graph: Callable[[Request], MaybeAwaitable[CompiledGraph]],
35
+ get_thread_id: Callable[[Request], MaybeAwaitable[str]],
36
+ allow_access_thread: Callable[[str, ThreadInfo | None, Request], MaybeAwaitable[bool]] = lambda thread_id, thread_info, request: True,
37
+ allow_invoke_thread: Callable[[str, ThreadInfo | None, InvokeInput, Request], MaybeAwaitable[bool]] = lambda thread_id, thread_info, invoke_input, request: True,
38
+ invoke_runnable_config: Callable[[str, ThreadInfo | None, InvokeInput, Request], MaybeAwaitable[RunnableConfig | None]] = lambda thread_id, thread_info, invoke_input, request: None,
39
+ on_invoke_thread: Callable[[str, ThreadInfo | None, InvokeInput, Request], MaybeAwaitable[None]] = lambda thread_id, thread_info, invoke_input, request: None,
40
+ # InvokeInputCls: Type[InvokeInput] = InvokeInput,
41
+ **router_kwargs
42
+ ):
43
+ """
44
+ POST /invoke
45
+ POST /stream
46
+ GET /history
47
+ GET /feedback
48
+ POST /feedback
49
+ """
50
+
51
+ router = APIRouter(**router_kwargs)
52
+
53
+ def _parse_input(user_input: InvokeInput, thread_id: str, invoke_recvd_runnable_config: RunnableConfig | None) -> Tuple[Dict[str, Any], str]:
54
+ run_id = uuid4()
55
+ input_message = ChatMessage(type="human", content=user_input.message)
56
+
57
+ runnable_config = invoke_recvd_runnable_config or RunnableConfig()
58
+
59
+ runnable_config["configurable"] = {
60
+ **{
61
+ # used by checkpointer
62
+ "thread_id": thread_id,
63
+
64
+ "_has_threadinfo": True,
65
+
66
+ # "args": user_input.args,
67
+ },
68
+ **(runnable_config.get("configurable", {}))
69
+ }
70
+
71
+ kwargs = dict(
72
+ input={"messages": [input_message.to_langchain()]},
73
+ config=runnable_config
74
+ )
75
+ return kwargs, str(run_id)
76
+
77
+ @router.post("/invoke")
78
+ async def invoke(invoke_input: InvokeInput, request: Request) -> ChatMessage:
79
+ """
80
+ Invoke the agent with user input to retrieve a final response.
81
+
82
+ Use thread_id to persist and continue a multi-turn conversation. run_id kwarg
83
+ is also attached to messages for recording feedback.
84
+ """
85
+
86
+ graph = await as_awaitable(get_graph(request))
87
+ thread_id = await as_awaitable(get_thread_id(request))
88
+
89
+
90
+ thread_info = await ThreadsCheckpointerUtil.get_thread_info(thread_id, graph.checkpointer)
91
+
92
+ if not await as_awaitable(allow_invoke_thread(thread_id, thread_info, invoke_input, request)):
93
+ raise HTTPException(status_code=403, detail="Forbidden")
94
+
95
+ invoke_recvd_runnable_config = await as_awaitable(invoke_runnable_config(thread_id, thread_info, invoke_input, request))
96
+
97
+ kwargs, run_id = _parse_input(invoke_input, thread_id, invoke_recvd_runnable_config)
98
+
99
+ # # store this thread in the database if a new one
100
+ # if user_input.thread_id not in router.state.threads:
101
+ # thread_info = ThreadInfo(
102
+ # thread_id=user_input.thread_id,
103
+ # user=principal,
104
+ # workflow_id=user_input.workflow,
105
+ # name=user_input.message[:50],
106
+ # metadata={"router": user_input.router},
107
+ # )
108
+ # router.state.threads[user_input.thread_id] = thread_info
109
+ # await graph.checkpointer.aput_thread(thread_info)
110
+
111
+ await as_awaitable(on_invoke_thread(thread_id, thread_info, invoke_input, request))
112
+
113
+ # langfuse_handler = CallbackHandler(
114
+ # public_key=router.state.cfg.logging.langfuse.public_key,
115
+ # secret_key=router.state.cfg.logging.langfuse.secret_key,
116
+ # host=router.state.cfg.logging.langfuse.host,
117
+ # # user_id=principal,
118
+ # session_id=user_input.thread_id,
119
+ # trace_name=user_input.message[:50],
120
+ # )
121
+ kwargs["config"]["callbacks"] = [] # was [langfuse_handler]
122
+ # kwargs["config"]["configurable"]["workflow_id"] = user_input.workflow
123
+ try:
124
+ response = await graph.ainvoke(**kwargs)
125
+ output = ChatMessage.from_langchain(response["messages"][-1])
126
+ output.run_id = str(run_id)
127
+ return output
128
+ except Exception as e:
129
+ raise HTTPException(status_code=500, detail=str(e))
130
+
131
+ @router.post("/stream")
132
+ async def stream_agent(stream_input: StreamInput, request: Request):
133
+ """
134
+ Stream the agent's response to a user input, including intermediate messages and tokens.
135
+
136
+ Use thread_id to persist and continue a multi-turn conversation. run_id kwarg
137
+ is also attached to all messages for recording feedback.
138
+ """
139
+
140
+ graph = await as_awaitable(get_graph(request))
141
+ thread_id = await as_awaitable(get_thread_id(request))
142
+
143
+ thread_info = await ThreadsCheckpointerUtil.get_thread_info(thread_id, graph.checkpointer)
144
+
145
+ if not await as_awaitable(allow_invoke_thread(thread_id, thread_info, stream_input, request)):
146
+ raise HTTPException(status_code=403, detail="Forbidden")
147
+
148
+ invoke_recvd_runnable_config = await as_awaitable(invoke_runnable_config(thread_id, thread_info, stream_input, request))
149
+
150
+ async def message_generator() -> AsyncGenerator[str, None]:
151
+ """
152
+ Generate a stream of messages from the agent.
153
+
154
+ This is the workhorse method for the /stream endpoint.
155
+ """
156
+ kwargs, run_id = _parse_input(stream_input, thread_id, invoke_recvd_runnable_config)
157
+
158
+ await as_awaitable(on_invoke_thread(thread_id, thread_info, stream_input, request))
159
+
160
+ # # store this thread in the database if a new one
161
+ # if user_input.thread_id not in router.state.threads:
162
+ # thread_info = ThreadInfo(
163
+ # thread_id=user_input.thread_id,
164
+ # user=principal,
165
+ # workflow_id=user_input.workflow,
166
+ # name=user_input.message[:50],
167
+ # metadata={"router": user_input.router},
168
+ # )
169
+ # router.state.threads[user_input.thread_id] = thread_info
170
+ # await graph.checkpointer.aput_thread(thread_info)
171
+
172
+ # Use an asyncio queue to process both messages and tokens in
173
+ # chronological order, so we can easily yield them to the client.
174
+ output_queue = asyncio.Queue(maxsize=10)
175
+
176
+ # langfuse_handler = CallbackHandler(
177
+ # public_key=router.state.cfg.logging.langfuse.public_key,
178
+ # secret_key=router.state.cfg.logging.langfuse.secret_key,
179
+ # host=router.state.cfg.logging.langfuse.host,
180
+ # user_id=principal,
181
+ # session_id=user_input.thread_id,
182
+ # trace_name=user_input.message[:50],
183
+ # )
184
+ if stream_input.stream_tokens:
185
+ kwargs["config"]["callbacks"] = [
186
+ TokenQueueStreamingHandler(queue=output_queue),
187
+ # langfuse_handler,
188
+ ]
189
+ # kwargs["config"]["configurable"]["workflow_id"] = stream_input.workflow
190
+
191
+ # Pass the agent's stream of messages to the queue in a separate task, so
192
+ # we can yield the messages to the client in the main thread.
193
+ async def run_agent_stream():
194
+ async for s in graph.astream(**kwargs, stream_mode="updates"):
195
+ await output_queue.put(s)
196
+ await output_queue.put(None)
197
+
198
+ stream_task = asyncio.create_task(run_agent_stream())
199
+
200
+ # Process the queue and yield messages over the SSE stream.
201
+ while s := await output_queue.get():
202
+ log.info("Got from queue: %s: %s", type(s), s)
203
+ if isinstance(s, str):
204
+ # str is an LLM token
205
+ yield f"data: {json.dumps({'type': 'token', 'content': s})}\n\n"
206
+ continue
207
+
208
+ # Otherwise, s should be a dict of state updates for each node in the graph.
209
+ # s could have updates for multiple nodes, so check each for messages.
210
+ new_messages = []
211
+ for _, state in s.items():
212
+ new_messages.extend(state["messages"])
213
+ for message in new_messages:
214
+ try:
215
+ chat_message = ChatMessage.from_langchain(message)
216
+ chat_message.run_id = str(run_id)
217
+ except Exception as e:
218
+ yield f"data: {json.dumps({'type': 'error', 'content': f'Error parsing message: {e}'})}\n\n"
219
+ continue
220
+ # LangGraph re-sends the input message, which feels weird, so drop it
221
+ if (
222
+ chat_message.type == "human"
223
+ and chat_message.content == stream_input.message
224
+ ):
225
+ continue
226
+ yield f"data: {json.dumps({'type': 'message', 'content': chat_message.dict()})}\n\n"
227
+
228
+ await stream_task
229
+ yield "data: [DONE]\n\n"
230
+
231
+ return StreamingResponse(
232
+ message_generator(),
233
+ media_type="text/event-stream",
234
+ )
235
+
236
+ @router.get("/history")
237
+ async def get_history(request: Request) -> List[ChatMessage]:
238
+ """
239
+ Get the history of a thread.
240
+ """
241
+
242
+ graph = await as_awaitable(get_graph(request))
243
+ thread_id = await as_awaitable(get_thread_id(request))
244
+
245
+ thread_info = await ThreadsCheckpointerUtil.get_thread_info(thread_id, graph.checkpointer)
246
+
247
+ if not await as_awaitable(allow_access_thread(thread_id, thread_info, request)):
248
+ raise HTTPException(status_code=403, detail="Forbidden")
249
+
250
+ # agent: CompiledGraph = router.state.workflows[workflow].get_graph()
251
+ config = RunnableConfig(configurable={
252
+ # used by checkpointer
253
+ "thread_id": thread_id,
254
+ })
255
+ state = await graph.aget_state(config)
256
+ messages = state.values.get("messages", [])
257
+
258
+ converted_messages: List[ChatMessage] = []
259
+ for message in messages:
260
+ try:
261
+ chat_message = ChatMessage.from_langchain(message)
262
+ converted_messages.append(chat_message)
263
+ except Exception as e:
264
+ log.error(f"Error parsing message: {e}")
265
+ continue
266
+ return converted_messages
267
+
268
+ # @router.get("/feedback")
269
+ # async def get_feedback(request: Request, thread_id: str):
270
+ # """Get all feedback for a thread.
271
+ #
272
+ # Arguments:
273
+ # thread_id: The ID of the thread to get feedback for.
274
+ # """
275
+ # # if thread_id not in router.state.threads:
276
+ # # raise HTTPException(status_code=404, detail=f"Unknown thread: {thread_id}")
277
+ # # assert_viewer_can_assume_identity(
278
+ # # request, principal=router.state.threads[thread_id].user
279
+ # # )
280
+ # feedback = [
281
+ # f.model_dump(mode="json")
282
+ # async for f in graph.checkpointer.alist_feedback(thread_id=thread_id)
283
+ # ]
284
+ # return feedback
285
+ #
286
+ # @router.post("/feedback")
287
+ # async def feedback(feedback: Feedback, request: Request):
288
+ # """
289
+ # Record feedback for a run of the agent.
290
+ #
291
+ # Arguments:
292
+ # feedback: The feedback to record.
293
+ # """
294
+ # if feedback.thread_id not in router.state.threads:
295
+ # raise HTTPException(
296
+ # status_code=404, detail=f"Unknown thread: {feedback.thread_id}"
297
+ # )
298
+ # assert_viewer_can_assume_identity(
299
+ # request, principal=router.state.threads[feedback.thread_id].user
300
+ # )
301
+ #
302
+ # # store in database
303
+ # try:
304
+ # await graph.checkpointer.aput_feedback(feedback)
305
+ # db_status = "success"
306
+ # except Exception as e:
307
+ # log.error(f"Error storing feedback in database: {e}")
308
+ # db_status = "error"
309
+ #
310
+ # ## Also store in Langfuse
311
+ # ## We don't have the run_id, but need it for Langfuse
312
+ # ## The run_id is currently not store in the database.
313
+ # # try:
314
+ # # langfuse = Langfuse(
315
+ # # public_key=router.state.cfg.logging.langfuse.public_key,
316
+ # # secret_key=router.state.cfg.logging.langfuse.secret_key,
317
+ # # host=router.state.cfg.logging.langfuse.host,
318
+ # # )
319
+ # # langfuse.score(
320
+ # # trace_id=feedback.run_id,
321
+ # # name=feedback.key,
322
+ # # value=feedback.score,
323
+ # # comment=feedback.kwargs.get("comment", ""),
324
+ # # )
325
+ # # langfuse_status = "success"
326
+ # # except Exception as e:
327
+ # # log.error(f"Error storing feedback in Langfuse: {e}")
328
+ # # langfuse_status = "error"
329
+ #
330
+ # langfuse_status = "not implemented"
331
+ #
332
+ # return {"db_status": db_status, "langfuse_status": langfuse_status}
333
+
334
+ return router
@@ -0,0 +1,169 @@
1
+ from datetime import datetime
2
+ from typing import Dict, Any, List, Literal, Optional, Union
3
+ from langchain_core.messages import (
4
+ BaseMessage,
5
+ HumanMessage,
6
+ AIMessage,
7
+ ToolMessage,
8
+ ToolCall,
9
+ message_to_dict,
10
+ messages_from_dict,
11
+ )
12
+ from pydantic import BaseModel, Field
13
+
14
+ class InvokeInput(BaseModel):
15
+ """Basic user input for the agent."""
16
+
17
+ message: str = Field(
18
+ description="User input to the agent.",
19
+ examples=["What is the weather in Tokyo?"],
20
+ )
21
+ # args: Dict[str, Any] = Field(
22
+ # description="Arguments to pass to the workflow.",
23
+ # default={},
24
+ # examples=[{"kb": "veritone_support"}],
25
+ # )
26
+ # user: Optional[str] = Field(
27
+ # description="A user identifier to validate the user in knowledge bases and other tools.",
28
+ # default=None,
29
+ # examples=["jjohnson", "ccarlson"],
30
+ # )
31
+ # thread_id: str = Field(
32
+ # description="Thread ID to persist and continue a multi-turn conversation.",
33
+ # examples=["847c6285-8fc9-4560-a83f-4e6285809254"],
34
+ # )
35
+
36
+
37
+ class StreamInput(InvokeInput):
38
+ """User input for streaming the agent's response."""
39
+
40
+ stream_tokens: bool = Field(
41
+ description="Whether to stream LLM tokens to the client.",
42
+ default=True,
43
+ )
44
+
45
+
46
+ class AgentResponse(BaseModel):
47
+ """Response from the agent when called via /invoke."""
48
+
49
+ message: Dict[str, Any] = Field(
50
+ description="Final response from the agent, as a serialized LangChain message.",
51
+ examples=[
52
+ {
53
+ "message": {
54
+ "type": "ai",
55
+ "data": {
56
+ "content": "The weather in Tokyo is 70 degrees.",
57
+ "type": "ai",
58
+ },
59
+ }
60
+ }
61
+ ],
62
+ )
63
+
64
+
65
+ class ChatMessage(BaseModel):
66
+ """Message in a chat."""
67
+
68
+ type: Literal["human", "ai", "tool"] = Field(
69
+ description="Role of the message.",
70
+ examples=["human", "ai", "tool"],
71
+ )
72
+ content: Union[str, list[Union[str, dict]]] = Field(
73
+ description="Content of the message.",
74
+ examples=["Hello, world!"],
75
+ )
76
+ tool_calls: List[ToolCall] = Field(
77
+ description="Tool calls in the message.",
78
+ default=[],
79
+ )
80
+ tool_call_id: str | None = Field(
81
+ description="Tool call that this message is responding to.",
82
+ default=None,
83
+ examples=["call_Jja7J89XsjrOLA5r!MEOW!SL"],
84
+ )
85
+ run_id: str | None = Field(
86
+ description="Run ID of the message.",
87
+ default=None,
88
+ examples=["847c6285-8fc9-4560-a83f-4e6285809254"],
89
+ )
90
+ original: Dict[str, Any] = Field(
91
+ description="Original LangChain message in serialized form.",
92
+ default={},
93
+ )
94
+
95
+ @classmethod
96
+ def from_langchain(cls, message: BaseMessage) -> "ChatMessage":
97
+ """Create a ChatMessage from a LangChain message."""
98
+ original = message_to_dict(message)
99
+ match message:
100
+ case HumanMessage():
101
+ human_message = cls(
102
+ type="human", content=message.content, original=original
103
+ )
104
+ return human_message
105
+ case AIMessage():
106
+ ai_message = cls(type="ai", content=message.content, original=original)
107
+ if message.tool_calls:
108
+ ai_message.tool_calls = message.tool_calls
109
+ return ai_message
110
+ case ToolMessage():
111
+ tool_message = cls(
112
+ type="tool",
113
+ content=message.content,
114
+ tool_call_id=message.tool_call_id,
115
+ original=original,
116
+ )
117
+ return tool_message
118
+ case _:
119
+ raise ValueError(
120
+ f"Unsupported message type: {message.__class__.__name__}"
121
+ )
122
+
123
+ def to_langchain(self) -> BaseMessage:
124
+ """Convert the ChatMessage to a LangChain message."""
125
+ if self.original:
126
+ return messages_from_dict([self.original])[0]
127
+ match self.type:
128
+ case "human":
129
+ return HumanMessage(content=self.content)
130
+ case _:
131
+ raise NotImplementedError(f"Unsupported message type: {self.type}")
132
+
133
+ def pretty_print(self) -> None:
134
+ """Pretty print the ChatMessage."""
135
+ lc_msg = self.to_langchain()
136
+ lc_msg.pretty_print()
137
+
138
+ def get_artifact(self) -> Optional[Dict[str, Any]]:
139
+ """Get the artifact from the message if there is one."""
140
+ if (
141
+ self.original.get("data")
142
+ and self.original["data"].get("artifact")
143
+ and self.original["data"]["artifact"]
144
+ ):
145
+ return self.original["data"]["artifact"]
146
+ return None
147
+
148
+
149
+ class Feedback(BaseModel):
150
+ """Feedback for a run."""
151
+
152
+ message_id: str = Field(
153
+ description="Message ID to record feedback for.",
154
+ examples=["847c6285-8fc9-4560-a83f-4e6285809254"],
155
+ )
156
+ # thread_id: str = Field(
157
+ # description="Thread ID to record feedback for.",
158
+ # examples=["847c6285-8fc9-4560-a83f-4e6285809254"],
159
+ # )
160
+ score: float = Field(
161
+ description="Feedback score.",
162
+ examples=[0.8],
163
+ )
164
+ kwargs: Dict[str, Any] = Field(
165
+ description="Additional feedback kwargs, passed to LangSmith.",
166
+ default={},
167
+ examples=[{"comment": "In-line human feedback"}],
168
+ )
169
+ creation: datetime = Field(default_factory=datetime.now)
@@ -0,0 +1,59 @@
1
+ from typing import Annotated
2
+
3
+ from langchain_aws import ChatBedrock
4
+ from typing_extensions import TypedDict
5
+
6
+ from fastapi import FastAPI
7
+ from langgraph.checkpoint.memory import MemorySaver
8
+ from langgraph.graph import StateGraph
9
+ from langgraph.graph.message import add_messages
10
+ import uvicorn
11
+
12
+ from veri_agents_api.fastapi.thread import create_thread_router
13
+
14
+ if __name__ == "__main__":
15
+ class State(TypedDict):
16
+ messages: Annotated[list, add_messages]
17
+
18
+ graph_builder = StateGraph(State)
19
+
20
+ llm = ChatBedrock(model_id="us.anthropic.claude-3-7-sonnet-20250219-v1:0") # pyright: ignore[reportCallIssue]
21
+
22
+ def chatbot(state: State):
23
+ return {"messages": [llm.invoke(state["messages"])]}
24
+
25
+ # The first argument is the unique node name
26
+ # The second argument is the function or object that will be called whenever
27
+ # the node is used.
28
+ graph_builder.add_node("chatbot", chatbot)
29
+ graph_builder.set_entry_point("chatbot")
30
+ graph_builder.set_finish_point("chatbot")
31
+
32
+
33
+
34
+ # in-memory persistence
35
+ memory = MemorySaver()
36
+ graph = graph_builder.compile(checkpointer=memory)
37
+
38
+ # veri-agents convenience router
39
+ thread_router = create_thread_router(
40
+ # same graph for every request
41
+ get_graph=lambda req: graph,
42
+ # derive thread id from /thread/{thread_id} path param
43
+ get_thread_id=lambda req: req.path_params["thread_id"]
44
+ )
45
+
46
+ # root fastapi app
47
+ app = FastAPI()
48
+ app.include_router(thread_router, prefix="/threads/{thread_id}")
49
+
50
+ uvicorn.run(app, port=5000, log_level="info")
51
+ # you can now access:
52
+ # GET /openapi.json
53
+ # GET /threads
54
+ # GET /thread/{thread_id}/info
55
+ # POST /threads/{thread_id}/invoke
56
+ # POST /threads/{thread_id}/stream
57
+ # GET /threads/{thread_id}/history
58
+ # GET /threads/{thread_id}/feedback
59
+ # POST /threads/{thread_id}/feedback
@@ -0,0 +1,58 @@
1
+ from typing import Annotated
2
+
3
+ from langchain_aws import ChatBedrock
4
+ from typing_extensions import TypedDict
5
+
6
+ from langgraph.checkpoint.memory import MemorySaver
7
+ from langgraph.graph import StateGraph
8
+ from langgraph.graph.message import add_messages
9
+ from fastapi import FastAPI
10
+ import uvicorn
11
+
12
+ from veri_agents_api.fastapi.thread import create_thread_router
13
+
14
+ if __name__ == "__main__":
15
+ class State(TypedDict):
16
+ messages: Annotated[list, add_messages]
17
+
18
+ graph_builder = StateGraph(State)
19
+
20
+ llm = ChatBedrock(model_id="us.anthropic.claude-3-7-sonnet-20250219-v1:0") # pyright: ignore[reportCallIssue]
21
+
22
+ def chatbot(state: State):
23
+ return {"messages": [llm.invoke(state["messages"])]}
24
+
25
+ # The first argument is the unique node name
26
+ # The second argument is the function or object that will be called whenever
27
+ # the node is used.
28
+ graph_builder.add_node("chatbot", chatbot)
29
+ graph_builder.set_entry_point("chatbot")
30
+ graph_builder.set_finish_point("chatbot")
31
+
32
+
33
+
34
+
35
+ # in-memory persistence
36
+ memory = MemorySaver()
37
+ graph = graph_builder.compile(checkpointer=memory)
38
+
39
+ # veri-agents convenience router
40
+ thread_router = create_thread_router(
41
+ # same graph for every request
42
+ get_graph=lambda req: graph,
43
+ # same thread for every request
44
+ get_thread_id=lambda req: "inmem"
45
+ )
46
+
47
+ # root fastapi app
48
+ app = FastAPI()
49
+ app.include_router(thread_router)
50
+
51
+ uvicorn.run(app, port=5000, log_level="info")
52
+ # you can now access:
53
+ # GET /openapi.json
54
+ # POST /invoke
55
+ # POST /stream
56
+ # GET /history
57
+ # GET /feedback
58
+ # POST /feedback
@@ -0,0 +1 @@
1
+ from .router import *
@@ -0,0 +1,75 @@
1
+ import logging
2
+ from typing import Callable, cast, Type, Awaitable
3
+
4
+ from fastapi import HTTPException, Request, APIRouter
5
+ from langgraph.graph.graph import CompiledGraph
6
+
7
+ from veri_agents_api.fastapi.thread import (
8
+ create_thread_router as create_thread_router, InvokeInput
9
+ )
10
+ from veri_agents_api.threads_util import ThreadsCheckpointerUtil, ThreadInfo
11
+ from veri_agents_api.util.awaitable import MaybeAwaitable, as_awaitable
12
+
13
+ log = logging.getLogger(__name__)
14
+
15
+
16
+ def create_threads_router(
17
+ get_graph: Callable[[Request], MaybeAwaitable[CompiledGraph]],
18
+ allow_access_thread: Callable[[str, ThreadInfo | None, Request], MaybeAwaitable[bool]] = lambda thread_id, thread_info, request: True,
19
+ allow_invoke_thread: Callable[[str, ThreadInfo | None, InvokeInput, Request], MaybeAwaitable[bool]] = lambda thread_id, thread_info, invoke_input, request: True,
20
+ on_invoke_thread: Callable[[str, ThreadInfo | None, InvokeInput, Request], MaybeAwaitable[None]] = lambda thread_id, thread_info, invoke_input, request: None,
21
+ # InvokeInputCls: Type[InvokeInput] = InvokeInput,
22
+ **router_kwargs
23
+ ):
24
+ router = APIRouter(prefix="/threads", **router_kwargs)
25
+
26
+ thread_router = create_thread_router(
27
+ # derive thread id from /thread/{thread_id} path param
28
+ get_thread_id=lambda req: req.path_params["thread_id"],
29
+
30
+ # arg passthrough - TODO: make more elegant
31
+ get_graph=get_graph,
32
+ allow_access_thread=allow_access_thread,
33
+ allow_invoke_thread=allow_invoke_thread,
34
+ on_invoke_thread=on_invoke_thread,
35
+
36
+ # InvokeInputCls=InvokeInputCls
37
+ )
38
+
39
+ @router.get("/")
40
+ async def get_threads(request: Request):
41
+ """Get all threads the user has access to."""
42
+
43
+ graph = await as_awaitable(get_graph(request))
44
+
45
+ all_thread_info = await ThreadsCheckpointerUtil.list_threads(graph.checkpointer)
46
+
47
+ accessible_thread_info: list[ThreadInfo] = []
48
+ for thread_info in all_thread_info:
49
+ if allow_access_thread(thread_info.thread_id, thread_info, request):
50
+ accessible_thread_info.append(thread_info)
51
+
52
+ return accessible_thread_info
53
+
54
+ @router.get("/{thread_id}/info")
55
+ async def get_thread_by_id(thread_id: str, request: Request):
56
+ """Get a thread by its ID.
57
+
58
+ Arguments:
59
+ thread_id: The ID of the thread to get.
60
+ """
61
+ graph = await as_awaitable(get_graph(request))
62
+
63
+ thread_info = await ThreadsCheckpointerUtil.get_thread_info(thread_id, graph.checkpointer)
64
+
65
+ if not allow_access_thread(thread_id, thread_info, request):
66
+ raise HTTPException(status_code=403, detail="Forbidden")
67
+
68
+ try:
69
+ return thread_info
70
+ except:
71
+ raise HTTPException(status_code=404, detail="Thread not found")
72
+
73
+ router.include_router(thread_router, prefix="/{thread_id}")
74
+
75
+ return router
@@ -0,0 +1,58 @@
1
+ from typing import Annotated
2
+
3
+ from langchain_aws import ChatBedrock
4
+ from typing_extensions import TypedDict
5
+
6
+ from fastapi import FastAPI
7
+ from langgraph.checkpoint.memory import MemorySaver
8
+ from langgraph.graph import StateGraph
9
+ from langgraph.graph.message import add_messages
10
+ import uvicorn
11
+
12
+ from veri_agents_api.fastapi.threads import create_threads_router
13
+
14
+ if __name__ == "__main__":
15
+ class State(TypedDict):
16
+ messages: Annotated[list, add_messages]
17
+
18
+ graph_builder = StateGraph(State)
19
+
20
+ # noinspection PyArgumentList
21
+ llm = ChatBedrock(model_id="us.anthropic.claude-3-7-sonnet-20250219-v1:0") # pyright: ignore[reportCallIssue]
22
+
23
+ def chatbot(state: State):
24
+ return {"messages": [llm.invoke(state["messages"])]}
25
+
26
+ # The first argument is the unique node name
27
+ # The second argument is the function or object that will be called whenever
28
+ # the node is used.
29
+ graph_builder.add_node("chatbot", chatbot)
30
+ graph_builder.set_entry_point("chatbot")
31
+ graph_builder.set_finish_point("chatbot")
32
+
33
+
34
+
35
+ # in-memory persistence
36
+ memory = MemorySaver()
37
+ graph = graph_builder.compile(checkpointer=memory)
38
+
39
+ # veri-agents convenience router
40
+ threads_router = create_threads_router(
41
+ # same graph for every request
42
+ get_graph=lambda req: graph
43
+ )
44
+
45
+ # root fastapi app
46
+ app = FastAPI()
47
+ app.include_router(threads_router)
48
+
49
+ uvicorn.run(app, port=5000, log_level="info")
50
+ # you can now access:
51
+ # GET /openapi.json
52
+ # GET /threads
53
+ # GET /threads/{thread_id}/info
54
+ # POST /threads/{thread_id}/invoke
55
+ # POST /threads/{thread_id}/stream
56
+ # GET /threads/{thread_id}/history
57
+ # GET /threads/{thread_id}/feedback
58
+ # POST /threads/{thread_id}/feedback
@@ -0,0 +1,2 @@
1
+ from .checkpointer import *
2
+ from .schema import *
@@ -0,0 +1,43 @@
1
+ from typing import cast
2
+
3
+ from .schema import ThreadInfo
4
+ from langgraph.types import Checkpointer
5
+ from langgraph.checkpoint.base import BaseCheckpointSaver
6
+
7
+ class ThreadsCheckpointerUtil:
8
+ @staticmethod
9
+ async def get_thread_info(thread_id: str, checkpointer: Checkpointer) -> ThreadInfo | None:
10
+ if not isinstance(checkpointer, BaseCheckpointSaver):
11
+ raise Exception("checkpointer must be instance of BaseCheckpointSaver")
12
+
13
+ chk_tuple = (await checkpointer.aget_tuple(config={
14
+ "configurable": {"_has_threadinfo": True, "thread_id": thread_id}}))
15
+
16
+ if chk_tuple is None:
17
+ return None
18
+
19
+ thread_metadata = chk_tuple.metadata
20
+ return ThreadInfo(
21
+ thread_id=thread_id,
22
+ )
23
+
24
+ @staticmethod
25
+ async def list_threads(checkpointer: Checkpointer) -> list[ThreadInfo]:
26
+ if not isinstance(checkpointer, BaseCheckpointSaver):
27
+ raise Exception("checkpointer must be instance of BaseCheckpointSaver")
28
+
29
+ init_step_checkpoints = checkpointer.alist(config=None, filter={ 'step': -1 }) # get initial steps only - this ensures we are only getting one thread_id from the checkpoints of a thread
30
+
31
+ all_accessible_thread_ids: list[str] = []
32
+ async for checkpoint in init_step_checkpoints:
33
+ thread_id = cast(str | None, checkpoint.config.get("configurable", {"thread_id": None}).get("thread_id", None))
34
+ if thread_id is not None: # and allow_access_thread(thread_id, request) ?
35
+ all_accessible_thread_ids.append(cast(str, thread_id))
36
+
37
+ all_thread_info: list[ThreadInfo] = []
38
+ for thread_id in all_accessible_thread_ids:
39
+ thread_info = await ThreadsCheckpointerUtil.get_thread_info(thread_id, checkpointer)
40
+ if thread_info is not None:
41
+ all_thread_info.append(thread_info)
42
+
43
+ return all_thread_info
@@ -0,0 +1,13 @@
1
+ from datetime import datetime
2
+
3
+ from pydantic import BaseModel, Field
4
+
5
+ class ThreadInfo(BaseModel):
6
+ """Information about a single thread."""
7
+
8
+ thread_id: str
9
+ # workflow_id: str
10
+ # name: str
11
+ # user: str
12
+ # metadata: dict = Field(default={})
13
+ # creation: datetime = Field(default_factory=datetime.now)
@@ -0,0 +1,11 @@
1
+ import asyncio
2
+ from typing import TypeVar, Awaitable, cast
3
+
4
+ T = TypeVar('T')
5
+
6
+ async def as_awaitable(maybe_coroutine: T | Awaitable[T]) -> T:
7
+ if asyncio.iscoroutine(maybe_coroutine):
8
+ return await maybe_coroutine
9
+ return cast(T, maybe_coroutine)
10
+
11
+ type MaybeAwaitable[T] = T | Awaitable[T]