khoj 1.16.1.dev25__py3-none-any.whl → 1.16.1.dev47__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
khoj/routers/api_chat.py CHANGED
@@ -1,17 +1,17 @@
1
+ import asyncio
1
2
  import json
2
3
  import logging
3
- import math
4
+ import time
4
5
  from datetime import datetime
6
+ from functools import partial
5
7
  from typing import Any, Dict, List, Optional
6
8
  from urllib.parse import unquote
7
9
 
8
10
  from asgiref.sync import sync_to_async
9
- from fastapi import APIRouter, Depends, HTTPException, Request, WebSocket
11
+ from fastapi import APIRouter, Depends, HTTPException, Request
10
12
  from fastapi.requests import Request
11
13
  from fastapi.responses import Response, StreamingResponse
12
14
  from starlette.authentication import requires
13
- from starlette.websockets import WebSocketDisconnect
14
- from websockets import ConnectionClosedOK
15
15
 
16
16
  from khoj.app.settings import ALLOWED_HOSTS
17
17
  from khoj.database.adapters import (
@@ -23,19 +23,15 @@ from khoj.database.adapters import (
23
23
  aget_user_name,
24
24
  )
25
25
  from khoj.database.models import KhojUser
26
- from khoj.processor.conversation.prompts import (
27
- help_message,
28
- no_entries_found,
29
- no_notes_found,
30
- )
26
+ from khoj.processor.conversation.prompts import help_message, no_entries_found
31
27
  from khoj.processor.conversation.utils import save_to_conversation_log
32
28
  from khoj.processor.speech.text_to_speech import generate_text_to_speech
33
29
  from khoj.processor.tools.online_search import read_webpages, search_online
34
30
  from khoj.routers.api import extract_references_and_questions
35
31
  from khoj.routers.helpers import (
36
32
  ApiUserRateLimiter,
33
+ ChatEvent,
37
34
  CommonQueryParams,
38
- CommonQueryParamsClass,
39
35
  ConversationCommandRateLimiter,
40
36
  agenerate_chat_response,
41
37
  aget_relevant_information_sources,
@@ -526,141 +522,142 @@ async def set_conversation_title(
526
522
  )
527
523
 
528
524
 
529
- @api_chat.websocket("/ws")
530
- async def websocket_endpoint(
531
- websocket: WebSocket,
532
- conversation_id: int,
525
+ @api_chat.get("")
526
+ async def chat(
527
+ request: Request,
528
+ common: CommonQueryParams,
529
+ q: str,
530
+ n: int = 7,
531
+ d: float = 0.18,
532
+ stream: Optional[bool] = False,
533
+ title: Optional[str] = None,
534
+ conversation_id: Optional[int] = None,
533
535
  city: Optional[str] = None,
534
536
  region: Optional[str] = None,
535
537
  country: Optional[str] = None,
536
538
  timezone: Optional[str] = None,
539
+ rate_limiter_per_minute=Depends(
540
+ ApiUserRateLimiter(requests=5, subscribed_requests=60, window=60, slug="chat_minute")
541
+ ),
542
+ rate_limiter_per_day=Depends(
543
+ ApiUserRateLimiter(requests=5, subscribed_requests=600, window=60 * 60 * 24, slug="chat_day")
544
+ ),
537
545
  ):
538
- connection_alive = True
539
-
540
- async def send_status_update(message: str):
541
- nonlocal connection_alive
542
- if not connection_alive:
543
- return
546
+ async def event_generator(q: str):
547
+ start_time = time.perf_counter()
548
+ ttft = None
549
+ chat_metadata: dict = {}
550
+ connection_alive = True
551
+ user: KhojUser = request.user.object
552
+ event_delimiter = "␃🔚␗"
553
+ q = unquote(q)
554
+
555
+ async def send_event(event_type: ChatEvent, data: str | dict):
556
+ nonlocal connection_alive, ttft
557
+ if not connection_alive or await request.is_disconnected():
558
+ connection_alive = False
559
+ logger.warn(f"User {user} disconnected from {common.client} client")
560
+ return
561
+ try:
562
+ if event_type == ChatEvent.END_LLM_RESPONSE:
563
+ collect_telemetry()
564
+ if event_type == ChatEvent.START_LLM_RESPONSE:
565
+ ttft = time.perf_counter() - start_time
566
+ if event_type == ChatEvent.MESSAGE:
567
+ yield data
568
+ elif event_type == ChatEvent.REFERENCES or stream:
569
+ yield json.dumps({"type": event_type.value, "data": data}, ensure_ascii=False)
570
+ except asyncio.CancelledError as e:
571
+ connection_alive = False
572
+ logger.warn(f"User {user} disconnected from {common.client} client: {e}")
573
+ return
574
+ except Exception as e:
575
+ connection_alive = False
576
+ logger.error(f"Failed to stream chat API response to {user} on {common.client}: {e}", exc_info=True)
577
+ return
578
+ finally:
579
+ if stream:
580
+ yield event_delimiter
581
+
582
+ async def send_llm_response(response: str):
583
+ async for result in send_event(ChatEvent.START_LLM_RESPONSE, ""):
584
+ yield result
585
+ async for result in send_event(ChatEvent.MESSAGE, response):
586
+ yield result
587
+ async for result in send_event(ChatEvent.END_LLM_RESPONSE, ""):
588
+ yield result
589
+
590
+ def collect_telemetry():
591
+ # Gather chat response telemetry
592
+ nonlocal chat_metadata
593
+ latency = time.perf_counter() - start_time
594
+ cmd_set = set([cmd.value for cmd in conversation_commands])
595
+ chat_metadata = chat_metadata or {}
596
+ chat_metadata["conversation_command"] = cmd_set
597
+ chat_metadata["agent"] = conversation.agent.slug if conversation.agent else None
598
+ chat_metadata["latency"] = f"{latency:.3f}"
599
+ chat_metadata["ttft_latency"] = f"{ttft:.3f}"
600
+
601
+ logger.info(f"Chat response time to first token: {ttft:.3f} seconds")
602
+ logger.info(f"Chat response total time: {latency:.3f} seconds")
603
+ update_telemetry_state(
604
+ request=request,
605
+ telemetry_type="api",
606
+ api="chat",
607
+ client=request.user.client_app,
608
+ user_agent=request.headers.get("user-agent"),
609
+ host=request.headers.get("host"),
610
+ metadata=chat_metadata,
611
+ )
544
612
 
545
- status_packet = {
546
- "type": "status",
547
- "message": message,
548
- "content-type": "application/json",
549
- }
550
- try:
551
- await websocket.send_text(json.dumps(status_packet))
552
- except ConnectionClosedOK:
553
- connection_alive = False
554
- logger.info(f"User {user} disconnected web socket. Emitting rest of responses to clear thread")
555
-
556
- async def send_complete_llm_response(llm_response: str):
557
- nonlocal connection_alive
558
- if not connection_alive:
559
- return
560
- try:
561
- await websocket.send_text("start_llm_response")
562
- await websocket.send_text(llm_response)
563
- await websocket.send_text("end_llm_response")
564
- except ConnectionClosedOK:
565
- connection_alive = False
566
- logger.info(f"User {user} disconnected web socket. Emitting rest of responses to clear thread")
567
-
568
- async def send_message(message: str):
569
- nonlocal connection_alive
570
- if not connection_alive:
571
- return
572
- try:
573
- await websocket.send_text(message)
574
- except ConnectionClosedOK:
575
- connection_alive = False
576
- logger.info(f"User {user} disconnected web socket. Emitting rest of responses to clear thread")
577
-
578
- async def send_rate_limit_message(message: str):
579
- nonlocal connection_alive
580
- if not connection_alive:
613
+ conversation = await ConversationAdapters.aget_conversation_by_user(
614
+ user, client_application=request.user.client_app, conversation_id=conversation_id, title=title
615
+ )
616
+ if not conversation:
617
+ async for result in send_llm_response(f"Conversation {conversation_id} not found"):
618
+ yield result
581
619
  return
582
620
 
583
- status_packet = {
584
- "type": "rate_limit",
585
- "message": message,
586
- "content-type": "application/json",
587
- }
588
- try:
589
- await websocket.send_text(json.dumps(status_packet))
590
- except ConnectionClosedOK:
591
- connection_alive = False
592
- logger.info(f"User {user} disconnected web socket. Emitting rest of responses to clear thread")
593
-
594
- user: KhojUser = websocket.user.object
595
- conversation = await ConversationAdapters.aget_conversation_by_user(
596
- user, client_application=websocket.user.client_app, conversation_id=conversation_id
597
- )
598
-
599
- hourly_limiter = ApiUserRateLimiter(requests=5, subscribed_requests=60, window=60, slug="chat_minute")
600
-
601
- daily_limiter = ApiUserRateLimiter(requests=5, subscribed_requests=600, window=60 * 60 * 24, slug="chat_day")
602
-
603
- await is_ready_to_chat(user)
604
-
605
- user_name = await aget_user_name(user)
621
+ await is_ready_to_chat(user)
606
622
 
607
- location = None
608
-
609
- if city or region or country:
610
- location = LocationData(city=city, region=region, country=country)
611
-
612
- await websocket.accept()
613
- while connection_alive:
614
- try:
615
- if conversation:
616
- await sync_to_async(conversation.refresh_from_db)(fields=["conversation_log"])
617
- q = await websocket.receive_text()
618
-
619
- # Refresh these because the connection to the database might have been closed
620
- await conversation.arefresh_from_db()
621
-
622
- except WebSocketDisconnect:
623
- logger.debug(f"User {user} disconnected web socket")
624
- break
625
-
626
- try:
627
- await sync_to_async(hourly_limiter)(websocket)
628
- await sync_to_async(daily_limiter)(websocket)
629
- except HTTPException as e:
630
- await send_rate_limit_message(e.detail)
631
- break
623
+ user_name = await aget_user_name(user)
624
+ location = None
625
+ if city or region or country:
626
+ location = LocationData(city=city, region=region, country=country)
632
627
 
633
628
  if is_query_empty(q):
634
- await send_message("start_llm_response")
635
- await send_message(
636
- "It seems like your query is incomplete. Could you please provide more details or specify what you need help with?"
637
- )
638
- await send_message("end_llm_response")
639
- continue
629
+ async for result in send_llm_response("Please ask your query to get started."):
630
+ yield result
631
+ return
640
632
 
641
633
  user_message_time = datetime.now().strftime("%Y-%m-%d %H:%M:%S")
642
634
  conversation_commands = [get_conversation_command(query=q, any_references=True)]
643
635
 
644
- await send_status_update(f"**👀 Understanding Query**: {q}")
636
+ async for result in send_event(ChatEvent.STATUS, f"**👀 Understanding Query**: {q}"):
637
+ yield result
645
638
 
646
639
  meta_log = conversation.conversation_log
647
640
  is_automated_task = conversation_commands == [ConversationCommand.AutomatedTask]
648
- used_slash_summarize = conversation_commands == [ConversationCommand.Summarize]
649
641
 
650
642
  if conversation_commands == [ConversationCommand.Default] or is_automated_task:
651
643
  conversation_commands = await aget_relevant_information_sources(q, meta_log, is_automated_task)
652
644
  conversation_commands_str = ", ".join([cmd.value for cmd in conversation_commands])
653
- await send_status_update(f"**🗃️ Chose Data Sources to Search:** {conversation_commands_str}")
645
+ async for result in send_event(
646
+ ChatEvent.STATUS, f"**🗃️ Chose Data Sources to Search:** {conversation_commands_str}"
647
+ ):
648
+ yield result
654
649
 
655
650
  mode = await aget_relevant_output_modes(q, meta_log, is_automated_task)
656
- await send_status_update(f"**🧑🏾‍💻 Decided Response Mode:** {mode.value}")
651
+ async for result in send_event(ChatEvent.STATUS, f"**🧑🏾‍💻 Decided Response Mode:** {mode.value}"):
652
+ yield result
657
653
  if mode not in conversation_commands:
658
654
  conversation_commands.append(mode)
659
655
 
660
656
  for cmd in conversation_commands:
661
- await conversation_command_rate_limiter.update_and_check_if_valid(websocket, cmd)
657
+ await conversation_command_rate_limiter.update_and_check_if_valid(request, cmd)
662
658
  q = q.replace(f"/{cmd.value}", "").strip()
663
659
 
660
+ used_slash_summarize = conversation_commands == [ConversationCommand.Summarize]
664
661
  file_filters = conversation.file_filters if conversation else []
665
662
  # Skip trying to summarize if
666
663
  if (
@@ -676,28 +673,37 @@ async def websocket_endpoint(
676
673
  response_log = ""
677
674
  if len(file_filters) == 0:
678
675
  response_log = "No files selected for summarization. Please add files using the section on the left."
679
- await send_complete_llm_response(response_log)
676
+ async for result in send_llm_response(response_log):
677
+ yield result
680
678
  elif len(file_filters) > 1:
681
679
  response_log = "Only one file can be selected for summarization."
682
- await send_complete_llm_response(response_log)
680
+ async for result in send_llm_response(response_log):
681
+ yield result
683
682
  else:
684
683
  try:
685
684
  file_object = await FileObjectAdapters.async_get_file_objects_by_name(user, file_filters[0])
686
685
  if len(file_object) == 0:
687
686
  response_log = "Sorry, we couldn't find the full text of this file. Please re-upload the document and try again."
688
- await send_complete_llm_response(response_log)
689
- continue
687
+ async for result in send_llm_response(response_log):
688
+ yield result
689
+ return
690
690
  contextual_data = " ".join([file.raw_text for file in file_object])
691
691
  if not q:
692
692
  q = "Create a general summary of the file"
693
- await send_status_update(f"**🧑🏾‍💻 Constructing Summary Using:** {file_object[0].file_name}")
693
+ async for result in send_event(
694
+ ChatEvent.STATUS, f"**🧑🏾‍💻 Constructing Summary Using:** {file_object[0].file_name}"
695
+ ):
696
+ yield result
697
+
694
698
  response = await extract_relevant_summary(q, contextual_data)
695
699
  response_log = str(response)
696
- await send_complete_llm_response(response_log)
700
+ async for result in send_llm_response(response_log):
701
+ yield result
697
702
  except Exception as e:
698
703
  response_log = "Error summarizing file."
699
704
  logger.error(f"Error summarizing file for {user.email}: {e}", exc_info=True)
700
- await send_complete_llm_response(response_log)
705
+ async for result in send_llm_response(response_log):
706
+ yield result
701
707
  await sync_to_async(save_to_conversation_log)(
702
708
  q,
703
709
  response_log,
@@ -705,16 +711,10 @@ async def websocket_endpoint(
705
711
  meta_log,
706
712
  user_message_time,
707
713
  intent_type="summarize",
708
- client_application=websocket.user.client_app,
714
+ client_application=request.user.client_app,
709
715
  conversation_id=conversation_id,
710
716
  )
711
- update_telemetry_state(
712
- request=websocket,
713
- telemetry_type="api",
714
- api="chat",
715
- metadata={"conversation_command": conversation_commands[0].value},
716
- )
717
- continue
717
+ return
718
718
 
719
719
  custom_filters = []
720
720
  if conversation_commands == [ConversationCommand.Help]:
@@ -724,8 +724,9 @@ async def websocket_endpoint(
724
724
  conversation_config = await ConversationAdapters.aget_default_conversation_config()
725
725
  model_type = conversation_config.model_type
726
726
  formatted_help = help_message.format(model=model_type, version=state.khoj_version, device=get_device())
727
- await send_complete_llm_response(formatted_help)
728
- continue
727
+ async for result in send_llm_response(formatted_help):
728
+ yield result
729
+ return
729
730
  # Adding specification to search online specifically on khoj.dev pages.
730
731
  custom_filters.append("site:khoj.dev")
731
732
  conversation_commands.append(ConversationCommand.Online)
@@ -733,14 +734,14 @@ async def websocket_endpoint(
733
734
  if ConversationCommand.Automation in conversation_commands:
734
735
  try:
735
736
  automation, crontime, query_to_run, subject = await create_automation(
736
- q, timezone, user, websocket.url, meta_log
737
+ q, timezone, user, request.url, meta_log
737
738
  )
738
739
  except Exception as e:
739
740
  logger.error(f"Error scheduling task {q} for {user.email}: {e}")
740
- await send_complete_llm_response(
741
- f"Unable to create automation. Ensure the automation doesn't already exist."
742
- )
743
- continue
741
+ error_message = f"Unable to create automation. Ensure the automation doesn't already exist."
742
+ async for result in send_llm_response(error_message):
743
+ yield result
744
+ return
744
745
 
745
746
  llm_response = construct_automation_created_message(automation, crontime, query_to_run, subject)
746
747
  await sync_to_async(save_to_conversation_log)(
@@ -750,57 +751,78 @@ async def websocket_endpoint(
750
751
  meta_log,
751
752
  user_message_time,
752
753
  intent_type="automation",
753
- client_application=websocket.user.client_app,
754
+ client_application=request.user.client_app,
754
755
  conversation_id=conversation_id,
755
756
  inferred_queries=[query_to_run],
756
757
  automation_id=automation.id,
757
758
  )
758
- common = CommonQueryParamsClass(
759
- client=websocket.user.client_app,
760
- user_agent=websocket.headers.get("user-agent"),
761
- host=websocket.headers.get("host"),
762
- )
763
- update_telemetry_state(
764
- request=websocket,
765
- telemetry_type="api",
766
- api="chat",
767
- **common.__dict__,
768
- )
769
- await send_complete_llm_response(llm_response)
770
- continue
759
+ async for result in send_llm_response(llm_response):
760
+ yield result
761
+ return
771
762
 
772
- compiled_references, inferred_queries, defiltered_query = await extract_references_and_questions(
773
- websocket, meta_log, q, 7, 0.18, conversation_id, conversation_commands, location, send_status_update
774
- )
763
+ # Gather Context
764
+ ## Extract Document References
765
+ compiled_references, inferred_queries, defiltered_query = [], [], None
766
+ async for result in extract_references_and_questions(
767
+ request,
768
+ meta_log,
769
+ q,
770
+ (n or 7),
771
+ (d or 0.18),
772
+ conversation_id,
773
+ conversation_commands,
774
+ location,
775
+ partial(send_event, ChatEvent.STATUS),
776
+ ):
777
+ if isinstance(result, dict) and ChatEvent.STATUS in result:
778
+ yield result[ChatEvent.STATUS]
779
+ else:
780
+ compiled_references.extend(result[0])
781
+ inferred_queries.extend(result[1])
782
+ defiltered_query = result[2]
775
783
 
776
- if compiled_references:
784
+ if not is_none_or_empty(compiled_references):
777
785
  headings = "\n- " + "\n- ".join(set([c.get("compiled", c).split("\n")[0] for c in compiled_references]))
778
- await send_status_update(f"**📜 Found Relevant Notes**: {headings}")
786
+ async for result in send_event(ChatEvent.STATUS, f"**📜 Found Relevant Notes**: {headings}"):
787
+ yield result
779
788
 
780
789
  online_results: Dict = dict()
781
790
 
782
791
  if conversation_commands == [ConversationCommand.Notes] and not await EntryAdapters.auser_has_entries(user):
783
- await send_complete_llm_response(f"{no_entries_found.format()}")
784
- continue
792
+ async for result in send_llm_response(f"{no_entries_found.format()}"):
793
+ yield result
794
+ return
785
795
 
786
796
  if ConversationCommand.Notes in conversation_commands and is_none_or_empty(compiled_references):
787
797
  conversation_commands.remove(ConversationCommand.Notes)
788
798
 
799
+ ## Gather Online References
789
800
  if ConversationCommand.Online in conversation_commands:
790
801
  try:
791
- online_results = await search_online(
792
- defiltered_query, meta_log, location, send_status_update, custom_filters
793
- )
802
+ async for result in search_online(
803
+ defiltered_query, meta_log, location, partial(send_event, ChatEvent.STATUS), custom_filters
804
+ ):
805
+ if isinstance(result, dict) and ChatEvent.STATUS in result:
806
+ yield result[ChatEvent.STATUS]
807
+ else:
808
+ online_results = result
794
809
  except ValueError as e:
795
- logger.warning(f"Error searching online: {e}. Attempting to respond without online results")
796
- await send_complete_llm_response(
797
- f"Error searching online: {e}. Attempting to respond without online results"
798
- )
799
- continue
810
+ error_message = f"Error searching online: {e}. Attempting to respond without online results"
811
+ logger.warning(error_message)
812
+ async for result in send_llm_response(error_message):
813
+ yield result
814
+ return
800
815
 
816
+ ## Gather Webpage References
801
817
  if ConversationCommand.Webpage in conversation_commands:
802
818
  try:
803
- direct_web_pages = await read_webpages(defiltered_query, meta_log, location, send_status_update)
819
+ async for result in read_webpages(
820
+ defiltered_query, meta_log, location, partial(send_event, ChatEvent.STATUS)
821
+ ):
822
+ if isinstance(result, dict) and ChatEvent.STATUS in result:
823
+ yield result[ChatEvent.STATUS]
824
+ else:
825
+ direct_web_pages = result
804
826
  webpages = []
805
827
  for query in direct_web_pages:
806
828
  if online_results.get(query):
@@ -810,38 +832,52 @@ async def websocket_endpoint(
810
832
 
811
833
  for webpage in direct_web_pages[query]["webpages"]:
812
834
  webpages.append(webpage["link"])
813
-
814
- await send_status_update(f"**📚 Read web pages**: {webpages}")
835
+ async for result in send_event(ChatEvent.STATUS, f"**📚 Read web pages**: {webpages}"):
836
+ yield result
815
837
  except ValueError as e:
816
838
  logger.warning(
817
- f"Error directly reading webpages: {e}. Attempting to respond without online results", exc_info=True
839
+ f"Error directly reading webpages: {e}. Attempting to respond without online results",
840
+ exc_info=True,
818
841
  )
819
842
 
843
+ ## Send Gathered References
844
+ async for result in send_event(
845
+ ChatEvent.REFERENCES,
846
+ {
847
+ "inferredQueries": inferred_queries,
848
+ "context": compiled_references,
849
+ "onlineContext": online_results,
850
+ },
851
+ ):
852
+ yield result
853
+
854
+ # Generate Output
855
+ ## Generate Image Output
820
856
  if ConversationCommand.Image in conversation_commands:
821
- update_telemetry_state(
822
- request=websocket,
823
- telemetry_type="api",
824
- api="chat",
825
- metadata={"conversation_command": conversation_commands[0].value},
826
- )
827
- image, status_code, improved_image_prompt, intent_type = await text_to_image(
857
+ async for result in text_to_image(
828
858
  q,
829
859
  user,
830
860
  meta_log,
831
861
  location_data=location,
832
862
  references=compiled_references,
833
863
  online_results=online_results,
834
- send_status_func=send_status_update,
835
- )
864
+ send_status_func=partial(send_event, ChatEvent.STATUS),
865
+ ):
866
+ if isinstance(result, dict) and ChatEvent.STATUS in result:
867
+ yield result[ChatEvent.STATUS]
868
+ else:
869
+ image, status_code, improved_image_prompt, intent_type = result
870
+
836
871
  if image is None or status_code != 200:
837
872
  content_obj = {
838
- "image": image,
873
+ "content-type": "application/json",
839
874
  "intentType": intent_type,
840
875
  "detail": improved_image_prompt,
841
- "content-type": "application/json",
876
+ "image": image,
842
877
  }
843
- await send_complete_llm_response(json.dumps(content_obj))
844
- continue
878
+ async for result in send_llm_response(json.dumps(content_obj)):
879
+ yield result
880
+ return
845
881
 
846
882
  await sync_to_async(save_to_conversation_log)(
847
883
  q,
@@ -851,17 +887,23 @@ async def websocket_endpoint(
851
887
  user_message_time,
852
888
  intent_type=intent_type,
853
889
  inferred_queries=[improved_image_prompt],
854
- client_application=websocket.user.client_app,
890
+ client_application=request.user.client_app,
855
891
  conversation_id=conversation_id,
856
892
  compiled_references=compiled_references,
857
893
  online_results=online_results,
858
894
  )
859
- content_obj = {"image": image, "intentType": intent_type, "inferredQueries": [improved_image_prompt], "context": compiled_references, "content-type": "application/json", "online_results": online_results} # type: ignore
860
-
861
- await send_complete_llm_response(json.dumps(content_obj))
862
- continue
895
+ content_obj = {
896
+ "intentType": intent_type,
897
+ "inferredQueries": [improved_image_prompt],
898
+ "image": image,
899
+ }
900
+ async for result in send_llm_response(json.dumps(content_obj)):
901
+ yield result
902
+ return
863
903
 
864
- await send_status_update(f"**💭 Generating a well-informed response**")
904
+ ## Generate Text Output
905
+ async for result in send_event(ChatEvent.STATUS, f"**💭 Generating a well-informed response**"):
906
+ yield result
865
907
  llm_response, chat_metadata = await agenerate_chat_response(
866
908
  defiltered_query,
867
909
  meta_log,
@@ -871,310 +913,49 @@ async def websocket_endpoint(
871
913
  inferred_queries,
872
914
  conversation_commands,
873
915
  user,
874
- websocket.user.client_app,
916
+ request.user.client_app,
875
917
  conversation_id,
876
918
  location,
877
919
  user_name,
878
920
  )
879
921
 
880
- chat_metadata["agent"] = conversation.agent.slug if conversation.agent else None
922
+ # Send Response
923
+ async for result in send_event(ChatEvent.START_LLM_RESPONSE, ""):
924
+ yield result
881
925
 
882
- update_telemetry_state(
883
- request=websocket,
884
- telemetry_type="api",
885
- api="chat",
886
- metadata=chat_metadata,
887
- )
926
+ continue_stream = True
888
927
  iterator = AsyncIteratorWrapper(llm_response)
889
-
890
- await send_message("start_llm_response")
891
-
892
928
  async for item in iterator:
893
929
  if item is None:
894
- break
895
- if connection_alive:
896
- try:
897
- await send_message(f"{item}")
898
- except ConnectionClosedOK:
899
- connection_alive = False
900
- logger.info(f"User {user} disconnected web socket. Emitting rest of responses to clear thread")
901
-
902
- await send_message("end_llm_response")
903
-
904
-
905
- @api_chat.get("", response_class=Response)
906
- @requires(["authenticated"])
907
- async def chat(
908
- request: Request,
909
- common: CommonQueryParams,
910
- q: str,
911
- n: Optional[int] = 5,
912
- d: Optional[float] = 0.22,
913
- stream: Optional[bool] = False,
914
- title: Optional[str] = None,
915
- conversation_id: Optional[int] = None,
916
- city: Optional[str] = None,
917
- region: Optional[str] = None,
918
- country: Optional[str] = None,
919
- timezone: Optional[str] = None,
920
- rate_limiter_per_minute=Depends(
921
- ApiUserRateLimiter(requests=5, subscribed_requests=60, window=60, slug="chat_minute")
922
- ),
923
- rate_limiter_per_day=Depends(
924
- ApiUserRateLimiter(requests=5, subscribed_requests=600, window=60 * 60 * 24, slug="chat_day")
925
- ),
926
- ) -> Response:
927
- user: KhojUser = request.user.object
928
- q = unquote(q)
929
- if is_query_empty(q):
930
- return Response(
931
- content="It seems like your query is incomplete. Could you please provide more details or specify what you need help with?",
932
- media_type="text/plain",
933
- status_code=400,
934
- )
935
- user_message_time = datetime.now().strftime("%Y-%m-%d %H:%M:%S")
936
- logger.info(f"Chat request by {user.username}: {q}")
937
-
938
- await is_ready_to_chat(user)
939
- conversation_commands = [get_conversation_command(query=q, any_references=True)]
940
-
941
- _custom_filters = []
942
- if conversation_commands == [ConversationCommand.Help]:
943
- help_str = "/" + ConversationCommand.Help
944
- if q.strip() == help_str:
945
- conversation_config = await ConversationAdapters.aget_user_conversation_config(user)
946
- if conversation_config == None:
947
- conversation_config = await ConversationAdapters.aget_default_conversation_config()
948
- model_type = conversation_config.model_type
949
- formatted_help = help_message.format(model=model_type, version=state.khoj_version, device=get_device())
950
- return StreamingResponse(iter([formatted_help]), media_type="text/event-stream", status_code=200)
951
- # Adding specification to search online specifically on khoj.dev pages.
952
- _custom_filters.append("site:khoj.dev")
953
- conversation_commands.append(ConversationCommand.Online)
954
-
955
- conversation = await ConversationAdapters.aget_conversation_by_user(
956
- user, request.user.client_app, conversation_id, title
957
- )
958
- conversation_id = conversation.id if conversation else None
959
-
960
- if not conversation:
961
- return Response(
962
- content=f"No conversation found with requested id, title", media_type="text/plain", status_code=400
963
- )
964
- else:
965
- meta_log = conversation.conversation_log
966
-
967
- if ConversationCommand.Summarize in conversation_commands:
968
- file_filters = conversation.file_filters
969
- llm_response = ""
970
- if len(file_filters) == 0:
971
- llm_response = "No files selected for summarization. Please add files using the section on the left."
972
- elif len(file_filters) > 1:
973
- llm_response = "Only one file can be selected for summarization."
974
- else:
930
+ async for result in send_event(ChatEvent.END_LLM_RESPONSE, ""):
931
+ yield result
932
+ logger.debug("Finished streaming response")
933
+ return
934
+ if not connection_alive or not continue_stream:
935
+ continue
975
936
  try:
976
- file_object = await FileObjectAdapters.async_get_file_objects_by_name(user, file_filters[0])
977
- if len(file_object) == 0:
978
- llm_response = "Sorry, we couldn't find the full text of this file. Please re-upload the document and try again."
979
- return StreamingResponse(content=llm_response, media_type="text/event-stream", status_code=200)
980
- contextual_data = " ".join([file.raw_text for file in file_object])
981
- summarizeStr = "/" + ConversationCommand.Summarize
982
- if q.strip() == summarizeStr:
983
- q = "Create a general summary of the file"
984
- response = await extract_relevant_summary(q, contextual_data)
985
- llm_response = str(response)
937
+ async for result in send_event(ChatEvent.MESSAGE, f"{item}"):
938
+ yield result
986
939
  except Exception as e:
987
- logger.error(f"Error summarizing file for {user.email}: {e}")
988
- llm_response = "Error summarizing file."
989
- await sync_to_async(save_to_conversation_log)(
990
- q,
991
- llm_response,
992
- user,
993
- conversation.conversation_log,
994
- user_message_time,
995
- intent_type="summarize",
996
- client_application=request.user.client_app,
997
- conversation_id=conversation_id,
998
- )
999
- update_telemetry_state(
1000
- request=request,
1001
- telemetry_type="api",
1002
- api="chat",
1003
- metadata={"conversation_command": conversation_commands[0].value},
1004
- **common.__dict__,
1005
- )
1006
- return StreamingResponse(content=llm_response, media_type="text/event-stream", status_code=200)
1007
-
1008
- is_automated_task = conversation_commands == [ConversationCommand.AutomatedTask]
1009
-
1010
- if conversation_commands == [ConversationCommand.Default] or is_automated_task:
1011
- conversation_commands = await aget_relevant_information_sources(q, meta_log, is_automated_task)
1012
- mode = await aget_relevant_output_modes(q, meta_log, is_automated_task)
1013
- if mode not in conversation_commands:
1014
- conversation_commands.append(mode)
1015
-
1016
- for cmd in conversation_commands:
1017
- await conversation_command_rate_limiter.update_and_check_if_valid(request, cmd)
1018
- q = q.replace(f"/{cmd.value}", "").strip()
1019
-
1020
- location = None
1021
-
1022
- if city or region or country:
1023
- location = LocationData(city=city, region=region, country=country)
1024
-
1025
- user_name = await aget_user_name(user)
1026
-
1027
- if ConversationCommand.Automation in conversation_commands:
1028
- try:
1029
- automation, crontime, query_to_run, subject = await create_automation(
1030
- q, timezone, user, request.url, meta_log
1031
- )
1032
- except Exception as e:
1033
- logger.error(f"Error creating automation {q} for {user.email}: {e}", exc_info=True)
1034
- return Response(
1035
- content=f"Unable to create automation. Ensure the automation doesn't already exist.",
1036
- media_type="text/plain",
1037
- status_code=500,
1038
- )
1039
-
1040
- llm_response = construct_automation_created_message(automation, crontime, query_to_run, subject)
1041
- await sync_to_async(save_to_conversation_log)(
1042
- q,
1043
- llm_response,
1044
- user,
1045
- meta_log,
1046
- user_message_time,
1047
- intent_type="automation",
1048
- client_application=request.user.client_app,
1049
- conversation_id=conversation_id,
1050
- inferred_queries=[query_to_run],
1051
- automation_id=automation.id,
1052
- )
1053
-
1054
- if stream:
1055
- return StreamingResponse(llm_response, media_type="text/event-stream", status_code=200)
1056
- else:
1057
- return Response(content=llm_response, media_type="text/plain", status_code=200)
1058
-
1059
- compiled_references, inferred_queries, defiltered_query = await extract_references_and_questions(
1060
- request, meta_log, q, (n or 5), (d or math.inf), conversation_id, conversation_commands, location
1061
- )
1062
- online_results: Dict[str, Dict] = {}
1063
-
1064
- if conversation_commands == [ConversationCommand.Notes] and not await EntryAdapters.auser_has_entries(user):
1065
- no_entries_found_format = no_entries_found.format()
1066
- if stream:
1067
- return StreamingResponse(iter([no_entries_found_format]), media_type="text/event-stream", status_code=200)
1068
- else:
1069
- response_obj = {"response": no_entries_found_format}
1070
- return Response(content=json.dumps(response_obj), media_type="text/plain", status_code=200)
1071
-
1072
- if conversation_commands == [ConversationCommand.Notes] and is_none_or_empty(compiled_references):
1073
- no_notes_found_format = no_notes_found.format()
1074
- if stream:
1075
- return StreamingResponse(iter([no_notes_found_format]), media_type="text/event-stream", status_code=200)
1076
- else:
1077
- response_obj = {"response": no_notes_found_format}
1078
- return Response(content=json.dumps(response_obj), media_type="text/plain", status_code=200)
1079
-
1080
- if ConversationCommand.Notes in conversation_commands and is_none_or_empty(compiled_references):
1081
- conversation_commands.remove(ConversationCommand.Notes)
1082
-
1083
- if ConversationCommand.Online in conversation_commands:
1084
- try:
1085
- online_results = await search_online(defiltered_query, meta_log, location, custom_filters=_custom_filters)
1086
- except ValueError as e:
1087
- logger.warning(f"Error searching online: {e}. Attempting to respond without online results")
1088
-
1089
- if ConversationCommand.Webpage in conversation_commands:
1090
- try:
1091
- online_results = await read_webpages(defiltered_query, meta_log, location)
1092
- except ValueError as e:
1093
- logger.warning(
1094
- f"Error directly reading webpages: {e}. Attempting to respond without online results", exc_info=True
1095
- )
1096
-
1097
- if ConversationCommand.Image in conversation_commands:
1098
- update_telemetry_state(
1099
- request=request,
1100
- telemetry_type="api",
1101
- api="chat",
1102
- metadata={"conversation_command": conversation_commands[0].value},
1103
- **common.__dict__,
1104
- )
1105
- image, status_code, improved_image_prompt, intent_type = await text_to_image(
1106
- q, user, meta_log, location_data=location, references=compiled_references, online_results=online_results
1107
- )
1108
- if image is None:
1109
- content_obj = {"image": image, "intentType": intent_type, "detail": improved_image_prompt}
1110
- return Response(content=json.dumps(content_obj), media_type="application/json", status_code=status_code)
1111
-
1112
- await sync_to_async(save_to_conversation_log)(
1113
- q,
1114
- image,
1115
- user,
1116
- meta_log,
1117
- user_message_time,
1118
- intent_type=intent_type,
1119
- inferred_queries=[improved_image_prompt],
1120
- client_application=request.user.client_app,
1121
- conversation_id=conversation.id,
1122
- compiled_references=compiled_references,
1123
- online_results=online_results,
1124
- )
1125
- content_obj = {"image": image, "intentType": intent_type, "inferredQueries": [improved_image_prompt], "context": compiled_references, "online_results": online_results} # type: ignore
1126
- return Response(content=json.dumps(content_obj), media_type="application/json", status_code=status_code)
1127
-
1128
- # Get the (streamed) chat response from the LLM of choice.
1129
- llm_response, chat_metadata = await agenerate_chat_response(
1130
- defiltered_query,
1131
- meta_log,
1132
- conversation,
1133
- compiled_references,
1134
- online_results,
1135
- inferred_queries,
1136
- conversation_commands,
1137
- user,
1138
- request.user.client_app,
1139
- conversation.id,
1140
- location,
1141
- user_name,
1142
- )
1143
-
1144
- cmd_set = set([cmd.value for cmd in conversation_commands])
1145
- chat_metadata["conversation_command"] = cmd_set
1146
- chat_metadata["agent"] = conversation.agent.slug if conversation.agent else None
1147
-
1148
- update_telemetry_state(
1149
- request=request,
1150
- telemetry_type="api",
1151
- api="chat",
1152
- metadata=chat_metadata,
1153
- **common.__dict__,
1154
- )
1155
-
1156
- if llm_response is None:
1157
- return Response(content=llm_response, media_type="text/plain", status_code=500)
940
+ continue_stream = False
941
+ logger.info(f"User {user} disconnected. Emitting rest of responses to clear thread: {e}")
1158
942
 
943
+ ## Stream Text Response
1159
944
  if stream:
1160
- return StreamingResponse(llm_response, media_type="text/event-stream", status_code=200)
1161
-
1162
- iterator = AsyncIteratorWrapper(llm_response)
1163
-
1164
- # Get the full response from the generator if the stream is not requested.
1165
- aggregated_gpt_response = ""
1166
- async for item in iterator:
1167
- if item is None:
1168
- break
1169
- aggregated_gpt_response += item
1170
-
1171
- actual_response = aggregated_gpt_response.split("### compiled references:")[0]
1172
-
1173
- response_obj = {
1174
- "response": actual_response,
1175
- "inferredQueries": inferred_queries,
1176
- "context": compiled_references,
1177
- "online_results": online_results,
1178
- }
1179
-
1180
- return Response(content=json.dumps(response_obj), media_type="application/json", status_code=200)
945
+ return StreamingResponse(event_generator(q), media_type="text/plain")
946
+ ## Non-Streaming Text Response
947
+ else:
948
+ # Get the full response from the generator if the stream is not requested.
949
+ response_obj = {}
950
+ actual_response = ""
951
+ iterator = event_generator(q)
952
+ async for item in iterator:
953
+ try:
954
+ item_json = json.loads(item)
955
+ if "type" in item_json and item_json["type"] == ChatEvent.REFERENCES.value:
956
+ response_obj = item_json["data"]
957
+ except:
958
+ actual_response += item
959
+ response_obj["response"] = actual_response
960
+
961
+ return Response(content=json.dumps(response_obj), media_type="application/json", status_code=200)