esgpull 0.6.3__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (80) hide show
  1. esgpull/__init__.py +12 -0
  2. esgpull/auth.py +181 -0
  3. esgpull/cli/__init__.py +73 -0
  4. esgpull/cli/add.py +103 -0
  5. esgpull/cli/autoremove.py +38 -0
  6. esgpull/cli/config.py +116 -0
  7. esgpull/cli/convert.py +285 -0
  8. esgpull/cli/decorators.py +342 -0
  9. esgpull/cli/download.py +74 -0
  10. esgpull/cli/facet.py +23 -0
  11. esgpull/cli/get.py +28 -0
  12. esgpull/cli/install.py +85 -0
  13. esgpull/cli/link.py +105 -0
  14. esgpull/cli/login.py +56 -0
  15. esgpull/cli/remove.py +73 -0
  16. esgpull/cli/retry.py +43 -0
  17. esgpull/cli/search.py +201 -0
  18. esgpull/cli/self.py +238 -0
  19. esgpull/cli/show.py +66 -0
  20. esgpull/cli/status.py +67 -0
  21. esgpull/cli/track.py +87 -0
  22. esgpull/cli/update.py +184 -0
  23. esgpull/cli/utils.py +247 -0
  24. esgpull/config.py +410 -0
  25. esgpull/constants.py +56 -0
  26. esgpull/context.py +724 -0
  27. esgpull/database.py +161 -0
  28. esgpull/download.py +162 -0
  29. esgpull/esgpull.py +447 -0
  30. esgpull/exceptions.py +167 -0
  31. esgpull/fs.py +253 -0
  32. esgpull/graph.py +460 -0
  33. esgpull/install_config.py +185 -0
  34. esgpull/migrations/README +1 -0
  35. esgpull/migrations/env.py +82 -0
  36. esgpull/migrations/script.py.mako +24 -0
  37. esgpull/migrations/versions/0.3.0_update_tables.py +170 -0
  38. esgpull/migrations/versions/0.3.1_update_tables.py +25 -0
  39. esgpull/migrations/versions/0.3.2_update_tables.py +26 -0
  40. esgpull/migrations/versions/0.3.3_update_tables.py +25 -0
  41. esgpull/migrations/versions/0.3.4_update_tables.py +25 -0
  42. esgpull/migrations/versions/0.3.5_update_tables.py +25 -0
  43. esgpull/migrations/versions/0.3.6_update_tables.py +26 -0
  44. esgpull/migrations/versions/0.3.7_update_tables.py +26 -0
  45. esgpull/migrations/versions/0.3.8_update_tables.py +26 -0
  46. esgpull/migrations/versions/0.4.0_update_tables.py +25 -0
  47. esgpull/migrations/versions/0.5.0_update_tables.py +26 -0
  48. esgpull/migrations/versions/0.5.1_update_tables.py +26 -0
  49. esgpull/migrations/versions/0.5.2_update_tables.py +25 -0
  50. esgpull/migrations/versions/0.5.3_update_tables.py +26 -0
  51. esgpull/migrations/versions/0.5.4_update_tables.py +25 -0
  52. esgpull/migrations/versions/0.5.5_update_tables.py +25 -0
  53. esgpull/migrations/versions/0.6.0_update_tables.py +25 -0
  54. esgpull/migrations/versions/0.6.1_update_tables.py +25 -0
  55. esgpull/migrations/versions/0.6.2_update_tables.py +25 -0
  56. esgpull/migrations/versions/0.6.3_update_tables.py +25 -0
  57. esgpull/models/__init__.py +31 -0
  58. esgpull/models/base.py +50 -0
  59. esgpull/models/dataset.py +34 -0
  60. esgpull/models/facet.py +18 -0
  61. esgpull/models/file.py +65 -0
  62. esgpull/models/options.py +164 -0
  63. esgpull/models/query.py +481 -0
  64. esgpull/models/selection.py +201 -0
  65. esgpull/models/sql.py +258 -0
  66. esgpull/models/synda_file.py +85 -0
  67. esgpull/models/tag.py +19 -0
  68. esgpull/models/utils.py +54 -0
  69. esgpull/presets.py +13 -0
  70. esgpull/processor.py +172 -0
  71. esgpull/py.typed +0 -0
  72. esgpull/result.py +53 -0
  73. esgpull/tui.py +346 -0
  74. esgpull/utils.py +54 -0
  75. esgpull/version.py +1 -0
  76. esgpull-0.6.3.dist-info/METADATA +110 -0
  77. esgpull-0.6.3.dist-info/RECORD +80 -0
  78. esgpull-0.6.3.dist-info/WHEEL +4 -0
  79. esgpull-0.6.3.dist-info/entry_points.txt +3 -0
  80. esgpull-0.6.3.dist-info/licenses/LICENSE +28 -0
