datachain 0.14.2__py3-none-any.whl → 0.39.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (137) hide show
  1. datachain/__init__.py +20 -0
  2. datachain/asyn.py +11 -12
  3. datachain/cache.py +7 -7
  4. datachain/catalog/__init__.py +2 -2
  5. datachain/catalog/catalog.py +621 -507
  6. datachain/catalog/dependency.py +164 -0
  7. datachain/catalog/loader.py +28 -18
  8. datachain/checkpoint.py +43 -0
  9. datachain/cli/__init__.py +24 -33
  10. datachain/cli/commands/__init__.py +1 -8
  11. datachain/cli/commands/datasets.py +83 -52
  12. datachain/cli/commands/ls.py +17 -17
  13. datachain/cli/commands/show.py +4 -4
  14. datachain/cli/parser/__init__.py +8 -74
  15. datachain/cli/parser/job.py +95 -3
  16. datachain/cli/parser/studio.py +11 -4
  17. datachain/cli/parser/utils.py +1 -2
  18. datachain/cli/utils.py +2 -15
  19. datachain/client/azure.py +4 -4
  20. datachain/client/fsspec.py +45 -28
  21. datachain/client/gcs.py +6 -6
  22. datachain/client/hf.py +29 -2
  23. datachain/client/http.py +157 -0
  24. datachain/client/local.py +15 -11
  25. datachain/client/s3.py +17 -9
  26. datachain/config.py +4 -8
  27. datachain/data_storage/db_engine.py +12 -6
  28. datachain/data_storage/job.py +5 -1
  29. datachain/data_storage/metastore.py +1252 -186
  30. datachain/data_storage/schema.py +58 -45
  31. datachain/data_storage/serializer.py +105 -15
  32. datachain/data_storage/sqlite.py +286 -127
  33. datachain/data_storage/warehouse.py +250 -113
  34. datachain/dataset.py +353 -148
  35. datachain/delta.py +391 -0
  36. datachain/diff/__init__.py +27 -29
  37. datachain/error.py +60 -0
  38. datachain/func/__init__.py +2 -1
  39. datachain/func/aggregate.py +66 -42
  40. datachain/func/array.py +242 -38
  41. datachain/func/base.py +7 -4
  42. datachain/func/conditional.py +110 -60
  43. datachain/func/func.py +96 -45
  44. datachain/func/numeric.py +55 -38
  45. datachain/func/path.py +32 -20
  46. datachain/func/random.py +2 -2
  47. datachain/func/string.py +67 -37
  48. datachain/func/window.py +7 -8
  49. datachain/hash_utils.py +123 -0
  50. datachain/job.py +11 -7
  51. datachain/json.py +138 -0
  52. datachain/lib/arrow.py +58 -22
  53. datachain/lib/audio.py +245 -0
  54. datachain/lib/clip.py +14 -13
  55. datachain/lib/convert/flatten.py +5 -3
  56. datachain/lib/convert/python_to_sql.py +6 -10
  57. datachain/lib/convert/sql_to_python.py +8 -0
  58. datachain/lib/convert/values_to_tuples.py +156 -51
  59. datachain/lib/data_model.py +42 -20
  60. datachain/lib/dataset_info.py +36 -8
  61. datachain/lib/dc/__init__.py +8 -2
  62. datachain/lib/dc/csv.py +25 -28
  63. datachain/lib/dc/database.py +398 -0
  64. datachain/lib/dc/datachain.py +1289 -425
  65. datachain/lib/dc/datasets.py +320 -38
  66. datachain/lib/dc/hf.py +38 -24
  67. datachain/lib/dc/json.py +29 -32
  68. datachain/lib/dc/listings.py +112 -8
  69. datachain/lib/dc/pandas.py +16 -12
  70. datachain/lib/dc/parquet.py +35 -23
  71. datachain/lib/dc/records.py +31 -23
  72. datachain/lib/dc/storage.py +154 -64
  73. datachain/lib/dc/storage_pattern.py +251 -0
  74. datachain/lib/dc/utils.py +24 -16
  75. datachain/lib/dc/values.py +8 -9
  76. datachain/lib/file.py +622 -89
  77. datachain/lib/hf.py +69 -39
  78. datachain/lib/image.py +14 -14
  79. datachain/lib/listing.py +14 -11
  80. datachain/lib/listing_info.py +1 -2
  81. datachain/lib/meta_formats.py +3 -4
  82. datachain/lib/model_store.py +39 -7
  83. datachain/lib/namespaces.py +125 -0
  84. datachain/lib/projects.py +130 -0
  85. datachain/lib/pytorch.py +32 -21
  86. datachain/lib/settings.py +192 -56
  87. datachain/lib/signal_schema.py +427 -104
  88. datachain/lib/tar.py +1 -2
  89. datachain/lib/text.py +8 -7
  90. datachain/lib/udf.py +164 -76
  91. datachain/lib/udf_signature.py +60 -35
  92. datachain/lib/utils.py +118 -4
  93. datachain/lib/video.py +17 -9
  94. datachain/lib/webdataset.py +61 -56
  95. datachain/lib/webdataset_laion.py +15 -16
  96. datachain/listing.py +22 -10
  97. datachain/model/bbox.py +3 -1
  98. datachain/model/ultralytics/bbox.py +16 -12
  99. datachain/model/ultralytics/pose.py +16 -12
  100. datachain/model/ultralytics/segment.py +16 -12
  101. datachain/namespace.py +84 -0
  102. datachain/node.py +6 -6
  103. datachain/nodes_thread_pool.py +0 -1
  104. datachain/plugins.py +24 -0
  105. datachain/project.py +78 -0
  106. datachain/query/batch.py +40 -41
  107. datachain/query/dataset.py +604 -322
  108. datachain/query/dispatch.py +261 -154
  109. datachain/query/metrics.py +4 -6
  110. datachain/query/params.py +2 -3
  111. datachain/query/queue.py +3 -12
  112. datachain/query/schema.py +11 -6
  113. datachain/query/session.py +200 -33
  114. datachain/query/udf.py +34 -2
  115. datachain/remote/studio.py +171 -69
  116. datachain/script_meta.py +12 -12
  117. datachain/semver.py +68 -0
  118. datachain/sql/__init__.py +2 -0
  119. datachain/sql/functions/array.py +33 -1
  120. datachain/sql/postgresql_dialect.py +9 -0
  121. datachain/sql/postgresql_types.py +21 -0
  122. datachain/sql/sqlite/__init__.py +5 -1
  123. datachain/sql/sqlite/base.py +102 -29
  124. datachain/sql/sqlite/types.py +8 -13
  125. datachain/sql/types.py +70 -15
  126. datachain/studio.py +223 -46
  127. datachain/toolkit/split.py +31 -10
  128. datachain/utils.py +101 -59
  129. {datachain-0.14.2.dist-info → datachain-0.39.0.dist-info}/METADATA +77 -22
  130. datachain-0.39.0.dist-info/RECORD +173 -0
  131. {datachain-0.14.2.dist-info → datachain-0.39.0.dist-info}/WHEEL +1 -1
  132. datachain/cli/commands/query.py +0 -53
  133. datachain/query/utils.py +0 -42
  134. datachain-0.14.2.dist-info/RECORD +0 -158
  135. {datachain-0.14.2.dist-info → datachain-0.39.0.dist-info}/entry_points.txt +0 -0
  136. {datachain-0.14.2.dist-info → datachain-0.39.0.dist-info}/licenses/LICENSE +0 -0
  137. {datachain-0.14.2.dist-info → datachain-0.39.0.dist-info}/top_level.txt +0 -0
