khoj 2.0.0b14.dev51__py3-none-any.whl → 2.0.0b15.dev22__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- khoj/database/adapters/__init__.py +59 -20
- khoj/database/admin.py +4 -0
- khoj/database/migrations/0094_serverchatsettings_think_free_deep_and_more.py +61 -0
- khoj/database/models/__init__.py +18 -2
- khoj/interface/compiled/404/index.html +2 -2
- khoj/interface/compiled/_next/static/chunks/{9808-0ae18d938933fea3.js → 9808-bd5d7361ad026094.js} +1 -1
- khoj/interface/compiled/_next/static/css/{2945c4a857922f3b.css → c34713c98384ee87.css} +1 -1
- khoj/interface/compiled/_next/static/css/fb7ea16e60b40ecd.css +1 -0
- khoj/interface/compiled/agents/index.html +2 -2
- khoj/interface/compiled/agents/index.txt +2 -2
- khoj/interface/compiled/automations/index.html +2 -2
- khoj/interface/compiled/automations/index.txt +3 -3
- khoj/interface/compiled/chat/index.html +2 -2
- khoj/interface/compiled/chat/index.txt +3 -3
- khoj/interface/compiled/index.html +2 -2
- khoj/interface/compiled/index.txt +2 -2
- khoj/interface/compiled/search/index.html +2 -2
- khoj/interface/compiled/search/index.txt +2 -2
- khoj/interface/compiled/settings/index.html +2 -2
- khoj/interface/compiled/settings/index.txt +4 -4
- khoj/interface/compiled/share/chat/index.html +2 -2
- khoj/interface/compiled/share/chat/index.txt +2 -2
- khoj/processor/conversation/anthropic/anthropic_chat.py +4 -88
- khoj/processor/conversation/anthropic/utils.py +1 -2
- khoj/processor/conversation/google/gemini_chat.py +4 -88
- khoj/processor/conversation/google/utils.py +6 -3
- khoj/processor/conversation/openai/gpt.py +16 -93
- khoj/processor/conversation/openai/utils.py +38 -30
- khoj/processor/conversation/prompts.py +30 -39
- khoj/processor/conversation/utils.py +70 -84
- khoj/processor/image/generate.py +69 -15
- khoj/processor/tools/run_code.py +3 -2
- khoj/routers/api_chat.py +8 -21
- khoj/routers/helpers.py +243 -156
- khoj/routers/research.py +6 -6
- khoj/utils/helpers.py +6 -2
- {khoj-2.0.0b14.dev51.dist-info → khoj-2.0.0b15.dev22.dist-info}/METADATA +1 -1
- {khoj-2.0.0b14.dev51.dist-info → khoj-2.0.0b15.dev22.dist-info}/RECORD +51 -50
- khoj/interface/compiled/_next/static/css/ecea704005ba630c.css +0 -1
- /khoj/interface/compiled/_next/static/chunks/{1327-511bb0a862efce80.js → 1327-e254819a9172cfa7.js} +0 -0
- /khoj/interface/compiled/_next/static/chunks/{1915-fbfe167c84ad60c5.js → 1915-5c6508f6ebb62a30.js} +0 -0
- /khoj/interface/compiled/_next/static/chunks/{2117-e78b6902ad6f75ec.js → 2117-080746c8e170c81a.js} +0 -0
- /khoj/interface/compiled/_next/static/chunks/{2939-4d4084c5b888b960.js → 2939-4af3fd24b8ffc9ad.js} +0 -0
- /khoj/interface/compiled/_next/static/chunks/{4447-d6cf93724d57e34b.js → 4447-cd95608f8e93e711.js} +0 -0
- /khoj/interface/compiled/_next/static/chunks/{8667-4b7790573b08c50d.js → 8667-50b03a89e82e0ba7.js} +0 -0
- /khoj/interface/compiled/_next/static/chunks/{9139-ce1ae935dac9c871.js → 9139-8ac4d9feb10f8869.js} +0 -0
- /khoj/interface/compiled/_next/static/chunks/{webpack-e572645654c4335e.js → webpack-5393aad3d824e0cb.js} +0 -0
- /khoj/interface/compiled/_next/static/{yBzbL9kxl5BudSA9F4Gr6 → t8O_8CJ9p3UtV9kEsAAWT}/_buildManifest.js +0 -0
- /khoj/interface/compiled/_next/static/{yBzbL9kxl5BudSA9F4Gr6 → t8O_8CJ9p3UtV9kEsAAWT}/_ssgManifest.js +0 -0
- {khoj-2.0.0b14.dev51.dist-info → khoj-2.0.0b15.dev22.dist-info}/WHEEL +0 -0
- {khoj-2.0.0b14.dev51.dist-info → khoj-2.0.0b15.dev22.dist-info}/entry_points.txt +0 -0
- {khoj-2.0.0b14.dev51.dist-info → khoj-2.0.0b15.dev22.dist-info}/licenses/LICENSE +0 -0
@@ -23,7 +23,6 @@ from transformers import AutoTokenizer, PreTrainedTokenizer, PreTrainedTokenizer
|
|
23
23
|
from khoj.database.adapters import ConversationAdapters
|
24
24
|
from khoj.database.models import (
|
25
25
|
ChatMessageModel,
|
26
|
-
ChatModel,
|
27
26
|
ClientApplication,
|
28
27
|
Intent,
|
29
28
|
KhojUser,
|
@@ -263,7 +262,7 @@ def construct_question_history(
|
|
263
262
|
continue
|
264
263
|
|
265
264
|
message = chat.message
|
266
|
-
inferred_queries_list = chat.intent.inferred_queries or []
|
265
|
+
inferred_queries_list = chat.intent.inferred_queries or [] if chat.intent else []
|
267
266
|
|
268
267
|
# Ensure inferred_queries_list is a list, defaulting to the original query in a list
|
269
268
|
if not inferred_queries_list:
|
@@ -450,7 +449,6 @@ async def save_to_conversation_log(
|
|
450
449
|
query_images: List[str] = None,
|
451
450
|
raw_query_files: List[FileAttachment] = [],
|
452
451
|
generated_images: List[str] = [],
|
453
|
-
raw_generated_files: List[FileAttachment] = [],
|
454
452
|
generated_mermaidjs_diagram: str = None,
|
455
453
|
research_results: Optional[List[ResearchIteration]] = None,
|
456
454
|
train_of_thought: List[Any] = [],
|
@@ -475,7 +473,6 @@ async def save_to_conversation_log(
|
|
475
473
|
"trainOfThought": train_of_thought,
|
476
474
|
"turnId": turn_id,
|
477
475
|
"images": generated_images,
|
478
|
-
"queryFiles": [file.model_dump(mode="json") for file in raw_generated_files],
|
479
476
|
}
|
480
477
|
|
481
478
|
if generated_mermaidjs_diagram:
|
@@ -528,29 +525,18 @@ def construct_structured_message(
|
|
528
525
|
|
529
526
|
Assume vision is enabled and chat model provider supports messages in chatml format, unless specified otherwise.
|
530
527
|
"""
|
531
|
-
|
532
|
-
|
533
|
-
|
534
|
-
|
535
|
-
|
536
|
-
constructed_messages
|
537
|
-
if not is_none_or_empty(message):
|
538
|
-
constructed_messages += [{"type": "text", "text": message}] if isinstance(message, str) else message
|
539
|
-
# Drop image message passed by caller if chat model does not have vision enabled
|
540
|
-
if not vision_enabled:
|
541
|
-
constructed_messages = [m for m in constructed_messages if m.get("type") != "image_url"]
|
542
|
-
if not is_none_or_empty(attached_file_context):
|
543
|
-
constructed_messages += [{"type": "text", "text": attached_file_context}]
|
544
|
-
if vision_enabled and images:
|
545
|
-
for image in images:
|
546
|
-
constructed_messages += [{"type": "image_url", "image_url": {"url": image}}]
|
547
|
-
return constructed_messages
|
548
|
-
|
549
|
-
message = message if isinstance(message, str) else "\n\n".join(m["text"] for m in message)
|
528
|
+
constructed_messages: List[dict[str, Any]] = []
|
529
|
+
if not is_none_or_empty(message):
|
530
|
+
constructed_messages += [{"type": "text", "text": message}] if isinstance(message, str) else message
|
531
|
+
# Drop image message passed by caller if chat model does not have vision enabled
|
532
|
+
if not vision_enabled:
|
533
|
+
constructed_messages = [m for m in constructed_messages if m.get("type") != "image_url"]
|
550
534
|
if not is_none_or_empty(attached_file_context):
|
551
|
-
|
552
|
-
|
553
|
-
|
535
|
+
constructed_messages += [{"type": "text", "text": attached_file_context}]
|
536
|
+
if vision_enabled and images:
|
537
|
+
for image in images:
|
538
|
+
constructed_messages += [{"type": "image_url", "image_url": {"url": image}}]
|
539
|
+
return constructed_messages
|
554
540
|
|
555
541
|
|
556
542
|
def gather_raw_query_files(
|
@@ -570,20 +556,21 @@ def gather_raw_query_files(
|
|
570
556
|
|
571
557
|
|
572
558
|
def generate_chatml_messages_with_context(
|
559
|
+
# Context
|
573
560
|
user_message: str,
|
574
|
-
|
561
|
+
query_files: str = None,
|
562
|
+
query_images=None,
|
563
|
+
context_message="",
|
564
|
+
generated_asset_results: Dict[str, Dict] = {},
|
565
|
+
program_execution_context: List[str] = [],
|
575
566
|
chat_history: list[ChatMessageModel] = [],
|
567
|
+
system_message: str = None,
|
568
|
+
# Model Config
|
576
569
|
model_name="gpt-4o-mini",
|
570
|
+
model_type="",
|
577
571
|
max_prompt_size=None,
|
578
572
|
tokenizer_name=None,
|
579
|
-
query_images=None,
|
580
573
|
vision_enabled=False,
|
581
|
-
model_type="",
|
582
|
-
context_message="",
|
583
|
-
query_files: str = None,
|
584
|
-
generated_files: List[FileAttachment] = None,
|
585
|
-
generated_asset_results: Dict[str, Dict] = {},
|
586
|
-
program_execution_context: List[str] = [],
|
587
574
|
):
|
588
575
|
"""Generate chat messages with appropriate context from previous conversation to send to the chat model"""
|
589
576
|
# Set max prompt size from user config or based on pre-configured for model and machine specs
|
@@ -605,18 +592,10 @@ def generate_chatml_messages_with_context(
|
|
605
592
|
role = "user" if chat.by == "you" else "assistant"
|
606
593
|
|
607
594
|
# Legacy code to handle excalidraw diagrams prior to Dec 2024
|
608
|
-
if chat.by == "khoj" and "excalidraw" in chat.intent.type
|
595
|
+
if chat.by == "khoj" and chat.intent and "excalidraw" in chat.intent.type:
|
609
596
|
chat_message = (chat.intent.inferred_queries or [])[0]
|
610
597
|
|
611
|
-
|
612
|
-
raw_query_files = chat.queryFiles
|
613
|
-
query_files_dict = dict()
|
614
|
-
for file in raw_query_files:
|
615
|
-
query_files_dict[file["name"]] = file["content"]
|
616
|
-
|
617
|
-
message_attached_files = gather_raw_query_files(query_files_dict)
|
618
|
-
chatml_messages.append(ChatMessage(content=message_attached_files, role=role))
|
619
|
-
|
598
|
+
# Add search and action context
|
620
599
|
if not is_none_or_empty(chat.onlineContext):
|
621
600
|
message_context += [
|
622
601
|
{
|
@@ -655,11 +634,12 @@ def generate_chatml_messages_with_context(
|
|
655
634
|
|
656
635
|
if not is_none_or_empty(message_context):
|
657
636
|
reconstructed_context_message = ChatMessage(content=message_context, role="user")
|
658
|
-
chatml_messages.
|
637
|
+
chatml_messages.append(reconstructed_context_message)
|
659
638
|
|
639
|
+
# Add generated assets
|
660
640
|
if not is_none_or_empty(chat.images) and role == "assistant":
|
661
641
|
generated_assets["image"] = {
|
662
|
-
"
|
642
|
+
"description": (chat.intent.inferred_queries or [user_message])[0],
|
663
643
|
}
|
664
644
|
|
665
645
|
if not is_none_or_empty(chat.mermaidjsDiagram) and role == "assistant":
|
@@ -675,8 +655,17 @@ def generate_chatml_messages_with_context(
|
|
675
655
|
)
|
676
656
|
)
|
677
657
|
|
658
|
+
# Add user query with attached file, images or khoj response
|
659
|
+
if chat.queryFiles:
|
660
|
+
raw_query_files = chat.queryFiles
|
661
|
+
query_files_dict = dict()
|
662
|
+
for file in raw_query_files:
|
663
|
+
query_files_dict[file["name"]] = file["content"]
|
664
|
+
|
665
|
+
message_attached_files = gather_raw_query_files(query_files_dict)
|
666
|
+
|
678
667
|
message_content = construct_structured_message(
|
679
|
-
chat_message, chat.images if role == "user" else [], model_type, vision_enabled
|
668
|
+
chat_message, chat.images if role == "user" else [], model_type, vision_enabled, message_attached_files
|
680
669
|
)
|
681
670
|
|
682
671
|
reconstructed_message = ChatMessage(
|
@@ -684,19 +673,32 @@ def generate_chatml_messages_with_context(
|
|
684
673
|
role=role,
|
685
674
|
additional_kwargs={"message_type": chat.intent.type if chat.intent else None},
|
686
675
|
)
|
687
|
-
chatml_messages.
|
676
|
+
chatml_messages.append(reconstructed_message)
|
688
677
|
|
689
678
|
if len(chatml_messages) >= 3 * lookback_turns:
|
690
679
|
break
|
691
680
|
|
692
681
|
messages: list[ChatMessage] = []
|
693
682
|
|
683
|
+
if not is_none_or_empty(system_message):
|
684
|
+
messages.append(ChatMessage(content=system_message, role="system"))
|
685
|
+
|
686
|
+
if len(chatml_messages) > 0:
|
687
|
+
messages += chatml_messages
|
688
|
+
|
689
|
+
if program_execution_context:
|
690
|
+
program_context_text = "\n".join(program_execution_context)
|
691
|
+
context_message += f"{prompts.additional_program_context.format(context=program_context_text)}\n"
|
692
|
+
|
693
|
+
if not is_none_or_empty(context_message):
|
694
|
+
messages.append(ChatMessage(content=context_message, role="user"))
|
695
|
+
|
694
696
|
if not is_none_or_empty(generated_asset_results):
|
695
697
|
messages.append(
|
696
698
|
ChatMessage(
|
697
|
-
content=
|
699
|
+
content=prompts.generated_assets_context.format(generated_assets=yaml_dump(generated_asset_results)),
|
698
700
|
role="user",
|
699
|
-
)
|
701
|
+
),
|
700
702
|
)
|
701
703
|
|
702
704
|
if not is_none_or_empty(user_message):
|
@@ -709,23 +711,6 @@ def generate_chatml_messages_with_context(
|
|
709
711
|
)
|
710
712
|
)
|
711
713
|
|
712
|
-
if generated_files:
|
713
|
-
message_attached_files = gather_raw_query_files({file.name: file.content for file in generated_files})
|
714
|
-
messages.append(ChatMessage(content=message_attached_files, role="assistant"))
|
715
|
-
|
716
|
-
if program_execution_context:
|
717
|
-
program_context_text = "\n".join(program_execution_context)
|
718
|
-
context_message += f"{prompts.additional_program_context.format(context=program_context_text)}\n"
|
719
|
-
|
720
|
-
if not is_none_or_empty(context_message):
|
721
|
-
messages.append(ChatMessage(content=context_message, role="user"))
|
722
|
-
|
723
|
-
if len(chatml_messages) > 0:
|
724
|
-
messages += chatml_messages
|
725
|
-
|
726
|
-
if not is_none_or_empty(system_message):
|
727
|
-
messages.append(ChatMessage(content=system_message, role="system"))
|
728
|
-
|
729
714
|
# Normalize message content to list of chatml dictionaries
|
730
715
|
for message in messages:
|
731
716
|
if isinstance(message.content, str):
|
@@ -734,8 +719,8 @@ def generate_chatml_messages_with_context(
|
|
734
719
|
# Truncate oldest messages from conversation history until under max supported prompt size by model
|
735
720
|
messages = truncate_messages(messages, max_prompt_size, model_name, tokenizer_name)
|
736
721
|
|
737
|
-
# Return
|
738
|
-
return messages
|
722
|
+
# Return messages in chronological order
|
723
|
+
return messages
|
739
724
|
|
740
725
|
|
741
726
|
def get_encoder(
|
@@ -806,7 +791,9 @@ def count_tokens(
|
|
806
791
|
|
807
792
|
def count_total_tokens(messages: list[ChatMessage], encoder, system_message: Optional[ChatMessage]) -> Tuple[int, int]:
|
808
793
|
"""Count total tokens in messages including system message"""
|
809
|
-
system_message_tokens =
|
794
|
+
system_message_tokens = (
|
795
|
+
sum([count_tokens(message.content, encoder) for message in system_message]) if system_message else 0
|
796
|
+
)
|
810
797
|
message_tokens = sum([count_tokens(message.content, encoder) for message in messages])
|
811
798
|
# Reserves 4 tokens to demarcate each message (e.g <|im_start|>user, <|im_end|>, <|endoftext|> etc.)
|
812
799
|
total_tokens = message_tokens + system_message_tokens + 4 * len(messages)
|
@@ -823,11 +810,14 @@ def truncate_messages(
|
|
823
810
|
encoder = get_encoder(model_name, tokenizer_name)
|
824
811
|
|
825
812
|
# Extract system message from messages
|
826
|
-
system_message =
|
827
|
-
|
813
|
+
system_message = []
|
814
|
+
non_system_messages = []
|
815
|
+
for message in messages:
|
828
816
|
if message.role == "system":
|
829
|
-
system_message
|
830
|
-
|
817
|
+
system_message.append(message)
|
818
|
+
else:
|
819
|
+
non_system_messages.append(message)
|
820
|
+
messages = non_system_messages
|
831
821
|
|
832
822
|
# Drop older messages until under max supported prompt size by model
|
833
823
|
total_tokens, system_message_tokens = count_total_tokens(messages, encoder, system_message)
|
@@ -835,20 +825,20 @@ def truncate_messages(
|
|
835
825
|
while total_tokens > max_prompt_size and (len(messages) > 1 or len(messages[0].content) > 1):
|
836
826
|
# If the last message has more than one content part, pop the oldest content part.
|
837
827
|
# For tool calls, the whole message should dropped, assistant's tool call content being truncated annoys AI APIs.
|
838
|
-
if len(messages[
|
828
|
+
if len(messages[0].content) > 1 and messages[0].additional_kwargs.get("message_type") != "tool_call":
|
839
829
|
# The oldest content part is earlier in content list. So pop from the front.
|
840
|
-
messages[
|
830
|
+
messages[0].content.pop(0)
|
841
831
|
# Otherwise, pop the last message if it has only one content part or is a tool call.
|
842
832
|
else:
|
843
833
|
# The oldest message is the last one. So pop from the back.
|
844
|
-
dropped_message = messages.pop()
|
834
|
+
dropped_message = messages.pop(0)
|
845
835
|
# Drop tool result pair of tool call, if tool call message has been removed
|
846
836
|
if (
|
847
837
|
dropped_message.additional_kwargs.get("message_type") == "tool_call"
|
848
838
|
and messages
|
849
|
-
and messages[
|
839
|
+
and messages[0].additional_kwargs.get("message_type") == "tool_result"
|
850
840
|
):
|
851
|
-
messages.pop()
|
841
|
+
messages.pop(0)
|
852
842
|
|
853
843
|
total_tokens, _ = count_total_tokens(messages, encoder, system_message)
|
854
844
|
|
@@ -887,11 +877,7 @@ def truncate_messages(
|
|
887
877
|
f"Truncate current message to fit within max prompt size of {max_prompt_size} supported by {model_name} model:\n {truncated_snippet}"
|
888
878
|
)
|
889
879
|
|
890
|
-
if system_message
|
891
|
-
# Default system message role is system.
|
892
|
-
# Fallback to system message role of user for models that do not support this role like gemma-2 and openai's o1 model series.
|
893
|
-
system_message.role = "user" if "gemma-2" in model_name or model_name.startswith("o1") else "system"
|
894
|
-
return messages + [system_message] if system_message else messages
|
880
|
+
return system_message + messages if system_message else messages
|
895
881
|
|
896
882
|
|
897
883
|
def reciprocal_conversation_to_chatml(message_pair):
|
khoj/processor/image/generate.py
CHANGED
@@ -1,6 +1,7 @@
|
|
1
1
|
import base64
|
2
2
|
import io
|
3
3
|
import logging
|
4
|
+
import os
|
4
5
|
import time
|
5
6
|
from typing import Any, Callable, Dict, List, Optional
|
6
7
|
|
@@ -21,11 +22,12 @@ from khoj.database.adapters import ConversationAdapters
|
|
21
22
|
from khoj.database.models import (
|
22
23
|
Agent,
|
23
24
|
ChatMessageModel,
|
25
|
+
Intent,
|
24
26
|
KhojUser,
|
25
27
|
TextToImageModelConfig,
|
26
28
|
)
|
27
29
|
from khoj.processor.conversation.google.utils import _is_retryable_error
|
28
|
-
from khoj.routers.helpers import ChatEvent, generate_better_image_prompt
|
30
|
+
from khoj.routers.helpers import ChatEvent, ImageShape, generate_better_image_prompt
|
29
31
|
from khoj.routers.storage import upload_generated_image_to_bucket
|
30
32
|
from khoj.utils import state
|
31
33
|
from khoj.utils.helpers import convert_image_to_webp, timer
|
@@ -60,14 +62,17 @@ async def text_to_image(
|
|
60
62
|
return
|
61
63
|
|
62
64
|
text2image_model = text_to_image_config.model_name
|
63
|
-
|
65
|
+
image_chat_history: List[ChatMessageModel] = []
|
66
|
+
default_intent = Intent(type="remember")
|
64
67
|
for chat in chat_history[-4:]:
|
65
68
|
if chat.by == "you":
|
66
|
-
|
69
|
+
image_chat_history += [ChatMessageModel(by=chat.by, message=chat.message, intent=default_intent)]
|
70
|
+
elif chat.by == "khoj" and chat.images and chat.intent and chat.intent.inferred_queries:
|
71
|
+
image_chat_history += [
|
72
|
+
ChatMessageModel(by=chat.by, message=chat.intent.inferred_queries[0], intent=default_intent)
|
73
|
+
]
|
67
74
|
elif chat.by == "khoj" and chat.intent and chat.intent.type in ["remember", "reminder"]:
|
68
|
-
|
69
|
-
elif chat.by == "khoj" and chat.images:
|
70
|
-
chat_history_str += f"A: Improved Prompt: {chat.intent.inferred_queries[0]}\n"
|
75
|
+
image_chat_history += [ChatMessageModel(by=chat.by, message=chat.message, intent=default_intent)]
|
71
76
|
|
72
77
|
if send_status_func:
|
73
78
|
async for event in send_status_func("**Enhancing the Painting Prompt**"):
|
@@ -75,9 +80,9 @@ async def text_to_image(
|
|
75
80
|
|
76
81
|
# Generate a better image prompt
|
77
82
|
# Use the user's message, chat history, and other context
|
78
|
-
|
83
|
+
image_prompt_response = await generate_better_image_prompt(
|
79
84
|
message,
|
80
|
-
|
85
|
+
image_chat_history,
|
81
86
|
location_data=location_data,
|
82
87
|
note_references=references,
|
83
88
|
online_results=online_results,
|
@@ -88,6 +93,8 @@ async def text_to_image(
|
|
88
93
|
query_files=query_files,
|
89
94
|
tracer=tracer,
|
90
95
|
)
|
96
|
+
image_prompt = image_prompt_response["description"]
|
97
|
+
image_shape = image_prompt_response["shape"]
|
91
98
|
|
92
99
|
if send_status_func:
|
93
100
|
async for event in send_status_func(f"**Painting to Imagine**:\n{image_prompt}"):
|
@@ -97,13 +104,19 @@ async def text_to_image(
|
|
97
104
|
with timer(f"Generate image with {text_to_image_config.model_type}", logger):
|
98
105
|
try:
|
99
106
|
if text_to_image_config.model_type == TextToImageModelConfig.ModelType.OPENAI:
|
100
|
-
webp_image_bytes = generate_image_with_openai(
|
107
|
+
webp_image_bytes = generate_image_with_openai(
|
108
|
+
image_prompt, text_to_image_config, text2image_model, image_shape
|
109
|
+
)
|
101
110
|
elif text_to_image_config.model_type == TextToImageModelConfig.ModelType.STABILITYAI:
|
102
111
|
webp_image_bytes = generate_image_with_stability(image_prompt, text_to_image_config, text2image_model)
|
103
112
|
elif text_to_image_config.model_type == TextToImageModelConfig.ModelType.REPLICATE:
|
104
|
-
webp_image_bytes = generate_image_with_replicate(
|
113
|
+
webp_image_bytes = generate_image_with_replicate(
|
114
|
+
image_prompt, text_to_image_config, text2image_model, image_shape
|
115
|
+
)
|
105
116
|
elif text_to_image_config.model_type == TextToImageModelConfig.ModelType.GOOGLE:
|
106
|
-
webp_image_bytes = generate_image_with_google(
|
117
|
+
webp_image_bytes = generate_image_with_google(
|
118
|
+
image_prompt, text_to_image_config, text2image_model, image_shape
|
119
|
+
)
|
107
120
|
except openai.OpenAIError or openai.BadRequestError or openai.APIConnectionError as e:
|
108
121
|
if "content_policy_violation" in e.message:
|
109
122
|
logger.error(f"Image Generation blocked by OpenAI: {e}")
|
@@ -154,7 +167,10 @@ async def text_to_image(
|
|
154
167
|
reraise=True,
|
155
168
|
)
|
156
169
|
def generate_image_with_openai(
|
157
|
-
improved_image_prompt: str,
|
170
|
+
improved_image_prompt: str,
|
171
|
+
text_to_image_config: TextToImageModelConfig,
|
172
|
+
text2image_model: str,
|
173
|
+
shape: ImageShape = ImageShape.SQUARE,
|
158
174
|
):
|
159
175
|
"Generate image using OpenAI (compatible) API"
|
160
176
|
|
@@ -170,12 +186,21 @@ def generate_image_with_openai(
|
|
170
186
|
elif state.openai_client:
|
171
187
|
openai_client = state.openai_client
|
172
188
|
|
189
|
+
# Convert shape to size for OpenAI
|
190
|
+
if shape == ImageShape.PORTRAIT:
|
191
|
+
size = "1024x1536"
|
192
|
+
elif shape == ImageShape.LANDSCAPE:
|
193
|
+
size = "1536x1024"
|
194
|
+
else: # Square
|
195
|
+
size = "1024x1024"
|
196
|
+
|
173
197
|
# Generate image using OpenAI API
|
174
198
|
OPENAI_IMAGE_GEN_STYLE = "vivid"
|
175
199
|
response = openai_client.images.generate(
|
176
200
|
prompt=improved_image_prompt,
|
177
201
|
model=text2image_model,
|
178
202
|
style=OPENAI_IMAGE_GEN_STYLE,
|
203
|
+
size=size,
|
179
204
|
response_format="b64_json",
|
180
205
|
)
|
181
206
|
|
@@ -222,10 +247,22 @@ def generate_image_with_stability(
|
|
222
247
|
reraise=True,
|
223
248
|
)
|
224
249
|
def generate_image_with_replicate(
|
225
|
-
improved_image_prompt: str,
|
250
|
+
improved_image_prompt: str,
|
251
|
+
text_to_image_config: TextToImageModelConfig,
|
252
|
+
text2image_model: str,
|
253
|
+
shape: ImageShape = ImageShape.SQUARE,
|
226
254
|
):
|
227
255
|
"Generate image using Replicate API"
|
228
256
|
|
257
|
+
# Convert shape to aspect ratio for Replicate
|
258
|
+
# Replicate supports only 1:1, 3:4, and 4:3 aspect ratios
|
259
|
+
if shape == ImageShape.PORTRAIT:
|
260
|
+
aspect_ratio = "3:4"
|
261
|
+
elif shape == ImageShape.LANDSCAPE:
|
262
|
+
aspect_ratio = "4:3"
|
263
|
+
else: # Square
|
264
|
+
aspect_ratio = "1:1"
|
265
|
+
|
229
266
|
# Create image generation task on Replicate
|
230
267
|
replicate_create_prediction_url = f"https://api.replicate.com/v1/models/{text2image_model}/predictions"
|
231
268
|
headers = {
|
@@ -236,11 +273,16 @@ def generate_image_with_replicate(
|
|
236
273
|
"input": {
|
237
274
|
"prompt": improved_image_prompt,
|
238
275
|
"num_outputs": 1,
|
239
|
-
"aspect_ratio":
|
276
|
+
"aspect_ratio": aspect_ratio,
|
240
277
|
"output_format": "webp",
|
241
278
|
"output_quality": 100,
|
242
279
|
}
|
243
280
|
}
|
281
|
+
|
282
|
+
seed = int(os.getenv("KHOJ_LLM_SEED")) if os.getenv("KHOJ_LLM_SEED") else None
|
283
|
+
if seed:
|
284
|
+
json["input"]["seed"] = seed
|
285
|
+
|
244
286
|
create_prediction = requests.post(replicate_create_prediction_url, headers=headers, json=json).json()
|
245
287
|
|
246
288
|
# Get status of image generation task
|
@@ -276,7 +318,10 @@ def generate_image_with_replicate(
|
|
276
318
|
reraise=True,
|
277
319
|
)
|
278
320
|
def generate_image_with_google(
|
279
|
-
improved_image_prompt: str,
|
321
|
+
improved_image_prompt: str,
|
322
|
+
text_to_image_config: TextToImageModelConfig,
|
323
|
+
text2image_model: str,
|
324
|
+
shape: ImageShape = ImageShape.SQUARE,
|
280
325
|
):
|
281
326
|
"""Generate image using Google's AI over API"""
|
282
327
|
|
@@ -284,6 +329,14 @@ def generate_image_with_google(
|
|
284
329
|
api_key = text_to_image_config.api_key or text_to_image_config.ai_model_api.api_key
|
285
330
|
client = genai.Client(api_key=api_key)
|
286
331
|
|
332
|
+
# Convert shape to aspect ratio for Google
|
333
|
+
if shape == ImageShape.PORTRAIT:
|
334
|
+
aspect_ratio = "3:4"
|
335
|
+
elif shape == ImageShape.LANDSCAPE:
|
336
|
+
aspect_ratio = "4:3"
|
337
|
+
else: # Square
|
338
|
+
aspect_ratio = "1:1"
|
339
|
+
|
287
340
|
# Configure image generation settings
|
288
341
|
config = gtypes.GenerateImagesConfig(
|
289
342
|
number_of_images=1,
|
@@ -291,6 +344,7 @@ def generate_image_with_google(
|
|
291
344
|
person_generation=gtypes.PersonGeneration.ALLOW_ADULT,
|
292
345
|
include_rai_reason=True,
|
293
346
|
output_mime_type="image/png",
|
347
|
+
aspect_ratio=aspect_ratio,
|
294
348
|
)
|
295
349
|
|
296
350
|
# Call the Gemini API to generate the image
|
khoj/processor/tools/run_code.py
CHANGED
@@ -156,10 +156,11 @@ async def generate_python_code(
|
|
156
156
|
|
157
157
|
response = await send_message_to_model_wrapper(
|
158
158
|
code_generation_prompt,
|
159
|
-
query_images=query_images,
|
160
159
|
query_files=query_files,
|
161
|
-
|
160
|
+
query_images=query_images,
|
161
|
+
fast_model=False,
|
162
162
|
agent_chat_model=agent_chat_model,
|
163
|
+
user=user,
|
163
164
|
tracer=tracer,
|
164
165
|
)
|
165
166
|
|
khoj/routers/api_chat.py
CHANGED
@@ -90,7 +90,6 @@ from khoj.utils.helpers import (
|
|
90
90
|
is_operator_enabled,
|
91
91
|
)
|
92
92
|
from khoj.utils.rawconfig import (
|
93
|
-
FileAttachment,
|
94
93
|
FileFilterRequest,
|
95
94
|
FilesFilterRequest,
|
96
95
|
LocationData,
|
@@ -732,7 +731,6 @@ async def event_generator(
|
|
732
731
|
attached_file_context = gather_raw_query_files(query_files)
|
733
732
|
|
734
733
|
generated_images: List[str] = []
|
735
|
-
generated_files: List[FileAttachment] = []
|
736
734
|
generated_mermaidjs_diagram: str = None
|
737
735
|
generated_asset_results: Dict = dict()
|
738
736
|
program_execution_context: List[str] = []
|
@@ -769,7 +767,6 @@ async def event_generator(
|
|
769
767
|
train_of_thought=train_of_thought,
|
770
768
|
raw_query_files=raw_query_files,
|
771
769
|
generated_images=generated_images,
|
772
|
-
raw_generated_files=generated_asset_results,
|
773
770
|
generated_mermaidjs_diagram=generated_mermaidjs_diagram,
|
774
771
|
user_message_time=user_message_time,
|
775
772
|
tracer=tracer,
|
@@ -816,7 +813,6 @@ async def event_generator(
|
|
816
813
|
train_of_thought=train_of_thought,
|
817
814
|
raw_query_files=raw_query_files,
|
818
815
|
generated_images=generated_images,
|
819
|
-
raw_generated_files=generated_asset_results,
|
820
816
|
generated_mermaidjs_diagram=generated_mermaidjs_diagram,
|
821
817
|
user_message_time=user_message_time,
|
822
818
|
tracer=tracer,
|
@@ -927,9 +923,7 @@ async def event_generator(
|
|
927
923
|
|
928
924
|
# Automated tasks are handled before to allow mixing them with other conversation commands
|
929
925
|
cmds_to_rate_limit = []
|
930
|
-
is_automated_task = False
|
931
926
|
if q.startswith("/automated_task"):
|
932
|
-
is_automated_task = True
|
933
927
|
q = q.replace("/automated_task", "").lstrip()
|
934
928
|
cmds_to_rate_limit += [ConversationCommand.AutomatedTask]
|
935
929
|
|
@@ -989,7 +983,6 @@ async def event_generator(
|
|
989
983
|
chosen_io = await aget_data_sources_and_output_format(
|
990
984
|
q,
|
991
985
|
chat_history,
|
992
|
-
is_automated_task,
|
993
986
|
user=user,
|
994
987
|
query_images=uploaded_images,
|
995
988
|
agent=agent,
|
@@ -1021,7 +1014,6 @@ async def event_generator(
|
|
1021
1014
|
return
|
1022
1015
|
|
1023
1016
|
defiltered_query = defilter_query(q)
|
1024
|
-
file_filters = conversation.file_filters if conversation and conversation.file_filters else []
|
1025
1017
|
|
1026
1018
|
if conversation_commands == [ConversationCommand.Research]:
|
1027
1019
|
async for research_result in research(
|
@@ -1035,12 +1027,11 @@ async def event_generator(
|
|
1035
1027
|
send_status_func=partial(send_event, ChatEvent.STATUS),
|
1036
1028
|
user_name=user_name,
|
1037
1029
|
location=location,
|
1038
|
-
file_filters=file_filters,
|
1039
1030
|
query_files=attached_file_context,
|
1040
|
-
tracer=tracer,
|
1041
1031
|
cancellation_event=cancellation_event,
|
1042
1032
|
interrupt_queue=child_interrupt_queue,
|
1043
1033
|
abort_message=ChatEvent.END_EVENT.value,
|
1034
|
+
tracer=tracer,
|
1044
1035
|
):
|
1045
1036
|
if isinstance(research_result, ResearchIteration):
|
1046
1037
|
if research_result.summarizedResult:
|
@@ -1075,8 +1066,8 @@ async def event_generator(
|
|
1075
1066
|
logger.debug(f"Researched Results: {''.join(r.summarizedResult or '' for r in research_results)}")
|
1076
1067
|
|
1077
1068
|
# Gather Context
|
1078
|
-
##
|
1079
|
-
if ConversationCommand.
|
1069
|
+
## Gather Document References
|
1070
|
+
if ConversationCommand.Notes in conversation_commands:
|
1080
1071
|
try:
|
1081
1072
|
async for result in search_documents(
|
1082
1073
|
q,
|
@@ -1194,7 +1185,7 @@ async def event_generator(
|
|
1194
1185
|
):
|
1195
1186
|
yield result
|
1196
1187
|
|
1197
|
-
##
|
1188
|
+
## Run Code
|
1198
1189
|
if ConversationCommand.Code in conversation_commands:
|
1199
1190
|
try:
|
1200
1191
|
context = f"# Iteration 1:\n#---\nNotes:\n{compiled_references}\n\nOnline Results:{online_results}"
|
@@ -1220,6 +1211,8 @@ async def event_generator(
|
|
1220
1211
|
f"Failed to use code tool: {e}. Attempting to respond without code results",
|
1221
1212
|
exc_info=True,
|
1222
1213
|
)
|
1214
|
+
|
1215
|
+
## Operate Computer
|
1223
1216
|
if ConversationCommand.Operator in conversation_commands:
|
1224
1217
|
try:
|
1225
1218
|
async for result in operate_environment(
|
@@ -1300,7 +1293,7 @@ async def event_generator(
|
|
1300
1293
|
generated_images.append(generated_image)
|
1301
1294
|
|
1302
1295
|
generated_asset_results["images"] = {
|
1303
|
-
"
|
1296
|
+
"description": improved_image_prompt,
|
1304
1297
|
}
|
1305
1298
|
|
1306
1299
|
async for result in send_event(
|
@@ -1316,8 +1309,6 @@ async def event_generator(
|
|
1316
1309
|
yield result
|
1317
1310
|
|
1318
1311
|
inferred_queries = []
|
1319
|
-
diagram_description = ""
|
1320
|
-
|
1321
1312
|
async for result in generate_mermaidjs_diagram(
|
1322
1313
|
q=defiltered_query,
|
1323
1314
|
chat_history=chat_history,
|
@@ -1337,9 +1328,7 @@ async def event_generator(
|
|
1337
1328
|
better_diagram_description_prompt, mermaidjs_diagram_description = result
|
1338
1329
|
if better_diagram_description_prompt and mermaidjs_diagram_description:
|
1339
1330
|
inferred_queries.append(better_diagram_description_prompt)
|
1340
|
-
|
1341
|
-
|
1342
|
-
generated_mermaidjs_diagram = diagram_description
|
1331
|
+
generated_mermaidjs_diagram = mermaidjs_diagram_description
|
1343
1332
|
|
1344
1333
|
generated_asset_results["diagrams"] = {
|
1345
1334
|
"query": better_diagram_description_prompt,
|
@@ -1386,7 +1375,6 @@ async def event_generator(
|
|
1386
1375
|
user_name,
|
1387
1376
|
uploaded_images,
|
1388
1377
|
attached_file_context,
|
1389
|
-
generated_files,
|
1390
1378
|
program_execution_context,
|
1391
1379
|
generated_asset_results,
|
1392
1380
|
is_subscribed,
|
@@ -1447,7 +1435,6 @@ async def event_generator(
|
|
1447
1435
|
train_of_thought=train_of_thought,
|
1448
1436
|
raw_query_files=raw_query_files,
|
1449
1437
|
generated_images=generated_images,
|
1450
|
-
raw_generated_files=generated_files,
|
1451
1438
|
generated_mermaidjs_diagram=generated_mermaidjs_diagram,
|
1452
1439
|
tracer=tracer,
|
1453
1440
|
)
|