datachain 0.14.2__py3-none-any.whl → 0.39.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (137) hide show
  1. datachain/__init__.py +20 -0
  2. datachain/asyn.py +11 -12
  3. datachain/cache.py +7 -7
  4. datachain/catalog/__init__.py +2 -2
  5. datachain/catalog/catalog.py +621 -507
  6. datachain/catalog/dependency.py +164 -0
  7. datachain/catalog/loader.py +28 -18
  8. datachain/checkpoint.py +43 -0
  9. datachain/cli/__init__.py +24 -33
  10. datachain/cli/commands/__init__.py +1 -8
  11. datachain/cli/commands/datasets.py +83 -52
  12. datachain/cli/commands/ls.py +17 -17
  13. datachain/cli/commands/show.py +4 -4
  14. datachain/cli/parser/__init__.py +8 -74
  15. datachain/cli/parser/job.py +95 -3
  16. datachain/cli/parser/studio.py +11 -4
  17. datachain/cli/parser/utils.py +1 -2
  18. datachain/cli/utils.py +2 -15
  19. datachain/client/azure.py +4 -4
  20. datachain/client/fsspec.py +45 -28
  21. datachain/client/gcs.py +6 -6
  22. datachain/client/hf.py +29 -2
  23. datachain/client/http.py +157 -0
  24. datachain/client/local.py +15 -11
  25. datachain/client/s3.py +17 -9
  26. datachain/config.py +4 -8
  27. datachain/data_storage/db_engine.py +12 -6
  28. datachain/data_storage/job.py +5 -1
  29. datachain/data_storage/metastore.py +1252 -186
  30. datachain/data_storage/schema.py +58 -45
  31. datachain/data_storage/serializer.py +105 -15
  32. datachain/data_storage/sqlite.py +286 -127
  33. datachain/data_storage/warehouse.py +250 -113
  34. datachain/dataset.py +353 -148
  35. datachain/delta.py +391 -0
  36. datachain/diff/__init__.py +27 -29
  37. datachain/error.py +60 -0
  38. datachain/func/__init__.py +2 -1
  39. datachain/func/aggregate.py +66 -42
  40. datachain/func/array.py +242 -38
  41. datachain/func/base.py +7 -4
  42. datachain/func/conditional.py +110 -60
  43. datachain/func/func.py +96 -45
  44. datachain/func/numeric.py +55 -38
  45. datachain/func/path.py +32 -20
  46. datachain/func/random.py +2 -2
  47. datachain/func/string.py +67 -37
  48. datachain/func/window.py +7 -8
  49. datachain/hash_utils.py +123 -0
  50. datachain/job.py +11 -7
  51. datachain/json.py +138 -0
  52. datachain/lib/arrow.py +58 -22
  53. datachain/lib/audio.py +245 -0
  54. datachain/lib/clip.py +14 -13
  55. datachain/lib/convert/flatten.py +5 -3
  56. datachain/lib/convert/python_to_sql.py +6 -10
  57. datachain/lib/convert/sql_to_python.py +8 -0
  58. datachain/lib/convert/values_to_tuples.py +156 -51
  59. datachain/lib/data_model.py +42 -20
  60. datachain/lib/dataset_info.py +36 -8
  61. datachain/lib/dc/__init__.py +8 -2
  62. datachain/lib/dc/csv.py +25 -28
  63. datachain/lib/dc/database.py +398 -0
  64. datachain/lib/dc/datachain.py +1289 -425
  65. datachain/lib/dc/datasets.py +320 -38
  66. datachain/lib/dc/hf.py +38 -24
  67. datachain/lib/dc/json.py +29 -32
  68. datachain/lib/dc/listings.py +112 -8
  69. datachain/lib/dc/pandas.py +16 -12
  70. datachain/lib/dc/parquet.py +35 -23
  71. datachain/lib/dc/records.py +31 -23
  72. datachain/lib/dc/storage.py +154 -64
  73. datachain/lib/dc/storage_pattern.py +251 -0
  74. datachain/lib/dc/utils.py +24 -16
  75. datachain/lib/dc/values.py +8 -9
  76. datachain/lib/file.py +622 -89
  77. datachain/lib/hf.py +69 -39
  78. datachain/lib/image.py +14 -14
  79. datachain/lib/listing.py +14 -11
  80. datachain/lib/listing_info.py +1 -2
  81. datachain/lib/meta_formats.py +3 -4
  82. datachain/lib/model_store.py +39 -7
  83. datachain/lib/namespaces.py +125 -0
  84. datachain/lib/projects.py +130 -0
  85. datachain/lib/pytorch.py +32 -21
  86. datachain/lib/settings.py +192 -56
  87. datachain/lib/signal_schema.py +427 -104
  88. datachain/lib/tar.py +1 -2
  89. datachain/lib/text.py +8 -7
  90. datachain/lib/udf.py +164 -76
  91. datachain/lib/udf_signature.py +60 -35
  92. datachain/lib/utils.py +118 -4
  93. datachain/lib/video.py +17 -9
  94. datachain/lib/webdataset.py +61 -56
  95. datachain/lib/webdataset_laion.py +15 -16
  96. datachain/listing.py +22 -10
  97. datachain/model/bbox.py +3 -1
  98. datachain/model/ultralytics/bbox.py +16 -12
  99. datachain/model/ultralytics/pose.py +16 -12
  100. datachain/model/ultralytics/segment.py +16 -12
  101. datachain/namespace.py +84 -0
  102. datachain/node.py +6 -6
  103. datachain/nodes_thread_pool.py +0 -1
  104. datachain/plugins.py +24 -0
  105. datachain/project.py +78 -0
  106. datachain/query/batch.py +40 -41
  107. datachain/query/dataset.py +604 -322
  108. datachain/query/dispatch.py +261 -154
  109. datachain/query/metrics.py +4 -6
  110. datachain/query/params.py +2 -3
  111. datachain/query/queue.py +3 -12
  112. datachain/query/schema.py +11 -6
  113. datachain/query/session.py +200 -33
  114. datachain/query/udf.py +34 -2
  115. datachain/remote/studio.py +171 -69
  116. datachain/script_meta.py +12 -12
  117. datachain/semver.py +68 -0
  118. datachain/sql/__init__.py +2 -0
  119. datachain/sql/functions/array.py +33 -1
  120. datachain/sql/postgresql_dialect.py +9 -0
  121. datachain/sql/postgresql_types.py +21 -0
  122. datachain/sql/sqlite/__init__.py +5 -1
  123. datachain/sql/sqlite/base.py +102 -29
  124. datachain/sql/sqlite/types.py +8 -13
  125. datachain/sql/types.py +70 -15
  126. datachain/studio.py +223 -46
  127. datachain/toolkit/split.py +31 -10
  128. datachain/utils.py +101 -59
  129. {datachain-0.14.2.dist-info → datachain-0.39.0.dist-info}/METADATA +77 -22
  130. datachain-0.39.0.dist-info/RECORD +173 -0
  131. {datachain-0.14.2.dist-info → datachain-0.39.0.dist-info}/WHEEL +1 -1
  132. datachain/cli/commands/query.py +0 -53
  133. datachain/query/utils.py +0 -42
  134. datachain-0.14.2.dist-info/RECORD +0 -158
  135. {datachain-0.14.2.dist-info → datachain-0.39.0.dist-info}/entry_points.txt +0 -0
  136. {datachain-0.14.2.dist-info → datachain-0.39.0.dist-info}/licenses/LICENSE +0 -0
  137. {datachain-0.14.2.dist-info → datachain-0.39.0.dist-info}/top_level.txt +0 -0
