khoj 1.27.2.dev15__py3-none-any.whl → 1.27.2.dev29__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (53) hide show
  1. khoj/interface/compiled/404/index.html +1 -1
  2. khoj/interface/compiled/_next/static/chunks/1603-5138bb7c8035d9a6.js +1 -0
  3. khoj/interface/compiled/_next/static/chunks/app/agents/{page-2beaba7c9bb750bd.js → page-5ae1e540bb5be8a9.js} +1 -1
  4. khoj/interface/compiled/_next/static/chunks/app/automations/{page-9b5c77e0b0dd772c.js → page-774ae3e033f938cd.js} +1 -1
  5. khoj/interface/compiled/_next/static/chunks/app/chat/{page-151232d8417a1ea1.js → page-97f5b61aaf46d364.js} +1 -1
  6. khoj/interface/compiled/_next/static/chunks/app/factchecker/{page-798904432c2417c4.js → page-d82403db2866bad8.js} +1 -1
  7. khoj/interface/compiled/_next/static/chunks/app/{page-4b6008223ea79955.js → page-4dc472cf6d674004.js} +1 -1
  8. khoj/interface/compiled/_next/static/chunks/app/search/{page-ab2995529ece3140.js → page-9b64f61caa5bd7f9.js} +1 -1
  9. khoj/interface/compiled/_next/static/chunks/app/settings/{page-7946cabb9c54e22d.js → page-7a8c382af2a7e870.js} +1 -1
  10. khoj/interface/compiled/_next/static/chunks/app/share/chat/{page-6a01e07fb244c10c.js → page-eb9e282691858f2e.js} +1 -1
  11. khoj/interface/compiled/_next/static/chunks/{webpack-878569182b3af4c6.js → webpack-2b720658ccc746f2.js} +1 -1
  12. khoj/interface/compiled/_next/static/css/4cae6c0e5c72fb2d.css +1 -0
  13. khoj/interface/compiled/_next/static/css/ddcc0cf73e062476.css +1 -0
  14. khoj/interface/compiled/agents/index.html +1 -1
  15. khoj/interface/compiled/agents/index.txt +2 -2
  16. khoj/interface/compiled/automations/index.html +1 -1
  17. khoj/interface/compiled/automations/index.txt +2 -2
  18. khoj/interface/compiled/chat/index.html +1 -1
  19. khoj/interface/compiled/chat/index.txt +2 -2
  20. khoj/interface/compiled/factchecker/index.html +1 -1
  21. khoj/interface/compiled/factchecker/index.txt +2 -2
  22. khoj/interface/compiled/index.html +1 -1
  23. khoj/interface/compiled/index.txt +2 -2
  24. khoj/interface/compiled/search/index.html +1 -1
  25. khoj/interface/compiled/search/index.txt +2 -2
  26. khoj/interface/compiled/settings/index.html +1 -1
  27. khoj/interface/compiled/settings/index.txt +2 -2
  28. khoj/interface/compiled/share/chat/index.html +1 -1
  29. khoj/interface/compiled/share/chat/index.txt +2 -2
  30. khoj/processor/conversation/anthropic/anthropic_chat.py +6 -1
  31. khoj/processor/conversation/anthropic/utils.py +25 -5
  32. khoj/processor/conversation/google/gemini_chat.py +8 -2
  33. khoj/processor/conversation/google/utils.py +34 -10
  34. khoj/processor/conversation/offline/chat_model.py +31 -7
  35. khoj/processor/conversation/openai/gpt.py +14 -2
  36. khoj/processor/conversation/openai/utils.py +43 -9
  37. khoj/processor/conversation/prompts.py +0 -16
  38. khoj/processor/conversation/utils.py +168 -1
  39. khoj/processor/image/generate.py +2 -0
  40. khoj/processor/tools/online_search.py +14 -5
  41. khoj/routers/api.py +5 -0
  42. khoj/routers/api_chat.py +23 -2
  43. khoj/routers/helpers.py +65 -13
  44. {khoj-1.27.2.dev15.dist-info → khoj-1.27.2.dev29.dist-info}/METADATA +2 -1
  45. {khoj-1.27.2.dev15.dist-info → khoj-1.27.2.dev29.dist-info}/RECORD +50 -50
  46. khoj/interface/compiled/_next/static/chunks/1603-b9d95833e0e025e8.js +0 -1
  47. khoj/interface/compiled/_next/static/css/592ca99f5122e75a.css +0 -1
  48. khoj/interface/compiled/_next/static/css/d738728883c68af8.css +0 -1
  49. /khoj/interface/compiled/_next/static/{vcyFRDGArOFXwUVotHIuv → atzIseFarmC7TIwq2BgHC}/_buildManifest.js +0 -0
  50. /khoj/interface/compiled/_next/static/{vcyFRDGArOFXwUVotHIuv → atzIseFarmC7TIwq2BgHC}/_ssgManifest.js +0 -0
  51. {khoj-1.27.2.dev15.dist-info → khoj-1.27.2.dev29.dist-info}/WHEEL +0 -0
  52. {khoj-1.27.2.dev15.dist-info → khoj-1.27.2.dev29.dist-info}/entry_points.txt +0 -0
  53. {khoj-1.27.2.dev15.dist-info → khoj-1.27.2.dev29.dist-info}/licenses/LICENSE +0 -0
