continualcode 0.1.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
continualcode/tui.py ADDED
@@ -0,0 +1,994 @@
1
+ #!/usr/bin/env python3
2
+ """
3
+ Minimal TUI for a coding agent with online context (prompt) distillation.
4
+
5
+ Generation uses a teacher prefix that includes a long policy prompt (POLICY_PATH).
6
+ Training distills that policy into weights by running cross-entropy on the approved
7
+ assistant message but with the policy removed (student prefix).
8
+ """
9
+
10
+ from __future__ import annotations
11
+
12
+ import asyncio
13
+ import json
14
+ import os
15
+ import shlex
16
+ import subprocess
17
+ import tempfile
18
+ import time
19
+ from pathlib import Path
20
+ from typing import Any, Literal
21
+
22
+ from textual import events, on, work
23
+ from textual.app import App, ComposeResult
24
+ from textual.binding import Binding
25
+ from textual.containers import Horizontal, Vertical, VerticalScroll
26
+ from textual.widgets import Input, Label
27
+
28
+ try:
29
+ from continualcode.session import ContextDistillSession
30
+ TINKER_AVAILABLE = True
31
+ except Exception:
32
+ ContextDistillSession = None # type: ignore[assignment]
33
+ TINKER_AVAILABLE = False
34
+
35
+ from continualcode.tools import READONLY_TOOLS, execute_tool
36
+
37
+
38
+ # --- Config ---
39
+ MODEL = os.environ.get("MODEL", "Qwen/Qwen3-4B-Instruct-2507")
40
+ TINKER_URL = os.environ.get("TINKER_URL")
41
+ MAX_TOKENS = int(os.environ.get("MAX_TOKENS", "4096"))
42
+ TEMPERATURE = float(os.environ.get("TEMPERATURE", "0.7"))
43
+ ENABLE_TRAINING = os.environ.get("ENABLE_TRAINING", "1") == "1"
44
+ LEARNING_RATE = float(os.environ.get("LEARNING_RATE", "1e-5"))
45
+ LORA_RANK = int(os.environ.get("LORA_RANK", "32"))
46
+ AUTO_APPROVE_READONLY = os.environ.get("AUTO_APPROVE_READONLY", "0") == "1"
47
+
48
+ POLICY_PATH = os.environ.get("POLICY_PATH", "./policy_memory.md")
49
+ TRAIN_SCALE_APPROVE = float(os.environ.get("TRAIN_SCALE_APPROVE", "1.0"))
50
+ TRAIN_SCALE_EDITED = float(os.environ.get("TRAIN_SCALE_EDITED", "2.0"))
51
+ DISTILL_MODE = os.environ.get("DISTILL_MODE", "on_policy") # "on_policy" or "off_policy"
52
+ TRAIN_MAX_TOKENS = int(os.environ.get("TRAIN_MAX_TOKENS", "1024"))
53
+ TRAIN_TEMPERATURE = float(os.environ.get("TRAIN_TEMPERATURE", "1.0"))
54
+ KL_COEF = float(os.environ.get("KL_COEF", "1.0"))
55
+
56
+ # Checkpoint directory for manual saves
57
+ CHECKPOINT_DIR = Path(os.environ.get("CHECKPOINT_DIR", "./checkpoints"))
58
+ # Load checkpoint on startup (optional)
59
+ LOAD_CHECKPOINT = os.environ.get("LOAD_CHECKPOINT", "")
60
+
61
+ # Only train on file modification tools
62
+ TRAIN_TOOL_NAMES = {"edit", "write"}
63
+
64
+
65
+ # --- Permission Memory (from tinkercode_tui.py) ---
66
+
67
+ def _resolve_path(project_root: Path, raw_path: str) -> Path:
68
+ p = Path(raw_path).expanduser()
69
+ if not p.is_absolute():
70
+ p = project_root / p
71
+ return p.resolve()
72
+
73
+
74
+ def _scope_for_path(project_root: Path, raw_path: str) -> tuple[Path, str]:
75
+ abs_path = _resolve_path(project_root, raw_path)
76
+ try:
77
+ rel = abs_path.relative_to(project_root)
78
+ except ValueError:
79
+ return abs_path.parent, str(abs_path.parent)
80
+
81
+ if len(rel.parts) > 1:
82
+ return project_root / rel.parts[0], f"{rel.parts[0]}/"
83
+ return project_root, f"{project_root.name}/"
84
+
85
+
86
+ class PermissionMemory:
87
+ """Session-scoped permission memory (similar to Claude Code 'don't ask again')."""
88
+
89
+ def __init__(self, project_root: Path) -> None:
90
+ self.project_root = project_root.resolve()
91
+ self.allowed_tools: set[str] = set()
92
+ self.allowed_prefixes: dict[str, set[Path]] = {}
93
+
94
+ def allows(self, tool_name: str, tool_args: dict[str, Any]) -> bool:
95
+ if tool_name in self.allowed_tools:
96
+ return True
97
+
98
+ raw_path = tool_args.get("path")
99
+ if not isinstance(raw_path, str):
100
+ return False
101
+
102
+ abs_path = _resolve_path(self.project_root, raw_path)
103
+ for prefix in self.allowed_prefixes.get(tool_name, set()):
104
+ try:
105
+ abs_path.relative_to(prefix)
106
+ return True
107
+ except ValueError:
108
+ continue
109
+ return False
110
+
111
+ def remember(self, tool_name: str, tool_args: dict[str, Any]) -> None:
112
+ if tool_name == "bash":
113
+ self.allowed_tools.add(tool_name)
114
+ return
115
+
116
+ raw_path = tool_args.get("path")
117
+ if not isinstance(raw_path, str):
118
+ self.allowed_tools.add(tool_name)
119
+ return
120
+
121
+ prefix, _label = _scope_for_path(self.project_root, raw_path)
122
+
123
+ if tool_name in ("write", "edit"):
124
+ for t in ("write", "edit"):
125
+ self.allowed_prefixes.setdefault(t, set()).add(prefix)
126
+ return
127
+
128
+ self.allowed_prefixes.setdefault(tool_name, set()).add(prefix)
129
+
130
+ def remember_text(self, tool_name: str, tool_args: dict[str, Any]) -> str:
131
+ if tool_name == "bash":
132
+ return "Yes, and don't ask again for Bash commands"
133
+
134
+ raw_path = tool_args.get("path")
135
+ label = f"{self.project_root.name}/"
136
+ if isinstance(raw_path, str) and raw_path:
137
+ _prefix, label = _scope_for_path(self.project_root, raw_path)
138
+
139
+ if tool_name in ("write", "edit"):
140
+ return f"Yes, allow all edits in {label}"
141
+ if tool_name == "read":
142
+ return f"Yes, allow reading from {label}"
143
+ if tool_name in ("glob", "grep"):
144
+ return f"Yes, allow searching in {label}"
145
+
146
+ return "Yes, and don't ask again"
147
+
148
+
149
+ class TinkerCodeApp(App):
150
+ """Claude Code-style TUI with online prompt distillation."""
151
+
152
+ # Textual enables a command palette (Ctrl+P) by default. In a coding-agent TUI
153
+ # this is easy to trigger accidentally and looks like "random input". Disable it.
154
+ ENABLE_COMMAND_PALETTE = False
155
+
156
+ DEFAULT_CSS = """
157
+ Screen {
158
+ background: #1e1e1e;
159
+ }
160
+
161
+ /* Main output area */
162
+ #output {
163
+ height: 1fr;
164
+ padding: 0;
165
+ scrollbar-size: 1 1;
166
+ scrollbar-color: #333333;
167
+ background: #1e1e1e;
168
+ }
169
+
170
+ #spacer {
171
+ height: 1fr;
172
+ }
173
+
174
+ /* Message styling - all messages full width */
175
+ .msg {
176
+ height: auto;
177
+ width: 100%;
178
+ padding: 0 2;
179
+ }
180
+
181
+ /* User messages - dark background bar, full width */
182
+ .msg-user {
183
+ color: #ffffff;
184
+ background: #2a2a2a;
185
+ padding: 1 2;
186
+ text-style: bold;
187
+ }
188
+
189
+ /* Assistant messages */
190
+ .msg-assistant {
191
+ color: #cccccc;
192
+ padding: 1 2;
193
+ }
194
+
195
+ /* Tool results - indented */
196
+ .msg-tool {
197
+ color: #888888;
198
+ padding: 0 2 0 4;
199
+ }
200
+
201
+ /* System messages - subtle */
202
+ .msg-system {
203
+ color: #666666;
204
+ padding: 0 2;
205
+ }
206
+
207
+ /* Thinking indicator - italic like Claude Code */
208
+ .msg-thinking {
209
+ color: #666666;
210
+ text-style: italic;
211
+ padding: 1 2;
212
+ }
213
+
214
+ /* Correction feedback */
215
+ .msg-correction {
216
+ color: #dcdcaa;
217
+ padding: 0 2 0 4;
218
+ }
219
+
220
+ /* Status bar container - very bottom */
221
+ #status-bar {
222
+ height: 1;
223
+ dock: bottom;
224
+ background: #1e1e1e;
225
+ padding: 0 2;
226
+ }
227
+
228
+ #status-left {
229
+ width: auto;
230
+ color: #555555;
231
+ }
232
+
233
+ #status-right {
234
+ width: 1fr;
235
+ text-align: right;
236
+ color: #555555;
237
+ }
238
+
239
+ /* Bottom area - prompt OR approval */
240
+ #bottom-area {
241
+ dock: bottom;
242
+ height: auto;
243
+ max-height: 70%;
244
+ background: #252526;
245
+ border-bottom: solid #3c3c3c;
246
+ }
247
+
248
+ /* Prompt input row */
249
+ #prompt-row {
250
+ height: auto;
251
+ width: 100%;
252
+ padding: 1 2 1 2;
253
+ background: #252526;
254
+ }
255
+
256
+ #prompt-caret {
257
+ width: 2;
258
+ color: #cccccc;
259
+ height: 1;
260
+ }
261
+
262
+ #prompt-input {
263
+ width: 1fr;
264
+ border: none;
265
+ background: transparent;
266
+ padding: 0;
267
+ margin: 0;
268
+ height: 1;
269
+ color: #cccccc;
270
+ }
271
+
272
+ #prompt-input:focus {
273
+ border: none;
274
+ color: #ffffff;
275
+ }
276
+
277
+ /* Approval widget */
278
+ #approval {
279
+ display: none;
280
+ height: auto;
281
+ padding: 1 2;
282
+ background: #1e1e1e;
283
+ }
284
+
285
+ #approval.-visible {
286
+ display: block;
287
+ }
288
+
289
+ .approval-kicker {
290
+ color: #569cd6;
291
+ text-style: bold;
292
+ }
293
+
294
+ .approval-title {
295
+ color: #ffffff;
296
+ text-style: bold;
297
+ }
298
+
299
+ .approval-desc {
300
+ color: #9cdcfe;
301
+ padding: 0 0 0 2;
302
+ }
303
+
304
+ #approval-details {
305
+ height: auto;
306
+ max-height: 18;
307
+ overflow-y: auto;
308
+ padding: 1 0 1 2;
309
+ }
310
+
311
+ .diff-add {
312
+ color: #4ec994;
313
+ background: #1e3a2f;
314
+ padding: 0 1;
315
+ }
316
+
317
+ .diff-del {
318
+ color: #f97583;
319
+ background: #3d1f23;
320
+ padding: 0 1;
321
+ }
322
+
323
+ .diff-ctx {
324
+ color: #666666;
325
+ }
326
+
327
+ .approval-question {
328
+ color: #ffffff;
329
+ padding: 1 0;
330
+ }
331
+
332
+ .approval-option {
333
+ height: auto;
334
+ color: #666666;
335
+ }
336
+
337
+ .approval-option.-selected {
338
+ color: #ffffff;
339
+ }
340
+
341
+ .approval-hint {
342
+ color: #444444;
343
+ padding: 1 0 0 0;
344
+ }
345
+
346
+ #correction-input {
347
+ display: none;
348
+ margin: 1 0;
349
+ border: solid #333333;
350
+ background: #252526;
351
+ }
352
+
353
+ #correction-input.-visible {
354
+ display: block;
355
+ }
356
+ """
357
+
358
+ BINDINGS = [
359
+ Binding("ctrl+c", "quit", "Quit"),
360
+ Binding("ctrl+l", "clear", "Clear"),
361
+ Binding("ctrl+s", "save_checkpoint", "Save"),
362
+ Binding("?", "help", "Help"),
363
+ ]
364
+
365
+ def __init__(self) -> None:
366
+ super().__init__()
367
+ self.session: ContextDistillSession | None = None
368
+ self.session_ready = False
369
+ self.permissions = PermissionMemory(Path.cwd())
370
+ self._loading = True
371
+
372
+ # Metrics for status bar
373
+ self.total_tokens = 0
374
+ self.last_loss = 0.0
375
+
376
+ # Approval state
377
+ self._approval_future: asyncio.Future[dict[str, Any]] | None = None
378
+ self._approval_selection = 0
379
+ self._approval_options: list[tuple[str, str]] = []
380
+ self._show_correction = False
381
+ self._current_tool_name: str = ""
382
+ self._current_tool_args: dict[str, Any] = {}
383
+ self._edited_args: dict[str, Any] | None = None
384
+
385
+ def _prompt_input(self) -> Input:
386
+ return self.query_one("#prompt-input", Input)
387
+
388
+ def _prompt_row_visible(self) -> bool:
389
+ return bool(self.query_one("#prompt-row").display)
390
+
391
+ def _sync_focus(self) -> None:
392
+ """Keep focus on the right input widget for the current UI mode."""
393
+ if self._approval_future is not None:
394
+ if self._show_correction:
395
+ self.query_one("#correction-input", Input).focus()
396
+ else:
397
+ self.set_focus(None)
398
+ return
399
+
400
+ prompt_input = self.query_one("#prompt-input", Input)
401
+ if not prompt_input.disabled:
402
+ prompt_input.focus()
403
+
404
+ def _show_correction_input(self) -> None:
405
+ correction_input = self.query_one("#correction-input", Input)
406
+ self._show_correction = True
407
+ correction_input.disabled = False
408
+ correction_input.add_class("-visible")
409
+ correction_input.focus()
410
+
411
+ def _hide_correction_input(self) -> None:
412
+ correction_input = self.query_one("#correction-input", Input)
413
+ self._show_correction = False
414
+ correction_input.remove_class("-visible")
415
+ correction_input.disabled = True
416
+ self.set_focus(None)
417
+
418
+ def compose(self) -> ComposeResult:
419
+ with VerticalScroll(id="output", can_focus=False, can_focus_children=False):
420
+ yield Label("", id="spacer")
421
+ with Vertical(id="bottom-area"):
422
+ with Horizontal(id="prompt-row"):
423
+ yield Label("> ", id="prompt-caret")
424
+ yield Input(id="prompt-input")
425
+ with Vertical(id="approval"):
426
+ yield Label("Tool use", classes="approval-kicker")
427
+ yield Label("", id="approval-title", classes="approval-title")
428
+ yield Label("", id="approval-desc", classes="approval-desc")
429
+ with Vertical(id="approval-details"):
430
+ pass
431
+ yield Label("", id="approval-question", classes="approval-question")
432
+ yield Label("", id="opt-0", classes="approval-option")
433
+ yield Label("", id="opt-1", classes="approval-option")
434
+ yield Label("", id="opt-2", classes="approval-option")
435
+ yield Label("", id="opt-3", classes="approval-option")
436
+ yield Input(placeholder="Type what to change…", id="correction-input")
437
+ yield Label("Esc cancel · e edit · Tab correct", classes="approval-hint")
438
+ with Horizontal(id="status-bar"):
439
+ yield Label("? for shortcuts", id="status-left")
440
+ yield Label("", id="status-right")
441
+
442
+ async def on_mount(self) -> None:
443
+ # The correction input starts hidden; keep it disabled so it can't steal focus.
444
+ self.query_one("#correction-input", Input).disabled = True
445
+
446
+ # Init can take a while (model load, network, auth). Hide/disable the prompt
447
+ # during init so stray keystrokes don't show up as random characters.
448
+ self.query_one("#prompt-row").display = False
449
+ self._prompt_input().disabled = True
450
+
451
+ self.call_after_refresh(self._sync_focus)
452
+ self._update_status()
453
+
454
+ if TINKER_AVAILABLE:
455
+ init_msg = self.output("Initializing Tinker session...", style="thinking")
456
+ try:
457
+ self.session = ContextDistillSession(
458
+ model=MODEL,
459
+ tinker_url=TINKER_URL,
460
+ max_tokens=MAX_TOKENS,
461
+ temperature=TEMPERATURE,
462
+ enable_training=ENABLE_TRAINING,
463
+ lora_rank=LORA_RANK,
464
+ learning_rate=LEARNING_RATE,
465
+ policy_path=POLICY_PATH,
466
+ distill_mode=DISTILL_MODE,
467
+ train_max_tokens=TRAIN_MAX_TOKENS,
468
+ train_temperature=TRAIN_TEMPERATURE,
469
+ kl_coef=KL_COEF,
470
+ )
471
+ await self.session.init()
472
+
473
+ # Load checkpoint if specified
474
+ if LOAD_CHECKPOINT and self.session.training_client is not None:
475
+ if self.session.load_checkpoint(LOAD_CHECKPOINT):
476
+ self.output(f"Loaded checkpoint: {LOAD_CHECKPOINT}", style="system")
477
+ else:
478
+ self.output(f"Failed to load checkpoint: {LOAD_CHECKPOINT}", style="system")
479
+
480
+ self.session_ready = True
481
+ init_msg.remove()
482
+ status = f"Ready. Model: {MODEL} | policy={POLICY_PATH} | mode={DISTILL_MODE}"
483
+ if self.session.training_client is not None:
484
+ status += " [distillation ON]"
485
+ self.output(status, style="system")
486
+
487
+ # Now that we're ready, enable the prompt.
488
+ self._loading = False
489
+ self.query_one("#prompt-row").display = True
490
+ self._prompt_input().disabled = False
491
+ self.call_after_refresh(self._sync_focus)
492
+ except Exception as e:
493
+ init_msg.remove()
494
+ self.output(
495
+ f"Failed to initialize Tinker: {e}\n"
496
+ "Set TINKER_API_KEY (or pass api_key to the client) and restart.",
497
+ style="system",
498
+ )
499
+ # Avoid a confusing UX where the prompt accepts input but nothing can run.
500
+ self._prompt_input().disabled = True
501
+ self._loading = False
502
+ else:
503
+ self.output("Tinker not available.", style="system")
504
+ self._prompt_input().disabled = True
505
+ self._loading = False
506
+
507
+ async def on_key(self, event: events.Key) -> None:
508
+ # While we're loading (and the prompt is hidden), swallow printable keys so
509
+ # impatient typing doesn't render as "random characters" (e.g. a lone 'p').
510
+ # Let control bindings (Ctrl+C/Ctrl+L/Ctrl+S/?) continue to work.
511
+ if self._loading and len(event.key) == 1 and event.key.isprintable():
512
+ event.stop()
513
+ return
514
+
515
+ # Fall back to the base behavior for the rest of the app.
516
+ await super().on_key(event)
517
+
518
+ def action_help(self) -> None:
519
+ self.output(
520
+ "Keys: Enter=send · Ctrl+L=clear · Ctrl+S=save · Ctrl+C=quit · "
521
+ "Approval: 1/y approve · 2 approve+remember · 3/n/Esc deny · 4/Tab correct · e edit args",
522
+ style="system",
523
+ )
524
+
525
+ def output(self, text: str, style: str = "assistant") -> Label:
526
+ out = self.query_one("#output", VerticalScroll)
527
+ prefix = {
528
+ "user": "> ",
529
+ "assistant": "● ",
530
+ "tool": "⎿ ",
531
+ "correction": "⎿ ",
532
+ "thinking": "● ",
533
+ }.get(style, "")
534
+ label = Label(f"{prefix}{text}", classes=f"msg msg-{style}")
535
+ out.mount(label)
536
+ out.scroll_end()
537
+ self._update_status()
538
+ self.call_after_refresh(self._sync_focus)
539
+ return label
540
+
541
+ def _show_approval(self, tool_name: str, tool_args: dict[str, Any], remember_text: str) -> None:
542
+ """Show the inline approval UI."""
543
+ self._current_tool_name = tool_name
544
+ self._current_tool_args = dict(tool_args)
545
+ self._edited_args = None
546
+
547
+ self.query_one("#prompt-row").display = False
548
+ self.query_one("#prompt-input", Input).disabled = True
549
+ approval = self.query_one("#approval")
550
+ approval.add_class("-visible")
551
+ self.set_focus(None)
552
+
553
+ self._refresh_approval_ui(tool_name, tool_args, remember_text)
554
+ self.call_after_refresh(self._sync_focus)
555
+
556
+ def _refresh_approval_ui(self, tool_name: str, tool_args: dict[str, Any], remember_text: str) -> None:
557
+ """Refresh the approval UI with current tool args."""
558
+ file_path = str(tool_args.get("path", ""))
559
+ edited_suffix = " (edited)" if self._edited_args is not None else ""
560
+
561
+ if tool_name == "edit":
562
+ title = f"Edit file {file_path}{edited_suffix}"
563
+ elif tool_name == "write":
564
+ title = f"Create file {file_path}{edited_suffix}"
565
+ elif tool_name == "read":
566
+ title = f"Read file {file_path}{edited_suffix}"
567
+ elif tool_name == "bash":
568
+ title = f"Bash command{edited_suffix}"
569
+ else:
570
+ title = f"{tool_name}{edited_suffix}"
571
+ self.query_one("#approval-title", Label).update(title)
572
+
573
+ # Set description
574
+ if tool_name == "bash":
575
+ self.query_one("#approval-desc", Label).update(str(tool_args.get("cmd", "")))
576
+ else:
577
+ self.query_one("#approval-desc", Label).update("")
578
+
579
+ # Set details
580
+ details = self.query_one("#approval-details", Vertical)
581
+ details.remove_children()
582
+ max_lines = 18
583
+
584
+ if tool_name == "edit":
585
+ old = str(tool_args.get("old", ""))
586
+ new = str(tool_args.get("new", ""))
587
+ old_lines = old.split("\n")[:max_lines]
588
+ new_lines = new.split("\n")[:max_lines]
589
+ for i, line in enumerate(old_lines, 1):
590
+ details.mount(Label(f"{i:3} - {line}", classes="diff-del"))
591
+ for i, line in enumerate(new_lines, 1):
592
+ details.mount(Label(f"{i:3} + {line}", classes="diff-add"))
593
+ elif tool_name == "write":
594
+ content = str(tool_args.get("content", ""))
595
+ lines = content.split("\n")
596
+ for i, line in enumerate(lines[:max_lines], 1):
597
+ details.mount(Label(f"{i:3} + {line}", classes="diff-add"))
598
+ if len(lines) > max_lines:
599
+ details.mount(Label(f" ... ({len(lines) - max_lines} more lines)", classes="diff-ctx"))
600
+ elif tool_name == "bash":
601
+ cmd = str(tool_args.get("cmd", ""))
602
+ details.mount(Label(f"$ {cmd}", classes="diff-ctx"))
603
+ elif tool_name in ("read", "glob", "grep"):
604
+ path = tool_args.get("path", tool_args.get("pat", ""))
605
+ details.mount(Label(f"{path}", classes="diff-ctx"))
606
+
607
+ # Set question
608
+ name = Path(file_path).name if file_path else "this"
609
+ if tool_name == "edit":
610
+ question = f"Do you want to make this edit to {name}?"
611
+ elif tool_name == "write":
612
+ question = f"Do you want to create {name}?"
613
+ elif tool_name == "read":
614
+ question = f"Do you want to read {name}?"
615
+ elif tool_name == "bash":
616
+ question = "Do you want to run this command?"
617
+ else:
618
+ question = "Do you want to proceed?"
619
+ self.query_one("#approval-question", Label).update(question)
620
+
621
+ # Set options (4 options like Claude Code)
622
+ self._approval_options = [
623
+ ("yes", "Yes"),
624
+ ("yes_always", f"{remember_text} (shift+Tab)"),
625
+ ("no", "No"),
626
+ ("correct", "Type correction (Tab)"),
627
+ ]
628
+ self._approval_selection = 0
629
+ self._sync_approval_options()
630
+
631
+ # Reset correction input
632
+ correction_input = self.query_one("#correction-input", Input)
633
+ self._show_correction = False
634
+ correction_input.remove_class("-visible")
635
+ correction_input.disabled = True
636
+ correction_input.value = ""
637
+
638
+ def _sync_approval_options(self) -> None:
639
+ """Update option labels with selection indicator."""
640
+ for i, (_decision, text) in enumerate(self._approval_options):
641
+ label = self.query_one(f"#opt-{i}", Label)
642
+ prefix = "> " if i == self._approval_selection else " "
643
+ label.update(f"{prefix}{i + 1}. {text}")
644
+ label.set_class(i == self._approval_selection, "-selected")
645
+
646
+ def _hide_approval(self) -> None:
647
+ """Hide approval UI, show prompt."""
648
+ self.query_one("#approval").remove_class("-visible")
649
+ self.query_one("#prompt-row").display = True
650
+ self.query_one("#prompt-input", Input).disabled = False
651
+ self._hide_correction_input()
652
+ self.call_after_refresh(self._sync_focus)
653
+
654
+ def _resolve_approval(self, decision: str, correction: str | None = None) -> None:
655
+ """Complete the approval with a decision."""
656
+ if self._approval_future is None:
657
+ return
658
+
659
+ result = {
660
+ "decision": decision,
661
+ "additional_instructions": None,
662
+ "edited_args": self._edited_args,
663
+ "correction": correction,
664
+ }
665
+
666
+ self._hide_approval()
667
+ self._approval_future.set_result(result)
668
+ self._approval_future = None
669
+
670
+ def _edit_args_in_editor(self) -> None:
671
+ """Open $EDITOR to edit tool args."""
672
+ editor_raw = os.environ.get("EDITOR") or "vim"
673
+ editor_cmd = shlex.split(editor_raw)
674
+
675
+ # Use edited args if already edited, otherwise original
676
+ args_to_edit = self._edited_args if self._edited_args is not None else self._current_tool_args
677
+
678
+ tmp_path: str | None = None
679
+ try:
680
+ with tempfile.NamedTemporaryFile(mode="w", suffix=".json", delete=False, encoding="utf-8") as f:
681
+ json.dump(args_to_edit, f, indent=2, ensure_ascii=False)
682
+ f.write("\n")
683
+ tmp_path = f.name
684
+
685
+ with self.suspend():
686
+ subprocess.run([*editor_cmd, tmp_path], check=False)
687
+
688
+ with open(tmp_path, "r", encoding="utf-8") as f:
689
+ edited = json.load(f)
690
+
691
+ if isinstance(edited, dict):
692
+ self._edited_args = edited
693
+ self._current_tool_args = edited
694
+ # Refresh the UI to show edited content
695
+ remember_text = self.permissions.remember_text(self._current_tool_name, edited)
696
+ self._refresh_approval_ui(self._current_tool_name, edited, remember_text)
697
+ self.call_after_refresh(self._sync_focus)
698
+
699
+ except Exception as e:
700
+ self.output(f"Edit failed: {e}", style="system")
701
+ finally:
702
+ if tmp_path is not None:
703
+ try:
704
+ os.unlink(tmp_path)
705
+ except OSError:
706
+ pass
707
+
708
+ async def request_approval(self, tool_name: str, tool_args: dict[str, Any]) -> tuple[dict[str, Any], int]:
709
+ if AUTO_APPROVE_READONLY and tool_name in READONLY_TOOLS:
710
+ return {"decision": "yes", "additional_instructions": None, "edited_args": None, "correction": None, "auto_approved": True}, 0
711
+
712
+ if self.permissions.allows(tool_name, tool_args):
713
+ return {"decision": "yes", "additional_instructions": None, "edited_args": None, "correction": None, "auto_approved": True}, 0
714
+
715
+ remember_text = self.permissions.remember_text(tool_name, tool_args)
716
+ start = time.time()
717
+
718
+ # Create future and show UI
719
+ self._approval_future = asyncio.Future()
720
+ self._show_approval(tool_name, tool_args, remember_text)
721
+
722
+ # Wait for user decision
723
+ result = await self._approval_future
724
+ latency = int((time.time() - start) * 1000)
725
+
726
+ if result.get("decision") == "yes_always":
727
+ final_args = result.get("edited_args") or tool_args
728
+ self.permissions.remember(tool_name, final_args)
729
+
730
+ result["auto_approved"] = False
731
+ return result, latency
732
+
733
+ def on_key(self, event: events.Key) -> None:
734
+ """Handle keyboard input for approval UI (and prompt focus recovery)."""
735
+ if self._approval_future is None:
736
+ # Recover if focus is lost (e.g., user clicked output). Keep typing seamless.
737
+ if (
738
+ event.character
739
+ and event.character.isprintable()
740
+ and self.query_one("#prompt-row").display
741
+ and not self.query_one("#prompt-input", Input).disabled
742
+ and self.screen.focused is not self.query_one("#prompt-input", Input)
743
+ ):
744
+ prompt_input = self.query_one("#prompt-input", Input)
745
+ prompt_input.focus()
746
+ prompt_input.insert_text_at_cursor(event.character)
747
+ event.stop()
748
+ return
749
+
750
+ key = event.key
751
+
752
+ # If correction input is focused, only handle escape/enter
753
+ if self._show_correction:
754
+ if key == "escape":
755
+ self._hide_correction_input()
756
+ event.stop()
757
+ return
758
+
759
+ if key in ("1", "y"):
760
+ self._resolve_approval("yes")
761
+ event.stop()
762
+ elif key in ("2",) or key == "shift+tab":
763
+ self._resolve_approval("yes_always")
764
+ event.stop()
765
+ elif key in ("3", "n"):
766
+ self._resolve_approval("no")
767
+ event.stop()
768
+ elif key == "4":
769
+ # Option 4: Type correction
770
+ self._show_correction_input()
771
+ event.stop()
772
+ elif key == "escape":
773
+ self._resolve_approval("no")
774
+ event.stop()
775
+ elif key in ("up", "left"):
776
+ self._approval_selection = max(0, self._approval_selection - 1)
777
+ self._sync_approval_options()
778
+ event.stop()
779
+ elif key in ("down", "right"):
780
+ self._approval_selection = min(len(self._approval_options) - 1, self._approval_selection + 1)
781
+ self._sync_approval_options()
782
+ event.stop()
783
+ elif key == "enter":
784
+ decision = self._approval_options[self._approval_selection][0]
785
+ if decision == "correct":
786
+ # Show correction input
787
+ self._show_correction_input()
788
+ else:
789
+ self._resolve_approval(decision)
790
+ event.stop()
791
+ elif key == "tab":
792
+ # Tab = show correction input (option 4)
793
+ self._show_correction_input()
794
+ event.stop()
795
+ elif key == "e":
796
+ # Edit args in $EDITOR
797
+ self._edit_args_in_editor()
798
+ event.stop()
799
+
800
+ @on(Input.Submitted, "#correction-input")
801
+ def on_correction_submitted(self, event: Input.Submitted) -> None:
802
+ """Handle Enter in correction input - deny with correction text."""
803
+ correction = event.input.value.strip()
804
+ self._hide_correction_input()
805
+ # Correction = rejection with feedback for DPO
806
+ self._resolve_approval("correct", correction=correction if correction else None)
807
+ event.stop()
808
+
809
+ @on(Input.Submitted, "#prompt-input")
810
+ def on_submit(self, event: Input.Submitted) -> None:
811
+ text = event.value.strip()
812
+ if not text:
813
+ return
814
+ event.input.value = ""
815
+ self.output(text, style="user")
816
+ self._run_agent_loop(text)
817
+
818
+ @work(exclusive=True)
819
+ async def _run_agent_loop(self, user_input: str) -> None:
820
+ session = self.session
821
+ if session is None:
822
+ return
823
+
824
+ session.add_user(user_input)
825
+
826
+ try:
827
+ while True:
828
+ thinking = self.output("Thinking...", style="thinking")
829
+ prompt_messages = list(session.messages)
830
+
831
+ try:
832
+ message, ok = await session.sample()
833
+ except Exception as e:
834
+ thinking.remove()
835
+ self.output(f"Error: {e}", style="system")
836
+ return
837
+
838
+ thinking.remove()
839
+
840
+ if not ok:
841
+ self.output("Parse failed", style="system")
842
+ return
843
+
844
+ content = message.get("content", "")
845
+ if isinstance(content, list):
846
+ text = "".join(p.get("text", "") for p in content if p.get("type") == "text")
847
+ else:
848
+ text = content
849
+ if text and str(text).strip():
850
+ self.output(str(text), style="assistant")
851
+
852
+ session.add_assistant(message)
853
+
854
+ tool_calls = message.get("tool_calls", [])
855
+ if not tool_calls:
856
+ return
857
+
858
+ for tc in tool_calls:
859
+ name = tc.function.name
860
+ try:
861
+ model_args = json.loads(tc.function.arguments or "{}")
862
+ except json.JSONDecodeError:
863
+ model_args = {}
864
+
865
+ approval, latency = await self.request_approval(name, model_args)
866
+
867
+ decision = approval.get("decision", "no")
868
+ correction = approval.get("correction")
869
+ approved = decision in ("yes", "yes_always")
870
+ edited_args = approval.get("edited_args") if isinstance(approval.get("edited_args"), dict) else None
871
+ final_args = edited_args if (approved and edited_args is not None) else model_args
872
+
873
+ # If the user edited args, mutate the message tool call so the conversation
874
+ # (and any training) reflects the actual executed tool.
875
+ if approved and edited_args is not None:
876
+ tc.function.arguments = json.dumps(final_args, ensure_ascii=False)
877
+
878
+ if approved:
879
+ result = execute_tool(name, final_args)
880
+ lines = result.split("\n")
881
+ if name == "read":
882
+ self.output(f"Read {len(lines)} lines", style="tool")
883
+ else:
884
+ preview = lines[0][:100] if lines else ""
885
+ if len(lines) > 1:
886
+ preview += f" ... +{len(lines) - 1} lines"
887
+ self.output(preview, style="tool")
888
+ else:
889
+ # Denied or corrected
890
+ deny_reason = correction or "User denied this tool call."
891
+ result = json.dumps(
892
+ {
893
+ "status": "denied",
894
+ "tool_name": name,
895
+ "tool_call_id": tc.id,
896
+ "tool_args": model_args,
897
+ "reason": deny_reason,
898
+ "correction": correction,
899
+ },
900
+ ensure_ascii=False,
901
+ )
902
+ if correction:
903
+ self.output(correction, style="correction")
904
+ else:
905
+ self.output("Denied", style="tool")
906
+
907
+ session.add_tool_result(tc.id, result)
908
+
909
+ # Determine decision label for training
910
+ auto = bool(approval.get("auto_approved", False))
911
+ if auto:
912
+ decision_label: Literal["approved", "denied", "edited", "corrected", "auto_approved"] = "auto_approved"
913
+ elif decision == "correct":
914
+ decision_label = "corrected"
915
+ elif not approved:
916
+ decision_label = "denied"
917
+ elif edited_args is not None:
918
+ decision_label = "edited"
919
+ else:
920
+ decision_label = "approved"
921
+
922
+ # Train on approved write/edit (single tool call only)
923
+ if (
924
+ approved
925
+ and not auto
926
+ and len(tool_calls) == 1
927
+ and name in TRAIN_TOOL_NAMES
928
+ and session.training_client is not None
929
+ ):
930
+ scale = TRAIN_SCALE_EDITED if edited_args else TRAIN_SCALE_APPROVE
931
+ metrics = session.train_on_approval(
932
+ prompt_messages=prompt_messages,
933
+ assistant_message=message,
934
+ scale=scale,
935
+ )
936
+ if metrics:
937
+ self.last_loss = metrics.get("loss", 0.0)
938
+ self.total_tokens += int(metrics.get("tokens", 0))
939
+ tag = "edit" if edited_args else "approve"
940
+ extra = ""
941
+ if "kl_student_teacher" in metrics:
942
+ extra = (
943
+ f" st={metrics.get('kl_student_teacher', 0.0):.3f}"
944
+ f" kl={metrics.get('approx_kl', 0.0):.4f}"
945
+ f" r={metrics.get('ratio', 1.0):.3f}"
946
+ )
947
+ self.output(
948
+ f"cd[{tag}] step {int(metrics.get('step', 0))} | "
949
+ f"loss={self.last_loss:.3f} tok={int(metrics.get('tokens', 0))} w={scale:g}{extra}",
950
+ style="system",
951
+ )
952
+ self._update_status()
953
+ finally:
954
+ self.call_after_refresh(self._sync_focus)
955
+
956
+ def _update_status(self) -> None:
957
+ if self.session and self.session.train_steps > 0:
958
+ tokens_k = self.total_tokens / 1000
959
+ status = f"Steps: {self.session.train_steps} | Loss: {self.last_loss:.2f} | Tokens: {tokens_k:.1f}k"
960
+ else:
961
+ status = ""
962
+ self.query_one("#status-right", Label).update(status)
963
+
964
+ def action_clear(self) -> None:
965
+ out = self.query_one("#output", VerticalScroll)
966
+ out.remove_children()
967
+ out.mount(Label("", id="spacer"))
968
+ if self.session:
969
+ self.session.clear()
970
+ self._update_status()
971
+ self.output("Cleared.", style="system")
972
+ self.call_after_refresh(self._sync_focus)
973
+
974
+ def action_save_checkpoint(self) -> None:
975
+ """Save current LoRA weights to disk (Ctrl+S)."""
976
+ if not self.session or self.session.training_client is None:
977
+ self.output("No training session to save.", style="system")
978
+ return
979
+
980
+ # Create checkpoint directory if needed
981
+ CHECKPOINT_DIR.mkdir(parents=True, exist_ok=True)
982
+
983
+ # Generate checkpoint name with step count
984
+ step = self.session.train_steps
985
+ checkpoint_path = CHECKPOINT_DIR / f"step_{step:06d}"
986
+
987
+ if self.session.save_checkpoint(str(checkpoint_path)):
988
+ self.output(f"Saved checkpoint: {checkpoint_path}", style="system")
989
+ else:
990
+ self.output("Failed to save checkpoint.", style="system")
991
+
992
+
993
+ if __name__ == "__main__":
994
+ TinkerCodeApp().run()