datachain 0.1.12__py3-none-any.whl → 0.2.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of datachain might be problematic. Click here for more details.
- datachain/_version.py +2 -2
- datachain/asyn.py +3 -3
- datachain/catalog/__init__.py +3 -3
- datachain/catalog/catalog.py +6 -6
- datachain/catalog/loader.py +3 -3
- datachain/cli.py +2 -1
- datachain/client/azure.py +37 -1
- datachain/client/fsspec.py +1 -1
- datachain/client/local.py +1 -1
- datachain/data_storage/__init__.py +1 -1
- datachain/data_storage/metastore.py +11 -3
- datachain/data_storage/schema.py +2 -3
- datachain/data_storage/warehouse.py +31 -30
- datachain/dataset.py +1 -3
- datachain/lib/arrow.py +85 -0
- datachain/lib/dc.py +377 -178
- datachain/lib/feature.py +41 -90
- datachain/lib/feature_registry.py +3 -1
- datachain/lib/feature_utils.py +2 -2
- datachain/lib/file.py +20 -20
- datachain/lib/image.py +9 -2
- datachain/lib/meta_formats.py +66 -34
- datachain/lib/settings.py +5 -5
- datachain/lib/signal_schema.py +103 -105
- datachain/lib/udf.py +3 -12
- datachain/lib/udf_signature.py +11 -6
- datachain/lib/webdataset_laion.py +5 -22
- datachain/listing.py +8 -8
- datachain/node.py +1 -1
- datachain/progress.py +1 -1
- datachain/query/builtins.py +1 -1
- datachain/query/dataset.py +39 -110
- datachain/query/dispatch.py +1 -1
- datachain/query/metrics.py +19 -0
- datachain/query/schema.py +13 -3
- datachain/sql/__init__.py +1 -1
- datachain/utils.py +1 -122
- {datachain-0.1.12.dist-info → datachain-0.2.0.dist-info}/METADATA +10 -3
- {datachain-0.1.12.dist-info → datachain-0.2.0.dist-info}/RECORD +43 -42
- {datachain-0.1.12.dist-info → datachain-0.2.0.dist-info}/WHEEL +1 -1
- datachain/lib/parquet.py +0 -32
- {datachain-0.1.12.dist-info → datachain-0.2.0.dist-info}/LICENSE +0 -0
- {datachain-0.1.12.dist-info → datachain-0.2.0.dist-info}/entry_points.txt +0 -0
- {datachain-0.1.12.dist-info → datachain-0.2.0.dist-info}/top_level.txt +0 -0
datachain/_version.py
CHANGED
datachain/asyn.py
CHANGED
|
@@ -82,13 +82,13 @@ class AsyncMapper(Generic[InputT, ResultT]):
|
|
|
82
82
|
for _i in range(self.workers):
|
|
83
83
|
self.start_task(self.worker())
|
|
84
84
|
try:
|
|
85
|
-
done,
|
|
85
|
+
done, _pending = await asyncio.wait(
|
|
86
86
|
self._tasks, return_when=asyncio.FIRST_COMPLETED
|
|
87
87
|
)
|
|
88
88
|
self.gather_exceptions(done)
|
|
89
89
|
assert producer.done()
|
|
90
90
|
join = self.start_task(self.work_queue.join())
|
|
91
|
-
done,
|
|
91
|
+
done, _pending = await asyncio.wait(
|
|
92
92
|
self._tasks, return_when=asyncio.FIRST_COMPLETED
|
|
93
93
|
)
|
|
94
94
|
self.gather_exceptions(done)
|
|
@@ -208,7 +208,7 @@ class OrderedMapper(AsyncMapper[InputT, ResultT]):
|
|
|
208
208
|
|
|
209
209
|
async def _pop_result(self) -> Optional[ResultT]:
|
|
210
210
|
if self.heap and self.heap[0][0] == self._next_yield:
|
|
211
|
-
|
|
211
|
+
_i, out = heappop(self.heap)
|
|
212
212
|
else:
|
|
213
213
|
self._getters[self._next_yield] = get_value = self.loop.create_future()
|
|
214
214
|
out = await get_value
|
datachain/catalog/__init__.py
CHANGED
|
@@ -8,10 +8,10 @@ from .catalog import (
|
|
|
8
8
|
from .loader import get_catalog
|
|
9
9
|
|
|
10
10
|
__all__ = [
|
|
11
|
+
"QUERY_DATASET_PREFIX",
|
|
12
|
+
"QUERY_SCRIPT_CANCELED_EXIT_CODE",
|
|
13
|
+
"QUERY_SCRIPT_INVALID_LAST_STATEMENT_EXIT_CODE",
|
|
11
14
|
"Catalog",
|
|
12
15
|
"get_catalog",
|
|
13
16
|
"parse_edatachain_file",
|
|
14
|
-
"QUERY_SCRIPT_INVALID_LAST_STATEMENT_EXIT_CODE",
|
|
15
|
-
"QUERY_SCRIPT_CANCELED_EXIT_CODE",
|
|
16
|
-
"QUERY_DATASET_PREFIX",
|
|
17
17
|
]
|
datachain/catalog/catalog.py
CHANGED
|
@@ -142,6 +142,7 @@ class QueryResult(NamedTuple):
|
|
|
142
142
|
version: Optional[int]
|
|
143
143
|
output: str
|
|
144
144
|
preview: Optional[list[dict]]
|
|
145
|
+
metrics: dict[str, Any]
|
|
145
146
|
|
|
146
147
|
|
|
147
148
|
class DatasetRowsFetcher(NodesThreadPool):
|
|
@@ -876,13 +877,11 @@ class Catalog:
|
|
|
876
877
|
# so this is to improve performance
|
|
877
878
|
return None
|
|
878
879
|
|
|
879
|
-
dsrc_all = []
|
|
880
|
+
dsrc_all: list[DataSource] = []
|
|
880
881
|
for listing, file_path in enlisted_sources:
|
|
881
882
|
nodes = listing.expand_path(file_path)
|
|
882
883
|
dir_only = file_path.endswith("/")
|
|
883
|
-
for node in nodes
|
|
884
|
-
dsrc_all.append(DataSource(listing, node, dir_only))
|
|
885
|
-
|
|
884
|
+
dsrc_all.extend(DataSource(listing, node, dir_only) for node in nodes)
|
|
886
885
|
return dsrc_all
|
|
887
886
|
|
|
888
887
|
def enlist_sources_grouped(
|
|
@@ -1997,6 +1996,7 @@ class Catalog:
|
|
|
1997
1996
|
version=version,
|
|
1998
1997
|
output=output,
|
|
1999
1998
|
preview=exec_result.preview,
|
|
1999
|
+
metrics=exec_result.metrics,
|
|
2000
2000
|
)
|
|
2001
2001
|
|
|
2002
2002
|
def run_query(
|
|
@@ -2068,8 +2068,8 @@ class Catalog:
|
|
|
2068
2068
|
"DATACHAIN_JOB_ID": job_id or "",
|
|
2069
2069
|
},
|
|
2070
2070
|
)
|
|
2071
|
-
with subprocess.Popen(
|
|
2072
|
-
[python_executable, "-c", query_script_compiled],
|
|
2071
|
+
with subprocess.Popen( # noqa: S603
|
|
2072
|
+
[python_executable, "-c", query_script_compiled],
|
|
2073
2073
|
env=envs,
|
|
2074
2074
|
stdout=subprocess.PIPE if capture_output else None,
|
|
2075
2075
|
stderr=subprocess.STDOUT if capture_output else None,
|
datachain/catalog/loader.py
CHANGED
|
@@ -35,7 +35,7 @@ def get_id_generator() -> "AbstractIDGenerator":
|
|
|
35
35
|
id_generator_obj = deserialize(id_generator_serialized)
|
|
36
36
|
if not isinstance(id_generator_obj, AbstractIDGenerator):
|
|
37
37
|
raise RuntimeError(
|
|
38
|
-
|
|
38
|
+
"Deserialized ID generator is not an instance of AbstractIDGenerator: "
|
|
39
39
|
f"{id_generator_obj}"
|
|
40
40
|
)
|
|
41
41
|
return id_generator_obj
|
|
@@ -67,7 +67,7 @@ def get_metastore(id_generator: Optional["AbstractIDGenerator"]) -> "AbstractMet
|
|
|
67
67
|
metastore_obj = deserialize(metastore_serialized)
|
|
68
68
|
if not isinstance(metastore_obj, AbstractMetastore):
|
|
69
69
|
raise RuntimeError(
|
|
70
|
-
|
|
70
|
+
"Deserialized Metastore is not an instance of AbstractMetastore: "
|
|
71
71
|
f"{metastore_obj}"
|
|
72
72
|
)
|
|
73
73
|
return metastore_obj
|
|
@@ -101,7 +101,7 @@ def get_warehouse(id_generator: Optional["AbstractIDGenerator"]) -> "AbstractWar
|
|
|
101
101
|
warehouse_obj = deserialize(warehouse_serialized)
|
|
102
102
|
if not isinstance(warehouse_obj, AbstractWarehouse):
|
|
103
103
|
raise RuntimeError(
|
|
104
|
-
|
|
104
|
+
"Deserialized Warehouse is not an instance of AbstractWarehouse: "
|
|
105
105
|
f"{warehouse_obj}"
|
|
106
106
|
)
|
|
107
107
|
return warehouse_obj
|
datachain/cli.py
CHANGED
|
@@ -845,6 +845,7 @@ def query(
|
|
|
845
845
|
query=script_content,
|
|
846
846
|
query_type=JobQueryType.PYTHON,
|
|
847
847
|
python_version=python_version,
|
|
848
|
+
params=params,
|
|
848
849
|
)
|
|
849
850
|
|
|
850
851
|
try:
|
|
@@ -870,7 +871,7 @@ def query(
|
|
|
870
871
|
)
|
|
871
872
|
raise
|
|
872
873
|
|
|
873
|
-
catalog.metastore.set_job_status(job_id, JobStatus.COMPLETE)
|
|
874
|
+
catalog.metastore.set_job_status(job_id, JobStatus.COMPLETE, metrics=result.metrics)
|
|
874
875
|
|
|
875
876
|
show_records(result.preview, collapse_columns=not no_collapse)
|
|
876
877
|
|
datachain/client/azure.py
CHANGED
|
@@ -1,10 +1,12 @@
|
|
|
1
|
+
import posixpath
|
|
1
2
|
from typing import Any
|
|
2
3
|
|
|
3
4
|
from adlfs import AzureBlobFileSystem
|
|
5
|
+
from tqdm import tqdm
|
|
4
6
|
|
|
5
7
|
from datachain.node import Entry
|
|
6
8
|
|
|
7
|
-
from .fsspec import DELIMITER, Client
|
|
9
|
+
from .fsspec import DELIMITER, Client, ResultQueue
|
|
8
10
|
|
|
9
11
|
|
|
10
12
|
class AzureClient(Client):
|
|
@@ -28,3 +30,37 @@ class AzureClient(Client):
|
|
|
28
30
|
last_modified=v["last_modified"],
|
|
29
31
|
size=v.get("size", ""),
|
|
30
32
|
)
|
|
33
|
+
|
|
34
|
+
async def _fetch_flat(self, start_prefix: str, result_queue: ResultQueue) -> None:
|
|
35
|
+
prefix = start_prefix
|
|
36
|
+
if prefix:
|
|
37
|
+
prefix = prefix.lstrip(DELIMITER) + DELIMITER
|
|
38
|
+
found = False
|
|
39
|
+
try:
|
|
40
|
+
with tqdm(desc=f"Listing {self.uri}", unit=" objects") as pbar:
|
|
41
|
+
async with self.fs.service_client.get_container_client(
|
|
42
|
+
container=self.name
|
|
43
|
+
) as container_client:
|
|
44
|
+
async for page in container_client.list_blobs(
|
|
45
|
+
include=["metadata", "versions"], name_starts_with=prefix
|
|
46
|
+
).by_page():
|
|
47
|
+
entries = []
|
|
48
|
+
async for b in page:
|
|
49
|
+
found = True
|
|
50
|
+
if not self._is_valid_key(b["name"]):
|
|
51
|
+
continue
|
|
52
|
+
info = (await self.fs._details([b]))[0]
|
|
53
|
+
full_path = info["name"]
|
|
54
|
+
parent = posixpath.dirname(self.rel_path(full_path))
|
|
55
|
+
entries.append(self.convert_info(info, parent))
|
|
56
|
+
if entries:
|
|
57
|
+
await result_queue.put(entries)
|
|
58
|
+
pbar.update(len(entries))
|
|
59
|
+
if not found:
|
|
60
|
+
raise FileNotFoundError(
|
|
61
|
+
f"Unable to resolve remote path: {prefix}"
|
|
62
|
+
)
|
|
63
|
+
finally:
|
|
64
|
+
result_queue.put_nowait(None)
|
|
65
|
+
|
|
66
|
+
_fetch_default = _fetch_flat
|
datachain/client/fsspec.py
CHANGED
|
@@ -202,7 +202,7 @@ class Client(ABC):
|
|
|
202
202
|
try:
|
|
203
203
|
impl = getattr(self, f"_fetch_{method}")
|
|
204
204
|
except AttributeError:
|
|
205
|
-
raise ValueError("Unknown indexing method '{method}'") from None
|
|
205
|
+
raise ValueError(f"Unknown indexing method '{method}'") from None
|
|
206
206
|
result_queue: ResultQueue = asyncio.Queue()
|
|
207
207
|
loop = get_loop()
|
|
208
208
|
main_task = loop.create_task(impl(start_prefix, result_queue))
|
datachain/client/local.py
CHANGED
|
@@ -135,7 +135,7 @@ class FileClient(Client):
|
|
|
135
135
|
return posixpath.relpath(path, self.name)
|
|
136
136
|
|
|
137
137
|
def get_full_path(self, rel_path):
|
|
138
|
-
full_path = Path(self.name, rel_path).
|
|
138
|
+
full_path = Path(self.name, rel_path).as_posix()
|
|
139
139
|
if rel_path.endswith("/") or not rel_path:
|
|
140
140
|
full_path += "/"
|
|
141
141
|
return full_path
|
|
@@ -385,6 +385,7 @@ class AbstractMetastore(ABC, Serializable):
|
|
|
385
385
|
query_type: JobQueryType = JobQueryType.PYTHON,
|
|
386
386
|
workers: int = 1,
|
|
387
387
|
python_version: Optional[str] = None,
|
|
388
|
+
params: Optional[dict[str, str]] = None,
|
|
388
389
|
) -> str:
|
|
389
390
|
"""
|
|
390
391
|
Creates a new job.
|
|
@@ -398,6 +399,7 @@ class AbstractMetastore(ABC, Serializable):
|
|
|
398
399
|
status: JobStatus,
|
|
399
400
|
error_message: Optional[str] = None,
|
|
400
401
|
error_stack: Optional[str] = None,
|
|
402
|
+
metrics: Optional[dict[str, Any]] = None,
|
|
401
403
|
) -> None:
|
|
402
404
|
"""Set the status of the given job."""
|
|
403
405
|
|
|
@@ -1165,9 +1167,7 @@ class AbstractDBMetastore(AbstractMetastore):
|
|
|
1165
1167
|
return dataset_version
|
|
1166
1168
|
|
|
1167
1169
|
def _parse_dataset(self, rows) -> Optional[DatasetRecord]:
|
|
1168
|
-
versions = []
|
|
1169
|
-
for r in rows:
|
|
1170
|
-
versions.append(self.dataset_class.parse(*r))
|
|
1170
|
+
versions = [self.dataset_class.parse(*r) for r in rows]
|
|
1171
1171
|
if not versions:
|
|
1172
1172
|
return None
|
|
1173
1173
|
return reduce(lambda ds, version: ds.merge_versions(version), versions)
|
|
@@ -1463,6 +1463,8 @@ class AbstractDBMetastore(AbstractMetastore):
|
|
|
1463
1463
|
Column("python_version", Text, nullable=True),
|
|
1464
1464
|
Column("error_message", Text, nullable=False, default=""),
|
|
1465
1465
|
Column("error_stack", Text, nullable=False, default=""),
|
|
1466
|
+
Column("params", JSON, nullable=False),
|
|
1467
|
+
Column("metrics", JSON, nullable=False),
|
|
1466
1468
|
]
|
|
1467
1469
|
|
|
1468
1470
|
@cached_property
|
|
@@ -1489,6 +1491,7 @@ class AbstractDBMetastore(AbstractMetastore):
|
|
|
1489
1491
|
query_type: JobQueryType = JobQueryType.PYTHON,
|
|
1490
1492
|
workers: int = 1,
|
|
1491
1493
|
python_version: Optional[str] = None,
|
|
1494
|
+
params: Optional[dict[str, str]] = None,
|
|
1492
1495
|
conn: Optional[Any] = None,
|
|
1493
1496
|
) -> str:
|
|
1494
1497
|
"""
|
|
@@ -1508,6 +1511,8 @@ class AbstractDBMetastore(AbstractMetastore):
|
|
|
1508
1511
|
python_version=python_version,
|
|
1509
1512
|
error_message="",
|
|
1510
1513
|
error_stack="",
|
|
1514
|
+
params=json.dumps(params or {}),
|
|
1515
|
+
metrics=json.dumps({}),
|
|
1511
1516
|
),
|
|
1512
1517
|
conn=conn,
|
|
1513
1518
|
)
|
|
@@ -1519,6 +1524,7 @@ class AbstractDBMetastore(AbstractMetastore):
|
|
|
1519
1524
|
status: JobStatus,
|
|
1520
1525
|
error_message: Optional[str] = None,
|
|
1521
1526
|
error_stack: Optional[str] = None,
|
|
1527
|
+
metrics: Optional[dict[str, Any]] = None,
|
|
1522
1528
|
conn: Optional[Any] = None,
|
|
1523
1529
|
) -> None:
|
|
1524
1530
|
"""Set the status of the given job."""
|
|
@@ -1529,6 +1535,8 @@ class AbstractDBMetastore(AbstractMetastore):
|
|
|
1529
1535
|
values["error_message"] = error_message
|
|
1530
1536
|
if error_stack:
|
|
1531
1537
|
values["error_stack"] = error_stack
|
|
1538
|
+
if metrics:
|
|
1539
|
+
values["metrics"] = json.dumps(metrics)
|
|
1532
1540
|
self.db.execute(
|
|
1533
1541
|
self._jobs_update(self._jobs.c.id == job_id).values(**values),
|
|
1534
1542
|
conn=conn,
|
datachain/data_storage/schema.py
CHANGED
|
@@ -34,8 +34,7 @@ def dedup_columns(columns: Iterable[sa.Column]) -> list[sa.Column]:
|
|
|
34
34
|
if ec := c_set.get(c.name, None):
|
|
35
35
|
if str(ec.type) != str(c.type):
|
|
36
36
|
raise ValueError(
|
|
37
|
-
f"conflicting types for column {c.name}:"
|
|
38
|
-
f"{c.type!s} and {ec.type!s}"
|
|
37
|
+
f"conflicting types for column {c.name}:{c.type!s} and {ec.type!s}"
|
|
39
38
|
)
|
|
40
39
|
continue
|
|
41
40
|
c_set[c.name] = c
|
|
@@ -235,6 +234,7 @@ class DataTable:
|
|
|
235
234
|
def file_columns(cls) -> list[sa.Column]:
|
|
236
235
|
return [
|
|
237
236
|
sa.Column("id", Int, primary_key=True),
|
|
237
|
+
sa.Column("random", Int64, nullable=False),
|
|
238
238
|
sa.Column("vtype", String, nullable=False, index=True),
|
|
239
239
|
sa.Column("dir_type", Int, index=True),
|
|
240
240
|
sa.Column("parent", String, index=True),
|
|
@@ -246,7 +246,6 @@ class DataTable:
|
|
|
246
246
|
sa.Column("size", Int64, nullable=False, index=True),
|
|
247
247
|
sa.Column("owner_name", String),
|
|
248
248
|
sa.Column("owner_id", String),
|
|
249
|
-
sa.Column("random", Int64, nullable=False),
|
|
250
249
|
sa.Column("location", JSON),
|
|
251
250
|
sa.Column("source", String, nullable=False),
|
|
252
251
|
]
|
|
@@ -95,14 +95,14 @@ class AbstractWarehouse(ABC, Serializable):
|
|
|
95
95
|
|
|
96
96
|
exc = None
|
|
97
97
|
try:
|
|
98
|
-
if col_python_type
|
|
98
|
+
if col_python_type is list and value_type in (list, tuple, set):
|
|
99
99
|
if len(val) == 0:
|
|
100
100
|
return []
|
|
101
101
|
item_python_type = self.python_type(col_type.item_type)
|
|
102
|
-
if item_python_type
|
|
102
|
+
if item_python_type is not list:
|
|
103
103
|
if isinstance(val[0], item_python_type):
|
|
104
104
|
return val
|
|
105
|
-
if item_python_type
|
|
105
|
+
if item_python_type is float and isinstance(val[0], int):
|
|
106
106
|
return [float(i) for i in val]
|
|
107
107
|
# Optimization: Reuse these values for each function call within the
|
|
108
108
|
# list comprehension.
|
|
@@ -114,18 +114,18 @@ class AbstractWarehouse(ABC, Serializable):
|
|
|
114
114
|
)
|
|
115
115
|
return [self.convert_type(i, *item_type_info) for i in val]
|
|
116
116
|
# Special use case with JSON type as we save it as string
|
|
117
|
-
if col_python_type
|
|
118
|
-
if value_type
|
|
117
|
+
if col_python_type is dict or col_type_name == "JSON":
|
|
118
|
+
if value_type is str:
|
|
119
119
|
return val
|
|
120
120
|
if value_type in (dict, list):
|
|
121
121
|
return json.dumps(val)
|
|
122
122
|
raise ValueError(
|
|
123
|
-
f"Cannot convert value {val!r} with type
|
|
123
|
+
f"Cannot convert value {val!r} with type {value_type} to JSON"
|
|
124
124
|
)
|
|
125
125
|
|
|
126
126
|
if isinstance(val, col_python_type):
|
|
127
127
|
return val
|
|
128
|
-
if col_python_type
|
|
128
|
+
if col_python_type is float and isinstance(val, int):
|
|
129
129
|
return float(val)
|
|
130
130
|
except Exception as e: # noqa: BLE001
|
|
131
131
|
exc = e
|
|
@@ -335,6 +335,7 @@ class AbstractWarehouse(ABC, Serializable):
|
|
|
335
335
|
return select_query
|
|
336
336
|
if recursive:
|
|
337
337
|
root = False
|
|
338
|
+
where = self.path_expr(dr).op("GLOB")(path)
|
|
338
339
|
if not path or path == "/":
|
|
339
340
|
# root of the bucket, e.g s3://bucket/ -> getting all the nodes
|
|
340
341
|
# in the bucket
|
|
@@ -344,14 +345,18 @@ class AbstractWarehouse(ABC, Serializable):
|
|
|
344
345
|
# not a root and not a explicit glob, so it's pointing to some directory
|
|
345
346
|
# and we are adding a proper glob syntax for it
|
|
346
347
|
# e.g s3://bucket/dir1 -> s3://bucket/dir1/*
|
|
347
|
-
|
|
348
|
+
dir_path = path.rstrip("/") + "/*"
|
|
349
|
+
where = where | self.path_expr(dr).op("GLOB")(dir_path)
|
|
348
350
|
|
|
349
351
|
if not root:
|
|
350
352
|
# not a root, so running glob query
|
|
351
|
-
select_query = select_query.where(
|
|
353
|
+
select_query = select_query.where(where)
|
|
354
|
+
|
|
352
355
|
else:
|
|
353
356
|
parent = self.get_node_by_path(dr, path.lstrip("/").rstrip("/*"))
|
|
354
|
-
select_query = select_query.where(
|
|
357
|
+
select_query = select_query.where(
|
|
358
|
+
(dr.c.parent == parent.path) | (self.path_expr(dr) == path)
|
|
359
|
+
)
|
|
355
360
|
return select_query
|
|
356
361
|
|
|
357
362
|
def rename_dataset_table(
|
|
@@ -493,7 +498,10 @@ class AbstractWarehouse(ABC, Serializable):
|
|
|
493
498
|
This gets nodes based on the provided query, and should be used sparingly,
|
|
494
499
|
as it will be slow on any OLAP database systems.
|
|
495
500
|
"""
|
|
496
|
-
|
|
501
|
+
columns = [c.name for c in query.columns]
|
|
502
|
+
for row in self.db.execute(query):
|
|
503
|
+
d = dict(zip(columns, row))
|
|
504
|
+
yield Node(**d)
|
|
497
505
|
|
|
498
506
|
def get_dirs_by_parent_path(
|
|
499
507
|
self,
|
|
@@ -570,14 +578,12 @@ class AbstractWarehouse(ABC, Serializable):
|
|
|
570
578
|
matched_paths: list[list[str]] = [[]]
|
|
571
579
|
for curr_name in path_list[:-1]:
|
|
572
580
|
if glob.has_magic(curr_name):
|
|
573
|
-
new_paths = []
|
|
581
|
+
new_paths: list[list[str]] = []
|
|
574
582
|
for path in matched_paths:
|
|
575
583
|
nodes = self._get_nodes_by_glob_path_pattern(
|
|
576
584
|
dataset_rows, path, curr_name
|
|
577
585
|
)
|
|
578
|
-
for
|
|
579
|
-
if node.is_container:
|
|
580
|
-
new_paths.append([*path, node.name or ""])
|
|
586
|
+
new_paths.extend([*path, n.name] for n in nodes if n.is_container)
|
|
581
587
|
matched_paths = new_paths
|
|
582
588
|
else:
|
|
583
589
|
for path in matched_paths:
|
|
@@ -772,7 +778,7 @@ class AbstractWarehouse(ABC, Serializable):
|
|
|
772
778
|
self,
|
|
773
779
|
dataset_rows: "DataTable",
|
|
774
780
|
parent_path: str,
|
|
775
|
-
fields: Optional[
|
|
781
|
+
fields: Optional[Sequence[str]] = None,
|
|
776
782
|
type: Optional[str] = None,
|
|
777
783
|
conds=None,
|
|
778
784
|
order_by: Optional[Union[str, list[str]]] = None,
|
|
@@ -794,9 +800,9 @@ class AbstractWarehouse(ABC, Serializable):
|
|
|
794
800
|
else:
|
|
795
801
|
conds.append(path != "")
|
|
796
802
|
|
|
797
|
-
|
|
798
|
-
|
|
799
|
-
|
|
803
|
+
columns = q.c
|
|
804
|
+
if fields:
|
|
805
|
+
columns = [getattr(columns, f) for f in fields]
|
|
800
806
|
|
|
801
807
|
query = sa.select(*columns)
|
|
802
808
|
query = query.where(*conds)
|
|
@@ -833,19 +839,16 @@ class AbstractWarehouse(ABC, Serializable):
|
|
|
833
839
|
|
|
834
840
|
prefix_len = len(node.path)
|
|
835
841
|
|
|
836
|
-
def make_node_with_path(
|
|
837
|
-
|
|
838
|
-
return NodeWithPath(
|
|
839
|
-
sub_node, sub_node.path[prefix_len:].lstrip("/").split("/")
|
|
840
|
-
)
|
|
842
|
+
def make_node_with_path(node: Node) -> NodeWithPath:
|
|
843
|
+
return NodeWithPath(node, node.path[prefix_len:].lstrip("/").split("/"))
|
|
841
844
|
|
|
842
|
-
return map(make_node_with_path, self.
|
|
845
|
+
return map(make_node_with_path, self.get_nodes(query))
|
|
843
846
|
|
|
844
847
|
def find(
|
|
845
848
|
self,
|
|
846
849
|
dataset_rows: "DataTable",
|
|
847
850
|
node: Node,
|
|
848
|
-
fields:
|
|
851
|
+
fields: Sequence[str],
|
|
849
852
|
type=None,
|
|
850
853
|
conds=None,
|
|
851
854
|
order_by=None,
|
|
@@ -890,11 +893,9 @@ class AbstractWarehouse(ABC, Serializable):
|
|
|
890
893
|
def is_temp_table_name(self, name: str) -> bool:
|
|
891
894
|
"""Returns if the given table name refers to a temporary
|
|
892
895
|
or no longer needed table."""
|
|
893
|
-
|
|
896
|
+
return name.startswith(
|
|
894
897
|
(self.TMP_TABLE_NAME_PREFIX, self.UDF_TABLE_NAME_PREFIX, "ds_shadow_")
|
|
895
|
-
) or name.endswith("_shadow")
|
|
896
|
-
return True
|
|
897
|
-
return False
|
|
898
|
+
) or name.endswith("_shadow")
|
|
898
899
|
|
|
899
900
|
def get_temp_table_names(self) -> list[str]:
|
|
900
901
|
return [
|
datachain/dataset.py
CHANGED
|
@@ -405,9 +405,7 @@ class DatasetRecord:
|
|
|
405
405
|
Checks if a number can be a valid next latest version for dataset.
|
|
406
406
|
The only rule is that it cannot be lower than current latest version
|
|
407
407
|
"""
|
|
408
|
-
|
|
409
|
-
return False
|
|
410
|
-
return True
|
|
408
|
+
return not (self.latest_version and self.latest_version >= version)
|
|
411
409
|
|
|
412
410
|
def get_version(self, version: int) -> DatasetVersion:
|
|
413
411
|
if not self.has_version(version):
|
datachain/lib/arrow.py
ADDED
|
@@ -0,0 +1,85 @@
|
|
|
1
|
+
import re
|
|
2
|
+
from typing import TYPE_CHECKING, Optional
|
|
3
|
+
|
|
4
|
+
from pyarrow.dataset import dataset
|
|
5
|
+
|
|
6
|
+
from datachain.lib.feature import Feature
|
|
7
|
+
from datachain.lib.file import File
|
|
8
|
+
|
|
9
|
+
if TYPE_CHECKING:
|
|
10
|
+
import pyarrow as pa
|
|
11
|
+
|
|
12
|
+
|
|
13
|
+
class Source(Feature):
|
|
14
|
+
"""File source info for tables."""
|
|
15
|
+
|
|
16
|
+
file: File
|
|
17
|
+
index: int
|
|
18
|
+
|
|
19
|
+
|
|
20
|
+
class ArrowGenerator:
|
|
21
|
+
def __init__(self, schema: Optional["pa.Schema"] = None, **kwargs):
|
|
22
|
+
"""
|
|
23
|
+
Generator for getting rows from tabular files.
|
|
24
|
+
|
|
25
|
+
Parameters:
|
|
26
|
+
|
|
27
|
+
schema : Optional pyarrow schema for validation.
|
|
28
|
+
kwargs: Parameters to pass to pyarrow.dataset.dataset.
|
|
29
|
+
"""
|
|
30
|
+
self.schema = schema
|
|
31
|
+
self.kwargs = kwargs
|
|
32
|
+
|
|
33
|
+
def __call__(self, file: File):
|
|
34
|
+
path = file.get_path()
|
|
35
|
+
ds = dataset(path, filesystem=file.get_fs(), schema=self.schema, **self.kwargs)
|
|
36
|
+
index = 0
|
|
37
|
+
for record_batch in ds.to_batches():
|
|
38
|
+
for record in record_batch.to_pylist():
|
|
39
|
+
source = Source(file=file, index=index)
|
|
40
|
+
yield [source, *record.values()]
|
|
41
|
+
index += 1
|
|
42
|
+
|
|
43
|
+
|
|
44
|
+
def schema_to_output(schema: "pa.Schema"):
|
|
45
|
+
"""Generate UDF output schema from pyarrow schema."""
|
|
46
|
+
default_column = 0
|
|
47
|
+
output = {"source": Source}
|
|
48
|
+
for field in schema:
|
|
49
|
+
column = field.name.lower()
|
|
50
|
+
column = re.sub("[^0-9a-z_]+", "", column)
|
|
51
|
+
if not column:
|
|
52
|
+
column = f"c{default_column}"
|
|
53
|
+
default_column += 1
|
|
54
|
+
output[column] = _arrow_type_mapper(field.type) # type: ignore[assignment]
|
|
55
|
+
|
|
56
|
+
return output
|
|
57
|
+
|
|
58
|
+
|
|
59
|
+
def _arrow_type_mapper(col_type: "pa.DataType") -> type: # noqa: PLR0911
|
|
60
|
+
"""Convert pyarrow types to basic types."""
|
|
61
|
+
from datetime import datetime
|
|
62
|
+
|
|
63
|
+
import pyarrow as pa
|
|
64
|
+
|
|
65
|
+
if pa.types.is_timestamp(col_type):
|
|
66
|
+
return datetime
|
|
67
|
+
if pa.types.is_binary(col_type):
|
|
68
|
+
return bytes
|
|
69
|
+
if pa.types.is_floating(col_type):
|
|
70
|
+
return float
|
|
71
|
+
if pa.types.is_integer(col_type):
|
|
72
|
+
return int
|
|
73
|
+
if pa.types.is_boolean(col_type):
|
|
74
|
+
return bool
|
|
75
|
+
if pa.types.is_date(col_type):
|
|
76
|
+
return datetime
|
|
77
|
+
if pa.types.is_string(col_type) or pa.types.is_large_string(col_type):
|
|
78
|
+
return str
|
|
79
|
+
if pa.types.is_list(col_type):
|
|
80
|
+
return list[_arrow_type_mapper(col_type.value_type)] # type: ignore[misc]
|
|
81
|
+
if pa.types.is_struct(col_type) or pa.types.is_map(col_type):
|
|
82
|
+
return dict
|
|
83
|
+
if isinstance(col_type, pa.lib.DictionaryType):
|
|
84
|
+
return _arrow_type_mapper(col_type.value_type) # type: ignore[return-value]
|
|
85
|
+
raise TypeError(f"{col_type!r} datatypes not supported")
|