datachain 0.14.2__py3-none-any.whl → 0.39.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- datachain/__init__.py +20 -0
- datachain/asyn.py +11 -12
- datachain/cache.py +7 -7
- datachain/catalog/__init__.py +2 -2
- datachain/catalog/catalog.py +621 -507
- datachain/catalog/dependency.py +164 -0
- datachain/catalog/loader.py +28 -18
- datachain/checkpoint.py +43 -0
- datachain/cli/__init__.py +24 -33
- datachain/cli/commands/__init__.py +1 -8
- datachain/cli/commands/datasets.py +83 -52
- datachain/cli/commands/ls.py +17 -17
- datachain/cli/commands/show.py +4 -4
- datachain/cli/parser/__init__.py +8 -74
- datachain/cli/parser/job.py +95 -3
- datachain/cli/parser/studio.py +11 -4
- datachain/cli/parser/utils.py +1 -2
- datachain/cli/utils.py +2 -15
- datachain/client/azure.py +4 -4
- datachain/client/fsspec.py +45 -28
- datachain/client/gcs.py +6 -6
- datachain/client/hf.py +29 -2
- datachain/client/http.py +157 -0
- datachain/client/local.py +15 -11
- datachain/client/s3.py +17 -9
- datachain/config.py +4 -8
- datachain/data_storage/db_engine.py +12 -6
- datachain/data_storage/job.py +5 -1
- datachain/data_storage/metastore.py +1252 -186
- datachain/data_storage/schema.py +58 -45
- datachain/data_storage/serializer.py +105 -15
- datachain/data_storage/sqlite.py +286 -127
- datachain/data_storage/warehouse.py +250 -113
- datachain/dataset.py +353 -148
- datachain/delta.py +391 -0
- datachain/diff/__init__.py +27 -29
- datachain/error.py +60 -0
- datachain/func/__init__.py +2 -1
- datachain/func/aggregate.py +66 -42
- datachain/func/array.py +242 -38
- datachain/func/base.py +7 -4
- datachain/func/conditional.py +110 -60
- datachain/func/func.py +96 -45
- datachain/func/numeric.py +55 -38
- datachain/func/path.py +32 -20
- datachain/func/random.py +2 -2
- datachain/func/string.py +67 -37
- datachain/func/window.py +7 -8
- datachain/hash_utils.py +123 -0
- datachain/job.py +11 -7
- datachain/json.py +138 -0
- datachain/lib/arrow.py +58 -22
- datachain/lib/audio.py +245 -0
- datachain/lib/clip.py +14 -13
- datachain/lib/convert/flatten.py +5 -3
- datachain/lib/convert/python_to_sql.py +6 -10
- datachain/lib/convert/sql_to_python.py +8 -0
- datachain/lib/convert/values_to_tuples.py +156 -51
- datachain/lib/data_model.py +42 -20
- datachain/lib/dataset_info.py +36 -8
- datachain/lib/dc/__init__.py +8 -2
- datachain/lib/dc/csv.py +25 -28
- datachain/lib/dc/database.py +398 -0
- datachain/lib/dc/datachain.py +1289 -425
- datachain/lib/dc/datasets.py +320 -38
- datachain/lib/dc/hf.py +38 -24
- datachain/lib/dc/json.py +29 -32
- datachain/lib/dc/listings.py +112 -8
- datachain/lib/dc/pandas.py +16 -12
- datachain/lib/dc/parquet.py +35 -23
- datachain/lib/dc/records.py +31 -23
- datachain/lib/dc/storage.py +154 -64
- datachain/lib/dc/storage_pattern.py +251 -0
- datachain/lib/dc/utils.py +24 -16
- datachain/lib/dc/values.py +8 -9
- datachain/lib/file.py +622 -89
- datachain/lib/hf.py +69 -39
- datachain/lib/image.py +14 -14
- datachain/lib/listing.py +14 -11
- datachain/lib/listing_info.py +1 -2
- datachain/lib/meta_formats.py +3 -4
- datachain/lib/model_store.py +39 -7
- datachain/lib/namespaces.py +125 -0
- datachain/lib/projects.py +130 -0
- datachain/lib/pytorch.py +32 -21
- datachain/lib/settings.py +192 -56
- datachain/lib/signal_schema.py +427 -104
- datachain/lib/tar.py +1 -2
- datachain/lib/text.py +8 -7
- datachain/lib/udf.py +164 -76
- datachain/lib/udf_signature.py +60 -35
- datachain/lib/utils.py +118 -4
- datachain/lib/video.py +17 -9
- datachain/lib/webdataset.py +61 -56
- datachain/lib/webdataset_laion.py +15 -16
- datachain/listing.py +22 -10
- datachain/model/bbox.py +3 -1
- datachain/model/ultralytics/bbox.py +16 -12
- datachain/model/ultralytics/pose.py +16 -12
- datachain/model/ultralytics/segment.py +16 -12
- datachain/namespace.py +84 -0
- datachain/node.py +6 -6
- datachain/nodes_thread_pool.py +0 -1
- datachain/plugins.py +24 -0
- datachain/project.py +78 -0
- datachain/query/batch.py +40 -41
- datachain/query/dataset.py +604 -322
- datachain/query/dispatch.py +261 -154
- datachain/query/metrics.py +4 -6
- datachain/query/params.py +2 -3
- datachain/query/queue.py +3 -12
- datachain/query/schema.py +11 -6
- datachain/query/session.py +200 -33
- datachain/query/udf.py +34 -2
- datachain/remote/studio.py +171 -69
- datachain/script_meta.py +12 -12
- datachain/semver.py +68 -0
- datachain/sql/__init__.py +2 -0
- datachain/sql/functions/array.py +33 -1
- datachain/sql/postgresql_dialect.py +9 -0
- datachain/sql/postgresql_types.py +21 -0
- datachain/sql/sqlite/__init__.py +5 -1
- datachain/sql/sqlite/base.py +102 -29
- datachain/sql/sqlite/types.py +8 -13
- datachain/sql/types.py +70 -15
- datachain/studio.py +223 -46
- datachain/toolkit/split.py +31 -10
- datachain/utils.py +101 -59
- {datachain-0.14.2.dist-info → datachain-0.39.0.dist-info}/METADATA +77 -22
- datachain-0.39.0.dist-info/RECORD +173 -0
- {datachain-0.14.2.dist-info → datachain-0.39.0.dist-info}/WHEEL +1 -1
- datachain/cli/commands/query.py +0 -53
- datachain/query/utils.py +0 -42
- datachain-0.14.2.dist-info/RECORD +0 -158
- {datachain-0.14.2.dist-info → datachain-0.39.0.dist-info}/entry_points.txt +0 -0
- {datachain-0.14.2.dist-info → datachain-0.39.0.dist-info}/licenses/LICENSE +0 -0
- {datachain-0.14.2.dist-info → datachain-0.39.0.dist-info}/top_level.txt +0 -0
|
@@ -1,28 +1,37 @@
|
|
|
1
1
|
import copy
|
|
2
|
-
import json
|
|
3
2
|
import logging
|
|
4
3
|
import os
|
|
5
4
|
from abc import ABC, abstractmethod
|
|
6
5
|
from collections.abc import Iterator
|
|
6
|
+
from contextlib import contextmanager, suppress
|
|
7
7
|
from datetime import datetime, timezone
|
|
8
8
|
from functools import cached_property, reduce
|
|
9
9
|
from itertools import groupby
|
|
10
|
-
from typing import TYPE_CHECKING, Any
|
|
10
|
+
from typing import TYPE_CHECKING, Any
|
|
11
11
|
from uuid import uuid4
|
|
12
12
|
|
|
13
13
|
from sqlalchemy import (
|
|
14
14
|
JSON,
|
|
15
15
|
BigInteger,
|
|
16
|
+
Boolean,
|
|
16
17
|
Column,
|
|
17
18
|
DateTime,
|
|
18
19
|
ForeignKey,
|
|
20
|
+
Index,
|
|
19
21
|
Integer,
|
|
20
22
|
Table,
|
|
21
23
|
Text,
|
|
22
24
|
UniqueConstraint,
|
|
25
|
+
cast,
|
|
26
|
+
desc,
|
|
27
|
+
literal,
|
|
23
28
|
select,
|
|
24
29
|
)
|
|
30
|
+
from sqlalchemy.sql import func as f
|
|
25
31
|
|
|
32
|
+
from datachain import json
|
|
33
|
+
from datachain.catalog.dependency import DatasetDependencyNode
|
|
34
|
+
from datachain.checkpoint import Checkpoint
|
|
26
35
|
from datachain.data_storage import JobQueryType, JobStatus
|
|
27
36
|
from datachain.data_storage.serializer import Serializable
|
|
28
37
|
from datachain.dataset import (
|
|
@@ -33,22 +42,34 @@ from datachain.dataset import (
|
|
|
33
42
|
DatasetStatus,
|
|
34
43
|
DatasetVersion,
|
|
35
44
|
StorageURI,
|
|
45
|
+
parse_schema,
|
|
36
46
|
)
|
|
37
47
|
from datachain.error import (
|
|
48
|
+
CheckpointNotFoundError,
|
|
49
|
+
DataChainError,
|
|
38
50
|
DatasetNotFoundError,
|
|
51
|
+
DatasetVersionNotFoundError,
|
|
52
|
+
NamespaceDeleteNotAllowedError,
|
|
53
|
+
NamespaceNotFoundError,
|
|
54
|
+
ProjectDeleteNotAllowedError,
|
|
55
|
+
ProjectNotFoundError,
|
|
39
56
|
TableMissingError,
|
|
40
57
|
)
|
|
41
58
|
from datachain.job import Job
|
|
42
|
-
from datachain.
|
|
59
|
+
from datachain.namespace import Namespace
|
|
60
|
+
from datachain.project import Project
|
|
43
61
|
|
|
44
62
|
if TYPE_CHECKING:
|
|
45
|
-
from sqlalchemy import Delete, Insert, Select, Update
|
|
63
|
+
from sqlalchemy import CTE, Delete, Insert, Select, Subquery, Update
|
|
46
64
|
from sqlalchemy.schema import SchemaItem
|
|
65
|
+
from sqlalchemy.sql.elements import ColumnElement
|
|
47
66
|
|
|
48
67
|
from datachain.data_storage import schema
|
|
49
68
|
from datachain.data_storage.db_engine import DatabaseEngine
|
|
50
69
|
|
|
51
70
|
logger = logging.getLogger("datachain")
|
|
71
|
+
DEPTH_LIMIT_DEFAULT = 100
|
|
72
|
+
JOB_ANCESTRY_MAX_DEPTH = 100
|
|
52
73
|
|
|
53
74
|
|
|
54
75
|
class AbstractMetastore(ABC, Serializable):
|
|
@@ -60,15 +81,20 @@ class AbstractMetastore(ABC, Serializable):
|
|
|
60
81
|
uri: StorageURI
|
|
61
82
|
|
|
62
83
|
schema: "schema.Schema"
|
|
84
|
+
namespace_class: type[Namespace] = Namespace
|
|
85
|
+
project_class: type[Project] = Project
|
|
63
86
|
dataset_class: type[DatasetRecord] = DatasetRecord
|
|
87
|
+
dataset_version_class: type[DatasetVersion] = DatasetVersion
|
|
64
88
|
dataset_list_class: type[DatasetListRecord] = DatasetListRecord
|
|
65
89
|
dataset_list_version_class: type[DatasetListVersion] = DatasetListVersion
|
|
66
90
|
dependency_class: type[DatasetDependency] = DatasetDependency
|
|
91
|
+
dependency_node_class: type[DatasetDependencyNode] = DatasetDependencyNode
|
|
67
92
|
job_class: type[Job] = Job
|
|
93
|
+
checkpoint_class: type[Checkpoint] = Checkpoint
|
|
68
94
|
|
|
69
95
|
def __init__(
|
|
70
96
|
self,
|
|
71
|
-
uri:
|
|
97
|
+
uri: StorageURI | None = None,
|
|
72
98
|
):
|
|
73
99
|
self.uri = uri or StorageURI("")
|
|
74
100
|
|
|
@@ -82,7 +108,7 @@ class AbstractMetastore(ABC, Serializable):
|
|
|
82
108
|
@abstractmethod
|
|
83
109
|
def clone(
|
|
84
110
|
self,
|
|
85
|
-
uri:
|
|
111
|
+
uri: StorageURI | None = None,
|
|
86
112
|
use_new_connection: bool = False,
|
|
87
113
|
) -> "AbstractMetastore":
|
|
88
114
|
"""Clones AbstractMetastore implementation for some Storage input.
|
|
@@ -99,6 +125,16 @@ class AbstractMetastore(ABC, Serializable):
|
|
|
99
125
|
differently."""
|
|
100
126
|
self.close()
|
|
101
127
|
|
|
128
|
+
@contextmanager
|
|
129
|
+
def _init_guard(self):
|
|
130
|
+
"""Ensure resources acquired during __init__ are released on failure."""
|
|
131
|
+
try:
|
|
132
|
+
yield
|
|
133
|
+
except Exception:
|
|
134
|
+
with suppress(Exception):
|
|
135
|
+
self.close_on_exit()
|
|
136
|
+
raise
|
|
137
|
+
|
|
102
138
|
def cleanup_tables(self, temp_table_names: list[str]) -> None:
|
|
103
139
|
"""Cleanup temp tables."""
|
|
104
140
|
|
|
@@ -106,21 +142,131 @@ class AbstractMetastore(ABC, Serializable):
|
|
|
106
142
|
"""Cleanup for tests."""
|
|
107
143
|
|
|
108
144
|
#
|
|
109
|
-
#
|
|
145
|
+
# Namespaces
|
|
110
146
|
#
|
|
111
147
|
|
|
148
|
+
@property
|
|
149
|
+
@abstractmethod
|
|
150
|
+
def default_namespace_name(self):
|
|
151
|
+
"""Gets default namespace name"""
|
|
152
|
+
|
|
153
|
+
@property
|
|
154
|
+
def system_namespace_name(self):
|
|
155
|
+
return Namespace.system()
|
|
156
|
+
|
|
157
|
+
@abstractmethod
|
|
158
|
+
def create_namespace(
|
|
159
|
+
self,
|
|
160
|
+
name: str,
|
|
161
|
+
description: str | None = None,
|
|
162
|
+
uuid: str | None = None,
|
|
163
|
+
ignore_if_exists: bool = True,
|
|
164
|
+
validate: bool = True,
|
|
165
|
+
**kwargs,
|
|
166
|
+
) -> Namespace:
|
|
167
|
+
"""Creates new namespace"""
|
|
168
|
+
|
|
169
|
+
@abstractmethod
|
|
170
|
+
def get_namespace(self, name: str, conn=None) -> Namespace:
|
|
171
|
+
"""Gets a single namespace by name"""
|
|
172
|
+
|
|
173
|
+
@abstractmethod
|
|
174
|
+
def remove_namespace(self, namespace_id: int, conn=None) -> None:
|
|
175
|
+
"""Removes a single namespace by id"""
|
|
176
|
+
|
|
177
|
+
@abstractmethod
|
|
178
|
+
def list_namespaces(self, conn=None) -> list[Namespace]:
|
|
179
|
+
"""Gets a list of all namespaces"""
|
|
180
|
+
|
|
181
|
+
#
|
|
182
|
+
# Projects
|
|
183
|
+
#
|
|
184
|
+
|
|
185
|
+
@property
|
|
186
|
+
@abstractmethod
|
|
187
|
+
def default_project_name(self):
|
|
188
|
+
"""Gets default project name"""
|
|
189
|
+
|
|
190
|
+
@property
|
|
191
|
+
def listing_project_name(self):
|
|
192
|
+
return Project.listing()
|
|
193
|
+
|
|
194
|
+
@cached_property
|
|
195
|
+
def default_project(self) -> Project:
|
|
196
|
+
return self.get_project(
|
|
197
|
+
self.default_project_name, self.default_namespace_name, create=True
|
|
198
|
+
)
|
|
199
|
+
|
|
200
|
+
@cached_property
|
|
201
|
+
def listing_project(self) -> Project:
|
|
202
|
+
return self.get_project(self.listing_project_name, self.system_namespace_name)
|
|
203
|
+
|
|
204
|
+
@abstractmethod
|
|
205
|
+
def create_project(
|
|
206
|
+
self,
|
|
207
|
+
namespace_name: str,
|
|
208
|
+
name: str,
|
|
209
|
+
description: str | None = None,
|
|
210
|
+
uuid: str | None = None,
|
|
211
|
+
ignore_if_exists: bool = True,
|
|
212
|
+
validate: bool = True,
|
|
213
|
+
**kwargs,
|
|
214
|
+
) -> Project:
|
|
215
|
+
"""Creates new project in specific namespace"""
|
|
216
|
+
|
|
217
|
+
@abstractmethod
|
|
218
|
+
def get_project(
|
|
219
|
+
self, name: str, namespace_name: str, create: bool = False, conn=None
|
|
220
|
+
) -> Project:
|
|
221
|
+
"""
|
|
222
|
+
Gets a single project inside some namespace by name.
|
|
223
|
+
It also creates project if not found and create flag is set to True.
|
|
224
|
+
"""
|
|
225
|
+
|
|
226
|
+
def is_default_project(self, project_name: str, namespace_name: str) -> bool:
|
|
227
|
+
return (
|
|
228
|
+
project_name == self.default_project_name
|
|
229
|
+
and namespace_name == self.default_namespace_name
|
|
230
|
+
)
|
|
231
|
+
|
|
232
|
+
def is_listing_project(self, project_name: str, namespace_name: str) -> bool:
|
|
233
|
+
return (
|
|
234
|
+
project_name == self.listing_project_name
|
|
235
|
+
and namespace_name == self.system_namespace_name
|
|
236
|
+
)
|
|
237
|
+
|
|
238
|
+
@abstractmethod
|
|
239
|
+
def get_project_by_id(self, project_id: int, conn=None) -> Project:
|
|
240
|
+
"""Gets a single project by id"""
|
|
241
|
+
|
|
242
|
+
@abstractmethod
|
|
243
|
+
def count_projects(self, namespace_id: int | None = None) -> int:
|
|
244
|
+
"""Counts projects in some namespace or in general."""
|
|
245
|
+
|
|
246
|
+
@abstractmethod
|
|
247
|
+
def remove_project(self, project_id: int, conn=None) -> None:
|
|
248
|
+
"""Removes a single project by id"""
|
|
249
|
+
|
|
250
|
+
@abstractmethod
|
|
251
|
+
def list_projects(self, namespace_id: int | None, conn=None) -> list[Project]:
|
|
252
|
+
"""Gets list of projects in some namespace or in general (in all namespaces)"""
|
|
253
|
+
|
|
254
|
+
#
|
|
255
|
+
# Datasets
|
|
256
|
+
#
|
|
112
257
|
@abstractmethod
|
|
113
258
|
def create_dataset(
|
|
114
259
|
self,
|
|
115
260
|
name: str,
|
|
261
|
+
project_id: int | None = None,
|
|
116
262
|
status: int = DatasetStatus.CREATED,
|
|
117
|
-
sources:
|
|
118
|
-
feature_schema:
|
|
263
|
+
sources: list[str] | None = None,
|
|
264
|
+
feature_schema: dict | None = None,
|
|
119
265
|
query_script: str = "",
|
|
120
|
-
schema:
|
|
266
|
+
schema: dict[str, Any] | None = None,
|
|
121
267
|
ignore_if_exists: bool = False,
|
|
122
|
-
description:
|
|
123
|
-
|
|
268
|
+
description: str | None = None,
|
|
269
|
+
attrs: list[str] | None = None,
|
|
124
270
|
) -> DatasetRecord:
|
|
125
271
|
"""Creates new dataset."""
|
|
126
272
|
|
|
@@ -128,23 +274,23 @@ class AbstractMetastore(ABC, Serializable):
|
|
|
128
274
|
def create_dataset_version( # noqa: PLR0913
|
|
129
275
|
self,
|
|
130
276
|
dataset: DatasetRecord,
|
|
131
|
-
version:
|
|
277
|
+
version: str,
|
|
132
278
|
status: int,
|
|
133
279
|
sources: str = "",
|
|
134
|
-
feature_schema:
|
|
280
|
+
feature_schema: dict | None = None,
|
|
135
281
|
query_script: str = "",
|
|
136
282
|
error_message: str = "",
|
|
137
283
|
error_stack: str = "",
|
|
138
284
|
script_output: str = "",
|
|
139
|
-
created_at:
|
|
140
|
-
finished_at:
|
|
141
|
-
schema:
|
|
285
|
+
created_at: datetime | None = None,
|
|
286
|
+
finished_at: datetime | None = None,
|
|
287
|
+
schema: dict[str, Any] | None = None,
|
|
142
288
|
ignore_if_exists: bool = False,
|
|
143
|
-
num_objects:
|
|
144
|
-
size:
|
|
145
|
-
preview:
|
|
146
|
-
job_id:
|
|
147
|
-
uuid:
|
|
289
|
+
num_objects: int | None = None,
|
|
290
|
+
size: int | None = None,
|
|
291
|
+
preview: list[dict] | None = None,
|
|
292
|
+
job_id: str | None = None,
|
|
293
|
+
uuid: str | None = None,
|
|
148
294
|
) -> DatasetRecord:
|
|
149
295
|
"""Creates new dataset version."""
|
|
150
296
|
|
|
@@ -158,13 +304,13 @@ class AbstractMetastore(ABC, Serializable):
|
|
|
158
304
|
|
|
159
305
|
@abstractmethod
|
|
160
306
|
def update_dataset_version(
|
|
161
|
-
self, dataset: DatasetRecord, version:
|
|
307
|
+
self, dataset: DatasetRecord, version: str, **kwargs
|
|
162
308
|
) -> DatasetVersion:
|
|
163
309
|
"""Updates dataset version fields."""
|
|
164
310
|
|
|
165
311
|
@abstractmethod
|
|
166
312
|
def remove_dataset_version(
|
|
167
|
-
self, dataset: DatasetRecord, version:
|
|
313
|
+
self, dataset: DatasetRecord, version: str
|
|
168
314
|
) -> DatasetRecord:
|
|
169
315
|
"""
|
|
170
316
|
Deletes one single dataset version.
|
|
@@ -172,15 +318,32 @@ class AbstractMetastore(ABC, Serializable):
|
|
|
172
318
|
"""
|
|
173
319
|
|
|
174
320
|
@abstractmethod
|
|
175
|
-
def list_datasets(
|
|
176
|
-
|
|
321
|
+
def list_datasets(
|
|
322
|
+
self, project_id: int | None = None
|
|
323
|
+
) -> Iterator[DatasetListRecord]:
|
|
324
|
+
"""Lists all datasets in some project or in all projects."""
|
|
177
325
|
|
|
178
326
|
@abstractmethod
|
|
179
|
-
def
|
|
180
|
-
"""
|
|
327
|
+
def count_datasets(self, project_id: int | None = None) -> int:
|
|
328
|
+
"""Counts datasets in some project or in all projects."""
|
|
181
329
|
|
|
182
330
|
@abstractmethod
|
|
183
|
-
def
|
|
331
|
+
def list_datasets_by_prefix(
|
|
332
|
+
self, prefix: str, project_id: int | None = None
|
|
333
|
+
) -> Iterator["DatasetListRecord"]:
|
|
334
|
+
"""
|
|
335
|
+
Lists all datasets which names start with prefix in some project or in all
|
|
336
|
+
projects.
|
|
337
|
+
"""
|
|
338
|
+
|
|
339
|
+
@abstractmethod
|
|
340
|
+
def get_dataset(
|
|
341
|
+
self,
|
|
342
|
+
name: str, # normal, not full dataset name
|
|
343
|
+
namespace_name: str | None = None,
|
|
344
|
+
project_name: str | None = None,
|
|
345
|
+
conn=None,
|
|
346
|
+
) -> DatasetRecord:
|
|
184
347
|
"""Gets a single dataset by name."""
|
|
185
348
|
|
|
186
349
|
@abstractmethod
|
|
@@ -188,7 +351,7 @@ class AbstractMetastore(ABC, Serializable):
|
|
|
188
351
|
self,
|
|
189
352
|
dataset: DatasetRecord,
|
|
190
353
|
status: int,
|
|
191
|
-
version:
|
|
354
|
+
version: str | None = None,
|
|
192
355
|
error_message="",
|
|
193
356
|
error_stack="",
|
|
194
357
|
script_output="",
|
|
@@ -201,10 +364,10 @@ class AbstractMetastore(ABC, Serializable):
|
|
|
201
364
|
@abstractmethod
|
|
202
365
|
def add_dataset_dependency(
|
|
203
366
|
self,
|
|
204
|
-
|
|
205
|
-
source_dataset_version:
|
|
206
|
-
|
|
207
|
-
|
|
367
|
+
source_dataset: "DatasetRecord",
|
|
368
|
+
source_dataset_version: str,
|
|
369
|
+
dep_dataset: "DatasetRecord",
|
|
370
|
+
dep_dataset_version: str,
|
|
208
371
|
) -> None:
|
|
209
372
|
"""Adds dataset dependency to dataset."""
|
|
210
373
|
|
|
@@ -212,21 +375,27 @@ class AbstractMetastore(ABC, Serializable):
|
|
|
212
375
|
def update_dataset_dependency_source(
|
|
213
376
|
self,
|
|
214
377
|
source_dataset: DatasetRecord,
|
|
215
|
-
source_dataset_version:
|
|
216
|
-
new_source_dataset:
|
|
217
|
-
new_source_dataset_version:
|
|
378
|
+
source_dataset_version: str,
|
|
379
|
+
new_source_dataset: DatasetRecord | None = None,
|
|
380
|
+
new_source_dataset_version: str | None = None,
|
|
218
381
|
) -> None:
|
|
219
382
|
"""Updates dataset dependency source."""
|
|
220
383
|
|
|
221
384
|
@abstractmethod
|
|
222
385
|
def get_direct_dataset_dependencies(
|
|
223
|
-
self, dataset: DatasetRecord, version:
|
|
224
|
-
) -> list[
|
|
386
|
+
self, dataset: DatasetRecord, version: str
|
|
387
|
+
) -> list[DatasetDependency | None]:
|
|
225
388
|
"""Gets direct dataset dependencies."""
|
|
226
389
|
|
|
390
|
+
@abstractmethod
|
|
391
|
+
def get_dataset_dependency_nodes(
|
|
392
|
+
self, dataset_id: int, version_id: int
|
|
393
|
+
) -> list[DatasetDependencyNode | None]:
|
|
394
|
+
"""Gets dataset dependency node from database."""
|
|
395
|
+
|
|
227
396
|
@abstractmethod
|
|
228
397
|
def remove_dataset_dependencies(
|
|
229
|
-
self, dataset: DatasetRecord, version:
|
|
398
|
+
self, dataset: DatasetRecord, version: str | None = None
|
|
230
399
|
) -> None:
|
|
231
400
|
"""
|
|
232
401
|
When we remove dataset, we need to clean up it's dependencies as well.
|
|
@@ -234,7 +403,7 @@ class AbstractMetastore(ABC, Serializable):
|
|
|
234
403
|
|
|
235
404
|
@abstractmethod
|
|
236
405
|
def remove_dataset_dependants(
|
|
237
|
-
self, dataset: DatasetRecord, version:
|
|
406
|
+
self, dataset: DatasetRecord, version: str | None = None
|
|
238
407
|
) -> None:
|
|
239
408
|
"""
|
|
240
409
|
When we remove dataset, we need to clear its references in other dataset
|
|
@@ -254,43 +423,121 @@ class AbstractMetastore(ABC, Serializable):
|
|
|
254
423
|
name: str,
|
|
255
424
|
query: str,
|
|
256
425
|
query_type: JobQueryType = JobQueryType.PYTHON,
|
|
426
|
+
status: JobStatus = JobStatus.CREATED,
|
|
257
427
|
workers: int = 1,
|
|
258
|
-
python_version:
|
|
259
|
-
params:
|
|
428
|
+
python_version: str | None = None,
|
|
429
|
+
params: dict[str, str] | None = None,
|
|
430
|
+
parent_job_id: str | None = None,
|
|
260
431
|
) -> str:
|
|
261
432
|
"""
|
|
262
433
|
Creates a new job.
|
|
263
434
|
Returns the job id.
|
|
264
435
|
"""
|
|
265
436
|
|
|
437
|
+
@abstractmethod
|
|
438
|
+
def get_job(self, job_id: str) -> Job | None:
|
|
439
|
+
"""Returns the job with the given ID."""
|
|
440
|
+
|
|
441
|
+
@abstractmethod
|
|
442
|
+
def update_job(
|
|
443
|
+
self,
|
|
444
|
+
job_id: str,
|
|
445
|
+
status: JobStatus | None = None,
|
|
446
|
+
error_message: str | None = None,
|
|
447
|
+
error_stack: str | None = None,
|
|
448
|
+
finished_at: datetime | None = None,
|
|
449
|
+
metrics: dict[str, Any] | None = None,
|
|
450
|
+
) -> Job | None:
|
|
451
|
+
"""Updates job fields."""
|
|
452
|
+
|
|
266
453
|
@abstractmethod
|
|
267
454
|
def set_job_status(
|
|
268
455
|
self,
|
|
269
456
|
job_id: str,
|
|
270
457
|
status: JobStatus,
|
|
271
|
-
error_message:
|
|
272
|
-
error_stack:
|
|
273
|
-
metrics: Optional[dict[str, Any]] = None,
|
|
458
|
+
error_message: str | None = None,
|
|
459
|
+
error_stack: str | None = None,
|
|
274
460
|
) -> None:
|
|
275
461
|
"""Set the status of the given job."""
|
|
276
462
|
|
|
277
463
|
@abstractmethod
|
|
278
|
-
def get_job_status(self, job_id: str) ->
|
|
464
|
+
def get_job_status(self, job_id: str) -> JobStatus | None:
|
|
279
465
|
"""Returns the status of the given job."""
|
|
280
466
|
|
|
281
467
|
@abstractmethod
|
|
282
|
-
def
|
|
468
|
+
def get_last_job_by_name(self, name: str, conn=None) -> "Job | None":
|
|
469
|
+
"""Returns the last job with the given name, ordered by created_at."""
|
|
470
|
+
|
|
471
|
+
#
|
|
472
|
+
# Checkpoints
|
|
473
|
+
#
|
|
474
|
+
|
|
475
|
+
@abstractmethod
|
|
476
|
+
def list_checkpoints(self, job_id: str, conn=None) -> Iterator[Checkpoint]:
|
|
477
|
+
"""Returns all checkpoints related to some job"""
|
|
478
|
+
|
|
479
|
+
@abstractmethod
|
|
480
|
+
def get_last_checkpoint(self, job_id: str, conn=None) -> Checkpoint | None:
|
|
481
|
+
"""Get last created checkpoint for some job."""
|
|
482
|
+
|
|
483
|
+
@abstractmethod
|
|
484
|
+
def get_checkpoint_by_id(self, checkpoint_id: str, conn=None) -> Checkpoint:
|
|
485
|
+
"""Gets single checkpoint by id"""
|
|
486
|
+
|
|
487
|
+
def find_checkpoint(
|
|
488
|
+
self, job_id: str, _hash: str, partial: bool = False, conn=None
|
|
489
|
+
) -> Checkpoint | None:
|
|
490
|
+
"""
|
|
491
|
+
Tries to find checkpoint for a job with specific hash and optionally partial
|
|
492
|
+
"""
|
|
493
|
+
|
|
494
|
+
@abstractmethod
|
|
495
|
+
def create_checkpoint(
|
|
283
496
|
self,
|
|
284
497
|
job_id: str,
|
|
285
|
-
|
|
286
|
-
|
|
498
|
+
_hash: str,
|
|
499
|
+
partial: bool = False,
|
|
500
|
+
conn: Any | None = None,
|
|
501
|
+
) -> Checkpoint:
|
|
502
|
+
"""Creates new checkpoint"""
|
|
503
|
+
|
|
504
|
+
#
|
|
505
|
+
# Dataset Version Jobs (many-to-many)
|
|
506
|
+
#
|
|
507
|
+
|
|
508
|
+
@abstractmethod
|
|
509
|
+
def link_dataset_version_to_job(
|
|
510
|
+
self,
|
|
511
|
+
dataset_version_id: int,
|
|
512
|
+
job_id: str,
|
|
513
|
+
is_creator: bool = False,
|
|
514
|
+
conn=None,
|
|
287
515
|
) -> None:
|
|
288
|
-
"""
|
|
516
|
+
"""
|
|
517
|
+
Link dataset version to job.
|
|
518
|
+
|
|
519
|
+
This atomically:
|
|
520
|
+
1. Creates a link in the dataset_version_jobs junction table
|
|
521
|
+
2. Updates dataset_version.job_id to point to this job
|
|
522
|
+
"""
|
|
289
523
|
|
|
290
524
|
@abstractmethod
|
|
291
|
-
def
|
|
292
|
-
"""
|
|
293
|
-
|
|
525
|
+
def get_ancestor_job_ids(self, job_id: str, conn=None) -> list[str]:
|
|
526
|
+
"""Get all ancestor job IDs for a given job."""
|
|
527
|
+
|
|
528
|
+
@abstractmethod
|
|
529
|
+
def get_dataset_version_for_job_ancestry(
|
|
530
|
+
self,
|
|
531
|
+
dataset_name: str,
|
|
532
|
+
namespace_name: str,
|
|
533
|
+
project_name: str,
|
|
534
|
+
job_id: str,
|
|
535
|
+
conn=None,
|
|
536
|
+
) -> DatasetVersion | None:
|
|
537
|
+
"""
|
|
538
|
+
Find the dataset version that was created by any job in the ancestry.
|
|
539
|
+
Returns the most recently linked version from these jobs.
|
|
540
|
+
"""
|
|
294
541
|
|
|
295
542
|
|
|
296
543
|
class AbstractDBMetastore(AbstractMetastore):
|
|
@@ -301,14 +548,18 @@ class AbstractDBMetastore(AbstractMetastore):
|
|
|
301
548
|
and has shared logic for all database systems currently in use.
|
|
302
549
|
"""
|
|
303
550
|
|
|
551
|
+
NAMESPACE_TABLE = "namespaces"
|
|
552
|
+
PROJECT_TABLE = "projects"
|
|
304
553
|
DATASET_TABLE = "datasets"
|
|
305
554
|
DATASET_VERSION_TABLE = "datasets_versions"
|
|
306
555
|
DATASET_DEPENDENCY_TABLE = "datasets_dependencies"
|
|
556
|
+
DATASET_VERSION_JOBS_TABLE = "dataset_version_jobs"
|
|
307
557
|
JOBS_TABLE = "jobs"
|
|
558
|
+
CHECKPOINTS_TABLE = "checkpoints"
|
|
308
559
|
|
|
309
560
|
db: "DatabaseEngine"
|
|
310
561
|
|
|
311
|
-
def __init__(self, uri:
|
|
562
|
+
def __init__(self, uri: StorageURI | None = None):
|
|
312
563
|
uri = uri or StorageURI("")
|
|
313
564
|
super().__init__(uri)
|
|
314
565
|
|
|
@@ -319,14 +570,65 @@ class AbstractDBMetastore(AbstractMetastore):
|
|
|
319
570
|
def cleanup_tables(self, temp_table_names: list[str]) -> None:
|
|
320
571
|
"""Cleanup temp tables."""
|
|
321
572
|
|
|
573
|
+
@classmethod
|
|
574
|
+
def _namespaces_columns(cls) -> list["SchemaItem"]:
|
|
575
|
+
"""Namespace table columns."""
|
|
576
|
+
return [
|
|
577
|
+
Column("id", Integer, primary_key=True),
|
|
578
|
+
Column("uuid", Text, nullable=False, default=uuid4()),
|
|
579
|
+
Column("name", Text, nullable=False),
|
|
580
|
+
Column("description", Text),
|
|
581
|
+
Column("created_at", DateTime(timezone=True)),
|
|
582
|
+
]
|
|
583
|
+
|
|
584
|
+
@cached_property
|
|
585
|
+
def _namespaces_fields(self) -> list[str]:
|
|
586
|
+
return [
|
|
587
|
+
c.name # type: ignore [attr-defined]
|
|
588
|
+
for c in self._namespaces_columns()
|
|
589
|
+
if c.name # type: ignore [attr-defined]
|
|
590
|
+
]
|
|
591
|
+
|
|
592
|
+
@classmethod
|
|
593
|
+
def _projects_columns(cls) -> list["SchemaItem"]:
|
|
594
|
+
"""Project table columns."""
|
|
595
|
+
return [
|
|
596
|
+
Column("id", Integer, primary_key=True),
|
|
597
|
+
Column("uuid", Text, nullable=False, default=uuid4()),
|
|
598
|
+
Column("name", Text, nullable=False),
|
|
599
|
+
Column("description", Text),
|
|
600
|
+
Column("created_at", DateTime(timezone=True)),
|
|
601
|
+
Column(
|
|
602
|
+
"namespace_id",
|
|
603
|
+
Integer,
|
|
604
|
+
ForeignKey(f"{cls.NAMESPACE_TABLE}.id", ondelete="CASCADE"),
|
|
605
|
+
nullable=False,
|
|
606
|
+
),
|
|
607
|
+
UniqueConstraint("namespace_id", "name"),
|
|
608
|
+
]
|
|
609
|
+
|
|
610
|
+
@cached_property
|
|
611
|
+
def _projects_fields(self) -> list[str]:
|
|
612
|
+
return [
|
|
613
|
+
c.name # type: ignore [attr-defined]
|
|
614
|
+
for c in self._projects_columns()
|
|
615
|
+
if c.name # type: ignore [attr-defined]
|
|
616
|
+
]
|
|
617
|
+
|
|
322
618
|
@classmethod
|
|
323
619
|
def _datasets_columns(cls) -> list["SchemaItem"]:
|
|
324
620
|
"""Datasets table columns."""
|
|
325
621
|
return [
|
|
326
622
|
Column("id", Integer, primary_key=True),
|
|
623
|
+
Column(
|
|
624
|
+
"project_id",
|
|
625
|
+
Integer,
|
|
626
|
+
ForeignKey(f"{cls.PROJECT_TABLE}.id", ondelete="CASCADE"),
|
|
627
|
+
nullable=False,
|
|
628
|
+
),
|
|
327
629
|
Column("name", Text, nullable=False),
|
|
328
630
|
Column("description", Text),
|
|
329
|
-
Column("
|
|
631
|
+
Column("attrs", JSON, nullable=True),
|
|
330
632
|
Column("status", Integer, nullable=False),
|
|
331
633
|
Column("feature_schema", JSON, nullable=True),
|
|
332
634
|
Column("created_at", DateTime(timezone=True)),
|
|
@@ -367,7 +669,7 @@ class AbstractDBMetastore(AbstractMetastore):
|
|
|
367
669
|
ForeignKey(f"{cls.DATASET_TABLE}.id", ondelete="CASCADE"),
|
|
368
670
|
nullable=False,
|
|
369
671
|
),
|
|
370
|
-
Column("version",
|
|
672
|
+
Column("version", Text, nullable=False, default="1.0.0"),
|
|
371
673
|
Column(
|
|
372
674
|
"status",
|
|
373
675
|
Integer,
|
|
@@ -442,6 +744,16 @@ class AbstractDBMetastore(AbstractMetastore):
|
|
|
442
744
|
#
|
|
443
745
|
# Query Tables
|
|
444
746
|
#
|
|
747
|
+
@cached_property
|
|
748
|
+
def _namespaces(self) -> Table:
|
|
749
|
+
return Table(
|
|
750
|
+
self.NAMESPACE_TABLE, self.db.metadata, *self._namespaces_columns()
|
|
751
|
+
)
|
|
752
|
+
|
|
753
|
+
@cached_property
|
|
754
|
+
def _projects(self) -> Table:
|
|
755
|
+
return Table(self.PROJECT_TABLE, self.db.metadata, *self._projects_columns())
|
|
756
|
+
|
|
445
757
|
@cached_property
|
|
446
758
|
def _datasets(self) -> Table:
|
|
447
759
|
return Table(self.DATASET_TABLE, self.db.metadata, *self._datasets_columns())
|
|
@@ -465,6 +777,31 @@ class AbstractDBMetastore(AbstractMetastore):
|
|
|
465
777
|
#
|
|
466
778
|
# Query Starters (These can be overridden by subclasses)
|
|
467
779
|
#
|
|
780
|
+
@abstractmethod
|
|
781
|
+
def _namespaces_insert(self) -> "Insert": ...
|
|
782
|
+
|
|
783
|
+
def _namespaces_select(self, *columns) -> "Select":
|
|
784
|
+
if not columns:
|
|
785
|
+
return self._namespaces.select()
|
|
786
|
+
return select(*columns)
|
|
787
|
+
|
|
788
|
+
def _namespaces_update(self) -> "Update":
|
|
789
|
+
return self._namespaces.update()
|
|
790
|
+
|
|
791
|
+
def _namespaces_delete(self) -> "Delete":
|
|
792
|
+
return self._namespaces.delete()
|
|
793
|
+
|
|
794
|
+
@abstractmethod
|
|
795
|
+
def _projects_insert(self) -> "Insert": ...
|
|
796
|
+
|
|
797
|
+
def _projects_select(self, *columns) -> "Select":
|
|
798
|
+
if not columns:
|
|
799
|
+
return self._projects.select()
|
|
800
|
+
return select(*columns)
|
|
801
|
+
|
|
802
|
+
def _projects_delete(self) -> "Delete":
|
|
803
|
+
return self._projects.delete()
|
|
804
|
+
|
|
468
805
|
@abstractmethod
|
|
469
806
|
def _datasets_insert(self) -> "Insert": ...
|
|
470
807
|
|
|
@@ -507,6 +844,197 @@ class AbstractDBMetastore(AbstractMetastore):
|
|
|
507
844
|
def _datasets_dependencies_delete(self) -> "Delete":
|
|
508
845
|
return self._datasets_dependencies.delete()
|
|
509
846
|
|
|
847
|
+
#
|
|
848
|
+
# Namespaces
|
|
849
|
+
#
|
|
850
|
+
|
|
851
|
+
def create_namespace(
|
|
852
|
+
self,
|
|
853
|
+
name: str,
|
|
854
|
+
description: str | None = None,
|
|
855
|
+
uuid: str | None = None,
|
|
856
|
+
ignore_if_exists: bool = True,
|
|
857
|
+
validate: bool = True,
|
|
858
|
+
**kwargs,
|
|
859
|
+
) -> Namespace:
|
|
860
|
+
if validate:
|
|
861
|
+
Namespace.validate_name(name)
|
|
862
|
+
query = self._namespaces_insert().values(
|
|
863
|
+
name=name,
|
|
864
|
+
uuid=uuid or str(uuid4()),
|
|
865
|
+
created_at=datetime.now(timezone.utc),
|
|
866
|
+
description=description,
|
|
867
|
+
)
|
|
868
|
+
if ignore_if_exists and hasattr(query, "on_conflict_do_nothing"):
|
|
869
|
+
# SQLite and PostgreSQL both support 'on_conflict_do_nothing',
|
|
870
|
+
# but generic SQL does not
|
|
871
|
+
query = query.on_conflict_do_nothing(index_elements=["name"])
|
|
872
|
+
self.db.execute(query)
|
|
873
|
+
|
|
874
|
+
return self.get_namespace(name)
|
|
875
|
+
|
|
876
|
+
def remove_namespace(self, namespace_id: int, conn=None) -> None:
|
|
877
|
+
num_projects = self.count_projects(namespace_id)
|
|
878
|
+
if num_projects > 0:
|
|
879
|
+
raise NamespaceDeleteNotAllowedError(
|
|
880
|
+
f"Namespace cannot be removed. It contains {num_projects} project(s). "
|
|
881
|
+
"Please remove the project(s) first."
|
|
882
|
+
)
|
|
883
|
+
|
|
884
|
+
n = self._namespaces
|
|
885
|
+
with self.db.transaction():
|
|
886
|
+
self.db.execute(self._namespaces_delete().where(n.c.id == namespace_id))
|
|
887
|
+
|
|
888
|
+
def get_namespace(self, name: str, conn=None) -> Namespace:
|
|
889
|
+
"""Gets a single namespace by name"""
|
|
890
|
+
n = self._namespaces
|
|
891
|
+
|
|
892
|
+
query = self._namespaces_select(
|
|
893
|
+
*(getattr(n.c, f) for f in self._namespaces_fields),
|
|
894
|
+
).where(n.c.name == name)
|
|
895
|
+
rows = list(self.db.execute(query, conn=conn))
|
|
896
|
+
if not rows:
|
|
897
|
+
raise NamespaceNotFoundError(f"Namespace {name} not found.")
|
|
898
|
+
return self.namespace_class.parse(*rows[0])
|
|
899
|
+
|
|
900
|
+
def list_namespaces(self, conn=None) -> list[Namespace]:
|
|
901
|
+
"""Gets a list of all namespaces"""
|
|
902
|
+
n = self._namespaces
|
|
903
|
+
|
|
904
|
+
query = self._namespaces_select(
|
|
905
|
+
*(getattr(n.c, f) for f in self._namespaces_fields),
|
|
906
|
+
)
|
|
907
|
+
rows = list(self.db.execute(query, conn=conn))
|
|
908
|
+
|
|
909
|
+
return [self.namespace_class.parse(*r) for r in rows]
|
|
910
|
+
|
|
911
|
+
#
|
|
912
|
+
# Projects
|
|
913
|
+
#
|
|
914
|
+
|
|
915
|
+
def create_project(
|
|
916
|
+
self,
|
|
917
|
+
namespace_name: str,
|
|
918
|
+
name: str,
|
|
919
|
+
description: str | None = None,
|
|
920
|
+
uuid: str | None = None,
|
|
921
|
+
ignore_if_exists: bool = True,
|
|
922
|
+
validate: bool = True,
|
|
923
|
+
**kwargs,
|
|
924
|
+
) -> Project:
|
|
925
|
+
if validate:
|
|
926
|
+
Project.validate_name(name)
|
|
927
|
+
try:
|
|
928
|
+
namespace = self.get_namespace(namespace_name)
|
|
929
|
+
except NamespaceNotFoundError:
|
|
930
|
+
namespace = self.create_namespace(namespace_name, validate=validate)
|
|
931
|
+
|
|
932
|
+
query = self._projects_insert().values(
|
|
933
|
+
namespace_id=namespace.id,
|
|
934
|
+
uuid=uuid or str(uuid4()),
|
|
935
|
+
name=name,
|
|
936
|
+
created_at=datetime.now(timezone.utc),
|
|
937
|
+
description=description,
|
|
938
|
+
)
|
|
939
|
+
if ignore_if_exists and hasattr(query, "on_conflict_do_nothing"):
|
|
940
|
+
# SQLite and PostgreSQL both support 'on_conflict_do_nothing',
|
|
941
|
+
# but generic SQL does not
|
|
942
|
+
query = query.on_conflict_do_nothing(
|
|
943
|
+
index_elements=["namespace_id", "name"]
|
|
944
|
+
)
|
|
945
|
+
self.db.execute(query)
|
|
946
|
+
|
|
947
|
+
return self.get_project(name, namespace.name)
|
|
948
|
+
|
|
949
|
+
def _projects_base_query(self) -> "Select":
|
|
950
|
+
n = self._namespaces
|
|
951
|
+
p = self._projects
|
|
952
|
+
|
|
953
|
+
query = self._projects_select(
|
|
954
|
+
*(getattr(n.c, f) for f in self._namespaces_fields),
|
|
955
|
+
*(getattr(p.c, f) for f in self._projects_fields),
|
|
956
|
+
)
|
|
957
|
+
return query.select_from(n.join(p, n.c.id == p.c.namespace_id))
|
|
958
|
+
|
|
959
|
+
def get_project(
|
|
960
|
+
self, name: str, namespace_name: str, create: bool = False, conn=None
|
|
961
|
+
) -> Project:
|
|
962
|
+
"""Gets a single project inside some namespace by name"""
|
|
963
|
+
n = self._namespaces
|
|
964
|
+
p = self._projects
|
|
965
|
+
validate = True
|
|
966
|
+
|
|
967
|
+
if self.is_listing_project(name, namespace_name) or self.is_default_project(
|
|
968
|
+
name, namespace_name
|
|
969
|
+
):
|
|
970
|
+
# we are always creating default and listing projects if they don't exist
|
|
971
|
+
create = True
|
|
972
|
+
validate = False
|
|
973
|
+
|
|
974
|
+
query = self._projects_base_query().where(
|
|
975
|
+
p.c.name == name, n.c.name == namespace_name
|
|
976
|
+
)
|
|
977
|
+
|
|
978
|
+
rows = list(self.db.execute(query, conn=conn))
|
|
979
|
+
if not rows:
|
|
980
|
+
if create:
|
|
981
|
+
return self.create_project(namespace_name, name, validate=validate)
|
|
982
|
+
raise ProjectNotFoundError(
|
|
983
|
+
f"Project {name} in namespace {namespace_name} not found."
|
|
984
|
+
)
|
|
985
|
+
return self.project_class.parse(*rows[0])
|
|
986
|
+
|
|
987
|
+
def get_project_by_id(self, project_id: int, conn=None) -> Project:
|
|
988
|
+
"""Gets a single project by id"""
|
|
989
|
+
p = self._projects
|
|
990
|
+
|
|
991
|
+
query = self._projects_base_query().where(p.c.id == project_id)
|
|
992
|
+
|
|
993
|
+
rows = list(self.db.execute(query, conn=conn))
|
|
994
|
+
if not rows:
|
|
995
|
+
raise ProjectNotFoundError(f"Project with id {project_id} not found.")
|
|
996
|
+
return self.project_class.parse(*rows[0])
|
|
997
|
+
|
|
998
|
+
def count_projects(self, namespace_id: int | None = None) -> int:
|
|
999
|
+
p = self._projects
|
|
1000
|
+
|
|
1001
|
+
query = self._projects_base_query()
|
|
1002
|
+
if namespace_id:
|
|
1003
|
+
query = query.where(p.c.namespace_id == namespace_id)
|
|
1004
|
+
|
|
1005
|
+
query = select(f.count(1)).select_from(query.subquery())
|
|
1006
|
+
|
|
1007
|
+
return next(self.db.execute(query))[0]
|
|
1008
|
+
|
|
1009
|
+
def remove_project(self, project_id: int, conn=None) -> None:
|
|
1010
|
+
num_datasets = self.count_datasets(project_id)
|
|
1011
|
+
if num_datasets > 0:
|
|
1012
|
+
raise ProjectDeleteNotAllowedError(
|
|
1013
|
+
f"Project cannot be removed. It contains {num_datasets} dataset(s). "
|
|
1014
|
+
"Please remove the dataset(s) first."
|
|
1015
|
+
)
|
|
1016
|
+
|
|
1017
|
+
p = self._projects
|
|
1018
|
+
with self.db.transaction():
|
|
1019
|
+
self.db.execute(self._projects_delete().where(p.c.id == project_id))
|
|
1020
|
+
|
|
1021
|
+
def list_projects(
|
|
1022
|
+
self, namespace_id: int | None = None, conn=None
|
|
1023
|
+
) -> list[Project]:
|
|
1024
|
+
"""
|
|
1025
|
+
Gets a list of projects inside some namespace, or in all namespaces
|
|
1026
|
+
"""
|
|
1027
|
+
p = self._projects
|
|
1028
|
+
|
|
1029
|
+
query = self._projects_base_query()
|
|
1030
|
+
|
|
1031
|
+
if namespace_id:
|
|
1032
|
+
query = query.where(p.c.namespace_id == namespace_id)
|
|
1033
|
+
|
|
1034
|
+
rows = list(self.db.execute(query, conn=conn))
|
|
1035
|
+
|
|
1036
|
+
return [self.project_class.parse(*r) for r in rows]
|
|
1037
|
+
|
|
510
1038
|
#
|
|
511
1039
|
# Datasets
|
|
512
1040
|
#
|
|
@@ -514,20 +1042,26 @@ class AbstractDBMetastore(AbstractMetastore):
|
|
|
514
1042
|
def create_dataset(
|
|
515
1043
|
self,
|
|
516
1044
|
name: str,
|
|
1045
|
+
project_id: int | None = None,
|
|
517
1046
|
status: int = DatasetStatus.CREATED,
|
|
518
|
-
sources:
|
|
519
|
-
feature_schema:
|
|
1047
|
+
sources: list[str] | None = None,
|
|
1048
|
+
feature_schema: dict | None = None,
|
|
520
1049
|
query_script: str = "",
|
|
521
|
-
schema:
|
|
1050
|
+
schema: dict[str, Any] | None = None,
|
|
522
1051
|
ignore_if_exists: bool = False,
|
|
523
|
-
description:
|
|
524
|
-
|
|
1052
|
+
description: str | None = None,
|
|
1053
|
+
attrs: list[str] | None = None,
|
|
525
1054
|
**kwargs, # TODO registered = True / False
|
|
526
1055
|
) -> DatasetRecord:
|
|
527
1056
|
"""Creates new dataset."""
|
|
528
|
-
|
|
1057
|
+
if not project_id:
|
|
1058
|
+
project = self.default_project
|
|
1059
|
+
else:
|
|
1060
|
+
project = self.get_project_by_id(project_id)
|
|
1061
|
+
|
|
529
1062
|
query = self._datasets_insert().values(
|
|
530
1063
|
name=name,
|
|
1064
|
+
project_id=project.id,
|
|
531
1065
|
status=status,
|
|
532
1066
|
feature_schema=json.dumps(feature_schema or {}),
|
|
533
1067
|
created_at=datetime.now(timezone.utc),
|
|
@@ -538,36 +1072,38 @@ class AbstractDBMetastore(AbstractMetastore):
|
|
|
538
1072
|
query_script=query_script,
|
|
539
1073
|
schema=json.dumps(schema or {}),
|
|
540
1074
|
description=description,
|
|
541
|
-
|
|
1075
|
+
attrs=json.dumps(attrs or []),
|
|
542
1076
|
)
|
|
543
1077
|
if ignore_if_exists and hasattr(query, "on_conflict_do_nothing"):
|
|
544
1078
|
# SQLite and PostgreSQL both support 'on_conflict_do_nothing',
|
|
545
1079
|
# but generic SQL does not
|
|
546
|
-
query = query.on_conflict_do_nothing(index_elements=["name"])
|
|
1080
|
+
query = query.on_conflict_do_nothing(index_elements=["project_id", "name"])
|
|
547
1081
|
self.db.execute(query)
|
|
548
1082
|
|
|
549
|
-
return self.get_dataset(
|
|
1083
|
+
return self.get_dataset(
|
|
1084
|
+
name, namespace_name=project.namespace.name, project_name=project.name
|
|
1085
|
+
)
|
|
550
1086
|
|
|
551
1087
|
def create_dataset_version( # noqa: PLR0913
|
|
552
1088
|
self,
|
|
553
1089
|
dataset: DatasetRecord,
|
|
554
|
-
version:
|
|
1090
|
+
version: str,
|
|
555
1091
|
status: int,
|
|
556
1092
|
sources: str = "",
|
|
557
|
-
feature_schema:
|
|
1093
|
+
feature_schema: dict | None = None,
|
|
558
1094
|
query_script: str = "",
|
|
559
1095
|
error_message: str = "",
|
|
560
1096
|
error_stack: str = "",
|
|
561
1097
|
script_output: str = "",
|
|
562
|
-
created_at:
|
|
563
|
-
finished_at:
|
|
564
|
-
schema:
|
|
1098
|
+
created_at: datetime | None = None,
|
|
1099
|
+
finished_at: datetime | None = None,
|
|
1100
|
+
schema: dict[str, Any] | None = None,
|
|
565
1101
|
ignore_if_exists: bool = False,
|
|
566
|
-
num_objects:
|
|
567
|
-
size:
|
|
568
|
-
preview:
|
|
569
|
-
job_id:
|
|
570
|
-
uuid:
|
|
1102
|
+
num_objects: int | None = None,
|
|
1103
|
+
size: int | None = None,
|
|
1104
|
+
preview: list[dict] | None = None,
|
|
1105
|
+
job_id: str | None = None,
|
|
1106
|
+
uuid: str | None = None,
|
|
571
1107
|
conn=None,
|
|
572
1108
|
) -> DatasetRecord:
|
|
573
1109
|
"""Creates new dataset version."""
|
|
@@ -603,7 +1139,12 @@ class AbstractDBMetastore(AbstractMetastore):
|
|
|
603
1139
|
)
|
|
604
1140
|
self.db.execute(query, conn=conn)
|
|
605
1141
|
|
|
606
|
-
return self.get_dataset(
|
|
1142
|
+
return self.get_dataset(
|
|
1143
|
+
dataset.name,
|
|
1144
|
+
namespace_name=dataset.project.namespace.name,
|
|
1145
|
+
project_name=dataset.project.name,
|
|
1146
|
+
conn=conn,
|
|
1147
|
+
)
|
|
607
1148
|
|
|
608
1149
|
def remove_dataset(self, dataset: DatasetRecord) -> None:
|
|
609
1150
|
"""Removes dataset."""
|
|
@@ -617,26 +1158,47 @@ class AbstractDBMetastore(AbstractMetastore):
|
|
|
617
1158
|
self, dataset: DatasetRecord, conn=None, **kwargs
|
|
618
1159
|
) -> DatasetRecord:
|
|
619
1160
|
"""Updates dataset fields."""
|
|
620
|
-
values = {}
|
|
621
|
-
dataset_values = {}
|
|
1161
|
+
values: dict[str, Any] = {}
|
|
1162
|
+
dataset_values: dict[str, Any] = {}
|
|
622
1163
|
for field, value in kwargs.items():
|
|
623
|
-
if field in self._dataset_fields
|
|
624
|
-
|
|
625
|
-
|
|
1164
|
+
if field in ("id", "created_at") or field not in self._dataset_fields:
|
|
1165
|
+
continue # these fields are read-only or not applicable
|
|
1166
|
+
|
|
1167
|
+
if value is None and field in ("name", "status", "sources", "query_script"):
|
|
1168
|
+
raise ValueError(f"Field {field} cannot be None")
|
|
1169
|
+
if field == "name" and not value:
|
|
1170
|
+
raise ValueError("name cannot be empty")
|
|
1171
|
+
|
|
1172
|
+
if field == "attrs":
|
|
1173
|
+
if value is None:
|
|
1174
|
+
values[field] = None
|
|
626
1175
|
else:
|
|
627
|
-
values[field] = value
|
|
628
|
-
|
|
629
|
-
|
|
1176
|
+
values[field] = json.dumps(value)
|
|
1177
|
+
dataset_values[field] = value
|
|
1178
|
+
elif field == "schema":
|
|
1179
|
+
if value is None:
|
|
1180
|
+
values[field] = None
|
|
1181
|
+
dataset_values[field] = None
|
|
630
1182
|
else:
|
|
631
|
-
|
|
1183
|
+
values[field] = json.dumps(value)
|
|
1184
|
+
dataset_values[field] = parse_schema(value)
|
|
1185
|
+
elif field == "project_id":
|
|
1186
|
+
if not value:
|
|
1187
|
+
raise ValueError("Cannot set empty project_id for dataset")
|
|
1188
|
+
dataset_values["project"] = self.get_project_by_id(value)
|
|
1189
|
+
values[field] = value
|
|
1190
|
+
else:
|
|
1191
|
+
values[field] = value
|
|
1192
|
+
dataset_values[field] = value
|
|
632
1193
|
|
|
633
1194
|
if not values:
|
|
634
|
-
#
|
|
635
|
-
return dataset
|
|
1195
|
+
return dataset # nothing to update
|
|
636
1196
|
|
|
637
1197
|
d = self._datasets
|
|
638
1198
|
self.db.execute(
|
|
639
|
-
self._datasets_update()
|
|
1199
|
+
self._datasets_update()
|
|
1200
|
+
.where(d.c.name == dataset.name, d.c.project_id == dataset.project.id)
|
|
1201
|
+
.values(values),
|
|
640
1202
|
conn=conn,
|
|
641
1203
|
) # type: ignore [attr-defined]
|
|
642
1204
|
|
|
@@ -645,46 +1207,79 @@ class AbstractDBMetastore(AbstractMetastore):
|
|
|
645
1207
|
return result_ds
|
|
646
1208
|
|
|
647
1209
|
def update_dataset_version(
|
|
648
|
-
self, dataset: DatasetRecord, version:
|
|
1210
|
+
self, dataset: DatasetRecord, version: str, conn=None, **kwargs
|
|
649
1211
|
) -> DatasetVersion:
|
|
650
1212
|
"""Updates dataset fields."""
|
|
651
|
-
|
|
652
|
-
|
|
653
|
-
values = {}
|
|
1213
|
+
values: dict[str, Any] = {}
|
|
1214
|
+
version_values: dict[str, Any] = {}
|
|
654
1215
|
for field, value in kwargs.items():
|
|
655
|
-
if
|
|
656
|
-
|
|
657
|
-
|
|
658
|
-
|
|
659
|
-
|
|
660
|
-
|
|
661
|
-
|
|
662
|
-
|
|
1216
|
+
if (
|
|
1217
|
+
field in ("id", "created_at")
|
|
1218
|
+
or field not in self._dataset_version_fields
|
|
1219
|
+
):
|
|
1220
|
+
continue # these fields are read-only or not applicable
|
|
1221
|
+
|
|
1222
|
+
if value is None and field in (
|
|
1223
|
+
"status",
|
|
1224
|
+
"sources",
|
|
1225
|
+
"query_script",
|
|
1226
|
+
"error_message",
|
|
1227
|
+
"error_stack",
|
|
1228
|
+
"script_output",
|
|
1229
|
+
"uuid",
|
|
1230
|
+
):
|
|
1231
|
+
raise ValueError(f"Field {field} cannot be None")
|
|
1232
|
+
|
|
1233
|
+
if field == "schema":
|
|
1234
|
+
values[field] = json.dumps(value) if value else None
|
|
1235
|
+
version_values[field] = parse_schema(value) if value else None
|
|
1236
|
+
elif field == "feature_schema":
|
|
1237
|
+
if value is None:
|
|
1238
|
+
values[field] = None
|
|
1239
|
+
else:
|
|
1240
|
+
values[field] = json.dumps(value)
|
|
1241
|
+
version_values[field] = value
|
|
1242
|
+
elif field == "preview":
|
|
1243
|
+
if value is None:
|
|
1244
|
+
values[field] = None
|
|
1245
|
+
elif not isinstance(value, list):
|
|
1246
|
+
raise ValueError(
|
|
1247
|
+
f"Field '{field}' must be a list, got {type(value).__name__}"
|
|
1248
|
+
)
|
|
663
1249
|
else:
|
|
664
|
-
values[field] = value
|
|
665
|
-
|
|
1250
|
+
values[field] = json.dumps(value, serialize_bytes=True)
|
|
1251
|
+
version_values["_preview_data"] = value
|
|
1252
|
+
else:
|
|
1253
|
+
values[field] = value
|
|
1254
|
+
version_values[field] = value
|
|
666
1255
|
|
|
667
1256
|
if not values:
|
|
668
|
-
|
|
669
|
-
return dataset_version
|
|
1257
|
+
return dataset.get_version(version)
|
|
670
1258
|
|
|
671
1259
|
dv = self._datasets_versions
|
|
672
1260
|
self.db.execute(
|
|
673
1261
|
self._datasets_versions_update()
|
|
674
|
-
.where(dv.c.
|
|
1262
|
+
.where(dv.c.dataset_id == dataset.id, dv.c.version == version)
|
|
675
1263
|
.values(values),
|
|
676
1264
|
conn=conn,
|
|
677
1265
|
) # type: ignore [attr-defined]
|
|
678
1266
|
|
|
679
|
-
|
|
1267
|
+
for v in dataset.versions:
|
|
1268
|
+
if v.version == version:
|
|
1269
|
+
v.update(**version_values)
|
|
1270
|
+
return v
|
|
680
1271
|
|
|
681
|
-
|
|
1272
|
+
raise DatasetVersionNotFoundError(
|
|
1273
|
+
f"Dataset {dataset.name} does not have version {version}"
|
|
1274
|
+
)
|
|
1275
|
+
|
|
1276
|
+
def _parse_dataset(self, rows) -> DatasetRecord | None:
|
|
682
1277
|
versions = [self.dataset_class.parse(*r) for r in rows]
|
|
683
1278
|
if not versions:
|
|
684
1279
|
return None
|
|
685
1280
|
return reduce(lambda ds, version: ds.merge_versions(version), versions)
|
|
686
1281
|
|
|
687
|
-
def _parse_list_dataset(self, rows) ->
|
|
1282
|
+
def _parse_list_dataset(self, rows) -> DatasetListRecord | None:
|
|
688
1283
|
versions = [self.dataset_list_class.parse(*r) for r in rows]
|
|
689
1284
|
if not versions:
|
|
690
1285
|
return None
|
|
@@ -692,69 +1287,124 @@ class AbstractDBMetastore(AbstractMetastore):
|
|
|
692
1287
|
|
|
693
1288
|
def _parse_dataset_list(self, rows) -> Iterator["DatasetListRecord"]:
|
|
694
1289
|
# grouping rows by dataset id
|
|
695
|
-
for _, g in groupby(rows, lambda r: r[
|
|
1290
|
+
for _, g in groupby(rows, lambda r: r[11]):
|
|
696
1291
|
dataset = self._parse_list_dataset(list(g))
|
|
697
1292
|
if dataset:
|
|
698
1293
|
yield dataset
|
|
699
1294
|
|
|
700
1295
|
def _get_dataset_query(
|
|
701
1296
|
self,
|
|
1297
|
+
namespace_fields: list[str],
|
|
1298
|
+
project_fields: list[str],
|
|
702
1299
|
dataset_fields: list[str],
|
|
703
1300
|
dataset_version_fields: list[str],
|
|
704
1301
|
isouter: bool = True,
|
|
705
|
-
):
|
|
1302
|
+
) -> "Select":
|
|
706
1303
|
if not (
|
|
707
1304
|
self.db.has_table(self._datasets.name)
|
|
708
1305
|
and self.db.has_table(self._datasets_versions.name)
|
|
709
1306
|
):
|
|
710
1307
|
raise TableMissingError
|
|
711
1308
|
|
|
1309
|
+
n = self._namespaces
|
|
1310
|
+
p = self._projects
|
|
712
1311
|
d = self._datasets
|
|
713
1312
|
dv = self._datasets_versions
|
|
714
1313
|
|
|
715
1314
|
query = self._datasets_select(
|
|
1315
|
+
*(getattr(n.c, f) for f in namespace_fields),
|
|
1316
|
+
*(getattr(p.c, f) for f in project_fields),
|
|
716
1317
|
*(getattr(d.c, f) for f in dataset_fields),
|
|
717
1318
|
*(getattr(dv.c, f) for f in dataset_version_fields),
|
|
718
1319
|
)
|
|
719
|
-
j =
|
|
1320
|
+
j = (
|
|
1321
|
+
n.join(p, n.c.id == p.c.namespace_id)
|
|
1322
|
+
.join(d, p.c.id == d.c.project_id)
|
|
1323
|
+
.join(dv, d.c.id == dv.c.dataset_id, isouter=isouter)
|
|
1324
|
+
)
|
|
720
1325
|
return query.select_from(j)
|
|
721
1326
|
|
|
722
|
-
def _base_dataset_query(self):
|
|
1327
|
+
def _base_dataset_query(self) -> "Select":
|
|
723
1328
|
return self._get_dataset_query(
|
|
724
|
-
self.
|
|
1329
|
+
self._namespaces_fields,
|
|
1330
|
+
self._projects_fields,
|
|
1331
|
+
self._dataset_fields,
|
|
1332
|
+
self._dataset_version_fields,
|
|
725
1333
|
)
|
|
726
1334
|
|
|
727
|
-
def _base_list_datasets_query(self):
|
|
1335
|
+
def _base_list_datasets_query(self) -> "Select":
|
|
728
1336
|
return self._get_dataset_query(
|
|
729
|
-
self.
|
|
1337
|
+
self._namespaces_fields,
|
|
1338
|
+
self._projects_fields,
|
|
1339
|
+
self._dataset_list_fields,
|
|
1340
|
+
self._dataset_list_version_fields,
|
|
1341
|
+
isouter=False,
|
|
730
1342
|
)
|
|
731
1343
|
|
|
732
|
-
def list_datasets(
|
|
733
|
-
|
|
1344
|
+
def list_datasets(
|
|
1345
|
+
self, project_id: int | None = None
|
|
1346
|
+
) -> Iterator["DatasetListRecord"]:
|
|
1347
|
+
d = self._datasets
|
|
734
1348
|
query = self._base_list_datasets_query().order_by(
|
|
735
1349
|
self._datasets.c.name, self._datasets_versions.c.version
|
|
736
1350
|
)
|
|
1351
|
+
if project_id:
|
|
1352
|
+
query = query.where(d.c.project_id == project_id)
|
|
737
1353
|
yield from self._parse_dataset_list(self.db.execute(query))
|
|
738
1354
|
|
|
1355
|
+
def count_datasets(self, project_id: int | None = None) -> int:
|
|
1356
|
+
d = self._datasets
|
|
1357
|
+
query = self._datasets_select()
|
|
1358
|
+
if project_id:
|
|
1359
|
+
query = query.where(d.c.project_id == project_id)
|
|
1360
|
+
|
|
1361
|
+
query = select(f.count(1)).select_from(query.subquery())
|
|
1362
|
+
|
|
1363
|
+
return next(self.db.execute(query))[0]
|
|
1364
|
+
|
|
739
1365
|
def list_datasets_by_prefix(
|
|
740
|
-
self, prefix: str, conn=None
|
|
1366
|
+
self, prefix: str, project_id: int | None = None, conn=None
|
|
741
1367
|
) -> Iterator["DatasetListRecord"]:
|
|
1368
|
+
d = self._datasets
|
|
742
1369
|
query = self._base_list_datasets_query()
|
|
1370
|
+
if project_id:
|
|
1371
|
+
query = query.where(d.c.project_id == project_id)
|
|
743
1372
|
query = query.where(self._datasets.c.name.startswith(prefix))
|
|
744
1373
|
yield from self._parse_dataset_list(self.db.execute(query))
|
|
745
1374
|
|
|
746
|
-
def get_dataset(
|
|
747
|
-
|
|
1375
|
+
def get_dataset(
|
|
1376
|
+
self,
|
|
1377
|
+
name: str, # normal, not full dataset name
|
|
1378
|
+
namespace_name: str | None = None,
|
|
1379
|
+
project_name: str | None = None,
|
|
1380
|
+
conn=None,
|
|
1381
|
+
) -> DatasetRecord:
|
|
1382
|
+
"""
|
|
1383
|
+
Gets a single dataset in project by dataset name.
|
|
1384
|
+
"""
|
|
1385
|
+
namespace_name = namespace_name or self.default_namespace_name
|
|
1386
|
+
project_name = project_name or self.default_project_name
|
|
1387
|
+
|
|
748
1388
|
d = self._datasets
|
|
1389
|
+
n = self._namespaces
|
|
1390
|
+
p = self._projects
|
|
749
1391
|
query = self._base_dataset_query()
|
|
750
|
-
query = query.where(
|
|
1392
|
+
query = query.where(
|
|
1393
|
+
d.c.name == name,
|
|
1394
|
+
n.c.name == namespace_name,
|
|
1395
|
+
p.c.name == project_name,
|
|
1396
|
+
) # type: ignore [attr-defined]
|
|
751
1397
|
ds = self._parse_dataset(self.db.execute(query, conn=conn))
|
|
752
1398
|
if not ds:
|
|
753
|
-
raise DatasetNotFoundError(
|
|
1399
|
+
raise DatasetNotFoundError(
|
|
1400
|
+
f"Dataset {name} not found in namespace {namespace_name}"
|
|
1401
|
+
f" and project {project_name}"
|
|
1402
|
+
)
|
|
1403
|
+
|
|
754
1404
|
return ds
|
|
755
1405
|
|
|
756
1406
|
def remove_dataset_version(
|
|
757
|
-
self, dataset: DatasetRecord, version:
|
|
1407
|
+
self, dataset: DatasetRecord, version: str
|
|
758
1408
|
) -> DatasetRecord:
|
|
759
1409
|
"""
|
|
760
1410
|
Deletes one single dataset version.
|
|
@@ -787,7 +1437,7 @@ class AbstractDBMetastore(AbstractMetastore):
|
|
|
787
1437
|
self,
|
|
788
1438
|
dataset: DatasetRecord,
|
|
789
1439
|
status: int,
|
|
790
|
-
version:
|
|
1440
|
+
version: str | None = None,
|
|
791
1441
|
error_message="",
|
|
792
1442
|
error_stack="",
|
|
793
1443
|
script_output="",
|
|
@@ -808,7 +1458,7 @@ class AbstractDBMetastore(AbstractMetastore):
|
|
|
808
1458
|
update_data["error_message"] = error_message
|
|
809
1459
|
update_data["error_stack"] = error_stack
|
|
810
1460
|
|
|
811
|
-
self.update_dataset(dataset, conn=conn, **update_data)
|
|
1461
|
+
dataset = self.update_dataset(dataset, conn=conn, **update_data)
|
|
812
1462
|
|
|
813
1463
|
if version:
|
|
814
1464
|
self.update_dataset_version(dataset, version, conn=conn, **update_data)
|
|
@@ -820,32 +1470,29 @@ class AbstractDBMetastore(AbstractMetastore):
|
|
|
820
1470
|
#
|
|
821
1471
|
def add_dataset_dependency(
|
|
822
1472
|
self,
|
|
823
|
-
|
|
824
|
-
source_dataset_version:
|
|
825
|
-
|
|
826
|
-
|
|
1473
|
+
source_dataset: "DatasetRecord",
|
|
1474
|
+
source_dataset_version: str,
|
|
1475
|
+
dep_dataset: "DatasetRecord",
|
|
1476
|
+
dep_dataset_version: str,
|
|
827
1477
|
) -> None:
|
|
828
1478
|
"""Adds dataset dependency to dataset."""
|
|
829
|
-
source_dataset = self.get_dataset(source_dataset_name)
|
|
830
|
-
dataset = self.get_dataset(dataset_name)
|
|
831
|
-
|
|
832
1479
|
self.db.execute(
|
|
833
1480
|
self._datasets_dependencies_insert().values(
|
|
834
1481
|
source_dataset_id=source_dataset.id,
|
|
835
1482
|
source_dataset_version_id=(
|
|
836
1483
|
source_dataset.get_version(source_dataset_version).id
|
|
837
1484
|
),
|
|
838
|
-
dataset_id=
|
|
839
|
-
dataset_version_id=
|
|
1485
|
+
dataset_id=dep_dataset.id,
|
|
1486
|
+
dataset_version_id=dep_dataset.get_version(dep_dataset_version).id,
|
|
840
1487
|
)
|
|
841
1488
|
)
|
|
842
1489
|
|
|
843
1490
|
def update_dataset_dependency_source(
|
|
844
1491
|
self,
|
|
845
1492
|
source_dataset: DatasetRecord,
|
|
846
|
-
source_dataset_version:
|
|
847
|
-
new_source_dataset:
|
|
848
|
-
new_source_dataset_version:
|
|
1493
|
+
source_dataset_version: str,
|
|
1494
|
+
new_source_dataset: DatasetRecord | None = None,
|
|
1495
|
+
new_source_dataset_version: str | None = None,
|
|
849
1496
|
) -> None:
|
|
850
1497
|
dd = self._datasets_dependencies
|
|
851
1498
|
|
|
@@ -875,9 +1522,23 @@ class AbstractDBMetastore(AbstractMetastore):
|
|
|
875
1522
|
Returns a list of columns to select in a query for fetching dataset dependencies
|
|
876
1523
|
"""
|
|
877
1524
|
|
|
1525
|
+
@abstractmethod
|
|
1526
|
+
def _dataset_dependency_nodes_select_columns(
|
|
1527
|
+
self,
|
|
1528
|
+
namespaces_subquery: "Subquery",
|
|
1529
|
+
dependency_tree_cte: "CTE",
|
|
1530
|
+
datasets_subquery: "Subquery",
|
|
1531
|
+
) -> list["ColumnElement"]:
|
|
1532
|
+
"""
|
|
1533
|
+
Returns a list of columns to select in a query for fetching
|
|
1534
|
+
dataset dependency nodes.
|
|
1535
|
+
"""
|
|
1536
|
+
|
|
878
1537
|
def get_direct_dataset_dependencies(
|
|
879
|
-
self, dataset: DatasetRecord, version:
|
|
880
|
-
) -> list[
|
|
1538
|
+
self, dataset: DatasetRecord, version: str
|
|
1539
|
+
) -> list[DatasetDependency | None]:
|
|
1540
|
+
n = self._namespaces
|
|
1541
|
+
p = self._projects
|
|
881
1542
|
d = self._datasets
|
|
882
1543
|
dd = self._datasets_dependencies
|
|
883
1544
|
dv = self._datasets_versions
|
|
@@ -889,23 +1550,90 @@ class AbstractDBMetastore(AbstractMetastore):
|
|
|
889
1550
|
query = (
|
|
890
1551
|
self._datasets_dependencies_select(*select_cols)
|
|
891
1552
|
.select_from(
|
|
892
|
-
dd.join(d, dd.c.dataset_id == d.c.id, isouter=True)
|
|
893
|
-
|
|
894
|
-
)
|
|
1553
|
+
dd.join(d, dd.c.dataset_id == d.c.id, isouter=True)
|
|
1554
|
+
.join(dv, dd.c.dataset_version_id == dv.c.id, isouter=True)
|
|
1555
|
+
.join(p, d.c.project_id == p.c.id, isouter=True)
|
|
1556
|
+
.join(n, p.c.namespace_id == n.c.id, isouter=True)
|
|
895
1557
|
)
|
|
896
1558
|
.where(
|
|
897
1559
|
(dd.c.source_dataset_id == dataset.id)
|
|
898
1560
|
& (dd.c.source_dataset_version_id == dataset_version.id)
|
|
899
1561
|
)
|
|
900
1562
|
)
|
|
901
|
-
if version:
|
|
902
|
-
dataset_version = dataset.get_version(version)
|
|
903
|
-
query = query.where(dd.c.source_dataset_version_id == dataset_version.id)
|
|
904
1563
|
|
|
905
1564
|
return [self.dependency_class.parse(*r) for r in self.db.execute(query)]
|
|
906
1565
|
|
|
1566
|
+
def get_dataset_dependency_nodes(
|
|
1567
|
+
self, dataset_id: int, version_id: int, depth_limit: int = DEPTH_LIMIT_DEFAULT
|
|
1568
|
+
) -> list[DatasetDependencyNode | None]:
|
|
1569
|
+
n = self._namespaces_select().subquery()
|
|
1570
|
+
p = self._projects
|
|
1571
|
+
d = self._datasets_select().subquery()
|
|
1572
|
+
dd = self._datasets_dependencies
|
|
1573
|
+
dv = self._datasets_versions
|
|
1574
|
+
|
|
1575
|
+
# Common dependency fields for CTE
|
|
1576
|
+
dep_fields = [
|
|
1577
|
+
dd.c.id,
|
|
1578
|
+
dd.c.source_dataset_id,
|
|
1579
|
+
dd.c.source_dataset_version_id,
|
|
1580
|
+
dd.c.dataset_id,
|
|
1581
|
+
dd.c.dataset_version_id,
|
|
1582
|
+
]
|
|
1583
|
+
|
|
1584
|
+
# Base case: direct dependencies
|
|
1585
|
+
base_query = select(
|
|
1586
|
+
*dep_fields,
|
|
1587
|
+
literal(0).label("depth"),
|
|
1588
|
+
).where(
|
|
1589
|
+
(dd.c.source_dataset_id == dataset_id)
|
|
1590
|
+
& (dd.c.source_dataset_version_id == version_id)
|
|
1591
|
+
)
|
|
1592
|
+
|
|
1593
|
+
cte = base_query.cte(name="dependency_tree", recursive=True)
|
|
1594
|
+
|
|
1595
|
+
# Recursive case: dependencies of dependencies
|
|
1596
|
+
# Limit depth to 100 to prevent infinite loops in case of circular dependencies
|
|
1597
|
+
recursive_query = (
|
|
1598
|
+
select(
|
|
1599
|
+
*dep_fields,
|
|
1600
|
+
(cte.c.depth + 1).label("depth"),
|
|
1601
|
+
)
|
|
1602
|
+
.select_from(
|
|
1603
|
+
cte.join(
|
|
1604
|
+
dd,
|
|
1605
|
+
(cte.c.dataset_id == dd.c.source_dataset_id)
|
|
1606
|
+
& (cte.c.dataset_version_id == dd.c.source_dataset_version_id),
|
|
1607
|
+
)
|
|
1608
|
+
)
|
|
1609
|
+
.where(cte.c.depth < depth_limit)
|
|
1610
|
+
)
|
|
1611
|
+
|
|
1612
|
+
cte = cte.union(recursive_query)
|
|
1613
|
+
|
|
1614
|
+
# Fetch all with full details
|
|
1615
|
+
select_cols = self._dataset_dependency_nodes_select_columns(
|
|
1616
|
+
namespaces_subquery=n,
|
|
1617
|
+
dependency_tree_cte=cte,
|
|
1618
|
+
datasets_subquery=d,
|
|
1619
|
+
)
|
|
1620
|
+
final_query = self._datasets_dependencies_select(*select_cols).select_from(
|
|
1621
|
+
# Use outer joins to handle cases where dependent datasets have been
|
|
1622
|
+
# physically deleted. This allows us to return dependency records with
|
|
1623
|
+
# None values instead of silently omitting them, making broken
|
|
1624
|
+
# dependencies visible to callers.
|
|
1625
|
+
cte.join(d, cte.c.dataset_id == d.c.id, isouter=True)
|
|
1626
|
+
.join(dv, cte.c.dataset_version_id == dv.c.id, isouter=True)
|
|
1627
|
+
.join(p, d.c.project_id == p.c.id, isouter=True)
|
|
1628
|
+
.join(n, p.c.namespace_id == n.c.id, isouter=True)
|
|
1629
|
+
)
|
|
1630
|
+
|
|
1631
|
+
return [
|
|
1632
|
+
self.dependency_node_class.parse(*r) for r in self.db.execute(final_query)
|
|
1633
|
+
]
|
|
1634
|
+
|
|
907
1635
|
def remove_dataset_dependencies(
|
|
908
|
-
self, dataset: DatasetRecord, version:
|
|
1636
|
+
self, dataset: DatasetRecord, version: str | None = None
|
|
909
1637
|
) -> None:
|
|
910
1638
|
"""
|
|
911
1639
|
When we remove dataset, we need to clean up it's dependencies as well
|
|
@@ -924,7 +1652,7 @@ class AbstractDBMetastore(AbstractMetastore):
|
|
|
924
1652
|
self.db.execute(q)
|
|
925
1653
|
|
|
926
1654
|
def remove_dataset_dependants(
|
|
927
|
-
self, dataset: DatasetRecord, version:
|
|
1655
|
+
self, dataset: DatasetRecord, version: str | None = None
|
|
928
1656
|
) -> None:
|
|
929
1657
|
"""
|
|
930
1658
|
When we remove dataset, we need to clear its references in other dataset
|
|
@@ -975,11 +1703,13 @@ class AbstractDBMetastore(AbstractMetastore):
|
|
|
975
1703
|
Column("error_stack", Text, nullable=False, default=""),
|
|
976
1704
|
Column("params", JSON, nullable=False),
|
|
977
1705
|
Column("metrics", JSON, nullable=False),
|
|
1706
|
+
Column("parent_job_id", Text, nullable=True),
|
|
1707
|
+
Index("idx_jobs_parent_job_id", "parent_job_id"),
|
|
978
1708
|
]
|
|
979
1709
|
|
|
980
1710
|
@cached_property
|
|
981
1711
|
def _job_fields(self) -> list[str]:
|
|
982
|
-
return [c.name for c in self._jobs_columns() if c
|
|
1712
|
+
return [c.name for c in self._jobs_columns() if isinstance(c, Column)] # type: ignore[attr-defined]
|
|
983
1713
|
|
|
984
1714
|
@cached_property
|
|
985
1715
|
def _jobs(self) -> "Table":
|
|
@@ -1013,15 +1743,29 @@ class AbstractDBMetastore(AbstractMetastore):
|
|
|
1013
1743
|
query = self._jobs_query().where(self._jobs.c.id.in_(ids))
|
|
1014
1744
|
yield from self._parse_jobs(self.db.execute(query, conn=conn))
|
|
1015
1745
|
|
|
1746
|
+
def get_last_job_by_name(self, name: str, conn=None) -> "Job | None":
|
|
1747
|
+
query = (
|
|
1748
|
+
self._jobs_query()
|
|
1749
|
+
.where(self._jobs.c.name == name)
|
|
1750
|
+
.order_by(self._jobs.c.created_at.desc())
|
|
1751
|
+
.limit(1)
|
|
1752
|
+
)
|
|
1753
|
+
results = list(self.db.execute(query, conn=conn))
|
|
1754
|
+
if not results:
|
|
1755
|
+
return None
|
|
1756
|
+
return self._parse_job(results[0])
|
|
1757
|
+
|
|
1016
1758
|
def create_job(
|
|
1017
1759
|
self,
|
|
1018
1760
|
name: str,
|
|
1019
1761
|
query: str,
|
|
1020
1762
|
query_type: JobQueryType = JobQueryType.PYTHON,
|
|
1763
|
+
status: JobStatus = JobStatus.CREATED,
|
|
1021
1764
|
workers: int = 1,
|
|
1022
|
-
python_version:
|
|
1023
|
-
params:
|
|
1024
|
-
|
|
1765
|
+
python_version: str | None = None,
|
|
1766
|
+
params: dict[str, str] | None = None,
|
|
1767
|
+
parent_job_id: str | None = None,
|
|
1768
|
+
conn: Any = None,
|
|
1025
1769
|
) -> str:
|
|
1026
1770
|
"""
|
|
1027
1771
|
Creates a new job.
|
|
@@ -1032,7 +1776,7 @@ class AbstractDBMetastore(AbstractMetastore):
|
|
|
1032
1776
|
self._jobs_insert().values(
|
|
1033
1777
|
id=job_id,
|
|
1034
1778
|
name=name,
|
|
1035
|
-
status=
|
|
1779
|
+
status=status,
|
|
1036
1780
|
created_at=datetime.now(timezone.utc),
|
|
1037
1781
|
query=query,
|
|
1038
1782
|
query_type=query_type.value,
|
|
@@ -1042,30 +1786,68 @@ class AbstractDBMetastore(AbstractMetastore):
|
|
|
1042
1786
|
error_stack="",
|
|
1043
1787
|
params=json.dumps(params or {}),
|
|
1044
1788
|
metrics=json.dumps({}),
|
|
1789
|
+
parent_job_id=parent_job_id,
|
|
1045
1790
|
),
|
|
1046
1791
|
conn=conn,
|
|
1047
1792
|
)
|
|
1048
1793
|
return job_id
|
|
1049
1794
|
|
|
1795
|
+
def get_job(self, job_id: str, conn=None) -> Job | None:
|
|
1796
|
+
"""Returns the job with the given ID."""
|
|
1797
|
+
query = self._jobs_select(self._jobs).where(self._jobs.c.id == job_id)
|
|
1798
|
+
results = list(self.db.execute(query, conn=conn))
|
|
1799
|
+
if not results:
|
|
1800
|
+
return None
|
|
1801
|
+
return self._parse_job(results[0])
|
|
1802
|
+
|
|
1803
|
+
def update_job(
|
|
1804
|
+
self,
|
|
1805
|
+
job_id: str,
|
|
1806
|
+
status: JobStatus | None = None,
|
|
1807
|
+
error_message: str | None = None,
|
|
1808
|
+
error_stack: str | None = None,
|
|
1809
|
+
finished_at: datetime | None = None,
|
|
1810
|
+
metrics: dict[str, Any] | None = None,
|
|
1811
|
+
conn: Any | None = None,
|
|
1812
|
+
) -> Job | None:
|
|
1813
|
+
"""Updates job fields."""
|
|
1814
|
+
values: dict = {}
|
|
1815
|
+
if status is not None:
|
|
1816
|
+
values["status"] = status
|
|
1817
|
+
if error_message is not None:
|
|
1818
|
+
values["error_message"] = error_message
|
|
1819
|
+
if error_stack is not None:
|
|
1820
|
+
values["error_stack"] = error_stack
|
|
1821
|
+
if finished_at is not None:
|
|
1822
|
+
values["finished_at"] = finished_at
|
|
1823
|
+
if metrics:
|
|
1824
|
+
values["metrics"] = json.dumps(metrics)
|
|
1825
|
+
|
|
1826
|
+
if values:
|
|
1827
|
+
j = self._jobs
|
|
1828
|
+
self.db.execute(
|
|
1829
|
+
self._jobs_update().where(j.c.id == job_id).values(**values),
|
|
1830
|
+
conn=conn,
|
|
1831
|
+
) # type: ignore [attr-defined]
|
|
1832
|
+
|
|
1833
|
+
return self.get_job(job_id, conn=conn)
|
|
1834
|
+
|
|
1050
1835
|
def set_job_status(
|
|
1051
1836
|
self,
|
|
1052
1837
|
job_id: str,
|
|
1053
1838
|
status: JobStatus,
|
|
1054
|
-
error_message:
|
|
1055
|
-
error_stack:
|
|
1056
|
-
|
|
1057
|
-
conn: Optional[Any] = None,
|
|
1839
|
+
error_message: str | None = None,
|
|
1840
|
+
error_stack: str | None = None,
|
|
1841
|
+
conn: Any | None = None,
|
|
1058
1842
|
) -> None:
|
|
1059
1843
|
"""Set the status of the given job."""
|
|
1060
|
-
values: dict = {"status": status
|
|
1061
|
-
if status
|
|
1844
|
+
values: dict = {"status": status}
|
|
1845
|
+
if status in JobStatus.finished():
|
|
1062
1846
|
values["finished_at"] = datetime.now(timezone.utc)
|
|
1063
1847
|
if error_message:
|
|
1064
1848
|
values["error_message"] = error_message
|
|
1065
1849
|
if error_stack:
|
|
1066
1850
|
values["error_stack"] = error_stack
|
|
1067
|
-
if metrics:
|
|
1068
|
-
values["metrics"] = json.dumps(metrics)
|
|
1069
1851
|
self.db.execute(
|
|
1070
1852
|
self._jobs_update(self._jobs.c.id == job_id).values(**values),
|
|
1071
1853
|
conn=conn,
|
|
@@ -1074,8 +1856,8 @@ class AbstractDBMetastore(AbstractMetastore):
|
|
|
1074
1856
|
def get_job_status(
|
|
1075
1857
|
self,
|
|
1076
1858
|
job_id: str,
|
|
1077
|
-
conn:
|
|
1078
|
-
) ->
|
|
1859
|
+
conn: Any | None = None,
|
|
1860
|
+
) -> JobStatus | None:
|
|
1079
1861
|
"""Returns the status of the given job."""
|
|
1080
1862
|
results = list(
|
|
1081
1863
|
self.db.execute(
|
|
@@ -1087,36 +1869,320 @@ class AbstractDBMetastore(AbstractMetastore):
|
|
|
1087
1869
|
return None
|
|
1088
1870
|
return results[0][0]
|
|
1089
1871
|
|
|
1090
|
-
|
|
1872
|
+
#
|
|
1873
|
+
# Checkpoints
|
|
1874
|
+
#
|
|
1875
|
+
|
|
1876
|
+
@staticmethod
|
|
1877
|
+
def _checkpoints_columns() -> "list[SchemaItem]":
|
|
1878
|
+
return [
|
|
1879
|
+
Column(
|
|
1880
|
+
"id",
|
|
1881
|
+
Text,
|
|
1882
|
+
default=uuid4,
|
|
1883
|
+
primary_key=True,
|
|
1884
|
+
nullable=False,
|
|
1885
|
+
),
|
|
1886
|
+
Column("job_id", Text, nullable=True),
|
|
1887
|
+
Column("hash", Text, nullable=False),
|
|
1888
|
+
Column("partial", Boolean, default=False),
|
|
1889
|
+
Column("created_at", DateTime(timezone=True), nullable=False),
|
|
1890
|
+
UniqueConstraint("job_id", "hash"),
|
|
1891
|
+
]
|
|
1892
|
+
|
|
1893
|
+
@cached_property
|
|
1894
|
+
def _checkpoints_fields(self) -> list[str]:
|
|
1895
|
+
return [c.name for c in self._checkpoints_columns() if c.name] # type: ignore[attr-defined]
|
|
1896
|
+
|
|
1897
|
+
@cached_property
|
|
1898
|
+
def _checkpoints(self) -> "Table":
|
|
1899
|
+
return Table(
|
|
1900
|
+
self.CHECKPOINTS_TABLE,
|
|
1901
|
+
self.db.metadata,
|
|
1902
|
+
*self._checkpoints_columns(),
|
|
1903
|
+
)
|
|
1904
|
+
|
|
1905
|
+
@abstractmethod
|
|
1906
|
+
def _checkpoints_insert(self) -> "Insert": ...
|
|
1907
|
+
|
|
1908
|
+
@classmethod
|
|
1909
|
+
def _dataset_version_jobs_columns(cls) -> "list[SchemaItem]":
|
|
1910
|
+
"""Junction table for dataset versions and jobs many-to-many relationship."""
|
|
1911
|
+
return [
|
|
1912
|
+
Column("id", Integer, primary_key=True),
|
|
1913
|
+
Column(
|
|
1914
|
+
"dataset_version_id",
|
|
1915
|
+
Integer,
|
|
1916
|
+
ForeignKey(f"{cls.DATASET_VERSION_TABLE}.id", ondelete="CASCADE"),
|
|
1917
|
+
nullable=False,
|
|
1918
|
+
),
|
|
1919
|
+
Column("job_id", Text, nullable=False),
|
|
1920
|
+
Column("is_creator", Boolean, nullable=False, default=False),
|
|
1921
|
+
Column("created_at", DateTime(timezone=True)),
|
|
1922
|
+
UniqueConstraint("dataset_version_id", "job_id"),
|
|
1923
|
+
Index("dc_idx_dvj_query", "job_id", "is_creator", "created_at"),
|
|
1924
|
+
]
|
|
1925
|
+
|
|
1926
|
+
@cached_property
|
|
1927
|
+
def _dataset_version_jobs_fields(self) -> list[str]:
|
|
1928
|
+
return [c.name for c in self._dataset_version_jobs_columns() if c.name] # type: ignore[attr-defined]
|
|
1929
|
+
|
|
1930
|
+
@cached_property
|
|
1931
|
+
def _dataset_version_jobs(self) -> "Table":
|
|
1932
|
+
return Table(
|
|
1933
|
+
self.DATASET_VERSION_JOBS_TABLE,
|
|
1934
|
+
self.db.metadata,
|
|
1935
|
+
*self._dataset_version_jobs_columns(),
|
|
1936
|
+
)
|
|
1937
|
+
|
|
1938
|
+
@abstractmethod
|
|
1939
|
+
def _dataset_version_jobs_insert(self) -> "Insert": ...
|
|
1940
|
+
|
|
1941
|
+
def _dataset_version_jobs_select(self, *columns) -> "Select":
|
|
1942
|
+
if not columns:
|
|
1943
|
+
return self._dataset_version_jobs.select()
|
|
1944
|
+
return select(*columns)
|
|
1945
|
+
|
|
1946
|
+
def _dataset_version_jobs_delete(self) -> "Delete":
|
|
1947
|
+
return self._dataset_version_jobs.delete()
|
|
1948
|
+
|
|
1949
|
+
def _checkpoints_select(self, *columns) -> "Select":
|
|
1950
|
+
if not columns:
|
|
1951
|
+
return self._checkpoints.select()
|
|
1952
|
+
return select(*columns)
|
|
1953
|
+
|
|
1954
|
+
def _checkpoints_delete(self) -> "Delete":
|
|
1955
|
+
return self._checkpoints.delete()
|
|
1956
|
+
|
|
1957
|
+
def _checkpoints_query(self):
|
|
1958
|
+
return self._checkpoints_select(
|
|
1959
|
+
*[getattr(self._checkpoints.c, f) for f in self._checkpoints_fields]
|
|
1960
|
+
)
|
|
1961
|
+
|
|
1962
|
+
def create_checkpoint(
|
|
1963
|
+
self,
|
|
1964
|
+
job_id: str,
|
|
1965
|
+
_hash: str,
|
|
1966
|
+
partial: bool = False,
|
|
1967
|
+
conn: Any | None = None,
|
|
1968
|
+
) -> Checkpoint:
|
|
1969
|
+
"""
|
|
1970
|
+
Creates a new job query step.
|
|
1971
|
+
"""
|
|
1972
|
+
checkpoint_id = str(uuid4())
|
|
1973
|
+
self.db.execute(
|
|
1974
|
+
self._checkpoints_insert().values(
|
|
1975
|
+
id=checkpoint_id,
|
|
1976
|
+
job_id=job_id,
|
|
1977
|
+
hash=_hash,
|
|
1978
|
+
partial=partial,
|
|
1979
|
+
created_at=datetime.now(timezone.utc),
|
|
1980
|
+
),
|
|
1981
|
+
conn=conn,
|
|
1982
|
+
)
|
|
1983
|
+
return self.get_checkpoint_by_id(checkpoint_id)
|
|
1984
|
+
|
|
1985
|
+
def list_checkpoints(self, job_id: str, conn=None) -> Iterator[Checkpoint]:
|
|
1986
|
+
"""List checkpoints by job id."""
|
|
1987
|
+
query = self._checkpoints_query().where(self._checkpoints.c.job_id == job_id)
|
|
1988
|
+
rows = list(self.db.execute(query, conn=conn))
|
|
1989
|
+
|
|
1990
|
+
yield from [self.checkpoint_class.parse(*r) for r in rows]
|
|
1991
|
+
|
|
1992
|
+
def get_checkpoint_by_id(self, checkpoint_id: str, conn=None) -> Checkpoint:
|
|
1993
|
+
"""Returns the checkpoint with the given ID."""
|
|
1994
|
+
ch = self._checkpoints
|
|
1995
|
+
query = self._checkpoints_select(ch).where(ch.c.id == checkpoint_id)
|
|
1996
|
+
rows = list(self.db.execute(query, conn=conn))
|
|
1997
|
+
if not rows:
|
|
1998
|
+
raise CheckpointNotFoundError(f"Checkpoint {checkpoint_id} not found")
|
|
1999
|
+
return self.checkpoint_class.parse(*rows[0])
|
|
2000
|
+
|
|
2001
|
+
def find_checkpoint(
|
|
2002
|
+
self, job_id: str, _hash: str, partial: bool = False, conn=None
|
|
2003
|
+
) -> Checkpoint | None:
|
|
2004
|
+
"""
|
|
2005
|
+
Tries to find checkpoint for a job with specific hash and optionally partial
|
|
2006
|
+
"""
|
|
2007
|
+
ch = self._checkpoints
|
|
2008
|
+
query = self._checkpoints_select(ch).where(
|
|
2009
|
+
ch.c.job_id == job_id, ch.c.hash == _hash, ch.c.partial == partial
|
|
2010
|
+
)
|
|
2011
|
+
rows = list(self.db.execute(query, conn=conn))
|
|
2012
|
+
if not rows:
|
|
2013
|
+
return None
|
|
2014
|
+
return self.checkpoint_class.parse(*rows[0])
|
|
2015
|
+
|
|
2016
|
+
def get_last_checkpoint(self, job_id: str, conn=None) -> Checkpoint | None:
|
|
2017
|
+
query = (
|
|
2018
|
+
self._checkpoints_query()
|
|
2019
|
+
.where(self._checkpoints.c.job_id == job_id)
|
|
2020
|
+
.order_by(desc(self._checkpoints.c.created_at))
|
|
2021
|
+
.limit(1)
|
|
2022
|
+
)
|
|
2023
|
+
rows = list(self.db.execute(query, conn=conn))
|
|
2024
|
+
if not rows:
|
|
2025
|
+
return None
|
|
2026
|
+
return self.checkpoint_class.parse(*rows[0])
|
|
2027
|
+
|
|
2028
|
+
def link_dataset_version_to_job(
|
|
1091
2029
|
self,
|
|
2030
|
+
dataset_version_id: int,
|
|
1092
2031
|
job_id: str,
|
|
1093
|
-
|
|
1094
|
-
|
|
2032
|
+
is_creator: bool = False,
|
|
2033
|
+
conn=None,
|
|
1095
2034
|
) -> None:
|
|
1096
|
-
|
|
1097
|
-
|
|
1098
|
-
|
|
1099
|
-
|
|
1100
|
-
|
|
1101
|
-
|
|
1102
|
-
|
|
1103
|
-
|
|
2035
|
+
# Use transaction to atomically:
|
|
2036
|
+
# 1. Link dataset version to job in junction table
|
|
2037
|
+
# 2. Update dataset_version.job_id to point to this job
|
|
2038
|
+
with self.db.transaction() as tx_conn:
|
|
2039
|
+
conn = conn or tx_conn
|
|
2040
|
+
|
|
2041
|
+
# Insert into junction table
|
|
2042
|
+
query = self._dataset_version_jobs_insert().values(
|
|
2043
|
+
dataset_version_id=dataset_version_id,
|
|
2044
|
+
job_id=job_id,
|
|
2045
|
+
is_creator=is_creator,
|
|
2046
|
+
created_at=datetime.now(timezone.utc),
|
|
2047
|
+
)
|
|
2048
|
+
if hasattr(query, "on_conflict_do_nothing"):
|
|
2049
|
+
query = query.on_conflict_do_nothing(
|
|
2050
|
+
index_elements=["dataset_version_id", "job_id"]
|
|
1104
2051
|
)
|
|
1105
|
-
|
|
2052
|
+
self.db.execute(query, conn=conn)
|
|
2053
|
+
|
|
2054
|
+
# Also update dataset_version.job_id to point to this job
|
|
2055
|
+
update_query = (
|
|
2056
|
+
self._datasets_versions.update()
|
|
2057
|
+
.where(self._datasets_versions.c.id == dataset_version_id)
|
|
2058
|
+
.values(job_id=job_id)
|
|
2059
|
+
)
|
|
2060
|
+
self.db.execute(update_query, conn=conn)
|
|
2061
|
+
|
|
2062
|
+
def get_ancestor_job_ids(self, job_id: str, conn=None) -> list[str]:
|
|
2063
|
+
# Use recursive CTE to walk up the parent chain
|
|
2064
|
+
# Format: WITH RECURSIVE ancestors(id, parent_job_id, depth) AS (...)
|
|
2065
|
+
# Include depth tracking to prevent infinite recursion in case of
|
|
2066
|
+
# circular dependencies
|
|
2067
|
+
ancestors_cte = (
|
|
2068
|
+
self._jobs_select(
|
|
2069
|
+
self._jobs.c.id.label("id"),
|
|
2070
|
+
self._jobs.c.parent_job_id.label("parent_job_id"),
|
|
2071
|
+
literal(0).label("depth"),
|
|
1106
2072
|
)
|
|
1107
|
-
self.
|
|
2073
|
+
.where(self._jobs.c.id == job_id)
|
|
2074
|
+
.cte(name="ancestors", recursive=True)
|
|
2075
|
+
)
|
|
1108
2076
|
|
|
1109
|
-
|
|
1110
|
-
|
|
1111
|
-
|
|
1112
|
-
|
|
2077
|
+
# Recursive part: join with parent jobs, incrementing depth and checking limit
|
|
2078
|
+
ancestors_recursive = ancestors_cte.union_all(
|
|
2079
|
+
self._jobs_select(
|
|
2080
|
+
self._jobs.c.id.label("id"),
|
|
2081
|
+
self._jobs.c.parent_job_id.label("parent_job_id"),
|
|
2082
|
+
(ancestors_cte.c.depth + 1).label("depth"),
|
|
2083
|
+
).select_from(
|
|
2084
|
+
self._jobs.join(
|
|
2085
|
+
ancestors_cte,
|
|
2086
|
+
(
|
|
2087
|
+
self._jobs.c.id
|
|
2088
|
+
== cast(ancestors_cte.c.parent_job_id, self._jobs.c.id.type)
|
|
2089
|
+
)
|
|
2090
|
+
& (ancestors_cte.c.parent_job_id.isnot(None)) # Stop at root jobs
|
|
2091
|
+
& (ancestors_cte.c.depth < JOB_ANCESTRY_MAX_DEPTH),
|
|
2092
|
+
)
|
|
2093
|
+
)
|
|
2094
|
+
)
|
|
2095
|
+
|
|
2096
|
+
# Select all ancestor IDs and depths except the starting job itself
|
|
2097
|
+
query = select(ancestors_recursive.c.id, ancestors_recursive.c.depth).where(
|
|
2098
|
+
ancestors_recursive.c.id != job_id
|
|
2099
|
+
)
|
|
1113
2100
|
|
|
1114
|
-
|
|
2101
|
+
results = list(self.db.execute(query, conn=conn))
|
|
1115
2102
|
|
|
1116
|
-
|
|
1117
|
-
|
|
1118
|
-
|
|
1119
|
-
|
|
2103
|
+
# Check if we hit the depth limit
|
|
2104
|
+
if results:
|
|
2105
|
+
max_found_depth = max(row[1] for row in results)
|
|
2106
|
+
if max_found_depth >= JOB_ANCESTRY_MAX_DEPTH:
|
|
2107
|
+
from datachain.error import JobAncestryDepthExceededError
|
|
2108
|
+
|
|
2109
|
+
raise JobAncestryDepthExceededError(
|
|
2110
|
+
f"Job ancestry chain exceeds maximum depth of "
|
|
2111
|
+
f"{JOB_ANCESTRY_MAX_DEPTH}. Job ID: {job_id}"
|
|
2112
|
+
)
|
|
2113
|
+
|
|
2114
|
+
return [str(row[0]) for row in results]
|
|
2115
|
+
|
|
2116
|
+
def _get_dataset_version_for_job_ancestry_query(
|
|
2117
|
+
self,
|
|
2118
|
+
dataset_name: str,
|
|
2119
|
+
namespace_name: str,
|
|
2120
|
+
project_name: str,
|
|
2121
|
+
job_ancestry: list[str],
|
|
2122
|
+
) -> "Select":
|
|
2123
|
+
"""Find most recent dataset version created by any job in ancestry.
|
|
2124
|
+
|
|
2125
|
+
Searches job ancestry (current + parents) for the newest version of
|
|
2126
|
+
the dataset where is_creator=True. Returns newest by created_at, or
|
|
2127
|
+
None if no version was created by any job in the ancestry chain.
|
|
2128
|
+
|
|
2129
|
+
Used for checkpoint resolution to find which version to reuse when
|
|
2130
|
+
continuing from a parent job.
|
|
2131
|
+
"""
|
|
2132
|
+
return (
|
|
2133
|
+
self._datasets_versions_select()
|
|
2134
|
+
.select_from(
|
|
2135
|
+
self._dataset_version_jobs.join(
|
|
2136
|
+
self._datasets_versions,
|
|
2137
|
+
self._dataset_version_jobs.c.dataset_version_id
|
|
2138
|
+
== self._datasets_versions.c.id,
|
|
2139
|
+
)
|
|
2140
|
+
.join(
|
|
2141
|
+
self._datasets,
|
|
2142
|
+
self._datasets_versions.c.dataset_id == self._datasets.c.id,
|
|
2143
|
+
)
|
|
2144
|
+
.join(
|
|
2145
|
+
self._projects,
|
|
2146
|
+
self._datasets.c.project_id == self._projects.c.id,
|
|
2147
|
+
)
|
|
2148
|
+
.join(
|
|
2149
|
+
self._namespaces,
|
|
2150
|
+
self._projects.c.namespace_id == self._namespaces.c.id,
|
|
2151
|
+
)
|
|
2152
|
+
)
|
|
2153
|
+
.where(
|
|
2154
|
+
self._datasets.c.name == dataset_name,
|
|
2155
|
+
self._namespaces.c.name == namespace_name,
|
|
2156
|
+
self._projects.c.name == project_name,
|
|
2157
|
+
self._dataset_version_jobs.c.job_id.in_(job_ancestry),
|
|
2158
|
+
self._dataset_version_jobs.c.is_creator.is_(True),
|
|
2159
|
+
)
|
|
2160
|
+
.order_by(desc(self._dataset_version_jobs.c.created_at))
|
|
2161
|
+
.limit(1)
|
|
1120
2162
|
)
|
|
1121
2163
|
|
|
1122
|
-
|
|
2164
|
+
def get_dataset_version_for_job_ancestry(
|
|
2165
|
+
self,
|
|
2166
|
+
dataset_name: str,
|
|
2167
|
+
namespace_name: str,
|
|
2168
|
+
project_name: str,
|
|
2169
|
+
job_id: str,
|
|
2170
|
+
conn=None,
|
|
2171
|
+
) -> DatasetVersion | None:
|
|
2172
|
+
# Get job ancestry (current job + all ancestors)
|
|
2173
|
+
job_ancestry = [job_id, *self.get_ancestor_job_ids(job_id, conn=conn)]
|
|
2174
|
+
|
|
2175
|
+
query = self._get_dataset_version_for_job_ancestry_query(
|
|
2176
|
+
dataset_name, namespace_name, project_name, job_ancestry
|
|
2177
|
+
)
|
|
2178
|
+
|
|
2179
|
+
results = list(self.db.execute(query, conn=conn))
|
|
2180
|
+
if not results:
|
|
2181
|
+
return None
|
|
2182
|
+
|
|
2183
|
+
if len(results) > 1:
|
|
2184
|
+
raise DataChainError(
|
|
2185
|
+
f"Expected at most 1 dataset version, found {len(results)}"
|
|
2186
|
+
)
|
|
2187
|
+
|
|
2188
|
+
return self.dataset_version_class.parse(*results[0])
|