datachain 0.7.1__py3-none-any.whl → 0.7.3__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of datachain might be problematic. Click here for more details.

Files changed (46) hide show
  1. datachain/__init__.py +0 -2
  2. datachain/catalog/catalog.py +12 -9
  3. datachain/cli.py +109 -9
  4. datachain/client/fsspec.py +9 -9
  5. datachain/data_storage/metastore.py +63 -11
  6. datachain/data_storage/schema.py +2 -2
  7. datachain/data_storage/sqlite.py +5 -4
  8. datachain/data_storage/warehouse.py +18 -18
  9. datachain/dataset.py +142 -14
  10. datachain/func/__init__.py +49 -0
  11. datachain/{lib/func → func}/aggregate.py +13 -11
  12. datachain/func/array.py +176 -0
  13. datachain/func/base.py +23 -0
  14. datachain/func/conditional.py +81 -0
  15. datachain/func/func.py +384 -0
  16. datachain/func/path.py +110 -0
  17. datachain/func/random.py +23 -0
  18. datachain/func/string.py +154 -0
  19. datachain/func/window.py +49 -0
  20. datachain/lib/arrow.py +24 -12
  21. datachain/lib/data_model.py +25 -9
  22. datachain/lib/dataset_info.py +9 -5
  23. datachain/lib/dc.py +94 -56
  24. datachain/lib/hf.py +1 -1
  25. datachain/lib/signal_schema.py +1 -1
  26. datachain/lib/utils.py +1 -0
  27. datachain/lib/webdataset_laion.py +5 -5
  28. datachain/model/bbox.py +2 -2
  29. datachain/model/pose.py +5 -5
  30. datachain/model/segment.py +2 -2
  31. datachain/nodes_fetcher.py +2 -2
  32. datachain/query/dataset.py +57 -34
  33. datachain/remote/studio.py +40 -8
  34. datachain/sql/__init__.py +0 -2
  35. datachain/sql/functions/__init__.py +0 -26
  36. datachain/sql/selectable.py +11 -5
  37. datachain/sql/sqlite/base.py +11 -2
  38. datachain/studio.py +29 -0
  39. {datachain-0.7.1.dist-info → datachain-0.7.3.dist-info}/METADATA +2 -2
  40. {datachain-0.7.1.dist-info → datachain-0.7.3.dist-info}/RECORD +44 -37
  41. datachain/lib/func/__init__.py +0 -32
  42. datachain/lib/func/func.py +0 -152
  43. {datachain-0.7.1.dist-info → datachain-0.7.3.dist-info}/LICENSE +0 -0
  44. {datachain-0.7.1.dist-info → datachain-0.7.3.dist-info}/WHEEL +0 -0
  45. {datachain-0.7.1.dist-info → datachain-0.7.3.dist-info}/entry_points.txt +0 -0
  46. {datachain-0.7.1.dist-info → datachain-0.7.3.dist-info}/top_level.txt +0 -0
datachain/__init__.py CHANGED
@@ -1,4 +1,3 @@
1
- from datachain.lib import func
2
1
  from datachain.lib.data_model import DataModel, DataType, is_chain_type
3
2
  from datachain.lib.dc import C, Column, DataChain, Sys
