khoj 2.0.0b8.dev3__py3-none-any.whl → 2.0.0b9__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (68) hide show
  1. khoj/database/adapters/__init__.py +3 -2
  2. khoj/database/models/__init__.py +28 -0
  3. khoj/interface/compiled/404/index.html +2 -2
  4. khoj/interface/compiled/_next/static/chunks/{2327-ea623ca2d22f78e9.js → 2327-d863b0b21ecb23de.js} +1 -1
  5. khoj/interface/compiled/_next/static/chunks/5477-c4209b72942d3038.js +1 -0
  6. khoj/interface/compiled/_next/static/chunks/{5639-09e2009a2adedf8b.js → 5639-0c4604668cb2d4c0.js} +1 -1
  7. khoj/interface/compiled/_next/static/chunks/9139-8ac4d9feb10f8869.js +1 -0
  8. khoj/interface/compiled/_next/static/chunks/app/agents/layout-e3d72f0edda6aa0c.js +1 -0
  9. khoj/interface/compiled/_next/static/chunks/app/agents/{page-0006674668eb5a4d.js → page-9a4610474cd59a71.js} +1 -1
  10. khoj/interface/compiled/_next/static/chunks/app/automations/{page-4c465cde2d14cb52.js → page-f7bb9d777b7745d4.js} +1 -1
  11. khoj/interface/compiled/_next/static/chunks/app/chat/layout-33934fc2d6ae6838.js +1 -0
  12. khoj/interface/compiled/_next/static/chunks/app/chat/page-a8455b8f9d36a2b0.js +1 -0
  13. khoj/interface/compiled/_next/static/chunks/app/{page-85b9b416898738f7.js → page-2025944ec1f80144.js} +1 -1
  14. khoj/interface/compiled/_next/static/chunks/app/search/layout-4505b79deb734a30.js +1 -0
  15. khoj/interface/compiled/_next/static/chunks/app/search/{page-883b7d8d2e3abe3e.js → page-4885df3cd175c957.js} +1 -1
  16. khoj/interface/compiled/_next/static/chunks/app/settings/{page-95e994ddac31473f.js → page-8be3b35178abf2ec.js} +1 -1
  17. khoj/interface/compiled/_next/static/chunks/app/share/chat/layout-6fb51c5c80f8ec67.js +1 -0
  18. khoj/interface/compiled/_next/static/chunks/app/share/chat/{page-c062269e6906ef22.js → page-ee8ef5270163e7f2.js} +1 -1
  19. khoj/interface/compiled/_next/static/chunks/{webpack-c375c47fee5a4dda.js → webpack-6355be48bba04af8.js} +1 -1
  20. khoj/interface/compiled/_next/static/css/102b97d6472fdd3a.css +1 -0
  21. khoj/interface/compiled/_next/static/css/37a73b87f02df402.css +1 -0
  22. khoj/interface/compiled/_next/static/css/c34713c98384ee87.css +1 -0
  23. khoj/interface/compiled/_next/static/css/fc82e43baa9ae218.css +1 -0
  24. khoj/interface/compiled/agents/index.html +2 -2
  25. khoj/interface/compiled/agents/index.txt +2 -2
  26. khoj/interface/compiled/automations/index.html +2 -2
  27. khoj/interface/compiled/automations/index.txt +3 -3
  28. khoj/interface/compiled/chat/index.html +2 -2
  29. khoj/interface/compiled/chat/index.txt +2 -2
  30. khoj/interface/compiled/index.html +2 -2
  31. khoj/interface/compiled/index.txt +2 -2
  32. khoj/interface/compiled/search/index.html +2 -2
  33. khoj/interface/compiled/search/index.txt +2 -2
  34. khoj/interface/compiled/settings/index.html +2 -2
  35. khoj/interface/compiled/settings/index.txt +4 -4
  36. khoj/interface/compiled/share/chat/index.html +2 -2
  37. khoj/interface/compiled/share/chat/index.txt +2 -2
  38. khoj/main.py +11 -1
  39. khoj/processor/conversation/utils.py +6 -6
  40. khoj/processor/operator/__init__.py +16 -1
  41. khoj/routers/api_chat.py +846 -682
  42. khoj/routers/helpers.py +149 -5
  43. khoj/routers/research.py +26 -1
  44. khoj/utils/rawconfig.py +0 -1
  45. {khoj-2.0.0b8.dev3.dist-info → khoj-2.0.0b9.dist-info}/METADATA +1 -1
  46. {khoj-2.0.0b8.dev3.dist-info → khoj-2.0.0b9.dist-info}/RECORD +57 -57
  47. khoj/interface/compiled/_next/static/chunks/5477-18323501c445315e.js +0 -1
  48. khoj/interface/compiled/_next/static/chunks/9568-0d60ac475f4cc538.js +0 -1
  49. khoj/interface/compiled/_next/static/chunks/app/agents/layout-e49165209d2e406c.js +0 -1
  50. khoj/interface/compiled/_next/static/chunks/app/chat/layout-d5ae861e1ade9d08.js +0 -1
  51. khoj/interface/compiled/_next/static/chunks/app/chat/page-b371304895e54627.js +0 -1
  52. khoj/interface/compiled/_next/static/chunks/app/search/layout-f5881c7ae3ba0795.js +0 -1
  53. khoj/interface/compiled/_next/static/chunks/app/share/chat/layout-64a53f8ec4afa6b3.js +0 -1
  54. khoj/interface/compiled/_next/static/css/76c658ee459140a9.css +0 -1
  55. khoj/interface/compiled/_next/static/css/a0c2fd63bb396f04.css +0 -1
  56. khoj/interface/compiled/_next/static/css/ee66643a6a5bf71c.css +0 -1
  57. khoj/interface/compiled/_next/static/css/fbacbdfd5e7f3f0e.css +0 -1
  58. /khoj/interface/compiled/_next/static/{RrD-yExRLXeXG-f3lGWYq → 6kC_Tt4g0U2gXGUbnSB1O}/_buildManifest.js +0 -0
  59. /khoj/interface/compiled/_next/static/{RrD-yExRLXeXG-f3lGWYq → 6kC_Tt4g0U2gXGUbnSB1O}/_ssgManifest.js +0 -0
  60. /khoj/interface/compiled/_next/static/chunks/{1327-3b1a41af530fa8ee.js → 1327-1a9107b9a2a04a98.js} +0 -0
  61. /khoj/interface/compiled/_next/static/chunks/{1915-fbfe167c84ad60c5.js → 1915-5c6508f6ebb62a30.js} +0 -0
  62. /khoj/interface/compiled/_next/static/chunks/{2117-e78b6902ad6f75ec.js → 2117-080746c8e170c81a.js} +0 -0
  63. /khoj/interface/compiled/_next/static/chunks/{2939-4d4084c5b888b960.js → 2939-4af3fd24b8ffc9ad.js} +0 -0
  64. /khoj/interface/compiled/_next/static/chunks/{4447-d6cf93724d57e34b.js → 4447-cd95608f8e93e711.js} +0 -0
  65. /khoj/interface/compiled/_next/static/chunks/{8667-4b7790573b08c50d.js → 8667-50b03a89e82e0ba7.js} +0 -0
  66. {khoj-2.0.0b8.dev3.dist-info → khoj-2.0.0b9.dist-info}/WHEEL +0 -0
  67. {khoj-2.0.0b8.dev3.dist-info → khoj-2.0.0b9.dist-info}/entry_points.txt +0 -0
  68. {khoj-2.0.0b8.dev3.dist-info → khoj-2.0.0b9.dist-info}/licenses/LICENSE +0 -0
khoj/routers/api_chat.py CHANGED
@@ -10,9 +10,18 @@ from typing import Any, Dict, List, Optional
10
10
  from urllib.parse import unquote
11
11
 
12
12
  from asgiref.sync import sync_to_async
13
- from fastapi import APIRouter, Depends, HTTPException, Request
13
+ from fastapi import (
14
+ APIRouter,
15
+ Depends,
16
+ HTTPException,
17
+ Request,
18
+ WebSocket,
19
+ WebSocketDisconnect,
20
+ )
14
21
  from fastapi.responses import RedirectResponse, Response, StreamingResponse
22
+ from fastapi.websockets import WebSocketState
15
23
  from starlette.authentication import has_required_scope, requires
24
+ from starlette.requests import URL, Headers
16
25
 
17
26
  from khoj.app.settings import ALLOWED_HOSTS
