fastapi-fullstack 0.1.7__py3-none-any.whl → 0.1.15__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- {fastapi_fullstack-0.1.7.dist-info → fastapi_fullstack-0.1.15.dist-info}/METADATA +9 -2
- {fastapi_fullstack-0.1.7.dist-info → fastapi_fullstack-0.1.15.dist-info}/RECORD +71 -55
- fastapi_gen/__init__.py +6 -1
- fastapi_gen/cli.py +9 -0
- fastapi_gen/config.py +154 -2
- fastapi_gen/generator.py +34 -14
- fastapi_gen/prompts.py +172 -31
- fastapi_gen/template/VARIABLES.md +33 -4
- fastapi_gen/template/cookiecutter.json +10 -0
- fastapi_gen/template/hooks/post_gen_project.py +87 -2
- fastapi_gen/template/{{cookiecutter.project_slug}}/.env.prod.example +9 -0
- fastapi_gen/template/{{cookiecutter.project_slug}}/.gitlab-ci.yml +178 -0
- fastapi_gen/template/{{cookiecutter.project_slug}}/CLAUDE.md +3 -0
- fastapi_gen/template/{{cookiecutter.project_slug}}/README.md +334 -0
- fastapi_gen/template/{{cookiecutter.project_slug}}/backend/.env.example +32 -0
- fastapi_gen/template/{{cookiecutter.project_slug}}/backend/alembic/env.py +10 -1
- fastapi_gen/template/{{cookiecutter.project_slug}}/backend/app/admin.py +1 -1
- fastapi_gen/template/{{cookiecutter.project_slug}}/backend/app/agents/__init__.py +31 -0
- fastapi_gen/template/{{cookiecutter.project_slug}}/backend/app/agents/crewai_assistant.py +563 -0
- fastapi_gen/template/{{cookiecutter.project_slug}}/backend/app/agents/deepagents_assistant.py +526 -0
- fastapi_gen/template/{{cookiecutter.project_slug}}/backend/app/agents/langchain_assistant.py +4 -3
- fastapi_gen/template/{{cookiecutter.project_slug}}/backend/app/agents/langgraph_assistant.py +371 -0
- fastapi_gen/template/{{cookiecutter.project_slug}}/backend/app/api/routes/v1/agent.py +1472 -0
- fastapi_gen/template/{{cookiecutter.project_slug}}/backend/app/api/routes/v1/oauth.py +3 -7
- fastapi_gen/template/{{cookiecutter.project_slug}}/backend/app/commands/cleanup.py +2 -2
- fastapi_gen/template/{{cookiecutter.project_slug}}/backend/app/commands/seed.py +7 -2
- fastapi_gen/template/{{cookiecutter.project_slug}}/backend/app/core/config.py +44 -7
- fastapi_gen/template/{{cookiecutter.project_slug}}/backend/app/db/__init__.py +7 -0
- fastapi_gen/template/{{cookiecutter.project_slug}}/backend/app/db/base.py +42 -0
- fastapi_gen/template/{{cookiecutter.project_slug}}/backend/app/db/models/conversation.py +262 -1
- fastapi_gen/template/{{cookiecutter.project_slug}}/backend/app/db/models/item.py +76 -1
- fastapi_gen/template/{{cookiecutter.project_slug}}/backend/app/db/models/session.py +118 -1
- fastapi_gen/template/{{cookiecutter.project_slug}}/backend/app/db/models/user.py +158 -1
- fastapi_gen/template/{{cookiecutter.project_slug}}/backend/app/db/models/webhook.py +185 -3
- fastapi_gen/template/{{cookiecutter.project_slug}}/backend/app/main.py +29 -2
- fastapi_gen/template/{{cookiecutter.project_slug}}/backend/app/repositories/base.py +6 -0
- fastapi_gen/template/{{cookiecutter.project_slug}}/backend/app/repositories/session.py +4 -4
- fastapi_gen/template/{{cookiecutter.project_slug}}/backend/app/services/conversation.py +9 -9
- fastapi_gen/template/{{cookiecutter.project_slug}}/backend/app/services/session.py +6 -6
- fastapi_gen/template/{{cookiecutter.project_slug}}/backend/app/services/webhook.py +7 -7
- fastapi_gen/template/{{cookiecutter.project_slug}}/backend/app/worker/__init__.py +1 -1
- fastapi_gen/template/{{cookiecutter.project_slug}}/backend/app/worker/arq_app.py +165 -0
- fastapi_gen/template/{{cookiecutter.project_slug}}/backend/app/worker/tasks/__init__.py +10 -1
- fastapi_gen/template/{{cookiecutter.project_slug}}/backend/pyproject.toml +40 -0
- fastapi_gen/template/{{cookiecutter.project_slug}}/backend/tests/api/test_metrics.py +53 -0
- fastapi_gen/template/{{cookiecutter.project_slug}}/backend/tests/test_agents.py +2 -0
- fastapi_gen/template/{{cookiecutter.project_slug}}/docker-compose.dev.yml +6 -0
- fastapi_gen/template/{{cookiecutter.project_slug}}/docker-compose.prod.yml +100 -0
- fastapi_gen/template/{{cookiecutter.project_slug}}/docker-compose.yml +39 -0
- fastapi_gen/template/{{cookiecutter.project_slug}}/frontend/.env.example +5 -0
- fastapi_gen/template/{{cookiecutter.project_slug}}/frontend/src/components/chat/chat-container.tsx +28 -1
- fastapi_gen/template/{{cookiecutter.project_slug}}/frontend/src/components/chat/index.ts +1 -0
- fastapi_gen/template/{{cookiecutter.project_slug}}/frontend/src/components/chat/message-item.tsx +22 -4
- fastapi_gen/template/{{cookiecutter.project_slug}}/frontend/src/components/chat/message-list.tsx +23 -3
- fastapi_gen/template/{{cookiecutter.project_slug}}/frontend/src/components/chat/tool-approval-dialog.tsx +138 -0
- fastapi_gen/template/{{cookiecutter.project_slug}}/frontend/src/hooks/use-chat.ts +242 -18
- fastapi_gen/template/{{cookiecutter.project_slug}}/frontend/src/hooks/use-local-chat.ts +242 -17
- fastapi_gen/template/{{cookiecutter.project_slug}}/frontend/src/lib/constants.ts +1 -1
- fastapi_gen/template/{{cookiecutter.project_slug}}/frontend/src/types/chat.ts +57 -1
- fastapi_gen/template/{{cookiecutter.project_slug}}/kubernetes/configmap.yaml +63 -0
- fastapi_gen/template/{{cookiecutter.project_slug}}/kubernetes/deployment.yaml +242 -0
- fastapi_gen/template/{{cookiecutter.project_slug}}/kubernetes/ingress.yaml +44 -0
- fastapi_gen/template/{{cookiecutter.project_slug}}/kubernetes/kustomization.yaml +28 -0
- fastapi_gen/template/{{cookiecutter.project_slug}}/kubernetes/namespace.yaml +12 -0
- fastapi_gen/template/{{cookiecutter.project_slug}}/kubernetes/secret.yaml +59 -0
- fastapi_gen/template/{{cookiecutter.project_slug}}/kubernetes/service.yaml +23 -0
- fastapi_gen/template/{{cookiecutter.project_slug}}/nginx/nginx.conf +225 -0
- fastapi_gen/template/{{cookiecutter.project_slug}}/nginx/ssl/.gitkeep +18 -0
- {fastapi_fullstack-0.1.7.dist-info → fastapi_fullstack-0.1.15.dist-info}/WHEEL +0 -0
- {fastapi_fullstack-0.1.7.dist-info → fastapi_fullstack-0.1.15.dist-info}/entry_points.txt +0 -0
- {fastapi_fullstack-0.1.7.dist-info → fastapi_fullstack-0.1.15.dist-info}/licenses/LICENSE +0 -0
|
@@ -893,6 +893,1478 @@ async def agent_websocket(
|
|
|
893
893
|
# Try to send error, but don't fail if connection is closed
|
|
894
894
|
await manager.send_event(websocket, "error", {"message": str(e)})
|
|
895
895
|
|
|
896
|
+
except WebSocketDisconnect:
|
|
897
|
+
pass # Normal disconnect
|
|
898
|
+
finally:
|
|
899
|
+
manager.disconnect(websocket)
|
|
900
|
+
{%- elif cookiecutter.enable_ai_agent and cookiecutter.use_langgraph %}
|
|
901
|
+
"""AI Agent WebSocket routes with streaming support (LangGraph ReAct Agent)."""
|
|
902
|
+
|
|
903
|
+
import logging
|
|
904
|
+
from typing import Any
|
|
905
|
+
{%- if cookiecutter.enable_conversation_persistence and cookiecutter.use_database %}
|
|
906
|
+
from datetime import datetime, UTC
|
|
907
|
+
{%- if cookiecutter.use_postgresql %}
|
|
908
|
+
from uuid import UUID
|
|
909
|
+
{%- endif %}
|
|
910
|
+
{%- endif %}
|
|
911
|
+
|
|
912
|
+
from fastapi import APIRouter, WebSocket, WebSocketDisconnect{%- if cookiecutter.websocket_auth_jwt %}, Depends{%- endif %}{%- if cookiecutter.websocket_auth_api_key %}, Query{%- endif %}
|
|
913
|
+
|
|
914
|
+
from langchain_core.messages import AIMessage, AIMessageChunk, HumanMessage, SystemMessage, ToolMessage
|
|
915
|
+
|
|
916
|
+
from app.agents.langgraph_assistant import AgentContext, get_agent
|
|
917
|
+
{%- if cookiecutter.websocket_auth_jwt %}
|
|
918
|
+
from app.api.deps import get_current_user_ws
|
|
919
|
+
from app.db.models.user import User
|
|
920
|
+
{%- endif %}
|
|
921
|
+
{%- if cookiecutter.websocket_auth_api_key %}
|
|
922
|
+
from app.core.config import settings
|
|
923
|
+
{%- endif %}
|
|
924
|
+
{%- if cookiecutter.enable_conversation_persistence and (cookiecutter.use_postgresql or cookiecutter.use_sqlite) %}
|
|
925
|
+
from app.db.session import get_db_context
|
|
926
|
+
from app.api.deps import ConversationSvc, get_conversation_service
|
|
927
|
+
from app.schemas.conversation import ConversationCreate, MessageCreate, ToolCallCreate, ToolCallComplete
|
|
928
|
+
{%- elif cookiecutter.enable_conversation_persistence and cookiecutter.use_mongodb %}
|
|
929
|
+
from app.api.deps import ConversationSvc, get_conversation_service
|
|
930
|
+
from app.schemas.conversation import ConversationCreate, MessageCreate, ToolCallCreate, ToolCallComplete
|
|
931
|
+
{%- endif %}
|
|
932
|
+
|
|
933
|
+
logger = logging.getLogger(__name__)
|
|
934
|
+
|
|
935
|
+
router = APIRouter()
|
|
936
|
+
|
|
937
|
+
|
|
938
|
+
class AgentConnectionManager:
|
|
939
|
+
"""WebSocket connection manager for AI agent."""
|
|
940
|
+
|
|
941
|
+
def __init__(self) -> None:
|
|
942
|
+
self.active_connections: list[WebSocket] = []
|
|
943
|
+
|
|
944
|
+
async def connect(self, websocket: WebSocket) -> None:
|
|
945
|
+
"""Accept and store a new WebSocket connection."""
|
|
946
|
+
await websocket.accept()
|
|
947
|
+
self.active_connections.append(websocket)
|
|
948
|
+
logger.info(f"Agent WebSocket connected. Total connections: {len(self.active_connections)}")
|
|
949
|
+
|
|
950
|
+
def disconnect(self, websocket: WebSocket) -> None:
|
|
951
|
+
"""Remove a WebSocket connection."""
|
|
952
|
+
if websocket in self.active_connections:
|
|
953
|
+
self.active_connections.remove(websocket)
|
|
954
|
+
logger.info(f"Agent WebSocket disconnected. Total connections: {len(self.active_connections)}")
|
|
955
|
+
|
|
956
|
+
async def send_event(self, websocket: WebSocket, event_type: str, data: Any) -> bool:
|
|
957
|
+
"""Send a JSON event to a specific WebSocket client.
|
|
958
|
+
|
|
959
|
+
Returns True if sent successfully, False if connection is closed.
|
|
960
|
+
"""
|
|
961
|
+
try:
|
|
962
|
+
await websocket.send_json({"type": event_type, "data": data})
|
|
963
|
+
return True
|
|
964
|
+
except (WebSocketDisconnect, RuntimeError):
|
|
965
|
+
# Connection already closed
|
|
966
|
+
return False
|
|
967
|
+
|
|
968
|
+
|
|
969
|
+
manager = AgentConnectionManager()
|
|
970
|
+
|
|
971
|
+
|
|
972
|
+
def build_message_history(
|
|
973
|
+
history: list[dict[str, str]]
|
|
974
|
+
) -> list[HumanMessage | AIMessage | SystemMessage]:
|
|
975
|
+
"""Convert conversation history to LangChain message format."""
|
|
976
|
+
messages: list[HumanMessage | AIMessage | SystemMessage] = []
|
|
977
|
+
|
|
978
|
+
for msg in history:
|
|
979
|
+
if msg["role"] == "user":
|
|
980
|
+
messages.append(HumanMessage(content=msg["content"]))
|
|
981
|
+
elif msg["role"] == "assistant":
|
|
982
|
+
messages.append(AIMessage(content=msg["content"]))
|
|
983
|
+
elif msg["role"] == "system":
|
|
984
|
+
messages.append(SystemMessage(content=msg["content"]))
|
|
985
|
+
|
|
986
|
+
return messages
|
|
987
|
+
|
|
988
|
+
{%- if cookiecutter.websocket_auth_api_key %}
|
|
989
|
+
|
|
990
|
+
|
|
991
|
+
async def verify_api_key(api_key: str) -> bool:
|
|
992
|
+
"""Verify the API key for WebSocket authentication."""
|
|
993
|
+
return api_key == settings.API_KEY
|
|
994
|
+
{%- endif %}
|
|
995
|
+
|
|
996
|
+
|
|
997
|
+
@router.websocket("/ws/agent")
|
|
998
|
+
async def agent_websocket(
|
|
999
|
+
websocket: WebSocket,
|
|
1000
|
+
{%- if cookiecutter.websocket_auth_jwt %}
|
|
1001
|
+
user: User = Depends(get_current_user_ws),
|
|
1002
|
+
{%- elif cookiecutter.websocket_auth_api_key %}
|
|
1003
|
+
api_key: str = Query(..., alias="api_key"),
|
|
1004
|
+
{%- endif %}
|
|
1005
|
+
) -> None:
|
|
1006
|
+
"""WebSocket endpoint for LangGraph ReAct agent with streaming support.
|
|
1007
|
+
|
|
1008
|
+
Uses LangGraph astream_events() to stream all agent events including:
|
|
1009
|
+
- user_prompt: When user input is received
|
|
1010
|
+
- model_request_start: When model request begins
|
|
1011
|
+
- text_delta: Streaming text from the model
|
|
1012
|
+
- tool_call: When a tool is called
|
|
1013
|
+
- tool_result: When a tool returns a result
|
|
1014
|
+
- final_result: When the final result is ready
|
|
1015
|
+
- complete: When processing is complete
|
|
1016
|
+
- error: When an error occurs
|
|
1017
|
+
|
|
1018
|
+
Expected input message format:
|
|
1019
|
+
{
|
|
1020
|
+
"message": "user message here",
|
|
1021
|
+
"history": [{"role": "user|assistant|system", "content": "..."}]{% if cookiecutter.enable_conversation_persistence and cookiecutter.use_database %},
|
|
1022
|
+
"conversation_id": "optional-uuid-to-continue-existing-conversation"{% endif %}
|
|
1023
|
+
}
|
|
1024
|
+
{%- if cookiecutter.websocket_auth_jwt %}
|
|
1025
|
+
|
|
1026
|
+
Authentication: Requires a valid JWT token passed as a query parameter or header.
|
|
1027
|
+
{%- elif cookiecutter.websocket_auth_api_key %}
|
|
1028
|
+
|
|
1029
|
+
Authentication: Requires a valid API key passed as 'api_key' query parameter.
|
|
1030
|
+
Example: ws://localhost:{{ cookiecutter.backend_port }}/api/v1/ws/agent?api_key=your-api-key
|
|
1031
|
+
{%- endif %}
|
|
1032
|
+
{%- if cookiecutter.enable_conversation_persistence and cookiecutter.use_database %}
|
|
1033
|
+
|
|
1034
|
+
Persistence: Set 'conversation_id' to continue an existing conversation.
|
|
1035
|
+
If not provided, a new conversation is created. The conversation_id is
|
|
1036
|
+
returned in the 'conversation_created' event.
|
|
1037
|
+
{%- endif %}
|
|
1038
|
+
"""
|
|
1039
|
+
{%- if cookiecutter.websocket_auth_api_key %}
|
|
1040
|
+
# Verify API key before accepting connection
|
|
1041
|
+
if not await verify_api_key(api_key):
|
|
1042
|
+
await websocket.close(code=4001, reason="Invalid API key")
|
|
1043
|
+
return
|
|
1044
|
+
{%- endif %}
|
|
1045
|
+
|
|
1046
|
+
await manager.connect(websocket)
|
|
1047
|
+
|
|
1048
|
+
# Conversation state per connection
|
|
1049
|
+
conversation_history: list[dict[str, str]] = []
|
|
1050
|
+
context: AgentContext = {}
|
|
1051
|
+
{%- if cookiecutter.websocket_auth_jwt %}
|
|
1052
|
+
context["user_id"] = str(user.id) if user else None
|
|
1053
|
+
context["user_name"] = user.email if user else None
|
|
1054
|
+
{%- endif %}
|
|
1055
|
+
{%- if cookiecutter.enable_conversation_persistence and cookiecutter.use_database %}
|
|
1056
|
+
current_conversation_id: str | None = None
|
|
1057
|
+
{%- endif %}
|
|
1058
|
+
|
|
1059
|
+
try:
|
|
1060
|
+
while True:
|
|
1061
|
+
# Receive user message
|
|
1062
|
+
data = await websocket.receive_json()
|
|
1063
|
+
user_message = data.get("message", "")
|
|
1064
|
+
# Optionally accept history from client (or use server-side tracking)
|
|
1065
|
+
if "history" in data:
|
|
1066
|
+
conversation_history = data["history"]
|
|
1067
|
+
|
|
1068
|
+
if not user_message:
|
|
1069
|
+
await manager.send_event(websocket, "error", {"message": "Empty message"})
|
|
1070
|
+
continue
|
|
1071
|
+
|
|
1072
|
+
{%- if cookiecutter.enable_conversation_persistence and (cookiecutter.use_postgresql or cookiecutter.use_sqlite) %}
|
|
1073
|
+
|
|
1074
|
+
# Handle conversation persistence
|
|
1075
|
+
try:
|
|
1076
|
+
{%- if cookiecutter.use_postgresql %}
|
|
1077
|
+
async with get_db_context() as db:
|
|
1078
|
+
conv_service = get_conversation_service(db)
|
|
1079
|
+
|
|
1080
|
+
# Get or create conversation
|
|
1081
|
+
requested_conv_id = data.get("conversation_id")
|
|
1082
|
+
if requested_conv_id:
|
|
1083
|
+
current_conversation_id = requested_conv_id
|
|
1084
|
+
# Verify conversation exists
|
|
1085
|
+
await conv_service.get_conversation(UUID(requested_conv_id))
|
|
1086
|
+
elif not current_conversation_id:
|
|
1087
|
+
# Create new conversation
|
|
1088
|
+
conv_data = ConversationCreate(
|
|
1089
|
+
{%- if cookiecutter.websocket_auth_jwt %}
|
|
1090
|
+
user_id=user.id,
|
|
1091
|
+
{%- endif %}
|
|
1092
|
+
title=user_message[:50] if len(user_message) > 50 else user_message,
|
|
1093
|
+
)
|
|
1094
|
+
conversation = await conv_service.create_conversation(conv_data)
|
|
1095
|
+
current_conversation_id = str(conversation.id)
|
|
1096
|
+
await manager.send_event(
|
|
1097
|
+
websocket,
|
|
1098
|
+
"conversation_created",
|
|
1099
|
+
{"conversation_id": current_conversation_id},
|
|
1100
|
+
)
|
|
1101
|
+
|
|
1102
|
+
# Save user message
|
|
1103
|
+
await conv_service.add_message(
|
|
1104
|
+
UUID(current_conversation_id),
|
|
1105
|
+
MessageCreate(role="user", content=user_message),
|
|
1106
|
+
)
|
|
1107
|
+
{%- else %}
|
|
1108
|
+
with get_db_session() as db:
|
|
1109
|
+
conv_service = get_conversation_service(db)
|
|
1110
|
+
|
|
1111
|
+
# Get or create conversation
|
|
1112
|
+
requested_conv_id = data.get("conversation_id")
|
|
1113
|
+
if requested_conv_id:
|
|
1114
|
+
current_conversation_id = requested_conv_id
|
|
1115
|
+
conv_service.get_conversation(requested_conv_id)
|
|
1116
|
+
elif not current_conversation_id:
|
|
1117
|
+
# Create new conversation
|
|
1118
|
+
conv_data = ConversationCreate(
|
|
1119
|
+
{%- if cookiecutter.websocket_auth_jwt %}
|
|
1120
|
+
user_id=str(user.id),
|
|
1121
|
+
{%- endif %}
|
|
1122
|
+
title=user_message[:50] if len(user_message) > 50 else user_message,
|
|
1123
|
+
)
|
|
1124
|
+
conversation = conv_service.create_conversation(conv_data)
|
|
1125
|
+
current_conversation_id = str(conversation.id)
|
|
1126
|
+
await manager.send_event(
|
|
1127
|
+
websocket,
|
|
1128
|
+
"conversation_created",
|
|
1129
|
+
{"conversation_id": current_conversation_id},
|
|
1130
|
+
)
|
|
1131
|
+
|
|
1132
|
+
# Save user message
|
|
1133
|
+
conv_service.add_message(
|
|
1134
|
+
current_conversation_id,
|
|
1135
|
+
MessageCreate(role="user", content=user_message),
|
|
1136
|
+
)
|
|
1137
|
+
{%- endif %}
|
|
1138
|
+
except Exception as e:
|
|
1139
|
+
logger.warning(f"Failed to persist conversation: {e}")
|
|
1140
|
+
# Continue without persistence
|
|
1141
|
+
{%- elif cookiecutter.enable_conversation_persistence and cookiecutter.use_mongodb %}
|
|
1142
|
+
|
|
1143
|
+
# Handle conversation persistence (MongoDB)
|
|
1144
|
+
conv_service = get_conversation_service()
|
|
1145
|
+
|
|
1146
|
+
requested_conv_id = data.get("conversation_id")
|
|
1147
|
+
if requested_conv_id:
|
|
1148
|
+
current_conversation_id = requested_conv_id
|
|
1149
|
+
await conv_service.get_conversation(requested_conv_id)
|
|
1150
|
+
elif not current_conversation_id:
|
|
1151
|
+
conv_data = ConversationCreate(
|
|
1152
|
+
{%- if cookiecutter.websocket_auth_jwt %}
|
|
1153
|
+
user_id=str(user.id),
|
|
1154
|
+
{%- endif %}
|
|
1155
|
+
title=user_message[:50] if len(user_message) > 50 else user_message,
|
|
1156
|
+
)
|
|
1157
|
+
conversation = await conv_service.create_conversation(conv_data)
|
|
1158
|
+
current_conversation_id = str(conversation.id)
|
|
1159
|
+
await manager.send_event(
|
|
1160
|
+
websocket,
|
|
1161
|
+
"conversation_created",
|
|
1162
|
+
{"conversation_id": current_conversation_id},
|
|
1163
|
+
)
|
|
1164
|
+
|
|
1165
|
+
# Save user message
|
|
1166
|
+
await conv_service.add_message(
|
|
1167
|
+
current_conversation_id,
|
|
1168
|
+
MessageCreate(role="user", content=user_message),
|
|
1169
|
+
)
|
|
1170
|
+
{%- endif %}
|
|
1171
|
+
|
|
1172
|
+
await manager.send_event(websocket, "user_prompt", {"content": user_message})
|
|
1173
|
+
|
|
1174
|
+
try:
|
|
1175
|
+
assistant = get_agent()
|
|
1176
|
+
|
|
1177
|
+
final_output = ""
|
|
1178
|
+
tool_events: list[Any] = []
|
|
1179
|
+
seen_tool_call_ids: set[str] = set()
|
|
1180
|
+
|
|
1181
|
+
await manager.send_event(websocket, "model_request_start", {})
|
|
1182
|
+
|
|
1183
|
+
# Use LangGraph's astream with messages and updates modes
|
|
1184
|
+
async for stream_mode, data in assistant.stream(
|
|
1185
|
+
user_message,
|
|
1186
|
+
history=conversation_history,
|
|
1187
|
+
context=context,
|
|
1188
|
+
):
|
|
1189
|
+
if stream_mode == "messages":
|
|
1190
|
+
chunk, _metadata = data
|
|
1191
|
+
|
|
1192
|
+
if isinstance(chunk, AIMessageChunk):
|
|
1193
|
+
if chunk.content:
|
|
1194
|
+
text_content = ""
|
|
1195
|
+
if isinstance(chunk.content, str):
|
|
1196
|
+
text_content = chunk.content
|
|
1197
|
+
elif isinstance(chunk.content, list):
|
|
1198
|
+
for block in chunk.content:
|
|
1199
|
+
if isinstance(block, dict) and block.get("type") == "text":
|
|
1200
|
+
text_content += block.get("text", "")
|
|
1201
|
+
elif isinstance(block, str):
|
|
1202
|
+
text_content += block
|
|
1203
|
+
|
|
1204
|
+
if text_content:
|
|
1205
|
+
await manager.send_event(
|
|
1206
|
+
websocket,
|
|
1207
|
+
"text_delta",
|
|
1208
|
+
{"content": text_content},
|
|
1209
|
+
)
|
|
1210
|
+
final_output += text_content
|
|
1211
|
+
|
|
1212
|
+
# Handle tool call chunks
|
|
1213
|
+
if chunk.tool_call_chunks:
|
|
1214
|
+
for tc_chunk in chunk.tool_call_chunks:
|
|
1215
|
+
tc_id = tc_chunk.get("id")
|
|
1216
|
+
tc_name = tc_chunk.get("name")
|
|
1217
|
+
if tc_id and tc_name and tc_id not in seen_tool_call_ids:
|
|
1218
|
+
seen_tool_call_ids.add(tc_id)
|
|
1219
|
+
await manager.send_event(
|
|
1220
|
+
websocket,
|
|
1221
|
+
"tool_call",
|
|
1222
|
+
{
|
|
1223
|
+
"tool_name": tc_name,
|
|
1224
|
+
"args": {},
|
|
1225
|
+
"tool_call_id": tc_id,
|
|
1226
|
+
},
|
|
1227
|
+
)
|
|
1228
|
+
|
|
1229
|
+
elif stream_mode == "updates":
|
|
1230
|
+
# Handle state updates from nodes
|
|
1231
|
+
for node_name, update in data.items():
|
|
1232
|
+
if node_name == "tools":
|
|
1233
|
+
# Tool node completed - extract tool results
|
|
1234
|
+
for msg in update.get("messages", []):
|
|
1235
|
+
if isinstance(msg, ToolMessage):
|
|
1236
|
+
await manager.send_event(
|
|
1237
|
+
websocket,
|
|
1238
|
+
"tool_result",
|
|
1239
|
+
{
|
|
1240
|
+
"tool_call_id": msg.tool_call_id,
|
|
1241
|
+
"content": msg.content,
|
|
1242
|
+
},
|
|
1243
|
+
)
|
|
1244
|
+
elif node_name == "agent":
|
|
1245
|
+
# Agent node completed - check for tool calls
|
|
1246
|
+
for msg in update.get("messages", []):
|
|
1247
|
+
if isinstance(msg, AIMessage) and msg.tool_calls:
|
|
1248
|
+
for tc in msg.tool_calls:
|
|
1249
|
+
tc_id = tc.get("id", "")
|
|
1250
|
+
if tc_id not in seen_tool_call_ids:
|
|
1251
|
+
seen_tool_call_ids.add(tc_id)
|
|
1252
|
+
tool_events.append(tc)
|
|
1253
|
+
await manager.send_event(
|
|
1254
|
+
websocket,
|
|
1255
|
+
"tool_call",
|
|
1256
|
+
{
|
|
1257
|
+
"tool_name": tc.get("name", ""),
|
|
1258
|
+
"args": tc.get("args", {}),
|
|
1259
|
+
"tool_call_id": tc_id,
|
|
1260
|
+
},
|
|
1261
|
+
)
|
|
1262
|
+
|
|
1263
|
+
await manager.send_event(
|
|
1264
|
+
websocket,
|
|
1265
|
+
"final_result",
|
|
1266
|
+
{"output": final_output},
|
|
1267
|
+
)
|
|
1268
|
+
|
|
1269
|
+
# Update conversation history
|
|
1270
|
+
conversation_history.append({"role": "user", "content": user_message})
|
|
1271
|
+
if final_output:
|
|
1272
|
+
conversation_history.append(
|
|
1273
|
+
{"role": "assistant", "content": final_output}
|
|
1274
|
+
)
|
|
1275
|
+
|
|
1276
|
+
{%- if cookiecutter.enable_conversation_persistence and (cookiecutter.use_postgresql or cookiecutter.use_sqlite) %}
|
|
1277
|
+
|
|
1278
|
+
# Save assistant response to database
|
|
1279
|
+
if current_conversation_id and final_output:
|
|
1280
|
+
try:
|
|
1281
|
+
{%- if cookiecutter.use_postgresql %}
|
|
1282
|
+
async with get_db_context() as db:
|
|
1283
|
+
conv_service = get_conversation_service(db)
|
|
1284
|
+
await conv_service.add_message(
|
|
1285
|
+
UUID(current_conversation_id),
|
|
1286
|
+
MessageCreate(
|
|
1287
|
+
role="assistant",
|
|
1288
|
+
content=final_output,
|
|
1289
|
+
model_name=assistant.model_name if hasattr(assistant, "model_name") else None,
|
|
1290
|
+
),
|
|
1291
|
+
)
|
|
1292
|
+
{%- else %}
|
|
1293
|
+
with get_db_session() as db:
|
|
1294
|
+
conv_service = get_conversation_service(db)
|
|
1295
|
+
conv_service.add_message(
|
|
1296
|
+
current_conversation_id,
|
|
1297
|
+
MessageCreate(
|
|
1298
|
+
role="assistant",
|
|
1299
|
+
content=final_output,
|
|
1300
|
+
model_name=assistant.model_name if hasattr(assistant, "model_name") else None,
|
|
1301
|
+
),
|
|
1302
|
+
)
|
|
1303
|
+
{%- endif %}
|
|
1304
|
+
except Exception as e:
|
|
1305
|
+
logger.warning(f"Failed to persist assistant response: {e}")
|
|
1306
|
+
{%- elif cookiecutter.enable_conversation_persistence and cookiecutter.use_mongodb %}
|
|
1307
|
+
|
|
1308
|
+
# Save assistant response to database
|
|
1309
|
+
if current_conversation_id and final_output:
|
|
1310
|
+
try:
|
|
1311
|
+
await conv_service.add_message(
|
|
1312
|
+
current_conversation_id,
|
|
1313
|
+
MessageCreate(
|
|
1314
|
+
role="assistant",
|
|
1315
|
+
content=final_output,
|
|
1316
|
+
model_name=assistant.model_name if hasattr(assistant, "model_name") else None,
|
|
1317
|
+
),
|
|
1318
|
+
)
|
|
1319
|
+
except Exception as e:
|
|
1320
|
+
logger.warning(f"Failed to persist assistant response: {e}")
|
|
1321
|
+
{%- endif %}
|
|
1322
|
+
|
|
1323
|
+
await manager.send_event(websocket, "complete", {
|
|
1324
|
+
{%- if cookiecutter.enable_conversation_persistence and cookiecutter.use_database %}
|
|
1325
|
+
"conversation_id": current_conversation_id,
|
|
1326
|
+
{%- endif %}
|
|
1327
|
+
})
|
|
1328
|
+
|
|
1329
|
+
except WebSocketDisconnect:
|
|
1330
|
+
# Client disconnected during processing - this is normal
|
|
1331
|
+
logger.info("Client disconnected during agent processing")
|
|
1332
|
+
break
|
|
1333
|
+
except Exception as e:
|
|
1334
|
+
logger.exception(f"Error processing agent request: {e}")
|
|
1335
|
+
# Try to send error, but don't fail if connection is closed
|
|
1336
|
+
await manager.send_event(websocket, "error", {"message": str(e)})
|
|
1337
|
+
|
|
1338
|
+
except WebSocketDisconnect:
|
|
1339
|
+
pass # Normal disconnect
|
|
1340
|
+
finally:
|
|
1341
|
+
manager.disconnect(websocket)
|
|
1342
|
+
{%- elif cookiecutter.enable_ai_agent and cookiecutter.use_crewai %}
|
|
1343
|
+
"""AI Agent WebSocket routes with streaming support (CrewAI Multi-Agent)."""
|
|
1344
|
+
|
|
1345
|
+
import logging
|
|
1346
|
+
from typing import Any
|
|
1347
|
+
{%- if cookiecutter.enable_conversation_persistence and cookiecutter.use_database %}
|
|
1348
|
+
from datetime import datetime, UTC
|
|
1349
|
+
{%- if cookiecutter.use_postgresql %}
|
|
1350
|
+
from uuid import UUID
|
|
1351
|
+
{%- endif %}
|
|
1352
|
+
{%- endif %}
|
|
1353
|
+
|
|
1354
|
+
from fastapi import APIRouter, WebSocket, WebSocketDisconnect{%- if cookiecutter.websocket_auth_jwt %}, Depends{%- endif %}{%- if cookiecutter.websocket_auth_api_key %}, Query{%- endif %}
|
|
1355
|
+
|
|
1356
|
+
from app.agents.crewai_assistant import CrewContext, get_crew
|
|
1357
|
+
{%- if cookiecutter.websocket_auth_jwt %}
|
|
1358
|
+
from app.api.deps import get_current_user_ws
|
|
1359
|
+
from app.db.models.user import User
|
|
1360
|
+
{%- endif %}
|
|
1361
|
+
{%- if cookiecutter.websocket_auth_api_key %}
|
|
1362
|
+
from app.core.config import settings
|
|
1363
|
+
{%- endif %}
|
|
1364
|
+
{%- if cookiecutter.enable_conversation_persistence and (cookiecutter.use_postgresql or cookiecutter.use_sqlite) %}
|
|
1365
|
+
from app.db.session import get_db_context
|
|
1366
|
+
from app.api.deps import ConversationSvc, get_conversation_service
|
|
1367
|
+
from app.schemas.conversation import ConversationCreate, MessageCreate
|
|
1368
|
+
{%- elif cookiecutter.enable_conversation_persistence and cookiecutter.use_mongodb %}
|
|
1369
|
+
from app.api.deps import ConversationSvc, get_conversation_service
|
|
1370
|
+
from app.schemas.conversation import ConversationCreate, MessageCreate
|
|
1371
|
+
{%- endif %}
|
|
1372
|
+
|
|
1373
|
+
logger = logging.getLogger(__name__)
|
|
1374
|
+
|
|
1375
|
+
router = APIRouter()
|
|
1376
|
+
|
|
1377
|
+
|
|
1378
|
+
class AgentConnectionManager:
|
|
1379
|
+
"""WebSocket connection manager for AI agent."""
|
|
1380
|
+
|
|
1381
|
+
def __init__(self) -> None:
|
|
1382
|
+
self.active_connections: list[WebSocket] = []
|
|
1383
|
+
|
|
1384
|
+
async def connect(self, websocket: WebSocket) -> None:
|
|
1385
|
+
"""Accept and store a new WebSocket connection."""
|
|
1386
|
+
await websocket.accept()
|
|
1387
|
+
self.active_connections.append(websocket)
|
|
1388
|
+
logger.info(f"Agent WebSocket connected. Total connections: {len(self.active_connections)}")
|
|
1389
|
+
|
|
1390
|
+
def disconnect(self, websocket: WebSocket) -> None:
|
|
1391
|
+
"""Remove a WebSocket connection."""
|
|
1392
|
+
if websocket in self.active_connections:
|
|
1393
|
+
self.active_connections.remove(websocket)
|
|
1394
|
+
logger.info(f"Agent WebSocket disconnected. Total connections: {len(self.active_connections)}")
|
|
1395
|
+
|
|
1396
|
+
async def send_event(self, websocket: WebSocket, event_type: str, data: Any) -> bool:
|
|
1397
|
+
"""Send a JSON event to a specific WebSocket client.
|
|
1398
|
+
|
|
1399
|
+
Returns True if sent successfully, False if connection is closed.
|
|
1400
|
+
"""
|
|
1401
|
+
try:
|
|
1402
|
+
await websocket.send_json({"type": event_type, "data": data})
|
|
1403
|
+
return True
|
|
1404
|
+
except (WebSocketDisconnect, RuntimeError):
|
|
1405
|
+
# Connection already closed
|
|
1406
|
+
return False
|
|
1407
|
+
|
|
1408
|
+
|
|
1409
|
+
manager = AgentConnectionManager()
|
|
1410
|
+
|
|
1411
|
+
{%- if cookiecutter.websocket_auth_api_key %}
|
|
1412
|
+
|
|
1413
|
+
|
|
1414
|
+
async def verify_api_key(api_key: str) -> bool:
|
|
1415
|
+
"""Verify the API key for WebSocket authentication."""
|
|
1416
|
+
return api_key == settings.API_KEY
|
|
1417
|
+
{%- endif %}
|
|
1418
|
+
|
|
1419
|
+
|
|
1420
|
+
@router.websocket("/ws/agent")
|
|
1421
|
+
async def agent_websocket(
|
|
1422
|
+
websocket: WebSocket,
|
|
1423
|
+
{%- if cookiecutter.websocket_auth_jwt %}
|
|
1424
|
+
user: User = Depends(get_current_user_ws),
|
|
1425
|
+
{%- elif cookiecutter.websocket_auth_api_key %}
|
|
1426
|
+
api_key: str = Query(..., alias="api_key"),
|
|
1427
|
+
{%- endif %}
|
|
1428
|
+
) -> None:
|
|
1429
|
+
"""WebSocket endpoint for CrewAI multi-agent with streaming support.
|
|
1430
|
+
|
|
1431
|
+
Uses CrewAI to stream crew execution events including:
|
|
1432
|
+
- user_prompt: When user input is received
|
|
1433
|
+
- task_start: When a task begins execution
|
|
1434
|
+
- agent_action: When an agent takes an action
|
|
1435
|
+
- task_complete: When a task finishes
|
|
1436
|
+
- crew_complete: When all tasks are done
|
|
1437
|
+
- final_result: When the final result is ready
|
|
1438
|
+
- complete: When processing is complete
|
|
1439
|
+
- error: When an error occurs
|
|
1440
|
+
|
|
1441
|
+
Expected input message format:
|
|
1442
|
+
{
|
|
1443
|
+
"message": "user message here",
|
|
1444
|
+
"history": [{"role": "user|assistant|system", "content": "..."}]{% if cookiecutter.enable_conversation_persistence and cookiecutter.use_database %},
|
|
1445
|
+
"conversation_id": "optional-uuid-to-continue-existing-conversation"{% endif %}
|
|
1446
|
+
}
|
|
1447
|
+
{%- if cookiecutter.websocket_auth_jwt %}
|
|
1448
|
+
|
|
1449
|
+
Authentication: Requires a valid JWT token passed as a query parameter or header.
|
|
1450
|
+
{%- elif cookiecutter.websocket_auth_api_key %}
|
|
1451
|
+
|
|
1452
|
+
Authentication: Requires a valid API key passed as 'api_key' query parameter.
|
|
1453
|
+
Example: ws://localhost:{{ cookiecutter.backend_port }}/api/v1/ws/agent?api_key=your-api-key
|
|
1454
|
+
{%- endif %}
|
|
1455
|
+
{%- if cookiecutter.enable_conversation_persistence and cookiecutter.use_database %}
|
|
1456
|
+
|
|
1457
|
+
Persistence: Set 'conversation_id' to continue an existing conversation.
|
|
1458
|
+
If not provided, a new conversation is created. The conversation_id is
|
|
1459
|
+
returned in the 'conversation_created' event.
|
|
1460
|
+
{%- endif %}
|
|
1461
|
+
"""
|
|
1462
|
+
{%- if cookiecutter.websocket_auth_api_key %}
|
|
1463
|
+
# Verify API key before accepting connection
|
|
1464
|
+
if not await verify_api_key(api_key):
|
|
1465
|
+
await websocket.close(code=4001, reason="Invalid API key")
|
|
1466
|
+
return
|
|
1467
|
+
{%- endif %}
|
|
1468
|
+
|
|
1469
|
+
await manager.connect(websocket)
|
|
1470
|
+
|
|
1471
|
+
# Conversation state per connection
|
|
1472
|
+
conversation_history: list[dict[str, str]] = []
|
|
1473
|
+
context: CrewContext = {}
|
|
1474
|
+
{%- if cookiecutter.websocket_auth_jwt %}
|
|
1475
|
+
context["user_id"] = str(user.id) if user else None
|
|
1476
|
+
context["user_name"] = user.email if user else None
|
|
1477
|
+
{%- endif %}
|
|
1478
|
+
{%- if cookiecutter.enable_conversation_persistence and cookiecutter.use_database %}
|
|
1479
|
+
current_conversation_id: str | None = None
|
|
1480
|
+
{%- endif %}
|
|
1481
|
+
|
|
1482
|
+
try:
|
|
1483
|
+
while True:
|
|
1484
|
+
# Receive user message
|
|
1485
|
+
data = await websocket.receive_json()
|
|
1486
|
+
user_message = data.get("message", "")
|
|
1487
|
+
# Optionally accept history from client (or use server-side tracking)
|
|
1488
|
+
if "history" in data:
|
|
1489
|
+
conversation_history = data["history"]
|
|
1490
|
+
|
|
1491
|
+
if not user_message:
|
|
1492
|
+
await manager.send_event(websocket, "error", {"message": "Empty message"})
|
|
1493
|
+
continue
|
|
1494
|
+
|
|
1495
|
+
{%- if cookiecutter.enable_conversation_persistence and (cookiecutter.use_postgresql or cookiecutter.use_sqlite) %}
|
|
1496
|
+
|
|
1497
|
+
# Handle conversation persistence
|
|
1498
|
+
try:
|
|
1499
|
+
{%- if cookiecutter.use_postgresql %}
|
|
1500
|
+
async with get_db_context() as db:
|
|
1501
|
+
conv_service = get_conversation_service(db)
|
|
1502
|
+
|
|
1503
|
+
# Get or create conversation
|
|
1504
|
+
requested_conv_id = data.get("conversation_id")
|
|
1505
|
+
if requested_conv_id:
|
|
1506
|
+
current_conversation_id = requested_conv_id
|
|
1507
|
+
# Verify conversation exists
|
|
1508
|
+
await conv_service.get_conversation(UUID(requested_conv_id))
|
|
1509
|
+
elif not current_conversation_id:
|
|
1510
|
+
# Create new conversation
|
|
1511
|
+
conv_data = ConversationCreate(
|
|
1512
|
+
{%- if cookiecutter.websocket_auth_jwt %}
|
|
1513
|
+
user_id=user.id,
|
|
1514
|
+
{%- endif %}
|
|
1515
|
+
title=user_message[:50] if len(user_message) > 50 else user_message,
|
|
1516
|
+
)
|
|
1517
|
+
conversation = await conv_service.create_conversation(conv_data)
|
|
1518
|
+
current_conversation_id = str(conversation.id)
|
|
1519
|
+
await manager.send_event(
|
|
1520
|
+
websocket,
|
|
1521
|
+
"conversation_created",
|
|
1522
|
+
{"conversation_id": current_conversation_id},
|
|
1523
|
+
)
|
|
1524
|
+
|
|
1525
|
+
# Save user message
|
|
1526
|
+
await conv_service.add_message(
|
|
1527
|
+
UUID(current_conversation_id),
|
|
1528
|
+
MessageCreate(role="user", content=user_message),
|
|
1529
|
+
)
|
|
1530
|
+
{%- else %}
|
|
1531
|
+
with get_db_session() as db:
|
|
1532
|
+
conv_service = get_conversation_service(db)
|
|
1533
|
+
|
|
1534
|
+
# Get or create conversation
|
|
1535
|
+
requested_conv_id = data.get("conversation_id")
|
|
1536
|
+
if requested_conv_id:
|
|
1537
|
+
current_conversation_id = requested_conv_id
|
|
1538
|
+
conv_service.get_conversation(requested_conv_id)
|
|
1539
|
+
elif not current_conversation_id:
|
|
1540
|
+
# Create new conversation
|
|
1541
|
+
conv_data = ConversationCreate(
|
|
1542
|
+
{%- if cookiecutter.websocket_auth_jwt %}
|
|
1543
|
+
user_id=str(user.id),
|
|
1544
|
+
{%- endif %}
|
|
1545
|
+
title=user_message[:50] if len(user_message) > 50 else user_message,
|
|
1546
|
+
)
|
|
1547
|
+
conversation = conv_service.create_conversation(conv_data)
|
|
1548
|
+
current_conversation_id = str(conversation.id)
|
|
1549
|
+
await manager.send_event(
|
|
1550
|
+
websocket,
|
|
1551
|
+
"conversation_created",
|
|
1552
|
+
{"conversation_id": current_conversation_id},
|
|
1553
|
+
)
|
|
1554
|
+
|
|
1555
|
+
# Save user message
|
|
1556
|
+
conv_service.add_message(
|
|
1557
|
+
current_conversation_id,
|
|
1558
|
+
MessageCreate(role="user", content=user_message),
|
|
1559
|
+
)
|
|
1560
|
+
{%- endif %}
|
|
1561
|
+
except Exception as e:
|
|
1562
|
+
logger.warning(f"Failed to persist conversation: {e}")
|
|
1563
|
+
# Continue without persistence
|
|
1564
|
+
{%- elif cookiecutter.enable_conversation_persistence and cookiecutter.use_mongodb %}
|
|
1565
|
+
|
|
1566
|
+
# Handle conversation persistence (MongoDB)
|
|
1567
|
+
conv_service = get_conversation_service()
|
|
1568
|
+
|
|
1569
|
+
requested_conv_id = data.get("conversation_id")
|
|
1570
|
+
if requested_conv_id:
|
|
1571
|
+
current_conversation_id = requested_conv_id
|
|
1572
|
+
await conv_service.get_conversation(requested_conv_id)
|
|
1573
|
+
elif not current_conversation_id:
|
|
1574
|
+
conv_data = ConversationCreate(
|
|
1575
|
+
{%- if cookiecutter.websocket_auth_jwt %}
|
|
1576
|
+
user_id=str(user.id),
|
|
1577
|
+
{%- endif %}
|
|
1578
|
+
title=user_message[:50] if len(user_message) > 50 else user_message,
|
|
1579
|
+
)
|
|
1580
|
+
conversation = await conv_service.create_conversation(conv_data)
|
|
1581
|
+
current_conversation_id = str(conversation.id)
|
|
1582
|
+
await manager.send_event(
|
|
1583
|
+
websocket,
|
|
1584
|
+
"conversation_created",
|
|
1585
|
+
{"conversation_id": current_conversation_id},
|
|
1586
|
+
)
|
|
1587
|
+
|
|
1588
|
+
# Save user message
|
|
1589
|
+
await conv_service.add_message(
|
|
1590
|
+
current_conversation_id,
|
|
1591
|
+
MessageCreate(role="user", content=user_message),
|
|
1592
|
+
)
|
|
1593
|
+
{%- endif %}
|
|
1594
|
+
|
|
1595
|
+
await manager.send_event(websocket, "user_prompt", {"content": user_message})
|
|
1596
|
+
|
|
1597
|
+
try:
|
|
1598
|
+
crew_assistant = get_crew()
|
|
1599
|
+
|
|
1600
|
+
final_output = ""
|
|
1601
|
+
|
|
1602
|
+
await manager.send_event(websocket, "crew_start", {
|
|
1603
|
+
"crew_name": crew_assistant.config.name,
|
|
1604
|
+
"process": crew_assistant.config.process,
|
|
1605
|
+
})
|
|
1606
|
+
|
|
1607
|
+
# Stream crew execution events
|
|
1608
|
+
async for event in crew_assistant.stream(
|
|
1609
|
+
user_message,
|
|
1610
|
+
history=conversation_history,
|
|
1611
|
+
context=context,
|
|
1612
|
+
):
|
|
1613
|
+
event_type = event.get("type", "unknown")
|
|
1614
|
+
|
|
1615
|
+
# Crew lifecycle events
|
|
1616
|
+
if event_type == "crew_started":
|
|
1617
|
+
await manager.send_event(
|
|
1618
|
+
websocket,
|
|
1619
|
+
"crew_started",
|
|
1620
|
+
{
|
|
1621
|
+
"crew_name": event.get("crew_name", ""),
|
|
1622
|
+
"crew_id": event.get("crew_id", ""),
|
|
1623
|
+
},
|
|
1624
|
+
)
|
|
1625
|
+
|
|
1626
|
+
# Agent events
|
|
1627
|
+
elif event_type == "agent_started":
|
|
1628
|
+
await manager.send_event(
|
|
1629
|
+
websocket,
|
|
1630
|
+
"agent_started",
|
|
1631
|
+
{
|
|
1632
|
+
"agent": event.get("agent", ""),
|
|
1633
|
+
"task": event.get("task", ""),
|
|
1634
|
+
},
|
|
1635
|
+
)
|
|
1636
|
+
|
|
1637
|
+
elif event_type == "agent_completed":
|
|
1638
|
+
agent_name = event.get("agent", "")
|
|
1639
|
+
agent_output = event.get("output", "")
|
|
1640
|
+
await manager.send_event(
|
|
1641
|
+
websocket,
|
|
1642
|
+
"agent_completed",
|
|
1643
|
+
{
|
|
1644
|
+
"agent": agent_name,
|
|
1645
|
+
"output": agent_output,
|
|
1646
|
+
},
|
|
1647
|
+
)
|
|
1648
|
+
{%- if cookiecutter.enable_conversation_persistence and (cookiecutter.use_postgresql or cookiecutter.use_sqlite) %}
|
|
1649
|
+
# Save agent's output as a separate message
|
|
1650
|
+
if current_conversation_id and agent_output:
|
|
1651
|
+
try:
|
|
1652
|
+
{%- if cookiecutter.use_postgresql %}
|
|
1653
|
+
async with get_db_context() as db:
|
|
1654
|
+
conv_service = get_conversation_service(db)
|
|
1655
|
+
await conv_service.add_message(
|
|
1656
|
+
UUID(current_conversation_id),
|
|
1657
|
+
MessageCreate(
|
|
1658
|
+
role="assistant",
|
|
1659
|
+
content=f"✅ **{agent_name}**\n\n{agent_output}",
|
|
1660
|
+
),
|
|
1661
|
+
)
|
|
1662
|
+
{%- else %}
|
|
1663
|
+
with get_db_session() as db:
|
|
1664
|
+
conv_service = get_conversation_service(db)
|
|
1665
|
+
conv_service.add_message(
|
|
1666
|
+
current_conversation_id,
|
|
1667
|
+
MessageCreate(
|
|
1668
|
+
role="assistant",
|
|
1669
|
+
content=f"✅ **{agent_name}**\n\n{agent_output}",
|
|
1670
|
+
),
|
|
1671
|
+
)
|
|
1672
|
+
{%- endif %}
|
|
1673
|
+
except Exception as e:
|
|
1674
|
+
logger.warning(f"Failed to persist agent response: {e}")
|
|
1675
|
+
{%- elif cookiecutter.enable_conversation_persistence and cookiecutter.use_mongodb %}
|
|
1676
|
+
# Save agent's output as a separate message
|
|
1677
|
+
if current_conversation_id and agent_output:
|
|
1678
|
+
try:
|
|
1679
|
+
await conv_service.add_message(
|
|
1680
|
+
current_conversation_id,
|
|
1681
|
+
MessageCreate(
|
|
1682
|
+
role="assistant",
|
|
1683
|
+
content=f"✅ **{agent_name}**\n\n{agent_output}",
|
|
1684
|
+
),
|
|
1685
|
+
)
|
|
1686
|
+
except Exception as e:
|
|
1687
|
+
logger.warning(f"Failed to persist agent response: {e}")
|
|
1688
|
+
{%- endif %}
|
|
1689
|
+
|
|
1690
|
+
# Task events
|
|
1691
|
+
elif event_type == "task_started":
|
|
1692
|
+
await manager.send_event(
|
|
1693
|
+
websocket,
|
|
1694
|
+
"task_started",
|
|
1695
|
+
{
|
|
1696
|
+
"task_id": event.get("task_id", ""),
|
|
1697
|
+
"description": event.get("description", ""),
|
|
1698
|
+
"agent": event.get("agent", ""),
|
|
1699
|
+
},
|
|
1700
|
+
)
|
|
1701
|
+
|
|
1702
|
+
elif event_type == "task_completed":
|
|
1703
|
+
await manager.send_event(
|
|
1704
|
+
websocket,
|
|
1705
|
+
"task_completed",
|
|
1706
|
+
{
|
|
1707
|
+
"task_id": event.get("task_id", ""),
|
|
1708
|
+
"output": event.get("output", ""),
|
|
1709
|
+
"agent": event.get("agent", ""),
|
|
1710
|
+
},
|
|
1711
|
+
)
|
|
1712
|
+
|
|
1713
|
+
# Tool events
|
|
1714
|
+
elif event_type == "tool_started":
|
|
1715
|
+
await manager.send_event(
|
|
1716
|
+
websocket,
|
|
1717
|
+
"tool_started",
|
|
1718
|
+
{
|
|
1719
|
+
"tool_name": event.get("tool_name", ""),
|
|
1720
|
+
"tool_args": event.get("tool_args", ""),
|
|
1721
|
+
"agent": event.get("agent", ""),
|
|
1722
|
+
},
|
|
1723
|
+
)
|
|
1724
|
+
|
|
1725
|
+
elif event_type == "tool_finished":
|
|
1726
|
+
await manager.send_event(
|
|
1727
|
+
websocket,
|
|
1728
|
+
"tool_finished",
|
|
1729
|
+
{
|
|
1730
|
+
"tool_name": event.get("tool_name", ""),
|
|
1731
|
+
"tool_result": event.get("tool_result", ""),
|
|
1732
|
+
"agent": event.get("agent", ""),
|
|
1733
|
+
},
|
|
1734
|
+
)
|
|
1735
|
+
|
|
1736
|
+
# LLM events
|
|
1737
|
+
elif event_type == "llm_started":
|
|
1738
|
+
await manager.send_event(
|
|
1739
|
+
websocket,
|
|
1740
|
+
"llm_started",
|
|
1741
|
+
{
|
|
1742
|
+
"agent": event.get("agent", ""),
|
|
1743
|
+
},
|
|
1744
|
+
)
|
|
1745
|
+
|
|
1746
|
+
elif event_type == "llm_completed":
|
|
1747
|
+
await manager.send_event(
|
|
1748
|
+
websocket,
|
|
1749
|
+
"llm_completed",
|
|
1750
|
+
{
|
|
1751
|
+
"agent": event.get("agent", ""),
|
|
1752
|
+
"response": event.get("response", ""),
|
|
1753
|
+
},
|
|
1754
|
+
)
|
|
1755
|
+
|
|
1756
|
+
# Final result
|
|
1757
|
+
elif event_type == "crew_complete":
|
|
1758
|
+
final_output = event.get("result", "")
|
|
1759
|
+
await manager.send_event(
|
|
1760
|
+
websocket,
|
|
1761
|
+
"final_result",
|
|
1762
|
+
{"output": final_output},
|
|
1763
|
+
)
|
|
1764
|
+
|
|
1765
|
+
# Error
|
|
1766
|
+
elif event_type == "error":
|
|
1767
|
+
await manager.send_event(
|
|
1768
|
+
websocket,
|
|
1769
|
+
"error",
|
|
1770
|
+
{"message": event.get("error", "Unknown error")},
|
|
1771
|
+
)
|
|
1772
|
+
|
|
1773
|
+
# Update conversation history
|
|
1774
|
+
conversation_history.append({"role": "user", "content": user_message})
|
|
1775
|
+
if final_output:
|
|
1776
|
+
conversation_history.append(
|
|
1777
|
+
{"role": "assistant", "content": final_output}
|
|
1778
|
+
)
|
|
1779
|
+
|
|
1780
|
+
{%- if cookiecutter.enable_conversation_persistence and cookiecutter.use_database %}
|
|
1781
|
+
# Note: Agent outputs are saved individually in agent_completed events above
|
|
1782
|
+
{%- endif %}
|
|
1783
|
+
|
|
1784
|
+
await manager.send_event(websocket, "complete", {
|
|
1785
|
+
{%- if cookiecutter.enable_conversation_persistence and cookiecutter.use_database %}
|
|
1786
|
+
"conversation_id": current_conversation_id,
|
|
1787
|
+
{%- endif %}
|
|
1788
|
+
})
|
|
1789
|
+
|
|
1790
|
+
except WebSocketDisconnect:
|
|
1791
|
+
# Client disconnected during processing - this is normal
|
|
1792
|
+
logger.info("Client disconnected during agent processing")
|
|
1793
|
+
break
|
|
1794
|
+
except Exception as e:
|
|
1795
|
+
logger.exception(f"Error processing agent request: {e}")
|
|
1796
|
+
# Try to send error, but don't fail if connection is closed
|
|
1797
|
+
await manager.send_event(websocket, "error", {"message": str(e)})
|
|
1798
|
+
|
|
1799
|
+
except WebSocketDisconnect:
|
|
1800
|
+
pass # Normal disconnect
|
|
1801
|
+
finally:
|
|
1802
|
+
manager.disconnect(websocket)
|
|
1803
|
+
{%- elif cookiecutter.enable_ai_agent and cookiecutter.use_deepagents %}
|
|
1804
|
+
"""AI Agent WebSocket routes with streaming and human-in-the-loop support (DeepAgents)."""
|
|
1805
|
+
|
|
1806
|
+
import logging
|
|
1807
|
+
import uuid
|
|
1808
|
+
from typing import Any
|
|
1809
|
+
{%- if cookiecutter.enable_conversation_persistence and cookiecutter.use_database %}
|
|
1810
|
+
from datetime import datetime, UTC
|
|
1811
|
+
{%- if cookiecutter.use_postgresql %}
|
|
1812
|
+
from uuid import UUID
|
|
1813
|
+
{%- endif %}
|
|
1814
|
+
{%- endif %}
|
|
1815
|
+
|
|
1816
|
+
from fastapi import APIRouter, WebSocket, WebSocketDisconnect{%- if cookiecutter.websocket_auth_jwt %}, Depends{%- endif %}{%- if cookiecutter.websocket_auth_api_key %}, Query{%- endif %}
|
|
1817
|
+
|
|
1818
|
+
from langchain_core.messages import AIMessage, AIMessageChunk, HumanMessage, SystemMessage, ToolMessage
|
|
1819
|
+
|
|
1820
|
+
from app.agents.deepagents_assistant import AgentContext, Decision, InterruptData, get_agent
|
|
1821
|
+
{%- if cookiecutter.websocket_auth_jwt %}
|
|
1822
|
+
from app.api.deps import get_current_user_ws
|
|
1823
|
+
from app.db.models.user import User
|
|
1824
|
+
{%- endif %}
|
|
1825
|
+
{%- if cookiecutter.websocket_auth_api_key %}
|
|
1826
|
+
from app.core.config import settings
|
|
1827
|
+
{%- endif %}
|
|
1828
|
+
{%- if cookiecutter.enable_conversation_persistence and (cookiecutter.use_postgresql or cookiecutter.use_sqlite) %}
|
|
1829
|
+
from app.db.session import get_db_context
|
|
1830
|
+
from app.api.deps import ConversationSvc, get_conversation_service
|
|
1831
|
+
from app.schemas.conversation import ConversationCreate, MessageCreate, ToolCallCreate, ToolCallComplete
|
|
1832
|
+
{%- elif cookiecutter.enable_conversation_persistence and cookiecutter.use_mongodb %}
|
|
1833
|
+
from app.api.deps import ConversationSvc, get_conversation_service
|
|
1834
|
+
from app.schemas.conversation import ConversationCreate, MessageCreate, ToolCallCreate, ToolCallComplete
|
|
1835
|
+
{%- endif %}
|
|
1836
|
+
|
|
1837
|
+
logger = logging.getLogger(__name__)
|
|
1838
|
+
|
|
1839
|
+
router = APIRouter()
|
|
1840
|
+
|
|
1841
|
+
|
|
1842
|
+
class AgentConnectionManager:
|
|
1843
|
+
"""WebSocket connection manager for AI agent."""
|
|
1844
|
+
|
|
1845
|
+
def __init__(self) -> None:
|
|
1846
|
+
self.active_connections: list[WebSocket] = []
|
|
1847
|
+
|
|
1848
|
+
async def connect(self, websocket: WebSocket) -> None:
|
|
1849
|
+
"""Accept and store a new WebSocket connection."""
|
|
1850
|
+
await websocket.accept()
|
|
1851
|
+
self.active_connections.append(websocket)
|
|
1852
|
+
logger.info(f"Agent WebSocket connected. Total connections: {len(self.active_connections)}")
|
|
1853
|
+
|
|
1854
|
+
def disconnect(self, websocket: WebSocket) -> None:
|
|
1855
|
+
"""Remove a WebSocket connection."""
|
|
1856
|
+
if websocket in self.active_connections:
|
|
1857
|
+
self.active_connections.remove(websocket)
|
|
1858
|
+
logger.info(f"Agent WebSocket disconnected. Total connections: {len(self.active_connections)}")
|
|
1859
|
+
|
|
1860
|
+
async def send_event(self, websocket: WebSocket, event_type: str, data: Any) -> bool:
|
|
1861
|
+
"""Send a JSON event to a specific WebSocket client.
|
|
1862
|
+
|
|
1863
|
+
Returns True if sent successfully, False if connection is closed.
|
|
1864
|
+
"""
|
|
1865
|
+
try:
|
|
1866
|
+
await websocket.send_json({"type": event_type, "data": data})
|
|
1867
|
+
return True
|
|
1868
|
+
except (WebSocketDisconnect, RuntimeError):
|
|
1869
|
+
# Connection already closed
|
|
1870
|
+
return False
|
|
1871
|
+
|
|
1872
|
+
|
|
1873
|
+
manager = AgentConnectionManager()
|
|
1874
|
+
|
|
1875
|
+
|
|
1876
|
+
def build_message_history(
|
|
1877
|
+
history: list[dict[str, str]]
|
|
1878
|
+
) -> list[HumanMessage | AIMessage | SystemMessage]:
|
|
1879
|
+
"""Convert conversation history to LangChain message format."""
|
|
1880
|
+
messages: list[HumanMessage | AIMessage | SystemMessage] = []
|
|
1881
|
+
|
|
1882
|
+
for msg in history:
|
|
1883
|
+
if msg["role"] == "user":
|
|
1884
|
+
messages.append(HumanMessage(content=msg["content"]))
|
|
1885
|
+
elif msg["role"] == "assistant":
|
|
1886
|
+
messages.append(AIMessage(content=msg["content"]))
|
|
1887
|
+
elif msg["role"] == "system":
|
|
1888
|
+
messages.append(SystemMessage(content=msg["content"]))
|
|
1889
|
+
|
|
1890
|
+
return messages
|
|
1891
|
+
|
|
1892
|
+
{%- if cookiecutter.websocket_auth_api_key %}
|
|
1893
|
+
|
|
1894
|
+
|
|
1895
|
+
async def verify_api_key(api_key: str) -> bool:
|
|
1896
|
+
"""Verify the API key for WebSocket authentication."""
|
|
1897
|
+
return api_key == settings.API_KEY
|
|
1898
|
+
{%- endif %}
|
|
1899
|
+
|
|
1900
|
+
|
|
1901
|
+
@router.websocket("/ws/agent")
|
|
1902
|
+
async def agent_websocket(
|
|
1903
|
+
websocket: WebSocket,
|
|
1904
|
+
{%- if cookiecutter.websocket_auth_jwt %}
|
|
1905
|
+
user: User = Depends(get_current_user_ws),
|
|
1906
|
+
{%- elif cookiecutter.websocket_auth_api_key %}
|
|
1907
|
+
api_key: str = Query(..., alias="api_key"),
|
|
1908
|
+
{%- endif %}
|
|
1909
|
+
) -> None:
|
|
1910
|
+
"""WebSocket endpoint for DeepAgents with streaming and human-in-the-loop support.
|
|
1911
|
+
|
|
1912
|
+
Uses DeepAgents (LangGraph-based) to stream agent events including:
|
|
1913
|
+
- user_prompt: When user input is received
|
|
1914
|
+
- model_request_start: When model request begins
|
|
1915
|
+
- text_delta: Streaming text from the model
|
|
1916
|
+
- tool_call: When a tool is called (ls, read_file, write_file, edit_file, etc.)
|
|
1917
|
+
- tool_result: When a tool returns a result
|
|
1918
|
+
- tool_approval_required: When human approval is needed for tool execution
|
|
1919
|
+
- final_result: When the final result is ready
|
|
1920
|
+
- complete: When processing is complete
|
|
1921
|
+
- error: When an error occurs
|
|
1922
|
+
|
|
1923
|
+
Human-in-the-loop:
|
|
1924
|
+
When DEEPAGENTS_INTERRUPT_TOOLS is configured, certain tools will require
|
|
1925
|
+
approval before execution. The frontend receives a 'tool_approval_required'
|
|
1926
|
+
event and should respond with a 'resume' message containing decisions.
|
|
1927
|
+
|
|
1928
|
+
Expected input message formats:
|
|
1929
|
+
|
|
1930
|
+
Regular message:
|
|
1931
|
+
{
|
|
1932
|
+
"type": "message", // optional, default
|
|
1933
|
+
"message": "user message here",
|
|
1934
|
+
"history": [{"role": "user|assistant|system", "content": "..."}]{% if cookiecutter.enable_conversation_persistence and cookiecutter.use_database %},
|
|
1935
|
+
"conversation_id": "optional-uuid"{% endif %}
|
|
1936
|
+
}
|
|
1937
|
+
|
|
1938
|
+
Resume after interrupt:
|
|
1939
|
+
{
|
|
1940
|
+
"type": "resume",
|
|
1941
|
+
"decisions": [
|
|
1942
|
+
{"type": "approve"},
|
|
1943
|
+
{"type": "reject"},
|
|
1944
|
+
{"type": "edit", "edited_action": {"name": "tool_name", "args": {...}}}
|
|
1945
|
+
]
|
|
1946
|
+
}
|
|
1947
|
+
{%- if cookiecutter.websocket_auth_jwt %}
|
|
1948
|
+
|
|
1949
|
+
Authentication: Requires a valid JWT token passed as a query parameter or header.
|
|
1950
|
+
{%- elif cookiecutter.websocket_auth_api_key %}
|
|
1951
|
+
|
|
1952
|
+
Authentication: Requires a valid API key passed as 'api_key' query parameter.
|
|
1953
|
+
Example: ws://localhost:{{ cookiecutter.backend_port }}/api/v1/ws/agent?api_key=your-api-key
|
|
1954
|
+
{%- endif %}
|
|
1955
|
+
{%- if cookiecutter.enable_conversation_persistence and cookiecutter.use_database %}
|
|
1956
|
+
|
|
1957
|
+
Persistence: Set 'conversation_id' to continue an existing conversation.
|
|
1958
|
+
If not provided, a new conversation is created. The conversation_id is
|
|
1959
|
+
returned in the 'conversation_created' event.
|
|
1960
|
+
{%- endif %}
|
|
1961
|
+
"""
|
|
1962
|
+
{%- if cookiecutter.websocket_auth_api_key %}
|
|
1963
|
+
# Verify API key before accepting connection
|
|
1964
|
+
if not await verify_api_key(api_key):
|
|
1965
|
+
await websocket.close(code=4001, reason="Invalid API key")
|
|
1966
|
+
return
|
|
1967
|
+
{%- endif %}
|
|
1968
|
+
|
|
1969
|
+
await manager.connect(websocket)
|
|
1970
|
+
|
|
1971
|
+
# Conversation state per connection
|
|
1972
|
+
conversation_history: list[dict[str, str]] = []
|
|
1973
|
+
context: AgentContext = {}
|
|
1974
|
+
# Thread ID for LangGraph state persistence (required for HITL)
|
|
1975
|
+
thread_id: str = str(uuid.uuid4())
|
|
1976
|
+
# Track pending interrupt for resume
|
|
1977
|
+
pending_interrupt: InterruptData | None = None
|
|
1978
|
+
{%- if cookiecutter.websocket_auth_jwt %}
|
|
1979
|
+
context["user_id"] = str(user.id) if user else None
|
|
1980
|
+
context["user_name"] = user.email if user else None
|
|
1981
|
+
{%- endif %}
|
|
1982
|
+
{%- if cookiecutter.enable_conversation_persistence and cookiecutter.use_database %}
|
|
1983
|
+
current_conversation_id: str | None = None
|
|
1984
|
+
{%- endif %}
|
|
1985
|
+
|
|
1986
|
+
# Create assistant instance (reused for the connection)
|
|
1987
|
+
assistant = get_agent()
|
|
1988
|
+
|
|
1989
|
+
try:
|
|
1990
|
+
while True:
|
|
1991
|
+
# Receive message from client
|
|
1992
|
+
raw_data = await websocket.receive_json()
|
|
1993
|
+
message_type = raw_data.get("type", "message")
|
|
1994
|
+
|
|
1995
|
+
# Handle resume after interrupt
|
|
1996
|
+
if message_type == "resume":
|
|
1997
|
+
if not pending_interrupt:
|
|
1998
|
+
await manager.send_event(websocket, "error", {"message": "No pending interrupt to resume"})
|
|
1999
|
+
continue
|
|
2000
|
+
|
|
2001
|
+
decisions: list[Decision] = raw_data.get("decisions", [])
|
|
2002
|
+
if len(decisions) != len(pending_interrupt["action_requests"]):
|
|
2003
|
+
await manager.send_event(
|
|
2004
|
+
websocket,
|
|
2005
|
+
"error",
|
|
2006
|
+
{"message": f"Expected {len(pending_interrupt['action_requests'])} decisions, got {len(decisions)}"},
|
|
2007
|
+
)
|
|
2008
|
+
continue
|
|
2009
|
+
|
|
2010
|
+
# Clear pending interrupt
|
|
2011
|
+
pending_interrupt = None
|
|
2012
|
+
|
|
2013
|
+
try:
|
|
2014
|
+
await manager.send_event(websocket, "resume_start", {})
|
|
2015
|
+
|
|
2016
|
+
final_output = ""
|
|
2017
|
+
seen_tool_call_ids: set[str] = set()
|
|
2018
|
+
|
|
2019
|
+
# Stream resume
|
|
2020
|
+
async for stream_mode, stream_data in assistant.stream_resume(
|
|
2021
|
+
decisions=decisions,
|
|
2022
|
+
thread_id=thread_id,
|
|
2023
|
+
context=context,
|
|
2024
|
+
):
|
|
2025
|
+
if stream_mode == "interrupt":
|
|
2026
|
+
# Another interrupt occurred
|
|
2027
|
+
pending_interrupt = stream_data
|
|
2028
|
+
await manager.send_event(
|
|
2029
|
+
websocket,
|
|
2030
|
+
"tool_approval_required",
|
|
2031
|
+
{
|
|
2032
|
+
"action_requests": pending_interrupt["action_requests"],
|
|
2033
|
+
"review_configs": pending_interrupt["review_configs"],
|
|
2034
|
+
},
|
|
2035
|
+
)
|
|
2036
|
+
break
|
|
2037
|
+
|
|
2038
|
+
if stream_mode == "messages":
|
|
2039
|
+
chunk, _metadata = stream_data
|
|
2040
|
+
if isinstance(chunk, AIMessageChunk) and chunk.content:
|
|
2041
|
+
text_content = ""
|
|
2042
|
+
if isinstance(chunk.content, str):
|
|
2043
|
+
text_content = chunk.content
|
|
2044
|
+
elif isinstance(chunk.content, list):
|
|
2045
|
+
for block in chunk.content:
|
|
2046
|
+
if isinstance(block, dict) and block.get("type") == "text":
|
|
2047
|
+
text_content += block.get("text", "")
|
|
2048
|
+
elif isinstance(block, str):
|
|
2049
|
+
text_content += block
|
|
2050
|
+
if text_content:
|
|
2051
|
+
await manager.send_event(websocket, "text_delta", {"content": text_content})
|
|
2052
|
+
final_output += text_content
|
|
2053
|
+
|
|
2054
|
+
elif stream_mode == "updates":
|
|
2055
|
+
for node_name, update in stream_data.items():
|
|
2056
|
+
if node_name == "tools":
|
|
2057
|
+
for msg in update.get("messages", []):
|
|
2058
|
+
if isinstance(msg, ToolMessage):
|
|
2059
|
+
await manager.send_event(
|
|
2060
|
+
websocket,
|
|
2061
|
+
"tool_result",
|
|
2062
|
+
{"tool_call_id": msg.tool_call_id, "content": msg.content},
|
|
2063
|
+
)
|
|
2064
|
+
|
|
2065
|
+
if not pending_interrupt:
|
|
2066
|
+
# No interrupt, send final result
|
|
2067
|
+
if final_output:
|
|
2068
|
+
conversation_history.append({"role": "assistant", "content": final_output})
|
|
2069
|
+
await manager.send_event(websocket, "final_result", {"output": final_output})
|
|
2070
|
+
await manager.send_event(websocket, "complete", {})
|
|
2071
|
+
|
|
2072
|
+
except Exception as e:
|
|
2073
|
+
logger.exception(f"Error resuming agent: {e}")
|
|
2074
|
+
await manager.send_event(websocket, "error", {"message": str(e)})
|
|
2075
|
+
|
|
2076
|
+
continue
|
|
2077
|
+
|
|
2078
|
+
# Regular message handling
|
|
2079
|
+
user_message = raw_data.get("message", "")
|
|
2080
|
+
# Optionally accept history from client (or use server-side tracking)
|
|
2081
|
+
if "history" in raw_data:
|
|
2082
|
+
conversation_history = raw_data["history"]
|
|
2083
|
+
|
|
2084
|
+
if not user_message:
|
|
2085
|
+
await manager.send_event(websocket, "error", {"message": "Empty message"})
|
|
2086
|
+
continue
|
|
2087
|
+
|
|
2088
|
+
{%- if cookiecutter.enable_conversation_persistence and (cookiecutter.use_postgresql or cookiecutter.use_sqlite) %}
|
|
2089
|
+
|
|
2090
|
+
# Handle conversation persistence
|
|
2091
|
+
try:
|
|
2092
|
+
{%- if cookiecutter.use_postgresql %}
|
|
2093
|
+
async with get_db_context() as db:
|
|
2094
|
+
conv_service = get_conversation_service(db)
|
|
2095
|
+
|
|
2096
|
+
# Get or create conversation
|
|
2097
|
+
requested_conv_id = raw_data.get("conversation_id")
|
|
2098
|
+
if requested_conv_id:
|
|
2099
|
+
current_conversation_id = requested_conv_id
|
|
2100
|
+
# Verify conversation exists
|
|
2101
|
+
await conv_service.get_conversation(UUID(requested_conv_id))
|
|
2102
|
+
elif not current_conversation_id:
|
|
2103
|
+
# Create new conversation
|
|
2104
|
+
conv_data = ConversationCreate(
|
|
2105
|
+
{%- if cookiecutter.websocket_auth_jwt %}
|
|
2106
|
+
user_id=user.id,
|
|
2107
|
+
{%- endif %}
|
|
2108
|
+
title=user_message[:50] if len(user_message) > 50 else user_message,
|
|
2109
|
+
)
|
|
2110
|
+
conversation = await conv_service.create_conversation(conv_data)
|
|
2111
|
+
current_conversation_id = str(conversation.id)
|
|
2112
|
+
await manager.send_event(
|
|
2113
|
+
websocket,
|
|
2114
|
+
"conversation_created",
|
|
2115
|
+
{"conversation_id": current_conversation_id},
|
|
2116
|
+
)
|
|
2117
|
+
|
|
2118
|
+
# Save user message
|
|
2119
|
+
await conv_service.add_message(
|
|
2120
|
+
UUID(current_conversation_id),
|
|
2121
|
+
MessageCreate(role="user", content=user_message),
|
|
2122
|
+
)
|
|
2123
|
+
{%- else %}
|
|
2124
|
+
with get_db_session() as db:
|
|
2125
|
+
conv_service = get_conversation_service(db)
|
|
2126
|
+
|
|
2127
|
+
# Get or create conversation
|
|
2128
|
+
requested_conv_id = raw_data.get("conversation_id")
|
|
2129
|
+
if requested_conv_id:
|
|
2130
|
+
current_conversation_id = requested_conv_id
|
|
2131
|
+
conv_service.get_conversation(requested_conv_id)
|
|
2132
|
+
elif not current_conversation_id:
|
|
2133
|
+
# Create new conversation
|
|
2134
|
+
conv_data = ConversationCreate(
|
|
2135
|
+
{%- if cookiecutter.websocket_auth_jwt %}
|
|
2136
|
+
user_id=str(user.id),
|
|
2137
|
+
{%- endif %}
|
|
2138
|
+
title=user_message[:50] if len(user_message) > 50 else user_message,
|
|
2139
|
+
)
|
|
2140
|
+
conversation = conv_service.create_conversation(conv_data)
|
|
2141
|
+
current_conversation_id = str(conversation.id)
|
|
2142
|
+
await manager.send_event(
|
|
2143
|
+
websocket,
|
|
2144
|
+
"conversation_created",
|
|
2145
|
+
{"conversation_id": current_conversation_id},
|
|
2146
|
+
)
|
|
2147
|
+
|
|
2148
|
+
# Save user message
|
|
2149
|
+
conv_service.add_message(
|
|
2150
|
+
current_conversation_id,
|
|
2151
|
+
MessageCreate(role="user", content=user_message),
|
|
2152
|
+
)
|
|
2153
|
+
{%- endif %}
|
|
2154
|
+
except Exception as e:
|
|
2155
|
+
logger.warning(f"Failed to persist conversation: {e}")
|
|
2156
|
+
# Continue without persistence
|
|
2157
|
+
{%- elif cookiecutter.enable_conversation_persistence and cookiecutter.use_mongodb %}
|
|
2158
|
+
|
|
2159
|
+
# Handle conversation persistence (MongoDB)
|
|
2160
|
+
conv_service = get_conversation_service()
|
|
2161
|
+
|
|
2162
|
+
requested_conv_id = raw_data.get("conversation_id")
|
|
2163
|
+
if requested_conv_id:
|
|
2164
|
+
current_conversation_id = requested_conv_id
|
|
2165
|
+
await conv_service.get_conversation(requested_conv_id)
|
|
2166
|
+
elif not current_conversation_id:
|
|
2167
|
+
conv_data = ConversationCreate(
|
|
2168
|
+
{%- if cookiecutter.websocket_auth_jwt %}
|
|
2169
|
+
user_id=str(user.id),
|
|
2170
|
+
{%- endif %}
|
|
2171
|
+
title=user_message[:50] if len(user_message) > 50 else user_message,
|
|
2172
|
+
)
|
|
2173
|
+
conversation = await conv_service.create_conversation(conv_data)
|
|
2174
|
+
current_conversation_id = str(conversation.id)
|
|
2175
|
+
await manager.send_event(
|
|
2176
|
+
websocket,
|
|
2177
|
+
"conversation_created",
|
|
2178
|
+
{"conversation_id": current_conversation_id},
|
|
2179
|
+
)
|
|
2180
|
+
|
|
2181
|
+
# Save user message
|
|
2182
|
+
await conv_service.add_message(
|
|
2183
|
+
current_conversation_id,
|
|
2184
|
+
MessageCreate(role="user", content=user_message),
|
|
2185
|
+
)
|
|
2186
|
+
{%- endif %}
|
|
2187
|
+
|
|
2188
|
+
await manager.send_event(websocket, "user_prompt", {"content": user_message})
|
|
2189
|
+
|
|
2190
|
+
try:
|
|
2191
|
+
final_output = ""
|
|
2192
|
+
tool_events: list[Any] = []
|
|
2193
|
+
seen_tool_call_ids: set[str] = set()
|
|
2194
|
+
|
|
2195
|
+
await manager.send_event(websocket, "model_request_start", {})
|
|
2196
|
+
|
|
2197
|
+
# Use DeepAgents' stream() which wraps LangGraph's astream
|
|
2198
|
+
async for stream_mode, stream_data in assistant.stream(
|
|
2199
|
+
user_message,
|
|
2200
|
+
history=conversation_history,
|
|
2201
|
+
context=context,
|
|
2202
|
+
thread_id=thread_id,
|
|
2203
|
+
):
|
|
2204
|
+
# Handle interrupt - human approval required
|
|
2205
|
+
if stream_mode == "interrupt":
|
|
2206
|
+
pending_interrupt = stream_data
|
|
2207
|
+
await manager.send_event(
|
|
2208
|
+
websocket,
|
|
2209
|
+
"tool_approval_required",
|
|
2210
|
+
{
|
|
2211
|
+
"action_requests": pending_interrupt["action_requests"],
|
|
2212
|
+
"review_configs": pending_interrupt["review_configs"],
|
|
2213
|
+
},
|
|
2214
|
+
)
|
|
2215
|
+
break
|
|
2216
|
+
|
|
2217
|
+
if stream_mode == "messages":
|
|
2218
|
+
chunk, _metadata = stream_data
|
|
2219
|
+
|
|
2220
|
+
if isinstance(chunk, AIMessageChunk):
|
|
2221
|
+
if chunk.content:
|
|
2222
|
+
text_content = ""
|
|
2223
|
+
if isinstance(chunk.content, str):
|
|
2224
|
+
text_content = chunk.content
|
|
2225
|
+
elif isinstance(chunk.content, list):
|
|
2226
|
+
for block in chunk.content:
|
|
2227
|
+
if isinstance(block, dict) and block.get("type") == "text":
|
|
2228
|
+
text_content += block.get("text", "")
|
|
2229
|
+
elif isinstance(block, str):
|
|
2230
|
+
text_content += block
|
|
2231
|
+
|
|
2232
|
+
if text_content:
|
|
2233
|
+
await manager.send_event(
|
|
2234
|
+
websocket,
|
|
2235
|
+
"text_delta",
|
|
2236
|
+
{"content": text_content},
|
|
2237
|
+
)
|
|
2238
|
+
final_output += text_content
|
|
2239
|
+
|
|
2240
|
+
# Handle tool call chunks
|
|
2241
|
+
if chunk.tool_call_chunks:
|
|
2242
|
+
for tc_chunk in chunk.tool_call_chunks:
|
|
2243
|
+
tc_id = tc_chunk.get("id")
|
|
2244
|
+
tc_name = tc_chunk.get("name")
|
|
2245
|
+
if tc_id and tc_name and tc_id not in seen_tool_call_ids:
|
|
2246
|
+
seen_tool_call_ids.add(tc_id)
|
|
2247
|
+
await manager.send_event(
|
|
2248
|
+
websocket,
|
|
2249
|
+
"tool_call",
|
|
2250
|
+
{
|
|
2251
|
+
"tool_name": tc_name,
|
|
2252
|
+
"args": {},
|
|
2253
|
+
"tool_call_id": tc_id,
|
|
2254
|
+
},
|
|
2255
|
+
)
|
|
2256
|
+
|
|
2257
|
+
elif stream_mode == "updates":
|
|
2258
|
+
# Handle state updates from nodes
|
|
2259
|
+
for node_name, update in stream_data.items():
|
|
2260
|
+
if node_name == "tools":
|
|
2261
|
+
# Tool node completed - extract tool results
|
|
2262
|
+
for msg in update.get("messages", []):
|
|
2263
|
+
if isinstance(msg, ToolMessage):
|
|
2264
|
+
await manager.send_event(
|
|
2265
|
+
websocket,
|
|
2266
|
+
"tool_result",
|
|
2267
|
+
{
|
|
2268
|
+
"tool_call_id": msg.tool_call_id,
|
|
2269
|
+
"content": msg.content,
|
|
2270
|
+
},
|
|
2271
|
+
)
|
|
2272
|
+
elif node_name == "agent":
|
|
2273
|
+
# Agent node completed - check for tool calls
|
|
2274
|
+
for msg in update.get("messages", []):
|
|
2275
|
+
if isinstance(msg, AIMessage) and msg.tool_calls:
|
|
2276
|
+
for tc in msg.tool_calls:
|
|
2277
|
+
tc_id = tc.get("id", "")
|
|
2278
|
+
if tc_id not in seen_tool_call_ids:
|
|
2279
|
+
seen_tool_call_ids.add(tc_id)
|
|
2280
|
+
tool_events.append(tc)
|
|
2281
|
+
await manager.send_event(
|
|
2282
|
+
websocket,
|
|
2283
|
+
"tool_call",
|
|
2284
|
+
{
|
|
2285
|
+
"tool_name": tc.get("name", ""),
|
|
2286
|
+
"args": tc.get("args", {}),
|
|
2287
|
+
"tool_call_id": tc_id,
|
|
2288
|
+
},
|
|
2289
|
+
)
|
|
2290
|
+
|
|
2291
|
+
# Only send final result if not interrupted
|
|
2292
|
+
if not pending_interrupt:
|
|
2293
|
+
await manager.send_event(
|
|
2294
|
+
websocket,
|
|
2295
|
+
"final_result",
|
|
2296
|
+
{"output": final_output},
|
|
2297
|
+
)
|
|
2298
|
+
|
|
2299
|
+
# Update conversation history
|
|
2300
|
+
conversation_history.append({"role": "user", "content": user_message})
|
|
2301
|
+
if final_output:
|
|
2302
|
+
conversation_history.append(
|
|
2303
|
+
{"role": "assistant", "content": final_output}
|
|
2304
|
+
)
|
|
2305
|
+
|
|
2306
|
+
{%- if cookiecutter.enable_conversation_persistence and (cookiecutter.use_postgresql or cookiecutter.use_sqlite) %}
|
|
2307
|
+
|
|
2308
|
+
# Save assistant response to database
|
|
2309
|
+
if current_conversation_id and final_output:
|
|
2310
|
+
try:
|
|
2311
|
+
{%- if cookiecutter.use_postgresql %}
|
|
2312
|
+
async with get_db_context() as db:
|
|
2313
|
+
conv_service = get_conversation_service(db)
|
|
2314
|
+
await conv_service.add_message(
|
|
2315
|
+
UUID(current_conversation_id),
|
|
2316
|
+
MessageCreate(
|
|
2317
|
+
role="assistant",
|
|
2318
|
+
content=final_output,
|
|
2319
|
+
model_name=assistant.model_name if hasattr(assistant, "model_name") else None,
|
|
2320
|
+
),
|
|
2321
|
+
)
|
|
2322
|
+
{%- else %}
|
|
2323
|
+
with get_db_session() as db:
|
|
2324
|
+
conv_service = get_conversation_service(db)
|
|
2325
|
+
conv_service.add_message(
|
|
2326
|
+
current_conversation_id,
|
|
2327
|
+
MessageCreate(
|
|
2328
|
+
role="assistant",
|
|
2329
|
+
content=final_output,
|
|
2330
|
+
model_name=assistant.model_name if hasattr(assistant, "model_name") else None,
|
|
2331
|
+
),
|
|
2332
|
+
)
|
|
2333
|
+
{%- endif %}
|
|
2334
|
+
except Exception as e:
|
|
2335
|
+
logger.warning(f"Failed to persist assistant response: {e}")
|
|
2336
|
+
{%- elif cookiecutter.enable_conversation_persistence and cookiecutter.use_mongodb %}
|
|
2337
|
+
|
|
2338
|
+
# Save assistant response to database
|
|
2339
|
+
if current_conversation_id and final_output:
|
|
2340
|
+
try:
|
|
2341
|
+
await conv_service.add_message(
|
|
2342
|
+
current_conversation_id,
|
|
2343
|
+
MessageCreate(
|
|
2344
|
+
role="assistant",
|
|
2345
|
+
content=final_output,
|
|
2346
|
+
model_name=assistant.model_name if hasattr(assistant, "model_name") else None,
|
|
2347
|
+
),
|
|
2348
|
+
)
|
|
2349
|
+
except Exception as e:
|
|
2350
|
+
logger.warning(f"Failed to persist assistant response: {e}")
|
|
2351
|
+
{%- endif %}
|
|
2352
|
+
|
|
2353
|
+
await manager.send_event(websocket, "complete", {
|
|
2354
|
+
{%- if cookiecutter.enable_conversation_persistence and cookiecutter.use_database %}
|
|
2355
|
+
"conversation_id": current_conversation_id,
|
|
2356
|
+
{%- endif %}
|
|
2357
|
+
})
|
|
2358
|
+
|
|
2359
|
+
except WebSocketDisconnect:
|
|
2360
|
+
# Client disconnected during processing - this is normal
|
|
2361
|
+
logger.info("Client disconnected during agent processing")
|
|
2362
|
+
break
|
|
2363
|
+
except Exception as e:
|
|
2364
|
+
logger.exception(f"Error processing agent request: {e}")
|
|
2365
|
+
# Try to send error, but don't fail if connection is closed
|
|
2366
|
+
await manager.send_event(websocket, "error", {"message": str(e)})
|
|
2367
|
+
|
|
896
2368
|
except WebSocketDisconnect:
|
|
897
2369
|
pass # Normal disconnect
|
|
898
2370
|
finally:
|