esgpull 0.6.3__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- esgpull/__init__.py +12 -0
- esgpull/auth.py +181 -0
- esgpull/cli/__init__.py +73 -0
- esgpull/cli/add.py +103 -0
- esgpull/cli/autoremove.py +38 -0
- esgpull/cli/config.py +116 -0
- esgpull/cli/convert.py +285 -0
- esgpull/cli/decorators.py +342 -0
- esgpull/cli/download.py +74 -0
- esgpull/cli/facet.py +23 -0
- esgpull/cli/get.py +28 -0
- esgpull/cli/install.py +85 -0
- esgpull/cli/link.py +105 -0
- esgpull/cli/login.py +56 -0
- esgpull/cli/remove.py +73 -0
- esgpull/cli/retry.py +43 -0
- esgpull/cli/search.py +201 -0
- esgpull/cli/self.py +238 -0
- esgpull/cli/show.py +66 -0
- esgpull/cli/status.py +67 -0
- esgpull/cli/track.py +87 -0
- esgpull/cli/update.py +184 -0
- esgpull/cli/utils.py +247 -0
- esgpull/config.py +410 -0
- esgpull/constants.py +56 -0
- esgpull/context.py +724 -0
- esgpull/database.py +161 -0
- esgpull/download.py +162 -0
- esgpull/esgpull.py +447 -0
- esgpull/exceptions.py +167 -0
- esgpull/fs.py +253 -0
- esgpull/graph.py +460 -0
- esgpull/install_config.py +185 -0
- esgpull/migrations/README +1 -0
- esgpull/migrations/env.py +82 -0
- esgpull/migrations/script.py.mako +24 -0
- esgpull/migrations/versions/0.3.0_update_tables.py +170 -0
- esgpull/migrations/versions/0.3.1_update_tables.py +25 -0
- esgpull/migrations/versions/0.3.2_update_tables.py +26 -0
- esgpull/migrations/versions/0.3.3_update_tables.py +25 -0
- esgpull/migrations/versions/0.3.4_update_tables.py +25 -0
- esgpull/migrations/versions/0.3.5_update_tables.py +25 -0
- esgpull/migrations/versions/0.3.6_update_tables.py +26 -0
- esgpull/migrations/versions/0.3.7_update_tables.py +26 -0
- esgpull/migrations/versions/0.3.8_update_tables.py +26 -0
- esgpull/migrations/versions/0.4.0_update_tables.py +25 -0
- esgpull/migrations/versions/0.5.0_update_tables.py +26 -0
- esgpull/migrations/versions/0.5.1_update_tables.py +26 -0
- esgpull/migrations/versions/0.5.2_update_tables.py +25 -0
- esgpull/migrations/versions/0.5.3_update_tables.py +26 -0
- esgpull/migrations/versions/0.5.4_update_tables.py +25 -0
- esgpull/migrations/versions/0.5.5_update_tables.py +25 -0
- esgpull/migrations/versions/0.6.0_update_tables.py +25 -0
- esgpull/migrations/versions/0.6.1_update_tables.py +25 -0
- esgpull/migrations/versions/0.6.2_update_tables.py +25 -0
- esgpull/migrations/versions/0.6.3_update_tables.py +25 -0
- esgpull/models/__init__.py +31 -0
- esgpull/models/base.py +50 -0
- esgpull/models/dataset.py +34 -0
- esgpull/models/facet.py +18 -0
- esgpull/models/file.py +65 -0
- esgpull/models/options.py +164 -0
- esgpull/models/query.py +481 -0
- esgpull/models/selection.py +201 -0
- esgpull/models/sql.py +258 -0
- esgpull/models/synda_file.py +85 -0
- esgpull/models/tag.py +19 -0
- esgpull/models/utils.py +54 -0
- esgpull/presets.py +13 -0
- esgpull/processor.py +172 -0
- esgpull/py.typed +0 -0
- esgpull/result.py +53 -0
- esgpull/tui.py +346 -0
- esgpull/utils.py +54 -0
- esgpull/version.py +1 -0
- esgpull-0.6.3.dist-info/METADATA +110 -0
- esgpull-0.6.3.dist-info/RECORD +80 -0
- esgpull-0.6.3.dist-info/WHEEL +4 -0
- esgpull-0.6.3.dist-info/entry_points.txt +3 -0
- esgpull-0.6.3.dist-info/licenses/LICENSE +28 -0
esgpull/models/sql.py
ADDED
|
@@ -0,0 +1,258 @@
|
|
|
1
|
+
import functools
|
|
2
|
+
|
|
3
|
+
import sqlalchemy as sa
|
|
4
|
+
|
|
5
|
+
from esgpull.models import Table
|
|
6
|
+
from esgpull.models.facet import Facet
|
|
7
|
+
from esgpull.models.file import FileStatus
|
|
8
|
+
from esgpull.models.query import File, Query, query_file_proxy, query_tag_proxy
|
|
9
|
+
from esgpull.models.selection import Selection, selection_facet_proxy
|
|
10
|
+
from esgpull.models.synda_file import SyndaFile
|
|
11
|
+
from esgpull.models.tag import Tag
|
|
12
|
+
|
|
13
|
+
|
|
14
|
+
def count(item: Table) -> sa.Select[tuple[int]]:
|
|
15
|
+
table = item.__class__
|
|
16
|
+
return (
|
|
17
|
+
sa.select(sa.func.count("*"))
|
|
18
|
+
.select_from(table)
|
|
19
|
+
.filter_by(sha=item.sha)
|
|
20
|
+
)
|
|
21
|
+
|
|
22
|
+
|
|
23
|
+
def count_table(table: type[Table]) -> sa.Select[tuple[int]]:
|
|
24
|
+
return sa.select(sa.func.count("*")).select_from(table)
|
|
25
|
+
|
|
26
|
+
|
|
27
|
+
class facet:
|
|
28
|
+
@staticmethod
|
|
29
|
+
@functools.cache
|
|
30
|
+
def all() -> sa.Select[tuple[Facet]]:
|
|
31
|
+
return sa.select(Facet)
|
|
32
|
+
|
|
33
|
+
@staticmethod
|
|
34
|
+
@functools.cache
|
|
35
|
+
def shas() -> sa.Select[tuple[str]]:
|
|
36
|
+
return sa.select(Facet.sha)
|
|
37
|
+
|
|
38
|
+
@staticmethod
|
|
39
|
+
@functools.cache
|
|
40
|
+
def name_count() -> sa.Select[tuple[str, int]]:
|
|
41
|
+
return sa.select(Facet.name, sa.func.count("*")).group_by(Facet.name)
|
|
42
|
+
|
|
43
|
+
@staticmethod
|
|
44
|
+
@functools.cache
|
|
45
|
+
def usage() -> sa.Select[tuple[Facet, int]]:
|
|
46
|
+
return (
|
|
47
|
+
sa.select(Facet, sa.func.count("*"))
|
|
48
|
+
.join(selection_facet_proxy)
|
|
49
|
+
.group_by(Facet.sha)
|
|
50
|
+
)
|
|
51
|
+
|
|
52
|
+
@staticmethod
|
|
53
|
+
def known_shas(shas: list[str]) -> sa.Select[tuple[str]]:
|
|
54
|
+
return sa.select(Facet.sha).where(Facet.sha.in_(shas))
|
|
55
|
+
|
|
56
|
+
@staticmethod
|
|
57
|
+
@functools.cache
|
|
58
|
+
def names() -> sa.Select[tuple[str]]:
|
|
59
|
+
return sa.select(Facet.name).distinct()
|
|
60
|
+
|
|
61
|
+
@staticmethod
|
|
62
|
+
def values(name: str) -> sa.Select[tuple[str]]:
|
|
63
|
+
return sa.select(Facet.value).where(Facet.name == name)
|
|
64
|
+
|
|
65
|
+
|
|
66
|
+
class file:
|
|
67
|
+
@staticmethod
|
|
68
|
+
@functools.cache
|
|
69
|
+
def all() -> sa.Select[tuple[File]]:
|
|
70
|
+
return sa.select(File)
|
|
71
|
+
|
|
72
|
+
@staticmethod
|
|
73
|
+
@functools.cache
|
|
74
|
+
def shas() -> sa.Select[tuple[str]]:
|
|
75
|
+
return sa.select(File.sha)
|
|
76
|
+
|
|
77
|
+
@staticmethod
|
|
78
|
+
@functools.cache
|
|
79
|
+
def orphans() -> sa.Select[tuple[File]]:
|
|
80
|
+
return (
|
|
81
|
+
sa.select(File)
|
|
82
|
+
.outerjoin(query_file_proxy)
|
|
83
|
+
.filter_by(file_sha=None)
|
|
84
|
+
)
|
|
85
|
+
|
|
86
|
+
@staticmethod
|
|
87
|
+
@functools.cache
|
|
88
|
+
def linked() -> sa.Select[tuple[File]]:
|
|
89
|
+
return sa.select(query_file_proxy.c.file_sha).distinct()
|
|
90
|
+
|
|
91
|
+
__dups_cte: sa.CTE = (
|
|
92
|
+
sa.select(File.master_id)
|
|
93
|
+
.group_by(File.master_id)
|
|
94
|
+
.having(sa.func.count("*") > 1)
|
|
95
|
+
.cte()
|
|
96
|
+
)
|
|
97
|
+
|
|
98
|
+
@staticmethod
|
|
99
|
+
@functools.cache
|
|
100
|
+
def duplicates() -> sa.Select[tuple[File]]:
|
|
101
|
+
return sa.select(File).join(
|
|
102
|
+
file.__dups_cte,
|
|
103
|
+
File.master_id == file.__dups_cte.c.master_id,
|
|
104
|
+
)
|
|
105
|
+
|
|
106
|
+
@staticmethod
|
|
107
|
+
def shas_from_query(query_sha: str) -> sa.Select[tuple[str]]:
|
|
108
|
+
return sa.select(query_file_proxy.c.file_sha).filter_by(
|
|
109
|
+
query_sha=query_sha
|
|
110
|
+
)
|
|
111
|
+
|
|
112
|
+
@staticmethod
|
|
113
|
+
def with_status(*status: FileStatus) -> sa.Select[tuple[File]]:
|
|
114
|
+
return sa.select(File).where(File.status.in_(status))
|
|
115
|
+
|
|
116
|
+
@staticmethod
|
|
117
|
+
def with_file_id(file_id: str) -> sa.Select[tuple[str]]:
|
|
118
|
+
return sa.select(File.sha).where(File.file_id == file_id).limit(1)
|
|
119
|
+
|
|
120
|
+
@staticmethod
|
|
121
|
+
def total_size_with_status(
|
|
122
|
+
*status: FileStatus,
|
|
123
|
+
query_sha: str | None = None,
|
|
124
|
+
) -> sa.Select[tuple[int]]:
|
|
125
|
+
"""
|
|
126
|
+
This is re-implemented in Query.files_count_size because
|
|
127
|
+
of cyclic import between query.py and the current file.
|
|
128
|
+
"""
|
|
129
|
+
stmt = sa.select(sa.func.sum(File.size).where(File.status.in_(status)))
|
|
130
|
+
if query_sha is not None:
|
|
131
|
+
stmt = stmt.join_from(query_file_proxy, File).where(
|
|
132
|
+
query_file_proxy.c.query_sha == query_sha
|
|
133
|
+
)
|
|
134
|
+
return stmt
|
|
135
|
+
|
|
136
|
+
@staticmethod
|
|
137
|
+
@functools.cache
|
|
138
|
+
def status_count_size(
|
|
139
|
+
all_: bool = False,
|
|
140
|
+
) -> sa.Select[tuple[FileStatus, int, int]]:
|
|
141
|
+
stmt = sa.select(
|
|
142
|
+
File.status,
|
|
143
|
+
sa.func.count("*"),
|
|
144
|
+
sa.func.sum(File.size),
|
|
145
|
+
).group_by(File.status)
|
|
146
|
+
if not all_:
|
|
147
|
+
stmt = stmt.where(File.status != FileStatus.Done)
|
|
148
|
+
return stmt
|
|
149
|
+
|
|
150
|
+
|
|
151
|
+
class query:
|
|
152
|
+
@staticmethod
|
|
153
|
+
@functools.cache
|
|
154
|
+
def all() -> sa.Select[tuple[Query]]:
|
|
155
|
+
return sa.select(Query)
|
|
156
|
+
|
|
157
|
+
@staticmethod
|
|
158
|
+
@functools.cache
|
|
159
|
+
def shas() -> sa.Select[tuple[str]]:
|
|
160
|
+
return sa.select(Query.sha)
|
|
161
|
+
|
|
162
|
+
__tag_query_cte: sa.CTE = (
|
|
163
|
+
sa.select(Tag.name, query_tag_proxy.c.query_sha)
|
|
164
|
+
.join_from(query_tag_proxy, Tag)
|
|
165
|
+
.cte("tag_query_cte")
|
|
166
|
+
)
|
|
167
|
+
|
|
168
|
+
__name_cte: sa.CTE = (
|
|
169
|
+
sa.select(__tag_query_cte)
|
|
170
|
+
.group_by(__tag_query_cte.c.name)
|
|
171
|
+
.having(sa.func.count("*") == 1)
|
|
172
|
+
.cte("name_cte")
|
|
173
|
+
)
|
|
174
|
+
|
|
175
|
+
__sha_cte: sa.CTE = (
|
|
176
|
+
sa.select(__tag_query_cte)
|
|
177
|
+
.group_by(__tag_query_cte.c.query_sha)
|
|
178
|
+
.having(sa.func.count("*") == 1)
|
|
179
|
+
.cte("sha_cte")
|
|
180
|
+
)
|
|
181
|
+
|
|
182
|
+
@staticmethod
|
|
183
|
+
@functools.cache
|
|
184
|
+
def name_sha() -> sa.Select[tuple[str, str]]:
|
|
185
|
+
return sa.select(query.__name_cte).join(
|
|
186
|
+
query.__sha_cte,
|
|
187
|
+
query.__name_cte.c.query_sha == query.__sha_cte.c.query_sha,
|
|
188
|
+
)
|
|
189
|
+
|
|
190
|
+
@staticmethod
|
|
191
|
+
def with_shas(*shas: str) -> sa.Select[tuple[Query]]:
|
|
192
|
+
if not shas:
|
|
193
|
+
raise ValueError(shas)
|
|
194
|
+
return sa.select(Query).where(Query.sha.in_(shas))
|
|
195
|
+
|
|
196
|
+
@staticmethod
|
|
197
|
+
def with_tag(tag: str) -> sa.Select[tuple[Query]]:
|
|
198
|
+
return (
|
|
199
|
+
sa.select(Query)
|
|
200
|
+
.join_from(query_tag_proxy, Tag)
|
|
201
|
+
.join_from(query_tag_proxy, Query)
|
|
202
|
+
.where(Tag.name == tag)
|
|
203
|
+
)
|
|
204
|
+
|
|
205
|
+
@staticmethod
|
|
206
|
+
def children(sha: str) -> sa.Select[tuple[Query]]:
|
|
207
|
+
return sa.select(Query).where(Query.require == sha)
|
|
208
|
+
|
|
209
|
+
|
|
210
|
+
class selection:
|
|
211
|
+
@staticmethod
|
|
212
|
+
@functools.cache
|
|
213
|
+
def all() -> sa.Select[tuple[Selection]]:
|
|
214
|
+
return sa.select(Selection)
|
|
215
|
+
|
|
216
|
+
@staticmethod
|
|
217
|
+
@functools.cache
|
|
218
|
+
def orphans() -> sa.Select[tuple[Selection]]:
|
|
219
|
+
return (
|
|
220
|
+
sa.select(Selection)
|
|
221
|
+
.outerjoin(Query)
|
|
222
|
+
.where(Query.sha == None) # noqa
|
|
223
|
+
)
|
|
224
|
+
|
|
225
|
+
|
|
226
|
+
class tag:
|
|
227
|
+
@staticmethod
|
|
228
|
+
@functools.cache
|
|
229
|
+
def all() -> sa.Select[tuple[Tag]]:
|
|
230
|
+
return sa.select(Tag)
|
|
231
|
+
|
|
232
|
+
@staticmethod
|
|
233
|
+
@functools.cache
|
|
234
|
+
def shas() -> sa.Select[tuple[str]]:
|
|
235
|
+
return sa.select(Tag.sha)
|
|
236
|
+
|
|
237
|
+
@staticmethod
|
|
238
|
+
@functools.cache
|
|
239
|
+
def orphans() -> sa.Select[tuple[Tag]]:
|
|
240
|
+
return (
|
|
241
|
+
sa.select(Tag).outerjoin(query_tag_proxy).filter_by(tag_sha=None)
|
|
242
|
+
)
|
|
243
|
+
|
|
244
|
+
|
|
245
|
+
class synda_file:
|
|
246
|
+
@staticmethod
|
|
247
|
+
@functools.cache
|
|
248
|
+
def all() -> sa.Select[tuple[SyndaFile]]:
|
|
249
|
+
return sa.select(SyndaFile)
|
|
250
|
+
|
|
251
|
+
@staticmethod
|
|
252
|
+
@functools.cache
|
|
253
|
+
def ids() -> sa.Select[tuple[int]]:
|
|
254
|
+
return sa.select(SyndaFile.file_id)
|
|
255
|
+
|
|
256
|
+
@staticmethod
|
|
257
|
+
def with_ids(*ids: int) -> sa.Select[tuple[SyndaFile]]:
|
|
258
|
+
return sa.select(SyndaFile).where(SyndaFile.file_id.in_(ids))
|
|
@@ -0,0 +1,85 @@
|
|
|
1
|
+
from sqlalchemy.orm import (
|
|
2
|
+
DeclarativeBase,
|
|
3
|
+
Mapped,
|
|
4
|
+
MappedAsDataclass,
|
|
5
|
+
mapped_column,
|
|
6
|
+
)
|
|
7
|
+
|
|
8
|
+
from esgpull.models.file import FileStatus
|
|
9
|
+
from esgpull.models.query import File
|
|
10
|
+
|
|
11
|
+
SyndaStatusMap = {
|
|
12
|
+
"running": FileStatus.Started,
|
|
13
|
+
"waiting": FileStatus.Queued,
|
|
14
|
+
}
|
|
15
|
+
|
|
16
|
+
|
|
17
|
+
class SyndaBase(MappedAsDataclass, DeclarativeBase):
|
|
18
|
+
pass
|
|
19
|
+
|
|
20
|
+
|
|
21
|
+
class SyndaFile(SyndaBase):
|
|
22
|
+
__tablename__ = "file"
|
|
23
|
+
|
|
24
|
+
url: Mapped[str]
|
|
25
|
+
file_functional_id: Mapped[str]
|
|
26
|
+
filename: Mapped[str]
|
|
27
|
+
local_path: Mapped[str]
|
|
28
|
+
data_node: Mapped[str]
|
|
29
|
+
checksum: Mapped[str]
|
|
30
|
+
checksum_type: Mapped[str]
|
|
31
|
+
duration: Mapped[int]
|
|
32
|
+
size: Mapped[int]
|
|
33
|
+
rate: Mapped[int]
|
|
34
|
+
start_date: Mapped[str]
|
|
35
|
+
end_date: Mapped[str]
|
|
36
|
+
crea_date: Mapped[str]
|
|
37
|
+
status: Mapped[str]
|
|
38
|
+
error_msg: Mapped[str]
|
|
39
|
+
sdget_status: Mapped[str]
|
|
40
|
+
sdget_error_msg: Mapped[str]
|
|
41
|
+
priority: Mapped[int]
|
|
42
|
+
tracking_id: Mapped[str]
|
|
43
|
+
model: Mapped[str]
|
|
44
|
+
project: Mapped[str]
|
|
45
|
+
variable: Mapped[str]
|
|
46
|
+
last_access_date: Mapped[str]
|
|
47
|
+
dataset_id: Mapped[int]
|
|
48
|
+
insertion_group_id: Mapped[int]
|
|
49
|
+
timestamp: Mapped[str]
|
|
50
|
+
file_id: Mapped[int] = mapped_column(init=False, primary_key=True)
|
|
51
|
+
|
|
52
|
+
def get_status(self) -> FileStatus:
|
|
53
|
+
s = self.status.lower()
|
|
54
|
+
result: FileStatus
|
|
55
|
+
if FileStatus.contains(s):
|
|
56
|
+
result = FileStatus(s)
|
|
57
|
+
elif s in SyndaStatusMap:
|
|
58
|
+
result = SyndaStatusMap[s]
|
|
59
|
+
else:
|
|
60
|
+
raise ValueError(s)
|
|
61
|
+
return result
|
|
62
|
+
|
|
63
|
+
def to_file(self) -> File:
|
|
64
|
+
file_id = self.file_functional_id
|
|
65
|
+
dataset_id = file_id.removesuffix(self.filename).strip(".")
|
|
66
|
+
dataset_master, version = dataset_id.rsplit(".", 1)
|
|
67
|
+
master_id = ".".join([dataset_master, self.filename])
|
|
68
|
+
url = self.url.replace("http://", "https://")
|
|
69
|
+
local_path = self.local_path.removesuffix(self.filename).strip("/")
|
|
70
|
+
result = File(
|
|
71
|
+
file_id=file_id,
|
|
72
|
+
dataset_id=dataset_id,
|
|
73
|
+
master_id=master_id,
|
|
74
|
+
url=url,
|
|
75
|
+
version=version,
|
|
76
|
+
filename=self.filename,
|
|
77
|
+
local_path=local_path,
|
|
78
|
+
data_node=self.data_node,
|
|
79
|
+
checksum=self.checksum,
|
|
80
|
+
checksum_type=self.checksum_type.upper(),
|
|
81
|
+
size=self.size,
|
|
82
|
+
status=self.get_status(),
|
|
83
|
+
)
|
|
84
|
+
result.compute_sha()
|
|
85
|
+
return result
|
esgpull/models/tag.py
ADDED
|
@@ -0,0 +1,19 @@
|
|
|
1
|
+
from __future__ import annotations
|
|
2
|
+
|
|
3
|
+
import sqlalchemy as sa
|
|
4
|
+
from sqlalchemy.orm import Mapped, mapped_column
|
|
5
|
+
|
|
6
|
+
from esgpull.models.base import Base
|
|
7
|
+
|
|
8
|
+
|
|
9
|
+
class Tag(Base):
|
|
10
|
+
__tablename__ = "tag"
|
|
11
|
+
|
|
12
|
+
name: Mapped[str] = mapped_column(sa.String(255))
|
|
13
|
+
description: Mapped[str | None] = mapped_column(sa.Text, default=None)
|
|
14
|
+
|
|
15
|
+
def _as_bytes(self) -> bytes:
|
|
16
|
+
return self.name.encode()
|
|
17
|
+
|
|
18
|
+
def __hash__(self) -> int:
|
|
19
|
+
return hash(self._as_bytes())
|
esgpull/models/utils.py
ADDED
|
@@ -0,0 +1,54 @@
|
|
|
1
|
+
from rich.console import Console, ConsoleOptions
|
|
2
|
+
from rich.measure import Measurement, measure_renderables
|
|
3
|
+
|
|
4
|
+
|
|
5
|
+
def short_sha(sha: str) -> str:
|
|
6
|
+
return f"<{sha[:6]}>"
|
|
7
|
+
|
|
8
|
+
|
|
9
|
+
def rich_measure_impl(
|
|
10
|
+
self,
|
|
11
|
+
console: Console,
|
|
12
|
+
options: ConsoleOptions,
|
|
13
|
+
) -> Measurement:
|
|
14
|
+
renderables = list(self.__rich_console__(console, options))
|
|
15
|
+
return measure_renderables(console, options, renderables)
|
|
16
|
+
|
|
17
|
+
|
|
18
|
+
def find_str(container: list | str) -> str:
|
|
19
|
+
if isinstance(container, list):
|
|
20
|
+
return find_str(container[0])
|
|
21
|
+
elif isinstance(container, str):
|
|
22
|
+
return container
|
|
23
|
+
else:
|
|
24
|
+
raise ValueError(container)
|
|
25
|
+
|
|
26
|
+
|
|
27
|
+
def find_int(container: list | int) -> int:
|
|
28
|
+
if isinstance(container, list):
|
|
29
|
+
return find_int(container[0])
|
|
30
|
+
elif isinstance(container, int):
|
|
31
|
+
return container
|
|
32
|
+
else:
|
|
33
|
+
raise ValueError(container)
|
|
34
|
+
|
|
35
|
+
|
|
36
|
+
def get_local_path(source: dict, version: str) -> str:
|
|
37
|
+
flat_raw = {}
|
|
38
|
+
for k, v in source.items():
|
|
39
|
+
if isinstance(v, list) and len(v) == 1:
|
|
40
|
+
flat_raw[k] = v[0]
|
|
41
|
+
else:
|
|
42
|
+
flat_raw[k] = v
|
|
43
|
+
template = find_str(flat_raw["directory_format_template_"])
|
|
44
|
+
# format: "%(a)/%(b)/%(c)/..."
|
|
45
|
+
template = template.removeprefix("%(root)s/")
|
|
46
|
+
template = template.replace("%(", "{")
|
|
47
|
+
template = template.replace(")s", "}")
|
|
48
|
+
flat_raw.pop("version", None)
|
|
49
|
+
if "rcm_name" in flat_raw: # cordex special case
|
|
50
|
+
institute = flat_raw["institute"]
|
|
51
|
+
rcm_name = flat_raw["rcm_name"]
|
|
52
|
+
rcm_model = institute + "-" + rcm_name
|
|
53
|
+
flat_raw["rcm_model"] = rcm_model
|
|
54
|
+
return template.format(version=version, **flat_raw)
|
esgpull/presets.py
ADDED
|
@@ -0,0 +1,13 @@
|
|
|
1
|
+
# from esgpull.models import Query
|
|
2
|
+
|
|
3
|
+
|
|
4
|
+
# IPCC_SCENARIOS = Query(
|
|
5
|
+
# select=dict(
|
|
6
|
+
# query=(
|
|
7
|
+
# "experiment:(rcp26 rcp45 rcp60 rcp85) OR "
|
|
8
|
+
# "experiment_id:(ssp119 ssp126 ssp245 ssp370 ssp460 ssp585)"
|
|
9
|
+
# )
|
|
10
|
+
# )
|
|
11
|
+
# )
|
|
12
|
+
|
|
13
|
+
# TEMPERATURE = Query(select=dict(variable=["tas", "tos", "tasmin", "tasmax"]))
|
esgpull/processor.py
ADDED
|
@@ -0,0 +1,172 @@
|
|
|
1
|
+
import asyncio
|
|
2
|
+
import ssl
|
|
3
|
+
from collections.abc import AsyncIterator
|
|
4
|
+
from functools import partial
|
|
5
|
+
from typing import TypeAlias
|
|
6
|
+
|
|
7
|
+
from aiostream.stream import merge
|
|
8
|
+
from httpx import AsyncClient, HTTPError
|
|
9
|
+
|
|
10
|
+
from esgpull.auth import Auth
|
|
11
|
+
from esgpull.config import Config
|
|
12
|
+
from esgpull.download import DownloadCtx, Simple
|
|
13
|
+
from esgpull.exceptions import DownloadSizeError
|
|
14
|
+
from esgpull.fs import Digest, Filesystem
|
|
15
|
+
from esgpull.models import File
|
|
16
|
+
from esgpull.result import Err, Ok, Result
|
|
17
|
+
from esgpull.tui import logger
|
|
18
|
+
|
|
19
|
+
# Callback: TypeAlias = Callable[[], None] | partial[None]
|
|
20
|
+
Callback: TypeAlias = partial[None]
|
|
21
|
+
|
|
22
|
+
default_ssl_context: ssl.SSLContext | bool = False
|
|
23
|
+
default_ssl_context_loaded = False
|
|
24
|
+
|
|
25
|
+
|
|
26
|
+
def load_default_ssl_context() -> str:
|
|
27
|
+
global default_ssl_context
|
|
28
|
+
global default_ssl_context_loaded
|
|
29
|
+
if ssl.OPENSSL_VERSION_INFO[0] >= 3:
|
|
30
|
+
default_ssl_context = ssl.create_default_context()
|
|
31
|
+
default_ssl_context.options |= 0x4
|
|
32
|
+
msg = "Using openssl 3 or higher"
|
|
33
|
+
else:
|
|
34
|
+
default_ssl_context = True
|
|
35
|
+
msg = "Using openssl 1"
|
|
36
|
+
default_ssl_context_loaded = True
|
|
37
|
+
return msg
|
|
38
|
+
|
|
39
|
+
|
|
40
|
+
class Task:
|
|
41
|
+
def __init__(
|
|
42
|
+
self,
|
|
43
|
+
config: Config,
|
|
44
|
+
auth: Auth,
|
|
45
|
+
fs: Filesystem,
|
|
46
|
+
# *,
|
|
47
|
+
# url: str | None = None,
|
|
48
|
+
file: File,
|
|
49
|
+
start_callbacks: list[Callback] | None = None,
|
|
50
|
+
) -> None:
|
|
51
|
+
self.config = config
|
|
52
|
+
self.auth = auth
|
|
53
|
+
self.fs = fs
|
|
54
|
+
self.ctx = DownloadCtx(file)
|
|
55
|
+
if not self.config.download.disable_checksum:
|
|
56
|
+
self.ctx.digest = Digest(file)
|
|
57
|
+
# if file is None and url is not None:
|
|
58
|
+
# self.file = self.fetch_file(url)
|
|
59
|
+
# elif file is not None:
|
|
60
|
+
# self.file = file
|
|
61
|
+
# else:
|
|
62
|
+
# raise ValueError("no arguments")
|
|
63
|
+
self.downloader = Simple()
|
|
64
|
+
msg: str | None = None
|
|
65
|
+
if not default_ssl_context_loaded:
|
|
66
|
+
msg = load_default_ssl_context()
|
|
67
|
+
self.ssl_context: ssl.SSLContext | bool
|
|
68
|
+
if self.config.download.disable_ssl:
|
|
69
|
+
self.ssl_context = False
|
|
70
|
+
else:
|
|
71
|
+
if msg is not None:
|
|
72
|
+
logger.info(msg)
|
|
73
|
+
self.ssl_context = default_ssl_context
|
|
74
|
+
if start_callbacks is None:
|
|
75
|
+
self.start_callbacks = []
|
|
76
|
+
else:
|
|
77
|
+
self.start_callbacks = start_callbacks
|
|
78
|
+
|
|
79
|
+
@property
|
|
80
|
+
def file(self) -> File:
|
|
81
|
+
return self.ctx.file
|
|
82
|
+
|
|
83
|
+
# def fetch_file(self, url: str) -> File:
|
|
84
|
+
# ctx = Context()
|
|
85
|
+
# # [?]TODO: define map data_node->index_node to find url-file
|
|
86
|
+
# # ctx.query.index_node = ...
|
|
87
|
+
# ctx.query.title = Path(url).name
|
|
88
|
+
# results = ctx.search(file=True)
|
|
89
|
+
# for res in results:
|
|
90
|
+
# file = File.from_dict(res)
|
|
91
|
+
# if file.version in url:
|
|
92
|
+
# return file
|
|
93
|
+
# raise ValueError(f"{url} is not valid")
|
|
94
|
+
|
|
95
|
+
async def stream(
|
|
96
|
+
self, semaphore: asyncio.Semaphore
|
|
97
|
+
) -> AsyncIterator[Result]:
|
|
98
|
+
ctx = self.ctx
|
|
99
|
+
try:
|
|
100
|
+
async with (
|
|
101
|
+
semaphore,
|
|
102
|
+
self.fs.open(ctx.file) as file_obj,
|
|
103
|
+
AsyncClient(
|
|
104
|
+
follow_redirects=True,
|
|
105
|
+
cert=self.auth.cert,
|
|
106
|
+
verify=self.ssl_context,
|
|
107
|
+
timeout=self.config.download.http_timeout,
|
|
108
|
+
) as client,
|
|
109
|
+
):
|
|
110
|
+
for callback in self.start_callbacks:
|
|
111
|
+
callback()
|
|
112
|
+
stream = self.downloader.stream(
|
|
113
|
+
client,
|
|
114
|
+
ctx,
|
|
115
|
+
self.config.download.chunk_size,
|
|
116
|
+
)
|
|
117
|
+
async for ctx in stream:
|
|
118
|
+
if ctx.chunk is not None:
|
|
119
|
+
await file_obj.write(ctx.chunk)
|
|
120
|
+
if ctx.error:
|
|
121
|
+
err = DownloadSizeError(ctx.completed, ctx.file.size)
|
|
122
|
+
yield Err(ctx, err)
|
|
123
|
+
await stream.aclose()
|
|
124
|
+
break
|
|
125
|
+
elif ctx.finished:
|
|
126
|
+
await file_obj.to_done()
|
|
127
|
+
yield Ok(ctx)
|
|
128
|
+
except (
|
|
129
|
+
HTTPError,
|
|
130
|
+
DownloadSizeError,
|
|
131
|
+
GeneratorExit,
|
|
132
|
+
ssl.SSLError
|
|
133
|
+
# KeyboardInterrupt,
|
|
134
|
+
) as err:
|
|
135
|
+
yield Err(ctx, err)
|
|
136
|
+
|
|
137
|
+
|
|
138
|
+
class Processor:
|
|
139
|
+
def __init__(
|
|
140
|
+
self,
|
|
141
|
+
config: Config,
|
|
142
|
+
auth: Auth,
|
|
143
|
+
fs: Filesystem,
|
|
144
|
+
files: list[File],
|
|
145
|
+
start_callbacks: dict[str, list[Callback]],
|
|
146
|
+
) -> None:
|
|
147
|
+
self.config = config
|
|
148
|
+
self.fs = fs
|
|
149
|
+
self.files = list(filter(self.should_download, files))
|
|
150
|
+
self.tasks = []
|
|
151
|
+
for file in files:
|
|
152
|
+
task = Task(
|
|
153
|
+
config=config,
|
|
154
|
+
auth=auth,
|
|
155
|
+
fs=fs,
|
|
156
|
+
file=file,
|
|
157
|
+
start_callbacks=start_callbacks[file.sha],
|
|
158
|
+
)
|
|
159
|
+
self.tasks.append(task)
|
|
160
|
+
|
|
161
|
+
def should_download(self, file: File) -> bool:
|
|
162
|
+
if self.fs[file].drs.is_file():
|
|
163
|
+
return False
|
|
164
|
+
else:
|
|
165
|
+
return True
|
|
166
|
+
|
|
167
|
+
async def process(self) -> AsyncIterator[Result]:
|
|
168
|
+
semaphore = asyncio.Semaphore(self.config.download.max_concurrent)
|
|
169
|
+
streams = [task.stream(semaphore) for task in self.tasks]
|
|
170
|
+
async with merge(*streams).stream() as stream:
|
|
171
|
+
async for result in stream:
|
|
172
|
+
yield result
|
esgpull/py.typed
ADDED
|
File without changes
|
esgpull/result.py
ADDED
|
@@ -0,0 +1,53 @@
|
|
|
1
|
+
from dataclasses import dataclass, field
|
|
2
|
+
from typing import Generic, TypeVar
|
|
3
|
+
|
|
4
|
+
T = TypeVar("T")
|
|
5
|
+
|
|
6
|
+
|
|
7
|
+
@dataclass
|
|
8
|
+
class Result(Generic[T]):
|
|
9
|
+
data: T
|
|
10
|
+
ok: bool
|
|
11
|
+
err: BaseException | None
|
|
12
|
+
|
|
13
|
+
|
|
14
|
+
@dataclass
|
|
15
|
+
class Ok(Result[T]):
|
|
16
|
+
ok: bool = field(default=True, init=False)
|
|
17
|
+
err: None = field(default=None, init=False)
|
|
18
|
+
|
|
19
|
+
|
|
20
|
+
@dataclass
|
|
21
|
+
class Err(Result[T]):
|
|
22
|
+
ok: bool = field(default=False, init=False)
|
|
23
|
+
err: BaseException = field()
|
|
24
|
+
|
|
25
|
+
|
|
26
|
+
# class Result:
|
|
27
|
+
# ok: bool
|
|
28
|
+
# file: File
|
|
29
|
+
# completed: int
|
|
30
|
+
# err: BaseException | None
|
|
31
|
+
|
|
32
|
+
|
|
33
|
+
# class Ok(Result):
|
|
34
|
+
# ok = True
|
|
35
|
+
# file: File
|
|
36
|
+
# completed: int
|
|
37
|
+
# err = None
|
|
38
|
+
|
|
39
|
+
# def __init__(self, file: File, completed: int) -> None:
|
|
40
|
+
# self.file = file
|
|
41
|
+
# self.completed = completed
|
|
42
|
+
|
|
43
|
+
|
|
44
|
+
# class Err(Result):
|
|
45
|
+
# ok = False
|
|
46
|
+
# file: File
|
|
47
|
+
# completed: int
|
|
48
|
+
# err: BaseException
|
|
49
|
+
|
|
50
|
+
# def __init__(self, file: File, completed: int, err: BaseException) -> None:
|
|
51
|
+
# self.file = file
|
|
52
|
+
# self.completed = completed
|
|
53
|
+
# self.err = err
|