veri-agents-api 0.1.1__tar.gz
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- veri_agents_api-0.1.1/.gitignore +186 -0
- veri_agents_api-0.1.1/PKG-INFO +17 -0
- veri_agents_api-0.1.1/package.json +16 -0
- veri_agents_api-0.1.1/pyproject.toml +30 -0
- veri_agents_api-0.1.1/src/veri_agents_api/__init__.py +0 -0
- veri_agents_api-0.1.1/src/veri_agents_api/fastapi/__init__.py +0 -0
- veri_agents_api-0.1.1/src/veri_agents_api/fastapi/thread/__init__.py +2 -0
- veri_agents_api-0.1.1/src/veri_agents_api/fastapi/thread/router.py +334 -0
- veri_agents_api-0.1.1/src/veri_agents_api/fastapi/thread/schema.py +169 -0
- veri_agents_api-0.1.1/src/veri_agents_api/fastapi/thread/test/multi_thread.py +59 -0
- veri_agents_api-0.1.1/src/veri_agents_api/fastapi/thread/test/single_thread.py +58 -0
- veri_agents_api-0.1.1/src/veri_agents_api/fastapi/threads/__init__.py +1 -0
- veri_agents_api-0.1.1/src/veri_agents_api/fastapi/threads/router.py +75 -0
- veri_agents_api-0.1.1/src/veri_agents_api/fastapi/threads/test/__main__.py +58 -0
- veri_agents_api-0.1.1/src/veri_agents_api/threads_util/__init__.py +2 -0
- veri_agents_api-0.1.1/src/veri_agents_api/threads_util/checkpointer.py +43 -0
- veri_agents_api-0.1.1/src/veri_agents_api/threads_util/schema.py +13 -0
- veri_agents_api-0.1.1/src/veri_agents_api/util/__init__.py +0 -0
- veri_agents_api-0.1.1/src/veri_agents_api/util/awaitable.py +11 -0
|
@@ -0,0 +1,186 @@
|
|
|
1
|
+
.turbo
|
|
2
|
+
|
|
3
|
+
# go
|
|
4
|
+
vendor
|
|
5
|
+
|
|
6
|
+
# js
|
|
7
|
+
dist
|
|
8
|
+
out-tsc
|
|
9
|
+
node_modules
|
|
10
|
+
|
|
11
|
+
.idea
|
|
12
|
+
*.iml
|
|
13
|
+
.DS_Store
|
|
14
|
+
|
|
15
|
+
|
|
16
|
+
# Byte-compiled / optimized / DLL files
|
|
17
|
+
__pycache__/
|
|
18
|
+
*.py[cod]
|
|
19
|
+
*$py.class
|
|
20
|
+
|
|
21
|
+
# C extensions
|
|
22
|
+
*.so
|
|
23
|
+
|
|
24
|
+
# Distribution / packaging
|
|
25
|
+
.Python
|
|
26
|
+
build/
|
|
27
|
+
develop-eggs/
|
|
28
|
+
dist/
|
|
29
|
+
downloads/
|
|
30
|
+
eggs/
|
|
31
|
+
.eggs/
|
|
32
|
+
lib/
|
|
33
|
+
lib64/
|
|
34
|
+
parts/
|
|
35
|
+
sdist/
|
|
36
|
+
var/
|
|
37
|
+
wheels/
|
|
38
|
+
share/python-wheels/
|
|
39
|
+
*.egg-info/
|
|
40
|
+
.installed.cfg
|
|
41
|
+
*.egg
|
|
42
|
+
MANIFEST
|
|
43
|
+
|
|
44
|
+
# PyInstaller
|
|
45
|
+
# Usually these files are written by a python script from a template
|
|
46
|
+
# before PyInstaller builds the exe, so as to inject date/other infos into it.
|
|
47
|
+
*.manifest
|
|
48
|
+
*.spec
|
|
49
|
+
|
|
50
|
+
# Installer logs
|
|
51
|
+
pip-log.txt
|
|
52
|
+
pip-delete-this-directory.txt
|
|
53
|
+
|
|
54
|
+
# Unit test / coverage reports
|
|
55
|
+
htmlcov/
|
|
56
|
+
.tox/
|
|
57
|
+
.nox/
|
|
58
|
+
.coverage
|
|
59
|
+
.coverage.*
|
|
60
|
+
.cache
|
|
61
|
+
nosetests.xml
|
|
62
|
+
coverage.xml
|
|
63
|
+
*.cover
|
|
64
|
+
*.py,cover
|
|
65
|
+
.hypothesis/
|
|
66
|
+
.pytest_cache/
|
|
67
|
+
cover/
|
|
68
|
+
|
|
69
|
+
# Translations
|
|
70
|
+
*.mo
|
|
71
|
+
*.pot
|
|
72
|
+
|
|
73
|
+
# Django stuff:
|
|
74
|
+
*.log
|
|
75
|
+
local_settings.py
|
|
76
|
+
db.sqlite3
|
|
77
|
+
db.sqlite3-journal
|
|
78
|
+
|
|
79
|
+
# Flask stuff:
|
|
80
|
+
instance/
|
|
81
|
+
.webassets-cache
|
|
82
|
+
|
|
83
|
+
# Scrapy stuff:
|
|
84
|
+
.scrapy
|
|
85
|
+
|
|
86
|
+
# Sphinx documentation
|
|
87
|
+
docs/_build/
|
|
88
|
+
|
|
89
|
+
# PyBuilder
|
|
90
|
+
.pybuilder/
|
|
91
|
+
target/
|
|
92
|
+
|
|
93
|
+
# Jupyter Notebook
|
|
94
|
+
.ipynb_checkpoints
|
|
95
|
+
|
|
96
|
+
# IPython
|
|
97
|
+
profile_default/
|
|
98
|
+
ipython_config.py
|
|
99
|
+
|
|
100
|
+
# pyenv
|
|
101
|
+
# For a library or package, you might want to ignore these files since the code is
|
|
102
|
+
# intended to run in multiple environments; otherwise, check them in:
|
|
103
|
+
# .python-version
|
|
104
|
+
|
|
105
|
+
# pipenv
|
|
106
|
+
# According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.
|
|
107
|
+
# However, in case of collaboration, if having platform-specific dependencies or dependencies
|
|
108
|
+
# having no cross-platform support, pipenv may install dependencies that don't work, or not
|
|
109
|
+
# install all needed dependencies.
|
|
110
|
+
#Pipfile.lock
|
|
111
|
+
|
|
112
|
+
# UV
|
|
113
|
+
# Similar to Pipfile.lock, it is generally recommended to include uv.lock in version control.
|
|
114
|
+
# This is especially recommended for binary packages to ensure reproducibility, and is more
|
|
115
|
+
# commonly ignored for libraries.
|
|
116
|
+
#uv.lock
|
|
117
|
+
|
|
118
|
+
# poetry
|
|
119
|
+
# Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control.
|
|
120
|
+
# This is especially recommended for binary packages to ensure reproducibility, and is more
|
|
121
|
+
# commonly ignored for libraries.
|
|
122
|
+
# https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control
|
|
123
|
+
#poetry.lock
|
|
124
|
+
|
|
125
|
+
# pdm
|
|
126
|
+
# Similar to Pipfile.lock, it is generally recommended to include pdm.lock in version control.
|
|
127
|
+
#pdm.lock
|
|
128
|
+
# pdm stores project-wide configurations in .pdm.toml, but it is recommended to not include it
|
|
129
|
+
# in version control.
|
|
130
|
+
# https://pdm.fming.dev/latest/usage/project/#working-with-version-control
|
|
131
|
+
.pdm.toml
|
|
132
|
+
.pdm-python
|
|
133
|
+
.pdm-build/
|
|
134
|
+
|
|
135
|
+
# PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm
|
|
136
|
+
__pypackages__/
|
|
137
|
+
|
|
138
|
+
# Celery stuff
|
|
139
|
+
celerybeat-schedule
|
|
140
|
+
celerybeat.pid
|
|
141
|
+
|
|
142
|
+
# SageMath parsed files
|
|
143
|
+
*.sage.py
|
|
144
|
+
|
|
145
|
+
# Environments
|
|
146
|
+
.env
|
|
147
|
+
.venv
|
|
148
|
+
env/
|
|
149
|
+
venv/
|
|
150
|
+
ENV/
|
|
151
|
+
env.bak/
|
|
152
|
+
venv.bak/
|
|
153
|
+
|
|
154
|
+
# Spyder project settings
|
|
155
|
+
.spyderproject
|
|
156
|
+
.spyproject
|
|
157
|
+
|
|
158
|
+
# Rope project settings
|
|
159
|
+
.ropeproject
|
|
160
|
+
|
|
161
|
+
# mkdocs documentation
|
|
162
|
+
/site
|
|
163
|
+
|
|
164
|
+
# mypy
|
|
165
|
+
.mypy_cache/
|
|
166
|
+
.dmypy.json
|
|
167
|
+
dmypy.json
|
|
168
|
+
|
|
169
|
+
# Pyre type checker
|
|
170
|
+
.pyre/
|
|
171
|
+
|
|
172
|
+
# pytype static type analyzer
|
|
173
|
+
.pytype/
|
|
174
|
+
|
|
175
|
+
# Cython debug symbols
|
|
176
|
+
cython_debug/
|
|
177
|
+
|
|
178
|
+
# PyCharm
|
|
179
|
+
# JetBrains specific template is maintained in a separate JetBrains.gitignore that can
|
|
180
|
+
# be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore
|
|
181
|
+
# and can be added to the global gitignore or merged into this file. For a more nuclear
|
|
182
|
+
# option (not recommended) you can uncomment the following to ignore the entire idea folder.
|
|
183
|
+
#.idea/
|
|
184
|
+
|
|
185
|
+
# PyPI configuration file
|
|
186
|
+
.pypirc
|
|
@@ -0,0 +1,17 @@
|
|
|
1
|
+
Metadata-Version: 2.4
|
|
2
|
+
Name: veri-agents-api
|
|
3
|
+
Version: 0.1.1
|
|
4
|
+
Summary: Add your description here
|
|
5
|
+
Author-email: Markus Toman <mtoman@veritone.com>, Teo Boley <tboley@veritone.com>
|
|
6
|
+
Requires-Python: >=3.12
|
|
7
|
+
Requires-Dist: veri-agents-common[langgraph]==0.1.1
|
|
8
|
+
Provides-Extra: dev
|
|
9
|
+
Requires-Dist: langchain-aws>=0.2.21; extra == 'dev'
|
|
10
|
+
Requires-Dist: uvicorn>=0.34.2; extra == 'dev'
|
|
11
|
+
Requires-Dist: veri-agents-common[fastapi,langfuse]==0.1.1; extra == 'dev'
|
|
12
|
+
Provides-Extra: fastapi
|
|
13
|
+
Requires-Dist: veri-agents-common[fastapi,langfuse]==0.1.1; extra == 'fastapi'
|
|
14
|
+
Provides-Extra: fastapi-dev
|
|
15
|
+
Requires-Dist: langchain-aws>=0.2.21; extra == 'fastapi-dev'
|
|
16
|
+
Requires-Dist: uvicorn>=0.34.2; extra == 'fastapi-dev'
|
|
17
|
+
Requires-Dist: veri-agents-common[fastapi,langfuse]==0.1.1; extra == 'fastapi-dev'
|
|
@@ -0,0 +1,16 @@
|
|
|
1
|
+
{
|
|
2
|
+
"name": "@veritone/agents-chat-api",
|
|
3
|
+
"version": "0.0.0",
|
|
4
|
+
"description": "",
|
|
5
|
+
"type": "module",
|
|
6
|
+
"scripts": {
|
|
7
|
+
"typecheck": "uv run pyright",
|
|
8
|
+
"format": "uv run ruff format",
|
|
9
|
+
"lint": "echo \"Warn: no lint specified\" && exit 0",
|
|
10
|
+
"test": "echo \"Warn: no test specified\" && exit 0"
|
|
11
|
+
},
|
|
12
|
+
"author": "",
|
|
13
|
+
"dependencies": {
|
|
14
|
+
"@veritone/agents-common": "workspace:"
|
|
15
|
+
}
|
|
16
|
+
}
|
|
@@ -0,0 +1,30 @@
|
|
|
1
|
+
[project]
|
|
2
|
+
name = "veri-agents-api"
|
|
3
|
+
version = "0.1.1"
|
|
4
|
+
description = "Add your description here"
|
|
5
|
+
authors = [
|
|
6
|
+
{name = "Markus Toman", email = "mtoman@veritone.com"},
|
|
7
|
+
{name = "Teo Boley", email = "tboley@veritone.com"},
|
|
8
|
+
]
|
|
9
|
+
requires-python = ">=3.12"
|
|
10
|
+
dependencies = [
|
|
11
|
+
"veri-agents-common[langgraph]==0.1.1"
|
|
12
|
+
]
|
|
13
|
+
|
|
14
|
+
[project.optional-dependencies]
|
|
15
|
+
fastapi = [
|
|
16
|
+
"veri-agents-common[fastapi,langfuse]==0.1.1",
|
|
17
|
+
]
|
|
18
|
+
fastapi-dev = [
|
|
19
|
+
"langchain-aws>=0.2.21",
|
|
20
|
+
"uvicorn>=0.34.2",
|
|
21
|
+
"veri-agents-api[fastapi]==0.1.1",
|
|
22
|
+
]
|
|
23
|
+
# as optional dep so it can be referenced in workspace pyproject.toml
|
|
24
|
+
dev = [
|
|
25
|
+
"veri-agents-api[fastapi-dev]==0.1.1"
|
|
26
|
+
]
|
|
27
|
+
|
|
28
|
+
[build-system]
|
|
29
|
+
requires = ["hatchling"]
|
|
30
|
+
build-backend = "hatchling.build"
|
|
File without changes
|
|
File without changes
|
|
@@ -0,0 +1,334 @@
|
|
|
1
|
+
import asyncio
|
|
2
|
+
import json
|
|
3
|
+
import logging
|
|
4
|
+
from typing import Any, AsyncGenerator, Dict, List, Tuple, Callable
|
|
5
|
+
from uuid import uuid4
|
|
6
|
+
|
|
7
|
+
from fastapi import HTTPException, Request, APIRouter
|
|
8
|
+
from fastapi.responses import StreamingResponse
|
|
9
|
+
from langchain_core.callbacks import AsyncCallbackHandler
|
|
10
|
+
from langchain_core.runnables import RunnableConfig
|
|
11
|
+
from langgraph.graph.graph import CompiledGraph
|
|
12
|
+
|
|
13
|
+
from .schema import (
|
|
14
|
+
ChatMessage,
|
|
15
|
+
StreamInput,
|
|
16
|
+
InvokeInput,
|
|
17
|
+
)
|
|
18
|
+
from veri_agents_api.threads_util import ThreadInfo, ThreadsCheckpointerUtil
|
|
19
|
+
from veri_agents_api.util.awaitable import as_awaitable, MaybeAwaitable
|
|
20
|
+
|
|
21
|
+
log = logging.getLogger(__name__)
|
|
22
|
+
|
|
23
|
+
class TokenQueueStreamingHandler(AsyncCallbackHandler):
|
|
24
|
+
"""LangChain callback handler for streaming LLM tokens to an asyncio queue."""
|
|
25
|
+
|
|
26
|
+
def __init__(self, queue: asyncio.Queue):
|
|
27
|
+
self.queue = queue
|
|
28
|
+
|
|
29
|
+
async def on_llm_new_token(self, token: str, **kwargs) -> None:
|
|
30
|
+
if token:
|
|
31
|
+
await self.queue.put(token)
|
|
32
|
+
|
|
33
|
+
def create_thread_router(
|
|
34
|
+
get_graph: Callable[[Request], MaybeAwaitable[CompiledGraph]],
|
|
35
|
+
get_thread_id: Callable[[Request], MaybeAwaitable[str]],
|
|
36
|
+
allow_access_thread: Callable[[str, ThreadInfo | None, Request], MaybeAwaitable[bool]] = lambda thread_id, thread_info, request: True,
|
|
37
|
+
allow_invoke_thread: Callable[[str, ThreadInfo | None, InvokeInput, Request], MaybeAwaitable[bool]] = lambda thread_id, thread_info, invoke_input, request: True,
|
|
38
|
+
invoke_runnable_config: Callable[[str, ThreadInfo | None, InvokeInput, Request], MaybeAwaitable[RunnableConfig | None]] = lambda thread_id, thread_info, invoke_input, request: None,
|
|
39
|
+
on_invoke_thread: Callable[[str, ThreadInfo | None, InvokeInput, Request], MaybeAwaitable[None]] = lambda thread_id, thread_info, invoke_input, request: None,
|
|
40
|
+
# InvokeInputCls: Type[InvokeInput] = InvokeInput,
|
|
41
|
+
**router_kwargs
|
|
42
|
+
):
|
|
43
|
+
"""
|
|
44
|
+
POST /invoke
|
|
45
|
+
POST /stream
|
|
46
|
+
GET /history
|
|
47
|
+
GET /feedback
|
|
48
|
+
POST /feedback
|
|
49
|
+
"""
|
|
50
|
+
|
|
51
|
+
router = APIRouter(**router_kwargs)
|
|
52
|
+
|
|
53
|
+
def _parse_input(user_input: InvokeInput, thread_id: str, invoke_recvd_runnable_config: RunnableConfig | None) -> Tuple[Dict[str, Any], str]:
|
|
54
|
+
run_id = uuid4()
|
|
55
|
+
input_message = ChatMessage(type="human", content=user_input.message)
|
|
56
|
+
|
|
57
|
+
runnable_config = invoke_recvd_runnable_config or RunnableConfig()
|
|
58
|
+
|
|
59
|
+
runnable_config["configurable"] = {
|
|
60
|
+
**{
|
|
61
|
+
# used by checkpointer
|
|
62
|
+
"thread_id": thread_id,
|
|
63
|
+
|
|
64
|
+
"_has_threadinfo": True,
|
|
65
|
+
|
|
66
|
+
# "args": user_input.args,
|
|
67
|
+
},
|
|
68
|
+
**(runnable_config.get("configurable", {}))
|
|
69
|
+
}
|
|
70
|
+
|
|
71
|
+
kwargs = dict(
|
|
72
|
+
input={"messages": [input_message.to_langchain()]},
|
|
73
|
+
config=runnable_config
|
|
74
|
+
)
|
|
75
|
+
return kwargs, str(run_id)
|
|
76
|
+
|
|
77
|
+
@router.post("/invoke")
|
|
78
|
+
async def invoke(invoke_input: InvokeInput, request: Request) -> ChatMessage:
|
|
79
|
+
"""
|
|
80
|
+
Invoke the agent with user input to retrieve a final response.
|
|
81
|
+
|
|
82
|
+
Use thread_id to persist and continue a multi-turn conversation. run_id kwarg
|
|
83
|
+
is also attached to messages for recording feedback.
|
|
84
|
+
"""
|
|
85
|
+
|
|
86
|
+
graph = await as_awaitable(get_graph(request))
|
|
87
|
+
thread_id = await as_awaitable(get_thread_id(request))
|
|
88
|
+
|
|
89
|
+
|
|
90
|
+
thread_info = await ThreadsCheckpointerUtil.get_thread_info(thread_id, graph.checkpointer)
|
|
91
|
+
|
|
92
|
+
if not await as_awaitable(allow_invoke_thread(thread_id, thread_info, invoke_input, request)):
|
|
93
|
+
raise HTTPException(status_code=403, detail="Forbidden")
|
|
94
|
+
|
|
95
|
+
invoke_recvd_runnable_config = await as_awaitable(invoke_runnable_config(thread_id, thread_info, invoke_input, request))
|
|
96
|
+
|
|
97
|
+
kwargs, run_id = _parse_input(invoke_input, thread_id, invoke_recvd_runnable_config)
|
|
98
|
+
|
|
99
|
+
# # store this thread in the database if a new one
|
|
100
|
+
# if user_input.thread_id not in router.state.threads:
|
|
101
|
+
# thread_info = ThreadInfo(
|
|
102
|
+
# thread_id=user_input.thread_id,
|
|
103
|
+
# user=principal,
|
|
104
|
+
# workflow_id=user_input.workflow,
|
|
105
|
+
# name=user_input.message[:50],
|
|
106
|
+
# metadata={"router": user_input.router},
|
|
107
|
+
# )
|
|
108
|
+
# router.state.threads[user_input.thread_id] = thread_info
|
|
109
|
+
# await graph.checkpointer.aput_thread(thread_info)
|
|
110
|
+
|
|
111
|
+
await as_awaitable(on_invoke_thread(thread_id, thread_info, invoke_input, request))
|
|
112
|
+
|
|
113
|
+
# langfuse_handler = CallbackHandler(
|
|
114
|
+
# public_key=router.state.cfg.logging.langfuse.public_key,
|
|
115
|
+
# secret_key=router.state.cfg.logging.langfuse.secret_key,
|
|
116
|
+
# host=router.state.cfg.logging.langfuse.host,
|
|
117
|
+
# # user_id=principal,
|
|
118
|
+
# session_id=user_input.thread_id,
|
|
119
|
+
# trace_name=user_input.message[:50],
|
|
120
|
+
# )
|
|
121
|
+
kwargs["config"]["callbacks"] = [] # was [langfuse_handler]
|
|
122
|
+
# kwargs["config"]["configurable"]["workflow_id"] = user_input.workflow
|
|
123
|
+
try:
|
|
124
|
+
response = await graph.ainvoke(**kwargs)
|
|
125
|
+
output = ChatMessage.from_langchain(response["messages"][-1])
|
|
126
|
+
output.run_id = str(run_id)
|
|
127
|
+
return output
|
|
128
|
+
except Exception as e:
|
|
129
|
+
raise HTTPException(status_code=500, detail=str(e))
|
|
130
|
+
|
|
131
|
+
@router.post("/stream")
|
|
132
|
+
async def stream_agent(stream_input: StreamInput, request: Request):
|
|
133
|
+
"""
|
|
134
|
+
Stream the agent's response to a user input, including intermediate messages and tokens.
|
|
135
|
+
|
|
136
|
+
Use thread_id to persist and continue a multi-turn conversation. run_id kwarg
|
|
137
|
+
is also attached to all messages for recording feedback.
|
|
138
|
+
"""
|
|
139
|
+
|
|
140
|
+
graph = await as_awaitable(get_graph(request))
|
|
141
|
+
thread_id = await as_awaitable(get_thread_id(request))
|
|
142
|
+
|
|
143
|
+
thread_info = await ThreadsCheckpointerUtil.get_thread_info(thread_id, graph.checkpointer)
|
|
144
|
+
|
|
145
|
+
if not await as_awaitable(allow_invoke_thread(thread_id, thread_info, stream_input, request)):
|
|
146
|
+
raise HTTPException(status_code=403, detail="Forbidden")
|
|
147
|
+
|
|
148
|
+
invoke_recvd_runnable_config = await as_awaitable(invoke_runnable_config(thread_id, thread_info, stream_input, request))
|
|
149
|
+
|
|
150
|
+
async def message_generator() -> AsyncGenerator[str, None]:
|
|
151
|
+
"""
|
|
152
|
+
Generate a stream of messages from the agent.
|
|
153
|
+
|
|
154
|
+
This is the workhorse method for the /stream endpoint.
|
|
155
|
+
"""
|
|
156
|
+
kwargs, run_id = _parse_input(stream_input, thread_id, invoke_recvd_runnable_config)
|
|
157
|
+
|
|
158
|
+
await as_awaitable(on_invoke_thread(thread_id, thread_info, stream_input, request))
|
|
159
|
+
|
|
160
|
+
# # store this thread in the database if a new one
|
|
161
|
+
# if user_input.thread_id not in router.state.threads:
|
|
162
|
+
# thread_info = ThreadInfo(
|
|
163
|
+
# thread_id=user_input.thread_id,
|
|
164
|
+
# user=principal,
|
|
165
|
+
# workflow_id=user_input.workflow,
|
|
166
|
+
# name=user_input.message[:50],
|
|
167
|
+
# metadata={"router": user_input.router},
|
|
168
|
+
# )
|
|
169
|
+
# router.state.threads[user_input.thread_id] = thread_info
|
|
170
|
+
# await graph.checkpointer.aput_thread(thread_info)
|
|
171
|
+
|
|
172
|
+
# Use an asyncio queue to process both messages and tokens in
|
|
173
|
+
# chronological order, so we can easily yield them to the client.
|
|
174
|
+
output_queue = asyncio.Queue(maxsize=10)
|
|
175
|
+
|
|
176
|
+
# langfuse_handler = CallbackHandler(
|
|
177
|
+
# public_key=router.state.cfg.logging.langfuse.public_key,
|
|
178
|
+
# secret_key=router.state.cfg.logging.langfuse.secret_key,
|
|
179
|
+
# host=router.state.cfg.logging.langfuse.host,
|
|
180
|
+
# user_id=principal,
|
|
181
|
+
# session_id=user_input.thread_id,
|
|
182
|
+
# trace_name=user_input.message[:50],
|
|
183
|
+
# )
|
|
184
|
+
if stream_input.stream_tokens:
|
|
185
|
+
kwargs["config"]["callbacks"] = [
|
|
186
|
+
TokenQueueStreamingHandler(queue=output_queue),
|
|
187
|
+
# langfuse_handler,
|
|
188
|
+
]
|
|
189
|
+
# kwargs["config"]["configurable"]["workflow_id"] = stream_input.workflow
|
|
190
|
+
|
|
191
|
+
# Pass the agent's stream of messages to the queue in a separate task, so
|
|
192
|
+
# we can yield the messages to the client in the main thread.
|
|
193
|
+
async def run_agent_stream():
|
|
194
|
+
async for s in graph.astream(**kwargs, stream_mode="updates"):
|
|
195
|
+
await output_queue.put(s)
|
|
196
|
+
await output_queue.put(None)
|
|
197
|
+
|
|
198
|
+
stream_task = asyncio.create_task(run_agent_stream())
|
|
199
|
+
|
|
200
|
+
# Process the queue and yield messages over the SSE stream.
|
|
201
|
+
while s := await output_queue.get():
|
|
202
|
+
log.info("Got from queue: %s: %s", type(s), s)
|
|
203
|
+
if isinstance(s, str):
|
|
204
|
+
# str is an LLM token
|
|
205
|
+
yield f"data: {json.dumps({'type': 'token', 'content': s})}\n\n"
|
|
206
|
+
continue
|
|
207
|
+
|
|
208
|
+
# Otherwise, s should be a dict of state updates for each node in the graph.
|
|
209
|
+
# s could have updates for multiple nodes, so check each for messages.
|
|
210
|
+
new_messages = []
|
|
211
|
+
for _, state in s.items():
|
|
212
|
+
new_messages.extend(state["messages"])
|
|
213
|
+
for message in new_messages:
|
|
214
|
+
try:
|
|
215
|
+
chat_message = ChatMessage.from_langchain(message)
|
|
216
|
+
chat_message.run_id = str(run_id)
|
|
217
|
+
except Exception as e:
|
|
218
|
+
yield f"data: {json.dumps({'type': 'error', 'content': f'Error parsing message: {e}'})}\n\n"
|
|
219
|
+
continue
|
|
220
|
+
# LangGraph re-sends the input message, which feels weird, so drop it
|
|
221
|
+
if (
|
|
222
|
+
chat_message.type == "human"
|
|
223
|
+
and chat_message.content == stream_input.message
|
|
224
|
+
):
|
|
225
|
+
continue
|
|
226
|
+
yield f"data: {json.dumps({'type': 'message', 'content': chat_message.dict()})}\n\n"
|
|
227
|
+
|
|
228
|
+
await stream_task
|
|
229
|
+
yield "data: [DONE]\n\n"
|
|
230
|
+
|
|
231
|
+
return StreamingResponse(
|
|
232
|
+
message_generator(),
|
|
233
|
+
media_type="text/event-stream",
|
|
234
|
+
)
|
|
235
|
+
|
|
236
|
+
@router.get("/history")
|
|
237
|
+
async def get_history(request: Request) -> List[ChatMessage]:
|
|
238
|
+
"""
|
|
239
|
+
Get the history of a thread.
|
|
240
|
+
"""
|
|
241
|
+
|
|
242
|
+
graph = await as_awaitable(get_graph(request))
|
|
243
|
+
thread_id = await as_awaitable(get_thread_id(request))
|
|
244
|
+
|
|
245
|
+
thread_info = await ThreadsCheckpointerUtil.get_thread_info(thread_id, graph.checkpointer)
|
|
246
|
+
|
|
247
|
+
if not await as_awaitable(allow_access_thread(thread_id, thread_info, request)):
|
|
248
|
+
raise HTTPException(status_code=403, detail="Forbidden")
|
|
249
|
+
|
|
250
|
+
# agent: CompiledGraph = router.state.workflows[workflow].get_graph()
|
|
251
|
+
config = RunnableConfig(configurable={
|
|
252
|
+
# used by checkpointer
|
|
253
|
+
"thread_id": thread_id,
|
|
254
|
+
})
|
|
255
|
+
state = await graph.aget_state(config)
|
|
256
|
+
messages = state.values.get("messages", [])
|
|
257
|
+
|
|
258
|
+
converted_messages: List[ChatMessage] = []
|
|
259
|
+
for message in messages:
|
|
260
|
+
try:
|
|
261
|
+
chat_message = ChatMessage.from_langchain(message)
|
|
262
|
+
converted_messages.append(chat_message)
|
|
263
|
+
except Exception as e:
|
|
264
|
+
log.error(f"Error parsing message: {e}")
|
|
265
|
+
continue
|
|
266
|
+
return converted_messages
|
|
267
|
+
|
|
268
|
+
# @router.get("/feedback")
|
|
269
|
+
# async def get_feedback(request: Request, thread_id: str):
|
|
270
|
+
# """Get all feedback for a thread.
|
|
271
|
+
#
|
|
272
|
+
# Arguments:
|
|
273
|
+
# thread_id: The ID of the thread to get feedback for.
|
|
274
|
+
# """
|
|
275
|
+
# # if thread_id not in router.state.threads:
|
|
276
|
+
# # raise HTTPException(status_code=404, detail=f"Unknown thread: {thread_id}")
|
|
277
|
+
# # assert_viewer_can_assume_identity(
|
|
278
|
+
# # request, principal=router.state.threads[thread_id].user
|
|
279
|
+
# # )
|
|
280
|
+
# feedback = [
|
|
281
|
+
# f.model_dump(mode="json")
|
|
282
|
+
# async for f in graph.checkpointer.alist_feedback(thread_id=thread_id)
|
|
283
|
+
# ]
|
|
284
|
+
# return feedback
|
|
285
|
+
#
|
|
286
|
+
# @router.post("/feedback")
|
|
287
|
+
# async def feedback(feedback: Feedback, request: Request):
|
|
288
|
+
# """
|
|
289
|
+
# Record feedback for a run of the agent.
|
|
290
|
+
#
|
|
291
|
+
# Arguments:
|
|
292
|
+
# feedback: The feedback to record.
|
|
293
|
+
# """
|
|
294
|
+
# if feedback.thread_id not in router.state.threads:
|
|
295
|
+
# raise HTTPException(
|
|
296
|
+
# status_code=404, detail=f"Unknown thread: {feedback.thread_id}"
|
|
297
|
+
# )
|
|
298
|
+
# assert_viewer_can_assume_identity(
|
|
299
|
+
# request, principal=router.state.threads[feedback.thread_id].user
|
|
300
|
+
# )
|
|
301
|
+
#
|
|
302
|
+
# # store in database
|
|
303
|
+
# try:
|
|
304
|
+
# await graph.checkpointer.aput_feedback(feedback)
|
|
305
|
+
# db_status = "success"
|
|
306
|
+
# except Exception as e:
|
|
307
|
+
# log.error(f"Error storing feedback in database: {e}")
|
|
308
|
+
# db_status = "error"
|
|
309
|
+
#
|
|
310
|
+
# ## Also store in Langfuse
|
|
311
|
+
# ## We don't have the run_id, but need it for Langfuse
|
|
312
|
+
# ## The run_id is currently not store in the database.
|
|
313
|
+
# # try:
|
|
314
|
+
# # langfuse = Langfuse(
|
|
315
|
+
# # public_key=router.state.cfg.logging.langfuse.public_key,
|
|
316
|
+
# # secret_key=router.state.cfg.logging.langfuse.secret_key,
|
|
317
|
+
# # host=router.state.cfg.logging.langfuse.host,
|
|
318
|
+
# # )
|
|
319
|
+
# # langfuse.score(
|
|
320
|
+
# # trace_id=feedback.run_id,
|
|
321
|
+
# # name=feedback.key,
|
|
322
|
+
# # value=feedback.score,
|
|
323
|
+
# # comment=feedback.kwargs.get("comment", ""),
|
|
324
|
+
# # )
|
|
325
|
+
# # langfuse_status = "success"
|
|
326
|
+
# # except Exception as e:
|
|
327
|
+
# # log.error(f"Error storing feedback in Langfuse: {e}")
|
|
328
|
+
# # langfuse_status = "error"
|
|
329
|
+
#
|
|
330
|
+
# langfuse_status = "not implemented"
|
|
331
|
+
#
|
|
332
|
+
# return {"db_status": db_status, "langfuse_status": langfuse_status}
|
|
333
|
+
|
|
334
|
+
return router
|
|
@@ -0,0 +1,169 @@
|
|
|
1
|
+
from datetime import datetime
|
|
2
|
+
from typing import Dict, Any, List, Literal, Optional, Union
|
|
3
|
+
from langchain_core.messages import (
|
|
4
|
+
BaseMessage,
|
|
5
|
+
HumanMessage,
|
|
6
|
+
AIMessage,
|
|
7
|
+
ToolMessage,
|
|
8
|
+
ToolCall,
|
|
9
|
+
message_to_dict,
|
|
10
|
+
messages_from_dict,
|
|
11
|
+
)
|
|
12
|
+
from pydantic import BaseModel, Field
|
|
13
|
+
|
|
14
|
+
class InvokeInput(BaseModel):
|
|
15
|
+
"""Basic user input for the agent."""
|
|
16
|
+
|
|
17
|
+
message: str = Field(
|
|
18
|
+
description="User input to the agent.",
|
|
19
|
+
examples=["What is the weather in Tokyo?"],
|
|
20
|
+
)
|
|
21
|
+
# args: Dict[str, Any] = Field(
|
|
22
|
+
# description="Arguments to pass to the workflow.",
|
|
23
|
+
# default={},
|
|
24
|
+
# examples=[{"kb": "veritone_support"}],
|
|
25
|
+
# )
|
|
26
|
+
# user: Optional[str] = Field(
|
|
27
|
+
# description="A user identifier to validate the user in knowledge bases and other tools.",
|
|
28
|
+
# default=None,
|
|
29
|
+
# examples=["jjohnson", "ccarlson"],
|
|
30
|
+
# )
|
|
31
|
+
# thread_id: str = Field(
|
|
32
|
+
# description="Thread ID to persist and continue a multi-turn conversation.",
|
|
33
|
+
# examples=["847c6285-8fc9-4560-a83f-4e6285809254"],
|
|
34
|
+
# )
|
|
35
|
+
|
|
36
|
+
|
|
37
|
+
class StreamInput(InvokeInput):
|
|
38
|
+
"""User input for streaming the agent's response."""
|
|
39
|
+
|
|
40
|
+
stream_tokens: bool = Field(
|
|
41
|
+
description="Whether to stream LLM tokens to the client.",
|
|
42
|
+
default=True,
|
|
43
|
+
)
|
|
44
|
+
|
|
45
|
+
|
|
46
|
+
class AgentResponse(BaseModel):
|
|
47
|
+
"""Response from the agent when called via /invoke."""
|
|
48
|
+
|
|
49
|
+
message: Dict[str, Any] = Field(
|
|
50
|
+
description="Final response from the agent, as a serialized LangChain message.",
|
|
51
|
+
examples=[
|
|
52
|
+
{
|
|
53
|
+
"message": {
|
|
54
|
+
"type": "ai",
|
|
55
|
+
"data": {
|
|
56
|
+
"content": "The weather in Tokyo is 70 degrees.",
|
|
57
|
+
"type": "ai",
|
|
58
|
+
},
|
|
59
|
+
}
|
|
60
|
+
}
|
|
61
|
+
],
|
|
62
|
+
)
|
|
63
|
+
|
|
64
|
+
|
|
65
|
+
class ChatMessage(BaseModel):
|
|
66
|
+
"""Message in a chat."""
|
|
67
|
+
|
|
68
|
+
type: Literal["human", "ai", "tool"] = Field(
|
|
69
|
+
description="Role of the message.",
|
|
70
|
+
examples=["human", "ai", "tool"],
|
|
71
|
+
)
|
|
72
|
+
content: Union[str, list[Union[str, dict]]] = Field(
|
|
73
|
+
description="Content of the message.",
|
|
74
|
+
examples=["Hello, world!"],
|
|
75
|
+
)
|
|
76
|
+
tool_calls: List[ToolCall] = Field(
|
|
77
|
+
description="Tool calls in the message.",
|
|
78
|
+
default=[],
|
|
79
|
+
)
|
|
80
|
+
tool_call_id: str | None = Field(
|
|
81
|
+
description="Tool call that this message is responding to.",
|
|
82
|
+
default=None,
|
|
83
|
+
examples=["call_Jja7J89XsjrOLA5r!MEOW!SL"],
|
|
84
|
+
)
|
|
85
|
+
run_id: str | None = Field(
|
|
86
|
+
description="Run ID of the message.",
|
|
87
|
+
default=None,
|
|
88
|
+
examples=["847c6285-8fc9-4560-a83f-4e6285809254"],
|
|
89
|
+
)
|
|
90
|
+
original: Dict[str, Any] = Field(
|
|
91
|
+
description="Original LangChain message in serialized form.",
|
|
92
|
+
default={},
|
|
93
|
+
)
|
|
94
|
+
|
|
95
|
+
@classmethod
|
|
96
|
+
def from_langchain(cls, message: BaseMessage) -> "ChatMessage":
|
|
97
|
+
"""Create a ChatMessage from a LangChain message."""
|
|
98
|
+
original = message_to_dict(message)
|
|
99
|
+
match message:
|
|
100
|
+
case HumanMessage():
|
|
101
|
+
human_message = cls(
|
|
102
|
+
type="human", content=message.content, original=original
|
|
103
|
+
)
|
|
104
|
+
return human_message
|
|
105
|
+
case AIMessage():
|
|
106
|
+
ai_message = cls(type="ai", content=message.content, original=original)
|
|
107
|
+
if message.tool_calls:
|
|
108
|
+
ai_message.tool_calls = message.tool_calls
|
|
109
|
+
return ai_message
|
|
110
|
+
case ToolMessage():
|
|
111
|
+
tool_message = cls(
|
|
112
|
+
type="tool",
|
|
113
|
+
content=message.content,
|
|
114
|
+
tool_call_id=message.tool_call_id,
|
|
115
|
+
original=original,
|
|
116
|
+
)
|
|
117
|
+
return tool_message
|
|
118
|
+
case _:
|
|
119
|
+
raise ValueError(
|
|
120
|
+
f"Unsupported message type: {message.__class__.__name__}"
|
|
121
|
+
)
|
|
122
|
+
|
|
123
|
+
def to_langchain(self) -> BaseMessage:
|
|
124
|
+
"""Convert the ChatMessage to a LangChain message."""
|
|
125
|
+
if self.original:
|
|
126
|
+
return messages_from_dict([self.original])[0]
|
|
127
|
+
match self.type:
|
|
128
|
+
case "human":
|
|
129
|
+
return HumanMessage(content=self.content)
|
|
130
|
+
case _:
|
|
131
|
+
raise NotImplementedError(f"Unsupported message type: {self.type}")
|
|
132
|
+
|
|
133
|
+
def pretty_print(self) -> None:
|
|
134
|
+
"""Pretty print the ChatMessage."""
|
|
135
|
+
lc_msg = self.to_langchain()
|
|
136
|
+
lc_msg.pretty_print()
|
|
137
|
+
|
|
138
|
+
def get_artifact(self) -> Optional[Dict[str, Any]]:
|
|
139
|
+
"""Get the artifact from the message if there is one."""
|
|
140
|
+
if (
|
|
141
|
+
self.original.get("data")
|
|
142
|
+
and self.original["data"].get("artifact")
|
|
143
|
+
and self.original["data"]["artifact"]
|
|
144
|
+
):
|
|
145
|
+
return self.original["data"]["artifact"]
|
|
146
|
+
return None
|
|
147
|
+
|
|
148
|
+
|
|
149
|
+
class Feedback(BaseModel):
|
|
150
|
+
"""Feedback for a run."""
|
|
151
|
+
|
|
152
|
+
message_id: str = Field(
|
|
153
|
+
description="Message ID to record feedback for.",
|
|
154
|
+
examples=["847c6285-8fc9-4560-a83f-4e6285809254"],
|
|
155
|
+
)
|
|
156
|
+
# thread_id: str = Field(
|
|
157
|
+
# description="Thread ID to record feedback for.",
|
|
158
|
+
# examples=["847c6285-8fc9-4560-a83f-4e6285809254"],
|
|
159
|
+
# )
|
|
160
|
+
score: float = Field(
|
|
161
|
+
description="Feedback score.",
|
|
162
|
+
examples=[0.8],
|
|
163
|
+
)
|
|
164
|
+
kwargs: Dict[str, Any] = Field(
|
|
165
|
+
description="Additional feedback kwargs, passed to LangSmith.",
|
|
166
|
+
default={},
|
|
167
|
+
examples=[{"comment": "In-line human feedback"}],
|
|
168
|
+
)
|
|
169
|
+
creation: datetime = Field(default_factory=datetime.now)
|
|
@@ -0,0 +1,59 @@
|
|
|
1
|
+
from typing import Annotated
|
|
2
|
+
|
|
3
|
+
from langchain_aws import ChatBedrock
|
|
4
|
+
from typing_extensions import TypedDict
|
|
5
|
+
|
|
6
|
+
from fastapi import FastAPI
|
|
7
|
+
from langgraph.checkpoint.memory import MemorySaver
|
|
8
|
+
from langgraph.graph import StateGraph
|
|
9
|
+
from langgraph.graph.message import add_messages
|
|
10
|
+
import uvicorn
|
|
11
|
+
|
|
12
|
+
from veri_agents_api.fastapi.thread import create_thread_router
|
|
13
|
+
|
|
14
|
+
if __name__ == "__main__":
|
|
15
|
+
class State(TypedDict):
|
|
16
|
+
messages: Annotated[list, add_messages]
|
|
17
|
+
|
|
18
|
+
graph_builder = StateGraph(State)
|
|
19
|
+
|
|
20
|
+
llm = ChatBedrock(model_id="us.anthropic.claude-3-7-sonnet-20250219-v1:0") # pyright: ignore[reportCallIssue]
|
|
21
|
+
|
|
22
|
+
def chatbot(state: State):
|
|
23
|
+
return {"messages": [llm.invoke(state["messages"])]}
|
|
24
|
+
|
|
25
|
+
# The first argument is the unique node name
|
|
26
|
+
# The second argument is the function or object that will be called whenever
|
|
27
|
+
# the node is used.
|
|
28
|
+
graph_builder.add_node("chatbot", chatbot)
|
|
29
|
+
graph_builder.set_entry_point("chatbot")
|
|
30
|
+
graph_builder.set_finish_point("chatbot")
|
|
31
|
+
|
|
32
|
+
|
|
33
|
+
|
|
34
|
+
# in-memory persistence
|
|
35
|
+
memory = MemorySaver()
|
|
36
|
+
graph = graph_builder.compile(checkpointer=memory)
|
|
37
|
+
|
|
38
|
+
# veri-agents convenience router
|
|
39
|
+
thread_router = create_thread_router(
|
|
40
|
+
# same graph for every request
|
|
41
|
+
get_graph=lambda req: graph,
|
|
42
|
+
# derive thread id from /thread/{thread_id} path param
|
|
43
|
+
get_thread_id=lambda req: req.path_params["thread_id"]
|
|
44
|
+
)
|
|
45
|
+
|
|
46
|
+
# root fastapi app
|
|
47
|
+
app = FastAPI()
|
|
48
|
+
app.include_router(thread_router, prefix="/threads/{thread_id}")
|
|
49
|
+
|
|
50
|
+
uvicorn.run(app, port=5000, log_level="info")
|
|
51
|
+
# you can now access:
|
|
52
|
+
# GET /openapi.json
|
|
53
|
+
# GET /threads
|
|
54
|
+
# GET /thread/{thread_id}/info
|
|
55
|
+
# POST /threads/{thread_id}/invoke
|
|
56
|
+
# POST /threads/{thread_id}/stream
|
|
57
|
+
# GET /threads/{thread_id}/history
|
|
58
|
+
# GET /threads/{thread_id}/feedback
|
|
59
|
+
# POST /threads/{thread_id}/feedback
|
|
@@ -0,0 +1,58 @@
|
|
|
1
|
+
from typing import Annotated
|
|
2
|
+
|
|
3
|
+
from langchain_aws import ChatBedrock
|
|
4
|
+
from typing_extensions import TypedDict
|
|
5
|
+
|
|
6
|
+
from langgraph.checkpoint.memory import MemorySaver
|
|
7
|
+
from langgraph.graph import StateGraph
|
|
8
|
+
from langgraph.graph.message import add_messages
|
|
9
|
+
from fastapi import FastAPI
|
|
10
|
+
import uvicorn
|
|
11
|
+
|
|
12
|
+
from veri_agents_api.fastapi.thread import create_thread_router
|
|
13
|
+
|
|
14
|
+
if __name__ == "__main__":
|
|
15
|
+
class State(TypedDict):
|
|
16
|
+
messages: Annotated[list, add_messages]
|
|
17
|
+
|
|
18
|
+
graph_builder = StateGraph(State)
|
|
19
|
+
|
|
20
|
+
llm = ChatBedrock(model_id="us.anthropic.claude-3-7-sonnet-20250219-v1:0") # pyright: ignore[reportCallIssue]
|
|
21
|
+
|
|
22
|
+
def chatbot(state: State):
|
|
23
|
+
return {"messages": [llm.invoke(state["messages"])]}
|
|
24
|
+
|
|
25
|
+
# The first argument is the unique node name
|
|
26
|
+
# The second argument is the function or object that will be called whenever
|
|
27
|
+
# the node is used.
|
|
28
|
+
graph_builder.add_node("chatbot", chatbot)
|
|
29
|
+
graph_builder.set_entry_point("chatbot")
|
|
30
|
+
graph_builder.set_finish_point("chatbot")
|
|
31
|
+
|
|
32
|
+
|
|
33
|
+
|
|
34
|
+
|
|
35
|
+
# in-memory persistence
|
|
36
|
+
memory = MemorySaver()
|
|
37
|
+
graph = graph_builder.compile(checkpointer=memory)
|
|
38
|
+
|
|
39
|
+
# veri-agents convenience router
|
|
40
|
+
thread_router = create_thread_router(
|
|
41
|
+
# same graph for every request
|
|
42
|
+
get_graph=lambda req: graph,
|
|
43
|
+
# same thread for every request
|
|
44
|
+
get_thread_id=lambda req: "inmem"
|
|
45
|
+
)
|
|
46
|
+
|
|
47
|
+
# root fastapi app
|
|
48
|
+
app = FastAPI()
|
|
49
|
+
app.include_router(thread_router)
|
|
50
|
+
|
|
51
|
+
uvicorn.run(app, port=5000, log_level="info")
|
|
52
|
+
# you can now access:
|
|
53
|
+
# GET /openapi.json
|
|
54
|
+
# POST /invoke
|
|
55
|
+
# POST /stream
|
|
56
|
+
# GET /history
|
|
57
|
+
# GET /feedback
|
|
58
|
+
# POST /feedback
|
|
@@ -0,0 +1 @@
|
|
|
1
|
+
from .router import *
|
|
@@ -0,0 +1,75 @@
|
|
|
1
|
+
import logging
|
|
2
|
+
from typing import Callable, cast, Type, Awaitable
|
|
3
|
+
|
|
4
|
+
from fastapi import HTTPException, Request, APIRouter
|
|
5
|
+
from langgraph.graph.graph import CompiledGraph
|
|
6
|
+
|
|
7
|
+
from veri_agents_api.fastapi.thread import (
|
|
8
|
+
create_thread_router as create_thread_router, InvokeInput
|
|
9
|
+
)
|
|
10
|
+
from veri_agents_api.threads_util import ThreadsCheckpointerUtil, ThreadInfo
|
|
11
|
+
from veri_agents_api.util.awaitable import MaybeAwaitable, as_awaitable
|
|
12
|
+
|
|
13
|
+
log = logging.getLogger(__name__)
|
|
14
|
+
|
|
15
|
+
|
|
16
|
+
def create_threads_router(
|
|
17
|
+
get_graph: Callable[[Request], MaybeAwaitable[CompiledGraph]],
|
|
18
|
+
allow_access_thread: Callable[[str, ThreadInfo | None, Request], MaybeAwaitable[bool]] = lambda thread_id, thread_info, request: True,
|
|
19
|
+
allow_invoke_thread: Callable[[str, ThreadInfo | None, InvokeInput, Request], MaybeAwaitable[bool]] = lambda thread_id, thread_info, invoke_input, request: True,
|
|
20
|
+
on_invoke_thread: Callable[[str, ThreadInfo | None, InvokeInput, Request], MaybeAwaitable[None]] = lambda thread_id, thread_info, invoke_input, request: None,
|
|
21
|
+
# InvokeInputCls: Type[InvokeInput] = InvokeInput,
|
|
22
|
+
**router_kwargs
|
|
23
|
+
):
|
|
24
|
+
router = APIRouter(prefix="/threads", **router_kwargs)
|
|
25
|
+
|
|
26
|
+
thread_router = create_thread_router(
|
|
27
|
+
# derive thread id from /thread/{thread_id} path param
|
|
28
|
+
get_thread_id=lambda req: req.path_params["thread_id"],
|
|
29
|
+
|
|
30
|
+
# arg passthrough - TODO: make more elegant
|
|
31
|
+
get_graph=get_graph,
|
|
32
|
+
allow_access_thread=allow_access_thread,
|
|
33
|
+
allow_invoke_thread=allow_invoke_thread,
|
|
34
|
+
on_invoke_thread=on_invoke_thread,
|
|
35
|
+
|
|
36
|
+
# InvokeInputCls=InvokeInputCls
|
|
37
|
+
)
|
|
38
|
+
|
|
39
|
+
@router.get("/")
|
|
40
|
+
async def get_threads(request: Request):
|
|
41
|
+
"""Get all threads the user has access to."""
|
|
42
|
+
|
|
43
|
+
graph = await as_awaitable(get_graph(request))
|
|
44
|
+
|
|
45
|
+
all_thread_info = await ThreadsCheckpointerUtil.list_threads(graph.checkpointer)
|
|
46
|
+
|
|
47
|
+
accessible_thread_info: list[ThreadInfo] = []
|
|
48
|
+
for thread_info in all_thread_info:
|
|
49
|
+
if allow_access_thread(thread_info.thread_id, thread_info, request):
|
|
50
|
+
accessible_thread_info.append(thread_info)
|
|
51
|
+
|
|
52
|
+
return accessible_thread_info
|
|
53
|
+
|
|
54
|
+
@router.get("/{thread_id}/info")
|
|
55
|
+
async def get_thread_by_id(thread_id: str, request: Request):
|
|
56
|
+
"""Get a thread by its ID.
|
|
57
|
+
|
|
58
|
+
Arguments:
|
|
59
|
+
thread_id: The ID of the thread to get.
|
|
60
|
+
"""
|
|
61
|
+
graph = await as_awaitable(get_graph(request))
|
|
62
|
+
|
|
63
|
+
thread_info = await ThreadsCheckpointerUtil.get_thread_info(thread_id, graph.checkpointer)
|
|
64
|
+
|
|
65
|
+
if not allow_access_thread(thread_id, thread_info, request):
|
|
66
|
+
raise HTTPException(status_code=403, detail="Forbidden")
|
|
67
|
+
|
|
68
|
+
try:
|
|
69
|
+
return thread_info
|
|
70
|
+
except:
|
|
71
|
+
raise HTTPException(status_code=404, detail="Thread not found")
|
|
72
|
+
|
|
73
|
+
router.include_router(thread_router, prefix="/{thread_id}")
|
|
74
|
+
|
|
75
|
+
return router
|
|
@@ -0,0 +1,58 @@
|
|
|
1
|
+
from typing import Annotated
|
|
2
|
+
|
|
3
|
+
from langchain_aws import ChatBedrock
|
|
4
|
+
from typing_extensions import TypedDict
|
|
5
|
+
|
|
6
|
+
from fastapi import FastAPI
|
|
7
|
+
from langgraph.checkpoint.memory import MemorySaver
|
|
8
|
+
from langgraph.graph import StateGraph
|
|
9
|
+
from langgraph.graph.message import add_messages
|
|
10
|
+
import uvicorn
|
|
11
|
+
|
|
12
|
+
from veri_agents_api.fastapi.threads import create_threads_router
|
|
13
|
+
|
|
14
|
+
if __name__ == "__main__":
|
|
15
|
+
class State(TypedDict):
|
|
16
|
+
messages: Annotated[list, add_messages]
|
|
17
|
+
|
|
18
|
+
graph_builder = StateGraph(State)
|
|
19
|
+
|
|
20
|
+
# noinspection PyArgumentList
|
|
21
|
+
llm = ChatBedrock(model_id="us.anthropic.claude-3-7-sonnet-20250219-v1:0") # pyright: ignore[reportCallIssue]
|
|
22
|
+
|
|
23
|
+
def chatbot(state: State):
|
|
24
|
+
return {"messages": [llm.invoke(state["messages"])]}
|
|
25
|
+
|
|
26
|
+
# The first argument is the unique node name
|
|
27
|
+
# The second argument is the function or object that will be called whenever
|
|
28
|
+
# the node is used.
|
|
29
|
+
graph_builder.add_node("chatbot", chatbot)
|
|
30
|
+
graph_builder.set_entry_point("chatbot")
|
|
31
|
+
graph_builder.set_finish_point("chatbot")
|
|
32
|
+
|
|
33
|
+
|
|
34
|
+
|
|
35
|
+
# in-memory persistence
|
|
36
|
+
memory = MemorySaver()
|
|
37
|
+
graph = graph_builder.compile(checkpointer=memory)
|
|
38
|
+
|
|
39
|
+
# veri-agents convenience router
|
|
40
|
+
threads_router = create_threads_router(
|
|
41
|
+
# same graph for every request
|
|
42
|
+
get_graph=lambda req: graph
|
|
43
|
+
)
|
|
44
|
+
|
|
45
|
+
# root fastapi app
|
|
46
|
+
app = FastAPI()
|
|
47
|
+
app.include_router(threads_router)
|
|
48
|
+
|
|
49
|
+
uvicorn.run(app, port=5000, log_level="info")
|
|
50
|
+
# you can now access:
|
|
51
|
+
# GET /openapi.json
|
|
52
|
+
# GET /threads
|
|
53
|
+
# GET /threads/{thread_id}/info
|
|
54
|
+
# POST /threads/{thread_id}/invoke
|
|
55
|
+
# POST /threads/{thread_id}/stream
|
|
56
|
+
# GET /threads/{thread_id}/history
|
|
57
|
+
# GET /threads/{thread_id}/feedback
|
|
58
|
+
# POST /threads/{thread_id}/feedback
|
|
@@ -0,0 +1,43 @@
|
|
|
1
|
+
from typing import cast
|
|
2
|
+
|
|
3
|
+
from .schema import ThreadInfo
|
|
4
|
+
from langgraph.types import Checkpointer
|
|
5
|
+
from langgraph.checkpoint.base import BaseCheckpointSaver
|
|
6
|
+
|
|
7
|
+
class ThreadsCheckpointerUtil:
|
|
8
|
+
@staticmethod
|
|
9
|
+
async def get_thread_info(thread_id: str, checkpointer: Checkpointer) -> ThreadInfo | None:
|
|
10
|
+
if not isinstance(checkpointer, BaseCheckpointSaver):
|
|
11
|
+
raise Exception("checkpointer must be instance of BaseCheckpointSaver")
|
|
12
|
+
|
|
13
|
+
chk_tuple = (await checkpointer.aget_tuple(config={
|
|
14
|
+
"configurable": {"_has_threadinfo": True, "thread_id": thread_id}}))
|
|
15
|
+
|
|
16
|
+
if chk_tuple is None:
|
|
17
|
+
return None
|
|
18
|
+
|
|
19
|
+
thread_metadata = chk_tuple.metadata
|
|
20
|
+
return ThreadInfo(
|
|
21
|
+
thread_id=thread_id,
|
|
22
|
+
)
|
|
23
|
+
|
|
24
|
+
@staticmethod
|
|
25
|
+
async def list_threads(checkpointer: Checkpointer) -> list[ThreadInfo]:
|
|
26
|
+
if not isinstance(checkpointer, BaseCheckpointSaver):
|
|
27
|
+
raise Exception("checkpointer must be instance of BaseCheckpointSaver")
|
|
28
|
+
|
|
29
|
+
init_step_checkpoints = checkpointer.alist(config=None, filter={ 'step': -1 }) # get initial steps only - this ensures we are only getting one thread_id from the checkpoints of a thread
|
|
30
|
+
|
|
31
|
+
all_accessible_thread_ids: list[str] = []
|
|
32
|
+
async for checkpoint in init_step_checkpoints:
|
|
33
|
+
thread_id = cast(str | None, checkpoint.config.get("configurable", {"thread_id": None}).get("thread_id", None))
|
|
34
|
+
if thread_id is not None: # and allow_access_thread(thread_id, request) ?
|
|
35
|
+
all_accessible_thread_ids.append(cast(str, thread_id))
|
|
36
|
+
|
|
37
|
+
all_thread_info: list[ThreadInfo] = []
|
|
38
|
+
for thread_id in all_accessible_thread_ids:
|
|
39
|
+
thread_info = await ThreadsCheckpointerUtil.get_thread_info(thread_id, checkpointer)
|
|
40
|
+
if thread_info is not None:
|
|
41
|
+
all_thread_info.append(thread_info)
|
|
42
|
+
|
|
43
|
+
return all_thread_info
|
|
@@ -0,0 +1,13 @@
|
|
|
1
|
+
from datetime import datetime
|
|
2
|
+
|
|
3
|
+
from pydantic import BaseModel, Field
|
|
4
|
+
|
|
5
|
+
class ThreadInfo(BaseModel):
|
|
6
|
+
"""Information about a single thread."""
|
|
7
|
+
|
|
8
|
+
thread_id: str
|
|
9
|
+
# workflow_id: str
|
|
10
|
+
# name: str
|
|
11
|
+
# user: str
|
|
12
|
+
# metadata: dict = Field(default={})
|
|
13
|
+
# creation: datetime = Field(default_factory=datetime.now)
|
|
File without changes
|
|
@@ -0,0 +1,11 @@
|
|
|
1
|
+
import asyncio
|
|
2
|
+
from typing import TypeVar, Awaitable, cast
|
|
3
|
+
|
|
4
|
+
T = TypeVar('T')
|
|
5
|
+
|
|
6
|
+
async def as_awaitable(maybe_coroutine: T | Awaitable[T]) -> T:
|
|
7
|
+
if asyncio.iscoroutine(maybe_coroutine):
|
|
8
|
+
return await maybe_coroutine
|
|
9
|
+
return cast(T, maybe_coroutine)
|
|
10
|
+
|
|
11
|
+
type MaybeAwaitable[T] = T | Awaitable[T]
|