datachain 0.1.12__py3-none-any.whl → 0.2.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of datachain might be problematic. Click here for more details.

Files changed (44) hide show
  1. datachain/_version.py +2 -2
  2. datachain/asyn.py +3 -3
  3. datachain/catalog/__init__.py +3 -3
  4. datachain/catalog/catalog.py +6 -6
  5. datachain/catalog/loader.py +3 -3
  6. datachain/cli.py +2 -1
  7. datachain/client/azure.py +37 -1
  8. datachain/client/fsspec.py +1 -1
  9. datachain/client/local.py +1 -1
  10. datachain/data_storage/__init__.py +1 -1
  11. datachain/data_storage/metastore.py +11 -3
  12. datachain/data_storage/schema.py +2 -3
  13. datachain/data_storage/warehouse.py +31 -30
  14. datachain/dataset.py +1 -3
  15. datachain/lib/arrow.py +85 -0
  16. datachain/lib/dc.py +377 -178
  17. datachain/lib/feature.py +41 -90
  18. datachain/lib/feature_registry.py +3 -1
  19. datachain/lib/feature_utils.py +2 -2
  20. datachain/lib/file.py +20 -20
  21. datachain/lib/image.py +9 -2
  22. datachain/lib/meta_formats.py +66 -34
  23. datachain/lib/settings.py +5 -5
  24. datachain/lib/signal_schema.py +103 -105
  25. datachain/lib/udf.py +3 -12
  26. datachain/lib/udf_signature.py +11 -6
  27. datachain/lib/webdataset_laion.py +5 -22
  28. datachain/listing.py +8 -8
  29. datachain/node.py +1 -1
  30. datachain/progress.py +1 -1
  31. datachain/query/builtins.py +1 -1
  32. datachain/query/dataset.py +39 -110
  33. datachain/query/dispatch.py +1 -1
  34. datachain/query/metrics.py +19 -0
  35. datachain/query/schema.py +13 -3
  36. datachain/sql/__init__.py +1 -1
  37. datachain/utils.py +1 -122
  38. {datachain-0.1.12.dist-info → datachain-0.2.0.dist-info}/METADATA +10 -3
  39. {datachain-0.1.12.dist-info → datachain-0.2.0.dist-info}/RECORD +43 -42
  40. {datachain-0.1.12.dist-info → datachain-0.2.0.dist-info}/WHEEL +1 -1
  41. datachain/lib/parquet.py +0 -32
  42. {datachain-0.1.12.dist-info → datachain-0.2.0.dist-info}/LICENSE +0 -0
  43. {datachain-0.1.12.dist-info → datachain-0.2.0.dist-info}/entry_points.txt +0 -0
  44. {datachain-0.1.12.dist-info → datachain-0.2.0.dist-info}/top_level.txt +0 -0
datachain/_version.py CHANGED
@@ -12,5 +12,5 @@ __version__: str
12
12
  __version_tuple__: VERSION_TUPLE
13
13
  version_tuple: VERSION_TUPLE
14
14
 
15
- __version__ = version = '0.1.12'
16
- __version_tuple__ = version_tuple = (0, 1, 12)
15
+ __version__ = version = '0.2.0'
16
+ __version_tuple__ = version_tuple = (0, 2, 0)
datachain/asyn.py CHANGED
@@ -82,13 +82,13 @@ class AsyncMapper(Generic[InputT, ResultT]):
82
82
  for _i in range(self.workers):
83
83
  self.start_task(self.worker())
84
84
  try:
85
- done, pending = await asyncio.wait(
85
+ done, _pending = await asyncio.wait(
86
86
  self._tasks, return_when=asyncio.FIRST_COMPLETED
87
87
  )
88
88
  self.gather_exceptions(done)
89
89
  assert producer.done()
90
90
  join = self.start_task(self.work_queue.join())
91
- done, pending = await asyncio.wait(
91
+ done, _pending = await asyncio.wait(
92
92
  self._tasks, return_when=asyncio.FIRST_COMPLETED
93
93
  )
94
94
  self.gather_exceptions(done)
@@ -208,7 +208,7 @@ class OrderedMapper(AsyncMapper[InputT, ResultT]):
208
208
 
209
209
  async def _pop_result(self) -> Optional[ResultT]:
210
210
  if self.heap and self.heap[0][0] == self._next_yield:
211
- i, out = heappop(self.heap)
211
+ _i, out = heappop(self.heap)
212
212
  else:
213
213
  self._getters[self._next_yield] = get_value = self.loop.create_future()
214
214
  out = await get_value
@@ -8,10 +8,10 @@ from .catalog import (
8
8
  from .loader import get_catalog
9
9
 
10
10
  __all__ = [
11
+ "QUERY_DATASET_PREFIX",
12
+ "QUERY_SCRIPT_CANCELED_EXIT_CODE",
13
+ "QUERY_SCRIPT_INVALID_LAST_STATEMENT_EXIT_CODE",
11
14
  "Catalog",
12
15
  "get_catalog",
13
16
  "parse_edatachain_file",
14
- "QUERY_SCRIPT_INVALID_LAST_STATEMENT_EXIT_CODE",
15
- "QUERY_SCRIPT_CANCELED_EXIT_CODE",
16
- "QUERY_DATASET_PREFIX",
17
17
  ]
@@ -142,6 +142,7 @@ class QueryResult(NamedTuple):
142
142
  version: Optional[int]
143
143
  output: str
144
144
  preview: Optional[list[dict]]
145
+ metrics: dict[str, Any]
145
146
 
146
147
 
147
148
  class DatasetRowsFetcher(NodesThreadPool):
@@ -876,13 +877,11 @@ class Catalog:
876
877
  # so this is to improve performance
877
878
  return None
878
879
 
879
- dsrc_all = []
880
+ dsrc_all: list[DataSource] = []
880
881
  for listing, file_path in enlisted_sources:
881
882
  nodes = listing.expand_path(file_path)
882
883
  dir_only = file_path.endswith("/")
883
- for node in nodes:
884
- dsrc_all.append(DataSource(listing, node, dir_only))
885
-
884
+ dsrc_all.extend(DataSource(listing, node, dir_only) for node in nodes)
886
885
  return dsrc_all
887
886
 
888
887
  def enlist_sources_grouped(
@@ -1997,6 +1996,7 @@ class Catalog:
1997
1996
  version=version,
1998
1997
  output=output,
1999
1998
  preview=exec_result.preview,
1999
+ metrics=exec_result.metrics,
2000
2000
  )
2001
2001
 
2002
2002
  def run_query(
@@ -2068,8 +2068,8 @@ class Catalog:
2068
2068
  "DATACHAIN_JOB_ID": job_id or "",
2069
2069
  },
2070
2070
  )
2071
- with subprocess.Popen(
2072
- [python_executable, "-c", query_script_compiled], # noqa: S603
2071
+ with subprocess.Popen( # noqa: S603
2072
+ [python_executable, "-c", query_script_compiled],
2073
2073
  env=envs,
2074
2074
  stdout=subprocess.PIPE if capture_output else None,
2075
2075
  stderr=subprocess.STDOUT if capture_output else None,
@@ -35,7 +35,7 @@ def get_id_generator() -> "AbstractIDGenerator":
35
35
  id_generator_obj = deserialize(id_generator_serialized)
36
36
  if not isinstance(id_generator_obj, AbstractIDGenerator):
37
37
  raise RuntimeError(
38
- f"Deserialized ID generator is not an instance of AbstractIDGenerator: "
38
+ "Deserialized ID generator is not an instance of AbstractIDGenerator: "
39
39
  f"{id_generator_obj}"
40
40
  )
41
41
  return id_generator_obj
@@ -67,7 +67,7 @@ def get_metastore(id_generator: Optional["AbstractIDGenerator"]) -> "AbstractMet
67
67
  metastore_obj = deserialize(metastore_serialized)
68
68
  if not isinstance(metastore_obj, AbstractMetastore):
69
69
  raise RuntimeError(
70
- f"Deserialized Metastore is not an instance of AbstractMetastore: "
70
+ "Deserialized Metastore is not an instance of AbstractMetastore: "
71
71
  f"{metastore_obj}"
72
72
  )
73
73
  return metastore_obj
@@ -101,7 +101,7 @@ def get_warehouse(id_generator: Optional["AbstractIDGenerator"]) -> "AbstractWar
101
101
  warehouse_obj = deserialize(warehouse_serialized)
102
102
  if not isinstance(warehouse_obj, AbstractWarehouse):
103
103
  raise RuntimeError(
104
- f"Deserialized Warehouse is not an instance of AbstractWarehouse: "
104
+ "Deserialized Warehouse is not an instance of AbstractWarehouse: "
105
105
  f"{warehouse_obj}"
106
106
  )
107
107
  return warehouse_obj
datachain/cli.py CHANGED
@@ -845,6 +845,7 @@ def query(
845
845
  query=script_content,
846
846
  query_type=JobQueryType.PYTHON,
847
847
  python_version=python_version,
848
+ params=params,
848
849
  )
849
850
 
850
851
  try:
@@ -870,7 +871,7 @@ def query(
870
871
  )
871
872
  raise
872
873
 
873
- catalog.metastore.set_job_status(job_id, JobStatus.COMPLETE)
874
+ catalog.metastore.set_job_status(job_id, JobStatus.COMPLETE, metrics=result.metrics)
874
875
 
875
876
  show_records(result.preview, collapse_columns=not no_collapse)
876
877
 
datachain/client/azure.py CHANGED
@@ -1,10 +1,12 @@
1
+ import posixpath
1
2
  from typing import Any
2
3
 
3
4
  from adlfs import AzureBlobFileSystem
5
+ from tqdm import tqdm
4
6
 
5
7
  from datachain.node import Entry
6
8
 
7
- from .fsspec import DELIMITER, Client
9
+ from .fsspec import DELIMITER, Client, ResultQueue
8
10
 
9
11
 
10
12
  class AzureClient(Client):
@@ -28,3 +30,37 @@ class AzureClient(Client):
28
30
  last_modified=v["last_modified"],
29
31
  size=v.get("size", ""),
30
32
  )
33
+
34
+ async def _fetch_flat(self, start_prefix: str, result_queue: ResultQueue) -> None:
35
+ prefix = start_prefix
36
+ if prefix:
37
+ prefix = prefix.lstrip(DELIMITER) + DELIMITER
38
+ found = False
39
+ try:
40
+ with tqdm(desc=f"Listing {self.uri}", unit=" objects") as pbar:
41
+ async with self.fs.service_client.get_container_client(
42
+ container=self.name
43
+ ) as container_client:
44
+ async for page in container_client.list_blobs(
45
+ include=["metadata", "versions"], name_starts_with=prefix
46
+ ).by_page():
47
+ entries = []
48
+ async for b in page:
49
+ found = True
50
+ if not self._is_valid_key(b["name"]):
51
+ continue
52
+ info = (await self.fs._details([b]))[0]
53
+ full_path = info["name"]
54
+ parent = posixpath.dirname(self.rel_path(full_path))
55
+ entries.append(self.convert_info(info, parent))
56
+ if entries:
57
+ await result_queue.put(entries)
58
+ pbar.update(len(entries))
59
+ if not found:
60
+ raise FileNotFoundError(
61
+ f"Unable to resolve remote path: {prefix}"
62
+ )
63
+ finally:
64
+ result_queue.put_nowait(None)
65
+
66
+ _fetch_default = _fetch_flat
@@ -202,7 +202,7 @@ class Client(ABC):
202
202
  try:
203
203
  impl = getattr(self, f"_fetch_{method}")
204
204
  except AttributeError:
205
- raise ValueError("Unknown indexing method '{method}'") from None
205
+ raise ValueError(f"Unknown indexing method '{method}'") from None
206
206
  result_queue: ResultQueue = asyncio.Queue()
207
207
  loop = get_loop()
208
208
  main_task = loop.create_task(impl(start_prefix, result_queue))
datachain/client/local.py CHANGED
@@ -135,7 +135,7 @@ class FileClient(Client):
135
135
  return posixpath.relpath(path, self.name)
136
136
 
137
137
  def get_full_path(self, rel_path):
138
- full_path = Path(self.name, rel_path).as_uri()
138
+ full_path = Path(self.name, rel_path).as_posix()
139
139
  if rel_path.endswith("/") or not rel_path:
140
140
  full_path += "/"
141
141
  return full_path
@@ -5,8 +5,8 @@ from .warehouse import AbstractWarehouse
5
5
 
6
6
  __all__ = [
7
7
  "AbstractDBIDGenerator",
8
- "AbstractIDGenerator",
9
8
  "AbstractDBMetastore",
9
+ "AbstractIDGenerator",
10
10
  "AbstractMetastore",
11
11
  "AbstractWarehouse",
12
12
  "JobQueryType",
@@ -385,6 +385,7 @@ class AbstractMetastore(ABC, Serializable):
385
385
  query_type: JobQueryType = JobQueryType.PYTHON,
386
386
  workers: int = 1,
387
387
  python_version: Optional[str] = None,
388
+ params: Optional[dict[str, str]] = None,
388
389
  ) -> str:
389
390
  """
390
391
  Creates a new job.
@@ -398,6 +399,7 @@ class AbstractMetastore(ABC, Serializable):
398
399
  status: JobStatus,
399
400
  error_message: Optional[str] = None,
400
401
  error_stack: Optional[str] = None,
402
+ metrics: Optional[dict[str, Any]] = None,
401
403
  ) -> None:
402
404
  """Set the status of the given job."""
403
405
 
@@ -1165,9 +1167,7 @@ class AbstractDBMetastore(AbstractMetastore):
1165
1167
  return dataset_version
1166
1168
 
1167
1169
  def _parse_dataset(self, rows) -> Optional[DatasetRecord]:
1168
- versions = []
1169
- for r in rows:
1170
- versions.append(self.dataset_class.parse(*r))
1170
+ versions = [self.dataset_class.parse(*r) for r in rows]
1171
1171
  if not versions:
1172
1172
  return None
1173
1173
  return reduce(lambda ds, version: ds.merge_versions(version), versions)
@@ -1463,6 +1463,8 @@ class AbstractDBMetastore(AbstractMetastore):
1463
1463
  Column("python_version", Text, nullable=True),
1464
1464
  Column("error_message", Text, nullable=False, default=""),
1465
1465
  Column("error_stack", Text, nullable=False, default=""),
1466
+ Column("params", JSON, nullable=False),
1467
+ Column("metrics", JSON, nullable=False),
1466
1468
  ]
1467
1469
 
1468
1470
  @cached_property
@@ -1489,6 +1491,7 @@ class AbstractDBMetastore(AbstractMetastore):
1489
1491
  query_type: JobQueryType = JobQueryType.PYTHON,
1490
1492
  workers: int = 1,
1491
1493
  python_version: Optional[str] = None,
1494
+ params: Optional[dict[str, str]] = None,
1492
1495
  conn: Optional[Any] = None,
1493
1496
  ) -> str:
1494
1497
  """
@@ -1508,6 +1511,8 @@ class AbstractDBMetastore(AbstractMetastore):
1508
1511
  python_version=python_version,
1509
1512
  error_message="",
1510
1513
  error_stack="",
1514
+ params=json.dumps(params or {}),
1515
+ metrics=json.dumps({}),
1511
1516
  ),
1512
1517
  conn=conn,
1513
1518
  )
@@ -1519,6 +1524,7 @@ class AbstractDBMetastore(AbstractMetastore):
1519
1524
  status: JobStatus,
1520
1525
  error_message: Optional[str] = None,
1521
1526
  error_stack: Optional[str] = None,
1527
+ metrics: Optional[dict[str, Any]] = None,
1522
1528
  conn: Optional[Any] = None,
1523
1529
  ) -> None:
