superlinear 0.1.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- apps/__init__.py +4 -0
- apps/cli/__init__.py +8 -0
- apps/cli/bm25_rag.py +471 -0
- apps/cli/chat_repl.py +1497 -0
- apps/cli/client.py +195 -0
- apps/cli/docs_repl.py +2275 -0
- apps/cli/light_rag.py +729 -0
- apps/cli/local_snapshots.py +139 -0
- apps/cli/locks.py +214 -0
- apps/cli/main.py +457 -0
- apps/cli/output.py +32 -0
- apps/cli/server_cmds.py +516 -0
- apps/cli/session_cmds.py +491 -0
- apps/cli/snapshot_cmds.py +303 -0
- apps/cli/state.py +265 -0
- apps/server/__init__.py +4 -0
- apps/server/app.py +1363 -0
- apps/server/main.py +313 -0
- superlinear/__init__.py +114 -0
- superlinear/_version.py +3 -0
- superlinear/engine/__init__.py +10 -0
- superlinear/engine/adapters/__init__.py +12 -0
- superlinear/engine/adapters/base.py +91 -0
- superlinear/engine/adapters/superlinear.py +1233 -0
- superlinear/engine/chat_engine.py +1173 -0
- superlinear/engine/chat_types.py +130 -0
- superlinear/engine/registry.py +51 -0
- superlinear/engine/repetition.py +203 -0
- superlinear/engine/session_snapshots.py +451 -0
- superlinear/engine/tool_parser.py +83 -0
- superlinear/engine/types.py +42 -0
- superlinear/kernels/__init__.py +2 -0
- superlinear/kernels/common/__init__.py +21 -0
- superlinear/kernels/common/adjustment.py +106 -0
- superlinear/kernels/common/power.py +154 -0
- superlinear/kernels/superlinear/__init__.py +10 -0
- superlinear/kernels/superlinear/attention/__init__.py +78 -0
- superlinear/kernels/superlinear/attention/_prefill.py +940 -0
- superlinear/kernels/superlinear/attention/_sliding_window.py +1167 -0
- superlinear/kernels/superlinear/attention/api.py +433 -0
- superlinear/kernels/superlinear/search/__init__.py +33 -0
- superlinear/kernels/superlinear/search/_reference.py +204 -0
- superlinear/kernels/superlinear/search/_triton.py +488 -0
- superlinear/kernels/superlinear/search/_triton_gqa.py +534 -0
- superlinear/kernels/superlinear/search/api.py +200 -0
- superlinear/kernels/superlinear/span/__init__.py +41 -0
- superlinear/kernels/superlinear/span/_triton_bucketed_gqa.py +1461 -0
- superlinear/kernels/superlinear/span/_triton_forward.py +22 -0
- superlinear/kernels/superlinear/span/_triton_gqa.py +1226 -0
- superlinear/kernels/superlinear/span/_triton_impl.py +928 -0
- superlinear/kernels/superlinear/span/_triton_precomputed_sw.py +460 -0
- superlinear/kernels/superlinear/span/_triton_precomputed_sw_gqa.py +598 -0
- superlinear/kernels/superlinear/span/api.py +296 -0
- superlinear/kernels/superlinear/span/masks.py +187 -0
- superlinear/py.typed +0 -0
- superlinear/runtime.py +71 -0
- superlinear-0.1.0.dist-info/METADATA +469 -0
- superlinear-0.1.0.dist-info/RECORD +62 -0
- superlinear-0.1.0.dist-info/WHEEL +5 -0
- superlinear-0.1.0.dist-info/entry_points.txt +2 -0
- superlinear-0.1.0.dist-info/licenses/LICENSE +202 -0
- superlinear-0.1.0.dist-info/top_level.txt +2 -0
apps/cli/docs_repl.py
ADDED
|
@@ -0,0 +1,2275 @@
|
|
|
1
|
+
from __future__ import annotations
|
|
2
|
+
|
|
3
|
+
import atexit
|
|
4
|
+
import glob
|
|
5
|
+
import hashlib
|
|
6
|
+
import json
|
|
7
|
+
import os
|
|
8
|
+
import shutil
|
|
9
|
+
import shlex
|
|
10
|
+
import sys
|
|
11
|
+
import time
|
|
12
|
+
import textwrap
|
|
13
|
+
from pathlib import Path
|
|
14
|
+
import re
|
|
15
|
+
from typing import Any
|
|
16
|
+
|
|
17
|
+
# Enable readline for arrow keys, history navigation, and line editing.
|
|
18
|
+
try:
|
|
19
|
+
import readline
|
|
20
|
+
except ImportError:
|
|
21
|
+
readline = None # type: ignore[assignment] # Windows fallback
|
|
22
|
+
|
|
23
|
+
from apps.cli.bm25_rag import Bm25RagConfig, Bm25RagRetriever
|
|
24
|
+
from apps.cli.client import HttpError, SuperlinearClient
|
|
25
|
+
from apps.cli.light_rag import LightRagConfig, LightRagRetriever, tokenize_query_terms
|
|
26
|
+
from apps.cli.local_snapshots import delete_local_snapshot, list_local_snapshots
|
|
27
|
+
from apps.cli.locks import AlreadyLockedError, SessionLock
|
|
28
|
+
from apps.cli.output import format_table
|
|
29
|
+
from apps.cli.state import DocsWorkspaceState, load_state, save_state
|
|
30
|
+
|
|
31
|
+
|
|
32
|
+
def _docs_history_file_path() -> Path:
|
|
33
|
+
return Path.home() / ".config" / "spl" / "docs_history"
|
|
34
|
+
|
|
35
|
+
|
|
36
|
+
def _setup_readline_history() -> None:
|
|
37
|
+
"""Set up persistent command history for the docs REPL."""
|
|
38
|
+
if readline is None:
|
|
39
|
+
return
|
|
40
|
+
history_file = _docs_history_file_path()
|
|
41
|
+
history_file.parent.mkdir(parents=True, exist_ok=True)
|
|
42
|
+
try:
|
|
43
|
+
readline.read_history_file(history_file)
|
|
44
|
+
except FileNotFoundError:
|
|
45
|
+
pass
|
|
46
|
+
readline.set_history_length(1000)
|
|
47
|
+
atexit.register(readline.write_history_file, history_file)
|
|
48
|
+
|
|
49
|
+
|
|
50
|
+
# Commands for tab completion
|
|
51
|
+
_DOCS_COMMANDS = [
|
|
52
|
+
"/help", "/exit", "/clear", "/history", "/ls", "/rm", "/head", "/tail",
|
|
53
|
+
"/show", "/add", "/sources", "/rag", "/reset", "/stats", "/save", "/load", "/info",
|
|
54
|
+
]
|
|
55
|
+
|
|
56
|
+
|
|
57
|
+
def _setup_completer() -> None:
|
|
58
|
+
"""Set up tab completion for REPL commands."""
|
|
59
|
+
if readline is None:
|
|
60
|
+
return
|
|
61
|
+
|
|
62
|
+
def completer(text: str, state: int) -> str | None:
|
|
63
|
+
if text.startswith("/"):
|
|
64
|
+
matches = [cmd for cmd in _DOCS_COMMANDS if cmd.startswith(text)]
|
|
65
|
+
else:
|
|
66
|
+
matches = []
|
|
67
|
+
return matches[state] if state < len(matches) else None
|
|
68
|
+
|
|
69
|
+
readline.set_completer(completer)
|
|
70
|
+
readline.set_completer_delims(" \t\n")
|
|
71
|
+
readline.parse_and_bind("tab: complete")
|
|
72
|
+
|
|
73
|
+
|
|
74
|
+
def _cmd_history(n: int = 20) -> None:
|
|
75
|
+
"""Show the last n entries from readline input history."""
|
|
76
|
+
if readline is None:
|
|
77
|
+
print("history not available (readline not loaded)", file=sys.stderr)
|
|
78
|
+
return
|
|
79
|
+
length = readline.get_current_history_length()
|
|
80
|
+
if length == 0:
|
|
81
|
+
print("(no history)")
|
|
82
|
+
return
|
|
83
|
+
start = max(1, length - n + 1)
|
|
84
|
+
for i in range(start, length + 1):
|
|
85
|
+
item = readline.get_history_item(i)
|
|
86
|
+
if item:
|
|
87
|
+
print(f"{i:4d} {item}")
|
|
88
|
+
|
|
89
|
+
|
|
90
|
+
def _cmd_history_clear() -> None:
|
|
91
|
+
"""Clear readline input history (both in-memory and on disk)."""
|
|
92
|
+
if readline is None:
|
|
93
|
+
print("history not available (readline not loaded)", file=sys.stderr)
|
|
94
|
+
return
|
|
95
|
+
try:
|
|
96
|
+
readline.clear_history()
|
|
97
|
+
except Exception as exc:
|
|
98
|
+
print(f"failed to clear history: {exc}", file=sys.stderr)
|
|
99
|
+
return
|
|
100
|
+
|
|
101
|
+
history_file = _docs_history_file_path()
|
|
102
|
+
try:
|
|
103
|
+
history_file.parent.mkdir(parents=True, exist_ok=True)
|
|
104
|
+
readline.write_history_file(history_file)
|
|
105
|
+
except Exception:
|
|
106
|
+
pass
|
|
107
|
+
print("cleared input history")
|
|
108
|
+
|
|
109
|
+
|
|
110
|
+
class DocsReplError(RuntimeError):
|
|
111
|
+
pass
|
|
112
|
+
|
|
113
|
+
|
|
114
|
+
_PROMPT_TOO_LONG_RE = re.compile(r"Prompt too long:\s*(\d+)\s*tokens\s*\(max=(\d+)\)\.")
|
|
115
|
+
|
|
116
|
+
|
|
117
|
+
def _maybe_print_prompt_too_long_hint(*, msg: str, requested_max_seq_len: int | None) -> None:
|
|
118
|
+
m = _PROMPT_TOO_LONG_RE.search(msg or "")
|
|
119
|
+
if not m:
|
|
120
|
+
return
|
|
121
|
+
try:
|
|
122
|
+
max_allowed = int(m.group(2))
|
|
123
|
+
except Exception:
|
|
124
|
+
return
|
|
125
|
+
if requested_max_seq_len is None:
|
|
126
|
+
return
|
|
127
|
+
try:
|
|
128
|
+
requested = int(requested_max_seq_len)
|
|
129
|
+
except Exception:
|
|
130
|
+
return
|
|
131
|
+
if requested <= max_allowed:
|
|
132
|
+
return
|
|
133
|
+
|
|
134
|
+
print(
|
|
135
|
+
"hint: the server is rejecting prompts longer than its configured --max-prompt-tokens. "
|
|
136
|
+
f"You requested --max-seq-len={requested}, but the server cap is max_prompt_tokens={max_allowed}.",
|
|
137
|
+
file=sys.stderr,
|
|
138
|
+
)
|
|
139
|
+
print(
|
|
140
|
+
f"hint: restart the server with e.g. `spl server start --model <model> --max-prompt-tokens {requested}` (or higher).",
|
|
141
|
+
file=sys.stderr,
|
|
142
|
+
)
|
|
143
|
+
|
|
144
|
+
|
|
145
|
+
DOCS_INGEST_PROMPT = (
|
|
146
|
+
"You are Superlinear Docs, a stateful long-context assistant for document-grounded Q&A.\n"
|
|
147
|
+
"You will receive documents and later questions.\n"
|
|
148
|
+
"\n"
|
|
149
|
+
"## Global Rules (always)\n"
|
|
150
|
+
"- Use ONLY the ingested documents provided in this session. Do not use external knowledge.\n"
|
|
151
|
+
"- You MAY make logical inferences from document content. If the documents state 'A created the X series' and 'Y is part of the X series', you can conclude 'A created Y'.\n"
|
|
152
|
+
"- Never invent citations or file paths.\n"
|
|
153
|
+
"- CRITICAL: Before quoting anything, VERIFY which document it comes from by checking the [SOURCE path=...] tag.\n"
|
|
154
|
+
"- Always end answers with: Sources: <comma-separated paths>\n"
|
|
155
|
+
"\n"
|
|
156
|
+
"## Ingestion Mode\n"
|
|
157
|
+
"- I will send one or more documents wrapped in [SOURCE path=...] ... [/SOURCE].\n"
|
|
158
|
+
"- Treat each block as a separate source and remember its path.\n"
|
|
159
|
+
"- Do not answer questions, do not summarize, and do not add commentary while ingesting.\n"
|
|
160
|
+
"- Reply with exactly: OK\n"
|
|
161
|
+
"\n"
|
|
162
|
+
"## Q&A Mode\n"
|
|
163
|
+
"When the user asks a question, follow this method in <think>:\n"
|
|
164
|
+
"\n"
|
|
165
|
+
"### Step 1: PARSE THE QUESTION\n"
|
|
166
|
+
"- Does it ask about a SPECIFIC document? (e.g., 'the LSTM article', 'the Transformer document')\n"
|
|
167
|
+
" → If yes, you MUST only use content from that exact document. Ignore all other documents.\n"
|
|
168
|
+
"- Does it ask 'list ALL', 'every', 'how many'?\n"
|
|
169
|
+
" → If yes, scan exhaustively. Do not stop at first few examples.\n"
|
|
170
|
+
"- Does it ask about presence/absence? ('does X mention Y', 'which do NOT mention')\n"
|
|
171
|
+
" → If yes, you MUST search for the exact term in the specific document before answering.\n"
|
|
172
|
+
"\n"
|
|
173
|
+
"### Step 1.5: SOURCE-SPECIFIC REASONING (when question names a specific document)\n"
|
|
174
|
+
"If the question asks about a specific article/document (e.g., 'According to the RNN article...'):\n"
|
|
175
|
+
"\n"
|
|
176
|
+
"**STRICT RULE: You must ONLY use content from that ONE document. Do NOT mention, quote, or reference ANY other document in your answer.**\n"
|
|
177
|
+
"\n"
|
|
178
|
+
"Reason through these steps:\n"
|
|
179
|
+
"1. IDENTIFY: 'The question asks specifically about the [X] article.'\n"
|
|
180
|
+
"2. LOCATE: 'I have the [X] document at [SOURCE path=...]. I will search ONLY within it.'\n"
|
|
181
|
+
"3. SEARCH: 'Looking for [topic/terms] in the [X] document...'\n"
|
|
182
|
+
"4. EXTRACT: 'Found in the [X] document: [exact quote with context]'\n"
|
|
183
|
+
"5. CONFIRM: 'This quote is definitely from [X], not from another document.'\n"
|
|
184
|
+
"6. FINAL CHECK: 'My answer mentions ONLY the [X] article. I have NOT included information from other articles.'\n"
|
|
185
|
+
"\n"
|
|
186
|
+
"**CRITICAL**: Before using ANY quote, check its [SOURCE path=...]. If the path contains a different filename than [X], you MUST NOT use that quote. Find a quote from the correct document or say the information is not in that document.\n"
|
|
187
|
+
"\n"
|
|
188
|
+
"Example: If asked 'What does the Transformer article say about X?' and you find a great quote from attention_machine_learning.txt → REJECT IT. Either find a quote from transformer_deep_learning.txt or say 'The Transformer article does not contain this.'\n"
|
|
189
|
+
"\n"
|
|
190
|
+
"If the answer is not in the named document, say 'The [X] article does not contain this information.' Do NOT supplement with other documents.\n"
|
|
191
|
+
"\n"
|
|
192
|
+
"### Step 2: VERIFICATION (MANDATORY)\n"
|
|
193
|
+
"Before writing your answer, perform these checks in <think>:\n"
|
|
194
|
+
"\n"
|
|
195
|
+
"A) SOURCE VERIFICATION: For each quote you plan to use:\n"
|
|
196
|
+
" - State: 'This quote appears in [document name] at [SOURCE path=...]'\n"
|
|
197
|
+
" - If you cannot confirm the source, do not use the quote.\n"
|
|
198
|
+
" - If asked about 'the X article' and your quote is from a different article, DISCARD IT.\n"
|
|
199
|
+
"\n"
|
|
200
|
+
"B) TERM SEARCH: For presence/absence questions ('does X mention Y'):\n"
|
|
201
|
+
" - State: 'Searching for \"Y\" in [document]...'\n"
|
|
202
|
+
" - Search for the term, partial matches, and variations (singular/plural).\n"
|
|
203
|
+
" - State: 'Found: [quote]' or 'Not found after searching [sections checked]'\n"
|
|
204
|
+
" - Only conclude absence after thorough search.\n"
|
|
205
|
+
"\n"
|
|
206
|
+
"C) ENUMERATION CHECK: For 'list all' questions:\n"
|
|
207
|
+
" - After your initial list, ask: 'Did I miss anything?'\n"
|
|
208
|
+
" - Re-scan the relevant sections.\n"
|
|
209
|
+
" - Add any missed items.\n"
|
|
210
|
+
"\n"
|
|
211
|
+
"### Step 3: ANSWER\n"
|
|
212
|
+
"- Quote with correct source attribution.\n"
|
|
213
|
+
"- For enumeration: numbered/bulleted list.\n"
|
|
214
|
+
"- End with: Sources: <paths>\n"
|
|
215
|
+
)
|
|
216
|
+
|
|
217
|
+
DOCS_QA_PROMPT = (
|
|
218
|
+
"Answer using ONLY the ingested documents. Do not use external knowledge.\n"
|
|
219
|
+
"You MAY make logical inferences from document content (e.g., if 'A created the X series' and 'Y is in the X series', conclude 'A created Y').\n"
|
|
220
|
+
"\n"
|
|
221
|
+
"## Method (in <think>)\n"
|
|
222
|
+
"\n"
|
|
223
|
+
"### Step 1: PARSE THE QUESTION\n"
|
|
224
|
+
"- Does it ask about a SPECIFIC document? → Only use that document.\n"
|
|
225
|
+
"- Does it ask 'list ALL', 'every', 'how many'? → Scan exhaustively.\n"
|
|
226
|
+
"- Does it ask about presence/absence? → Search for exact term.\n"
|
|
227
|
+
"\n"
|
|
228
|
+
"### Step 1.5: SOURCE-SPECIFIC REASONING (when question names a specific document)\n"
|
|
229
|
+
"If the question asks about a specific article (e.g., 'According to the RNN article...'):\n"
|
|
230
|
+
"\n"
|
|
231
|
+
"**STRICT RULE: ONLY use content from that ONE document. Do NOT mention or quote ANY other document.**\n"
|
|
232
|
+
"\n"
|
|
233
|
+
"Reason through:\n"
|
|
234
|
+
"1. IDENTIFY: 'This question asks about the [X] article specifically.'\n"
|
|
235
|
+
"2. LOCATE: 'I will search ONLY within the [X] document.'\n"
|
|
236
|
+
"3. SEARCH: 'Looking for [topic] in [X]...'\n"
|
|
237
|
+
"4. EXTRACT: 'Found in [X]: [exact quote]'\n"
|
|
238
|
+
"5. CONFIRM: 'This is from [X], not another document.'\n"
|
|
239
|
+
"6. FINAL CHECK: 'My answer mentions ONLY [X]. No other articles are referenced.'\n"
|
|
240
|
+
"\n"
|
|
241
|
+
"If not found in the named document, say 'The [X] article does not contain this.' Do NOT supplement with other documents.\n"
|
|
242
|
+
"\n"
|
|
243
|
+
"### Step 2: USE EXCERPTS (with source filtering)\n"
|
|
244
|
+
"You received 'Retrieved excerpts' with passages from multiple documents.\n"
|
|
245
|
+
"**WARNING: For source-specific questions, IGNORE the excerpts entirely.**\n"
|
|
246
|
+
"**Search your full memory of the named document instead.**\n"
|
|
247
|
+
"The excerpts may contain tempting quotes from OTHER documents - do not use them.\n"
|
|
248
|
+
"\n"
|
|
249
|
+
"### Step 3: VERIFICATION (MANDATORY)\n"
|
|
250
|
+
"\n"
|
|
251
|
+
"A) SOURCE VERIFICATION: For each quote:\n"
|
|
252
|
+
" - State: 'This quote appears in [document] at [SOURCE path=...]'\n"
|
|
253
|
+
" - If asked about 'the X article' and quote is from different article, DISCARD IT.\n"
|
|
254
|
+
"\n"
|
|
255
|
+
"B) TERM SEARCH: For presence/absence questions:\n"
|
|
256
|
+
" - State: 'Searching for \"Y\" in [document]...'\n"
|
|
257
|
+
" - Search for term and variations.\n"
|
|
258
|
+
" - State: 'Found: [quote]' or 'Not found after searching [sections]'\n"
|
|
259
|
+
"\n"
|
|
260
|
+
"C) ENUMERATION CHECK: For 'list all':\n"
|
|
261
|
+
" - After initial list, ask: 'Did I miss anything?'\n"
|
|
262
|
+
" - Re-scan and add missed items.\n"
|
|
263
|
+
"\n"
|
|
264
|
+
"## Answer\n"
|
|
265
|
+
"- Quote with correct source attribution.\n"
|
|
266
|
+
"- **For source-specific questions: cite ONLY the named document, even if you mention topics that appear in other documents.**\n"
|
|
267
|
+
"- End with: Sources: <paths>\n"
|
|
268
|
+
)
|
|
269
|
+
|
|
270
|
+
|
|
271
|
+
DOCS_QA_PRIMER_ASSISTANT = (
|
|
272
|
+
"Understood. I will use any provided excerpts as hints, but verify against my full memory of the documents. "
|
|
273
|
+
"I will search for the pivotal scenes and dialogue, trace the causation chain, "
|
|
274
|
+
"and give a substantive answer with accurately quoted evidence."
|
|
275
|
+
)
|
|
276
|
+
|
|
277
|
+
|
|
278
|
+
_FOLLOWUP_ANAPHORA_RE = re.compile(r"\b(this|that|it|above|previous|earlier|last)\b", re.IGNORECASE)
|
|
279
|
+
_FOLLOWUP_INTENT_RE = re.compile(
|
|
280
|
+
r"\b(which\s+(article|source|file|doc|document)|where\s+.*\b(from|in)\b|what\s+(article|source|file|doc|document))\b",
|
|
281
|
+
re.IGNORECASE,
|
|
282
|
+
)
|
|
283
|
+
|
|
284
|
+
|
|
285
|
+
def _should_augment_rag_query_with_prev_question(*, question: str, prev_question: str | None) -> bool:
|
|
286
|
+
if not prev_question or not isinstance(prev_question, str) or not prev_question.strip():
|
|
287
|
+
return False
|
|
288
|
+
q = (question or "").strip()
|
|
289
|
+
if not q:
|
|
290
|
+
return False
|
|
291
|
+
|
|
292
|
+
# Heuristic: for follow-up questions that refer to "this/that/it" or ask "which article/source",
|
|
293
|
+
# include the previous user question to re-anchor lexical retrieval.
|
|
294
|
+
if _FOLLOWUP_ANAPHORA_RE.search(q) or _FOLLOWUP_INTENT_RE.search(q):
|
|
295
|
+
# Avoid hijacking very short single-entity queries like "RoPE?".
|
|
296
|
+
terms = tokenize_query_terms(q, max_terms=32)
|
|
297
|
+
if len(terms) <= 1 and len(q) < 24:
|
|
298
|
+
return False
|
|
299
|
+
return True
|
|
300
|
+
|
|
301
|
+
# If the question has no content terms at all, it's likely a follow-up.
|
|
302
|
+
terms = tokenize_query_terms(q, max_terms=32)
|
|
303
|
+
return not terms
|
|
304
|
+
|
|
305
|
+
|
|
306
|
+
def _now_utc_compact() -> str:
|
|
307
|
+
return time.strftime("%Y%m%d_%H%M%S", time.gmtime())
|
|
308
|
+
|
|
309
|
+
|
|
310
|
+
def _sanitize_for_id(name: str) -> str:
|
|
311
|
+
out = []
|
|
312
|
+
for ch in name.strip():
|
|
313
|
+
if ch.isalnum() or ch in {"-", "_"}:
|
|
314
|
+
out.append(ch)
|
|
315
|
+
else:
|
|
316
|
+
out.append("_")
|
|
317
|
+
s = "".join(out).strip("_")
|
|
318
|
+
return s[:32] if s else "docs"
|
|
319
|
+
|
|
320
|
+
|
|
321
|
+
def _new_session_id(*, workspace_name: str) -> str:
|
|
322
|
+
import secrets
|
|
323
|
+
|
|
324
|
+
prefix = f"docs_{_sanitize_for_id(workspace_name)}"
|
|
325
|
+
return f"{prefix}_{_now_utc_compact()}_{secrets.token_hex(3)}"
|
|
326
|
+
|
|
327
|
+
|
|
328
|
+
def _ensure_reachable(client: SuperlinearClient) -> None:
|
|
329
|
+
try:
|
|
330
|
+
client.health()
|
|
331
|
+
except HttpError as exc:
|
|
332
|
+
raise DocsReplError(
|
|
333
|
+
f"Server unreachable at {client.base_url}. Start it with `spl server start --model <model>` "
|
|
334
|
+
f"or pass `--url`.\n{exc}"
|
|
335
|
+
) from exc
|
|
336
|
+
|
|
337
|
+
|
|
338
|
+
def _session_exists(client: SuperlinearClient, session_id: str) -> bool:
|
|
339
|
+
try:
|
|
340
|
+
client.request_json("GET", f"/v1/sessions/{session_id}", timeout_s=5.0)
|
|
341
|
+
return True
|
|
342
|
+
except HttpError as exc:
|
|
343
|
+
if exc.status_code == 404:
|
|
344
|
+
return False
|
|
345
|
+
raise
|
|
346
|
+
|
|
347
|
+
|
|
348
|
+
def _create_session(client: SuperlinearClient, session_id: str) -> None:
|
|
349
|
+
try:
|
|
350
|
+
client.request_json("POST", "/v1/sessions", payload={"session_id": session_id}, timeout_s=30.0)
|
|
351
|
+
except HttpError as exc:
|
|
352
|
+
if exc.status_code == 409:
|
|
353
|
+
return
|
|
354
|
+
raise
|
|
355
|
+
|
|
356
|
+
|
|
357
|
+
def _create_session_with_max_seq_len(
|
|
358
|
+
client: SuperlinearClient, session_id: str, *, max_seq_len: int | None
|
|
359
|
+
) -> None:
|
|
360
|
+
payload: dict[str, Any] = {"session_id": session_id}
|
|
361
|
+
if max_seq_len is not None:
|
|
362
|
+
payload["max_seq_len"] = int(max_seq_len)
|
|
363
|
+
try:
|
|
364
|
+
client.request_json("POST", "/v1/sessions", payload=payload, timeout_s=30.0)
|
|
365
|
+
except HttpError as exc:
|
|
366
|
+
if exc.status_code == 409:
|
|
367
|
+
return
|
|
368
|
+
raise
|
|
369
|
+
|
|
370
|
+
|
|
371
|
+
def _maybe_resize_session(
|
|
372
|
+
client: SuperlinearClient,
|
|
373
|
+
session_id: str,
|
|
374
|
+
*,
|
|
375
|
+
min_max_seq_len: int | None,
|
|
376
|
+
strategy: str = "auto",
|
|
377
|
+
) -> None:
|
|
378
|
+
if min_max_seq_len is None:
|
|
379
|
+
return
|
|
380
|
+
try:
|
|
381
|
+
info = client.request_json("GET", f"/v1/sessions/{session_id}", timeout_s=10.0)
|
|
382
|
+
except HttpError:
|
|
383
|
+
return
|
|
384
|
+
if not isinstance(info, dict):
|
|
385
|
+
return
|
|
386
|
+
try:
|
|
387
|
+
cur = int(info.get("max_seq_len") or 0)
|
|
388
|
+
except Exception:
|
|
389
|
+
cur = 0
|
|
390
|
+
target = int(min_max_seq_len)
|
|
391
|
+
if target <= 0 or (cur > 0 and target <= cur):
|
|
392
|
+
return
|
|
393
|
+
|
|
394
|
+
# Resize the *existing* session to at least the requested length.
|
|
395
|
+
try:
|
|
396
|
+
client.request_json(
|
|
397
|
+
"POST",
|
|
398
|
+
f"/v1/sessions/{session_id}/resize",
|
|
399
|
+
payload={"max_seq_len": target, "strategy": strategy},
|
|
400
|
+
timeout_s=300.0,
|
|
401
|
+
)
|
|
402
|
+
except HttpError as exc:
|
|
403
|
+
# Provide a more actionable hint for common failure modes.
|
|
404
|
+
raise DocsReplError(
|
|
405
|
+
"Failed to resize session context length. "
|
|
406
|
+
"This can happen if the target is too large for GPU memory. "
|
|
407
|
+
f"(session_id={session_id} target_max_seq_len={target}): {exc}"
|
|
408
|
+
) from exc
|
|
409
|
+
|
|
410
|
+
|
|
411
|
+
def _get_session_pos(client: SuperlinearClient, session_id: str) -> int | None:
|
|
412
|
+
try:
|
|
413
|
+
info = client.request_json("GET", f"/v1/sessions/{session_id}", timeout_s=10.0)
|
|
414
|
+
except HttpError:
|
|
415
|
+
return None
|
|
416
|
+
if isinstance(info, dict) and isinstance(info.get("current_pos"), int):
|
|
417
|
+
return int(info["current_pos"])
|
|
418
|
+
return None
|
|
419
|
+
|
|
420
|
+
|
|
421
|
+
def _encode_sources_description(sources: list[dict[str, Any]]) -> str:
|
|
422
|
+
return json.dumps({"spl_docs_sources_v1": sources}, ensure_ascii=False, sort_keys=True)
|
|
423
|
+
|
|
424
|
+
|
|
425
|
+
def _decode_sources_description(desc: str | None) -> list[dict[str, Any]] | None:
|
|
426
|
+
if not desc or not isinstance(desc, str):
|
|
427
|
+
return None
|
|
428
|
+
try:
|
|
429
|
+
obj = json.loads(desc)
|
|
430
|
+
except Exception:
|
|
431
|
+
return None
|
|
432
|
+
if not isinstance(obj, dict):
|
|
433
|
+
return None
|
|
434
|
+
raw = obj.get("spl_docs_sources_v1")
|
|
435
|
+
if not isinstance(raw, list):
|
|
436
|
+
return None
|
|
437
|
+
out: list[dict[str, Any]] = []
|
|
438
|
+
for s in raw:
|
|
439
|
+
if not isinstance(s, dict):
|
|
440
|
+
continue
|
|
441
|
+
path = s.get("path")
|
|
442
|
+
if not isinstance(path, str) or not path:
|
|
443
|
+
continue
|
|
444
|
+
item: dict[str, Any] = {"path": path}
|
|
445
|
+
title = s.get("title")
|
|
446
|
+
if isinstance(title, str) and title.strip():
|
|
447
|
+
item["title"] = title.strip()
|
|
448
|
+
source = s.get("source")
|
|
449
|
+
if isinstance(source, str) and source.strip():
|
|
450
|
+
item["source"] = source.strip()
|
|
451
|
+
url = s.get("url")
|
|
452
|
+
if isinstance(url, str) and url.strip():
|
|
453
|
+
item["url"] = url.strip()
|
|
454
|
+
b = s.get("bytes")
|
|
455
|
+
if isinstance(b, int) and b >= 0:
|
|
456
|
+
item["bytes"] = b
|
|
457
|
+
sha = s.get("sha256")
|
|
458
|
+
if isinstance(sha, str) and sha:
|
|
459
|
+
item["sha256"] = sha
|
|
460
|
+
added = s.get("added_at_unix_s")
|
|
461
|
+
if isinstance(added, int) and added > 0:
|
|
462
|
+
item["added_at_unix_s"] = added
|
|
463
|
+
out.append(item)
|
|
464
|
+
return out
|
|
465
|
+
|
|
466
|
+
|
|
467
|
+
def _extract_doc_metadata(*, path: Path, text: str) -> dict[str, str]:
|
|
468
|
+
"""Best-effort extraction of title/source/url from a document.
|
|
469
|
+
|
|
470
|
+
Supports the wiki test corpus header format:
|
|
471
|
+
Title: ...\nSource: ...\nURL: ...
|
|
472
|
+
and common Markdown titles.
|
|
473
|
+
"""
|
|
474
|
+
|
|
475
|
+
title: str | None = None
|
|
476
|
+
source: str | None = None
|
|
477
|
+
url: str | None = None
|
|
478
|
+
|
|
479
|
+
lines = text.replace("\r", "").split("\n")
|
|
480
|
+
head = lines[:80]
|
|
481
|
+
|
|
482
|
+
for ln in head:
|
|
483
|
+
s = ln.strip()
|
|
484
|
+
if not s:
|
|
485
|
+
continue
|
|
486
|
+
if s.lower().startswith("title:") and title is None:
|
|
487
|
+
title = s.split(":", 1)[1].strip()
|
|
488
|
+
continue
|
|
489
|
+
if s.lower().startswith("source:") and source is None:
|
|
490
|
+
source = s.split(":", 1)[1].strip()
|
|
491
|
+
continue
|
|
492
|
+
if s.lower().startswith("url:") and url is None:
|
|
493
|
+
url = s.split(":", 1)[1].strip()
|
|
494
|
+
continue
|
|
495
|
+
if s.startswith("#") and title is None:
|
|
496
|
+
# Markdown heading.
|
|
497
|
+
title = s.lstrip("#").strip()
|
|
498
|
+
continue
|
|
499
|
+
|
|
500
|
+
if not title:
|
|
501
|
+
title = path.stem
|
|
502
|
+
|
|
503
|
+
out: dict[str, str] = {"title": title}
|
|
504
|
+
if source:
|
|
505
|
+
out["source"] = source
|
|
506
|
+
if url:
|
|
507
|
+
out["url"] = url
|
|
508
|
+
return out
|
|
509
|
+
|
|
510
|
+
|
|
511
|
+
def _build_qa_bootstrap_message(*, sources: list[dict[str, Any]]) -> str:
|
|
512
|
+
"""A strong, near-generation instruction + index injected as a USER message.
|
|
513
|
+
|
|
514
|
+
This avoids relying on a late system prompt in session mode: the server intentionally
|
|
515
|
+
drops additional system messages once the transcript already has a leading system.
|
|
516
|
+
"""
|
|
517
|
+
|
|
518
|
+
# Keep the index compact and resilient.
|
|
519
|
+
max_items = 200
|
|
520
|
+
items = sources[:max_items]
|
|
521
|
+
|
|
522
|
+
lines: list[str] = []
|
|
523
|
+
lines.append("You are now in docs Q&A mode for this session.")
|
|
524
|
+
lines.append("You MUST answer using ONLY the ingested documents in this session. Do not use external knowledge.")
|
|
525
|
+
lines.append("If the documents do not contain the answer, say you don't know. Do not guess.")
|
|
526
|
+
lines.append("Never invent citations or file paths.")
|
|
527
|
+
lines.append("")
|
|
528
|
+
lines.append("Method (follow this every time):")
|
|
529
|
+
lines.append("1) In <think>, pick the 3–8 most likely sources from the index by title/source/path.")
|
|
530
|
+
lines.append("2) In <think>, thoroughly search those sources for relevant passages and extract concrete facts.")
|
|
531
|
+
lines.append("3) In <think>, reconcile conflicts and consolidate into a coherent answer.")
|
|
532
|
+
lines.append("4) In the final answer, include short quotes (1–3 lines) for key claims when possible.")
|
|
533
|
+
lines.append("5) Always end with a final line exactly: Sources: <comma-separated paths>.")
|
|
534
|
+
lines.append("")
|
|
535
|
+
lines.append("Note: You may receive a 'Retrieved excerpts' message immediately before a question.")
|
|
536
|
+
lines.append("- Use those excerpts as PRIMARY evidence when relevant.")
|
|
537
|
+
lines.append("- They are not exhaustive; consult other sources from the index if needed.")
|
|
538
|
+
lines.append("")
|
|
539
|
+
lines.append("Quality bar:")
|
|
540
|
+
lines.append("- If you cannot find direct support in the docs, respond 'I don't know'.")
|
|
541
|
+
lines.append(" Still include Sources listing the most relevant documents you checked (1–8 paths).")
|
|
542
|
+
lines.append("- If partially supported, clearly separate supported vs missing details.")
|
|
543
|
+
lines.append("- Prefer fewer, higher-confidence claims over broad speculation.")
|
|
544
|
+
lines.append("- IMPORTANT: Do not rush. Use an extended <think> phase to do the source-selection and extraction steps.")
|
|
545
|
+
lines.append("")
|
|
546
|
+
lines.append("Available documents (index):")
|
|
547
|
+
|
|
548
|
+
for i, s in enumerate(items, start=1):
|
|
549
|
+
path = str(s.get("path") or "")
|
|
550
|
+
title = str(s.get("title") or "").strip() or Path(path).name
|
|
551
|
+
title = title.replace("\n", " ").strip()
|
|
552
|
+
if len(title) > 120:
|
|
553
|
+
title = title[:117] + "…"
|
|
554
|
+
src = s.get("source")
|
|
555
|
+
url = s.get("url")
|
|
556
|
+
extra: list[str] = []
|
|
557
|
+
if isinstance(src, str) and src.strip():
|
|
558
|
+
extra.append(f"source={src.strip()}")
|
|
559
|
+
if isinstance(url, str) and url.strip():
|
|
560
|
+
extra.append(f"url={url.strip()}")
|
|
561
|
+
extra_s = (" | " + " | ".join(extra)) if extra else ""
|
|
562
|
+
lines.append(f"- {i}. {title} | path={path}{extra_s}")
|
|
563
|
+
|
|
564
|
+
if len(sources) > max_items:
|
|
565
|
+
lines.append(f"(and {len(sources) - max_items} more not shown)")
|
|
566
|
+
|
|
567
|
+
return "\n".join(lines).strip() + "\n"
|
|
568
|
+
|
|
569
|
+
|
|
570
|
+
def _hydrate_sources_from_snapshot(client: SuperlinearClient, snapshot_id: str) -> list[dict[str, Any]] | None:
|
|
571
|
+
try:
|
|
572
|
+
manifest = client.request_json("GET", f"/v1/snapshots/{snapshot_id}", timeout_s=30.0)
|
|
573
|
+
except HttpError:
|
|
574
|
+
return None
|
|
575
|
+
if not isinstance(manifest, dict):
|
|
576
|
+
return None
|
|
577
|
+
metadata = manifest.get("metadata")
|
|
578
|
+
if not isinstance(metadata, dict):
|
|
579
|
+
return None
|
|
580
|
+
return _decode_sources_description(metadata.get("description"))
|
|
581
|
+
|
|
582
|
+
|
|
583
|
+
def _banner(
|
|
584
|
+
*,
|
|
585
|
+
url: str,
|
|
586
|
+
name: str,
|
|
587
|
+
session_id: str,
|
|
588
|
+
resumed: bool,
|
|
589
|
+
phase: str,
|
|
590
|
+
source_count: int,
|
|
591
|
+
base_snapshot_id: str | None,
|
|
592
|
+
rag_status: str,
|
|
593
|
+
) -> None:
|
|
594
|
+
mode = "resumed" if resumed else "new"
|
|
595
|
+
print(f"server={url}")
|
|
596
|
+
print(f"workspace={name} ({mode})")
|
|
597
|
+
print(f"session_id={session_id}")
|
|
598
|
+
print(f"phase={phase} sources={source_count}")
|
|
599
|
+
print(rag_status)
|
|
600
|
+
if base_snapshot_id:
|
|
601
|
+
print(f"base_snapshot_id={base_snapshot_id}")
|
|
602
|
+
print("type /help for commands")
|
|
603
|
+
|
|
604
|
+
|
|
605
|
+
def _rag_backend_from_ws(ws: DocsWorkspaceState) -> str:
|
|
606
|
+
backend = getattr(ws, "rag_backend", None)
|
|
607
|
+
if isinstance(backend, str):
|
|
608
|
+
b = backend.strip().lower()
|
|
609
|
+
if b in {"light", "bm25", "off"}:
|
|
610
|
+
return b
|
|
611
|
+
return "light" if bool(ws.light_rag_enabled) else "off"
|
|
612
|
+
|
|
613
|
+
|
|
614
|
+
def _apply_rag_backend_to_ws(ws: DocsWorkspaceState, backend: str) -> None:
|
|
615
|
+
b = (backend or "").strip().lower()
|
|
616
|
+
if b not in {"light", "bm25", "off"}:
|
|
617
|
+
b = "light"
|
|
618
|
+
ws.rag_backend = b
|
|
619
|
+
ws.light_rag_enabled = b != "off"
|
|
620
|
+
|
|
621
|
+
|
|
622
|
+
def _light_rag_config_from_ws(ws: DocsWorkspaceState) -> LightRagConfig:
|
|
623
|
+
enabled = _rag_backend_from_ws(ws) != "off"
|
|
624
|
+
return LightRagConfig(
|
|
625
|
+
enabled=enabled,
|
|
626
|
+
k=int(ws.light_rag_k),
|
|
627
|
+
total_chars=int(ws.light_rag_total_chars),
|
|
628
|
+
per_source_chars=int(ws.light_rag_per_source_chars),
|
|
629
|
+
debug=bool(ws.light_rag_debug),
|
|
630
|
+
).sanitized()
|
|
631
|
+
|
|
632
|
+
|
|
633
|
+
def _apply_light_rag_config_to_ws(ws: DocsWorkspaceState, cfg: LightRagConfig) -> None:
|
|
634
|
+
ws.light_rag_enabled = bool(cfg.enabled)
|
|
635
|
+
ws.light_rag_k = int(cfg.k)
|
|
636
|
+
ws.light_rag_total_chars = int(cfg.total_chars)
|
|
637
|
+
ws.light_rag_per_source_chars = int(cfg.per_source_chars)
|
|
638
|
+
ws.light_rag_debug = bool(cfg.debug)
|
|
639
|
+
|
|
640
|
+
|
|
641
|
+
def _format_light_rag_status(cfg: LightRagConfig) -> str:
|
|
642
|
+
cfg = cfg.sanitized()
|
|
643
|
+
return (
|
|
644
|
+
f"lightRAG={'on' if cfg.enabled else 'off'} "
|
|
645
|
+
f"k={cfg.k} chars={cfg.total_chars} per_source_chars={cfg.per_source_chars} "
|
|
646
|
+
f"debug={'on' if cfg.debug else 'off'}"
|
|
647
|
+
)
|
|
648
|
+
|
|
649
|
+
|
|
650
|
+
def _bm25_rag_config_from_ws(ws: DocsWorkspaceState) -> Bm25RagConfig:
|
|
651
|
+
common = _light_rag_config_from_ws(ws).sanitized()
|
|
652
|
+
try:
|
|
653
|
+
k_sources = int(ws.bm25_k_sources)
|
|
654
|
+
except Exception:
|
|
655
|
+
k_sources = 0
|
|
656
|
+
if k_sources <= 0:
|
|
657
|
+
k_sources = int(common.k)
|
|
658
|
+
try:
|
|
659
|
+
k_paragraphs = int(ws.bm25_k_paragraphs)
|
|
660
|
+
except Exception:
|
|
661
|
+
k_paragraphs = 40
|
|
662
|
+
return Bm25RagConfig(
|
|
663
|
+
enabled=bool(common.enabled),
|
|
664
|
+
k_sources=k_sources,
|
|
665
|
+
total_chars=int(common.total_chars),
|
|
666
|
+
per_source_chars=int(common.per_source_chars),
|
|
667
|
+
debug=bool(common.debug),
|
|
668
|
+
k_paragraphs=k_paragraphs,
|
|
669
|
+
max_terms=int(common.max_terms),
|
|
670
|
+
max_paragraphs_per_source=int(common.max_paragraphs_per_source),
|
|
671
|
+
max_paragraph_chars=int(common.max_paragraph_chars),
|
|
672
|
+
).sanitized()
|
|
673
|
+
|
|
674
|
+
|
|
675
|
+
def _format_rag_status(*, ws: DocsWorkspaceState, bm25_available: bool) -> str:
|
|
676
|
+
backend = _rag_backend_from_ws(ws)
|
|
677
|
+
common = _light_rag_config_from_ws(ws).sanitized()
|
|
678
|
+
debug_s = "on" if common.debug else "off"
|
|
679
|
+
if backend == "off" or not common.enabled:
|
|
680
|
+
return f"rag=off debug={debug_s}"
|
|
681
|
+
if backend == "bm25":
|
|
682
|
+
bm25_cfg = _bm25_rag_config_from_ws(ws)
|
|
683
|
+
avail_s = "" if bm25_available else " (unavailable -> fallback light)"
|
|
684
|
+
return (
|
|
685
|
+
f"rag=on backend=bm25 k_sources={bm25_cfg.k_sources} k_paragraphs={bm25_cfg.k_paragraphs} "
|
|
686
|
+
f"chars={common.total_chars} per_source_chars={common.per_source_chars} debug={debug_s}{avail_s}"
|
|
687
|
+
)
|
|
688
|
+
return (
|
|
689
|
+
f"rag=on backend=light k={common.k} chars={common.total_chars} per_source_chars={common.per_source_chars} "
|
|
690
|
+
f"debug={debug_s}"
|
|
691
|
+
)
|
|
692
|
+
|
|
693
|
+
|
|
694
|
+
def _format_sources_table(sources: list[dict[str, Any]]) -> str:
|
|
695
|
+
rows: list[list[str]] = []
|
|
696
|
+
for i, s in enumerate(sources, start=1):
|
|
697
|
+
path = str(s.get("path") or "")
|
|
698
|
+
b = s.get("bytes")
|
|
699
|
+
sha = s.get("sha256")
|
|
700
|
+
rows.append(
|
|
701
|
+
[
|
|
702
|
+
str(i),
|
|
703
|
+
path,
|
|
704
|
+
"" if b is None else str(b),
|
|
705
|
+
"" if not isinstance(sha, str) else sha[:12],
|
|
706
|
+
]
|
|
707
|
+
)
|
|
708
|
+
return format_table(["#", "path", "bytes", "sha256"], rows)
|
|
709
|
+
|
|
710
|
+
|
|
711
|
+
class _TurnStats:
|
|
712
|
+
def __init__(self) -> None:
|
|
713
|
+
self.finish_reason: str | None = None
|
|
714
|
+
self.ttft_s: float | None = None
|
|
715
|
+
self.total_s: float | None = None
|
|
716
|
+
self.prompt_tokens: int | None = None
|
|
717
|
+
self.completion_tokens: int | None = None
|
|
718
|
+
self.tok_per_s: float | None = None
|
|
719
|
+
self.server_prefill_s: float | None = None
|
|
720
|
+
self.server_decode_s: float | None = None
|
|
721
|
+
self.server_total_s: float | None = None
|
|
722
|
+
|
|
723
|
+
|
|
724
|
+
def _stats_footer(stats: _TurnStats) -> str:
|
|
725
|
+
parts: list[str] = []
|
|
726
|
+
if stats.ttft_s is not None:
|
|
727
|
+
parts.append(f"ttft={stats.ttft_s:.3f}s")
|
|
728
|
+
if stats.tok_per_s is not None:
|
|
729
|
+
parts.append(f"tok/s={stats.tok_per_s:.2f}")
|
|
730
|
+
if stats.prompt_tokens is not None and stats.completion_tokens is not None:
|
|
731
|
+
parts.append(f"tokens={stats.prompt_tokens}+{stats.completion_tokens}")
|
|
732
|
+
if stats.finish_reason is not None:
|
|
733
|
+
parts.append(f"finish={stats.finish_reason}")
|
|
734
|
+
if stats.total_s is not None:
|
|
735
|
+
parts.append(f"wall={stats.total_s:.3f}s")
|
|
736
|
+
return " ".join(parts) if parts else ""
|
|
737
|
+
|
|
738
|
+
|
|
739
|
+
def _stats_detail(stats: _TurnStats) -> str:
|
|
740
|
+
lines: list[str] = []
|
|
741
|
+
if stats.finish_reason is not None:
|
|
742
|
+
lines.append(f"finish_reason={stats.finish_reason}")
|
|
743
|
+
if stats.ttft_s is not None:
|
|
744
|
+
lines.append(f"ttft_s={stats.ttft_s:.6f}")
|
|
745
|
+
if stats.total_s is not None:
|
|
746
|
+
lines.append(f"wall_s={stats.total_s:.6f}")
|
|
747
|
+
if stats.prompt_tokens is not None:
|
|
748
|
+
lines.append(f"prompt_tokens={stats.prompt_tokens}")
|
|
749
|
+
if stats.completion_tokens is not None:
|
|
750
|
+
lines.append(f"completion_tokens={stats.completion_tokens}")
|
|
751
|
+
if stats.tok_per_s is not None:
|
|
752
|
+
lines.append(f"tok_per_s={stats.tok_per_s:.6f}")
|
|
753
|
+
if stats.server_prefill_s is not None:
|
|
754
|
+
lines.append(f"server_prefill_s={stats.server_prefill_s:.6f}")
|
|
755
|
+
if stats.server_decode_s is not None:
|
|
756
|
+
lines.append(f"server_decode_s={stats.server_decode_s:.6f}")
|
|
757
|
+
if stats.server_total_s is not None:
|
|
758
|
+
lines.append(f"server_total_s={stats.server_total_s:.6f}")
|
|
759
|
+
return "\n".join(lines)
|
|
760
|
+
|
|
761
|
+
|
|
762
|
+
def _stream_request(
|
|
763
|
+
*,
|
|
764
|
+
client: SuperlinearClient,
|
|
765
|
+
session_id: str,
|
|
766
|
+
messages: list[dict[str, Any]],
|
|
767
|
+
max_completion_tokens: int = 32768,
|
|
768
|
+
think_budget: int | None = 8192,
|
|
769
|
+
temperature: float = 0.3,
|
|
770
|
+
top_p: float = 0.95,
|
|
771
|
+
print_output: bool = True,
|
|
772
|
+
) -> _TurnStats:
|
|
773
|
+
enable_thinking_ui = print_output and think_budget is not None and think_budget > 0
|
|
774
|
+
|
|
775
|
+
payload: dict[str, Any] = {
|
|
776
|
+
"stream": True,
|
|
777
|
+
"session_id": session_id,
|
|
778
|
+
"messages": messages,
|
|
779
|
+
"max_completion_tokens": int(max_completion_tokens),
|
|
780
|
+
"temperature": float(temperature),
|
|
781
|
+
"top_p": float(top_p),
|
|
782
|
+
}
|
|
783
|
+
|
|
784
|
+
if enable_thinking_ui:
|
|
785
|
+
payload["reasoning_budget"] = int(think_budget)
|
|
786
|
+
payload["discard_thinking"] = True
|
|
787
|
+
payload["stream_thinking"] = True
|
|
788
|
+
|
|
789
|
+
started = time.monotonic()
|
|
790
|
+
ttft_s: float | None = None
|
|
791
|
+
finish_reason: str | None = None
|
|
792
|
+
usage: dict[str, Any] | None = None
|
|
793
|
+
timing: dict[str, Any] | None = None
|
|
794
|
+
|
|
795
|
+
started_answer = False
|
|
796
|
+
in_think = False
|
|
797
|
+
thinking_accum: str = ""
|
|
798
|
+
thinking_panel_active = False
|
|
799
|
+
thinking_panel_lines = 0
|
|
800
|
+
content_buf = ""
|
|
801
|
+
saw_thinking_delta = False
|
|
802
|
+
thinking_start_time: float | None = None
|
|
803
|
+
thinking_end_time: float | None = None
|
|
804
|
+
|
|
805
|
+
def _thinking_panel_format(text: str) -> list[str]:
|
|
806
|
+
cols = shutil.get_terminal_size(fallback=(120, 24)).columns
|
|
807
|
+
prefix = "thinking: "
|
|
808
|
+
width = max(20, cols - len(prefix) - 1)
|
|
809
|
+
|
|
810
|
+
normalized = text.replace("\r", "")
|
|
811
|
+
wrapped: list[str] = []
|
|
812
|
+
for logical in normalized.split("\n"):
|
|
813
|
+
parts = textwrap.wrap(
|
|
814
|
+
logical,
|
|
815
|
+
width=width,
|
|
816
|
+
replace_whitespace=False,
|
|
817
|
+
drop_whitespace=False,
|
|
818
|
+
break_long_words=True,
|
|
819
|
+
break_on_hyphens=False,
|
|
820
|
+
)
|
|
821
|
+
if not parts:
|
|
822
|
+
wrapped.append("")
|
|
823
|
+
else:
|
|
824
|
+
wrapped.extend(parts)
|
|
825
|
+
|
|
826
|
+
tail = wrapped[-10:]
|
|
827
|
+
if not tail:
|
|
828
|
+
tail = [""]
|
|
829
|
+
return [prefix + ln for ln in tail]
|
|
830
|
+
|
|
831
|
+
def _thinking_panel_move_to_top() -> None:
|
|
832
|
+
nonlocal thinking_panel_lines
|
|
833
|
+
if thinking_panel_lines > 1:
|
|
834
|
+
sys.stdout.write(f"\x1b[{thinking_panel_lines - 1}A")
|
|
835
|
+
|
|
836
|
+
def _thinking_panel_render(text: str) -> None:
|
|
837
|
+
nonlocal thinking_panel_active, thinking_panel_lines
|
|
838
|
+
if not print_output:
|
|
839
|
+
return
|
|
840
|
+
lines = _thinking_panel_format(text)
|
|
841
|
+
|
|
842
|
+
if not thinking_panel_active:
|
|
843
|
+
sys.stdout.write("\n")
|
|
844
|
+
thinking_panel_active = True
|
|
845
|
+
thinking_panel_lines = 1
|
|
846
|
+
|
|
847
|
+
_thinking_panel_move_to_top()
|
|
848
|
+
|
|
849
|
+
for i in range(thinking_panel_lines):
|
|
850
|
+
sys.stdout.write("\r\x1b[2K")
|
|
851
|
+
if i < thinking_panel_lines - 1:
|
|
852
|
+
sys.stdout.write("\n")
|
|
853
|
+
_thinking_panel_move_to_top()
|
|
854
|
+
|
|
855
|
+
for i, ln in enumerate(lines):
|
|
856
|
+
sys.stdout.write("\r\x1b[2K" + ln)
|
|
857
|
+
if i < len(lines) - 1:
|
|
858
|
+
sys.stdout.write("\n")
|
|
859
|
+
|
|
860
|
+
thinking_panel_lines = len(lines)
|
|
861
|
+
sys.stdout.flush()
|
|
862
|
+
|
|
863
|
+
def _thinking_panel_clear() -> None:
|
|
864
|
+
nonlocal thinking_panel_active, thinking_panel_lines, thinking_start_time, thinking_end_time
|
|
865
|
+
if not print_output:
|
|
866
|
+
return
|
|
867
|
+
if not thinking_panel_active:
|
|
868
|
+
return
|
|
869
|
+
|
|
870
|
+
_thinking_panel_move_to_top()
|
|
871
|
+
for i in range(thinking_panel_lines):
|
|
872
|
+
sys.stdout.write("\r\x1b[2K")
|
|
873
|
+
if i < thinking_panel_lines - 1:
|
|
874
|
+
sys.stdout.write("\n")
|
|
875
|
+
_thinking_panel_move_to_top()
|
|
876
|
+
|
|
877
|
+
thinking_panel_active = False
|
|
878
|
+
thinking_panel_lines = 0
|
|
879
|
+
|
|
880
|
+
if thinking_start_time is not None and thinking_end_time is not None:
|
|
881
|
+
duration = thinking_end_time - thinking_start_time
|
|
882
|
+
if duration >= 60:
|
|
883
|
+
minutes = int(duration // 60)
|
|
884
|
+
seconds = duration % 60
|
|
885
|
+
sys.stdout.write(
|
|
886
|
+
f"[thinking complete] duration: {minutes} minute{'s' if minutes != 1 else ''} {seconds:.1f} seconds\n"
|
|
887
|
+
)
|
|
888
|
+
else:
|
|
889
|
+
sys.stdout.write(f"[thinking complete] duration: {duration:.1f} seconds\n")
|
|
890
|
+
|
|
891
|
+
sys.stdout.flush()
|
|
892
|
+
|
|
893
|
+
def _answer_start_if_needed() -> None:
|
|
894
|
+
nonlocal started_answer
|
|
895
|
+
if not print_output:
|
|
896
|
+
return
|
|
897
|
+
if not started_answer:
|
|
898
|
+
_thinking_panel_clear()
|
|
899
|
+
print("assistant: ", end="", flush=True)
|
|
900
|
+
started_answer = True
|
|
901
|
+
|
|
902
|
+
gen = client.request_sse("POST", "/v1/chat/completions", payload=payload, timeout_s=3600.0)
|
|
903
|
+
try:
|
|
904
|
+
for event in gen:
|
|
905
|
+
if isinstance(event, dict) and "error" in event:
|
|
906
|
+
err = event.get("error")
|
|
907
|
+
msg = err.get("message") if isinstance(err, dict) else str(err)
|
|
908
|
+
raise DocsReplError(str(msg))
|
|
909
|
+
|
|
910
|
+
if not isinstance(event, dict):
|
|
911
|
+
continue
|
|
912
|
+
|
|
913
|
+
choices = event.get("choices")
|
|
914
|
+
if isinstance(choices, list) and choices:
|
|
915
|
+
ch0 = choices[0]
|
|
916
|
+
if isinstance(ch0, dict):
|
|
917
|
+
delta = ch0.get("delta") if isinstance(ch0.get("delta"), dict) else {}
|
|
918
|
+
if isinstance(delta, dict):
|
|
919
|
+
thinking = delta.get("thinking")
|
|
920
|
+
if enable_thinking_ui and isinstance(thinking, str) and thinking:
|
|
921
|
+
saw_thinking_delta = True
|
|
922
|
+
if ttft_s is None:
|
|
923
|
+
ttft_s = time.monotonic() - started
|
|
924
|
+
|
|
925
|
+
buf = thinking
|
|
926
|
+
while buf:
|
|
927
|
+
if not in_think:
|
|
928
|
+
start_idx = buf.find("<think>")
|
|
929
|
+
if start_idx == -1:
|
|
930
|
+
break
|
|
931
|
+
buf = buf[start_idx + len("<think>") :]
|
|
932
|
+
in_think = True
|
|
933
|
+
thinking_accum = ""
|
|
934
|
+
if thinking_start_time is None:
|
|
935
|
+
thinking_start_time = time.monotonic()
|
|
936
|
+
_thinking_panel_render(thinking_accum)
|
|
937
|
+
continue
|
|
938
|
+
|
|
939
|
+
end_idx = buf.find("</think>")
|
|
940
|
+
if end_idx == -1:
|
|
941
|
+
thinking_accum += buf
|
|
942
|
+
buf = ""
|
|
943
|
+
_thinking_panel_render(thinking_accum)
|
|
944
|
+
break
|
|
945
|
+
|
|
946
|
+
thinking_accum += buf[:end_idx]
|
|
947
|
+
buf = buf[end_idx + len("</think>") :]
|
|
948
|
+
if thinking_start_time is not None:
|
|
949
|
+
thinking_end_time = time.monotonic()
|
|
950
|
+
_thinking_panel_clear()
|
|
951
|
+
in_think = False
|
|
952
|
+
break
|
|
953
|
+
|
|
954
|
+
content = delta.get("content")
|
|
955
|
+
if isinstance(content, str) and content:
|
|
956
|
+
if ttft_s is None:
|
|
957
|
+
ttft_s = time.monotonic() - started
|
|
958
|
+
if not enable_thinking_ui or saw_thinking_delta:
|
|
959
|
+
_answer_start_if_needed()
|
|
960
|
+
sys.stdout.write(content)
|
|
961
|
+
sys.stdout.flush()
|
|
962
|
+
else:
|
|
963
|
+
# Fallback: parse <think> tags from content if the server isn't
|
|
964
|
+
# sending delta.thinking.
|
|
965
|
+
content_buf += content
|
|
966
|
+
while content_buf:
|
|
967
|
+
if in_think:
|
|
968
|
+
end_idx = content_buf.find("</think>")
|
|
969
|
+
if end_idx == -1:
|
|
970
|
+
thinking_accum += content_buf
|
|
971
|
+
content_buf = ""
|
|
972
|
+
_thinking_panel_render(thinking_accum)
|
|
973
|
+
break
|
|
974
|
+
|
|
975
|
+
thinking_accum += content_buf[:end_idx]
|
|
976
|
+
content_buf = content_buf[end_idx + len("</think>") :]
|
|
977
|
+
in_think = False
|
|
978
|
+
if thinking_start_time is not None:
|
|
979
|
+
thinking_end_time = time.monotonic()
|
|
980
|
+
_thinking_panel_clear()
|
|
981
|
+
continue
|
|
982
|
+
|
|
983
|
+
start_idx = content_buf.find("<think>")
|
|
984
|
+
if start_idx == -1:
|
|
985
|
+
_answer_start_if_needed()
|
|
986
|
+
sys.stdout.write(content_buf)
|
|
987
|
+
sys.stdout.flush()
|
|
988
|
+
content_buf = ""
|
|
989
|
+
break
|
|
990
|
+
|
|
991
|
+
if start_idx > 0:
|
|
992
|
+
_answer_start_if_needed()
|
|
993
|
+
sys.stdout.write(content_buf[:start_idx])
|
|
994
|
+
sys.stdout.flush()
|
|
995
|
+
|
|
996
|
+
content_buf = content_buf[start_idx + len("<think>") :]
|
|
997
|
+
in_think = True
|
|
998
|
+
thinking_accum = ""
|
|
999
|
+
if thinking_start_time is None:
|
|
1000
|
+
thinking_start_time = time.monotonic()
|
|
1001
|
+
_thinking_panel_render(thinking_accum)
|
|
1002
|
+
continue
|
|
1003
|
+
|
|
1004
|
+
tool_calls = delta.get("tool_calls")
|
|
1005
|
+
if tool_calls is not None:
|
|
1006
|
+
if ttft_s is None:
|
|
1007
|
+
ttft_s = time.monotonic() - started
|
|
1008
|
+
if print_output:
|
|
1009
|
+
sys.stdout.write(
|
|
1010
|
+
f"\n<tool_calls {len(tool_calls) if isinstance(tool_calls, list) else 1}>\n"
|
|
1011
|
+
)
|
|
1012
|
+
sys.stdout.flush()
|
|
1013
|
+
|
|
1014
|
+
fr = ch0.get("finish_reason")
|
|
1015
|
+
if fr is not None:
|
|
1016
|
+
finish_reason = str(fr)
|
|
1017
|
+
|
|
1018
|
+
if isinstance(event.get("usage"), dict):
|
|
1019
|
+
usage = event["usage"]
|
|
1020
|
+
if isinstance(event.get("x_superlinear_timing"), dict):
|
|
1021
|
+
timing = event["x_superlinear_timing"]
|
|
1022
|
+
except KeyboardInterrupt:
|
|
1023
|
+
try:
|
|
1024
|
+
gen.close()
|
|
1025
|
+
except Exception:
|
|
1026
|
+
pass
|
|
1027
|
+
raise
|
|
1028
|
+
finally:
|
|
1029
|
+
try:
|
|
1030
|
+
gen.close()
|
|
1031
|
+
except Exception:
|
|
1032
|
+
pass
|
|
1033
|
+
|
|
1034
|
+
_thinking_panel_clear()
|
|
1035
|
+
if enable_thinking_ui and in_think and thinking_start_time is not None and thinking_end_time is None:
|
|
1036
|
+
if print_output:
|
|
1037
|
+
sys.stdout.write("[thinking incomplete] (no </think> received before stream ended)\n")
|
|
1038
|
+
sys.stdout.flush()
|
|
1039
|
+
|
|
1040
|
+
ended = time.monotonic()
|
|
1041
|
+
if print_output:
|
|
1042
|
+
sys.stdout.write("\n")
|
|
1043
|
+
sys.stdout.flush()
|
|
1044
|
+
|
|
1045
|
+
stats = _TurnStats()
|
|
1046
|
+
stats.finish_reason = finish_reason
|
|
1047
|
+
stats.ttft_s = ttft_s
|
|
1048
|
+
stats.total_s = max(ended - started, 0.0)
|
|
1049
|
+
|
|
1050
|
+
if usage is not None:
|
|
1051
|
+
pt = usage.get("prompt_tokens")
|
|
1052
|
+
ct = usage.get("completion_tokens")
|
|
1053
|
+
if isinstance(pt, int):
|
|
1054
|
+
stats.prompt_tokens = pt
|
|
1055
|
+
if isinstance(ct, int):
|
|
1056
|
+
stats.completion_tokens = ct
|
|
1057
|
+
|
|
1058
|
+
if timing is not None:
|
|
1059
|
+
prefill_s = timing.get("prefill_s")
|
|
1060
|
+
decode_s = timing.get("decode_s")
|
|
1061
|
+
total_s = timing.get("total_s")
|
|
1062
|
+
tok_per_s = timing.get("tok_per_s")
|
|
1063
|
+
if isinstance(prefill_s, (float, int)):
|
|
1064
|
+
stats.server_prefill_s = float(prefill_s)
|
|
1065
|
+
if isinstance(decode_s, (float, int)):
|
|
1066
|
+
stats.server_decode_s = float(decode_s)
|
|
1067
|
+
if isinstance(total_s, (float, int)):
|
|
1068
|
+
stats.server_total_s = float(total_s)
|
|
1069
|
+
if isinstance(tok_per_s, (float, int)):
|
|
1070
|
+
stats.tok_per_s = float(tok_per_s)
|
|
1071
|
+
|
|
1072
|
+
return stats
|
|
1073
|
+
|
|
1074
|
+
|
|
1075
|
+
def _expand_files(args: list[str]) -> list[Path]:
|
|
1076
|
+
out: list[Path] = []
|
|
1077
|
+
for a in args:
|
|
1078
|
+
matches = glob.glob(a, recursive=True)
|
|
1079
|
+
candidates = matches if matches else [a]
|
|
1080
|
+
for c in candidates:
|
|
1081
|
+
p = Path(c).expanduser()
|
|
1082
|
+
if p.is_dir():
|
|
1083
|
+
for f in sorted(p.rglob("*")):
|
|
1084
|
+
if f.is_file():
|
|
1085
|
+
out.append(f)
|
|
1086
|
+
else:
|
|
1087
|
+
out.append(p)
|
|
1088
|
+
|
|
1089
|
+
seen: set[str] = set()
|
|
1090
|
+
deduped: list[Path] = []
|
|
1091
|
+
for p in out:
|
|
1092
|
+
key = str(p.resolve()) if p.exists() else str(p)
|
|
1093
|
+
if key in seen:
|
|
1094
|
+
continue
|
|
1095
|
+
seen.add(key)
|
|
1096
|
+
deduped.append(p)
|
|
1097
|
+
return deduped
|
|
1098
|
+
|
|
1099
|
+
|
|
1100
|
+
def _read_text_file(path: Path) -> tuple[str, int, str]:
|
|
1101
|
+
data = path.read_bytes()
|
|
1102
|
+
if b"\x00" in data:
|
|
1103
|
+
raise DocsReplError(f"Refusing to ingest binary file (NUL byte found): {path}")
|
|
1104
|
+
text = data.decode("utf-8", errors="replace")
|
|
1105
|
+
sha = hashlib.sha256(data).hexdigest()
|
|
1106
|
+
return text, len(data), sha
|
|
1107
|
+
|
|
1108
|
+
|
|
1109
|
+
def _build_docs_message(contents: list[tuple[str, str]]) -> str:
|
|
1110
|
+
# contents: list[(path_str, text)]
|
|
1111
|
+
blocks: list[str] = []
|
|
1112
|
+
for path_str, text in contents:
|
|
1113
|
+
blocks.append(f"[SOURCE path={path_str}]\n{text}\n[/SOURCE]\n")
|
|
1114
|
+
return "\n".join(blocks).strip() + "\n"
|
|
1115
|
+
|
|
1116
|
+
|
|
1117
|
+
def _is_prompt_level_message(m: dict) -> bool:
|
|
1118
|
+
"""Return True if this is a prompt-level message (system, document ingest, or ingest ack)."""
|
|
1119
|
+
role = m.get("role")
|
|
1120
|
+
content = m.get("content") or ""
|
|
1121
|
+
if role == "system":
|
|
1122
|
+
return True
|
|
1123
|
+
if role == "user" and "[SOURCE path=" in content:
|
|
1124
|
+
return True
|
|
1125
|
+
if role == "assistant" and content.strip() == "OK":
|
|
1126
|
+
return True
|
|
1127
|
+
return False
|
|
1128
|
+
|
|
1129
|
+
|
|
1130
|
+
def _cmd_head(*, client: SuperlinearClient, session_id: str, limit: int = 10) -> None:
|
|
1131
|
+
"""Show first n Q&A messages (excludes prompt-level messages)."""
|
|
1132
|
+
try:
|
|
1133
|
+
resp = client.request_json("GET", f"/v1/sessions/{session_id}/history", timeout_s=10.0)
|
|
1134
|
+
except HttpError as exc:
|
|
1135
|
+
raise DocsReplError(str(exc)) from exc
|
|
1136
|
+
|
|
1137
|
+
msgs = resp.get("messages") if isinstance(resp, dict) else None
|
|
1138
|
+
if not isinstance(msgs, list):
|
|
1139
|
+
raise DocsReplError("Invalid response from server for /head")
|
|
1140
|
+
|
|
1141
|
+
# Filter out prompt-level messages
|
|
1142
|
+
qa_msgs = [(i, m) for i, m in enumerate(msgs, 1) if isinstance(m, dict) and not _is_prompt_level_message(m)]
|
|
1143
|
+
|
|
1144
|
+
limit = max(1, min(int(limit), 200))
|
|
1145
|
+
head = qa_msgs[:limit]
|
|
1146
|
+
if not head:
|
|
1147
|
+
print("(no Q&A messages yet)")
|
|
1148
|
+
return
|
|
1149
|
+
|
|
1150
|
+
for orig_idx, m in head:
|
|
1151
|
+
role = m.get("role")
|
|
1152
|
+
content = m.get("content") or ""
|
|
1153
|
+
one_line = content.replace("\r", "").replace("\n", " ").strip()
|
|
1154
|
+
if len(one_line) > 200:
|
|
1155
|
+
one_line = one_line[:197] + "…"
|
|
1156
|
+
print(f"{orig_idx:>4} {role}: {one_line}")
|
|
1157
|
+
|
|
1158
|
+
|
|
1159
|
+
def _cmd_tail(*, client: SuperlinearClient, session_id: str, limit: int = 10) -> None:
|
|
1160
|
+
"""Show last n Q&A messages (excludes prompt-level messages)."""
|
|
1161
|
+
try:
|
|
1162
|
+
resp = client.request_json("GET", f"/v1/sessions/{session_id}/history", timeout_s=10.0)
|
|
1163
|
+
except HttpError as exc:
|
|
1164
|
+
raise DocsReplError(str(exc)) from exc
|
|
1165
|
+
|
|
1166
|
+
msgs = resp.get("messages") if isinstance(resp, dict) else None
|
|
1167
|
+
if not isinstance(msgs, list):
|
|
1168
|
+
raise DocsReplError("Invalid response from server for /tail")
|
|
1169
|
+
|
|
1170
|
+
# Filter out prompt-level messages
|
|
1171
|
+
qa_msgs = [(i, m) for i, m in enumerate(msgs, 1) if isinstance(m, dict) and not _is_prompt_level_message(m)]
|
|
1172
|
+
|
|
1173
|
+
limit = max(1, min(int(limit), 200))
|
|
1174
|
+
tail = qa_msgs[-limit:]
|
|
1175
|
+
if not tail:
|
|
1176
|
+
print("(no Q&A messages yet)")
|
|
1177
|
+
return
|
|
1178
|
+
|
|
1179
|
+
for orig_idx, m in tail:
|
|
1180
|
+
role = m.get("role")
|
|
1181
|
+
content = m.get("content") or ""
|
|
1182
|
+
one_line = content.replace("\r", "").replace("\n", " ").strip()
|
|
1183
|
+
if len(one_line) > 200:
|
|
1184
|
+
one_line = one_line[:197] + "…"
|
|
1185
|
+
print(f"{orig_idx:>4} {role}: {one_line}")
|
|
1186
|
+
|
|
1187
|
+
|
|
1188
|
+
def _wrap_for_terminal(text: str, *, indent: str = "", width: int | None = None) -> str:
|
|
1189
|
+
cols = shutil.get_terminal_size(fallback=(120, 24)).columns
|
|
1190
|
+
target_width = cols if width is None else int(width)
|
|
1191
|
+
target_width = max(20, target_width)
|
|
1192
|
+
|
|
1193
|
+
normalized = text.replace("\r", "")
|
|
1194
|
+
out_lines: list[str] = []
|
|
1195
|
+
for logical in normalized.split("\n"):
|
|
1196
|
+
if not logical:
|
|
1197
|
+
out_lines.append(indent)
|
|
1198
|
+
continue
|
|
1199
|
+
wrapped = textwrap.wrap(
|
|
1200
|
+
logical,
|
|
1201
|
+
width=max(10, target_width - len(indent)),
|
|
1202
|
+
replace_whitespace=False,
|
|
1203
|
+
drop_whitespace=False,
|
|
1204
|
+
break_long_words=True,
|
|
1205
|
+
break_on_hyphens=False,
|
|
1206
|
+
)
|
|
1207
|
+
if not wrapped:
|
|
1208
|
+
out_lines.append(indent)
|
|
1209
|
+
else:
|
|
1210
|
+
out_lines.extend([indent + w for w in wrapped])
|
|
1211
|
+
return "\n".join(out_lines)
|
|
1212
|
+
|
|
1213
|
+
|
|
1214
|
+
def _cmd_show(*, client: SuperlinearClient, session_id: str, index: int) -> None:
|
|
1215
|
+
"""Show a single message in full by 1-based index.
|
|
1216
|
+
|
|
1217
|
+
Use the index shown by /head or /tail (it is the original index in the session history).
|
|
1218
|
+
"""
|
|
1219
|
+
try:
|
|
1220
|
+
resp = client.request_json("GET", f"/v1/sessions/{session_id}/history", timeout_s=10.0)
|
|
1221
|
+
except HttpError as exc:
|
|
1222
|
+
raise DocsReplError(str(exc)) from exc
|
|
1223
|
+
|
|
1224
|
+
msgs = resp.get("messages") if isinstance(resp, dict) else None
|
|
1225
|
+
if not isinstance(msgs, list):
|
|
1226
|
+
raise DocsReplError("Invalid response from server for /show")
|
|
1227
|
+
|
|
1228
|
+
n = len(msgs)
|
|
1229
|
+
if n == 0:
|
|
1230
|
+
print("(empty)")
|
|
1231
|
+
return
|
|
1232
|
+
if index < 1 or index > n:
|
|
1233
|
+
raise DocsReplError(f"Message index out of range: {index} (1..{n})")
|
|
1234
|
+
|
|
1235
|
+
m = msgs[index - 1]
|
|
1236
|
+
if not isinstance(m, dict):
|
|
1237
|
+
raise DocsReplError("Invalid message format")
|
|
1238
|
+
|
|
1239
|
+
role = str(m.get("role") or "")
|
|
1240
|
+
content = m.get("content")
|
|
1241
|
+
tool_calls = m.get("tool_calls")
|
|
1242
|
+
|
|
1243
|
+
if content is None and tool_calls is not None:
|
|
1244
|
+
content_str = f"<tool_calls {len(tool_calls) if isinstance(tool_calls, list) else 1}>"
|
|
1245
|
+
else:
|
|
1246
|
+
content_str = "" if content is None else str(content)
|
|
1247
|
+
|
|
1248
|
+
header = f"{index:>4} {role}:"
|
|
1249
|
+
print(header)
|
|
1250
|
+
if content_str:
|
|
1251
|
+
print(_wrap_for_terminal(content_str, indent=" "))
|
|
1252
|
+
else:
|
|
1253
|
+
print(" (empty)")
|
|
1254
|
+
|
|
1255
|
+
|
|
1256
|
+
def _cmd_ls(*, client: SuperlinearClient, current_session_id: str, docs_workspaces: dict[str, DocsWorkspaceState]) -> None:
|
|
1257
|
+
"""List all sessions and snapshots."""
|
|
1258
|
+
# Sessions
|
|
1259
|
+
try:
|
|
1260
|
+
payload = client.request_json("GET", "/v1/sessions", timeout_s=10.0)
|
|
1261
|
+
except HttpError as exc:
|
|
1262
|
+
raise DocsReplError(str(exc)) from exc
|
|
1263
|
+
|
|
1264
|
+
raw_sessions = payload.get("sessions") if isinstance(payload, dict) else None
|
|
1265
|
+
if not isinstance(raw_sessions, list):
|
|
1266
|
+
raw_sessions = []
|
|
1267
|
+
session_ids = [s for s in raw_sessions if isinstance(s, str)]
|
|
1268
|
+
|
|
1269
|
+
# Build reverse lookup: session_id -> workspace name
|
|
1270
|
+
session_to_workspace: dict[str, tuple[str, DocsWorkspaceState]] = {}
|
|
1271
|
+
for ws_name, ws_state in docs_workspaces.items():
|
|
1272
|
+
session_to_workspace[ws_state.session_id] = (ws_name, ws_state)
|
|
1273
|
+
|
|
1274
|
+
print("sessions:")
|
|
1275
|
+
if not session_ids:
|
|
1276
|
+
print(" (none)")
|
|
1277
|
+
else:
|
|
1278
|
+
for sid in session_ids:
|
|
1279
|
+
marker = " *" if sid == current_session_id else ""
|
|
1280
|
+
# Add workspace info for docs sessions
|
|
1281
|
+
if sid in session_to_workspace:
|
|
1282
|
+
ws_name, ws_state = session_to_workspace[sid]
|
|
1283
|
+
src_count = len(ws_state.sources) if ws_state.sources else 0
|
|
1284
|
+
snap_info = f" snap={ws_state.base_snapshot_id[:8]}..." if ws_state.base_snapshot_id else ""
|
|
1285
|
+
print(f" {sid}{marker} (docs:{ws_name} {src_count} sources{snap_info})")
|
|
1286
|
+
else:
|
|
1287
|
+
print(f" {sid}{marker}")
|
|
1288
|
+
|
|
1289
|
+
# Snapshots (local)
|
|
1290
|
+
snapshots = list_local_snapshots()
|
|
1291
|
+
print("\nsnapshots:")
|
|
1292
|
+
if not snapshots:
|
|
1293
|
+
print(" (none)")
|
|
1294
|
+
else:
|
|
1295
|
+
for snap in snapshots:
|
|
1296
|
+
sid = snap.get("snapshot_id") or ""
|
|
1297
|
+
title = snap.get("title") or ""
|
|
1298
|
+
if title:
|
|
1299
|
+
print(f" {sid} {title}")
|
|
1300
|
+
else:
|
|
1301
|
+
print(f" {sid}")
|
|
1302
|
+
|
|
1303
|
+
|
|
1304
|
+
def _cmd_rm(
|
|
1305
|
+
*,
|
|
1306
|
+
client: SuperlinearClient,
|
|
1307
|
+
target_ids: list[str],
|
|
1308
|
+
current_session_id: str,
|
|
1309
|
+
) -> bool:
|
|
1310
|
+
"""Remove session(s) and/or snapshot(s). Returns True if current session was removed."""
|
|
1311
|
+
removed_current = False
|
|
1312
|
+
for target_id in target_ids:
|
|
1313
|
+
# Check if it's a snapshot ID (32-char hex)
|
|
1314
|
+
raw_id = target_id[5:] if target_id.startswith("snap-") else target_id
|
|
1315
|
+
is_snapshot = len(raw_id) == 32 and all(c in "0123456789abcdef" for c in raw_id.lower())
|
|
1316
|
+
|
|
1317
|
+
if is_snapshot:
|
|
1318
|
+
# Delete snapshot
|
|
1319
|
+
deleted = delete_local_snapshot(raw_id)
|
|
1320
|
+
if deleted:
|
|
1321
|
+
print(f"removed snapshot_id={raw_id}")
|
|
1322
|
+
else:
|
|
1323
|
+
print(f"error: snapshot not found: {raw_id}", file=sys.stderr)
|
|
1324
|
+
else:
|
|
1325
|
+
# Delete session
|
|
1326
|
+
try:
|
|
1327
|
+
client.request_json("DELETE", f"/v1/sessions/{target_id}", timeout_s=10.0)
|
|
1328
|
+
print(f"removed session_id={target_id}")
|
|
1329
|
+
if target_id == current_session_id:
|
|
1330
|
+
removed_current = True
|
|
1331
|
+
except HttpError as exc:
|
|
1332
|
+
if exc.status_code == 404:
|
|
1333
|
+
print(f"error: session not found: {target_id}", file=sys.stderr)
|
|
1334
|
+
else:
|
|
1335
|
+
print(f"error: failed to remove {target_id}: {exc}", file=sys.stderr)
|
|
1336
|
+
return removed_current
|
|
1337
|
+
|
|
1338
|
+
|
|
1339
|
+
def _cmd_help() -> None:
|
|
1340
|
+
print(
|
|
1341
|
+
"\n".join(
|
|
1342
|
+
[
|
|
1343
|
+
"commands:",
|
|
1344
|
+
" /help",
|
|
1345
|
+
" /exit [-c] exit (--clean/-c: delete workspace)",
|
|
1346
|
+
" /clear clear screen",
|
|
1347
|
+
" /history [n] show last n input commands (default 20)",
|
|
1348
|
+
" /history clear clear input command history",
|
|
1349
|
+
" /info show workspace info",
|
|
1350
|
+
" /ls list sessions and snapshots",
|
|
1351
|
+
" /rm delete current session, start fresh",
|
|
1352
|
+
" /rm <id...> delete session(s) or snapshot(s)",
|
|
1353
|
+
" /head [n] show first n Q&A messages",
|
|
1354
|
+
" /tail [n] show last n Q&A messages",
|
|
1355
|
+
" /show <i> show message i in full (use /tail to find ids)",
|
|
1356
|
+
" /add <paths...> [-s] add documents (--save/-s: save snapshot)",
|
|
1357
|
+
" /sources list loaded sources",
|
|
1358
|
+
" /rag ... configure RAG backend",
|
|
1359
|
+
" /reset reset workspace to base snapshot",
|
|
1360
|
+
" /stats show last turn metrics",
|
|
1361
|
+
" /save [title] save snapshot",
|
|
1362
|
+
" /load <snap> load snapshot into new workspace",
|
|
1363
|
+
]
|
|
1364
|
+
)
|
|
1365
|
+
)
|
|
1366
|
+
|
|
1367
|
+
|
|
1368
|
+
def docs_repl(
|
|
1369
|
+
*,
|
|
1370
|
+
url: str,
|
|
1371
|
+
name: str,
|
|
1372
|
+
load_snapshot_id: str | None = None,
|
|
1373
|
+
max_seq_len: int | None = None,
|
|
1374
|
+
think_budget: int | None = 32768,
|
|
1375
|
+
temperature: float = 0.3,
|
|
1376
|
+
top_p: float = 0.95,
|
|
1377
|
+
system_prompt: str | None = None,
|
|
1378
|
+
) -> int:
|
|
1379
|
+
_setup_readline_history()
|
|
1380
|
+
_setup_completer()
|
|
1381
|
+
if not name or not isinstance(name, str):
|
|
1382
|
+
print("docs name is required: `spl docs <name>`", file=sys.stderr)
|
|
1383
|
+
return 2
|
|
1384
|
+
|
|
1385
|
+
client = SuperlinearClient(base_url=url, timeout_s=3600.0)
|
|
1386
|
+
try:
|
|
1387
|
+
_ensure_reachable(client)
|
|
1388
|
+
except DocsReplError as exc:
|
|
1389
|
+
print(str(exc), file=sys.stderr)
|
|
1390
|
+
return 1
|
|
1391
|
+
|
|
1392
|
+
state = load_state()
|
|
1393
|
+
existing_ws = state.docs_workspaces.get(name)
|
|
1394
|
+
resumed = existing_ws is not None
|
|
1395
|
+
|
|
1396
|
+
if existing_ws is None:
|
|
1397
|
+
ws = DocsWorkspaceState(
|
|
1398
|
+
session_id=_new_session_id(workspace_name=name),
|
|
1399
|
+
phase="INGEST",
|
|
1400
|
+
base_snapshot_id=None,
|
|
1401
|
+
sources=[],
|
|
1402
|
+
)
|
|
1403
|
+
else:
|
|
1404
|
+
ws = existing_ws
|
|
1405
|
+
|
|
1406
|
+
lock = SessionLock(session_id=ws.session_id, kind="docs", label=f"spl docs {name}")
|
|
1407
|
+
try:
|
|
1408
|
+
try:
|
|
1409
|
+
lock.acquire()
|
|
1410
|
+
except AlreadyLockedError as exc:
|
|
1411
|
+
print(
|
|
1412
|
+
f"error: workspace is already open in another REPL (workspace={name} session_id={ws.session_id} pid={exc.info.pid}).",
|
|
1413
|
+
file=sys.stderr,
|
|
1414
|
+
)
|
|
1415
|
+
print("next steps: close the other REPL, or choose a different docs workspace name.", file=sys.stderr)
|
|
1416
|
+
return 2
|
|
1417
|
+
|
|
1418
|
+
# Handle --load flag: load from a snapshot
|
|
1419
|
+
if load_snapshot_id is not None:
|
|
1420
|
+
snap_id = load_snapshot_id.strip()
|
|
1421
|
+
# Normalize: accept raw 32-char hex or snap- prefix
|
|
1422
|
+
if snap_id.startswith("snap-"):
|
|
1423
|
+
snap_id = snap_id[5:]
|
|
1424
|
+
if len(snap_id) != 32:
|
|
1425
|
+
print(f"error: invalid snapshot id: {load_snapshot_id}", file=sys.stderr)
|
|
1426
|
+
return 2
|
|
1427
|
+
|
|
1428
|
+
# Delete existing session if it exists
|
|
1429
|
+
if _session_exists(client, ws.session_id):
|
|
1430
|
+
try:
|
|
1431
|
+
client.request_json("DELETE", f"/v1/sessions/{ws.session_id}", timeout_s=30.0)
|
|
1432
|
+
except HttpError:
|
|
1433
|
+
pass # Ignore deletion errors
|
|
1434
|
+
|
|
1435
|
+
# Load from snapshot into a fresh session
|
|
1436
|
+
try:
|
|
1437
|
+
client.request_json(
|
|
1438
|
+
"POST",
|
|
1439
|
+
f"/v1/snapshots/{snap_id}/load",
|
|
1440
|
+
payload={"session_id": ws.session_id},
|
|
1441
|
+
timeout_s=300.0,
|
|
1442
|
+
)
|
|
1443
|
+
except HttpError as exc:
|
|
1444
|
+
if exc.status_code == 404:
|
|
1445
|
+
print(f"error: snapshot not found: {snap_id} (use `spl snapshot ls`)", file=sys.stderr)
|
|
1446
|
+
elif exc.status_code == 429:
|
|
1447
|
+
print("Server is busy (429). Try again.", file=sys.stderr)
|
|
1448
|
+
else:
|
|
1449
|
+
print(str(exc), file=sys.stderr)
|
|
1450
|
+
return 1
|
|
1451
|
+
|
|
1452
|
+
# Update workspace state with the loaded snapshot
|
|
1453
|
+
ws.base_snapshot_id = snap_id
|
|
1454
|
+
ws.phase = "INGEST"
|
|
1455
|
+
ws.sources = _hydrate_sources_from_snapshot(client, snap_id) or []
|
|
1456
|
+
resumed = True # Mark as resumed since we loaded from snapshot
|
|
1457
|
+
|
|
1458
|
+
state.docs_workspaces[name] = ws
|
|
1459
|
+
save_state(state)
|
|
1460
|
+
|
|
1461
|
+
_maybe_resize_session(client, ws.session_id, min_max_seq_len=max_seq_len)
|
|
1462
|
+
|
|
1463
|
+
elif existing_ws is None:
|
|
1464
|
+
_create_session_with_max_seq_len(client, ws.session_id, max_seq_len=max_seq_len)
|
|
1465
|
+
_maybe_resize_session(client, ws.session_id, min_max_seq_len=max_seq_len)
|
|
1466
|
+
else:
|
|
1467
|
+
if _session_exists(client, ws.session_id):
|
|
1468
|
+
_maybe_resize_session(client, ws.session_id, min_max_seq_len=max_seq_len)
|
|
1469
|
+
else:
|
|
1470
|
+
print(f"note: session not found on server: {ws.session_id}", file=sys.stderr)
|
|
1471
|
+
if ws.base_snapshot_id:
|
|
1472
|
+
try:
|
|
1473
|
+
client.request_json(
|
|
1474
|
+
"POST",
|
|
1475
|
+
f"/v1/snapshots/{ws.base_snapshot_id}/load",
|
|
1476
|
+
payload={"session_id": ws.session_id},
|
|
1477
|
+
timeout_s=300.0,
|
|
1478
|
+
)
|
|
1479
|
+
except HttpError as exc:
|
|
1480
|
+
if exc.status_code == 404:
|
|
1481
|
+
print(
|
|
1482
|
+
f"error: base snapshot not found: {ws.base_snapshot_id} (use `spl snapshot ls`).",
|
|
1483
|
+
file=sys.stderr,
|
|
1484
|
+
)
|
|
1485
|
+
ws.base_snapshot_id = None
|
|
1486
|
+
ws.sources = []
|
|
1487
|
+
_create_session(client, ws.session_id)
|
|
1488
|
+
else:
|
|
1489
|
+
raise
|
|
1490
|
+
ws.phase = "INGEST"
|
|
1491
|
+
else:
|
|
1492
|
+
ws.phase = "INGEST"
|
|
1493
|
+
ws.sources = []
|
|
1494
|
+
ws.base_snapshot_id = None
|
|
1495
|
+
_create_session_with_max_seq_len(client, ws.session_id, max_seq_len=max_seq_len)
|
|
1496
|
+
_maybe_resize_session(client, ws.session_id, min_max_seq_len=max_seq_len)
|
|
1497
|
+
|
|
1498
|
+
# If we restored from a snapshot above, its max_seq_len may be smaller than the user's
|
|
1499
|
+
# requested context length. Apply a best-effort resize now.
|
|
1500
|
+
_maybe_resize_session(client, ws.session_id, min_max_seq_len=max_seq_len)
|
|
1501
|
+
|
|
1502
|
+
# Best-effort hydration of sources from the base snapshot metadata.
|
|
1503
|
+
if (not ws.sources) and ws.base_snapshot_id:
|
|
1504
|
+
snap_sources = _hydrate_sources_from_snapshot(client, ws.base_snapshot_id)
|
|
1505
|
+
if snap_sources is not None:
|
|
1506
|
+
ws.sources = snap_sources
|
|
1507
|
+
|
|
1508
|
+
state.docs_workspaces[name] = ws
|
|
1509
|
+
save_state(state)
|
|
1510
|
+
|
|
1511
|
+
light_retriever = LightRagRetriever()
|
|
1512
|
+
bm25_retriever = Bm25RagRetriever()
|
|
1513
|
+
bm25_hint_printed = False
|
|
1514
|
+
|
|
1515
|
+
_banner(
|
|
1516
|
+
url=client.base_url,
|
|
1517
|
+
name=name,
|
|
1518
|
+
session_id=ws.session_id,
|
|
1519
|
+
resumed=resumed,
|
|
1520
|
+
phase=ws.phase,
|
|
1521
|
+
source_count=len(ws.sources),
|
|
1522
|
+
base_snapshot_id=ws.base_snapshot_id,
|
|
1523
|
+
rag_status=_format_rag_status(ws=ws, bm25_available=bm25_retriever.is_available()),
|
|
1524
|
+
)
|
|
1525
|
+
|
|
1526
|
+
last_stats: _TurnStats | None = None
|
|
1527
|
+
last_substantive_user_question: str | None = None
|
|
1528
|
+
|
|
1529
|
+
while True:
|
|
1530
|
+
prompt = f"spl(docs:{name}:{ws.phase})> "
|
|
1531
|
+
try:
|
|
1532
|
+
line = input(prompt)
|
|
1533
|
+
except EOFError:
|
|
1534
|
+
print()
|
|
1535
|
+
return 0
|
|
1536
|
+
except KeyboardInterrupt:
|
|
1537
|
+
print("^C")
|
|
1538
|
+
continue
|
|
1539
|
+
|
|
1540
|
+
line = line.strip()
|
|
1541
|
+
if not line:
|
|
1542
|
+
continue
|
|
1543
|
+
|
|
1544
|
+
if line.startswith("/"):
|
|
1545
|
+
cmdline = line[1:].strip()
|
|
1546
|
+
try:
|
|
1547
|
+
parts = shlex.split(cmdline)
|
|
1548
|
+
except ValueError as exc:
|
|
1549
|
+
print(f"parse error: {exc}", file=sys.stderr)
|
|
1550
|
+
continue
|
|
1551
|
+
if not parts:
|
|
1552
|
+
continue
|
|
1553
|
+
cmd, args = parts[0], parts[1:]
|
|
1554
|
+
|
|
1555
|
+
if cmd in {"exit", "quit"}:
|
|
1556
|
+
clean = "--clean" in args or "-c" in args
|
|
1557
|
+
if clean:
|
|
1558
|
+
try:
|
|
1559
|
+
client.request_json("DELETE", f"/v1/sessions/{ws.session_id}", timeout_s=30.0)
|
|
1560
|
+
except HttpError:
|
|
1561
|
+
pass
|
|
1562
|
+
del state.docs_workspaces[name]
|
|
1563
|
+
save_state(state)
|
|
1564
|
+
print(f"deleted workspace: {name}")
|
|
1565
|
+
return 0
|
|
1566
|
+
if cmd == "clear":
|
|
1567
|
+
print("\033[2J\033[H", end="", flush=True)
|
|
1568
|
+
continue
|
|
1569
|
+
if cmd == "help":
|
|
1570
|
+
_cmd_help()
|
|
1571
|
+
continue
|
|
1572
|
+
if cmd == "session":
|
|
1573
|
+
print(ws.session_id)
|
|
1574
|
+
continue
|
|
1575
|
+
if cmd == "info":
|
|
1576
|
+
print(f"workspace={name}")
|
|
1577
|
+
print(f"session_id={ws.session_id}")
|
|
1578
|
+
print(f"phase={ws.phase}")
|
|
1579
|
+
print(f"sources={len(ws.sources) if ws.sources else 0}")
|
|
1580
|
+
if ws.base_snapshot_id:
|
|
1581
|
+
print(f"base_snapshot_id={ws.base_snapshot_id}")
|
|
1582
|
+
print(_format_rag_status(ws=ws, bm25_available=bm25_retriever.is_available()))
|
|
1583
|
+
continue
|
|
1584
|
+
if cmd == "stats":
|
|
1585
|
+
if last_stats is None:
|
|
1586
|
+
print("no stats yet")
|
|
1587
|
+
else:
|
|
1588
|
+
print(_stats_detail(last_stats))
|
|
1589
|
+
continue
|
|
1590
|
+
if cmd == "head":
|
|
1591
|
+
n = 10
|
|
1592
|
+
if len(args) == 1:
|
|
1593
|
+
try:
|
|
1594
|
+
n = int(args[0])
|
|
1595
|
+
except Exception:
|
|
1596
|
+
print("usage: /head [n]", file=sys.stderr)
|
|
1597
|
+
continue
|
|
1598
|
+
elif len(args) > 1:
|
|
1599
|
+
print("usage: /head [n]", file=sys.stderr)
|
|
1600
|
+
continue
|
|
1601
|
+
try:
|
|
1602
|
+
_cmd_head(client=client, session_id=ws.session_id, limit=n)
|
|
1603
|
+
except DocsReplError as exc:
|
|
1604
|
+
print(str(exc), file=sys.stderr)
|
|
1605
|
+
continue
|
|
1606
|
+
if cmd == "tail":
|
|
1607
|
+
n = 10
|
|
1608
|
+
if len(args) == 1:
|
|
1609
|
+
try:
|
|
1610
|
+
n = int(args[0])
|
|
1611
|
+
except Exception:
|
|
1612
|
+
print("usage: /tail [n]", file=sys.stderr)
|
|
1613
|
+
continue
|
|
1614
|
+
elif len(args) > 1:
|
|
1615
|
+
print("usage: /tail [n]", file=sys.stderr)
|
|
1616
|
+
continue
|
|
1617
|
+
try:
|
|
1618
|
+
_cmd_tail(client=client, session_id=ws.session_id, limit=n)
|
|
1619
|
+
except DocsReplError as exc:
|
|
1620
|
+
print(str(exc), file=sys.stderr)
|
|
1621
|
+
continue
|
|
1622
|
+
if cmd == "show":
|
|
1623
|
+
if len(args) != 1:
|
|
1624
|
+
print("usage: /show <i>", file=sys.stderr)
|
|
1625
|
+
continue
|
|
1626
|
+
try:
|
|
1627
|
+
i = int(args[0])
|
|
1628
|
+
except Exception:
|
|
1629
|
+
print("usage: /show <i>", file=sys.stderr)
|
|
1630
|
+
continue
|
|
1631
|
+
try:
|
|
1632
|
+
_cmd_show(client=client, session_id=ws.session_id, index=i)
|
|
1633
|
+
except DocsReplError as exc:
|
|
1634
|
+
print(str(exc), file=sys.stderr)
|
|
1635
|
+
continue
|
|
1636
|
+
if cmd == "history":
|
|
1637
|
+
if args in [["clear"], ["--clear"], ["-c"]]:
|
|
1638
|
+
_cmd_history_clear()
|
|
1639
|
+
continue
|
|
1640
|
+
n = 20
|
|
1641
|
+
if len(args) == 1:
|
|
1642
|
+
try:
|
|
1643
|
+
n = int(args[0])
|
|
1644
|
+
except Exception:
|
|
1645
|
+
print("usage: /history [n] | /history clear", file=sys.stderr)
|
|
1646
|
+
continue
|
|
1647
|
+
elif len(args) > 1:
|
|
1648
|
+
print("usage: /history [n] | /history clear", file=sys.stderr)
|
|
1649
|
+
continue
|
|
1650
|
+
_cmd_history(n)
|
|
1651
|
+
continue
|
|
1652
|
+
if cmd == "ls":
|
|
1653
|
+
try:
|
|
1654
|
+
_cmd_ls(client=client, current_session_id=ws.session_id, docs_workspaces=state.docs_workspaces)
|
|
1655
|
+
except DocsReplError as exc:
|
|
1656
|
+
print(str(exc), file=sys.stderr)
|
|
1657
|
+
continue
|
|
1658
|
+
if cmd == "rm":
|
|
1659
|
+
if not args:
|
|
1660
|
+
# /rm with no args = delete current session and reset workspace
|
|
1661
|
+
try:
|
|
1662
|
+
client.request_json("DELETE", f"/v1/sessions/{ws.session_id}", timeout_s=10.0)
|
|
1663
|
+
print(f"removed session_id={ws.session_id}")
|
|
1664
|
+
except HttpError as exc:
|
|
1665
|
+
if exc.status_code != 404:
|
|
1666
|
+
print(f"error: {exc}", file=sys.stderr)
|
|
1667
|
+
# Reset workspace and create new session
|
|
1668
|
+
new_session = _new_session_id(workspace_name=name)
|
|
1669
|
+
_create_session_with_max_seq_len(client, new_session, max_seq_len=max_seq_len)
|
|
1670
|
+
lock.release()
|
|
1671
|
+
lock = SessionLock(session_id=new_session, kind="docs", label=f"spl docs {name}")
|
|
1672
|
+
lock.acquire()
|
|
1673
|
+
ws.session_id = new_session
|
|
1674
|
+
ws.phase = "INGEST"
|
|
1675
|
+
ws.base_snapshot_id = None
|
|
1676
|
+
ws.sources = []
|
|
1677
|
+
state.docs_workspaces[name] = ws
|
|
1678
|
+
save_state(state)
|
|
1679
|
+
light_retriever.clear_cache()
|
|
1680
|
+
bm25_retriever.clear_index()
|
|
1681
|
+
print(f"new session_id={new_session}")
|
|
1682
|
+
continue
|
|
1683
|
+
try:
|
|
1684
|
+
removed_current = _cmd_rm(
|
|
1685
|
+
client=client,
|
|
1686
|
+
target_ids=args,
|
|
1687
|
+
current_session_id=ws.session_id,
|
|
1688
|
+
)
|
|
1689
|
+
if removed_current:
|
|
1690
|
+
# Current session removed - need to recreate
|
|
1691
|
+
print("current session deleted; recreating...")
|
|
1692
|
+
new_session = _new_session_id(workspace_name=name)
|
|
1693
|
+
_create_session_with_max_seq_len(client, new_session, max_seq_len=max_seq_len)
|
|
1694
|
+
ws.session_id = new_session
|
|
1695
|
+
ws.phase = "INGEST"
|
|
1696
|
+
ws.base_snapshot_id = None
|
|
1697
|
+
ws.sources = []
|
|
1698
|
+
state.docs_workspaces[name] = ws
|
|
1699
|
+
save_state(state)
|
|
1700
|
+
lock.release()
|
|
1701
|
+
lock = SessionLock(session_id=new_session, kind="docs", label=f"spl docs {name}")
|
|
1702
|
+
lock.acquire()
|
|
1703
|
+
light_retriever.clear_cache()
|
|
1704
|
+
bm25_retriever.clear_index()
|
|
1705
|
+
except DocsReplError as exc:
|
|
1706
|
+
print(str(exc), file=sys.stderr)
|
|
1707
|
+
continue
|
|
1708
|
+
if cmd == "sources":
|
|
1709
|
+
if (not ws.sources) and ws.base_snapshot_id:
|
|
1710
|
+
snap_sources = _hydrate_sources_from_snapshot(client, ws.base_snapshot_id)
|
|
1711
|
+
if snap_sources is not None:
|
|
1712
|
+
ws.sources = snap_sources
|
|
1713
|
+
state.docs_workspaces[name] = ws
|
|
1714
|
+
save_state(state)
|
|
1715
|
+
print(_format_sources_table(ws.sources))
|
|
1716
|
+
continue
|
|
1717
|
+
if cmd == "rag":
|
|
1718
|
+
if not args:
|
|
1719
|
+
print(_format_rag_status(ws=ws, bm25_available=bm25_retriever.is_available()))
|
|
1720
|
+
print(
|
|
1721
|
+
"usage: /rag off | /rag backend=light|bm25 | /rag on | /rag k=<int> | /rag chars=<int> | "
|
|
1722
|
+
"/rag per_source_chars=<int> | /rag debug on|off | /rag bm25_k_paragraphs=<int> | /rag bm25_k_sources=<int>"
|
|
1723
|
+
)
|
|
1724
|
+
continue
|
|
1725
|
+
|
|
1726
|
+
backend = _rag_backend_from_ws(ws)
|
|
1727
|
+
cfg = _light_rag_config_from_ws(ws).sanitized()
|
|
1728
|
+
k = int(cfg.k)
|
|
1729
|
+
total_chars = int(cfg.total_chars)
|
|
1730
|
+
per_source_chars = int(cfg.per_source_chars)
|
|
1731
|
+
debug = bool(cfg.debug)
|
|
1732
|
+
bm25_k_paragraphs = int(getattr(ws, "bm25_k_paragraphs", 40))
|
|
1733
|
+
bm25_k_sources = int(getattr(ws, "bm25_k_sources", 0))
|
|
1734
|
+
backend_requested_bm25 = False
|
|
1735
|
+
|
|
1736
|
+
ok = True
|
|
1737
|
+
i = 0
|
|
1738
|
+
while i < len(args):
|
|
1739
|
+
a = args[i]
|
|
1740
|
+
if a in {"on", "off"}:
|
|
1741
|
+
if a == "off":
|
|
1742
|
+
backend = "off"
|
|
1743
|
+
elif backend == "off":
|
|
1744
|
+
backend = "light"
|
|
1745
|
+
i += 1
|
|
1746
|
+
continue
|
|
1747
|
+
|
|
1748
|
+
if a == "backend":
|
|
1749
|
+
if i + 1 >= len(args):
|
|
1750
|
+
print("usage: /rag backend light|bm25|off", file=sys.stderr)
|
|
1751
|
+
ok = False
|
|
1752
|
+
break
|
|
1753
|
+
b = args[i + 1].strip().lower()
|
|
1754
|
+
if b not in {"light", "bm25", "off"}:
|
|
1755
|
+
print(f"invalid /rag backend: {b!r}", file=sys.stderr)
|
|
1756
|
+
ok = False
|
|
1757
|
+
break
|
|
1758
|
+
backend = b
|
|
1759
|
+
backend_requested_bm25 = backend_requested_bm25 or (b == "bm25")
|
|
1760
|
+
i += 2
|
|
1761
|
+
continue
|
|
1762
|
+
|
|
1763
|
+
if a.startswith("backend="):
|
|
1764
|
+
b = a.split("=", 1)[1].strip().lower()
|
|
1765
|
+
if b not in {"light", "bm25", "off"}:
|
|
1766
|
+
print(f"invalid /rag backend: {b!r}", file=sys.stderr)
|
|
1767
|
+
ok = False
|
|
1768
|
+
break
|
|
1769
|
+
backend = b
|
|
1770
|
+
backend_requested_bm25 = backend_requested_bm25 or (b == "bm25")
|
|
1771
|
+
i += 1
|
|
1772
|
+
continue
|
|
1773
|
+
|
|
1774
|
+
if a == "debug":
|
|
1775
|
+
if i + 1 >= len(args) or args[i + 1] not in {"on", "off"}:
|
|
1776
|
+
print("usage: /rag debug on|off", file=sys.stderr)
|
|
1777
|
+
ok = False
|
|
1778
|
+
break
|
|
1779
|
+
debug = args[i + 1] == "on"
|
|
1780
|
+
i += 2
|
|
1781
|
+
continue
|
|
1782
|
+
|
|
1783
|
+
if a.startswith("k="):
|
|
1784
|
+
try:
|
|
1785
|
+
k = int(a.split("=", 1)[1])
|
|
1786
|
+
except Exception:
|
|
1787
|
+
print(f"invalid /rag k: {a!r}", file=sys.stderr)
|
|
1788
|
+
ok = False
|
|
1789
|
+
break
|
|
1790
|
+
i += 1
|
|
1791
|
+
continue
|
|
1792
|
+
|
|
1793
|
+
if a.startswith("chars="):
|
|
1794
|
+
try:
|
|
1795
|
+
total_chars = int(a.split("=", 1)[1])
|
|
1796
|
+
except Exception:
|
|
1797
|
+
print(f"invalid /rag chars: {a!r}", file=sys.stderr)
|
|
1798
|
+
ok = False
|
|
1799
|
+
break
|
|
1800
|
+
i += 1
|
|
1801
|
+
continue
|
|
1802
|
+
|
|
1803
|
+
if a.startswith("per_source_chars=") or a.startswith("per-source-chars="):
|
|
1804
|
+
try:
|
|
1805
|
+
per_source_chars = int(a.split("=", 1)[1])
|
|
1806
|
+
except Exception:
|
|
1807
|
+
print(f"invalid /rag per_source_chars: {a!r}", file=sys.stderr)
|
|
1808
|
+
ok = False
|
|
1809
|
+
break
|
|
1810
|
+
i += 1
|
|
1811
|
+
continue
|
|
1812
|
+
|
|
1813
|
+
if a.startswith("bm25_k_paragraphs=") or a.startswith("bm25-k-paragraphs="):
|
|
1814
|
+
try:
|
|
1815
|
+
bm25_k_paragraphs = int(a.split("=", 1)[1])
|
|
1816
|
+
except Exception:
|
|
1817
|
+
print(f"invalid /rag bm25_k_paragraphs: {a!r}", file=sys.stderr)
|
|
1818
|
+
ok = False
|
|
1819
|
+
break
|
|
1820
|
+
i += 1
|
|
1821
|
+
continue
|
|
1822
|
+
|
|
1823
|
+
if a.startswith("bm25_k_sources=") or a.startswith("bm25-k-sources="):
|
|
1824
|
+
try:
|
|
1825
|
+
bm25_k_sources = int(a.split("=", 1)[1])
|
|
1826
|
+
except Exception:
|
|
1827
|
+
print(f"invalid /rag bm25_k_sources: {a!r}", file=sys.stderr)
|
|
1828
|
+
ok = False
|
|
1829
|
+
break
|
|
1830
|
+
i += 1
|
|
1831
|
+
continue
|
|
1832
|
+
|
|
1833
|
+
print(f"unknown /rag option: {a!r}", file=sys.stderr)
|
|
1834
|
+
ok = False
|
|
1835
|
+
break
|
|
1836
|
+
|
|
1837
|
+
if not ok:
|
|
1838
|
+
continue
|
|
1839
|
+
|
|
1840
|
+
new_cfg = LightRagConfig(
|
|
1841
|
+
enabled=backend != "off",
|
|
1842
|
+
k=k,
|
|
1843
|
+
total_chars=total_chars,
|
|
1844
|
+
per_source_chars=per_source_chars,
|
|
1845
|
+
debug=debug,
|
|
1846
|
+
).sanitized()
|
|
1847
|
+
_apply_rag_backend_to_ws(ws, backend)
|
|
1848
|
+
_apply_light_rag_config_to_ws(ws, new_cfg)
|
|
1849
|
+
ws.bm25_k_paragraphs = max(1, min(int(bm25_k_paragraphs), 1000))
|
|
1850
|
+
ws.bm25_k_sources = max(0, min(int(bm25_k_sources), 50))
|
|
1851
|
+
state.docs_workspaces[name] = ws
|
|
1852
|
+
save_state(state)
|
|
1853
|
+
if backend_requested_bm25 and backend == "bm25" and not bm25_retriever.is_available():
|
|
1854
|
+
print(
|
|
1855
|
+
"note: `rank-bm25` is not installed; BM25 backend will fall back to lightRAG "
|
|
1856
|
+
"(install with `pip install rank-bm25`).",
|
|
1857
|
+
file=sys.stderr,
|
|
1858
|
+
)
|
|
1859
|
+
bm25_hint_printed = True
|
|
1860
|
+
if backend == "bm25" and bm25_retriever.is_available():
|
|
1861
|
+
bm25_cfg = _bm25_rag_config_from_ws(ws)
|
|
1862
|
+
for dbg in bm25_retriever.ensure_index(sources=ws.sources, debug=bm25_cfg.debug):
|
|
1863
|
+
print(dbg, file=sys.stderr)
|
|
1864
|
+
print(_format_rag_status(ws=ws, bm25_available=bm25_retriever.is_available()))
|
|
1865
|
+
continue
|
|
1866
|
+
if cmd == "add":
|
|
1867
|
+
if ws.phase != "INGEST":
|
|
1868
|
+
print("cannot /add in QA phase; use /clear to return to INGEST", file=sys.stderr)
|
|
1869
|
+
continue
|
|
1870
|
+
# Parse --save / -s flag
|
|
1871
|
+
save_after = "--save" in args or "-s" in args
|
|
1872
|
+
path_args = [a for a in args if a not in ("--save", "-s")]
|
|
1873
|
+
if not path_args:
|
|
1874
|
+
print("usage: /add <paths...> [--save|-s]", file=sys.stderr)
|
|
1875
|
+
continue
|
|
1876
|
+
files = _expand_files(path_args)
|
|
1877
|
+
if not files:
|
|
1878
|
+
print("no files matched", file=sys.stderr)
|
|
1879
|
+
continue
|
|
1880
|
+
|
|
1881
|
+
contents: list[tuple[str, str]] = []
|
|
1882
|
+
new_sources: list[dict[str, Any]] = []
|
|
1883
|
+
total_bytes = 0
|
|
1884
|
+
now = int(time.time())
|
|
1885
|
+
for p in files:
|
|
1886
|
+
try:
|
|
1887
|
+
text, nbytes, sha = _read_text_file(p)
|
|
1888
|
+
except Exception as exc:
|
|
1889
|
+
print(str(exc), file=sys.stderr)
|
|
1890
|
+
continue
|
|
1891
|
+
path_str = str(p.resolve())
|
|
1892
|
+
meta = _extract_doc_metadata(path=p, text=text)
|
|
1893
|
+
contents.append((path_str, text))
|
|
1894
|
+
total_bytes += nbytes
|
|
1895
|
+
new_sources.append(
|
|
1896
|
+
{
|
|
1897
|
+
"path": path_str,
|
|
1898
|
+
**meta,
|
|
1899
|
+
"bytes": nbytes,
|
|
1900
|
+
"sha256": sha,
|
|
1901
|
+
"added_at_unix_s": now,
|
|
1902
|
+
}
|
|
1903
|
+
)
|
|
1904
|
+
|
|
1905
|
+
if not contents:
|
|
1906
|
+
print("no ingestible text files found", file=sys.stderr)
|
|
1907
|
+
continue
|
|
1908
|
+
|
|
1909
|
+
include_ingest_prompt = (not ws.sources) and (ws.base_snapshot_id is None)
|
|
1910
|
+
messages: list[dict[str, Any]] = []
|
|
1911
|
+
if include_ingest_prompt:
|
|
1912
|
+
ingest_prompt = system_prompt if system_prompt else DOCS_INGEST_PROMPT
|
|
1913
|
+
messages.append({"role": "system", "content": ingest_prompt})
|
|
1914
|
+
messages.append({"role": "user", "content": _build_docs_message(contents)})
|
|
1915
|
+
|
|
1916
|
+
pos_before = _get_session_pos(client, ws.session_id)
|
|
1917
|
+
try:
|
|
1918
|
+
ingest_stats = _stream_request(
|
|
1919
|
+
client=client,
|
|
1920
|
+
session_id=ws.session_id,
|
|
1921
|
+
messages=messages,
|
|
1922
|
+
max_completion_tokens=16,
|
|
1923
|
+
think_budget=None,
|
|
1924
|
+
temperature=temperature,
|
|
1925
|
+
top_p=top_p,
|
|
1926
|
+
print_output=False,
|
|
1927
|
+
)
|
|
1928
|
+
last_stats = ingest_stats
|
|
1929
|
+
except KeyboardInterrupt:
|
|
1930
|
+
print("(cancelled)")
|
|
1931
|
+
continue
|
|
1932
|
+
except DocsReplError as exc:
|
|
1933
|
+
msg = str(exc)
|
|
1934
|
+
_maybe_print_prompt_too_long_hint(msg=msg, requested_max_seq_len=max_seq_len)
|
|
1935
|
+
print(msg, file=sys.stderr)
|
|
1936
|
+
continue
|
|
1937
|
+
except HttpError as exc:
|
|
1938
|
+
if exc.status_code == 429:
|
|
1939
|
+
print("Server is busy (429). Try again.", file=sys.stderr)
|
|
1940
|
+
else:
|
|
1941
|
+
print(str(exc), file=sys.stderr)
|
|
1942
|
+
continue
|
|
1943
|
+
|
|
1944
|
+
# Update source list (dedupe by path+sha256).
|
|
1945
|
+
existing_keys = {(s.get("path"), s.get("sha256")) for s in ws.sources if isinstance(s, dict)}
|
|
1946
|
+
for s in new_sources:
|
|
1947
|
+
key = (s.get("path"), s.get("sha256"))
|
|
1948
|
+
if key in existing_keys:
|
|
1949
|
+
continue
|
|
1950
|
+
ws.sources.append(s)
|
|
1951
|
+
existing_keys.add(key)
|
|
1952
|
+
|
|
1953
|
+
state.docs_workspaces[name] = ws
|
|
1954
|
+
save_state(state)
|
|
1955
|
+
|
|
1956
|
+
if _rag_backend_from_ws(ws) == "bm25" and bm25_retriever.is_available():
|
|
1957
|
+
bm25_cfg = _bm25_rag_config_from_ws(ws)
|
|
1958
|
+
for dbg in bm25_retriever.ensure_index(sources=ws.sources, debug=bm25_cfg.debug):
|
|
1959
|
+
print(dbg, file=sys.stderr)
|
|
1960
|
+
|
|
1961
|
+
pos_after = _get_session_pos(client, ws.session_id)
|
|
1962
|
+
prompt_delta = None
|
|
1963
|
+
if (
|
|
1964
|
+
pos_before is not None
|
|
1965
|
+
and pos_after is not None
|
|
1966
|
+
and ingest_stats.completion_tokens is not None
|
|
1967
|
+
and pos_after >= pos_before
|
|
1968
|
+
):
|
|
1969
|
+
prompt_delta = max(pos_after - pos_before - int(ingest_stats.completion_tokens), 0)
|
|
1970
|
+
|
|
1971
|
+
extra = []
|
|
1972
|
+
if prompt_delta is not None:
|
|
1973
|
+
extra.append(f"prompt_delta_tokens={prompt_delta}")
|
|
1974
|
+
if ingest_stats.server_prefill_s and prompt_delta is not None and ingest_stats.server_prefill_s > 0:
|
|
1975
|
+
extra.append(f"prefill_tok/s={prompt_delta/ingest_stats.server_prefill_s:.2f}")
|
|
1976
|
+
|
|
1977
|
+
snap_msg = f" base_snapshot_id={ws.base_snapshot_id}" if ws.base_snapshot_id else ""
|
|
1978
|
+
print(
|
|
1979
|
+
f"added files={len(contents)} bytes={total_bytes} sources={len(ws.sources)}{snap_msg}"
|
|
1980
|
+
+ ((" " + " ".join(extra)) if extra else "")
|
|
1981
|
+
)
|
|
1982
|
+
footer = _stats_footer(ingest_stats)
|
|
1983
|
+
if footer:
|
|
1984
|
+
print(footer)
|
|
1985
|
+
|
|
1986
|
+
# Auto-save if --save/-s flag was passed
|
|
1987
|
+
if save_after:
|
|
1988
|
+
payload: dict[str, Any] = {"description": _encode_sources_description(ws.sources)}
|
|
1989
|
+
try:
|
|
1990
|
+
resp = client.request_json(
|
|
1991
|
+
"POST",
|
|
1992
|
+
f"/v1/sessions/{ws.session_id}/save",
|
|
1993
|
+
payload=payload,
|
|
1994
|
+
timeout_s=300.0,
|
|
1995
|
+
)
|
|
1996
|
+
except HttpError as exc:
|
|
1997
|
+
if exc.status_code == 429:
|
|
1998
|
+
print("Server is busy (429). Try again.", file=sys.stderr)
|
|
1999
|
+
else:
|
|
2000
|
+
print(str(exc), file=sys.stderr)
|
|
2001
|
+
continue
|
|
2002
|
+
snap_id = resp.get("snapshot_id") if isinstance(resp, dict) else None
|
|
2003
|
+
if isinstance(snap_id, str) and snap_id:
|
|
2004
|
+
ws.base_snapshot_id = snap_id
|
|
2005
|
+
state.docs_workspaces[name] = ws
|
|
2006
|
+
save_state(state)
|
|
2007
|
+
print(f"saved base_snapshot_id={snap_id}")
|
|
2008
|
+
else:
|
|
2009
|
+
print("saved", file=sys.stderr)
|
|
2010
|
+
continue
|
|
2011
|
+
if cmd == "reset":
|
|
2012
|
+
if ws.base_snapshot_id:
|
|
2013
|
+
try:
|
|
2014
|
+
client.request_json(
|
|
2015
|
+
"POST",
|
|
2016
|
+
f"/v1/snapshots/{ws.base_snapshot_id}/load",
|
|
2017
|
+
payload={"session_id": ws.session_id, "force": True},
|
|
2018
|
+
timeout_s=300.0,
|
|
2019
|
+
)
|
|
2020
|
+
except HttpError as exc:
|
|
2021
|
+
if exc.status_code == 404:
|
|
2022
|
+
print(
|
|
2023
|
+
f"error: base snapshot not found: {ws.base_snapshot_id} (use `spl snapshot ls`).",
|
|
2024
|
+
file=sys.stderr,
|
|
2025
|
+
)
|
|
2026
|
+
else:
|
|
2027
|
+
print(str(exc), file=sys.stderr)
|
|
2028
|
+
continue
|
|
2029
|
+
ws.phase = "INGEST"
|
|
2030
|
+
state.docs_workspaces[name] = ws
|
|
2031
|
+
save_state(state)
|
|
2032
|
+
last_stats = None
|
|
2033
|
+
last_substantive_user_question = None
|
|
2034
|
+
bm25_hint_printed = False
|
|
2035
|
+
light_retriever.clear_cache()
|
|
2036
|
+
bm25_retriever.clear_index()
|
|
2037
|
+
print(f"cleared to base_snapshot_id={ws.base_snapshot_id}")
|
|
2038
|
+
else:
|
|
2039
|
+
# No checkpoint yet: return to empty ingest state.
|
|
2040
|
+
try:
|
|
2041
|
+
client.request_json("DELETE", f"/v1/sessions/{ws.session_id}", timeout_s=30.0)
|
|
2042
|
+
except HttpError:
|
|
2043
|
+
pass
|
|
2044
|
+
_create_session(client, ws.session_id)
|
|
2045
|
+
ws.phase = "INGEST"
|
|
2046
|
+
ws.sources = []
|
|
2047
|
+
ws.base_snapshot_id = None
|
|
2048
|
+
state.docs_workspaces[name] = ws
|
|
2049
|
+
save_state(state)
|
|
2050
|
+
last_stats = None
|
|
2051
|
+
last_substantive_user_question = None
|
|
2052
|
+
bm25_hint_printed = False
|
|
2053
|
+
light_retriever.clear_cache()
|
|
2054
|
+
bm25_retriever.clear_index()
|
|
2055
|
+
print("cleared (empty)")
|
|
2056
|
+
continue
|
|
2057
|
+
if cmd == "save":
|
|
2058
|
+
if ws.phase != "INGEST":
|
|
2059
|
+
print("docs /save is base-only; use /clear to return to INGEST first", file=sys.stderr)
|
|
2060
|
+
continue
|
|
2061
|
+
title = " ".join(args).strip() if args else ""
|
|
2062
|
+
payload: dict[str, Any] = {"description": _encode_sources_description(ws.sources)}
|
|
2063
|
+
if title:
|
|
2064
|
+
payload["title"] = title
|
|
2065
|
+
try:
|
|
2066
|
+
resp = client.request_json(
|
|
2067
|
+
"POST",
|
|
2068
|
+
f"/v1/sessions/{ws.session_id}/save",
|
|
2069
|
+
payload=payload,
|
|
2070
|
+
timeout_s=300.0,
|
|
2071
|
+
)
|
|
2072
|
+
except HttpError as exc:
|
|
2073
|
+
if exc.status_code == 429:
|
|
2074
|
+
print("Server is busy (429). Try again.", file=sys.stderr)
|
|
2075
|
+
else:
|
|
2076
|
+
print(str(exc), file=sys.stderr)
|
|
2077
|
+
continue
|
|
2078
|
+
snap_id = resp.get("snapshot_id") if isinstance(resp, dict) else None
|
|
2079
|
+
if isinstance(snap_id, str) and snap_id:
|
|
2080
|
+
ws.base_snapshot_id = snap_id
|
|
2081
|
+
state.docs_workspaces[name] = ws
|
|
2082
|
+
save_state(state)
|
|
2083
|
+
print(f"saved base_snapshot_id={snap_id}")
|
|
2084
|
+
else:
|
|
2085
|
+
print("saved", file=sys.stderr)
|
|
2086
|
+
continue
|
|
2087
|
+
if cmd == "load":
|
|
2088
|
+
if not args:
|
|
2089
|
+
print("usage: /load <snapshot_id> --as <new_name>", file=sys.stderr)
|
|
2090
|
+
continue
|
|
2091
|
+
snap_id = args[0]
|
|
2092
|
+
rest = args[1:]
|
|
2093
|
+
if len(rest) != 2 or rest[0] != "--as":
|
|
2094
|
+
print("usage: /load <snapshot_id> --as <new_name>", file=sys.stderr)
|
|
2095
|
+
continue
|
|
2096
|
+
new_name = rest[1]
|
|
2097
|
+
if not new_name:
|
|
2098
|
+
print("new_name must be non-empty", file=sys.stderr)
|
|
2099
|
+
continue
|
|
2100
|
+
if new_name in state.docs_workspaces:
|
|
2101
|
+
print(f"workspace already exists: {new_name}", file=sys.stderr)
|
|
2102
|
+
continue
|
|
2103
|
+
new_session = _new_session_id(workspace_name=new_name)
|
|
2104
|
+
next_lock = SessionLock(session_id=new_session, kind="docs", label=f"spl docs {new_name}")
|
|
2105
|
+
try:
|
|
2106
|
+
next_lock.acquire()
|
|
2107
|
+
except AlreadyLockedError as exc:
|
|
2108
|
+
print(
|
|
2109
|
+
f"error: target session is already open in another REPL (session_id={new_session} pid={exc.info.pid}).",
|
|
2110
|
+
file=sys.stderr,
|
|
2111
|
+
)
|
|
2112
|
+
continue
|
|
2113
|
+
|
|
2114
|
+
try:
|
|
2115
|
+
client.request_json(
|
|
2116
|
+
"POST",
|
|
2117
|
+
f"/v1/snapshots/{snap_id}/load",
|
|
2118
|
+
payload={"session_id": new_session},
|
|
2119
|
+
timeout_s=300.0,
|
|
2120
|
+
)
|
|
2121
|
+
except HttpError as exc:
|
|
2122
|
+
next_lock.release()
|
|
2123
|
+
if exc.status_code == 404:
|
|
2124
|
+
print(f"Snapshot not found: {snap_id} (use `spl snapshot ls`).", file=sys.stderr)
|
|
2125
|
+
elif exc.status_code == 409:
|
|
2126
|
+
print(f"Target session already exists: {new_session} (try again).", file=sys.stderr)
|
|
2127
|
+
elif exc.status_code == 429:
|
|
2128
|
+
print("Server is busy (429). Try again.", file=sys.stderr)
|
|
2129
|
+
else:
|
|
2130
|
+
print(str(exc), file=sys.stderr)
|
|
2131
|
+
continue
|
|
2132
|
+
|
|
2133
|
+
new_sources = _hydrate_sources_from_snapshot(client, snap_id) or []
|
|
2134
|
+
state.docs_workspaces[new_name] = DocsWorkspaceState(
|
|
2135
|
+
session_id=new_session,
|
|
2136
|
+
phase="INGEST",
|
|
2137
|
+
base_snapshot_id=snap_id,
|
|
2138
|
+
sources=new_sources,
|
|
2139
|
+
)
|
|
2140
|
+
save_state(state)
|
|
2141
|
+
|
|
2142
|
+
lock.release()
|
|
2143
|
+
lock = next_lock
|
|
2144
|
+
name = new_name
|
|
2145
|
+
ws = state.docs_workspaces[name]
|
|
2146
|
+
resumed = False
|
|
2147
|
+
last_stats = None
|
|
2148
|
+
light_retriever.clear_cache()
|
|
2149
|
+
bm25_retriever.clear_index()
|
|
2150
|
+
bm25_hint_printed = False
|
|
2151
|
+
_banner(
|
|
2152
|
+
url=client.base_url,
|
|
2153
|
+
name=name,
|
|
2154
|
+
session_id=ws.session_id,
|
|
2155
|
+
resumed=False,
|
|
2156
|
+
phase=ws.phase,
|
|
2157
|
+
source_count=len(ws.sources),
|
|
2158
|
+
base_snapshot_id=ws.base_snapshot_id,
|
|
2159
|
+
rag_status=_format_rag_status(ws=ws, bm25_available=bm25_retriever.is_available()),
|
|
2160
|
+
)
|
|
2161
|
+
continue
|
|
2162
|
+
|
|
2163
|
+
print(f"unknown command: /{cmd}", file=sys.stderr)
|
|
2164
|
+
continue
|
|
2165
|
+
|
|
2166
|
+
# Regular question input.
|
|
2167
|
+
qa_messages: list[dict[str, Any]] = []
|
|
2168
|
+
if ws.phase == "INGEST":
|
|
2169
|
+
# Transition to QA; lock docs base.
|
|
2170
|
+
ws.phase = "QA"
|
|
2171
|
+
state.docs_workspaces[name] = ws
|
|
2172
|
+
save_state(state)
|
|
2173
|
+
# IMPORTANT: do not rely on a system message here. In session mode the server
|
|
2174
|
+
# drops additional system messages once a leading system already exists.
|
|
2175
|
+
qa_messages.append({"role": "user", "content": _build_qa_bootstrap_message(sources=ws.sources)})
|
|
2176
|
+
qa_messages.append({"role": "assistant", "content": DOCS_QA_PRIMER_ASSISTANT})
|
|
2177
|
+
|
|
2178
|
+
backend = _rag_backend_from_ws(ws)
|
|
2179
|
+
rag_cfg = _light_rag_config_from_ws(ws)
|
|
2180
|
+
|
|
2181
|
+
rag_question = line
|
|
2182
|
+
if ws.phase == "QA" and _should_augment_rag_query_with_prev_question(
|
|
2183
|
+
question=line, prev_question=last_substantive_user_question
|
|
2184
|
+
):
|
|
2185
|
+
rag_question = f"{line}\n\nContext from previous user question: {last_substantive_user_question}".strip()
|
|
2186
|
+
|
|
2187
|
+
rag_msg: str | None = None
|
|
2188
|
+
rag_debug: list[str] = []
|
|
2189
|
+
if backend == "light":
|
|
2190
|
+
rag_msg, rag_debug = light_retriever.build_retrieved_excerpts_message(
|
|
2191
|
+
question=rag_question,
|
|
2192
|
+
sources=ws.sources,
|
|
2193
|
+
config=rag_cfg,
|
|
2194
|
+
)
|
|
2195
|
+
elif backend == "bm25":
|
|
2196
|
+
if not bm25_retriever.is_available():
|
|
2197
|
+
if not bm25_hint_printed:
|
|
2198
|
+
print(
|
|
2199
|
+
"note: BM25 backend requested but `rank-bm25` is not installed; falling back to lightRAG "
|
|
2200
|
+
"(install with `pip install rank-bm25`).",
|
|
2201
|
+
file=sys.stderr,
|
|
2202
|
+
)
|
|
2203
|
+
bm25_hint_printed = True
|
|
2204
|
+
rag_msg, rag_debug = light_retriever.build_retrieved_excerpts_message(
|
|
2205
|
+
question=rag_question,
|
|
2206
|
+
sources=ws.sources,
|
|
2207
|
+
config=rag_cfg,
|
|
2208
|
+
)
|
|
2209
|
+
else:
|
|
2210
|
+
bm25_cfg = _bm25_rag_config_from_ws(ws)
|
|
2211
|
+
rag_msg, rag_debug = bm25_retriever.build_retrieved_excerpts_message(
|
|
2212
|
+
question=rag_question,
|
|
2213
|
+
sources=ws.sources,
|
|
2214
|
+
config=bm25_cfg,
|
|
2215
|
+
)
|
|
2216
|
+
|
|
2217
|
+
if rag_debug:
|
|
2218
|
+
for dbg in rag_debug:
|
|
2219
|
+
print(dbg, file=sys.stderr)
|
|
2220
|
+
if rag_msg:
|
|
2221
|
+
qa_messages.append({"role": "user", "content": rag_msg})
|
|
2222
|
+
|
|
2223
|
+
qa_messages.append({"role": "user", "content": line})
|
|
2224
|
+
|
|
2225
|
+
# Track the last substantive user question so follow-up questions can reuse it for retrieval.
|
|
2226
|
+
# We intentionally do this after appending the message so the chat itself remains unchanged.
|
|
2227
|
+
cur_terms = tokenize_query_terms(line, max_terms=rag_cfg.max_terms)
|
|
2228
|
+
if cur_terms and not _should_augment_rag_query_with_prev_question(
|
|
2229
|
+
question=line, prev_question=last_substantive_user_question
|
|
2230
|
+
):
|
|
2231
|
+
last_substantive_user_question = line
|
|
2232
|
+
|
|
2233
|
+
try:
|
|
2234
|
+
stats = _stream_request(
|
|
2235
|
+
client=client,
|
|
2236
|
+
session_id=ws.session_id,
|
|
2237
|
+
messages=qa_messages,
|
|
2238
|
+
max_completion_tokens=32768,
|
|
2239
|
+
think_budget=think_budget,
|
|
2240
|
+
temperature=temperature,
|
|
2241
|
+
top_p=top_p,
|
|
2242
|
+
print_output=True,
|
|
2243
|
+
)
|
|
2244
|
+
last_stats = stats
|
|
2245
|
+
footer = _stats_footer(stats)
|
|
2246
|
+
if footer:
|
|
2247
|
+
print(footer)
|
|
2248
|
+
except KeyboardInterrupt:
|
|
2249
|
+
print("\n(cancelled)")
|
|
2250
|
+
continue
|
|
2251
|
+
except DocsReplError as exc:
|
|
2252
|
+
msg = str(exc)
|
|
2253
|
+
_maybe_print_prompt_too_long_hint(msg=msg, requested_max_seq_len=max_seq_len)
|
|
2254
|
+
print(msg, file=sys.stderr)
|
|
2255
|
+
continue
|
|
2256
|
+
except HttpError as exc:
|
|
2257
|
+
if exc.status_code == 429:
|
|
2258
|
+
print("Server is busy (429). Try again.", file=sys.stderr)
|
|
2259
|
+
else:
|
|
2260
|
+
print(str(exc), file=sys.stderr)
|
|
2261
|
+
continue
|
|
2262
|
+
except Exception as exc: # pragma: no cover
|
|
2263
|
+
print(f"error: {exc}", file=sys.stderr)
|
|
2264
|
+
continue
|
|
2265
|
+
except (DocsReplError, HttpError) as exc:
|
|
2266
|
+
print(str(exc), file=sys.stderr)
|
|
2267
|
+
return 1
|
|
2268
|
+
except Exception as exc: # pragma: no cover
|
|
2269
|
+
print(f"error: {exc}", file=sys.stderr)
|
|
2270
|
+
return 1
|
|
2271
|
+
finally:
|
|
2272
|
+
try:
|
|
2273
|
+
lock.release()
|
|
2274
|
+
except Exception:
|
|
2275
|
+
pass
|