asky-cli 0.1.6__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
asky/llm.py ADDED
@@ -0,0 +1,378 @@
1
+ """LLM integration and conversation loop."""
2
+
3
+ import json
4
+ import os
5
+ import re
6
+ import time
7
+ from typing import Any, Dict, List, Optional
8
+
9
+ import requests
10
+ from rich.console import Console
11
+ from rich.markdown import Markdown
12
+
13
+
14
+ from asky.config import (
15
+ MAX_TURNS,
16
+ MODELS,
17
+ TOOLS,
18
+ SYSTEM_PROMPT,
19
+ FORCE_SEARCH_PROMPT,
20
+ SYSTEM_PROMPT_SUFFIX,
21
+ DEEP_RESEARCH_PROMPT_TEMPLATE,
22
+ DEEP_DIVE_PROMPT_TEMPLATE,
23
+ QUERY_SUMMARY_MAX_CHARS,
24
+ ANSWER_SUMMARY_MAX_CHARS,
25
+ SUMMARIZATION_MODEL,
26
+ LLM_USER_AGENT,
27
+ SUMMARIZE_QUERY_PROMPT_TEMPLATE,
28
+ SUMMARIZE_ANSWER_PROMPT_TEMPLATE,
29
+ REQUEST_TIMEOUT,
30
+ DEFAULT_CONTEXT_SIZE,
31
+ CONTINUE_QUERY_THRESHOLD,
32
+ )
33
+ from asky.html import strip_think_tags
34
+ from asky.tools import dispatch_tool_call, reset_read_urls
35
+
36
+
37
+ def is_markdown(text: str) -> bool:
38
+ """Check if the text likely contains markdown formatting."""
39
+ # Basic detection: common markdown patterns
40
+ patterns = [
41
+ r"^#+\s", # Headers
42
+ r"\*\*.*\*\*", # Bold
43
+ r"__.*__", # Bold
44
+ r"\*.*\* ", # Italic
45
+ r"_.*_", # Italic
46
+ r"\[.*\]\(.*\)", # Links
47
+ r"```", # Code blocks
48
+ r"^\s*[-*+]\s", # Lists
49
+ r"^\s*\d+\.\s", # Numbered lists
50
+ ]
51
+ return any(re.search(p, text, re.M) for p in patterns)
52
+
53
+
54
+ def parse_textual_tool_call(text: str) -> Optional[Dict[str, Any]]:
55
+ """Parse tool calls from textual format (fallback for some models)."""
56
+ if not text:
57
+ return None
58
+ m = re.search(r"to=functions\.([a-zA-Z0-9_]+)", text)
59
+ if not m:
60
+ return None
61
+ name = m.group(1)
62
+ j = re.search(r"(\{.*\})", text, re.S)
63
+ if not j:
64
+ return None
65
+ try:
66
+ json.loads(j.group(1))
67
+ return {"name": name, "arguments": j.group(1)}
68
+ except Exception:
69
+ return None
70
+
71
+
72
+ class UsageTracker:
73
+ """Track token usage per model alias."""
74
+
75
+ def __init__(self):
76
+ self.usage: Dict[str, int] = {}
77
+
78
+ def add_usage(self, model_alias: str, tokens: int):
79
+ self.usage[model_alias] = self.usage.get(model_alias, 0) + tokens
80
+
81
+ def get_total_usage(self, model_alias: str) -> int:
82
+ return self.usage.get(model_alias, 0)
83
+
84
+
85
+ def count_tokens(messages: List[Dict[str, Any]]) -> int:
86
+ """Naive token counting: chars / 4."""
87
+ total_chars = 0
88
+ for m in messages:
89
+ content = m.get("content")
90
+ if content:
91
+ total_chars += len(content)
92
+ # Also count tool calls and results
93
+ tc = m.get("tool_calls")
94
+ if tc:
95
+ total_chars += len(json.dumps(tc))
96
+ return total_chars // 4
97
+
98
+
99
+ def get_llm_msg(
100
+ model_id: str,
101
+ messages: List[Dict[str, Any]],
102
+ tools: Optional[List[Dict]] = TOOLS,
103
+ verbose: bool = False,
104
+ model_alias: Optional[str] = None,
105
+ usage_tracker: Optional[UsageTracker] = None,
106
+ ) -> Dict[str, Any]:
107
+ """Send messages to the LLM and get a response."""
108
+ # Find the model config based on model_id
109
+ model_config = next((m for m in MODELS.values() if m["id"] == model_id), None)
110
+
111
+ url = ""
112
+ headers = {
113
+ "Content-Type": "application/json",
114
+ "User-Agent": LLM_USER_AGENT,
115
+ }
116
+
117
+ if model_config and "base_url" in model_config:
118
+ url = model_config["base_url"]
119
+
120
+ api_key = None
121
+ if model_config and "api_key" in model_config:
122
+ api_key = model_config["api_key"]
123
+ elif model_config and "api_key_env" in model_config:
124
+ api_key_env_var = model_config["api_key_env"]
125
+ api_key = os.environ.get(api_key_env_var)
126
+ if not api_key:
127
+ print(f"Warning: {api_key_env_var} not found in environment variables.")
128
+
129
+ if api_key:
130
+ headers["Authorization"] = f"Bearer {api_key}"
131
+
132
+ payload = {
133
+ "model": model_id,
134
+ "messages": messages,
135
+ }
136
+ if tools:
137
+ payload["tools"] = tools
138
+ payload["tool_choice"] = "auto"
139
+
140
+ max_retries = 10
141
+ backoff = 2
142
+ max_backoff = 60
143
+
144
+ if verbose:
145
+ print(f"\n[DEBUG] Sending to LLM ({model_id})...")
146
+ print(f"Tools enabled: {bool(tools)}")
147
+ print("Last message sent:")
148
+ if messages:
149
+ last_msg = messages[-1]
150
+ content = last_msg.get("content", "")
151
+ if content and len(content) > 500:
152
+ print(f" Role: {last_msg['role']}")
153
+ print(f" Content (truncated): {content[:500]}...")
154
+ else:
155
+ print(f" {json.dumps(last_msg, indent=2)}")
156
+ print("-" * 20)
157
+
158
+ tokens_sent = count_tokens(messages)
159
+ if model_alias:
160
+ print(f"[{model_alias}] Sent: {tokens_sent} tokens")
161
+ else:
162
+ print(f"[{model_id}] Sent: {tokens_sent} tokens")
163
+
164
+ for attempt in range(max_retries):
165
+ try:
166
+ resp = requests.post(
167
+ url, json=payload, headers=headers, timeout=REQUEST_TIMEOUT
168
+ )
169
+ resp.raise_for_status()
170
+ resp_json = resp.json()
171
+
172
+ # Extract usage if available, otherwise use naive count
173
+ usage = resp_json.get("usage", {})
174
+ prompt_tokens = usage.get("prompt_tokens", tokens_sent)
175
+ completion_tokens = usage.get("completion_tokens", 0)
176
+ if "completion_tokens" not in usage:
177
+ completion_tokens = (
178
+ len(json.dumps(resp_json["choices"][0]["message"])) // 4
179
+ )
180
+
181
+ total_call_tokens = prompt_tokens + completion_tokens
182
+
183
+ if usage_tracker and model_alias:
184
+ usage_tracker.add_usage(model_alias, total_call_tokens)
185
+
186
+ return resp_json["choices"][0]["message"]
187
+ except requests.exceptions.HTTPError as e:
188
+ if e.response is not None and e.response.status_code == 429:
189
+ if attempt < max_retries - 1:
190
+ retry_after = e.response.headers.get("Retry-After")
191
+ if retry_after:
192
+ try:
193
+ # Handle potential floating point strings (e.g. "5.0")
194
+ wait_time = int(float(retry_after))
195
+ except ValueError:
196
+ wait_time = backoff
197
+ backoff = min(backoff * 2, max_backoff)
198
+ else:
199
+ wait_time = backoff
200
+ backoff = min(backoff * 2, max_backoff)
201
+
202
+ print(
203
+ f"Rate limit exceeded (429). Retrying in {wait_time} seconds..."
204
+ )
205
+ time.sleep(wait_time)
206
+ continue
207
+ raise e
208
+ except requests.exceptions.RequestException as e:
209
+ if attempt < max_retries - 1:
210
+ print(f"Request error: {e}. Retrying in {backoff} seconds...")
211
+ time.sleep(backoff)
212
+ backoff = min(backoff * 2, max_backoff)
213
+ continue
214
+ raise e
215
+ raise requests.exceptions.RequestException("Max retries exceeded")
216
+
217
+
218
+ def extract_calls(msg: Dict[str, Any], turn: int) -> List[Dict[str, Any]]:
219
+ """Extract tool calls from an LLM message."""
220
+ tc = msg.get("tool_calls")
221
+ if tc:
222
+ return tc
223
+ parsed = parse_textual_tool_call(msg.get("content", ""))
224
+ if parsed:
225
+ return [{"id": f"textual_call_{turn}", "function": parsed}]
226
+ return []
227
+
228
+
229
+ def construct_system_prompt(
230
+ deep_research_n: int, deep_dive: bool, force_search: bool
231
+ ) -> str:
232
+ """Build the system prompt based on mode flags."""
233
+ system_content = SYSTEM_PROMPT
234
+ if force_search:
235
+ system_content += FORCE_SEARCH_PROMPT
236
+ system_content += SYSTEM_PROMPT_SUFFIX
237
+
238
+ if deep_research_n > 0:
239
+ system_content += DEEP_RESEARCH_PROMPT_TEMPLATE.format(n=deep_research_n)
240
+ if deep_dive:
241
+ system_content += DEEP_DIVE_PROMPT_TEMPLATE
242
+ return system_content
243
+
244
+
245
+ def run_conversation_loop(
246
+ model_config: Dict[str, Any],
247
+ messages: List[Dict[str, Any]],
248
+ summarize: bool,
249
+ verbose: bool = False,
250
+ usage_tracker: Optional[UsageTracker] = None,
251
+ ) -> str:
252
+ """Run the multi-turn conversation loop with tool execution."""
253
+ turn = 0
254
+ start_time = time.perf_counter()
255
+ final_answer = ""
256
+ original_system_prompt = (
257
+ messages[0]["content"] if messages and messages[0]["role"] == "system" else ""
258
+ )
259
+
260
+ # Reset read URLs for new conversation
261
+ reset_read_urls()
262
+
263
+ try:
264
+ while turn < MAX_TURNS:
265
+ turn += 1
266
+
267
+ # Token & Turn Tracking
268
+ total_tokens = count_tokens(messages)
269
+ context_size = model_config.get("context_size", DEFAULT_CONTEXT_SIZE)
270
+ turns_left = MAX_TURNS - turn + 1
271
+
272
+ status_msg = (
273
+ f"\n\n[SYSTEM UPDATE]:\n"
274
+ f"- Context Used: {total_tokens / context_size * 100:.2f}%"
275
+ f"- Turns Remaining: {turns_left} (out of {MAX_TURNS})\n"
276
+ f"Please manage your context usage efficiently."
277
+ )
278
+ if messages and messages[0]["role"] == "system":
279
+ messages[0]["content"] = original_system_prompt + status_msg
280
+
281
+ msg = get_llm_msg(
282
+ model_config["id"],
283
+ messages,
284
+ verbose=verbose,
285
+ model_alias=model_config.get("alias"),
286
+ usage_tracker=usage_tracker,
287
+ )
288
+ calls = extract_calls(msg, turn)
289
+ if not calls:
290
+ final_answer = strip_think_tags(msg.get("content", ""))
291
+ if is_markdown(final_answer):
292
+ console = Console()
293
+ console.print(Markdown(final_answer))
294
+ else:
295
+ print(final_answer)
296
+ break
297
+ messages.append(msg)
298
+ for call in calls:
299
+ result = dispatch_tool_call(call, model_config["max_chars"], summarize)
300
+ messages.append(
301
+ {
302
+ "role": "tool",
303
+ "tool_call_id": call["id"],
304
+ "content": json.dumps(result),
305
+ }
306
+ )
307
+ if turn >= MAX_TURNS:
308
+ print("Error: Max turns reached.")
309
+ except Exception as e:
310
+ print(f"Error: {str(e)}")
311
+ finally:
312
+ print(f"\nQuery completed in {time.perf_counter() - start_time:.2f} seconds")
313
+ return final_answer
314
+
315
+
316
+ def generate_summaries(
317
+ query: str, answer: str, usage_tracker: Optional[UsageTracker] = None
318
+ ) -> tuple[str, str]:
319
+ """Generate summaries for query and answer using the summarization model."""
320
+ query_summary = ""
321
+ answer_summary = ""
322
+
323
+ # Generate Query Summary (if needed)
324
+ if len(query) > CONTINUE_QUERY_THRESHOLD:
325
+ try:
326
+ msgs = [
327
+ {
328
+ "role": "system",
329
+ "content": SUMMARIZE_QUERY_PROMPT_TEMPLATE.format(
330
+ QUERY_SUMMARY_MAX_CHARS=QUERY_SUMMARY_MAX_CHARS
331
+ ),
332
+ },
333
+ {"role": "user", "content": query[:1000]},
334
+ ]
335
+ model_id = MODELS[SUMMARIZATION_MODEL]["id"]
336
+ model_alias = MODELS[SUMMARIZATION_MODEL].get("alias", SUMMARIZATION_MODEL)
337
+ msg = get_llm_msg(
338
+ model_id,
339
+ msgs,
340
+ tools=None,
341
+ model_alias=model_alias,
342
+ usage_tracker=usage_tracker,
343
+ )
344
+ query_summary = strip_think_tags(msg.get("content", "")).strip()
345
+ if len(query_summary) > QUERY_SUMMARY_MAX_CHARS:
346
+ query_summary = query_summary[: QUERY_SUMMARY_MAX_CHARS - 3] + "..."
347
+ except Exception as e:
348
+ print(f"Error summarizing query: {e}")
349
+ query_summary = query[:QUERY_SUMMARY_MAX_CHARS]
350
+
351
+ # Generate Answer Summary (Always)
352
+ try:
353
+ msgs = [
354
+ {
355
+ "role": "system",
356
+ "content": SUMMARIZE_ANSWER_PROMPT_TEMPLATE.format(
357
+ ANSWER_SUMMARY_MAX_CHARS=ANSWER_SUMMARY_MAX_CHARS
358
+ ),
359
+ },
360
+ {"role": "user", "content": answer[:5000]},
361
+ ]
362
+ model_id = MODELS[SUMMARIZATION_MODEL]["id"]
363
+ model_alias = MODELS[SUMMARIZATION_MODEL].get("alias", SUMMARIZATION_MODEL)
364
+ msg = get_llm_msg(
365
+ model_id,
366
+ msgs,
367
+ tools=None,
368
+ model_alias=model_alias,
369
+ usage_tracker=usage_tracker,
370
+ )
371
+ answer_summary = strip_think_tags(msg.get("content", "")).strip()
372
+ if len(answer_summary) > ANSWER_SUMMARY_MAX_CHARS:
373
+ answer_summary = answer_summary[: ANSWER_SUMMARY_MAX_CHARS - 3] + "..."
374
+ except Exception as e:
375
+ print(f"Error summarizing answer: {e}")
376
+ answer_summary = answer[:ANSWER_SUMMARY_MAX_CHARS]
377
+
378
+ return query_summary, answer_summary
asky/storage.py ADDED
@@ -0,0 +1,157 @@
1
+ """SQLite database functions for conversation history."""
2
+
3
+ import os
4
+ import sqlite3
5
+ from datetime import datetime
6
+ from typing import List, Optional
7
+
8
+ from asky.config import DB_PATH, CONTINUE_QUERY_THRESHOLD
9
+
10
+
11
+ def init_db() -> None:
12
+ """Initialize the SQLite database and create tables if they don't exist."""
13
+ os.makedirs(DB_PATH.parent, exist_ok=True)
14
+ conn = sqlite3.connect(DB_PATH)
15
+ c = conn.cursor()
16
+ c.execute("""
17
+ CREATE TABLE IF NOT EXISTS history (
18
+ id INTEGER PRIMARY KEY AUTOINCREMENT,
19
+ timestamp TEXT,
20
+ query TEXT,
21
+ query_summary TEXT,
22
+ answer TEXT,
23
+ answer_summary TEXT,
24
+ model TEXT
25
+ )
26
+ """)
27
+ conn.commit()
28
+ conn.close()
29
+
30
+
31
+ def get_history(limit: int = 10) -> List[tuple]:
32
+ """Retrieve recent history entries."""
33
+ conn = sqlite3.connect(DB_PATH)
34
+ c = conn.cursor()
35
+ c.execute(
36
+ "SELECT id, timestamp, query, query_summary, answer_summary, model FROM history ORDER BY id DESC LIMIT ?",
37
+ (limit,),
38
+ )
39
+ rows = c.fetchall()
40
+ conn.close()
41
+ return rows
42
+
43
+
44
+ def get_db_record_count() -> int:
45
+ """Return the total number of entries in the history table."""
46
+ conn = sqlite3.connect(DB_PATH)
47
+ c = conn.cursor()
48
+ c.execute("SELECT COUNT(*) FROM history")
49
+ count = c.fetchone()[0]
50
+ conn.close()
51
+ return count
52
+
53
+
54
+ def get_interaction_context(ids: List[int], full: bool = False) -> str:
55
+ """Get context from previous interactions by their IDs."""
56
+ if not ids:
57
+ return ""
58
+ conn = sqlite3.connect(DB_PATH)
59
+ c = conn.cursor()
60
+ placeholders = ",".join(["?"] * len(ids))
61
+ query_str = f"SELECT id, query, query_summary, answer, answer_summary FROM history WHERE id IN ({placeholders})"
62
+ c.execute(query_str, ids)
63
+ results = c.fetchall()
64
+ conn.close()
65
+
66
+ context_parts = []
67
+ for row in results:
68
+ rid, query, q_sum, answer, a_sum = row
69
+ # Use summary if available and original query is long enough
70
+ if q_sum and len(query) >= CONTINUE_QUERY_THRESHOLD:
71
+ q_text = q_sum
72
+ else:
73
+ q_text = query
74
+
75
+ a_text = answer if full else a_sum
76
+ context_parts.append(f"Query {rid}: {q_text}\nAnswer {rid}: {a_text}")
77
+
78
+ return "\n\n".join(context_parts)
79
+
80
+
81
+ def cleanup_db(target: Optional[str], delete_all: bool = False) -> None:
82
+ """Delete history records by ID, range, list, or all."""
83
+ conn = sqlite3.connect(DB_PATH)
84
+ c = conn.cursor()
85
+
86
+ try:
87
+ if delete_all:
88
+ c.execute("DELETE FROM history")
89
+ c.execute("DELETE FROM sqlite_sequence WHERE name='history'")
90
+ print("dataset cleaned completely.")
91
+ elif target:
92
+ _delete_by_target(c, target)
93
+ else:
94
+ print(
95
+ "Error: No target specified for cleanup. Use --all or provide IDs/range."
96
+ )
97
+
98
+ conn.commit()
99
+ except Exception as e:
100
+ print(f"Error during cleanup: {e}")
101
+ finally:
102
+ conn.close()
103
+
104
+
105
+ def _delete_by_target(cursor: sqlite3.Cursor, target: str) -> None:
106
+ """Helper to delete records by target specification."""
107
+ # Check for range (e.g., "1-5")
108
+ if "-" in target:
109
+ try:
110
+ start, end = map(int, target.split("-"))
111
+ if start > end:
112
+ start, end = end, start
113
+ cursor.execute(
114
+ "DELETE FROM history WHERE id >= ? AND id <= ?", (start, end)
115
+ )
116
+ print(f"deleted records from {start} to {end}.")
117
+ except ValueError:
118
+ print("Error: Invalid range format. Use 'start-end' (e.g., 1-5).")
119
+ # Check for comma-separated list (e.g., "1,3,5")
120
+ elif "," in target:
121
+ try:
122
+ ids = [int(x.strip()) for x in target.split(",")]
123
+ placeholders = ",".join(["?"] * len(ids))
124
+ cursor.execute(f"DELETE FROM history WHERE id IN ({placeholders})", ids)
125
+ print(f"deleted records: {', '.join(map(str, ids))}.")
126
+ except ValueError:
127
+ print("Error: Invalid list format. Use comma-separated integers.")
128
+ # Single ID
129
+ else:
130
+ try:
131
+ rid = int(target)
132
+ cursor.execute("DELETE FROM history WHERE id = ?", (rid,))
133
+ print(f"deleted record {rid}.")
134
+ except ValueError:
135
+ print("Error: Invalid ID format. Must be an integer.")
136
+
137
+
138
+ def save_interaction(
139
+ query: str,
140
+ answer: str,
141
+ model: str,
142
+ query_summary: str = "",
143
+ answer_summary: str = "",
144
+ ) -> None:
145
+ """Save an interaction to the database."""
146
+ conn = sqlite3.connect(DB_PATH)
147
+ c = conn.cursor()
148
+ timestamp = datetime.now().isoformat()
149
+ c.execute(
150
+ """
151
+ INSERT INTO history (timestamp, query, query_summary, answer, answer_summary, model)
152
+ VALUES (?, ?, ?, ?, ?, ?)
153
+ """,
154
+ (timestamp, query, query_summary, answer, answer_summary, model),
155
+ )
156
+ conn.commit()
157
+ conn.close()