khoj 1.27.2.dev29__py3-none-any.whl → 1.28.1.dev1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- khoj/configure.py +1 -1
- khoj/database/adapters/__init__.py +50 -12
- khoj/interface/compiled/404/index.html +1 -1
- khoj/interface/compiled/_next/static/chunks/1034-da58b679fcbb79c1.js +1 -0
- khoj/interface/compiled/_next/static/chunks/1467-b331e469fe411347.js +1 -0
- khoj/interface/compiled/_next/static/chunks/1603-c1568f45947e9f2c.js +1 -0
- khoj/interface/compiled/_next/static/chunks/3423-ff7402ae1dd66592.js +1 -0
- khoj/interface/compiled/_next/static/chunks/8423-e80647edf6c92c27.js +1 -0
- khoj/interface/compiled/_next/static/chunks/app/agents/{page-5ae1e540bb5be8a9.js → page-2beaba7c9bb750bd.js} +1 -1
- khoj/interface/compiled/_next/static/chunks/app/automations/{page-774ae3e033f938cd.js → page-9b5c77e0b0dd772c.js} +1 -1
- khoj/interface/compiled/_next/static/chunks/app/chat/page-bfc70b16ba5e51b4.js +1 -0
- khoj/interface/compiled/_next/static/chunks/app/factchecker/page-340bcf53abf6a2cc.js +1 -0
- khoj/interface/compiled/_next/static/chunks/app/{page-4dc472cf6d674004.js → page-f249666a0cbdaa0d.js} +1 -1
- khoj/interface/compiled/_next/static/chunks/app/search/{page-9b64f61caa5bd7f9.js → page-ab2995529ece3140.js} +1 -1
- khoj/interface/compiled/_next/static/chunks/app/settings/{page-7a8c382af2a7e870.js → page-89e6737b2cc9fb3a.js} +1 -1
- khoj/interface/compiled/_next/static/chunks/app/share/chat/{page-eb9e282691858f2e.js → page-505b07bce608b34e.js} +1 -1
- khoj/interface/compiled/_next/static/chunks/{webpack-2b720658ccc746f2.js → webpack-878569182b3af4c6.js} +1 -1
- khoj/interface/compiled/_next/static/css/{2272c73fc7a3b571.css → 26c1c33d0423a7d8.css} +1 -1
- khoj/interface/compiled/_next/static/css/592ca99f5122e75a.css +1 -0
- khoj/interface/compiled/_next/static/css/a795ee88875f4853.css +25 -0
- khoj/interface/compiled/_next/static/css/d738728883c68af8.css +1 -0
- khoj/interface/compiled/agents/index.html +1 -1
- khoj/interface/compiled/agents/index.txt +2 -2
- khoj/interface/compiled/automations/index.html +1 -1
- khoj/interface/compiled/automations/index.txt +2 -2
- khoj/interface/compiled/chat/index.html +1 -1
- khoj/interface/compiled/chat/index.txt +2 -2
- khoj/interface/compiled/factchecker/index.html +1 -1
- khoj/interface/compiled/factchecker/index.txt +2 -2
- khoj/interface/compiled/index.html +1 -1
- khoj/interface/compiled/index.txt +2 -2
- khoj/interface/compiled/search/index.html +1 -1
- khoj/interface/compiled/search/index.txt +2 -2
- khoj/interface/compiled/settings/index.html +1 -1
- khoj/interface/compiled/settings/index.txt +2 -2
- khoj/interface/compiled/share/chat/index.html +1 -1
- khoj/interface/compiled/share/chat/index.txt +2 -2
- khoj/processor/conversation/anthropic/anthropic_chat.py +14 -10
- khoj/processor/conversation/anthropic/utils.py +13 -2
- khoj/processor/conversation/google/gemini_chat.py +15 -11
- khoj/processor/conversation/offline/chat_model.py +18 -10
- khoj/processor/conversation/openai/gpt.py +11 -8
- khoj/processor/conversation/openai/utils.py +7 -0
- khoj/processor/conversation/prompts.py +156 -49
- khoj/processor/conversation/utils.py +146 -13
- khoj/processor/embeddings.py +4 -4
- khoj/processor/tools/online_search.py +13 -7
- khoj/processor/tools/run_code.py +144 -0
- khoj/routers/api.py +6 -6
- khoj/routers/api_chat.py +193 -112
- khoj/routers/helpers.py +107 -48
- khoj/routers/research.py +320 -0
- khoj/search_filter/date_filter.py +1 -3
- khoj/search_filter/file_filter.py +1 -2
- khoj/search_type/text_search.py +3 -3
- khoj/utils/helpers.py +24 -2
- khoj/utils/yaml.py +4 -0
- {khoj-1.27.2.dev29.dist-info → khoj-1.28.1.dev1.dist-info}/METADATA +3 -3
- {khoj-1.27.2.dev29.dist-info → khoj-1.28.1.dev1.dist-info}/RECORD +66 -63
- khoj/interface/compiled/_next/static/chunks/1603-5138bb7c8035d9a6.js +0 -1
- khoj/interface/compiled/_next/static/chunks/2697-61fcba89fd87eab4.js +0 -1
- khoj/interface/compiled/_next/static/chunks/3423-0b533af8bf6ac218.js +0 -1
- khoj/interface/compiled/_next/static/chunks/9479-ff7d8c4dae2014d1.js +0 -1
- khoj/interface/compiled/_next/static/chunks/app/chat/page-97f5b61aaf46d364.js +0 -1
- khoj/interface/compiled/_next/static/chunks/app/factchecker/page-d82403db2866bad8.js +0 -1
- khoj/interface/compiled/_next/static/css/4cae6c0e5c72fb2d.css +0 -1
- khoj/interface/compiled/_next/static/css/76d55eb435962b19.css +0 -25
- khoj/interface/compiled/_next/static/css/ddcc0cf73e062476.css +0 -1
- /khoj/interface/compiled/_next/static/{atzIseFarmC7TIwq2BgHC → K7ZigmRDrBfpIN7jxKQsA}/_buildManifest.js +0 -0
- /khoj/interface/compiled/_next/static/{atzIseFarmC7TIwq2BgHC → K7ZigmRDrBfpIN7jxKQsA}/_ssgManifest.js +0 -0
- /khoj/interface/compiled/_next/static/chunks/{1970-60c96aed937a4928.js → 1970-90dd510762d820ba.js} +0 -0
- /khoj/interface/compiled/_next/static/chunks/{9417-2ca87207387fc790.js → 9417-951f46451a8dd6d7.js} +0 -0
- {khoj-1.27.2.dev29.dist-info → khoj-1.28.1.dev1.dist-info}/WHEEL +0 -0
- {khoj-1.27.2.dev29.dist-info → khoj-1.28.1.dev1.dist-info}/entry_points.txt +0 -0
- {khoj-1.27.2.dev29.dist-info → khoj-1.28.1.dev1.dist-info}/licenses/LICENSE +0 -0
@@ -1,14 +1,17 @@
|
|
1
1
|
import base64
|
2
|
+
import json
|
2
3
|
import logging
|
3
4
|
import math
|
4
5
|
import mimetypes
|
5
6
|
import os
|
6
7
|
import queue
|
8
|
+
import uuid
|
7
9
|
from dataclasses import dataclass
|
8
10
|
from datetime import datetime
|
11
|
+
from enum import Enum
|
9
12
|
from io import BytesIO
|
10
13
|
from time import perf_counter
|
11
|
-
from typing import Any, Dict, List, Optional
|
14
|
+
from typing import Any, Callable, Dict, List, Optional
|
12
15
|
|
13
16
|
import PIL.Image
|
14
17
|
import requests
|
@@ -23,8 +26,17 @@ from khoj.database.adapters import ConversationAdapters
|
|
23
26
|
from khoj.database.models import ChatModelOptions, ClientApplication, KhojUser
|
24
27
|
from khoj.processor.conversation import prompts
|
25
28
|
from khoj.processor.conversation.offline.utils import download_model, infer_max_tokens
|
29
|
+
from khoj.search_filter.base_filter import BaseFilter
|
30
|
+
from khoj.search_filter.date_filter import DateFilter
|
31
|
+
from khoj.search_filter.file_filter import FileFilter
|
32
|
+
from khoj.search_filter.word_filter import WordFilter
|
26
33
|
from khoj.utils import state
|
27
|
-
from khoj.utils.helpers import
|
34
|
+
from khoj.utils.helpers import (
|
35
|
+
ConversationCommand,
|
36
|
+
in_debug_mode,
|
37
|
+
is_none_or_empty,
|
38
|
+
merge_dicts,
|
39
|
+
)
|
28
40
|
|
29
41
|
logger = logging.getLogger(__name__)
|
30
42
|
model_to_prompt_size = {
|
@@ -85,8 +97,110 @@ class ThreadedGenerator:
|
|
85
97
|
self.queue.put(StopIteration)
|
86
98
|
|
87
99
|
|
100
|
+
class InformationCollectionIteration:
|
101
|
+
def __init__(
|
102
|
+
self,
|
103
|
+
tool: str,
|
104
|
+
query: str,
|
105
|
+
context: list = None,
|
106
|
+
onlineContext: dict = None,
|
107
|
+
codeContext: dict = None,
|
108
|
+
summarizedResult: str = None,
|
109
|
+
):
|
110
|
+
self.tool = tool
|
111
|
+
self.query = query
|
112
|
+
self.context = context
|
113
|
+
self.onlineContext = onlineContext
|
114
|
+
self.codeContext = codeContext
|
115
|
+
self.summarizedResult = summarizedResult
|
116
|
+
|
117
|
+
|
118
|
+
def construct_iteration_history(
|
119
|
+
previous_iterations: List[InformationCollectionIteration], previous_iteration_prompt: str
|
120
|
+
) -> str:
|
121
|
+
previous_iterations_history = ""
|
122
|
+
for idx, iteration in enumerate(previous_iterations):
|
123
|
+
iteration_data = previous_iteration_prompt.format(
|
124
|
+
tool=iteration.tool,
|
125
|
+
query=iteration.query,
|
126
|
+
result=iteration.summarizedResult,
|
127
|
+
index=idx + 1,
|
128
|
+
)
|
129
|
+
|
130
|
+
previous_iterations_history += iteration_data
|
131
|
+
return previous_iterations_history
|
132
|
+
|
133
|
+
|
134
|
+
def construct_chat_history(conversation_history: dict, n: int = 4, agent_name="AI") -> str:
|
135
|
+
chat_history = ""
|
136
|
+
for chat in conversation_history.get("chat", [])[-n:]:
|
137
|
+
if chat["by"] == "khoj" and chat["intent"].get("type") in ["remember", "reminder", "summarize"]:
|
138
|
+
chat_history += f"User: {chat['intent']['query']}\n"
|
139
|
+
|
140
|
+
if chat["intent"].get("inferred-queries"):
|
141
|
+
chat_history += f'Khoj: {{"queries": {chat["intent"].get("inferred-queries")}}}\n'
|
142
|
+
|
143
|
+
chat_history += f"{agent_name}: {chat['message']}\n\n"
|
144
|
+
elif chat["by"] == "khoj" and ("text-to-image" in chat["intent"].get("type")):
|
145
|
+
chat_history += f"User: {chat['intent']['query']}\n"
|
146
|
+
chat_history += f"{agent_name}: [generated image redacted for space]\n"
|
147
|
+
elif chat["by"] == "khoj" and ("excalidraw" in chat["intent"].get("type")):
|
148
|
+
chat_history += f"User: {chat['intent']['query']}\n"
|
149
|
+
chat_history += f"{agent_name}: {chat['intent']['inferred-queries'][0]}\n"
|
150
|
+
return chat_history
|
151
|
+
|
152
|
+
|
153
|
+
def construct_tool_chat_history(
|
154
|
+
previous_iterations: List[InformationCollectionIteration], tool: ConversationCommand = None
|
155
|
+
) -> Dict[str, list]:
|
156
|
+
chat_history: list = []
|
157
|
+
inferred_query_extractor: Callable[[InformationCollectionIteration], List[str]] = lambda x: []
|
158
|
+
if tool == ConversationCommand.Notes:
|
159
|
+
inferred_query_extractor = (
|
160
|
+
lambda iteration: [c["query"] for c in iteration.context] if iteration.context else []
|
161
|
+
)
|
162
|
+
elif tool == ConversationCommand.Online:
|
163
|
+
inferred_query_extractor = (
|
164
|
+
lambda iteration: list(iteration.onlineContext.keys()) if iteration.onlineContext else []
|
165
|
+
)
|
166
|
+
elif tool == ConversationCommand.Code:
|
167
|
+
inferred_query_extractor = lambda iteration: list(iteration.codeContext.keys()) if iteration.codeContext else []
|
168
|
+
for iteration in previous_iterations:
|
169
|
+
chat_history += [
|
170
|
+
{
|
171
|
+
"by": "you",
|
172
|
+
"message": iteration.query,
|
173
|
+
},
|
174
|
+
{
|
175
|
+
"by": "khoj",
|
176
|
+
"intent": {
|
177
|
+
"type": "remember",
|
178
|
+
"inferred-queries": inferred_query_extractor(iteration),
|
179
|
+
"query": iteration.query,
|
180
|
+
},
|
181
|
+
"message": iteration.summarizedResult,
|
182
|
+
},
|
183
|
+
]
|
184
|
+
|
185
|
+
return {"chat": chat_history}
|
186
|
+
|
187
|
+
|
188
|
+
class ChatEvent(Enum):
|
189
|
+
START_LLM_RESPONSE = "start_llm_response"
|
190
|
+
END_LLM_RESPONSE = "end_llm_response"
|
191
|
+
MESSAGE = "message"
|
192
|
+
REFERENCES = "references"
|
193
|
+
STATUS = "status"
|
194
|
+
METADATA = "metadata"
|
195
|
+
|
196
|
+
|
88
197
|
def message_to_log(
|
89
|
-
user_message,
|
198
|
+
user_message,
|
199
|
+
chat_response,
|
200
|
+
user_message_metadata={},
|
201
|
+
khoj_message_metadata={},
|
202
|
+
conversation_log=[],
|
203
|
+
train_of_thought=[],
|
90
204
|
):
|
91
205
|
"""Create json logs from messages, metadata for conversation log"""
|
92
206
|
default_khoj_message_metadata = {
|
@@ -114,6 +228,7 @@ def save_to_conversation_log(
|
|
114
228
|
user_message_time: str = None,
|
115
229
|
compiled_references: List[Dict[str, Any]] = [],
|
116
230
|
online_results: Dict[str, Any] = {},
|
231
|
+
code_results: Dict[str, Any] = {},
|
117
232
|
inferred_queries: List[str] = [],
|
118
233
|
intent_type: str = "remember",
|
119
234
|
client_application: ClientApplication = None,
|
@@ -121,22 +236,29 @@ def save_to_conversation_log(
|
|
121
236
|
automation_id: str = None,
|
122
237
|
query_images: List[str] = None,
|
123
238
|
tracer: Dict[str, Any] = {},
|
239
|
+
train_of_thought: List[Any] = [],
|
124
240
|
):
|
125
241
|
user_message_time = user_message_time or datetime.now().strftime("%Y-%m-%d %H:%M:%S")
|
242
|
+
turn_id = tracer.get("mid") or str(uuid.uuid4())
|
126
243
|
updated_conversation = message_to_log(
|
127
244
|
user_message=q,
|
128
245
|
chat_response=chat_response,
|
129
246
|
user_message_metadata={
|
130
247
|
"created": user_message_time,
|
131
248
|
"images": query_images,
|
249
|
+
"turnId": turn_id,
|
132
250
|
},
|
133
251
|
khoj_message_metadata={
|
134
252
|
"context": compiled_references,
|
135
253
|
"intent": {"inferred-queries": inferred_queries, "type": intent_type},
|
136
254
|
"onlineContext": online_results,
|
255
|
+
"codeContext": code_results,
|
137
256
|
"automationId": automation_id,
|
257
|
+
"trainOfThought": train_of_thought,
|
258
|
+
"turnId": turn_id,
|
138
259
|
},
|
139
260
|
conversation_log=meta_log.get("chat", []),
|
261
|
+
train_of_thought=train_of_thought,
|
140
262
|
)
|
141
263
|
ConversationAdapters.save_conversation(
|
142
264
|
user,
|
@@ -330,9 +452,23 @@ def reciprocal_conversation_to_chatml(message_pair):
|
|
330
452
|
return [ChatMessage(content=message, role=role) for message, role in zip(message_pair, ["user", "assistant"])]
|
331
453
|
|
332
454
|
|
333
|
-
def
|
334
|
-
"""Remove any markdown json codeblock formatting if present. Useful for non schema enforceable models"""
|
335
|
-
return response.removeprefix("```json").removesuffix("```")
|
455
|
+
def clean_json(response: str):
|
456
|
+
"""Remove any markdown json codeblock and newline formatting if present. Useful for non schema enforceable models"""
|
457
|
+
return response.strip().replace("\n", "").removeprefix("```json").removesuffix("```")
|
458
|
+
|
459
|
+
|
460
|
+
def clean_code_python(code: str):
|
461
|
+
"""Remove any markdown codeblock and newline formatting if present. Useful for non schema enforceable models"""
|
462
|
+
return code.strip().removeprefix("```python").removesuffix("```")
|
463
|
+
|
464
|
+
|
465
|
+
def defilter_query(query: str):
|
466
|
+
"""Remove any query filters in query"""
|
467
|
+
defiltered_query = query
|
468
|
+
filters: List[BaseFilter] = [WordFilter(), FileFilter(), DateFilter()]
|
469
|
+
for filter in filters:
|
470
|
+
defiltered_query = filter.defilter(defiltered_query)
|
471
|
+
return defiltered_query
|
336
472
|
|
337
473
|
|
338
474
|
@dataclass
|
@@ -375,15 +511,12 @@ def commit_conversation_trace(
|
|
375
511
|
Returns the path to the repository.
|
376
512
|
"""
|
377
513
|
# Serialize session, system message and response to yaml
|
378
|
-
system_message_yaml =
|
379
|
-
response_yaml =
|
514
|
+
system_message_yaml = json.dumps(system_message, ensure_ascii=False, sort_keys=False)
|
515
|
+
response_yaml = json.dumps(response, ensure_ascii=False, sort_keys=False)
|
380
516
|
formatted_session = [{"role": message.role, "content": message.content} for message in session]
|
381
|
-
session_yaml =
|
517
|
+
session_yaml = json.dumps(formatted_session, ensure_ascii=False, sort_keys=False)
|
382
518
|
query = (
|
383
|
-
|
384
|
-
.strip()
|
385
|
-
.removeprefix("'")
|
386
|
-
.removesuffix("'")
|
519
|
+
json.dumps(session[-1].content, ensure_ascii=False, sort_keys=False).strip().removeprefix("'").removesuffix("'")
|
387
520
|
) # Extract serialized query from chat session
|
388
521
|
|
389
522
|
# Extract chat metadata for session
|
khoj/processor/embeddings.py
CHANGED
@@ -13,7 +13,7 @@ from tenacity import (
|
|
13
13
|
)
|
14
14
|
from torch import nn
|
15
15
|
|
16
|
-
from khoj.utils.helpers import get_device, merge_dicts, timer
|
16
|
+
from khoj.utils.helpers import fix_json_dict, get_device, merge_dicts, timer
|
17
17
|
from khoj.utils.rawconfig import SearchResponse
|
18
18
|
|
19
19
|
logger = logging.getLogger(__name__)
|
@@ -31,9 +31,9 @@ class EmbeddingsModel:
|
|
31
31
|
):
|
32
32
|
default_query_encode_kwargs = {"show_progress_bar": False, "normalize_embeddings": True}
|
33
33
|
default_docs_encode_kwargs = {"show_progress_bar": True, "normalize_embeddings": True}
|
34
|
-
self.query_encode_kwargs = merge_dicts(query_encode_kwargs, default_query_encode_kwargs)
|
35
|
-
self.docs_encode_kwargs = merge_dicts(docs_encode_kwargs, default_docs_encode_kwargs)
|
36
|
-
self.model_kwargs = merge_dicts(model_kwargs, {"device": get_device()})
|
34
|
+
self.query_encode_kwargs = merge_dicts(fix_json_dict(query_encode_kwargs), default_query_encode_kwargs)
|
35
|
+
self.docs_encode_kwargs = merge_dicts(fix_json_dict(docs_encode_kwargs), default_docs_encode_kwargs)
|
36
|
+
self.model_kwargs = merge_dicts(fix_json_dict(model_kwargs), {"device": get_device()})
|
37
37
|
self.model_name = model_name
|
38
38
|
self.inference_endpoint = embeddings_inference_endpoint
|
39
39
|
self.api_key = embeddings_inference_endpoint_api_key
|
@@ -4,7 +4,7 @@ import logging
|
|
4
4
|
import os
|
5
5
|
import urllib.parse
|
6
6
|
from collections import defaultdict
|
7
|
-
from typing import Callable, Dict, List, Optional, Tuple, Union
|
7
|
+
from typing import Any, Callable, Dict, List, Optional, Tuple, Union
|
8
8
|
|
9
9
|
import aiohttp
|
10
10
|
from bs4 import BeautifulSoup
|
@@ -52,7 +52,9 @@ OLOSTEP_QUERY_PARAMS = {
|
|
52
52
|
"expandMarkdown": "True",
|
53
53
|
"expandHtml": "False",
|
54
54
|
}
|
55
|
-
|
55
|
+
|
56
|
+
DEFAULT_MAX_WEBPAGES_TO_READ = 1
|
57
|
+
MAX_WEBPAGES_TO_INFER = 10
|
56
58
|
|
57
59
|
|
58
60
|
async def search_online(
|
@@ -62,6 +64,7 @@ async def search_online(
|
|
62
64
|
user: KhojUser,
|
63
65
|
send_status_func: Optional[Callable] = None,
|
64
66
|
custom_filters: List[str] = [],
|
67
|
+
max_webpages_to_read: int = DEFAULT_MAX_WEBPAGES_TO_READ,
|
65
68
|
query_images: List[str] = None,
|
66
69
|
agent: Agent = None,
|
67
70
|
tracer: dict = {},
|
@@ -97,7 +100,7 @@ async def search_online(
|
|
97
100
|
for subquery in response_dict:
|
98
101
|
if "answerBox" in response_dict[subquery]:
|
99
102
|
continue
|
100
|
-
for organic in response_dict[subquery].get("organic", [])[:
|
103
|
+
for organic in response_dict[subquery].get("organic", [])[:max_webpages_to_read]:
|
101
104
|
link = organic.get("link")
|
102
105
|
if link in webpages:
|
103
106
|
webpages[link]["queries"].add(subquery)
|
@@ -155,13 +158,16 @@ async def read_webpages(
|
|
155
158
|
query_images: List[str] = None,
|
156
159
|
agent: Agent = None,
|
157
160
|
tracer: dict = {},
|
161
|
+
max_webpages_to_read: int = DEFAULT_MAX_WEBPAGES_TO_READ,
|
158
162
|
):
|
159
163
|
"Infer web pages to read from the query and extract relevant information from them"
|
160
164
|
logger.info(f"Inferring web pages to read")
|
161
|
-
|
162
|
-
|
163
|
-
|
164
|
-
|
165
|
+
urls = await infer_webpage_urls(
|
166
|
+
query, conversation_history, location, user, query_images, agent=agent, tracer=tracer
|
167
|
+
)
|
168
|
+
|
169
|
+
# Get the top 10 web pages to read
|
170
|
+
urls = urls[:max_webpages_to_read]
|
165
171
|
|
166
172
|
logger.info(f"Reading web pages at: {urls}")
|
167
173
|
if send_status_func:
|
@@ -0,0 +1,144 @@
|
|
1
|
+
import asyncio
|
2
|
+
import datetime
|
3
|
+
import json
|
4
|
+
import logging
|
5
|
+
import os
|
6
|
+
from typing import Any, Callable, List, Optional
|
7
|
+
|
8
|
+
import aiohttp
|
9
|
+
|
10
|
+
from khoj.database.adapters import ais_user_subscribed
|
11
|
+
from khoj.database.models import Agent, KhojUser
|
12
|
+
from khoj.processor.conversation import prompts
|
13
|
+
from khoj.processor.conversation.utils import (
|
14
|
+
ChatEvent,
|
15
|
+
clean_code_python,
|
16
|
+
clean_json,
|
17
|
+
construct_chat_history,
|
18
|
+
)
|
19
|
+
from khoj.routers.helpers import send_message_to_model_wrapper
|
20
|
+
from khoj.utils.helpers import timer
|
21
|
+
from khoj.utils.rawconfig import LocationData
|
22
|
+
|
23
|
+
logger = logging.getLogger(__name__)
|
24
|
+
|
25
|
+
|
26
|
+
SANDBOX_URL = os.getenv("KHOJ_TERRARIUM_URL", "http://localhost:8080")
|
27
|
+
|
28
|
+
|
29
|
+
async def run_code(
|
30
|
+
query: str,
|
31
|
+
conversation_history: dict,
|
32
|
+
context: str,
|
33
|
+
location_data: LocationData,
|
34
|
+
user: KhojUser,
|
35
|
+
send_status_func: Optional[Callable] = None,
|
36
|
+
query_images: List[str] = None,
|
37
|
+
agent: Agent = None,
|
38
|
+
sandbox_url: str = SANDBOX_URL,
|
39
|
+
tracer: dict = {},
|
40
|
+
):
|
41
|
+
# Generate Code
|
42
|
+
if send_status_func:
|
43
|
+
async for event in send_status_func(f"**Generate code snippets** for {query}"):
|
44
|
+
yield {ChatEvent.STATUS: event}
|
45
|
+
try:
|
46
|
+
with timer("Chat actor: Generate programs to execute", logger):
|
47
|
+
codes = await generate_python_code(
|
48
|
+
query,
|
49
|
+
conversation_history,
|
50
|
+
context,
|
51
|
+
location_data,
|
52
|
+
user,
|
53
|
+
query_images,
|
54
|
+
agent,
|
55
|
+
tracer,
|
56
|
+
)
|
57
|
+
except Exception as e:
|
58
|
+
raise ValueError(f"Failed to generate code for {query} with error: {e}")
|
59
|
+
|
60
|
+
# Run Code
|
61
|
+
if send_status_func:
|
62
|
+
async for event in send_status_func(f"**Running {len(codes)} code snippets**"):
|
63
|
+
yield {ChatEvent.STATUS: event}
|
64
|
+
try:
|
65
|
+
tasks = [execute_sandboxed_python(code, sandbox_url) for code in codes]
|
66
|
+
with timer("Chat actor: Execute generated programs", logger):
|
67
|
+
results = await asyncio.gather(*tasks)
|
68
|
+
for result in results:
|
69
|
+
code = result.pop("code")
|
70
|
+
logger.info(f"Executed Code:\n--@@--\n{code}\n--@@--Result:\n--@@--\n{result}\n--@@--")
|
71
|
+
yield {query: {"code": code, "results": result}}
|
72
|
+
except Exception as e:
|
73
|
+
raise ValueError(f"Failed to run code for {query} with error: {e}")
|
74
|
+
|
75
|
+
|
76
|
+
async def generate_python_code(
|
77
|
+
q: str,
|
78
|
+
conversation_history: dict,
|
79
|
+
context: str,
|
80
|
+
location_data: LocationData,
|
81
|
+
user: KhojUser,
|
82
|
+
query_images: List[str] = None,
|
83
|
+
agent: Agent = None,
|
84
|
+
tracer: dict = {},
|
85
|
+
) -> List[str]:
|
86
|
+
location = f"{location_data}" if location_data else "Unknown"
|
87
|
+
username = prompts.user_name.format(name=user.get_full_name()) if user.get_full_name() else ""
|
88
|
+
subscribed = await ais_user_subscribed(user)
|
89
|
+
chat_history = construct_chat_history(conversation_history)
|
90
|
+
|
91
|
+
utc_date = datetime.datetime.now(datetime.timezone.utc).strftime("%Y-%m-%d")
|
92
|
+
personality_context = (
|
93
|
+
prompts.personality_context.format(personality=agent.personality) if agent and agent.personality else ""
|
94
|
+
)
|
95
|
+
|
96
|
+
code_generation_prompt = prompts.python_code_generation_prompt.format(
|
97
|
+
current_date=utc_date,
|
98
|
+
query=q,
|
99
|
+
chat_history=chat_history,
|
100
|
+
context=context,
|
101
|
+
location=location,
|
102
|
+
username=username,
|
103
|
+
personality_context=personality_context,
|
104
|
+
)
|
105
|
+
|
106
|
+
response = await send_message_to_model_wrapper(
|
107
|
+
code_generation_prompt,
|
108
|
+
query_images=query_images,
|
109
|
+
response_type="json_object",
|
110
|
+
user=user,
|
111
|
+
tracer=tracer,
|
112
|
+
)
|
113
|
+
|
114
|
+
# Validate that the response is a non-empty, JSON-serializable list
|
115
|
+
response = clean_json(response)
|
116
|
+
response = json.loads(response)
|
117
|
+
codes = [code.strip() for code in response["codes"] if code.strip()]
|
118
|
+
|
119
|
+
if not isinstance(codes, list) or not codes or len(codes) == 0:
|
120
|
+
raise ValueError
|
121
|
+
return codes
|
122
|
+
|
123
|
+
|
124
|
+
async def execute_sandboxed_python(code: str, sandbox_url: str = SANDBOX_URL) -> dict[str, Any]:
|
125
|
+
"""
|
126
|
+
Takes code to run as a string and calls the terrarium API to execute it.
|
127
|
+
Returns the result of the code execution as a dictionary.
|
128
|
+
"""
|
129
|
+
headers = {"Content-Type": "application/json"}
|
130
|
+
cleaned_code = clean_code_python(code)
|
131
|
+
data = {"code": cleaned_code}
|
132
|
+
|
133
|
+
async with aiohttp.ClientSession() as session:
|
134
|
+
async with session.post(sandbox_url, json=data, headers=headers) as response:
|
135
|
+
if response.status == 200:
|
136
|
+
result: dict[str, Any] = await response.json()
|
137
|
+
result["code"] = cleaned_code
|
138
|
+
return result
|
139
|
+
else:
|
140
|
+
return {
|
141
|
+
"code": cleaned_code,
|
142
|
+
"success": False,
|
143
|
+
"std_err": f"Failed to execute code with {response.status}",
|
144
|
+
}
|
khoj/routers/api.py
CHANGED
@@ -44,6 +44,7 @@ from khoj.processor.conversation.offline.chat_model import extract_questions_off
|
|
44
44
|
from khoj.processor.conversation.offline.whisper import transcribe_audio_offline
|
45
45
|
from khoj.processor.conversation.openai.gpt import extract_questions
|
46
46
|
from khoj.processor.conversation.openai.whisper import transcribe_audio
|
47
|
+
from khoj.processor.conversation.utils import defilter_query
|
47
48
|
from khoj.routers.helpers import (
|
48
49
|
ApiUserRateLimiter,
|
49
50
|
ChatEvent,
|
@@ -167,8 +168,8 @@ async def execute_search(
|
|
167
168
|
search_futures += [
|
168
169
|
executor.submit(
|
169
170
|
text_search.query,
|
170
|
-
user,
|
171
171
|
user_query,
|
172
|
+
user,
|
172
173
|
t,
|
173
174
|
question_embedding=encoded_asymmetric_query,
|
174
175
|
max_distance=max_distance,
|
@@ -355,7 +356,7 @@ async def extract_references_and_questions(
|
|
355
356
|
user = request.user.object if request.user.is_authenticated else None
|
356
357
|
|
357
358
|
# Initialize Variables
|
358
|
-
compiled_references: List[
|
359
|
+
compiled_references: List[dict[str, str]] = []
|
359
360
|
inferred_queries: List[str] = []
|
360
361
|
|
361
362
|
agent_has_entries = False
|
@@ -384,9 +385,7 @@ async def extract_references_and_questions(
|
|
384
385
|
return
|
385
386
|
|
386
387
|
# Extract filter terms from user message
|
387
|
-
defiltered_query = q
|
388
|
-
for filter in [DateFilter(), WordFilter(), FileFilter()]:
|
389
|
-
defiltered_query = filter.defilter(defiltered_query)
|
388
|
+
defiltered_query = defilter_query(q)
|
390
389
|
filters_in_query = q.replace(defiltered_query, "").strip()
|
391
390
|
conversation = await sync_to_async(ConversationAdapters.get_conversation_by_id)(conversation_id)
|
392
391
|
|
@@ -502,7 +501,8 @@ async def extract_references_and_questions(
|
|
502
501
|
)
|
503
502
|
search_results = text_search.deduplicated_search_responses(search_results)
|
504
503
|
compiled_references = [
|
505
|
-
{"compiled": item.additional["compiled"], "file": item.additional["file"]}
|
504
|
+
{"query": q, "compiled": item.additional["compiled"], "file": item.additional["file"]}
|
505
|
+
for q, item in zip(inferred_queries, search_results)
|
506
506
|
]
|
507
507
|
|
508
508
|
yield compiled_references, inferred_queries, defiltered_query
|