contentgrid-assistant-api 0.0.2__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- _version.txt +1 -0
- contentgrid_assistant_api/app.py +131 -0
- contentgrid_assistant_api/config.py +55 -0
- contentgrid_assistant_api/db/repositories/thread_repository.py +78 -0
- contentgrid_assistant_api/db/types/message.py +25 -0
- contentgrid_assistant_api/db/types/thread.py +46 -0
- contentgrid_assistant_api/dependencies.py +79 -0
- contentgrid_assistant_api/routers/agent_home.py +41 -0
- contentgrid_assistant_api/routers/message_router.py +256 -0
- contentgrid_assistant_api/routers/thread_router.py +129 -0
- contentgrid_assistant_api/types/agents.py +50 -0
- contentgrid_assistant_api/types/context.py +9 -0
- contentgrid_assistant_api-0.0.2.dist-info/METADATA +401 -0
- contentgrid_assistant_api-0.0.2.dist-info/RECORD +17 -0
- contentgrid_assistant_api-0.0.2.dist-info/WHEEL +5 -0
- contentgrid_assistant_api-0.0.2.dist-info/licenses/LICENSE +13 -0
- contentgrid_assistant_api-0.0.2.dist-info/top_level.txt +1 -0
|
@@ -0,0 +1,256 @@
|
|
|
1
|
+
# type : ignore
|
|
2
|
+
|
|
3
|
+
import base64
|
|
4
|
+
import json
|
|
5
|
+
from typing import Annotated, AsyncGenerator, List, Union, Optional
|
|
6
|
+
import typing
|
|
7
|
+
import uuid
|
|
8
|
+
from enum import Enum
|
|
9
|
+
from langgraph.graph.state import CompiledStateGraph
|
|
10
|
+
from fastapi import APIRouter, BackgroundTasks, File, Form, Request, UploadFile, status
|
|
11
|
+
from fastapi.params import Depends
|
|
12
|
+
from contentgrid_assistant_api.config import AssistantExtensionConfig
|
|
13
|
+
from fastapi.responses import StreamingResponse
|
|
14
|
+
from contentgrid_assistant_api.db.repositories.thread_repository import ThreadRepository
|
|
15
|
+
from contentgrid_assistant_api.db.types.message import HALHumanMessage, HALAIMessage, HALSystemMessage, HALToolMessage
|
|
16
|
+
from contentgrid_assistant_api.dependencies import DependencyResolver
|
|
17
|
+
from contentgrid_extension_helpers.authentication import ContentGridUser
|
|
18
|
+
from langgraph.checkpoint.base import BaseCheckpointSaver
|
|
19
|
+
from contentgrid_extension_helpers.responses.hal import FastAPIHALCollection, HALTemplateFor, HALLinkFor
|
|
20
|
+
from langchain_core.messages import HumanMessage, SystemMessage, AIMessage, BaseMessage, ToolMessage
|
|
21
|
+
from fastapi import HTTPException
|
|
22
|
+
|
|
23
|
+
|
|
24
|
+
# ContentBlocks
|
|
25
|
+
from langchain_core.messages.content import create_image_block, create_file_block, create_audio_block, create_text_block, ImageContentBlock, AudioContentBlock, FileContentBlock, TextContentBlock
|
|
26
|
+
import logging
|
|
27
|
+
|
|
28
|
+
from contentgrid_assistant_api.types.context import DefaultThreadContext
|
|
29
|
+
|
|
30
|
+
def generate_agent_message_router(dep_resolver: DependencyResolver, extension_config: AssistantExtensionConfig, tags: Optional[List[str | Enum]]=None) -> APIRouter:
|
|
31
|
+
messagesrouter = APIRouter(prefix="/{thread_id}" + extension_config.routes_message_prefix, tags=tags or ["messages"])
|
|
32
|
+
|
|
33
|
+
def convert_to_hal_message(message: BaseMessage, thread_id: uuid.UUID) -> Union[HALHumanMessage, HALAIMessage, HALSystemMessage, HALToolMessage]:
|
|
34
|
+
"""Convert a LangChain BaseMessage to the appropriate HAL message type"""
|
|
35
|
+
if isinstance(message, ToolMessage):
|
|
36
|
+
# Handle ToolMessage first since it might inherit from other message types
|
|
37
|
+
hal_message = HALToolMessage(**message.model_dump())
|
|
38
|
+
elif isinstance(message, HumanMessage):
|
|
39
|
+
hal_message = HALHumanMessage(**message.model_dump())
|
|
40
|
+
elif isinstance(message, AIMessage):
|
|
41
|
+
hal_message = HALAIMessage(**message.model_dump())
|
|
42
|
+
elif isinstance(message, SystemMessage):
|
|
43
|
+
hal_message = HALSystemMessage(**message.model_dump())
|
|
44
|
+
else:
|
|
45
|
+
# Default to HALSystemMessage for unknown types
|
|
46
|
+
logging.warning("Unknown type : " + str(type(message)))
|
|
47
|
+
hal_message = HALSystemMessage(**message.model_dump())
|
|
48
|
+
|
|
49
|
+
# Add HAL links and templates
|
|
50
|
+
hal_message.links = {
|
|
51
|
+
"self": HALLinkFor(endpoint_function_name="read_message", tags=tags, path_params={"thread_id": str(thread_id), "message_id": str(message.id)}, condition=message.id is not None),
|
|
52
|
+
"thread": HALLinkFor(endpoint_function_name="read_thread", tags=tags, path_params={"thread_id": str(thread_id)}),
|
|
53
|
+
"messages": HALLinkFor(endpoint_function_name="read_messages", tags=tags, path_params={"thread_id": str(thread_id)})
|
|
54
|
+
}
|
|
55
|
+
|
|
56
|
+
return hal_message
|
|
57
|
+
|
|
58
|
+
async def convert_upload_file_to_contentblock(file: UploadFile) -> ImageContentBlock | FileContentBlock | AudioContentBlock:
|
|
59
|
+
# Reference: https://docs.langchain.com/oss/python/langchain/messages#content-block-reference
|
|
60
|
+
|
|
61
|
+
# Read file content
|
|
62
|
+
file_content = await file.read()
|
|
63
|
+
file_base64 = base64.b64encode(file_content).decode("utf-8")
|
|
64
|
+
|
|
65
|
+
# Determine content type based on file's content type
|
|
66
|
+
if file.content_type:
|
|
67
|
+
content_type = file.content_type
|
|
68
|
+
|
|
69
|
+
if content_type.startswith("image/"):
|
|
70
|
+
# Handle image files
|
|
71
|
+
return create_image_block(
|
|
72
|
+
base64=file_base64,
|
|
73
|
+
mime_type=content_type
|
|
74
|
+
)
|
|
75
|
+
elif content_type.startswith("audio/"):
|
|
76
|
+
# Handle audio files
|
|
77
|
+
return create_audio_block(
|
|
78
|
+
base64=file_base64,
|
|
79
|
+
mime_type=content_type
|
|
80
|
+
)
|
|
81
|
+
else:
|
|
82
|
+
# Handle PDF files and other extensions
|
|
83
|
+
return create_file_block(
|
|
84
|
+
base64=file_base64,
|
|
85
|
+
mime_type=content_type,
|
|
86
|
+
filename=file.filename if file.filename else None
|
|
87
|
+
)
|
|
88
|
+
else:
|
|
89
|
+
# No content type provided, treat as generic file
|
|
90
|
+
return create_file_block(
|
|
91
|
+
base64=file_base64,
|
|
92
|
+
mime_type="application/octet-stream",
|
|
93
|
+
filename=file.filename if file.filename else None
|
|
94
|
+
)
|
|
95
|
+
|
|
96
|
+
|
|
97
|
+
@messagesrouter.get("/", response_model=FastAPIHALCollection, response_model_exclude_none=True)
|
|
98
|
+
def read_messages(
|
|
99
|
+
thread_id: uuid.UUID,
|
|
100
|
+
user: ContentGridUser = Depends(dep_resolver.get_current_user_dependency()),
|
|
101
|
+
thread_repo: ThreadRepository = Depends(dep_resolver.get_thread_repository_dependency()),
|
|
102
|
+
agent: CompiledStateGraph = Depends(dep_resolver.get_agent_dependency()),
|
|
103
|
+
thread_context: DefaultThreadContext = Depends(dep_resolver.get_thread_context_dependency())
|
|
104
|
+
):
|
|
105
|
+
"""Get all messages for a thread"""
|
|
106
|
+
state = agent.get_state(config={"configurable" : thread_context}) # type: ignore
|
|
107
|
+
new_message_can_be_added = True
|
|
108
|
+
if state and state.values:
|
|
109
|
+
messages = state.values['messages']
|
|
110
|
+
# Convert LangChain messages to HAL message types
|
|
111
|
+
hal_messages = [convert_to_hal_message(msg, thread_id) for msg in messages]
|
|
112
|
+
last_message = messages[-1]
|
|
113
|
+
if not isinstance(last_message, AIMessage):
|
|
114
|
+
new_message_can_be_added = False
|
|
115
|
+
else:
|
|
116
|
+
# we have an AI message but we need to check if it is waiting for tools to execute.
|
|
117
|
+
if len(last_message.tool_calls) > 0:
|
|
118
|
+
new_message_can_be_added = False
|
|
119
|
+
else:
|
|
120
|
+
hal_messages = []
|
|
121
|
+
|
|
122
|
+
return FastAPIHALCollection[Union[HALHumanMessage, HALAIMessage, HALSystemMessage, HALToolMessage]](
|
|
123
|
+
_embedded={"messages": hal_messages},
|
|
124
|
+
_links={
|
|
125
|
+
"self": HALLinkFor(endpoint_function_name="read_messages", tags=tags, path_params={"thread_id": str(thread_id)}),
|
|
126
|
+
"thread": HALLinkFor(endpoint_function_name="read_thread", tags=tags, path_params={"thread_id": str(thread_id)})
|
|
127
|
+
},
|
|
128
|
+
_templates={"addMessage": HALTemplateFor(endpoint_function_name="add_message", tags=tags, path_params={"thread_id": str(thread_id)}, condition=new_message_can_be_added)}
|
|
129
|
+
)
|
|
130
|
+
|
|
131
|
+
|
|
132
|
+
@messagesrouter.post("/", response_model=HALHumanMessage, status_code=status.HTTP_202_ACCEPTED, response_model_exclude_none=True)
|
|
133
|
+
async def add_message(
|
|
134
|
+
request: Request,
|
|
135
|
+
thread_id: uuid.UUID,
|
|
136
|
+
question: Annotated[str, Form()],
|
|
137
|
+
background_tasks : BackgroundTasks,
|
|
138
|
+
user: ContentGridUser = Depends(dep_resolver.get_current_user_dependency()),
|
|
139
|
+
thread_repo: ThreadRepository = Depends(dep_resolver.get_thread_repository_dependency()),
|
|
140
|
+
agent: CompiledStateGraph = Depends(dep_resolver.get_agent_dependency()),
|
|
141
|
+
checkpointer: BaseCheckpointSaver = Depends(dep_resolver.get_langgraph_checkpointer_dependency()),
|
|
142
|
+
thread_context: DefaultThreadContext = Depends(dep_resolver.get_thread_context_dependency()),
|
|
143
|
+
file: UploadFile = File(None),
|
|
144
|
+
):
|
|
145
|
+
"""Add new messages to a thread"""
|
|
146
|
+
# Check Accept header to determine response format
|
|
147
|
+
accept_header = request.headers.get("accept", "").lower()
|
|
148
|
+
streaming = (
|
|
149
|
+
"text/event-stream" in accept_header
|
|
150
|
+
)
|
|
151
|
+
# Build content blocks
|
|
152
|
+
content_blocks : List[ImageContentBlock | FileContentBlock | AudioContentBlock | TextContentBlock] = [create_text_block(text=question)]
|
|
153
|
+
|
|
154
|
+
if file:
|
|
155
|
+
# Add custom content blocks based on file type
|
|
156
|
+
file_block = await convert_upload_file_to_contentblock(file)
|
|
157
|
+
content_blocks.append(file_block)
|
|
158
|
+
|
|
159
|
+
# Create the message with proper content structure
|
|
160
|
+
new_message = HumanMessage(content=content_blocks) # type: ignore
|
|
161
|
+
messages = [new_message]
|
|
162
|
+
|
|
163
|
+
if streaming:
|
|
164
|
+
state = agent.get_state(config={"configurable" : thread_context}) #type: ignore
|
|
165
|
+
nb_current_messages = 0
|
|
166
|
+
if state and state.values and 'messages' in state.values.keys() and len(state.values['messages']):
|
|
167
|
+
nb_current_messages = len(state.values['messages'])
|
|
168
|
+
|
|
169
|
+
@typing.no_type_check
|
|
170
|
+
async def generate_stream(current_message_index) -> AsyncGenerator[str, None]:
|
|
171
|
+
for mode, chunk in agent.stream(
|
|
172
|
+
{"messages": messages},
|
|
173
|
+
{"configurable": thread_context, "recursion_limit": extension_config.graph_recursion_limit},
|
|
174
|
+
context=thread_context,
|
|
175
|
+
stream_mode=["values", "messages"]
|
|
176
|
+
):
|
|
177
|
+
if mode=="values":
|
|
178
|
+
values = chunk
|
|
179
|
+
new_state_messages = values["messages"][current_message_index:]
|
|
180
|
+
hal_messages = [convert_to_hal_message(msg, thread_id) for msg in new_state_messages]
|
|
181
|
+
|
|
182
|
+
# Convert HAL messages to dictionaries
|
|
183
|
+
messages_data = []
|
|
184
|
+
for msg in hal_messages:
|
|
185
|
+
if hasattr(msg, 'model_dump'):
|
|
186
|
+
messages_data.append(msg.model_dump(exclude_unset=True))
|
|
187
|
+
else:
|
|
188
|
+
messages_data.append(msg)
|
|
189
|
+
|
|
190
|
+
# Send JSON chunk with proper structure
|
|
191
|
+
yield "event: message\n"
|
|
192
|
+
yield f"data: {json.dumps(messages_data)}\n\n"
|
|
193
|
+
current_message_index += len(new_state_messages)
|
|
194
|
+
elif mode == "messages":
|
|
195
|
+
message_chunk, metadata = chunk
|
|
196
|
+
if message_chunk.content:
|
|
197
|
+
try:
|
|
198
|
+
yield "event: token\n"
|
|
199
|
+
yield f"data: {json.dumps(message_chunk.content)}\n"
|
|
200
|
+
yield f"id: {message_chunk.id}\n\n"
|
|
201
|
+
except Exception as e:
|
|
202
|
+
logging.exception(e)
|
|
203
|
+
continue
|
|
204
|
+
|
|
205
|
+
# Send completion signal
|
|
206
|
+
yield f"data: {json.dumps({'complete': True})}\n\n"
|
|
207
|
+
|
|
208
|
+
return StreamingResponse(
|
|
209
|
+
generate_stream(current_message_index=nb_current_messages),
|
|
210
|
+
media_type="text/event-stream",
|
|
211
|
+
headers={
|
|
212
|
+
"Cache-Control": "no-cache",
|
|
213
|
+
"Connection": "keep-alive",
|
|
214
|
+
}
|
|
215
|
+
)
|
|
216
|
+
else:
|
|
217
|
+
def process_agent_response():
|
|
218
|
+
try:
|
|
219
|
+
agent.invoke(
|
|
220
|
+
{"messages": messages},
|
|
221
|
+
{"configurable": thread_context, "recursion_limit": extension_config.graph_recursion_limit},
|
|
222
|
+
context=thread_context,
|
|
223
|
+
)
|
|
224
|
+
except Exception as e:
|
|
225
|
+
# Log the error since this runs in background
|
|
226
|
+
logging.exception(f"Error processing agent response: {e}")
|
|
227
|
+
|
|
228
|
+
# Add the task to background tasks
|
|
229
|
+
background_tasks.add_task(process_agent_response)
|
|
230
|
+
|
|
231
|
+
# Return immediately with the human message
|
|
232
|
+
return convert_to_hal_message(new_message, thread_id=thread_id)
|
|
233
|
+
|
|
234
|
+
@messagesrouter.get("/{message_id}", response_model=Union[HALHumanMessage, HALAIMessage, HALSystemMessage, HALToolMessage], response_model_exclude_none=True)
|
|
235
|
+
def read_message(
|
|
236
|
+
thread_id: uuid.UUID,
|
|
237
|
+
message_id: str,
|
|
238
|
+
user: ContentGridUser = Depends(dep_resolver.get_current_user_dependency()),
|
|
239
|
+
thread_repo: ThreadRepository = Depends(dep_resolver.get_thread_repository_dependency()),
|
|
240
|
+
checkpointer: BaseCheckpointSaver = Depends(dep_resolver.get_langgraph_checkpointer_dependency()),
|
|
241
|
+
thread_context: DefaultThreadContext = Depends(dep_resolver.get_thread_context_dependency())
|
|
242
|
+
):
|
|
243
|
+
"""Get a specific message by ID from a thread"""
|
|
244
|
+
state = checkpointer.get({"configurable": thread_context}) #type: ignore
|
|
245
|
+
if state and "channel_values" in state.keys():
|
|
246
|
+
messages = state["channel_values"]["messages"]
|
|
247
|
+
|
|
248
|
+
# Find the message by ID
|
|
249
|
+
for msg in messages:
|
|
250
|
+
if msg.id == message_id:
|
|
251
|
+
return convert_to_hal_message(msg, thread_id)
|
|
252
|
+
|
|
253
|
+
# Message not found
|
|
254
|
+
raise HTTPException(status_code=404, detail="Message not found")
|
|
255
|
+
|
|
256
|
+
return messagesrouter
|
|
@@ -0,0 +1,129 @@
|
|
|
1
|
+
|
|
2
|
+
from enum import Enum
|
|
3
|
+
from typing import Annotated, List, Optional
|
|
4
|
+
import uuid
|
|
5
|
+
from fastapi import APIRouter, Depends, Query, status
|
|
6
|
+
from pydantic import HttpUrl
|
|
7
|
+
|
|
8
|
+
from contentgrid_assistant_api.db.repositories.thread_repository import ThreadRepository
|
|
9
|
+
from contentgrid_assistant_api.db.types.thread import ThreadRead, ThreadUpdate
|
|
10
|
+
from contentgrid_assistant_api.dependencies import DependencyResolver
|
|
11
|
+
from contentgrid_extension_helpers.responses.hal import FastAPIHALCollection, HALLinkFor, HALTemplateFor
|
|
12
|
+
from contentgrid_extension_helpers.authentication import ContentGridUser
|
|
13
|
+
from contentgrid_assistant_api.routers.message_router import generate_agent_message_router
|
|
14
|
+
from contentgrid_assistant_api.types.agents import AgentToolCollectionResponse, AgentToolResponse
|
|
15
|
+
from contentgrid_assistant_api.types.context import DefaultThreadContext
|
|
16
|
+
from langchain.messages import HumanMessage
|
|
17
|
+
from langgraph.graph.state import CompiledStateGraph
|
|
18
|
+
from contentgrid_assistant_api.config import AssistantExtensionConfig
|
|
19
|
+
|
|
20
|
+
def generate_agent_thread_router(dep_resolver: DependencyResolver, extension_config: AssistantExtensionConfig, tags: Optional[List[str | Enum]]=None):
|
|
21
|
+
threadrouter = APIRouter(prefix=extension_config.routes_thread_prefix, tags=tags or ["threads"])
|
|
22
|
+
|
|
23
|
+
threadrouter.include_router(
|
|
24
|
+
generate_agent_message_router(dep_resolver, extension_config, tags=tags)
|
|
25
|
+
)
|
|
26
|
+
|
|
27
|
+
@threadrouter.post("/", response_model=ThreadRead, status_code=status.HTTP_201_CREATED, response_model_exclude_none=True)
|
|
28
|
+
def create_thread(
|
|
29
|
+
origin: Optional[HttpUrl] = None,
|
|
30
|
+
user: ContentGridUser = Depends(dep_resolver.get_current_user_dependency()),
|
|
31
|
+
thread_repo: ThreadRepository = Depends(dep_resolver.get_thread_repository_dependency()),
|
|
32
|
+
agent : CompiledStateGraph = Depends(dep_resolver.get_agent_dependency())
|
|
33
|
+
):
|
|
34
|
+
"""Create new thread"""
|
|
35
|
+
|
|
36
|
+
thread_id = uuid.uuid4()
|
|
37
|
+
messages = [HumanMessage(content=extension_config.opening_message)]
|
|
38
|
+
context = dep_resolver.agent.thread_context(thread_id=str(thread_id), user=user, origin=origin)
|
|
39
|
+
agent.invoke(
|
|
40
|
+
{"messages": messages},
|
|
41
|
+
{"configurable": context}, #type: ignore
|
|
42
|
+
context=context, #type: ignore
|
|
43
|
+
)
|
|
44
|
+
|
|
45
|
+
created_thread = thread_repo.create(user=user, origin=origin, component="datamodel", thread_id=thread_id)
|
|
46
|
+
return ThreadRead(**created_thread.model_dump(), tags=tags)
|
|
47
|
+
|
|
48
|
+
@threadrouter.get("/", response_model=FastAPIHALCollection, response_model_exclude_none=True)
|
|
49
|
+
def read_threads(
|
|
50
|
+
origin: Optional[HttpUrl] = None,
|
|
51
|
+
user: ContentGridUser = Depends(dep_resolver.get_current_user_dependency()),
|
|
52
|
+
thread_repo: ThreadRepository = Depends(dep_resolver.get_thread_repository_dependency()),
|
|
53
|
+
offset: int = 0,
|
|
54
|
+
limit: Annotated[int, Query(le=100)] = 100,
|
|
55
|
+
):
|
|
56
|
+
"""Get all threads with pagination"""
|
|
57
|
+
threads = []
|
|
58
|
+
if origin:
|
|
59
|
+
threads = thread_repo.get_all_for_user_and_origin(user, origin, offset=offset, limit=limit)
|
|
60
|
+
else:
|
|
61
|
+
threads = thread_repo.get_all_for_user(user, offset=offset, limit=limit)
|
|
62
|
+
thread_reads = [ThreadRead(**thread.model_dump(), tags=tags) for thread in threads]
|
|
63
|
+
return FastAPIHALCollection[ThreadRead](
|
|
64
|
+
_embedded={"threads" : thread_reads},
|
|
65
|
+
_links={"self": HALLinkFor(endpoint_function_name="read_threads", tags=tags, templated=True, params={"offset": offset, "limit": limit})},
|
|
66
|
+
_templates={"startThread": HALTemplateFor(endpoint_function_name="create_thread", tags=tags, params={"origin": origin.encoded_string()} if origin else {})},
|
|
67
|
+
)
|
|
68
|
+
|
|
69
|
+
@threadrouter.get("/{thread_id}", response_model=ThreadRead, response_model_exclude_none=True)
|
|
70
|
+
def read_thread(
|
|
71
|
+
thread_id: uuid.UUID,
|
|
72
|
+
user: ContentGridUser = Depends(dep_resolver.get_current_user_dependency()),
|
|
73
|
+
thread_repo: ThreadRepository = Depends(dep_resolver.get_thread_repository_dependency())
|
|
74
|
+
):
|
|
75
|
+
"""Get thread by ID"""
|
|
76
|
+
thread = thread_repo.get_by_id_for_user(thread_id, user)
|
|
77
|
+
return ThreadRead(**thread.model_dump(), tags=tags)
|
|
78
|
+
|
|
79
|
+
@threadrouter.get("/{thread_id}/tools", response_model=AgentToolCollectionResponse, response_model_exclude_none=True)
|
|
80
|
+
def get_thread_tools(
|
|
81
|
+
thread_id: uuid.UUID,
|
|
82
|
+
user: ContentGridUser = Depends(dep_resolver.get_current_user_dependency()),
|
|
83
|
+
thread_repo: ThreadRepository = Depends(dep_resolver.get_thread_repository_dependency()),
|
|
84
|
+
thread_context: DefaultThreadContext = Depends(dep_resolver.get_thread_context_dependency())
|
|
85
|
+
):
|
|
86
|
+
"""Get available tools for a thread"""
|
|
87
|
+
tools_list = dep_resolver.agent.get_tools(thread_context)
|
|
88
|
+
|
|
89
|
+
# Extract tool information including schema
|
|
90
|
+
tool_responses = []
|
|
91
|
+
for tool in tools_list:
|
|
92
|
+
tool_dict = tool.model_dump()
|
|
93
|
+
|
|
94
|
+
# Extract schema if args_schema is available
|
|
95
|
+
if hasattr(tool, 'args_schema') and tool.args_schema:
|
|
96
|
+
try:
|
|
97
|
+
# Get JSON schema from Pydantic model
|
|
98
|
+
tool_dict['args'] = tool.args_schema.model_json_schema()
|
|
99
|
+
except (AttributeError, Exception):
|
|
100
|
+
# Fallback if schema doesn't exist or can't be extracted
|
|
101
|
+
tool_dict['args'] = {}
|
|
102
|
+
else:
|
|
103
|
+
tool_dict['args'] = {}
|
|
104
|
+
|
|
105
|
+
tool_responses.append(AgentToolResponse(**tool_dict))
|
|
106
|
+
|
|
107
|
+
return AgentToolCollectionResponse(thread_id=thread_id, _embedded={"tools": tool_responses}, tags=tags)
|
|
108
|
+
|
|
109
|
+
@threadrouter.patch("/{thread_id}", response_model=ThreadRead, response_model_exclude_none=True)
|
|
110
|
+
def update_thread(
|
|
111
|
+
thread_id: uuid.UUID,
|
|
112
|
+
thread_update: ThreadUpdate,
|
|
113
|
+
user: ContentGridUser = Depends(dep_resolver.get_current_user_dependency()),
|
|
114
|
+
thread_repo: ThreadRepository = Depends(dep_resolver.get_thread_repository_dependency())
|
|
115
|
+
):
|
|
116
|
+
"""Update thread"""
|
|
117
|
+
updated_thread = thread_repo.update_for_user(thread_id, thread_update, user)
|
|
118
|
+
return ThreadRead(**updated_thread.model_dump(), tags=tags)
|
|
119
|
+
|
|
120
|
+
@threadrouter.delete("/{thread_id}", status_code=status.HTTP_204_NO_CONTENT)
|
|
121
|
+
def delete_thread(
|
|
122
|
+
thread_id: uuid.UUID,
|
|
123
|
+
user: ContentGridUser = Depends(dep_resolver.get_current_user_dependency()),
|
|
124
|
+
thread_repo: ThreadRepository = Depends(dep_resolver.get_thread_repository_dependency())
|
|
125
|
+
):
|
|
126
|
+
"""Delete thread"""
|
|
127
|
+
thread_repo.delete_for_user(thread_id, user)
|
|
128
|
+
|
|
129
|
+
return threadrouter
|
|
@@ -0,0 +1,50 @@
|
|
|
1
|
+
|
|
2
|
+
from typing import Any, Callable, List, Optional
|
|
3
|
+
from enum import Enum
|
|
4
|
+
import uuid
|
|
5
|
+
from pydantic import BaseModel, Field, ConfigDict
|
|
6
|
+
from langchain.tools import BaseTool
|
|
7
|
+
from contentgrid_extension_helpers.responses.hal import FastAPIHALResponse, FastAPIHALCollection, HALLinkFor
|
|
8
|
+
from contentgrid_assistant_api.types.context import DefaultThreadContext
|
|
9
|
+
|
|
10
|
+
class Agent(BaseModel):
|
|
11
|
+
model_config = ConfigDict(arbitrary_types_allowed=True)
|
|
12
|
+
|
|
13
|
+
name : str = "Default agent name"
|
|
14
|
+
version : str = "v0"
|
|
15
|
+
get_agent_override : Callable[..., Any] = Field(exclude=True, default_factory=lambda x : None) # used to override the get agent dependency
|
|
16
|
+
get_current_user_override : Callable[..., Any] = Field(exclude=True, default_factory=lambda x : None) # used to override authentication
|
|
17
|
+
thread_context : type = Field(exclude=True, default=DefaultThreadContext)
|
|
18
|
+
# Function that accepts the thread_context and returns the tools.
|
|
19
|
+
# We do this dynamically because tools could be generated based on the context (like based on the origin profile)
|
|
20
|
+
get_tools : Callable[..., List[BaseTool]] = Field(exclude=True, default=lambda _ : [])
|
|
21
|
+
|
|
22
|
+
|
|
23
|
+
class AgentHomeResponse(FastAPIHALResponse):
|
|
24
|
+
model_config = ConfigDict(arbitrary_types_allowed=True)
|
|
25
|
+
name : str
|
|
26
|
+
version : str
|
|
27
|
+
|
|
28
|
+
def __init__(self, tags: Optional[List[str | Enum]]=None, **kwargs):
|
|
29
|
+
super().__init__(**kwargs)
|
|
30
|
+
self.links = {
|
|
31
|
+
"self" : HALLinkFor(endpoint_function_name="get_agent_home", tags=tags),
|
|
32
|
+
"threads" : HALLinkFor(endpoint_function_name="read_threads", tags=tags),
|
|
33
|
+
}
|
|
34
|
+
|
|
35
|
+
class AgentToolResponse(FastAPIHALResponse):
|
|
36
|
+
name : str
|
|
37
|
+
description : str
|
|
38
|
+
args : dict
|
|
39
|
+
|
|
40
|
+
class AgentToolCollectionResponse(FastAPIHALCollection[AgentToolResponse]):
|
|
41
|
+
thread_id : uuid.UUID
|
|
42
|
+
|
|
43
|
+
def __init__(self, tags: Optional[List[str | Enum]]=None, **kwargs):
|
|
44
|
+
super().__init__(**kwargs)
|
|
45
|
+
self.links = {
|
|
46
|
+
"self": HALLinkFor(endpoint_function_name="get_thread_tools", tags=tags, templated=False, path_params=lambda instance: {"thread_id": instance.thread_id}),
|
|
47
|
+
"thread": HALLinkFor(endpoint_function_name="read_thread", tags=tags, templated=False, path_params=lambda instance: {"thread_id": instance.thread_id}),
|
|
48
|
+
}
|
|
49
|
+
|
|
50
|
+
|