khoj 1.22.3.dev5__py3-none-any.whl → 1.23.3.dev1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- khoj/database/adapters/__init__.py +8 -3
- khoj/database/migrations/0061_alter_chatmodeloptions_model_type.py +26 -0
- khoj/database/migrations/0061_alter_texttoimagemodelconfig_model_type.py +21 -0
- khoj/database/migrations/0062_merge_20240913_0222.py +14 -0
- khoj/database/models/__init__.py +2 -0
- khoj/interface/compiled/404/index.html +1 -1
- khoj/interface/compiled/_next/static/chunks/app/agents/{page-3c01900e7b5c7e50.js → page-1ac024e05374f91f.js} +1 -1
- khoj/interface/compiled/_next/static/chunks/app/automations/{page-6ea3381528603372.js → page-85e9176b460c5e33.js} +1 -1
- khoj/interface/compiled/_next/static/chunks/app/chat/page-ababf339318a3b50.js +1 -0
- khoj/interface/compiled/_next/static/chunks/app/factchecker/{page-04a19ab1a988976f.js → page-21cf46aca7e6d487.js} +1 -1
- khoj/interface/compiled/_next/static/chunks/app/{page-95ecd0acac7ece82.js → page-b406302925829b15.js} +1 -1
- khoj/interface/compiled/_next/static/chunks/app/search/{page-fa15807b1ad7e30b.js → page-fde8c956cc33a187.js} +1 -1
- khoj/interface/compiled/_next/static/chunks/app/settings/{page-1a2acc46cdabaf4a.js → page-88737126debb4712.js} +1 -1
- khoj/interface/compiled/_next/static/chunks/app/share/chat/{page-e20f54450d3ce6c0.js → page-f11b4fb0f2bc3381.js} +1 -1
- khoj/interface/compiled/_next/static/chunks/{webpack-07fad5db87344b82.js → webpack-f162a207b26413cd.js} +1 -1
- khoj/interface/compiled/_next/static/css/17e284bae7dc4881.css +1 -0
- khoj/interface/compiled/_next/static/css/37a313cb39403a84.css +1 -0
- khoj/interface/compiled/_next/static/css/4cae6c0e5c72fb2d.css +1 -0
- khoj/interface/compiled/_next/static/css/6bde1f2045622ef7.css +1 -0
- khoj/interface/compiled/agents/index.html +1 -1
- khoj/interface/compiled/agents/index.txt +2 -2
- khoj/interface/compiled/automations/index.html +1 -1
- khoj/interface/compiled/automations/index.txt +2 -2
- khoj/interface/compiled/chat/index.html +1 -1
- khoj/interface/compiled/chat/index.txt +2 -2
- khoj/interface/compiled/factchecker/index.html +1 -1
- khoj/interface/compiled/factchecker/index.txt +2 -2
- khoj/interface/compiled/index.html +1 -1
- khoj/interface/compiled/index.txt +2 -2
- khoj/interface/compiled/search/index.html +1 -1
- khoj/interface/compiled/search/index.txt +2 -2
- khoj/interface/compiled/settings/index.html +1 -1
- khoj/interface/compiled/settings/index.txt +2 -2
- khoj/interface/compiled/share/chat/index.html +1 -1
- khoj/interface/compiled/share/chat/index.txt +2 -2
- khoj/interface/email/magic_link.html +1 -1
- khoj/interface/email/task.html +1 -1
- khoj/interface/email/welcome.html +1 -1
- khoj/processor/conversation/google/__init__.py +0 -0
- khoj/processor/conversation/google/gemini_chat.py +221 -0
- khoj/processor/conversation/google/utils.py +192 -0
- khoj/processor/conversation/openai/gpt.py +2 -0
- khoj/processor/conversation/openai/utils.py +35 -10
- khoj/processor/conversation/prompts.py +6 -6
- khoj/processor/conversation/utils.py +16 -5
- khoj/processor/image/generate.py +212 -0
- khoj/processor/tools/online_search.py +2 -1
- khoj/routers/api.py +13 -0
- khoj/routers/api_chat.py +1 -1
- khoj/routers/email.py +6 -1
- khoj/routers/helpers.py +86 -164
- {khoj-1.22.3.dev5.dist-info → khoj-1.23.3.dev1.dist-info}/METADATA +2 -1
- {khoj-1.22.3.dev5.dist-info → khoj-1.23.3.dev1.dist-info}/RECORD +61 -54
- khoj/interface/compiled/_next/static/chunks/app/chat/page-132e5199f954559f.js +0 -1
- khoj/interface/compiled/_next/static/css/149c5104fe3d38b8.css +0 -1
- khoj/interface/compiled/_next/static/css/2272c73fc7a3b571.css +0 -1
- khoj/interface/compiled/_next/static/css/553f9cdcc7a2bcd6.css +0 -1
- khoj/interface/compiled/_next/static/css/a3530ec58b0b660f.css +0 -1
- /khoj/interface/compiled/_next/static/{vjWGo1xJFCitZUk51rujk → BtK3cBCv0oGm04ZdaAvMB}/_buildManifest.js +0 -0
- /khoj/interface/compiled/_next/static/{vjWGo1xJFCitZUk51rujk → BtK3cBCv0oGm04ZdaAvMB}/_ssgManifest.js +0 -0
- /khoj/interface/compiled/_next/static/chunks/{8423-ce22327cf2d2edae.js → 8423-14fc72aec9104ce9.js} +0 -0
- /khoj/interface/compiled/_next/static/chunks/{9178-3a0baad1c172d515.js → 9178-c153fc402c970365.js} +0 -0
- /khoj/interface/compiled/_next/static/chunks/{9417-2e54c6fd056982d8.js → 9417-5d14ac74aaab2c66.js} +0 -0
- {khoj-1.22.3.dev5.dist-info → khoj-1.23.3.dev1.dist-info}/WHEEL +0 -0
- {khoj-1.22.3.dev5.dist-info → khoj-1.23.3.dev1.dist-info}/entry_points.txt +0 -0
- {khoj-1.22.3.dev5.dist-info → khoj-1.23.3.dev1.dist-info}/licenses/LICENSE +0 -0
@@ -0,0 +1,192 @@
|
|
1
|
+
import logging
|
2
|
+
import random
|
3
|
+
from threading import Thread
|
4
|
+
|
5
|
+
import google.generativeai as genai
|
6
|
+
from google.generativeai.types.answer_types import FinishReason
|
7
|
+
from google.generativeai.types.generation_types import (
|
8
|
+
GenerateContentResponse,
|
9
|
+
StopCandidateException,
|
10
|
+
)
|
11
|
+
from google.generativeai.types.safety_types import (
|
12
|
+
HarmBlockThreshold,
|
13
|
+
HarmCategory,
|
14
|
+
HarmProbability,
|
15
|
+
)
|
16
|
+
from tenacity import (
|
17
|
+
before_sleep_log,
|
18
|
+
retry,
|
19
|
+
stop_after_attempt,
|
20
|
+
wait_exponential,
|
21
|
+
wait_random_exponential,
|
22
|
+
)
|
23
|
+
|
24
|
+
from khoj.processor.conversation.utils import ThreadedGenerator
|
25
|
+
|
26
|
+
logger = logging.getLogger(__name__)
|
27
|
+
|
28
|
+
|
29
|
+
DEFAULT_MAX_TOKENS_GEMINI = 8192
|
30
|
+
|
31
|
+
|
32
|
+
@retry(
|
33
|
+
wait=wait_random_exponential(min=1, max=10),
|
34
|
+
stop=stop_after_attempt(2),
|
35
|
+
before_sleep=before_sleep_log(logger, logging.DEBUG),
|
36
|
+
reraise=True,
|
37
|
+
)
|
38
|
+
def gemini_completion_with_backoff(
|
39
|
+
messages, system_prompt, model_name, temperature=0, api_key=None, model_kwargs=None, max_tokens=None
|
40
|
+
) -> str:
|
41
|
+
genai.configure(api_key=api_key)
|
42
|
+
max_tokens = max_tokens or DEFAULT_MAX_TOKENS_GEMINI
|
43
|
+
model_kwargs = model_kwargs or dict()
|
44
|
+
model_kwargs["temperature"] = temperature
|
45
|
+
model_kwargs["max_output_tokens"] = max_tokens
|
46
|
+
model = genai.GenerativeModel(
|
47
|
+
model_name,
|
48
|
+
generation_config=model_kwargs,
|
49
|
+
system_instruction=system_prompt,
|
50
|
+
safety_settings={
|
51
|
+
HarmCategory.HARM_CATEGORY_DANGEROUS_CONTENT: HarmBlockThreshold.BLOCK_ONLY_HIGH,
|
52
|
+
HarmCategory.HARM_CATEGORY_HARASSMENT: HarmBlockThreshold.BLOCK_ONLY_HIGH,
|
53
|
+
HarmCategory.HARM_CATEGORY_HATE_SPEECH: HarmBlockThreshold.BLOCK_ONLY_HIGH,
|
54
|
+
HarmCategory.HARM_CATEGORY_SEXUALLY_EXPLICIT: HarmBlockThreshold.BLOCK_ONLY_HIGH,
|
55
|
+
},
|
56
|
+
)
|
57
|
+
|
58
|
+
formatted_messages = [{"role": message.role, "parts": [message.content]} for message in messages]
|
59
|
+
|
60
|
+
# Start chat session. All messages up to the last are considered to be part of the chat history
|
61
|
+
chat_session = model.start_chat(history=formatted_messages[0:-1])
|
62
|
+
|
63
|
+
try:
|
64
|
+
# Generate the response. The last message is considered to be the current prompt
|
65
|
+
aggregated_response = chat_session.send_message(formatted_messages[-1]["parts"][0])
|
66
|
+
return aggregated_response.text
|
67
|
+
except StopCandidateException as e:
|
68
|
+
response_message, _ = handle_gemini_response(e.args)
|
69
|
+
# Respond with reason for stopping
|
70
|
+
logger.warning(
|
71
|
+
f"LLM Response Prevented for {model_name}: {response_message}.\n"
|
72
|
+
+ f"Last Message by {messages[-1].role}: {messages[-1].content}"
|
73
|
+
)
|
74
|
+
return response_message
|
75
|
+
|
76
|
+
|
77
|
+
@retry(
|
78
|
+
wait=wait_exponential(multiplier=1, min=4, max=10),
|
79
|
+
stop=stop_after_attempt(2),
|
80
|
+
before_sleep=before_sleep_log(logger, logging.DEBUG),
|
81
|
+
reraise=True,
|
82
|
+
)
|
83
|
+
def gemini_chat_completion_with_backoff(
|
84
|
+
messages,
|
85
|
+
compiled_references,
|
86
|
+
online_results,
|
87
|
+
model_name,
|
88
|
+
temperature,
|
89
|
+
api_key,
|
90
|
+
system_prompt,
|
91
|
+
max_prompt_size=None,
|
92
|
+
completion_func=None,
|
93
|
+
model_kwargs=None,
|
94
|
+
):
|
95
|
+
g = ThreadedGenerator(compiled_references, online_results, completion_func=completion_func)
|
96
|
+
t = Thread(
|
97
|
+
target=gemini_llm_thread,
|
98
|
+
args=(g, messages, system_prompt, model_name, temperature, api_key, max_prompt_size, model_kwargs),
|
99
|
+
)
|
100
|
+
t.start()
|
101
|
+
return g
|
102
|
+
|
103
|
+
|
104
|
+
def gemini_llm_thread(
|
105
|
+
g, messages, system_prompt, model_name, temperature, api_key, max_prompt_size=None, model_kwargs=None
|
106
|
+
):
|
107
|
+
try:
|
108
|
+
genai.configure(api_key=api_key)
|
109
|
+
max_tokens = max_prompt_size or DEFAULT_MAX_TOKENS_GEMINI
|
110
|
+
model_kwargs = model_kwargs or dict()
|
111
|
+
model_kwargs["temperature"] = temperature
|
112
|
+
model_kwargs["max_output_tokens"] = max_tokens
|
113
|
+
model_kwargs["stop_sequences"] = ["Notes:\n["]
|
114
|
+
model = genai.GenerativeModel(
|
115
|
+
model_name,
|
116
|
+
generation_config=model_kwargs,
|
117
|
+
system_instruction=system_prompt,
|
118
|
+
safety_settings={
|
119
|
+
HarmCategory.HARM_CATEGORY_DANGEROUS_CONTENT: HarmBlockThreshold.BLOCK_ONLY_HIGH,
|
120
|
+
HarmCategory.HARM_CATEGORY_HARASSMENT: HarmBlockThreshold.BLOCK_ONLY_HIGH,
|
121
|
+
HarmCategory.HARM_CATEGORY_HATE_SPEECH: HarmBlockThreshold.BLOCK_ONLY_HIGH,
|
122
|
+
HarmCategory.HARM_CATEGORY_SEXUALLY_EXPLICIT: HarmBlockThreshold.BLOCK_ONLY_HIGH,
|
123
|
+
},
|
124
|
+
)
|
125
|
+
|
126
|
+
formatted_messages = [{"role": message.role, "parts": [message.content]} for message in messages]
|
127
|
+
# all messages up to the last are considered to be part of the chat history
|
128
|
+
chat_session = model.start_chat(history=formatted_messages[0:-1])
|
129
|
+
# the last message is considered to be the current prompt
|
130
|
+
for chunk in chat_session.send_message(formatted_messages[-1]["parts"][0], stream=True):
|
131
|
+
message, stopped = handle_gemini_response(chunk.candidates, chunk.prompt_feedback)
|
132
|
+
message = message or chunk.text
|
133
|
+
g.send(message)
|
134
|
+
if stopped:
|
135
|
+
raise StopCandidateException(message)
|
136
|
+
except StopCandidateException as e:
|
137
|
+
logger.warning(
|
138
|
+
f"LLM Response Prevented for {model_name}: {e.args[0]}.\n"
|
139
|
+
+ f"Last Message by {messages[-1].role}: {messages[-1].content}"
|
140
|
+
)
|
141
|
+
except Exception as e:
|
142
|
+
logger.error(f"Error in gemini_llm_thread: {e}", exc_info=True)
|
143
|
+
finally:
|
144
|
+
g.close()
|
145
|
+
|
146
|
+
|
147
|
+
def handle_gemini_response(candidates, prompt_feedback=None):
|
148
|
+
"""Check if Gemini response was blocked and return an explanatory error message."""
|
149
|
+
# Check if the response was blocked due to safety concerns with the prompt
|
150
|
+
if len(candidates) == 0 and prompt_feedback:
|
151
|
+
message = f"\nI'd prefer to not respond to that due to **{prompt_feedback.block_reason.name}** issues with your query."
|
152
|
+
stopped = True
|
153
|
+
# Check if the response was blocked due to safety concerns with the generated content
|
154
|
+
elif candidates[0].finish_reason == FinishReason.SAFETY:
|
155
|
+
message = generate_safety_response(candidates[0].safety_ratings)
|
156
|
+
stopped = True
|
157
|
+
# Check if the response was stopped due to reaching maximum token limit or other reasons
|
158
|
+
elif candidates[0].finish_reason != FinishReason.STOP:
|
159
|
+
message = f"\nI can't talk further about that because of **{candidates[0].finish_reason.name} issue.**"
|
160
|
+
stopped = True
|
161
|
+
# Otherwise, the response is valid and can be used
|
162
|
+
else:
|
163
|
+
message = None
|
164
|
+
stopped = False
|
165
|
+
return message, stopped
|
166
|
+
|
167
|
+
|
168
|
+
def generate_safety_response(safety_ratings):
|
169
|
+
"""Generate a conversational response based on the safety ratings of the response."""
|
170
|
+
# Get the safety rating with the highest probability
|
171
|
+
max_safety_rating = sorted(safety_ratings, key=lambda x: x.probability, reverse=True)[0]
|
172
|
+
# Remove the "HARM_CATEGORY_" prefix and title case the category name
|
173
|
+
max_safety_category = " ".join(max_safety_rating.category.name.split("_")[2:]).title()
|
174
|
+
# Add a bit of variety to the discomfort level based on the safety rating probability
|
175
|
+
discomfort_level = {
|
176
|
+
HarmProbability.HARM_PROBABILITY_UNSPECIFIED: " ",
|
177
|
+
HarmProbability.LOW: "a bit ",
|
178
|
+
HarmProbability.MEDIUM: "moderately ",
|
179
|
+
HarmProbability.HIGH: random.choice(["very ", "quite ", "fairly "]),
|
180
|
+
}[max_safety_rating.probability]
|
181
|
+
# Generate a response using a random response template
|
182
|
+
safety_response_choice = random.choice(
|
183
|
+
[
|
184
|
+
"\nUmm, I'd rather not to respond to that. The conversation has some probability of going into **{category}** territory.",
|
185
|
+
"\nI'd prefer not to talk about **{category}** related topics. It makes me {discomfort_level}uncomfortable.",
|
186
|
+
"\nI feel {discomfort_level}squeamish talking about **{category}** related stuff! Can we talk about something less controversial?",
|
187
|
+
"\nThat sounds {discomfort_level}outside the [Overtone Window](https://en.wikipedia.org/wiki/Overton_window) of acceptable conversation. Should we stick to something less {category} related?",
|
188
|
+
]
|
189
|
+
)
|
190
|
+
return safety_response_choice.format(
|
191
|
+
category=max_safety_category, probability=max_safety_rating.probability.name, discomfort_level=discomfort_level
|
192
|
+
)
|
@@ -14,6 +14,7 @@ from khoj.processor.conversation.openai.utils import (
|
|
14
14
|
from khoj.processor.conversation.utils import (
|
15
15
|
construct_structured_message,
|
16
16
|
generate_chatml_messages_with_context,
|
17
|
+
remove_json_codeblock,
|
17
18
|
)
|
18
19
|
from khoj.utils.helpers import ConversationCommand, is_none_or_empty
|
19
20
|
from khoj.utils.rawconfig import LocationData
|
@@ -85,6 +86,7 @@ def extract_questions(
|
|
85
86
|
# Extract, Clean Message from GPT's Response
|
86
87
|
try:
|
87
88
|
response = response.strip()
|
89
|
+
response = remove_json_codeblock(response)
|
88
90
|
response = json.loads(response)
|
89
91
|
response = [q.strip() for q in response["queries"] if q.strip()]
|
90
92
|
if not isinstance(response, list) or not response:
|
@@ -45,15 +45,28 @@ def completion_with_backoff(
|
|
45
45
|
openai_clients[client_key] = client
|
46
46
|
|
47
47
|
formatted_messages = [{"role": message.role, "content": message.content} for message in messages]
|
48
|
+
stream = True
|
49
|
+
|
50
|
+
# Update request parameters for compatability with o1 model series
|
51
|
+
# Refer: https://platform.openai.com/docs/guides/reasoning/beta-limitations
|
52
|
+
if model.startswith("o1"):
|
53
|
+
stream = False
|
54
|
+
temperature = 1
|
55
|
+
model_kwargs.pop("stop", None)
|
56
|
+
model_kwargs.pop("response_format", None)
|
48
57
|
|
49
58
|
chat = client.chat.completions.create(
|
50
|
-
stream=
|
59
|
+
stream=stream,
|
51
60
|
messages=formatted_messages, # type: ignore
|
52
61
|
model=model, # type: ignore
|
53
62
|
temperature=temperature,
|
54
63
|
timeout=20,
|
55
64
|
**(model_kwargs or dict()),
|
56
65
|
)
|
66
|
+
|
67
|
+
if not stream:
|
68
|
+
return chat.choices[0].message.content
|
69
|
+
|
57
70
|
aggregated_response = ""
|
58
71
|
for chunk in chat:
|
59
72
|
if len(chunk.choices) == 0:
|
@@ -112,9 +125,18 @@ def llm_thread(g, messages, model_name, temperature, openai_api_key=None, api_ba
|
|
112
125
|
client: openai.OpenAI = openai_clients[client_key]
|
113
126
|
|
114
127
|
formatted_messages = [{"role": message.role, "content": message.content} for message in messages]
|
128
|
+
stream = True
|
129
|
+
|
130
|
+
# Update request parameters for compatability with o1 model series
|
131
|
+
# Refer: https://platform.openai.com/docs/guides/reasoning/beta-limitations
|
132
|
+
if model_name.startswith("o1"):
|
133
|
+
stream = False
|
134
|
+
temperature = 1
|
135
|
+
model_kwargs.pop("stop", None)
|
136
|
+
model_kwargs.pop("response_format", None)
|
115
137
|
|
116
138
|
chat = client.chat.completions.create(
|
117
|
-
stream=
|
139
|
+
stream=stream,
|
118
140
|
messages=formatted_messages,
|
119
141
|
model=model_name, # type: ignore
|
120
142
|
temperature=temperature,
|
@@ -122,14 +144,17 @@ def llm_thread(g, messages, model_name, temperature, openai_api_key=None, api_ba
|
|
122
144
|
**(model_kwargs or dict()),
|
123
145
|
)
|
124
146
|
|
125
|
-
|
126
|
-
|
127
|
-
|
128
|
-
|
129
|
-
|
130
|
-
|
131
|
-
|
132
|
-
|
147
|
+
if not stream:
|
148
|
+
g.send(chat.choices[0].message.content)
|
149
|
+
else:
|
150
|
+
for chunk in chat:
|
151
|
+
if len(chunk.choices) == 0:
|
152
|
+
continue
|
153
|
+
delta_chunk = chunk.choices[0].delta
|
154
|
+
if isinstance(delta_chunk, str):
|
155
|
+
g.send(delta_chunk)
|
156
|
+
elif delta_chunk.content:
|
157
|
+
g.send(delta_chunk.content)
|
133
158
|
except Exception as e:
|
134
159
|
logger.error(f"Error in llm_thread: {e}", exc_info=True)
|
135
160
|
finally:
|
@@ -13,8 +13,8 @@ You were created by Khoj Inc. with the following capabilities:
|
|
13
13
|
- You *CAN* generate images, look-up real-time information from the internet, set reminders and answer questions based on the user's notes.
|
14
14
|
- Say "I don't know" or "I don't understand" if you don't know what to say or if you don't know the answer to a question.
|
15
15
|
- Make sure to use the specific LaTeX math mode delimiters for your response. LaTex math mode specific delimiters as following
|
16
|
-
- inline math mode :
|
17
|
-
- display math mode: insert linebreak after opening
|
16
|
+
- inline math mode : \\( and \\)
|
17
|
+
- display math mode: insert linebreak after opening $$, \\[ and before closing $$, \\]
|
18
18
|
- Ask crisp follow-up questions to get additional context, when the answer cannot be inferred from the provided notes or past conversations.
|
19
19
|
- Sometimes the user will share personal information that needs to be remembered, like an account ID or a residential address. These can be acknowledged with a simple "Got it" or "Okay".
|
20
20
|
- Provide inline references to quotes from the user's notes or any web pages you refer to in your responses in markdown format. For example, "The farmer had ten sheep. [1](https://example.com)". *ALWAYS CITE YOUR SOURCES AND PROVIDE REFERENCES*. Add them inline to directly support your claim.
|
@@ -128,8 +128,8 @@ User's Notes:
|
|
128
128
|
## --
|
129
129
|
|
130
130
|
image_generation_improve_prompt_base = """
|
131
|
-
You are a talented
|
132
|
-
|
131
|
+
You are a talented media artist with the ability to describe images to compose in professional, fine detail.
|
132
|
+
Generate a vivid description of the image to be rendered using the provided context and user prompt below:
|
133
133
|
|
134
134
|
Today's Date: {current_date}
|
135
135
|
User's Location: {location}
|
@@ -145,10 +145,10 @@ Conversation Log:
|
|
145
145
|
|
146
146
|
User Prompt: "{query}"
|
147
147
|
|
148
|
-
Now generate an
|
148
|
+
Now generate an professional description of the image to generate in vivid, fine detail.
|
149
149
|
- Use today's date, user's location, user's notes and online references to weave in any context that will improve the image generation.
|
150
150
|
- Retain any important information and follow any instructions in the conversation log or user prompt.
|
151
|
-
- Add specific, fine position details to compose the image.
|
151
|
+
- Add specific, fine position details. Mention painting style, camera parameters to compose the image.
|
152
152
|
- Ensure your improved prompt is in prose format."""
|
153
153
|
|
154
154
|
image_generation_improve_prompt_dalle = PromptTemplate.from_template(
|
@@ -1,4 +1,3 @@
|
|
1
|
-
import json
|
2
1
|
import logging
|
3
2
|
import math
|
4
3
|
import queue
|
@@ -24,6 +23,8 @@ model_to_prompt_size = {
|
|
24
23
|
"gpt-4-0125-preview": 20000,
|
25
24
|
"gpt-4-turbo-preview": 20000,
|
26
25
|
"gpt-4o-mini": 20000,
|
26
|
+
"o1-preview": 20000,
|
27
|
+
"o1-mini": 20000,
|
27
28
|
"TheBloke/Mistral-7B-Instruct-v0.2-GGUF": 3500,
|
28
29
|
"NousResearch/Hermes-2-Pro-Mistral-7B-GGUF": 3500,
|
29
30
|
"bartowski/Meta-Llama-3.1-8B-Instruct-GGUF": 20000,
|
@@ -220,8 +221,9 @@ def truncate_messages(
|
|
220
221
|
try:
|
221
222
|
if loaded_model:
|
222
223
|
encoder = loaded_model.tokenizer()
|
223
|
-
elif model_name.startswith("gpt-"):
|
224
|
-
|
224
|
+
elif model_name.startswith("gpt-") or model_name.startswith("o1"):
|
225
|
+
# as tiktoken doesn't recognize o1 model series yet
|
226
|
+
encoder = tiktoken.encoding_for_model("gpt-4o" if model_name.startswith("o1") else model_name)
|
225
227
|
elif tokenizer_name:
|
226
228
|
if tokenizer_name in state.pretrained_tokenizers:
|
227
229
|
encoder = state.pretrained_tokenizers[tokenizer_name]
|
@@ -236,7 +238,7 @@ def truncate_messages(
|
|
236
238
|
else:
|
237
239
|
encoder = AutoTokenizer.from_pretrained(default_tokenizer)
|
238
240
|
state.pretrained_tokenizers[default_tokenizer] = encoder
|
239
|
-
logger.
|
241
|
+
logger.debug(
|
240
242
|
f"Fallback to default chat model tokenizer: {tokenizer_name}.\nConfigure tokenizer for unsupported model: {model_name} in Khoj settings to improve context stuffing."
|
241
243
|
)
|
242
244
|
|
@@ -278,10 +280,19 @@ def truncate_messages(
|
|
278
280
|
)
|
279
281
|
|
280
282
|
if system_message:
|
281
|
-
|
283
|
+
# Default system message role is system.
|
284
|
+
# Fallback to system message role of user for models that do not support this role like gemma-2 and openai's o1 model series.
|
285
|
+
system_message.role = "user" if "gemma-2" in model_name or model_name.startswith("o1") else "system"
|
282
286
|
return messages + [system_message] if system_message else messages
|
283
287
|
|
284
288
|
|
285
289
|
def reciprocal_conversation_to_chatml(message_pair):
|
286
290
|
"""Convert a single back and forth between user and assistant to chatml format"""
|
287
291
|
return [ChatMessage(content=message, role=role) for message, role in zip(message_pair, ["user", "assistant"])]
|
292
|
+
|
293
|
+
|
294
|
+
def remove_json_codeblock(response):
|
295
|
+
"""Remove any markdown json codeblock formatting if present. Useful for non schema enforceable models"""
|
296
|
+
if response.startswith("```json") and response.endswith("```"):
|
297
|
+
response = response[7:-3]
|
298
|
+
return response
|
@@ -0,0 +1,212 @@
|
|
1
|
+
import base64
|
2
|
+
import io
|
3
|
+
import logging
|
4
|
+
import time
|
5
|
+
from typing import Any, Callable, Dict, List, Optional
|
6
|
+
|
7
|
+
import openai
|
8
|
+
import requests
|
9
|
+
|
10
|
+
from khoj.database.adapters import ConversationAdapters
|
11
|
+
from khoj.database.models import KhojUser, TextToImageModelConfig
|
12
|
+
from khoj.routers.helpers import ChatEvent, generate_better_image_prompt
|
13
|
+
from khoj.routers.storage import upload_image
|
14
|
+
from khoj.utils import state
|
15
|
+
from khoj.utils.helpers import ImageIntentType, convert_image_to_webp, timer
|
16
|
+
from khoj.utils.rawconfig import LocationData
|
17
|
+
|
18
|
+
logger = logging.getLogger(__name__)
|
19
|
+
|
20
|
+
|
21
|
+
async def text_to_image(
|
22
|
+
message: str,
|
23
|
+
user: KhojUser,
|
24
|
+
conversation_log: dict,
|
25
|
+
location_data: LocationData,
|
26
|
+
references: List[Dict[str, Any]],
|
27
|
+
online_results: Dict[str, Any],
|
28
|
+
subscribed: bool = False,
|
29
|
+
send_status_func: Optional[Callable] = None,
|
30
|
+
uploaded_image_url: Optional[str] = None,
|
31
|
+
):
|
32
|
+
status_code = 200
|
33
|
+
image = None
|
34
|
+
image_url = None
|
35
|
+
intent_type = ImageIntentType.TEXT_TO_IMAGE_V3
|
36
|
+
|
37
|
+
text_to_image_config = await ConversationAdapters.aget_user_text_to_image_model(user)
|
38
|
+
if not text_to_image_config:
|
39
|
+
# If the user has not configured a text to image model, return an unsupported on server error
|
40
|
+
status_code = 501
|
41
|
+
message = "Failed to generate image. Setup image generation on the server."
|
42
|
+
yield image_url or image, status_code, message, intent_type.value
|
43
|
+
return
|
44
|
+
|
45
|
+
text2image_model = text_to_image_config.model_name
|
46
|
+
chat_history = ""
|
47
|
+
for chat in conversation_log.get("chat", [])[-4:]:
|
48
|
+
if chat["by"] == "khoj" and chat["intent"].get("type") in ["remember", "reminder"]:
|
49
|
+
chat_history += f"Q: {chat['intent']['query']}\n"
|
50
|
+
chat_history += f"A: {chat['message']}\n"
|
51
|
+
elif chat["by"] == "khoj" and "text-to-image" in chat["intent"].get("type"):
|
52
|
+
chat_history += f"Q: Prompt: {chat['intent']['query']}\n"
|
53
|
+
chat_history += f"A: Improved Prompt: {chat['intent']['inferred-queries'][0]}\n"
|
54
|
+
|
55
|
+
if send_status_func:
|
56
|
+
async for event in send_status_func("**Enhancing the Painting Prompt**"):
|
57
|
+
yield {ChatEvent.STATUS: event}
|
58
|
+
|
59
|
+
# Generate a better image prompt
|
60
|
+
# Use the user's message, chat history, and other context
|
61
|
+
image_prompt = await generate_better_image_prompt(
|
62
|
+
message,
|
63
|
+
chat_history,
|
64
|
+
location_data=location_data,
|
65
|
+
note_references=references,
|
66
|
+
online_results=online_results,
|
67
|
+
model_type=text_to_image_config.model_type,
|
68
|
+
subscribed=subscribed,
|
69
|
+
uploaded_image_url=uploaded_image_url,
|
70
|
+
)
|
71
|
+
|
72
|
+
if send_status_func:
|
73
|
+
async for event in send_status_func(f"**Painting to Imagine**:\n{image_prompt}"):
|
74
|
+
yield {ChatEvent.STATUS: event}
|
75
|
+
|
76
|
+
# Generate image using the configured model and API
|
77
|
+
with timer(f"Generate image with {text_to_image_config.model_type}", logger):
|
78
|
+
try:
|
79
|
+
if text_to_image_config.model_type == TextToImageModelConfig.ModelType.OPENAI:
|
80
|
+
webp_image_bytes = generate_image_with_openai(image_prompt, text_to_image_config, text2image_model)
|
81
|
+
elif text_to_image_config.model_type == TextToImageModelConfig.ModelType.STABILITYAI:
|
82
|
+
webp_image_bytes = generate_image_with_stability(image_prompt, text_to_image_config, text2image_model)
|
83
|
+
elif text_to_image_config.model_type == TextToImageModelConfig.ModelType.REPLICATE:
|
84
|
+
webp_image_bytes = generate_image_with_replicate(image_prompt, text_to_image_config, text2image_model)
|
85
|
+
except openai.OpenAIError or openai.BadRequestError or openai.APIConnectionError as e:
|
86
|
+
if "content_policy_violation" in e.message:
|
87
|
+
logger.error(f"Image Generation blocked by OpenAI: {e}")
|
88
|
+
status_code = e.status_code # type: ignore
|
89
|
+
message = f"Image generation blocked by OpenAI: {e.message}" # type: ignore
|
90
|
+
yield image_url or image, status_code, message, intent_type.value
|
91
|
+
return
|
92
|
+
else:
|
93
|
+
logger.error(f"Image Generation failed with {e}", exc_info=True)
|
94
|
+
message = f"Image generation failed with OpenAI error: {e.message}" # type: ignore
|
95
|
+
status_code = e.status_code # type: ignore
|
96
|
+
yield image_url or image, status_code, message, intent_type.value
|
97
|
+
return
|
98
|
+
except requests.RequestException as e:
|
99
|
+
logger.error(f"Image Generation failed with {e}", exc_info=True)
|
100
|
+
message = f"Image generation using {text2image_model} via {text_to_image_config.model_type} failed with error: {e}"
|
101
|
+
status_code = 502
|
102
|
+
yield image_url or image, status_code, message, intent_type.value
|
103
|
+
return
|
104
|
+
|
105
|
+
# Decide how to store the generated image
|
106
|
+
with timer("Upload image to S3", logger):
|
107
|
+
image_url = upload_image(webp_image_bytes, user.uuid)
|
108
|
+
if image_url:
|
109
|
+
intent_type = ImageIntentType.TEXT_TO_IMAGE2
|
110
|
+
else:
|
111
|
+
intent_type = ImageIntentType.TEXT_TO_IMAGE_V3
|
112
|
+
image = base64.b64encode(webp_image_bytes).decode("utf-8")
|
113
|
+
|
114
|
+
yield image_url or image, status_code, image_prompt, intent_type.value
|
115
|
+
|
116
|
+
|
117
|
+
def generate_image_with_openai(
|
118
|
+
improved_image_prompt: str, text_to_image_config: TextToImageModelConfig, text2image_model: str
|
119
|
+
):
|
120
|
+
"Generate image using OpenAI API"
|
121
|
+
|
122
|
+
# Get the API key from the user's configuration
|
123
|
+
if text_to_image_config.api_key:
|
124
|
+
api_key = text_to_image_config.api_key
|
125
|
+
elif text_to_image_config.openai_config:
|
126
|
+
api_key = text_to_image_config.openai_config.api_key
|
127
|
+
elif state.openai_client:
|
128
|
+
api_key = state.openai_client.api_key
|
129
|
+
auth_header = {"Authorization": f"Bearer {api_key}"} if api_key else {}
|
130
|
+
|
131
|
+
# Generate image using OpenAI API
|
132
|
+
OPENAI_IMAGE_GEN_STYLE = "vivid"
|
133
|
+
response = state.openai_client.images.generate(
|
134
|
+
prompt=improved_image_prompt,
|
135
|
+
model=text2image_model,
|
136
|
+
style=OPENAI_IMAGE_GEN_STYLE,
|
137
|
+
response_format="b64_json",
|
138
|
+
extra_headers=auth_header,
|
139
|
+
)
|
140
|
+
|
141
|
+
# Extract the base64 image from the response
|
142
|
+
image = response.data[0].b64_json
|
143
|
+
# Decode base64 png and convert it to webp for faster loading
|
144
|
+
return convert_image_to_webp(base64.b64decode(image))
|
145
|
+
|
146
|
+
|
147
|
+
def generate_image_with_stability(
|
148
|
+
improved_image_prompt: str, text_to_image_config: TextToImageModelConfig, text2image_model: str
|
149
|
+
):
|
150
|
+
"Generate image using Stability AI"
|
151
|
+
|
152
|
+
# Call Stability AI API to generate image
|
153
|
+
response = requests.post(
|
154
|
+
f"https://api.stability.ai/v2beta/stable-image/generate/sd3",
|
155
|
+
headers={"authorization": f"Bearer {text_to_image_config.api_key}", "accept": "image/*"},
|
156
|
+
files={"none": ""},
|
157
|
+
data={
|
158
|
+
"prompt": improved_image_prompt,
|
159
|
+
"model": text2image_model,
|
160
|
+
"mode": "text-to-image",
|
161
|
+
"output_format": "png",
|
162
|
+
"aspect_ratio": "1:1",
|
163
|
+
},
|
164
|
+
)
|
165
|
+
# Convert png to webp for faster loading
|
166
|
+
return convert_image_to_webp(response.content)
|
167
|
+
|
168
|
+
|
169
|
+
def generate_image_with_replicate(
|
170
|
+
improved_image_prompt: str, text_to_image_config: TextToImageModelConfig, text2image_model: str
|
171
|
+
):
|
172
|
+
"Generate image using Replicate API"
|
173
|
+
|
174
|
+
# Create image generation task on Replicate
|
175
|
+
replicate_create_prediction_url = f"https://api.replicate.com/v1/models/{text2image_model}/predictions"
|
176
|
+
headers = {
|
177
|
+
"Authorization": f"Bearer {text_to_image_config.api_key}",
|
178
|
+
"Content-Type": "application/json",
|
179
|
+
}
|
180
|
+
json = {
|
181
|
+
"input": {
|
182
|
+
"prompt": improved_image_prompt,
|
183
|
+
"num_outputs": 1,
|
184
|
+
"aspect_ratio": "1:1",
|
185
|
+
"output_format": "webp",
|
186
|
+
"output_quality": 100,
|
187
|
+
}
|
188
|
+
}
|
189
|
+
create_prediction = requests.post(replicate_create_prediction_url, headers=headers, json=json).json()
|
190
|
+
|
191
|
+
# Get status of image generation task
|
192
|
+
get_prediction_url = create_prediction["urls"]["get"]
|
193
|
+
get_prediction = requests.get(get_prediction_url, headers=headers).json()
|
194
|
+
status = get_prediction["status"]
|
195
|
+
retry_count = 1
|
196
|
+
|
197
|
+
# Poll the image generation task for completion status
|
198
|
+
while status not in ["succeeded", "failed", "canceled"] and retry_count < 20:
|
199
|
+
time.sleep(2)
|
200
|
+
get_prediction = requests.get(get_prediction_url, headers=headers).json()
|
201
|
+
status = get_prediction["status"]
|
202
|
+
retry_count += 1
|
203
|
+
|
204
|
+
# Raise exception if the image generation task fails
|
205
|
+
if status != "succeeded":
|
206
|
+
if retry_count >= 10:
|
207
|
+
raise requests.RequestException("Image generation timed out")
|
208
|
+
raise requests.RequestException(f"Image generation failed with status: {status}")
|
209
|
+
|
210
|
+
# Get the generated image
|
211
|
+
image_url = get_prediction["output"][0] if isinstance(get_prediction["output"], list) else get_prediction["output"]
|
212
|
+
return io.BytesIO(requests.get(image_url).content).getvalue()
|
@@ -7,6 +7,7 @@ from collections import defaultdict
|
|
7
7
|
from typing import Callable, Dict, List, Optional, Tuple, Union
|
8
8
|
|
9
9
|
import aiohttp
|
10
|
+
import requests
|
10
11
|
from bs4 import BeautifulSoup
|
11
12
|
from markdownify import markdownify
|
12
13
|
|
@@ -94,7 +95,7 @@ async def search_online(
|
|
94
95
|
|
95
96
|
# Read, extract relevant info from the retrieved web pages
|
96
97
|
if webpages:
|
97
|
-
webpage_links = [link for link, _, _ in webpages]
|
98
|
+
webpage_links = set([link for link, _, _ in webpages])
|
98
99
|
logger.info(f"Reading web pages at: {list(webpage_links)}")
|
99
100
|
if send_status_func:
|
100
101
|
webpage_links_str = "\n- " + "\n- ".join(list(webpage_links))
|
khoj/routers/api.py
CHANGED
@@ -31,6 +31,7 @@ from khoj.database.models import ChatModelOptions, KhojUser, SpeechToTextModelOp
|
|
31
31
|
from khoj.processor.conversation.anthropic.anthropic_chat import (
|
32
32
|
extract_questions_anthropic,
|
33
33
|
)
|
34
|
+
from khoj.processor.conversation.google.gemini_chat import extract_questions_gemini
|
34
35
|
from khoj.processor.conversation.offline.chat_model import extract_questions_offline
|
35
36
|
from khoj.processor.conversation.offline.whisper import transcribe_audio_offline
|
36
37
|
from khoj.processor.conversation.openai.gpt import extract_questions
|
@@ -419,6 +420,18 @@ async def extract_references_and_questions(
|
|
419
420
|
location_data=location_data,
|
420
421
|
user=user,
|
421
422
|
)
|
423
|
+
elif conversation_config.model_type == ChatModelOptions.ModelType.GOOGLE:
|
424
|
+
api_key = conversation_config.openai_config.api_key
|
425
|
+
chat_model = conversation_config.chat_model
|
426
|
+
inferred_queries = extract_questions_gemini(
|
427
|
+
defiltered_query,
|
428
|
+
model=chat_model,
|
429
|
+
api_key=api_key,
|
430
|
+
conversation_log=meta_log,
|
431
|
+
location_data=location_data,
|
432
|
+
max_tokens=conversation_config.max_prompt_size,
|
433
|
+
user=user,
|
434
|
+
)
|
422
435
|
|
423
436
|
# Collate search results as context for GPT
|
424
437
|
with timer("Searching knowledge base took", logger):
|
khoj/routers/api_chat.py
CHANGED
@@ -26,6 +26,7 @@ from khoj.database.adapters import (
|
|
26
26
|
from khoj.database.models import KhojUser
|
27
27
|
from khoj.processor.conversation.prompts import help_message, no_entries_found
|
28
28
|
from khoj.processor.conversation.utils import save_to_conversation_log
|
29
|
+
from khoj.processor.image.generate import text_to_image
|
29
30
|
from khoj.processor.speech.text_to_speech import generate_text_to_speech
|
30
31
|
from khoj.processor.tools.online_search import read_webpages, search_online
|
31
32
|
from khoj.routers.api import extract_references_and_questions
|
@@ -44,7 +45,6 @@ from khoj.routers.helpers import (
|
|
44
45
|
is_query_empty,
|
45
46
|
is_ready_to_chat,
|
46
47
|
read_chat_stream,
|
47
|
-
text_to_image,
|
48
48
|
update_telemetry_state,
|
49
49
|
validate_conversation_config,
|
50
50
|
)
|
khoj/routers/email.py
CHANGED
@@ -44,7 +44,12 @@ async def send_magic_link_email(email, unique_id, host):
|
|
44
44
|
html_content = template.render(link=f"{host}auth/magic?code={unique_id}")
|
45
45
|
|
46
46
|
resend.Emails.send(
|
47
|
-
{
|
47
|
+
{
|
48
|
+
"sender": os.environ.get("RESEND_EMAIL", "noreply@khoj.dev"),
|
49
|
+
"to": email,
|
50
|
+
"subject": "Your Sign-In Link for Khoj 🚀",
|
51
|
+
"html": html_content,
|
52
|
+
}
|
48
53
|
)
|
49
54
|
|
50
55
|
|