datachain 0.14.2__py3-none-any.whl → 0.39.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- datachain/__init__.py +20 -0
- datachain/asyn.py +11 -12
- datachain/cache.py +7 -7
- datachain/catalog/__init__.py +2 -2
- datachain/catalog/catalog.py +621 -507
- datachain/catalog/dependency.py +164 -0
- datachain/catalog/loader.py +28 -18
- datachain/checkpoint.py +43 -0
- datachain/cli/__init__.py +24 -33
- datachain/cli/commands/__init__.py +1 -8
- datachain/cli/commands/datasets.py +83 -52
- datachain/cli/commands/ls.py +17 -17
- datachain/cli/commands/show.py +4 -4
- datachain/cli/parser/__init__.py +8 -74
- datachain/cli/parser/job.py +95 -3
- datachain/cli/parser/studio.py +11 -4
- datachain/cli/parser/utils.py +1 -2
- datachain/cli/utils.py +2 -15
- datachain/client/azure.py +4 -4
- datachain/client/fsspec.py +45 -28
- datachain/client/gcs.py +6 -6
- datachain/client/hf.py +29 -2
- datachain/client/http.py +157 -0
- datachain/client/local.py +15 -11
- datachain/client/s3.py +17 -9
- datachain/config.py +4 -8
- datachain/data_storage/db_engine.py +12 -6
- datachain/data_storage/job.py +5 -1
- datachain/data_storage/metastore.py +1252 -186
- datachain/data_storage/schema.py +58 -45
- datachain/data_storage/serializer.py +105 -15
- datachain/data_storage/sqlite.py +286 -127
- datachain/data_storage/warehouse.py +250 -113
- datachain/dataset.py +353 -148
- datachain/delta.py +391 -0
- datachain/diff/__init__.py +27 -29
- datachain/error.py +60 -0
- datachain/func/__init__.py +2 -1
- datachain/func/aggregate.py +66 -42
- datachain/func/array.py +242 -38
- datachain/func/base.py +7 -4
- datachain/func/conditional.py +110 -60
- datachain/func/func.py +96 -45
- datachain/func/numeric.py +55 -38
- datachain/func/path.py +32 -20
- datachain/func/random.py +2 -2
- datachain/func/string.py +67 -37
- datachain/func/window.py +7 -8
- datachain/hash_utils.py +123 -0
- datachain/job.py +11 -7
- datachain/json.py +138 -0
- datachain/lib/arrow.py +58 -22
- datachain/lib/audio.py +245 -0
- datachain/lib/clip.py +14 -13
- datachain/lib/convert/flatten.py +5 -3
- datachain/lib/convert/python_to_sql.py +6 -10
- datachain/lib/convert/sql_to_python.py +8 -0
- datachain/lib/convert/values_to_tuples.py +156 -51
- datachain/lib/data_model.py +42 -20
- datachain/lib/dataset_info.py +36 -8
- datachain/lib/dc/__init__.py +8 -2
- datachain/lib/dc/csv.py +25 -28
- datachain/lib/dc/database.py +398 -0
- datachain/lib/dc/datachain.py +1289 -425
- datachain/lib/dc/datasets.py +320 -38
- datachain/lib/dc/hf.py +38 -24
- datachain/lib/dc/json.py +29 -32
- datachain/lib/dc/listings.py +112 -8
- datachain/lib/dc/pandas.py +16 -12
- datachain/lib/dc/parquet.py +35 -23
- datachain/lib/dc/records.py +31 -23
- datachain/lib/dc/storage.py +154 -64
- datachain/lib/dc/storage_pattern.py +251 -0
- datachain/lib/dc/utils.py +24 -16
- datachain/lib/dc/values.py +8 -9
- datachain/lib/file.py +622 -89
- datachain/lib/hf.py +69 -39
- datachain/lib/image.py +14 -14
- datachain/lib/listing.py +14 -11
- datachain/lib/listing_info.py +1 -2
- datachain/lib/meta_formats.py +3 -4
- datachain/lib/model_store.py +39 -7
- datachain/lib/namespaces.py +125 -0
- datachain/lib/projects.py +130 -0
- datachain/lib/pytorch.py +32 -21
- datachain/lib/settings.py +192 -56
- datachain/lib/signal_schema.py +427 -104
- datachain/lib/tar.py +1 -2
- datachain/lib/text.py +8 -7
- datachain/lib/udf.py +164 -76
- datachain/lib/udf_signature.py +60 -35
- datachain/lib/utils.py +118 -4
- datachain/lib/video.py +17 -9
- datachain/lib/webdataset.py +61 -56
- datachain/lib/webdataset_laion.py +15 -16
- datachain/listing.py +22 -10
- datachain/model/bbox.py +3 -1
- datachain/model/ultralytics/bbox.py +16 -12
- datachain/model/ultralytics/pose.py +16 -12
- datachain/model/ultralytics/segment.py +16 -12
- datachain/namespace.py +84 -0
- datachain/node.py +6 -6
- datachain/nodes_thread_pool.py +0 -1
- datachain/plugins.py +24 -0
- datachain/project.py +78 -0
- datachain/query/batch.py +40 -41
- datachain/query/dataset.py +604 -322
- datachain/query/dispatch.py +261 -154
- datachain/query/metrics.py +4 -6
- datachain/query/params.py +2 -3
- datachain/query/queue.py +3 -12
- datachain/query/schema.py +11 -6
- datachain/query/session.py +200 -33
- datachain/query/udf.py +34 -2
- datachain/remote/studio.py +171 -69
- datachain/script_meta.py +12 -12
- datachain/semver.py +68 -0
- datachain/sql/__init__.py +2 -0
- datachain/sql/functions/array.py +33 -1
- datachain/sql/postgresql_dialect.py +9 -0
- datachain/sql/postgresql_types.py +21 -0
- datachain/sql/sqlite/__init__.py +5 -1
- datachain/sql/sqlite/base.py +102 -29
- datachain/sql/sqlite/types.py +8 -13
- datachain/sql/types.py +70 -15
- datachain/studio.py +223 -46
- datachain/toolkit/split.py +31 -10
- datachain/utils.py +101 -59
- {datachain-0.14.2.dist-info → datachain-0.39.0.dist-info}/METADATA +77 -22
- datachain-0.39.0.dist-info/RECORD +173 -0
- {datachain-0.14.2.dist-info → datachain-0.39.0.dist-info}/WHEEL +1 -1
- datachain/cli/commands/query.py +0 -53
- datachain/query/utils.py +0 -42
- datachain-0.14.2.dist-info/RECORD +0 -158
- {datachain-0.14.2.dist-info → datachain-0.39.0.dist-info}/entry_points.txt +0 -0
- {datachain-0.14.2.dist-info → datachain-0.39.0.dist-info}/licenses/LICENSE +0 -0
- {datachain-0.14.2.dist-info → datachain-0.39.0.dist-info}/top_level.txt +0 -0
|
@@ -1,28 +1,30 @@
|
|
|
1
1
|
import glob
|
|
2
|
-
import json
|
|
3
2
|
import logging
|
|
4
3
|
import posixpath
|
|
5
|
-
import
|
|
4
|
+
import secrets
|
|
6
5
|
import string
|
|
7
6
|
from abc import ABC, abstractmethod
|
|
8
|
-
from collections.abc import Generator, Iterable, Iterator, Sequence
|
|
9
|
-
from typing import TYPE_CHECKING, Any,
|
|
7
|
+
from collections.abc import Callable, Generator, Iterable, Iterator, Sequence
|
|
8
|
+
from typing import TYPE_CHECKING, Any, Union, cast
|
|
10
9
|
from urllib.parse import urlparse
|
|
11
10
|
|
|
12
11
|
import attrs
|
|
13
12
|
import sqlalchemy as sa
|
|
14
|
-
from sqlalchemy import Table, case, select
|
|
15
|
-
from sqlalchemy.sql import func
|
|
16
13
|
from sqlalchemy.sql.expression import true
|
|
17
|
-
from tqdm.auto import tqdm
|
|
18
14
|
|
|
15
|
+
from datachain import json
|
|
19
16
|
from datachain.client import Client
|
|
20
17
|
from datachain.data_storage.schema import convert_rows_custom_column_types
|
|
21
18
|
from datachain.data_storage.serializer import Serializable
|
|
22
19
|
from datachain.dataset import DatasetRecord, StorageURI
|
|
20
|
+
from datachain.lib.file import File
|
|
21
|
+
from datachain.lib.model_store import ModelStore
|
|
22
|
+
from datachain.lib.signal_schema import SignalSchema
|
|
23
23
|
from datachain.node import DirType, DirTypeGroup, Node, NodeWithPath, get_path
|
|
24
|
+
from datachain.query.batch import RowsOutput
|
|
25
|
+
from datachain.query.schema import ColumnMeta
|
|
24
26
|
from datachain.sql.functions import path as pathfunc
|
|
25
|
-
from datachain.sql.types import
|
|
27
|
+
from datachain.sql.types import SQLType
|
|
26
28
|
from datachain.utils import sql_escape_like
|
|
27
29
|
|
|
28
30
|
if TYPE_CHECKING:
|
|
@@ -31,18 +33,18 @@ if TYPE_CHECKING:
|
|
|
31
33
|
_FromClauseArgument,
|
|
32
34
|
_OnClauseArgument,
|
|
33
35
|
)
|
|
34
|
-
from sqlalchemy.sql.selectable import
|
|
36
|
+
from sqlalchemy.sql.selectable import FromClause
|
|
35
37
|
from sqlalchemy.types import TypeEngine
|
|
36
38
|
|
|
37
39
|
from datachain.data_storage import schema
|
|
38
40
|
from datachain.data_storage.db_engine import DatabaseEngine
|
|
39
41
|
from datachain.data_storage.schema import DataTable
|
|
40
|
-
from datachain.lib.file import File
|
|
41
42
|
|
|
42
43
|
|
|
43
44
|
logger = logging.getLogger("datachain")
|
|
44
45
|
|
|
45
46
|
SELECT_BATCH_SIZE = 100_000 # number of rows to fetch at a time
|
|
47
|
+
INSERT_BATCH_SIZE = 10_000 # number of rows to insert at a time
|
|
46
48
|
|
|
47
49
|
|
|
48
50
|
class AbstractWarehouse(ABC, Serializable):
|
|
@@ -69,12 +71,36 @@ class AbstractWarehouse(ABC, Serializable):
|
|
|
69
71
|
return self
|
|
70
72
|
|
|
71
73
|
def __exit__(self, exc_type, exc_value, traceback) -> None:
|
|
72
|
-
|
|
73
|
-
pass
|
|
74
|
+
"""Default behavior is to do nothing, as connections may be shared."""
|
|
74
75
|
|
|
75
76
|
def cleanup_for_tests(self):
|
|
76
77
|
"""Cleanup for tests."""
|
|
77
78
|
|
|
79
|
+
def _to_jsonable(self, obj: Any) -> Any:
|
|
80
|
+
"""Recursively convert Python/Pydantic structures into JSON-serializable
|
|
81
|
+
objects.
|
|
82
|
+
"""
|
|
83
|
+
|
|
84
|
+
if ModelStore.is_pydantic(type(obj)):
|
|
85
|
+
# Use Pydantic's JSON mode to ensure datetime and other non-JSON
|
|
86
|
+
# native types are serialized in a compatible way.
|
|
87
|
+
return obj.model_dump(mode="json")
|
|
88
|
+
|
|
89
|
+
if isinstance(obj, dict):
|
|
90
|
+
out: dict[str, Any] = {}
|
|
91
|
+
for k, v in obj.items():
|
|
92
|
+
if not isinstance(k, str):
|
|
93
|
+
key_str = json.dumps(self._to_jsonable(k), ensure_ascii=False)
|
|
94
|
+
else:
|
|
95
|
+
key_str = k
|
|
96
|
+
out[key_str] = self._to_jsonable(v)
|
|
97
|
+
return out
|
|
98
|
+
|
|
99
|
+
if isinstance(obj, (list, tuple, set)):
|
|
100
|
+
return [self._to_jsonable(i) for i in obj]
|
|
101
|
+
|
|
102
|
+
return obj
|
|
103
|
+
|
|
78
104
|
def convert_type( # noqa: PLR0911
|
|
79
105
|
self,
|
|
80
106
|
val: Any,
|
|
@@ -121,11 +147,13 @@ class AbstractWarehouse(ABC, Serializable):
|
|
|
121
147
|
if col_python_type is dict or col_type_name == "JSON":
|
|
122
148
|
if value_type is str:
|
|
123
149
|
return val
|
|
124
|
-
|
|
125
|
-
|
|
126
|
-
|
|
127
|
-
|
|
128
|
-
|
|
150
|
+
try:
|
|
151
|
+
json_ready = self._to_jsonable(val)
|
|
152
|
+
return json.dumps(json_ready, ensure_ascii=False)
|
|
153
|
+
except Exception as e:
|
|
154
|
+
raise ValueError(
|
|
155
|
+
f"Cannot convert value {val!r} with type {value_type} to JSON"
|
|
156
|
+
) from e
|
|
129
157
|
|
|
130
158
|
if isinstance(val, col_python_type):
|
|
131
159
|
return val
|
|
@@ -173,22 +201,22 @@ class AbstractWarehouse(ABC, Serializable):
|
|
|
173
201
|
#
|
|
174
202
|
|
|
175
203
|
@abstractmethod
|
|
176
|
-
def is_ready(self, timeout:
|
|
204
|
+
def is_ready(self, timeout: int | None = None) -> bool: ...
|
|
177
205
|
|
|
178
206
|
def dataset_rows(
|
|
179
207
|
self,
|
|
180
208
|
dataset: DatasetRecord,
|
|
181
|
-
version:
|
|
182
|
-
|
|
209
|
+
version: str | None = None,
|
|
210
|
+
column: str = "file",
|
|
183
211
|
):
|
|
184
212
|
version = version or dataset.latest_version
|
|
185
213
|
|
|
186
|
-
table_name = self.dataset_table_name(dataset
|
|
214
|
+
table_name = self.dataset_table_name(dataset, version)
|
|
187
215
|
return self.schema.dataset_row_cls(
|
|
188
216
|
table_name,
|
|
189
217
|
self.db,
|
|
190
218
|
dataset.get_schema(version),
|
|
191
|
-
|
|
219
|
+
column=column,
|
|
192
220
|
)
|
|
193
221
|
|
|
194
222
|
@property
|
|
@@ -199,6 +227,15 @@ class AbstractWarehouse(ABC, Serializable):
|
|
|
199
227
|
# Query Execution
|
|
200
228
|
#
|
|
201
229
|
|
|
230
|
+
def query_count(self, query: sa.Select) -> int:
|
|
231
|
+
"""Count the number of rows in a query."""
|
|
232
|
+
count_query = sa.select(sa.func.count(1)).select_from(query.subquery())
|
|
233
|
+
return next(self.db.execute(count_query))[0]
|
|
234
|
+
|
|
235
|
+
def table_rows_count(self, table) -> int:
|
|
236
|
+
count_query = sa.select(sa.func.count(1)).select_from(table)
|
|
237
|
+
return next(self.db.execute(count_query))[0]
|
|
238
|
+
|
|
202
239
|
def dataset_select_paginated(
|
|
203
240
|
self,
|
|
204
241
|
query,
|
|
@@ -210,7 +247,7 @@ class AbstractWarehouse(ABC, Serializable):
|
|
|
210
247
|
limit = query._limit
|
|
211
248
|
paginated_query = query.limit(page_size)
|
|
212
249
|
|
|
213
|
-
offset = 0
|
|
250
|
+
offset = query._offset or 0
|
|
214
251
|
num_yielded = 0
|
|
215
252
|
|
|
216
253
|
# Ensure we're using a thread-local connection
|
|
@@ -218,7 +255,8 @@ class AbstractWarehouse(ABC, Serializable):
|
|
|
218
255
|
while True:
|
|
219
256
|
if limit is not None:
|
|
220
257
|
limit -= num_yielded
|
|
221
|
-
|
|
258
|
+
num_yielded = 0
|
|
259
|
+
if limit <= 0:
|
|
222
260
|
break
|
|
223
261
|
if limit < page_size:
|
|
224
262
|
paginated_query = paginated_query.limit(None).limit(limit)
|
|
@@ -226,16 +264,81 @@ class AbstractWarehouse(ABC, Serializable):
|
|
|
226
264
|
# Cursor results are not thread-safe, so we convert them to a list
|
|
227
265
|
results = list(wh.dataset_rows_select(paginated_query.offset(offset)))
|
|
228
266
|
|
|
229
|
-
processed =
|
|
267
|
+
processed = 0
|
|
230
268
|
for row in results:
|
|
231
|
-
processed
|
|
269
|
+
processed += 1
|
|
232
270
|
yield row
|
|
233
271
|
num_yielded += 1
|
|
234
272
|
|
|
235
|
-
if
|
|
273
|
+
if processed < page_size:
|
|
236
274
|
break # no more results
|
|
237
275
|
offset += page_size
|
|
238
276
|
|
|
277
|
+
def _regenerate_system_columns(
|
|
278
|
+
self,
|
|
279
|
+
selectable: sa.Select,
|
|
280
|
+
keep_existing_columns: bool = False,
|
|
281
|
+
regenerate_columns: Iterable[str] | None = None,
|
|
282
|
+
) -> sa.Select:
|
|
283
|
+
"""
|
|
284
|
+
Return a SELECT that regenerates system columns deterministically.
|
|
285
|
+
|
|
286
|
+
If keep_existing_columns is True, existing system columns will be kept as-is
|
|
287
|
+
even when they are listed in ``regenerate_columns``.
|
|
288
|
+
|
|
289
|
+
Args:
|
|
290
|
+
selectable: Base SELECT
|
|
291
|
+
keep_existing_columns: When True, reuse existing system columns even if
|
|
292
|
+
they are part of the regeneration set.
|
|
293
|
+
regenerate_columns: Names of system columns to regenerate. Defaults to
|
|
294
|
+
{"sys__id", "sys__rand"}. Columns not listed are left untouched.
|
|
295
|
+
"""
|
|
296
|
+
system_columns = {
|
|
297
|
+
sys_col.name: sys_col.type
|
|
298
|
+
for sys_col in self.schema.dataset_row_cls.sys_columns()
|
|
299
|
+
}
|
|
300
|
+
regenerate = set(regenerate_columns or system_columns)
|
|
301
|
+
generators = {
|
|
302
|
+
"sys__id": self._system_row_number_expr,
|
|
303
|
+
"sys__rand": self._system_random_expr,
|
|
304
|
+
}
|
|
305
|
+
|
|
306
|
+
base = cast("FromClause", selectable.subquery())
|
|
307
|
+
|
|
308
|
+
def build(name: str) -> sa.ColumnElement:
|
|
309
|
+
expr = generators[name]()
|
|
310
|
+
return sa.cast(expr, system_columns[name]).label(name)
|
|
311
|
+
|
|
312
|
+
columns: list[sa.ColumnElement] = []
|
|
313
|
+
present: set[str] = set()
|
|
314
|
+
changed = False
|
|
315
|
+
|
|
316
|
+
for col in base.c:
|
|
317
|
+
present.add(col.name)
|
|
318
|
+
regen = col.name in regenerate and not keep_existing_columns
|
|
319
|
+
columns.append(build(col.name) if regen else col)
|
|
320
|
+
changed |= regen
|
|
321
|
+
|
|
322
|
+
for name in regenerate - present:
|
|
323
|
+
columns.append(build(name))
|
|
324
|
+
changed = True
|
|
325
|
+
|
|
326
|
+
if not changed:
|
|
327
|
+
return selectable
|
|
328
|
+
|
|
329
|
+
inner = sa.select(*columns).select_from(base).subquery()
|
|
330
|
+
return sa.select(*inner.c).select_from(inner)
|
|
331
|
+
|
|
332
|
+
def _system_row_number_expr(self):
|
|
333
|
+
"""Return an expression that produces deterministic row numbers."""
|
|
334
|
+
|
|
335
|
+
raise NotImplementedError
|
|
336
|
+
|
|
337
|
+
def _system_random_expr(self):
|
|
338
|
+
"""Return an expression that produces deterministic random values."""
|
|
339
|
+
|
|
340
|
+
raise NotImplementedError
|
|
341
|
+
|
|
239
342
|
#
|
|
240
343
|
# Table Name Internal Functions
|
|
241
344
|
#
|
|
@@ -246,12 +349,24 @@ class AbstractWarehouse(ABC, Serializable):
|
|
|
246
349
|
name = parsed.path if parsed.scheme == "file" else parsed.netloc
|
|
247
350
|
return parsed.scheme, name
|
|
248
351
|
|
|
249
|
-
def dataset_table_name(self,
|
|
352
|
+
def dataset_table_name(self, dataset: DatasetRecord, version: str) -> str:
|
|
353
|
+
return self._construct_dataset_table_name(
|
|
354
|
+
dataset.project.namespace.name,
|
|
355
|
+
dataset.project.name,
|
|
356
|
+
dataset.name,
|
|
357
|
+
version,
|
|
358
|
+
)
|
|
359
|
+
|
|
360
|
+
def _construct_dataset_table_name(
|
|
361
|
+
self, namespace: str, project: str, dataset_name: str, version: str
|
|
362
|
+
) -> str:
|
|
250
363
|
prefix = self.DATASET_TABLE_PREFIX
|
|
251
364
|
if Client.is_data_source_uri(dataset_name):
|
|
252
365
|
# for datasets that are created for bucket listing we use different prefix
|
|
253
366
|
prefix = self.DATASET_SOURCE_TABLE_PREFIX
|
|
254
|
-
return
|
|
367
|
+
return (
|
|
368
|
+
f"{prefix}{namespace}_{project}_{dataset_name}_{version.replace('.', '_')}"
|
|
369
|
+
)
|
|
255
370
|
|
|
256
371
|
def temp_table_name(self) -> str:
|
|
257
372
|
return self.TMP_TABLE_NAME_PREFIX + _random_string(6)
|
|
@@ -269,38 +384,26 @@ class AbstractWarehouse(ABC, Serializable):
|
|
|
269
384
|
name: str,
|
|
270
385
|
columns: Sequence["sa.Column"] = (),
|
|
271
386
|
if_not_exists: bool = True,
|
|
272
|
-
) -> Table:
|
|
387
|
+
) -> sa.Table:
|
|
273
388
|
"""Creates a dataset rows table for the given dataset name and columns"""
|
|
274
389
|
|
|
275
390
|
def drop_dataset_rows_table(
|
|
276
391
|
self,
|
|
277
392
|
dataset: DatasetRecord,
|
|
278
|
-
version:
|
|
393
|
+
version: str,
|
|
279
394
|
if_exists: bool = True,
|
|
280
395
|
) -> None:
|
|
281
396
|
"""Drops a dataset rows table for the given dataset name."""
|
|
282
|
-
table_name = self.dataset_table_name(dataset
|
|
283
|
-
table = Table(table_name, self.db.metadata)
|
|
397
|
+
table_name = self.dataset_table_name(dataset, version)
|
|
398
|
+
table = sa.Table(table_name, self.db.metadata)
|
|
284
399
|
self.db.drop_table(table, if_exists=if_exists)
|
|
285
|
-
|
|
286
|
-
|
|
287
|
-
|
|
288
|
-
self,
|
|
289
|
-
src: "DatasetRecord",
|
|
290
|
-
dst: "DatasetRecord",
|
|
291
|
-
src_version: int,
|
|
292
|
-
dst_version: int,
|
|
293
|
-
) -> None:
|
|
294
|
-
"""
|
|
295
|
-
Merges source dataset rows and current latest destination dataset rows
|
|
296
|
-
into a new rows table created for new destination dataset version.
|
|
297
|
-
Note that table for new destination version must be created upfront.
|
|
298
|
-
Merge results should not contain duplicates.
|
|
299
|
-
"""
|
|
400
|
+
# Remove from metadata cache to allow recreation
|
|
401
|
+
if table_name in self.db.metadata.tables:
|
|
402
|
+
self.db.metadata.remove(self.db.metadata.tables[table_name])
|
|
300
403
|
|
|
301
404
|
def dataset_rows_select(
|
|
302
405
|
self,
|
|
303
|
-
query: sa.
|
|
406
|
+
query: sa.Select,
|
|
304
407
|
**kwargs,
|
|
305
408
|
) -> Iterator[tuple[Any, ...]]:
|
|
306
409
|
"""
|
|
@@ -311,51 +414,81 @@ class AbstractWarehouse(ABC, Serializable):
|
|
|
311
414
|
query.selected_columns, rows, self.db.dialect
|
|
312
415
|
)
|
|
313
416
|
|
|
417
|
+
def dataset_rows_select_from_ids(
|
|
418
|
+
self,
|
|
419
|
+
query: sa.Select,
|
|
420
|
+
ids: Iterable[RowsOutput],
|
|
421
|
+
is_batched: bool,
|
|
422
|
+
) -> Iterator[RowsOutput]:
|
|
423
|
+
"""
|
|
424
|
+
Fetch dataset rows from database using a list of IDs.
|
|
425
|
+
"""
|
|
426
|
+
if (id_col := query.selected_columns.get("sys__id")) is None:
|
|
427
|
+
raise RuntimeError("sys__id column not found in query")
|
|
428
|
+
|
|
429
|
+
query = query._clone().offset(None).limit(None).order_by(None)
|
|
430
|
+
|
|
431
|
+
if is_batched:
|
|
432
|
+
for batch in ids:
|
|
433
|
+
yield list(self.dataset_rows_select(query.where(id_col.in_(batch))))
|
|
434
|
+
else:
|
|
435
|
+
yield from self.dataset_rows_select(query.where(id_col.in_(ids)))
|
|
436
|
+
|
|
314
437
|
@abstractmethod
|
|
315
438
|
def get_dataset_sources(
|
|
316
|
-
self, dataset: DatasetRecord, version:
|
|
439
|
+
self, dataset: DatasetRecord, version: str
|
|
317
440
|
) -> list[StorageURI]: ...
|
|
318
441
|
|
|
319
|
-
def
|
|
320
|
-
self,
|
|
321
|
-
old_name: str,
|
|
322
|
-
new_name: str,
|
|
323
|
-
old_version: int,
|
|
324
|
-
new_version: int,
|
|
442
|
+
def rename_dataset_tables(
|
|
443
|
+
self, dataset: DatasetRecord, dataset_updated: DatasetRecord
|
|
325
444
|
) -> None:
|
|
326
|
-
|
|
327
|
-
|
|
328
|
-
|
|
329
|
-
|
|
445
|
+
"""
|
|
446
|
+
Renames all dataset version tables when parts of the dataset that
|
|
447
|
+
are used in constructing table name are updated.
|
|
448
|
+
If nothing important is changed, nothing will be renamed (no DB calls
|
|
449
|
+
will be made at all).
|
|
450
|
+
"""
|
|
451
|
+
for version in [v.version for v in dataset_updated.versions]:
|
|
452
|
+
if not dataset.has_version(version):
|
|
453
|
+
continue
|
|
454
|
+
src = self.dataset_table_name(dataset, version)
|
|
455
|
+
dest = self.dataset_table_name(dataset_updated, version)
|
|
456
|
+
if src == dest:
|
|
457
|
+
continue
|
|
458
|
+
self.db.rename_table(src, dest)
|
|
330
459
|
|
|
331
460
|
def dataset_rows_count(self, dataset: DatasetRecord, version=None) -> int:
|
|
332
461
|
"""Returns total number of rows in a dataset"""
|
|
333
462
|
dr = self.dataset_rows(dataset, version)
|
|
334
463
|
table = dr.get_table()
|
|
335
|
-
query = select(sa.func.count(table.c.sys__id))
|
|
464
|
+
query = sa.select(sa.func.count(table.c.sys__id))
|
|
336
465
|
(res,) = self.db.execute(query)
|
|
337
466
|
return res[0]
|
|
338
467
|
|
|
339
468
|
def dataset_stats(
|
|
340
|
-
self, dataset: DatasetRecord, version:
|
|
341
|
-
) -> tuple[
|
|
469
|
+
self, dataset: DatasetRecord, version: str
|
|
470
|
+
) -> tuple[int | None, int | None]:
|
|
342
471
|
"""
|
|
343
472
|
Returns tuple with dataset stats: total number of rows and total dataset size.
|
|
344
473
|
"""
|
|
345
|
-
if not (self.db.has_table(self.dataset_table_name(dataset
|
|
474
|
+
if not (self.db.has_table(self.dataset_table_name(dataset, version))):
|
|
346
475
|
return None, None
|
|
347
476
|
|
|
477
|
+
file_signals = list(
|
|
478
|
+
SignalSchema.deserialize(dataset.feature_schema).get_signals(File)
|
|
479
|
+
)
|
|
480
|
+
|
|
348
481
|
dr = self.dataset_rows(dataset, version)
|
|
349
482
|
table = dr.get_table()
|
|
350
483
|
expressions: tuple[_ColumnsClauseArgument[Any], ...] = (
|
|
351
484
|
sa.func.count(table.c.sys__id),
|
|
352
485
|
)
|
|
353
|
-
|
|
354
|
-
|
|
355
|
-
|
|
486
|
+
size_column_names = [ColumnMeta.to_db_name(s) + "__size" for s in file_signals]
|
|
487
|
+
size_columns = [c for c in table.columns if c.name in size_column_names]
|
|
488
|
+
|
|
356
489
|
if size_columns:
|
|
357
490
|
expressions = (*expressions, sa.func.sum(sum(size_columns)))
|
|
358
|
-
query = select(*expressions)
|
|
491
|
+
query = sa.select(*expressions)
|
|
359
492
|
((nrows, *rest),) = self.db.execute(query)
|
|
360
493
|
return nrows, rest[0] if rest else 0
|
|
361
494
|
|
|
@@ -364,17 +497,22 @@ class AbstractWarehouse(ABC, Serializable):
|
|
|
364
497
|
"""Convert File entries so they can be passed on to `insert_rows()`"""
|
|
365
498
|
|
|
366
499
|
@abstractmethod
|
|
367
|
-
def insert_rows(
|
|
500
|
+
def insert_rows(
|
|
501
|
+
self,
|
|
502
|
+
table: sa.Table,
|
|
503
|
+
rows: Iterable[dict[str, Any]],
|
|
504
|
+
batch_size: int = INSERT_BATCH_SIZE,
|
|
505
|
+
) -> None:
|
|
368
506
|
"""Does batch inserts of any kind of rows into table"""
|
|
369
507
|
|
|
370
|
-
def insert_rows_done(self, table: Table) -> None:
|
|
508
|
+
def insert_rows_done(self, table: sa.Table) -> None:
|
|
371
509
|
"""
|
|
372
510
|
Only needed for certain implementations
|
|
373
511
|
to signal when rows inserts are complete.
|
|
374
512
|
"""
|
|
375
513
|
|
|
376
514
|
@abstractmethod
|
|
377
|
-
def insert_dataset_rows(self, df, dataset: DatasetRecord, version:
|
|
515
|
+
def insert_dataset_rows(self, df, dataset: DatasetRecord, version: str) -> int:
|
|
378
516
|
"""Inserts dataset rows directly into dataset table"""
|
|
379
517
|
|
|
380
518
|
@abstractmethod
|
|
@@ -393,7 +531,7 @@ class AbstractWarehouse(ABC, Serializable):
|
|
|
393
531
|
|
|
394
532
|
@abstractmethod
|
|
395
533
|
def dataset_table_export_file_names(
|
|
396
|
-
self, dataset: DatasetRecord, version:
|
|
534
|
+
self, dataset: DatasetRecord, version: str
|
|
397
535
|
) -> list[str]:
|
|
398
536
|
"""
|
|
399
537
|
Returns list of file names that will be created when user runs dataset export
|
|
@@ -404,7 +542,7 @@ class AbstractWarehouse(ABC, Serializable):
|
|
|
404
542
|
self,
|
|
405
543
|
bucket_uri: str,
|
|
406
544
|
dataset: DatasetRecord,
|
|
407
|
-
version:
|
|
545
|
+
version: str,
|
|
408
546
|
client_config=None,
|
|
409
547
|
) -> list[str]:
|
|
410
548
|
"""
|
|
@@ -454,7 +592,7 @@ class AbstractWarehouse(ABC, Serializable):
|
|
|
454
592
|
dr = dataset_rows
|
|
455
593
|
columns = [c.name for c in query.selected_columns]
|
|
456
594
|
for row in self.db.execute(query):
|
|
457
|
-
d = dict(zip(columns, row))
|
|
595
|
+
d = dict(zip(columns, row, strict=False))
|
|
458
596
|
yield Node(**{dr.without_object(k): v for k, v in d.items()})
|
|
459
597
|
|
|
460
598
|
def get_dirs_by_parent_path(
|
|
@@ -478,7 +616,7 @@ class AbstractWarehouse(ABC, Serializable):
|
|
|
478
616
|
dataset_rows: "DataTable",
|
|
479
617
|
path_list: list[str],
|
|
480
618
|
glob_name: str,
|
|
481
|
-
|
|
619
|
+
column="file",
|
|
482
620
|
) -> Iterator[Node]:
|
|
483
621
|
"""Finds all Nodes that correspond to GLOB like path pattern."""
|
|
484
622
|
dr = dataset_rows
|
|
@@ -488,7 +626,7 @@ class AbstractWarehouse(ABC, Serializable):
|
|
|
488
626
|
).subquery()
|
|
489
627
|
path_glob = "/".join([*path_list, glob_name])
|
|
490
628
|
dirpath = path_glob[: -len(glob_name)]
|
|
491
|
-
relpath = func.substr(de.c(q, "path"), len(dirpath) + 1)
|
|
629
|
+
relpath = sa.func.substr(de.c(q, "path"), len(dirpath) + 1)
|
|
492
630
|
|
|
493
631
|
return self.get_nodes(
|
|
494
632
|
self.expand_query(de, q, dr)
|
|
@@ -512,7 +650,7 @@ class AbstractWarehouse(ABC, Serializable):
|
|
|
512
650
|
de = dr.dir_expansion()
|
|
513
651
|
q = de.query(
|
|
514
652
|
dr.select().where(dr.c("is_latest") == true()).subquery(),
|
|
515
|
-
|
|
653
|
+
column=dr.column,
|
|
516
654
|
).subquery()
|
|
517
655
|
q = self.expand_query(de, q, dr)
|
|
518
656
|
|
|
@@ -575,25 +713,23 @@ class AbstractWarehouse(ABC, Serializable):
|
|
|
575
713
|
default = getattr(
|
|
576
714
|
attrs.fields(Node), dr.without_object(column.name)
|
|
577
715
|
).default
|
|
578
|
-
return func.coalesce(column, default).label(column.name)
|
|
716
|
+
return sa.func.coalesce(column, default).label(column.name)
|
|
579
717
|
|
|
580
718
|
return sa.select(
|
|
581
719
|
q.c.sys__id,
|
|
582
|
-
case(
|
|
583
|
-
|
|
584
|
-
),
|
|
720
|
+
sa.case(
|
|
721
|
+
(de.c(q, "is_dir") == true(), DirType.DIR), else_=DirType.FILE
|
|
722
|
+
).label(dr.col_name("dir_type")),
|
|
585
723
|
de.c(q, "path"),
|
|
586
724
|
with_default(dr.c("etag")),
|
|
587
725
|
de.c(q, "version"),
|
|
588
726
|
with_default(dr.c("is_latest")),
|
|
589
727
|
dr.c("last_modified"),
|
|
590
728
|
with_default(dr.c("size")),
|
|
591
|
-
with_default(dr.c("rand",
|
|
729
|
+
with_default(dr.c("rand", column="sys")),
|
|
592
730
|
dr.c("location"),
|
|
593
731
|
de.c(q, "source"),
|
|
594
|
-
).select_from(
|
|
595
|
-
q.outerjoin(dr.table, q.c.sys__id == dr.c("id", object_name="sys"))
|
|
596
|
-
)
|
|
732
|
+
).select_from(q.outerjoin(dr.table, q.c.sys__id == dr.c("id", column="sys")))
|
|
597
733
|
|
|
598
734
|
def get_node_by_path(self, dataset_rows: "DataTable", path: str) -> Node:
|
|
599
735
|
"""Gets node that corresponds to some path"""
|
|
@@ -658,7 +794,7 @@ class AbstractWarehouse(ABC, Serializable):
|
|
|
658
794
|
return de.c(inner_query, f)
|
|
659
795
|
|
|
660
796
|
return self.db.execute(
|
|
661
|
-
select(*(field_to_expr(f) for f in fields)).order_by(
|
|
797
|
+
sa.select(*(field_to_expr(f) for f in fields)).order_by(
|
|
662
798
|
de.c(inner_query, "source"),
|
|
663
799
|
de.c(inner_query, "path"),
|
|
664
800
|
de.c(inner_query, "version"),
|
|
@@ -680,7 +816,7 @@ class AbstractWarehouse(ABC, Serializable):
|
|
|
680
816
|
return dr.c(f)
|
|
681
817
|
|
|
682
818
|
q = (
|
|
683
|
-
select(*(field_to_expr(f) for f in fields))
|
|
819
|
+
sa.select(*(field_to_expr(f) for f in fields))
|
|
684
820
|
.where(
|
|
685
821
|
dr.c("path").like(f"{sql_escape_like(dirpath)}%"),
|
|
686
822
|
~self.instr(pathfunc.name(dr.c("path")), "/"),
|
|
@@ -693,7 +829,7 @@ class AbstractWarehouse(ABC, Serializable):
|
|
|
693
829
|
def size(
|
|
694
830
|
self,
|
|
695
831
|
dataset_rows: "DataTable",
|
|
696
|
-
node:
|
|
832
|
+
node: Node | dict[str, Any],
|
|
697
833
|
count_files: bool = False,
|
|
698
834
|
) -> tuple[int, int]:
|
|
699
835
|
"""
|
|
@@ -715,10 +851,10 @@ class AbstractWarehouse(ABC, Serializable):
|
|
|
715
851
|
sub_glob = posixpath.join(path, "*")
|
|
716
852
|
dr = dataset_rows
|
|
717
853
|
selections: list[sa.ColumnElement] = [
|
|
718
|
-
func.sum(dr.c("size")),
|
|
854
|
+
sa.func.sum(dr.c("size")),
|
|
719
855
|
]
|
|
720
856
|
if count_files:
|
|
721
|
-
selections.append(func.count())
|
|
857
|
+
selections.append(sa.func.count())
|
|
722
858
|
results = next(
|
|
723
859
|
self.db.execute(
|
|
724
860
|
dr.select(*selections).where(
|
|
@@ -735,10 +871,10 @@ class AbstractWarehouse(ABC, Serializable):
|
|
|
735
871
|
self,
|
|
736
872
|
dataset_rows: "DataTable",
|
|
737
873
|
parent_path: str,
|
|
738
|
-
fields:
|
|
739
|
-
type:
|
|
874
|
+
fields: Sequence[str] | None = None,
|
|
875
|
+
type: str | None = None,
|
|
740
876
|
conds=None,
|
|
741
|
-
order_by:
|
|
877
|
+
order_by: str | list[str] | None = None,
|
|
742
878
|
include_subobjects: bool = True,
|
|
743
879
|
) -> sa.Select:
|
|
744
880
|
if not conds:
|
|
@@ -776,7 +912,7 @@ class AbstractWarehouse(ABC, Serializable):
|
|
|
776
912
|
self,
|
|
777
913
|
dataset_rows: "DataTable",
|
|
778
914
|
node: Node,
|
|
779
|
-
sort:
|
|
915
|
+
sort: list[str] | str | None = None,
|
|
780
916
|
include_subobjects: bool = True,
|
|
781
917
|
) -> Iterator[NodeWithPath]:
|
|
782
918
|
"""
|
|
@@ -834,28 +970,33 @@ class AbstractWarehouse(ABC, Serializable):
|
|
|
834
970
|
def create_udf_table(
|
|
835
971
|
self,
|
|
836
972
|
columns: Sequence["sa.Column"] = (),
|
|
837
|
-
name:
|
|
838
|
-
) ->
|
|
973
|
+
name: str | None = None,
|
|
974
|
+
) -> sa.Table:
|
|
839
975
|
"""
|
|
840
976
|
Create a temporary table for storing custom signals generated by a UDF.
|
|
841
977
|
SQLite TEMPORARY tables cannot be directly used as they are process-specific,
|
|
842
978
|
and UDFs are run in other processes when run in parallel.
|
|
843
979
|
"""
|
|
980
|
+
columns = [
|
|
981
|
+
c
|
|
982
|
+
for c in columns
|
|
983
|
+
if c.name not in [col.name for col in self.dataset_row_cls.sys_columns()]
|
|
984
|
+
]
|
|
844
985
|
tbl = sa.Table(
|
|
845
986
|
name or self.udf_table_name(),
|
|
846
987
|
sa.MetaData(),
|
|
847
|
-
|
|
988
|
+
*self.dataset_row_cls.sys_columns(),
|
|
848
989
|
*columns,
|
|
849
990
|
)
|
|
850
|
-
self.db.create_table(tbl, if_not_exists=True)
|
|
991
|
+
self.db.create_table(tbl, if_not_exists=True, kind="udf")
|
|
851
992
|
return tbl
|
|
852
993
|
|
|
853
994
|
@abstractmethod
|
|
854
995
|
def copy_table(
|
|
855
996
|
self,
|
|
856
|
-
table: Table,
|
|
857
|
-
query:
|
|
858
|
-
progress_cb:
|
|
997
|
+
table: sa.Table,
|
|
998
|
+
query: sa.Select,
|
|
999
|
+
progress_cb: Callable[[int], None] | None = None,
|
|
859
1000
|
) -> None:
|
|
860
1001
|
"""
|
|
861
1002
|
Copy the results of a query into a table.
|
|
@@ -868,13 +1009,15 @@ class AbstractWarehouse(ABC, Serializable):
|
|
|
868
1009
|
right: "_FromClauseArgument",
|
|
869
1010
|
onclause: "_OnClauseArgument",
|
|
870
1011
|
inner: bool = True,
|
|
871
|
-
|
|
1012
|
+
full: bool = False,
|
|
1013
|
+
columns=None,
|
|
1014
|
+
) -> sa.Select:
|
|
872
1015
|
"""
|
|
873
1016
|
Join two tables together.
|
|
874
1017
|
"""
|
|
875
1018
|
|
|
876
1019
|
@abstractmethod
|
|
877
|
-
def create_pre_udf_table(self, query:
|
|
1020
|
+
def create_pre_udf_table(self, query: sa.Select) -> sa.Table:
|
|
878
1021
|
"""
|
|
879
1022
|
Create a temporary table from a query for use in a UDF.
|
|
880
1023
|
"""
|
|
@@ -899,16 +1042,10 @@ class AbstractWarehouse(ABC, Serializable):
|
|
|
899
1042
|
are cleaned up as soon as they are no longer needed.
|
|
900
1043
|
"""
|
|
901
1044
|
to_drop = set(names)
|
|
902
|
-
|
|
903
|
-
|
|
904
|
-
) as pbar:
|
|
905
|
-
for name in to_drop:
|
|
906
|
-
self.db.drop_table(Table(name, self.db.metadata), if_exists=True)
|
|
907
|
-
pbar.update(1)
|
|
1045
|
+
for name in to_drop:
|
|
1046
|
+
self.db.drop_table(sa.Table(name, self.db.metadata), if_exists=True)
|
|
908
1047
|
|
|
909
1048
|
|
|
910
1049
|
def _random_string(length: int) -> str:
|
|
911
|
-
|
|
912
|
-
|
|
913
|
-
for i in range(length)
|
|
914
|
-
)
|
|
1050
|
+
alphabet = string.ascii_letters + string.digits
|
|
1051
|
+
return "".join(secrets.choice(alphabet) for _ in range(length))
|