datachain 0.16.4__py3-none-any.whl → 0.16.5__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of datachain might be problematic. Click here for more details.

@@ -79,6 +79,7 @@ DATASET_INTERNAL_ERROR_MESSAGE = "Internal error on creating dataset"
79
79
  QUERY_SCRIPT_INVALID_LAST_STATEMENT_EXIT_CODE = 10
80
80
  # exit code we use if query script was canceled
81
81
  QUERY_SCRIPT_CANCELED_EXIT_CODE = 11
82
+ QUERY_SCRIPT_SIGTERM_EXIT_CODE = -15 # if query script was terminated by SIGTERM
82
83
 
83
84
  # dataset pull
84
85
  PULL_DATASET_MAX_THREADS = 5
@@ -1645,7 +1646,10 @@ class Catalog:
1645
1646
  thread.join() # wait for the reader thread
1646
1647
 
1647
1648
  logger.info("Process %s exited with return code %s", proc.pid, proc.returncode)
1648
- if proc.returncode == QUERY_SCRIPT_CANCELED_EXIT_CODE:
1649
+ if proc.returncode in (
1650
+ QUERY_SCRIPT_CANCELED_EXIT_CODE,
1651
+ QUERY_SCRIPT_SIGTERM_EXIT_CODE,
1652
+ ):
1649
1653
  raise QueryScriptCancelError(
1650
1654
  "Query script was canceled by user",
1651
1655
  return_code=proc.returncode,
datachain/cli/__init__.py CHANGED
@@ -34,8 +34,10 @@ def main(argv: Optional[list[str]] = None) -> int:
34
34
  datachain_parser = get_parser()
35
35
  args = datachain_parser.parse_args(argv)
36
36
 
37
- if args.command in ("internal-run-udf", "internal-run-udf-worker"):
38
- return handle_udf(args.command)
37
+ if args.command == "internal-run-udf":
38
+ return handle_udf()
39
+ if args.command == "internal-run-udf-worker":
40
+ return handle_udf_runner(args.fd)
39
41
 
40
42
  if args.command is None:
41
43
  datachain_parser.print_help(sys.stderr)
@@ -303,13 +305,13 @@ def handle_general_exception(exc, args, logging_level):
303
305
  return error, 1
304
306
 
305
307
 
306
- def handle_udf(command):
307
- if command == "internal-run-udf":
308
- from datachain.query.dispatch import udf_entrypoint
308
+ def handle_udf() -> int:
309
+ from datachain.query.dispatch import udf_entrypoint
309
310
 
310
- return udf_entrypoint()
311
+ return udf_entrypoint()
311
312
 
312
- if command == "internal-run-udf-worker":
313
- from datachain.query.dispatch import udf_worker_entrypoint
314
313
 
315
- return udf_worker_entrypoint()
314
+ def handle_udf_runner(fd: Optional[int] = None) -> int:
315
+ from datachain.query.dispatch import udf_worker_entrypoint
316
+
317
+ return udf_worker_entrypoint(fd)
@@ -29,6 +29,7 @@ def query(
29
29
  name=os.path.basename(script),
30
30
  query=script_content,
31
31
  query_type=JobQueryType.PYTHON,
32
+ status=JobStatus.RUNNING,
32
33
  python_version=python_version,
33
34
  params=params,
34
35
  )
@@ -549,7 +549,15 @@ def get_parser() -> ArgumentParser: # noqa: PLR0915
549
549
  add_anon_arg(parse_gc)
550
550
 
551
551
  subp.add_parser("internal-run-udf", parents=[parent_parser])
552
- subp.add_parser("internal-run-udf-worker", parents=[parent_parser])
552
+ run_udf_worker = subp.add_parser("internal-run-udf-worker", parents=[parent_parser])
553
+ run_udf_worker.add_argument(
554
+ "--fd",
555
+ type=int,
556
+ action="store",
557
+ default=None,
558
+ help="File descriptor to write results to",
559
+ )
560
+
553
561
  add_completion_parser(subp, [parent_parser])
554
562
  return parser
555
563
 
@@ -3,6 +3,7 @@ from enum import Enum
3
3
 
4
4
  class JobStatus(int, Enum):
5
5
  CREATED = 1
6
+ SCHEDULED = 10
6
7
  QUEUED = 2
7
8
  INIT = 3
8
9
  RUNNING = 4
@@ -254,6 +254,7 @@ class AbstractMetastore(ABC, Serializable):
254
254
  name: str,
255
255
  query: str,
256
256
  query_type: JobQueryType = JobQueryType.PYTHON,
257
+ status: JobStatus = JobStatus.CREATED,
257
258
  workers: int = 1,
258
259
  python_version: Optional[str] = None,
259
260
  params: Optional[dict[str, str]] = None,
@@ -264,33 +265,35 @@ class AbstractMetastore(ABC, Serializable):
264
265
  """
265
266
 
266
267
  @abstractmethod
267
- def set_job_status(
268
+ def get_job(self, job_id: str) -> Optional[Job]:
269
+ """Returns the job with the given ID."""
270
+
271
+ @abstractmethod
272
+ def update_job(
268
273
  self,
269
274
  job_id: str,
270
- status: JobStatus,
275
+ status: Optional[JobStatus] = None,
276
+ exit_code: Optional[int] = None,
271
277
  error_message: Optional[str] = None,
272
278
  error_stack: Optional[str] = None,
279
+ finished_at: Optional[datetime] = None,
273
280
  metrics: Optional[dict[str, Any]] = None,
274
- ) -> None:
275
- """Set the status of the given job."""
281
+ ) -> Optional["Job"]:
282
+ """Updates job fields."""
276
283
 
277
284
  @abstractmethod
278
- def get_job_status(self, job_id: str) -> Optional[JobStatus]:
279
- """Returns the status of the given job."""
280
-
281
- @abstractmethod
282
- def set_job_and_dataset_status(
285
+ def set_job_status(
283
286
  self,
284
287
  job_id: str,
285
- job_status: JobStatus,
286
- dataset_status: DatasetStatus,
288
+ status: JobStatus,
289
+ error_message: Optional[str] = None,
290
+ error_stack: Optional[str] = None,
287
291
  ) -> None:
288
- """Set the status of the given job and dataset."""
292
+ """Set the status of the given job."""
289
293
 
290
294
  @abstractmethod
291
- def get_job_dataset_versions(self, job_id: str) -> list[tuple[str, int]]:
292
- """Returns dataset names and versions for the job."""
293
- raise NotImplementedError
295
+ def get_job_status(self, job_id: str) -> Optional[JobStatus]:
296
+ """Returns the status of the given job."""
294
297
 
295
298
 
296
299
  class AbstractDBMetastore(AbstractMetastore):
@@ -651,30 +654,31 @@ class AbstractDBMetastore(AbstractMetastore):
651
654
  dataset_version = dataset.get_version(version)
652
655
 
653
656
  values = {}
657
+ version_values: dict = {}
654
658
  for field, value in kwargs.items():
655
659
  if field in self._dataset_version_fields[1:]:
656
660
  if field == "schema":
657
- dataset_version.update(**{field: DatasetRecord.parse_schema(value)})
658
661
  values[field] = json.dumps(value) if value else None
662
+ version_values[field] = DatasetRecord.parse_schema(value)
659
663
  elif field == "feature_schema":
660
664
  values[field] = json.dumps(value) if value else None
665
+ version_values[field] = value
661
666
  elif field == "preview" and isinstance(value, list):
662
667
  values[field] = json.dumps(value, cls=JSONSerialize)
668
+ version_values[field] = value
663
669
  else:
664
670
  values[field] = value
665
- dataset_version.update(**{field: value})
666
-
667
- if not values:
668
- # Nothing to update
669
- return dataset_version
671
+ version_values[field] = value
670
672
 
671
- dv = self._datasets_versions
672
- self.db.execute(
673
- self._datasets_versions_update()
674
- .where(dv.c.id == dataset_version.id)
675
- .values(values),
676
- conn=conn,
677
- ) # type: ignore [attr-defined]
673
+ if values:
674
+ dv = self._datasets_versions
675
+ self.db.execute(
676
+ self._datasets_versions_update()
677
+ .where(dv.c.dataset_id == dataset.id and dv.c.version == version)
678
+ .values(values),
679
+ conn=conn,
680
+ ) # type: ignore [attr-defined]
681
+ dataset_version.update(**version_values)
678
682
 
679
683
  return dataset_version
680
684
 
@@ -702,7 +706,7 @@ class AbstractDBMetastore(AbstractMetastore):
702
706
  dataset_fields: list[str],
703
707
  dataset_version_fields: list[str],
704
708
  isouter: bool = True,
705
- ):
709
+ ) -> "Select":
706
710
  if not (
707
711
  self.db.has_table(self._datasets.name)
708
712
  and self.db.has_table(self._datasets_versions.name)
@@ -719,12 +723,12 @@ class AbstractDBMetastore(AbstractMetastore):
719
723
  j = d.join(dv, d.c.id == dv.c.dataset_id, isouter=isouter)
720
724
  return query.select_from(j)
721
725
 
722
- def _base_dataset_query(self):
726
+ def _base_dataset_query(self) -> "Select":
723
727
  return self._get_dataset_query(
724
728
  self._dataset_fields, self._dataset_version_fields
725
729
  )
726
730
 
727
- def _base_list_datasets_query(self):
731
+ def _base_list_datasets_query(self) -> "Select":
728
732
  return self._get_dataset_query(
729
733
  self._dataset_list_fields, self._dataset_list_version_fields, isouter=False
730
734
  )
@@ -1018,6 +1022,7 @@ class AbstractDBMetastore(AbstractMetastore):
1018
1022
  name: str,
1019
1023
  query: str,
1020
1024
  query_type: JobQueryType = JobQueryType.PYTHON,
1025
+ status: JobStatus = JobStatus.CREATED,
1021
1026
  workers: int = 1,
1022
1027
  python_version: Optional[str] = None,
1023
1028
  params: Optional[dict[str, str]] = None,
@@ -1032,7 +1037,7 @@ class AbstractDBMetastore(AbstractMetastore):
1032
1037
  self._jobs_insert().values(
1033
1038
  id=job_id,
1034
1039
  name=name,
1035
- status=JobStatus.CREATED,
1040
+ status=status,
1036
1041
  created_at=datetime.now(timezone.utc),
1037
1042
  query=query,
1038
1043
  query_type=query_type.value,
@@ -1047,25 +1052,65 @@ class AbstractDBMetastore(AbstractMetastore):
1047
1052
  )
1048
1053
  return job_id
1049
1054
 
1055
+ def get_job(self, job_id: str, conn=None) -> Optional[Job]:
1056
+ """Returns the job with the given ID."""
1057
+ query = self._jobs_select(self._jobs).where(self._jobs.c.id == job_id)
1058
+ results = list(self.db.execute(query, conn=conn))
1059
+ if not results:
1060
+ return None
1061
+ return self._parse_job(results[0])
1062
+
1063
+ def update_job(
1064
+ self,
1065
+ job_id: str,
1066
+ status: Optional[JobStatus] = None,
1067
+ exit_code: Optional[int] = None,
1068
+ error_message: Optional[str] = None,
1069
+ error_stack: Optional[str] = None,
1070
+ finished_at: Optional[datetime] = None,
1071
+ metrics: Optional[dict[str, Any]] = None,
1072
+ conn: Optional[Any] = None,
1073
+ ) -> Optional["Job"]:
1074
+ """Updates job fields."""
1075
+ values: dict = {}
1076
+ if status is not None:
1077
+ values["status"] = status
1078
+ if exit_code is not None:
1079
+ values["exit_code"] = exit_code
1080
+ if error_message is not None:
1081
+ values["error_message"] = error_message
1082
+ if error_stack is not None:
1083
+ values["error_stack"] = error_stack
1084
+ if finished_at is not None:
1085
+ values["finished_at"] = finished_at
1086
+ if metrics:
1087
+ values["metrics"] = json.dumps(metrics)
1088
+
1089
+ if values:
1090
+ j = self._jobs
1091
+ self.db.execute(
1092
+ self._jobs_update().where(j.c.id == job_id).values(**values),
1093
+ conn=conn,
1094
+ ) # type: ignore [attr-defined]
1095
+
1096
+ return self.get_job(job_id, conn=conn)
1097
+
1050
1098
  def set_job_status(
1051
1099
  self,
1052
1100
  job_id: str,
1053
1101
  status: JobStatus,
1054
1102
  error_message: Optional[str] = None,
1055
1103
  error_stack: Optional[str] = None,
1056
- metrics: Optional[dict[str, Any]] = None,
1057
1104
  conn: Optional[Any] = None,
1058
1105
  ) -> None:
1059
1106
  """Set the status of the given job."""
1060
- values: dict = {"status": status.value}
1061
- if status.value in JobStatus.finished():
1107
+ values: dict = {"status": status}
1108
+ if status in JobStatus.finished():
1062
1109
  values["finished_at"] = datetime.now(timezone.utc)
1063
1110
  if error_message:
1064
1111
  values["error_message"] = error_message
1065
1112
  if error_stack:
1066
1113
  values["error_stack"] = error_stack
1067
- if metrics:
1068
- values["metrics"] = json.dumps(metrics)
1069
1114
  self.db.execute(
1070
1115
  self._jobs_update(self._jobs.c.id == job_id).values(**values),
1071
1116
  conn=conn,
@@ -1086,37 +1131,3 @@ class AbstractDBMetastore(AbstractMetastore):
1086
1131
  if not results:
1087
1132
  return None
1088
1133
  return results[0][0]
1089
-
1090
- def set_job_and_dataset_status(
1091
- self,
1092
- job_id: str,
1093
- job_status: JobStatus,
1094
- dataset_status: DatasetStatus,
1095
- ) -> None:
1096
- """Set the status of the given job and dataset."""
1097
- with self.db.transaction() as conn:
1098
- self.set_job_status(job_id, status=job_status, conn=conn)
1099
- dv = self._datasets_versions
1100
- query = (
1101
- self._datasets_versions_update()
1102
- .where(
1103
- (dv.c.job_id == job_id) & (dv.c.status != DatasetStatus.COMPLETE)
1104
- )
1105
- .values(status=dataset_status)
1106
- )
1107
- self.db.execute(query, conn=conn) # type: ignore[attr-defined]
1108
-
1109
- def get_job_dataset_versions(self, job_id: str) -> list[tuple[str, int]]:
1110
- """Returns dataset names and versions for the job."""
1111
- dv = self._datasets_versions
1112
- ds = self._datasets
1113
-
1114
- join_condition = dv.c.dataset_id == ds.c.id
1115
-
1116
- query = (
1117
- self._datasets_versions_select(ds.c.name, dv.c.version)
1118
- .select_from(dv.join(ds, join_condition))
1119
- .where(dv.c.job_id == job_id)
1120
- )
1121
-
1122
- return list(self.db.execute(query))
@@ -11,16 +11,15 @@ from urllib.parse import urlparse
11
11
 
12
12
  import attrs
13
13
  import sqlalchemy as sa
14
- from sqlalchemy import Table, case, select
15
- from sqlalchemy.sql import func
16
14
  from sqlalchemy.sql.expression import true
17
- from tqdm.auto import tqdm
18
15
 
19
16
  from datachain.client import Client
20
17
  from datachain.data_storage.schema import convert_rows_custom_column_types
21
18
  from datachain.data_storage.serializer import Serializable
22
19
  from datachain.dataset import DatasetRecord, StorageURI
23
20
  from datachain.node import DirType, DirTypeGroup, Node, NodeWithPath, get_path
21
+ from datachain.query.batch import RowsOutput
22
+ from datachain.query.utils import get_query_id_column
24
23
  from datachain.sql.functions import path as pathfunc
25
24
  from datachain.sql.types import Int, SQLType
26
25
  from datachain.utils import sql_escape_like
@@ -31,7 +30,6 @@ if TYPE_CHECKING:
31
30
  _FromClauseArgument,
32
31
  _OnClauseArgument,
33
32
  )
34
- from sqlalchemy.sql.selectable import Select
35
33
  from sqlalchemy.types import TypeEngine
36
34
 
37
35
  from datachain.data_storage import schema
@@ -199,13 +197,13 @@ class AbstractWarehouse(ABC, Serializable):
199
197
  # Query Execution
200
198
  #
201
199
 
202
- def query_count(self, query: sa.sql.selectable.Select) -> int:
200
+ def query_count(self, query: sa.Select) -> int:
203
201
  """Count the number of rows in a query."""
204
- count_query = sa.select(func.count(1)).select_from(query.subquery())
202
+ count_query = sa.select(sa.func.count(1)).select_from(query.subquery())
205
203
  return next(self.db.execute(count_query))[0]
206
204
 
207
205
  def table_rows_count(self, table) -> int:
208
- count_query = sa.select(func.count(1)).select_from(table)
206
+ count_query = sa.select(sa.func.count(1)).select_from(table)
209
207
  return next(self.db.execute(count_query))[0]
210
208
 
211
209
  def dataset_select_paginated(
@@ -278,7 +276,7 @@ class AbstractWarehouse(ABC, Serializable):
278
276
  name: str,
279
277
  columns: Sequence["sa.Column"] = (),
280
278
  if_not_exists: bool = True,
281
- ) -> Table:
279
+ ) -> sa.Table:
282
280
  """Creates a dataset rows table for the given dataset name and columns"""
283
281
 
284
282
  def drop_dataset_rows_table(
@@ -289,7 +287,7 @@ class AbstractWarehouse(ABC, Serializable):
289
287
  ) -> None:
290
288
  """Drops a dataset rows table for the given dataset name."""
291
289
  table_name = self.dataset_table_name(dataset.name, version)
292
- table = Table(table_name, self.db.metadata)
290
+ table = sa.Table(table_name, self.db.metadata)
293
291
  self.db.drop_table(table, if_exists=if_exists)
294
292
 
295
293
  @abstractmethod
@@ -309,7 +307,7 @@ class AbstractWarehouse(ABC, Serializable):
309
307
 
310
308
  def dataset_rows_select(
311
309
  self,
312
- query: sa.sql.selectable.Select,
310
+ query: sa.Select,
313
311
  **kwargs,
314
312
  ) -> Iterator[tuple[Any, ...]]:
315
313
  """
@@ -320,6 +318,24 @@ class AbstractWarehouse(ABC, Serializable):
320
318
  query.selected_columns, rows, self.db.dialect
321
319
  )
