continualcode 0.1.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- continualcode/__init__.py +17 -0
- continualcode/cli.py +67 -0
- continualcode/session.py +357 -0
- continualcode/tools.py +230 -0
- continualcode/tui.py +994 -0
- continualcode-0.1.0.dist-info/METADATA +115 -0
- continualcode-0.1.0.dist-info/RECORD +10 -0
- continualcode-0.1.0.dist-info/WHEEL +4 -0
- continualcode-0.1.0.dist-info/entry_points.txt +2 -0
- continualcode-0.1.0.dist-info/licenses/LICENSE +21 -0
continualcode/tui.py
ADDED
|
@@ -0,0 +1,994 @@
|
|
|
1
|
+
#!/usr/bin/env python3
|
|
2
|
+
"""
|
|
3
|
+
Minimal TUI for a coding agent with online context (prompt) distillation.
|
|
4
|
+
|
|
5
|
+
Generation uses a teacher prefix that includes a long policy prompt (POLICY_PATH).
|
|
6
|
+
Training distills that policy into weights by running cross-entropy on the approved
|
|
7
|
+
assistant message but with the policy removed (student prefix).
|
|
8
|
+
"""
|
|
9
|
+
|
|
10
|
+
from __future__ import annotations
|
|
11
|
+
|
|
12
|
+
import asyncio
|
|
13
|
+
import json
|
|
14
|
+
import os
|
|
15
|
+
import shlex
|
|
16
|
+
import subprocess
|
|
17
|
+
import tempfile
|
|
18
|
+
import time
|
|
19
|
+
from pathlib import Path
|
|
20
|
+
from typing import Any, Literal
|
|
21
|
+
|
|
22
|
+
from textual import events, on, work
|
|
23
|
+
from textual.app import App, ComposeResult
|
|
24
|
+
from textual.binding import Binding
|
|
25
|
+
from textual.containers import Horizontal, Vertical, VerticalScroll
|
|
26
|
+
from textual.widgets import Input, Label
|
|
27
|
+
|
|
28
|
+
try:
|
|
29
|
+
from continualcode.session import ContextDistillSession
|
|
30
|
+
TINKER_AVAILABLE = True
|
|
31
|
+
except Exception:
|
|
32
|
+
ContextDistillSession = None # type: ignore[assignment]
|
|
33
|
+
TINKER_AVAILABLE = False
|
|
34
|
+
|
|
35
|
+
from continualcode.tools import READONLY_TOOLS, execute_tool
|
|
36
|
+
|
|
37
|
+
|
|
38
|
+
# --- Config ---
|
|
39
|
+
MODEL = os.environ.get("MODEL", "Qwen/Qwen3-4B-Instruct-2507")
|
|
40
|
+
TINKER_URL = os.environ.get("TINKER_URL")
|
|
41
|
+
MAX_TOKENS = int(os.environ.get("MAX_TOKENS", "4096"))
|
|
42
|
+
TEMPERATURE = float(os.environ.get("TEMPERATURE", "0.7"))
|
|
43
|
+
ENABLE_TRAINING = os.environ.get("ENABLE_TRAINING", "1") == "1"
|
|
44
|
+
LEARNING_RATE = float(os.environ.get("LEARNING_RATE", "1e-5"))
|
|
45
|
+
LORA_RANK = int(os.environ.get("LORA_RANK", "32"))
|
|
46
|
+
AUTO_APPROVE_READONLY = os.environ.get("AUTO_APPROVE_READONLY", "0") == "1"
|
|
47
|
+
|
|
48
|
+
POLICY_PATH = os.environ.get("POLICY_PATH", "./policy_memory.md")
|
|
49
|
+
TRAIN_SCALE_APPROVE = float(os.environ.get("TRAIN_SCALE_APPROVE", "1.0"))
|
|
50
|
+
TRAIN_SCALE_EDITED = float(os.environ.get("TRAIN_SCALE_EDITED", "2.0"))
|
|
51
|
+
DISTILL_MODE = os.environ.get("DISTILL_MODE", "on_policy") # "on_policy" or "off_policy"
|
|
52
|
+
TRAIN_MAX_TOKENS = int(os.environ.get("TRAIN_MAX_TOKENS", "1024"))
|
|
53
|
+
TRAIN_TEMPERATURE = float(os.environ.get("TRAIN_TEMPERATURE", "1.0"))
|
|
54
|
+
KL_COEF = float(os.environ.get("KL_COEF", "1.0"))
|
|
55
|
+
|
|
56
|
+
# Checkpoint directory for manual saves
|
|
57
|
+
CHECKPOINT_DIR = Path(os.environ.get("CHECKPOINT_DIR", "./checkpoints"))
|
|
58
|
+
# Load checkpoint on startup (optional)
|
|
59
|
+
LOAD_CHECKPOINT = os.environ.get("LOAD_CHECKPOINT", "")
|
|
60
|
+
|
|
61
|
+
# Only train on file modification tools
|
|
62
|
+
TRAIN_TOOL_NAMES = {"edit", "write"}
|
|
63
|
+
|
|
64
|
+
|
|
65
|
+
# --- Permission Memory (from tinkercode_tui.py) ---
|
|
66
|
+
|
|
67
|
+
def _resolve_path(project_root: Path, raw_path: str) -> Path:
|
|
68
|
+
p = Path(raw_path).expanduser()
|
|
69
|
+
if not p.is_absolute():
|
|
70
|
+
p = project_root / p
|
|
71
|
+
return p.resolve()
|
|
72
|
+
|
|
73
|
+
|
|
74
|
+
def _scope_for_path(project_root: Path, raw_path: str) -> tuple[Path, str]:
|
|
75
|
+
abs_path = _resolve_path(project_root, raw_path)
|
|
76
|
+
try:
|
|
77
|
+
rel = abs_path.relative_to(project_root)
|
|
78
|
+
except ValueError:
|
|
79
|
+
return abs_path.parent, str(abs_path.parent)
|
|
80
|
+
|
|
81
|
+
if len(rel.parts) > 1:
|
|
82
|
+
return project_root / rel.parts[0], f"{rel.parts[0]}/"
|
|
83
|
+
return project_root, f"{project_root.name}/"
|
|
84
|
+
|
|
85
|
+
|
|
86
|
+
class PermissionMemory:
|
|
87
|
+
"""Session-scoped permission memory (similar to Claude Code 'don't ask again')."""
|
|
88
|
+
|
|
89
|
+
def __init__(self, project_root: Path) -> None:
|
|
90
|
+
self.project_root = project_root.resolve()
|
|
91
|
+
self.allowed_tools: set[str] = set()
|
|
92
|
+
self.allowed_prefixes: dict[str, set[Path]] = {}
|
|
93
|
+
|
|
94
|
+
def allows(self, tool_name: str, tool_args: dict[str, Any]) -> bool:
|
|
95
|
+
if tool_name in self.allowed_tools:
|
|
96
|
+
return True
|
|
97
|
+
|
|
98
|
+
raw_path = tool_args.get("path")
|
|
99
|
+
if not isinstance(raw_path, str):
|
|
100
|
+
return False
|
|
101
|
+
|
|
102
|
+
abs_path = _resolve_path(self.project_root, raw_path)
|
|
103
|
+
for prefix in self.allowed_prefixes.get(tool_name, set()):
|
|
104
|
+
try:
|
|
105
|
+
abs_path.relative_to(prefix)
|
|
106
|
+
return True
|
|
107
|
+
except ValueError:
|
|
108
|
+
continue
|
|
109
|
+
return False
|
|
110
|
+
|
|
111
|
+
def remember(self, tool_name: str, tool_args: dict[str, Any]) -> None:
|
|
112
|
+
if tool_name == "bash":
|
|
113
|
+
self.allowed_tools.add(tool_name)
|
|
114
|
+
return
|
|
115
|
+
|
|
116
|
+
raw_path = tool_args.get("path")
|
|
117
|
+
if not isinstance(raw_path, str):
|
|
118
|
+
self.allowed_tools.add(tool_name)
|
|
119
|
+
return
|
|
120
|
+
|
|
121
|
+
prefix, _label = _scope_for_path(self.project_root, raw_path)
|
|
122
|
+
|
|
123
|
+
if tool_name in ("write", "edit"):
|
|
124
|
+
for t in ("write", "edit"):
|
|
125
|
+
self.allowed_prefixes.setdefault(t, set()).add(prefix)
|
|
126
|
+
return
|
|
127
|
+
|
|
128
|
+
self.allowed_prefixes.setdefault(tool_name, set()).add(prefix)
|
|
129
|
+
|
|
130
|
+
def remember_text(self, tool_name: str, tool_args: dict[str, Any]) -> str:
|
|
131
|
+
if tool_name == "bash":
|
|
132
|
+
return "Yes, and don't ask again for Bash commands"
|
|
133
|
+
|
|
134
|
+
raw_path = tool_args.get("path")
|
|
135
|
+
label = f"{self.project_root.name}/"
|
|
136
|
+
if isinstance(raw_path, str) and raw_path:
|
|
137
|
+
_prefix, label = _scope_for_path(self.project_root, raw_path)
|
|
138
|
+
|
|
139
|
+
if tool_name in ("write", "edit"):
|
|
140
|
+
return f"Yes, allow all edits in {label}"
|
|
141
|
+
if tool_name == "read":
|
|
142
|
+
return f"Yes, allow reading from {label}"
|
|
143
|
+
if tool_name in ("glob", "grep"):
|
|
144
|
+
return f"Yes, allow searching in {label}"
|
|
145
|
+
|
|
146
|
+
return "Yes, and don't ask again"
|
|
147
|
+
|
|
148
|
+
|
|
149
|
+
class TinkerCodeApp(App):
|
|
150
|
+
"""Claude Code-style TUI with online prompt distillation."""
|
|
151
|
+
|
|
152
|
+
# Textual enables a command palette (Ctrl+P) by default. In a coding-agent TUI
|
|
153
|
+
# this is easy to trigger accidentally and looks like "random input". Disable it.
|
|
154
|
+
ENABLE_COMMAND_PALETTE = False
|
|
155
|
+
|
|
156
|
+
DEFAULT_CSS = """
|
|
157
|
+
Screen {
|
|
158
|
+
background: #1e1e1e;
|
|
159
|
+
}
|
|
160
|
+
|
|
161
|
+
/* Main output area */
|
|
162
|
+
#output {
|
|
163
|
+
height: 1fr;
|
|
164
|
+
padding: 0;
|
|
165
|
+
scrollbar-size: 1 1;
|
|
166
|
+
scrollbar-color: #333333;
|
|
167
|
+
background: #1e1e1e;
|
|
168
|
+
}
|
|
169
|
+
|
|
170
|
+
#spacer {
|
|
171
|
+
height: 1fr;
|
|
172
|
+
}
|
|
173
|
+
|
|
174
|
+
/* Message styling - all messages full width */
|
|
175
|
+
.msg {
|
|
176
|
+
height: auto;
|
|
177
|
+
width: 100%;
|
|
178
|
+
padding: 0 2;
|
|
179
|
+
}
|
|
180
|
+
|
|
181
|
+
/* User messages - dark background bar, full width */
|
|
182
|
+
.msg-user {
|
|
183
|
+
color: #ffffff;
|
|
184
|
+
background: #2a2a2a;
|
|
185
|
+
padding: 1 2;
|
|
186
|
+
text-style: bold;
|
|
187
|
+
}
|
|
188
|
+
|
|
189
|
+
/* Assistant messages */
|
|
190
|
+
.msg-assistant {
|
|
191
|
+
color: #cccccc;
|
|
192
|
+
padding: 1 2;
|
|
193
|
+
}
|
|
194
|
+
|
|
195
|
+
/* Tool results - indented */
|
|
196
|
+
.msg-tool {
|
|
197
|
+
color: #888888;
|
|
198
|
+
padding: 0 2 0 4;
|
|
199
|
+
}
|
|
200
|
+
|
|
201
|
+
/* System messages - subtle */
|
|
202
|
+
.msg-system {
|
|
203
|
+
color: #666666;
|
|
204
|
+
padding: 0 2;
|
|
205
|
+
}
|
|
206
|
+
|
|
207
|
+
/* Thinking indicator - italic like Claude Code */
|
|
208
|
+
.msg-thinking {
|
|
209
|
+
color: #666666;
|
|
210
|
+
text-style: italic;
|
|
211
|
+
padding: 1 2;
|
|
212
|
+
}
|
|
213
|
+
|
|
214
|
+
/* Correction feedback */
|
|
215
|
+
.msg-correction {
|
|
216
|
+
color: #dcdcaa;
|
|
217
|
+
padding: 0 2 0 4;
|
|
218
|
+
}
|
|
219
|
+
|
|
220
|
+
/* Status bar container - very bottom */
|
|
221
|
+
#status-bar {
|
|
222
|
+
height: 1;
|
|
223
|
+
dock: bottom;
|
|
224
|
+
background: #1e1e1e;
|
|
225
|
+
padding: 0 2;
|
|
226
|
+
}
|
|
227
|
+
|
|
228
|
+
#status-left {
|
|
229
|
+
width: auto;
|
|
230
|
+
color: #555555;
|
|
231
|
+
}
|
|
232
|
+
|
|
233
|
+
#status-right {
|
|
234
|
+
width: 1fr;
|
|
235
|
+
text-align: right;
|
|
236
|
+
color: #555555;
|
|
237
|
+
}
|
|
238
|
+
|
|
239
|
+
/* Bottom area - prompt OR approval */
|
|
240
|
+
#bottom-area {
|
|
241
|
+
dock: bottom;
|
|
242
|
+
height: auto;
|
|
243
|
+
max-height: 70%;
|
|
244
|
+
background: #252526;
|
|
245
|
+
border-bottom: solid #3c3c3c;
|
|
246
|
+
}
|
|
247
|
+
|
|
248
|
+
/* Prompt input row */
|
|
249
|
+
#prompt-row {
|
|
250
|
+
height: auto;
|
|
251
|
+
width: 100%;
|
|
252
|
+
padding: 1 2 1 2;
|
|
253
|
+
background: #252526;
|
|
254
|
+
}
|
|
255
|
+
|
|
256
|
+
#prompt-caret {
|
|
257
|
+
width: 2;
|
|
258
|
+
color: #cccccc;
|
|
259
|
+
height: 1;
|
|
260
|
+
}
|
|
261
|
+
|
|
262
|
+
#prompt-input {
|
|
263
|
+
width: 1fr;
|
|
264
|
+
border: none;
|
|
265
|
+
background: transparent;
|
|
266
|
+
padding: 0;
|
|
267
|
+
margin: 0;
|
|
268
|
+
height: 1;
|
|
269
|
+
color: #cccccc;
|
|
270
|
+
}
|
|
271
|
+
|
|
272
|
+
#prompt-input:focus {
|
|
273
|
+
border: none;
|
|
274
|
+
color: #ffffff;
|
|
275
|
+
}
|
|
276
|
+
|
|
277
|
+
/* Approval widget */
|
|
278
|
+
#approval {
|
|
279
|
+
display: none;
|
|
280
|
+
height: auto;
|
|
281
|
+
padding: 1 2;
|
|
282
|
+
background: #1e1e1e;
|
|
283
|
+
}
|
|
284
|
+
|
|
285
|
+
#approval.-visible {
|
|
286
|
+
display: block;
|
|
287
|
+
}
|
|
288
|
+
|
|
289
|
+
.approval-kicker {
|
|
290
|
+
color: #569cd6;
|
|
291
|
+
text-style: bold;
|
|
292
|
+
}
|
|
293
|
+
|
|
294
|
+
.approval-title {
|
|
295
|
+
color: #ffffff;
|
|
296
|
+
text-style: bold;
|
|
297
|
+
}
|
|
298
|
+
|
|
299
|
+
.approval-desc {
|
|
300
|
+
color: #9cdcfe;
|
|
301
|
+
padding: 0 0 0 2;
|
|
302
|
+
}
|
|
303
|
+
|
|
304
|
+
#approval-details {
|
|
305
|
+
height: auto;
|
|
306
|
+
max-height: 18;
|
|
307
|
+
overflow-y: auto;
|
|
308
|
+
padding: 1 0 1 2;
|
|
309
|
+
}
|
|
310
|
+
|
|
311
|
+
.diff-add {
|
|
312
|
+
color: #4ec994;
|
|
313
|
+
background: #1e3a2f;
|
|
314
|
+
padding: 0 1;
|
|
315
|
+
}
|
|
316
|
+
|
|
317
|
+
.diff-del {
|
|
318
|
+
color: #f97583;
|
|
319
|
+
background: #3d1f23;
|
|
320
|
+
padding: 0 1;
|
|
321
|
+
}
|
|
322
|
+
|
|
323
|
+
.diff-ctx {
|
|
324
|
+
color: #666666;
|
|
325
|
+
}
|
|
326
|
+
|
|
327
|
+
.approval-question {
|
|
328
|
+
color: #ffffff;
|
|
329
|
+
padding: 1 0;
|
|
330
|
+
}
|
|
331
|
+
|
|
332
|
+
.approval-option {
|
|
333
|
+
height: auto;
|
|
334
|
+
color: #666666;
|
|
335
|
+
}
|
|
336
|
+
|
|
337
|
+
.approval-option.-selected {
|
|
338
|
+
color: #ffffff;
|
|
339
|
+
}
|
|
340
|
+
|
|
341
|
+
.approval-hint {
|
|
342
|
+
color: #444444;
|
|
343
|
+
padding: 1 0 0 0;
|
|
344
|
+
}
|
|
345
|
+
|
|
346
|
+
#correction-input {
|
|
347
|
+
display: none;
|
|
348
|
+
margin: 1 0;
|
|
349
|
+
border: solid #333333;
|
|
350
|
+
background: #252526;
|
|
351
|
+
}
|
|
352
|
+
|
|
353
|
+
#correction-input.-visible {
|
|
354
|
+
display: block;
|
|
355
|
+
}
|
|
356
|
+
"""
|
|
357
|
+
|
|
358
|
+
BINDINGS = [
|
|
359
|
+
Binding("ctrl+c", "quit", "Quit"),
|
|
360
|
+
Binding("ctrl+l", "clear", "Clear"),
|
|
361
|
+
Binding("ctrl+s", "save_checkpoint", "Save"),
|
|
362
|
+
Binding("?", "help", "Help"),
|
|
363
|
+
]
|
|
364
|
+
|
|
365
|
+
def __init__(self) -> None:
|
|
366
|
+
super().__init__()
|
|
367
|
+
self.session: ContextDistillSession | None = None
|
|
368
|
+
self.session_ready = False
|
|
369
|
+
self.permissions = PermissionMemory(Path.cwd())
|
|
370
|
+
self._loading = True
|
|
371
|
+
|
|
372
|
+
# Metrics for status bar
|
|
373
|
+
self.total_tokens = 0
|
|
374
|
+
self.last_loss = 0.0
|
|
375
|
+
|
|
376
|
+
# Approval state
|
|
377
|
+
self._approval_future: asyncio.Future[dict[str, Any]] | None = None
|
|
378
|
+
self._approval_selection = 0
|
|
379
|
+
self._approval_options: list[tuple[str, str]] = []
|
|
380
|
+
self._show_correction = False
|
|
381
|
+
self._current_tool_name: str = ""
|
|
382
|
+
self._current_tool_args: dict[str, Any] = {}
|
|
383
|
+
self._edited_args: dict[str, Any] | None = None
|
|
384
|
+
|
|
385
|
+
def _prompt_input(self) -> Input:
|
|
386
|
+
return self.query_one("#prompt-input", Input)
|
|
387
|
+
|
|
388
|
+
def _prompt_row_visible(self) -> bool:
|
|
389
|
+
return bool(self.query_one("#prompt-row").display)
|
|
390
|
+
|
|
391
|
+
def _sync_focus(self) -> None:
|
|
392
|
+
"""Keep focus on the right input widget for the current UI mode."""
|
|
393
|
+
if self._approval_future is not None:
|
|
394
|
+
if self._show_correction:
|
|
395
|
+
self.query_one("#correction-input", Input).focus()
|
|
396
|
+
else:
|
|
397
|
+
self.set_focus(None)
|
|
398
|
+
return
|
|
399
|
+
|
|
400
|
+
prompt_input = self.query_one("#prompt-input", Input)
|
|
401
|
+
if not prompt_input.disabled:
|
|
402
|
+
prompt_input.focus()
|
|
403
|
+
|
|
404
|
+
def _show_correction_input(self) -> None:
|
|
405
|
+
correction_input = self.query_one("#correction-input", Input)
|
|
406
|
+
self._show_correction = True
|
|
407
|
+
correction_input.disabled = False
|
|
408
|
+
correction_input.add_class("-visible")
|
|
409
|
+
correction_input.focus()
|
|
410
|
+
|
|
411
|
+
def _hide_correction_input(self) -> None:
|
|
412
|
+
correction_input = self.query_one("#correction-input", Input)
|
|
413
|
+
self._show_correction = False
|
|
414
|
+
correction_input.remove_class("-visible")
|
|
415
|
+
correction_input.disabled = True
|
|
416
|
+
self.set_focus(None)
|
|
417
|
+
|
|
418
|
+
def compose(self) -> ComposeResult:
|
|
419
|
+
with VerticalScroll(id="output", can_focus=False, can_focus_children=False):
|
|
420
|
+
yield Label("", id="spacer")
|
|
421
|
+
with Vertical(id="bottom-area"):
|
|
422
|
+
with Horizontal(id="prompt-row"):
|
|
423
|
+
yield Label("> ", id="prompt-caret")
|
|
424
|
+
yield Input(id="prompt-input")
|
|
425
|
+
with Vertical(id="approval"):
|
|
426
|
+
yield Label("Tool use", classes="approval-kicker")
|
|
427
|
+
yield Label("", id="approval-title", classes="approval-title")
|
|
428
|
+
yield Label("", id="approval-desc", classes="approval-desc")
|
|
429
|
+
with Vertical(id="approval-details"):
|
|
430
|
+
pass
|
|
431
|
+
yield Label("", id="approval-question", classes="approval-question")
|
|
432
|
+
yield Label("", id="opt-0", classes="approval-option")
|
|
433
|
+
yield Label("", id="opt-1", classes="approval-option")
|
|
434
|
+
yield Label("", id="opt-2", classes="approval-option")
|
|
435
|
+
yield Label("", id="opt-3", classes="approval-option")
|
|
436
|
+
yield Input(placeholder="Type what to change…", id="correction-input")
|
|
437
|
+
yield Label("Esc cancel · e edit · Tab correct", classes="approval-hint")
|
|
438
|
+
with Horizontal(id="status-bar"):
|
|
439
|
+
yield Label("? for shortcuts", id="status-left")
|
|
440
|
+
yield Label("", id="status-right")
|
|
441
|
+
|
|
442
|
+
async def on_mount(self) -> None:
|
|
443
|
+
# The correction input starts hidden; keep it disabled so it can't steal focus.
|
|
444
|
+
self.query_one("#correction-input", Input).disabled = True
|
|
445
|
+
|
|
446
|
+
# Init can take a while (model load, network, auth). Hide/disable the prompt
|
|
447
|
+
# during init so stray keystrokes don't show up as random characters.
|
|
448
|
+
self.query_one("#prompt-row").display = False
|
|
449
|
+
self._prompt_input().disabled = True
|
|
450
|
+
|
|
451
|
+
self.call_after_refresh(self._sync_focus)
|
|
452
|
+
self._update_status()
|
|
453
|
+
|
|
454
|
+
if TINKER_AVAILABLE:
|
|
455
|
+
init_msg = self.output("Initializing Tinker session...", style="thinking")
|
|
456
|
+
try:
|
|
457
|
+
self.session = ContextDistillSession(
|
|
458
|
+
model=MODEL,
|
|
459
|
+
tinker_url=TINKER_URL,
|
|
460
|
+
max_tokens=MAX_TOKENS,
|
|
461
|
+
temperature=TEMPERATURE,
|
|
462
|
+
enable_training=ENABLE_TRAINING,
|
|
463
|
+
lora_rank=LORA_RANK,
|
|
464
|
+
learning_rate=LEARNING_RATE,
|
|
465
|
+
policy_path=POLICY_PATH,
|
|
466
|
+
distill_mode=DISTILL_MODE,
|
|
467
|
+
train_max_tokens=TRAIN_MAX_TOKENS,
|
|
468
|
+
train_temperature=TRAIN_TEMPERATURE,
|
|
469
|
+
kl_coef=KL_COEF,
|
|
470
|
+
)
|
|
471
|
+
await self.session.init()
|
|
472
|
+
|
|
473
|
+
# Load checkpoint if specified
|
|
474
|
+
if LOAD_CHECKPOINT and self.session.training_client is not None:
|
|
475
|
+
if self.session.load_checkpoint(LOAD_CHECKPOINT):
|
|
476
|
+
self.output(f"Loaded checkpoint: {LOAD_CHECKPOINT}", style="system")
|
|
477
|
+
else:
|
|
478
|
+
self.output(f"Failed to load checkpoint: {LOAD_CHECKPOINT}", style="system")
|
|
479
|
+
|
|
480
|
+
self.session_ready = True
|
|
481
|
+
init_msg.remove()
|
|
482
|
+
status = f"Ready. Model: {MODEL} | policy={POLICY_PATH} | mode={DISTILL_MODE}"
|
|
483
|
+
if self.session.training_client is not None:
|
|
484
|
+
status += " [distillation ON]"
|
|
485
|
+
self.output(status, style="system")
|
|
486
|
+
|
|
487
|
+
# Now that we're ready, enable the prompt.
|
|
488
|
+
self._loading = False
|
|
489
|
+
self.query_one("#prompt-row").display = True
|
|
490
|
+
self._prompt_input().disabled = False
|
|
491
|
+
self.call_after_refresh(self._sync_focus)
|
|
492
|
+
except Exception as e:
|
|
493
|
+
init_msg.remove()
|
|
494
|
+
self.output(
|
|
495
|
+
f"Failed to initialize Tinker: {e}\n"
|
|
496
|
+
"Set TINKER_API_KEY (or pass api_key to the client) and restart.",
|
|
497
|
+
style="system",
|
|
498
|
+
)
|
|
499
|
+
# Avoid a confusing UX where the prompt accepts input but nothing can run.
|
|
500
|
+
self._prompt_input().disabled = True
|
|
501
|
+
self._loading = False
|
|
502
|
+
else:
|
|
503
|
+
self.output("Tinker not available.", style="system")
|
|
504
|
+
self._prompt_input().disabled = True
|
|
505
|
+
self._loading = False
|
|
506
|
+
|
|
507
|
+
async def on_key(self, event: events.Key) -> None:
|
|
508
|
+
# While we're loading (and the prompt is hidden), swallow printable keys so
|
|
509
|
+
# impatient typing doesn't render as "random characters" (e.g. a lone 'p').
|
|
510
|
+
# Let control bindings (Ctrl+C/Ctrl+L/Ctrl+S/?) continue to work.
|
|
511
|
+
if self._loading and len(event.key) == 1 and event.key.isprintable():
|
|
512
|
+
event.stop()
|
|
513
|
+
return
|
|
514
|
+
|
|
515
|
+
# Fall back to the base behavior for the rest of the app.
|
|
516
|
+
await super().on_key(event)
|
|
517
|
+
|
|
518
|
+
def action_help(self) -> None:
|
|
519
|
+
self.output(
|
|
520
|
+
"Keys: Enter=send · Ctrl+L=clear · Ctrl+S=save · Ctrl+C=quit · "
|
|
521
|
+
"Approval: 1/y approve · 2 approve+remember · 3/n/Esc deny · 4/Tab correct · e edit args",
|
|
522
|
+
style="system",
|
|
523
|
+
)
|
|
524
|
+
|
|
525
|
+
def output(self, text: str, style: str = "assistant") -> Label:
|
|
526
|
+
out = self.query_one("#output", VerticalScroll)
|
|
527
|
+
prefix = {
|
|
528
|
+
"user": "> ",
|
|
529
|
+
"assistant": "● ",
|
|
530
|
+
"tool": "⎿ ",
|
|
531
|
+
"correction": "⎿ ",
|
|
532
|
+
"thinking": "● ",
|
|
533
|
+
}.get(style, "")
|
|
534
|
+
label = Label(f"{prefix}{text}", classes=f"msg msg-{style}")
|
|
535
|
+
out.mount(label)
|
|
536
|
+
out.scroll_end()
|
|
537
|
+
self._update_status()
|
|
538
|
+
self.call_after_refresh(self._sync_focus)
|
|
539
|
+
return label
|
|
540
|
+
|
|
541
|
+
def _show_approval(self, tool_name: str, tool_args: dict[str, Any], remember_text: str) -> None:
|
|
542
|
+
"""Show the inline approval UI."""
|
|
543
|
+
self._current_tool_name = tool_name
|
|
544
|
+
self._current_tool_args = dict(tool_args)
|
|
545
|
+
self._edited_args = None
|
|
546
|
+
|
|
547
|
+
self.query_one("#prompt-row").display = False
|
|
548
|
+
self.query_one("#prompt-input", Input).disabled = True
|
|
549
|
+
approval = self.query_one("#approval")
|
|
550
|
+
approval.add_class("-visible")
|
|
551
|
+
self.set_focus(None)
|
|
552
|
+
|
|
553
|
+
self._refresh_approval_ui(tool_name, tool_args, remember_text)
|
|
554
|
+
self.call_after_refresh(self._sync_focus)
|
|
555
|
+
|
|
556
|
+
def _refresh_approval_ui(self, tool_name: str, tool_args: dict[str, Any], remember_text: str) -> None:
|
|
557
|
+
"""Refresh the approval UI with current tool args."""
|
|
558
|
+
file_path = str(tool_args.get("path", ""))
|
|
559
|
+
edited_suffix = " (edited)" if self._edited_args is not None else ""
|
|
560
|
+
|
|
561
|
+
if tool_name == "edit":
|
|
562
|
+
title = f"Edit file {file_path}{edited_suffix}"
|
|
563
|
+
elif tool_name == "write":
|
|
564
|
+
title = f"Create file {file_path}{edited_suffix}"
|
|
565
|
+
elif tool_name == "read":
|
|
566
|
+
title = f"Read file {file_path}{edited_suffix}"
|
|
567
|
+
elif tool_name == "bash":
|
|
568
|
+
title = f"Bash command{edited_suffix}"
|
|
569
|
+
else:
|
|
570
|
+
title = f"{tool_name}{edited_suffix}"
|
|
571
|
+
self.query_one("#approval-title", Label).update(title)
|
|
572
|
+
|
|
573
|
+
# Set description
|
|
574
|
+
if tool_name == "bash":
|
|
575
|
+
self.query_one("#approval-desc", Label).update(str(tool_args.get("cmd", "")))
|
|
576
|
+
else:
|
|
577
|
+
self.query_one("#approval-desc", Label).update("")
|
|
578
|
+
|
|
579
|
+
# Set details
|
|
580
|
+
details = self.query_one("#approval-details", Vertical)
|
|
581
|
+
details.remove_children()
|
|
582
|
+
max_lines = 18
|
|
583
|
+
|
|
584
|
+
if tool_name == "edit":
|
|
585
|
+
old = str(tool_args.get("old", ""))
|
|
586
|
+
new = str(tool_args.get("new", ""))
|
|
587
|
+
old_lines = old.split("\n")[:max_lines]
|
|
588
|
+
new_lines = new.split("\n")[:max_lines]
|
|
589
|
+
for i, line in enumerate(old_lines, 1):
|
|
590
|
+
details.mount(Label(f"{i:3} - {line}", classes="diff-del"))
|
|
591
|
+
for i, line in enumerate(new_lines, 1):
|
|
592
|
+
details.mount(Label(f"{i:3} + {line}", classes="diff-add"))
|
|
593
|
+
elif tool_name == "write":
|
|
594
|
+
content = str(tool_args.get("content", ""))
|
|
595
|
+
lines = content.split("\n")
|
|
596
|
+
for i, line in enumerate(lines[:max_lines], 1):
|
|
597
|
+
details.mount(Label(f"{i:3} + {line}", classes="diff-add"))
|
|
598
|
+
if len(lines) > max_lines:
|
|
599
|
+
details.mount(Label(f" ... ({len(lines) - max_lines} more lines)", classes="diff-ctx"))
|
|
600
|
+
elif tool_name == "bash":
|
|
601
|
+
cmd = str(tool_args.get("cmd", ""))
|
|
602
|
+
details.mount(Label(f"$ {cmd}", classes="diff-ctx"))
|
|
603
|
+
elif tool_name in ("read", "glob", "grep"):
|
|
604
|
+
path = tool_args.get("path", tool_args.get("pat", ""))
|
|
605
|
+
details.mount(Label(f"{path}", classes="diff-ctx"))
|
|
606
|
+
|
|
607
|
+
# Set question
|
|
608
|
+
name = Path(file_path).name if file_path else "this"
|
|
609
|
+
if tool_name == "edit":
|
|
610
|
+
question = f"Do you want to make this edit to {name}?"
|
|
611
|
+
elif tool_name == "write":
|
|
612
|
+
question = f"Do you want to create {name}?"
|
|
613
|
+
elif tool_name == "read":
|
|
614
|
+
question = f"Do you want to read {name}?"
|
|
615
|
+
elif tool_name == "bash":
|
|
616
|
+
question = "Do you want to run this command?"
|
|
617
|
+
else:
|
|
618
|
+
question = "Do you want to proceed?"
|
|
619
|
+
self.query_one("#approval-question", Label).update(question)
|
|
620
|
+
|
|
621
|
+
# Set options (4 options like Claude Code)
|
|
622
|
+
self._approval_options = [
|
|
623
|
+
("yes", "Yes"),
|
|
624
|
+
("yes_always", f"{remember_text} (shift+Tab)"),
|
|
625
|
+
("no", "No"),
|
|
626
|
+
("correct", "Type correction (Tab)"),
|
|
627
|
+
]
|
|
628
|
+
self._approval_selection = 0
|
|
629
|
+
self._sync_approval_options()
|
|
630
|
+
|
|
631
|
+
# Reset correction input
|
|
632
|
+
correction_input = self.query_one("#correction-input", Input)
|
|
633
|
+
self._show_correction = False
|
|
634
|
+
correction_input.remove_class("-visible")
|
|
635
|
+
correction_input.disabled = True
|
|
636
|
+
correction_input.value = ""
|
|
637
|
+
|
|
638
|
+
def _sync_approval_options(self) -> None:
|
|
639
|
+
"""Update option labels with selection indicator."""
|
|
640
|
+
for i, (_decision, text) in enumerate(self._approval_options):
|
|
641
|
+
label = self.query_one(f"#opt-{i}", Label)
|
|
642
|
+
prefix = "> " if i == self._approval_selection else " "
|
|
643
|
+
label.update(f"{prefix}{i + 1}. {text}")
|
|
644
|
+
label.set_class(i == self._approval_selection, "-selected")
|
|
645
|
+
|
|
646
|
+
def _hide_approval(self) -> None:
|
|
647
|
+
"""Hide approval UI, show prompt."""
|
|
648
|
+
self.query_one("#approval").remove_class("-visible")
|
|
649
|
+
self.query_one("#prompt-row").display = True
|
|
650
|
+
self.query_one("#prompt-input", Input).disabled = False
|
|
651
|
+
self._hide_correction_input()
|
|
652
|
+
self.call_after_refresh(self._sync_focus)
|
|
653
|
+
|
|
654
|
+
def _resolve_approval(self, decision: str, correction: str | None = None) -> None:
|
|
655
|
+
"""Complete the approval with a decision."""
|
|
656
|
+
if self._approval_future is None:
|
|
657
|
+
return
|
|
658
|
+
|
|
659
|
+
result = {
|
|
660
|
+
"decision": decision,
|
|
661
|
+
"additional_instructions": None,
|
|
662
|
+
"edited_args": self._edited_args,
|
|
663
|
+
"correction": correction,
|
|
664
|
+
}
|
|
665
|
+
|
|
666
|
+
self._hide_approval()
|
|
667
|
+
self._approval_future.set_result(result)
|
|
668
|
+
self._approval_future = None
|
|
669
|
+
|
|
670
|
+
def _edit_args_in_editor(self) -> None:
|
|
671
|
+
"""Open $EDITOR to edit tool args."""
|
|
672
|
+
editor_raw = os.environ.get("EDITOR") or "vim"
|
|
673
|
+
editor_cmd = shlex.split(editor_raw)
|
|
674
|
+
|
|
675
|
+
# Use edited args if already edited, otherwise original
|
|
676
|
+
args_to_edit = self._edited_args if self._edited_args is not None else self._current_tool_args
|
|
677
|
+
|
|
678
|
+
tmp_path: str | None = None
|
|
679
|
+
try:
|
|
680
|
+
with tempfile.NamedTemporaryFile(mode="w", suffix=".json", delete=False, encoding="utf-8") as f:
|
|
681
|
+
json.dump(args_to_edit, f, indent=2, ensure_ascii=False)
|
|
682
|
+
f.write("\n")
|
|
683
|
+
tmp_path = f.name
|
|
684
|
+
|
|
685
|
+
with self.suspend():
|
|
686
|
+
subprocess.run([*editor_cmd, tmp_path], check=False)
|
|
687
|
+
|
|
688
|
+
with open(tmp_path, "r", encoding="utf-8") as f:
|
|
689
|
+
edited = json.load(f)
|
|
690
|
+
|
|
691
|
+
if isinstance(edited, dict):
|
|
692
|
+
self._edited_args = edited
|
|
693
|
+
self._current_tool_args = edited
|
|
694
|
+
# Refresh the UI to show edited content
|
|
695
|
+
remember_text = self.permissions.remember_text(self._current_tool_name, edited)
|
|
696
|
+
self._refresh_approval_ui(self._current_tool_name, edited, remember_text)
|
|
697
|
+
self.call_after_refresh(self._sync_focus)
|
|
698
|
+
|
|
699
|
+
except Exception as e:
|
|
700
|
+
self.output(f"Edit failed: {e}", style="system")
|
|
701
|
+
finally:
|
|
702
|
+
if tmp_path is not None:
|
|
703
|
+
try:
|
|
704
|
+
os.unlink(tmp_path)
|
|
705
|
+
except OSError:
|
|
706
|
+
pass
|
|
707
|
+
|
|
708
|
+
async def request_approval(self, tool_name: str, tool_args: dict[str, Any]) -> tuple[dict[str, Any], int]:
|
|
709
|
+
if AUTO_APPROVE_READONLY and tool_name in READONLY_TOOLS:
|
|
710
|
+
return {"decision": "yes", "additional_instructions": None, "edited_args": None, "correction": None, "auto_approved": True}, 0
|
|
711
|
+
|
|
712
|
+
if self.permissions.allows(tool_name, tool_args):
|
|
713
|
+
return {"decision": "yes", "additional_instructions": None, "edited_args": None, "correction": None, "auto_approved": True}, 0
|
|
714
|
+
|
|
715
|
+
remember_text = self.permissions.remember_text(tool_name, tool_args)
|
|
716
|
+
start = time.time()
|
|
717
|
+
|
|
718
|
+
# Create future and show UI
|
|
719
|
+
self._approval_future = asyncio.Future()
|
|
720
|
+
self._show_approval(tool_name, tool_args, remember_text)
|
|
721
|
+
|
|
722
|
+
# Wait for user decision
|
|
723
|
+
result = await self._approval_future
|
|
724
|
+
latency = int((time.time() - start) * 1000)
|
|
725
|
+
|
|
726
|
+
if result.get("decision") == "yes_always":
|
|
727
|
+
final_args = result.get("edited_args") or tool_args
|
|
728
|
+
self.permissions.remember(tool_name, final_args)
|
|
729
|
+
|
|
730
|
+
result["auto_approved"] = False
|
|
731
|
+
return result, latency
|
|
732
|
+
|
|
733
|
+
def on_key(self, event: events.Key) -> None:
|
|
734
|
+
"""Handle keyboard input for approval UI (and prompt focus recovery)."""
|
|
735
|
+
if self._approval_future is None:
|
|
736
|
+
# Recover if focus is lost (e.g., user clicked output). Keep typing seamless.
|
|
737
|
+
if (
|
|
738
|
+
event.character
|
|
739
|
+
and event.character.isprintable()
|
|
740
|
+
and self.query_one("#prompt-row").display
|
|
741
|
+
and not self.query_one("#prompt-input", Input).disabled
|
|
742
|
+
and self.screen.focused is not self.query_one("#prompt-input", Input)
|
|
743
|
+
):
|
|
744
|
+
prompt_input = self.query_one("#prompt-input", Input)
|
|
745
|
+
prompt_input.focus()
|
|
746
|
+
prompt_input.insert_text_at_cursor(event.character)
|
|
747
|
+
event.stop()
|
|
748
|
+
return
|
|
749
|
+
|
|
750
|
+
key = event.key
|
|
751
|
+
|
|
752
|
+
# If correction input is focused, only handle escape/enter
|
|
753
|
+
if self._show_correction:
|
|
754
|
+
if key == "escape":
|
|
755
|
+
self._hide_correction_input()
|
|
756
|
+
event.stop()
|
|
757
|
+
return
|
|
758
|
+
|
|
759
|
+
if key in ("1", "y"):
|
|
760
|
+
self._resolve_approval("yes")
|
|
761
|
+
event.stop()
|
|
762
|
+
elif key in ("2",) or key == "shift+tab":
|
|
763
|
+
self._resolve_approval("yes_always")
|
|
764
|
+
event.stop()
|
|
765
|
+
elif key in ("3", "n"):
|
|
766
|
+
self._resolve_approval("no")
|
|
767
|
+
event.stop()
|
|
768
|
+
elif key == "4":
|
|
769
|
+
# Option 4: Type correction
|
|
770
|
+
self._show_correction_input()
|
|
771
|
+
event.stop()
|
|
772
|
+
elif key == "escape":
|
|
773
|
+
self._resolve_approval("no")
|
|
774
|
+
event.stop()
|
|
775
|
+
elif key in ("up", "left"):
|
|
776
|
+
self._approval_selection = max(0, self._approval_selection - 1)
|
|
777
|
+
self._sync_approval_options()
|
|
778
|
+
event.stop()
|
|
779
|
+
elif key in ("down", "right"):
|
|
780
|
+
self._approval_selection = min(len(self._approval_options) - 1, self._approval_selection + 1)
|
|
781
|
+
self._sync_approval_options()
|
|
782
|
+
event.stop()
|
|
783
|
+
elif key == "enter":
|
|
784
|
+
decision = self._approval_options[self._approval_selection][0]
|
|
785
|
+
if decision == "correct":
|
|
786
|
+
# Show correction input
|
|
787
|
+
self._show_correction_input()
|
|
788
|
+
else:
|
|
789
|
+
self._resolve_approval(decision)
|
|
790
|
+
event.stop()
|
|
791
|
+
elif key == "tab":
|
|
792
|
+
# Tab = show correction input (option 4)
|
|
793
|
+
self._show_correction_input()
|
|
794
|
+
event.stop()
|
|
795
|
+
elif key == "e":
|
|
796
|
+
# Edit args in $EDITOR
|
|
797
|
+
self._edit_args_in_editor()
|
|
798
|
+
event.stop()
|
|
799
|
+
|
|
800
|
+
@on(Input.Submitted, "#correction-input")
|
|
801
|
+
def on_correction_submitted(self, event: Input.Submitted) -> None:
|
|
802
|
+
"""Handle Enter in correction input - deny with correction text."""
|
|
803
|
+
correction = event.input.value.strip()
|
|
804
|
+
self._hide_correction_input()
|
|
805
|
+
# Correction = rejection with feedback for DPO
|
|
806
|
+
self._resolve_approval("correct", correction=correction if correction else None)
|
|
807
|
+
event.stop()
|
|
808
|
+
|
|
809
|
+
@on(Input.Submitted, "#prompt-input")
|
|
810
|
+
def on_submit(self, event: Input.Submitted) -> None:
|
|
811
|
+
text = event.value.strip()
|
|
812
|
+
if not text:
|
|
813
|
+
return
|
|
814
|
+
event.input.value = ""
|
|
815
|
+
self.output(text, style="user")
|
|
816
|
+
self._run_agent_loop(text)
|
|
817
|
+
|
|
818
|
+
@work(exclusive=True)
|
|
819
|
+
async def _run_agent_loop(self, user_input: str) -> None:
|
|
820
|
+
session = self.session
|
|
821
|
+
if session is None:
|
|
822
|
+
return
|
|
823
|
+
|
|
824
|
+
session.add_user(user_input)
|
|
825
|
+
|
|
826
|
+
try:
|
|
827
|
+
while True:
|
|
828
|
+
thinking = self.output("Thinking...", style="thinking")
|
|
829
|
+
prompt_messages = list(session.messages)
|
|
830
|
+
|
|
831
|
+
try:
|
|
832
|
+
message, ok = await session.sample()
|
|
833
|
+
except Exception as e:
|
|
834
|
+
thinking.remove()
|
|
835
|
+
self.output(f"Error: {e}", style="system")
|
|
836
|
+
return
|
|
837
|
+
|
|
838
|
+
thinking.remove()
|
|
839
|
+
|
|
840
|
+
if not ok:
|
|
841
|
+
self.output("Parse failed", style="system")
|
|
842
|
+
return
|
|
843
|
+
|
|
844
|
+
content = message.get("content", "")
|
|
845
|
+
if isinstance(content, list):
|
|
846
|
+
text = "".join(p.get("text", "") for p in content if p.get("type") == "text")
|
|
847
|
+
else:
|
|
848
|
+
text = content
|
|
849
|
+
if text and str(text).strip():
|
|
850
|
+
self.output(str(text), style="assistant")
|
|
851
|
+
|
|
852
|
+
session.add_assistant(message)
|
|
853
|
+
|
|
854
|
+
tool_calls = message.get("tool_calls", [])
|
|
855
|
+
if not tool_calls:
|
|
856
|
+
return
|
|
857
|
+
|
|
858
|
+
for tc in tool_calls:
|
|
859
|
+
name = tc.function.name
|
|
860
|
+
try:
|
|
861
|
+
model_args = json.loads(tc.function.arguments or "{}")
|
|
862
|
+
except json.JSONDecodeError:
|
|
863
|
+
model_args = {}
|
|
864
|
+
|
|
865
|
+
approval, latency = await self.request_approval(name, model_args)
|
|
866
|
+
|
|
867
|
+
decision = approval.get("decision", "no")
|
|
868
|
+
correction = approval.get("correction")
|
|
869
|
+
approved = decision in ("yes", "yes_always")
|
|
870
|
+
edited_args = approval.get("edited_args") if isinstance(approval.get("edited_args"), dict) else None
|
|
871
|
+
final_args = edited_args if (approved and edited_args is not None) else model_args
|
|
872
|
+
|
|
873
|
+
# If the user edited args, mutate the message tool call so the conversation
|
|
874
|
+
# (and any training) reflects the actual executed tool.
|
|
875
|
+
if approved and edited_args is not None:
|
|
876
|
+
tc.function.arguments = json.dumps(final_args, ensure_ascii=False)
|
|
877
|
+
|
|
878
|
+
if approved:
|
|
879
|
+
result = execute_tool(name, final_args)
|
|
880
|
+
lines = result.split("\n")
|
|
881
|
+
if name == "read":
|
|
882
|
+
self.output(f"Read {len(lines)} lines", style="tool")
|
|
883
|
+
else:
|
|
884
|
+
preview = lines[0][:100] if lines else ""
|
|
885
|
+
if len(lines) > 1:
|
|
886
|
+
preview += f" ... +{len(lines) - 1} lines"
|
|
887
|
+
self.output(preview, style="tool")
|
|
888
|
+
else:
|
|
889
|
+
# Denied or corrected
|
|
890
|
+
deny_reason = correction or "User denied this tool call."
|
|
891
|
+
result = json.dumps(
|
|
892
|
+
{
|
|
893
|
+
"status": "denied",
|
|
894
|
+
"tool_name": name,
|
|
895
|
+
"tool_call_id": tc.id,
|
|
896
|
+
"tool_args": model_args,
|
|
897
|
+
"reason": deny_reason,
|
|
898
|
+
"correction": correction,
|
|
899
|
+
},
|
|
900
|
+
ensure_ascii=False,
|
|
901
|
+
)
|
|
902
|
+
if correction:
|
|
903
|
+
self.output(correction, style="correction")
|
|
904
|
+
else:
|
|
905
|
+
self.output("Denied", style="tool")
|
|
906
|
+
|
|
907
|
+
session.add_tool_result(tc.id, result)
|
|
908
|
+
|
|
909
|
+
# Determine decision label for training
|
|
910
|
+
auto = bool(approval.get("auto_approved", False))
|
|
911
|
+
if auto:
|
|
912
|
+
decision_label: Literal["approved", "denied", "edited", "corrected", "auto_approved"] = "auto_approved"
|
|
913
|
+
elif decision == "correct":
|
|
914
|
+
decision_label = "corrected"
|
|
915
|
+
elif not approved:
|
|
916
|
+
decision_label = "denied"
|
|
917
|
+
elif edited_args is not None:
|
|
918
|
+
decision_label = "edited"
|
|
919
|
+
else:
|
|
920
|
+
decision_label = "approved"
|
|
921
|
+
|
|
922
|
+
# Train on approved write/edit (single tool call only)
|
|
923
|
+
if (
|
|
924
|
+
approved
|
|
925
|
+
and not auto
|
|
926
|
+
and len(tool_calls) == 1
|
|
927
|
+
and name in TRAIN_TOOL_NAMES
|
|
928
|
+
and session.training_client is not None
|
|
929
|
+
):
|
|
930
|
+
scale = TRAIN_SCALE_EDITED if edited_args else TRAIN_SCALE_APPROVE
|
|
931
|
+
metrics = session.train_on_approval(
|
|
932
|
+
prompt_messages=prompt_messages,
|
|
933
|
+
assistant_message=message,
|
|
934
|
+
scale=scale,
|
|
935
|
+
)
|
|
936
|
+
if metrics:
|
|
937
|
+
self.last_loss = metrics.get("loss", 0.0)
|
|
938
|
+
self.total_tokens += int(metrics.get("tokens", 0))
|
|
939
|
+
tag = "edit" if edited_args else "approve"
|
|
940
|
+
extra = ""
|
|
941
|
+
if "kl_student_teacher" in metrics:
|
|
942
|
+
extra = (
|
|
943
|
+
f" st={metrics.get('kl_student_teacher', 0.0):.3f}"
|
|
944
|
+
f" kl={metrics.get('approx_kl', 0.0):.4f}"
|
|
945
|
+
f" r={metrics.get('ratio', 1.0):.3f}"
|
|
946
|
+
)
|
|
947
|
+
self.output(
|
|
948
|
+
f"cd[{tag}] step {int(metrics.get('step', 0))} | "
|
|
949
|
+
f"loss={self.last_loss:.3f} tok={int(metrics.get('tokens', 0))} w={scale:g}{extra}",
|
|
950
|
+
style="system",
|
|
951
|
+
)
|
|
952
|
+
self._update_status()
|
|
953
|
+
finally:
|
|
954
|
+
self.call_after_refresh(self._sync_focus)
|
|
955
|
+
|
|
956
|
+
def _update_status(self) -> None:
|
|
957
|
+
if self.session and self.session.train_steps > 0:
|
|
958
|
+
tokens_k = self.total_tokens / 1000
|
|
959
|
+
status = f"Steps: {self.session.train_steps} | Loss: {self.last_loss:.2f} | Tokens: {tokens_k:.1f}k"
|
|
960
|
+
else:
|
|
961
|
+
status = ""
|
|
962
|
+
self.query_one("#status-right", Label).update(status)
|
|
963
|
+
|
|
964
|
+
def action_clear(self) -> None:
|
|
965
|
+
out = self.query_one("#output", VerticalScroll)
|
|
966
|
+
out.remove_children()
|
|
967
|
+
out.mount(Label("", id="spacer"))
|
|
968
|
+
if self.session:
|
|
969
|
+
self.session.clear()
|
|
970
|
+
self._update_status()
|
|
971
|
+
self.output("Cleared.", style="system")
|
|
972
|
+
self.call_after_refresh(self._sync_focus)
|
|
973
|
+
|
|
974
|
+
def action_save_checkpoint(self) -> None:
|
|
975
|
+
"""Save current LoRA weights to disk (Ctrl+S)."""
|
|
976
|
+
if not self.session or self.session.training_client is None:
|
|
977
|
+
self.output("No training session to save.", style="system")
|
|
978
|
+
return
|
|
979
|
+
|
|
980
|
+
# Create checkpoint directory if needed
|
|
981
|
+
CHECKPOINT_DIR.mkdir(parents=True, exist_ok=True)
|
|
982
|
+
|
|
983
|
+
# Generate checkpoint name with step count
|
|
984
|
+
step = self.session.train_steps
|
|
985
|
+
checkpoint_path = CHECKPOINT_DIR / f"step_{step:06d}"
|
|
986
|
+
|
|
987
|
+
if self.session.save_checkpoint(str(checkpoint_path)):
|
|
988
|
+
self.output(f"Saved checkpoint: {checkpoint_path}", style="system")
|
|
989
|
+
else:
|
|
990
|
+
self.output("Failed to save checkpoint.", style="system")
|
|
991
|
+
|
|
992
|
+
|
|
993
|
+
if __name__ == "__main__":
|
|
994
|
+
TinkerCodeApp().run()
|