@@ -1,25 +1,29 @@
1
1
  import logging
2
2
  import os
3
3
  import sqlite3
4
- from collections.abc import Iterable, Sequence
4
+ from collections.abc import Callable, Iterable, Sequence
5
5
  from contextlib import contextmanager
6
- from functools import wraps
6
+ from functools import cached_property, wraps
7
7
  from time import sleep
8
- from typing import (
9
- TYPE_CHECKING,
10
- Any,
11
- Callable,
12
- ClassVar,
13
- Optional,
14
- Union,
15
- )
8
+ from typing import TYPE_CHECKING, Any, ClassVar, Union
16
9
 
17
10
  import sqlalchemy
18
- from sqlalchemy import MetaData, Table, UniqueConstraint, exists, select
11
+ from sqlalchemy import (
12
+ Column,
13
+ Integer,
14
+ MetaData,
15
+ Table,
16
+ UniqueConstraint,
17
+ exists,
18
+ select,
19
+ )
19
20
  from sqlalchemy.dialects import sqlite
20
21
  from sqlalchemy.schema import CreateIndex, CreateTable, DropTable
21
22
  from sqlalchemy.sql import func
22
- from sqlalchemy.sql.elements import BinaryExpression, BooleanClauseList
23
+ from sqlalchemy.sql.elements import (
24
+ BinaryExpression,
25
+ BooleanClauseList,
26
+ )
23
27
  from sqlalchemy.sql.expression import bindparam, cast
24
28
  from sqlalchemy.sql.selectable import Select
25
29
  from tqdm.auto import tqdm
@@ -28,14 +32,18 @@ import datachain.sql.sqlite
28
32
  from datachain.data_storage import AbstractDBMetastore, AbstractWarehouse
29
33
  from datachain.data_storage.db_engine import DatabaseEngine
30
34
  from datachain.data_storage.schema import DefaultSchema
35
+ from datachain.data_storage.warehouse import INSERT_BATCH_SIZE
31
36
  from datachain.dataset import DatasetRecord, StorageURI
32
- from datachain.error import DataChainError
37
+ from datachain.error import DataChainError, OutdatedDatabaseSchemaError
38
+ from datachain.namespace import Namespace
39
+ from datachain.project import Project
33
40
  from datachain.sql.sqlite import create_user_defined_sql_functions, sqlite_dialect
