khoj 1.16.1.dev15__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- khoj/__init__.py +0 -0
- khoj/app/README.md +94 -0
- khoj/app/__init__.py +0 -0
- khoj/app/asgi.py +16 -0
- khoj/app/settings.py +192 -0
- khoj/app/urls.py +25 -0
- khoj/configure.py +424 -0
- khoj/database/__init__.py +0 -0
- khoj/database/adapters/__init__.py +1234 -0
- khoj/database/admin.py +290 -0
- khoj/database/apps.py +6 -0
- khoj/database/management/__init__.py +0 -0
- khoj/database/management/commands/__init__.py +0 -0
- khoj/database/management/commands/change_generated_images_url.py +61 -0
- khoj/database/management/commands/convert_images_png_to_webp.py +99 -0
- khoj/database/migrations/0001_khojuser.py +98 -0
- khoj/database/migrations/0002_googleuser.py +32 -0
- khoj/database/migrations/0003_vector_extension.py +10 -0
- khoj/database/migrations/0004_content_types_and_more.py +181 -0
- khoj/database/migrations/0005_embeddings_corpus_id.py +19 -0
- khoj/database/migrations/0006_embeddingsdates.py +33 -0
- khoj/database/migrations/0007_add_conversation.py +27 -0
- khoj/database/migrations/0008_alter_conversation_conversation_log.py +17 -0
- khoj/database/migrations/0009_khojapiuser.py +24 -0
- khoj/database/migrations/0010_chatmodeloptions_and_more.py +83 -0
- khoj/database/migrations/0010_rename_embeddings_entry_and_more.py +30 -0
- khoj/database/migrations/0011_merge_20231102_0138.py +14 -0
- khoj/database/migrations/0012_entry_file_source.py +21 -0
- khoj/database/migrations/0013_subscription.py +37 -0
- khoj/database/migrations/0014_alter_googleuser_picture.py +17 -0
- khoj/database/migrations/0015_alter_subscription_user.py +21 -0
- khoj/database/migrations/0016_alter_subscription_renewal_date.py +17 -0
- khoj/database/migrations/0017_searchmodel.py +32 -0
- khoj/database/migrations/0018_searchmodelconfig_delete_searchmodel.py +30 -0
- khoj/database/migrations/0019_alter_googleuser_family_name_and_more.py +27 -0
- khoj/database/migrations/0020_reflectivequestion.py +36 -0
- khoj/database/migrations/0021_speechtotextmodeloptions_and_more.py +42 -0
- khoj/database/migrations/0022_texttoimagemodelconfig.py +25 -0
- khoj/database/migrations/0023_usersearchmodelconfig.py +33 -0
- khoj/database/migrations/0024_alter_entry_embeddings.py +18 -0
- khoj/database/migrations/0025_clientapplication_khojuser_phone_number_and_more.py +46 -0
- khoj/database/migrations/0025_searchmodelconfig_embeddings_inference_endpoint_and_more.py +22 -0
- khoj/database/migrations/0026_searchmodelconfig_cross_encoder_inference_endpoint_and_more.py +22 -0
- khoj/database/migrations/0027_merge_20240118_1324.py +13 -0
- khoj/database/migrations/0028_khojuser_verified_phone_number.py +17 -0
- khoj/database/migrations/0029_userrequests.py +27 -0
- khoj/database/migrations/0030_conversation_slug_and_title.py +38 -0
- khoj/database/migrations/0031_agent_conversation_agent.py +53 -0
- khoj/database/migrations/0031_alter_googleuser_locale.py +30 -0
- khoj/database/migrations/0032_merge_20240322_0427.py +14 -0
- khoj/database/migrations/0033_rename_tuning_agent_personality.py +17 -0
- khoj/database/migrations/0034_alter_chatmodeloptions_chat_model.py +32 -0
- khoj/database/migrations/0035_processlock.py +26 -0
- khoj/database/migrations/0036_alter_processlock_name.py +19 -0
- khoj/database/migrations/0036_delete_offlinechatprocessorconversationconfig.py +15 -0
- khoj/database/migrations/0036_publicconversation.py +42 -0
- khoj/database/migrations/0037_chatmodeloptions_openai_config_and_more.py +51 -0
- khoj/database/migrations/0037_searchmodelconfig_bi_encoder_docs_encode_config_and_more.py +32 -0
- khoj/database/migrations/0038_merge_20240425_0857.py +14 -0
- khoj/database/migrations/0038_merge_20240426_1640.py +12 -0
- khoj/database/migrations/0039_merge_20240501_0301.py +12 -0
- khoj/database/migrations/0040_alter_processlock_name.py +26 -0
- khoj/database/migrations/0040_merge_20240504_1010.py +14 -0
- khoj/database/migrations/0041_merge_20240505_1234.py +14 -0
- khoj/database/migrations/0042_serverchatsettings.py +46 -0
- khoj/database/migrations/0043_alter_chatmodeloptions_model_type.py +21 -0
- khoj/database/migrations/0044_conversation_file_filters.py +17 -0
- khoj/database/migrations/0045_fileobject.py +37 -0
- khoj/database/migrations/0046_khojuser_email_verification_code_and_more.py +22 -0
- khoj/database/migrations/0047_alter_entry_file_type.py +31 -0
- khoj/database/migrations/0048_voicemodeloption_uservoicemodelconfig.py +52 -0
- khoj/database/migrations/0049_datastore.py +38 -0
- khoj/database/migrations/0049_texttoimagemodelconfig_api_key_and_more.py +58 -0
- khoj/database/migrations/0050_alter_processlock_name.py +25 -0
- khoj/database/migrations/0051_merge_20240702_1220.py +14 -0
- khoj/database/migrations/0052_alter_searchmodelconfig_bi_encoder_docs_encode_config_and_more.py +27 -0
- khoj/database/migrations/__init__.py +0 -0
- khoj/database/models/__init__.py +402 -0
- khoj/database/tests.py +3 -0
- khoj/interface/email/feedback.html +34 -0
- khoj/interface/email/magic_link.html +17 -0
- khoj/interface/email/task.html +40 -0
- khoj/interface/email/welcome.html +61 -0
- khoj/interface/web/404.html +56 -0
- khoj/interface/web/agent.html +312 -0
- khoj/interface/web/agents.html +276 -0
- khoj/interface/web/assets/icons/agents.svg +6 -0
- khoj/interface/web/assets/icons/automation.svg +37 -0
- khoj/interface/web/assets/icons/cancel.svg +3 -0
- khoj/interface/web/assets/icons/chat.svg +24 -0
- khoj/interface/web/assets/icons/collapse.svg +17 -0
- khoj/interface/web/assets/icons/computer.png +0 -0
- khoj/interface/web/assets/icons/confirm-icon.svg +1 -0
- khoj/interface/web/assets/icons/copy-button-success.svg +6 -0
- khoj/interface/web/assets/icons/copy-button.svg +5 -0
- khoj/interface/web/assets/icons/credit-card.png +0 -0
- khoj/interface/web/assets/icons/delete.svg +26 -0
- khoj/interface/web/assets/icons/docx.svg +7 -0
- khoj/interface/web/assets/icons/edit.svg +4 -0
- khoj/interface/web/assets/icons/favicon-128x128.ico +0 -0
- khoj/interface/web/assets/icons/favicon-128x128.png +0 -0
- khoj/interface/web/assets/icons/favicon-256x256.png +0 -0
- khoj/interface/web/assets/icons/favicon.icns +0 -0
- khoj/interface/web/assets/icons/github.svg +1 -0
- khoj/interface/web/assets/icons/key.svg +4 -0
- khoj/interface/web/assets/icons/khoj-logo-sideways-200.png +0 -0
- khoj/interface/web/assets/icons/khoj-logo-sideways-500.png +0 -0
- khoj/interface/web/assets/icons/khoj-logo-sideways.svg +5385 -0
- khoj/interface/web/assets/icons/logotype.svg +1 -0
- khoj/interface/web/assets/icons/markdown.svg +1 -0
- khoj/interface/web/assets/icons/new.svg +23 -0
- khoj/interface/web/assets/icons/notion.svg +4 -0
- khoj/interface/web/assets/icons/openai-logomark.svg +1 -0
- khoj/interface/web/assets/icons/org.svg +1 -0
- khoj/interface/web/assets/icons/pdf.svg +23 -0
- khoj/interface/web/assets/icons/pencil-edit.svg +5 -0
- khoj/interface/web/assets/icons/plaintext.svg +1 -0
- khoj/interface/web/assets/icons/question-mark-icon.svg +1 -0
- khoj/interface/web/assets/icons/search.svg +25 -0
- khoj/interface/web/assets/icons/send.svg +1 -0
- khoj/interface/web/assets/icons/share.svg +8 -0
- khoj/interface/web/assets/icons/speaker.svg +4 -0
- khoj/interface/web/assets/icons/stop-solid.svg +37 -0
- khoj/interface/web/assets/icons/sync.svg +4 -0
- khoj/interface/web/assets/icons/thumbs-down-svgrepo-com.svg +6 -0
- khoj/interface/web/assets/icons/thumbs-up-svgrepo-com.svg +6 -0
- khoj/interface/web/assets/icons/user-silhouette.svg +4 -0
- khoj/interface/web/assets/icons/voice.svg +8 -0
- khoj/interface/web/assets/icons/web.svg +2 -0
- khoj/interface/web/assets/icons/whatsapp.svg +17 -0
- khoj/interface/web/assets/khoj.css +237 -0
- khoj/interface/web/assets/markdown-it.min.js +8476 -0
- khoj/interface/web/assets/natural-cron.min.js +1 -0
- khoj/interface/web/assets/org.min.js +1823 -0
- khoj/interface/web/assets/pico.min.css +5 -0
- khoj/interface/web/assets/purify.min.js +3 -0
- khoj/interface/web/assets/samples/desktop-browse-draw-sample.png +0 -0
- khoj/interface/web/assets/samples/desktop-plain-chat-sample.png +0 -0
- khoj/interface/web/assets/samples/desktop-remember-plan-sample.png +0 -0
- khoj/interface/web/assets/samples/phone-browse-draw-sample.png +0 -0
- khoj/interface/web/assets/samples/phone-plain-chat-sample.png +0 -0
- khoj/interface/web/assets/samples/phone-remember-plan-sample.png +0 -0
- khoj/interface/web/assets/utils.js +33 -0
- khoj/interface/web/base_config.html +445 -0
- khoj/interface/web/chat.html +3546 -0
- khoj/interface/web/config.html +1011 -0
- khoj/interface/web/config_automation.html +1103 -0
- khoj/interface/web/content_source_computer_input.html +139 -0
- khoj/interface/web/content_source_github_input.html +216 -0
- khoj/interface/web/content_source_notion_input.html +94 -0
- khoj/interface/web/khoj.webmanifest +51 -0
- khoj/interface/web/login.html +219 -0
- khoj/interface/web/public_conversation.html +2006 -0
- khoj/interface/web/search.html +470 -0
- khoj/interface/web/utils.html +48 -0
- khoj/main.py +241 -0
- khoj/manage.py +22 -0
- khoj/migrations/__init__.py +0 -0
- khoj/migrations/migrate_offline_chat_default_model.py +69 -0
- khoj/migrations/migrate_offline_chat_default_model_2.py +71 -0
- khoj/migrations/migrate_offline_chat_schema.py +83 -0
- khoj/migrations/migrate_offline_model.py +29 -0
- khoj/migrations/migrate_processor_config_openai.py +67 -0
- khoj/migrations/migrate_server_pg.py +138 -0
- khoj/migrations/migrate_version.py +17 -0
- khoj/processor/__init__.py +0 -0
- khoj/processor/content/__init__.py +0 -0
- khoj/processor/content/docx/__init__.py +0 -0
- khoj/processor/content/docx/docx_to_entries.py +110 -0
- khoj/processor/content/github/__init__.py +0 -0
- khoj/processor/content/github/github_to_entries.py +224 -0
- khoj/processor/content/images/__init__.py +0 -0
- khoj/processor/content/images/image_to_entries.py +118 -0
- khoj/processor/content/markdown/__init__.py +0 -0
- khoj/processor/content/markdown/markdown_to_entries.py +165 -0
- khoj/processor/content/notion/notion_to_entries.py +260 -0
- khoj/processor/content/org_mode/__init__.py +0 -0
- khoj/processor/content/org_mode/org_to_entries.py +231 -0
- khoj/processor/content/org_mode/orgnode.py +532 -0
- khoj/processor/content/pdf/__init__.py +0 -0
- khoj/processor/content/pdf/pdf_to_entries.py +116 -0
- khoj/processor/content/plaintext/__init__.py +0 -0
- khoj/processor/content/plaintext/plaintext_to_entries.py +122 -0
- khoj/processor/content/text_to_entries.py +297 -0
- khoj/processor/conversation/__init__.py +0 -0
- khoj/processor/conversation/anthropic/__init__.py +0 -0
- khoj/processor/conversation/anthropic/anthropic_chat.py +206 -0
- khoj/processor/conversation/anthropic/utils.py +114 -0
- khoj/processor/conversation/offline/__init__.py +0 -0
- khoj/processor/conversation/offline/chat_model.py +231 -0
- khoj/processor/conversation/offline/utils.py +78 -0
- khoj/processor/conversation/offline/whisper.py +15 -0
- khoj/processor/conversation/openai/__init__.py +0 -0
- khoj/processor/conversation/openai/gpt.py +187 -0
- khoj/processor/conversation/openai/utils.py +129 -0
- khoj/processor/conversation/openai/whisper.py +13 -0
- khoj/processor/conversation/prompts.py +758 -0
- khoj/processor/conversation/utils.py +262 -0
- khoj/processor/embeddings.py +117 -0
- khoj/processor/speech/__init__.py +0 -0
- khoj/processor/speech/text_to_speech.py +51 -0
- khoj/processor/tools/__init__.py +0 -0
- khoj/processor/tools/online_search.py +225 -0
- khoj/routers/__init__.py +0 -0
- khoj/routers/api.py +626 -0
- khoj/routers/api_agents.py +43 -0
- khoj/routers/api_chat.py +1180 -0
- khoj/routers/api_config.py +434 -0
- khoj/routers/api_phone.py +86 -0
- khoj/routers/auth.py +181 -0
- khoj/routers/email.py +133 -0
- khoj/routers/helpers.py +1188 -0
- khoj/routers/indexer.py +349 -0
- khoj/routers/notion.py +91 -0
- khoj/routers/storage.py +35 -0
- khoj/routers/subscription.py +104 -0
- khoj/routers/twilio.py +36 -0
- khoj/routers/web_client.py +471 -0
- khoj/search_filter/__init__.py +0 -0
- khoj/search_filter/base_filter.py +15 -0
- khoj/search_filter/date_filter.py +217 -0
- khoj/search_filter/file_filter.py +30 -0
- khoj/search_filter/word_filter.py +29 -0
- khoj/search_type/__init__.py +0 -0
- khoj/search_type/text_search.py +241 -0
- khoj/utils/__init__.py +0 -0
- khoj/utils/cli.py +93 -0
- khoj/utils/config.py +81 -0
- khoj/utils/constants.py +24 -0
- khoj/utils/fs_syncer.py +249 -0
- khoj/utils/helpers.py +418 -0
- khoj/utils/initialization.py +146 -0
- khoj/utils/jsonl.py +43 -0
- khoj/utils/models.py +47 -0
- khoj/utils/rawconfig.py +160 -0
- khoj/utils/state.py +46 -0
- khoj/utils/yaml.py +43 -0
- khoj-1.16.1.dev15.dist-info/METADATA +178 -0
- khoj-1.16.1.dev15.dist-info/RECORD +242 -0
- khoj-1.16.1.dev15.dist-info/WHEEL +4 -0
- khoj-1.16.1.dev15.dist-info/entry_points.txt +2 -0
- khoj-1.16.1.dev15.dist-info/licenses/LICENSE +661 -0
|
@@ -0,0 +1,1234 @@
|
|
|
1
|
+
import json
|
|
2
|
+
import logging
|
|
3
|
+
import math
|
|
4
|
+
import random
|
|
5
|
+
import re
|
|
6
|
+
import secrets
|
|
7
|
+
import sys
|
|
8
|
+
from datetime import date, datetime, timedelta, timezone
|
|
9
|
+
from enum import Enum
|
|
10
|
+
from typing import Callable, Iterable, List, Optional, Type
|
|
11
|
+
|
|
12
|
+
import cron_descriptor
|
|
13
|
+
from apscheduler.job import Job
|
|
14
|
+
from asgiref.sync import sync_to_async
|
|
15
|
+
from django.contrib.sessions.backends.db import SessionStore
|
|
16
|
+
from django.db import models
|
|
17
|
+
from django.db.models import Q
|
|
18
|
+
from django.db.models.manager import BaseManager
|
|
19
|
+
from django.db.utils import IntegrityError
|
|
20
|
+
from django_apscheduler.models import DjangoJob, DjangoJobExecution
|
|
21
|
+
from fastapi import HTTPException
|
|
22
|
+
from pgvector.django import CosineDistance
|
|
23
|
+
from torch import Tensor
|
|
24
|
+
|
|
25
|
+
from khoj.database.models import (
|
|
26
|
+
Agent,
|
|
27
|
+
ChatModelOptions,
|
|
28
|
+
ClientApplication,
|
|
29
|
+
Conversation,
|
|
30
|
+
DataStore,
|
|
31
|
+
Entry,
|
|
32
|
+
FileObject,
|
|
33
|
+
GithubConfig,
|
|
34
|
+
GithubRepoConfig,
|
|
35
|
+
GoogleUser,
|
|
36
|
+
KhojApiUser,
|
|
37
|
+
KhojUser,
|
|
38
|
+
NotionConfig,
|
|
39
|
+
OpenAIProcessorConversationConfig,
|
|
40
|
+
ProcessLock,
|
|
41
|
+
PublicConversation,
|
|
42
|
+
ReflectiveQuestion,
|
|
43
|
+
SearchModelConfig,
|
|
44
|
+
ServerChatSettings,
|
|
45
|
+
SpeechToTextModelOptions,
|
|
46
|
+
Subscription,
|
|
47
|
+
TextToImageModelConfig,
|
|
48
|
+
UserConversationConfig,
|
|
49
|
+
UserRequests,
|
|
50
|
+
UserSearchModelConfig,
|
|
51
|
+
UserTextToImageModelConfig,
|
|
52
|
+
UserVoiceModelConfig,
|
|
53
|
+
VoiceModelOption,
|
|
54
|
+
)
|
|
55
|
+
from khoj.processor.conversation import prompts
|
|
56
|
+
from khoj.search_filter.date_filter import DateFilter
|
|
57
|
+
from khoj.search_filter.file_filter import FileFilter
|
|
58
|
+
from khoj.search_filter.word_filter import WordFilter
|
|
59
|
+
from khoj.utils import state
|
|
60
|
+
from khoj.utils.config import OfflineChatProcessorModel
|
|
61
|
+
from khoj.utils.helpers import generate_random_name, is_none_or_empty, timer
|
|
62
|
+
|
|
63
|
+
logger = logging.getLogger(__name__)
|
|
64
|
+
|
|
65
|
+
|
|
66
|
+
class SubscriptionState(Enum):
|
|
67
|
+
TRIAL = "trial"
|
|
68
|
+
SUBSCRIBED = "subscribed"
|
|
69
|
+
UNSUBSCRIBED = "unsubscribed"
|
|
70
|
+
EXPIRED = "expired"
|
|
71
|
+
INVALID = "invalid"
|
|
72
|
+
|
|
73
|
+
|
|
74
|
+
async def set_notion_config(token: str, user: KhojUser):
|
|
75
|
+
notion_config = await NotionConfig.objects.filter(user=user).afirst()
|
|
76
|
+
if not notion_config:
|
|
77
|
+
notion_config = await NotionConfig.objects.acreate(token=token, user=user)
|
|
78
|
+
else:
|
|
79
|
+
notion_config.token = token
|
|
80
|
+
await notion_config.asave()
|
|
81
|
+
return notion_config
|
|
82
|
+
|
|
83
|
+
|
|
84
|
+
def create_khoj_token(user: KhojUser, name=None):
|
|
85
|
+
"Create Khoj API key for user"
|
|
86
|
+
token = f"kk-{secrets.token_urlsafe(32)}"
|
|
87
|
+
name = name or f"{generate_random_name().title()}"
|
|
88
|
+
return KhojApiUser.objects.create(token=token, user=user, name=name)
|
|
89
|
+
|
|
90
|
+
|
|
91
|
+
async def acreate_khoj_token(user: KhojUser, name=None):
|
|
92
|
+
"Create Khoj API key for user"
|
|
93
|
+
token = f"kk-{secrets.token_urlsafe(32)}"
|
|
94
|
+
name = name or f"{generate_random_name().title()}"
|
|
95
|
+
return await KhojApiUser.objects.acreate(token=token, user=user, name=name)
|
|
96
|
+
|
|
97
|
+
|
|
98
|
+
def get_khoj_tokens(user: KhojUser):
|
|
99
|
+
"Get all Khoj API keys for user"
|
|
100
|
+
return list(KhojApiUser.objects.filter(user=user))
|
|
101
|
+
|
|
102
|
+
|
|
103
|
+
async def delete_khoj_token(user: KhojUser, token: str):
|
|
104
|
+
"Delete Khoj API Key for user"
|
|
105
|
+
await KhojApiUser.objects.filter(token=token, user=user).adelete()
|
|
106
|
+
|
|
107
|
+
|
|
108
|
+
async def get_or_create_user(token: dict) -> KhojUser:
|
|
109
|
+
user = await get_user_by_token(token)
|
|
110
|
+
if not user:
|
|
111
|
+
user = await create_user_by_google_token(token)
|
|
112
|
+
return user
|
|
113
|
+
|
|
114
|
+
|
|
115
|
+
async def aget_or_create_user_by_phone_number(phone_number: str) -> KhojUser:
|
|
116
|
+
if is_none_or_empty(phone_number):
|
|
117
|
+
return None
|
|
118
|
+
user = await aget_user_by_phone_number(phone_number)
|
|
119
|
+
if not user:
|
|
120
|
+
user = await acreate_user_by_phone_number(phone_number)
|
|
121
|
+
return user
|
|
122
|
+
|
|
123
|
+
|
|
124
|
+
async def aset_user_phone_number(user: KhojUser, phone_number: str) -> KhojUser:
|
|
125
|
+
if is_none_or_empty(phone_number):
|
|
126
|
+
return None
|
|
127
|
+
phone_number = phone_number.strip()
|
|
128
|
+
if not phone_number.startswith("+"):
|
|
129
|
+
phone_number = f"+{phone_number}"
|
|
130
|
+
existing_user_with_phone_number = await aget_user_by_phone_number(phone_number)
|
|
131
|
+
if existing_user_with_phone_number and existing_user_with_phone_number.id != user.id:
|
|
132
|
+
if is_none_or_empty(existing_user_with_phone_number.email):
|
|
133
|
+
# Transfer conversation history to the new user. If they don't have an associated email, they are effectively a new user
|
|
134
|
+
async for conversation in Conversation.objects.filter(user=existing_user_with_phone_number).aiterator():
|
|
135
|
+
conversation.user = user
|
|
136
|
+
await conversation.asave()
|
|
137
|
+
|
|
138
|
+
await existing_user_with_phone_number.adelete()
|
|
139
|
+
else:
|
|
140
|
+
raise HTTPException(status_code=400, detail="Phone number already exists")
|
|
141
|
+
|
|
142
|
+
user.phone_number = phone_number
|
|
143
|
+
await user.asave()
|
|
144
|
+
return user
|
|
145
|
+
|
|
146
|
+
|
|
147
|
+
async def aremove_phone_number(user: KhojUser) -> KhojUser:
|
|
148
|
+
user.phone_number = None
|
|
149
|
+
user.verified_phone_number = False
|
|
150
|
+
await user.asave()
|
|
151
|
+
return user
|
|
152
|
+
|
|
153
|
+
|
|
154
|
+
async def acreate_user_by_phone_number(phone_number: str) -> KhojUser:
|
|
155
|
+
if is_none_or_empty(phone_number):
|
|
156
|
+
return None
|
|
157
|
+
user, _ = await KhojUser.objects.filter(phone_number=phone_number).aupdate_or_create(
|
|
158
|
+
defaults={"username": phone_number, "phone_number": phone_number}
|
|
159
|
+
)
|
|
160
|
+
await user.asave()
|
|
161
|
+
|
|
162
|
+
await Subscription.objects.acreate(user=user, type="trial")
|
|
163
|
+
|
|
164
|
+
return user
|
|
165
|
+
|
|
166
|
+
|
|
167
|
+
async def aget_or_create_user_by_email(email: str) -> KhojUser:
|
|
168
|
+
user, _ = await KhojUser.objects.filter(email=email).aupdate_or_create(defaults={"username": email, "email": email})
|
|
169
|
+
await user.asave()
|
|
170
|
+
|
|
171
|
+
if user:
|
|
172
|
+
user.email_verification_code = secrets.token_urlsafe(18)
|
|
173
|
+
await user.asave()
|
|
174
|
+
|
|
175
|
+
user_subscription = await Subscription.objects.filter(user=user).afirst()
|
|
176
|
+
if not user_subscription:
|
|
177
|
+
await Subscription.objects.acreate(user=user, type="trial")
|
|
178
|
+
|
|
179
|
+
return user
|
|
180
|
+
|
|
181
|
+
|
|
182
|
+
async def aget_user_validated_by_email_verification_code(code: str) -> KhojUser:
|
|
183
|
+
user = await KhojUser.objects.filter(email_verification_code=code).afirst()
|
|
184
|
+
if not user:
|
|
185
|
+
return None
|
|
186
|
+
|
|
187
|
+
user.email_verification_code = None
|
|
188
|
+
user.verified_email = True
|
|
189
|
+
await user.asave()
|
|
190
|
+
|
|
191
|
+
return user
|
|
192
|
+
|
|
193
|
+
|
|
194
|
+
async def create_user_by_google_token(token: dict) -> KhojUser:
|
|
195
|
+
user, _ = await KhojUser.objects.filter(email=token.get("email")).aupdate_or_create(
|
|
196
|
+
defaults={"username": token.get("email"), "email": token.get("email")}
|
|
197
|
+
)
|
|
198
|
+
user.verified_email = True
|
|
199
|
+
await user.asave()
|
|
200
|
+
|
|
201
|
+
await GoogleUser.objects.acreate(
|
|
202
|
+
sub=token.get("sub"),
|
|
203
|
+
azp=token.get("azp"),
|
|
204
|
+
email=token.get("email"),
|
|
205
|
+
name=token.get("name"),
|
|
206
|
+
given_name=token.get("given_name"),
|
|
207
|
+
family_name=token.get("family_name"),
|
|
208
|
+
picture=token.get("picture"),
|
|
209
|
+
locale=token.get("locale"),
|
|
210
|
+
user=user,
|
|
211
|
+
)
|
|
212
|
+
|
|
213
|
+
await Subscription.objects.acreate(user=user, type="trial")
|
|
214
|
+
|
|
215
|
+
return user
|
|
216
|
+
|
|
217
|
+
|
|
218
|
+
def set_user_name(user: KhojUser, first_name: str, last_name: str) -> KhojUser:
|
|
219
|
+
user.first_name = first_name
|
|
220
|
+
user.last_name = last_name
|
|
221
|
+
user.save()
|
|
222
|
+
return user
|
|
223
|
+
|
|
224
|
+
|
|
225
|
+
def get_user_name(user: KhojUser):
|
|
226
|
+
full_name = user.get_full_name()
|
|
227
|
+
if not is_none_or_empty(full_name):
|
|
228
|
+
return full_name
|
|
229
|
+
google_profile: GoogleUser = GoogleUser.objects.filter(user=user).first()
|
|
230
|
+
if google_profile:
|
|
231
|
+
return google_profile.given_name
|
|
232
|
+
|
|
233
|
+
return None
|
|
234
|
+
|
|
235
|
+
|
|
236
|
+
def get_user_photo(user: KhojUser):
|
|
237
|
+
google_profile: GoogleUser = GoogleUser.objects.filter(user=user).first()
|
|
238
|
+
if google_profile:
|
|
239
|
+
return google_profile.picture
|
|
240
|
+
|
|
241
|
+
return None
|
|
242
|
+
|
|
243
|
+
|
|
244
|
+
def get_user_subscription(email: str) -> Optional[Subscription]:
|
|
245
|
+
return Subscription.objects.filter(user__email=email).first()
|
|
246
|
+
|
|
247
|
+
|
|
248
|
+
async def set_user_subscription(
|
|
249
|
+
email: str, is_recurring=None, renewal_date=None, type="standard"
|
|
250
|
+
) -> Optional[Subscription]:
|
|
251
|
+
# Get or create the user object and their subscription
|
|
252
|
+
user = await aget_or_create_user_by_email(email)
|
|
253
|
+
user_subscription = await Subscription.objects.filter(user=user).afirst()
|
|
254
|
+
|
|
255
|
+
# Update the user subscription state
|
|
256
|
+
user_subscription.type = type
|
|
257
|
+
if is_recurring is not None:
|
|
258
|
+
user_subscription.is_recurring = is_recurring
|
|
259
|
+
if renewal_date is False:
|
|
260
|
+
user_subscription.renewal_date = None
|
|
261
|
+
elif renewal_date is not None:
|
|
262
|
+
user_subscription.renewal_date = renewal_date
|
|
263
|
+
await user_subscription.asave()
|
|
264
|
+
return user_subscription
|
|
265
|
+
|
|
266
|
+
|
|
267
|
+
def subscription_to_state(subscription: Subscription) -> str:
|
|
268
|
+
if not subscription:
|
|
269
|
+
return SubscriptionState.INVALID.value
|
|
270
|
+
elif subscription.type == Subscription.Type.TRIAL:
|
|
271
|
+
# Trial subscription is valid for 7 days
|
|
272
|
+
if datetime.now(tz=timezone.utc) - subscription.created_at > timedelta(days=14):
|
|
273
|
+
return SubscriptionState.EXPIRED.value
|
|
274
|
+
|
|
275
|
+
return SubscriptionState.TRIAL.value
|
|
276
|
+
elif subscription.is_recurring and subscription.renewal_date >= datetime.now(tz=timezone.utc):
|
|
277
|
+
return SubscriptionState.SUBSCRIBED.value
|
|
278
|
+
elif not subscription.is_recurring and subscription.renewal_date is None:
|
|
279
|
+
return SubscriptionState.EXPIRED.value
|
|
280
|
+
elif not subscription.is_recurring and subscription.renewal_date >= datetime.now(tz=timezone.utc):
|
|
281
|
+
return SubscriptionState.UNSUBSCRIBED.value
|
|
282
|
+
elif not subscription.is_recurring and subscription.renewal_date < datetime.now(tz=timezone.utc):
|
|
283
|
+
return SubscriptionState.EXPIRED.value
|
|
284
|
+
return SubscriptionState.INVALID.value
|
|
285
|
+
|
|
286
|
+
|
|
287
|
+
def get_user_subscription_state(email: str) -> str:
|
|
288
|
+
"""Get subscription state of user
|
|
289
|
+
Valid state transitions: trial -> subscribed <-> unsubscribed OR expired
|
|
290
|
+
"""
|
|
291
|
+
user_subscription = Subscription.objects.filter(user__email=email).first()
|
|
292
|
+
return subscription_to_state(user_subscription)
|
|
293
|
+
|
|
294
|
+
|
|
295
|
+
async def aget_user_subscription_state(user: KhojUser) -> str:
|
|
296
|
+
"""Get subscription state of user
|
|
297
|
+
Valid state transitions: trial -> subscribed <-> unsubscribed OR expired
|
|
298
|
+
"""
|
|
299
|
+
user_subscription = await Subscription.objects.filter(user=user).afirst()
|
|
300
|
+
return subscription_to_state(user_subscription)
|
|
301
|
+
|
|
302
|
+
|
|
303
|
+
async def get_user_by_email(email: str) -> KhojUser:
|
|
304
|
+
return await KhojUser.objects.filter(email=email).afirst()
|
|
305
|
+
|
|
306
|
+
|
|
307
|
+
async def aget_user_by_uuid(uuid: str) -> KhojUser:
|
|
308
|
+
return await KhojUser.objects.filter(uuid=uuid).afirst()
|
|
309
|
+
|
|
310
|
+
|
|
311
|
+
async def get_user_by_token(token: dict) -> KhojUser:
|
|
312
|
+
google_user = await GoogleUser.objects.filter(sub=token.get("sub")).select_related("user").afirst()
|
|
313
|
+
if not google_user:
|
|
314
|
+
return None
|
|
315
|
+
return google_user.user
|
|
316
|
+
|
|
317
|
+
|
|
318
|
+
async def aget_user_by_phone_number(phone_number: str) -> KhojUser:
|
|
319
|
+
if is_none_or_empty(phone_number):
|
|
320
|
+
return None
|
|
321
|
+
matched_user = await KhojUser.objects.filter(phone_number=phone_number).prefetch_related("subscription").afirst()
|
|
322
|
+
|
|
323
|
+
if not matched_user:
|
|
324
|
+
return None
|
|
325
|
+
|
|
326
|
+
# If the user with this phone number does not have an email account with Khoj, return the user
|
|
327
|
+
if is_none_or_empty(matched_user.email):
|
|
328
|
+
return matched_user
|
|
329
|
+
|
|
330
|
+
# If the user has an email account with Khoj and a verified number, return the user
|
|
331
|
+
if matched_user.verified_phone_number:
|
|
332
|
+
return matched_user
|
|
333
|
+
|
|
334
|
+
return None
|
|
335
|
+
|
|
336
|
+
|
|
337
|
+
async def retrieve_user(session_id: str) -> KhojUser:
|
|
338
|
+
session = SessionStore(session_key=session_id)
|
|
339
|
+
if not await sync_to_async(session.exists)(session_key=session_id):
|
|
340
|
+
raise HTTPException(status_code=401, detail="Invalid session")
|
|
341
|
+
session_data = await sync_to_async(session.load)()
|
|
342
|
+
user = await KhojUser.objects.filter(id=session_data.get("_auth_user_id")).afirst()
|
|
343
|
+
if not user:
|
|
344
|
+
raise HTTPException(status_code=401, detail="Invalid user")
|
|
345
|
+
return user
|
|
346
|
+
|
|
347
|
+
|
|
348
|
+
def get_all_users() -> BaseManager[KhojUser]:
|
|
349
|
+
return KhojUser.objects.all()
|
|
350
|
+
|
|
351
|
+
|
|
352
|
+
def get_user_github_config(user: KhojUser):
|
|
353
|
+
config = GithubConfig.objects.filter(user=user).prefetch_related("githubrepoconfig").first()
|
|
354
|
+
return config
|
|
355
|
+
|
|
356
|
+
|
|
357
|
+
def get_user_notion_config(user: KhojUser):
|
|
358
|
+
config = NotionConfig.objects.filter(user=user).first()
|
|
359
|
+
return config
|
|
360
|
+
|
|
361
|
+
|
|
362
|
+
def delete_user_requests(window: timedelta = timedelta(days=1)):
|
|
363
|
+
return UserRequests.objects.filter(created_at__lte=datetime.now(tz=timezone.utc) - window).delete()
|
|
364
|
+
|
|
365
|
+
|
|
366
|
+
async def aget_user_name(user: KhojUser):
|
|
367
|
+
full_name = user.get_full_name()
|
|
368
|
+
if not is_none_or_empty(full_name):
|
|
369
|
+
return full_name
|
|
370
|
+
google_profile: GoogleUser = await GoogleUser.objects.filter(user=user).afirst()
|
|
371
|
+
if google_profile:
|
|
372
|
+
return google_profile.given_name
|
|
373
|
+
|
|
374
|
+
return None
|
|
375
|
+
|
|
376
|
+
|
|
377
|
+
async def set_text_content_config(user: KhojUser, object: Type[models.Model], updated_config):
|
|
378
|
+
deduped_files = list(set(updated_config.input_files)) if updated_config.input_files else None
|
|
379
|
+
deduped_filters = list(set(updated_config.input_filter)) if updated_config.input_filter else None
|
|
380
|
+
await object.objects.filter(user=user).adelete()
|
|
381
|
+
await object.objects.acreate(
|
|
382
|
+
input_files=deduped_files,
|
|
383
|
+
input_filter=deduped_filters,
|
|
384
|
+
index_heading_entries=updated_config.index_heading_entries,
|
|
385
|
+
user=user,
|
|
386
|
+
)
|
|
387
|
+
|
|
388
|
+
|
|
389
|
+
async def set_user_github_config(user: KhojUser, pat_token: str, repos: list):
|
|
390
|
+
config = await GithubConfig.objects.filter(user=user).afirst()
|
|
391
|
+
|
|
392
|
+
if not config:
|
|
393
|
+
config = await GithubConfig.objects.acreate(pat_token=pat_token, user=user)
|
|
394
|
+
else:
|
|
395
|
+
config.pat_token = pat_token
|
|
396
|
+
await config.asave()
|
|
397
|
+
await config.githubrepoconfig.all().adelete()
|
|
398
|
+
|
|
399
|
+
for repo in repos:
|
|
400
|
+
await GithubRepoConfig.objects.acreate(
|
|
401
|
+
name=repo["name"], owner=repo["owner"], branch=repo["branch"], github_config=config
|
|
402
|
+
)
|
|
403
|
+
return config
|
|
404
|
+
|
|
405
|
+
|
|
406
|
+
def get_user_search_model_or_default(user=None):
|
|
407
|
+
if user and UserSearchModelConfig.objects.filter(user=user).exists():
|
|
408
|
+
return UserSearchModelConfig.objects.filter(user=user).first().setting
|
|
409
|
+
|
|
410
|
+
if SearchModelConfig.objects.filter(name="default").exists():
|
|
411
|
+
return SearchModelConfig.objects.filter(name="default").first()
|
|
412
|
+
else:
|
|
413
|
+
SearchModelConfig.objects.create()
|
|
414
|
+
|
|
415
|
+
return SearchModelConfig.objects.first()
|
|
416
|
+
|
|
417
|
+
|
|
418
|
+
def get_or_create_search_models():
|
|
419
|
+
search_models = SearchModelConfig.objects.all()
|
|
420
|
+
if search_models.count() == 0:
|
|
421
|
+
SearchModelConfig.objects.create()
|
|
422
|
+
search_models = SearchModelConfig.objects.all()
|
|
423
|
+
|
|
424
|
+
return search_models
|
|
425
|
+
|
|
426
|
+
|
|
427
|
+
async def aset_user_search_model(user: KhojUser, search_model_config_id: int):
|
|
428
|
+
config = await SearchModelConfig.objects.filter(id=search_model_config_id).afirst()
|
|
429
|
+
if not config:
|
|
430
|
+
return None
|
|
431
|
+
new_config, _ = await UserSearchModelConfig.objects.aupdate_or_create(user=user, defaults={"setting": config})
|
|
432
|
+
return new_config
|
|
433
|
+
|
|
434
|
+
|
|
435
|
+
async def aget_user_search_model(user: KhojUser):
|
|
436
|
+
config = await UserSearchModelConfig.objects.filter(user=user).prefetch_related("setting").afirst()
|
|
437
|
+
if not config:
|
|
438
|
+
return None
|
|
439
|
+
return config.setting
|
|
440
|
+
|
|
441
|
+
|
|
442
|
+
class ProcessLockAdapters:
|
|
443
|
+
@staticmethod
|
|
444
|
+
def get_process_lock(process_name: str):
|
|
445
|
+
return ProcessLock.objects.filter(name=process_name).first()
|
|
446
|
+
|
|
447
|
+
@staticmethod
|
|
448
|
+
def set_process_lock(process_name: str, max_duration_in_seconds: int = 600):
|
|
449
|
+
return ProcessLock.objects.create(name=process_name, max_duration_in_seconds=max_duration_in_seconds)
|
|
450
|
+
|
|
451
|
+
@staticmethod
|
|
452
|
+
def is_process_locked(process_name: str):
|
|
453
|
+
process_lock = ProcessLock.objects.filter(name=process_name).first()
|
|
454
|
+
if not process_lock:
|
|
455
|
+
return False
|
|
456
|
+
if process_lock.started_at + timedelta(seconds=process_lock.max_duration_in_seconds) < datetime.now(
|
|
457
|
+
tz=timezone.utc
|
|
458
|
+
):
|
|
459
|
+
process_lock.delete()
|
|
460
|
+
logger.info(f"🔓 Deleted stale {process_name} process lock on timeout")
|
|
461
|
+
return False
|
|
462
|
+
return True
|
|
463
|
+
|
|
464
|
+
@staticmethod
|
|
465
|
+
def remove_process_lock(process_lock: ProcessLock):
|
|
466
|
+
return process_lock.delete()
|
|
467
|
+
|
|
468
|
+
@staticmethod
|
|
469
|
+
def run_with_lock(func: Callable, operation: ProcessLock.Operation, max_duration_in_seconds: int = 600, **kwargs):
|
|
470
|
+
# Exit early if process lock is already taken
|
|
471
|
+
if ProcessLockAdapters.is_process_locked(operation):
|
|
472
|
+
logger.debug(f"🔒 Skip executing {func} as {operation} lock is already taken")
|
|
473
|
+
return
|
|
474
|
+
|
|
475
|
+
success = False
|
|
476
|
+
process_lock = None
|
|
477
|
+
try:
|
|
478
|
+
# Set process lock
|
|
479
|
+
process_lock = ProcessLockAdapters.set_process_lock(operation, max_duration_in_seconds)
|
|
480
|
+
logger.info(f"🔐 Locked {operation} to execute {func}")
|
|
481
|
+
|
|
482
|
+
# Execute Function
|
|
483
|
+
with timer(f"🔒 Run {func} with {operation} process lock", logger):
|
|
484
|
+
func(**kwargs)
|
|
485
|
+
success = True
|
|
486
|
+
except IntegrityError as e:
|
|
487
|
+
logger.debug(f"⚠️ Unable to create the process lock for {func} with {operation}: {e}")
|
|
488
|
+
success = False
|
|
489
|
+
except Exception as e:
|
|
490
|
+
logger.error(f"🚨 Error executing {func} with {operation} process lock: {e}", exc_info=True)
|
|
491
|
+
success = False
|
|
492
|
+
finally:
|
|
493
|
+
# Remove Process Lock
|
|
494
|
+
if process_lock:
|
|
495
|
+
ProcessLockAdapters.remove_process_lock(process_lock)
|
|
496
|
+
logger.info(
|
|
497
|
+
f"🔓 Unlocked {operation} process after executing {func} {'Succeeded' if success else 'Failed'}"
|
|
498
|
+
)
|
|
499
|
+
else:
|
|
500
|
+
logger.debug(f"Skip removing {operation} process lock as it was not set")
|
|
501
|
+
|
|
502
|
+
|
|
503
|
+
def run_with_process_lock(*args, **kwargs):
|
|
504
|
+
"""Wrapper function used for scheduling jobs.
|
|
505
|
+
Required as APScheduler can't discover the `ProcessLockAdapter.run_with_lock' method on its own.
|
|
506
|
+
"""
|
|
507
|
+
return ProcessLockAdapters.run_with_lock(*args, **kwargs)
|
|
508
|
+
|
|
509
|
+
|
|
510
|
+
class ClientApplicationAdapters:
|
|
511
|
+
@staticmethod
|
|
512
|
+
async def aget_client_application_by_id(client_id: str, client_secret: str):
|
|
513
|
+
return await ClientApplication.objects.filter(client_id=client_id, client_secret=client_secret).afirst()
|
|
514
|
+
|
|
515
|
+
|
|
516
|
+
class AgentAdapters:
|
|
517
|
+
DEFAULT_AGENT_NAME = "Khoj"
|
|
518
|
+
DEFAULT_AGENT_AVATAR = "https://assets.khoj.dev/lamp-128.png"
|
|
519
|
+
DEFAULT_AGENT_SLUG = "khoj"
|
|
520
|
+
|
|
521
|
+
@staticmethod
|
|
522
|
+
async def aget_agent_by_slug(agent_slug: str, user: KhojUser):
|
|
523
|
+
return await Agent.objects.filter(
|
|
524
|
+
(Q(slug__iexact=agent_slug.lower())) & (Q(public=True) | Q(creator=user))
|
|
525
|
+
).afirst()
|
|
526
|
+
|
|
527
|
+
@staticmethod
|
|
528
|
+
def get_agent_by_slug(slug: str, user: KhojUser = None):
|
|
529
|
+
if user:
|
|
530
|
+
return Agent.objects.filter((Q(slug__iexact=slug.lower())) & (Q(public=True) | Q(creator=user))).first()
|
|
531
|
+
return Agent.objects.filter(slug__iexact=slug.lower(), public=True).first()
|
|
532
|
+
|
|
533
|
+
@staticmethod
|
|
534
|
+
def get_all_accessible_agents(user: KhojUser = None):
|
|
535
|
+
if user:
|
|
536
|
+
return Agent.objects.filter(Q(public=True) | Q(creator=user)).distinct().order_by("created_at")
|
|
537
|
+
return Agent.objects.filter(public=True).order_by("created_at")
|
|
538
|
+
|
|
539
|
+
@staticmethod
|
|
540
|
+
async def aget_all_accessible_agents(user: KhojUser = None) -> List[Agent]:
|
|
541
|
+
agents = await sync_to_async(AgentAdapters.get_all_accessible_agents)(user)
|
|
542
|
+
return await sync_to_async(list)(agents)
|
|
543
|
+
|
|
544
|
+
@staticmethod
|
|
545
|
+
def get_conversation_agent_by_id(agent_id: int):
|
|
546
|
+
agent = Agent.objects.filter(id=agent_id).first()
|
|
547
|
+
if agent == AgentAdapters.get_default_agent():
|
|
548
|
+
# If the agent is set to the default agent, then return None and let the default application code be used
|
|
549
|
+
return None
|
|
550
|
+
return agent
|
|
551
|
+
|
|
552
|
+
@staticmethod
|
|
553
|
+
def get_default_agent():
|
|
554
|
+
return Agent.objects.filter(name=AgentAdapters.DEFAULT_AGENT_NAME).first()
|
|
555
|
+
|
|
556
|
+
@staticmethod
|
|
557
|
+
def create_default_agent():
|
|
558
|
+
default_conversation_config = ConversationAdapters.get_default_conversation_config()
|
|
559
|
+
if default_conversation_config is None:
|
|
560
|
+
logger.info("No default conversation config found, skipping default agent creation")
|
|
561
|
+
return None
|
|
562
|
+
default_personality = prompts.personality.format(current_date="placeholder")
|
|
563
|
+
|
|
564
|
+
agent = Agent.objects.filter(name=AgentAdapters.DEFAULT_AGENT_NAME).first()
|
|
565
|
+
|
|
566
|
+
if agent:
|
|
567
|
+
agent.personality = default_personality
|
|
568
|
+
agent.chat_model = default_conversation_config
|
|
569
|
+
agent.slug = AgentAdapters.DEFAULT_AGENT_SLUG
|
|
570
|
+
agent.name = AgentAdapters.DEFAULT_AGENT_NAME
|
|
571
|
+
agent.save()
|
|
572
|
+
else:
|
|
573
|
+
# The default agent is public and managed by the admin. It's handled a little differently than other agents.
|
|
574
|
+
agent = Agent.objects.create(
|
|
575
|
+
name=AgentAdapters.DEFAULT_AGENT_NAME,
|
|
576
|
+
public=True,
|
|
577
|
+
managed_by_admin=True,
|
|
578
|
+
chat_model=default_conversation_config,
|
|
579
|
+
personality=default_personality,
|
|
580
|
+
tools=["*"],
|
|
581
|
+
avatar=AgentAdapters.DEFAULT_AGENT_AVATAR,
|
|
582
|
+
slug=AgentAdapters.DEFAULT_AGENT_SLUG,
|
|
583
|
+
)
|
|
584
|
+
Conversation.objects.filter(agent=None).update(agent=agent)
|
|
585
|
+
|
|
586
|
+
return agent
|
|
587
|
+
|
|
588
|
+
@staticmethod
|
|
589
|
+
async def aget_default_agent():
|
|
590
|
+
return await Agent.objects.filter(name=AgentAdapters.DEFAULT_AGENT_NAME).afirst()
|
|
591
|
+
|
|
592
|
+
|
|
593
|
+
class PublicConversationAdapters:
|
|
594
|
+
@staticmethod
|
|
595
|
+
def get_public_conversation_by_slug(slug: str):
|
|
596
|
+
return PublicConversation.objects.filter(slug=slug).first()
|
|
597
|
+
|
|
598
|
+
@staticmethod
|
|
599
|
+
def get_public_conversation_url(public_conversation: PublicConversation):
|
|
600
|
+
# Public conversations are viewable by anyone, but not editable.
|
|
601
|
+
return f"/share/chat/{public_conversation.slug}/"
|
|
602
|
+
|
|
603
|
+
|
|
604
|
+
class DataStoreAdapters:
|
|
605
|
+
@staticmethod
|
|
606
|
+
async def astore_data(data: dict, key: str, user: KhojUser, private: bool = True):
|
|
607
|
+
if await DataStore.objects.filter(key=key).aexists():
|
|
608
|
+
return key
|
|
609
|
+
await DataStore.objects.acreate(value=data, key=key, owner=user, private=private)
|
|
610
|
+
return key
|
|
611
|
+
|
|
612
|
+
@staticmethod
|
|
613
|
+
async def aretrieve_public_data(key: str):
|
|
614
|
+
return await DataStore.objects.filter(key=key, private=False).afirst()
|
|
615
|
+
|
|
616
|
+
|
|
617
|
+
class ConversationAdapters:
|
|
618
|
+
@staticmethod
|
|
619
|
+
def make_public_conversation_copy(conversation: Conversation):
|
|
620
|
+
return PublicConversation.objects.create(
|
|
621
|
+
source_owner=conversation.user,
|
|
622
|
+
agent=conversation.agent,
|
|
623
|
+
conversation_log=conversation.conversation_log,
|
|
624
|
+
slug=conversation.slug,
|
|
625
|
+
title=conversation.title,
|
|
626
|
+
)
|
|
627
|
+
|
|
628
|
+
@staticmethod
|
|
629
|
+
def get_conversation_by_user(
|
|
630
|
+
user: KhojUser, client_application: ClientApplication = None, conversation_id: int = None
|
|
631
|
+
) -> Optional[Conversation]:
|
|
632
|
+
if conversation_id:
|
|
633
|
+
conversation = (
|
|
634
|
+
Conversation.objects.filter(user=user, client=client_application, id=conversation_id)
|
|
635
|
+
.order_by("-updated_at")
|
|
636
|
+
.first()
|
|
637
|
+
)
|
|
638
|
+
else:
|
|
639
|
+
agent = AgentAdapters.get_default_agent()
|
|
640
|
+
conversation = (
|
|
641
|
+
Conversation.objects.filter(user=user, client=client_application).order_by("-updated_at").first()
|
|
642
|
+
) or Conversation.objects.create(user=user, client=client_application, agent=agent)
|
|
643
|
+
|
|
644
|
+
return conversation
|
|
645
|
+
|
|
646
|
+
@staticmethod
|
|
647
|
+
def get_conversation_sessions(user: KhojUser, client_application: ClientApplication = None):
|
|
648
|
+
return Conversation.objects.filter(user=user, client=client_application).order_by("-updated_at")
|
|
649
|
+
|
|
650
|
+
@staticmethod
|
|
651
|
+
async def aset_conversation_title(
|
|
652
|
+
user: KhojUser, client_application: ClientApplication, conversation_id: int, title: str
|
|
653
|
+
):
|
|
654
|
+
conversation = await Conversation.objects.filter(
|
|
655
|
+
user=user, client=client_application, id=conversation_id
|
|
656
|
+
).afirst()
|
|
657
|
+
if conversation:
|
|
658
|
+
conversation.title = title
|
|
659
|
+
await conversation.asave()
|
|
660
|
+
return conversation
|
|
661
|
+
return None
|
|
662
|
+
|
|
663
|
+
@staticmethod
|
|
664
|
+
def get_conversation_by_id(conversation_id: int):
|
|
665
|
+
return Conversation.objects.filter(id=conversation_id).first()
|
|
666
|
+
|
|
667
|
+
@staticmethod
|
|
668
|
+
async def acreate_conversation_session(
|
|
669
|
+
user: KhojUser, client_application: ClientApplication = None, agent_slug: str = None
|
|
670
|
+
):
|
|
671
|
+
if agent_slug:
|
|
672
|
+
agent = await AgentAdapters.aget_agent_by_slug(agent_slug, user)
|
|
673
|
+
if agent is None:
|
|
674
|
+
raise HTTPException(status_code=400, detail="No such agent currently exists.")
|
|
675
|
+
return await Conversation.objects.acreate(user=user, client=client_application, agent=agent)
|
|
676
|
+
agent = await AgentAdapters.aget_default_agent()
|
|
677
|
+
return await Conversation.objects.acreate(user=user, client=client_application, agent=agent)
|
|
678
|
+
|
|
679
|
+
@staticmethod
|
|
680
|
+
async def aget_conversation_by_user(
|
|
681
|
+
user: KhojUser, client_application: ClientApplication = None, conversation_id: int = None, title: str = None
|
|
682
|
+
) -> Optional[Conversation]:
|
|
683
|
+
if conversation_id:
|
|
684
|
+
return await Conversation.objects.filter(user=user, client=client_application, id=conversation_id).afirst()
|
|
685
|
+
elif title:
|
|
686
|
+
return await Conversation.objects.filter(user=user, client=client_application, title=title).afirst()
|
|
687
|
+
else:
|
|
688
|
+
conversation = Conversation.objects.filter(user=user, client=client_application).order_by("-updated_at")
|
|
689
|
+
|
|
690
|
+
if await conversation.aexists():
|
|
691
|
+
return await conversation.prefetch_related("agent").afirst()
|
|
692
|
+
|
|
693
|
+
return await (
|
|
694
|
+
Conversation.objects.filter(user=user, client=client_application).order_by("-updated_at").afirst()
|
|
695
|
+
) or await Conversation.objects.acreate(user=user, client=client_application)
|
|
696
|
+
|
|
697
|
+
@staticmethod
|
|
698
|
+
async def adelete_conversation_by_user(
|
|
699
|
+
user: KhojUser, client_application: ClientApplication = None, conversation_id: int = None
|
|
700
|
+
):
|
|
701
|
+
if conversation_id:
|
|
702
|
+
return await Conversation.objects.filter(user=user, client=client_application, id=conversation_id).adelete()
|
|
703
|
+
return await Conversation.objects.filter(user=user, client=client_application).adelete()
|
|
704
|
+
|
|
705
|
+
@staticmethod
|
|
706
|
+
def has_any_conversation_config(user: KhojUser):
|
|
707
|
+
return ChatModelOptions.objects.filter(user=user).exists()
|
|
708
|
+
|
|
709
|
+
@staticmethod
|
|
710
|
+
def get_openai_conversation_config():
|
|
711
|
+
return OpenAIProcessorConversationConfig.objects.filter().first()
|
|
712
|
+
|
|
713
|
+
@staticmethod
|
|
714
|
+
def has_valid_openai_conversation_config():
|
|
715
|
+
return OpenAIProcessorConversationConfig.objects.filter().exists()
|
|
716
|
+
|
|
717
|
+
@staticmethod
|
|
718
|
+
async def aset_user_conversation_processor(user: KhojUser, conversation_processor_config_id: int):
|
|
719
|
+
config = await ChatModelOptions.objects.filter(id=conversation_processor_config_id).afirst()
|
|
720
|
+
if not config:
|
|
721
|
+
return None
|
|
722
|
+
new_config = await UserConversationConfig.objects.aupdate_or_create(user=user, defaults={"setting": config})
|
|
723
|
+
return new_config
|
|
724
|
+
|
|
725
|
+
@staticmethod
|
|
726
|
+
async def aset_user_voice_model(user: KhojUser, model_id: str):
|
|
727
|
+
config = await VoiceModelOption.objects.filter(model_id=model_id).afirst()
|
|
728
|
+
if not config:
|
|
729
|
+
return None
|
|
730
|
+
new_config = await UserVoiceModelConfig.objects.aupdate_or_create(user=user, defaults={"setting": config})
|
|
731
|
+
return new_config
|
|
732
|
+
|
|
733
|
+
@staticmethod
|
|
734
|
+
def get_conversation_config(user: KhojUser):
|
|
735
|
+
config = UserConversationConfig.objects.filter(user=user).first()
|
|
736
|
+
if not config:
|
|
737
|
+
return None
|
|
738
|
+
return config.setting
|
|
739
|
+
|
|
740
|
+
@staticmethod
|
|
741
|
+
async def aget_conversation_config(user: KhojUser):
|
|
742
|
+
config = await UserConversationConfig.objects.filter(user=user).prefetch_related("setting").afirst()
|
|
743
|
+
if not config:
|
|
744
|
+
return None
|
|
745
|
+
return config.setting
|
|
746
|
+
|
|
747
|
+
@staticmethod
|
|
748
|
+
async def aget_voice_model_config(user: KhojUser) -> Optional[VoiceModelOption]:
|
|
749
|
+
voice_model_config = await UserVoiceModelConfig.objects.filter(user=user).prefetch_related("setting").afirst()
|
|
750
|
+
if voice_model_config:
|
|
751
|
+
return voice_model_config.setting
|
|
752
|
+
return None
|
|
753
|
+
|
|
754
|
+
@staticmethod
|
|
755
|
+
def get_voice_model_options():
|
|
756
|
+
return VoiceModelOption.objects.all()
|
|
757
|
+
|
|
758
|
+
@staticmethod
|
|
759
|
+
def get_voice_model_config(user: KhojUser) -> Optional[VoiceModelOption]:
|
|
760
|
+
voice_model_config = UserVoiceModelConfig.objects.filter(user=user).prefetch_related("setting").first()
|
|
761
|
+
if voice_model_config:
|
|
762
|
+
return voice_model_config.setting
|
|
763
|
+
return None
|
|
764
|
+
|
|
765
|
+
@staticmethod
|
|
766
|
+
def get_default_conversation_config():
|
|
767
|
+
server_chat_settings = ServerChatSettings.objects.first()
|
|
768
|
+
if server_chat_settings is None or server_chat_settings.default_model is None:
|
|
769
|
+
return ChatModelOptions.objects.filter().first()
|
|
770
|
+
return server_chat_settings.default_model
|
|
771
|
+
|
|
772
|
+
@staticmethod
|
|
773
|
+
async def aget_default_conversation_config():
|
|
774
|
+
server_chat_settings: ServerChatSettings = (
|
|
775
|
+
await ServerChatSettings.objects.filter()
|
|
776
|
+
.prefetch_related("default_model", "default_model__openai_config")
|
|
777
|
+
.afirst()
|
|
778
|
+
)
|
|
779
|
+
if server_chat_settings is None or server_chat_settings.default_model is None:
|
|
780
|
+
return await ChatModelOptions.objects.filter().prefetch_related("openai_config").afirst()
|
|
781
|
+
return server_chat_settings.default_model
|
|
782
|
+
|
|
783
|
+
@staticmethod
|
|
784
|
+
async def aget_summarizer_conversation_config():
|
|
785
|
+
server_chat_settings: ServerChatSettings = (
|
|
786
|
+
await ServerChatSettings.objects.filter()
|
|
787
|
+
.prefetch_related(
|
|
788
|
+
"summarizer_model", "default_model", "default_model__openai_config", "summarizer_model__openai_config"
|
|
789
|
+
)
|
|
790
|
+
.afirst()
|
|
791
|
+
)
|
|
792
|
+
if server_chat_settings is None or (
|
|
793
|
+
server_chat_settings.summarizer_model is None and server_chat_settings.default_model is None
|
|
794
|
+
):
|
|
795
|
+
return await ChatModelOptions.objects.filter().prefetch_related("openai_config").afirst()
|
|
796
|
+
return server_chat_settings.summarizer_model or server_chat_settings.default_model
|
|
797
|
+
|
|
798
|
+
@staticmethod
|
|
799
|
+
def create_conversation_from_public_conversation(
|
|
800
|
+
user: KhojUser, public_conversation: PublicConversation, client_app: ClientApplication
|
|
801
|
+
):
|
|
802
|
+
return Conversation.objects.create(
|
|
803
|
+
user=user,
|
|
804
|
+
conversation_log=public_conversation.conversation_log,
|
|
805
|
+
client=client_app,
|
|
806
|
+
slug=public_conversation.slug,
|
|
807
|
+
title=public_conversation.title,
|
|
808
|
+
agent=public_conversation.agent,
|
|
809
|
+
)
|
|
810
|
+
|
|
811
|
+
@staticmethod
|
|
812
|
+
def save_conversation(
|
|
813
|
+
user: KhojUser,
|
|
814
|
+
conversation_log: dict,
|
|
815
|
+
client_application: ClientApplication = None,
|
|
816
|
+
conversation_id: int = None,
|
|
817
|
+
user_message: str = None,
|
|
818
|
+
):
|
|
819
|
+
slug = user_message.strip()[:200] if user_message else None
|
|
820
|
+
if conversation_id:
|
|
821
|
+
conversation = Conversation.objects.filter(user=user, client=client_application, id=conversation_id).first()
|
|
822
|
+
else:
|
|
823
|
+
conversation = (
|
|
824
|
+
Conversation.objects.filter(user=user, client=client_application).order_by("-updated_at").first()
|
|
825
|
+
)
|
|
826
|
+
|
|
827
|
+
if conversation:
|
|
828
|
+
conversation.conversation_log = conversation_log
|
|
829
|
+
conversation.slug = slug
|
|
830
|
+
conversation.updated_at = datetime.now(tz=timezone.utc)
|
|
831
|
+
conversation.save()
|
|
832
|
+
else:
|
|
833
|
+
Conversation.objects.create(
|
|
834
|
+
user=user, conversation_log=conversation_log, client=client_application, slug=slug
|
|
835
|
+
)
|
|
836
|
+
|
|
837
|
+
@staticmethod
|
|
838
|
+
def get_conversation_processor_options():
|
|
839
|
+
return ChatModelOptions.objects.all()
|
|
840
|
+
|
|
841
|
+
@staticmethod
|
|
842
|
+
def set_conversation_processor_config(user: KhojUser, new_config: ChatModelOptions):
|
|
843
|
+
user_conversation_config, _ = UserConversationConfig.objects.get_or_create(user=user)
|
|
844
|
+
user_conversation_config.setting = new_config
|
|
845
|
+
user_conversation_config.save()
|
|
846
|
+
|
|
847
|
+
@staticmethod
|
|
848
|
+
async def aget_user_conversation_config(user: KhojUser):
|
|
849
|
+
config = (
|
|
850
|
+
await UserConversationConfig.objects.filter(user=user).prefetch_related("setting__openai_config").afirst()
|
|
851
|
+
)
|
|
852
|
+
if not config:
|
|
853
|
+
return None
|
|
854
|
+
return config.setting
|
|
855
|
+
|
|
856
|
+
@staticmethod
|
|
857
|
+
async def get_speech_to_text_config():
|
|
858
|
+
return await SpeechToTextModelOptions.objects.filter().afirst()
|
|
859
|
+
|
|
860
|
+
@staticmethod
|
|
861
|
+
async def aget_conversation_starters(user: KhojUser, max_results=3):
|
|
862
|
+
all_questions = []
|
|
863
|
+
if await ReflectiveQuestion.objects.filter(user=user).aexists():
|
|
864
|
+
all_questions = await sync_to_async(ReflectiveQuestion.objects.filter(user=user).values_list)(
|
|
865
|
+
"question", flat=True
|
|
866
|
+
)
|
|
867
|
+
|
|
868
|
+
all_questions = await sync_to_async(ReflectiveQuestion.objects.filter(user=None).values_list)(
|
|
869
|
+
"question", flat=True
|
|
870
|
+
)
|
|
871
|
+
|
|
872
|
+
all_questions = await sync_to_async(list)(all_questions) # type: ignore
|
|
873
|
+
if len(all_questions) < max_results:
|
|
874
|
+
return all_questions
|
|
875
|
+
|
|
876
|
+
return random.sample(all_questions, max_results)
|
|
877
|
+
|
|
878
|
+
@staticmethod
|
|
879
|
+
def get_valid_conversation_config(user: KhojUser, conversation: Conversation):
|
|
880
|
+
agent: Agent = conversation.agent if AgentAdapters.get_default_agent() != conversation.agent else None
|
|
881
|
+
if agent and agent.chat_model:
|
|
882
|
+
conversation_config = conversation.agent.chat_model
|
|
883
|
+
else:
|
|
884
|
+
conversation_config = ConversationAdapters.get_conversation_config(user)
|
|
885
|
+
|
|
886
|
+
if conversation_config is None:
|
|
887
|
+
conversation_config = ConversationAdapters.get_default_conversation_config()
|
|
888
|
+
|
|
889
|
+
if conversation_config.model_type == "offline":
|
|
890
|
+
if state.offline_chat_processor_config is None or state.offline_chat_processor_config.loaded_model is None:
|
|
891
|
+
chat_model = conversation_config.chat_model
|
|
892
|
+
max_tokens = conversation_config.max_prompt_size
|
|
893
|
+
state.offline_chat_processor_config = OfflineChatProcessorModel(chat_model, max_tokens)
|
|
894
|
+
|
|
895
|
+
return conversation_config
|
|
896
|
+
|
|
897
|
+
if (
|
|
898
|
+
conversation_config.model_type == "openai" or conversation_config.model_type == "anthropic"
|
|
899
|
+
) and conversation_config.openai_config:
|
|
900
|
+
return conversation_config
|
|
901
|
+
|
|
902
|
+
else:
|
|
903
|
+
raise ValueError("Invalid conversation config - either configure offline chat or openai chat")
|
|
904
|
+
|
|
905
|
+
@staticmethod
|
|
906
|
+
async def aget_text_to_image_model_config():
|
|
907
|
+
return await TextToImageModelConfig.objects.filter().prefetch_related("openai_config").afirst()
|
|
908
|
+
|
|
909
|
+
@staticmethod
|
|
910
|
+
def get_text_to_image_model_config():
|
|
911
|
+
return TextToImageModelConfig.objects.filter().first()
|
|
912
|
+
|
|
913
|
+
@staticmethod
|
|
914
|
+
def get_text_to_image_model_options():
|
|
915
|
+
return TextToImageModelConfig.objects.all()
|
|
916
|
+
|
|
917
|
+
@staticmethod
|
|
918
|
+
def get_user_text_to_image_model_config(user: KhojUser):
|
|
919
|
+
config = UserTextToImageModelConfig.objects.filter(user=user).first()
|
|
920
|
+
if not config:
|
|
921
|
+
default_config = ConversationAdapters.get_text_to_image_model_config()
|
|
922
|
+
if not default_config:
|
|
923
|
+
return None
|
|
924
|
+
return default_config
|
|
925
|
+
return config.setting
|
|
926
|
+
|
|
927
|
+
@staticmethod
|
|
928
|
+
async def aget_user_text_to_image_model(user: KhojUser) -> Optional[TextToImageModelConfig]:
|
|
929
|
+
config = await UserTextToImageModelConfig.objects.filter(user=user).prefetch_related("setting").afirst()
|
|
930
|
+
if not config:
|
|
931
|
+
default_config = await ConversationAdapters.aget_text_to_image_model_config()
|
|
932
|
+
if not default_config:
|
|
933
|
+
return None
|
|
934
|
+
return default_config
|
|
935
|
+
return config.setting
|
|
936
|
+
|
|
937
|
+
@staticmethod
|
|
938
|
+
async def aset_user_text_to_image_model(user: KhojUser, text_to_image_model_config_id: int):
|
|
939
|
+
config = await TextToImageModelConfig.objects.filter(id=text_to_image_model_config_id).afirst()
|
|
940
|
+
if not config:
|
|
941
|
+
return None
|
|
942
|
+
new_config, _ = await UserTextToImageModelConfig.objects.aupdate_or_create(
|
|
943
|
+
user=user, defaults={"setting": config}
|
|
944
|
+
)
|
|
945
|
+
return new_config
|
|
946
|
+
|
|
947
|
+
|
|
948
|
+
class FileObjectAdapters:
|
|
949
|
+
@staticmethod
|
|
950
|
+
def update_raw_text(file_object: FileObject, new_raw_text: str):
|
|
951
|
+
file_object.raw_text = new_raw_text
|
|
952
|
+
file_object.save()
|
|
953
|
+
|
|
954
|
+
@staticmethod
|
|
955
|
+
def create_file_object(user: KhojUser, file_name: str, raw_text: str):
|
|
956
|
+
return FileObject.objects.create(user=user, file_name=file_name, raw_text=raw_text)
|
|
957
|
+
|
|
958
|
+
@staticmethod
|
|
959
|
+
def get_file_object_by_name(user: KhojUser, file_name: str):
|
|
960
|
+
return FileObject.objects.filter(user=user, file_name=file_name).first()
|
|
961
|
+
|
|
962
|
+
@staticmethod
|
|
963
|
+
def get_all_file_objects(user: KhojUser):
|
|
964
|
+
return FileObject.objects.filter(user=user).all()
|
|
965
|
+
|
|
966
|
+
@staticmethod
|
|
967
|
+
def delete_file_object_by_name(user: KhojUser, file_name: str):
|
|
968
|
+
return FileObject.objects.filter(user=user, file_name=file_name).delete()
|
|
969
|
+
|
|
970
|
+
@staticmethod
|
|
971
|
+
def delete_all_file_objects(user: KhojUser):
|
|
972
|
+
return FileObject.objects.filter(user=user).delete()
|
|
973
|
+
|
|
974
|
+
@staticmethod
|
|
975
|
+
async def async_update_raw_text(file_object: FileObject, new_raw_text: str):
|
|
976
|
+
file_object.raw_text = new_raw_text
|
|
977
|
+
await file_object.asave()
|
|
978
|
+
|
|
979
|
+
@staticmethod
|
|
980
|
+
async def async_create_file_object(user: KhojUser, file_name: str, raw_text: str):
|
|
981
|
+
return await FileObject.objects.acreate(user=user, file_name=file_name, raw_text=raw_text)
|
|
982
|
+
|
|
983
|
+
@staticmethod
|
|
984
|
+
async def async_get_file_objects_by_name(user: KhojUser, file_name: str):
|
|
985
|
+
return await sync_to_async(list)(FileObject.objects.filter(user=user, file_name=file_name))
|
|
986
|
+
|
|
987
|
+
@staticmethod
|
|
988
|
+
async def async_get_all_file_objects(user: KhojUser):
|
|
989
|
+
return await sync_to_async(list)(FileObject.objects.filter(user=user))
|
|
990
|
+
|
|
991
|
+
@staticmethod
|
|
992
|
+
async def async_delete_file_object_by_name(user: KhojUser, file_name: str):
|
|
993
|
+
return await FileObject.objects.filter(user=user, file_name=file_name).adelete()
|
|
994
|
+
|
|
995
|
+
@staticmethod
|
|
996
|
+
async def async_delete_all_file_objects(user: KhojUser):
|
|
997
|
+
return await FileObject.objects.filter(user=user).adelete()
|
|
998
|
+
|
|
999
|
+
|
|
1000
|
+
class EntryAdapters:
|
|
1001
|
+
word_filer = WordFilter()
|
|
1002
|
+
file_filter = FileFilter()
|
|
1003
|
+
date_filter = DateFilter()
|
|
1004
|
+
|
|
1005
|
+
@staticmethod
|
|
1006
|
+
def does_entry_exist(user: KhojUser, hashed_value: str) -> bool:
|
|
1007
|
+
return Entry.objects.filter(user=user, hashed_value=hashed_value).exists()
|
|
1008
|
+
|
|
1009
|
+
@staticmethod
|
|
1010
|
+
def delete_entry_by_file(user: KhojUser, file_path: str):
|
|
1011
|
+
deleted_count, _ = Entry.objects.filter(user=user, file_path=file_path).delete()
|
|
1012
|
+
return deleted_count
|
|
1013
|
+
|
|
1014
|
+
@staticmethod
|
|
1015
|
+
def get_filtered_entries(user: KhojUser, file_type: str = None, file_source: str = None):
|
|
1016
|
+
queryset = Entry.objects.filter(user=user)
|
|
1017
|
+
|
|
1018
|
+
if file_type is not None:
|
|
1019
|
+
queryset = queryset.filter(file_type=file_type)
|
|
1020
|
+
|
|
1021
|
+
if file_source is not None:
|
|
1022
|
+
queryset = queryset.filter(file_source=file_source)
|
|
1023
|
+
|
|
1024
|
+
return queryset
|
|
1025
|
+
|
|
1026
|
+
@staticmethod
|
|
1027
|
+
def delete_all_entries(user: KhojUser, file_type: str = None, file_source: str = None, batch_size=1000):
|
|
1028
|
+
deleted_count = 0
|
|
1029
|
+
queryset = EntryAdapters.get_filtered_entries(user, file_type, file_source)
|
|
1030
|
+
while queryset.exists():
|
|
1031
|
+
batch_ids = list(queryset.values_list("id", flat=True)[:batch_size])
|
|
1032
|
+
batch = Entry.objects.filter(id__in=batch_ids, user=user)
|
|
1033
|
+
count, _ = batch.delete()
|
|
1034
|
+
deleted_count += count
|
|
1035
|
+
return deleted_count
|
|
1036
|
+
|
|
1037
|
+
@staticmethod
|
|
1038
|
+
async def adelete_all_entries(user: KhojUser, file_type: str = None, file_source: str = None, batch_size=1000):
|
|
1039
|
+
deleted_count = 0
|
|
1040
|
+
queryset = EntryAdapters.get_filtered_entries(user, file_type, file_source)
|
|
1041
|
+
while await queryset.aexists():
|
|
1042
|
+
batch_ids = await sync_to_async(list)(queryset.values_list("id", flat=True)[:batch_size])
|
|
1043
|
+
batch = Entry.objects.filter(id__in=batch_ids, user=user)
|
|
1044
|
+
count, _ = await batch.adelete()
|
|
1045
|
+
deleted_count += count
|
|
1046
|
+
return deleted_count
|
|
1047
|
+
|
|
1048
|
+
@staticmethod
|
|
1049
|
+
def get_existing_entry_hashes_by_file(user: KhojUser, file_path: str):
|
|
1050
|
+
return Entry.objects.filter(user=user, file_path=file_path).values_list("hashed_value", flat=True)
|
|
1051
|
+
|
|
1052
|
+
@staticmethod
|
|
1053
|
+
def delete_entry_by_hash(user: KhojUser, hashed_values: List[str]):
|
|
1054
|
+
Entry.objects.filter(user=user, hashed_value__in=hashed_values).delete()
|
|
1055
|
+
|
|
1056
|
+
@staticmethod
|
|
1057
|
+
def get_entries_by_date_filter(entry: BaseManager[Entry], start_date: date, end_date: date):
|
|
1058
|
+
return entry.filter(
|
|
1059
|
+
entrydates__date__gte=start_date,
|
|
1060
|
+
entrydates__date__lte=end_date,
|
|
1061
|
+
)
|
|
1062
|
+
|
|
1063
|
+
@staticmethod
|
|
1064
|
+
def user_has_entries(user: KhojUser):
|
|
1065
|
+
return Entry.objects.filter(user=user).exists()
|
|
1066
|
+
|
|
1067
|
+
@staticmethod
|
|
1068
|
+
async def auser_has_entries(user: KhojUser):
|
|
1069
|
+
return await Entry.objects.filter(user=user).aexists()
|
|
1070
|
+
|
|
1071
|
+
@staticmethod
|
|
1072
|
+
async def adelete_entry_by_file(user: KhojUser, file_path: str):
|
|
1073
|
+
return await Entry.objects.filter(user=user, file_path=file_path).adelete()
|
|
1074
|
+
|
|
1075
|
+
@staticmethod
|
|
1076
|
+
def get_all_filenames_by_source(user: KhojUser, file_source: str):
|
|
1077
|
+
return (
|
|
1078
|
+
Entry.objects.filter(user=user, file_source=file_source)
|
|
1079
|
+
.distinct("file_path")
|
|
1080
|
+
.values_list("file_path", flat=True)
|
|
1081
|
+
)
|
|
1082
|
+
|
|
1083
|
+
@staticmethod
|
|
1084
|
+
def get_size_of_indexed_data_in_mb(user: KhojUser):
|
|
1085
|
+
entries = Entry.objects.filter(user=user).iterator()
|
|
1086
|
+
total_size = sum(sys.getsizeof(entry.compiled) for entry in entries)
|
|
1087
|
+
return total_size / 1024 / 1024
|
|
1088
|
+
|
|
1089
|
+
@staticmethod
|
|
1090
|
+
def apply_filters(user: KhojUser, query: str, file_type_filter: str = None):
|
|
1091
|
+
q_filter_terms = Q()
|
|
1092
|
+
|
|
1093
|
+
explicit_word_terms = EntryAdapters.word_filer.get_filter_terms(query)
|
|
1094
|
+
file_filters = EntryAdapters.file_filter.get_filter_terms(query)
|
|
1095
|
+
date_filters = EntryAdapters.date_filter.get_query_date_range(query)
|
|
1096
|
+
|
|
1097
|
+
if len(explicit_word_terms) == 0 and len(file_filters) == 0 and len(date_filters) == 0:
|
|
1098
|
+
return Entry.objects.filter(user=user)
|
|
1099
|
+
|
|
1100
|
+
for term in explicit_word_terms:
|
|
1101
|
+
if term.startswith("+"):
|
|
1102
|
+
q_filter_terms &= Q(raw__icontains=term[1:])
|
|
1103
|
+
elif term.startswith("-"):
|
|
1104
|
+
q_filter_terms &= ~Q(raw__icontains=term[1:])
|
|
1105
|
+
|
|
1106
|
+
q_file_filter_terms = Q()
|
|
1107
|
+
|
|
1108
|
+
if len(file_filters) > 0:
|
|
1109
|
+
for term in file_filters:
|
|
1110
|
+
q_file_filter_terms |= Q(file_path__regex=term)
|
|
1111
|
+
|
|
1112
|
+
q_filter_terms &= q_file_filter_terms
|
|
1113
|
+
|
|
1114
|
+
if len(date_filters) > 0:
|
|
1115
|
+
min_date, max_date = date_filters
|
|
1116
|
+
if min_date is not None:
|
|
1117
|
+
# Convert the min_date timestamp to yyyy-mm-dd format
|
|
1118
|
+
formatted_min_date = date.fromtimestamp(min_date).strftime("%Y-%m-%d")
|
|
1119
|
+
q_filter_terms &= Q(embeddings_dates__date__gte=formatted_min_date)
|
|
1120
|
+
if max_date is not None:
|
|
1121
|
+
# Convert the max_date timestamp to yyyy-mm-dd format
|
|
1122
|
+
formatted_max_date = date.fromtimestamp(max_date).strftime("%Y-%m-%d")
|
|
1123
|
+
q_filter_terms &= Q(embeddings_dates__date__lte=formatted_max_date)
|
|
1124
|
+
|
|
1125
|
+
relevant_entries = Entry.objects.filter(user=user).filter(
|
|
1126
|
+
q_filter_terms,
|
|
1127
|
+
)
|
|
1128
|
+
if file_type_filter:
|
|
1129
|
+
relevant_entries = relevant_entries.filter(file_type=file_type_filter)
|
|
1130
|
+
return relevant_entries
|
|
1131
|
+
|
|
1132
|
+
@staticmethod
|
|
1133
|
+
def search_with_embeddings(
|
|
1134
|
+
user: KhojUser,
|
|
1135
|
+
embeddings: Tensor,
|
|
1136
|
+
max_results: int = 10,
|
|
1137
|
+
file_type_filter: str = None,
|
|
1138
|
+
raw_query: str = None,
|
|
1139
|
+
max_distance: float = math.inf,
|
|
1140
|
+
):
|
|
1141
|
+
relevant_entries = EntryAdapters.apply_filters(user, raw_query, file_type_filter)
|
|
1142
|
+
relevant_entries = relevant_entries.filter(user=user).annotate(
|
|
1143
|
+
distance=CosineDistance("embeddings", embeddings)
|
|
1144
|
+
)
|
|
1145
|
+
relevant_entries = relevant_entries.filter(distance__lte=max_distance)
|
|
1146
|
+
|
|
1147
|
+
if file_type_filter:
|
|
1148
|
+
relevant_entries = relevant_entries.filter(file_type=file_type_filter)
|
|
1149
|
+
relevant_entries = relevant_entries.order_by("distance")
|
|
1150
|
+
return relevant_entries[:max_results]
|
|
1151
|
+
|
|
1152
|
+
@staticmethod
|
|
1153
|
+
def get_unique_file_types(user: KhojUser):
|
|
1154
|
+
return Entry.objects.filter(user=user).values_list("file_type", flat=True).distinct()
|
|
1155
|
+
|
|
1156
|
+
@staticmethod
|
|
1157
|
+
def get_unique_file_sources(user: KhojUser):
|
|
1158
|
+
return Entry.objects.filter(user=user).values_list("file_source", flat=True).distinct().all()
|
|
1159
|
+
|
|
1160
|
+
|
|
1161
|
+
class AutomationAdapters:
|
|
1162
|
+
@staticmethod
|
|
1163
|
+
def get_automations(user: KhojUser) -> Iterable[Job]:
|
|
1164
|
+
all_automations: Iterable[Job] = state.scheduler.get_jobs()
|
|
1165
|
+
for automation in all_automations:
|
|
1166
|
+
if automation.id.startswith(f"automation_{user.uuid}_"):
|
|
1167
|
+
yield automation
|
|
1168
|
+
|
|
1169
|
+
@staticmethod
|
|
1170
|
+
def get_automation_metadata(user: KhojUser, automation: Job):
|
|
1171
|
+
# Perform validation checks
|
|
1172
|
+
# Check if user is allowed to delete this automation id
|
|
1173
|
+
if not automation.id.startswith(f"automation_{user.uuid}_"):
|
|
1174
|
+
raise ValueError("Invalid automation id")
|
|
1175
|
+
|
|
1176
|
+
automation_metadata = json.loads(automation.name)
|
|
1177
|
+
crontime = automation_metadata["crontime"]
|
|
1178
|
+
timezone = automation.next_run_time.strftime("%Z")
|
|
1179
|
+
schedule = f"{cron_descriptor.get_description(crontime)} {timezone}"
|
|
1180
|
+
return {
|
|
1181
|
+
"id": automation.id,
|
|
1182
|
+
"subject": automation_metadata["subject"],
|
|
1183
|
+
"query_to_run": re.sub(r"^/automated_task\s*", "", automation_metadata["query_to_run"]),
|
|
1184
|
+
"scheduling_request": automation_metadata["scheduling_request"],
|
|
1185
|
+
"schedule": schedule,
|
|
1186
|
+
"crontime": crontime,
|
|
1187
|
+
"next": automation.next_run_time.strftime("%Y-%m-%d %I:%M %p %Z"),
|
|
1188
|
+
}
|
|
1189
|
+
|
|
1190
|
+
@staticmethod
|
|
1191
|
+
def get_job_last_run(user: KhojUser, automation: Job):
|
|
1192
|
+
# Perform validation checks
|
|
1193
|
+
# Check if user is allowed to delete this automation id
|
|
1194
|
+
if not automation.id.startswith(f"automation_{user.uuid}_"):
|
|
1195
|
+
raise ValueError("Invalid automation id")
|
|
1196
|
+
|
|
1197
|
+
django_job = DjangoJob.objects.filter(id=automation.id).first()
|
|
1198
|
+
execution = DjangoJobExecution.objects.filter(job=django_job, status="Executed")
|
|
1199
|
+
|
|
1200
|
+
last_run_time = None
|
|
1201
|
+
|
|
1202
|
+
if execution.exists():
|
|
1203
|
+
last_run_time = execution.latest("run_time").run_time
|
|
1204
|
+
|
|
1205
|
+
return last_run_time.strftime("%Y-%m-%d %I:%M %p %Z") if last_run_time else None
|
|
1206
|
+
|
|
1207
|
+
@staticmethod
|
|
1208
|
+
def get_automations_metadata(user: KhojUser):
|
|
1209
|
+
for automation in AutomationAdapters.get_automations(user):
|
|
1210
|
+
yield AutomationAdapters.get_automation_metadata(user, automation)
|
|
1211
|
+
|
|
1212
|
+
@staticmethod
|
|
1213
|
+
def get_automation(user: KhojUser, automation_id: str) -> Job:
|
|
1214
|
+
# Perform validation checks
|
|
1215
|
+
# Check if user is allowed to delete this automation id
|
|
1216
|
+
if not automation_id.startswith(f"automation_{user.uuid}_"):
|
|
1217
|
+
raise ValueError("Invalid automation id")
|
|
1218
|
+
# Check if automation with this id exist
|
|
1219
|
+
automation: Job = state.scheduler.get_job(job_id=automation_id)
|
|
1220
|
+
if not automation:
|
|
1221
|
+
raise ValueError("Invalid automation id")
|
|
1222
|
+
|
|
1223
|
+
return automation
|
|
1224
|
+
|
|
1225
|
+
@staticmethod
|
|
1226
|
+
def delete_automation(user: KhojUser, automation_id: str):
|
|
1227
|
+
# Get valid, user-owned automation
|
|
1228
|
+
automation: Job = AutomationAdapters.get_automation(user, automation_id)
|
|
1229
|
+
|
|
1230
|
+
# Collate info about user automation to be deleted
|
|
1231
|
+
automation_metadata = AutomationAdapters.get_automation_metadata(user, automation)
|
|
1232
|
+
|
|
1233
|
+
automation.remove()
|
|
1234
|
+
return automation_metadata
|