contentgrid-assistant-api 0.0.2__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- _version.txt +1 -0
- contentgrid_assistant_api/app.py +131 -0
- contentgrid_assistant_api/config.py +55 -0
- contentgrid_assistant_api/db/repositories/thread_repository.py +78 -0
- contentgrid_assistant_api/db/types/message.py +25 -0
- contentgrid_assistant_api/db/types/thread.py +46 -0
- contentgrid_assistant_api/dependencies.py +79 -0
- contentgrid_assistant_api/routers/agent_home.py +41 -0
- contentgrid_assistant_api/routers/message_router.py +256 -0
- contentgrid_assistant_api/routers/thread_router.py +129 -0
- contentgrid_assistant_api/types/agents.py +50 -0
- contentgrid_assistant_api/types/context.py +9 -0
- contentgrid_assistant_api-0.0.2.dist-info/METADATA +401 -0
- contentgrid_assistant_api-0.0.2.dist-info/RECORD +17 -0
- contentgrid_assistant_api-0.0.2.dist-info/WHEEL +5 -0
- contentgrid_assistant_api-0.0.2.dist-info/licenses/LICENSE +13 -0
- contentgrid_assistant_api-0.0.2.dist-info/top_level.txt +1 -0
_version.txt
ADDED
|
@@ -0,0 +1 @@
|
|
|
1
|
+
0.0.2
|
|
@@ -0,0 +1,131 @@
|
|
|
1
|
+
import logging
|
|
2
|
+
from typing import List
|
|
3
|
+
from fastapi import Depends, FastAPI, status, Request
|
|
4
|
+
from fastapi.middleware.cors import CORSMiddleware
|
|
5
|
+
from openai import APIError
|
|
6
|
+
from contentgrid_extension_helpers.middleware.exception_middleware import catch_exceptions_middleware
|
|
7
|
+
from contentgrid_extension_helpers.responses.hal import FastAPIHALResponse, FastAPIHALCollection, HALLink, HALLinkFor
|
|
8
|
+
from contentgrid_extension_helpers.logging import setup_json_logging
|
|
9
|
+
from contentgrid_extension_helpers.problem_response import ProblemResponse
|
|
10
|
+
|
|
11
|
+
from contentgrid_assistant_api.routers.agent_home import generate_agent_home_router
|
|
12
|
+
from contentgrid_assistant_api.config import AssistantExtensionConfig, DatabaseConfig
|
|
13
|
+
from contentgrid_assistant_api.types.agents import Agent, AgentHomeResponse
|
|
14
|
+
|
|
15
|
+
|
|
16
|
+
class ContentGridAssistantAPI(FastAPI):
|
|
17
|
+
"""Specialized FastAPI application for a ContentGrid Assistant"""
|
|
18
|
+
|
|
19
|
+
def __init__(self,
|
|
20
|
+
extension_config: AssistantExtensionConfig | None = None,
|
|
21
|
+
database_config: DatabaseConfig | None = None,
|
|
22
|
+
agents: List[Agent] = [],
|
|
23
|
+
*args, **kwargs):
|
|
24
|
+
super().__init__(*args, **kwargs)
|
|
25
|
+
self.extension_config = extension_config or AssistantExtensionConfig()
|
|
26
|
+
self.database_config = database_config or DatabaseConfig()
|
|
27
|
+
|
|
28
|
+
self._setup_logging()
|
|
29
|
+
if not self.extension_config.production:
|
|
30
|
+
self._setup_cors()
|
|
31
|
+
self._setup_hal_response()
|
|
32
|
+
|
|
33
|
+
self._register_agent_routers(agents)
|
|
34
|
+
self._register_endpoints(agents)
|
|
35
|
+
self._register_middleware()
|
|
36
|
+
|
|
37
|
+
@property
|
|
38
|
+
def _server_prefix(self) -> str:
|
|
39
|
+
path_prefix = self.extension_config.extension_path_prefix
|
|
40
|
+
if path_prefix:
|
|
41
|
+
if path_prefix.startswith("/"):
|
|
42
|
+
return path_prefix
|
|
43
|
+
else:
|
|
44
|
+
return f"/{path_prefix}"
|
|
45
|
+
else:
|
|
46
|
+
return ""
|
|
47
|
+
|
|
48
|
+
def _format_problem_type(self, type: str) -> str:
|
|
49
|
+
"""Format problem type URL using configured base URL"""
|
|
50
|
+
return f"{self.extension_config.problem_type_base_url}/{type}"
|
|
51
|
+
|
|
52
|
+
def _setup_logging(self):
|
|
53
|
+
"""Configure JSON logging for production environments"""
|
|
54
|
+
if self.extension_config.production:
|
|
55
|
+
setup_json_logging()
|
|
56
|
+
|
|
57
|
+
def _setup_cors(self):
|
|
58
|
+
"""Configure CORS middleware"""
|
|
59
|
+
self.add_middleware(
|
|
60
|
+
CORSMiddleware,
|
|
61
|
+
allow_origins=["*"],
|
|
62
|
+
allow_credentials=True,
|
|
63
|
+
allow_methods=["*"],
|
|
64
|
+
allow_headers=["*"],
|
|
65
|
+
)
|
|
66
|
+
|
|
67
|
+
def _setup_hal_response(self):
|
|
68
|
+
"""Initialize HAL response configuration"""
|
|
69
|
+
FastAPIHALResponse.init_app(self)
|
|
70
|
+
FastAPIHALResponse.add_server_url(
|
|
71
|
+
self.extension_config.server_url or f"http://localhost:{self.extension_config.server_port or 8000}"
|
|
72
|
+
)
|
|
73
|
+
|
|
74
|
+
def _register_agent_routers(self, agents : List[Agent]):
|
|
75
|
+
"""Register API routers with authentication"""
|
|
76
|
+
for agent in agents:
|
|
77
|
+
self.include_router(
|
|
78
|
+
generate_agent_home_router(agent, self.extension_config, self.database_config),
|
|
79
|
+
prefix=f"{self.extension_config.extension_path_prefix if self.extension_config.extension_path_prefix else ''}/{agent.name}",
|
|
80
|
+
tags=[agent.name],
|
|
81
|
+
dependencies=[Depends(agent.get_current_user_override)]
|
|
82
|
+
)
|
|
83
|
+
|
|
84
|
+
def _register_endpoints(self, agents : List[Agent]):
|
|
85
|
+
"""Register root and health check endpoints"""
|
|
86
|
+
|
|
87
|
+
@self.get(f"{self._server_prefix}/health")
|
|
88
|
+
def health_check():
|
|
89
|
+
return "ok"
|
|
90
|
+
|
|
91
|
+
@self.get(f"{self._server_prefix}/", response_model=FastAPIHALCollection, response_model_exclude_unset=True)
|
|
92
|
+
def get_server_resources():
|
|
93
|
+
return FastAPIHALCollection(
|
|
94
|
+
_embedded={
|
|
95
|
+
"agents" : [
|
|
96
|
+
AgentHomeResponse(**agent.model_dump(), tags=[agent.name]) for agent in agents
|
|
97
|
+
]
|
|
98
|
+
},
|
|
99
|
+
_links={
|
|
100
|
+
"self": HALLinkFor(endpoint_function_name="get_server_resources"),
|
|
101
|
+
}
|
|
102
|
+
)
|
|
103
|
+
|
|
104
|
+
def _register_middleware(self):
|
|
105
|
+
"""Register custom middleware for exception handling"""
|
|
106
|
+
self.middleware("http")(self._create_openai_exception_middleware())
|
|
107
|
+
self.middleware("http")(catch_exceptions_middleware)
|
|
108
|
+
|
|
109
|
+
def _create_openai_exception_middleware(self):
|
|
110
|
+
"""Create OpenAI exception middleware with access to config"""
|
|
111
|
+
async def catch_openai_exceptions_middleware(request: Request, call_next):
|
|
112
|
+
"""Handle OpenAI API exceptions and convert to problem responses"""
|
|
113
|
+
try:
|
|
114
|
+
return await call_next(request)
|
|
115
|
+
except APIError as e:
|
|
116
|
+
logging.exception(f"OpenAI error: {str(e)}", exc_info=True, stack_info=True)
|
|
117
|
+
if e.code == "unsupported_file":
|
|
118
|
+
return ProblemResponse(
|
|
119
|
+
title="File not supported",
|
|
120
|
+
problem_type=self._format_problem_type("unsupported-file"),
|
|
121
|
+
detail="Provided file extension is not supported.",
|
|
122
|
+
status=status.HTTP_400_BAD_REQUEST
|
|
123
|
+
)
|
|
124
|
+
else:
|
|
125
|
+
return ProblemResponse(
|
|
126
|
+
title="OpenAI error",
|
|
127
|
+
problem_type=self._format_problem_type("openai-error"),
|
|
128
|
+
detail="An unexpected OpenAI error occured.",
|
|
129
|
+
status=status.HTTP_500_INTERNAL_SERVER_ERROR
|
|
130
|
+
)
|
|
131
|
+
return catch_openai_exceptions_middleware
|
|
@@ -0,0 +1,55 @@
|
|
|
1
|
+
from pydantic_settings import BaseSettings
|
|
2
|
+
from pydantic import Field, ConfigDict
|
|
3
|
+
from pydantic import computed_field
|
|
4
|
+
import logging
|
|
5
|
+
import urllib
|
|
6
|
+
|
|
7
|
+
class AssistantExtensionConfig(BaseSettings):
|
|
8
|
+
model_config = ConfigDict(extra="allow", env_file=[".env", ".env.secret"], env_file_encoding="utf-8")
|
|
9
|
+
|
|
10
|
+
production : bool = False
|
|
11
|
+
|
|
12
|
+
|
|
13
|
+
server_port: int | None = 8000
|
|
14
|
+
server_url: str | None = Field("http://localhost:8000", serialization_alias="BACKEND_URL")
|
|
15
|
+
web_concurrency: int | None = 1
|
|
16
|
+
opening_message : str = "Hello, please introduce yourself and list your available tools and functionalities."
|
|
17
|
+
graph_recursion_limit: int = 100
|
|
18
|
+
|
|
19
|
+
extension_path_prefix : str | None = None
|
|
20
|
+
|
|
21
|
+
routes_assistant_prefix : str = "/assistant"
|
|
22
|
+
routes_thread_prefix : str = "/threads"
|
|
23
|
+
routes_message_prefix : str = "/messages"
|
|
24
|
+
|
|
25
|
+
problem_type_base_url : str = "https://api.contentgrid.com/problems/ml"
|
|
26
|
+
|
|
27
|
+
class DatabaseConfig(BaseSettings):
|
|
28
|
+
model_config = ConfigDict(extra="allow", env_file=[".env", ".env.secret"], env_file_encoding="utf-8")
|
|
29
|
+
|
|
30
|
+
pg_dbname: str = "assistant"
|
|
31
|
+
pg_user: str = "assistant"
|
|
32
|
+
pg_passwd: str = "assistant"
|
|
33
|
+
pg_host: str = "postgres"
|
|
34
|
+
pg_port: str = "5432"
|
|
35
|
+
pg_reinitialize: bool = False
|
|
36
|
+
use_sqlite_db: bool = False
|
|
37
|
+
|
|
38
|
+
@computed_field # type: ignore
|
|
39
|
+
@property
|
|
40
|
+
def database_url(self) -> str:
|
|
41
|
+
if self.use_sqlite_db:
|
|
42
|
+
logging.info("Using SQLite database")
|
|
43
|
+
return "sqlite:///sqlite.db"
|
|
44
|
+
else:
|
|
45
|
+
return f"postgresql://{self.pg_user}:{urllib.parse.quote(self.pg_passwd)}@{self.pg_host}:{self.pg_port}/{self.pg_dbname}" # type: ignore
|
|
46
|
+
|
|
47
|
+
def log_config(self) -> None:
|
|
48
|
+
"""Log database configuration (excluding password)"""
|
|
49
|
+
logging.info("===Database Config===")
|
|
50
|
+
logging.info(f"user : {self.pg_user}")
|
|
51
|
+
logging.info("Passwd : ********")
|
|
52
|
+
logging.info(f"dbname : {self.pg_dbname}")
|
|
53
|
+
logging.info(f"host : {self.pg_host}")
|
|
54
|
+
logging.info(f"port : {self.pg_port}")
|
|
55
|
+
logging.info(f"use_sqlite_db : {self.use_sqlite_db}")
|
|
@@ -0,0 +1,78 @@
|
|
|
1
|
+
|
|
2
|
+
from typing import List, Optional
|
|
3
|
+
import uuid
|
|
4
|
+
from pydantic import HttpUrl
|
|
5
|
+
from sqlmodel import Session, select
|
|
6
|
+
from fastapi import HTTPException
|
|
7
|
+
from contentgrid_extension_helpers.dependencies.sqlalch.repositories import BaseRepository
|
|
8
|
+
from contentgrid_assistant_api.db.types.thread import Thread, ThreadCreate, ThreadUpdate
|
|
9
|
+
from contentgrid_extension_helpers.dependencies.authentication.user import ContentGridUser
|
|
10
|
+
|
|
11
|
+
|
|
12
|
+
class ThreadRepository(BaseRepository[Thread, ThreadCreate, ThreadUpdate]):
|
|
13
|
+
"""Repository for Thread operations"""
|
|
14
|
+
|
|
15
|
+
def __init__(self, session: Session):
|
|
16
|
+
super().__init__(session, Thread)
|
|
17
|
+
|
|
18
|
+
def create(self, user: ContentGridUser, origin : Optional[HttpUrl], component : str, thread_id : uuid.UUID | None = None) -> Thread:
|
|
19
|
+
"""Create a new thread, associating it with the user"""
|
|
20
|
+
thread_params = {
|
|
21
|
+
"name" : "New Thread",
|
|
22
|
+
"origin" : origin.encoded_string() if origin else None,
|
|
23
|
+
"component" : component,
|
|
24
|
+
"user_sub" : user.sub
|
|
25
|
+
}
|
|
26
|
+
|
|
27
|
+
if thread_id:
|
|
28
|
+
thread_params["id"] = thread_id
|
|
29
|
+
|
|
30
|
+
new_thread = Thread(
|
|
31
|
+
**thread_params
|
|
32
|
+
)
|
|
33
|
+
self.session.add(new_thread)
|
|
34
|
+
self.session.commit()
|
|
35
|
+
self.session.refresh(new_thread)
|
|
36
|
+
return new_thread
|
|
37
|
+
|
|
38
|
+
def get_all_for_user(self, user: ContentGridUser, offset: int = 0, limit: int = 100) -> List[Thread]:
|
|
39
|
+
"""Get all threads for a specific user with pagination"""
|
|
40
|
+
return self.session.exec(
|
|
41
|
+
select(Thread).where(Thread.user_sub == user.sub).offset(offset).limit(limit)
|
|
42
|
+
).all()
|
|
43
|
+
|
|
44
|
+
def get_all_for_user_and_origin(self, user: ContentGridUser, origin: HttpUrl, offset: int = 0, limit: int = 100) -> List[Thread]:
|
|
45
|
+
"""Get all threads for a specific user and origin"""
|
|
46
|
+
return self.session.exec(
|
|
47
|
+
select(Thread).where(Thread.user_sub == user.sub, Thread.origin == origin.encoded_string()).offset(offset).limit(limit)
|
|
48
|
+
).all()
|
|
49
|
+
|
|
50
|
+
def get_by_id_for_user(self, thread_id: uuid.UUID, user: ContentGridUser) -> Thread:
|
|
51
|
+
"""Get thread by ID, ensuring it belongs to the user"""
|
|
52
|
+
dataset = self.session.exec(
|
|
53
|
+
select(Thread).where(Thread.id == thread_id, Thread.user_sub == user.sub)
|
|
54
|
+
).first()
|
|
55
|
+
if not dataset:
|
|
56
|
+
raise HTTPException(status_code=404, detail="Thread not found or access denied")
|
|
57
|
+
return dataset
|
|
58
|
+
|
|
59
|
+
def update_for_user(self, thread_id: uuid.UUID, update_model: ThreadUpdate, user: ContentGridUser) -> Thread:
|
|
60
|
+
"""Update thread, ensuring it belongs to the user"""
|
|
61
|
+
thread = self.get_by_id_for_user(thread_id, user)
|
|
62
|
+
|
|
63
|
+
# Apply updates
|
|
64
|
+
update_data = update_model.model_dump(exclude_unset=True)
|
|
65
|
+
for field, value in update_data.items():
|
|
66
|
+
setattr(thread, field, value)
|
|
67
|
+
|
|
68
|
+
self.session.add(thread)
|
|
69
|
+
self.session.commit()
|
|
70
|
+
self.session.refresh(thread)
|
|
71
|
+
return thread
|
|
72
|
+
|
|
73
|
+
def delete_for_user(self, thread_id: uuid.UUID, user: ContentGridUser) -> Thread:
|
|
74
|
+
"""Delete thread, ensuring it belongs to the user"""
|
|
75
|
+
thread = self.get_by_id_for_user(thread_id, user)
|
|
76
|
+
self.session.delete(thread)
|
|
77
|
+
self.session.commit()
|
|
78
|
+
return thread
|
|
@@ -0,0 +1,25 @@
|
|
|
1
|
+
from langchain_core.messages import HumanMessage, SystemMessage, AIMessage, ToolMessage, UsageMetadata
|
|
2
|
+
from contentgrid_extension_helpers.responses.hal import FastAPIHALResponse
|
|
3
|
+
from pydantic import Field
|
|
4
|
+
|
|
5
|
+
|
|
6
|
+
class BaseHALMessage(FastAPIHALResponse):
|
|
7
|
+
hidden : bool = False
|
|
8
|
+
additional_kwargs: dict = Field(exclude=True) # exclude removes the output from the serialization
|
|
9
|
+
response_metadata: dict = Field(exclude=True)
|
|
10
|
+
|
|
11
|
+
def get(self, key, default=None):
|
|
12
|
+
"""Add get method for compatibility with LangChain validation"""
|
|
13
|
+
return getattr(self, key, default)
|
|
14
|
+
|
|
15
|
+
class HALHumanMessage(BaseHALMessage, HumanMessage):
|
|
16
|
+
pass
|
|
17
|
+
|
|
18
|
+
class HALAIMessage(BaseHALMessage, AIMessage):
|
|
19
|
+
usage_metadata: UsageMetadata | None = Field(exclude=True)
|
|
20
|
+
|
|
21
|
+
class HALSystemMessage(BaseHALMessage, SystemMessage):
|
|
22
|
+
hidden : bool = True
|
|
23
|
+
|
|
24
|
+
class HALToolMessage(BaseHALMessage, ToolMessage):
|
|
25
|
+
pass
|
|
@@ -0,0 +1,46 @@
|
|
|
1
|
+
from enum import Enum
|
|
2
|
+
from typing import Optional, List
|
|
3
|
+
import uuid
|
|
4
|
+
from pydantic import BaseModel
|
|
5
|
+
from datetime import datetime
|
|
6
|
+
from sqlmodel import Field, SQLModel
|
|
7
|
+
from contentgrid_extension_helpers.responses.hal import FastAPIHALResponse, HALLinkFor, HALTemplateFor, HALLink
|
|
8
|
+
|
|
9
|
+
class ThreadBase(SQLModel):
|
|
10
|
+
# Base model for Thread
|
|
11
|
+
# Not used on its own, but as a base for Thread, ThreadCreate.
|
|
12
|
+
# ThreadUpdate inherits from Basemodel. Meaning only the fields there can be patched.
|
|
13
|
+
# All fields set here will be inherited so are therefore publicly accessible
|
|
14
|
+
# For private fields that should not be publicly accessible, use the main Thread class
|
|
15
|
+
id : uuid.UUID = Field(primary_key=True, default_factory=uuid.uuid4)
|
|
16
|
+
name: str = Field(description="Name of the Thread")
|
|
17
|
+
origin : Optional[str] = Field(index=True)
|
|
18
|
+
component: str = Field(index=True, default="Undefined")
|
|
19
|
+
created_at : datetime = Field(default_factory=datetime.now, description="Creation timestamp")
|
|
20
|
+
|
|
21
|
+
class Thread(ThreadBase, table=True):
|
|
22
|
+
# Main Thread model, this is the model where you will code your business logic with
|
|
23
|
+
user_sub : str = Field(index=True, description="User subject (sub) who owns the Thread")
|
|
24
|
+
|
|
25
|
+
class ThreadRead(ThreadBase, FastAPIHALResponse):
|
|
26
|
+
# A read-only version of the Thread model that can be used for public API responses
|
|
27
|
+
# Note that this model inherits from ThreadBase, so it has the same fields
|
|
28
|
+
|
|
29
|
+
def __init__(self, tags: Optional[List[str | Enum]]=None, **kwargs):
|
|
30
|
+
super().__init__(**kwargs)
|
|
31
|
+
self.links = {
|
|
32
|
+
"self": HALLinkFor(endpoint_function_name="read_thread", tags=tags, templated=False, path_params=lambda instance: {"thread_id": instance.id}),
|
|
33
|
+
"messages": HALLinkFor(endpoint_function_name="read_messages", tags=tags, templated=False, path_params=lambda instance: {"thread_id": instance.id}),
|
|
34
|
+
"tools": HALLinkFor(endpoint_function_name="get_thread_tools", tags=tags, templated=False, path_params=lambda instance: {"thread_id": instance.id}),
|
|
35
|
+
}
|
|
36
|
+
self.templates = {
|
|
37
|
+
"update": HALTemplateFor(endpoint_function_name="update_thread", tags=tags, templated=False, path_params=lambda instance: {"thread_id": instance.id}),
|
|
38
|
+
"delete": HALTemplateFor(endpoint_function_name="delete_thread", tags=tags, templated=False, path_params=lambda instance: {"thread_id": instance.id})
|
|
39
|
+
}
|
|
40
|
+
|
|
41
|
+
class ThreadCreate(ThreadBase):
|
|
42
|
+
# Creating a thread will be done by injecting the user_sub and origin from the dependencies.
|
|
43
|
+
pass
|
|
44
|
+
|
|
45
|
+
class ThreadUpdate(BaseModel):
|
|
46
|
+
name : str = Field(None, description="New name of the Thread") # type: ignore
|
|
@@ -0,0 +1,79 @@
|
|
|
1
|
+
from typing import Generator
|
|
2
|
+
import uuid
|
|
3
|
+
from fastapi import Depends
|
|
4
|
+
from fastapi.security import OAuth2PasswordBearer
|
|
5
|
+
from sqlmodel import Session
|
|
6
|
+
from contentgrid_assistant_api.db.repositories.thread_repository import ThreadRepository
|
|
7
|
+
from contentgrid_extension_helpers.dependencies.sqlalch.db import SQLiteSessionFactory, PostgresSessionFactory
|
|
8
|
+
from contentgrid_extension_helpers.dependencies.authentication.user import ContentGridUser
|
|
9
|
+
from langgraph.checkpoint.postgres import PostgresSaver
|
|
10
|
+
from langgraph.checkpoint.memory import MemorySaver
|
|
11
|
+
from langgraph.checkpoint.base import BaseCheckpointSaver
|
|
12
|
+
from langgraph.graph.state import CompiledStateGraph
|
|
13
|
+
from contentgrid_assistant_api.types.context import DefaultThreadContext
|
|
14
|
+
from contentgrid_assistant_api.config import DatabaseConfig
|
|
15
|
+
from contentgrid_assistant_api.types.agents import Agent
|
|
16
|
+
|
|
17
|
+
|
|
18
|
+
oauth2_scheme = OAuth2PasswordBearer(tokenUrl="token")
|
|
19
|
+
|
|
20
|
+
class DependencyResolver():
|
|
21
|
+
def __init__(self, agent : Agent, db_config : DatabaseConfig) -> None:
|
|
22
|
+
db_config.pg_dbname = agent.name
|
|
23
|
+
self.agent = agent
|
|
24
|
+
self.db_config = db_config
|
|
25
|
+
if db_config.use_sqlite_db:
|
|
26
|
+
self.db_conn_factory = SQLiteSessionFactory(sqlite_file_name=f"{agent.name}.db")
|
|
27
|
+
self.in_mem_store = MemorySaver()
|
|
28
|
+
else:
|
|
29
|
+
self.db_conn_factory = PostgresSessionFactory(
|
|
30
|
+
pg_host=db_config.pg_host,
|
|
31
|
+
pg_dbname=db_config.pg_dbname,
|
|
32
|
+
pg_user=db_config.pg_user,
|
|
33
|
+
pg_passwd=db_config.pg_passwd,
|
|
34
|
+
pg_port=db_config.pg_port
|
|
35
|
+
)
|
|
36
|
+
# Only setup PostgreSQL checkpointer if not using SQLite
|
|
37
|
+
with PostgresSaver.from_conn_string(db_config.database_url) as store:
|
|
38
|
+
store.setup()
|
|
39
|
+
|
|
40
|
+
def get_langgraph_checkpointer_dependency(self):
|
|
41
|
+
def get_langgraph_checkpointer() -> Generator[BaseCheckpointSaver, None, None]:
|
|
42
|
+
if self.db_config.use_sqlite_db:
|
|
43
|
+
yield self.in_mem_store
|
|
44
|
+
elif "postgresql" in self.db_config.database_url:
|
|
45
|
+
with PostgresSaver.from_conn_string(self.db_config.database_url) as store:
|
|
46
|
+
yield store
|
|
47
|
+
return get_langgraph_checkpointer
|
|
48
|
+
|
|
49
|
+
|
|
50
|
+
def get_thread_repository_dependency(self):
|
|
51
|
+
def get_thread_repository(session: Session = Depends(self.db_conn_factory)) -> ThreadRepository:
|
|
52
|
+
"""Get a dataset repository instance"""
|
|
53
|
+
return ThreadRepository(session)
|
|
54
|
+
return get_thread_repository
|
|
55
|
+
|
|
56
|
+
def get_thread_context_dependency(self):
|
|
57
|
+
def get_thread_context(thread_id: uuid.UUID, thread_repository: ThreadRepository = Depends(self.get_thread_repository_dependency()), user: ContentGridUser = Depends(self.get_current_user_dependency())) -> DefaultThreadContext:
|
|
58
|
+
# Getting the Conversation context from the incoming request.
|
|
59
|
+
# This dependency can only be used on /{thread_id}/... endpoints
|
|
60
|
+
# Each thread is related to a possible origin, using the thread_id we fetch that origin using the thread_repository.
|
|
61
|
+
# If the user is not allowed to read the thread, the database won't return the thread and the request is terminated with 404 early.
|
|
62
|
+
# If the user is allowed to read, the origin from the database is injected in the conversation context which tools can access.
|
|
63
|
+
# ThreadContext shows all fields in the thread's context of the user.
|
|
64
|
+
thread = thread_repository.get_by_id_for_user(thread_id=thread_id, user=user)
|
|
65
|
+
# Here we could check if the user is still allowed to reach the origin.
|
|
66
|
+
# > This should be done in the agent graphs. Or you can let the tools fail or add a graph that fetches the origin (like fetch datamodel in the console assistant or fetch profile in the navigator assistant)
|
|
67
|
+
# btw mypy complains about messages not being passed but that is not good because then the conversation is empty. so do not pass messages here. it should come from the postgres persistance.
|
|
68
|
+
return self.agent.thread_context(user=user, origin=thread.origin, thread_id=str(thread_id)) # type: ignore
|
|
69
|
+
return get_thread_context
|
|
70
|
+
|
|
71
|
+
def get_agent_dependency(self):
|
|
72
|
+
def get_agent(checkpointer: BaseCheckpointSaver = Depends(self.get_langgraph_checkpointer_dependency())) -> CompiledStateGraph:
|
|
73
|
+
return self.agent.get_agent_override(checkpointer=checkpointer)
|
|
74
|
+
return get_agent
|
|
75
|
+
|
|
76
|
+
def get_current_user_dependency(self):
|
|
77
|
+
if self.agent.get_current_user_override:
|
|
78
|
+
return self.agent.get_current_user_override
|
|
79
|
+
raise Exception("Get current user dependency not set...")
|
|
@@ -0,0 +1,41 @@
|
|
|
1
|
+
from fastapi import APIRouter, FastAPI
|
|
2
|
+
import os
|
|
3
|
+
from fastapi.concurrency import asynccontextmanager
|
|
4
|
+
from contentgrid_assistant_api.config import DatabaseConfig, AssistantExtensionConfig
|
|
5
|
+
from contentgrid_assistant_api.dependencies import DependencyResolver
|
|
6
|
+
from contentgrid_assistant_api.routers.thread_router import generate_agent_thread_router
|
|
7
|
+
from contentgrid_assistant_api.types.agents import Agent, AgentHomeResponse
|
|
8
|
+
|
|
9
|
+
def exit_uvicorn():
|
|
10
|
+
import signal
|
|
11
|
+
# Send interrupt signal (Ctrl+C equivalent) to uvicorn (parent process)
|
|
12
|
+
# This is a bit of a janky way to "gracefully crash" the process, but it should suffice for now.
|
|
13
|
+
os.kill(os.getppid(), signal.SIGINT)
|
|
14
|
+
|
|
15
|
+
|
|
16
|
+
def generate_agent_home_router(agent : Agent, extension_config: AssistantExtensionConfig, database_config: DatabaseConfig) -> APIRouter:
|
|
17
|
+
if database_config.pg_dbname == database_config.__class__.model_fields["pg_dbname"].default:
|
|
18
|
+
# Check if the pg_dbname is still the default, if yes use the agent name
|
|
19
|
+
database_config.pg_dbname = agent.name
|
|
20
|
+
|
|
21
|
+
dep_resolver = DependencyResolver(agent=agent, db_config=database_config)
|
|
22
|
+
|
|
23
|
+
@asynccontextmanager
|
|
24
|
+
async def lifespan(app: FastAPI):
|
|
25
|
+
if database_config.pg_reinitialize:
|
|
26
|
+
# Clean up the threads and wipe the database
|
|
27
|
+
dep_resolver.db_conn_factory.wipe_database()
|
|
28
|
+
# Lifespan of the fast API router. Code before the yield is executed when the application starts
|
|
29
|
+
# and code after the yield is executed when the application stops.
|
|
30
|
+
dep_resolver.db_conn_factory.create_db_and_tables()
|
|
31
|
+
yield
|
|
32
|
+
|
|
33
|
+
|
|
34
|
+
router = APIRouter(lifespan=lifespan, tags=[agent.name])
|
|
35
|
+
router.include_router(generate_agent_thread_router(dep_resolver, extension_config, tags=[agent.name]))
|
|
36
|
+
|
|
37
|
+
@router.get("/", response_model=AgentHomeResponse, response_model_exclude_unset=True)
|
|
38
|
+
def get_agent_home():
|
|
39
|
+
return AgentHomeResponse(**agent.model_dump(), tags=[agent.name])
|
|
40
|
+
|
|
41
|
+
return router
|