datachain 0.7.5__py3-none-any.whl → 0.7.7__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of datachain might be problematic. Click here for more details.
- datachain/catalog/catalog.py +99 -113
- datachain/catalog/loader.py +8 -65
- datachain/cli.py +148 -57
- datachain/data_storage/__init__.py +0 -3
- datachain/data_storage/metastore.py +2 -9
- datachain/data_storage/sqlite.py +7 -145
- datachain/data_storage/warehouse.py +1 -5
- datachain/dataset.py +15 -0
- datachain/func/__init__.py +2 -1
- datachain/func/func.py +7 -2
- datachain/lib/dc.py +4 -4
- datachain/lib/pytorch.py +1 -4
- datachain/query/dataset.py +0 -5
- datachain/query/dispatch.py +1 -13
- datachain/query/session.py +0 -1
- datachain/remote/studio.py +33 -1
- datachain/studio.py +80 -0
- {datachain-0.7.5.dist-info → datachain-0.7.7.dist-info}/METADATA +1 -1
- {datachain-0.7.5.dist-info → datachain-0.7.7.dist-info}/RECORD +23 -24
- datachain/data_storage/id_generator.py +0 -136
- {datachain-0.7.5.dist-info → datachain-0.7.7.dist-info}/LICENSE +0 -0
- {datachain-0.7.5.dist-info → datachain-0.7.7.dist-info}/WHEEL +0 -0
- {datachain-0.7.5.dist-info → datachain-0.7.7.dist-info}/entry_points.txt +0 -0
- {datachain-0.7.5.dist-info → datachain-0.7.7.dist-info}/top_level.txt +0 -0
datachain/cli.py
CHANGED
|
@@ -233,6 +233,67 @@ def add_studio_parser(subparsers, parent_parser) -> None:
|
|
|
233
233
|
help="The team to list datasets for. By default, it will use team from config.",
|
|
234
234
|
)
|
|
235
235
|
|
|
236
|
+
studio_run_help = "Run a job in Studio"
|
|
237
|
+
studio_run_description = "This command runs a job in Studio."
|
|
238
|
+
|
|
239
|
+
studio_run_parser = studio_subparser.add_parser(
|
|
240
|
+
"run",
|
|
241
|
+
parents=[parent_parser],
|
|
242
|
+
description=studio_run_description,
|
|
243
|
+
help=studio_run_help,
|
|
244
|
+
)
|
|
245
|
+
|
|
246
|
+
studio_run_parser.add_argument(
|
|
247
|
+
"query_file",
|
|
248
|
+
action="store",
|
|
249
|
+
help="The query file to run.",
|
|
250
|
+
)
|
|
251
|
+
|
|
252
|
+
studio_run_parser.add_argument(
|
|
253
|
+
"--team",
|
|
254
|
+
action="store",
|
|
255
|
+
default=None,
|
|
256
|
+
help="The team to run a job for. By default, it will use team from config.",
|
|
257
|
+
)
|
|
258
|
+
studio_run_parser.add_argument(
|
|
259
|
+
"--env-file",
|
|
260
|
+
action="store",
|
|
261
|
+
help="File containing environment variables to set for the job.",
|
|
262
|
+
)
|
|
263
|
+
|
|
264
|
+
studio_run_parser.add_argument(
|
|
265
|
+
"--env",
|
|
266
|
+
nargs="+",
|
|
267
|
+
help="Environment variable. Can be specified multiple times. Format: KEY=VALUE",
|
|
268
|
+
)
|
|
269
|
+
|
|
270
|
+
studio_run_parser.add_argument(
|
|
271
|
+
"--workers",
|
|
272
|
+
type=int,
|
|
273
|
+
help="Number of workers to use for the job.",
|
|
274
|
+
)
|
|
275
|
+
studio_run_parser.add_argument(
|
|
276
|
+
"--files",
|
|
277
|
+
nargs="+",
|
|
278
|
+
help="Files to include in the job.",
|
|
279
|
+
)
|
|
280
|
+
studio_run_parser.add_argument(
|
|
281
|
+
"--python-version",
|
|
282
|
+
action="store",
|
|
283
|
+
help="Python version to use for the job (e.g. '3.9', '3.10', '3.11').",
|
|
284
|
+
)
|
|
285
|
+
studio_run_parser.add_argument(
|
|
286
|
+
"--req-file",
|
|
287
|
+
action="store",
|
|
288
|
+
help="File containing Python package requirements.",
|
|
289
|
+
)
|
|
290
|
+
|
|
291
|
+
studio_run_parser.add_argument(
|
|
292
|
+
"--req",
|
|
293
|
+
nargs="+",
|
|
294
|
+
help="Python package requirement. Can be specified multiple times.",
|
|
295
|
+
)
|
|
296
|
+
|
|
236
297
|
|
|
237
298
|
def get_parser() -> ArgumentParser: # noqa: PLR0915
|
|
238
299
|
try:
|
|
@@ -358,7 +419,18 @@ def get_parser() -> ArgumentParser: # noqa: PLR0915
|
|
|
358
419
|
|
|
359
420
|
add_studio_parser(subp, parent_parser)
|
|
360
421
|
|
|
361
|
-
|
|
422
|
+
datasets_parser = subp.add_parser(
|
|
423
|
+
"datasets",
|
|
424
|
+
aliases=["ds"],
|
|
425
|
+
parents=[parent_parser],
|
|
426
|
+
description="Commands for managing datasers",
|
|
427
|
+
)
|
|
428
|
+
datasets_subparser = datasets_parser.add_subparsers(
|
|
429
|
+
dest="datasets_cmd",
|
|
430
|
+
help="Use `datachain datasets CMD --help` to display command specific help",
|
|
431
|
+
)
|
|
432
|
+
|
|
433
|
+
parse_pull = datasets_subparser.add_parser(
|
|
362
434
|
"pull",
|
|
363
435
|
parents=[parent_parser],
|
|
364
436
|
description="Pull specific dataset version from SaaS",
|
|
@@ -400,9 +472,21 @@ def get_parser() -> ArgumentParser: # noqa: PLR0915
|
|
|
400
472
|
"--edatachain-file",
|
|
401
473
|
help="Use a different filename for the resulting .edatachain file",
|
|
402
474
|
)
|
|
475
|
+
parse_pull.add_argument(
|
|
476
|
+
"--local-name",
|
|
477
|
+
action="store",
|
|
478
|
+
default=None,
|
|
479
|
+
help="Name of the local dataset",
|
|
480
|
+
)
|
|
481
|
+
parse_pull.add_argument(
|
|
482
|
+
"--local-version",
|
|
483
|
+
action="store",
|
|
484
|
+
default=None,
|
|
485
|
+
help="Version of the local dataset",
|
|
486
|
+
)
|
|
403
487
|
|
|
404
|
-
parse_edit_dataset =
|
|
405
|
-
"edit
|
|
488
|
+
parse_edit_dataset = datasets_subparser.add_parser(
|
|
489
|
+
"edit", parents=[parent_parser], description="Edit dataset metadata"
|
|
406
490
|
)
|
|
407
491
|
parse_edit_dataset.add_argument("name", type=str, help="Dataset name")
|
|
408
492
|
parse_edit_dataset.add_argument(
|
|
@@ -447,8 +531,8 @@ def get_parser() -> ArgumentParser: # noqa: PLR0915
|
|
|
447
531
|
help="The team to edit a dataset. By default, it will use team from config.",
|
|
448
532
|
)
|
|
449
533
|
|
|
450
|
-
datasets_parser =
|
|
451
|
-
"
|
|
534
|
+
datasets_parser = datasets_subparser.add_parser(
|
|
535
|
+
"ls", parents=[parent_parser], description="List datasets"
|
|
452
536
|
)
|
|
453
537
|
datasets_parser.add_argument(
|
|
454
538
|
"--studio",
|
|
@@ -477,8 +561,8 @@ def get_parser() -> ArgumentParser: # noqa: PLR0915
|
|
|
477
561
|
help="The team to list datasets for. By default, it will use team from config.",
|
|
478
562
|
)
|
|
479
563
|
|
|
480
|
-
rm_dataset_parser =
|
|
481
|
-
"rm
|
|
564
|
+
rm_dataset_parser = datasets_subparser.add_parser(
|
|
565
|
+
"rm", parents=[parent_parser], description="Removes dataset", aliases=["remove"]
|
|
482
566
|
)
|
|
483
567
|
rm_dataset_parser.add_argument("name", type=str, help="Dataset name")
|
|
484
568
|
rm_dataset_parser.add_argument(
|
|
@@ -521,8 +605,8 @@ def get_parser() -> ArgumentParser: # noqa: PLR0915
|
|
|
521
605
|
help="The team to delete a dataset. By default, it will use team from config.",
|
|
522
606
|
)
|
|
523
607
|
|
|
524
|
-
dataset_stats_parser =
|
|
525
|
-
"
|
|
608
|
+
dataset_stats_parser = datasets_subparser.add_parser(
|
|
609
|
+
"stats",
|
|
526
610
|
parents=[parent_parser],
|
|
527
611
|
description="Shows basic dataset stats",
|
|
528
612
|
)
|
|
@@ -1203,27 +1287,59 @@ def main(argv: Optional[list[str]] = None) -> int: # noqa: C901, PLR0912, PLR09
|
|
|
1203
1287
|
edatachain=args.edatachain,
|
|
1204
1288
|
edatachain_file=args.edatachain_file,
|
|
1205
1289
|
)
|
|
1206
|
-
elif args.command
|
|
1207
|
-
|
|
1208
|
-
|
|
1209
|
-
|
|
1210
|
-
|
|
1211
|
-
|
|
1212
|
-
|
|
1213
|
-
|
|
1214
|
-
|
|
1215
|
-
|
|
1216
|
-
|
|
1217
|
-
|
|
1218
|
-
|
|
1219
|
-
|
|
1220
|
-
|
|
1221
|
-
|
|
1222
|
-
|
|
1223
|
-
|
|
1224
|
-
|
|
1225
|
-
|
|
1226
|
-
|
|
1290
|
+
elif args.command in ("datasets", "ds"):
|
|
1291
|
+
if args.datasets_cmd == "pull":
|
|
1292
|
+
catalog.pull_dataset(
|
|
1293
|
+
args.dataset,
|
|
1294
|
+
args.output,
|
|
1295
|
+
local_ds_name=args.local_name,
|
|
1296
|
+
local_ds_version=args.local_version,
|
|
1297
|
+
no_cp=args.no_cp,
|
|
1298
|
+
force=bool(args.force),
|
|
1299
|
+
edatachain=args.edatachain,
|
|
1300
|
+
edatachain_file=args.edatachain_file,
|
|
1301
|
+
)
|
|
1302
|
+
elif args.datasets_cmd == "edit":
|
|
1303
|
+
edit_dataset(
|
|
1304
|
+
catalog,
|
|
1305
|
+
args.name,
|
|
1306
|
+
new_name=args.new_name,
|
|
1307
|
+
description=args.description,
|
|
1308
|
+
labels=args.labels,
|
|
1309
|
+
studio=args.studio,
|
|
1310
|
+
local=args.local,
|
|
1311
|
+
all=args.all,
|
|
1312
|
+
team=args.team,
|
|
1313
|
+
)
|
|
1314
|
+
elif args.datasets_cmd == "ls":
|
|
1315
|
+
datasets(
|
|
1316
|
+
catalog=catalog,
|
|
1317
|
+
studio=args.studio,
|
|
1318
|
+
local=args.local,
|
|
1319
|
+
all=args.all,
|
|
1320
|
+
team=args.team,
|
|
1321
|
+
)
|
|
1322
|
+
elif args.datasets_cmd in ("rm", "remove"):
|
|
1323
|
+
rm_dataset(
|
|
1324
|
+
catalog,
|
|
1325
|
+
args.name,
|
|
1326
|
+
version=args.version,
|
|
1327
|
+
force=args.force,
|
|
1328
|
+
studio=args.studio,
|
|
1329
|
+
local=args.local,
|
|
1330
|
+
all=args.all,
|
|
1331
|
+
team=args.team,
|
|
1332
|
+
)
|
|
1333
|
+
elif args.datasets_cmd == "stats":
|
|
1334
|
+
dataset_stats(
|
|
1335
|
+
catalog,
|
|
1336
|
+
args.name,
|
|
1337
|
+
args.version,
|
|
1338
|
+
show_bytes=args.bytes,
|
|
1339
|
+
si=args.si,
|
|
1340
|
+
)
|
|
1341
|
+
else:
|
|
1342
|
+
raise Exception(f"Unexpected command {args.datasets_cmd}")
|
|
1227
1343
|
elif args.command == "ls":
|
|
1228
1344
|
ls(
|
|
1229
1345
|
args.sources,
|
|
@@ -1235,14 +1351,7 @@ def main(argv: Optional[list[str]] = None) -> int: # noqa: C901, PLR0912, PLR09
|
|
|
1235
1351
|
update=bool(args.update),
|
|
1236
1352
|
client_config=client_config,
|
|
1237
1353
|
)
|
|
1238
|
-
|
|
1239
|
-
datasets(
|
|
1240
|
-
catalog=catalog,
|
|
1241
|
-
studio=args.studio,
|
|
1242
|
-
local=args.local,
|
|
1243
|
-
all=args.all,
|
|
1244
|
-
team=args.team,
|
|
1245
|
-
)
|
|
1354
|
+
|
|
1246
1355
|
elif args.command == "show":
|
|
1247
1356
|
show(
|
|
1248
1357
|
catalog,
|
|
@@ -1254,25 +1363,7 @@ def main(argv: Optional[list[str]] = None) -> int: # noqa: C901, PLR0912, PLR09
|
|
|
1254
1363
|
no_collapse=args.no_collapse,
|
|
1255
1364
|
schema=args.schema,
|
|
1256
1365
|
)
|
|
1257
|
-
|
|
1258
|
-
rm_dataset(
|
|
1259
|
-
catalog,
|
|
1260
|
-
args.name,
|
|
1261
|
-
version=args.version,
|
|
1262
|
-
force=args.force,
|
|
1263
|
-
studio=args.studio,
|
|
1264
|
-
local=args.local,
|
|
1265
|
-
all=args.all,
|
|
1266
|
-
team=args.team,
|
|
1267
|
-
)
|
|
1268
|
-
elif args.command == "dataset-stats":
|
|
1269
|
-
dataset_stats(
|
|
1270
|
-
catalog,
|
|
1271
|
-
args.name,
|
|
1272
|
-
args.version,
|
|
1273
|
-
show_bytes=args.bytes,
|
|
1274
|
-
si=args.si,
|
|
1275
|
-
)
|
|
1366
|
+
|
|
1276
1367
|
elif args.command == "du":
|
|
1277
1368
|
du(
|
|
1278
1369
|
catalog,
|
|
@@ -1,12 +1,9 @@
|
|
|
1
|
-
from .id_generator import AbstractDBIDGenerator, AbstractIDGenerator
|
|
2
1
|
from .job import JobQueryType, JobStatus
|
|
3
2
|
from .metastore import AbstractDBMetastore, AbstractMetastore
|
|
4
3
|
from .warehouse import AbstractWarehouse
|
|
5
4
|
|
|
6
5
|
__all__ = [
|
|
7
|
-
"AbstractDBIDGenerator",
|
|
8
6
|
"AbstractDBMetastore",
|
|
9
|
-
"AbstractIDGenerator",
|
|
10
7
|
"AbstractMetastore",
|
|
11
8
|
"AbstractWarehouse",
|
|
12
9
|
"JobQueryType",
|
|
@@ -45,7 +45,7 @@ if TYPE_CHECKING:
|
|
|
45
45
|
from sqlalchemy import Delete, Insert, Select, Update
|
|
46
46
|
from sqlalchemy.schema import SchemaItem
|
|
47
47
|
|
|
48
|
-
from datachain.data_storage import
|
|
48
|
+
from datachain.data_storage import schema
|
|
49
49
|
from datachain.data_storage.db_engine import DatabaseEngine
|
|
50
50
|
|
|
51
51
|
logger = logging.getLogger("datachain")
|
|
@@ -304,16 +304,10 @@ class AbstractDBMetastore(AbstractMetastore):
|
|
|
304
304
|
DATASET_DEPENDENCY_TABLE = "datasets_dependencies"
|
|
305
305
|
JOBS_TABLE = "jobs"
|
|
306
306
|
|
|
307
|
-
id_generator: "AbstractIDGenerator"
|
|
308
307
|
db: "DatabaseEngine"
|
|
309
308
|
|
|
310
|
-
def __init__(
|
|
311
|
-
self,
|
|
312
|
-
id_generator: "AbstractIDGenerator",
|
|
313
|
-
uri: Optional[StorageURI] = None,
|
|
314
|
-
):
|
|
309
|
+
def __init__(self, uri: Optional[StorageURI] = None):
|
|
315
310
|
uri = uri or StorageURI("")
|
|
316
|
-
self.id_generator = id_generator
|
|
317
311
|
super().__init__(uri)
|
|
318
312
|
|
|
319
313
|
def close(self) -> None:
|
|
@@ -322,7 +316,6 @@ class AbstractDBMetastore(AbstractMetastore):
|
|
|
322
316
|
|
|
323
317
|
def cleanup_tables(self, temp_table_names: list[str]) -> None:
|
|
324
318
|
"""Cleanup temp tables."""
|
|
325
|
-
self.id_generator.delete_uris(temp_table_names)
|
|
326
319
|
|
|
327
320
|
@classmethod
|
|
328
321
|
def _datasets_columns(cls) -> list["SchemaItem"]:
|
datachain/data_storage/sqlite.py
CHANGED
|
@@ -15,7 +15,6 @@ from typing import (
|
|
|
15
15
|
)
|
|
16
16
|
|
|
17
17
|
import sqlalchemy
|
|
18
|
-
from packaging import version
|
|
19
18
|
from sqlalchemy import MetaData, Table, UniqueConstraint, exists, select
|
|
20
19
|
from sqlalchemy.dialects import sqlite
|
|
21
20
|
from sqlalchemy.schema import CreateIndex, CreateTable, DropTable
|
|
@@ -27,7 +26,6 @@ from tqdm import tqdm
|
|
|
27
26
|
import datachain.sql.sqlite
|
|
28
27
|
from datachain.data_storage import AbstractDBMetastore, AbstractWarehouse
|
|
29
28
|
from datachain.data_storage.db_engine import DatabaseEngine
|
|
30
|
-
from datachain.data_storage.id_generator import AbstractDBIDGenerator
|
|
31
29
|
from datachain.data_storage.schema import DefaultSchema
|
|
32
30
|
from datachain.dataset import DatasetRecord, StorageURI
|
|
33
31
|
from datachain.error import DataChainError
|
|
@@ -275,123 +273,16 @@ class SQLiteDatabaseEngine(DatabaseEngine):
|
|
|
275
273
|
self.execute_str(f"ALTER TABLE {comp_old_name} RENAME TO {comp_new_name}")
|
|
276
274
|
|
|
277
275
|
|
|
278
|
-
class SQLiteIDGenerator(AbstractDBIDGenerator):
|
|
279
|
-
_db: "SQLiteDatabaseEngine"
|
|
280
|
-
|
|
281
|
-
def __init__(
|
|
282
|
-
self,
|
|
283
|
-
db: Optional["SQLiteDatabaseEngine"] = None,
|
|
284
|
-
table_prefix: Optional[str] = None,
|
|
285
|
-
skip_db_init: bool = False,
|
|
286
|
-
db_file: Optional[str] = None,
|
|
287
|
-
in_memory: bool = False,
|
|
288
|
-
):
|
|
289
|
-
db_file = get_db_file_in_memory(db_file, in_memory)
|
|
290
|
-
|
|
291
|
-
db = db or SQLiteDatabaseEngine.from_db_file(db_file)
|
|
292
|
-
|
|
293
|
-
super().__init__(db, table_prefix, skip_db_init)
|
|
294
|
-
|
|
295
|
-
def clone(self) -> "SQLiteIDGenerator":
|
|
296
|
-
"""Clones SQLiteIDGenerator implementation."""
|
|
297
|
-
return SQLiteIDGenerator(
|
|
298
|
-
self._db.clone(), self._table_prefix, skip_db_init=True
|
|
299
|
-
)
|
|
300
|
-
|
|
301
|
-
def clone_params(self) -> tuple[Callable[..., Any], list[Any], dict[str, Any]]:
|
|
302
|
-
"""
|
|
303
|
-
Returns the function, args, and kwargs needed to instantiate a cloned copy
|
|
304
|
-
of this SQLiteIDGenerator implementation, for use in separate processes
|
|
305
|
-
or machines.
|
|
306
|
-
"""
|
|
307
|
-
return (
|
|
308
|
-
SQLiteIDGenerator.init_after_clone,
|
|
309
|
-
[],
|
|
310
|
-
{
|
|
311
|
-
"db_clone_params": self._db.clone_params(),
|
|
312
|
-
"table_prefix": self._table_prefix,
|
|
313
|
-
},
|
|
314
|
-
)
|
|
315
|
-
|
|
316
|
-
@classmethod
|
|
317
|
-
def init_after_clone(
|
|
318
|
-
cls,
|
|
319
|
-
*,
|
|
320
|
-
db_clone_params: tuple[Callable, list, dict[str, Any]],
|
|
321
|
-
table_prefix: Optional[str] = None,
|
|
322
|
-
) -> "SQLiteIDGenerator":
|
|
323
|
-
"""
|
|
324
|
-
Initializes a new instance of this SQLiteIDGenerator implementation
|
|
325
|
-
using the given parameters, which were obtained from a call to clone_params.
|
|
326
|
-
"""
|
|
327
|
-
(db_class, db_args, db_kwargs) = db_clone_params
|
|
328
|
-
return cls(
|
|
329
|
-
db=db_class(*db_args, **db_kwargs),
|
|
330
|
-
table_prefix=table_prefix,
|
|
331
|
-
skip_db_init=True,
|
|
332
|
-
)
|
|
333
|
-
|
|
334
|
-
@property
|
|
335
|
-
def db(self) -> "SQLiteDatabaseEngine":
|
|
336
|
-
return self._db
|
|
337
|
-
|
|
338
|
-
def init_id(self, uri: str) -> None:
|
|
339
|
-
"""Initializes the ID generator for the given URI with zero last_id."""
|
|
340
|
-
self._db.execute(
|
|
341
|
-
sqlite.insert(self._table)
|
|
342
|
-
.values(uri=uri, last_id=0)
|
|
343
|
-
.on_conflict_do_nothing()
|
|
344
|
-
)
|
|
345
|
-
|
|
346
|
-
def get_next_ids(self, uri: str, count: int) -> range:
|
|
347
|
-
"""Returns a range of IDs for the given URI."""
|
|
348
|
-
|
|
349
|
-
sqlite_version = version.parse(sqlite3.sqlite_version)
|
|
350
|
-
is_returning_supported = sqlite_version >= version.parse("3.35.0")
|
|
351
|
-
if is_returning_supported:
|
|
352
|
-
stmt = (
|
|
353
|
-
sqlite.insert(self._table)
|
|
354
|
-
.values(uri=uri, last_id=count)
|
|
355
|
-
.on_conflict_do_update(
|
|
356
|
-
index_elements=["uri"],
|
|
357
|
-
set_={"last_id": self._table.c.last_id + count},
|
|
358
|
-
)
|
|
359
|
-
.returning(self._table.c.last_id)
|
|
360
|
-
)
|
|
361
|
-
last_id = self._db.execute(stmt).fetchone()[0]
|
|
362
|
-
else:
|
|
363
|
-
# Older versions of SQLite are still the default under Ubuntu LTS,
|
|
364
|
-
# e.g. Ubuntu 20.04 LTS (Focal Fossa) uses 3.31.1
|
|
365
|
-
# Transactions ensure no concurrency conflicts
|
|
366
|
-
with self._db.transaction() as conn:
|
|
367
|
-
stmt_ins = (
|
|
368
|
-
sqlite.insert(self._table)
|
|
369
|
-
.values(uri=uri, last_id=count)
|
|
370
|
-
.on_conflict_do_update(
|
|
371
|
-
index_elements=["uri"],
|
|
372
|
-
set_={"last_id": self._table.c.last_id + count},
|
|
373
|
-
)
|
|
374
|
-
)
|
|
375
|
-
self._db.execute(stmt_ins, conn=conn)
|
|
376
|
-
|
|
377
|
-
stmt_sel = select(self._table.c.last_id).where(self._table.c.uri == uri)
|
|
378
|
-
last_id = self._db.execute(stmt_sel, conn=conn).fetchone()[0]
|
|
379
|
-
|
|
380
|
-
return range(last_id - count + 1, last_id + 1)
|
|
381
|
-
|
|
382
|
-
|
|
383
276
|
class SQLiteMetastore(AbstractDBMetastore):
|
|
384
277
|
"""
|
|
385
278
|
SQLite Metastore uses SQLite3 for storing indexed data locally.
|
|
386
279
|
This is currently used for the local cli.
|
|
387
280
|
"""
|
|
388
281
|
|
|
389
|
-
id_generator: "SQLiteIDGenerator"
|
|
390
282
|
db: "SQLiteDatabaseEngine"
|
|
391
283
|
|
|
392
284
|
def __init__(
|
|
393
285
|
self,
|
|
394
|
-
id_generator: "SQLiteIDGenerator",
|
|
395
286
|
uri: Optional[StorageURI] = None,
|
|
396
287
|
db: Optional["SQLiteDatabaseEngine"] = None,
|
|
397
288
|
db_file: Optional[str] = None,
|
|
@@ -399,7 +290,7 @@ class SQLiteMetastore(AbstractDBMetastore):
|
|
|
399
290
|
):
|
|
400
291
|
uri = uri or StorageURI("")
|
|
401
292
|
self.schema: DefaultSchema = DefaultSchema()
|
|
402
|
-
super().__init__(
|
|
293
|
+
super().__init__(uri)
|
|
403
294
|
|
|
404
295
|
# needed for dropping tables in correct order for tests because of
|
|
405
296
|
# foreign keys
|
|
@@ -424,11 +315,7 @@ class SQLiteMetastore(AbstractDBMetastore):
|
|
|
424
315
|
if not uri and self.uri:
|
|
425
316
|
uri = self.uri
|
|
426
317
|
|
|
427
|
-
return SQLiteMetastore(
|
|
428
|
-
self.id_generator.clone(),
|
|
429
|
-
uri=uri,
|
|
430
|
-
db=self.db.clone(),
|
|
431
|
-
)
|
|
318
|
+
return SQLiteMetastore(uri=uri, db=self.db.clone())
|
|
432
319
|
|
|
433
320
|
def clone_params(self) -> tuple[Callable[..., Any], list[Any], dict[str, Any]]:
|
|
434
321
|
"""
|
|
@@ -439,7 +326,6 @@ class SQLiteMetastore(AbstractDBMetastore):
|
|
|
439
326
|
SQLiteMetastore.init_after_clone,
|
|
440
327
|
[],
|
|
441
328
|
{
|
|
442
|
-
"id_generator_clone_params": self.id_generator.clone_params(),
|
|
443
329
|
"uri": self.uri,
|
|
444
330
|
"db_clone_params": self.db.clone_params(),
|
|
445
331
|
},
|
|
@@ -449,21 +335,11 @@ class SQLiteMetastore(AbstractDBMetastore):
|
|
|
449
335
|
def init_after_clone(
|
|
450
336
|
cls,
|
|
451
337
|
*,
|
|
452
|
-
id_generator_clone_params: tuple[Callable, list, dict[str, Any]],
|
|
453
338
|
uri: StorageURI,
|
|
454
339
|
db_clone_params: tuple[Callable, list, dict[str, Any]],
|
|
455
340
|
) -> "SQLiteMetastore":
|
|
456
|
-
(
|
|
457
|
-
id_generator_class,
|
|
458
|
-
id_generator_args,
|
|
459
|
-
id_generator_kwargs,
|
|
460
|
-
) = id_generator_clone_params
|
|
461
341
|
(db_class, db_args, db_kwargs) = db_clone_params
|
|
462
|
-
return cls(
|
|
463
|
-
id_generator=id_generator_class(*id_generator_args, **id_generator_kwargs),
|
|
464
|
-
uri=uri,
|
|
465
|
-
db=db_class(*db_args, **db_kwargs),
|
|
466
|
-
)
|
|
342
|
+
return cls(uri=uri, db=db_class(*db_args, **db_kwargs))
|
|
467
343
|
|
|
468
344
|
def _init_tables(self) -> None:
|
|
469
345
|
"""Initialize tables."""
|
|
@@ -518,7 +394,6 @@ class SQLiteWarehouse(AbstractWarehouse):
|
|
|
518
394
|
This is currently used for the local cli.
|
|
519
395
|
"""
|
|
520
396
|
|
|
521
|
-
id_generator: "SQLiteIDGenerator"
|
|
522
397
|
db: "SQLiteDatabaseEngine"
|
|
523
398
|
|
|
524
399
|
# Cache for our defined column types to dialect specific TypeEngine relations
|
|
@@ -526,13 +401,12 @@ class SQLiteWarehouse(AbstractWarehouse):
|
|
|
526
401
|
|
|
527
402
|
def __init__(
|
|
528
403
|
self,
|
|
529
|
-
id_generator: "SQLiteIDGenerator",
|
|
530
404
|
db: Optional["SQLiteDatabaseEngine"] = None,
|
|
531
405
|
db_file: Optional[str] = None,
|
|
532
406
|
in_memory: bool = False,
|
|
533
407
|
):
|
|
534
408
|
self.schema: DefaultSchema = DefaultSchema()
|
|
535
|
-
super().__init__(
|
|
409
|
+
super().__init__()
|
|
536
410
|
|
|
537
411
|
db_file = get_db_file_in_memory(db_file, in_memory)
|
|
538
412
|
|
|
@@ -543,7 +417,7 @@ class SQLiteWarehouse(AbstractWarehouse):
|
|
|
543
417
|
self.close()
|
|
544
418
|
|
|
545
419
|
def clone(self, use_new_connection: bool = False) -> "SQLiteWarehouse":
|
|
546
|
-
return SQLiteWarehouse(
|
|
420
|
+
return SQLiteWarehouse(db=self.db.clone())
|
|
547
421
|
|
|
548
422
|
def clone_params(self) -> tuple[Callable[..., Any], list[Any], dict[str, Any]]:
|
|
549
423
|
"""
|
|
@@ -553,29 +427,17 @@ class SQLiteWarehouse(AbstractWarehouse):
|
|
|
553
427
|
return (
|
|
554
428
|
SQLiteWarehouse.init_after_clone,
|
|
555
429
|
[],
|
|
556
|
-
{
|
|
557
|
-
"id_generator_clone_params": self.id_generator.clone_params(),
|
|
558
|
-
"db_clone_params": self.db.clone_params(),
|
|
559
|
-
},
|
|
430
|
+
{"db_clone_params": self.db.clone_params()},
|
|
560
431
|
)
|
|
561
432
|
|
|
562
433
|
@classmethod
|
|
563
434
|
def init_after_clone(
|
|
564
435
|
cls,
|
|
565
436
|
*,
|
|
566
|
-
id_generator_clone_params: tuple[Callable, list, dict[str, Any]],
|
|
567
437
|
db_clone_params: tuple[Callable, list, dict[str, Any]],
|
|
568
438
|
) -> "SQLiteWarehouse":
|
|
569
|
-
(
|
|
570
|
-
id_generator_class,
|
|
571
|
-
id_generator_args,
|
|
572
|
-
id_generator_kwargs,
|
|
573
|
-
) = id_generator_clone_params
|
|
574
439
|
(db_class, db_args, db_kwargs) = db_clone_params
|
|
575
|
-
return cls(
|
|
576
|
-
id_generator=id_generator_class(*id_generator_args, **id_generator_kwargs),
|
|
577
|
-
db=db_class(*db_args, **db_kwargs),
|
|
578
|
-
)
|
|
440
|
+
return cls(db=db_class(*db_args, **db_kwargs))
|
|
579
441
|
|
|
580
442
|
def _reflect_tables(self, filter_tables=None):
|
|
581
443
|
"""
|
|
@@ -34,7 +34,7 @@ if TYPE_CHECKING:
|
|
|
34
34
|
from sqlalchemy.sql.selectable import Join, Select
|
|
35
35
|
from sqlalchemy.types import TypeEngine
|
|
36
36
|
|
|
37
|
-
from datachain.data_storage import
|
|
37
|
+
from datachain.data_storage import schema
|
|
38
38
|
from datachain.data_storage.db_engine import DatabaseEngine
|
|
39
39
|
from datachain.data_storage.schema import DataTable
|
|
40
40
|
from datachain.lib.file import File
|
|
@@ -69,13 +69,9 @@ class AbstractWarehouse(ABC, Serializable):
|
|
|
69
69
|
UDF_TABLE_NAME_PREFIX = "udf_"
|
|
70
70
|
TMP_TABLE_NAME_PREFIX = "tmp_"
|
|
71
71
|
|
|
72
|
-
id_generator: "AbstractIDGenerator"
|
|
73
72
|
schema: "schema.Schema"
|
|
74
73
|
db: "DatabaseEngine"
|
|
75
74
|
|
|
76
|
-
def __init__(self, id_generator: "AbstractIDGenerator"):
|
|
77
|
-
self.id_generator = id_generator
|
|
78
|
-
|
|
79
75
|
def __enter__(self) -> "AbstractWarehouse":
|
|
80
76
|
return self
|
|
81
77
|
|
datachain/dataset.py
CHANGED
|
@@ -488,6 +488,18 @@ class DatasetRecord:
|
|
|
488
488
|
if v.version == version
|
|
489
489
|
)
|
|
490
490
|
|
|
491
|
+
def get_version_by_uuid(self, uuid: str) -> DatasetVersion:
|
|
492
|
+
try:
|
|
493
|
+
return next(
|
|
494
|
+
v
|
|
495
|
+
for v in self.versions # type: ignore [union-attr]
|
|
496
|
+
if v.uuid == uuid
|
|
497
|
+
)
|
|
498
|
+
except StopIteration:
|
|
499
|
+
raise DatasetVersionNotFoundError(
|
|
500
|
+
f"Dataset {self.name} does not have version with uuid {uuid}"
|
|
501
|
+
) from None
|
|
502
|
+
|
|
491
503
|
def remove_version(self, version: int) -> None:
|
|
492
504
|
if not self.versions or not self.has_version(version):
|
|
493
505
|
return
|
|
@@ -635,6 +647,9 @@ class DatasetListRecord:
|
|
|
635
647
|
LISTING_PREFIX
|
|
636
648
|
)
|
|
637
649
|
|
|
650
|
+
def has_version_with_uuid(self, uuid: str) -> bool:
|
|
651
|
+
return any(v.uuid == uuid for v in self.versions)
|
|
652
|
+
|
|
638
653
|
|
|
639
654
|
class RowDict(dict):
|
|
640
655
|
pass
|
datachain/func/__init__.py
CHANGED
|
@@ -1,4 +1,4 @@
|
|
|
1
|
-
from sqlalchemy import literal
|
|
1
|
+
from sqlalchemy import case, literal
|
|
2
2
|
|
|
3
3
|
from . import array, path, random, string
|
|
4
4
|
from .aggregate import (
|
|
@@ -24,6 +24,7 @@ __all__ = [
|
|
|
24
24
|
"any_value",
|
|
25
25
|
"array",
|
|
26
26
|
"avg",
|
|
27
|
+
"case",
|
|
27
28
|
"collect",
|
|
28
29
|
"concat",
|
|
29
30
|
"cosine_distance",
|
datachain/func/func.py
CHANGED
|
@@ -2,9 +2,11 @@ import inspect
|
|
|
2
2
|
from collections.abc import Sequence
|
|
3
3
|
from typing import TYPE_CHECKING, Any, Callable, Optional, Union
|
|
4
4
|
|
|
5
|
-
from sqlalchemy import BindParameter, ColumnElement, desc
|
|
5
|
+
from sqlalchemy import BindParameter, Case, ColumnElement, desc
|
|
6
|
+
from sqlalchemy.ext.hybrid import Comparator
|
|
6
7
|
|
|
7
8
|
from datachain.lib.convert.python_to_sql import python_to_sql
|
|
9
|
+
from datachain.lib.convert.sql_to_python import sql_to_python
|
|
8
10
|
from datachain.lib.utils import DataChainColumnError, DataChainParamsError
|
|
9
11
|
from datachain.query.schema import Column, ColumnMeta
|
|
10
12
|
|
|
@@ -71,7 +73,7 @@ class Func(Function):
|
|
|
71
73
|
return (
|
|
72
74
|
[
|
|
73
75
|
col
|
|
74
|
-
if isinstance(col, (Func, BindParameter))
|
|
76
|
+
if isinstance(col, (Func, BindParameter, Case, Comparator))
|
|
75
77
|
else ColumnMeta.to_db_name(
|
|
76
78
|
col.name if isinstance(col, ColumnElement) else col
|
|
77
79
|
)
|
|
@@ -273,6 +275,9 @@ def get_db_col_type(signals_schema: "SignalSchema", col: ColT) -> "DataType":
|
|
|
273
275
|
if isinstance(col, Func):
|
|
274
276
|
return col.get_result_type(signals_schema)
|
|
275
277
|
|
|
278
|
+
if isinstance(col, ColumnElement) and not hasattr(col, "name"):
|
|
279
|
+
return sql_to_python(col)
|
|
280
|
+
|
|
276
281
|
return signals_schema.get_column_type(
|
|
277
282
|
col.name if isinstance(col, ColumnElement) else col
|
|
278
283
|
)
|
datachain/lib/dc.py
CHANGED
|
@@ -1150,7 +1150,7 @@ class DataChain:
|
|
|
1150
1150
|
def group_by(
|
|
1151
1151
|
self,
|
|
1152
1152
|
*,
|
|
1153
|
-
partition_by: Union[str, Func, Sequence[Union[str, Func]]],
|
|
1153
|
+
partition_by: Optional[Union[str, Func, Sequence[Union[str, Func]]]] = None,
|
|
1154
1154
|
**kwargs: Func,
|
|
1155
1155
|
) -> "Self":
|
|
1156
1156
|
"""Group rows by specified set of signals and return new signals
|
|
@@ -1167,10 +1167,10 @@ class DataChain:
|
|
|
1167
1167
|
)
|
|
1168
1168
|
```
|
|
1169
1169
|
"""
|
|
1170
|
-
if
|
|
1170
|
+
if partition_by is None:
|
|
1171
|
+
partition_by = []
|
|
1172
|
+
elif isinstance(partition_by, (str, Func)):
|
|
1171
1173
|
partition_by = [partition_by]
|
|
1172
|
-
if not partition_by:
|
|
1173
|
-
raise ValueError("At least one column should be provided for partition_by")
|
|
1174
1174
|
|
|
1175
1175
|
partition_by_columns: list[Column] = []
|
|
1176
1176
|
signal_columns: list[Column] = []
|