esgpull/models/sql.py ADDED
@@ -0,0 +1,258 @@
1
+ import functools
2
+
3
+ import sqlalchemy as sa
4
+
5
+ from esgpull.models import Table
6
+ from esgpull.models.facet import Facet
7
+ from esgpull.models.file import FileStatus
8
+ from esgpull.models.query import File, Query, query_file_proxy, query_tag_proxy
9
+ from esgpull.models.selection import Selection, selection_facet_proxy
10
+ from esgpull.models.synda_file import SyndaFile
11
+ from esgpull.models.tag import Tag
12
+
13
+
14
+ def count(item: Table) -> sa.Select[tuple[int]]:
15
+ table = item.__class__
16
+ return (
17
+ sa.select(sa.func.count("*"))
18
+ .select_from(table)
19
+ .filter_by(sha=item.sha)
20
+ )
21
+
22
+
23
+ def count_table(table: type[Table]) -> sa.Select[tuple[int]]:
24
+ return sa.select(sa.func.count("*")).select_from(table)
25
+
26
+
27
+ class facet:
28
+ @staticmethod
29
+ @functools.cache
30
+ def all() -> sa.Select[tuple[Facet]]:
31
+ return sa.select(Facet)
32
+
33
+ @staticmethod
34
+ @functools.cache
35
+ def shas() -> sa.Select[tuple[str]]:
36
+ return sa.select(Facet.sha)
37
+
38
+ @staticmethod
39
+ @functools.cache
40
+ def name_count() -> sa.Select[tuple[str, int]]:
41
+ return sa.select(Facet.name, sa.func.count("*")).group_by(Facet.name)
42
+
43
+ @staticmethod
44
+ @functools.cache
45
+ def usage() -> sa.Select[tuple[Facet, int]]:
46
+ return (
47
+ sa.select(Facet, sa.func.count("*"))
48
+ .join(selection_facet_proxy)
49
+ .group_by(Facet.sha)
50
+ )
51
+
52
+ @staticmethod
53
+ def known_shas(shas: list[str]) -> sa.Select[tuple[str]]:
54
+ return sa.select(Facet.sha).where(Facet.sha.in_(shas))
55
+
56
+ @staticmethod
57
+ @functools.cache
58
+ def names() -> sa.Select[tuple[str]]:
59
+ return sa.select(Facet.name).distinct()
60
+
61
+ @staticmethod
62
+ def values(name: str) -> sa.Select[tuple[str]]:
63
+ return sa.select(Facet.value).where(Facet.name == name)
64
+
65
+
66
+ class file:
67
+ @staticmethod
68
+ @functools.cache
69
+ def all() -> sa.Select[tuple[File]]:
70
+ return sa.select(File)
71
+
72
+ @staticmethod
73
+ @functools.cache
74
+ def shas() -> sa.Select[tuple[str]]:
75
+ return sa.select(File.sha)
76
+
77
+ @staticmethod
78
+ @functools.cache
79
+ def orphans() -> sa.Select[tuple[File]]:
80
+ return (
81
+ sa.select(File)
82
+ .outerjoin(query_file_proxy)
83
+ .filter_by(file_sha=None)
84
+ )
85
+
86
+ @staticmethod
87
+ @functools.cache
88
+ def linked() -> sa.Select[tuple[File]]:
89
+ return sa.select(query_file_proxy.c.file_sha).distinct()
90
+
91
+ __dups_cte: sa.CTE = (
92
+ sa.select(File.master_id)
93
+ .group_by(File.master_id)
94
+ .having(sa.func.count("*") > 1)
95
+ .cte()
96
+ )
97
+
98
+ @staticmethod
99
+ @functools.cache
100
+ def duplicates() -> sa.Select[tuple[File]]:
101
+ return sa.select(File).join(
102
+ file.__dups_cte,
103
+ File.master_id == file.__dups_cte.c.master_id,
104
+ )
105
+
106
+ @staticmethod
107
+ def shas_from_query(query_sha: str) -> sa.Select[tuple[str]]:
108
+ return sa.select(query_file_proxy.c.file_sha).filter_by(
109
+ query_sha=query_sha
110
+ )
111
+
112
+ @staticmethod
113
+ def with_status(*status: FileStatus) -> sa.Select[tuple[File]]:
114
+ return sa.select(File).where(File.status.in_(status))
115
+
116
+ @staticmethod
117
+ def with_file_id(file_id: str) -> sa.Select[tuple[str]]:
118
+ return sa.select(File.sha).where(File.file_id == file_id).limit(1)
119
+
120
+ @staticmethod
121
+ def total_size_with_status(
122
+ *status: FileStatus,
123
+ query_sha: str | None = None,
124
+ ) -> sa.Select[tuple[int]]:
125
+ """
126
+ This is re-implemented in Query.files_count_size because
127
+ of cyclic import between query.py and the current file.
128
+ """
129
+ stmt = sa.select(sa.func.sum(File.size).where(File.status.in_(status)))
130
+ if query_sha is not None:
131
+ stmt = stmt.join_from(query_file_proxy, File).where(
132
+ query_file_proxy.c.query_sha == query_sha
133
+ )
134
+ return stmt
135
+
136
+ @staticmethod
137
+ @functools.cache
138
+ def status_count_size(
139
+ all_: bool = False,
140
+ ) -> sa.Select[tuple[FileStatus, int, int]]:
141
+ stmt = sa.select(
142
+ File.status,
143
+ sa.func.count("*"),
144
+ sa.func.sum(File.size),
145
+ ).group_by(File.status)
146
+ if not all_:
147
+ stmt = stmt.where(File.status != FileStatus.Done)
148
+ return stmt
149
+
150
+
151
+ class query:
152
+ @staticmethod
153
+ @functools.cache
154
+ def all() -> sa.Select[tuple[Query]]:
155
+ return sa.select(Query)
156
+
157
+ @staticmethod
158
+ @functools.cache
159
+ def shas() -> sa.Select[tuple[str]]:
160
+ return sa.select(Query.sha)
161
+
162
+ __tag_query_cte: sa.CTE = (
163
+ sa.select(Tag.name, query_tag_proxy.c.query_sha)
164
+ .join_from(query_tag_proxy, Tag)
165
+ .cte("tag_query_cte")
166
+ )
167
+
168
+ __name_cte: sa.CTE = (
169
+ sa.select(__tag_query_cte)
170
+ .group_by(__tag_query_cte.c.name)
171
+ .having(sa.func.count("*") == 1)
172
+ .cte("name_cte")
173
+ )
174
+
175
+ __sha_cte: sa.CTE = (
176
+ sa.select(__tag_query_cte)
177
+ .group_by(__tag_query_cte.c.query_sha)
178
+ .having(sa.func.count("*") == 1)
179
+ .cte("sha_cte")
180
+ )
181
+
182
+ @staticmethod
183
+ @functools.cache
184
+ def name_sha() -> sa.Select[tuple[str, str]]:
185
+ return sa.select(query.__name_cte).join(
186
+ query.__sha_cte,
187
+ query.__name_cte.c.query_sha == query.__sha_cte.c.query_sha,
188
+ )
189
+
190
+ @staticmethod
191
+ def with_shas(*shas: str) -> sa.Select[tuple[Query]]:
192
+ if not shas:
193
+ raise ValueError(shas)
194
+ return sa.select(Query).where(Query.sha.in_(shas))
195
+
196
+ @staticmethod
197
+ def with_tag(tag: str) -> sa.Select[tuple[Query]]:
198
+ return (
199
+ sa.select(Query)
200
+ .join_from(query_tag_proxy, Tag)
201
+ .join_from(query_tag_proxy, Query)
202
+ .where(Tag.name == tag)
203
+ )
204
+
205
+ @staticmethod
206
+ def children(sha: str) -> sa.Select[tuple[Query]]:
207
+ return sa.select(Query).where(Query.require == sha)
208
+
209
+
210
+ class selection:
211
+ @staticmethod
212
+ @functools.cache
213
+ def all() -> sa.Select[tuple[Selection]]:
214
+ return sa.select(Selection)
215
+
216
+ @staticmethod
217
+ @functools.cache
218
+ def orphans() -> sa.Select[tuple[Selection]]:
219
+ return (
220
+ sa.select(Selection)
221
+ .outerjoin(Query)
222
+ .where(Query.sha == None) # noqa
223
+ )
224
+
225
+
226
+ class tag:
227
+ @staticmethod
228
+ @functools.cache
229
+ def all() -> sa.Select[tuple[Tag]]:
230
+ return sa.select(Tag)
231
+
232
+ @staticmethod
233
+ @functools.cache
234
+ def shas() -> sa.Select[tuple[str]]:
235
+ return sa.select(Tag.sha)
236
+
237
+ @staticmethod
238
+ @functools.cache
239
+ def orphans() -> sa.Select[tuple[Tag]]:
240
+ return (
241
+ sa.select(Tag).outerjoin(query_tag_proxy).filter_by(tag_sha=None)
242
+ )
243
+
244
+
245
+ class synda_file:
246
+ @staticmethod
247
+ @functools.cache
248
+ def all() -> sa.Select[tuple[SyndaFile]]:
249
+ return sa.select(SyndaFile)
250
+
251
+ @staticmethod
252
+ @functools.cache
253
+ def ids() -> sa.Select[tuple[int]]:
254
+ return sa.select(SyndaFile.file_id)
255
+
256
+ @staticmethod
257
+ def with_ids(*ids: int) -> sa.Select[tuple[SyndaFile]]:
258
+ return sa.select(SyndaFile).where(SyndaFile.file_id.in_(ids))
@@ -0,0 +1,85 @@
1
+ from sqlalchemy.orm import (
2
+ DeclarativeBase,
3
+ Mapped,
4
+ MappedAsDataclass,
5
+ mapped_column,
6
+ )
7
+
8
+ from esgpull.models.file import FileStatus
9
+ from esgpull.models.query import File
10
+
11
+ SyndaStatusMap = {
12
+ "running": FileStatus.Started,
13
+ "waiting": FileStatus.Queued,
14
+ }
15
+
16
+
17
+ class SyndaBase(MappedAsDataclass, DeclarativeBase):
18
+ pass
19
+
20
+
21
+ class SyndaFile(SyndaBase):
22
+ __tablename__ = "file"
23
+
24
+ url: Mapped[str]
25
+ file_functional_id: Mapped[str]
26
+ filename: Mapped[str]
27
+ local_path: Mapped[str]
28
+ data_node: Mapped[str]
29
+ checksum: Mapped[str]
30
+ checksum_type: Mapped[str]
31
+ duration: Mapped[int]
32
+ size: Mapped[int]
33
+ rate: Mapped[int]
34
+ start_date: Mapped[str]
35
+ end_date: Mapped[str]
36
+ crea_date: Mapped[str]
37
+ status: Mapped[str]
38
+ error_msg: Mapped[str]
39
+ sdget_status: Mapped[str]
40
+ sdget_error_msg: Mapped[str]
41
+ priority: Mapped[int]
42
+ tracking_id: Mapped[str]
43
+ model: Mapped[str]
44
+ project: Mapped[str]
45
+ variable: Mapped[str]
46
+ last_access_date: Mapped[str]
47
+ dataset_id: Mapped[int]
48
+ insertion_group_id: Mapped[int]
49
+ timestamp: Mapped[str]
50
+ file_id: Mapped[int] = mapped_column(init=False, primary_key=True)
51
+
52
+ def get_status(self) -> FileStatus:
53
+ s = self.status.lower()
54
+ result: FileStatus
55
+ if FileStatus.contains(s):
56
+ result = FileStatus(s)
57
+ elif s in SyndaStatusMap:
58
+ result = SyndaStatusMap[s]
59
+ else:
60
+ raise ValueError(s)
61
+ return result
62
+
63
+ def to_file(self) -> File:
64
+ file_id = self.file_functional_id
65
+ dataset_id = file_id.removesuffix(self.filename).strip(".")
66
+ dataset_master, version = dataset_id.rsplit(".", 1)
67
+ master_id = ".".join([dataset_master, self.filename])
68
+ url = self.url.replace("http://", "https://")
69
+ local_path = self.local_path.removesuffix(self.filename).strip("/")
70
+ result = File(
71
+ file_id=file_id,
72
+ dataset_id=dataset_id,
73
+ master_id=master_id,
74
+ url=url,
75
+ version=version,
76
+ filename=self.filename,
77
+ local_path=local_path,
78
+ data_node=self.data_node,
79
+ checksum=self.checksum,
80
+ checksum_type=self.checksum_type.upper(),
81
+ size=self.size,
82
+ status=self.get_status(),
83
+ )
84
+ result.compute_sha()
85
+ return result
esgpull/models/tag.py ADDED
@@ -0,0 +1,19 @@
1
+ from __future__ import annotations
2
+
3
+ import sqlalchemy as sa
4
+ from sqlalchemy.orm import Mapped, mapped_column
5
+
6
+ from esgpull.models.base import Base
7
+
8
+
9
+ class Tag(Base):
10
+ __tablename__ = "tag"
11
+
12
+ name: Mapped[str] = mapped_column(sa.String(255))
13
+ description: Mapped[str | None] = mapped_column(sa.Text, default=None)
14
+
15
+ def _as_bytes(self) -> bytes:
16
+ return self.name.encode()
17
+
18
+ def __hash__(self) -> int:
19
+ return hash(self._as_bytes())
@@ -0,0 +1,54 @@
1
+ from rich.console import Console, ConsoleOptions
2
+ from rich.measure import Measurement, measure_renderables
3
+
4
+
5
+ def short_sha(sha: str) -> str:
6
+ return f"<{sha[:6]}>"
7
+
8
+
9
+ def rich_measure_impl(
10
+ self,
11
+ console: Console,
12
+ options: ConsoleOptions,
13
+ ) -> Measurement:
14
+ renderables = list(self.__rich_console__(console, options))
15
+ return measure_renderables(console, options, renderables)
16
+
17
+
18
+ def find_str(container: list | str) -> str:
19
+ if isinstance(container, list):
20
+ return find_str(container[0])
21
+ elif isinstance(container, str):
22
+ return container
23
+ else:
24
+ raise ValueError(container)
25
+
26
+
27
+ def find_int(container: list | int) -> int:
28
+ if isinstance(container, list):
29
+ return find_int(container[0])
30
+ elif isinstance(container, int):
31
+ return container
32
+ else:
33
+ raise ValueError(container)
34
+
35
+
36
+ def get_local_path(source: dict, version: str) -> str:
37
+ flat_raw = {}
38
+ for k, v in source.items():
39
+ if isinstance(v, list) and len(v) == 1:
40
+ flat_raw[k] = v[0]
41
+ else:
42
+ flat_raw[k] = v
43
+ template = find_str(flat_raw["directory_format_template_"])
44
+ # format: "%(a)/%(b)/%(c)/..."
45
+ template = template.removeprefix("%(root)s/")
46
+ template = template.replace("%(", "{")
47
+ template = template.replace(")s", "}")
48
+ flat_raw.pop("version", None)
49
+ if "rcm_name" in flat_raw: # cordex special case
50
+ institute = flat_raw["institute"]
51
+ rcm_name = flat_raw["rcm_name"]
52
+ rcm_model = institute + "-" + rcm_name
53
+ flat_raw["rcm_model"] = rcm_model
54
+ return template.format(version=version, **flat_raw)
esgpull/presets.py ADDED
@@ -0,0 +1,13 @@
1
+ # from esgpull.models import Query
2
+
3
+
4
+ # IPCC_SCENARIOS = Query(
5
+ # select=dict(
6
+ # query=(
7
+ # "experiment:(rcp26 rcp45 rcp60 rcp85) OR "
8
+ # "experiment_id:(ssp119 ssp126 ssp245 ssp370 ssp460 ssp585)"
9
+ # )
10
+ # )
11
+ # )
12
+
13
+ # TEMPERATURE = Query(select=dict(variable=["tas", "tos", "tasmin", "tasmax"]))
esgpull/processor.py ADDED
@@ -0,0 +1,172 @@
1
+ import asyncio
2
+ import ssl
3
+ from collections.abc import AsyncIterator
4
+ from functools import partial
5
+ from typing import TypeAlias
6
+
7
+ from aiostream.stream import merge
8
+ from httpx import AsyncClient, HTTPError
9
+
10
+ from esgpull.auth import Auth
11
+ from esgpull.config import Config
12
+ from esgpull.download import DownloadCtx, Simple
13
+ from esgpull.exceptions import DownloadSizeError
14
+ from esgpull.fs import Digest, Filesystem
15
+ from esgpull.models import File
16
+ from esgpull.result import Err, Ok, Result
17
+ from esgpull.tui import logger
18
+
19
+ # Callback: TypeAlias = Callable[[], None] | partial[None]
20
+ Callback: TypeAlias = partial[None]
21
+
22
+ default_ssl_context: ssl.SSLContext | bool = False
23
+ default_ssl_context_loaded = False
24
+
25
+
26
+ def load_default_ssl_context() -> str:
27
+ global default_ssl_context
28
+ global default_ssl_context_loaded
29
+ if ssl.OPENSSL_VERSION_INFO[0] >= 3:
30
+ default_ssl_context = ssl.create_default_context()
31
+ default_ssl_context.options |= 0x4
32
+ msg = "Using openssl 3 or higher"
33
+ else:
34
+ default_ssl_context = True
35
+ msg = "Using openssl 1"
36
+ default_ssl_context_loaded = True
37
+ return msg
38
+
39
+
40
+ class Task:
41
+ def __init__(
42
+ self,
43
+ config: Config,
44
+ auth: Auth,
45
+ fs: Filesystem,
46
+ # *,
47
+ # url: str | None = None,
48
+ file: File,
49
+ start_callbacks: list[Callback] | None = None,
50
+ ) -> None:
51
+ self.config = config
52
+ self.auth = auth
53
+ self.fs = fs
54
+ self.ctx = DownloadCtx(file)
55
+ if not self.config.download.disable_checksum:
56
+ self.ctx.digest = Digest(file)
57
+ # if file is None and url is not None:
58
+ # self.file = self.fetch_file(url)
59
+ # elif file is not None:
60
+ # self.file = file
61
+ # else:
62
+ # raise ValueError("no arguments")
63
+ self.downloader = Simple()
64
+ msg: str | None = None
65
+ if not default_ssl_context_loaded:
66
+ msg = load_default_ssl_context()
67
+ self.ssl_context: ssl.SSLContext | bool
68
+ if self.config.download.disable_ssl:
69
+ self.ssl_context = False
70
+ else:
71
+ if msg is not None:
72
+ logger.info(msg)
73
+ self.ssl_context = default_ssl_context
74
+ if start_callbacks is None:
75
+ self.start_callbacks = []
76
+ else:
77
+ self.start_callbacks = start_callbacks
78
+
79
+ @property
80
+ def file(self) -> File:
81
+ return self.ctx.file
82
+
83
+ # def fetch_file(self, url: str) -> File:
84
+ # ctx = Context()
85
+ # # [?]TODO: define map data_node->index_node to find url-file
86
+ # # ctx.query.index_node = ...
87
+ # ctx.query.title = Path(url).name
88
+ # results = ctx.search(file=True)
89
+ # for res in results:
90
+ # file = File.from_dict(res)
91
+ # if file.version in url:
92
+ # return file
93
+ # raise ValueError(f"{url} is not valid")
94
+
95
+ async def stream(
96
+ self, semaphore: asyncio.Semaphore
97
+ ) -> AsyncIterator[Result]:
98
+ ctx = self.ctx
99
+ try:
100
+ async with (
101
+ semaphore,
102
+ self.fs.open(ctx.file) as file_obj,
103
+ AsyncClient(
104
+ follow_redirects=True,
105
+ cert=self.auth.cert,
106
+ verify=self.ssl_context,
107
+ timeout=self.config.download.http_timeout,
108
+ ) as client,
109
+ ):
110
+ for callback in self.start_callbacks:
111
+ callback()
112
+ stream = self.downloader.stream(
113
+ client,
114
+ ctx,
115
+ self.config.download.chunk_size,
116
+ )
117
+ async for ctx in stream:
118
+ if ctx.chunk is not None:
119
+ await file_obj.write(ctx.chunk)
120
+ if ctx.error:
121
+ err = DownloadSizeError(ctx.completed, ctx.file.size)
122
+ yield Err(ctx, err)
123
+ await stream.aclose()
124
+ break
125
+ elif ctx.finished:
126
+ await file_obj.to_done()
127
+ yield Ok(ctx)
128
+ except (
129
+ HTTPError,
130
+ DownloadSizeError,
131
+ GeneratorExit,
132
+ ssl.SSLError
133
+ # KeyboardInterrupt,
134
+ ) as err:
135
+ yield Err(ctx, err)
136
+
137
+
138
+ class Processor:
139
+ def __init__(
140
+ self,
141
+ config: Config,
142
+ auth: Auth,
143
+ fs: Filesystem,
144
+ files: list[File],
145
+ start_callbacks: dict[str, list[Callback]],
146
+ ) -> None:
147
+ self.config = config
148
+ self.fs = fs
149
+ self.files = list(filter(self.should_download, files))
150
+ self.tasks = []
151
+ for file in files:
152
+ task = Task(
153
+ config=config,
154
+ auth=auth,
155
+ fs=fs,
156
+ file=file,
157
+ start_callbacks=start_callbacks[file.sha],
158
+ )
159
+ self.tasks.append(task)
160
+
161
+ def should_download(self, file: File) -> bool:
162
+ if self.fs[file].drs.is_file():
163
+ return False
164
+ else:
165
+ return True
166
+
167
+ async def process(self) -> AsyncIterator[Result]:
168
+ semaphore = asyncio.Semaphore(self.config.download.max_concurrent)
169
+ streams = [task.stream(semaphore) for task in self.tasks]
170
+ async with merge(*streams).stream() as stream:
171
+ async for result in stream:
172
+ yield result
esgpull/py.typed ADDED
File without changes
esgpull/result.py ADDED
@@ -0,0 +1,53 @@
1
+ from dataclasses import dataclass, field
2
+ from typing import Generic, TypeVar
3
+
4
+ T = TypeVar("T")
5
+
6
+
7
+ @dataclass
8
+ class Result(Generic[T]):
9
+ data: T
10
+ ok: bool
11
+ err: BaseException | None
12
+
13
+
14
+ @dataclass
15
+ class Ok(Result[T]):
16
+ ok: bool = field(default=True, init=False)
17
+ err: None = field(default=None, init=False)
18
+
19
+
20
+ @dataclass
21
+ class Err(Result[T]):
22
+ ok: bool = field(default=False, init=False)
23
+ err: BaseException = field()
24
+
25
+
26
+ # class Result:
27
+ # ok: bool
28
+ # file: File
29
+ # completed: int
30
+ # err: BaseException | None
31
+
32
+
33
+ # class Ok(Result):
34
+ # ok = True
35
+ # file: File
36
+ # completed: int
37
+ # err = None
38
+
39
+ # def __init__(self, file: File, completed: int) -> None:
40
+ # self.file = file
41
+ # self.completed = completed
42
+
43
+
44
+ # class Err(Result):
45
+ # ok = False
46
+ # file: File
47
+ # completed: int
48
+ # err: BaseException
49
+
50
+ # def __init__(self, file: File, completed: int, err: BaseException) -> None:
51
+ # self.file = file
52
+ # self.completed = completed
53
+ # self.err = err