asky-cli 0.1.6__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- asky/__init__.py +7 -0
- asky/__main__.py +6 -0
- asky/banner.py +123 -0
- asky/cli.py +506 -0
- asky/config.py +270 -0
- asky/config.toml +226 -0
- asky/html.py +62 -0
- asky/llm.py +378 -0
- asky/storage.py +157 -0
- asky/tools.py +314 -0
- asky_cli-0.1.6.dist-info/METADATA +290 -0
- asky_cli-0.1.6.dist-info/RECORD +14 -0
- asky_cli-0.1.6.dist-info/WHEEL +4 -0
- asky_cli-0.1.6.dist-info/entry_points.txt +3 -0
asky/llm.py
ADDED
|
@@ -0,0 +1,378 @@
|
|
|
1
|
+
"""LLM integration and conversation loop."""
|
|
2
|
+
|
|
3
|
+
import json
|
|
4
|
+
import os
|
|
5
|
+
import re
|
|
6
|
+
import time
|
|
7
|
+
from typing import Any, Dict, List, Optional
|
|
8
|
+
|
|
9
|
+
import requests
|
|
10
|
+
from rich.console import Console
|
|
11
|
+
from rich.markdown import Markdown
|
|
12
|
+
|
|
13
|
+
|
|
14
|
+
from asky.config import (
|
|
15
|
+
MAX_TURNS,
|
|
16
|
+
MODELS,
|
|
17
|
+
TOOLS,
|
|
18
|
+
SYSTEM_PROMPT,
|
|
19
|
+
FORCE_SEARCH_PROMPT,
|
|
20
|
+
SYSTEM_PROMPT_SUFFIX,
|
|
21
|
+
DEEP_RESEARCH_PROMPT_TEMPLATE,
|
|
22
|
+
DEEP_DIVE_PROMPT_TEMPLATE,
|
|
23
|
+
QUERY_SUMMARY_MAX_CHARS,
|
|
24
|
+
ANSWER_SUMMARY_MAX_CHARS,
|
|
25
|
+
SUMMARIZATION_MODEL,
|
|
26
|
+
LLM_USER_AGENT,
|
|
27
|
+
SUMMARIZE_QUERY_PROMPT_TEMPLATE,
|
|
28
|
+
SUMMARIZE_ANSWER_PROMPT_TEMPLATE,
|
|
29
|
+
REQUEST_TIMEOUT,
|
|
30
|
+
DEFAULT_CONTEXT_SIZE,
|
|
31
|
+
CONTINUE_QUERY_THRESHOLD,
|
|
32
|
+
)
|
|
33
|
+
from asky.html import strip_think_tags
|
|
34
|
+
from asky.tools import dispatch_tool_call, reset_read_urls
|
|
35
|
+
|
|
36
|
+
|
|
37
|
+
def is_markdown(text: str) -> bool:
|
|
38
|
+
"""Check if the text likely contains markdown formatting."""
|
|
39
|
+
# Basic detection: common markdown patterns
|
|
40
|
+
patterns = [
|
|
41
|
+
r"^#+\s", # Headers
|
|
42
|
+
r"\*\*.*\*\*", # Bold
|
|
43
|
+
r"__.*__", # Bold
|
|
44
|
+
r"\*.*\* ", # Italic
|
|
45
|
+
r"_.*_", # Italic
|
|
46
|
+
r"\[.*\]\(.*\)", # Links
|
|
47
|
+
r"```", # Code blocks
|
|
48
|
+
r"^\s*[-*+]\s", # Lists
|
|
49
|
+
r"^\s*\d+\.\s", # Numbered lists
|
|
50
|
+
]
|
|
51
|
+
return any(re.search(p, text, re.M) for p in patterns)
|
|
52
|
+
|
|
53
|
+
|
|
54
|
+
def parse_textual_tool_call(text: str) -> Optional[Dict[str, Any]]:
|
|
55
|
+
"""Parse tool calls from textual format (fallback for some models)."""
|
|
56
|
+
if not text:
|
|
57
|
+
return None
|
|
58
|
+
m = re.search(r"to=functions\.([a-zA-Z0-9_]+)", text)
|
|
59
|
+
if not m:
|
|
60
|
+
return None
|
|
61
|
+
name = m.group(1)
|
|
62
|
+
j = re.search(r"(\{.*\})", text, re.S)
|
|
63
|
+
if not j:
|
|
64
|
+
return None
|
|
65
|
+
try:
|
|
66
|
+
json.loads(j.group(1))
|
|
67
|
+
return {"name": name, "arguments": j.group(1)}
|
|
68
|
+
except Exception:
|
|
69
|
+
return None
|
|
70
|
+
|
|
71
|
+
|
|
72
|
+
class UsageTracker:
|
|
73
|
+
"""Track token usage per model alias."""
|
|
74
|
+
|
|
75
|
+
def __init__(self):
|
|
76
|
+
self.usage: Dict[str, int] = {}
|
|
77
|
+
|
|
78
|
+
def add_usage(self, model_alias: str, tokens: int):
|
|
79
|
+
self.usage[model_alias] = self.usage.get(model_alias, 0) + tokens
|
|
80
|
+
|
|
81
|
+
def get_total_usage(self, model_alias: str) -> int:
|
|
82
|
+
return self.usage.get(model_alias, 0)
|
|
83
|
+
|
|
84
|
+
|
|
85
|
+
def count_tokens(messages: List[Dict[str, Any]]) -> int:
|
|
86
|
+
"""Naive token counting: chars / 4."""
|
|
87
|
+
total_chars = 0
|
|
88
|
+
for m in messages:
|
|
89
|
+
content = m.get("content")
|
|
90
|
+
if content:
|
|
91
|
+
total_chars += len(content)
|
|
92
|
+
# Also count tool calls and results
|
|
93
|
+
tc = m.get("tool_calls")
|
|
94
|
+
if tc:
|
|
95
|
+
total_chars += len(json.dumps(tc))
|
|
96
|
+
return total_chars // 4
|
|
97
|
+
|
|
98
|
+
|
|
99
|
+
def get_llm_msg(
|
|
100
|
+
model_id: str,
|
|
101
|
+
messages: List[Dict[str, Any]],
|
|
102
|
+
tools: Optional[List[Dict]] = TOOLS,
|
|
103
|
+
verbose: bool = False,
|
|
104
|
+
model_alias: Optional[str] = None,
|
|
105
|
+
usage_tracker: Optional[UsageTracker] = None,
|
|
106
|
+
) -> Dict[str, Any]:
|
|
107
|
+
"""Send messages to the LLM and get a response."""
|
|
108
|
+
# Find the model config based on model_id
|
|
109
|
+
model_config = next((m for m in MODELS.values() if m["id"] == model_id), None)
|
|
110
|
+
|
|
111
|
+
url = ""
|
|
112
|
+
headers = {
|
|
113
|
+
"Content-Type": "application/json",
|
|
114
|
+
"User-Agent": LLM_USER_AGENT,
|
|
115
|
+
}
|
|
116
|
+
|
|
117
|
+
if model_config and "base_url" in model_config:
|
|
118
|
+
url = model_config["base_url"]
|
|
119
|
+
|
|
120
|
+
api_key = None
|
|
121
|
+
if model_config and "api_key" in model_config:
|
|
122
|
+
api_key = model_config["api_key"]
|
|
123
|
+
elif model_config and "api_key_env" in model_config:
|
|
124
|
+
api_key_env_var = model_config["api_key_env"]
|
|
125
|
+
api_key = os.environ.get(api_key_env_var)
|
|
126
|
+
if not api_key:
|
|
127
|
+
print(f"Warning: {api_key_env_var} not found in environment variables.")
|
|
128
|
+
|
|
129
|
+
if api_key:
|
|
130
|
+
headers["Authorization"] = f"Bearer {api_key}"
|
|
131
|
+
|
|
132
|
+
payload = {
|
|
133
|
+
"model": model_id,
|
|
134
|
+
"messages": messages,
|
|
135
|
+
}
|
|
136
|
+
if tools:
|
|
137
|
+
payload["tools"] = tools
|
|
138
|
+
payload["tool_choice"] = "auto"
|
|
139
|
+
|
|
140
|
+
max_retries = 10
|
|
141
|
+
backoff = 2
|
|
142
|
+
max_backoff = 60
|
|
143
|
+
|
|
144
|
+
if verbose:
|
|
145
|
+
print(f"\n[DEBUG] Sending to LLM ({model_id})...")
|
|
146
|
+
print(f"Tools enabled: {bool(tools)}")
|
|
147
|
+
print("Last message sent:")
|
|
148
|
+
if messages:
|
|
149
|
+
last_msg = messages[-1]
|
|
150
|
+
content = last_msg.get("content", "")
|
|
151
|
+
if content and len(content) > 500:
|
|
152
|
+
print(f" Role: {last_msg['role']}")
|
|
153
|
+
print(f" Content (truncated): {content[:500]}...")
|
|
154
|
+
else:
|
|
155
|
+
print(f" {json.dumps(last_msg, indent=2)}")
|
|
156
|
+
print("-" * 20)
|
|
157
|
+
|
|
158
|
+
tokens_sent = count_tokens(messages)
|
|
159
|
+
if model_alias:
|
|
160
|
+
print(f"[{model_alias}] Sent: {tokens_sent} tokens")
|
|
161
|
+
else:
|
|
162
|
+
print(f"[{model_id}] Sent: {tokens_sent} tokens")
|
|
163
|
+
|
|
164
|
+
for attempt in range(max_retries):
|
|
165
|
+
try:
|
|
166
|
+
resp = requests.post(
|
|
167
|
+
url, json=payload, headers=headers, timeout=REQUEST_TIMEOUT
|
|
168
|
+
)
|
|
169
|
+
resp.raise_for_status()
|
|
170
|
+
resp_json = resp.json()
|
|
171
|
+
|
|
172
|
+
# Extract usage if available, otherwise use naive count
|
|
173
|
+
usage = resp_json.get("usage", {})
|
|
174
|
+
prompt_tokens = usage.get("prompt_tokens", tokens_sent)
|
|
175
|
+
completion_tokens = usage.get("completion_tokens", 0)
|
|
176
|
+
if "completion_tokens" not in usage:
|
|
177
|
+
completion_tokens = (
|
|
178
|
+
len(json.dumps(resp_json["choices"][0]["message"])) // 4
|
|
179
|
+
)
|
|
180
|
+
|
|
181
|
+
total_call_tokens = prompt_tokens + completion_tokens
|
|
182
|
+
|
|
183
|
+
if usage_tracker and model_alias:
|
|
184
|
+
usage_tracker.add_usage(model_alias, total_call_tokens)
|
|
185
|
+
|
|
186
|
+
return resp_json["choices"][0]["message"]
|
|
187
|
+
except requests.exceptions.HTTPError as e:
|
|
188
|
+
if e.response is not None and e.response.status_code == 429:
|
|
189
|
+
if attempt < max_retries - 1:
|
|
190
|
+
retry_after = e.response.headers.get("Retry-After")
|
|
191
|
+
if retry_after:
|
|
192
|
+
try:
|
|
193
|
+
# Handle potential floating point strings (e.g. "5.0")
|
|
194
|
+
wait_time = int(float(retry_after))
|
|
195
|
+
except ValueError:
|
|
196
|
+
wait_time = backoff
|
|
197
|
+
backoff = min(backoff * 2, max_backoff)
|
|
198
|
+
else:
|
|
199
|
+
wait_time = backoff
|
|
200
|
+
backoff = min(backoff * 2, max_backoff)
|
|
201
|
+
|
|
202
|
+
print(
|
|
203
|
+
f"Rate limit exceeded (429). Retrying in {wait_time} seconds..."
|
|
204
|
+
)
|
|
205
|
+
time.sleep(wait_time)
|
|
206
|
+
continue
|
|
207
|
+
raise e
|
|
208
|
+
except requests.exceptions.RequestException as e:
|
|
209
|
+
if attempt < max_retries - 1:
|
|
210
|
+
print(f"Request error: {e}. Retrying in {backoff} seconds...")
|
|
211
|
+
time.sleep(backoff)
|
|
212
|
+
backoff = min(backoff * 2, max_backoff)
|
|
213
|
+
continue
|
|
214
|
+
raise e
|
|
215
|
+
raise requests.exceptions.RequestException("Max retries exceeded")
|
|
216
|
+
|
|
217
|
+
|
|
218
|
+
def extract_calls(msg: Dict[str, Any], turn: int) -> List[Dict[str, Any]]:
|
|
219
|
+
"""Extract tool calls from an LLM message."""
|
|
220
|
+
tc = msg.get("tool_calls")
|
|
221
|
+
if tc:
|
|
222
|
+
return tc
|
|
223
|
+
parsed = parse_textual_tool_call(msg.get("content", ""))
|
|
224
|
+
if parsed:
|
|
225
|
+
return [{"id": f"textual_call_{turn}", "function": parsed}]
|
|
226
|
+
return []
|
|
227
|
+
|
|
228
|
+
|
|
229
|
+
def construct_system_prompt(
|
|
230
|
+
deep_research_n: int, deep_dive: bool, force_search: bool
|
|
231
|
+
) -> str:
|
|
232
|
+
"""Build the system prompt based on mode flags."""
|
|
233
|
+
system_content = SYSTEM_PROMPT
|
|
234
|
+
if force_search:
|
|
235
|
+
system_content += FORCE_SEARCH_PROMPT
|
|
236
|
+
system_content += SYSTEM_PROMPT_SUFFIX
|
|
237
|
+
|
|
238
|
+
if deep_research_n > 0:
|
|
239
|
+
system_content += DEEP_RESEARCH_PROMPT_TEMPLATE.format(n=deep_research_n)
|
|
240
|
+
if deep_dive:
|
|
241
|
+
system_content += DEEP_DIVE_PROMPT_TEMPLATE
|
|
242
|
+
return system_content
|
|
243
|
+
|
|
244
|
+
|
|
245
|
+
def run_conversation_loop(
|
|
246
|
+
model_config: Dict[str, Any],
|
|
247
|
+
messages: List[Dict[str, Any]],
|
|
248
|
+
summarize: bool,
|
|
249
|
+
verbose: bool = False,
|
|
250
|
+
usage_tracker: Optional[UsageTracker] = None,
|
|
251
|
+
) -> str:
|
|
252
|
+
"""Run the multi-turn conversation loop with tool execution."""
|
|
253
|
+
turn = 0
|
|
254
|
+
start_time = time.perf_counter()
|
|
255
|
+
final_answer = ""
|
|
256
|
+
original_system_prompt = (
|
|
257
|
+
messages[0]["content"] if messages and messages[0]["role"] == "system" else ""
|
|
258
|
+
)
|
|
259
|
+
|
|
260
|
+
# Reset read URLs for new conversation
|
|
261
|
+
reset_read_urls()
|
|
262
|
+
|
|
263
|
+
try:
|
|
264
|
+
while turn < MAX_TURNS:
|
|
265
|
+
turn += 1
|
|
266
|
+
|
|
267
|
+
# Token & Turn Tracking
|
|
268
|
+
total_tokens = count_tokens(messages)
|
|
269
|
+
context_size = model_config.get("context_size", DEFAULT_CONTEXT_SIZE)
|
|
270
|
+
turns_left = MAX_TURNS - turn + 1
|
|
271
|
+
|
|
272
|
+
status_msg = (
|
|
273
|
+
f"\n\n[SYSTEM UPDATE]:\n"
|
|
274
|
+
f"- Context Used: {total_tokens / context_size * 100:.2f}%"
|
|
275
|
+
f"- Turns Remaining: {turns_left} (out of {MAX_TURNS})\n"
|
|
276
|
+
f"Please manage your context usage efficiently."
|
|
277
|
+
)
|
|
278
|
+
if messages and messages[0]["role"] == "system":
|
|
279
|
+
messages[0]["content"] = original_system_prompt + status_msg
|
|
280
|
+
|
|
281
|
+
msg = get_llm_msg(
|
|
282
|
+
model_config["id"],
|
|
283
|
+
messages,
|
|
284
|
+
verbose=verbose,
|
|
285
|
+
model_alias=model_config.get("alias"),
|
|
286
|
+
usage_tracker=usage_tracker,
|
|
287
|
+
)
|
|
288
|
+
calls = extract_calls(msg, turn)
|
|
289
|
+
if not calls:
|
|
290
|
+
final_answer = strip_think_tags(msg.get("content", ""))
|
|
291
|
+
if is_markdown(final_answer):
|
|
292
|
+
console = Console()
|
|
293
|
+
console.print(Markdown(final_answer))
|
|
294
|
+
else:
|
|
295
|
+
print(final_answer)
|
|
296
|
+
break
|
|
297
|
+
messages.append(msg)
|
|
298
|
+
for call in calls:
|
|
299
|
+
result = dispatch_tool_call(call, model_config["max_chars"], summarize)
|
|
300
|
+
messages.append(
|
|
301
|
+
{
|
|
302
|
+
"role": "tool",
|
|
303
|
+
"tool_call_id": call["id"],
|
|
304
|
+
"content": json.dumps(result),
|
|
305
|
+
}
|
|
306
|
+
)
|
|
307
|
+
if turn >= MAX_TURNS:
|
|
308
|
+
print("Error: Max turns reached.")
|
|
309
|
+
except Exception as e:
|
|
310
|
+
print(f"Error: {str(e)}")
|
|
311
|
+
finally:
|
|
312
|
+
print(f"\nQuery completed in {time.perf_counter() - start_time:.2f} seconds")
|
|
313
|
+
return final_answer
|
|
314
|
+
|
|
315
|
+
|
|
316
|
+
def generate_summaries(
|
|
317
|
+
query: str, answer: str, usage_tracker: Optional[UsageTracker] = None
|
|
318
|
+
) -> tuple[str, str]:
|
|
319
|
+
"""Generate summaries for query and answer using the summarization model."""
|
|
320
|
+
query_summary = ""
|
|
321
|
+
answer_summary = ""
|
|
322
|
+
|
|
323
|
+
# Generate Query Summary (if needed)
|
|
324
|
+
if len(query) > CONTINUE_QUERY_THRESHOLD:
|
|
325
|
+
try:
|
|
326
|
+
msgs = [
|
|
327
|
+
{
|
|
328
|
+
"role": "system",
|
|
329
|
+
"content": SUMMARIZE_QUERY_PROMPT_TEMPLATE.format(
|
|
330
|
+
QUERY_SUMMARY_MAX_CHARS=QUERY_SUMMARY_MAX_CHARS
|
|
331
|
+
),
|
|
332
|
+
},
|
|
333
|
+
{"role": "user", "content": query[:1000]},
|
|
334
|
+
]
|
|
335
|
+
model_id = MODELS[SUMMARIZATION_MODEL]["id"]
|
|
336
|
+
model_alias = MODELS[SUMMARIZATION_MODEL].get("alias", SUMMARIZATION_MODEL)
|
|
337
|
+
msg = get_llm_msg(
|
|
338
|
+
model_id,
|
|
339
|
+
msgs,
|
|
340
|
+
tools=None,
|
|
341
|
+
model_alias=model_alias,
|
|
342
|
+
usage_tracker=usage_tracker,
|
|
343
|
+
)
|
|
344
|
+
query_summary = strip_think_tags(msg.get("content", "")).strip()
|
|
345
|
+
if len(query_summary) > QUERY_SUMMARY_MAX_CHARS:
|
|
346
|
+
query_summary = query_summary[: QUERY_SUMMARY_MAX_CHARS - 3] + "..."
|
|
347
|
+
except Exception as e:
|
|
348
|
+
print(f"Error summarizing query: {e}")
|
|
349
|
+
query_summary = query[:QUERY_SUMMARY_MAX_CHARS]
|
|
350
|
+
|
|
351
|
+
# Generate Answer Summary (Always)
|
|
352
|
+
try:
|
|
353
|
+
msgs = [
|
|
354
|
+
{
|
|
355
|
+
"role": "system",
|
|
356
|
+
"content": SUMMARIZE_ANSWER_PROMPT_TEMPLATE.format(
|
|
357
|
+
ANSWER_SUMMARY_MAX_CHARS=ANSWER_SUMMARY_MAX_CHARS
|
|
358
|
+
),
|
|
359
|
+
},
|
|
360
|
+
{"role": "user", "content": answer[:5000]},
|
|
361
|
+
]
|
|
362
|
+
model_id = MODELS[SUMMARIZATION_MODEL]["id"]
|
|
363
|
+
model_alias = MODELS[SUMMARIZATION_MODEL].get("alias", SUMMARIZATION_MODEL)
|
|
364
|
+
msg = get_llm_msg(
|
|
365
|
+
model_id,
|
|
366
|
+
msgs,
|
|
367
|
+
tools=None,
|
|
368
|
+
model_alias=model_alias,
|
|
369
|
+
usage_tracker=usage_tracker,
|
|
370
|
+
)
|
|
371
|
+
answer_summary = strip_think_tags(msg.get("content", "")).strip()
|
|
372
|
+
if len(answer_summary) > ANSWER_SUMMARY_MAX_CHARS:
|
|
373
|
+
answer_summary = answer_summary[: ANSWER_SUMMARY_MAX_CHARS - 3] + "..."
|
|
374
|
+
except Exception as e:
|
|
375
|
+
print(f"Error summarizing answer: {e}")
|
|
376
|
+
answer_summary = answer[:ANSWER_SUMMARY_MAX_CHARS]
|
|
377
|
+
|
|
378
|
+
return query_summary, answer_summary
|
asky/storage.py
ADDED
|
@@ -0,0 +1,157 @@
|
|
|
1
|
+
"""SQLite database functions for conversation history."""
|
|
2
|
+
|
|
3
|
+
import os
|
|
4
|
+
import sqlite3
|
|
5
|
+
from datetime import datetime
|
|
6
|
+
from typing import List, Optional
|
|
7
|
+
|
|
8
|
+
from asky.config import DB_PATH, CONTINUE_QUERY_THRESHOLD
|
|
9
|
+
|
|
10
|
+
|
|
11
|
+
def init_db() -> None:
|
|
12
|
+
"""Initialize the SQLite database and create tables if they don't exist."""
|
|
13
|
+
os.makedirs(DB_PATH.parent, exist_ok=True)
|
|
14
|
+
conn = sqlite3.connect(DB_PATH)
|
|
15
|
+
c = conn.cursor()
|
|
16
|
+
c.execute("""
|
|
17
|
+
CREATE TABLE IF NOT EXISTS history (
|
|
18
|
+
id INTEGER PRIMARY KEY AUTOINCREMENT,
|
|
19
|
+
timestamp TEXT,
|
|
20
|
+
query TEXT,
|
|
21
|
+
query_summary TEXT,
|
|
22
|
+
answer TEXT,
|
|
23
|
+
answer_summary TEXT,
|
|
24
|
+
model TEXT
|
|
25
|
+
)
|
|
26
|
+
""")
|
|
27
|
+
conn.commit()
|
|
28
|
+
conn.close()
|
|
29
|
+
|
|
30
|
+
|
|
31
|
+
def get_history(limit: int = 10) -> List[tuple]:
|
|
32
|
+
"""Retrieve recent history entries."""
|
|
33
|
+
conn = sqlite3.connect(DB_PATH)
|
|
34
|
+
c = conn.cursor()
|
|
35
|
+
c.execute(
|
|
36
|
+
"SELECT id, timestamp, query, query_summary, answer_summary, model FROM history ORDER BY id DESC LIMIT ?",
|
|
37
|
+
(limit,),
|
|
38
|
+
)
|
|
39
|
+
rows = c.fetchall()
|
|
40
|
+
conn.close()
|
|
41
|
+
return rows
|
|
42
|
+
|
|
43
|
+
|
|
44
|
+
def get_db_record_count() -> int:
|
|
45
|
+
"""Return the total number of entries in the history table."""
|
|
46
|
+
conn = sqlite3.connect(DB_PATH)
|
|
47
|
+
c = conn.cursor()
|
|
48
|
+
c.execute("SELECT COUNT(*) FROM history")
|
|
49
|
+
count = c.fetchone()[0]
|
|
50
|
+
conn.close()
|
|
51
|
+
return count
|
|
52
|
+
|
|
53
|
+
|
|
54
|
+
def get_interaction_context(ids: List[int], full: bool = False) -> str:
|
|
55
|
+
"""Get context from previous interactions by their IDs."""
|
|
56
|
+
if not ids:
|
|
57
|
+
return ""
|
|
58
|
+
conn = sqlite3.connect(DB_PATH)
|
|
59
|
+
c = conn.cursor()
|
|
60
|
+
placeholders = ",".join(["?"] * len(ids))
|
|
61
|
+
query_str = f"SELECT id, query, query_summary, answer, answer_summary FROM history WHERE id IN ({placeholders})"
|
|
62
|
+
c.execute(query_str, ids)
|
|
63
|
+
results = c.fetchall()
|
|
64
|
+
conn.close()
|
|
65
|
+
|
|
66
|
+
context_parts = []
|
|
67
|
+
for row in results:
|
|
68
|
+
rid, query, q_sum, answer, a_sum = row
|
|
69
|
+
# Use summary if available and original query is long enough
|
|
70
|
+
if q_sum and len(query) >= CONTINUE_QUERY_THRESHOLD:
|
|
71
|
+
q_text = q_sum
|
|
72
|
+
else:
|
|
73
|
+
q_text = query
|
|
74
|
+
|
|
75
|
+
a_text = answer if full else a_sum
|
|
76
|
+
context_parts.append(f"Query {rid}: {q_text}\nAnswer {rid}: {a_text}")
|
|
77
|
+
|
|
78
|
+
return "\n\n".join(context_parts)
|
|
79
|
+
|
|
80
|
+
|
|
81
|
+
def cleanup_db(target: Optional[str], delete_all: bool = False) -> None:
|
|
82
|
+
"""Delete history records by ID, range, list, or all."""
|
|
83
|
+
conn = sqlite3.connect(DB_PATH)
|
|
84
|
+
c = conn.cursor()
|
|
85
|
+
|
|
86
|
+
try:
|
|
87
|
+
if delete_all:
|
|
88
|
+
c.execute("DELETE FROM history")
|
|
89
|
+
c.execute("DELETE FROM sqlite_sequence WHERE name='history'")
|
|
90
|
+
print("dataset cleaned completely.")
|
|
91
|
+
elif target:
|
|
92
|
+
_delete_by_target(c, target)
|
|
93
|
+
else:
|
|
94
|
+
print(
|
|
95
|
+
"Error: No target specified for cleanup. Use --all or provide IDs/range."
|
|
96
|
+
)
|
|
97
|
+
|
|
98
|
+
conn.commit()
|
|
99
|
+
except Exception as e:
|
|
100
|
+
print(f"Error during cleanup: {e}")
|
|
101
|
+
finally:
|
|
102
|
+
conn.close()
|
|
103
|
+
|
|
104
|
+
|
|
105
|
+
def _delete_by_target(cursor: sqlite3.Cursor, target: str) -> None:
|
|
106
|
+
"""Helper to delete records by target specification."""
|
|
107
|
+
# Check for range (e.g., "1-5")
|
|
108
|
+
if "-" in target:
|
|
109
|
+
try:
|
|
110
|
+
start, end = map(int, target.split("-"))
|
|
111
|
+
if start > end:
|
|
112
|
+
start, end = end, start
|
|
113
|
+
cursor.execute(
|
|
114
|
+
"DELETE FROM history WHERE id >= ? AND id <= ?", (start, end)
|
|
115
|
+
)
|
|
116
|
+
print(f"deleted records from {start} to {end}.")
|
|
117
|
+
except ValueError:
|
|
118
|
+
print("Error: Invalid range format. Use 'start-end' (e.g., 1-5).")
|
|
119
|
+
# Check for comma-separated list (e.g., "1,3,5")
|
|
120
|
+
elif "," in target:
|
|
121
|
+
try:
|
|
122
|
+
ids = [int(x.strip()) for x in target.split(",")]
|
|
123
|
+
placeholders = ",".join(["?"] * len(ids))
|
|
124
|
+
cursor.execute(f"DELETE FROM history WHERE id IN ({placeholders})", ids)
|
|
125
|
+
print(f"deleted records: {', '.join(map(str, ids))}.")
|
|
126
|
+
except ValueError:
|
|
127
|
+
print("Error: Invalid list format. Use comma-separated integers.")
|
|
128
|
+
# Single ID
|
|
129
|
+
else:
|
|
130
|
+
try:
|
|
131
|
+
rid = int(target)
|
|
132
|
+
cursor.execute("DELETE FROM history WHERE id = ?", (rid,))
|
|
133
|
+
print(f"deleted record {rid}.")
|
|
134
|
+
except ValueError:
|
|
135
|
+
print("Error: Invalid ID format. Must be an integer.")
|
|
136
|
+
|
|
137
|
+
|
|
138
|
+
def save_interaction(
|
|
139
|
+
query: str,
|
|
140
|
+
answer: str,
|
|
141
|
+
model: str,
|
|
142
|
+
query_summary: str = "",
|
|
143
|
+
answer_summary: str = "",
|
|
144
|
+
) -> None:
|
|
145
|
+
"""Save an interaction to the database."""
|
|
146
|
+
conn = sqlite3.connect(DB_PATH)
|
|
147
|
+
c = conn.cursor()
|
|
148
|
+
timestamp = datetime.now().isoformat()
|
|
149
|
+
c.execute(
|
|
150
|
+
"""
|
|
151
|
+
INSERT INTO history (timestamp, query, query_summary, answer, answer_summary, model)
|
|
152
|
+
VALUES (?, ?, ?, ?, ?, ?)
|
|
153
|
+
""",
|
|
154
|
+
(timestamp, query, query_summary, answer, answer_summary, model),
|
|
155
|
+
)
|
|
156
|
+
conn.commit()
|
|
157
|
+
conn.close()
|