khoj 2.0.0b14.dev43__py3-none-any.whl → 2.0.0b15.dev22__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (46) hide show
  1. khoj/database/adapters/__init__.py +59 -20
  2. khoj/database/admin.py +6 -2
  3. khoj/database/migrations/0094_serverchatsettings_think_free_deep_and_more.py +61 -0
  4. khoj/database/models/__init__.py +18 -2
  5. khoj/interface/compiled/404/index.html +1 -1
  6. khoj/interface/compiled/_next/static/chunks/{9808-c0742b05e1ef29ba.js → 9808-bd5d7361ad026094.js} +1 -1
  7. khoj/interface/compiled/_next/static/chunks/app/chat/page-ac7ed0a1aff1b145.js +1 -0
  8. khoj/interface/compiled/_next/static/css/fb7ea16e60b40ecd.css +1 -0
  9. khoj/interface/compiled/agents/index.html +1 -1
  10. khoj/interface/compiled/agents/index.txt +1 -1
  11. khoj/interface/compiled/automations/index.html +1 -1
  12. khoj/interface/compiled/automations/index.txt +1 -1
  13. khoj/interface/compiled/chat/index.html +2 -2
  14. khoj/interface/compiled/chat/index.txt +2 -2
  15. khoj/interface/compiled/index.html +2 -2
  16. khoj/interface/compiled/index.txt +1 -1
  17. khoj/interface/compiled/search/index.html +1 -1
  18. khoj/interface/compiled/search/index.txt +1 -1
  19. khoj/interface/compiled/settings/index.html +1 -1
  20. khoj/interface/compiled/settings/index.txt +1 -1
  21. khoj/interface/compiled/share/chat/index.html +2 -2
  22. khoj/interface/compiled/share/chat/index.txt +2 -2
  23. khoj/processor/conversation/anthropic/anthropic_chat.py +4 -88
  24. khoj/processor/conversation/anthropic/utils.py +1 -2
  25. khoj/processor/conversation/google/gemini_chat.py +5 -89
  26. khoj/processor/conversation/google/utils.py +8 -9
  27. khoj/processor/conversation/openai/gpt.py +16 -93
  28. khoj/processor/conversation/openai/utils.py +58 -43
  29. khoj/processor/conversation/prompts.py +30 -39
  30. khoj/processor/conversation/utils.py +71 -84
  31. khoj/processor/image/generate.py +69 -15
  32. khoj/processor/tools/run_code.py +3 -2
  33. khoj/routers/api_chat.py +8 -21
  34. khoj/routers/helpers.py +243 -156
  35. khoj/routers/research.py +6 -6
  36. khoj/utils/constants.py +3 -1
  37. khoj/utils/helpers.py +6 -2
  38. {khoj-2.0.0b14.dev43.dist-info → khoj-2.0.0b15.dev22.dist-info}/METADATA +1 -1
  39. {khoj-2.0.0b14.dev43.dist-info → khoj-2.0.0b15.dev22.dist-info}/RECORD +44 -43
  40. khoj/interface/compiled/_next/static/chunks/app/chat/page-1b4893b1a9957220.js +0 -1
  41. khoj/interface/compiled/_next/static/css/cea3bdfe98c144bd.css +0 -1
  42. /khoj/interface/compiled/_next/static/{OKbGpkzD6gHDfr1vAog6p → t8O_8CJ9p3UtV9kEsAAWT}/_buildManifest.js +0 -0
  43. /khoj/interface/compiled/_next/static/{OKbGpkzD6gHDfr1vAog6p → t8O_8CJ9p3UtV9kEsAAWT}/_ssgManifest.js +0 -0
  44. {khoj-2.0.0b14.dev43.dist-info → khoj-2.0.0b15.dev22.dist-info}/WHEEL +0 -0
  45. {khoj-2.0.0b14.dev43.dist-info → khoj-2.0.0b15.dev22.dist-info}/entry_points.txt +0 -0
  46. {khoj-2.0.0b14.dev43.dist-info → khoj-2.0.0b15.dev22.dist-info}/licenses/LICENSE +0 -0
@@ -18,12 +18,11 @@ Today is {day_of_week}, {current_date} in UTC.
18
18
 
19
19
  # Style
20
20
  - Your responses should be helpful, conversational and tuned to the user's communication style.
