esgpull 0.6.3__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (80) hide show
  1. esgpull/__init__.py +12 -0
  2. esgpull/auth.py +181 -0
  3. esgpull/cli/__init__.py +73 -0
  4. esgpull/cli/add.py +103 -0
  5. esgpull/cli/autoremove.py +38 -0
  6. esgpull/cli/config.py +116 -0
  7. esgpull/cli/convert.py +285 -0
  8. esgpull/cli/decorators.py +342 -0
  9. esgpull/cli/download.py +74 -0
  10. esgpull/cli/facet.py +23 -0
  11. esgpull/cli/get.py +28 -0
  12. esgpull/cli/install.py +85 -0
  13. esgpull/cli/link.py +105 -0
  14. esgpull/cli/login.py +56 -0
  15. esgpull/cli/remove.py +73 -0
  16. esgpull/cli/retry.py +43 -0
  17. esgpull/cli/search.py +201 -0
  18. esgpull/cli/self.py +238 -0
  19. esgpull/cli/show.py +66 -0
  20. esgpull/cli/status.py +67 -0
  21. esgpull/cli/track.py +87 -0
  22. esgpull/cli/update.py +184 -0
  23. esgpull/cli/utils.py +247 -0
  24. esgpull/config.py +410 -0
  25. esgpull/constants.py +56 -0
  26. esgpull/context.py +724 -0
  27. esgpull/database.py +161 -0
  28. esgpull/download.py +162 -0
  29. esgpull/esgpull.py +447 -0
  30. esgpull/exceptions.py +167 -0
  31. esgpull/fs.py +253 -0
  32. esgpull/graph.py +460 -0
  33. esgpull/install_config.py +185 -0
  34. esgpull/migrations/README +1 -0
  35. esgpull/migrations/env.py +82 -0
  36. esgpull/migrations/script.py.mako +24 -0
  37. esgpull/migrations/versions/0.3.0_update_tables.py +170 -0
  38. esgpull/migrations/versions/0.3.1_update_tables.py +25 -0
  39. esgpull/migrations/versions/0.3.2_update_tables.py +26 -0
  40. esgpull/migrations/versions/0.3.3_update_tables.py +25 -0
  41. esgpull/migrations/versions/0.3.4_update_tables.py +25 -0
  42. esgpull/migrations/versions/0.3.5_update_tables.py +25 -0
  43. esgpull/migrations/versions/0.3.6_update_tables.py +26 -0
  44. esgpull/migrations/versions/0.3.7_update_tables.py +26 -0
  45. esgpull/migrations/versions/0.3.8_update_tables.py +26 -0
  46. esgpull/migrations/versions/0.4.0_update_tables.py +25 -0
  47. esgpull/migrations/versions/0.5.0_update_tables.py +26 -0
  48. esgpull/migrations/versions/0.5.1_update_tables.py +26 -0
  49. esgpull/migrations/versions/0.5.2_update_tables.py +25 -0
  50. esgpull/migrations/versions/0.5.3_update_tables.py +26 -0
  51. esgpull/migrations/versions/0.5.4_update_tables.py +25 -0
  52. esgpull/migrations/versions/0.5.5_update_tables.py +25 -0
  53. esgpull/migrations/versions/0.6.0_update_tables.py +25 -0
  54. esgpull/migrations/versions/0.6.1_update_tables.py +25 -0
  55. esgpull/migrations/versions/0.6.2_update_tables.py +25 -0
  56. esgpull/migrations/versions/0.6.3_update_tables.py +25 -0
  57. esgpull/models/__init__.py +31 -0
  58. esgpull/models/base.py +50 -0
  59. esgpull/models/dataset.py +34 -0
  60. esgpull/models/facet.py +18 -0
  61. esgpull/models/file.py +65 -0
  62. esgpull/models/options.py +164 -0
  63. esgpull/models/query.py +481 -0
  64. esgpull/models/selection.py +201 -0
  65. esgpull/models/sql.py +258 -0
  66. esgpull/models/synda_file.py +85 -0
  67. esgpull/models/tag.py +19 -0
  68. esgpull/models/utils.py +54 -0
  69. esgpull/presets.py +13 -0
  70. esgpull/processor.py +172 -0
  71. esgpull/py.typed +0 -0
  72. esgpull/result.py +53 -0
  73. esgpull/tui.py +346 -0
  74. esgpull/utils.py +54 -0
  75. esgpull/version.py +1 -0
  76. esgpull-0.6.3.dist-info/METADATA +110 -0
  77. esgpull-0.6.3.dist-info/RECORD +80 -0
  78. esgpull-0.6.3.dist-info/WHEEL +4 -0
  79. esgpull-0.6.3.dist-info/entry_points.txt +3 -0
  80. esgpull-0.6.3.dist-info/licenses/LICENSE +28 -0
