contree-mcp 0.1.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- contree_mcp/__init__.py +0 -0
- contree_mcp/__main__.py +25 -0
- contree_mcp/app.py +240 -0
- contree_mcp/arguments.py +35 -0
- contree_mcp/auth/__init__.py +2 -0
- contree_mcp/auth/registry.py +236 -0
- contree_mcp/backend_types.py +301 -0
- contree_mcp/cache.py +208 -0
- contree_mcp/client.py +711 -0
- contree_mcp/context.py +53 -0
- contree_mcp/docs.py +1203 -0
- contree_mcp/file_cache.py +381 -0
- contree_mcp/prompts.py +238 -0
- contree_mcp/py.typed +0 -0
- contree_mcp/resources/__init__.py +17 -0
- contree_mcp/resources/guide.py +715 -0
- contree_mcp/resources/image_lineage.py +46 -0
- contree_mcp/resources/image_ls.py +32 -0
- contree_mcp/resources/import_operation.py +52 -0
- contree_mcp/resources/instance_operation.py +52 -0
- contree_mcp/resources/read_file.py +33 -0
- contree_mcp/resources/static.py +12 -0
- contree_mcp/server.py +77 -0
- contree_mcp/tools/__init__.py +39 -0
- contree_mcp/tools/cancel_operation.py +36 -0
- contree_mcp/tools/download.py +128 -0
- contree_mcp/tools/get_guide.py +54 -0
- contree_mcp/tools/get_image.py +30 -0
- contree_mcp/tools/get_operation.py +26 -0
- contree_mcp/tools/import_image.py +99 -0
- contree_mcp/tools/list_files.py +80 -0
- contree_mcp/tools/list_images.py +50 -0
- contree_mcp/tools/list_operations.py +46 -0
- contree_mcp/tools/read_file.py +47 -0
- contree_mcp/tools/registry_auth.py +71 -0
- contree_mcp/tools/registry_token_obtain.py +80 -0
- contree_mcp/tools/rsync.py +46 -0
- contree_mcp/tools/run.py +97 -0
- contree_mcp/tools/set_tag.py +31 -0
- contree_mcp/tools/upload.py +50 -0
- contree_mcp/tools/wait_operations.py +79 -0
- contree_mcp-0.1.0.dist-info/METADATA +450 -0
- contree_mcp-0.1.0.dist-info/RECORD +46 -0
- contree_mcp-0.1.0.dist-info/WHEEL +4 -0
- contree_mcp-0.1.0.dist-info/entry_points.txt +2 -0
- contree_mcp-0.1.0.dist-info/licenses/LICENSE +176 -0
|
@@ -0,0 +1,301 @@
|
|
|
1
|
+
"""Backend types for Contree MCP DO NOT EDIT THIS FILE UNLESS YOU KNOW WHAT YOU ARE DOING"""
|
|
2
|
+
|
|
3
|
+
from base64 import b64decode, b64encode
|
|
4
|
+
from enum import Enum
|
|
5
|
+
from typing import Any, Literal, TypeVar
|
|
6
|
+
from uuid import UUID
|
|
7
|
+
|
|
8
|
+
from pydantic import BaseModel, ByteSize, Field, PositiveInt, model_validator
|
|
9
|
+
from typing_extensions import Self
|
|
10
|
+
|
|
11
|
+
PublicUUID = str(UUID(int=0))
|
|
12
|
+
|
|
13
|
+
E = TypeVar("E", bound=Enum)
|
|
14
|
+
|
|
15
|
+
|
|
16
|
+
class ImageCredentials(BaseModel):
|
|
17
|
+
username: str = ""
|
|
18
|
+
password: str = ""
|
|
19
|
+
|
|
20
|
+
|
|
21
|
+
class ImageRegistry(BaseModel):
|
|
22
|
+
url: str
|
|
23
|
+
credentials: ImageCredentials = ImageCredentials()
|
|
24
|
+
|
|
25
|
+
|
|
26
|
+
class ImportImageMetadata(BaseModel):
|
|
27
|
+
registry: ImageRegistry
|
|
28
|
+
tag: str | None = None
|
|
29
|
+
timeout: PositiveInt = 300
|
|
30
|
+
|
|
31
|
+
|
|
32
|
+
class ImageSize(BaseModel):
|
|
33
|
+
physical: int = -1
|
|
34
|
+
logical: int = -1
|
|
35
|
+
|
|
36
|
+
|
|
37
|
+
class Image(BaseModel):
|
|
38
|
+
"""Response model for image endpoints.
|
|
39
|
+
|
|
40
|
+
API handlers:
|
|
41
|
+
- GET /inspect/{image_uuid}/ -> Image
|
|
42
|
+
- GET /inspect/?tag={tag} -> Image (redirect)
|
|
43
|
+
- PATCH /images/{image_uuid}/tag -> Image
|
|
44
|
+
- DELETE /images/{image_uuid}/tag -> Image
|
|
45
|
+
"""
|
|
46
|
+
|
|
47
|
+
uuid: str = Field(description="Image UUID")
|
|
48
|
+
tag: str | None = Field(default=None, description="Image tag or null")
|
|
49
|
+
created_at: str = Field(default="", description="ISO 8601 creation timestamp")
|
|
50
|
+
|
|
51
|
+
|
|
52
|
+
class ImageListResponse(BaseModel):
|
|
53
|
+
"""Response from GET /images.
|
|
54
|
+
|
|
55
|
+
API handlers:
|
|
56
|
+
- GET /images -> ImageListResponse
|
|
57
|
+
"""
|
|
58
|
+
|
|
59
|
+
images: list[Image] = Field(default_factory=list)
|
|
60
|
+
|
|
61
|
+
|
|
62
|
+
class FileItem(BaseModel):
|
|
63
|
+
"""File info in directory listing."""
|
|
64
|
+
|
|
65
|
+
path: str = Field(description="File name relative to directory")
|
|
66
|
+
size: int = Field(description="File size in bytes")
|
|
67
|
+
owner: int | str = Field(description="User ID or name of owner")
|
|
68
|
+
group: int | str = Field(description="Group ID or name")
|
|
69
|
+
mode: int = Field(description="File permissions as integer")
|
|
70
|
+
mtime: int = Field(description="Last modification Unix timestamp")
|
|
71
|
+
is_dir: bool = Field(description="Directory indicator")
|
|
72
|
+
is_regular: bool = Field(description="Regular file indicator")
|
|
73
|
+
is_symlink: bool = Field(description="Symbolic link indicator")
|
|
74
|
+
is_socket: bool = Field(description="Socket indicator")
|
|
75
|
+
is_fifo: bool = Field(description="FIFO/named pipe indicator")
|
|
76
|
+
symlink_to: str = Field(default="", description="Target path for symlinks")
|
|
77
|
+
|
|
78
|
+
|
|
79
|
+
class DirectoryList(BaseModel):
|
|
80
|
+
"""Response from GET /inspect/{uuid}/list.
|
|
81
|
+
|
|
82
|
+
API handlers:
|
|
83
|
+
- GET /inspect/{image_uuid}/list -> DirectoryList
|
|
84
|
+
"""
|
|
85
|
+
|
|
86
|
+
path: str = Field(description="Directory path listed")
|
|
87
|
+
files: list[FileItem] = Field(default_factory=list)
|
|
88
|
+
|
|
89
|
+
|
|
90
|
+
class FileResponse(BaseModel):
|
|
91
|
+
"""Response from file endpoints.
|
|
92
|
+
|
|
93
|
+
API handlers:
|
|
94
|
+
- POST /files -> FileResponse
|
|
95
|
+
- GET /files?sha256={hash} -> FileResponse
|
|
96
|
+
"""
|
|
97
|
+
|
|
98
|
+
uuid: str = Field(description="File UUID")
|
|
99
|
+
sha256: str = Field(description="SHA256 hash of file content")
|
|
100
|
+
|
|
101
|
+
|
|
102
|
+
class InstanceSpawnResponse(BaseModel):
|
|
103
|
+
"""Response from POST /instances.
|
|
104
|
+
|
|
105
|
+
API handlers:
|
|
106
|
+
- POST /instances -> InstanceSpawnResponse
|
|
107
|
+
- POST /images/import -> InstanceSpawnResponse (same format)
|
|
108
|
+
"""
|
|
109
|
+
|
|
110
|
+
uuid: str = Field(description="Operation UUID")
|
|
111
|
+
|
|
112
|
+
|
|
113
|
+
class Stream(BaseModel):
|
|
114
|
+
value: str
|
|
115
|
+
encoding: Literal["ascii", "base64"] = "ascii"
|
|
116
|
+
truncated: bool = False
|
|
117
|
+
|
|
118
|
+
def text(self) -> str:
|
|
119
|
+
if self.encoding == "ascii":
|
|
120
|
+
return self.value
|
|
121
|
+
elif self.encoding == "base64":
|
|
122
|
+
return b64decode(self.value).decode("utf-8", errors="replace")
|
|
123
|
+
raise ValueError(f"Unsupported encoding: {self.encoding}")
|
|
124
|
+
|
|
125
|
+
@classmethod
|
|
126
|
+
def from_bytes(cls, data: bytes, max_size: int = -1) -> Self:
|
|
127
|
+
encoding: Literal["ascii", "base64"]
|
|
128
|
+
truncated = False
|
|
129
|
+
if 0 < max_size < len(data):
|
|
130
|
+
data = data[:max_size]
|
|
131
|
+
truncated = True
|
|
132
|
+
try:
|
|
133
|
+
encoding = "ascii"
|
|
134
|
+
value = data.decode("ascii")
|
|
135
|
+
except UnicodeDecodeError:
|
|
136
|
+
encoding = "base64"
|
|
137
|
+
value = b64encode(data).decode("ascii")
|
|
138
|
+
return cls(value=value, encoding=encoding, truncated=truncated)
|
|
139
|
+
|
|
140
|
+
def to_bytes(self) -> bytes:
|
|
141
|
+
match self.encoding:
|
|
142
|
+
case "ascii":
|
|
143
|
+
return self.value.encode("ascii")
|
|
144
|
+
case "base64":
|
|
145
|
+
return b64decode(self.value)
|
|
146
|
+
case _:
|
|
147
|
+
raise ValueError(f"Unsupported encoding: {self.encoding}")
|
|
148
|
+
|
|
149
|
+
def __bool__(self) -> bool:
|
|
150
|
+
return bool(self.value)
|
|
151
|
+
|
|
152
|
+
|
|
153
|
+
class ConsumedResources(BaseModel):
|
|
154
|
+
block_input: int = -1
|
|
155
|
+
block_output: int = -1
|
|
156
|
+
cost: float = -1.0
|
|
157
|
+
elapsed_time: float = -1
|
|
158
|
+
involuntary_switches: int = -1
|
|
159
|
+
max_rss: int = -1
|
|
160
|
+
monotonic_time: float = -1
|
|
161
|
+
page_faults: int = -1
|
|
162
|
+
page_faults_io: int = -1
|
|
163
|
+
shared_memory: int = -1
|
|
164
|
+
signals: int = -1
|
|
165
|
+
swaps: int = -1
|
|
166
|
+
system_cpu_time: float = -1.0
|
|
167
|
+
unshared_memory: int = -1
|
|
168
|
+
user_cpu_time: float = -1.0
|
|
169
|
+
voluntary_switches: int = -1
|
|
170
|
+
|
|
171
|
+
|
|
172
|
+
class ProcessExitState(BaseModel):
|
|
173
|
+
continued: bool = False
|
|
174
|
+
core_dump: bool = False
|
|
175
|
+
exit_code: int = 0
|
|
176
|
+
pid: int = 0
|
|
177
|
+
signal: int = -1
|
|
178
|
+
stopped: bool = False
|
|
179
|
+
timed_out: bool = False
|
|
180
|
+
|
|
181
|
+
|
|
182
|
+
class InstanceResult(BaseModel):
|
|
183
|
+
resources: ConsumedResources = ConsumedResources()
|
|
184
|
+
state: ProcessExitState = ProcessExitState()
|
|
185
|
+
stdout: Stream
|
|
186
|
+
stderr: Stream
|
|
187
|
+
|
|
188
|
+
|
|
189
|
+
class InstanceFileSpec(BaseModel):
|
|
190
|
+
uuid: str
|
|
191
|
+
mode: str = "0644"
|
|
192
|
+
uid: int = 0
|
|
193
|
+
gid: int = 0
|
|
194
|
+
|
|
195
|
+
|
|
196
|
+
class InstanceMetadata(BaseModel):
|
|
197
|
+
"""Metadata for instance execution operations"""
|
|
198
|
+
|
|
199
|
+
command: str = Field(description="Command to run")
|
|
200
|
+
image: str = Field(description="Image UUID or string starts with 'tag:'")
|
|
201
|
+
hostname: str = "linuxkit"
|
|
202
|
+
args: list[str] = Field(default_factory=list, description="Command arguments, must be used with shell is false")
|
|
203
|
+
shell: bool = Field(default=False, description="In this mode command is a shell expression and args must be empty")
|
|
204
|
+
env: dict[str, str] = Field(default_factory=dict)
|
|
205
|
+
cwd: str = Field(default="/root", description="Path to the working directory, must be absolute")
|
|
206
|
+
disposable: bool = False
|
|
207
|
+
stdin: Stream = Stream(value="")
|
|
208
|
+
timeout: PositiveInt = 60
|
|
209
|
+
truncate_output_at: ByteSize = ByteSize(64 * 1024)
|
|
210
|
+
files: dict[str, InstanceFileSpec] = Field(default_factory=dict, description="Files to add to the image")
|
|
211
|
+
result: InstanceResult | None = None
|
|
212
|
+
|
|
213
|
+
|
|
214
|
+
class OperationStatus(str, Enum):
|
|
215
|
+
PENDING = "PENDING"
|
|
216
|
+
EXECUTING = "EXECUTING"
|
|
217
|
+
SUCCESS = "SUCCESS"
|
|
218
|
+
FAILED = "FAILED"
|
|
219
|
+
CANCELLED = "CANCELLED"
|
|
220
|
+
ASSIGNED = "ASSIGNED"
|
|
221
|
+
|
|
222
|
+
def is_terminal(self) -> bool:
|
|
223
|
+
cls = self.__class__
|
|
224
|
+
return self in {cls.SUCCESS, cls.FAILED, cls.CANCELLED}
|
|
225
|
+
|
|
226
|
+
|
|
227
|
+
class OperationKind(str, Enum):
|
|
228
|
+
INSTANCE = "instance"
|
|
229
|
+
IMAGE_IMPORT = "image_import"
|
|
230
|
+
|
|
231
|
+
|
|
232
|
+
class OperationResult(BaseModel):
|
|
233
|
+
image: str | None = Field(default=None, description="Result image UUID or null")
|
|
234
|
+
tag: str | None = Field(default=None, description="Assigned tag or null")
|
|
235
|
+
|
|
236
|
+
|
|
237
|
+
class OperationSummary(BaseModel):
|
|
238
|
+
"""Summary model for operations in list.
|
|
239
|
+
|
|
240
|
+
API handlers:
|
|
241
|
+
- GET /operations -> OperationListResponse.operations (list of OperationSummary)
|
|
242
|
+
"""
|
|
243
|
+
|
|
244
|
+
uuid: str = Field(description="Operation UUID")
|
|
245
|
+
kind: OperationKind = Field(description="Operation kind")
|
|
246
|
+
status: OperationStatus = Field(description="Operation status")
|
|
247
|
+
error: str | None = Field(default=None, description="Error message if failed")
|
|
248
|
+
created_at: str = Field(default="", description="ISO 8601 creation timestamp")
|
|
249
|
+
|
|
250
|
+
|
|
251
|
+
class OperationListResponse(BaseModel):
|
|
252
|
+
"""Response from GET /operations.
|
|
253
|
+
|
|
254
|
+
API handlers:
|
|
255
|
+
- GET /operations -> OperationListResponse
|
|
256
|
+
"""
|
|
257
|
+
|
|
258
|
+
operations: list[OperationSummary] = Field(default_factory=list)
|
|
259
|
+
|
|
260
|
+
@model_validator(mode="before")
|
|
261
|
+
@classmethod
|
|
262
|
+
def wrap_list(cls, data: Any) -> Any:
|
|
263
|
+
"""Handle backend returning list instead of dict with operations key."""
|
|
264
|
+
if isinstance(data, list):
|
|
265
|
+
return {"operations": data}
|
|
266
|
+
return data
|
|
267
|
+
|
|
268
|
+
|
|
269
|
+
class OperationResponse(BaseModel):
|
|
270
|
+
"""Response model for operation detail endpoint.
|
|
271
|
+
|
|
272
|
+
API handlers:
|
|
273
|
+
- GET /operations/{operation_id} -> OperationResponse
|
|
274
|
+
"""
|
|
275
|
+
|
|
276
|
+
uuid: str = Field(default="", description="Operation UUID")
|
|
277
|
+
status: OperationStatus = Field(description="Operation status")
|
|
278
|
+
kind: OperationKind = Field(description="Operation kind")
|
|
279
|
+
error: str | None = Field(default=None, description="Error message if any")
|
|
280
|
+
metadata: InstanceMetadata | ImportImageMetadata | None = Field(default=None, description="Operation metadata")
|
|
281
|
+
result: OperationResult | None = Field(default=None, description="Operation result")
|
|
282
|
+
duration: float = Field(default=0.0, description="Operation duration")
|
|
283
|
+
|
|
284
|
+
@model_validator(mode="before")
|
|
285
|
+
@classmethod
|
|
286
|
+
def parse_metadata(cls, data: Any) -> Any:
|
|
287
|
+
if isinstance(data, dict) and "kind" in data:
|
|
288
|
+
data = dict(data) # Create mutable copy
|
|
289
|
+
data["kind"] = data["kind"].lower()
|
|
290
|
+
return data
|
|
291
|
+
|
|
292
|
+
|
|
293
|
+
class CancelOperationResponse(BaseModel):
|
|
294
|
+
"""Response from cancel operation endpoint.
|
|
295
|
+
|
|
296
|
+
API handlers:
|
|
297
|
+
- DELETE /operations/{operation_id} -> CancelOperationResponse
|
|
298
|
+
"""
|
|
299
|
+
|
|
300
|
+
uuid: str = Field(default="", description="Operation UUID")
|
|
301
|
+
status: OperationStatus = Field(default=OperationStatus.CANCELLED)
|
contree_mcp/cache.py
ADDED
|
@@ -0,0 +1,208 @@
|
|
|
1
|
+
from __future__ import annotations
|
|
2
|
+
|
|
3
|
+
import asyncio
|
|
4
|
+
import json
|
|
5
|
+
import re
|
|
6
|
+
import sqlite3
|
|
7
|
+
from collections.abc import Mapping
|
|
8
|
+
from contextlib import suppress
|
|
9
|
+
from dataclasses import dataclass
|
|
10
|
+
from datetime import datetime, timedelta, timezone
|
|
11
|
+
from pathlib import Path
|
|
12
|
+
from types import MappingProxyType
|
|
13
|
+
from typing import Any
|
|
14
|
+
|
|
15
|
+
import aiosqlite
|
|
16
|
+
from pydantic import BaseModel
|
|
17
|
+
|
|
18
|
+
# Pattern: alphanumeric, underscore, dot (for nested paths like "user.name")
|
|
19
|
+
SAFE_FIELD_PATTERN = re.compile(r"^[a-zA-Z_][a-zA-Z0-9_.]*$")
|
|
20
|
+
|
|
21
|
+
|
|
22
|
+
@dataclass(frozen=True)
|
|
23
|
+
class CacheEntry:
|
|
24
|
+
id: int
|
|
25
|
+
kind: str
|
|
26
|
+
key: str
|
|
27
|
+
parent_id: int | None # Reference to parent entry's id
|
|
28
|
+
data: Mapping[str, Any]
|
|
29
|
+
created_at: datetime
|
|
30
|
+
updated_at: datetime
|
|
31
|
+
|
|
32
|
+
@classmethod
|
|
33
|
+
def from_row(cls, cursor: sqlite3.Cursor, row: tuple[Any, ...]) -> CacheEntry:
|
|
34
|
+
row_dict: dict[str, Any] = {col[0]: row[idx] for idx, col in enumerate(cursor.description or [])}
|
|
35
|
+
row_dict["data"] = MappingProxyType(json.loads(row_dict["data"]))
|
|
36
|
+
return cls(**row_dict)
|
|
37
|
+
|
|
38
|
+
|
|
39
|
+
class Cache:
|
|
40
|
+
DEFAULT_CACHE_DIR = Path.home() / ".cache" / "contree_mcp"
|
|
41
|
+
DEFAULT_CACHE_DB_PATH = DEFAULT_CACHE_DIR / "cache.db"
|
|
42
|
+
|
|
43
|
+
SCHEMA = """
|
|
44
|
+
CREATE TABLE IF NOT EXISTS cache (
|
|
45
|
+
id INTEGER PRIMARY KEY AUTOINCREMENT,
|
|
46
|
+
kind TEXT NOT NULL,
|
|
47
|
+
key TEXT NOT NULL,
|
|
48
|
+
parent_id INTEGER,
|
|
49
|
+
data TEXT NOT NULL,
|
|
50
|
+
created_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP NOT NULL,
|
|
51
|
+
updated_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP NOT NULL,
|
|
52
|
+
UNIQUE(kind, key),
|
|
53
|
+
FOREIGN KEY (parent_id) REFERENCES cache(id)
|
|
54
|
+
);
|
|
55
|
+
CREATE INDEX IF NOT EXISTS idx_cache_kind ON cache(kind);
|
|
56
|
+
CREATE INDEX IF NOT EXISTS idx_cache_parent ON cache(parent_id);
|
|
57
|
+
CREATE INDEX IF NOT EXISTS idx_cache_created ON cache(created_at);
|
|
58
|
+
"""
|
|
59
|
+
|
|
60
|
+
def __init__(self, db_path: Path = DEFAULT_CACHE_DB_PATH, retention_days: int = 120) -> None:
|
|
61
|
+
self.db_path = db_path
|
|
62
|
+
self.db_path.parent.mkdir(parents=True, exist_ok=True)
|
|
63
|
+
self.retention_days = retention_days
|
|
64
|
+
self.__conn: aiosqlite.Connection | None = None
|
|
65
|
+
self.__lock = asyncio.Lock()
|
|
66
|
+
self.__retention_task: asyncio.Task[None] | None = None
|
|
67
|
+
|
|
68
|
+
@property
|
|
69
|
+
def conn(self) -> aiosqlite.Connection:
|
|
70
|
+
"""Get the database connection (sync wrapper)."""
|
|
71
|
+
if self.__conn is None:
|
|
72
|
+
raise RuntimeError("Database connection not initialized. Use 'async with' context.")
|
|
73
|
+
return self.__conn
|
|
74
|
+
|
|
75
|
+
async def _init_db(self) -> None:
|
|
76
|
+
async with self.__lock:
|
|
77
|
+
if self.__conn is not None:
|
|
78
|
+
raise RuntimeError("Database already initialized.")
|
|
79
|
+
|
|
80
|
+
conn = await aiosqlite.connect(str(self.db_path))
|
|
81
|
+
await conn.execute("PRAGMA journal_mode=WAL")
|
|
82
|
+
conn.row_factory = CacheEntry.from_row # type: ignore[assignment]
|
|
83
|
+
await conn.executescript(self.SCHEMA)
|
|
84
|
+
await conn.commit()
|
|
85
|
+
self.__conn = conn
|
|
86
|
+
self.__retention_task = asyncio.create_task(self.retain_periodically())
|
|
87
|
+
|
|
88
|
+
async def retain_periodically(self, interval_hours: int = 24) -> None:
|
|
89
|
+
await self._retain()
|
|
90
|
+
while True:
|
|
91
|
+
await asyncio.sleep(interval_hours * 3600)
|
|
92
|
+
with suppress(Exception):
|
|
93
|
+
await self._retain()
|
|
94
|
+
|
|
95
|
+
async def _retain(self) -> None:
|
|
96
|
+
if self.retention_days <= 0:
|
|
97
|
+
return
|
|
98
|
+
cutoff = (datetime.now(timezone.utc) - timedelta(days=self.retention_days)).isoformat()
|
|
99
|
+
await self.conn.execute("DELETE FROM cache WHERE created_at < ?", (cutoff,))
|
|
100
|
+
await self.conn.commit()
|
|
101
|
+
|
|
102
|
+
async def close(self) -> None:
|
|
103
|
+
async with self.__lock:
|
|
104
|
+
if self.__conn is None:
|
|
105
|
+
return
|
|
106
|
+
await self.__conn.close()
|
|
107
|
+
if self.__retention_task is not None:
|
|
108
|
+
self.__retention_task.cancel()
|
|
109
|
+
await asyncio.gather(self.__retention_task, return_exceptions=True)
|
|
110
|
+
self.__conn = None
|
|
111
|
+
|
|
112
|
+
async def __aenter__(self) -> Cache:
|
|
113
|
+
await self._init_db()
|
|
114
|
+
return self
|
|
115
|
+
|
|
116
|
+
async def __aexit__(self, *args: object) -> None:
|
|
117
|
+
await self.close()
|
|
118
|
+
|
|
119
|
+
async def get(self, kind: str, key: str, ttl: int | float = -1) -> CacheEntry | None:
|
|
120
|
+
async with self.conn.execute("""SELECT * FROM cache WHERE kind = ? AND key = ?""", (kind, key)) as cursor:
|
|
121
|
+
result = await cursor.fetchone()
|
|
122
|
+
|
|
123
|
+
if result is None:
|
|
124
|
+
return None
|
|
125
|
+
|
|
126
|
+
if ttl > 0 and isinstance(result, CacheEntry):
|
|
127
|
+
age = (datetime.now(timezone.utc) - result.updated_at).total_seconds()
|
|
128
|
+
if age > ttl:
|
|
129
|
+
return None
|
|
130
|
+
return result # type: ignore[return-value]
|
|
131
|
+
|
|
132
|
+
async def put(
|
|
133
|
+
self, kind: str, key: str, data: dict[str, Any] | BaseModel, parent_id: int | None = None
|
|
134
|
+
) -> CacheEntry:
|
|
135
|
+
if isinstance(data, BaseModel):
|
|
136
|
+
data = data.model_dump(mode="json")
|
|
137
|
+
|
|
138
|
+
await self.conn.execute(
|
|
139
|
+
"""
|
|
140
|
+
INSERT INTO cache (kind, key, parent_id, data, updated_at)
|
|
141
|
+
VALUES (?, ?, ?, ?, ?)
|
|
142
|
+
ON CONFLICT(kind, key) DO UPDATE
|
|
143
|
+
SET parent_id=excluded.parent_id, data=excluded.data, updated_at=excluded.updated_at
|
|
144
|
+
""",
|
|
145
|
+
(kind, key, parent_id, json.dumps(data), datetime.now(timezone.utc)),
|
|
146
|
+
)
|
|
147
|
+
await self.conn.commit()
|
|
148
|
+
entry = await self.get(kind, key)
|
|
149
|
+
if entry is None:
|
|
150
|
+
raise RuntimeError("Failed to retrieve cache entry after insertion.")
|
|
151
|
+
return entry
|
|
152
|
+
|
|
153
|
+
async def delete(self, kind: str, key: str) -> bool:
|
|
154
|
+
cursor = await self.conn.execute("DELETE FROM cache WHERE kind = ? AND key = ?", (kind, key))
|
|
155
|
+
await self.conn.commit()
|
|
156
|
+
return cursor.rowcount > 0
|
|
157
|
+
|
|
158
|
+
async def list_entries(self, kind: str, limit: int = 100, **field_filter: Any) -> list[CacheEntry]:
|
|
159
|
+
# Validate filter keys to prevent SQL injection
|
|
160
|
+
for key in field_filter:
|
|
161
|
+
if not SAFE_FIELD_PATTERN.match(key):
|
|
162
|
+
raise ValueError(f"Invalid filter field name: {key!r}")
|
|
163
|
+
|
|
164
|
+
json_filter = list(field_filter.items())
|
|
165
|
+
json_query = " AND ".join([f"json_extract(data, '$.{k}') = ?" for k, _ in json_filter])
|
|
166
|
+
if json_query:
|
|
167
|
+
json_query = f" AND {json_query}"
|
|
168
|
+
json_params = tuple(v for _, v in json_filter)
|
|
169
|
+
query = f"""SELECT * FROM cache WHERE kind = ? {json_query} ORDER BY created_at DESC LIMIT ?"""
|
|
170
|
+
params = (kind, *json_params, limit)
|
|
171
|
+
async with self.conn.execute(query, params) as cursor:
|
|
172
|
+
return await cursor.fetchall() # type: ignore[return-value]
|
|
173
|
+
|
|
174
|
+
async def get_by_id(self, entry_id: int) -> CacheEntry | None:
|
|
175
|
+
async with self.conn.execute("""SELECT * FROM cache WHERE id = ?""", (entry_id,)) as cursor:
|
|
176
|
+
return await cursor.fetchone() # type: ignore[return-value]
|
|
177
|
+
|
|
178
|
+
async def get_ancestors(self, kind: str, key: str, limit: int = 50) -> list[CacheEntry]:
|
|
179
|
+
query = """
|
|
180
|
+
WITH RECURSIVE ancestor_chain(id, kind, key, parent_id, data, created_at, updated_at, depth) AS
|
|
181
|
+
(
|
|
182
|
+
SELECT *, 0 FROM cache WHERE kind = ? AND key = ?
|
|
183
|
+
UNION ALL
|
|
184
|
+
SELECT c.*, ac.depth + 1 FROM cache c
|
|
185
|
+
INNER JOIN ancestor_chain ac ON c.id = ac.parent_id WHERE ac.depth < ?
|
|
186
|
+
)
|
|
187
|
+
SELECT id, kind, key, parent_id, data, created_at, updated_at
|
|
188
|
+
FROM ancestor_chain WHERE depth > 0 ORDER BY depth
|
|
189
|
+
"""
|
|
190
|
+
async with self.conn.execute(query, (kind, key, limit)) as cursor:
|
|
191
|
+
return await cursor.fetchall() # type: ignore[return-value]
|
|
192
|
+
|
|
193
|
+
async def get_children(self, kind: str, parent_key: str, limit: int = 50) -> list[CacheEntry]:
|
|
194
|
+
parent = await self.get(kind, parent_key)
|
|
195
|
+
if parent is None:
|
|
196
|
+
return []
|
|
197
|
+
|
|
198
|
+
query = """
|
|
199
|
+
WITH RECURSIVE child_chain(id, kind, key, parent_id, data, created_at, updated_at) AS (
|
|
200
|
+
SELECT * FROM cache WHERE parent_id = ?
|
|
201
|
+
UNION ALL
|
|
202
|
+
SELECT c.* FROM cache c INNER JOIN child_chain cc ON c.parent_id = cc.id
|
|
203
|
+
)
|
|
204
|
+
SELECT * FROM child_chain LIMIT ?
|
|
205
|
+
"""
|
|
206
|
+
|
|
207
|
+
async with self.conn.execute(query, (parent.id, limit)) as cursor:
|
|
208
|
+
return await cursor.fetchall() # type: ignore[return-value]
|