322
320
 
321
+ def dataset_rows_select_from_ids(
322
+ self,
323
+ query: sa.Select,
324
+ ids: Iterable[RowsOutput],
325
+ is_batched: bool,
326
+ ) -> Iterator[RowsOutput]:
327
+ """
328
+ Fetch dataset rows from database using a list of IDs.
329
+ """
330
+ if (id_col := get_query_id_column(query)) is None:
331
+ raise RuntimeError("sys__id column not found in query")
332
+
333
+ if is_batched:
334
+ for batch in ids:
335
+ yield list(self.dataset_rows_select(query.where(id_col.in_(batch))))
336
+ else:
337
+ yield from self.dataset_rows_select(query.where(id_col.in_(ids)))
338
+
323
339
  @abstractmethod
324
340
  def get_dataset_sources(
325
341
  self, dataset: DatasetRecord, version: int
@@ -341,7 +357,7 @@ class AbstractWarehouse(ABC, Serializable):
341
357
  """Returns total number of rows in a dataset"""
342
358
  dr = self.dataset_rows(dataset, version)
343
359
  table = dr.get_table()
344
- query = select(sa.func.count(table.c.sys__id))
360
+ query = sa.select(sa.func.count(table.c.sys__id))
345
361
  (res,) = self.db.execute(query)
346
362
  return res[0]
347
363
 
@@ -364,7 +380,7 @@ class AbstractWarehouse(ABC, Serializable):
364
380
  ]