@@ -1,28 +1,30 @@
1
1
  import glob
2
- import json
3
2
  import logging
4
3
  import posixpath
5
- import random
4
+ import secrets
6
5
  import string
7
6
  from abc import ABC, abstractmethod
8
- from collections.abc import Generator, Iterable, Iterator, Sequence
9
- from typing import TYPE_CHECKING, Any, Callable, Optional, Union
7
+ from collections.abc import Callable, Generator, Iterable, Iterator, Sequence
8
+ from typing import TYPE_CHECKING, Any, Union, cast
10
9
  from urllib.parse import urlparse
11
10
 
12
11
  import attrs
13
12
  import sqlalchemy as sa
14
- from sqlalchemy import Table, case, select
15
- from sqlalchemy.sql import func
16
13
  from sqlalchemy.sql.expression import true
17
- from tqdm.auto import tqdm
18
14
 
15
+ from datachain import json
19
16
  from datachain.client import Client
20
17
  from datachain.data_storage.schema import convert_rows_custom_column_types
21
18
  from datachain.data_storage.serializer import Serializable
22
19
  from datachain.dataset import DatasetRecord, StorageURI
20
+ from datachain.lib.file import File
21
+ from datachain.lib.model_store import ModelStore
22
+ from datachain.lib.signal_schema import SignalSchema
23
23
  from datachain.node import DirType, DirTypeGroup, Node, NodeWithPath, get_path
24
+ from datachain.query.batch import RowsOutput
25
+ from datachain.query.schema import ColumnMeta
24
26
  from datachain.sql.functions import path as pathfunc
25
- from datachain.sql.types import Int, SQLType
27
+ from datachain.sql.types import SQLType
26
28
  from datachain.utils import sql_escape_like
27
29
 
28
30
  if TYPE_CHECKING:
@@ -31,18 +33,18 @@ if TYPE_CHECKING:
31
33
  _FromClauseArgument,
32
34
  _OnClauseArgument,
33
35
  )
34
- from sqlalchemy.sql.selectable import Select
36
+ from sqlalchemy.sql.selectable import FromClause
35
37
  from sqlalchemy.types import TypeEngine
36
38
 
37
39
  from datachain.data_storage import schema
38
40
  from datachain.data_storage.db_engine import DatabaseEngine
39
41
  from datachain.data_storage.schema import DataTable
40
- from datachain.lib.file import File
41
42
 
42
43
 
43
44
  logger = logging.getLogger("datachain")
44
45
 
45
46
  SELECT_BATCH_SIZE = 100_000 # number of rows to fetch at a time
47
+ INSERT_BATCH_SIZE = 10_000 # number of rows to insert at a time
46
48
 
47
49
 
48
50
  class AbstractWarehouse(ABC, Serializable):
@@ -69,12 +71,36 @@ class AbstractWarehouse(ABC, Serializable):
69
71
  return self
70
72
 
71
73
  def __exit__(self, exc_type, exc_value, traceback) -> None:
72
- # Default behavior is to do nothing, as connections may be shared.
73
- pass
74
+ """Default behavior is to do nothing, as connections may be shared."""
74
75
 
75
76
  def cleanup_for_tests(self):
76
77
  """Cleanup for tests."""
77
78
 
