esgpull 0.8.0__py3-none-any.whl → 0.9.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
esgpull/cli/update.py CHANGED
@@ -10,8 +10,8 @@ from esgpull.cli.decorators import args, opts
10
10
  from esgpull.cli.utils import get_queries, init_esgpull, valid_name_tag
11
11
  from esgpull.context import HintsDict, ResultSearch
12
12
  from esgpull.exceptions import UnsetOptionsError
13
- from esgpull.models import File, FileStatus, Query
14
- from esgpull.tui import Verbosity, logger
13
+ from esgpull.models import Dataset, File, FileStatus, Query
14
+ from esgpull.tui import Verbosity
15
15
  from esgpull.utils import format_size
16
16
 
17
17
 
@@ -20,10 +20,13 @@ class QueryFiles:
20
20
  query: Query
21
21
  expanded: Query
22
22
  skip: bool = False
23
+ datasets: list[Dataset] = field(default_factory=list)
23
24
  files: list[File] = field(default_factory=list)
25
+ dataset_hits: int = field(init=False)
24
26
  hits: int = field(init=False)
25
27
  hints: HintsDict = field(init=False)
26
28
  results: list[ResultSearch] = field(init=False)
29
+ dataset_results: list[ResultSearch] = field(init=False)
27
30
 
28
31
 
29
32
  @click.command()
@@ -80,20 +83,27 @@ def update(
80
83
  file=True,
81
84
  facets=["index_node"],
82
85
  )
83
- for qf, qf_hints in zip(qfs, hints):
86
+ dataset_hits = esg.context.hits(
87
+ *[qf.expanded for qf in qfs],
88
+ file=False,
89
+ )
90
+ for qf, qf_hints, qf_dataset_hits in zip(qfs, hints, dataset_hits):
84
91
  qf.hits = sum(esg.context.hits_from_hints(qf_hints))
85
92
  if qf_hints:
86
93
  qf.hints = qf_hints
94
+ qf.dataset_hits = qf_dataset_hits
87
95
  else:
88
96
  qf.skip = True
89
97
  for qf in qfs:
90
98
  s = "s" if qf.hits > 1 else ""
91
- esg.ui.print(f"{qf.query.rich_name} -> {qf.hits} file{s}.")
99
+ esg.ui.print(
100
+ f"{qf.query.rich_name} -> {qf.hits} file{s} (before replica de-duplication)."
101
+ )
92
102
  total_hits = sum([qf.hits for qf in qfs])
93
103
  if total_hits == 0:
94
104
  esg.ui.print("No files found.")
95
105
  esg.ui.raise_maybe_record(Exit(0))
96
- else:
106
+ elif len(qfs) > 1:
97
107
  esg.ui.print(f"{total_hits} files found.")
98
108
  qfs = [qf for qf in qfs if not qf.skip]
99
109
  # Prepare optimally distributed requests to ESGF