34
41
  from datachain.sql.sqlite.base import load_usearch_extension
35
42
  from datachain.sql.types import SQLType
36
- from datachain.utils import DataChainDir, batched_it
43
+ from datachain.utils import DataChainDir, batched, batched_it
37
44
 
38
45
  if TYPE_CHECKING:
46
+ from sqlalchemy import CTE, Subquery
39
47
  from sqlalchemy.dialects.sqlite import Insert
40
48
  from sqlalchemy.engine.base import Engine
41
49
  from sqlalchemy.schema import SchemaItem
@@ -59,6 +67,14 @@ datachain.sql.sqlite.setup()
59
67
  quote_schema = sqlite_dialect.identifier_preparer.quote_schema
60
68
  quote = sqlite_dialect.identifier_preparer.quote
61
69
 
70
+ # NOTE! This should be manually increased when we change our DB schema in codebase
71
+ SCHEMA_VERSION = 1
72
+
73
+ OUTDATED_SCHEMA_ERROR_MESSAGE = (
74
+ "You have an old version of the database schema. Please refer to the documentation"
75
+ " for more information."
76
+ )
77
+
62
78
 
63
79
  def _get_in_memory_uri():
64
80
  return "file::memory:?cache=shared"
@@ -85,8 +101,8 @@ def retry_sqlite_locks(func):
85
101
 
86
102
 
87
103
  def get_db_file_in_memory(
88
- db_file: Optional[str] = None, in_memory: bool = False
89
- ) -> Optional[str]:
104
+ db_file: str | None = None, in_memory: bool = False
105
+ ) -> str | None:
90
106
  """Get in-memory db_file and check that conflicting arguments are not provided."""
91
107
  if in_memory:
92
108
  if db_file and db_file != ":memory:":
@@ -99,7 +115,7 @@ class SQLiteDatabaseEngine(DatabaseEngine):
99
115
  dialect = sqlite_dialect
100
116
 
101
117
  db: sqlite3.Connection
102
- db_file: Optional[str]
118
+ db_file: str | None
103
119
  is_closed: bool
104
120
 
105
121
  def __init__(
@@ -107,22 +123,24 @@ class SQLiteDatabaseEngine(DatabaseEngine):
107
123
  engine: "Engine",
108
124
  metadata: "MetaData",
109
125
  db: sqlite3.Connection,
110
- db_file: Optional[str] = None,
126
+ db_file: str | None = None,
127
+ max_variable_number: int | None = 999,
111
128
  ):
112
129
  self.engine = engine
113
130
  self.metadata = metadata
114
131
  self.db = db
115
132
  self.db_file = db_file
116
133
  self.is_closed = False
134
+ self.max_variable_number = max_variable_number
117
135
 
118
136
  @classmethod
119
- def from_db_file(cls, db_file: Optional[str] = None) -> "SQLiteDatabaseEngine":
137
+ def from_db_file(cls, db_file: str | None = None) -> "SQLiteDatabaseEngine":
120
138
  return cls(*cls._connect(db_file=db_file))
121
139
 
122
140
  @staticmethod
123
141
  def _connect(
124
- db_file: Optional[str] = None,
125
- ) -> tuple["Engine", "MetaData", sqlite3.Connection, str]:
142
+ db_file: str | None = None,
143
+ ) -> tuple["Engine", "MetaData", sqlite3.Connection, str, int]:
126
144
  try:
127
145
  if db_file == ":memory:":
128
146
  # Enable multithreaded usage of the same in-memory db
@@ -149,6 +167,13 @@ class SQLiteDatabaseEngine(DatabaseEngine):
149
167
  db.execute("PRAGMA journal_mode = WAL")
150
168
  db.execute("PRAGMA synchronous = NORMAL")
151
169
  db.execute("PRAGMA case_sensitive_like = ON")
170
+
171
+ max_variable_number = 999 # minimum in old SQLite versions
172
+ for row in db.execute("PRAGMA compile_options;").fetchall():
173
+ option = row[0]
174
+ if option.startswith("MAX_VARIABLE_NUMBER="):
175
+ max_variable_number = int(option.split("=")[1])
176
+
152
177
  if os.environ.get("DEBUG_SHOW_SQL_QUERIES"):
153
178
  import sys
154
179
 
@@ -156,7 +181,7 @@ class SQLiteDatabaseEngine(DatabaseEngine):
156
181
 
157
182
  load_usearch_extension(db)
158
183
 
159
- return engine, MetaData(), db, db_file
184
+ return engine, MetaData(), db, db_file, max_variable_number
160
185
  except RuntimeError:
161
186
  raise DataChainError("Can't connect to SQLite DB") from None
162
187
 
