khoj 1.40.1.dev20__py3-none-any.whl → 1.40.1.dev28__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- khoj/database/adapters/__init__.py +24 -22
- khoj/interface/compiled/404/index.html +2 -2
- khoj/interface/compiled/_next/static/chunks/{webpack-639e8a9d35bd5e13.js → webpack-05ff3cbe22520b30.js} +1 -1
- khoj/interface/compiled/agents/index.html +2 -2
- khoj/interface/compiled/agents/index.txt +2 -2
- khoj/interface/compiled/automations/index.html +2 -2
- khoj/interface/compiled/automations/index.txt +3 -3
- khoj/interface/compiled/chat/index.html +2 -2
- khoj/interface/compiled/chat/index.txt +2 -2
- khoj/interface/compiled/index.html +2 -2
- khoj/interface/compiled/index.txt +2 -2
- khoj/interface/compiled/search/index.html +2 -2
- khoj/interface/compiled/search/index.txt +2 -2
- khoj/interface/compiled/settings/index.html +2 -2
- khoj/interface/compiled/settings/index.txt +4 -4
- khoj/interface/compiled/share/chat/index.html +2 -2
- khoj/interface/compiled/share/chat/index.txt +2 -2
- khoj/processor/conversation/anthropic/anthropic_chat.py +23 -12
- khoj/processor/conversation/anthropic/utils.py +21 -64
- khoj/processor/conversation/google/gemini_chat.py +23 -12
- khoj/processor/conversation/google/utils.py +26 -64
- khoj/processor/conversation/offline/chat_model.py +86 -35
- khoj/processor/conversation/openai/gpt.py +23 -12
- khoj/processor/conversation/openai/utils.py +50 -70
- khoj/processor/conversation/utils.py +2 -38
- khoj/routers/api_chat.py +36 -28
- khoj/routers/helpers.py +13 -23
- khoj/utils/helpers.py +56 -0
- {khoj-1.40.1.dev20.dist-info → khoj-1.40.1.dev28.dist-info}/METADATA +1 -1
- {khoj-1.40.1.dev20.dist-info → khoj-1.40.1.dev28.dist-info}/RECORD +41 -41
- /khoj/interface/compiled/_next/static/{cbPHN2fL-jO8972HjaRbw → 7IDXliK3ZEqSeZICqmtPL}/_buildManifest.js +0 -0
- /khoj/interface/compiled/_next/static/{cbPHN2fL-jO8972HjaRbw → 7IDXliK3ZEqSeZICqmtPL}/_ssgManifest.js +0 -0
- /khoj/interface/compiled/_next/static/chunks/{1915-ab4353eaca76f690.js → 1915-1943ee8a628b893c.js} +0 -0
- /khoj/interface/compiled/_next/static/chunks/{2117-1c18aa2098982bf9.js → 2117-5a41630a2bd2eae8.js} +0 -0
- /khoj/interface/compiled/_next/static/chunks/{4363-4efaf12abe696251.js → 4363-e6ac2203564d1a3b.js} +0 -0
- /khoj/interface/compiled/_next/static/chunks/{4447-5d44807c40355b1a.js → 4447-e038b251d626c340.js} +0 -0
- /khoj/interface/compiled/_next/static/chunks/{8667-adbe6017a66cef10.js → 8667-8136f74e9a086fca.js} +0 -0
- /khoj/interface/compiled/_next/static/chunks/{9259-d8bcd9da9e80c81e.js → 9259-640fdd77408475df.js} +0 -0
- {khoj-1.40.1.dev20.dist-info → khoj-1.40.1.dev28.dist-info}/WHEEL +0 -0
- {khoj-1.40.1.dev20.dist-info → khoj-1.40.1.dev28.dist-info}/entry_points.txt +0 -0
- {khoj-1.40.1.dev20.dist-info → khoj-1.40.1.dev28.dist-info}/licenses/LICENSE +0 -0
@@ -1,6 +1,7 @@
|
|
1
|
+
import asyncio
|
1
2
|
import logging
|
2
3
|
from datetime import datetime, timedelta
|
3
|
-
from typing import Dict, List, Optional
|
4
|
+
from typing import AsyncGenerator, Dict, List, Optional
|
4
5
|
|
5
6
|
import pyjson5
|
6
7
|
from langchain.schema import ChatMessage
|
@@ -160,7 +161,7 @@ def gemini_send_message_to_model(
|
|
160
161
|
)
|
161
162
|
|
162
163
|
|
163
|
-
def converse_gemini(
|
164
|
+
async def converse_gemini(
|
164
165
|
references,
|
165
166
|
user_query,
|
166
167
|
online_results: Optional[Dict[str, Dict]] = None,
|
@@ -185,7 +186,7 @@ def converse_gemini(
|
|
185
186
|
program_execution_context: List[str] = None,
|
186
187
|
deepthought: Optional[bool] = False,
|
187
188
|
tracer={},
|
188
|
-
):
|
189
|
+
) -> AsyncGenerator[str, None]:
|
189
190
|
"""
|
190
191
|
Converse with user using Google's Gemini
|
191
192
|
"""
|
@@ -216,11 +217,17 @@ def converse_gemini(
|
|
216
217
|
|
217
218
|
# Get Conversation Primer appropriate to Conversation Type
|
218
219
|
if conversation_commands == [ConversationCommand.Notes] and is_none_or_empty(references):
|
219
|
-
|
220
|
-
|
220
|
+
response = prompts.no_notes_found.format()
|
221
|
+
if completion_func:
|
222
|
+
asyncio.create_task(completion_func(chat_response=response))
|
223
|
+
yield response
|
224
|
+
return
|
221
225
|
elif conversation_commands == [ConversationCommand.Online] and is_none_or_empty(online_results):
|
222
|
-
|
223
|
-
|
226
|
+
response = prompts.no_online_results_found.format()
|
227
|
+
if completion_func:
|
228
|
+
asyncio.create_task(completion_func(chat_response=response))
|
229
|
+
yield response
|
230
|
+
return
|
224
231
|
|
225
232
|
context_message = ""
|
226
233
|
if not is_none_or_empty(references):
|
@@ -253,16 +260,20 @@ def converse_gemini(
|
|
253
260
|
logger.debug(f"Conversation Context for Gemini: {messages_to_print(messages)}")
|
254
261
|
|
255
262
|
# Get Response from Google AI
|
256
|
-
|
263
|
+
full_response = ""
|
264
|
+
async for chunk in gemini_chat_completion_with_backoff(
|
257
265
|
messages=messages,
|
258
|
-
compiled_references=references,
|
259
|
-
online_results=online_results,
|
260
266
|
model_name=model,
|
261
267
|
temperature=temperature,
|
262
268
|
api_key=api_key,
|
263
269
|
api_base_url=api_base_url,
|
264
270
|
system_prompt=system_prompt,
|
265
|
-
completion_func=completion_func,
|
266
271
|
deepthought=deepthought,
|
267
272
|
tracer=tracer,
|
268
|
-
)
|
273
|
+
):
|
274
|
+
full_response += chunk
|
275
|
+
yield chunk
|
276
|
+
|
277
|
+
# Call completion_func once finish streaming and we have the full response
|
278
|
+
if completion_func:
|
279
|
+
asyncio.create_task(completion_func(chat_response=full_response))
|
@@ -2,8 +2,8 @@ import logging
|
|
2
2
|
import os
|
3
3
|
import random
|
4
4
|
from copy import deepcopy
|
5
|
-
from
|
6
|
-
from typing import Dict
|
5
|
+
from time import perf_counter
|
6
|
+
from typing import AsyncGenerator, AsyncIterator, Dict
|
7
7
|
|
8
8
|
from google import genai
|
9
9
|
from google.genai import errors as gerrors
|
@@ -19,14 +19,13 @@ from tenacity import (
|
|
19
19
|
)
|
20
20
|
|
21
21
|
from khoj.processor.conversation.utils import (
|
22
|
-
ThreadedGenerator,
|
23
22
|
commit_conversation_trace,
|
24
23
|
get_image_from_base64,
|
25
24
|
get_image_from_url,
|
26
25
|
)
|
27
26
|
from khoj.utils.helpers import (
|
28
|
-
get_ai_api_info,
|
29
27
|
get_chat_usage_metrics,
|
28
|
+
get_gemini_client,
|
30
29
|
is_none_or_empty,
|
31
30
|
is_promptrace_enabled,
|
32
31
|
)
|
@@ -62,17 +61,6 @@ SAFETY_SETTINGS = [
|
|
62
61
|
]
|
63
62
|
|
64
63
|
|
65
|
-
def get_gemini_client(api_key, api_base_url=None) -> genai.Client:
|
66
|
-
api_info = get_ai_api_info(api_key, api_base_url)
|
67
|
-
return genai.Client(
|
68
|
-
location=api_info.region,
|
69
|
-
project=api_info.project,
|
70
|
-
credentials=api_info.credentials,
|
71
|
-
api_key=api_info.api_key,
|
72
|
-
vertexai=api_info.api_key is None,
|
73
|
-
)
|
74
|
-
|
75
|
-
|
76
64
|
@retry(
|
77
65
|
wait=wait_random_exponential(min=1, max=10),
|
78
66
|
stop=stop_after_attempt(2),
|
@@ -132,8 +120,8 @@ def gemini_completion_with_backoff(
|
|
132
120
|
)
|
133
121
|
|
134
122
|
# Aggregate cost of chat
|
135
|
-
input_tokens = response.usage_metadata.prompt_token_count if response else 0
|
136
|
-
output_tokens = response.usage_metadata.candidates_token_count if response else 0
|
123
|
+
input_tokens = response.usage_metadata.prompt_token_count or 0 if response else 0
|
124
|
+
output_tokens = response.usage_metadata.candidates_token_count or 0 if response else 0
|
137
125
|
thought_tokens = response.usage_metadata.thoughts_token_count or 0 if response else 0
|
138
126
|
tracer["usage"] = get_chat_usage_metrics(
|
139
127
|
model_name, input_tokens, output_tokens, thought_tokens=thought_tokens, usage=tracer.get("usage")
|
@@ -154,52 +142,17 @@ def gemini_completion_with_backoff(
|
|
154
142
|
before_sleep=before_sleep_log(logger, logging.DEBUG),
|
155
143
|
reraise=True,
|
156
144
|
)
|
157
|
-
def gemini_chat_completion_with_backoff(
|
145
|
+
async def gemini_chat_completion_with_backoff(
|
158
146
|
messages,
|
159
|
-
compiled_references,
|
160
|
-
online_results,
|
161
147
|
model_name,
|
162
148
|
temperature,
|
163
149
|
api_key,
|
164
150
|
api_base_url,
|
165
151
|
system_prompt,
|
166
|
-
completion_func=None,
|
167
152
|
model_kwargs=None,
|
168
153
|
deepthought=False,
|
169
154
|
tracer: dict = {},
|
170
|
-
):
|
171
|
-
g = ThreadedGenerator(compiled_references, online_results, completion_func=completion_func)
|
172
|
-
t = Thread(
|
173
|
-
target=gemini_llm_thread,
|
174
|
-
args=(
|
175
|
-
g,
|
176
|
-
messages,
|
177
|
-
system_prompt,
|
178
|
-
model_name,
|
179
|
-
temperature,
|
180
|
-
api_key,
|
181
|
-
api_base_url,
|
182
|
-
model_kwargs,
|
183
|
-
deepthought,
|
184
|
-
tracer,
|
185
|
-
),
|
186
|
-
)
|
187
|
-
t.start()
|
188
|
-
return g
|
189
|
-
|
190
|
-
|
191
|
-
def gemini_llm_thread(
|
192
|
-
g,
|
193
|
-
messages,
|
194
|
-
system_prompt,
|
195
|
-
model_name,
|
196
|
-
temperature,
|
197
|
-
api_key,
|
198
|
-
api_base_url=None,
|
199
|
-
model_kwargs=None,
|
200
|
-
deepthought=False,
|
201
|
-
tracer: dict = {},
|
202
|
-
):
|
155
|
+
) -> AsyncGenerator[str, None]:
|
203
156
|
try:
|
204
157
|
client = gemini_clients.get(api_key)
|
205
158
|
if not client:
|
@@ -224,21 +177,32 @@ def gemini_llm_thread(
|
|
224
177
|
)
|
225
178
|
|
226
179
|
aggregated_response = ""
|
227
|
-
|
228
|
-
|
180
|
+
final_chunk = None
|
181
|
+
start_time = perf_counter()
|
182
|
+
chat_stream: AsyncIterator[gtypes.GenerateContentResponse] = await client.aio.models.generate_content_stream(
|
229
183
|
model=model_name, config=config, contents=formatted_messages
|
230
|
-
)
|
184
|
+
)
|
185
|
+
async for chunk in chat_stream:
|
186
|
+
# Log the time taken to start response
|
187
|
+
if final_chunk is None:
|
188
|
+
logger.info(f"First response took: {perf_counter() - start_time:.3f} seconds")
|
189
|
+
# Keep track of the last chunk for usage data
|
190
|
+
final_chunk = chunk
|
191
|
+
# Handle streamed response chunk
|
231
192
|
message, stopped = handle_gemini_response(chunk.candidates, chunk.prompt_feedback)
|
232
193
|
message = message or chunk.text
|
233
194
|
aggregated_response += message
|
234
|
-
|
195
|
+
yield message
|
235
196
|
if stopped:
|
236
197
|
raise ValueError(message)
|
237
198
|
|
199
|
+
# Log the time taken to stream the entire response
|
200
|
+
logger.info(f"Chat streaming took: {perf_counter() - start_time:.3f} seconds")
|
201
|
+
|
238
202
|
# Calculate cost of chat
|
239
|
-
input_tokens =
|
240
|
-
output_tokens =
|
241
|
-
thought_tokens =
|
203
|
+
input_tokens = final_chunk.usage_metadata.prompt_token_count or 0 if final_chunk else 0
|
204
|
+
output_tokens = final_chunk.usage_metadata.candidates_token_count or 0 if final_chunk else 0
|
205
|
+
thought_tokens = final_chunk.usage_metadata.thoughts_token_count or 0 if final_chunk else 0
|
242
206
|
tracer["usage"] = get_chat_usage_metrics(
|
243
207
|
model_name, input_tokens, output_tokens, thought_tokens=thought_tokens, usage=tracer.get("usage")
|
244
208
|
)
|
@@ -254,9 +218,7 @@ def gemini_llm_thread(
|
|
254
218
|
+ f"Last Message by {messages[-1].role}: {messages[-1].content}"
|
255
219
|
)
|
256
220
|
except Exception as e:
|
257
|
-
logger.error(f"Error in
|
258
|
-
finally:
|
259
|
-
g.close()
|
221
|
+
logger.error(f"Error in gemini_chat_completion_with_backoff stream: {e}", exc_info=True)
|
260
222
|
|
261
223
|
|
262
224
|
def handle_gemini_response(
|
@@ -1,9 +1,10 @@
|
|
1
|
-
import
|
1
|
+
import asyncio
|
2
2
|
import logging
|
3
3
|
import os
|
4
4
|
from datetime import datetime, timedelta
|
5
5
|
from threading import Thread
|
6
|
-
from
|
6
|
+
from time import perf_counter
|
7
|
+
from typing import Any, AsyncGenerator, Dict, List, Optional, Union
|
7
8
|
|
8
9
|
import pyjson5
|
9
10
|
from langchain.schema import ChatMessage
|
@@ -13,7 +14,6 @@ from khoj.database.models import Agent, ChatModel, KhojUser
|
|
13
14
|
from khoj.processor.conversation import prompts
|
14
15
|
from khoj.processor.conversation.offline.utils import download_model
|
15
16
|
from khoj.processor.conversation.utils import (
|
16
|
-
ThreadedGenerator,
|
17
17
|
clean_json,
|
18
18
|
commit_conversation_trace,
|
19
19
|
generate_chatml_messages_with_context,
|
@@ -147,7 +147,7 @@ def filter_questions(questions: List[str]):
|
|
147
147
|
return list(filtered_questions)
|
148
148
|
|
149
149
|
|
150
|
-
def converse_offline(
|
150
|
+
async def converse_offline(
|
151
151
|
user_query,
|
152
152
|
references=[],
|
153
153
|
online_results={},
|
@@ -167,9 +167,9 @@ def converse_offline(
|
|
167
167
|
additional_context: List[str] = None,
|
168
168
|
generated_asset_results: Dict[str, Dict] = {},
|
169
169
|
tracer: dict = {},
|
170
|
-
) ->
|
170
|
+
) -> AsyncGenerator[str, None]:
|
171
171
|
"""
|
172
|
-
Converse with user using Llama
|
172
|
+
Converse with user using Llama (Async Version)
|
173
173
|
"""
|
174
174
|
# Initialize Variables
|
175
175
|
assert loaded_model is None or isinstance(loaded_model, Llama), "loaded_model must be of type Llama, if configured"
|
@@ -200,10 +200,17 @@ def converse_offline(
|
|
200
200
|
|
201
201
|
# Get Conversation Primer appropriate to Conversation Type
|
202
202
|
if conversation_commands == [ConversationCommand.Notes] and is_none_or_empty(references):
|
203
|
-
|
203
|
+
response = prompts.no_notes_found.format()
|
204
|
+
if completion_func:
|
205
|
+
asyncio.create_task(completion_func(chat_response=response))
|
206
|
+
yield response
|
207
|
+
return
|
204
208
|
elif conversation_commands == [ConversationCommand.Online] and is_none_or_empty(online_results):
|
205
|
-
|
206
|
-
|
209
|
+
response = prompts.no_online_results_found.format()
|
210
|
+
if completion_func:
|
211
|
+
asyncio.create_task(completion_func(chat_response=response))
|
212
|
+
yield response
|
213
|
+
return
|
207
214
|
|
208
215
|
context_message = ""
|
209
216
|
if not is_none_or_empty(references):
|
@@ -240,33 +247,77 @@ def converse_offline(
|
|
240
247
|
|
241
248
|
logger.debug(f"Conversation Context for {model_name}: {messages_to_print(messages)}")
|
242
249
|
|
243
|
-
|
244
|
-
|
245
|
-
t.start()
|
246
|
-
return g
|
247
|
-
|
248
|
-
|
249
|
-
def llm_thread(g, messages: List[ChatMessage], model: Any, max_prompt_size: int = None, tracer: dict = {}):
|
250
|
+
# Use asyncio.Queue and a thread to bridge sync iterator
|
251
|
+
queue: asyncio.Queue = asyncio.Queue()
|
250
252
|
stop_phrases = ["<s>", "INST]", "Notes:"]
|
251
|
-
|
252
|
-
|
253
|
-
|
254
|
-
|
255
|
-
|
256
|
-
|
257
|
-
)
|
258
|
-
|
259
|
-
|
260
|
-
|
261
|
-
|
262
|
-
|
263
|
-
|
264
|
-
|
265
|
-
|
266
|
-
|
267
|
-
|
268
|
-
|
269
|
-
|
253
|
+
aggregated_response_container = {"response": ""}
|
254
|
+
|
255
|
+
def _sync_llm_thread():
|
256
|
+
"""Synchronous function to run in a separate thread."""
|
257
|
+
aggregated_response = ""
|
258
|
+
start_time = perf_counter()
|
259
|
+
state.chat_lock.acquire()
|
260
|
+
try:
|
261
|
+
response_iterator = send_message_to_model_offline(
|
262
|
+
messages,
|
263
|
+
loaded_model=offline_chat_model,
|
264
|
+
stop=stop_phrases,
|
265
|
+
max_prompt_size=max_prompt_size,
|
266
|
+
streaming=True,
|
267
|
+
tracer=tracer,
|
268
|
+
)
|
269
|
+
for response in response_iterator:
|
270
|
+
response_delta = response["choices"][0]["delta"].get("content", "")
|
271
|
+
# Log the time taken to start response
|
272
|
+
if aggregated_response == "" and response_delta != "":
|
273
|
+
logger.info(f"First response took: {perf_counter() - start_time:.3f} seconds")
|
274
|
+
# Handle response chunk
|
275
|
+
aggregated_response += response_delta
|
276
|
+
# Put chunk into the asyncio queue (non-blocking)
|
277
|
+
try:
|
278
|
+
queue.put_nowait(response_delta)
|
279
|
+
except asyncio.QueueFull:
|
280
|
+
# Should not happen with default queue size unless consumer is very slow
|
281
|
+
logger.warning("Asyncio queue full during offline LLM streaming.")
|
282
|
+
# Potentially block here or handle differently if needed
|
283
|
+
asyncio.run(queue.put(response_delta))
|
284
|
+
|
285
|
+
# Log the time taken to stream the entire response
|
286
|
+
logger.info(f"Chat streaming took: {perf_counter() - start_time:.3f} seconds")
|
287
|
+
|
288
|
+
# Save conversation trace
|
289
|
+
tracer["chat_model"] = model_name
|
290
|
+
if is_promptrace_enabled():
|
291
|
+
commit_conversation_trace(messages, aggregated_response, tracer)
|
292
|
+
|
293
|
+
except Exception as e:
|
294
|
+
logger.error(f"Error in offline LLM thread: {e}", exc_info=True)
|
295
|
+
finally:
|
296
|
+
state.chat_lock.release()
|
297
|
+
# Signal end of stream
|
298
|
+
queue.put_nowait(None)
|
299
|
+
aggregated_response_container["response"] = aggregated_response
|
300
|
+
|
301
|
+
# Start the synchronous thread
|
302
|
+
thread = Thread(target=_sync_llm_thread)
|
303
|
+
thread.start()
|
304
|
+
|
305
|
+
# Asynchronously consume from the queue
|
306
|
+
while True:
|
307
|
+
chunk = await queue.get()
|
308
|
+
if chunk is None: # End of stream signal
|
309
|
+
queue.task_done()
|
310
|
+
break
|
311
|
+
yield chunk
|
312
|
+
queue.task_done()
|
313
|
+
|
314
|
+
# Wait for the thread to finish (optional, ensures cleanup)
|
315
|
+
loop = asyncio.get_running_loop()
|
316
|
+
await loop.run_in_executor(None, thread.join)
|
317
|
+
|
318
|
+
# Call the completion function after streaming is done
|
319
|
+
if completion_func:
|
320
|
+
asyncio.create_task(completion_func(chat_response=aggregated_response_container["response"]))
|
270
321
|
|
271
322
|
|
272
323
|
def send_message_to_model_offline(
|
@@ -1,6 +1,7 @@
|
|
1
|
+
import asyncio
|
1
2
|
import logging
|
2
3
|
from datetime import datetime, timedelta
|
3
|
-
from typing import Dict, List, Optional
|
4
|
+
from typing import AsyncGenerator, Dict, List, Optional
|
4
5
|
|
5
6
|
import pyjson5
|
6
7
|
from langchain.schema import ChatMessage
|
@@ -162,7 +163,7 @@ def send_message_to_model(
|
|
162
163
|
)
|
163
164
|
|
164
165
|
|
165
|
-
def converse_openai(
|
166
|
+
async def converse_openai(
|
166
167
|
references,
|
167
168
|
user_query,
|
168
169
|
online_results: Optional[Dict[str, Dict]] = None,
|
@@ -187,7 +188,7 @@ def converse_openai(
|
|
187
188
|
program_execution_context: List[str] = None,
|
188
189
|
deepthought: Optional[bool] = False,
|
189
190
|
tracer: dict = {},
|
190
|
-
):
|
191
|
+
) -> AsyncGenerator[str, None]:
|
191
192
|
"""
|
192
193
|
Converse with user using OpenAI's ChatGPT
|
193
194
|
"""
|
@@ -217,11 +218,17 @@ def converse_openai(
|
|
217
218
|
|
218
219
|
# Get Conversation Primer appropriate to Conversation Type
|
219
220
|
if conversation_commands == [ConversationCommand.Notes] and is_none_or_empty(references):
|
220
|
-
|
221
|
-
|
221
|
+
response = prompts.no_notes_found.format()
|
222
|
+
if completion_func:
|
223
|
+
asyncio.create_task(completion_func(chat_response=response))
|
224
|
+
yield response
|
225
|
+
return
|
222
226
|
elif conversation_commands == [ConversationCommand.Online] and is_none_or_empty(online_results):
|
223
|
-
|
224
|
-
|
227
|
+
response = prompts.no_online_results_found.format()
|
228
|
+
if completion_func:
|
229
|
+
asyncio.create_task(completion_func(chat_response=response))
|
230
|
+
yield response
|
231
|
+
return
|
225
232
|
|
226
233
|
context_message = ""
|
227
234
|
if not is_none_or_empty(references):
|
@@ -255,19 +262,23 @@ def converse_openai(
|
|
255
262
|
logger.debug(f"Conversation Context for GPT: {messages_to_print(messages)}")
|
256
263
|
|
257
264
|
# Get Response from GPT
|
258
|
-
|
265
|
+
full_response = ""
|
266
|
+
async for chunk in chat_completion_with_backoff(
|
259
267
|
messages=messages,
|
260
|
-
compiled_references=references,
|
261
|
-
online_results=online_results,
|
262
268
|
model_name=model,
|
263
269
|
temperature=temperature,
|
264
270
|
openai_api_key=api_key,
|
265
271
|
api_base_url=api_base_url,
|
266
|
-
completion_func=completion_func,
|
267
272
|
deepthought=deepthought,
|
268
273
|
model_kwargs={"stop": ["Notes:\n["]},
|
269
274
|
tracer=tracer,
|
270
|
-
)
|
275
|
+
):
|
276
|
+
full_response += chunk
|
277
|
+
yield chunk
|
278
|
+
|
279
|
+
# Call completion_func once finish streaming and we have the full response
|
280
|
+
if completion_func:
|
281
|
+
asyncio.create_task(completion_func(chat_response=full_response))
|
271
282
|
|
272
283
|
|
273
284
|
def clean_response_schema(schema: BaseModel | dict) -> dict:
|
@@ -1,7 +1,7 @@
|
|
1
1
|
import logging
|
2
2
|
import os
|
3
|
-
from
|
4
|
-
from typing import Dict, List
|
3
|
+
from time import perf_counter
|
4
|
+
from typing import AsyncGenerator, Dict, List
|
5
5
|
from urllib.parse import urlparse
|
6
6
|
|
7
7
|
import openai
|
@@ -16,13 +16,10 @@ from tenacity import (
|
|
16
16
|
wait_random_exponential,
|
17
17
|
)
|
18
18
|
|
19
|
-
from khoj.processor.conversation.utils import
|
20
|
-
JsonSupport,
|
21
|
-
ThreadedGenerator,
|
22
|
-
commit_conversation_trace,
|
23
|
-
)
|
19
|
+
from khoj.processor.conversation.utils import JsonSupport, commit_conversation_trace
|
24
20
|
from khoj.utils.helpers import (
|
25
21
|
get_chat_usage_metrics,
|
22
|
+
get_openai_async_client,
|
26
23
|
get_openai_client,
|
27
24
|
is_promptrace_enabled,
|
28
25
|
)
|
@@ -30,6 +27,7 @@ from khoj.utils.helpers import (
|
|
30
27
|
logger = logging.getLogger(__name__)
|
31
28
|
|
32
29
|
openai_clients: Dict[str, openai.OpenAI] = {}
|
30
|
+
openai_async_clients: Dict[str, openai.AsyncOpenAI] = {}
|
33
31
|
|
34
32
|
|
35
33
|
@retry(
|
@@ -124,45 +122,22 @@ def completion_with_backoff(
|
|
124
122
|
before_sleep=before_sleep_log(logger, logging.DEBUG),
|
125
123
|
reraise=True,
|
126
124
|
)
|
127
|
-
def chat_completion_with_backoff(
|
125
|
+
async def chat_completion_with_backoff(
|
128
126
|
messages,
|
129
|
-
compiled_references,
|
130
|
-
online_results,
|
131
127
|
model_name,
|
132
128
|
temperature,
|
133
129
|
openai_api_key=None,
|
134
130
|
api_base_url=None,
|
135
|
-
completion_func=None,
|
136
|
-
deepthought=False,
|
137
|
-
model_kwargs=None,
|
138
|
-
tracer: dict = {},
|
139
|
-
):
|
140
|
-
g = ThreadedGenerator(compiled_references, online_results, completion_func=completion_func)
|
141
|
-
t = Thread(
|
142
|
-
target=llm_thread,
|
143
|
-
args=(g, messages, model_name, temperature, openai_api_key, api_base_url, deepthought, model_kwargs, tracer),
|
144
|
-
)
|
145
|
-
t.start()
|
146
|
-
return g
|
147
|
-
|
148
|
-
|
149
|
-
def llm_thread(
|
150
|
-
g,
|
151
|
-
messages,
|
152
|
-
model_name: str,
|
153
|
-
temperature,
|
154
|
-
openai_api_key=None,
|
155
|
-
api_base_url=None,
|
156
131
|
deepthought=False,
|
157
132
|
model_kwargs: dict = {},
|
158
133
|
tracer: dict = {},
|
159
|
-
):
|
134
|
+
) -> AsyncGenerator[str, None]:
|
160
135
|
try:
|
161
136
|
client_key = f"{openai_api_key}--{api_base_url}"
|
162
|
-
client =
|
137
|
+
client = openai_async_clients.get(client_key)
|
163
138
|
if not client:
|
164
|
-
client =
|
165
|
-
|
139
|
+
client = get_openai_async_client(openai_api_key, api_base_url)
|
140
|
+
openai_async_clients[client_key] = client
|
166
141
|
|
167
142
|
formatted_messages = [{"role": message.role, "content": message.content} for message in messages]
|
168
143
|
|
@@ -207,53 +182,58 @@ def llm_thread(
|
|
207
182
|
if os.getenv("KHOJ_LLM_SEED"):
|
208
183
|
model_kwargs["seed"] = int(os.getenv("KHOJ_LLM_SEED"))
|
209
184
|
|
210
|
-
|
211
|
-
|
212
|
-
|
185
|
+
aggregated_response = ""
|
186
|
+
final_chunk = None
|
187
|
+
start_time = perf_counter()
|
188
|
+
chat_stream: openai.AsyncStream[ChatCompletionChunk] = await client.chat.completions.create(
|
189
|
+
messages=formatted_messages, # type: ignore
|
190
|
+
model=model_name,
|
213
191
|
stream=stream,
|
214
192
|
temperature=temperature,
|
215
193
|
timeout=20,
|
216
194
|
**model_kwargs,
|
217
195
|
)
|
218
|
-
|
219
|
-
|
220
|
-
|
221
|
-
|
222
|
-
|
223
|
-
|
224
|
-
|
225
|
-
|
226
|
-
|
227
|
-
|
228
|
-
|
229
|
-
|
230
|
-
|
231
|
-
|
232
|
-
|
233
|
-
|
234
|
-
|
235
|
-
|
236
|
-
|
237
|
-
|
238
|
-
|
239
|
-
|
240
|
-
|
241
|
-
cost =
|
242
|
-
|
243
|
-
|
244
|
-
|
245
|
-
|
246
|
-
|
196
|
+
async for chunk in chat_stream:
|
197
|
+
# Log the time taken to start response
|
198
|
+
if final_chunk is None:
|
199
|
+
logger.info(f"First response took: {perf_counter() - start_time:.3f} seconds")
|
200
|
+
# Keep track of the last chunk for usage data
|
201
|
+
final_chunk = chunk
|
202
|
+
# Handle streamed response chunk
|
203
|
+
if len(chunk.choices) == 0:
|
204
|
+
continue
|
205
|
+
delta_chunk = chunk.choices[0].delta
|
206
|
+
text_chunk = ""
|
207
|
+
if isinstance(delta_chunk, str):
|
208
|
+
text_chunk = delta_chunk
|
209
|
+
elif delta_chunk and delta_chunk.content:
|
210
|
+
text_chunk = delta_chunk.content
|
211
|
+
if text_chunk:
|
212
|
+
aggregated_response += text_chunk
|
213
|
+
yield text_chunk
|
214
|
+
|
215
|
+
# Log the time taken to stream the entire response
|
216
|
+
logger.info(f"Chat streaming took: {perf_counter() - start_time:.3f} seconds")
|
217
|
+
|
218
|
+
# Calculate cost of chat after stream finishes
|
219
|
+
input_tokens, output_tokens, cost = 0, 0, 0
|
220
|
+
if final_chunk and hasattr(final_chunk, "usage") and final_chunk.usage:
|
221
|
+
input_tokens = final_chunk.usage.prompt_tokens
|
222
|
+
output_tokens = final_chunk.usage.completion_tokens
|
223
|
+
# Estimated costs returned by DeepInfra API
|
224
|
+
if final_chunk.usage.model_extra and "estimated_cost" in final_chunk.usage.model_extra:
|
225
|
+
cost = final_chunk.usage.model_extra.get("estimated_cost", 0)
|
247
226
|
|
248
227
|
# Save conversation trace
|
249
228
|
tracer["chat_model"] = model_name
|
250
229
|
tracer["temperature"] = temperature
|
230
|
+
tracer["usage"] = get_chat_usage_metrics(
|
231
|
+
model_name, input_tokens, output_tokens, usage=tracer.get("usage"), cost=cost
|
232
|
+
)
|
251
233
|
if is_promptrace_enabled():
|
252
234
|
commit_conversation_trace(messages, aggregated_response, tracer)
|
253
235
|
except Exception as e:
|
254
|
-
logger.error(f"Error in
|
255
|
-
finally:
|
256
|
-
g.close()
|
236
|
+
logger.error(f"Error in chat_completion_with_backoff stream: {e}", exc_info=True)
|
257
237
|
|
258
238
|
|
259
239
|
def get_openai_api_json_support(model_name: str, api_base_url: str = None) -> JsonSupport:
|