khoj 1.27.2.dev12__py3-none-any.whl → 1.28.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (78) hide show
  1. khoj/configure.py +1 -1
  2. khoj/database/adapters/__init__.py +55 -12
  3. khoj/interface/compiled/404/index.html +1 -1
  4. khoj/interface/compiled/_next/static/chunks/1034-da58b679fcbb79c1.js +1 -0
  5. khoj/interface/compiled/_next/static/chunks/1467-b331e469fe411347.js +1 -0
  6. khoj/interface/compiled/_next/static/chunks/1603-c1568f45947e9f2c.js +1 -0
  7. khoj/interface/compiled/_next/static/chunks/3423-ff7402ae1dd66592.js +1 -0
  8. khoj/interface/compiled/_next/static/chunks/8423-e80647edf6c92c27.js +1 -0
  9. khoj/interface/compiled/_next/static/chunks/app/agents/{page-2beaba7c9bb750bd.js → page-fc492762298e975e.js} +1 -1
  10. khoj/interface/compiled/_next/static/chunks/app/automations/{page-9b5c77e0b0dd772c.js → page-416ee13a00575c39.js} +1 -1
  11. khoj/interface/compiled/_next/static/chunks/app/chat/page-c70f5b0c722d7627.js +1 -0
  12. khoj/interface/compiled/_next/static/chunks/app/factchecker/page-1541d90140794f63.js +1 -0
  13. khoj/interface/compiled/_next/static/chunks/app/{page-8f22b790e50dd722.js → page-b269e444fc067759.js} +1 -1
  14. khoj/interface/compiled/_next/static/chunks/app/search/{page-ab2995529ece3140.js → page-7d431ce8e565c7c3.js} +1 -1
  15. khoj/interface/compiled/_next/static/chunks/app/settings/{page-7946cabb9c54e22d.js → page-95f56e53f48f0289.js} +1 -1
  16. khoj/interface/compiled/_next/static/chunks/app/share/chat/{page-6a01e07fb244c10c.js → page-4eba6154f7bb9771.js} +1 -1
  17. khoj/interface/compiled/_next/static/chunks/{webpack-17202cfae517c5de.js → webpack-33a82ccca02cd2b8.js} +1 -1
  18. khoj/interface/compiled/_next/static/css/2196fae09c2f906e.css +1 -0
  19. khoj/interface/compiled/_next/static/css/6bde1f2045622ef7.css +1 -0
  20. khoj/interface/compiled/_next/static/css/a795ee88875f4853.css +25 -0
  21. khoj/interface/compiled/_next/static/css/ebef43da1c0651d5.css +1 -0
  22. khoj/interface/compiled/agents/index.html +1 -1
  23. khoj/interface/compiled/agents/index.txt +2 -2
  24. khoj/interface/compiled/automations/index.html +1 -1
  25. khoj/interface/compiled/automations/index.txt +2 -2
  26. khoj/interface/compiled/chat/index.html +1 -1
  27. khoj/interface/compiled/chat/index.txt +2 -2
  28. khoj/interface/compiled/factchecker/index.html +1 -1
  29. khoj/interface/compiled/factchecker/index.txt +2 -2
  30. khoj/interface/compiled/index.html +1 -1
  31. khoj/interface/compiled/index.txt +2 -2
  32. khoj/interface/compiled/search/index.html +1 -1
  33. khoj/interface/compiled/search/index.txt +2 -2
  34. khoj/interface/compiled/settings/index.html +1 -1
  35. khoj/interface/compiled/settings/index.txt +2 -2
  36. khoj/interface/compiled/share/chat/index.html +1 -1
  37. khoj/interface/compiled/share/chat/index.txt +2 -2
  38. khoj/processor/conversation/anthropic/anthropic_chat.py +19 -10
  39. khoj/processor/conversation/anthropic/utils.py +37 -6
  40. khoj/processor/conversation/google/gemini_chat.py +23 -13
  41. khoj/processor/conversation/google/utils.py +34 -10
  42. khoj/processor/conversation/offline/chat_model.py +48 -16
  43. khoj/processor/conversation/openai/gpt.py +25 -10
  44. khoj/processor/conversation/openai/utils.py +50 -9
  45. khoj/processor/conversation/prompts.py +156 -65
  46. khoj/processor/conversation/utils.py +306 -6
  47. khoj/processor/embeddings.py +4 -4
  48. khoj/processor/image/generate.py +2 -0
  49. khoj/processor/tools/online_search.py +27 -12
  50. khoj/processor/tools/run_code.py +144 -0
  51. khoj/routers/api.py +11 -6
  52. khoj/routers/api_chat.py +213 -111
  53. khoj/routers/helpers.py +171 -60
  54. khoj/routers/research.py +320 -0
  55. khoj/search_filter/date_filter.py +1 -3
  56. khoj/search_filter/file_filter.py +1 -2
  57. khoj/search_type/text_search.py +3 -3
  58. khoj/utils/helpers.py +25 -3
  59. khoj/utils/yaml.py +4 -0
  60. {khoj-1.27.2.dev12.dist-info → khoj-1.28.1.dist-info}/METADATA +3 -2
  61. {khoj-1.27.2.dev12.dist-info → khoj-1.28.1.dist-info}/RECORD +68 -65
  62. khoj/interface/compiled/_next/static/chunks/1603-b9d95833e0e025e8.js +0 -1
  63. khoj/interface/compiled/_next/static/chunks/2697-61fcba89fd87eab4.js +0 -1
  64. khoj/interface/compiled/_next/static/chunks/3423-8e9c420574a9fbe3.js +0 -1
  65. khoj/interface/compiled/_next/static/chunks/9479-4b443fdcc99141c9.js +0 -1
  66. khoj/interface/compiled/_next/static/chunks/app/chat/page-151232d8417a1ea1.js +0 -1
  67. khoj/interface/compiled/_next/static/chunks/app/factchecker/page-798904432c2417c4.js +0 -1
  68. khoj/interface/compiled/_next/static/css/2272c73fc7a3b571.css +0 -1
  69. khoj/interface/compiled/_next/static/css/553f9cdcc7a2bcd6.css +0 -1
  70. khoj/interface/compiled/_next/static/css/76d55eb435962b19.css +0 -25
  71. khoj/interface/compiled/_next/static/css/b70402177a7c3207.css +0 -1
  72. /khoj/interface/compiled/_next/static/{kul3DNllWR6eaUDc4X0eU → JcTomiF3o0dIo4RxHR9Vu}/_buildManifest.js +0 -0
  73. /khoj/interface/compiled/_next/static/{kul3DNllWR6eaUDc4X0eU → JcTomiF3o0dIo4RxHR9Vu}/_ssgManifest.js +0 -0
  74. /khoj/interface/compiled/_next/static/chunks/{1970-1d6d0c1b00b4f343.js → 1970-90dd510762d820ba.js} +0 -0
  75. /khoj/interface/compiled/_next/static/chunks/{9417-759984ad62caa3dc.js → 9417-951f46451a8dd6d7.js} +0 -0
  76. {khoj-1.27.2.dev12.dist-info → khoj-1.28.1.dist-info}/WHEEL +0 -0
  77. {khoj-1.27.2.dev12.dist-info → khoj-1.28.1.dist-info}/entry_points.txt +0 -0
  78. {khoj-1.27.2.dev12.dist-info → khoj-1.28.1.dist-info}/licenses/LICENSE +0 -0
