datachain 0.16.4__py3-none-any.whl → 0.17.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of datachain might be problematic. Click here for more details.
- datachain/catalog/catalog.py +25 -92
- datachain/cli/__init__.py +11 -9
- datachain/cli/commands/datasets.py +1 -1
- datachain/cli/commands/query.py +1 -0
- datachain/cli/commands/show.py +1 -1
- datachain/cli/parser/__init__.py +11 -3
- datachain/data_storage/job.py +1 -0
- datachain/data_storage/metastore.py +105 -94
- datachain/data_storage/sqlite.py +8 -7
- datachain/data_storage/warehouse.py +58 -46
- datachain/dataset.py +88 -45
- datachain/lib/arrow.py +23 -1
- datachain/lib/dataset_info.py +2 -1
- datachain/lib/dc/csv.py +1 -0
- datachain/lib/dc/datachain.py +38 -16
- datachain/lib/dc/datasets.py +28 -7
- datachain/lib/dc/storage.py +10 -2
- datachain/lib/listing.py +2 -0
- datachain/lib/pytorch.py +2 -2
- datachain/lib/udf.py +17 -5
- datachain/listing.py +1 -1
- datachain/query/batch.py +40 -39
- datachain/query/dataset.py +42 -41
- datachain/query/dispatch.py +137 -75
- datachain/query/metrics.py +1 -2
- datachain/query/queue.py +1 -11
- datachain/query/session.py +2 -2
- datachain/query/udf.py +1 -1
- datachain/query/utils.py +8 -14
- datachain/remote/studio.py +4 -4
- datachain/semver.py +58 -0
- datachain/studio.py +1 -1
- datachain/utils.py +3 -0
- {datachain-0.16.4.dist-info → datachain-0.17.0.dist-info}/METADATA +1 -1
- {datachain-0.16.4.dist-info → datachain-0.17.0.dist-info}/RECORD +39 -38
- {datachain-0.16.4.dist-info → datachain-0.17.0.dist-info}/WHEEL +1 -1
- {datachain-0.16.4.dist-info → datachain-0.17.0.dist-info}/entry_points.txt +0 -0
- {datachain-0.16.4.dist-info → datachain-0.17.0.dist-info}/licenses/LICENSE +0 -0
- {datachain-0.16.4.dist-info → datachain-0.17.0.dist-info}/top_level.txt +0 -0
|
@@ -11,16 +11,15 @@ from urllib.parse import urlparse
|
|
|
11
11
|
|
|
12
12
|
import attrs
|
|
13
13
|
import sqlalchemy as sa
|
|
14
|
-
from sqlalchemy import Table, case, select
|
|
15
|
-
from sqlalchemy.sql import func
|
|
16
14
|
from sqlalchemy.sql.expression import true
|
|
17
|
-
from tqdm.auto import tqdm
|
|
18
15
|
|
|
19
16
|
from datachain.client import Client
|
|
20
17
|
from datachain.data_storage.schema import convert_rows_custom_column_types
|
|
21
18
|
from datachain.data_storage.serializer import Serializable
|
|
22
19
|
from datachain.dataset import DatasetRecord, StorageURI
|
|
23
20
|
from datachain.node import DirType, DirTypeGroup, Node, NodeWithPath, get_path
|
|
21
|
+
from datachain.query.batch import RowsOutput
|
|
22
|
+
from datachain.query.utils import get_query_id_column
|
|
24
23
|
from datachain.sql.functions import path as pathfunc
|
|
25
24
|
from datachain.sql.types import Int, SQLType
|
|
26
25
|
from datachain.utils import sql_escape_like
|
|
@@ -31,7 +30,6 @@ if TYPE_CHECKING:
|
|
|
31
30
|
_FromClauseArgument,
|
|
32
31
|
_OnClauseArgument,
|
|
33
32
|
)
|
|
34
|
-
from sqlalchemy.sql.selectable import Select
|
|
35
33
|
from sqlalchemy.types import TypeEngine
|
|
36
34
|
|
|
37
35
|
from datachain.data_storage import schema
|
|
@@ -178,7 +176,7 @@ class AbstractWarehouse(ABC, Serializable):
|
|
|
178
176
|
def dataset_rows(
|
|
179
177
|
self,
|
|
180
178
|
dataset: DatasetRecord,
|
|
181
|
-
version: Optional[
|
|
179
|
+
version: Optional[str] = None,
|
|
182
180
|
column: str = "file",
|
|
183
181
|
):
|
|
184
182
|
version = version or dataset.latest_version
|
|
@@ -199,13 +197,13 @@ class AbstractWarehouse(ABC, Serializable):
|
|
|
199
197
|
# Query Execution
|
|
200
198
|
#
|
|
201
199
|
|
|
202
|
-
def query_count(self, query: sa.
|
|
200
|
+
def query_count(self, query: sa.Select) -> int:
|
|
203
201
|
"""Count the number of rows in a query."""
|
|
204
|
-
count_query = sa.select(func.count(1)).select_from(query.subquery())
|
|
202
|
+
count_query = sa.select(sa.func.count(1)).select_from(query.subquery())
|
|
205
203
|
return next(self.db.execute(count_query))[0]
|
|
206
204
|
|
|
207
205
|
def table_rows_count(self, table) -> int:
|
|
208
|
-
count_query = sa.select(func.count(1)).select_from(table)
|
|
206
|
+
count_query = sa.select(sa.func.count(1)).select_from(table)
|
|
209
207
|
return next(self.db.execute(count_query))[0]
|
|
210
208
|
|
|
211
209
|
def dataset_select_paginated(
|
|
@@ -255,7 +253,7 @@ class AbstractWarehouse(ABC, Serializable):
|
|
|
255
253
|
name = parsed.path if parsed.scheme == "file" else parsed.netloc
|
|
256
254
|
return parsed.scheme, name
|
|
257
255
|
|
|
258
|
-
def dataset_table_name(self, dataset_name: str, version:
|
|
256
|
+
def dataset_table_name(self, dataset_name: str, version: str) -> str:
|
|
259
257
|
prefix = self.DATASET_TABLE_PREFIX
|
|
260
258
|
if Client.is_data_source_uri(dataset_name):
|
|
261
259
|
# for datasets that are created for bucket listing we use different prefix
|
|
@@ -278,18 +276,18 @@ class AbstractWarehouse(ABC, Serializable):
|
|
|
278
276
|
name: str,
|
|
279
277
|
columns: Sequence["sa.Column"] = (),
|
|
280
278
|
if_not_exists: bool = True,
|
|
281
|
-
) -> Table:
|
|
279
|
+
) -> sa.Table:
|
|
282
280
|
"""Creates a dataset rows table for the given dataset name and columns"""
|
|
283
281
|
|
|
284
282
|
def drop_dataset_rows_table(
|
|
285
283
|
self,
|
|
286
284
|
dataset: DatasetRecord,
|
|
287
|
-
version:
|
|
285
|
+
version: str,
|
|
288
286
|
if_exists: bool = True,
|
|
289
287
|
) -> None:
|
|
290
288
|
"""Drops a dataset rows table for the given dataset name."""
|
|
291
289
|
table_name = self.dataset_table_name(dataset.name, version)
|
|
292
|
-
table = Table(table_name, self.db.metadata)
|
|
290
|
+
table = sa.Table(table_name, self.db.metadata)
|
|
293
291
|
self.db.drop_table(table, if_exists=if_exists)
|
|
294
292
|
|
|
295
293
|
@abstractmethod
|
|
@@ -297,8 +295,8 @@ class AbstractWarehouse(ABC, Serializable):
|
|
|
297
295
|
self,
|
|
298
296
|
src: "DatasetRecord",
|
|
299
297
|
dst: "DatasetRecord",
|
|
300
|
-
src_version:
|
|
301
|
-
dst_version:
|
|
298
|
+
src_version: str,
|
|
299
|
+
dst_version: str,
|
|
302
300
|
) -> None:
|
|
303
301
|
"""
|
|
304
302
|
Merges source dataset rows and current latest destination dataset rows
|
|
@@ -309,7 +307,7 @@ class AbstractWarehouse(ABC, Serializable):
|
|
|
309
307
|
|
|
310
308
|
def dataset_rows_select(
|
|
311
309
|
self,
|
|
312
|
-
query: sa.
|
|
310
|
+
query: sa.Select,
|
|
313
311
|
**kwargs,
|
|
314
312
|
) -> Iterator[tuple[Any, ...]]:
|
|
315
313
|
"""
|
|
@@ -320,17 +318,35 @@ class AbstractWarehouse(ABC, Serializable):
|
|
|
320
318
|
query.selected_columns, rows, self.db.dialect
|
|
321
319
|
)
|
|
322
320
|
|
|
321
|
+
def dataset_rows_select_from_ids(
|
|
322
|
+
self,
|
|
323
|
+
query: sa.Select,
|
|
324
|
+
ids: Iterable[RowsOutput],
|
|
325
|
+
is_batched: bool,
|
|
326
|
+
) -> Iterator[RowsOutput]:
|
|
327
|
+
"""
|
|
328
|
+
Fetch dataset rows from database using a list of IDs.
|
|
329
|
+
"""
|
|
330
|
+
if (id_col := get_query_id_column(query)) is None:
|
|
331
|
+
raise RuntimeError("sys__id column not found in query")
|
|
332
|
+
|
|
333
|
+
if is_batched:
|
|
334
|
+
for batch in ids:
|
|
335
|
+
yield list(self.dataset_rows_select(query.where(id_col.in_(batch))))
|
|
336
|
+
else:
|
|
337
|
+
yield from self.dataset_rows_select(query.where(id_col.in_(ids)))
|
|
338
|
+
|
|
323
339
|
@abstractmethod
|
|
324
340
|
def get_dataset_sources(
|
|
325
|
-
self, dataset: DatasetRecord, version:
|
|
341
|
+
self, dataset: DatasetRecord, version: str
|
|
326
342
|
) -> list[StorageURI]: ...
|
|
327
343
|
|
|
328
344
|
def rename_dataset_table(
|
|
329
345
|
self,
|
|
330
346
|
old_name: str,
|
|
331
347
|
new_name: str,
|
|
332
|
-
old_version:
|
|
333
|
-
new_version:
|
|
348
|
+
old_version: str,
|
|
349
|
+
new_version: str,
|
|
334
350
|
) -> None:
|
|
335
351
|
old_ds_table_name = self.dataset_table_name(old_name, old_version)
|
|
336
352
|
new_ds_table_name = self.dataset_table_name(new_name, new_version)
|
|
@@ -341,12 +357,12 @@ class AbstractWarehouse(ABC, Serializable):
|
|
|
341
357
|
"""Returns total number of rows in a dataset"""
|
|
342
358
|
dr = self.dataset_rows(dataset, version)
|
|
343
359
|
table = dr.get_table()
|
|
344
|
-
query = select(sa.func.count(table.c.sys__id))
|
|
360
|
+
query = sa.select(sa.func.count(table.c.sys__id))
|
|
345
361
|
(res,) = self.db.execute(query)
|
|
346
362
|
return res[0]
|
|
347
363
|
|
|
348
364
|
def dataset_stats(
|
|
349
|
-
self, dataset: DatasetRecord, version:
|
|
365
|
+
self, dataset: DatasetRecord, version: str
|
|
350
366
|
) -> tuple[Optional[int], Optional[int]]:
|
|
351
367
|
"""
|
|
352
368
|
Returns tuple with dataset stats: total number of rows and total dataset size.
|
|
@@ -364,7 +380,7 @@ class AbstractWarehouse(ABC, Serializable):
|
|
|
364
380
|
]
|
|
365
381
|
if size_columns:
|
|
366
382
|
expressions = (*expressions, sa.func.sum(sum(size_columns)))
|
|
367
|
-
query = select(*expressions)
|
|
383
|
+
query = sa.select(*expressions)
|
|
368
384
|
((nrows, *rest),) = self.db.execute(query)
|
|
369
385
|
return nrows, rest[0] if rest else 0
|
|
370
386
|
|
|
@@ -373,17 +389,17 @@ class AbstractWarehouse(ABC, Serializable):
|
|
|
373
389
|
"""Convert File entries so they can be passed on to `insert_rows()`"""
|
|
374
390
|
|
|
375
391
|
@abstractmethod
|
|
376
|
-
def insert_rows(self, table: Table, rows: Iterable[dict[str, Any]]) -> None:
|
|
392
|
+
def insert_rows(self, table: sa.Table, rows: Iterable[dict[str, Any]]) -> None:
|
|
377
393
|
"""Does batch inserts of any kind of rows into table"""
|
|
378
394
|
|
|
379
|
-
def insert_rows_done(self, table: Table) -> None:
|
|
395
|
+
def insert_rows_done(self, table: sa.Table) -> None:
|
|
380
396
|
"""
|
|
381
397
|
Only needed for certain implementations
|
|
382
398
|
to signal when rows inserts are complete.
|
|
383
399
|
"""
|
|
384
400
|
|
|
385
401
|
@abstractmethod
|
|
386
|
-
def insert_dataset_rows(self, df, dataset: DatasetRecord, version:
|
|
402
|
+
def insert_dataset_rows(self, df, dataset: DatasetRecord, version: str) -> int:
|
|
387
403
|
"""Inserts dataset rows directly into dataset table"""
|
|
388
404
|
|
|
389
405
|
@abstractmethod
|
|
@@ -402,7 +418,7 @@ class AbstractWarehouse(ABC, Serializable):
|
|
|
402
418
|
|
|
403
419
|
@abstractmethod
|
|
404
420
|
def dataset_table_export_file_names(
|
|
405
|
-
self, dataset: DatasetRecord, version:
|
|
421
|
+
self, dataset: DatasetRecord, version: str
|
|
406
422
|
) -> list[str]:
|
|
407
423
|
"""
|
|
408
424
|
Returns list of file names that will be created when user runs dataset export
|
|
@@ -413,7 +429,7 @@ class AbstractWarehouse(ABC, Serializable):
|
|
|
413
429
|
self,
|
|
414
430
|
bucket_uri: str,
|
|
415
431
|
dataset: DatasetRecord,
|
|
416
|
-
version:
|
|
432
|
+
version: str,
|
|
417
433
|
client_config=None,
|
|
418
434
|
) -> list[str]:
|
|
419
435
|
"""
|
|
@@ -497,7 +513,7 @@ class AbstractWarehouse(ABC, Serializable):
|
|
|
497
513
|
).subquery()
|
|
498
514
|
path_glob = "/".join([*path_list, glob_name])
|
|
499
515
|
dirpath = path_glob[: -len(glob_name)]
|
|
500
|
-
relpath = func.substr(de.c(q, "path"), len(dirpath) + 1)
|
|
516
|
+
relpath = sa.func.substr(de.c(q, "path"), len(dirpath) + 1)
|
|
501
517
|
|
|
502
518
|
return self.get_nodes(
|
|
503
519
|
self.expand_query(de, q, dr)
|
|
@@ -584,13 +600,13 @@ class AbstractWarehouse(ABC, Serializable):
|
|
|
584
600
|
default = getattr(
|
|
585
601
|
attrs.fields(Node), dr.without_object(column.name)
|
|
586
602
|
).default
|
|
587
|
-
return func.coalesce(column, default).label(column.name)
|
|
603
|
+
return sa.func.coalesce(column, default).label(column.name)
|
|
588
604
|
|
|
589
605
|
return sa.select(
|
|
590
606
|
q.c.sys__id,
|
|
591
|
-
case(
|
|
592
|
-
|
|
593
|
-
),
|
|
607
|
+
sa.case(
|
|
608
|
+
(de.c(q, "is_dir") == true(), DirType.DIR), else_=DirType.FILE
|
|
609
|
+
).label(dr.col_name("dir_type")),
|
|
594
610
|
de.c(q, "path"),
|
|
595
611
|
with_default(dr.c("etag")),
|
|
596
612
|
de.c(q, "version"),
|
|
@@ -665,7 +681,7 @@ class AbstractWarehouse(ABC, Serializable):
|
|
|
665
681
|
return de.c(inner_query, f)
|
|
666
682
|
|
|
667
683
|
return self.db.execute(
|
|
668
|
-
select(*(field_to_expr(f) for f in fields)).order_by(
|
|
684
|
+
sa.select(*(field_to_expr(f) for f in fields)).order_by(
|
|
669
685
|
de.c(inner_query, "source"),
|
|
670
686
|
de.c(inner_query, "path"),
|
|
671
687
|
de.c(inner_query, "version"),
|
|
@@ -687,7 +703,7 @@ class AbstractWarehouse(ABC, Serializable):
|
|
|
687
703
|
return dr.c(f)
|
|
688
704
|
|
|
689
705
|
q = (
|
|
690
|
-
select(*(field_to_expr(f) for f in fields))
|
|
706
|
+
sa.select(*(field_to_expr(f) for f in fields))
|
|
691
707
|
.where(
|
|
692
708
|
dr.c("path").like(f"{sql_escape_like(dirpath)}%"),
|
|
693
709
|
~self.instr(pathfunc.name(dr.c("path")), "/"),
|
|
@@ -722,10 +738,10 @@ class AbstractWarehouse(ABC, Serializable):
|
|
|
722
738
|
sub_glob = posixpath.join(path, "*")
|
|
723
739
|
dr = dataset_rows
|
|
724
740
|
selections: list[sa.ColumnElement] = [
|
|
725
|
-
func.sum(dr.c("size")),
|
|
741
|
+
sa.func.sum(dr.c("size")),
|
|
726
742
|
]
|
|
727
743
|
if count_files:
|
|
728
|
-
selections.append(func.count())
|
|
744
|
+
selections.append(sa.func.count())
|
|
729
745
|
results = next(
|
|
730
746
|
self.db.execute(
|
|
731
747
|
dr.select(*selections).where(
|
|
@@ -842,7 +858,7 @@ class AbstractWarehouse(ABC, Serializable):
|
|
|
842
858
|
self,
|
|
843
859
|
columns: Sequence["sa.Column"] = (),
|
|
844
860
|
name: Optional[str] = None,
|
|
845
|
-
) ->
|
|
861
|
+
) -> sa.Table:
|
|
846
862
|
"""
|
|
847
863
|
Create a temporary table for storing custom signals generated by a UDF.
|
|
848
864
|
SQLite TEMPORARY tables cannot be directly used as they are process-specific,
|
|
@@ -860,8 +876,8 @@ class AbstractWarehouse(ABC, Serializable):
|
|
|
860
876
|
@abstractmethod
|
|
861
877
|
def copy_table(
|
|
862
878
|
self,
|
|
863
|
-
table: Table,
|
|
864
|
-
query:
|
|
879
|
+
table: sa.Table,
|
|
880
|
+
query: sa.Select,
|
|
865
881
|
progress_cb: Optional[Callable[[int], None]] = None,
|
|
866
882
|
) -> None:
|
|
867
883
|
"""
|
|
@@ -875,13 +891,13 @@ class AbstractWarehouse(ABC, Serializable):
|
|
|
875
891
|
right: "_FromClauseArgument",
|
|
876
892
|
onclause: "_OnClauseArgument",
|
|
877
893
|
inner: bool = True,
|
|
878
|
-
) ->
|
|
894
|
+
) -> sa.Select:
|
|
879
895
|
"""
|
|
880
896
|
Join two tables together.
|
|
881
897
|
"""
|
|
882
898
|
|
|
883
899
|
@abstractmethod
|
|
884
|
-
def create_pre_udf_table(self, query:
|
|
900
|
+
def create_pre_udf_table(self, query: sa.Select) -> sa.Table:
|
|
885
901
|
"""
|
|
886
902
|
Create a temporary table from a query for use in a UDF.
|
|
887
903
|
"""
|
|
@@ -906,12 +922,8 @@ class AbstractWarehouse(ABC, Serializable):
|
|
|
906
922
|
are cleaned up as soon as they are no longer needed.
|
|
907
923
|
"""
|
|
908
924
|
to_drop = set(names)
|
|
909
|
-
|
|
910
|
-
|
|
911
|
-
) as pbar:
|
|
912
|
-
for name in to_drop:
|
|
913
|
-
self.db.drop_table(Table(name, self.db.metadata), if_exists=True)
|
|
914
|
-
pbar.update(1)
|
|
925
|
+
for name in to_drop:
|
|
926
|
+
self.db.drop_table(sa.Table(name, self.db.metadata), if_exists=True)
|
|
915
927
|
|
|
916
928
|
|
|
917
929
|
def _random_string(length: int) -> str:
|
datachain/dataset.py
CHANGED
|
@@ -12,6 +12,7 @@ from typing import (
|
|
|
12
12
|
)
|
|
13
13
|
from urllib.parse import urlparse
|
|
14
14
|
|
|
15
|
+
from datachain import semver
|
|
15
16
|
from datachain.error import DatasetVersionNotFoundError
|
|
16
17
|
from datachain.sql.types import NAME_TYPES_MAPPING, SQLType
|
|
17
18
|
|
|
@@ -25,6 +26,8 @@ DATASET_PREFIX = "ds://"
|
|
|
25
26
|
QUERY_DATASET_PREFIX = "ds_query_"
|
|
26
27
|
LISTING_PREFIX = "lst__"
|
|
27
28
|
|
|
29
|
+
DEFAULT_DATASET_VERSION = "1.0.0"
|
|
30
|
+
|
|
28
31
|
|
|
29
32
|
# StorageURI represents a normalised URI to a valid storage location (full bucket or
|
|
30
33
|
# absolute local path).
|
|
@@ -33,12 +36,12 @@ LISTING_PREFIX = "lst__"
|
|
|
33
36
|
StorageURI = NewType("StorageURI", str)
|
|
34
37
|
|
|
35
38
|
|
|
36
|
-
def parse_dataset_uri(uri: str) -> tuple[str, Optional[
|
|
39
|
+
def parse_dataset_uri(uri: str) -> tuple[str, Optional[str]]:
|
|
37
40
|
"""
|
|
38
41
|
Parse dataser uri to extract name and version out of it (if version is defined)
|
|
39
42
|
Example:
|
|
40
|
-
Input: ds://zalando@v3
|
|
41
|
-
Output: (zalando, 3)
|
|
43
|
+
Input: ds://zalando@v3.0.1
|
|
44
|
+
Output: (zalando, 3.0.1)
|
|
42
45
|
"""
|
|
43
46
|
p = urlparse(uri)
|
|
44
47
|
if p.scheme != "ds":
|
|
@@ -51,16 +54,15 @@ def parse_dataset_uri(uri: str) -> tuple[str, Optional[int]]:
|
|
|
51
54
|
raise Exception(
|
|
52
55
|
"Wrong dataset uri format, it should be: ds://<name>@v<version>"
|
|
53
56
|
)
|
|
54
|
-
|
|
55
|
-
return name, version
|
|
57
|
+
return name, s[1]
|
|
56
58
|
|
|
57
59
|
|
|
58
|
-
def create_dataset_uri(name: str, version: Optional[
|
|
60
|
+
def create_dataset_uri(name: str, version: Optional[str] = None) -> str:
|
|
59
61
|
"""
|
|
60
62
|
Creates a dataset uri based on dataset name and optionally version
|
|
61
63
|
Example:
|
|
62
|
-
Input: zalando, 3
|
|
63
|
-
Output: ds//zalando@v3
|
|
64
|
+
Input: zalando, 3.0.1
|
|
65
|
+
Output: ds//zalando@v3.0.1
|
|
64
66
|
"""
|
|
65
67
|
uri = f"{DATASET_PREFIX}{name}"
|
|
66
68
|
if version:
|
|
@@ -79,7 +81,7 @@ class DatasetDependency:
|
|
|
79
81
|
id: int
|
|
80
82
|
type: str
|
|
81
83
|
name: str
|
|
82
|
-
version: str
|
|
84
|
+
version: str
|
|
83
85
|
created_at: datetime
|
|
84
86
|
dependencies: list[Optional["DatasetDependency"]]
|
|
85
87
|
|
|
@@ -102,7 +104,7 @@ class DatasetDependency:
|
|
|
102
104
|
dataset_id: Optional[int],
|
|
103
105
|
dataset_version_id: Optional[int],
|
|
104
106
|
dataset_name: Optional[str],
|
|
105
|
-
dataset_version: Optional[
|
|
107
|
+
dataset_version: Optional[str],
|
|
106
108
|
dataset_version_created_at: Optional[datetime],
|
|
107
109
|
) -> Optional["DatasetDependency"]:
|
|
108
110
|
from datachain.client import Client
|
|
@@ -124,7 +126,7 @@ class DatasetDependency:
|
|
|
124
126
|
dependency_type,
|
|
125
127
|
dependency_name,
|
|
126
128
|
(
|
|
127
|
-
|
|
129
|
+
dataset_version # type: ignore[arg-type]
|
|
128
130
|
if dataset_version
|
|
129
131
|
else None
|
|
130
132
|
),
|
|
@@ -163,7 +165,7 @@ class DatasetVersion:
|
|
|
163
165
|
id: int
|
|
164
166
|
uuid: str
|
|
165
167
|
dataset_id: int
|
|
166
|
-
version:
|
|
168
|
+
version: str
|
|
167
169
|
status: int
|
|
168
170
|
feature_schema: dict
|
|
169
171
|
created_at: datetime
|
|
@@ -185,7 +187,7 @@ class DatasetVersion:
|
|
|
185
187
|
id: int,
|
|
186
188
|
uuid: str,
|
|
187
189
|
dataset_id: int,
|
|
188
|
-
version:
|
|
190
|
+
version: str,
|
|
189
191
|
status: int,
|
|
190
192
|
feature_schema: Optional[str],
|
|
191
193
|
created_at: datetime,
|
|
@@ -222,6 +224,10 @@ class DatasetVersion:
|
|
|
222
224
|
job_id,
|
|
223
225
|
)
|
|
224
226
|
|
|
227
|
+
@property
|
|
228
|
+
def version_value(self) -> int:
|
|
229
|
+
return semver.value(self.version)
|
|
230
|
+
|
|
225
231
|
def __eq__(self, other):
|
|
226
232
|
if not isinstance(other, DatasetVersion):
|
|
227
233
|
return False
|
|
@@ -230,7 +236,7 @@ class DatasetVersion:
|
|
|
230
236
|
def __lt__(self, other):
|
|
231
237
|
if not isinstance(other, DatasetVersion):
|
|
232
238
|
return False
|
|
233
|
-
return self.
|
|
239
|
+
return self.version_value < other.version_value
|
|
234
240
|
|
|
235
241
|
def __hash__(self):
|
|
236
242
|
return hash(f"{self.dataset_id}_{self.version}")
|
|
@@ -275,7 +281,7 @@ class DatasetListVersion:
|
|
|
275
281
|
id: int
|
|
276
282
|
uuid: str
|
|
277
283
|
dataset_id: int
|
|
278
|
-
version:
|
|
284
|
+
version: str
|
|
279
285
|
status: int
|
|
280
286
|
created_at: datetime
|
|
281
287
|
finished_at: Optional[datetime]
|
|
@@ -292,7 +298,7 @@ class DatasetListVersion:
|
|
|
292
298
|
id: int,
|
|
293
299
|
uuid: str,
|
|
294
300
|
dataset_id: int,
|
|
295
|
-
version:
|
|
301
|
+
version: str,
|
|
296
302
|
status: int,
|
|
297
303
|
created_at: datetime,
|
|
298
304
|
finished_at: Optional[datetime],
|
|
@@ -323,6 +329,10 @@ class DatasetListVersion:
|
|
|
323
329
|
def __hash__(self):
|
|
324
330
|
return hash(f"{self.dataset_id}_{self.version}")
|
|
325
331
|
|
|
332
|
+
@property
|
|
333
|
+
def version_value(self) -> int:
|
|
334
|
+
return semver.value(self.version)
|
|
335
|
+
|
|
326
336
|
|
|
327
337
|
@dataclass
|
|
328
338
|
class DatasetRecord:
|
|
@@ -371,7 +381,7 @@ class DatasetRecord:
|
|
|
371
381
|
version_id: int,
|
|
372
382
|
version_uuid: str,
|
|
373
383
|
version_dataset_id: int,
|
|
374
|
-
version:
|
|
384
|
+
version: str,
|
|
375
385
|
version_status: int,
|
|
376
386
|
version_feature_schema: Optional[str],
|
|
377
387
|
version_created_at: datetime,
|
|
@@ -441,7 +451,7 @@ class DatasetRecord:
|
|
|
441
451
|
for c_name, c_type in self.schema.items()
|
|
442
452
|
}
|
|
443
453
|
|
|
444
|
-
def get_schema(self, version:
|
|
454
|
+
def get_schema(self, version: str) -> dict[str, Union[SQLType, type[SQLType]]]:
|
|
445
455
|
return self.get_version(version).schema if version else self.schema
|
|
446
456
|
|
|
447
457
|
def update(self, **kwargs):
|
|
@@ -460,20 +470,23 @@ class DatasetRecord:
|
|
|
460
470
|
self.versions = []
|
|
461
471
|
|
|
462
472
|
self.versions = list(set(self.versions + other.versions))
|
|
463
|
-
self.versions.sort(key=lambda v: v.
|
|
473
|
+
self.versions.sort(key=lambda v: v.version_value)
|
|
464
474
|
return self
|
|
465
475
|
|
|
466
|
-
def has_version(self, version:
|
|
467
|
-
return version in self.
|
|
476
|
+
def has_version(self, version: str) -> bool:
|
|
477
|
+
return version in [v.version for v in self.versions]
|
|
468
478
|
|
|
469
|
-
def is_valid_next_version(self, version:
|
|
479
|
+
def is_valid_next_version(self, version: str) -> bool:
|
|
470
480
|
"""
|
|
471
481
|
Checks if a number can be a valid next latest version for dataset.
|
|
472
482
|
The only rule is that it cannot be lower than current latest version
|
|
473
483
|
"""
|
|
474
|
-
return not (
|
|
484
|
+
return not (
|
|
485
|
+
self.latest_version
|
|
486
|
+
and semver.value(self.latest_version) >= semver.value(version)
|
|
487
|
+
)
|
|
475
488
|
|
|
476
|
-
def get_version(self, version:
|
|
489
|
+
def get_version(self, version: str) -> DatasetVersion:
|
|
477
490
|
if not self.has_version(version):
|
|
478
491
|
raise DatasetVersionNotFoundError(
|
|
479
492
|
f"Dataset {self.name} does not have version {version}"
|
|
@@ -496,15 +509,15 @@ class DatasetRecord:
|
|
|
496
509
|
f"Dataset {self.name} does not have version with uuid {uuid}"
|
|
497
510
|
) from None
|
|
498
511
|
|
|
499
|
-
def remove_version(self, version:
|
|
512
|
+
def remove_version(self, version: str) -> None:
|
|
500
513
|
if not self.versions or not self.has_version(version):
|
|
501
514
|
return
|
|
502
515
|
|
|
503
516
|
self.versions = [v for v in self.versions if v.version != version]
|
|
504
517
|
|
|
505
|
-
def identifier(self, version:
|
|
518
|
+
def identifier(self, version: str) -> str:
|
|
506
519
|
"""
|
|
507
|
-
Get identifier in the form my-dataset@v3
|
|
520
|
+
Get identifier in the form my-dataset@v3.0.1
|
|
508
521
|
"""
|
|
509
522
|
if not self.has_version(version):
|
|
510
523
|
raise DatasetVersionNotFoundError(
|
|
@@ -512,43 +525,73 @@ class DatasetRecord:
|
|
|
512
525
|
)
|
|
513
526
|
return f"{self.name}@v{version}"
|
|
514
527
|
|
|
515
|
-
def uri(self, version:
|
|
528
|
+
def uri(self, version: str) -> str:
|
|
516
529
|
"""
|
|
517
|
-
Dataset uri example: ds://dogs@v3
|
|
530
|
+
Dataset uri example: ds://dogs@v3.0.1
|
|
518
531
|
"""
|
|
519
532
|
identifier = self.identifier(version)
|
|
520
533
|
return f"{DATASET_PREFIX}{identifier}"
|
|
521
534
|
|
|
522
535
|
@property
|
|
523
|
-
def
|
|
536
|
+
def next_version_major(self) -> str:
|
|
524
537
|
"""
|
|
525
|
-
|
|
526
|
-
in self.versions attribute
|
|
538
|
+
Returns the next auto-incremented version if the major part is being bumped.
|
|
527
539
|
"""
|
|
528
540
|
if not self.versions:
|
|
529
|
-
return
|
|
541
|
+
return "1.0.0"
|
|
530
542
|
|
|
531
|
-
|
|
543
|
+
major, minor, patch = semver.parse(self.latest_version)
|
|
544
|
+
return semver.create(major + 1, 0, 0)
|
|
532
545
|
|
|
533
546
|
@property
|
|
534
|
-
def
|
|
535
|
-
"""
|
|
547
|
+
def next_version_minor(self) -> str:
|
|
548
|
+
"""
|
|
549
|
+
Returns the next auto-incremented version if the minor part is being bumped.
|
|
550
|
+
"""
|
|
536
551
|
if not self.versions:
|
|
537
|
-
return 1
|
|
538
|
-
|
|
552
|
+
return "1.0.0"
|
|
553
|
+
|
|
554
|
+
major, minor, patch = semver.parse(self.latest_version)
|
|
555
|
+
return semver.create(major, minor + 1, 0)
|
|
539
556
|
|
|
540
557
|
@property
|
|
541
|
-
def
|
|
558
|
+
def next_version_patch(self) -> str:
|
|
559
|
+
"""
|
|
560
|
+
Returns the next auto-incremented version if the patch part is being bumped.
|
|
561
|
+
"""
|
|
562
|
+
if not self.versions:
|
|
563
|
+
return "1.0.0"
|
|
564
|
+
|
|
565
|
+
major, minor, patch = semver.parse(self.latest_version)
|
|
566
|
+
return semver.create(major, minor, patch + 1)
|
|
567
|
+
|
|
568
|
+
@property
|
|
569
|
+
def latest_version(self) -> str:
|
|
542
570
|
"""Returns latest version of a dataset"""
|
|
543
|
-
return max(self.
|
|
571
|
+
return max(self.versions).version
|
|
572
|
+
|
|
573
|
+
def latest_major_version(self, major: int) -> Optional[str]:
|
|
574
|
+
"""
|
|
575
|
+
Returns latest specific major version, e.g if dataset has versions:
|
|
576
|
+
- 1.4.1
|
|
577
|
+
- 2.0.1
|
|
578
|
+
- 2.1.1
|
|
579
|
+
- 2.4.0
|
|
580
|
+
and we call `.latest_major_version(2)` it will return: "2.4.0".
|
|
581
|
+
If no major version is find with input value, None will be returned
|
|
582
|
+
"""
|
|
583
|
+
versions = [v for v in self.versions if semver.parse(v.version)[0] == major]
|
|
584
|
+
if not versions:
|
|
585
|
+
return None
|
|
586
|
+
return max(versions).version
|
|
544
587
|
|
|
545
588
|
@property
|
|
546
|
-
def prev_version(self) -> Optional[
|
|
589
|
+
def prev_version(self) -> Optional[str]:
|
|
547
590
|
"""Returns previous version of a dataset"""
|
|
548
591
|
if len(self.versions) == 1:
|
|
549
592
|
return None
|
|
550
593
|
|
|
551
|
-
return sorted(self.
|
|
594
|
+
return sorted(self.versions)[-2].version
|
|
552
595
|
|
|
553
596
|
@classmethod
|
|
554
597
|
def from_dict(cls, d: dict[str, Any]) -> "DatasetRecord":
|
|
@@ -577,7 +620,7 @@ class DatasetListRecord:
|
|
|
577
620
|
version_id: int,
|
|
578
621
|
version_uuid: str,
|
|
579
622
|
version_dataset_id: int,
|
|
580
|
-
version:
|
|
623
|
+
version: str,
|
|
581
624
|
version_status: int,
|
|
582
625
|
version_created_at: datetime,
|
|
583
626
|
version_finished_at: Optional[datetime],
|
|
@@ -626,11 +669,11 @@ class DatasetListRecord:
|
|
|
626
669
|
self.versions = []
|
|
627
670
|
|
|
628
671
|
self.versions = list(set(self.versions + other.versions))
|
|
629
|
-
self.versions.sort(key=lambda v: v.
|
|
672
|
+
self.versions.sort(key=lambda v: v.version_value)
|
|
630
673
|
return self
|
|
631
674
|
|
|
632
675
|
def latest_version(self) -> DatasetListVersion:
|
|
633
|
-
return max(self.versions, key=lambda v: v.
|
|
676
|
+
return max(self.versions, key=lambda v: v.version_value)
|
|
634
677
|
|
|
635
678
|
@property
|
|
636
679
|
def is_bucket_listing(self) -> bool:
|
datachain/lib/arrow.py
CHANGED
|
@@ -4,6 +4,7 @@ from typing import TYPE_CHECKING, Any, Optional
|
|
|
4
4
|
|
|
5
5
|
import orjson
|
|
6
6
|
import pyarrow as pa
|
|
7
|
+
from pyarrow._csv import ParseOptions
|
|
7
8
|
from pyarrow.dataset import CsvFileFormat, dataset
|
|
8
9
|
from tqdm.auto import tqdm
|
|
9
10
|
|
|
@@ -26,6 +27,18 @@ if TYPE_CHECKING:
|
|
|
26
27
|
DATACHAIN_SIGNAL_SCHEMA_PARQUET_KEY = b"DataChain SignalSchema"
|
|
27
28
|
|
|
28
29
|
|
|
30
|
+
def fix_pyarrow_format(format, parse_options=None):
|
|
31
|
+
# Re-init invalid row handler: https://issues.apache.org/jira/browse/ARROW-17641
|
|
32
|
+
if (
|
|
33
|
+
format
|
|
34
|
+
and isinstance(format, CsvFileFormat)
|
|
35
|
+
and parse_options
|
|
36
|
+
and isinstance(parse_options, ParseOptions)
|
|
37
|
+
):
|
|
38
|
+
format.parse_options = parse_options
|
|
39
|
+
return format
|
|
40
|
+
|
|
41
|
+
|
|
29
42
|
class ArrowGenerator(Generator):
|
|
30
43
|
DEFAULT_BATCH_SIZE = 2**17 # same as `pyarrow._dataset._DEFAULT_BATCH_SIZE`
|
|
31
44
|
|
|
@@ -53,6 +66,7 @@ class ArrowGenerator(Generator):
|
|
|
53
66
|
self.output_schema = output_schema
|
|
54
67
|
self.source = source
|
|
55
68
|
self.nrows = nrows
|
|
69
|
+
self.parse_options = kwargs.pop("parse_options", None)
|
|
56
70
|
self.kwargs = kwargs
|
|
57
71
|
|
|
58
72
|
def process(self, file: File):
|
|
@@ -64,7 +78,11 @@ class ArrowGenerator(Generator):
|
|
|
64
78
|
else:
|
|
65
79
|
fs, fs_path = file.get_fs(), file.get_path()
|
|
66
80
|
|
|
67
|
-
|
|
81
|
+
kwargs = self.kwargs
|
|
82
|
+
if format := kwargs.get("format"):
|
|
83
|
+
kwargs["format"] = fix_pyarrow_format(format, self.parse_options)
|
|
84
|
+
|
|
85
|
+
ds = dataset(fs_path, schema=self.input_schema, filesystem=fs, **kwargs)
|
|
68
86
|
|
|
69
87
|
hf_schema = _get_hf_schema(ds.schema)
|
|
70
88
|
use_datachain_schema = (
|
|
@@ -137,6 +155,10 @@ class ArrowGenerator(Generator):
|
|
|
137
155
|
|
|
138
156
|
|
|
139
157
|
def infer_schema(chain: "DataChain", **kwargs) -> pa.Schema:
|
|
158
|
+
parse_options = kwargs.pop("parse_options", None)
|
|
159
|
+
if format := kwargs.get("format"):
|
|
160
|
+
kwargs["format"] = fix_pyarrow_format(format, parse_options)
|
|
161
|
+
|
|
140
162
|
schemas = []
|
|
141
163
|
for file in chain.collect("file"):
|
|
142
164
|
ds = dataset(file.get_path(), filesystem=file.get_fs(), **kwargs) # type: ignore[union-attr]
|