esgpull 0.8.0__py3-none-any.whl → 0.9.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
esgpull/config.py CHANGED
@@ -4,10 +4,10 @@ import logging
4
4
  from collections.abc import Iterator, Mapping
5
5
  from enum import Enum, auto
6
6
  from pathlib import Path
7
- from typing import Any, cast
7
+ from typing import Any, TypeVar, Union, cast, overload
8
8
 
9
9
  import tomlkit
10
- from attrs import Factory, define, field, fields
10
+ from attrs import Factory, define, field
11
11
  from attrs import has as attrs_has
12
12
  from cattrs import Converter
13
13
  from cattrs.gen import make_dict_unstructure_fn, override
@@ -20,6 +20,66 @@ from esgpull.models.options import Options
20
20
 
21
21
  logger = logging.getLogger("esgpull")
22
22
 
23
+ T = TypeVar("T")
24
+
25
+
26
+ @overload
27
+ def cast_value(
28
+ target: str, value: Union[str, int, bool, float], key: str
29
+ ) -> str: ...
30
+
31
+
32
+ @overload
33
+ def cast_value(
34
+ target: bool, value: Union[str, int, bool, float], key: str
35
+ ) -> bool: ...
36
+
37
+
38
+ @overload
39
+ def cast_value(
40
+ target: int, value: Union[str, int, bool, float], key: str
41
+ ) -> int: ...
42
+
43
+
44
+ @overload
45
+ def cast_value(
46
+ target: float, value: Union[str, int, bool, float], key: str
47
+ ) -> float: ...
48
+
49
+
50
+ def cast_value(
51
+ target: Any, value: Union[str, int, bool, float], key: str
52
+ ) -> Any:
53
+ if isinstance(value, type(target)):
54
+ return value
55
+ elif attrs_has(type(target)):
56
+ raise KeyError(key)
57
+ elif isinstance(target, str):
58
+ return str(value)
59
+ elif isinstance(target, float):
60
+ try:
61
+ return float(value)
62
+ except Exception:
63
+ raise ValueError(value)
64
+ elif isinstance(target, bool):
65
+ if isinstance(value, str):
66
+ if value.lower() in ["on", "true"]:
67
+ return True
68
+ elif value.lower() in ["off", "false"]:
69
+ return False
70
+ else:
71
+ raise ValueError(value)
72
+ else:
73
+ raise TypeError(value)
74
+ elif isinstance(target, int):
75
+ # int must be after bool, because isinstance(True, int) == True
76
+ try:
77
+ return int(value)
78
+ except Exception:
79
+ raise ValueError(value)
80
+ else:
81
+ raise TypeError(value)
82
+
23
83
 
24
84
  @define
25
85
  class Paths:
@@ -28,6 +88,7 @@ class Paths:
28
88
  db: Path = field(converter=Path)
29
89
  log: Path = field(converter=Path)
30
90
  tmp: Path = field(converter=Path)
91
+ plugins: Path = field(converter=Path)
31
92
 
32
93
  @auth.default
33
94
  def _auth_factory(self) -> Path:
@@ -69,12 +130,21 @@ class Paths:
69
130
  root = InstallConfig.default
70
131
  return root / "tmp"
71
132
 
133
+ @plugins.default
134
+ def _plugins_factory(self) -> Path:
135
+ if InstallConfig.current is not None:
136
+ root = InstallConfig.current.path
137
+ else:
138
+ root = InstallConfig.default
139
+ return root / "plugins"
140
+
72
141
  def __iter__(self) -> Iterator[Path]:
73
142
  yield self.auth
74
143
  yield self.data
75
144
  yield self.db
76
145
  yield self.log
77
146
  yield self.tmp
147
+ yield self.plugins
78
148
 
79
149
 
80
150
  @define