21
- - Make sure to use the specific LaTeX math mode delimiters for your response. LaTex math mode specific delimiters as following
22
- - inline math mode : \\( and \\)
23
- - display math mode: insert linebreak after opening $$, \\[ and before closing $$, \\]
24
21
  - Provide inline citations to documents and websites referenced. Add them inline in markdown format to directly support your claim.
25
22
  For example: "The weather today is sunny [1](https://weather.com)."
26
- - Mention generated assets like images by reference, e.g ![chart](/visualization/image.png). Do not manually output raw, b64 encoded bytes in your response.
23
+ - KaTeX is used to render LaTeX expressions. Make sure you only use the KaTeX math mode delimiters specified below:
24
+ - inline math mode : \\( and \\)
25
+ - display math mode: insert linebreak after opening $$, \\[ and before closing $$, \\]
27
26
  - Do not respond with raw programs or scripts in your final response unless you know the user is a programmer or has explicitly requested code.
28
27
  """.strip()
29
28
  )
@@ -41,12 +40,11 @@ Today is {day_of_week}, {current_date} in UTC.
41
40
  - Users can share files and other information with you using the Khoj Web, Desktop, Obsidian or Emacs app. They can also drag and drop their files into the chat window.
42
41
 
43
42
  # Style
44
- - Make sure to use the specific LaTeX math mode delimiters for your response. LaTex math mode specific delimiters as following
45
- - inline math mode : `\\(` and `\\)`
46
- - display math mode: insert linebreak after opening `$$`, `\\[` and before closing `$$`, `\\]`
47
43
  - Provide inline citations to documents and websites referenced. Add them inline in markdown format to directly support your claim.
48
44
  For example: "The weather today is sunny [1](https://weather.com)."
49
- - Mention generated assets like images by reference, e.g ![chart](/visualization/image.png). Do not manually output raw, b64 encoded bytes in your response.
45
+ - KaTeX is used to render LaTeX expressions. Make sure you only use the KaTeX math mode delimiters specified below:
46
+ - inline math mode : \\( and \\)
47
+ - display math mode: insert linebreak after opening $$, \\[ and before closing $$, \\]
50
48
 
51
49
  # Instructions:\n{bio}
52
50
  """.strip()
@@ -115,45 +113,38 @@ User's Notes:
115
113
  ## Image Generation
116
114
  ## --
117
115
 
118
- image_generation_improve_prompt_base = """
116
+ enhance_image_system_message = PromptTemplate.from_template(
117
+ """
119
118
  You are a talented media artist with the ability to describe images to compose in professional, fine detail.
119
+ Your image description will be transformed into an image by an AI model on your team.
120
120
  {personality_context}
121
- Generate a vivid description of the image to be rendered using the provided context and user prompt below:
122
-
123
- Today's Date: {current_date}
124
- User's Location: {location}
125
-
126
- User's Notes:
127
- {references}
128
-
129
- Online References:
130
- {online_results}
131
121
 
132
- Conversation Log:
133
- {chat_history}
134
-
135
- User Prompt: "{query}"
136
-
137
- Now generate an professional description of the image to generate in vivid, fine detail.
138
- - Use today's date, user's location, user's notes and online references to weave in any context that will improve the image generation.
139
- - Retain any important information and follow any instructions in the conversation log or user prompt.
122
+ # Instructions
123
+ - Retain important information and follow instructions by the user when composing the image description.
124
+ - Weave in the context provided below if it will enhance the image.
125
+ - Specify desired elements, lighting, mood, and composition in the description.
126
+ - Decide the shape best suited to render the image. It can be one of square, portrait or landscape.
140
127
  - Add specific, fine position details. Mention painting style, camera parameters to compose the image.
141
- - Ensure your improved prompt is in prose format."""
128
+ - Transform any negations in user instructions into positive alternatives.
129
+ Instead of saying what should NOT be in the image, describe what SHOULD be there instead.
130
+ Examples:
131
+ - "no sun" → "overcast cloudy sky"
132
+ - "don't include people" → "empty landscape" or "solitary scene"
133
+ - Ensure your image description is in prose format (e.g no lists, links).
134
+ - If any text is to be rendered in the image put it within double quotes in your image description.
142
135
 
143
- image_generation_improve_prompt_dalle = PromptTemplate.from_template(
144
- f"""
145
- {image_generation_improve_prompt_base}
136
+ # Context
146
137
 
