datachain 0.16.3__py3-none-any.whl → 0.16.5__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of datachain might be problematic. Click here for more details.
- datachain/catalog/catalog.py +5 -1
- datachain/cli/__init__.py +11 -9
- datachain/cli/commands/query.py +1 -0
- datachain/cli/parser/__init__.py +9 -1
- datachain/cli/parser/job.py +6 -1
- datachain/data_storage/job.py +1 -0
- datachain/data_storage/metastore.py +82 -71
- datachain/data_storage/warehouse.py +46 -34
- datachain/lib/arrow.py +23 -1
- datachain/lib/dc/csv.py +1 -0
- datachain/lib/dc/datachain.py +30 -13
- datachain/lib/listing.py +2 -0
- datachain/lib/udf.py +17 -5
- datachain/query/batch.py +40 -39
- datachain/query/dataset.py +33 -32
- datachain/query/dispatch.py +137 -75
- datachain/query/metrics.py +1 -2
- datachain/query/queue.py +1 -11
- datachain/query/udf.py +1 -1
- datachain/query/utils.py +8 -14
- datachain/remote/studio.py +2 -0
- datachain/studio.py +3 -0
- datachain/utils.py +3 -0
- {datachain-0.16.3.dist-info → datachain-0.16.5.dist-info}/METADATA +1 -1
- {datachain-0.16.3.dist-info → datachain-0.16.5.dist-info}/RECORD +29 -29
- {datachain-0.16.3.dist-info → datachain-0.16.5.dist-info}/WHEEL +1 -1
- {datachain-0.16.3.dist-info → datachain-0.16.5.dist-info}/entry_points.txt +0 -0
- {datachain-0.16.3.dist-info → datachain-0.16.5.dist-info}/licenses/LICENSE +0 -0
- {datachain-0.16.3.dist-info → datachain-0.16.5.dist-info}/top_level.txt +0 -0
datachain/catalog/catalog.py
CHANGED
|
@@ -79,6 +79,7 @@ DATASET_INTERNAL_ERROR_MESSAGE = "Internal error on creating dataset"
|
|
|
79
79
|
QUERY_SCRIPT_INVALID_LAST_STATEMENT_EXIT_CODE = 10
|
|
80
80
|
# exit code we use if query script was canceled
|
|
81
81
|
QUERY_SCRIPT_CANCELED_EXIT_CODE = 11
|
|
82
|
+
QUERY_SCRIPT_SIGTERM_EXIT_CODE = -15 # if query script was terminated by SIGTERM
|
|
82
83
|
|
|
83
84
|
# dataset pull
|
|
84
85
|
PULL_DATASET_MAX_THREADS = 5
|
|
@@ -1645,7 +1646,10 @@ class Catalog:
|
|
|
1645
1646
|
thread.join() # wait for the reader thread
|
|
1646
1647
|
|
|
1647
1648
|
logger.info("Process %s exited with return code %s", proc.pid, proc.returncode)
|
|
1648
|
-
if proc.returncode
|
|
1649
|
+
if proc.returncode in (
|
|
1650
|
+
QUERY_SCRIPT_CANCELED_EXIT_CODE,
|
|
1651
|
+
QUERY_SCRIPT_SIGTERM_EXIT_CODE,
|
|
1652
|
+
):
|
|
1649
1653
|
raise QueryScriptCancelError(
|
|
1650
1654
|
"Query script was canceled by user",
|
|
1651
1655
|
return_code=proc.returncode,
|
datachain/cli/__init__.py
CHANGED
|
@@ -34,8 +34,10 @@ def main(argv: Optional[list[str]] = None) -> int:
|
|
|
34
34
|
datachain_parser = get_parser()
|
|
35
35
|
args = datachain_parser.parse_args(argv)
|
|
36
36
|
|
|
37
|
-
if args.command
|
|
38
|
-
return handle_udf(
|
|
37
|
+
if args.command == "internal-run-udf":
|
|
38
|
+
return handle_udf()
|
|
39
|
+
if args.command == "internal-run-udf-worker":
|
|
40
|
+
return handle_udf_runner(args.fd)
|
|
39
41
|
|
|
40
42
|
if args.command is None:
|
|
41
43
|
datachain_parser.print_help(sys.stderr)
|
|
@@ -303,13 +305,13 @@ def handle_general_exception(exc, args, logging_level):
|
|
|
303
305
|
return error, 1
|
|
304
306
|
|
|
305
307
|
|
|
306
|
-
def handle_udf(
|
|
307
|
-
|
|
308
|
-
from datachain.query.dispatch import udf_entrypoint
|
|
308
|
+
def handle_udf() -> int:
|
|
309
|
+
from datachain.query.dispatch import udf_entrypoint
|
|
309
310
|
|
|
310
|
-
|
|
311
|
+
return udf_entrypoint()
|
|
311
312
|
|
|
312
|
-
if command == "internal-run-udf-worker":
|
|
313
|
-
from datachain.query.dispatch import udf_worker_entrypoint
|
|
314
313
|
|
|
315
|
-
|
|
314
|
+
def handle_udf_runner(fd: Optional[int] = None) -> int:
|
|
315
|
+
from datachain.query.dispatch import udf_worker_entrypoint
|
|
316
|
+
|
|
317
|
+
return udf_worker_entrypoint(fd)
|
datachain/cli/commands/query.py
CHANGED
datachain/cli/parser/__init__.py
CHANGED
|
@@ -549,7 +549,15 @@ def get_parser() -> ArgumentParser: # noqa: PLR0915
|
|
|
549
549
|
add_anon_arg(parse_gc)
|
|
550
550
|
|
|
551
551
|
subp.add_parser("internal-run-udf", parents=[parent_parser])
|
|
552
|
-
subp.add_parser("internal-run-udf-worker", parents=[parent_parser])
|
|
552
|
+
run_udf_worker = subp.add_parser("internal-run-udf-worker", parents=[parent_parser])
|
|
553
|
+
run_udf_worker.add_argument(
|
|
554
|
+
"--fd",
|
|
555
|
+
type=int,
|
|
556
|
+
action="store",
|
|
557
|
+
default=None,
|
|
558
|
+
help="File descriptor to write results to",
|
|
559
|
+
)
|
|
560
|
+
|
|
553
561
|
add_completion_parser(subp, [parent_parser])
|
|
554
562
|
return parser
|
|
555
563
|
|
datachain/cli/parser/job.py
CHANGED
|
@@ -13,7 +13,7 @@ def add_jobs_parser(subparsers, parent_parser) -> None:
|
|
|
13
13
|
)
|
|
14
14
|
jobs_subparser = jobs_parser.add_subparsers(
|
|
15
15
|
dest="cmd",
|
|
16
|
-
help="Use `datachain
|
|
16
|
+
help="Use `datachain job CMD --help` to display command-specific help",
|
|
17
17
|
)
|
|
18
18
|
|
|
19
19
|
studio_run_help = "Run a job in Studio"
|
|
@@ -66,6 +66,11 @@ def add_jobs_parser(subparsers, parent_parser) -> None:
|
|
|
66
66
|
action="store",
|
|
67
67
|
help="Python version for the job (e.g., 3.9, 3.10, 3.11)",
|
|
68
68
|
)
|
|
69
|
+
studio_run_parser.add_argument(
|
|
70
|
+
"--repository",
|
|
71
|
+
action="store",
|
|
72
|
+
help="Repository URL to clone before running the job",
|
|
73
|
+
)
|
|
69
74
|
studio_run_parser.add_argument(
|
|
70
75
|
"--req-file",
|
|
71
76
|
action="store",
|
datachain/data_storage/job.py
CHANGED
|
@@ -254,6 +254,7 @@ class AbstractMetastore(ABC, Serializable):
|
|
|
254
254
|
name: str,
|
|
255
255
|
query: str,
|
|
256
256
|
query_type: JobQueryType = JobQueryType.PYTHON,
|
|
257
|
+
status: JobStatus = JobStatus.CREATED,
|
|
257
258
|
workers: int = 1,
|
|
258
259
|
python_version: Optional[str] = None,
|
|
259
260
|
params: Optional[dict[str, str]] = None,
|
|
@@ -264,33 +265,35 @@ class AbstractMetastore(ABC, Serializable):
|
|
|
264
265
|
"""
|
|
265
266
|
|
|
266
267
|
@abstractmethod
|
|
267
|
-
def
|
|
268
|
+
def get_job(self, job_id: str) -> Optional[Job]:
|
|
269
|
+
"""Returns the job with the given ID."""
|
|
270
|
+
|
|
271
|
+
@abstractmethod
|
|
272
|
+
def update_job(
|
|
268
273
|
self,
|
|
269
274
|
job_id: str,
|
|
270
|
-
status: JobStatus,
|
|
275
|
+
status: Optional[JobStatus] = None,
|
|
276
|
+
exit_code: Optional[int] = None,
|
|
271
277
|
error_message: Optional[str] = None,
|
|
272
278
|
error_stack: Optional[str] = None,
|
|
279
|
+
finished_at: Optional[datetime] = None,
|
|
273
280
|
metrics: Optional[dict[str, Any]] = None,
|
|
274
|
-
) ->
|
|
275
|
-
"""
|
|
281
|
+
) -> Optional["Job"]:
|
|
282
|
+
"""Updates job fields."""
|
|
276
283
|
|
|
277
284
|
@abstractmethod
|
|
278
|
-
def
|
|
279
|
-
"""Returns the status of the given job."""
|
|
280
|
-
|
|
281
|
-
@abstractmethod
|
|
282
|
-
def set_job_and_dataset_status(
|
|
285
|
+
def set_job_status(
|
|
283
286
|
self,
|
|
284
287
|
job_id: str,
|
|
285
|
-
|
|
286
|
-
|
|
288
|
+
status: JobStatus,
|
|
289
|
+
error_message: Optional[str] = None,
|
|
290
|
+
error_stack: Optional[str] = None,
|
|
287
291
|
) -> None:
|
|
288
|
-
"""Set the status of the given job
|
|
292
|
+
"""Set the status of the given job."""
|
|
289
293
|
|
|
290
294
|
@abstractmethod
|
|
291
|
-
def
|
|
292
|
-
"""Returns
|
|
293
|
-
raise NotImplementedError
|
|
295
|
+
def get_job_status(self, job_id: str) -> Optional[JobStatus]:
|
|
296
|
+
"""Returns the status of the given job."""
|
|
294
297
|
|
|
295
298
|
|
|
296
299
|
class AbstractDBMetastore(AbstractMetastore):
|
|
@@ -651,30 +654,31 @@ class AbstractDBMetastore(AbstractMetastore):
|
|
|
651
654
|
dataset_version = dataset.get_version(version)
|
|
652
655
|
|
|
653
656
|
values = {}
|
|
657
|
+
version_values: dict = {}
|
|
654
658
|
for field, value in kwargs.items():
|
|
655
659
|
if field in self._dataset_version_fields[1:]:
|
|
656
660
|
if field == "schema":
|
|
657
|
-
dataset_version.update(**{field: DatasetRecord.parse_schema(value)})
|
|
658
661
|
values[field] = json.dumps(value) if value else None
|
|
662
|
+
version_values[field] = DatasetRecord.parse_schema(value)
|
|
659
663
|
elif field == "feature_schema":
|
|
660
664
|
values[field] = json.dumps(value) if value else None
|
|
665
|
+
version_values[field] = value
|
|
661
666
|
elif field == "preview" and isinstance(value, list):
|
|
662
667
|
values[field] = json.dumps(value, cls=JSONSerialize)
|
|
668
|
+
version_values[field] = value
|
|
663
669
|
else:
|
|
664
670
|
values[field] = value
|
|
665
|
-
|
|
666
|
-
|
|
667
|
-
if not values:
|
|
668
|
-
# Nothing to update
|
|
669
|
-
return dataset_version
|
|
671
|
+
version_values[field] = value
|
|
670
672
|
|
|
671
|
-
|
|
672
|
-
|
|
673
|
-
self.
|
|
674
|
-
|
|
675
|
-
|
|
676
|
-
|
|
677
|
-
|
|
673
|
+
if values:
|
|
674
|
+
dv = self._datasets_versions
|
|
675
|
+
self.db.execute(
|
|
676
|
+
self._datasets_versions_update()
|
|
677
|
+
.where(dv.c.dataset_id == dataset.id and dv.c.version == version)
|
|
678
|
+
.values(values),
|
|
679
|
+
conn=conn,
|
|
680
|
+
) # type: ignore [attr-defined]
|
|
681
|
+
dataset_version.update(**version_values)
|
|
678
682
|
|
|
679
683
|
return dataset_version
|
|
680
684
|
|
|
@@ -702,7 +706,7 @@ class AbstractDBMetastore(AbstractMetastore):
|
|
|
702
706
|
dataset_fields: list[str],
|
|
703
707
|
dataset_version_fields: list[str],
|
|
704
708
|
isouter: bool = True,
|
|
705
|
-
):
|
|
709
|
+
) -> "Select":
|
|
706
710
|
if not (
|
|
707
711
|
self.db.has_table(self._datasets.name)
|
|
708
712
|
and self.db.has_table(self._datasets_versions.name)
|
|
@@ -719,12 +723,12 @@ class AbstractDBMetastore(AbstractMetastore):
|
|
|
719
723
|
j = d.join(dv, d.c.id == dv.c.dataset_id, isouter=isouter)
|
|
720
724
|
return query.select_from(j)
|
|
721
725
|
|
|
722
|
-
def _base_dataset_query(self):
|
|
726
|
+
def _base_dataset_query(self) -> "Select":
|
|
723
727
|
return self._get_dataset_query(
|
|
724
728
|
self._dataset_fields, self._dataset_version_fields
|
|
725
729
|
)
|
|
726
730
|
|
|
727
|
-
def _base_list_datasets_query(self):
|
|
731
|
+
def _base_list_datasets_query(self) -> "Select":
|
|
728
732
|
return self._get_dataset_query(
|
|
729
733
|
self._dataset_list_fields, self._dataset_list_version_fields, isouter=False
|
|
730
734
|
)
|
|
@@ -1018,6 +1022,7 @@ class AbstractDBMetastore(AbstractMetastore):
|
|
|
1018
1022
|
name: str,
|
|
1019
1023
|
query: str,
|
|
1020
1024
|
query_type: JobQueryType = JobQueryType.PYTHON,
|
|
1025
|
+
status: JobStatus = JobStatus.CREATED,
|
|
1021
1026
|
workers: int = 1,
|
|
1022
1027
|
python_version: Optional[str] = None,
|
|
1023
1028
|
params: Optional[dict[str, str]] = None,
|
|
@@ -1032,7 +1037,7 @@ class AbstractDBMetastore(AbstractMetastore):
|
|
|
1032
1037
|
self._jobs_insert().values(
|
|
1033
1038
|
id=job_id,
|
|
1034
1039
|
name=name,
|
|
1035
|
-
status=
|
|
1040
|
+
status=status,
|
|
1036
1041
|
created_at=datetime.now(timezone.utc),
|
|
1037
1042
|
query=query,
|
|
1038
1043
|
query_type=query_type.value,
|
|
@@ -1047,25 +1052,65 @@ class AbstractDBMetastore(AbstractMetastore):
|
|
|
1047
1052
|
)
|
|
1048
1053
|
return job_id
|
|
1049
1054
|
|
|
1055
|
+
def get_job(self, job_id: str, conn=None) -> Optional[Job]:
|
|
1056
|
+
"""Returns the job with the given ID."""
|
|
1057
|
+
query = self._jobs_select(self._jobs).where(self._jobs.c.id == job_id)
|
|
1058
|
+
results = list(self.db.execute(query, conn=conn))
|
|
1059
|
+
if not results:
|
|
1060
|
+
return None
|
|
1061
|
+
return self._parse_job(results[0])
|
|
1062
|
+
|
|
1063
|
+
def update_job(
|
|
1064
|
+
self,
|
|
1065
|
+
job_id: str,
|
|
1066
|
+
status: Optional[JobStatus] = None,
|
|
1067
|
+
exit_code: Optional[int] = None,
|
|
1068
|
+
error_message: Optional[str] = None,
|
|
1069
|
+
error_stack: Optional[str] = None,
|
|
1070
|
+
finished_at: Optional[datetime] = None,
|
|
1071
|
+
metrics: Optional[dict[str, Any]] = None,
|
|
1072
|
+
conn: Optional[Any] = None,
|
|
1073
|
+
) -> Optional["Job"]:
|
|
1074
|
+
"""Updates job fields."""
|
|
1075
|
+
values: dict = {}
|
|
1076
|
+
if status is not None:
|
|
1077
|
+
values["status"] = status
|
|
1078
|
+
if exit_code is not None:
|
|
1079
|
+
values["exit_code"] = exit_code
|
|
1080
|
+
if error_message is not None:
|
|
1081
|
+
values["error_message"] = error_message
|
|
1082
|
+
if error_stack is not None:
|
|
1083
|
+
values["error_stack"] = error_stack
|
|
1084
|
+
if finished_at is not None:
|
|
1085
|
+
values["finished_at"] = finished_at
|
|
1086
|
+
if metrics:
|
|
1087
|
+
values["metrics"] = json.dumps(metrics)
|
|
1088
|
+
|
|
1089
|
+
if values:
|
|
1090
|
+
j = self._jobs
|
|
1091
|
+
self.db.execute(
|
|
1092
|
+
self._jobs_update().where(j.c.id == job_id).values(**values),
|
|
1093
|
+
conn=conn,
|
|
1094
|
+
) # type: ignore [attr-defined]
|
|
1095
|
+
|
|
1096
|
+
return self.get_job(job_id, conn=conn)
|
|
1097
|
+
|
|
1050
1098
|
def set_job_status(
|
|
1051
1099
|
self,
|
|
1052
1100
|
job_id: str,
|
|
1053
1101
|
status: JobStatus,
|
|
1054
1102
|
error_message: Optional[str] = None,
|
|
1055
1103
|
error_stack: Optional[str] = None,
|
|
1056
|
-
metrics: Optional[dict[str, Any]] = None,
|
|
1057
1104
|
conn: Optional[Any] = None,
|
|
1058
1105
|
) -> None:
|
|
1059
1106
|
"""Set the status of the given job."""
|
|
1060
|
-
values: dict = {"status": status
|
|
1061
|
-
if status
|
|
1107
|
+
values: dict = {"status": status}
|
|
1108
|
+
if status in JobStatus.finished():
|
|
1062
1109
|
values["finished_at"] = datetime.now(timezone.utc)
|
|
1063
1110
|
if error_message:
|
|
1064
1111
|
values["error_message"] = error_message
|
|
1065
1112
|
if error_stack:
|
|
1066
1113
|
values["error_stack"] = error_stack
|
|
1067
|
-
if metrics:
|
|
1068
|
-
values["metrics"] = json.dumps(metrics)
|
|
1069
1114
|
self.db.execute(
|
|
1070
1115
|
self._jobs_update(self._jobs.c.id == job_id).values(**values),
|
|
1071
1116
|
conn=conn,
|
|
@@ -1086,37 +1131,3 @@ class AbstractDBMetastore(AbstractMetastore):
|
|
|
1086
1131
|
if not results:
|
|
1087
1132
|
return None
|
|
1088
1133
|
return results[0][0]
|
|
1089
|
-
|
|
1090
|
-
def set_job_and_dataset_status(
|
|
1091
|
-
self,
|
|
1092
|
-
job_id: str,
|
|
1093
|
-
job_status: JobStatus,
|
|
1094
|
-
dataset_status: DatasetStatus,
|
|
1095
|
-
) -> None:
|
|
1096
|
-
"""Set the status of the given job and dataset."""
|
|
1097
|
-
with self.db.transaction() as conn:
|
|
1098
|
-
self.set_job_status(job_id, status=job_status, conn=conn)
|
|
1099
|
-
dv = self._datasets_versions
|
|
1100
|
-
query = (
|
|
1101
|
-
self._datasets_versions_update()
|
|
1102
|
-
.where(
|
|
1103
|
-
(dv.c.job_id == job_id) & (dv.c.status != DatasetStatus.COMPLETE)
|
|
1104
|
-
)
|
|
1105
|
-
.values(status=dataset_status)
|
|
1106
|
-
)
|
|
1107
|
-
self.db.execute(query, conn=conn) # type: ignore[attr-defined]
|
|
1108
|
-
|
|
1109
|
-
def get_job_dataset_versions(self, job_id: str) -> list[tuple[str, int]]:
|
|
1110
|
-
"""Returns dataset names and versions for the job."""
|
|
1111
|
-
dv = self._datasets_versions
|
|
1112
|
-
ds = self._datasets
|
|
1113
|
-
|
|
1114
|
-
join_condition = dv.c.dataset_id == ds.c.id
|
|
1115
|
-
|
|
1116
|
-
query = (
|
|
1117
|
-
self._datasets_versions_select(ds.c.name, dv.c.version)
|
|
1118
|
-
.select_from(dv.join(ds, join_condition))
|
|
1119
|
-
.where(dv.c.job_id == job_id)
|
|
1120
|
-
)
|
|
1121
|
-
|
|
1122
|
-
return list(self.db.execute(query))
|
|
@@ -11,16 +11,15 @@ from urllib.parse import urlparse
|
|
|
11
11
|
|
|
12
12
|
import attrs
|
|
13
13
|
import sqlalchemy as sa
|
|
14
|
-
from sqlalchemy import Table, case, select
|
|
15
|
-
from sqlalchemy.sql import func
|
|
16
14
|
from sqlalchemy.sql.expression import true
|
|
17
|
-
from tqdm.auto import tqdm
|
|
18
15
|
|
|
19
16
|
from datachain.client import Client
|
|
20
17
|
from datachain.data_storage.schema import convert_rows_custom_column_types
|
|
21
18
|
from datachain.data_storage.serializer import Serializable
|
|
22
19
|
from datachain.dataset import DatasetRecord, StorageURI
|
|
23
20
|
from datachain.node import DirType, DirTypeGroup, Node, NodeWithPath, get_path
|
|
21
|
+
from datachain.query.batch import RowsOutput
|
|
22
|
+
from datachain.query.utils import get_query_id_column
|
|
24
23
|
from datachain.sql.functions import path as pathfunc
|
|
25
24
|
from datachain.sql.types import Int, SQLType
|
|
26
25
|
from datachain.utils import sql_escape_like
|
|
@@ -31,7 +30,6 @@ if TYPE_CHECKING:
|
|
|
31
30
|
_FromClauseArgument,
|
|
32
31
|
_OnClauseArgument,
|
|
33
32
|
)
|
|
34
|
-
from sqlalchemy.sql.selectable import Select
|
|
35
33
|
from sqlalchemy.types import TypeEngine
|
|
36
34
|
|
|
37
35
|
from datachain.data_storage import schema
|
|
@@ -199,13 +197,13 @@ class AbstractWarehouse(ABC, Serializable):
|
|
|
199
197
|
# Query Execution
|
|
200
198
|
#
|
|
201
199
|
|
|
202
|
-
def query_count(self, query: sa.
|
|
200
|
+
def query_count(self, query: sa.Select) -> int:
|
|
203
201
|
"""Count the number of rows in a query."""
|
|
204
|
-
count_query = sa.select(func.count(1)).select_from(query.subquery())
|
|
202
|
+
count_query = sa.select(sa.func.count(1)).select_from(query.subquery())
|
|
205
203
|
return next(self.db.execute(count_query))[0]
|
|
206
204
|
|
|
207
205
|
def table_rows_count(self, table) -> int:
|
|
208
|
-
count_query = sa.select(func.count(1)).select_from(table)
|
|
206
|
+
count_query = sa.select(sa.func.count(1)).select_from(table)
|
|
209
207
|
return next(self.db.execute(count_query))[0]
|
|
210
208
|
|
|
211
209
|
def dataset_select_paginated(
|
|
@@ -278,7 +276,7 @@ class AbstractWarehouse(ABC, Serializable):
|
|
|
278
276
|
name: str,
|
|
279
277
|
columns: Sequence["sa.Column"] = (),
|
|
280
278
|
if_not_exists: bool = True,
|
|
281
|
-
) -> Table:
|
|
279
|
+
) -> sa.Table:
|
|
282
280
|
"""Creates a dataset rows table for the given dataset name and columns"""
|
|
283
281
|
|
|
284
282
|
def drop_dataset_rows_table(
|
|
@@ -289,7 +287,7 @@ class AbstractWarehouse(ABC, Serializable):
|
|
|
289
287
|
) -> None:
|
|
290
288
|
"""Drops a dataset rows table for the given dataset name."""
|
|
291
289
|
table_name = self.dataset_table_name(dataset.name, version)
|
|
292
|
-
table = Table(table_name, self.db.metadata)
|
|
290
|
+
table = sa.Table(table_name, self.db.metadata)
|
|
293
291
|
self.db.drop_table(table, if_exists=if_exists)
|
|
294
292
|
|
|
295
293
|
@abstractmethod
|
|
@@ -309,7 +307,7 @@ class AbstractWarehouse(ABC, Serializable):
|
|
|
309
307
|
|
|
310
308
|
def dataset_rows_select(
|
|
311
309
|
self,
|
|
312
|
-
query: sa.
|
|
310
|
+
query: sa.Select,
|
|
313
311
|
**kwargs,
|
|
314
312
|
) -> Iterator[tuple[Any, ...]]:
|
|
315
313
|
"""
|
|
@@ -320,6 +318,24 @@ class AbstractWarehouse(ABC, Serializable):
|
|
|
320
318
|
query.selected_columns, rows, self.db.dialect
|
|
321
319
|
)
|
|
322
320
|
|
|
321
|
+
def dataset_rows_select_from_ids(
|
|
322
|
+
self,
|
|
323
|
+
query: sa.Select,
|
|
324
|
+
ids: Iterable[RowsOutput],
|
|
325
|
+
is_batched: bool,
|
|
326
|
+
) -> Iterator[RowsOutput]:
|
|
327
|
+
"""
|
|
328
|
+
Fetch dataset rows from database using a list of IDs.
|
|
329
|
+
"""
|
|
330
|
+
if (id_col := get_query_id_column(query)) is None:
|
|
331
|
+
raise RuntimeError("sys__id column not found in query")
|
|
332
|
+
|
|
333
|
+
if is_batched:
|
|
334
|
+
for batch in ids:
|
|
335
|
+
yield list(self.dataset_rows_select(query.where(id_col.in_(batch))))
|
|
336
|
+
else:
|
|
337
|
+
yield from self.dataset_rows_select(query.where(id_col.in_(ids)))
|
|
338
|
+
|
|
323
339
|
@abstractmethod
|
|
324
340
|
def get_dataset_sources(
|
|
325
341
|
self, dataset: DatasetRecord, version: int
|
|
@@ -341,7 +357,7 @@ class AbstractWarehouse(ABC, Serializable):
|
|
|
341
357
|
"""Returns total number of rows in a dataset"""
|
|
342
358
|
dr = self.dataset_rows(dataset, version)
|
|
343
359
|
table = dr.get_table()
|
|
344
|
-
query = select(sa.func.count(table.c.sys__id))
|
|
360
|
+
query = sa.select(sa.func.count(table.c.sys__id))
|
|
345
361
|
(res,) = self.db.execute(query)
|
|
346
362
|
return res[0]
|
|
347
363
|
|
|
@@ -364,7 +380,7 @@ class AbstractWarehouse(ABC, Serializable):
|
|
|
364
380
|
]
|
|
365
381
|
if size_columns:
|
|
366
382
|
expressions = (*expressions, sa.func.sum(sum(size_columns)))
|
|
367
|
-
query = select(*expressions)
|
|
383
|
+
query = sa.select(*expressions)
|
|
368
384
|
((nrows, *rest),) = self.db.execute(query)
|
|
369
385
|
return nrows, rest[0] if rest else 0
|
|
370
386
|
|
|
@@ -373,10 +389,10 @@ class AbstractWarehouse(ABC, Serializable):
|
|
|
373
389
|
"""Convert File entries so they can be passed on to `insert_rows()`"""
|
|
374
390
|
|
|
375
391
|
@abstractmethod
|
|
376
|
-
def insert_rows(self, table: Table, rows: Iterable[dict[str, Any]]) -> None:
|
|
392
|
+
def insert_rows(self, table: sa.Table, rows: Iterable[dict[str, Any]]) -> None:
|
|
377
393
|
"""Does batch inserts of any kind of rows into table"""
|
|
378
394
|
|
|
379
|
-
def insert_rows_done(self, table: Table) -> None:
|
|
395
|
+
def insert_rows_done(self, table: sa.Table) -> None:
|
|
380
396
|
"""
|
|
381
397
|
Only needed for certain implementations
|
|
382
398
|
to signal when rows inserts are complete.
|
|
@@ -497,7 +513,7 @@ class AbstractWarehouse(ABC, Serializable):
|
|
|
497
513
|
).subquery()
|
|
498
514
|
path_glob = "/".join([*path_list, glob_name])
|
|
499
515
|
dirpath = path_glob[: -len(glob_name)]
|
|
500
|
-
relpath = func.substr(de.c(q, "path"), len(dirpath) + 1)
|
|
516
|
+
relpath = sa.func.substr(de.c(q, "path"), len(dirpath) + 1)
|
|
501
517
|
|
|
502
518
|
return self.get_nodes(
|
|
503
519
|
self.expand_query(de, q, dr)
|
|
@@ -584,13 +600,13 @@ class AbstractWarehouse(ABC, Serializable):
|
|
|
584
600
|
default = getattr(
|
|
585
601
|
attrs.fields(Node), dr.without_object(column.name)
|
|
586
602
|
).default
|
|
587
|
-
return func.coalesce(column, default).label(column.name)
|
|
603
|
+
return sa.func.coalesce(column, default).label(column.name)
|
|
588
604
|
|
|
589
605
|
return sa.select(
|
|
590
606
|
q.c.sys__id,
|
|
591
|
-
case(
|
|
592
|
-
|
|
593
|
-
),
|
|
607
|
+
sa.case(
|
|
608
|
+
(de.c(q, "is_dir") == true(), DirType.DIR), else_=DirType.FILE
|
|
609
|
+
).label(dr.col_name("dir_type")),
|
|
594
610
|
de.c(q, "path"),
|
|
595
611
|
with_default(dr.c("etag")),
|
|
596
612
|
de.c(q, "version"),
|
|
@@ -665,7 +681,7 @@ class AbstractWarehouse(ABC, Serializable):
|
|
|
665
681
|
return de.c(inner_query, f)
|
|
666
682
|
|
|
667
683
|
return self.db.execute(
|
|
668
|
-
select(*(field_to_expr(f) for f in fields)).order_by(
|
|
684
|
+
sa.select(*(field_to_expr(f) for f in fields)).order_by(
|
|
669
685
|
de.c(inner_query, "source"),
|
|
670
686
|
de.c(inner_query, "path"),
|
|
671
687
|
de.c(inner_query, "version"),
|
|
@@ -687,7 +703,7 @@ class AbstractWarehouse(ABC, Serializable):
|
|
|
687
703
|
return dr.c(f)
|
|
688
704
|
|
|
689
705
|
q = (
|
|
690
|
-
select(*(field_to_expr(f) for f in fields))
|
|
706
|
+
sa.select(*(field_to_expr(f) for f in fields))
|
|
691
707
|
.where(
|
|
692
708
|
dr.c("path").like(f"{sql_escape_like(dirpath)}%"),
|
|
693
709
|
~self.instr(pathfunc.name(dr.c("path")), "/"),
|
|
@@ -722,10 +738,10 @@ class AbstractWarehouse(ABC, Serializable):
|
|
|
722
738
|
sub_glob = posixpath.join(path, "*")
|
|
723
739
|
dr = dataset_rows
|
|
724
740
|
selections: list[sa.ColumnElement] = [
|
|
725
|
-
func.sum(dr.c("size")),
|
|
741
|
+
sa.func.sum(dr.c("size")),
|
|
726
742
|
]
|
|
727
743
|
if count_files:
|
|
728
|
-
selections.append(func.count())
|
|
744
|
+
selections.append(sa.func.count())
|
|
729
745
|
results = next(
|
|
730
746
|
self.db.execute(
|
|
731
747
|
dr.select(*selections).where(
|
|
@@ -842,7 +858,7 @@ class AbstractWarehouse(ABC, Serializable):
|
|
|
842
858
|
self,
|
|
843
859
|
columns: Sequence["sa.Column"] = (),
|
|
844
860
|
name: Optional[str] = None,
|
|
845
|
-
) ->
|
|
861
|
+
) -> sa.Table:
|
|
846
862
|
"""
|
|
847
863
|
Create a temporary table for storing custom signals generated by a UDF.
|
|
848
864
|
SQLite TEMPORARY tables cannot be directly used as they are process-specific,
|
|
@@ -860,8 +876,8 @@ class AbstractWarehouse(ABC, Serializable):
|
|
|
860
876
|
@abstractmethod
|
|
861
877
|
def copy_table(
|
|
862
878
|
self,
|
|
863
|
-
table: Table,
|
|
864
|
-
query:
|
|
879
|
+
table: sa.Table,
|
|
880
|
+
query: sa.Select,
|
|
865
881
|
progress_cb: Optional[Callable[[int], None]] = None,
|
|
866
882
|
) -> None:
|
|
867
883
|
"""
|
|
@@ -875,13 +891,13 @@ class AbstractWarehouse(ABC, Serializable):
|
|
|
875
891
|
right: "_FromClauseArgument",
|
|
876
892
|
onclause: "_OnClauseArgument",
|
|
877
893
|
inner: bool = True,
|
|
878
|
-
) ->
|
|
894
|
+
) -> sa.Select:
|
|
879
895
|
"""
|
|
880
896
|
Join two tables together.
|
|
881
897
|
"""
|
|
882
898
|
|
|
883
899
|
@abstractmethod
|
|
884
|
-
def create_pre_udf_table(self, query:
|
|
900
|
+
def create_pre_udf_table(self, query: sa.Select) -> sa.Table:
|
|
885
901
|
"""
|
|
886
902
|
Create a temporary table from a query for use in a UDF.
|
|
887
903
|
"""
|
|
@@ -906,12 +922,8 @@ class AbstractWarehouse(ABC, Serializable):
|
|
|
906
922
|
are cleaned up as soon as they are no longer needed.
|
|
907
923
|
"""
|
|
908
924
|
to_drop = set(names)
|
|
909
|
-
|
|
910
|
-
|
|
911
|
-
) as pbar:
|
|
912
|
-
for name in to_drop:
|
|
913
|
-
self.db.drop_table(Table(name, self.db.metadata), if_exists=True)
|
|
914
|
-
pbar.update(1)
|
|
925
|
+
for name in to_drop:
|
|
926
|
+
self.db.drop_table(sa.Table(name, self.db.metadata), if_exists=True)
|
|
915
927
|
|
|
916
928
|
|
|
917
929
|
def _random_string(length: int) -> str:
|
datachain/lib/arrow.py
CHANGED
|
@@ -4,6 +4,7 @@ from typing import TYPE_CHECKING, Any, Optional
|
|
|
4
4
|
|
|
5
5
|
import orjson
|
|
6
6
|
import pyarrow as pa
|
|
7
|
+
from pyarrow._csv import ParseOptions
|
|
7
8
|
from pyarrow.dataset import CsvFileFormat, dataset
|
|
8
9
|
from tqdm.auto import tqdm
|
|
9
10
|
|
|
@@ -26,6 +27,18 @@ if TYPE_CHECKING:
|
|
|
26
27
|
DATACHAIN_SIGNAL_SCHEMA_PARQUET_KEY = b"DataChain SignalSchema"
|
|
27
28
|
|
|
28
29
|
|
|
30
|
+
def fix_pyarrow_format(format, parse_options=None):
|
|
31
|
+
# Re-init invalid row handler: https://issues.apache.org/jira/browse/ARROW-17641
|
|
32
|
+
if (
|
|
33
|
+
format
|
|
34
|
+
and isinstance(format, CsvFileFormat)
|
|
35
|
+
and parse_options
|
|
36
|
+
and isinstance(parse_options, ParseOptions)
|
|
37
|
+
):
|
|
38
|
+
format.parse_options = parse_options
|
|
39
|
+
return format
|
|
40
|
+
|
|
41
|
+
|
|
29
42
|
class ArrowGenerator(Generator):
|
|
30
43
|
DEFAULT_BATCH_SIZE = 2**17 # same as `pyarrow._dataset._DEFAULT_BATCH_SIZE`
|
|
31
44
|
|
|
@@ -53,6 +66,7 @@ class ArrowGenerator(Generator):
|
|
|
53
66
|
self.output_schema = output_schema
|
|
54
67
|
self.source = source
|
|
55
68
|
self.nrows = nrows
|
|
69
|
+
self.parse_options = kwargs.pop("parse_options", None)
|
|
56
70
|
self.kwargs = kwargs
|
|
57
71
|
|
|
58
72
|
def process(self, file: File):
|
|
@@ -64,7 +78,11 @@ class ArrowGenerator(Generator):
|
|
|
64
78
|
else:
|
|
65
79
|
fs, fs_path = file.get_fs(), file.get_path()
|
|
66
80
|
|
|
67
|
-
|
|
81
|
+
kwargs = self.kwargs
|
|
82
|
+
if format := kwargs.get("format"):
|
|
83
|
+
kwargs["format"] = fix_pyarrow_format(format, self.parse_options)
|
|
84
|
+
|
|
85
|
+
ds = dataset(fs_path, schema=self.input_schema, filesystem=fs, **kwargs)
|
|
68
86
|
|
|
69
87
|
hf_schema = _get_hf_schema(ds.schema)
|
|
70
88
|
use_datachain_schema = (
|
|
@@ -137,6 +155,10 @@ class ArrowGenerator(Generator):
|
|
|
137
155
|
|
|
138
156
|
|
|
139
157
|
def infer_schema(chain: "DataChain", **kwargs) -> pa.Schema:
|
|
158
|
+
parse_options = kwargs.pop("parse_options", None)
|
|
159
|
+
if format := kwargs.get("format"):
|
|
160
|
+
kwargs["format"] = fix_pyarrow_format(format, parse_options)
|
|
161
|
+
|
|
140
162
|
schemas = []
|
|
141
163
|
for file in chain.collect("file"):
|
|
142
164
|
ds = dataset(file.get_path(), filesystem=file.get_fs(), **kwargs) # type: ignore[union-attr]
|