datachain 0.5.0__py3-none-any.whl → 0.6.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of datachain might be problematic. Click here for more details.
- datachain/catalog/catalog.py +8 -0
- datachain/data_storage/metastore.py +20 -1
- datachain/data_storage/sqlite.py +24 -32
- datachain/lib/arrow.py +64 -19
- datachain/lib/convert/values_to_tuples.py +2 -2
- datachain/lib/data_model.py +1 -1
- datachain/lib/dc.py +131 -12
- datachain/lib/signal_schema.py +6 -6
- datachain/lib/udf.py +208 -160
- datachain/lib/udf_signature.py +8 -6
- datachain/query/batch.py +0 -10
- datachain/query/dataset.py +7 -7
- datachain/query/dispatch.py +2 -14
- datachain/query/session.py +42 -0
- datachain/sql/functions/string.py +12 -0
- datachain/sql/sqlite/base.py +10 -5
- {datachain-0.5.0.dist-info → datachain-0.6.0.dist-info}/METADATA +1 -1
- {datachain-0.5.0.dist-info → datachain-0.6.0.dist-info}/RECORD +22 -23
- datachain/query/udf.py +0 -126
- {datachain-0.5.0.dist-info → datachain-0.6.0.dist-info}/LICENSE +0 -0
- {datachain-0.5.0.dist-info → datachain-0.6.0.dist-info}/WHEEL +0 -0
- {datachain-0.5.0.dist-info → datachain-0.6.0.dist-info}/entry_points.txt +0 -0
- {datachain-0.5.0.dist-info → datachain-0.6.0.dist-info}/top_level.txt +0 -0
datachain/catalog/catalog.py
CHANGED
|
@@ -988,6 +988,14 @@ class Catalog:
|
|
|
988
988
|
schema = {
|
|
989
989
|
c.name: c.type.to_dict() for c in columns if isinstance(c.type, SQLType)
|
|
990
990
|
}
|
|
991
|
+
|
|
992
|
+
job_id = job_id or os.getenv("DATACHAIN_JOB_ID")
|
|
993
|
+
if not job_id:
|
|
994
|
+
from datachain.query.session import Session
|
|
995
|
+
|
|
996
|
+
session = Session.get(catalog=self)
|
|
997
|
+
job_id = session.job_id
|
|
998
|
+
|
|
991
999
|
dataset = self.metastore.create_dataset_version(
|
|
992
1000
|
dataset,
|
|
993
1001
|
version,
|
|
@@ -50,7 +50,6 @@ if TYPE_CHECKING:
|
|
|
50
50
|
from datachain.data_storage import AbstractIDGenerator, schema
|
|
51
51
|
from datachain.data_storage.db_engine import DatabaseEngine
|
|
52
52
|
|
|
53
|
-
|
|
54
53
|
logger = logging.getLogger("datachain")
|
|
55
54
|
|
|
56
55
|
|
|
@@ -384,6 +383,11 @@ class AbstractMetastore(ABC, Serializable):
|
|
|
384
383
|
) -> None:
|
|
385
384
|
"""Set the status of the given job and dataset."""
|
|
386
385
|
|
|
386
|
+
@abstractmethod
|
|
387
|
+
def get_job_dataset_versions(self, job_id: str) -> list[tuple[str, int]]:
|
|
388
|
+
"""Returns dataset names and versions for the job."""
|
|
389
|
+
raise NotImplementedError
|
|
390
|
+
|
|
387
391
|
|
|
388
392
|
class AbstractDBMetastore(AbstractMetastore):
|
|
389
393
|
"""
|
|
@@ -1519,3 +1523,18 @@ class AbstractDBMetastore(AbstractMetastore):
|
|
|
1519
1523
|
.values(status=dataset_status)
|
|
1520
1524
|
)
|
|
1521
1525
|
self.db.execute(query, conn=conn) # type: ignore[attr-defined]
|
|
1526
|
+
|
|
1527
|
+
def get_job_dataset_versions(self, job_id: str) -> list[tuple[str, int]]:
|
|
1528
|
+
"""Returns dataset names and versions for the job."""
|
|
1529
|
+
dv = self._datasets_versions
|
|
1530
|
+
ds = self._datasets
|
|
1531
|
+
|
|
1532
|
+
join_condition = dv.c.dataset_id == ds.c.id
|
|
1533
|
+
|
|
1534
|
+
query = (
|
|
1535
|
+
self._datasets_versions_select(ds.c.name, dv.c.version)
|
|
1536
|
+
.select_from(dv.join(ds, join_condition))
|
|
1537
|
+
.where(dv.c.job_id == job_id)
|
|
1538
|
+
)
|
|
1539
|
+
|
|
1540
|
+
return list(self.db.execute(query))
|
datachain/data_storage/sqlite.py
CHANGED
|
@@ -15,6 +15,7 @@ from typing import (
|
|
|
15
15
|
)
|
|
16
16
|
|
|
17
17
|
import sqlalchemy
|
|
18
|
+
from packaging import version
|
|
18
19
|
from sqlalchemy import MetaData, Table, UniqueConstraint, exists, select
|
|
19
20
|
from sqlalchemy.dialects import sqlite
|
|
20
21
|
from sqlalchemy.schema import CreateIndex, CreateTable, DropTable
|
|
@@ -153,7 +154,7 @@ class SQLiteDatabaseEngine(DatabaseEngine):
|
|
|
153
154
|
if os.environ.get("DEBUG_SHOW_SQL_QUERIES"):
|
|
154
155
|
import sys
|
|
155
156
|
|
|
156
|
-
db.set_trace_callback(sys.stderr
|
|
157
|
+
db.set_trace_callback(lambda stmt: print(stmt, file=sys.stderr))
|
|
157
158
|
|
|
158
159
|
load_usearch_extension(db)
|
|
159
160
|
|
|
@@ -345,45 +346,36 @@ class SQLiteIDGenerator(AbstractDBIDGenerator):
|
|
|
345
346
|
def get_next_ids(self, uri: str, count: int) -> range:
|
|
346
347
|
"""Returns a range of IDs for the given URI."""
|
|
347
348
|
|
|
348
|
-
|
|
349
|
-
|
|
350
|
-
|
|
351
|
-
|
|
352
|
-
# leaving fallback to the current implementation for older versions of SQLite,
|
|
353
|
-
# which is still supported, for example, in Ubuntu 20.04 LTS (Focal Fossa),
|
|
354
|
-
# where SQLite version 3.31.1 is used.
|
|
355
|
-
|
|
356
|
-
# sqlite_version = version.parse(sqlite3.sqlite_version)
|
|
357
|
-
# if sqlite_version >= version.parse("3.35.0"):
|
|
358
|
-
# # RETURNING is supported on SQLite 3.35.0 (2021-03-12) or newer
|
|
359
|
-
# stmt = (
|
|
360
|
-
# sqlite.insert(self._table)
|
|
361
|
-
# .values(uri=uri, last_id=count)
|
|
362
|
-
# .on_conflict_do_update(
|
|
363
|
-
# index_elements=["uri"],
|
|
364
|
-
# set_={"last_id": self._table.c.last_id + count},
|
|
365
|
-
# )
|
|
366
|
-
# .returning(self._table.c.last_id)
|
|
367
|
-
# )
|
|
368
|
-
# last_id = self._db.execute(stmt).fetchone()[0]
|
|
369
|
-
# else:
|
|
370
|
-
# (fallback to the current implementation with a transaction)
|
|
371
|
-
|
|
372
|
-
# Transactions ensure no concurrency conflicts
|
|
373
|
-
with self._db.transaction() as conn:
|
|
374
|
-
# UPSERT syntax was added to SQLite with version 3.24.0 (2018-06-04).
|
|
375
|
-
stmt_ins = (
|
|
349
|
+
sqlite_version = version.parse(sqlite3.sqlite_version)
|
|
350
|
+
is_returning_supported = sqlite_version >= version.parse("3.35.0")
|
|
351
|
+
if is_returning_supported:
|
|
352
|
+
stmt = (
|
|
376
353
|
sqlite.insert(self._table)
|
|
377
354
|
.values(uri=uri, last_id=count)
|
|
378
355
|
.on_conflict_do_update(
|
|
379
356
|
index_elements=["uri"],
|
|
380
357
|
set_={"last_id": self._table.c.last_id + count},
|
|
381
358
|
)
|
|
359
|
+
.returning(self._table.c.last_id)
|
|
382
360
|
)
|
|
383
|
-
self._db.execute(
|
|
361
|
+
last_id = self._db.execute(stmt).fetchone()[0]
|
|
362
|
+
else:
|
|
363
|
+
# Older versions of SQLite are still the default under Ubuntu LTS,
|
|
364
|
+
# e.g. Ubuntu 20.04 LTS (Focal Fossa) uses 3.31.1
|
|
365
|
+
# Transactions ensure no concurrency conflicts
|
|
366
|
+
with self._db.transaction() as conn:
|
|
367
|
+
stmt_ins = (
|
|
368
|
+
sqlite.insert(self._table)
|
|
369
|
+
.values(uri=uri, last_id=count)
|
|
370
|
+
.on_conflict_do_update(
|
|
371
|
+
index_elements=["uri"],
|
|
372
|
+
set_={"last_id": self._table.c.last_id + count},
|
|
373
|
+
)
|
|
374
|
+
)
|
|
375
|
+
self._db.execute(stmt_ins, conn=conn)
|
|
384
376
|
|
|
385
|
-
|
|
386
|
-
|
|
377
|
+
stmt_sel = select(self._table.c.last_id).where(self._table.c.uri == uri)
|
|
378
|
+
last_id = self._db.execute(stmt_sel, conn=conn).fetchone()[0]
|
|
387
379
|
|
|
388
380
|
return range(last_id - count + 1, last_id + 1)
|
|
389
381
|
|
datachain/lib/arrow.py
CHANGED
|
@@ -1,8 +1,9 @@
|
|
|
1
1
|
import re
|
|
2
2
|
from collections.abc import Sequence
|
|
3
3
|
from tempfile import NamedTemporaryFile
|
|
4
|
-
from typing import TYPE_CHECKING, Optional
|
|
4
|
+
from typing import TYPE_CHECKING, Any, Optional
|
|
5
5
|
|
|
6
|
+
import orjson
|
|
6
7
|
import pyarrow as pa
|
|
7
8
|
from pyarrow.dataset import CsvFileFormat, dataset
|
|
8
9
|
from tqdm import tqdm
|
|
@@ -10,6 +11,7 @@ from tqdm import tqdm
|
|
|
10
11
|
from datachain.lib.data_model import dict_to_data_model
|
|
11
12
|
from datachain.lib.file import ArrowRow, File
|
|
12
13
|
from datachain.lib.model_store import ModelStore
|
|
14
|
+
from datachain.lib.signal_schema import SignalSchema
|
|
13
15
|
from datachain.lib.udf import Generator
|
|
14
16
|
|
|
15
17
|
if TYPE_CHECKING:
|
|
@@ -20,6 +22,9 @@ if TYPE_CHECKING:
|
|
|
20
22
|
from datachain.lib.dc import DataChain
|
|
21
23
|
|
|
22
24
|
|
|
25
|
+
DATACHAIN_SIGNAL_SCHEMA_PARQUET_KEY = b"DataChain SignalSchema"
|
|
26
|
+
|
|
27
|
+
|
|
23
28
|
class ArrowGenerator(Generator):
|
|
24
29
|
def __init__(
|
|
25
30
|
self,
|
|
@@ -61,28 +66,35 @@ class ArrowGenerator(Generator):
|
|
|
61
66
|
path, filesystem=file.get_fs(), schema=self.input_schema, **self.kwargs
|
|
62
67
|
)
|
|
63
68
|
hf_schema = _get_hf_schema(ds.schema)
|
|
69
|
+
use_datachain_schema = (
|
|
70
|
+
bool(ds.schema.metadata)
|
|
71
|
+
and DATACHAIN_SIGNAL_SCHEMA_PARQUET_KEY in ds.schema.metadata
|
|
72
|
+
)
|
|
64
73
|
index = 0
|
|
65
74
|
with tqdm(desc="Parsed by pyarrow", unit=" rows") as pbar:
|
|
66
75
|
for record_batch in ds.to_batches():
|
|
67
76
|
for record in record_batch.to_pylist():
|
|
68
|
-
|
|
69
|
-
|
|
70
|
-
|
|
71
|
-
|
|
72
|
-
|
|
73
|
-
|
|
74
|
-
|
|
75
|
-
|
|
76
|
-
|
|
77
|
-
|
|
78
|
-
|
|
79
|
-
|
|
80
|
-
|
|
81
|
-
|
|
82
|
-
|
|
83
|
-
|
|
84
|
-
|
|
85
|
-
|
|
77
|
+
if use_datachain_schema and self.output_schema:
|
|
78
|
+
vals = [_nested_model_instantiate(record, self.output_schema)]
|
|
79
|
+
else:
|
|
80
|
+
vals = list(record.values())
|
|
81
|
+
if self.output_schema:
|
|
82
|
+
fields = self.output_schema.model_fields
|
|
83
|
+
vals_dict = {}
|
|
84
|
+
for i, ((field, field_info), val) in enumerate(
|
|
85
|
+
zip(fields.items(), vals)
|
|
86
|
+
):
|
|
87
|
+
anno = field_info.annotation
|
|
88
|
+
if hf_schema:
|
|
89
|
+
from datachain.lib.hf import convert_feature
|
|
90
|
+
|
|
91
|
+
feat = list(hf_schema[0].values())[i]
|
|
92
|
+
vals_dict[field] = convert_feature(val, feat, anno)
|
|
93
|
+
elif ModelStore.is_pydantic(anno):
|
|
94
|
+
vals_dict[field] = anno(**val) # type: ignore[misc]
|
|
95
|
+
else:
|
|
96
|
+
vals_dict[field] = val
|
|
97
|
+
vals = [self.output_schema(**vals_dict)]
|
|
86
98
|
if self.source:
|
|
87
99
|
kwargs: dict = self.kwargs
|
|
88
100
|
# Can't serialize CsvFileFormat; may lose formatting options.
|
|
@@ -113,6 +125,9 @@ def schema_to_output(schema: pa.Schema, col_names: Optional[Sequence[str]] = Non
|
|
|
113
125
|
)
|
|
114
126
|
if not col_names:
|
|
115
127
|
col_names = schema.names
|
|
128
|
+
signal_schema = _get_datachain_schema(schema)
|
|
129
|
+
if signal_schema:
|
|
130
|
+
return signal_schema.values
|
|
116
131
|
columns = _convert_col_names(col_names) # type: ignore[arg-type]
|
|
117
132
|
hf_schema = _get_hf_schema(schema)
|
|
118
133
|
if hf_schema:
|
|
@@ -197,3 +212,33 @@ def _get_hf_schema(
|
|
|
197
212
|
features = schema_from_arrow(schema)
|
|
198
213
|
return features, get_output_schema(features)
|
|
199
214
|
return None
|
|
215
|
+
|
|
216
|
+
|
|
217
|
+
def _get_datachain_schema(schema: "pa.Schema") -> Optional[SignalSchema]:
|
|
218
|
+
"""Return a restored SignalSchema from parquet metadata, if any is found."""
|
|
219
|
+
if schema.metadata and DATACHAIN_SIGNAL_SCHEMA_PARQUET_KEY in schema.metadata:
|
|
220
|
+
serialized_signal_schema = orjson.loads(
|
|
221
|
+
schema.metadata[DATACHAIN_SIGNAL_SCHEMA_PARQUET_KEY]
|
|
222
|
+
)
|
|
223
|
+
return SignalSchema.deserialize(serialized_signal_schema)
|
|
224
|
+
return None
|
|
225
|
+
|
|
226
|
+
|
|
227
|
+
def _nested_model_instantiate(
|
|
228
|
+
column_values: dict[str, Any], model: type["BaseModel"], prefix: str = ""
|
|
229
|
+
) -> "BaseModel":
|
|
230
|
+
"""Instantiate the given model and all sub-models/fields based on the provided
|
|
231
|
+
column values."""
|
|
232
|
+
vals_dict = {}
|
|
233
|
+
for field, field_info in model.model_fields.items():
|
|
234
|
+
anno = field_info.annotation
|
|
235
|
+
cur_path = f"{prefix}.{field}" if prefix else field
|
|
236
|
+
if ModelStore.is_pydantic(anno):
|
|
237
|
+
vals_dict[field] = _nested_model_instantiate(
|
|
238
|
+
column_values,
|
|
239
|
+
anno, # type: ignore[arg-type]
|
|
240
|
+
prefix=cur_path,
|
|
241
|
+
)
|
|
242
|
+
elif cur_path in column_values:
|
|
243
|
+
vals_dict[field] = column_values[cur_path]
|
|
244
|
+
return model(**vals_dict)
|
|
@@ -4,7 +4,7 @@ from typing import Any, Union
|
|
|
4
4
|
from datachain.lib.data_model import (
|
|
5
5
|
DataType,
|
|
6
6
|
DataTypeNames,
|
|
7
|
-
|
|
7
|
+
DataValue,
|
|
8
8
|
is_chain_type,
|
|
9
9
|
)
|
|
10
10
|
from datachain.lib.utils import DataChainParamsError
|
|
@@ -20,7 +20,7 @@ class ValuesToTupleError(DataChainParamsError):
|
|
|
20
20
|
def values_to_tuples( # noqa: C901, PLR0912
|
|
21
21
|
ds_name: str = "",
|
|
22
22
|
output: Union[None, DataType, Sequence[str], dict[str, DataType]] = None,
|
|
23
|
-
**fr_map: Sequence[
|
|
23
|
+
**fr_map: Sequence[DataValue],
|
|
24
24
|
) -> tuple[Any, Any, Any]:
|
|
25
25
|
if output:
|
|
26
26
|
if not isinstance(output, (Sequence, str, dict)):
|
datachain/lib/data_model.py
CHANGED
|
@@ -18,7 +18,7 @@ StandardType = Union[
|
|
|
18
18
|
]
|
|
19
19
|
DataType = Union[type[BaseModel], StandardType]
|
|
20
20
|
DataTypeNames = "BaseModel, int, str, float, bool, list, dict, bytes, datetime"
|
|
21
|
-
|
|
21
|
+
DataValue = Union[BaseModel, int, str, float, bool, list, dict, bytes, datetime]
|
|
22
22
|
|
|
23
23
|
|
|
24
24
|
class DataModel(BaseModel):
|
datachain/lib/dc.py
CHANGED
|
@@ -16,6 +16,7 @@ from typing import (
|
|
|
16
16
|
overload,
|
|
17
17
|
)
|
|
18
18
|
|
|
19
|
+
import orjson
|
|
19
20
|
import pandas as pd
|
|
20
21
|
import sqlalchemy
|
|
21
22
|
from pydantic import BaseModel
|
|
@@ -58,9 +59,10 @@ from datachain.query.dataset import (
|
|
|
58
59
|
from datachain.query.schema import DEFAULT_DELIMITER, Column, DatasetRow
|
|
59
60
|
from datachain.sql.functions import path as pathfunc
|
|
60
61
|
from datachain.telemetry import telemetry
|
|
61
|
-
from datachain.utils import inside_notebook
|
|
62
|
+
from datachain.utils import batched_it, inside_notebook
|
|
62
63
|
|
|
63
64
|
if TYPE_CHECKING:
|
|
65
|
+
from pyarrow import DataType as ArrowDataType
|
|
64
66
|
from typing_extensions import Concatenate, ParamSpec, Self
|
|
65
67
|
|
|
66
68
|
from datachain.lib.hf import HFDatasetType
|
|
@@ -71,6 +73,10 @@ C = Column
|
|
|
71
73
|
|
|
72
74
|
_T = TypeVar("_T")
|
|
73
75
|
D = TypeVar("D", bound="DataChain")
|
|
76
|
+
UDFObjT = TypeVar("UDFObjT", bound=UDFBase)
|
|
77
|
+
|
|
78
|
+
|
|
79
|
+
DEFAULT_PARQUET_CHUNK_SIZE = 100_000
|
|
74
80
|
|
|
75
81
|
|
|
76
82
|
def resolve_columns(
|
|
@@ -819,7 +825,7 @@ class DataChain:
|
|
|
819
825
|
|
|
820
826
|
def gen(
|
|
821
827
|
self,
|
|
822
|
-
func: Optional[Callable] = None,
|
|
828
|
+
func: Optional[Union[Callable, Generator]] = None,
|
|
823
829
|
params: Union[None, str, Sequence[str]] = None,
|
|
824
830
|
output: OutputType = None,
|
|
825
831
|
**signal_map,
|
|
@@ -931,12 +937,12 @@ class DataChain:
|
|
|
931
937
|
|
|
932
938
|
def _udf_to_obj(
|
|
933
939
|
self,
|
|
934
|
-
target_class: type[
|
|
935
|
-
func: Optional[Callable],
|
|
940
|
+
target_class: type[UDFObjT],
|
|
941
|
+
func: Optional[Union[Callable, UDFObjT]],
|
|
936
942
|
params: Union[None, str, Sequence[str]],
|
|
937
943
|
output: OutputType,
|
|
938
944
|
signal_map,
|
|
939
|
-
) ->
|
|
945
|
+
) -> UDFObjT:
|
|
940
946
|
is_generator = target_class.is_output_batched
|
|
941
947
|
name = self.name or ""
|
|
942
948
|
|
|
@@ -1019,7 +1025,7 @@ class DataChain:
|
|
|
1019
1025
|
The supported functions:
|
|
1020
1026
|
Numerical: +, -, *, /, rand(), avg(), count(), func(),
|
|
1021
1027
|
greatest(), least(), max(), min(), sum()
|
|
1022
|
-
String: length(), split()
|
|
1028
|
+
String: length(), split(), replace(), regexp_replace()
|
|
1023
1029
|
Filename: name(), parent(), file_stem(), file_ext()
|
|
1024
1030
|
Array: length(), sip_hash_64(), euclidean_distance(),
|
|
1025
1031
|
cosine_distance()
|
|
@@ -1103,6 +1109,29 @@ class DataChain:
|
|
|
1103
1109
|
rows = (row_factory(db_signals, r) for r in rows)
|
|
1104
1110
|
yield from rows
|
|
1105
1111
|
|
|
1112
|
+
def to_columnar_data_with_names(
|
|
1113
|
+
self, chunk_size: int = DEFAULT_PARQUET_CHUNK_SIZE
|
|
1114
|
+
) -> tuple[list[str], Iterator[list[list[Any]]]]:
|
|
1115
|
+
"""Returns column names and the results as an iterator that provides chunks,
|
|
1116
|
+
with each chunk containing a list of columns, where each column contains a
|
|
1117
|
+
list of the row values for that column in that chunk. Useful for columnar data
|
|
1118
|
+
formats, such as parquet or other OLAP databases.
|
|
1119
|
+
"""
|
|
1120
|
+
headers, _ = self._effective_signals_schema.get_headers_with_length()
|
|
1121
|
+
column_names = [".".join(filter(None, header)) for header in headers]
|
|
1122
|
+
|
|
1123
|
+
results_iter = self.collect_flatten()
|
|
1124
|
+
|
|
1125
|
+
def column_chunks() -> Iterator[list[list[Any]]]:
|
|
1126
|
+
for chunk_iter in batched_it(results_iter, chunk_size):
|
|
1127
|
+
columns: list[list[Any]] = [[] for _ in column_names]
|
|
1128
|
+
for row in chunk_iter:
|
|
1129
|
+
for i, col in enumerate(columns):
|
|
1130
|
+
col.append(row[i])
|
|
1131
|
+
yield columns
|
|
1132
|
+
|
|
1133
|
+
return column_names, column_chunks()
|
|
1134
|
+
|
|
1106
1135
|
@overload
|
|
1107
1136
|
def results(self) -> list[tuple[Any, ...]]: ...
|
|
1108
1137
|
|
|
@@ -1681,6 +1710,7 @@ class DataChain:
|
|
|
1681
1710
|
nrows=None,
|
|
1682
1711
|
session: Optional[Session] = None,
|
|
1683
1712
|
settings: Optional[dict] = None,
|
|
1713
|
+
column_types: Optional[dict[str, "Union[str, ArrowDataType]"]] = None,
|
|
1684
1714
|
**kwargs,
|
|
1685
1715
|
) -> "DataChain":
|
|
1686
1716
|
"""Generate chain from csv files.
|
|
@@ -1699,6 +1729,9 @@ class DataChain:
|
|
|
1699
1729
|
nrows : Optional row limit.
|
|
1700
1730
|
session : Session to use for the chain.
|
|
1701
1731
|
settings : Settings to use for the chain.
|
|
1732
|
+
column_types : Dictionary of column names and their corresponding types.
|
|
1733
|
+
It is passed to CSV reader and for each column specified type auto
|
|
1734
|
+
inference is disabled.
|
|
1702
1735
|
|
|
1703
1736
|
Example:
|
|
1704
1737
|
Reading a csv file:
|
|
@@ -1714,6 +1747,15 @@ class DataChain:
|
|
|
1714
1747
|
from pandas.io.parsers.readers import STR_NA_VALUES
|
|
1715
1748
|
from pyarrow.csv import ConvertOptions, ParseOptions, ReadOptions
|
|
1716
1749
|
from pyarrow.dataset import CsvFileFormat
|
|
1750
|
+
from pyarrow.lib import type_for_alias
|
|
1751
|
+
|
|
1752
|
+
if column_types:
|
|
1753
|
+
column_types = {
|
|
1754
|
+
name: type_for_alias(typ) if isinstance(typ, str) else typ
|
|
1755
|
+
for name, typ in column_types.items()
|
|
1756
|
+
}
|
|
1757
|
+
else:
|
|
1758
|
+
column_types = {}
|
|
1717
1759
|
|
|
1718
1760
|
chain = DataChain.from_storage(
|
|
1719
1761
|
path, session=session, settings=settings, **kwargs
|
|
@@ -1739,7 +1781,9 @@ class DataChain:
|
|
|
1739
1781
|
parse_options = ParseOptions(delimiter=delimiter)
|
|
1740
1782
|
read_options = ReadOptions(column_names=column_names)
|
|
1741
1783
|
convert_options = ConvertOptions(
|
|
1742
|
-
strings_can_be_null=True,
|
|
1784
|
+
strings_can_be_null=True,
|
|
1785
|
+
null_values=STR_NA_VALUES,
|
|
1786
|
+
column_types=column_types,
|
|
1743
1787
|
)
|
|
1744
1788
|
format = CsvFileFormat(
|
|
1745
1789
|
parse_options=parse_options,
|
|
@@ -1808,21 +1852,96 @@ class DataChain:
|
|
|
1808
1852
|
self,
|
|
1809
1853
|
path: Union[str, os.PathLike[str], BinaryIO],
|
|
1810
1854
|
partition_cols: Optional[Sequence[str]] = None,
|
|
1855
|
+
chunk_size: int = DEFAULT_PARQUET_CHUNK_SIZE,
|
|
1811
1856
|
**kwargs,
|
|
1812
1857
|
) -> None:
|
|
1813
|
-
"""Save chain to parquet file.
|
|
1858
|
+
"""Save chain to parquet file with SignalSchema metadata.
|
|
1814
1859
|
|
|
1815
1860
|
Parameters:
|
|
1816
1861
|
path : Path or a file-like binary object to save the file.
|
|
1817
1862
|
partition_cols : Column names by which to partition the dataset.
|
|
1863
|
+
chunk_size : The chunk size of results to read and convert to columnar
|
|
1864
|
+
data, to avoid running out of memory.
|
|
1818
1865
|
"""
|
|
1866
|
+
import pyarrow as pa
|
|
1867
|
+
import pyarrow.parquet as pq
|
|
1868
|
+
|
|
1869
|
+
from datachain.lib.arrow import DATACHAIN_SIGNAL_SCHEMA_PARQUET_KEY
|
|
1870
|
+
|
|
1819
1871
|
_partition_cols = list(partition_cols) if partition_cols else None
|
|
1820
|
-
|
|
1821
|
-
|
|
1822
|
-
partition_cols=_partition_cols,
|
|
1823
|
-
**kwargs,
|
|
1872
|
+
signal_schema_metadata = orjson.dumps(
|
|
1873
|
+
self._effective_signals_schema.serialize()
|
|
1824
1874
|
)
|
|
1825
1875
|
|
|
1876
|
+
column_names, column_chunks = self.to_columnar_data_with_names(chunk_size)
|
|
1877
|
+
|
|
1878
|
+
parquet_schema = None
|
|
1879
|
+
parquet_writer = None
|
|
1880
|
+
first_chunk = True
|
|
1881
|
+
|
|
1882
|
+
for chunk in column_chunks:
|
|
1883
|
+
# pyarrow infers the best parquet schema from the python types of
|
|
1884
|
+
# the input data.
|
|
1885
|
+
table = pa.Table.from_pydict(
|
|
1886
|
+
dict(zip(column_names, chunk)),
|
|
1887
|
+
schema=parquet_schema,
|
|
1888
|
+
)
|
|
1889
|
+
|
|
1890
|
+
# Preserve any existing metadata, and add the DataChain SignalSchema.
|
|
1891
|
+
existing_metadata = table.schema.metadata or {}
|
|
1892
|
+
merged_metadata = {
|
|
1893
|
+
**existing_metadata,
|
|
1894
|
+
DATACHAIN_SIGNAL_SCHEMA_PARQUET_KEY: signal_schema_metadata,
|
|
1895
|
+
}
|
|
1896
|
+
table = table.replace_schema_metadata(merged_metadata)
|
|
1897
|
+
parquet_schema = table.schema
|
|
1898
|
+
|
|
1899
|
+
if _partition_cols:
|
|
1900
|
+
# Write to a partitioned parquet dataset.
|
|
1901
|
+
pq.write_to_dataset(
|
|
1902
|
+
table,
|
|
1903
|
+
root_path=path,
|
|
1904
|
+
partition_cols=_partition_cols,
|
|
1905
|
+
**kwargs,
|
|
1906
|
+
)
|
|
1907
|
+
else:
|
|
1908
|
+
if first_chunk:
|
|
1909
|
+
# Write to a single parquet file.
|
|
1910
|
+
parquet_writer = pq.ParquetWriter(path, parquet_schema, **kwargs)
|
|
1911
|
+
first_chunk = False
|
|
1912
|
+
|
|
1913
|
+
assert parquet_writer
|
|
1914
|
+
parquet_writer.write_table(table)
|
|
1915
|
+
|
|
1916
|
+
if parquet_writer:
|
|
1917
|
+
parquet_writer.close()
|
|
1918
|
+
|
|
1919
|
+
def to_csv(
|
|
1920
|
+
self,
|
|
1921
|
+
path: Union[str, os.PathLike[str]],
|
|
1922
|
+
delimiter: str = ",",
|
|
1923
|
+
**kwargs,
|
|
1924
|
+
) -> None:
|
|
1925
|
+
"""Save chain to a csv (comma-separated values) file.
|
|
1926
|
+
|
|
1927
|
+
Parameters:
|
|
1928
|
+
path : Path to save the file.
|
|
1929
|
+
delimiter : Delimiter to use for the resulting file.
|
|
1930
|
+
"""
|
|
1931
|
+
import csv
|
|
1932
|
+
|
|
1933
|
+
headers, _ = self._effective_signals_schema.get_headers_with_length()
|
|
1934
|
+
column_names = [".".join(filter(None, header)) for header in headers]
|
|
1935
|
+
|
|
1936
|
+
results_iter = self.collect_flatten()
|
|
1937
|
+
|
|
1938
|
+
with open(path, "w", newline="") as f:
|
|
1939
|
+
writer = csv.writer(f, delimiter=delimiter, **kwargs)
|
|
1940
|
+
writer.writerow(column_names)
|
|
1941
|
+
|
|
1942
|
+
for row in results_iter:
|
|
1943
|
+
writer.writerow(row)
|
|
1944
|
+
|
|
1826
1945
|
@classmethod
|
|
1827
1946
|
def from_records(
|
|
1828
1947
|
cls,
|
datachain/lib/signal_schema.py
CHANGED
|
@@ -25,7 +25,7 @@ from typing_extensions import Literal as LiteralEx
|
|
|
25
25
|
from datachain.lib.convert.python_to_sql import python_to_sql
|
|
26
26
|
from datachain.lib.convert.sql_to_python import sql_to_python
|
|
27
27
|
from datachain.lib.convert.unflatten import unflatten_to_json_pos
|
|
28
|
-
from datachain.lib.data_model import DataModel, DataType
|
|
28
|
+
from datachain.lib.data_model import DataModel, DataType, DataValue
|
|
29
29
|
from datachain.lib.file import File
|
|
30
30
|
from datachain.lib.model_store import ModelStore
|
|
31
31
|
from datachain.lib.utils import DataChainParamsError
|
|
@@ -110,7 +110,7 @@ class SignalSchema:
|
|
|
110
110
|
values: dict[str, DataType]
|
|
111
111
|
tree: dict[str, Any]
|
|
112
112
|
setup_func: dict[str, Callable]
|
|
113
|
-
setup_values: Optional[dict[str,
|
|
113
|
+
setup_values: Optional[dict[str, Any]]
|
|
114
114
|
|
|
115
115
|
def __init__(
|
|
116
116
|
self,
|
|
@@ -333,21 +333,21 @@ class SignalSchema:
|
|
|
333
333
|
res[db_name] = python_to_sql(type_)
|
|
334
334
|
return res
|
|
335
335
|
|
|
336
|
-
def row_to_objs(self, row: Sequence[Any]) -> list[
|
|
336
|
+
def row_to_objs(self, row: Sequence[Any]) -> list[DataValue]:
|
|
337
337
|
self._init_setup_values()
|
|
338
338
|
|
|
339
|
-
objs = []
|
|
339
|
+
objs: list[DataValue] = []
|
|
340
340
|
pos = 0
|
|
341
341
|
for name, fr_type in self.values.items():
|
|
342
342
|
if self.setup_values and (val := self.setup_values.get(name, None)):
|
|
343
343
|
objs.append(val)
|
|
344
344
|
elif (fr := ModelStore.to_pydantic(fr_type)) is not None:
|
|
345
345
|
j, pos = unflatten_to_json_pos(fr, row, pos)
|
|
346
|
-
objs.append(fr(**j))
|
|
346
|
+
objs.append(fr(**j))
|
|
347
347
|
else:
|
|
348
348
|
objs.append(row[pos])
|
|
349
349
|
pos += 1
|
|
350
|
-
return objs
|
|
350
|
+
return objs
|
|
351
351
|
|
|
352
352
|
def contains_file(self) -> bool:
|
|
353
353
|
for type_ in self.values.values():
|