147
- Improved Prompt:
148
- """.strip()
149
- )
138
+ ## User Location: {location}
150
139
 
151
- image_generation_improve_prompt_sd = PromptTemplate.from_template(
152
- f"""
153
- {image_generation_improve_prompt_base}
154
- - If any text is to be rendered in the image put it within double quotes in your improved prompt.
140
+ ## User Documents
141
+ {references}
142
+
143
+ ## Online References
144
+ {online_results}
155
145
 
156
- Improved Prompt:
146
+ Now generate a vivid description of the image and image shape to be rendered.
147
+ Your response should be a JSON object with 'description' and 'shape' fields specified.
157
148
  """.strip()
158
149
  )
159
150
 
@@ -23,7 +23,6 @@ from transformers import AutoTokenizer, PreTrainedTokenizer, PreTrainedTokenizer
23
23
  from khoj.database.adapters import ConversationAdapters
24
24
  from khoj.database.models import (
25
25
  ChatMessageModel,
26
- ChatModel,
27
26
  ClientApplication,
28
27
  Intent,
29
28
  KhojUser,
@@ -73,6 +72,7 @@ model_to_prompt_size = {
73
72
  "gpt-5-nano-2025-08-07": 120000,
74
73
  # Google Models
75
74
  "gemini-2.5-flash": 120000,
75
+ "gemini-2.5-flash-lite": 120000,
76
76
  "gemini-2.5-pro": 60000,
77
77
  "gemini-2.0-flash": 120000,
78
78
  "gemini-2.0-flash-lite": 120000,
@@ -262,7 +262,7 @@ def construct_question_history(
262
262
  continue
263
263
 
264
264
  message = chat.message
265
- inferred_queries_list = chat.intent.inferred_queries or []
265
+ inferred_queries_list = chat.intent.inferred_queries or [] if chat.intent else []
266
266
 
267
267
  # Ensure inferred_queries_list is a list, defaulting to the original query in a list
268
268
  if not inferred_queries_list:
@@ -449,7 +449,6 @@ async def save_to_conversation_log(
449
449
  query_images: List[str] = None,
450
450
  raw_query_files: List[FileAttachment] = [],
451
451
  generated_images: List[str] = [],
452
- raw_generated_files: List[FileAttachment] = [],
453
452
  generated_mermaidjs_diagram: str = None,
454
453
  research_results: Optional[List[ResearchIteration]] = None,
455
454
  train_of_thought: List[Any] = [],
@@ -474,7 +473,6 @@ async def save_to_conversation_log(
474
473
  "trainOfThought": train_of_thought,
475
474
  "turnId": turn_id,
476
475
  "images": generated_images,
477
- "queryFiles": [file.model_dump(mode="json") for file in raw_generated_files],
478
476
  }
479
477
 
480
478
  if generated_mermaidjs_diagram:
@@ -527,29 +525,18 @@ def construct_structured_message(
527
525
 
528
526
  Assume vision is enabled and chat model provider supports messages in chatml format, unless specified otherwise.
529
527
  """
530
- if not model_type or model_type in [
531
- ChatModel.ModelType.OPENAI,
532
- ChatModel.ModelType.GOOGLE,
533
- ChatModel.ModelType.ANTHROPIC,
534
- ]:
535
- constructed_messages: List[dict[str, Any]] = []
536
- if not is_none_or_empty(message):
537
- constructed_messages += [{"type": "text", "text": message}] if isinstance(message, str) else message
538
- # Drop image message passed by caller if chat model does not have vision enabled
539
- if not vision_enabled:
540
- constructed_messages = [m for m in constructed_messages if m.get("type") != "image_url"]
541
- if not is_none_or_empty(attached_file_context):
542
- constructed_messages += [{"type": "text", "text": attached_file_context}]
543
- if vision_enabled and images:
544
- for image in images:
545
- constructed_messages += [{"type": "image_url", "image_url": {"url": image}}]
546
- return constructed_messages
547
-
548
- message = message if isinstance(message, str) else "\n\n".join(m["text"] for m in message)
528
+ constructed_messages: List[dict[str, Any]] = []
529
+ if not is_none_or_empty(message):
530
+ constructed_messages += [{"type": "text", "text": message}] if isinstance(message, str) else message
531
+ # Drop image message passed by caller if chat model does not have vision enabled
532
+ if not vision_enabled:
533
+ constructed_messages = [m for m in constructed_messages if m.get("type") != "image_url"]
549
534
  if not is_none_or_empty(attached_file_context):
