khoj 1.16.1.dev25__py3-none-any.whl → 1.17.1.dev216__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (47) hide show
  1. khoj/configure.py +6 -6
  2. khoj/database/adapters/__init__.py +55 -26
  3. khoj/database/migrations/0053_agent_style_color_agent_style_icon.py +61 -0
  4. khoj/database/models/__init__.py +35 -0
  5. khoj/interface/web/assets/icons/favicon-128x128.png +0 -0
  6. khoj/interface/web/assets/icons/favicon-256x256.png +0 -0
  7. khoj/interface/web/assets/icons/khoj-logo-sideways-200.png +0 -0
  8. khoj/interface/web/assets/icons/khoj-logo-sideways-500.png +0 -0
  9. khoj/interface/web/assets/icons/khoj-logo-sideways.svg +31 -5384
  10. khoj/interface/web/assets/icons/khoj.svg +26 -0
  11. khoj/interface/web/chat.html +191 -301
  12. khoj/interface/web/content_source_computer_input.html +3 -3
  13. khoj/interface/web/content_source_github_input.html +1 -1
  14. khoj/interface/web/content_source_notion_input.html +1 -1
  15. khoj/interface/web/public_conversation.html +1 -1
  16. khoj/interface/web/search.html +2 -2
  17. khoj/interface/web/{config.html → settings.html} +30 -30
  18. khoj/interface/web/utils.html +1 -1
  19. khoj/processor/content/docx/docx_to_entries.py +4 -9
  20. khoj/processor/content/github/github_to_entries.py +1 -3
  21. khoj/processor/content/images/image_to_entries.py +4 -9
  22. khoj/processor/content/markdown/markdown_to_entries.py +4 -9
  23. khoj/processor/content/notion/notion_to_entries.py +1 -3
  24. khoj/processor/content/org_mode/org_to_entries.py +4 -9
  25. khoj/processor/content/pdf/pdf_to_entries.py +4 -9
  26. khoj/processor/content/plaintext/plaintext_to_entries.py +4 -9
  27. khoj/processor/content/text_to_entries.py +1 -3
  28. khoj/processor/conversation/utils.py +0 -4
  29. khoj/processor/tools/online_search.py +13 -7
  30. khoj/routers/api.py +58 -9
  31. khoj/routers/api_agents.py +3 -1
  32. khoj/routers/api_chat.py +335 -562
  33. khoj/routers/api_content.py +538 -0
  34. khoj/routers/api_model.py +156 -0
  35. khoj/routers/helpers.py +338 -23
  36. khoj/routers/notion.py +2 -8
  37. khoj/routers/web_client.py +43 -256
  38. khoj/search_type/text_search.py +5 -4
  39. khoj/utils/fs_syncer.py +4 -2
  40. khoj/utils/rawconfig.py +6 -1
  41. {khoj-1.16.1.dev25.dist-info → khoj-1.17.1.dev216.dist-info}/METADATA +2 -2
  42. {khoj-1.16.1.dev25.dist-info → khoj-1.17.1.dev216.dist-info}/RECORD +45 -43
  43. khoj/routers/api_config.py +0 -434
  44. khoj/routers/indexer.py +0 -349
  45. {khoj-1.16.1.dev25.dist-info → khoj-1.17.1.dev216.dist-info}/WHEEL +0 -0
  46. {khoj-1.16.1.dev25.dist-info → khoj-1.17.1.dev216.dist-info}/entry_points.txt +0 -0
  47. {khoj-1.16.1.dev25.dist-info → khoj-1.17.1.dev216.dist-info}/licenses/LICENSE +0 -0
khoj/routers/api_chat.py CHANGED
@@ -1,41 +1,36 @@
1
+ import asyncio
1
2
  import json
2
3
  import logging
3
- import math
4
+ import time
4
5
  from datetime import datetime
6
+ from functools import partial
5
7
  from typing import Any, Dict, List, Optional
6
8
  from urllib.parse import unquote
7
9
 
8
10
  from asgiref.sync import sync_to_async
9
- from fastapi import APIRouter, Depends, HTTPException, Request, WebSocket
11
+ from fastapi import APIRouter, Depends, HTTPException, Request
10
12
  from fastapi.requests import Request
11
13
  from fastapi.responses import Response, StreamingResponse
12
14
  from starlette.authentication import requires
13
- from starlette.websockets import WebSocketDisconnect
14
- from websockets import ConnectionClosedOK
15
15
 
16
16
  from khoj.app.settings import ALLOWED_HOSTS
17
17
  from khoj.database.adapters import (
18
18
  ConversationAdapters,
19
- DataStoreAdapters,
20
19
  EntryAdapters,
21
20
  FileObjectAdapters,
22
21
  PublicConversationAdapters,
23
22
  aget_user_name,
24
23
  )
25
24
  from khoj.database.models import KhojUser
26
- from khoj.processor.conversation.prompts import (
27
- help_message,
28
- no_entries_found,
29
- no_notes_found,
30
- )
25
+ from khoj.processor.conversation.prompts import help_message, no_entries_found
31
26
  from khoj.processor.conversation.utils import save_to_conversation_log
32
27
  from khoj.processor.speech.text_to_speech import generate_text_to_speech
33
28
  from khoj.processor.tools.online_search import read_webpages, search_online
34
29
  from khoj.routers.api import extract_references_and_questions