365
381
  if size_columns:
366
382
  expressions = (*expressions, sa.func.sum(sum(size_columns)))
367
- query = select(*expressions)
383
+ query = sa.select(*expressions)
368
384
  ((nrows, *rest),) = self.db.execute(query)
369
385
  return nrows, rest[0] if rest else 0
370
386
 
@@ -373,10 +389,10 @@ class AbstractWarehouse(ABC, Serializable):
373
389
  """Convert File entries so they can be passed on to `insert_rows()`"""
374
390
 
375
391
  @abstractmethod
376
- def insert_rows(self, table: Table, rows: Iterable[dict[str, Any]]) -> None:
392
+ def insert_rows(self, table: sa.Table, rows: Iterable[dict[str, Any]]) -> None:
377
393
  """Does batch inserts of any kind of rows into table"""
378
394
 
379
- def insert_rows_done(self, table: Table) -> None:
395
+ def insert_rows_done(self, table: sa.Table) -> None:
380
396
  """
381
397
  Only needed for certain implementations
382
398
  to signal when rows inserts are complete.
@@ -497,7 +513,7 @@ class AbstractWarehouse(ABC, Serializable):
497
513
  ).subquery()
498
514
  path_glob = "/".join([*path_list, glob_name])
499
515
  dirpath = path_glob[: -len(glob_name)]
500
- relpath = func.substr(de.c(q, "path"), len(dirpath) + 1)
516
+ relpath = sa.func.substr(de.c(q, "path"), len(dirpath) + 1)
501
517
 
502
518
  return self.get_nodes(
503
519
  self.expand_query(de, q, dr)
@@ -584,13 +600,13 @@ class AbstractWarehouse(ABC, Serializable):
584
600
  default = getattr(
585
601
  attrs.fields(Node), dr.without_object(column.name)
586
602
  ).default
587
- return func.coalesce(column, default).label(column.name)
603
+ return sa.func.coalesce(column, default).label(column.name)
588
604
 
589
605
  return sa.select(
590
606
  q.c.sys__id,
591
- case((de.c(q, "is_dir") == true(), DirType.DIR), else_=DirType.FILE).label(
592
- dr.col_name("dir_type")
593
- ),
607
+ sa.case(
608
+ (de.c(q, "is_dir") == true(), DirType.DIR), else_=DirType.FILE
609
+ ).label(dr.col_name("dir_type")),
594
610
  de.c(q, "path"),
595
611
  with_default(dr.c("etag")),
596
612
  de.c(q, "version"),
@@ -665,7 +681,7 @@ class AbstractWarehouse(ABC, Serializable):
665
681
  return de.c(inner_query, f)
666
682
 
667
683
  return self.db.execute(
668
- select(*(field_to_expr(f) for f in fields)).order_by(
684
+ sa.select(*(field_to_expr(f) for f in fields)).order_by(
669
685
  de.c(inner_query, "source"),
670
686
  de.c(inner_query, "path"),
671
687
  de.c(inner_query, "version"),
@@ -687,7 +703,7 @@ class AbstractWarehouse(ABC, Serializable):
687
703
  return dr.c(f)
688
704
 
689
705
  q = (
690
- select(*(field_to_expr(f) for f in fields))
706
+ sa.select(*(field_to_expr(f) for f in fields))
691
707
  .where(
692
708
  dr.c("path").like(f"{sql_escape_like(dirpath)}%"),
693
709
  ~self.instr(pathfunc.name(dr.c("path")), "/"),
@@ -722,10 +738,10 @@ class AbstractWarehouse(ABC, Serializable):
722
738
  sub_glob = posixpath.join(path, "*")
723
739
  dr = dataset_rows
724
740
  selections: list[sa.ColumnElement] = [
725
- func.sum(dr.c("size")),
741
+ sa.func.sum(dr.c("size")),
726
742
  ]
727
743
  if count_files:
728
- selections.append(func.count())
744
+ selections.append(sa.func.count())
729
745
  results = next(
730
746
  self.db.execute(
731
747
  dr.select(*selections).where(
@@ -842,7 +858,7 @@ class AbstractWarehouse(ABC, Serializable):
842
858
  self,
843
859
  columns: Sequence["sa.Column"] = (),
844
860
  name: Optional[str] = None,
845
- ) -> "sa.Table":
861
+ ) -> sa.Table:
846
862
  """