@@ -172,18 +197,25 @@ class SQLiteDatabaseEngine(DatabaseEngine):
172
197
  """
173
198
  return (
174
199
  SQLiteDatabaseEngine.from_db_file,
175
- [self.db_file],
200
+ [str(self.db_file)],
176
201
  {},
177
202
  )
178
203
 
204
+ @classmethod
205
+ def serialize_callable_name(cls) -> str:
206
+ return "sqlite.from_db_file"
207
+
179
208
  def _reconnect(self) -> None:
180
209
  if not self.is_closed:
181
210
  raise RuntimeError("Cannot reconnect on still-open DB!")
182
- engine, metadata, db, db_file = self._connect(db_file=self.db_file)
211
+ engine, metadata, db, db_file, max_variable_number = self._connect(
212
+ db_file=self.db_file
213
+ )
183
214
  self.engine = engine
184
215
  self.metadata = metadata
185
216
  self.db = db
186
217
  self.db_file = db_file
218
+ self.max_variable_number = max_variable_number
187
219
  self.is_closed = False
188
220
 
189
221
  def get_table(self, name: str) -> Table:
@@ -196,7 +228,7 @@ class SQLiteDatabaseEngine(DatabaseEngine):
196
228
  def execute(
197
229
  self,
198
230
  query,
199
- cursor: Optional[sqlite3.Cursor] = None,
231
+ cursor: sqlite3.Cursor | None = None,
200
232
  conn=None,
201
233
  ) -> sqlite3.Cursor:
202
234
  if self.is_closed:
@@ -215,7 +247,7 @@ class SQLiteDatabaseEngine(DatabaseEngine):
215
247
 
216
248
  @retry_sqlite_locks
217
249
  def executemany(
218
- self, query, params, cursor: Optional[sqlite3.Cursor] = None, conn=None
250
+ self, query, params, cursor: sqlite3.Cursor | None = None, conn=None
219
251
  ) -> sqlite3.Cursor:
220
252
  if cursor:
221
253
  return cursor.executemany(self.compile(query).string, params)
@@ -230,13 +262,27 @@ class SQLiteDatabaseEngine(DatabaseEngine):
230
262
  return self.db.execute(sql, parameters)
231
263
 
232
264
  def insert_dataframe(self, table_name: str, df) -> int:
265
+ # Dynamically calculates chunksize by dividing max variable limit in a
266
+ # single SQL insert with number of columns in dataframe.
267
+ # This way we avoid error: sqlite3.OperationalError: too many SQL variables,
268
+ num_columns = df.shape[1]
269
+ if num_columns == 0:
270
+ num_columns = 1
271
+
272
+ if self.max_variable_number < num_columns:
273
+ raise RuntimeError(
274
+ "Number of columns exceeds DB maximum variables when inserting data"
275
+ )
276
+
277
+ chunksize = self.max_variable_number // num_columns
278
+
233
279
  return df.to_sql(
234
280
  table_name,
235
281
  self.db,
236
282
  if_exists="append",
237
283
  index=False,
238
284
  method="multi",
239
- chunksize=1000,
285
+ chunksize=chunksize,
240
286
  )
241
287
 
242
288
  def cursor(self, factory=None):
@@ -245,6 +291,8 @@ class SQLiteDatabaseEngine(DatabaseEngine):
245
291
  return self.db.cursor(factory)
246
292
 
247
293
  def close(self) -> None:
294
+ if self.is_closed:
295
+ return
248
296
  self.db.close()
249
297
  self.is_closed = True
250
298
 
@@ -276,7 +324,18 @@ class SQLiteDatabaseEngine(DatabaseEngine):
276
324
  )
277
325
  return bool(next(self.execute(query))[0])
278
326
 
279
- def create_table(self, table: "Table", if_not_exists: bool = True) -> None:
327
+ @property
328
+ def table_names(self) -> list[str]:
329
+ query = "SELECT name FROM sqlite_master WHERE type='table';"
330
+ return [r[0] for r in self.execute_str(query).fetchall()]
331
+
332
+ def create_table(
333
+ self,
334
+ table: "Table",
335
+ if_not_exists: bool = True,
336
+ *,
337
+ kind: str | None = None,
338
+ ) -> None:
280
339
  self.execute(CreateTable(table, if_not_exists=if_not_exists))
281
340
 
282
341
  def drop_table(self, table: "Table", if_exists: bool = False) -> None:
@@ -294,13 +353,15 @@ class SQLiteMetastore(AbstractDBMetastore):
294
353
  This is currently used for the local cli.
295
354
  """
296
355
 
297
- db: "SQLiteDatabaseEngine"
356
+ META_TABLE = "meta"
357
+
358
+ db: SQLiteDatabaseEngine
298
359
 
299
360
  def __init__(
300
361
  self,
301
- uri: Optional[StorageURI] = None,
302
- db: Optional["SQLiteDatabaseEngine"] = None,
303
- db_file: Optional[str] = None,
362
+ uri: StorageURI | None = None,
363
+ db: SQLiteDatabaseEngine | None = None,
364
+ db_file: str | None = None,
304
365
  in_memory: bool = False,
305
366
  ):