@@ -213,6 +283,13 @@ def iter_keys(
213
283
  yield local_path
214
284
 
215
285
 
286
+ @define
287
+ class Plugins:
288
+ """Configuration for the plugin system"""
289
+
290
+ enabled: bool = False
291
+
292
+
216
293
  @define
217
294
  class Config:
218
295
  paths: Paths = Factory(Paths)
@@ -221,6 +298,7 @@ class Config:
221
298
  db: Db = Factory(Db)
222
299
  download: Download = Factory(Download)
223
300
  api: API = Factory(API)
301
+ plugins: Plugins = Factory(Plugins)
224
302
  _raw: TOMLDocument | None = field(init=False, default=None)
225
303
  _config_file: Path | None = field(init=False, default=None)
226
304
 
@@ -287,7 +365,7 @@ class Config:
287
365
  def update_item(
288
366
  self,
289
367
  key: str,
290
- value: int | str,
368
+ value: str | int | bool,
291
369
  empty_ok: bool = False,
292
370
  ) -> int | str | None:
293
371
  if self._raw is None and empty_ok:
@@ -302,29 +380,8 @@ class Config:
302
380
  doc.setdefault(part, {})
303
381
  doc = doc[part]
304
382
  obj = getattr(obj, part)
305
- value_type = getattr(fields(type(obj)), last).type
306
383
  old_value = getattr(obj, last)
307
- if attrs_has(value_type):
308
- raise KeyError(key)
309
- elif value_type is str:
310
- ...
311
- elif value_type is int:
312
- try:
313
- value = value_type(value)
314
- except Exception:
315
- ...
316
- elif value_type is bool:
317
- if isinstance(value, bool):
318
- ...
319
- elif isinstance(value, str):
320
- if value.lower() in ["on", "true"]:
321
- value = True
322
- elif value.lower() in ["off", "false"]:
323
- value = False
324
- else:
325
- raise ValueError(value)
326
- else:
327
- raise TypeError(value)
384
+ value = cast_value(old_value, value, key)
328
385
  setattr(obj, last, value)
329
386
  doc[last] = value
330
387
  return old_value
esgpull/constants.py CHANGED
@@ -1,6 +1,9 @@
1
+ import os
2
+
1
3
  CONFIG_FILENAME = "config.toml"
2
4
  INSTALLS_PATH_ENV = "ESGPULL_INSTALLS_PATH"
3
5
  ROOT_ENV = "ESGPULL_CURRENT"
6
+ ESGPULL_DEBUG = os.environ.get("ESGPULL_DEBUG", "0") == "1"
4
7
 
5
8
  IDP = "/esgf-idp/openid/"
6
9
  CEDA_IDP = "/OpenID/Provider/server/"
esgpull/context.py CHANGED
@@ -16,7 +16,7 @@ from rich.pretty import pretty_repr
16
16
 
17
17
  from esgpull.config import Config
18
18
  from esgpull.exceptions import SolrUnstableQueryError
19
- from esgpull.models import Dataset, File, Query
19
+ from esgpull.models import DatasetRecord, File, Query
20
20
  from esgpull.tui import logger
21
21
  from esgpull.utils import format_date_iso, index2url, sync
22
22
 
@@ -151,7 +151,7 @@ class ResultHints(Result):
151
151
 
152
152
  @dataclass
153
153
  class ResultSearch(Result):
154
- data: Sequence[File | Dataset] = field(init=False, repr=False)
154
+ data: Sequence[File | DatasetRecord] = field(init=False, repr=False)
155
155
 
156
156
  def process(self) -> None:
157
157
  raise NotImplementedError
@@ -159,14 +159,14 @@ class ResultSearch(Result):
159
159
 
160
160
  @dataclass
161
161
  class ResultDatasets(Result):
162
- data: Sequence[Dataset] = field(init=False, repr=False)
162
+ data: Sequence[DatasetRecord] = field(init=False, repr=False)
163
163
 
164
164
  def process(self) -> None:
165
165
  self.data = []
166
166
  if self.success:
167
167
  for doc in self.json["response"]["docs"]:
168
168
  try:
169
- dataset = Dataset.serialize(doc)
169
+ dataset = DatasetRecord.serialize(doc)
170
170
  self.data.append(dataset)
171
171
  except KeyError as exc:
172
172
  logger.exception(exc)
@@ -492,8 +492,8 @@ class Context:
492
492
  self,
493
493
  *results: ResultSearch,
494
494
  keep_duplicates: bool,
495
- ) -> list[Dataset]:
496
- datasets: list[Dataset] = []
495
+ ) -> list[DatasetRecord]:
496
+ datasets: list[DatasetRecord] = []
497
497
  ids: set[str] = set()
