aru-code 0.20.1__tar.gz → 0.22.0__tar.gz
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- {aru_code-0.20.1/aru_code.egg-info → aru_code-0.22.0}/PKG-INFO +1 -1
- aru_code-0.22.0/aru/__init__.py +1 -0
- aru_code-0.22.0/aru/checkpoints.py +189 -0
- {aru_code-0.20.1 → aru_code-0.22.0}/aru/cli.py +167 -0
- {aru_code-0.20.1 → aru_code-0.22.0}/aru/commands.py +2 -0
- {aru_code-0.20.1 → aru_code-0.22.0}/aru/runtime.py +3 -0
- {aru_code-0.20.1 → aru_code-0.22.0}/aru/session.py +18 -0
- {aru_code-0.20.1 → aru_code-0.22.0}/aru/tools/codebase.py +167 -19
- {aru_code-0.20.1 → aru_code-0.22.0/aru_code.egg-info}/PKG-INFO +1 -1
- {aru_code-0.20.1 → aru_code-0.22.0}/aru_code.egg-info/SOURCES.txt +2 -0
- {aru_code-0.20.1 → aru_code-0.22.0}/pyproject.toml +1 -1
- aru_code-0.22.0/tests/test_checkpoints.py +190 -0
- aru_code-0.20.1/aru/__init__.py +0 -1
- {aru_code-0.20.1 → aru_code-0.22.0}/LICENSE +0 -0
- {aru_code-0.20.1 → aru_code-0.22.0}/README.md +0 -0
- {aru_code-0.20.1 → aru_code-0.22.0}/aru/agent_factory.py +0 -0
- {aru_code-0.20.1 → aru_code-0.22.0}/aru/agents/__init__.py +0 -0
- {aru_code-0.20.1 → aru_code-0.22.0}/aru/agents/base.py +0 -0
- {aru_code-0.20.1 → aru_code-0.22.0}/aru/agents/executor.py +0 -0
- {aru_code-0.20.1 → aru_code-0.22.0}/aru/agents/planner.py +0 -0
- {aru_code-0.20.1 → aru_code-0.22.0}/aru/cache_patch.py +0 -0
- {aru_code-0.20.1 → aru_code-0.22.0}/aru/completers.py +0 -0
- {aru_code-0.20.1 → aru_code-0.22.0}/aru/config.py +0 -0
- {aru_code-0.20.1 → aru_code-0.22.0}/aru/context.py +0 -0
- {aru_code-0.20.1 → aru_code-0.22.0}/aru/display.py +0 -0
- {aru_code-0.20.1 → aru_code-0.22.0}/aru/history_blocks.py +0 -0
- {aru_code-0.20.1 → aru_code-0.22.0}/aru/permissions.py +0 -0
- {aru_code-0.20.1 → aru_code-0.22.0}/aru/plugins/__init__.py +0 -0
- {aru_code-0.20.1 → aru_code-0.22.0}/aru/plugins/custom_tools.py +0 -0
- {aru_code-0.20.1 → aru_code-0.22.0}/aru/plugins/hooks.py +0 -0
- {aru_code-0.20.1 → aru_code-0.22.0}/aru/plugins/manager.py +0 -0
- {aru_code-0.20.1 → aru_code-0.22.0}/aru/plugins/tool_api.py +0 -0
- {aru_code-0.20.1 → aru_code-0.22.0}/aru/providers.py +0 -0
- {aru_code-0.20.1 → aru_code-0.22.0}/aru/runner.py +0 -0
- {aru_code-0.20.1 → aru_code-0.22.0}/aru/tools/__init__.py +0 -0
- {aru_code-0.20.1 → aru_code-0.22.0}/aru/tools/ast_tools.py +0 -0
- {aru_code-0.20.1 → aru_code-0.22.0}/aru/tools/gitignore.py +0 -0
- {aru_code-0.20.1 → aru_code-0.22.0}/aru/tools/mcp_client.py +0 -0
- {aru_code-0.20.1 → aru_code-0.22.0}/aru/tools/ranker.py +0 -0
- {aru_code-0.20.1 → aru_code-0.22.0}/aru/tools/tasklist.py +0 -0
- {aru_code-0.20.1 → aru_code-0.22.0}/aru_code.egg-info/dependency_links.txt +0 -0
- {aru_code-0.20.1 → aru_code-0.22.0}/aru_code.egg-info/entry_points.txt +0 -0
- {aru_code-0.20.1 → aru_code-0.22.0}/aru_code.egg-info/requires.txt +0 -0
- {aru_code-0.20.1 → aru_code-0.22.0}/aru_code.egg-info/top_level.txt +0 -0
- {aru_code-0.20.1 → aru_code-0.22.0}/setup.cfg +0 -0
- {aru_code-0.20.1 → aru_code-0.22.0}/tests/test_agents_base.py +0 -0
- {aru_code-0.20.1 → aru_code-0.22.0}/tests/test_cli.py +0 -0
- {aru_code-0.20.1 → aru_code-0.22.0}/tests/test_cli_advanced.py +0 -0
- {aru_code-0.20.1 → aru_code-0.22.0}/tests/test_cli_base.py +0 -0
- {aru_code-0.20.1 → aru_code-0.22.0}/tests/test_cli_completers.py +0 -0
- {aru_code-0.20.1 → aru_code-0.22.0}/tests/test_cli_new.py +0 -0
- {aru_code-0.20.1 → aru_code-0.22.0}/tests/test_cli_run_cli.py +0 -0
- {aru_code-0.20.1 → aru_code-0.22.0}/tests/test_cli_session.py +0 -0
- {aru_code-0.20.1 → aru_code-0.22.0}/tests/test_cli_shell.py +0 -0
- {aru_code-0.20.1 → aru_code-0.22.0}/tests/test_codebase.py +0 -0
- {aru_code-0.20.1 → aru_code-0.22.0}/tests/test_confabulation_regression.py +0 -0
- {aru_code-0.20.1 → aru_code-0.22.0}/tests/test_config.py +0 -0
- {aru_code-0.20.1 → aru_code-0.22.0}/tests/test_context.py +0 -0
- {aru_code-0.20.1 → aru_code-0.22.0}/tests/test_executor.py +0 -0
- {aru_code-0.20.1 → aru_code-0.22.0}/tests/test_gitignore.py +0 -0
- {aru_code-0.20.1 → aru_code-0.22.0}/tests/test_main.py +0 -0
- {aru_code-0.20.1 → aru_code-0.22.0}/tests/test_mcp_client.py +0 -0
- {aru_code-0.20.1 → aru_code-0.22.0}/tests/test_permissions.py +0 -0
- {aru_code-0.20.1 → aru_code-0.22.0}/tests/test_planner.py +0 -0
- {aru_code-0.20.1 → aru_code-0.22.0}/tests/test_plugins.py +0 -0
- {aru_code-0.20.1 → aru_code-0.22.0}/tests/test_providers.py +0 -0
- {aru_code-0.20.1 → aru_code-0.22.0}/tests/test_ranker.py +0 -0
|
@@ -0,0 +1 @@
|
|
|
1
|
+
__version__ = "0.22.0"
|
|
@@ -0,0 +1,189 @@
|
|
|
1
|
+
"""File checkpoint system for undo/rewind support.
|
|
2
|
+
|
|
3
|
+
Tracks file state before tool mutations so changes can be reverted.
|
|
4
|
+
Inspired by Claude Code's fileHistory system.
|
|
5
|
+
|
|
6
|
+
Architecture:
|
|
7
|
+
- Each user message creates a "snapshot" identified by a turn index.
|
|
8
|
+
- Before any file mutation (write_file, edit_file, bash), the pre-edit
|
|
9
|
+
content is saved as a versioned backup in .aru/file-history/{session_id}/.
|
|
10
|
+
- On /undo, the most recent snapshot is applied: files are restored to
|
|
11
|
+
their pre-turn state and the conversation is rewound.
|
|
12
|
+
|
|
13
|
+
Backup naming: {sha256(path)[:16]}@v{version}
|
|
14
|
+
Snapshot: {turn_index: {file_path: BackupEntry}}
|
|
15
|
+
"""
|
|
16
|
+
|
|
17
|
+
from __future__ import annotations
|
|
18
|
+
|
|
19
|
+
import hashlib
|
|
20
|
+
import os
|
|
21
|
+
import shutil
|
|
22
|
+
import threading
|
|
23
|
+
from dataclasses import dataclass, field
|
|
24
|
+
|
|
25
|
+
|
|
26
|
+
@dataclass
|
|
27
|
+
class BackupEntry:
|
|
28
|
+
"""A single file backup."""
|
|
29
|
+
backup_path: str | None # None = file didn't exist before this turn
|
|
30
|
+
version: int
|
|
31
|
+
original_path: str
|
|
32
|
+
|
|
33
|
+
|
|
34
|
+
@dataclass
|
|
35
|
+
class Snapshot:
|
|
36
|
+
"""Checkpoint at a specific conversation turn."""
|
|
37
|
+
turn_index: int
|
|
38
|
+
backups: dict[str, BackupEntry] = field(default_factory=dict) # abs_path → BackupEntry
|
|
39
|
+
|
|
40
|
+
|
|
41
|
+
MAX_SNAPSHOTS = 100
|
|
42
|
+
|
|
43
|
+
|
|
44
|
+
class CheckpointManager:
|
|
45
|
+
"""Manages file checkpoints for a session.
|
|
46
|
+
|
|
47
|
+
Thread-safe: multiple tools may run in parallel within a turn.
|
|
48
|
+
"""
|
|
49
|
+
|
|
50
|
+
def __init__(self, session_id: str, base_dir: str | None = None):
|
|
51
|
+
self._session_id = session_id
|
|
52
|
+
self._base_dir = base_dir or os.path.join(os.getcwd(), ".aru", "file-history", session_id)
|
|
53
|
+
self._lock = threading.Lock()
|
|
54
|
+
self._snapshots: list[Snapshot] = []
|
|
55
|
+
self._current_turn: int = 0
|
|
56
|
+
self._tracked_files: set[str] = set()
|
|
57
|
+
# Per-file version counter (monotonic)
|
|
58
|
+
self._file_versions: dict[str, int] = {}
|
|
59
|
+
self._dir_created = False
|
|
60
|
+
|
|
61
|
+
def _ensure_dir(self):
|
|
62
|
+
if not self._dir_created:
|
|
63
|
+
os.makedirs(self._base_dir, exist_ok=True)
|
|
64
|
+
self._dir_created = True
|
|
65
|
+
|
|
66
|
+
def _backup_filename(self, file_path: str, version: int) -> str:
|
|
67
|
+
path_hash = hashlib.sha256(file_path.encode("utf-8")).hexdigest()[:16]
|
|
68
|
+
return f"{path_hash}@v{version}"
|
|
69
|
+
|
|
70
|
+
def begin_turn(self, turn_index: int):
|
|
71
|
+
"""Start a new turn — creates a fresh snapshot for this turn."""
|
|
72
|
+
with self._lock:
|
|
73
|
+
self._current_turn = turn_index
|
|
74
|
+
# Create snapshot for this turn (backups added lazily as files are edited)
|
|
75
|
+
snapshot = Snapshot(turn_index=turn_index)
|
|
76
|
+
self._snapshots.append(snapshot)
|
|
77
|
+
# Enforce cap
|
|
78
|
+
if len(self._snapshots) > MAX_SNAPSHOTS:
|
|
79
|
+
evicted = self._snapshots.pop(0)
|
|
80
|
+
self._cleanup_snapshot_backups(evicted)
|
|
81
|
+
|
|
82
|
+
def track_edit(self, file_path: str):
|
|
83
|
+
"""Capture pre-edit state of a file before mutation.
|
|
84
|
+
|
|
85
|
+
Call this BEFORE writing/editing a file. If the file was already
|
|
86
|
+
captured in the current turn's snapshot, this is a no-op.
|
|
87
|
+
"""
|
|
88
|
+
abs_path = os.path.abspath(file_path)
|
|
89
|
+
|
|
90
|
+
with self._lock:
|
|
91
|
+
if not self._snapshots:
|
|
92
|
+
return
|
|
93
|
+
|
|
94
|
+
current_snapshot = self._snapshots[-1]
|
|
95
|
+
|
|
96
|
+
# Already tracked in this turn
|
|
97
|
+
if abs_path in current_snapshot.backups:
|
|
98
|
+
return
|
|
99
|
+
|
|
100
|
+
# Increment version
|
|
101
|
+
version = self._file_versions.get(abs_path, 0) + 1
|
|
102
|
+
self._file_versions[abs_path] = version
|
|
103
|
+
self._tracked_files.add(abs_path)
|
|
104
|
+
|
|
105
|
+
# Read file outside lock (IO)
|
|
106
|
+
backup_path = None
|
|
107
|
+
if os.path.isfile(abs_path):
|
|
108
|
+
self._ensure_dir()
|
|
109
|
+
backup_name = self._backup_filename(abs_path, version)
|
|
110
|
+
backup_path = os.path.join(self._base_dir, backup_name)
|
|
111
|
+
try:
|
|
112
|
+
shutil.copy2(abs_path, backup_path)
|
|
113
|
+
except OSError:
|
|
114
|
+
backup_path = None
|
|
115
|
+
|
|
116
|
+
# Commit to snapshot
|
|
117
|
+
with self._lock:
|
|
118
|
+
if not self._snapshots:
|
|
119
|
+
return
|
|
120
|
+
entry = BackupEntry(
|
|
121
|
+
backup_path=backup_path,
|
|
122
|
+
version=version,
|
|
123
|
+
original_path=abs_path,
|
|
124
|
+
)
|
|
125
|
+
self._snapshots[-1].backups[abs_path] = entry
|
|
126
|
+
|
|
127
|
+
def undo_last_turn(self) -> tuple[list[str], int]:
|
|
128
|
+
"""Revert files changed in the most recent snapshot.
|
|
129
|
+
|
|
130
|
+
Returns:
|
|
131
|
+
(list of restored file paths, turn_index that was undone)
|
|
132
|
+
"""
|
|
133
|
+
with self._lock:
|
|
134
|
+
if not self._snapshots:
|
|
135
|
+
return [], 0
|
|
136
|
+
snapshot = self._snapshots.pop()
|
|
137
|
+
|
|
138
|
+
restored = []
|
|
139
|
+
for abs_path, entry in snapshot.backups.items():
|
|
140
|
+
try:
|
|
141
|
+
if entry.backup_path is None:
|
|
142
|
+
# File didn't exist before — delete it
|
|
143
|
+
if os.path.isfile(abs_path):
|
|
144
|
+
os.unlink(abs_path)
|
|
145
|
+
restored.append(abs_path)
|
|
146
|
+
elif os.path.isfile(entry.backup_path):
|
|
147
|
+
# Restore from backup
|
|
148
|
+
shutil.copy2(entry.backup_path, abs_path)
|
|
149
|
+
restored.append(abs_path)
|
|
150
|
+
except OSError:
|
|
151
|
+
pass # best effort
|
|
152
|
+
|
|
153
|
+
return restored, snapshot.turn_index
|
|
154
|
+
|
|
155
|
+
def get_snapshot_count(self) -> int:
|
|
156
|
+
with self._lock:
|
|
157
|
+
return len(self._snapshots)
|
|
158
|
+
|
|
159
|
+
def get_last_snapshot_files(self) -> list[str]:
|
|
160
|
+
"""Return files that would be affected by undo."""
|
|
161
|
+
with self._lock:
|
|
162
|
+
if not self._snapshots:
|
|
163
|
+
return []
|
|
164
|
+
return list(self._snapshots[-1].backups.keys())
|
|
165
|
+
|
|
166
|
+
def _cleanup_snapshot_backups(self, snapshot: Snapshot):
|
|
167
|
+
"""Remove backup files for an evicted snapshot (if not referenced by others)."""
|
|
168
|
+
# Collect all backup paths still referenced
|
|
169
|
+
referenced = set()
|
|
170
|
+
for s in self._snapshots:
|
|
171
|
+
for entry in s.backups.values():
|
|
172
|
+
if entry.backup_path:
|
|
173
|
+
referenced.add(entry.backup_path)
|
|
174
|
+
|
|
175
|
+
# Delete unreferenced backups
|
|
176
|
+
for entry in snapshot.backups.values():
|
|
177
|
+
if entry.backup_path and entry.backup_path not in referenced:
|
|
178
|
+
try:
|
|
179
|
+
os.unlink(entry.backup_path)
|
|
180
|
+
except OSError:
|
|
181
|
+
pass
|
|
182
|
+
|
|
183
|
+
def cleanup(self):
|
|
184
|
+
"""Remove all backup files for this session."""
|
|
185
|
+
try:
|
|
186
|
+
if os.path.isdir(self._base_dir):
|
|
187
|
+
shutil.rmtree(self._base_dir, ignore_errors=True)
|
|
188
|
+
except OSError:
|
|
189
|
+
pass
|
|
@@ -203,6 +203,11 @@ async def run_cli(skip_permissions: bool = False, resume_id: str | None = None):
|
|
|
203
203
|
ctx.on_file_mutation = session.invalidate_context_cache
|
|
204
204
|
atexit.register(lambda: cleanup_processes(ctx.tracked_processes))
|
|
205
205
|
|
|
206
|
+
# Initialize checkpoint manager for undo/rewind support
|
|
207
|
+
from aru.checkpoints import CheckpointManager
|
|
208
|
+
ctx.checkpoint_manager = CheckpointManager(session.session_id)
|
|
209
|
+
_turn_counter = 0
|
|
210
|
+
|
|
206
211
|
planner = None
|
|
207
212
|
executor = None
|
|
208
213
|
paste_state = PasteState()
|
|
@@ -329,6 +334,64 @@ async def run_cli(skip_permissions: bool = False, resume_id: str | None = None):
|
|
|
329
334
|
# Reset "allow all" approvals for each new user message
|
|
330
335
|
perm_reset_session()
|
|
331
336
|
|
|
337
|
+
if user_input.lower() == "/undo":
|
|
338
|
+
affected_files = ctx.checkpoint_manager.get_last_snapshot_files()
|
|
339
|
+
if not affected_files and not session.history:
|
|
340
|
+
console.print("[dim]Nothing to undo.[/dim]")
|
|
341
|
+
continue
|
|
342
|
+
|
|
343
|
+
# Show what will be reverted
|
|
344
|
+
if affected_files:
|
|
345
|
+
cwd = os.getcwd()
|
|
346
|
+
console.print("[bold]Files that will be restored:[/bold]")
|
|
347
|
+
for f in affected_files:
|
|
348
|
+
rel = os.path.relpath(f, cwd) if f.startswith(cwd) else f
|
|
349
|
+
console.print(f" [cyan]{rel}[/cyan]")
|
|
350
|
+
|
|
351
|
+
console.print()
|
|
352
|
+
console.print("[bold]Restore options:[/bold]")
|
|
353
|
+
console.print(" [cyan](b)[/cyan] Restore code and conversation (both)")
|
|
354
|
+
console.print(" [cyan](c)[/cyan] Restore only code (keep conversation)")
|
|
355
|
+
console.print(" [cyan](v)[/cyan] Restore only conversation (keep code)")
|
|
356
|
+
console.print(" [cyan](n)[/cyan] Cancel")
|
|
357
|
+
try:
|
|
358
|
+
choice = console.input("[bold yellow]Choice (b/c/v/n):[/bold yellow] ").strip().lower()
|
|
359
|
+
except (EOFError, KeyboardInterrupt):
|
|
360
|
+
choice = "n"
|
|
361
|
+
|
|
362
|
+
if choice in ("n", ""):
|
|
363
|
+
console.print("[dim]Cancelled.[/dim]")
|
|
364
|
+
continue
|
|
365
|
+
|
|
366
|
+
restored_files = []
|
|
367
|
+
msgs_removed = 0
|
|
368
|
+
|
|
369
|
+
if choice in ("b", "c"):
|
|
370
|
+
# Restore files from checkpoint
|
|
371
|
+
restored_files, _ = ctx.checkpoint_manager.undo_last_turn()
|
|
372
|
+
|
|
373
|
+
if choice in ("b", "v"):
|
|
374
|
+
# Remove last turn from conversation
|
|
375
|
+
msgs_removed = session.undo_last_turn()
|
|
376
|
+
|
|
377
|
+
parts = []
|
|
378
|
+
if restored_files:
|
|
379
|
+
cwd = os.getcwd()
|
|
380
|
+
for f in restored_files:
|
|
381
|
+
rel = os.path.relpath(f, cwd) if f.startswith(cwd) else f
|
|
382
|
+
parts.append(f" [cyan]{rel}[/cyan]")
|
|
383
|
+
console.print(f"[green]Restored {len(restored_files)} file(s):[/green]")
|
|
384
|
+
for p in parts:
|
|
385
|
+
console.print(p)
|
|
386
|
+
session.invalidate_context_cache()
|
|
387
|
+
if msgs_removed:
|
|
388
|
+
console.print(f"[green]Removed {msgs_removed} message(s) from conversation.[/green]")
|
|
389
|
+
if not restored_files and not msgs_removed:
|
|
390
|
+
console.print("[dim]Nothing was changed.[/dim]")
|
|
391
|
+
else:
|
|
392
|
+
store.save(session)
|
|
393
|
+
continue
|
|
394
|
+
|
|
332
395
|
if user_input.lower() in ("/quit", "/exit", "quit", "exit"):
|
|
333
396
|
store.save(session)
|
|
334
397
|
console.print(f"\n[dim]Session saved: {session.session_id}[/dim]")
|
|
@@ -455,6 +518,10 @@ async def run_cli(skip_permissions: bool = False, resume_id: str | None = None):
|
|
|
455
518
|
))
|
|
456
519
|
continue
|
|
457
520
|
|
|
521
|
+
# Begin a new checkpoint turn for undo support
|
|
522
|
+
_turn_counter += 1
|
|
523
|
+
ctx.checkpoint_manager.begin_turn(_turn_counter)
|
|
524
|
+
|
|
458
525
|
if user_input.startswith("! "):
|
|
459
526
|
cmd = user_input[2:].strip()
|
|
460
527
|
if not cmd:
|
|
@@ -609,6 +676,72 @@ def _list_sessions_and_exit():
|
|
|
609
676
|
console.print(f"\n[dim]Resume with: aru --resume <id>[/dim]")
|
|
610
677
|
|
|
611
678
|
|
|
679
|
+
async def run_oneshot(prompt: str, print_only: bool = False, skip_permissions: bool = False):
|
|
680
|
+
"""Run a single prompt non-interactively and exit.
|
|
681
|
+
|
|
682
|
+
Args:
|
|
683
|
+
prompt: The user prompt to execute.
|
|
684
|
+
print_only: If True, run without tools (text-only response).
|
|
685
|
+
skip_permissions: If True, skip all permission checks.
|
|
686
|
+
"""
|
|
687
|
+
from aru.runtime import init_ctx
|
|
688
|
+
from aru.config import load_config
|
|
689
|
+
from aru.cache_patch import apply_cache_patch
|
|
690
|
+
|
|
691
|
+
apply_cache_patch()
|
|
692
|
+
ctx = init_ctx(console=console, skip_permissions=skip_permissions)
|
|
693
|
+
|
|
694
|
+
config = load_config()
|
|
695
|
+
session = Session()
|
|
696
|
+
if config.default_model:
|
|
697
|
+
session.model_ref = config.default_model
|
|
698
|
+
|
|
699
|
+
ctx.model_id = session.model_id
|
|
700
|
+
small_ref = config.model_aliases.get("small") if config else None
|
|
701
|
+
if not small_ref:
|
|
702
|
+
from aru.providers import resolve_model_ref
|
|
703
|
+
provider_key, _ = resolve_model_ref(session.model_ref)
|
|
704
|
+
_small_defaults = {
|
|
705
|
+
"anthropic": "anthropic/claude-haiku-4-5",
|
|
706
|
+
"openai": "openai/gpt-4o-mini",
|
|
707
|
+
"groq": "groq/llama-3.1-8b-instant",
|
|
708
|
+
"deepseek": "deepseek/deepseek-chat",
|
|
709
|
+
"ollama": "ollama/llama3.1",
|
|
710
|
+
}
|
|
711
|
+
small_ref = _small_defaults.get(provider_key, session.model_ref)
|
|
712
|
+
ctx.small_model_ref = small_ref
|
|
713
|
+
|
|
714
|
+
extra_instructions = config.get_extra_instructions()
|
|
715
|
+
|
|
716
|
+
if print_only:
|
|
717
|
+
# Text-only mode: no tools, just a direct LLM call
|
|
718
|
+
from agno.agent import Agent
|
|
719
|
+
from aru.providers import create_model
|
|
720
|
+
from aru.agents.base import build_instructions
|
|
721
|
+
|
|
722
|
+
agent = Agent(
|
|
723
|
+
name="Aru",
|
|
724
|
+
model=create_model(session.model_ref, max_tokens=8192),
|
|
725
|
+
tools=[],
|
|
726
|
+
instructions=build_instructions("general", extra_instructions),
|
|
727
|
+
markdown=True,
|
|
728
|
+
)
|
|
729
|
+
response = await agent.arun(prompt)
|
|
730
|
+
if response and response.content:
|
|
731
|
+
# Print raw text to stdout for piping
|
|
732
|
+
print(response.content)
|
|
733
|
+
else:
|
|
734
|
+
# Full mode with tools
|
|
735
|
+
from aru.runner import build_env_context
|
|
736
|
+
env_ctx = build_env_context(session)
|
|
737
|
+
agent = create_general_agent(session, config, env_context=env_ctx)
|
|
738
|
+
session.add_message("user", prompt)
|
|
739
|
+
await run_agent_capture(agent, prompt, session)
|
|
740
|
+
|
|
741
|
+
if session.token_summary:
|
|
742
|
+
console.print(f"[dim]{session.token_summary}[/dim]")
|
|
743
|
+
|
|
744
|
+
|
|
612
745
|
def main():
|
|
613
746
|
"""Entry point for the aru CLI."""
|
|
614
747
|
from dotenv import load_dotenv
|
|
@@ -616,6 +749,7 @@ def main():
|
|
|
616
749
|
load_dotenv()
|
|
617
750
|
args = sys.argv[1:]
|
|
618
751
|
skip_permissions = "--dangerously-skip-permissions" in args
|
|
752
|
+
print_only = "--print" in args or "-p" in args
|
|
619
753
|
|
|
620
754
|
if "--list" in args:
|
|
621
755
|
_list_sessions_and_exit()
|
|
@@ -629,6 +763,39 @@ def main():
|
|
|
629
763
|
else:
|
|
630
764
|
resume_id = "last"
|
|
631
765
|
|
|
766
|
+
# Collect positional arguments (non-flag, non-flag-value)
|
|
767
|
+
flags_with_value = {"--resume"}
|
|
768
|
+
positional = []
|
|
769
|
+
skip_next = False
|
|
770
|
+
for i, arg in enumerate(args):
|
|
771
|
+
if skip_next:
|
|
772
|
+
skip_next = False
|
|
773
|
+
continue
|
|
774
|
+
if arg.startswith("--") or arg.startswith("-"):
|
|
775
|
+
if arg in flags_with_value:
|
|
776
|
+
skip_next = True
|
|
777
|
+
continue
|
|
778
|
+
positional.append(arg)
|
|
779
|
+
|
|
780
|
+
# Piped stdin: echo "fix bug" | aru
|
|
781
|
+
if not sys.stdin.isatty() and not positional:
|
|
782
|
+
piped_input = sys.stdin.read().strip()
|
|
783
|
+
if piped_input:
|
|
784
|
+
positional = [piped_input]
|
|
785
|
+
|
|
786
|
+
# One-shot mode: aru "fix the bug" or aru --print "explain this"
|
|
787
|
+
if positional:
|
|
788
|
+
prompt = " ".join(positional)
|
|
789
|
+
try:
|
|
790
|
+
asyncio.run(run_oneshot(prompt, print_only=print_only, skip_permissions=skip_permissions))
|
|
791
|
+
except (KeyboardInterrupt, asyncio.CancelledError, SystemExit):
|
|
792
|
+
pass
|
|
793
|
+
except Exception as e:
|
|
794
|
+
from rich.markup import escape
|
|
795
|
+
console.print(f"\n[bold red]Fatal error: {escape(str(e))}[/bold red]")
|
|
796
|
+
return
|
|
797
|
+
|
|
798
|
+
# Interactive REPL mode
|
|
632
799
|
try:
|
|
633
800
|
asyncio.run(run_cli(skip_permissions=skip_permissions, resume_id=resume_id))
|
|
634
801
|
except (KeyboardInterrupt, asyncio.CancelledError, SystemExit):
|
|
@@ -21,6 +21,7 @@ SLASH_COMMANDS = [
|
|
|
21
21
|
("/skills", "List available skills", "/skills"),
|
|
22
22
|
("/agents", "List custom agents", "/agents"),
|
|
23
23
|
("/mcp", "List loaded MCP tools", "/mcp"),
|
|
24
|
+
("/undo", "Undo last turn — restore files and/or conversation", "/undo"),
|
|
24
25
|
("/cost", "Show detailed token usage and cost", "/cost"),
|
|
25
26
|
("/quit", "Exit aru", "/quit"),
|
|
26
27
|
]
|
|
@@ -83,6 +84,7 @@ def _show_help(config) -> None:
|
|
|
83
84
|
table.add_row("/skills", "List available skills")
|
|
84
85
|
table.add_row("/agents", "List custom agents")
|
|
85
86
|
table.add_row("/mcp", "List loaded MCP tools")
|
|
87
|
+
table.add_row("/undo", "Undo last turn (restore files and/or conversation)")
|
|
86
88
|
table.add_row("/help", "Show this help")
|
|
87
89
|
table.add_row("/quit", "Exit aru")
|
|
88
90
|
table.add_row("! <cmd>", "Run shell command")
|
|
@@ -122,6 +122,9 @@ class RuntimeContext:
|
|
|
122
122
|
# -- Plugins --
|
|
123
123
|
plugin_manager: Any = None # aru.plugins.manager.PluginManager (lazy to avoid circular)
|
|
124
124
|
|
|
125
|
+
# -- Checkpoints --
|
|
126
|
+
checkpoint_manager: Any = None # aru.checkpoints.CheckpointManager (lazy)
|
|
127
|
+
|
|
125
128
|
|
|
126
129
|
# ── ContextVar plumbing ──────────────────────────────────────────────
|
|
127
130
|
|
|
@@ -386,6 +386,24 @@ class Session:
|
|
|
386
386
|
return f"[yellow]Token budget at {pct:.0f}%[/yellow]"
|
|
387
387
|
return None
|
|
388
388
|
|
|
389
|
+
def undo_last_turn(self) -> int:
|
|
390
|
+
"""Remove the last complete turn (user message + assistant/tool responses).
|
|
391
|
+
|
|
392
|
+
Pops backward from the end of history until the last user message
|
|
393
|
+
(inclusive) is removed. Returns the number of messages removed.
|
|
394
|
+
"""
|
|
395
|
+
if not self.history:
|
|
396
|
+
return 0
|
|
397
|
+
removed = 0
|
|
398
|
+
# Pop from the end until we've removed one user message
|
|
399
|
+
while self.history:
|
|
400
|
+
msg = self.history.pop()
|
|
401
|
+
removed += 1
|
|
402
|
+
if msg["role"] == "user":
|
|
403
|
+
break
|
|
404
|
+
self.updated_at = datetime.now().isoformat(timespec="milliseconds")
|
|
405
|
+
return removed
|
|
406
|
+
|
|
389
407
|
def add_message(self, role: str, content):
|
|
390
408
|
"""Append a message to history.
|
|
391
409
|
|
|
@@ -31,6 +31,16 @@ def _notify_file_mutation():
|
|
|
31
31
|
ctx.on_file_mutation()
|
|
32
32
|
|
|
33
33
|
|
|
34
|
+
def _checkpoint_file(file_path: str):
|
|
35
|
+
"""Capture pre-edit state of a file for undo support.
|
|
36
|
+
|
|
37
|
+
Must be called BEFORE writing/editing the file.
|
|
38
|
+
"""
|
|
39
|
+
ctx = get_ctx()
|
|
40
|
+
if ctx.checkpoint_manager:
|
|
41
|
+
ctx.checkpoint_manager.track_edit(file_path)
|
|
42
|
+
|
|
43
|
+
|
|
34
44
|
def _get_small_model_ref() -> str:
|
|
35
45
|
"""Get the small model reference for sub-agents."""
|
|
36
46
|
return get_ctx().small_model_ref
|
|
@@ -266,6 +276,7 @@ def write_file(file_path: str, content: str) -> str:
|
|
|
266
276
|
if not check_permission("write", file_path, Group(header, Text(), diff)):
|
|
267
277
|
return f"PERMISSION DENIED by user: write to {file_path}. Do NOT retry this operation. Stop and ask the user for new instructions."
|
|
268
278
|
try:
|
|
279
|
+
_checkpoint_file(file_path)
|
|
269
280
|
os.makedirs(os.path.dirname(file_path) or ".", exist_ok=True)
|
|
270
281
|
with open(file_path, "w", encoding="utf-8") as f:
|
|
271
282
|
f.write(content)
|
|
@@ -305,6 +316,7 @@ def write_files(file_list: list[dict]) -> str:
|
|
|
305
316
|
errors.append("Error: missing 'path' in entry")
|
|
306
317
|
continue
|
|
307
318
|
try:
|
|
319
|
+
_checkpoint_file(path)
|
|
308
320
|
os.makedirs(os.path.dirname(path) or ".", exist_ok=True)
|
|
309
321
|
with open(path, "w", encoding="utf-8") as f:
|
|
310
322
|
f.write(content)
|
|
@@ -363,6 +375,7 @@ def edit_file(file_path: str, old_string: str, new_string: str) -> str:
|
|
|
363
375
|
if not check_permission("edit", file_path, Group(header, Text(), diff)):
|
|
364
376
|
return f"PERMISSION DENIED by user: edit {file_path}. Do NOT retry this operation. Stop and ask the user for new instructions."
|
|
365
377
|
try:
|
|
378
|
+
_checkpoint_file(file_path)
|
|
366
379
|
with open(file_path, "r", encoding="utf-8") as f:
|
|
367
380
|
content = f.read()
|
|
368
381
|
|
|
@@ -424,6 +437,7 @@ def edit_files(edits: list[dict]) -> str:
|
|
|
424
437
|
continue
|
|
425
438
|
try:
|
|
426
439
|
if path not in cache:
|
|
440
|
+
_checkpoint_file(path)
|
|
427
441
|
with open(path, "r", encoding="utf-8") as f:
|
|
428
442
|
cache[path] = f.read()
|
|
429
443
|
|
|
@@ -915,37 +929,80 @@ async def bash(command: str, timeout: int = 60, working_directory: str = "") ->
|
|
|
915
929
|
|
|
916
930
|
|
|
917
931
|
class _HTMLToText(html.parser.HTMLParser):
|
|
918
|
-
"""
|
|
932
|
+
"""HTML-to-text converter with improved content extraction."""
|
|
919
933
|
|
|
920
|
-
SKIP_TAGS = {"script", "style", "svg", "noscript", "head"
|
|
934
|
+
SKIP_TAGS = {"script", "style", "svg", "noscript", "head", "nav", "footer",
|
|
935
|
+
"iframe", "form", "button", "input", "select", "textarea"}
|
|
921
936
|
BLOCK_TAGS = {"p", "div", "br", "h1", "h2", "h3", "h4", "h5", "h6",
|
|
922
|
-
"li", "tr", "blockquote", "pre", "section", "article",
|
|
937
|
+
"li", "tr", "blockquote", "pre", "section", "article",
|
|
938
|
+
"header", "main", "figcaption", "details", "summary", "dt", "dd"}
|
|
939
|
+
HEADING_TAGS = {"h1", "h2", "h3", "h4", "h5", "h6"}
|
|
940
|
+
LIST_TAGS = {"li"}
|
|
923
941
|
|
|
924
942
|
def __init__(self):
|
|
925
943
|
super().__init__()
|
|
926
944
|
self._pieces: list[str] = []
|
|
927
945
|
self._skip_depth = 0
|
|
946
|
+
self._in_pre = False
|
|
947
|
+
self._in_anchor = False
|
|
948
|
+
self._anchor_href = ""
|
|
928
949
|
|
|
929
950
|
def handle_starttag(self, tag, attrs):
|
|
930
951
|
if tag in self.SKIP_TAGS:
|
|
931
952
|
self._skip_depth += 1
|
|
932
|
-
elif
|
|
953
|
+
elif self._skip_depth:
|
|
954
|
+
return
|
|
955
|
+
elif tag == "pre":
|
|
956
|
+
self._in_pre = True
|
|
957
|
+
self._pieces.append("\n```\n")
|
|
958
|
+
elif tag == "code" and not self._in_pre:
|
|
959
|
+
self._pieces.append("`")
|
|
960
|
+
elif tag == "a":
|
|
961
|
+
self._in_anchor = True
|
|
962
|
+
attrs_dict = dict(attrs)
|
|
963
|
+
self._anchor_href = attrs_dict.get("href", "")
|
|
964
|
+
elif tag in self.HEADING_TAGS:
|
|
965
|
+
level = int(tag[1])
|
|
966
|
+
self._pieces.append(f"\n{'#' * level} ")
|
|
967
|
+
elif tag in self.LIST_TAGS:
|
|
968
|
+
self._pieces.append("\n- ")
|
|
969
|
+
elif tag in self.BLOCK_TAGS:
|
|
970
|
+
self._pieces.append("\n")
|
|
971
|
+
elif tag == "br":
|
|
933
972
|
self._pieces.append("\n")
|
|
934
973
|
|
|
935
974
|
def handle_endtag(self, tag):
|
|
936
975
|
if tag in self.SKIP_TAGS:
|
|
937
976
|
self._skip_depth = max(0, self._skip_depth - 1)
|
|
938
|
-
elif
|
|
977
|
+
elif self._skip_depth:
|
|
978
|
+
return
|
|
979
|
+
elif tag == "pre":
|
|
980
|
+
self._in_pre = False
|
|
981
|
+
self._pieces.append("\n```\n")
|
|
982
|
+
elif tag == "code" and not self._in_pre:
|
|
983
|
+
self._pieces.append("`")
|
|
984
|
+
elif tag == "a":
|
|
985
|
+
if self._anchor_href and not self._anchor_href.startswith(("#", "javascript:")):
|
|
986
|
+
self._pieces.append(f" ({self._anchor_href})")
|
|
987
|
+
self._in_anchor = False
|
|
988
|
+
self._anchor_href = ""
|
|
989
|
+
elif tag in self.HEADING_TAGS:
|
|
990
|
+
self._pieces.append("\n")
|
|
991
|
+
elif tag in self.BLOCK_TAGS:
|
|
939
992
|
self._pieces.append("\n")
|
|
940
993
|
|
|
941
994
|
def handle_data(self, data):
|
|
942
995
|
if not self._skip_depth:
|
|
943
|
-
self.
|
|
996
|
+
if self._in_pre:
|
|
997
|
+
self._pieces.append(data)
|
|
998
|
+
else:
|
|
999
|
+
self._pieces.append(data)
|
|
944
1000
|
|
|
945
1001
|
def get_text(self) -> str:
|
|
946
1002
|
raw = "".join(self._pieces)
|
|
947
1003
|
# Collapse whitespace within lines, preserve line breaks
|
|
948
|
-
lines = [" ".join(line.split())
|
|
1004
|
+
lines = [" ".join(line.split()) if not line.startswith("```") else line
|
|
1005
|
+
for line in raw.splitlines()]
|
|
949
1006
|
# Collapse multiple blank lines
|
|
950
1007
|
text = re.sub(r"\n{3,}", "\n\n", "\n".join(lines))
|
|
951
1008
|
return text.strip()
|
|
@@ -967,6 +1024,67 @@ def web_search(query: str, max_results: int = 5) -> str:
|
|
|
967
1024
|
import re as _re
|
|
968
1025
|
import urllib.parse
|
|
969
1026
|
|
|
1027
|
+
# Try DuckDuckGo Lite (simpler, more stable HTML than full version)
|
|
1028
|
+
results = _ddg_lite_search(query, max_results)
|
|
1029
|
+
if not results:
|
|
1030
|
+
# Fallback to DuckDuckGo HTML (classic scraping)
|
|
1031
|
+
results = _ddg_html_search(query, max_results)
|
|
1032
|
+
if not results:
|
|
1033
|
+
return f"No results found for: {query}"
|
|
1034
|
+
return "\n\n".join(results)
|
|
1035
|
+
|
|
1036
|
+
|
|
1037
|
+
def _ddg_lite_search(query: str, max_results: int) -> list[str]:
|
|
1038
|
+
"""Search via DuckDuckGo Lite — minimal HTML, more stable parsing."""
|
|
1039
|
+
import re as _re
|
|
1040
|
+
import urllib.parse
|
|
1041
|
+
|
|
1042
|
+
try:
|
|
1043
|
+
with httpx.Client(follow_redirects=True, timeout=15) as client:
|
|
1044
|
+
resp = client.post(
|
|
1045
|
+
"https://lite.duckduckgo.com/lite/",
|
|
1046
|
+
data={"q": query},
|
|
1047
|
+
headers={"User-Agent": "Mozilla/5.0 (Windows NT 10.0; Win64; x64) AppleWebKit/537.36"},
|
|
1048
|
+
)
|
|
1049
|
+
resp.raise_for_status()
|
|
1050
|
+
except httpx.RequestError:
|
|
1051
|
+
return []
|
|
1052
|
+
|
|
1053
|
+
html_text = resp.text
|
|
1054
|
+
results = []
|
|
1055
|
+
|
|
1056
|
+
# DuckDuckGo Lite uses table rows with class "result-link" for titles
|
|
1057
|
+
# and "result-snippet" for snippets
|
|
1058
|
+
link_pattern = _re.compile(
|
|
1059
|
+
r'<a[^>]+class="result-link"[^>]*href="([^"]*)"[^>]*>(.*?)</a>',
|
|
1060
|
+
_re.DOTALL,
|
|
1061
|
+
)
|
|
1062
|
+
snippet_pattern = _re.compile(
|
|
1063
|
+
r'<td[^>]+class="result-snippet"[^>]*>(.*?)</td>',
|
|
1064
|
+
_re.DOTALL,
|
|
1065
|
+
)
|
|
1066
|
+
|
|
1067
|
+
links = link_pattern.findall(html_text)
|
|
1068
|
+
snippets = snippet_pattern.findall(html_text)
|
|
1069
|
+
|
|
1070
|
+
for i, (url, title) in enumerate(links[:max_results]):
|
|
1071
|
+
title_clean = _re.sub(r"<[^>]+>", "", title).strip()
|
|
1072
|
+
snippet_clean = _re.sub(r"<[^>]+>", "", snippets[i]).strip() if i < len(snippets) else ""
|
|
1073
|
+
# Decode DuckDuckGo redirect URLs
|
|
1074
|
+
actual_url = url
|
|
1075
|
+
ud_match = _re.search(r"uddg=([^&]+)", url)
|
|
1076
|
+
if ud_match:
|
|
1077
|
+
actual_url = urllib.parse.unquote(ud_match.group(1))
|
|
1078
|
+
results.append(f"{i + 1}. {title_clean}\n {actual_url}\n {snippet_clean}")
|
|
1079
|
+
|
|
1080
|
+
return results
|
|
1081
|
+
|
|
1082
|
+
|
|
1083
|
+
def _ddg_html_search(query: str, max_results: int) -> list[str]:
|
|
1084
|
+
"""Fallback: search via DuckDuckGo HTML version."""
|
|
1085
|
+
import re as _re
|
|
1086
|
+
import urllib.parse
|
|
1087
|
+
|
|
970
1088
|
encoded = urllib.parse.quote_plus(query)
|
|
971
1089
|
url = f"https://html.duckduckgo.com/html/?q={encoded}"
|
|
972
1090
|
|
|
@@ -976,42 +1094,74 @@ def web_search(query: str, max_results: int = 5) -> str:
|
|
|
976
1094
|
"User-Agent": "Mozilla/5.0 (Windows NT 10.0; Win64; x64) AppleWebKit/537.36",
|
|
977
1095
|
})
|
|
978
1096
|
resp.raise_for_status()
|
|
979
|
-
except httpx.RequestError
|
|
980
|
-
return
|
|
1097
|
+
except httpx.RequestError:
|
|
1098
|
+
return []
|
|
981
1099
|
|
|
982
|
-
|
|
1100
|
+
html_text = resp.text
|
|
983
1101
|
results = []
|
|
984
1102
|
|
|
985
|
-
# Parse DuckDuckGo HTML results
|
|
986
1103
|
blocks = _re.findall(
|
|
987
1104
|
r'<a[^>]+class="result__a"[^>]*href="([^"]*)"[^>]*>(.*?)</a>.*?'
|
|
988
1105
|
r'<a[^>]+class="result__snippet"[^>]*>(.*?)</a>',
|
|
989
|
-
|
|
1106
|
+
html_text, _re.DOTALL,
|
|
990
1107
|
)
|
|
991
1108
|
|
|
992
1109
|
for i, (link, title, snippet) in enumerate(blocks[:max_results], 1):
|
|
993
|
-
# Clean HTML tags
|
|
994
1110
|
title_clean = _re.sub(r"<[^>]+>", "", title).strip()
|
|
995
1111
|
snippet_clean = _re.sub(r"<[^>]+>", "", snippet).strip()
|
|
996
|
-
# DuckDuckGo wraps URLs in a redirect — extract the actual URL
|
|
997
1112
|
actual_url = link
|
|
998
1113
|
ud_match = _re.search(r"uddg=([^&]+)", link)
|
|
999
1114
|
if ud_match:
|
|
1000
1115
|
actual_url = urllib.parse.unquote(ud_match.group(1))
|
|
1001
1116
|
results.append(f"{i}. {title_clean}\n {actual_url}\n {snippet_clean}")
|
|
1002
1117
|
|
|
1003
|
-
|
|
1004
|
-
return f"No results found for: {query}"
|
|
1005
|
-
return "\n\n".join(results)
|
|
1118
|
+
return results
|
|
1006
1119
|
|
|
1007
1120
|
|
|
1008
1121
|
def web_fetch(url: str, max_chars: int = 8000) -> str:
|
|
1009
1122
|
"""Fetch a URL and return content as text.
|
|
1010
1123
|
|
|
1124
|
+
Uses Jina Reader (r.jina.ai) for clean content extraction from HTML pages.
|
|
1125
|
+
Falls back to direct fetch with local HTML-to-text conversion if Jina is
|
|
1126
|
+
unavailable.
|
|
1127
|
+
|
|
1011
1128
|
Args:
|
|
1012
1129
|
url: The URL to fetch.
|
|
1013
1130
|
max_chars: Max characters to return (default 8000).
|
|
1014
1131
|
"""
|
|
1132
|
+
# Try Jina Reader first for HTML URLs — produces clean markdown
|
|
1133
|
+
if not url.endswith((".json", ".txt", ".xml", ".csv", ".pdf")):
|
|
1134
|
+
jina_text = _fetch_via_jina(url, max_chars)
|
|
1135
|
+
if jina_text:
|
|
1136
|
+
return _truncate_output(jina_text, source_tool="web_fetch")
|
|
1137
|
+
|
|
1138
|
+
# Direct fetch fallback
|
|
1139
|
+
return _fetch_direct(url, max_chars)
|
|
1140
|
+
|
|
1141
|
+
|
|
1142
|
+
def _fetch_via_jina(url: str, max_chars: int) -> str | None:
|
|
1143
|
+
"""Fetch URL content via Jina Reader for clean markdown output."""
|
|
1144
|
+
jina_url = f"https://r.jina.ai/{url}"
|
|
1145
|
+
try:
|
|
1146
|
+
with httpx.Client(follow_redirects=True, timeout=30) as client:
|
|
1147
|
+
resp = client.get(jina_url, headers={
|
|
1148
|
+
"Accept": "text/plain",
|
|
1149
|
+
"User-Agent": "Mozilla/5.0 (compatible; aru-agent/0.1)",
|
|
1150
|
+
})
|
|
1151
|
+
if resp.status_code != 200:
|
|
1152
|
+
return None
|
|
1153
|
+
text = resp.text.strip()
|
|
1154
|
+
if not text or len(text) < 50:
|
|
1155
|
+
return None
|
|
1156
|
+
if len(text) > max_chars:
|
|
1157
|
+
text = text[:max_chars] + f"\n\n... [truncated at {max_chars} chars]"
|
|
1158
|
+
return text
|
|
1159
|
+
except (httpx.RequestError, httpx.HTTPStatusError):
|
|
1160
|
+
return None
|
|
1161
|
+
|
|
1162
|
+
|
|
1163
|
+
def _fetch_direct(url: str, max_chars: int) -> str:
|
|
1164
|
+
"""Direct URL fetch with local HTML-to-text conversion."""
|
|
1015
1165
|
try:
|
|
1016
1166
|
with httpx.Client(follow_redirects=True, timeout=30) as client:
|
|
1017
1167
|
resp = client.get(url, headers={
|
|
@@ -1028,12 +1178,10 @@ def web_fetch(url: str, max_chars: int = 8000) -> str:
|
|
|
1028
1178
|
body = resp.text
|
|
1029
1179
|
|
|
1030
1180
|
if "json" in content_type:
|
|
1031
|
-
# JSON — return as-is (already readable)
|
|
1032
1181
|
text = body
|
|
1033
1182
|
elif "html" in content_type:
|
|
1034
1183
|
text = _html_to_text(body)
|
|
1035
1184
|
else:
|
|
1036
|
-
# Plain text or other
|
|
1037
1185
|
text = body
|
|
1038
1186
|
|
|
1039
1187
|
if len(text) > max_chars:
|
|
@@ -4,6 +4,7 @@ pyproject.toml
|
|
|
4
4
|
aru/__init__.py
|
|
5
5
|
aru/agent_factory.py
|
|
6
6
|
aru/cache_patch.py
|
|
7
|
+
aru/checkpoints.py
|
|
7
8
|
aru/cli.py
|
|
8
9
|
aru/commands.py
|
|
9
10
|
aru/completers.py
|
|
@@ -39,6 +40,7 @@ aru_code.egg-info/entry_points.txt
|
|
|
39
40
|
aru_code.egg-info/requires.txt
|
|
40
41
|
aru_code.egg-info/top_level.txt
|
|
41
42
|
tests/test_agents_base.py
|
|
43
|
+
tests/test_checkpoints.py
|
|
42
44
|
tests/test_cli.py
|
|
43
45
|
tests/test_cli_advanced.py
|
|
44
46
|
tests/test_cli_base.py
|
|
@@ -0,0 +1,190 @@
|
|
|
1
|
+
"""Tests for the checkpoint/undo system."""
|
|
2
|
+
|
|
3
|
+
import os
|
|
4
|
+
import tempfile
|
|
5
|
+
|
|
6
|
+
import pytest
|
|
7
|
+
|
|
8
|
+
from aru.checkpoints import CheckpointManager
|
|
9
|
+
|
|
10
|
+
|
|
11
|
+
@pytest.fixture
|
|
12
|
+
def tmp_workspace(tmp_path):
|
|
13
|
+
"""Create a temporary workspace with some files."""
|
|
14
|
+
(tmp_path / "hello.py").write_text("print('hello')\n")
|
|
15
|
+
(tmp_path / "config.json").write_text('{"key": "value"}\n')
|
|
16
|
+
return tmp_path
|
|
17
|
+
|
|
18
|
+
|
|
19
|
+
@pytest.fixture
|
|
20
|
+
def manager(tmp_path):
|
|
21
|
+
"""Create a CheckpointManager with temp backup dir."""
|
|
22
|
+
backup_dir = str(tmp_path / "backups")
|
|
23
|
+
return CheckpointManager("test-session", base_dir=backup_dir)
|
|
24
|
+
|
|
25
|
+
|
|
26
|
+
class TestCheckpointManager:
|
|
27
|
+
def test_begin_turn_creates_snapshot(self, manager):
|
|
28
|
+
manager.begin_turn(1)
|
|
29
|
+
assert manager.get_snapshot_count() == 1
|
|
30
|
+
|
|
31
|
+
def test_track_edit_captures_file_state(self, manager, tmp_workspace):
|
|
32
|
+
manager.begin_turn(1)
|
|
33
|
+
file_path = str(tmp_workspace / "hello.py")
|
|
34
|
+
manager.track_edit(file_path)
|
|
35
|
+
|
|
36
|
+
affected = manager.get_last_snapshot_files()
|
|
37
|
+
assert os.path.abspath(file_path) in affected
|
|
38
|
+
|
|
39
|
+
def test_track_edit_idempotent_within_turn(self, manager, tmp_workspace):
|
|
40
|
+
manager.begin_turn(1)
|
|
41
|
+
file_path = str(tmp_workspace / "hello.py")
|
|
42
|
+
manager.track_edit(file_path)
|
|
43
|
+
manager.track_edit(file_path) # should be no-op
|
|
44
|
+
|
|
45
|
+
affected = manager.get_last_snapshot_files()
|
|
46
|
+
assert len(affected) == 1
|
|
47
|
+
|
|
48
|
+
def test_undo_restores_edited_file(self, manager, tmp_workspace):
|
|
49
|
+
file_path = str(tmp_workspace / "hello.py")
|
|
50
|
+
original_content = "print('hello')\n"
|
|
51
|
+
|
|
52
|
+
manager.begin_turn(1)
|
|
53
|
+
manager.track_edit(file_path)
|
|
54
|
+
|
|
55
|
+
# Simulate edit
|
|
56
|
+
with open(file_path, "w") as f:
|
|
57
|
+
f.write("print('CHANGED')\n")
|
|
58
|
+
assert open(file_path).read() == "print('CHANGED')\n"
|
|
59
|
+
|
|
60
|
+
# Undo
|
|
61
|
+
restored, turn = manager.undo_last_turn()
|
|
62
|
+
assert turn == 1
|
|
63
|
+
assert os.path.abspath(file_path) in restored
|
|
64
|
+
assert open(file_path).read() == original_content
|
|
65
|
+
|
|
66
|
+
def test_undo_deletes_newly_created_file(self, manager, tmp_workspace):
|
|
67
|
+
new_file = str(tmp_workspace / "new_file.py")
|
|
68
|
+
|
|
69
|
+
manager.begin_turn(1)
|
|
70
|
+
manager.track_edit(new_file) # file doesn't exist yet
|
|
71
|
+
|
|
72
|
+
# Simulate creation
|
|
73
|
+
with open(new_file, "w") as f:
|
|
74
|
+
f.write("new content\n")
|
|
75
|
+
assert os.path.isfile(new_file)
|
|
76
|
+
|
|
77
|
+
# Undo should delete the file
|
|
78
|
+
restored, turn = manager.undo_last_turn()
|
|
79
|
+
assert os.path.abspath(new_file) in restored
|
|
80
|
+
assert not os.path.isfile(new_file)
|
|
81
|
+
|
|
82
|
+
def test_undo_multiple_files(self, manager, tmp_workspace):
|
|
83
|
+
file1 = str(tmp_workspace / "hello.py")
|
|
84
|
+
file2 = str(tmp_workspace / "config.json")
|
|
85
|
+
|
|
86
|
+
manager.begin_turn(1)
|
|
87
|
+
manager.track_edit(file1)
|
|
88
|
+
manager.track_edit(file2)
|
|
89
|
+
|
|
90
|
+
# Edit both
|
|
91
|
+
with open(file1, "w") as f:
|
|
92
|
+
f.write("changed1\n")
|
|
93
|
+
with open(file2, "w") as f:
|
|
94
|
+
f.write("changed2\n")
|
|
95
|
+
|
|
96
|
+
# Undo
|
|
97
|
+
restored, _ = manager.undo_last_turn()
|
|
98
|
+
assert len(restored) == 2
|
|
99
|
+
assert open(file1).read() == "print('hello')\n"
|
|
100
|
+
assert open(file2).read() == '{"key": "value"}\n'
|
|
101
|
+
|
|
102
|
+
def test_undo_only_affects_last_turn(self, manager, tmp_workspace):
|
|
103
|
+
file_path = str(tmp_workspace / "hello.py")
|
|
104
|
+
|
|
105
|
+
# Turn 1: edit file
|
|
106
|
+
manager.begin_turn(1)
|
|
107
|
+
manager.track_edit(file_path)
|
|
108
|
+
with open(file_path, "w") as f:
|
|
109
|
+
f.write("turn1\n")
|
|
110
|
+
|
|
111
|
+
# Turn 2: edit file again
|
|
112
|
+
manager.begin_turn(2)
|
|
113
|
+
manager.track_edit(file_path)
|
|
114
|
+
with open(file_path, "w") as f:
|
|
115
|
+
f.write("turn2\n")
|
|
116
|
+
|
|
117
|
+
# Undo turn 2 → should restore to turn1 state
|
|
118
|
+
restored, turn = manager.undo_last_turn()
|
|
119
|
+
assert turn == 2
|
|
120
|
+
assert open(file_path).read() == "turn1\n"
|
|
121
|
+
|
|
122
|
+
# Undo turn 1 → should restore to original
|
|
123
|
+
restored, turn = manager.undo_last_turn()
|
|
124
|
+
assert turn == 1
|
|
125
|
+
assert open(file_path).read() == "print('hello')\n"
|
|
126
|
+
|
|
127
|
+
def test_undo_empty_returns_empty(self, manager):
|
|
128
|
+
restored, turn = manager.undo_last_turn()
|
|
129
|
+
assert restored == []
|
|
130
|
+
assert turn == 0
|
|
131
|
+
|
|
132
|
+
def test_get_last_snapshot_files_empty(self, manager):
|
|
133
|
+
assert manager.get_last_snapshot_files() == []
|
|
134
|
+
|
|
135
|
+
def test_max_snapshots_enforced(self, manager, tmp_workspace):
|
|
136
|
+
file_path = str(tmp_workspace / "hello.py")
|
|
137
|
+
for i in range(105):
|
|
138
|
+
manager.begin_turn(i)
|
|
139
|
+
manager.track_edit(file_path)
|
|
140
|
+
with open(file_path, "w") as f:
|
|
141
|
+
f.write(f"v{i}\n")
|
|
142
|
+
|
|
143
|
+
assert manager.get_snapshot_count() == 100
|
|
144
|
+
|
|
145
|
+
def test_cleanup_removes_backup_dir(self, manager, tmp_workspace):
|
|
146
|
+
file_path = str(tmp_workspace / "hello.py")
|
|
147
|
+
manager.begin_turn(1)
|
|
148
|
+
manager.track_edit(file_path)
|
|
149
|
+
|
|
150
|
+
assert os.path.isdir(manager._base_dir)
|
|
151
|
+
manager.cleanup()
|
|
152
|
+
assert not os.path.isdir(manager._base_dir)
|
|
153
|
+
|
|
154
|
+
|
|
155
|
+
class TestSessionUndoLastTurn:
|
|
156
|
+
"""Tests for Session.undo_last_turn (conversation history only)."""
|
|
157
|
+
|
|
158
|
+
def test_undo_removes_last_turn(self):
|
|
159
|
+
from aru.session import Session
|
|
160
|
+
session = Session()
|
|
161
|
+
session.add_message("user", "hello")
|
|
162
|
+
session.add_message("assistant", "hi there")
|
|
163
|
+
session.add_message("user", "how are you")
|
|
164
|
+
session.add_message("assistant", "good")
|
|
165
|
+
|
|
166
|
+
removed = session.undo_last_turn()
|
|
167
|
+
assert removed == 2 # user + assistant
|
|
168
|
+
assert len(session.history) == 2
|
|
169
|
+
assert session.history[-1]["role"] == "assistant"
|
|
170
|
+
|
|
171
|
+
def test_undo_removes_tool_messages(self):
|
|
172
|
+
from aru.session import Session
|
|
173
|
+
session = Session()
|
|
174
|
+
session.add_message("user", "fix the bug")
|
|
175
|
+
session.add_message("assistant", "reading file")
|
|
176
|
+
session.add_message("tool", "file contents here")
|
|
177
|
+
session.add_message("assistant", "done")
|
|
178
|
+
|
|
179
|
+
removed = session.undo_last_turn()
|
|
180
|
+
# Should remove: user + assistant + tool + assistant = 4 if they go back to last user
|
|
181
|
+
# Actually: pops from end until user is found
|
|
182
|
+
# done (assistant) → tool → reading file (assistant) → fix the bug (user) = 4
|
|
183
|
+
assert removed == 4
|
|
184
|
+
assert len(session.history) == 0
|
|
185
|
+
|
|
186
|
+
def test_undo_empty_history(self):
|
|
187
|
+
from aru.session import Session
|
|
188
|
+
session = Session()
|
|
189
|
+
removed = session.undo_last_turn()
|
|
190
|
+
assert removed == 0
|
aru_code-0.20.1/aru/__init__.py
DELETED
|
@@ -1 +0,0 @@
|
|
|
1
|
-
__version__ = "0.20.1"
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|