comate-cli 0.1.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- comate_cli/__init__.py +5 -0
- comate_cli/__main__.py +5 -0
- comate_cli/main.py +128 -0
- comate_cli/terminal_agent/__init__.py +2 -0
- comate_cli/terminal_agent/animations.py +283 -0
- comate_cli/terminal_agent/app.py +261 -0
- comate_cli/terminal_agent/assistant_render.py +243 -0
- comate_cli/terminal_agent/env_utils.py +37 -0
- comate_cli/terminal_agent/error_display.py +46 -0
- comate_cli/terminal_agent/event_renderer.py +867 -0
- comate_cli/terminal_agent/fragment_utils.py +25 -0
- comate_cli/terminal_agent/history_printer.py +150 -0
- comate_cli/terminal_agent/input_geometry.py +92 -0
- comate_cli/terminal_agent/layout_coordinator.py +188 -0
- comate_cli/terminal_agent/logging_adapter.py +147 -0
- comate_cli/terminal_agent/logo.py +58 -0
- comate_cli/terminal_agent/markdown_render.py +24 -0
- comate_cli/terminal_agent/mention_completer.py +293 -0
- comate_cli/terminal_agent/message_style.py +33 -0
- comate_cli/terminal_agent/models.py +89 -0
- comate_cli/terminal_agent/question_view.py +584 -0
- comate_cli/terminal_agent/rewind_store.py +712 -0
- comate_cli/terminal_agent/rpc_protocol.py +103 -0
- comate_cli/terminal_agent/rpc_stdio.py +280 -0
- comate_cli/terminal_agent/selection_menu.py +305 -0
- comate_cli/terminal_agent/session_view.py +99 -0
- comate_cli/terminal_agent/slash_commands.py +142 -0
- comate_cli/terminal_agent/startup.py +77 -0
- comate_cli/terminal_agent/status_bar.py +258 -0
- comate_cli/terminal_agent/text_effects.py +30 -0
- comate_cli/terminal_agent/tool_view.py +584 -0
- comate_cli/terminal_agent/tui.py +1006 -0
- comate_cli/terminal_agent/tui_parts/__init__.py +17 -0
- comate_cli/terminal_agent/tui_parts/commands.py +759 -0
- comate_cli/terminal_agent/tui_parts/history_sync.py +262 -0
- comate_cli/terminal_agent/tui_parts/input_behavior.py +324 -0
- comate_cli/terminal_agent/tui_parts/key_bindings.py +307 -0
- comate_cli/terminal_agent/tui_parts/render_panels.py +537 -0
- comate_cli/terminal_agent/tui_parts/slash_command_registry.py +45 -0
- comate_cli/terminal_agent/tui_parts/ui_mode.py +9 -0
- comate_cli-0.1.0.dist-info/METADATA +37 -0
- comate_cli-0.1.0.dist-info/RECORD +44 -0
- comate_cli-0.1.0.dist-info/WHEEL +4 -0
- comate_cli-0.1.0.dist-info/entry_points.txt +2 -0
|
@@ -0,0 +1,712 @@
|
|
|
1
|
+
from __future__ import annotations
|
|
2
|
+
|
|
3
|
+
import difflib
|
|
4
|
+
import hashlib
|
|
5
|
+
import json
|
|
6
|
+
import logging
|
|
7
|
+
import os
|
|
8
|
+
import tempfile
|
|
9
|
+
from dataclasses import dataclass
|
|
10
|
+
from datetime import datetime, timezone
|
|
11
|
+
from pathlib import Path
|
|
12
|
+
from typing import Any, Literal
|
|
13
|
+
|
|
14
|
+
from comate_agent_sdk.agent import ChatSession
|
|
15
|
+
from comate_agent_sdk.context.items import ItemType
|
|
16
|
+
|
|
17
|
+
logger = logging.getLogger(__name__)
|
|
18
|
+
|
|
19
|
+
_TRACKED_TOOLS = {"Write", "Edit", "MultiEdit"}
|
|
20
|
+
_SCHEMA_VERSION = 1
|
|
21
|
+
|
|
22
|
+
|
|
23
|
+
@dataclass(frozen=True)
|
|
24
|
+
class RewindCheckpoint:
|
|
25
|
+
checkpoint_id: int
|
|
26
|
+
turn_number: int
|
|
27
|
+
user_preview: str
|
|
28
|
+
user_message: str
|
|
29
|
+
created_at: str
|
|
30
|
+
touched_files: tuple[str, ...]
|
|
31
|
+
|
|
32
|
+
|
|
33
|
+
@dataclass(frozen=True)
|
|
34
|
+
class RewindPlanFile:
|
|
35
|
+
relpath: str
|
|
36
|
+
action: Literal["write", "delete", "skip_binary", "skip_unknown"]
|
|
37
|
+
added_lines: int = 0
|
|
38
|
+
removed_lines: int = 0
|
|
39
|
+
note: str | None = None
|
|
40
|
+
|
|
41
|
+
|
|
42
|
+
@dataclass(frozen=True)
|
|
43
|
+
class RewindRestorePlan:
|
|
44
|
+
checkpoint: RewindCheckpoint
|
|
45
|
+
files: tuple[RewindPlanFile, ...]
|
|
46
|
+
total_added_lines: int
|
|
47
|
+
total_removed_lines: int
|
|
48
|
+
writable_files_count: int
|
|
49
|
+
skipped_binary_count: int
|
|
50
|
+
skipped_unknown_count: int
|
|
51
|
+
|
|
52
|
+
|
|
53
|
+
class RewindStore:
|
|
54
|
+
def __init__(self, *, session: ChatSession, project_root: Path) -> None:
|
|
55
|
+
self._session = session
|
|
56
|
+
self._project_root = project_root.resolve()
|
|
57
|
+
self._register_hook()
|
|
58
|
+
|
|
59
|
+
def bind_session(self, session: ChatSession) -> None:
|
|
60
|
+
self._session = session
|
|
61
|
+
self._register_hook()
|
|
62
|
+
|
|
63
|
+
def _register_hook(self) -> None:
|
|
64
|
+
if hasattr(self._session._agent, "register_python_hook"):
|
|
65
|
+
self._session._agent.register_python_hook(
|
|
66
|
+
event_name="PreToolUse",
|
|
67
|
+
callback=self._on_pre_tool_use,
|
|
68
|
+
)
|
|
69
|
+
|
|
70
|
+
async def _on_pre_tool_use(self, hook_input: Any) -> Any:
|
|
71
|
+
tool_name = getattr(hook_input, "tool_name", "")
|
|
72
|
+
if tool_name not in _TRACKED_TOOLS:
|
|
73
|
+
return None
|
|
74
|
+
|
|
75
|
+
args = getattr(hook_input, "tool_input", {})
|
|
76
|
+
if not isinstance(args, dict):
|
|
77
|
+
return None
|
|
78
|
+
|
|
79
|
+
file_path = args.get("file_path")
|
|
80
|
+
if not file_path:
|
|
81
|
+
return None
|
|
82
|
+
|
|
83
|
+
relpath = self._normalize_relpath(str(file_path))
|
|
84
|
+
if relpath is None:
|
|
85
|
+
return None
|
|
86
|
+
|
|
87
|
+
index = self._load_index()
|
|
88
|
+
checkpoints = index.get("checkpoints", [])
|
|
89
|
+
|
|
90
|
+
if not checkpoints:
|
|
91
|
+
# 第一次工具调用前自动创建 baseline checkpoint(id=0, turn=0)。
|
|
92
|
+
# 用于为"第一个用户 checkpoint 之前"的代码还原提供原始文件状态基准。
|
|
93
|
+
# turn=0 的 checkpoint 不会出现在用户可见的 /rewind 列表中。
|
|
94
|
+
baseline: dict[str, Any] = {
|
|
95
|
+
"id": 0,
|
|
96
|
+
"turn_number": 0,
|
|
97
|
+
"user_preview": "",
|
|
98
|
+
"user_message": "",
|
|
99
|
+
"created_at": datetime.now(timezone.utc).isoformat(),
|
|
100
|
+
"touched_files": [],
|
|
101
|
+
"file_events": {},
|
|
102
|
+
"manifest": {},
|
|
103
|
+
}
|
|
104
|
+
checkpoints.append(baseline)
|
|
105
|
+
index["checkpoints"] = checkpoints
|
|
106
|
+
|
|
107
|
+
latest_cp = checkpoints[-1]
|
|
108
|
+
manifest = latest_cp.get("manifest", {})
|
|
109
|
+
|
|
110
|
+
if relpath not in manifest:
|
|
111
|
+
try:
|
|
112
|
+
state = self._capture_file_state(relpath)
|
|
113
|
+
manifest[relpath] = state
|
|
114
|
+
self._save_checkpoint_file(latest_cp)
|
|
115
|
+
self._save_index(index)
|
|
116
|
+
except Exception as e:
|
|
117
|
+
logger.warning(f"Failed to capture pre-edit baseline for {relpath}: {e}")
|
|
118
|
+
|
|
119
|
+
return None
|
|
120
|
+
|
|
121
|
+
|
|
122
|
+
@property
|
|
123
|
+
def session_id(self) -> str:
|
|
124
|
+
return self._session.session_id
|
|
125
|
+
|
|
126
|
+
@property
|
|
127
|
+
def storage_root(self) -> Path:
|
|
128
|
+
return Path(self._session._storage_root)
|
|
129
|
+
|
|
130
|
+
@property
|
|
131
|
+
def rewind_root(self) -> Path:
|
|
132
|
+
return self.storage_root / "rewind"
|
|
133
|
+
|
|
134
|
+
@property
|
|
135
|
+
def checkpoint_root(self) -> Path:
|
|
136
|
+
return self.rewind_root / "checkpoints"
|
|
137
|
+
|
|
138
|
+
@property
|
|
139
|
+
def blob_root(self) -> Path:
|
|
140
|
+
return self.rewind_root / "blobs"
|
|
141
|
+
|
|
142
|
+
@property
|
|
143
|
+
def index_path(self) -> Path:
|
|
144
|
+
return self.rewind_root / "index.json"
|
|
145
|
+
|
|
146
|
+
def list_checkpoints(self) -> list[RewindCheckpoint]:
|
|
147
|
+
index = self._load_index()
|
|
148
|
+
checkpoints = [
|
|
149
|
+
self._checkpoint_from_dict(cp)
|
|
150
|
+
for cp in index.get("checkpoints", [])
|
|
151
|
+
if int(cp.get("turn_number", 0)) > 0 # 过滤 turn=0 的 baseline,不暴露给用户
|
|
152
|
+
]
|
|
153
|
+
checkpoints.sort(key=lambda cp: cp.checkpoint_id)
|
|
154
|
+
return checkpoints
|
|
155
|
+
|
|
156
|
+
def prune_after_checkpoint(self, *, checkpoint_id: int) -> int:
|
|
157
|
+
"""删除指定 checkpoint 之后的所有 checkpoint。"""
|
|
158
|
+
index = self._load_index()
|
|
159
|
+
raw_checkpoints = index.get("checkpoints", [])
|
|
160
|
+
if not raw_checkpoints:
|
|
161
|
+
return 0
|
|
162
|
+
|
|
163
|
+
kept: list[dict[str, Any]] = []
|
|
164
|
+
dropped_ids: list[int] = []
|
|
165
|
+
for cp in raw_checkpoints:
|
|
166
|
+
cp_id = int(cp.get("id", 0))
|
|
167
|
+
if cp_id <= checkpoint_id:
|
|
168
|
+
kept.append(cp)
|
|
169
|
+
else:
|
|
170
|
+
dropped_ids.append(cp_id)
|
|
171
|
+
|
|
172
|
+
if not dropped_ids:
|
|
173
|
+
return 0
|
|
174
|
+
|
|
175
|
+
index["checkpoints"] = kept
|
|
176
|
+
next_id = (max((int(cp.get("id", 0)) for cp in kept), default=0) + 1)
|
|
177
|
+
index["next_checkpoint_id"] = next_id
|
|
178
|
+
self._save_index(index)
|
|
179
|
+
|
|
180
|
+
for cp_id in dropped_ids:
|
|
181
|
+
cp_path = self.checkpoint_root / f"{cp_id}.json"
|
|
182
|
+
try:
|
|
183
|
+
cp_path.unlink()
|
|
184
|
+
except FileNotFoundError:
|
|
185
|
+
pass
|
|
186
|
+
|
|
187
|
+
return len(dropped_ids)
|
|
188
|
+
|
|
189
|
+
def restore_code_before_checkpoint(self, *, checkpoint_id: int) -> RewindRestorePlan:
|
|
190
|
+
"""还原代码到 checkpoint_id 对应 Turn 开始前的状态。
|
|
191
|
+
|
|
192
|
+
找到 checkpoint_id 的前驱 checkpoint(含 baseline id=0),
|
|
193
|
+
将代码还原到前驱的 manifest 状态。
|
|
194
|
+
若没有任何前驱(极罕见的旧数据),返回空计划不做任何操作。
|
|
195
|
+
"""
|
|
196
|
+
index = self._load_index()
|
|
197
|
+
checkpoints = index.get("checkpoints", [])
|
|
198
|
+
by_id = {int(cp.get("id", 0)): cp for cp in checkpoints}
|
|
199
|
+
|
|
200
|
+
if checkpoint_id not in by_id:
|
|
201
|
+
raise ValueError(f"Checkpoint not found: {checkpoint_id}")
|
|
202
|
+
|
|
203
|
+
sorted_ids = sorted(by_id.keys())
|
|
204
|
+
target_idx = sorted_ids.index(checkpoint_id)
|
|
205
|
+
|
|
206
|
+
if target_idx > 0:
|
|
207
|
+
pred_id = sorted_ids[target_idx - 1]
|
|
208
|
+
return self.restore_code_to_checkpoint(checkpoint_id=pred_id)
|
|
209
|
+
|
|
210
|
+
# 无前驱(连 baseline 都没有),返回空计划
|
|
211
|
+
target_raw = by_id[checkpoint_id]
|
|
212
|
+
return RewindRestorePlan(
|
|
213
|
+
checkpoint=self._checkpoint_from_dict(target_raw),
|
|
214
|
+
files=(),
|
|
215
|
+
total_added_lines=0,
|
|
216
|
+
total_removed_lines=0,
|
|
217
|
+
writable_files_count=0,
|
|
218
|
+
skipped_binary_count=0,
|
|
219
|
+
skipped_unknown_count=0,
|
|
220
|
+
)
|
|
221
|
+
|
|
222
|
+
def capture_checkpoint_for_latest_turn(self, *, user_preview: str) -> RewindCheckpoint | None:
|
|
223
|
+
current_turn = self._current_turn()
|
|
224
|
+
if current_turn <= 0:
|
|
225
|
+
return None
|
|
226
|
+
|
|
227
|
+
index = self._load_index()
|
|
228
|
+
checkpoints = index.get("checkpoints", [])
|
|
229
|
+
if any(int(cp.get("turn_number", 0)) == current_turn for cp in checkpoints):
|
|
230
|
+
return None
|
|
231
|
+
|
|
232
|
+
prev_manifest = self._normalize_manifest(
|
|
233
|
+
checkpoints[-1].get("manifest", {}) if checkpoints else {}
|
|
234
|
+
)
|
|
235
|
+
touched_files, file_events = self._collect_turn_file_events(turn_number=current_turn)
|
|
236
|
+
|
|
237
|
+
manifest: dict[str, dict[str, Any]] = {k: dict(v) for k, v in prev_manifest.items()}
|
|
238
|
+
for relpath in touched_files:
|
|
239
|
+
manifest[relpath] = self._capture_file_state(relpath)
|
|
240
|
+
|
|
241
|
+
checkpoint_id = int(index.get("next_checkpoint_id", 1))
|
|
242
|
+
checkpoint = {
|
|
243
|
+
"id": checkpoint_id,
|
|
244
|
+
"turn_number": current_turn,
|
|
245
|
+
"user_preview": self._normalize_preview(user_preview),
|
|
246
|
+
"user_message": str(user_preview).strip(),
|
|
247
|
+
"created_at": datetime.now(timezone.utc).isoformat(),
|
|
248
|
+
"touched_files": sorted(touched_files),
|
|
249
|
+
"file_events": file_events,
|
|
250
|
+
"manifest": manifest,
|
|
251
|
+
}
|
|
252
|
+
checkpoints.append(checkpoint)
|
|
253
|
+
index["next_checkpoint_id"] = checkpoint_id + 1
|
|
254
|
+
index["checkpoints"] = checkpoints
|
|
255
|
+
self._save_index(index)
|
|
256
|
+
self._save_checkpoint_file(checkpoint)
|
|
257
|
+
return self._checkpoint_from_dict(checkpoint)
|
|
258
|
+
|
|
259
|
+
def build_restore_plan(self, *, checkpoint_id: int) -> RewindRestorePlan:
|
|
260
|
+
index = self._load_index()
|
|
261
|
+
checkpoints = index.get("checkpoints", [])
|
|
262
|
+
if not checkpoints:
|
|
263
|
+
raise ValueError("No checkpoints available")
|
|
264
|
+
|
|
265
|
+
by_id = {int(cp.get("id", 0)): cp for cp in checkpoints}
|
|
266
|
+
target_raw = by_id.get(checkpoint_id)
|
|
267
|
+
if target_raw is None:
|
|
268
|
+
raise ValueError(f"Checkpoint not found: {checkpoint_id}")
|
|
269
|
+
|
|
270
|
+
latest_manifest = self._normalize_manifest(checkpoints[-1].get("manifest", {}))
|
|
271
|
+
target_manifest = self._normalize_manifest(target_raw.get("manifest", {}))
|
|
272
|
+
tracked_paths = sorted(set(latest_manifest.keys()) | set(target_manifest.keys()))
|
|
273
|
+
|
|
274
|
+
planned_files: list[RewindPlanFile] = []
|
|
275
|
+
total_added = 0
|
|
276
|
+
total_removed = 0
|
|
277
|
+
writable_count = 0
|
|
278
|
+
skipped_binary = 0
|
|
279
|
+
skipped_unknown = 0
|
|
280
|
+
|
|
281
|
+
target_id = int(target_raw.get("id", 0))
|
|
282
|
+
for relpath in tracked_paths:
|
|
283
|
+
current_state = self._capture_file_state(relpath)
|
|
284
|
+
target_state = target_manifest.get(relpath)
|
|
285
|
+
if target_state is None:
|
|
286
|
+
target_state = self._infer_missing_target_state(
|
|
287
|
+
relpath=relpath,
|
|
288
|
+
checkpoints=checkpoints,
|
|
289
|
+
target_checkpoint_id=target_id,
|
|
290
|
+
)
|
|
291
|
+
if target_state is None:
|
|
292
|
+
if current_state.get("exists", False):
|
|
293
|
+
skipped_unknown += 1
|
|
294
|
+
planned_files.append(
|
|
295
|
+
RewindPlanFile(
|
|
296
|
+
relpath=relpath,
|
|
297
|
+
action="skip_unknown",
|
|
298
|
+
note="missing baseline before first tracked change",
|
|
299
|
+
)
|
|
300
|
+
)
|
|
301
|
+
continue
|
|
302
|
+
|
|
303
|
+
file_plan = self._plan_file_change(
|
|
304
|
+
relpath=relpath,
|
|
305
|
+
current_state=current_state,
|
|
306
|
+
target_state=target_state,
|
|
307
|
+
)
|
|
308
|
+
if file_plan is None:
|
|
309
|
+
continue
|
|
310
|
+
planned_files.append(file_plan)
|
|
311
|
+
|
|
312
|
+
if file_plan.action == "write" or file_plan.action == "delete":
|
|
313
|
+
writable_count += 1
|
|
314
|
+
total_added += file_plan.added_lines
|
|
315
|
+
total_removed += file_plan.removed_lines
|
|
316
|
+
elif file_plan.action == "skip_binary":
|
|
317
|
+
skipped_binary += 1
|
|
318
|
+
elif file_plan.action == "skip_unknown":
|
|
319
|
+
skipped_unknown += 1
|
|
320
|
+
|
|
321
|
+
return RewindRestorePlan(
|
|
322
|
+
checkpoint=self._checkpoint_from_dict(target_raw),
|
|
323
|
+
files=tuple(planned_files),
|
|
324
|
+
total_added_lines=total_added,
|
|
325
|
+
total_removed_lines=total_removed,
|
|
326
|
+
writable_files_count=writable_count,
|
|
327
|
+
skipped_binary_count=skipped_binary,
|
|
328
|
+
skipped_unknown_count=skipped_unknown,
|
|
329
|
+
)
|
|
330
|
+
|
|
331
|
+
def restore_code_to_checkpoint(self, *, checkpoint_id: int) -> RewindRestorePlan:
|
|
332
|
+
index = self._load_index()
|
|
333
|
+
checkpoints = index.get("checkpoints", [])
|
|
334
|
+
by_id = {int(cp.get("id", 0)): cp for cp in checkpoints}
|
|
335
|
+
target_raw = by_id.get(checkpoint_id)
|
|
336
|
+
if target_raw is None:
|
|
337
|
+
raise ValueError(f"Checkpoint not found: {checkpoint_id}")
|
|
338
|
+
target_manifest = self._normalize_manifest(target_raw.get("manifest", {}))
|
|
339
|
+
|
|
340
|
+
plan = self.build_restore_plan(checkpoint_id=checkpoint_id)
|
|
341
|
+
for file_plan in plan.files:
|
|
342
|
+
relpath = file_plan.relpath
|
|
343
|
+
abs_path = (self._project_root / relpath).resolve()
|
|
344
|
+
if file_plan.action == "delete":
|
|
345
|
+
if abs_path.exists() and abs_path.is_file():
|
|
346
|
+
abs_path.unlink()
|
|
347
|
+
continue
|
|
348
|
+
if file_plan.action == "write":
|
|
349
|
+
target_state = target_manifest.get(relpath) or {}
|
|
350
|
+
sha256 = str(target_state.get("sha256") or "").strip()
|
|
351
|
+
if not sha256:
|
|
352
|
+
continue
|
|
353
|
+
blob_path = self._blob_path(sha256)
|
|
354
|
+
if not blob_path.exists():
|
|
355
|
+
logger.warning(f"rewind blob missing for {relpath}: {sha256}")
|
|
356
|
+
continue
|
|
357
|
+
data = blob_path.read_bytes()
|
|
358
|
+
self._atomic_write_bytes(abs_path, data)
|
|
359
|
+
continue
|
|
360
|
+
return plan
|
|
361
|
+
|
|
362
|
+
def _plan_file_change(
|
|
363
|
+
self,
|
|
364
|
+
*,
|
|
365
|
+
relpath: str,
|
|
366
|
+
current_state: dict[str, Any],
|
|
367
|
+
target_state: dict[str, Any],
|
|
368
|
+
) -> RewindPlanFile | None:
|
|
369
|
+
current_exists = bool(current_state.get("exists", False))
|
|
370
|
+
target_exists = bool(target_state.get("exists", False))
|
|
371
|
+
|
|
372
|
+
if target_exists and bool(target_state.get("binary", False)):
|
|
373
|
+
if current_exists:
|
|
374
|
+
return RewindPlanFile(
|
|
375
|
+
relpath=relpath,
|
|
376
|
+
action="skip_binary",
|
|
377
|
+
note="target is binary; skipped",
|
|
378
|
+
)
|
|
379
|
+
return None
|
|
380
|
+
|
|
381
|
+
if not target_exists:
|
|
382
|
+
if not current_exists:
|
|
383
|
+
return None
|
|
384
|
+
if bool(current_state.get("binary", False)):
|
|
385
|
+
return RewindPlanFile(
|
|
386
|
+
relpath=relpath,
|
|
387
|
+
action="skip_binary",
|
|
388
|
+
note="current is binary; skipped delete",
|
|
389
|
+
)
|
|
390
|
+
before_bytes = self._read_blob_for_state(current_state)
|
|
391
|
+
added, removed = self._compute_line_delta(before=before_bytes, after=b"")
|
|
392
|
+
return RewindPlanFile(
|
|
393
|
+
relpath=relpath,
|
|
394
|
+
action="delete",
|
|
395
|
+
added_lines=added,
|
|
396
|
+
removed_lines=removed,
|
|
397
|
+
)
|
|
398
|
+
|
|
399
|
+
if bool(current_state.get("binary", False)):
|
|
400
|
+
return RewindPlanFile(
|
|
401
|
+
relpath=relpath,
|
|
402
|
+
action="skip_binary",
|
|
403
|
+
note="current is binary; skipped",
|
|
404
|
+
)
|
|
405
|
+
|
|
406
|
+
target_bytes = self._read_blob_for_state(target_state)
|
|
407
|
+
if target_bytes is None:
|
|
408
|
+
return RewindPlanFile(
|
|
409
|
+
relpath=relpath,
|
|
410
|
+
action="skip_unknown",
|
|
411
|
+
note="target blob missing",
|
|
412
|
+
)
|
|
413
|
+
|
|
414
|
+
if not current_exists:
|
|
415
|
+
added, removed = self._compute_line_delta(before=b"", after=target_bytes)
|
|
416
|
+
return RewindPlanFile(
|
|
417
|
+
relpath=relpath,
|
|
418
|
+
action="write",
|
|
419
|
+
added_lines=added,
|
|
420
|
+
removed_lines=removed,
|
|
421
|
+
)
|
|
422
|
+
|
|
423
|
+
current_bytes = self._read_blob_for_state(current_state)
|
|
424
|
+
if current_bytes is None:
|
|
425
|
+
return RewindPlanFile(
|
|
426
|
+
relpath=relpath,
|
|
427
|
+
action="skip_unknown",
|
|
428
|
+
note="current blob missing",
|
|
429
|
+
)
|
|
430
|
+
if current_bytes == target_bytes:
|
|
431
|
+
return None
|
|
432
|
+
|
|
433
|
+
added, removed = self._compute_line_delta(before=current_bytes, after=target_bytes)
|
|
434
|
+
return RewindPlanFile(
|
|
435
|
+
relpath=relpath,
|
|
436
|
+
action="write",
|
|
437
|
+
added_lines=added,
|
|
438
|
+
removed_lines=removed,
|
|
439
|
+
)
|
|
440
|
+
|
|
441
|
+
@staticmethod
|
|
442
|
+
def _compute_line_delta(*, before: bytes, after: bytes) -> tuple[int, int]:
|
|
443
|
+
before_text = before.decode("utf-8", errors="replace").splitlines()
|
|
444
|
+
after_text = after.decode("utf-8", errors="replace").splitlines()
|
|
445
|
+
added = 0
|
|
446
|
+
removed = 0
|
|
447
|
+
for line in difflib.unified_diff(before_text, after_text, lineterm=""):
|
|
448
|
+
if line.startswith("+++ ") or line.startswith("--- ") or line.startswith("@@"):
|
|
449
|
+
continue
|
|
450
|
+
if line.startswith("+"):
|
|
451
|
+
added += 1
|
|
452
|
+
elif line.startswith("-"):
|
|
453
|
+
removed += 1
|
|
454
|
+
return added, removed
|
|
455
|
+
|
|
456
|
+
def _read_blob_for_state(self, state: dict[str, Any]) -> bytes | None:
|
|
457
|
+
if not bool(state.get("exists", False)):
|
|
458
|
+
return b""
|
|
459
|
+
sha256 = str(state.get("sha256") or "").strip()
|
|
460
|
+
if not sha256:
|
|
461
|
+
return None
|
|
462
|
+
blob_path = self._blob_path(sha256)
|
|
463
|
+
if not blob_path.exists():
|
|
464
|
+
return None
|
|
465
|
+
return blob_path.read_bytes()
|
|
466
|
+
|
|
467
|
+
def _infer_missing_target_state(
|
|
468
|
+
self,
|
|
469
|
+
*,
|
|
470
|
+
relpath: str,
|
|
471
|
+
checkpoints: list[dict[str, Any]],
|
|
472
|
+
target_checkpoint_id: int,
|
|
473
|
+
) -> dict[str, Any] | None:
|
|
474
|
+
first_seen: dict[str, Any] | None = None
|
|
475
|
+
for cp in checkpoints:
|
|
476
|
+
manifest = self._normalize_manifest(cp.get("manifest", {}))
|
|
477
|
+
if relpath in manifest:
|
|
478
|
+
first_seen = cp
|
|
479
|
+
break
|
|
480
|
+
if first_seen is None:
|
|
481
|
+
return None
|
|
482
|
+
|
|
483
|
+
first_id = int(first_seen.get("id", 0))
|
|
484
|
+
if first_id <= target_checkpoint_id:
|
|
485
|
+
return None
|
|
486
|
+
|
|
487
|
+
events = first_seen.get("file_events", {}) or {}
|
|
488
|
+
event_meta = events.get(relpath, {}) if isinstance(events, dict) else {}
|
|
489
|
+
created = event_meta.get("created")
|
|
490
|
+
if created is True:
|
|
491
|
+
return {"exists": False, "binary": False, "sha256": None}
|
|
492
|
+
return None
|
|
493
|
+
|
|
494
|
+
def _collect_turn_file_events(
|
|
495
|
+
self,
|
|
496
|
+
*,
|
|
497
|
+
turn_number: int,
|
|
498
|
+
) -> tuple[set[str], dict[str, dict[str, Any]]]:
|
|
499
|
+
touched: set[str] = set()
|
|
500
|
+
events: dict[str, dict[str, Any]] = {}
|
|
501
|
+
items = getattr(self._session._agent._context.conversation, "items", [])
|
|
502
|
+
for item in items:
|
|
503
|
+
if item.item_type != ItemType.TOOL_RESULT:
|
|
504
|
+
continue
|
|
505
|
+
if int(getattr(item, "created_turn", 0) or 0) != turn_number:
|
|
506
|
+
continue
|
|
507
|
+
if bool(getattr(item, "is_tool_error", False)):
|
|
508
|
+
continue
|
|
509
|
+
|
|
510
|
+
tool_name = str(item.tool_name or "").strip()
|
|
511
|
+
if not tool_name and getattr(item, "message", None) is not None:
|
|
512
|
+
tool_name = str(getattr(item.message, "tool_name", "")).strip()
|
|
513
|
+
if tool_name not in _TRACKED_TOOLS:
|
|
514
|
+
continue
|
|
515
|
+
|
|
516
|
+
envelope = (item.metadata or {}).get("tool_raw_envelope")
|
|
517
|
+
if not isinstance(envelope, dict):
|
|
518
|
+
continue
|
|
519
|
+
|
|
520
|
+
meta = envelope.get("meta", {})
|
|
521
|
+
data = envelope.get("data", {})
|
|
522
|
+
if not isinstance(meta, dict):
|
|
523
|
+
meta = {}
|
|
524
|
+
if not isinstance(data, dict):
|
|
525
|
+
data = {}
|
|
526
|
+
|
|
527
|
+
raw_path = data.get("relpath") or meta.get("file_path")
|
|
528
|
+
relpath = self._normalize_relpath(raw_path)
|
|
529
|
+
if relpath is None:
|
|
530
|
+
continue
|
|
531
|
+
touched.add(relpath)
|
|
532
|
+
|
|
533
|
+
created_val = data.get("created")
|
|
534
|
+
created: bool | None
|
|
535
|
+
if isinstance(created_val, bool):
|
|
536
|
+
created = created_val
|
|
537
|
+
else:
|
|
538
|
+
created = None
|
|
539
|
+
operation = str(meta.get("operation") or "").strip() or None
|
|
540
|
+
events[relpath] = {
|
|
541
|
+
"tool": tool_name,
|
|
542
|
+
"operation": operation,
|
|
543
|
+
"created": created,
|
|
544
|
+
}
|
|
545
|
+
return touched, events
|
|
546
|
+
|
|
547
|
+
@staticmethod
|
|
548
|
+
def _normalize_preview(text: str) -> str:
|
|
549
|
+
line = " ".join(str(text).strip().split())
|
|
550
|
+
if len(line) <= 80:
|
|
551
|
+
return line
|
|
552
|
+
return f"{line[:77]}..."
|
|
553
|
+
|
|
554
|
+
def _current_turn(self) -> int:
|
|
555
|
+
try:
|
|
556
|
+
return int(getattr(self._session._agent._context, "_turn_number", 0) or 0)
|
|
557
|
+
except Exception:
|
|
558
|
+
return 0
|
|
559
|
+
|
|
560
|
+
def _capture_file_state(self, relpath: str) -> dict[str, Any]:
|
|
561
|
+
abs_path = (self._project_root / relpath).resolve()
|
|
562
|
+
if not self._is_under(abs_path, self._project_root):
|
|
563
|
+
return {"exists": False, "binary": False, "sha256": None}
|
|
564
|
+
if not abs_path.exists() or not abs_path.is_file():
|
|
565
|
+
return {"exists": False, "binary": False, "sha256": None}
|
|
566
|
+
|
|
567
|
+
data = abs_path.read_bytes()
|
|
568
|
+
if self._is_binary_content(data):
|
|
569
|
+
return {"exists": True, "binary": True, "sha256": None}
|
|
570
|
+
|
|
571
|
+
sha256 = hashlib.sha256(data).hexdigest()
|
|
572
|
+
self._store_blob(sha256, data)
|
|
573
|
+
return {"exists": True, "binary": False, "sha256": sha256}
|
|
574
|
+
|
|
575
|
+
def _save_checkpoint_file(self, checkpoint: dict[str, Any]) -> None:
|
|
576
|
+
self.checkpoint_root.mkdir(parents=True, exist_ok=True)
|
|
577
|
+
cp_id = int(checkpoint.get("id", 0))
|
|
578
|
+
path = self.checkpoint_root / f"{cp_id}.json"
|
|
579
|
+
path.write_text(
|
|
580
|
+
json.dumps(checkpoint, ensure_ascii=False, indent=2),
|
|
581
|
+
encoding="utf-8",
|
|
582
|
+
)
|
|
583
|
+
|
|
584
|
+
def _blob_path(self, sha256: str) -> Path:
|
|
585
|
+
return self.blob_root / f"{sha256}.bin"
|
|
586
|
+
|
|
587
|
+
def _store_blob(self, sha256: str, data: bytes) -> None:
|
|
588
|
+
self.blob_root.mkdir(parents=True, exist_ok=True)
|
|
589
|
+
path = self._blob_path(sha256)
|
|
590
|
+
if path.exists():
|
|
591
|
+
return
|
|
592
|
+
self._atomic_write_bytes(path, data)
|
|
593
|
+
|
|
594
|
+
@staticmethod
|
|
595
|
+
def _is_binary_content(data: bytes) -> bool:
|
|
596
|
+
if b"\x00" in data[:4096]:
|
|
597
|
+
return True
|
|
598
|
+
try:
|
|
599
|
+
data.decode("utf-8")
|
|
600
|
+
except UnicodeDecodeError:
|
|
601
|
+
return True
|
|
602
|
+
return False
|
|
603
|
+
|
|
604
|
+
@staticmethod
|
|
605
|
+
def _checkpoint_from_dict(data: dict[str, Any]) -> RewindCheckpoint:
|
|
606
|
+
touched = data.get("touched_files", []) or []
|
|
607
|
+
user_message = str(data.get("user_message", "")).strip()
|
|
608
|
+
user_preview = str(data.get("user_preview", ""))
|
|
609
|
+
if not user_message:
|
|
610
|
+
user_message = user_preview
|
|
611
|
+
return RewindCheckpoint(
|
|
612
|
+
checkpoint_id=int(data.get("id", 0)),
|
|
613
|
+
turn_number=int(data.get("turn_number", 0)),
|
|
614
|
+
user_preview=user_preview,
|
|
615
|
+
user_message=user_message,
|
|
616
|
+
created_at=str(data.get("created_at", "")),
|
|
617
|
+
touched_files=tuple(str(x) for x in touched),
|
|
618
|
+
)
|
|
619
|
+
|
|
620
|
+
@staticmethod
|
|
621
|
+
def _normalize_manifest(raw_manifest: Any) -> dict[str, dict[str, Any]]:
|
|
622
|
+
if not isinstance(raw_manifest, dict):
|
|
623
|
+
return {}
|
|
624
|
+
manifest: dict[str, dict[str, Any]] = {}
|
|
625
|
+
for key, raw in raw_manifest.items():
|
|
626
|
+
relpath = str(key).replace("\\", "/").strip("/")
|
|
627
|
+
if not relpath:
|
|
628
|
+
continue
|
|
629
|
+
state = raw if isinstance(raw, dict) else {}
|
|
630
|
+
manifest[relpath] = {
|
|
631
|
+
"exists": bool(state.get("exists", False)),
|
|
632
|
+
"binary": bool(state.get("binary", False)),
|
|
633
|
+
"sha256": state.get("sha256"),
|
|
634
|
+
}
|
|
635
|
+
return manifest
|
|
636
|
+
|
|
637
|
+
def _normalize_relpath(self, raw_path: Any) -> str | None:
|
|
638
|
+
if raw_path is None:
|
|
639
|
+
return None
|
|
640
|
+
text = str(raw_path).strip()
|
|
641
|
+
if not text:
|
|
642
|
+
return None
|
|
643
|
+
path = Path(text)
|
|
644
|
+
abs_path = path.resolve() if path.is_absolute() else (self._project_root / path).resolve()
|
|
645
|
+
if not self._is_under(abs_path, self._project_root):
|
|
646
|
+
return None
|
|
647
|
+
return abs_path.relative_to(self._project_root).as_posix()
|
|
648
|
+
|
|
649
|
+
@staticmethod
|
|
650
|
+
def _is_under(path: Path, root: Path) -> bool:
|
|
651
|
+
try:
|
|
652
|
+
path.relative_to(root)
|
|
653
|
+
return True
|
|
654
|
+
except Exception:
|
|
655
|
+
return False
|
|
656
|
+
|
|
657
|
+
def _load_index(self) -> dict[str, Any]:
|
|
658
|
+
if not self.index_path.exists():
|
|
659
|
+
return {
|
|
660
|
+
"schema_version": _SCHEMA_VERSION,
|
|
661
|
+
"session_id": self._session.session_id,
|
|
662
|
+
"next_checkpoint_id": 1,
|
|
663
|
+
"checkpoints": [],
|
|
664
|
+
}
|
|
665
|
+
try:
|
|
666
|
+
data = json.loads(self.index_path.read_text(encoding="utf-8"))
|
|
667
|
+
except Exception:
|
|
668
|
+
logger.warning(f"failed to load rewind index: {self.index_path}")
|
|
669
|
+
return {
|
|
670
|
+
"schema_version": _SCHEMA_VERSION,
|
|
671
|
+
"session_id": self._session.session_id,
|
|
672
|
+
"next_checkpoint_id": 1,
|
|
673
|
+
"checkpoints": [],
|
|
674
|
+
}
|
|
675
|
+
|
|
676
|
+
if not isinstance(data, dict):
|
|
677
|
+
return {
|
|
678
|
+
"schema_version": _SCHEMA_VERSION,
|
|
679
|
+
"session_id": self._session.session_id,
|
|
680
|
+
"next_checkpoint_id": 1,
|
|
681
|
+
"checkpoints": [],
|
|
682
|
+
}
|
|
683
|
+
|
|
684
|
+
checkpoints = data.get("checkpoints")
|
|
685
|
+
if not isinstance(checkpoints, list):
|
|
686
|
+
checkpoints = []
|
|
687
|
+
data["checkpoints"] = checkpoints
|
|
688
|
+
data["next_checkpoint_id"] = int(data.get("next_checkpoint_id", len(checkpoints) + 1))
|
|
689
|
+
data["schema_version"] = _SCHEMA_VERSION
|
|
690
|
+
data["session_id"] = self._session.session_id
|
|
691
|
+
return data
|
|
692
|
+
|
|
693
|
+
def _save_index(self, index: dict[str, Any]) -> None:
|
|
694
|
+
self.rewind_root.mkdir(parents=True, exist_ok=True)
|
|
695
|
+
payload = json.dumps(index, ensure_ascii=False, indent=2).encode("utf-8")
|
|
696
|
+
self._atomic_write_bytes(self.index_path, payload)
|
|
697
|
+
|
|
698
|
+
@staticmethod
|
|
699
|
+
def _atomic_write_bytes(path: Path, payload: bytes) -> None:
|
|
700
|
+
path.parent.mkdir(parents=True, exist_ok=True)
|
|
701
|
+
fd, tmp_name = tempfile.mkstemp(prefix=".tmp_", dir=str(path.parent))
|
|
702
|
+
try:
|
|
703
|
+
with os.fdopen(fd, "wb") as f:
|
|
704
|
+
f.write(payload)
|
|
705
|
+
f.flush()
|
|
706
|
+
os.fsync(f.fileno())
|
|
707
|
+
os.replace(tmp_name, path)
|
|
708
|
+
finally:
|
|
709
|
+
try:
|
|
710
|
+
os.remove(tmp_name)
|
|
711
|
+
except FileNotFoundError:
|
|
712
|
+
pass
|