306
367
  uri = uri or StorageURI("")
@@ -315,7 +376,12 @@ class SQLiteMetastore(AbstractDBMetastore):
315
376
 
316
377
  self.db = db or SQLiteDatabaseEngine.from_db_file(db_file)
317
378
 
318
- self._init_tables()
379
+ with self._init_guard():
380
+ self._init_meta_table()
381
+ self._init_meta_schema_value()
382
+ self._check_schema_version()
383
+ self._init_tables()
384
+ self._init_namespaces_projects()
319
385
 
320
386
  def __exit__(self, exc_type, exc_value, traceback) -> None:
321
387
  """Close connection upon exit from context manager."""
@@ -323,7 +389,7 @@ class SQLiteMetastore(AbstractDBMetastore):
323
389
 
324
390
  def clone(
325
391
  self,
326
- uri: Optional[StorageURI] = None,
392
+ uri: StorageURI | None = None,
327
393
  use_new_connection: bool = False,
328
394
  ) -> "SQLiteMetastore":
329
395
  uri = uri or StorageURI("")
@@ -346,6 +412,10 @@ class SQLiteMetastore(AbstractDBMetastore):
346
412
  },
347
413
  )
348
414
 
415
+ @classmethod
416
+ def serialize_callable_name(cls) -> str:
417
+ return "sqlite.metastore.init_after_clone"
418
+
349
419
  @classmethod