1524
1530
  """Set the status of the given job."""
@@ -1529,6 +1535,8 @@ class AbstractDBMetastore(AbstractMetastore):
1529
1535
  values["error_message"] = error_message
1530
1536
  if error_stack:
1531
1537
  values["error_stack"] = error_stack
1538
+ if metrics:
1539
+ values["metrics"] = json.dumps(metrics)
1532
1540
  self.db.execute(
1533
1541
  self._jobs_update(self._jobs.c.id == job_id).values(**values),
1534
1542
  conn=conn,
@@ -34,8 +34,7 @@ def dedup_columns(columns: Iterable[sa.Column]) -> list[sa.Column]:
34
34
  if ec := c_set.get(c.name, None):
35
35
  if str(ec.type) != str(c.type):
36
36
  raise ValueError(
37
- f"conflicting types for column {c.name}:"
38
- f"{c.type!s} and {ec.type!s}"
37
+ f"conflicting types for column {c.name}:{c.type!s} and {ec.type!s}"
39
38
  )
40
39
  continue
41
40
  c_set[c.name] = c
@@ -235,6 +234,7 @@ class DataTable:
235
234
  def file_columns(cls) -> list[sa.Column]:
236
235
  return [
237
236
  sa.Column("id", Int, primary_key=True),
237
+ sa.Column("random", Int64, nullable=False),
238
238
  sa.Column("vtype", String, nullable=False, index=True),
239
239
  sa.Column("dir_type", Int, index=True),
240
240
  sa.Column("parent", String, index=True),
@@ -246,7 +246,6 @@ class DataTable:
246
246
  sa.Column("size", Int64, nullable=False, index=True),
247
247
  sa.Column("owner_name", String),
248
248
  sa.Column("owner_id", String),
249
- sa.Column("random", Int64, nullable=False),
250
249
  sa.Column("location", JSON),
251
250
  sa.Column("source", String, nullable=False),
252
251
  ]
@@ -95,14 +95,14 @@ class AbstractWarehouse(ABC, Serializable):
95
95
 
96
96
  exc = None
97
97
  try:
98
- if col_python_type == list and value_type in (list, tuple, set):
98
+ if col_python_type is list and value_type in (list, tuple, set):
99
99
  if len(val) == 0:
100
100
  return []
101
101
  item_python_type = self.python_type(col_type.item_type)
102
- if item_python_type != list:
102
+ if item_python_type is not list:
103
103
  if isinstance(val[0], item_python_type):
104
104
  return val
105
- if item_python_type == float and isinstance(val[0], int):
105
+ if item_python_type is float and isinstance(val[0], int):
106
106
  return [float(i) for i in val]
107
107
  # Optimization: Reuse these values for each function call within the
108
108
  # list comprehension.
@@ -114,18 +114,18 @@ class AbstractWarehouse(ABC, Serializable):
114
114
  )