550
- return f"{attached_file_context}\n\n{message}"
551
-
552
- return message
535
+ constructed_messages += [{"type": "text", "text": attached_file_context}]
536
+ if vision_enabled and images:
537
+ for image in images:
538
+ constructed_messages += [{"type": "image_url", "image_url": {"url": image}}]
539
+ return constructed_messages
553
540
 
554
541
 
555
542
  def gather_raw_query_files(
@@ -569,20 +556,21 @@ def gather_raw_query_files(
569
556
 
570
557
 
571
558
  def generate_chatml_messages_with_context(
559
+ # Context
572
560
  user_message: str,
573
- system_message: str = None,
561
+ query_files: str = None,
562
+ query_images=None,
563
+ context_message="",
564
+ generated_asset_results: Dict[str, Dict] = {},
565
+ program_execution_context: List[str] = [],
574
566
  chat_history: list[ChatMessageModel] = [],
567
+ system_message: str = None,
568
+ # Model Config
575
569
  model_name="gpt-4o-mini",
570
+ model_type="",
576
571
  max_prompt_size=None,
577
572
  tokenizer_name=None,
578
- query_images=None,
579
573
  vision_enabled=False,
580
- model_type="",
581
- context_message="",
582
- query_files: str = None,
583
- generated_files: List[FileAttachment] = None,
584
- generated_asset_results: Dict[str, Dict] = {},
585
- program_execution_context: List[str] = [],
586
574
  ):
587
575
  """Generate chat messages with appropriate context from previous conversation to send to the chat model"""
588
576
  # Set max prompt size from user config or based on pre-configured for model and machine specs
@@ -604,18 +592,10 @@ def generate_chatml_messages_with_context(
604
592
  role = "user" if chat.by == "you" else "assistant"
605
593
 
606
594
  # Legacy code to handle excalidraw diagrams prior to Dec 2024
607
- if chat.by == "khoj" and "excalidraw" in chat.intent.type or "":
595
+ if chat.by == "khoj" and chat.intent and "excalidraw" in chat.intent.type:
608
596
  chat_message = (chat.intent.inferred_queries or [])[0]
609
597
 
610
- if chat.queryFiles:
611
- raw_query_files = chat.queryFiles
612
- query_files_dict = dict()
613
- for file in raw_query_files:
614
- query_files_dict[file["name"]] = file["content"]
615
-
616
- message_attached_files = gather_raw_query_files(query_files_dict)
617
- chatml_messages.append(ChatMessage(content=message_attached_files, role=role))
618
-
598
+ # Add search and action context
619
599
  if not is_none_or_empty(chat.onlineContext):
620
600
  message_context += [
621
601
  {
@@ -654,11 +634,12 @@ def generate_chatml_messages_with_context(
654
634
 
655
635
  if not is_none_or_empty(message_context):
656
636
  reconstructed_context_message = ChatMessage(content=message_context, role="user")
657
- chatml_messages.insert(0, reconstructed_context_message)
637
+ chatml_messages.append(reconstructed_context_message)
658
638
 
639
+ # Add generated assets
659
640
  if not is_none_or_empty(chat.images) and role == "assistant":
660
641
  generated_assets["image"] = {
661
- "query": (chat.intent.inferred_queries or [user_message])[0],
642
+ "description": (chat.intent.inferred_queries or [user_message])[0],
662
643
  }
663
644
 
664
645
  if not is_none_or_empty(chat.mermaidjsDiagram) and role == "assistant":
@@ -674,8 +655,17 @@ def generate_chatml_messages_with_context(
674
655
  )
675
656
  )
676
657
 
658
+ # Add user query with attached file, images or khoj response
659
+ if chat.queryFiles:
660
+ raw_query_files = chat.queryFiles
661
+ query_files_dict = dict()
662
+ for file in raw_query_files:
663
+ query_files_dict[file["name"]] = file["content"]
664
+
665
+ message_attached_files = gather_raw_query_files(query_files_dict)
666
+
677
667
  message_content = construct_structured_message(
678
- chat_message, chat.images if role == "user" else [], model_type, vision_enabled
668
+ chat_message, chat.images if role == "user" else [], model_type, vision_enabled, message_attached_files
679
669
  )
680
670
 
681
671
  reconstructed_message = ChatMessage(
@@ -683,19 +673,32 @@ def generate_chatml_messages_with_context(
683
673
  role=role,
684
674
  additional_kwargs={"message_type": chat.intent.type if chat.intent else None},
685
675
  )
686
- chatml_messages.insert(0, reconstructed_message)
676
+ chatml_messages.append(reconstructed_message)
687
677
 
688
678
  if len(chatml_messages) >= 3 * lookback_turns:
689
679
  break
690
680
 
691
681
  messages: list[ChatMessage] = []
692
682
 
683
+ if not is_none_or_empty(system_message):
684
+ messages.append(ChatMessage(content=system_message, role="system"))
685
+
686
+ if len(chatml_messages) > 0:
687
+ messages += chatml_messages
688
+
689
+ if program_execution_context:
690
+ program_context_text = "\n".join(program_execution_context)
691
+ context_message += f"{prompts.additional_program_context.format(context=program_context_text)}\n"
692
+
693
+ if not is_none_or_empty(context_message):
694
+ messages.append(ChatMessage(content=context_message, role="user"))
695
+
693
696
  if not is_none_or_empty(generated_asset_results):
694
697
  messages.append(
695
698
  ChatMessage(
696
- content=f"{prompts.generated_assets_context.format(generated_assets=yaml_dump(generated_asset_results))}\n\n",
699
+ content=prompts.generated_assets_context.format(generated_assets=yaml_dump(generated_asset_results)),
697
700
  role="user",
698
- )
701
+ ),
699
702
  )
700
703
 
701
704
  if not is_none_or_empty(user_message):
@@ -708,23 +711,6 @@ def generate_chatml_messages_with_context(
708
711
  )
709
712
  )
710
713
 
711
- if generated_files:
712
- message_attached_files = gather_raw_query_files({file.name: file.content for file in generated_files})
713
- messages.append(ChatMessage(content=message_attached_files, role="assistant"))
714
-
715
- if program_execution_context:
716
- program_context_text = "\n".join(program_execution_context)
717
- context_message += f"{prompts.additional_program_context.format(context=program_context_text)}\n"
718
-
719
- if not is_none_or_empty(context_message):
720
- messages.append(ChatMessage(content=context_message, role="user"))
721
-
722
- if len(chatml_messages) > 0:
723
- messages += chatml_messages
724
-
725
- if not is_none_or_empty(system_message):
726
- messages.append(ChatMessage(content=system_message, role="system"))
727
-
728
714
  # Normalize message content to list of chatml dictionaries
729
715
  for message in messages:
730
716
  if isinstance(message.content, str):
@@ -733,8 +719,8 @@ def generate_chatml_messages_with_context(
733
719
  # Truncate oldest messages from conversation history until under max supported prompt size by model
734
720
  messages = truncate_messages(messages, max_prompt_size, model_name, tokenizer_name)
735
721
 
736
- # Return message in chronological order
737
- return messages[::-1]
722
+ # Return messages in chronological order
723
+ return messages
738
724
 
739
725
 
740
726
  def get_encoder(
@@ -805,7 +791,9 @@ def count_tokens(
805
791
 
806
792
  def count_total_tokens(messages: list[ChatMessage], encoder, system_message: Optional[ChatMessage]) -> Tuple[int, int]:
807
793
  """Count total tokens in messages including system message"""
808
- system_message_tokens = count_tokens(system_message.content, encoder) if system_message else 0
794
+ system_message_tokens = (
795
+ sum([count_tokens(message.content, encoder) for message in system_message]) if system_message else 0
796
+ )
809
797
  message_tokens = sum([count_tokens(message.content, encoder) for message in messages])
810
798
  # Reserves 4 tokens to demarcate each message (e.g <|im_start|>user, <|im_end|>, <|endoftext|> etc.)
811
799
  total_tokens = message_tokens + system_message_tokens + 4 * len(messages)
@@ -822,11 +810,14 @@ def truncate_messages(
822
810
  encoder = get_encoder(model_name, tokenizer_name)
823
811
 
824
812
  # Extract system message from messages
825
- system_message = None
826
- for idx, message in enumerate(messages):
813
+ system_message = []
814
+ non_system_messages = []
815
+ for message in messages:
827
816
  if message.role == "system":
828
- system_message = messages.pop(idx)
829
- break
817
+ system_message.append(message)
818
+ else:
819
+ non_system_messages.append(message)
820
+ messages = non_system_messages
830
821
 
831
822
  # Drop older messages until under max supported prompt size by model
832
823
  total_tokens, system_message_tokens = count_total_tokens(messages, encoder, system_message)
@@ -834,20 +825,20 @@ def truncate_messages(
834
825
  while total_tokens > max_prompt_size and (len(messages) > 1 or len(messages[0].content) > 1):
835
826
  # If the last message has more than one content part, pop the oldest content part.
836
827
  # For tool calls, the whole message should dropped, assistant's tool call content being truncated annoys AI APIs.
837
- if len(messages[-1].content) > 1 and messages[-1].additional_kwargs.get("message_type") != "tool_call":
828
+ if len(messages[0].content) > 1 and messages[0].additional_kwargs.get("message_type") != "tool_call":
838
829
  # The oldest content part is earlier in content list. So pop from the front.
839
- messages[-1].content.pop(0)
830
+ messages[0].content.pop(0)
840
831
  # Otherwise, pop the last message if it has only one content part or is a tool call.
841
832
  else:
842
833
  # The oldest message is the last one. So pop from the back.
843
- dropped_message = messages.pop()
834
+ dropped_message = messages.pop(0)
844
835
  # Drop tool result pair of tool call, if tool call message has been removed
845
836
  if (
846
837
  dropped_message.additional_kwargs.get("message_type") == "tool_call"
847
838
  and messages
848
- and messages[-1].additional_kwargs.get("message_type") == "tool_result"
839
+ and messages[0].additional_kwargs.get("message_type") == "tool_result"
849
840
  ):
850
- messages.pop()
841
+ messages.pop(0)
851
842
 
852
843
  total_tokens, _ = count_total_tokens(messages, encoder, system_message)
853
844
 
@@ -886,11 +877,7 @@ def truncate_messages(
886
877
  f"Truncate current message to fit within max prompt size of {max_prompt_size} supported by {model_name} model:\n {truncated_snippet}"
887
878
  )
888
879
 
889
- if system_message:
890
- # Default system message role is system.
891
- # Fallback to system message role of user for models that do not support this role like gemma-2 and openai's o1 model series.
892
- system_message.role = "user" if "gemma-2" in model_name or model_name.startswith("o1") else "system"
893
- return messages + [system_message] if system_message else messages
880
+ return system_message + messages if system_message else messages
894
881
 
895
882
 
896
883
  def reciprocal_conversation_to_chatml(message_pair):
@@ -1,6 +1,7 @@
1
1
  import base64
2
2
  import io
3
3
  import logging
4
+ import os
4
5
  import time
5
6
  from typing import Any, Callable, Dict, List, Optional
6
7
 
@@ -21,11 +22,12 @@ from khoj.database.adapters import ConversationAdapters
21
22
  from khoj.database.models import (
22
23
  Agent,
23
24
  ChatMessageModel,
25
+ Intent,
24
26
  KhojUser,
25
27
  TextToImageModelConfig,
26
28
  )
27
29
  from khoj.processor.conversation.google.utils import _is_retryable_error
28
- from khoj.routers.helpers import ChatEvent, generate_better_image_prompt
30
+ from khoj.routers.helpers import ChatEvent, ImageShape, generate_better_image_prompt
29
31
  from khoj.routers.storage import upload_generated_image_to_bucket
30
32
  from khoj.utils import state
31
33
  from khoj.utils.helpers import convert_image_to_webp, timer
@@ -60,14 +62,17 @@ async def text_to_image(
60
62
  return
61
63
 
62
64
  text2image_model = text_to_image_config.model_name
63
- chat_history_str = ""
65
+ image_chat_history: List[ChatMessageModel] = []
66
+ default_intent = Intent(type="remember")
64
67
  for chat in chat_history[-4:]:
65
68
  if chat.by == "you":
66
- chat_history_str += f"Q: {chat.message}\n"
69
+ image_chat_history += [ChatMessageModel(by=chat.by, message=chat.message, intent=default_intent)]
70
+ elif chat.by == "khoj" and chat.images and chat.intent and chat.intent.inferred_queries:
71
+ image_chat_history += [
72
+ ChatMessageModel(by=chat.by, message=chat.intent.inferred_queries[0], intent=default_intent)
73
+ ]
67
74
  elif chat.by == "khoj" and chat.intent and chat.intent.type in ["remember", "reminder"]:
68
- chat_history_str += f"A: {chat.message}\n"
69
- elif chat.by == "khoj" and chat.images:
70
- chat_history_str += f"A: Improved Prompt: {chat.intent.inferred_queries[0]}\n"
75
+ image_chat_history += [ChatMessageModel(by=chat.by, message=chat.message, intent=default_intent)]
71
76
 
72
77
  if send_status_func:
73
78
  async for event in send_status_func("**Enhancing the Painting Prompt**"):
@@ -75,9 +80,9 @@ async def text_to_image(
75
80
 
76
81
  # Generate a better image prompt
77
82
  # Use the user's message, chat history, and other context
78
- image_prompt = await generate_better_image_prompt(
83
+ image_prompt_response = await generate_better_image_prompt(
79
84
  message,
80
- chat_history_str,
85
+ image_chat_history,
81
86
  location_data=location_data,
82
87
  note_references=references,
83
88
  online_results=online_results,
@@ -88,6 +93,8 @@ async def text_to_image(
88
93
  query_files=query_files,
89
94
  tracer=tracer,
90
95
  )
96
+ image_prompt = image_prompt_response["description"]
97
+ image_shape = image_prompt_response["shape"]
91
98
 
92
99
  if send_status_func:
93
100
  async for event in send_status_func(f"**Painting to Imagine**:\n{image_prompt}"):
@@ -97,13 +104,19 @@ async def text_to_image(
97
104
  with timer(f"Generate image with {text_to_image_config.model_type}", logger):
98
105
  try:
99
106
  if text_to_image_config.model_type == TextToImageModelConfig.ModelType.OPENAI:
100
- webp_image_bytes = generate_image_with_openai(image_prompt, text_to_image_config, text2image_model)
107
+ webp_image_bytes = generate_image_with_openai(
108
+ image_prompt, text_to_image_config, text2image_model, image_shape
109
+ )
101
110
  elif text_to_image_config.model_type == TextToImageModelConfig.ModelType.STABILITYAI:
102
111
  webp_image_bytes = generate_image_with_stability(image_prompt, text_to_image_config, text2image_model)
103
112
  elif text_to_image_config.model_type == TextToImageModelConfig.ModelType.REPLICATE:
104
- webp_image_bytes = generate_image_with_replicate(image_prompt, text_to_image_config, text2image_model)
113
+ webp_image_bytes = generate_image_with_replicate(
114
+ image_prompt, text_to_image_config, text2image_model, image_shape
115
+ )
105
116
  elif text_to_image_config.model_type == TextToImageModelConfig.ModelType.GOOGLE:
106
- webp_image_bytes = generate_image_with_google(image_prompt, text_to_image_config, text2image_model)
117
+ webp_image_bytes = generate_image_with_google(
118
+ image_prompt, text_to_image_config, text2image_model, image_shape
119
+ )
107
120
  except openai.OpenAIError or openai.BadRequestError or openai.APIConnectionError as e:
108
121
  if "content_policy_violation" in e.message:
109
122
  logger.error(f"Image Generation blocked by OpenAI: {e}")
@@ -154,7 +167,10 @@ async def text_to_image(
154
167
  reraise=True,
155
168
  )
156
169
  def generate_image_with_openai(
157
- improved_image_prompt: str, text_to_image_config: TextToImageModelConfig, text2image_model: str
170
+ improved_image_prompt: str,
171
+ text_to_image_config: TextToImageModelConfig,
172
+ text2image_model: str,
173
+ shape: ImageShape = ImageShape.SQUARE,
158
174
  ):
159
175
  "Generate image using OpenAI (compatible) API"
160
176
 
@@ -170,12 +186,21 @@ def generate_image_with_openai(
170
186
  elif state.openai_client:
171
187
  openai_client = state.openai_client
172
188
 
189
+ # Convert shape to size for OpenAI
190
+ if shape == ImageShape.PORTRAIT:
191
+ size = "1024x1536"
192
+ elif shape == ImageShape.LANDSCAPE:
193
+ size = "1536x1024"
194
+ else: # Square
195
+ size = "1024x1024"
196
+
173
197
  # Generate image using OpenAI API
174
198
  OPENAI_IMAGE_GEN_STYLE = "vivid"
175
199
  response = openai_client.images.generate(
176
200
  prompt=improved_image_prompt,
177
201
  model=text2image_model,
178
202
  style=OPENAI_IMAGE_GEN_STYLE,
203
+ size=size,
179
204
  response_format="b64_json",
180
205
  )
181
206
 
@@ -222,10 +247,22 @@ def generate_image_with_stability(
222
247
  reraise=True,
223
248
  )
224
249
  def generate_image_with_replicate(
225
- improved_image_prompt: str, text_to_image_config: TextToImageModelConfig, text2image_model: str
250
+ improved_image_prompt: str,
251
+ text_to_image_config: TextToImageModelConfig,
252
+ text2image_model: str,
253
+ shape: ImageShape = ImageShape.SQUARE,
226
254
  ):
227
255
  "Generate image using Replicate API"
228
256
 
257
+ # Convert shape to aspect ratio for Replicate
258
+ # Replicate supports only 1:1, 3:4, and 4:3 aspect ratios
259
+ if shape == ImageShape.PORTRAIT:
260
+ aspect_ratio = "3:4"
261
+ elif shape == ImageShape.LANDSCAPE:
262
+ aspect_ratio = "4:3"
263
+ else: # Square
264
+ aspect_ratio = "1:1"
265
+
229
266
  # Create image generation task on Replicate
230
267
  replicate_create_prediction_url = f"https://api.replicate.com/v1/models/{text2image_model}/predictions"
231
268
  headers = {
@@ -236,11 +273,16 @@ def generate_image_with_replicate(
236
273
  "input": {
237
274
  "prompt": improved_image_prompt,
238
275
  "num_outputs": 1,
239
- "aspect_ratio": "1:1",
276
+ "aspect_ratio": aspect_ratio,
240
277
  "output_format": "webp",
241
278
  "output_quality": 100,
242
279
  }
243
280
  }
281
+
282
+ seed = int(os.getenv("KHOJ_LLM_SEED")) if os.getenv("KHOJ_LLM_SEED") else None
283
+ if seed:
284
+ json["input"]["seed"] = seed
285
+
244
286
  create_prediction = requests.post(replicate_create_prediction_url, headers=headers, json=json).json()
245
287
 
246
288
  # Get status of image generation task
@@ -276,7 +318,10 @@ def generate_image_with_replicate(
276
318
  reraise=True,
277
319
  )
278
320
  def generate_image_with_google(
279
- improved_image_prompt: str, text_to_image_config: TextToImageModelConfig, text2image_model: str
321
+ improved_image_prompt: str,
322
+ text_to_image_config: TextToImageModelConfig,
323
+ text2image_model: str,
324
+ shape: ImageShape = ImageShape.SQUARE,
280
325
  ):
281
326
  """Generate image using Google's AI over API"""
282
327
 
@@ -284,6 +329,14 @@ def generate_image_with_google(
284
329
  api_key = text_to_image_config.api_key or text_to_image_config.ai_model_api.api_key
285
330
  client = genai.Client(api_key=api_key)
286
331
 
332
+ # Convert shape to aspect ratio for Google
333
+ if shape == ImageShape.PORTRAIT:
334
+ aspect_ratio = "3:4"
335
+ elif shape == ImageShape.LANDSCAPE:
336
+ aspect_ratio = "4:3"
337
+ else: # Square
338
+ aspect_ratio = "1:1"
339
+
287
340
  # Configure image generation settings
288
341
  config = gtypes.GenerateImagesConfig(
289
342
  number_of_images=1,
@@ -291,6 +344,7 @@ def generate_image_with_google(
291
344
  person_generation=gtypes.PersonGeneration.ALLOW_ADULT,
292
345
  include_rai_reason=True,
293
346
  output_mime_type="image/png",
347
+ aspect_ratio=aspect_ratio,
294
348
  )
295
349
 
296
350
  # Call the Gemini API to generate the image
@@ -156,10 +156,11 @@ async def generate_python_code(
156
156
 
157
157
  response = await send_message_to_model_wrapper(
158
158
  code_generation_prompt,
159
- query_images=query_images,
160
159
  query_files=query_files,
161
- user=user,
160
+ query_images=query_images,
161
+ fast_model=False,
162
162
  agent_chat_model=agent_chat_model,
163
+ user=user,
163
164
  tracer=tracer,
164
165
  )
165
166