khoj 1.41.1.dev4__py3-none-any.whl → 1.41.1.dev16__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (51) hide show
  1. khoj/interface/compiled/404/index.html +2 -2
  2. khoj/interface/compiled/_next/static/chunks/2327-aa22697ed9c8d54a.js +1 -0
  3. khoj/interface/compiled/_next/static/chunks/app/agents/layout-e00fb81dca656a10.js +1 -0
  4. khoj/interface/compiled/_next/static/chunks/app/agents/{page-996513ae80f8720c.js → page-ceeb9a91edea74ce.js} +1 -1
  5. khoj/interface/compiled/_next/static/chunks/app/automations/{page-2320231573aa9a49.js → page-e3cb78747ab98cc7.js} +1 -1
  6. khoj/interface/compiled/_next/static/chunks/app/chat/layout-33934fc2d6ae6838.js +1 -0
  7. khoj/interface/compiled/_next/static/chunks/app/chat/{page-6e81dbf18637a86e.js → page-7e780dc11eb5e5d3.js} +1 -1
  8. khoj/interface/compiled/_next/static/chunks/app/{page-d9a2e44bbcf49f82.js → page-a4053e1bb578b2ce.js} +1 -1
  9. khoj/interface/compiled/_next/static/chunks/app/search/layout-c02531d586972d7d.js +1 -0
  10. khoj/interface/compiled/_next/static/chunks/app/search/{page-31452bbda0e0a56f.js → page-8973da2f4c076fe1.js} +1 -1
  11. khoj/interface/compiled/_next/static/chunks/app/settings/{page-fdb72b15ca908b43.js → page-375136dbb400525b.js} +1 -1
  12. khoj/interface/compiled/_next/static/chunks/app/share/chat/layout-e8e5db7830bf3f47.js +1 -0
  13. khoj/interface/compiled/_next/static/chunks/app/share/chat/{page-5b7cb35d835af900.js → page-384b54fc953b18f2.js} +1 -1
  14. khoj/interface/compiled/_next/static/chunks/{webpack-b5922a670d3076e8.js → webpack-05ff3cbe22520b30.js} +1 -1
  15. khoj/interface/compiled/_next/static/css/f29752d6e1be7624.css +1 -0
  16. khoj/interface/compiled/_next/static/css/{0db53bacf81896f5.css → fca983d49c3dd1a3.css} +1 -1
  17. khoj/interface/compiled/agents/index.html +2 -2
  18. khoj/interface/compiled/agents/index.txt +2 -2
  19. khoj/interface/compiled/automations/index.html +2 -2
  20. khoj/interface/compiled/automations/index.txt +2 -2
  21. khoj/interface/compiled/chat/index.html +2 -2
  22. khoj/interface/compiled/chat/index.txt +2 -2
  23. khoj/interface/compiled/index.html +2 -2
  24. khoj/interface/compiled/index.txt +2 -2
  25. khoj/interface/compiled/search/index.html +2 -2
  26. khoj/interface/compiled/search/index.txt +2 -2
  27. khoj/interface/compiled/settings/index.html +2 -2
  28. khoj/interface/compiled/settings/index.txt +2 -2
  29. khoj/interface/compiled/share/chat/index.html +2 -2
  30. khoj/interface/compiled/share/chat/index.txt +2 -2
  31. khoj/processor/conversation/anthropic/anthropic_chat.py +4 -2
  32. khoj/processor/conversation/anthropic/utils.py +14 -3
  33. khoj/processor/conversation/openai/gpt.py +4 -2
  34. khoj/processor/conversation/openai/utils.py +334 -23
  35. khoj/processor/conversation/utils.py +7 -0
  36. khoj/routers/api_chat.py +87 -25
  37. khoj/routers/helpers.py +54 -119
  38. khoj/routers/research.py +7 -0
  39. {khoj-1.41.1.dev4.dist-info → khoj-1.41.1.dev16.dist-info}/METADATA +1 -1
  40. {khoj-1.41.1.dev4.dist-info → khoj-1.41.1.dev16.dist-info}/RECORD +45 -45
  41. khoj/interface/compiled/_next/static/chunks/2327-c99ead647a0ee901.js +0 -1
  42. khoj/interface/compiled/_next/static/chunks/app/agents/layout-4e2a134ec26aa606.js +0 -1
  43. khoj/interface/compiled/_next/static/chunks/app/chat/layout-ad4d1792ab1a4108.js +0 -1
  44. khoj/interface/compiled/_next/static/chunks/app/search/layout-f5881c7ae3ba0795.js +0 -1
  45. khoj/interface/compiled/_next/static/chunks/app/share/chat/layout-abb6c5f4239ad7be.js +0 -1
  46. khoj/interface/compiled/_next/static/css/55d4a822f8d94b67.css +0 -1
  47. /khoj/interface/compiled/_next/static/{jbvFiURrQG-AB37JAwuIG → h-E6l3I7yBCfhaSWaXDb_}/_buildManifest.js +0 -0
  48. /khoj/interface/compiled/_next/static/{jbvFiURrQG-AB37JAwuIG → h-E6l3I7yBCfhaSWaXDb_}/_ssgManifest.js +0 -0
  49. {khoj-1.41.1.dev4.dist-info → khoj-1.41.1.dev16.dist-info}/WHEEL +0 -0
  50. {khoj-1.41.1.dev4.dist-info → khoj-1.41.1.dev16.dist-info}/entry_points.txt +0 -0
  51. {khoj-1.41.1.dev4.dist-info → khoj-1.41.1.dev16.dist-info}/licenses/LICENSE +0 -0
