khoj 2.0.0b14.dev51__py3-none-any.whl → 2.0.0b15.dev22__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (52) hide show
  1. khoj/database/adapters/__init__.py +59 -20
  2. khoj/database/admin.py +4 -0
  3. khoj/database/migrations/0094_serverchatsettings_think_free_deep_and_more.py +61 -0
  4. khoj/database/models/__init__.py +18 -2
  5. khoj/interface/compiled/404/index.html +2 -2
  6. khoj/interface/compiled/_next/static/chunks/{9808-0ae18d938933fea3.js → 9808-bd5d7361ad026094.js} +1 -1
  7. khoj/interface/compiled/_next/static/css/{2945c4a857922f3b.css → c34713c98384ee87.css} +1 -1
  8. khoj/interface/compiled/_next/static/css/fb7ea16e60b40ecd.css +1 -0
  9. khoj/interface/compiled/agents/index.html +2 -2
  10. khoj/interface/compiled/agents/index.txt +2 -2
  11. khoj/interface/compiled/automations/index.html +2 -2
  12. khoj/interface/compiled/automations/index.txt +3 -3
  13. khoj/interface/compiled/chat/index.html +2 -2
  14. khoj/interface/compiled/chat/index.txt +3 -3
  15. khoj/interface/compiled/index.html +2 -2
  16. khoj/interface/compiled/index.txt +2 -2
  17. khoj/interface/compiled/search/index.html +2 -2
  18. khoj/interface/compiled/search/index.txt +2 -2
  19. khoj/interface/compiled/settings/index.html +2 -2
  20. khoj/interface/compiled/settings/index.txt +4 -4
  21. khoj/interface/compiled/share/chat/index.html +2 -2
  22. khoj/interface/compiled/share/chat/index.txt +2 -2
  23. khoj/processor/conversation/anthropic/anthropic_chat.py +4 -88
  24. khoj/processor/conversation/anthropic/utils.py +1 -2
  25. khoj/processor/conversation/google/gemini_chat.py +4 -88
  26. khoj/processor/conversation/google/utils.py +6 -3
  27. khoj/processor/conversation/openai/gpt.py +16 -93
  28. khoj/processor/conversation/openai/utils.py +38 -30
  29. khoj/processor/conversation/prompts.py +30 -39
  30. khoj/processor/conversation/utils.py +70 -84
  31. khoj/processor/image/generate.py +69 -15
  32. khoj/processor/tools/run_code.py +3 -2
  33. khoj/routers/api_chat.py +8 -21
  34. khoj/routers/helpers.py +243 -156
  35. khoj/routers/research.py +6 -6
  36. khoj/utils/helpers.py +6 -2
  37. {khoj-2.0.0b14.dev51.dist-info → khoj-2.0.0b15.dev22.dist-info}/METADATA +1 -1
  38. {khoj-2.0.0b14.dev51.dist-info → khoj-2.0.0b15.dev22.dist-info}/RECORD +51 -50
  39. khoj/interface/compiled/_next/static/css/ecea704005ba630c.css +0 -1
  40. /khoj/interface/compiled/_next/static/chunks/{1327-511bb0a862efce80.js → 1327-e254819a9172cfa7.js} +0 -0
  41. /khoj/interface/compiled/_next/static/chunks/{1915-fbfe167c84ad60c5.js → 1915-5c6508f6ebb62a30.js} +0 -0
  42. /khoj/interface/compiled/_next/static/chunks/{2117-e78b6902ad6f75ec.js → 2117-080746c8e170c81a.js} +0 -0
  43. /khoj/interface/compiled/_next/static/chunks/{2939-4d4084c5b888b960.js → 2939-4af3fd24b8ffc9ad.js} +0 -0
  44. /khoj/interface/compiled/_next/static/chunks/{4447-d6cf93724d57e34b.js → 4447-cd95608f8e93e711.js} +0 -0
  45. /khoj/interface/compiled/_next/static/chunks/{8667-4b7790573b08c50d.js → 8667-50b03a89e82e0ba7.js} +0 -0
  46. /khoj/interface/compiled/_next/static/chunks/{9139-ce1ae935dac9c871.js → 9139-8ac4d9feb10f8869.js} +0 -0
  47. /khoj/interface/compiled/_next/static/chunks/{webpack-e572645654c4335e.js → webpack-5393aad3d824e0cb.js} +0 -0
  48. /khoj/interface/compiled/_next/static/{yBzbL9kxl5BudSA9F4Gr6 → t8O_8CJ9p3UtV9kEsAAWT}/_buildManifest.js +0 -0
  49. /khoj/interface/compiled/_next/static/{yBzbL9kxl5BudSA9F4Gr6 → t8O_8CJ9p3UtV9kEsAAWT}/_ssgManifest.js +0 -0
  50. {khoj-2.0.0b14.dev51.dist-info → khoj-2.0.0b15.dev22.dist-info}/WHEEL +0 -0
  51. {khoj-2.0.0b14.dev51.dist-info → khoj-2.0.0b15.dev22.dist-info}/entry_points.txt +0 -0
  52. {khoj-2.0.0b14.dev51.dist-info → khoj-2.0.0b15.dev22.dist-info}/licenses/LICENSE +0 -0