@@ -101,6 +111,12 @@ def update(
101
111
  # It might be interesting for the special case where all files already
102
112
  # exist in db, then the detailed fetch could be skipped.
103
113
  for qf in qfs:
114
+ qf_dataset_results = esg.context.prepare_search(
115
+ qf.expanded,
116
+ file=False,
117
+ hits=[qf.dataset_hits],
118
+ max_hits=None,
119
+ )
104
120
  if esg.config.api.use_custom_distribution_algorithm:
105
121
  qf_results = esg.context.prepare_search_distributed(
106
122
  qf.expanded,
@@ -115,7 +131,7 @@ def update(
115
131
  hits=[qf.hits],
116
132
  max_hits=None,
117
133
  )
118
- nb_req = len(qf_results)
134
+ nb_req = len(qf_dataset_results) + len(qf_results)
119
135
  if nb_req > 50:
120
136
  msg = (
121
137
  f"{nb_req} requests will be sent to ESGF to"
@@ -126,13 +142,32 @@ def update(
126
142
  esg.ui.print(f"{qf.query.rich_name} is now untracked.")
127
143
  qf.query.tracked = False
128
144
  qf_results = []
145
+ qf_dataset_results = []
129
146
  case "n":
130
147
  qf_results = []
148
+ qf_dataset_results = []
131
149
  case _:
132
150
  ...
133
151
  qf.results = qf_results
152
+ qf.dataset_results = qf_dataset_results
134
153
  # Fetch files and update db
135
154
  # [?] TODO: dry_run to print urls here
155
+ with esg.ui.spinner("Fetching datasets"):
156
+ coros = []
157
+ for qf in qfs:
158
+ coro = esg.context._datasets(
159
+ *qf.dataset_results, keep_duplicates=False
160
+ )
161
+ coros.append(coro)
162
+ datasets = esg.context.sync_gather(*coros)
163
+ for qf, qf_datasets in zip(qfs, datasets):
164
+ qf.datasets = [
165
+ Dataset(
166
+ dataset_id=record.dataset_id,
167
+ total_files=record.number_of_files,
168
+ )
169
+ for record in qf_datasets
170
+ ]
136
171
  with esg.ui.spinner("Fetching files"):
137
172
  coros = []
138
173
  for qf in qfs:
@@ -142,23 +177,37 @@ def update(
142
177
  for qf, qf_files in zip(qfs, files):
143
178
  qf.files = qf_files
144
179
  for qf in qfs:
145
- shas = {f.sha for f in qf.query.files}
146
- new_files: list[File] = []
147
- for file in qf.files:
148
- if file.sha not in shas:
149
- new_files.append(file)
150
- nb_files = len(new_files)
151
180
  if not qf.query.tracked:
152
181
  esg.db.add(qf.query)
153
182
  continue
154
- elif nb_files == 0:
155
- esg.ui.print(f"{qf.query.rich_name} is already up-to-date.")
156
- continue
157
- size = sum([file.size for file in new_files])
183
+ with esg.db.commit_context():
184
+ unregistered_datasets = [
185
+ f for f in qf.datasets if f not in esg.db
186
+ ]
187
+ if len(unregistered_datasets) > 0:
188
+ esg.ui.print(
189
+ f"Adding {len(unregistered_datasets)} new datasets to database."
190
+ )
191
+ esg.db.session.add_all(unregistered_datasets)
192
+ files_from_db = [
193
+ esg.db.get(File, f.sha) for f in qf.files if f in esg.db
194
+ ]
195
+ registered_files = [f for f in files_from_db if f is not None]
196
+ unregistered_files = [f for f in qf.files if f not in esg.db]
197
+ if len(unregistered_files) > 0:
198
+ esg.ui.print(
199
+ f"Adding {len(unregistered_files)} new files to database."
200
+ )
201
+ esg.db.session.add_all(unregistered_files)
202
+ files = registered_files + unregistered_files
203
+ not_done_files = [
204
+ file for file in files if file.status != FileStatus.Done
205
+ ]
206
+ download_size = sum(file.size for file in not_done_files)
158
207
  msg = (
159
- f"\nUpdating {qf.query.rich_name} with {nb_files}"
160
- f" new files ({format_size(size)})."
161
- "\nSend to download queue?"
208
+ f"\n{qf.query.rich_name}: {len(not_done_files)} "
209
+ f" files ({format_size(download_size)}) to download."
210
+ f"\nLink to query and send to download queue?"
162
211
  )
163
212
  if yes:
164
213
  choice = "y"
@@ -174,25 +223,19 @@ def update(
174
223
  if choice == "y":
175
224
  legacy = esg.legacy_query
176
225
  has_legacy = legacy.state.persistent
226
+ applied_changes = False
177
227
  with esg.db.commit_context():
178
228
  for file in esg.ui.track(
179
- new_files,
180
- description=qf.query.rich_name,
229
+ files,
230
+ description=f"{qf.query.rich_name}",
181
231
  ):
182
- file_db = esg.db.get(File, file.sha)
183
- if file_db is None:
184
- if esg.db.has_file_id(file):
185
- logger.error(
186
- "File id already exists in database, "
187
- "there might be an error with its checksum"
188
- f"\n{file}"
189
- )
190
- continue
232
+ if file.status != FileStatus.Done:
191
233
  file.status = FileStatus.Queued
192
- esg.db.session.add(file)
193
- elif has_legacy and legacy in file_db.queries:
194
- esg.db.unlink(query=legacy, file=file_db)
195
- esg.db.link(query=qf.query, file=file)
196
- qf.query.updated_at = datetime.now(timezone.utc)
234
+ if has_legacy and legacy in file.queries:
235
+ _ = esg.db.unlink(query=legacy, file=file)
236
+ changed = esg.db.link(query=qf.query, file=file)
237
+ applied_changes = applied_changes or changed
238
+ if applied_changes:
239
+ qf.query.updated_at = datetime.now(timezone.utc)
197
240
  esg.db.session.add(qf.query)
198
241
  esg.ui.raise_maybe_record(Exit(0))
esgpull/cli/utils.py CHANGED
@@ -103,7 +103,7 @@ def totable(docs: list[OrderedDict[str, Any]]) -> Table:
103
103
  table = Table(box=MINIMAL_DOUBLE_HEAD, show_edge=False)
104
104
  for key in docs[0].keys():
105
105
  justify: Literal["left", "right", "center"]
106
- if key in ["file", "dataset"]:
106
+ if key in ["file", "dataset", "plugin"]:
107
107
  justify = "left"
108
108
  else:
109
109
  justify = "right"
@@ -243,3 +243,18 @@ def get_queries(
243
243
  kids = graph.get_all_children(query.sha)
244
244
  queries.extend(kids)
245
245
  return queries
246
+
247
+
248
+ def extract_subdict(doc: dict, key: str | None) -> dict:
249
+ if key is None:
250
+ return doc
251
+ for part in key.split("."):
252
+ if not part:
253
+ raise KeyError(key)
254
+ elif part in doc:
255
+ doc = doc[part]
256
+ else:
257
+ raise KeyError(part)
258
+ for part in key.split(".")[::-1]:
259
+ doc = {part: doc}
260
+ return doc
esgpull/config.py CHANGED
@@ -4,10 +4,10 @@ import logging
4
4
  from collections.abc import Iterator, Mapping
5
5
  from enum import Enum, auto
6
6
  from pathlib import Path
7
- from typing import Any, cast
7
+ from typing import Any, TypeVar, Union, cast, overload
8
8
 
9
9
  import tomlkit
10
- from attrs import Factory, define, field, fields
10
+ from attrs import Factory, define, field
11
11
  from attrs import has as attrs_has
12
12
  from cattrs import Converter
13
13
  from cattrs.gen import make_dict_unstructure_fn, override
@@ -20,6 +20,66 @@ from esgpull.models.options import Options
20
20
 
21
21
  logger = logging.getLogger("esgpull")
22
22
 
23
+ T = TypeVar("T")
24
+
25
+
26
+ @overload
27
+ def cast_value(
28
+ target: str, value: Union[str, int, bool, float], key: str
29
+ ) -> str: ...
30
+
31
+
32
+ @overload
33
+ def cast_value(
34
+ target: bool, value: Union[str, int, bool, float], key: str
35
+ ) -> bool: ...
36
+
37
+
38
+ @overload
39
+ def cast_value(
40
+ target: int, value: Union[str, int, bool, float], key: str
41
+ ) -> int: ...
42
+
43
+
44
+ @overload
45
+ def cast_value(
46
+ target: float, value: Union[str, int, bool, float], key: str
47
+ ) -> float: ...
48
+
49
+
50
+ def cast_value(
51
+ target: Any, value: Union[str, int, bool, float], key: str
52
+ ) -> Any:
53
+ if isinstance(value, type(target)):
54
+ return value
55
+ elif attrs_has(type(target)):
56
+ raise KeyError(key)
57
+ elif isinstance(target, str):
58
+ return str(value)
59
+ elif isinstance(target, float):
60
+ try:
61
+ return float(value)
62
+ except Exception:
63
+ raise ValueError(value)
64
+ elif isinstance(target, bool):
65
+ if isinstance(value, str):
66
+ if value.lower() in ["on", "true"]:
67
+ return True
68
+ elif value.lower() in ["off", "false"]:
69
+ return False
70
+ else:
71
+ raise ValueError(value)
72
+ else:
73
+ raise TypeError(value)
74
+ elif isinstance(target, int):
75
+ # int must be after bool, because isinstance(True, int) == True
76
+ try:
77
+ return int(value)
78
+ except Exception:
79
+ raise ValueError(value)
80
+ else:
81
+ raise TypeError(value)
82
+
23
83
 
24
84
  @define
25
85
  class Paths:
@@ -28,6 +88,7 @@ class Paths:
28
88
  db: Path = field(converter=Path)
29
89
  log: Path = field(converter=Path)
30
90
  tmp: Path = field(converter=Path)
91
+ plugins: Path = field(converter=Path)
31
92
 
32
93
  @auth.default
33
94
  def _auth_factory(self) -> Path:
@@ -69,12 +130,21 @@ class Paths:
69
130
  root = InstallConfig.default
70
131
  return root / "tmp"
71
132
 
72
- def __iter__(self) -> Iterator[Path]:
133
+ @plugins.default
134
+ def _plugins_factory(self) -> Path:
135
+ if InstallConfig.current is not None:
136
+ root = InstallConfig.current.path
137
+ else:
138
+ root = InstallConfig.default
139
+ return root / "plugins"
140
+
141
+ def values(self) -> Iterator[Path]:
73
142
  yield self.auth
74
143
  yield self.data
75
144
  yield self.db
76
145
  yield self.log
77
146
  yield self.tmp
147
+ yield self.plugins
78
148
 
79
149
 
80
150
  @define
@@ -213,6 +283,13 @@ def iter_keys(
213
283
  yield local_path
214
284
 
215
285
 
286
+ @define
287
+ class Plugins:
288
+ """Configuration for the plugin system"""
289
+
290
+ enabled: bool = False
291
+
292
+
216
293
  @define
217
294
  class Config:
218
295
  paths: Paths = Factory(Paths)
@@ -221,6 +298,7 @@ class Config:
221
298
  db: Db = Factory(Db)
222
299
  download: Download = Factory(Download)
223
300
  api: API = Factory(API)
301
+ plugins: Plugins = Factory(Plugins)
224
302
  _raw: TOMLDocument | None = field(init=False, default=None)
225
303
  _config_file: Path | None = field(init=False, default=None)
226
304
 
@@ -287,7 +365,7 @@ class Config:
287
365
  def update_item(
288
366
  self,
289
367
  key: str,
290
- value: int | str,
368
+ value: str | int | bool,
291
369
  empty_ok: bool = False,
292
370
  ) -> int | str | None:
293
371
  if self._raw is None and empty_ok:
@@ -302,29 +380,8 @@ class Config:
302
380
  doc.setdefault(part, {})
303
381
  doc = doc[part]
304
382
  obj = getattr(obj, part)
305
- value_type = getattr(fields(type(obj)), last).type
306
383
  old_value = getattr(obj, last)
307
- if attrs_has(value_type):
308
- raise KeyError(key)
309
- elif value_type is str:
310
- ...
311
- elif value_type is int:
312
- try:
313
- value = value_type(value)
314
- except Exception:
315
- ...
316
- elif value_type is bool:
317
- if isinstance(value, bool):
318
- ...
319
- elif isinstance(value, str):
320
- if value.lower() in ["on", "true"]:
321
- value = True
322
- elif value.lower() in ["off", "false"]:
323
- value = False
324
- else:
325
- raise ValueError(value)
326
- else:
327
- raise TypeError(value)
384
+ value = cast_value(old_value, value, key)
328
385
  setattr(obj, last, value)
329
386
  doc[last] = value
330
387
  return old_value
esgpull/constants.py CHANGED
@@ -1,6 +1,9 @@
1
+ import os
2
+
1
3
  CONFIG_FILENAME = "config.toml"
2
4
  INSTALLS_PATH_ENV = "ESGPULL_INSTALLS_PATH"
3
5
  ROOT_ENV = "ESGPULL_CURRENT"
6
+ ESGPULL_DEBUG = os.environ.get("ESGPULL_DEBUG", "0") == "1"
4
7
 
5
8
  IDP = "/esgf-idp/openid/"
6
9
  CEDA_IDP = "/OpenID/Provider/server/"
esgpull/context.py CHANGED
@@ -16,7 +16,7 @@ from rich.pretty import pretty_repr
16
16
 
17
17
  from esgpull.config import Config
18
18
  from esgpull.exceptions import SolrUnstableQueryError
19
- from esgpull.models import Dataset, File, Query
19
+ from esgpull.models import DatasetRecord, File, Query
20
20
  from esgpull.tui import logger
21
21
  from esgpull.utils import format_date_iso, index2url, sync
22
22
 
@@ -151,7 +151,7 @@ class ResultHints(Result):
151
151
 
152
152
  @dataclass
153
153
  class ResultSearch(Result):
154
- data: Sequence[File | Dataset] = field(init=False, repr=False)
154
+ data: Sequence[File | DatasetRecord] = field(init=False, repr=False)
155
155
 
156
156
  def process(self) -> None:
157
157
  raise NotImplementedError
@@ -159,14 +159,14 @@ class ResultSearch(Result):
159
159
 
160
160
  @dataclass
161
161
  class ResultDatasets(Result):
162
- data: Sequence[Dataset] = field(init=False, repr=False)
162
+ data: Sequence[DatasetRecord] = field(init=False, repr=False)
163
163
 
164
164
  def process(self) -> None:
165
165
  self.data = []
166
166
  if self.success:
167
167
  for doc in self.json["response"]["docs"]:
168
168
  try:
169
- dataset = Dataset.serialize(doc)
169
+ dataset = DatasetRecord.serialize(doc)
170
170
  self.data.append(dataset)
171
171
  except KeyError as exc:
172
172
  logger.exception(exc)
@@ -492,8 +492,8 @@ class Context:
492
492
  self,
493
493
  *results: ResultSearch,
494
494
  keep_duplicates: bool,
495
- ) -> list[Dataset]:
496
- datasets: list[Dataset] = []
495
+ ) -> list[DatasetRecord]:
496
+ datasets: list[DatasetRecord] = []
497
497
  ids: set[str] = set()
498
498
  async for result in self._fetch(*results):
499
499
  dataset_result = result.to(ResultDatasets)
@@ -501,7 +501,7 @@ class Context:
501
501
  if dataset_result.processed:
502
502
  for d in dataset_result.data:
503
503
  if not keep_duplicates and d.dataset_id in ids:
504
- logger.warning(f"Duplicate dataset {d.dataset_id}")
504
+ logger.debug(f"Duplicate dataset {d.dataset_id}")
505
505
  else:
506
506
  datasets.append(d)
507
507
  ids.add(d.dataset_id)
@@ -520,7 +520,7 @@ class Context:
520
520
  if files_result.processed:
521
521
  for file in files_result.data:
522
522
  if not keep_duplicates and file.sha in shas:
523
- logger.warning(f"Duplicate file {file.file_id}")
523
+ logger.debug(f"Duplicate file {file.file_id}")
524
524
  else:
525
525
  files.append(file)
526
526
  shas.add(file.sha)
@@ -627,7 +627,7 @@ class Context:
627
627
  date_from: datetime | None = None,
628
628
  date_to: datetime | None = None,
629
629
  keep_duplicates: bool = True,
630
- ) -> list[Dataset]:
630
+ ) -> list[DatasetRecord]:
631
631
  if hits is None:
632
632
  hits = self.hits(*queries, file=False)
633
633
  results = self.prepare_search(
esgpull/database.py CHANGED
@@ -12,11 +12,13 @@ import sqlalchemy.orm
12
12
  from alembic.config import Config as AlembicConfig
13
13
  from alembic.migration import MigrationContext
14
14
  from alembic.script import ScriptDirectory
15
+ from sqlalchemy.inspection import inspect
15
16
  from sqlalchemy.orm import Session, joinedload, make_transient
16
17
 
17
18
  from esgpull import __file__
18
19
  from esgpull.config import Config
19
20
  from esgpull.models import File, Query, Table, sql
21
+ from esgpull.models.base import Base, BaseNoSHA
20
22
  from esgpull.version import __version__
21
23
 
22
24
  # from esgpull.exceptions import NoClauseError
@@ -145,14 +147,26 @@ class Database:
145
147
  for item in items:
146
148
  make_transient(item)
147
149
 
148
- def link(self, query: Query, file: File):
149
- self.session.execute(sql.query_file.link(query, file))
150
-
151
- def unlink(self, query: Query, file: File):
152
- self.session.execute(sql.query_file.unlink(query, file))
150
+ def link(self, query: Query, file: File) -> bool:
151
+ if not self.session.scalar(sql.query_file.is_linked(query, file)):
152
+ self.session.execute(sql.query_file.link(query, file))
153
+ return True
154
+ else:
155
+ return False
153
156
 
154
- def __contains__(self, item: Table) -> bool:
155
- return self.scalars(sql.count(item))[0] > 0
157
+ def unlink(self, query: Query, file: File) -> bool:
158
+ if self.session.scalar(sql.query_file.is_linked(query, file)):
159
+ self.session.execute(sql.query_file.unlink(query, file))
160
+ return True
161
+ else:
162
+ return False
163
+
164
+ def __contains__(self, item: Base | BaseNoSHA) -> bool:
165
+ mapper = inspect(item.__class__)
166
+ pk_col = mapper.primary_key[0]
167
+ pk_value = getattr(item, pk_col.name)
168
+ stmt = sa.exists().where(pk_col == pk_value)
169
+ return self.scalars(sa.select(stmt))[0]
156
170
 
157
171
  def has_file_id(self, file: File) -> bool:
158
172
  return len(self.scalars(sql.file.with_file_id(file.file_id))) == 1
esgpull/download.py CHANGED
@@ -1,6 +1,7 @@
1
1
  # from math import ceil
2
2
  from collections.abc import AsyncGenerator
3
3
  from dataclasses import dataclass
4
+ from datetime import datetime
4
5
 
5
6
  from httpx import AsyncClient
6
7
 
@@ -19,6 +20,7 @@ class DownloadCtx:
19
20
  completed: int = 0
20
21
  chunk: bytes | None = None
21
22
  digest: Digest | None = None
23
+ start_time: datetime | None = None
22
24
 
23
25
  @property
24
26
  def finished(self) -> bool:
@@ -54,6 +56,7 @@ class Simple(BaseDownloader):
54
56
  ctx: DownloadCtx,
55
57
  chunk_size: int,
56
58
  ) -> AsyncGenerator[DownloadCtx, None]:
59
+ ctx.start_time = datetime.now()
57
60
  async with client.stream("GET", ctx.file.url) as resp:
58
61
  resp.raise_for_status()
59
62
  async for chunk in resp.aiter_bytes(chunk_size=chunk_size):
esgpull/esgpull.py CHANGED
@@ -3,6 +3,7 @@ from __future__ import annotations
3
3
  import logging
4
4
  from collections.abc import AsyncIterator
5
5
  from dataclasses import dataclass
6
+ from datetime import datetime
6
7
  from functools import cached_property, partial
7
8
  from pathlib import Path
8
9
  from warnings import warn
@@ -25,6 +26,7 @@ from esgpull.auth import Auth, Credentials
25
26
  from esgpull.config import Config
26
27
  from esgpull.context import Context
27
28
  from esgpull.database import Database
29
+ from esgpull.download import DownloadCtx
28
30
  from esgpull.exceptions import (
29
31
  DownloadCancelled,
30
32
  InvalidInstallPath,
@@ -44,6 +46,13 @@ from esgpull.models import (
44
46
  sql,
45
47
  )
46
48
  from esgpull.models.utils import short_sha
49
+ from esgpull.plugin import (
50
+ Event,
51
+ PluginManager,
52
+ emit,
53
+ get_plugin_manager,
54
+ set_plugin_manager,
55
+ )
47
56
  from esgpull.processor import Processor
48
57
  from esgpull.result import Err, Ok, Result
49
58
  from esgpull.tui import UI, DummyLive, Verbosity, logger
@@ -117,6 +126,18 @@ class Esgpull:
117
126
  if load_db:
118
127
  self.db = Database.from_config(self.config)
119
128
  self.graph = Graph(self.db)
129
+ # Initialize plugin system
130
+ plugin_config_path = self.config.paths.plugins / "plugins.toml"
131
+ try:
132
+ self.plugin_manager = get_plugin_manager()
133
+ self.plugin_manager.__init__(config_path=plugin_config_path)
134
+ except ValueError:
135
+ self.plugin_manager = PluginManager(config_path=plugin_config_path)
136
+ set_plugin_manager(self.plugin_manager)
137
+ if self.config.plugins.enabled:
138
+ self.plugin_manager.enabled = True
139
+ self.config.paths.plugins.mkdir(exist_ok=True, parents=True)
140
+ self.plugin_manager.discover_plugins(self.config.paths.plugins)
120
141
 
121
142
  def fetch_index_nodes(self) -> list[str]:
122
143
  """
@@ -309,7 +330,7 @@ class Esgpull:
309
330
  progress: Progress,
310
331
  task_ids: dict[str, TaskID],
311
332
  live: Live | DummyLive,
312
- ) -> AsyncIterator[Result]:
333
+ ) -> AsyncIterator[Result[DownloadCtx]]:
313
334
  async for result in processor.process():
314
335
  task_idx = progress.task_ids.index(task_ids[result.data.file.sha])
315
336
  task = progress.tasks[task_idx]
@@ -348,7 +369,7 @@ class Esgpull:
348
369
  yield result
349
370
  case Err(_, err):
350
371
  progress.remove_task(task.id)
351
- yield Err(result.data, err)
372
+ yield Err(result.data, err=err)
352
373
  case Err():
353
374
  progress.remove_task(task.id)
354
375
  yield result
@@ -438,15 +459,38 @@ class Esgpull:
438
459
  match result:
439
460
  case Ok():
440
461
  main_progress.update(main_task_id, advance=1)
441
- result.data.file.status = FileStatus.Done
442
- files.append(result.data.file)
443
- case Err():
462
+ file = result.data.file
463
+ file.status = FileStatus.Done
464
+ files.append(file)
465
+ emit(
466
+ Event.file_complete,
467
+ file=file,
468
+ destination=self.fs[file].drs,
469
+ start_time=result.data.start_time,
470
+ end_time=datetime.now(),
471
+ )
472
+ if file.dataset is not None:
473
+ is_dataset_complete = self.db.scalars(
474
+ sql.dataset.is_complete(file.dataset)
475
+ )[0]
476
+ if is_dataset_complete:
477
+ emit(
478
+ Event.dataset_complete,
479
+ dataset=file.dataset,
480
+ )
481
+ case Err(_, err):
444
482
  queue_size -= 1
445
483
  main_progress.update(
446
484
  main_task_id, total=queue_size
447
485
  )
448
486
  result.data.file.status = FileStatus.Error
449
487
  errors.append(result)
488
+ emit(
489
+ Event.file_error,
490
+ file=result.data.file,
491
+ exception=err,
492
+ )
493
+
450
494
  if use_db:
451
495
  self.db.add(result.data.file)
452
496
  remaining_dict.pop(result.data.file.sha, None)