@@ -1,12 +1,21 @@
1
1
  import logging
2
2
  import os
3
+ from functools import partial
3
4
  from time import perf_counter
4
- from typing import AsyncGenerator, Dict, List
5
+ from typing import AsyncGenerator, Dict, Generator, List, Literal, Optional, Union
5
6
  from urllib.parse import urlparse
6
7
 
7
8
  import openai
8
- from openai.types.chat.chat_completion import ChatCompletion
9
- from openai.types.chat.chat_completion_chunk import ChatCompletionChunk
9
+ from openai.lib.streaming.chat import (
10
+ ChatCompletionStream,
11
+ ChatCompletionStreamEvent,
12
+ ContentDeltaEvent,
13
+ )
14
+ from openai.types.chat.chat_completion_chunk import (
15
+ ChatCompletionChunk,
16
+ Choice,
17
+ ChoiceDelta,
18
+ )
10
19
  from tenacity import (
11
20
  before_sleep_log,
12
21
  retry,
@@ -16,7 +25,11 @@ from tenacity import (
16
25
  wait_random_exponential,
17
26
  )
18
27
 
19
- from khoj.processor.conversation.utils import JsonSupport, commit_conversation_trace
28
+ from khoj.processor.conversation.utils import (
29
+ JsonSupport,
30
+ ResponseWithThought,
31
+ commit_conversation_trace,
32
+ )
20
33
  from khoj.utils.helpers import (
21
34
  get_chat_usage_metrics,
22
35
  get_openai_async_client,
@@ -59,6 +72,7 @@ def completion_with_backoff(
59
72
  client = get_openai_client(openai_api_key, api_base_url)
60
73
  openai_clients[client_key] = client
61
74
 
75
+ stream_processor = default_stream_processor
62
76
  formatted_messages = [{"role": message.role, "content": message.content} for message in messages]
63
77
 
64
78
  # Tune reasoning models arguments
@@ -69,6 +83,24 @@ def completion_with_backoff(
69
83
  elif is_twitter_reasoning_model(model_name, api_base_url):
70
84
  reasoning_effort = "high" if deepthought else "low"
71
85
  model_kwargs["reasoning_effort"] = reasoning_effort
86
+ elif model_name.startswith("deepseek-reasoner"):
87
+ # Two successive messages cannot be from the same role. Should merge any back-to-back messages from the same role.
88
+ # The first message should always be a user message (except system message).
89
+ updated_messages: List[dict] = []
90
+ for i, message in enumerate(formatted_messages):
91
+ if i > 0 and message["role"] == formatted_messages[i - 1]["role"]:
92
+ updated_messages[-1]["content"] += " " + message["content"]
93
+ elif i == 1 and formatted_messages[i - 1]["role"] == "system" and message["role"] == "assistant":
94
+ updated_messages[-1]["content"] += " " + message["content"]
95
+ else:
96
+ updated_messages.append(message)
97
+ formatted_messages = updated_messages
98
+ elif is_qwen_reasoning_model(model_name, api_base_url):
99
+ stream_processor = partial(in_stream_thought_processor, thought_tag="think")
100
+ # Reasoning is enabled by default. Disable when deepthought is False.
101
+ # See https://qwenlm.github.io/blog/qwen3/#advanced-usages
102
+ if not deepthought and len(formatted_messages) > 0:
103
+ formatted_messages[-1]["content"] = formatted_messages[-1]["content"] + " /no_think"
72
104
 
73
105
  model_kwargs["stream_options"] = {"include_usage": True}
74
106
  if os.getenv("KHOJ_LLM_SEED"):
@@ -82,12 +114,11 @@ def completion_with_backoff(
82
114
  timeout=20,
83
115
  **model_kwargs,
84
116
  ) as chat:
85
- for chunk in chat:
86
- if chunk.type == "error":
87
- logger.error(f"Openai api response error: {chunk.error}", exc_info=True)
88
- continue
89
- elif chunk.type == "content.delta":
117
+ for chunk in stream_processor(chat):
118
+ if chunk.type == "content.delta":
90
119
  aggregated_response += chunk.delta
120
+ elif chunk.type == "thought.delta":
121
+ pass
91
122
 
92
123
  # Calculate cost of chat
93
124
  input_tokens = chunk.usage.prompt_tokens if hasattr(chunk, "usage") and chunk.usage else 0
@@ -124,14 +155,14 @@ def completion_with_backoff(
124
155
  )
125
156
  async def chat_completion_with_backoff(
126
157
  messages,
127
- model_name,
158
+ model_name: str,
128
159
  temperature,
129
160
  openai_api_key=None,
130
161
  api_base_url=None,
131
162
  deepthought=False,
132
163
  model_kwargs: dict = {},
133
164
  tracer: dict = {},
134
- ) -> AsyncGenerator[str, None]:
165
+ ) -> AsyncGenerator[ResponseWithThought, None]:
135
166
  try:
136
167
  client_key = f"{openai_api_key}--{api_base_url}"
137
168
  client = openai_async_clients.get(client_key)
@@ -139,6 +170,7 @@ async def chat_completion_with_backoff(
139
170
  client = get_openai_async_client(openai_api_key, api_base_url)
140
171
  openai_async_clients[client_key] = client
141
172
 
173
+ stream_processor = adefault_stream_processor
142
174
  formatted_messages = [{"role": message.role, "content": message.content} for message in messages]
143
175
 
144
176
  # Configure thinking for openai reasoning models
@@ -161,9 +193,11 @@ async def chat_completion_with_backoff(
161
193
  "content"
162
194
  ] = f"{first_system_message_content}\nFormatting re-enabled"
163
195
  elif is_twitter_reasoning_model(model_name, api_base_url):
196
+ stream_processor = adeepseek_stream_processor
164
197
  reasoning_effort = "high" if deepthought else "low"
165
198
  model_kwargs["reasoning_effort"] = reasoning_effort
166
199
  elif model_name.startswith("deepseek-reasoner"):
200
+ stream_processor = adeepseek_stream_processor
167
201
  # Two successive messages cannot be from the same role. Should merge any back-to-back messages from the same role.
168
202
  # The first message should always be a user message (except system message).
169
203
  updated_messages: List[dict] = []
@@ -174,8 +208,13 @@ async def chat_completion_with_backoff(
174
208
  updated_messages[-1]["content"] += " " + message["content"]
175
209
  else:
176
210
  updated_messages.append(message)
177
-
178
211
  formatted_messages = updated_messages
212
+ elif is_qwen_reasoning_model(model_name, api_base_url):
213
+ stream_processor = partial(ain_stream_thought_processor, thought_tag="think")
214
+ # Reasoning is enabled by default. Disable when deepthought is False.
215
+ # See https://qwenlm.github.io/blog/qwen3/#advanced-usages
216
+ if not deepthought and len(formatted_messages) > 0:
217
+ formatted_messages[-1]["content"] = formatted_messages[-1]["content"] + " /no_think"
179
218
 
180
219
  stream = True
181
220
  model_kwargs["stream_options"] = {"include_usage": True}
@@ -193,24 +232,25 @@ async def chat_completion_with_backoff(
193
232
  timeout=20,
194
233
  **model_kwargs,
195
234
  )
196
- async for chunk in chat_stream:
235
+ async for chunk in stream_processor(chat_stream):
197
236
  # Log the time taken to start response
198
237
  if final_chunk is None:
199
238
  logger.info(f"First response took: {perf_counter() - start_time:.3f} seconds")
200
239
  # Keep track of the last chunk for usage data
201
240
  final_chunk = chunk
202
- # Handle streamed response chunk
241
+ # Skip empty chunks
203
242
  if len(chunk.choices) == 0:
204
243
  continue
205
- delta_chunk = chunk.choices[0].delta
206
- text_chunk = ""
207
- if isinstance(delta_chunk, str):
208
- text_chunk = delta_chunk
209
- elif delta_chunk and delta_chunk.content:
210
- text_chunk = delta_chunk.content
211
- if text_chunk:
212
- aggregated_response += text_chunk
213
- yield text_chunk
244
+ # Handle streamed response chunk
245
+ response_chunk: ResponseWithThought = None
246
+ response_delta = chunk.choices[0].delta
247
+ if response_delta.content:
248
+ response_chunk = ResponseWithThought(response=response_delta.content)
249
+ aggregated_response += response_chunk.response
250
+ elif response_delta.thought:
251
+ response_chunk = ResponseWithThought(thought=response_delta.thought)
252
+ if response_chunk:
253
+ yield response_chunk
214
254
 
215
255
  # Log the time taken to stream the entire response
216
256
  logger.info(f"Chat streaming took: {perf_counter() - start_time:.3f} seconds")
@@ -264,3 +304,274 @@ def is_twitter_reasoning_model(model_name: str, api_base_url: str = None) -> boo
264
304
  and api_base_url is not None
265
305
  and api_base_url.startswith("https://api.x.ai/v1")
266
306
  )
307
+
308
+
309
+ def is_qwen_reasoning_model(model_name: str, api_base_url: str = None) -> bool:
310
+ """
311
+ Check if the model is a Qwen reasoning model
312
+ """
313
+ return "qwen3" in model_name.lower() and api_base_url is not None
314
+
315
+
316
+ class ThoughtDeltaEvent(ContentDeltaEvent):
317
+ """
318
+ Chat completion chunk with thoughts, reasoning support.
319
+ """
320
+
321
+ type: Literal["thought.delta"]
322
+ """The thought or reasoning generated by the model."""
323
+
324
+
325
+ ChatCompletionStreamWithThoughtEvent = Union[ChatCompletionStreamEvent, ThoughtDeltaEvent]
326
+
327
+
328
+ class ChoiceDeltaWithThoughts(ChoiceDelta):
329
+ """
330
+ Chat completion chunk with thoughts, reasoning support.
331
+ """
332
+
333
+ thought: Optional[str] = None
334
+ """The thought or reasoning generated by the model."""
335
+
336
+
337
+ class ChoiceWithThoughts(Choice):
338
+ delta: ChoiceDeltaWithThoughts
339
+
340
+
341
+ class ChatCompletionWithThoughtsChunk(ChatCompletionChunk):
342
+ choices: List[ChoiceWithThoughts] # Override the choices type
343
+
344
+
345
+ def default_stream_processor(
346
+ chat_stream: ChatCompletionStream,
347
+ ) -> Generator[ChatCompletionStreamWithThoughtEvent, None, None]:
348
+ """
349
+ Async generator to cast and return chunks from the standard openai chat completions stream.
350
+ """
351
+ for chunk in chat_stream:
352
+ yield chunk
353
+
354
+
355
+ async def adefault_stream_processor(
356
+ chat_stream: openai.AsyncStream[ChatCompletionChunk],
357
+ ) -> AsyncGenerator[ChatCompletionWithThoughtsChunk, None]:
358
+ """
359
+ Async generator to cast and return chunks from the standard openai chat completions stream.
360
+ """
361
+ async for chunk in chat_stream:
362
+ yield ChatCompletionWithThoughtsChunk.model_validate(chunk.model_dump())
363
+
364
+
365
+ async def adeepseek_stream_processor(
366
+ chat_stream: openai.AsyncStream[ChatCompletionChunk],
367
+ ) -> AsyncGenerator[ChatCompletionWithThoughtsChunk, None]:
368
+ """
369
+ Async generator to cast and return chunks from the deepseek chat completions stream.
370
+ """
371
+ async for chunk in chat_stream:
372
+ tchunk = ChatCompletionWithThoughtsChunk.model_validate(chunk.model_dump())
373
+ if (
374
+ len(tchunk.choices) > 0
375
+ and hasattr(tchunk.choices[0].delta, "reasoning_content")
376
+ and tchunk.choices[0].delta.reasoning_content
377
+ ):
378
+ tchunk.choices[0].delta.thought = chunk.choices[0].delta.reasoning_content
379
+ yield tchunk
380
+
381
+
382
+ def in_stream_thought_processor(
383
+ chat_stream: openai.Stream[ChatCompletionChunk], thought_tag="think"
384
+ ) -> Generator[ChatCompletionStreamWithThoughtEvent, None, None]:
385
+ """
386
+ Generator for chat completion with thought chunks.
387
+ Assumes <thought_tag>...</thought_tag> can only appear once at the start.
388
+ Handles partial tags across streamed chunks.
389
+ """
390
+ start_tag = f"<{thought_tag}>"
391
+ end_tag = f"</{thought_tag}>"
392
+ buf: str = ""
393
+ # Modes and transitions: detect_start > thought (optional) > message
394
+ mode = "detect_start"
395
+
396
+ for chunk in default_stream_processor(chat_stream):
397
+ if mode == "message" or chunk.type != "content.delta":
398
+ # Message mode is terminal, so just yield chunks, no processing
399
+ yield chunk
400
+ continue
401
+
402
+ buf += chunk.delta
403
+
404
+ if mode == "detect_start":
405
+ # Try to determine if we start with thought tag
406
+ if buf.startswith(start_tag):
407
+ # Found start tag, switch mode
408
+ buf = buf[len(start_tag) :] # Remove start tag
409
+ mode = "thought"
410
+ # Fall through to process the rest of the buffer in 'thought' mode *within this iteration*
411
+ elif len(buf) >= len(start_tag):
412
+ # Buffer is long enough, definitely doesn't start with tag
413
+ chunk.delta = buf
414
+ yield chunk
415
+ mode = "message"
416
+ buf = ""
417
+ continue
418
+ elif start_tag.startswith(buf):
419
+ # Buffer is a prefix of the start tag, need more data
420
+ continue
421
+ else:
422
+ # Buffer doesn't match start tag prefix and is shorter than tag
423
+ chunk.delta = buf
424
+ yield chunk
425
+ mode = "message"
426
+ buf = ""
427
+ continue
428
+
429
+ if mode == "thought":
430
+ # Look for the end tag
431
+ idx = buf.find(end_tag)
432
+ if idx != -1:
433
+ # Found end tag. Yield thought content before it.
434
+ if idx > 0 and buf[:idx].strip():
435
+ chunk.type = "thought.delta"
436
+ chunk.delta = buf[:idx]
437
+ yield chunk
438
+ # Process content *after* the tag as message
439
+ buf = buf[idx + len(end_tag) :]
440
+ if buf:
441
+ chunk.delta = buf
442
+ yield chunk
443
+ mode = "message"
444
+ buf = ""
445
+ continue
446
+ else:
447
+ # End tag not found yet. Yield thought content, holding back potential partial end tag.
448
+ send_upto = len(buf)
449
+ # Check if buffer ends with a prefix of end_tag
450
+ for i in range(len(end_tag) - 1, 0, -1):
451
+ if buf.endswith(end_tag[:i]):
452
+ send_upto = len(buf) - i # Don't send the partial tag yet
453
+ break
454
+ if send_upto > 0 and buf[:send_upto].strip():
455
+ chunk.type = "thought.delta"
456
+ chunk.delta = buf[:send_upto]
457
+ yield chunk
458
+ buf = buf[send_upto:] # Keep only the partial tag (or empty)
459
+ # Need more data to find the complete end tag
460
+ continue
461
+
462
+ # End of stream handling
463
+ if buf:
464
+ if mode == "thought": # Stream ended before </think> was found
465
+ chunk.type = "thought.delta"
466
+ chunk.delta = buf
467
+ yield chunk
468
+ elif mode == "detect_start": # Stream ended before start tag could be confirmed/denied
469
+ # If it wasn't a partial start tag, treat as message
470
+ if not start_tag.startswith(buf):
471
+ chunk.delta = buf
472
+ yield chunk
473
+ # else: discard partial <think>
474
+ # If mode == "message", buffer should be empty due to logic above, but yield just in case
475
+ elif mode == "message":
476
+ chunk.delta = buf
477
+ yield chunk
478
+
479
+
480
+ async def ain_stream_thought_processor(
481
+ chat_stream: openai.AsyncStream[ChatCompletionChunk], thought_tag="think"
482
+ ) -> AsyncGenerator[ChatCompletionWithThoughtsChunk, None]:
483
+ """
484
+ Async generator for chat completion with thought chunks.
485
+ Assumes <thought_tag>...</thought_tag> can only appear once at the start.
486
+ Handles partial tags across streamed chunks.
487
+ """
488
+ start_tag = f"<{thought_tag}>"
489
+ end_tag = f"</{thought_tag}>"
490
+ buf: str = ""
491
+ # Modes and transitions: detect_start > thought (optional) > message
492
+ mode = "detect_start"
493
+
494
+ async for chunk in adefault_stream_processor(chat_stream):
495
+ if len(chunk.choices) == 0:
496
+ continue
497
+ if mode == "message":
498
+ # Message mode is terminal, so just yield chunks, no processing
499
+ yield chunk
500
+ continue
501
+
502
+ buf += chunk.choices[0].delta.content
503
+
504
+ if mode == "detect_start":
505
+ # Try to determine if we start with thought tag
506
+ if buf.startswith(start_tag):
507
+ # Found start tag, switch mode
508
+ buf = buf[len(start_tag) :] # Remove start tag
509
+ mode = "thought"
510
+ # Fall through to process the rest of the buffer in 'thought' mode *within this iteration*
511
+ elif len(buf) >= len(start_tag):
512
+ # Buffer is long enough, definitely doesn't start with tag
513
+ chunk.choices[0].delta.content = buf
514
+ yield chunk
515
+ mode = "message"
516
+ buf = ""
517
+ continue
518
+ elif start_tag.startswith(buf):
519
+ # Buffer is a prefix of the start tag, need more data
520
+ continue
521
+ else:
522
+ # Buffer doesn't match start tag prefix and is shorter than tag
523
+ chunk.choices[0].delta.content = buf
524
+ yield chunk
525
+ mode = "message"
526
+ buf = ""
527
+ continue
528
+
529
+ if mode == "thought":
530
+ # Look for the end tag
531
+ idx = buf.find(end_tag)
532
+ if idx != -1:
533
+ # Found end tag. Yield thought content before it.
534
+ if idx > 0 and buf[:idx].strip():
535
+ chunk.choices[0].delta.thought = buf[:idx]
536
+ chunk.choices[0].delta.content = ""
537
+ yield chunk
538
+ # Process content *after* the tag as message
539
+ buf = buf[idx + len(end_tag) :]
540
+ if buf:
541
+ chunk.choices[0].delta.content = buf
542
+ yield chunk
543
+ mode = "message"
544
+ buf = ""
545
+ continue
546
+ else:
547
+ # End tag not found yet. Yield thought content, holding back potential partial end tag.
548
+ send_upto = len(buf)
549
+ # Check if buffer ends with a prefix of end_tag
550
+ for i in range(len(end_tag) - 1, 0, -1):
551
+ if buf.endswith(end_tag[:i]):
552
+ send_upto = len(buf) - i # Don't send the partial tag yet
553
+ break
554
+ if send_upto > 0 and buf[:send_upto].strip():
555
+ chunk.choices[0].delta.thought = buf[:send_upto]
556
+ chunk.choices[0].delta.content = ""
557
+ yield chunk
558
+ buf = buf[send_upto:] # Keep only the partial tag (or empty)
559
+ # Need more data to find the complete end tag
560
+ continue
561
+
562
+ # End of stream handling
563
+ if buf:
564
+ if mode == "thought": # Stream ended before </think> was found
565
+ chunk.choices[0].delta.thought = buf
566
+ chunk.choices[0].delta.content = ""
567
+ yield chunk
568
+ elif mode == "detect_start": # Stream ended before start tag could be confirmed/denied
569
+ # If it wasn't a partial start tag, treat as message
570
+ if not start_tag.startswith(buf):
571
+ chunk.choices[0].delta.content = buf
572
+ yield chunk
573
+ # else: discard partial <think>
574
+ # If mode == "message", buffer should be empty due to logic above, but yield just in case
575
+ elif mode == "message":
576
+ chunk.choices[0].delta.content = buf
577
+ yield chunk
@@ -191,6 +191,7 @@ class ChatEvent(Enum):
191
191
  REFERENCES = "references"
192
192
  GENERATED_ASSETS = "generated_assets"
193
193
  STATUS = "status"
194
+ THOUGHT = "thought"
194
195
  METADATA = "metadata"
195
196
  USAGE = "usage"
196
197
  END_RESPONSE = "end_response"
@@ -873,3 +874,9 @@ class JsonSupport(int, Enum):
873
874
  NONE = 0
874
875
  OBJECT = 1
875
876
  SCHEMA = 2
877
+
878
+
879
+ class ResponseWithThought:
880
+ def __init__(self, response: str = None, thought: str = None):
881
+ self.response = response
882
+ self.thought = thought
khoj/routers/api_chat.py CHANGED
@@ -25,7 +25,11 @@ from khoj.database.adapters import (
25
25
  from khoj.database.models import Agent, KhojUser
26
26
  from khoj.processor.conversation import prompts
27
27
  from khoj.processor.conversation.prompts import help_message, no_entries_found
28
- from khoj.processor.conversation.utils import defilter_query, save_to_conversation_log
28
+ from khoj.processor.conversation.utils import (
29
+ ResponseWithThought,
30
+ defilter_query,
31
+ save_to_conversation_log,
32
+ )
29
33
  from khoj.processor.image.generate import text_to_image
30
34
  from khoj.processor.speech.text_to_speech import generate_text_to_speech
31
35
  from khoj.processor.tools.online_search import (
@@ -679,14 +683,13 @@ async def chat(
679
683
  start_time = time.perf_counter()
680
684
  ttft = None
681
685
  chat_metadata: dict = {}
682
- connection_alive = True
683
686
  user: KhojUser = request.user.object
684
687
  is_subscribed = has_required_scope(request, ["premium"])
685
- event_delimiter = "␃🔚␗"
686
688
  q = unquote(q)
687
689
  train_of_thought = []
688
690
  nonlocal conversation_id
689
691
  nonlocal raw_query_files
692
+ cancellation_event = asyncio.Event()
690
693
 
691
694
  tracer: dict = {
692
695
  "mid": turn_id,
@@ -713,11 +716,33 @@ async def chat(
713
716
  for file in raw_query_files:
714
717
  query_files[file.name] = file.content
715
718
 
719
+ # Create a task to monitor for disconnections
720
+ disconnect_monitor_task = None
721
+
722
+ async def monitor_disconnection():
723
+ try:
724
+ msg = await request.receive()
725
+ if msg["type"] == "http.disconnect":
726
+ logger.debug(f"User {user} disconnected from {common.client} client.")
727
+ cancellation_event.set()
728
+ except Exception as e:
729
+ logger.error(f"Error in disconnect monitor: {e}")
730
+
731
+ # Cancel the disconnect monitor task if it is still running
732
+ async def cancel_disconnect_monitor():
733
+ if disconnect_monitor_task and not disconnect_monitor_task.done():
734
+ logger.debug(f"Cancelling disconnect monitor task for user {user}")
735
+ disconnect_monitor_task.cancel()
736
+ try:
737
+ await disconnect_monitor_task
738
+ except asyncio.CancelledError:
739
+ pass
740
+
716
741
  async def send_event(event_type: ChatEvent, data: str | dict):
717
- nonlocal connection_alive, ttft, train_of_thought
718
- if not connection_alive or await request.is_disconnected():
719
- connection_alive = False
720
- logger.warning(f"User {user} disconnected from {common.client} client")
742
+ nonlocal ttft, train_of_thought
743
+ event_delimiter = "␃🔚␗"
744
+ if cancellation_event.is_set():
745
+ logger.debug(f"User {user} disconnected from {common.client} client. Setting cancellation event.")
721
746
  return
722
747
  try:
723
748
  if event_type == ChatEvent.END_LLM_RESPONSE:
@@ -726,23 +751,41 @@ async def chat(
726
751
  ttft = time.perf_counter() - start_time
727
752
  elif event_type == ChatEvent.STATUS:
728
753
  train_of_thought.append({"type": event_type.value, "data": data})
754
+ elif event_type == ChatEvent.THOUGHT:
755
+ # Append the data to the last thought as thoughts are streamed
756
+ if (
757
+ len(train_of_thought) > 0
758
+ and train_of_thought[-1]["type"] == ChatEvent.THOUGHT.value
759
+ and type(train_of_thought[-1]["data"]) == type(data) == str
760
+ ):
761
+ train_of_thought[-1]["data"] += data
762
+ else:
763
+ train_of_thought.append({"type": event_type.value, "data": data})
729
764
 
730
765
  if event_type == ChatEvent.MESSAGE:
731
766
  yield data
732
767
  elif event_type == ChatEvent.REFERENCES or ChatEvent.METADATA or stream:
733
768
  yield json.dumps({"type": event_type.value, "data": data}, ensure_ascii=False)
734
769
  except asyncio.CancelledError as e:
735
- connection_alive = False
736
- logger.warn(f"User {user} disconnected from {common.client} client: {e}")
737
- return
770
+ if cancellation_event.is_set():
771
+ logger.debug(f"Request cancelled. User {user} disconnected from {common.client} client: {e}.")
738
772
  except Exception as e:
739
- connection_alive = False
740
- logger.error(f"Failed to stream chat API response to {user} on {common.client}: {e}", exc_info=True)
741
- return
773
+ if not cancellation_event.is_set():
774
+ logger.error(
775
+ f"Failed to stream chat API response to {user} on {common.client}: {e}.",
776
+ exc_info=True,
777
+ )
742
778
  finally:
743
- yield event_delimiter
779
+ if not cancellation_event.is_set():
780
+ yield event_delimiter
781
+ # Cancel the disconnect monitor task if it is still running
782
+ if cancellation_event.is_set() or event_type == ChatEvent.END_RESPONSE:
783
+ await cancel_disconnect_monitor()
744
784
 
745
785
  async def send_llm_response(response: str, usage: dict = None):
786
+ # Check if the client is still connected
787
+ if cancellation_event.is_set():
788
+ return
746
789
  # Send Chat Response
747
790
  async for result in send_event(ChatEvent.START_LLM_RESPONSE, ""):
748
791
  yield result
@@ -783,6 +826,9 @@ async def chat(
783
826
  metadata=chat_metadata,
784
827
  )
785
828
 
829
+ # Start the disconnect monitor in the background
830
+ disconnect_monitor_task = asyncio.create_task(monitor_disconnection())
831
+
786
832
  if is_query_empty(q):
787
833
  async for result in send_llm_response("Please ask your query to get started.", tracer.get("usage")):
788
834
  yield result
@@ -900,6 +946,7 @@ async def chat(
900
946
  file_filters=conversation.file_filters if conversation else [],
901
947
  query_files=attached_file_context,
902
948
  tracer=tracer,
949
+ cancellation_event=cancellation_event,
903
950
  ):
904
951
  if isinstance(research_result, InformationCollectionIteration):
905
952
  if research_result.summarizedResult:
@@ -1274,6 +1321,13 @@ async def chat(
1274
1321
  async for result in send_event(ChatEvent.STATUS, error_message):
1275
1322
  yield result
1276
1323
 
1324
+ # Check if the user has disconnected
1325
+ if cancellation_event.is_set():
1326
+ logger.debug(f"User {user} disconnected from {common.client} client. Stopping LLM response.")
1327
+ # Cancel the disconnect monitor task if it is still running
1328
+ await cancel_disconnect_monitor()
1329
+ return
1330
+
1277
1331
  ## Generate Text Output
1278
1332
  async for result in send_event(ChatEvent.STATUS, f"**Generating a well-informed response**"):
1279
1333
  yield result
@@ -1306,27 +1360,32 @@ async def chat(
1306
1360
  tracer,
1307
1361
  )
1308
1362
 
1309
- # Send Response
1310
- async for result in send_event(ChatEvent.START_LLM_RESPONSE, ""):
1311
- yield result
1312
-
1313
- continue_stream = True
1314
1363
  async for item in llm_response:
1315
1364
  # Should not happen with async generator, end is signaled by loop exit. Skip.
1316
1365
  if item is None:
1317
1366
  continue
1318
- if not connection_alive or not continue_stream:
1319
- # Drain the generator if disconnected but keep processing internally
1367
+ if cancellation_event.is_set():
1368
+ break
1369
+ message = item.response if isinstance(item, ResponseWithThought) else item
1370
+ if isinstance(item, ResponseWithThought) and item.thought:
1371
+ async for result in send_event(ChatEvent.THOUGHT, item.thought):
1372
+ yield result
1320
1373
  continue
1374
+
1375
+ # Start sending response
1376
+ async for result in send_event(ChatEvent.START_LLM_RESPONSE, ""):
1377
+ yield result
1378
+
1321
1379
  try:
1322
- async for result in send_event(ChatEvent.MESSAGE, f"{item}"):
1380
+ async for result in send_event(ChatEvent.MESSAGE, message):
1323
1381
  yield result
1324
1382
  except Exception as e:
1325
- continue_stream = False
1326
- logger.info(f"User {user} disconnected or error during streaming. Stopping send: {e}")
1383
+ if not cancellation_event.is_set():
1384
+ logger.warning(f"Error during streaming. Stopping send: {e}")
1385
+ break
1327
1386
 
1328
1387
  # Signal end of LLM response after the loop finishes
1329
- if connection_alive:
1388
+ if not cancellation_event.is_set():
1330
1389
  async for result in send_event(ChatEvent.END_LLM_RESPONSE, ""):
1331
1390
  yield result
1332
1391
  # Send Usage Metadata once llm interactions are complete
@@ -1337,6 +1396,9 @@ async def chat(
1337
1396
  yield result
1338
1397
  logger.debug("Finished streaming response")
1339
1398
 
1399
+ # Cancel the disconnect monitor task if it is still running
1400
+ await cancel_disconnect_monitor()
1401
+
1340
1402
  ## Stream Text Response
1341
1403
  if stream:
1342
1404
  return StreamingResponse(event_generator(q, images=raw_images), media_type="text/plain")