khoj 1.22.3.dev5__py3-none-any.whl → 1.23.3.dev1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (66) hide show
  1. khoj/database/adapters/__init__.py +8 -3
  2. khoj/database/migrations/0061_alter_chatmodeloptions_model_type.py +26 -0
  3. khoj/database/migrations/0061_alter_texttoimagemodelconfig_model_type.py +21 -0
  4. khoj/database/migrations/0062_merge_20240913_0222.py +14 -0
  5. khoj/database/models/__init__.py +2 -0
  6. khoj/interface/compiled/404/index.html +1 -1
  7. khoj/interface/compiled/_next/static/chunks/app/agents/{page-3c01900e7b5c7e50.js → page-1ac024e05374f91f.js} +1 -1
  8. khoj/interface/compiled/_next/static/chunks/app/automations/{page-6ea3381528603372.js → page-85e9176b460c5e33.js} +1 -1
  9. khoj/interface/compiled/_next/static/chunks/app/chat/page-ababf339318a3b50.js +1 -0
  10. khoj/interface/compiled/_next/static/chunks/app/factchecker/{page-04a19ab1a988976f.js → page-21cf46aca7e6d487.js} +1 -1
  11. khoj/interface/compiled/_next/static/chunks/app/{page-95ecd0acac7ece82.js → page-b406302925829b15.js} +1 -1
  12. khoj/interface/compiled/_next/static/chunks/app/search/{page-fa15807b1ad7e30b.js → page-fde8c956cc33a187.js} +1 -1
  13. khoj/interface/compiled/_next/static/chunks/app/settings/{page-1a2acc46cdabaf4a.js → page-88737126debb4712.js} +1 -1
  14. khoj/interface/compiled/_next/static/chunks/app/share/chat/{page-e20f54450d3ce6c0.js → page-f11b4fb0f2bc3381.js} +1 -1
  15. khoj/interface/compiled/_next/static/chunks/{webpack-07fad5db87344b82.js → webpack-f162a207b26413cd.js} +1 -1
  16. khoj/interface/compiled/_next/static/css/17e284bae7dc4881.css +1 -0
  17. khoj/interface/compiled/_next/static/css/37a313cb39403a84.css +1 -0
  18. khoj/interface/compiled/_next/static/css/4cae6c0e5c72fb2d.css +1 -0
  19. khoj/interface/compiled/_next/static/css/6bde1f2045622ef7.css +1 -0
  20. khoj/interface/compiled/agents/index.html +1 -1
  21. khoj/interface/compiled/agents/index.txt +2 -2
  22. khoj/interface/compiled/automations/index.html +1 -1
  23. khoj/interface/compiled/automations/index.txt +2 -2
  24. khoj/interface/compiled/chat/index.html +1 -1
  25. khoj/interface/compiled/chat/index.txt +2 -2
  26. khoj/interface/compiled/factchecker/index.html +1 -1
  27. khoj/interface/compiled/factchecker/index.txt +2 -2
  28. khoj/interface/compiled/index.html +1 -1
  29. khoj/interface/compiled/index.txt +2 -2
  30. khoj/interface/compiled/search/index.html +1 -1
  31. khoj/interface/compiled/search/index.txt +2 -2
  32. khoj/interface/compiled/settings/index.html +1 -1
  33. khoj/interface/compiled/settings/index.txt +2 -2
  34. khoj/interface/compiled/share/chat/index.html +1 -1
  35. khoj/interface/compiled/share/chat/index.txt +2 -2
  36. khoj/interface/email/magic_link.html +1 -1
  37. khoj/interface/email/task.html +1 -1
  38. khoj/interface/email/welcome.html +1 -1
  39. khoj/processor/conversation/google/__init__.py +0 -0
  40. khoj/processor/conversation/google/gemini_chat.py +221 -0
  41. khoj/processor/conversation/google/utils.py +192 -0
  42. khoj/processor/conversation/openai/gpt.py +2 -0
  43. khoj/processor/conversation/openai/utils.py +35 -10
  44. khoj/processor/conversation/prompts.py +6 -6
  45. khoj/processor/conversation/utils.py +16 -5
  46. khoj/processor/image/generate.py +212 -0
  47. khoj/processor/tools/online_search.py +2 -1
  48. khoj/routers/api.py +13 -0
  49. khoj/routers/api_chat.py +1 -1
  50. khoj/routers/email.py +6 -1
  51. khoj/routers/helpers.py +86 -164
  52. {khoj-1.22.3.dev5.dist-info → khoj-1.23.3.dev1.dist-info}/METADATA +2 -1
  53. {khoj-1.22.3.dev5.dist-info → khoj-1.23.3.dev1.dist-info}/RECORD +61 -54
  54. khoj/interface/compiled/_next/static/chunks/app/chat/page-132e5199f954559f.js +0 -1
  55. khoj/interface/compiled/_next/static/css/149c5104fe3d38b8.css +0 -1
  56. khoj/interface/compiled/_next/static/css/2272c73fc7a3b571.css +0 -1
  57. khoj/interface/compiled/_next/static/css/553f9cdcc7a2bcd6.css +0 -1
  58. khoj/interface/compiled/_next/static/css/a3530ec58b0b660f.css +0 -1
  59. /khoj/interface/compiled/_next/static/{vjWGo1xJFCitZUk51rujk → BtK3cBCv0oGm04ZdaAvMB}/_buildManifest.js +0 -0
  60. /khoj/interface/compiled/_next/static/{vjWGo1xJFCitZUk51rujk → BtK3cBCv0oGm04ZdaAvMB}/_ssgManifest.js +0 -0
  61. /khoj/interface/compiled/_next/static/chunks/{8423-ce22327cf2d2edae.js → 8423-14fc72aec9104ce9.js} +0 -0
  62. /khoj/interface/compiled/_next/static/chunks/{9178-3a0baad1c172d515.js → 9178-c153fc402c970365.js} +0 -0
  63. /khoj/interface/compiled/_next/static/chunks/{9417-2e54c6fd056982d8.js → 9417-5d14ac74aaab2c66.js} +0 -0
  64. {khoj-1.22.3.dev5.dist-info → khoj-1.23.3.dev1.dist-info}/WHEEL +0 -0
  65. {khoj-1.22.3.dev5.dist-info → khoj-1.23.3.dev1.dist-info}/entry_points.txt +0 -0
  66. {khoj-1.22.3.dev5.dist-info → khoj-1.23.3.dev1.dist-info}/licenses/LICENSE +0 -0
