khoj 1.40.1.dev18__py3-none-any.whl → 1.40.1.dev27__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- khoj/database/adapters/__init__.py +24 -22
- khoj/interface/compiled/404/index.html +2 -2
- khoj/interface/compiled/_next/static/chunks/{2327-50940053d8852cae.js → 2327-64e90eab8ee88c9c.js} +1 -1
- khoj/interface/compiled/_next/static/chunks/{8515-010dd769c584b672.js → 8515-f305779d95dd5780.js} +1 -1
- khoj/interface/compiled/_next/static/chunks/app/agents/layout-4e2a134ec26aa606.js +1 -0
- khoj/interface/compiled/_next/static/chunks/app/agents/{page-c9ceb9b94e24b94a.js → page-ceeb9a91edea74ce.js} +1 -1
- khoj/interface/compiled/_next/static/chunks/app/automations/{page-3dc59a0df3827dc7.js → page-e3cb78747ab98cc7.js} +1 -1
- khoj/interface/compiled/_next/static/chunks/app/chat/layout-ad4d1792ab1a4108.js +1 -0
- khoj/interface/compiled/_next/static/chunks/app/chat/{page-ff3211ac7a5a1f48.js → page-5174ed911079ee57.js} +1 -1
- khoj/interface/compiled/_next/static/chunks/app/{page-38f1f125d7aeb4c7.js → page-a4053e1bb578b2ce.js} +1 -1
- khoj/interface/compiled/_next/static/chunks/app/search/layout-c02531d586972d7d.js +1 -0
- khoj/interface/compiled/_next/static/chunks/app/search/{page-26d4492fb1200e0e.js → page-8973da2f4c076fe1.js} +1 -1
- khoj/interface/compiled/_next/static/chunks/app/settings/{page-bf1a4e488b29fceb.js → page-375136dbb400525b.js} +1 -1
- khoj/interface/compiled/_next/static/chunks/app/share/chat/layout-e8e5db7830bf3f47.js +1 -0
- khoj/interface/compiled/_next/static/chunks/app/share/chat/{page-a1f10c96366c3a4f.js → page-384b54fc953b18f2.js} +1 -1
- khoj/interface/compiled/_next/static/chunks/{webpack-2707657083ef5456.js → webpack-21f76f7f59582bc7.js} +1 -1
- khoj/interface/compiled/_next/static/css/37a73b87f02df402.css +1 -0
- khoj/interface/compiled/_next/static/css/f29752d6e1be7624.css +1 -0
- khoj/interface/compiled/_next/static/css/{0db53bacf81896f5.css → fca983d49c3dd1a3.css} +1 -1
- khoj/interface/compiled/agents/index.html +2 -2
- khoj/interface/compiled/agents/index.txt +2 -2
- khoj/interface/compiled/automations/index.html +2 -2
- khoj/interface/compiled/automations/index.txt +3 -3
- khoj/interface/compiled/chat/index.html +2 -2
- khoj/interface/compiled/chat/index.txt +2 -2
- khoj/interface/compiled/index.html +2 -2
- khoj/interface/compiled/index.txt +2 -2
- khoj/interface/compiled/search/index.html +2 -2
- khoj/interface/compiled/search/index.txt +2 -2
- khoj/interface/compiled/settings/index.html +2 -2
- khoj/interface/compiled/settings/index.txt +4 -4
- khoj/interface/compiled/share/chat/index.html +2 -2
- khoj/interface/compiled/share/chat/index.txt +2 -2
- khoj/processor/conversation/anthropic/anthropic_chat.py +22 -12
- khoj/processor/conversation/anthropic/utils.py +21 -64
- khoj/processor/conversation/google/gemini_chat.py +22 -12
- khoj/processor/conversation/google/utils.py +26 -64
- khoj/processor/conversation/offline/chat_model.py +86 -35
- khoj/processor/conversation/openai/gpt.py +22 -12
- khoj/processor/conversation/openai/utils.py +52 -71
- khoj/processor/conversation/utils.py +2 -38
- khoj/routers/api.py +1 -3
- khoj/routers/api_chat.py +19 -16
- khoj/routers/helpers.py +13 -23
- khoj/routers/research.py +1 -2
- khoj/utils/helpers.py +56 -0
- {khoj-1.40.1.dev18.dist-info → khoj-1.40.1.dev27.dist-info}/METADATA +1 -1
- {khoj-1.40.1.dev18.dist-info → khoj-1.40.1.dev27.dist-info}/RECORD +59 -59
- khoj/interface/compiled/_next/static/chunks/app/agents/layout-e49165209d2e406c.js +0 -1
- khoj/interface/compiled/_next/static/chunks/app/chat/layout-d5ae861e1ade9d08.js +0 -1
- khoj/interface/compiled/_next/static/chunks/app/search/layout-f5881c7ae3ba0795.js +0 -1
- khoj/interface/compiled/_next/static/chunks/app/share/chat/layout-64a53f8ec4afa6b3.js +0 -1
- khoj/interface/compiled/_next/static/css/bb7ea98028b368f3.css +0 -1
- khoj/interface/compiled/_next/static/css/ee66643a6a5bf71c.css +0 -1
- /khoj/interface/compiled/_next/static/chunks/{1915-1943ee8a628b893c.js → 1915-ab4353eaca76f690.js} +0 -0
- /khoj/interface/compiled/_next/static/chunks/{2117-5a41630a2bd2eae8.js → 2117-1c18aa2098982bf9.js} +0 -0
- /khoj/interface/compiled/_next/static/chunks/{4363-e6ac2203564d1a3b.js → 4363-4efaf12abe696251.js} +0 -0
- /khoj/interface/compiled/_next/static/chunks/{4447-e038b251d626c340.js → 4447-5d44807c40355b1a.js} +0 -0
- /khoj/interface/compiled/_next/static/chunks/{8667-8136f74e9a086fca.js → 8667-adbe6017a66cef10.js} +0 -0
- /khoj/interface/compiled/_next/static/chunks/{9259-640fdd77408475df.js → 9259-d8bcd9da9e80c81e.js} +0 -0
- /khoj/interface/compiled/_next/static/{_n7QmqkwoEnXgFtkxOxET → jpAPJYOTptFnVVNIEkfa8}/_buildManifest.js +0 -0
- /khoj/interface/compiled/_next/static/{_n7QmqkwoEnXgFtkxOxET → jpAPJYOTptFnVVNIEkfa8}/_ssgManifest.js +0 -0
- {khoj-1.40.1.dev18.dist-info → khoj-1.40.1.dev27.dist-info}/WHEEL +0 -0
- {khoj-1.40.1.dev18.dist-info → khoj-1.40.1.dev27.dist-info}/entry_points.txt +0 -0
- {khoj-1.40.1.dev18.dist-info → khoj-1.40.1.dev27.dist-info}/licenses/LICENSE +0 -0
@@ -1,6 +1,6 @@
|
|
1
1
|
import logging
|
2
2
|
from datetime import datetime, timedelta
|
3
|
-
from typing import Dict, List, Optional
|
3
|
+
from typing import AsyncGenerator, Dict, List, Optional
|
4
4
|
|
5
5
|
import pyjson5
|
6
6
|
from langchain.schema import ChatMessage
|
@@ -160,7 +160,7 @@ def gemini_send_message_to_model(
|
|
160
160
|
)
|
161
161
|
|
162
162
|
|
163
|
-
def converse_gemini(
|
163
|
+
async def converse_gemini(
|
164
164
|
references,
|
165
165
|
user_query,
|
166
166
|
online_results: Optional[Dict[str, Dict]] = None,
|
@@ -185,7 +185,7 @@ def converse_gemini(
|
|
185
185
|
program_execution_context: List[str] = None,
|
186
186
|
deepthought: Optional[bool] = False,
|
187
187
|
tracer={},
|
188
|
-
):
|
188
|
+
) -> AsyncGenerator[str, None]:
|
189
189
|
"""
|
190
190
|
Converse with user using Google's Gemini
|
191
191
|
"""
|
@@ -216,11 +216,17 @@ def converse_gemini(
|
|
216
216
|
|
217
217
|
# Get Conversation Primer appropriate to Conversation Type
|
218
218
|
if conversation_commands == [ConversationCommand.Notes] and is_none_or_empty(references):
|
219
|
-
|
220
|
-
|
219
|
+
response = prompts.no_notes_found.format()
|
220
|
+
if completion_func:
|
221
|
+
await completion_func(chat_response=response)
|
222
|
+
yield response
|
223
|
+
return
|
221
224
|
elif conversation_commands == [ConversationCommand.Online] and is_none_or_empty(online_results):
|
222
|
-
|
223
|
-
|
225
|
+
response = prompts.no_online_results_found.format()
|
226
|
+
if completion_func:
|
227
|
+
await completion_func(chat_response=response)
|
228
|
+
yield response
|
229
|
+
return
|
224
230
|
|
225
231
|
context_message = ""
|
226
232
|
if not is_none_or_empty(references):
|
@@ -253,16 +259,20 @@ def converse_gemini(
|
|
253
259
|
logger.debug(f"Conversation Context for Gemini: {messages_to_print(messages)}")
|
254
260
|
|
255
261
|
# Get Response from Google AI
|
256
|
-
|
262
|
+
full_response = ""
|
263
|
+
async for chunk in gemini_chat_completion_with_backoff(
|
257
264
|
messages=messages,
|
258
|
-
compiled_references=references,
|
259
|
-
online_results=online_results,
|
260
265
|
model_name=model,
|
261
266
|
temperature=temperature,
|
262
267
|
api_key=api_key,
|
263
268
|
api_base_url=api_base_url,
|
264
269
|
system_prompt=system_prompt,
|
265
|
-
completion_func=completion_func,
|
266
270
|
deepthought=deepthought,
|
267
271
|
tracer=tracer,
|
268
|
-
)
|
272
|
+
):
|
273
|
+
full_response += chunk
|
274
|
+
yield chunk
|
275
|
+
|
276
|
+
# Call completion_func once finish streaming and we have the full response
|
277
|
+
if completion_func:
|
278
|
+
await completion_func(chat_response=full_response)
|
@@ -2,8 +2,8 @@ import logging
|
|
2
2
|
import os
|
3
3
|
import random
|
4
4
|
from copy import deepcopy
|
5
|
-
from
|
6
|
-
from typing import Dict
|
5
|
+
from time import perf_counter
|
6
|
+
from typing import AsyncGenerator, AsyncIterator, Dict
|
7
7
|
|
8
8
|
from google import genai
|
9
9
|
from google.genai import errors as gerrors
|
@@ -19,14 +19,13 @@ from tenacity import (
|
|
19
19
|
)
|
20
20
|
|
21
21
|
from khoj.processor.conversation.utils import (
|
22
|
-
ThreadedGenerator,
|
23
22
|
commit_conversation_trace,
|
24
23
|
get_image_from_base64,
|
25
24
|
get_image_from_url,
|
26
25
|
)
|
27
26
|
from khoj.utils.helpers import (
|
28
|
-
get_ai_api_info,
|
29
27
|
get_chat_usage_metrics,
|
28
|
+
get_gemini_client,
|
30
29
|
is_none_or_empty,
|
31
30
|
is_promptrace_enabled,
|
32
31
|
)
|
@@ -62,17 +61,6 @@ SAFETY_SETTINGS = [
|
|
62
61
|
]
|
63
62
|
|
64
63
|
|
65
|
-
def get_gemini_client(api_key, api_base_url=None) -> genai.Client:
|
66
|
-
api_info = get_ai_api_info(api_key, api_base_url)
|
67
|
-
return genai.Client(
|
68
|
-
location=api_info.region,
|
69
|
-
project=api_info.project,
|
70
|
-
credentials=api_info.credentials,
|
71
|
-
api_key=api_info.api_key,
|
72
|
-
vertexai=api_info.api_key is None,
|
73
|
-
)
|
74
|
-
|
75
|
-
|
76
64
|
@retry(
|
77
65
|
wait=wait_random_exponential(min=1, max=10),
|
78
66
|
stop=stop_after_attempt(2),
|
@@ -132,8 +120,8 @@ def gemini_completion_with_backoff(
|
|
132
120
|
)
|
133
121
|
|
134
122
|
# Aggregate cost of chat
|
135
|
-
input_tokens = response.usage_metadata.prompt_token_count if response else 0
|
136
|
-
output_tokens = response.usage_metadata.candidates_token_count if response else 0
|
123
|
+
input_tokens = response.usage_metadata.prompt_token_count or 0 if response else 0
|
124
|
+
output_tokens = response.usage_metadata.candidates_token_count or 0 if response else 0
|
137
125
|
thought_tokens = response.usage_metadata.thoughts_token_count or 0 if response else 0
|
138
126
|
tracer["usage"] = get_chat_usage_metrics(
|
139
127
|
model_name, input_tokens, output_tokens, thought_tokens=thought_tokens, usage=tracer.get("usage")
|
@@ -154,52 +142,17 @@ def gemini_completion_with_backoff(
|
|
154
142
|
before_sleep=before_sleep_log(logger, logging.DEBUG),
|
155
143
|
reraise=True,
|
156
144
|
)
|
157
|
-
def gemini_chat_completion_with_backoff(
|
145
|
+
async def gemini_chat_completion_with_backoff(
|
158
146
|
messages,
|
159
|
-
compiled_references,
|
160
|
-
online_results,
|
161
147
|
model_name,
|
162
148
|
temperature,
|
163
149
|
api_key,
|
164
150
|
api_base_url,
|
165
151
|
system_prompt,
|
166
|
-
completion_func=None,
|
167
152
|
model_kwargs=None,
|
168
153
|
deepthought=False,
|
169
154
|
tracer: dict = {},
|
170
|
-
):
|
171
|
-
g = ThreadedGenerator(compiled_references, online_results, completion_func=completion_func)
|
172
|
-
t = Thread(
|
173
|
-
target=gemini_llm_thread,
|
174
|
-
args=(
|
175
|
-
g,
|
176
|
-
messages,
|
177
|
-
system_prompt,
|
178
|
-
model_name,
|
179
|
-
temperature,
|
180
|
-
api_key,
|
181
|
-
api_base_url,
|
182
|
-
model_kwargs,
|
183
|
-
deepthought,
|
184
|
-
tracer,
|
185
|
-
),
|
186
|
-
)
|
187
|
-
t.start()
|
188
|
-
return g
|
189
|
-
|
190
|
-
|
191
|
-
def gemini_llm_thread(
|
192
|
-
g,
|
193
|
-
messages,
|
194
|
-
system_prompt,
|
195
|
-
model_name,
|
196
|
-
temperature,
|
197
|
-
api_key,
|
198
|
-
api_base_url=None,
|
199
|
-
model_kwargs=None,
|
200
|
-
deepthought=False,
|
201
|
-
tracer: dict = {},
|
202
|
-
):
|
155
|
+
) -> AsyncGenerator[str, None]:
|
203
156
|
try:
|
204
157
|
client = gemini_clients.get(api_key)
|
205
158
|
if not client:
|
@@ -224,21 +177,32 @@ def gemini_llm_thread(
|
|
224
177
|
)
|
225
178
|
|
226
179
|
aggregated_response = ""
|
227
|
-
|
228
|
-
|
180
|
+
final_chunk = None
|
181
|
+
start_time = perf_counter()
|
182
|
+
chat_stream: AsyncIterator[gtypes.GenerateContentResponse] = await client.aio.models.generate_content_stream(
|
229
183
|
model=model_name, config=config, contents=formatted_messages
|
230
|
-
)
|
184
|
+
)
|
185
|
+
async for chunk in chat_stream:
|
186
|
+
# Log the time taken to start response
|
187
|
+
if final_chunk is None:
|
188
|
+
logger.info(f"First response took: {perf_counter() - start_time:.3f} seconds")
|
189
|
+
# Keep track of the last chunk for usage data
|
190
|
+
final_chunk = chunk
|
191
|
+
# Handle streamed response chunk
|
231
192
|
message, stopped = handle_gemini_response(chunk.candidates, chunk.prompt_feedback)
|
232
193
|
message = message or chunk.text
|
233
194
|
aggregated_response += message
|
234
|
-
|
195
|
+
yield message
|
235
196
|
if stopped:
|
236
197
|
raise ValueError(message)
|
237
198
|
|
199
|
+
# Log the time taken to stream the entire response
|
200
|
+
logger.info(f"Chat streaming took: {perf_counter() - start_time:.3f} seconds")
|
201
|
+
|
238
202
|
# Calculate cost of chat
|
239
|
-
input_tokens =
|
240
|
-
output_tokens =
|
241
|
-
thought_tokens =
|
203
|
+
input_tokens = final_chunk.usage_metadata.prompt_token_count or 0 if final_chunk else 0
|
204
|
+
output_tokens = final_chunk.usage_metadata.candidates_token_count or 0 if final_chunk else 0
|
205
|
+
thought_tokens = final_chunk.usage_metadata.thoughts_token_count or 0 if final_chunk else 0
|
242
206
|
tracer["usage"] = get_chat_usage_metrics(
|
243
207
|
model_name, input_tokens, output_tokens, thought_tokens=thought_tokens, usage=tracer.get("usage")
|
244
208
|
)
|
@@ -254,9 +218,7 @@ def gemini_llm_thread(
|
|
254
218
|
+ f"Last Message by {messages[-1].role}: {messages[-1].content}"
|
255
219
|
)
|
256
220
|
except Exception as e:
|
257
|
-
logger.error(f"Error in
|
258
|
-
finally:
|
259
|
-
g.close()
|
221
|
+
logger.error(f"Error in gemini_chat_completion_with_backoff stream: {e}", exc_info=True)
|
260
222
|
|
261
223
|
|
262
224
|
def handle_gemini_response(
|
@@ -1,9 +1,10 @@
|
|
1
|
-
import
|
1
|
+
import asyncio
|
2
2
|
import logging
|
3
3
|
import os
|
4
4
|
from datetime import datetime, timedelta
|
5
5
|
from threading import Thread
|
6
|
-
from
|
6
|
+
from time import perf_counter
|
7
|
+
from typing import Any, AsyncGenerator, Dict, List, Optional, Union
|
7
8
|
|
8
9
|
import pyjson5
|
9
10
|
from langchain.schema import ChatMessage
|
@@ -13,7 +14,6 @@ from khoj.database.models import Agent, ChatModel, KhojUser
|
|
13
14
|
from khoj.processor.conversation import prompts
|
14
15
|
from khoj.processor.conversation.offline.utils import download_model
|
15
16
|
from khoj.processor.conversation.utils import (
|
16
|
-
ThreadedGenerator,
|
17
17
|
clean_json,
|
18
18
|
commit_conversation_trace,
|
19
19
|
generate_chatml_messages_with_context,
|
@@ -147,7 +147,7 @@ def filter_questions(questions: List[str]):
|
|
147
147
|
return list(filtered_questions)
|
148
148
|
|
149
149
|
|
150
|
-
def converse_offline(
|
150
|
+
async def converse_offline(
|
151
151
|
user_query,
|
152
152
|
references=[],
|
153
153
|
online_results={},
|
@@ -167,9 +167,9 @@ def converse_offline(
|
|
167
167
|
additional_context: List[str] = None,
|
168
168
|
generated_asset_results: Dict[str, Dict] = {},
|
169
169
|
tracer: dict = {},
|
170
|
-
) ->
|
170
|
+
) -> AsyncGenerator[str, None]:
|
171
171
|
"""
|
172
|
-
Converse with user using Llama
|
172
|
+
Converse with user using Llama (Async Version)
|
173
173
|
"""
|
174
174
|
# Initialize Variables
|
175
175
|
assert loaded_model is None or isinstance(loaded_model, Llama), "loaded_model must be of type Llama, if configured"
|
@@ -200,10 +200,17 @@ def converse_offline(
|
|
200
200
|
|
201
201
|
# Get Conversation Primer appropriate to Conversation Type
|
202
202
|
if conversation_commands == [ConversationCommand.Notes] and is_none_or_empty(references):
|
203
|
-
|
203
|
+
response = prompts.no_notes_found.format()
|
204
|
+
if completion_func:
|
205
|
+
await completion_func(chat_response=response)
|
206
|
+
yield response
|
207
|
+
return
|
204
208
|
elif conversation_commands == [ConversationCommand.Online] and is_none_or_empty(online_results):
|
205
|
-
|
206
|
-
|
209
|
+
response = prompts.no_online_results_found.format()
|
210
|
+
if completion_func:
|
211
|
+
await completion_func(chat_response=response)
|
212
|
+
yield response
|
213
|
+
return
|
207
214
|
|
208
215
|
context_message = ""
|
209
216
|
if not is_none_or_empty(references):
|
@@ -240,33 +247,77 @@ def converse_offline(
|
|
240
247
|
|
241
248
|
logger.debug(f"Conversation Context for {model_name}: {messages_to_print(messages)}")
|
242
249
|
|
243
|
-
|
244
|
-
|
245
|
-
t.start()
|
246
|
-
return g
|
247
|
-
|
248
|
-
|
249
|
-
def llm_thread(g, messages: List[ChatMessage], model: Any, max_prompt_size: int = None, tracer: dict = {}):
|
250
|
+
# Use asyncio.Queue and a thread to bridge sync iterator
|
251
|
+
queue: asyncio.Queue = asyncio.Queue()
|
250
252
|
stop_phrases = ["<s>", "INST]", "Notes:"]
|
251
|
-
|
252
|
-
|
253
|
-
|
254
|
-
|
255
|
-
|
256
|
-
|
257
|
-
)
|
258
|
-
|
259
|
-
|
260
|
-
|
261
|
-
|
262
|
-
|
263
|
-
|
264
|
-
|
265
|
-
|
266
|
-
|
267
|
-
|
268
|
-
|
269
|
-
|
253
|
+
aggregated_response_container = {"response": ""}
|
254
|
+
|
255
|
+
def _sync_llm_thread():
|
256
|
+
"""Synchronous function to run in a separate thread."""
|
257
|
+
aggregated_response = ""
|
258
|
+
start_time = perf_counter()
|
259
|
+
state.chat_lock.acquire()
|
260
|
+
try:
|
261
|
+
response_iterator = send_message_to_model_offline(
|
262
|
+
messages,
|
263
|
+
loaded_model=offline_chat_model,
|
264
|
+
stop=stop_phrases,
|
265
|
+
max_prompt_size=max_prompt_size,
|
266
|
+
streaming=True,
|
267
|
+
tracer=tracer,
|
268
|
+
)
|
269
|
+
for response in response_iterator:
|
270
|
+
response_delta = response["choices"][0]["delta"].get("content", "")
|
271
|
+
# Log the time taken to start response
|
272
|
+
if aggregated_response == "" and response_delta != "":
|
273
|
+
logger.info(f"First response took: {perf_counter() - start_time:.3f} seconds")
|
274
|
+
# Handle response chunk
|
275
|
+
aggregated_response += response_delta
|
276
|
+
# Put chunk into the asyncio queue (non-blocking)
|
277
|
+
try:
|
278
|
+
queue.put_nowait(response_delta)
|
279
|
+
except asyncio.QueueFull:
|
280
|
+
# Should not happen with default queue size unless consumer is very slow
|
281
|
+
logger.warning("Asyncio queue full during offline LLM streaming.")
|
282
|
+
# Potentially block here or handle differently if needed
|
283
|
+
asyncio.run(queue.put(response_delta))
|
284
|
+
|
285
|
+
# Log the time taken to stream the entire response
|
286
|
+
logger.info(f"Chat streaming took: {perf_counter() - start_time:.3f} seconds")
|
287
|
+
|
288
|
+
# Save conversation trace
|
289
|
+
tracer["chat_model"] = model_name
|
290
|
+
if is_promptrace_enabled():
|
291
|
+
commit_conversation_trace(messages, aggregated_response, tracer)
|
292
|
+
|
293
|
+
except Exception as e:
|
294
|
+
logger.error(f"Error in offline LLM thread: {e}", exc_info=True)
|
295
|
+
finally:
|
296
|
+
state.chat_lock.release()
|
297
|
+
# Signal end of stream
|
298
|
+
queue.put_nowait(None)
|
299
|
+
aggregated_response_container["response"] = aggregated_response
|
300
|
+
|
301
|
+
# Start the synchronous thread
|
302
|
+
thread = Thread(target=_sync_llm_thread)
|
303
|
+
thread.start()
|
304
|
+
|
305
|
+
# Asynchronously consume from the queue
|
306
|
+
while True:
|
307
|
+
chunk = await queue.get()
|
308
|
+
if chunk is None: # End of stream signal
|
309
|
+
queue.task_done()
|
310
|
+
break
|
311
|
+
yield chunk
|
312
|
+
queue.task_done()
|
313
|
+
|
314
|
+
# Wait for the thread to finish (optional, ensures cleanup)
|
315
|
+
loop = asyncio.get_running_loop()
|
316
|
+
await loop.run_in_executor(None, thread.join)
|
317
|
+
|
318
|
+
# Call the completion function after streaming is done
|
319
|
+
if completion_func:
|
320
|
+
await completion_func(chat_response=aggregated_response_container["response"])
|
270
321
|
|
271
322
|
|
272
323
|
def send_message_to_model_offline(
|
@@ -1,6 +1,6 @@
|
|
1
1
|
import logging
|
2
2
|
from datetime import datetime, timedelta
|
3
|
-
from typing import Dict, List, Optional
|
3
|
+
from typing import AsyncGenerator, Dict, List, Optional
|
4
4
|
|
5
5
|
import pyjson5
|
6
6
|
from langchain.schema import ChatMessage
|
@@ -162,7 +162,7 @@ def send_message_to_model(
|
|
162
162
|
)
|
163
163
|
|
164
164
|
|
165
|
-
def converse_openai(
|
165
|
+
async def converse_openai(
|
166
166
|
references,
|
167
167
|
user_query,
|
168
168
|
online_results: Optional[Dict[str, Dict]] = None,
|
@@ -187,7 +187,7 @@ def converse_openai(
|
|
187
187
|
program_execution_context: List[str] = None,
|
188
188
|
deepthought: Optional[bool] = False,
|
189
189
|
tracer: dict = {},
|
190
|
-
):
|
190
|
+
) -> AsyncGenerator[str, None]:
|
191
191
|
"""
|
192
192
|
Converse with user using OpenAI's ChatGPT
|
193
193
|
"""
|
@@ -217,11 +217,17 @@ def converse_openai(
|
|
217
217
|
|
218
218
|
# Get Conversation Primer appropriate to Conversation Type
|
219
219
|
if conversation_commands == [ConversationCommand.Notes] and is_none_or_empty(references):
|
220
|
-
|
221
|
-
|
220
|
+
response = prompts.no_notes_found.format()
|
221
|
+
if completion_func:
|
222
|
+
await completion_func(chat_response=response)
|
223
|
+
yield response
|
224
|
+
return
|
222
225
|
elif conversation_commands == [ConversationCommand.Online] and is_none_or_empty(online_results):
|
223
|
-
|
224
|
-
|
226
|
+
response = prompts.no_online_results_found.format()
|
227
|
+
if completion_func:
|
228
|
+
await completion_func(chat_response=response)
|
229
|
+
yield response
|
230
|
+
return
|
225
231
|
|
226
232
|
context_message = ""
|
227
233
|
if not is_none_or_empty(references):
|
@@ -255,19 +261,23 @@ def converse_openai(
|
|
255
261
|
logger.debug(f"Conversation Context for GPT: {messages_to_print(messages)}")
|
256
262
|
|
257
263
|
# Get Response from GPT
|
258
|
-
|
264
|
+
full_response = ""
|
265
|
+
async for chunk in chat_completion_with_backoff(
|
259
266
|
messages=messages,
|
260
|
-
compiled_references=references,
|
261
|
-
online_results=online_results,
|
262
267
|
model_name=model,
|
263
268
|
temperature=temperature,
|
264
269
|
openai_api_key=api_key,
|
265
270
|
api_base_url=api_base_url,
|
266
|
-
completion_func=completion_func,
|
267
271
|
deepthought=deepthought,
|
268
272
|
model_kwargs={"stop": ["Notes:\n["]},
|
269
273
|
tracer=tracer,
|
270
|
-
)
|
274
|
+
):
|
275
|
+
full_response += chunk
|
276
|
+
yield chunk
|
277
|
+
|
278
|
+
# Call completion_func once finish streaming and we have the full response
|
279
|
+
if completion_func:
|
280
|
+
await completion_func(chat_response=full_response)
|
271
281
|
|
272
282
|
|
273
283
|
def clean_response_schema(schema: BaseModel | dict) -> dict:
|
@@ -1,7 +1,7 @@
|
|
1
1
|
import logging
|
2
2
|
import os
|
3
|
-
from
|
4
|
-
from typing import Dict, List
|
3
|
+
from time import perf_counter
|
4
|
+
from typing import AsyncGenerator, Dict, List
|
5
5
|
from urllib.parse import urlparse
|
6
6
|
|
7
7
|
import openai
|
@@ -16,13 +16,10 @@ from tenacity import (
|
|
16
16
|
wait_random_exponential,
|
17
17
|
)
|
18
18
|
|
19
|
-
from khoj.processor.conversation.utils import
|
20
|
-
JsonSupport,
|
21
|
-
ThreadedGenerator,
|
22
|
-
commit_conversation_trace,
|
23
|
-
)
|
19
|
+
from khoj.processor.conversation.utils import JsonSupport, commit_conversation_trace
|
24
20
|
from khoj.utils.helpers import (
|
25
21
|
get_chat_usage_metrics,
|
22
|
+
get_openai_async_client,
|
26
23
|
get_openai_client,
|
27
24
|
is_promptrace_enabled,
|
28
25
|
)
|
@@ -30,6 +27,7 @@ from khoj.utils.helpers import (
|
|
30
27
|
logger = logging.getLogger(__name__)
|
31
28
|
|
32
29
|
openai_clients: Dict[str, openai.OpenAI] = {}
|
30
|
+
openai_async_clients: Dict[str, openai.AsyncOpenAI] = {}
|
33
31
|
|
34
32
|
|
35
33
|
@retry(
|
@@ -124,45 +122,22 @@ def completion_with_backoff(
|
|
124
122
|
before_sleep=before_sleep_log(logger, logging.DEBUG),
|
125
123
|
reraise=True,
|
126
124
|
)
|
127
|
-
def chat_completion_with_backoff(
|
125
|
+
async def chat_completion_with_backoff(
|
128
126
|
messages,
|
129
|
-
compiled_references,
|
130
|
-
online_results,
|
131
127
|
model_name,
|
132
128
|
temperature,
|
133
129
|
openai_api_key=None,
|
134
130
|
api_base_url=None,
|
135
|
-
completion_func=None,
|
136
|
-
deepthought=False,
|
137
|
-
model_kwargs=None,
|
138
|
-
tracer: dict = {},
|
139
|
-
):
|
140
|
-
g = ThreadedGenerator(compiled_references, online_results, completion_func=completion_func)
|
141
|
-
t = Thread(
|
142
|
-
target=llm_thread,
|
143
|
-
args=(g, messages, model_name, temperature, openai_api_key, api_base_url, deepthought, model_kwargs, tracer),
|
144
|
-
)
|
145
|
-
t.start()
|
146
|
-
return g
|
147
|
-
|
148
|
-
|
149
|
-
def llm_thread(
|
150
|
-
g,
|
151
|
-
messages,
|
152
|
-
model_name: str,
|
153
|
-
temperature,
|
154
|
-
openai_api_key=None,
|
155
|
-
api_base_url=None,
|
156
131
|
deepthought=False,
|
157
132
|
model_kwargs: dict = {},
|
158
133
|
tracer: dict = {},
|
159
|
-
):
|
134
|
+
) -> AsyncGenerator[str, None]:
|
160
135
|
try:
|
161
136
|
client_key = f"{openai_api_key}--{api_base_url}"
|
162
|
-
client =
|
137
|
+
client = openai_async_clients.get(client_key)
|
163
138
|
if not client:
|
164
|
-
client =
|
165
|
-
|
139
|
+
client = get_openai_async_client(openai_api_key, api_base_url)
|
140
|
+
openai_async_clients[client_key] = client
|
166
141
|
|
167
142
|
formatted_messages = [{"role": message.role, "content": message.content} for message in messages]
|
168
143
|
|
@@ -181,9 +156,10 @@ def llm_thread(
|
|
181
156
|
]
|
182
157
|
if len(system_messages) > 0:
|
183
158
|
first_system_message_index, first_system_message = system_messages[0]
|
159
|
+
first_system_message_content = first_system_message["content"]
|
184
160
|
formatted_messages[first_system_message_index][
|
185
161
|
"content"
|
186
|
-
] = f"{
|
162
|
+
] = f"{first_system_message_content}\nFormatting re-enabled"
|
187
163
|
elif is_twitter_reasoning_model(model_name, api_base_url):
|
188
164
|
reasoning_effort = "high" if deepthought else "low"
|
189
165
|
model_kwargs["reasoning_effort"] = reasoning_effort
|
@@ -206,53 +182,58 @@ def llm_thread(
|
|
206
182
|
if os.getenv("KHOJ_LLM_SEED"):
|
207
183
|
model_kwargs["seed"] = int(os.getenv("KHOJ_LLM_SEED"))
|
208
184
|
|
209
|
-
|
210
|
-
|
211
|
-
|
185
|
+
aggregated_response = ""
|
186
|
+
final_chunk = None
|
187
|
+
start_time = perf_counter()
|
188
|
+
chat_stream: openai.AsyncStream[ChatCompletionChunk] = await client.chat.completions.create(
|
189
|
+
messages=formatted_messages, # type: ignore
|
190
|
+
model=model_name,
|
212
191
|
stream=stream,
|
213
192
|
temperature=temperature,
|
214
193
|
timeout=20,
|
215
194
|
**model_kwargs,
|
216
195
|
)
|
217
|
-
|
218
|
-
|
219
|
-
|
220
|
-
|
221
|
-
|
222
|
-
|
223
|
-
|
224
|
-
|
225
|
-
|
226
|
-
|
227
|
-
|
228
|
-
|
229
|
-
|
230
|
-
|
231
|
-
|
232
|
-
|
233
|
-
|
234
|
-
|
235
|
-
|
236
|
-
|
237
|
-
|
238
|
-
|
239
|
-
|
240
|
-
cost =
|
241
|
-
|
242
|
-
|
243
|
-
|
244
|
-
|
245
|
-
|
196
|
+
async for chunk in chat_stream:
|
197
|
+
# Log the time taken to start response
|
198
|
+
if final_chunk is None:
|
199
|
+
logger.info(f"First response took: {perf_counter() - start_time:.3f} seconds")
|
200
|
+
# Keep track of the last chunk for usage data
|
201
|
+
final_chunk = chunk
|
202
|
+
# Handle streamed response chunk
|
203
|
+
if len(chunk.choices) == 0:
|
204
|
+
continue
|
205
|
+
delta_chunk = chunk.choices[0].delta
|
206
|
+
text_chunk = ""
|
207
|
+
if isinstance(delta_chunk, str):
|
208
|
+
text_chunk = delta_chunk
|
209
|
+
elif delta_chunk and delta_chunk.content:
|
210
|
+
text_chunk = delta_chunk.content
|
211
|
+
if text_chunk:
|
212
|
+
aggregated_response += text_chunk
|
213
|
+
yield text_chunk
|
214
|
+
|
215
|
+
# Log the time taken to stream the entire response
|
216
|
+
logger.info(f"Chat streaming took: {perf_counter() - start_time:.3f} seconds")
|
217
|
+
|
218
|
+
# Calculate cost of chat after stream finishes
|
219
|
+
input_tokens, output_tokens, cost = 0, 0, 0
|
220
|
+
if final_chunk and hasattr(final_chunk, "usage") and final_chunk.usage:
|
221
|
+
input_tokens = final_chunk.usage.prompt_tokens
|
222
|
+
output_tokens = final_chunk.usage.completion_tokens
|
223
|
+
# Estimated costs returned by DeepInfra API
|
224
|
+
if final_chunk.usage.model_extra and "estimated_cost" in final_chunk.usage.model_extra:
|
225
|
+
cost = final_chunk.usage.model_extra.get("estimated_cost", 0)
|
246
226
|
|
247
227
|
# Save conversation trace
|
248
228
|
tracer["chat_model"] = model_name
|
249
229
|
tracer["temperature"] = temperature
|
230
|
+
tracer["usage"] = get_chat_usage_metrics(
|
231
|
+
model_name, input_tokens, output_tokens, usage=tracer.get("usage"), cost=cost
|
232
|
+
)
|
250
233
|
if is_promptrace_enabled():
|
251
234
|
commit_conversation_trace(messages, aggregated_response, tracer)
|
252
235
|
except Exception as e:
|
253
|
-
logger.error(f"Error in
|
254
|
-
finally:
|
255
|
-
g.close()
|
236
|
+
logger.error(f"Error in chat_completion_with_backoff stream: {e}", exc_info=True)
|
256
237
|
|
257
238
|
|
258
239
|
def get_openai_api_json_support(model_name: str, api_base_url: str = None) -> JsonSupport:
|