khoj 1.40.1.dev20__py3-none-any.whl → 1.40.1.dev28__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (41) hide show
  1. khoj/database/adapters/__init__.py +24 -22
  2. khoj/interface/compiled/404/index.html +2 -2
  3. khoj/interface/compiled/_next/static/chunks/{webpack-639e8a9d35bd5e13.js → webpack-05ff3cbe22520b30.js} +1 -1
  4. khoj/interface/compiled/agents/index.html +2 -2
  5. khoj/interface/compiled/agents/index.txt +2 -2
  6. khoj/interface/compiled/automations/index.html +2 -2
  7. khoj/interface/compiled/automations/index.txt +3 -3
  8. khoj/interface/compiled/chat/index.html +2 -2
  9. khoj/interface/compiled/chat/index.txt +2 -2
  10. khoj/interface/compiled/index.html +2 -2
  11. khoj/interface/compiled/index.txt +2 -2
  12. khoj/interface/compiled/search/index.html +2 -2
  13. khoj/interface/compiled/search/index.txt +2 -2
  14. khoj/interface/compiled/settings/index.html +2 -2
  15. khoj/interface/compiled/settings/index.txt +4 -4
  16. khoj/interface/compiled/share/chat/index.html +2 -2
  17. khoj/interface/compiled/share/chat/index.txt +2 -2
  18. khoj/processor/conversation/anthropic/anthropic_chat.py +23 -12
  19. khoj/processor/conversation/anthropic/utils.py +21 -64
  20. khoj/processor/conversation/google/gemini_chat.py +23 -12
  21. khoj/processor/conversation/google/utils.py +26 -64
  22. khoj/processor/conversation/offline/chat_model.py +86 -35
  23. khoj/processor/conversation/openai/gpt.py +23 -12
  24. khoj/processor/conversation/openai/utils.py +50 -70
  25. khoj/processor/conversation/utils.py +2 -38
  26. khoj/routers/api_chat.py +36 -28
  27. khoj/routers/helpers.py +13 -23
  28. khoj/utils/helpers.py +56 -0
  29. {khoj-1.40.1.dev20.dist-info → khoj-1.40.1.dev28.dist-info}/METADATA +1 -1
  30. {khoj-1.40.1.dev20.dist-info → khoj-1.40.1.dev28.dist-info}/RECORD +41 -41
  31. /khoj/interface/compiled/_next/static/{cbPHN2fL-jO8972HjaRbw → 7IDXliK3ZEqSeZICqmtPL}/_buildManifest.js +0 -0
  32. /khoj/interface/compiled/_next/static/{cbPHN2fL-jO8972HjaRbw → 7IDXliK3ZEqSeZICqmtPL}/_ssgManifest.js +0 -0
  33. /khoj/interface/compiled/_next/static/chunks/{1915-ab4353eaca76f690.js → 1915-1943ee8a628b893c.js} +0 -0
  34. /khoj/interface/compiled/_next/static/chunks/{2117-1c18aa2098982bf9.js → 2117-5a41630a2bd2eae8.js} +0 -0
  35. /khoj/interface/compiled/_next/static/chunks/{4363-4efaf12abe696251.js → 4363-e6ac2203564d1a3b.js} +0 -0
  36. /khoj/interface/compiled/_next/static/chunks/{4447-5d44807c40355b1a.js → 4447-e038b251d626c340.js} +0 -0
  37. /khoj/interface/compiled/_next/static/chunks/{8667-adbe6017a66cef10.js → 8667-8136f74e9a086fca.js} +0 -0
  38. /khoj/interface/compiled/_next/static/chunks/{9259-d8bcd9da9e80c81e.js → 9259-640fdd77408475df.js} +0 -0
  39. {khoj-1.40.1.dev20.dist-info → khoj-1.40.1.dev28.dist-info}/WHEEL +0 -0
  40. {khoj-1.40.1.dev20.dist-info → khoj-1.40.1.dev28.dist-info}/entry_points.txt +0 -0
  41. {khoj-1.40.1.dev20.dist-info → khoj-1.40.1.dev28.dist-info}/licenses/LICENSE +0 -0