esgpull/fs.py ADDED
@@ -0,0 +1,253 @@
1
+ from __future__ import annotations
2
+
3
+ import hashlib
4
+ from collections.abc import Iterator
5
+ from dataclasses import InitVar, dataclass, field
6
+ from enum import Enum, auto
7
+ from pathlib import Path
8
+ from shutil import copyfile
9
+
10
+ import aiofiles
11
+ from aiofiles.threadpool.binary import AsyncBufferedIOBase
12
+
13
+ from esgpull.config import Config
14
+ from esgpull.models import File
15
+ from esgpull.result import Err, Ok, Result
16
+ from esgpull.tui import logger
17
+
18
+
19
+ class FileCheck(Enum):
20
+ Missing = auto() # not in any known paths
21
+ Part = auto() # {file.sha}.part exists
22
+ BadSize = auto() # {file.sha}.done exists AND has wrong size
23
+ BadChecksum = auto() # {file.sha}.done exists AND has wrong size
24
+ Done = auto() # {file.sha}.done exists AND is ready to be moved
25
+ Ok = auto() # file is in drs with everything ok
26
+
27
+ def as_err(self, file: File) -> Exception:
28
+ err_cls = type(str(self), (Exception,), {})
29
+ return err_cls(file)
30
+
31
+
32
+ @dataclass
33
+ class Digest:
34
+ file: InitVar[File]
35
+ alg: hashlib._Hash = field(init=False)
36
+
37
+ def __post_init__(self, file: File) -> None:
38
+ match file.checksum_type:
39
+ case "SHA256":
40
+ self.alg = hashlib.sha256()
41
+ case _:
42
+ raise NotImplementedError
43
+
44
+ @classmethod
45
+ def from_path(cls, file: File, path: Path) -> Digest:
46
+ block_size = path.stat().st_blksize
47
+ digest = cls(file)
48
+ with path.open("rb") as f:
49
+ while True:
50
+ block = f.read(block_size)
51
+ if block == b"":
52
+ break
53
+ else:
54
+ digest.update(block)
55
+ return digest
56
+
57
+ def update(self, chunk: bytes) -> None:
58
+ self.alg.update(chunk)
59
+
60
+ def hexdigest(self) -> str:
61
+ return self.alg.hexdigest()
62
+
63
+
64
+ @dataclass
65
+ class Filesystem:
66
+ auth: Path
67
+ data: Path
68
+ db: Path
69
+ log: Path
70
+ tmp: Path
71
+ disable_checksum: bool = False
72
+ install: InitVar[bool] = True
73
+
74
+ @staticmethod
75
+ def from_config(config: Config, install: bool = False) -> Filesystem:
76
+ return Filesystem(
77
+ auth=config.paths.auth,
78
+ data=config.paths.data,
79
+ db=config.paths.db,
80
+ log=config.paths.log,
81
+ tmp=config.paths.tmp,
82
+ disable_checksum=config.download.disable_checksum,
83
+ install=install,
84
+ )
85
+
86
+ def __post_init__(self, install: bool = True) -> None:
87
+ if install:
88
+ self.auth.mkdir(parents=True, exist_ok=True)
89
+ self.data.mkdir(parents=True, exist_ok=True)
90
+ self.db.mkdir(parents=True, exist_ok=True)
91
+ self.log.mkdir(parents=True, exist_ok=True)
92
+ self.tmp.mkdir(parents=True, exist_ok=True)
93
+
94
+ def __getitem__(self, file: File) -> FilePath:
95
+ if not isinstance(file, File):
96
+ raise TypeError(file)
97
+ return FilePath(
98
+ drs=self.data / file.local_path / file.filename,
99
+ tmp=self.tmp / f"{file.sha}.part",
100
+ )
101
+
102
+ def glob_netcdf(self) -> Iterator[Path]:
103
+ for path in self.data.glob("**/*.nc"):
104
+ yield path.relative_to(self.data)
105
+
106
+ def open(self, file: File) -> FileObject:
107
+ return FileObject(self[file])
108
+
109
+ def isempty(self, path: Path) -> bool:
110
+ if next(path.iterdir(), None) is None:
111
+ return True
112
+ else:
113
+ return False
114
+
115
+ def iter_empty_parents(self, path: Path) -> Iterator[Path]:
116
+ sample: Path | None
117
+ for _ in range(10): # abitrary 10 to avoid infinite loop
118
+ sample = next(path.glob("**/*.nc"), None)
119
+ if sample is None and self.isempty(path):
120
+ yield path
121
+ path = path.parent
122
+ else:
123
+ return
124
+
125
+ def move_to_drs(self, file: File) -> None:
126
+ path = self[file]
127
+ path.drs.parent.mkdir(parents=True, exist_ok=True)
128
+ try:
129
+ path.done.rename(path.drs)
130
+ except OSError as err:
131
+ logger.error(err)
132
+ copyfile(path.done, path.drs)
133
+ msg = """
134
+ File rename error, shutil.copyfile was used instead.
135
+ For large files, download times might be impacted.
136
+ To address this issue, you may consider setting your `tmp` directory to the same filesystem as your `data` directory:
137
+
138
+ $ esgpull config path.tmp <some/path/on/data/filesystem>
139
+ """.strip()
140
+ logger.error(msg)
141
+
142
+ def delete(self, *files: File) -> None:
143
+ for file in files:
144
+ path = self[file].drs
145
+ if not path.is_file():
146
+ continue
147
+ path.unlink()
148
+ logger.info(f"Deleted file {path}")
149
+ for subpath in self.iter_empty_parents(path.parent):
150
+ subpath.rmdir()
151
+ logger.info(f"Deleted empty folder {subpath}")
152
+
153
+ def compute_checksum(
154
+ self,
155
+ file: File,
156
+ path: Path,
157
+ disable_checksum: bool | None = None,
158
+ ) -> str:
159
+ if disable_checksum is None:
160
+ disable_checksum = self.disable_checksum
161
+ if disable_checksum:
162
+ return file.checksum
163
+ else:
164
+ return Digest.from_path(file, path).hexdigest()
165
+
166
+ def check_impl(
167
+ self,
168
+ file: File,
169
+ path: Path,
170
+ digest: Digest | None = None,
171
+ ) -> FileCheck:
172
+ if path.stat().st_size != file.size:
173
+ return FileCheck.BadSize
174
+ if digest is None:
175
+ checksum = self.compute_checksum(file, path)
176
+ else:
177
+ checksum = digest.hexdigest()
178
+ if checksum == file.checksum:
179
+ return FileCheck.Ok
180
+ else:
181
+ return FileCheck.BadChecksum
182
+
183
+ def check(
184
+ self,
185
+ file: File,
186
+ digest: Digest | None = None,
187
+ ) -> FileCheck:
188
+ path = self[file]
189
+ if path.drs.is_file():
190
+ return self.check_impl(file, path.drs, digest)
191
+ elif path.done.is_file():
192
+ match self.check_impl(file, path.done, digest):
193
+ case FileCheck.Ok:
194
+ return FileCheck.Done
195
+ case check:
196
+ return check
197
+ elif path.tmp.is_file():
198
+ match self.check_impl(file, path.tmp, digest):
199
+ case FileCheck.BadSize:
200
+ return FileCheck.Part
201
+ case _:
202
+ raise ValueError()
203
+ else:
204
+ return FileCheck.Missing
205
+
206
+ def finalize(
207
+ self,
208
+ file: File,
209
+ digest: Digest | None = None,
210
+ ) -> Result[FileCheck]:
211
+ match self.check(file, digest=digest):
212
+ case FileCheck.Ok:
213
+ return Ok(FileCheck.Ok)
214
+ case FileCheck.Done:
215
+ self.move_to_drs(file)
216
+ return Ok(FileCheck.Ok)
217
+ case check:
218
+ return Err(check, check.as_err(file))
219
+
220
+
221
+ @dataclass
222
+ class FilePath:
223
+ drs: Path
224
+ tmp: Path
225
+
226
+ @property
227
+ def done(self) -> Path:
228
+ return self.tmp.with_suffix(".done")
229
+
230
+ def __str__(self) -> str:
231
+ return str(self.drs)
232
+
233
+
234
+ @dataclass
235
+ class FileObject:
236
+ path: FilePath
237
+ buffer: AsyncBufferedIOBase = field(init=False)
238
+
239
+ async def __aenter__(self) -> FileObject:
240
+ self.buffer = await aiofiles.open(self.path.tmp, "wb")
241
+ return self
242
+
243
+ async def __aexit__(self, exc_type, exc_value, exc_traceback) -> None:
244
+ if not self.buffer.closed:
245
+ await self.buffer.close()
246
+
247
+ async def write(self, chunk: bytes) -> None:
248
+ await self.buffer.write(chunk)
249
+
250
+ async def to_done(self) -> None:
251
+ if not self.buffer.closed:
252
+ await self.buffer.close()
253
+ self.path.tmp.rename(self.path.done)
esgpull/graph.py ADDED
@@ -0,0 +1,460 @@
1
+ from __future__ import annotations
2
+
3
+ from collections.abc import Iterator, Mapping, Sequence
4
+ from dataclasses import dataclass
5
+
6
+ from rich.console import Console, ConsoleOptions
7
+ from rich.pretty import pretty_repr
8
+ from rich.tree import Tree
9
+
10
+ from esgpull.config import Config
11
+ from esgpull.database import Database
12
+ from esgpull.exceptions import (
13
+ GraphWithoutDatabase,
14
+ QueryDuplicate,
15
+ TooShortKeyError,
16
+ )
17
+ from esgpull.models import Facet, Query, QueryDict, Tag, sql
18
+ from esgpull.models.utils import rich_measure_impl
19
+
20
+
21
+ @dataclass(init=False, repr=False)
22
+ class Graph:
23
+ queries: dict[str, Query]
24
+ _db: Database | None
25
+ _shas: set[str]
26
+ _name_sha: dict[str, str]
27
+ _rendered: set[str]
28
+ _deleted_shas: set[str]
29
+
30
+ @classmethod
31
+ def from_config(cls, config: Config) -> Graph:
32
+ db = Database.from_config(config)
33
+ return Graph(db)
34
+
35
+ def __init__(self, db: Database | None) -> None:
36
+ self._db = db
37
+ self.queries = {}
38
+ self._shas = set()
39
+ self._name_sha = {}
40
+ self._deleted_shas = set()
41
+ if db is not None:
42
+ self._load_db_shas()
43
+
44
+ @property
45
+ def db(self) -> Database:
46
+ if self._db is None:
47
+ raise GraphWithoutDatabase()
48
+ else:
49
+ return self._db
50
+
51
+ @staticmethod
52
+ def matching_shas(name: str, shas: set[str]) -> list[str]:
53
+ shas_copy = list(shas)
54
+ for pos, c in enumerate(name):
55
+ idx = 0
56
+ while idx < len(shas_copy):
57
+ if shas_copy[idx][pos] != c:
58
+ shas_copy.pop(idx)
59
+ else:
60
+ idx += 1
61
+ return shas_copy
62
+
63
+ @staticmethod
64
+ def _expand_name(
65
+ name: str, shas: set[str], name_sha: dict[str, str]
66
+ ) -> str:
67
+ if name in name_sha:
68
+ sha = name_sha[name]
69
+ elif name in shas:
70
+ sha = name
71
+ else:
72
+ short_name = name
73
+ if short_name.startswith("#"):
74
+ short_name = short_name[1:]
75
+ matching_shas = Graph.matching_shas(short_name, shas)
76
+ if len(matching_shas) > 1:
77
+ raise TooShortKeyError(name)
78
+ elif len(matching_shas) == 1:
79
+ sha = matching_shas[0]
80
+ else:
81
+ sha = name
82
+ return sha
83
+
84
+ def __contains__(self, item: Query | str) -> bool:
85
+ match item:
86
+ case Query():
87
+ item.compute_sha()
88
+ sha = item.sha
89
+ case str():
90
+ sha = self._expand_name(item, self._shas, self._name_sha)
91
+ case _:
92
+ raise TypeError(item)
93
+ return sha in self._shas
94
+
95
+ def get(self, name: str) -> Query:
96
+ sha = self._expand_name(name, self._shas, self._name_sha)
97
+ if sha in self.queries:
98
+ ...
99
+ elif sha in self._shas:
100
+ query_db = self.db.get(Query, sha)
101
+ if query_db is not None:
102
+ self.queries[sha] = query_db
103
+ else:
104
+ raise
105
+ else:
106
+ raise KeyError(name)
107
+ return self.queries[sha]
108
+
109
+ def get_mutable(self, name: str) -> Query:
110
+ sha = self._expand_name(name, self._shas, self._name_sha)
111
+ if sha in self._shas:
112
+ query_db = self.db.get(
113
+ Query,
114
+ sha,
115
+ lazy=False,
116
+ detached=True,
117
+ )
118
+ if query_db is not None:
119
+ self.queries[sha] = query_db
120
+ else:
121
+ raise
122
+ else:
123
+ raise KeyError(name)
124
+ return self.queries[sha]
125
+
126
+ def get_children(self, sha: str) -> Sequence[Query]:
127
+ if self._db is None:
128
+ children: list[Query] = []
129
+ for query in self.queries.values():
130
+ if query.require == sha:
131
+ children.append(query)
132
+ elif sha is None:
133
+ return []
134
+ else:
135
+ return self.db.scalars(sql.query.children(sha))
136
+ return children
137
+
138
+ def get_all_children(self, sha: str) -> Sequence[Query]:
139
+ children: list[Query] = []
140
+ for query in self.get_children(sha):
141
+ children.append(query)
142
+ children.extend(self.get_all_children(query.sha))
143
+ return children
144
+
145
+ def get_parent(self, query: Query) -> Query | None:
146
+ if query.require is not None:
147
+ return self.get(query.require)
148
+ else:
149
+ return None
150
+
151
+ def get_parents(self, query: Query) -> list[Query]:
152
+ result: list[Query] = []
153
+ parent = self.get_parent(query)
154
+ while parent is not None:
155
+ result.append(parent)
156
+ parent = self.get_parent(parent)
157
+ return result
158
+
159
+ def get_tags(self) -> list[Tag]:
160
+ return list(self.db.scalars(sql.tag.all()))
161
+
162
+ def get_tag(self, name: str) -> Tag | None:
163
+ result: Tag | None = None
164
+ for tag in self.get_tags():
165
+ if tag.name == name:
166
+ result = tag
167
+ break
168
+ return result
169
+
170
+ def with_tag(self, tag_name: str) -> list[Query]:
171
+ queries: list[Query] = []
172
+ shas: set[str] = set()
173
+ try:
174
+ db_queries = self.db.scalars(sql.query.with_tag(tag_name))
175
+ for query in db_queries:
176
+ queries.append(query)
177
+ shas.add(query.sha)
178
+ except GraphWithoutDatabase:
179
+ pass
180
+ for sha in self._shas - shas:
181
+ query = self.get(sha)
182
+ if query.get_tag(tag_name) is not None:
183
+ queries.append(query)
184
+ return queries
185
+
186
+ def subgraph(
187
+ self,
188
+ *queries: Query,
189
+ children: bool = True,
190
+ parents: bool = False,
191
+ keep_db: bool = False,
192
+ ) -> Graph:
193
+ if not queries:
194
+ raise ValueError("Cannot subgraph from nothing")
195
+ if keep_db:
196
+ graph = Graph(self.db)
197
+ queries_shas = [q.sha for q in queries]
198
+ graph.load_db(*queries_shas)
199
+ else:
200
+ graph = Graph(None)
201
+ graph.add(*queries, force=True, clone=False)
202
+ if children:
203
+ for query in queries:
204
+ query_children = self.get_all_children(query.sha)
205
+ if len(query_children) == 0:
206
+ continue
207
+ if keep_db:
208
+ children_shas = [q.sha for q in query_children]
209
+ graph.load_db(*children_shas)
210
+ else:
211
+ graph.add(*query_children, force=True, clone=False)
212
+ if parents:
213
+ for query in queries:
214
+ query_parents = self.get_parents(query)
215
+ if len(query_parents) == 0:
216
+ continue
217
+ if keep_db:
218
+ parents_shas = [q.sha for q in query_parents]
219
+ graph.load_db(*parents_shas)
220
+ else:
221
+ graph.add(*query_parents, force=True, clone=False)
222
+ return graph
223
+
224
+ def _load_db_shas(self, full: bool = False) -> None:
225
+ name_sha: dict[str, str] = {}
226
+ self._shas = set(self.db.scalars(sql.query.shas()))
227
+ for name, sha in self.db.rows(sql.query.name_sha()):
228
+ name_sha[name] = sha
229
+ self._name_sha = name_sha
230
+
231
+ def load_db(self, *shas: str) -> None:
232
+ if shas:
233
+ unloaded_shas = set(shas)
234
+ else:
235
+ unloaded_shas = set(self._shas) - set(self.queries.keys())
236
+ if unloaded_shas:
237
+ queries = self.db.scalars(sql.query.with_shas(*unloaded_shas))
238
+ for query in queries:
239
+ self.queries[query.sha] = query
240
+
241
+ def validate(self, *queries: Query, noraise: bool = False) -> set[str]:
242
+ names = set(self._name_sha.keys())
243
+ duplicates = {q.name: q for q in queries if q.name in names}
244
+ if duplicates and not noraise:
245
+ raise QueryDuplicate(pretty_repr(duplicates))
246
+ else:
247
+ return set(duplicates.keys())
248
+
249
+ def resolve_require(self, query: Query) -> None:
250
+ if query.require is None or query.require in self._shas:
251
+ ...
252
+ elif query.require in self: # self.has(sha=query.require):
253
+ parent = self.get(query.require)
254
+ query.require = parent.sha
255
+ query.compute_sha()
256
+ else:
257
+ query._unknown_require = True # type: ignore [attr-defined]
258
+
259
+ def add(
260
+ self,
261
+ *queries: Query,
262
+ force: bool = False,
263
+ clone: bool = True,
264
+ noraise: bool = False,
265
+ ) -> Mapping[str, Query]:
266
+ """
267
+ Add new query to the graph.
268
+
269
+ - (re)compute sha for each query
270
+ - validate query.name against existing queries
271
+ - populate graph._name_sha to enable `graph[query.name]` indexing
272
+ - replace query.require with full sha
273
+ """
274
+ new_shas: set[str] = set(self._shas)
275
+ new_deleted_shas: set[str] = set(self._deleted_shas)
276
+ new_queries: dict[str, Query] = dict(self.queries.items())
277
+ name_shas: dict[str, list[str]] = {
278
+ name: [sha] for name, sha in self._name_sha.items()
279
+ }
280
+ queue: list[Query] = [
281
+ query.clone(compute_sha=True) if clone else query
282
+ for query in queries
283
+ ]
284
+ # duplicate_names = self.validate(*queue, noraise=noraise or force)
285
+ replaced: dict[str, Query] = {}
286
+ for query in queue:
287
+ if query.sha in new_shas:
288
+ if force:
289
+ if query.sha in new_queries:
290
+ old = new_queries[query.sha]
291
+ else:
292
+ old = self.get(query.sha)
293
+ replaced[query.sha] = old.clone(compute_sha=False) # True?
294
+ else:
295
+ raise QueryDuplicate(pretty_repr(query))
296
+ new_shas.add(query.sha)
297
+ if query.sha in new_deleted_shas:
298
+ new_deleted_shas.remove(query.sha)
299
+ new_queries[query.sha] = query
300
+ skip_tags: set[str] = set()
301
+ for sha, query in self.queries.items():
302
+ tag_name = query.tag_name
303
+ if tag_name is not None and tag_name not in skip_tags:
304
+ name_shas.setdefault(tag_name, [])
305
+ name_shas[tag_name].append(query.sha)
306
+ if len(name_shas[tag_name]) > 1:
307
+ skip_tags.add(tag_name)
308
+ new_name_sha = {
309
+ name: shas[0]
310
+ for name, shas in name_shas.items()
311
+ if name not in skip_tags
312
+ }
313
+ if not force:
314
+ for sha, query in new_queries.items():
315
+ if query.require is not None:
316
+ sha = self._expand_name(
317
+ query.require, new_shas, new_name_sha
318
+ )
319
+ if sha != query.require:
320
+ raise ValueError("case change require")
321
+ self.queries = new_queries
322
+ self._shas = new_shas
323
+ self._deleted_shas = new_deleted_shas
324
+ self._name_sha = new_name_sha
325
+ return replaced
326
+
327
+ def get_unknown_facets(self) -> set[Facet]:
328
+ """
329
+ Why was this implemented?
330
+ Maybe useful to enable adding facets (e.g. `table_id:*day*`)
331
+ """
332
+ facets: dict[str, Facet] = {}
333
+ for query in self.queries.values():
334
+ for facet in query.selection._facets:
335
+ if facet.sha not in facets:
336
+ facets[facet.sha] = facet
337
+ shas = list(facets.keys())
338
+ known_shas = self.db.scalars(sql.facet.known_shas(shas))
339
+ unknown_shas = set(shas) - set(known_shas)
340
+ unknown_facets = {facets[sha] for sha in unknown_shas}
341
+ return unknown_facets
342
+
343
+ def merge(self) -> Mapping[str, Query]:
344
+ """
345
+ Try to load instances from database into self.db.
346
+
347
+ Start with tags, since they are not part of query.sha,
348
+ and there could be new tags to add to an existing query.
349
+ Those new tags need to be merged before adding them to an
350
+ existing query instance from database (autoflush mess).
351
+
352
+ Only load options/selection/facets if query is not in db,
353
+ and updated options/selection/facets should change sha value.
354
+ """
355
+ updated_shas: set[str] = set()
356
+ for sha, query in self.queries.items():
357
+ query_db = self.db.merge(query, commit=True)
358
+ if query is query_db:
359
+ ...
360
+ else:
361
+ updated_shas.add(sha)
362
+ self.queries[sha] = query_db
363
+ for sha in self._deleted_shas:
364
+ query_to_delete = self.db.get(Query, sha)
365
+ if query_to_delete is not None:
366
+ self.db.delete(query_to_delete)
367
+ return {sha: self.queries[sha] for sha in updated_shas}
368
+
369
+ def expand(self, name: str) -> Query:
370
+ """
371
+ Expand/unpack `query.requires`, using `query.name` index.
372
+ """
373
+ query = self.get(name)
374
+ while query.require is not None:
375
+ query = self.get(query.require) << query
376
+ return query
377
+
378
+ def dump(self) -> list[QueryDict]:
379
+ """
380
+ Dump full graph as list of dicts (yaml selection syntax).
381
+ """
382
+ return [q.asdict() for q in self.queries.values()]
383
+
384
+ def asdict(self, files: bool = False) -> Mapping[str, QueryDict]:
385
+ """
386
+ Dump full graph as dict of dict, indexed by each query's sha.
387
+ """
388
+ result = {}
389
+ for sha, query in self.queries.items():
390
+ result[sha] = query.asdict()
391
+ if files:
392
+ result[sha]["files"] = [f.asdict() for f in query.files]
393
+ return result
394
+
395
+ def fill_tree(
396
+ self,
397
+ root: Query | None,
398
+ tree: Tree,
399
+ keep_require: bool = False,
400
+ ) -> None:
401
+ """
402
+ Recursive method to add branches starting from queries with either:
403
+ - require is None
404
+ - require is not in self.queries
405
+ """
406
+ for sha, query in self.queries.items():
407
+ query_tree: Tree | None = None
408
+ if sha in self._rendered:
409
+ ...
410
+ elif root is None:
411
+ if (
412
+ query.require is None or query.require not in self
413
+ ): # self.has(sha=query.require):
414
+ self._rendered.add(sha)
415
+ query_tree = query._rich_tree()
416
+ elif query.require == root.sha:
417
+ self._rendered.add(sha)
418
+ if keep_require:
419
+ query_tree = query._rich_tree()
420
+ else:
421
+ query_tree = query.no_require()._rich_tree()
422
+ if query_tree is not None:
423
+ tree.add(query_tree)
424
+ self.fill_tree(query, query_tree)
425
+
426
+ __rich_measure__ = rich_measure_impl
427
+
428
+ def __rich_console__(
429
+ self,
430
+ console: Console,
431
+ options: ConsoleOptions,
432
+ ) -> Iterator[Tree]:
433
+ """
434
+ Returns a `rich.tree.Tree` representing queries and their `require`.
435
+ """
436
+ tree = Tree("", hide_root=True, guide_style="dim")
437
+ self._rendered = set()
438
+ self.fill_tree(None, tree)
439
+ unrendered = set(self.queries.keys()) - self._rendered
440
+ for sha in unrendered:
441
+ query = self.get(sha)
442
+ if query.require is None:
443
+ continue
444
+ parent = self.get(query.require)
445
+ self.fill_tree(parent, tree, keep_require=True)
446
+ del self._rendered
447
+ yield tree
448
+
449
+ def delete(self, query: Query) -> None:
450
+ self._shas.remove(query.sha)
451
+ self.queries.pop(query.sha, None)
452
+ self._deleted_shas.add(query.sha)
453
+
454
+ def replace(self, original: Query, new: Query) -> None:
455
+ # if original not in self.db:
456
+ # raise ValueError(f"{original.name} not found in the database.")
457
+ # elif new in self.db:
458
+ # raise ValueError(f"{new.name} already in the database.")
459
+ self.delete(original)
460
+ self.add(new)