4
3
  from datachain.lib.file import (
@@ -35,7 +34,6 @@ __all__ = [
35
34
  "Sys",
36
35
  "TarVFile",
37
36
  "TextFile",
38
- "func",
39
37
  "is_chain_type",
40
38
  "metrics",
41
39
  "param",
@@ -38,6 +38,7 @@ from datachain.dataset import (
38
38
  DATASET_PREFIX,
39
39
  QUERY_DATASET_PREFIX,
40
40
  DatasetDependency,
41
+ DatasetListRecord,
41
42
  DatasetRecord,
42
43
  DatasetStats,
43
44
  DatasetStatus,
@@ -54,7 +55,6 @@ from datachain.error import (
54
55
  QueryScriptCancelError,
55
56
  QueryScriptRunError,
56
57
  )
57
- from datachain.listing import Listing
58
58
  from datachain.node import DirType, Node, NodeWithPath
59
59
  from datachain.nodes_thread_pool import NodesThreadPool
60
60
  from datachain.remote.studio import StudioClient
@@ -73,9 +73,10 @@ if TYPE_CHECKING:
73
73
  AbstractMetastore,
74
74
  AbstractWarehouse,
75
75
  )
76
- from datachain.dataset import DatasetVersion
76
+ from datachain.dataset import DatasetListVersion
77
77
  from datachain.job import Job
78
78
  from datachain.lib.file import File
79
+ from datachain.listing import Listing
79
80
 
80
81
  logger = logging.getLogger("datachain")
81
82
 
@@ -236,7 +237,7 @@ class DatasetRowsFetcher(NodesThreadPool):
236
237
  class NodeGroup:
237
238
  """Class for a group of nodes from the same source"""
238
239
 
239
- listing: Listing
240
+ listing: "Listing"
240
241
  sources: list[DataSource]
241
242
 
242
243
  # The source path within the bucket
@@ -591,8 +592,9 @@ class Catalog:
591
592
  client_config=None,
592
593
  object_name="file",
593
594
  skip_indexing=False,
594
- ) -> tuple[Listing, str]:
595
+ ) -> tuple["Listing", str]:
595
596
  from datachain.lib.dc import DataChain
597
+ from datachain.listing import Listing
596
598
 
597
599
  DataChain.from_storage(
598
600
  source, session=self.session, update=update, object_name=object_name
@@ -660,7 +662,8 @@ class Catalog:
660
662
  no_glob: bool = False,
661
663
  client_config=None,
662
664
  ) -> list[NodeGroup]:
663
- from datachain.query import DatasetQuery
665
+ from datachain.listing import Listing
666
+ from datachain.query.dataset import DatasetQuery
664
667
 
665
668
  def _row_to_node(d: dict[str, Any]) -> Node:
666
669
  del d["file__source"]
@@ -876,7 +879,7 @@ class Catalog:
876
879
  def update_dataset_version_with_warehouse_info(
877
880
  self, dataset: DatasetRecord, version: int, rows_dropped=False, **kwargs
878
881
  ) -> None:
879
- from datachain.query import DatasetQuery
882
+ from datachain.query.dataset import DatasetQuery
880
883
 
881
884
  dataset_version = dataset.get_version(version)
882
885
 
@@ -1133,7 +1136,7 @@ class Catalog:
1133
1136
 
1134
1137
  return direct_dependencies
1135
1138
 
1136
- def ls_datasets(self, include_listing: bool = False) -> Iterator[DatasetRecord]:
1139
+ def ls_datasets(self, include_listing: bool = False) -> Iterator[DatasetListRecord]:
1137
1140
  datasets = self.metastore.list_datasets()
1138
1141
  for d in datasets:
1139
1142
  if not d.is_bucket_listing or include_listing:
@@ -1142,7 +1145,7 @@ class Catalog:
1142
1145
  def list_datasets_versions(
1143
1146
  self,
1144
1147
  include_listing: bool = False,
1145
- ) -> Iterator[tuple[DatasetRecord, "DatasetVersion", Optional["Job"]]]:
1148
+ ) -> Iterator[tuple[DatasetListRecord, "DatasetListVersion", Optional["Job"]]]:
1146
1149
  """Iterate over all dataset versions with related jobs."""
1147
1150
  datasets = list(self.ls_datasets(include_listing=include_listing))
1148
1151
 
@@ -1177,7 +1180,7 @@ class Catalog:
1177
1180
  def ls_dataset_rows(
1178
1181
  self, name: str, version: int, offset=None, limit=None
1179
1182
  ) -> list[dict]:
1180
- from datachain.query import DatasetQuery
1183
+ from datachain.query.dataset import DatasetQuery
1181
1184
 
1182
1185
  dataset = self.get_dataset(name)
1183
1186
 
datachain/cli.py CHANGED
@@ -18,7 +18,12 @@ from datachain.cli_utils import BooleanOptionalAction, CommaSeparatedArgs, KeyVa
18
18
  from datachain.config import Config
19
19
  from datachain.error import DataChainError
20
20
  from datachain.lib.dc import DataChain
21
- from datachain.studio import list_datasets, process_studio_cli_args
21
+ from datachain.studio import (
22
+ edit_studio_dataset,
23
+ list_datasets,
24
+ process_studio_cli_args,
25
+ remove_studio_dataset,
26
+ )
22
27
  from datachain.telemetry import telemetry
23
28
 
24
29
  if TYPE_CHECKING:
@@ -403,21 +408,44 @@ def get_parser() -> ArgumentParser: # noqa: PLR0915
403
408
  parse_edit_dataset.add_argument(
404
409
  "--new-name",
405
410
  action="store",
406
- default="",
407
411
  help="Dataset new name",
408
412
  )
409
413
  parse_edit_dataset.add_argument(
410
414
  "--description",
411
415
  action="store",
412
- default="",
413
416
  help="Dataset description",
414
417
  )
415
418
  parse_edit_dataset.add_argument(
416
419
  "--labels",
417
- default=[],
418
420
  nargs="+",
419
421
  help="Dataset labels",
420
422
  )
423
+ parse_edit_dataset.add_argument(
424
+ "--studio",
425
+ action="store_true",
426
+ default=False,
427
+ help="Edit dataset from Studio",
428
+ )
429
+ parse_edit_dataset.add_argument(
430
+ "-L",
431
+ "--local",
432
+ action="store_true",
433
+ default=False,
434
+ help="Edit local dataset only",
435
+ )
436
+ parse_edit_dataset.add_argument(
437
+ "-a",
438
+ "--all",
439
+ action="store_true",
440
+ default=True,
441
+ help="Edit both datasets from studio and local",
442
+ )
443
+ parse_edit_dataset.add_argument(
444
+ "--team",
445
+ action="store",
446
+ default=None,
447
+ help="The team to edit a dataset. By default, it will use team from config.",
448
+ )
421
449
 
422
450
  datasets_parser = subp.add_parser(
423
451
  "datasets", parents=[parent_parser], description="List datasets"
@@ -466,6 +494,32 @@ def get_parser() -> ArgumentParser: # noqa: PLR0915
466
494
  action=BooleanOptionalAction,
467
495
  help="Force delete registered dataset with all of it's versions",
468
496
  )
497
+ rm_dataset_parser.add_argument(
498
+ "--studio",
499
+ action="store_true",
500
+ default=False,
501
+ help="Remove dataset from Studio",
502
+ )
503
+ rm_dataset_parser.add_argument(
504
+ "-L",
505
+ "--local",
506
+ action="store_true",
507
+ default=False,
508
+ help="Remove local datasets only",
509
+ )
510
+ rm_dataset_parser.add_argument(
511
+ "-a",
512
+ "--all",
513
+ action="store_true",
514
+ default=True,
515
+ help="Remove both local and studio",
516
+ )
517
+ rm_dataset_parser.add_argument(
518
+ "--team",
519
+ action="store",
520
+ default=None,
521
+ help="The team to delete a dataset. By default, it will use team from config.",
522
+ )
469
523
 
