esgpull 0.7.3__py3-none-any.whl → 0.9.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (37) hide show
  1. esgpull/cli/__init__.py +2 -2
  2. esgpull/cli/add.py +7 -1
  3. esgpull/cli/config.py +5 -21
  4. esgpull/cli/plugins.py +398 -0
  5. esgpull/cli/show.py +29 -0
  6. esgpull/cli/status.py +6 -4
  7. esgpull/cli/update.py +72 -18
  8. esgpull/cli/utils.py +16 -1
  9. esgpull/config.py +83 -25
  10. esgpull/constants.py +3 -0
  11. esgpull/context.py +15 -15
  12. esgpull/database.py +8 -2
  13. esgpull/download.py +3 -0
  14. esgpull/esgpull.py +49 -5
  15. esgpull/graph.py +1 -1
  16. esgpull/migrations/versions/0.8.0_update_tables.py +28 -0
  17. esgpull/migrations/versions/0.9.0_update_tables.py +28 -0
  18. esgpull/migrations/versions/14c72daea083_query_add_column_updated_at.py +36 -0
  19. esgpull/migrations/versions/c7c8541fa741_query_add_column_added_at.py +37 -0
  20. esgpull/migrations/versions/d14f179e553c_file_add_composite_index_dataset_id_.py +32 -0
  21. esgpull/migrations/versions/e7edab5d4e4b_add_dataset_tracking.py +39 -0
  22. esgpull/models/__init__.py +2 -1
  23. esgpull/models/base.py +31 -14
  24. esgpull/models/dataset.py +48 -5
  25. esgpull/models/options.py +1 -1
  26. esgpull/models/query.py +98 -15
  27. esgpull/models/sql.py +40 -9
  28. esgpull/plugin.py +574 -0
  29. esgpull/processor.py +3 -3
  30. esgpull/tui.py +23 -1
  31. esgpull/utils.py +19 -3
  32. {esgpull-0.7.3.dist-info → esgpull-0.9.0.dist-info}/METADATA +11 -2
  33. {esgpull-0.7.3.dist-info → esgpull-0.9.0.dist-info}/RECORD +36 -29
  34. {esgpull-0.7.3.dist-info → esgpull-0.9.0.dist-info}/WHEEL +1 -1
  35. esgpull/cli/datasets.py +0 -78
  36. {esgpull-0.7.3.dist-info → esgpull-0.9.0.dist-info}/entry_points.txt +0 -0
  37. {esgpull-0.7.3.dist-info → esgpull-0.9.0.dist-info}/licenses/LICENSE +0 -0
esgpull/cli/update.py CHANGED
@@ -1,6 +1,7 @@
1
1
  from __future__ import annotations
2
2
 
3
3
  from dataclasses import dataclass, field
4
+ from datetime import datetime, timezone
4
5
 
5
6
  import click
6
7
  from click.exceptions import Abort, Exit
@@ -9,7 +10,7 @@ from esgpull.cli.decorators import args, opts
9
10
  from esgpull.cli.utils import get_queries, init_esgpull, valid_name_tag
10
11
  from esgpull.context import HintsDict, ResultSearch
11
12
  from esgpull.exceptions import UnsetOptionsError
12
- from esgpull.models import File, FileStatus, Query
13
+ from esgpull.models import Dataset, File, FileStatus, Query
13
14
  from esgpull.tui import Verbosity, logger
14
15
  from esgpull.utils import format_size
15
16
 
@@ -19,10 +20,13 @@ class QueryFiles:
19
20
  query: Query
20
21
  expanded: Query
21
22
  skip: bool = False
23
+ datasets: list[Dataset] = field(default_factory=list)
22
24
  files: list[File] = field(default_factory=list)
25
+ dataset_hits: int = field(init=False)
23
26
  hits: int = field(init=False)
24
27
  hints: HintsDict = field(init=False)
25
28
  results: list[ResultSearch] = field(init=False)