@@ -1,17 +1,23 @@
1
1
  import base64
2
+ import json
2
3
  import logging
3
4
  import math
4
5
  import mimetypes
6
+ import os
5
7
  import queue
8
+ import uuid
6
9
  from dataclasses import dataclass
7
10
  from datetime import datetime
11
+ from enum import Enum
8
12
  from io import BytesIO
9
13
  from time import perf_counter
10
- from typing import Any, Dict, List, Optional
14
+ from typing import Any, Callable, Dict, List, Optional
11
15
 
12
16
  import PIL.Image
13
17
  import requests
14
18
  import tiktoken
19
+ import yaml
20
+ from git import Repo
15
21
  from langchain.schema import ChatMessage
16
22
  from llama_cpp.llama import Llama
17
23
  from transformers import AutoTokenizer
@@ -20,8 +26,17 @@ from khoj.database.adapters import ConversationAdapters
20
26
  from khoj.database.models import ChatModelOptions, ClientApplication, KhojUser
21
27
  from khoj.processor.conversation import prompts
22
28
  from khoj.processor.conversation.offline.utils import download_model, infer_max_tokens
29
+ from khoj.search_filter.base_filter import BaseFilter
30
+ from khoj.search_filter.date_filter import DateFilter
31
+ from khoj.search_filter.file_filter import FileFilter
32
+ from khoj.search_filter.word_filter import WordFilter
23
33
  from khoj.utils import state
24
- from khoj.utils.helpers import is_none_or_empty, merge_dicts
34
+ from khoj.utils.helpers import (
35
+ ConversationCommand,
36
+ in_debug_mode,
37
+ is_none_or_empty,
38
+ merge_dicts,
39
+ )
25
40
 
26
41
  logger = logging.getLogger(__name__)