470
524
  dataset_stats_parser = subp.add_parser(
471
525
  "dataset-stats",
@@ -909,8 +963,40 @@ def rm_dataset(
909
963
  name: str,
910
964
  version: Optional[int] = None,
911
965
  force: Optional[bool] = False,
966
+ studio: bool = False,
967
+ local: bool = False,
968
+ all: bool = True,
969
+ team: Optional[str] = None,
970
+ ):
971
+ token = Config().read().get("studio", {}).get("token")
972
+ all, local, studio = _determine_flavors(studio, local, all, token)
973
+
974
+ if all or local:
975
+ catalog.remove_dataset(name, version=version, force=force)
976
+
977
+ if (all or studio) and token:
978
+ remove_studio_dataset(team, name, version, force)
979
+
980
+
981
+ def edit_dataset(
982
+ catalog: "Catalog",
983
+ name: str,
984
+ new_name: Optional[str] = None,
985
+ description: Optional[str] = None,
986
+ labels: Optional[list[str]] = None,
987
+ studio: bool = False,
988
+ local: bool = False,
989
+ all: bool = True,
990
+ team: Optional[str] = None,
912
991
  ):
913
- catalog.remove_dataset(name, version=version, force=force)
992
+ token = Config().read().get("studio", {}).get("token")
993
+ all, local, studio = _determine_flavors(studio, local, all, token)
994
+
995
+ if all or local:
996
+ catalog.edit_dataset(name, new_name, description, labels)
997
+
998
+ if (all or studio) and token:
999
+ edit_studio_dataset(team, name, new_name, description, labels)
914
1000
 
915
1001
 
916
1002
  def dataset_stats(
@@ -957,7 +1043,7 @@ def show(
957
1043
  schema: bool = False,
958
1044
  ) -> None:
959
1045
  from datachain.lib.dc import DataChain
960
- from datachain.query import DatasetQuery
1046
+ from datachain.query.dataset import DatasetQuery
961
1047
  from datachain.utils import show_records
962
1048
 
963
1049
  dataset = catalog.get_dataset(name)
@@ -1127,11 +1213,16 @@ def main(argv: Optional[list[str]] = None) -> int: # noqa: C901, PLR0912, PLR09
1127
1213
  edatachain_file=args.edatachain_file,
1128
1214
  )
1129
1215
  elif args.command == "edit-dataset":
1130
- catalog.edit_dataset(
1216
+ edit_dataset(
1217
+ catalog,
1131
1218
  args.name,
1132
- description=args.description,
1133
1219
  new_name=args.new_name,
1220
+ description=args.description,
1134
1221
  labels=args.labels,
1222
+ studio=args.studio,
1223
+ local=args.local,
1224
+ all=args.all,
1225
+ team=args.team,
1135
1226
  )
1136
1227
  elif args.command == "ls":
1137
1228
  ls(
@@ -1164,7 +1255,16 @@ def main(argv: Optional[list[str]] = None) -> int: # noqa: C901, PLR0912, PLR09
1164
1255
  schema=args.schema,
1165
1256
  )
1166
1257
  elif args.command == "rm-dataset":
1167
- rm_dataset(catalog, args.name, version=args.version, force=args.force)
1258
+ rm_dataset(
1259
+ catalog,
1260
+ args.name,
1261
+ version=args.version,
1262
+ force=args.force,
1263
+ studio=args.studio,
1264
+ local=args.local,
1265
+ all=args.all,
1266
+ team=args.team,
1267
+ )
1168
1268
  elif args.command == "dataset-stats":
1169
1269
  dataset_stats(
1170
1270
  catalog,
@@ -28,7 +28,6 @@ from tqdm import tqdm
28
28
  from datachain.cache import DataChainCache
29
29
  from datachain.client.fileslice import FileWrapper
30
30
  from datachain.error import ClientError as DataChainClientError
31
- from datachain.lib.file import File
32
31
  from datachain.nodes_fetcher import NodesFetcher
33
32
  from datachain.nodes_thread_pool import NodeChunk
34
33
 
@@ -36,6 +35,7 @@ if TYPE_CHECKING:
36
35
  from fsspec.spec import AbstractFileSystem
37
36
 
38
37
  from datachain.dataset import StorageURI
38
+ from datachain.lib.file import File
39
39
 
40
40
 
41
41
  logger = logging.getLogger("datachain")
@@ -45,7 +45,7 @@ DELIMITER = "/" # Path delimiter.
45
45
 
46
46
  DATA_SOURCE_URI_PATTERN = re.compile(r"^[\w]+:\/\/.*$")
47
47
 
48
- ResultQueue = asyncio.Queue[Optional[Sequence[File]]]
48
+ ResultQueue = asyncio.Queue[Optional[Sequence["File"]]]
49
49
 
50
50
 
51
51
  def _is_win_local_path(uri: str) -> bool:
@@ -212,7 +212,7 @@ class Client(ABC):
212
212
 
213
213
  async def scandir(
214
214
  self, start_prefix: str, method: str = "default"
215
- ) -> AsyncIterator[Sequence[File]]:
215
+ ) -> AsyncIterator[Sequence["File"]]:
216
216
  try:
217
217
  impl = getattr(self, f"_fetch_{method}")
218
218
  except AttributeError:
@@ -317,7 +317,7 @@ class Client(ABC):
317
317
  return f"{self.PREFIX}{self.name}/{rel_path}"
318
318
 
319
319
  @abstractmethod
320
- def info_to_file(self, v: dict[str, Any], parent: str) -> File: ...
320
+ def info_to_file(self, v: dict[str, Any], parent: str) -> "File": ...
321
321
 
322
322
  def fetch_nodes(
323
323
  self,
@@ -354,7 +354,7 @@ class Client(ABC):
354
354
  copy2(src, dst)
355
355
 
356
356
  def open_object(
357
- self, file: File, use_cache: bool = True, cb: Callback = DEFAULT_CALLBACK
357
+ self, file: "File", use_cache: bool = True, cb: Callback = DEFAULT_CALLBACK
358
358
  ) -> BinaryIO:
359
359
  """Open a file, including files in tar archives."""
360
360
  if use_cache and (cache_path := self.cache.get_path(file)):
@@ -362,19 +362,19 @@ class Client(ABC):
362
362
  assert not file.location
363
363
  return FileWrapper(self.fs.open(self.get_full_path(file.path)), cb) # type: ignore[return-value]
364
364
 
365
- def download(self, file: File, *, callback: Callback = DEFAULT_CALLBACK) -> None:
365
+ def download(self, file: "File", *, callback: Callback = DEFAULT_CALLBACK) -> None:
366
366
  sync(get_loop(), functools.partial(self._download, file, callback=callback))
367
367
 
368
- async def _download(self, file: File, *, callback: "Callback" = None) -> None:
368
+ async def _download(self, file: "File", *, callback: "Callback" = None) -> None:
369
369
  if self.cache.contains(file):
370
370
  # Already in cache, so there's nothing to do.
371
371
  return
372
372
  await self._put_in_cache(file, callback=callback)
373
373
 
374
- def put_in_cache(self, file: File, *, callback: "Callback" = None) -> None:
374
+ def put_in_cache(self, file: "File", *, callback: "Callback" = None) -> None:
375
375
  sync(get_loop(), functools.partial(self._put_in_cache, file, callback=callback))
376
376
 
377
- async def _put_in_cache(self, file: File, *, callback: "Callback" = None) -> None:
377
+ async def _put_in_cache(self, file: "File", *, callback: "Callback" = None) -> None:
378
378
  assert not file.location
379
379
  if file.etag:
380
380
  etag = await self.get_current_etag(file)
@@ -27,6 +27,8 @@ from datachain.data_storage import JobQueryType, JobStatus
27
27
  from datachain.data_storage.serializer import Serializable
28
28
  from datachain.dataset import (
29
29
  DatasetDependency,
30
+ DatasetListRecord,
31
+ DatasetListVersion,
30
32
  DatasetRecord,
31
33
  DatasetStatus,
32
34
  DatasetVersion,
@@ -59,6 +61,8 @@ class AbstractMetastore(ABC, Serializable):
59
61
 
60
62
  schema: "schema.Schema"
61
63
  dataset_class: type[DatasetRecord] = DatasetRecord
64
+ dataset_list_class: type[DatasetListRecord] = DatasetListRecord
65
+ dataset_list_version_class: type[DatasetListVersion] = DatasetListVersion
62
66
  dependency_class: type[DatasetDependency] = DatasetDependency
63
67
  job_class: type[Job] = Job
64
68
 
@@ -166,11 +170,11 @@ class AbstractMetastore(ABC, Serializable):
166
170
  """
167
171
 
168
172
  @abstractmethod
169
- def list_datasets(self) -> Iterator[DatasetRecord]:
173
+ def list_datasets(self) -> Iterator[DatasetListRecord]:
170
174
  """Lists all datasets."""
171
175
 
172
176
  @abstractmethod
173
- def list_datasets_by_prefix(self, prefix: str) -> Iterator["DatasetRecord"]:
177
+ def list_datasets_by_prefix(self, prefix: str) -> Iterator["DatasetListRecord"]:
174
178
  """Lists all datasets which names start with prefix."""
175
179
 
176
180
  @abstractmethod
@@ -348,6 +352,14 @@ class AbstractDBMetastore(AbstractMetastore):
348
352
  if c.name # type: ignore [attr-defined]
349
353
  ]
350
354
 
355
+ @cached_property
356
+ def _dataset_list_fields(self) -> list[str]:
357
+ return [
358
+ c.name # type: ignore [attr-defined]
359
+ for c in self._datasets_columns()
360
+ if c.name in self.dataset_list_class.__dataclass_fields__ # type: ignore [attr-defined]
361
+ ]
362
+
351
363
  @classmethod
352
364
  def _datasets_versions_columns(cls) -> list["SchemaItem"]:
353
365
  """Datasets versions table columns."""
@@ -390,6 +402,15 @@ class AbstractDBMetastore(AbstractMetastore):
390
402
  if c.name # type: ignore [attr-defined]
391
403
  ]
392
404
 
405
+ @cached_property
406
+ def _dataset_list_version_fields(self) -> list[str]:
407
+ return [
408
+ c.name # type: ignore [attr-defined]
409
+ for c in self._datasets_versions_columns()
410
+ if c.name # type: ignore [attr-defined]
411
+ in self.dataset_list_version_class.__dataclass_fields__
412
+ ]
413
+
393
414
  @classmethod
394
415
  def _datasets_dependencies_columns(cls) -> list["SchemaItem"]:
395
416
  """Datasets dependencies table columns."""
@@ -671,7 +692,25 @@ class AbstractDBMetastore(AbstractMetastore):
671
692
  if dataset:
672
693
  yield dataset
673
694
 
674
- def _base_dataset_query(self):
695
+ def _parse_list_dataset(self, rows) -> Optional[DatasetListRecord]:
696
+ versions = [self.dataset_list_class.parse(*r) for r in rows]
697
+ if not versions:
698
+ return None
699
+ return reduce(lambda ds, version: ds.merge_versions(version), versions)
700
+
701
+ def _parse_dataset_list(self, rows) -> Iterator["DatasetListRecord"]:
702
+ # grouping rows by dataset id
703
+ for _, g in groupby(rows, lambda r: r[0]):
704
+ dataset = self._parse_list_dataset(list(g))
705
+ if dataset:
706
+ yield dataset
707
+
708
+ def _get_dataset_query(
709
+ self,
710
+ dataset_fields: list[str],
711
+ dataset_version_fields: list[str],
712
+ isouter: bool = True,
713
+ ):
675
714
  if not (
676
715
  self.db.has_table(self._datasets.name)
677
716
  and self.db.has_table(self._datasets_versions.name)
@@ -680,23 +719,36 @@ class AbstractDBMetastore(AbstractMetastore):
680
719
 
681
720
  d = self._datasets
682
721
  dv = self._datasets_versions
722
+
683
723
  query = self._datasets_select(
684
- *(getattr(d.c, f) for f in self._dataset_fields),
685
- *(getattr(dv.c, f) for f in self._dataset_version_fields),
724
+ *(getattr(d.c, f) for f in dataset_fields),
725
+ *(getattr(dv.c, f) for f in dataset_version_fields),
686
726
  )
687
- j = d.join(dv, d.c.id == dv.c.dataset_id, isouter=True)
727
+ j = d.join(dv, d.c.id == dv.c.dataset_id, isouter=isouter)
688
728
  return query.select_from(j)
689
729
 
690
- def list_datasets(self) -> Iterator["DatasetRecord"]:
730
+ def _base_dataset_query(self):
731
+ return self._get_dataset_query(
732
+ self._dataset_fields, self._dataset_version_fields
733
+ )
734
+
735
+ def _base_list_datasets_query(self):
736
+ return self._get_dataset_query(
737
+ self._dataset_list_fields, self._dataset_list_version_fields, isouter=False
738
+ )
739
+
740
+ def list_datasets(self) -> Iterator["DatasetListRecord"]:
691
741
  """Lists all datasets."""
692
- yield from self._parse_datasets(self.db.execute(self._base_dataset_query()))
742
+ yield from self._parse_dataset_list(
743
+ self.db.execute(self._base_list_datasets_query())
744
+ )
693
745
 
694
746
  def list_datasets_by_prefix(
695
747
  self, prefix: str, conn=None
696
- ) -> Iterator["DatasetRecord"]:
697
- query = self._base_dataset_query()
748
+ ) -> Iterator["DatasetListRecord"]:
749
+ query = self._base_list_datasets_query()
698
750
  query = query.where(self._datasets.c.name.startswith(prefix))
699
- yield from self._parse_datasets(self.db.execute(query))
751
+ yield from self._parse_dataset_list(self.db.execute(query))
700
752
 
701
753
  def get_dataset(self, name: str, conn=None) -> DatasetRecord:
702
754
  """Gets a single dataset by name"""
@@ -12,7 +12,7 @@ import sqlalchemy as sa
12
12
  from sqlalchemy.sql import func as f
13
13
  from sqlalchemy.sql.expression import false, null, true
14
14
 
15
- from datachain.sql.functions import path
15
+ from datachain.sql.functions import path as pathfunc
16
16
  from datachain.sql.types import Int, SQLType, UInt64
17
17
 
18
18
  if TYPE_CHECKING:
@@ -130,7 +130,7 @@ class DirExpansion:
130
130
 
131
131
  def query(self, q):
132
132
  q = self.base_select(q).cte(recursive=True)
133
- parent = path.parent(self.c(q, "path"))
133
+ parent = pathfunc.parent(self.c(q, "path"))
134
134
  q = q.union_all(
135
135
  sa.select(
136
136
  sa.literal(-1).label("sys__id"),
@@ -122,7 +122,9 @@ class SQLiteDatabaseEngine(DatabaseEngine):
122
122
  return cls(*cls._connect(db_file=db_file))
123
123
 
124
124
  @staticmethod
125
- def _connect(db_file: Optional[str] = None):
125
+ def _connect(
126
+ db_file: Optional[str] = None,
127
+ ) -> tuple["Engine", "MetaData", sqlite3.Connection, str]:
126
128
  try:
127
129
  if db_file == ":memory:":
128
130
  # Enable multithreaded usage of the same in-memory db
@@ -130,9 +132,8 @@ class SQLiteDatabaseEngine(DatabaseEngine):
130
132
  _get_in_memory_uri(), uri=True, detect_types=DETECT_TYPES
131
133
  )
132
134
  else:
133
- db = sqlite3.connect(
134
- db_file or DataChainDir.find().db, detect_types=DETECT_TYPES
135
- )
135
+ db_file = db_file or DataChainDir.find().db
136
+ db = sqlite3.connect(db_file, detect_types=DETECT_TYPES)
136
137
  create_user_defined_sql_functions(db)
137
138
  engine = sqlalchemy.create_engine(
138
139
  "sqlite+pysqlite:///", creator=lambda: db, future=True
@@ -224,28 +224,28 @@ class AbstractWarehouse(ABC, Serializable):
224
224
  offset = 0
225
225
  num_yielded = 0
226
226
 
227
- while True:
228
- if limit is not None:
229
- limit -= num_yielded
230
- if limit == 0:
231
- break
232
- if limit < page_size:
233
- paginated_query = paginated_query.limit(None).limit(limit)
234
-
235
- # Ensure we're using a thread-local connection
236
- with self.clone() as wh:
227
+ # Ensure we're using a thread-local connection
228
+ with self.clone() as wh:
229
+ while True:
230
+ if limit is not None:
231
+ limit -= num_yielded
232
+ if limit == 0:
233
+ break
234
+ if limit < page_size:
235
+ paginated_query = paginated_query.limit(None).limit(limit)
236
+
237
237
  # Cursor results are not thread-safe, so we convert them to a list
238
238
  results = list(wh.dataset_rows_select(paginated_query.offset(offset)))
239
239
 
240
- processed = False
241
- for row in results:
242
- processed = True
243
- yield row
244
- num_yielded += 1
240
+ processed = False
241
+ for row in results:
242
+ processed = True
243
+ yield row
244
+ num_yielded += 1
245
245
 
246
- if not processed:
247
- break # no more results
248
- offset += page_size
246
+ if not processed:
247
+ break # no more results
248
+ offset += page_size
249
249
 
250
250
  #
251
251
  # Table Name Internal Functions