115
115
  return [self.convert_type(i, *item_type_info) for i in val]
116
116
  # Special use case with JSON type as we save it as string
117
- if col_python_type == dict or col_type_name == "JSON":
118
- if value_type == str:
117
+ if col_python_type is dict or col_type_name == "JSON":
118
+ if value_type is str:
119
119
  return val
120
120
  if value_type in (dict, list):
121
121
  return json.dumps(val)
122
122
  raise ValueError(
123
- f"Cannot convert value {val!r} with type" f"{value_type} to JSON"
123
+ f"Cannot convert value {val!r} with type {value_type} to JSON"
124
124
  )
125
125
 
126
126
  if isinstance(val, col_python_type):
127
127
  return val
128
- if col_python_type == float and isinstance(val, int):
128
+ if col_python_type is float and isinstance(val, int):
129
129
  return float(val)
130
130
  except Exception as e: # noqa: BLE001
131
131
  exc = e
@@ -335,6 +335,7 @@ class AbstractWarehouse(ABC, Serializable):
335
335
  return select_query
336
336
  if recursive:
337
337
  root = False
338
+ where = self.path_expr(dr).op("GLOB")(path)
338
339
  if not path or path == "/":
339
340
  # root of the bucket, e.g s3://bucket/ -> getting all the nodes
