datachain 0.1.13__py3-none-any.whl → 0.2.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of datachain might be problematic. Click here for more details.

Files changed (49) hide show
  1. datachain/__init__.py +0 -4
  2. datachain/asyn.py +3 -3
  3. datachain/catalog/__init__.py +3 -3
  4. datachain/catalog/catalog.py +6 -6
  5. datachain/catalog/loader.py +3 -3
  6. datachain/cli.py +10 -2
  7. datachain/client/azure.py +37 -1
  8. datachain/client/fsspec.py +1 -1
  9. datachain/client/local.py +1 -1
  10. datachain/data_storage/__init__.py +1 -1
  11. datachain/data_storage/metastore.py +11 -3
  12. datachain/data_storage/schema.py +12 -7
  13. datachain/data_storage/sqlite.py +3 -0
  14. datachain/data_storage/warehouse.py +31 -30
  15. datachain/dataset.py +1 -3
  16. datachain/lib/arrow.py +85 -0
  17. datachain/lib/cached_stream.py +3 -85
  18. datachain/lib/dc.py +382 -179
  19. datachain/lib/feature.py +46 -91
  20. datachain/lib/feature_registry.py +4 -1
  21. datachain/lib/feature_utils.py +2 -2
  22. datachain/lib/file.py +30 -44
  23. datachain/lib/image.py +9 -2
  24. datachain/lib/meta_formats.py +66 -34
  25. datachain/lib/settings.py +5 -5
  26. datachain/lib/signal_schema.py +103 -105
  27. datachain/lib/udf.py +10 -38
  28. datachain/lib/udf_signature.py +11 -6
  29. datachain/lib/webdataset_laion.py +5 -22
  30. datachain/listing.py +8 -8
  31. datachain/node.py +1 -1
  32. datachain/progress.py +1 -1
  33. datachain/query/builtins.py +1 -1
  34. datachain/query/dataset.py +42 -119
  35. datachain/query/dispatch.py +1 -1
  36. datachain/query/metrics.py +19 -0
  37. datachain/query/schema.py +13 -3
  38. datachain/sql/__init__.py +1 -1
  39. datachain/sql/sqlite/base.py +34 -2
  40. datachain/sql/sqlite/vector.py +13 -5
  41. datachain/utils.py +1 -122
  42. {datachain-0.1.13.dist-info → datachain-0.2.1.dist-info}/METADATA +11 -4
  43. {datachain-0.1.13.dist-info → datachain-0.2.1.dist-info}/RECORD +47 -47
  44. {datachain-0.1.13.dist-info → datachain-0.2.1.dist-info}/WHEEL +1 -1
  45. datachain/_version.py +0 -16
  46. datachain/lib/parquet.py +0 -32
  47. {datachain-0.1.13.dist-info → datachain-0.2.1.dist-info}/LICENSE +0 -0
  48. {datachain-0.1.13.dist-info → datachain-0.2.1.dist-info}/entry_points.txt +0 -0
  49. {datachain-0.1.13.dist-info → datachain-0.2.1.dist-info}/top_level.txt +0 -0
datachain/__init__.py CHANGED
@@ -1,4 +0,0 @@
1
- try:
2
- from ._version import version as __version__
3
- except ImportError:
4
- __version__ = "UNKNOWN"
datachain/asyn.py CHANGED
@@ -82,13 +82,13 @@ class AsyncMapper(Generic[InputT, ResultT]):
82
82
  for _i in range(self.workers):
83
83
  self.start_task(self.worker())
84
84
  try:
85
- done, pending = await asyncio.wait(
85
+ done, _pending = await asyncio.wait(
86
86
  self._tasks, return_when=asyncio.FIRST_COMPLETED
87
87
  )
88
88
  self.gather_exceptions(done)
89
89
  assert producer.done()
90
90
  join = self.start_task(self.work_queue.join())
91
- done, pending = await asyncio.wait(
91
+ done, _pending = await asyncio.wait(
92
92
  self._tasks, return_when=asyncio.FIRST_COMPLETED
93
93
  )
94
94
  self.gather_exceptions(done)
@@ -208,7 +208,7 @@ class OrderedMapper(AsyncMapper[InputT, ResultT]):
208
208
 
209
209
  async def _pop_result(self) -> Optional[ResultT]:
210
210
  if self.heap and self.heap[0][0] == self._next_yield:
211
- i, out = heappop(self.heap)
211
+ _i, out = heappop(self.heap)
212
212
  else:
213
213
  self._getters[self._next_yield] = get_value = self.loop.create_future()
214
214
  out = await get_value
@@ -8,10 +8,10 @@ from .catalog import (
8
8
  from .loader import get_catalog
9
9
 
10
10
  __all__ = [
11
+ "QUERY_DATASET_PREFIX",
12
+ "QUERY_SCRIPT_CANCELED_EXIT_CODE",
13
+ "QUERY_SCRIPT_INVALID_LAST_STATEMENT_EXIT_CODE",
11
14
  "Catalog",
12
15
  "get_catalog",
13
16
  "parse_edatachain_file",
14
- "QUERY_SCRIPT_INVALID_LAST_STATEMENT_EXIT_CODE",
15
- "QUERY_SCRIPT_CANCELED_EXIT_CODE",
16
- "QUERY_DATASET_PREFIX",
17
17
  ]
@@ -142,6 +142,7 @@ class QueryResult(NamedTuple):
142
142
  version: Optional[int]
143
143
  output: str
144
144
  preview: Optional[list[dict]]
145
+ metrics: dict[str, Any]
145
146
 
146
147
 
147
148
  class DatasetRowsFetcher(NodesThreadPool):
@@ -876,13 +877,11 @@ class Catalog:
876
877
  # so this is to improve performance
877
878
  return None
878
879
 
879
- dsrc_all = []
880
+ dsrc_all: list[DataSource] = []
880
881
  for listing, file_path in enlisted_sources:
881
882
  nodes = listing.expand_path(file_path)
882
883
  dir_only = file_path.endswith("/")
883
- for node in nodes:
884
- dsrc_all.append(DataSource(listing, node, dir_only))
885
-
884
+ dsrc_all.extend(DataSource(listing, node, dir_only) for node in nodes)
886
885
  return dsrc_all
887
886
 
888
887
  def enlist_sources_grouped(
@@ -1997,6 +1996,7 @@ class Catalog:
1997
1996
  version=version,
1998
1997
  output=output,
1999
1998
  preview=exec_result.preview,
1999
+ metrics=exec_result.metrics,
2000
2000
  )
2001
2001
 
2002
2002
  def run_query(
@@ -2068,8 +2068,8 @@ class Catalog:
2068
2068
  "DATACHAIN_JOB_ID": job_id or "",
2069
2069
  },
2070
2070
  )
2071
- with subprocess.Popen(
2072
- [python_executable, "-c", query_script_compiled], # noqa: S603
2071
+ with subprocess.Popen( # noqa: S603
2072
+ [python_executable, "-c", query_script_compiled],
2073
2073
  env=envs,
2074
2074
  stdout=subprocess.PIPE if capture_output else None,
2075
2075
  stderr=subprocess.STDOUT if capture_output else None,
@@ -35,7 +35,7 @@ def get_id_generator() -> "AbstractIDGenerator":
35
35
  id_generator_obj = deserialize(id_generator_serialized)
36
36
  if not isinstance(id_generator_obj, AbstractIDGenerator):
37
37
  raise RuntimeError(
38
- f"Deserialized ID generator is not an instance of AbstractIDGenerator: "
38
+ "Deserialized ID generator is not an instance of AbstractIDGenerator: "
39
39
  f"{id_generator_obj}"
40
40
  )
41
41
  return id_generator_obj
@@ -67,7 +67,7 @@ def get_metastore(id_generator: Optional["AbstractIDGenerator"]) -> "AbstractMet
67
67
  metastore_obj = deserialize(metastore_serialized)
68
68
  if not isinstance(metastore_obj, AbstractMetastore):
69
69
  raise RuntimeError(
70
- f"Deserialized Metastore is not an instance of AbstractMetastore: "
70
+ "Deserialized Metastore is not an instance of AbstractMetastore: "
71
71
  f"{metastore_obj}"
72
72
  )
73
73
  return metastore_obj
@@ -101,7 +101,7 @@ def get_warehouse(id_generator: Optional["AbstractIDGenerator"]) -> "AbstractWar
101
101
  warehouse_obj = deserialize(warehouse_serialized)
102
102
  if not isinstance(warehouse_obj, AbstractWarehouse):
103
103
  raise RuntimeError(
104
- f"Deserialized Warehouse is not an instance of AbstractWarehouse: "
104
+ "Deserialized Warehouse is not an instance of AbstractWarehouse: "
105
105
  f"{warehouse_obj}"
106
106
  )
107
107
  return warehouse_obj
datachain/cli.py CHANGED
@@ -5,13 +5,14 @@ import sys
5
5
  import traceback
6
6
  from argparse import SUPPRESS, Action, ArgumentParser, ArgumentTypeError, Namespace
7
7
  from collections.abc import Iterable, Iterator, Mapping, Sequence
8
+ from importlib.metadata import PackageNotFoundError, version
8
9
  from itertools import chain
9
10
  from multiprocessing import freeze_support
10
11
  from typing import TYPE_CHECKING, Optional, Union
11
12
 
12
13
  import shtab
13
14
 
14
- from datachain import __version__, utils
15
+ from datachain import utils
15
16
  from datachain.cli_utils import BooleanOptionalAction, CommaSeparatedArgs, KeyValueArgs
16
17
  from datachain.utils import DataChainDir
17
18
 
@@ -96,6 +97,12 @@ def add_show_args(parser: ArgumentParser) -> None:
96
97
 
97
98
 
98
99
  def get_parser() -> ArgumentParser: # noqa: PLR0915
100
+ try:
101
+ __version__ = version("datachain")
102
+ except PackageNotFoundError:
103
+ # package is not installed
104
+ __version__ = "unknown"
105
+
99
106
  parser = ArgumentParser(
100
107
  description="DataChain: Wrangle unstructured AI data at scale", prog="datachain"
101
108
  )
@@ -845,6 +852,7 @@ def query(
845
852
  query=script_content,
846
853
  query_type=JobQueryType.PYTHON,
847
854
  python_version=python_version,
855
+ params=params,
848
856
  )
849
857
 
850
858
  try:
@@ -870,7 +878,7 @@ def query(
870
878
  )
871
879
  raise
872
880
 
873
- catalog.metastore.set_job_status(job_id, JobStatus.COMPLETE)
881
+ catalog.metastore.set_job_status(job_id, JobStatus.COMPLETE, metrics=result.metrics)
874
882
 
875
883
  show_records(result.preview, collapse_columns=not no_collapse)
876
884
 
datachain/client/azure.py CHANGED
@@ -1,10 +1,12 @@
1
+ import posixpath
1
2
  from typing import Any
2
3
 
3
4
  from adlfs import AzureBlobFileSystem
5
+ from tqdm import tqdm
4
6
 
5
7
  from datachain.node import Entry
6
8
 
7
- from .fsspec import DELIMITER, Client
9
+ from .fsspec import DELIMITER, Client, ResultQueue
8
10
 
9
11
 
10
12
  class AzureClient(Client):
@@ -28,3 +30,37 @@ class AzureClient(Client):
28
30
  last_modified=v["last_modified"],
29
31
  size=v.get("size", ""),
30
32
  )
33
+
34
+ async def _fetch_flat(self, start_prefix: str, result_queue: ResultQueue) -> None:
35
+ prefix = start_prefix
36
+ if prefix:
37
+ prefix = prefix.lstrip(DELIMITER) + DELIMITER
38
+ found = False
39
+ try:
40
+ with tqdm(desc=f"Listing {self.uri}", unit=" objects") as pbar:
41
+ async with self.fs.service_client.get_container_client(
42
+ container=self.name
43
+ ) as container_client:
44
+ async for page in container_client.list_blobs(
45
+ include=["metadata", "versions"], name_starts_with=prefix
46
+ ).by_page():
47
+ entries = []
48
+ async for b in page:
49
+ found = True
50
+ if not self._is_valid_key(b["name"]):
51
+ continue
52
+ info = (await self.fs._details([b]))[0]
53
+ full_path = info["name"]
54
+ parent = posixpath.dirname(self.rel_path(full_path))
55
+ entries.append(self.convert_info(info, parent))
56
+ if entries:
57
+ await result_queue.put(entries)
58
+ pbar.update(len(entries))
59
+ if not found:
60
+ raise FileNotFoundError(
61
+ f"Unable to resolve remote path: {prefix}"
62
+ )
63
+ finally:
64
+ result_queue.put_nowait(None)
65
+
66
+ _fetch_default = _fetch_flat
@@ -202,7 +202,7 @@ class Client(ABC):
202
202
  try:
203
203
  impl = getattr(self, f"_fetch_{method}")
204
204
  except AttributeError:
205
- raise ValueError("Unknown indexing method '{method}'") from None
205
+ raise ValueError(f"Unknown indexing method '{method}'") from None
206
206
  result_queue: ResultQueue = asyncio.Queue()
207
207
  loop = get_loop()
208
208
  main_task = loop.create_task(impl(start_prefix, result_queue))
datachain/client/local.py CHANGED
@@ -135,7 +135,7 @@ class FileClient(Client):
135
135
  return posixpath.relpath(path, self.name)
136
136
 
137
137
  def get_full_path(self, rel_path):
138
- full_path = Path(self.name, rel_path).as_uri()
138
+ full_path = Path(self.name, rel_path).as_posix()
139
139
  if rel_path.endswith("/") or not rel_path:
140
140
  full_path += "/"
141
141
  return full_path
@@ -5,8 +5,8 @@ from .warehouse import AbstractWarehouse
5
5
 
6
6
  __all__ = [
7
7
  "AbstractDBIDGenerator",
8
- "AbstractIDGenerator",
9
8
  "AbstractDBMetastore",
9
+ "AbstractIDGenerator",
10
10
  "AbstractMetastore",
11
11
  "AbstractWarehouse",
12
12
  "JobQueryType",
@@ -385,6 +385,7 @@ class AbstractMetastore(ABC, Serializable):
385
385
  query_type: JobQueryType = JobQueryType.PYTHON,
386
386
  workers: int = 1,
387
387
  python_version: Optional[str] = None,
388
+ params: Optional[dict[str, str]] = None,
388
389
  ) -> str:
389
390
  """
390
391
  Creates a new job.
@@ -398,6 +399,7 @@ class AbstractMetastore(ABC, Serializable):
398
399
  status: JobStatus,
399
400
  error_message: Optional[str] = None,
400
401
  error_stack: Optional[str] = None,
402
+ metrics: Optional[dict[str, Any]] = None,
401
403
  ) -> None:
402
404
  """Set the status of the given job."""
403
405
 
@@ -1165,9 +1167,7 @@ class AbstractDBMetastore(AbstractMetastore):
1165
1167
  return dataset_version
1166
1168
 
1167
1169
  def _parse_dataset(self, rows) -> Optional[DatasetRecord]:
1168
- versions = []
1169
- for r in rows:
1170
- versions.append(self.dataset_class.parse(*r))
1170
+ versions = [self.dataset_class.parse(*r) for r in rows]
1171
1171
  if not versions:
1172
1172
  return None
1173
1173
  return reduce(lambda ds, version: ds.merge_versions(version), versions)
@@ -1463,6 +1463,8 @@ class AbstractDBMetastore(AbstractMetastore):
1463
1463
  Column("python_version", Text, nullable=True),
1464
1464
  Column("error_message", Text, nullable=False, default=""),
1465
1465
  Column("error_stack", Text, nullable=False, default=""),
1466
+ Column("params", JSON, nullable=False),
1467
+ Column("metrics", JSON, nullable=False),
1466
1468
  ]
1467
1469
 
1468
1470
  @cached_property
@@ -1489,6 +1491,7 @@ class AbstractDBMetastore(AbstractMetastore):
1489
1491
  query_type: JobQueryType = JobQueryType.PYTHON,
1490
1492
  workers: int = 1,
1491
1493
  python_version: Optional[str] = None,
1494
+ params: Optional[dict[str, str]] = None,
1492
1495
  conn: Optional[Any] = None,
1493
1496
  ) -> str:
1494
1497
  """
@@ -1508,6 +1511,8 @@ class AbstractDBMetastore(AbstractMetastore):
1508
1511
  python_version=python_version,
1509
1512
  error_message="",
1510
1513
  error_stack="",
1514
+ params=json.dumps(params or {}),
1515
+ metrics=json.dumps({}),
1511
1516
  ),
1512
1517
  conn=conn,
1513
1518
  )
@@ -1519,6 +1524,7 @@ class AbstractDBMetastore(AbstractMetastore):
1519
1524
  status: JobStatus,
1520
1525
  error_message: Optional[str] = None,
1521
1526
  error_stack: Optional[str] = None,
1527
+ metrics: Optional[dict[str, Any]] = None,
1522
1528
  conn: Optional[Any] = None,
1523
1529
  ) -> None:
1524
1530
  """Set the status of the given job."""
@@ -1529,6 +1535,8 @@ class AbstractDBMetastore(AbstractMetastore):
1529
1535
  values["error_message"] = error_message
1530
1536
  if error_stack:
1531
1537
  values["error_stack"] = error_stack
1538
+ if metrics:
1539
+ values["metrics"] = json.dumps(metrics)
1532
1540
  self.db.execute(
1533
1541
  self._jobs_update(self._jobs.c.id == job_id).values(**values),
1534
1542
  conn=conn,
@@ -31,11 +31,10 @@ def dedup_columns(columns: Iterable[sa.Column]) -> list[sa.Column]:
31
31
  """
32
32
  c_set: dict[str, sa.Column] = {}
33
33
  for c in columns:
34
- if ec := c_set.get(c.name, None):
34
+ if (ec := c_set.get(c.name, None)) is not None:
35
35
  if str(ec.type) != str(c.type):
36
36
  raise ValueError(
37
- f"conflicting types for column {c.name}:"
38
- f"{c.type!s} and {ec.type!s}"
37
+ f"conflicting types for column {c.name}:{c.type!s} and {ec.type!s}"
39
38
  )
40
39
  continue
41
40
  c_set[c.name] = c
@@ -172,8 +171,8 @@ class DataTable:
172
171
  ):
173
172
  # copy columns, since re-using the same objects from another table
174
173
  # may raise an error
175
- columns = [cls.copy_column(c) for c in columns if c.name != "id"]
176
- columns = [sa.Column("id", Int, primary_key=True), *columns]
174
+ columns = cls.sys_columns() + [cls.copy_column(c) for c in columns]
175
+ columns = dedup_columns(columns)
177
176
 
178
177
  if metadata is None:
179
178
  metadata = sa.MetaData()
@@ -231,10 +230,17 @@ class DataTable:
231
230
  def delete(self):
232
231
  return self.apply_conditions(self.table.delete())
233
232
 
233
+ @staticmethod
234
+ def sys_columns():
235
+ return [
236
+ sa.Column("id", Int, primary_key=True),
237
+ sa.Column("random", Int64, nullable=False, default=f.random()),
238
+ ]
239
+
234
240
  @classmethod
235
241
  def file_columns(cls) -> list[sa.Column]:
236
242
  return [
237
- sa.Column("id", Int, primary_key=True),
243
+ *cls.sys_columns(),
238
244
  sa.Column("vtype", String, nullable=False, index=True),
239
245
  sa.Column("dir_type", Int, index=True),
240
246
  sa.Column("parent", String, index=True),
@@ -246,7 +252,6 @@ class DataTable:
246
252
  sa.Column("size", Int64, nullable=False, index=True),
247
253
  sa.Column("owner_name", String),
248
254
  sa.Column("owner_id", String),
249
- sa.Column("random", Int64, nullable=False),
250
255
  sa.Column("location", JSON),
251
256
  sa.Column("source", String, nullable=False),
252
257
  ]
@@ -33,6 +33,7 @@ from datachain.data_storage.schema import (
33
33
  from datachain.dataset import DatasetRecord
34
34
  from datachain.error import DataChainError
35
35
  from datachain.sql.sqlite import create_user_defined_sql_functions, sqlite_dialect
36
+ from datachain.sql.sqlite.base import load_usearch_extension
36
37
  from datachain.sql.types import SQLType
37
38
  from datachain.storage import StorageURI
38
39
  from datachain.utils import DataChainDir
@@ -114,6 +115,8 @@ class SQLiteDatabaseEngine(DatabaseEngine):
114
115
  if os.environ.get("DEBUG_SHOW_SQL_QUERIES"):
115
116
  db.set_trace_callback(print)
116
117
 
118
+ load_usearch_extension(db)
119
+
117
120
  return cls(engine, MetaData(), db, db_file)
118
121
  except RuntimeError:
119
122
  raise DataChainError("Can't connect to SQLite DB") from None
@@ -95,14 +95,14 @@ class AbstractWarehouse(ABC, Serializable):
95
95
 
96
96
  exc = None
97
97
  try:
98
- if col_python_type == list and value_type in (list, tuple, set):
98
+ if col_python_type is list and value_type in (list, tuple, set):
99
99
  if len(val) == 0:
100
100
  return []
101
101
  item_python_type = self.python_type(col_type.item_type)
102
- if item_python_type != list:
102
+ if item_python_type is not list:
103
103
  if isinstance(val[0], item_python_type):
104
104
  return val
105
- if item_python_type == float and isinstance(val[0], int):
105
+ if item_python_type is float and isinstance(val[0], int):
106
106
  return [float(i) for i in val]
107
107
  # Optimization: Reuse these values for each function call within the
108
108
  # list comprehension.
@@ -114,18 +114,18 @@ class AbstractWarehouse(ABC, Serializable):
114
114
  )
115
115
  return [self.convert_type(i, *item_type_info) for i in val]
116
116
  # Special use case with JSON type as we save it as string
117
- if col_python_type == dict or col_type_name == "JSON":
118
- if value_type == str:
117
+ if col_python_type is dict or col_type_name == "JSON":
118
+ if value_type is str:
119
119
  return val
120
120
  if value_type in (dict, list):
121
121
  return json.dumps(val)
122
122
  raise ValueError(
123
- f"Cannot convert value {val!r} with type" f"{value_type} to JSON"
123
+ f"Cannot convert value {val!r} with type {value_type} to JSON"
124
124
  )
125
125
 
126
126
  if isinstance(val, col_python_type):
127
127
  return val
128
- if col_python_type == float and isinstance(val, int):
128
+ if col_python_type is float and isinstance(val, int):
129
129
  return float(val)
130
130
  except Exception as e: # noqa: BLE001
131
131
  exc = e
@@ -335,6 +335,7 @@ class AbstractWarehouse(ABC, Serializable):
335
335
  return select_query
336
336
  if recursive:
337
337
  root = False
338
+ where = self.path_expr(dr).op("GLOB")(path)
338
339
  if not path or path == "/":
339
340
  # root of the bucket, e.g s3://bucket/ -> getting all the nodes
340
341
  # in the bucket
@@ -344,14 +345,18 @@ class AbstractWarehouse(ABC, Serializable):
344
345
  # not a root and not a explicit glob, so it's pointing to some directory
345
346
  # and we are adding a proper glob syntax for it
346
347
  # e.g s3://bucket/dir1 -> s3://bucket/dir1/*
347
- path = path.rstrip("/") + "/*"
348
+ dir_path = path.rstrip("/") + "/*"
349
+ where = where | self.path_expr(dr).op("GLOB")(dir_path)
348
350
 
349
351
  if not root:
350
352
  # not a root, so running glob query
351
- select_query = select_query.where(self.path_expr(dr).op("GLOB")(path))
353
+ select_query = select_query.where(where)
354
+
352
355
  else:
353
356
  parent = self.get_node_by_path(dr, path.lstrip("/").rstrip("/*"))
354
- select_query = select_query.where(dr.c.parent == parent.path)
357
+ select_query = select_query.where(
358
+ (dr.c.parent == parent.path) | (self.path_expr(dr) == path)
359
+ )
355
360
  return select_query
356
361
 
357
362
  def rename_dataset_table(
@@ -493,7 +498,10 @@ class AbstractWarehouse(ABC, Serializable):
493
498
  This gets nodes based on the provided query, and should be used sparingly,
494
499
  as it will be slow on any OLAP database systems.
495
500
  """
496
- return (Node(*row) for row in self.db.execute(query))
501
+ columns = [c.name for c in query.columns]
502
+ for row in self.db.execute(query):
503
+ d = dict(zip(columns, row))
504
+ yield Node(**d)
497
505
 
498
506
  def get_dirs_by_parent_path(
499
507
  self,
@@ -570,14 +578,12 @@ class AbstractWarehouse(ABC, Serializable):
570
578
  matched_paths: list[list[str]] = [[]]
571
579
  for curr_name in path_list[:-1]:
572
580
  if glob.has_magic(curr_name):
573
- new_paths = []
581
+ new_paths: list[list[str]] = []
574
582
  for path in matched_paths:
575
583
  nodes = self._get_nodes_by_glob_path_pattern(
576
584
  dataset_rows, path, curr_name
577
585
  )
578
- for node in nodes:
579
- if node.is_container:
580
- new_paths.append([*path, node.name or ""])
586
+ new_paths.extend([*path, n.name] for n in nodes if n.is_container)
581
587
  matched_paths = new_paths
582
588
  else:
583
589
  for path in matched_paths:
@@ -772,7 +778,7 @@ class AbstractWarehouse(ABC, Serializable):
772
778
  self,
773
779
  dataset_rows: "DataTable",
774
780
  parent_path: str,
775
- fields: Optional[Iterable[str]] = None,
781
+ fields: Optional[Sequence[str]] = None,
776
782
  type: Optional[str] = None,
777
783
  conds=None,
778
784
  order_by: Optional[Union[str, list[str]]] = None,
@@ -794,9 +800,9 @@ class AbstractWarehouse(ABC, Serializable):
794
800
  else:
795
801
  conds.append(path != "")
796
802
 
797
- if fields is None:
798
- fields = [c.name for c in dr.file_columns()]
799
- columns = [getattr(q.c, f) for f in fields]
803
+ columns = q.c
804
+ if fields:
805
+ columns = [getattr(columns, f) for f in fields]
800
806
 
801
807
  query = sa.select(*columns)
802
808
  query = query.where(*conds)
@@ -833,19 +839,16 @@ class AbstractWarehouse(ABC, Serializable):
833
839
 
834
840
  prefix_len = len(node.path)
835
841
 
836
- def make_node_with_path(row):
837
- sub_node = Node(*row)
838
- return NodeWithPath(
839
- sub_node, sub_node.path[prefix_len:].lstrip("/").split("/")
840
- )
842
+ def make_node_with_path(node: Node) -> NodeWithPath:
843
+ return NodeWithPath(node, node.path[prefix_len:].lstrip("/").split("/"))
841
844
 
842
- return map(make_node_with_path, self.db.execute(query))
845
+ return map(make_node_with_path, self.get_nodes(query))
843
846
 
844
847
  def find(
845
848
  self,
846
849
  dataset_rows: "DataTable",
847
850
  node: Node,
848
- fields: Iterable[str],
851
+ fields: Sequence[str],
849
852
  type=None,
850
853
  conds=None,
851
854
  order_by=None,
@@ -890,11 +893,9 @@ class AbstractWarehouse(ABC, Serializable):
890
893
  def is_temp_table_name(self, name: str) -> bool:
891
894
  """Returns if the given table name refers to a temporary
892
895
  or no longer needed table."""
893
- if name.startswith(
896
+ return name.startswith(
894
897
  (self.TMP_TABLE_NAME_PREFIX, self.UDF_TABLE_NAME_PREFIX, "ds_shadow_")
895
- ) or name.endswith("_shadow"):
896
- return True
897
- return False
898
+ ) or name.endswith("_shadow")
898
899
 
899
900
  def get_temp_table_names(self) -> list[str]:
900
901
  return [
datachain/dataset.py CHANGED
@@ -405,9 +405,7 @@ class DatasetRecord:
405
405
  Checks if a number can be a valid next latest version for dataset.
406
406
  The only rule is that it cannot be lower than current latest version
407
407
  """
408
- if self.latest_version and self.latest_version >= version:
409
- return False
410
- return True
408
+ return not (self.latest_version and self.latest_version >= version)
411
409
 
412
410
  def get_version(self, version: int) -> DatasetVersion:
413
411
  if not self.has_version(version):
datachain/lib/arrow.py ADDED
@@ -0,0 +1,85 @@
1
+ import re
2
+ from typing import TYPE_CHECKING, Optional
3
+
4
+ from pyarrow.dataset import dataset
5
+
6
+ from datachain.lib.feature import Feature
7
+ from datachain.lib.file import File
8
+
9
+ if TYPE_CHECKING:
10
+ import pyarrow as pa
11
+
12
+
13
+ class Source(Feature):
14
+ """File source info for tables."""
15
+
16
+ file: File
17
+ index: int
18
+
19
+
20
+ class ArrowGenerator:
21
+ def __init__(self, schema: Optional["pa.Schema"] = None, **kwargs):
22
+ """
23
+ Generator for getting rows from tabular files.
24
+
25
+ Parameters:
26
+
27
+ schema : Optional pyarrow schema for validation.
28
+ kwargs: Parameters to pass to pyarrow.dataset.dataset.
29
+ """
30
+ self.schema = schema
31
+ self.kwargs = kwargs
32
+
33
+ def __call__(self, file: File):
34
+ path = file.get_path()
35
+ ds = dataset(path, filesystem=file.get_fs(), schema=self.schema, **self.kwargs)
36
+ index = 0
37
+ for record_batch in ds.to_batches():
38
+ for record in record_batch.to_pylist():
39
+ source = Source(file=file, index=index)
40
+ yield [source, *record.values()]
41
+ index += 1
42
+
43
+
44
+ def schema_to_output(schema: "pa.Schema"):
45
+ """Generate UDF output schema from pyarrow schema."""
46
+ default_column = 0
47
+ output = {"source": Source}
48
+ for field in schema:
49
+ column = field.name.lower()
50
+ column = re.sub("[^0-9a-z_]+", "", column)
51
+ if not column:
52
+ column = f"c{default_column}"
53
+ default_column += 1
54
+ output[column] = _arrow_type_mapper(field.type) # type: ignore[assignment]
55
+
56
+ return output
57
+
58
+
59
+ def _arrow_type_mapper(col_type: "pa.DataType") -> type: # noqa: PLR0911
60
+ """Convert pyarrow types to basic types."""
61
+ from datetime import datetime
62
+
63
+ import pyarrow as pa
64
+
65
+ if pa.types.is_timestamp(col_type):
66
+ return datetime
67
+ if pa.types.is_binary(col_type):
68
+ return bytes
69
+ if pa.types.is_floating(col_type):
70
+ return float
71
+ if pa.types.is_integer(col_type):
72
+ return int
73
+ if pa.types.is_boolean(col_type):
74
+ return bool
75
+ if pa.types.is_date(col_type):
76
+ return datetime
77
+ if pa.types.is_string(col_type) or pa.types.is_large_string(col_type):
78
+ return str
79
+ if pa.types.is_list(col_type):
80
+ return list[_arrow_type_mapper(col_type.value_type)] # type: ignore[misc]
81
+ if pa.types.is_struct(col_type) or pa.types.is_map(col_type):
82
+ return dict
83
+ if isinstance(col_type, pa.lib.DictionaryType):
84
+ return _arrow_type_mapper(col_type.value_type) # type: ignore[return-value]
85
+ raise TypeError(f"{col_type!r} datatypes not supported")