79
+ def _to_jsonable(self, obj: Any) -> Any:
80
+ """Recursively convert Python/Pydantic structures into JSON-serializable
81
+ objects.
82
+ """
83
+
84
+ if ModelStore.is_pydantic(type(obj)):
85
+ # Use Pydantic's JSON mode to ensure datetime and other non-JSON
86
+ # native types are serialized in a compatible way.
87
+ return obj.model_dump(mode="json")
88
+
89
+ if isinstance(obj, dict):
90
+ out: dict[str, Any] = {}
91
+ for k, v in obj.items():
92
+ if not isinstance(k, str):
93
+ key_str = json.dumps(self._to_jsonable(k), ensure_ascii=False)
94
+ else:
95
+ key_str = k
96
+ out[key_str] = self._to_jsonable(v)
97
+ return out
98
+
99
+ if isinstance(obj, (list, tuple, set)):
100
+ return [self._to_jsonable(i) for i in obj]
101
+
102
+ return obj
103
+
78
104
  def convert_type( # noqa: PLR0911
79
105
  self,
80
106
  val: Any,
@@ -121,11 +147,13 @@ class AbstractWarehouse(ABC, Serializable):
121
147
  if col_python_type is dict or col_type_name == "JSON":
122
148
  if value_type is str:
123
149
  return val
124
- if value_type in (dict, list):
125
- return json.dumps(val)
126
- raise ValueError(
127
- f"Cannot convert value {val!r} with type {value_type} to JSON"
128
- )
150
+ try:
151
+ json_ready = self._to_jsonable(val)
152
+ return json.dumps(json_ready, ensure_ascii=False)
153
+ except Exception as e:
154
+ raise ValueError(
155
+ f"Cannot convert value {val!r} with type {value_type} to JSON"
156
+ ) from e
129
157
 
130
158
  if isinstance(val, col_python_type):
131
159
  return val
@@ -173,22 +201,22 @@ class AbstractWarehouse(ABC, Serializable):
173
201
  #
174
202
 
175
203
  @abstractmethod
176
- def is_ready(self, timeout: Optional[int] = None) -> bool: ...
204
+ def is_ready(self, timeout: int | None = None) -> bool: ...
177
205
 
178
206
  def dataset_rows(
179
207
  self,
180
208
  dataset: DatasetRecord,
181
- version: Optional[int] = None,
182
- object_name: str = "file",
209
+ version: str | None = None,
210
+ column: str = "file",
183
211
  ):
184
212
  version = version or dataset.latest_version
185
213
 
186
- table_name = self.dataset_table_name(dataset.name, version)
214
+ table_name = self.dataset_table_name(dataset, version)
187
215
  return self.schema.dataset_row_cls(
188
216
  table_name,
189
217
  self.db,
190
218
  dataset.get_schema(version),
191
- object_name=object_name,
219
+ column=column,
192
220
  )
193
221
 
194
222
  @property
@@ -199,6 +227,15 @@ class AbstractWarehouse(ABC, Serializable):
199
227
  # Query Execution
200
228
  #
201
229
 
230
+ def query_count(self, query: sa.Select) -> int:
231
+ """Count the number of rows in a query."""
232
+ count_query = sa.select(sa.func.count(1)).select_from(query.subquery())
233
+ return next(self.db.execute(count_query))[0]
234
+
235
+ def table_rows_count(self, table) -> int:
236
+ count_query = sa.select(sa.func.count(1)).select_from(table)
237
+ return next(self.db.execute(count_query))[0]
238
+
202
239
  def dataset_select_paginated(
203
240
  self,
204
241
  query,
@@ -210,7 +247,7 @@ class AbstractWarehouse(ABC, Serializable):
210
247
  limit = query._limit
211
248
  paginated_query = query.limit(page_size)
212
249
 
213
- offset = 0
250
+ offset = query._offset or 0
214
251
  num_yielded = 0
215
252
 
216
253
  # Ensure we're using a thread-local connection
@@ -218,7 +255,8 @@ class AbstractWarehouse(ABC, Serializable):
218
255
  while True:
219
256
  if limit is not None:
220
257
  limit -= num_yielded
221
- if limit == 0:
258
+ num_yielded = 0
259
+ if limit <= 0:
222
260
  break
223
261
  if limit < page_size:
224
262
  paginated_query = paginated_query.limit(None).limit(limit)
@@ -226,16 +264,81 @@ class AbstractWarehouse(ABC, Serializable):
226
264
  # Cursor results are not thread-safe, so we convert them to a list
227
265
  results = list(wh.dataset_rows_select(paginated_query.offset(offset)))
228
266
 
229
- processed = False
267
+ processed = 0
230
268
  for row in results:
231
- processed = True
269
+ processed += 1
232
270
  yield row
233
271
  num_yielded += 1
234
272
 
235
- if not processed:
273
+ if processed < page_size:
236
274
  break # no more results
237
275
  offset += page_size
238
276
 
277
+ def _regenerate_system_columns(
278
+ self,
279
+ selectable: sa.Select,
280
+ keep_existing_columns: bool = False,
281
+ regenerate_columns: Iterable[str] | None = None,
282
+ ) -> sa.Select:
283
+ """
284
+ Return a SELECT that regenerates system columns deterministically.
285
+
286
+ If keep_existing_columns is True, existing system columns will be kept as-is
287
+ even when they are listed in ``regenerate_columns``.
288
+
289
+ Args:
290
+ selectable: Base SELECT
291
+ keep_existing_columns: When True, reuse existing system columns even if
292
+ they are part of the regeneration set.
293
+ regenerate_columns: Names of system columns to regenerate. Defaults to
294
+ {"sys__id", "sys__rand"}. Columns not listed are left untouched.
295
+ """
296
+ system_columns = {
297
+ sys_col.name: sys_col.type
298
+ for sys_col in self.schema.dataset_row_cls.sys_columns()
299
+ }
300
+ regenerate = set(regenerate_columns or system_columns)
301
+ generators = {
302
+ "sys__id": self._system_row_number_expr,
303
+ "sys__rand": self._system_random_expr,
304
+ }
305
+
306
+ base = cast("FromClause", selectable.subquery())
307
+
308
+ def build(name: str) -> sa.ColumnElement:
309
+ expr = generators[name]()
310
+ return sa.cast(expr, system_columns[name]).label(name)
311
+
312
+ columns: list[sa.ColumnElement] = []
313
+ present: set[str] = set()
314
+ changed = False
315
+
316
+ for col in base.c:
317
+ present.add(col.name)
318
+ regen = col.name in regenerate and not keep_existing_columns
319
+ columns.append(build(col.name) if regen else col)
320
+ changed |= regen
321
+
322
+ for name in regenerate - present:
323
+ columns.append(build(name))
324
+ changed = True
325
+
326
+ if not changed:
327
+ return selectable
328
+
329
+ inner = sa.select(*columns).select_from(base).subquery()
330
+ return sa.select(*inner.c).select_from(inner)
331
+
332
+ def _system_row_number_expr(self):
333
+ """Return an expression that produces deterministic row numbers."""
334
+
335
+ raise NotImplementedError
336
+
337
+ def _system_random_expr(self):
338
+ """Return an expression that produces deterministic random values."""
339
+
340
+ raise NotImplementedError
341
+
239
342
  #
240
343
  # Table Name Internal Functions
241
344
  #
@@ -246,12 +349,24 @@ class AbstractWarehouse(ABC, Serializable):
246
349
  name = parsed.path if parsed.scheme == "file" else parsed.netloc
247
350
  return parsed.scheme, name
248
351
 
249
- def dataset_table_name(self, dataset_name: str, version: int) -> str:
352
+ def dataset_table_name(self, dataset: DatasetRecord, version: str) -> str:
353
+ return self._construct_dataset_table_name(
354
+ dataset.project.namespace.name,
355
+ dataset.project.name,
356
+ dataset.name,
357
+ version,
358
+ )
359
+
360
+ def _construct_dataset_table_name(
361
+ self, namespace: str, project: str, dataset_name: str, version: str
362
+ ) -> str:
250
363
  prefix = self.DATASET_TABLE_PREFIX
251
364
  if Client.is_data_source_uri(dataset_name):
252
365
  # for datasets that are created for bucket listing we use different prefix
253
366
  prefix = self.DATASET_SOURCE_TABLE_PREFIX
254
- return f"{prefix}{dataset_name}_{version}"
367
+ return (
368
+ f"{prefix}{namespace}_{project}_{dataset_name}_{version.replace('.', '_')}"
369
+ )
255
370
 
256
371
  def temp_table_name(self) -> str:
257
372
  return self.TMP_TABLE_NAME_PREFIX + _random_string(6)
@@ -269,38 +384,26 @@ class AbstractWarehouse(ABC, Serializable):
269
384
  name: str,
270
385
  columns: Sequence["sa.Column"] = (),
271
386
  if_not_exists: bool = True,
272
- ) -> Table:
387
+ ) -> sa.Table:
273
388
  """Creates a dataset rows table for the given dataset name and columns"""
