khoj 2.0.0b8.dev1__py3-none-any.whl → 2.0.0b9.dev7__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (65) hide show
  1. khoj/database/adapters/__init__.py +3 -2
  2. khoj/database/models/__init__.py +28 -0
  3. khoj/interface/compiled/404/index.html +2 -2
  4. khoj/interface/compiled/_next/static/chunks/{2327-d863b0b21ecb23de.js → 2327-ea623ca2d22f78e9.js} +1 -1
  5. khoj/interface/compiled/_next/static/chunks/5477-c4209b72942d3038.js +1 -0
  6. khoj/interface/compiled/_next/static/chunks/{5639-0c4604668cb2d4c0.js → 5639-09e2009a2adedf8b.js} +1 -1
  7. khoj/interface/compiled/_next/static/chunks/9139-ce1ae935dac9c871.js +1 -0
  8. khoj/interface/compiled/_next/static/chunks/app/agents/layout-4e2a134ec26aa606.js +1 -0
  9. khoj/interface/compiled/_next/static/chunks/app/agents/{page-5db6ad18da10d353.js → page-0006674668eb5a4d.js} +1 -1
  10. khoj/interface/compiled/_next/static/chunks/app/automations/{page-6271e2e31c7571d1.js → page-4c465cde2d14cb52.js} +1 -1
  11. khoj/interface/compiled/_next/static/chunks/app/chat/layout-d6acbba22ccac0ff.js +1 -0
  12. khoj/interface/compiled/_next/static/chunks/app/chat/page-4408125f66c165cf.js +1 -0
  13. khoj/interface/compiled/_next/static/chunks/app/{page-a19a597629e87fb8.js → page-85b9b416898738f7.js} +1 -1
  14. khoj/interface/compiled/_next/static/chunks/app/search/layout-484d34239ed0f2b1.js +1 -0
  15. khoj/interface/compiled/_next/static/chunks/app/search/{page-fa366ac14b228688.js → page-883b7d8d2e3abe3e.js} +1 -1
  16. khoj/interface/compiled/_next/static/chunks/app/settings/{page-8f9a85f96088c18b.js → page-95e994ddac31473f.js} +1 -1
  17. khoj/interface/compiled/_next/static/chunks/app/share/chat/layout-b3f7ae1ef8871d30.js +1 -0
  18. khoj/interface/compiled/_next/static/chunks/app/share/chat/{page-ed7787cf4938b8e3.js → page-c062269e6906ef22.js} +1 -1
  19. khoj/interface/compiled/_next/static/chunks/{webpack-c6e14fd89812b96f.js → webpack-c375c47fee5a4dda.js} +1 -1
  20. khoj/interface/compiled/_next/static/css/1e9b757ee2a2b34b.css +1 -0
  21. khoj/interface/compiled/_next/static/css/ee66643a6a5bf71c.css +1 -0
  22. khoj/interface/compiled/agents/index.html +2 -2
  23. khoj/interface/compiled/agents/index.txt +2 -2
  24. khoj/interface/compiled/automations/index.html +2 -2
  25. khoj/interface/compiled/automations/index.txt +3 -3
  26. khoj/interface/compiled/chat/index.html +2 -2
  27. khoj/interface/compiled/chat/index.txt +2 -2
  28. khoj/interface/compiled/index.html +2 -2
  29. khoj/interface/compiled/index.txt +2 -2
  30. khoj/interface/compiled/search/index.html +2 -2
  31. khoj/interface/compiled/search/index.txt +2 -2
  32. khoj/interface/compiled/settings/index.html +2 -2
  33. khoj/interface/compiled/settings/index.txt +4 -4
  34. khoj/interface/compiled/share/chat/index.html +2 -2
  35. khoj/interface/compiled/share/chat/index.txt +2 -2
  36. khoj/main.py +11 -1
  37. khoj/processor/conversation/utils.py +7 -7
  38. khoj/processor/operator/__init__.py +16 -1
  39. khoj/processor/tools/run_code.py +5 -2
  40. khoj/routers/api_chat.py +822 -682
  41. khoj/routers/helpers.py +101 -2
  42. khoj/routers/research.py +56 -14
  43. khoj/utils/rawconfig.py +0 -1
  44. {khoj-2.0.0b8.dev1.dist-info → khoj-2.0.0b9.dev7.dist-info}/METADATA +1 -1
  45. {khoj-2.0.0b8.dev1.dist-info → khoj-2.0.0b9.dev7.dist-info}/RECORD +56 -56
  46. khoj/interface/compiled/_next/static/chunks/5477-18323501c445315e.js +0 -1
  47. khoj/interface/compiled/_next/static/chunks/9568-0d60ac475f4cc538.js +0 -1
  48. khoj/interface/compiled/_next/static/chunks/app/agents/layout-e00fb81dca656a10.js +0 -1
  49. khoj/interface/compiled/_next/static/chunks/app/chat/layout-33934fc2d6ae6838.js +0 -1
  50. khoj/interface/compiled/_next/static/chunks/app/chat/page-b186e95387e23ed5.js +0 -1
  51. khoj/interface/compiled/_next/static/chunks/app/search/layout-f5881c7ae3ba0795.js +0 -1
  52. khoj/interface/compiled/_next/static/chunks/app/share/chat/layout-abb6c5f4239ad7be.js +0 -1
  53. khoj/interface/compiled/_next/static/css/37a73b87f02df402.css +0 -1
  54. khoj/interface/compiled/_next/static/css/93eeacc43e261162.css +0 -1
  55. /khoj/interface/compiled/_next/static/chunks/{1327-1a9107b9a2a04a98.js → 1327-3b1a41af530fa8ee.js} +0 -0
  56. /khoj/interface/compiled/_next/static/chunks/{1915-5c6508f6ebb62a30.js → 1915-fbfe167c84ad60c5.js} +0 -0
  57. /khoj/interface/compiled/_next/static/chunks/{2117-080746c8e170c81a.js → 2117-e78b6902ad6f75ec.js} +0 -0
  58. /khoj/interface/compiled/_next/static/chunks/{2939-4af3fd24b8ffc9ad.js → 2939-4d4084c5b888b960.js} +0 -0
  59. /khoj/interface/compiled/_next/static/chunks/{4447-cd95608f8e93e711.js → 4447-d6cf93724d57e34b.js} +0 -0
  60. /khoj/interface/compiled/_next/static/chunks/{8667-50b03a89e82e0ba7.js → 8667-4b7790573b08c50d.js} +0 -0
  61. /khoj/interface/compiled/_next/static/{ZHB1va0KhWIu8Zs8-kbgt → j6hqMrHT_uO81BNrhFQkJ}/_buildManifest.js +0 -0
  62. /khoj/interface/compiled/_next/static/{ZHB1va0KhWIu8Zs8-kbgt → j6hqMrHT_uO81BNrhFQkJ}/_ssgManifest.js +0 -0
  63. {khoj-2.0.0b8.dev1.dist-info → khoj-2.0.0b9.dev7.dist-info}/WHEEL +0 -0
  64. {khoj-2.0.0b8.dev1.dist-info → khoj-2.0.0b9.dev7.dist-info}/entry_points.txt +0 -0
  65. {khoj-2.0.0b8.dev1.dist-info → khoj-2.0.0b9.dev7.dist-info}/licenses/LICENSE +0 -0
khoj/routers/api_chat.py CHANGED
@@ -10,9 +10,18 @@ from typing import Any, Dict, List, Optional
10
10
  from urllib.parse import unquote
11
11
 
12
12
  from asgiref.sync import sync_to_async
13
- from fastapi import APIRouter, Depends, HTTPException, Request
13
+ from fastapi import (
14
+ APIRouter,
15
+ Depends,
16
+ HTTPException,
17
+ Request,
18
+ WebSocket,
19
+ WebSocketDisconnect,
20
+ )
14
21
  from fastapi.responses import RedirectResponse, Response, StreamingResponse
22
+ from fastapi.websockets import WebSocketState
15
23
  from starlette.authentication import has_required_scope, requires
24
+ from starlette.requests import Headers
16
25
 
17
26
  from khoj.app.settings import ALLOWED_HOSTS
