datachain 0.14.2__py3-none-any.whl → 0.39.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (137) hide show
  1. datachain/__init__.py +20 -0
  2. datachain/asyn.py +11 -12
  3. datachain/cache.py +7 -7
  4. datachain/catalog/__init__.py +2 -2
  5. datachain/catalog/catalog.py +621 -507
  6. datachain/catalog/dependency.py +164 -0
  7. datachain/catalog/loader.py +28 -18
  8. datachain/checkpoint.py +43 -0
  9. datachain/cli/__init__.py +24 -33
  10. datachain/cli/commands/__init__.py +1 -8
  11. datachain/cli/commands/datasets.py +83 -52
  12. datachain/cli/commands/ls.py +17 -17
  13. datachain/cli/commands/show.py +4 -4
  14. datachain/cli/parser/__init__.py +8 -74
  15. datachain/cli/parser/job.py +95 -3
  16. datachain/cli/parser/studio.py +11 -4
  17. datachain/cli/parser/utils.py +1 -2
  18. datachain/cli/utils.py +2 -15
  19. datachain/client/azure.py +4 -4
  20. datachain/client/fsspec.py +45 -28
  21. datachain/client/gcs.py +6 -6
  22. datachain/client/hf.py +29 -2
  23. datachain/client/http.py +157 -0
  24. datachain/client/local.py +15 -11
  25. datachain/client/s3.py +17 -9
  26. datachain/config.py +4 -8
  27. datachain/data_storage/db_engine.py +12 -6
  28. datachain/data_storage/job.py +5 -1
  29. datachain/data_storage/metastore.py +1252 -186
  30. datachain/data_storage/schema.py +58 -45
  31. datachain/data_storage/serializer.py +105 -15
  32. datachain/data_storage/sqlite.py +286 -127
  33. datachain/data_storage/warehouse.py +250 -113
  34. datachain/dataset.py +353 -148
  35. datachain/delta.py +391 -0
  36. datachain/diff/__init__.py +27 -29
  37. datachain/error.py +60 -0
  38. datachain/func/__init__.py +2 -1
  39. datachain/func/aggregate.py +66 -42
  40. datachain/func/array.py +242 -38
  41. datachain/func/base.py +7 -4
  42. datachain/func/conditional.py +110 -60
  43. datachain/func/func.py +96 -45
  44. datachain/func/numeric.py +55 -38
  45. datachain/func/path.py +32 -20
  46. datachain/func/random.py +2 -2
  47. datachain/func/string.py +67 -37
  48. datachain/func/window.py +7 -8
  49. datachain/hash_utils.py +123 -0
  50. datachain/job.py +11 -7
  51. datachain/json.py +138 -0
  52. datachain/lib/arrow.py +58 -22
  53. datachain/lib/audio.py +245 -0
  54. datachain/lib/clip.py +14 -13
  55. datachain/lib/convert/flatten.py +5 -3
  56. datachain/lib/convert/python_to_sql.py +6 -10
  57. datachain/lib/convert/sql_to_python.py +8 -0
  58. datachain/lib/convert/values_to_tuples.py +156 -51
  59. datachain/lib/data_model.py +42 -20
  60. datachain/lib/dataset_info.py +36 -8
  61. datachain/lib/dc/__init__.py +8 -2
  62. datachain/lib/dc/csv.py +25 -28
  63. datachain/lib/dc/database.py +398 -0
  64. datachain/lib/dc/datachain.py +1289 -425
  65. datachain/lib/dc/datasets.py +320 -38
  66. datachain/lib/dc/hf.py +38 -24
  67. datachain/lib/dc/json.py +29 -32
  68. datachain/lib/dc/listings.py +112 -8
  69. datachain/lib/dc/pandas.py +16 -12
  70. datachain/lib/dc/parquet.py +35 -23
  71. datachain/lib/dc/records.py +31 -23
  72. datachain/lib/dc/storage.py +154 -64
  73. datachain/lib/dc/storage_pattern.py +251 -0
  74. datachain/lib/dc/utils.py +24 -16
  75. datachain/lib/dc/values.py +8 -9
  76. datachain/lib/file.py +622 -89
  77. datachain/lib/hf.py +69 -39
  78. datachain/lib/image.py +14 -14
  79. datachain/lib/listing.py +14 -11
  80. datachain/lib/listing_info.py +1 -2
  81. datachain/lib/meta_formats.py +3 -4
  82. datachain/lib/model_store.py +39 -7
  83. datachain/lib/namespaces.py +125 -0
  84. datachain/lib/projects.py +130 -0
  85. datachain/lib/pytorch.py +32 -21
  86. datachain/lib/settings.py +192 -56
  87. datachain/lib/signal_schema.py +427 -104
  88. datachain/lib/tar.py +1 -2
  89. datachain/lib/text.py +8 -7
  90. datachain/lib/udf.py +164 -76
  91. datachain/lib/udf_signature.py +60 -35
  92. datachain/lib/utils.py +118 -4
  93. datachain/lib/video.py +17 -9
  94. datachain/lib/webdataset.py +61 -56
  95. datachain/lib/webdataset_laion.py +15 -16
  96. datachain/listing.py +22 -10
  97. datachain/model/bbox.py +3 -1
  98. datachain/model/ultralytics/bbox.py +16 -12
  99. datachain/model/ultralytics/pose.py +16 -12
  100. datachain/model/ultralytics/segment.py +16 -12
  101. datachain/namespace.py +84 -0
  102. datachain/node.py +6 -6
  103. datachain/nodes_thread_pool.py +0 -1
  104. datachain/plugins.py +24 -0
  105. datachain/project.py +78 -0
  106. datachain/query/batch.py +40 -41
  107. datachain/query/dataset.py +604 -322
  108. datachain/query/dispatch.py +261 -154
  109. datachain/query/metrics.py +4 -6
  110. datachain/query/params.py +2 -3
  111. datachain/query/queue.py +3 -12
  112. datachain/query/schema.py +11 -6
  113. datachain/query/session.py +200 -33
  114. datachain/query/udf.py +34 -2
  115. datachain/remote/studio.py +171 -69
  116. datachain/script_meta.py +12 -12
  117. datachain/semver.py +68 -0
  118. datachain/sql/__init__.py +2 -0
  119. datachain/sql/functions/array.py +33 -1
  120. datachain/sql/postgresql_dialect.py +9 -0
  121. datachain/sql/postgresql_types.py +21 -0
  122. datachain/sql/sqlite/__init__.py +5 -1
  123. datachain/sql/sqlite/base.py +102 -29
  124. datachain/sql/sqlite/types.py +8 -13
  125. datachain/sql/types.py +70 -15
  126. datachain/studio.py +223 -46
  127. datachain/toolkit/split.py +31 -10
  128. datachain/utils.py +101 -59
  129. {datachain-0.14.2.dist-info → datachain-0.39.0.dist-info}/METADATA +77 -22
  130. datachain-0.39.0.dist-info/RECORD +173 -0
  131. {datachain-0.14.2.dist-info → datachain-0.39.0.dist-info}/WHEEL +1 -1
  132. datachain/cli/commands/query.py +0 -53
  133. datachain/query/utils.py +0 -42
  134. datachain-0.14.2.dist-info/RECORD +0 -158
  135. {datachain-0.14.2.dist-info → datachain-0.39.0.dist-info}/entry_points.txt +0 -0
  136. {datachain-0.14.2.dist-info → datachain-0.39.0.dist-info}/licenses/LICENSE +0 -0
  137. {datachain-0.14.2.dist-info → datachain-0.39.0.dist-info}/top_level.txt +0 -0
@@ -1,28 +1,37 @@
1
1
  import copy
2
- import json
3
2
  import logging
4
3
  import os
5
4
  from abc import ABC, abstractmethod
6
5
  from collections.abc import Iterator
6
+ from contextlib import contextmanager, suppress
7
7
  from datetime import datetime, timezone
8
8
  from functools import cached_property, reduce
9
9
  from itertools import groupby
10
- from typing import TYPE_CHECKING, Any, Optional
10
+ from typing import TYPE_CHECKING, Any
11
11
  from uuid import uuid4
12
12
 
13
13
  from sqlalchemy import (
14
14
  JSON,
15
15
  BigInteger,
16
+ Boolean,
16
17
  Column,
17
18
  DateTime,
18
19
  ForeignKey,
20
+ Index,
19
21
  Integer,
20
22
  Table,
21
23
  Text,
22
24
  UniqueConstraint,
25
+ cast,
26
+ desc,
27
+ literal,
23
28
  select,
24
29
  )
30
+ from sqlalchemy.sql import func as f
25
31
 
32
+ from datachain import json
33
+ from datachain.catalog.dependency import DatasetDependencyNode
34
+ from datachain.checkpoint import Checkpoint
26
35
  from datachain.data_storage import JobQueryType, JobStatus
27
36
  from datachain.data_storage.serializer import Serializable
28
37
  from datachain.dataset import (
@@ -33,22 +42,34 @@ from datachain.dataset import (
33
42
  DatasetStatus,
34
43
  DatasetVersion,
35
44
  StorageURI,
45
+ parse_schema,
36
46
  )
37
47
  from datachain.error import (
48
+ CheckpointNotFoundError,
49
+ DataChainError,
38
50
  DatasetNotFoundError,
51
+ DatasetVersionNotFoundError,
52
+ NamespaceDeleteNotAllowedError,
53
+ NamespaceNotFoundError,
54
+ ProjectDeleteNotAllowedError,
55
+ ProjectNotFoundError,
39
56
  TableMissingError,
40
57
  )
41
58
  from datachain.job import Job
42
- from datachain.utils import JSONSerialize
59
+ from datachain.namespace import Namespace
60
+ from datachain.project import Project
43
61
 
44
62
  if TYPE_CHECKING:
45
- from sqlalchemy import Delete, Insert, Select, Update
63
+ from sqlalchemy import CTE, Delete, Insert, Select, Subquery, Update
46
64
  from sqlalchemy.schema import SchemaItem
65
+ from sqlalchemy.sql.elements import ColumnElement
47
66
 
48
67
  from datachain.data_storage import schema
49
68
  from datachain.data_storage.db_engine import DatabaseEngine
50
69
 
51
70
  logger = logging.getLogger("datachain")
71
+ DEPTH_LIMIT_DEFAULT = 100
72
+ JOB_ANCESTRY_MAX_DEPTH = 100
52
73
 
53
74
 
54
75
  class AbstractMetastore(ABC, Serializable):
@@ -60,15 +81,20 @@ class AbstractMetastore(ABC, Serializable):
60
81
  uri: StorageURI
61
82
 
62
83
  schema: "schema.Schema"
84
+ namespace_class: type[Namespace] = Namespace
85
+ project_class: type[Project] = Project
63
86
  dataset_class: type[DatasetRecord] = DatasetRecord
87
+ dataset_version_class: type[DatasetVersion] = DatasetVersion
64
88
  dataset_list_class: type[DatasetListRecord] = DatasetListRecord
65
89
  dataset_list_version_class: type[DatasetListVersion] = DatasetListVersion
66
90
  dependency_class: type[DatasetDependency] = DatasetDependency
91
+ dependency_node_class: type[DatasetDependencyNode] = DatasetDependencyNode
67
92
  job_class: type[Job] = Job
93
+ checkpoint_class: type[Checkpoint] = Checkpoint
68
94
 
69
95
  def __init__(
70
96
  self,
71
- uri: Optional[StorageURI] = None,
97
+ uri: StorageURI | None = None,
72
98
  ):
73
99
  self.uri = uri or StorageURI("")
74
100
 
@@ -82,7 +108,7 @@ class AbstractMetastore(ABC, Serializable):
82
108
  @abstractmethod
83
109
  def clone(
84
110
  self,
85
- uri: Optional[StorageURI] = None,
111
+ uri: StorageURI | None = None,
86
112
  use_new_connection: bool = False,
87
113
  ) -> "AbstractMetastore":
88
114
  """Clones AbstractMetastore implementation for some Storage input.
@@ -99,6 +125,16 @@ class AbstractMetastore(ABC, Serializable):
99
125
  differently."""
100
126
  self.close()
101
127
 
128
+ @contextmanager
129
+ def _init_guard(self):
130
+ """Ensure resources acquired during __init__ are released on failure."""
131
+ try:
132
+ yield
133
+ except Exception:
134
+ with suppress(Exception):
135
+ self.close_on_exit()
136
+ raise
137
+
102
138
  def cleanup_tables(self, temp_table_names: list[str]) -> None:
103
139
  """Cleanup temp tables."""
104
140
 
@@ -106,21 +142,131 @@ class AbstractMetastore(ABC, Serializable):
106
142
  """Cleanup for tests."""
107
143
 
108
144
  #
109
- # Datasets
145
+ # Namespaces
110
146
  #
111
147
 
148
+ @property
149
+ @abstractmethod
150
+ def default_namespace_name(self):
151
+ """Gets default namespace name"""
152
+
153
+ @property
154
+ def system_namespace_name(self):
155
+ return Namespace.system()
156
+
157
+ @abstractmethod
158
+ def create_namespace(
159
+ self,
160
+ name: str,
161
+ description: str | None = None,
162
+ uuid: str | None = None,
163
+ ignore_if_exists: bool = True,
164
+ validate: bool = True,
165
+ **kwargs,
166
+ ) -> Namespace:
167
+ """Creates new namespace"""
168
+
169
+ @abstractmethod
170
+ def get_namespace(self, name: str, conn=None) -> Namespace:
171
+ """Gets a single namespace by name"""
172
+
173
+ @abstractmethod
174
+ def remove_namespace(self, namespace_id: int, conn=None) -> None:
175
+ """Removes a single namespace by id"""
176
+
177
+ @abstractmethod
178
+ def list_namespaces(self, conn=None) -> list[Namespace]:
179
+ """Gets a list of all namespaces"""
180
+
181
+ #
182
+ # Projects
183
+ #
184
+
185
+ @property
186
+ @abstractmethod
187
+ def default_project_name(self):
188
+ """Gets default project name"""
189
+
190
+ @property
191
+ def listing_project_name(self):
192
+ return Project.listing()
193
+
194
+ @cached_property
195
+ def default_project(self) -> Project:
196
+ return self.get_project(
197
+ self.default_project_name, self.default_namespace_name, create=True
198
+ )
199
+
200
+ @cached_property
201
+ def listing_project(self) -> Project:
202
+ return self.get_project(self.listing_project_name, self.system_namespace_name)
203
+
204
+ @abstractmethod
205
+ def create_project(
206
+ self,
207
+ namespace_name: str,
208
+ name: str,
209
+ description: str | None = None,
210
+ uuid: str | None = None,
211
+ ignore_if_exists: bool = True,
212
+ validate: bool = True,
213
+ **kwargs,
214
+ ) -> Project:
215
+ """Creates new project in specific namespace"""
216
+
217
+ @abstractmethod
218
+ def get_project(
219
+ self, name: str, namespace_name: str, create: bool = False, conn=None
220
+ ) -> Project:
221
+ """
222
+ Gets a single project inside some namespace by name.
223
+ It also creates project if not found and create flag is set to True.
224
+ """
225
+
226
+ def is_default_project(self, project_name: str, namespace_name: str) -> bool:
227
+ return (
228
+ project_name == self.default_project_name
229
+ and namespace_name == self.default_namespace_name
230
+ )
231
+
232
+ def is_listing_project(self, project_name: str, namespace_name: str) -> bool:
233
+ return (
234
+ project_name == self.listing_project_name
235
+ and namespace_name == self.system_namespace_name
236
+ )
237
+
238
+ @abstractmethod
239
+ def get_project_by_id(self, project_id: int, conn=None) -> Project:
240
+ """Gets a single project by id"""
241
+
242
+ @abstractmethod
243
+ def count_projects(self, namespace_id: int | None = None) -> int:
244
+ """Counts projects in some namespace or in general."""
245
+
246
+ @abstractmethod
247
+ def remove_project(self, project_id: int, conn=None) -> None:
248
+ """Removes a single project by id"""
249
+
250
+ @abstractmethod
251
+ def list_projects(self, namespace_id: int | None, conn=None) -> list[Project]:
252
+ """Gets list of projects in some namespace or in general (in all namespaces)"""
253
+
254
+ #
255
+ # Datasets
256
+ #
112
257
  @abstractmethod
113
258
  def create_dataset(
114
259
  self,
115
260
  name: str,
261
+ project_id: int | None = None,
116
262
  status: int = DatasetStatus.CREATED,
117
- sources: Optional[list[str]] = None,
118
- feature_schema: Optional[dict] = None,
263
+ sources: list[str] | None = None,
264
+ feature_schema: dict | None = None,
119
265
  query_script: str = "",
120
- schema: Optional[dict[str, Any]] = None,
266
+ schema: dict[str, Any] | None = None,
121
267
  ignore_if_exists: bool = False,
122
- description: Optional[str] = None,
123
- labels: Optional[list[str]] = None,
268
+ description: str | None = None,
269
+ attrs: list[str] | None = None,
124
270
  ) -> DatasetRecord:
125
271
  """Creates new dataset."""
126
272
 
@@ -128,23 +274,23 @@ class AbstractMetastore(ABC, Serializable):
128
274
  def create_dataset_version( # noqa: PLR0913
129
275
  self,
130
276
  dataset: DatasetRecord,
131
- version: int,
277
+ version: str,
132
278
  status: int,
133
279
  sources: str = "",
134
- feature_schema: Optional[dict] = None,
280
+ feature_schema: dict | None = None,
135
281
  query_script: str = "",
136
282
  error_message: str = "",
137
283
  error_stack: str = "",
138
284
  script_output: str = "",
139
- created_at: Optional[datetime] = None,
140
- finished_at: Optional[datetime] = None,
141
- schema: Optional[dict[str, Any]] = None,
285
+ created_at: datetime | None = None,
286
+ finished_at: datetime | None = None,
287
+ schema: dict[str, Any] | None = None,
142
288
  ignore_if_exists: bool = False,
143
- num_objects: Optional[int] = None,
144
- size: Optional[int] = None,
145
- preview: Optional[list[dict]] = None,
146
- job_id: Optional[str] = None,
147
- uuid: Optional[str] = None,
289
+ num_objects: int | None = None,
290
+ size: int | None = None,
291
+ preview: list[dict] | None = None,
292
+ job_id: str | None = None,
293
+ uuid: str | None = None,
148
294
  ) -> DatasetRecord:
149
295
  """Creates new dataset version."""
150
296
 
@@ -158,13 +304,13 @@ class AbstractMetastore(ABC, Serializable):
158
304
 
159
305
  @abstractmethod
160
306
  def update_dataset_version(
161
- self, dataset: DatasetRecord, version: int, **kwargs
307
+ self, dataset: DatasetRecord, version: str, **kwargs
162
308
  ) -> DatasetVersion:
163
309
  """Updates dataset version fields."""
164
310
 
165
311
  @abstractmethod
166
312
  def remove_dataset_version(
167
- self, dataset: DatasetRecord, version: int
313
+ self, dataset: DatasetRecord, version: str
168
314
  ) -> DatasetRecord:
169
315
  """
170
316
  Deletes one single dataset version.
@@ -172,15 +318,32 @@ class AbstractMetastore(ABC, Serializable):
172
318
  """
173
319
 
174
320
  @abstractmethod
175
- def list_datasets(self) -> Iterator[DatasetListRecord]:
176
- """Lists all datasets."""
321
+ def list_datasets(
322
+ self, project_id: int | None = None
323
+ ) -> Iterator[DatasetListRecord]:
324
+ """Lists all datasets in some project or in all projects."""
177
325
 
178
326
  @abstractmethod
179
- def list_datasets_by_prefix(self, prefix: str) -> Iterator["DatasetListRecord"]:
180
- """Lists all datasets which names start with prefix."""
327
+ def count_datasets(self, project_id: int | None = None) -> int:
328
+ """Counts datasets in some project or in all projects."""
181
329
 
182
330
  @abstractmethod
183
- def get_dataset(self, name: str) -> DatasetRecord:
331
+ def list_datasets_by_prefix(
332
+ self, prefix: str, project_id: int | None = None
333
+ ) -> Iterator["DatasetListRecord"]:
334
+ """
335
+ Lists all datasets which names start with prefix in some project or in all
336
+ projects.
337
+ """
338
+
339
+ @abstractmethod
340
+ def get_dataset(
341
+ self,
342
+ name: str, # normal, not full dataset name
343
+ namespace_name: str | None = None,
344
+ project_name: str | None = None,
345
+ conn=None,
346
+ ) -> DatasetRecord:
184
347
  """Gets a single dataset by name."""
185
348
 
186
349
  @abstractmethod
@@ -188,7 +351,7 @@ class AbstractMetastore(ABC, Serializable):
188
351
  self,
189
352
  dataset: DatasetRecord,
190
353
  status: int,
191
- version: Optional[int] = None,
354
+ version: str | None = None,
192
355
  error_message="",
193
356
  error_stack="",
194
357
  script_output="",
@@ -201,10 +364,10 @@ class AbstractMetastore(ABC, Serializable):
201
364
  @abstractmethod
202
365
  def add_dataset_dependency(
203
366
  self,
204
- source_dataset_name: str,
205
- source_dataset_version: int,
206
- dataset_name: str,
207
- dataset_version: int,
367
+ source_dataset: "DatasetRecord",
368
+ source_dataset_version: str,
369
+ dep_dataset: "DatasetRecord",
370
+ dep_dataset_version: str,
208
371
  ) -> None:
209
372
  """Adds dataset dependency to dataset."""
210
373
 
@@ -212,21 +375,27 @@ class AbstractMetastore(ABC, Serializable):
212
375
  def update_dataset_dependency_source(
213
376
  self,
214
377
  source_dataset: DatasetRecord,
215
- source_dataset_version: int,
216
- new_source_dataset: Optional[DatasetRecord] = None,
217
- new_source_dataset_version: Optional[int] = None,
378
+ source_dataset_version: str,
379
+ new_source_dataset: DatasetRecord | None = None,
380
+ new_source_dataset_version: str | None = None,
218
381
  ) -> None:
219
382
  """Updates dataset dependency source."""
220
383
 
221
384
  @abstractmethod
222
385
  def get_direct_dataset_dependencies(
223
- self, dataset: DatasetRecord, version: int
224
- ) -> list[Optional[DatasetDependency]]:
386
+ self, dataset: DatasetRecord, version: str
387
+ ) -> list[DatasetDependency | None]:
225
388
  """Gets direct dataset dependencies."""
226
389
 
390
+ @abstractmethod
391
+ def get_dataset_dependency_nodes(
392
+ self, dataset_id: int, version_id: int
393
+ ) -> list[DatasetDependencyNode | None]:
394
+ """Gets dataset dependency node from database."""
395
+
227
396
  @abstractmethod
228
397
  def remove_dataset_dependencies(
229
- self, dataset: DatasetRecord, version: Optional[int] = None
398
+ self, dataset: DatasetRecord, version: str | None = None
230
399
  ) -> None:
231
400
  """
232
401
  When we remove dataset, we need to clean up it's dependencies as well.
@@ -234,7 +403,7 @@ class AbstractMetastore(ABC, Serializable):
234
403
 
235
404
  @abstractmethod
236
405
  def remove_dataset_dependants(
237
- self, dataset: DatasetRecord, version: Optional[int] = None
406
+ self, dataset: DatasetRecord, version: str | None = None
238
407
  ) -> None:
239
408
  """
240
409
  When we remove dataset, we need to clear its references in other dataset
@@ -254,43 +423,121 @@ class AbstractMetastore(ABC, Serializable):
254
423
  name: str,
255
424
  query: str,
256
425
  query_type: JobQueryType = JobQueryType.PYTHON,
426
+ status: JobStatus = JobStatus.CREATED,
257
427
  workers: int = 1,
258
- python_version: Optional[str] = None,
259
- params: Optional[dict[str, str]] = None,
428
+ python_version: str | None = None,
429
+ params: dict[str, str] | None = None,
430
+ parent_job_id: str | None = None,
260
431
  ) -> str:
261
432
  """
262
433
  Creates a new job.
263
434
  Returns the job id.
264
435
  """
265
436
 
437
+ @abstractmethod
438
+ def get_job(self, job_id: str) -> Job | None:
439
+ """Returns the job with the given ID."""
440
+
441
+ @abstractmethod
442
+ def update_job(
443
+ self,
444
+ job_id: str,
445
+ status: JobStatus | None = None,
446
+ error_message: str | None = None,
447
+ error_stack: str | None = None,
448
+ finished_at: datetime | None = None,
449
+ metrics: dict[str, Any] | None = None,
450
+ ) -> Job | None:
451
+ """Updates job fields."""
452
+
266
453
  @abstractmethod
267
454
  def set_job_status(
268
455
  self,
269
456
  job_id: str,
270
457
  status: JobStatus,
271
- error_message: Optional[str] = None,
272
- error_stack: Optional[str] = None,
273
- metrics: Optional[dict[str, Any]] = None,
458
+ error_message: str | None = None,
459
+ error_stack: str | None = None,
274
460
  ) -> None:
275
461
  """Set the status of the given job."""
276
462
 
277
463
  @abstractmethod
278
- def get_job_status(self, job_id: str) -> Optional[JobStatus]:
464
+ def get_job_status(self, job_id: str) -> JobStatus | None:
279
465
  """Returns the status of the given job."""
280
466
 
281
467
  @abstractmethod
282
- def set_job_and_dataset_status(
468
+ def get_last_job_by_name(self, name: str, conn=None) -> "Job | None":
469
+ """Returns the last job with the given name, ordered by created_at."""
470
+
471
+ #
472
+ # Checkpoints
473
+ #
474
+
475
+ @abstractmethod
476
+ def list_checkpoints(self, job_id: str, conn=None) -> Iterator[Checkpoint]:
477
+ """Returns all checkpoints related to some job"""
478
+
479
+ @abstractmethod
480
+ def get_last_checkpoint(self, job_id: str, conn=None) -> Checkpoint | None:
481
+ """Get last created checkpoint for some job."""
482
+
483
+ @abstractmethod
484
+ def get_checkpoint_by_id(self, checkpoint_id: str, conn=None) -> Checkpoint:
485
+ """Gets single checkpoint by id"""
486
+
487
+ def find_checkpoint(
488
+ self, job_id: str, _hash: str, partial: bool = False, conn=None
489
+ ) -> Checkpoint | None:
490
+ """
491
+ Tries to find checkpoint for a job with specific hash and optionally partial
492
+ """
493
+
494
+ @abstractmethod
495
+ def create_checkpoint(
283
496
  self,
284
497
  job_id: str,
285
- job_status: JobStatus,
286
- dataset_status: DatasetStatus,
498
+ _hash: str,
499
+ partial: bool = False,
500
+ conn: Any | None = None,
501
+ ) -> Checkpoint:
502
+ """Creates new checkpoint"""
503
+
504
+ #
505
+ # Dataset Version Jobs (many-to-many)
506
+ #
507
+
508
+ @abstractmethod
509
+ def link_dataset_version_to_job(
510
+ self,
511
+ dataset_version_id: int,
512
+ job_id: str,
513
+ is_creator: bool = False,
514
+ conn=None,
287
515
  ) -> None:
288
- """Set the status of the given job and dataset."""
516
+ """
517
+ Link dataset version to job.
518
+
519
+ This atomically:
520
+ 1. Creates a link in the dataset_version_jobs junction table
521
+ 2. Updates dataset_version.job_id to point to this job
522
+ """
289
523
 
290
524
  @abstractmethod
291
- def get_job_dataset_versions(self, job_id: str) -> list[tuple[str, int]]:
292
- """Returns dataset names and versions for the job."""
293
- raise NotImplementedError
525
+ def get_ancestor_job_ids(self, job_id: str, conn=None) -> list[str]:
526
+ """Get all ancestor job IDs for a given job."""
527
+
528
+ @abstractmethod
529
+ def get_dataset_version_for_job_ancestry(
530
+ self,
531
+ dataset_name: str,
532
+ namespace_name: str,
533
+ project_name: str,
534
+ job_id: str,
535
+ conn=None,
536
+ ) -> DatasetVersion | None:
537
+ """
538
+ Find the dataset version that was created by any job in the ancestry.
539
+ Returns the most recently linked version from these jobs.
540
+ """
294
541
 
295
542
 
296
543
  class AbstractDBMetastore(AbstractMetastore):
@@ -301,14 +548,18 @@ class AbstractDBMetastore(AbstractMetastore):
301
548
  and has shared logic for all database systems currently in use.
302
549
  """
303
550
 
551
+ NAMESPACE_TABLE = "namespaces"
552
+ PROJECT_TABLE = "projects"
304
553
  DATASET_TABLE = "datasets"
305
554
  DATASET_VERSION_TABLE = "datasets_versions"
306
555
  DATASET_DEPENDENCY_TABLE = "datasets_dependencies"
556
+ DATASET_VERSION_JOBS_TABLE = "dataset_version_jobs"
307
557
  JOBS_TABLE = "jobs"
558
+ CHECKPOINTS_TABLE = "checkpoints"
308
559
 
309
560
  db: "DatabaseEngine"
310
561
 
311
- def __init__(self, uri: Optional[StorageURI] = None):
562
+ def __init__(self, uri: StorageURI | None = None):
312
563
  uri = uri or StorageURI("")
313
564
  super().__init__(uri)
314
565
 
@@ -319,14 +570,65 @@ class AbstractDBMetastore(AbstractMetastore):
319
570
  def cleanup_tables(self, temp_table_names: list[str]) -> None:
320
571
  """Cleanup temp tables."""
321
572
 
573
+ @classmethod
574
+ def _namespaces_columns(cls) -> list["SchemaItem"]:
575
+ """Namespace table columns."""
576
+ return [
577
+ Column("id", Integer, primary_key=True),
578
+ Column("uuid", Text, nullable=False, default=uuid4()),
579
+ Column("name", Text, nullable=False),
580
+ Column("description", Text),
581
+ Column("created_at", DateTime(timezone=True)),
582
+ ]
583
+
584
+ @cached_property
585
+ def _namespaces_fields(self) -> list[str]:
586
+ return [
587
+ c.name # type: ignore [attr-defined]
588
+ for c in self._namespaces_columns()
589
+ if c.name # type: ignore [attr-defined]
590
+ ]
591
+
592
+ @classmethod
593
+ def _projects_columns(cls) -> list["SchemaItem"]:
594
+ """Project table columns."""
595
+ return [
596
+ Column("id", Integer, primary_key=True),
597
+ Column("uuid", Text, nullable=False, default=uuid4()),
598
+ Column("name", Text, nullable=False),
599
+ Column("description", Text),
600
+ Column("created_at", DateTime(timezone=True)),
601
+ Column(
602
+ "namespace_id",
603
+ Integer,
604
+ ForeignKey(f"{cls.NAMESPACE_TABLE}.id", ondelete="CASCADE"),
605
+ nullable=False,
606
+ ),
607
+ UniqueConstraint("namespace_id", "name"),
608
+ ]
609
+
610
+ @cached_property
611
+ def _projects_fields(self) -> list[str]:
612
+ return [
613
+ c.name # type: ignore [attr-defined]
614
+ for c in self._projects_columns()
615
+ if c.name # type: ignore [attr-defined]
616
+ ]
617
+
322
618
  @classmethod
323
619
  def _datasets_columns(cls) -> list["SchemaItem"]:
324
620
  """Datasets table columns."""
325
621
  return [
326
622
  Column("id", Integer, primary_key=True),
623
+ Column(
624
+ "project_id",
625
+ Integer,
626
+ ForeignKey(f"{cls.PROJECT_TABLE}.id", ondelete="CASCADE"),
627
+ nullable=False,
628
+ ),
327
629
  Column("name", Text, nullable=False),
328
630
  Column("description", Text),
329
- Column("labels", JSON, nullable=True),
631
+ Column("attrs", JSON, nullable=True),
330
632
  Column("status", Integer, nullable=False),
331
633
  Column("feature_schema", JSON, nullable=True),
332
634
  Column("created_at", DateTime(timezone=True)),
@@ -367,7 +669,7 @@ class AbstractDBMetastore(AbstractMetastore):
367
669
  ForeignKey(f"{cls.DATASET_TABLE}.id", ondelete="CASCADE"),
368
670
  nullable=False,
369
671
  ),
370
- Column("version", Integer, nullable=False),
672
+ Column("version", Text, nullable=False, default="1.0.0"),
371
673
  Column(
372
674
  "status",
373
675
  Integer,
@@ -442,6 +744,16 @@ class AbstractDBMetastore(AbstractMetastore):
442
744
  #
443
745
  # Query Tables
444
746
  #
747
+ @cached_property
748
+ def _namespaces(self) -> Table:
749
+ return Table(
750
+ self.NAMESPACE_TABLE, self.db.metadata, *self._namespaces_columns()
751
+ )
752
+
753
+ @cached_property
754
+ def _projects(self) -> Table:
755
+ return Table(self.PROJECT_TABLE, self.db.metadata, *self._projects_columns())
756
+
445
757
  @cached_property
446
758
  def _datasets(self) -> Table:
447
759
  return Table(self.DATASET_TABLE, self.db.metadata, *self._datasets_columns())
@@ -465,6 +777,31 @@ class AbstractDBMetastore(AbstractMetastore):
465
777
  #
466
778
  # Query Starters (These can be overridden by subclasses)
467
779
  #
780
+ @abstractmethod
781
+ def _namespaces_insert(self) -> "Insert": ...
782
+
783
+ def _namespaces_select(self, *columns) -> "Select":
784
+ if not columns:
785
+ return self._namespaces.select()
786
+ return select(*columns)
787
+
788
+ def _namespaces_update(self) -> "Update":
789
+ return self._namespaces.update()
790
+
791
+ def _namespaces_delete(self) -> "Delete":
792
+ return self._namespaces.delete()
793
+
794
+ @abstractmethod
795
+ def _projects_insert(self) -> "Insert": ...
796
+
797
+ def _projects_select(self, *columns) -> "Select":
798
+ if not columns:
799
+ return self._projects.select()
800
+ return select(*columns)
801
+
802
+ def _projects_delete(self) -> "Delete":
803
+ return self._projects.delete()
804
+
468
805
  @abstractmethod
469
806
  def _datasets_insert(self) -> "Insert": ...
470
807
 
@@ -507,6 +844,197 @@ class AbstractDBMetastore(AbstractMetastore):
507
844
  def _datasets_dependencies_delete(self) -> "Delete":
508
845
  return self._datasets_dependencies.delete()
509
846
 
847
+ #
848
+ # Namespaces
849
+ #
850
+
851
+ def create_namespace(
852
+ self,
853
+ name: str,
854
+ description: str | None = None,
855
+ uuid: str | None = None,
856
+ ignore_if_exists: bool = True,
857
+ validate: bool = True,
858
+ **kwargs,
859
+ ) -> Namespace:
860
+ if validate:
861
+ Namespace.validate_name(name)
862
+ query = self._namespaces_insert().values(
863
+ name=name,
864
+ uuid=uuid or str(uuid4()),
865
+ created_at=datetime.now(timezone.utc),
866
+ description=description,
867
+ )
868
+ if ignore_if_exists and hasattr(query, "on_conflict_do_nothing"):
869
+ # SQLite and PostgreSQL both support 'on_conflict_do_nothing',
870
+ # but generic SQL does not
871
+ query = query.on_conflict_do_nothing(index_elements=["name"])
872
+ self.db.execute(query)
873
+
874
+ return self.get_namespace(name)
875
+
876
+ def remove_namespace(self, namespace_id: int, conn=None) -> None:
877
+ num_projects = self.count_projects(namespace_id)
878
+ if num_projects > 0:
879
+ raise NamespaceDeleteNotAllowedError(
880
+ f"Namespace cannot be removed. It contains {num_projects} project(s). "
881
+ "Please remove the project(s) first."
882
+ )
883
+
884
+ n = self._namespaces
885
+ with self.db.transaction():
886
+ self.db.execute(self._namespaces_delete().where(n.c.id == namespace_id))
887
+
888
+ def get_namespace(self, name: str, conn=None) -> Namespace:
889
+ """Gets a single namespace by name"""
890
+ n = self._namespaces
891
+
892
+ query = self._namespaces_select(
893
+ *(getattr(n.c, f) for f in self._namespaces_fields),
894
+ ).where(n.c.name == name)
895
+ rows = list(self.db.execute(query, conn=conn))
896
+ if not rows:
897
+ raise NamespaceNotFoundError(f"Namespace {name} not found.")
898
+ return self.namespace_class.parse(*rows[0])
899
+
900
+ def list_namespaces(self, conn=None) -> list[Namespace]:
901
+ """Gets a list of all namespaces"""
902
+ n = self._namespaces
903
+
904
+ query = self._namespaces_select(
905
+ *(getattr(n.c, f) for f in self._namespaces_fields),
906
+ )
907
+ rows = list(self.db.execute(query, conn=conn))
908
+
909
+ return [self.namespace_class.parse(*r) for r in rows]
910
+
911
+ #
912
+ # Projects
913
+ #
914
+
915
+ def create_project(
916
+ self,
917
+ namespace_name: str,
918
+ name: str,
919
+ description: str | None = None,
920
+ uuid: str | None = None,
921
+ ignore_if_exists: bool = True,
922
+ validate: bool = True,
923
+ **kwargs,
924
+ ) -> Project:
925
+ if validate:
926
+ Project.validate_name(name)
927
+ try:
928
+ namespace = self.get_namespace(namespace_name)
929
+ except NamespaceNotFoundError:
930
+ namespace = self.create_namespace(namespace_name, validate=validate)
931
+
932
+ query = self._projects_insert().values(
933
+ namespace_id=namespace.id,
934
+ uuid=uuid or str(uuid4()),
935
+ name=name,
936
+ created_at=datetime.now(timezone.utc),
937
+ description=description,
938
+ )
939
+ if ignore_if_exists and hasattr(query, "on_conflict_do_nothing"):
940
+ # SQLite and PostgreSQL both support 'on_conflict_do_nothing',
941
+ # but generic SQL does not
942
+ query = query.on_conflict_do_nothing(
943
+ index_elements=["namespace_id", "name"]
944
+ )
945
+ self.db.execute(query)
946
+
947
+ return self.get_project(name, namespace.name)
948
+
949
+ def _projects_base_query(self) -> "Select":
950
+ n = self._namespaces
951
+ p = self._projects
952
+
953
+ query = self._projects_select(
954
+ *(getattr(n.c, f) for f in self._namespaces_fields),
955
+ *(getattr(p.c, f) for f in self._projects_fields),
956
+ )
957
+ return query.select_from(n.join(p, n.c.id == p.c.namespace_id))
958
+
959
+ def get_project(
960
+ self, name: str, namespace_name: str, create: bool = False, conn=None
961
+ ) -> Project:
962
+ """Gets a single project inside some namespace by name"""
963
+ n = self._namespaces
964
+ p = self._projects
965
+ validate = True
966
+
967
+ if self.is_listing_project(name, namespace_name) or self.is_default_project(
968
+ name, namespace_name
969
+ ):
970
+ # we are always creating default and listing projects if they don't exist
971
+ create = True
972
+ validate = False
973
+
974
+ query = self._projects_base_query().where(
975
+ p.c.name == name, n.c.name == namespace_name
976
+ )
977
+
978
+ rows = list(self.db.execute(query, conn=conn))
979
+ if not rows:
980
+ if create:
981
+ return self.create_project(namespace_name, name, validate=validate)
982
+ raise ProjectNotFoundError(
983
+ f"Project {name} in namespace {namespace_name} not found."
984
+ )
985
+ return self.project_class.parse(*rows[0])
986
+
987
+ def get_project_by_id(self, project_id: int, conn=None) -> Project:
988
+ """Gets a single project by id"""
989
+ p = self._projects
990
+
991
+ query = self._projects_base_query().where(p.c.id == project_id)
992
+
993
+ rows = list(self.db.execute(query, conn=conn))
994
+ if not rows:
995
+ raise ProjectNotFoundError(f"Project with id {project_id} not found.")
996
+ return self.project_class.parse(*rows[0])
997
+
998
+ def count_projects(self, namespace_id: int | None = None) -> int:
999
+ p = self._projects
1000
+
1001
+ query = self._projects_base_query()
1002
+ if namespace_id:
1003
+ query = query.where(p.c.namespace_id == namespace_id)
1004
+
1005
+ query = select(f.count(1)).select_from(query.subquery())
1006
+
1007
+ return next(self.db.execute(query))[0]
1008
+
1009
+ def remove_project(self, project_id: int, conn=None) -> None:
1010
+ num_datasets = self.count_datasets(project_id)
1011
+ if num_datasets > 0:
1012
+ raise ProjectDeleteNotAllowedError(
1013
+ f"Project cannot be removed. It contains {num_datasets} dataset(s). "
1014
+ "Please remove the dataset(s) first."
1015
+ )
1016
+
1017
+ p = self._projects
1018
+ with self.db.transaction():
1019
+ self.db.execute(self._projects_delete().where(p.c.id == project_id))
1020
+
1021
+ def list_projects(
1022
+ self, namespace_id: int | None = None, conn=None
1023
+ ) -> list[Project]:
1024
+ """
1025
+ Gets a list of projects inside some namespace, or in all namespaces
1026
+ """
1027
+ p = self._projects
1028
+
1029
+ query = self._projects_base_query()
1030
+
1031
+ if namespace_id:
1032
+ query = query.where(p.c.namespace_id == namespace_id)
1033
+
1034
+ rows = list(self.db.execute(query, conn=conn))
1035
+
1036
+ return [self.project_class.parse(*r) for r in rows]
1037
+
510
1038
  #
511
1039
  # Datasets
512
1040
  #
@@ -514,20 +1042,26 @@ class AbstractDBMetastore(AbstractMetastore):
514
1042
  def create_dataset(
515
1043
  self,
516
1044
  name: str,
1045
+ project_id: int | None = None,
517
1046
  status: int = DatasetStatus.CREATED,
518
- sources: Optional[list[str]] = None,
519
- feature_schema: Optional[dict] = None,
1047
+ sources: list[str] | None = None,
1048
+ feature_schema: dict | None = None,
520
1049
  query_script: str = "",
521
- schema: Optional[dict[str, Any]] = None,
1050
+ schema: dict[str, Any] | None = None,
522
1051
  ignore_if_exists: bool = False,
523
- description: Optional[str] = None,
524
- labels: Optional[list[str]] = None,
1052
+ description: str | None = None,
1053
+ attrs: list[str] | None = None,
525
1054
  **kwargs, # TODO registered = True / False
526
1055
  ) -> DatasetRecord:
527
1056
  """Creates new dataset."""
528
- # TODO abstract this method and add registered = True based on kwargs
1057
+ if not project_id:
1058
+ project = self.default_project
1059
+ else:
1060
+ project = self.get_project_by_id(project_id)
1061
+
529
1062
  query = self._datasets_insert().values(
530
1063
  name=name,
1064
+ project_id=project.id,
531
1065
  status=status,
532
1066
  feature_schema=json.dumps(feature_schema or {}),
533
1067
  created_at=datetime.now(timezone.utc),
@@ -538,36 +1072,38 @@ class AbstractDBMetastore(AbstractMetastore):
538
1072
  query_script=query_script,
539
1073
  schema=json.dumps(schema or {}),
540
1074
  description=description,
541
- labels=json.dumps(labels or []),
1075
+ attrs=json.dumps(attrs or []),
542
1076
  )
543
1077
  if ignore_if_exists and hasattr(query, "on_conflict_do_nothing"):
544
1078
  # SQLite and PostgreSQL both support 'on_conflict_do_nothing',
545
1079
  # but generic SQL does not
546
- query = query.on_conflict_do_nothing(index_elements=["name"])
1080
+ query = query.on_conflict_do_nothing(index_elements=["project_id", "name"])
547
1081
  self.db.execute(query)
548
1082
 
549
- return self.get_dataset(name)
1083
+ return self.get_dataset(
1084
+ name, namespace_name=project.namespace.name, project_name=project.name
1085
+ )
550
1086
 
551
1087
  def create_dataset_version( # noqa: PLR0913
552
1088
  self,
553
1089
  dataset: DatasetRecord,
554
- version: int,
1090
+ version: str,
555
1091
  status: int,
556
1092
  sources: str = "",
557
- feature_schema: Optional[dict] = None,
1093
+ feature_schema: dict | None = None,
558
1094
  query_script: str = "",
559
1095
  error_message: str = "",
560
1096
  error_stack: str = "",
561
1097
  script_output: str = "",
562
- created_at: Optional[datetime] = None,
563
- finished_at: Optional[datetime] = None,
564
- schema: Optional[dict[str, Any]] = None,
1098
+ created_at: datetime | None = None,
1099
+ finished_at: datetime | None = None,
1100
+ schema: dict[str, Any] | None = None,
565
1101
  ignore_if_exists: bool = False,
566
- num_objects: Optional[int] = None,
567
- size: Optional[int] = None,
568
- preview: Optional[list[dict]] = None,
569
- job_id: Optional[str] = None,
570
- uuid: Optional[str] = None,
1102
+ num_objects: int | None = None,
1103
+ size: int | None = None,
1104
+ preview: list[dict] | None = None,
1105
+ job_id: str | None = None,
1106
+ uuid: str | None = None,
571
1107
  conn=None,
572
1108
  ) -> DatasetRecord:
573
1109
  """Creates new dataset version."""
@@ -603,7 +1139,12 @@ class AbstractDBMetastore(AbstractMetastore):
603
1139
  )
604
1140
  self.db.execute(query, conn=conn)
605
1141
 
606
- return self.get_dataset(dataset.name, conn=conn)
1142
+ return self.get_dataset(
1143
+ dataset.name,
1144
+ namespace_name=dataset.project.namespace.name,
1145
+ project_name=dataset.project.name,
1146
+ conn=conn,
1147
+ )
607
1148
 
608
1149
  def remove_dataset(self, dataset: DatasetRecord) -> None:
609
1150
  """Removes dataset."""
@@ -617,26 +1158,47 @@ class AbstractDBMetastore(AbstractMetastore):
617
1158
  self, dataset: DatasetRecord, conn=None, **kwargs
618
1159
  ) -> DatasetRecord:
619
1160
  """Updates dataset fields."""
620
- values = {}
621
- dataset_values = {}
1161
+ values: dict[str, Any] = {}
1162
+ dataset_values: dict[str, Any] = {}
622
1163
  for field, value in kwargs.items():
623
- if field in self._dataset_fields[1:]:
624
- if field in ["labels", "schema"]:
625
- values[field] = json.dumps(value) if value else None
1164
+ if field in ("id", "created_at") or field not in self._dataset_fields:
1165
+ continue # these fields are read-only or not applicable
1166
+
1167
+ if value is None and field in ("name", "status", "sources", "query_script"):
1168
+ raise ValueError(f"Field {field} cannot be None")
1169
+ if field == "name" and not value:
1170
+ raise ValueError("name cannot be empty")
1171
+
1172
+ if field == "attrs":
1173
+ if value is None:
1174
+ values[field] = None
626
1175
  else:
627
- values[field] = value
628
- if field == "schema":
629
- dataset_values[field] = DatasetRecord.parse_schema(value)
1176
+ values[field] = json.dumps(value)
1177
+ dataset_values[field] = value
1178
+ elif field == "schema":
1179
+ if value is None:
1180
+ values[field] = None
1181
+ dataset_values[field] = None
630
1182
  else:
631
- dataset_values[field] = value
1183
+ values[field] = json.dumps(value)
1184
+ dataset_values[field] = parse_schema(value)
1185
+ elif field == "project_id":
1186
+ if not value:
1187
+ raise ValueError("Cannot set empty project_id for dataset")
1188
+ dataset_values["project"] = self.get_project_by_id(value)
1189
+ values[field] = value
1190
+ else:
1191
+ values[field] = value
1192
+ dataset_values[field] = value
632
1193
 
633
1194
  if not values:
634
- # Nothing to update
635
- return dataset
1195
+ return dataset # nothing to update
636
1196
 
637
1197
  d = self._datasets
638
1198
  self.db.execute(
639
- self._datasets_update().where(d.c.name == dataset.name).values(values),
1199
+ self._datasets_update()
1200
+ .where(d.c.name == dataset.name, d.c.project_id == dataset.project.id)
1201
+ .values(values),
640
1202
  conn=conn,
641
1203
  ) # type: ignore [attr-defined]
642
1204
 
@@ -645,46 +1207,79 @@ class AbstractDBMetastore(AbstractMetastore):
645
1207
  return result_ds
646
1208
 
647
1209
  def update_dataset_version(
648
- self, dataset: DatasetRecord, version: int, conn=None, **kwargs
1210
+ self, dataset: DatasetRecord, version: str, conn=None, **kwargs
649
1211
  ) -> DatasetVersion:
650
1212
  """Updates dataset fields."""
651
- dataset_version = dataset.get_version(version)
652
-
653
- values = {}
1213
+ values: dict[str, Any] = {}
1214
+ version_values: dict[str, Any] = {}
654
1215
  for field, value in kwargs.items():
655
- if field in self._dataset_version_fields[1:]:
656
- if field == "schema":
657
- dataset_version.update(**{field: DatasetRecord.parse_schema(value)})
658
- values[field] = json.dumps(value) if value else None
659
- elif field == "feature_schema":
660
- values[field] = json.dumps(value) if value else None
661
- elif field == "preview" and isinstance(value, list):
662
- values[field] = json.dumps(value, cls=JSONSerialize)
1216
+ if (
1217
+ field in ("id", "created_at")
1218
+ or field not in self._dataset_version_fields
1219
+ ):
1220
+ continue # these fields are read-only or not applicable
1221
+
1222
+ if value is None and field in (
1223
+ "status",
1224
+ "sources",
1225
+ "query_script",
1226
+ "error_message",
1227
+ "error_stack",
1228
+ "script_output",
1229
+ "uuid",
1230
+ ):
1231
+ raise ValueError(f"Field {field} cannot be None")
1232
+
1233
+ if field == "schema":
1234
+ values[field] = json.dumps(value) if value else None
1235
+ version_values[field] = parse_schema(value) if value else None
1236
+ elif field == "feature_schema":
1237
+ if value is None:
1238
+ values[field] = None
1239
+ else:
1240
+ values[field] = json.dumps(value)
1241
+ version_values[field] = value
1242
+ elif field == "preview":
1243
+ if value is None:
1244
+ values[field] = None
1245
+ elif not isinstance(value, list):
1246
+ raise ValueError(
1247
+ f"Field '{field}' must be a list, got {type(value).__name__}"
1248
+ )
663
1249
  else:
664
- values[field] = value
665
- dataset_version.update(**{field: value})
1250
+ values[field] = json.dumps(value, serialize_bytes=True)
1251
+ version_values["_preview_data"] = value
1252
+ else:
1253
+ values[field] = value
1254
+ version_values[field] = value
666
1255
 
667
1256
  if not values:
668
- # Nothing to update
669
- return dataset_version
1257
+ return dataset.get_version(version)
670
1258
 
671
1259
  dv = self._datasets_versions
672
1260
  self.db.execute(
673
1261
  self._datasets_versions_update()
674
- .where(dv.c.id == dataset_version.id)
1262
+ .where(dv.c.dataset_id == dataset.id, dv.c.version == version)
675
1263
  .values(values),
676
1264
  conn=conn,
677
1265
  ) # type: ignore [attr-defined]
678
1266
 
679
- return dataset_version
1267
+ for v in dataset.versions:
1268
+ if v.version == version:
1269
+ v.update(**version_values)
1270
+ return v
680
1271
 
681
- def _parse_dataset(self, rows) -> Optional[DatasetRecord]:
1272
+ raise DatasetVersionNotFoundError(
1273
+ f"Dataset {dataset.name} does not have version {version}"
1274
+ )
1275
+
1276
+ def _parse_dataset(self, rows) -> DatasetRecord | None:
682
1277
  versions = [self.dataset_class.parse(*r) for r in rows]
683
1278
  if not versions:
684
1279
  return None
685
1280
  return reduce(lambda ds, version: ds.merge_versions(version), versions)
686
1281
 
687
- def _parse_list_dataset(self, rows) -> Optional[DatasetListRecord]:
1282
+ def _parse_list_dataset(self, rows) -> DatasetListRecord | None:
688
1283
  versions = [self.dataset_list_class.parse(*r) for r in rows]
689
1284
  if not versions:
690
1285
  return None
@@ -692,69 +1287,124 @@ class AbstractDBMetastore(AbstractMetastore):
692
1287
 
693
1288
  def _parse_dataset_list(self, rows) -> Iterator["DatasetListRecord"]:
694
1289
  # grouping rows by dataset id
695
- for _, g in groupby(rows, lambda r: r[0]):
1290
+ for _, g in groupby(rows, lambda r: r[11]):
696
1291
  dataset = self._parse_list_dataset(list(g))
697
1292
  if dataset:
698
1293
  yield dataset
699
1294
 
700
1295
  def _get_dataset_query(
701
1296
  self,
1297
+ namespace_fields: list[str],
1298
+ project_fields: list[str],
702
1299
  dataset_fields: list[str],
703
1300
  dataset_version_fields: list[str],
704
1301
  isouter: bool = True,
705
- ):
1302
+ ) -> "Select":
706
1303
  if not (
707
1304
  self.db.has_table(self._datasets.name)
708
1305
  and self.db.has_table(self._datasets_versions.name)
709
1306
  ):
710
1307
  raise TableMissingError
711
1308
 
1309
+ n = self._namespaces
1310
+ p = self._projects
712
1311
  d = self._datasets
713
1312
  dv = self._datasets_versions
714
1313
 
715
1314
  query = self._datasets_select(
1315
+ *(getattr(n.c, f) for f in namespace_fields),
1316
+ *(getattr(p.c, f) for f in project_fields),
716
1317
  *(getattr(d.c, f) for f in dataset_fields),
717
1318
  *(getattr(dv.c, f) for f in dataset_version_fields),
718
1319
  )
719
- j = d.join(dv, d.c.id == dv.c.dataset_id, isouter=isouter)
1320
+ j = (
1321
+ n.join(p, n.c.id == p.c.namespace_id)
1322
+ .join(d, p.c.id == d.c.project_id)
1323
+ .join(dv, d.c.id == dv.c.dataset_id, isouter=isouter)
1324
+ )
720
1325
  return query.select_from(j)
721
1326
 
722
- def _base_dataset_query(self):
1327
+ def _base_dataset_query(self) -> "Select":
723
1328
  return self._get_dataset_query(
724
- self._dataset_fields, self._dataset_version_fields
1329
+ self._namespaces_fields,
1330
+ self._projects_fields,
1331
+ self._dataset_fields,
1332
+ self._dataset_version_fields,
725
1333
  )
726
1334
 
727
- def _base_list_datasets_query(self):
1335
+ def _base_list_datasets_query(self) -> "Select":
728
1336
  return self._get_dataset_query(
729
- self._dataset_list_fields, self._dataset_list_version_fields, isouter=False
1337
+ self._namespaces_fields,
1338
+ self._projects_fields,
1339
+ self._dataset_list_fields,
1340
+ self._dataset_list_version_fields,
1341
+ isouter=False,
730
1342
  )
731
1343
 
732
- def list_datasets(self) -> Iterator["DatasetListRecord"]:
733
- """Lists all datasets."""
1344
+ def list_datasets(
1345
+ self, project_id: int | None = None
1346
+ ) -> Iterator["DatasetListRecord"]:
1347
+ d = self._datasets
734
1348
  query = self._base_list_datasets_query().order_by(
735
1349
  self._datasets.c.name, self._datasets_versions.c.version
736
1350
  )
1351
+ if project_id:
1352
+ query = query.where(d.c.project_id == project_id)
737
1353
  yield from self._parse_dataset_list(self.db.execute(query))
738
1354
 
1355
+ def count_datasets(self, project_id: int | None = None) -> int:
1356
+ d = self._datasets
1357
+ query = self._datasets_select()
1358
+ if project_id:
1359
+ query = query.where(d.c.project_id == project_id)
1360
+
1361
+ query = select(f.count(1)).select_from(query.subquery())
1362
+
1363
+ return next(self.db.execute(query))[0]
1364
+
739
1365
  def list_datasets_by_prefix(
740
- self, prefix: str, conn=None
1366
+ self, prefix: str, project_id: int | None = None, conn=None
741
1367
  ) -> Iterator["DatasetListRecord"]:
1368
+ d = self._datasets
742
1369
  query = self._base_list_datasets_query()
1370
+ if project_id:
1371
+ query = query.where(d.c.project_id == project_id)
743
1372
  query = query.where(self._datasets.c.name.startswith(prefix))
744
1373
  yield from self._parse_dataset_list(self.db.execute(query))
745
1374
 
746
- def get_dataset(self, name: str, conn=None) -> DatasetRecord:
747
- """Gets a single dataset by name"""
1375
+ def get_dataset(
1376
+ self,
1377
+ name: str, # normal, not full dataset name
1378
+ namespace_name: str | None = None,
1379
+ project_name: str | None = None,
1380
+ conn=None,
1381
+ ) -> DatasetRecord:
1382
+ """
1383
+ Gets a single dataset in project by dataset name.
1384
+ """
1385
+ namespace_name = namespace_name or self.default_namespace_name
1386
+ project_name = project_name or self.default_project_name
1387
+
748
1388
  d = self._datasets
1389
+ n = self._namespaces
1390
+ p = self._projects
749
1391
  query = self._base_dataset_query()
750
- query = query.where(d.c.name == name) # type: ignore [attr-defined]
1392
+ query = query.where(
1393
+ d.c.name == name,
1394
+ n.c.name == namespace_name,
1395
+ p.c.name == project_name,
1396
+ ) # type: ignore [attr-defined]
751
1397
  ds = self._parse_dataset(self.db.execute(query, conn=conn))
752
1398
  if not ds:
753
- raise DatasetNotFoundError(f"Dataset {name} not found.")
1399
+ raise DatasetNotFoundError(
1400
+ f"Dataset {name} not found in namespace {namespace_name}"
1401
+ f" and project {project_name}"
1402
+ )
1403
+
754
1404
  return ds
755
1405
 
756
1406
  def remove_dataset_version(
757
- self, dataset: DatasetRecord, version: int
1407
+ self, dataset: DatasetRecord, version: str
758
1408
  ) -> DatasetRecord:
759
1409
  """
760
1410
  Deletes one single dataset version.
@@ -787,7 +1437,7 @@ class AbstractDBMetastore(AbstractMetastore):
787
1437
  self,
788
1438
  dataset: DatasetRecord,
789
1439
  status: int,
790
- version: Optional[int] = None,
1440
+ version: str | None = None,
791
1441
  error_message="",
792
1442
  error_stack="",
793
1443
  script_output="",
@@ -808,7 +1458,7 @@ class AbstractDBMetastore(AbstractMetastore):
808
1458
  update_data["error_message"] = error_message
809
1459
  update_data["error_stack"] = error_stack
810
1460
 
811
- self.update_dataset(dataset, conn=conn, **update_data)
1461
+ dataset = self.update_dataset(dataset, conn=conn, **update_data)
812
1462
 
813
1463
  if version:
814
1464
  self.update_dataset_version(dataset, version, conn=conn, **update_data)
@@ -820,32 +1470,29 @@ class AbstractDBMetastore(AbstractMetastore):
820
1470
  #
821
1471
  def add_dataset_dependency(
822
1472
  self,
823
- source_dataset_name: str,
824
- source_dataset_version: int,
825
- dataset_name: str,
826
- dataset_version: int,
1473
+ source_dataset: "DatasetRecord",
1474
+ source_dataset_version: str,
1475
+ dep_dataset: "DatasetRecord",
1476
+ dep_dataset_version: str,
827
1477
  ) -> None:
828
1478
  """Adds dataset dependency to dataset."""
829
- source_dataset = self.get_dataset(source_dataset_name)
830
- dataset = self.get_dataset(dataset_name)
831
-
832
1479
  self.db.execute(
833
1480
  self._datasets_dependencies_insert().values(
834
1481
  source_dataset_id=source_dataset.id,
835
1482
  source_dataset_version_id=(
836
1483
  source_dataset.get_version(source_dataset_version).id
837
1484
  ),
838
- dataset_id=dataset.id,
839
- dataset_version_id=dataset.get_version(dataset_version).id,
1485
+ dataset_id=dep_dataset.id,
1486
+ dataset_version_id=dep_dataset.get_version(dep_dataset_version).id,
840
1487
  )
841
1488
  )
842
1489
 
843
1490
  def update_dataset_dependency_source(
844
1491
  self,
845
1492
  source_dataset: DatasetRecord,
846
- source_dataset_version: int,
847
- new_source_dataset: Optional[DatasetRecord] = None,
848
- new_source_dataset_version: Optional[int] = None,
1493
+ source_dataset_version: str,
1494
+ new_source_dataset: DatasetRecord | None = None,
1495
+ new_source_dataset_version: str | None = None,
849
1496
  ) -> None:
850
1497
  dd = self._datasets_dependencies
851
1498
 
@@ -875,9 +1522,23 @@ class AbstractDBMetastore(AbstractMetastore):
875
1522
  Returns a list of columns to select in a query for fetching dataset dependencies
876
1523
  """
877
1524
 
1525
+ @abstractmethod
1526
+ def _dataset_dependency_nodes_select_columns(
1527
+ self,
1528
+ namespaces_subquery: "Subquery",
1529
+ dependency_tree_cte: "CTE",
1530
+ datasets_subquery: "Subquery",
1531
+ ) -> list["ColumnElement"]:
1532
+ """
1533
+ Returns a list of columns to select in a query for fetching
1534
+ dataset dependency nodes.
1535
+ """
1536
+
878
1537
  def get_direct_dataset_dependencies(
879
- self, dataset: DatasetRecord, version: int
880
- ) -> list[Optional[DatasetDependency]]:
1538
+ self, dataset: DatasetRecord, version: str
1539
+ ) -> list[DatasetDependency | None]:
1540
+ n = self._namespaces
1541
+ p = self._projects
881
1542
  d = self._datasets
882
1543
  dd = self._datasets_dependencies
883
1544
  dv = self._datasets_versions
@@ -889,23 +1550,90 @@ class AbstractDBMetastore(AbstractMetastore):
889
1550
  query = (
890
1551
  self._datasets_dependencies_select(*select_cols)
891
1552
  .select_from(
892
- dd.join(d, dd.c.dataset_id == d.c.id, isouter=True).join(
893
- dv, dd.c.dataset_version_id == dv.c.id, isouter=True
894
- )
1553
+ dd.join(d, dd.c.dataset_id == d.c.id, isouter=True)
1554
+ .join(dv, dd.c.dataset_version_id == dv.c.id, isouter=True)
1555
+ .join(p, d.c.project_id == p.c.id, isouter=True)
1556
+ .join(n, p.c.namespace_id == n.c.id, isouter=True)
895
1557
  )
896
1558
  .where(
897
1559
  (dd.c.source_dataset_id == dataset.id)
898
1560
  & (dd.c.source_dataset_version_id == dataset_version.id)
899
1561
  )
900
1562
  )
901
- if version:
902
- dataset_version = dataset.get_version(version)
903
- query = query.where(dd.c.source_dataset_version_id == dataset_version.id)
904
1563
 
905
1564
  return [self.dependency_class.parse(*r) for r in self.db.execute(query)]
906
1565
 
1566
+ def get_dataset_dependency_nodes(
1567
+ self, dataset_id: int, version_id: int, depth_limit: int = DEPTH_LIMIT_DEFAULT
1568
+ ) -> list[DatasetDependencyNode | None]:
1569
+ n = self._namespaces_select().subquery()
1570
+ p = self._projects
1571
+ d = self._datasets_select().subquery()
1572
+ dd = self._datasets_dependencies
1573
+ dv = self._datasets_versions
1574
+
1575
+ # Common dependency fields for CTE
1576
+ dep_fields = [
1577
+ dd.c.id,
1578
+ dd.c.source_dataset_id,
1579
+ dd.c.source_dataset_version_id,
1580
+ dd.c.dataset_id,
1581
+ dd.c.dataset_version_id,
1582
+ ]
1583
+
1584
+ # Base case: direct dependencies
1585
+ base_query = select(
1586
+ *dep_fields,
1587
+ literal(0).label("depth"),
1588
+ ).where(
1589
+ (dd.c.source_dataset_id == dataset_id)
1590
+ & (dd.c.source_dataset_version_id == version_id)
1591
+ )
1592
+
1593
+ cte = base_query.cte(name="dependency_tree", recursive=True)
1594
+
1595
+ # Recursive case: dependencies of dependencies
1596
+ # Limit depth to 100 to prevent infinite loops in case of circular dependencies
1597
+ recursive_query = (
1598
+ select(
1599
+ *dep_fields,
1600
+ (cte.c.depth + 1).label("depth"),
1601
+ )
1602
+ .select_from(
1603
+ cte.join(
1604
+ dd,
1605
+ (cte.c.dataset_id == dd.c.source_dataset_id)
1606
+ & (cte.c.dataset_version_id == dd.c.source_dataset_version_id),
1607
+ )
1608
+ )
1609
+ .where(cte.c.depth < depth_limit)
1610
+ )
1611
+
1612
+ cte = cte.union(recursive_query)
1613
+
1614
+ # Fetch all with full details
1615
+ select_cols = self._dataset_dependency_nodes_select_columns(
1616
+ namespaces_subquery=n,
1617
+ dependency_tree_cte=cte,
1618
+ datasets_subquery=d,
1619
+ )
1620
+ final_query = self._datasets_dependencies_select(*select_cols).select_from(
1621
+ # Use outer joins to handle cases where dependent datasets have been
1622
+ # physically deleted. This allows us to return dependency records with
1623
+ # None values instead of silently omitting them, making broken
1624
+ # dependencies visible to callers.
1625
+ cte.join(d, cte.c.dataset_id == d.c.id, isouter=True)
1626
+ .join(dv, cte.c.dataset_version_id == dv.c.id, isouter=True)
1627
+ .join(p, d.c.project_id == p.c.id, isouter=True)
1628
+ .join(n, p.c.namespace_id == n.c.id, isouter=True)
1629
+ )
1630
+
1631
+ return [
1632
+ self.dependency_node_class.parse(*r) for r in self.db.execute(final_query)
1633
+ ]
1634
+
907
1635
  def remove_dataset_dependencies(
908
- self, dataset: DatasetRecord, version: Optional[int] = None
1636
+ self, dataset: DatasetRecord, version: str | None = None
909
1637
  ) -> None:
910
1638
  """
911
1639
  When we remove dataset, we need to clean up it's dependencies as well
@@ -924,7 +1652,7 @@ class AbstractDBMetastore(AbstractMetastore):
924
1652
  self.db.execute(q)
925
1653
 
926
1654
  def remove_dataset_dependants(
927
- self, dataset: DatasetRecord, version: Optional[int] = None
1655
+ self, dataset: DatasetRecord, version: str | None = None
928
1656
  ) -> None:
929
1657
  """
930
1658
  When we remove dataset, we need to clear its references in other dataset
@@ -975,11 +1703,13 @@ class AbstractDBMetastore(AbstractMetastore):
975
1703
  Column("error_stack", Text, nullable=False, default=""),
976
1704
  Column("params", JSON, nullable=False),
977
1705
  Column("metrics", JSON, nullable=False),
1706
+ Column("parent_job_id", Text, nullable=True),
1707
+ Index("idx_jobs_parent_job_id", "parent_job_id"),
978
1708
  ]
979
1709
 
980
1710
  @cached_property
981
1711
  def _job_fields(self) -> list[str]:
982
- return [c.name for c in self._jobs_columns() if c.name] # type: ignore[attr-defined]
1712
+ return [c.name for c in self._jobs_columns() if isinstance(c, Column)] # type: ignore[attr-defined]
983
1713
 
984
1714
  @cached_property
985
1715
  def _jobs(self) -> "Table":
@@ -1013,15 +1743,29 @@ class AbstractDBMetastore(AbstractMetastore):
1013
1743
  query = self._jobs_query().where(self._jobs.c.id.in_(ids))
1014
1744
  yield from self._parse_jobs(self.db.execute(query, conn=conn))
1015
1745
 
1746
+ def get_last_job_by_name(self, name: str, conn=None) -> "Job | None":
1747
+ query = (
1748
+ self._jobs_query()
1749
+ .where(self._jobs.c.name == name)
1750
+ .order_by(self._jobs.c.created_at.desc())
1751
+ .limit(1)
1752
+ )
1753
+ results = list(self.db.execute(query, conn=conn))
1754
+ if not results:
1755
+ return None
1756
+ return self._parse_job(results[0])
1757
+
1016
1758
  def create_job(
1017
1759
  self,
1018
1760
  name: str,
1019
1761
  query: str,
1020
1762
  query_type: JobQueryType = JobQueryType.PYTHON,
1763
+ status: JobStatus = JobStatus.CREATED,
1021
1764
  workers: int = 1,
1022
- python_version: Optional[str] = None,
1023
- params: Optional[dict[str, str]] = None,
1024
- conn: Optional[Any] = None,
1765
+ python_version: str | None = None,
1766
+ params: dict[str, str] | None = None,
1767
+ parent_job_id: str | None = None,
1768
+ conn: Any = None,
1025
1769
  ) -> str:
1026
1770
  """
1027
1771
  Creates a new job.
@@ -1032,7 +1776,7 @@ class AbstractDBMetastore(AbstractMetastore):
1032
1776
  self._jobs_insert().values(
1033
1777
  id=job_id,
1034
1778
  name=name,
1035
- status=JobStatus.CREATED,
1779
+ status=status,
1036
1780
  created_at=datetime.now(timezone.utc),
1037
1781
  query=query,
1038
1782
  query_type=query_type.value,
@@ -1042,30 +1786,68 @@ class AbstractDBMetastore(AbstractMetastore):
1042
1786
  error_stack="",
1043
1787
  params=json.dumps(params or {}),
1044
1788
  metrics=json.dumps({}),
1789
+ parent_job_id=parent_job_id,
1045
1790
  ),
1046
1791
  conn=conn,
1047
1792
  )
1048
1793
  return job_id
1049
1794
 
1795
+ def get_job(self, job_id: str, conn=None) -> Job | None:
1796
+ """Returns the job with the given ID."""
1797
+ query = self._jobs_select(self._jobs).where(self._jobs.c.id == job_id)
1798
+ results = list(self.db.execute(query, conn=conn))
1799
+ if not results:
1800
+ return None
1801
+ return self._parse_job(results[0])
1802
+
1803
+ def update_job(
1804
+ self,
1805
+ job_id: str,
1806
+ status: JobStatus | None = None,
1807
+ error_message: str | None = None,
1808
+ error_stack: str | None = None,
1809
+ finished_at: datetime | None = None,
1810
+ metrics: dict[str, Any] | None = None,
1811
+ conn: Any | None = None,
1812
+ ) -> Job | None:
1813
+ """Updates job fields."""
1814
+ values: dict = {}
1815
+ if status is not None:
1816
+ values["status"] = status
1817
+ if error_message is not None:
1818
+ values["error_message"] = error_message
1819
+ if error_stack is not None:
1820
+ values["error_stack"] = error_stack
1821
+ if finished_at is not None:
1822
+ values["finished_at"] = finished_at
1823
+ if metrics:
1824
+ values["metrics"] = json.dumps(metrics)
1825
+
1826
+ if values:
1827
+ j = self._jobs
1828
+ self.db.execute(
1829
+ self._jobs_update().where(j.c.id == job_id).values(**values),
1830
+ conn=conn,
1831
+ ) # type: ignore [attr-defined]
1832
+
1833
+ return self.get_job(job_id, conn=conn)
1834
+
1050
1835
  def set_job_status(
1051
1836
  self,
1052
1837
  job_id: str,
1053
1838
  status: JobStatus,
1054
- error_message: Optional[str] = None,
1055
- error_stack: Optional[str] = None,
1056
- metrics: Optional[dict[str, Any]] = None,
1057
- conn: Optional[Any] = None,
1839
+ error_message: str | None = None,
1840
+ error_stack: str | None = None,
1841
+ conn: Any | None = None,
1058
1842
  ) -> None:
1059
1843
  """Set the status of the given job."""
1060
- values: dict = {"status": status.value}
1061
- if status.value in JobStatus.finished():
1844
+ values: dict = {"status": status}
1845
+ if status in JobStatus.finished():
1062
1846
  values["finished_at"] = datetime.now(timezone.utc)
1063
1847
  if error_message:
1064
1848
  values["error_message"] = error_message
1065
1849
  if error_stack:
1066
1850
  values["error_stack"] = error_stack
1067
- if metrics:
1068
- values["metrics"] = json.dumps(metrics)
1069
1851
  self.db.execute(
1070
1852
  self._jobs_update(self._jobs.c.id == job_id).values(**values),
1071
1853
  conn=conn,
@@ -1074,8 +1856,8 @@ class AbstractDBMetastore(AbstractMetastore):
1074
1856
  def get_job_status(
1075
1857
  self,
1076
1858
  job_id: str,
1077
- conn: Optional[Any] = None,
1078
- ) -> Optional[JobStatus]:
1859
+ conn: Any | None = None,
1860
+ ) -> JobStatus | None:
1079
1861
  """Returns the status of the given job."""
1080
1862
  results = list(
1081
1863
  self.db.execute(
@@ -1087,36 +1869,320 @@ class AbstractDBMetastore(AbstractMetastore):
1087
1869
  return None
1088
1870
  return results[0][0]
1089
1871
 
1090
- def set_job_and_dataset_status(
1872
+ #
1873
+ # Checkpoints
1874
+ #
1875
+
1876
+ @staticmethod
1877
+ def _checkpoints_columns() -> "list[SchemaItem]":
1878
+ return [
1879
+ Column(
1880
+ "id",
1881
+ Text,
1882
+ default=uuid4,
1883
+ primary_key=True,
1884
+ nullable=False,
1885
+ ),
1886
+ Column("job_id", Text, nullable=True),
1887
+ Column("hash", Text, nullable=False),
1888
+ Column("partial", Boolean, default=False),
1889
+ Column("created_at", DateTime(timezone=True), nullable=False),
1890
+ UniqueConstraint("job_id", "hash"),
1891
+ ]
1892
+
1893
+ @cached_property
1894
+ def _checkpoints_fields(self) -> list[str]:
1895
+ return [c.name for c in self._checkpoints_columns() if c.name] # type: ignore[attr-defined]
1896
+
1897
+ @cached_property
1898
+ def _checkpoints(self) -> "Table":
1899
+ return Table(
1900
+ self.CHECKPOINTS_TABLE,
1901
+ self.db.metadata,
1902
+ *self._checkpoints_columns(),
1903
+ )
1904
+
1905
+ @abstractmethod
1906
+ def _checkpoints_insert(self) -> "Insert": ...
1907
+
1908
+ @classmethod
1909
+ def _dataset_version_jobs_columns(cls) -> "list[SchemaItem]":
1910
+ """Junction table for dataset versions and jobs many-to-many relationship."""
1911
+ return [
1912
+ Column("id", Integer, primary_key=True),
1913
+ Column(
1914
+ "dataset_version_id",
1915
+ Integer,
1916
+ ForeignKey(f"{cls.DATASET_VERSION_TABLE}.id", ondelete="CASCADE"),
1917
+ nullable=False,
1918
+ ),
1919
+ Column("job_id", Text, nullable=False),
1920
+ Column("is_creator", Boolean, nullable=False, default=False),
1921
+ Column("created_at", DateTime(timezone=True)),
1922
+ UniqueConstraint("dataset_version_id", "job_id"),
1923
+ Index("dc_idx_dvj_query", "job_id", "is_creator", "created_at"),
1924
+ ]
1925
+
1926
+ @cached_property
1927
+ def _dataset_version_jobs_fields(self) -> list[str]:
1928
+ return [c.name for c in self._dataset_version_jobs_columns() if c.name] # type: ignore[attr-defined]
1929
+
1930
+ @cached_property
1931
+ def _dataset_version_jobs(self) -> "Table":
1932
+ return Table(
1933
+ self.DATASET_VERSION_JOBS_TABLE,
1934
+ self.db.metadata,
1935
+ *self._dataset_version_jobs_columns(),
1936
+ )
1937
+
1938
+ @abstractmethod
1939
+ def _dataset_version_jobs_insert(self) -> "Insert": ...
1940
+
1941
+ def _dataset_version_jobs_select(self, *columns) -> "Select":
1942
+ if not columns:
1943
+ return self._dataset_version_jobs.select()
1944
+ return select(*columns)
1945
+
1946
+ def _dataset_version_jobs_delete(self) -> "Delete":
1947
+ return self._dataset_version_jobs.delete()
1948
+
1949
+ def _checkpoints_select(self, *columns) -> "Select":
1950
+ if not columns:
1951
+ return self._checkpoints.select()
1952
+ return select(*columns)
1953
+
1954
+ def _checkpoints_delete(self) -> "Delete":
1955
+ return self._checkpoints.delete()
1956
+
1957
+ def _checkpoints_query(self):
1958
+ return self._checkpoints_select(
1959
+ *[getattr(self._checkpoints.c, f) for f in self._checkpoints_fields]
1960
+ )
1961
+
1962
+ def create_checkpoint(
1963
+ self,
1964
+ job_id: str,
1965
+ _hash: str,
1966
+ partial: bool = False,
1967
+ conn: Any | None = None,
1968
+ ) -> Checkpoint:
1969
+ """
1970
+ Creates a new job query step.
1971
+ """
1972
+ checkpoint_id = str(uuid4())
1973
+ self.db.execute(
1974
+ self._checkpoints_insert().values(
1975
+ id=checkpoint_id,
1976
+ job_id=job_id,
1977
+ hash=_hash,
1978
+ partial=partial,
1979
+ created_at=datetime.now(timezone.utc),
1980
+ ),
1981
+ conn=conn,
1982
+ )
1983
+ return self.get_checkpoint_by_id(checkpoint_id)
1984
+
1985
+ def list_checkpoints(self, job_id: str, conn=None) -> Iterator[Checkpoint]:
1986
+ """List checkpoints by job id."""
1987
+ query = self._checkpoints_query().where(self._checkpoints.c.job_id == job_id)
1988
+ rows = list(self.db.execute(query, conn=conn))
1989
+
1990
+ yield from [self.checkpoint_class.parse(*r) for r in rows]
1991
+
1992
+ def get_checkpoint_by_id(self, checkpoint_id: str, conn=None) -> Checkpoint:
1993
+ """Returns the checkpoint with the given ID."""
1994
+ ch = self._checkpoints
1995
+ query = self._checkpoints_select(ch).where(ch.c.id == checkpoint_id)
1996
+ rows = list(self.db.execute(query, conn=conn))
1997
+ if not rows:
1998
+ raise CheckpointNotFoundError(f"Checkpoint {checkpoint_id} not found")
1999
+ return self.checkpoint_class.parse(*rows[0])
2000
+
2001
+ def find_checkpoint(
2002
+ self, job_id: str, _hash: str, partial: bool = False, conn=None
2003
+ ) -> Checkpoint | None:
2004
+ """
2005
+ Tries to find checkpoint for a job with specific hash and optionally partial
2006
+ """
2007
+ ch = self._checkpoints
2008
+ query = self._checkpoints_select(ch).where(
2009
+ ch.c.job_id == job_id, ch.c.hash == _hash, ch.c.partial == partial
2010
+ )
2011
+ rows = list(self.db.execute(query, conn=conn))
2012
+ if not rows:
2013
+ return None
2014
+ return self.checkpoint_class.parse(*rows[0])
2015
+
2016
+ def get_last_checkpoint(self, job_id: str, conn=None) -> Checkpoint | None:
2017
+ query = (
2018
+ self._checkpoints_query()
2019
+ .where(self._checkpoints.c.job_id == job_id)
2020
+ .order_by(desc(self._checkpoints.c.created_at))
2021
+ .limit(1)
2022
+ )
2023
+ rows = list(self.db.execute(query, conn=conn))
2024
+ if not rows:
2025
+ return None
2026
+ return self.checkpoint_class.parse(*rows[0])
2027
+
2028
+ def link_dataset_version_to_job(
1091
2029
  self,
2030
+ dataset_version_id: int,
1092
2031
  job_id: str,
1093
- job_status: JobStatus,
1094
- dataset_status: DatasetStatus,
2032
+ is_creator: bool = False,
2033
+ conn=None,
1095
2034
  ) -> None:
1096
- """Set the status of the given job and dataset."""
1097
- with self.db.transaction() as conn:
1098
- self.set_job_status(job_id, status=job_status, conn=conn)
1099
- dv = self._datasets_versions
1100
- query = (
1101
- self._datasets_versions_update()
1102
- .where(
1103
- (dv.c.job_id == job_id) & (dv.c.status != DatasetStatus.COMPLETE)
2035
+ # Use transaction to atomically:
2036
+ # 1. Link dataset version to job in junction table
2037
+ # 2. Update dataset_version.job_id to point to this job
2038
+ with self.db.transaction() as tx_conn:
2039
+ conn = conn or tx_conn
2040
+
2041
+ # Insert into junction table
2042
+ query = self._dataset_version_jobs_insert().values(
2043
+ dataset_version_id=dataset_version_id,
2044
+ job_id=job_id,
2045
+ is_creator=is_creator,
2046
+ created_at=datetime.now(timezone.utc),
2047
+ )
2048
+ if hasattr(query, "on_conflict_do_nothing"):
2049
+ query = query.on_conflict_do_nothing(
2050
+ index_elements=["dataset_version_id", "job_id"]
1104
2051
  )
1105
- .values(status=dataset_status)
2052
+ self.db.execute(query, conn=conn)
2053
+
2054
+ # Also update dataset_version.job_id to point to this job
2055
+ update_query = (
2056
+ self._datasets_versions.update()
2057
+ .where(self._datasets_versions.c.id == dataset_version_id)
2058
+ .values(job_id=job_id)
2059
+ )
2060
+ self.db.execute(update_query, conn=conn)
2061
+
2062
+ def get_ancestor_job_ids(self, job_id: str, conn=None) -> list[str]:
2063
+ # Use recursive CTE to walk up the parent chain
2064
+ # Format: WITH RECURSIVE ancestors(id, parent_job_id, depth) AS (...)
2065
+ # Include depth tracking to prevent infinite recursion in case of
2066
+ # circular dependencies
2067
+ ancestors_cte = (
2068
+ self._jobs_select(
2069
+ self._jobs.c.id.label("id"),
2070
+ self._jobs.c.parent_job_id.label("parent_job_id"),
2071
+ literal(0).label("depth"),
1106
2072
  )
1107
- self.db.execute(query, conn=conn) # type: ignore[attr-defined]
2073
+ .where(self._jobs.c.id == job_id)
2074
+ .cte(name="ancestors", recursive=True)
2075
+ )
1108
2076
 
1109
- def get_job_dataset_versions(self, job_id: str) -> list[tuple[str, int]]:
1110
- """Returns dataset names and versions for the job."""
1111
- dv = self._datasets_versions
1112
- ds = self._datasets
2077
+ # Recursive part: join with parent jobs, incrementing depth and checking limit
2078
+ ancestors_recursive = ancestors_cte.union_all(
2079
+ self._jobs_select(
2080
+ self._jobs.c.id.label("id"),
2081
+ self._jobs.c.parent_job_id.label("parent_job_id"),
2082
+ (ancestors_cte.c.depth + 1).label("depth"),
2083
+ ).select_from(
2084
+ self._jobs.join(
2085
+ ancestors_cte,
2086
+ (
2087
+ self._jobs.c.id
2088
+ == cast(ancestors_cte.c.parent_job_id, self._jobs.c.id.type)
2089
+ )
2090
+ & (ancestors_cte.c.parent_job_id.isnot(None)) # Stop at root jobs
2091
+ & (ancestors_cte.c.depth < JOB_ANCESTRY_MAX_DEPTH),
2092
+ )
2093
+ )
2094
+ )
2095
+
2096
+ # Select all ancestor IDs and depths except the starting job itself
2097
+ query = select(ancestors_recursive.c.id, ancestors_recursive.c.depth).where(
2098
+ ancestors_recursive.c.id != job_id
2099
+ )
1113
2100
 
1114
- join_condition = dv.c.dataset_id == ds.c.id
2101
+ results = list(self.db.execute(query, conn=conn))
1115
2102
 
1116
- query = (
1117
- self._datasets_versions_select(ds.c.name, dv.c.version)
1118
- .select_from(dv.join(ds, join_condition))
1119
- .where(dv.c.job_id == job_id)
2103
+ # Check if we hit the depth limit
2104
+ if results:
2105
+ max_found_depth = max(row[1] for row in results)
2106
+ if max_found_depth >= JOB_ANCESTRY_MAX_DEPTH:
2107
+ from datachain.error import JobAncestryDepthExceededError
2108
+
2109
+ raise JobAncestryDepthExceededError(
2110
+ f"Job ancestry chain exceeds maximum depth of "
2111
+ f"{JOB_ANCESTRY_MAX_DEPTH}. Job ID: {job_id}"
2112
+ )
2113
+
2114
+ return [str(row[0]) for row in results]
2115
+
2116
+ def _get_dataset_version_for_job_ancestry_query(
2117
+ self,
2118
+ dataset_name: str,
2119
+ namespace_name: str,
2120
+ project_name: str,
2121
+ job_ancestry: list[str],
2122
+ ) -> "Select":
2123
+ """Find most recent dataset version created by any job in ancestry.
2124
+
2125
+ Searches job ancestry (current + parents) for the newest version of
2126
+ the dataset where is_creator=True. Returns newest by created_at, or
2127
+ None if no version was created by any job in the ancestry chain.
2128
+
2129
+ Used for checkpoint resolution to find which version to reuse when
2130
+ continuing from a parent job.
2131
+ """
2132
+ return (
2133
+ self._datasets_versions_select()
2134
+ .select_from(
2135
+ self._dataset_version_jobs.join(
2136
+ self._datasets_versions,
2137
+ self._dataset_version_jobs.c.dataset_version_id
2138
+ == self._datasets_versions.c.id,
2139
+ )
2140
+ .join(
2141
+ self._datasets,
2142
+ self._datasets_versions.c.dataset_id == self._datasets.c.id,
2143
+ )
2144
+ .join(
2145
+ self._projects,
2146
+ self._datasets.c.project_id == self._projects.c.id,
2147
+ )
2148
+ .join(
2149
+ self._namespaces,
2150
+ self._projects.c.namespace_id == self._namespaces.c.id,
2151
+ )
2152
+ )
2153
+ .where(
2154
+ self._datasets.c.name == dataset_name,
2155
+ self._namespaces.c.name == namespace_name,
2156
+ self._projects.c.name == project_name,
2157
+ self._dataset_version_jobs.c.job_id.in_(job_ancestry),
2158
+ self._dataset_version_jobs.c.is_creator.is_(True),
2159
+ )
2160
+ .order_by(desc(self._dataset_version_jobs.c.created_at))
2161
+ .limit(1)
1120
2162
  )
1121
2163
 
1122
- return list(self.db.execute(query))
2164
+ def get_dataset_version_for_job_ancestry(
2165
+ self,
2166
+ dataset_name: str,
2167
+ namespace_name: str,
2168
+ project_name: str,
2169
+ job_id: str,
2170
+ conn=None,
2171
+ ) -> DatasetVersion | None:
2172
+ # Get job ancestry (current job + all ancestors)
2173
+ job_ancestry = [job_id, *self.get_ancestor_job_ids(job_id, conn=conn)]
2174
+
2175
+ query = self._get_dataset_version_for_job_ancestry_query(
2176
+ dataset_name, namespace_name, project_name, job_ancestry
2177
+ )
2178
+
2179
+ results = list(self.db.execute(query, conn=conn))
2180
+ if not results:
2181
+ return None
2182
+
2183
+ if len(results) > 1:
2184
+ raise DataChainError(
2185
+ f"Expected at most 1 dataset version, found {len(results)}"
2186
+ )
2187
+
2188
+ return self.dataset_version_class.parse(*results[0])