274
389
 
275
390
  def drop_dataset_rows_table(
276
391
  self,
277
392
  dataset: DatasetRecord,
278
- version: int,
393
+ version: str,
279
394
  if_exists: bool = True,
280
395
  ) -> None:
281
396
  """Drops a dataset rows table for the given dataset name."""
282
- table_name = self.dataset_table_name(dataset.name, version)
283
- table = Table(table_name, self.db.metadata)
397
+ table_name = self.dataset_table_name(dataset, version)
398
+ table = sa.Table(table_name, self.db.metadata)
284
399
  self.db.drop_table(table, if_exists=if_exists)
285
-
286
- @abstractmethod
287
- def merge_dataset_rows(
288
- self,
289
- src: "DatasetRecord",
290
- dst: "DatasetRecord",
291
- src_version: int,
292
- dst_version: int,
293
- ) -> None:
294
- """
295
- Merges source dataset rows and current latest destination dataset rows
296
- into a new rows table created for new destination dataset version.
297
- Note that table for new destination version must be created upfront.
298
- Merge results should not contain duplicates.
299
- """
400
+ # Remove from metadata cache to allow recreation
401
+ if table_name in self.db.metadata.tables:
402
+ self.db.metadata.remove(self.db.metadata.tables[table_name])
300
403
 
301
404
  def dataset_rows_select(
302
405
  self,
303
- query: sa.sql.selectable.Select,
406
+ query: sa.Select,
304
407
  **kwargs,
305
408
  ) -> Iterator[tuple[Any, ...]]:
306
409
  """
@@ -311,51 +414,81 @@ class AbstractWarehouse(ABC, Serializable):
311
414
  query.selected_columns, rows, self.db.dialect
312
415
  )
313
416
 
417
+ def dataset_rows_select_from_ids(
418
+ self,
419
+ query: sa.Select,
420
+ ids: Iterable[RowsOutput],
421
+ is_batched: bool,
422
+ ) -> Iterator[RowsOutput]:
423
+ """
424
+ Fetch dataset rows from database using a list of IDs.
425
+ """
426
+ if (id_col := query.selected_columns.get("sys__id")) is None:
427
+ raise RuntimeError("sys__id column not found in query")
428
+
429
+ query = query._clone().offset(None).limit(None).order_by(None)
430
+
431
+ if is_batched:
432
+ for batch in ids:
433
+ yield list(self.dataset_rows_select(query.where(id_col.in_(batch))))
434
+ else:
435
+ yield from self.dataset_rows_select(query.where(id_col.in_(ids)))
436
+
314
437
  @abstractmethod
315
438
  def get_dataset_sources(
316
- self, dataset: DatasetRecord, version: int
439
+ self, dataset: DatasetRecord, version: str
317
440
  ) -> list[StorageURI]: ...
318
441
 
319
- def rename_dataset_table(
320
- self,
321
- old_name: str,
322
- new_name: str,
323
- old_version: int,
324
- new_version: int,
442
+ def rename_dataset_tables(
443
+ self, dataset: DatasetRecord, dataset_updated: DatasetRecord
325
444
  ) -> None:
326
- old_ds_table_name = self.dataset_table_name(old_name, old_version)
327
- new_ds_table_name = self.dataset_table_name(new_name, new_version)
328
-
329
- self.db.rename_table(old_ds_table_name, new_ds_table_name)
445
+ """
446
+ Renames all dataset version tables when parts of the dataset that
447
+ are used in constructing table name are updated.
448
+ If nothing important is changed, nothing will be renamed (no DB calls
449
+ will be made at all).
450
+ """
451
+ for version in [v.version for v in dataset_updated.versions]:
452
+ if not dataset.has_version(version):
453
+ continue
454
+ src = self.dataset_table_name(dataset, version)
455
+ dest = self.dataset_table_name(dataset_updated, version)
456
+ if src == dest:
457
+ continue
458
+ self.db.rename_table(src, dest)
330
459
 
331
460
  def dataset_rows_count(self, dataset: DatasetRecord, version=None) -> int:
332
461
  """Returns total number of rows in a dataset"""
333
462
  dr = self.dataset_rows(dataset, version)
334
463
  table = dr.get_table()
335
- query = select(sa.func.count(table.c.sys__id))
464
+ query = sa.select(sa.func.count(table.c.sys__id))
336
465
  (res,) = self.db.execute(query)
337
466
  return res[0]
338
467
 
339
468
  def dataset_stats(
340
- self, dataset: DatasetRecord, version: int
341
- ) -> tuple[Optional[int], Optional[int]]:
469
+ self, dataset: DatasetRecord, version: str
470
+ ) -> tuple[int | None, int | None]:
342
471
  """