498
498
  async for result in self._fetch(*results):
499
499
  dataset_result = result.to(ResultDatasets)
@@ -501,7 +501,7 @@ class Context:
501
501
  if dataset_result.processed:
502
502
  for d in dataset_result.data:
503
503
  if not keep_duplicates and d.dataset_id in ids:
504
- logger.warning(f"Duplicate dataset {d.dataset_id}")
504
+ logger.debug(f"Duplicate dataset {d.dataset_id}")
505
505
  else:
506
506
  datasets.append(d)
507
507
  ids.add(d.dataset_id)
@@ -520,7 +520,7 @@ class Context:
520
520
  if files_result.processed:
521
521
  for file in files_result.data:
522
522
  if not keep_duplicates and file.sha in shas:
523
- logger.warning(f"Duplicate file {file.file_id}")
523
+ logger.debug(f"Duplicate file {file.file_id}")
524
524
  else:
525
525
  files.append(file)
526
526
  shas.add(file.sha)
@@ -627,7 +627,7 @@ class Context:
627
627
  date_from: datetime | None = None,
628
628
  date_to: datetime | None = None,
629
629
  keep_duplicates: bool = True,
630
- ) -> list[Dataset]:
630
+ ) -> list[DatasetRecord]:
631
631
  if hits is None:
632
632
  hits = self.hits(*queries, file=False)
633
633
  results = self.prepare_search(
esgpull/database.py CHANGED
@@ -12,11 +12,13 @@ import sqlalchemy.orm
12
12
  from alembic.config import Config as AlembicConfig
13
13
  from alembic.migration import MigrationContext
14
14
  from alembic.script import ScriptDirectory
15
+ from sqlalchemy.inspection import inspect
15
16
  from sqlalchemy.orm import Session, joinedload, make_transient
16
17
 
17
18
  from esgpull import __file__
18
19
  from esgpull.config import Config
19
20
  from esgpull.models import File, Query, Table, sql
21
+ from esgpull.models.base import Base, BaseNoSHA
20
22
  from esgpull.version import __version__
21
23
 
22
24
  # from esgpull.exceptions import NoClauseError
@@ -151,8 +153,12 @@ class Database:
151
153
  def unlink(self, query: Query, file: File):
152
154
  self.session.execute(sql.query_file.unlink(query, file))
153
155
 
154
- def __contains__(self, item: Table) -> bool:
155
- return self.scalars(sql.count(item))[0] > 0
156
+ def __contains__(self, item: Base | BaseNoSHA) -> bool:
157
+ mapper = inspect(item.__class__)
158
+ pk_col = mapper.primary_key[0]
159
+ pk_value = getattr(item, pk_col.name)
160
+ stmt = sa.exists().where(pk_col == pk_value)
161
+ return self.scalars(sa.select(stmt))[0]
156
162
 
157
163
  def has_file_id(self, file: File) -> bool:
158
164
  return len(self.scalars(sql.file.with_file_id(file.file_id))) == 1
esgpull/download.py CHANGED
@@ -1,6 +1,7 @@
1
1
  # from math import ceil
2
2
  from collections.abc import AsyncGenerator
3
3
  from dataclasses import dataclass
4
+ from datetime import datetime
4
5
 
5
6
  from httpx import AsyncClient
6
7
 
@@ -19,6 +20,7 @@ class DownloadCtx:
19
20
  completed: int = 0
20
21
  chunk: bytes | None = None
21
22
  digest: Digest | None = None
23
+ start_time: datetime | None = None
22
24
 
23
25
  @property
24
26
  def finished(self) -> bool:
@@ -54,6 +56,7 @@ class Simple(BaseDownloader):
54
56
  ctx: DownloadCtx,
55
57
  chunk_size: int,
56
58
  ) -> AsyncGenerator[DownloadCtx, None]:
59
+ ctx.start_time = datetime.now()
57
60
  async with client.stream("GET", ctx.file.url) as resp:
58
61
  resp.raise_for_status()
59
62
  async for chunk in resp.aiter_bytes(chunk_size=chunk_size):
esgpull/esgpull.py CHANGED
@@ -3,6 +3,7 @@ from __future__ import annotations
3
3
  import logging
4
4
  from collections.abc import AsyncIterator
5
5
  from dataclasses import dataclass
6
+ from datetime import datetime
6
7
  from functools import cached_property, partial
7
8
  from pathlib import Path
8
9
  from warnings import warn
@@ -25,6 +26,7 @@ from esgpull.auth import Auth, Credentials
25
26
  from esgpull.config import Config
26
27
  from esgpull.context import Context
27
28
  from esgpull.database import Database
29
+ from esgpull.download import DownloadCtx
28
30
  from esgpull.exceptions import (
29
31
  DownloadCancelled,
30
32
  InvalidInstallPath,
@@ -44,6 +46,13 @@ from esgpull.models import (
44
46
  sql,
45
47
  )
46
48
  from esgpull.models.utils import short_sha
49
+ from esgpull.plugin import (
50
+ Event,
51
+ PluginManager,
52
+ emit,
53
+ get_plugin_manager,
54
+ set_plugin_manager,
55
+ )
47
56
  from esgpull.processor import Processor
48
57
  from esgpull.result import Err, Ok, Result
49
58
  from esgpull.tui import UI, DummyLive, Verbosity, logger
@@ -117,6 +126,18 @@ class Esgpull:
117
126
  if load_db:
118
127
  self.db = Database.from_config(self.config)
119
128
  self.graph = Graph(self.db)
129
+ # Initialize plugin system
130
+ plugin_config_path = self.config.paths.plugins / "plugins.toml"
131
+ try:
132
+ self.plugin_manager = get_plugin_manager()
133
+ self.plugin_manager.__init__(config_path=plugin_config_path)
134
+ except ValueError:
135
+ self.plugin_manager = PluginManager(config_path=plugin_config_path)
136
+ set_plugin_manager(self.plugin_manager)
137
+ if self.config.plugins.enabled:
138
+ self.plugin_manager.enabled = True
139
+ self.config.paths.plugins.mkdir(exist_ok=True, parents=True)
140
+ self.plugin_manager.discover_plugins(self.config.paths.plugins)
120
141
 
121
142
  def fetch_index_nodes(self) -> list[str]:
122
143
  """
@@ -309,7 +330,7 @@ class Esgpull:
309
330
  progress: Progress,
310
331
  task_ids: dict[str, TaskID],
311
332
  live: Live | DummyLive,
312
- ) -> AsyncIterator[Result]:
333
+ ) -> AsyncIterator[Result[DownloadCtx]]:
313
334
  async for result in processor.process():
314
335
  task_idx = progress.task_ids.index(task_ids[result.data.file.sha])
315
336
  task = progress.tasks[task_idx]
@@ -348,7 +369,7 @@ class Esgpull:
348
369
  yield result
349
370
  case Err(_, err):
350
371
  progress.remove_task(task.id)
351
- yield Err(result.data, err)
372
+ yield Err(result.data, err=err)
352
373
  case Err():
353
374
  progress.remove_task(task.id)
354
375
  yield result
@@ -438,15 +459,38 @@ class Esgpull:
438
459
  match result:
439
460
  case Ok():
440
461
  main_progress.update(main_task_id, advance=1)
441
- result.data.file.status = FileStatus.Done
442
- files.append(result.data.file)
443
- case Err():
462
+ file = result.data.file
463
+ file.status = FileStatus.Done
464
+ files.append(file)
465
+ emit(
466
+ Event.file_complete,
467
+ file=file,
468
+ destination=self.fs[file].drs,
469
+ start_time=result.data.start_time,
470
+ end_time=datetime.now(),
471
+ )
472
+ if file.dataset is not None:
473
+ is_dataset_complete = self.db.scalars(
474
+ sql.dataset.is_complete(file.dataset)
475
+ )[0]
476
+ if is_dataset_complete:
477
+ emit(
478
+ Event.dataset_complete,
479
+ dataset=file.dataset,
480
+ )
481
+ case Err(_, err):
444
482
  queue_size -= 1
445
483
  main_progress.update(
446
484
  main_task_id, total=queue_size
447
485
  )
448
486
  result.data.file.status = FileStatus.Error
449
487
  errors.append(result)
488
+ emit(
489
+ Event.file_error,
490
+ file=result.data.file,
491
+ exception=err,
492
+ )
493
+
450
494
  if use_db:
451
495
  self.db.add(result.data.file)
452
496
  remaining_dict.pop(result.data.file.sha, None)
esgpull/graph.py CHANGED
@@ -418,7 +418,7 @@ class Graph:
418
418
  if keep_require:
419
419
  query_tree = query._rich_tree()
420
420
  else:
421
- query_tree = query.no_require()._rich_tree()
421
+ query_tree = query._rich_tree(hide_require=True)
422
422
  if query_tree is not None:
423
423
  tree.add(query_tree)
424
424
  self.fill_tree(query, query_tree)
@@ -0,0 +1,28 @@
1
+ """update tables
2
+
3
+ Revision ID: 0.9.0
4
+ Revises: d14f179e553c
5
+ Create Date: 2025-07-07 14:54:58.433022
6
+
7
+ """
8
+ from alembic import op
9
+ import sqlalchemy as sa
10
+
11
+
12
+ # revision identifiers, used by Alembic.
13
+ revision = '0.9.0'
14
+ down_revision = 'd14f179e553c'
15
+ branch_labels = None
16
+ depends_on = None
17
+
18
+
19
+ def upgrade() -> None:
20
+ # ### commands auto generated by Alembic - please adjust! ###
21
+ pass
22
+ # ### end Alembic commands ###
23
+
24
+
25
+ def downgrade() -> None:
26
+ # ### commands auto generated by Alembic - please adjust! ###
27
+ pass
28
+ # ### end Alembic commands ###
@@ -0,0 +1,32 @@
1
+ """file_add_composite_index_dataset_id_status
2
+
3
+ Revision ID: d14f179e553c
4
+ Revises: e7edab5d4e4b
5
+ Create Date: 2025-06-18 16:05:35.721085
6
+
7
+ """
8
+ from alembic import op
9
+ import sqlalchemy as sa
10
+
11
+
12
+ # revision identifiers, used by Alembic.
13
+ revision = 'd14f179e553c'
14
+ down_revision = 'e7edab5d4e4b'
15
+ branch_labels = None
16
+ depends_on = None
17
+
18
+
19
+ def upgrade() -> None:
20
+ # ### commands auto generated by Alembic - please adjust! ###
21
+ with op.batch_alter_table('file', schema=None) as batch_op:
22
+ batch_op.create_index('ix_file_dataset_status', ['dataset_id', 'status'], unique=False)
23
+
24
+ # ### end Alembic commands ###
25
+
26
+
27
+ def downgrade() -> None:
28
+ # ### commands auto generated by Alembic - please adjust! ###
29
+ with op.batch_alter_table('file', schema=None) as batch_op:
30
+ batch_op.drop_index('ix_file_dataset_status')
31
+
32
+ # ### end Alembic commands ###
@@ -0,0 +1,39 @@
1
+ """add_dataset_tracking
2
+
3
+ Revision ID: e7edab5d4e4b
4
+ Revises: 0.8.0
5
+ Create Date: 2025-05-23 17:38:22.066153
6
+
7
+ """
8
+ from alembic import op
9
+ import sqlalchemy as sa
10
+
11
+
12
+ # revision identifiers, used by Alembic.
13
+ revision = 'e7edab5d4e4b'
14
+ down_revision = '0.8.0'
15
+ branch_labels = None
16
+ depends_on = None
17
+
18
+
19
+ def upgrade() -> None:
20
+ # ### commands auto generated by Alembic - please adjust! ###
21
+ op.create_table('dataset',
22
+ sa.Column('dataset_id', sa.String(length=255), nullable=False),
23
+ sa.Column('total_files', sa.Integer(), nullable=False),
24
+ sa.Column('created_at', sa.DateTime(), server_default=sa.text('(CURRENT_TIMESTAMP)'), nullable=False),
25
+ sa.Column('updated_at', sa.DateTime(), server_default=sa.text('(CURRENT_TIMESTAMP)'), nullable=False),
26
+ sa.PrimaryKeyConstraint('dataset_id')
27
+ )
28
+ with op.batch_alter_table('file', schema=None) as batch_op:
29
+ batch_op.create_foreign_key('fk_file_dataset', 'dataset', ['dataset_id'], ['dataset_id'])
30
+
31
+ # ### end Alembic commands ###
32
+
33
+ def downgrade() -> None:
34
+ # ### commands auto generated by Alembic - please adjust! ###
35
+ with op.batch_alter_table('file', schema=None) as batch_op:
36
+ batch_op.drop_constraint('fk_file_dataset', type_='foreignkey')
37
+
38
+ op.drop_table('dataset')
39
+ # ### end Alembic commands ###
@@ -1,7 +1,7 @@
1
1
  from typing import TypeVar
2
2
 
3
3
  from esgpull.models.base import Base
4
- from esgpull.models.dataset import Dataset
4
+ from esgpull.models.dataset import Dataset, DatasetRecord
5
5
  from esgpull.models.facet import Facet
6
6
  from esgpull.models.file import FastFile, FileStatus
7
7
  from esgpull.models.options import Option, Options
@@ -15,6 +15,7 @@ Table = TypeVar("Table", bound=Base)
15
15
  __all__ = [
16
16
  "Base",
17
17
  "Dataset",
18
+ "DatasetRecord",
18
19
  "Facet",
19
20
  "FastFile",
20
21
  "File",
esgpull/models/base.py CHANGED
@@ -16,16 +16,10 @@ T = TypeVar("T")
16
16
  Sha = sa.String(40)
17
17
 
18
18
 
19
- class Base(MappedAsDataclass, DeclarativeBase):
19
+ # Base class for all models - provides core SQLAlchemy functionality
20
+ class _BaseModel(MappedAsDataclass, DeclarativeBase):
20
21
  __dataclass_fields__: ClassVar[dict[str, Field]]
21
- __sql_attrs__ = ("id", "sha", "_sa_instance_state", "__dataclass_fields__")
22
-
23
- sha: Mapped[str] = mapped_column(
24
- Sha,
25
- init=False,
26
- repr=False,
27
- primary_key=True,
28
- )
22
+ __sql_attrs__ = ("id", "_sa_instance_state", "__dataclass_fields__")
29
23
 
30
24
  @property
31
25
  def _names(self) -> tuple[str, ...]:
@@ -36,15 +30,38 @@ class Base(MappedAsDataclass, DeclarativeBase):
36
30
  result += (name,)
37
31
  return result
38
32
 
33
+ @property
34
+ def state(self) -> InstanceState:
35
+ return cast(InstanceState, sa.inspect(self))
36
+
37
+ def asdict(self) -> Mapping[str, Any]:
38
+ raise NotImplementedError
39
+
40
+
41
+ # Base class for models that use SHA as primary key
42
+ class Base(_BaseModel):
43
+ __abstract__ = True
44
+ __sql_attrs__ = ("id", "sha", "_sa_instance_state", "__dataclass_fields__")
45
+
46
+ sha: Mapped[str] = mapped_column(
47
+ Sha,
48
+ init=False,
49
+ repr=False,
50
+ primary_key=True,
51
+ )
52
+
39
53
  def _as_bytes(self) -> bytes:
40
54
  raise NotImplementedError
41
55
 
42
56
  def compute_sha(self) -> None:
43
57
  self.sha = sha1(self._as_bytes()).hexdigest()
44
58
 
45
- @property
46
- def state(self) -> InstanceState:
47
- return cast(InstanceState, sa.inspect(self))
48
59
 
49
- def asdict(self) -> Mapping[str, Any]:
50
- raise NotImplementedError
60
+ # Base class for models that don't use SHA (e.g., Dataset)
61
+ class BaseNoSHA(_BaseModel):
62
+ __abstract__ = True
63
+ __sql_attrs__ = ("id", "_sa_instance_state", "__dataclass_fields__")
64
+
65
+
66
+ # Keep SHAKeyMixin for backward compatibility if needed
67
+ SHAKeyMixin = Base
esgpull/models/dataset.py CHANGED
@@ -1,12 +1,22 @@
1
1
  from __future__ import annotations
2
2
 
3
- from dataclasses import asdict, dataclass
3
+ from collections.abc import Mapping
4
+ from dataclasses import dataclass
5
+ from datetime import datetime, timezone
6
+ from typing import TYPE_CHECKING, Any
4
7
 
8
+ import sqlalchemy as sa
9
+ from sqlalchemy.orm import Mapped, mapped_column, relationship
10
+
11
+ from esgpull.models.base import BaseNoSHA
5
12
  from esgpull.models.utils import find_int, find_str
6
13
 
14
+ if TYPE_CHECKING:
15
+ from esgpull.models.query import File
16
+
7
17
 
8
18
  @dataclass
9
- class Dataset:
19
+ class DatasetRecord:
10
20
  dataset_id: str
11
21
  master_id: str
12
22
  version: str
@@ -15,7 +25,7 @@ class Dataset:
15
25
  number_of_files: int
16
26
 
17
27
  @classmethod
18
- def serialize(cls, source: dict) -> Dataset:
28
+ def serialize(cls, source: dict) -> DatasetRecord:
19
29
  dataset_id = find_str(source["instance_id"]).partition("|")[0]
20
30
  master_id, version = dataset_id.rsplit(".", 1)
21
31
  data_node = find_str(source["data_node"])
@@ -30,5 +40,38 @@ class Dataset:
30
40
  number_of_files=number_of_files,
31
41
  )
32
42
 
33
- def asdict(self) -> dict:
34
- return asdict(self)
43
+
44
+ class Dataset(BaseNoSHA):
45
+ __tablename__ = "dataset"
46
+
47
+ dataset_id: Mapped[str] = mapped_column(sa.String(255), primary_key=True)
48
+ total_files: Mapped[int] = mapped_column(sa.Integer)
49
+ created_at: Mapped[datetime] = mapped_column(
50
+ server_default=sa.func.now(),
51
+ default_factory=lambda: datetime.now(timezone.utc),
52
+ init=False,
53
+ )
54
+ updated_at: Mapped[datetime] = mapped_column(
55
+ server_default=sa.func.now(),
56
+ default_factory=lambda: datetime.now(timezone.utc),
57
+ init=False,
58
+ )
59
+ files: Mapped[list[File]] = relationship(
60
+ back_populates="dataset",
61
+ foreign_keys="[File.dataset_id]",
62
+ primaryjoin="Dataset.dataset_id==File.dataset_id",
63
+ default_factory=list,
64
+ init=False,
65
+ repr=False,
66
+ )
67
+
68
+ def asdict(self) -> Mapping[str, Any]:
69
+ return {
70
+ "dataset_id": self.dataset_id,
71
+ "total_files": self.total_files,
72
+ "created_at": self.created_at.isoformat(),
73
+ "updated_at": self.updated_at.isoformat(),
74
+ }
75
+
76
+ def __hash__(self) -> int:
77
+ return hash(self.dataset_id)