18
27
  from khoj.database.adapters import (
@@ -51,6 +60,7 @@ from khoj.routers.helpers import (
51
60
  ConversationCommandRateLimiter,
52
61
  DeleteMessageRequestBody,
53
62
  FeedbackData,
63
+ WebSocketConnectionManager,
54
64
  acreate_title_from_history,
55
65
  agenerate_chat_response,
56
66
  aget_data_sources_and_output_format,
@@ -60,6 +70,7 @@ from khoj.routers.helpers import (
60
70
  generate_mermaidjs_diagram,
61
71
  generate_summary_from_files,
62
72
  get_conversation_command,
73
+ get_message_from_queue,
63
74
  is_query_empty,
64
75
  is_ready_to_chat,
65
76
  read_chat_stream,
@@ -657,19 +668,13 @@ def delete_message(request: Request, delete_request: DeleteMessageRequestBody) -
657
668
  return Response(content=json.dumps({"status": "error", "message": "Message not found"}), status_code=404)
658
669
 
659
670
 
660
- @api_chat.post("")
661
- @requires(["authenticated"])
662
- async def chat(
663
- request: Request,
664
- common: CommonQueryParams,
671
+ async def event_generator(
665
672
  body: ChatRequestBody,
666
- rate_limiter_per_minute=Depends(
667
- ApiUserRateLimiter(requests=20, subscribed_requests=20, window=60, slug="chat_minute")
668
- ),
669
- rate_limiter_per_day=Depends(
670
- ApiUserRateLimiter(requests=100, subscribed_requests=600, window=60 * 60 * 24, slug="chat_day")
671
- ),
672
- image_rate_limiter=Depends(ApiImageRateLimiter(max_images=10, max_combined_size_mb=20)),
673
+ user_scope: Any,
674
+ common: CommonQueryParams,
675
+ headers: Headers,
676
+ request_obj: Request | WebSocket,
677
+ parent_interrupt_queue: asyncio.Queue = None,
673
678
  ):
674
679
  # Access the parameters from the body
675
680
  q = body.q
@@ -686,67 +691,68 @@ async def chat(
686
691
  timezone = body.timezone
687
692
  raw_images = body.images
688
693
  raw_query_files = body.files
689
- interrupt_flag = body.interrupt
690
-
691
- async def event_generator(q: str, images: list[str]):
692
- start_time = time.perf_counter()
693
- ttft = None
694
- chat_metadata: dict = {}
695
- conversation = None
696
- user: KhojUser = request.user.object
697
- is_subscribed = has_required_scope(request, ["premium"])
698
- q = unquote(q)
699
- train_of_thought = []
700
- nonlocal conversation_id
701
- nonlocal raw_query_files
702
- cancellation_event = asyncio.Event()
703
-
704
- tracer: dict = {
705
- "mid": turn_id,
706
- "cid": conversation_id,
707
- "uid": user.id,
708
- "khoj_version": state.khoj_version,
709
- }
710
694
 
711
- uploaded_images: list[str] = []
712
- if images:
713
- for image in images:
714
- decoded_string = unquote(image)
715
- base64_data = decoded_string.split(",", 1)[1]
716
- image_bytes = base64.b64decode(base64_data)
717
- webp_image_bytes = convert_image_to_webp(image_bytes)
718
- uploaded_image = upload_user_image_to_bucket(webp_image_bytes, request.user.object.id)
719
- if not uploaded_image:
720
- base64_webp_image = base64.b64encode(webp_image_bytes).decode("utf-8")
721
- uploaded_image = f"data:image/webp;base64,{base64_webp_image}"
722
- uploaded_images.append(uploaded_image)
723
-
724
- query_files: Dict[str, str] = {}
725
- if raw_query_files:
726
- for file in raw_query_files:
727
- query_files[file.name] = file.content
728
-
729
- research_results: List[ResearchIteration] = []
730
- online_results: Dict = dict()
731
- code_results: Dict = dict()
732
- operator_results: List[OperatorRun] = []
733
- compiled_references: List[Any] = []
734
- inferred_queries: List[Any] = []
735
- attached_file_context = gather_raw_query_files(query_files)
736
-
737
- generated_images: List[str] = []
738
- generated_files: List[FileAttachment] = []
739
- generated_mermaidjs_diagram: str = None
740
- generated_asset_results: Dict = dict()
741
- program_execution_context: List[str] = []
742
- chat_history: List[ChatMessageModel] = []
743
-
744
- # Create a task to monitor for disconnections
745
- disconnect_monitor_task = None
746
-
747
- async def monitor_disconnection():
695
+ start_time = time.perf_counter()
696
+ ttft = None
697
+ chat_metadata: dict = {}
698
+ conversation = None
699
+ user: KhojUser = user_scope.object
700
+ is_subscribed = has_required_scope(request_obj, ["premium"])
701
+ q = unquote(q)
702
+ defiltered_query = defilter_query(q)
703
+ train_of_thought = []
704
+ cancellation_event = asyncio.Event()
705
+ child_interrupt_queue: asyncio.Queue = asyncio.Queue(maxsize=10)
706
+ event_delimiter = "␃🔚␗"
707
+
708
+ tracer: dict = {
709
+ "mid": turn_id,
710
+ "cid": conversation_id,
711
+ "uid": user.id,
712
+ "khoj_version": state.khoj_version,
713
+ }
714
+
715
+ uploaded_images: list[str] = []
716
+ if raw_images:
717
+ for image in raw_images:
718
+ decoded_string = unquote(image)
719
+ base64_data = decoded_string.split(",", 1)[1]
720
+ image_bytes = base64.b64decode(base64_data)
721
+ webp_image_bytes = convert_image_to_webp(image_bytes)
722
+ uploaded_image = upload_user_image_to_bucket(webp_image_bytes, user.id)
723
+ if not uploaded_image:
724
+ base64_webp_image = base64.b64encode(webp_image_bytes).decode("utf-8")
725
+ uploaded_image = f"data:image/webp;base64,{base64_webp_image}"
726
+ uploaded_images.append(uploaded_image)
727
+
728
+ query_files: Dict[str, str] = {}
729
+ if raw_query_files:
730
+ for file in raw_query_files:
731
+ query_files[file.name] = file.content
732
+
733
+ research_results: List[ResearchIteration] = []
734
+ online_results: Dict = dict()
735
+ code_results: Dict = dict()
736
+ operator_results: List[OperatorRun] = []
737
+ compiled_references: List[Any] = []
738
+ inferred_queries: List[Any] = []
739
+ attached_file_context = gather_raw_query_files(query_files)
740
+
741
+ generated_images: List[str] = []
742
+ generated_files: List[FileAttachment] = []
743
+ generated_mermaidjs_diagram: str = None
744
+ generated_asset_results: Dict = dict()
745
+ program_execution_context: List[str] = []
746
+ user_message_time = datetime.now().strftime("%Y-%m-%d %H:%M:%S")
747
+
748
+ # Create a task to monitor for disconnections
749
+ disconnect_monitor_task = None
750
+
751
+ async def monitor_disconnection():
752
+ nonlocal q, defiltered_query
753
+ if isinstance(request_obj, Request):
748
754
  try:
749
- msg = await request.receive()
755
+ msg = await request_obj.receive()
750
756
  if msg["type"] == "http.disconnect":
751
757
  logger.debug(f"Request cancelled. User {user} disconnected from {common.client} client.")
752
758
  cancellation_event.set()
@@ -758,14 +764,13 @@ async def chat(
758
764
  q,
759
765
  chat_response="",
760
766
  user=user,
761
- chat_history=chat_history,
762
767
  compiled_references=compiled_references,
763
768
  online_results=online_results,
764
769
  code_results=code_results,
765
770
  operator_results=operator_results,
766
771
  research_results=research_results,
767
772
  inferred_queries=inferred_queries,
768
- client_application=request.user.client_app,
773
+ client_application=user_scope.client_app,
769
774
  conversation_id=conversation_id,
770
775
  query_images=uploaded_images,
771
776
  train_of_thought=train_of_thought,
@@ -773,502 +778,319 @@ async def chat(
773
778
  generated_images=generated_images,
774
779
  raw_generated_files=generated_asset_results,
775
780
  generated_mermaidjs_diagram=generated_mermaidjs_diagram,
781
+ user_message_time=user_message_time,
776
782
  tracer=tracer,
777
783
  )
778
784
  )
779
785
  except Exception as e:
780
786
  logger.error(f"Error in disconnect monitor: {e}")
781
-
782
- # Cancel the disconnect monitor task if it is still running
783
- async def cancel_disconnect_monitor():
784
- if disconnect_monitor_task and not disconnect_monitor_task.done():
785
- logger.debug(f"Cancelling disconnect monitor task for user {user}")
786
- disconnect_monitor_task.cancel()
787
- try:
788
- await disconnect_monitor_task
789
- except asyncio.CancelledError:
790
- pass
791
-
792
- async def send_event(event_type: ChatEvent, data: str | dict):
793
- nonlocal ttft, train_of_thought
794
- event_delimiter = "␃🔚␗"
795
- if cancellation_event.is_set():
796
- return
797
- try:
798
- if event_type == ChatEvent.END_LLM_RESPONSE:
799
- collect_telemetry()
800
- elif event_type == ChatEvent.START_LLM_RESPONSE:
801
- ttft = time.perf_counter() - start_time
802
- elif event_type == ChatEvent.STATUS:
803
- train_of_thought.append({"type": event_type.value, "data": data})
804
- elif event_type == ChatEvent.THOUGHT:
805
- # Append the data to the last thought as thoughts are streamed
806
- if (
807
- len(train_of_thought) > 0
808
- and train_of_thought[-1]["type"] == ChatEvent.THOUGHT.value
809
- and type(train_of_thought[-1]["data"]) == type(data) == str
810
- ):
811
- train_of_thought[-1]["data"] += data
787
+ elif isinstance(request_obj, WebSocket):
788
+ while request_obj.client_state == WebSocketState.CONNECTED and not cancellation_event.is_set():
789
+ await asyncio.sleep(1)
790
+
791
+ # Check if any interrupt query is received
792
+ if interrupt_query := get_message_from_queue(parent_interrupt_queue):
793
+ if interrupt_query == event_delimiter:
794
+ cancellation_event.set()
795
+ logger.debug(f"Chat cancelled by user {user} via interrupt queue.")
812
796
  else:
813
- train_of_thought.append({"type": event_type.value, "data": data})
814
-
815
- if event_type == ChatEvent.MESSAGE:
816
- yield data
817
- elif event_type == ChatEvent.REFERENCES or ChatEvent.METADATA or stream:
818
- yield json.dumps({"type": event_type.value, "data": data}, ensure_ascii=False)
819
- except Exception as e:
820
- if not cancellation_event.is_set():
821
- logger.error(
822
- f"Failed to stream chat API response to {user} on {common.client}: {e}.",
823
- exc_info=True,
797
+ # Pass the interrupt query to child tasks
798
+ logger.info(f"Continuing chat with the new instruction: {interrupt_query}")
799
+ await child_interrupt_queue.put(interrupt_query)
800
+ # Append the interrupt query to the main query
801
+ q += f"\n\n{interrupt_query}"
802
+ defiltered_query += f"\n\n{defilter_query(interrupt_query)}"
803
+
804
+ logger.debug(f"WebSocket disconnected or chat cancelled by user {user} from {common.client} client.")
805
+ if conversation and cancellation_event.is_set():
806
+ await asyncio.shield(
807
+ save_to_conversation_log(
808
+ q,
809
+ chat_response="",
810
+ user=user,
811
+ compiled_references=compiled_references,
812
+ online_results=online_results,
813
+ code_results=code_results,
814
+ operator_results=operator_results,
815
+ research_results=research_results,
816
+ inferred_queries=inferred_queries,
817
+ client_application=user_scope.client_app,
818
+ conversation_id=conversation_id,
819
+ query_images=uploaded_images,
820
+ train_of_thought=train_of_thought,
821
+ raw_query_files=raw_query_files,
822
+ generated_images=generated_images,
823
+ raw_generated_files=generated_asset_results,
824
+ generated_mermaidjs_diagram=generated_mermaidjs_diagram,
825
+ user_message_time=user_message_time,
826
+ tracer=tracer,
824
827
  )
825
- finally:
826
- if not cancellation_event.is_set():
827
- yield event_delimiter
828
- # Cancel the disconnect monitor task if it is still running
829
- if cancellation_event.is_set() or event_type == ChatEvent.END_RESPONSE:
830
- await cancel_disconnect_monitor()
831
-
832
- async def send_llm_response(response: str, usage: dict = None):
833
- # Check if the client is still connected
834
- if cancellation_event.is_set():
835
- return
836
- # Send Chat Response
837
- async for result in send_event(ChatEvent.START_LLM_RESPONSE, ""):
838
- yield result
839
- async for result in send_event(ChatEvent.MESSAGE, response):
840
- yield result
841
- async for result in send_event(ChatEvent.END_LLM_RESPONSE, ""):
842
- yield result
843
- # Send Usage Metadata once llm interactions are complete
844
- if usage:
845
- async for event in send_event(ChatEvent.USAGE, usage):
846
- yield event
847
- async for result in send_event(ChatEvent.END_RESPONSE, ""):
848
- yield result
849
-
850
- def collect_telemetry():
851
- # Gather chat response telemetry
852
- nonlocal chat_metadata
853
- latency = time.perf_counter() - start_time
854
- cmd_set = set([cmd.value for cmd in conversation_commands])
855
- cost = (tracer.get("usage", {}) or {}).get("cost", 0)
856
- chat_metadata = chat_metadata or {}
857
- chat_metadata["conversation_command"] = cmd_set
858
- chat_metadata["agent"] = conversation.agent.slug if conversation and conversation.agent else None
859
- chat_metadata["cost"] = f"{cost:.5f}"
860
- chat_metadata["latency"] = f"{latency:.3f}"
861
- if ttft:
862
- chat_metadata["ttft_latency"] = f"{ttft:.3f}"
863
- logger.info(f"Chat response time to first token: {ttft:.3f} seconds")
864
- logger.info(f"Chat response total time: {latency:.3f} seconds")
865
- logger.info(f"Chat response cost: ${cost:.5f}")
866
- update_telemetry_state(
867
- request=request,
868
- telemetry_type="api",
869
- api="chat",
870
- client=common.client,
871
- user_agent=request.headers.get("user-agent"),
872
- host=request.headers.get("host"),
873
- metadata=chat_metadata,
874
- )
875
-
876
- # Start the disconnect monitor in the background
877
- disconnect_monitor_task = asyncio.create_task(monitor_disconnection())
878
-
879
- if is_query_empty(q):
880
- async for result in send_llm_response("Please ask your query to get started.", tracer.get("usage")):
881
- yield result
882
- return
883
-
884
- # Automated tasks are handled before to allow mixing them with other conversation commands
885
- cmds_to_rate_limit = []
886
- is_automated_task = False
887
- if q.startswith("/automated_task"):
888
- is_automated_task = True
889
- q = q.replace("/automated_task", "").lstrip()
890
- cmds_to_rate_limit += [ConversationCommand.AutomatedTask]
828
+ )
891
829
 
892
- # Extract conversation command from query
893
- conversation_commands = [get_conversation_command(query=q)]
830
+ # Cancel the disconnect monitor task if it is still running
831
+ async def cancel_disconnect_monitor():
832
+ if disconnect_monitor_task and not disconnect_monitor_task.done():
833
+ logger.debug(f"Cancelling disconnect monitor task for user {user}")
834
+ disconnect_monitor_task.cancel()
835
+ try:
836
+ await disconnect_monitor_task
837
+ except asyncio.CancelledError:
838
+ pass
894
839
 
895
- conversation = await ConversationAdapters.aget_conversation_by_user(
896
- user,
897
- client_application=request.user.client_app,
898
- conversation_id=conversation_id,
899
- title=title,
900
- create_new=body.create_new,
901
- )
902
- if not conversation:
903
- async for result in send_llm_response(f"Conversation {conversation_id} not found", tracer.get("usage")):
904
- yield result
840
+ async def send_event(event_type: ChatEvent, data: str | dict):
841
+ nonlocal ttft, train_of_thought
842
+ if cancellation_event.is_set():
905
843
  return
906
- conversation_id = str(conversation.id)
907
-
908
- async for event in send_event(ChatEvent.METADATA, {"conversationId": conversation_id, "turnId": turn_id}):
909
- yield event
910
-
911
- agent: Agent | None = None
912
- default_agent = await AgentAdapters.aget_default_agent()
913
- if conversation.agent and conversation.agent != default_agent:
914
- agent = conversation.agent
915
-
916
- if not conversation.agent:
917
- conversation.agent = default_agent
918
- await conversation.asave()
919
- agent = default_agent
920
-
921
- await is_ready_to_chat(user)
922
- user_name = await aget_user_name(user)
923
- location = None
924
- if city or region or country or country_code:
925
- location = LocationData(city=city, region=region, country=country, country_code=country_code)
926
- user_message_time = datetime.now().strftime("%Y-%m-%d %H:%M:%S")
927
- chat_history = conversation.messages
928
-
929
- # If interrupt flag is set, wait for the previous turn to be saved before proceeding
930
- if interrupt_flag:
931
- max_wait_time = 20.0 # seconds
932
- wait_interval = 0.3 # seconds
933
- wait_start = wait_current = time.time()
934
- while wait_current - wait_start < max_wait_time:
935
- # Refresh conversation to check if interrupted message saved to DB
936
- conversation = await ConversationAdapters.aget_conversation_by_user(
937
- user,
938
- client_application=request.user.client_app,
939
- conversation_id=conversation_id,
940
- )
844
+ try:
845
+ if event_type == ChatEvent.END_LLM_RESPONSE:
846
+ collect_telemetry()
847
+ elif event_type == ChatEvent.START_LLM_RESPONSE:
848
+ ttft = time.perf_counter() - start_time
849
+ elif event_type == ChatEvent.STATUS:
850
+ train_of_thought.append({"type": event_type.value, "data": data})
851
+ elif event_type == ChatEvent.THOUGHT:
852
+ # Append the data to the last thought as thoughts are streamed
941
853
  if (
942
- conversation
943
- and conversation.messages
944
- and conversation.messages[-1].by == "khoj"
945
- and not conversation.messages[-1].message
854
+ len(train_of_thought) > 0
855
+ and train_of_thought[-1]["type"] == ChatEvent.THOUGHT.value
856
+ and type(train_of_thought[-1]["data"]) == type(data) == str
946
857
  ):
947
- logger.info(f"Detected interrupted message save to conversation {conversation_id}.")
948
- break
949
- await asyncio.sleep(wait_interval)
950
- wait_current = time.time()
951
-
952
- if wait_current - wait_start >= max_wait_time:
953
- logger.warning(
954
- f"Timeout waiting to load interrupted context from conversation {conversation_id}. Proceed without previous context."
955
- )
858
+ train_of_thought[-1]["data"] += data
859
+ else:
860
+ train_of_thought.append({"type": event_type.value, "data": data})
956
861
 
957
- # If interrupted message in DB
958
- if (
959
- conversation
960
- and conversation.messages
961
- and conversation.messages[-1].by == "khoj"
962
- and not conversation.messages[-1].message
963
- ):
964
- # Populate context from interrupted message
965
- last_message = conversation.messages[-1]
966
- online_results = {key: val.model_dump() for key, val in last_message.onlineContext.items() or []}
967
- code_results = {key: val.model_dump() for key, val in last_message.codeContext.items() or []}
968
- compiled_references = [ref.model_dump() for ref in last_message.context or []]
969
- research_results = [
970
- ResearchIteration(**iter_dict)
971
- for iter_dict in last_message.researchContext or []
972
- if iter_dict.get("summarizedResult")
973
- ]
974
- operator_results = [OperatorRun(**iter_dict) for iter_dict in last_message.operatorContext or []]
975
- train_of_thought = [thought.model_dump() for thought in last_message.trainOfThought or []]
976
- # Drop the interrupted message from conversation history
977
- chat_history.pop()
978
- logger.info(f"Loaded interrupted partial context from conversation {conversation_id}.")
979
-
980
- if conversation_commands == [ConversationCommand.Default]:
981
- try:
982
- chosen_io = await aget_data_sources_and_output_format(
983
- q,
984
- chat_history,
985
- is_automated_task,
986
- user=user,
987
- query_images=uploaded_images,
988
- agent=agent,
989
- query_files=attached_file_context,
990
- tracer=tracer,
862
+ if event_type == ChatEvent.MESSAGE:
863
+ yield data
864
+ elif event_type == ChatEvent.REFERENCES or ChatEvent.METADATA or stream:
865
+ yield json.dumps({"type": event_type.value, "data": data}, ensure_ascii=False)
866
+ except Exception as e:
867
+ if not cancellation_event.is_set():
868
+ logger.error(
869
+ f"Failed to stream chat API response to {user} on {common.client}: {e}.",
870
+ exc_info=True,
991
871
  )
992
- except ValueError as e:
993
- logger.error(f"Error getting data sources and output format: {e}. Falling back to default.")
994
- conversation_commands = [ConversationCommand.General]
995
-
996
- conversation_commands = chosen_io.get("sources") + [chosen_io.get("output")]
997
-
998
- # If we're doing research, we don't want to do anything else
999
- if ConversationCommand.Research in conversation_commands:
1000
- conversation_commands = [ConversationCommand.Research]
872
+ finally:
873
+ if not cancellation_event.is_set():
874
+ yield event_delimiter
875
+ # Cancel the disconnect monitor task if it is still running
876
+ if cancellation_event.is_set() or event_type == ChatEvent.END_RESPONSE:
877
+ await cancel_disconnect_monitor()
1001
878
 
1002
- conversation_commands_str = ", ".join([cmd.value for cmd in conversation_commands])
1003
- async for result in send_event(ChatEvent.STATUS, f"**Selected Tools:** {conversation_commands_str}"):
1004
- yield result
879
+ async def send_llm_response(response: str, usage: dict = None):
880
+ # Check if the client is still connected
881
+ if cancellation_event.is_set():
882
+ return
883
+ # Send Chat Response
884
+ async for result in send_event(ChatEvent.START_LLM_RESPONSE, ""):
885
+ yield result
886
+ async for result in send_event(ChatEvent.MESSAGE, response):
887
+ yield result
888
+ async for result in send_event(ChatEvent.END_LLM_RESPONSE, ""):
889
+ yield result
890
+ # Send Usage Metadata once llm interactions are complete
891
+ if usage:
892
+ async for event in send_event(ChatEvent.USAGE, usage):
893
+ yield event
894
+ async for result in send_event(ChatEvent.END_RESPONSE, ""):
895
+ yield result
1005
896
 
1006
- cmds_to_rate_limit += conversation_commands
1007
- for cmd in cmds_to_rate_limit:
1008
- try:
1009
- await conversation_command_rate_limiter.update_and_check_if_valid(request, cmd)
1010
- q = q.replace(f"/{cmd.value}", "").strip()
1011
- except HTTPException as e:
1012
- async for result in send_llm_response(str(e.detail), tracer.get("usage")):
1013
- yield result
1014
- return
897
+ def collect_telemetry():
898
+ # Gather chat response telemetry
899
+ nonlocal chat_metadata
900
+ latency = time.perf_counter() - start_time
901
+ cmd_set = set([cmd.value for cmd in conversation_commands])
902
+ cost = (tracer.get("usage", {}) or {}).get("cost", 0)
903
+ chat_metadata = chat_metadata or {}
904
+ chat_metadata["conversation_command"] = cmd_set
905
+ chat_metadata["agent"] = conversation.agent.slug if conversation and conversation.agent else None
906
+ chat_metadata["cost"] = f"{cost:.5f}"
907
+ chat_metadata["latency"] = f"{latency:.3f}"
908
+ if ttft:
909
+ chat_metadata["ttft_latency"] = f"{ttft:.3f}"
910
+ logger.info(f"Chat response time to first token: {ttft:.3f} seconds")
911
+ logger.info(f"Chat response total time: {latency:.3f} seconds")
912
+ logger.info(f"Chat response cost: ${cost:.5f}")
913
+ update_telemetry_state(
914
+ request=request_obj,
915
+ telemetry_type="api",
916
+ api="chat",
917
+ client=common.client,
918
+ user_agent=headers.get("user-agent"),
919
+ host=headers.get("host"),
920
+ metadata=chat_metadata,
921
+ )
1015
922
 
1016
- defiltered_query = defilter_query(q)
1017
- file_filters = conversation.file_filters if conversation and conversation.file_filters else []
923
+ # Start the disconnect monitor in the background
924
+ disconnect_monitor_task = asyncio.create_task(monitor_disconnection())
1018
925
 
1019
- if conversation_commands == [ConversationCommand.Research]:
1020
- async for research_result in research(
926
+ if is_query_empty(q):
927
+ async for result in send_llm_response("Please ask your query to get started.", tracer.get("usage")):
928
+ yield result
929
+ return
930
+
931
+ # Automated tasks are handled before to allow mixing them with other conversation commands
932
+ cmds_to_rate_limit = []
933
+ is_automated_task = False
934
+ if q.startswith("/automated_task"):
935
+ is_automated_task = True
936
+ q = q.replace("/automated_task", "").lstrip()
937
+ cmds_to_rate_limit += [ConversationCommand.AutomatedTask]
938
+
939
+ # Extract conversation command from query
940
+ conversation_commands = [get_conversation_command(query=q)]
941
+
942
+ conversation = await ConversationAdapters.aget_conversation_by_user(
943
+ user,
944
+ client_application=user_scope.client_app,
945
+ conversation_id=conversation_id,
946
+ title=title,
947
+ create_new=body.create_new,
948
+ )
949
+ if not conversation:
950
+ async for result in send_llm_response(f"Conversation {conversation_id} not found", tracer.get("usage")):
951
+ yield result
952
+ return
953
+ conversation_id = str(conversation.id)
954
+
955
+ async for event in send_event(ChatEvent.METADATA, {"conversationId": conversation_id, "turnId": turn_id}):
956
+ yield event
957
+
958
+ agent: Agent | None = None
959
+ default_agent = await AgentAdapters.aget_default_agent()
960
+ if conversation.agent and conversation.agent != default_agent:
961
+ agent = conversation.agent
962
+
963
+ if not conversation.agent:
964
+ conversation.agent = default_agent
965
+ await conversation.asave()
966
+ agent = default_agent
967
+
968
+ await is_ready_to_chat(user)
969
+ user_name = await aget_user_name(user)
970
+ location = None
971
+ if city or region or country or country_code:
972
+ location = LocationData(city=city, region=region, country=country, country_code=country_code)
973
+ chat_history = conversation.messages
974
+
975
+ # If interrupted message in DB
976
+ if last_message := await conversation.pop_message(interrupted=True):
977
+ # Populate context from interrupted message
978
+ online_results = {key: val.model_dump() for key, val in last_message.onlineContext.items() or []}
979
+ code_results = {key: val.model_dump() for key, val in last_message.codeContext.items() or []}
980
+ compiled_references = [ref.model_dump() for ref in last_message.context or []]
981
+ research_results = [
982
+ ResearchIteration(**iter_dict)
983
+ for iter_dict in last_message.researchContext or []
984
+ if iter_dict.get("summarizedResult")
985
+ ]
986
+ operator_results = [OperatorRun(**iter_dict) for iter_dict in last_message.operatorContext or []]
987
+ train_of_thought = [thought.model_dump() for thought in last_message.trainOfThought or []]
988
+ logger.info(f"Loaded interrupted partial context from conversation {conversation_id}.")
989
+
990
+ if conversation_commands == [ConversationCommand.Default]:
991
+ try:
992
+ chosen_io = await aget_data_sources_and_output_format(
993
+ q,
994
+ chat_history,
995
+ is_automated_task,
1021
996
  user=user,
1022
- query=defiltered_query,
1023
- conversation_id=conversation_id,
1024
- conversation_history=chat_history,
1025
- previous_iterations=list(research_results),
1026
997
  query_images=uploaded_images,
1027
998
  agent=agent,
1028
- send_status_func=partial(send_event, ChatEvent.STATUS),
1029
- user_name=user_name,
1030
- location=location,
1031
- file_filters=file_filters,
1032
999
  query_files=attached_file_context,
1033
1000
  tracer=tracer,
1034
- cancellation_event=cancellation_event,
1035
- ):
1036
- if isinstance(research_result, ResearchIteration):
1037
- if research_result.summarizedResult:
1038
- if research_result.onlineContext:
1039
- online_results.update(research_result.onlineContext)
1040
- if research_result.codeContext:
1041
- code_results.update(research_result.codeContext)
1042
- if research_result.context:
1043
- compiled_references.extend(research_result.context)
1044
- if not research_results or research_results[-1] is not research_result:
1045
- research_results.append(research_result)
1046
- else:
1047
- yield research_result
1048
-
1049
- # Track operator results across research and operator iterations
1050
- # This relies on two conditions:
1051
- # 1. Check to append new (partial) operator results
1052
- # Relies on triggering this check on every status updates.
1053
- # Status updates cascade up from operator to research to chat api on every step.
1054
- # 2. Keep operator results in sync with each research operator step
1055
- # Relies on python object references to ensure operator results
1056
- # are implicitly kept in sync after the initial append
1057
- if (
1058
- research_results
1059
- and research_results[-1].operatorContext
1060
- and (not operator_results or operator_results[-1] is not research_results[-1].operatorContext)
1061
- ):
1062
- operator_results.append(research_results[-1].operatorContext)
1063
-
1064
- # researched_results = await extract_relevant_info(q, researched_results, agent)
1065
- if state.verbose > 1:
1066
- logger.debug(f'Researched Results: {"".join(r.summarizedResult or "" for r in research_results)}')
1067
-
1068
- # Gather Context
1069
- ## Extract Document References
1070
- if not ConversationCommand.Research in conversation_commands:
1071
- try:
1072
- async for result in search_documents(
1073
- q,
1074
- (n or 7),
1075
- d,
1076
- user,
1077
- chat_history,
1078
- conversation_id,
1079
- conversation_commands,
1080
- location,
1081
- partial(send_event, ChatEvent.STATUS),
1082
- query_images=uploaded_images,
1083
- agent=agent,
1084
- query_files=attached_file_context,
1085
- tracer=tracer,
1086
- ):
1087
- if isinstance(result, dict) and ChatEvent.STATUS in result:
1088
- yield result[ChatEvent.STATUS]
1089
- else:
1090
- compiled_references.extend(result[0])
1091
- inferred_queries.extend(result[1])
1092
- defiltered_query = result[2]
1093
- except Exception as e:
1094
- error_message = (
1095
- f"Error searching knowledge base: {e}. Attempting to respond without document references."
1096
- )
1097
- logger.error(error_message, exc_info=True)
1098
- async for result in send_event(
1099
- ChatEvent.STATUS, "Document search failed. I'll try respond without document references"
1100
- ):
1101
- yield result
1102
-
1103
- if not is_none_or_empty(compiled_references):
1104
- headings = "\n- " + "\n- ".join(set([c.get("compiled", c).split("\n")[0] for c in compiled_references]))
1105
- # Strip only leading # from headings
1106
- headings = headings.replace("#", "")
1107
- async for result in send_event(ChatEvent.STATUS, f"**Found Relevant Notes**: {headings}"):
1108
- yield result
1001
+ )
1002
+ except ValueError as e:
1003
+ logger.error(f"Error getting data sources and output format: {e}. Falling back to default.")
1004
+ conversation_commands = [ConversationCommand.General]
1109
1005
 
1110
- if conversation_commands == [ConversationCommand.Notes] and not await EntryAdapters.auser_has_entries(user):
1111
- async for result in send_llm_response(f"{no_entries_found.format()}", tracer.get("usage")):
1112
- yield result
1113
- return
1006
+ conversation_commands = chosen_io.get("sources") + [chosen_io.get("output")]
1114
1007
 
1115
- if ConversationCommand.Notes in conversation_commands and is_none_or_empty(compiled_references):
1116
- conversation_commands.remove(ConversationCommand.Notes)
1008
+ # If we're doing research, we don't want to do anything else
1009
+ if ConversationCommand.Research in conversation_commands:
1010
+ conversation_commands = [ConversationCommand.Research]
1117
1011
 
1118
- ## Gather Online References
1119
- if ConversationCommand.Online in conversation_commands:
1120
- try:
1121
- async for result in search_online(
1122
- defiltered_query,
1123
- chat_history,
1124
- location,
1125
- user,
1126
- partial(send_event, ChatEvent.STATUS),
1127
- custom_filters=[],
1128
- max_online_searches=3,
1129
- query_images=uploaded_images,
1130
- query_files=attached_file_context,
1131
- agent=agent,
1132
- tracer=tracer,
1133
- ):
1134
- if isinstance(result, dict) and ChatEvent.STATUS in result:
1135
- yield result[ChatEvent.STATUS]
1136
- else:
1137
- online_results = result
1138
- except Exception as e:
1139
- error_message = f"Error searching online: {e}. Attempting to respond without online results"
1140
- logger.warning(error_message)
1141
- async for result in send_event(
1142
- ChatEvent.STATUS, "Online search failed. I'll try respond without online references"
1143
- ):
1144
- yield result
1012
+ conversation_commands_str = ", ".join([cmd.value for cmd in conversation_commands])
1013
+ async for result in send_event(ChatEvent.STATUS, f"**Selected Tools:** {conversation_commands_str}"):
1014
+ yield result
1145
1015
 
1146
- ## Gather Webpage References
1147
- if ConversationCommand.Webpage in conversation_commands:
1148
- try:
1149
- async for result in read_webpages(
1150
- defiltered_query,
1151
- chat_history,
1152
- location,
1153
- user,
1154
- partial(send_event, ChatEvent.STATUS),
1155
- max_webpages_to_read=1,
1156
- query_images=uploaded_images,
1157
- agent=agent,
1158
- query_files=attached_file_context,
1159
- tracer=tracer,
1160
- ):
1161
- if isinstance(result, dict) and ChatEvent.STATUS in result:
1162
- yield result[ChatEvent.STATUS]
1163
- else:
1164
- direct_web_pages = result
1165
- webpages = []
1166
- for query in direct_web_pages:
1167
- if online_results.get(query):
1168
- online_results[query]["webpages"] = direct_web_pages[query]["webpages"]
1169
- else:
1170
- online_results[query] = {"webpages": direct_web_pages[query]["webpages"]}
1016
+ cmds_to_rate_limit += conversation_commands
1017
+ for cmd in cmds_to_rate_limit:
1018
+ try:
1019
+ await conversation_command_rate_limiter.update_and_check_if_valid(request_obj, cmd)
1020
+ q = q.replace(f"/{cmd.value}", "").strip()
1021
+ except HTTPException as e:
1022
+ async for result in send_llm_response(str(e.detail), tracer.get("usage")):
1023
+ yield result
1024
+ return
1171
1025
 
1172
- for webpage in direct_web_pages[query]["webpages"]:
1173
- webpages.append(webpage["link"])
1174
- async for result in send_event(ChatEvent.STATUS, f"**Read web pages**: {webpages}"):
1175
- yield result
1176
- except Exception as e:
1177
- logger.warning(
1178
- f"Error reading webpages: {e}. Attempting to respond without webpage results",
1179
- exc_info=True,
1180
- )
1181
- async for result in send_event(
1182
- ChatEvent.STATUS, "Webpage read failed. I'll try respond without webpage references"
1183
- ):
1184
- yield result
1026
+ defiltered_query = defilter_query(q)
1027
+ file_filters = conversation.file_filters if conversation and conversation.file_filters else []
1185
1028
 
1186
- ## Gather Code Results
1187
- if ConversationCommand.Code in conversation_commands:
1188
- try:
1189
- context = f"# Iteration 1:\n#---\nNotes:\n{compiled_references}\n\nOnline Results:{online_results}"
1190
- async for result in run_code(
1191
- defiltered_query,
1192
- chat_history,
1193
- context,
1194
- location,
1195
- user,
1196
- partial(send_event, ChatEvent.STATUS),
1197
- query_images=uploaded_images,
1198
- agent=agent,
1199
- query_files=attached_file_context,
1200
- tracer=tracer,
1201
- ):
1202
- if isinstance(result, dict) and ChatEvent.STATUS in result:
1203
- yield result[ChatEvent.STATUS]
1204
- else:
1205
- code_results = result
1206
- except ValueError as e:
1207
- program_execution_context.append(f"Failed to run code")
1208
- logger.warning(
1209
- f"Failed to use code tool: {e}. Attempting to respond without code results",
1210
- exc_info=True,
1211
- )
1212
- if ConversationCommand.Operator in conversation_commands:
1213
- try:
1214
- async for result in operate_environment(
1215
- defiltered_query,
1216
- user,
1217
- chat_history,
1218
- location,
1219
- list(operator_results)[-1] if operator_results else None,
1220
- query_images=uploaded_images,
1221
- query_files=attached_file_context,
1222
- send_status_func=partial(send_event, ChatEvent.STATUS),
1223
- agent=agent,
1224
- cancellation_event=cancellation_event,
1225
- tracer=tracer,
1226
- ):
1227
- if isinstance(result, dict) and ChatEvent.STATUS in result:
1228
- yield result[ChatEvent.STATUS]
1229
- elif isinstance(result, OperatorRun):
1230
- if not operator_results or operator_results[-1] is not result:
1231
- operator_results.append(result)
1232
- # Add webpages visited while operating browser to references
1233
- if result.webpages:
1234
- if not online_results.get(defiltered_query):
1235
- online_results[defiltered_query] = {"webpages": result.webpages}
1236
- elif not online_results[defiltered_query].get("webpages"):
1237
- online_results[defiltered_query]["webpages"] = result.webpages
1238
- else:
1239
- online_results[defiltered_query]["webpages"] += result.webpages
1240
- except ValueError as e:
1241
- program_execution_context.append(f"Browser operation error: {e}")
1242
- logger.warning(f"Failed to operate browser with {e}", exc_info=True)
1243
- async for result in send_event(
1244
- ChatEvent.STATUS, "Operating browser failed. I'll try respond appropriately"
1245
- ):
1246
- yield result
1247
-
1248
- ## Send Gathered References
1249
- unique_online_results = deduplicate_organic_results(online_results)
1250
- async for result in send_event(
1251
- ChatEvent.REFERENCES,
1252
- {
1253
- "inferredQueries": inferred_queries,
1254
- "context": compiled_references,
1255
- "onlineContext": unique_online_results,
1256
- "codeContext": code_results,
1257
- },
1029
+ if conversation_commands == [ConversationCommand.Research]:
1030
+ async for research_result in research(
1031
+ user=user,
1032
+ query=defiltered_query,
1033
+ conversation_id=conversation_id,
1034
+ conversation_history=chat_history,
1035
+ previous_iterations=list(research_results),
1036
+ query_images=uploaded_images,
1037
+ agent=agent,
1038
+ send_status_func=partial(send_event, ChatEvent.STATUS),
1039
+ user_name=user_name,
1040
+ location=location,
1041
+ file_filters=file_filters,
1042
+ query_files=attached_file_context,
1043
+ tracer=tracer,
1044
+ cancellation_event=cancellation_event,
1045
+ interrupt_queue=child_interrupt_queue,
1046
+ abort_message=event_delimiter,
1258
1047
  ):
1259
- yield result
1048
+ if isinstance(research_result, ResearchIteration):
1049
+ if research_result.summarizedResult:
1050
+ if research_result.onlineContext:
1051
+ online_results.update(research_result.onlineContext)
1052
+ if research_result.codeContext:
1053
+ code_results.update(research_result.codeContext)
1054
+ if research_result.context:
1055
+ compiled_references.extend(research_result.context)
1056
+ if not research_results or research_results[-1] is not research_result:
1057
+ research_results.append(research_result)
1058
+ else:
1059
+ yield research_result
1060
+
1061
+ # Track operator results across research and operator iterations
1062
+ # This relies on two conditions:
1063
+ # 1. Check to append new (partial) operator results
1064
+ # Relies on triggering this check on every status updates.
1065
+ # Status updates cascade up from operator to research to chat api on every step.
1066
+ # 2. Keep operator results in sync with each research operator step
1067
+ # Relies on python object references to ensure operator results
1068
+ # are implicitly kept in sync after the initial append
1069
+ if (
1070
+ research_results
1071
+ and research_results[-1].operatorContext
1072
+ and (not operator_results or operator_results[-1] is not research_results[-1].operatorContext)
1073
+ ):
1074
+ operator_results.append(research_results[-1].operatorContext)
1260
1075
 
1261
- # Generate Output
1262
- ## Generate Image Output
1263
- if ConversationCommand.Image in conversation_commands:
1264
- async for result in text_to_image(
1265
- defiltered_query,
1076
+ # researched_results = await extract_relevant_info(q, researched_results, agent)
1077
+ if state.verbose > 1:
1078
+ logger.debug(f'Researched Results: {"".join(r.summarizedResult or "" for r in research_results)}')
1079
+
1080
+ # Gather Context
1081
+ ## Extract Document References
1082
+ if not ConversationCommand.Research in conversation_commands:
1083
+ try:
1084
+ async for result in search_documents(
1085
+ q,
1086
+ (n or 7),
1087
+ d,
1266
1088
  user,
1267
1089
  chat_history,
1268
- location_data=location,
1269
- references=compiled_references,
1270
- online_results=online_results,
1271
- send_status_func=partial(send_event, ChatEvent.STATUS),
1090
+ conversation_id,
1091
+ conversation_commands,
1092
+ location,
1093
+ partial(send_event, ChatEvent.STATUS),
1272
1094
  query_images=uploaded_images,
1273
1095
  agent=agent,
1274
1096
  query_files=attached_file_context,
@@ -1277,184 +1099,526 @@ async def chat(
1277
1099
  if isinstance(result, dict) and ChatEvent.STATUS in result:
1278
1100
  yield result[ChatEvent.STATUS]
1279
1101
  else:
1280
- generated_image, status_code, improved_image_prompt = result
1102
+ compiled_references.extend(result[0])
1103
+ inferred_queries.extend(result[1])
1104
+ defiltered_query = result[2]
1105
+ except Exception as e:
1106
+ error_message = f"Error searching knowledge base: {e}. Attempting to respond without document references."
1107
+ logger.error(error_message, exc_info=True)
1108
+ async for result in send_event(
1109
+ ChatEvent.STATUS, "Document search failed. I'll try respond without document references"
1110
+ ):
1111
+ yield result
1281
1112
 
1282
- inferred_queries.append(improved_image_prompt)
1283
- if generated_image is None or status_code != 200:
1284
- program_execution_context.append(f"Failed to generate image with {improved_image_prompt}")
1285
- async for result in send_event(ChatEvent.STATUS, f"Failed to generate image"):
1286
- yield result
1287
- else:
1288
- generated_images.append(generated_image)
1113
+ if not is_none_or_empty(compiled_references):
1114
+ distinct_headings = set([d.get("compiled").split("\n")[0] for d in compiled_references if "compiled" in d])
1115
+ distinct_files = set([d["file"] for d in compiled_references])
1116
+ # Strip only leading # from headings
1117
+ headings_str = "\n- " + "\n- ".join(distinct_headings).replace("#", "")
1118
+ async for result in send_event(
1119
+ ChatEvent.STATUS,
1120
+ f"**Found {len(distinct_headings)} Notes Across {len(distinct_files)} Files**: {headings_str}",
1121
+ ):
1122
+ yield result
1289
1123
 
1290
- generated_asset_results["images"] = {
1291
- "query": improved_image_prompt,
1292
- }
1124
+ if conversation_commands == [ConversationCommand.Notes] and not await EntryAdapters.auser_has_entries(user):
1125
+ async for result in send_llm_response(f"{no_entries_found.format()}", tracer.get("usage")):
1126
+ yield result
1127
+ return
1293
1128
 
1294
- async for result in send_event(
1295
- ChatEvent.GENERATED_ASSETS,
1296
- {
1297
- "images": [generated_image],
1298
- },
1299
- ):
1300
- yield result
1129
+ if ConversationCommand.Notes in conversation_commands and is_none_or_empty(compiled_references):
1130
+ conversation_commands.remove(ConversationCommand.Notes)
1301
1131
 
1302
- if ConversationCommand.Diagram in conversation_commands:
1303
- async for result in send_event(ChatEvent.STATUS, f"Creating diagram"):
1132
+ ## Gather Online References
1133
+ if ConversationCommand.Online in conversation_commands:
1134
+ try:
1135
+ async for result in search_online(
1136
+ defiltered_query,
1137
+ chat_history,
1138
+ location,
1139
+ user,
1140
+ partial(send_event, ChatEvent.STATUS),
1141
+ custom_filters=[],
1142
+ max_online_searches=3,
1143
+ query_images=uploaded_images,
1144
+ query_files=attached_file_context,
1145
+ agent=agent,
1146
+ tracer=tracer,
1147
+ ):
1148
+ if isinstance(result, dict) and ChatEvent.STATUS in result:
1149
+ yield result[ChatEvent.STATUS]
1150
+ else:
1151
+ online_results = result
1152
+ except Exception as e:
1153
+ error_message = f"Error searching online: {e}. Attempting to respond without online results"
1154
+ logger.warning(error_message)
1155
+ async for result in send_event(
1156
+ ChatEvent.STATUS, "Online search failed. I'll try respond without online references"
1157
+ ):
1304
1158
  yield result
1305
1159
 
1306
- inferred_queries = []
1307
- diagram_description = ""
1160
+ ## Gather Webpage References
1161
+ if ConversationCommand.Webpage in conversation_commands:
1162
+ try:
1163
+ async for result in read_webpages(
1164
+ defiltered_query,
1165
+ chat_history,
1166
+ location,
1167
+ user,
1168
+ partial(send_event, ChatEvent.STATUS),
1169
+ max_webpages_to_read=1,
1170
+ query_images=uploaded_images,
1171
+ agent=agent,
1172
+ query_files=attached_file_context,
1173
+ tracer=tracer,
1174
+ ):
1175
+ if isinstance(result, dict) and ChatEvent.STATUS in result:
1176
+ yield result[ChatEvent.STATUS]
1177
+ else:
1178
+ direct_web_pages = result
1179
+ webpages = []
1180
+ for query in direct_web_pages:
1181
+ if online_results.get(query):
1182
+ online_results[query]["webpages"] = direct_web_pages[query]["webpages"]
1183
+ else:
1184
+ online_results[query] = {"webpages": direct_web_pages[query]["webpages"]}
1308
1185
 
1309
- async for result in generate_mermaidjs_diagram(
1310
- q=defiltered_query,
1311
- chat_history=chat_history,
1312
- location_data=location,
1313
- note_references=compiled_references,
1314
- online_results=online_results,
1186
+ for webpage in direct_web_pages[query]["webpages"]:
1187
+ webpages.append(webpage["link"])
1188
+ async for result in send_event(ChatEvent.STATUS, f"**Read web pages**: {webpages}"):
1189
+ yield result
1190
+ except Exception as e:
1191
+ logger.warning(
1192
+ f"Error reading webpages: {e}. Attempting to respond without webpage results",
1193
+ exc_info=True,
1194
+ )
1195
+ async for result in send_event(
1196
+ ChatEvent.STATUS, "Webpage read failed. I'll try respond without webpage references"
1197
+ ):
1198
+ yield result
1199
+
1200
+ ## Gather Code Results
1201
+ if ConversationCommand.Code in conversation_commands:
1202
+ try:
1203
+ context = f"# Iteration 1:\n#---\nNotes:\n{compiled_references}\n\nOnline Results:{online_results}"
1204
+ async for result in run_code(
1205
+ defiltered_query,
1206
+ chat_history,
1207
+ context,
1208
+ location,
1209
+ user,
1210
+ partial(send_event, ChatEvent.STATUS),
1315
1211
  query_images=uploaded_images,
1316
- user=user,
1317
1212
  agent=agent,
1318
- send_status_func=partial(send_event, ChatEvent.STATUS),
1319
1213
  query_files=attached_file_context,
1320
1214
  tracer=tracer,
1321
1215
  ):
1322
1216
  if isinstance(result, dict) and ChatEvent.STATUS in result:
1323
1217
  yield result[ChatEvent.STATUS]
1324
1218
  else:
1325
- better_diagram_description_prompt, mermaidjs_diagram_description = result
1326
- if better_diagram_description_prompt and mermaidjs_diagram_description:
1327
- inferred_queries.append(better_diagram_description_prompt)
1328
- diagram_description = mermaidjs_diagram_description
1329
-
1330
- generated_mermaidjs_diagram = diagram_description
1331
-
1332
- generated_asset_results["diagrams"] = {
1333
- "query": better_diagram_description_prompt,
1334
- }
1335
-
1336
- async for result in send_event(
1337
- ChatEvent.GENERATED_ASSETS,
1338
- {
1339
- "mermaidjsDiagram": mermaidjs_diagram_description,
1340
- },
1341
- ):
1342
- yield result
1343
- else:
1344
- error_message = "Failed to generate diagram. Please try again later."
1345
- program_execution_context.append(
1346
- prompts.failed_diagram_generation.format(
1347
- attempted_diagram=better_diagram_description_prompt
1348
- )
1349
- )
1219
+ code_results = result
1220
+ except ValueError as e:
1221
+ program_execution_context.append(f"Failed to run code")
1222
+ logger.warning(
1223
+ f"Failed to use code tool: {e}. Attempting to respond without code results",
1224
+ exc_info=True,
1225
+ )
1226
+ if ConversationCommand.Operator in conversation_commands:
1227
+ try:
1228
+ async for result in operate_environment(
1229
+ defiltered_query,
1230
+ user,
1231
+ chat_history,
1232
+ location,
1233
+ list(operator_results)[-1] if operator_results else None,
1234
+ query_images=uploaded_images,
1235
+ query_files=attached_file_context,
1236
+ send_status_func=partial(send_event, ChatEvent.STATUS),
1237
+ agent=agent,
1238
+ cancellation_event=cancellation_event,
1239
+ interrupt_queue=child_interrupt_queue,
1240
+ tracer=tracer,
1241
+ ):
1242
+ if isinstance(result, dict) and ChatEvent.STATUS in result:
1243
+ yield result[ChatEvent.STATUS]
1244
+ elif isinstance(result, OperatorRun):
1245
+ if not operator_results or operator_results[-1] is not result:
1246
+ operator_results.append(result)
1247
+ # Add webpages visited while operating browser to references
1248
+ if result.webpages:
1249
+ if not online_results.get(defiltered_query):
1250
+ online_results[defiltered_query] = {"webpages": result.webpages}
1251
+ elif not online_results[defiltered_query].get("webpages"):
1252
+ online_results[defiltered_query]["webpages"] = result.webpages
1253
+ else:
1254
+ online_results[defiltered_query]["webpages"] += result.webpages
1255
+ except ValueError as e:
1256
+ program_execution_context.append(f"Browser operation error: {e}")
1257
+ logger.warning(f"Failed to operate browser with {e}", exc_info=True)
1258
+ async for result in send_event(
1259
+ ChatEvent.STATUS, "Operating browser failed. I'll try respond appropriately"
1260
+ ):
1261
+ yield result
1262
+
1263
+ ## Send Gathered References
1264
+ unique_online_results = deduplicate_organic_results(online_results)
1265
+ async for result in send_event(
1266
+ ChatEvent.REFERENCES,
1267
+ {
1268
+ "inferredQueries": inferred_queries,
1269
+ "context": compiled_references,
1270
+ "onlineContext": unique_online_results,
1271
+ "codeContext": code_results,
1272
+ },
1273
+ ):
1274
+ yield result
1275
+
1276
+ # Generate Output
1277
+ ## Generate Image Output
1278
+ if ConversationCommand.Image in conversation_commands:
1279
+ async for result in text_to_image(
1280
+ defiltered_query,
1281
+ user,
1282
+ chat_history,
1283
+ location_data=location,
1284
+ references=compiled_references,
1285
+ online_results=online_results,
1286
+ send_status_func=partial(send_event, ChatEvent.STATUS),
1287
+ query_images=uploaded_images,
1288
+ agent=agent,
1289
+ query_files=attached_file_context,
1290
+ tracer=tracer,
1291
+ ):
1292
+ if isinstance(result, dict) and ChatEvent.STATUS in result:
1293
+ yield result[ChatEvent.STATUS]
1294
+ else:
1295
+ generated_image, status_code, improved_image_prompt = result
1296
+
1297
+ inferred_queries.append(improved_image_prompt)
1298
+ if generated_image is None or status_code != 200:
1299
+ program_execution_context.append(f"Failed to generate image with {improved_image_prompt}")
1300
+ async for result in send_event(ChatEvent.STATUS, f"Failed to generate image"):
1301
+ yield result
1302
+ else:
1303
+ generated_images.append(generated_image)
1304
+
1305
+ generated_asset_results["images"] = {
1306
+ "query": improved_image_prompt,
1307
+ }
1308
+
1309
+ async for result in send_event(
1310
+ ChatEvent.GENERATED_ASSETS,
1311
+ {
1312
+ "images": [generated_image],
1313
+ },
1314
+ ):
1315
+ yield result
1316
+
1317
+ if ConversationCommand.Diagram in conversation_commands:
1318
+ async for result in send_event(ChatEvent.STATUS, f"Creating diagram"):
1319
+ yield result
1350
1320
 
1351
- async for result in send_event(ChatEvent.STATUS, error_message):
1352
- yield result
1321
+ inferred_queries = []
1322
+ diagram_description = ""
1323
+
1324
+ async for result in generate_mermaidjs_diagram(
1325
+ q=defiltered_query,
1326
+ chat_history=chat_history,
1327
+ location_data=location,
1328
+ note_references=compiled_references,
1329
+ online_results=online_results,
1330
+ query_images=uploaded_images,
1331
+ user=user,
1332
+ agent=agent,
1333
+ send_status_func=partial(send_event, ChatEvent.STATUS),
1334
+ query_files=attached_file_context,
1335
+ tracer=tracer,
1336
+ ):
1337
+ if isinstance(result, dict) and ChatEvent.STATUS in result:
1338
+ yield result[ChatEvent.STATUS]
1339
+ else:
1340
+ better_diagram_description_prompt, mermaidjs_diagram_description = result
1341
+ if better_diagram_description_prompt and mermaidjs_diagram_description:
1342
+ inferred_queries.append(better_diagram_description_prompt)
1343
+ diagram_description = mermaidjs_diagram_description
1344
+
1345
+ generated_mermaidjs_diagram = diagram_description
1346
+
1347
+ generated_asset_results["diagrams"] = {
1348
+ "query": better_diagram_description_prompt,
1349
+ }
1350
+
1351
+ async for result in send_event(
1352
+ ChatEvent.GENERATED_ASSETS,
1353
+ {
1354
+ "mermaidjsDiagram": mermaidjs_diagram_description,
1355
+ },
1356
+ ):
1357
+ yield result
1358
+ else:
1359
+ error_message = "Failed to generate diagram. Please try again later."
1360
+ program_execution_context.append(
1361
+ prompts.failed_diagram_generation.format(attempted_diagram=better_diagram_description_prompt)
1362
+ )
1363
+
1364
+ async for result in send_event(ChatEvent.STATUS, error_message):
1365
+ yield result
1366
+
1367
+ # Check if the user has disconnected
1368
+ if cancellation_event.is_set():
1369
+ logger.debug(f"Stopping LLM response to user {user} on {common.client} client.")
1370
+ # Cancel the disconnect monitor task if it is still running
1371
+ await cancel_disconnect_monitor()
1372
+ return
1373
+
1374
+ ## Generate Text Output
1375
+ async for result in send_event(ChatEvent.STATUS, f"**Generating a well-informed response**"):
1376
+ yield result
1377
+
1378
+ llm_response, chat_metadata = await agenerate_chat_response(
1379
+ defiltered_query,
1380
+ chat_history,
1381
+ conversation,
1382
+ compiled_references,
1383
+ online_results,
1384
+ code_results,
1385
+ operator_results,
1386
+ research_results,
1387
+ user,
1388
+ location,
1389
+ user_name,
1390
+ uploaded_images,
1391
+ attached_file_context,
1392
+ generated_files,
1393
+ program_execution_context,
1394
+ generated_asset_results,
1395
+ is_subscribed,
1396
+ tracer,
1397
+ )
1353
1398
 
1354
- # Check if the user has disconnected
1399
+ full_response = ""
1400
+ async for item in llm_response:
1401
+ # Should not happen with async generator. Skip.
1402
+ if item is None or not isinstance(item, ResponseWithThought):
1403
+ logger.warning(f"Unexpected item type in LLM response: {type(item)}. Skipping.")
1404
+ continue
1355
1405
  if cancellation_event.is_set():
1356
- logger.debug(f"Stopping LLM response to user {user} on {common.client} client.")
1357
- # Cancel the disconnect monitor task if it is still running
1358
- await cancel_disconnect_monitor()
1359
- return
1406
+ break
1407
+ message = item.text
1408
+ full_response += message if message else ""
1409
+ if item.thought:
1410
+ async for result in send_event(ChatEvent.THOUGHT, item.thought):
1411
+ yield result
1412
+ continue
1360
1413
 
1361
- ## Generate Text Output
1362
- async for result in send_event(ChatEvent.STATUS, f"**Generating a well-informed response**"):
1414
+ # Start sending response
1415
+ async for result in send_event(ChatEvent.START_LLM_RESPONSE, ""):
1363
1416
  yield result
1364
1417
 
1365
- llm_response, chat_metadata = await agenerate_chat_response(
1366
- defiltered_query,
1367
- chat_history,
1368
- conversation,
1369
- compiled_references,
1370
- online_results,
1371
- code_results,
1372
- operator_results,
1373
- research_results,
1374
- user,
1375
- location,
1376
- user_name,
1377
- uploaded_images,
1378
- attached_file_context,
1379
- generated_files,
1380
- program_execution_context,
1381
- generated_asset_results,
1382
- is_subscribed,
1383
- tracer,
1418
+ try:
1419
+ async for result in send_event(ChatEvent.MESSAGE, message):
1420
+ yield result
1421
+ except Exception as e:
1422
+ if not cancellation_event.is_set():
1423
+ logger.warning(f"Error during streaming. Stopping send: {e}")
1424
+ break
1425
+
1426
+ # Save conversation once finish streaming
1427
+ asyncio.create_task(
1428
+ save_to_conversation_log(
1429
+ q,
1430
+ chat_response=full_response,
1431
+ user=user,
1432
+ compiled_references=compiled_references,
1433
+ online_results=online_results,
1434
+ code_results=code_results,
1435
+ operator_results=operator_results,
1436
+ research_results=research_results,
1437
+ inferred_queries=inferred_queries,
1438
+ client_application=user_scope.client_app,
1439
+ conversation_id=str(conversation.id),
1440
+ query_images=uploaded_images,
1441
+ train_of_thought=train_of_thought,
1442
+ raw_query_files=raw_query_files,
1443
+ generated_images=generated_images,
1444
+ raw_generated_files=generated_files,
1445
+ generated_mermaidjs_diagram=generated_mermaidjs_diagram,
1446
+ tracer=tracer,
1384
1447
  )
1448
+ )
1385
1449
 
1386
- full_response = ""
1387
- async for item in llm_response:
1388
- # Should not happen with async generator. Skip.
1389
- if item is None or not isinstance(item, ResponseWithThought):
1390
- logger.warning(f"Unexpected item type in LLM response: {type(item)}. Skipping.")
1450
+ # Signal end of LLM response after the loop finishes
1451
+ if not cancellation_event.is_set():
1452
+ async for result in send_event(ChatEvent.END_LLM_RESPONSE, ""):
1453
+ yield result
1454
+ # Send Usage Metadata once llm interactions are complete
1455
+ if tracer.get("usage"):
1456
+ async for event in send_event(ChatEvent.USAGE, tracer.get("usage")):
1457
+ yield event
1458
+ async for result in send_event(ChatEvent.END_RESPONSE, ""):
1459
+ yield result
1460
+ logger.debug("Finished streaming response")
1461
+
1462
+ # Cancel the disconnect monitor task if it is still running
1463
+ await cancel_disconnect_monitor()
1464
+
1465
+
1466
+ @api_chat.websocket("/ws")
1467
+ @requires(["authenticated"])
1468
+ async def chat_ws(
1469
+ websocket: WebSocket,
1470
+ common: CommonQueryParams,
1471
+ ):
1472
+ # Validate WebSocket Origin
1473
+ origin = websocket.headers.get("origin")
1474
+ if not origin or URL(origin).hostname not in ALLOWED_HOSTS:
1475
+ await websocket.close(code=1008, reason="Origin not allowed")
1476
+ return
1477
+
1478
+ # Limit open websocket connections per user
1479
+ user = websocket.scope["user"].object
1480
+ connection_manager = WebSocketConnectionManager(trial_user_max_connections=5, subscribed_user_max_connections=10)
1481
+ connection_id = str(uuid.uuid4())
1482
+
1483
+ if not await connection_manager.can_connect(websocket):
1484
+ await websocket.close(code=1008, reason="Connection limit exceeded")
1485
+ logger.info(f"WebSocket connection rejected for user {user.id}: connection limit exceeded")
1486
+ return
1487
+
1488
+ await websocket.accept()
1489
+
1490
+ # Note new websocket connection for the user
1491
+ await connection_manager.register_connection(user, connection_id)
1492
+
1493
+ # Initialize rate limiters
1494
+ rate_limiter_per_minute = ApiUserRateLimiter(requests=20, subscribed_requests=20, window=60, slug="chat_minute")
1495
+ rate_limiter_per_day = ApiUserRateLimiter(
1496
+ requests=100, subscribed_requests=600, window=60 * 60 * 24, slug="chat_day"
1497
+ )
1498
+ image_rate_limiter = ApiImageRateLimiter(max_images=10, max_combined_size_mb=20)
1499
+
1500
+ # Shared interrupt queue for communicating interrupts to ongoing research
1501
+ interrupt_queue: asyncio.Queue = asyncio.Queue(maxsize=10)
1502
+ current_task = None
1503
+
1504
+ try:
1505
+ while True:
1506
+ data = await websocket.receive_json()
1507
+
1508
+ # Check if this is an interrupt message
1509
+ if data.get("type") == "interrupt":
1510
+ if current_task and not current_task.done():
1511
+ # Send interrupt signal to the ongoing task
1512
+ abort_message = "␃🔚␗"
1513
+ await interrupt_queue.put(data.get("query") or abort_message)
1514
+ logger.info(
1515
+ f"Interrupt signal sent to ongoing task for user {websocket.scope['user'].object.id} with query: {data.get('query')}"
1516
+ )
1517
+ if data.get("query"):
1518
+ ack_type = "interrupt_message_acknowledged"
1519
+ await websocket.send_text(json.dumps({"type": ack_type}))
1520
+ else:
1521
+ ack_type = "interrupt_acknowledged"
1522
+ await websocket.send_text(json.dumps({"type": ack_type}))
1523
+ else:
1524
+ logger.info(f"No ongoing task to interrupt for user {websocket.scope['user'].object.id}")
1391
1525
  continue
1392
- if cancellation_event.is_set():
1393
- break
1394
- message = item.text
1395
- full_response += message if message else ""
1396
- if item.thought:
1397
- async for result in send_event(ChatEvent.THOUGHT, item.thought):
1398
- yield result
1526
+
1527
+ # Handle regular chat messages - ensure data has required fields
1528
+ if "q" not in data:
1529
+ await websocket.send_text(json.dumps({"error": "Missing required field 'q' in chat message"}))
1399
1530
  continue
1400
1531
 
1401
- # Start sending response
1402
- async for result in send_event(ChatEvent.START_LLM_RESPONSE, ""):
1403
- yield result
1532
+ body = ChatRequestBody(**data)
1404
1533
 
1534
+ # Apply rate limiting manually
1405
1535
  try:
1406
- async for result in send_event(ChatEvent.MESSAGE, message):
1407
- yield result
1408
- except Exception as e:
1409
- if not cancellation_event.is_set():
1410
- logger.warning(f"Error during streaming. Stopping send: {e}")
1411
- break
1536
+ rate_limiter_per_minute.check_websocket(websocket)
1537
+ rate_limiter_per_day.check_websocket(websocket)
1538
+ image_rate_limiter.check_websocket(websocket, body)
1539
+ except HTTPException as e:
1540
+ await websocket.send_text(json.dumps({"error": e.detail}))
1541
+ continue
1412
1542
 
1413
- # Save conversation once finish streaming
1414
- asyncio.create_task(
1415
- save_to_conversation_log(
1416
- q,
1417
- chat_response=full_response,
1418
- user=user,
1419
- chat_history=chat_history,
1420
- compiled_references=compiled_references,
1421
- online_results=online_results,
1422
- code_results=code_results,
1423
- operator_results=operator_results,
1424
- research_results=research_results,
1425
- inferred_queries=inferred_queries,
1426
- client_application=request.user.client_app,
1427
- conversation_id=str(conversation.id),
1428
- query_images=uploaded_images,
1429
- train_of_thought=train_of_thought,
1430
- raw_query_files=raw_query_files,
1431
- generated_images=generated_images,
1432
- raw_generated_files=generated_files,
1433
- generated_mermaidjs_diagram=generated_mermaidjs_diagram,
1434
- tracer=tracer,
1435
- )
1543
+ # Cancel any ongoing task before starting a new one
1544
+ if current_task and not current_task.done():
1545
+ current_task.cancel()
1546
+ try:
1547
+ await current_task
1548
+ except asyncio.CancelledError:
1549
+ pass
1550
+
1551
+ # Create a new task for processing the chat request
1552
+ current_task = asyncio.create_task(process_chat_request(websocket, body, common, interrupt_queue))
1553
+
1554
+ except WebSocketDisconnect:
1555
+ logger.info(f"WebSocket disconnected for user {websocket.scope['user'].object.id}")
1556
+ if current_task and not current_task.done():
1557
+ current_task.cancel()
1558
+ except Exception as e:
1559
+ logger.error(f"Error in websocket chat: {e}", exc_info=True)
1560
+ if current_task and not current_task.done():
1561
+ current_task.cancel()
1562
+ await websocket.close(code=1011, reason="Internal Server Error")
1563
+ finally:
1564
+ # Always unregister the connection on disconnect
1565
+ await connection_manager.unregister_connection(user, connection_id)
1566
+
1567
+
1568
+ async def process_chat_request(
1569
+ websocket: WebSocket,
1570
+ body: ChatRequestBody,
1571
+ common: CommonQueryParams,
1572
+ interrupt_queue: asyncio.Queue,
1573
+ ):
1574
+ """Process a single chat request with interrupt support"""
1575
+ try:
1576
+ # Since we are using websockets, we can ignore the stream parameter and always stream
1577
+ response_iterator = event_generator(
1578
+ body,
1579
+ websocket.scope["user"],
1580
+ common,
1581
+ websocket.headers,
1582
+ websocket,
1583
+ interrupt_queue,
1436
1584
  )
1585
+ async for event in response_iterator:
1586
+ await websocket.send_text(event)
1587
+ except asyncio.CancelledError:
1588
+ logger.debug(f"Chat request cancelled for user {websocket.scope['user'].object.id}")
1589
+ raise
1590
+ except Exception as e:
1591
+ logger.error(f"Error processing chat request: {e}", exc_info=True)
1592
+ await websocket.send_text(json.dumps({"error": "Internal server error"}))
1593
+ raise
1437
1594
 
1438
- # Signal end of LLM response after the loop finishes
1439
- if not cancellation_event.is_set():
1440
- async for result in send_event(ChatEvent.END_LLM_RESPONSE, ""):
1441
- yield result
1442
- # Send Usage Metadata once llm interactions are complete
1443
- if tracer.get("usage"):
1444
- async for event in send_event(ChatEvent.USAGE, tracer.get("usage")):
1445
- yield event
1446
- async for result in send_event(ChatEvent.END_RESPONSE, ""):
1447
- yield result
1448
- logger.debug("Finished streaming response")
1449
1595
 
1450
- # Cancel the disconnect monitor task if it is still running
1451
- await cancel_disconnect_monitor()
1596
+ @api_chat.post("")
1597
+ @requires(["authenticated"])
1598
+ async def chat(
1599
+ request: Request,
1600
+ common: CommonQueryParams,
1601
+ body: ChatRequestBody,
1602
+ rate_limiter_per_minute=Depends(
1603
+ ApiUserRateLimiter(requests=20, subscribed_requests=20, window=60, slug="chat_minute")
1604
+ ),
1605
+ rate_limiter_per_day=Depends(
1606
+ ApiUserRateLimiter(requests=100, subscribed_requests=600, window=60 * 60 * 24, slug="chat_day")
1607
+ ),
1608
+ image_rate_limiter=Depends(ApiImageRateLimiter(max_images=10, max_combined_size_mb=20)),
1609
+ ):
1610
+ response_iterator = event_generator(
1611
+ body,
1612
+ request.user,
1613
+ common,
1614
+ request.headers,
1615
+ request,
1616
+ )
1452
1617
 
1453
- ## Stream Text Response
1454
- if stream:
1455
- return StreamingResponse(event_generator(q, images=raw_images), media_type="text/plain")
1456
- ## Non-Streaming Text Response
1618
+ # Stream Text Response
1619
+ if body.stream:
1620
+ return StreamingResponse(response_iterator, media_type="text/plain")
1621
+ # Non-Streaming Text Response
1457
1622
  else:
1458
- response_iterator = event_generator(q, images=raw_images)
1459
1623
  response_data = await read_chat_stream(response_iterator)
1460
1624
  return Response(content=json.dumps(response_data), media_type="application/json", status_code=200)