@@ -19,8 +19,13 @@ from tenacity import (
19
19
  wait_random_exponential,
20
20
  )
21
21
 
22
- from khoj.processor.conversation.utils import ThreadedGenerator, get_image_from_url
23
- from khoj.utils.helpers import is_none_or_empty
22
+ from khoj.processor.conversation.utils import (
23
+ ThreadedGenerator,
24
+ commit_conversation_trace,
25
+ get_image_from_url,
26
+ )
27
+ from khoj.utils import state
28
+ from khoj.utils.helpers import in_debug_mode, is_none_or_empty
24
29
 
25
30
  logger = logging.getLogger(__name__)
26
31
 
@@ -35,7 +40,7 @@ MAX_OUTPUT_TOKENS_GEMINI = 8192
35
40
  reraise=True,
36
41
  )
37
42
  def gemini_completion_with_backoff(
38
- messages, system_prompt, model_name, temperature=0, api_key=None, model_kwargs=None
43
+ messages, system_prompt, model_name, temperature=0, api_key=None, model_kwargs=None, tracer={}
39
44
  ) -> str:
40
45
  genai.configure(api_key=api_key)
41
46
  model_kwargs = model_kwargs or dict()
@@ -60,16 +65,23 @@ def gemini_completion_with_backoff(
60
65
 
61
66
  try:
62
67
  # Generate the response. The last message is considered to be the current prompt
63
- aggregated_response = chat_session.send_message(formatted_messages[-1]["parts"])
64
- return aggregated_response.text
68
+ response = chat_session.send_message(formatted_messages[-1]["parts"])
69
+ response_text = response.text
65
70
  except StopCandidateException as e:
66
- response_message, _ = handle_gemini_response(e.args)
71
+ response_text, _ = handle_gemini_response(e.args)
67
72
  # Respond with reason for stopping
68
73
  logger.warning(
69
- f"LLM Response Prevented for {model_name}: {response_message}.\n"
74
+ f"LLM Response Prevented for {model_name}: {response_text}.\n"
70
75
  + f"Last Message by {messages[-1].role}: {messages[-1].content}"
71
76
  )
72
- return response_message
77
+
78
+ # Save conversation trace
79
+ tracer["chat_model"] = model_name
80
+ tracer["temperature"] = temperature
81
+ if in_debug_mode() or state.verbose > 1:
82
+ commit_conversation_trace(messages, response_text, tracer)
83
+
84
+ return response_text
73
85
 
74
86
 
75
87
  @retry(
@@ -88,17 +100,20 @@ def gemini_chat_completion_with_backoff(
88
100
  system_prompt,
89
101
  completion_func=None,
90
102
  model_kwargs=None,
103
+ tracer: dict = {},
91
104
  ):
92
105
  g = ThreadedGenerator(compiled_references, online_results, completion_func=completion_func)
93
106
  t = Thread(
94
107
  target=gemini_llm_thread,
95
- args=(g, messages, system_prompt, model_name, temperature, api_key, model_kwargs),
108
+ args=(g, messages, system_prompt, model_name, temperature, api_key, model_kwargs, tracer),
96
109
  )
97
110
  t.start()
98
111
  return g
99
112
 
100
113
 
101
- def gemini_llm_thread(g, messages, system_prompt, model_name, temperature, api_key, model_kwargs=None):
114
+ def gemini_llm_thread(
115
+ g, messages, system_prompt, model_name, temperature, api_key, model_kwargs=None, tracer: dict = {}
116
+ ):
102
117
  try:
103
118
  genai.configure(api_key=api_key)
104
119
  model_kwargs = model_kwargs or dict()
@@ -117,16 +132,25 @@ def gemini_llm_thread(g, messages, system_prompt, model_name, temperature, api_k
117
132
  },
118
133
  )
119
134
 
135
+ aggregated_response = ""
120
136
  formatted_messages = [{"role": message.role, "parts": message.content} for message in messages]
137
+
121
138
  # all messages up to the last are considered to be part of the chat history
122
139
  chat_session = model.start_chat(history=formatted_messages[0:-1])
123
140
  # the last message is considered to be the current prompt
124
141
  for chunk in chat_session.send_message(formatted_messages[-1]["parts"], stream=True):
125
142
  message, stopped = handle_gemini_response(chunk.candidates, chunk.prompt_feedback)
126
143
  message = message or chunk.text
144
+ aggregated_response += message
127
145
  g.send(message)
128
146
  if stopped:
129
147
  raise StopCandidateException(message)
148
+
149
+ # Save conversation trace
150
+ tracer["chat_model"] = model_name
151
+ tracer["temperature"] = temperature
152
+ if in_debug_mode() or state.verbose > 1:
153
+ commit_conversation_trace(messages, aggregated_response, tracer)
130
154
  except StopCandidateException as e:
131
155
  logger.warning(
132
156
  f"LLM Response Prevented for {model_name}: {e.args[0]}.\n"
@@ -12,11 +12,12 @@ from khoj.processor.conversation import prompts
12
12
  from khoj.processor.conversation.offline.utils import download_model
13
13
  from khoj.processor.conversation.utils import (
14
14
  ThreadedGenerator,
15
+ commit_conversation_trace,
15
16
  generate_chatml_messages_with_context,
16
17
  )
17
18
  from khoj.utils import state
18
19
  from khoj.utils.constants import empty_escape_sequences
19
- from khoj.utils.helpers import ConversationCommand, is_none_or_empty
20
+ from khoj.utils.helpers import ConversationCommand, in_debug_mode, is_none_or_empty
20
21
  from khoj.utils.rawconfig import LocationData
21
22
 
22
23
  logger = logging.getLogger(__name__)
@@ -34,6 +35,7 @@ def extract_questions_offline(
34
35
  max_prompt_size: int = None,
35
36
  temperature: float = 0.7,
36
37
  personality_context: Optional[str] = None,
38
+ tracer: dict = {},
37
39
  ) -> List[str]:
38
40
  """
39
41
  Infer search queries to retrieve relevant notes to answer user query
@@ -94,6 +96,7 @@ def extract_questions_offline(
94
96
  max_prompt_size=max_prompt_size,
95
97
  temperature=temperature,
96
98
  response_type="json_object",
99
+ tracer=tracer,
97
100
  )
98
101
  finally:
99
102
  state.chat_lock.release()
@@ -146,6 +149,7 @@ def converse_offline(
146
149
  location_data: LocationData = None,
147
150
  user_name: str = None,
148
151
  agent: Agent = None,
152
+ tracer: dict = {},
149
153
  ) -> Union[ThreadedGenerator, Iterator[str]]:
150
154
  """
151
155
  Converse with user using Llama
@@ -153,8 +157,9 @@ def converse_offline(
153
157
  # Initialize Variables
154
158
  assert loaded_model is None or isinstance(loaded_model, Llama), "loaded_model must be of type Llama, if configured"
155
159
  offline_chat_model = loaded_model or download_model(model, max_tokens=max_prompt_size)
156
- compiled_references = "\n\n".join({f"# File: {item['file']}\n## {item['compiled']}\n" for item in references})
160
+ tracer["chat_model"] = model
157
161
 
162
+ compiled_references = "\n\n".join({f"# File: {item['file']}\n## {item['compiled']}\n" for item in references})
158
163
  current_date = datetime.now()
159
164
 
160
165
  if agent and agent.personality:
@@ -215,13 +220,14 @@ def converse_offline(
215
220
  logger.debug(f"Conversation Context for {model}: {truncated_messages}")
216
221
 
217
222
  g = ThreadedGenerator(references, online_results, completion_func=completion_func)
218
- t = Thread(target=llm_thread, args=(g, messages, offline_chat_model, max_prompt_size))
223
+ t = Thread(target=llm_thread, args=(g, messages, offline_chat_model, max_prompt_size, tracer))
219
224
  t.start()
220
225
  return g
221
226
 
222
227
 
223
- def llm_thread(g, messages: List[ChatMessage], model: Any, max_prompt_size: int = None):
228
+ def llm_thread(g, messages: List[ChatMessage], model: Any, max_prompt_size: int = None, tracer: dict = {}):
224
229
  stop_phrases = ["<s>", "INST]", "Notes:"]
230
+ aggregated_response = ""
225
231
 
226
232
  state.chat_lock.acquire()
227
233
  try:
@@ -229,7 +235,14 @@ def llm_thread(g, messages: List[ChatMessage], model: Any, max_prompt_size: int
229
235
  messages, loaded_model=model, stop=stop_phrases, max_prompt_size=max_prompt_size, streaming=True
230
236
  )
231
237
  for response in response_iterator:
232
- g.send(response["choices"][0]["delta"].get("content", ""))
238
+ response_delta = response["choices"][0]["delta"].get("content", "")
239
+ aggregated_response += response_delta
240
+ g.send(response_delta)
241
+
242
+ # Save conversation trace
243
+ if in_debug_mode() or state.verbose > 1:
244
+ commit_conversation_trace(messages, aggregated_response, tracer)
245
+
233
246
  finally:
234
247
  state.chat_lock.release()
235
248
  g.close()
@@ -244,6 +257,7 @@ def send_message_to_model_offline(
244
257
  stop=[],
245
258
  max_prompt_size: int = None,
246
259
  response_type: str = "text",
260
+ tracer: dict = {},
247
261
  ):
248
262
  assert loaded_model is None or isinstance(loaded_model, Llama), "loaded_model must be of type Llama, if configured"
249
263
  offline_chat_model = loaded_model or download_model(model, max_tokens=max_prompt_size)
@@ -251,7 +265,17 @@ def send_message_to_model_offline(
251
265
  response = offline_chat_model.create_chat_completion(
252
266
  messages_dict, stop=stop, stream=streaming, temperature=temperature, response_format={"type": response_type}
253
267
  )
268
+
254
269
  if streaming:
255
270
  return response
256
- else:
257
- return response["choices"][0]["message"].get("content", "")
271
+
272
+ response_text = response["choices"][0]["message"].get("content", "")
273
+
274
+ # Save conversation trace for non-streaming responses
275
+ # Streamed responses need to be saved by the calling function
276
+ tracer["chat_model"] = model
277
+ tracer["temperature"] = temperature
278
+ if in_debug_mode() or state.verbose > 1:
279
+ commit_conversation_trace(messages, response_text, tracer)
280
+
281
+ return response_text
@@ -33,6 +33,7 @@ def extract_questions(
33
33
  query_images: Optional[list[str]] = None,
34
34
  vision_enabled: bool = False,
35
35
  personality_context: Optional[str] = None,
36
+ tracer: dict = {},
36
37
  ):
37
38
  """
38
39
  Infer search queries to retrieve relevant notes to answer user query
@@ -82,7 +83,13 @@ def extract_questions(
82
83
  messages = [ChatMessage(content=prompt, role="user")]
83
84
 
84
85
  response = send_message_to_model(
85
- messages, api_key, model, response_type="json_object", api_base_url=api_base_url, temperature=temperature
86
+ messages,
87
+ api_key,
88
+ model,
89
+ response_type="json_object",
90
+ api_base_url=api_base_url,
91
+ temperature=temperature,
92
+ tracer=tracer,
86
93
  )
87
94
 
88
95
  # Extract, Clean Message from GPT's Response
@@ -103,7 +110,9 @@ def extract_questions(
103
110
  return questions
104
111
 
105
112
 
106
- def send_message_to_model(messages, api_key, model, response_type="text", api_base_url=None, temperature=0):
113
+ def send_message_to_model(
114
+ messages, api_key, model, response_type="text", api_base_url=None, temperature=0, tracer: dict = {}
115
+ ):
107
116
  """
108
117
  Send message to model
109
118
  """
@@ -116,6 +125,7 @@ def send_message_to_model(messages, api_key, model, response_type="text", api_ba
116
125
  temperature=temperature,
117
126
  api_base_url=api_base_url,
118
127
  model_kwargs={"response_format": {"type": response_type}},
128
+ tracer=tracer,
119
129
  )
120
130
 
121
131
 
@@ -137,6 +147,7 @@ def converse(
137
147
  agent: Agent = None,
138
148
  query_images: Optional[list[str]] = None,
139
149
  vision_available: bool = False,
150
+ tracer: dict = {},
140
151
  ):
141
152
  """
142
153
  Converse with user using OpenAI's ChatGPT
@@ -207,4 +218,5 @@ def converse(
207
218
  api_base_url=api_base_url,
208
219
  completion_func=completion_func,
209
220
  model_kwargs={"stop": ["Notes:\n["]},
221
+ tracer=tracer,
210
222
  )
@@ -12,7 +12,12 @@ from tenacity import (
12
12
  wait_random_exponential,
13
13
  )
14
14
 
15
- from khoj.processor.conversation.utils import ThreadedGenerator
15
+ from khoj.processor.conversation.utils import (
16
+ ThreadedGenerator,
17
+ commit_conversation_trace,
18
+ )
19
+ from khoj.utils import state
20
+ from khoj.utils.helpers import in_debug_mode
16
21
 
17
22
  logger = logging.getLogger(__name__)
18
23
 
@@ -33,7 +38,7 @@ openai_clients: Dict[str, openai.OpenAI] = {}
33
38
  reraise=True,
34
39
  )
35
40
  def completion_with_backoff(
36
- messages, model, temperature=0, openai_api_key=None, api_base_url=None, model_kwargs=None
41
+ messages, model, temperature=0, openai_api_key=None, api_base_url=None, model_kwargs=None, tracer: dict = {}
37
42
  ) -> str:
38
43
  client_key = f"{openai_api_key}--{api_base_url}"
39
44
  client: openai.OpenAI | None = openai_clients.get(client_key)
@@ -77,6 +82,12 @@ def completion_with_backoff(
77
82
  elif delta_chunk.content:
78
83
  aggregated_response += delta_chunk.content
79
84
 
85
+ # Save conversation trace
86
+ tracer["chat_model"] = model
87
+ tracer["temperature"] = temperature
88
+ if in_debug_mode() or state.verbose > 1:
89
+ commit_conversation_trace(messages, aggregated_response, tracer)
90
+
80
91
  return aggregated_response
81
92
 
82
93
 
@@ -103,26 +114,37 @@ def chat_completion_with_backoff(
103
114
  api_base_url=None,
104
115
  completion_func=None,
105
116
  model_kwargs=None,
117
+ tracer: dict = {},
106
118
  ):
107
119
  g = ThreadedGenerator(compiled_references, online_results, completion_func=completion_func)
108
120
  t = Thread(
109
- target=llm_thread, args=(g, messages, model_name, temperature, openai_api_key, api_base_url, model_kwargs)
121
+ target=llm_thread,
122
+ args=(g, messages, model_name, temperature, openai_api_key, api_base_url, model_kwargs, tracer),
110
123
  )
111
124
  t.start()
112
125
  return g
113
126
 
114
127
 
115
- def llm_thread(g, messages, model_name, temperature, openai_api_key=None, api_base_url=None, model_kwargs=None):
128
+ def llm_thread(
129
+ g,
130
+ messages,
131
+ model_name,
132
+ temperature,
133
+ openai_api_key=None,
134
+ api_base_url=None,
135
+ model_kwargs=None,
136
+ tracer: dict = {},
137
+ ):
116
138
  try:
117
139
  client_key = f"{openai_api_key}--{api_base_url}"
118
140
  if client_key not in openai_clients:
119
- client: openai.OpenAI = openai.OpenAI(
141
+ client = openai.OpenAI(
120
142
  api_key=openai_api_key,
121
143
  base_url=api_base_url,
122
144
  )
123
145
  openai_clients[client_key] = client
124
146
  else:
125
- client: openai.OpenAI = openai_clients[client_key]
147
+ client = openai_clients[client_key]
126
148
 
127
149
  formatted_messages = [{"role": message.role, "content": message.content} for message in messages]
128
150
  stream = True
@@ -144,17 +166,29 @@ def llm_thread(g, messages, model_name, temperature, openai_api_key=None, api_ba
144
166
  **(model_kwargs or dict()),
145
167
  )
146
168
 
169
+ aggregated_response = ""
147
170
  if not stream:
148
- g.send(chat.choices[0].message.content)
171
+ aggregated_response = chat.choices[0].message.content
172
+ g.send(aggregated_response)
149
173
  else:
150
174
  for chunk in chat:
151
175
  if len(chunk.choices) == 0:
152
176
  continue
153
177
  delta_chunk = chunk.choices[0].delta
178
+ text_chunk = ""
154
179
  if isinstance(delta_chunk, str):
155
- g.send(delta_chunk)
180
+ text_chunk = delta_chunk
156
181
  elif delta_chunk.content:
157
- g.send(delta_chunk.content)
182
+ text_chunk = delta_chunk.content
183
+ if text_chunk:
184
+ aggregated_response += text_chunk
185
+ g.send(text_chunk)
186
+
187
+ # Save conversation trace
188
+ tracer["chat_model"] = model_name
189
+ tracer["temperature"] = temperature
190
+ if in_debug_mode() or state.verbose > 1:
191
+ commit_conversation_trace(messages, aggregated_response, tracer)
158
192
  except Exception as e:
159
193
  logger.error(f"Error in llm_thread: {e}", exc_info=True)
160
194
  finally:
@@ -193,7 +193,6 @@ you need to convert the user's query to a description format that the novice art
193
193
  - ellipse
194
194
  - line
195
195
  - arrow
196
- - frame
197
196
 
198
197
  use these primitives to describe what sort of diagram the drawer should create. the artist must recreate the diagram every time, so include all relevant prior information in your description.
199
198
 
@@ -284,21 +283,6 @@ For text, you must use the `text` property to specify the text to be rendered. Y
284
283
  text: string,
285
284
  }}
286
285
 
287
- For frames, use the `children` property to specify the elements that are inside the frame by their ids.
288
-
289
- {{
290
- type: "frame",
291
- id: string,
292
- x: number,
293
- y: number,
294
- width: number,
295
- height: number,
296
- name: string,
297
- children: [
298
- string
299
- ]
300
- }}
301
-
302
286
  Here's an example of a valid diagram:
303
287
 
304
288
  Design Description: Create a diagram describing a circular development process with 3 stages: design, implementation and feedback. The design stage is connected to the implementation stage and the implementation stage is connected to the feedback stage and the feedback stage is connected to the design stage. Each stage should be labeled with the stage name.
@@ -2,6 +2,7 @@ import base64
2
2
  import logging
3
3
  import math
4
4
  import mimetypes
5
+ import os
5
6
  import queue
6
7
  from dataclasses import dataclass
7
8
  from datetime import datetime
@@ -12,6 +13,8 @@ from typing import Any, Dict, List, Optional
12
13
  import PIL.Image
13
14
  import requests
14
15
  import tiktoken
16
+ import yaml
17
+ from git import Repo
15
18
  from langchain.schema import ChatMessage
16
19
  from llama_cpp.llama import Llama
17
20
  from transformers import AutoTokenizer
@@ -21,7 +24,7 @@ from khoj.database.models import ChatModelOptions, ClientApplication, KhojUser
21
24
  from khoj.processor.conversation import prompts
22
25
  from khoj.processor.conversation.offline.utils import download_model, infer_max_tokens
23
26
  from khoj.utils import state
24
- from khoj.utils.helpers import is_none_or_empty, merge_dicts
27
+ from khoj.utils.helpers import in_debug_mode, is_none_or_empty, merge_dicts
25
28
 
26
29
  logger = logging.getLogger(__name__)
27
30
  model_to_prompt_size = {
@@ -117,6 +120,7 @@ def save_to_conversation_log(
117
120
  conversation_id: str = None,
118
121
  automation_id: str = None,
119
122
  query_images: List[str] = None,
123
+ tracer: Dict[str, Any] = {},
120
124
  ):
121
125
  user_message_time = user_message_time or datetime.now().strftime("%Y-%m-%d %H:%M:%S")
122
126
  updated_conversation = message_to_log(
@@ -142,6 +146,9 @@ def save_to_conversation_log(
142
146
  user_message=q,
143
147
  )
144
148
 
149
+ if in_debug_mode() or state.verbose > 1:
150
+ merge_message_into_conversation_trace(q, chat_response, tracer)
151
+
145
152
  logger.info(
146
153
  f"""
147
154
  Saved Conversation Turn
@@ -354,3 +361,163 @@ def get_image_from_url(image_url: str, type="pil"):
354
361
  except requests.exceptions.RequestException as e:
355
362
  logger.error(f"Failed to get image from URL {image_url}: {e}")
356
363
  return ImageWithType(content=None, type=None)
364
+
365
+
366
+ def commit_conversation_trace(
367
+ session: list[ChatMessage],
368
+ response: str | list[dict],
369
+ tracer: dict,
370
+ system_message: str | list[dict] = "",
371
+ repo_path: str = "/tmp/promptrace",
372
+ ) -> str:
373
+ """
374
+ Save trace of conversation step using git. Useful to visualize, compare and debug traces.
375
+ Returns the path to the repository.
376
+ """
377
+ # Serialize session, system message and response to yaml
378
+ system_message_yaml = yaml.dump(system_message, allow_unicode=True, sort_keys=False, default_flow_style=False)
379
+ response_yaml = yaml.dump(response, allow_unicode=True, sort_keys=False, default_flow_style=False)
380
+ formatted_session = [{"role": message.role, "content": message.content} for message in session]
381
+ session_yaml = yaml.dump(formatted_session, allow_unicode=True, sort_keys=False, default_flow_style=False)
382
+ query = (
383
+ yaml.dump(session[-1].content, allow_unicode=True, sort_keys=False, default_flow_style=False)
384
+ .strip()
385
+ .removeprefix("'")
386
+ .removesuffix("'")
387
+ ) # Extract serialized query from chat session
388
+
389
+ # Extract chat metadata for session
390
+ uid, cid, mid = tracer.get("uid", "main"), tracer.get("cid", "main"), tracer.get("mid")
391
+
392
+ # Infer repository path from environment variable or provided path
393
+ repo_path = os.getenv("PROMPTRACE_DIR", repo_path)
394
+
395
+ try:
396
+ # Prepare git repository
397
+ os.makedirs(repo_path, exist_ok=True)
398
+ repo = Repo.init(repo_path)
399
+
400
+ # Remove post-commit hook if it exists
401
+ hooks_dir = os.path.join(repo_path, ".git", "hooks")
402
+ post_commit_hook = os.path.join(hooks_dir, "post-commit")
403
+ if os.path.exists(post_commit_hook):
404
+ os.remove(post_commit_hook)
405
+
406
+ # Configure git user if not set
407
+ if not repo.config_reader().has_option("user", "email"):
408
+ repo.config_writer().set_value("user", "name", "Prompt Tracer").release()
409
+ repo.config_writer().set_value("user", "email", "promptracer@khoj.dev").release()
410
+
411
+ # Create an initial commit if the repository is newly created
412
+ if not repo.head.is_valid():
413
+ repo.index.commit("And then there was a trace")
414
+
415
+ # Check out the initial commit
416
+ initial_commit = repo.commit("HEAD~0")
417
+ repo.head.reference = initial_commit
418
+ repo.head.reset(index=True, working_tree=True)
419
+
420
+ # Create or switch to user branch from initial commit
421
+ user_branch = f"u_{uid}"
422
+ if user_branch not in repo.branches:
423
+ repo.create_head(user_branch)
424
+ repo.heads[user_branch].checkout()
425
+
426
+ # Create or switch to conversation branch from user branch
427
+ conv_branch = f"c_{cid}"
428
+ if conv_branch not in repo.branches:
429
+ repo.create_head(conv_branch)
430
+ repo.heads[conv_branch].checkout()
431
+
432
+ # Create or switch to message branch from conversation branch
433
+ msg_branch = f"m_{mid}" if mid else None
434
+ if msg_branch and msg_branch not in repo.branches:
435
+ repo.create_head(msg_branch)
436
+ if msg_branch:
437
+ repo.heads[msg_branch].checkout()
438
+
439
+ # Include file with content to commit
440
+ files_to_commit = {"query": session_yaml, "response": response_yaml, "system_prompt": system_message_yaml}
441
+
442
+ # Write files and stage them
443
+ for filename, content in files_to_commit.items():
444
+ file_path = os.path.join(repo_path, filename)
445
+ # Unescape special characters in content for better readability
446
+ content = content.strip().replace("\\n", "\n").replace("\\t", "\t")
447
+ with open(file_path, "w", encoding="utf-8") as f:
448
+ f.write(content)
449
+ repo.index.add([filename])
450
+
451
+ # Create commit
452
+ metadata_yaml = yaml.dump(tracer, allow_unicode=True, sort_keys=False, default_flow_style=False)
453
+ commit_message = f"""
454
+ {query[:250]}
455
+
456
+ Response:
457
+ ---
458
+ {response[:500]}...
459
+
460
+ Metadata
461
+ ---
462
+ {metadata_yaml}
463
+ """.strip()
464
+
465
+ repo.index.commit(commit_message)
466
+
467
+ logger.debug(f"Saved conversation trace to repo at {repo_path}")
468
+ return repo_path
469
+ except Exception as e:
470
+ logger.error(f"Failed to add conversation trace to repo: {str(e)}", exc_info=True)
471
+ return None
472
+
473
+
474
+ def merge_message_into_conversation_trace(query: str, response: str, tracer: dict, repo_path="/tmp/promptrace") -> bool:
475
+ """
476
+ Merge the message branch into its parent conversation branch.
477
+
478
+ Args:
479
+ query: User query
480
+ response: Assistant response
481
+ tracer: Dictionary containing uid, cid and mid
482
+ repo_path: Path to the git repository
483
+
484
+ Returns:
485
+ bool: True if merge was successful, False otherwise
486
+ """
487
+ try:
488
+ # Extract branch names
489
+ msg_branch = f"m_{tracer['mid']}"
490
+ conv_branch = f"c_{tracer['cid']}"
491
+
492
+ # Infer repository path from environment variable or provided path
493
+ repo_path = os.getenv("PROMPTRACE_DIR", repo_path)
494
+ repo = Repo(repo_path)
495
+
496
+ # Checkout conversation branch
497
+ repo.heads[conv_branch].checkout()
498
+
499
+ # Create commit message
500
+ metadata_yaml = yaml.dump(tracer, allow_unicode=True, sort_keys=False, default_flow_style=False)
501
+ commit_message = f"""
502
+ {query[:250]}
503
+
504
+ Response:
505
+ ---
506
+ {response[:500]}...
507
+
508
+ Metadata
509
+ ---
510
+ {metadata_yaml}
511
+ """.strip()
512
+
513
+ # Merge message branch into conversation branch
514
+ repo.git.merge(msg_branch, no_ff=True, m=commit_message)
515
+
516
+ # Delete message branch after merge
517
+ repo.delete_head(msg_branch, force=True)
518
+
519
+ logger.debug(f"Successfully merged {msg_branch} into {conv_branch}")
520
+ return True
521
+ except Exception as e:
522
+ logger.error(f"Failed to merge message {msg_branch} into conversation {conv_branch}: {str(e)}", exc_info=True)
523
+ return False
@@ -28,6 +28,7 @@ async def text_to_image(
28
28
  send_status_func: Optional[Callable] = None,
29
29
  query_images: Optional[List[str]] = None,
30
30
  agent: Agent = None,
31
+ tracer: dict = {},
31
32
  ):
32
33
  status_code = 200
33
34
  image = None
@@ -68,6 +69,7 @@ async def text_to_image(
68
69
  query_images=query_images,
69
70
  user=user,
70
71
  agent=agent,
72
+ tracer=tracer,
71
73
  )
72
74
 
73
75
  if send_status_func: