datachain 0.30.5__py3-none-any.whl → 0.39.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (119) hide show
  1. datachain/__init__.py +4 -0
  2. datachain/asyn.py +11 -12
  3. datachain/cache.py +5 -5
  4. datachain/catalog/__init__.py +0 -2
  5. datachain/catalog/catalog.py +276 -354
  6. datachain/catalog/dependency.py +164 -0
  7. datachain/catalog/loader.py +8 -3
  8. datachain/checkpoint.py +43 -0
  9. datachain/cli/__init__.py +10 -17
  10. datachain/cli/commands/__init__.py +1 -8
  11. datachain/cli/commands/datasets.py +42 -27
  12. datachain/cli/commands/ls.py +15 -15
  13. datachain/cli/commands/show.py +2 -2
  14. datachain/cli/parser/__init__.py +3 -43
  15. datachain/cli/parser/job.py +1 -1
  16. datachain/cli/parser/utils.py +1 -2
  17. datachain/cli/utils.py +2 -15
  18. datachain/client/azure.py +2 -2
  19. datachain/client/fsspec.py +34 -23
  20. datachain/client/gcs.py +3 -3
  21. datachain/client/http.py +157 -0
  22. datachain/client/local.py +11 -7
  23. datachain/client/s3.py +3 -3
  24. datachain/config.py +4 -8
  25. datachain/data_storage/db_engine.py +12 -6
  26. datachain/data_storage/job.py +2 -0
  27. datachain/data_storage/metastore.py +716 -137
  28. datachain/data_storage/schema.py +20 -27
  29. datachain/data_storage/serializer.py +105 -15
  30. datachain/data_storage/sqlite.py +114 -114
  31. datachain/data_storage/warehouse.py +140 -48
  32. datachain/dataset.py +109 -89
  33. datachain/delta.py +117 -42
  34. datachain/diff/__init__.py +25 -33
  35. datachain/error.py +24 -0
  36. datachain/func/aggregate.py +9 -11
  37. datachain/func/array.py +12 -12
  38. datachain/func/base.py +7 -4
  39. datachain/func/conditional.py +9 -13
  40. datachain/func/func.py +63 -45
  41. datachain/func/numeric.py +5 -7
  42. datachain/func/string.py +2 -2
  43. datachain/hash_utils.py +123 -0
  44. datachain/job.py +11 -7
  45. datachain/json.py +138 -0
  46. datachain/lib/arrow.py +18 -15
  47. datachain/lib/audio.py +60 -59
  48. datachain/lib/clip.py +14 -13
  49. datachain/lib/convert/python_to_sql.py +6 -10
  50. datachain/lib/convert/values_to_tuples.py +151 -53
  51. datachain/lib/data_model.py +23 -19
  52. datachain/lib/dataset_info.py +7 -7
  53. datachain/lib/dc/__init__.py +2 -1
  54. datachain/lib/dc/csv.py +22 -26
  55. datachain/lib/dc/database.py +37 -34
  56. datachain/lib/dc/datachain.py +518 -324
  57. datachain/lib/dc/datasets.py +38 -30
  58. datachain/lib/dc/hf.py +16 -20
  59. datachain/lib/dc/json.py +17 -18
  60. datachain/lib/dc/listings.py +5 -8
  61. datachain/lib/dc/pandas.py +3 -6
  62. datachain/lib/dc/parquet.py +33 -21
  63. datachain/lib/dc/records.py +9 -13
  64. datachain/lib/dc/storage.py +103 -65
  65. datachain/lib/dc/storage_pattern.py +251 -0
  66. datachain/lib/dc/utils.py +17 -14
  67. datachain/lib/dc/values.py +3 -6
  68. datachain/lib/file.py +187 -50
  69. datachain/lib/hf.py +7 -5
  70. datachain/lib/image.py +13 -13
  71. datachain/lib/listing.py +5 -5
  72. datachain/lib/listing_info.py +1 -2
  73. datachain/lib/meta_formats.py +2 -3
  74. datachain/lib/model_store.py +20 -8
  75. datachain/lib/namespaces.py +59 -7
  76. datachain/lib/projects.py +51 -9
  77. datachain/lib/pytorch.py +31 -23
  78. datachain/lib/settings.py +188 -85
  79. datachain/lib/signal_schema.py +302 -64
  80. datachain/lib/text.py +8 -7
  81. datachain/lib/udf.py +103 -63
  82. datachain/lib/udf_signature.py +59 -34
  83. datachain/lib/utils.py +20 -0
  84. datachain/lib/video.py +3 -4
  85. datachain/lib/webdataset.py +31 -36
  86. datachain/lib/webdataset_laion.py +15 -16
  87. datachain/listing.py +12 -5
  88. datachain/model/bbox.py +3 -1
  89. datachain/namespace.py +22 -3
  90. datachain/node.py +6 -6
  91. datachain/nodes_thread_pool.py +0 -1
  92. datachain/plugins.py +24 -0
  93. datachain/project.py +4 -4
  94. datachain/query/batch.py +10 -12
  95. datachain/query/dataset.py +376 -194
  96. datachain/query/dispatch.py +112 -84
  97. datachain/query/metrics.py +3 -4
  98. datachain/query/params.py +2 -3
  99. datachain/query/queue.py +2 -1
  100. datachain/query/schema.py +7 -6
  101. datachain/query/session.py +190 -33
  102. datachain/query/udf.py +9 -6
  103. datachain/remote/studio.py +90 -53
  104. datachain/script_meta.py +12 -12
  105. datachain/sql/sqlite/base.py +37 -25
  106. datachain/sql/sqlite/types.py +1 -1
  107. datachain/sql/types.py +36 -5
  108. datachain/studio.py +49 -40
  109. datachain/toolkit/split.py +31 -10
  110. datachain/utils.py +39 -48
  111. {datachain-0.30.5.dist-info → datachain-0.39.0.dist-info}/METADATA +26 -38
  112. datachain-0.39.0.dist-info/RECORD +173 -0
  113. datachain/cli/commands/query.py +0 -54
  114. datachain/query/utils.py +0 -36
  115. datachain-0.30.5.dist-info/RECORD +0 -168
  116. {datachain-0.30.5.dist-info → datachain-0.39.0.dist-info}/WHEEL +0 -0
  117. {datachain-0.30.5.dist-info → datachain-0.39.0.dist-info}/entry_points.txt +0 -0
  118. {datachain-0.30.5.dist-info → datachain-0.39.0.dist-info}/licenses/LICENSE +0 -0
  119. {datachain-0.30.5.dist-info → datachain-0.39.0.dist-info}/top_level.txt +0 -0
@@ -1,28 +1,37 @@
1
1
  import copy
2
- import json
3
2
  import logging
4
3
  import os
5
4
  from abc import ABC, abstractmethod
6
5
  from collections.abc import Iterator
6
+ from contextlib import contextmanager, suppress
7
7
  from datetime import datetime, timezone
8
8
  from functools import cached_property, reduce
9
9
  from itertools import groupby
10
- from typing import TYPE_CHECKING, Any, Optional
10
+ from typing import TYPE_CHECKING, Any
11
11
  from uuid import uuid4
12
12
 
13
13
  from sqlalchemy import (
14
14
  JSON,
15
15
  BigInteger,
16
+ Boolean,
16
17
  Column,
17
18
  DateTime,
18
19
  ForeignKey,
20
+ Index,
19
21
  Integer,
20
22
  Table,
21
23
  Text,
22
24
  UniqueConstraint,
25
+ cast,
26
+ desc,
27
+ literal,
23
28
  select,
24
29
  )
30
+ from sqlalchemy.sql import func as f
25
31
 
32
+ from datachain import json
33
+ from datachain.catalog.dependency import DatasetDependencyNode
34
+ from datachain.checkpoint import Checkpoint
26
35
  from datachain.data_storage import JobQueryType, JobStatus
27
36
  from datachain.data_storage.serializer import Serializable
28
37
  from datachain.dataset import (
@@ -33,27 +42,34 @@ from datachain.dataset import (
33
42
  DatasetStatus,
34
43
  DatasetVersion,
35
44
  StorageURI,
45
+ parse_schema,
36
46
  )
37
47
  from datachain.error import (
48
+ CheckpointNotFoundError,
49
+ DataChainError,
38
50
  DatasetNotFoundError,
39
51
  DatasetVersionNotFoundError,
52
+ NamespaceDeleteNotAllowedError,
40
53
  NamespaceNotFoundError,
54
+ ProjectDeleteNotAllowedError,
41
55
  ProjectNotFoundError,
42
56
  TableMissingError,
43
57
  )
44
58
  from datachain.job import Job
45
59
  from datachain.namespace import Namespace
46
60
  from datachain.project import Project
47
- from datachain.utils import JSONSerialize
48
61
 
49
62
  if TYPE_CHECKING:
50
- from sqlalchemy import Delete, Insert, Select, Update
63
+ from sqlalchemy import CTE, Delete, Insert, Select, Subquery, Update
51
64
  from sqlalchemy.schema import SchemaItem
65
+ from sqlalchemy.sql.elements import ColumnElement
52
66
 
53
67
  from datachain.data_storage import schema
54
68
  from datachain.data_storage.db_engine import DatabaseEngine
55
69
 
56
70
  logger = logging.getLogger("datachain")
71
+ DEPTH_LIMIT_DEFAULT = 100
72
+ JOB_ANCESTRY_MAX_DEPTH = 100
57
73
 
58
74
 
59
75
  class AbstractMetastore(ABC, Serializable):
@@ -68,14 +84,17 @@ class AbstractMetastore(ABC, Serializable):
68
84
  namespace_class: type[Namespace] = Namespace
69
85
  project_class: type[Project] = Project
70
86
  dataset_class: type[DatasetRecord] = DatasetRecord
87
+ dataset_version_class: type[DatasetVersion] = DatasetVersion
71
88
  dataset_list_class: type[DatasetListRecord] = DatasetListRecord
72
89
  dataset_list_version_class: type[DatasetListVersion] = DatasetListVersion
73
90
  dependency_class: type[DatasetDependency] = DatasetDependency
91
+ dependency_node_class: type[DatasetDependencyNode] = DatasetDependencyNode
74
92
  job_class: type[Job] = Job
93
+ checkpoint_class: type[Checkpoint] = Checkpoint
75
94
 
76
95
  def __init__(
77
96
  self,
78
- uri: Optional[StorageURI] = None,
97
+ uri: StorageURI | None = None,
79
98
  ):
80
99
  self.uri = uri or StorageURI("")
81
100
 
@@ -89,7 +108,7 @@ class AbstractMetastore(ABC, Serializable):
89
108
  @abstractmethod
90
109
  def clone(
91
110
  self,
92
- uri: Optional[StorageURI] = None,
111
+ uri: StorageURI | None = None,
93
112
  use_new_connection: bool = False,
94
113
  ) -> "AbstractMetastore":
95
114
  """Clones AbstractMetastore implementation for some Storage input.
@@ -106,6 +125,16 @@ class AbstractMetastore(ABC, Serializable):
106
125
  differently."""
107
126
  self.close()
108
127
 
128
+ @contextmanager
129
+ def _init_guard(self):
130
+ """Ensure resources acquired during __init__ are released on failure."""
131
+ try:
132
+ yield
133
+ except Exception:
134
+ with suppress(Exception):
135
+ self.close_on_exit()
136
+ raise
137
+
109
138
  def cleanup_tables(self, temp_table_names: list[str]) -> None:
110
139
  """Cleanup temp tables."""
111
140
 
@@ -129,8 +158,8 @@ class AbstractMetastore(ABC, Serializable):
129
158
  def create_namespace(
130
159
  self,
131
160
  name: str,
132
- description: Optional[str] = None,
133
- uuid: Optional[str] = None,
161
+ description: str | None = None,
162
+ uuid: str | None = None,
134
163
  ignore_if_exists: bool = True,
135
164
  validate: bool = True,
136
165
  **kwargs,
@@ -141,6 +170,10 @@ class AbstractMetastore(ABC, Serializable):
141
170
  def get_namespace(self, name: str, conn=None) -> Namespace:
142
171
  """Gets a single namespace by name"""
143
172
 
173
+ @abstractmethod
174
+ def remove_namespace(self, namespace_id: int, conn=None) -> None:
175
+ """Removes a single namespace by id"""
176
+
144
177
  @abstractmethod
145
178
  def list_namespaces(self, conn=None) -> list[Namespace]:
146
179
  """Gets a list of all namespaces"""
@@ -173,8 +206,8 @@ class AbstractMetastore(ABC, Serializable):
173
206
  self,
174
207
  namespace_name: str,
175
208
  name: str,
176
- description: Optional[str] = None,
177
- uuid: Optional[str] = None,
209
+ description: str | None = None,
210
+ uuid: str | None = None,
178
211
  ignore_if_exists: bool = True,
179
212
  validate: bool = True,
180
213
  **kwargs,
@@ -190,12 +223,32 @@ class AbstractMetastore(ABC, Serializable):
190
223
  It also creates project if not found and create flag is set to True.
191
224
  """
192
225
 
226
+ def is_default_project(self, project_name: str, namespace_name: str) -> bool:
227
+ return (
228
+ project_name == self.default_project_name
229
+ and namespace_name == self.default_namespace_name
230
+ )
231
+
232
+ def is_listing_project(self, project_name: str, namespace_name: str) -> bool:
233
+ return (
234
+ project_name == self.listing_project_name
235
+ and namespace_name == self.system_namespace_name
236
+ )
237
+
193
238
  @abstractmethod
194
239
  def get_project_by_id(self, project_id: int, conn=None) -> Project:
195
240
  """Gets a single project by id"""
196
241
 
197
242
  @abstractmethod
198
- def list_projects(self, namespace_id: Optional[int], conn=None) -> list[Project]:
243
+ def count_projects(self, namespace_id: int | None = None) -> int:
244
+ """Counts projects in some namespace or in general."""
245
+
246
+ @abstractmethod
247
+ def remove_project(self, project_id: int, conn=None) -> None:
248
+ """Removes a single project by id"""
249
+
250
+ @abstractmethod
251
+ def list_projects(self, namespace_id: int | None, conn=None) -> list[Project]:
199
252
  """Gets list of projects in some namespace or in general (in all namespaces)"""
200
253
 
201
254
  #
@@ -205,15 +258,15 @@ class AbstractMetastore(ABC, Serializable):
205
258
  def create_dataset(
206
259
  self,
207
260
  name: str,
208
- project_id: Optional[int] = None,
261
+ project_id: int | None = None,
209
262
  status: int = DatasetStatus.CREATED,
210
- sources: Optional[list[str]] = None,
211
- feature_schema: Optional[dict] = None,
263
+ sources: list[str] | None = None,
264
+ feature_schema: dict | None = None,
212
265
  query_script: str = "",
213
- schema: Optional[dict[str, Any]] = None,
266
+ schema: dict[str, Any] | None = None,
214
267
  ignore_if_exists: bool = False,
215
- description: Optional[str] = None,
216
- attrs: Optional[list[str]] = None,
268
+ description: str | None = None,
269
+ attrs: list[str] | None = None,
217
270
  ) -> DatasetRecord:
218
271
  """Creates new dataset."""
219
272
 
@@ -224,20 +277,20 @@ class AbstractMetastore(ABC, Serializable):
224
277
  version: str,
225
278
  status: int,
226
279
  sources: str = "",
227
- feature_schema: Optional[dict] = None,
280
+ feature_schema: dict | None = None,
228
281
  query_script: str = "",
229
282
  error_message: str = "",
230
283
  error_stack: str = "",
231
284
  script_output: str = "",
232
- created_at: Optional[datetime] = None,
233
- finished_at: Optional[datetime] = None,
234
- schema: Optional[dict[str, Any]] = None,
285
+ created_at: datetime | None = None,
286
+ finished_at: datetime | None = None,
287
+ schema: dict[str, Any] | None = None,
235
288
  ignore_if_exists: bool = False,
236
- num_objects: Optional[int] = None,
237
- size: Optional[int] = None,
238
- preview: Optional[list[dict]] = None,
239
- job_id: Optional[str] = None,
240
- uuid: Optional[str] = None,
289
+ num_objects: int | None = None,
290
+ size: int | None = None,
291
+ preview: list[dict] | None = None,
292
+ job_id: str | None = None,
293
+ uuid: str | None = None,
241
294
  ) -> DatasetRecord:
242
295
  """Creates new dataset version."""
243
296
 
@@ -266,13 +319,17 @@ class AbstractMetastore(ABC, Serializable):
266
319
 
267
320
  @abstractmethod
268
321
  def list_datasets(
269
- self, project_id: Optional[int] = None
322
+ self, project_id: int | None = None
270
323
  ) -> Iterator[DatasetListRecord]:
271
324
  """Lists all datasets in some project or in all projects."""
272
325
 
326
+ @abstractmethod
327
+ def count_datasets(self, project_id: int | None = None) -> int:
328
+ """Counts datasets in some project or in all projects."""
329
+
273
330
  @abstractmethod
274
331
  def list_datasets_by_prefix(
275
- self, prefix: str, project_id: Optional[int] = None
332
+ self, prefix: str, project_id: int | None = None
276
333
  ) -> Iterator["DatasetListRecord"]:
277
334
  """
278
335
  Lists all datasets which names start with prefix in some project or in all
@@ -283,8 +340,8 @@ class AbstractMetastore(ABC, Serializable):
283
340
  def get_dataset(
284
341
  self,
285
342
  name: str, # normal, not full dataset name
286
- namespace_name: Optional[str] = None,
287
- project_name: Optional[str] = None,
343
+ namespace_name: str | None = None,
344
+ project_name: str | None = None,
288
345
  conn=None,
289
346
  ) -> DatasetRecord:
290
347
  """Gets a single dataset by name."""
@@ -294,7 +351,7 @@ class AbstractMetastore(ABC, Serializable):
294
351
  self,
295
352
  dataset: DatasetRecord,
296
353
  status: int,
297
- version: Optional[str] = None,
354
+ version: str | None = None,
298
355
  error_message="",
299
356
  error_stack="",
300
357
  script_output="",
@@ -319,20 +376,26 @@ class AbstractMetastore(ABC, Serializable):
319
376
  self,
320
377
  source_dataset: DatasetRecord,
321
378
  source_dataset_version: str,
322
- new_source_dataset: Optional[DatasetRecord] = None,
323
- new_source_dataset_version: Optional[str] = None,
379
+ new_source_dataset: DatasetRecord | None = None,
380
+ new_source_dataset_version: str | None = None,
324
381
  ) -> None:
325
382
  """Updates dataset dependency source."""
326
383
 
327
384
  @abstractmethod
328
385
  def get_direct_dataset_dependencies(
329
386
  self, dataset: DatasetRecord, version: str
330
- ) -> list[Optional[DatasetDependency]]:
387
+ ) -> list[DatasetDependency | None]:
331
388
  """Gets direct dataset dependencies."""
332
389
 
390
+ @abstractmethod
391
+ def get_dataset_dependency_nodes(
392
+ self, dataset_id: int, version_id: int
393
+ ) -> list[DatasetDependencyNode | None]:
394
+ """Gets dataset dependency node from database."""
395
+
333
396
  @abstractmethod
334
397
  def remove_dataset_dependencies(
335
- self, dataset: DatasetRecord, version: Optional[str] = None
398
+ self, dataset: DatasetRecord, version: str | None = None
336
399
  ) -> None:
337
400
  """
338
401
  When we remove dataset, we need to clean up it's dependencies as well.
@@ -340,7 +403,7 @@ class AbstractMetastore(ABC, Serializable):
340
403
 
341
404
  @abstractmethod
342
405
  def remove_dataset_dependants(
343
- self, dataset: DatasetRecord, version: Optional[str] = None
406
+ self, dataset: DatasetRecord, version: str | None = None
344
407
  ) -> None:
345
408
  """
346
409
  When we remove dataset, we need to clear its references in other dataset
@@ -362,8 +425,9 @@ class AbstractMetastore(ABC, Serializable):
362
425
  query_type: JobQueryType = JobQueryType.PYTHON,
363
426
  status: JobStatus = JobStatus.CREATED,
364
427
  workers: int = 1,
365
- python_version: Optional[str] = None,
366
- params: Optional[dict[str, str]] = None,
428
+ python_version: str | None = None,
429
+ params: dict[str, str] | None = None,
430
+ parent_job_id: str | None = None,
367
431
  ) -> str:
368
432
  """
369
433
  Creates a new job.
@@ -371,19 +435,19 @@ class AbstractMetastore(ABC, Serializable):
371
435
  """
372
436
 
373
437
  @abstractmethod
374
- def get_job(self, job_id: str) -> Optional[Job]:
438
+ def get_job(self, job_id: str) -> Job | None:
375
439
  """Returns the job with the given ID."""
376
440
 
377
441
  @abstractmethod
378
442
  def update_job(
379
443
  self,
380
444
  job_id: str,
381
- status: Optional[JobStatus] = None,
382
- error_message: Optional[str] = None,
383
- error_stack: Optional[str] = None,
384
- finished_at: Optional[datetime] = None,
385
- metrics: Optional[dict[str, Any]] = None,
386
- ) -> Optional["Job"]:
445
+ status: JobStatus | None = None,
446
+ error_message: str | None = None,
447
+ error_stack: str | None = None,
448
+ finished_at: datetime | None = None,
449
+ metrics: dict[str, Any] | None = None,
450
+ ) -> Job | None:
387
451
  """Updates job fields."""
388
452
 
389
453
  @abstractmethod
@@ -391,15 +455,90 @@ class AbstractMetastore(ABC, Serializable):
391
455
  self,
392
456
  job_id: str,
393
457
  status: JobStatus,
394
- error_message: Optional[str] = None,
395
- error_stack: Optional[str] = None,
458
+ error_message: str | None = None,
459
+ error_stack: str | None = None,
396
460
  ) -> None:
397
461
  """Set the status of the given job."""
398
462
 
399
463
  @abstractmethod
400
- def get_job_status(self, job_id: str) -> Optional[JobStatus]:
464
+ def get_job_status(self, job_id: str) -> JobStatus | None:
401
465
  """Returns the status of the given job."""
402
466
 
467
+ @abstractmethod
468
+ def get_last_job_by_name(self, name: str, conn=None) -> "Job | None":
469
+ """Returns the last job with the given name, ordered by created_at."""
470
+
471
+ #
472
+ # Checkpoints
473
+ #
474
+
475
+ @abstractmethod
476
+ def list_checkpoints(self, job_id: str, conn=None) -> Iterator[Checkpoint]:
477
+ """Returns all checkpoints related to some job"""
478
+
479
+ @abstractmethod
480
+ def get_last_checkpoint(self, job_id: str, conn=None) -> Checkpoint | None:
481
+ """Get last created checkpoint for some job."""
482
+
483
+ @abstractmethod
484
+ def get_checkpoint_by_id(self, checkpoint_id: str, conn=None) -> Checkpoint:
485
+ """Gets single checkpoint by id"""
486
+
487
+ def find_checkpoint(
488
+ self, job_id: str, _hash: str, partial: bool = False, conn=None
489
+ ) -> Checkpoint | None:
490
+ """
491
+ Tries to find checkpoint for a job with specific hash and optionally partial
492
+ """
493
+
494
+ @abstractmethod
495
+ def create_checkpoint(
496
+ self,
497
+ job_id: str,
498
+ _hash: str,
499
+ partial: bool = False,
500
+ conn: Any | None = None,
501
+ ) -> Checkpoint:
502
+ """Creates new checkpoint"""
503
+
504
+ #
505
+ # Dataset Version Jobs (many-to-many)
506
+ #
507
+
508
+ @abstractmethod
509
+ def link_dataset_version_to_job(
510
+ self,
511
+ dataset_version_id: int,
512
+ job_id: str,
513
+ is_creator: bool = False,
514
+ conn=None,
515
+ ) -> None:
516
+ """
517
+ Link dataset version to job.
518
+
519
+ This atomically:
520
+ 1. Creates a link in the dataset_version_jobs junction table
521
+ 2. Updates dataset_version.job_id to point to this job
522
+ """
523
+
524
+ @abstractmethod
525
+ def get_ancestor_job_ids(self, job_id: str, conn=None) -> list[str]:
526
+ """Get all ancestor job IDs for a given job."""
527
+
528
+ @abstractmethod
529
+ def get_dataset_version_for_job_ancestry(
530
+ self,
531
+ dataset_name: str,
532
+ namespace_name: str,
533
+ project_name: str,
534
+ job_id: str,
535
+ conn=None,
536
+ ) -> DatasetVersion | None:
537
+ """
538
+ Find the dataset version that was created by any job in the ancestry.
539
+ Returns the most recently linked version from these jobs.
540
+ """
541
+
403
542
 
404
543
  class AbstractDBMetastore(AbstractMetastore):
405
544
  """
@@ -414,11 +553,13 @@ class AbstractDBMetastore(AbstractMetastore):
414
553
  DATASET_TABLE = "datasets"
415
554
  DATASET_VERSION_TABLE = "datasets_versions"
416
555
  DATASET_DEPENDENCY_TABLE = "datasets_dependencies"
556
+ DATASET_VERSION_JOBS_TABLE = "dataset_version_jobs"
417
557
  JOBS_TABLE = "jobs"
558
+ CHECKPOINTS_TABLE = "checkpoints"
418
559
 
419
560
  db: "DatabaseEngine"
420
561
 
421
- def __init__(self, uri: Optional[StorageURI] = None):
562
+ def __init__(self, uri: StorageURI | None = None):
422
563
  uri = uri or StorageURI("")
423
564
  super().__init__(uri)
424
565
 
@@ -658,9 +799,6 @@ class AbstractDBMetastore(AbstractMetastore):
658
799
  return self._projects.select()
659
800
  return select(*columns)
660
801
 
661
- def _projects_update(self) -> "Update":
662
- return self._projects.update()
663
-
664
802
  def _projects_delete(self) -> "Delete":
665
803
  return self._projects.delete()
666
804
 
@@ -713,8 +851,8 @@ class AbstractDBMetastore(AbstractMetastore):
713
851
  def create_namespace(
714
852
  self,
715
853
  name: str,
716
- description: Optional[str] = None,
717
- uuid: Optional[str] = None,
854
+ description: str | None = None,
855
+ uuid: str | None = None,
718
856
  ignore_if_exists: bool = True,
719
857
  validate: bool = True,
720
858
  **kwargs,
@@ -735,6 +873,18 @@ class AbstractDBMetastore(AbstractMetastore):
735
873
 
736
874
  return self.get_namespace(name)
737
875
 
876
+ def remove_namespace(self, namespace_id: int, conn=None) -> None:
877
+ num_projects = self.count_projects(namespace_id)
878
+ if num_projects > 0:
879
+ raise NamespaceDeleteNotAllowedError(
880
+ f"Namespace cannot be removed. It contains {num_projects} project(s). "
881
+ "Please remove the project(s) first."
882
+ )
883
+
884
+ n = self._namespaces
885
+ with self.db.transaction():
886
+ self.db.execute(self._namespaces_delete().where(n.c.id == namespace_id))
887
+
738
888
  def get_namespace(self, name: str, conn=None) -> Namespace:
739
889
  """Gets a single namespace by name"""
740
890
  n = self._namespaces
@@ -766,8 +916,8 @@ class AbstractDBMetastore(AbstractMetastore):
766
916
  self,
767
917
  namespace_name: str,
768
918
  name: str,
769
- description: Optional[str] = None,
770
- uuid: Optional[str] = None,
919
+ description: str | None = None,
920
+ uuid: str | None = None,
771
921
  ignore_if_exists: bool = True,
772
922
  validate: bool = True,
773
923
  **kwargs,
@@ -796,17 +946,15 @@ class AbstractDBMetastore(AbstractMetastore):
796
946
 
797
947
  return self.get_project(name, namespace.name)
798
948
 
799
- def _is_listing_project(self, project_name: str, namespace_name: str) -> bool:
800
- return (
801
- project_name == self.listing_project_name
802
- and namespace_name == self.system_namespace_name
803
- )
949
+ def _projects_base_query(self) -> "Select":
950
+ n = self._namespaces
951
+ p = self._projects
804
952
 
805
- def _is_default_project(self, project_name: str, namespace_name: str) -> bool:
806
- return (
807
- project_name == self.default_project_name
808
- and namespace_name == self.default_namespace_name
953
+ query = self._projects_select(
954
+ *(getattr(n.c, f) for f in self._namespaces_fields),
955
+ *(getattr(p.c, f) for f in self._projects_fields),
809
956
  )
957
+ return query.select_from(n.join(p, n.c.id == p.c.namespace_id))
810
958
 
811
959
  def get_project(
812
960
  self, name: str, namespace_name: str, create: bool = False, conn=None
@@ -816,18 +964,14 @@ class AbstractDBMetastore(AbstractMetastore):
816
964
  p = self._projects
817
965
  validate = True
818
966
 
819
- if self._is_listing_project(name, namespace_name) or self._is_default_project(
967
+ if self.is_listing_project(name, namespace_name) or self.is_default_project(
820
968
  name, namespace_name
821
969
  ):
822
970
  # we are always creating default and listing projects if they don't exist
823
971
  create = True
824
972
  validate = False
825
973
 
826
- query = self._projects_select(
827
- *(getattr(n.c, f) for f in self._namespaces_fields),
828
- *(getattr(p.c, f) for f in self._projects_fields),
829
- )
830
- query = query.select_from(n.join(p, n.c.id == p.c.namespace_id)).where(
974
+ query = self._projects_base_query().where(
831
975
  p.c.name == name, n.c.name == namespace_name
832
976
  )
833
977
 
@@ -842,37 +986,50 @@ class AbstractDBMetastore(AbstractMetastore):
842
986
 
843
987
  def get_project_by_id(self, project_id: int, conn=None) -> Project:
844
988
  """Gets a single project by id"""
845
- n = self._namespaces
846
989
  p = self._projects
847
990
 
848
- query = self._projects_select(
849
- *(getattr(n.c, f) for f in self._namespaces_fields),
850
- *(getattr(p.c, f) for f in self._projects_fields),
851
- )
852
- query = query.select_from(n.join(p, n.c.id == p.c.namespace_id)).where(
853
- p.c.id == project_id
854
- )
991
+ query = self._projects_base_query().where(p.c.id == project_id)
855
992
 
856
993
  rows = list(self.db.execute(query, conn=conn))
857
994
  if not rows:
858
995
  raise ProjectNotFoundError(f"Project with id {project_id} not found.")
859
996
  return self.project_class.parse(*rows[0])
860
997
 
861
- def list_projects(self, namespace_id: Optional[int], conn=None) -> list[Project]:
998
+ def count_projects(self, namespace_id: int | None = None) -> int:
999
+ p = self._projects
1000
+
1001
+ query = self._projects_base_query()
1002
+ if namespace_id:
1003
+ query = query.where(p.c.namespace_id == namespace_id)
1004
+
1005
+ query = select(f.count(1)).select_from(query.subquery())
1006
+
1007
+ return next(self.db.execute(query))[0]
1008
+
1009
+ def remove_project(self, project_id: int, conn=None) -> None:
1010
+ num_datasets = self.count_datasets(project_id)
1011
+ if num_datasets > 0:
1012
+ raise ProjectDeleteNotAllowedError(
1013
+ f"Project cannot be removed. It contains {num_datasets} dataset(s). "
1014
+ "Please remove the dataset(s) first."
1015
+ )
1016
+
1017
+ p = self._projects
1018
+ with self.db.transaction():
1019
+ self.db.execute(self._projects_delete().where(p.c.id == project_id))
1020
+
1021
+ def list_projects(
1022
+ self, namespace_id: int | None = None, conn=None
1023
+ ) -> list[Project]:
862
1024
  """
863
1025
  Gets a list of projects inside some namespace, or in all namespaces
864
1026
  """
865
- n = self._namespaces
866
1027
  p = self._projects
867
1028
 
868
- query = self._projects_select(
869
- *(getattr(n.c, f) for f in self._namespaces_fields),
870
- *(getattr(p.c, f) for f in self._projects_fields),
871
- )
872
- query = query.select_from(n.join(p, n.c.id == p.c.namespace_id))
1029
+ query = self._projects_base_query()
873
1030
 
874
1031
  if namespace_id:
875
- query = query.where(n.c.id == namespace_id)
1032
+ query = query.where(p.c.namespace_id == namespace_id)
876
1033
 
877
1034
  rows = list(self.db.execute(query, conn=conn))
878
1035
 
@@ -885,15 +1042,15 @@ class AbstractDBMetastore(AbstractMetastore):
885
1042
  def create_dataset(
886
1043
  self,
887
1044
  name: str,
888
- project_id: Optional[int] = None,
1045
+ project_id: int | None = None,
889
1046
  status: int = DatasetStatus.CREATED,
890
- sources: Optional[list[str]] = None,
891
- feature_schema: Optional[dict] = None,
1047
+ sources: list[str] | None = None,
1048
+ feature_schema: dict | None = None,
892
1049
  query_script: str = "",
893
- schema: Optional[dict[str, Any]] = None,
1050
+ schema: dict[str, Any] | None = None,
894
1051
  ignore_if_exists: bool = False,
895
- description: Optional[str] = None,
896
- attrs: Optional[list[str]] = None,
1052
+ description: str | None = None,
1053
+ attrs: list[str] | None = None,
897
1054
  **kwargs, # TODO registered = True / False
898
1055
  ) -> DatasetRecord:
899
1056
  """Creates new dataset."""
@@ -933,20 +1090,20 @@ class AbstractDBMetastore(AbstractMetastore):
933
1090
  version: str,
934
1091
  status: int,
935
1092
  sources: str = "",
936
- feature_schema: Optional[dict] = None,
1093
+ feature_schema: dict | None = None,
937
1094
  query_script: str = "",
938
1095
  error_message: str = "",
939
1096
  error_stack: str = "",
940
1097
  script_output: str = "",
941
- created_at: Optional[datetime] = None,
942
- finished_at: Optional[datetime] = None,
943
- schema: Optional[dict[str, Any]] = None,
1098
+ created_at: datetime | None = None,
1099
+ finished_at: datetime | None = None,
1100
+ schema: dict[str, Any] | None = None,
944
1101
  ignore_if_exists: bool = False,
945
- num_objects: Optional[int] = None,
946
- size: Optional[int] = None,
947
- preview: Optional[list[dict]] = None,
948
- job_id: Optional[str] = None,
949
- uuid: Optional[str] = None,
1102
+ num_objects: int | None = None,
1103
+ size: int | None = None,
1104
+ preview: list[dict] | None = None,
1105
+ job_id: str | None = None,
1106
+ uuid: str | None = None,
950
1107
  conn=None,
951
1108
  ) -> DatasetRecord:
952
1109
  """Creates new dataset version."""
@@ -1024,7 +1181,7 @@ class AbstractDBMetastore(AbstractMetastore):
1024
1181
  dataset_values[field] = None
1025
1182
  else:
1026
1183
  values[field] = json.dumps(value)
1027
- dataset_values[field] = DatasetRecord.parse_schema(value)
1184
+ dataset_values[field] = parse_schema(value)
1028
1185
  elif field == "project_id":
1029
1186
  if not value:
1030
1187
  raise ValueError("Cannot set empty project_id for dataset")
@@ -1075,9 +1232,7 @@ class AbstractDBMetastore(AbstractMetastore):
1075
1232
 
1076
1233
  if field == "schema":
1077
1234
  values[field] = json.dumps(value) if value else None
1078
- version_values[field] = (
1079
- DatasetRecord.parse_schema(value) if value else None
1080
- )
1235
+ version_values[field] = parse_schema(value) if value else None
1081
1236
  elif field == "feature_schema":
1082
1237
  if value is None:
1083
1238
  values[field] = None
@@ -1092,7 +1247,7 @@ class AbstractDBMetastore(AbstractMetastore):
1092
1247
  f"Field '{field}' must be a list, got {type(value).__name__}"
1093
1248
  )
1094
1249
  else:
1095
- values[field] = json.dumps(value, cls=JSONSerialize)
1250
+ values[field] = json.dumps(value, serialize_bytes=True)
1096
1251
  version_values["_preview_data"] = value
1097
1252
  else:
1098
1253
  values[field] = value
@@ -1118,13 +1273,13 @@ class AbstractDBMetastore(AbstractMetastore):
1118
1273
  f"Dataset {dataset.name} does not have version {version}"
1119
1274
  )
1120
1275
 
1121
- def _parse_dataset(self, rows) -> Optional[DatasetRecord]:
1276
+ def _parse_dataset(self, rows) -> DatasetRecord | None:
1122
1277
  versions = [self.dataset_class.parse(*r) for r in rows]
1123
1278
  if not versions:
1124
1279
  return None
1125
1280
  return reduce(lambda ds, version: ds.merge_versions(version), versions)
1126
1281
 
1127
- def _parse_list_dataset(self, rows) -> Optional[DatasetListRecord]:
1282
+ def _parse_list_dataset(self, rows) -> DatasetListRecord | None:
1128
1283
  versions = [self.dataset_list_class.parse(*r) for r in rows]
1129
1284
  if not versions:
1130
1285
  return None
@@ -1187,9 +1342,8 @@ class AbstractDBMetastore(AbstractMetastore):
1187
1342
  )
1188
1343
 
1189
1344
  def list_datasets(
1190
- self, project_id: Optional[int] = None
1345
+ self, project_id: int | None = None
1191
1346
  ) -> Iterator["DatasetListRecord"]:
1192
- """Lists all datasets."""
1193
1347
  d = self._datasets
1194
1348
  query = self._base_list_datasets_query().order_by(
1195
1349
  self._datasets.c.name, self._datasets_versions.c.version
@@ -1198,8 +1352,18 @@ class AbstractDBMetastore(AbstractMetastore):
1198
1352
  query = query.where(d.c.project_id == project_id)
1199
1353
  yield from self._parse_dataset_list(self.db.execute(query))
1200
1354
 
1355
+ def count_datasets(self, project_id: int | None = None) -> int:
1356
+ d = self._datasets
1357
+ query = self._datasets_select()
1358
+ if project_id:
1359
+ query = query.where(d.c.project_id == project_id)
1360
+
1361
+ query = select(f.count(1)).select_from(query.subquery())
1362
+
1363
+ return next(self.db.execute(query))[0]
1364
+
1201
1365
  def list_datasets_by_prefix(
1202
- self, prefix: str, project_id: Optional[int] = None, conn=None
1366
+ self, prefix: str, project_id: int | None = None, conn=None
1203
1367
  ) -> Iterator["DatasetListRecord"]:
1204
1368
  d = self._datasets
1205
1369
  query = self._base_list_datasets_query()
@@ -1211,8 +1375,8 @@ class AbstractDBMetastore(AbstractMetastore):
1211
1375
  def get_dataset(
1212
1376
  self,
1213
1377
  name: str, # normal, not full dataset name
1214
- namespace_name: Optional[str] = None,
1215
- project_name: Optional[str] = None,
1378
+ namespace_name: str | None = None,
1379
+ project_name: str | None = None,
1216
1380
  conn=None,
1217
1381
  ) -> DatasetRecord:
1218
1382
  """
@@ -1273,7 +1437,7 @@ class AbstractDBMetastore(AbstractMetastore):
1273
1437
  self,
1274
1438
  dataset: DatasetRecord,
1275
1439
  status: int,
1276
- version: Optional[str] = None,
1440
+ version: str | None = None,
1277
1441
  error_message="",
1278
1442
  error_stack="",
1279
1443
  script_output="",
@@ -1327,8 +1491,8 @@ class AbstractDBMetastore(AbstractMetastore):
1327
1491
  self,
1328
1492
  source_dataset: DatasetRecord,
1329
1493
  source_dataset_version: str,
1330
- new_source_dataset: Optional[DatasetRecord] = None,
1331
- new_source_dataset_version: Optional[str] = None,
1494
+ new_source_dataset: DatasetRecord | None = None,
1495
+ new_source_dataset_version: str | None = None,
1332
1496
  ) -> None:
1333
1497
  dd = self._datasets_dependencies
1334
1498
 
@@ -1358,9 +1522,21 @@ class AbstractDBMetastore(AbstractMetastore):
1358
1522
  Returns a list of columns to select in a query for fetching dataset dependencies
1359
1523
  """
1360
1524
 
1525
+ @abstractmethod
1526
+ def _dataset_dependency_nodes_select_columns(
1527
+ self,
1528
+ namespaces_subquery: "Subquery",
1529
+ dependency_tree_cte: "CTE",
1530
+ datasets_subquery: "Subquery",
1531
+ ) -> list["ColumnElement"]:
1532
+ """
1533
+ Returns a list of columns to select in a query for fetching
1534
+ dataset dependency nodes.
1535
+ """
1536
+
1361
1537
  def get_direct_dataset_dependencies(
1362
1538
  self, dataset: DatasetRecord, version: str
1363
- ) -> list[Optional[DatasetDependency]]:
1539
+ ) -> list[DatasetDependency | None]:
1364
1540
  n = self._namespaces
1365
1541
  p = self._projects
1366
1542
  d = self._datasets
@@ -1387,8 +1563,77 @@ class AbstractDBMetastore(AbstractMetastore):
1387
1563
 
1388
1564
  return [self.dependency_class.parse(*r) for r in self.db.execute(query)]
1389
1565
 
1566
+ def get_dataset_dependency_nodes(
1567
+ self, dataset_id: int, version_id: int, depth_limit: int = DEPTH_LIMIT_DEFAULT
1568
+ ) -> list[DatasetDependencyNode | None]:
1569
+ n = self._namespaces_select().subquery()
1570
+ p = self._projects
1571
+ d = self._datasets_select().subquery()
1572
+ dd = self._datasets_dependencies
1573
+ dv = self._datasets_versions
1574
+
1575
+ # Common dependency fields for CTE
1576
+ dep_fields = [
1577
+ dd.c.id,
1578
+ dd.c.source_dataset_id,
1579
+ dd.c.source_dataset_version_id,
1580
+ dd.c.dataset_id,
1581
+ dd.c.dataset_version_id,
1582
+ ]
1583
+
1584
+ # Base case: direct dependencies
1585
+ base_query = select(
1586
+ *dep_fields,
1587
+ literal(0).label("depth"),
1588
+ ).where(
1589
+ (dd.c.source_dataset_id == dataset_id)
1590
+ & (dd.c.source_dataset_version_id == version_id)
1591
+ )
1592
+
1593
+ cte = base_query.cte(name="dependency_tree", recursive=True)
1594
+
1595
+ # Recursive case: dependencies of dependencies
1596
+ # Limit depth to 100 to prevent infinite loops in case of circular dependencies
1597
+ recursive_query = (
1598
+ select(
1599
+ *dep_fields,
1600
+ (cte.c.depth + 1).label("depth"),
1601
+ )
1602
+ .select_from(
1603
+ cte.join(
1604
+ dd,
1605
+ (cte.c.dataset_id == dd.c.source_dataset_id)
1606
+ & (cte.c.dataset_version_id == dd.c.source_dataset_version_id),
1607
+ )
1608
+ )
1609
+ .where(cte.c.depth < depth_limit)
1610
+ )
1611
+
1612
+ cte = cte.union(recursive_query)
1613
+
1614
+ # Fetch all with full details
1615
+ select_cols = self._dataset_dependency_nodes_select_columns(
1616
+ namespaces_subquery=n,
1617
+ dependency_tree_cte=cte,
1618
+ datasets_subquery=d,
1619
+ )
1620
+ final_query = self._datasets_dependencies_select(*select_cols).select_from(
1621
+ # Use outer joins to handle cases where dependent datasets have been
1622
+ # physically deleted. This allows us to return dependency records with
1623
+ # None values instead of silently omitting them, making broken
1624
+ # dependencies visible to callers.
1625
+ cte.join(d, cte.c.dataset_id == d.c.id, isouter=True)
1626
+ .join(dv, cte.c.dataset_version_id == dv.c.id, isouter=True)
1627
+ .join(p, d.c.project_id == p.c.id, isouter=True)
1628
+ .join(n, p.c.namespace_id == n.c.id, isouter=True)
1629
+ )
1630
+
1631
+ return [
1632
+ self.dependency_node_class.parse(*r) for r in self.db.execute(final_query)
1633
+ ]
1634
+
1390
1635
  def remove_dataset_dependencies(
1391
- self, dataset: DatasetRecord, version: Optional[str] = None
1636
+ self, dataset: DatasetRecord, version: str | None = None
1392
1637
  ) -> None:
1393
1638
  """
1394
1639
  When we remove dataset, we need to clean up it's dependencies as well
@@ -1407,7 +1652,7 @@ class AbstractDBMetastore(AbstractMetastore):
1407
1652
  self.db.execute(q)
1408
1653
 
1409
1654
  def remove_dataset_dependants(
1410
- self, dataset: DatasetRecord, version: Optional[str] = None
1655
+ self, dataset: DatasetRecord, version: str | None = None
1411
1656
  ) -> None:
1412
1657
  """
1413
1658
  When we remove dataset, we need to clear its references in other dataset
@@ -1458,11 +1703,13 @@ class AbstractDBMetastore(AbstractMetastore):
1458
1703
  Column("error_stack", Text, nullable=False, default=""),
1459
1704
  Column("params", JSON, nullable=False),
1460
1705
  Column("metrics", JSON, nullable=False),
1706
+ Column("parent_job_id", Text, nullable=True),
1707
+ Index("idx_jobs_parent_job_id", "parent_job_id"),
1461
1708
  ]
1462
1709
 
1463
1710
  @cached_property
1464
1711
  def _job_fields(self) -> list[str]:
1465
- return [c.name for c in self._jobs_columns() if c.name] # type: ignore[attr-defined]
1712
+ return [c.name for c in self._jobs_columns() if isinstance(c, Column)] # type: ignore[attr-defined]
1466
1713
 
1467
1714
  @cached_property
1468
1715
  def _jobs(self) -> "Table":
@@ -1496,6 +1743,18 @@ class AbstractDBMetastore(AbstractMetastore):
1496
1743
  query = self._jobs_query().where(self._jobs.c.id.in_(ids))
1497
1744
  yield from self._parse_jobs(self.db.execute(query, conn=conn))
1498
1745
 
1746
+ def get_last_job_by_name(self, name: str, conn=None) -> "Job | None":
1747
+ query = (
1748
+ self._jobs_query()
1749
+ .where(self._jobs.c.name == name)
1750
+ .order_by(self._jobs.c.created_at.desc())
1751
+ .limit(1)
1752
+ )
1753
+ results = list(self.db.execute(query, conn=conn))
1754
+ if not results:
1755
+ return None
1756
+ return self._parse_job(results[0])
1757
+
1499
1758
  def create_job(
1500
1759
  self,
1501
1760
  name: str,
@@ -1503,9 +1762,10 @@ class AbstractDBMetastore(AbstractMetastore):
1503
1762
  query_type: JobQueryType = JobQueryType.PYTHON,
1504
1763
  status: JobStatus = JobStatus.CREATED,
1505
1764
  workers: int = 1,
1506
- python_version: Optional[str] = None,
1507
- params: Optional[dict[str, str]] = None,
1508
- conn: Optional[Any] = None,
1765
+ python_version: str | None = None,
1766
+ params: dict[str, str] | None = None,
1767
+ parent_job_id: str | None = None,
1768
+ conn: Any = None,
1509
1769
  ) -> str:
1510
1770
  """
1511
1771
  Creates a new job.
@@ -1526,12 +1786,13 @@ class AbstractDBMetastore(AbstractMetastore):
1526
1786
  error_stack="",
1527
1787
  params=json.dumps(params or {}),
1528
1788
  metrics=json.dumps({}),
1789
+ parent_job_id=parent_job_id,
1529
1790
  ),
1530
1791
  conn=conn,
1531
1792
  )
1532
1793
  return job_id
1533
1794
 