343
472
  Returns tuple with dataset stats: total number of rows and total dataset size.
344
473
  """
345
- if not (self.db.has_table(self.dataset_table_name(dataset.name, version))):
474
+ if not (self.db.has_table(self.dataset_table_name(dataset, version))):
346
475
  return None, None
347
476
 
477
+ file_signals = list(
478
+ SignalSchema.deserialize(dataset.feature_schema).get_signals(File)
479
+ )
480
+
348
481
  dr = self.dataset_rows(dataset, version)
349
482
  table = dr.get_table()
350
483
  expressions: tuple[_ColumnsClauseArgument[Any], ...] = (
351
484
  sa.func.count(table.c.sys__id),
352
485
  )
353
- size_columns = [
354
- c for c in table.columns if c.name == "size" or c.name.endswith("__size")
355
- ]
486
+ size_column_names = [ColumnMeta.to_db_name(s) + "__size" for s in file_signals]
487
+ size_columns = [c for c in table.columns if c.name in size_column_names]
488
+
356
489
  if size_columns:
357
490
  expressions = (*expressions, sa.func.sum(sum(size_columns)))
358
- query = select(*expressions)
491
+ query = sa.select(*expressions)
359
492
  ((nrows, *rest),) = self.db.execute(query)
360
493
  return nrows, rest[0] if rest else 0
361
494
 
@@ -364,17 +497,22 @@ class AbstractWarehouse(ABC, Serializable):
364
497
  """Convert File entries so they can be passed on to `insert_rows()`"""
365
498
 
366
499
  @abstractmethod
367
- def insert_rows(self, table: Table, rows: Iterable[dict[str, Any]]) -> None:
500
+ def insert_rows(
501
+ self,
502
+ table: sa.Table,
503
+ rows: Iterable[dict[str, Any]],
504
+ batch_size: int = INSERT_BATCH_SIZE,
505
+ ) -> None:
368
506
  """Does batch inserts of any kind of rows into table"""
369
507
 
370
- def insert_rows_done(self, table: Table) -> None:
508
+ def insert_rows_done(self, table: sa.Table) -> None:
371
509
  """
372
510
  Only needed for certain implementations
373
511
  to signal when rows inserts are complete.
374
512
  """
375
513
 
376
514
  @abstractmethod
377
- def insert_dataset_rows(self, df, dataset: DatasetRecord, version: int) -> int:
515
+ def insert_dataset_rows(self, df, dataset: DatasetRecord, version: str) -> int:
378
516
  """Inserts dataset rows directly into dataset table"""
379
517
 
380
518
  @abstractmethod
@@ -393,7 +531,7 @@ class AbstractWarehouse(ABC, Serializable):
393
531
 
394
532
  @abstractmethod
395
533
  def dataset_table_export_file_names(
396
- self, dataset: DatasetRecord, version: int
534
+ self, dataset: DatasetRecord, version: str
397
535
  ) -> list[str]:
398
536
  """
399
537
  Returns list of file names that will be created when user runs dataset export
@@ -404,7 +542,7 @@ class AbstractWarehouse(ABC, Serializable):
404
542
  self,
405
543
  bucket_uri: str,
406
544
  dataset: DatasetRecord,
407
- version: int,
545
+ version: str,
408
546
  client_config=None,
409
547
  ) -> list[str]:
410
548
  """
@@ -454,7 +592,7 @@ class AbstractWarehouse(ABC, Serializable):
454
592
  dr = dataset_rows
455
593
  columns = [c.name for c in query.selected_columns]
456
594
  for row in self.db.execute(query):
457
- d = dict(zip(columns, row))
595
+ d = dict(zip(columns, row, strict=False))
458
596
  yield Node(**{dr.without_object(k): v for k, v in d.items()})
459
597
 
460
598
  def get_dirs_by_parent_path(
@@ -478,7 +616,7 @@ class AbstractWarehouse(ABC, Serializable):
478
616
  dataset_rows: "DataTable",
479
617
  path_list: list[str],
480
618
  glob_name: str,
481
- object_name="file",
619
+ column="file",
482
620
  ) -> Iterator[Node]:
483
621
  """Finds all Nodes that correspond to GLOB like path pattern."""
484
622
  dr = dataset_rows
@@ -488,7 +626,7 @@ class AbstractWarehouse(ABC, Serializable):
488
626
  ).subquery()
489
627
  path_glob = "/".join([*path_list, glob_name])
490
628
  dirpath = path_glob[: -len(glob_name)]
491
- relpath = func.substr(de.c(q, "path"), len(dirpath) + 1)
629
+ relpath = sa.func.substr(de.c(q, "path"), len(dirpath) + 1)
492
630
 
493
631
  return self.get_nodes(
494
632
  self.expand_query(de, q, dr)
@@ -512,7 +650,7 @@ class AbstractWarehouse(ABC, Serializable):
512
650
  de = dr.dir_expansion()
513
651
  q = de.query(
514
652
  dr.select().where(dr.c("is_latest") == true()).subquery(),
515
- object_name=dr.object_name,
653
+ column=dr.column,
516
654
  ).subquery()
517
655
  q = self.expand_query(de, q, dr)
518
656
 
@@ -575,25 +713,23 @@ class AbstractWarehouse(ABC, Serializable):
575
713
  default = getattr(
576
714
  attrs.fields(Node), dr.without_object(column.name)
577
715
  ).default
578
- return func.coalesce(column, default).label(column.name)
716
+ return sa.func.coalesce(column, default).label(column.name)
579
717
 
580
718
  return sa.select(
581
719
  q.c.sys__id,
582
- case((de.c(q, "is_dir") == true(), DirType.DIR), else_=DirType.FILE).label(
583
- dr.col_name("dir_type")
584
- ),
720
+ sa.case(
721
+ (de.c(q, "is_dir") == true(), DirType.DIR), else_=DirType.FILE
722
+ ).label(dr.col_name("dir_type")),
585
723
  de.c(q, "path"),
586
724
  with_default(dr.c("etag")),
587
725
  de.c(q, "version"),
588
726
  with_default(dr.c("is_latest")),
589
727
  dr.c("last_modified"),
590
728
  with_default(dr.c("size")),
591
- with_default(dr.c("rand", object_name="sys")),
729
+ with_default(dr.c("rand", column="sys")),
592
730
  dr.c("location"),
593
731
  de.c(q, "source"),
594
- ).select_from(
595
- q.outerjoin(dr.table, q.c.sys__id == dr.c("id", object_name="sys"))
596
- )
732
+ ).select_from(q.outerjoin(dr.table, q.c.sys__id == dr.c("id", column="sys")))
597
733
 
598
734
  def get_node_by_path(self, dataset_rows: "DataTable", path: str) -> Node:
599
735
  """Gets node that corresponds to some path"""
@@ -658,7 +794,7 @@ class AbstractWarehouse(ABC, Serializable):
658
794
  return de.c(inner_query, f)
659
795
 
660
796
  return self.db.execute(
661
- select(*(field_to_expr(f) for f in fields)).order_by(
797
+ sa.select(*(field_to_expr(f) for f in fields)).order_by(
662
798
  de.c(inner_query, "source"),
663
799
  de.c(inner_query, "path"),
664
800
  de.c(inner_query, "version"),
@@ -680,7 +816,7 @@ class AbstractWarehouse(ABC, Serializable):
680
816
  return dr.c(f)
681
817
 
682
818
  q = (
683
- select(*(field_to_expr(f) for f in fields))
819
+ sa.select(*(field_to_expr(f) for f in fields))
684
820
  .where(
685
821
  dr.c("path").like(f"{sql_escape_like(dirpath)}%"),
686
822
  ~self.instr(pathfunc.name(dr.c("path")), "/"),
@@ -693,7 +829,7 @@ class AbstractWarehouse(ABC, Serializable):
693
829
  def size(
694
830
  self,
695
831
  dataset_rows: "DataTable",
696
- node: Union[Node, dict[str, Any]],
832
+ node: Node | dict[str, Any],
697
833
  count_files: bool = False,
698
834
  ) -> tuple[int, int]:
699
835
  """
@@ -715,10 +851,10 @@ class AbstractWarehouse(ABC, Serializable):
715
851
  sub_glob = posixpath.join(path, "*")
716
852
  dr = dataset_rows
717
853
  selections: list[sa.ColumnElement] = [
718
- func.sum(dr.c("size")),
854
+ sa.func.sum(dr.c("size")),
719
855
  ]