35
30
  from khoj.routers.helpers import (
36
31
  ApiUserRateLimiter,
32
+ ChatEvent,
37
33
  CommonQueryParams,
38
- CommonQueryParamsClass,
39
34
  ConversationCommandRateLimiter,
40
35
  agenerate_chat_response,
41
36
  aget_relevant_information_sources,
@@ -58,7 +53,7 @@ from khoj.utils.helpers import (
58
53
  get_device,
59
54
  is_none_or_empty,
60
55
  )
61
- from khoj.utils.rawconfig import FilterRequest, LocationData
56
+ from khoj.utils.rawconfig import FileFilterRequest, FilesFilterRequest, LocationData
62
57
 
63
58
  # Initialize Router
64
59
  logger = logging.getLogger(__name__)
@@ -92,68 +87,36 @@ def get_file_filter(request: Request, conversation_id: str) -> Response:
92
87
  return Response(content=json.dumps(file_filters), media_type="application/json", status_code=200)
93
88
 
94
89
 
95
- class FactCheckerStoreDataFormat(BaseModel):
96
- factToVerify: str
97
- response: str
98
- references: Any
99
- childReferences: List[Any]
100
- runId: str
101
- modelUsed: Dict[str, Any]
102
-
103
-
104
- class FactCheckerStoreData(BaseModel):
105
- runId: str
106
- storeData: FactCheckerStoreDataFormat
107
-
108
-
109
- @api_chat.post("/store/factchecker", response_class=Response)
90
+ @api_chat.delete("/conversation/file-filters/bulk", response_class=Response)
110
91
  @requires(["authenticated"])
111
- async def store_factchecker(request: Request, common: CommonQueryParams, data: FactCheckerStoreData):
112
- user = request.user.object
113
-
114
- update_telemetry_state(
115
- request=request,
116
- telemetry_type="api",
117
- api="store_factchecker",
118
- **common.__dict__,
119
- )
120
- fact_checker_key = f"factchecker_{data.runId}"
121
- await DataStoreAdapters.astore_data(data.storeData.model_dump_json(), fact_checker_key, user, private=False)
122
- return Response(content=json.dumps({"status": "ok"}), media_type="application/json", status_code=200)
123
-
124
-
125
- @api_chat.get("/store/factchecker", response_class=Response)
126
- async def get_factchecker(request: Request, common: CommonQueryParams, runId: str):
127
- update_telemetry_state(
128
- request=request,
129
- telemetry_type="api",
130
- api="read_factchecker",
131
- **common.__dict__,
132
- )
92
+ def remove_files_filter(request: Request, filter: FilesFilterRequest) -> Response:
93
+ conversation_id = int(filter.conversation_id)
94
+ files_filter = filter.filenames
95
+ file_filters = ConversationAdapters.remove_files_from_filter(request.user.object, conversation_id, files_filter)
96
+ return Response(content=json.dumps(file_filters), media_type="application/json", status_code=200)
133
97
 
134
- fact_checker_key = f"factchecker_{runId}"
135
98
 
136
- data = await DataStoreAdapters.aretrieve_public_data(fact_checker_key)
137
- if data is None:
138
- return Response(status_code=404)
139
- return Response(content=json.dumps(data.value), media_type="application/json", status_code=200)
99
+ @api_chat.post("/conversation/file-filters/bulk", response_class=Response)
100
+ @requires(["authenticated"])
101
+ def add_files_filter(request: Request, filter: FilesFilterRequest):
102
+ try:
103
+ conversation_id = int(filter.conversation_id)
104
+ files_filter = filter.filenames
105
+ file_filters = ConversationAdapters.add_files_to_filter(request.user.object, conversation_id, files_filter)
106
+ return Response(content=json.dumps(file_filters), media_type="application/json", status_code=200)
107
+ except Exception as e:
108
+ logger.error(f"Error adding file filter {filter.filename}: {e}", exc_info=True)
109
+ raise HTTPException(status_code=422, detail=str(e))
140
110
 
141
111
 
142
112
  @api_chat.post("/conversation/file-filters", response_class=Response)
143
113
  @requires(["authenticated"])
144
- def add_file_filter(request: Request, filter: FilterRequest):
114
+ def add_file_filter(request: Request, filter: FileFilterRequest):
145
115
  try:
146
- conversation = ConversationAdapters.get_conversation_by_user(
147
- request.user.object, conversation_id=int(filter.conversation_id)
148
- )
149
- file_list = EntryAdapters.get_all_filenames_by_source(request.user.object, "computer")
150
- if filter.filename in file_list and filter.filename not in conversation.file_filters:
151
- conversation.file_filters.append(filter.filename)
152
- conversation.save()
153
- # remove files from conversation.file_filters that are not in file_list
154
- conversation.file_filters = [file for file in conversation.file_filters if file in file_list]
155
- conversation.save()
156
- return Response(content=json.dumps(conversation.file_filters), media_type="application/json", status_code=200)
116
+ conversation_id = int(filter.conversation_id)
117
+ files_filter = [filter.filename]
118
+ file_filters = ConversationAdapters.add_files_to_filter(request.user.object, conversation_id, files_filter)
119
+ return Response(content=json.dumps(file_filters), media_type="application/json", status_code=200)
157
120
  except Exception as e:
158
121
  logger.error(f"Error adding file filter {filter.filename}: {e}", exc_info=True)
159
122
  raise HTTPException(status_code=422, detail=str(e))
@@ -161,18 +124,11 @@ def add_file_filter(request: Request, filter: FilterRequest):
161
124
 
162
125
  @api_chat.delete("/conversation/file-filters", response_class=Response)
163
126
  @requires(["authenticated"])
164
- def remove_file_filter(request: Request, filter: FilterRequest) -> Response:
165
- conversation = ConversationAdapters.get_conversation_by_user(
166
- request.user.object, conversation_id=int(filter.conversation_id)
167
- )
168
- if filter.filename in conversation.file_filters:
169
- conversation.file_filters.remove(filter.filename)
170
- conversation.save()
171
- # remove files from conversation.file_filters that are not in file_list
172
- file_list = EntryAdapters.get_all_filenames_by_source(request.user.object, "computer")
173
- conversation.file_filters = [file for file in conversation.file_filters if file in file_list]
174
- conversation.save()
175
- return Response(content=json.dumps(conversation.file_filters), media_type="application/json", status_code=200)
127
+ def remove_file_filter(request: Request, filter: FileFilterRequest) -> Response:
128
+ conversation_id = int(filter.conversation_id)
129
+ files_filter = [filter.filename]
130
+ file_filters = ConversationAdapters.remove_files_from_filter(request.user.object, conversation_id, files_filter)
131
+ return Response(content=json.dumps(file_filters), media_type="application/json", status_code=200)
176
132
 
177
133
 
178
134
  class FeedbackData(BaseModel):
@@ -195,10 +151,10 @@ async def text_to_speech(
195
151
  common: CommonQueryParams,
196
152
  text: str,
197
153
  rate_limiter_per_minute=Depends(
198
- ApiUserRateLimiter(requests=5, subscribed_requests=20, window=60, slug="chat_minute")
154
+ ApiUserRateLimiter(requests=20, subscribed_requests=20, window=60, slug="chat_minute")
199
155
  ),
200
156
  rate_limiter_per_day=Depends(
201
- ApiUserRateLimiter(requests=5, subscribed_requests=300, window=60 * 60 * 24, slug="chat_day")
157
+ ApiUserRateLimiter(requests=50, subscribed_requests=300, window=60 * 60 * 24, slug="chat_day")
202
158
  ),
203
159
  ) -> Response:
204
160
  voice_model = await ConversationAdapters.aget_voice_model_config(request.user.object)
@@ -252,6 +208,9 @@ def chat_history(
252
208
  "name": conversation.agent.name,
253
209
  "avatar": conversation.agent.avatar,
254
210
  "isCreator": conversation.agent.creator == user,
211
+ "color": conversation.agent.style_color,
212
+ "icon": conversation.agent.style_icon,
213
+ "persona": conversation.agent.personality,
255
214
  }
256
215
 
257
216
  meta_log = conversation.conversation_log
@@ -306,13 +265,21 @@ def get_shared_chat(
306
265
  "name": conversation.agent.name,
307
266
  "avatar": conversation.agent.avatar,
308
267
  "isCreator": conversation.agent.creator == user,
268
+ "color": conversation.agent.style_color,
269
+ "icon": conversation.agent.style_icon,
270
+ "persona": conversation.agent.personality,
309
271
  }
310
272
 
311
273
  meta_log = conversation.conversation_log
274
+ scrubbed_title = conversation.title if conversation.title else conversation.slug
275
+
276
+ if scrubbed_title:
277
+ scrubbed_title = scrubbed_title.replace("-", " ")
278
+
312
279
  meta_log.update(
313
280
  {
314
281
  "conversation_id": conversation.id,
315
- "slug": conversation.title if conversation.title else conversation.slug,
282
+ "slug": scrubbed_title,
316
283
  "agent": agent_metadata,
317
284
  }
318
285
  )
@@ -328,7 +295,7 @@ def get_shared_chat(
328
295
  update_telemetry_state(
329
296
  request=request,
330
297
  telemetry_type="api",
331
- api="public_conversation_history",
298
+ api="chat_history",
332
299
  **common.__dict__,
333
300
  )
334
301
 
@@ -370,7 +337,7 @@ def fork_public_conversation(
370
337
  public_conversation = PublicConversationAdapters.get_public_conversation_by_slug(public_conversation_slug)
371
338
 
372
339
  # Duplicate Public Conversation to User's Private Conversation
373
- ConversationAdapters.create_conversation_from_public_conversation(
340
+ new_conversation = ConversationAdapters.create_conversation_from_public_conversation(
374
341
  user, public_conversation, request.user.client_app
375
342
  )
376
343
 
@@ -386,7 +353,16 @@ def fork_public_conversation(
386
353
 
387
354
  redirect_uri = str(request.app.url_path_for("chat_page"))
388
355
 
389
- return Response(status_code=200, content=json.dumps({"status": "ok", "next_url": redirect_uri}))
356
+ return Response(
357
+ status_code=200,
358
+ content=json.dumps(
359
+ {
360
+ "status": "ok",
361
+ "next_url": redirect_uri,
362
+ "conversation_id": new_conversation.id,
363
+ }
364
+ ),
365
+ )
390
366
 
391
367
 
392
368
  @api_chat.post("/share")
@@ -427,15 +403,30 @@ def duplicate_chat_history_public_conversation(
427
403
  def chat_sessions(
428
404
  request: Request,
429
405
  common: CommonQueryParams,
406
+ recent: Optional[bool] = False,
430
407
  ):
431
408
  user = request.user.object
432
409
 
433
410
  # Load Conversation Sessions
434
- sessions = ConversationAdapters.get_conversation_sessions(user, request.user.client_app).values_list(
435
- "id", "slug", "title"
411
+ conversations = ConversationAdapters.get_conversation_sessions(user, request.user.client_app)
412
+ if recent:
413
+ conversations = conversations[:8]
414
+
415
+ sessions = conversations.values_list(
416
+ "id", "slug", "title", "agent__slug", "agent__name", "agent__avatar", "created_at", "updated_at"
436
417
  )
437
418
 
438
- session_values = [{"conversation_id": session[0], "slug": session[2] or session[1]} for session in sessions]
419
+ session_values = [
420
+ {
421
+ "conversation_id": session[0],
422
+ "slug": session[2] or session[1],
423
+ "agent_name": session[4],
424
+ "agent_avatar": session[5],
425
+ "created": session[6].strftime("%Y-%m-%d %H:%M:%S"),
426
+ "updated": session[7].strftime("%Y-%m-%d %H:%M:%S"),
427
+ }
428
+ for session in sessions
429
+ ]
439
430
 
440
431
  update_telemetry_state(
441
432
  request=request,
@@ -477,7 +468,6 @@ async def create_chat_session(
477
468
 
478
469
 
479
470
  @api_chat.get("/options", response_class=Response)
480
- @requires(["authenticated"])
481
471
  async def chat_options(
482
472
  request: Request,
483
473
  common: CommonQueryParams,
@@ -526,141 +516,140 @@ async def set_conversation_title(
526
516
  )
527
517
 
528
518
 
529
- @api_chat.websocket("/ws")
530
- async def websocket_endpoint(
531
- websocket: WebSocket,
532
- conversation_id: int,
519
+ @api_chat.get("")
520
+ async def chat(
521
+ request: Request,
522
+ common: CommonQueryParams,
523
+ q: str,
524
+ n: int = 7,
525
+ d: float = 0.18,
526
+ stream: Optional[bool] = False,
527
+ title: Optional[str] = None,
528
+ conversation_id: Optional[int] = None,
533
529
  city: Optional[str] = None,
534
530
  region: Optional[str] = None,
535
531
  country: Optional[str] = None,
536
532
  timezone: Optional[str] = None,
533
+ rate_limiter_per_minute=Depends(
534
+ ApiUserRateLimiter(requests=5, subscribed_requests=60, window=60, slug="chat_minute")
535
+ ),
536
+ rate_limiter_per_day=Depends(
537
+ ApiUserRateLimiter(requests=5, subscribed_requests=600, window=60 * 60 * 24, slug="chat_day")
538
+ ),
537
539
  ):
538
- connection_alive = True
540
+ async def event_generator(q: str):
541
+ start_time = time.perf_counter()
542
+ ttft = None
543
+ chat_metadata: dict = {}
544
+ connection_alive = True
545
+ user: KhojUser = request.user.object
546
+ event_delimiter = "␃🔚␗"
547
+ q = unquote(q)
548
+
549
+ async def send_event(event_type: ChatEvent, data: str | dict):
550
+ nonlocal connection_alive, ttft
551
+ if not connection_alive or await request.is_disconnected():
552
+ connection_alive = False
553
+ logger.warn(f"User {user} disconnected from {common.client} client")
554
+ return
555
+ try:
556
+ if event_type == ChatEvent.END_LLM_RESPONSE:
557
+ collect_telemetry()
558
+ if event_type == ChatEvent.START_LLM_RESPONSE:
559
+ ttft = time.perf_counter() - start_time
560
+ if event_type == ChatEvent.MESSAGE:
561
+ yield data
562
+ elif event_type == ChatEvent.REFERENCES or stream:
563
+ yield json.dumps({"type": event_type.value, "data": data}, ensure_ascii=False)
564
+ except asyncio.CancelledError as e:
565
+ connection_alive = False
566
+ logger.warn(f"User {user} disconnected from {common.client} client: {e}")
567
+ return
568
+ except Exception as e:
569
+ connection_alive = False
570
+ logger.error(f"Failed to stream chat API response to {user} on {common.client}: {e}", exc_info=True)
571
+ return
572
+ finally:
573
+ if stream:
574
+ yield event_delimiter
575
+
576
+ async def send_llm_response(response: str):
577
+ async for result in send_event(ChatEvent.START_LLM_RESPONSE, ""):
578
+ yield result
579
+ async for result in send_event(ChatEvent.MESSAGE, response):
580
+ yield result
581
+ async for result in send_event(ChatEvent.END_LLM_RESPONSE, ""):
582
+ yield result
583
+
584
+ def collect_telemetry():
585
+ # Gather chat response telemetry
586
+ nonlocal chat_metadata
587
+ latency = time.perf_counter() - start_time
588
+ cmd_set = set([cmd.value for cmd in conversation_commands])
589
+ chat_metadata = chat_metadata or {}
590
+ chat_metadata["conversation_command"] = cmd_set
591
+ chat_metadata["agent"] = conversation.agent.slug if conversation.agent else None
592
+ chat_metadata["latency"] = f"{latency:.3f}"
593
+ chat_metadata["ttft_latency"] = f"{ttft:.3f}"
594
+
595
+ logger.info(f"Chat response time to first token: {ttft:.3f} seconds")
596
+ logger.info(f"Chat response total time: {latency:.3f} seconds")
597
+ update_telemetry_state(
598
+ request=request,
599
+ telemetry_type="api",
600
+ api="chat",
601
+ client=request.user.client_app,
602
+ user_agent=request.headers.get("user-agent"),
603
+ host=request.headers.get("host"),
604
+ metadata=chat_metadata,
605
+ )
539
606
 
540
- async def send_status_update(message: str):
541
- nonlocal connection_alive
542
- if not connection_alive:
543
- return
607
+ conversation_commands = [get_conversation_command(query=q, any_references=True)]
544
608
 
545
- status_packet = {
546
- "type": "status",
547
- "message": message,
548
- "content-type": "application/json",
549
- }
550
- try:
551
- await websocket.send_text(json.dumps(status_packet))
552
- except ConnectionClosedOK:
553
- connection_alive = False
554
- logger.info(f"User {user} disconnected web socket. Emitting rest of responses to clear thread")
555
-
556
- async def send_complete_llm_response(llm_response: str):
557
- nonlocal connection_alive
558
- if not connection_alive:
559
- return
560
- try:
561
- await websocket.send_text("start_llm_response")
562
- await websocket.send_text(llm_response)
563
- await websocket.send_text("end_llm_response")
564
- except ConnectionClosedOK:
565
- connection_alive = False
566
- logger.info(f"User {user} disconnected web socket. Emitting rest of responses to clear thread")
567
-
568
- async def send_message(message: str):
569
- nonlocal connection_alive
570
- if not connection_alive:
571
- return
572
- try:
573
- await websocket.send_text(message)
574
- except ConnectionClosedOK:
575
- connection_alive = False
576
- logger.info(f"User {user} disconnected web socket. Emitting rest of responses to clear thread")
577
-
578
- async def send_rate_limit_message(message: str):
579
- nonlocal connection_alive
580
- if not connection_alive:
609
+ conversation = await ConversationAdapters.aget_conversation_by_user(
610
+ user, client_application=request.user.client_app, conversation_id=conversation_id, title=title
611
+ )
612
+ if not conversation:
613
+ async for result in send_llm_response(f"Conversation {conversation_id} not found"):
614
+ yield result
581
615
  return
582
616
 
583
- status_packet = {
584
- "type": "rate_limit",
585
- "message": message,
586
- "content-type": "application/json",
587
- }
588
- try:
589
- await websocket.send_text(json.dumps(status_packet))
590
- except ConnectionClosedOK:
591
- connection_alive = False
592
- logger.info(f"User {user} disconnected web socket. Emitting rest of responses to clear thread")
593
-
594
- user: KhojUser = websocket.user.object
595
- conversation = await ConversationAdapters.aget_conversation_by_user(
596
- user, client_application=websocket.user.client_app, conversation_id=conversation_id
597
- )
617
+ await is_ready_to_chat(user)
598
618
 
599
- hourly_limiter = ApiUserRateLimiter(requests=5, subscribed_requests=60, window=60, slug="chat_minute")
600
-
601
- daily_limiter = ApiUserRateLimiter(requests=5, subscribed_requests=600, window=60 * 60 * 24, slug="chat_day")
602
-
603
- await is_ready_to_chat(user)
604
-
605
- user_name = await aget_user_name(user)
606
-
607
- location = None
608
-
609
- if city or region or country:
610
- location = LocationData(city=city, region=region, country=country)
611
-
612
- await websocket.accept()
613
- while connection_alive:
614
- try:
615
- if conversation:
616
- await sync_to_async(conversation.refresh_from_db)(fields=["conversation_log"])
617
- q = await websocket.receive_text()
618
-
619
- # Refresh these because the connection to the database might have been closed
620
- await conversation.arefresh_from_db()
621
-
622
- except WebSocketDisconnect:
623
- logger.debug(f"User {user} disconnected web socket")
624
- break
625
-
626
- try:
627
- await sync_to_async(hourly_limiter)(websocket)
628
- await sync_to_async(daily_limiter)(websocket)
629
- except HTTPException as e:
630
- await send_rate_limit_message(e.detail)
631
- break
619
+ user_name = await aget_user_name(user)
620
+ location = None
621
+ if city or region or country:
622
+ location = LocationData(city=city, region=region, country=country)
632
623
 
633
624
  if is_query_empty(q):
634
- await send_message("start_llm_response")
635
- await send_message(
636
- "It seems like your query is incomplete. Could you please provide more details or specify what you need help with?"
637
- )
638
- await send_message("end_llm_response")
639
- continue
625
+ async for result in send_llm_response("Please ask your query to get started."):
626
+ yield result
627
+ return
640
628
 
641
629
  user_message_time = datetime.now().strftime("%Y-%m-%d %H:%M:%S")
642
- conversation_commands = [get_conversation_command(query=q, any_references=True)]
643
-
644
- await send_status_update(f"**👀 Understanding Query**: {q}")
645
630
 
646
631
  meta_log = conversation.conversation_log
647
632
  is_automated_task = conversation_commands == [ConversationCommand.AutomatedTask]
648
- used_slash_summarize = conversation_commands == [ConversationCommand.Summarize]
649
633
 
650
634
  if conversation_commands == [ConversationCommand.Default] or is_automated_task:
651
635
  conversation_commands = await aget_relevant_information_sources(q, meta_log, is_automated_task)
652
636
  conversation_commands_str = ", ".join([cmd.value for cmd in conversation_commands])
653
- await send_status_update(f"**🗃️ Chose Data Sources to Search:** {conversation_commands_str}")
637
+ async for result in send_event(
638
+ ChatEvent.STATUS, f"**Chose Data Sources to Search:** {conversation_commands_str}"
639
+ ):
640
+ yield result
654
641
 
655
642
  mode = await aget_relevant_output_modes(q, meta_log, is_automated_task)
656
- await send_status_update(f"**🧑🏾‍💻 Decided Response Mode:** {mode.value}")
643
+ async for result in send_event(ChatEvent.STATUS, f"**Decided Response Mode:** {mode.value}"):
644
+ yield result
657
645
  if mode not in conversation_commands:
658
646
  conversation_commands.append(mode)
659
647
 
660
648
  for cmd in conversation_commands:
661
- await conversation_command_rate_limiter.update_and_check_if_valid(websocket, cmd)
649
+ await conversation_command_rate_limiter.update_and_check_if_valid(request, cmd)
662
650
  q = q.replace(f"/{cmd.value}", "").strip()
663
651
 
652
+ used_slash_summarize = conversation_commands == [ConversationCommand.Summarize]
664
653
  file_filters = conversation.file_filters if conversation else []
665
654
  # Skip trying to summarize if
666
655
  if (
@@ -676,28 +665,37 @@ async def websocket_endpoint(
676
665
  response_log = ""
677
666
  if len(file_filters) == 0:
678
667
  response_log = "No files selected for summarization. Please add files using the section on the left."
679
- await send_complete_llm_response(response_log)
668
+ async for result in send_llm_response(response_log):
669
+ yield result
680
670
  elif len(file_filters) > 1:
681
671
  response_log = "Only one file can be selected for summarization."
682
- await send_complete_llm_response(response_log)
672
+ async for result in send_llm_response(response_log):
673
+ yield result
683
674
  else:
684
675
  try:
685
676
  file_object = await FileObjectAdapters.async_get_file_objects_by_name(user, file_filters[0])
686
677
  if len(file_object) == 0:
687
678
  response_log = "Sorry, we couldn't find the full text of this file. Please re-upload the document and try again."
688
- await send_complete_llm_response(response_log)
689
- continue
679
+ async for result in send_llm_response(response_log):
680
+ yield result
681
+ return
690
682
  contextual_data = " ".join([file.raw_text for file in file_object])
691
683
  if not q:
692
684
  q = "Create a general summary of the file"
693
- await send_status_update(f"**🧑🏾‍💻 Constructing Summary Using:** {file_object[0].file_name}")
685
+ async for result in send_event(
686
+ ChatEvent.STATUS, f"**Constructing Summary Using:** {file_object[0].file_name}"
687
+ ):
688
+ yield result
689
+
694
690
  response = await extract_relevant_summary(q, contextual_data)
695
691
  response_log = str(response)
696
- await send_complete_llm_response(response_log)
692
+ async for result in send_llm_response(response_log):
693
+ yield result
697
694
  except Exception as e:
698
695
  response_log = "Error summarizing file."
699
696
  logger.error(f"Error summarizing file for {user.email}: {e}", exc_info=True)
700
- await send_complete_llm_response(response_log)
697
+ async for result in send_llm_response(response_log):
698
+ yield result
701
699
  await sync_to_async(save_to_conversation_log)(
702
700
  q,
703
701
  response_log,
@@ -705,16 +703,10 @@ async def websocket_endpoint(
705
703
  meta_log,
706
704
  user_message_time,
707
705
  intent_type="summarize",
708
- client_application=websocket.user.client_app,
706
+ client_application=request.user.client_app,
709
707
  conversation_id=conversation_id,
710
708
  )
711
- update_telemetry_state(
712
- request=websocket,
713
- telemetry_type="api",
714
- api="chat",
715
- metadata={"conversation_command": conversation_commands[0].value},
716
- )
717
- continue
709
+ return
718
710
 
719
711
  custom_filters = []
720
712
  if conversation_commands == [ConversationCommand.Help]:
@@ -724,8 +716,9 @@ async def websocket_endpoint(
724
716
  conversation_config = await ConversationAdapters.aget_default_conversation_config()
725
717
  model_type = conversation_config.model_type
726
718
  formatted_help = help_message.format(model=model_type, version=state.khoj_version, device=get_device())
727
- await send_complete_llm_response(formatted_help)
728
- continue
719
+ async for result in send_llm_response(formatted_help):
720
+ yield result
721
+ return
729
722
  # Adding specification to search online specifically on khoj.dev pages.
730
723
  custom_filters.append("site:khoj.dev")
731
724
  conversation_commands.append(ConversationCommand.Online)
@@ -733,14 +726,14 @@ async def websocket_endpoint(
733
726
  if ConversationCommand.Automation in conversation_commands:
734
727
  try:
735
728
  automation, crontime, query_to_run, subject = await create_automation(
736
- q, timezone, user, websocket.url, meta_log
729
+ q, timezone, user, request.url, meta_log
737
730
  )
738
731
  except Exception as e:
739
732
  logger.error(f"Error scheduling task {q} for {user.email}: {e}")
740
- await send_complete_llm_response(
741
- f"Unable to create automation. Ensure the automation doesn't already exist."
742
- )
743
- continue
733
+ error_message = f"Unable to create automation. Ensure the automation doesn't already exist."
734
+ async for result in send_llm_response(error_message):
735
+ yield result
736
+ return
744
737
 
745
738
  llm_response = construct_automation_created_message(automation, crontime, query_to_run, subject)
746
739
  await sync_to_async(save_to_conversation_log)(
@@ -750,57 +743,78 @@ async def websocket_endpoint(
750
743
  meta_log,
751
744
  user_message_time,
752
745
  intent_type="automation",
753
- client_application=websocket.user.client_app,
746
+ client_application=request.user.client_app,
754
747
  conversation_id=conversation_id,
755
748
  inferred_queries=[query_to_run],
756
749
  automation_id=automation.id,
757
750
  )
758
- common = CommonQueryParamsClass(
759
- client=websocket.user.client_app,
760
- user_agent=websocket.headers.get("user-agent"),
761
- host=websocket.headers.get("host"),
762
- )
763
- update_telemetry_state(
764
- request=websocket,
765
- telemetry_type="api",
766
- api="chat",
767
- **common.__dict__,
768
- )
769
- await send_complete_llm_response(llm_response)
770
- continue
751
+ async for result in send_llm_response(llm_response):
752
+ yield result
753
+ return
771
754
 
772
- compiled_references, inferred_queries, defiltered_query = await extract_references_and_questions(
773
- websocket, meta_log, q, 7, 0.18, conversation_id, conversation_commands, location, send_status_update
774
- )
755
+ # Gather Context
756
+ ## Extract Document References
757
+ compiled_references, inferred_queries, defiltered_query = [], [], None
758
+ async for result in extract_references_and_questions(
759
+ request,
760
+ meta_log,
761
+ q,
762
+ (n or 7),
763
+ (d or 0.18),
764
+ conversation_id,
765
+ conversation_commands,
766
+ location,
767
+ partial(send_event, ChatEvent.STATUS),
768
+ ):
769
+ if isinstance(result, dict) and ChatEvent.STATUS in result:
770
+ yield result[ChatEvent.STATUS]
771
+ else:
772
+ compiled_references.extend(result[0])
773
+ inferred_queries.extend(result[1])
774
+ defiltered_query = result[2]
775
775
 
776
- if compiled_references:
776
+ if not is_none_or_empty(compiled_references):
777
777
  headings = "\n- " + "\n- ".join(set([c.get("compiled", c).split("\n")[0] for c in compiled_references]))
778
- await send_status_update(f"**📜 Found Relevant Notes**: {headings}")
778
+ async for result in send_event(ChatEvent.STATUS, f"**Found Relevant Notes**: {headings}"):
779
+ yield result
779
780
 
780
781
  online_results: Dict = dict()
781
782
 
782
783
  if conversation_commands == [ConversationCommand.Notes] and not await EntryAdapters.auser_has_entries(user):
783
- await send_complete_llm_response(f"{no_entries_found.format()}")
784
- continue
784
+ async for result in send_llm_response(f"{no_entries_found.format()}"):
785
+ yield result
786
+ return
785
787
 
786
788
  if ConversationCommand.Notes in conversation_commands and is_none_or_empty(compiled_references):
787
789
  conversation_commands.remove(ConversationCommand.Notes)
788
790
 
791
+ ## Gather Online References
789
792
  if ConversationCommand.Online in conversation_commands:
790
793
  try:
791
- online_results = await search_online(
792
- defiltered_query, meta_log, location, send_status_update, custom_filters
793
- )
794
+ async for result in search_online(
795
+ defiltered_query, meta_log, location, partial(send_event, ChatEvent.STATUS), custom_filters
796
+ ):
797
+ if isinstance(result, dict) and ChatEvent.STATUS in result:
798
+ yield result[ChatEvent.STATUS]
799
+ else:
800
+ online_results = result
794
801
  except ValueError as e:
795
- logger.warning(f"Error searching online: {e}. Attempting to respond without online results")
796
- await send_complete_llm_response(
797
- f"Error searching online: {e}. Attempting to respond without online results"
798
- )
799
- continue
802
+ error_message = f"Error searching online: {e}. Attempting to respond without online results"
803
+ logger.warning(error_message)
804
+ async for result in send_llm_response(error_message):
805
+ yield result
806
+ return
800
807
 
808
+ ## Gather Webpage References
801
809
  if ConversationCommand.Webpage in conversation_commands:
802
810
  try:
803
- direct_web_pages = await read_webpages(defiltered_query, meta_log, location, send_status_update)
811
+ async for result in read_webpages(
812
+ defiltered_query, meta_log, location, partial(send_event, ChatEvent.STATUS)
813
+ ):
814
+ if isinstance(result, dict) and ChatEvent.STATUS in result:
815
+ yield result[ChatEvent.STATUS]
816
+ else:
817
+ direct_web_pages = result
804
818
  webpages = []
805
819
  for query in direct_web_pages:
806
820
  if online_results.get(query):
@@ -810,38 +824,52 @@ async def websocket_endpoint(
810
824
 
811
825
  for webpage in direct_web_pages[query]["webpages"]:
812
826
  webpages.append(webpage["link"])
813
-
814
- await send_status_update(f"**📚 Read web pages**: {webpages}")
827
+ async for result in send_event(ChatEvent.STATUS, f"**Read web pages**: {webpages}"):
828
+ yield result
815
829
  except ValueError as e:
816
830
  logger.warning(
817
- f"Error directly reading webpages: {e}. Attempting to respond without online results", exc_info=True
831
+ f"Error directly reading webpages: {e}. Attempting to respond without online results",
832
+ exc_info=True,
818
833
  )
819
834
 
835
+ ## Send Gathered References
836
+ async for result in send_event(
837
+ ChatEvent.REFERENCES,
838
+ {
839
+ "inferredQueries": inferred_queries,
840
+ "context": compiled_references,
841
+ "onlineContext": online_results,
842
+ },
843
+ ):
844
+ yield result
845
+
846
+ # Generate Output
847
+ ## Generate Image Output
820
848
  if ConversationCommand.Image in conversation_commands:
821
- update_telemetry_state(
822
- request=websocket,
823
- telemetry_type="api",
824
- api="chat",
825
- metadata={"conversation_command": conversation_commands[0].value},
826
- )
827
- image, status_code, improved_image_prompt, intent_type = await text_to_image(
849
+ async for result in text_to_image(
828
850
  q,
829
851
  user,
830
852
  meta_log,
831
853
  location_data=location,
832
854
  references=compiled_references,
833
855
  online_results=online_results,
834
- send_status_func=send_status_update,
835
- )
856
+ send_status_func=partial(send_event, ChatEvent.STATUS),
857
+ ):
858
+ if isinstance(result, dict) and ChatEvent.STATUS in result:
859
+ yield result[ChatEvent.STATUS]
860
+ else:
861
+ image, status_code, improved_image_prompt, intent_type = result
862
+
836
863
  if image is None or status_code != 200:
837
864
  content_obj = {
838
- "image": image,
865
+ "content-type": "application/json",
839
866
  "intentType": intent_type,
840
867
  "detail": improved_image_prompt,
841
- "content-type": "application/json",
868
+ "image": image,
842
869
  }
843
- await send_complete_llm_response(json.dumps(content_obj))
844
- continue
870
+ async for result in send_llm_response(json.dumps(content_obj)):
871
+ yield result
872
+ return
845
873
 
846
874
  await sync_to_async(save_to_conversation_log)(
847
875
  q,
@@ -851,17 +879,23 @@ async def websocket_endpoint(
851
879
  user_message_time,
852
880
  intent_type=intent_type,
853
881
  inferred_queries=[improved_image_prompt],
854
- client_application=websocket.user.client_app,
882
+ client_application=request.user.client_app,
855
883
  conversation_id=conversation_id,
856
884
  compiled_references=compiled_references,
857
885
  online_results=online_results,
858
886
  )
859
- content_obj = {"image": image, "intentType": intent_type, "inferredQueries": [improved_image_prompt], "context": compiled_references, "content-type": "application/json", "online_results": online_results} # type: ignore
860
-
861
- await send_complete_llm_response(json.dumps(content_obj))
862
- continue
887
+ content_obj = {
888
+ "intentType": intent_type,
889
+ "inferredQueries": [improved_image_prompt],
890
+ "image": image,
891
+ }
892
+ async for result in send_llm_response(json.dumps(content_obj)):
893
+ yield result
894
+ return
863
895
 
864
- await send_status_update(f"**💭 Generating a well-informed response**")
896
+ ## Generate Text Output
897
+ async for result in send_event(ChatEvent.STATUS, f"**Generating a well-informed response**"):
898
+ yield result
865
899
  llm_response, chat_metadata = await agenerate_chat_response(
866
900
  defiltered_query,
867
901
  meta_log,
@@ -871,310 +905,49 @@ async def websocket_endpoint(
871
905
  inferred_queries,
872
906
  conversation_commands,
873
907
  user,
874
- websocket.user.client_app,
908
+ request.user.client_app,
875
909
  conversation_id,
876
910
  location,
877
911
  user_name,
878
912
  )
879
913
 
880
- chat_metadata["agent"] = conversation.agent.slug if conversation.agent else None
914
+ # Send Response
915
+ async for result in send_event(ChatEvent.START_LLM_RESPONSE, ""):
916
+ yield result
881
917
 
882
- update_telemetry_state(
883
- request=websocket,
884
- telemetry_type="api",
885
- api="chat",
886
- metadata=chat_metadata,
887
- )
918
+ continue_stream = True
888
919
  iterator = AsyncIteratorWrapper(llm_response)
889
-
890
- await send_message("start_llm_response")
891
-
892
920
  async for item in iterator:
893
921
  if item is None:
894
- break
895
- if connection_alive:
896
- try:
897
- await send_message(f"{item}")
898
- except ConnectionClosedOK:
899
- connection_alive = False
900
- logger.info(f"User {user} disconnected web socket. Emitting rest of responses to clear thread")
901
-
902
- await send_message("end_llm_response")
903
-
904
-
905
- @api_chat.get("", response_class=Response)
906
- @requires(["authenticated"])
907
- async def chat(
908
- request: Request,
909
- common: CommonQueryParams,
910
- q: str,
911
- n: Optional[int] = 5,
912
- d: Optional[float] = 0.22,
913
- stream: Optional[bool] = False,
914
- title: Optional[str] = None,
915
- conversation_id: Optional[int] = None,
916
- city: Optional[str] = None,
917
- region: Optional[str] = None,
918
- country: Optional[str] = None,
919
- timezone: Optional[str] = None,
920
- rate_limiter_per_minute=Depends(
921
- ApiUserRateLimiter(requests=5, subscribed_requests=60, window=60, slug="chat_minute")
922
- ),
923
- rate_limiter_per_day=Depends(
924
- ApiUserRateLimiter(requests=5, subscribed_requests=600, window=60 * 60 * 24, slug="chat_day")
925
- ),
926
- ) -> Response:
927
- user: KhojUser = request.user.object
928
- q = unquote(q)
929
- if is_query_empty(q):
930
- return Response(
931
- content="It seems like your query is incomplete. Could you please provide more details or specify what you need help with?",
932
- media_type="text/plain",
933
- status_code=400,
934
- )
935
- user_message_time = datetime.now().strftime("%Y-%m-%d %H:%M:%S")
936
- logger.info(f"Chat request by {user.username}: {q}")
937
-
938
- await is_ready_to_chat(user)
939
- conversation_commands = [get_conversation_command(query=q, any_references=True)]
940
-
941
- _custom_filters = []
942
- if conversation_commands == [ConversationCommand.Help]:
943
- help_str = "/" + ConversationCommand.Help
944
- if q.strip() == help_str:
945
- conversation_config = await ConversationAdapters.aget_user_conversation_config(user)
946
- if conversation_config == None:
947
- conversation_config = await ConversationAdapters.aget_default_conversation_config()
948
- model_type = conversation_config.model_type
949
- formatted_help = help_message.format(model=model_type, version=state.khoj_version, device=get_device())
950
- return StreamingResponse(iter([formatted_help]), media_type="text/event-stream", status_code=200)
951
- # Adding specification to search online specifically on khoj.dev pages.
952
- _custom_filters.append("site:khoj.dev")
953
- conversation_commands.append(ConversationCommand.Online)
954
-
955
- conversation = await ConversationAdapters.aget_conversation_by_user(
956
- user, request.user.client_app, conversation_id, title
957
- )
958
- conversation_id = conversation.id if conversation else None
959
-
960
- if not conversation:
961
- return Response(
962
- content=f"No conversation found with requested id, title", media_type="text/plain", status_code=400
963
- )
964
- else:
965
- meta_log = conversation.conversation_log
966
-
967
- if ConversationCommand.Summarize in conversation_commands:
968
- file_filters = conversation.file_filters
969
- llm_response = ""
970
- if len(file_filters) == 0:
971
- llm_response = "No files selected for summarization. Please add files using the section on the left."
972
- elif len(file_filters) > 1:
973
- llm_response = "Only one file can be selected for summarization."
974
- else:
922
+ async for result in send_event(ChatEvent.END_LLM_RESPONSE, ""):
923
+ yield result
924
+ logger.debug("Finished streaming response")
925
+ return
926
+ if not connection_alive or not continue_stream:
927
+ continue
975
928
  try:
976
- file_object = await FileObjectAdapters.async_get_file_objects_by_name(user, file_filters[0])
977
- if len(file_object) == 0:
978
- llm_response = "Sorry, we couldn't find the full text of this file. Please re-upload the document and try again."
979
- return StreamingResponse(content=llm_response, media_type="text/event-stream", status_code=200)
980
- contextual_data = " ".join([file.raw_text for file in file_object])
981
- summarizeStr = "/" + ConversationCommand.Summarize
982
- if q.strip() == summarizeStr:
983
- q = "Create a general summary of the file"
984
- response = await extract_relevant_summary(q, contextual_data)
985
- llm_response = str(response)
929
+ async for result in send_event(ChatEvent.MESSAGE, f"{item}"):
930
+ yield result
986
931
  except Exception as e:
987
- logger.error(f"Error summarizing file for {user.email}: {e}")
988
- llm_response = "Error summarizing file."
989
- await sync_to_async(save_to_conversation_log)(
990
- q,
991
- llm_response,
992
- user,
993
- conversation.conversation_log,
994
- user_message_time,
995
- intent_type="summarize",
996
- client_application=request.user.client_app,
997
- conversation_id=conversation_id,
998
- )
999
- update_telemetry_state(
1000
- request=request,
1001
- telemetry_type="api",
1002
- api="chat",
1003
- metadata={"conversation_command": conversation_commands[0].value},
1004
- **common.__dict__,
1005
- )
1006
- return StreamingResponse(content=llm_response, media_type="text/event-stream", status_code=200)
1007
-
1008
- is_automated_task = conversation_commands == [ConversationCommand.AutomatedTask]
1009
-
1010
- if conversation_commands == [ConversationCommand.Default] or is_automated_task:
1011
- conversation_commands = await aget_relevant_information_sources(q, meta_log, is_automated_task)
1012
- mode = await aget_relevant_output_modes(q, meta_log, is_automated_task)
1013
- if mode not in conversation_commands:
1014
- conversation_commands.append(mode)
1015
-
1016
- for cmd in conversation_commands:
1017
- await conversation_command_rate_limiter.update_and_check_if_valid(request, cmd)
1018
- q = q.replace(f"/{cmd.value}", "").strip()
1019
-
1020
- location = None
1021
-
1022
- if city or region or country:
1023
- location = LocationData(city=city, region=region, country=country)
1024
-
1025
- user_name = await aget_user_name(user)
1026
-
1027
- if ConversationCommand.Automation in conversation_commands:
1028
- try:
1029
- automation, crontime, query_to_run, subject = await create_automation(
1030
- q, timezone, user, request.url, meta_log
1031
- )
1032
- except Exception as e:
1033
- logger.error(f"Error creating automation {q} for {user.email}: {e}", exc_info=True)
1034
- return Response(
1035
- content=f"Unable to create automation. Ensure the automation doesn't already exist.",
1036
- media_type="text/plain",
1037
- status_code=500,
1038
- )
1039
-
1040
- llm_response = construct_automation_created_message(automation, crontime, query_to_run, subject)
1041
- await sync_to_async(save_to_conversation_log)(
1042
- q,
1043
- llm_response,
1044
- user,
1045
- meta_log,
1046
- user_message_time,
1047
- intent_type="automation",
1048
- client_application=request.user.client_app,
1049
- conversation_id=conversation_id,
1050
- inferred_queries=[query_to_run],
1051
- automation_id=automation.id,
1052
- )
1053
-
1054
- if stream:
1055
- return StreamingResponse(llm_response, media_type="text/event-stream", status_code=200)
1056
- else:
1057
- return Response(content=llm_response, media_type="text/plain", status_code=200)
1058
-
1059
- compiled_references, inferred_queries, defiltered_query = await extract_references_and_questions(
1060
- request, meta_log, q, (n or 5), (d or math.inf), conversation_id, conversation_commands, location
1061
- )
1062
- online_results: Dict[str, Dict] = {}
1063
-
1064
- if conversation_commands == [ConversationCommand.Notes] and not await EntryAdapters.auser_has_entries(user):
1065
- no_entries_found_format = no_entries_found.format()
1066
- if stream:
1067
- return StreamingResponse(iter([no_entries_found_format]), media_type="text/event-stream", status_code=200)
1068
- else:
1069
- response_obj = {"response": no_entries_found_format}
1070
- return Response(content=json.dumps(response_obj), media_type="text/plain", status_code=200)
1071
-
1072
- if conversation_commands == [ConversationCommand.Notes] and is_none_or_empty(compiled_references):
1073
- no_notes_found_format = no_notes_found.format()
1074
- if stream:
1075
- return StreamingResponse(iter([no_notes_found_format]), media_type="text/event-stream", status_code=200)
1076
- else:
1077
- response_obj = {"response": no_notes_found_format}
1078
- return Response(content=json.dumps(response_obj), media_type="text/plain", status_code=200)
1079
-
1080
- if ConversationCommand.Notes in conversation_commands and is_none_or_empty(compiled_references):
1081
- conversation_commands.remove(ConversationCommand.Notes)
1082
-
1083
- if ConversationCommand.Online in conversation_commands:
1084
- try:
1085
- online_results = await search_online(defiltered_query, meta_log, location, custom_filters=_custom_filters)
1086
- except ValueError as e:
1087
- logger.warning(f"Error searching online: {e}. Attempting to respond without online results")
1088
-
1089
- if ConversationCommand.Webpage in conversation_commands:
1090
- try:
1091
- online_results = await read_webpages(defiltered_query, meta_log, location)
1092
- except ValueError as e:
1093
- logger.warning(
1094
- f"Error directly reading webpages: {e}. Attempting to respond without online results", exc_info=True
1095
- )
1096
-
1097
- if ConversationCommand.Image in conversation_commands:
1098
- update_telemetry_state(
1099
- request=request,
1100
- telemetry_type="api",
1101
- api="chat",
1102
- metadata={"conversation_command": conversation_commands[0].value},
1103
- **common.__dict__,
1104
- )
1105
- image, status_code, improved_image_prompt, intent_type = await text_to_image(
1106
- q, user, meta_log, location_data=location, references=compiled_references, online_results=online_results
1107
- )
1108
- if image is None:
1109
- content_obj = {"image": image, "intentType": intent_type, "detail": improved_image_prompt}
1110
- return Response(content=json.dumps(content_obj), media_type="application/json", status_code=status_code)
1111
-
1112
- await sync_to_async(save_to_conversation_log)(
1113
- q,
1114
- image,
1115
- user,
1116
- meta_log,
1117
- user_message_time,
1118
- intent_type=intent_type,
1119
- inferred_queries=[improved_image_prompt],
1120
- client_application=request.user.client_app,
1121
- conversation_id=conversation.id,
1122
- compiled_references=compiled_references,
1123
- online_results=online_results,
1124
- )
1125
- content_obj = {"image": image, "intentType": intent_type, "inferredQueries": [improved_image_prompt], "context": compiled_references, "online_results": online_results} # type: ignore
1126
- return Response(content=json.dumps(content_obj), media_type="application/json", status_code=status_code)
1127
-
1128
- # Get the (streamed) chat response from the LLM of choice.
1129
- llm_response, chat_metadata = await agenerate_chat_response(
1130
- defiltered_query,
1131
- meta_log,
1132
- conversation,
1133
- compiled_references,
1134
- online_results,
1135
- inferred_queries,
1136
- conversation_commands,
1137
- user,
1138
- request.user.client_app,
1139
- conversation.id,
1140
- location,
1141
- user_name,
1142
- )
1143
-
1144
- cmd_set = set([cmd.value for cmd in conversation_commands])
1145
- chat_metadata["conversation_command"] = cmd_set
1146
- chat_metadata["agent"] = conversation.agent.slug if conversation.agent else None
1147
-
1148
- update_telemetry_state(
1149
- request=request,
1150
- telemetry_type="api",
1151
- api="chat",
1152
- metadata=chat_metadata,
1153
- **common.__dict__,
1154
- )
1155
-
1156
- if llm_response is None:
1157
- return Response(content=llm_response, media_type="text/plain", status_code=500)
932
+ continue_stream = False
933
+ logger.info(f"User {user} disconnected. Emitting rest of responses to clear thread: {e}")
1158
934
 
935
+ ## Stream Text Response
1159
936
  if stream:
1160
- return StreamingResponse(llm_response, media_type="text/event-stream", status_code=200)
1161
-
1162
- iterator = AsyncIteratorWrapper(llm_response)
1163
-
1164
- # Get the full response from the generator if the stream is not requested.
1165
- aggregated_gpt_response = ""
1166
- async for item in iterator:
1167
- if item is None:
1168
- break
1169
- aggregated_gpt_response += item
1170
-
1171
- actual_response = aggregated_gpt_response.split("### compiled references:")[0]
1172
-
1173
- response_obj = {
1174
- "response": actual_response,
1175
- "inferredQueries": inferred_queries,
1176
- "context": compiled_references,
1177
- "online_results": online_results,
1178
- }
1179
-
1180
- return Response(content=json.dumps(response_obj), media_type="application/json", status_code=200)
937
+ return StreamingResponse(event_generator(q), media_type="text/plain")
938
+ ## Non-Streaming Text Response
939
+ else:
940
+ # Get the full response from the generator if the stream is not requested.
941
+ response_obj = {}
942
+ actual_response = ""
943
+ iterator = event_generator(q)
944
+ async for item in iterator:
945
+ try:
946
+ item_json = json.loads(item)
947
+ if "type" in item_json and item_json["type"] == ChatEvent.REFERENCES.value:
948
+ response_obj = item_json["data"]
949
+ except:
950
+ actual_response += item
951
+ response_obj["response"] = actual_response
952
+
953
+ return Response(content=json.dumps(response_obj), media_type="application/json", status_code=200)