1534
- def get_job(self, job_id: str, conn=None) -> Optional[Job]:
1795
+ def get_job(self, job_id: str, conn=None) -> Job | None:
1535
1796
  """Returns the job with the given ID."""
1536
1797
  query = self._jobs_select(self._jobs).where(self._jobs.c.id == job_id)
1537
1798
  results = list(self.db.execute(query, conn=conn))
@@ -1542,13 +1803,13 @@ class AbstractDBMetastore(AbstractMetastore):
1542
1803
  def update_job(
1543
1804
  self,
1544
1805
  job_id: str,
1545
- status: Optional[JobStatus] = None,
1546
- error_message: Optional[str] = None,
1547
- error_stack: Optional[str] = None,
1548
- finished_at: Optional[datetime] = None,
1549
- metrics: Optional[dict[str, Any]] = None,
1550
- conn: Optional[Any] = None,
1551
- ) -> Optional["Job"]:
1806
+ status: JobStatus | None = None,
1807
+ error_message: str | None = None,
1808
+ error_stack: str | None = None,
1809
+ finished_at: datetime | None = None,
1810
+ metrics: dict[str, Any] | None = None,
1811
+ conn: Any | None = None,
1812
+ ) -> Job | None:
1552
1813
  """Updates job fields."""
1553
1814
  values: dict = {}
1554
1815
  if status is not None:
@@ -1575,9 +1836,9 @@ class AbstractDBMetastore(AbstractMetastore):
1575
1836
  self,
1576
1837
  job_id: str,
1577
1838
  status: JobStatus,
1578
- error_message: Optional[str] = None,
1579
- error_stack: Optional[str] = None,
1580
- conn: Optional[Any] = None,
1839
+ error_message: str | None = None,
1840
+ error_stack: str | None = None,
1841
+ conn: Any | None = None,
1581
1842
  ) -> None:
1582
1843
  """Set the status of the given job."""
1583
1844
  values: dict = {"status": status}
@@ -1595,8 +1856,8 @@ class AbstractDBMetastore(AbstractMetastore):
1595
1856
  def get_job_status(
1596
1857
  self,
1597
1858
  job_id: str,
1598
- conn: Optional[Any] = None,
1599
- ) -> Optional[JobStatus]:
1859
+ conn: Any | None = None,
1860
+ ) -> JobStatus | None:
1600
1861
  """Returns the status of the given job."""
