khoj 1.41.1.dev90__py3-none-any.whl → 1.41.1.dev107__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (66) hide show
  1. khoj/database/adapters/__init__.py +16 -5
  2. khoj/database/models/__init__.py +3 -0
  3. khoj/interface/compiled/404/index.html +2 -2
  4. khoj/interface/compiled/_next/static/chunks/5477-b91e9926cfc3095c.js +1 -0
  5. khoj/interface/compiled/_next/static/chunks/app/agents/layout-e49165209d2e406c.js +1 -0
  6. khoj/interface/compiled/_next/static/chunks/app/agents/{page-996513ae80f8720c.js → page-c9ceb9b94e24b94a.js} +1 -1
  7. khoj/interface/compiled/_next/static/chunks/app/automations/{page-2320231573aa9a49.js → page-3dc59a0df3827dc7.js} +1 -1
  8. khoj/interface/compiled/_next/static/chunks/app/chat/layout-d5ae861e1ade9d08.js +1 -0
  9. khoj/interface/compiled/_next/static/chunks/app/chat/{page-6257055246cdebd5.js → page-2b27c7118d8d5a16.js} +1 -1
  10. khoj/interface/compiled/_next/static/chunks/app/{page-d9a2e44bbcf49f82.js → page-38f1f125d7aeb4c7.js} +1 -1
  11. khoj/interface/compiled/_next/static/chunks/app/search/{page-31452bbda0e0a56f.js → page-26d4492fb1200e0e.js} +1 -1
  12. khoj/interface/compiled/_next/static/chunks/app/settings/{page-fdb72b15ca908b43.js → page-bf1a4e488b29fceb.js} +1 -1
  13. khoj/interface/compiled/_next/static/chunks/app/share/chat/layout-64a53f8ec4afa6b3.js +1 -0
  14. khoj/interface/compiled/_next/static/chunks/app/share/chat/{page-5b7cb35d835af900.js → page-a1f10c96366c3a4f.js} +1 -1
  15. khoj/interface/compiled/_next/static/chunks/{webpack-e091508620cb8aef.js → webpack-c6bde5961098facd.js} +1 -1
  16. khoj/interface/compiled/_next/static/css/bb7ea98028b368f3.css +1 -0
  17. khoj/interface/compiled/_next/static/css/ee66643a6a5bf71c.css +1 -0
  18. khoj/interface/compiled/agents/index.html +2 -2
  19. khoj/interface/compiled/agents/index.txt +2 -2
  20. khoj/interface/compiled/automations/index.html +2 -2
  21. khoj/interface/compiled/automations/index.txt +3 -3
  22. khoj/interface/compiled/chat/index.html +2 -2
  23. khoj/interface/compiled/chat/index.txt +2 -2
  24. khoj/interface/compiled/index.html +2 -2
  25. khoj/interface/compiled/index.txt +2 -2
  26. khoj/interface/compiled/search/index.html +2 -2
  27. khoj/interface/compiled/search/index.txt +2 -2
  28. khoj/interface/compiled/settings/index.html +2 -2
  29. khoj/interface/compiled/settings/index.txt +4 -4
  30. khoj/interface/compiled/share/chat/index.html +2 -2
  31. khoj/interface/compiled/share/chat/index.txt +2 -2
  32. khoj/processor/conversation/anthropic/anthropic_chat.py +3 -3
  33. khoj/processor/conversation/anthropic/utils.py +37 -19
  34. khoj/processor/conversation/google/gemini_chat.py +2 -2
  35. khoj/processor/conversation/offline/chat_model.py +2 -2
  36. khoj/processor/conversation/openai/gpt.py +3 -3
  37. khoj/processor/conversation/prompts.py +1 -1
  38. khoj/processor/conversation/utils.py +71 -42
  39. khoj/processor/operator/grounding_agent_uitars.py +2 -2
  40. khoj/processor/operator/operate_browser.py +17 -4
  41. khoj/processor/operator/operator_agent_anthropic.py +24 -5
  42. khoj/routers/api_chat.py +98 -28
  43. khoj/routers/api_model.py +3 -3
  44. khoj/routers/helpers.py +11 -8
  45. khoj/routers/research.py +15 -4
  46. khoj/utils/constants.py +6 -0
  47. khoj/utils/rawconfig.py +1 -0
  48. {khoj-1.41.1.dev90.dist-info → khoj-1.41.1.dev107.dist-info}/METADATA +2 -2
  49. {khoj-1.41.1.dev90.dist-info → khoj-1.41.1.dev107.dist-info}/RECORD +60 -60
  50. khoj/interface/compiled/_next/static/chunks/5477-77ce5c6f468d6c25.js +0 -1
  51. khoj/interface/compiled/_next/static/chunks/app/agents/layout-4e2a134ec26aa606.js +0 -1
  52. khoj/interface/compiled/_next/static/chunks/app/chat/layout-ad4d1792ab1a4108.js +0 -1
  53. khoj/interface/compiled/_next/static/chunks/app/share/chat/layout-abb6c5f4239ad7be.js +0 -1
  54. khoj/interface/compiled/_next/static/css/37a73b87f02df402.css +0 -1
  55. khoj/interface/compiled/_next/static/css/55d4a822f8d94b67.css +0 -1
  56. /khoj/interface/compiled/_next/static/chunks/{1915-1943ee8a628b893c.js → 1915-ab4353eaca76f690.js} +0 -0
  57. /khoj/interface/compiled/_next/static/chunks/{2117-5a41630a2bd2eae8.js → 2117-1c18aa2098982bf9.js} +0 -0
  58. /khoj/interface/compiled/_next/static/chunks/{4363-e6ac2203564d1a3b.js → 4363-4efaf12abe696251.js} +0 -0
  59. /khoj/interface/compiled/_next/static/chunks/{4447-e038b251d626c340.js → 4447-5d44807c40355b1a.js} +0 -0
  60. /khoj/interface/compiled/_next/static/chunks/{8667-8136f74e9a086fca.js → 8667-adbe6017a66cef10.js} +0 -0
  61. /khoj/interface/compiled/_next/static/chunks/{9259-640fdd77408475df.js → 9259-d8bcd9da9e80c81e.js} +0 -0
  62. /khoj/interface/compiled/_next/static/{WLmcH2J-wz36GlS6O8HSL → y_k1yn7bI1CgM5ZfW7jUq}/_buildManifest.js +0 -0
  63. /khoj/interface/compiled/_next/static/{WLmcH2J-wz36GlS6O8HSL → y_k1yn7bI1CgM5ZfW7jUq}/_ssgManifest.js +0 -0
  64. {khoj-1.41.1.dev90.dist-info → khoj-1.41.1.dev107.dist-info}/WHEEL +0 -0
  65. {khoj-1.41.1.dev90.dist-info → khoj-1.41.1.dev107.dist-info}/entry_points.txt +0 -0
  66. {khoj-1.41.1.dev90.dist-info → khoj-1.41.1.dev107.dist-info}/licenses/LICENSE +0 -0
