khoj 1.27.2.dev15__py3-none-any.whl → 1.28.1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- khoj/configure.py +1 -1
- khoj/database/adapters/__init__.py +50 -12
- khoj/interface/compiled/404/index.html +1 -1
- khoj/interface/compiled/_next/static/chunks/1034-da58b679fcbb79c1.js +1 -0
- khoj/interface/compiled/_next/static/chunks/1467-b331e469fe411347.js +1 -0
- khoj/interface/compiled/_next/static/chunks/1603-c1568f45947e9f2c.js +1 -0
- khoj/interface/compiled/_next/static/chunks/3423-ff7402ae1dd66592.js +1 -0
- khoj/interface/compiled/_next/static/chunks/8423-e80647edf6c92c27.js +1 -0
- khoj/interface/compiled/_next/static/chunks/app/agents/{page-2beaba7c9bb750bd.js → page-fc492762298e975e.js} +1 -1
- khoj/interface/compiled/_next/static/chunks/app/automations/{page-9b5c77e0b0dd772c.js → page-416ee13a00575c39.js} +1 -1
- khoj/interface/compiled/_next/static/chunks/app/chat/page-c70f5b0c722d7627.js +1 -0
- khoj/interface/compiled/_next/static/chunks/app/factchecker/page-1541d90140794f63.js +1 -0
- khoj/interface/compiled/_next/static/chunks/app/{page-4b6008223ea79955.js → page-b269e444fc067759.js} +1 -1
- khoj/interface/compiled/_next/static/chunks/app/search/{page-ab2995529ece3140.js → page-7d431ce8e565c7c3.js} +1 -1
- khoj/interface/compiled/_next/static/chunks/app/settings/{page-7946cabb9c54e22d.js → page-95f56e53f48f0289.js} +1 -1
- khoj/interface/compiled/_next/static/chunks/app/share/chat/{page-6a01e07fb244c10c.js → page-4eba6154f7bb9771.js} +1 -1
- khoj/interface/compiled/_next/static/chunks/{webpack-878569182b3af4c6.js → webpack-33a82ccca02cd2b8.js} +1 -1
- khoj/interface/compiled/_next/static/css/2196fae09c2f906e.css +1 -0
- khoj/interface/compiled/_next/static/css/6bde1f2045622ef7.css +1 -0
- khoj/interface/compiled/_next/static/css/a795ee88875f4853.css +25 -0
- khoj/interface/compiled/_next/static/css/ebef43da1c0651d5.css +1 -0
- khoj/interface/compiled/agents/index.html +1 -1
- khoj/interface/compiled/agents/index.txt +2 -2
- khoj/interface/compiled/automations/index.html +1 -1
- khoj/interface/compiled/automations/index.txt +2 -2
- khoj/interface/compiled/chat/index.html +1 -1
- khoj/interface/compiled/chat/index.txt +2 -2
- khoj/interface/compiled/factchecker/index.html +1 -1
- khoj/interface/compiled/factchecker/index.txt +2 -2
- khoj/interface/compiled/index.html +1 -1
- khoj/interface/compiled/index.txt +2 -2
- khoj/interface/compiled/search/index.html +1 -1
- khoj/interface/compiled/search/index.txt +2 -2
- khoj/interface/compiled/settings/index.html +1 -1
- khoj/interface/compiled/settings/index.txt +2 -2
- khoj/interface/compiled/share/chat/index.html +1 -1
- khoj/interface/compiled/share/chat/index.txt +2 -2
- khoj/processor/conversation/anthropic/anthropic_chat.py +19 -10
- khoj/processor/conversation/anthropic/utils.py +37 -6
- khoj/processor/conversation/google/gemini_chat.py +23 -13
- khoj/processor/conversation/google/utils.py +34 -10
- khoj/processor/conversation/offline/chat_model.py +48 -16
- khoj/processor/conversation/openai/gpt.py +25 -10
- khoj/processor/conversation/openai/utils.py +50 -9
- khoj/processor/conversation/prompts.py +156 -65
- khoj/processor/conversation/utils.py +306 -6
- khoj/processor/embeddings.py +4 -4
- khoj/processor/image/generate.py +2 -0
- khoj/processor/tools/online_search.py +27 -12
- khoj/processor/tools/run_code.py +144 -0
- khoj/routers/api.py +11 -6
- khoj/routers/api_chat.py +213 -111
- khoj/routers/helpers.py +171 -60
- khoj/routers/research.py +320 -0
- khoj/search_filter/date_filter.py +1 -3
- khoj/search_filter/file_filter.py +1 -2
- khoj/search_type/text_search.py +3 -3
- khoj/utils/helpers.py +24 -2
- khoj/utils/yaml.py +4 -0
- {khoj-1.27.2.dev15.dist-info → khoj-1.28.1.dist-info}/METADATA +3 -2
- {khoj-1.27.2.dev15.dist-info → khoj-1.28.1.dist-info}/RECORD +68 -65
- khoj/interface/compiled/_next/static/chunks/1603-b9d95833e0e025e8.js +0 -1
- khoj/interface/compiled/_next/static/chunks/2697-61fcba89fd87eab4.js +0 -1
- khoj/interface/compiled/_next/static/chunks/3423-0b533af8bf6ac218.js +0 -1
- khoj/interface/compiled/_next/static/chunks/9479-ff7d8c4dae2014d1.js +0 -1
- khoj/interface/compiled/_next/static/chunks/app/chat/page-151232d8417a1ea1.js +0 -1
- khoj/interface/compiled/_next/static/chunks/app/factchecker/page-798904432c2417c4.js +0 -1
- khoj/interface/compiled/_next/static/css/2272c73fc7a3b571.css +0 -1
- khoj/interface/compiled/_next/static/css/553f9cdcc7a2bcd6.css +0 -1
- khoj/interface/compiled/_next/static/css/76d55eb435962b19.css +0 -25
- khoj/interface/compiled/_next/static/css/d738728883c68af8.css +0 -1
- /khoj/interface/compiled/_next/static/{vcyFRDGArOFXwUVotHIuv → JcTomiF3o0dIo4RxHR9Vu}/_buildManifest.js +0 -0
- /khoj/interface/compiled/_next/static/{vcyFRDGArOFXwUVotHIuv → JcTomiF3o0dIo4RxHR9Vu}/_ssgManifest.js +0 -0
- /khoj/interface/compiled/_next/static/chunks/{1970-60c96aed937a4928.js → 1970-90dd510762d820ba.js} +0 -0
- /khoj/interface/compiled/_next/static/chunks/{9417-2ca87207387fc790.js → 9417-951f46451a8dd6d7.js} +0 -0
- {khoj-1.27.2.dev15.dist-info → khoj-1.28.1.dist-info}/WHEEL +0 -0
- {khoj-1.27.2.dev15.dist-info → khoj-1.28.1.dist-info}/entry_points.txt +0 -0
- {khoj-1.27.2.dev15.dist-info → khoj-1.28.1.dist-info}/licenses/LICENSE +0 -0
@@ -1,17 +1,23 @@
|
|
1
1
|
import base64
|
2
|
+
import json
|
2
3
|
import logging
|
3
4
|
import math
|
4
5
|
import mimetypes
|
6
|
+
import os
|
5
7
|
import queue
|
8
|
+
import uuid
|
6
9
|
from dataclasses import dataclass
|
7
10
|
from datetime import datetime
|
11
|
+
from enum import Enum
|
8
12
|
from io import BytesIO
|
9
13
|
from time import perf_counter
|
10
|
-
from typing import Any, Dict, List, Optional
|
14
|
+
from typing import Any, Callable, Dict, List, Optional
|
11
15
|
|
12
16
|
import PIL.Image
|
13
17
|
import requests
|
14
18
|
import tiktoken
|
19
|
+
import yaml
|
20
|
+
from git import Repo
|
15
21
|
from langchain.schema import ChatMessage
|
16
22
|
from llama_cpp.llama import Llama
|
17
23
|
from transformers import AutoTokenizer
|
@@ -20,8 +26,17 @@ from khoj.database.adapters import ConversationAdapters
|
|
20
26
|
from khoj.database.models import ChatModelOptions, ClientApplication, KhojUser
|
21
27
|
from khoj.processor.conversation import prompts
|
22
28
|
from khoj.processor.conversation.offline.utils import download_model, infer_max_tokens
|
29
|
+
from khoj.search_filter.base_filter import BaseFilter
|
30
|
+
from khoj.search_filter.date_filter import DateFilter
|
31
|
+
from khoj.search_filter.file_filter import FileFilter
|
32
|
+
from khoj.search_filter.word_filter import WordFilter
|
23
33
|
from khoj.utils import state
|
24
|
-
from khoj.utils.helpers import
|
34
|
+
from khoj.utils.helpers import (
|
35
|
+
ConversationCommand,
|
36
|
+
in_debug_mode,
|
37
|
+
is_none_or_empty,
|
38
|
+
merge_dicts,
|
39
|
+
)
|
25
40
|
|
26
41
|
logger = logging.getLogger(__name__)
|
27
42
|
model_to_prompt_size = {
|
@@ -82,8 +97,110 @@ class ThreadedGenerator:
|
|
82
97
|
self.queue.put(StopIteration)
|
83
98
|
|
84
99
|
|
100
|
+
class InformationCollectionIteration:
|
101
|
+
def __init__(
|
102
|
+
self,
|
103
|
+
tool: str,
|
104
|
+
query: str,
|
105
|
+
context: list = None,
|
106
|
+
onlineContext: dict = None,
|
107
|
+
codeContext: dict = None,
|
108
|
+
summarizedResult: str = None,
|
109
|
+
):
|
110
|
+
self.tool = tool
|
111
|
+
self.query = query
|
112
|
+
self.context = context
|
113
|
+
self.onlineContext = onlineContext
|
114
|
+
self.codeContext = codeContext
|
115
|
+
self.summarizedResult = summarizedResult
|
116
|
+
|
117
|
+
|
118
|
+
def construct_iteration_history(
|
119
|
+
previous_iterations: List[InformationCollectionIteration], previous_iteration_prompt: str
|
120
|
+
) -> str:
|
121
|
+
previous_iterations_history = ""
|
122
|
+
for idx, iteration in enumerate(previous_iterations):
|
123
|
+
iteration_data = previous_iteration_prompt.format(
|
124
|
+
tool=iteration.tool,
|
125
|
+
query=iteration.query,
|
126
|
+
result=iteration.summarizedResult,
|
127
|
+
index=idx + 1,
|
128
|
+
)
|
129
|
+
|
130
|
+
previous_iterations_history += iteration_data
|
131
|
+
return previous_iterations_history
|
132
|
+
|
133
|
+
|
134
|
+
def construct_chat_history(conversation_history: dict, n: int = 4, agent_name="AI") -> str:
|
135
|
+
chat_history = ""
|
136
|
+
for chat in conversation_history.get("chat", [])[-n:]:
|
137
|
+
if chat["by"] == "khoj" and chat["intent"].get("type") in ["remember", "reminder", "summarize"]:
|
138
|
+
chat_history += f"User: {chat['intent']['query']}\n"
|
139
|
+
|
140
|
+
if chat["intent"].get("inferred-queries"):
|
141
|
+
chat_history += f'Khoj: {{"queries": {chat["intent"].get("inferred-queries")}}}\n'
|
142
|
+
|
143
|
+
chat_history += f"{agent_name}: {chat['message']}\n\n"
|
144
|
+
elif chat["by"] == "khoj" and ("text-to-image" in chat["intent"].get("type")):
|
145
|
+
chat_history += f"User: {chat['intent']['query']}\n"
|
146
|
+
chat_history += f"{agent_name}: [generated image redacted for space]\n"
|
147
|
+
elif chat["by"] == "khoj" and ("excalidraw" in chat["intent"].get("type")):
|
148
|
+
chat_history += f"User: {chat['intent']['query']}\n"
|
149
|
+
chat_history += f"{agent_name}: {chat['intent']['inferred-queries'][0]}\n"
|
150
|
+
return chat_history
|
151
|
+
|
152
|
+
|
153
|
+
def construct_tool_chat_history(
|
154
|
+
previous_iterations: List[InformationCollectionIteration], tool: ConversationCommand = None
|
155
|
+
) -> Dict[str, list]:
|
156
|
+
chat_history: list = []
|
157
|
+
inferred_query_extractor: Callable[[InformationCollectionIteration], List[str]] = lambda x: []
|
158
|
+
if tool == ConversationCommand.Notes:
|
159
|
+
inferred_query_extractor = (
|
160
|
+
lambda iteration: [c["query"] for c in iteration.context] if iteration.context else []
|
161
|
+
)
|
162
|
+
elif tool == ConversationCommand.Online:
|
163
|
+
inferred_query_extractor = (
|
164
|
+
lambda iteration: list(iteration.onlineContext.keys()) if iteration.onlineContext else []
|
165
|
+
)
|
166
|
+
elif tool == ConversationCommand.Code:
|
167
|
+
inferred_query_extractor = lambda iteration: list(iteration.codeContext.keys()) if iteration.codeContext else []
|
168
|
+
for iteration in previous_iterations:
|
169
|
+
chat_history += [
|
170
|
+
{
|
171
|
+
"by": "you",
|
172
|
+
"message": iteration.query,
|
173
|
+
},
|
174
|
+
{
|
175
|
+
"by": "khoj",
|
176
|
+
"intent": {
|
177
|
+
"type": "remember",
|
178
|
+
"inferred-queries": inferred_query_extractor(iteration),
|
179
|
+
"query": iteration.query,
|
180
|
+
},
|
181
|
+
"message": iteration.summarizedResult,
|
182
|
+
},
|
183
|
+
]
|
184
|
+
|
185
|
+
return {"chat": chat_history}
|
186
|
+
|
187
|
+
|
188
|
+
class ChatEvent(Enum):
|
189
|
+
START_LLM_RESPONSE = "start_llm_response"
|
190
|
+
END_LLM_RESPONSE = "end_llm_response"
|
191
|
+
MESSAGE = "message"
|
192
|
+
REFERENCES = "references"
|
193
|
+
STATUS = "status"
|
194
|
+
METADATA = "metadata"
|
195
|
+
|
196
|
+
|
85
197
|
def message_to_log(
|
86
|
-
user_message,
|
198
|
+
user_message,
|
199
|
+
chat_response,
|
200
|
+
user_message_metadata={},
|
201
|
+
khoj_message_metadata={},
|
202
|
+
conversation_log=[],
|
203
|
+
train_of_thought=[],
|
87
204
|
):
|
88
205
|
"""Create json logs from messages, metadata for conversation log"""
|
89
206
|
default_khoj_message_metadata = {
|
@@ -111,28 +228,37 @@ def save_to_conversation_log(
|
|
111
228
|
user_message_time: str = None,
|
112
229
|
compiled_references: List[Dict[str, Any]] = [],
|
113
230
|
online_results: Dict[str, Any] = {},
|
231
|
+
code_results: Dict[str, Any] = {},
|
114
232
|
inferred_queries: List[str] = [],
|
115
233
|
intent_type: str = "remember",
|
116
234
|
client_application: ClientApplication = None,
|
117
235
|
conversation_id: str = None,
|
118
236
|
automation_id: str = None,
|
119
237
|
query_images: List[str] = None,
|
238
|
+
tracer: Dict[str, Any] = {},
|
239
|
+
train_of_thought: List[Any] = [],
|
120
240
|
):
|
121
241
|
user_message_time = user_message_time or datetime.now().strftime("%Y-%m-%d %H:%M:%S")
|
242
|
+
turn_id = tracer.get("mid") or str(uuid.uuid4())
|
122
243
|
updated_conversation = message_to_log(
|
123
244
|
user_message=q,
|
124
245
|
chat_response=chat_response,
|
125
246
|
user_message_metadata={
|
126
247
|
"created": user_message_time,
|
127
248
|
"images": query_images,
|
249
|
+
"turnId": turn_id,
|
128
250
|
},
|
129
251
|
khoj_message_metadata={
|
130
252
|
"context": compiled_references,
|
131
253
|
"intent": {"inferred-queries": inferred_queries, "type": intent_type},
|
132
254
|
"onlineContext": online_results,
|
255
|
+
"codeContext": code_results,
|
133
256
|
"automationId": automation_id,
|
257
|
+
"trainOfThought": train_of_thought,
|
258
|
+
"turnId": turn_id,
|
134
259
|
},
|
135
260
|
conversation_log=meta_log.get("chat", []),
|
261
|
+
train_of_thought=train_of_thought,
|
136
262
|
)
|
137
263
|
ConversationAdapters.save_conversation(
|
138
264
|
user,
|
@@ -142,6 +268,9 @@ def save_to_conversation_log(
|
|
142
268
|
user_message=q,
|
143
269
|
)
|
144
270
|
|
271
|
+
if in_debug_mode() or state.verbose > 1:
|
272
|
+
merge_message_into_conversation_trace(q, chat_response, tracer)
|
273
|
+
|
145
274
|
logger.info(
|
146
275
|
f"""
|
147
276
|
Saved Conversation Turn
|
@@ -323,9 +452,23 @@ def reciprocal_conversation_to_chatml(message_pair):
|
|
323
452
|
return [ChatMessage(content=message, role=role) for message, role in zip(message_pair, ["user", "assistant"])]
|
324
453
|
|
325
454
|
|
326
|
-
def
|
327
|
-
"""Remove any markdown json codeblock formatting if present. Useful for non schema enforceable models"""
|
328
|
-
return response.removeprefix("```json").removesuffix("```")
|
455
|
+
def clean_json(response: str):
|
456
|
+
"""Remove any markdown json codeblock and newline formatting if present. Useful for non schema enforceable models"""
|
457
|
+
return response.strip().replace("\n", "").removeprefix("```json").removesuffix("```")
|
458
|
+
|
459
|
+
|
460
|
+
def clean_code_python(code: str):
|
461
|
+
"""Remove any markdown codeblock and newline formatting if present. Useful for non schema enforceable models"""
|
462
|
+
return code.strip().removeprefix("```python").removesuffix("```")
|
463
|
+
|
464
|
+
|
465
|
+
def defilter_query(query: str):
|
466
|
+
"""Remove any query filters in query"""
|
467
|
+
defiltered_query = query
|
468
|
+
filters: List[BaseFilter] = [WordFilter(), FileFilter(), DateFilter()]
|
469
|
+
for filter in filters:
|
470
|
+
defiltered_query = filter.defilter(defiltered_query)
|
471
|
+
return defiltered_query
|
329
472
|
|
330
473
|
|
331
474
|
@dataclass
|
@@ -354,3 +497,160 @@ def get_image_from_url(image_url: str, type="pil"):
|
|
354
497
|
except requests.exceptions.RequestException as e:
|
355
498
|
logger.error(f"Failed to get image from URL {image_url}: {e}")
|
356
499
|
return ImageWithType(content=None, type=None)
|
500
|
+
|
501
|
+
|
502
|
+
def commit_conversation_trace(
|
503
|
+
session: list[ChatMessage],
|
504
|
+
response: str | list[dict],
|
505
|
+
tracer: dict,
|
506
|
+
system_message: str | list[dict] = "",
|
507
|
+
repo_path: str = "/tmp/promptrace",
|
508
|
+
) -> str:
|
509
|
+
"""
|
510
|
+
Save trace of conversation step using git. Useful to visualize, compare and debug traces.
|
511
|
+
Returns the path to the repository.
|
512
|
+
"""
|
513
|
+
# Serialize session, system message and response to yaml
|
514
|
+
system_message_yaml = json.dumps(system_message, ensure_ascii=False, sort_keys=False)
|
515
|
+
response_yaml = json.dumps(response, ensure_ascii=False, sort_keys=False)
|
516
|
+
formatted_session = [{"role": message.role, "content": message.content} for message in session]
|
517
|
+
session_yaml = json.dumps(formatted_session, ensure_ascii=False, sort_keys=False)
|
518
|
+
query = (
|
519
|
+
json.dumps(session[-1].content, ensure_ascii=False, sort_keys=False).strip().removeprefix("'").removesuffix("'")
|
520
|
+
) # Extract serialized query from chat session
|
521
|
+
|
522
|
+
# Extract chat metadata for session
|
523
|
+
uid, cid, mid = tracer.get("uid", "main"), tracer.get("cid", "main"), tracer.get("mid")
|
524
|
+
|
525
|
+
# Infer repository path from environment variable or provided path
|
526
|
+
repo_path = os.getenv("PROMPTRACE_DIR", repo_path)
|
527
|
+
|
528
|
+
try:
|
529
|
+
# Prepare git repository
|
530
|
+
os.makedirs(repo_path, exist_ok=True)
|
531
|
+
repo = Repo.init(repo_path)
|
532
|
+
|
533
|
+
# Remove post-commit hook if it exists
|
534
|
+
hooks_dir = os.path.join(repo_path, ".git", "hooks")
|
535
|
+
post_commit_hook = os.path.join(hooks_dir, "post-commit")
|
536
|
+
if os.path.exists(post_commit_hook):
|
537
|
+
os.remove(post_commit_hook)
|
538
|
+
|
539
|
+
# Configure git user if not set
|
540
|
+
if not repo.config_reader().has_option("user", "email"):
|
541
|
+
repo.config_writer().set_value("user", "name", "Prompt Tracer").release()
|
542
|
+
repo.config_writer().set_value("user", "email", "promptracer@khoj.dev").release()
|
543
|
+
|
544
|
+
# Create an initial commit if the repository is newly created
|
545
|
+
if not repo.head.is_valid():
|
546
|
+
repo.index.commit("And then there was a trace")
|
547
|
+
|
548
|
+
# Check out the initial commit
|
549
|
+
initial_commit = repo.commit("HEAD~0")
|
550
|
+
repo.head.reference = initial_commit
|
551
|
+
repo.head.reset(index=True, working_tree=True)
|
552
|
+
|
553
|
+
# Create or switch to user branch from initial commit
|
554
|
+
user_branch = f"u_{uid}"
|
555
|
+
if user_branch not in repo.branches:
|
556
|
+
repo.create_head(user_branch)
|
557
|
+
repo.heads[user_branch].checkout()
|
558
|
+
|
559
|
+
# Create or switch to conversation branch from user branch
|
560
|
+
conv_branch = f"c_{cid}"
|
561
|
+
if conv_branch not in repo.branches:
|
562
|
+
repo.create_head(conv_branch)
|
563
|
+
repo.heads[conv_branch].checkout()
|
564
|
+
|
565
|
+
# Create or switch to message branch from conversation branch
|
566
|
+
msg_branch = f"m_{mid}" if mid else None
|
567
|
+
if msg_branch and msg_branch not in repo.branches:
|
568
|
+
repo.create_head(msg_branch)
|
569
|
+
if msg_branch:
|
570
|
+
repo.heads[msg_branch].checkout()
|
571
|
+
|
572
|
+
# Include file with content to commit
|
573
|
+
files_to_commit = {"query": session_yaml, "response": response_yaml, "system_prompt": system_message_yaml}
|
574
|
+
|
575
|
+
# Write files and stage them
|
576
|
+
for filename, content in files_to_commit.items():
|
577
|
+
file_path = os.path.join(repo_path, filename)
|
578
|
+
# Unescape special characters in content for better readability
|
579
|
+
content = content.strip().replace("\\n", "\n").replace("\\t", "\t")
|
580
|
+
with open(file_path, "w", encoding="utf-8") as f:
|
581
|
+
f.write(content)
|
582
|
+
repo.index.add([filename])
|
583
|
+
|
584
|
+
# Create commit
|
585
|
+
metadata_yaml = yaml.dump(tracer, allow_unicode=True, sort_keys=False, default_flow_style=False)
|
586
|
+
commit_message = f"""
|
587
|
+
{query[:250]}
|
588
|
+
|
589
|
+
Response:
|
590
|
+
---
|
591
|
+
{response[:500]}...
|
592
|
+
|
593
|
+
Metadata
|
594
|
+
---
|
595
|
+
{metadata_yaml}
|
596
|
+
""".strip()
|
597
|
+
|
598
|
+
repo.index.commit(commit_message)
|
599
|
+
|
600
|
+
logger.debug(f"Saved conversation trace to repo at {repo_path}")
|
601
|
+
return repo_path
|
602
|
+
except Exception as e:
|
603
|
+
logger.error(f"Failed to add conversation trace to repo: {str(e)}", exc_info=True)
|
604
|
+
return None
|
605
|
+
|
606
|
+
|
607
|
+
def merge_message_into_conversation_trace(query: str, response: str, tracer: dict, repo_path="/tmp/promptrace") -> bool:
|
608
|
+
"""
|
609
|
+
Merge the message branch into its parent conversation branch.
|
610
|
+
|
611
|
+
Args:
|
612
|
+
query: User query
|
613
|
+
response: Assistant response
|
614
|
+
tracer: Dictionary containing uid, cid and mid
|
615
|
+
repo_path: Path to the git repository
|
616
|
+
|
617
|
+
Returns:
|
618
|
+
bool: True if merge was successful, False otherwise
|
619
|
+
"""
|
620
|
+
try:
|
621
|
+
# Extract branch names
|
622
|
+
msg_branch = f"m_{tracer['mid']}"
|
623
|
+
conv_branch = f"c_{tracer['cid']}"
|
624
|
+
|
625
|
+
# Infer repository path from environment variable or provided path
|
626
|
+
repo_path = os.getenv("PROMPTRACE_DIR", repo_path)
|
627
|
+
repo = Repo(repo_path)
|
628
|
+
|
629
|
+
# Checkout conversation branch
|
630
|
+
repo.heads[conv_branch].checkout()
|
631
|
+
|
632
|
+
# Create commit message
|
633
|
+
metadata_yaml = yaml.dump(tracer, allow_unicode=True, sort_keys=False, default_flow_style=False)
|
634
|
+
commit_message = f"""
|
635
|
+
{query[:250]}
|
636
|
+
|
637
|
+
Response:
|
638
|
+
---
|
639
|
+
{response[:500]}...
|
640
|
+
|
641
|
+
Metadata
|
642
|
+
---
|
643
|
+
{metadata_yaml}
|
644
|
+
""".strip()
|
645
|
+
|
646
|
+
# Merge message branch into conversation branch
|
647
|
+
repo.git.merge(msg_branch, no_ff=True, m=commit_message)
|
648
|
+
|
649
|
+
# Delete message branch after merge
|
650
|
+
repo.delete_head(msg_branch, force=True)
|
651
|
+
|
652
|
+
logger.debug(f"Successfully merged {msg_branch} into {conv_branch}")
|
653
|
+
return True
|
654
|
+
except Exception as e:
|
655
|
+
logger.error(f"Failed to merge message {msg_branch} into conversation {conv_branch}: {str(e)}", exc_info=True)
|
656
|
+
return False
|
khoj/processor/embeddings.py
CHANGED
@@ -13,7 +13,7 @@ from tenacity import (
|
|
13
13
|
)
|
14
14
|
from torch import nn
|
15
15
|
|
16
|
-
from khoj.utils.helpers import get_device, merge_dicts, timer
|
16
|
+
from khoj.utils.helpers import fix_json_dict, get_device, merge_dicts, timer
|
17
17
|
from khoj.utils.rawconfig import SearchResponse
|
18
18
|
|
19
19
|
logger = logging.getLogger(__name__)
|
@@ -31,9 +31,9 @@ class EmbeddingsModel:
|
|
31
31
|
):
|
32
32
|
default_query_encode_kwargs = {"show_progress_bar": False, "normalize_embeddings": True}
|
33
33
|
default_docs_encode_kwargs = {"show_progress_bar": True, "normalize_embeddings": True}
|
34
|
-
self.query_encode_kwargs = merge_dicts(query_encode_kwargs, default_query_encode_kwargs)
|
35
|
-
self.docs_encode_kwargs = merge_dicts(docs_encode_kwargs, default_docs_encode_kwargs)
|
36
|
-
self.model_kwargs = merge_dicts(model_kwargs, {"device": get_device()})
|
34
|
+
self.query_encode_kwargs = merge_dicts(fix_json_dict(query_encode_kwargs), default_query_encode_kwargs)
|
35
|
+
self.docs_encode_kwargs = merge_dicts(fix_json_dict(docs_encode_kwargs), default_docs_encode_kwargs)
|
36
|
+
self.model_kwargs = merge_dicts(fix_json_dict(model_kwargs), {"device": get_device()})
|
37
37
|
self.model_name = model_name
|
38
38
|
self.inference_endpoint = embeddings_inference_endpoint
|
39
39
|
self.api_key = embeddings_inference_endpoint_api_key
|
khoj/processor/image/generate.py
CHANGED
@@ -28,6 +28,7 @@ async def text_to_image(
|
|
28
28
|
send_status_func: Optional[Callable] = None,
|
29
29
|
query_images: Optional[List[str]] = None,
|
30
30
|
agent: Agent = None,
|
31
|
+
tracer: dict = {},
|
31
32
|
):
|
32
33
|
status_code = 200
|
33
34
|
image = None
|
@@ -68,6 +69,7 @@ async def text_to_image(
|
|
68
69
|
query_images=query_images,
|
69
70
|
user=user,
|
70
71
|
agent=agent,
|
72
|
+
tracer=tracer,
|
71
73
|
)
|
72
74
|
|
73
75
|
if send_status_func:
|
@@ -4,7 +4,7 @@ import logging
|
|
4
4
|
import os
|
5
5
|
import urllib.parse
|
6
6
|
from collections import defaultdict
|
7
|
-
from typing import Callable, Dict, List, Optional, Tuple, Union
|
7
|
+
from typing import Any, Callable, Dict, List, Optional, Tuple, Union
|
8
8
|
|
9
9
|
import aiohttp
|
10
10
|
from bs4 import BeautifulSoup
|
@@ -52,7 +52,9 @@ OLOSTEP_QUERY_PARAMS = {
|
|
52
52
|
"expandMarkdown": "True",
|
53
53
|
"expandHtml": "False",
|
54
54
|
}
|
55
|
-
|
55
|
+
|
56
|
+
DEFAULT_MAX_WEBPAGES_TO_READ = 1
|
57
|
+
MAX_WEBPAGES_TO_INFER = 10
|
56
58
|
|
57
59
|
|
58
60
|
async def search_online(
|
@@ -62,8 +64,10 @@ async def search_online(
|
|
62
64
|
user: KhojUser,
|
63
65
|
send_status_func: Optional[Callable] = None,
|
64
66
|
custom_filters: List[str] = [],
|
67
|
+
max_webpages_to_read: int = DEFAULT_MAX_WEBPAGES_TO_READ,
|
65
68
|
query_images: List[str] = None,
|
66
69
|
agent: Agent = None,
|
70
|
+
tracer: dict = {},
|
67
71
|
):
|
68
72
|
query += " ".join(custom_filters)
|
69
73
|
if not is_internet_connected():
|
@@ -73,7 +77,7 @@ async def search_online(
|
|
73
77
|
|
74
78
|
# Breakdown the query into subqueries to get the correct answer
|
75
79
|
subqueries = await generate_online_subqueries(
|
76
|
-
query, conversation_history, location, user, query_images=query_images, agent=agent
|
80
|
+
query, conversation_history, location, user, query_images=query_images, agent=agent, tracer=tracer
|
77
81
|
)
|
78
82
|
response_dict = {}
|
79
83
|
|
@@ -96,7 +100,7 @@ async def search_online(
|
|
96
100
|
for subquery in response_dict:
|
97
101
|
if "answerBox" in response_dict[subquery]:
|
98
102
|
continue
|
99
|
-
for organic in response_dict[subquery].get("organic", [])[:
|
103
|
+
for organic in response_dict[subquery].get("organic", [])[:max_webpages_to_read]:
|
100
104
|
link = organic.get("link")
|
101
105
|
if link in webpages:
|
102
106
|
webpages[link]["queries"].add(subquery)
|
@@ -111,7 +115,7 @@ async def search_online(
|
|
111
115
|
async for event in send_status_func(f"**Reading web pages**: {webpage_links_str}"):
|
112
116
|
yield {ChatEvent.STATUS: event}
|
113
117
|
tasks = [
|
114
|
-
read_webpage_and_extract_content(data["queries"], link, data["content"], user=user, agent=agent)
|
118
|
+
read_webpage_and_extract_content(data["queries"], link, data["content"], user=user, agent=agent, tracer=tracer)
|
115
119
|
for link, data in webpages.items()
|
116
120
|
]
|
117
121
|
results = await asyncio.gather(*tasks)
|
@@ -153,20 +157,24 @@ async def read_webpages(
|
|
153
157
|
send_status_func: Optional[Callable] = None,
|
154
158
|
query_images: List[str] = None,
|
155
159
|
agent: Agent = None,
|
160
|
+
tracer: dict = {},
|
161
|
+
max_webpages_to_read: int = DEFAULT_MAX_WEBPAGES_TO_READ,
|
156
162
|
):
|
157
163
|
"Infer web pages to read from the query and extract relevant information from them"
|
158
164
|
logger.info(f"Inferring web pages to read")
|
159
|
-
|
160
|
-
|
161
|
-
|
162
|
-
|
165
|
+
urls = await infer_webpage_urls(
|
166
|
+
query, conversation_history, location, user, query_images, agent=agent, tracer=tracer
|
167
|
+
)
|
168
|
+
|
169
|
+
# Get the top 10 web pages to read
|
170
|
+
urls = urls[:max_webpages_to_read]
|
163
171
|
|
164
172
|
logger.info(f"Reading web pages at: {urls}")
|
165
173
|
if send_status_func:
|
166
174
|
webpage_links_str = "\n- " + "\n- ".join(list(urls))
|
167
175
|
async for event in send_status_func(f"**Reading web pages**: {webpage_links_str}"):
|
168
176
|
yield {ChatEvent.STATUS: event}
|
169
|
-
tasks = [read_webpage_and_extract_content({query}, url, user=user, agent=agent) for url in urls]
|
177
|
+
tasks = [read_webpage_and_extract_content({query}, url, user=user, agent=agent, tracer=tracer) for url in urls]
|
170
178
|
results = await asyncio.gather(*tasks)
|
171
179
|
|
172
180
|
response: Dict[str, Dict] = defaultdict(dict)
|
@@ -192,7 +200,12 @@ async def read_webpage(
|
|
192
200
|
|
193
201
|
|
194
202
|
async def read_webpage_and_extract_content(
|
195
|
-
subqueries: set[str],
|
203
|
+
subqueries: set[str],
|
204
|
+
url: str,
|
205
|
+
content: str = None,
|
206
|
+
user: KhojUser = None,
|
207
|
+
agent: Agent = None,
|
208
|
+
tracer: dict = {},
|
196
209
|
) -> Tuple[set[str], str, Union[None, str]]:
|
197
210
|
# Select the web scrapers to use for reading the web page
|
198
211
|
web_scrapers = await ConversationAdapters.aget_enabled_webscrapers()
|
@@ -214,7 +227,9 @@ async def read_webpage_and_extract_content(
|
|
214
227
|
# Extract relevant information from the web page
|
215
228
|
if is_none_or_empty(extracted_info):
|
216
229
|
with timer(f"Extracting relevant information from web page at '{url}' took", logger):
|
217
|
-
extracted_info = await extract_relevant_info(
|
230
|
+
extracted_info = await extract_relevant_info(
|
231
|
+
subqueries, content, user=user, agent=agent, tracer=tracer
|
232
|
+
)
|
218
233
|
|
219
234
|
# If we successfully extracted information, break the loop
|
220
235
|
if not is_none_or_empty(extracted_info):
|
@@ -0,0 +1,144 @@
|
|
1
|
+
import asyncio
|
2
|
+
import datetime
|
3
|
+
import json
|
4
|
+
import logging
|
5
|
+
import os
|
6
|
+
from typing import Any, Callable, List, Optional
|
7
|
+
|
8
|
+
import aiohttp
|
9
|
+
|
10
|
+
from khoj.database.adapters import ais_user_subscribed
|
11
|
+
from khoj.database.models import Agent, KhojUser
|
12
|
+
from khoj.processor.conversation import prompts
|
13
|
+
from khoj.processor.conversation.utils import (
|
14
|
+
ChatEvent,
|
15
|
+
clean_code_python,
|
16
|
+
clean_json,
|
17
|
+
construct_chat_history,
|
18
|
+
)
|
19
|
+
from khoj.routers.helpers import send_message_to_model_wrapper
|
20
|
+
from khoj.utils.helpers import timer
|
21
|
+
from khoj.utils.rawconfig import LocationData
|
22
|
+
|
23
|
+
logger = logging.getLogger(__name__)
|
24
|
+
|
25
|
+
|
26
|
+
SANDBOX_URL = os.getenv("KHOJ_TERRARIUM_URL", "http://localhost:8080")
|
27
|
+
|
28
|
+
|
29
|
+
async def run_code(
|
30
|
+
query: str,
|
31
|
+
conversation_history: dict,
|
32
|
+
context: str,
|
33
|
+
location_data: LocationData,
|
34
|
+
user: KhojUser,
|
35
|
+
send_status_func: Optional[Callable] = None,
|
36
|
+
query_images: List[str] = None,
|
37
|
+
agent: Agent = None,
|
38
|
+
sandbox_url: str = SANDBOX_URL,
|
39
|
+
tracer: dict = {},
|
40
|
+
):
|
41
|
+
# Generate Code
|
42
|
+
if send_status_func:
|
43
|
+
async for event in send_status_func(f"**Generate code snippets** for {query}"):
|
44
|
+
yield {ChatEvent.STATUS: event}
|
45
|
+
try:
|
46
|
+
with timer("Chat actor: Generate programs to execute", logger):
|
47
|
+
codes = await generate_python_code(
|
48
|
+
query,
|
49
|
+
conversation_history,
|
50
|
+
context,
|
51
|
+
location_data,
|
52
|
+
user,
|
53
|
+
query_images,
|
54
|
+
agent,
|
55
|
+
tracer,
|
56
|
+
)
|
57
|
+
except Exception as e:
|
58
|
+
raise ValueError(f"Failed to generate code for {query} with error: {e}")
|
59
|
+
|
60
|
+
# Run Code
|
61
|
+
if send_status_func:
|
62
|
+
async for event in send_status_func(f"**Running {len(codes)} code snippets**"):
|
63
|
+
yield {ChatEvent.STATUS: event}
|
64
|
+
try:
|
65
|
+
tasks = [execute_sandboxed_python(code, sandbox_url) for code in codes]
|
66
|
+
with timer("Chat actor: Execute generated programs", logger):
|
67
|
+
results = await asyncio.gather(*tasks)
|
68
|
+
for result in results:
|
69
|
+
code = result.pop("code")
|
70
|
+
logger.info(f"Executed Code:\n--@@--\n{code}\n--@@--Result:\n--@@--\n{result}\n--@@--")
|
71
|
+
yield {query: {"code": code, "results": result}}
|
72
|
+
except Exception as e:
|
73
|
+
raise ValueError(f"Failed to run code for {query} with error: {e}")
|
74
|
+
|
75
|
+
|
76
|
+
async def generate_python_code(
|
77
|
+
q: str,
|
78
|
+
conversation_history: dict,
|
79
|
+
context: str,
|
80
|
+
location_data: LocationData,
|
81
|
+
user: KhojUser,
|
82
|
+
query_images: List[str] = None,
|
83
|
+
agent: Agent = None,
|
84
|
+
tracer: dict = {},
|
85
|
+
) -> List[str]:
|
86
|
+
location = f"{location_data}" if location_data else "Unknown"
|
87
|
+
username = prompts.user_name.format(name=user.get_full_name()) if user.get_full_name() else ""
|
88
|
+
subscribed = await ais_user_subscribed(user)
|
89
|
+
chat_history = construct_chat_history(conversation_history)
|
90
|
+
|
91
|
+
utc_date = datetime.datetime.now(datetime.timezone.utc).strftime("%Y-%m-%d")
|
92
|
+
personality_context = (
|
93
|
+
prompts.personality_context.format(personality=agent.personality) if agent and agent.personality else ""
|
94
|
+
)
|
95
|
+
|
96
|
+
code_generation_prompt = prompts.python_code_generation_prompt.format(
|
97
|
+
current_date=utc_date,
|
98
|
+
query=q,
|
99
|
+
chat_history=chat_history,
|
100
|
+
context=context,
|
101
|
+
location=location,
|
102
|
+
username=username,
|
103
|
+
personality_context=personality_context,
|
104
|
+
)
|
105
|
+
|
106
|
+
response = await send_message_to_model_wrapper(
|
107
|
+
code_generation_prompt,
|
108
|
+
query_images=query_images,
|
109
|
+
response_type="json_object",
|
110
|
+
user=user,
|
111
|
+
tracer=tracer,
|
112
|
+
)
|
113
|
+
|
114
|
+
# Validate that the response is a non-empty, JSON-serializable list
|
115
|
+
response = clean_json(response)
|
116
|
+
response = json.loads(response)
|
117
|
+
codes = [code.strip() for code in response["codes"] if code.strip()]
|
118
|
+
|
119
|
+
if not isinstance(codes, list) or not codes or len(codes) == 0:
|
120
|
+
raise ValueError
|
121
|
+
return codes
|
122
|
+
|
123
|
+
|
124
|
+
async def execute_sandboxed_python(code: str, sandbox_url: str = SANDBOX_URL) -> dict[str, Any]:
|
125
|
+
"""
|
126
|
+
Takes code to run as a string and calls the terrarium API to execute it.
|
127
|
+
Returns the result of the code execution as a dictionary.
|
128
|
+
"""
|
129
|
+
headers = {"Content-Type": "application/json"}
|
130
|
+
cleaned_code = clean_code_python(code)
|
131
|
+
data = {"code": cleaned_code}
|
132
|
+
|
133
|
+
async with aiohttp.ClientSession() as session:
|
134
|
+
async with session.post(sandbox_url, json=data, headers=headers) as response:
|
135
|
+
if response.status == 200:
|
136
|
+
result: dict[str, Any] = await response.json()
|
137
|
+
result["code"] = cleaned_code
|
138
|
+
return result
|
139
|
+
else:
|
140
|
+
return {
|
141
|
+
"code": cleaned_code,
|
142
|
+
"success": False,
|
143
|
+
"std_err": f"Failed to execute code with {response.status}",
|
144
|
+
}
|