350
420
  def init_after_clone(
351
421
  cls,
@@ -356,8 +426,44 @@ class SQLiteMetastore(AbstractDBMetastore):
356
426
  (db_class, db_args, db_kwargs) = db_clone_params
357
427
  return cls(uri=uri, db=db_class(*db_args, **db_kwargs))
358
428
 
429
+ @cached_property
430
+ def _meta(self) -> Table:
431
+ return Table(self.META_TABLE, self.db.metadata, *self._meta_columns())
432
+
433
+ def _meta_select(self, *columns) -> "Select":
434
+ if not columns:
435
+ return self._meta.select()
436
+ return select(*columns)
437
+
438
+ def _meta_insert(self) -> "Insert":
439
+ return sqlite.insert(self._meta)
440
+
441
+ def _init_meta_table(self) -> None:
442
+ """Initializes meta table"""
443
+ # NOTE! needs to be called before _init_tables()
444
+ table_names = self.db.table_names
445
+ if table_names and self.META_TABLE not in table_names:
446
+ # this will happen on first run
447
+ raise OutdatedDatabaseSchemaError(OUTDATED_SCHEMA_ERROR_MESSAGE)
448
+
449
+ self.db.create_table(self._meta, if_not_exists=True)
450
+ self.default_table_names.append(self._meta.name)
451
+
452
+ def _init_meta_schema_value(self) -> None:
453
+ """Inserts current schema version value if not present in meta table yet"""
454
+ stmt = (
455
+ self._meta_insert()
456
+ .values(id=1, schema_version=SCHEMA_VERSION)
457
+ .on_conflict_do_nothing(index_elements=["id"])
458
+ )
459
+ self.db.execute(stmt)
460
+
359
461
  def _init_tables(self) -> None:
360
462
  """Initialize tables."""
463
+ self.db.create_table(self._namespaces, if_not_exists=True)
464
+ self.default_table_names.append(self._namespaces.name)
465
+ self.db.create_table(self._projects, if_not_exists=True)
466
+ self.default_table_names.append(self._projects.name)
361
467
  self.db.create_table(self._datasets, if_not_exists=True)
362
468
  self.default_table_names.append(self._datasets.name)
363
469
  self.db.create_table(self._datasets_versions, if_not_exists=True)
@@ -366,11 +472,61 @@ class SQLiteMetastore(AbstractDBMetastore):
366
472
  self.default_table_names.append(self._datasets_dependencies.name)
367
473
  self.db.create_table(self._jobs, if_not_exists=True)
368
474
  self.default_table_names.append(self._jobs.name)
475
+ self.db.create_table(self._checkpoints, if_not_exists=True)
476
+ self.default_table_names.append(self._checkpoints.name)
477
+ self.db.create_table(self._dataset_version_jobs, if_not_exists=True)
478
+ self.default_table_names.append(self._dataset_version_jobs.name)
479
+
480
+ def _init_namespaces_projects(self) -> None:
481
+ """
482
+ Creates local namespace and local project connected to it.
483
+ In local environment user cannot explicitly create other namespaces and
484
+ projects and all datasets user creates will be stored in those.
485
+ When pulling dataset from Studio, then other namespaces and projects will
486
+ be created implicitly though, to keep the same fully qualified name with
487
+ Studio dataset.
488
+ """
489
+ system_namespace = self.create_namespace(
490
+ Namespace.system(), "System namespace", validate=False
491
+ )
492
+ self.create_project(
493
+ system_namespace.name, Project.listing(), "Listing project", validate=False
494
+ )
495
+
496
+ def _check_schema_version(self) -> None:
497
+ """
498
+ Checks if current DB schema is up to date with latest DB model and schema
499
+ version. If not, OutdatedDatabaseSchemaError is raised.
500
+ """
501
+ schema_version = next(self.db.execute(self._meta_select()))[1]
502
+ if schema_version < SCHEMA_VERSION:
503
+ raise OutdatedDatabaseSchemaError(OUTDATED_SCHEMA_ERROR_MESSAGE)
504
+
505
+ #
506
+ # Dataset dependencies
507
+ #
508
+ @classmethod
509
+ def _meta_columns(cls) -> list["SchemaItem"]:
510
+ return [
511
+ Column("id", Integer, primary_key=True),
512
+ Column("schema_version", Integer, default=SCHEMA_VERSION),
513
+ ]
369
514
 
370
515
  @classmethod
371
516
  def _datasets_columns(cls) -> list["SchemaItem"]:
372
517
  """Datasets table columns."""
373
- return [*super()._datasets_columns(), UniqueConstraint("name")]
518
+ return [*super()._datasets_columns(), UniqueConstraint("project_id", "name")]
519
+
520
+ @classmethod
521
+ def _namespaces_columns(cls) -> list["SchemaItem"]:
522
+ """Datasets table columns."""
523
+ return [*super()._namespaces_columns(), UniqueConstraint("name")]
524
+
525
+ def _namespaces_insert(self) -> "Insert":
526
+ return sqlite.insert(self._namespaces)
527
+
528
+ def _projects_insert(self) -> "Insert":
529
+ return sqlite.insert(self._projects)
374
530
 
375
531
  def _datasets_insert(self) -> "Insert":
376
532
  return sqlite.insert(self._datasets)
@@ -387,6 +543,8 @@ class SQLiteMetastore(AbstractDBMetastore):
387
543
 
388
544
  def _dataset_dependencies_select_columns(self) -> list["SchemaItem"]:
389
545
  return [
546
+ self._namespaces.c.name,
547
+ self._projects.c.name,
390
548
  self._datasets_dependencies.c.id,
391
549
  self._datasets_dependencies.c.dataset_id,
392
550
  self._datasets_dependencies.c.dataset_version_id,
@@ -395,6 +553,26 @@ class SQLiteMetastore(AbstractDBMetastore):
395
553
  self._datasets_versions.c.created_at,
396
554
  ]
397
555
 
556
+ def _dataset_dependency_nodes_select_columns(
557
+ self,
558
+ namespaces_subquery: "Subquery",
559
+ dependency_tree_cte: "CTE",
560
+ datasets_subquery: "Subquery",
561
+ ) -> list["ColumnElement"]:
562
+ return [
563
+ namespaces_subquery.c.name,
564
+ self._projects.c.name,
565
+ dependency_tree_cte.c.id,
566
+ dependency_tree_cte.c.dataset_id,
567
+ dependency_tree_cte.c.dataset_version_id,
568
+ datasets_subquery.c.name,
569
+ self._datasets_versions.c.version,
570
+ self._datasets_versions.c.created_at,
571
+ dependency_tree_cte.c.source_dataset_id,
572
+ dependency_tree_cte.c.source_dataset_version_id,
573
+ dependency_tree_cte.c.depth,
574
+ ]
575
+
398
576
  #
399
577
  # Jobs
400
578
  #
@@ -402,6 +580,31 @@ class SQLiteMetastore(AbstractDBMetastore):
402
580
  def _jobs_insert(self) -> "Insert":
403
581
  return sqlite.insert(self._jobs)
404
582
 
583
+ #
584
+ # Checkpoints
585
+ #
586
+ def _checkpoints_insert(self) -> "Insert":
587
+ return sqlite.insert(self._checkpoints)
588
+
589
+ def _dataset_version_jobs_insert(self) -> "Insert":
590
+ return sqlite.insert(self._dataset_version_jobs)
591
+
592
+ #
593
+ # Namespaces
594
+ #
595
+
596
+ @property
597
+ def default_namespace_name(self):
598
+ return Namespace.default()
599
+
600
+ #
601
+ # Projects
602
+ #
603
+
604
+ @property
605
+ def default_project_name(self):
606
+ return Project.default()
607
+
405
608
 
406
609
  class SQLiteWarehouse(AbstractWarehouse):
407
610
  """
@@ -409,15 +612,15 @@ class SQLiteWarehouse(AbstractWarehouse):
409
612
  This is currently used for the local cli.
410
613
  """
411
614
 
412
- db: "SQLiteDatabaseEngine"
615
+ db: SQLiteDatabaseEngine
413
616
 
414
617
  # Cache for our defined column types to dialect specific TypeEngine relations
415
618
  _col_python_type: ClassVar[dict[type, "TypeEngine"]] = {}
416
619
 
417
620
  def __init__(
418
621
  self,
419
- db: Optional["SQLiteDatabaseEngine"] = None,
420
- db_file: Optional[str] = None,
622
+ db: SQLiteDatabaseEngine | None = None,
623
+ db_file: str | None = None,
421
624
  in_memory: bool = False,
422
625
  ):
423
626
  self.schema: DefaultSchema = DefaultSchema()
@@ -445,6 +648,10 @@ class SQLiteWarehouse(AbstractWarehouse):
445
648
  {"db_clone_params": self.db.clone_params()},
446
649
  )