@@ -73,6 +73,10 @@ model_to_prompt_size = {
73
73
  "claude-3-7-sonnet-20250219": 60000,
74
74
  "claude-3-7-sonnet-latest": 60000,
75
75
  "claude-3-5-haiku-20241022": 60000,
76
+ "claude-sonnet-4": 60000,
77
+ "claude-sonnet-4-20250514": 60000,
78
+ "claude-opus-4": 60000,
79
+ "claude-opus-4-20250514": 60000,
76
80
  # Offline Models
77
81
  "bartowski/Qwen2.5-14B-Instruct-GGUF": 20000,
78
82
  "bartowski/Meta-Llama-3.1-8B-Instruct-GGUF": 20000,
@@ -106,9 +110,12 @@ class InformationCollectionIteration:
106
110
 
107
111
 
108
112
  def construct_iteration_history(
109
- query: str, previous_iterations: List[InformationCollectionIteration], previous_iteration_prompt: str
113
+ previous_iterations: List[InformationCollectionIteration],
114
+ previous_iteration_prompt: str,
115
+ query: str = None,
110
116
  ) -> list[dict]:
111
- previous_iterations_history = []
117
+ iteration_history: list[dict] = []
118
+ previous_iteration_messages: list[dict] = []
112
119
  for idx, iteration in enumerate(previous_iterations):
113
120
  iteration_data = previous_iteration_prompt.format(
114
121
  tool=iteration.tool,
@@ -117,23 +124,19 @@ def construct_iteration_history(
117
124
  index=idx + 1,
118
125
  )
119
126
 
120
- previous_iterations_history.append(iteration_data)
127
+ previous_iteration_messages.append({"type": "text", "text": iteration_data})
121
128
 
122
- return (
123
- [
124
- {
125
- "by": "you",
126
- "message": query,
127
- },
129
+ if previous_iteration_messages:
130
+ if query:
131
+ iteration_history.append({"by": "you", "message": query})
132
+ iteration_history.append(
128
133
  {
129
134
  "by": "khoj",
130
135
  "intent": {"type": "remember", "query": query},
131
- "message": previous_iterations_history,
132
- },
133
- ]
134
- if previous_iterations_history
135
- else []
136
- )
136
+ "message": previous_iteration_messages,
137
+ }
138
+ )
139
+ return iteration_history
137
140
 
138
141
 
139
142
  def construct_chat_history(conversation_history: dict, n: int = 4, agent_name="AI") -> str:
@@ -281,6 +284,7 @@ async def save_to_conversation_log(
281
284
  generated_images: List[str] = [],
282
285
  raw_generated_files: List[FileAttachment] = [],
283
286
  generated_mermaidjs_diagram: str = None,
287
+ research_results: Optional[List[InformationCollectionIteration]] = None,
284
288
  train_of_thought: List[Any] = [],
285
289
  tracer: Dict[str, Any] = {},
286
290
  ):
@@ -298,6 +302,7 @@ async def save_to_conversation_log(
298
302
  "onlineContext": online_results,
299
303
  "codeContext": code_results,
300
304
  "operatorContext": operator_results,
305
+ "researchContext": [vars(r) for r in research_results] if research_results and not chat_response else None,
301
306
  "automationId": automation_id,
302
307
  "trainOfThought": train_of_thought,
303
308
  "turnId": turn_id,
@@ -337,7 +342,7 @@ Khoj: "{chat_response}"
337
342
 
338
343
 
339
344
  def construct_structured_message(
340
- message: list[str] | str,
345
+ message: list[dict] | str,
341
346
  images: list[str],
342
347
  model_type: str,
343
348
  vision_enabled: bool,
@@ -351,11 +356,9 @@ def construct_structured_message(
351
356
  ChatModel.ModelType.GOOGLE,
352
357
  ChatModel.ModelType.ANTHROPIC,
353
358
  ]:
354
- message = [message] if isinstance(message, str) else message
355
-
356
- constructed_messages: List[dict[str, Any]] = [
357
- {"type": "text", "text": message_part} for message_part in message
358
- ]
359
+ constructed_messages: List[dict[str, Any]] = (
360
+ [{"type": "text", "text": message}] if isinstance(message, str) else message
361
+ )
359
362
 
360
363
  if not is_none_or_empty(attached_file_context):
361
364
  constructed_messages.append({"type": "text", "text": attached_file_context})
@@ -364,6 +367,7 @@ def construct_structured_message(
364
367
  constructed_messages.append({"type": "image_url", "image_url": {"url": image}})
365
368
  return constructed_messages
366
369
 
370
+ message = message if isinstance(message, str) else "\n\n".join(m["text"] for m in message)
367
371
  if not is_none_or_empty(attached_file_context):
368
372
  return f"{attached_file_context}\n\n{message}"
369
373
 
@@ -387,7 +391,7 @@ def gather_raw_query_files(
387
391
 
388
392
 
389
393
  def generate_chatml_messages_with_context(
390
- user_message,
394
+ user_message: str,
391
395
  system_message: str = None,
392
396
  conversation_log={},
393
397
  model_name="gpt-4o-mini",
@@ -417,7 +421,7 @@ def generate_chatml_messages_with_context(
417
421
  # Extract Chat History for Context
418
422
  chatml_messages: List[ChatMessage] = []
419
423
  for chat in conversation_log.get("chat", []):
420
- message_context = ""
424
+ message_context = []
421
425
  message_attached_files = ""
422
426
 
423
427
  generated_assets = {}
@@ -429,16 +433,6 @@ def generate_chatml_messages_with_context(
429
433
  if chat["by"] == "khoj" and "excalidraw" in chat["intent"].get("type", ""):
430
434
  chat_message = chat["intent"].get("inferred-queries")[0]
431
435
 
432
- if not is_none_or_empty(chat.get("context")):
433
- references = "\n\n".join(
434
- {
435
- f"# File: {item['file']}\n## {item['compiled']}\n"
436
- for item in chat.get("context") or []
437
- if isinstance(item, dict)
438
- }
439
- )
440
- message_context += f"{prompts.notes_conversation.format(references=references)}\n\n"
441
-
442
436
  if chat.get("queryFiles"):
443
437
  raw_query_files = chat.get("queryFiles")
444
438
  query_files_dict = dict()
@@ -449,15 +443,38 @@ def generate_chatml_messages_with_context(
449
443
  chatml_messages.append(ChatMessage(content=message_attached_files, role=role))
450
444
 
451
445
  if not is_none_or_empty(chat.get("onlineContext")):
452
- message_context += f"{prompts.online_search_conversation.format(online_results=chat.get('onlineContext'))}"
446
+ message_context += [
447
+ {
448
+ "type": "text",
449
+ "text": f"{prompts.online_search_conversation.format(online_results=chat.get('onlineContext'))}",
450
+ }
451
+ ]
453
452
 
454
453
  if not is_none_or_empty(chat.get("codeContext")):
455
- message_context += f"{prompts.code_executed_context.format(code_results=chat.get('codeContext'))}"
454
+ message_context += [
455
+ {
456
+ "type": "text",
457
+ "text": f"{prompts.code_executed_context.format(code_results=chat.get('codeContext'))}",
458
+ }
459
+ ]
456
460
 
457
461
  if not is_none_or_empty(chat.get("operatorContext")):
458
- message_context += (
459
- f"{prompts.operator_execution_context.format(operator_results=chat.get('operatorContext'))}"
462
+ message_context += [
463
+ {
464
+ "type": "text",
465
+ "text": f"{prompts.operator_execution_context.format(operator_results=chat.get('operatorContext'))}",
466
+ }
467
+ ]
468
+
469
+ if not is_none_or_empty(chat.get("context")):
470
+ references = "\n\n".join(
471
+ {
472
+ f"# File: {item['file']}\n## {item['compiled']}\n"
473
+ for item in chat.get("context") or []
474
+ if isinstance(item, dict)
475
+ }
460
476
  )
477
+ message_context += [{"type": "text", "text": f"{prompts.notes_conversation.format(references=references)}"}]
461
478
 
462
479
  if not is_none_or_empty(message_context):
463
480
  reconstructed_context_message = ChatMessage(content=message_context, role="user")
@@ -697,8 +714,9 @@ def clean_code_python(code: str):
697
714
 
698
715
  def load_complex_json(json_str):
699
716
  """
700
- Preprocess a raw JSON string to escape unescaped double quotes within value strings,
701
- while preserving the JSON structure and already escaped quotes.
717
+ Preprocess a raw JSON string to
718
+ - escape unescaped double quotes within value strings while preserving the JSON structure and already escaped quotes.
719
+ - remove suffix after the first valid JSON object,
702
720
  """
703
721
 
704
722
  def replace_unescaped_quotes(match):
@@ -726,9 +744,20 @@ def load_complex_json(json_str):
726
744
  for loads in json_loaders_to_try:
727
745
  try:
728
746
  return loads(processed)
729
- except (json.JSONDecodeError, pyjson5.Json5Exception) as e:
730
- errors.append(f"{type(e).__name__}: {str(e)}")
731
-
747
+ except (json.JSONDecodeError, pyjson5.Json5Exception) as e_load:
748
+ loader_name = loads.__name__
749
+ errors.append(f"{loader_name} (initial parse): {type(e_load).__name__}: {str(e_load)}")
750
+
751
+ # Handle plain text suffixes by slicing at error position
752
+ if hasattr(e_load, "pos") and 0 < e_load.pos < len(processed):
753
+ try:
754
+ sliced = processed[: e_load.pos].strip()
755
+ if sliced:
756
+ return loads(sliced)
757
+ except Exception as e_slice:
758
+ errors.append(
759
+ f"{loader_name} after slice at {e_load.pos}: {type(e_slice).__name__}: {str(e_slice)}"
760
+ )
732
761
  # If all loaders fail, raise the aggregated error
733
762
  raise ValueError(
734
763
  f"Failed to load JSON with errors: {'; '.join(errors)}\n\n"
@@ -13,7 +13,7 @@ from io import BytesIO
13
13
  from typing import Any, List
14
14
 
15
15
  import numpy as np
16
- from openai import AzureOpenAI, OpenAI
16
+ from openai import AsyncAzureOpenAI, AsyncOpenAI
17
17
  from openai.types.chat import ChatCompletion
18
18
  from PIL import Image
19
19
 
@@ -72,7 +72,7 @@ class GroundingAgentUitars:
72
72
  def __init__(
73
73
  self,
74
74
  model_name: str,
75
- client: OpenAI | AzureOpenAI,
75
+ client: AsyncOpenAI | AsyncAzureOpenAI,
76
76
  max_iterations=50,
77
77
  environment_type: Literal["computer", "web"] = "computer",
78
78
  runtime_conf: dict = {
@@ -4,8 +4,6 @@ import logging
4
4
  import os
5
5
  from typing import Callable, List, Optional
6
6
 
7
- import requests
8
-
9
7
  from khoj.database.adapters import AgentAdapters, ConversationAdapters
10
8
  from khoj.database.models import Agent, ChatModel, KhojUser
11
9
  from khoj.processor.operator.operator_actions import *
@@ -49,9 +47,9 @@ async def operate_browser(
49
47
  # Initialize Agent
50
48
  max_iterations = int(os.getenv("KHOJ_OPERATOR_ITERATIONS", 40))
51
49
  operator_agent: OperatorAgent
52
- if reasoning_model.name.startswith("gpt-4o"):
50
+ if is_operator_model(reasoning_model.name) == ChatModel.ModelType.OPENAI:
53
51
  operator_agent = OpenAIOperatorAgent(query, reasoning_model, max_iterations, tracer)
54
- elif reasoning_model.name.startswith("claude-3-7-sonnet"):
52
+ elif is_operator_model(reasoning_model.name) == ChatModel.ModelType.ANTHROPIC:
55
53
  operator_agent = AnthropicOperatorAgent(query, reasoning_model, max_iterations, tracer)
56
54
  else:
57
55
  grounding_model_name = "ui-tars-1.5"
@@ -150,3 +148,18 @@ async def operate_browser(
150
148
  "result": user_input_message or response,
151
149
  "webpages": [{"link": url, "snippet": ""} for url in environment.visited_urls],
152
150
  }
151
+
152
+
153
+ def is_operator_model(model: str) -> ChatModel.ModelType | None:
154
+ """Check if the model is an operator model."""
155
+ operator_models = {
156
+ "gpt-4o": ChatModel.ModelType.OPENAI,
157
+ "claude-3-7-sonnet": ChatModel.ModelType.ANTHROPIC,
158
+ "claude-sonnet-4": ChatModel.ModelType.ANTHROPIC,
159
+ "claude-opus-4": ChatModel.ModelType.ANTHROPIC,
160
+ "ui-tars-1.5": ChatModel.ModelType.OFFLINE,
161
+ }
162
+ for operator_model in operator_models:
163
+ if model.startswith(operator_model):
164
+ return operator_models[operator_model] # type: ignore[return-value]
165
+ return None
@@ -3,10 +3,11 @@ import json
3
3
  import logging
4
4
  from copy import deepcopy
5
5
  from datetime import datetime
6
- from typing import Any, List, Optional, cast
6
+ from typing import List, Optional, cast
7
7
 
8
8
  from anthropic.types.beta import BetaContentBlock
9
9
 
10
+ from khoj.processor.conversation.anthropic.utils import is_reasoning_model
10
11
  from khoj.processor.operator.operator_actions import *
11
12
  from khoj.processor.operator.operator_agent_base import (
12
13
  AgentActResult,
@@ -25,8 +26,7 @@ class AnthropicOperatorAgent(OperatorAgent):
25
26
  client = get_anthropic_async_client(
26
27
  self.vision_model.ai_model_api.api_key, self.vision_model.ai_model_api.api_base_url
27
28
  )
28
- tool_version = "2025-01-24"
29
- betas = [f"computer-use-{tool_version}", "token-efficient-tools-2025-02-19"]
29
+ betas = self.model_default_headers()
30
30
  temperature = 1.0
31
31
  actions: List[OperatorAction] = []
32
32
  action_results: List[dict] = []
@@ -56,7 +56,7 @@ class AnthropicOperatorAgent(OperatorAgent):
56
56
 
57
57
  tools = [
58
58
  {
59
- "type": f"computer_20250124",
59
+ "type": self.model_default_tool("computer"),
60
60
  "name": "computer",
61
61
  "display_width_px": 1024,
62
62
  "display_height_px": 768,
@@ -78,7 +78,7 @@ class AnthropicOperatorAgent(OperatorAgent):
78
78
  ]
79
79
 
80
80
  thinking: dict[str, str | int] = {"type": "disabled"}
81
- if self.vision_model.name.startswith("claude-3-7"):
81
+ if is_reasoning_model(self.vision_model.name):
82
82
  thinking = {"type": "enabled", "budget_tokens": 1024}
83
83
 
84
84
  messages_for_api = self._format_message_for_api(self.messages)
@@ -381,3 +381,22 @@ class AnthropicOperatorAgent(OperatorAgent):
381
381
  return None
382
382
 
383
383
  return coord
384
+
385
+ def model_default_tool(self, tool_type: Literal["computer", "editor", "terminal"]) -> str:
386
+ """Get the default tool of specified type for the given model."""
387
+ if self.vision_model.name.startswith("claude-3-7-sonnet"):
388
+ if tool_type == "computer":
389
+ return "computer_20250124"
390
+ elif self.vision_model.name.startswith("claude-sonnet-4") or self.vision_model.name.startswith("claude-opus-4"):
391
+ if tool_type == "computer":
392
+ return "computer_20250124"
393
+ raise ValueError(f"Unsupported tool type for model '{self.vision_model.name}': {tool_type}")
394
+
395
+ def model_default_headers(self) -> list[str]:
396
+ """Get the default computer use headers for the given model."""
397
+ if self.vision_model.name.startswith("claude-3-7-sonnet"):
398
+ return [f"computer-use-2025-01-24", "token-efficient-tools-2025-02-19"]
399
+ elif self.vision_model.name.startswith("claude-sonnet-4") or self.vision_model.name.startswith("claude-opus-4"):
400
+ return ["computer-use-2025-01-24"]
401
+ else:
402
+ return []
khoj/routers/api_chat.py CHANGED
@@ -682,11 +682,13 @@ async def chat(
682
682
  timezone = body.timezone
683
683
  raw_images = body.images
684
684
  raw_query_files = body.files
685
+ interrupt_flag = body.interrupt
685
686
 
686
687
  async def event_generator(q: str, images: list[str]):
687
688
  start_time = time.perf_counter()
688
689
  ttft = None
689
690
  chat_metadata: dict = {}
691
+ conversation = None
690
692
  user: KhojUser = request.user.object
691
693
  is_subscribed = has_required_scope(request, ["premium"])
692
694
  q = unquote(q)
@@ -720,6 +722,20 @@ async def chat(
720
722
  for file in raw_query_files:
721
723
  query_files[file.name] = file.content
722
724
 
725
+ research_results: List[InformationCollectionIteration] = []
726
+ online_results: Dict = dict()
727
+ code_results: Dict = dict()
728
+ operator_results: Dict[str, str] = {}
729
+ compiled_references: List[Any] = []
730
+ inferred_queries: List[Any] = []
731
+ attached_file_context = gather_raw_query_files(query_files)
732
+
733
+ generated_images: List[str] = []
734
+ generated_files: List[FileAttachment] = []
735
+ generated_mermaidjs_diagram: str = None
736
+ generated_asset_results: Dict = dict()
737
+ program_execution_context: List[str] = []
738
+
723
739
  # Create a task to monitor for disconnections
724
740
  disconnect_monitor_task = None
725
741
 
@@ -727,8 +743,34 @@ async def chat(
727
743
  try:
728
744
  msg = await request.receive()
729
745
  if msg["type"] == "http.disconnect":
730
- logger.debug(f"User {user} disconnected from {common.client} client.")
746
+ logger.debug(f"Request cancelled. User {user} disconnected from {common.client} client.")
731
747
  cancellation_event.set()
748
+ # ensure partial chat state saved on interrupt
749
+ # shield the save against task cancellation
750
+ if conversation:
751
+ await asyncio.shield(
752
+ save_to_conversation_log(
753
+ q,
754
+ chat_response="",
755
+ user=user,
756
+ meta_log=meta_log,
757
+ compiled_references=compiled_references,
758
+ online_results=online_results,
759
+ code_results=code_results,
760
+ operator_results=operator_results,
761
+ research_results=research_results,
762
+ inferred_queries=inferred_queries,
763
+ client_application=request.user.client_app,
764
+ conversation_id=conversation_id,
765
+ query_images=uploaded_images,
766
+ train_of_thought=train_of_thought,
767
+ raw_query_files=raw_query_files,
768
+ generated_images=generated_images,
769
+ raw_generated_files=generated_asset_results,
770
+ generated_mermaidjs_diagram=generated_mermaidjs_diagram,
771
+ tracer=tracer,
772
+ )
773
+ )
732
774
  except Exception as e:
733
775
  logger.error(f"Error in disconnect monitor: {e}")
734
776
 
@@ -746,7 +788,6 @@ async def chat(
746
788
  nonlocal ttft, train_of_thought
747
789
  event_delimiter = "␃🔚␗"
748
790
  if cancellation_event.is_set():
749
- logger.debug(f"User {user} disconnected from {common.client} client. Setting cancellation event.")
750
791
  return
751
792
  try:
752
793
  if event_type == ChatEvent.END_LLM_RESPONSE:
@@ -770,9 +811,6 @@ async def chat(
770
811
  yield data
771
812
  elif event_type == ChatEvent.REFERENCES or ChatEvent.METADATA or stream:
772
813
  yield json.dumps({"type": event_type.value, "data": data}, ensure_ascii=False)
773
- except asyncio.CancelledError as e:
774
- if cancellation_event.is_set():
775
- logger.debug(f"Request cancelled. User {user} disconnected from {common.client} client: {e}.")
776
814
  except Exception as e:
777
815
  if not cancellation_event.is_set():
778
816
  logger.error(
@@ -860,9 +898,9 @@ async def chat(
860
898
  async for result in send_llm_response(f"Conversation {conversation_id} not found", tracer.get("usage")):
861
899
  yield result
862
900
  return
863
- conversation_id = conversation.id
901
+ conversation_id = str(conversation.id)
864
902
 
865
- async for event in send_event(ChatEvent.METADATA, {"conversationId": str(conversation_id), "turnId": turn_id}):
903
+ async for event in send_event(ChatEvent.METADATA, {"conversationId": conversation_id, "turnId": turn_id}):
866
904
  yield event
867
905
 
868
906
  agent: Agent | None = None
@@ -883,21 +921,53 @@ async def chat(
883
921
  user_message_time = datetime.now().strftime("%Y-%m-%d %H:%M:%S")
884
922
  meta_log = conversation.conversation_log
885
923
 
886
- researched_results = ""
887
- online_results: Dict = dict()
888
- code_results: Dict = dict()
889
- operator_results: Dict[str, str] = {}
890
- generated_asset_results: Dict = dict()
891
- ## Extract Document References
892
- compiled_references: List[Any] = []
893
- inferred_queries: List[Any] = []
894
- file_filters = conversation.file_filters if conversation and conversation.file_filters else []
895
- attached_file_context = gather_raw_query_files(query_files)
924
+ # If interrupt flag is set, wait for the previous turn to be saved before proceeding
925
+ if interrupt_flag:
926
+ max_wait_time = 20.0 # seconds
927
+ wait_interval = 0.3 # seconds
928
+ wait_start = wait_current = time.time()
929
+ while wait_current - wait_start < max_wait_time:
930
+ # Refresh conversation to check if interrupted message saved to DB
931
+ conversation = await ConversationAdapters.aget_conversation_by_user(
932
+ user,
933
+ client_application=request.user.client_app,
934
+ conversation_id=conversation_id,
935
+ )
936
+ if (
937
+ conversation
938
+ and conversation.messages
939
+ and conversation.messages[-1].by == "khoj"
940
+ and not conversation.messages[-1].message
941
+ ):
942
+ logger.info(f"Detected interrupted message save to conversation {conversation_id}.")
943
+ break
944
+ await asyncio.sleep(wait_interval)
945
+ wait_current = time.time()
896
946
 
897
- generated_images: List[str] = []
898
- generated_files: List[FileAttachment] = []
899
- generated_mermaidjs_diagram: str = None
900
- program_execution_context: List[str] = []
947
+ if wait_current - wait_start >= max_wait_time:
948
+ logger.warning(
949
+ f"Timeout waiting to load interrupted context from conversation {conversation_id}. Proceed without previous context."
950
+ )
951
+
952
+ # If interrupted message in DB
953
+ if (
954
+ conversation
955
+ and conversation.messages
956
+ and conversation.messages[-1].by == "khoj"
957
+ and not conversation.messages[-1].message
958
+ ):
959
+ # Populate context from interrupted message
960
+ last_message = conversation.messages[-1]
961
+ online_results = {key: val.model_dump() for key, val in last_message.onlineContext.items() or []}
962
+ code_results = {key: val.model_dump() for key, val in last_message.codeContext.items() or []}
963
+ operator_results = last_message.operatorContext or {}
964
+ compiled_references = [ref.model_dump() for ref in last_message.context or []]
965
+ research_results = [
966
+ InformationCollectionIteration(**iter_dict) for iter_dict in last_message.researchContext or []
967
+ ]
968
+ # Drop the interrupted message from conversation history
969
+ meta_log["chat"].pop()
970
+ logger.info(f"Loaded interrupted partial context from conversation {conversation_id}.")
901
971
 
902
972
  if conversation_commands == [ConversationCommand.Default]:
903
973
  try:
@@ -936,6 +1006,7 @@ async def chat(
936
1006
  return
937
1007
 
938
1008
  defiltered_query = defilter_query(q)
1009
+ file_filters = conversation.file_filters if conversation and conversation.file_filters else []
939
1010
 
940
1011
  if conversation_commands == [ConversationCommand.Research]:
941
1012
  async for research_result in execute_information_collection(
@@ -943,12 +1014,13 @@ async def chat(
943
1014
  query=defiltered_query,
944
1015
  conversation_id=conversation_id,
945
1016
  conversation_history=meta_log,
1017
+ previous_iterations=research_results,
946
1018
  query_images=uploaded_images,
947
1019
  agent=agent,
948
1020
  send_status_func=partial(send_event, ChatEvent.STATUS),
949
1021
  user_name=user_name,
950
1022
  location=location,
951
- file_filters=conversation.file_filters if conversation else [],
1023
+ file_filters=file_filters,
952
1024
  query_files=attached_file_context,
953
1025
  tracer=tracer,
954
1026
  cancellation_event=cancellation_event,
@@ -963,17 +1035,16 @@ async def chat(
963
1035
  compiled_references.extend(research_result.context)
964
1036
  if research_result.operatorContext:
965
1037
  operator_results.update(research_result.operatorContext)
966
- researched_results += research_result.summarizedResult
1038
+ research_results.append(research_result)
967
1039
 
968
1040
  else:
969
1041
  yield research_result
970
1042
 
971
1043
  # researched_results = await extract_relevant_info(q, researched_results, agent)
972
1044
  if state.verbose > 1:
973
- logger.debug(f"Researched Results: {researched_results}")
1045
+ logger.debug(f'Researched Results: {"".join(r.summarizedResult for r in research_results)}')
974
1046
 
975
1047
  used_slash_summarize = conversation_commands == [ConversationCommand.Summarize]
976
- file_filters = conversation.file_filters if conversation else []
977
1048
  # Skip trying to summarize if
978
1049
  if (
979
1050
  # summarization intent was inferred
@@ -1362,7 +1433,7 @@ async def chat(
1362
1433
 
1363
1434
  # Check if the user has disconnected
1364
1435
  if cancellation_event.is_set():
1365
- logger.debug(f"User {user} disconnected from {common.client} client. Stopping LLM response.")
1436
+ logger.debug(f"Stopping LLM response to user {user} on {common.client} client.")
1366
1437
  # Cancel the disconnect monitor task if it is still running
1367
1438
  await cancel_disconnect_monitor()
1368
1439
  return
@@ -1379,14 +1450,13 @@ async def chat(
1379
1450
  online_results,
1380
1451
  code_results,
1381
1452
  operator_results,
1453
+ research_results,
1382
1454
  inferred_queries,
1383
1455
  conversation_commands,
1384
1456
  user,
1385
1457
  request.user.client_app,
1386
- conversation_id,
1387
1458
  location,
1388
1459
  user_name,
1389
- researched_results,
1390
1460
  uploaded_images,
1391
1461
  train_of_thought,
1392
1462
  attached_file_context,
khoj/routers/api_model.py CHANGED
@@ -72,7 +72,7 @@ async def update_chat_model(
72
72
  if chat_model is None:
73
73
  return Response(status_code=404, content=json.dumps({"status": "error", "message": "Chat model not found"}))
74
74
  if not subscribed and chat_model.price_tier != PriceTier.FREE:
75
- raise Response(
75
+ return Response(
76
76
  status_code=403,
77
77
  content=json.dumps({"status": "error", "message": "Subscribe to switch to this chat model"}),
78
78
  )
@@ -108,7 +108,7 @@ async def update_voice_model(
108
108
  if voice_model is None:
109
109
  return Response(status_code=404, content=json.dumps({"status": "error", "message": "Voice model not found"}))
110
110
  if not subscribed and voice_model.price_tier != PriceTier.FREE:
111
- raise Response(
111
+ return Response(
112
112
  status_code=403,
113
113
  content=json.dumps({"status": "error", "message": "Subscribe to switch to this voice model"}),
114
114
  )
@@ -143,7 +143,7 @@ async def update_paint_model(
143
143
  if image_model is None:
144
144
  return Response(status_code=404, content=json.dumps({"status": "error", "message": "Image model not found"}))
145
145
  if not subscribed and image_model.price_tier != PriceTier.FREE:
146
- raise Response(
146
+ return Response(
147
147
  status_code=403,
148
148
  content=json.dumps({"status": "error", "message": "Subscribe to switch to this image model"}),
149
149
  )