340
341
  # in the bucket
@@ -344,14 +345,18 @@ class AbstractWarehouse(ABC, Serializable):
344
345
  # not a root and not a explicit glob, so it's pointing to some directory
345
346
  # and we are adding a proper glob syntax for it
346
347
  # e.g s3://bucket/dir1 -> s3://bucket/dir1/*
347
- path = path.rstrip("/") + "/*"
348
+ dir_path = path.rstrip("/") + "/*"
349
+ where = where | self.path_expr(dr).op("GLOB")(dir_path)
348
350
 
349
351
  if not root:
350
352
  # not a root, so running glob query
351
- select_query = select_query.where(self.path_expr(dr).op("GLOB")(path))
353
+ select_query = select_query.where(where)
354
+
352
355
  else:
353
356
  parent = self.get_node_by_path(dr, path.lstrip("/").rstrip("/*"))
354
- select_query = select_query.where(dr.c.parent == parent.path)
357
+ select_query = select_query.where(
358
+ (dr.c.parent == parent.path) | (self.path_expr(dr) == path)
359
+ )
355
360
  return select_query
356
361
 
357
362
  def rename_dataset_table(
@@ -493,7 +498,10 @@ class AbstractWarehouse(ABC, Serializable):
493
498
  This gets nodes based on the provided query, and should be used sparingly,
494
499
  as it will be slow on any OLAP database systems.
495
500
  """
496
- return (Node(*row) for row in self.db.execute(query))
501
+ columns = [c.name for c in query.columns]
502
+ for row in self.db.execute(query):
503
+ d = dict(zip(columns, row))
504
+ yield Node(**d)
497
505
 
498
506
  def get_dirs_by_parent_path(
499
507
  self,
@@ -570,14 +578,12 @@ class AbstractWarehouse(ABC, Serializable):
570
578
  matched_paths: list[list[str]] = [[]]
571
579
  for curr_name in path_list[:-1]:
572
580
  if glob.has_magic(curr_name):
573
- new_paths = []
581
+ new_paths: list[list[str]] = []
574
582
  for path in matched_paths:
575
583
  nodes = self._get_nodes_by_glob_path_pattern(
576
584
  dataset_rows, path, curr_name
577
585
  )
578
- for node in nodes:
579
- if node.is_container:
580
- new_paths.append([*path, node.name or ""])
586
+ new_paths.extend([*path, n.name] for n in nodes if n.is_container)
581
587
  matched_paths = new_paths
582
588
  else:
583
589
  for path in matched_paths:
@@ -772,7 +778,7 @@ class AbstractWarehouse(ABC, Serializable):
772
778
  self,
773
779
  dataset_rows: "DataTable",
774
780
  parent_path: str,
775
- fields: Optional[Iterable[str]] = None,
781
+ fields: Optional[Sequence[str]] = None,
776
782
  type: Optional[str] = None,
777
783
  conds=None,
778
784
  order_by: Optional[Union[str, list[str]]] = None,
@@ -794,9 +800,9 @@ class AbstractWarehouse(ABC, Serializable):
794
800
  else:
795
801
  conds.append(path != "")
796
802
 
797
- if fields is None:
798
- fields = [c.name for c in dr.file_columns()]
799
- columns = [getattr(q.c, f) for f in fields]
803
+ columns = q.c
804
+ if fields:
805
+ columns = [getattr(columns, f) for f in fields]
800
806
 
801
807
  query = sa.select(*columns)
802
808
  query = query.where(*conds)
@@ -833,19 +839,16 @@ class AbstractWarehouse(ABC, Serializable):
833
839
 
834
840
  prefix_len = len(node.path)
835
841
 
836
- def make_node_with_path(row):
837
- sub_node = Node(*row)
838
- return NodeWithPath(
839
- sub_node, sub_node.path[prefix_len:].lstrip("/").split("/")
840
- )
842
+ def make_node_with_path(node: Node) -> NodeWithPath:
843
+ return NodeWithPath(node, node.path[prefix_len:].lstrip("/").split("/"))
841
844
 
842
- return map(make_node_with_path, self.db.execute(query))
845
+ return map(make_node_with_path, self.get_nodes(query))
843
846
 
844
847
  def find(
845
848
  self,
846
849
  dataset_rows: "DataTable",
847
850
  node: Node,
848
- fields: Iterable[str],
851
+ fields: Sequence[str],
849
852
  type=None,
850
853
  conds=None,
851
854
  order_by=None,
@@ -890,11 +893,9 @@ class AbstractWarehouse(ABC, Serializable):
890
893
  def is_temp_table_name(self, name: str) -> bool:
891
894
  """Returns if the given table name refers to a temporary
892
895
  or no longer needed table."""
893
- if name.startswith(
896
+ return name.startswith(
894
897
  (self.TMP_TABLE_NAME_PREFIX, self.UDF_TABLE_NAME_PREFIX, "ds_shadow_")
895
- ) or name.endswith("_shadow"):
896
- return True
897
- return False
898
+ ) or name.endswith("_shadow")
898
899
 
899
900
  def get_temp_table_names(self) -> list[str]:
900
901
  return [
datachain/dataset.py CHANGED
@@ -405,9 +405,7 @@ class DatasetRecord:
405
405
  Checks if a number can be a valid next latest version for dataset.
406
406
  The only rule is that it cannot be lower than current latest version
407
407
  """
408
- if self.latest_version and self.latest_version >= version:
409
- return False
410
- return True
408
+ return not (self.latest_version and self.latest_version >= version)
411
409
 
412
410
  def get_version(self, version: int) -> DatasetVersion:
413
411
  if not self.has_version(version):
datachain/lib/arrow.py ADDED
@@ -0,0 +1,85 @@
1
+ import re
2
+ from typing import TYPE_CHECKING, Optional
3
+
4
+ from pyarrow.dataset import dataset
5
+
6
+ from datachain.lib.feature import Feature
7
+ from datachain.lib.file import File
8
+
9
+ if TYPE_CHECKING:
10
+ import pyarrow as pa
11
+
12
+
13
+ class Source(Feature):
14
+ """File source info for tables."""
15
+
16
+ file: File
17
+ index: int
18
+
19
+
20
+ class ArrowGenerator:
21
+ def __init__(self, schema: Optional["pa.Schema"] = None, **kwargs):
22
+ """
23
+ Generator for getting rows from tabular files.
24
+
25
+ Parameters:
26
+
27
+ schema : Optional pyarrow schema for validation.
28
+ kwargs: Parameters to pass to pyarrow.dataset.dataset.
29
+ """
30
+ self.schema = schema
31
+ self.kwargs = kwargs
32
+
33
+ def __call__(self, file: File):
34
+ path = file.get_path()
35
+ ds = dataset(path, filesystem=file.get_fs(), schema=self.schema, **self.kwargs)
36
+ index = 0
37
+ for record_batch in ds.to_batches():
38
+ for record in record_batch.to_pylist():
39
+ source = Source(file=file, index=index)
40
+ yield [source, *record.values()]
41
+ index += 1
42
+
43
+
44
+ def schema_to_output(schema: "pa.Schema"):
45
+ """Generate UDF output schema from pyarrow schema."""
46
+ default_column = 0
47
+ output = {"source": Source}
48
+ for field in schema:
49
+ column = field.name.lower()
50
+ column = re.sub("[^0-9a-z_]+", "", column)
51
+ if not column:
52
+ column = f"c{default_column}"
53
+ default_column += 1
54
+ output[column] = _arrow_type_mapper(field.type) # type: ignore[assignment]
55
+
56
+ return output
57
+
58
+
59
+ def _arrow_type_mapper(col_type: "pa.DataType") -> type: # noqa: PLR0911
60
+ """Convert pyarrow types to basic types."""
61
+ from datetime import datetime
62
+
63
+ import pyarrow as pa
64
+
65
+ if pa.types.is_timestamp(col_type):
66
+ return datetime
67
+ if pa.types.is_binary(col_type):
68
+ return bytes
69
+ if pa.types.is_floating(col_type):
70
+ return float
71
+ if pa.types.is_integer(col_type):
72
+ return int
73
+ if pa.types.is_boolean(col_type):
74
+ return bool
75
+ if pa.types.is_date(col_type):
76
+ return datetime
77
+ if pa.types.is_string(col_type) or pa.types.is_large_string(col_type):
78
+ return str
79
+ if pa.types.is_list(col_type):
80
+ return list[_arrow_type_mapper(col_type.value_type)] # type: ignore[misc]
81
+ if pa.types.is_struct(col_type) or pa.types.is_map(col_type):
82
+ return dict
83
+ if isinstance(col_type, pa.lib.DictionaryType):
84
+ return _arrow_type_mapper(col_type.value_type) # type: ignore[return-value]
85
+ raise TypeError(f"{col_type!r} datatypes not supported")