447
650
 
651
+ @classmethod
652
+ def serialize_callable_name(cls) -> str:
653
+ return "sqlite.warehouse.init_after_clone"
654
+
448
655
  @classmethod
449
656
  def init_after_clone(
450
657
  cls,
@@ -468,7 +675,7 @@ class SQLiteWarehouse(AbstractWarehouse):
468
675
  only=filter_tables,
469
676
  )
470
677
 
471
- def is_ready(self, timeout: Optional[int] = None) -> bool:
678
+ def is_ready(self, timeout: int | None = None) -> bool:
472
679
  return True
473
680
 
474
681
  def create_dataset_rows_table(
@@ -486,10 +693,10 @@ class SQLiteWarehouse(AbstractWarehouse):
486
693
  return table
487
694
 
488
695
  def get_dataset_sources(
489
- self, dataset: DatasetRecord, version: int
696
+ self, dataset: DatasetRecord, version: str
490
697
  ) -> list[StorageURI]:
491
698
  dr = self.dataset_rows(dataset, version)
492
- query = dr.select(dr.c("source", object_name="file")).distinct()
699
+ query = dr.select(dr.c("source", column="file")).distinct()
493
700
  cur = self.db.cursor()
494
701
  cur.row_factory = sqlite3.Row # type: ignore[assignment]
495
702
 
@@ -498,79 +705,26 @@ class SQLiteWarehouse(AbstractWarehouse):
498
705
  for row in self.db.execute(query, cursor=cur)
499
706
  ]
500
707
 
501
- def merge_dataset_rows(
502
- self,
503
- src: DatasetRecord,
504
- dst: DatasetRecord,
505
- src_version: int,
506
- dst_version: int,
507
- ) -> None:
508
- dst_empty = False
509
-
510
- if not self.db.has_table(self.dataset_table_name(src.name, src_version)):
511
- # source table doesn't exist, nothing to do
512
- return
513
-
514
- src_dr = self.dataset_rows(src, src_version).table
515
-
516
- if not self.db.has_table(self.dataset_table_name(dst.name, dst_version)):
517
- # destination table doesn't exist, create it
518
- self.create_dataset_rows_table(
519
- self.dataset_table_name(dst.name, dst_version),
520
- columns=src_dr.columns,
521
- )
522
- dst_empty = True
523
-
524
- dst_dr = self.dataset_rows(dst, dst_version).table
525
- merge_fields = [c.name for c in src_dr.columns if c.name != "sys__id"]
526
- select_src = select(*(getattr(src_dr.columns, f) for f in merge_fields))
527
-
528
- if dst_empty:
529
- # we don't need union, but just select from source to destination
530
- insert_query = sqlite.insert(dst_dr).from_select(merge_fields, select_src)
531
- else:
532
- dst_version_latest = None
533
- # find the previous version of the destination dataset
534
- dst_previous_versions = [
535
- v.version
536
- for v in dst.versions # type: ignore [union-attr]
537
- if v.version < dst_version
538
- ]
539
- if dst_previous_versions:
540
- dst_version_latest = max(dst_previous_versions)
541
-
542
- dst_dr_latest = self.dataset_rows(dst, dst_version_latest).table
543
-
544
- select_dst_latest = select(
545
- *(getattr(dst_dr_latest.c, f) for f in merge_fields)
546
- )
547
- union_query = sqlalchemy.union(select_src, select_dst_latest)
548
- insert_query = (
549
- sqlite.insert(dst_dr)
550
- .from_select(merge_fields, union_query)
551
- .prefix_with("OR IGNORE")
552
- )
553
-
554
- self.db.execute(insert_query)
555
-
556
708
  def prepare_entries(self, entries: "Iterable[File]") -> Iterable[dict[str, Any]]:
557
709
  return (e.model_dump() for e in entries)
558
710
 
559
- def insert_rows(self, table: Table, rows: Iterable[dict[str, Any]]) -> None:
560
- rows = list(rows)
561
- if not rows:
562
- return
563
-
564
- with self.db.transaction() as conn:
565
- # transactions speeds up inserts significantly as there is no separate
566
- # transaction created for each insert row
567
- self.db.executemany(
568
- table.insert().values({f: bindparam(f) for f in rows[0]}),
569
- rows,
570
- conn=conn,
571
- )
711
+ def insert_rows(
712
+ self,
713
+ table: Table,
714
+ rows: Iterable[dict[str, Any]],
715
+ batch_size: int = INSERT_BATCH_SIZE,
716
+ ) -> None:
717
+ for row_chunk in batched(rows, batch_size):
718
+ with self.db.transaction() as conn:
719
+ # transactions speeds up inserts significantly as there is no separate
720
+ # transaction created for each insert row
721
+ self.db.executemany(
722
+ table.insert().values({f: bindparam(f) for f in row_chunk[0]}),
723
+ row_chunk,
724
+ conn=conn,
725
+ )
572
726
 
