khoj 2.0.0b8.dev1__py3-none-any.whl → 2.0.0b9__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- khoj/database/adapters/__init__.py +3 -2
- khoj/database/models/__init__.py +28 -0
- khoj/interface/compiled/404/index.html +2 -2
- khoj/interface/compiled/_next/static/chunks/5477-c4209b72942d3038.js +1 -0
- khoj/interface/compiled/_next/static/chunks/9139-8ac4d9feb10f8869.js +1 -0
- khoj/interface/compiled/_next/static/chunks/app/agents/layout-e3d72f0edda6aa0c.js +1 -0
- khoj/interface/compiled/_next/static/chunks/app/agents/{page-5db6ad18da10d353.js → page-9a4610474cd59a71.js} +1 -1
- khoj/interface/compiled/_next/static/chunks/app/automations/{page-6271e2e31c7571d1.js → page-f7bb9d777b7745d4.js} +1 -1
- khoj/interface/compiled/_next/static/chunks/app/chat/page-a8455b8f9d36a2b0.js +1 -0
- khoj/interface/compiled/_next/static/chunks/app/{page-a19a597629e87fb8.js → page-2025944ec1f80144.js} +1 -1
- khoj/interface/compiled/_next/static/chunks/app/search/layout-4505b79deb734a30.js +1 -0
- khoj/interface/compiled/_next/static/chunks/app/search/{page-fa366ac14b228688.js → page-4885df3cd175c957.js} +1 -1
- khoj/interface/compiled/_next/static/chunks/app/settings/{page-8f9a85f96088c18b.js → page-8be3b35178abf2ec.js} +1 -1
- khoj/interface/compiled/_next/static/chunks/app/share/chat/layout-6fb51c5c80f8ec67.js +1 -0
- khoj/interface/compiled/_next/static/chunks/app/share/chat/{page-ed7787cf4938b8e3.js → page-ee8ef5270163e7f2.js} +1 -1
- khoj/interface/compiled/_next/static/chunks/{webpack-c6e14fd89812b96f.js → webpack-6355be48bba04af8.js} +1 -1
- khoj/interface/compiled/_next/static/css/102b97d6472fdd3a.css +1 -0
- khoj/interface/compiled/_next/static/css/{93eeacc43e261162.css → c34713c98384ee87.css} +1 -1
- khoj/interface/compiled/_next/static/css/fc82e43baa9ae218.css +1 -0
- khoj/interface/compiled/agents/index.html +2 -2
- khoj/interface/compiled/agents/index.txt +2 -2
- khoj/interface/compiled/automations/index.html +2 -2
- khoj/interface/compiled/automations/index.txt +2 -2
- khoj/interface/compiled/chat/index.html +2 -2
- khoj/interface/compiled/chat/index.txt +2 -2
- khoj/interface/compiled/index.html +2 -2
- khoj/interface/compiled/index.txt +2 -2
- khoj/interface/compiled/search/index.html +2 -2
- khoj/interface/compiled/search/index.txt +2 -2
- khoj/interface/compiled/settings/index.html +2 -2
- khoj/interface/compiled/settings/index.txt +2 -2
- khoj/interface/compiled/share/chat/index.html +2 -2
- khoj/interface/compiled/share/chat/index.txt +2 -2
- khoj/main.py +11 -1
- khoj/processor/conversation/utils.py +7 -7
- khoj/processor/operator/__init__.py +16 -1
- khoj/processor/tools/run_code.py +5 -2
- khoj/routers/api_chat.py +846 -682
- khoj/routers/helpers.py +149 -5
- khoj/routers/research.py +56 -14
- khoj/utils/rawconfig.py +0 -1
- {khoj-2.0.0b8.dev1.dist-info → khoj-2.0.0b9.dist-info}/METADATA +1 -1
- {khoj-2.0.0b8.dev1.dist-info → khoj-2.0.0b9.dist-info}/RECORD +48 -48
- khoj/interface/compiled/_next/static/chunks/5477-18323501c445315e.js +0 -1
- khoj/interface/compiled/_next/static/chunks/9568-0d60ac475f4cc538.js +0 -1
- khoj/interface/compiled/_next/static/chunks/app/agents/layout-e00fb81dca656a10.js +0 -1
- khoj/interface/compiled/_next/static/chunks/app/chat/page-b186e95387e23ed5.js +0 -1
- khoj/interface/compiled/_next/static/chunks/app/search/layout-f5881c7ae3ba0795.js +0 -1
- khoj/interface/compiled/_next/static/chunks/app/share/chat/layout-abb6c5f4239ad7be.js +0 -1
- khoj/interface/compiled/_next/static/css/76c658ee459140a9.css +0 -1
- khoj/interface/compiled/_next/static/css/a0c2fd63bb396f04.css +0 -1
- /khoj/interface/compiled/_next/static/{ZHB1va0KhWIu8Zs8-kbgt → 6kC_Tt4g0U2gXGUbnSB1O}/_buildManifest.js +0 -0
- /khoj/interface/compiled/_next/static/{ZHB1va0KhWIu8Zs8-kbgt → 6kC_Tt4g0U2gXGUbnSB1O}/_ssgManifest.js +0 -0
- {khoj-2.0.0b8.dev1.dist-info → khoj-2.0.0b9.dist-info}/WHEEL +0 -0
- {khoj-2.0.0b8.dev1.dist-info → khoj-2.0.0b9.dist-info}/entry_points.txt +0 -0
- {khoj-2.0.0b8.dev1.dist-info → khoj-2.0.0b9.dist-info}/licenses/LICENSE +0 -0
khoj/routers/api_chat.py
CHANGED
@@ -10,9 +10,18 @@ from typing import Any, Dict, List, Optional
|
|
10
10
|
from urllib.parse import unquote
|
11
11
|
|
12
12
|
from asgiref.sync import sync_to_async
|
13
|
-
from fastapi import
|
13
|
+
from fastapi import (
|
14
|
+
APIRouter,
|
15
|
+
Depends,
|
16
|
+
HTTPException,
|
17
|
+
Request,
|
18
|
+
WebSocket,
|
19
|
+
WebSocketDisconnect,
|
20
|
+
)
|
14
21
|
from fastapi.responses import RedirectResponse, Response, StreamingResponse
|
22
|
+
from fastapi.websockets import WebSocketState
|
15
23
|
from starlette.authentication import has_required_scope, requires
|
24
|
+
from starlette.requests import URL, Headers
|
16
25
|
|
17
26
|
from khoj.app.settings import ALLOWED_HOSTS
|
18
27
|
from khoj.database.adapters import (
|
@@ -51,6 +60,7 @@ from khoj.routers.helpers import (
|
|
51
60
|
ConversationCommandRateLimiter,
|
52
61
|
DeleteMessageRequestBody,
|
53
62
|
FeedbackData,
|
63
|
+
WebSocketConnectionManager,
|
54
64
|
acreate_title_from_history,
|
55
65
|
agenerate_chat_response,
|
56
66
|
aget_data_sources_and_output_format,
|
@@ -60,6 +70,7 @@ from khoj.routers.helpers import (
|
|
60
70
|
generate_mermaidjs_diagram,
|
61
71
|
generate_summary_from_files,
|
62
72
|
get_conversation_command,
|
73
|
+
get_message_from_queue,
|
63
74
|
is_query_empty,
|
64
75
|
is_ready_to_chat,
|
65
76
|
read_chat_stream,
|
@@ -657,19 +668,13 @@ def delete_message(request: Request, delete_request: DeleteMessageRequestBody) -
|
|
657
668
|
return Response(content=json.dumps({"status": "error", "message": "Message not found"}), status_code=404)
|
658
669
|
|
659
670
|
|
660
|
-
|
661
|
-
@requires(["authenticated"])
|
662
|
-
async def chat(
|
663
|
-
request: Request,
|
664
|
-
common: CommonQueryParams,
|
671
|
+
async def event_generator(
|
665
672
|
body: ChatRequestBody,
|
666
|
-
|
667
|
-
|
668
|
-
|
669
|
-
|
670
|
-
|
671
|
-
),
|
672
|
-
image_rate_limiter=Depends(ApiImageRateLimiter(max_images=10, max_combined_size_mb=20)),
|
673
|
+
user_scope: Any,
|
674
|
+
common: CommonQueryParams,
|
675
|
+
headers: Headers,
|
676
|
+
request_obj: Request | WebSocket,
|
677
|
+
parent_interrupt_queue: asyncio.Queue = None,
|
673
678
|
):
|
674
679
|
# Access the parameters from the body
|
675
680
|
q = body.q
|
@@ -686,67 +691,68 @@ async def chat(
|
|
686
691
|
timezone = body.timezone
|
687
692
|
raw_images = body.images
|
688
693
|
raw_query_files = body.files
|
689
|
-
interrupt_flag = body.interrupt
|
690
|
-
|
691
|
-
async def event_generator(q: str, images: list[str]):
|
692
|
-
start_time = time.perf_counter()
|
693
|
-
ttft = None
|
694
|
-
chat_metadata: dict = {}
|
695
|
-
conversation = None
|
696
|
-
user: KhojUser = request.user.object
|
697
|
-
is_subscribed = has_required_scope(request, ["premium"])
|
698
|
-
q = unquote(q)
|
699
|
-
train_of_thought = []
|
700
|
-
nonlocal conversation_id
|
701
|
-
nonlocal raw_query_files
|
702
|
-
cancellation_event = asyncio.Event()
|
703
|
-
|
704
|
-
tracer: dict = {
|
705
|
-
"mid": turn_id,
|
706
|
-
"cid": conversation_id,
|
707
|
-
"uid": user.id,
|
708
|
-
"khoj_version": state.khoj_version,
|
709
|
-
}
|
710
694
|
|
711
|
-
|
712
|
-
|
713
|
-
|
714
|
-
|
715
|
-
|
716
|
-
|
717
|
-
|
718
|
-
|
719
|
-
|
720
|
-
|
721
|
-
|
722
|
-
|
723
|
-
|
724
|
-
|
725
|
-
|
726
|
-
|
727
|
-
|
728
|
-
|
729
|
-
|
730
|
-
|
731
|
-
|
732
|
-
|
733
|
-
|
734
|
-
|
735
|
-
|
736
|
-
|
737
|
-
|
738
|
-
|
739
|
-
|
740
|
-
|
741
|
-
|
742
|
-
|
743
|
-
|
744
|
-
|
745
|
-
|
746
|
-
|
747
|
-
|
695
|
+
start_time = time.perf_counter()
|
696
|
+
ttft = None
|
697
|
+
chat_metadata: dict = {}
|
698
|
+
conversation = None
|
699
|
+
user: KhojUser = user_scope.object
|
700
|
+
is_subscribed = has_required_scope(request_obj, ["premium"])
|
701
|
+
q = unquote(q)
|
702
|
+
defiltered_query = defilter_query(q)
|
703
|
+
train_of_thought = []
|
704
|
+
cancellation_event = asyncio.Event()
|
705
|
+
child_interrupt_queue: asyncio.Queue = asyncio.Queue(maxsize=10)
|
706
|
+
event_delimiter = "␃🔚␗"
|
707
|
+
|
708
|
+
tracer: dict = {
|
709
|
+
"mid": turn_id,
|
710
|
+
"cid": conversation_id,
|
711
|
+
"uid": user.id,
|
712
|
+
"khoj_version": state.khoj_version,
|
713
|
+
}
|
714
|
+
|
715
|
+
uploaded_images: list[str] = []
|
716
|
+
if raw_images:
|
717
|
+
for image in raw_images:
|
718
|
+
decoded_string = unquote(image)
|
719
|
+
base64_data = decoded_string.split(",", 1)[1]
|
720
|
+
image_bytes = base64.b64decode(base64_data)
|
721
|
+
webp_image_bytes = convert_image_to_webp(image_bytes)
|
722
|
+
uploaded_image = upload_user_image_to_bucket(webp_image_bytes, user.id)
|
723
|
+
if not uploaded_image:
|
724
|
+
base64_webp_image = base64.b64encode(webp_image_bytes).decode("utf-8")
|
725
|
+
uploaded_image = f"data:image/webp;base64,{base64_webp_image}"
|
726
|
+
uploaded_images.append(uploaded_image)
|
727
|
+
|
728
|
+
query_files: Dict[str, str] = {}
|
729
|
+
if raw_query_files:
|
730
|
+
for file in raw_query_files:
|
731
|
+
query_files[file.name] = file.content
|
732
|
+
|
733
|
+
research_results: List[ResearchIteration] = []
|
734
|
+
online_results: Dict = dict()
|
735
|
+
code_results: Dict = dict()
|
736
|
+
operator_results: List[OperatorRun] = []
|
737
|
+
compiled_references: List[Any] = []
|
738
|
+
inferred_queries: List[Any] = []
|
739
|
+
attached_file_context = gather_raw_query_files(query_files)
|
740
|
+
|
741
|
+
generated_images: List[str] = []
|
742
|
+
generated_files: List[FileAttachment] = []
|
743
|
+
generated_mermaidjs_diagram: str = None
|
744
|
+
generated_asset_results: Dict = dict()
|
745
|
+
program_execution_context: List[str] = []
|
746
|
+
user_message_time = datetime.now().strftime("%Y-%m-%d %H:%M:%S")
|
747
|
+
|
748
|
+
# Create a task to monitor for disconnections
|
749
|
+
disconnect_monitor_task = None
|
750
|
+
|
751
|
+
async def monitor_disconnection():
|
752
|
+
nonlocal q, defiltered_query
|
753
|
+
if isinstance(request_obj, Request):
|
748
754
|
try:
|
749
|
-
msg = await
|
755
|
+
msg = await request_obj.receive()
|
750
756
|
if msg["type"] == "http.disconnect":
|
751
757
|
logger.debug(f"Request cancelled. User {user} disconnected from {common.client} client.")
|
752
758
|
cancellation_event.set()
|
@@ -758,14 +764,13 @@ async def chat(
|
|
758
764
|
q,
|
759
765
|
chat_response="",
|
760
766
|
user=user,
|
761
|
-
chat_history=chat_history,
|
762
767
|
compiled_references=compiled_references,
|
763
768
|
online_results=online_results,
|
764
769
|
code_results=code_results,
|
765
770
|
operator_results=operator_results,
|
766
771
|
research_results=research_results,
|
767
772
|
inferred_queries=inferred_queries,
|
768
|
-
client_application=
|
773
|
+
client_application=user_scope.client_app,
|
769
774
|
conversation_id=conversation_id,
|
770
775
|
query_images=uploaded_images,
|
771
776
|
train_of_thought=train_of_thought,
|
@@ -773,502 +778,319 @@ async def chat(
|
|
773
778
|
generated_images=generated_images,
|
774
779
|
raw_generated_files=generated_asset_results,
|
775
780
|
generated_mermaidjs_diagram=generated_mermaidjs_diagram,
|
781
|
+
user_message_time=user_message_time,
|
776
782
|
tracer=tracer,
|
777
783
|
)
|
778
784
|
)
|
779
785
|
except Exception as e:
|
780
786
|
logger.error(f"Error in disconnect monitor: {e}")
|
781
|
-
|
782
|
-
|
783
|
-
|
784
|
-
|
785
|
-
|
786
|
-
|
787
|
-
|
788
|
-
|
789
|
-
|
790
|
-
pass
|
791
|
-
|
792
|
-
async def send_event(event_type: ChatEvent, data: str | dict):
|
793
|
-
nonlocal ttft, train_of_thought
|
794
|
-
event_delimiter = "␃🔚␗"
|
795
|
-
if cancellation_event.is_set():
|
796
|
-
return
|
797
|
-
try:
|
798
|
-
if event_type == ChatEvent.END_LLM_RESPONSE:
|
799
|
-
collect_telemetry()
|
800
|
-
elif event_type == ChatEvent.START_LLM_RESPONSE:
|
801
|
-
ttft = time.perf_counter() - start_time
|
802
|
-
elif event_type == ChatEvent.STATUS:
|
803
|
-
train_of_thought.append({"type": event_type.value, "data": data})
|
804
|
-
elif event_type == ChatEvent.THOUGHT:
|
805
|
-
# Append the data to the last thought as thoughts are streamed
|
806
|
-
if (
|
807
|
-
len(train_of_thought) > 0
|
808
|
-
and train_of_thought[-1]["type"] == ChatEvent.THOUGHT.value
|
809
|
-
and type(train_of_thought[-1]["data"]) == type(data) == str
|
810
|
-
):
|
811
|
-
train_of_thought[-1]["data"] += data
|
787
|
+
elif isinstance(request_obj, WebSocket):
|
788
|
+
while request_obj.client_state == WebSocketState.CONNECTED and not cancellation_event.is_set():
|
789
|
+
await asyncio.sleep(1)
|
790
|
+
|
791
|
+
# Check if any interrupt query is received
|
792
|
+
if interrupt_query := get_message_from_queue(parent_interrupt_queue):
|
793
|
+
if interrupt_query == event_delimiter:
|
794
|
+
cancellation_event.set()
|
795
|
+
logger.debug(f"Chat cancelled by user {user} via interrupt queue.")
|
812
796
|
else:
|
813
|
-
|
814
|
-
|
815
|
-
|
816
|
-
|
817
|
-
|
818
|
-
|
819
|
-
|
820
|
-
|
821
|
-
|
822
|
-
|
823
|
-
|
797
|
+
# Pass the interrupt query to child tasks
|
798
|
+
logger.info(f"Continuing chat with the new instruction: {interrupt_query}")
|
799
|
+
await child_interrupt_queue.put(interrupt_query)
|
800
|
+
# Append the interrupt query to the main query
|
801
|
+
q += f"\n\n{interrupt_query}"
|
802
|
+
defiltered_query += f"\n\n{defilter_query(interrupt_query)}"
|
803
|
+
|
804
|
+
logger.debug(f"WebSocket disconnected or chat cancelled by user {user} from {common.client} client.")
|
805
|
+
if conversation and cancellation_event.is_set():
|
806
|
+
await asyncio.shield(
|
807
|
+
save_to_conversation_log(
|
808
|
+
q,
|
809
|
+
chat_response="",
|
810
|
+
user=user,
|
811
|
+
compiled_references=compiled_references,
|
812
|
+
online_results=online_results,
|
813
|
+
code_results=code_results,
|
814
|
+
operator_results=operator_results,
|
815
|
+
research_results=research_results,
|
816
|
+
inferred_queries=inferred_queries,
|
817
|
+
client_application=user_scope.client_app,
|
818
|
+
conversation_id=conversation_id,
|
819
|
+
query_images=uploaded_images,
|
820
|
+
train_of_thought=train_of_thought,
|
821
|
+
raw_query_files=raw_query_files,
|
822
|
+
generated_images=generated_images,
|
823
|
+
raw_generated_files=generated_asset_results,
|
824
|
+
generated_mermaidjs_diagram=generated_mermaidjs_diagram,
|
825
|
+
user_message_time=user_message_time,
|
826
|
+
tracer=tracer,
|
824
827
|
)
|
825
|
-
|
826
|
-
if not cancellation_event.is_set():
|
827
|
-
yield event_delimiter
|
828
|
-
# Cancel the disconnect monitor task if it is still running
|
829
|
-
if cancellation_event.is_set() or event_type == ChatEvent.END_RESPONSE:
|
830
|
-
await cancel_disconnect_monitor()
|
831
|
-
|
832
|
-
async def send_llm_response(response: str, usage: dict = None):
|
833
|
-
# Check if the client is still connected
|
834
|
-
if cancellation_event.is_set():
|
835
|
-
return
|
836
|
-
# Send Chat Response
|
837
|
-
async for result in send_event(ChatEvent.START_LLM_RESPONSE, ""):
|
838
|
-
yield result
|
839
|
-
async for result in send_event(ChatEvent.MESSAGE, response):
|
840
|
-
yield result
|
841
|
-
async for result in send_event(ChatEvent.END_LLM_RESPONSE, ""):
|
842
|
-
yield result
|
843
|
-
# Send Usage Metadata once llm interactions are complete
|
844
|
-
if usage:
|
845
|
-
async for event in send_event(ChatEvent.USAGE, usage):
|
846
|
-
yield event
|
847
|
-
async for result in send_event(ChatEvent.END_RESPONSE, ""):
|
848
|
-
yield result
|
849
|
-
|
850
|
-
def collect_telemetry():
|
851
|
-
# Gather chat response telemetry
|
852
|
-
nonlocal chat_metadata
|
853
|
-
latency = time.perf_counter() - start_time
|
854
|
-
cmd_set = set([cmd.value for cmd in conversation_commands])
|
855
|
-
cost = (tracer.get("usage", {}) or {}).get("cost", 0)
|
856
|
-
chat_metadata = chat_metadata or {}
|
857
|
-
chat_metadata["conversation_command"] = cmd_set
|
858
|
-
chat_metadata["agent"] = conversation.agent.slug if conversation and conversation.agent else None
|
859
|
-
chat_metadata["cost"] = f"{cost:.5f}"
|
860
|
-
chat_metadata["latency"] = f"{latency:.3f}"
|
861
|
-
if ttft:
|
862
|
-
chat_metadata["ttft_latency"] = f"{ttft:.3f}"
|
863
|
-
logger.info(f"Chat response time to first token: {ttft:.3f} seconds")
|
864
|
-
logger.info(f"Chat response total time: {latency:.3f} seconds")
|
865
|
-
logger.info(f"Chat response cost: ${cost:.5f}")
|
866
|
-
update_telemetry_state(
|
867
|
-
request=request,
|
868
|
-
telemetry_type="api",
|
869
|
-
api="chat",
|
870
|
-
client=common.client,
|
871
|
-
user_agent=request.headers.get("user-agent"),
|
872
|
-
host=request.headers.get("host"),
|
873
|
-
metadata=chat_metadata,
|
874
|
-
)
|
875
|
-
|
876
|
-
# Start the disconnect monitor in the background
|
877
|
-
disconnect_monitor_task = asyncio.create_task(monitor_disconnection())
|
878
|
-
|
879
|
-
if is_query_empty(q):
|
880
|
-
async for result in send_llm_response("Please ask your query to get started.", tracer.get("usage")):
|
881
|
-
yield result
|
882
|
-
return
|
883
|
-
|
884
|
-
# Automated tasks are handled before to allow mixing them with other conversation commands
|
885
|
-
cmds_to_rate_limit = []
|
886
|
-
is_automated_task = False
|
887
|
-
if q.startswith("/automated_task"):
|
888
|
-
is_automated_task = True
|
889
|
-
q = q.replace("/automated_task", "").lstrip()
|
890
|
-
cmds_to_rate_limit += [ConversationCommand.AutomatedTask]
|
828
|
+
)
|
891
829
|
|
892
|
-
|
893
|
-
|
830
|
+
# Cancel the disconnect monitor task if it is still running
|
831
|
+
async def cancel_disconnect_monitor():
|
832
|
+
if disconnect_monitor_task and not disconnect_monitor_task.done():
|
833
|
+
logger.debug(f"Cancelling disconnect monitor task for user {user}")
|
834
|
+
disconnect_monitor_task.cancel()
|
835
|
+
try:
|
836
|
+
await disconnect_monitor_task
|
837
|
+
except asyncio.CancelledError:
|
838
|
+
pass
|
894
839
|
|
895
|
-
|
896
|
-
|
897
|
-
|
898
|
-
conversation_id=conversation_id,
|
899
|
-
title=title,
|
900
|
-
create_new=body.create_new,
|
901
|
-
)
|
902
|
-
if not conversation:
|
903
|
-
async for result in send_llm_response(f"Conversation {conversation_id} not found", tracer.get("usage")):
|
904
|
-
yield result
|
840
|
+
async def send_event(event_type: ChatEvent, data: str | dict):
|
841
|
+
nonlocal ttft, train_of_thought
|
842
|
+
if cancellation_event.is_set():
|
905
843
|
return
|
906
|
-
|
907
|
-
|
908
|
-
|
909
|
-
|
910
|
-
|
911
|
-
|
912
|
-
|
913
|
-
|
914
|
-
|
915
|
-
|
916
|
-
if not conversation.agent:
|
917
|
-
conversation.agent = default_agent
|
918
|
-
await conversation.asave()
|
919
|
-
agent = default_agent
|
920
|
-
|
921
|
-
await is_ready_to_chat(user)
|
922
|
-
user_name = await aget_user_name(user)
|
923
|
-
location = None
|
924
|
-
if city or region or country or country_code:
|
925
|
-
location = LocationData(city=city, region=region, country=country, country_code=country_code)
|
926
|
-
user_message_time = datetime.now().strftime("%Y-%m-%d %H:%M:%S")
|
927
|
-
chat_history = conversation.messages
|
928
|
-
|
929
|
-
# If interrupt flag is set, wait for the previous turn to be saved before proceeding
|
930
|
-
if interrupt_flag:
|
931
|
-
max_wait_time = 20.0 # seconds
|
932
|
-
wait_interval = 0.3 # seconds
|
933
|
-
wait_start = wait_current = time.time()
|
934
|
-
while wait_current - wait_start < max_wait_time:
|
935
|
-
# Refresh conversation to check if interrupted message saved to DB
|
936
|
-
conversation = await ConversationAdapters.aget_conversation_by_user(
|
937
|
-
user,
|
938
|
-
client_application=request.user.client_app,
|
939
|
-
conversation_id=conversation_id,
|
940
|
-
)
|
844
|
+
try:
|
845
|
+
if event_type == ChatEvent.END_LLM_RESPONSE:
|
846
|
+
collect_telemetry()
|
847
|
+
elif event_type == ChatEvent.START_LLM_RESPONSE:
|
848
|
+
ttft = time.perf_counter() - start_time
|
849
|
+
elif event_type == ChatEvent.STATUS:
|
850
|
+
train_of_thought.append({"type": event_type.value, "data": data})
|
851
|
+
elif event_type == ChatEvent.THOUGHT:
|
852
|
+
# Append the data to the last thought as thoughts are streamed
|
941
853
|
if (
|
942
|
-
|
943
|
-
and
|
944
|
-
and
|
945
|
-
and not conversation.messages[-1].message
|
854
|
+
len(train_of_thought) > 0
|
855
|
+
and train_of_thought[-1]["type"] == ChatEvent.THOUGHT.value
|
856
|
+
and type(train_of_thought[-1]["data"]) == type(data) == str
|
946
857
|
):
|
947
|
-
|
948
|
-
|
949
|
-
|
950
|
-
wait_current = time.time()
|
951
|
-
|
952
|
-
if wait_current - wait_start >= max_wait_time:
|
953
|
-
logger.warning(
|
954
|
-
f"Timeout waiting to load interrupted context from conversation {conversation_id}. Proceed without previous context."
|
955
|
-
)
|
858
|
+
train_of_thought[-1]["data"] += data
|
859
|
+
else:
|
860
|
+
train_of_thought.append({"type": event_type.value, "data": data})
|
956
861
|
|
957
|
-
|
958
|
-
|
959
|
-
|
960
|
-
|
961
|
-
|
962
|
-
|
963
|
-
|
964
|
-
|
965
|
-
|
966
|
-
online_results = {key: val.model_dump() for key, val in last_message.onlineContext.items() or []}
|
967
|
-
code_results = {key: val.model_dump() for key, val in last_message.codeContext.items() or []}
|
968
|
-
compiled_references = [ref.model_dump() for ref in last_message.context or []]
|
969
|
-
research_results = [
|
970
|
-
ResearchIteration(**iter_dict)
|
971
|
-
for iter_dict in last_message.researchContext or []
|
972
|
-
if iter_dict.get("summarizedResult")
|
973
|
-
]
|
974
|
-
operator_results = [OperatorRun(**iter_dict) for iter_dict in last_message.operatorContext or []]
|
975
|
-
train_of_thought = [thought.model_dump() for thought in last_message.trainOfThought or []]
|
976
|
-
# Drop the interrupted message from conversation history
|
977
|
-
chat_history.pop()
|
978
|
-
logger.info(f"Loaded interrupted partial context from conversation {conversation_id}.")
|
979
|
-
|
980
|
-
if conversation_commands == [ConversationCommand.Default]:
|
981
|
-
try:
|
982
|
-
chosen_io = await aget_data_sources_and_output_format(
|
983
|
-
q,
|
984
|
-
chat_history,
|
985
|
-
is_automated_task,
|
986
|
-
user=user,
|
987
|
-
query_images=uploaded_images,
|
988
|
-
agent=agent,
|
989
|
-
query_files=attached_file_context,
|
990
|
-
tracer=tracer,
|
862
|
+
if event_type == ChatEvent.MESSAGE:
|
863
|
+
yield data
|
864
|
+
elif event_type == ChatEvent.REFERENCES or ChatEvent.METADATA or stream:
|
865
|
+
yield json.dumps({"type": event_type.value, "data": data}, ensure_ascii=False)
|
866
|
+
except Exception as e:
|
867
|
+
if not cancellation_event.is_set():
|
868
|
+
logger.error(
|
869
|
+
f"Failed to stream chat API response to {user} on {common.client}: {e}.",
|
870
|
+
exc_info=True,
|
991
871
|
)
|
992
|
-
|
993
|
-
|
994
|
-
|
995
|
-
|
996
|
-
|
997
|
-
|
998
|
-
# If we're doing research, we don't want to do anything else
|
999
|
-
if ConversationCommand.Research in conversation_commands:
|
1000
|
-
conversation_commands = [ConversationCommand.Research]
|
872
|
+
finally:
|
873
|
+
if not cancellation_event.is_set():
|
874
|
+
yield event_delimiter
|
875
|
+
# Cancel the disconnect monitor task if it is still running
|
876
|
+
if cancellation_event.is_set() or event_type == ChatEvent.END_RESPONSE:
|
877
|
+
await cancel_disconnect_monitor()
|
1001
878
|
|
1002
|
-
|
1003
|
-
|
1004
|
-
|
879
|
+
async def send_llm_response(response: str, usage: dict = None):
|
880
|
+
# Check if the client is still connected
|
881
|
+
if cancellation_event.is_set():
|
882
|
+
return
|
883
|
+
# Send Chat Response
|
884
|
+
async for result in send_event(ChatEvent.START_LLM_RESPONSE, ""):
|
885
|
+
yield result
|
886
|
+
async for result in send_event(ChatEvent.MESSAGE, response):
|
887
|
+
yield result
|
888
|
+
async for result in send_event(ChatEvent.END_LLM_RESPONSE, ""):
|
889
|
+
yield result
|
890
|
+
# Send Usage Metadata once llm interactions are complete
|
891
|
+
if usage:
|
892
|
+
async for event in send_event(ChatEvent.USAGE, usage):
|
893
|
+
yield event
|
894
|
+
async for result in send_event(ChatEvent.END_RESPONSE, ""):
|
895
|
+
yield result
|
1005
896
|
|
1006
|
-
|
1007
|
-
|
1008
|
-
|
1009
|
-
|
1010
|
-
|
1011
|
-
|
1012
|
-
|
1013
|
-
|
1014
|
-
|
897
|
+
def collect_telemetry():
|
898
|
+
# Gather chat response telemetry
|
899
|
+
nonlocal chat_metadata
|
900
|
+
latency = time.perf_counter() - start_time
|
901
|
+
cmd_set = set([cmd.value for cmd in conversation_commands])
|
902
|
+
cost = (tracer.get("usage", {}) or {}).get("cost", 0)
|
903
|
+
chat_metadata = chat_metadata or {}
|
904
|
+
chat_metadata["conversation_command"] = cmd_set
|
905
|
+
chat_metadata["agent"] = conversation.agent.slug if conversation and conversation.agent else None
|
906
|
+
chat_metadata["cost"] = f"{cost:.5f}"
|
907
|
+
chat_metadata["latency"] = f"{latency:.3f}"
|
908
|
+
if ttft:
|
909
|
+
chat_metadata["ttft_latency"] = f"{ttft:.3f}"
|
910
|
+
logger.info(f"Chat response time to first token: {ttft:.3f} seconds")
|
911
|
+
logger.info(f"Chat response total time: {latency:.3f} seconds")
|
912
|
+
logger.info(f"Chat response cost: ${cost:.5f}")
|
913
|
+
update_telemetry_state(
|
914
|
+
request=request_obj,
|
915
|
+
telemetry_type="api",
|
916
|
+
api="chat",
|
917
|
+
client=common.client,
|
918
|
+
user_agent=headers.get("user-agent"),
|
919
|
+
host=headers.get("host"),
|
920
|
+
metadata=chat_metadata,
|
921
|
+
)
|
1015
922
|
|
1016
|
-
|
1017
|
-
|
923
|
+
# Start the disconnect monitor in the background
|
924
|
+
disconnect_monitor_task = asyncio.create_task(monitor_disconnection())
|
1018
925
|
|
1019
|
-
|
1020
|
-
|
926
|
+
if is_query_empty(q):
|
927
|
+
async for result in send_llm_response("Please ask your query to get started.", tracer.get("usage")):
|
928
|
+
yield result
|
929
|
+
return
|
930
|
+
|
931
|
+
# Automated tasks are handled before to allow mixing them with other conversation commands
|
932
|
+
cmds_to_rate_limit = []
|
933
|
+
is_automated_task = False
|
934
|
+
if q.startswith("/automated_task"):
|
935
|
+
is_automated_task = True
|
936
|
+
q = q.replace("/automated_task", "").lstrip()
|
937
|
+
cmds_to_rate_limit += [ConversationCommand.AutomatedTask]
|
938
|
+
|
939
|
+
# Extract conversation command from query
|
940
|
+
conversation_commands = [get_conversation_command(query=q)]
|
941
|
+
|
942
|
+
conversation = await ConversationAdapters.aget_conversation_by_user(
|
943
|
+
user,
|
944
|
+
client_application=user_scope.client_app,
|
945
|
+
conversation_id=conversation_id,
|
946
|
+
title=title,
|
947
|
+
create_new=body.create_new,
|
948
|
+
)
|
949
|
+
if not conversation:
|
950
|
+
async for result in send_llm_response(f"Conversation {conversation_id} not found", tracer.get("usage")):
|
951
|
+
yield result
|
952
|
+
return
|
953
|
+
conversation_id = str(conversation.id)
|
954
|
+
|
955
|
+
async for event in send_event(ChatEvent.METADATA, {"conversationId": conversation_id, "turnId": turn_id}):
|
956
|
+
yield event
|
957
|
+
|
958
|
+
agent: Agent | None = None
|
959
|
+
default_agent = await AgentAdapters.aget_default_agent()
|
960
|
+
if conversation.agent and conversation.agent != default_agent:
|
961
|
+
agent = conversation.agent
|
962
|
+
|
963
|
+
if not conversation.agent:
|
964
|
+
conversation.agent = default_agent
|
965
|
+
await conversation.asave()
|
966
|
+
agent = default_agent
|
967
|
+
|
968
|
+
await is_ready_to_chat(user)
|
969
|
+
user_name = await aget_user_name(user)
|
970
|
+
location = None
|
971
|
+
if city or region or country or country_code:
|
972
|
+
location = LocationData(city=city, region=region, country=country, country_code=country_code)
|
973
|
+
chat_history = conversation.messages
|
974
|
+
|
975
|
+
# If interrupted message in DB
|
976
|
+
if last_message := await conversation.pop_message(interrupted=True):
|
977
|
+
# Populate context from interrupted message
|
978
|
+
online_results = {key: val.model_dump() for key, val in last_message.onlineContext.items() or []}
|
979
|
+
code_results = {key: val.model_dump() for key, val in last_message.codeContext.items() or []}
|
980
|
+
compiled_references = [ref.model_dump() for ref in last_message.context or []]
|
981
|
+
research_results = [
|
982
|
+
ResearchIteration(**iter_dict)
|
983
|
+
for iter_dict in last_message.researchContext or []
|
984
|
+
if iter_dict.get("summarizedResult")
|
985
|
+
]
|
986
|
+
operator_results = [OperatorRun(**iter_dict) for iter_dict in last_message.operatorContext or []]
|
987
|
+
train_of_thought = [thought.model_dump() for thought in last_message.trainOfThought or []]
|
988
|
+
logger.info(f"Loaded interrupted partial context from conversation {conversation_id}.")
|
989
|
+
|
990
|
+
if conversation_commands == [ConversationCommand.Default]:
|
991
|
+
try:
|
992
|
+
chosen_io = await aget_data_sources_and_output_format(
|
993
|
+
q,
|
994
|
+
chat_history,
|
995
|
+
is_automated_task,
|
1021
996
|
user=user,
|
1022
|
-
query=defiltered_query,
|
1023
|
-
conversation_id=conversation_id,
|
1024
|
-
conversation_history=chat_history,
|
1025
|
-
previous_iterations=list(research_results),
|
1026
997
|
query_images=uploaded_images,
|
1027
998
|
agent=agent,
|
1028
|
-
send_status_func=partial(send_event, ChatEvent.STATUS),
|
1029
|
-
user_name=user_name,
|
1030
|
-
location=location,
|
1031
|
-
file_filters=file_filters,
|
1032
999
|
query_files=attached_file_context,
|
1033
1000
|
tracer=tracer,
|
1034
|
-
|
1035
|
-
|
1036
|
-
|
1037
|
-
|
1038
|
-
if research_result.onlineContext:
|
1039
|
-
online_results.update(research_result.onlineContext)
|
1040
|
-
if research_result.codeContext:
|
1041
|
-
code_results.update(research_result.codeContext)
|
1042
|
-
if research_result.context:
|
1043
|
-
compiled_references.extend(research_result.context)
|
1044
|
-
if not research_results or research_results[-1] is not research_result:
|
1045
|
-
research_results.append(research_result)
|
1046
|
-
else:
|
1047
|
-
yield research_result
|
1048
|
-
|
1049
|
-
# Track operator results across research and operator iterations
|
1050
|
-
# This relies on two conditions:
|
1051
|
-
# 1. Check to append new (partial) operator results
|
1052
|
-
# Relies on triggering this check on every status updates.
|
1053
|
-
# Status updates cascade up from operator to research to chat api on every step.
|
1054
|
-
# 2. Keep operator results in sync with each research operator step
|
1055
|
-
# Relies on python object references to ensure operator results
|
1056
|
-
# are implicitly kept in sync after the initial append
|
1057
|
-
if (
|
1058
|
-
research_results
|
1059
|
-
and research_results[-1].operatorContext
|
1060
|
-
and (not operator_results or operator_results[-1] is not research_results[-1].operatorContext)
|
1061
|
-
):
|
1062
|
-
operator_results.append(research_results[-1].operatorContext)
|
1063
|
-
|
1064
|
-
# researched_results = await extract_relevant_info(q, researched_results, agent)
|
1065
|
-
if state.verbose > 1:
|
1066
|
-
logger.debug(f'Researched Results: {"".join(r.summarizedResult or "" for r in research_results)}')
|
1067
|
-
|
1068
|
-
# Gather Context
|
1069
|
-
## Extract Document References
|
1070
|
-
if not ConversationCommand.Research in conversation_commands:
|
1071
|
-
try:
|
1072
|
-
async for result in search_documents(
|
1073
|
-
q,
|
1074
|
-
(n or 7),
|
1075
|
-
d,
|
1076
|
-
user,
|
1077
|
-
chat_history,
|
1078
|
-
conversation_id,
|
1079
|
-
conversation_commands,
|
1080
|
-
location,
|
1081
|
-
partial(send_event, ChatEvent.STATUS),
|
1082
|
-
query_images=uploaded_images,
|
1083
|
-
agent=agent,
|
1084
|
-
query_files=attached_file_context,
|
1085
|
-
tracer=tracer,
|
1086
|
-
):
|
1087
|
-
if isinstance(result, dict) and ChatEvent.STATUS in result:
|
1088
|
-
yield result[ChatEvent.STATUS]
|
1089
|
-
else:
|
1090
|
-
compiled_references.extend(result[0])
|
1091
|
-
inferred_queries.extend(result[1])
|
1092
|
-
defiltered_query = result[2]
|
1093
|
-
except Exception as e:
|
1094
|
-
error_message = (
|
1095
|
-
f"Error searching knowledge base: {e}. Attempting to respond without document references."
|
1096
|
-
)
|
1097
|
-
logger.error(error_message, exc_info=True)
|
1098
|
-
async for result in send_event(
|
1099
|
-
ChatEvent.STATUS, "Document search failed. I'll try respond without document references"
|
1100
|
-
):
|
1101
|
-
yield result
|
1102
|
-
|
1103
|
-
if not is_none_or_empty(compiled_references):
|
1104
|
-
headings = "\n- " + "\n- ".join(set([c.get("compiled", c).split("\n")[0] for c in compiled_references]))
|
1105
|
-
# Strip only leading # from headings
|
1106
|
-
headings = headings.replace("#", "")
|
1107
|
-
async for result in send_event(ChatEvent.STATUS, f"**Found Relevant Notes**: {headings}"):
|
1108
|
-
yield result
|
1001
|
+
)
|
1002
|
+
except ValueError as e:
|
1003
|
+
logger.error(f"Error getting data sources and output format: {e}. Falling back to default.")
|
1004
|
+
conversation_commands = [ConversationCommand.General]
|
1109
1005
|
|
1110
|
-
|
1111
|
-
async for result in send_llm_response(f"{no_entries_found.format()}", tracer.get("usage")):
|
1112
|
-
yield result
|
1113
|
-
return
|
1006
|
+
conversation_commands = chosen_io.get("sources") + [chosen_io.get("output")]
|
1114
1007
|
|
1115
|
-
|
1116
|
-
|
1008
|
+
# If we're doing research, we don't want to do anything else
|
1009
|
+
if ConversationCommand.Research in conversation_commands:
|
1010
|
+
conversation_commands = [ConversationCommand.Research]
|
1117
1011
|
|
1118
|
-
|
1119
|
-
|
1120
|
-
|
1121
|
-
async for result in search_online(
|
1122
|
-
defiltered_query,
|
1123
|
-
chat_history,
|
1124
|
-
location,
|
1125
|
-
user,
|
1126
|
-
partial(send_event, ChatEvent.STATUS),
|
1127
|
-
custom_filters=[],
|
1128
|
-
max_online_searches=3,
|
1129
|
-
query_images=uploaded_images,
|
1130
|
-
query_files=attached_file_context,
|
1131
|
-
agent=agent,
|
1132
|
-
tracer=tracer,
|
1133
|
-
):
|
1134
|
-
if isinstance(result, dict) and ChatEvent.STATUS in result:
|
1135
|
-
yield result[ChatEvent.STATUS]
|
1136
|
-
else:
|
1137
|
-
online_results = result
|
1138
|
-
except Exception as e:
|
1139
|
-
error_message = f"Error searching online: {e}. Attempting to respond without online results"
|
1140
|
-
logger.warning(error_message)
|
1141
|
-
async for result in send_event(
|
1142
|
-
ChatEvent.STATUS, "Online search failed. I'll try respond without online references"
|
1143
|
-
):
|
1144
|
-
yield result
|
1012
|
+
conversation_commands_str = ", ".join([cmd.value for cmd in conversation_commands])
|
1013
|
+
async for result in send_event(ChatEvent.STATUS, f"**Selected Tools:** {conversation_commands_str}"):
|
1014
|
+
yield result
|
1145
1015
|
|
1146
|
-
|
1147
|
-
|
1148
|
-
|
1149
|
-
|
1150
|
-
|
1151
|
-
|
1152
|
-
|
1153
|
-
|
1154
|
-
|
1155
|
-
max_webpages_to_read=1,
|
1156
|
-
query_images=uploaded_images,
|
1157
|
-
agent=agent,
|
1158
|
-
query_files=attached_file_context,
|
1159
|
-
tracer=tracer,
|
1160
|
-
):
|
1161
|
-
if isinstance(result, dict) and ChatEvent.STATUS in result:
|
1162
|
-
yield result[ChatEvent.STATUS]
|
1163
|
-
else:
|
1164
|
-
direct_web_pages = result
|
1165
|
-
webpages = []
|
1166
|
-
for query in direct_web_pages:
|
1167
|
-
if online_results.get(query):
|
1168
|
-
online_results[query]["webpages"] = direct_web_pages[query]["webpages"]
|
1169
|
-
else:
|
1170
|
-
online_results[query] = {"webpages": direct_web_pages[query]["webpages"]}
|
1016
|
+
cmds_to_rate_limit += conversation_commands
|
1017
|
+
for cmd in cmds_to_rate_limit:
|
1018
|
+
try:
|
1019
|
+
await conversation_command_rate_limiter.update_and_check_if_valid(request_obj, cmd)
|
1020
|
+
q = q.replace(f"/{cmd.value}", "").strip()
|
1021
|
+
except HTTPException as e:
|
1022
|
+
async for result in send_llm_response(str(e.detail), tracer.get("usage")):
|
1023
|
+
yield result
|
1024
|
+
return
|
1171
1025
|
|
1172
|
-
|
1173
|
-
|
1174
|
-
async for result in send_event(ChatEvent.STATUS, f"**Read web pages**: {webpages}"):
|
1175
|
-
yield result
|
1176
|
-
except Exception as e:
|
1177
|
-
logger.warning(
|
1178
|
-
f"Error reading webpages: {e}. Attempting to respond without webpage results",
|
1179
|
-
exc_info=True,
|
1180
|
-
)
|
1181
|
-
async for result in send_event(
|
1182
|
-
ChatEvent.STATUS, "Webpage read failed. I'll try respond without webpage references"
|
1183
|
-
):
|
1184
|
-
yield result
|
1026
|
+
defiltered_query = defilter_query(q)
|
1027
|
+
file_filters = conversation.file_filters if conversation and conversation.file_filters else []
|
1185
1028
|
|
1186
|
-
|
1187
|
-
|
1188
|
-
|
1189
|
-
|
1190
|
-
|
1191
|
-
|
1192
|
-
|
1193
|
-
|
1194
|
-
|
1195
|
-
|
1196
|
-
|
1197
|
-
|
1198
|
-
|
1199
|
-
|
1200
|
-
|
1201
|
-
|
1202
|
-
|
1203
|
-
|
1204
|
-
else:
|
1205
|
-
code_results = result
|
1206
|
-
except ValueError as e:
|
1207
|
-
program_execution_context.append(f"Failed to run code")
|
1208
|
-
logger.warning(
|
1209
|
-
f"Failed to use code tool: {e}. Attempting to respond without code results",
|
1210
|
-
exc_info=True,
|
1211
|
-
)
|
1212
|
-
if ConversationCommand.Operator in conversation_commands:
|
1213
|
-
try:
|
1214
|
-
async for result in operate_environment(
|
1215
|
-
defiltered_query,
|
1216
|
-
user,
|
1217
|
-
chat_history,
|
1218
|
-
location,
|
1219
|
-
list(operator_results)[-1] if operator_results else None,
|
1220
|
-
query_images=uploaded_images,
|
1221
|
-
query_files=attached_file_context,
|
1222
|
-
send_status_func=partial(send_event, ChatEvent.STATUS),
|
1223
|
-
agent=agent,
|
1224
|
-
cancellation_event=cancellation_event,
|
1225
|
-
tracer=tracer,
|
1226
|
-
):
|
1227
|
-
if isinstance(result, dict) and ChatEvent.STATUS in result:
|
1228
|
-
yield result[ChatEvent.STATUS]
|
1229
|
-
elif isinstance(result, OperatorRun):
|
1230
|
-
if not operator_results or operator_results[-1] is not result:
|
1231
|
-
operator_results.append(result)
|
1232
|
-
# Add webpages visited while operating browser to references
|
1233
|
-
if result.webpages:
|
1234
|
-
if not online_results.get(defiltered_query):
|
1235
|
-
online_results[defiltered_query] = {"webpages": result.webpages}
|
1236
|
-
elif not online_results[defiltered_query].get("webpages"):
|
1237
|
-
online_results[defiltered_query]["webpages"] = result.webpages
|
1238
|
-
else:
|
1239
|
-
online_results[defiltered_query]["webpages"] += result.webpages
|
1240
|
-
except ValueError as e:
|
1241
|
-
program_execution_context.append(f"Browser operation error: {e}")
|
1242
|
-
logger.warning(f"Failed to operate browser with {e}", exc_info=True)
|
1243
|
-
async for result in send_event(
|
1244
|
-
ChatEvent.STATUS, "Operating browser failed. I'll try respond appropriately"
|
1245
|
-
):
|
1246
|
-
yield result
|
1247
|
-
|
1248
|
-
## Send Gathered References
|
1249
|
-
unique_online_results = deduplicate_organic_results(online_results)
|
1250
|
-
async for result in send_event(
|
1251
|
-
ChatEvent.REFERENCES,
|
1252
|
-
{
|
1253
|
-
"inferredQueries": inferred_queries,
|
1254
|
-
"context": compiled_references,
|
1255
|
-
"onlineContext": unique_online_results,
|
1256
|
-
"codeContext": code_results,
|
1257
|
-
},
|
1029
|
+
if conversation_commands == [ConversationCommand.Research]:
|
1030
|
+
async for research_result in research(
|
1031
|
+
user=user,
|
1032
|
+
query=defiltered_query,
|
1033
|
+
conversation_id=conversation_id,
|
1034
|
+
conversation_history=chat_history,
|
1035
|
+
previous_iterations=list(research_results),
|
1036
|
+
query_images=uploaded_images,
|
1037
|
+
agent=agent,
|
1038
|
+
send_status_func=partial(send_event, ChatEvent.STATUS),
|
1039
|
+
user_name=user_name,
|
1040
|
+
location=location,
|
1041
|
+
file_filters=file_filters,
|
1042
|
+
query_files=attached_file_context,
|
1043
|
+
tracer=tracer,
|
1044
|
+
cancellation_event=cancellation_event,
|
1045
|
+
interrupt_queue=child_interrupt_queue,
|
1046
|
+
abort_message=event_delimiter,
|
1258
1047
|
):
|
1259
|
-
|
1048
|
+
if isinstance(research_result, ResearchIteration):
|
1049
|
+
if research_result.summarizedResult:
|
1050
|
+
if research_result.onlineContext:
|
1051
|
+
online_results.update(research_result.onlineContext)
|
1052
|
+
if research_result.codeContext:
|
1053
|
+
code_results.update(research_result.codeContext)
|
1054
|
+
if research_result.context:
|
1055
|
+
compiled_references.extend(research_result.context)
|
1056
|
+
if not research_results or research_results[-1] is not research_result:
|
1057
|
+
research_results.append(research_result)
|
1058
|
+
else:
|
1059
|
+
yield research_result
|
1060
|
+
|
1061
|
+
# Track operator results across research and operator iterations
|
1062
|
+
# This relies on two conditions:
|
1063
|
+
# 1. Check to append new (partial) operator results
|
1064
|
+
# Relies on triggering this check on every status updates.
|
1065
|
+
# Status updates cascade up from operator to research to chat api on every step.
|
1066
|
+
# 2. Keep operator results in sync with each research operator step
|
1067
|
+
# Relies on python object references to ensure operator results
|
1068
|
+
# are implicitly kept in sync after the initial append
|
1069
|
+
if (
|
1070
|
+
research_results
|
1071
|
+
and research_results[-1].operatorContext
|
1072
|
+
and (not operator_results or operator_results[-1] is not research_results[-1].operatorContext)
|
1073
|
+
):
|
1074
|
+
operator_results.append(research_results[-1].operatorContext)
|
1260
1075
|
|
1261
|
-
#
|
1262
|
-
|
1263
|
-
|
1264
|
-
|
1265
|
-
|
1076
|
+
# researched_results = await extract_relevant_info(q, researched_results, agent)
|
1077
|
+
if state.verbose > 1:
|
1078
|
+
logger.debug(f'Researched Results: {"".join(r.summarizedResult or "" for r in research_results)}')
|
1079
|
+
|
1080
|
+
# Gather Context
|
1081
|
+
## Extract Document References
|
1082
|
+
if not ConversationCommand.Research in conversation_commands:
|
1083
|
+
try:
|
1084
|
+
async for result in search_documents(
|
1085
|
+
q,
|
1086
|
+
(n or 7),
|
1087
|
+
d,
|
1266
1088
|
user,
|
1267
1089
|
chat_history,
|
1268
|
-
|
1269
|
-
|
1270
|
-
|
1271
|
-
|
1090
|
+
conversation_id,
|
1091
|
+
conversation_commands,
|
1092
|
+
location,
|
1093
|
+
partial(send_event, ChatEvent.STATUS),
|
1272
1094
|
query_images=uploaded_images,
|
1273
1095
|
agent=agent,
|
1274
1096
|
query_files=attached_file_context,
|
@@ -1277,184 +1099,526 @@ async def chat(
|
|
1277
1099
|
if isinstance(result, dict) and ChatEvent.STATUS in result:
|
1278
1100
|
yield result[ChatEvent.STATUS]
|
1279
1101
|
else:
|
1280
|
-
|
1102
|
+
compiled_references.extend(result[0])
|
1103
|
+
inferred_queries.extend(result[1])
|
1104
|
+
defiltered_query = result[2]
|
1105
|
+
except Exception as e:
|
1106
|
+
error_message = f"Error searching knowledge base: {e}. Attempting to respond without document references."
|
1107
|
+
logger.error(error_message, exc_info=True)
|
1108
|
+
async for result in send_event(
|
1109
|
+
ChatEvent.STATUS, "Document search failed. I'll try respond without document references"
|
1110
|
+
):
|
1111
|
+
yield result
|
1281
1112
|
|
1282
|
-
|
1283
|
-
|
1284
|
-
|
1285
|
-
|
1286
|
-
|
1287
|
-
|
1288
|
-
|
1113
|
+
if not is_none_or_empty(compiled_references):
|
1114
|
+
distinct_headings = set([d.get("compiled").split("\n")[0] for d in compiled_references if "compiled" in d])
|
1115
|
+
distinct_files = set([d["file"] for d in compiled_references])
|
1116
|
+
# Strip only leading # from headings
|
1117
|
+
headings_str = "\n- " + "\n- ".join(distinct_headings).replace("#", "")
|
1118
|
+
async for result in send_event(
|
1119
|
+
ChatEvent.STATUS,
|
1120
|
+
f"**Found {len(distinct_headings)} Notes Across {len(distinct_files)} Files**: {headings_str}",
|
1121
|
+
):
|
1122
|
+
yield result
|
1289
1123
|
|
1290
|
-
|
1291
|
-
|
1292
|
-
|
1124
|
+
if conversation_commands == [ConversationCommand.Notes] and not await EntryAdapters.auser_has_entries(user):
|
1125
|
+
async for result in send_llm_response(f"{no_entries_found.format()}", tracer.get("usage")):
|
1126
|
+
yield result
|
1127
|
+
return
|
1293
1128
|
|
1294
|
-
|
1295
|
-
|
1296
|
-
{
|
1297
|
-
"images": [generated_image],
|
1298
|
-
},
|
1299
|
-
):
|
1300
|
-
yield result
|
1129
|
+
if ConversationCommand.Notes in conversation_commands and is_none_or_empty(compiled_references):
|
1130
|
+
conversation_commands.remove(ConversationCommand.Notes)
|
1301
1131
|
|
1302
|
-
|
1303
|
-
|
1132
|
+
## Gather Online References
|
1133
|
+
if ConversationCommand.Online in conversation_commands:
|
1134
|
+
try:
|
1135
|
+
async for result in search_online(
|
1136
|
+
defiltered_query,
|
1137
|
+
chat_history,
|
1138
|
+
location,
|
1139
|
+
user,
|
1140
|
+
partial(send_event, ChatEvent.STATUS),
|
1141
|
+
custom_filters=[],
|
1142
|
+
max_online_searches=3,
|
1143
|
+
query_images=uploaded_images,
|
1144
|
+
query_files=attached_file_context,
|
1145
|
+
agent=agent,
|
1146
|
+
tracer=tracer,
|
1147
|
+
):
|
1148
|
+
if isinstance(result, dict) and ChatEvent.STATUS in result:
|
1149
|
+
yield result[ChatEvent.STATUS]
|
1150
|
+
else:
|
1151
|
+
online_results = result
|
1152
|
+
except Exception as e:
|
1153
|
+
error_message = f"Error searching online: {e}. Attempting to respond without online results"
|
1154
|
+
logger.warning(error_message)
|
1155
|
+
async for result in send_event(
|
1156
|
+
ChatEvent.STATUS, "Online search failed. I'll try respond without online references"
|
1157
|
+
):
|
1304
1158
|
yield result
|
1305
1159
|
|
1306
|
-
|
1307
|
-
|
1160
|
+
## Gather Webpage References
|
1161
|
+
if ConversationCommand.Webpage in conversation_commands:
|
1162
|
+
try:
|
1163
|
+
async for result in read_webpages(
|
1164
|
+
defiltered_query,
|
1165
|
+
chat_history,
|
1166
|
+
location,
|
1167
|
+
user,
|
1168
|
+
partial(send_event, ChatEvent.STATUS),
|
1169
|
+
max_webpages_to_read=1,
|
1170
|
+
query_images=uploaded_images,
|
1171
|
+
agent=agent,
|
1172
|
+
query_files=attached_file_context,
|
1173
|
+
tracer=tracer,
|
1174
|
+
):
|
1175
|
+
if isinstance(result, dict) and ChatEvent.STATUS in result:
|
1176
|
+
yield result[ChatEvent.STATUS]
|
1177
|
+
else:
|
1178
|
+
direct_web_pages = result
|
1179
|
+
webpages = []
|
1180
|
+
for query in direct_web_pages:
|
1181
|
+
if online_results.get(query):
|
1182
|
+
online_results[query]["webpages"] = direct_web_pages[query]["webpages"]
|
1183
|
+
else:
|
1184
|
+
online_results[query] = {"webpages": direct_web_pages[query]["webpages"]}
|
1308
1185
|
|
1309
|
-
|
1310
|
-
|
1311
|
-
|
1312
|
-
|
1313
|
-
|
1314
|
-
|
1186
|
+
for webpage in direct_web_pages[query]["webpages"]:
|
1187
|
+
webpages.append(webpage["link"])
|
1188
|
+
async for result in send_event(ChatEvent.STATUS, f"**Read web pages**: {webpages}"):
|
1189
|
+
yield result
|
1190
|
+
except Exception as e:
|
1191
|
+
logger.warning(
|
1192
|
+
f"Error reading webpages: {e}. Attempting to respond without webpage results",
|
1193
|
+
exc_info=True,
|
1194
|
+
)
|
1195
|
+
async for result in send_event(
|
1196
|
+
ChatEvent.STATUS, "Webpage read failed. I'll try respond without webpage references"
|
1197
|
+
):
|
1198
|
+
yield result
|
1199
|
+
|
1200
|
+
## Gather Code Results
|
1201
|
+
if ConversationCommand.Code in conversation_commands:
|
1202
|
+
try:
|
1203
|
+
context = f"# Iteration 1:\n#---\nNotes:\n{compiled_references}\n\nOnline Results:{online_results}"
|
1204
|
+
async for result in run_code(
|
1205
|
+
defiltered_query,
|
1206
|
+
chat_history,
|
1207
|
+
context,
|
1208
|
+
location,
|
1209
|
+
user,
|
1210
|
+
partial(send_event, ChatEvent.STATUS),
|
1315
1211
|
query_images=uploaded_images,
|
1316
|
-
user=user,
|
1317
1212
|
agent=agent,
|
1318
|
-
send_status_func=partial(send_event, ChatEvent.STATUS),
|
1319
1213
|
query_files=attached_file_context,
|
1320
1214
|
tracer=tracer,
|
1321
1215
|
):
|
1322
1216
|
if isinstance(result, dict) and ChatEvent.STATUS in result:
|
1323
1217
|
yield result[ChatEvent.STATUS]
|
1324
1218
|
else:
|
1325
|
-
|
1326
|
-
|
1327
|
-
|
1328
|
-
|
1329
|
-
|
1330
|
-
|
1331
|
-
|
1332
|
-
|
1333
|
-
|
1334
|
-
|
1335
|
-
|
1336
|
-
|
1337
|
-
|
1338
|
-
|
1339
|
-
|
1340
|
-
|
1341
|
-
|
1342
|
-
|
1343
|
-
|
1344
|
-
|
1345
|
-
|
1346
|
-
|
1347
|
-
|
1348
|
-
|
1349
|
-
|
1219
|
+
code_results = result
|
1220
|
+
except ValueError as e:
|
1221
|
+
program_execution_context.append(f"Failed to run code")
|
1222
|
+
logger.warning(
|
1223
|
+
f"Failed to use code tool: {e}. Attempting to respond without code results",
|
1224
|
+
exc_info=True,
|
1225
|
+
)
|
1226
|
+
if ConversationCommand.Operator in conversation_commands:
|
1227
|
+
try:
|
1228
|
+
async for result in operate_environment(
|
1229
|
+
defiltered_query,
|
1230
|
+
user,
|
1231
|
+
chat_history,
|
1232
|
+
location,
|
1233
|
+
list(operator_results)[-1] if operator_results else None,
|
1234
|
+
query_images=uploaded_images,
|
1235
|
+
query_files=attached_file_context,
|
1236
|
+
send_status_func=partial(send_event, ChatEvent.STATUS),
|
1237
|
+
agent=agent,
|
1238
|
+
cancellation_event=cancellation_event,
|
1239
|
+
interrupt_queue=child_interrupt_queue,
|
1240
|
+
tracer=tracer,
|
1241
|
+
):
|
1242
|
+
if isinstance(result, dict) and ChatEvent.STATUS in result:
|
1243
|
+
yield result[ChatEvent.STATUS]
|
1244
|
+
elif isinstance(result, OperatorRun):
|
1245
|
+
if not operator_results or operator_results[-1] is not result:
|
1246
|
+
operator_results.append(result)
|
1247
|
+
# Add webpages visited while operating browser to references
|
1248
|
+
if result.webpages:
|
1249
|
+
if not online_results.get(defiltered_query):
|
1250
|
+
online_results[defiltered_query] = {"webpages": result.webpages}
|
1251
|
+
elif not online_results[defiltered_query].get("webpages"):
|
1252
|
+
online_results[defiltered_query]["webpages"] = result.webpages
|
1253
|
+
else:
|
1254
|
+
online_results[defiltered_query]["webpages"] += result.webpages
|
1255
|
+
except ValueError as e:
|
1256
|
+
program_execution_context.append(f"Browser operation error: {e}")
|
1257
|
+
logger.warning(f"Failed to operate browser with {e}", exc_info=True)
|
1258
|
+
async for result in send_event(
|
1259
|
+
ChatEvent.STATUS, "Operating browser failed. I'll try respond appropriately"
|
1260
|
+
):
|
1261
|
+
yield result
|
1262
|
+
|
1263
|
+
## Send Gathered References
|
1264
|
+
unique_online_results = deduplicate_organic_results(online_results)
|
1265
|
+
async for result in send_event(
|
1266
|
+
ChatEvent.REFERENCES,
|
1267
|
+
{
|
1268
|
+
"inferredQueries": inferred_queries,
|
1269
|
+
"context": compiled_references,
|
1270
|
+
"onlineContext": unique_online_results,
|
1271
|
+
"codeContext": code_results,
|
1272
|
+
},
|
1273
|
+
):
|
1274
|
+
yield result
|
1275
|
+
|
1276
|
+
# Generate Output
|
1277
|
+
## Generate Image Output
|
1278
|
+
if ConversationCommand.Image in conversation_commands:
|
1279
|
+
async for result in text_to_image(
|
1280
|
+
defiltered_query,
|
1281
|
+
user,
|
1282
|
+
chat_history,
|
1283
|
+
location_data=location,
|
1284
|
+
references=compiled_references,
|
1285
|
+
online_results=online_results,
|
1286
|
+
send_status_func=partial(send_event, ChatEvent.STATUS),
|
1287
|
+
query_images=uploaded_images,
|
1288
|
+
agent=agent,
|
1289
|
+
query_files=attached_file_context,
|
1290
|
+
tracer=tracer,
|
1291
|
+
):
|
1292
|
+
if isinstance(result, dict) and ChatEvent.STATUS in result:
|
1293
|
+
yield result[ChatEvent.STATUS]
|
1294
|
+
else:
|
1295
|
+
generated_image, status_code, improved_image_prompt = result
|
1296
|
+
|
1297
|
+
inferred_queries.append(improved_image_prompt)
|
1298
|
+
if generated_image is None or status_code != 200:
|
1299
|
+
program_execution_context.append(f"Failed to generate image with {improved_image_prompt}")
|
1300
|
+
async for result in send_event(ChatEvent.STATUS, f"Failed to generate image"):
|
1301
|
+
yield result
|
1302
|
+
else:
|
1303
|
+
generated_images.append(generated_image)
|
1304
|
+
|
1305
|
+
generated_asset_results["images"] = {
|
1306
|
+
"query": improved_image_prompt,
|
1307
|
+
}
|
1308
|
+
|
1309
|
+
async for result in send_event(
|
1310
|
+
ChatEvent.GENERATED_ASSETS,
|
1311
|
+
{
|
1312
|
+
"images": [generated_image],
|
1313
|
+
},
|
1314
|
+
):
|
1315
|
+
yield result
|
1316
|
+
|
1317
|
+
if ConversationCommand.Diagram in conversation_commands:
|
1318
|
+
async for result in send_event(ChatEvent.STATUS, f"Creating diagram"):
|
1319
|
+
yield result
|
1350
1320
|
|
1351
|
-
|
1352
|
-
|
1321
|
+
inferred_queries = []
|
1322
|
+
diagram_description = ""
|
1323
|
+
|
1324
|
+
async for result in generate_mermaidjs_diagram(
|
1325
|
+
q=defiltered_query,
|
1326
|
+
chat_history=chat_history,
|
1327
|
+
location_data=location,
|
1328
|
+
note_references=compiled_references,
|
1329
|
+
online_results=online_results,
|
1330
|
+
query_images=uploaded_images,
|
1331
|
+
user=user,
|
1332
|
+
agent=agent,
|
1333
|
+
send_status_func=partial(send_event, ChatEvent.STATUS),
|
1334
|
+
query_files=attached_file_context,
|
1335
|
+
tracer=tracer,
|
1336
|
+
):
|
1337
|
+
if isinstance(result, dict) and ChatEvent.STATUS in result:
|
1338
|
+
yield result[ChatEvent.STATUS]
|
1339
|
+
else:
|
1340
|
+
better_diagram_description_prompt, mermaidjs_diagram_description = result
|
1341
|
+
if better_diagram_description_prompt and mermaidjs_diagram_description:
|
1342
|
+
inferred_queries.append(better_diagram_description_prompt)
|
1343
|
+
diagram_description = mermaidjs_diagram_description
|
1344
|
+
|
1345
|
+
generated_mermaidjs_diagram = diagram_description
|
1346
|
+
|
1347
|
+
generated_asset_results["diagrams"] = {
|
1348
|
+
"query": better_diagram_description_prompt,
|
1349
|
+
}
|
1350
|
+
|
1351
|
+
async for result in send_event(
|
1352
|
+
ChatEvent.GENERATED_ASSETS,
|
1353
|
+
{
|
1354
|
+
"mermaidjsDiagram": mermaidjs_diagram_description,
|
1355
|
+
},
|
1356
|
+
):
|
1357
|
+
yield result
|
1358
|
+
else:
|
1359
|
+
error_message = "Failed to generate diagram. Please try again later."
|
1360
|
+
program_execution_context.append(
|
1361
|
+
prompts.failed_diagram_generation.format(attempted_diagram=better_diagram_description_prompt)
|
1362
|
+
)
|
1363
|
+
|
1364
|
+
async for result in send_event(ChatEvent.STATUS, error_message):
|
1365
|
+
yield result
|
1366
|
+
|
1367
|
+
# Check if the user has disconnected
|
1368
|
+
if cancellation_event.is_set():
|
1369
|
+
logger.debug(f"Stopping LLM response to user {user} on {common.client} client.")
|
1370
|
+
# Cancel the disconnect monitor task if it is still running
|
1371
|
+
await cancel_disconnect_monitor()
|
1372
|
+
return
|
1373
|
+
|
1374
|
+
## Generate Text Output
|
1375
|
+
async for result in send_event(ChatEvent.STATUS, f"**Generating a well-informed response**"):
|
1376
|
+
yield result
|
1377
|
+
|
1378
|
+
llm_response, chat_metadata = await agenerate_chat_response(
|
1379
|
+
defiltered_query,
|
1380
|
+
chat_history,
|
1381
|
+
conversation,
|
1382
|
+
compiled_references,
|
1383
|
+
online_results,
|
1384
|
+
code_results,
|
1385
|
+
operator_results,
|
1386
|
+
research_results,
|
1387
|
+
user,
|
1388
|
+
location,
|
1389
|
+
user_name,
|
1390
|
+
uploaded_images,
|
1391
|
+
attached_file_context,
|
1392
|
+
generated_files,
|
1393
|
+
program_execution_context,
|
1394
|
+
generated_asset_results,
|
1395
|
+
is_subscribed,
|
1396
|
+
tracer,
|
1397
|
+
)
|
1353
1398
|
|
1354
|
-
|
1399
|
+
full_response = ""
|
1400
|
+
async for item in llm_response:
|
1401
|
+
# Should not happen with async generator. Skip.
|
1402
|
+
if item is None or not isinstance(item, ResponseWithThought):
|
1403
|
+
logger.warning(f"Unexpected item type in LLM response: {type(item)}. Skipping.")
|
1404
|
+
continue
|
1355
1405
|
if cancellation_event.is_set():
|
1356
|
-
|
1357
|
-
|
1358
|
-
|
1359
|
-
|
1406
|
+
break
|
1407
|
+
message = item.text
|
1408
|
+
full_response += message if message else ""
|
1409
|
+
if item.thought:
|
1410
|
+
async for result in send_event(ChatEvent.THOUGHT, item.thought):
|
1411
|
+
yield result
|
1412
|
+
continue
|
1360
1413
|
|
1361
|
-
|
1362
|
-
async for result in send_event(ChatEvent.
|
1414
|
+
# Start sending response
|
1415
|
+
async for result in send_event(ChatEvent.START_LLM_RESPONSE, ""):
|
1363
1416
|
yield result
|
1364
1417
|
|
1365
|
-
|
1366
|
-
|
1367
|
-
|
1368
|
-
|
1369
|
-
|
1370
|
-
|
1371
|
-
|
1372
|
-
|
1373
|
-
|
1374
|
-
|
1375
|
-
|
1376
|
-
|
1377
|
-
|
1378
|
-
|
1379
|
-
|
1380
|
-
|
1381
|
-
|
1382
|
-
|
1383
|
-
|
1418
|
+
try:
|
1419
|
+
async for result in send_event(ChatEvent.MESSAGE, message):
|
1420
|
+
yield result
|
1421
|
+
except Exception as e:
|
1422
|
+
if not cancellation_event.is_set():
|
1423
|
+
logger.warning(f"Error during streaming. Stopping send: {e}")
|
1424
|
+
break
|
1425
|
+
|
1426
|
+
# Save conversation once finish streaming
|
1427
|
+
asyncio.create_task(
|
1428
|
+
save_to_conversation_log(
|
1429
|
+
q,
|
1430
|
+
chat_response=full_response,
|
1431
|
+
user=user,
|
1432
|
+
compiled_references=compiled_references,
|
1433
|
+
online_results=online_results,
|
1434
|
+
code_results=code_results,
|
1435
|
+
operator_results=operator_results,
|
1436
|
+
research_results=research_results,
|
1437
|
+
inferred_queries=inferred_queries,
|
1438
|
+
client_application=user_scope.client_app,
|
1439
|
+
conversation_id=str(conversation.id),
|
1440
|
+
query_images=uploaded_images,
|
1441
|
+
train_of_thought=train_of_thought,
|
1442
|
+
raw_query_files=raw_query_files,
|
1443
|
+
generated_images=generated_images,
|
1444
|
+
raw_generated_files=generated_files,
|
1445
|
+
generated_mermaidjs_diagram=generated_mermaidjs_diagram,
|
1446
|
+
tracer=tracer,
|
1384
1447
|
)
|
1448
|
+
)
|
1385
1449
|
|
1386
|
-
|
1387
|
-
|
1388
|
-
|
1389
|
-
|
1390
|
-
|
1450
|
+
# Signal end of LLM response after the loop finishes
|
1451
|
+
if not cancellation_event.is_set():
|
1452
|
+
async for result in send_event(ChatEvent.END_LLM_RESPONSE, ""):
|
1453
|
+
yield result
|
1454
|
+
# Send Usage Metadata once llm interactions are complete
|
1455
|
+
if tracer.get("usage"):
|
1456
|
+
async for event in send_event(ChatEvent.USAGE, tracer.get("usage")):
|
1457
|
+
yield event
|
1458
|
+
async for result in send_event(ChatEvent.END_RESPONSE, ""):
|
1459
|
+
yield result
|
1460
|
+
logger.debug("Finished streaming response")
|
1461
|
+
|
1462
|
+
# Cancel the disconnect monitor task if it is still running
|
1463
|
+
await cancel_disconnect_monitor()
|
1464
|
+
|
1465
|
+
|
1466
|
+
@api_chat.websocket("/ws")
|
1467
|
+
@requires(["authenticated"])
|
1468
|
+
async def chat_ws(
|
1469
|
+
websocket: WebSocket,
|
1470
|
+
common: CommonQueryParams,
|
1471
|
+
):
|
1472
|
+
# Validate WebSocket Origin
|
1473
|
+
origin = websocket.headers.get("origin")
|
1474
|
+
if not origin or URL(origin).hostname not in ALLOWED_HOSTS:
|
1475
|
+
await websocket.close(code=1008, reason="Origin not allowed")
|
1476
|
+
return
|
1477
|
+
|
1478
|
+
# Limit open websocket connections per user
|
1479
|
+
user = websocket.scope["user"].object
|
1480
|
+
connection_manager = WebSocketConnectionManager(trial_user_max_connections=5, subscribed_user_max_connections=10)
|
1481
|
+
connection_id = str(uuid.uuid4())
|
1482
|
+
|
1483
|
+
if not await connection_manager.can_connect(websocket):
|
1484
|
+
await websocket.close(code=1008, reason="Connection limit exceeded")
|
1485
|
+
logger.info(f"WebSocket connection rejected for user {user.id}: connection limit exceeded")
|
1486
|
+
return
|
1487
|
+
|
1488
|
+
await websocket.accept()
|
1489
|
+
|
1490
|
+
# Note new websocket connection for the user
|
1491
|
+
await connection_manager.register_connection(user, connection_id)
|
1492
|
+
|
1493
|
+
# Initialize rate limiters
|
1494
|
+
rate_limiter_per_minute = ApiUserRateLimiter(requests=20, subscribed_requests=20, window=60, slug="chat_minute")
|
1495
|
+
rate_limiter_per_day = ApiUserRateLimiter(
|
1496
|
+
requests=100, subscribed_requests=600, window=60 * 60 * 24, slug="chat_day"
|
1497
|
+
)
|
1498
|
+
image_rate_limiter = ApiImageRateLimiter(max_images=10, max_combined_size_mb=20)
|
1499
|
+
|
1500
|
+
# Shared interrupt queue for communicating interrupts to ongoing research
|
1501
|
+
interrupt_queue: asyncio.Queue = asyncio.Queue(maxsize=10)
|
1502
|
+
current_task = None
|
1503
|
+
|
1504
|
+
try:
|
1505
|
+
while True:
|
1506
|
+
data = await websocket.receive_json()
|
1507
|
+
|
1508
|
+
# Check if this is an interrupt message
|
1509
|
+
if data.get("type") == "interrupt":
|
1510
|
+
if current_task and not current_task.done():
|
1511
|
+
# Send interrupt signal to the ongoing task
|
1512
|
+
abort_message = "␃🔚␗"
|
1513
|
+
await interrupt_queue.put(data.get("query") or abort_message)
|
1514
|
+
logger.info(
|
1515
|
+
f"Interrupt signal sent to ongoing task for user {websocket.scope['user'].object.id} with query: {data.get('query')}"
|
1516
|
+
)
|
1517
|
+
if data.get("query"):
|
1518
|
+
ack_type = "interrupt_message_acknowledged"
|
1519
|
+
await websocket.send_text(json.dumps({"type": ack_type}))
|
1520
|
+
else:
|
1521
|
+
ack_type = "interrupt_acknowledged"
|
1522
|
+
await websocket.send_text(json.dumps({"type": ack_type}))
|
1523
|
+
else:
|
1524
|
+
logger.info(f"No ongoing task to interrupt for user {websocket.scope['user'].object.id}")
|
1391
1525
|
continue
|
1392
|
-
|
1393
|
-
|
1394
|
-
|
1395
|
-
|
1396
|
-
if item.thought:
|
1397
|
-
async for result in send_event(ChatEvent.THOUGHT, item.thought):
|
1398
|
-
yield result
|
1526
|
+
|
1527
|
+
# Handle regular chat messages - ensure data has required fields
|
1528
|
+
if "q" not in data:
|
1529
|
+
await websocket.send_text(json.dumps({"error": "Missing required field 'q' in chat message"}))
|
1399
1530
|
continue
|
1400
1531
|
|
1401
|
-
|
1402
|
-
async for result in send_event(ChatEvent.START_LLM_RESPONSE, ""):
|
1403
|
-
yield result
|
1532
|
+
body = ChatRequestBody(**data)
|
1404
1533
|
|
1534
|
+
# Apply rate limiting manually
|
1405
1535
|
try:
|
1406
|
-
|
1407
|
-
|
1408
|
-
|
1409
|
-
|
1410
|
-
|
1411
|
-
|
1536
|
+
rate_limiter_per_minute.check_websocket(websocket)
|
1537
|
+
rate_limiter_per_day.check_websocket(websocket)
|
1538
|
+
image_rate_limiter.check_websocket(websocket, body)
|
1539
|
+
except HTTPException as e:
|
1540
|
+
await websocket.send_text(json.dumps({"error": e.detail}))
|
1541
|
+
continue
|
1412
1542
|
|
1413
|
-
|
1414
|
-
|
1415
|
-
|
1416
|
-
|
1417
|
-
|
1418
|
-
|
1419
|
-
|
1420
|
-
|
1421
|
-
|
1422
|
-
|
1423
|
-
|
1424
|
-
|
1425
|
-
|
1426
|
-
|
1427
|
-
|
1428
|
-
|
1429
|
-
|
1430
|
-
|
1431
|
-
|
1432
|
-
|
1433
|
-
|
1434
|
-
|
1435
|
-
|
1543
|
+
# Cancel any ongoing task before starting a new one
|
1544
|
+
if current_task and not current_task.done():
|
1545
|
+
current_task.cancel()
|
1546
|
+
try:
|
1547
|
+
await current_task
|
1548
|
+
except asyncio.CancelledError:
|
1549
|
+
pass
|
1550
|
+
|
1551
|
+
# Create a new task for processing the chat request
|
1552
|
+
current_task = asyncio.create_task(process_chat_request(websocket, body, common, interrupt_queue))
|
1553
|
+
|
1554
|
+
except WebSocketDisconnect:
|
1555
|
+
logger.info(f"WebSocket disconnected for user {websocket.scope['user'].object.id}")
|
1556
|
+
if current_task and not current_task.done():
|
1557
|
+
current_task.cancel()
|
1558
|
+
except Exception as e:
|
1559
|
+
logger.error(f"Error in websocket chat: {e}", exc_info=True)
|
1560
|
+
if current_task and not current_task.done():
|
1561
|
+
current_task.cancel()
|
1562
|
+
await websocket.close(code=1011, reason="Internal Server Error")
|
1563
|
+
finally:
|
1564
|
+
# Always unregister the connection on disconnect
|
1565
|
+
await connection_manager.unregister_connection(user, connection_id)
|
1566
|
+
|
1567
|
+
|
1568
|
+
async def process_chat_request(
|
1569
|
+
websocket: WebSocket,
|
1570
|
+
body: ChatRequestBody,
|
1571
|
+
common: CommonQueryParams,
|
1572
|
+
interrupt_queue: asyncio.Queue,
|
1573
|
+
):
|
1574
|
+
"""Process a single chat request with interrupt support"""
|
1575
|
+
try:
|
1576
|
+
# Since we are using websockets, we can ignore the stream parameter and always stream
|
1577
|
+
response_iterator = event_generator(
|
1578
|
+
body,
|
1579
|
+
websocket.scope["user"],
|
1580
|
+
common,
|
1581
|
+
websocket.headers,
|
1582
|
+
websocket,
|
1583
|
+
interrupt_queue,
|
1436
1584
|
)
|
1585
|
+
async for event in response_iterator:
|
1586
|
+
await websocket.send_text(event)
|
1587
|
+
except asyncio.CancelledError:
|
1588
|
+
logger.debug(f"Chat request cancelled for user {websocket.scope['user'].object.id}")
|
1589
|
+
raise
|
1590
|
+
except Exception as e:
|
1591
|
+
logger.error(f"Error processing chat request: {e}", exc_info=True)
|
1592
|
+
await websocket.send_text(json.dumps({"error": "Internal server error"}))
|
1593
|
+
raise
|
1437
1594
|
|
1438
|
-
# Signal end of LLM response after the loop finishes
|
1439
|
-
if not cancellation_event.is_set():
|
1440
|
-
async for result in send_event(ChatEvent.END_LLM_RESPONSE, ""):
|
1441
|
-
yield result
|
1442
|
-
# Send Usage Metadata once llm interactions are complete
|
1443
|
-
if tracer.get("usage"):
|
1444
|
-
async for event in send_event(ChatEvent.USAGE, tracer.get("usage")):
|
1445
|
-
yield event
|
1446
|
-
async for result in send_event(ChatEvent.END_RESPONSE, ""):
|
1447
|
-
yield result
|
1448
|
-
logger.debug("Finished streaming response")
|
1449
1595
|
|
1450
|
-
|
1451
|
-
|
1596
|
+
@api_chat.post("")
|
1597
|
+
@requires(["authenticated"])
|
1598
|
+
async def chat(
|
1599
|
+
request: Request,
|
1600
|
+
common: CommonQueryParams,
|
1601
|
+
body: ChatRequestBody,
|
1602
|
+
rate_limiter_per_minute=Depends(
|
1603
|
+
ApiUserRateLimiter(requests=20, subscribed_requests=20, window=60, slug="chat_minute")
|
1604
|
+
),
|
1605
|
+
rate_limiter_per_day=Depends(
|
1606
|
+
ApiUserRateLimiter(requests=100, subscribed_requests=600, window=60 * 60 * 24, slug="chat_day")
|
1607
|
+
),
|
1608
|
+
image_rate_limiter=Depends(ApiImageRateLimiter(max_images=10, max_combined_size_mb=20)),
|
1609
|
+
):
|
1610
|
+
response_iterator = event_generator(
|
1611
|
+
body,
|
1612
|
+
request.user,
|
1613
|
+
common,
|
1614
|
+
request.headers,
|
1615
|
+
request,
|
1616
|
+
)
|
1452
1617
|
|
1453
|
-
|
1454
|
-
if stream:
|
1455
|
-
return StreamingResponse(
|
1456
|
-
|
1618
|
+
# Stream Text Response
|
1619
|
+
if body.stream:
|
1620
|
+
return StreamingResponse(response_iterator, media_type="text/plain")
|
1621
|
+
# Non-Streaming Text Response
|
1457
1622
|
else:
|
1458
|
-
response_iterator = event_generator(q, images=raw_images)
|
1459
1623
|
response_data = await read_chat_stream(response_iterator)
|
1460
1624
|
return Response(content=json.dumps(response_data), media_type="application/json", status_code=200)
|