esgpull 0.8.0__py3-none-any.whl → 0.9.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- esgpull/cli/__init__.py +2 -2
- esgpull/cli/add.py +7 -1
- esgpull/cli/config.py +5 -21
- esgpull/cli/plugins.py +398 -0
- esgpull/cli/update.py +58 -15
- esgpull/cli/utils.py +16 -1
- esgpull/config.py +82 -25
- esgpull/constants.py +3 -0
- esgpull/context.py +9 -9
- esgpull/database.py +8 -2
- esgpull/download.py +3 -0
- esgpull/esgpull.py +49 -5
- esgpull/graph.py +1 -1
- esgpull/migrations/versions/0.9.0_update_tables.py +28 -0
- esgpull/migrations/versions/d14f179e553c_file_add_composite_index_dataset_id_.py +32 -0
- esgpull/migrations/versions/e7edab5d4e4b_add_dataset_tracking.py +39 -0
- esgpull/models/__init__.py +2 -1
- esgpull/models/base.py +31 -14
- esgpull/models/dataset.py +48 -5
- esgpull/models/query.py +58 -14
- esgpull/models/sql.py +40 -9
- esgpull/plugin.py +574 -0
- esgpull/processor.py +3 -3
- esgpull/tui.py +23 -1
- esgpull/utils.py +5 -1
- {esgpull-0.8.0.dist-info → esgpull-0.9.0.dist-info}/METADATA +2 -1
- {esgpull-0.8.0.dist-info → esgpull-0.9.0.dist-info}/RECORD +30 -26
- esgpull/cli/datasets.py +0 -78
- {esgpull-0.8.0.dist-info → esgpull-0.9.0.dist-info}/WHEEL +0 -0
- {esgpull-0.8.0.dist-info → esgpull-0.9.0.dist-info}/entry_points.txt +0 -0
- {esgpull-0.8.0.dist-info → esgpull-0.9.0.dist-info}/licenses/LICENSE +0 -0
esgpull/config.py
CHANGED
|
@@ -4,10 +4,10 @@ import logging
|
|
|
4
4
|
from collections.abc import Iterator, Mapping
|
|
5
5
|
from enum import Enum, auto
|
|
6
6
|
from pathlib import Path
|
|
7
|
-
from typing import Any, cast
|
|
7
|
+
from typing import Any, TypeVar, Union, cast, overload
|
|
8
8
|
|
|
9
9
|
import tomlkit
|
|
10
|
-
from attrs import Factory, define, field
|
|
10
|
+
from attrs import Factory, define, field
|
|
11
11
|
from attrs import has as attrs_has
|
|
12
12
|
from cattrs import Converter
|
|
13
13
|
from cattrs.gen import make_dict_unstructure_fn, override
|
|
@@ -20,6 +20,66 @@ from esgpull.models.options import Options
|
|
|
20
20
|
|
|
21
21
|
logger = logging.getLogger("esgpull")
|
|
22
22
|
|
|
23
|
+
T = TypeVar("T")
|
|
24
|
+
|
|
25
|
+
|
|
26
|
+
@overload
|
|
27
|
+
def cast_value(
|
|
28
|
+
target: str, value: Union[str, int, bool, float], key: str
|
|
29
|
+
) -> str: ...
|
|
30
|
+
|
|
31
|
+
|
|
32
|
+
@overload
|
|
33
|
+
def cast_value(
|
|
34
|
+
target: bool, value: Union[str, int, bool, float], key: str
|
|
35
|
+
) -> bool: ...
|
|
36
|
+
|
|
37
|
+
|
|
38
|
+
@overload
|
|
39
|
+
def cast_value(
|
|
40
|
+
target: int, value: Union[str, int, bool, float], key: str
|
|
41
|
+
) -> int: ...
|
|
42
|
+
|
|
43
|
+
|
|
44
|
+
@overload
|
|
45
|
+
def cast_value(
|
|
46
|
+
target: float, value: Union[str, int, bool, float], key: str
|
|
47
|
+
) -> float: ...
|
|
48
|
+
|
|
49
|
+
|
|
50
|
+
def cast_value(
|
|
51
|
+
target: Any, value: Union[str, int, bool, float], key: str
|
|
52
|
+
) -> Any:
|
|
53
|
+
if isinstance(value, type(target)):
|
|
54
|
+
return value
|
|
55
|
+
elif attrs_has(type(target)):
|
|
56
|
+
raise KeyError(key)
|
|
57
|
+
elif isinstance(target, str):
|
|
58
|
+
return str(value)
|
|
59
|
+
elif isinstance(target, float):
|
|
60
|
+
try:
|
|
61
|
+
return float(value)
|
|
62
|
+
except Exception:
|
|
63
|
+
raise ValueError(value)
|
|
64
|
+
elif isinstance(target, bool):
|
|
65
|
+
if isinstance(value, str):
|
|
66
|
+
if value.lower() in ["on", "true"]:
|
|
67
|
+
return True
|
|
68
|
+
elif value.lower() in ["off", "false"]:
|
|
69
|
+
return False
|
|
70
|
+
else:
|
|
71
|
+
raise ValueError(value)
|
|
72
|
+
else:
|
|
73
|
+
raise TypeError(value)
|
|
74
|
+
elif isinstance(target, int):
|
|
75
|
+
# int must be after bool, because isinstance(True, int) == True
|
|
76
|
+
try:
|
|
77
|
+
return int(value)
|
|
78
|
+
except Exception:
|
|
79
|
+
raise ValueError(value)
|
|
80
|
+
else:
|
|
81
|
+
raise TypeError(value)
|
|
82
|
+
|
|
23
83
|
|
|
24
84
|
@define
|
|
25
85
|
class Paths:
|
|
@@ -28,6 +88,7 @@ class Paths:
|
|
|
28
88
|
db: Path = field(converter=Path)
|
|
29
89
|
log: Path = field(converter=Path)
|
|
30
90
|
tmp: Path = field(converter=Path)
|
|
91
|
+
plugins: Path = field(converter=Path)
|
|
31
92
|
|
|
32
93
|
@auth.default
|
|
33
94
|
def _auth_factory(self) -> Path:
|
|
@@ -69,12 +130,21 @@ class Paths:
|
|
|
69
130
|
root = InstallConfig.default
|
|
70
131
|
return root / "tmp"
|
|
71
132
|
|
|
133
|
+
@plugins.default
|
|
134
|
+
def _plugins_factory(self) -> Path:
|
|
135
|
+
if InstallConfig.current is not None:
|
|
136
|
+
root = InstallConfig.current.path
|
|
137
|
+
else:
|
|
138
|
+
root = InstallConfig.default
|
|
139
|
+
return root / "plugins"
|
|
140
|
+
|
|
72
141
|
def __iter__(self) -> Iterator[Path]:
|
|
73
142
|
yield self.auth
|
|
74
143
|
yield self.data
|
|
75
144
|
yield self.db
|
|
76
145
|
yield self.log
|
|
77
146
|
yield self.tmp
|
|
147
|
+
yield self.plugins
|
|
78
148
|
|
|
79
149
|
|
|
80
150
|
@define
|
|
@@ -213,6 +283,13 @@ def iter_keys(
|
|
|
213
283
|
yield local_path
|
|
214
284
|
|
|
215
285
|
|
|
286
|
+
@define
|
|
287
|
+
class Plugins:
|
|
288
|
+
"""Configuration for the plugin system"""
|
|
289
|
+
|
|
290
|
+
enabled: bool = False
|
|
291
|
+
|
|
292
|
+
|
|
216
293
|
@define
|
|
217
294
|
class Config:
|
|
218
295
|
paths: Paths = Factory(Paths)
|
|
@@ -221,6 +298,7 @@ class Config:
|
|
|
221
298
|
db: Db = Factory(Db)
|
|
222
299
|
download: Download = Factory(Download)
|
|
223
300
|
api: API = Factory(API)
|
|
301
|
+
plugins: Plugins = Factory(Plugins)
|
|
224
302
|
_raw: TOMLDocument | None = field(init=False, default=None)
|
|
225
303
|
_config_file: Path | None = field(init=False, default=None)
|
|
226
304
|
|
|
@@ -287,7 +365,7 @@ class Config:
|
|
|
287
365
|
def update_item(
|
|
288
366
|
self,
|
|
289
367
|
key: str,
|
|
290
|
-
value: int |
|
|
368
|
+
value: str | int | bool,
|
|
291
369
|
empty_ok: bool = False,
|
|
292
370
|
) -> int | str | None:
|
|
293
371
|
if self._raw is None and empty_ok:
|
|
@@ -302,29 +380,8 @@ class Config:
|
|
|
302
380
|
doc.setdefault(part, {})
|
|
303
381
|
doc = doc[part]
|
|
304
382
|
obj = getattr(obj, part)
|
|
305
|
-
value_type = getattr(fields(type(obj)), last).type
|
|
306
383
|
old_value = getattr(obj, last)
|
|
307
|
-
|
|
308
|
-
raise KeyError(key)
|
|
309
|
-
elif value_type is str:
|
|
310
|
-
...
|
|
311
|
-
elif value_type is int:
|
|
312
|
-
try:
|
|
313
|
-
value = value_type(value)
|
|
314
|
-
except Exception:
|
|
315
|
-
...
|
|
316
|
-
elif value_type is bool:
|
|
317
|
-
if isinstance(value, bool):
|
|
318
|
-
...
|
|
319
|
-
elif isinstance(value, str):
|
|
320
|
-
if value.lower() in ["on", "true"]:
|
|
321
|
-
value = True
|
|
322
|
-
elif value.lower() in ["off", "false"]:
|
|
323
|
-
value = False
|
|
324
|
-
else:
|
|
325
|
-
raise ValueError(value)
|
|
326
|
-
else:
|
|
327
|
-
raise TypeError(value)
|
|
384
|
+
value = cast_value(old_value, value, key)
|
|
328
385
|
setattr(obj, last, value)
|
|
329
386
|
doc[last] = value
|
|
330
387
|
return old_value
|
esgpull/constants.py
CHANGED
esgpull/context.py
CHANGED
|
@@ -16,7 +16,7 @@ from rich.pretty import pretty_repr
|
|
|
16
16
|
|
|
17
17
|
from esgpull.config import Config
|
|
18
18
|
from esgpull.exceptions import SolrUnstableQueryError
|
|
19
|
-
from esgpull.models import
|
|
19
|
+
from esgpull.models import DatasetRecord, File, Query
|
|
20
20
|
from esgpull.tui import logger
|
|
21
21
|
from esgpull.utils import format_date_iso, index2url, sync
|
|
22
22
|
|
|
@@ -151,7 +151,7 @@ class ResultHints(Result):
|
|
|
151
151
|
|
|
152
152
|
@dataclass
|
|
153
153
|
class ResultSearch(Result):
|
|
154
|
-
data: Sequence[File |
|
|
154
|
+
data: Sequence[File | DatasetRecord] = field(init=False, repr=False)
|
|
155
155
|
|
|
156
156
|
def process(self) -> None:
|
|
157
157
|
raise NotImplementedError
|
|
@@ -159,14 +159,14 @@ class ResultSearch(Result):
|
|
|
159
159
|
|
|
160
160
|
@dataclass
|
|
161
161
|
class ResultDatasets(Result):
|
|
162
|
-
data: Sequence[
|
|
162
|
+
data: Sequence[DatasetRecord] = field(init=False, repr=False)
|
|
163
163
|
|
|
164
164
|
def process(self) -> None:
|
|
165
165
|
self.data = []
|
|
166
166
|
if self.success:
|
|
167
167
|
for doc in self.json["response"]["docs"]:
|
|
168
168
|
try:
|
|
169
|
-
dataset =
|
|
169
|
+
dataset = DatasetRecord.serialize(doc)
|
|
170
170
|
self.data.append(dataset)
|
|
171
171
|
except KeyError as exc:
|
|
172
172
|
logger.exception(exc)
|
|
@@ -492,8 +492,8 @@ class Context:
|
|
|
492
492
|
self,
|
|
493
493
|
*results: ResultSearch,
|
|
494
494
|
keep_duplicates: bool,
|
|
495
|
-
) -> list[
|
|
496
|
-
datasets: list[
|
|
495
|
+
) -> list[DatasetRecord]:
|
|
496
|
+
datasets: list[DatasetRecord] = []
|
|
497
497
|
ids: set[str] = set()
|
|
498
498
|
async for result in self._fetch(*results):
|
|
499
499
|
dataset_result = result.to(ResultDatasets)
|
|
@@ -501,7 +501,7 @@ class Context:
|
|
|
501
501
|
if dataset_result.processed:
|
|
502
502
|
for d in dataset_result.data:
|
|
503
503
|
if not keep_duplicates and d.dataset_id in ids:
|
|
504
|
-
logger.
|
|
504
|
+
logger.debug(f"Duplicate dataset {d.dataset_id}")
|
|
505
505
|
else:
|
|
506
506
|
datasets.append(d)
|
|
507
507
|
ids.add(d.dataset_id)
|
|
@@ -520,7 +520,7 @@ class Context:
|
|
|
520
520
|
if files_result.processed:
|
|
521
521
|
for file in files_result.data:
|
|
522
522
|
if not keep_duplicates and file.sha in shas:
|
|
523
|
-
logger.
|
|
523
|
+
logger.debug(f"Duplicate file {file.file_id}")
|
|
524
524
|
else:
|
|
525
525
|
files.append(file)
|
|
526
526
|
shas.add(file.sha)
|
|
@@ -627,7 +627,7 @@ class Context:
|
|
|
627
627
|
date_from: datetime | None = None,
|
|
628
628
|
date_to: datetime | None = None,
|
|
629
629
|
keep_duplicates: bool = True,
|
|
630
|
-
) -> list[
|
|
630
|
+
) -> list[DatasetRecord]:
|
|
631
631
|
if hits is None:
|
|
632
632
|
hits = self.hits(*queries, file=False)
|
|
633
633
|
results = self.prepare_search(
|
esgpull/database.py
CHANGED
|
@@ -12,11 +12,13 @@ import sqlalchemy.orm
|
|
|
12
12
|
from alembic.config import Config as AlembicConfig
|
|
13
13
|
from alembic.migration import MigrationContext
|
|
14
14
|
from alembic.script import ScriptDirectory
|
|
15
|
+
from sqlalchemy.inspection import inspect
|
|
15
16
|
from sqlalchemy.orm import Session, joinedload, make_transient
|
|
16
17
|
|
|
17
18
|
from esgpull import __file__
|
|
18
19
|
from esgpull.config import Config
|
|
19
20
|
from esgpull.models import File, Query, Table, sql
|
|
21
|
+
from esgpull.models.base import Base, BaseNoSHA
|
|
20
22
|
from esgpull.version import __version__
|
|
21
23
|
|
|
22
24
|
# from esgpull.exceptions import NoClauseError
|
|
@@ -151,8 +153,12 @@ class Database:
|
|
|
151
153
|
def unlink(self, query: Query, file: File):
|
|
152
154
|
self.session.execute(sql.query_file.unlink(query, file))
|
|
153
155
|
|
|
154
|
-
def __contains__(self, item:
|
|
155
|
-
|
|
156
|
+
def __contains__(self, item: Base | BaseNoSHA) -> bool:
|
|
157
|
+
mapper = inspect(item.__class__)
|
|
158
|
+
pk_col = mapper.primary_key[0]
|
|
159
|
+
pk_value = getattr(item, pk_col.name)
|
|
160
|
+
stmt = sa.exists().where(pk_col == pk_value)
|
|
161
|
+
return self.scalars(sa.select(stmt))[0]
|
|
156
162
|
|
|
157
163
|
def has_file_id(self, file: File) -> bool:
|
|
158
164
|
return len(self.scalars(sql.file.with_file_id(file.file_id))) == 1
|
esgpull/download.py
CHANGED
|
@@ -1,6 +1,7 @@
|
|
|
1
1
|
# from math import ceil
|
|
2
2
|
from collections.abc import AsyncGenerator
|
|
3
3
|
from dataclasses import dataclass
|
|
4
|
+
from datetime import datetime
|
|
4
5
|
|
|
5
6
|
from httpx import AsyncClient
|
|
6
7
|
|
|
@@ -19,6 +20,7 @@ class DownloadCtx:
|
|
|
19
20
|
completed: int = 0
|
|
20
21
|
chunk: bytes | None = None
|
|
21
22
|
digest: Digest | None = None
|
|
23
|
+
start_time: datetime | None = None
|
|
22
24
|
|
|
23
25
|
@property
|
|
24
26
|
def finished(self) -> bool:
|
|
@@ -54,6 +56,7 @@ class Simple(BaseDownloader):
|
|
|
54
56
|
ctx: DownloadCtx,
|
|
55
57
|
chunk_size: int,
|
|
56
58
|
) -> AsyncGenerator[DownloadCtx, None]:
|
|
59
|
+
ctx.start_time = datetime.now()
|
|
57
60
|
async with client.stream("GET", ctx.file.url) as resp:
|
|
58
61
|
resp.raise_for_status()
|
|
59
62
|
async for chunk in resp.aiter_bytes(chunk_size=chunk_size):
|
esgpull/esgpull.py
CHANGED
|
@@ -3,6 +3,7 @@ from __future__ import annotations
|
|
|
3
3
|
import logging
|
|
4
4
|
from collections.abc import AsyncIterator
|
|
5
5
|
from dataclasses import dataclass
|
|
6
|
+
from datetime import datetime
|
|
6
7
|
from functools import cached_property, partial
|
|
7
8
|
from pathlib import Path
|
|
8
9
|
from warnings import warn
|
|
@@ -25,6 +26,7 @@ from esgpull.auth import Auth, Credentials
|
|
|
25
26
|
from esgpull.config import Config
|
|
26
27
|
from esgpull.context import Context
|
|
27
28
|
from esgpull.database import Database
|
|
29
|
+
from esgpull.download import DownloadCtx
|
|
28
30
|
from esgpull.exceptions import (
|
|
29
31
|
DownloadCancelled,
|
|
30
32
|
InvalidInstallPath,
|
|
@@ -44,6 +46,13 @@ from esgpull.models import (
|
|
|
44
46
|
sql,
|
|
45
47
|
)
|
|
46
48
|
from esgpull.models.utils import short_sha
|
|
49
|
+
from esgpull.plugin import (
|
|
50
|
+
Event,
|
|
51
|
+
PluginManager,
|
|
52
|
+
emit,
|
|
53
|
+
get_plugin_manager,
|
|
54
|
+
set_plugin_manager,
|
|
55
|
+
)
|
|
47
56
|
from esgpull.processor import Processor
|
|
48
57
|
from esgpull.result import Err, Ok, Result
|
|
49
58
|
from esgpull.tui import UI, DummyLive, Verbosity, logger
|
|
@@ -117,6 +126,18 @@ class Esgpull:
|
|
|
117
126
|
if load_db:
|
|
118
127
|
self.db = Database.from_config(self.config)
|
|
119
128
|
self.graph = Graph(self.db)
|
|
129
|
+
# Initialize plugin system
|
|
130
|
+
plugin_config_path = self.config.paths.plugins / "plugins.toml"
|
|
131
|
+
try:
|
|
132
|
+
self.plugin_manager = get_plugin_manager()
|
|
133
|
+
self.plugin_manager.__init__(config_path=plugin_config_path)
|
|
134
|
+
except ValueError:
|
|
135
|
+
self.plugin_manager = PluginManager(config_path=plugin_config_path)
|
|
136
|
+
set_plugin_manager(self.plugin_manager)
|
|
137
|
+
if self.config.plugins.enabled:
|
|
138
|
+
self.plugin_manager.enabled = True
|
|
139
|
+
self.config.paths.plugins.mkdir(exist_ok=True, parents=True)
|
|
140
|
+
self.plugin_manager.discover_plugins(self.config.paths.plugins)
|
|
120
141
|
|
|
121
142
|
def fetch_index_nodes(self) -> list[str]:
|
|
122
143
|
"""
|
|
@@ -309,7 +330,7 @@ class Esgpull:
|
|
|
309
330
|
progress: Progress,
|
|
310
331
|
task_ids: dict[str, TaskID],
|
|
311
332
|
live: Live | DummyLive,
|
|
312
|
-
) -> AsyncIterator[Result]:
|
|
333
|
+
) -> AsyncIterator[Result[DownloadCtx]]:
|
|
313
334
|
async for result in processor.process():
|
|
314
335
|
task_idx = progress.task_ids.index(task_ids[result.data.file.sha])
|
|
315
336
|
task = progress.tasks[task_idx]
|
|
@@ -348,7 +369,7 @@ class Esgpull:
|
|
|
348
369
|
yield result
|
|
349
370
|
case Err(_, err):
|
|
350
371
|
progress.remove_task(task.id)
|
|
351
|
-
yield Err(result.data, err)
|
|
372
|
+
yield Err(result.data, err=err)
|
|
352
373
|
case Err():
|
|
353
374
|
progress.remove_task(task.id)
|
|
354
375
|
yield result
|
|
@@ -438,15 +459,38 @@ class Esgpull:
|
|
|
438
459
|
match result:
|
|
439
460
|
case Ok():
|
|
440
461
|
main_progress.update(main_task_id, advance=1)
|
|
441
|
-
result.data.file
|
|
442
|
-
|
|
443
|
-
|
|
462
|
+
file = result.data.file
|
|
463
|
+
file.status = FileStatus.Done
|
|
464
|
+
files.append(file)
|
|
465
|
+
emit(
|
|
466
|
+
Event.file_complete,
|
|
467
|
+
file=file,
|
|
468
|
+
destination=self.fs[file].drs,
|
|
469
|
+
start_time=result.data.start_time,
|
|
470
|
+
end_time=datetime.now(),
|
|
471
|
+
)
|
|
472
|
+
if file.dataset is not None:
|
|
473
|
+
is_dataset_complete = self.db.scalars(
|
|
474
|
+
sql.dataset.is_complete(file.dataset)
|
|
475
|
+
)[0]
|
|
476
|
+
if is_dataset_complete:
|
|
477
|
+
emit(
|
|
478
|
+
Event.dataset_complete,
|
|
479
|
+
dataset=file.dataset,
|
|
480
|
+
)
|
|
481
|
+
case Err(_, err):
|
|
444
482
|
queue_size -= 1
|
|
445
483
|
main_progress.update(
|
|
446
484
|
main_task_id, total=queue_size
|
|
447
485
|
)
|
|
448
486
|
result.data.file.status = FileStatus.Error
|
|
449
487
|
errors.append(result)
|
|
488
|
+
emit(
|
|
489
|
+
Event.file_error,
|
|
490
|
+
file=result.data.file,
|
|
491
|
+
exception=err,
|
|
492
|
+
)
|
|
493
|
+
|
|
450
494
|
if use_db:
|
|
451
495
|
self.db.add(result.data.file)
|
|
452
496
|
remaining_dict.pop(result.data.file.sha, None)
|
esgpull/graph.py
CHANGED
|
@@ -418,7 +418,7 @@ class Graph:
|
|
|
418
418
|
if keep_require:
|
|
419
419
|
query_tree = query._rich_tree()
|
|
420
420
|
else:
|
|
421
|
-
query_tree = query.
|
|
421
|
+
query_tree = query._rich_tree(hide_require=True)
|
|
422
422
|
if query_tree is not None:
|
|
423
423
|
tree.add(query_tree)
|
|
424
424
|
self.fill_tree(query, query_tree)
|
|
@@ -0,0 +1,28 @@
|
|
|
1
|
+
"""update tables
|
|
2
|
+
|
|
3
|
+
Revision ID: 0.9.0
|
|
4
|
+
Revises: d14f179e553c
|
|
5
|
+
Create Date: 2025-07-07 14:54:58.433022
|
|
6
|
+
|
|
7
|
+
"""
|
|
8
|
+
from alembic import op
|
|
9
|
+
import sqlalchemy as sa
|
|
10
|
+
|
|
11
|
+
|
|
12
|
+
# revision identifiers, used by Alembic.
|
|
13
|
+
revision = '0.9.0'
|
|
14
|
+
down_revision = 'd14f179e553c'
|
|
15
|
+
branch_labels = None
|
|
16
|
+
depends_on = None
|
|
17
|
+
|
|
18
|
+
|
|
19
|
+
def upgrade() -> None:
|
|
20
|
+
# ### commands auto generated by Alembic - please adjust! ###
|
|
21
|
+
pass
|
|
22
|
+
# ### end Alembic commands ###
|
|
23
|
+
|
|
24
|
+
|
|
25
|
+
def downgrade() -> None:
|
|
26
|
+
# ### commands auto generated by Alembic - please adjust! ###
|
|
27
|
+
pass
|
|
28
|
+
# ### end Alembic commands ###
|
|
@@ -0,0 +1,32 @@
|
|
|
1
|
+
"""file_add_composite_index_dataset_id_status
|
|
2
|
+
|
|
3
|
+
Revision ID: d14f179e553c
|
|
4
|
+
Revises: e7edab5d4e4b
|
|
5
|
+
Create Date: 2025-06-18 16:05:35.721085
|
|
6
|
+
|
|
7
|
+
"""
|
|
8
|
+
from alembic import op
|
|
9
|
+
import sqlalchemy as sa
|
|
10
|
+
|
|
11
|
+
|
|
12
|
+
# revision identifiers, used by Alembic.
|
|
13
|
+
revision = 'd14f179e553c'
|
|
14
|
+
down_revision = 'e7edab5d4e4b'
|
|
15
|
+
branch_labels = None
|
|
16
|
+
depends_on = None
|
|
17
|
+
|
|
18
|
+
|
|
19
|
+
def upgrade() -> None:
|
|
20
|
+
# ### commands auto generated by Alembic - please adjust! ###
|
|
21
|
+
with op.batch_alter_table('file', schema=None) as batch_op:
|
|
22
|
+
batch_op.create_index('ix_file_dataset_status', ['dataset_id', 'status'], unique=False)
|
|
23
|
+
|
|
24
|
+
# ### end Alembic commands ###
|
|
25
|
+
|
|
26
|
+
|
|
27
|
+
def downgrade() -> None:
|
|
28
|
+
# ### commands auto generated by Alembic - please adjust! ###
|
|
29
|
+
with op.batch_alter_table('file', schema=None) as batch_op:
|
|
30
|
+
batch_op.drop_index('ix_file_dataset_status')
|
|
31
|
+
|
|
32
|
+
# ### end Alembic commands ###
|
|
@@ -0,0 +1,39 @@
|
|
|
1
|
+
"""add_dataset_tracking
|
|
2
|
+
|
|
3
|
+
Revision ID: e7edab5d4e4b
|
|
4
|
+
Revises: 0.8.0
|
|
5
|
+
Create Date: 2025-05-23 17:38:22.066153
|
|
6
|
+
|
|
7
|
+
"""
|
|
8
|
+
from alembic import op
|
|
9
|
+
import sqlalchemy as sa
|
|
10
|
+
|
|
11
|
+
|
|
12
|
+
# revision identifiers, used by Alembic.
|
|
13
|
+
revision = 'e7edab5d4e4b'
|
|
14
|
+
down_revision = '0.8.0'
|
|
15
|
+
branch_labels = None
|
|
16
|
+
depends_on = None
|
|
17
|
+
|
|
18
|
+
|
|
19
|
+
def upgrade() -> None:
|
|
20
|
+
# ### commands auto generated by Alembic - please adjust! ###
|
|
21
|
+
op.create_table('dataset',
|
|
22
|
+
sa.Column('dataset_id', sa.String(length=255), nullable=False),
|
|
23
|
+
sa.Column('total_files', sa.Integer(), nullable=False),
|
|
24
|
+
sa.Column('created_at', sa.DateTime(), server_default=sa.text('(CURRENT_TIMESTAMP)'), nullable=False),
|
|
25
|
+
sa.Column('updated_at', sa.DateTime(), server_default=sa.text('(CURRENT_TIMESTAMP)'), nullable=False),
|
|
26
|
+
sa.PrimaryKeyConstraint('dataset_id')
|
|
27
|
+
)
|
|
28
|
+
with op.batch_alter_table('file', schema=None) as batch_op:
|
|
29
|
+
batch_op.create_foreign_key('fk_file_dataset', 'dataset', ['dataset_id'], ['dataset_id'])
|
|
30
|
+
|
|
31
|
+
# ### end Alembic commands ###
|
|
32
|
+
|
|
33
|
+
def downgrade() -> None:
|
|
34
|
+
# ### commands auto generated by Alembic - please adjust! ###
|
|
35
|
+
with op.batch_alter_table('file', schema=None) as batch_op:
|
|
36
|
+
batch_op.drop_constraint('fk_file_dataset', type_='foreignkey')
|
|
37
|
+
|
|
38
|
+
op.drop_table('dataset')
|
|
39
|
+
# ### end Alembic commands ###
|
esgpull/models/__init__.py
CHANGED
|
@@ -1,7 +1,7 @@
|
|
|
1
1
|
from typing import TypeVar
|
|
2
2
|
|
|
3
3
|
from esgpull.models.base import Base
|
|
4
|
-
from esgpull.models.dataset import Dataset
|
|
4
|
+
from esgpull.models.dataset import Dataset, DatasetRecord
|
|
5
5
|
from esgpull.models.facet import Facet
|
|
6
6
|
from esgpull.models.file import FastFile, FileStatus
|
|
7
7
|
from esgpull.models.options import Option, Options
|
|
@@ -15,6 +15,7 @@ Table = TypeVar("Table", bound=Base)
|
|
|
15
15
|
__all__ = [
|
|
16
16
|
"Base",
|
|
17
17
|
"Dataset",
|
|
18
|
+
"DatasetRecord",
|
|
18
19
|
"Facet",
|
|
19
20
|
"FastFile",
|
|
20
21
|
"File",
|
esgpull/models/base.py
CHANGED
|
@@ -16,16 +16,10 @@ T = TypeVar("T")
|
|
|
16
16
|
Sha = sa.String(40)
|
|
17
17
|
|
|
18
18
|
|
|
19
|
-
|
|
19
|
+
# Base class for all models - provides core SQLAlchemy functionality
|
|
20
|
+
class _BaseModel(MappedAsDataclass, DeclarativeBase):
|
|
20
21
|
__dataclass_fields__: ClassVar[dict[str, Field]]
|
|
21
|
-
__sql_attrs__ = ("id", "
|
|
22
|
-
|
|
23
|
-
sha: Mapped[str] = mapped_column(
|
|
24
|
-
Sha,
|
|
25
|
-
init=False,
|
|
26
|
-
repr=False,
|
|
27
|
-
primary_key=True,
|
|
28
|
-
)
|
|
22
|
+
__sql_attrs__ = ("id", "_sa_instance_state", "__dataclass_fields__")
|
|
29
23
|
|
|
30
24
|
@property
|
|
31
25
|
def _names(self) -> tuple[str, ...]:
|
|
@@ -36,15 +30,38 @@ class Base(MappedAsDataclass, DeclarativeBase):
|
|
|
36
30
|
result += (name,)
|
|
37
31
|
return result
|
|
38
32
|
|
|
33
|
+
@property
|
|
34
|
+
def state(self) -> InstanceState:
|
|
35
|
+
return cast(InstanceState, sa.inspect(self))
|
|
36
|
+
|
|
37
|
+
def asdict(self) -> Mapping[str, Any]:
|
|
38
|
+
raise NotImplementedError
|
|
39
|
+
|
|
40
|
+
|
|
41
|
+
# Base class for models that use SHA as primary key
|
|
42
|
+
class Base(_BaseModel):
|
|
43
|
+
__abstract__ = True
|
|
44
|
+
__sql_attrs__ = ("id", "sha", "_sa_instance_state", "__dataclass_fields__")
|
|
45
|
+
|
|
46
|
+
sha: Mapped[str] = mapped_column(
|
|
47
|
+
Sha,
|
|
48
|
+
init=False,
|
|
49
|
+
repr=False,
|
|
50
|
+
primary_key=True,
|
|
51
|
+
)
|
|
52
|
+
|
|
39
53
|
def _as_bytes(self) -> bytes:
|
|
40
54
|
raise NotImplementedError
|
|
41
55
|
|
|
42
56
|
def compute_sha(self) -> None:
|
|
43
57
|
self.sha = sha1(self._as_bytes()).hexdigest()
|
|
44
58
|
|
|
45
|
-
@property
|
|
46
|
-
def state(self) -> InstanceState:
|
|
47
|
-
return cast(InstanceState, sa.inspect(self))
|
|
48
59
|
|
|
49
|
-
|
|
50
|
-
|
|
60
|
+
# Base class for models that don't use SHA (e.g., Dataset)
|
|
61
|
+
class BaseNoSHA(_BaseModel):
|
|
62
|
+
__abstract__ = True
|
|
63
|
+
__sql_attrs__ = ("id", "_sa_instance_state", "__dataclass_fields__")
|
|
64
|
+
|
|
65
|
+
|
|
66
|
+
# Keep SHAKeyMixin for backward compatibility if needed
|
|
67
|
+
SHAKeyMixin = Base
|
esgpull/models/dataset.py
CHANGED
|
@@ -1,12 +1,22 @@
|
|
|
1
1
|
from __future__ import annotations
|
|
2
2
|
|
|
3
|
-
from
|
|
3
|
+
from collections.abc import Mapping
|
|
4
|
+
from dataclasses import dataclass
|
|
5
|
+
from datetime import datetime, timezone
|
|
6
|
+
from typing import TYPE_CHECKING, Any
|
|
4
7
|
|
|
8
|
+
import sqlalchemy as sa
|
|
9
|
+
from sqlalchemy.orm import Mapped, mapped_column, relationship
|
|
10
|
+
|
|
11
|
+
from esgpull.models.base import BaseNoSHA
|
|
5
12
|
from esgpull.models.utils import find_int, find_str
|
|
6
13
|
|
|
14
|
+
if TYPE_CHECKING:
|
|
15
|
+
from esgpull.models.query import File
|
|
16
|
+
|
|
7
17
|
|
|
8
18
|
@dataclass
|
|
9
|
-
class
|
|
19
|
+
class DatasetRecord:
|
|
10
20
|
dataset_id: str
|
|
11
21
|
master_id: str
|
|
12
22
|
version: str
|
|
@@ -15,7 +25,7 @@ class Dataset:
|
|
|
15
25
|
number_of_files: int
|
|
16
26
|
|
|
17
27
|
@classmethod
|
|
18
|
-
def serialize(cls, source: dict) ->
|
|
28
|
+
def serialize(cls, source: dict) -> DatasetRecord:
|
|
19
29
|
dataset_id = find_str(source["instance_id"]).partition("|")[0]
|
|
20
30
|
master_id, version = dataset_id.rsplit(".", 1)
|
|
21
31
|
data_node = find_str(source["data_node"])
|
|
@@ -30,5 +40,38 @@ class Dataset:
|
|
|
30
40
|
number_of_files=number_of_files,
|
|
31
41
|
)
|
|
32
42
|
|
|
33
|
-
|
|
34
|
-
|
|
43
|
+
|
|
44
|
+
class Dataset(BaseNoSHA):
|
|
45
|
+
__tablename__ = "dataset"
|
|
46
|
+
|
|
47
|
+
dataset_id: Mapped[str] = mapped_column(sa.String(255), primary_key=True)
|
|
48
|
+
total_files: Mapped[int] = mapped_column(sa.Integer)
|
|
49
|
+
created_at: Mapped[datetime] = mapped_column(
|
|
50
|
+
server_default=sa.func.now(),
|
|
51
|
+
default_factory=lambda: datetime.now(timezone.utc),
|
|
52
|
+
init=False,
|
|
53
|
+
)
|
|
54
|
+
updated_at: Mapped[datetime] = mapped_column(
|
|
55
|
+
server_default=sa.func.now(),
|
|
56
|
+
default_factory=lambda: datetime.now(timezone.utc),
|
|
57
|
+
init=False,
|
|
58
|
+
)
|
|
59
|
+
files: Mapped[list[File]] = relationship(
|
|
60
|
+
back_populates="dataset",
|
|
61
|
+
foreign_keys="[File.dataset_id]",
|
|
62
|
+
primaryjoin="Dataset.dataset_id==File.dataset_id",
|
|
63
|
+
default_factory=list,
|
|
64
|
+
init=False,
|
|
65
|
+
repr=False,
|
|
66
|
+
)
|
|
67
|
+
|
|
68
|
+
def asdict(self) -> Mapping[str, Any]:
|
|
69
|
+
return {
|
|
70
|
+
"dataset_id": self.dataset_id,
|
|
71
|
+
"total_files": self.total_files,
|
|
72
|
+
"created_at": self.created_at.isoformat(),
|
|
73
|
+
"updated_at": self.updated_at.isoformat(),
|
|
74
|
+
}
|
|
75
|
+
|
|
76
|
+
def __hash__(self) -> int:
|
|
77
|
+
return hash(self.dataset_id)
|