esgpull 0.6.3__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- esgpull/__init__.py +12 -0
- esgpull/auth.py +181 -0
- esgpull/cli/__init__.py +73 -0
- esgpull/cli/add.py +103 -0
- esgpull/cli/autoremove.py +38 -0
- esgpull/cli/config.py +116 -0
- esgpull/cli/convert.py +285 -0
- esgpull/cli/decorators.py +342 -0
- esgpull/cli/download.py +74 -0
- esgpull/cli/facet.py +23 -0
- esgpull/cli/get.py +28 -0
- esgpull/cli/install.py +85 -0
- esgpull/cli/link.py +105 -0
- esgpull/cli/login.py +56 -0
- esgpull/cli/remove.py +73 -0
- esgpull/cli/retry.py +43 -0
- esgpull/cli/search.py +201 -0
- esgpull/cli/self.py +238 -0
- esgpull/cli/show.py +66 -0
- esgpull/cli/status.py +67 -0
- esgpull/cli/track.py +87 -0
- esgpull/cli/update.py +184 -0
- esgpull/cli/utils.py +247 -0
- esgpull/config.py +410 -0
- esgpull/constants.py +56 -0
- esgpull/context.py +724 -0
- esgpull/database.py +161 -0
- esgpull/download.py +162 -0
- esgpull/esgpull.py +447 -0
- esgpull/exceptions.py +167 -0
- esgpull/fs.py +253 -0
- esgpull/graph.py +460 -0
- esgpull/install_config.py +185 -0
- esgpull/migrations/README +1 -0
- esgpull/migrations/env.py +82 -0
- esgpull/migrations/script.py.mako +24 -0
- esgpull/migrations/versions/0.3.0_update_tables.py +170 -0
- esgpull/migrations/versions/0.3.1_update_tables.py +25 -0
- esgpull/migrations/versions/0.3.2_update_tables.py +26 -0
- esgpull/migrations/versions/0.3.3_update_tables.py +25 -0
- esgpull/migrations/versions/0.3.4_update_tables.py +25 -0
- esgpull/migrations/versions/0.3.5_update_tables.py +25 -0
- esgpull/migrations/versions/0.3.6_update_tables.py +26 -0
- esgpull/migrations/versions/0.3.7_update_tables.py +26 -0
- esgpull/migrations/versions/0.3.8_update_tables.py +26 -0
- esgpull/migrations/versions/0.4.0_update_tables.py +25 -0
- esgpull/migrations/versions/0.5.0_update_tables.py +26 -0
- esgpull/migrations/versions/0.5.1_update_tables.py +26 -0
- esgpull/migrations/versions/0.5.2_update_tables.py +25 -0
- esgpull/migrations/versions/0.5.3_update_tables.py +26 -0
- esgpull/migrations/versions/0.5.4_update_tables.py +25 -0
- esgpull/migrations/versions/0.5.5_update_tables.py +25 -0
- esgpull/migrations/versions/0.6.0_update_tables.py +25 -0
- esgpull/migrations/versions/0.6.1_update_tables.py +25 -0
- esgpull/migrations/versions/0.6.2_update_tables.py +25 -0
- esgpull/migrations/versions/0.6.3_update_tables.py +25 -0
- esgpull/models/__init__.py +31 -0
- esgpull/models/base.py +50 -0
- esgpull/models/dataset.py +34 -0
- esgpull/models/facet.py +18 -0
- esgpull/models/file.py +65 -0
- esgpull/models/options.py +164 -0
- esgpull/models/query.py +481 -0
- esgpull/models/selection.py +201 -0
- esgpull/models/sql.py +258 -0
- esgpull/models/synda_file.py +85 -0
- esgpull/models/tag.py +19 -0
- esgpull/models/utils.py +54 -0
- esgpull/presets.py +13 -0
- esgpull/processor.py +172 -0
- esgpull/py.typed +0 -0
- esgpull/result.py +53 -0
- esgpull/tui.py +346 -0
- esgpull/utils.py +54 -0
- esgpull/version.py +1 -0
- esgpull-0.6.3.dist-info/METADATA +110 -0
- esgpull-0.6.3.dist-info/RECORD +80 -0
- esgpull-0.6.3.dist-info/WHEEL +4 -0
- esgpull-0.6.3.dist-info/entry_points.txt +3 -0
- esgpull-0.6.3.dist-info/licenses/LICENSE +28 -0
esgpull/fs.py
ADDED
|
@@ -0,0 +1,253 @@
|
|
|
1
|
+
from __future__ import annotations
|
|
2
|
+
|
|
3
|
+
import hashlib
|
|
4
|
+
from collections.abc import Iterator
|
|
5
|
+
from dataclasses import InitVar, dataclass, field
|
|
6
|
+
from enum import Enum, auto
|
|
7
|
+
from pathlib import Path
|
|
8
|
+
from shutil import copyfile
|
|
9
|
+
|
|
10
|
+
import aiofiles
|
|
11
|
+
from aiofiles.threadpool.binary import AsyncBufferedIOBase
|
|
12
|
+
|
|
13
|
+
from esgpull.config import Config
|
|
14
|
+
from esgpull.models import File
|
|
15
|
+
from esgpull.result import Err, Ok, Result
|
|
16
|
+
from esgpull.tui import logger
|
|
17
|
+
|
|
18
|
+
|
|
19
|
+
class FileCheck(Enum):
|
|
20
|
+
Missing = auto() # not in any known paths
|
|
21
|
+
Part = auto() # {file.sha}.part exists
|
|
22
|
+
BadSize = auto() # {file.sha}.done exists AND has wrong size
|
|
23
|
+
BadChecksum = auto() # {file.sha}.done exists AND has wrong size
|
|
24
|
+
Done = auto() # {file.sha}.done exists AND is ready to be moved
|
|
25
|
+
Ok = auto() # file is in drs with everything ok
|
|
26
|
+
|
|
27
|
+
def as_err(self, file: File) -> Exception:
|
|
28
|
+
err_cls = type(str(self), (Exception,), {})
|
|
29
|
+
return err_cls(file)
|
|
30
|
+
|
|
31
|
+
|
|
32
|
+
@dataclass
|
|
33
|
+
class Digest:
|
|
34
|
+
file: InitVar[File]
|
|
35
|
+
alg: hashlib._Hash = field(init=False)
|
|
36
|
+
|
|
37
|
+
def __post_init__(self, file: File) -> None:
|
|
38
|
+
match file.checksum_type:
|
|
39
|
+
case "SHA256":
|
|
40
|
+
self.alg = hashlib.sha256()
|
|
41
|
+
case _:
|
|
42
|
+
raise NotImplementedError
|
|
43
|
+
|
|
44
|
+
@classmethod
|
|
45
|
+
def from_path(cls, file: File, path: Path) -> Digest:
|
|
46
|
+
block_size = path.stat().st_blksize
|
|
47
|
+
digest = cls(file)
|
|
48
|
+
with path.open("rb") as f:
|
|
49
|
+
while True:
|
|
50
|
+
block = f.read(block_size)
|
|
51
|
+
if block == b"":
|
|
52
|
+
break
|
|
53
|
+
else:
|
|
54
|
+
digest.update(block)
|
|
55
|
+
return digest
|
|
56
|
+
|
|
57
|
+
def update(self, chunk: bytes) -> None:
|
|
58
|
+
self.alg.update(chunk)
|
|
59
|
+
|
|
60
|
+
def hexdigest(self) -> str:
|
|
61
|
+
return self.alg.hexdigest()
|
|
62
|
+
|
|
63
|
+
|
|
64
|
+
@dataclass
|
|
65
|
+
class Filesystem:
|
|
66
|
+
auth: Path
|
|
67
|
+
data: Path
|
|
68
|
+
db: Path
|
|
69
|
+
log: Path
|
|
70
|
+
tmp: Path
|
|
71
|
+
disable_checksum: bool = False
|
|
72
|
+
install: InitVar[bool] = True
|
|
73
|
+
|
|
74
|
+
@staticmethod
|
|
75
|
+
def from_config(config: Config, install: bool = False) -> Filesystem:
|
|
76
|
+
return Filesystem(
|
|
77
|
+
auth=config.paths.auth,
|
|
78
|
+
data=config.paths.data,
|
|
79
|
+
db=config.paths.db,
|
|
80
|
+
log=config.paths.log,
|
|
81
|
+
tmp=config.paths.tmp,
|
|
82
|
+
disable_checksum=config.download.disable_checksum,
|
|
83
|
+
install=install,
|
|
84
|
+
)
|
|
85
|
+
|
|
86
|
+
def __post_init__(self, install: bool = True) -> None:
|
|
87
|
+
if install:
|
|
88
|
+
self.auth.mkdir(parents=True, exist_ok=True)
|
|
89
|
+
self.data.mkdir(parents=True, exist_ok=True)
|
|
90
|
+
self.db.mkdir(parents=True, exist_ok=True)
|
|
91
|
+
self.log.mkdir(parents=True, exist_ok=True)
|
|
92
|
+
self.tmp.mkdir(parents=True, exist_ok=True)
|
|
93
|
+
|
|
94
|
+
def __getitem__(self, file: File) -> FilePath:
|
|
95
|
+
if not isinstance(file, File):
|
|
96
|
+
raise TypeError(file)
|
|
97
|
+
return FilePath(
|
|
98
|
+
drs=self.data / file.local_path / file.filename,
|
|
99
|
+
tmp=self.tmp / f"{file.sha}.part",
|
|
100
|
+
)
|
|
101
|
+
|
|
102
|
+
def glob_netcdf(self) -> Iterator[Path]:
|
|
103
|
+
for path in self.data.glob("**/*.nc"):
|
|
104
|
+
yield path.relative_to(self.data)
|
|
105
|
+
|
|
106
|
+
def open(self, file: File) -> FileObject:
|
|
107
|
+
return FileObject(self[file])
|
|
108
|
+
|
|
109
|
+
def isempty(self, path: Path) -> bool:
|
|
110
|
+
if next(path.iterdir(), None) is None:
|
|
111
|
+
return True
|
|
112
|
+
else:
|
|
113
|
+
return False
|
|
114
|
+
|
|
115
|
+
def iter_empty_parents(self, path: Path) -> Iterator[Path]:
|
|
116
|
+
sample: Path | None
|
|
117
|
+
for _ in range(10): # abitrary 10 to avoid infinite loop
|
|
118
|
+
sample = next(path.glob("**/*.nc"), None)
|
|
119
|
+
if sample is None and self.isempty(path):
|
|
120
|
+
yield path
|
|
121
|
+
path = path.parent
|
|
122
|
+
else:
|
|
123
|
+
return
|
|
124
|
+
|
|
125
|
+
def move_to_drs(self, file: File) -> None:
|
|
126
|
+
path = self[file]
|
|
127
|
+
path.drs.parent.mkdir(parents=True, exist_ok=True)
|
|
128
|
+
try:
|
|
129
|
+
path.done.rename(path.drs)
|
|
130
|
+
except OSError as err:
|
|
131
|
+
logger.error(err)
|
|
132
|
+
copyfile(path.done, path.drs)
|
|
133
|
+
msg = """
|
|
134
|
+
File rename error, shutil.copyfile was used instead.
|
|
135
|
+
For large files, download times might be impacted.
|
|
136
|
+
To address this issue, you may consider setting your `tmp` directory to the same filesystem as your `data` directory:
|
|
137
|
+
|
|
138
|
+
$ esgpull config path.tmp <some/path/on/data/filesystem>
|
|
139
|
+
""".strip()
|
|
140
|
+
logger.error(msg)
|
|
141
|
+
|
|
142
|
+
def delete(self, *files: File) -> None:
|
|
143
|
+
for file in files:
|
|
144
|
+
path = self[file].drs
|
|
145
|
+
if not path.is_file():
|
|
146
|
+
continue
|
|
147
|
+
path.unlink()
|
|
148
|
+
logger.info(f"Deleted file {path}")
|
|
149
|
+
for subpath in self.iter_empty_parents(path.parent):
|
|
150
|
+
subpath.rmdir()
|
|
151
|
+
logger.info(f"Deleted empty folder {subpath}")
|
|
152
|
+
|
|
153
|
+
def compute_checksum(
|
|
154
|
+
self,
|
|
155
|
+
file: File,
|
|
156
|
+
path: Path,
|
|
157
|
+
disable_checksum: bool | None = None,
|
|
158
|
+
) -> str:
|
|
159
|
+
if disable_checksum is None:
|
|
160
|
+
disable_checksum = self.disable_checksum
|
|
161
|
+
if disable_checksum:
|
|
162
|
+
return file.checksum
|
|
163
|
+
else:
|
|
164
|
+
return Digest.from_path(file, path).hexdigest()
|
|
165
|
+
|
|
166
|
+
def check_impl(
|
|
167
|
+
self,
|
|
168
|
+
file: File,
|
|
169
|
+
path: Path,
|
|
170
|
+
digest: Digest | None = None,
|
|
171
|
+
) -> FileCheck:
|
|
172
|
+
if path.stat().st_size != file.size:
|
|
173
|
+
return FileCheck.BadSize
|
|
174
|
+
if digest is None:
|
|
175
|
+
checksum = self.compute_checksum(file, path)
|
|
176
|
+
else:
|
|
177
|
+
checksum = digest.hexdigest()
|
|
178
|
+
if checksum == file.checksum:
|
|
179
|
+
return FileCheck.Ok
|
|
180
|
+
else:
|
|
181
|
+
return FileCheck.BadChecksum
|
|
182
|
+
|
|
183
|
+
def check(
|
|
184
|
+
self,
|
|
185
|
+
file: File,
|
|
186
|
+
digest: Digest | None = None,
|
|
187
|
+
) -> FileCheck:
|
|
188
|
+
path = self[file]
|
|
189
|
+
if path.drs.is_file():
|
|
190
|
+
return self.check_impl(file, path.drs, digest)
|
|
191
|
+
elif path.done.is_file():
|
|
192
|
+
match self.check_impl(file, path.done, digest):
|
|
193
|
+
case FileCheck.Ok:
|
|
194
|
+
return FileCheck.Done
|
|
195
|
+
case check:
|
|
196
|
+
return check
|
|
197
|
+
elif path.tmp.is_file():
|
|
198
|
+
match self.check_impl(file, path.tmp, digest):
|
|
199
|
+
case FileCheck.BadSize:
|
|
200
|
+
return FileCheck.Part
|
|
201
|
+
case _:
|
|
202
|
+
raise ValueError()
|
|
203
|
+
else:
|
|
204
|
+
return FileCheck.Missing
|
|
205
|
+
|
|
206
|
+
def finalize(
|
|
207
|
+
self,
|
|
208
|
+
file: File,
|
|
209
|
+
digest: Digest | None = None,
|
|
210
|
+
) -> Result[FileCheck]:
|
|
211
|
+
match self.check(file, digest=digest):
|
|
212
|
+
case FileCheck.Ok:
|
|
213
|
+
return Ok(FileCheck.Ok)
|
|
214
|
+
case FileCheck.Done:
|
|
215
|
+
self.move_to_drs(file)
|
|
216
|
+
return Ok(FileCheck.Ok)
|
|
217
|
+
case check:
|
|
218
|
+
return Err(check, check.as_err(file))
|
|
219
|
+
|
|
220
|
+
|
|
221
|
+
@dataclass
|
|
222
|
+
class FilePath:
|
|
223
|
+
drs: Path
|
|
224
|
+
tmp: Path
|
|
225
|
+
|
|
226
|
+
@property
|
|
227
|
+
def done(self) -> Path:
|
|
228
|
+
return self.tmp.with_suffix(".done")
|
|
229
|
+
|
|
230
|
+
def __str__(self) -> str:
|
|
231
|
+
return str(self.drs)
|
|
232
|
+
|
|
233
|
+
|
|
234
|
+
@dataclass
|
|
235
|
+
class FileObject:
|
|
236
|
+
path: FilePath
|
|
237
|
+
buffer: AsyncBufferedIOBase = field(init=False)
|
|
238
|
+
|
|
239
|
+
async def __aenter__(self) -> FileObject:
|
|
240
|
+
self.buffer = await aiofiles.open(self.path.tmp, "wb")
|
|
241
|
+
return self
|
|
242
|
+
|
|
243
|
+
async def __aexit__(self, exc_type, exc_value, exc_traceback) -> None:
|
|
244
|
+
if not self.buffer.closed:
|
|
245
|
+
await self.buffer.close()
|
|
246
|
+
|
|
247
|
+
async def write(self, chunk: bytes) -> None:
|
|
248
|
+
await self.buffer.write(chunk)
|
|
249
|
+
|
|
250
|
+
async def to_done(self) -> None:
|
|
251
|
+
if not self.buffer.closed:
|
|
252
|
+
await self.buffer.close()
|
|
253
|
+
self.path.tmp.rename(self.path.done)
|
esgpull/graph.py
ADDED
|
@@ -0,0 +1,460 @@
|
|
|
1
|
+
from __future__ import annotations
|
|
2
|
+
|
|
3
|
+
from collections.abc import Iterator, Mapping, Sequence
|
|
4
|
+
from dataclasses import dataclass
|
|
5
|
+
|
|
6
|
+
from rich.console import Console, ConsoleOptions
|
|
7
|
+
from rich.pretty import pretty_repr
|
|
8
|
+
from rich.tree import Tree
|
|
9
|
+
|
|
10
|
+
from esgpull.config import Config
|
|
11
|
+
from esgpull.database import Database
|
|
12
|
+
from esgpull.exceptions import (
|
|
13
|
+
GraphWithoutDatabase,
|
|
14
|
+
QueryDuplicate,
|
|
15
|
+
TooShortKeyError,
|
|
16
|
+
)
|
|
17
|
+
from esgpull.models import Facet, Query, QueryDict, Tag, sql
|
|
18
|
+
from esgpull.models.utils import rich_measure_impl
|
|
19
|
+
|
|
20
|
+
|
|
21
|
+
@dataclass(init=False, repr=False)
|
|
22
|
+
class Graph:
|
|
23
|
+
queries: dict[str, Query]
|
|
24
|
+
_db: Database | None
|
|
25
|
+
_shas: set[str]
|
|
26
|
+
_name_sha: dict[str, str]
|
|
27
|
+
_rendered: set[str]
|
|
28
|
+
_deleted_shas: set[str]
|
|
29
|
+
|
|
30
|
+
@classmethod
|
|
31
|
+
def from_config(cls, config: Config) -> Graph:
|
|
32
|
+
db = Database.from_config(config)
|
|
33
|
+
return Graph(db)
|
|
34
|
+
|
|
35
|
+
def __init__(self, db: Database | None) -> None:
|
|
36
|
+
self._db = db
|
|
37
|
+
self.queries = {}
|
|
38
|
+
self._shas = set()
|
|
39
|
+
self._name_sha = {}
|
|
40
|
+
self._deleted_shas = set()
|
|
41
|
+
if db is not None:
|
|
42
|
+
self._load_db_shas()
|
|
43
|
+
|
|
44
|
+
@property
|
|
45
|
+
def db(self) -> Database:
|
|
46
|
+
if self._db is None:
|
|
47
|
+
raise GraphWithoutDatabase()
|
|
48
|
+
else:
|
|
49
|
+
return self._db
|
|
50
|
+
|
|
51
|
+
@staticmethod
|
|
52
|
+
def matching_shas(name: str, shas: set[str]) -> list[str]:
|
|
53
|
+
shas_copy = list(shas)
|
|
54
|
+
for pos, c in enumerate(name):
|
|
55
|
+
idx = 0
|
|
56
|
+
while idx < len(shas_copy):
|
|
57
|
+
if shas_copy[idx][pos] != c:
|
|
58
|
+
shas_copy.pop(idx)
|
|
59
|
+
else:
|
|
60
|
+
idx += 1
|
|
61
|
+
return shas_copy
|
|
62
|
+
|
|
63
|
+
@staticmethod
|
|
64
|
+
def _expand_name(
|
|
65
|
+
name: str, shas: set[str], name_sha: dict[str, str]
|
|
66
|
+
) -> str:
|
|
67
|
+
if name in name_sha:
|
|
68
|
+
sha = name_sha[name]
|
|
69
|
+
elif name in shas:
|
|
70
|
+
sha = name
|
|
71
|
+
else:
|
|
72
|
+
short_name = name
|
|
73
|
+
if short_name.startswith("#"):
|
|
74
|
+
short_name = short_name[1:]
|
|
75
|
+
matching_shas = Graph.matching_shas(short_name, shas)
|
|
76
|
+
if len(matching_shas) > 1:
|
|
77
|
+
raise TooShortKeyError(name)
|
|
78
|
+
elif len(matching_shas) == 1:
|
|
79
|
+
sha = matching_shas[0]
|
|
80
|
+
else:
|
|
81
|
+
sha = name
|
|
82
|
+
return sha
|
|
83
|
+
|
|
84
|
+
def __contains__(self, item: Query | str) -> bool:
|
|
85
|
+
match item:
|
|
86
|
+
case Query():
|
|
87
|
+
item.compute_sha()
|
|
88
|
+
sha = item.sha
|
|
89
|
+
case str():
|
|
90
|
+
sha = self._expand_name(item, self._shas, self._name_sha)
|
|
91
|
+
case _:
|
|
92
|
+
raise TypeError(item)
|
|
93
|
+
return sha in self._shas
|
|
94
|
+
|
|
95
|
+
def get(self, name: str) -> Query:
|
|
96
|
+
sha = self._expand_name(name, self._shas, self._name_sha)
|
|
97
|
+
if sha in self.queries:
|
|
98
|
+
...
|
|
99
|
+
elif sha in self._shas:
|
|
100
|
+
query_db = self.db.get(Query, sha)
|
|
101
|
+
if query_db is not None:
|
|
102
|
+
self.queries[sha] = query_db
|
|
103
|
+
else:
|
|
104
|
+
raise
|
|
105
|
+
else:
|
|
106
|
+
raise KeyError(name)
|
|
107
|
+
return self.queries[sha]
|
|
108
|
+
|
|
109
|
+
def get_mutable(self, name: str) -> Query:
|
|
110
|
+
sha = self._expand_name(name, self._shas, self._name_sha)
|
|
111
|
+
if sha in self._shas:
|
|
112
|
+
query_db = self.db.get(
|
|
113
|
+
Query,
|
|
114
|
+
sha,
|
|
115
|
+
lazy=False,
|
|
116
|
+
detached=True,
|
|
117
|
+
)
|
|
118
|
+
if query_db is not None:
|
|
119
|
+
self.queries[sha] = query_db
|
|
120
|
+
else:
|
|
121
|
+
raise
|
|
122
|
+
else:
|
|
123
|
+
raise KeyError(name)
|
|
124
|
+
return self.queries[sha]
|
|
125
|
+
|
|
126
|
+
def get_children(self, sha: str) -> Sequence[Query]:
|
|
127
|
+
if self._db is None:
|
|
128
|
+
children: list[Query] = []
|
|
129
|
+
for query in self.queries.values():
|
|
130
|
+
if query.require == sha:
|
|
131
|
+
children.append(query)
|
|
132
|
+
elif sha is None:
|
|
133
|
+
return []
|
|
134
|
+
else:
|
|
135
|
+
return self.db.scalars(sql.query.children(sha))
|
|
136
|
+
return children
|
|
137
|
+
|
|
138
|
+
def get_all_children(self, sha: str) -> Sequence[Query]:
|
|
139
|
+
children: list[Query] = []
|
|
140
|
+
for query in self.get_children(sha):
|
|
141
|
+
children.append(query)
|
|
142
|
+
children.extend(self.get_all_children(query.sha))
|
|
143
|
+
return children
|
|
144
|
+
|
|
145
|
+
def get_parent(self, query: Query) -> Query | None:
|
|
146
|
+
if query.require is not None:
|
|
147
|
+
return self.get(query.require)
|
|
148
|
+
else:
|
|
149
|
+
return None
|
|
150
|
+
|
|
151
|
+
def get_parents(self, query: Query) -> list[Query]:
|
|
152
|
+
result: list[Query] = []
|
|
153
|
+
parent = self.get_parent(query)
|
|
154
|
+
while parent is not None:
|
|
155
|
+
result.append(parent)
|
|
156
|
+
parent = self.get_parent(parent)
|
|
157
|
+
return result
|
|
158
|
+
|
|
159
|
+
def get_tags(self) -> list[Tag]:
|
|
160
|
+
return list(self.db.scalars(sql.tag.all()))
|
|
161
|
+
|
|
162
|
+
def get_tag(self, name: str) -> Tag | None:
|
|
163
|
+
result: Tag | None = None
|
|
164
|
+
for tag in self.get_tags():
|
|
165
|
+
if tag.name == name:
|
|
166
|
+
result = tag
|
|
167
|
+
break
|
|
168
|
+
return result
|
|
169
|
+
|
|
170
|
+
def with_tag(self, tag_name: str) -> list[Query]:
|
|
171
|
+
queries: list[Query] = []
|
|
172
|
+
shas: set[str] = set()
|
|
173
|
+
try:
|
|
174
|
+
db_queries = self.db.scalars(sql.query.with_tag(tag_name))
|
|
175
|
+
for query in db_queries:
|
|
176
|
+
queries.append(query)
|
|
177
|
+
shas.add(query.sha)
|
|
178
|
+
except GraphWithoutDatabase:
|
|
179
|
+
pass
|
|
180
|
+
for sha in self._shas - shas:
|
|
181
|
+
query = self.get(sha)
|
|
182
|
+
if query.get_tag(tag_name) is not None:
|
|
183
|
+
queries.append(query)
|
|
184
|
+
return queries
|
|
185
|
+
|
|
186
|
+
def subgraph(
|
|
187
|
+
self,
|
|
188
|
+
*queries: Query,
|
|
189
|
+
children: bool = True,
|
|
190
|
+
parents: bool = False,
|
|
191
|
+
keep_db: bool = False,
|
|
192
|
+
) -> Graph:
|
|
193
|
+
if not queries:
|
|
194
|
+
raise ValueError("Cannot subgraph from nothing")
|
|
195
|
+
if keep_db:
|
|
196
|
+
graph = Graph(self.db)
|
|
197
|
+
queries_shas = [q.sha for q in queries]
|
|
198
|
+
graph.load_db(*queries_shas)
|
|
199
|
+
else:
|
|
200
|
+
graph = Graph(None)
|
|
201
|
+
graph.add(*queries, force=True, clone=False)
|
|
202
|
+
if children:
|
|
203
|
+
for query in queries:
|
|
204
|
+
query_children = self.get_all_children(query.sha)
|
|
205
|
+
if len(query_children) == 0:
|
|
206
|
+
continue
|
|
207
|
+
if keep_db:
|
|
208
|
+
children_shas = [q.sha for q in query_children]
|
|
209
|
+
graph.load_db(*children_shas)
|
|
210
|
+
else:
|
|
211
|
+
graph.add(*query_children, force=True, clone=False)
|
|
212
|
+
if parents:
|
|
213
|
+
for query in queries:
|
|
214
|
+
query_parents = self.get_parents(query)
|
|
215
|
+
if len(query_parents) == 0:
|
|
216
|
+
continue
|
|
217
|
+
if keep_db:
|
|
218
|
+
parents_shas = [q.sha for q in query_parents]
|
|
219
|
+
graph.load_db(*parents_shas)
|
|
220
|
+
else:
|
|
221
|
+
graph.add(*query_parents, force=True, clone=False)
|
|
222
|
+
return graph
|
|
223
|
+
|
|
224
|
+
def _load_db_shas(self, full: bool = False) -> None:
|
|
225
|
+
name_sha: dict[str, str] = {}
|
|
226
|
+
self._shas = set(self.db.scalars(sql.query.shas()))
|
|
227
|
+
for name, sha in self.db.rows(sql.query.name_sha()):
|
|
228
|
+
name_sha[name] = sha
|
|
229
|
+
self._name_sha = name_sha
|
|
230
|
+
|
|
231
|
+
def load_db(self, *shas: str) -> None:
|
|
232
|
+
if shas:
|
|
233
|
+
unloaded_shas = set(shas)
|
|
234
|
+
else:
|
|
235
|
+
unloaded_shas = set(self._shas) - set(self.queries.keys())
|
|
236
|
+
if unloaded_shas:
|
|
237
|
+
queries = self.db.scalars(sql.query.with_shas(*unloaded_shas))
|
|
238
|
+
for query in queries:
|
|
239
|
+
self.queries[query.sha] = query
|
|
240
|
+
|
|
241
|
+
def validate(self, *queries: Query, noraise: bool = False) -> set[str]:
|
|
242
|
+
names = set(self._name_sha.keys())
|
|
243
|
+
duplicates = {q.name: q for q in queries if q.name in names}
|
|
244
|
+
if duplicates and not noraise:
|
|
245
|
+
raise QueryDuplicate(pretty_repr(duplicates))
|
|
246
|
+
else:
|
|
247
|
+
return set(duplicates.keys())
|
|
248
|
+
|
|
249
|
+
def resolve_require(self, query: Query) -> None:
|
|
250
|
+
if query.require is None or query.require in self._shas:
|
|
251
|
+
...
|
|
252
|
+
elif query.require in self: # self.has(sha=query.require):
|
|
253
|
+
parent = self.get(query.require)
|
|
254
|
+
query.require = parent.sha
|
|
255
|
+
query.compute_sha()
|
|
256
|
+
else:
|
|
257
|
+
query._unknown_require = True # type: ignore [attr-defined]
|
|
258
|
+
|
|
259
|
+
def add(
|
|
260
|
+
self,
|
|
261
|
+
*queries: Query,
|
|
262
|
+
force: bool = False,
|
|
263
|
+
clone: bool = True,
|
|
264
|
+
noraise: bool = False,
|
|
265
|
+
) -> Mapping[str, Query]:
|
|
266
|
+
"""
|
|
267
|
+
Add new query to the graph.
|
|
268
|
+
|
|
269
|
+
- (re)compute sha for each query
|
|
270
|
+
- validate query.name against existing queries
|
|
271
|
+
- populate graph._name_sha to enable `graph[query.name]` indexing
|
|
272
|
+
- replace query.require with full sha
|
|
273
|
+
"""
|
|
274
|
+
new_shas: set[str] = set(self._shas)
|
|
275
|
+
new_deleted_shas: set[str] = set(self._deleted_shas)
|
|
276
|
+
new_queries: dict[str, Query] = dict(self.queries.items())
|
|
277
|
+
name_shas: dict[str, list[str]] = {
|
|
278
|
+
name: [sha] for name, sha in self._name_sha.items()
|
|
279
|
+
}
|
|
280
|
+
queue: list[Query] = [
|
|
281
|
+
query.clone(compute_sha=True) if clone else query
|
|
282
|
+
for query in queries
|
|
283
|
+
]
|
|
284
|
+
# duplicate_names = self.validate(*queue, noraise=noraise or force)
|
|
285
|
+
replaced: dict[str, Query] = {}
|
|
286
|
+
for query in queue:
|
|
287
|
+
if query.sha in new_shas:
|
|
288
|
+
if force:
|
|
289
|
+
if query.sha in new_queries:
|
|
290
|
+
old = new_queries[query.sha]
|
|
291
|
+
else:
|
|
292
|
+
old = self.get(query.sha)
|
|
293
|
+
replaced[query.sha] = old.clone(compute_sha=False) # True?
|
|
294
|
+
else:
|
|
295
|
+
raise QueryDuplicate(pretty_repr(query))
|
|
296
|
+
new_shas.add(query.sha)
|
|
297
|
+
if query.sha in new_deleted_shas:
|
|
298
|
+
new_deleted_shas.remove(query.sha)
|
|
299
|
+
new_queries[query.sha] = query
|
|
300
|
+
skip_tags: set[str] = set()
|
|
301
|
+
for sha, query in self.queries.items():
|
|
302
|
+
tag_name = query.tag_name
|
|
303
|
+
if tag_name is not None and tag_name not in skip_tags:
|
|
304
|
+
name_shas.setdefault(tag_name, [])
|
|
305
|
+
name_shas[tag_name].append(query.sha)
|
|
306
|
+
if len(name_shas[tag_name]) > 1:
|
|
307
|
+
skip_tags.add(tag_name)
|
|
308
|
+
new_name_sha = {
|
|
309
|
+
name: shas[0]
|
|
310
|
+
for name, shas in name_shas.items()
|
|
311
|
+
if name not in skip_tags
|
|
312
|
+
}
|
|
313
|
+
if not force:
|
|
314
|
+
for sha, query in new_queries.items():
|
|
315
|
+
if query.require is not None:
|
|
316
|
+
sha = self._expand_name(
|
|
317
|
+
query.require, new_shas, new_name_sha
|
|
318
|
+
)
|
|
319
|
+
if sha != query.require:
|
|
320
|
+
raise ValueError("case change require")
|
|
321
|
+
self.queries = new_queries
|
|
322
|
+
self._shas = new_shas
|
|
323
|
+
self._deleted_shas = new_deleted_shas
|
|
324
|
+
self._name_sha = new_name_sha
|
|
325
|
+
return replaced
|
|
326
|
+
|
|
327
|
+
def get_unknown_facets(self) -> set[Facet]:
|
|
328
|
+
"""
|
|
329
|
+
Why was this implemented?
|
|
330
|
+
Maybe useful to enable adding facets (e.g. `table_id:*day*`)
|
|
331
|
+
"""
|
|
332
|
+
facets: dict[str, Facet] = {}
|
|
333
|
+
for query in self.queries.values():
|
|
334
|
+
for facet in query.selection._facets:
|
|
335
|
+
if facet.sha not in facets:
|
|
336
|
+
facets[facet.sha] = facet
|
|
337
|
+
shas = list(facets.keys())
|
|
338
|
+
known_shas = self.db.scalars(sql.facet.known_shas(shas))
|
|
339
|
+
unknown_shas = set(shas) - set(known_shas)
|
|
340
|
+
unknown_facets = {facets[sha] for sha in unknown_shas}
|
|
341
|
+
return unknown_facets
|
|
342
|
+
|
|
343
|
+
def merge(self) -> Mapping[str, Query]:
|
|
344
|
+
"""
|
|
345
|
+
Try to load instances from database into self.db.
|
|
346
|
+
|
|
347
|
+
Start with tags, since they are not part of query.sha,
|
|
348
|
+
and there could be new tags to add to an existing query.
|
|
349
|
+
Those new tags need to be merged before adding them to an
|
|
350
|
+
existing query instance from database (autoflush mess).
|
|
351
|
+
|
|
352
|
+
Only load options/selection/facets if query is not in db,
|
|
353
|
+
and updated options/selection/facets should change sha value.
|
|
354
|
+
"""
|
|
355
|
+
updated_shas: set[str] = set()
|
|
356
|
+
for sha, query in self.queries.items():
|
|
357
|
+
query_db = self.db.merge(query, commit=True)
|
|
358
|
+
if query is query_db:
|
|
359
|
+
...
|
|
360
|
+
else:
|
|
361
|
+
updated_shas.add(sha)
|
|
362
|
+
self.queries[sha] = query_db
|
|
363
|
+
for sha in self._deleted_shas:
|
|
364
|
+
query_to_delete = self.db.get(Query, sha)
|
|
365
|
+
if query_to_delete is not None:
|
|
366
|
+
self.db.delete(query_to_delete)
|
|
367
|
+
return {sha: self.queries[sha] for sha in updated_shas}
|
|
368
|
+
|
|
369
|
+
def expand(self, name: str) -> Query:
|
|
370
|
+
"""
|
|
371
|
+
Expand/unpack `query.requires`, using `query.name` index.
|
|
372
|
+
"""
|
|
373
|
+
query = self.get(name)
|
|
374
|
+
while query.require is not None:
|
|
375
|
+
query = self.get(query.require) << query
|
|
376
|
+
return query
|
|
377
|
+
|
|
378
|
+
def dump(self) -> list[QueryDict]:
|
|
379
|
+
"""
|
|
380
|
+
Dump full graph as list of dicts (yaml selection syntax).
|
|
381
|
+
"""
|
|
382
|
+
return [q.asdict() for q in self.queries.values()]
|
|
383
|
+
|
|
384
|
+
def asdict(self, files: bool = False) -> Mapping[str, QueryDict]:
|
|
385
|
+
"""
|
|
386
|
+
Dump full graph as dict of dict, indexed by each query's sha.
|
|
387
|
+
"""
|
|
388
|
+
result = {}
|
|
389
|
+
for sha, query in self.queries.items():
|
|
390
|
+
result[sha] = query.asdict()
|
|
391
|
+
if files:
|
|
392
|
+
result[sha]["files"] = [f.asdict() for f in query.files]
|
|
393
|
+
return result
|
|
394
|
+
|
|
395
|
+
def fill_tree(
|
|
396
|
+
self,
|
|
397
|
+
root: Query | None,
|
|
398
|
+
tree: Tree,
|
|
399
|
+
keep_require: bool = False,
|
|
400
|
+
) -> None:
|
|
401
|
+
"""
|
|
402
|
+
Recursive method to add branches starting from queries with either:
|
|
403
|
+
- require is None
|
|
404
|
+
- require is not in self.queries
|
|
405
|
+
"""
|
|
406
|
+
for sha, query in self.queries.items():
|
|
407
|
+
query_tree: Tree | None = None
|
|
408
|
+
if sha in self._rendered:
|
|
409
|
+
...
|
|
410
|
+
elif root is None:
|
|
411
|
+
if (
|
|
412
|
+
query.require is None or query.require not in self
|
|
413
|
+
): # self.has(sha=query.require):
|
|
414
|
+
self._rendered.add(sha)
|
|
415
|
+
query_tree = query._rich_tree()
|
|
416
|
+
elif query.require == root.sha:
|
|
417
|
+
self._rendered.add(sha)
|
|
418
|
+
if keep_require:
|
|
419
|
+
query_tree = query._rich_tree()
|
|
420
|
+
else:
|
|
421
|
+
query_tree = query.no_require()._rich_tree()
|
|
422
|
+
if query_tree is not None:
|
|
423
|
+
tree.add(query_tree)
|
|
424
|
+
self.fill_tree(query, query_tree)
|
|
425
|
+
|
|
426
|
+
__rich_measure__ = rich_measure_impl
|
|
427
|
+
|
|
428
|
+
def __rich_console__(
|
|
429
|
+
self,
|
|
430
|
+
console: Console,
|
|
431
|
+
options: ConsoleOptions,
|
|
432
|
+
) -> Iterator[Tree]:
|
|
433
|
+
"""
|
|
434
|
+
Returns a `rich.tree.Tree` representing queries and their `require`.
|
|
435
|
+
"""
|
|
436
|
+
tree = Tree("", hide_root=True, guide_style="dim")
|
|
437
|
+
self._rendered = set()
|
|
438
|
+
self.fill_tree(None, tree)
|
|
439
|
+
unrendered = set(self.queries.keys()) - self._rendered
|
|
440
|
+
for sha in unrendered:
|
|
441
|
+
query = self.get(sha)
|
|
442
|
+
if query.require is None:
|
|
443
|
+
continue
|
|
444
|
+
parent = self.get(query.require)
|
|
445
|
+
self.fill_tree(parent, tree, keep_require=True)
|
|
446
|
+
del self._rendered
|
|
447
|
+
yield tree
|
|
448
|
+
|
|
449
|
+
def delete(self, query: Query) -> None:
|
|
450
|
+
self._shas.remove(query.sha)
|
|
451
|
+
self.queries.pop(query.sha, None)
|
|
452
|
+
self._deleted_shas.add(query.sha)
|
|
453
|
+
|
|
454
|
+
def replace(self, original: Query, new: Query) -> None:
|
|
455
|
+
# if original not in self.db:
|
|
456
|
+
# raise ValueError(f"{original.name} not found in the database.")
|
|
457
|
+
# elif new in self.db:
|
|
458
|
+
# raise ValueError(f"{new.name} already in the database.")
|
|
459
|
+
self.delete(original)
|
|
460
|
+
self.add(new)
|