1601
1862
  results = list(
1602
1863
  self.db.execute(
@@ -1607,3 +1868,321 @@ class AbstractDBMetastore(AbstractMetastore):
1607
1868
  if not results:
1608
1869
  return None
1609
1870
  return results[0][0]
1871
+
1872
+ #
1873
+ # Checkpoints
1874
+ #
1875
+
1876
+ @staticmethod
1877
+ def _checkpoints_columns() -> "list[SchemaItem]":
1878
+ return [
1879
+ Column(
1880
+ "id",
1881
+ Text,
1882
+ default=uuid4,
1883
+ primary_key=True,
1884
+ nullable=False,
1885
+ ),
1886
+ Column("job_id", Text, nullable=True),
1887
+ Column("hash", Text, nullable=False),
1888
+ Column("partial", Boolean, default=False),
1889
+ Column("created_at", DateTime(timezone=True), nullable=False),
1890
+ UniqueConstraint("job_id", "hash"),
1891
+ ]
1892
+
1893
+ @cached_property
1894
+ def _checkpoints_fields(self) -> list[str]:
1895
+ return [c.name for c in self._checkpoints_columns() if c.name] # type: ignore[attr-defined]
1896
+
1897
+ @cached_property
1898
+ def _checkpoints(self) -> "Table":
1899
+ return Table(
1900
+ self.CHECKPOINTS_TABLE,
1901
+ self.db.metadata,
1902
+ *self._checkpoints_columns(),
1903
+ )
1904
+
1905
+ @abstractmethod
1906
+ def _checkpoints_insert(self) -> "Insert": ...
1907
+
1908
+ @classmethod
1909
+ def _dataset_version_jobs_columns(cls) -> "list[SchemaItem]":
1910
+ """Junction table for dataset versions and jobs many-to-many relationship."""
1911
+ return [
1912
+ Column("id", Integer, primary_key=True),
1913
+ Column(
1914
+ "dataset_version_id",
1915
+ Integer,
1916
+ ForeignKey(f"{cls.DATASET_VERSION_TABLE}.id", ondelete="CASCADE"),
1917
+ nullable=False,
1918
+ ),
1919
+ Column("job_id", Text, nullable=False),
1920
+ Column("is_creator", Boolean, nullable=False, default=False),
1921
+ Column("created_at", DateTime(timezone=True)),
1922
+ UniqueConstraint("dataset_version_id", "job_id"),
1923
+ Index("dc_idx_dvj_query", "job_id", "is_creator", "created_at"),
1924
+ ]
1925
+
1926
+ @cached_property
1927
+ def _dataset_version_jobs_fields(self) -> list[str]:
1928
+ return [c.name for c in self._dataset_version_jobs_columns() if c.name] # type: ignore[attr-defined]
1929
+
1930
+ @cached_property
1931
+ def _dataset_version_jobs(self) -> "Table":
1932
+ return Table(
1933
+ self.DATASET_VERSION_JOBS_TABLE,
1934
+ self.db.metadata,
1935
+ *self._dataset_version_jobs_columns(),
1936
+ )
1937
+
1938
+ @abstractmethod
1939
+ def _dataset_version_jobs_insert(self) -> "Insert": ...
1940
+
1941
+ def _dataset_version_jobs_select(self, *columns) -> "Select":
1942
+ if not columns:
1943
+ return self._dataset_version_jobs.select()
1944
+ return select(*columns)
1945
+
1946
+ def _dataset_version_jobs_delete(self) -> "Delete":
1947
+ return self._dataset_version_jobs.delete()
1948
+
1949
+ def _checkpoints_select(self, *columns) -> "Select":
1950
+ if not columns:
1951
+ return self._checkpoints.select()
1952
+ return select(*columns)
1953
+
1954
+ def _checkpoints_delete(self) -> "Delete":
1955
+ return self._checkpoints.delete()
1956
+
1957
+ def _checkpoints_query(self):
1958
+ return self._checkpoints_select(
1959
+ *[getattr(self._checkpoints.c, f) for f in self._checkpoints_fields]
1960
+ )
1961
+
1962
+ def create_checkpoint(
1963
+ self,
1964
+ job_id: str,
1965
+ _hash: str,
1966
+ partial: bool = False,
1967
+ conn: Any | None = None,
1968
+ ) -> Checkpoint:
1969
+ """
1970
+ Creates a new job query step.
1971
+ """
1972
+ checkpoint_id = str(uuid4())
1973
+ self.db.execute(
1974
+ self._checkpoints_insert().values(
1975
+ id=checkpoint_id,
1976
+ job_id=job_id,
1977
+ hash=_hash,
1978
+ partial=partial,
1979
+ created_at=datetime.now(timezone.utc),
1980
+ ),
1981
+ conn=conn,
1982
+ )
1983
+ return self.get_checkpoint_by_id(checkpoint_id)
1984
+
1985
+ def list_checkpoints(self, job_id: str, conn=None) -> Iterator[Checkpoint]:
1986
+ """List checkpoints by job id."""
1987
+ query = self._checkpoints_query().where(self._checkpoints.c.job_id == job_id)
1988
+ rows = list(self.db.execute(query, conn=conn))
1989
+
1990
+ yield from [self.checkpoint_class.parse(*r) for r in rows]
1991
+
1992
+ def get_checkpoint_by_id(self, checkpoint_id: str, conn=None) -> Checkpoint:
1993
+ """Returns the checkpoint with the given ID."""
1994
+ ch = self._checkpoints
1995
+ query = self._checkpoints_select(ch).where(ch.c.id == checkpoint_id)
1996
+ rows = list(self.db.execute(query, conn=conn))
1997
+ if not rows:
1998
+ raise CheckpointNotFoundError(f"Checkpoint {checkpoint_id} not found")
1999
+ return self.checkpoint_class.parse(*rows[0])
2000
+
2001
+ def find_checkpoint(
2002
+ self, job_id: str, _hash: str, partial: bool = False, conn=None
2003
+ ) -> Checkpoint | None:
2004
+ """
2005
+ Tries to find checkpoint for a job with specific hash and optionally partial
2006
+ """
2007
+ ch = self._checkpoints
2008
+ query = self._checkpoints_select(ch).where(
2009
+ ch.c.job_id == job_id, ch.c.hash == _hash, ch.c.partial == partial
2010
+ )
2011
+ rows = list(self.db.execute(query, conn=conn))
2012
+ if not rows:
2013
+ return None
2014
+ return self.checkpoint_class.parse(*rows[0])
2015
+
2016
+ def get_last_checkpoint(self, job_id: str, conn=None) -> Checkpoint | None:
2017
+ query = (
2018
+ self._checkpoints_query()
2019
+ .where(self._checkpoints.c.job_id == job_id)
2020
+ .order_by(desc(self._checkpoints.c.created_at))
2021
+ .limit(1)
2022
+ )
2023
+ rows = list(self.db.execute(query, conn=conn))
2024
+ if not rows:
2025
+ return None
2026
+ return self.checkpoint_class.parse(*rows[0])
2027
+
2028
+ def link_dataset_version_to_job(
2029
+ self,
2030
+ dataset_version_id: int,
2031
+ job_id: str,
2032
+ is_creator: bool = False,
2033
+ conn=None,
2034
+ ) -> None:
2035
+ # Use transaction to atomically:
2036
+ # 1. Link dataset version to job in junction table
2037
+ # 2. Update dataset_version.job_id to point to this job
2038
+ with self.db.transaction() as tx_conn:
2039
+ conn = conn or tx_conn
2040
+
2041
+ # Insert into junction table
2042
+ query = self._dataset_version_jobs_insert().values(
2043
+ dataset_version_id=dataset_version_id,
2044
+ job_id=job_id,
2045
+ is_creator=is_creator,
2046
+ created_at=datetime.now(timezone.utc),
2047
+ )
2048
+ if hasattr(query, "on_conflict_do_nothing"):
2049
+ query = query.on_conflict_do_nothing(
2050
+ index_elements=["dataset_version_id", "job_id"]
2051
+ )
2052
+ self.db.execute(query, conn=conn)
2053
+
2054
+ # Also update dataset_version.job_id to point to this job
2055
+ update_query = (
2056
+ self._datasets_versions.update()
2057
+ .where(self._datasets_versions.c.id == dataset_version_id)
2058
+ .values(job_id=job_id)
2059
+ )
2060
+ self.db.execute(update_query, conn=conn)
2061
+
2062
+ def get_ancestor_job_ids(self, job_id: str, conn=None) -> list[str]:
2063
+ # Use recursive CTE to walk up the parent chain
2064
+ # Format: WITH RECURSIVE ancestors(id, parent_job_id, depth) AS (...)
2065
+ # Include depth tracking to prevent infinite recursion in case of
2066
+ # circular dependencies
2067
+ ancestors_cte = (
2068
+ self._jobs_select(
2069
+ self._jobs.c.id.label("id"),
2070
+ self._jobs.c.parent_job_id.label("parent_job_id"),
2071
+ literal(0).label("depth"),
2072
+ )
2073
+ .where(self._jobs.c.id == job_id)
2074
+ .cte(name="ancestors", recursive=True)
2075
+ )
2076
+
2077
+ # Recursive part: join with parent jobs, incrementing depth and checking limit
2078
+ ancestors_recursive = ancestors_cte.union_all(
2079
+ self._jobs_select(
2080
+ self._jobs.c.id.label("id"),
2081
+ self._jobs.c.parent_job_id.label("parent_job_id"),
2082
+ (ancestors_cte.c.depth + 1).label("depth"),
2083
+ ).select_from(
2084
+ self._jobs.join(
2085
+ ancestors_cte,
2086
+ (
2087
+ self._jobs.c.id
2088
+ == cast(ancestors_cte.c.parent_job_id, self._jobs.c.id.type)
2089
+ )
2090
+ & (ancestors_cte.c.parent_job_id.isnot(None)) # Stop at root jobs
2091
+ & (ancestors_cte.c.depth < JOB_ANCESTRY_MAX_DEPTH),
2092
+ )
2093
+ )
2094
+ )
2095
+
2096
+ # Select all ancestor IDs and depths except the starting job itself
2097
+ query = select(ancestors_recursive.c.id, ancestors_recursive.c.depth).where(
2098
+ ancestors_recursive.c.id != job_id
2099
+ )
2100
+
2101
+ results = list(self.db.execute(query, conn=conn))
2102
+
2103
+ # Check if we hit the depth limit
2104
+ if results:
2105
+ max_found_depth = max(row[1] for row in results)
2106
+ if max_found_depth >= JOB_ANCESTRY_MAX_DEPTH:
2107
+ from datachain.error import JobAncestryDepthExceededError
2108
+
2109
+ raise JobAncestryDepthExceededError(
2110
+ f"Job ancestry chain exceeds maximum depth of "
2111
+ f"{JOB_ANCESTRY_MAX_DEPTH}. Job ID: {job_id}"
2112
+ )
2113
+
2114
+ return [str(row[0]) for row in results]
2115
+
2116
+ def _get_dataset_version_for_job_ancestry_query(
2117
+ self,
2118
+ dataset_name: str,
2119
+ namespace_name: str,
2120
+ project_name: str,
2121
+ job_ancestry: list[str],
2122
+ ) -> "Select":
2123
+ """Find most recent dataset version created by any job in ancestry.
2124
+
2125
+ Searches job ancestry (current + parents) for the newest version of
2126
+ the dataset where is_creator=True. Returns newest by created_at, or
2127
+ None if no version was created by any job in the ancestry chain.
2128
+
2129
+ Used for checkpoint resolution to find which version to reuse when
2130
+ continuing from a parent job.
2131
+ """
2132
+ return (
2133
+ self._datasets_versions_select()
2134
+ .select_from(
2135
+ self._dataset_version_jobs.join(
2136
+ self._datasets_versions,
2137
+ self._dataset_version_jobs.c.dataset_version_id
2138
+ == self._datasets_versions.c.id,
2139
+ )
2140
+ .join(
2141
+ self._datasets,
2142
+ self._datasets_versions.c.dataset_id == self._datasets.c.id,
2143
+ )
2144
+ .join(
2145
+ self._projects,
2146
+ self._datasets.c.project_id == self._projects.c.id,
2147
+ )
2148
+ .join(
2149
+ self._namespaces,
2150
+ self._projects.c.namespace_id == self._namespaces.c.id,
2151
+ )
2152
+ )
2153
+ .where(
2154
+ self._datasets.c.name == dataset_name,
2155
+ self._namespaces.c.name == namespace_name,
2156
+ self._projects.c.name == project_name,
2157
+ self._dataset_version_jobs.c.job_id.in_(job_ancestry),
2158
+ self._dataset_version_jobs.c.is_creator.is_(True),
2159
+ )
2160
+ .order_by(desc(self._dataset_version_jobs.c.created_at))
2161
+ .limit(1)
2162
+ )
2163
+
2164
+ def get_dataset_version_for_job_ancestry(
2165
+ self,
2166
+ dataset_name: str,
2167
+ namespace_name: str,
2168
+ project_name: str,
2169
+ job_id: str,
2170
+ conn=None,
2171
+ ) -> DatasetVersion | None:
2172
+ # Get job ancestry (current job + all ancestors)
2173
+ job_ancestry = [job_id, *self.get_ancestor_job_ids(job_id, conn=conn)]
2174
+
2175
+ query = self._get_dataset_version_for_job_ancestry_query(
2176
+ dataset_name, namespace_name, project_name, job_ancestry
2177
+ )
2178
+
2179
+ results = list(self.db.execute(query, conn=conn))
2180
+ if not results:
2181
+ return None
2182
+
2183
+ if len(results) > 1:
2184
+ raise DataChainError(
2185
+ f"Expected at most 1 dataset version, found {len(results)}"
2186
+ )
2187
+
2188
+ return self.dataset_version_class.parse(*results[0])