@@ -1,6 +1,7 @@
1
+ import asyncio
1
2
  import logging
2
3
  from datetime import datetime, timedelta
3
- from typing import Dict, List, Optional
4
+ from typing import AsyncGenerator, Dict, List, Optional
4
5
 
5
6
  import pyjson5
6
7
  from langchain.schema import ChatMessage
@@ -160,7 +161,7 @@ def gemini_send_message_to_model(
160
161
  )
161
162
 
162
163
 
163
- def converse_gemini(
164
+ async def converse_gemini(
164
165
  references,
165
166
  user_query,
166
167
  online_results: Optional[Dict[str, Dict]] = None,
@@ -185,7 +186,7 @@ def converse_gemini(
185
186
  program_execution_context: List[str] = None,
186
187
  deepthought: Optional[bool] = False,
187
188
  tracer={},
188
- ):
189
+ ) -> AsyncGenerator[str, None]:
189
190
  """
190
191
  Converse with user using Google's Gemini
191
192
  """
@@ -216,11 +217,17 @@ def converse_gemini(
216
217
 
217
218
  # Get Conversation Primer appropriate to Conversation Type
218
219
  if conversation_commands == [ConversationCommand.Notes] and is_none_or_empty(references):
219
- completion_func(chat_response=prompts.no_notes_found.format())
220
- return iter([prompts.no_notes_found.format()])
220
+ response = prompts.no_notes_found.format()
221
+ if completion_func:
222
+ asyncio.create_task(completion_func(chat_response=response))
223
+ yield response
224
+ return
221
225
  elif conversation_commands == [ConversationCommand.Online] and is_none_or_empty(online_results):
222
- completion_func(chat_response=prompts.no_online_results_found.format())
223
- return iter([prompts.no_online_results_found.format()])
226
+ response = prompts.no_online_results_found.format()
227
+ if completion_func:
228
+ asyncio.create_task(completion_func(chat_response=response))
229
+ yield response
230
+ return
224
231
 
225
232
  context_message = ""
226
233
  if not is_none_or_empty(references):
@@ -253,16 +260,20 @@ def converse_gemini(
253
260
  logger.debug(f"Conversation Context for Gemini: {messages_to_print(messages)}")
254
261
 
255
262
  # Get Response from Google AI
256
- return gemini_chat_completion_with_backoff(
263
+ full_response = ""
264
+ async for chunk in gemini_chat_completion_with_backoff(
257
265
  messages=messages,
258
- compiled_references=references,
259
- online_results=online_results,
260
266
  model_name=model,
261
267
  temperature=temperature,
262
268
  api_key=api_key,
263
269
  api_base_url=api_base_url,
264
270
  system_prompt=system_prompt,
265
- completion_func=completion_func,
266
271
  deepthought=deepthought,
267
272
  tracer=tracer,
268
- )
273
+ ):
274
+ full_response += chunk
275
+ yield chunk
276
+
277
+ # Call completion_func once finish streaming and we have the full response
278
+ if completion_func:
279
+ asyncio.create_task(completion_func(chat_response=full_response))
@@ -2,8 +2,8 @@ import logging
2
2
  import os
3
3
  import random
4
4
  from copy import deepcopy
5
- from threading import Thread
6
- from typing import Dict
5
+ from time import perf_counter
6
+ from typing import AsyncGenerator, AsyncIterator, Dict
7
7
 
8
8
  from google import genai
9
9
  from google.genai import errors as gerrors
@@ -19,14 +19,13 @@ from tenacity import (
19
19
  )
20
20
 
21
21
  from khoj.processor.conversation.utils import (
22
- ThreadedGenerator,
23
22
  commit_conversation_trace,
24
23
  get_image_from_base64,
25
24
  get_image_from_url,
26
25
  )
27
26
  from khoj.utils.helpers import (
28
- get_ai_api_info,
29
27
  get_chat_usage_metrics,
28
+ get_gemini_client,
30
29
  is_none_or_empty,
31
30
  is_promptrace_enabled,
32
31
  )
@@ -62,17 +61,6 @@ SAFETY_SETTINGS = [
62
61
  ]
63
62
 
64
63
 
65
- def get_gemini_client(api_key, api_base_url=None) -> genai.Client:
66
- api_info = get_ai_api_info(api_key, api_base_url)
67
- return genai.Client(
68
- location=api_info.region,
69
- project=api_info.project,
70
- credentials=api_info.credentials,
71
- api_key=api_info.api_key,
72
- vertexai=api_info.api_key is None,
73
- )
74
-
75
-
76
64
  @retry(
77
65
  wait=wait_random_exponential(min=1, max=10),
78
66
  stop=stop_after_attempt(2),
@@ -132,8 +120,8 @@ def gemini_completion_with_backoff(
132
120
  )
133
121
 
134
122
  # Aggregate cost of chat
135
- input_tokens = response.usage_metadata.prompt_token_count if response else 0
136
- output_tokens = response.usage_metadata.candidates_token_count if response else 0
123
+ input_tokens = response.usage_metadata.prompt_token_count or 0 if response else 0
124
+ output_tokens = response.usage_metadata.candidates_token_count or 0 if response else 0
137
125
  thought_tokens = response.usage_metadata.thoughts_token_count or 0 if response else 0
138
126
  tracer["usage"] = get_chat_usage_metrics(
139
127
  model_name, input_tokens, output_tokens, thought_tokens=thought_tokens, usage=tracer.get("usage")
@@ -154,52 +142,17 @@ def gemini_completion_with_backoff(
154
142
  before_sleep=before_sleep_log(logger, logging.DEBUG),
155
143
  reraise=True,
156
144
  )
157
- def gemini_chat_completion_with_backoff(
145
+ async def gemini_chat_completion_with_backoff(
158
146
  messages,
159
- compiled_references,
160
- online_results,
161
147
  model_name,
162
148
  temperature,
163
149
  api_key,
164
150
  api_base_url,
165
151
  system_prompt,
166
- completion_func=None,
167
152
  model_kwargs=None,
168
153
  deepthought=False,
169
154
  tracer: dict = {},
170
- ):
171
- g = ThreadedGenerator(compiled_references, online_results, completion_func=completion_func)
172
- t = Thread(
173
- target=gemini_llm_thread,
174
- args=(
175
- g,
176
- messages,
177
- system_prompt,
178
- model_name,
179
- temperature,
180
- api_key,
181
- api_base_url,
182
- model_kwargs,
183
- deepthought,
184
- tracer,
185
- ),
186
- )
187
- t.start()
188
- return g
189
-
190
-
191
- def gemini_llm_thread(
192
- g,
193
- messages,
194
- system_prompt,
195
- model_name,
196
- temperature,
197
- api_key,
198
- api_base_url=None,
199
- model_kwargs=None,
200
- deepthought=False,
201
- tracer: dict = {},
202
- ):
155
+ ) -> AsyncGenerator[str, None]:
203
156
  try:
204
157
  client = gemini_clients.get(api_key)
205
158
  if not client:
@@ -224,21 +177,32 @@ def gemini_llm_thread(
224
177
  )
225
178
 
226
179
  aggregated_response = ""
227
-
228
- for chunk in client.models.generate_content_stream(
180
+ final_chunk = None
181
+ start_time = perf_counter()
182
+ chat_stream: AsyncIterator[gtypes.GenerateContentResponse] = await client.aio.models.generate_content_stream(
229
183
  model=model_name, config=config, contents=formatted_messages
230
- ):
184
+ )
185
+ async for chunk in chat_stream:
186
+ # Log the time taken to start response
187
+ if final_chunk is None:
188
+ logger.info(f"First response took: {perf_counter() - start_time:.3f} seconds")
189
+ # Keep track of the last chunk for usage data
190
+ final_chunk = chunk
191
+ # Handle streamed response chunk
231
192
  message, stopped = handle_gemini_response(chunk.candidates, chunk.prompt_feedback)
232
193
  message = message or chunk.text
233
194
  aggregated_response += message
234
- g.send(message)
195
+ yield message
235
196
  if stopped:
236
197
  raise ValueError(message)
237
198
 
199
+ # Log the time taken to stream the entire response
200
+ logger.info(f"Chat streaming took: {perf_counter() - start_time:.3f} seconds")
201
+
238
202
  # Calculate cost of chat
239
- input_tokens = chunk.usage_metadata.prompt_token_count
240
- output_tokens = chunk.usage_metadata.candidates_token_count
241
- thought_tokens = chunk.usage_metadata.thoughts_token_count or 0
203
+ input_tokens = final_chunk.usage_metadata.prompt_token_count or 0 if final_chunk else 0
204
+ output_tokens = final_chunk.usage_metadata.candidates_token_count or 0 if final_chunk else 0
205
+ thought_tokens = final_chunk.usage_metadata.thoughts_token_count or 0 if final_chunk else 0
242
206
  tracer["usage"] = get_chat_usage_metrics(
243
207
  model_name, input_tokens, output_tokens, thought_tokens=thought_tokens, usage=tracer.get("usage")
244
208
  )
@@ -254,9 +218,7 @@ def gemini_llm_thread(
254
218
  + f"Last Message by {messages[-1].role}: {messages[-1].content}"
255
219
  )
256
220
  except Exception as e:
257
- logger.error(f"Error in gemini_llm_thread: {e}", exc_info=True)
258
- finally:
259
- g.close()
221
+ logger.error(f"Error in gemini_chat_completion_with_backoff stream: {e}", exc_info=True)
260
222
 
261
223
 
262
224
  def handle_gemini_response(
@@ -1,9 +1,10 @@
1
- import json
1
+ import asyncio
2
2
  import logging
3
3
  import os
4
4
  from datetime import datetime, timedelta
5
5
  from threading import Thread
6
- from typing import Any, Dict, Iterator, List, Optional, Union
6
+ from time import perf_counter
7
+ from typing import Any, AsyncGenerator, Dict, List, Optional, Union
7
8
 
8
9
  import pyjson5
9
10
  from langchain.schema import ChatMessage
@@ -13,7 +14,6 @@ from khoj.database.models import Agent, ChatModel, KhojUser
13
14
  from khoj.processor.conversation import prompts
14
15
  from khoj.processor.conversation.offline.utils import download_model
15
16
  from khoj.processor.conversation.utils import (
16
- ThreadedGenerator,
17
17
  clean_json,
18
18
  commit_conversation_trace,
19
19
  generate_chatml_messages_with_context,
@@ -147,7 +147,7 @@ def filter_questions(questions: List[str]):
147
147
  return list(filtered_questions)
148
148
 
149
149
 
150
- def converse_offline(
150
+ async def converse_offline(
151
151
  user_query,
152
152
  references=[],
153
153
  online_results={},
@@ -167,9 +167,9 @@ def converse_offline(
167
167
  additional_context: List[str] = None,
168
168
  generated_asset_results: Dict[str, Dict] = {},
169
169
  tracer: dict = {},
170
- ) -> Union[ThreadedGenerator, Iterator[str]]:
170
+ ) -> AsyncGenerator[str, None]:
171
171
  """
172
- Converse with user using Llama
172
+ Converse with user using Llama (Async Version)
173
173
  """
174
174
  # Initialize Variables
175
175
  assert loaded_model is None or isinstance(loaded_model, Llama), "loaded_model must be of type Llama, if configured"
@@ -200,10 +200,17 @@ def converse_offline(
200
200
 
201
201
  # Get Conversation Primer appropriate to Conversation Type
202
202
  if conversation_commands == [ConversationCommand.Notes] and is_none_or_empty(references):
203
- return iter([prompts.no_notes_found.format()])
203
+ response = prompts.no_notes_found.format()
204
+ if completion_func:
205
+ asyncio.create_task(completion_func(chat_response=response))
206
+ yield response
207
+ return
204
208
  elif conversation_commands == [ConversationCommand.Online] and is_none_or_empty(online_results):
205
- completion_func(chat_response=prompts.no_online_results_found.format())
206
- return iter([prompts.no_online_results_found.format()])
209
+ response = prompts.no_online_results_found.format()
210
+ if completion_func:
211
+ asyncio.create_task(completion_func(chat_response=response))
212
+ yield response
213
+ return
207
214
 
208
215
  context_message = ""
209
216
  if not is_none_or_empty(references):
@@ -240,33 +247,77 @@ def converse_offline(
240
247
 
241
248
  logger.debug(f"Conversation Context for {model_name}: {messages_to_print(messages)}")
242
249
 
243
- g = ThreadedGenerator(references, online_results, completion_func=completion_func)
244
- t = Thread(target=llm_thread, args=(g, messages, offline_chat_model, max_prompt_size, tracer))
245
- t.start()
246
- return g
247
-
248
-
249
- def llm_thread(g, messages: List[ChatMessage], model: Any, max_prompt_size: int = None, tracer: dict = {}):
250
+ # Use asyncio.Queue and a thread to bridge sync iterator
251
+ queue: asyncio.Queue = asyncio.Queue()
250
252
  stop_phrases = ["<s>", "INST]", "Notes:"]
251
- aggregated_response = ""
252
-
253
- state.chat_lock.acquire()
254
- try:
255
- response_iterator = send_message_to_model_offline(
256
- messages, loaded_model=model, stop=stop_phrases, max_prompt_size=max_prompt_size, streaming=True
257
- )
258
- for response in response_iterator:
259
- response_delta = response["choices"][0]["delta"].get("content", "")
260
- aggregated_response += response_delta
261
- g.send(response_delta)
262
-
263
- # Save conversation trace
264
- if is_promptrace_enabled():
265
- commit_conversation_trace(messages, aggregated_response, tracer)
266
-
267
- finally:
268
- state.chat_lock.release()
269
- g.close()
253
+ aggregated_response_container = {"response": ""}
254
+
255
+ def _sync_llm_thread():
256
+ """Synchronous function to run in a separate thread."""
257
+ aggregated_response = ""
258
+ start_time = perf_counter()
259
+ state.chat_lock.acquire()
260
+ try:
261
+ response_iterator = send_message_to_model_offline(
262
+ messages,
263
+ loaded_model=offline_chat_model,
264
+ stop=stop_phrases,
265
+ max_prompt_size=max_prompt_size,
266
+ streaming=True,
267
+ tracer=tracer,
268
+ )
269
+ for response in response_iterator:
270
+ response_delta = response["choices"][0]["delta"].get("content", "")
271
+ # Log the time taken to start response
272
+ if aggregated_response == "" and response_delta != "":
273
+ logger.info(f"First response took: {perf_counter() - start_time:.3f} seconds")
274
+ # Handle response chunk
275
+ aggregated_response += response_delta
276
+ # Put chunk into the asyncio queue (non-blocking)
277
+ try:
278
+ queue.put_nowait(response_delta)
279
+ except asyncio.QueueFull:
280
+ # Should not happen with default queue size unless consumer is very slow
281
+ logger.warning("Asyncio queue full during offline LLM streaming.")
282
+ # Potentially block here or handle differently if needed
283
+ asyncio.run(queue.put(response_delta))
284
+
285
+ # Log the time taken to stream the entire response
286
+ logger.info(f"Chat streaming took: {perf_counter() - start_time:.3f} seconds")
287
+
288
+ # Save conversation trace
289
+ tracer["chat_model"] = model_name
290
+ if is_promptrace_enabled():
291
+ commit_conversation_trace(messages, aggregated_response, tracer)
292
+
293
+ except Exception as e:
294
+ logger.error(f"Error in offline LLM thread: {e}", exc_info=True)
295
+ finally:
296
+ state.chat_lock.release()
297
+ # Signal end of stream
298
+ queue.put_nowait(None)
299
+ aggregated_response_container["response"] = aggregated_response
300
+
301
+ # Start the synchronous thread
302
+ thread = Thread(target=_sync_llm_thread)
303
+ thread.start()
304
+
305
+ # Asynchronously consume from the queue
306
+ while True:
307
+ chunk = await queue.get()
308
+ if chunk is None: # End of stream signal
309
+ queue.task_done()
310
+ break
311
+ yield chunk
312
+ queue.task_done()
313
+
314
+ # Wait for the thread to finish (optional, ensures cleanup)
315
+ loop = asyncio.get_running_loop()
316
+ await loop.run_in_executor(None, thread.join)
317
+
318
+ # Call the completion function after streaming is done
319
+ if completion_func:
320
+ asyncio.create_task(completion_func(chat_response=aggregated_response_container["response"]))
270
321
 
271
322
 
272
323
  def send_message_to_model_offline(
@@ -1,6 +1,7 @@
1
+ import asyncio
1
2
  import logging
2
3
  from datetime import datetime, timedelta
3
- from typing import Dict, List, Optional
4
+ from typing import AsyncGenerator, Dict, List, Optional
4
5
 
5
6
  import pyjson5
6
7
  from langchain.schema import ChatMessage
@@ -162,7 +163,7 @@ def send_message_to_model(
162
163
  )
163
164
 
164
165
 
165
- def converse_openai(
166
+ async def converse_openai(
166
167
  references,
167
168
  user_query,
168
169
  online_results: Optional[Dict[str, Dict]] = None,
@@ -187,7 +188,7 @@ def converse_openai(
187
188
  program_execution_context: List[str] = None,
188
189
  deepthought: Optional[bool] = False,
189
190
  tracer: dict = {},
190
- ):
191
+ ) -> AsyncGenerator[str, None]:
191
192
  """
192
193
  Converse with user using OpenAI's ChatGPT
193
194
  """
@@ -217,11 +218,17 @@ def converse_openai(
217
218
 
218
219
  # Get Conversation Primer appropriate to Conversation Type
219
220
  if conversation_commands == [ConversationCommand.Notes] and is_none_or_empty(references):
220
- completion_func(chat_response=prompts.no_notes_found.format())
221
- return iter([prompts.no_notes_found.format()])
221
+ response = prompts.no_notes_found.format()
222
+ if completion_func:
223
+ asyncio.create_task(completion_func(chat_response=response))
224
+ yield response
225
+ return
222
226
  elif conversation_commands == [ConversationCommand.Online] and is_none_or_empty(online_results):
223
- completion_func(chat_response=prompts.no_online_results_found.format())
224
- return iter([prompts.no_online_results_found.format()])
227
+ response = prompts.no_online_results_found.format()
228
+ if completion_func:
229
+ asyncio.create_task(completion_func(chat_response=response))
230
+ yield response
231
+ return
225
232
 
226
233
  context_message = ""
227
234
  if not is_none_or_empty(references):
@@ -255,19 +262,23 @@ def converse_openai(
255
262
  logger.debug(f"Conversation Context for GPT: {messages_to_print(messages)}")
256
263
 
257
264
  # Get Response from GPT
258
- return chat_completion_with_backoff(
265
+ full_response = ""
266
+ async for chunk in chat_completion_with_backoff(
259
267
  messages=messages,
260
- compiled_references=references,
261
- online_results=online_results,
262
268
  model_name=model,
263
269
  temperature=temperature,
264
270
  openai_api_key=api_key,
265
271
  api_base_url=api_base_url,
266
- completion_func=completion_func,
267
272
  deepthought=deepthought,
268
273
  model_kwargs={"stop": ["Notes:\n["]},
269
274
  tracer=tracer,
270
- )
275
+ ):
276
+ full_response += chunk
277
+ yield chunk
278
+
279
+ # Call completion_func once finish streaming and we have the full response
280
+ if completion_func:
281
+ asyncio.create_task(completion_func(chat_response=full_response))
271
282
 
272
283
 
273
284
  def clean_response_schema(schema: BaseModel | dict) -> dict:
@@ -1,7 +1,7 @@
1
1
  import logging
2
2
  import os
3
- from threading import Thread
4
- from typing import Dict, List
3
+ from time import perf_counter
4
+ from typing import AsyncGenerator, Dict, List
5
5
  from urllib.parse import urlparse
6
6
 
7
7
  import openai
@@ -16,13 +16,10 @@ from tenacity import (
16
16
  wait_random_exponential,
17
17
  )
18
18
 
19
- from khoj.processor.conversation.utils import (
20
- JsonSupport,
21
- ThreadedGenerator,
22
- commit_conversation_trace,
23
- )
19
+ from khoj.processor.conversation.utils import JsonSupport, commit_conversation_trace
24
20
  from khoj.utils.helpers import (
25
21
  get_chat_usage_metrics,
22
+ get_openai_async_client,
26
23
  get_openai_client,
27
24
  is_promptrace_enabled,
28
25
  )
@@ -30,6 +27,7 @@ from khoj.utils.helpers import (
30
27
  logger = logging.getLogger(__name__)
31
28
 
32
29
  openai_clients: Dict[str, openai.OpenAI] = {}
30
+ openai_async_clients: Dict[str, openai.AsyncOpenAI] = {}
33
31
 
34
32
 
35
33
  @retry(
@@ -124,45 +122,22 @@ def completion_with_backoff(
124
122
  before_sleep=before_sleep_log(logger, logging.DEBUG),
125
123
  reraise=True,
126
124
  )
127
- def chat_completion_with_backoff(
125
+ async def chat_completion_with_backoff(
128
126
  messages,
129
- compiled_references,
130
- online_results,
131
127
  model_name,
132
128
  temperature,
133
129
  openai_api_key=None,
134
130
  api_base_url=None,
135
- completion_func=None,
136
- deepthought=False,
137
- model_kwargs=None,
138
- tracer: dict = {},
139
- ):
140
- g = ThreadedGenerator(compiled_references, online_results, completion_func=completion_func)
141
- t = Thread(
142
- target=llm_thread,
143
- args=(g, messages, model_name, temperature, openai_api_key, api_base_url, deepthought, model_kwargs, tracer),
144
- )
145
- t.start()
146
- return g
147
-
148
-
149
- def llm_thread(
150
- g,
151
- messages,
152
- model_name: str,
153
- temperature,
154
- openai_api_key=None,
155
- api_base_url=None,
156
131
  deepthought=False,
157
132
  model_kwargs: dict = {},
158
133
  tracer: dict = {},
159
- ):
134
+ ) -> AsyncGenerator[str, None]:
160
135
  try:
161
136
  client_key = f"{openai_api_key}--{api_base_url}"
162
- client = openai_clients.get(client_key)
137
+ client = openai_async_clients.get(client_key)
163
138
  if not client:
164
- client = get_openai_client(openai_api_key, api_base_url)
165
- openai_clients[client_key] = client
139
+ client = get_openai_async_client(openai_api_key, api_base_url)
140
+ openai_async_clients[client_key] = client
166
141
 
167
142
  formatted_messages = [{"role": message.role, "content": message.content} for message in messages]
168
143
 
@@ -207,53 +182,58 @@ def llm_thread(
207
182
  if os.getenv("KHOJ_LLM_SEED"):
208
183
  model_kwargs["seed"] = int(os.getenv("KHOJ_LLM_SEED"))
209
184
 
210
- chat: ChatCompletion | openai.Stream[ChatCompletionChunk] = client.chat.completions.create(
211
- messages=formatted_messages,
212
- model=model_name, # type: ignore
185
+ aggregated_response = ""
186
+ final_chunk = None
187
+ start_time = perf_counter()
188
+ chat_stream: openai.AsyncStream[ChatCompletionChunk] = await client.chat.completions.create(
189
+ messages=formatted_messages, # type: ignore
190
+ model=model_name,
213
191
  stream=stream,
214
192
  temperature=temperature,
215
193
  timeout=20,
216
194
  **model_kwargs,
217
195
  )
218
-
219
- aggregated_response = ""
220
- if not stream:
221
- chunk = chat
222
- aggregated_response = chunk.choices[0].message.content
223
- g.send(aggregated_response)
224
- else:
225
- for chunk in chat:
226
- if len(chunk.choices) == 0:
227
- continue
228
- delta_chunk = chunk.choices[0].delta
229
- text_chunk = ""
230
- if isinstance(delta_chunk, str):
231
- text_chunk = delta_chunk
232
- elif delta_chunk.content:
233
- text_chunk = delta_chunk.content
234
- if text_chunk:
235
- aggregated_response += text_chunk
236
- g.send(text_chunk)
237
-
238
- # Calculate cost of chat
239
- input_tokens = chunk.usage.prompt_tokens if hasattr(chunk, "usage") and chunk.usage else 0
240
- output_tokens = chunk.usage.completion_tokens if hasattr(chunk, "usage") and chunk.usage else 0
241
- cost = (
242
- chunk.usage.model_extra.get("estimated_cost", 0) if hasattr(chunk, "usage") and chunk.usage else 0
243
- ) # Estimated costs returned by DeepInfra API
244
- tracer["usage"] = get_chat_usage_metrics(
245
- model_name, input_tokens, output_tokens, usage=tracer.get("usage"), cost=cost
246
- )
196
+ async for chunk in chat_stream:
197
+ # Log the time taken to start response
198
+ if final_chunk is None:
199
+ logger.info(f"First response took: {perf_counter() - start_time:.3f} seconds")
200
+ # Keep track of the last chunk for usage data
201
+ final_chunk = chunk
202
+ # Handle streamed response chunk
203
+ if len(chunk.choices) == 0:
204
+ continue
205
+ delta_chunk = chunk.choices[0].delta
206
+ text_chunk = ""
207
+ if isinstance(delta_chunk, str):
208
+ text_chunk = delta_chunk
209
+ elif delta_chunk and delta_chunk.content:
210
+ text_chunk = delta_chunk.content
211
+ if text_chunk:
212
+ aggregated_response += text_chunk
213
+ yield text_chunk
214
+
215
+ # Log the time taken to stream the entire response
216
+ logger.info(f"Chat streaming took: {perf_counter() - start_time:.3f} seconds")
217
+
218
+ # Calculate cost of chat after stream finishes
219
+ input_tokens, output_tokens, cost = 0, 0, 0
220
+ if final_chunk and hasattr(final_chunk, "usage") and final_chunk.usage:
221
+ input_tokens = final_chunk.usage.prompt_tokens
222
+ output_tokens = final_chunk.usage.completion_tokens
223
+ # Estimated costs returned by DeepInfra API
224
+ if final_chunk.usage.model_extra and "estimated_cost" in final_chunk.usage.model_extra:
225
+ cost = final_chunk.usage.model_extra.get("estimated_cost", 0)
247
226
 
248
227
  # Save conversation trace
249
228
  tracer["chat_model"] = model_name
250
229
  tracer["temperature"] = temperature
230
+ tracer["usage"] = get_chat_usage_metrics(
231
+ model_name, input_tokens, output_tokens, usage=tracer.get("usage"), cost=cost
232
+ )
251
233
  if is_promptrace_enabled():
252
234
  commit_conversation_trace(messages, aggregated_response, tracer)
253
235
  except Exception as e:
254
- logger.error(f"Error in llm_thread: {e}", exc_info=True)
255
- finally:
256
- g.close()
236
+ logger.error(f"Error in chat_completion_with_backoff stream: {e}", exc_info=True)
257
237
 
258
238
 
259
239
  def get_openai_api_json_support(model_name: str, api_base_url: str = None) -> JsonSupport: