esgpull 0.7.3__py3-none-any.whl → 0.9.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- esgpull/cli/__init__.py +2 -2
- esgpull/cli/add.py +7 -1
- esgpull/cli/config.py +5 -21
- esgpull/cli/plugins.py +398 -0
- esgpull/cli/show.py +29 -0
- esgpull/cli/status.py +6 -4
- esgpull/cli/update.py +72 -18
- esgpull/cli/utils.py +16 -1
- esgpull/config.py +83 -25
- esgpull/constants.py +3 -0
- esgpull/context.py +15 -15
- esgpull/database.py +8 -2
- esgpull/download.py +3 -0
- esgpull/esgpull.py +49 -5
- esgpull/graph.py +1 -1
- esgpull/migrations/versions/0.8.0_update_tables.py +28 -0
- esgpull/migrations/versions/0.9.0_update_tables.py +28 -0
- esgpull/migrations/versions/14c72daea083_query_add_column_updated_at.py +36 -0
- esgpull/migrations/versions/c7c8541fa741_query_add_column_added_at.py +37 -0
- esgpull/migrations/versions/d14f179e553c_file_add_composite_index_dataset_id_.py +32 -0
- esgpull/migrations/versions/e7edab5d4e4b_add_dataset_tracking.py +39 -0
- esgpull/models/__init__.py +2 -1
- esgpull/models/base.py +31 -14
- esgpull/models/dataset.py +48 -5
- esgpull/models/options.py +1 -1
- esgpull/models/query.py +98 -15
- esgpull/models/sql.py +40 -9
- esgpull/plugin.py +574 -0
- esgpull/processor.py +3 -3
- esgpull/tui.py +23 -1
- esgpull/utils.py +19 -3
- {esgpull-0.7.3.dist-info → esgpull-0.9.0.dist-info}/METADATA +11 -2
- {esgpull-0.7.3.dist-info → esgpull-0.9.0.dist-info}/RECORD +36 -29
- {esgpull-0.7.3.dist-info → esgpull-0.9.0.dist-info}/WHEEL +1 -1
- esgpull/cli/datasets.py +0 -78
- {esgpull-0.7.3.dist-info → esgpull-0.9.0.dist-info}/entry_points.txt +0 -0
- {esgpull-0.7.3.dist-info → esgpull-0.9.0.dist-info}/licenses/LICENSE +0 -0
esgpull/cli/update.py
CHANGED
|
@@ -1,6 +1,7 @@
|
|
|
1
1
|
from __future__ import annotations
|
|
2
2
|
|
|
3
3
|
from dataclasses import dataclass, field
|
|
4
|
+
from datetime import datetime, timezone
|
|
4
5
|
|
|
5
6
|
import click
|
|
6
7
|
from click.exceptions import Abort, Exit
|
|
@@ -9,7 +10,7 @@ from esgpull.cli.decorators import args, opts
|
|
|
9
10
|
from esgpull.cli.utils import get_queries, init_esgpull, valid_name_tag
|
|
10
11
|
from esgpull.context import HintsDict, ResultSearch
|
|
11
12
|
from esgpull.exceptions import UnsetOptionsError
|
|
12
|
-
from esgpull.models import File, FileStatus, Query
|
|
13
|
+
from esgpull.models import Dataset, File, FileStatus, Query
|
|
13
14
|
from esgpull.tui import Verbosity, logger
|
|
14
15
|
from esgpull.utils import format_size
|
|
15
16
|
|
|
@@ -19,10 +20,13 @@ class QueryFiles:
|
|
|
19
20
|
query: Query
|
|
20
21
|
expanded: Query
|
|
21
22
|
skip: bool = False
|
|
23
|
+
datasets: list[Dataset] = field(default_factory=list)
|
|
22
24
|
files: list[File] = field(default_factory=list)
|
|
25
|
+
dataset_hits: int = field(init=False)
|
|
23
26
|
hits: int = field(init=False)
|
|
24
27
|
hints: HintsDict = field(init=False)
|
|
25
28
|
results: list[ResultSearch] = field(init=False)
|
|
29
|
+
dataset_results: list[ResultSearch] = field(init=False)
|
|
26
30
|
|
|
27
31
|
|
|
28
32
|
@click.command()
|
|
@@ -79,20 +83,27 @@ def update(
|
|
|
79
83
|
file=True,
|
|
80
84
|
facets=["index_node"],
|
|
81
85
|
)
|
|
82
|
-
|
|
86
|
+
dataset_hits = esg.context.hits(
|
|
87
|
+
*[qf.expanded for qf in qfs],
|
|
88
|
+
file=False,
|
|
89
|
+
)
|
|
90
|
+
for qf, qf_hints, qf_dataset_hits in zip(qfs, hints, dataset_hits):
|
|
83
91
|
qf.hits = sum(esg.context.hits_from_hints(qf_hints))
|
|
84
92
|
if qf_hints:
|
|
85
93
|
qf.hints = qf_hints
|
|
94
|
+
qf.dataset_hits = qf_dataset_hits
|
|
86
95
|
else:
|
|
87
96
|
qf.skip = True
|
|
88
97
|
for qf in qfs:
|
|
89
98
|
s = "s" if qf.hits > 1 else ""
|
|
90
|
-
esg.ui.print(
|
|
99
|
+
esg.ui.print(
|
|
100
|
+
f"{qf.query.rich_name} -> {qf.hits} file{s} (before replica de-duplication)."
|
|
101
|
+
)
|
|
91
102
|
total_hits = sum([qf.hits for qf in qfs])
|
|
92
103
|
if total_hits == 0:
|
|
93
104
|
esg.ui.print("No files found.")
|
|
94
105
|
esg.ui.raise_maybe_record(Exit(0))
|
|
95
|
-
|
|
106
|
+
elif len(qfs) > 1:
|
|
96
107
|
esg.ui.print(f"{total_hits} files found.")
|
|
97
108
|
qfs = [qf for qf in qfs if not qf.skip]
|
|
98
109
|
# Prepare optimally distributed requests to ESGF
|
|
@@ -100,13 +111,27 @@ def update(
|
|
|
100
111
|
# It might be interesting for the special case where all files already
|
|
101
112
|
# exist in db, then the detailed fetch could be skipped.
|
|
102
113
|
for qf in qfs:
|
|
103
|
-
|
|
114
|
+
qf_dataset_results = esg.context.prepare_search(
|
|
104
115
|
qf.expanded,
|
|
105
|
-
file=
|
|
106
|
-
|
|
116
|
+
file=False,
|
|
117
|
+
hits=[qf.dataset_hits],
|
|
107
118
|
max_hits=None,
|
|
108
119
|
)
|
|
109
|
-
|
|
120
|
+
if esg.config.api.use_custom_distribution_algorithm:
|
|
121
|
+
qf_results = esg.context.prepare_search_distributed(
|
|
122
|
+
qf.expanded,
|
|
123
|
+
file=True,
|
|
124
|
+
hints=[qf.hints],
|
|
125
|
+
max_hits=None,
|
|
126
|
+
)
|
|
127
|
+
else:
|
|
128
|
+
qf_results = esg.context.prepare_search(
|
|
129
|
+
qf.expanded,
|
|
130
|
+
file=True,
|
|
131
|
+
hits=[qf.hits],
|
|
132
|
+
max_hits=None,
|
|
133
|
+
)
|
|
134
|
+
nb_req = len(qf_dataset_results) + len(qf_results)
|
|
110
135
|
if nb_req > 50:
|
|
111
136
|
msg = (
|
|
112
137
|
f"{nb_req} requests will be sent to ESGF to"
|
|
@@ -117,13 +142,32 @@ def update(
|
|
|
117
142
|
esg.ui.print(f"{qf.query.rich_name} is now untracked.")
|
|
118
143
|
qf.query.tracked = False
|
|
119
144
|
qf_results = []
|
|
145
|
+
qf_dataset_results = []
|
|
120
146
|
case "n":
|
|
121
147
|
qf_results = []
|
|
148
|
+
qf_dataset_results = []
|
|
122
149
|
case _:
|
|
123
150
|
...
|
|
124
151
|
qf.results = qf_results
|
|
152
|
+
qf.dataset_results = qf_dataset_results
|
|
125
153
|
# Fetch files and update db
|
|
126
154
|
# [?] TODO: dry_run to print urls here
|
|
155
|
+
with esg.ui.spinner("Fetching datasets"):
|
|
156
|
+
coros = []
|
|
157
|
+
for qf in qfs:
|
|
158
|
+
coro = esg.context._datasets(
|
|
159
|
+
*qf.dataset_results, keep_duplicates=False
|
|
160
|
+
)
|
|
161
|
+
coros.append(coro)
|
|
162
|
+
datasets = esg.context.sync_gather(*coros)
|
|
163
|
+
for qf, qf_datasets in zip(qfs, datasets):
|
|
164
|
+
qf.datasets = [
|
|
165
|
+
Dataset(
|
|
166
|
+
dataset_id=record.dataset_id,
|
|
167
|
+
total_files=record.number_of_files,
|
|
168
|
+
)
|
|
169
|
+
for record in qf_datasets
|
|
170
|
+
]
|
|
127
171
|
with esg.ui.spinner("Fetching files"):
|
|
128
172
|
coros = []
|
|
129
173
|
for qf in qfs:
|
|
@@ -133,23 +177,26 @@ def update(
|
|
|
133
177
|
for qf, qf_files in zip(qfs, files):
|
|
134
178
|
qf.files = qf_files
|
|
135
179
|
for qf in qfs:
|
|
136
|
-
|
|
137
|
-
|
|
138
|
-
|
|
139
|
-
if file.sha not in shas:
|
|
140
|
-
new_files.append(file)
|
|
180
|
+
new_files = [f for f in qf.files if f not in esg.db]
|
|
181
|
+
new_datasets = [d for d in qf.datasets if d not in esg.db]
|
|
182
|
+
nb_datasets = len(new_datasets)
|
|
141
183
|
nb_files = len(new_files)
|
|
142
184
|
if not qf.query.tracked:
|
|
143
185
|
esg.db.add(qf.query)
|
|
144
186
|
continue
|
|
145
|
-
elif nb_files == 0:
|
|
187
|
+
elif nb_datasets == nb_files == 0:
|
|
146
188
|
esg.ui.print(f"{qf.query.rich_name} is already up-to-date.")
|
|
147
189
|
continue
|
|
148
190
|
size = sum([file.size for file in new_files])
|
|
191
|
+
if size > 0:
|
|
192
|
+
queue_msg = " and send new files to download queue"
|
|
193
|
+
else:
|
|
194
|
+
queue_msg = ""
|
|
149
195
|
msg = (
|
|
150
|
-
f"\
|
|
151
|
-
f"
|
|
152
|
-
"
|
|
196
|
+
f"\n{qf.query.rich_name}: {nb_files} new"
|
|
197
|
+
f" files, {nb_datasets} new datasets"
|
|
198
|
+
f" ({format_size(size)})."
|
|
199
|
+
f"\nUpdate the database{queue_msg}?"
|
|
153
200
|
)
|
|
154
201
|
if yes:
|
|
155
202
|
choice = "y"
|
|
@@ -166,9 +213,14 @@ def update(
|
|
|
166
213
|
legacy = esg.legacy_query
|
|
167
214
|
has_legacy = legacy.state.persistent
|
|
168
215
|
with esg.db.commit_context():
|
|
216
|
+
for dataset in esg.ui.track(
|
|
217
|
+
new_datasets,
|
|
218
|
+
description=f"{qf.query.rich_name} (datasets)",
|
|
219
|
+
):
|
|
220
|
+
esg.db.session.add(dataset)
|
|
169
221
|
for file in esg.ui.track(
|
|
170
222
|
new_files,
|
|
171
|
-
description=qf.query.rich_name,
|
|
223
|
+
description=f"{qf.query.rich_name} (files)",
|
|
172
224
|
):
|
|
173
225
|
file_db = esg.db.get(File, file.sha)
|
|
174
226
|
if file_db is None:
|
|
@@ -184,4 +236,6 @@ def update(
|
|
|
184
236
|
elif has_legacy and legacy in file_db.queries:
|
|
185
237
|
esg.db.unlink(query=legacy, file=file_db)
|
|
186
238
|
esg.db.link(query=qf.query, file=file)
|
|
239
|
+
qf.query.updated_at = datetime.now(timezone.utc)
|
|
240
|
+
esg.db.session.add(qf.query)
|
|
187
241
|
esg.ui.raise_maybe_record(Exit(0))
|
esgpull/cli/utils.py
CHANGED
|
@@ -103,7 +103,7 @@ def totable(docs: list[OrderedDict[str, Any]]) -> Table:
|
|
|
103
103
|
table = Table(box=MINIMAL_DOUBLE_HEAD, show_edge=False)
|
|
104
104
|
for key in docs[0].keys():
|
|
105
105
|
justify: Literal["left", "right", "center"]
|
|
106
|
-
if key in ["file", "dataset"]:
|
|
106
|
+
if key in ["file", "dataset", "plugin"]:
|
|
107
107
|
justify = "left"
|
|
108
108
|
else:
|
|
109
109
|
justify = "right"
|
|
@@ -243,3 +243,18 @@ def get_queries(
|
|
|
243
243
|
kids = graph.get_all_children(query.sha)
|
|
244
244
|
queries.extend(kids)
|
|
245
245
|
return queries
|
|
246
|
+
|
|
247
|
+
|
|
248
|
+
def extract_subdict(doc: dict, key: str | None) -> dict:
|
|
249
|
+
if key is None:
|
|
250
|
+
return doc
|
|
251
|
+
for part in key.split("."):
|
|
252
|
+
if not part:
|
|
253
|
+
raise KeyError(key)
|
|
254
|
+
elif part in doc:
|
|
255
|
+
doc = doc[part]
|
|
256
|
+
else:
|
|
257
|
+
raise KeyError(part)
|
|
258
|
+
for part in key.split(".")[::-1]:
|
|
259
|
+
doc = {part: doc}
|
|
260
|
+
return doc
|
esgpull/config.py
CHANGED
|
@@ -4,10 +4,10 @@ import logging
|
|
|
4
4
|
from collections.abc import Iterator, Mapping
|
|
5
5
|
from enum import Enum, auto
|
|
6
6
|
from pathlib import Path
|
|
7
|
-
from typing import Any, cast
|
|
7
|
+
from typing import Any, TypeVar, Union, cast, overload
|
|
8
8
|
|
|
9
9
|
import tomlkit
|
|
10
|
-
from attrs import Factory, define, field
|
|
10
|
+
from attrs import Factory, define, field
|
|
11
11
|
from attrs import has as attrs_has
|
|
12
12
|
from cattrs import Converter
|
|
13
13
|
from cattrs.gen import make_dict_unstructure_fn, override
|
|
@@ -20,6 +20,66 @@ from esgpull.models.options import Options
|
|
|
20
20
|
|
|
21
21
|
logger = logging.getLogger("esgpull")
|
|
22
22
|
|
|
23
|
+
T = TypeVar("T")
|
|
24
|
+
|
|
25
|
+
|
|
26
|
+
@overload
|
|
27
|
+
def cast_value(
|
|
28
|
+
target: str, value: Union[str, int, bool, float], key: str
|
|
29
|
+
) -> str: ...
|
|
30
|
+
|
|
31
|
+
|
|
32
|
+
@overload
|
|
33
|
+
def cast_value(
|
|
34
|
+
target: bool, value: Union[str, int, bool, float], key: str
|
|
35
|
+
) -> bool: ...
|
|
36
|
+
|
|
37
|
+
|
|
38
|
+
@overload
|
|
39
|
+
def cast_value(
|
|
40
|
+
target: int, value: Union[str, int, bool, float], key: str
|
|
41
|
+
) -> int: ...
|
|
42
|
+
|
|
43
|
+
|
|
44
|
+
@overload
|
|
45
|
+
def cast_value(
|
|
46
|
+
target: float, value: Union[str, int, bool, float], key: str
|
|
47
|
+
) -> float: ...
|
|
48
|
+
|
|
49
|
+
|
|
50
|
+
def cast_value(
|
|
51
|
+
target: Any, value: Union[str, int, bool, float], key: str
|
|
52
|
+
) -> Any:
|
|
53
|
+
if isinstance(value, type(target)):
|
|
54
|
+
return value
|
|
55
|
+
elif attrs_has(type(target)):
|
|
56
|
+
raise KeyError(key)
|
|
57
|
+
elif isinstance(target, str):
|
|
58
|
+
return str(value)
|
|
59
|
+
elif isinstance(target, float):
|
|
60
|
+
try:
|
|
61
|
+
return float(value)
|
|
62
|
+
except Exception:
|
|
63
|
+
raise ValueError(value)
|
|
64
|
+
elif isinstance(target, bool):
|
|
65
|
+
if isinstance(value, str):
|
|
66
|
+
if value.lower() in ["on", "true"]:
|
|
67
|
+
return True
|
|
68
|
+
elif value.lower() in ["off", "false"]:
|
|
69
|
+
return False
|
|
70
|
+
else:
|
|
71
|
+
raise ValueError(value)
|
|
72
|
+
else:
|
|
73
|
+
raise TypeError(value)
|
|
74
|
+
elif isinstance(target, int):
|
|
75
|
+
# int must be after bool, because isinstance(True, int) == True
|
|
76
|
+
try:
|
|
77
|
+
return int(value)
|
|
78
|
+
except Exception:
|
|
79
|
+
raise ValueError(value)
|
|
80
|
+
else:
|
|
81
|
+
raise TypeError(value)
|
|
82
|
+
|
|
23
83
|
|
|
24
84
|
@define
|
|
25
85
|
class Paths:
|
|
@@ -28,6 +88,7 @@ class Paths:
|
|
|
28
88
|
db: Path = field(converter=Path)
|
|
29
89
|
log: Path = field(converter=Path)
|
|
30
90
|
tmp: Path = field(converter=Path)
|
|
91
|
+
plugins: Path = field(converter=Path)
|
|
31
92
|
|
|
32
93
|
@auth.default
|
|
33
94
|
def _auth_factory(self) -> Path:
|
|
@@ -69,12 +130,21 @@ class Paths:
|
|
|
69
130
|
root = InstallConfig.default
|
|
70
131
|
return root / "tmp"
|
|
71
132
|
|
|
133
|
+
@plugins.default
|
|
134
|
+
def _plugins_factory(self) -> Path:
|
|
135
|
+
if InstallConfig.current is not None:
|
|
136
|
+
root = InstallConfig.current.path
|
|
137
|
+
else:
|
|
138
|
+
root = InstallConfig.default
|
|
139
|
+
return root / "plugins"
|
|
140
|
+
|
|
72
141
|
def __iter__(self) -> Iterator[Path]:
|
|
73
142
|
yield self.auth
|
|
74
143
|
yield self.data
|
|
75
144
|
yield self.db
|
|
76
145
|
yield self.log
|
|
77
146
|
yield self.tmp
|
|
147
|
+
yield self.plugins
|
|
78
148
|
|
|
79
149
|
|
|
80
150
|
@define
|
|
@@ -126,6 +196,7 @@ class API:
|
|
|
126
196
|
page_limit: int = 50
|
|
127
197
|
default_options: DefaultOptions = Factory(DefaultOptions)
|
|
128
198
|
default_query_id: str = ""
|
|
199
|
+
use_custom_distribution_algorithm: bool = False
|
|
129
200
|
|
|
130
201
|
|
|
131
202
|
def fix_rename_search_api(doc: TOMLDocument) -> TOMLDocument:
|
|
@@ -212,6 +283,13 @@ def iter_keys(
|
|
|
212
283
|
yield local_path
|
|
213
284
|
|
|
214
285
|
|
|
286
|
+
@define
|
|
287
|
+
class Plugins:
|
|
288
|
+
"""Configuration for the plugin system"""
|
|
289
|
+
|
|
290
|
+
enabled: bool = False
|
|
291
|
+
|
|
292
|
+
|
|
215
293
|
@define
|
|
216
294
|
class Config:
|
|
217
295
|
paths: Paths = Factory(Paths)
|
|
@@ -220,6 +298,7 @@ class Config:
|
|
|
220
298
|
db: Db = Factory(Db)
|
|
221
299
|
download: Download = Factory(Download)
|
|
222
300
|
api: API = Factory(API)
|
|
301
|
+
plugins: Plugins = Factory(Plugins)
|
|
223
302
|
_raw: TOMLDocument | None = field(init=False, default=None)
|
|
224
303
|
_config_file: Path | None = field(init=False, default=None)
|
|
225
304
|
|
|
@@ -286,7 +365,7 @@ class Config:
|
|
|
286
365
|
def update_item(
|
|
287
366
|
self,
|
|
288
367
|
key: str,
|
|
289
|
-
value: int |
|
|
368
|
+
value: str | int | bool,
|
|
290
369
|
empty_ok: bool = False,
|
|
291
370
|
) -> int | str | None:
|
|
292
371
|
if self._raw is None and empty_ok:
|
|
@@ -301,29 +380,8 @@ class Config:
|
|
|
301
380
|
doc.setdefault(part, {})
|
|
302
381
|
doc = doc[part]
|
|
303
382
|
obj = getattr(obj, part)
|
|
304
|
-
value_type = getattr(fields(type(obj)), last).type
|
|
305
383
|
old_value = getattr(obj, last)
|
|
306
|
-
|
|
307
|
-
raise KeyError(key)
|
|
308
|
-
elif value_type is str:
|
|
309
|
-
...
|
|
310
|
-
elif value_type is int:
|
|
311
|
-
try:
|
|
312
|
-
value = value_type(value)
|
|
313
|
-
except Exception:
|
|
314
|
-
...
|
|
315
|
-
elif value_type is bool:
|
|
316
|
-
if isinstance(value, bool):
|
|
317
|
-
...
|
|
318
|
-
elif isinstance(value, str):
|
|
319
|
-
if value.lower() in ["on", "true"]:
|
|
320
|
-
value = True
|
|
321
|
-
elif value.lower() in ["off", "false"]:
|
|
322
|
-
value = False
|
|
323
|
-
else:
|
|
324
|
-
raise ValueError(value)
|
|
325
|
-
else:
|
|
326
|
-
raise TypeError(value)
|
|
384
|
+
value = cast_value(old_value, value, key)
|
|
327
385
|
setattr(obj, last, value)
|
|
328
386
|
doc[last] = value
|
|
329
387
|
return old_value
|
esgpull/constants.py
CHANGED
esgpull/context.py
CHANGED
|
@@ -16,9 +16,9 @@ from rich.pretty import pretty_repr
|
|
|
16
16
|
|
|
17
17
|
from esgpull.config import Config
|
|
18
18
|
from esgpull.exceptions import SolrUnstableQueryError
|
|
19
|
-
from esgpull.models import
|
|
19
|
+
from esgpull.models import DatasetRecord, File, Query
|
|
20
20
|
from esgpull.tui import logger
|
|
21
|
-
from esgpull.utils import
|
|
21
|
+
from esgpull.utils import format_date_iso, index2url, sync
|
|
22
22
|
|
|
23
23
|
# workaround for notebooks with running event loop
|
|
24
24
|
if asyncio.get_event_loop().is_running():
|
|
@@ -77,9 +77,9 @@ class Result:
|
|
|
77
77
|
else:
|
|
78
78
|
params["fields"] = "instance_id"
|
|
79
79
|
if date_from is not None:
|
|
80
|
-
params["from"] =
|
|
80
|
+
params["from"] = format_date_iso(date_from)
|
|
81
81
|
if date_to is not None:
|
|
82
|
-
params["to"] =
|
|
82
|
+
params["to"] = format_date_iso(date_to)
|
|
83
83
|
if facets_param is not None:
|
|
84
84
|
if len(set(facets_param) & DangerousFacets) > 0:
|
|
85
85
|
raise SolrUnstableQueryError(pretty_repr(self.query))
|
|
@@ -90,9 +90,9 @@ class Result:
|
|
|
90
90
|
facets_star = False
|
|
91
91
|
# [?]TODO: add nominal temporal constraints `to`
|
|
92
92
|
# if "start" in facets:
|
|
93
|
-
# query["start"] =
|
|
93
|
+
# query["start"] = format_date_iso(str(facets.pop("start")))
|
|
94
94
|
# if "end" in facets:
|
|
95
|
-
# query["end"] =
|
|
95
|
+
# query["end"] = format_date_iso(str(facets.pop("end")))
|
|
96
96
|
solr_terms: list[str] = []
|
|
97
97
|
for name, values in self.query.selection.items():
|
|
98
98
|
value_term = " ".join(values)
|
|
@@ -151,7 +151,7 @@ class ResultHints(Result):
|
|
|
151
151
|
|
|
152
152
|
@dataclass
|
|
153
153
|
class ResultSearch(Result):
|
|
154
|
-
data: Sequence[File |
|
|
154
|
+
data: Sequence[File | DatasetRecord] = field(init=False, repr=False)
|
|
155
155
|
|
|
156
156
|
def process(self) -> None:
|
|
157
157
|
raise NotImplementedError
|
|
@@ -159,14 +159,14 @@ class ResultSearch(Result):
|
|
|
159
159
|
|
|
160
160
|
@dataclass
|
|
161
161
|
class ResultDatasets(Result):
|
|
162
|
-
data: Sequence[
|
|
162
|
+
data: Sequence[DatasetRecord] = field(init=False, repr=False)
|
|
163
163
|
|
|
164
164
|
def process(self) -> None:
|
|
165
165
|
self.data = []
|
|
166
166
|
if self.success:
|
|
167
167
|
for doc in self.json["response"]["docs"]:
|
|
168
168
|
try:
|
|
169
|
-
dataset =
|
|
169
|
+
dataset = DatasetRecord.serialize(doc)
|
|
170
170
|
self.data.append(dataset)
|
|
171
171
|
except KeyError as exc:
|
|
172
172
|
logger.exception(exc)
|
|
@@ -282,7 +282,7 @@ class Context:
|
|
|
282
282
|
# # if since is None:
|
|
283
283
|
# # self.since = since
|
|
284
284
|
# # else:
|
|
285
|
-
# # self.since =
|
|
285
|
+
# # self.since = format_date_iso(since)
|
|
286
286
|
|
|
287
287
|
async def __aenter__(self) -> Context:
|
|
288
288
|
if hasattr(self, "client"):
|
|
@@ -492,8 +492,8 @@ class Context:
|
|
|
492
492
|
self,
|
|
493
493
|
*results: ResultSearch,
|
|
494
494
|
keep_duplicates: bool,
|
|
495
|
-
) -> list[
|
|
496
|
-
datasets: list[
|
|
495
|
+
) -> list[DatasetRecord]:
|
|
496
|
+
datasets: list[DatasetRecord] = []
|
|
497
497
|
ids: set[str] = set()
|
|
498
498
|
async for result in self._fetch(*results):
|
|
499
499
|
dataset_result = result.to(ResultDatasets)
|
|
@@ -501,7 +501,7 @@ class Context:
|
|
|
501
501
|
if dataset_result.processed:
|
|
502
502
|
for d in dataset_result.data:
|
|
503
503
|
if not keep_duplicates and d.dataset_id in ids:
|
|
504
|
-
logger.
|
|
504
|
+
logger.debug(f"Duplicate dataset {d.dataset_id}")
|
|
505
505
|
else:
|
|
506
506
|
datasets.append(d)
|
|
507
507
|
ids.add(d.dataset_id)
|
|
@@ -520,7 +520,7 @@ class Context:
|
|
|
520
520
|
if files_result.processed:
|
|
521
521
|
for file in files_result.data:
|
|
522
522
|
if not keep_duplicates and file.sha in shas:
|
|
523
|
-
logger.
|
|
523
|
+
logger.debug(f"Duplicate file {file.file_id}")
|
|
524
524
|
else:
|
|
525
525
|
files.append(file)
|
|
526
526
|
shas.add(file.sha)
|
|
@@ -627,7 +627,7 @@ class Context:
|
|
|
627
627
|
date_from: datetime | None = None,
|
|
628
628
|
date_to: datetime | None = None,
|
|
629
629
|
keep_duplicates: bool = True,
|
|
630
|
-
) -> list[
|
|
630
|
+
) -> list[DatasetRecord]:
|
|
631
631
|
if hits is None:
|
|
632
632
|
hits = self.hits(*queries, file=False)
|
|
633
633
|
results = self.prepare_search(
|
esgpull/database.py
CHANGED
|
@@ -12,11 +12,13 @@ import sqlalchemy.orm
|
|
|
12
12
|
from alembic.config import Config as AlembicConfig
|
|
13
13
|
from alembic.migration import MigrationContext
|
|
14
14
|
from alembic.script import ScriptDirectory
|
|
15
|
+
from sqlalchemy.inspection import inspect
|
|
15
16
|
from sqlalchemy.orm import Session, joinedload, make_transient
|
|
16
17
|
|
|
17
18
|
from esgpull import __file__
|
|
18
19
|
from esgpull.config import Config
|
|
19
20
|
from esgpull.models import File, Query, Table, sql
|
|
21
|
+
from esgpull.models.base import Base, BaseNoSHA
|
|
20
22
|
from esgpull.version import __version__
|
|
21
23
|
|
|
22
24
|
# from esgpull.exceptions import NoClauseError
|
|
@@ -151,8 +153,12 @@ class Database:
|
|
|
151
153
|
def unlink(self, query: Query, file: File):
|
|
152
154
|
self.session.execute(sql.query_file.unlink(query, file))
|
|
153
155
|
|
|
154
|
-
def __contains__(self, item:
|
|
155
|
-
|
|
156
|
+
def __contains__(self, item: Base | BaseNoSHA) -> bool:
|
|
157
|
+
mapper = inspect(item.__class__)
|
|
158
|
+
pk_col = mapper.primary_key[0]
|
|
159
|
+
pk_value = getattr(item, pk_col.name)
|
|
160
|
+
stmt = sa.exists().where(pk_col == pk_value)
|
|
161
|
+
return self.scalars(sa.select(stmt))[0]
|
|
156
162
|
|
|
157
163
|
def has_file_id(self, file: File) -> bool:
|
|
158
164
|
return len(self.scalars(sql.file.with_file_id(file.file_id))) == 1
|
esgpull/download.py
CHANGED
|
@@ -1,6 +1,7 @@
|
|
|
1
1
|
# from math import ceil
|
|
2
2
|
from collections.abc import AsyncGenerator
|
|
3
3
|
from dataclasses import dataclass
|
|
4
|
+
from datetime import datetime
|
|
4
5
|
|
|
5
6
|
from httpx import AsyncClient
|
|
6
7
|
|
|
@@ -19,6 +20,7 @@ class DownloadCtx:
|
|
|
19
20
|
completed: int = 0
|
|
20
21
|
chunk: bytes | None = None
|
|
21
22
|
digest: Digest | None = None
|
|
23
|
+
start_time: datetime | None = None
|
|
22
24
|
|
|
23
25
|
@property
|
|
24
26
|
def finished(self) -> bool:
|
|
@@ -54,6 +56,7 @@ class Simple(BaseDownloader):
|
|
|
54
56
|
ctx: DownloadCtx,
|
|
55
57
|
chunk_size: int,
|
|
56
58
|
) -> AsyncGenerator[DownloadCtx, None]:
|
|
59
|
+
ctx.start_time = datetime.now()
|
|
57
60
|
async with client.stream("GET", ctx.file.url) as resp:
|
|
58
61
|
resp.raise_for_status()
|
|
59
62
|
async for chunk in resp.aiter_bytes(chunk_size=chunk_size):
|
esgpull/esgpull.py
CHANGED
|
@@ -3,6 +3,7 @@ from __future__ import annotations
|
|
|
3
3
|
import logging
|
|
4
4
|
from collections.abc import AsyncIterator
|
|
5
5
|
from dataclasses import dataclass
|
|
6
|
+
from datetime import datetime
|
|
6
7
|
from functools import cached_property, partial
|
|
7
8
|
from pathlib import Path
|
|
8
9
|
from warnings import warn
|
|
@@ -25,6 +26,7 @@ from esgpull.auth import Auth, Credentials
|
|
|
25
26
|
from esgpull.config import Config
|
|
26
27
|
from esgpull.context import Context
|
|
27
28
|
from esgpull.database import Database
|
|
29
|
+
from esgpull.download import DownloadCtx
|
|
28
30
|
from esgpull.exceptions import (
|
|
29
31
|
DownloadCancelled,
|
|
30
32
|
InvalidInstallPath,
|
|
@@ -44,6 +46,13 @@ from esgpull.models import (
|
|
|
44
46
|
sql,
|
|
45
47
|
)
|
|
46
48
|
from esgpull.models.utils import short_sha
|
|
49
|
+
from esgpull.plugin import (
|
|
50
|
+
Event,
|
|
51
|
+
PluginManager,
|
|
52
|
+
emit,
|
|
53
|
+
get_plugin_manager,
|
|
54
|
+
set_plugin_manager,
|
|
55
|
+
)
|
|
47
56
|
from esgpull.processor import Processor
|
|
48
57
|
from esgpull.result import Err, Ok, Result
|
|
49
58
|
from esgpull.tui import UI, DummyLive, Verbosity, logger
|
|
@@ -117,6 +126,18 @@ class Esgpull:
|
|
|
117
126
|
if load_db:
|
|
118
127
|
self.db = Database.from_config(self.config)
|
|
119
128
|
self.graph = Graph(self.db)
|
|
129
|
+
# Initialize plugin system
|
|
130
|
+
plugin_config_path = self.config.paths.plugins / "plugins.toml"
|
|
131
|
+
try:
|
|
132
|
+
self.plugin_manager = get_plugin_manager()
|
|
133
|
+
self.plugin_manager.__init__(config_path=plugin_config_path)
|
|
134
|
+
except ValueError:
|
|
135
|
+
self.plugin_manager = PluginManager(config_path=plugin_config_path)
|
|
136
|
+
set_plugin_manager(self.plugin_manager)
|
|
137
|
+
if self.config.plugins.enabled:
|
|
138
|
+
self.plugin_manager.enabled = True
|
|
139
|
+
self.config.paths.plugins.mkdir(exist_ok=True, parents=True)
|
|
140
|
+
self.plugin_manager.discover_plugins(self.config.paths.plugins)
|
|
120
141
|
|
|
121
142
|
def fetch_index_nodes(self) -> list[str]:
|
|
122
143
|
"""
|
|
@@ -309,7 +330,7 @@ class Esgpull:
|
|
|
309
330
|
progress: Progress,
|
|
310
331
|
task_ids: dict[str, TaskID],
|
|
311
332
|
live: Live | DummyLive,
|
|
312
|
-
) -> AsyncIterator[Result]:
|
|
333
|
+
) -> AsyncIterator[Result[DownloadCtx]]:
|
|
313
334
|
async for result in processor.process():
|
|
314
335
|
task_idx = progress.task_ids.index(task_ids[result.data.file.sha])
|
|
315
336
|
task = progress.tasks[task_idx]
|
|
@@ -348,7 +369,7 @@ class Esgpull:
|
|
|
348
369
|
yield result
|
|
349
370
|
case Err(_, err):
|
|
350
371
|
progress.remove_task(task.id)
|
|
351
|
-
yield Err(result.data, err)
|
|
372
|
+
yield Err(result.data, err=err)
|
|
352
373
|
case Err():
|
|
353
374
|
progress.remove_task(task.id)
|
|
354
375
|
yield result
|
|
@@ -438,15 +459,38 @@ class Esgpull:
|
|
|
438
459
|
match result:
|
|
439
460
|
case Ok():
|
|
440
461
|
main_progress.update(main_task_id, advance=1)
|
|
441
|
-
result.data.file
|
|
442
|
-
|
|
443
|
-
|
|
462
|
+
file = result.data.file
|
|
463
|
+
file.status = FileStatus.Done
|
|
464
|
+
files.append(file)
|
|
465
|
+
emit(
|
|
466
|
+
Event.file_complete,
|
|
467
|
+
file=file,
|
|
468
|
+
destination=self.fs[file].drs,
|
|
469
|
+
start_time=result.data.start_time,
|
|
470
|
+
end_time=datetime.now(),
|
|
471
|
+
)
|
|
472
|
+
if file.dataset is not None:
|
|
473
|
+
is_dataset_complete = self.db.scalars(
|
|
474
|
+
sql.dataset.is_complete(file.dataset)
|
|
475
|
+
)[0]
|
|
476
|
+
if is_dataset_complete:
|
|
477
|
+
emit(
|
|
478
|
+
Event.dataset_complete,
|
|
479
|
+
dataset=file.dataset,
|
|
480
|
+
)
|
|
481
|
+
case Err(_, err):
|
|
444
482
|
queue_size -= 1
|
|
445
483
|
main_progress.update(
|
|
446
484
|
main_task_id, total=queue_size
|
|
447
485
|
)
|
|
448
486
|
result.data.file.status = FileStatus.Error
|
|
449
487
|
errors.append(result)
|
|
488
|
+
emit(
|
|
489
|
+
Event.file_error,
|
|
490
|
+
file=result.data.file,
|
|
491
|
+
exception=err,
|
|
492
|
+
)
|
|
493
|
+
|
|
450
494
|
if use_db:
|
|
451
495
|
self.db.add(result.data.file)
|
|
452
496
|
remaining_dict.pop(result.data.file.sha, None)
|
esgpull/graph.py
CHANGED
|
@@ -418,7 +418,7 @@ class Graph:
|
|
|
418
418
|
if keep_require:
|
|
419
419
|
query_tree = query._rich_tree()
|
|
420
420
|
else:
|
|
421
|
-
query_tree = query.
|
|
421
|
+
query_tree = query._rich_tree(hide_require=True)
|
|
422
422
|
if query_tree is not None:
|
|
423
423
|
tree.add(query_tree)
|
|
424
424
|
self.fill_tree(query, query_tree)
|