khoj 1.16.1.dev15__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- khoj/__init__.py +0 -0
- khoj/app/README.md +94 -0
- khoj/app/__init__.py +0 -0
- khoj/app/asgi.py +16 -0
- khoj/app/settings.py +192 -0
- khoj/app/urls.py +25 -0
- khoj/configure.py +424 -0
- khoj/database/__init__.py +0 -0
- khoj/database/adapters/__init__.py +1234 -0
- khoj/database/admin.py +290 -0
- khoj/database/apps.py +6 -0
- khoj/database/management/__init__.py +0 -0
- khoj/database/management/commands/__init__.py +0 -0
- khoj/database/management/commands/change_generated_images_url.py +61 -0
- khoj/database/management/commands/convert_images_png_to_webp.py +99 -0
- khoj/database/migrations/0001_khojuser.py +98 -0
- khoj/database/migrations/0002_googleuser.py +32 -0
- khoj/database/migrations/0003_vector_extension.py +10 -0
- khoj/database/migrations/0004_content_types_and_more.py +181 -0
- khoj/database/migrations/0005_embeddings_corpus_id.py +19 -0
- khoj/database/migrations/0006_embeddingsdates.py +33 -0
- khoj/database/migrations/0007_add_conversation.py +27 -0
- khoj/database/migrations/0008_alter_conversation_conversation_log.py +17 -0
- khoj/database/migrations/0009_khojapiuser.py +24 -0
- khoj/database/migrations/0010_chatmodeloptions_and_more.py +83 -0
- khoj/database/migrations/0010_rename_embeddings_entry_and_more.py +30 -0
- khoj/database/migrations/0011_merge_20231102_0138.py +14 -0
- khoj/database/migrations/0012_entry_file_source.py +21 -0
- khoj/database/migrations/0013_subscription.py +37 -0
- khoj/database/migrations/0014_alter_googleuser_picture.py +17 -0
- khoj/database/migrations/0015_alter_subscription_user.py +21 -0
- khoj/database/migrations/0016_alter_subscription_renewal_date.py +17 -0
- khoj/database/migrations/0017_searchmodel.py +32 -0
- khoj/database/migrations/0018_searchmodelconfig_delete_searchmodel.py +30 -0
- khoj/database/migrations/0019_alter_googleuser_family_name_and_more.py +27 -0
- khoj/database/migrations/0020_reflectivequestion.py +36 -0
- khoj/database/migrations/0021_speechtotextmodeloptions_and_more.py +42 -0
- khoj/database/migrations/0022_texttoimagemodelconfig.py +25 -0
- khoj/database/migrations/0023_usersearchmodelconfig.py +33 -0
- khoj/database/migrations/0024_alter_entry_embeddings.py +18 -0
- khoj/database/migrations/0025_clientapplication_khojuser_phone_number_and_more.py +46 -0
- khoj/database/migrations/0025_searchmodelconfig_embeddings_inference_endpoint_and_more.py +22 -0
- khoj/database/migrations/0026_searchmodelconfig_cross_encoder_inference_endpoint_and_more.py +22 -0
- khoj/database/migrations/0027_merge_20240118_1324.py +13 -0
- khoj/database/migrations/0028_khojuser_verified_phone_number.py +17 -0
- khoj/database/migrations/0029_userrequests.py +27 -0
- khoj/database/migrations/0030_conversation_slug_and_title.py +38 -0
- khoj/database/migrations/0031_agent_conversation_agent.py +53 -0
- khoj/database/migrations/0031_alter_googleuser_locale.py +30 -0
- khoj/database/migrations/0032_merge_20240322_0427.py +14 -0
- khoj/database/migrations/0033_rename_tuning_agent_personality.py +17 -0
- khoj/database/migrations/0034_alter_chatmodeloptions_chat_model.py +32 -0
- khoj/database/migrations/0035_processlock.py +26 -0
- khoj/database/migrations/0036_alter_processlock_name.py +19 -0
- khoj/database/migrations/0036_delete_offlinechatprocessorconversationconfig.py +15 -0
- khoj/database/migrations/0036_publicconversation.py +42 -0
- khoj/database/migrations/0037_chatmodeloptions_openai_config_and_more.py +51 -0
- khoj/database/migrations/0037_searchmodelconfig_bi_encoder_docs_encode_config_and_more.py +32 -0
- khoj/database/migrations/0038_merge_20240425_0857.py +14 -0
- khoj/database/migrations/0038_merge_20240426_1640.py +12 -0
- khoj/database/migrations/0039_merge_20240501_0301.py +12 -0
- khoj/database/migrations/0040_alter_processlock_name.py +26 -0
- khoj/database/migrations/0040_merge_20240504_1010.py +14 -0
- khoj/database/migrations/0041_merge_20240505_1234.py +14 -0
- khoj/database/migrations/0042_serverchatsettings.py +46 -0
- khoj/database/migrations/0043_alter_chatmodeloptions_model_type.py +21 -0
- khoj/database/migrations/0044_conversation_file_filters.py +17 -0
- khoj/database/migrations/0045_fileobject.py +37 -0
- khoj/database/migrations/0046_khojuser_email_verification_code_and_more.py +22 -0
- khoj/database/migrations/0047_alter_entry_file_type.py +31 -0
- khoj/database/migrations/0048_voicemodeloption_uservoicemodelconfig.py +52 -0
- khoj/database/migrations/0049_datastore.py +38 -0
- khoj/database/migrations/0049_texttoimagemodelconfig_api_key_and_more.py +58 -0
- khoj/database/migrations/0050_alter_processlock_name.py +25 -0
- khoj/database/migrations/0051_merge_20240702_1220.py +14 -0
- khoj/database/migrations/0052_alter_searchmodelconfig_bi_encoder_docs_encode_config_and_more.py +27 -0
- khoj/database/migrations/__init__.py +0 -0
- khoj/database/models/__init__.py +402 -0
- khoj/database/tests.py +3 -0
- khoj/interface/email/feedback.html +34 -0
- khoj/interface/email/magic_link.html +17 -0
- khoj/interface/email/task.html +40 -0
- khoj/interface/email/welcome.html +61 -0
- khoj/interface/web/404.html +56 -0
- khoj/interface/web/agent.html +312 -0
- khoj/interface/web/agents.html +276 -0
- khoj/interface/web/assets/icons/agents.svg +6 -0
- khoj/interface/web/assets/icons/automation.svg +37 -0
- khoj/interface/web/assets/icons/cancel.svg +3 -0
- khoj/interface/web/assets/icons/chat.svg +24 -0
- khoj/interface/web/assets/icons/collapse.svg +17 -0
- khoj/interface/web/assets/icons/computer.png +0 -0
- khoj/interface/web/assets/icons/confirm-icon.svg +1 -0
- khoj/interface/web/assets/icons/copy-button-success.svg +6 -0
- khoj/interface/web/assets/icons/copy-button.svg +5 -0
- khoj/interface/web/assets/icons/credit-card.png +0 -0
- khoj/interface/web/assets/icons/delete.svg +26 -0
- khoj/interface/web/assets/icons/docx.svg +7 -0
- khoj/interface/web/assets/icons/edit.svg +4 -0
- khoj/interface/web/assets/icons/favicon-128x128.ico +0 -0
- khoj/interface/web/assets/icons/favicon-128x128.png +0 -0
- khoj/interface/web/assets/icons/favicon-256x256.png +0 -0
- khoj/interface/web/assets/icons/favicon.icns +0 -0
- khoj/interface/web/assets/icons/github.svg +1 -0
- khoj/interface/web/assets/icons/key.svg +4 -0
- khoj/interface/web/assets/icons/khoj-logo-sideways-200.png +0 -0
- khoj/interface/web/assets/icons/khoj-logo-sideways-500.png +0 -0
- khoj/interface/web/assets/icons/khoj-logo-sideways.svg +5385 -0
- khoj/interface/web/assets/icons/logotype.svg +1 -0
- khoj/interface/web/assets/icons/markdown.svg +1 -0
- khoj/interface/web/assets/icons/new.svg +23 -0
- khoj/interface/web/assets/icons/notion.svg +4 -0
- khoj/interface/web/assets/icons/openai-logomark.svg +1 -0
- khoj/interface/web/assets/icons/org.svg +1 -0
- khoj/interface/web/assets/icons/pdf.svg +23 -0
- khoj/interface/web/assets/icons/pencil-edit.svg +5 -0
- khoj/interface/web/assets/icons/plaintext.svg +1 -0
- khoj/interface/web/assets/icons/question-mark-icon.svg +1 -0
- khoj/interface/web/assets/icons/search.svg +25 -0
- khoj/interface/web/assets/icons/send.svg +1 -0
- khoj/interface/web/assets/icons/share.svg +8 -0
- khoj/interface/web/assets/icons/speaker.svg +4 -0
- khoj/interface/web/assets/icons/stop-solid.svg +37 -0
- khoj/interface/web/assets/icons/sync.svg +4 -0
- khoj/interface/web/assets/icons/thumbs-down-svgrepo-com.svg +6 -0
- khoj/interface/web/assets/icons/thumbs-up-svgrepo-com.svg +6 -0
- khoj/interface/web/assets/icons/user-silhouette.svg +4 -0
- khoj/interface/web/assets/icons/voice.svg +8 -0
- khoj/interface/web/assets/icons/web.svg +2 -0
- khoj/interface/web/assets/icons/whatsapp.svg +17 -0
- khoj/interface/web/assets/khoj.css +237 -0
- khoj/interface/web/assets/markdown-it.min.js +8476 -0
- khoj/interface/web/assets/natural-cron.min.js +1 -0
- khoj/interface/web/assets/org.min.js +1823 -0
- khoj/interface/web/assets/pico.min.css +5 -0
- khoj/interface/web/assets/purify.min.js +3 -0
- khoj/interface/web/assets/samples/desktop-browse-draw-sample.png +0 -0
- khoj/interface/web/assets/samples/desktop-plain-chat-sample.png +0 -0
- khoj/interface/web/assets/samples/desktop-remember-plan-sample.png +0 -0
- khoj/interface/web/assets/samples/phone-browse-draw-sample.png +0 -0
- khoj/interface/web/assets/samples/phone-plain-chat-sample.png +0 -0
- khoj/interface/web/assets/samples/phone-remember-plan-sample.png +0 -0
- khoj/interface/web/assets/utils.js +33 -0
- khoj/interface/web/base_config.html +445 -0
- khoj/interface/web/chat.html +3546 -0
- khoj/interface/web/config.html +1011 -0
- khoj/interface/web/config_automation.html +1103 -0
- khoj/interface/web/content_source_computer_input.html +139 -0
- khoj/interface/web/content_source_github_input.html +216 -0
- khoj/interface/web/content_source_notion_input.html +94 -0
- khoj/interface/web/khoj.webmanifest +51 -0
- khoj/interface/web/login.html +219 -0
- khoj/interface/web/public_conversation.html +2006 -0
- khoj/interface/web/search.html +470 -0
- khoj/interface/web/utils.html +48 -0
- khoj/main.py +241 -0
- khoj/manage.py +22 -0
- khoj/migrations/__init__.py +0 -0
- khoj/migrations/migrate_offline_chat_default_model.py +69 -0
- khoj/migrations/migrate_offline_chat_default_model_2.py +71 -0
- khoj/migrations/migrate_offline_chat_schema.py +83 -0
- khoj/migrations/migrate_offline_model.py +29 -0
- khoj/migrations/migrate_processor_config_openai.py +67 -0
- khoj/migrations/migrate_server_pg.py +138 -0
- khoj/migrations/migrate_version.py +17 -0
- khoj/processor/__init__.py +0 -0
- khoj/processor/content/__init__.py +0 -0
- khoj/processor/content/docx/__init__.py +0 -0
- khoj/processor/content/docx/docx_to_entries.py +110 -0
- khoj/processor/content/github/__init__.py +0 -0
- khoj/processor/content/github/github_to_entries.py +224 -0
- khoj/processor/content/images/__init__.py +0 -0
- khoj/processor/content/images/image_to_entries.py +118 -0
- khoj/processor/content/markdown/__init__.py +0 -0
- khoj/processor/content/markdown/markdown_to_entries.py +165 -0
- khoj/processor/content/notion/notion_to_entries.py +260 -0
- khoj/processor/content/org_mode/__init__.py +0 -0
- khoj/processor/content/org_mode/org_to_entries.py +231 -0
- khoj/processor/content/org_mode/orgnode.py +532 -0
- khoj/processor/content/pdf/__init__.py +0 -0
- khoj/processor/content/pdf/pdf_to_entries.py +116 -0
- khoj/processor/content/plaintext/__init__.py +0 -0
- khoj/processor/content/plaintext/plaintext_to_entries.py +122 -0
- khoj/processor/content/text_to_entries.py +297 -0
- khoj/processor/conversation/__init__.py +0 -0
- khoj/processor/conversation/anthropic/__init__.py +0 -0
- khoj/processor/conversation/anthropic/anthropic_chat.py +206 -0
- khoj/processor/conversation/anthropic/utils.py +114 -0
- khoj/processor/conversation/offline/__init__.py +0 -0
- khoj/processor/conversation/offline/chat_model.py +231 -0
- khoj/processor/conversation/offline/utils.py +78 -0
- khoj/processor/conversation/offline/whisper.py +15 -0
- khoj/processor/conversation/openai/__init__.py +0 -0
- khoj/processor/conversation/openai/gpt.py +187 -0
- khoj/processor/conversation/openai/utils.py +129 -0
- khoj/processor/conversation/openai/whisper.py +13 -0
- khoj/processor/conversation/prompts.py +758 -0
- khoj/processor/conversation/utils.py +262 -0
- khoj/processor/embeddings.py +117 -0
- khoj/processor/speech/__init__.py +0 -0
- khoj/processor/speech/text_to_speech.py +51 -0
- khoj/processor/tools/__init__.py +0 -0
- khoj/processor/tools/online_search.py +225 -0
- khoj/routers/__init__.py +0 -0
- khoj/routers/api.py +626 -0
- khoj/routers/api_agents.py +43 -0
- khoj/routers/api_chat.py +1180 -0
- khoj/routers/api_config.py +434 -0
- khoj/routers/api_phone.py +86 -0
- khoj/routers/auth.py +181 -0
- khoj/routers/email.py +133 -0
- khoj/routers/helpers.py +1188 -0
- khoj/routers/indexer.py +349 -0
- khoj/routers/notion.py +91 -0
- khoj/routers/storage.py +35 -0
- khoj/routers/subscription.py +104 -0
- khoj/routers/twilio.py +36 -0
- khoj/routers/web_client.py +471 -0
- khoj/search_filter/__init__.py +0 -0
- khoj/search_filter/base_filter.py +15 -0
- khoj/search_filter/date_filter.py +217 -0
- khoj/search_filter/file_filter.py +30 -0
- khoj/search_filter/word_filter.py +29 -0
- khoj/search_type/__init__.py +0 -0
- khoj/search_type/text_search.py +241 -0
- khoj/utils/__init__.py +0 -0
- khoj/utils/cli.py +93 -0
- khoj/utils/config.py +81 -0
- khoj/utils/constants.py +24 -0
- khoj/utils/fs_syncer.py +249 -0
- khoj/utils/helpers.py +418 -0
- khoj/utils/initialization.py +146 -0
- khoj/utils/jsonl.py +43 -0
- khoj/utils/models.py +47 -0
- khoj/utils/rawconfig.py +160 -0
- khoj/utils/state.py +46 -0
- khoj/utils/yaml.py +43 -0
- khoj-1.16.1.dev15.dist-info/METADATA +178 -0
- khoj-1.16.1.dev15.dist-info/RECORD +242 -0
- khoj-1.16.1.dev15.dist-info/WHEEL +4 -0
- khoj-1.16.1.dev15.dist-info/entry_points.txt +2 -0
- khoj-1.16.1.dev15.dist-info/licenses/LICENSE +661 -0
khoj/utils/helpers.py
ADDED
|
@@ -0,0 +1,418 @@
|
|
|
1
|
+
from __future__ import annotations # to avoid quoting type hints
|
|
2
|
+
|
|
3
|
+
import datetime
|
|
4
|
+
import logging
|
|
5
|
+
import os
|
|
6
|
+
import platform
|
|
7
|
+
import random
|
|
8
|
+
import uuid
|
|
9
|
+
from collections import OrderedDict
|
|
10
|
+
from enum import Enum
|
|
11
|
+
from importlib import import_module
|
|
12
|
+
from importlib.metadata import version
|
|
13
|
+
from itertools import islice
|
|
14
|
+
from os import path
|
|
15
|
+
from pathlib import Path
|
|
16
|
+
from time import perf_counter
|
|
17
|
+
from typing import TYPE_CHECKING, Optional, Union
|
|
18
|
+
from urllib.parse import urlparse
|
|
19
|
+
|
|
20
|
+
import psutil
|
|
21
|
+
import requests
|
|
22
|
+
import torch
|
|
23
|
+
from asgiref.sync import sync_to_async
|
|
24
|
+
from magika import Magika
|
|
25
|
+
|
|
26
|
+
from khoj.utils import constants
|
|
27
|
+
|
|
28
|
+
if TYPE_CHECKING:
|
|
29
|
+
from sentence_transformers import CrossEncoder, SentenceTransformer
|
|
30
|
+
|
|
31
|
+
from khoj.utils.models import BaseEncoder
|
|
32
|
+
from khoj.utils.rawconfig import AppConfig
|
|
33
|
+
|
|
34
|
+
|
|
35
|
+
# Initialize Magika for file type identification
|
|
36
|
+
magika = Magika()
|
|
37
|
+
|
|
38
|
+
|
|
39
|
+
class AsyncIteratorWrapper:
|
|
40
|
+
def __init__(self, obj):
|
|
41
|
+
self._it = iter(obj)
|
|
42
|
+
|
|
43
|
+
def __aiter__(self):
|
|
44
|
+
return self
|
|
45
|
+
|
|
46
|
+
async def __anext__(self):
|
|
47
|
+
try:
|
|
48
|
+
value = await self.next_async()
|
|
49
|
+
except StopAsyncIteration:
|
|
50
|
+
return
|
|
51
|
+
return value
|
|
52
|
+
|
|
53
|
+
@sync_to_async
|
|
54
|
+
def next_async(self):
|
|
55
|
+
try:
|
|
56
|
+
return next(self._it)
|
|
57
|
+
except StopIteration:
|
|
58
|
+
raise StopAsyncIteration
|
|
59
|
+
|
|
60
|
+
|
|
61
|
+
def is_none_or_empty(item):
|
|
62
|
+
return item == None or (hasattr(item, "__iter__") and len(item) == 0) or item == ""
|
|
63
|
+
|
|
64
|
+
|
|
65
|
+
def to_snake_case_from_dash(item: str):
|
|
66
|
+
return item.replace("_", "-")
|
|
67
|
+
|
|
68
|
+
|
|
69
|
+
def get_absolute_path(filepath: Union[str, Path]) -> str:
|
|
70
|
+
return str(Path(filepath).expanduser().absolute())
|
|
71
|
+
|
|
72
|
+
|
|
73
|
+
def resolve_absolute_path(filepath: Union[str, Optional[Path]], strict=False) -> Path:
|
|
74
|
+
return Path(filepath).expanduser().absolute().resolve(strict=strict)
|
|
75
|
+
|
|
76
|
+
|
|
77
|
+
def get_from_dict(dictionary, *args):
|
|
78
|
+
"""null-aware get from a nested dictionary
|
|
79
|
+
Returns: dictionary[args[0]][args[1]]... or None if any keys missing"""
|
|
80
|
+
current = dictionary
|
|
81
|
+
for arg in args:
|
|
82
|
+
if not hasattr(current, "__iter__") or not arg in current:
|
|
83
|
+
return None
|
|
84
|
+
current = current[arg]
|
|
85
|
+
return current
|
|
86
|
+
|
|
87
|
+
|
|
88
|
+
def merge_dicts(priority_dict: dict, default_dict: dict):
|
|
89
|
+
merged_dict = priority_dict.copy()
|
|
90
|
+
for key, _ in default_dict.items():
|
|
91
|
+
if key not in priority_dict:
|
|
92
|
+
merged_dict[key] = default_dict[key]
|
|
93
|
+
elif isinstance(priority_dict[key], dict) and isinstance(default_dict[key], dict):
|
|
94
|
+
merged_dict[key] = merge_dicts(priority_dict[key], default_dict[key])
|
|
95
|
+
return merged_dict
|
|
96
|
+
|
|
97
|
+
|
|
98
|
+
def get_file_type(file_type: str, file_content: bytes) -> tuple[str, str]:
|
|
99
|
+
"Get file type from file mime type"
|
|
100
|
+
|
|
101
|
+
# Extract encoding from file_type
|
|
102
|
+
encoding = file_type.split("=")[1].strip().lower() if ";" in file_type else None
|
|
103
|
+
file_type = file_type.split(";")[0].strip() if ";" in file_type else file_type
|
|
104
|
+
|
|
105
|
+
# Infer content type from reading file content
|
|
106
|
+
try:
|
|
107
|
+
content_group = magika.identify_bytes(file_content).output.group
|
|
108
|
+
except Exception:
|
|
109
|
+
# Fallback to using just file type if content type cannot be inferred
|
|
110
|
+
content_group = "unknown"
|
|
111
|
+
|
|
112
|
+
if file_type in ["text/markdown"]:
|
|
113
|
+
return "markdown", encoding
|
|
114
|
+
elif file_type in ["text/org"]:
|
|
115
|
+
return "org", encoding
|
|
116
|
+
elif file_type in ["application/pdf"]:
|
|
117
|
+
return "pdf", encoding
|
|
118
|
+
elif file_type in ["application/msword", "application/vnd.openxmlformats-officedocument.wordprocessingml.document"]:
|
|
119
|
+
return "docx", encoding
|
|
120
|
+
elif file_type in ["image/jpeg"]:
|
|
121
|
+
return "image", encoding
|
|
122
|
+
elif file_type in ["image/png"]:
|
|
123
|
+
return "image", encoding
|
|
124
|
+
elif content_group in ["code", "text"]:
|
|
125
|
+
return "plaintext", encoding
|
|
126
|
+
else:
|
|
127
|
+
return "other", encoding
|
|
128
|
+
|
|
129
|
+
|
|
130
|
+
def load_model(
|
|
131
|
+
model_name: str, model_type, model_dir=None, device: str = None
|
|
132
|
+
) -> Union[BaseEncoder, SentenceTransformer, CrossEncoder]:
|
|
133
|
+
"Load model from disk or huggingface"
|
|
134
|
+
# Construct model path
|
|
135
|
+
logger = logging.getLogger(__name__)
|
|
136
|
+
model_path = path.join(model_dir, model_name.replace("/", "_")) if model_dir is not None else None
|
|
137
|
+
|
|
138
|
+
# Load model from model_path if it exists there
|
|
139
|
+
model_type_class = get_class_by_name(model_type) if isinstance(model_type, str) else model_type
|
|
140
|
+
if model_path is not None and resolve_absolute_path(model_path).exists():
|
|
141
|
+
logger.debug(f"Loading {model_name} model from disk")
|
|
142
|
+
model = model_type_class(get_absolute_path(model_path), device=device)
|
|
143
|
+
# Else load the model from the model_name
|
|
144
|
+
else:
|
|
145
|
+
logger.info(f"🤖 Downloading {model_name} model from web")
|
|
146
|
+
model = model_type_class(model_name, device=device)
|
|
147
|
+
if model_path is not None:
|
|
148
|
+
logger.info(f"📩 Saved {model_name} model to disk")
|
|
149
|
+
model.save(model_path)
|
|
150
|
+
|
|
151
|
+
return model
|
|
152
|
+
|
|
153
|
+
|
|
154
|
+
def get_class_by_name(name: str) -> object:
|
|
155
|
+
"Returns the class object from name string"
|
|
156
|
+
module_name, class_name = name.rsplit(".", 1)
|
|
157
|
+
return getattr(import_module(module_name), class_name)
|
|
158
|
+
|
|
159
|
+
|
|
160
|
+
class timer:
|
|
161
|
+
"""Context manager to log time taken for a block of code to run"""
|
|
162
|
+
|
|
163
|
+
def __init__(self, message: str, logger: logging.Logger, device: torch.device = None):
|
|
164
|
+
self.message = message
|
|
165
|
+
self.logger = logger
|
|
166
|
+
self.device = device
|
|
167
|
+
|
|
168
|
+
def __enter__(self):
|
|
169
|
+
self.start = perf_counter()
|
|
170
|
+
return self
|
|
171
|
+
|
|
172
|
+
def __exit__(self, *_):
|
|
173
|
+
elapsed = perf_counter() - self.start
|
|
174
|
+
if self.device is None:
|
|
175
|
+
self.logger.debug(f"{self.message}: {elapsed:.3f} seconds")
|
|
176
|
+
else:
|
|
177
|
+
self.logger.debug(f"{self.message}: {elapsed:.3f} seconds on device: {self.device}")
|
|
178
|
+
|
|
179
|
+
|
|
180
|
+
class LRU(OrderedDict):
|
|
181
|
+
def __init__(self, *args, capacity=128, **kwargs):
|
|
182
|
+
self.capacity = capacity
|
|
183
|
+
super().__init__(*args, **kwargs)
|
|
184
|
+
|
|
185
|
+
def __getitem__(self, key):
|
|
186
|
+
value = super().__getitem__(key)
|
|
187
|
+
self.move_to_end(key)
|
|
188
|
+
return value
|
|
189
|
+
|
|
190
|
+
def __setitem__(self, key, value):
|
|
191
|
+
super().__setitem__(key, value)
|
|
192
|
+
if len(self) > self.capacity:
|
|
193
|
+
oldest = next(iter(self))
|
|
194
|
+
del self[oldest]
|
|
195
|
+
|
|
196
|
+
|
|
197
|
+
def get_server_id():
|
|
198
|
+
"""Get, Generate Persistent, Random ID per server install.
|
|
199
|
+
Helps count distinct khoj servers deployed.
|
|
200
|
+
Maintains anonymity by using non-PII random id."""
|
|
201
|
+
# Initialize server_id to None
|
|
202
|
+
server_id = None
|
|
203
|
+
# Expand path to the khoj env file. It contains persistent internal app data
|
|
204
|
+
app_env_filename = path.expanduser(constants.app_env_filepath)
|
|
205
|
+
|
|
206
|
+
# Check if the file exists
|
|
207
|
+
if path.exists(app_env_filename):
|
|
208
|
+
# Read the contents of the file
|
|
209
|
+
with open(app_env_filename, "r") as f:
|
|
210
|
+
contents = f.readlines()
|
|
211
|
+
|
|
212
|
+
# Extract the server_id from the contents
|
|
213
|
+
for line in contents:
|
|
214
|
+
key, value = line.strip().split("=")
|
|
215
|
+
if key.strip() == "server_id":
|
|
216
|
+
server_id = value.strip()
|
|
217
|
+
break
|
|
218
|
+
|
|
219
|
+
# If server_id is not found, generate and write to env file
|
|
220
|
+
if server_id is None:
|
|
221
|
+
# If server_id is not found, generate a new one
|
|
222
|
+
server_id = str(uuid.uuid4())
|
|
223
|
+
|
|
224
|
+
with open(app_env_filename, "a") as f:
|
|
225
|
+
f.write("server_id=" + server_id + "\n")
|
|
226
|
+
else:
|
|
227
|
+
# If server_id is not found, generate a new one
|
|
228
|
+
server_id = str(uuid.uuid4())
|
|
229
|
+
|
|
230
|
+
# Create khoj config directory if it doesn't exist
|
|
231
|
+
os.makedirs(path.dirname(app_env_filename), exist_ok=True)
|
|
232
|
+
|
|
233
|
+
# Write the server_id to the env file
|
|
234
|
+
with open(app_env_filename, "w") as f:
|
|
235
|
+
f.write("server_id=" + server_id + "\n")
|
|
236
|
+
|
|
237
|
+
return server_id
|
|
238
|
+
|
|
239
|
+
|
|
240
|
+
def telemetry_disabled(app_config: AppConfig):
|
|
241
|
+
return not app_config or not app_config.should_log_telemetry
|
|
242
|
+
|
|
243
|
+
|
|
244
|
+
def log_telemetry(
|
|
245
|
+
telemetry_type: str,
|
|
246
|
+
api: str = None,
|
|
247
|
+
client: Optional[str] = None,
|
|
248
|
+
app_config: Optional[AppConfig] = None,
|
|
249
|
+
properties: dict = None,
|
|
250
|
+
):
|
|
251
|
+
"""Log basic app usage telemetry like client, os, api called"""
|
|
252
|
+
# Do not log usage telemetry, if telemetry is disabled via app config
|
|
253
|
+
if telemetry_disabled(app_config):
|
|
254
|
+
return []
|
|
255
|
+
|
|
256
|
+
if properties.get("server_id") is None:
|
|
257
|
+
properties["server_id"] = get_server_id()
|
|
258
|
+
|
|
259
|
+
# Populate telemetry data to log
|
|
260
|
+
request_body = {
|
|
261
|
+
"telemetry_type": telemetry_type,
|
|
262
|
+
"server_version": version("khoj"),
|
|
263
|
+
"os": platform.system(),
|
|
264
|
+
"timestamp": datetime.datetime.now().strftime("%Y-%m-%d %H:%M:%S"),
|
|
265
|
+
}
|
|
266
|
+
request_body.update(properties or {})
|
|
267
|
+
if api:
|
|
268
|
+
# API endpoint on server called by client
|
|
269
|
+
request_body["api"] = api
|
|
270
|
+
if client:
|
|
271
|
+
# Client from which the API was called. E.g. Emacs, Obsidian
|
|
272
|
+
request_body["client"] = client
|
|
273
|
+
|
|
274
|
+
# Log telemetry data to telemetry endpoint
|
|
275
|
+
return request_body
|
|
276
|
+
|
|
277
|
+
|
|
278
|
+
def get_device_memory() -> int:
|
|
279
|
+
"""Get device memory in GB"""
|
|
280
|
+
device = get_device()
|
|
281
|
+
if device.type == "cuda":
|
|
282
|
+
return torch.cuda.get_device_properties(device).total_memory
|
|
283
|
+
elif device.type == "mps":
|
|
284
|
+
return torch.mps.driver_allocated_memory()
|
|
285
|
+
else:
|
|
286
|
+
return psutil.virtual_memory().total
|
|
287
|
+
|
|
288
|
+
|
|
289
|
+
def get_device() -> torch.device:
|
|
290
|
+
"""Get device to run model on"""
|
|
291
|
+
if torch.cuda.is_available():
|
|
292
|
+
# Use CUDA GPU
|
|
293
|
+
return torch.device("cuda:0")
|
|
294
|
+
elif torch.backends.mps.is_available():
|
|
295
|
+
# Use Apple M1 Metal Acceleration
|
|
296
|
+
return torch.device("mps")
|
|
297
|
+
else:
|
|
298
|
+
return torch.device("cpu")
|
|
299
|
+
|
|
300
|
+
|
|
301
|
+
class ConversationCommand(str, Enum):
|
|
302
|
+
Default = "default"
|
|
303
|
+
General = "general"
|
|
304
|
+
Notes = "notes"
|
|
305
|
+
Help = "help"
|
|
306
|
+
Online = "online"
|
|
307
|
+
Webpage = "webpage"
|
|
308
|
+
Image = "image"
|
|
309
|
+
Text = "text"
|
|
310
|
+
Automation = "automation"
|
|
311
|
+
AutomatedTask = "automated_task"
|
|
312
|
+
Summarize = "summarize"
|
|
313
|
+
|
|
314
|
+
|
|
315
|
+
command_descriptions = {
|
|
316
|
+
ConversationCommand.General: "Only talk about information that relies on Khoj's general knowledge, not your personal knowledge base.",
|
|
317
|
+
ConversationCommand.Notes: "Only talk about information that is available in your knowledge base.",
|
|
318
|
+
ConversationCommand.Default: "The default command when no command specified. It intelligently auto-switches between general and notes mode.",
|
|
319
|
+
ConversationCommand.Online: "Search for information on the internet.",
|
|
320
|
+
ConversationCommand.Webpage: "Get information from webpage links provided by you.",
|
|
321
|
+
ConversationCommand.Image: "Generate images by describing your imagination in words.",
|
|
322
|
+
ConversationCommand.Automation: "Automatically run your query at a specified time or interval.",
|
|
323
|
+
ConversationCommand.Help: "Get help with how to use or setup Khoj from the documentation",
|
|
324
|
+
ConversationCommand.Summarize: "Create an appropriate summary using provided documents.",
|
|
325
|
+
}
|
|
326
|
+
|
|
327
|
+
tool_descriptions_for_llm = {
|
|
328
|
+
ConversationCommand.Default: "To use a mix of your internal knowledge and the user's personal knowledge, or if you don't entirely understand the query.",
|
|
329
|
+
ConversationCommand.General: "To use when you can answer the question without any outside information or personal knowledge",
|
|
330
|
+
ConversationCommand.Notes: "To search the user's personal knowledge base. Especially helpful if the question expects context from the user's notes or documents.",
|
|
331
|
+
ConversationCommand.Online: "To search for the latest, up-to-date information from the internet. Note: **Questions about Khoj should always use this data source**",
|
|
332
|
+
ConversationCommand.Webpage: "To use if the user has directly provided the webpage urls or you are certain of the webpage urls to read.",
|
|
333
|
+
ConversationCommand.Summarize: "To create a summary of the document provided by the user.",
|
|
334
|
+
}
|
|
335
|
+
|
|
336
|
+
mode_descriptions_for_llm = {
|
|
337
|
+
ConversationCommand.Image: "Use this if the user is requesting an image or visual response to their query.",
|
|
338
|
+
ConversationCommand.Automation: "Use this if the user is requesting a response at a scheduled date or time.",
|
|
339
|
+
ConversationCommand.Text: "Use this if the other response modes don't seem to fit the query.",
|
|
340
|
+
}
|
|
341
|
+
|
|
342
|
+
|
|
343
|
+
class ImageIntentType(Enum):
|
|
344
|
+
"""
|
|
345
|
+
Chat message intent by Khoj for image responses.
|
|
346
|
+
Marks the schema used to reference image in chat messages
|
|
347
|
+
"""
|
|
348
|
+
|
|
349
|
+
# Images as Inline PNG
|
|
350
|
+
TEXT_TO_IMAGE = "text-to-image"
|
|
351
|
+
# Images as URLs
|
|
352
|
+
TEXT_TO_IMAGE2 = "text-to-image2"
|
|
353
|
+
# Images as Inline WebP
|
|
354
|
+
TEXT_TO_IMAGE_V3 = "text-to-image-v3"
|
|
355
|
+
|
|
356
|
+
|
|
357
|
+
def generate_random_name():
|
|
358
|
+
# List of adjectives and nouns to choose from
|
|
359
|
+
adjectives = [
|
|
360
|
+
"happy",
|
|
361
|
+
"serendipitous",
|
|
362
|
+
"exuberant",
|
|
363
|
+
"calm",
|
|
364
|
+
"brave",
|
|
365
|
+
"scared",
|
|
366
|
+
"energetic",
|
|
367
|
+
"chivalrous",
|
|
368
|
+
"kind",
|
|
369
|
+
"suave",
|
|
370
|
+
]
|
|
371
|
+
nouns = ["dog", "cat", "falcon", "whale", "turtle", "rabbit", "hamster", "snake", "spider", "elephant"]
|
|
372
|
+
|
|
373
|
+
# Select two random words from the lists
|
|
374
|
+
adjective = random.choice(adjectives)
|
|
375
|
+
noun = random.choice(nouns)
|
|
376
|
+
|
|
377
|
+
# Combine the words to form a name
|
|
378
|
+
name = f"{adjective} {noun}"
|
|
379
|
+
|
|
380
|
+
return name
|
|
381
|
+
|
|
382
|
+
|
|
383
|
+
def batcher(iterable, max_n):
|
|
384
|
+
"Split an iterable into chunks of size max_n"
|
|
385
|
+
it = iter(iterable)
|
|
386
|
+
while True:
|
|
387
|
+
chunk = list(islice(it, max_n))
|
|
388
|
+
if not chunk:
|
|
389
|
+
return
|
|
390
|
+
yield (x for x in chunk if x is not None)
|
|
391
|
+
|
|
392
|
+
|
|
393
|
+
def is_env_var_true(env_var: str, default: str = "false") -> bool:
|
|
394
|
+
"""Get state of boolean environment variable"""
|
|
395
|
+
return os.getenv(env_var, default).lower() == "true"
|
|
396
|
+
|
|
397
|
+
|
|
398
|
+
def in_debug_mode():
|
|
399
|
+
"""Check if Khoj is running in debug mode.
|
|
400
|
+
Set KHOJ_DEBUG environment variable to true to enable debug mode."""
|
|
401
|
+
return is_env_var_true("KHOJ_DEBUG")
|
|
402
|
+
|
|
403
|
+
|
|
404
|
+
def is_valid_url(url: str) -> bool:
|
|
405
|
+
"""Check if a string is a valid URL"""
|
|
406
|
+
try:
|
|
407
|
+
result = urlparse(url.strip())
|
|
408
|
+
return all([result.scheme, result.netloc])
|
|
409
|
+
except:
|
|
410
|
+
return False
|
|
411
|
+
|
|
412
|
+
|
|
413
|
+
def is_internet_connected():
|
|
414
|
+
try:
|
|
415
|
+
response = requests.head("https://www.google.com")
|
|
416
|
+
return response.status_code == 200
|
|
417
|
+
except:
|
|
418
|
+
return False
|
|
@@ -0,0 +1,146 @@
|
|
|
1
|
+
import logging
|
|
2
|
+
import os
|
|
3
|
+
|
|
4
|
+
from khoj.database.adapters import ConversationAdapters
|
|
5
|
+
from khoj.database.models import (
|
|
6
|
+
ChatModelOptions,
|
|
7
|
+
KhojUser,
|
|
8
|
+
OpenAIProcessorConversationConfig,
|
|
9
|
+
SpeechToTextModelOptions,
|
|
10
|
+
TextToImageModelConfig,
|
|
11
|
+
)
|
|
12
|
+
from khoj.processor.conversation.utils import model_to_prompt_size, model_to_tokenizer
|
|
13
|
+
from khoj.utils.constants import default_offline_chat_model, default_online_chat_model
|
|
14
|
+
|
|
15
|
+
logger = logging.getLogger(__name__)
|
|
16
|
+
|
|
17
|
+
|
|
18
|
+
def initialization():
|
|
19
|
+
def _create_admin_user():
|
|
20
|
+
logger.info(
|
|
21
|
+
"👩✈️ Setting up admin user. These credentials will allow you to configure your server at /server/admin."
|
|
22
|
+
)
|
|
23
|
+
email_addr = os.getenv("KHOJ_ADMIN_EMAIL") or input("Email: ")
|
|
24
|
+
password = os.getenv("KHOJ_ADMIN_PASSWORD") or input("Password: ")
|
|
25
|
+
admin_user = KhojUser.objects.create_superuser(email=email_addr, username=email_addr, password=password)
|
|
26
|
+
logger.info(f"👩✈️ Created admin user: {admin_user.email}")
|
|
27
|
+
|
|
28
|
+
def _create_chat_configuration():
|
|
29
|
+
logger.info(
|
|
30
|
+
"🗣️ Configure chat models available to your server. You can always update these at /server/admin using the credentials of your admin account"
|
|
31
|
+
)
|
|
32
|
+
|
|
33
|
+
try:
|
|
34
|
+
use_offline_model = input("Use offline chat model? (y/n): ")
|
|
35
|
+
if use_offline_model == "y":
|
|
36
|
+
logger.info("🗣️ Setting up offline chat model")
|
|
37
|
+
|
|
38
|
+
offline_chat_model = input(
|
|
39
|
+
f"Enter the offline chat model you want to use. See HuggingFace for available GGUF models (default: {default_offline_chat_model}): "
|
|
40
|
+
)
|
|
41
|
+
if offline_chat_model == "":
|
|
42
|
+
ChatModelOptions.objects.create(
|
|
43
|
+
chat_model=default_offline_chat_model, model_type=ChatModelOptions.ModelType.OFFLINE
|
|
44
|
+
)
|
|
45
|
+
else:
|
|
46
|
+
default_max_tokens = model_to_prompt_size.get(offline_chat_model, 2000)
|
|
47
|
+
max_tokens = input(
|
|
48
|
+
f"Enter the maximum number of tokens to use for the offline chat model (default {default_max_tokens}):"
|
|
49
|
+
)
|
|
50
|
+
max_tokens = max_tokens or default_max_tokens
|
|
51
|
+
|
|
52
|
+
default_tokenizer = model_to_tokenizer.get(
|
|
53
|
+
offline_chat_model, "hf-internal-testing/llama-tokenizer"
|
|
54
|
+
)
|
|
55
|
+
tokenizer = input(
|
|
56
|
+
f"Enter the tokenizer to use for the offline chat model (default: {default_tokenizer}):"
|
|
57
|
+
)
|
|
58
|
+
tokenizer = tokenizer or default_tokenizer
|
|
59
|
+
|
|
60
|
+
ChatModelOptions.objects.create(
|
|
61
|
+
chat_model=offline_chat_model,
|
|
62
|
+
model_type=ChatModelOptions.ModelType.OFFLINE,
|
|
63
|
+
max_prompt_size=max_tokens,
|
|
64
|
+
tokenizer=tokenizer,
|
|
65
|
+
)
|
|
66
|
+
except ModuleNotFoundError as e:
|
|
67
|
+
logger.warning("Offline models are not supported on this device.")
|
|
68
|
+
|
|
69
|
+
use_openai_model = input("Use OpenAI models? (y/n): ")
|
|
70
|
+
if use_openai_model == "y":
|
|
71
|
+
logger.info("🗣️ Setting up your OpenAI configuration")
|
|
72
|
+
api_key = input("Enter your OpenAI API key: ")
|
|
73
|
+
OpenAIProcessorConversationConfig.objects.create(api_key=api_key)
|
|
74
|
+
|
|
75
|
+
openai_chat_model = input(
|
|
76
|
+
f"Enter the OpenAI chat model you want to use (default: {default_online_chat_model}): "
|
|
77
|
+
)
|
|
78
|
+
openai_chat_model = openai_chat_model or default_online_chat_model
|
|
79
|
+
|
|
80
|
+
default_max_tokens = model_to_prompt_size.get(openai_chat_model, 2000)
|
|
81
|
+
max_tokens = input(
|
|
82
|
+
f"Enter the maximum number of tokens to use for the OpenAI chat model (default: {default_max_tokens}): "
|
|
83
|
+
)
|
|
84
|
+
max_tokens = max_tokens or default_max_tokens
|
|
85
|
+
ChatModelOptions.objects.create(
|
|
86
|
+
chat_model=openai_chat_model, model_type=ChatModelOptions.ModelType.OPENAI, max_prompt_size=max_tokens
|
|
87
|
+
)
|
|
88
|
+
|
|
89
|
+
default_speech2text_model = "whisper-1"
|
|
90
|
+
openai_speech2text_model = input(
|
|
91
|
+
f"Enter the OpenAI speech to text model you want to use (default: {default_speech2text_model}): "
|
|
92
|
+
)
|
|
93
|
+
openai_speech2text_model = openai_speech2text_model or default_speech2text_model
|
|
94
|
+
SpeechToTextModelOptions.objects.create(
|
|
95
|
+
model_name=openai_speech2text_model, model_type=SpeechToTextModelOptions.ModelType.OPENAI
|
|
96
|
+
)
|
|
97
|
+
|
|
98
|
+
default_text_to_image_model = "dall-e-3"
|
|
99
|
+
openai_text_to_image_model = input(
|
|
100
|
+
f"Enter the OpenAI text to image model you want to use (default: {default_text_to_image_model}): "
|
|
101
|
+
)
|
|
102
|
+
openai_speech2text_model = openai_text_to_image_model or default_text_to_image_model
|
|
103
|
+
TextToImageModelConfig.objects.create(
|
|
104
|
+
model_name=openai_text_to_image_model, model_type=TextToImageModelConfig.ModelType.OPENAI
|
|
105
|
+
)
|
|
106
|
+
|
|
107
|
+
if use_offline_model == "y" or use_openai_model == "y":
|
|
108
|
+
logger.info("🗣️ Chat model configuration complete")
|
|
109
|
+
|
|
110
|
+
use_offline_speech2text_model = input("Use offline speech to text model? (y/n): ")
|
|
111
|
+
if use_offline_speech2text_model == "y":
|
|
112
|
+
logger.info("🗣️ Setting up offline speech to text model")
|
|
113
|
+
# Delete any existing speech to text model options. There can only be one.
|
|
114
|
+
SpeechToTextModelOptions.objects.all().delete()
|
|
115
|
+
|
|
116
|
+
default_offline_speech2text_model = "base"
|
|
117
|
+
offline_speech2text_model = input(
|
|
118
|
+
f"Enter the Whisper model to use Offline (default: {default_offline_speech2text_model}): "
|
|
119
|
+
)
|
|
120
|
+
offline_speech2text_model = offline_speech2text_model or default_offline_speech2text_model
|
|
121
|
+
SpeechToTextModelOptions.objects.create(
|
|
122
|
+
model_name=offline_speech2text_model, model_type=SpeechToTextModelOptions.ModelType.OFFLINE
|
|
123
|
+
)
|
|
124
|
+
|
|
125
|
+
logger.info(f"🗣️ Offline speech to text model configured to {offline_speech2text_model}")
|
|
126
|
+
|
|
127
|
+
admin_user = KhojUser.objects.filter(is_staff=True).first()
|
|
128
|
+
if admin_user is None:
|
|
129
|
+
while True:
|
|
130
|
+
try:
|
|
131
|
+
_create_admin_user()
|
|
132
|
+
break
|
|
133
|
+
except Exception as e:
|
|
134
|
+
logger.error(f"🚨 Failed to create admin user: {e}", exc_info=True)
|
|
135
|
+
|
|
136
|
+
chat_config = ConversationAdapters.get_default_conversation_config()
|
|
137
|
+
if admin_user is None and chat_config is None:
|
|
138
|
+
while True:
|
|
139
|
+
try:
|
|
140
|
+
_create_chat_configuration()
|
|
141
|
+
break
|
|
142
|
+
# Some environments don't support interactive input. We catch the exception and return if that's the case. The admin can still configure their settings from the admin page.
|
|
143
|
+
except EOFError:
|
|
144
|
+
return
|
|
145
|
+
except Exception as e:
|
|
146
|
+
logger.error(f"🚨 Failed to create chat configuration: {e}", exc_info=True)
|
khoj/utils/jsonl.py
ADDED
|
@@ -0,0 +1,43 @@
|
|
|
1
|
+
import gzip
|
|
2
|
+
import json
|
|
3
|
+
import logging
|
|
4
|
+
|
|
5
|
+
from khoj.utils.constants import empty_escape_sequences
|
|
6
|
+
from khoj.utils.helpers import get_absolute_path
|
|
7
|
+
|
|
8
|
+
logger = logging.getLogger(__name__)
|
|
9
|
+
|
|
10
|
+
|
|
11
|
+
def load_jsonl(input_path):
|
|
12
|
+
"Read List of JSON objects from JSON line file"
|
|
13
|
+
# Initialize Variables
|
|
14
|
+
data = []
|
|
15
|
+
jsonl_file = None
|
|
16
|
+
|
|
17
|
+
# Open JSONL file
|
|
18
|
+
if input_path.suffix == ".gz":
|
|
19
|
+
jsonl_file = gzip.open(get_absolute_path(input_path), "rt", encoding="utf-8")
|
|
20
|
+
else:
|
|
21
|
+
jsonl_file = open(get_absolute_path(input_path), "r", encoding="utf-8")
|
|
22
|
+
|
|
23
|
+
# Read JSONL file
|
|
24
|
+
for line in jsonl_file:
|
|
25
|
+
data.append(json.loads(line.strip(empty_escape_sequences)))
|
|
26
|
+
|
|
27
|
+
# Close JSONL file
|
|
28
|
+
jsonl_file.close()
|
|
29
|
+
|
|
30
|
+
# Log JSONL entries loaded
|
|
31
|
+
logger.debug(f"Loaded {len(data)} records from {input_path}")
|
|
32
|
+
|
|
33
|
+
return data
|
|
34
|
+
|
|
35
|
+
|
|
36
|
+
def compress_jsonl_data(jsonl_data, output_path):
|
|
37
|
+
# Create output directory, if it doesn't exist
|
|
38
|
+
output_path.parent.mkdir(parents=True, exist_ok=True)
|
|
39
|
+
|
|
40
|
+
with gzip.open(output_path, "wt", encoding="utf-8") as gzip_file:
|
|
41
|
+
gzip_file.write(jsonl_data)
|
|
42
|
+
|
|
43
|
+
logger.debug(f"Wrote jsonl data to gzip compressed jsonl at {output_path}")
|
khoj/utils/models.py
ADDED
|
@@ -0,0 +1,47 @@
|
|
|
1
|
+
from abc import ABC, abstractmethod
|
|
2
|
+
from typing import List
|
|
3
|
+
|
|
4
|
+
import openai
|
|
5
|
+
import torch
|
|
6
|
+
from tqdm import trange
|
|
7
|
+
|
|
8
|
+
|
|
9
|
+
class BaseEncoder(ABC):
|
|
10
|
+
@abstractmethod
|
|
11
|
+
def __init__(self, model_name: str, device: torch.device = None, **kwargs):
|
|
12
|
+
...
|
|
13
|
+
|
|
14
|
+
@abstractmethod
|
|
15
|
+
def encode(self, entries: List[str], device: torch.device = None, **kwargs) -> torch.Tensor:
|
|
16
|
+
...
|
|
17
|
+
|
|
18
|
+
|
|
19
|
+
class OpenAI(BaseEncoder):
|
|
20
|
+
def __init__(self, model_name, client: openai.OpenAI, device=None):
|
|
21
|
+
self.model_name = model_name
|
|
22
|
+
self.openai_client = client
|
|
23
|
+
self.embedding_dimensions = None
|
|
24
|
+
|
|
25
|
+
def encode(self, entries, device=None, **kwargs):
|
|
26
|
+
embedding_tensors = []
|
|
27
|
+
|
|
28
|
+
for index in trange(0, len(entries)):
|
|
29
|
+
# OpenAI models create better embeddings for entries without newlines
|
|
30
|
+
processed_entry = entries[index].replace("\n", " ")
|
|
31
|
+
|
|
32
|
+
try:
|
|
33
|
+
response = self.openai_client.embeddings.create(input=processed_entry, model=self.model_name)
|
|
34
|
+
embedding_tensors += [torch.tensor(response.data[0].embedding, device=device)]
|
|
35
|
+
# Use current models embedding dimension, once available
|
|
36
|
+
# Else default to embedding dimensions of the text-embedding-ada-002 model
|
|
37
|
+
self.embedding_dimensions = len(response.data[0].embedding) if not self.embedding_dimensions else 1536
|
|
38
|
+
except Exception as e:
|
|
39
|
+
print(
|
|
40
|
+
f"Failed to encode entry {index} of length: {len(entries[index])}\n\n{entries[index][:1000]}...\n\n{e}"
|
|
41
|
+
)
|
|
42
|
+
# Use zero embedding vector for entries with failed embeddings
|
|
43
|
+
# This ensures entry embeddings match the order of the source entries
|
|
44
|
+
# And they have minimal similarity to other entries (as zero vectors are always orthogonal to other vector)
|
|
45
|
+
embedding_tensors += [torch.zeros(self.embedding_dimensions, device=device)]
|
|
46
|
+
|
|
47
|
+
return torch.stack(embedding_tensors)
|