720
856
  if count_files:
721
- selections.append(func.count())
857
+ selections.append(sa.func.count())
722
858
  results = next(
723
859
  self.db.execute(
724
860
  dr.select(*selections).where(
@@ -735,10 +871,10 @@ class AbstractWarehouse(ABC, Serializable):
735
871
  self,
736
872
  dataset_rows: "DataTable",
737
873
  parent_path: str,
738
- fields: Optional[Sequence[str]] = None,
739
- type: Optional[str] = None,
874
+ fields: Sequence[str] | None = None,
875
+ type: str | None = None,
740
876
  conds=None,
741
- order_by: Optional[Union[str, list[str]]] = None,
877
+ order_by: str | list[str] | None = None,
742
878
  include_subobjects: bool = True,
743
879
  ) -> sa.Select:
744
880
  if not conds:
@@ -776,7 +912,7 @@ class AbstractWarehouse(ABC, Serializable):
776
912
  self,
777
913
  dataset_rows: "DataTable",
778
914
  node: Node,
779
- sort: Union[list[str], str, None] = None,
915
+ sort: list[str] | str | None = None,
780
916
  include_subobjects: bool = True,
781
917
  ) -> Iterator[NodeWithPath]:
782
918
  """
@@ -834,28 +970,33 @@ class AbstractWarehouse(ABC, Serializable):
834
970
  def create_udf_table(
835
971
  self,
836
972
  columns: Sequence["sa.Column"] = (),
837
- name: Optional[str] = None,
838
- ) -> "sa.Table":
973
+ name: str | None = None,
974
+ ) -> sa.Table:
839
975
  """
840
976
  Create a temporary table for storing custom signals generated by a UDF.
841
977
  SQLite TEMPORARY tables cannot be directly used as they are process-specific,
842
978
  and UDFs are run in other processes when run in parallel.
843
979
  """
980
+ columns = [
981
+ c
982
+ for c in columns
983
+ if c.name not in [col.name for col in self.dataset_row_cls.sys_columns()]
984
+ ]
844
985
  tbl = sa.Table(
845
986
  name or self.udf_table_name(),
846
987
  sa.MetaData(),
847
- sa.Column("sys__id", Int, primary_key=True),
988
+ *self.dataset_row_cls.sys_columns(),
848
989
  *columns,
849
990
  )
850
- self.db.create_table(tbl, if_not_exists=True)
991
+ self.db.create_table(tbl, if_not_exists=True, kind="udf")
851
992
  return tbl
852
993
 
853
994
  @abstractmethod
854
995
  def copy_table(
855
996
  self,
856
- table: Table,
857
- query: "Select",
858
- progress_cb: Optional[Callable[[int], None]] = None,
997
+ table: sa.Table,
998
+ query: sa.Select,
999
+ progress_cb: Callable[[int], None] | None = None,
859
1000
  ) -> None:
860
1001
  """
861
1002
  Copy the results of a query into a table.
@@ -868,13 +1009,15 @@ class AbstractWarehouse(ABC, Serializable):
868
1009
  right: "_FromClauseArgument",
869
1010
  onclause: "_OnClauseArgument",
870
1011
  inner: bool = True,
871
- ) -> "Select":
1012
+ full: bool = False,
1013
+ columns=None,
1014
+ ) -> sa.Select:
872
1015
  """
873
1016
  Join two tables together.
874
1017
  """
875
1018
 
876
1019
  @abstractmethod
877
- def create_pre_udf_table(self, query: "Select") -> "Table":
1020
+ def create_pre_udf_table(self, query: sa.Select) -> sa.Table:
878
1021
  """
879
1022
  Create a temporary table from a query for use in a UDF.
880
1023
  """
@@ -899,16 +1042,10 @@ class AbstractWarehouse(ABC, Serializable):
899
1042
  are cleaned up as soon as they are no longer needed.
900
1043
  """
901
1044
  to_drop = set(names)
902
- with tqdm(
903
- desc="Cleanup", unit=" tables", total=len(to_drop), leave=False
904
- ) as pbar:
905
- for name in to_drop:
906
- self.db.drop_table(Table(name, self.db.metadata), if_exists=True)
907
- pbar.update(1)
1045
+ for name in to_drop:
1046
+ self.db.drop_table(sa.Table(name, self.db.metadata), if_exists=True)
908
1047
 
909
1048
 
910
1049
  def _random_string(length: int) -> str:
911
- return "".join(
912
- random.choice(string.ascii_letters + string.digits) # noqa: S311
913
- for i in range(length)
914
- )
1050
+ alphabet = string.ascii_letters + string.digits
1051
+ return "".join(secrets.choice(alphabet) for _ in range(length))