khoj 1.16.1.dev15__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- khoj/__init__.py +0 -0
- khoj/app/README.md +94 -0
- khoj/app/__init__.py +0 -0
- khoj/app/asgi.py +16 -0
- khoj/app/settings.py +192 -0
- khoj/app/urls.py +25 -0
- khoj/configure.py +424 -0
- khoj/database/__init__.py +0 -0
- khoj/database/adapters/__init__.py +1234 -0
- khoj/database/admin.py +290 -0
- khoj/database/apps.py +6 -0
- khoj/database/management/__init__.py +0 -0
- khoj/database/management/commands/__init__.py +0 -0
- khoj/database/management/commands/change_generated_images_url.py +61 -0
- khoj/database/management/commands/convert_images_png_to_webp.py +99 -0
- khoj/database/migrations/0001_khojuser.py +98 -0
- khoj/database/migrations/0002_googleuser.py +32 -0
- khoj/database/migrations/0003_vector_extension.py +10 -0
- khoj/database/migrations/0004_content_types_and_more.py +181 -0
- khoj/database/migrations/0005_embeddings_corpus_id.py +19 -0
- khoj/database/migrations/0006_embeddingsdates.py +33 -0
- khoj/database/migrations/0007_add_conversation.py +27 -0
- khoj/database/migrations/0008_alter_conversation_conversation_log.py +17 -0
- khoj/database/migrations/0009_khojapiuser.py +24 -0
- khoj/database/migrations/0010_chatmodeloptions_and_more.py +83 -0
- khoj/database/migrations/0010_rename_embeddings_entry_and_more.py +30 -0
- khoj/database/migrations/0011_merge_20231102_0138.py +14 -0
- khoj/database/migrations/0012_entry_file_source.py +21 -0
- khoj/database/migrations/0013_subscription.py +37 -0
- khoj/database/migrations/0014_alter_googleuser_picture.py +17 -0
- khoj/database/migrations/0015_alter_subscription_user.py +21 -0
- khoj/database/migrations/0016_alter_subscription_renewal_date.py +17 -0
- khoj/database/migrations/0017_searchmodel.py +32 -0
- khoj/database/migrations/0018_searchmodelconfig_delete_searchmodel.py +30 -0
- khoj/database/migrations/0019_alter_googleuser_family_name_and_more.py +27 -0
- khoj/database/migrations/0020_reflectivequestion.py +36 -0
- khoj/database/migrations/0021_speechtotextmodeloptions_and_more.py +42 -0
- khoj/database/migrations/0022_texttoimagemodelconfig.py +25 -0
- khoj/database/migrations/0023_usersearchmodelconfig.py +33 -0
- khoj/database/migrations/0024_alter_entry_embeddings.py +18 -0
- khoj/database/migrations/0025_clientapplication_khojuser_phone_number_and_more.py +46 -0
- khoj/database/migrations/0025_searchmodelconfig_embeddings_inference_endpoint_and_more.py +22 -0
- khoj/database/migrations/0026_searchmodelconfig_cross_encoder_inference_endpoint_and_more.py +22 -0
- khoj/database/migrations/0027_merge_20240118_1324.py +13 -0
- khoj/database/migrations/0028_khojuser_verified_phone_number.py +17 -0
- khoj/database/migrations/0029_userrequests.py +27 -0
- khoj/database/migrations/0030_conversation_slug_and_title.py +38 -0
- khoj/database/migrations/0031_agent_conversation_agent.py +53 -0
- khoj/database/migrations/0031_alter_googleuser_locale.py +30 -0
- khoj/database/migrations/0032_merge_20240322_0427.py +14 -0
- khoj/database/migrations/0033_rename_tuning_agent_personality.py +17 -0
- khoj/database/migrations/0034_alter_chatmodeloptions_chat_model.py +32 -0
- khoj/database/migrations/0035_processlock.py +26 -0
- khoj/database/migrations/0036_alter_processlock_name.py +19 -0
- khoj/database/migrations/0036_delete_offlinechatprocessorconversationconfig.py +15 -0
- khoj/database/migrations/0036_publicconversation.py +42 -0
- khoj/database/migrations/0037_chatmodeloptions_openai_config_and_more.py +51 -0
- khoj/database/migrations/0037_searchmodelconfig_bi_encoder_docs_encode_config_and_more.py +32 -0
- khoj/database/migrations/0038_merge_20240425_0857.py +14 -0
- khoj/database/migrations/0038_merge_20240426_1640.py +12 -0
- khoj/database/migrations/0039_merge_20240501_0301.py +12 -0
- khoj/database/migrations/0040_alter_processlock_name.py +26 -0
- khoj/database/migrations/0040_merge_20240504_1010.py +14 -0
- khoj/database/migrations/0041_merge_20240505_1234.py +14 -0
- khoj/database/migrations/0042_serverchatsettings.py +46 -0
- khoj/database/migrations/0043_alter_chatmodeloptions_model_type.py +21 -0
- khoj/database/migrations/0044_conversation_file_filters.py +17 -0
- khoj/database/migrations/0045_fileobject.py +37 -0
- khoj/database/migrations/0046_khojuser_email_verification_code_and_more.py +22 -0
- khoj/database/migrations/0047_alter_entry_file_type.py +31 -0
- khoj/database/migrations/0048_voicemodeloption_uservoicemodelconfig.py +52 -0
- khoj/database/migrations/0049_datastore.py +38 -0
- khoj/database/migrations/0049_texttoimagemodelconfig_api_key_and_more.py +58 -0
- khoj/database/migrations/0050_alter_processlock_name.py +25 -0
- khoj/database/migrations/0051_merge_20240702_1220.py +14 -0
- khoj/database/migrations/0052_alter_searchmodelconfig_bi_encoder_docs_encode_config_and_more.py +27 -0
- khoj/database/migrations/__init__.py +0 -0
- khoj/database/models/__init__.py +402 -0
- khoj/database/tests.py +3 -0
- khoj/interface/email/feedback.html +34 -0
- khoj/interface/email/magic_link.html +17 -0
- khoj/interface/email/task.html +40 -0
- khoj/interface/email/welcome.html +61 -0
- khoj/interface/web/404.html +56 -0
- khoj/interface/web/agent.html +312 -0
- khoj/interface/web/agents.html +276 -0
- khoj/interface/web/assets/icons/agents.svg +6 -0
- khoj/interface/web/assets/icons/automation.svg +37 -0
- khoj/interface/web/assets/icons/cancel.svg +3 -0
- khoj/interface/web/assets/icons/chat.svg +24 -0
- khoj/interface/web/assets/icons/collapse.svg +17 -0
- khoj/interface/web/assets/icons/computer.png +0 -0
- khoj/interface/web/assets/icons/confirm-icon.svg +1 -0
- khoj/interface/web/assets/icons/copy-button-success.svg +6 -0
- khoj/interface/web/assets/icons/copy-button.svg +5 -0
- khoj/interface/web/assets/icons/credit-card.png +0 -0
- khoj/interface/web/assets/icons/delete.svg +26 -0
- khoj/interface/web/assets/icons/docx.svg +7 -0
- khoj/interface/web/assets/icons/edit.svg +4 -0
- khoj/interface/web/assets/icons/favicon-128x128.ico +0 -0
- khoj/interface/web/assets/icons/favicon-128x128.png +0 -0
- khoj/interface/web/assets/icons/favicon-256x256.png +0 -0
- khoj/interface/web/assets/icons/favicon.icns +0 -0
- khoj/interface/web/assets/icons/github.svg +1 -0
- khoj/interface/web/assets/icons/key.svg +4 -0
- khoj/interface/web/assets/icons/khoj-logo-sideways-200.png +0 -0
- khoj/interface/web/assets/icons/khoj-logo-sideways-500.png +0 -0
- khoj/interface/web/assets/icons/khoj-logo-sideways.svg +5385 -0
- khoj/interface/web/assets/icons/logotype.svg +1 -0
- khoj/interface/web/assets/icons/markdown.svg +1 -0
- khoj/interface/web/assets/icons/new.svg +23 -0
- khoj/interface/web/assets/icons/notion.svg +4 -0
- khoj/interface/web/assets/icons/openai-logomark.svg +1 -0
- khoj/interface/web/assets/icons/org.svg +1 -0
- khoj/interface/web/assets/icons/pdf.svg +23 -0
- khoj/interface/web/assets/icons/pencil-edit.svg +5 -0
- khoj/interface/web/assets/icons/plaintext.svg +1 -0
- khoj/interface/web/assets/icons/question-mark-icon.svg +1 -0
- khoj/interface/web/assets/icons/search.svg +25 -0
- khoj/interface/web/assets/icons/send.svg +1 -0
- khoj/interface/web/assets/icons/share.svg +8 -0
- khoj/interface/web/assets/icons/speaker.svg +4 -0
- khoj/interface/web/assets/icons/stop-solid.svg +37 -0
- khoj/interface/web/assets/icons/sync.svg +4 -0
- khoj/interface/web/assets/icons/thumbs-down-svgrepo-com.svg +6 -0
- khoj/interface/web/assets/icons/thumbs-up-svgrepo-com.svg +6 -0
- khoj/interface/web/assets/icons/user-silhouette.svg +4 -0
- khoj/interface/web/assets/icons/voice.svg +8 -0
- khoj/interface/web/assets/icons/web.svg +2 -0
- khoj/interface/web/assets/icons/whatsapp.svg +17 -0
- khoj/interface/web/assets/khoj.css +237 -0
- khoj/interface/web/assets/markdown-it.min.js +8476 -0
- khoj/interface/web/assets/natural-cron.min.js +1 -0
- khoj/interface/web/assets/org.min.js +1823 -0
- khoj/interface/web/assets/pico.min.css +5 -0
- khoj/interface/web/assets/purify.min.js +3 -0
- khoj/interface/web/assets/samples/desktop-browse-draw-sample.png +0 -0
- khoj/interface/web/assets/samples/desktop-plain-chat-sample.png +0 -0
- khoj/interface/web/assets/samples/desktop-remember-plan-sample.png +0 -0
- khoj/interface/web/assets/samples/phone-browse-draw-sample.png +0 -0
- khoj/interface/web/assets/samples/phone-plain-chat-sample.png +0 -0
- khoj/interface/web/assets/samples/phone-remember-plan-sample.png +0 -0
- khoj/interface/web/assets/utils.js +33 -0
- khoj/interface/web/base_config.html +445 -0
- khoj/interface/web/chat.html +3546 -0
- khoj/interface/web/config.html +1011 -0
- khoj/interface/web/config_automation.html +1103 -0
- khoj/interface/web/content_source_computer_input.html +139 -0
- khoj/interface/web/content_source_github_input.html +216 -0
- khoj/interface/web/content_source_notion_input.html +94 -0
- khoj/interface/web/khoj.webmanifest +51 -0
- khoj/interface/web/login.html +219 -0
- khoj/interface/web/public_conversation.html +2006 -0
- khoj/interface/web/search.html +470 -0
- khoj/interface/web/utils.html +48 -0
- khoj/main.py +241 -0
- khoj/manage.py +22 -0
- khoj/migrations/__init__.py +0 -0
- khoj/migrations/migrate_offline_chat_default_model.py +69 -0
- khoj/migrations/migrate_offline_chat_default_model_2.py +71 -0
- khoj/migrations/migrate_offline_chat_schema.py +83 -0
- khoj/migrations/migrate_offline_model.py +29 -0
- khoj/migrations/migrate_processor_config_openai.py +67 -0
- khoj/migrations/migrate_server_pg.py +138 -0
- khoj/migrations/migrate_version.py +17 -0
- khoj/processor/__init__.py +0 -0
- khoj/processor/content/__init__.py +0 -0
- khoj/processor/content/docx/__init__.py +0 -0
- khoj/processor/content/docx/docx_to_entries.py +110 -0
- khoj/processor/content/github/__init__.py +0 -0
- khoj/processor/content/github/github_to_entries.py +224 -0
- khoj/processor/content/images/__init__.py +0 -0
- khoj/processor/content/images/image_to_entries.py +118 -0
- khoj/processor/content/markdown/__init__.py +0 -0
- khoj/processor/content/markdown/markdown_to_entries.py +165 -0
- khoj/processor/content/notion/notion_to_entries.py +260 -0
- khoj/processor/content/org_mode/__init__.py +0 -0
- khoj/processor/content/org_mode/org_to_entries.py +231 -0
- khoj/processor/content/org_mode/orgnode.py +532 -0
- khoj/processor/content/pdf/__init__.py +0 -0
- khoj/processor/content/pdf/pdf_to_entries.py +116 -0
- khoj/processor/content/plaintext/__init__.py +0 -0
- khoj/processor/content/plaintext/plaintext_to_entries.py +122 -0
- khoj/processor/content/text_to_entries.py +297 -0
- khoj/processor/conversation/__init__.py +0 -0
- khoj/processor/conversation/anthropic/__init__.py +0 -0
- khoj/processor/conversation/anthropic/anthropic_chat.py +206 -0
- khoj/processor/conversation/anthropic/utils.py +114 -0
- khoj/processor/conversation/offline/__init__.py +0 -0
- khoj/processor/conversation/offline/chat_model.py +231 -0
- khoj/processor/conversation/offline/utils.py +78 -0
- khoj/processor/conversation/offline/whisper.py +15 -0
- khoj/processor/conversation/openai/__init__.py +0 -0
- khoj/processor/conversation/openai/gpt.py +187 -0
- khoj/processor/conversation/openai/utils.py +129 -0
- khoj/processor/conversation/openai/whisper.py +13 -0
- khoj/processor/conversation/prompts.py +758 -0
- khoj/processor/conversation/utils.py +262 -0
- khoj/processor/embeddings.py +117 -0
- khoj/processor/speech/__init__.py +0 -0
- khoj/processor/speech/text_to_speech.py +51 -0
- khoj/processor/tools/__init__.py +0 -0
- khoj/processor/tools/online_search.py +225 -0
- khoj/routers/__init__.py +0 -0
- khoj/routers/api.py +626 -0
- khoj/routers/api_agents.py +43 -0
- khoj/routers/api_chat.py +1180 -0
- khoj/routers/api_config.py +434 -0
- khoj/routers/api_phone.py +86 -0
- khoj/routers/auth.py +181 -0
- khoj/routers/email.py +133 -0
- khoj/routers/helpers.py +1188 -0
- khoj/routers/indexer.py +349 -0
- khoj/routers/notion.py +91 -0
- khoj/routers/storage.py +35 -0
- khoj/routers/subscription.py +104 -0
- khoj/routers/twilio.py +36 -0
- khoj/routers/web_client.py +471 -0
- khoj/search_filter/__init__.py +0 -0
- khoj/search_filter/base_filter.py +15 -0
- khoj/search_filter/date_filter.py +217 -0
- khoj/search_filter/file_filter.py +30 -0
- khoj/search_filter/word_filter.py +29 -0
- khoj/search_type/__init__.py +0 -0
- khoj/search_type/text_search.py +241 -0
- khoj/utils/__init__.py +0 -0
- khoj/utils/cli.py +93 -0
- khoj/utils/config.py +81 -0
- khoj/utils/constants.py +24 -0
- khoj/utils/fs_syncer.py +249 -0
- khoj/utils/helpers.py +418 -0
- khoj/utils/initialization.py +146 -0
- khoj/utils/jsonl.py +43 -0
- khoj/utils/models.py +47 -0
- khoj/utils/rawconfig.py +160 -0
- khoj/utils/state.py +46 -0
- khoj/utils/yaml.py +43 -0
- khoj-1.16.1.dev15.dist-info/METADATA +178 -0
- khoj-1.16.1.dev15.dist-info/RECORD +242 -0
- khoj-1.16.1.dev15.dist-info/WHEEL +4 -0
- khoj-1.16.1.dev15.dist-info/entry_points.txt +2 -0
- khoj-1.16.1.dev15.dist-info/licenses/LICENSE +661 -0
khoj/routers/helpers.py
ADDED
|
@@ -0,0 +1,1188 @@
|
|
|
1
|
+
import asyncio
|
|
2
|
+
import base64
|
|
3
|
+
import hashlib
|
|
4
|
+
import io
|
|
5
|
+
import json
|
|
6
|
+
import logging
|
|
7
|
+
import math
|
|
8
|
+
import re
|
|
9
|
+
from concurrent.futures import ThreadPoolExecutor
|
|
10
|
+
from datetime import datetime, timedelta, timezone
|
|
11
|
+
from functools import partial
|
|
12
|
+
from random import random
|
|
13
|
+
from typing import (
|
|
14
|
+
Annotated,
|
|
15
|
+
Any,
|
|
16
|
+
Callable,
|
|
17
|
+
Dict,
|
|
18
|
+
Iterator,
|
|
19
|
+
List,
|
|
20
|
+
Optional,
|
|
21
|
+
Tuple,
|
|
22
|
+
Union,
|
|
23
|
+
)
|
|
24
|
+
from urllib.parse import parse_qs, urlencode, urljoin, urlparse
|
|
25
|
+
|
|
26
|
+
import cron_descriptor
|
|
27
|
+
import openai
|
|
28
|
+
import pytz
|
|
29
|
+
import requests
|
|
30
|
+
from apscheduler.job import Job
|
|
31
|
+
from apscheduler.triggers.cron import CronTrigger
|
|
32
|
+
from asgiref.sync import sync_to_async
|
|
33
|
+
from fastapi import Depends, Header, HTTPException, Request, UploadFile
|
|
34
|
+
from PIL import Image
|
|
35
|
+
from starlette.authentication import has_required_scope
|
|
36
|
+
from starlette.requests import URL
|
|
37
|
+
|
|
38
|
+
from khoj.database.adapters import (
|
|
39
|
+
AgentAdapters,
|
|
40
|
+
AutomationAdapters,
|
|
41
|
+
ConversationAdapters,
|
|
42
|
+
EntryAdapters,
|
|
43
|
+
create_khoj_token,
|
|
44
|
+
get_khoj_tokens,
|
|
45
|
+
run_with_process_lock,
|
|
46
|
+
)
|
|
47
|
+
from khoj.database.models import (
|
|
48
|
+
ChatModelOptions,
|
|
49
|
+
ClientApplication,
|
|
50
|
+
Conversation,
|
|
51
|
+
KhojUser,
|
|
52
|
+
ProcessLock,
|
|
53
|
+
Subscription,
|
|
54
|
+
TextToImageModelConfig,
|
|
55
|
+
UserRequests,
|
|
56
|
+
)
|
|
57
|
+
from khoj.processor.conversation import prompts
|
|
58
|
+
from khoj.processor.conversation.anthropic.anthropic_chat import (
|
|
59
|
+
anthropic_send_message_to_model,
|
|
60
|
+
converse_anthropic,
|
|
61
|
+
)
|
|
62
|
+
from khoj.processor.conversation.offline.chat_model import (
|
|
63
|
+
converse_offline,
|
|
64
|
+
send_message_to_model_offline,
|
|
65
|
+
)
|
|
66
|
+
from khoj.processor.conversation.openai.gpt import converse, send_message_to_model
|
|
67
|
+
from khoj.processor.conversation.utils import (
|
|
68
|
+
ThreadedGenerator,
|
|
69
|
+
generate_chatml_messages_with_context,
|
|
70
|
+
save_to_conversation_log,
|
|
71
|
+
)
|
|
72
|
+
from khoj.routers.email import is_resend_enabled, send_task_email
|
|
73
|
+
from khoj.routers.storage import upload_image
|
|
74
|
+
from khoj.utils import state
|
|
75
|
+
from khoj.utils.config import OfflineChatProcessorModel
|
|
76
|
+
from khoj.utils.helpers import (
|
|
77
|
+
ConversationCommand,
|
|
78
|
+
ImageIntentType,
|
|
79
|
+
is_none_or_empty,
|
|
80
|
+
is_valid_url,
|
|
81
|
+
log_telemetry,
|
|
82
|
+
mode_descriptions_for_llm,
|
|
83
|
+
timer,
|
|
84
|
+
tool_descriptions_for_llm,
|
|
85
|
+
)
|
|
86
|
+
from khoj.utils.rawconfig import LocationData
|
|
87
|
+
|
|
88
|
+
logger = logging.getLogger(__name__)
|
|
89
|
+
|
|
90
|
+
executor = ThreadPoolExecutor(max_workers=1)
|
|
91
|
+
|
|
92
|
+
|
|
93
|
+
def is_query_empty(query: str) -> bool:
|
|
94
|
+
return is_none_or_empty(query.strip())
|
|
95
|
+
|
|
96
|
+
|
|
97
|
+
def validate_conversation_config():
|
|
98
|
+
default_config = ConversationAdapters.get_default_conversation_config()
|
|
99
|
+
|
|
100
|
+
if default_config is None:
|
|
101
|
+
raise HTTPException(status_code=500, detail="Contact the server administrator to set a default chat model.")
|
|
102
|
+
|
|
103
|
+
if default_config.model_type == "openai" and not default_config.openai_config:
|
|
104
|
+
raise HTTPException(status_code=500, detail="Contact the server administrator to set a default chat model.")
|
|
105
|
+
|
|
106
|
+
|
|
107
|
+
async def is_ready_to_chat(user: KhojUser):
|
|
108
|
+
user_conversation_config = (await ConversationAdapters.aget_user_conversation_config(user)) or (
|
|
109
|
+
await ConversationAdapters.aget_default_conversation_config()
|
|
110
|
+
)
|
|
111
|
+
|
|
112
|
+
if user_conversation_config and user_conversation_config.model_type == "offline":
|
|
113
|
+
chat_model = user_conversation_config.chat_model
|
|
114
|
+
max_tokens = user_conversation_config.max_prompt_size
|
|
115
|
+
if state.offline_chat_processor_config is None:
|
|
116
|
+
logger.info("Loading Offline Chat Model...")
|
|
117
|
+
state.offline_chat_processor_config = OfflineChatProcessorModel(chat_model, max_tokens)
|
|
118
|
+
return True
|
|
119
|
+
|
|
120
|
+
if (
|
|
121
|
+
user_conversation_config
|
|
122
|
+
and (user_conversation_config.model_type == "openai" or user_conversation_config.model_type == "anthropic")
|
|
123
|
+
and user_conversation_config.openai_config
|
|
124
|
+
):
|
|
125
|
+
return True
|
|
126
|
+
|
|
127
|
+
raise HTTPException(status_code=500, detail="Set your OpenAI API key or enable Local LLM via Khoj settings.")
|
|
128
|
+
|
|
129
|
+
|
|
130
|
+
def update_telemetry_state(
|
|
131
|
+
request: Request,
|
|
132
|
+
telemetry_type: str,
|
|
133
|
+
api: str,
|
|
134
|
+
client: Optional[str] = None,
|
|
135
|
+
user_agent: Optional[str] = None,
|
|
136
|
+
referer: Optional[str] = None,
|
|
137
|
+
host: Optional[str] = None,
|
|
138
|
+
metadata: Optional[dict] = None,
|
|
139
|
+
):
|
|
140
|
+
user: KhojUser = request.user.object if request.user.is_authenticated else None
|
|
141
|
+
client_app: ClientApplication = request.user.client_app if request.user.is_authenticated else None
|
|
142
|
+
subscription: Subscription = user.subscription if user and hasattr(user, "subscription") else None
|
|
143
|
+
user_state = {
|
|
144
|
+
"client_host": request.client.host if request.client else None,
|
|
145
|
+
"user_agent": user_agent or "unknown",
|
|
146
|
+
"referer": referer or "unknown",
|
|
147
|
+
"host": host or "unknown",
|
|
148
|
+
"server_id": str(user.uuid) if user else None,
|
|
149
|
+
"subscription_type": subscription.type if subscription else None,
|
|
150
|
+
"is_recurring": subscription.is_recurring if subscription else None,
|
|
151
|
+
"client_id": str(client_app.name) if client_app else "default",
|
|
152
|
+
}
|
|
153
|
+
|
|
154
|
+
if metadata:
|
|
155
|
+
user_state.update(metadata)
|
|
156
|
+
|
|
157
|
+
state.telemetry += [
|
|
158
|
+
log_telemetry(
|
|
159
|
+
telemetry_type=telemetry_type, api=api, client=client, app_config=state.config.app, properties=user_state
|
|
160
|
+
)
|
|
161
|
+
]
|
|
162
|
+
|
|
163
|
+
|
|
164
|
+
def get_next_url(request: Request) -> str:
|
|
165
|
+
"Construct next url relative to current domain from request"
|
|
166
|
+
next_url_param = urlparse(request.query_params.get("next", "/"))
|
|
167
|
+
next_path = "/" # default next path
|
|
168
|
+
# If relative path or absolute path to current domain
|
|
169
|
+
if is_none_or_empty(next_url_param.scheme) or next_url_param.netloc == request.base_url.netloc:
|
|
170
|
+
# Use path in next query param
|
|
171
|
+
next_path = next_url_param.path
|
|
172
|
+
# Construct absolute url using current domain and next path from request
|
|
173
|
+
return urljoin(str(request.base_url).rstrip("/"), next_path)
|
|
174
|
+
|
|
175
|
+
|
|
176
|
+
def construct_chat_history(conversation_history: dict, n: int = 4, agent_name="AI") -> str:
|
|
177
|
+
chat_history = ""
|
|
178
|
+
for chat in conversation_history.get("chat", [])[-n:]:
|
|
179
|
+
if chat["by"] == "khoj" and chat["intent"].get("type") in ["remember", "reminder"]:
|
|
180
|
+
chat_history += f"User: {chat['intent']['query']}\n"
|
|
181
|
+
chat_history += f"{agent_name}: {chat['message']}\n"
|
|
182
|
+
elif chat["by"] == "khoj" and ("text-to-image" in chat["intent"].get("type")):
|
|
183
|
+
chat_history += f"User: {chat['intent']['query']}\n"
|
|
184
|
+
chat_history += f"{agent_name}: [generated image redacted for space]\n"
|
|
185
|
+
return chat_history
|
|
186
|
+
|
|
187
|
+
|
|
188
|
+
def get_conversation_command(query: str, any_references: bool = False) -> ConversationCommand:
|
|
189
|
+
if query.startswith("/notes"):
|
|
190
|
+
return ConversationCommand.Notes
|
|
191
|
+
elif query.startswith("/help"):
|
|
192
|
+
return ConversationCommand.Help
|
|
193
|
+
elif query.startswith("/general"):
|
|
194
|
+
return ConversationCommand.General
|
|
195
|
+
elif query.startswith("/online"):
|
|
196
|
+
return ConversationCommand.Online
|
|
197
|
+
elif query.startswith("/webpage"):
|
|
198
|
+
return ConversationCommand.Webpage
|
|
199
|
+
elif query.startswith("/image"):
|
|
200
|
+
return ConversationCommand.Image
|
|
201
|
+
elif query.startswith("/automated_task"):
|
|
202
|
+
return ConversationCommand.AutomatedTask
|
|
203
|
+
elif query.startswith("/summarize"):
|
|
204
|
+
return ConversationCommand.Summarize
|
|
205
|
+
# If no relevant notes found for the given query
|
|
206
|
+
elif not any_references:
|
|
207
|
+
return ConversationCommand.General
|
|
208
|
+
else:
|
|
209
|
+
return ConversationCommand.Default
|
|
210
|
+
|
|
211
|
+
|
|
212
|
+
async def agenerate_chat_response(*args):
|
|
213
|
+
loop = asyncio.get_event_loop()
|
|
214
|
+
return await loop.run_in_executor(executor, generate_chat_response, *args)
|
|
215
|
+
|
|
216
|
+
|
|
217
|
+
async def acreate_title_from_query(query: str) -> str:
|
|
218
|
+
"""
|
|
219
|
+
Create a title from the given query
|
|
220
|
+
"""
|
|
221
|
+
title_generation_prompt = prompts.subject_generation.format(query=query)
|
|
222
|
+
|
|
223
|
+
with timer("Chat actor: Generate title from query", logger):
|
|
224
|
+
response = await send_message_to_model_wrapper(title_generation_prompt)
|
|
225
|
+
|
|
226
|
+
return response.strip()
|
|
227
|
+
|
|
228
|
+
|
|
229
|
+
async def aget_relevant_information_sources(query: str, conversation_history: dict, is_task: bool):
|
|
230
|
+
"""
|
|
231
|
+
Given a query, determine which of the available tools the agent should use in order to answer appropriately.
|
|
232
|
+
"""
|
|
233
|
+
|
|
234
|
+
tool_options = dict()
|
|
235
|
+
tool_options_str = ""
|
|
236
|
+
|
|
237
|
+
for tool, description in tool_descriptions_for_llm.items():
|
|
238
|
+
tool_options[tool.value] = description
|
|
239
|
+
tool_options_str += f'- "{tool.value}": "{description}"\n'
|
|
240
|
+
|
|
241
|
+
chat_history = construct_chat_history(conversation_history)
|
|
242
|
+
|
|
243
|
+
relevant_tools_prompt = prompts.pick_relevant_information_collection_tools.format(
|
|
244
|
+
query=query,
|
|
245
|
+
tools=tool_options_str,
|
|
246
|
+
chat_history=chat_history,
|
|
247
|
+
)
|
|
248
|
+
|
|
249
|
+
with timer("Chat actor: Infer information sources to refer", logger):
|
|
250
|
+
response = await send_message_to_model_wrapper(relevant_tools_prompt, response_type="json_object")
|
|
251
|
+
|
|
252
|
+
try:
|
|
253
|
+
response = response.strip()
|
|
254
|
+
response = json.loads(response)
|
|
255
|
+
response = [q.strip() for q in response["source"] if q.strip()]
|
|
256
|
+
if not isinstance(response, list) or not response or len(response) == 0:
|
|
257
|
+
logger.error(f"Invalid response for determining relevant tools: {response}")
|
|
258
|
+
return tool_options
|
|
259
|
+
|
|
260
|
+
final_response = [] if not is_task else [ConversationCommand.AutomatedTask]
|
|
261
|
+
for llm_suggested_tool in response:
|
|
262
|
+
if llm_suggested_tool in tool_options.keys():
|
|
263
|
+
# Check whether the tool exists as a valid ConversationCommand
|
|
264
|
+
final_response.append(ConversationCommand(llm_suggested_tool))
|
|
265
|
+
|
|
266
|
+
if is_none_or_empty(final_response):
|
|
267
|
+
final_response = [ConversationCommand.Default]
|
|
268
|
+
return final_response
|
|
269
|
+
except Exception as e:
|
|
270
|
+
logger.error(f"Invalid response for determining relevant tools: {response}")
|
|
271
|
+
return [ConversationCommand.Default]
|
|
272
|
+
|
|
273
|
+
|
|
274
|
+
async def aget_relevant_output_modes(query: str, conversation_history: dict, is_task: bool = False):
|
|
275
|
+
"""
|
|
276
|
+
Given a query, determine which of the available tools the agent should use in order to answer appropriately.
|
|
277
|
+
"""
|
|
278
|
+
|
|
279
|
+
mode_options = dict()
|
|
280
|
+
mode_options_str = ""
|
|
281
|
+
|
|
282
|
+
for mode, description in mode_descriptions_for_llm.items():
|
|
283
|
+
# Do not allow tasks to schedule another task
|
|
284
|
+
if is_task and mode == ConversationCommand.Automation:
|
|
285
|
+
continue
|
|
286
|
+
mode_options[mode.value] = description
|
|
287
|
+
mode_options_str += f'- "{mode.value}": "{description}"\n'
|
|
288
|
+
|
|
289
|
+
chat_history = construct_chat_history(conversation_history)
|
|
290
|
+
|
|
291
|
+
relevant_mode_prompt = prompts.pick_relevant_output_mode.format(
|
|
292
|
+
query=query,
|
|
293
|
+
modes=mode_options_str,
|
|
294
|
+
chat_history=chat_history,
|
|
295
|
+
)
|
|
296
|
+
|
|
297
|
+
with timer("Chat actor: Infer output mode for chat response", logger):
|
|
298
|
+
response = await send_message_to_model_wrapper(relevant_mode_prompt)
|
|
299
|
+
|
|
300
|
+
try:
|
|
301
|
+
response = response.strip()
|
|
302
|
+
|
|
303
|
+
if is_none_or_empty(response):
|
|
304
|
+
return ConversationCommand.Text
|
|
305
|
+
|
|
306
|
+
if response in mode_options.keys():
|
|
307
|
+
# Check whether the tool exists as a valid ConversationCommand
|
|
308
|
+
return ConversationCommand(response)
|
|
309
|
+
|
|
310
|
+
return ConversationCommand.Text
|
|
311
|
+
except Exception:
|
|
312
|
+
logger.error(f"Invalid response for determining relevant mode: {response}")
|
|
313
|
+
return ConversationCommand.Text
|
|
314
|
+
|
|
315
|
+
|
|
316
|
+
async def infer_webpage_urls(q: str, conversation_history: dict, location_data: LocationData) -> List[str]:
|
|
317
|
+
"""
|
|
318
|
+
Infer webpage links from the given query
|
|
319
|
+
"""
|
|
320
|
+
location = f"{location_data.city}, {location_data.region}, {location_data.country}" if location_data else "Unknown"
|
|
321
|
+
chat_history = construct_chat_history(conversation_history)
|
|
322
|
+
|
|
323
|
+
utc_date = datetime.utcnow().strftime("%Y-%m-%d")
|
|
324
|
+
online_queries_prompt = prompts.infer_webpages_to_read.format(
|
|
325
|
+
current_date=utc_date,
|
|
326
|
+
query=q,
|
|
327
|
+
chat_history=chat_history,
|
|
328
|
+
location=location,
|
|
329
|
+
)
|
|
330
|
+
|
|
331
|
+
with timer("Chat actor: Infer webpage urls to read", logger):
|
|
332
|
+
response = await send_message_to_model_wrapper(online_queries_prompt, response_type="json_object")
|
|
333
|
+
|
|
334
|
+
# Validate that the response is a non-empty, JSON-serializable list of URLs
|
|
335
|
+
try:
|
|
336
|
+
response = response.strip()
|
|
337
|
+
urls = json.loads(response)
|
|
338
|
+
valid_unique_urls = {str(url).strip() for url in urls["links"] if is_valid_url(url)}
|
|
339
|
+
if is_none_or_empty(valid_unique_urls):
|
|
340
|
+
raise ValueError(f"Invalid list of urls: {response}")
|
|
341
|
+
return list(valid_unique_urls)
|
|
342
|
+
except Exception:
|
|
343
|
+
raise ValueError(f"Invalid list of urls: {response}")
|
|
344
|
+
|
|
345
|
+
|
|
346
|
+
async def generate_online_subqueries(q: str, conversation_history: dict, location_data: LocationData) -> List[str]:
|
|
347
|
+
"""
|
|
348
|
+
Generate subqueries from the given query
|
|
349
|
+
"""
|
|
350
|
+
location = f"{location_data.city}, {location_data.region}, {location_data.country}" if location_data else "Unknown"
|
|
351
|
+
chat_history = construct_chat_history(conversation_history)
|
|
352
|
+
|
|
353
|
+
utc_date = datetime.utcnow().strftime("%Y-%m-%d")
|
|
354
|
+
online_queries_prompt = prompts.online_search_conversation_subqueries.format(
|
|
355
|
+
current_date=utc_date,
|
|
356
|
+
query=q,
|
|
357
|
+
chat_history=chat_history,
|
|
358
|
+
location=location,
|
|
359
|
+
)
|
|
360
|
+
|
|
361
|
+
with timer("Chat actor: Generate online search subqueries", logger):
|
|
362
|
+
response = await send_message_to_model_wrapper(online_queries_prompt, response_type="json_object")
|
|
363
|
+
|
|
364
|
+
# Validate that the response is a non-empty, JSON-serializable list
|
|
365
|
+
try:
|
|
366
|
+
response = response.strip()
|
|
367
|
+
response = json.loads(response)
|
|
368
|
+
response = [q.strip() for q in response["queries"] if q.strip()]
|
|
369
|
+
if not isinstance(response, list) or not response or len(response) == 0:
|
|
370
|
+
logger.error(f"Invalid response for constructing subqueries: {response}. Returning original query: {q}")
|
|
371
|
+
return [q]
|
|
372
|
+
return response
|
|
373
|
+
except Exception as e:
|
|
374
|
+
logger.error(f"Invalid response for constructing subqueries: {response}. Returning original query: {q}")
|
|
375
|
+
return [q]
|
|
376
|
+
|
|
377
|
+
|
|
378
|
+
async def schedule_query(q: str, conversation_history: dict) -> Tuple[str, ...]:
|
|
379
|
+
"""
|
|
380
|
+
Schedule the date, time to run the query. Assume the server timezone is UTC.
|
|
381
|
+
"""
|
|
382
|
+
chat_history = construct_chat_history(conversation_history)
|
|
383
|
+
|
|
384
|
+
crontime_prompt = prompts.crontime_prompt.format(
|
|
385
|
+
query=q,
|
|
386
|
+
chat_history=chat_history,
|
|
387
|
+
)
|
|
388
|
+
|
|
389
|
+
raw_response = await send_message_to_model_wrapper(crontime_prompt, response_type="json_object")
|
|
390
|
+
|
|
391
|
+
# Validate that the response is a non-empty, JSON-serializable list
|
|
392
|
+
try:
|
|
393
|
+
raw_response = raw_response.strip()
|
|
394
|
+
response: Dict[str, str] = json.loads(raw_response)
|
|
395
|
+
if not response or not isinstance(response, Dict) or len(response) != 3:
|
|
396
|
+
raise AssertionError(f"Invalid response for scheduling query : {response}")
|
|
397
|
+
return response.get("crontime"), response.get("query"), response.get("subject")
|
|
398
|
+
except Exception:
|
|
399
|
+
raise AssertionError(f"Invalid response for scheduling query: {raw_response}")
|
|
400
|
+
|
|
401
|
+
|
|
402
|
+
async def extract_relevant_info(q: str, corpus: str) -> Union[str, None]:
|
|
403
|
+
"""
|
|
404
|
+
Extract relevant information for a given query from the target corpus
|
|
405
|
+
"""
|
|
406
|
+
|
|
407
|
+
if is_none_or_empty(corpus) or is_none_or_empty(q):
|
|
408
|
+
return None
|
|
409
|
+
|
|
410
|
+
extract_relevant_information = prompts.extract_relevant_information.format(
|
|
411
|
+
query=q,
|
|
412
|
+
corpus=corpus.strip(),
|
|
413
|
+
)
|
|
414
|
+
|
|
415
|
+
summarizer_model: ChatModelOptions = await ConversationAdapters.aget_summarizer_conversation_config()
|
|
416
|
+
|
|
417
|
+
with timer("Chat actor: Extract relevant information from data", logger):
|
|
418
|
+
response = await send_message_to_model_wrapper(
|
|
419
|
+
extract_relevant_information,
|
|
420
|
+
prompts.system_prompt_extract_relevant_information,
|
|
421
|
+
chat_model_option=summarizer_model,
|
|
422
|
+
)
|
|
423
|
+
return response.strip()
|
|
424
|
+
|
|
425
|
+
|
|
426
|
+
async def extract_relevant_summary(q: str, corpus: str) -> Union[str, None]:
|
|
427
|
+
"""
|
|
428
|
+
Extract relevant information for a given query from the target corpus
|
|
429
|
+
"""
|
|
430
|
+
|
|
431
|
+
if is_none_or_empty(corpus) or is_none_or_empty(q):
|
|
432
|
+
return None
|
|
433
|
+
|
|
434
|
+
extract_relevant_information = prompts.extract_relevant_summary.format(
|
|
435
|
+
query=q,
|
|
436
|
+
corpus=corpus.strip(),
|
|
437
|
+
)
|
|
438
|
+
|
|
439
|
+
summarizer_model: ChatModelOptions = await ConversationAdapters.aget_summarizer_conversation_config()
|
|
440
|
+
|
|
441
|
+
with timer("Chat actor: Extract relevant information from data", logger):
|
|
442
|
+
response = await send_message_to_model_wrapper(
|
|
443
|
+
extract_relevant_information,
|
|
444
|
+
prompts.system_prompt_extract_relevant_summary,
|
|
445
|
+
chat_model_option=summarizer_model,
|
|
446
|
+
)
|
|
447
|
+
return response.strip()
|
|
448
|
+
|
|
449
|
+
|
|
450
|
+
async def generate_better_image_prompt(
|
|
451
|
+
q: str,
|
|
452
|
+
conversation_history: str,
|
|
453
|
+
location_data: LocationData,
|
|
454
|
+
note_references: List[Dict[str, Any]],
|
|
455
|
+
online_results: Optional[dict] = None,
|
|
456
|
+
model_type: Optional[str] = None,
|
|
457
|
+
) -> str:
|
|
458
|
+
"""
|
|
459
|
+
Generate a better image prompt from the given query
|
|
460
|
+
"""
|
|
461
|
+
|
|
462
|
+
today_date = datetime.now(tz=timezone.utc).strftime("%Y-%m-%d, %A")
|
|
463
|
+
model_type = model_type or TextToImageModelConfig.ModelType.OPENAI
|
|
464
|
+
|
|
465
|
+
if location_data:
|
|
466
|
+
location = f"{location_data.city}, {location_data.region}, {location_data.country}"
|
|
467
|
+
location_prompt = prompts.user_location.format(location=location)
|
|
468
|
+
else:
|
|
469
|
+
location_prompt = "Unknown"
|
|
470
|
+
|
|
471
|
+
user_references = "\n\n".join([f"# {item['compiled']}" for item in note_references])
|
|
472
|
+
|
|
473
|
+
simplified_online_results = {}
|
|
474
|
+
|
|
475
|
+
if online_results:
|
|
476
|
+
for result in online_results:
|
|
477
|
+
if online_results[result].get("answerBox"):
|
|
478
|
+
simplified_online_results[result] = online_results[result]["answerBox"]
|
|
479
|
+
elif online_results[result].get("webpages"):
|
|
480
|
+
simplified_online_results[result] = online_results[result]["webpages"]
|
|
481
|
+
|
|
482
|
+
if model_type == TextToImageModelConfig.ModelType.OPENAI:
|
|
483
|
+
image_prompt = prompts.image_generation_improve_prompt_dalle.format(
|
|
484
|
+
query=q,
|
|
485
|
+
chat_history=conversation_history,
|
|
486
|
+
location=location_prompt,
|
|
487
|
+
current_date=today_date,
|
|
488
|
+
references=user_references,
|
|
489
|
+
online_results=simplified_online_results,
|
|
490
|
+
)
|
|
491
|
+
elif model_type == TextToImageModelConfig.ModelType.STABILITYAI:
|
|
492
|
+
image_prompt = prompts.image_generation_improve_prompt_sd.format(
|
|
493
|
+
query=q,
|
|
494
|
+
chat_history=conversation_history,
|
|
495
|
+
location=location_prompt,
|
|
496
|
+
current_date=today_date,
|
|
497
|
+
references=user_references,
|
|
498
|
+
online_results=simplified_online_results,
|
|
499
|
+
)
|
|
500
|
+
|
|
501
|
+
summarizer_model: ChatModelOptions = await ConversationAdapters.aget_summarizer_conversation_config()
|
|
502
|
+
|
|
503
|
+
with timer("Chat actor: Generate contextual image prompt", logger):
|
|
504
|
+
response = await send_message_to_model_wrapper(image_prompt, chat_model_option=summarizer_model)
|
|
505
|
+
response = response.strip()
|
|
506
|
+
if response.startswith(('"', "'")) and response.endswith(('"', "'")):
|
|
507
|
+
response = response[1:-1]
|
|
508
|
+
|
|
509
|
+
return response
|
|
510
|
+
|
|
511
|
+
|
|
512
|
+
async def send_message_to_model_wrapper(
|
|
513
|
+
message: str,
|
|
514
|
+
system_message: str = "",
|
|
515
|
+
response_type: str = "text",
|
|
516
|
+
chat_model_option: ChatModelOptions = None,
|
|
517
|
+
):
|
|
518
|
+
conversation_config: ChatModelOptions = (
|
|
519
|
+
chat_model_option or await ConversationAdapters.aget_default_conversation_config()
|
|
520
|
+
)
|
|
521
|
+
|
|
522
|
+
if conversation_config is None:
|
|
523
|
+
raise HTTPException(status_code=500, detail="Contact the server administrator to set a default chat model.")
|
|
524
|
+
|
|
525
|
+
chat_model = conversation_config.chat_model
|
|
526
|
+
max_tokens = conversation_config.max_prompt_size
|
|
527
|
+
tokenizer = conversation_config.tokenizer
|
|
528
|
+
|
|
529
|
+
if conversation_config.model_type == "offline":
|
|
530
|
+
if state.offline_chat_processor_config is None or state.offline_chat_processor_config.loaded_model is None:
|
|
531
|
+
state.offline_chat_processor_config = OfflineChatProcessorModel(chat_model, max_tokens)
|
|
532
|
+
|
|
533
|
+
loaded_model = state.offline_chat_processor_config.loaded_model
|
|
534
|
+
truncated_messages = generate_chatml_messages_with_context(
|
|
535
|
+
user_message=message,
|
|
536
|
+
system_message=system_message,
|
|
537
|
+
model_name=chat_model,
|
|
538
|
+
loaded_model=loaded_model,
|
|
539
|
+
tokenizer_name=tokenizer,
|
|
540
|
+
max_prompt_size=max_tokens,
|
|
541
|
+
)
|
|
542
|
+
|
|
543
|
+
return send_message_to_model_offline(
|
|
544
|
+
messages=truncated_messages,
|
|
545
|
+
loaded_model=loaded_model,
|
|
546
|
+
model=chat_model,
|
|
547
|
+
streaming=False,
|
|
548
|
+
)
|
|
549
|
+
|
|
550
|
+
elif conversation_config.model_type == "openai":
|
|
551
|
+
openai_chat_config = conversation_config.openai_config
|
|
552
|
+
api_key = openai_chat_config.api_key
|
|
553
|
+
api_base_url = openai_chat_config.api_base_url
|
|
554
|
+
truncated_messages = generate_chatml_messages_with_context(
|
|
555
|
+
user_message=message,
|
|
556
|
+
system_message=system_message,
|
|
557
|
+
model_name=chat_model,
|
|
558
|
+
max_prompt_size=max_tokens,
|
|
559
|
+
tokenizer_name=tokenizer,
|
|
560
|
+
)
|
|
561
|
+
|
|
562
|
+
openai_response = send_message_to_model(
|
|
563
|
+
messages=truncated_messages,
|
|
564
|
+
api_key=api_key,
|
|
565
|
+
model=chat_model,
|
|
566
|
+
response_type=response_type,
|
|
567
|
+
api_base_url=api_base_url,
|
|
568
|
+
)
|
|
569
|
+
|
|
570
|
+
return openai_response
|
|
571
|
+
elif conversation_config.model_type == "anthropic":
|
|
572
|
+
api_key = conversation_config.openai_config.api_key
|
|
573
|
+
truncated_messages = generate_chatml_messages_with_context(
|
|
574
|
+
user_message=message,
|
|
575
|
+
system_message=system_message,
|
|
576
|
+
model_name=chat_model,
|
|
577
|
+
max_prompt_size=max_tokens,
|
|
578
|
+
tokenizer_name=tokenizer,
|
|
579
|
+
)
|
|
580
|
+
|
|
581
|
+
return anthropic_send_message_to_model(
|
|
582
|
+
messages=truncated_messages,
|
|
583
|
+
api_key=api_key,
|
|
584
|
+
model=chat_model,
|
|
585
|
+
)
|
|
586
|
+
else:
|
|
587
|
+
raise HTTPException(status_code=500, detail="Invalid conversation config")
|
|
588
|
+
|
|
589
|
+
|
|
590
|
+
def send_message_to_model_wrapper_sync(
|
|
591
|
+
message: str,
|
|
592
|
+
system_message: str = "",
|
|
593
|
+
response_type: str = "text",
|
|
594
|
+
):
|
|
595
|
+
conversation_config: ChatModelOptions = ConversationAdapters.get_default_conversation_config()
|
|
596
|
+
|
|
597
|
+
if conversation_config is None:
|
|
598
|
+
raise HTTPException(status_code=500, detail="Contact the server administrator to set a default chat model.")
|
|
599
|
+
|
|
600
|
+
chat_model = conversation_config.chat_model
|
|
601
|
+
max_tokens = conversation_config.max_prompt_size
|
|
602
|
+
|
|
603
|
+
if conversation_config.model_type == "offline":
|
|
604
|
+
if state.offline_chat_processor_config is None or state.offline_chat_processor_config.loaded_model is None:
|
|
605
|
+
state.offline_chat_processor_config = OfflineChatProcessorModel(chat_model, max_tokens)
|
|
606
|
+
|
|
607
|
+
loaded_model = state.offline_chat_processor_config.loaded_model
|
|
608
|
+
truncated_messages = generate_chatml_messages_with_context(
|
|
609
|
+
user_message=message, system_message=system_message, model_name=chat_model, loaded_model=loaded_model
|
|
610
|
+
)
|
|
611
|
+
|
|
612
|
+
return send_message_to_model_offline(
|
|
613
|
+
messages=truncated_messages,
|
|
614
|
+
loaded_model=loaded_model,
|
|
615
|
+
model=chat_model,
|
|
616
|
+
streaming=False,
|
|
617
|
+
)
|
|
618
|
+
|
|
619
|
+
elif conversation_config.model_type == "openai":
|
|
620
|
+
api_key = conversation_config.openai_config.api_key
|
|
621
|
+
truncated_messages = generate_chatml_messages_with_context(
|
|
622
|
+
user_message=message, system_message=system_message, model_name=chat_model
|
|
623
|
+
)
|
|
624
|
+
|
|
625
|
+
openai_response = send_message_to_model(
|
|
626
|
+
messages=truncated_messages, api_key=api_key, model=chat_model, response_type=response_type
|
|
627
|
+
)
|
|
628
|
+
|
|
629
|
+
return openai_response
|
|
630
|
+
|
|
631
|
+
elif conversation_config.model_type == "anthropic":
|
|
632
|
+
api_key = conversation_config.openai_config.api_key
|
|
633
|
+
truncated_messages = generate_chatml_messages_with_context(
|
|
634
|
+
user_message=message,
|
|
635
|
+
system_message=system_message,
|
|
636
|
+
model_name=chat_model,
|
|
637
|
+
max_prompt_size=max_tokens,
|
|
638
|
+
)
|
|
639
|
+
|
|
640
|
+
return anthropic_send_message_to_model(
|
|
641
|
+
messages=truncated_messages,
|
|
642
|
+
api_key=api_key,
|
|
643
|
+
model=chat_model,
|
|
644
|
+
)
|
|
645
|
+
else:
|
|
646
|
+
raise HTTPException(status_code=500, detail="Invalid conversation config")
|
|
647
|
+
|
|
648
|
+
|
|
649
|
+
def generate_chat_response(
|
|
650
|
+
q: str,
|
|
651
|
+
meta_log: dict,
|
|
652
|
+
conversation: Conversation,
|
|
653
|
+
compiled_references: List[Dict] = [],
|
|
654
|
+
online_results: Dict[str, Dict] = {},
|
|
655
|
+
inferred_queries: List[str] = [],
|
|
656
|
+
conversation_commands: List[ConversationCommand] = [ConversationCommand.Default],
|
|
657
|
+
user: KhojUser = None,
|
|
658
|
+
client_application: ClientApplication = None,
|
|
659
|
+
conversation_id: int = None,
|
|
660
|
+
location_data: LocationData = None,
|
|
661
|
+
user_name: Optional[str] = None,
|
|
662
|
+
) -> Tuple[Union[ThreadedGenerator, Iterator[str]], Dict[str, str]]:
|
|
663
|
+
# Initialize Variables
|
|
664
|
+
chat_response = None
|
|
665
|
+
logger.debug(f"Conversation Types: {conversation_commands}")
|
|
666
|
+
|
|
667
|
+
metadata = {}
|
|
668
|
+
agent = AgentAdapters.get_conversation_agent_by_id(conversation.agent.id) if conversation.agent else None
|
|
669
|
+
|
|
670
|
+
try:
|
|
671
|
+
partial_completion = partial(
|
|
672
|
+
save_to_conversation_log,
|
|
673
|
+
q,
|
|
674
|
+
user=user,
|
|
675
|
+
meta_log=meta_log,
|
|
676
|
+
compiled_references=compiled_references,
|
|
677
|
+
online_results=online_results,
|
|
678
|
+
inferred_queries=inferred_queries,
|
|
679
|
+
client_application=client_application,
|
|
680
|
+
conversation_id=conversation_id,
|
|
681
|
+
)
|
|
682
|
+
|
|
683
|
+
conversation_config = ConversationAdapters.get_valid_conversation_config(user, conversation)
|
|
684
|
+
if conversation_config.model_type == "offline":
|
|
685
|
+
loaded_model = state.offline_chat_processor_config.loaded_model
|
|
686
|
+
chat_response = converse_offline(
|
|
687
|
+
references=compiled_references,
|
|
688
|
+
online_results=online_results,
|
|
689
|
+
user_query=q,
|
|
690
|
+
loaded_model=loaded_model,
|
|
691
|
+
conversation_log=meta_log,
|
|
692
|
+
completion_func=partial_completion,
|
|
693
|
+
conversation_commands=conversation_commands,
|
|
694
|
+
model=conversation_config.chat_model,
|
|
695
|
+
max_prompt_size=conversation_config.max_prompt_size,
|
|
696
|
+
tokenizer_name=conversation_config.tokenizer,
|
|
697
|
+
location_data=location_data,
|
|
698
|
+
user_name=user_name,
|
|
699
|
+
agent=agent,
|
|
700
|
+
)
|
|
701
|
+
|
|
702
|
+
elif conversation_config.model_type == "openai":
|
|
703
|
+
openai_chat_config = conversation_config.openai_config
|
|
704
|
+
api_key = openai_chat_config.api_key
|
|
705
|
+
chat_model = conversation_config.chat_model
|
|
706
|
+
chat_response = converse(
|
|
707
|
+
compiled_references,
|
|
708
|
+
q,
|
|
709
|
+
online_results=online_results,
|
|
710
|
+
conversation_log=meta_log,
|
|
711
|
+
model=chat_model,
|
|
712
|
+
api_key=api_key,
|
|
713
|
+
api_base_url=openai_chat_config.api_base_url,
|
|
714
|
+
completion_func=partial_completion,
|
|
715
|
+
conversation_commands=conversation_commands,
|
|
716
|
+
max_prompt_size=conversation_config.max_prompt_size,
|
|
717
|
+
tokenizer_name=conversation_config.tokenizer,
|
|
718
|
+
location_data=location_data,
|
|
719
|
+
user_name=user_name,
|
|
720
|
+
agent=agent,
|
|
721
|
+
)
|
|
722
|
+
|
|
723
|
+
elif conversation_config.model_type == "anthropic":
|
|
724
|
+
api_key = conversation_config.openai_config.api_key
|
|
725
|
+
chat_response = converse_anthropic(
|
|
726
|
+
compiled_references,
|
|
727
|
+
q,
|
|
728
|
+
online_results,
|
|
729
|
+
meta_log,
|
|
730
|
+
model=conversation_config.chat_model,
|
|
731
|
+
api_key=api_key,
|
|
732
|
+
completion_func=partial_completion,
|
|
733
|
+
conversation_commands=conversation_commands,
|
|
734
|
+
max_prompt_size=conversation_config.max_prompt_size,
|
|
735
|
+
tokenizer_name=conversation_config.tokenizer,
|
|
736
|
+
location_data=location_data,
|
|
737
|
+
user_name=user_name,
|
|
738
|
+
agent=agent,
|
|
739
|
+
)
|
|
740
|
+
|
|
741
|
+
metadata.update({"chat_model": conversation_config.chat_model})
|
|
742
|
+
|
|
743
|
+
except Exception as e:
|
|
744
|
+
logger.error(e, exc_info=True)
|
|
745
|
+
raise HTTPException(status_code=500, detail=str(e))
|
|
746
|
+
|
|
747
|
+
return chat_response, metadata
|
|
748
|
+
|
|
749
|
+
|
|
750
|
+
async def text_to_image(
|
|
751
|
+
message: str,
|
|
752
|
+
user: KhojUser,
|
|
753
|
+
conversation_log: dict,
|
|
754
|
+
location_data: LocationData,
|
|
755
|
+
references: List[Dict[str, Any]],
|
|
756
|
+
online_results: Dict[str, Any],
|
|
757
|
+
send_status_func: Optional[Callable] = None,
|
|
758
|
+
) -> Tuple[Optional[str], int, Optional[str], str]:
|
|
759
|
+
status_code = 200
|
|
760
|
+
image = None
|
|
761
|
+
response = None
|
|
762
|
+
image_url = None
|
|
763
|
+
intent_type = ImageIntentType.TEXT_TO_IMAGE_V3
|
|
764
|
+
|
|
765
|
+
text_to_image_config = await ConversationAdapters.aget_user_text_to_image_model(user)
|
|
766
|
+
if not text_to_image_config:
|
|
767
|
+
# If the user has not configured a text to image model, return an unsupported on server error
|
|
768
|
+
status_code = 501
|
|
769
|
+
message = "Failed to generate image. Setup image generation on the server."
|
|
770
|
+
return image_url or image, status_code, message, intent_type.value
|
|
771
|
+
|
|
772
|
+
text2image_model = text_to_image_config.model_name
|
|
773
|
+
chat_history = ""
|
|
774
|
+
for chat in conversation_log.get("chat", [])[-4:]:
|
|
775
|
+
if chat["by"] == "khoj" and chat["intent"].get("type") in ["remember", "reminder"]:
|
|
776
|
+
chat_history += f"Q: {chat['intent']['query']}\n"
|
|
777
|
+
chat_history += f"A: {chat['message']}\n"
|
|
778
|
+
elif chat["by"] == "khoj" and "text-to-image" in chat["intent"].get("type"):
|
|
779
|
+
chat_history += f"Q: Prompt: {chat['intent']['query']}\n"
|
|
780
|
+
chat_history += f"A: Improved Prompt: {chat['intent']['inferred-queries'][0]}\n"
|
|
781
|
+
|
|
782
|
+
with timer("Improve the original user query", logger):
|
|
783
|
+
if send_status_func:
|
|
784
|
+
await send_status_func("**✍🏽 Enhancing the Painting Prompt**")
|
|
785
|
+
improved_image_prompt = await generate_better_image_prompt(
|
|
786
|
+
message,
|
|
787
|
+
chat_history,
|
|
788
|
+
location_data=location_data,
|
|
789
|
+
note_references=references,
|
|
790
|
+
online_results=online_results,
|
|
791
|
+
model_type=text_to_image_config.model_type,
|
|
792
|
+
)
|
|
793
|
+
|
|
794
|
+
if send_status_func:
|
|
795
|
+
await send_status_func(f"**🖼️ Painting using Enhanced Prompt**:\n{improved_image_prompt}")
|
|
796
|
+
|
|
797
|
+
if text_to_image_config.model_type == TextToImageModelConfig.ModelType.OPENAI:
|
|
798
|
+
with timer("Generate image with OpenAI", logger):
|
|
799
|
+
if text_to_image_config.api_key:
|
|
800
|
+
api_key = text_to_image_config.api_key
|
|
801
|
+
elif text_to_image_config.openai_config:
|
|
802
|
+
api_key = text_to_image_config.openai_config.api_key
|
|
803
|
+
elif state.openai_client:
|
|
804
|
+
api_key = state.openai_client.api_key
|
|
805
|
+
auth_header = {"Authorization": f"Bearer {api_key}"} if api_key else {}
|
|
806
|
+
try:
|
|
807
|
+
response = state.openai_client.images.generate(
|
|
808
|
+
prompt=improved_image_prompt,
|
|
809
|
+
model=text2image_model,
|
|
810
|
+
response_format="b64_json",
|
|
811
|
+
extra_headers=auth_header,
|
|
812
|
+
)
|
|
813
|
+
image = response.data[0].b64_json
|
|
814
|
+
decoded_image = base64.b64decode(image)
|
|
815
|
+
except openai.OpenAIError or openai.BadRequestError or openai.APIConnectionError as e:
|
|
816
|
+
if "content_policy_violation" in e.message:
|
|
817
|
+
logger.error(f"Image Generation blocked by OpenAI: {e}")
|
|
818
|
+
status_code = e.status_code # type: ignore
|
|
819
|
+
message = f"Image generation blocked by OpenAI: {e.message}" # type: ignore
|
|
820
|
+
return image_url or image, status_code, message, intent_type.value
|
|
821
|
+
else:
|
|
822
|
+
logger.error(f"Image Generation failed with {e}", exc_info=True)
|
|
823
|
+
message = f"Image generation failed with OpenAI error: {e.message}" # type: ignore
|
|
824
|
+
status_code = e.status_code # type: ignore
|
|
825
|
+
return image_url or image, status_code, message, intent_type.value
|
|
826
|
+
|
|
827
|
+
elif text_to_image_config.model_type == TextToImageModelConfig.ModelType.STABILITYAI:
|
|
828
|
+
with timer("Generate image with Stability AI", logger):
|
|
829
|
+
try:
|
|
830
|
+
response = requests.post(
|
|
831
|
+
f"https://api.stability.ai/v2beta/stable-image/generate/sd3",
|
|
832
|
+
headers={"authorization": f"Bearer {text_to_image_config.api_key}", "accept": "image/*"},
|
|
833
|
+
files={"none": ""},
|
|
834
|
+
data={
|
|
835
|
+
"prompt": improved_image_prompt,
|
|
836
|
+
"model": text2image_model,
|
|
837
|
+
"mode": "text-to-image",
|
|
838
|
+
"output_format": "png",
|
|
839
|
+
"aspect_ratio": "1:1",
|
|
840
|
+
},
|
|
841
|
+
)
|
|
842
|
+
decoded_image = response.content
|
|
843
|
+
except requests.RequestException as e:
|
|
844
|
+
logger.error(f"Image Generation failed with {e}", exc_info=True)
|
|
845
|
+
message = f"Image generation failed with Stability AI error: {e}"
|
|
846
|
+
status_code = e.status_code # type: ignore
|
|
847
|
+
return image_url or image, status_code, message, intent_type.value
|
|
848
|
+
|
|
849
|
+
with timer("Convert image to webp", logger):
|
|
850
|
+
# Convert png to webp for faster loading
|
|
851
|
+
image_io = io.BytesIO(decoded_image)
|
|
852
|
+
png_image = Image.open(image_io)
|
|
853
|
+
webp_image_io = io.BytesIO()
|
|
854
|
+
png_image.save(webp_image_io, "WEBP")
|
|
855
|
+
webp_image_bytes = webp_image_io.getvalue()
|
|
856
|
+
webp_image_io.close()
|
|
857
|
+
image_io.close()
|
|
858
|
+
|
|
859
|
+
with timer("Upload image to S3", logger):
|
|
860
|
+
image_url = upload_image(webp_image_bytes, user.uuid)
|
|
861
|
+
if image_url:
|
|
862
|
+
intent_type = ImageIntentType.TEXT_TO_IMAGE2
|
|
863
|
+
else:
|
|
864
|
+
intent_type = ImageIntentType.TEXT_TO_IMAGE_V3
|
|
865
|
+
image = base64.b64encode(webp_image_bytes).decode("utf-8")
|
|
866
|
+
|
|
867
|
+
return image_url or image, status_code, improved_image_prompt, intent_type.value
|
|
868
|
+
|
|
869
|
+
|
|
870
|
+
class ApiUserRateLimiter:
|
|
871
|
+
def __init__(self, requests: int, subscribed_requests: int, window: int, slug: str):
|
|
872
|
+
self.requests = requests
|
|
873
|
+
self.subscribed_requests = subscribed_requests
|
|
874
|
+
self.window = window
|
|
875
|
+
self.slug = slug
|
|
876
|
+
|
|
877
|
+
def __call__(self, request: Request):
|
|
878
|
+
# Rate limiting disabled if billing is disabled
|
|
879
|
+
if state.billing_enabled is False:
|
|
880
|
+
return
|
|
881
|
+
|
|
882
|
+
# Rate limiting is disabled if user unauthenticated.
|
|
883
|
+
# Other systems handle authentication
|
|
884
|
+
if not request.user.is_authenticated:
|
|
885
|
+
return
|
|
886
|
+
|
|
887
|
+
user: KhojUser = request.user.object
|
|
888
|
+
subscribed = has_required_scope(request, ["premium"])
|
|
889
|
+
|
|
890
|
+
# Remove requests outside of the time window
|
|
891
|
+
cutoff = datetime.now(tz=timezone.utc) - timedelta(seconds=self.window)
|
|
892
|
+
count_requests = UserRequests.objects.filter(user=user, created_at__gte=cutoff, slug=self.slug).count()
|
|
893
|
+
|
|
894
|
+
# Check if the user has exceeded the rate limit
|
|
895
|
+
if subscribed and count_requests >= self.subscribed_requests:
|
|
896
|
+
raise HTTPException(status_code=429, detail="Slow down! Too Many Requests")
|
|
897
|
+
if not subscribed and count_requests >= self.requests:
|
|
898
|
+
if self.requests >= self.subscribed_requests:
|
|
899
|
+
raise HTTPException(
|
|
900
|
+
status_code=429,
|
|
901
|
+
detail="Slow down! Too Many Requests",
|
|
902
|
+
)
|
|
903
|
+
raise HTTPException(
|
|
904
|
+
status_code=429,
|
|
905
|
+
detail="We're glad you're enjoying Khoj! You've exceeded your usage limit for today. Come back tomorrow or subscribe to increase your usage limit via [your settings](https://app.khoj.dev/config).",
|
|
906
|
+
)
|
|
907
|
+
|
|
908
|
+
# Add the current request to the cache
|
|
909
|
+
UserRequests.objects.create(user=user, slug=self.slug)
|
|
910
|
+
|
|
911
|
+
|
|
912
|
+
class ConversationCommandRateLimiter:
|
|
913
|
+
def __init__(self, trial_rate_limit: int, subscribed_rate_limit: int, slug: str):
|
|
914
|
+
self.slug = slug
|
|
915
|
+
self.trial_rate_limit = trial_rate_limit
|
|
916
|
+
self.subscribed_rate_limit = subscribed_rate_limit
|
|
917
|
+
self.restricted_commands = [ConversationCommand.Online, ConversationCommand.Image]
|
|
918
|
+
|
|
919
|
+
async def update_and_check_if_valid(self, request: Request, conversation_command: ConversationCommand):
|
|
920
|
+
if state.billing_enabled is False:
|
|
921
|
+
return
|
|
922
|
+
|
|
923
|
+
if not request.user.is_authenticated:
|
|
924
|
+
return
|
|
925
|
+
|
|
926
|
+
if conversation_command not in self.restricted_commands:
|
|
927
|
+
return
|
|
928
|
+
|
|
929
|
+
user: KhojUser = request.user.object
|
|
930
|
+
subscribed = has_required_scope(request, ["premium"])
|
|
931
|
+
|
|
932
|
+
# Remove requests outside of the 24-hr time window
|
|
933
|
+
cutoff = datetime.now(tz=timezone.utc) - timedelta(seconds=60 * 60 * 24)
|
|
934
|
+
command_slug = f"{self.slug}_{conversation_command.value}"
|
|
935
|
+
count_requests = await UserRequests.objects.filter(
|
|
936
|
+
user=user, created_at__gte=cutoff, slug=command_slug
|
|
937
|
+
).acount()
|
|
938
|
+
|
|
939
|
+
if subscribed and count_requests >= self.subscribed_rate_limit:
|
|
940
|
+
raise HTTPException(status_code=429, detail="Slow down! Too Many Requests")
|
|
941
|
+
if not subscribed and count_requests >= self.trial_rate_limit:
|
|
942
|
+
raise HTTPException(
|
|
943
|
+
status_code=429,
|
|
944
|
+
detail=f"We're glad you're enjoying Khoj! You've exceeded your `/{conversation_command.value}` command usage limit for today. Subscribe to increase your usage limit via [your settings](https://app.khoj.dev/config).",
|
|
945
|
+
)
|
|
946
|
+
await UserRequests.objects.acreate(user=user, slug=command_slug)
|
|
947
|
+
return
|
|
948
|
+
|
|
949
|
+
|
|
950
|
+
class ApiIndexedDataLimiter:
|
|
951
|
+
def __init__(
|
|
952
|
+
self,
|
|
953
|
+
incoming_entries_size_limit: float,
|
|
954
|
+
subscribed_incoming_entries_size_limit: float,
|
|
955
|
+
total_entries_size_limit: float,
|
|
956
|
+
subscribed_total_entries_size_limit: float,
|
|
957
|
+
):
|
|
958
|
+
self.num_entries_size = incoming_entries_size_limit
|
|
959
|
+
self.subscribed_num_entries_size = subscribed_incoming_entries_size_limit
|
|
960
|
+
self.total_entries_size_limit = total_entries_size_limit
|
|
961
|
+
self.subscribed_total_entries_size = subscribed_total_entries_size_limit
|
|
962
|
+
|
|
963
|
+
def __call__(self, request: Request, files: List[UploadFile]):
|
|
964
|
+
if state.billing_enabled is False:
|
|
965
|
+
return
|
|
966
|
+
subscribed = has_required_scope(request, ["premium"])
|
|
967
|
+
incoming_data_size_mb = 0.0
|
|
968
|
+
deletion_file_names = set()
|
|
969
|
+
|
|
970
|
+
if not request.user.is_authenticated:
|
|
971
|
+
return
|
|
972
|
+
|
|
973
|
+
user: KhojUser = request.user.object
|
|
974
|
+
|
|
975
|
+
for file in files:
|
|
976
|
+
if file.size == 0:
|
|
977
|
+
deletion_file_names.add(file.filename)
|
|
978
|
+
|
|
979
|
+
incoming_data_size_mb += file.size / 1024 / 1024
|
|
980
|
+
|
|
981
|
+
num_deleted_entries = 0
|
|
982
|
+
for file_path in deletion_file_names:
|
|
983
|
+
deleted_count = EntryAdapters.delete_entry_by_file(user, file_path)
|
|
984
|
+
num_deleted_entries += deleted_count
|
|
985
|
+
|
|
986
|
+
logger.info(f"Deleted {num_deleted_entries} entries for user: {user}.")
|
|
987
|
+
|
|
988
|
+
if subscribed and incoming_data_size_mb >= self.subscribed_num_entries_size:
|
|
989
|
+
raise HTTPException(status_code=429, detail="Too much data indexed.")
|
|
990
|
+
if not subscribed and incoming_data_size_mb >= self.num_entries_size:
|
|
991
|
+
raise HTTPException(
|
|
992
|
+
status_code=429, detail="Too much data indexed. Subscribe to increase your data index limit."
|
|
993
|
+
)
|
|
994
|
+
|
|
995
|
+
user_size_data = EntryAdapters.get_size_of_indexed_data_in_mb(user)
|
|
996
|
+
if subscribed and user_size_data + incoming_data_size_mb >= self.subscribed_total_entries_size:
|
|
997
|
+
raise HTTPException(status_code=429, detail="Too much data indexed.")
|
|
998
|
+
if not subscribed and user_size_data + incoming_data_size_mb >= self.total_entries_size_limit:
|
|
999
|
+
raise HTTPException(
|
|
1000
|
+
status_code=429, detail="Too much data indexed. Subscribe to increase your data index limit."
|
|
1001
|
+
)
|
|
1002
|
+
|
|
1003
|
+
|
|
1004
|
+
class CommonQueryParamsClass:
|
|
1005
|
+
def __init__(
|
|
1006
|
+
self,
|
|
1007
|
+
client: Optional[str] = None,
|
|
1008
|
+
user_agent: Optional[str] = Header(None),
|
|
1009
|
+
referer: Optional[str] = Header(None),
|
|
1010
|
+
host: Optional[str] = Header(None),
|
|
1011
|
+
):
|
|
1012
|
+
self.client = client
|
|
1013
|
+
self.user_agent = user_agent
|
|
1014
|
+
self.referer = referer
|
|
1015
|
+
self.host = host
|
|
1016
|
+
|
|
1017
|
+
|
|
1018
|
+
CommonQueryParams = Annotated[CommonQueryParamsClass, Depends()]
|
|
1019
|
+
|
|
1020
|
+
|
|
1021
|
+
def should_notify(original_query: str, executed_query: str, ai_response: str) -> bool:
|
|
1022
|
+
"""
|
|
1023
|
+
Decide whether to notify the user of the AI response.
|
|
1024
|
+
Default to notifying the user for now.
|
|
1025
|
+
"""
|
|
1026
|
+
if any(is_none_or_empty(message) for message in [original_query, executed_query, ai_response]):
|
|
1027
|
+
return False
|
|
1028
|
+
|
|
1029
|
+
to_notify_or_not = prompts.to_notify_or_not.format(
|
|
1030
|
+
original_query=original_query,
|
|
1031
|
+
executed_query=executed_query,
|
|
1032
|
+
response=ai_response,
|
|
1033
|
+
)
|
|
1034
|
+
|
|
1035
|
+
with timer("Chat actor: Decide to notify user of automation response", logger):
|
|
1036
|
+
try:
|
|
1037
|
+
response = send_message_to_model_wrapper_sync(to_notify_or_not)
|
|
1038
|
+
should_notify_result = "no" not in response.lower()
|
|
1039
|
+
logger.info(f'Decided to {"not " if not should_notify_result else ""}notify user of automation response.')
|
|
1040
|
+
return should_notify_result
|
|
1041
|
+
except:
|
|
1042
|
+
logger.warning(f"Fallback to notify user of automation response as failed to infer should notify or not.")
|
|
1043
|
+
return True
|
|
1044
|
+
|
|
1045
|
+
|
|
1046
|
+
def scheduled_chat(
|
|
1047
|
+
query_to_run: str, scheduling_request: str, subject: str, user: KhojUser, calling_url: URL, job_id: str = None
|
|
1048
|
+
):
|
|
1049
|
+
logger.info(f"Processing scheduled_chat: {query_to_run}")
|
|
1050
|
+
if job_id:
|
|
1051
|
+
# Get the job object and check whether the time is valid for it to run. This helps avoid race conditions that cause the same job to be run multiple times.
|
|
1052
|
+
job = AutomationAdapters.get_automation(user, job_id)
|
|
1053
|
+
last_run_time = AutomationAdapters.get_job_last_run(user, job)
|
|
1054
|
+
|
|
1055
|
+
# Convert last_run_time from %Y-%m-%d %I:%M %p %Z to datetime object
|
|
1056
|
+
if last_run_time:
|
|
1057
|
+
last_run_time = datetime.strptime(last_run_time, "%Y-%m-%d %I:%M %p %Z").replace(tzinfo=timezone.utc)
|
|
1058
|
+
|
|
1059
|
+
# If the last run time was within the last 6 hours, don't run it again. This helps avoid multithreading issues and rate limits.
|
|
1060
|
+
if (datetime.now(timezone.utc) - last_run_time).total_seconds() < 21600:
|
|
1061
|
+
logger.info(f"Skipping scheduled chat {job_id} as the next run time is in the future.")
|
|
1062
|
+
return
|
|
1063
|
+
|
|
1064
|
+
# Extract relevant params from the original URL
|
|
1065
|
+
scheme = "http" if not calling_url.is_secure else "https"
|
|
1066
|
+
query_dict = parse_qs(calling_url.query)
|
|
1067
|
+
|
|
1068
|
+
# Pop the stream value from query_dict if it exists
|
|
1069
|
+
query_dict.pop("stream", None)
|
|
1070
|
+
|
|
1071
|
+
# Replace the original scheduling query with the scheduled query
|
|
1072
|
+
query_dict["q"] = [query_to_run]
|
|
1073
|
+
|
|
1074
|
+
# Construct the URL to call the chat API with the scheduled query string
|
|
1075
|
+
encoded_query = urlencode(query_dict, doseq=True)
|
|
1076
|
+
url = f"{scheme}://{calling_url.netloc}/api/chat?{encoded_query}"
|
|
1077
|
+
|
|
1078
|
+
# Construct the Headers for the chat API
|
|
1079
|
+
headers = {"User-Agent": "Khoj"}
|
|
1080
|
+
if not state.anonymous_mode:
|
|
1081
|
+
# Add authorization request header in non-anonymous mode
|
|
1082
|
+
token = get_khoj_tokens(user)
|
|
1083
|
+
if is_none_or_empty(token):
|
|
1084
|
+
token = create_khoj_token(user).token
|
|
1085
|
+
else:
|
|
1086
|
+
token = token[0].token
|
|
1087
|
+
headers["Authorization"] = f"Bearer {token}"
|
|
1088
|
+
|
|
1089
|
+
# Call the chat API endpoint with authenticated user token and query
|
|
1090
|
+
raw_response = requests.get(url, headers=headers)
|
|
1091
|
+
|
|
1092
|
+
# Stop if the chat API call was not successful
|
|
1093
|
+
if raw_response.status_code != 200:
|
|
1094
|
+
logger.error(f"Failed to run schedule chat: {raw_response.text}, user: {user}, query: {query_to_run}")
|
|
1095
|
+
return None
|
|
1096
|
+
|
|
1097
|
+
# Extract the AI response from the chat API response
|
|
1098
|
+
cleaned_query = re.sub(r"^/automated_task\s*", "", query_to_run).strip()
|
|
1099
|
+
is_image = False
|
|
1100
|
+
if raw_response.headers.get("Content-Type") == "application/json":
|
|
1101
|
+
response_map = raw_response.json()
|
|
1102
|
+
ai_response = response_map.get("response") or response_map.get("image")
|
|
1103
|
+
is_image = response_map.get("image") is not None
|
|
1104
|
+
else:
|
|
1105
|
+
ai_response = raw_response.text
|
|
1106
|
+
|
|
1107
|
+
# Notify user if the AI response is satisfactory
|
|
1108
|
+
if should_notify(original_query=scheduling_request, executed_query=cleaned_query, ai_response=ai_response):
|
|
1109
|
+
if is_resend_enabled():
|
|
1110
|
+
send_task_email(user.get_short_name(), user.email, cleaned_query, ai_response, subject, is_image)
|
|
1111
|
+
else:
|
|
1112
|
+
return raw_response
|
|
1113
|
+
|
|
1114
|
+
|
|
1115
|
+
async def create_automation(q: str, timezone: str, user: KhojUser, calling_url: URL, meta_log: dict = {}):
|
|
1116
|
+
crontime, query_to_run, subject = await schedule_query(q, meta_log)
|
|
1117
|
+
job = await schedule_automation(query_to_run, subject, crontime, timezone, q, user, calling_url)
|
|
1118
|
+
return job, crontime, query_to_run, subject
|
|
1119
|
+
|
|
1120
|
+
|
|
1121
|
+
async def schedule_automation(
|
|
1122
|
+
query_to_run: str,
|
|
1123
|
+
subject: str,
|
|
1124
|
+
crontime: str,
|
|
1125
|
+
timezone: str,
|
|
1126
|
+
scheduling_request: str,
|
|
1127
|
+
user: KhojUser,
|
|
1128
|
+
calling_url: URL,
|
|
1129
|
+
):
|
|
1130
|
+
# Disable minute level automation recurrence
|
|
1131
|
+
minute_value = crontime.split(" ")[0]
|
|
1132
|
+
if not minute_value.isdigit():
|
|
1133
|
+
# Run automation at some random minute (to distribute request load) instead of running every X minutes
|
|
1134
|
+
crontime = " ".join([str(math.floor(random() * 60))] + crontime.split(" ")[1:])
|
|
1135
|
+
|
|
1136
|
+
user_timezone = pytz.timezone(timezone)
|
|
1137
|
+
trigger = CronTrigger.from_crontab(crontime, user_timezone)
|
|
1138
|
+
trigger.jitter = 60
|
|
1139
|
+
# Generate id and metadata used by task scheduler and process locks for the task runs
|
|
1140
|
+
job_metadata = json.dumps(
|
|
1141
|
+
{
|
|
1142
|
+
"query_to_run": query_to_run,
|
|
1143
|
+
"scheduling_request": scheduling_request,
|
|
1144
|
+
"subject": subject,
|
|
1145
|
+
"crontime": crontime,
|
|
1146
|
+
}
|
|
1147
|
+
)
|
|
1148
|
+
query_id = hashlib.md5(f"{query_to_run}_{crontime}".encode("utf-8")).hexdigest()
|
|
1149
|
+
job_id = f"automation_{user.uuid}_{query_id}"
|
|
1150
|
+
job = await sync_to_async(state.scheduler.add_job)(
|
|
1151
|
+
run_with_process_lock,
|
|
1152
|
+
trigger=trigger,
|
|
1153
|
+
args=(
|
|
1154
|
+
scheduled_chat,
|
|
1155
|
+
f"{ProcessLock.Operation.SCHEDULED_JOB}_{user.uuid}_{query_id}",
|
|
1156
|
+
),
|
|
1157
|
+
kwargs={
|
|
1158
|
+
"query_to_run": query_to_run,
|
|
1159
|
+
"scheduling_request": scheduling_request,
|
|
1160
|
+
"subject": subject,
|
|
1161
|
+
"user": user,
|
|
1162
|
+
"calling_url": calling_url,
|
|
1163
|
+
"job_id": job_id,
|
|
1164
|
+
},
|
|
1165
|
+
id=job_id,
|
|
1166
|
+
name=job_metadata,
|
|
1167
|
+
max_instances=2, # Allow second instance to kill any previous instance with stale lock
|
|
1168
|
+
)
|
|
1169
|
+
return job
|
|
1170
|
+
|
|
1171
|
+
|
|
1172
|
+
def construct_automation_created_message(automation: Job, crontime: str, query_to_run: str, subject: str):
|
|
1173
|
+
# Display next run time in user timezone instead of UTC
|
|
1174
|
+
schedule = f'{cron_descriptor.get_description(crontime)} {automation.next_run_time.strftime("%Z")}'
|
|
1175
|
+
next_run_time = automation.next_run_time.strftime("%Y-%m-%d %I:%M %p %Z")
|
|
1176
|
+
# Remove /automated_task prefix from inferred_query
|
|
1177
|
+
unprefixed_query_to_run = re.sub(r"^\/automated_task\s*", "", query_to_run)
|
|
1178
|
+
# Create the automation response
|
|
1179
|
+
automation_icon_url = f"/static/assets/icons/automation.svg"
|
|
1180
|
+
return f"""
|
|
1181
|
+
###  Created Automation
|
|
1182
|
+
- Subject: **{subject}**
|
|
1183
|
+
- Query to Run: "{unprefixed_query_to_run}"
|
|
1184
|
+
- Schedule: `{schedule}`
|
|
1185
|
+
- Next Run At: {next_run_time}
|
|
1186
|
+
|
|
1187
|
+
Manage your automations [here](/automations).
|
|
1188
|
+
""".strip()
|