contentgrid-assistant-api 0.0.2__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -0,0 +1,256 @@
1
+ # type : ignore
2
+
3
+ import base64
4
+ import json
5
+ from typing import Annotated, AsyncGenerator, List, Union, Optional
6
+ import typing
7
+ import uuid
8
+ from enum import Enum
9
+ from langgraph.graph.state import CompiledStateGraph
10
+ from fastapi import APIRouter, BackgroundTasks, File, Form, Request, UploadFile, status
11
+ from fastapi.params import Depends
12
+ from contentgrid_assistant_api.config import AssistantExtensionConfig
13
+ from fastapi.responses import StreamingResponse
14
+ from contentgrid_assistant_api.db.repositories.thread_repository import ThreadRepository
15
+ from contentgrid_assistant_api.db.types.message import HALHumanMessage, HALAIMessage, HALSystemMessage, HALToolMessage
16
+ from contentgrid_assistant_api.dependencies import DependencyResolver
17
+ from contentgrid_extension_helpers.authentication import ContentGridUser
18
+ from langgraph.checkpoint.base import BaseCheckpointSaver
19
+ from contentgrid_extension_helpers.responses.hal import FastAPIHALCollection, HALTemplateFor, HALLinkFor
20
+ from langchain_core.messages import HumanMessage, SystemMessage, AIMessage, BaseMessage, ToolMessage
21
+ from fastapi import HTTPException
22
+
23
+
24
+ # ContentBlocks
25
+ from langchain_core.messages.content import create_image_block, create_file_block, create_audio_block, create_text_block, ImageContentBlock, AudioContentBlock, FileContentBlock, TextContentBlock
26
+ import logging
27
+
28
+ from contentgrid_assistant_api.types.context import DefaultThreadContext
29
+
30
+ def generate_agent_message_router(dep_resolver: DependencyResolver, extension_config: AssistantExtensionConfig, tags: Optional[List[str | Enum]]=None) -> APIRouter:
31
+ messagesrouter = APIRouter(prefix="/{thread_id}" + extension_config.routes_message_prefix, tags=tags or ["messages"])
32
+
33
+ def convert_to_hal_message(message: BaseMessage, thread_id: uuid.UUID) -> Union[HALHumanMessage, HALAIMessage, HALSystemMessage, HALToolMessage]:
34
+ """Convert a LangChain BaseMessage to the appropriate HAL message type"""
35
+ if isinstance(message, ToolMessage):
36
+ # Handle ToolMessage first since it might inherit from other message types
37
+ hal_message = HALToolMessage(**message.model_dump())
38
+ elif isinstance(message, HumanMessage):
39
+ hal_message = HALHumanMessage(**message.model_dump())
40
+ elif isinstance(message, AIMessage):
41
+ hal_message = HALAIMessage(**message.model_dump())
42
+ elif isinstance(message, SystemMessage):
43
+ hal_message = HALSystemMessage(**message.model_dump())
44
+ else:
45
+ # Default to HALSystemMessage for unknown types
46
+ logging.warning("Unknown type : " + str(type(message)))
47
+ hal_message = HALSystemMessage(**message.model_dump())
48
+
49
+ # Add HAL links and templates
50
+ hal_message.links = {
51
+ "self": HALLinkFor(endpoint_function_name="read_message", tags=tags, path_params={"thread_id": str(thread_id), "message_id": str(message.id)}, condition=message.id is not None),
52
+ "thread": HALLinkFor(endpoint_function_name="read_thread", tags=tags, path_params={"thread_id": str(thread_id)}),
53
+ "messages": HALLinkFor(endpoint_function_name="read_messages", tags=tags, path_params={"thread_id": str(thread_id)})
54
+ }
55
+
56
+ return hal_message
57
+
58
+ async def convert_upload_file_to_contentblock(file: UploadFile) -> ImageContentBlock | FileContentBlock | AudioContentBlock:
59
+ # Reference: https://docs.langchain.com/oss/python/langchain/messages#content-block-reference
60
+
61
+ # Read file content
62
+ file_content = await file.read()
63
+ file_base64 = base64.b64encode(file_content).decode("utf-8")
64
+
65
+ # Determine content type based on file's content type
66
+ if file.content_type:
67
+ content_type = file.content_type
68
+
69
+ if content_type.startswith("image/"):
70
+ # Handle image files
71
+ return create_image_block(
72
+ base64=file_base64,
73
+ mime_type=content_type
74
+ )
75
+ elif content_type.startswith("audio/"):
76
+ # Handle audio files
77
+ return create_audio_block(
78
+ base64=file_base64,
79
+ mime_type=content_type
80
+ )
81
+ else:
82
+ # Handle PDF files and other extensions
83
+ return create_file_block(
84
+ base64=file_base64,
85
+ mime_type=content_type,
86
+ filename=file.filename if file.filename else None
87
+ )
88
+ else:
89
+ # No content type provided, treat as generic file
90
+ return create_file_block(
91
+ base64=file_base64,
92
+ mime_type="application/octet-stream",
93
+ filename=file.filename if file.filename else None
94
+ )
95
+
96
+
97
+ @messagesrouter.get("/", response_model=FastAPIHALCollection, response_model_exclude_none=True)
98
+ def read_messages(
99
+ thread_id: uuid.UUID,
100
+ user: ContentGridUser = Depends(dep_resolver.get_current_user_dependency()),
101
+ thread_repo: ThreadRepository = Depends(dep_resolver.get_thread_repository_dependency()),
102
+ agent: CompiledStateGraph = Depends(dep_resolver.get_agent_dependency()),
103
+ thread_context: DefaultThreadContext = Depends(dep_resolver.get_thread_context_dependency())
104
+ ):
105
+ """Get all messages for a thread"""
106
+ state = agent.get_state(config={"configurable" : thread_context}) # type: ignore
107
+ new_message_can_be_added = True
108
+ if state and state.values:
109
+ messages = state.values['messages']
110
+ # Convert LangChain messages to HAL message types
111
+ hal_messages = [convert_to_hal_message(msg, thread_id) for msg in messages]
112
+ last_message = messages[-1]
113
+ if not isinstance(last_message, AIMessage):
114
+ new_message_can_be_added = False
115
+ else:
116
+ # we have an AI message but we need to check if it is waiting for tools to execute.
117
+ if len(last_message.tool_calls) > 0:
118
+ new_message_can_be_added = False
119
+ else:
120
+ hal_messages = []
121
+
122
+ return FastAPIHALCollection[Union[HALHumanMessage, HALAIMessage, HALSystemMessage, HALToolMessage]](
123
+ _embedded={"messages": hal_messages},
124
+ _links={
125
+ "self": HALLinkFor(endpoint_function_name="read_messages", tags=tags, path_params={"thread_id": str(thread_id)}),
126
+ "thread": HALLinkFor(endpoint_function_name="read_thread", tags=tags, path_params={"thread_id": str(thread_id)})
127
+ },
128
+ _templates={"addMessage": HALTemplateFor(endpoint_function_name="add_message", tags=tags, path_params={"thread_id": str(thread_id)}, condition=new_message_can_be_added)}
129
+ )
130
+
131
+
132
+ @messagesrouter.post("/", response_model=HALHumanMessage, status_code=status.HTTP_202_ACCEPTED, response_model_exclude_none=True)
133
+ async def add_message(
134
+ request: Request,
135
+ thread_id: uuid.UUID,
136
+ question: Annotated[str, Form()],
137
+ background_tasks : BackgroundTasks,
138
+ user: ContentGridUser = Depends(dep_resolver.get_current_user_dependency()),
139
+ thread_repo: ThreadRepository = Depends(dep_resolver.get_thread_repository_dependency()),
140
+ agent: CompiledStateGraph = Depends(dep_resolver.get_agent_dependency()),
141
+ checkpointer: BaseCheckpointSaver = Depends(dep_resolver.get_langgraph_checkpointer_dependency()),
142
+ thread_context: DefaultThreadContext = Depends(dep_resolver.get_thread_context_dependency()),
143
+ file: UploadFile = File(None),
144
+ ):
145
+ """Add new messages to a thread"""
146
+ # Check Accept header to determine response format
147
+ accept_header = request.headers.get("accept", "").lower()
148
+ streaming = (
149
+ "text/event-stream" in accept_header
150
+ )
151
+ # Build content blocks
152
+ content_blocks : List[ImageContentBlock | FileContentBlock | AudioContentBlock | TextContentBlock] = [create_text_block(text=question)]
153
+
154
+ if file:
155
+ # Add custom content blocks based on file type
156
+ file_block = await convert_upload_file_to_contentblock(file)
157
+ content_blocks.append(file_block)
158
+
159
+ # Create the message with proper content structure
160
+ new_message = HumanMessage(content=content_blocks) # type: ignore
161
+ messages = [new_message]
162
+
163
+ if streaming:
164
+ state = agent.get_state(config={"configurable" : thread_context}) #type: ignore
165
+ nb_current_messages = 0
166
+ if state and state.values and 'messages' in state.values.keys() and len(state.values['messages']):
167
+ nb_current_messages = len(state.values['messages'])
168
+
169
+ @typing.no_type_check
170
+ async def generate_stream(current_message_index) -> AsyncGenerator[str, None]:
171
+ for mode, chunk in agent.stream(
172
+ {"messages": messages},
173
+ {"configurable": thread_context, "recursion_limit": extension_config.graph_recursion_limit},
174
+ context=thread_context,
175
+ stream_mode=["values", "messages"]
176
+ ):
177
+ if mode=="values":
178
+ values = chunk
179
+ new_state_messages = values["messages"][current_message_index:]
180
+ hal_messages = [convert_to_hal_message(msg, thread_id) for msg in new_state_messages]
181
+
182
+ # Convert HAL messages to dictionaries
183
+ messages_data = []
184
+ for msg in hal_messages:
185
+ if hasattr(msg, 'model_dump'):
186
+ messages_data.append(msg.model_dump(exclude_unset=True))
187
+ else:
188
+ messages_data.append(msg)
189
+
190
+ # Send JSON chunk with proper structure
191
+ yield "event: message\n"
192
+ yield f"data: {json.dumps(messages_data)}\n\n"
193
+ current_message_index += len(new_state_messages)
194
+ elif mode == "messages":
195
+ message_chunk, metadata = chunk
196
+ if message_chunk.content:
197
+ try:
198
+ yield "event: token\n"
199
+ yield f"data: {json.dumps(message_chunk.content)}\n"
200
+ yield f"id: {message_chunk.id}\n\n"
201
+ except Exception as e:
202
+ logging.exception(e)
203
+ continue
204
+
205
+ # Send completion signal
206
+ yield f"data: {json.dumps({'complete': True})}\n\n"
207
+
208
+ return StreamingResponse(
209
+ generate_stream(current_message_index=nb_current_messages),
210
+ media_type="text/event-stream",
211
+ headers={
212
+ "Cache-Control": "no-cache",
213
+ "Connection": "keep-alive",
214
+ }
215
+ )
216
+ else:
217
+ def process_agent_response():
218
+ try:
219
+ agent.invoke(
220
+ {"messages": messages},
221
+ {"configurable": thread_context, "recursion_limit": extension_config.graph_recursion_limit},
222
+ context=thread_context,
223
+ )
224
+ except Exception as e:
225
+ # Log the error since this runs in background
226
+ logging.exception(f"Error processing agent response: {e}")
227
+
228
+ # Add the task to background tasks
229
+ background_tasks.add_task(process_agent_response)
230
+
231
+ # Return immediately with the human message
232
+ return convert_to_hal_message(new_message, thread_id=thread_id)
233
+
234
+ @messagesrouter.get("/{message_id}", response_model=Union[HALHumanMessage, HALAIMessage, HALSystemMessage, HALToolMessage], response_model_exclude_none=True)
235
+ def read_message(
236
+ thread_id: uuid.UUID,
237
+ message_id: str,
238
+ user: ContentGridUser = Depends(dep_resolver.get_current_user_dependency()),
239
+ thread_repo: ThreadRepository = Depends(dep_resolver.get_thread_repository_dependency()),
240
+ checkpointer: BaseCheckpointSaver = Depends(dep_resolver.get_langgraph_checkpointer_dependency()),
241
+ thread_context: DefaultThreadContext = Depends(dep_resolver.get_thread_context_dependency())
242
+ ):
243
+ """Get a specific message by ID from a thread"""
244
+ state = checkpointer.get({"configurable": thread_context}) #type: ignore
245
+ if state and "channel_values" in state.keys():
246
+ messages = state["channel_values"]["messages"]
247
+
248
+ # Find the message by ID
249
+ for msg in messages:
250
+ if msg.id == message_id:
251
+ return convert_to_hal_message(msg, thread_id)
252
+
253
+ # Message not found
254
+ raise HTTPException(status_code=404, detail="Message not found")
255
+
256
+ return messagesrouter
@@ -0,0 +1,129 @@
1
+
2
+ from enum import Enum
3
+ from typing import Annotated, List, Optional
4
+ import uuid
5
+ from fastapi import APIRouter, Depends, Query, status
6
+ from pydantic import HttpUrl
7
+
8
+ from contentgrid_assistant_api.db.repositories.thread_repository import ThreadRepository
9
+ from contentgrid_assistant_api.db.types.thread import ThreadRead, ThreadUpdate
10
+ from contentgrid_assistant_api.dependencies import DependencyResolver
11
+ from contentgrid_extension_helpers.responses.hal import FastAPIHALCollection, HALLinkFor, HALTemplateFor
12
+ from contentgrid_extension_helpers.authentication import ContentGridUser
13
+ from contentgrid_assistant_api.routers.message_router import generate_agent_message_router
14
+ from contentgrid_assistant_api.types.agents import AgentToolCollectionResponse, AgentToolResponse
15
+ from contentgrid_assistant_api.types.context import DefaultThreadContext
16
+ from langchain.messages import HumanMessage
17
+ from langgraph.graph.state import CompiledStateGraph
18
+ from contentgrid_assistant_api.config import AssistantExtensionConfig
19
+
20
+ def generate_agent_thread_router(dep_resolver: DependencyResolver, extension_config: AssistantExtensionConfig, tags: Optional[List[str | Enum]]=None):
21
+ threadrouter = APIRouter(prefix=extension_config.routes_thread_prefix, tags=tags or ["threads"])
22
+
23
+ threadrouter.include_router(
24
+ generate_agent_message_router(dep_resolver, extension_config, tags=tags)
25
+ )
26
+
27
+ @threadrouter.post("/", response_model=ThreadRead, status_code=status.HTTP_201_CREATED, response_model_exclude_none=True)
28
+ def create_thread(
29
+ origin: Optional[HttpUrl] = None,
30
+ user: ContentGridUser = Depends(dep_resolver.get_current_user_dependency()),
31
+ thread_repo: ThreadRepository = Depends(dep_resolver.get_thread_repository_dependency()),
32
+ agent : CompiledStateGraph = Depends(dep_resolver.get_agent_dependency())
33
+ ):
34
+ """Create new thread"""
35
+
36
+ thread_id = uuid.uuid4()
37
+ messages = [HumanMessage(content=extension_config.opening_message)]
38
+ context = dep_resolver.agent.thread_context(thread_id=str(thread_id), user=user, origin=origin)
39
+ agent.invoke(
40
+ {"messages": messages},
41
+ {"configurable": context}, #type: ignore
42
+ context=context, #type: ignore
43
+ )
44
+
45
+ created_thread = thread_repo.create(user=user, origin=origin, component="datamodel", thread_id=thread_id)
46
+ return ThreadRead(**created_thread.model_dump(), tags=tags)
47
+
48
+ @threadrouter.get("/", response_model=FastAPIHALCollection, response_model_exclude_none=True)
49
+ def read_threads(
50
+ origin: Optional[HttpUrl] = None,
51
+ user: ContentGridUser = Depends(dep_resolver.get_current_user_dependency()),
52
+ thread_repo: ThreadRepository = Depends(dep_resolver.get_thread_repository_dependency()),
53
+ offset: int = 0,
54
+ limit: Annotated[int, Query(le=100)] = 100,
55
+ ):
56
+ """Get all threads with pagination"""
57
+ threads = []
58
+ if origin:
59
+ threads = thread_repo.get_all_for_user_and_origin(user, origin, offset=offset, limit=limit)
60
+ else:
61
+ threads = thread_repo.get_all_for_user(user, offset=offset, limit=limit)
62
+ thread_reads = [ThreadRead(**thread.model_dump(), tags=tags) for thread in threads]
63
+ return FastAPIHALCollection[ThreadRead](
64
+ _embedded={"threads" : thread_reads},
65
+ _links={"self": HALLinkFor(endpoint_function_name="read_threads", tags=tags, templated=True, params={"offset": offset, "limit": limit})},
66
+ _templates={"startThread": HALTemplateFor(endpoint_function_name="create_thread", tags=tags, params={"origin": origin.encoded_string()} if origin else {})},
67
+ )
68
+
69
+ @threadrouter.get("/{thread_id}", response_model=ThreadRead, response_model_exclude_none=True)
70
+ def read_thread(
71
+ thread_id: uuid.UUID,
72
+ user: ContentGridUser = Depends(dep_resolver.get_current_user_dependency()),
73
+ thread_repo: ThreadRepository = Depends(dep_resolver.get_thread_repository_dependency())
74
+ ):
75
+ """Get thread by ID"""
76
+ thread = thread_repo.get_by_id_for_user(thread_id, user)
77
+ return ThreadRead(**thread.model_dump(), tags=tags)
78
+
79
+ @threadrouter.get("/{thread_id}/tools", response_model=AgentToolCollectionResponse, response_model_exclude_none=True)
80
+ def get_thread_tools(
81
+ thread_id: uuid.UUID,
82
+ user: ContentGridUser = Depends(dep_resolver.get_current_user_dependency()),
83
+ thread_repo: ThreadRepository = Depends(dep_resolver.get_thread_repository_dependency()),
84
+ thread_context: DefaultThreadContext = Depends(dep_resolver.get_thread_context_dependency())
85
+ ):
86
+ """Get available tools for a thread"""
87
+ tools_list = dep_resolver.agent.get_tools(thread_context)
88
+
89
+ # Extract tool information including schema
90
+ tool_responses = []
91
+ for tool in tools_list:
92
+ tool_dict = tool.model_dump()
93
+
94
+ # Extract schema if args_schema is available
95
+ if hasattr(tool, 'args_schema') and tool.args_schema:
96
+ try:
97
+ # Get JSON schema from Pydantic model
98
+ tool_dict['args'] = tool.args_schema.model_json_schema()
99
+ except (AttributeError, Exception):
100
+ # Fallback if schema doesn't exist or can't be extracted
101
+ tool_dict['args'] = {}
102
+ else:
103
+ tool_dict['args'] = {}
104
+
105
+ tool_responses.append(AgentToolResponse(**tool_dict))
106
+
107
+ return AgentToolCollectionResponse(thread_id=thread_id, _embedded={"tools": tool_responses}, tags=tags)
108
+
109
+ @threadrouter.patch("/{thread_id}", response_model=ThreadRead, response_model_exclude_none=True)
110
+ def update_thread(
111
+ thread_id: uuid.UUID,
112
+ thread_update: ThreadUpdate,
113
+ user: ContentGridUser = Depends(dep_resolver.get_current_user_dependency()),
114
+ thread_repo: ThreadRepository = Depends(dep_resolver.get_thread_repository_dependency())
115
+ ):
116
+ """Update thread"""
117
+ updated_thread = thread_repo.update_for_user(thread_id, thread_update, user)
118
+ return ThreadRead(**updated_thread.model_dump(), tags=tags)
119
+
120
+ @threadrouter.delete("/{thread_id}", status_code=status.HTTP_204_NO_CONTENT)
121
+ def delete_thread(
122
+ thread_id: uuid.UUID,
123
+ user: ContentGridUser = Depends(dep_resolver.get_current_user_dependency()),
124
+ thread_repo: ThreadRepository = Depends(dep_resolver.get_thread_repository_dependency())
125
+ ):
126
+ """Delete thread"""
127
+ thread_repo.delete_for_user(thread_id, user)
128
+
129
+ return threadrouter
@@ -0,0 +1,50 @@
1
+
2
+ from typing import Any, Callable, List, Optional
3
+ from enum import Enum
4
+ import uuid
5
+ from pydantic import BaseModel, Field, ConfigDict
6
+ from langchain.tools import BaseTool
7
+ from contentgrid_extension_helpers.responses.hal import FastAPIHALResponse, FastAPIHALCollection, HALLinkFor
8
+ from contentgrid_assistant_api.types.context import DefaultThreadContext
9
+
10
+ class Agent(BaseModel):
11
+ model_config = ConfigDict(arbitrary_types_allowed=True)
12
+
13
+ name : str = "Default agent name"
14
+ version : str = "v0"
15
+ get_agent_override : Callable[..., Any] = Field(exclude=True, default_factory=lambda x : None) # used to override the get agent dependency
16
+ get_current_user_override : Callable[..., Any] = Field(exclude=True, default_factory=lambda x : None) # used to override authentication
17
+ thread_context : type = Field(exclude=True, default=DefaultThreadContext)
18
+ # Function that accepts the thread_context and returns the tools.
19
+ # We do this dynamically because tools could be generated based on the context (like based on the origin profile)
20
+ get_tools : Callable[..., List[BaseTool]] = Field(exclude=True, default=lambda _ : [])
21
+
22
+
23
+ class AgentHomeResponse(FastAPIHALResponse):
24
+ model_config = ConfigDict(arbitrary_types_allowed=True)
25
+ name : str
26
+ version : str
27
+
28
+ def __init__(self, tags: Optional[List[str | Enum]]=None, **kwargs):
29
+ super().__init__(**kwargs)
30
+ self.links = {
31
+ "self" : HALLinkFor(endpoint_function_name="get_agent_home", tags=tags),
32
+ "threads" : HALLinkFor(endpoint_function_name="read_threads", tags=tags),
33
+ }
34
+
35
+ class AgentToolResponse(FastAPIHALResponse):
36
+ name : str
37
+ description : str
38
+ args : dict
39
+
40
+ class AgentToolCollectionResponse(FastAPIHALCollection[AgentToolResponse]):
41
+ thread_id : uuid.UUID
42
+
43
+ def __init__(self, tags: Optional[List[str | Enum]]=None, **kwargs):
44
+ super().__init__(**kwargs)
45
+ self.links = {
46
+ "self": HALLinkFor(endpoint_function_name="get_thread_tools", tags=tags, templated=False, path_params=lambda instance: {"thread_id": instance.thread_id}),
47
+ "thread": HALLinkFor(endpoint_function_name="read_thread", tags=tags, templated=False, path_params=lambda instance: {"thread_id": instance.thread_id}),
48
+ }
49
+
50
+
@@ -0,0 +1,9 @@
1
+ from typing import Optional
2
+ from langchain.agents import AgentState
3
+ from contentgrid_extension_helpers.authentication import ContentGridUser
4
+
5
+ class DefaultThreadContext(AgentState):
6
+ user: ContentGridUser
7
+ thread_id : str
8
+ origin: Optional[str]
9
+