847
863
  Create a temporary table for storing custom signals generated by a UDF.
848
864
  SQLite TEMPORARY tables cannot be directly used as they are process-specific,
@@ -860,8 +876,8 @@ class AbstractWarehouse(ABC, Serializable):
860
876
  @abstractmethod
861
877
  def copy_table(
862
878
  self,
863
- table: Table,
864
- query: "Select",
879
+ table: sa.Table,
880
+ query: sa.Select,
865
881
  progress_cb: Optional[Callable[[int], None]] = None,
866
882
  ) -> None:
867
883
  """
@@ -875,13 +891,13 @@ class AbstractWarehouse(ABC, Serializable):
875
891
  right: "_FromClauseArgument",
876
892
  onclause: "_OnClauseArgument",
877
893
  inner: bool = True,
878
- ) -> "Select":
894
+ ) -> sa.Select:
879
895
  """
880
896
  Join two tables together.
881
897
  """
882
898
 
883
899
  @abstractmethod
884
- def create_pre_udf_table(self, query: "Select") -> "Table":
900
+ def create_pre_udf_table(self, query: sa.Select) -> sa.Table:
885
901
  """
886
902
  Create a temporary table from a query for use in a UDF.
887
903
  """
@@ -906,12 +922,8 @@ class AbstractWarehouse(ABC, Serializable):
906
922
  are cleaned up as soon as they are no longer needed.
907
923
  """
908
924
  to_drop = set(names)
909
- with tqdm(
910
- desc="Cleanup", unit=" tables", total=len(to_drop), leave=False
911
- ) as pbar:
912
- for name in to_drop:
913
- self.db.drop_table(Table(name, self.db.metadata), if_exists=True)
914
- pbar.update(1)
925
+ for name in to_drop:
926
+ self.db.drop_table(sa.Table(name, self.db.metadata), if_exists=True)
915
927
 
916
928
 
917
929
  def _random_string(length: int) -> str:
datachain/lib/arrow.py CHANGED
@@ -4,6 +4,7 @@ from typing import TYPE_CHECKING, Any, Optional
4
4
 
5
5
  import orjson
6
6
  import pyarrow as pa
7
+ from pyarrow._csv import ParseOptions
7
8
  from pyarrow.dataset import CsvFileFormat, dataset
8
9
  from tqdm.auto import tqdm
9
10
 
@@ -26,6 +27,18 @@ if TYPE_CHECKING:
26
27
  DATACHAIN_SIGNAL_SCHEMA_PARQUET_KEY = b"DataChain SignalSchema"
27
28
 
28
29
 
30
+ def fix_pyarrow_format(format, parse_options=None):
31
+ # Re-init invalid row handler: https://issues.apache.org/jira/browse/ARROW-17641
32
+ if (
33
+ format
34
+ and isinstance(format, CsvFileFormat)
35
+ and parse_options
36
+ and isinstance(parse_options, ParseOptions)
37
+ ):
38
+ format.parse_options = parse_options
39
+ return format
40
+
41
+
29
42
  class ArrowGenerator(Generator):
30
43
  DEFAULT_BATCH_SIZE = 2**17 # same as `pyarrow._dataset._DEFAULT_BATCH_SIZE`
31
44
 
@@ -53,6 +66,7 @@ class ArrowGenerator(Generator):
53
66
  self.output_schema = output_schema
54
67
  self.source = source
55
68
  self.nrows = nrows
69
+ self.parse_options = kwargs.pop("parse_options", None)
56
70
  self.kwargs = kwargs
57
71
 
58
72
  def process(self, file: File):
@@ -64,7 +78,11 @@ class ArrowGenerator(Generator):
64
78
  else:
65
79
  fs, fs_path = file.get_fs(), file.get_path()
66
80
 
67
- ds = dataset(fs_path, schema=self.input_schema, filesystem=fs, **self.kwargs)
81
+ kwargs = self.kwargs
82
+ if format := kwargs.get("format"):
83
+ kwargs["format"] = fix_pyarrow_format(format, self.parse_options)
84
+
85
+ ds = dataset(fs_path, schema=self.input_schema, filesystem=fs, **kwargs)
68
86
 
69
87
  hf_schema = _get_hf_schema(ds.schema)
70
88
  use_datachain_schema = (
@@ -137,6 +155,10 @@ class ArrowGenerator(Generator):
137
155
 
138
156
 
139
157
  def infer_schema(chain: "DataChain", **kwargs) -> pa.Schema:
158
+ parse_options = kwargs.pop("parse_options", None)
159
+ if format := kwargs.get("format"):
160
+ kwargs["format"] = fix_pyarrow_format(format, parse_options)
161
+
140
162
  schemas = []
141
163
  for file in chain.collect("file"):
142
164
  ds = dataset(file.get_path(), filesystem=file.get_fs(), **kwargs) # type: ignore[union-attr]
datachain/lib/dc/csv.py CHANGED
@@ -124,4 +124,5 @@ def read_csv(
124
124
  source=source,
125
125
  nrows=nrows,
126
126
  format=format,
127
+ parse_options=parse_options,
127
128
  )