29
+ dataset_results: list[ResultSearch] = field(init=False)
26
30
 
27
31
 
28
32
  @click.command()
@@ -79,20 +83,27 @@ def update(
79
83
  file=True,
80
84
  facets=["index_node"],
81
85
  )
82
- for qf, qf_hints in zip(qfs, hints):
86
+ dataset_hits = esg.context.hits(
87
+ *[qf.expanded for qf in qfs],
88
+ file=False,
89
+ )
90
+ for qf, qf_hints, qf_dataset_hits in zip(qfs, hints, dataset_hits):
83
91
  qf.hits = sum(esg.context.hits_from_hints(qf_hints))
84
92
  if qf_hints:
85
93
  qf.hints = qf_hints
94
+ qf.dataset_hits = qf_dataset_hits
86
95
  else:
87
96
  qf.skip = True
88
97
  for qf in qfs:
89
98
  s = "s" if qf.hits > 1 else ""
90
- esg.ui.print(f"{qf.query.rich_name} -> {qf.hits} file{s}.")
99
+ esg.ui.print(
100
+ f"{qf.query.rich_name} -> {qf.hits} file{s} (before replica de-duplication)."
101
+ )
91
102
  total_hits = sum([qf.hits for qf in qfs])
92
103
  if total_hits == 0:
93
104
  esg.ui.print("No files found.")
94
105
  esg.ui.raise_maybe_record(Exit(0))
95
- else:
106
+ elif len(qfs) > 1:
96
107
  esg.ui.print(f"{total_hits} files found.")
97
108
  qfs = [qf for qf in qfs if not qf.skip]
98
109
  # Prepare optimally distributed requests to ESGF
@@ -100,13 +111,27 @@ def update(
100
111
  # It might be interesting for the special case where all files already
101
112
  # exist in db, then the detailed fetch could be skipped.
102
113
  for qf in qfs:
103
- qf_results = esg.context.prepare_search_distributed(
114
+ qf_dataset_results = esg.context.prepare_search(
104
115
  qf.expanded,
105
- file=True,
106
- hints=[qf.hints],
116
+ file=False,
117
+ hits=[qf.dataset_hits],
107
118
  max_hits=None,
108
119
  )
109
- nb_req = len(qf_results)
120
+ if esg.config.api.use_custom_distribution_algorithm:
121
+ qf_results = esg.context.prepare_search_distributed(
122
+ qf.expanded,
123
+ file=True,
124
+ hints=[qf.hints],
125
+ max_hits=None,
126
+ )
127
+ else:
128
+ qf_results = esg.context.prepare_search(
129
+ qf.expanded,
130
+ file=True,
131
+ hits=[qf.hits],
132
+ max_hits=None,
133
+ )
134
+ nb_req = len(qf_dataset_results) + len(qf_results)
110
135
  if nb_req > 50:
111
136
  msg = (
112
137
  f"{nb_req} requests will be sent to ESGF to"
@@ -117,13 +142,32 @@ def update(
117
142
  esg.ui.print(f"{qf.query.rich_name} is now untracked.")
118
143
  qf.query.tracked = False
119
144
  qf_results = []
145
+ qf_dataset_results = []
120
146
  case "n":
121
147
  qf_results = []
148
+ qf_dataset_results = []
122
149
  case _:
123
150
  ...
124
151
  qf.results = qf_results
152
+ qf.dataset_results = qf_dataset_results
125
153
  # Fetch files and update db
126
154
  # [?] TODO: dry_run to print urls here
155
+ with esg.ui.spinner("Fetching datasets"):
156
+ coros = []
157
+ for qf in qfs:
158
+ coro = esg.context._datasets(
159
+ *qf.dataset_results, keep_duplicates=False
160
+ )
161
+ coros.append(coro)
162
+ datasets = esg.context.sync_gather(*coros)
163
+ for qf, qf_datasets in zip(qfs, datasets):
164
+ qf.datasets = [
165
+ Dataset(
166
+ dataset_id=record.dataset_id,
167
+ total_files=record.number_of_files,
168
+ )
169
+ for record in qf_datasets
170
+ ]
127
171
  with esg.ui.spinner("Fetching files"):
128
172
  coros = []
129
173
  for qf in qfs:
@@ -133,23 +177,26 @@ def update(
133
177
  for qf, qf_files in zip(qfs, files):
134
178
  qf.files = qf_files
135
179
  for qf in qfs:
136
- shas = {f.sha for f in qf.query.files}
137
- new_files: list[File] = []
138
- for file in qf.files:
139
- if file.sha not in shas:
140
- new_files.append(file)
180
+ new_files = [f for f in qf.files if f not in esg.db]
181
+ new_datasets = [d for d in qf.datasets if d not in esg.db]
182
+ nb_datasets = len(new_datasets)
141
183
  nb_files = len(new_files)
142
184
  if not qf.query.tracked:
143
185
  esg.db.add(qf.query)
144
186
  continue
145
- elif nb_files == 0:
187
+ elif nb_datasets == nb_files == 0:
146
188
  esg.ui.print(f"{qf.query.rich_name} is already up-to-date.")
147
189
  continue
148
190
  size = sum([file.size for file in new_files])
191
+ if size > 0:
192
+ queue_msg = " and send new files to download queue"
193
+ else:
194
+ queue_msg = ""
149
195
  msg = (
150
- f"\nUpdating {qf.query.rich_name} with {nb_files}"
151
- f" new files ({format_size(size)})."
152
- "\nSend to download queue?"
196
+ f"\n{qf.query.rich_name}: {nb_files} new"
197
+ f" files, {nb_datasets} new datasets"
198
+ f" ({format_size(size)})."
199
+ f"\nUpdate the database{queue_msg}?"
153
200
  )
154
201
  if yes:
155
202
  choice = "y"
@@ -166,9 +213,14 @@ def update(
166
213
  legacy = esg.legacy_query
167
214
  has_legacy = legacy.state.persistent
168
215
  with esg.db.commit_context():
216
+ for dataset in esg.ui.track(
217
+ new_datasets,
218
+ description=f"{qf.query.rich_name} (datasets)",
219
+ ):
220
+ esg.db.session.add(dataset)
169
221
  for file in esg.ui.track(
170
222
  new_files,
171
- description=qf.query.rich_name,
223
+ description=f"{qf.query.rich_name} (files)",
172
224
  ):
173
225
  file_db = esg.db.get(File, file.sha)
174
226
  if file_db is None:
@@ -184,4 +236,6 @@ def update(
184
236
  elif has_legacy and legacy in file_db.queries:
185
237
  esg.db.unlink(query=legacy, file=file_db)
186
238
  esg.db.link(query=qf.query, file=file)
239
+ qf.query.updated_at = datetime.now(timezone.utc)
240
+ esg.db.session.add(qf.query)
187
241
  esg.ui.raise_maybe_record(Exit(0))
esgpull/cli/utils.py CHANGED
@@ -103,7 +103,7 @@ def totable(docs: list[OrderedDict[str, Any]]) -> Table:
103
103
  table = Table(box=MINIMAL_DOUBLE_HEAD, show_edge=False)
104
104
  for key in docs[0].keys():
105
105
  justify: Literal["left", "right", "center"]
106
- if key in ["file", "dataset"]:
106
+ if key in ["file", "dataset", "plugin"]:
107
107
  justify = "left"
108
108
  else:
109
109
  justify = "right"
@@ -243,3 +243,18 @@ def get_queries(
243
243
  kids = graph.get_all_children(query.sha)
244
244
  queries.extend(kids)
245
245
  return queries
246
+
247
+
248
+ def extract_subdict(doc: dict, key: str | None) -> dict:
249
+ if key is None:
250
+ return doc
251
+ for part in key.split("."):
252
+ if not part:
253
+ raise KeyError(key)
254
+ elif part in doc:
255
+ doc = doc[part]
256
+ else:
257
+ raise KeyError(part)
258
+ for part in key.split(".")[::-1]:
259
+ doc = {part: doc}
260
+ return doc
esgpull/config.py CHANGED
@@ -4,10 +4,10 @@ import logging
4
4
  from collections.abc import Iterator, Mapping
5
5
  from enum import Enum, auto
6
6
  from pathlib import Path
7
- from typing import Any, cast
7
+ from typing import Any, TypeVar, Union, cast, overload
8
8
 
9
9
  import tomlkit
10
- from attrs import Factory, define, field, fields
10
+ from attrs import Factory, define, field
11
11
  from attrs import has as attrs_has
12
12
  from cattrs import Converter
13
13
  from cattrs.gen import make_dict_unstructure_fn, override
@@ -20,6 +20,66 @@ from esgpull.models.options import Options
20
20
 
21
21
  logger = logging.getLogger("esgpull")
22
22
 
23
+ T = TypeVar("T")
24
+
25
+
26
+ @overload
27
+ def cast_value(
28
+ target: str, value: Union[str, int, bool, float], key: str
29
+ ) -> str: ...
30
+
31
+
32
+ @overload
33
+ def cast_value(
34
+ target: bool, value: Union[str, int, bool, float], key: str
35
+ ) -> bool: ...
36
+
37
+
38
+ @overload
39
+ def cast_value(
40
+ target: int, value: Union[str, int, bool, float], key: str
41
+ ) -> int: ...
42
+
43
+
44
+ @overload
45
+ def cast_value(
46
+ target: float, value: Union[str, int, bool, float], key: str
47
+ ) -> float: ...
48
+
49
+
50
+ def cast_value(
51
+ target: Any, value: Union[str, int, bool, float], key: str
52
+ ) -> Any:
53
+ if isinstance(value, type(target)):
54
+ return value
55
+ elif attrs_has(type(target)):
56
+ raise KeyError(key)
57
+ elif isinstance(target, str):
58
+ return str(value)
59
+ elif isinstance(target, float):
60
+ try:
61
+ return float(value)
62
+ except Exception:
63
+ raise ValueError(value)
64
+ elif isinstance(target, bool):
65
+ if isinstance(value, str):
66
+ if value.lower() in ["on", "true"]:
67
+ return True
68
+ elif value.lower() in ["off", "false"]:
69
+ return False
70
+ else:
71
+ raise ValueError(value)
72
+ else:
73
+ raise TypeError(value)
74
+ elif isinstance(target, int):
75
+ # int must be after bool, because isinstance(True, int) == True
76
+ try:
77
+ return int(value)
78
+ except Exception:
79
+ raise ValueError(value)
80
+ else:
81
+ raise TypeError(value)
82
+
23
83
 
24
84
  @define
25
85
  class Paths:
@@ -28,6 +88,7 @@ class Paths:
28
88
  db: Path = field(converter=Path)
29
89
  log: Path = field(converter=Path)
30
90
  tmp: Path = field(converter=Path)
91
+ plugins: Path = field(converter=Path)
31
92
 
32
93
  @auth.default
33
94
  def _auth_factory(self) -> Path:
@@ -69,12 +130,21 @@ class Paths:
69
130
  root = InstallConfig.default
70
131
  return root / "tmp"
71
132
 
133
+ @plugins.default
134
+ def _plugins_factory(self) -> Path:
135
+ if InstallConfig.current is not None:
136
+ root = InstallConfig.current.path
137
+ else:
138
+ root = InstallConfig.default
139
+ return root / "plugins"
140
+
72
141
  def __iter__(self) -> Iterator[Path]:
73
142
  yield self.auth
74
143
  yield self.data
75
144
  yield self.db
76
145
  yield self.log
77
146
  yield self.tmp
147
+ yield self.plugins
78
148
 
79
149
 
80
150
  @define
@@ -126,6 +196,7 @@ class API:
126
196
  page_limit: int = 50
127
197
  default_options: DefaultOptions = Factory(DefaultOptions)
128
198
  default_query_id: str = ""
199
+ use_custom_distribution_algorithm: bool = False
129
200
 
130
201
 
131
202
  def fix_rename_search_api(doc: TOMLDocument) -> TOMLDocument:
@@ -212,6 +283,13 @@ def iter_keys(
212
283
  yield local_path
213
284
 
214
285
 
286
+ @define
287
+ class Plugins:
288
+ """Configuration for the plugin system"""
289
+
290
+ enabled: bool = False
291
+
292
+
215
293
  @define
216
294
  class Config:
217
295
  paths: Paths = Factory(Paths)
@@ -220,6 +298,7 @@ class Config:
220
298
  db: Db = Factory(Db)
221
299
  download: Download = Factory(Download)
222
300
  api: API = Factory(API)
301
+ plugins: Plugins = Factory(Plugins)
223
302
  _raw: TOMLDocument | None = field(init=False, default=None)
224
303
  _config_file: Path | None = field(init=False, default=None)
225
304
 
@@ -286,7 +365,7 @@ class Config:
286
365
  def update_item(
287
366
  self,
288
367
  key: str,
289
- value: int | str,
368
+ value: str | int | bool,
290
369
  empty_ok: bool = False,
291
370
  ) -> int | str | None:
292
371
  if self._raw is None and empty_ok:
@@ -301,29 +380,8 @@ class Config:
301
380
  doc.setdefault(part, {})
302
381
  doc = doc[part]
303
382
  obj = getattr(obj, part)
304
- value_type = getattr(fields(type(obj)), last).type
305
383
  old_value = getattr(obj, last)
306
- if attrs_has(value_type):
307
- raise KeyError(key)
308
- elif value_type is str:
309
- ...
310
- elif value_type is int:
311
- try:
312
- value = value_type(value)
313
- except Exception:
314
- ...
315
- elif value_type is bool:
316
- if isinstance(value, bool):
317
- ...
318
- elif isinstance(value, str):
319
- if value.lower() in ["on", "true"]:
320
- value = True
321
- elif value.lower() in ["off", "false"]:
322
- value = False
323
- else:
324
- raise ValueError(value)
325
- else:
326
- raise TypeError(value)
384
+ value = cast_value(old_value, value, key)
327
385
  setattr(obj, last, value)
328
386
  doc[last] = value
329
387
  return old_value
esgpull/constants.py CHANGED
@@ -1,6 +1,9 @@
1
+ import os
2
+
1
3
  CONFIG_FILENAME = "config.toml"
2
4
  INSTALLS_PATH_ENV = "ESGPULL_INSTALLS_PATH"
3
5
  ROOT_ENV = "ESGPULL_CURRENT"
6
+ ESGPULL_DEBUG = os.environ.get("ESGPULL_DEBUG", "0") == "1"
4
7
 
5
8
  IDP = "/esgf-idp/openid/"
6
9
  CEDA_IDP = "/OpenID/Provider/server/"
esgpull/context.py CHANGED
@@ -16,9 +16,9 @@ from rich.pretty import pretty_repr
16
16
 
17
17
  from esgpull.config import Config
18
18
  from esgpull.exceptions import SolrUnstableQueryError
19
- from esgpull.models import Dataset, File, Query
19
+ from esgpull.models import DatasetRecord, File, Query
20
20
  from esgpull.tui import logger
21
- from esgpull.utils import format_date, index2url, sync
21
+ from esgpull.utils import format_date_iso, index2url, sync
22
22
 
23
23
  # workaround for notebooks with running event loop
24
24
  if asyncio.get_event_loop().is_running():
@@ -77,9 +77,9 @@ class Result:
77
77
  else:
78
78
  params["fields"] = "instance_id"
79
79
  if date_from is not None:
80
- params["from"] = format_date(date_from)
80
+ params["from"] = format_date_iso(date_from)
81
81
  if date_to is not None:
82
- params["to"] = format_date(date_to)
82
+ params["to"] = format_date_iso(date_to)
83
83
  if facets_param is not None:
84
84
  if len(set(facets_param) & DangerousFacets) > 0:
85
85
  raise SolrUnstableQueryError(pretty_repr(self.query))
@@ -90,9 +90,9 @@ class Result:
90
90
  facets_star = False
91
91
  # [?]TODO: add nominal temporal constraints `to`
92
92
  # if "start" in facets:
93
- # query["start"] = format_date(str(facets.pop("start")))
93
+ # query["start"] = format_date_iso(str(facets.pop("start")))
94
94
  # if "end" in facets:
95
- # query["end"] = format_date(str(facets.pop("end")))
95
+ # query["end"] = format_date_iso(str(facets.pop("end")))
96
96
  solr_terms: list[str] = []
97
97
  for name, values in self.query.selection.items():
98
98
  value_term = " ".join(values)
@@ -151,7 +151,7 @@ class ResultHints(Result):
151
151
 
152
152
  @dataclass
153
153
  class ResultSearch(Result):
154
- data: Sequence[File | Dataset] = field(init=False, repr=False)
154
+ data: Sequence[File | DatasetRecord] = field(init=False, repr=False)
155
155
 
156
156
  def process(self) -> None:
157
157
  raise NotImplementedError
@@ -159,14 +159,14 @@ class ResultSearch(Result):
159
159
 
160
160
  @dataclass
161
161
  class ResultDatasets(Result):
162
- data: Sequence[Dataset] = field(init=False, repr=False)
162
+ data: Sequence[DatasetRecord] = field(init=False, repr=False)
163
163
 
164
164
  def process(self) -> None:
165
165
  self.data = []
166
166
  if self.success:
167
167
  for doc in self.json["response"]["docs"]:
168
168
  try:
169
- dataset = Dataset.serialize(doc)
169
+ dataset = DatasetRecord.serialize(doc)
170
170
  self.data.append(dataset)
171
171
  except KeyError as exc:
172
172
  logger.exception(exc)
@@ -282,7 +282,7 @@ class Context:
282
282
  # # if since is None:
283
283
  # # self.since = since
284
284
  # # else:
285
- # # self.since = format_date(since)
285
+ # # self.since = format_date_iso(since)
286
286
 
287
287
  async def __aenter__(self) -> Context:
288
288
  if hasattr(self, "client"):
@@ -492,8 +492,8 @@ class Context:
492
492
  self,
493
493
  *results: ResultSearch,
494
494
  keep_duplicates: bool,
495
- ) -> list[Dataset]:
496
- datasets: list[Dataset] = []
495
+ ) -> list[DatasetRecord]:
496
+ datasets: list[DatasetRecord] = []
497
497
  ids: set[str] = set()
498
498
  async for result in self._fetch(*results):
499
499
  dataset_result = result.to(ResultDatasets)
@@ -501,7 +501,7 @@ class Context:
501
501
  if dataset_result.processed:
502
502
  for d in dataset_result.data:
503
503
  if not keep_duplicates and d.dataset_id in ids:
504
- logger.warning(f"Duplicate dataset {d.dataset_id}")
504
+ logger.debug(f"Duplicate dataset {d.dataset_id}")
505
505
  else:
506
506
  datasets.append(d)
507
507
  ids.add(d.dataset_id)
@@ -520,7 +520,7 @@ class Context:
520
520
  if files_result.processed:
521
521
  for file in files_result.data:
522
522
  if not keep_duplicates and file.sha in shas:
523
- logger.warning(f"Duplicate file {file.file_id}")
523
+ logger.debug(f"Duplicate file {file.file_id}")
524
524
  else:
525
525
  files.append(file)
526
526
  shas.add(file.sha)
@@ -627,7 +627,7 @@ class Context:
627
627
  date_from: datetime | None = None,
628
628
  date_to: datetime | None = None,
629
629
  keep_duplicates: bool = True,
630
- ) -> list[Dataset]:
630
+ ) -> list[DatasetRecord]:
631
631
  if hits is None:
632
632
  hits = self.hits(*queries, file=False)
633
633
  results = self.prepare_search(
esgpull/database.py CHANGED
@@ -12,11 +12,13 @@ import sqlalchemy.orm
12
12
  from alembic.config import Config as AlembicConfig
13
13
  from alembic.migration import MigrationContext
14
14
  from alembic.script import ScriptDirectory
15
+ from sqlalchemy.inspection import inspect
15
16
  from sqlalchemy.orm import Session, joinedload, make_transient
16
17
 
17
18
  from esgpull import __file__
18
19
  from esgpull.config import Config
19
20
  from esgpull.models import File, Query, Table, sql
21
+ from esgpull.models.base import Base, BaseNoSHA
20
22
  from esgpull.version import __version__
21
23
 
22
24
  # from esgpull.exceptions import NoClauseError
@@ -151,8 +153,12 @@ class Database:
151
153
  def unlink(self, query: Query, file: File):
152
154
  self.session.execute(sql.query_file.unlink(query, file))
153
155
 
154
- def __contains__(self, item: Table) -> bool:
155
- return self.scalars(sql.count(item))[0] > 0
156
+ def __contains__(self, item: Base | BaseNoSHA) -> bool:
157
+ mapper = inspect(item.__class__)
158
+ pk_col = mapper.primary_key[0]
159
+ pk_value = getattr(item, pk_col.name)
160
+ stmt = sa.exists().where(pk_col == pk_value)
161
+ return self.scalars(sa.select(stmt))[0]
156
162
 
157
163
  def has_file_id(self, file: File) -> bool:
158
164
  return len(self.scalars(sql.file.with_file_id(file.file_id))) == 1
esgpull/download.py CHANGED
@@ -1,6 +1,7 @@
1
1
  # from math import ceil
2
2
  from collections.abc import AsyncGenerator
3
3
  from dataclasses import dataclass
4
+ from datetime import datetime
4
5
 
5
6
  from httpx import AsyncClient
6
7
 
@@ -19,6 +20,7 @@ class DownloadCtx:
19
20
  completed: int = 0
20
21
  chunk: bytes | None = None
21
22
  digest: Digest | None = None
23
+ start_time: datetime | None = None
22
24
 
23
25
  @property
24
26
  def finished(self) -> bool:
@@ -54,6 +56,7 @@ class Simple(BaseDownloader):
54
56
  ctx: DownloadCtx,
55
57
  chunk_size: int,
56
58
  ) -> AsyncGenerator[DownloadCtx, None]:
59
+ ctx.start_time = datetime.now()
57
60
  async with client.stream("GET", ctx.file.url) as resp:
58
61
  resp.raise_for_status()
59
62
  async for chunk in resp.aiter_bytes(chunk_size=chunk_size):
esgpull/esgpull.py CHANGED
@@ -3,6 +3,7 @@ from __future__ import annotations
3
3
  import logging
4
4
  from collections.abc import AsyncIterator
5
5
  from dataclasses import dataclass
6
+ from datetime import datetime
6
7
  from functools import cached_property, partial
7
8
  from pathlib import Path
8
9
  from warnings import warn
@@ -25,6 +26,7 @@ from esgpull.auth import Auth, Credentials
25
26
  from esgpull.config import Config
26
27
  from esgpull.context import Context
27
28
  from esgpull.database import Database
29
+ from esgpull.download import DownloadCtx
28
30
  from esgpull.exceptions import (
29
31
  DownloadCancelled,
30
32
  InvalidInstallPath,
@@ -44,6 +46,13 @@ from esgpull.models import (
44
46
  sql,
45
47
  )
46
48
  from esgpull.models.utils import short_sha
49
+ from esgpull.plugin import (
50
+ Event,
51
+ PluginManager,
52
+ emit,
53
+ get_plugin_manager,
54
+ set_plugin_manager,
55
+ )
47
56
  from esgpull.processor import Processor
48
57
  from esgpull.result import Err, Ok, Result
49
58
  from esgpull.tui import UI, DummyLive, Verbosity, logger
@@ -117,6 +126,18 @@ class Esgpull:
117
126
  if load_db:
118
127
  self.db = Database.from_config(self.config)
119
128
  self.graph = Graph(self.db)
129
+ # Initialize plugin system
130
+ plugin_config_path = self.config.paths.plugins / "plugins.toml"
131
+ try:
132
+ self.plugin_manager = get_plugin_manager()
133
+ self.plugin_manager.__init__(config_path=plugin_config_path)
134
+ except ValueError:
135
+ self.plugin_manager = PluginManager(config_path=plugin_config_path)
136
+ set_plugin_manager(self.plugin_manager)
137
+ if self.config.plugins.enabled:
138
+ self.plugin_manager.enabled = True
139
+ self.config.paths.plugins.mkdir(exist_ok=True, parents=True)
140
+ self.plugin_manager.discover_plugins(self.config.paths.plugins)
120
141
 
121
142
  def fetch_index_nodes(self) -> list[str]:
122
143
  """
@@ -309,7 +330,7 @@ class Esgpull:
309
330
  progress: Progress,
310
331
  task_ids: dict[str, TaskID],
311
332
  live: Live | DummyLive,
312
- ) -> AsyncIterator[Result]:
333
+ ) -> AsyncIterator[Result[DownloadCtx]]:
313
334
  async for result in processor.process():
314
335
  task_idx = progress.task_ids.index(task_ids[result.data.file.sha])
315
336
  task = progress.tasks[task_idx]
@@ -348,7 +369,7 @@ class Esgpull:
348
369
  yield result
349
370
  case Err(_, err):
350
371
  progress.remove_task(task.id)
351
- yield Err(result.data, err)
372
+ yield Err(result.data, err=err)
352
373
  case Err():
353
374
  progress.remove_task(task.id)
354
375
  yield result
@@ -438,15 +459,38 @@ class Esgpull:
438
459
  match result:
439
460
  case Ok():
440
461
  main_progress.update(main_task_id, advance=1)
441
- result.data.file.status = FileStatus.Done
442
- files.append(result.data.file)
443
- case Err():
462
+ file = result.data.file
463
+ file.status = FileStatus.Done
464
+ files.append(file)
465
+ emit(
466
+ Event.file_complete,
467
+ file=file,
468
+ destination=self.fs[file].drs,
469
+ start_time=result.data.start_time,
470
+ end_time=datetime.now(),
471
+ )
472
+ if file.dataset is not None:
473
+ is_dataset_complete = self.db.scalars(
474
+ sql.dataset.is_complete(file.dataset)
475
+ )[0]
476
+ if is_dataset_complete:
477
+ emit(
478
+ Event.dataset_complete,
479
+ dataset=file.dataset,
480
+ )
481
+ case Err(_, err):
444
482
  queue_size -= 1
445
483
  main_progress.update(
446
484
  main_task_id, total=queue_size
447
485
  )
448
486
  result.data.file.status = FileStatus.Error
449
487
  errors.append(result)
488
+ emit(
489
+ Event.file_error,
490
+ file=result.data.file,
491
+ exception=err,
492
+ )
493
+
450
494
  if use_db:
451
495
  self.db.add(result.data.file)
452
496
  remaining_dict.pop(result.data.file.sha, None)
esgpull/graph.py CHANGED
@@ -418,7 +418,7 @@ class Graph:
418
418
  if keep_require:
419
419
  query_tree = query._rich_tree()
420
420
  else:
421
- query_tree = query.no_require()._rich_tree()
421
+ query_tree = query._rich_tree(hide_require=True)
422
422
  if query_tree is not None:
423
423
  tree.add(query_tree)
424
424
  self.fill_tree(query, query_tree)