zai-cli 0.1.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- zai/__init__.py +1 -0
- zai/__main__.py +4 -0
- zai/cli/__init__.py +1 -0
- zai/cli/common.py +16 -0
- zai/cli/integrations.py +319 -0
- zai/cli/interactive.py +518 -0
- zai/cli/settings.py +436 -0
- zai/cli/utilities.py +227 -0
- zai/cli/workflows.py +137 -0
- zai/commands/commit.md +24 -0
- zai/commands/explain.md +17 -0
- zai/commands/feature.md +34 -0
- zai/commands/fix.md +14 -0
- zai/commands/review.md +22 -0
- zai/config.py +307 -0
- zai/core/__init__.py +0 -0
- zai/core/agent.py +701 -0
- zai/core/cancellation.py +67 -0
- zai/core/commands.py +85 -0
- zai/core/context.py +299 -0
- zai/core/errors.py +125 -0
- zai/core/fallback.py +171 -0
- zai/core/hooks.py +115 -0
- zai/core/memory.py +57 -0
- zai/core/process.py +204 -0
- zai/core/repomap.py +381 -0
- zai/core/runtime.py +29 -0
- zai/core/security.py +33 -0
- zai/core/session.py +425 -0
- zai/core/storage.py +193 -0
- zai/core/streaming.py +157 -0
- zai/core/tool_schema.py +133 -0
- zai/core/undo.py +443 -0
- zai/core/watch.py +80 -0
- zai/main.py +210 -0
- zai/mcp/__init__.py +0 -0
- zai/mcp/client.py +431 -0
- zai/mcp/manager.py +118 -0
- zai/plugins/__init__.py +2 -0
- zai/plugins/base.py +49 -0
- zai/plugins/loader.py +404 -0
- zai/providers/__init__.py +22 -0
- zai/providers/anthropic.py +131 -0
- zai/providers/base.py +67 -0
- zai/providers/cerebras.py +57 -0
- zai/providers/gemini.py +119 -0
- zai/providers/groq.py +116 -0
- zai/providers/ollama.py +62 -0
- zai/providers/openai.py +124 -0
- zai/providers/openrouter.py +63 -0
- zai/providers/qwen.py +47 -0
- zai/skills/__init__.py +0 -0
- zai/skills/registry.py +52 -0
- zai/tools/__init__.py +0 -0
- zai/tools/browser.py +224 -0
- zai/tools/code_runner.py +49 -0
- zai/tools/files.py +53 -0
- zai/tools/git.py +38 -0
- zai/tools/search.py +157 -0
- zai/tools/vision.py +128 -0
- zai/ui/__init__.py +0 -0
- zai/ui/input.py +199 -0
- zai_cli-0.1.0.dist-info/METADATA +722 -0
- zai_cli-0.1.0.dist-info/RECORD +68 -0
- zai_cli-0.1.0.dist-info/WHEEL +5 -0
- zai_cli-0.1.0.dist-info/entry_points.txt +2 -0
- zai_cli-0.1.0.dist-info/licenses/LICENSE +21 -0
- zai_cli-0.1.0.dist-info/top_level.txt +1 -0
zai/core/session.py
ADDED
|
@@ -0,0 +1,425 @@
|
|
|
1
|
+
import hashlib
|
|
2
|
+
import json
|
|
3
|
+
import re
|
|
4
|
+
import uuid
|
|
5
|
+
from datetime import datetime
|
|
6
|
+
from pathlib import Path
|
|
7
|
+
|
|
8
|
+
from .errors import FileError
|
|
9
|
+
from .security import resolve_project_path
|
|
10
|
+
from .storage import (
|
|
11
|
+
atomic_write_json,
|
|
12
|
+
atomic_write_text,
|
|
13
|
+
file_lock,
|
|
14
|
+
quarantine_corrupt_file,
|
|
15
|
+
read_json,
|
|
16
|
+
)
|
|
17
|
+
|
|
18
|
+
SESSION_DIR = Path.home() / ".zai" / "sessions"
|
|
19
|
+
SESSION_SCHEMA_VERSION = 4
|
|
20
|
+
MAX_NAMED_SESSIONS = 100
|
|
21
|
+
|
|
22
|
+
|
|
23
|
+
def _safe_session_name(name: str) -> str:
|
|
24
|
+
return re.sub(r"[^a-zA-Z0-9_.-]+", "_", name).strip("._") or "session"
|
|
25
|
+
|
|
26
|
+
|
|
27
|
+
def _valid_lookup(value: str) -> bool:
|
|
28
|
+
return not (
|
|
29
|
+
".." in value
|
|
30
|
+
or any(character in value for character in r"/\*?[]")
|
|
31
|
+
)
|
|
32
|
+
|
|
33
|
+
|
|
34
|
+
def _project_identity(project_path: str) -> tuple[str, str]:
|
|
35
|
+
resolved = str(Path(project_path).resolve())
|
|
36
|
+
slug = re.sub(
|
|
37
|
+
r"[^a-zA-Z0-9_-]+",
|
|
38
|
+
"-",
|
|
39
|
+
Path(resolved).name,
|
|
40
|
+
).strip("-") or "project"
|
|
41
|
+
digest = hashlib.sha256(resolved.lower().encode("utf-8")).hexdigest()[:12]
|
|
42
|
+
return resolved, f"{slug}-{digest}"
|
|
43
|
+
|
|
44
|
+
|
|
45
|
+
def _message_payload(message) -> dict:
|
|
46
|
+
item = {"role": message.role, "content": message.content}
|
|
47
|
+
if getattr(message, "pinned", False):
|
|
48
|
+
item["pinned"] = True
|
|
49
|
+
if getattr(message, "tool_call_id", ""):
|
|
50
|
+
item["tool_call_id"] = message.tool_call_id
|
|
51
|
+
if getattr(message, "tool_name", ""):
|
|
52
|
+
item["tool_name"] = message.tool_name
|
|
53
|
+
tool_calls = getattr(message, "tool_calls", [])
|
|
54
|
+
if tool_calls:
|
|
55
|
+
item["tool_calls"] = [
|
|
56
|
+
{
|
|
57
|
+
"id": call.id,
|
|
58
|
+
"name": call.name,
|
|
59
|
+
"arguments": call.arguments,
|
|
60
|
+
}
|
|
61
|
+
for call in tool_calls
|
|
62
|
+
]
|
|
63
|
+
return item
|
|
64
|
+
|
|
65
|
+
|
|
66
|
+
def _session_title(messages: list[dict]) -> str:
|
|
67
|
+
for message in reversed(messages):
|
|
68
|
+
if message.get("role") == "user" and message.get("content", "").strip():
|
|
69
|
+
return " ".join(message["content"].split())[:80]
|
|
70
|
+
return "Untitled session"
|
|
71
|
+
|
|
72
|
+
|
|
73
|
+
def _session_payload(
|
|
74
|
+
history: list,
|
|
75
|
+
project_path: str | None = None,
|
|
76
|
+
existing: dict | None = None,
|
|
77
|
+
) -> dict:
|
|
78
|
+
messages = [_message_payload(message) for message in history]
|
|
79
|
+
now = datetime.now().isoformat(timespec="seconds")
|
|
80
|
+
existing = existing or {}
|
|
81
|
+
return {
|
|
82
|
+
"version": SESSION_SCHEMA_VERSION,
|
|
83
|
+
"id": existing.get("id") or uuid.uuid4().hex,
|
|
84
|
+
"title": _session_title(messages),
|
|
85
|
+
"project_path": (
|
|
86
|
+
str(Path(project_path).resolve())
|
|
87
|
+
if project_path
|
|
88
|
+
else existing.get("project_path")
|
|
89
|
+
),
|
|
90
|
+
"created_at": existing.get("created_at") or now,
|
|
91
|
+
"updated_at": now,
|
|
92
|
+
"messages": messages,
|
|
93
|
+
}
|
|
94
|
+
|
|
95
|
+
|
|
96
|
+
def _read_session(path: Path) -> tuple[list, dict]:
|
|
97
|
+
data = read_json(path, None, expected_type=(list, dict))
|
|
98
|
+
if data is None:
|
|
99
|
+
raise ValueError("Corrupt session")
|
|
100
|
+
if isinstance(data, list):
|
|
101
|
+
return data, {
|
|
102
|
+
"version": 1,
|
|
103
|
+
"id": hashlib.sha256(str(path).encode()).hexdigest()[:32],
|
|
104
|
+
"title": _session_title(data),
|
|
105
|
+
"project_path": None,
|
|
106
|
+
}
|
|
107
|
+
version = data.get("version", 1)
|
|
108
|
+
if version > SESSION_SCHEMA_VERSION:
|
|
109
|
+
raise ValueError(f"Unsupported session version: {version}")
|
|
110
|
+
if isinstance(data.get("messages"), list):
|
|
111
|
+
data.setdefault(
|
|
112
|
+
"id",
|
|
113
|
+
hashlib.sha256(str(path).encode()).hexdigest()[:32],
|
|
114
|
+
)
|
|
115
|
+
data.setdefault("title", _session_title(data["messages"]))
|
|
116
|
+
return data["messages"], data
|
|
117
|
+
quarantine_corrupt_file(path)
|
|
118
|
+
raise ValueError("Invalid session format")
|
|
119
|
+
|
|
120
|
+
|
|
121
|
+
def _session_info(path: Path) -> dict | None:
|
|
122
|
+
try:
|
|
123
|
+
messages, metadata = _read_session(path)
|
|
124
|
+
stat = path.stat()
|
|
125
|
+
return {
|
|
126
|
+
"id": metadata["id"],
|
|
127
|
+
"name": path.stem,
|
|
128
|
+
"title": metadata.get("title") or _session_title(messages),
|
|
129
|
+
"messages": len(messages),
|
|
130
|
+
"path": str(path),
|
|
131
|
+
"project_path": metadata.get("project_path"),
|
|
132
|
+
"created_at": metadata.get("created_at"),
|
|
133
|
+
"updated_at": metadata.get("updated_at"),
|
|
134
|
+
"modified_at": stat.st_mtime,
|
|
135
|
+
"auto": path.stem.startswith("auto-"),
|
|
136
|
+
}
|
|
137
|
+
except (OSError, ValueError, KeyError):
|
|
138
|
+
return None
|
|
139
|
+
|
|
140
|
+
|
|
141
|
+
def _all_session_paths() -> list[Path]:
|
|
142
|
+
if not SESSION_DIR.exists():
|
|
143
|
+
return []
|
|
144
|
+
return list(SESSION_DIR.glob("*.json"))
|
|
145
|
+
|
|
146
|
+
|
|
147
|
+
def _resolve_session_path(identifier: str) -> Path | None:
|
|
148
|
+
if not identifier or not _valid_lookup(identifier):
|
|
149
|
+
return None
|
|
150
|
+
safe = _safe_session_name(identifier)
|
|
151
|
+
exact = SESSION_DIR / f"{safe}.json"
|
|
152
|
+
if exact.is_file():
|
|
153
|
+
return exact
|
|
154
|
+
|
|
155
|
+
identifier_lower = identifier.lower()
|
|
156
|
+
candidates = []
|
|
157
|
+
for path in _all_session_paths():
|
|
158
|
+
info = _session_info(path)
|
|
159
|
+
if not info:
|
|
160
|
+
continue
|
|
161
|
+
if (
|
|
162
|
+
info["id"].lower() == identifier_lower
|
|
163
|
+
or info["id"].lower().startswith(identifier_lower)
|
|
164
|
+
):
|
|
165
|
+
candidates.append(path)
|
|
166
|
+
elif identifier_lower in info["name"].lower():
|
|
167
|
+
candidates.append(path)
|
|
168
|
+
unique = list(dict.fromkeys(candidates))
|
|
169
|
+
return unique[0] if len(unique) == 1 else None
|
|
170
|
+
|
|
171
|
+
|
|
172
|
+
def _messages_to_objects(messages: list):
|
|
173
|
+
from ..providers.base import Message, ToolCall
|
|
174
|
+
|
|
175
|
+
return [
|
|
176
|
+
Message(
|
|
177
|
+
role=item["role"],
|
|
178
|
+
content=item.get("content", ""),
|
|
179
|
+
tool_call_id=item.get("tool_call_id", ""),
|
|
180
|
+
tool_name=item.get("tool_name", ""),
|
|
181
|
+
pinned=item.get("pinned", False),
|
|
182
|
+
tool_calls=[
|
|
183
|
+
ToolCall(
|
|
184
|
+
id=call["id"],
|
|
185
|
+
name=call["name"],
|
|
186
|
+
arguments=call.get("arguments", {}),
|
|
187
|
+
)
|
|
188
|
+
for call in item.get("tool_calls", [])
|
|
189
|
+
],
|
|
190
|
+
)
|
|
191
|
+
for item in messages
|
|
192
|
+
]
|
|
193
|
+
|
|
194
|
+
|
|
195
|
+
def _prune_named_sessions() -> None:
|
|
196
|
+
named = [
|
|
197
|
+
path
|
|
198
|
+
for path in _all_session_paths()
|
|
199
|
+
if not path.stem.startswith("auto-")
|
|
200
|
+
]
|
|
201
|
+
named.sort(key=lambda path: path.stat().st_mtime, reverse=True)
|
|
202
|
+
for old in named[MAX_NAMED_SESSIONS:]:
|
|
203
|
+
old.unlink(missing_ok=True)
|
|
204
|
+
old.with_name(f"{old.name}.lock").unlink(missing_ok=True)
|
|
205
|
+
|
|
206
|
+
|
|
207
|
+
def save_history(
|
|
208
|
+
history: list,
|
|
209
|
+
name: str | None = None,
|
|
210
|
+
project_path: str | None = None,
|
|
211
|
+
) -> str:
|
|
212
|
+
"""Save or update a named conversation and return its path."""
|
|
213
|
+
SESSION_DIR.mkdir(parents=True, exist_ok=True)
|
|
214
|
+
name = _safe_session_name(name or datetime.now().strftime("%Y%m%d-%H%M%S"))
|
|
215
|
+
if name.startswith("auto-"):
|
|
216
|
+
name = f"named-{name}"
|
|
217
|
+
path = SESSION_DIR / f"{name}.json"
|
|
218
|
+
with file_lock(SESSION_DIR / ".sessions"):
|
|
219
|
+
existing: dict = {}
|
|
220
|
+
if path.exists():
|
|
221
|
+
try:
|
|
222
|
+
_, existing = _read_session(path)
|
|
223
|
+
except ValueError:
|
|
224
|
+
existing = {}
|
|
225
|
+
atomic_write_json(path, _session_payload(history, project_path, existing))
|
|
226
|
+
_prune_named_sessions()
|
|
227
|
+
return str(path)
|
|
228
|
+
|
|
229
|
+
|
|
230
|
+
def save_auto_history(history: list, project_path: str) -> str:
|
|
231
|
+
"""Continuously update exactly one resumable session per project."""
|
|
232
|
+
SESSION_DIR.mkdir(parents=True, exist_ok=True)
|
|
233
|
+
_, project_id = _project_identity(project_path)
|
|
234
|
+
path = SESSION_DIR / f"auto-{project_id}.json"
|
|
235
|
+
with file_lock(SESSION_DIR / ".sessions"):
|
|
236
|
+
existing: dict = {}
|
|
237
|
+
if path.exists():
|
|
238
|
+
try:
|
|
239
|
+
_, existing = _read_session(path)
|
|
240
|
+
except ValueError:
|
|
241
|
+
existing = {}
|
|
242
|
+
atomic_write_json(path, _session_payload(history, project_path, existing))
|
|
243
|
+
return str(path)
|
|
244
|
+
|
|
245
|
+
|
|
246
|
+
def load_history(identifier: str):
|
|
247
|
+
"""Load one unambiguous session by exact name, ID, or unique partial name."""
|
|
248
|
+
path = _resolve_session_path(identifier)
|
|
249
|
+
if not path:
|
|
250
|
+
return None
|
|
251
|
+
try:
|
|
252
|
+
messages, _ = _read_session(path)
|
|
253
|
+
return _messages_to_objects(messages)
|
|
254
|
+
except (ValueError, KeyError):
|
|
255
|
+
return None
|
|
256
|
+
|
|
257
|
+
|
|
258
|
+
def load_latest_session(project_path: str | None = None):
|
|
259
|
+
"""Return (messages, metadata) for the latest matching session."""
|
|
260
|
+
expected_project = (
|
|
261
|
+
str(Path(project_path).resolve())
|
|
262
|
+
if project_path
|
|
263
|
+
else None
|
|
264
|
+
)
|
|
265
|
+
sessions = list_sessions(
|
|
266
|
+
limit=None,
|
|
267
|
+
project_path=expected_project,
|
|
268
|
+
include_auto=True,
|
|
269
|
+
)
|
|
270
|
+
for info in sessions:
|
|
271
|
+
loaded = load_history(info["id"])
|
|
272
|
+
if loaded is not None:
|
|
273
|
+
return loaded, info
|
|
274
|
+
return None
|
|
275
|
+
|
|
276
|
+
|
|
277
|
+
def load_latest_history(project_path: str | None = None):
|
|
278
|
+
latest = load_latest_session(project_path)
|
|
279
|
+
return latest[0] if latest else None
|
|
280
|
+
|
|
281
|
+
|
|
282
|
+
def list_sessions(
|
|
283
|
+
limit: int | None = 50,
|
|
284
|
+
project_path: str | None = None,
|
|
285
|
+
include_auto: bool = True,
|
|
286
|
+
) -> list[dict]:
|
|
287
|
+
"""List sessions newest first, optionally scoped to one project."""
|
|
288
|
+
expected_project = (
|
|
289
|
+
str(Path(project_path).resolve())
|
|
290
|
+
if project_path
|
|
291
|
+
else None
|
|
292
|
+
)
|
|
293
|
+
sessions = []
|
|
294
|
+
for path in _all_session_paths():
|
|
295
|
+
info = _session_info(path)
|
|
296
|
+
if not info:
|
|
297
|
+
continue
|
|
298
|
+
if not include_auto and info["auto"]:
|
|
299
|
+
continue
|
|
300
|
+
if expected_project and info["project_path"] != expected_project:
|
|
301
|
+
continue
|
|
302
|
+
sessions.append(info)
|
|
303
|
+
sessions.sort(key=lambda item: item["modified_at"], reverse=True)
|
|
304
|
+
return sessions if limit is None else sessions[:limit]
|
|
305
|
+
|
|
306
|
+
|
|
307
|
+
def search_sessions(
|
|
308
|
+
query: str,
|
|
309
|
+
project_path: str | None = None,
|
|
310
|
+
limit: int = 20,
|
|
311
|
+
) -> list[dict]:
|
|
312
|
+
"""Search session names, titles, and message content."""
|
|
313
|
+
needle = query.strip().lower()
|
|
314
|
+
if not needle:
|
|
315
|
+
return []
|
|
316
|
+
results = []
|
|
317
|
+
for info in list_sessions(
|
|
318
|
+
limit=None,
|
|
319
|
+
project_path=project_path,
|
|
320
|
+
include_auto=True,
|
|
321
|
+
):
|
|
322
|
+
path = Path(info["path"])
|
|
323
|
+
try:
|
|
324
|
+
messages, _ = _read_session(path)
|
|
325
|
+
except ValueError:
|
|
326
|
+
continue
|
|
327
|
+
haystack = " ".join([
|
|
328
|
+
info["name"],
|
|
329
|
+
info["title"],
|
|
330
|
+
*(message.get("content", "") for message in messages),
|
|
331
|
+
]).lower()
|
|
332
|
+
if needle in haystack:
|
|
333
|
+
results.append(info)
|
|
334
|
+
if len(results) >= limit:
|
|
335
|
+
break
|
|
336
|
+
return results
|
|
337
|
+
|
|
338
|
+
|
|
339
|
+
def get_session_info(identifier: str) -> dict | None:
|
|
340
|
+
path = _resolve_session_path(identifier)
|
|
341
|
+
return _session_info(path) if path else None
|
|
342
|
+
|
|
343
|
+
|
|
344
|
+
def rename_session(identifier: str, new_name: str) -> str | None:
|
|
345
|
+
"""Rename a named session. Automatic project sessions cannot be renamed."""
|
|
346
|
+
source = _resolve_session_path(identifier)
|
|
347
|
+
if not source or source.stem.startswith("auto-"):
|
|
348
|
+
return None
|
|
349
|
+
safe_name = _safe_session_name(new_name)
|
|
350
|
+
destination = SESSION_DIR / f"{safe_name}.json"
|
|
351
|
+
if destination.exists():
|
|
352
|
+
return None
|
|
353
|
+
with file_lock(SESSION_DIR / ".sessions"):
|
|
354
|
+
if not source.exists() or destination.exists():
|
|
355
|
+
return None
|
|
356
|
+
source.replace(destination)
|
|
357
|
+
source.with_name(f"{source.name}.lock").unlink(missing_ok=True)
|
|
358
|
+
return str(destination)
|
|
359
|
+
|
|
360
|
+
|
|
361
|
+
def delete_session(identifier: str) -> bool:
|
|
362
|
+
path = _resolve_session_path(identifier)
|
|
363
|
+
if not path:
|
|
364
|
+
return False
|
|
365
|
+
with file_lock(SESSION_DIR / ".sessions"):
|
|
366
|
+
if not path.exists():
|
|
367
|
+
return False
|
|
368
|
+
path.unlink()
|
|
369
|
+
path.with_name(f"{path.name}.lock").unlink(missing_ok=True)
|
|
370
|
+
return True
|
|
371
|
+
|
|
372
|
+
|
|
373
|
+
def export_session(
|
|
374
|
+
identifier: str,
|
|
375
|
+
export_format: str = "md",
|
|
376
|
+
output: str | None = None,
|
|
377
|
+
project_path: str = ".",
|
|
378
|
+
) -> str | None:
|
|
379
|
+
"""Export one session as Markdown or JSON inside the selected project."""
|
|
380
|
+
path = _resolve_session_path(identifier)
|
|
381
|
+
if not path or export_format not in {"md", "json"}:
|
|
382
|
+
return None
|
|
383
|
+
try:
|
|
384
|
+
messages, metadata = _read_session(path)
|
|
385
|
+
except ValueError:
|
|
386
|
+
return None
|
|
387
|
+
|
|
388
|
+
suffix = export_format
|
|
389
|
+
try:
|
|
390
|
+
destination = resolve_project_path(
|
|
391
|
+
project_path,
|
|
392
|
+
output or f"{path.stem}.{suffix}",
|
|
393
|
+
)
|
|
394
|
+
except FileError:
|
|
395
|
+
return None
|
|
396
|
+
if export_format == "json":
|
|
397
|
+
payload = {
|
|
398
|
+
**metadata,
|
|
399
|
+
"messages": messages,
|
|
400
|
+
}
|
|
401
|
+
atomic_write_text(
|
|
402
|
+
destination,
|
|
403
|
+
json.dumps(payload, indent=2, ensure_ascii=False) + "\n",
|
|
404
|
+
mode=0o644,
|
|
405
|
+
lock=False,
|
|
406
|
+
)
|
|
407
|
+
else:
|
|
408
|
+
lines = [
|
|
409
|
+
f"# {metadata.get('title') or _session_title(messages)}",
|
|
410
|
+
"",
|
|
411
|
+
f"- Session ID: `{metadata.get('id', '')}`",
|
|
412
|
+
f"- Project: `{metadata.get('project_path') or '-'}`",
|
|
413
|
+
f"- Updated: `{metadata.get('updated_at') or '-'}`",
|
|
414
|
+
"",
|
|
415
|
+
]
|
|
416
|
+
for message in messages:
|
|
417
|
+
role = message.get("role", "message").replace("_", " ").title()
|
|
418
|
+
lines.extend([f"## {role}", "", message.get("content", ""), ""])
|
|
419
|
+
atomic_write_text(
|
|
420
|
+
destination,
|
|
421
|
+
"\n".join(lines).rstrip() + "\n",
|
|
422
|
+
mode=0o644,
|
|
423
|
+
lock=False,
|
|
424
|
+
)
|
|
425
|
+
return str(destination)
|
zai/core/storage.py
ADDED
|
@@ -0,0 +1,193 @@
|
|
|
1
|
+
"""Crash-safe local state storage with per-file process locking."""
|
|
2
|
+
from __future__ import annotations
|
|
3
|
+
|
|
4
|
+
import json
|
|
5
|
+
import os
|
|
6
|
+
import tempfile
|
|
7
|
+
import threading
|
|
8
|
+
import time
|
|
9
|
+
from contextlib import contextmanager
|
|
10
|
+
from pathlib import Path
|
|
11
|
+
from typing import Any, Callable, Iterator
|
|
12
|
+
|
|
13
|
+
_thread_locks: dict[str, threading.RLock] = {}
|
|
14
|
+
_thread_locks_guard = threading.Lock()
|
|
15
|
+
|
|
16
|
+
|
|
17
|
+
def _thread_lock(path: Path) -> threading.RLock:
|
|
18
|
+
key = str(path.resolve())
|
|
19
|
+
with _thread_locks_guard:
|
|
20
|
+
return _thread_locks.setdefault(key, threading.RLock())
|
|
21
|
+
|
|
22
|
+
|
|
23
|
+
@contextmanager
|
|
24
|
+
def file_lock(path: str | Path) -> Iterator[None]:
|
|
25
|
+
"""Hold a cross-thread and cross-process lock for one state file."""
|
|
26
|
+
target = Path(path)
|
|
27
|
+
lock_path = target.with_name(f"{target.name}.lock")
|
|
28
|
+
lock_path.parent.mkdir(parents=True, exist_ok=True)
|
|
29
|
+
thread_lock = _thread_lock(lock_path)
|
|
30
|
+
with thread_lock:
|
|
31
|
+
with open(lock_path, "a+b") as handle:
|
|
32
|
+
handle.seek(0, os.SEEK_END)
|
|
33
|
+
if handle.tell() == 0:
|
|
34
|
+
handle.write(b"\0")
|
|
35
|
+
handle.flush()
|
|
36
|
+
handle.seek(0)
|
|
37
|
+
if os.name == "nt":
|
|
38
|
+
import msvcrt
|
|
39
|
+
|
|
40
|
+
msvcrt.locking( # type: ignore[attr-defined]
|
|
41
|
+
handle.fileno(),
|
|
42
|
+
msvcrt.LK_LOCK, # type: ignore[attr-defined]
|
|
43
|
+
1,
|
|
44
|
+
)
|
|
45
|
+
else:
|
|
46
|
+
import fcntl
|
|
47
|
+
|
|
48
|
+
fcntl.flock( # type: ignore[attr-defined]
|
|
49
|
+
handle.fileno(),
|
|
50
|
+
fcntl.LOCK_EX, # type: ignore[attr-defined]
|
|
51
|
+
)
|
|
52
|
+
try:
|
|
53
|
+
yield
|
|
54
|
+
finally:
|
|
55
|
+
handle.seek(0)
|
|
56
|
+
if os.name == "nt":
|
|
57
|
+
import msvcrt
|
|
58
|
+
|
|
59
|
+
msvcrt.locking( # type: ignore[attr-defined]
|
|
60
|
+
handle.fileno(),
|
|
61
|
+
msvcrt.LK_UNLCK, # type: ignore[attr-defined]
|
|
62
|
+
1,
|
|
63
|
+
)
|
|
64
|
+
else:
|
|
65
|
+
import fcntl
|
|
66
|
+
|
|
67
|
+
fcntl.flock( # type: ignore[attr-defined]
|
|
68
|
+
handle.fileno(),
|
|
69
|
+
fcntl.LOCK_UN, # type: ignore[attr-defined]
|
|
70
|
+
)
|
|
71
|
+
|
|
72
|
+
|
|
73
|
+
def atomic_write_text(
|
|
74
|
+
path: str | Path,
|
|
75
|
+
content: str,
|
|
76
|
+
*,
|
|
77
|
+
encoding: str = "utf-8",
|
|
78
|
+
lock: bool = True,
|
|
79
|
+
mode: int | None = None,
|
|
80
|
+
) -> None:
|
|
81
|
+
target = Path(path)
|
|
82
|
+
target.parent.mkdir(parents=True, exist_ok=True)
|
|
83
|
+
|
|
84
|
+
def write() -> None:
|
|
85
|
+
existing_mode = None
|
|
86
|
+
try:
|
|
87
|
+
existing_mode = target.stat().st_mode & 0o777
|
|
88
|
+
except OSError:
|
|
89
|
+
pass
|
|
90
|
+
fd, temporary_name = tempfile.mkstemp(
|
|
91
|
+
prefix=f".{target.name}.",
|
|
92
|
+
suffix=".tmp",
|
|
93
|
+
dir=target.parent,
|
|
94
|
+
)
|
|
95
|
+
temporary = Path(temporary_name)
|
|
96
|
+
try:
|
|
97
|
+
with os.fdopen(fd, "w", encoding=encoding, newline="") as handle:
|
|
98
|
+
handle.write(content)
|
|
99
|
+
handle.flush()
|
|
100
|
+
os.fsync(handle.fileno())
|
|
101
|
+
if existing_mode is not None:
|
|
102
|
+
os.chmod(temporary, existing_mode)
|
|
103
|
+
elif mode is not None:
|
|
104
|
+
os.chmod(temporary, mode)
|
|
105
|
+
os.replace(temporary, target)
|
|
106
|
+
finally:
|
|
107
|
+
temporary.unlink(missing_ok=True)
|
|
108
|
+
|
|
109
|
+
if lock:
|
|
110
|
+
with file_lock(target):
|
|
111
|
+
write()
|
|
112
|
+
else:
|
|
113
|
+
write()
|
|
114
|
+
|
|
115
|
+
|
|
116
|
+
def atomic_write_json(
|
|
117
|
+
path: str | Path,
|
|
118
|
+
data: Any,
|
|
119
|
+
*,
|
|
120
|
+
indent: int = 2,
|
|
121
|
+
lock: bool = True,
|
|
122
|
+
) -> None:
|
|
123
|
+
atomic_write_text(
|
|
124
|
+
path,
|
|
125
|
+
json.dumps(data, indent=indent, ensure_ascii=False) + "\n",
|
|
126
|
+
lock=lock,
|
|
127
|
+
)
|
|
128
|
+
|
|
129
|
+
|
|
130
|
+
def quarantine_corrupt_file(path: str | Path) -> Path | None:
|
|
131
|
+
target = Path(path)
|
|
132
|
+
if not target.exists():
|
|
133
|
+
return None
|
|
134
|
+
stamp = time.strftime("%Y%m%d-%H%M%S")
|
|
135
|
+
destination = target.with_name(f"{target.name}.corrupt-{stamp}")
|
|
136
|
+
counter = 1
|
|
137
|
+
while destination.exists():
|
|
138
|
+
destination = target.with_name(
|
|
139
|
+
f"{target.name}.corrupt-{stamp}-{counter}"
|
|
140
|
+
)
|
|
141
|
+
counter += 1
|
|
142
|
+
try:
|
|
143
|
+
os.replace(target, destination)
|
|
144
|
+
return destination
|
|
145
|
+
except OSError:
|
|
146
|
+
return None
|
|
147
|
+
|
|
148
|
+
|
|
149
|
+
def read_json(
|
|
150
|
+
path: str | Path,
|
|
151
|
+
default: Any,
|
|
152
|
+
*,
|
|
153
|
+
expected_type: type | tuple[type, ...] | None = None,
|
|
154
|
+
quarantine: bool = True,
|
|
155
|
+
) -> Any:
|
|
156
|
+
target = Path(path)
|
|
157
|
+
if not target.exists():
|
|
158
|
+
return default
|
|
159
|
+
with file_lock(target):
|
|
160
|
+
try:
|
|
161
|
+
data = json.loads(target.read_text(encoding="utf-8"))
|
|
162
|
+
if expected_type is not None and not isinstance(data, expected_type):
|
|
163
|
+
raise ValueError("unexpected JSON root type")
|
|
164
|
+
return data
|
|
165
|
+
except (OSError, UnicodeError, json.JSONDecodeError, ValueError):
|
|
166
|
+
if quarantine:
|
|
167
|
+
quarantine_corrupt_file(target)
|
|
168
|
+
return default
|
|
169
|
+
|
|
170
|
+
|
|
171
|
+
def update_json(
|
|
172
|
+
path: str | Path,
|
|
173
|
+
default: Any,
|
|
174
|
+
updater: Callable[[Any], Any],
|
|
175
|
+
*,
|
|
176
|
+
expected_type: type | tuple[type, ...] | None = None,
|
|
177
|
+
) -> Any:
|
|
178
|
+
"""Read-modify-write one JSON file while holding its process lock."""
|
|
179
|
+
target = Path(path)
|
|
180
|
+
with file_lock(target):
|
|
181
|
+
if target.exists():
|
|
182
|
+
try:
|
|
183
|
+
current = json.loads(target.read_text(encoding="utf-8"))
|
|
184
|
+
if expected_type is not None and not isinstance(current, expected_type):
|
|
185
|
+
raise ValueError("unexpected JSON root type")
|
|
186
|
+
except (OSError, UnicodeError, json.JSONDecodeError, ValueError):
|
|
187
|
+
quarantine_corrupt_file(target)
|
|
188
|
+
current = default
|
|
189
|
+
else:
|
|
190
|
+
current = default
|
|
191
|
+
updated = updater(current)
|
|
192
|
+
atomic_write_json(target, updated, lock=False)
|
|
193
|
+
return updated
|