esgpull 0.8.0__py3-none-any.whl → 0.9.1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- esgpull/cli/__init__.py +2 -2
- esgpull/cli/add.py +7 -1
- esgpull/cli/config.py +5 -21
- esgpull/cli/plugins.py +398 -0
- esgpull/cli/remove.py +9 -3
- esgpull/cli/self.py +1 -1
- esgpull/cli/update.py +78 -35
- esgpull/cli/utils.py +16 -1
- esgpull/config.py +83 -26
- esgpull/constants.py +3 -0
- esgpull/context.py +9 -9
- esgpull/database.py +21 -7
- esgpull/download.py +3 -0
- esgpull/esgpull.py +49 -5
- esgpull/fs.py +9 -20
- esgpull/graph.py +1 -1
- esgpull/migrations/versions/0.9.0_update_tables.py +28 -0
- esgpull/migrations/versions/0.9.1_update_tables.py +28 -0
- esgpull/migrations/versions/d14f179e553c_file_add_composite_index_dataset_id_.py +32 -0
- esgpull/migrations/versions/e7edab5d4e4b_add_dataset_tracking.py +39 -0
- esgpull/models/__init__.py +2 -1
- esgpull/models/base.py +31 -14
- esgpull/models/dataset.py +48 -5
- esgpull/models/query.py +58 -14
- esgpull/models/sql.py +48 -9
- esgpull/plugin.py +574 -0
- esgpull/processor.py +3 -3
- esgpull/tui.py +23 -1
- esgpull/utils.py +5 -1
- {esgpull-0.8.0.dist-info → esgpull-0.9.1.dist-info}/METADATA +19 -3
- {esgpull-0.8.0.dist-info → esgpull-0.9.1.dist-info}/RECORD +34 -29
- esgpull/cli/datasets.py +0 -78
- {esgpull-0.8.0.dist-info → esgpull-0.9.1.dist-info}/WHEEL +0 -0
- {esgpull-0.8.0.dist-info → esgpull-0.9.1.dist-info}/entry_points.txt +0 -0
- {esgpull-0.8.0.dist-info → esgpull-0.9.1.dist-info}/licenses/LICENSE +0 -0
esgpull/cli/update.py
CHANGED
|
@@ -10,8 +10,8 @@ from esgpull.cli.decorators import args, opts
|
|
|
10
10
|
from esgpull.cli.utils import get_queries, init_esgpull, valid_name_tag
|
|
11
11
|
from esgpull.context import HintsDict, ResultSearch
|
|
12
12
|
from esgpull.exceptions import UnsetOptionsError
|
|
13
|
-
from esgpull.models import File, FileStatus, Query
|
|
14
|
-
from esgpull.tui import Verbosity
|
|
13
|
+
from esgpull.models import Dataset, File, FileStatus, Query
|
|
14
|
+
from esgpull.tui import Verbosity
|
|
15
15
|
from esgpull.utils import format_size
|
|
16
16
|
|
|
17
17
|
|
|
@@ -20,10 +20,13 @@ class QueryFiles:
|
|
|
20
20
|
query: Query
|
|
21
21
|
expanded: Query
|
|
22
22
|
skip: bool = False
|
|
23
|
+
datasets: list[Dataset] = field(default_factory=list)
|
|
23
24
|
files: list[File] = field(default_factory=list)
|
|
25
|
+
dataset_hits: int = field(init=False)
|
|
24
26
|
hits: int = field(init=False)
|
|
25
27
|
hints: HintsDict = field(init=False)
|
|
26
28
|
results: list[ResultSearch] = field(init=False)
|
|
29
|
+
dataset_results: list[ResultSearch] = field(init=False)
|
|
27
30
|
|
|
28
31
|
|
|
29
32
|
@click.command()
|
|
@@ -80,20 +83,27 @@ def update(
|
|
|
80
83
|
file=True,
|
|
81
84
|
facets=["index_node"],
|
|
82
85
|
)
|
|
83
|
-
|
|
86
|
+
dataset_hits = esg.context.hits(
|
|
87
|
+
*[qf.expanded for qf in qfs],
|
|
88
|
+
file=False,
|
|
89
|
+
)
|
|
90
|
+
for qf, qf_hints, qf_dataset_hits in zip(qfs, hints, dataset_hits):
|
|
84
91
|
qf.hits = sum(esg.context.hits_from_hints(qf_hints))
|
|
85
92
|
if qf_hints:
|
|
86
93
|
qf.hints = qf_hints
|
|
94
|
+
qf.dataset_hits = qf_dataset_hits
|
|
87
95
|
else:
|
|
88
96
|
qf.skip = True
|
|
89
97
|
for qf in qfs:
|
|
90
98
|
s = "s" if qf.hits > 1 else ""
|
|
91
|
-
esg.ui.print(
|
|
99
|
+
esg.ui.print(
|
|
100
|
+
f"{qf.query.rich_name} -> {qf.hits} file{s} (before replica de-duplication)."
|
|
101
|
+
)
|
|
92
102
|
total_hits = sum([qf.hits for qf in qfs])
|
|
93
103
|
if total_hits == 0:
|
|
94
104
|
esg.ui.print("No files found.")
|
|
95
105
|
esg.ui.raise_maybe_record(Exit(0))
|
|
96
|
-
|
|
106
|
+
elif len(qfs) > 1:
|
|
97
107
|
esg.ui.print(f"{total_hits} files found.")
|
|
98
108
|
qfs = [qf for qf in qfs if not qf.skip]
|
|
99
109
|
# Prepare optimally distributed requests to ESGF
|
|
@@ -101,6 +111,12 @@ def update(
|
|
|
101
111
|
# It might be interesting for the special case where all files already
|
|
102
112
|
# exist in db, then the detailed fetch could be skipped.
|
|
103
113
|
for qf in qfs:
|
|
114
|
+
qf_dataset_results = esg.context.prepare_search(
|
|
115
|
+
qf.expanded,
|
|
116
|
+
file=False,
|
|
117
|
+
hits=[qf.dataset_hits],
|
|
118
|
+
max_hits=None,
|
|
119
|
+
)
|
|
104
120
|
if esg.config.api.use_custom_distribution_algorithm:
|
|
105
121
|
qf_results = esg.context.prepare_search_distributed(
|
|
106
122
|
qf.expanded,
|
|
@@ -115,7 +131,7 @@ def update(
|
|
|
115
131
|
hits=[qf.hits],
|
|
116
132
|
max_hits=None,
|
|
117
133
|
)
|
|
118
|
-
nb_req = len(qf_results)
|
|
134
|
+
nb_req = len(qf_dataset_results) + len(qf_results)
|
|
119
135
|
if nb_req > 50:
|
|
120
136
|
msg = (
|
|
121
137
|
f"{nb_req} requests will be sent to ESGF to"
|
|
@@ -126,13 +142,32 @@ def update(
|
|
|
126
142
|
esg.ui.print(f"{qf.query.rich_name} is now untracked.")
|
|
127
143
|
qf.query.tracked = False
|
|
128
144
|
qf_results = []
|
|
145
|
+
qf_dataset_results = []
|
|
129
146
|
case "n":
|
|
130
147
|
qf_results = []
|
|
148
|
+
qf_dataset_results = []
|
|
131
149
|
case _:
|
|
132
150
|
...
|
|
133
151
|
qf.results = qf_results
|
|
152
|
+
qf.dataset_results = qf_dataset_results
|
|
134
153
|
# Fetch files and update db
|
|
135
154
|
# [?] TODO: dry_run to print urls here
|
|
155
|
+
with esg.ui.spinner("Fetching datasets"):
|
|
156
|
+
coros = []
|
|
157
|
+
for qf in qfs:
|
|
158
|
+
coro = esg.context._datasets(
|
|
159
|
+
*qf.dataset_results, keep_duplicates=False
|
|
160
|
+
)
|
|
161
|
+
coros.append(coro)
|
|
162
|
+
datasets = esg.context.sync_gather(*coros)
|
|
163
|
+
for qf, qf_datasets in zip(qfs, datasets):
|
|
164
|
+
qf.datasets = [
|
|
165
|
+
Dataset(
|
|
166
|
+
dataset_id=record.dataset_id,
|
|
167
|
+
total_files=record.number_of_files,
|
|
168
|
+
)
|
|
169
|
+
for record in qf_datasets
|
|
170
|
+
]
|
|
136
171
|
with esg.ui.spinner("Fetching files"):
|
|
137
172
|
coros = []
|
|
138
173
|
for qf in qfs:
|
|
@@ -142,23 +177,37 @@ def update(
|
|
|
142
177
|
for qf, qf_files in zip(qfs, files):
|
|
143
178
|
qf.files = qf_files
|
|
144
179
|
for qf in qfs:
|
|
145
|
-
shas = {f.sha for f in qf.query.files}
|
|
146
|
-
new_files: list[File] = []
|
|
147
|
-
for file in qf.files:
|
|
148
|
-
if file.sha not in shas:
|
|
149
|
-
new_files.append(file)
|
|
150
|
-
nb_files = len(new_files)
|
|
151
180
|
if not qf.query.tracked:
|
|
152
181
|
esg.db.add(qf.query)
|
|
153
182
|
continue
|
|
154
|
-
|
|
155
|
-
|
|
156
|
-
|
|
157
|
-
|
|
183
|
+
with esg.db.commit_context():
|
|
184
|
+
unregistered_datasets = [
|
|
185
|
+
f for f in qf.datasets if f not in esg.db
|
|
186
|
+
]
|
|
187
|
+
if len(unregistered_datasets) > 0:
|
|
188
|
+
esg.ui.print(
|
|
189
|
+
f"Adding {len(unregistered_datasets)} new datasets to database."
|
|
190
|
+
)
|
|
191
|
+
esg.db.session.add_all(unregistered_datasets)
|
|
192
|
+
files_from_db = [
|
|
193
|
+
esg.db.get(File, f.sha) for f in qf.files if f in esg.db
|
|
194
|
+
]
|
|
195
|
+
registered_files = [f for f in files_from_db if f is not None]
|
|
196
|
+
unregistered_files = [f for f in qf.files if f not in esg.db]
|
|
197
|
+
if len(unregistered_files) > 0:
|
|
198
|
+
esg.ui.print(
|
|
199
|
+
f"Adding {len(unregistered_files)} new files to database."
|
|
200
|
+
)
|
|
201
|
+
esg.db.session.add_all(unregistered_files)
|
|
202
|
+
files = registered_files + unregistered_files
|
|
203
|
+
not_done_files = [
|
|
204
|
+
file for file in files if file.status != FileStatus.Done
|
|
205
|
+
]
|
|
206
|
+
download_size = sum(file.size for file in not_done_files)
|
|
158
207
|
msg = (
|
|
159
|
-
f"\
|
|
160
|
-
f"
|
|
161
|
-
"\
|
|
208
|
+
f"\n{qf.query.rich_name}: {len(not_done_files)} "
|
|
209
|
+
f" files ({format_size(download_size)}) to download."
|
|
210
|
+
f"\nLink to query and send to download queue?"
|
|
162
211
|
)
|
|
163
212
|
if yes:
|
|
164
213
|
choice = "y"
|
|
@@ -174,25 +223,19 @@ def update(
|
|
|
174
223
|
if choice == "y":
|
|
175
224
|
legacy = esg.legacy_query
|
|
176
225
|
has_legacy = legacy.state.persistent
|
|
226
|
+
applied_changes = False
|
|
177
227
|
with esg.db.commit_context():
|
|
178
228
|
for file in esg.ui.track(
|
|
179
|
-
|
|
180
|
-
description=qf.query.rich_name,
|
|
229
|
+
files,
|
|
230
|
+
description=f"{qf.query.rich_name}",
|
|
181
231
|
):
|
|
182
|
-
|
|
183
|
-
if file_db is None:
|
|
184
|
-
if esg.db.has_file_id(file):
|
|
185
|
-
logger.error(
|
|
186
|
-
"File id already exists in database, "
|
|
187
|
-
"there might be an error with its checksum"
|
|
188
|
-
f"\n{file}"
|
|
189
|
-
)
|
|
190
|
-
continue
|
|
232
|
+
if file.status != FileStatus.Done:
|
|
191
233
|
file.status = FileStatus.Queued
|
|
192
|
-
|
|
193
|
-
|
|
194
|
-
|
|
195
|
-
|
|
196
|
-
|
|
234
|
+
if has_legacy and legacy in file.queries:
|
|
235
|
+
_ = esg.db.unlink(query=legacy, file=file)
|
|
236
|
+
changed = esg.db.link(query=qf.query, file=file)
|
|
237
|
+
applied_changes = applied_changes or changed
|
|
238
|
+
if applied_changes:
|
|
239
|
+
qf.query.updated_at = datetime.now(timezone.utc)
|
|
197
240
|
esg.db.session.add(qf.query)
|
|
198
241
|
esg.ui.raise_maybe_record(Exit(0))
|
esgpull/cli/utils.py
CHANGED
|
@@ -103,7 +103,7 @@ def totable(docs: list[OrderedDict[str, Any]]) -> Table:
|
|
|
103
103
|
table = Table(box=MINIMAL_DOUBLE_HEAD, show_edge=False)
|
|
104
104
|
for key in docs[0].keys():
|
|
105
105
|
justify: Literal["left", "right", "center"]
|
|
106
|
-
if key in ["file", "dataset"]:
|
|
106
|
+
if key in ["file", "dataset", "plugin"]:
|
|
107
107
|
justify = "left"
|
|
108
108
|
else:
|
|
109
109
|
justify = "right"
|
|
@@ -243,3 +243,18 @@ def get_queries(
|
|
|
243
243
|
kids = graph.get_all_children(query.sha)
|
|
244
244
|
queries.extend(kids)
|
|
245
245
|
return queries
|
|
246
|
+
|
|
247
|
+
|
|
248
|
+
def extract_subdict(doc: dict, key: str | None) -> dict:
|
|
249
|
+
if key is None:
|
|
250
|
+
return doc
|
|
251
|
+
for part in key.split("."):
|
|
252
|
+
if not part:
|
|
253
|
+
raise KeyError(key)
|
|
254
|
+
elif part in doc:
|
|
255
|
+
doc = doc[part]
|
|
256
|
+
else:
|
|
257
|
+
raise KeyError(part)
|
|
258
|
+
for part in key.split(".")[::-1]:
|
|
259
|
+
doc = {part: doc}
|
|
260
|
+
return doc
|
esgpull/config.py
CHANGED
|
@@ -4,10 +4,10 @@ import logging
|
|
|
4
4
|
from collections.abc import Iterator, Mapping
|
|
5
5
|
from enum import Enum, auto
|
|
6
6
|
from pathlib import Path
|
|
7
|
-
from typing import Any, cast
|
|
7
|
+
from typing import Any, TypeVar, Union, cast, overload
|
|
8
8
|
|
|
9
9
|
import tomlkit
|
|
10
|
-
from attrs import Factory, define, field
|
|
10
|
+
from attrs import Factory, define, field
|
|
11
11
|
from attrs import has as attrs_has
|
|
12
12
|
from cattrs import Converter
|
|
13
13
|
from cattrs.gen import make_dict_unstructure_fn, override
|
|
@@ -20,6 +20,66 @@ from esgpull.models.options import Options
|
|
|
20
20
|
|
|
21
21
|
logger = logging.getLogger("esgpull")
|
|
22
22
|
|
|
23
|
+
T = TypeVar("T")
|
|
24
|
+
|
|
25
|
+
|
|
26
|
+
@overload
|
|
27
|
+
def cast_value(
|
|
28
|
+
target: str, value: Union[str, int, bool, float], key: str
|
|
29
|
+
) -> str: ...
|
|
30
|
+
|
|
31
|
+
|
|
32
|
+
@overload
|
|
33
|
+
def cast_value(
|
|
34
|
+
target: bool, value: Union[str, int, bool, float], key: str
|
|
35
|
+
) -> bool: ...
|
|
36
|
+
|
|
37
|
+
|
|
38
|
+
@overload
|
|
39
|
+
def cast_value(
|
|
40
|
+
target: int, value: Union[str, int, bool, float], key: str
|
|
41
|
+
) -> int: ...
|
|
42
|
+
|
|
43
|
+
|
|
44
|
+
@overload
|
|
45
|
+
def cast_value(
|
|
46
|
+
target: float, value: Union[str, int, bool, float], key: str
|
|
47
|
+
) -> float: ...
|
|
48
|
+
|
|
49
|
+
|
|
50
|
+
def cast_value(
|
|
51
|
+
target: Any, value: Union[str, int, bool, float], key: str
|
|
52
|
+
) -> Any:
|
|
53
|
+
if isinstance(value, type(target)):
|
|
54
|
+
return value
|
|
55
|
+
elif attrs_has(type(target)):
|
|
56
|
+
raise KeyError(key)
|
|
57
|
+
elif isinstance(target, str):
|
|
58
|
+
return str(value)
|
|
59
|
+
elif isinstance(target, float):
|
|
60
|
+
try:
|
|
61
|
+
return float(value)
|
|
62
|
+
except Exception:
|
|
63
|
+
raise ValueError(value)
|
|
64
|
+
elif isinstance(target, bool):
|
|
65
|
+
if isinstance(value, str):
|
|
66
|
+
if value.lower() in ["on", "true"]:
|
|
67
|
+
return True
|
|
68
|
+
elif value.lower() in ["off", "false"]:
|
|
69
|
+
return False
|
|
70
|
+
else:
|
|
71
|
+
raise ValueError(value)
|
|
72
|
+
else:
|
|
73
|
+
raise TypeError(value)
|
|
74
|
+
elif isinstance(target, int):
|
|
75
|
+
# int must be after bool, because isinstance(True, int) == True
|
|
76
|
+
try:
|
|
77
|
+
return int(value)
|
|
78
|
+
except Exception:
|
|
79
|
+
raise ValueError(value)
|
|
80
|
+
else:
|
|
81
|
+
raise TypeError(value)
|
|
82
|
+
|
|
23
83
|
|
|
24
84
|
@define
|
|
25
85
|
class Paths:
|
|
@@ -28,6 +88,7 @@ class Paths:
|
|
|
28
88
|
db: Path = field(converter=Path)
|
|
29
89
|
log: Path = field(converter=Path)
|
|
30
90
|
tmp: Path = field(converter=Path)
|
|
91
|
+
plugins: Path = field(converter=Path)
|
|
31
92
|
|
|
32
93
|
@auth.default
|
|
33
94
|
def _auth_factory(self) -> Path:
|
|
@@ -69,12 +130,21 @@ class Paths:
|
|
|
69
130
|
root = InstallConfig.default
|
|
70
131
|
return root / "tmp"
|
|
71
132
|
|
|
72
|
-
|
|
133
|
+
@plugins.default
|
|
134
|
+
def _plugins_factory(self) -> Path:
|
|
135
|
+
if InstallConfig.current is not None:
|
|
136
|
+
root = InstallConfig.current.path
|
|
137
|
+
else:
|
|
138
|
+
root = InstallConfig.default
|
|
139
|
+
return root / "plugins"
|
|
140
|
+
|
|
141
|
+
def values(self) -> Iterator[Path]:
|
|
73
142
|
yield self.auth
|
|
74
143
|
yield self.data
|
|
75
144
|
yield self.db
|
|
76
145
|
yield self.log
|
|
77
146
|
yield self.tmp
|
|
147
|
+
yield self.plugins
|
|
78
148
|
|
|
79
149
|
|
|
80
150
|
@define
|
|
@@ -213,6 +283,13 @@ def iter_keys(
|
|
|
213
283
|
yield local_path
|
|
214
284
|
|
|
215
285
|
|
|
286
|
+
@define
|
|
287
|
+
class Plugins:
|
|
288
|
+
"""Configuration for the plugin system"""
|
|
289
|
+
|
|
290
|
+
enabled: bool = False
|
|
291
|
+
|
|
292
|
+
|
|
216
293
|
@define
|
|
217
294
|
class Config:
|
|
218
295
|
paths: Paths = Factory(Paths)
|
|
@@ -221,6 +298,7 @@ class Config:
|
|
|
221
298
|
db: Db = Factory(Db)
|
|
222
299
|
download: Download = Factory(Download)
|
|
223
300
|
api: API = Factory(API)
|
|
301
|
+
plugins: Plugins = Factory(Plugins)
|
|
224
302
|
_raw: TOMLDocument | None = field(init=False, default=None)
|
|
225
303
|
_config_file: Path | None = field(init=False, default=None)
|
|
226
304
|
|
|
@@ -287,7 +365,7 @@ class Config:
|
|
|
287
365
|
def update_item(
|
|
288
366
|
self,
|
|
289
367
|
key: str,
|
|
290
|
-
value: int |
|
|
368
|
+
value: str | int | bool,
|
|
291
369
|
empty_ok: bool = False,
|
|
292
370
|
) -> int | str | None:
|
|
293
371
|
if self._raw is None and empty_ok:
|
|
@@ -302,29 +380,8 @@ class Config:
|
|
|
302
380
|
doc.setdefault(part, {})
|
|
303
381
|
doc = doc[part]
|
|
304
382
|
obj = getattr(obj, part)
|
|
305
|
-
value_type = getattr(fields(type(obj)), last).type
|
|
306
383
|
old_value = getattr(obj, last)
|
|
307
|
-
|
|
308
|
-
raise KeyError(key)
|
|
309
|
-
elif value_type is str:
|
|
310
|
-
...
|
|
311
|
-
elif value_type is int:
|
|
312
|
-
try:
|
|
313
|
-
value = value_type(value)
|
|
314
|
-
except Exception:
|
|
315
|
-
...
|
|
316
|
-
elif value_type is bool:
|
|
317
|
-
if isinstance(value, bool):
|
|
318
|
-
...
|
|
319
|
-
elif isinstance(value, str):
|
|
320
|
-
if value.lower() in ["on", "true"]:
|
|
321
|
-
value = True
|
|
322
|
-
elif value.lower() in ["off", "false"]:
|
|
323
|
-
value = False
|
|
324
|
-
else:
|
|
325
|
-
raise ValueError(value)
|
|
326
|
-
else:
|
|
327
|
-
raise TypeError(value)
|
|
384
|
+
value = cast_value(old_value, value, key)
|
|
328
385
|
setattr(obj, last, value)
|
|
329
386
|
doc[last] = value
|
|
330
387
|
return old_value
|
esgpull/constants.py
CHANGED
esgpull/context.py
CHANGED
|
@@ -16,7 +16,7 @@ from rich.pretty import pretty_repr
|
|
|
16
16
|
|
|
17
17
|
from esgpull.config import Config
|
|
18
18
|
from esgpull.exceptions import SolrUnstableQueryError
|
|
19
|
-
from esgpull.models import
|
|
19
|
+
from esgpull.models import DatasetRecord, File, Query
|
|
20
20
|
from esgpull.tui import logger
|
|
21
21
|
from esgpull.utils import format_date_iso, index2url, sync
|
|
22
22
|
|
|
@@ -151,7 +151,7 @@ class ResultHints(Result):
|
|
|
151
151
|
|
|
152
152
|
@dataclass
|
|
153
153
|
class ResultSearch(Result):
|
|
154
|
-
data: Sequence[File |
|
|
154
|
+
data: Sequence[File | DatasetRecord] = field(init=False, repr=False)
|
|
155
155
|
|
|
156
156
|
def process(self) -> None:
|
|
157
157
|
raise NotImplementedError
|
|
@@ -159,14 +159,14 @@ class ResultSearch(Result):
|
|
|
159
159
|
|
|
160
160
|
@dataclass
|
|
161
161
|
class ResultDatasets(Result):
|
|
162
|
-
data: Sequence[
|
|
162
|
+
data: Sequence[DatasetRecord] = field(init=False, repr=False)
|
|
163
163
|
|
|
164
164
|
def process(self) -> None:
|
|
165
165
|
self.data = []
|
|
166
166
|
if self.success:
|
|
167
167
|
for doc in self.json["response"]["docs"]:
|
|
168
168
|
try:
|
|
169
|
-
dataset =
|
|
169
|
+
dataset = DatasetRecord.serialize(doc)
|
|
170
170
|
self.data.append(dataset)
|
|
171
171
|
except KeyError as exc:
|
|
172
172
|
logger.exception(exc)
|
|
@@ -492,8 +492,8 @@ class Context:
|
|
|
492
492
|
self,
|
|
493
493
|
*results: ResultSearch,
|
|
494
494
|
keep_duplicates: bool,
|
|
495
|
-
) -> list[
|
|
496
|
-
datasets: list[
|
|
495
|
+
) -> list[DatasetRecord]:
|
|
496
|
+
datasets: list[DatasetRecord] = []
|
|
497
497
|
ids: set[str] = set()
|
|
498
498
|
async for result in self._fetch(*results):
|
|
499
499
|
dataset_result = result.to(ResultDatasets)
|
|
@@ -501,7 +501,7 @@ class Context:
|
|
|
501
501
|
if dataset_result.processed:
|
|
502
502
|
for d in dataset_result.data:
|
|
503
503
|
if not keep_duplicates and d.dataset_id in ids:
|
|
504
|
-
logger.
|
|
504
|
+
logger.debug(f"Duplicate dataset {d.dataset_id}")
|
|
505
505
|
else:
|
|
506
506
|
datasets.append(d)
|
|
507
507
|
ids.add(d.dataset_id)
|
|
@@ -520,7 +520,7 @@ class Context:
|
|
|
520
520
|
if files_result.processed:
|
|
521
521
|
for file in files_result.data:
|
|
522
522
|
if not keep_duplicates and file.sha in shas:
|
|
523
|
-
logger.
|
|
523
|
+
logger.debug(f"Duplicate file {file.file_id}")
|
|
524
524
|
else:
|
|
525
525
|
files.append(file)
|
|
526
526
|
shas.add(file.sha)
|
|
@@ -627,7 +627,7 @@ class Context:
|
|
|
627
627
|
date_from: datetime | None = None,
|
|
628
628
|
date_to: datetime | None = None,
|
|
629
629
|
keep_duplicates: bool = True,
|
|
630
|
-
) -> list[
|
|
630
|
+
) -> list[DatasetRecord]:
|
|
631
631
|
if hits is None:
|
|
632
632
|
hits = self.hits(*queries, file=False)
|
|
633
633
|
results = self.prepare_search(
|
esgpull/database.py
CHANGED
|
@@ -12,11 +12,13 @@ import sqlalchemy.orm
|
|
|
12
12
|
from alembic.config import Config as AlembicConfig
|
|
13
13
|
from alembic.migration import MigrationContext
|
|
14
14
|
from alembic.script import ScriptDirectory
|
|
15
|
+
from sqlalchemy.inspection import inspect
|
|
15
16
|
from sqlalchemy.orm import Session, joinedload, make_transient
|
|
16
17
|
|
|
17
18
|
from esgpull import __file__
|
|
18
19
|
from esgpull.config import Config
|
|
19
20
|
from esgpull.models import File, Query, Table, sql
|
|
21
|
+
from esgpull.models.base import Base, BaseNoSHA
|
|
20
22
|
from esgpull.version import __version__
|
|
21
23
|
|
|
22
24
|
# from esgpull.exceptions import NoClauseError
|
|
@@ -145,14 +147,26 @@ class Database:
|
|
|
145
147
|
for item in items:
|
|
146
148
|
make_transient(item)
|
|
147
149
|
|
|
148
|
-
def link(self, query: Query, file: File):
|
|
149
|
-
self.session.
|
|
150
|
-
|
|
151
|
-
|
|
152
|
-
|
|
150
|
+
def link(self, query: Query, file: File) -> bool:
|
|
151
|
+
if not self.session.scalar(sql.query_file.is_linked(query, file)):
|
|
152
|
+
self.session.execute(sql.query_file.link(query, file))
|
|
153
|
+
return True
|
|
154
|
+
else:
|
|
155
|
+
return False
|
|
153
156
|
|
|
154
|
-
def
|
|
155
|
-
|
|
157
|
+
def unlink(self, query: Query, file: File) -> bool:
|
|
158
|
+
if self.session.scalar(sql.query_file.is_linked(query, file)):
|
|
159
|
+
self.session.execute(sql.query_file.unlink(query, file))
|
|
160
|
+
return True
|
|
161
|
+
else:
|
|
162
|
+
return False
|
|
163
|
+
|
|
164
|
+
def __contains__(self, item: Base | BaseNoSHA) -> bool:
|
|
165
|
+
mapper = inspect(item.__class__)
|
|
166
|
+
pk_col = mapper.primary_key[0]
|
|
167
|
+
pk_value = getattr(item, pk_col.name)
|
|
168
|
+
stmt = sa.exists().where(pk_col == pk_value)
|
|
169
|
+
return self.scalars(sa.select(stmt))[0]
|
|
156
170
|
|
|
157
171
|
def has_file_id(self, file: File) -> bool:
|
|
158
172
|
return len(self.scalars(sql.file.with_file_id(file.file_id))) == 1
|
esgpull/download.py
CHANGED
|
@@ -1,6 +1,7 @@
|
|
|
1
1
|
# from math import ceil
|
|
2
2
|
from collections.abc import AsyncGenerator
|
|
3
3
|
from dataclasses import dataclass
|
|
4
|
+
from datetime import datetime
|
|
4
5
|
|
|
5
6
|
from httpx import AsyncClient
|
|
6
7
|
|
|
@@ -19,6 +20,7 @@ class DownloadCtx:
|
|
|
19
20
|
completed: int = 0
|
|
20
21
|
chunk: bytes | None = None
|
|
21
22
|
digest: Digest | None = None
|
|
23
|
+
start_time: datetime | None = None
|
|
22
24
|
|
|
23
25
|
@property
|
|
24
26
|
def finished(self) -> bool:
|
|
@@ -54,6 +56,7 @@ class Simple(BaseDownloader):
|
|
|
54
56
|
ctx: DownloadCtx,
|
|
55
57
|
chunk_size: int,
|
|
56
58
|
) -> AsyncGenerator[DownloadCtx, None]:
|
|
59
|
+
ctx.start_time = datetime.now()
|
|
57
60
|
async with client.stream("GET", ctx.file.url) as resp:
|
|
58
61
|
resp.raise_for_status()
|
|
59
62
|
async for chunk in resp.aiter_bytes(chunk_size=chunk_size):
|
esgpull/esgpull.py
CHANGED
|
@@ -3,6 +3,7 @@ from __future__ import annotations
|
|
|
3
3
|
import logging
|
|
4
4
|
from collections.abc import AsyncIterator
|
|
5
5
|
from dataclasses import dataclass
|
|
6
|
+
from datetime import datetime
|
|
6
7
|
from functools import cached_property, partial
|
|
7
8
|
from pathlib import Path
|
|
8
9
|
from warnings import warn
|
|
@@ -25,6 +26,7 @@ from esgpull.auth import Auth, Credentials
|
|
|
25
26
|
from esgpull.config import Config
|
|
26
27
|
from esgpull.context import Context
|
|
27
28
|
from esgpull.database import Database
|
|
29
|
+
from esgpull.download import DownloadCtx
|
|
28
30
|
from esgpull.exceptions import (
|
|
29
31
|
DownloadCancelled,
|
|
30
32
|
InvalidInstallPath,
|
|
@@ -44,6 +46,13 @@ from esgpull.models import (
|
|
|
44
46
|
sql,
|
|
45
47
|
)
|
|
46
48
|
from esgpull.models.utils import short_sha
|
|
49
|
+
from esgpull.plugin import (
|
|
50
|
+
Event,
|
|
51
|
+
PluginManager,
|
|
52
|
+
emit,
|
|
53
|
+
get_plugin_manager,
|
|
54
|
+
set_plugin_manager,
|
|
55
|
+
)
|
|
47
56
|
from esgpull.processor import Processor
|
|
48
57
|
from esgpull.result import Err, Ok, Result
|
|
49
58
|
from esgpull.tui import UI, DummyLive, Verbosity, logger
|
|
@@ -117,6 +126,18 @@ class Esgpull:
|
|
|
117
126
|
if load_db:
|
|
118
127
|
self.db = Database.from_config(self.config)
|
|
119
128
|
self.graph = Graph(self.db)
|
|
129
|
+
# Initialize plugin system
|
|
130
|
+
plugin_config_path = self.config.paths.plugins / "plugins.toml"
|
|
131
|
+
try:
|
|
132
|
+
self.plugin_manager = get_plugin_manager()
|
|
133
|
+
self.plugin_manager.__init__(config_path=plugin_config_path)
|
|
134
|
+
except ValueError:
|
|
135
|
+
self.plugin_manager = PluginManager(config_path=plugin_config_path)
|
|
136
|
+
set_plugin_manager(self.plugin_manager)
|
|
137
|
+
if self.config.plugins.enabled:
|
|
138
|
+
self.plugin_manager.enabled = True
|
|
139
|
+
self.config.paths.plugins.mkdir(exist_ok=True, parents=True)
|
|
140
|
+
self.plugin_manager.discover_plugins(self.config.paths.plugins)
|
|
120
141
|
|
|
121
142
|
def fetch_index_nodes(self) -> list[str]:
|
|
122
143
|
"""
|
|
@@ -309,7 +330,7 @@ class Esgpull:
|
|
|
309
330
|
progress: Progress,
|
|
310
331
|
task_ids: dict[str, TaskID],
|
|
311
332
|
live: Live | DummyLive,
|
|
312
|
-
) -> AsyncIterator[Result]:
|
|
333
|
+
) -> AsyncIterator[Result[DownloadCtx]]:
|
|
313
334
|
async for result in processor.process():
|
|
314
335
|
task_idx = progress.task_ids.index(task_ids[result.data.file.sha])
|
|
315
336
|
task = progress.tasks[task_idx]
|
|
@@ -348,7 +369,7 @@ class Esgpull:
|
|
|
348
369
|
yield result
|
|
349
370
|
case Err(_, err):
|
|
350
371
|
progress.remove_task(task.id)
|
|
351
|
-
yield Err(result.data, err)
|
|
372
|
+
yield Err(result.data, err=err)
|
|
352
373
|
case Err():
|
|
353
374
|
progress.remove_task(task.id)
|
|
354
375
|
yield result
|
|
@@ -438,15 +459,38 @@ class Esgpull:
|
|
|
438
459
|
match result:
|
|
439
460
|
case Ok():
|
|
440
461
|
main_progress.update(main_task_id, advance=1)
|
|
441
|
-
result.data.file
|
|
442
|
-
|
|
443
|
-
|
|
462
|
+
file = result.data.file
|
|
463
|
+
file.status = FileStatus.Done
|
|
464
|
+
files.append(file)
|
|
465
|
+
emit(
|
|
466
|
+
Event.file_complete,
|
|
467
|
+
file=file,
|
|
468
|
+
destination=self.fs[file].drs,
|
|
469
|
+
start_time=result.data.start_time,
|
|
470
|
+
end_time=datetime.now(),
|
|
471
|
+
)
|
|
472
|
+
if file.dataset is not None:
|
|
473
|
+
is_dataset_complete = self.db.scalars(
|
|
474
|
+
sql.dataset.is_complete(file.dataset)
|
|
475
|
+
)[0]
|
|
476
|
+
if is_dataset_complete:
|
|
477
|
+
emit(
|
|
478
|
+
Event.dataset_complete,
|
|
479
|
+
dataset=file.dataset,
|
|
480
|
+
)
|
|
481
|
+
case Err(_, err):
|
|
444
482
|
queue_size -= 1
|
|
445
483
|
main_progress.update(
|
|
446
484
|
main_task_id, total=queue_size
|
|
447
485
|
)
|
|
448
486
|
result.data.file.status = FileStatus.Error
|
|
449
487
|
errors.append(result)
|
|
488
|
+
emit(
|
|
489
|
+
Event.file_error,
|
|
490
|
+
file=result.data.file,
|
|
491
|
+
exception=err,
|
|
492
|
+
)
|
|
493
|
+
|
|
450
494
|
if use_db:
|
|
451
495
|
self.db.add(result.data.file)
|
|
452
496
|
remaining_dict.pop(result.data.file.sha, None)
|