18
27
  from khoj.database.adapters import (
@@ -60,6 +69,7 @@ from khoj.routers.helpers import (
60
69
  generate_mermaidjs_diagram,
61
70
  generate_summary_from_files,
62
71
  get_conversation_command,
72
+ get_message_from_queue,
63
73
  is_query_empty,
64
74
  is_ready_to_chat,
65
75
  read_chat_stream,
@@ -657,19 +667,13 @@ def delete_message(request: Request, delete_request: DeleteMessageRequestBody) -
657
667
  return Response(content=json.dumps({"status": "error", "message": "Message not found"}), status_code=404)
658
668
 
659
669
 
660
- @api_chat.post("")
661
- @requires(["authenticated"])
662
- async def chat(
663
- request: Request,
664
- common: CommonQueryParams,
670
+ async def event_generator(
665
671
  body: ChatRequestBody,
666
- rate_limiter_per_minute=Depends(
667
- ApiUserRateLimiter(requests=20, subscribed_requests=20, window=60, slug="chat_minute")
668
- ),
669
- rate_limiter_per_day=Depends(
670
- ApiUserRateLimiter(requests=100, subscribed_requests=600, window=60 * 60 * 24, slug="chat_day")
671
- ),
672
- image_rate_limiter=Depends(ApiImageRateLimiter(max_images=10, max_combined_size_mb=20)),
672
+ user_scope: Any,
673
+ common: CommonQueryParams,
674
+ headers: Headers,
675
+ request_obj: Request | WebSocket,
676
+ parent_interrupt_queue: asyncio.Queue = None,
673
677
  ):
674
678
  # Access the parameters from the body
675
679
  q = body.q
@@ -686,67 +690,68 @@ async def chat(
686
690
  timezone = body.timezone
687
691
  raw_images = body.images
688
692
  raw_query_files = body.files
689
- interrupt_flag = body.interrupt
690
-
691
- async def event_generator(q: str, images: list[str]):
692
- start_time = time.perf_counter()
693
- ttft = None
694
- chat_metadata: dict = {}
695
- conversation = None
696
- user: KhojUser = request.user.object
697
- is_subscribed = has_required_scope(request, ["premium"])
698
- q = unquote(q)
699
- train_of_thought = []
700
- nonlocal conversation_id
701
- nonlocal raw_query_files
702
- cancellation_event = asyncio.Event()
703
-
704
- tracer: dict = {
705
- "mid": turn_id,
706
- "cid": conversation_id,
707
- "uid": user.id,
708
- "khoj_version": state.khoj_version,
709
- }
710
693
 
711
- uploaded_images: list[str] = []
712
- if images:
713
- for image in images:
714
- decoded_string = unquote(image)
715
- base64_data = decoded_string.split(",", 1)[1]
716
- image_bytes = base64.b64decode(base64_data)
717
- webp_image_bytes = convert_image_to_webp(image_bytes)
718
- uploaded_image = upload_user_image_to_bucket(webp_image_bytes, request.user.object.id)
719
- if not uploaded_image:
720
- base64_webp_image = base64.b64encode(webp_image_bytes).decode("utf-8")
721
- uploaded_image = f"data:image/webp;base64,{base64_webp_image}"
722
- uploaded_images.append(uploaded_image)
723
-
724
- query_files: Dict[str, str] = {}
725
- if raw_query_files:
726
- for file in raw_query_files:
727
- query_files[file.name] = file.content
728
-
729
- research_results: List[ResearchIteration] = []
730
- online_results: Dict = dict()
731
- code_results: Dict = dict()
732
- operator_results: List[OperatorRun] = []
733
- compiled_references: List[Any] = []
734
- inferred_queries: List[Any] = []
735
- attached_file_context = gather_raw_query_files(query_files)
736
-
737
- generated_images: List[str] = []
738
- generated_files: List[FileAttachment] = []
739
- generated_mermaidjs_diagram: str = None
740
- generated_asset_results: Dict = dict()
741
- program_execution_context: List[str] = []
742
- chat_history: List[ChatMessageModel] = []
743
-
744
- # Create a task to monitor for disconnections
745
- disconnect_monitor_task = None
746
-
747
- async def monitor_disconnection():
694
+ start_time = time.perf_counter()
695
+ ttft = None
696
+ chat_metadata: dict = {}
697
+ conversation = None
698
+ user: KhojUser = user_scope.object
699
+ is_subscribed = has_required_scope(request_obj, ["premium"])
700
+ q = unquote(q)
701
+ defiltered_query = defilter_query(q)
702
+ train_of_thought = []
703
+ cancellation_event = asyncio.Event()
704
+ child_interrupt_queue: asyncio.Queue = asyncio.Queue()
705
+ event_delimiter = "␃🔚␗"
706
+
707
+ tracer: dict = {
708
+ "mid": turn_id,
709
+ "cid": conversation_id,
710
+ "uid": user.id,
711
+ "khoj_version": state.khoj_version,
712
+ }
713
+
714
+ uploaded_images: list[str] = []
715
+ if raw_images:
716
+ for image in raw_images:
717
+ decoded_string = unquote(image)
718
+ base64_data = decoded_string.split(",", 1)[1]
719
+ image_bytes = base64.b64decode(base64_data)
720
+ webp_image_bytes = convert_image_to_webp(image_bytes)
721
+ uploaded_image = upload_user_image_to_bucket(webp_image_bytes, user.id)
722
+ if not uploaded_image:
723
+ base64_webp_image = base64.b64encode(webp_image_bytes).decode("utf-8")
724
+ uploaded_image = f"data:image/webp;base64,{base64_webp_image}"
725
+ uploaded_images.append(uploaded_image)
726
+
727
+ query_files: Dict[str, str] = {}
728
+ if raw_query_files:
729
+ for file in raw_query_files:
730
+ query_files[file.name] = file.content
731
+
732
+ research_results: List[ResearchIteration] = []
733
+ online_results: Dict = dict()
734
+ code_results: Dict = dict()
735
+ operator_results: List[OperatorRun] = []
736
+ compiled_references: List[Any] = []
737
+ inferred_queries: List[Any] = []
738
+ attached_file_context = gather_raw_query_files(query_files)
739
+
740
+ generated_images: List[str] = []
741
+ generated_files: List[FileAttachment] = []
742
+ generated_mermaidjs_diagram: str = None
743
+ generated_asset_results: Dict = dict()
744
+ program_execution_context: List[str] = []
745
+ user_message_time = datetime.now().strftime("%Y-%m-%d %H:%M:%S")
746
+
747
+ # Create a task to monitor for disconnections
748
+ disconnect_monitor_task = None
749
+
750
+ async def monitor_disconnection():
751
+ nonlocal q, defiltered_query
752
+ if isinstance(request_obj, Request):
748
753
  try:
749
- msg = await request.receive()
754
+ msg = await request_obj.receive()
750
755
  if msg["type"] == "http.disconnect":
751
756
  logger.debug(f"Request cancelled. User {user} disconnected from {common.client} client.")
752
757
  cancellation_event.set()
@@ -758,14 +763,13 @@ async def chat(
758
763
  q,
759
764
  chat_response="",
760
765
  user=user,
761
- chat_history=chat_history,
762
766
  compiled_references=compiled_references,
763
767
  online_results=online_results,
764
768
  code_results=code_results,
765
769
  operator_results=operator_results,
766
770
  research_results=research_results,
767
771
  inferred_queries=inferred_queries,
768
- client_application=request.user.client_app,
772
+ client_application=user_scope.client_app,
769
773
  conversation_id=conversation_id,
770
774
  query_images=uploaded_images,
771
775
  train_of_thought=train_of_thought,
@@ -773,502 +777,318 @@ async def chat(
773
777
  generated_images=generated_images,
774
778
  raw_generated_files=generated_asset_results,
775
779
  generated_mermaidjs_diagram=generated_mermaidjs_diagram,
780
+ user_message_time=user_message_time,
776
781
  tracer=tracer,
777
782
  )
778
783
  )
779
784
  except Exception as e:
780
785
  logger.error(f"Error in disconnect monitor: {e}")
781
-
782
- # Cancel the disconnect monitor task if it is still running
783
- async def cancel_disconnect_monitor():
784
- if disconnect_monitor_task and not disconnect_monitor_task.done():
785
- logger.debug(f"Cancelling disconnect monitor task for user {user}")
786
- disconnect_monitor_task.cancel()
787
- try:
788
- await disconnect_monitor_task
789
- except asyncio.CancelledError:
790
- pass
791
-
792
- async def send_event(event_type: ChatEvent, data: str | dict):
793
- nonlocal ttft, train_of_thought
794
- event_delimiter = "␃🔚␗"
795
- if cancellation_event.is_set():
796
- return
797
- try:
798
- if event_type == ChatEvent.END_LLM_RESPONSE:
799
- collect_telemetry()
800
- elif event_type == ChatEvent.START_LLM_RESPONSE:
801
- ttft = time.perf_counter() - start_time
802
- elif event_type == ChatEvent.STATUS:
803
- train_of_thought.append({"type": event_type.value, "data": data})
804
- elif event_type == ChatEvent.THOUGHT:
805
- # Append the data to the last thought as thoughts are streamed
806
- if (
807
- len(train_of_thought) > 0
808
- and train_of_thought[-1]["type"] == ChatEvent.THOUGHT.value
809
- and type(train_of_thought[-1]["data"]) == type(data) == str
810
- ):
811
- train_of_thought[-1]["data"] += data
786
+ elif isinstance(request_obj, WebSocket):
787
+ while request_obj.client_state == WebSocketState.CONNECTED and not cancellation_event.is_set():
788
+ await asyncio.sleep(1)
789
+
790
+ # Check if any interrupt query is received
791
+ if interrupt_query := get_message_from_queue(parent_interrupt_queue):
792
+ if interrupt_query == event_delimiter:
793
+ cancellation_event.set()
794
+ logger.debug(f"Chat cancelled by user {user} via interrupt queue.")
812
795
  else:
813
- train_of_thought.append({"type": event_type.value, "data": data})
814
-
815
- if event_type == ChatEvent.MESSAGE:
816
- yield data
817
- elif event_type == ChatEvent.REFERENCES or ChatEvent.METADATA or stream:
818
- yield json.dumps({"type": event_type.value, "data": data}, ensure_ascii=False)
819
- except Exception as e:
820
- if not cancellation_event.is_set():
821
- logger.error(
822
- f"Failed to stream chat API response to {user} on {common.client}: {e}.",
823
- exc_info=True,
796
+ # Pass the interrupt query to child tasks
797
+ logger.info(f"Continuing chat with the new instruction: {interrupt_query}")
798
+ await child_interrupt_queue.put(interrupt_query)
799
+ q += f"\n\n{interrupt_query}"
800
+ defiltered_query += f"\n\n{defilter_query(interrupt_query)}"
801
+
802
+ logger.debug(f"WebSocket disconnected or chat cancelled by user {user} from {common.client} client.")
803
+ if conversation and cancellation_event.is_set():
804
+ await asyncio.shield(
805
+ save_to_conversation_log(
806
+ q,
807
+ chat_response="",
808
+ user=user,
809
+ compiled_references=compiled_references,
810
+ online_results=online_results,
811
+ code_results=code_results,
812
+ operator_results=operator_results,
813
+ research_results=research_results,
814
+ inferred_queries=inferred_queries,
815
+ client_application=user_scope.client_app,
816
+ conversation_id=conversation_id,
817
+ query_images=uploaded_images,
818
+ train_of_thought=train_of_thought,
819
+ raw_query_files=raw_query_files,
820
+ generated_images=generated_images,
821
+ raw_generated_files=generated_asset_results,
822
+ generated_mermaidjs_diagram=generated_mermaidjs_diagram,
823
+ user_message_time=user_message_time,
824
+ tracer=tracer,
824
825
  )
825
- finally:
826
- if not cancellation_event.is_set():
827
- yield event_delimiter
828
- # Cancel the disconnect monitor task if it is still running
829
- if cancellation_event.is_set() or event_type == ChatEvent.END_RESPONSE:
830
- await cancel_disconnect_monitor()
831
-
832
- async def send_llm_response(response: str, usage: dict = None):
833
- # Check if the client is still connected
834
- if cancellation_event.is_set():
835
- return
836
- # Send Chat Response
837
- async for result in send_event(ChatEvent.START_LLM_RESPONSE, ""):
838
- yield result
839
- async for result in send_event(ChatEvent.MESSAGE, response):
840
- yield result
841
- async for result in send_event(ChatEvent.END_LLM_RESPONSE, ""):
842
- yield result
843
- # Send Usage Metadata once llm interactions are complete
844
- if usage:
845
- async for event in send_event(ChatEvent.USAGE, usage):
846
- yield event
847
- async for result in send_event(ChatEvent.END_RESPONSE, ""):
848
- yield result
849
-
850
- def collect_telemetry():
851
- # Gather chat response telemetry
852
- nonlocal chat_metadata
853
- latency = time.perf_counter() - start_time
854
- cmd_set = set([cmd.value for cmd in conversation_commands])
855
- cost = (tracer.get("usage", {}) or {}).get("cost", 0)
856
- chat_metadata = chat_metadata or {}
857
- chat_metadata["conversation_command"] = cmd_set
858
- chat_metadata["agent"] = conversation.agent.slug if conversation and conversation.agent else None
859
- chat_metadata["cost"] = f"{cost:.5f}"
860
- chat_metadata["latency"] = f"{latency:.3f}"
861
- if ttft:
862
- chat_metadata["ttft_latency"] = f"{ttft:.3f}"
863
- logger.info(f"Chat response time to first token: {ttft:.3f} seconds")
864
- logger.info(f"Chat response total time: {latency:.3f} seconds")
865
- logger.info(f"Chat response cost: ${cost:.5f}")
866
- update_telemetry_state(
867
- request=request,
868
- telemetry_type="api",
869
- api="chat",
870
- client=common.client,
871
- user_agent=request.headers.get("user-agent"),
872
- host=request.headers.get("host"),
873
- metadata=chat_metadata,
874
- )
875
-
876
- # Start the disconnect monitor in the background
877
- disconnect_monitor_task = asyncio.create_task(monitor_disconnection())
878
-
879
- if is_query_empty(q):
880
- async for result in send_llm_response("Please ask your query to get started.", tracer.get("usage")):
881
- yield result
882
- return
883
-
884
- # Automated tasks are handled before to allow mixing them with other conversation commands
885
- cmds_to_rate_limit = []
886
- is_automated_task = False
887
- if q.startswith("/automated_task"):
888
- is_automated_task = True
889
- q = q.replace("/automated_task", "").lstrip()
890
- cmds_to_rate_limit += [ConversationCommand.AutomatedTask]
826
+ )
891
827
 
892
- # Extract conversation command from query
893
- conversation_commands = [get_conversation_command(query=q)]
828
+ # Cancel the disconnect monitor task if it is still running
829
+ async def cancel_disconnect_monitor():
830
+ if disconnect_monitor_task and not disconnect_monitor_task.done():
831
+ logger.debug(f"Cancelling disconnect monitor task for user {user}")
832
+ disconnect_monitor_task.cancel()
833
+ try:
834
+ await disconnect_monitor_task
835
+ except asyncio.CancelledError:
836
+ pass
894
837
 
895
- conversation = await ConversationAdapters.aget_conversation_by_user(
896
- user,
897
- client_application=request.user.client_app,
898
- conversation_id=conversation_id,
899
- title=title,
900
- create_new=body.create_new,
901
- )
902
- if not conversation:
903
- async for result in send_llm_response(f"Conversation {conversation_id} not found", tracer.get("usage")):
904
- yield result
838
+ async def send_event(event_type: ChatEvent, data: str | dict):
839
+ nonlocal ttft, train_of_thought
840
+ if cancellation_event.is_set():
905
841
  return
906
- conversation_id = str(conversation.id)
907
-
908
- async for event in send_event(ChatEvent.METADATA, {"conversationId": conversation_id, "turnId": turn_id}):
909
- yield event
910
-
911
- agent: Agent | None = None
912
- default_agent = await AgentAdapters.aget_default_agent()
913
- if conversation.agent and conversation.agent != default_agent:
914
- agent = conversation.agent
915
-
916
- if not conversation.agent:
917
- conversation.agent = default_agent
918
- await conversation.asave()
919
- agent = default_agent
920
-
921
- await is_ready_to_chat(user)
922
- user_name = await aget_user_name(user)
923
- location = None
924
- if city or region or country or country_code:
925
- location = LocationData(city=city, region=region, country=country, country_code=country_code)
926
- user_message_time = datetime.now().strftime("%Y-%m-%d %H:%M:%S")
927
- chat_history = conversation.messages
928
-
929
- # If interrupt flag is set, wait for the previous turn to be saved before proceeding
930
- if interrupt_flag:
931
- max_wait_time = 20.0 # seconds
932
- wait_interval = 0.3 # seconds
933
- wait_start = wait_current = time.time()
934
- while wait_current - wait_start < max_wait_time:
935
- # Refresh conversation to check if interrupted message saved to DB
936
- conversation = await ConversationAdapters.aget_conversation_by_user(
937
- user,
938
- client_application=request.user.client_app,
939
- conversation_id=conversation_id,
940
- )
842
+ try:
843
+ if event_type == ChatEvent.END_LLM_RESPONSE:
844
+ collect_telemetry()
845
+ elif event_type == ChatEvent.START_LLM_RESPONSE:
846
+ ttft = time.perf_counter() - start_time
847
+ elif event_type == ChatEvent.STATUS:
848
+ train_of_thought.append({"type": event_type.value, "data": data})
849
+ elif event_type == ChatEvent.THOUGHT:
850
+ # Append the data to the last thought as thoughts are streamed
941
851
  if (
942
- conversation
943
- and conversation.messages
944
- and conversation.messages[-1].by == "khoj"
945
- and not conversation.messages[-1].message
852
+ len(train_of_thought) > 0
853
+ and train_of_thought[-1]["type"] == ChatEvent.THOUGHT.value
854
+ and type(train_of_thought[-1]["data"]) == type(data) == str
946
855
  ):
947
- logger.info(f"Detected interrupted message save to conversation {conversation_id}.")
948
- break
949
- await asyncio.sleep(wait_interval)
950
- wait_current = time.time()
951
-
952
- if wait_current - wait_start >= max_wait_time:
953
- logger.warning(
954
- f"Timeout waiting to load interrupted context from conversation {conversation_id}. Proceed without previous context."
955
- )
856
+ train_of_thought[-1]["data"] += data
857
+ else:
858
+ train_of_thought.append({"type": event_type.value, "data": data})
956
859
 
957
- # If interrupted message in DB
958
- if (
959
- conversation
960
- and conversation.messages
961
- and conversation.messages[-1].by == "khoj"
962
- and not conversation.messages[-1].message
963
- ):
964
- # Populate context from interrupted message
965
- last_message = conversation.messages[-1]
966
- online_results = {key: val.model_dump() for key, val in last_message.onlineContext.items() or []}
967
- code_results = {key: val.model_dump() for key, val in last_message.codeContext.items() or []}
968
- compiled_references = [ref.model_dump() for ref in last_message.context or []]
969
- research_results = [
970
- ResearchIteration(**iter_dict)
971
- for iter_dict in last_message.researchContext or []
972
- if iter_dict.get("summarizedResult")
973
- ]
974
- operator_results = [OperatorRun(**iter_dict) for iter_dict in last_message.operatorContext or []]
975
- train_of_thought = [thought.model_dump() for thought in last_message.trainOfThought or []]
976
- # Drop the interrupted message from conversation history
977
- chat_history.pop()
978
- logger.info(f"Loaded interrupted partial context from conversation {conversation_id}.")
979
-
980
- if conversation_commands == [ConversationCommand.Default]:
981
- try:
982
- chosen_io = await aget_data_sources_and_output_format(
983
- q,
984
- chat_history,
985
- is_automated_task,
986
- user=user,
987
- query_images=uploaded_images,
988
- agent=agent,
989
- query_files=attached_file_context,
990
- tracer=tracer,
860
+ if event_type == ChatEvent.MESSAGE:
861
+ yield data
862
+ elif event_type == ChatEvent.REFERENCES or ChatEvent.METADATA or stream:
863
+ yield json.dumps({"type": event_type.value, "data": data}, ensure_ascii=False)
864
+ except Exception as e:
865
+ if not cancellation_event.is_set():
866
+ logger.error(
867
+ f"Failed to stream chat API response to {user} on {common.client}: {e}.",
868
+ exc_info=True,
991
869
  )
992
- except ValueError as e:
993
- logger.error(f"Error getting data sources and output format: {e}. Falling back to default.")
994
- conversation_commands = [ConversationCommand.General]
995
-
996
- conversation_commands = chosen_io.get("sources") + [chosen_io.get("output")]
997
-
998
- # If we're doing research, we don't want to do anything else
999
- if ConversationCommand.Research in conversation_commands:
1000
- conversation_commands = [ConversationCommand.Research]
870
+ finally:
871
+ if not cancellation_event.is_set():
872
+ yield event_delimiter
873
+ # Cancel the disconnect monitor task if it is still running
874
+ if cancellation_event.is_set() or event_type == ChatEvent.END_RESPONSE:
875
+ await cancel_disconnect_monitor()
1001
876
 
1002
- conversation_commands_str = ", ".join([cmd.value for cmd in conversation_commands])
1003
- async for result in send_event(ChatEvent.STATUS, f"**Selected Tools:** {conversation_commands_str}"):
1004
- yield result
877
+ async def send_llm_response(response: str, usage: dict = None):
878
+ # Check if the client is still connected
879
+ if cancellation_event.is_set():
880
+ return
881
+ # Send Chat Response
882
+ async for result in send_event(ChatEvent.START_LLM_RESPONSE, ""):
883
+ yield result
884
+ async for result in send_event(ChatEvent.MESSAGE, response):
885
+ yield result
886
+ async for result in send_event(ChatEvent.END_LLM_RESPONSE, ""):
887
+ yield result
888
+ # Send Usage Metadata once llm interactions are complete
889
+ if usage:
890
+ async for event in send_event(ChatEvent.USAGE, usage):
891
+ yield event
892
+ async for result in send_event(ChatEvent.END_RESPONSE, ""):
893
+ yield result
1005
894
 
1006
- cmds_to_rate_limit += conversation_commands
1007
- for cmd in cmds_to_rate_limit:
1008
- try:
1009
- await conversation_command_rate_limiter.update_and_check_if_valid(request, cmd)
1010
- q = q.replace(f"/{cmd.value}", "").strip()
1011
- except HTTPException as e:
1012
- async for result in send_llm_response(str(e.detail), tracer.get("usage")):
1013
- yield result
1014
- return
895
+ def collect_telemetry():
896
+ # Gather chat response telemetry
897
+ nonlocal chat_metadata
898
+ latency = time.perf_counter() - start_time
899
+ cmd_set = set([cmd.value for cmd in conversation_commands])
900
+ cost = (tracer.get("usage", {}) or {}).get("cost", 0)
901
+ chat_metadata = chat_metadata or {}
902
+ chat_metadata["conversation_command"] = cmd_set
903
+ chat_metadata["agent"] = conversation.agent.slug if conversation and conversation.agent else None
904
+ chat_metadata["cost"] = f"{cost:.5f}"
905
+ chat_metadata["latency"] = f"{latency:.3f}"
906
+ if ttft:
907
+ chat_metadata["ttft_latency"] = f"{ttft:.3f}"
908
+ logger.info(f"Chat response time to first token: {ttft:.3f} seconds")
909
+ logger.info(f"Chat response total time: {latency:.3f} seconds")
910
+ logger.info(f"Chat response cost: ${cost:.5f}")
911
+ update_telemetry_state(
912
+ request=request_obj,
913
+ telemetry_type="api",
914
+ api="chat",
915
+ client=common.client,
916
+ user_agent=headers.get("user-agent"),
917
+ host=headers.get("host"),
918
+ metadata=chat_metadata,
919
+ )
1015
920
 
1016
- defiltered_query = defilter_query(q)
1017
- file_filters = conversation.file_filters if conversation and conversation.file_filters else []
921
+ # Start the disconnect monitor in the background
922
+ disconnect_monitor_task = asyncio.create_task(monitor_disconnection())
1018
923
 
1019
- if conversation_commands == [ConversationCommand.Research]:
1020
- async for research_result in research(
924
+ if is_query_empty(q):
925
+ async for result in send_llm_response("Please ask your query to get started.", tracer.get("usage")):
926
+ yield result
927
+ return
928
+
929
+ # Automated tasks are handled before to allow mixing them with other conversation commands
930
+ cmds_to_rate_limit = []
931
+ is_automated_task = False
932
+ if q.startswith("/automated_task"):
933
+ is_automated_task = True
934
+ q = q.replace("/automated_task", "").lstrip()
935
+ cmds_to_rate_limit += [ConversationCommand.AutomatedTask]
936
+
937
+ # Extract conversation command from query
938
+ conversation_commands = [get_conversation_command(query=q)]
939
+
940
+ conversation = await ConversationAdapters.aget_conversation_by_user(
941
+ user,
942
+ client_application=user_scope.client_app,
943
+ conversation_id=conversation_id,
944
+ title=title,
945
+ create_new=body.create_new,
946
+ )
947
+ if not conversation:
948
+ async for result in send_llm_response(f"Conversation {conversation_id} not found", tracer.get("usage")):
949
+ yield result
950
+ return
951
+ conversation_id = str(conversation.id)
952
+
953
+ async for event in send_event(ChatEvent.METADATA, {"conversationId": conversation_id, "turnId": turn_id}):
954
+ yield event
955
+
956
+ agent: Agent | None = None
957
+ default_agent = await AgentAdapters.aget_default_agent()
958
+ if conversation.agent and conversation.agent != default_agent:
959
+ agent = conversation.agent
960
+
961
+ if not conversation.agent:
962
+ conversation.agent = default_agent
963
+ await conversation.asave()
964
+ agent = default_agent
965
+
966
+ await is_ready_to_chat(user)
967
+ user_name = await aget_user_name(user)
968
+ location = None
969
+ if city or region or country or country_code:
970
+ location = LocationData(city=city, region=region, country=country, country_code=country_code)
971
+ chat_history = conversation.messages
972
+
973
+ # If interrupted message in DB
974
+ if last_message := await conversation.pop_message(interrupted=True):
975
+ # Populate context from interrupted message
976
+ online_results = {key: val.model_dump() for key, val in last_message.onlineContext.items() or []}
977
+ code_results = {key: val.model_dump() for key, val in last_message.codeContext.items() or []}
978
+ compiled_references = [ref.model_dump() for ref in last_message.context or []]
979
+ research_results = [
980
+ ResearchIteration(**iter_dict)
981
+ for iter_dict in last_message.researchContext or []
982
+ if iter_dict.get("summarizedResult")
983
+ ]
984
+ operator_results = [OperatorRun(**iter_dict) for iter_dict in last_message.operatorContext or []]
985
+ train_of_thought = [thought.model_dump() for thought in last_message.trainOfThought or []]
986
+ logger.info(f"Loaded interrupted partial context from conversation {conversation_id}.")
987
+
988
+ if conversation_commands == [ConversationCommand.Default]:
989
+ try:
990
+ chosen_io = await aget_data_sources_and_output_format(
991
+ q,
992
+ chat_history,
993
+ is_automated_task,
1021
994
  user=user,
1022
- query=defiltered_query,
1023
- conversation_id=conversation_id,
1024
- conversation_history=chat_history,
1025
- previous_iterations=list(research_results),
1026
995
  query_images=uploaded_images,
1027
996
  agent=agent,
1028
- send_status_func=partial(send_event, ChatEvent.STATUS),
1029
- user_name=user_name,
1030
- location=location,
1031
- file_filters=file_filters,
1032
997
  query_files=attached_file_context,
1033
998
  tracer=tracer,
1034
- cancellation_event=cancellation_event,
1035
- ):
1036
- if isinstance(research_result, ResearchIteration):
1037
- if research_result.summarizedResult:
1038
- if research_result.onlineContext:
1039
- online_results.update(research_result.onlineContext)
1040
- if research_result.codeContext:
1041
- code_results.update(research_result.codeContext)
1042
- if research_result.context:
1043
- compiled_references.extend(research_result.context)
1044
- if not research_results or research_results[-1] is not research_result:
1045
- research_results.append(research_result)
1046
- else:
1047
- yield research_result
1048
-
1049
- # Track operator results across research and operator iterations
1050
- # This relies on two conditions:
1051
- # 1. Check to append new (partial) operator results
1052
- # Relies on triggering this check on every status updates.
1053
- # Status updates cascade up from operator to research to chat api on every step.
1054
- # 2. Keep operator results in sync with each research operator step
1055
- # Relies on python object references to ensure operator results
1056
- # are implicitly kept in sync after the initial append
1057
- if (
1058
- research_results
1059
- and research_results[-1].operatorContext
1060
- and (not operator_results or operator_results[-1] is not research_results[-1].operatorContext)
1061
- ):
1062
- operator_results.append(research_results[-1].operatorContext)
1063
-
1064
- # researched_results = await extract_relevant_info(q, researched_results, agent)
1065
- if state.verbose > 1:
1066
- logger.debug(f'Researched Results: {"".join(r.summarizedResult or "" for r in research_results)}')
1067
-
1068
- # Gather Context
1069
- ## Extract Document References
1070
- if not ConversationCommand.Research in conversation_commands:
1071
- try:
1072
- async for result in search_documents(
1073
- q,
1074
- (n or 7),
1075
- d,
1076
- user,
1077
- chat_history,
1078
- conversation_id,
1079
- conversation_commands,
1080
- location,
1081
- partial(send_event, ChatEvent.STATUS),
1082
- query_images=uploaded_images,
1083
- agent=agent,
1084
- query_files=attached_file_context,
1085
- tracer=tracer,
1086
- ):
1087
- if isinstance(result, dict) and ChatEvent.STATUS in result:
1088
- yield result[ChatEvent.STATUS]
1089
- else:
1090
- compiled_references.extend(result[0])
1091
- inferred_queries.extend(result[1])
1092
- defiltered_query = result[2]
1093
- except Exception as e:
1094
- error_message = (
1095
- f"Error searching knowledge base: {e}. Attempting to respond without document references."
1096
- )
1097
- logger.error(error_message, exc_info=True)
1098
- async for result in send_event(
1099
- ChatEvent.STATUS, "Document search failed. I'll try respond without document references"
1100
- ):
1101
- yield result
1102
-
1103
- if not is_none_or_empty(compiled_references):
1104
- headings = "\n- " + "\n- ".join(set([c.get("compiled", c).split("\n")[0] for c in compiled_references]))
1105
- # Strip only leading # from headings
1106
- headings = headings.replace("#", "")
1107
- async for result in send_event(ChatEvent.STATUS, f"**Found Relevant Notes**: {headings}"):
1108
- yield result
999
+ )
1000
+ except ValueError as e:
1001
+ logger.error(f"Error getting data sources and output format: {e}. Falling back to default.")
1002
+ conversation_commands = [ConversationCommand.General]
1109
1003
 
1110
- if conversation_commands == [ConversationCommand.Notes] and not await EntryAdapters.auser_has_entries(user):
1111
- async for result in send_llm_response(f"{no_entries_found.format()}", tracer.get("usage")):
1112
- yield result
1113
- return
1004
+ conversation_commands = chosen_io.get("sources") + [chosen_io.get("output")]
1114
1005
 
1115
- if ConversationCommand.Notes in conversation_commands and is_none_or_empty(compiled_references):
1116
- conversation_commands.remove(ConversationCommand.Notes)
1006
+ # If we're doing research, we don't want to do anything else
1007
+ if ConversationCommand.Research in conversation_commands:
1008
+ conversation_commands = [ConversationCommand.Research]
1117
1009
 
1118
- ## Gather Online References
1119
- if ConversationCommand.Online in conversation_commands:
1120
- try:
1121
- async for result in search_online(
1122
- defiltered_query,
1123
- chat_history,
1124
- location,
1125
- user,
1126
- partial(send_event, ChatEvent.STATUS),
1127
- custom_filters=[],
1128
- max_online_searches=3,
1129
- query_images=uploaded_images,
1130
- query_files=attached_file_context,
1131
- agent=agent,
1132
- tracer=tracer,
1133
- ):
1134
- if isinstance(result, dict) and ChatEvent.STATUS in result:
1135
- yield result[ChatEvent.STATUS]
1136
- else:
1137
- online_results = result
1138
- except Exception as e:
1139
- error_message = f"Error searching online: {e}. Attempting to respond without online results"
1140
- logger.warning(error_message)
1141
- async for result in send_event(
1142
- ChatEvent.STATUS, "Online search failed. I'll try respond without online references"
1143
- ):
1144
- yield result
1010
+ conversation_commands_str = ", ".join([cmd.value for cmd in conversation_commands])
1011
+ async for result in send_event(ChatEvent.STATUS, f"**Selected Tools:** {conversation_commands_str}"):
1012
+ yield result
1145
1013
 
1146
- ## Gather Webpage References
1147
- if ConversationCommand.Webpage in conversation_commands:
1148
- try:
1149
- async for result in read_webpages(
1150
- defiltered_query,
1151
- chat_history,
1152
- location,
1153
- user,
1154
- partial(send_event, ChatEvent.STATUS),
1155
- max_webpages_to_read=1,
1156
- query_images=uploaded_images,
1157
- agent=agent,
1158
- query_files=attached_file_context,
1159
- tracer=tracer,
1160
- ):
1161
- if isinstance(result, dict) and ChatEvent.STATUS in result:
1162
- yield result[ChatEvent.STATUS]
1163
- else:
1164
- direct_web_pages = result
1165
- webpages = []
1166
- for query in direct_web_pages:
1167
- if online_results.get(query):
1168
- online_results[query]["webpages"] = direct_web_pages[query]["webpages"]
1169
- else:
1170
- online_results[query] = {"webpages": direct_web_pages[query]["webpages"]}
1014
+ cmds_to_rate_limit += conversation_commands
1015
+ for cmd in cmds_to_rate_limit:
1016
+ try:
1017
+ await conversation_command_rate_limiter.update_and_check_if_valid(request_obj, cmd)
1018
+ q = q.replace(f"/{cmd.value}", "").strip()
1019
+ except HTTPException as e:
1020
+ async for result in send_llm_response(str(e.detail), tracer.get("usage")):
1021
+ yield result
1022
+ return
1171
1023
 
1172
- for webpage in direct_web_pages[query]["webpages"]:
1173
- webpages.append(webpage["link"])
1174
- async for result in send_event(ChatEvent.STATUS, f"**Read web pages**: {webpages}"):
1175
- yield result
1176
- except Exception as e:
1177
- logger.warning(
1178
- f"Error reading webpages: {e}. Attempting to respond without webpage results",
1179
- exc_info=True,
1180
- )
1181
- async for result in send_event(
1182
- ChatEvent.STATUS, "Webpage read failed. I'll try respond without webpage references"
1183
- ):
1184
- yield result
1024
+ defiltered_query = defilter_query(q)
1025
+ file_filters = conversation.file_filters if conversation and conversation.file_filters else []
1185
1026
 
1186
- ## Gather Code Results
1187
- if ConversationCommand.Code in conversation_commands:
1188
- try:
1189
- context = f"# Iteration 1:\n#---\nNotes:\n{compiled_references}\n\nOnline Results:{online_results}"
1190
- async for result in run_code(
1191
- defiltered_query,
1192
- chat_history,
1193
- context,
1194
- location,
1195
- user,
1196
- partial(send_event, ChatEvent.STATUS),
1197
- query_images=uploaded_images,
1198
- agent=agent,
1199
- query_files=attached_file_context,
1200
- tracer=tracer,
1201
- ):
1202
- if isinstance(result, dict) and ChatEvent.STATUS in result:
1203
- yield result[ChatEvent.STATUS]
1204
- else:
1205
- code_results = result
1206
- except ValueError as e:
1207
- program_execution_context.append(f"Failed to run code")
1208
- logger.warning(
1209
- f"Failed to use code tool: {e}. Attempting to respond without code results",
1210
- exc_info=True,
1211
- )
1212
- if ConversationCommand.Operator in conversation_commands:
1213
- try:
1214
- async for result in operate_environment(
1215
- defiltered_query,
1216
- user,
1217
- chat_history,
1218
- location,
1219
- list(operator_results)[-1] if operator_results else None,
1220
- query_images=uploaded_images,
1221
- query_files=attached_file_context,
1222
- send_status_func=partial(send_event, ChatEvent.STATUS),
1223
- agent=agent,
1224
- cancellation_event=cancellation_event,
1225
- tracer=tracer,
1226
- ):
1227
- if isinstance(result, dict) and ChatEvent.STATUS in result:
1228
- yield result[ChatEvent.STATUS]
1229
- elif isinstance(result, OperatorRun):
1230
- if not operator_results or operator_results[-1] is not result:
1231
- operator_results.append(result)
1232
- # Add webpages visited while operating browser to references
1233
- if result.webpages:
1234
- if not online_results.get(defiltered_query):
1235
- online_results[defiltered_query] = {"webpages": result.webpages}
1236
- elif not online_results[defiltered_query].get("webpages"):
1237
- online_results[defiltered_query]["webpages"] = result.webpages
1238
- else:
1239
- online_results[defiltered_query]["webpages"] += result.webpages
1240
- except ValueError as e:
1241
- program_execution_context.append(f"Browser operation error: {e}")
1242
- logger.warning(f"Failed to operate browser with {e}", exc_info=True)
1243
- async for result in send_event(
1244
- ChatEvent.STATUS, "Operating browser failed. I'll try respond appropriately"
1245
- ):
1246
- yield result
1247
-
1248
- ## Send Gathered References
1249
- unique_online_results = deduplicate_organic_results(online_results)
1250
- async for result in send_event(
1251
- ChatEvent.REFERENCES,
1252
- {
1253
- "inferredQueries": inferred_queries,
1254
- "context": compiled_references,
1255
- "onlineContext": unique_online_results,
1256
- "codeContext": code_results,
1257
- },
1027
+ if conversation_commands == [ConversationCommand.Research]:
1028
+ async for research_result in research(
1029
+ user=user,
1030
+ query=defiltered_query,
1031
+ conversation_id=conversation_id,
1032
+ conversation_history=chat_history,
1033
+ previous_iterations=list(research_results),
1034
+ query_images=uploaded_images,
1035
+ agent=agent,
1036
+ send_status_func=partial(send_event, ChatEvent.STATUS),
1037
+ user_name=user_name,
1038
+ location=location,
1039
+ file_filters=file_filters,
1040
+ query_files=attached_file_context,
1041
+ tracer=tracer,
1042
+ cancellation_event=cancellation_event,
1043
+ interrupt_queue=child_interrupt_queue,
1044
+ abort_message=event_delimiter,
1258
1045
  ):
1259
- yield result
1046
+ if isinstance(research_result, ResearchIteration):
1047
+ if research_result.summarizedResult:
1048
+ if research_result.onlineContext:
1049
+ online_results.update(research_result.onlineContext)
1050
+ if research_result.codeContext:
1051
+ code_results.update(research_result.codeContext)
1052
+ if research_result.context:
1053
+ compiled_references.extend(research_result.context)
1054
+ if not research_results or research_results[-1] is not research_result:
1055
+ research_results.append(research_result)
1056
+ else:
1057
+ yield research_result
1058
+
1059
+ # Track operator results across research and operator iterations
1060
+ # This relies on two conditions:
1061
+ # 1. Check to append new (partial) operator results
1062
+ # Relies on triggering this check on every status updates.
1063
+ # Status updates cascade up from operator to research to chat api on every step.
1064
+ # 2. Keep operator results in sync with each research operator step
1065
+ # Relies on python object references to ensure operator results
1066
+ # are implicitly kept in sync after the initial append
1067
+ if (
1068
+ research_results
1069
+ and research_results[-1].operatorContext
1070
+ and (not operator_results or operator_results[-1] is not research_results[-1].operatorContext)
1071
+ ):
1072
+ operator_results.append(research_results[-1].operatorContext)
1260
1073
 
1261
- # Generate Output
1262
- ## Generate Image Output
1263
- if ConversationCommand.Image in conversation_commands:
1264
- async for result in text_to_image(
1265
- defiltered_query,
1074
+ # researched_results = await extract_relevant_info(q, researched_results, agent)
1075
+ if state.verbose > 1:
1076
+ logger.debug(f'Researched Results: {"".join(r.summarizedResult or "" for r in research_results)}')
1077
+
1078
+ # Gather Context
1079
+ ## Extract Document References
1080
+ if not ConversationCommand.Research in conversation_commands:
1081
+ try:
1082
+ async for result in search_documents(
1083
+ q,
1084
+ (n or 7),
1085
+ d,
1266
1086
  user,
1267
1087
  chat_history,
1268
- location_data=location,
1269
- references=compiled_references,
1270
- online_results=online_results,
1271
- send_status_func=partial(send_event, ChatEvent.STATUS),
1088
+ conversation_id,
1089
+ conversation_commands,
1090
+ location,
1091
+ partial(send_event, ChatEvent.STATUS),
1272
1092
  query_images=uploaded_images,
1273
1093
  agent=agent,
1274
1094
  query_files=attached_file_context,
@@ -1277,184 +1097,504 @@ async def chat(
1277
1097
  if isinstance(result, dict) and ChatEvent.STATUS in result:
1278
1098
  yield result[ChatEvent.STATUS]
1279
1099
  else:
1280
- generated_image, status_code, improved_image_prompt = result
1100
+ compiled_references.extend(result[0])
1101
+ inferred_queries.extend(result[1])
1102
+ defiltered_query = result[2]
1103
+ except Exception as e:
1104
+ error_message = f"Error searching knowledge base: {e}. Attempting to respond without document references."
1105
+ logger.error(error_message, exc_info=True)
1106
+ async for result in send_event(
1107
+ ChatEvent.STATUS, "Document search failed. I'll try respond without document references"
1108
+ ):
1109
+ yield result
1281
1110
 
1282
- inferred_queries.append(improved_image_prompt)
1283
- if generated_image is None or status_code != 200:
1284
- program_execution_context.append(f"Failed to generate image with {improved_image_prompt}")
1285
- async for result in send_event(ChatEvent.STATUS, f"Failed to generate image"):
1286
- yield result
1287
- else:
1288
- generated_images.append(generated_image)
1111
+ if not is_none_or_empty(compiled_references):
1112
+ distinct_headings = set([d.get("compiled").split("\n")[0] for d in compiled_references if "compiled" in d])
1113
+ distinct_files = set([d["file"] for d in compiled_references])
1114
+ # Strip only leading # from headings
1115
+ headings_str = "\n- " + "\n- ".join(distinct_headings).replace("#", "")
1116
+ async for result in send_event(
1117
+ ChatEvent.STATUS,
1118
+ f"**Found {len(distinct_headings)} Notes Across {len(distinct_files)} Files**: {headings_str}",
1119
+ ):
1120
+ yield result
1289
1121
 
1290
- generated_asset_results["images"] = {
1291
- "query": improved_image_prompt,
1292
- }
1122
+ if conversation_commands == [ConversationCommand.Notes] and not await EntryAdapters.auser_has_entries(user):
1123
+ async for result in send_llm_response(f"{no_entries_found.format()}", tracer.get("usage")):
1124
+ yield result
1125
+ return
1293
1126
 
1294
- async for result in send_event(
1295
- ChatEvent.GENERATED_ASSETS,
1296
- {
1297
- "images": [generated_image],
1298
- },
1299
- ):
1300
- yield result
1127
+ if ConversationCommand.Notes in conversation_commands and is_none_or_empty(compiled_references):
1128
+ conversation_commands.remove(ConversationCommand.Notes)
1301
1129
 
1302
- if ConversationCommand.Diagram in conversation_commands:
1303
- async for result in send_event(ChatEvent.STATUS, f"Creating diagram"):
1130
+ ## Gather Online References
1131
+ if ConversationCommand.Online in conversation_commands:
1132
+ try:
1133
+ async for result in search_online(
1134
+ defiltered_query,
1135
+ chat_history,
1136
+ location,
1137
+ user,
1138
+ partial(send_event, ChatEvent.STATUS),
1139
+ custom_filters=[],
1140
+ max_online_searches=3,
1141
+ query_images=uploaded_images,
1142
+ query_files=attached_file_context,
1143
+ agent=agent,
1144
+ tracer=tracer,
1145
+ ):
1146
+ if isinstance(result, dict) and ChatEvent.STATUS in result:
1147
+ yield result[ChatEvent.STATUS]
1148
+ else:
1149
+ online_results = result
1150
+ except Exception as e:
1151
+ error_message = f"Error searching online: {e}. Attempting to respond without online results"
1152
+ logger.warning(error_message)
1153
+ async for result in send_event(
1154
+ ChatEvent.STATUS, "Online search failed. I'll try respond without online references"
1155
+ ):
1304
1156
  yield result
1305
1157
 
1306
- inferred_queries = []
1307
- diagram_description = ""
1158
+ ## Gather Webpage References
1159
+ if ConversationCommand.Webpage in conversation_commands:
1160
+ try:
1161
+ async for result in read_webpages(
1162
+ defiltered_query,
1163
+ chat_history,
1164
+ location,
1165
+ user,
1166
+ partial(send_event, ChatEvent.STATUS),
1167
+ max_webpages_to_read=1,
1168
+ query_images=uploaded_images,
1169
+ agent=agent,
1170
+ query_files=attached_file_context,
1171
+ tracer=tracer,
1172
+ ):
1173
+ if isinstance(result, dict) and ChatEvent.STATUS in result:
1174
+ yield result[ChatEvent.STATUS]
1175
+ else:
1176
+ direct_web_pages = result
1177
+ webpages = []
1178
+ for query in direct_web_pages:
1179
+ if online_results.get(query):
1180
+ online_results[query]["webpages"] = direct_web_pages[query]["webpages"]
1181
+ else:
1182
+ online_results[query] = {"webpages": direct_web_pages[query]["webpages"]}
1308
1183
 
1309
- async for result in generate_mermaidjs_diagram(
1310
- q=defiltered_query,
1311
- chat_history=chat_history,
1312
- location_data=location,
1313
- note_references=compiled_references,
1314
- online_results=online_results,
1184
+ for webpage in direct_web_pages[query]["webpages"]:
1185
+ webpages.append(webpage["link"])
1186
+ async for result in send_event(ChatEvent.STATUS, f"**Read web pages**: {webpages}"):
1187
+ yield result
1188
+ except Exception as e:
1189
+ logger.warning(
1190
+ f"Error reading webpages: {e}. Attempting to respond without webpage results",
1191
+ exc_info=True,
1192
+ )
1193
+ async for result in send_event(
1194
+ ChatEvent.STATUS, "Webpage read failed. I'll try respond without webpage references"
1195
+ ):
1196
+ yield result
1197
+
1198
+ ## Gather Code Results
1199
+ if ConversationCommand.Code in conversation_commands:
1200
+ try:
1201
+ context = f"# Iteration 1:\n#---\nNotes:\n{compiled_references}\n\nOnline Results:{online_results}"
1202
+ async for result in run_code(
1203
+ defiltered_query,
1204
+ chat_history,
1205
+ context,
1206
+ location,
1207
+ user,
1208
+ partial(send_event, ChatEvent.STATUS),
1315
1209
  query_images=uploaded_images,
1316
- user=user,
1317
1210
  agent=agent,
1318
- send_status_func=partial(send_event, ChatEvent.STATUS),
1319
1211
  query_files=attached_file_context,
1320
1212
  tracer=tracer,
1321
1213
  ):
1322
1214
  if isinstance(result, dict) and ChatEvent.STATUS in result:
1323
1215
  yield result[ChatEvent.STATUS]
1324
1216
  else:
1325
- better_diagram_description_prompt, mermaidjs_diagram_description = result
1326
- if better_diagram_description_prompt and mermaidjs_diagram_description:
1327
- inferred_queries.append(better_diagram_description_prompt)
1328
- diagram_description = mermaidjs_diagram_description
1329
-
1330
- generated_mermaidjs_diagram = diagram_description
1331
-
1332
- generated_asset_results["diagrams"] = {
1333
- "query": better_diagram_description_prompt,
1334
- }
1335
-
1336
- async for result in send_event(
1337
- ChatEvent.GENERATED_ASSETS,
1338
- {
1339
- "mermaidjsDiagram": mermaidjs_diagram_description,
1340
- },
1341
- ):
1342
- yield result
1343
- else:
1344
- error_message = "Failed to generate diagram. Please try again later."
1345
- program_execution_context.append(
1346
- prompts.failed_diagram_generation.format(
1347
- attempted_diagram=better_diagram_description_prompt
1348
- )
1349
- )
1217
+ code_results = result
1218
+ except ValueError as e:
1219
+ program_execution_context.append(f"Failed to run code")
1220
+ logger.warning(
1221
+ f"Failed to use code tool: {e}. Attempting to respond without code results",
1222
+ exc_info=True,
1223
+ )
1224
+ if ConversationCommand.Operator in conversation_commands:
1225
+ try:
1226
+ async for result in operate_environment(
1227
+ defiltered_query,
1228
+ user,
1229
+ chat_history,
1230
+ location,
1231
+ list(operator_results)[-1] if operator_results else None,
1232
+ query_images=uploaded_images,
1233
+ query_files=attached_file_context,
1234
+ send_status_func=partial(send_event, ChatEvent.STATUS),
1235
+ agent=agent,
1236
+ cancellation_event=cancellation_event,
1237
+ interrupt_queue=child_interrupt_queue,
1238
+ tracer=tracer,
1239
+ ):
1240
+ if isinstance(result, dict) and ChatEvent.STATUS in result:
1241
+ yield result[ChatEvent.STATUS]
1242
+ elif isinstance(result, OperatorRun):
1243
+ if not operator_results or operator_results[-1] is not result:
1244
+ operator_results.append(result)
1245
+ # Add webpages visited while operating browser to references
1246
+ if result.webpages:
1247
+ if not online_results.get(defiltered_query):
1248
+ online_results[defiltered_query] = {"webpages": result.webpages}
1249
+ elif not online_results[defiltered_query].get("webpages"):
1250
+ online_results[defiltered_query]["webpages"] = result.webpages
1251
+ else:
1252
+ online_results[defiltered_query]["webpages"] += result.webpages
1253
+ except ValueError as e:
1254
+ program_execution_context.append(f"Browser operation error: {e}")
1255
+ logger.warning(f"Failed to operate browser with {e}", exc_info=True)
1256
+ async for result in send_event(
1257
+ ChatEvent.STATUS, "Operating browser failed. I'll try respond appropriately"
1258
+ ):
1259
+ yield result
1350
1260
 
1351
- async for result in send_event(ChatEvent.STATUS, error_message):
1352
- yield result
1261
+ ## Send Gathered References
1262
+ unique_online_results = deduplicate_organic_results(online_results)
1263
+ async for result in send_event(
1264
+ ChatEvent.REFERENCES,
1265
+ {
1266
+ "inferredQueries": inferred_queries,
1267
+ "context": compiled_references,
1268
+ "onlineContext": unique_online_results,
1269
+ "codeContext": code_results,
1270
+ },
1271
+ ):
1272
+ yield result
1273
+
1274
+ # Generate Output
1275
+ ## Generate Image Output
1276
+ if ConversationCommand.Image in conversation_commands:
1277
+ async for result in text_to_image(
1278
+ defiltered_query,
1279
+ user,
1280
+ chat_history,
1281
+ location_data=location,
1282
+ references=compiled_references,
1283
+ online_results=online_results,
1284
+ send_status_func=partial(send_event, ChatEvent.STATUS),
1285
+ query_images=uploaded_images,
1286
+ agent=agent,
1287
+ query_files=attached_file_context,
1288
+ tracer=tracer,
1289
+ ):
1290
+ if isinstance(result, dict) and ChatEvent.STATUS in result:
1291
+ yield result[ChatEvent.STATUS]
1292
+ else:
1293
+ generated_image, status_code, improved_image_prompt = result
1294
+
1295
+ inferred_queries.append(improved_image_prompt)
1296
+ if generated_image is None or status_code != 200:
1297
+ program_execution_context.append(f"Failed to generate image with {improved_image_prompt}")
1298
+ async for result in send_event(ChatEvent.STATUS, f"Failed to generate image"):
1299
+ yield result
1300
+ else:
1301
+ generated_images.append(generated_image)
1353
1302
 
1354
- # Check if the user has disconnected
1303
+ generated_asset_results["images"] = {
1304
+ "query": improved_image_prompt,
1305
+ }
1306
+
1307
+ async for result in send_event(
1308
+ ChatEvent.GENERATED_ASSETS,
1309
+ {
1310
+ "images": [generated_image],
1311
+ },
1312
+ ):
1313
+ yield result
1314
+
1315
+ if ConversationCommand.Diagram in conversation_commands:
1316
+ async for result in send_event(ChatEvent.STATUS, f"Creating diagram"):
1317
+ yield result
1318
+
1319
+ inferred_queries = []
1320
+ diagram_description = ""
1321
+
1322
+ async for result in generate_mermaidjs_diagram(
1323
+ q=defiltered_query,
1324
+ chat_history=chat_history,
1325
+ location_data=location,
1326
+ note_references=compiled_references,
1327
+ online_results=online_results,
1328
+ query_images=uploaded_images,
1329
+ user=user,
1330
+ agent=agent,
1331
+ send_status_func=partial(send_event, ChatEvent.STATUS),
1332
+ query_files=attached_file_context,
1333
+ tracer=tracer,
1334
+ ):
1335
+ if isinstance(result, dict) and ChatEvent.STATUS in result:
1336
+ yield result[ChatEvent.STATUS]
1337
+ else:
1338
+ better_diagram_description_prompt, mermaidjs_diagram_description = result
1339
+ if better_diagram_description_prompt and mermaidjs_diagram_description:
1340
+ inferred_queries.append(better_diagram_description_prompt)
1341
+ diagram_description = mermaidjs_diagram_description
1342
+
1343
+ generated_mermaidjs_diagram = diagram_description
1344
+
1345
+ generated_asset_results["diagrams"] = {
1346
+ "query": better_diagram_description_prompt,
1347
+ }
1348
+
1349
+ async for result in send_event(
1350
+ ChatEvent.GENERATED_ASSETS,
1351
+ {
1352
+ "mermaidjsDiagram": mermaidjs_diagram_description,
1353
+ },
1354
+ ):
1355
+ yield result
1356
+ else:
1357
+ error_message = "Failed to generate diagram. Please try again later."
1358
+ program_execution_context.append(
1359
+ prompts.failed_diagram_generation.format(attempted_diagram=better_diagram_description_prompt)
1360
+ )
1361
+
1362
+ async for result in send_event(ChatEvent.STATUS, error_message):
1363
+ yield result
1364
+
1365
+ # Check if the user has disconnected
1366
+ if cancellation_event.is_set():
1367
+ logger.debug(f"Stopping LLM response to user {user} on {common.client} client.")
1368
+ # Cancel the disconnect monitor task if it is still running
1369
+ await cancel_disconnect_monitor()
1370
+ return
1371
+
1372
+ ## Generate Text Output
1373
+ async for result in send_event(ChatEvent.STATUS, f"**Generating a well-informed response**"):
1374
+ yield result
1375
+
1376
+ llm_response, chat_metadata = await agenerate_chat_response(
1377
+ defiltered_query,
1378
+ chat_history,
1379
+ conversation,
1380
+ compiled_references,
1381
+ online_results,
1382
+ code_results,
1383
+ operator_results,
1384
+ research_results,
1385
+ user,
1386
+ location,
1387
+ user_name,
1388
+ uploaded_images,
1389
+ attached_file_context,
1390
+ generated_files,
1391
+ program_execution_context,
1392
+ generated_asset_results,
1393
+ is_subscribed,
1394
+ tracer,
1395
+ )
1396
+
1397
+ full_response = ""
1398
+ async for item in llm_response:
1399
+ # Should not happen with async generator. Skip.
1400
+ if item is None or not isinstance(item, ResponseWithThought):
1401
+ logger.warning(f"Unexpected item type in LLM response: {type(item)}. Skipping.")
1402
+ continue
1355
1403
  if cancellation_event.is_set():
1356
- logger.debug(f"Stopping LLM response to user {user} on {common.client} client.")
1357
- # Cancel the disconnect monitor task if it is still running
1358
- await cancel_disconnect_monitor()
1359
- return
1404
+ break
1405
+ message = item.text
1406
+ full_response += message if message else ""
1407
+ if item.thought:
1408
+ async for result in send_event(ChatEvent.THOUGHT, item.thought):
1409
+ yield result
1410
+ continue
1360
1411
 
1361
- ## Generate Text Output
1362
- async for result in send_event(ChatEvent.STATUS, f"**Generating a well-informed response**"):
1412
+ # Start sending response
1413
+ async for result in send_event(ChatEvent.START_LLM_RESPONSE, ""):
1363
1414
  yield result
1364
1415
 
1365
- llm_response, chat_metadata = await agenerate_chat_response(
1366
- defiltered_query,
1367
- chat_history,
1368
- conversation,
1369
- compiled_references,
1370
- online_results,
1371
- code_results,
1372
- operator_results,
1373
- research_results,
1374
- user,
1375
- location,
1376
- user_name,
1377
- uploaded_images,
1378
- attached_file_context,
1379
- generated_files,
1380
- program_execution_context,
1381
- generated_asset_results,
1382
- is_subscribed,
1383
- tracer,
1416
+ try:
1417
+ async for result in send_event(ChatEvent.MESSAGE, message):
1418
+ yield result
1419
+ except Exception as e:
1420
+ if not cancellation_event.is_set():
1421
+ logger.warning(f"Error during streaming. Stopping send: {e}")
1422
+ break
1423
+
1424
+ # Save conversation once finish streaming
1425
+ asyncio.create_task(
1426
+ save_to_conversation_log(
1427
+ q,
1428
+ chat_response=full_response,
1429
+ user=user,
1430
+ compiled_references=compiled_references,
1431
+ online_results=online_results,
1432
+ code_results=code_results,
1433
+ operator_results=operator_results,
1434
+ research_results=research_results,
1435
+ inferred_queries=inferred_queries,
1436
+ client_application=user_scope.client_app,
1437
+ conversation_id=str(conversation.id),
1438
+ query_images=uploaded_images,
1439
+ train_of_thought=train_of_thought,
1440
+ raw_query_files=raw_query_files,
1441
+ generated_images=generated_images,
1442
+ raw_generated_files=generated_files,
1443
+ generated_mermaidjs_diagram=generated_mermaidjs_diagram,
1444
+ tracer=tracer,
1384
1445
  )
1446
+ )
1447
+
1448
+ # Signal end of LLM response after the loop finishes
1449
+ if not cancellation_event.is_set():
1450
+ async for result in send_event(ChatEvent.END_LLM_RESPONSE, ""):
1451
+ yield result
1452
+ # Send Usage Metadata once llm interactions are complete
1453
+ if tracer.get("usage"):
1454
+ async for event in send_event(ChatEvent.USAGE, tracer.get("usage")):
1455
+ yield event
1456
+ async for result in send_event(ChatEvent.END_RESPONSE, ""):
1457
+ yield result
1458
+ logger.debug("Finished streaming response")
1385
1459
 
1386
- full_response = ""
1387
- async for item in llm_response:
1388
- # Should not happen with async generator. Skip.
1389
- if item is None or not isinstance(item, ResponseWithThought):
1390
- logger.warning(f"Unexpected item type in LLM response: {type(item)}. Skipping.")
1460
+ # Cancel the disconnect monitor task if it is still running
1461
+ await cancel_disconnect_monitor()
1462
+
1463
+
1464
+ @api_chat.websocket("/ws")
1465
+ @requires(["authenticated"])
1466
+ async def chat_ws(
1467
+ websocket: WebSocket,
1468
+ common: CommonQueryParams,
1469
+ ):
1470
+ await websocket.accept()
1471
+
1472
+ # Initialize rate limiters
1473
+ rate_limiter_per_minute = ApiUserRateLimiter(requests=20, subscribed_requests=20, window=60, slug="chat_minute")
1474
+ rate_limiter_per_day = ApiUserRateLimiter(
1475
+ requests=100, subscribed_requests=600, window=60 * 60 * 24, slug="chat_day"
1476
+ )
1477
+ image_rate_limiter = ApiImageRateLimiter(max_images=10, max_combined_size_mb=20)
1478
+
1479
+ # Shared interrupt queue for communicating interrupts to ongoing research
1480
+ interrupt_queue: asyncio.Queue = asyncio.Queue()
1481
+ current_task = None
1482
+
1483
+ try:
1484
+ while True:
1485
+ data = await websocket.receive_json()
1486
+
1487
+ # Check if this is an interrupt message
1488
+ if data.get("type") == "interrupt":
1489
+ if current_task and not current_task.done():
1490
+ # Send interrupt signal to the ongoing task
1491
+ abort_message = "␃🔚␗"
1492
+ await interrupt_queue.put(data.get("query") or abort_message)
1493
+ logger.info(
1494
+ f"Interrupt signal sent to ongoing task for user {websocket.scope['user'].object.id} with query: {data.get('query')}"
1495
+ )
1496
+ if data.get("query"):
1497
+ ack_type = "interrupt_message_acknowledged"
1498
+ await websocket.send_text(json.dumps({"type": ack_type}))
1499
+ else:
1500
+ ack_type = "interrupt_acknowledged"
1501
+ await websocket.send_text(json.dumps({"type": ack_type}))
1502
+ else:
1503
+ logger.info(f"No ongoing task to interrupt for user {websocket.scope['user'].object.id}")
1391
1504
  continue
1392
- if cancellation_event.is_set():
1393
- break
1394
- message = item.text
1395
- full_response += message if message else ""
1396
- if item.thought:
1397
- async for result in send_event(ChatEvent.THOUGHT, item.thought):
1398
- yield result
1505
+
1506
+ # Handle regular chat messages - ensure data has required fields
1507
+ if "q" not in data:
1508
+ await websocket.send_text(json.dumps({"error": "Missing required field 'q' in chat message"}))
1399
1509
  continue
1400
1510
 
1401
- # Start sending response
1402
- async for result in send_event(ChatEvent.START_LLM_RESPONSE, ""):
1403
- yield result
1511
+ body = ChatRequestBody(**data)
1404
1512
 
1513
+ # Apply rate limiting manually
1405
1514
  try:
1406
- async for result in send_event(ChatEvent.MESSAGE, message):
1407
- yield result
1408
- except Exception as e:
1409
- if not cancellation_event.is_set():
1410
- logger.warning(f"Error during streaming. Stopping send: {e}")
1411
- break
1515
+ rate_limiter_per_minute.check_websocket(websocket)
1516
+ rate_limiter_per_day.check_websocket(websocket)
1517
+ image_rate_limiter.check_websocket(websocket, body)
1518
+ except HTTPException as e:
1519
+ await websocket.send_text(json.dumps({"error": e.detail}))
1520
+ continue
1412
1521
 
1413
- # Save conversation once finish streaming
1414
- asyncio.create_task(
1415
- save_to_conversation_log(
1416
- q,
1417
- chat_response=full_response,
1418
- user=user,
1419
- chat_history=chat_history,
1420
- compiled_references=compiled_references,
1421
- online_results=online_results,
1422
- code_results=code_results,
1423
- operator_results=operator_results,
1424
- research_results=research_results,
1425
- inferred_queries=inferred_queries,
1426
- client_application=request.user.client_app,
1427
- conversation_id=str(conversation.id),
1428
- query_images=uploaded_images,
1429
- train_of_thought=train_of_thought,
1430
- raw_query_files=raw_query_files,
1431
- generated_images=generated_images,
1432
- raw_generated_files=generated_files,
1433
- generated_mermaidjs_diagram=generated_mermaidjs_diagram,
1434
- tracer=tracer,
1435
- )
1522
+ # Cancel any ongoing task before starting a new one
1523
+ if current_task and not current_task.done():
1524
+ current_task.cancel()
1525
+ try:
1526
+ await current_task
1527
+ except asyncio.CancelledError:
1528
+ pass
1529
+
1530
+ # Create a new task for processing the chat request
1531
+ current_task = asyncio.create_task(process_chat_request(websocket, body, common, interrupt_queue))
1532
+
1533
+ except WebSocketDisconnect:
1534
+ logger.info(f"WebSocket disconnected for user {websocket.scope['user'].object.id}")
1535
+ if current_task and not current_task.done():
1536
+ current_task.cancel()
1537
+ except Exception as e:
1538
+ logger.error(f"Error in websocket chat: {e}", exc_info=True)
1539
+ if current_task and not current_task.done():
1540
+ current_task.cancel()
1541
+ await websocket.close(code=1011, reason="Internal Server Error")
1542
+
1543
+
1544
+ async def process_chat_request(
1545
+ websocket: WebSocket,
1546
+ body: ChatRequestBody,
1547
+ common: CommonQueryParams,
1548
+ interrupt_queue: asyncio.Queue,
1549
+ ):
1550
+ """Process a single chat request with interrupt support"""
1551
+ try:
1552
+ # Since we are using websockets, we can ignore the stream parameter and always stream
1553
+ response_iterator = event_generator(
1554
+ body,
1555
+ websocket.scope["user"],
1556
+ common,
1557
+ websocket.headers,
1558
+ websocket,
1559
+ interrupt_queue,
1436
1560
  )
1561
+ async for event in response_iterator:
1562
+ await websocket.send_text(event)
1563
+ except asyncio.CancelledError:
1564
+ logger.debug(f"Chat request cancelled for user {websocket.scope['user'].object.id}")
1565
+ raise
1566
+ except Exception as e:
1567
+ logger.error(f"Error processing chat request: {e}", exc_info=True)
1568
+ await websocket.send_text(json.dumps({"error": "Internal server error"}))
1569
+ raise
1437
1570
 
1438
- # Signal end of LLM response after the loop finishes
1439
- if not cancellation_event.is_set():
1440
- async for result in send_event(ChatEvent.END_LLM_RESPONSE, ""):
1441
- yield result
1442
- # Send Usage Metadata once llm interactions are complete
1443
- if tracer.get("usage"):
1444
- async for event in send_event(ChatEvent.USAGE, tracer.get("usage")):
1445
- yield event
1446
- async for result in send_event(ChatEvent.END_RESPONSE, ""):
1447
- yield result
1448
- logger.debug("Finished streaming response")
1449
1571
 
1450
- # Cancel the disconnect monitor task if it is still running
1451
- await cancel_disconnect_monitor()
1572
+ @api_chat.post("")
1573
+ @requires(["authenticated"])
1574
+ async def chat(
1575
+ request: Request,
1576
+ common: CommonQueryParams,
1577
+ body: ChatRequestBody,
1578
+ rate_limiter_per_minute=Depends(
1579
+ ApiUserRateLimiter(requests=20, subscribed_requests=20, window=60, slug="chat_minute")
1580
+ ),
1581
+ rate_limiter_per_day=Depends(
1582
+ ApiUserRateLimiter(requests=100, subscribed_requests=600, window=60 * 60 * 24, slug="chat_day")
1583
+ ),
1584
+ image_rate_limiter=Depends(ApiImageRateLimiter(max_images=10, max_combined_size_mb=20)),
1585
+ ):
1586
+ response_iterator = event_generator(
1587
+ body,
1588
+ request.user,
1589
+ common,
1590
+ request.headers,
1591
+ request,
1592
+ )
1452
1593
 
1453
- ## Stream Text Response
1454
- if stream:
1455
- return StreamingResponse(event_generator(q, images=raw_images), media_type="text/plain")
1456
- ## Non-Streaming Text Response
1594
+ # Stream Text Response
1595
+ if body.stream:
1596
+ return StreamingResponse(response_iterator, media_type="text/plain")
1597
+ # Non-Streaming Text Response
1457
1598
  else:
1458
- response_iterator = event_generator(q, images=raw_images)
1459
1599
  response_data = await read_chat_stream(response_iterator)
1460
1600
  return Response(content=json.dumps(response_data), media_type="application/json", status_code=200)