@@ -23,7 +23,6 @@ from transformers import AutoTokenizer, PreTrainedTokenizer, PreTrainedTokenizer
23
23
  from khoj.database.adapters import ConversationAdapters
24
24
  from khoj.database.models import (
25
25
  ChatMessageModel,
26
- ChatModel,
27
26
  ClientApplication,
28
27
  Intent,
29
28
  KhojUser,
@@ -263,7 +262,7 @@ def construct_question_history(
263
262
  continue
264
263
 
265
264
  message = chat.message
266
- inferred_queries_list = chat.intent.inferred_queries or []
265
+ inferred_queries_list = chat.intent.inferred_queries or [] if chat.intent else []
267
266
 
268
267
  # Ensure inferred_queries_list is a list, defaulting to the original query in a list
269
268
  if not inferred_queries_list:
@@ -450,7 +449,6 @@ async def save_to_conversation_log(
450
449
  query_images: List[str] = None,
451
450
  raw_query_files: List[FileAttachment] = [],
452
451
  generated_images: List[str] = [],
453
- raw_generated_files: List[FileAttachment] = [],
454
452
  generated_mermaidjs_diagram: str = None,
455
453
  research_results: Optional[List[ResearchIteration]] = None,
456
454
  train_of_thought: List[Any] = [],
@@ -475,7 +473,6 @@ async def save_to_conversation_log(
475
473
  "trainOfThought": train_of_thought,
476
474
  "turnId": turn_id,
477
475
  "images": generated_images,
478
- "queryFiles": [file.model_dump(mode="json") for file in raw_generated_files],
479
476
  }
480
477
 
481
478
  if generated_mermaidjs_diagram:
@@ -528,29 +525,18 @@ def construct_structured_message(
528
525
 
529
526
  Assume vision is enabled and chat model provider supports messages in chatml format, unless specified otherwise.
530
527
  """
531
- if not model_type or model_type in [
532
- ChatModel.ModelType.OPENAI,
533
- ChatModel.ModelType.GOOGLE,
534
- ChatModel.ModelType.ANTHROPIC,
535
- ]:
536
- constructed_messages: List[dict[str, Any]] = []
537
- if not is_none_or_empty(message):
538
- constructed_messages += [{"type": "text", "text": message}] if isinstance(message, str) else message
539
- # Drop image message passed by caller if chat model does not have vision enabled
540
- if not vision_enabled:
541
- constructed_messages = [m for m in constructed_messages if m.get("type") != "image_url"]
542
- if not is_none_or_empty(attached_file_context):
543
- constructed_messages += [{"type": "text", "text": attached_file_context}]
544
- if vision_enabled and images:
545
- for image in images:
546
- constructed_messages += [{"type": "image_url", "image_url": {"url": image}}]
547
- return constructed_messages
548
-
549
- message = message if isinstance(message, str) else "\n\n".join(m["text"] for m in message)
528
+ constructed_messages: List[dict[str, Any]] = []
529
+ if not is_none_or_empty(message):
530
+ constructed_messages += [{"type": "text", "text": message}] if isinstance(message, str) else message
531
+ # Drop image message passed by caller if chat model does not have vision enabled
532
+ if not vision_enabled:
533
+ constructed_messages = [m for m in constructed_messages if m.get("type") != "image_url"]
550
534
  if not is_none_or_empty(attached_file_context):
551
- return f"{attached_file_context}\n\n{message}"
552
-
553
- return message
535
+ constructed_messages += [{"type": "text", "text": attached_file_context}]
536
+ if vision_enabled and images:
537
+ for image in images:
538
+ constructed_messages += [{"type": "image_url", "image_url": {"url": image}}]
539
+ return constructed_messages
554
540
 
555
541
 
556
542
  def gather_raw_query_files(
@@ -570,20 +556,21 @@ def gather_raw_query_files(
570
556
 
571
557
 
572
558
  def generate_chatml_messages_with_context(
559
+ # Context
573
560
  user_message: str,
574
- system_message: str = None,
561
+ query_files: str = None,
562
+ query_images=None,
563
+ context_message="",
564
+ generated_asset_results: Dict[str, Dict] = {},
565
+ program_execution_context: List[str] = [],
575
566
  chat_history: list[ChatMessageModel] = [],
567
+ system_message: str = None,
568
+ # Model Config
576
569
  model_name="gpt-4o-mini",
570
+ model_type="",
577
571
  max_prompt_size=None,
578
572
  tokenizer_name=None,
579
- query_images=None,
580
573
  vision_enabled=False,
581
- model_type="",
582
- context_message="",
583
- query_files: str = None,
584
- generated_files: List[FileAttachment] = None,
585
- generated_asset_results: Dict[str, Dict] = {},
586
- program_execution_context: List[str] = [],
587
574
  ):
588
575
  """Generate chat messages with appropriate context from previous conversation to send to the chat model"""
589
576
  # Set max prompt size from user config or based on pre-configured for model and machine specs
@@ -605,18 +592,10 @@ def generate_chatml_messages_with_context(
605
592
  role = "user" if chat.by == "you" else "assistant"
606
593
 
607
594
  # Legacy code to handle excalidraw diagrams prior to Dec 2024
608
- if chat.by == "khoj" and "excalidraw" in chat.intent.type or "":
595
+ if chat.by == "khoj" and chat.intent and "excalidraw" in chat.intent.type:
609
596
  chat_message = (chat.intent.inferred_queries or [])[0]
610
597
 
611
- if chat.queryFiles:
612
- raw_query_files = chat.queryFiles
613
- query_files_dict = dict()
614
- for file in raw_query_files:
615
- query_files_dict[file["name"]] = file["content"]
616
-
617
- message_attached_files = gather_raw_query_files(query_files_dict)
618
- chatml_messages.append(ChatMessage(content=message_attached_files, role=role))
619
-
598
+ # Add search and action context
620
599
  if not is_none_or_empty(chat.onlineContext):
621
600
  message_context += [
622
601
  {
@@ -655,11 +634,12 @@ def generate_chatml_messages_with_context(
655
634
 
656
635
  if not is_none_or_empty(message_context):
657
636
  reconstructed_context_message = ChatMessage(content=message_context, role="user")
658
- chatml_messages.insert(0, reconstructed_context_message)
637
+ chatml_messages.append(reconstructed_context_message)
659
638
 
639
+ # Add generated assets
660
640
  if not is_none_or_empty(chat.images) and role == "assistant":
661
641
  generated_assets["image"] = {
662
- "query": (chat.intent.inferred_queries or [user_message])[0],
642
+ "description": (chat.intent.inferred_queries or [user_message])[0],
663
643
  }
664
644
 
665
645
  if not is_none_or_empty(chat.mermaidjsDiagram) and role == "assistant":
@@ -675,8 +655,17 @@ def generate_chatml_messages_with_context(
675
655
  )
676
656
  )
677
657
 
658
+ # Add user query with attached file, images or khoj response
659
+ if chat.queryFiles:
660
+ raw_query_files = chat.queryFiles
661
+ query_files_dict = dict()
662
+ for file in raw_query_files:
663
+ query_files_dict[file["name"]] = file["content"]
664
+
665
+ message_attached_files = gather_raw_query_files(query_files_dict)
666
+
678
667
  message_content = construct_structured_message(
679
- chat_message, chat.images if role == "user" else [], model_type, vision_enabled
668
+ chat_message, chat.images if role == "user" else [], model_type, vision_enabled, message_attached_files
680
669
  )
681
670
 
682
671
  reconstructed_message = ChatMessage(
@@ -684,19 +673,32 @@ def generate_chatml_messages_with_context(
684
673
  role=role,
685
674
  additional_kwargs={"message_type": chat.intent.type if chat.intent else None},
686
675
  )
687
- chatml_messages.insert(0, reconstructed_message)
676
+ chatml_messages.append(reconstructed_message)
688
677
 
689
678
  if len(chatml_messages) >= 3 * lookback_turns:
690
679
  break
691
680
 
692
681
  messages: list[ChatMessage] = []
693
682
 
683
+ if not is_none_or_empty(system_message):
684
+ messages.append(ChatMessage(content=system_message, role="system"))
685
+
686
+ if len(chatml_messages) > 0:
687
+ messages += chatml_messages
688
+
689
+ if program_execution_context:
690
+ program_context_text = "\n".join(program_execution_context)
691
+ context_message += f"{prompts.additional_program_context.format(context=program_context_text)}\n"
692
+
693
+ if not is_none_or_empty(context_message):
694
+ messages.append(ChatMessage(content=context_message, role="user"))
695
+
694
696
  if not is_none_or_empty(generated_asset_results):
695
697
  messages.append(
696
698
  ChatMessage(
697
- content=f"{prompts.generated_assets_context.format(generated_assets=yaml_dump(generated_asset_results))}\n\n",
699
+ content=prompts.generated_assets_context.format(generated_assets=yaml_dump(generated_asset_results)),
698
700
  role="user",
699
- )
701
+ ),
700
702
  )
701
703
 
702
704
  if not is_none_or_empty(user_message):
@@ -709,23 +711,6 @@ def generate_chatml_messages_with_context(
709
711
  )
710
712
  )
711
713
 
712
- if generated_files:
713
- message_attached_files = gather_raw_query_files({file.name: file.content for file in generated_files})
714
- messages.append(ChatMessage(content=message_attached_files, role="assistant"))
715
-
716
- if program_execution_context:
717
- program_context_text = "\n".join(program_execution_context)
718
- context_message += f"{prompts.additional_program_context.format(context=program_context_text)}\n"
719
-
720
- if not is_none_or_empty(context_message):
721
- messages.append(ChatMessage(content=context_message, role="user"))
722
-
723
- if len(chatml_messages) > 0:
724
- messages += chatml_messages
725
-
726
- if not is_none_or_empty(system_message):
727
- messages.append(ChatMessage(content=system_message, role="system"))
728
-
729
714
  # Normalize message content to list of chatml dictionaries
730
715
  for message in messages:
731
716
  if isinstance(message.content, str):
@@ -734,8 +719,8 @@ def generate_chatml_messages_with_context(
734
719
  # Truncate oldest messages from conversation history until under max supported prompt size by model
735
720
  messages = truncate_messages(messages, max_prompt_size, model_name, tokenizer_name)
736
721
 
737
- # Return message in chronological order
738
- return messages[::-1]
722
+ # Return messages in chronological order
723
+ return messages
739
724
 
740
725
 
741
726
  def get_encoder(
@@ -806,7 +791,9 @@ def count_tokens(
806
791
 
807
792
  def count_total_tokens(messages: list[ChatMessage], encoder, system_message: Optional[ChatMessage]) -> Tuple[int, int]:
808
793
  """Count total tokens in messages including system message"""
809
- system_message_tokens = count_tokens(system_message.content, encoder) if system_message else 0
794
+ system_message_tokens = (
795
+ sum([count_tokens(message.content, encoder) for message in system_message]) if system_message else 0
796
+ )
810
797
  message_tokens = sum([count_tokens(message.content, encoder) for message in messages])
811
798
  # Reserves 4 tokens to demarcate each message (e.g <|im_start|>user, <|im_end|>, <|endoftext|> etc.)
812
799
  total_tokens = message_tokens + system_message_tokens + 4 * len(messages)
@@ -823,11 +810,14 @@ def truncate_messages(
823
810
  encoder = get_encoder(model_name, tokenizer_name)
824
811
 
825
812
  # Extract system message from messages
826
- system_message = None
827
- for idx, message in enumerate(messages):
813
+ system_message = []
814
+ non_system_messages = []
815
+ for message in messages:
828
816
  if message.role == "system":
829
- system_message = messages.pop(idx)
830
- break
817
+ system_message.append(message)
818
+ else:
819
+ non_system_messages.append(message)
820
+ messages = non_system_messages
831
821
 
832
822
  # Drop older messages until under max supported prompt size by model
833
823
  total_tokens, system_message_tokens = count_total_tokens(messages, encoder, system_message)
@@ -835,20 +825,20 @@ def truncate_messages(
835
825
  while total_tokens > max_prompt_size and (len(messages) > 1 or len(messages[0].content) > 1):
836
826
  # If the last message has more than one content part, pop the oldest content part.
837
827
  # For tool calls, the whole message should dropped, assistant's tool call content being truncated annoys AI APIs.
838
- if len(messages[-1].content) > 1 and messages[-1].additional_kwargs.get("message_type") != "tool_call":
828
+ if len(messages[0].content) > 1 and messages[0].additional_kwargs.get("message_type") != "tool_call":
839
829
  # The oldest content part is earlier in content list. So pop from the front.
840
- messages[-1].content.pop(0)
830
+ messages[0].content.pop(0)
841
831
  # Otherwise, pop the last message if it has only one content part or is a tool call.
842
832
  else:
843
833
  # The oldest message is the last one. So pop from the back.
844
- dropped_message = messages.pop()
834
+ dropped_message = messages.pop(0)
845
835
  # Drop tool result pair of tool call, if tool call message has been removed
846
836
  if (
847
837
  dropped_message.additional_kwargs.get("message_type") == "tool_call"
848
838
  and messages
849
- and messages[-1].additional_kwargs.get("message_type") == "tool_result"
839
+ and messages[0].additional_kwargs.get("message_type") == "tool_result"
850
840
  ):
851
- messages.pop()
841
+ messages.pop(0)
852
842
 
853
843
  total_tokens, _ = count_total_tokens(messages, encoder, system_message)
854
844
 
@@ -887,11 +877,7 @@ def truncate_messages(
887
877
  f"Truncate current message to fit within max prompt size of {max_prompt_size} supported by {model_name} model:\n {truncated_snippet}"
888
878
  )
889
879
 
890
- if system_message:
891
- # Default system message role is system.
892
- # Fallback to system message role of user for models that do not support this role like gemma-2 and openai's o1 model series.
893
- system_message.role = "user" if "gemma-2" in model_name or model_name.startswith("o1") else "system"
894
- return messages + [system_message] if system_message else messages
880
+ return system_message + messages if system_message else messages
895
881
 
896
882
 
897
883
  def reciprocal_conversation_to_chatml(message_pair):
@@ -1,6 +1,7 @@
1
1
  import base64
2
2
  import io
3
3
  import logging
4
+ import os
4
5
  import time
5
6
  from typing import Any, Callable, Dict, List, Optional
6
7
 
@@ -21,11 +22,12 @@ from khoj.database.adapters import ConversationAdapters
21
22
  from khoj.database.models import (
22
23
  Agent,
23
24
  ChatMessageModel,
25
+ Intent,
24
26
  KhojUser,
25
27
  TextToImageModelConfig,
26
28
  )
27
29
  from khoj.processor.conversation.google.utils import _is_retryable_error
28
- from khoj.routers.helpers import ChatEvent, generate_better_image_prompt
30
+ from khoj.routers.helpers import ChatEvent, ImageShape, generate_better_image_prompt
29
31
  from khoj.routers.storage import upload_generated_image_to_bucket
30
32
  from khoj.utils import state
31
33
  from khoj.utils.helpers import convert_image_to_webp, timer
@@ -60,14 +62,17 @@ async def text_to_image(
60
62
  return
61
63
 
62
64
  text2image_model = text_to_image_config.model_name
63
- chat_history_str = ""
65
+ image_chat_history: List[ChatMessageModel] = []
66
+ default_intent = Intent(type="remember")
64
67
  for chat in chat_history[-4:]:
65
68
  if chat.by == "you":
66
- chat_history_str += f"Q: {chat.message}\n"
69
+ image_chat_history += [ChatMessageModel(by=chat.by, message=chat.message, intent=default_intent)]
70
+ elif chat.by == "khoj" and chat.images and chat.intent and chat.intent.inferred_queries:
71
+ image_chat_history += [
72
+ ChatMessageModel(by=chat.by, message=chat.intent.inferred_queries[0], intent=default_intent)
73
+ ]
67
74
  elif chat.by == "khoj" and chat.intent and chat.intent.type in ["remember", "reminder"]:
68
- chat_history_str += f"A: {chat.message}\n"
69
- elif chat.by == "khoj" and chat.images:
70
- chat_history_str += f"A: Improved Prompt: {chat.intent.inferred_queries[0]}\n"
75
+ image_chat_history += [ChatMessageModel(by=chat.by, message=chat.message, intent=default_intent)]
71
76
 
72
77
  if send_status_func:
73
78
  async for event in send_status_func("**Enhancing the Painting Prompt**"):
@@ -75,9 +80,9 @@ async def text_to_image(
75
80
 
76
81
  # Generate a better image prompt
77
82
  # Use the user's message, chat history, and other context
78
- image_prompt = await generate_better_image_prompt(
83
+ image_prompt_response = await generate_better_image_prompt(
79
84
  message,
80
- chat_history_str,
85
+ image_chat_history,
81
86
  location_data=location_data,
82
87
  note_references=references,
83
88
  online_results=online_results,
@@ -88,6 +93,8 @@ async def text_to_image(
88
93
  query_files=query_files,
89
94
  tracer=tracer,
90
95
  )
96
+ image_prompt = image_prompt_response["description"]
97
+ image_shape = image_prompt_response["shape"]
91
98
 
92
99
  if send_status_func:
93
100
  async for event in send_status_func(f"**Painting to Imagine**:\n{image_prompt}"):
@@ -97,13 +104,19 @@ async def text_to_image(
97
104
  with timer(f"Generate image with {text_to_image_config.model_type}", logger):
98
105
  try:
99
106
  if text_to_image_config.model_type == TextToImageModelConfig.ModelType.OPENAI:
100
- webp_image_bytes = generate_image_with_openai(image_prompt, text_to_image_config, text2image_model)
107
+ webp_image_bytes = generate_image_with_openai(
108
+ image_prompt, text_to_image_config, text2image_model, image_shape
109
+ )
101
110
  elif text_to_image_config.model_type == TextToImageModelConfig.ModelType.STABILITYAI:
102
111
  webp_image_bytes = generate_image_with_stability(image_prompt, text_to_image_config, text2image_model)
103
112
  elif text_to_image_config.model_type == TextToImageModelConfig.ModelType.REPLICATE:
104
- webp_image_bytes = generate_image_with_replicate(image_prompt, text_to_image_config, text2image_model)
113
+ webp_image_bytes = generate_image_with_replicate(
114
+ image_prompt, text_to_image_config, text2image_model, image_shape
115
+ )
105
116
  elif text_to_image_config.model_type == TextToImageModelConfig.ModelType.GOOGLE:
106
- webp_image_bytes = generate_image_with_google(image_prompt, text_to_image_config, text2image_model)
117
+ webp_image_bytes = generate_image_with_google(
118
+ image_prompt, text_to_image_config, text2image_model, image_shape
119
+ )
107
120
  except openai.OpenAIError or openai.BadRequestError or openai.APIConnectionError as e:
108
121
  if "content_policy_violation" in e.message:
109
122
  logger.error(f"Image Generation blocked by OpenAI: {e}")
@@ -154,7 +167,10 @@ async def text_to_image(
154
167
  reraise=True,
155
168
  )
156
169
  def generate_image_with_openai(
157
- improved_image_prompt: str, text_to_image_config: TextToImageModelConfig, text2image_model: str
170
+ improved_image_prompt: str,
171
+ text_to_image_config: TextToImageModelConfig,
172
+ text2image_model: str,
173
+ shape: ImageShape = ImageShape.SQUARE,
158
174
  ):
159
175
  "Generate image using OpenAI (compatible) API"
160
176
 
@@ -170,12 +186,21 @@ def generate_image_with_openai(
170
186
  elif state.openai_client:
171
187
  openai_client = state.openai_client
172
188
 
189
+ # Convert shape to size for OpenAI
190
+ if shape == ImageShape.PORTRAIT:
191
+ size = "1024x1536"
192
+ elif shape == ImageShape.LANDSCAPE:
193
+ size = "1536x1024"
194
+ else: # Square
195
+ size = "1024x1024"
196
+
173
197
  # Generate image using OpenAI API
174
198
  OPENAI_IMAGE_GEN_STYLE = "vivid"
175
199
  response = openai_client.images.generate(
176
200
  prompt=improved_image_prompt,
177
201
  model=text2image_model,
178
202
  style=OPENAI_IMAGE_GEN_STYLE,
203
+ size=size,
179
204
  response_format="b64_json",
180
205
  )
181
206
 
@@ -222,10 +247,22 @@ def generate_image_with_stability(
222
247
  reraise=True,
223
248
  )
224
249
  def generate_image_with_replicate(
225
- improved_image_prompt: str, text_to_image_config: TextToImageModelConfig, text2image_model: str
250
+ improved_image_prompt: str,
251
+ text_to_image_config: TextToImageModelConfig,
252
+ text2image_model: str,
253
+ shape: ImageShape = ImageShape.SQUARE,
226
254
  ):
227
255
  "Generate image using Replicate API"
228
256
 
257
+ # Convert shape to aspect ratio for Replicate
258
+ # Replicate supports only 1:1, 3:4, and 4:3 aspect ratios
259
+ if shape == ImageShape.PORTRAIT:
260
+ aspect_ratio = "3:4"
261
+ elif shape == ImageShape.LANDSCAPE:
262
+ aspect_ratio = "4:3"
263
+ else: # Square
264
+ aspect_ratio = "1:1"
265
+
229
266
  # Create image generation task on Replicate
230
267
  replicate_create_prediction_url = f"https://api.replicate.com/v1/models/{text2image_model}/predictions"
231
268
  headers = {
@@ -236,11 +273,16 @@ def generate_image_with_replicate(
236
273
  "input": {
237
274
  "prompt": improved_image_prompt,
238
275
  "num_outputs": 1,
239
- "aspect_ratio": "1:1",
276
+ "aspect_ratio": aspect_ratio,
240
277
  "output_format": "webp",
241
278
  "output_quality": 100,
242
279
  }
243
280
  }
281
+
282
+ seed = int(os.getenv("KHOJ_LLM_SEED")) if os.getenv("KHOJ_LLM_SEED") else None
283
+ if seed:
284
+ json["input"]["seed"] = seed
285
+
244
286
  create_prediction = requests.post(replicate_create_prediction_url, headers=headers, json=json).json()
245
287
 
246
288
  # Get status of image generation task
@@ -276,7 +318,10 @@ def generate_image_with_replicate(
276
318
  reraise=True,
277
319
  )
278
320
  def generate_image_with_google(
279
- improved_image_prompt: str, text_to_image_config: TextToImageModelConfig, text2image_model: str
321
+ improved_image_prompt: str,
322
+ text_to_image_config: TextToImageModelConfig,
323
+ text2image_model: str,
324
+ shape: ImageShape = ImageShape.SQUARE,
280
325
  ):
281
326
  """Generate image using Google's AI over API"""
282
327
 
@@ -284,6 +329,14 @@ def generate_image_with_google(
284
329
  api_key = text_to_image_config.api_key or text_to_image_config.ai_model_api.api_key
285
330
  client = genai.Client(api_key=api_key)
286
331
 
332
+ # Convert shape to aspect ratio for Google
333
+ if shape == ImageShape.PORTRAIT:
334
+ aspect_ratio = "3:4"
335
+ elif shape == ImageShape.LANDSCAPE:
336
+ aspect_ratio = "4:3"
337
+ else: # Square
338
+ aspect_ratio = "1:1"
339
+
287
340
  # Configure image generation settings
288
341
  config = gtypes.GenerateImagesConfig(
289
342
  number_of_images=1,
@@ -291,6 +344,7 @@ def generate_image_with_google(
291
344
  person_generation=gtypes.PersonGeneration.ALLOW_ADULT,
292
345
  include_rai_reason=True,
293
346
  output_mime_type="image/png",
347
+ aspect_ratio=aspect_ratio,
294
348
  )
295
349
 
296
350
  # Call the Gemini API to generate the image
@@ -156,10 +156,11 @@ async def generate_python_code(
156
156
 
157
157
  response = await send_message_to_model_wrapper(
158
158
  code_generation_prompt,
159
- query_images=query_images,
160
159
  query_files=query_files,
161
- user=user,
160
+ query_images=query_images,
161
+ fast_model=False,
162
162
  agent_chat_model=agent_chat_model,
163
+ user=user,
163
164
  tracer=tracer,
164
165
  )
165
166
 
khoj/routers/api_chat.py CHANGED
@@ -90,7 +90,6 @@ from khoj.utils.helpers import (
90
90
  is_operator_enabled,
91
91
  )
92
92
  from khoj.utils.rawconfig import (
93
- FileAttachment,
94
93
  FileFilterRequest,
95
94
  FilesFilterRequest,
96
95
  LocationData,
@@ -732,7 +731,6 @@ async def event_generator(
732
731
  attached_file_context = gather_raw_query_files(query_files)
733
732
 
734
733
  generated_images: List[str] = []
735
- generated_files: List[FileAttachment] = []
736
734
  generated_mermaidjs_diagram: str = None
737
735
  generated_asset_results: Dict = dict()
738
736
  program_execution_context: List[str] = []
@@ -769,7 +767,6 @@ async def event_generator(
769
767
  train_of_thought=train_of_thought,
770
768
  raw_query_files=raw_query_files,
771
769
  generated_images=generated_images,
772
- raw_generated_files=generated_asset_results,
773
770
  generated_mermaidjs_diagram=generated_mermaidjs_diagram,
774
771
  user_message_time=user_message_time,
775
772
  tracer=tracer,
@@ -816,7 +813,6 @@ async def event_generator(
816
813
  train_of_thought=train_of_thought,
817
814
  raw_query_files=raw_query_files,
818
815
  generated_images=generated_images,
819
- raw_generated_files=generated_asset_results,
820
816
  generated_mermaidjs_diagram=generated_mermaidjs_diagram,
821
817
  user_message_time=user_message_time,
822
818
  tracer=tracer,
@@ -927,9 +923,7 @@ async def event_generator(
927
923
 
928
924
  # Automated tasks are handled before to allow mixing them with other conversation commands
929
925
  cmds_to_rate_limit = []
930
- is_automated_task = False
931
926
  if q.startswith("/automated_task"):
932
- is_automated_task = True
933
927
  q = q.replace("/automated_task", "").lstrip()
934
928
  cmds_to_rate_limit += [ConversationCommand.AutomatedTask]
935
929
 
@@ -989,7 +983,6 @@ async def event_generator(
989
983
  chosen_io = await aget_data_sources_and_output_format(
990
984
  q,
991
985
  chat_history,
992
- is_automated_task,
993
986
  user=user,
994
987
  query_images=uploaded_images,
995
988
  agent=agent,
@@ -1021,7 +1014,6 @@ async def event_generator(
1021
1014
  return
1022
1015
 
1023
1016
  defiltered_query = defilter_query(q)
1024
- file_filters = conversation.file_filters if conversation and conversation.file_filters else []
1025
1017
 
1026
1018
  if conversation_commands == [ConversationCommand.Research]:
1027
1019
  async for research_result in research(
@@ -1035,12 +1027,11 @@ async def event_generator(
1035
1027
  send_status_func=partial(send_event, ChatEvent.STATUS),
1036
1028
  user_name=user_name,
1037
1029
  location=location,
1038
- file_filters=file_filters,
1039
1030
  query_files=attached_file_context,
1040
- tracer=tracer,
1041
1031
  cancellation_event=cancellation_event,
1042
1032
  interrupt_queue=child_interrupt_queue,
1043
1033
  abort_message=ChatEvent.END_EVENT.value,
1034
+ tracer=tracer,
1044
1035
  ):
1045
1036
  if isinstance(research_result, ResearchIteration):
1046
1037
  if research_result.summarizedResult:
@@ -1075,8 +1066,8 @@ async def event_generator(
1075
1066
  logger.debug(f"Researched Results: {''.join(r.summarizedResult or '' for r in research_results)}")
1076
1067
 
1077
1068
  # Gather Context
1078
- ## Extract Document References
1079
- if ConversationCommand.Research not in conversation_commands:
1069
+ ## Gather Document References
1070
+ if ConversationCommand.Notes in conversation_commands:
1080
1071
  try:
1081
1072
  async for result in search_documents(
1082
1073
  q,
@@ -1194,7 +1185,7 @@ async def event_generator(
1194
1185
  ):
1195
1186
  yield result
1196
1187
 
1197
- ## Gather Code Results
1188
+ ## Run Code
1198
1189
  if ConversationCommand.Code in conversation_commands:
1199
1190
  try:
1200
1191
  context = f"# Iteration 1:\n#---\nNotes:\n{compiled_references}\n\nOnline Results:{online_results}"
@@ -1220,6 +1211,8 @@ async def event_generator(
1220
1211
  f"Failed to use code tool: {e}. Attempting to respond without code results",
1221
1212
  exc_info=True,
1222
1213
  )
1214
+
1215
+ ## Operate Computer
1223
1216
  if ConversationCommand.Operator in conversation_commands:
1224
1217
  try:
1225
1218
  async for result in operate_environment(
@@ -1300,7 +1293,7 @@ async def event_generator(
1300
1293
  generated_images.append(generated_image)
1301
1294
 
1302
1295
  generated_asset_results["images"] = {
1303
- "query": improved_image_prompt,
1296
+ "description": improved_image_prompt,
1304
1297
  }
1305
1298
 
1306
1299
  async for result in send_event(
@@ -1316,8 +1309,6 @@ async def event_generator(
1316
1309
  yield result
1317
1310
 
1318
1311
  inferred_queries = []
1319
- diagram_description = ""
1320
-
1321
1312
  async for result in generate_mermaidjs_diagram(
1322
1313
  q=defiltered_query,
1323
1314
  chat_history=chat_history,
@@ -1337,9 +1328,7 @@ async def event_generator(
1337
1328
  better_diagram_description_prompt, mermaidjs_diagram_description = result
1338
1329
  if better_diagram_description_prompt and mermaidjs_diagram_description:
1339
1330
  inferred_queries.append(better_diagram_description_prompt)
1340
- diagram_description = mermaidjs_diagram_description
1341
-
1342
- generated_mermaidjs_diagram = diagram_description
1331
+ generated_mermaidjs_diagram = mermaidjs_diagram_description
1343
1332
 
1344
1333
  generated_asset_results["diagrams"] = {
1345
1334
  "query": better_diagram_description_prompt,
@@ -1386,7 +1375,6 @@ async def event_generator(
1386
1375
  user_name,
1387
1376
  uploaded_images,
1388
1377
  attached_file_context,
1389
- generated_files,
1390
1378
  program_execution_context,
1391
1379
  generated_asset_results,
1392
1380
  is_subscribed,
@@ -1447,7 +1435,6 @@ async def event_generator(
1447
1435
  train_of_thought=train_of_thought,
1448
1436
  raw_query_files=raw_query_files,
1449
1437
  generated_images=generated_images,
1450
- raw_generated_files=generated_files,
1451
1438
  generated_mermaidjs_diagram=generated_mermaidjs_diagram,
1452
1439
  tracer=tracer,
1453
1440
  )