573
- def insert_dataset_rows(self, df, dataset: DatasetRecord, version: int) -> int:
727
+ def insert_dataset_rows(self, df, dataset: DatasetRecord, version: str) -> int:
574
728
  dr = self.dataset_rows(dataset, version)
575
729
  return self.db.insert_dataframe(dr.table.name, df)
576
730
 
@@ -595,7 +749,7 @@ class SQLiteWarehouse(AbstractWarehouse):
595
749
  return col_type.python_type
596
750
 
597
751
  def dataset_table_export_file_names(
598
- self, dataset: DatasetRecord, version: int
752
+ self, dataset: DatasetRecord, version: str
599
753
  ) -> list[str]:
600
754
  raise NotImplementedError("Exporting dataset table not implemented for SQLite")
601
755
 
@@ -603,7 +757,7 @@ class SQLiteWarehouse(AbstractWarehouse):
603
757
  self,
604
758
  bucket_uri: str,
605
759
  dataset: DatasetRecord,
606
- version: int,
760
+ version: str,
607
761
  client_config=None,
608
762
  ) -> list[str]:
609
763
  raise NotImplementedError("Exporting dataset table not implemented for SQLite")
@@ -612,9 +766,17 @@ class SQLiteWarehouse(AbstractWarehouse):
612
766
  self,
613
767
  table: Table,
614
768
  query: Select,
615
- progress_cb: Optional[Callable[[int], None]] = None,
769
+ progress_cb: Callable[[int], None] | None = None,
616
770
  ) -> None:
617
- if len(query._group_by_clause) > 0:
771
+ col_id = (
772
+ query.selected_columns.sys__id
773
+ if "sys__id" in query.selected_columns
774
+ else None
775
+ )
776
+
777
+ # If there is no sys__id column, we cannot copy the table in batches,
778
+ # and we need to copy all rows at once. Same if there is a group by clause.
779
+ if col_id is None or len(query._group_by_clause) > 0:
618
780
  select_q = query.with_only_columns(
619
781
  *[c for c in query.selected_columns if c.name != "sys__id"]
620
782
  )
@@ -622,12 +784,7 @@ class SQLiteWarehouse(AbstractWarehouse):
622
784
  self.db.execute(q)
623
785
  return
624
786
 
625
- if "sys__id" in query.selected_columns:
626
- col_id = query.selected_columns.sys__id
627
- else:
628
- col_id = sqlalchemy.column("sys__id")
629
787
  select_ids = query.with_only_columns(col_id)
630
-
631
788
  ids = self.db.execute(select_ids).fetchall()
632
789
 
633
790
  select_q = (
@@ -638,7 +795,7 @@ class SQLiteWarehouse(AbstractWarehouse):
638
795
  .limit(None)
639
796
  )
640
797
 
641
- for batch in batched_it(ids, 10_000):
798
+ for batch in batched_it(ids, INSERT_BATCH_SIZE):
642
799
  batch_ids = [row[0] for row in batch]
643
800
  select_q._where_criteria = (col_id.in_(batch_ids),)
644
801
  q = table.insert().from_select(list(select_q.selected_columns), select_q)
@@ -693,18 +850,20 @@ class SQLiteWarehouse(AbstractWarehouse):
693
850
  if isinstance(c, BinaryExpression):
694
851
  right_left_join = add_left_rows_filter(c)
695
852
 
696
- union = sqlalchemy.union(left_right_join, right_left_join).subquery()
697
- return sqlalchemy.select(*union.c).select_from(union)
853
+ union_cte = sqlalchemy.union(left_right_join, right_left_join).cte()
854
+ return sqlalchemy.select(*union_cte.c).select_from(union_cte)
855
+
856
+ def _system_row_number_expr(self):
857
+ return func.row_number().over()
858
+
859
+ def _system_random_expr(self):
860
+ return self._system_row_number_expr() * 1103515245 + 12345
698
861
 
699
862
  def create_pre_udf_table(self, query: "Select") -> "Table":
700
863
  """
701
864
  Create a temporary table from a query for use in a UDF.
702
865
  """
703
- columns = [
704
- sqlalchemy.Column(c.name, c.type)
705
- for c in query.selected_columns
706
- if c.name != "sys__id"
707
- ]
866
+ columns = [sqlalchemy.Column(c.name, c.type) for c in query.selected_columns]
708
867
  table = self.create_udf_table(columns)
709
868
 
710
869
  with tqdm(desc="Preparing", unit=" rows", leave=False) as pbar: