datachain 0.7.11__py3-none-any.whl → 0.8.1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of datachain might be problematic. Click here for more details.
- datachain/catalog/catalog.py +56 -45
- datachain/cli.py +25 -3
- datachain/client/gcs.py +9 -0
- datachain/data_storage/sqlite.py +20 -6
- datachain/data_storage/warehouse.py +0 -1
- datachain/lib/arrow.py +82 -58
- datachain/lib/dc.py +167 -166
- datachain/lib/diff.py +197 -0
- datachain/lib/file.py +3 -1
- datachain/lib/listing.py +44 -0
- datachain/lib/meta_formats.py +38 -42
- datachain/lib/udf.py +0 -1
- datachain/query/batch.py +32 -6
- datachain/query/dataset.py +18 -17
- datachain/query/dispatch.py +125 -125
- datachain/query/session.py +8 -5
- datachain/query/udf.py +20 -0
- datachain/query/utils.py +42 -0
- datachain/remote/studio.py +53 -1
- datachain/studio.py +47 -2
- datachain/utils.py +1 -1
- {datachain-0.7.11.dist-info → datachain-0.8.1.dist-info}/METADATA +4 -3
- {datachain-0.7.11.dist-info → datachain-0.8.1.dist-info}/RECORD +27 -24
- {datachain-0.7.11.dist-info → datachain-0.8.1.dist-info}/LICENSE +0 -0
- {datachain-0.7.11.dist-info → datachain-0.8.1.dist-info}/WHEEL +0 -0
- {datachain-0.7.11.dist-info → datachain-0.8.1.dist-info}/entry_points.txt +0 -0
- {datachain-0.7.11.dist-info → datachain-0.8.1.dist-info}/top_level.txt +0 -0
datachain/catalog/catalog.py
CHANGED
|
@@ -1,7 +1,6 @@
|
|
|
1
1
|
import io
|
|
2
2
|
import json
|
|
3
3
|
import logging
|
|
4
|
-
import math
|
|
5
4
|
import os
|
|
6
5
|
import os.path
|
|
7
6
|
import posixpath
|
|
@@ -13,7 +12,6 @@ from collections.abc import Iterable, Iterator, Mapping, Sequence
|
|
|
13
12
|
from copy import copy
|
|
14
13
|
from dataclasses import dataclass
|
|
15
14
|
from functools import cached_property, reduce
|
|
16
|
-
from random import shuffle
|
|
17
15
|
from threading import Thread
|
|
18
16
|
from typing import (
|
|
19
17
|
IO,
|
|
@@ -54,15 +52,12 @@ from datachain.error import (
|
|
|
54
52
|
QueryScriptCancelError,
|
|
55
53
|
QueryScriptRunError,
|
|
56
54
|
)
|
|
55
|
+
from datachain.lib.listing import get_listing
|
|
57
56
|
from datachain.node import DirType, Node, NodeWithPath
|
|
58
57
|
from datachain.nodes_thread_pool import NodesThreadPool
|
|
59
58
|
from datachain.remote.studio import StudioClient
|
|
60
59
|
from datachain.sql.types import DateTime, SQLType
|
|
61
|
-
from datachain.utils import
|
|
62
|
-
DataChainDir,
|
|
63
|
-
batched,
|
|
64
|
-
datachain_paths_join,
|
|
65
|
-
)
|
|
60
|
+
from datachain.utils import DataChainDir, datachain_paths_join
|
|
66
61
|
|
|
67
62
|
from .datasource import DataSource
|
|
68
63
|
|
|
@@ -90,7 +85,7 @@ QUERY_SCRIPT_INVALID_LAST_STATEMENT_EXIT_CODE = 10
|
|
|
90
85
|
QUERY_SCRIPT_CANCELED_EXIT_CODE = 11
|
|
91
86
|
|
|
92
87
|
# dataset pull
|
|
93
|
-
PULL_DATASET_MAX_THREADS =
|
|
88
|
+
PULL_DATASET_MAX_THREADS = 5
|
|
94
89
|
PULL_DATASET_CHUNK_TIMEOUT = 3600
|
|
95
90
|
PULL_DATASET_SLEEP_INTERVAL = 0.1 # sleep time while waiting for chunk to be available
|
|
96
91
|
PULL_DATASET_CHECK_STATUS_INTERVAL = 20 # interval to check export status in Studio
|
|
@@ -130,6 +125,7 @@ class DatasetRowsFetcher(NodesThreadPool):
|
|
|
130
125
|
local_ds_version: int,
|
|
131
126
|
schema: dict[str, Union[SQLType, type[SQLType]]],
|
|
132
127
|
max_threads: int = PULL_DATASET_MAX_THREADS,
|
|
128
|
+
progress_bar=None,
|
|
133
129
|
):
|
|
134
130
|
super().__init__(max_threads)
|
|
135
131
|
self._check_dependencies()
|
|
@@ -142,6 +138,7 @@ class DatasetRowsFetcher(NodesThreadPool):
|
|
|
142
138
|
self.schema = schema
|
|
143
139
|
self.last_status_check: Optional[float] = None
|
|
144
140
|
self.studio_client = StudioClient()
|
|
141
|
+
self.progress_bar = progress_bar
|
|
145
142
|
|
|
146
143
|
def done_task(self, done):
|
|
147
144
|
for task in done:
|
|
@@ -198,6 +195,20 @@ class DatasetRowsFetcher(NodesThreadPool):
|
|
|
198
195
|
for c in [c for c, t in self.schema.items() if t == DateTime]:
|
|
199
196
|
df[c] = pd.to_datetime(df[c], unit="s")
|
|
200
197
|
|
|
198
|
+
# id will be autogenerated in DB
|
|
199
|
+
return df.drop("sys__id", axis=1)
|
|
200
|
+
|
|
201
|
+
def get_parquet_content(self, url: str):
|
|
202
|
+
while True:
|
|
203
|
+
if self.should_check_for_status():
|
|
204
|
+
self.check_for_status()
|
|
205
|
+
r = requests.get(url, timeout=PULL_DATASET_CHUNK_TIMEOUT)
|
|
206
|
+
if r.status_code == 404:
|
|
207
|
+
time.sleep(PULL_DATASET_SLEEP_INTERVAL)
|
|
208
|
+
continue
|
|
209
|
+
r.raise_for_status()
|
|
210
|
+
return r.content
|
|
211
|
+
|
|
201
212
|
def do_task(self, urls):
|
|
202
213
|
import lz4.frame
|
|
203
214
|
import pandas as pd
|
|
@@ -207,31 +218,22 @@ class DatasetRowsFetcher(NodesThreadPool):
|
|
|
207
218
|
local_ds = metastore.get_dataset(self.local_ds_name)
|
|
208
219
|
|
|
209
220
|
urls = list(urls)
|
|
210
|
-
while urls:
|
|
211
|
-
for url in urls:
|
|
212
|
-
if self.should_check_for_status():
|
|
213
|
-
self.check_for_status()
|
|
214
221
|
|
|
215
|
-
|
|
216
|
-
|
|
217
|
-
|
|
218
|
-
# moving to the next url
|
|
219
|
-
continue
|
|
222
|
+
for url in urls:
|
|
223
|
+
if self.should_check_for_status():
|
|
224
|
+
self.check_for_status()
|
|
220
225
|
|
|
221
|
-
|
|
222
|
-
|
|
223
|
-
|
|
224
|
-
|
|
225
|
-
self.fix_columns(df)
|
|
226
|
-
|
|
227
|
-
# id will be autogenerated in DB
|
|
228
|
-
df = df.drop("sys__id", axis=1)
|
|
226
|
+
df = pd.read_parquet(
|
|
227
|
+
io.BytesIO(lz4.frame.decompress(self.get_parquet_content(url)))
|
|
228
|
+
)
|
|
229
|
+
df = self.fix_columns(df)
|
|
229
230
|
|
|
230
|
-
|
|
231
|
-
|
|
232
|
-
|
|
233
|
-
|
|
234
|
-
|
|
231
|
+
inserted = warehouse.insert_dataset_rows(
|
|
232
|
+
df, local_ds, self.local_ds_version
|
|
233
|
+
)
|
|
234
|
+
self.increase_counter(inserted) # type: ignore [arg-type]
|
|
235
|
+
# sometimes progress bar doesn't get updated so manually updating it
|
|
236
|
+
self.update_progress_bar(self.progress_bar)
|
|
235
237
|
|
|
236
238
|
|
|
237
239
|
@dataclass
|
|
@@ -598,7 +600,7 @@ class Catalog:
|
|
|
598
600
|
source, session=self.session, update=update, object_name=object_name
|
|
599
601
|
)
|
|
600
602
|
|
|
601
|
-
list_ds_name, list_uri, list_path, _ =
|
|
603
|
+
list_ds_name, list_uri, list_path, _ = get_listing(
|
|
602
604
|
source, self.session, update=update
|
|
603
605
|
)
|
|
604
606
|
|
|
@@ -696,11 +698,9 @@ class Catalog:
|
|
|
696
698
|
)
|
|
697
699
|
indexed_sources = []
|
|
698
700
|
for source in dataset_sources:
|
|
699
|
-
from datachain.lib.dc import DataChain
|
|
700
|
-
|
|
701
701
|
client = self.get_client(source, **client_config)
|
|
702
702
|
uri = client.uri
|
|
703
|
-
dataset_name, _, _, _ =
|
|
703
|
+
dataset_name, _, _, _ = get_listing(uri, self.session)
|
|
704
704
|
listing = Listing(
|
|
705
705
|
self.metastore.clone(),
|
|
706
706
|
self.warehouse.clone(),
|
|
@@ -1291,13 +1291,13 @@ class Catalog:
|
|
|
1291
1291
|
for source in data_sources: # type: ignore [union-attr]
|
|
1292
1292
|
yield source, source.ls(fields)
|
|
1293
1293
|
|
|
1294
|
-
def pull_dataset( # noqa: PLR0915
|
|
1294
|
+
def pull_dataset( # noqa: C901, PLR0915
|
|
1295
1295
|
self,
|
|
1296
1296
|
remote_ds_uri: str,
|
|
1297
1297
|
output: Optional[str] = None,
|
|
1298
1298
|
local_ds_name: Optional[str] = None,
|
|
1299
1299
|
local_ds_version: Optional[int] = None,
|
|
1300
|
-
|
|
1300
|
+
cp: bool = False,
|
|
1301
1301
|
force: bool = False,
|
|
1302
1302
|
edatachain: bool = False,
|
|
1303
1303
|
edatachain_file: Optional[str] = None,
|
|
@@ -1305,7 +1305,7 @@ class Catalog:
|
|
|
1305
1305
|
client_config=None,
|
|
1306
1306
|
) -> None:
|
|
1307
1307
|
def _instantiate(ds_uri: str) -> None:
|
|
1308
|
-
if
|
|
1308
|
+
if not cp:
|
|
1309
1309
|
return
|
|
1310
1310
|
assert output
|
|
1311
1311
|
self.cp(
|
|
@@ -1318,7 +1318,7 @@ class Catalog:
|
|
|
1318
1318
|
)
|
|
1319
1319
|
print(f"Dataset {ds_uri} instantiated locally to {output}")
|
|
1320
1320
|
|
|
1321
|
-
if
|
|
1321
|
+
if cp and not output:
|
|
1322
1322
|
raise ValueError("Please provide output directory for instantiation")
|
|
1323
1323
|
|
|
1324
1324
|
studio_client = StudioClient()
|
|
@@ -1417,12 +1417,26 @@ class Catalog:
|
|
|
1417
1417
|
signed_urls = export_response.data
|
|
1418
1418
|
|
|
1419
1419
|
if signed_urls:
|
|
1420
|
-
shuffle(signed_urls)
|
|
1421
|
-
|
|
1422
1420
|
with (
|
|
1423
1421
|
self.metastore.clone() as metastore,
|
|
1424
1422
|
self.warehouse.clone() as warehouse,
|
|
1425
1423
|
):
|
|
1424
|
+
|
|
1425
|
+
def batch(urls):
|
|
1426
|
+
"""
|
|
1427
|
+
Batching urls in a way that fetching is most efficient as
|
|
1428
|
+
urls with lower id will be created first. Because that, we
|
|
1429
|
+
are making sure all threads are pulling most recent urls
|
|
1430
|
+
from beginning
|
|
1431
|
+
"""
|
|
1432
|
+
res = [[] for i in range(PULL_DATASET_MAX_THREADS)]
|
|
1433
|
+
current_worker = 0
|
|
1434
|
+
for url in signed_urls:
|
|
1435
|
+
res[current_worker].append(url)
|
|
1436
|
+
current_worker = (current_worker + 1) % PULL_DATASET_MAX_THREADS
|
|
1437
|
+
|
|
1438
|
+
return res
|
|
1439
|
+
|
|
1426
1440
|
rows_fetcher = DatasetRowsFetcher(
|
|
1427
1441
|
metastore,
|
|
1428
1442
|
warehouse,
|
|
@@ -1431,14 +1445,11 @@ class Catalog:
|
|
|
1431
1445
|
local_ds_name,
|
|
1432
1446
|
local_ds_version,
|
|
1433
1447
|
schema,
|
|
1448
|
+
progress_bar=dataset_save_progress_bar,
|
|
1434
1449
|
)
|
|
1435
1450
|
try:
|
|
1436
1451
|
rows_fetcher.run(
|
|
1437
|
-
|
|
1438
|
-
signed_urls,
|
|
1439
|
-
math.ceil(len(signed_urls) / PULL_DATASET_MAX_THREADS),
|
|
1440
|
-
),
|
|
1441
|
-
dataset_save_progress_bar,
|
|
1452
|
+
iter(batch(signed_urls)), dataset_save_progress_bar
|
|
1442
1453
|
)
|
|
1443
1454
|
except:
|
|
1444
1455
|
self.remove_dataset(local_ds_name, local_ds_version)
|
datachain/cli.py
CHANGED
|
@@ -294,6 +294,28 @@ def add_studio_parser(subparsers, parent_parser) -> None:
|
|
|
294
294
|
help="Python package requirement. Can be specified multiple times.",
|
|
295
295
|
)
|
|
296
296
|
|
|
297
|
+
studio_cancel_help = "Cancel a job in Studio"
|
|
298
|
+
studio_cancel_description = "This command cancels a job in Studio."
|
|
299
|
+
|
|
300
|
+
studio_cancel_parser = studio_subparser.add_parser(
|
|
301
|
+
"cancel",
|
|
302
|
+
parents=[parent_parser],
|
|
303
|
+
description=studio_cancel_description,
|
|
304
|
+
help=studio_cancel_help,
|
|
305
|
+
)
|
|
306
|
+
|
|
307
|
+
studio_cancel_parser.add_argument(
|
|
308
|
+
"job_id",
|
|
309
|
+
action="store",
|
|
310
|
+
help="The job ID to cancel.",
|
|
311
|
+
)
|
|
312
|
+
studio_cancel_parser.add_argument(
|
|
313
|
+
"--team",
|
|
314
|
+
action="store",
|
|
315
|
+
default=None,
|
|
316
|
+
help="The team to cancel a job for. By default, it will use team from config.",
|
|
317
|
+
)
|
|
318
|
+
|
|
297
319
|
|
|
298
320
|
def get_parser() -> ArgumentParser: # noqa: PLR0915
|
|
299
321
|
try:
|
|
@@ -457,10 +479,10 @@ def get_parser() -> ArgumentParser: # noqa: PLR0915
|
|
|
457
479
|
help="Copy directories recursively",
|
|
458
480
|
)
|
|
459
481
|
parse_pull.add_argument(
|
|
460
|
-
"--
|
|
482
|
+
"--cp",
|
|
461
483
|
default=False,
|
|
462
484
|
action="store_true",
|
|
463
|
-
help="
|
|
485
|
+
help="Copy actual files after pulling remote dataset into local DB",
|
|
464
486
|
)
|
|
465
487
|
parse_pull.add_argument(
|
|
466
488
|
"--edatachain",
|
|
@@ -1300,7 +1322,7 @@ def main(argv: Optional[list[str]] = None) -> int: # noqa: C901, PLR0912, PLR09
|
|
|
1300
1322
|
args.output,
|
|
1301
1323
|
local_ds_name=args.local_name,
|
|
1302
1324
|
local_ds_version=args.local_version,
|
|
1303
|
-
|
|
1325
|
+
cp=args.cp,
|
|
1304
1326
|
force=bool(args.force),
|
|
1305
1327
|
edatachain=args.edatachain,
|
|
1306
1328
|
edatachain_file=args.edatachain_file,
|
datachain/client/gcs.py
CHANGED
|
@@ -32,6 +32,15 @@ class GCSClient(Client):
|
|
|
32
32
|
|
|
33
33
|
return cast(GCSFileSystem, super().create_fs(**kwargs))
|
|
34
34
|
|
|
35
|
+
def url(self, path: str, expires: int = 3600, **kwargs) -> str:
|
|
36
|
+
try:
|
|
37
|
+
return self.fs.sign(self.get_full_path(path), expiration=expires, **kwargs)
|
|
38
|
+
except AttributeError as exc:
|
|
39
|
+
is_anon = self.fs.storage_options.get("token") == "anon"
|
|
40
|
+
if is_anon and "you need a private key to sign credentials" in str(exc):
|
|
41
|
+
return f"https://storage.googleapis.com/{self.name}/{path}"
|
|
42
|
+
raise
|
|
43
|
+
|
|
35
44
|
@staticmethod
|
|
36
45
|
def parse_timestamp(timestamp: str) -> datetime:
|
|
37
46
|
"""
|
datachain/data_storage/sqlite.py
CHANGED
|
@@ -209,10 +209,12 @@ class SQLiteDatabaseEngine(DatabaseEngine):
|
|
|
209
209
|
|
|
210
210
|
@retry_sqlite_locks
|
|
211
211
|
def executemany(
|
|
212
|
-
self, query, params, cursor: Optional[sqlite3.Cursor] = None
|
|
212
|
+
self, query, params, cursor: Optional[sqlite3.Cursor] = None, conn=None
|
|
213
213
|
) -> sqlite3.Cursor:
|
|
214
214
|
if cursor:
|
|
215
215
|
return cursor.executemany(self.compile(query).string, params)
|
|
216
|
+
if conn:
|
|
217
|
+
return conn.executemany(self.compile(query).string, params)
|
|
216
218
|
return self.db.executemany(self.compile(query).string, params)
|
|
217
219
|
|
|
218
220
|
@retry_sqlite_locks
|
|
@@ -222,7 +224,14 @@ class SQLiteDatabaseEngine(DatabaseEngine):
|
|
|
222
224
|
return self.db.execute(sql, parameters)
|
|
223
225
|
|
|
224
226
|
def insert_dataframe(self, table_name: str, df) -> int:
|
|
225
|
-
return df.to_sql(
|
|
227
|
+
return df.to_sql(
|
|
228
|
+
table_name,
|
|
229
|
+
self.db,
|
|
230
|
+
if_exists="append",
|
|
231
|
+
index=False,
|
|
232
|
+
method="multi",
|
|
233
|
+
chunksize=1000,
|
|
234
|
+
)
|
|
226
235
|
|
|
227
236
|
def cursor(self, factory=None):
|
|
228
237
|
if factory is None:
|
|
@@ -545,10 +554,15 @@ class SQLiteWarehouse(AbstractWarehouse):
|
|
|
545
554
|
rows = list(rows)
|
|
546
555
|
if not rows:
|
|
547
556
|
return
|
|
548
|
-
|
|
549
|
-
|
|
550
|
-
|
|
551
|
-
|
|
557
|
+
|
|
558
|
+
with self.db.transaction() as conn:
|
|
559
|
+
# transactions speeds up inserts significantly as there is no separate
|
|
560
|
+
# transaction created for each insert row
|
|
561
|
+
self.db.executemany(
|
|
562
|
+
table.insert().values({f: bindparam(f) for f in rows[0]}),
|
|
563
|
+
rows,
|
|
564
|
+
conn=conn,
|
|
565
|
+
)
|
|
552
566
|
|
|
553
567
|
def insert_dataset_rows(self, df, dataset: DatasetRecord, version: int) -> int:
|
|
554
568
|
dr = self.dataset_rows(dataset, version)
|
datachain/lib/arrow.py
CHANGED
|
@@ -1,9 +1,11 @@
|
|
|
1
1
|
from collections.abc import Sequence
|
|
2
|
-
from
|
|
2
|
+
from itertools import islice
|
|
3
3
|
from typing import TYPE_CHECKING, Any, Optional
|
|
4
4
|
|
|
5
|
+
import fsspec.implementations.reference
|
|
5
6
|
import orjson
|
|
6
7
|
import pyarrow as pa
|
|
8
|
+
from fsspec.core import split_protocol
|
|
7
9
|
from pyarrow.dataset import CsvFileFormat, dataset
|
|
8
10
|
from tqdm import tqdm
|
|
9
11
|
|
|
@@ -25,7 +27,18 @@ if TYPE_CHECKING:
|
|
|
25
27
|
DATACHAIN_SIGNAL_SCHEMA_PARQUET_KEY = b"DataChain SignalSchema"
|
|
26
28
|
|
|
27
29
|
|
|
30
|
+
class ReferenceFileSystem(fsspec.implementations.reference.ReferenceFileSystem):
|
|
31
|
+
def _open(self, path, mode="rb", *args, **kwargs):
|
|
32
|
+
# overriding because `fsspec`'s `ReferenceFileSystem._open`
|
|
33
|
+
# reads the whole file in-memory.
|
|
34
|
+
(uri,) = self.references[path]
|
|
35
|
+
protocol, _ = split_protocol(uri)
|
|
36
|
+
return self.fss[protocol]._open(uri, mode, *args, **kwargs)
|
|
37
|
+
|
|
38
|
+
|
|
28
39
|
class ArrowGenerator(Generator):
|
|
40
|
+
DEFAULT_BATCH_SIZE = 2**17 # same as `pyarrow._dataset._DEFAULT_BATCH_SIZE`
|
|
41
|
+
|
|
29
42
|
def __init__(
|
|
30
43
|
self,
|
|
31
44
|
input_schema: Optional["pa.Schema"] = None,
|
|
@@ -55,57 +68,80 @@ class ArrowGenerator(Generator):
|
|
|
55
68
|
def process(self, file: File):
|
|
56
69
|
if file._caching_enabled:
|
|
57
70
|
file.ensure_cached()
|
|
58
|
-
|
|
59
|
-
|
|
60
|
-
|
|
61
|
-
path = _nrows_file(file, self.nrows)
|
|
62
|
-
ds = dataset(path, schema=self.input_schema, **self.kwargs)
|
|
71
|
+
cache_path = file.get_local_path()
|
|
72
|
+
fs_path = file.path
|
|
73
|
+
fs = ReferenceFileSystem({fs_path: [cache_path]})
|
|
63
74
|
else:
|
|
64
|
-
|
|
65
|
-
|
|
66
|
-
|
|
67
|
-
|
|
75
|
+
fs, fs_path = file.get_fs(), file.get_path()
|
|
76
|
+
|
|
77
|
+
ds = dataset(fs_path, schema=self.input_schema, filesystem=fs, **self.kwargs)
|
|
78
|
+
|
|
68
79
|
hf_schema = _get_hf_schema(ds.schema)
|
|
69
80
|
use_datachain_schema = (
|
|
70
81
|
bool(ds.schema.metadata)
|
|
71
82
|
and DATACHAIN_SIGNAL_SCHEMA_PARQUET_KEY in ds.schema.metadata
|
|
72
83
|
)
|
|
73
|
-
|
|
74
|
-
|
|
75
|
-
|
|
76
|
-
|
|
77
|
-
|
|
78
|
-
|
|
79
|
-
|
|
80
|
-
|
|
81
|
-
|
|
82
|
-
|
|
83
|
-
|
|
84
|
-
|
|
85
|
-
|
|
86
|
-
|
|
87
|
-
|
|
88
|
-
|
|
89
|
-
|
|
90
|
-
|
|
91
|
-
|
|
92
|
-
|
|
93
|
-
|
|
94
|
-
|
|
95
|
-
|
|
96
|
-
|
|
97
|
-
|
|
98
|
-
|
|
99
|
-
|
|
100
|
-
|
|
101
|
-
|
|
102
|
-
|
|
103
|
-
|
|
104
|
-
|
|
105
|
-
|
|
106
|
-
|
|
107
|
-
|
|
108
|
-
|
|
84
|
+
|
|
85
|
+
kw = {}
|
|
86
|
+
if self.nrows:
|
|
87
|
+
kw = {"batch_size": min(self.DEFAULT_BATCH_SIZE, self.nrows)}
|
|
88
|
+
|
|
89
|
+
def iter_records():
|
|
90
|
+
for record_batch in ds.to_batches(**kw):
|
|
91
|
+
yield from record_batch.to_pylist()
|
|
92
|
+
|
|
93
|
+
it = islice(iter_records(), self.nrows)
|
|
94
|
+
with tqdm(it, desc="Parsed by pyarrow", unit="rows", total=self.nrows) as pbar:
|
|
95
|
+
for index, record in enumerate(pbar):
|
|
96
|
+
yield self._process_record(
|
|
97
|
+
record, file, index, hf_schema, use_datachain_schema
|
|
98
|
+
)
|
|
99
|
+
|
|
100
|
+
def _process_record(
|
|
101
|
+
self,
|
|
102
|
+
record: dict[str, Any],
|
|
103
|
+
file: File,
|
|
104
|
+
index: int,
|
|
105
|
+
hf_schema: Optional[tuple["Features", dict[str, "DataType"]]],
|
|
106
|
+
use_datachain_schema: bool,
|
|
107
|
+
):
|
|
108
|
+
if use_datachain_schema and self.output_schema:
|
|
109
|
+
vals = [_nested_model_instantiate(record, self.output_schema)]
|
|
110
|
+
else:
|
|
111
|
+
vals = self._process_non_datachain_record(record, hf_schema)
|
|
112
|
+
|
|
113
|
+
if self.source:
|
|
114
|
+
kwargs: dict = self.kwargs
|
|
115
|
+
# Can't serialize CsvFileFormat; may lose formatting options.
|
|
116
|
+
if isinstance(kwargs.get("format"), CsvFileFormat):
|
|
117
|
+
kwargs["format"] = "csv"
|
|
118
|
+
arrow_file = ArrowRow(file=file, index=index, kwargs=kwargs)
|
|
119
|
+
return [arrow_file, *vals]
|
|
120
|
+
return vals
|
|
121
|
+
|
|
122
|
+
def _process_non_datachain_record(
|
|
123
|
+
self,
|
|
124
|
+
record: dict[str, Any],
|
|
125
|
+
hf_schema: Optional[tuple["Features", dict[str, "DataType"]]],
|
|
126
|
+
):
|
|
127
|
+
vals = list(record.values())
|
|
128
|
+
if not self.output_schema:
|
|
129
|
+
return vals
|
|
130
|
+
|
|
131
|
+
fields = self.output_schema.model_fields
|
|
132
|
+
vals_dict = {}
|
|
133
|
+
for i, ((field, field_info), val) in enumerate(zip(fields.items(), vals)):
|
|
134
|
+
anno = field_info.annotation
|
|
135
|
+
if hf_schema:
|
|
136
|
+
from datachain.lib.hf import convert_feature
|
|
137
|
+
|
|
138
|
+
feat = list(hf_schema[0].values())[i]
|
|
139
|
+
vals_dict[field] = convert_feature(val, feat, anno)
|
|
140
|
+
elif ModelStore.is_pydantic(anno):
|
|
141
|
+
vals_dict[field] = anno(**val) # type: ignore[misc]
|
|
142
|
+
else:
|
|
143
|
+
vals_dict[field] = val
|
|
144
|
+
return [self.output_schema(**vals_dict)]
|
|
109
145
|
|
|
110
146
|
|
|
111
147
|
def infer_schema(chain: "DataChain", **kwargs) -> pa.Schema:
|
|
@@ -190,18 +226,6 @@ def arrow_type_mapper(col_type: pa.DataType, column: str = "") -> type: # noqa:
|
|
|
190
226
|
raise TypeError(f"{col_type!r} datatypes not supported, column: {column}")
|
|
191
227
|
|
|
192
228
|
|
|
193
|
-
def _nrows_file(file: File, nrows: int) -> str:
|
|
194
|
-
tf = NamedTemporaryFile(delete=False) # noqa: SIM115
|
|
195
|
-
with file.open(mode="r") as reader:
|
|
196
|
-
with open(tf.name, "a") as writer:
|
|
197
|
-
for row, line in enumerate(reader):
|
|
198
|
-
if row >= nrows:
|
|
199
|
-
break
|
|
200
|
-
writer.write(line)
|
|
201
|
-
writer.write("\n")
|
|
202
|
-
return tf.name
|
|
203
|
-
|
|
204
|
-
|
|
205
229
|
def _get_hf_schema(
|
|
206
230
|
schema: "pa.Schema",
|
|
207
231
|
) -> Optional[tuple["Features", dict[str, "DataType"]]]:
|