27
42
  model_to_prompt_size = {
@@ -82,8 +97,110 @@ class ThreadedGenerator:
82
97
  self.queue.put(StopIteration)
83
98
 
84
99
 
100
+ class InformationCollectionIteration:
101
+ def __init__(
102
+ self,
103
+ tool: str,
104
+ query: str,
105
+ context: list = None,
106
+ onlineContext: dict = None,
107
+ codeContext: dict = None,
108
+ summarizedResult: str = None,
109
+ ):
110
+ self.tool = tool
111
+ self.query = query
112
+ self.context = context
113
+ self.onlineContext = onlineContext
114
+ self.codeContext = codeContext
115
+ self.summarizedResult = summarizedResult
116
+
117
+
118
+ def construct_iteration_history(
119
+ previous_iterations: List[InformationCollectionIteration], previous_iteration_prompt: str
120
+ ) -> str:
121
+ previous_iterations_history = ""
122
+ for idx, iteration in enumerate(previous_iterations):
123
+ iteration_data = previous_iteration_prompt.format(
124
+ tool=iteration.tool,
125
+ query=iteration.query,
126
+ result=iteration.summarizedResult,
127
+ index=idx + 1,
128
+ )
129
+
130
+ previous_iterations_history += iteration_data
131
+ return previous_iterations_history
132
+
133
+
134
+ def construct_chat_history(conversation_history: dict, n: int = 4, agent_name="AI") -> str:
135
+ chat_history = ""
136
+ for chat in conversation_history.get("chat", [])[-n:]:
137
+ if chat["by"] == "khoj" and chat["intent"].get("type") in ["remember", "reminder", "summarize"]:
138
+ chat_history += f"User: {chat['intent']['query']}\n"
139
+
140
+ if chat["intent"].get("inferred-queries"):
141
+ chat_history += f'Khoj: {{"queries": {chat["intent"].get("inferred-queries")}}}\n'
142
+
143
+ chat_history += f"{agent_name}: {chat['message']}\n\n"
144
+ elif chat["by"] == "khoj" and ("text-to-image" in chat["intent"].get("type")):
145
+ chat_history += f"User: {chat['intent']['query']}\n"
146
+ chat_history += f"{agent_name}: [generated image redacted for space]\n"
147
+ elif chat["by"] == "khoj" and ("excalidraw" in chat["intent"].get("type")):
148
+ chat_history += f"User: {chat['intent']['query']}\n"
149
+ chat_history += f"{agent_name}: {chat['intent']['inferred-queries'][0]}\n"
150
+ return chat_history
151
+
152
+
153
+ def construct_tool_chat_history(
154
+ previous_iterations: List[InformationCollectionIteration], tool: ConversationCommand = None
155
+ ) -> Dict[str, list]:
156
+ chat_history: list = []
157
+ inferred_query_extractor: Callable[[InformationCollectionIteration], List[str]] = lambda x: []
158
+ if tool == ConversationCommand.Notes:
159
+ inferred_query_extractor = (
160
+ lambda iteration: [c["query"] for c in iteration.context] if iteration.context else []
161
+ )
162
+ elif tool == ConversationCommand.Online:
163
+ inferred_query_extractor = (
164
+ lambda iteration: list(iteration.onlineContext.keys()) if iteration.onlineContext else []
165
+ )
166
+ elif tool == ConversationCommand.Code:
167
+ inferred_query_extractor = lambda iteration: list(iteration.codeContext.keys()) if iteration.codeContext else []
168
+ for iteration in previous_iterations:
169
+ chat_history += [
170
+ {
171
+ "by": "you",
172
+ "message": iteration.query,
173
+ },
174
+ {
175
+ "by": "khoj",
176
+ "intent": {
177
+ "type": "remember",
178
+ "inferred-queries": inferred_query_extractor(iteration),
179
+ "query": iteration.query,
180
+ },
181
+ "message": iteration.summarizedResult,
182
+ },
183
+ ]
184
+
185
+ return {"chat": chat_history}
186
+
187
+
188
+ class ChatEvent(Enum):
189
+ START_LLM_RESPONSE = "start_llm_response"
190
+ END_LLM_RESPONSE = "end_llm_response"
191
+ MESSAGE = "message"
192
+ REFERENCES = "references"
193
+ STATUS = "status"
194
+ METADATA = "metadata"
195
+
196
+
85
197
  def message_to_log(
86
- user_message, chat_response, user_message_metadata={}, khoj_message_metadata={}, conversation_log=[]
198
+ user_message,
199
+ chat_response,
200
+ user_message_metadata={},
201
+ khoj_message_metadata={},
202
+ conversation_log=[],
203
+ train_of_thought=[],
87
204
  ):
88
205
  """Create json logs from messages, metadata for conversation log"""
89
206
  default_khoj_message_metadata = {
@@ -111,28 +228,37 @@ def save_to_conversation_log(
111
228
  user_message_time: str = None,
112
229
  compiled_references: List[Dict[str, Any]] = [],
113
230
  online_results: Dict[str, Any] = {},
231
+ code_results: Dict[str, Any] = {},
114
232
  inferred_queries: List[str] = [],
115
233
  intent_type: str = "remember",
116
234
  client_application: ClientApplication = None,
117
235
  conversation_id: str = None,
118
236
  automation_id: str = None,
119
237
  query_images: List[str] = None,
238
+ tracer: Dict[str, Any] = {},
239
+ train_of_thought: List[Any] = [],
120
240
  ):
121
241
  user_message_time = user_message_time or datetime.now().strftime("%Y-%m-%d %H:%M:%S")
242
+ turn_id = tracer.get("mid") or str(uuid.uuid4())
122
243
  updated_conversation = message_to_log(
123
244
  user_message=q,
124
245
  chat_response=chat_response,
125
246
  user_message_metadata={
126
247
  "created": user_message_time,
127
248
  "images": query_images,
249
+ "turnId": turn_id,
128
250
  },
129
251
  khoj_message_metadata={
130
252
  "context": compiled_references,
131
253
  "intent": {"inferred-queries": inferred_queries, "type": intent_type},
132
254
  "onlineContext": online_results,
255
+ "codeContext": code_results,
133
256
  "automationId": automation_id,
257
+ "trainOfThought": train_of_thought,
258
+ "turnId": turn_id,
134
259
  },
135
260
  conversation_log=meta_log.get("chat", []),
261
+ train_of_thought=train_of_thought,
136
262
  )
137
263
  ConversationAdapters.save_conversation(
138
264
  user,
@@ -142,6 +268,9 @@ def save_to_conversation_log(
142
268
  user_message=q,
143
269
  )
144
270
 
271
+ if in_debug_mode() or state.verbose > 1:
272
+ merge_message_into_conversation_trace(q, chat_response, tracer)
273
+
145
274
  logger.info(
146
275
  f"""
147
276
  Saved Conversation Turn
@@ -323,9 +452,23 @@ def reciprocal_conversation_to_chatml(message_pair):
323
452
  return [ChatMessage(content=message, role=role) for message, role in zip(message_pair, ["user", "assistant"])]
324
453
 
325
454
 
326
- def remove_json_codeblock(response: str):
327
- """Remove any markdown json codeblock formatting if present. Useful for non schema enforceable models"""
328
- return response.removeprefix("```json").removesuffix("```")
455
+ def clean_json(response: str):
456
+ """Remove any markdown json codeblock and newline formatting if present. Useful for non schema enforceable models"""
457
+ return response.strip().replace("\n", "").removeprefix("```json").removesuffix("```")
458
+
459
+
460
+ def clean_code_python(code: str):
461
+ """Remove any markdown codeblock and newline formatting if present. Useful for non schema enforceable models"""
462
+ return code.strip().removeprefix("```python").removesuffix("```")
463
+
464
+
465
+ def defilter_query(query: str):
466
+ """Remove any query filters in query"""
467
+ defiltered_query = query
468
+ filters: List[BaseFilter] = [WordFilter(), FileFilter(), DateFilter()]
469
+ for filter in filters:
470
+ defiltered_query = filter.defilter(defiltered_query)
471
+ return defiltered_query
329
472
 
330
473
 
331
474
  @dataclass
@@ -354,3 +497,160 @@ def get_image_from_url(image_url: str, type="pil"):
354
497
  except requests.exceptions.RequestException as e:
355
498
  logger.error(f"Failed to get image from URL {image_url}: {e}")
356
499
  return ImageWithType(content=None, type=None)
500
+
501
+
502
+ def commit_conversation_trace(
503
+ session: list[ChatMessage],
504
+ response: str | list[dict],
505
+ tracer: dict,
506
+ system_message: str | list[dict] = "",
507
+ repo_path: str = "/tmp/promptrace",
508
+ ) -> str:
509
+ """
510
+ Save trace of conversation step using git. Useful to visualize, compare and debug traces.
511
+ Returns the path to the repository.
512
+ """
513
+ # Serialize session, system message and response to yaml
514
+ system_message_yaml = json.dumps(system_message, ensure_ascii=False, sort_keys=False)
515
+ response_yaml = json.dumps(response, ensure_ascii=False, sort_keys=False)
516
+ formatted_session = [{"role": message.role, "content": message.content} for message in session]
517
+ session_yaml = json.dumps(formatted_session, ensure_ascii=False, sort_keys=False)
518
+ query = (
519
+ json.dumps(session[-1].content, ensure_ascii=False, sort_keys=False).strip().removeprefix("'").removesuffix("'")
520
+ ) # Extract serialized query from chat session
521
+
522
+ # Extract chat metadata for session
523
+ uid, cid, mid = tracer.get("uid", "main"), tracer.get("cid", "main"), tracer.get("mid")
524
+
525
+ # Infer repository path from environment variable or provided path
526
+ repo_path = os.getenv("PROMPTRACE_DIR", repo_path)
527
+
528
+ try:
529
+ # Prepare git repository
530
+ os.makedirs(repo_path, exist_ok=True)
531
+ repo = Repo.init(repo_path)
532
+
533
+ # Remove post-commit hook if it exists
534
+ hooks_dir = os.path.join(repo_path, ".git", "hooks")
535
+ post_commit_hook = os.path.join(hooks_dir, "post-commit")
536
+ if os.path.exists(post_commit_hook):
537
+ os.remove(post_commit_hook)
538
+
539
+ # Configure git user if not set
540
+ if not repo.config_reader().has_option("user", "email"):
541
+ repo.config_writer().set_value("user", "name", "Prompt Tracer").release()
542
+ repo.config_writer().set_value("user", "email", "promptracer@khoj.dev").release()
543
+
544
+ # Create an initial commit if the repository is newly created
545
+ if not repo.head.is_valid():
546
+ repo.index.commit("And then there was a trace")
547
+
548
+ # Check out the initial commit
549
+ initial_commit = repo.commit("HEAD~0")
550
+ repo.head.reference = initial_commit
551
+ repo.head.reset(index=True, working_tree=True)
552
+
553
+ # Create or switch to user branch from initial commit
554
+ user_branch = f"u_{uid}"
555
+ if user_branch not in repo.branches:
556
+ repo.create_head(user_branch)
557
+ repo.heads[user_branch].checkout()
558
+
559
+ # Create or switch to conversation branch from user branch
560
+ conv_branch = f"c_{cid}"
561
+ if conv_branch not in repo.branches:
562
+ repo.create_head(conv_branch)
563
+ repo.heads[conv_branch].checkout()
564
+
565
+ # Create or switch to message branch from conversation branch
566
+ msg_branch = f"m_{mid}" if mid else None
567
+ if msg_branch and msg_branch not in repo.branches:
568
+ repo.create_head(msg_branch)
569
+ if msg_branch:
570
+ repo.heads[msg_branch].checkout()
571
+
572
+ # Include file with content to commit
573
+ files_to_commit = {"query": session_yaml, "response": response_yaml, "system_prompt": system_message_yaml}
574
+
575
+ # Write files and stage them
576
+ for filename, content in files_to_commit.items():
577
+ file_path = os.path.join(repo_path, filename)
578
+ # Unescape special characters in content for better readability
579
+ content = content.strip().replace("\\n", "\n").replace("\\t", "\t")
580
+ with open(file_path, "w", encoding="utf-8") as f:
581
+ f.write(content)
582
+ repo.index.add([filename])
583
+
584
+ # Create commit
585
+ metadata_yaml = yaml.dump(tracer, allow_unicode=True, sort_keys=False, default_flow_style=False)
586
+ commit_message = f"""
587
+ {query[:250]}
588
+
589
+ Response:
590
+ ---
591
+ {response[:500]}...
592
+
593
+ Metadata
594
+ ---
595
+ {metadata_yaml}
596
+ """.strip()
597
+
598
+ repo.index.commit(commit_message)
599
+
600
+ logger.debug(f"Saved conversation trace to repo at {repo_path}")
601
+ return repo_path
602
+ except Exception as e:
603
+ logger.error(f"Failed to add conversation trace to repo: {str(e)}", exc_info=True)
604
+ return None
605
+
606
+
607
+ def merge_message_into_conversation_trace(query: str, response: str, tracer: dict, repo_path="/tmp/promptrace") -> bool:
608
+ """
609
+ Merge the message branch into its parent conversation branch.
610
+
611
+ Args:
612
+ query: User query
613
+ response: Assistant response
614
+ tracer: Dictionary containing uid, cid and mid
615
+ repo_path: Path to the git repository
616
+
617
+ Returns:
618
+ bool: True if merge was successful, False otherwise
619
+ """
620
+ try:
621
+ # Extract branch names
622
+ msg_branch = f"m_{tracer['mid']}"
623
+ conv_branch = f"c_{tracer['cid']}"
624
+
625
+ # Infer repository path from environment variable or provided path
626
+ repo_path = os.getenv("PROMPTRACE_DIR", repo_path)
627
+ repo = Repo(repo_path)
628
+
629
+ # Checkout conversation branch
630
+ repo.heads[conv_branch].checkout()
631
+
632
+ # Create commit message
633
+ metadata_yaml = yaml.dump(tracer, allow_unicode=True, sort_keys=False, default_flow_style=False)
634
+ commit_message = f"""
635
+ {query[:250]}
636
+
637
+ Response:
638
+ ---
639
+ {response[:500]}...
640
+
641
+ Metadata
642
+ ---
643
+ {metadata_yaml}
644
+ """.strip()
645
+
646
+ # Merge message branch into conversation branch
647
+ repo.git.merge(msg_branch, no_ff=True, m=commit_message)
648
+
649
+ # Delete message branch after merge
650
+ repo.delete_head(msg_branch, force=True)
651
+
652
+ logger.debug(f"Successfully merged {msg_branch} into {conv_branch}")
653
+ return True
654
+ except Exception as e:
655
+ logger.error(f"Failed to merge message {msg_branch} into conversation {conv_branch}: {str(e)}", exc_info=True)
656
+ return False
@@ -13,7 +13,7 @@ from tenacity import (
13
13
  )
14
14
  from torch import nn
15
15
 
16
- from khoj.utils.helpers import get_device, merge_dicts, timer
16
+ from khoj.utils.helpers import fix_json_dict, get_device, merge_dicts, timer
17
17
  from khoj.utils.rawconfig import SearchResponse
18
18
 
19
19
  logger = logging.getLogger(__name__)
@@ -31,9 +31,9 @@ class EmbeddingsModel:
31
31
  ):
32
32
  default_query_encode_kwargs = {"show_progress_bar": False, "normalize_embeddings": True}
33
33
  default_docs_encode_kwargs = {"show_progress_bar": True, "normalize_embeddings": True}
34
- self.query_encode_kwargs = merge_dicts(query_encode_kwargs, default_query_encode_kwargs)
35
- self.docs_encode_kwargs = merge_dicts(docs_encode_kwargs, default_docs_encode_kwargs)
36
- self.model_kwargs = merge_dicts(model_kwargs, {"device": get_device()})
34
+ self.query_encode_kwargs = merge_dicts(fix_json_dict(query_encode_kwargs), default_query_encode_kwargs)
35
+ self.docs_encode_kwargs = merge_dicts(fix_json_dict(docs_encode_kwargs), default_docs_encode_kwargs)
36
+ self.model_kwargs = merge_dicts(fix_json_dict(model_kwargs), {"device": get_device()})
37
37
  self.model_name = model_name
38
38
  self.inference_endpoint = embeddings_inference_endpoint
39
39
  self.api_key = embeddings_inference_endpoint_api_key
@@ -28,6 +28,7 @@ async def text_to_image(
28
28
  send_status_func: Optional[Callable] = None,
29
29
  query_images: Optional[List[str]] = None,
30
30
  agent: Agent = None,
31
+ tracer: dict = {},
31
32
  ):
32
33
  status_code = 200
33
34
  image = None
@@ -68,6 +69,7 @@ async def text_to_image(
68
69
  query_images=query_images,
69
70
  user=user,
70
71
  agent=agent,
72
+ tracer=tracer,
71
73
  )
72
74
 
73
75
  if send_status_func:
@@ -4,7 +4,7 @@ import logging
4
4
  import os
5
5
  import urllib.parse
6
6
  from collections import defaultdict
7
- from typing import Callable, Dict, List, Optional, Tuple, Union
7
+ from typing import Any, Callable, Dict, List, Optional, Tuple, Union
8
8
 
9
9
  import aiohttp
10
10
  from bs4 import BeautifulSoup
@@ -52,7 +52,9 @@ OLOSTEP_QUERY_PARAMS = {
52
52
  "expandMarkdown": "True",
53
53
  "expandHtml": "False",
54
54
  }
55
- MAX_WEBPAGES_TO_READ = 1
55
+
56
+ DEFAULT_MAX_WEBPAGES_TO_READ = 1
57
+ MAX_WEBPAGES_TO_INFER = 10
56
58
 
57
59
 
58
60
  async def search_online(
@@ -62,8 +64,10 @@ async def search_online(
62
64
  user: KhojUser,
63
65
  send_status_func: Optional[Callable] = None,
64
66
  custom_filters: List[str] = [],
67
+ max_webpages_to_read: int = DEFAULT_MAX_WEBPAGES_TO_READ,
65
68
  query_images: List[str] = None,
66
69
  agent: Agent = None,
70
+ tracer: dict = {},
67
71
  ):
68
72
  query += " ".join(custom_filters)
69
73
  if not is_internet_connected():
@@ -73,7 +77,7 @@ async def search_online(
73
77
 
74
78
  # Breakdown the query into subqueries to get the correct answer
75
79
  subqueries = await generate_online_subqueries(
76
- query, conversation_history, location, user, query_images=query_images, agent=agent
80
+ query, conversation_history, location, user, query_images=query_images, agent=agent, tracer=tracer
77
81
  )
78
82
  response_dict = {}
79
83
 
@@ -96,7 +100,7 @@ async def search_online(
96
100
  for subquery in response_dict:
97
101
  if "answerBox" in response_dict[subquery]:
98
102
  continue
99
- for organic in response_dict[subquery].get("organic", [])[:MAX_WEBPAGES_TO_READ]:
103
+ for organic in response_dict[subquery].get("organic", [])[:max_webpages_to_read]:
100
104
  link = organic.get("link")
101
105
  if link in webpages:
102
106
  webpages[link]["queries"].add(subquery)
@@ -111,7 +115,7 @@ async def search_online(
111
115
  async for event in send_status_func(f"**Reading web pages**: {webpage_links_str}"):
112
116
  yield {ChatEvent.STATUS: event}
113
117
  tasks = [
114
- read_webpage_and_extract_content(data["queries"], link, data["content"], user=user, agent=agent)
118
+ read_webpage_and_extract_content(data["queries"], link, data["content"], user=user, agent=agent, tracer=tracer)
115
119
  for link, data in webpages.items()
116
120
  ]
117
121
  results = await asyncio.gather(*tasks)
@@ -153,20 +157,24 @@ async def read_webpages(
153
157
  send_status_func: Optional[Callable] = None,
154
158
  query_images: List[str] = None,
155
159
  agent: Agent = None,
160
+ tracer: dict = {},
161
+ max_webpages_to_read: int = DEFAULT_MAX_WEBPAGES_TO_READ,
156
162
  ):
157
163
  "Infer web pages to read from the query and extract relevant information from them"
158
164
  logger.info(f"Inferring web pages to read")
159
- if send_status_func:
160
- async for event in send_status_func(f"**Inferring web pages to read**"):
161
- yield {ChatEvent.STATUS: event}
162
- urls = await infer_webpage_urls(query, conversation_history, location, user, query_images)
165
+ urls = await infer_webpage_urls(
166
+ query, conversation_history, location, user, query_images, agent=agent, tracer=tracer
167
+ )
168
+
169
+ # Get the top 10 web pages to read
170
+ urls = urls[:max_webpages_to_read]
163
171
 
164
172
  logger.info(f"Reading web pages at: {urls}")
165
173
  if send_status_func:
166
174
  webpage_links_str = "\n- " + "\n- ".join(list(urls))
167
175
  async for event in send_status_func(f"**Reading web pages**: {webpage_links_str}"):
168
176
  yield {ChatEvent.STATUS: event}
169
- tasks = [read_webpage_and_extract_content({query}, url, user=user, agent=agent) for url in urls]
177
+ tasks = [read_webpage_and_extract_content({query}, url, user=user, agent=agent, tracer=tracer) for url in urls]
170
178
  results = await asyncio.gather(*tasks)
171
179
 
172
180
  response: Dict[str, Dict] = defaultdict(dict)
@@ -192,7 +200,12 @@ async def read_webpage(
192
200
 
193
201
 
194
202
  async def read_webpage_and_extract_content(
195
- subqueries: set[str], url: str, content: str = None, user: KhojUser = None, agent: Agent = None
203
+ subqueries: set[str],
204
+ url: str,
205
+ content: str = None,
206
+ user: KhojUser = None,
207
+ agent: Agent = None,
208
+ tracer: dict = {},
196
209
  ) -> Tuple[set[str], str, Union[None, str]]:
197
210
  # Select the web scrapers to use for reading the web page
198
211
  web_scrapers = await ConversationAdapters.aget_enabled_webscrapers()
@@ -214,7 +227,9 @@ async def read_webpage_and_extract_content(
214
227
  # Extract relevant information from the web page
215
228
  if is_none_or_empty(extracted_info):
216
229
  with timer(f"Extracting relevant information from web page at '{url}' took", logger):
217
- extracted_info = await extract_relevant_info(subqueries, content, user=user, agent=agent)
230
+ extracted_info = await extract_relevant_info(
231
+ subqueries, content, user=user, agent=agent, tracer=tracer
232
+ )
218
233
 
219
234
  # If we successfully extracted information, break the loop
220
235
  if not is_none_or_empty(extracted_info):
@@ -0,0 +1,144 @@
1
+ import asyncio
2
+ import datetime
3
+ import json
4
+ import logging
5
+ import os
6
+ from typing import Any, Callable, List, Optional
7
+
8
+ import aiohttp
9
+
10
+ from khoj.database.adapters import ais_user_subscribed
11
+ from khoj.database.models import Agent, KhojUser
12
+ from khoj.processor.conversation import prompts
13
+ from khoj.processor.conversation.utils import (
14
+ ChatEvent,
15
+ clean_code_python,
16
+ clean_json,
17
+ construct_chat_history,
18
+ )
19
+ from khoj.routers.helpers import send_message_to_model_wrapper
20
+ from khoj.utils.helpers import timer
21
+ from khoj.utils.rawconfig import LocationData
22
+
23
+ logger = logging.getLogger(__name__)
24
+
25
+
26
+ SANDBOX_URL = os.getenv("KHOJ_TERRARIUM_URL", "http://localhost:8080")
27
+
28
+
29
+ async def run_code(
30
+ query: str,
31
+ conversation_history: dict,
32
+ context: str,
33
+ location_data: LocationData,
34
+ user: KhojUser,
35
+ send_status_func: Optional[Callable] = None,
36
+ query_images: List[str] = None,
37
+ agent: Agent = None,
38
+ sandbox_url: str = SANDBOX_URL,
39
+ tracer: dict = {},
40
+ ):
41
+ # Generate Code
42
+ if send_status_func:
43
+ async for event in send_status_func(f"**Generate code snippets** for {query}"):
44
+ yield {ChatEvent.STATUS: event}
45
+ try:
46
+ with timer("Chat actor: Generate programs to execute", logger):
47
+ codes = await generate_python_code(
48
+ query,
49
+ conversation_history,
50
+ context,
51
+ location_data,
52
+ user,
53
+ query_images,
54
+ agent,
55
+ tracer,
56
+ )
57
+ except Exception as e:
58
+ raise ValueError(f"Failed to generate code for {query} with error: {e}")
59
+
60
+ # Run Code
61
+ if send_status_func:
62
+ async for event in send_status_func(f"**Running {len(codes)} code snippets**"):
63
+ yield {ChatEvent.STATUS: event}
64
+ try:
65
+ tasks = [execute_sandboxed_python(code, sandbox_url) for code in codes]
66
+ with timer("Chat actor: Execute generated programs", logger):
67
+ results = await asyncio.gather(*tasks)
68
+ for result in results:
69
+ code = result.pop("code")
70
+ logger.info(f"Executed Code:\n--@@--\n{code}\n--@@--Result:\n--@@--\n{result}\n--@@--")
71
+ yield {query: {"code": code, "results": result}}
72
+ except Exception as e:
73
+ raise ValueError(f"Failed to run code for {query} with error: {e}")
74
+
75
+
76
+ async def generate_python_code(
77
+ q: str,
78
+ conversation_history: dict,
79
+ context: str,
80
+ location_data: LocationData,
81
+ user: KhojUser,
82
+ query_images: List[str] = None,
83
+ agent: Agent = None,
84
+ tracer: dict = {},
85
+ ) -> List[str]:
86
+ location = f"{location_data}" if location_data else "Unknown"
87
+ username = prompts.user_name.format(name=user.get_full_name()) if user.get_full_name() else ""
88
+ subscribed = await ais_user_subscribed(user)
89
+ chat_history = construct_chat_history(conversation_history)
90
+
91
+ utc_date = datetime.datetime.now(datetime.timezone.utc).strftime("%Y-%m-%d")
92
+ personality_context = (
93
+ prompts.personality_context.format(personality=agent.personality) if agent and agent.personality else ""
94
+ )
95
+
96
+ code_generation_prompt = prompts.python_code_generation_prompt.format(
97
+ current_date=utc_date,
98
+ query=q,
99
+ chat_history=chat_history,
100
+ context=context,
101
+ location=location,
102
+ username=username,
103
+ personality_context=personality_context,
104
+ )
105
+
106
+ response = await send_message_to_model_wrapper(
107
+ code_generation_prompt,
108
+ query_images=query_images,
109
+ response_type="json_object",
110
+ user=user,
111
+ tracer=tracer,
112
+ )
113
+
114
+ # Validate that the response is a non-empty, JSON-serializable list
115
+ response = clean_json(response)
116
+ response = json.loads(response)
117
+ codes = [code.strip() for code in response["codes"] if code.strip()]
118
+
119
+ if not isinstance(codes, list) or not codes or len(codes) == 0:
120
+ raise ValueError
121
+ return codes
122
+
123
+
124
+ async def execute_sandboxed_python(code: str, sandbox_url: str = SANDBOX_URL) -> dict[str, Any]:
125
+ """
126
+ Takes code to run as a string and calls the terrarium API to execute it.
127
+ Returns the result of the code execution as a dictionary.
128
+ """
129
+ headers = {"Content-Type": "application/json"}
130
+ cleaned_code = clean_code_python(code)
131
+ data = {"code": cleaned_code}
132
+
133
+ async with aiohttp.ClientSession() as session:
134
+ async with session.post(sandbox_url, json=data, headers=headers) as response:
135
+ if response.status == 200:
136
+ result: dict[str, Any] = await response.json()
137
+ result["code"] = cleaned_code
138
+ return result
139
+ else:
140
+ return {
141
+ "code": cleaned_code,
142
+ "success": False,
143
+ "std_err": f"Failed to execute code with {response.status}",
144
+ }