@@ -0,0 +1,192 @@
1
+ import logging
2
+ import random
3
+ from threading import Thread
4
+
5
+ import google.generativeai as genai
6
+ from google.generativeai.types.answer_types import FinishReason
7
+ from google.generativeai.types.generation_types import (
8
+ GenerateContentResponse,
9
+ StopCandidateException,
10
+ )
11
+ from google.generativeai.types.safety_types import (
12
+ HarmBlockThreshold,
13
+ HarmCategory,
14
+ HarmProbability,
15
+ )
16
+ from tenacity import (
17
+ before_sleep_log,
18
+ retry,
19
+ stop_after_attempt,
20
+ wait_exponential,
21
+ wait_random_exponential,
22
+ )
23
+
24
+ from khoj.processor.conversation.utils import ThreadedGenerator
25
+
26
+ logger = logging.getLogger(__name__)
27
+
28
+
29
+ DEFAULT_MAX_TOKENS_GEMINI = 8192
30
+
31
+
32
+ @retry(
33
+ wait=wait_random_exponential(min=1, max=10),
34
+ stop=stop_after_attempt(2),
35
+ before_sleep=before_sleep_log(logger, logging.DEBUG),
36
+ reraise=True,
37
+ )
38
+ def gemini_completion_with_backoff(
39
+ messages, system_prompt, model_name, temperature=0, api_key=None, model_kwargs=None, max_tokens=None
40
+ ) -> str:
41
+ genai.configure(api_key=api_key)
42
+ max_tokens = max_tokens or DEFAULT_MAX_TOKENS_GEMINI
43
+ model_kwargs = model_kwargs or dict()
44
+ model_kwargs["temperature"] = temperature
45
+ model_kwargs["max_output_tokens"] = max_tokens
46
+ model = genai.GenerativeModel(
47
+ model_name,
48
+ generation_config=model_kwargs,
49
+ system_instruction=system_prompt,
50
+ safety_settings={
51
+ HarmCategory.HARM_CATEGORY_DANGEROUS_CONTENT: HarmBlockThreshold.BLOCK_ONLY_HIGH,
52
+ HarmCategory.HARM_CATEGORY_HARASSMENT: HarmBlockThreshold.BLOCK_ONLY_HIGH,
53
+ HarmCategory.HARM_CATEGORY_HATE_SPEECH: HarmBlockThreshold.BLOCK_ONLY_HIGH,
54
+ HarmCategory.HARM_CATEGORY_SEXUALLY_EXPLICIT: HarmBlockThreshold.BLOCK_ONLY_HIGH,
55
+ },
56
+ )
57
+
58
+ formatted_messages = [{"role": message.role, "parts": [message.content]} for message in messages]
59
+
60
+ # Start chat session. All messages up to the last are considered to be part of the chat history
61
+ chat_session = model.start_chat(history=formatted_messages[0:-1])
62
+
63
+ try:
64
+ # Generate the response. The last message is considered to be the current prompt
65
+ aggregated_response = chat_session.send_message(formatted_messages[-1]["parts"][0])
66
+ return aggregated_response.text
67
+ except StopCandidateException as e:
68
+ response_message, _ = handle_gemini_response(e.args)
69
+ # Respond with reason for stopping
70
+ logger.warning(
71
+ f"LLM Response Prevented for {model_name}: {response_message}.\n"
72
+ + f"Last Message by {messages[-1].role}: {messages[-1].content}"
73
+ )
74
+ return response_message
75
+
76
+
77
+ @retry(
78
+ wait=wait_exponential(multiplier=1, min=4, max=10),
79
+ stop=stop_after_attempt(2),
80
+ before_sleep=before_sleep_log(logger, logging.DEBUG),
81
+ reraise=True,
82
+ )
83
+ def gemini_chat_completion_with_backoff(
84
+ messages,
85
+ compiled_references,
86
+ online_results,
87
+ model_name,
88
+ temperature,
89
+ api_key,
90
+ system_prompt,
91
+ max_prompt_size=None,
92
+ completion_func=None,
93
+ model_kwargs=None,
94
+ ):
95
+ g = ThreadedGenerator(compiled_references, online_results, completion_func=completion_func)
96
+ t = Thread(
97
+ target=gemini_llm_thread,
98
+ args=(g, messages, system_prompt, model_name, temperature, api_key, max_prompt_size, model_kwargs),
99
+ )
100
+ t.start()
101
+ return g
102
+
103
+
104
+ def gemini_llm_thread(
105
+ g, messages, system_prompt, model_name, temperature, api_key, max_prompt_size=None, model_kwargs=None
106
+ ):
107
+ try:
108
+ genai.configure(api_key=api_key)
109
+ max_tokens = max_prompt_size or DEFAULT_MAX_TOKENS_GEMINI
110
+ model_kwargs = model_kwargs or dict()
111
+ model_kwargs["temperature"] = temperature
112
+ model_kwargs["max_output_tokens"] = max_tokens
113
+ model_kwargs["stop_sequences"] = ["Notes:\n["]
114
+ model = genai.GenerativeModel(
115
+ model_name,
116
+ generation_config=model_kwargs,
117
+ system_instruction=system_prompt,
118
+ safety_settings={
119
+ HarmCategory.HARM_CATEGORY_DANGEROUS_CONTENT: HarmBlockThreshold.BLOCK_ONLY_HIGH,
120
+ HarmCategory.HARM_CATEGORY_HARASSMENT: HarmBlockThreshold.BLOCK_ONLY_HIGH,
121
+ HarmCategory.HARM_CATEGORY_HATE_SPEECH: HarmBlockThreshold.BLOCK_ONLY_HIGH,
122
+ HarmCategory.HARM_CATEGORY_SEXUALLY_EXPLICIT: HarmBlockThreshold.BLOCK_ONLY_HIGH,
123
+ },
124
+ )
125
+
126
+ formatted_messages = [{"role": message.role, "parts": [message.content]} for message in messages]
127
+ # all messages up to the last are considered to be part of the chat history
128
+ chat_session = model.start_chat(history=formatted_messages[0:-1])
129
+ # the last message is considered to be the current prompt
130
+ for chunk in chat_session.send_message(formatted_messages[-1]["parts"][0], stream=True):
131
+ message, stopped = handle_gemini_response(chunk.candidates, chunk.prompt_feedback)
132
+ message = message or chunk.text
133
+ g.send(message)
134
+ if stopped:
135
+ raise StopCandidateException(message)
136
+ except StopCandidateException as e:
137
+ logger.warning(
138
+ f"LLM Response Prevented for {model_name}: {e.args[0]}.\n"
139
+ + f"Last Message by {messages[-1].role}: {messages[-1].content}"
140
+ )
141
+ except Exception as e:
142
+ logger.error(f"Error in gemini_llm_thread: {e}", exc_info=True)
143
+ finally:
144
+ g.close()
145
+
146
+
147
+ def handle_gemini_response(candidates, prompt_feedback=None):
148
+ """Check if Gemini response was blocked and return an explanatory error message."""
149
+ # Check if the response was blocked due to safety concerns with the prompt
150
+ if len(candidates) == 0 and prompt_feedback:
151
+ message = f"\nI'd prefer to not respond to that due to **{prompt_feedback.block_reason.name}** issues with your query."
152
+ stopped = True
153
+ # Check if the response was blocked due to safety concerns with the generated content
154
+ elif candidates[0].finish_reason == FinishReason.SAFETY:
155
+ message = generate_safety_response(candidates[0].safety_ratings)
156
+ stopped = True
157
+ # Check if the response was stopped due to reaching maximum token limit or other reasons
158
+ elif candidates[0].finish_reason != FinishReason.STOP:
159
+ message = f"\nI can't talk further about that because of **{candidates[0].finish_reason.name} issue.**"
160
+ stopped = True
161
+ # Otherwise, the response is valid and can be used
162
+ else:
163
+ message = None
164
+ stopped = False
165
+ return message, stopped
166
+
167
+
168
+ def generate_safety_response(safety_ratings):
169
+ """Generate a conversational response based on the safety ratings of the response."""
170
+ # Get the safety rating with the highest probability
171
+ max_safety_rating = sorted(safety_ratings, key=lambda x: x.probability, reverse=True)[0]
172
+ # Remove the "HARM_CATEGORY_" prefix and title case the category name
173
+ max_safety_category = " ".join(max_safety_rating.category.name.split("_")[2:]).title()
174
+ # Add a bit of variety to the discomfort level based on the safety rating probability
175
+ discomfort_level = {
176
+ HarmProbability.HARM_PROBABILITY_UNSPECIFIED: " ",
177
+ HarmProbability.LOW: "a bit ",
178
+ HarmProbability.MEDIUM: "moderately ",
179
+ HarmProbability.HIGH: random.choice(["very ", "quite ", "fairly "]),
180
+ }[max_safety_rating.probability]
181
+ # Generate a response using a random response template
182
+ safety_response_choice = random.choice(
183
+ [
184
+ "\nUmm, I'd rather not to respond to that. The conversation has some probability of going into **{category}** territory.",
185
+ "\nI'd prefer not to talk about **{category}** related topics. It makes me {discomfort_level}uncomfortable.",
186
+ "\nI feel {discomfort_level}squeamish talking about **{category}** related stuff! Can we talk about something less controversial?",
187
+ "\nThat sounds {discomfort_level}outside the [Overtone Window](https://en.wikipedia.org/wiki/Overton_window) of acceptable conversation. Should we stick to something less {category} related?",
188
+ ]
189
+ )
190
+ return safety_response_choice.format(
191
+ category=max_safety_category, probability=max_safety_rating.probability.name, discomfort_level=discomfort_level
192
+ )
@@ -14,6 +14,7 @@ from khoj.processor.conversation.openai.utils import (
14
14
  from khoj.processor.conversation.utils import (
15
15
  construct_structured_message,
16
16
  generate_chatml_messages_with_context,
17
+ remove_json_codeblock,
17
18
  )
18
19
  from khoj.utils.helpers import ConversationCommand, is_none_or_empty
19
20
  from khoj.utils.rawconfig import LocationData
@@ -85,6 +86,7 @@ def extract_questions(
85
86
  # Extract, Clean Message from GPT's Response
86
87
  try:
87
88
  response = response.strip()
89
+ response = remove_json_codeblock(response)
88
90
  response = json.loads(response)
89
91
  response = [q.strip() for q in response["queries"] if q.strip()]
90
92
  if not isinstance(response, list) or not response:
@@ -45,15 +45,28 @@ def completion_with_backoff(
45
45
  openai_clients[client_key] = client
46
46
 
47
47
  formatted_messages = [{"role": message.role, "content": message.content} for message in messages]
48
+ stream = True
49
+
50
+ # Update request parameters for compatability with o1 model series
51
+ # Refer: https://platform.openai.com/docs/guides/reasoning/beta-limitations
52
+ if model.startswith("o1"):
53
+ stream = False
54
+ temperature = 1
55
+ model_kwargs.pop("stop", None)
56
+ model_kwargs.pop("response_format", None)
48
57
 
49
58
  chat = client.chat.completions.create(
50
- stream=True,
59
+ stream=stream,
51
60
  messages=formatted_messages, # type: ignore
52
61
  model=model, # type: ignore
53
62
  temperature=temperature,
54
63
  timeout=20,
55
64
  **(model_kwargs or dict()),
56
65
  )
66
+
67
+ if not stream:
68
+ return chat.choices[0].message.content
69
+
57
70
  aggregated_response = ""
58
71
  for chunk in chat:
59
72
  if len(chunk.choices) == 0:
@@ -112,9 +125,18 @@ def llm_thread(g, messages, model_name, temperature, openai_api_key=None, api_ba
112
125
  client: openai.OpenAI = openai_clients[client_key]
113
126
 
114
127
  formatted_messages = [{"role": message.role, "content": message.content} for message in messages]
128
+ stream = True
129
+
130
+ # Update request parameters for compatability with o1 model series
131
+ # Refer: https://platform.openai.com/docs/guides/reasoning/beta-limitations
132
+ if model_name.startswith("o1"):
133
+ stream = False
134
+ temperature = 1
135
+ model_kwargs.pop("stop", None)
136
+ model_kwargs.pop("response_format", None)
115
137
 
116
138
  chat = client.chat.completions.create(
117
- stream=True,
139
+ stream=stream,
118
140
  messages=formatted_messages,
119
141
  model=model_name, # type: ignore
120
142
  temperature=temperature,
@@ -122,14 +144,17 @@ def llm_thread(g, messages, model_name, temperature, openai_api_key=None, api_ba
122
144
  **(model_kwargs or dict()),
123
145
  )
124
146
 
125
- for chunk in chat:
126
- if len(chunk.choices) == 0:
127
- continue
128
- delta_chunk = chunk.choices[0].delta
129
- if isinstance(delta_chunk, str):
130
- g.send(delta_chunk)
131
- elif delta_chunk.content:
132
- g.send(delta_chunk.content)
147
+ if not stream:
148
+ g.send(chat.choices[0].message.content)
149
+ else:
150
+ for chunk in chat:
151
+ if len(chunk.choices) == 0:
152
+ continue
153
+ delta_chunk = chunk.choices[0].delta
154
+ if isinstance(delta_chunk, str):
155
+ g.send(delta_chunk)
156
+ elif delta_chunk.content:
157
+ g.send(delta_chunk.content)
133
158
  except Exception as e:
134
159
  logger.error(f"Error in llm_thread: {e}", exc_info=True)
135
160
  finally:
@@ -13,8 +13,8 @@ You were created by Khoj Inc. with the following capabilities:
13
13
  - You *CAN* generate images, look-up real-time information from the internet, set reminders and answer questions based on the user's notes.
14
14
  - Say "I don't know" or "I don't understand" if you don't know what to say or if you don't know the answer to a question.
15
15
  - Make sure to use the specific LaTeX math mode delimiters for your response. LaTex math mode specific delimiters as following
16
- - inline math mode : `\\(` and `\\)`
17
- - display math mode: insert linebreak after opening `$$`, `\\[` and before closing `$$`, `\\]`
16
+ - inline math mode : \\( and \\)
17
+ - display math mode: insert linebreak after opening $$, \\[ and before closing $$, \\]
18
18
  - Ask crisp follow-up questions to get additional context, when the answer cannot be inferred from the provided notes or past conversations.
19
19
  - Sometimes the user will share personal information that needs to be remembered, like an account ID or a residential address. These can be acknowledged with a simple "Got it" or "Okay".
20
20
  - Provide inline references to quotes from the user's notes or any web pages you refer to in your responses in markdown format. For example, "The farmer had ten sheep. [1](https://example.com)". *ALWAYS CITE YOUR SOURCES AND PROVIDE REFERENCES*. Add them inline to directly support your claim.
@@ -128,8 +128,8 @@ User's Notes:
128
128
  ## --
129
129
 
130
130
  image_generation_improve_prompt_base = """
131
- You are a talented creator with the ability to describe images to compose in vivid, fine detail.
132
- Use the provided context and user prompt to generate a more detailed prompt to create an image:
131
+ You are a talented media artist with the ability to describe images to compose in professional, fine detail.
132
+ Generate a vivid description of the image to be rendered using the provided context and user prompt below:
133
133
 
134
134
  Today's Date: {current_date}
135
135
  User's Location: {location}
@@ -145,10 +145,10 @@ Conversation Log:
145
145
 
146
146
  User Prompt: "{query}"
147
147
 
148
- Now generate an improved prompt describing the image to generate in vivid, fine detail.
148
+ Now generate an professional description of the image to generate in vivid, fine detail.
149
149
  - Use today's date, user's location, user's notes and online references to weave in any context that will improve the image generation.
150
150
  - Retain any important information and follow any instructions in the conversation log or user prompt.
151
- - Add specific, fine position details to compose the image.
151
+ - Add specific, fine position details. Mention painting style, camera parameters to compose the image.
152
152
  - Ensure your improved prompt is in prose format."""
153
153
 
154
154
  image_generation_improve_prompt_dalle = PromptTemplate.from_template(
@@ -1,4 +1,3 @@
1
- import json
2
1
  import logging
3
2
  import math
4
3
  import queue
@@ -24,6 +23,8 @@ model_to_prompt_size = {
24
23
  "gpt-4-0125-preview": 20000,
25
24
  "gpt-4-turbo-preview": 20000,
26
25
  "gpt-4o-mini": 20000,
26
+ "o1-preview": 20000,
27
+ "o1-mini": 20000,
27
28
  "TheBloke/Mistral-7B-Instruct-v0.2-GGUF": 3500,
28
29
  "NousResearch/Hermes-2-Pro-Mistral-7B-GGUF": 3500,
29
30
  "bartowski/Meta-Llama-3.1-8B-Instruct-GGUF": 20000,
@@ -220,8 +221,9 @@ def truncate_messages(
220
221
  try:
221
222
  if loaded_model:
222
223
  encoder = loaded_model.tokenizer()
223
- elif model_name.startswith("gpt-"):
224
- encoder = tiktoken.encoding_for_model(model_name)
224
+ elif model_name.startswith("gpt-") or model_name.startswith("o1"):
225
+ # as tiktoken doesn't recognize o1 model series yet
226
+ encoder = tiktoken.encoding_for_model("gpt-4o" if model_name.startswith("o1") else model_name)
225
227
  elif tokenizer_name:
226
228
  if tokenizer_name in state.pretrained_tokenizers:
227
229
  encoder = state.pretrained_tokenizers[tokenizer_name]
@@ -236,7 +238,7 @@ def truncate_messages(
236
238
  else:
237
239
  encoder = AutoTokenizer.from_pretrained(default_tokenizer)
238
240
  state.pretrained_tokenizers[default_tokenizer] = encoder
239
- logger.warning(
241
+ logger.debug(
240
242
  f"Fallback to default chat model tokenizer: {tokenizer_name}.\nConfigure tokenizer for unsupported model: {model_name} in Khoj settings to improve context stuffing."
241
243
  )
242
244
 
@@ -278,10 +280,19 @@ def truncate_messages(
278
280
  )
279
281
 
280
282
  if system_message:
281
- system_message.role = "user" if "gemma-2" in model_name else "system"
283
+ # Default system message role is system.
284
+ # Fallback to system message role of user for models that do not support this role like gemma-2 and openai's o1 model series.
285
+ system_message.role = "user" if "gemma-2" in model_name or model_name.startswith("o1") else "system"
282
286
  return messages + [system_message] if system_message else messages
283
287
 
284
288
 
285
289
  def reciprocal_conversation_to_chatml(message_pair):
286
290
  """Convert a single back and forth between user and assistant to chatml format"""
287
291
  return [ChatMessage(content=message, role=role) for message, role in zip(message_pair, ["user", "assistant"])]
292
+
293
+
294
+ def remove_json_codeblock(response):
295
+ """Remove any markdown json codeblock formatting if present. Useful for non schema enforceable models"""
296
+ if response.startswith("```json") and response.endswith("```"):
297
+ response = response[7:-3]
298
+ return response
@@ -0,0 +1,212 @@
1
+ import base64
2
+ import io
3
+ import logging
4
+ import time
5
+ from typing import Any, Callable, Dict, List, Optional
6
+
7
+ import openai
8
+ import requests
9
+
10
+ from khoj.database.adapters import ConversationAdapters
11
+ from khoj.database.models import KhojUser, TextToImageModelConfig
12
+ from khoj.routers.helpers import ChatEvent, generate_better_image_prompt
13
+ from khoj.routers.storage import upload_image
14
+ from khoj.utils import state
15
+ from khoj.utils.helpers import ImageIntentType, convert_image_to_webp, timer
16
+ from khoj.utils.rawconfig import LocationData
17
+
18
+ logger = logging.getLogger(__name__)
19
+
20
+
21
+ async def text_to_image(
22
+ message: str,
23
+ user: KhojUser,
24
+ conversation_log: dict,
25
+ location_data: LocationData,
26
+ references: List[Dict[str, Any]],
27
+ online_results: Dict[str, Any],
28
+ subscribed: bool = False,
29
+ send_status_func: Optional[Callable] = None,
30
+ uploaded_image_url: Optional[str] = None,
31
+ ):
32
+ status_code = 200
33
+ image = None
34
+ image_url = None
35
+ intent_type = ImageIntentType.TEXT_TO_IMAGE_V3
36
+
37
+ text_to_image_config = await ConversationAdapters.aget_user_text_to_image_model(user)
38
+ if not text_to_image_config:
39
+ # If the user has not configured a text to image model, return an unsupported on server error
40
+ status_code = 501
41
+ message = "Failed to generate image. Setup image generation on the server."
42
+ yield image_url or image, status_code, message, intent_type.value
43
+ return
44
+
45
+ text2image_model = text_to_image_config.model_name
46
+ chat_history = ""
47
+ for chat in conversation_log.get("chat", [])[-4:]:
48
+ if chat["by"] == "khoj" and chat["intent"].get("type") in ["remember", "reminder"]:
49
+ chat_history += f"Q: {chat['intent']['query']}\n"
50
+ chat_history += f"A: {chat['message']}\n"
51
+ elif chat["by"] == "khoj" and "text-to-image" in chat["intent"].get("type"):
52
+ chat_history += f"Q: Prompt: {chat['intent']['query']}\n"
53
+ chat_history += f"A: Improved Prompt: {chat['intent']['inferred-queries'][0]}\n"
54
+
55
+ if send_status_func:
56
+ async for event in send_status_func("**Enhancing the Painting Prompt**"):
57
+ yield {ChatEvent.STATUS: event}
58
+
59
+ # Generate a better image prompt
60
+ # Use the user's message, chat history, and other context
61
+ image_prompt = await generate_better_image_prompt(
62
+ message,
63
+ chat_history,
64
+ location_data=location_data,
65
+ note_references=references,
66
+ online_results=online_results,
67
+ model_type=text_to_image_config.model_type,
68
+ subscribed=subscribed,
69
+ uploaded_image_url=uploaded_image_url,
70
+ )
71
+
72
+ if send_status_func:
73
+ async for event in send_status_func(f"**Painting to Imagine**:\n{image_prompt}"):
74
+ yield {ChatEvent.STATUS: event}
75
+
76
+ # Generate image using the configured model and API
77
+ with timer(f"Generate image with {text_to_image_config.model_type}", logger):
78
+ try:
79
+ if text_to_image_config.model_type == TextToImageModelConfig.ModelType.OPENAI:
80
+ webp_image_bytes = generate_image_with_openai(image_prompt, text_to_image_config, text2image_model)
81
+ elif text_to_image_config.model_type == TextToImageModelConfig.ModelType.STABILITYAI:
82
+ webp_image_bytes = generate_image_with_stability(image_prompt, text_to_image_config, text2image_model)
83
+ elif text_to_image_config.model_type == TextToImageModelConfig.ModelType.REPLICATE:
84
+ webp_image_bytes = generate_image_with_replicate(image_prompt, text_to_image_config, text2image_model)
85
+ except openai.OpenAIError or openai.BadRequestError or openai.APIConnectionError as e:
86
+ if "content_policy_violation" in e.message:
87
+ logger.error(f"Image Generation blocked by OpenAI: {e}")
88
+ status_code = e.status_code # type: ignore
89
+ message = f"Image generation blocked by OpenAI: {e.message}" # type: ignore
90
+ yield image_url or image, status_code, message, intent_type.value
91
+ return
92
+ else:
93
+ logger.error(f"Image Generation failed with {e}", exc_info=True)
94
+ message = f"Image generation failed with OpenAI error: {e.message}" # type: ignore
95
+ status_code = e.status_code # type: ignore
96
+ yield image_url or image, status_code, message, intent_type.value
97
+ return
98
+ except requests.RequestException as e:
99
+ logger.error(f"Image Generation failed with {e}", exc_info=True)
100
+ message = f"Image generation using {text2image_model} via {text_to_image_config.model_type} failed with error: {e}"
101
+ status_code = 502
102
+ yield image_url or image, status_code, message, intent_type.value
103
+ return
104
+
105
+ # Decide how to store the generated image
106
+ with timer("Upload image to S3", logger):
107
+ image_url = upload_image(webp_image_bytes, user.uuid)
108
+ if image_url:
109
+ intent_type = ImageIntentType.TEXT_TO_IMAGE2
110
+ else:
111
+ intent_type = ImageIntentType.TEXT_TO_IMAGE_V3
112
+ image = base64.b64encode(webp_image_bytes).decode("utf-8")
113
+
114
+ yield image_url or image, status_code, image_prompt, intent_type.value
115
+
116
+
117
+ def generate_image_with_openai(
118
+ improved_image_prompt: str, text_to_image_config: TextToImageModelConfig, text2image_model: str
119
+ ):
120
+ "Generate image using OpenAI API"
121
+
122
+ # Get the API key from the user's configuration
123
+ if text_to_image_config.api_key:
124
+ api_key = text_to_image_config.api_key
125
+ elif text_to_image_config.openai_config:
126
+ api_key = text_to_image_config.openai_config.api_key
127
+ elif state.openai_client:
128
+ api_key = state.openai_client.api_key
129
+ auth_header = {"Authorization": f"Bearer {api_key}"} if api_key else {}
130
+
131
+ # Generate image using OpenAI API
132
+ OPENAI_IMAGE_GEN_STYLE = "vivid"
133
+ response = state.openai_client.images.generate(
134
+ prompt=improved_image_prompt,
135
+ model=text2image_model,
136
+ style=OPENAI_IMAGE_GEN_STYLE,
137
+ response_format="b64_json",
138
+ extra_headers=auth_header,
139
+ )
140
+
141
+ # Extract the base64 image from the response
142
+ image = response.data[0].b64_json
143
+ # Decode base64 png and convert it to webp for faster loading
144
+ return convert_image_to_webp(base64.b64decode(image))
145
+
146
+
147
+ def generate_image_with_stability(
148
+ improved_image_prompt: str, text_to_image_config: TextToImageModelConfig, text2image_model: str
149
+ ):
150
+ "Generate image using Stability AI"
151
+
152
+ # Call Stability AI API to generate image
153
+ response = requests.post(
154
+ f"https://api.stability.ai/v2beta/stable-image/generate/sd3",
155
+ headers={"authorization": f"Bearer {text_to_image_config.api_key}", "accept": "image/*"},
156
+ files={"none": ""},
157
+ data={
158
+ "prompt": improved_image_prompt,
159
+ "model": text2image_model,
160
+ "mode": "text-to-image",
161
+ "output_format": "png",
162
+ "aspect_ratio": "1:1",
163
+ },
164
+ )
165
+ # Convert png to webp for faster loading
166
+ return convert_image_to_webp(response.content)
167
+
168
+
169
+ def generate_image_with_replicate(
170
+ improved_image_prompt: str, text_to_image_config: TextToImageModelConfig, text2image_model: str
171
+ ):
172
+ "Generate image using Replicate API"
173
+
174
+ # Create image generation task on Replicate
175
+ replicate_create_prediction_url = f"https://api.replicate.com/v1/models/{text2image_model}/predictions"
176
+ headers = {
177
+ "Authorization": f"Bearer {text_to_image_config.api_key}",
178
+ "Content-Type": "application/json",
179
+ }
180
+ json = {
181
+ "input": {
182
+ "prompt": improved_image_prompt,
183
+ "num_outputs": 1,
184
+ "aspect_ratio": "1:1",
185
+ "output_format": "webp",
186
+ "output_quality": 100,
187
+ }
188
+ }
189
+ create_prediction = requests.post(replicate_create_prediction_url, headers=headers, json=json).json()
190
+
191
+ # Get status of image generation task
192
+ get_prediction_url = create_prediction["urls"]["get"]
193
+ get_prediction = requests.get(get_prediction_url, headers=headers).json()
194
+ status = get_prediction["status"]
195
+ retry_count = 1
196
+
197
+ # Poll the image generation task for completion status
198
+ while status not in ["succeeded", "failed", "canceled"] and retry_count < 20:
199
+ time.sleep(2)
200
+ get_prediction = requests.get(get_prediction_url, headers=headers).json()
201
+ status = get_prediction["status"]
202
+ retry_count += 1
203
+
204
+ # Raise exception if the image generation task fails
205
+ if status != "succeeded":
206
+ if retry_count >= 10:
207
+ raise requests.RequestException("Image generation timed out")
208
+ raise requests.RequestException(f"Image generation failed with status: {status}")
209
+
210
+ # Get the generated image
211
+ image_url = get_prediction["output"][0] if isinstance(get_prediction["output"], list) else get_prediction["output"]
212
+ return io.BytesIO(requests.get(image_url).content).getvalue()
@@ -7,6 +7,7 @@ from collections import defaultdict
7
7
  from typing import Callable, Dict, List, Optional, Tuple, Union
8
8
 
9
9
  import aiohttp
10
+ import requests
10
11
  from bs4 import BeautifulSoup
11
12
  from markdownify import markdownify
12
13
 
@@ -94,7 +95,7 @@ async def search_online(
94
95
 
95
96
  # Read, extract relevant info from the retrieved web pages
96
97
  if webpages:
97
- webpage_links = [link for link, _, _ in webpages]
98
+ webpage_links = set([link for link, _, _ in webpages])
98
99
  logger.info(f"Reading web pages at: {list(webpage_links)}")
99
100
  if send_status_func:
100
101
  webpage_links_str = "\n- " + "\n- ".join(list(webpage_links))
khoj/routers/api.py CHANGED
@@ -31,6 +31,7 @@ from khoj.database.models import ChatModelOptions, KhojUser, SpeechToTextModelOp
31
31
  from khoj.processor.conversation.anthropic.anthropic_chat import (
32
32
  extract_questions_anthropic,
33
33
  )
34
+ from khoj.processor.conversation.google.gemini_chat import extract_questions_gemini
34
35
  from khoj.processor.conversation.offline.chat_model import extract_questions_offline
35
36
  from khoj.processor.conversation.offline.whisper import transcribe_audio_offline
36
37
  from khoj.processor.conversation.openai.gpt import extract_questions
@@ -419,6 +420,18 @@ async def extract_references_and_questions(
419
420
  location_data=location_data,
420
421
  user=user,
421
422
  )
423
+ elif conversation_config.model_type == ChatModelOptions.ModelType.GOOGLE:
424
+ api_key = conversation_config.openai_config.api_key
425
+ chat_model = conversation_config.chat_model
426
+ inferred_queries = extract_questions_gemini(
427
+ defiltered_query,
428
+ model=chat_model,
429
+ api_key=api_key,
430
+ conversation_log=meta_log,
431
+ location_data=location_data,
432
+ max_tokens=conversation_config.max_prompt_size,
433
+ user=user,
434
+ )
422
435
 
423
436
  # Collate search results as context for GPT
424
437
  with timer("Searching knowledge base took", logger):
khoj/routers/api_chat.py CHANGED
@@ -26,6 +26,7 @@ from khoj.database.adapters import (
26
26
  from khoj.database.models import KhojUser
27
27
  from khoj.processor.conversation.prompts import help_message, no_entries_found
28
28
  from khoj.processor.conversation.utils import save_to_conversation_log
29
+ from khoj.processor.image.generate import text_to_image
29
30
  from khoj.processor.speech.text_to_speech import generate_text_to_speech
30
31
  from khoj.processor.tools.online_search import read_webpages, search_online
31
32
  from khoj.routers.api import extract_references_and_questions
@@ -44,7 +45,6 @@ from khoj.routers.helpers import (
44
45
  is_query_empty,
45
46
  is_ready_to_chat,
46
47
  read_chat_stream,
47
- text_to_image,
48
48
  update_telemetry_state,
49
49
  validate_conversation_config,
50
50
  )
khoj/routers/email.py CHANGED
@@ -44,7 +44,12 @@ async def send_magic_link_email(email, unique_id, host):
44
44
  html_content = template.render(link=f"{host}auth/magic?code={unique_id}")
45
45
 
46
46
  resend.Emails.send(
47
- {"sender": "noreply@khoj.dev", "to": email, "subject": "Your Sign-In Link for Khoj 🚀", "html": html_content}
47
+ {
48
+ "sender": os.environ.get("RESEND_EMAIL", "noreply@khoj.dev"),
49
+ "to": email,
50
+ "subject": "Your Sign-In Link for Khoj 🚀",
51
+ "html": html_content,
52
+ }
48
53
  )
49
54
 
50
55