datachain 0.6.0__py3-none-any.whl → 0.6.2__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of datachain might be problematic. Click here for more details.
- datachain/__init__.py +2 -0
- datachain/catalog/catalog.py +62 -228
- datachain/cli.py +136 -22
- datachain/client/fsspec.py +9 -0
- datachain/client/local.py +11 -32
- datachain/config.py +126 -51
- datachain/data_storage/schema.py +66 -33
- datachain/data_storage/sqlite.py +12 -4
- datachain/data_storage/warehouse.py +101 -129
- datachain/lib/convert/sql_to_python.py +8 -12
- datachain/lib/dc.py +275 -80
- datachain/lib/func/__init__.py +32 -0
- datachain/lib/func/aggregate.py +353 -0
- datachain/lib/func/func.py +152 -0
- datachain/lib/listing.py +6 -21
- datachain/lib/listing_info.py +4 -0
- datachain/lib/signal_schema.py +17 -8
- datachain/lib/udf.py +3 -3
- datachain/lib/utils.py +5 -0
- datachain/listing.py +22 -48
- datachain/query/__init__.py +1 -2
- datachain/query/batch.py +0 -1
- datachain/query/dataset.py +33 -46
- datachain/query/schema.py +1 -61
- datachain/query/session.py +33 -25
- datachain/remote/studio.py +63 -14
- datachain/sql/functions/__init__.py +1 -1
- datachain/sql/functions/aggregate.py +47 -0
- datachain/sql/functions/array.py +0 -8
- datachain/sql/sqlite/base.py +20 -2
- datachain/studio.py +129 -0
- datachain/utils.py +58 -0
- {datachain-0.6.0.dist-info → datachain-0.6.2.dist-info}/METADATA +7 -6
- {datachain-0.6.0.dist-info → datachain-0.6.2.dist-info}/RECORD +38 -33
- {datachain-0.6.0.dist-info → datachain-0.6.2.dist-info}/WHEEL +1 -1
- {datachain-0.6.0.dist-info → datachain-0.6.2.dist-info}/LICENSE +0 -0
- {datachain-0.6.0.dist-info → datachain-0.6.2.dist-info}/entry_points.txt +0 -0
- {datachain-0.6.0.dist-info → datachain-0.6.2.dist-info}/top_level.txt +0 -0
datachain/lib/dc.py
CHANGED
|
@@ -1,6 +1,8 @@
|
|
|
1
1
|
import copy
|
|
2
2
|
import os
|
|
3
|
+
import os.path
|
|
3
4
|
import re
|
|
5
|
+
import sys
|
|
4
6
|
from collections.abc import Iterator, Sequence
|
|
5
7
|
from functools import wraps
|
|
6
8
|
from typing import (
|
|
@@ -23,16 +25,17 @@ from pydantic import BaseModel
|
|
|
23
25
|
from sqlalchemy.sql.functions import GenericFunction
|
|
24
26
|
from sqlalchemy.sql.sqltypes import NullType
|
|
25
27
|
|
|
28
|
+
from datachain.client import Client
|
|
29
|
+
from datachain.client.local import FileClient
|
|
30
|
+
from datachain.dataset import DatasetRecord
|
|
26
31
|
from datachain.lib.convert.python_to_sql import python_to_sql
|
|
27
32
|
from datachain.lib.convert.values_to_tuples import values_to_tuples
|
|
28
33
|
from datachain.lib.data_model import DataModel, DataType, dict_to_data_model
|
|
29
34
|
from datachain.lib.dataset_info import DatasetInfo
|
|
30
35
|
from datachain.lib.file import ArrowRow, File, get_file_type
|
|
31
36
|
from datachain.lib.file import ExportPlacement as FileExportPlacement
|
|
37
|
+
from datachain.lib.func import Func
|
|
32
38
|
from datachain.lib.listing import (
|
|
33
|
-
is_listing_dataset,
|
|
34
|
-
is_listing_expired,
|
|
35
|
-
is_listing_subset,
|
|
36
39
|
list_bucket,
|
|
37
40
|
ls,
|
|
38
41
|
parse_listing_uri,
|
|
@@ -42,24 +45,15 @@ from datachain.lib.meta_formats import read_meta, read_schema
|
|
|
42
45
|
from datachain.lib.model_store import ModelStore
|
|
43
46
|
from datachain.lib.settings import Settings
|
|
44
47
|
from datachain.lib.signal_schema import SignalSchema
|
|
45
|
-
from datachain.lib.udf import
|
|
46
|
-
Aggregator,
|
|
47
|
-
BatchMapper,
|
|
48
|
-
Generator,
|
|
49
|
-
Mapper,
|
|
50
|
-
UDFBase,
|
|
51
|
-
)
|
|
48
|
+
from datachain.lib.udf import Aggregator, BatchMapper, Generator, Mapper, UDFBase
|
|
52
49
|
from datachain.lib.udf_signature import UdfSignature
|
|
53
|
-
from datachain.lib.utils import DataChainParamsError
|
|
50
|
+
from datachain.lib.utils import DataChainColumnError, DataChainParamsError
|
|
54
51
|
from datachain.query import Session
|
|
55
|
-
from datachain.query.dataset import
|
|
56
|
-
|
|
57
|
-
PartitionByType,
|
|
58
|
-
)
|
|
59
|
-
from datachain.query.schema import DEFAULT_DELIMITER, Column, DatasetRow
|
|
52
|
+
from datachain.query.dataset import DatasetQuery, PartitionByType
|
|
53
|
+
from datachain.query.schema import DEFAULT_DELIMITER, Column, ColumnMeta
|
|
60
54
|
from datachain.sql.functions import path as pathfunc
|
|
61
55
|
from datachain.telemetry import telemetry
|
|
62
|
-
from datachain.utils import batched_it, inside_notebook
|
|
56
|
+
from datachain.utils import batched_it, inside_notebook, row_to_nested_dict
|
|
63
57
|
|
|
64
58
|
if TYPE_CHECKING:
|
|
65
59
|
from pyarrow import DataType as ArrowDataType
|
|
@@ -149,11 +143,6 @@ class DatasetMergeError(DataChainParamsError): # noqa: D101
|
|
|
149
143
|
super().__init__(f"Merge error on='{on_str}'{right_on_str}: {msg}")
|
|
150
144
|
|
|
151
145
|
|
|
152
|
-
class DataChainColumnError(DataChainParamsError): # noqa: D101
|
|
153
|
-
def __init__(self, col_name, msg): # noqa: D107
|
|
154
|
-
super().__init__(f"Error for column {col_name}: {msg}")
|
|
155
|
-
|
|
156
|
-
|
|
157
146
|
OutputType = Union[None, DataType, Sequence[str], dict[str, DataType]]
|
|
158
147
|
|
|
159
148
|
|
|
@@ -300,6 +289,13 @@ class DataChain:
|
|
|
300
289
|
"""Version of the underlying dataset, if there is one."""
|
|
301
290
|
return self._query.version
|
|
302
291
|
|
|
292
|
+
@property
|
|
293
|
+
def dataset(self) -> Optional[DatasetRecord]:
|
|
294
|
+
"""Underlying dataset, if there is one."""
|
|
295
|
+
if not self.name:
|
|
296
|
+
return None
|
|
297
|
+
return self.session.catalog.get_dataset(self.name)
|
|
298
|
+
|
|
303
299
|
def __or__(self, other: "Self") -> "Self":
|
|
304
300
|
"""Return `self.union(other)`."""
|
|
305
301
|
return self.union(other)
|
|
@@ -380,6 +376,47 @@ class DataChain:
|
|
|
380
376
|
self.signals_schema |= signals_schema
|
|
381
377
|
return self
|
|
382
378
|
|
|
379
|
+
@classmethod
|
|
380
|
+
def parse_uri(
|
|
381
|
+
cls, uri: str, session: Session, update: bool = False
|
|
382
|
+
) -> tuple[str, str, str, bool]:
|
|
383
|
+
"""Returns correct listing dataset name that must be used for saving listing
|
|
384
|
+
operation. It takes into account existing listings and reusability of those.
|
|
385
|
+
It also returns boolean saying if returned dataset name is reused / already
|
|
386
|
+
exists or not, and it returns correct listing path that should be used to find
|
|
387
|
+
rows based on uri.
|
|
388
|
+
"""
|
|
389
|
+
catalog = session.catalog
|
|
390
|
+
cache = catalog.cache
|
|
391
|
+
client_config = catalog.client_config
|
|
392
|
+
|
|
393
|
+
client = Client.get_client(uri, cache, **client_config)
|
|
394
|
+
ds_name, list_uri, list_path = parse_listing_uri(uri, cache, client_config)
|
|
395
|
+
listing = None
|
|
396
|
+
|
|
397
|
+
listings = [
|
|
398
|
+
ls
|
|
399
|
+
for ls in catalog.listings()
|
|
400
|
+
if not ls.is_expired and ls.contains(ds_name)
|
|
401
|
+
]
|
|
402
|
+
|
|
403
|
+
if listings:
|
|
404
|
+
if update:
|
|
405
|
+
# choosing the smallest possible one to minimize update time
|
|
406
|
+
listing = sorted(listings, key=lambda ls: len(ls.name))[0]
|
|
407
|
+
else:
|
|
408
|
+
# no need to update, choosing the most recent one
|
|
409
|
+
listing = sorted(listings, key=lambda ls: ls.created_at)[-1]
|
|
410
|
+
|
|
411
|
+
if isinstance(client, FileClient) and listing and listing.name != ds_name:
|
|
412
|
+
# For local file system we need to fix listing path / prefix
|
|
413
|
+
# if we are reusing existing listing
|
|
414
|
+
list_path = f'{ds_name.strip("/").removeprefix(listing.name)}/{list_path}'
|
|
415
|
+
|
|
416
|
+
ds_name = listing.name if listing else ds_name
|
|
417
|
+
|
|
418
|
+
return ds_name, list_uri, list_path, bool(listing)
|
|
419
|
+
|
|
383
420
|
@classmethod
|
|
384
421
|
def from_storage(
|
|
385
422
|
cls,
|
|
@@ -414,25 +451,15 @@ class DataChain:
|
|
|
414
451
|
file_type = get_file_type(type)
|
|
415
452
|
|
|
416
453
|
client_config = {"anon": True} if anon else None
|
|
417
|
-
|
|
418
454
|
session = Session.get(session, client_config=client_config, in_memory=in_memory)
|
|
455
|
+
cache = session.catalog.cache
|
|
456
|
+
client_config = session.catalog.client_config
|
|
419
457
|
|
|
420
|
-
|
|
421
|
-
uri, session
|
|
458
|
+
list_ds_name, list_uri, list_path, list_ds_exists = cls.parse_uri(
|
|
459
|
+
uri, session, update=update
|
|
422
460
|
)
|
|
423
|
-
need_listing = True
|
|
424
461
|
|
|
425
|
-
|
|
426
|
-
if (
|
|
427
|
-
not is_listing_expired(ds.created_at) # type: ignore[union-attr]
|
|
428
|
-
and is_listing_subset(ds.name, list_dataset_name) # type: ignore[union-attr]
|
|
429
|
-
and not update
|
|
430
|
-
):
|
|
431
|
-
need_listing = False
|
|
432
|
-
list_dataset_name = ds.name # type: ignore[union-attr]
|
|
433
|
-
|
|
434
|
-
if need_listing:
|
|
435
|
-
# caching new listing to special listing dataset
|
|
462
|
+
if update or not list_ds_exists:
|
|
436
463
|
(
|
|
437
464
|
cls.from_records(
|
|
438
465
|
DataChain.DEFAULT_FILE_RECORD,
|
|
@@ -441,17 +468,13 @@ class DataChain:
|
|
|
441
468
|
in_memory=in_memory,
|
|
442
469
|
)
|
|
443
470
|
.gen(
|
|
444
|
-
list_bucket(
|
|
445
|
-
list_uri,
|
|
446
|
-
session.catalog.cache,
|
|
447
|
-
client_config=session.catalog.client_config,
|
|
448
|
-
),
|
|
471
|
+
list_bucket(list_uri, cache, client_config=client_config),
|
|
449
472
|
output={f"{object_name}": File},
|
|
450
473
|
)
|
|
451
|
-
.save(
|
|
474
|
+
.save(list_ds_name, listing=True)
|
|
452
475
|
)
|
|
453
476
|
|
|
454
|
-
dc = cls.from_dataset(
|
|
477
|
+
dc = cls.from_dataset(list_ds_name, session=session, settings=settings)
|
|
455
478
|
dc.signals_schema = dc.signals_schema.mutate({f"{object_name}": file_type})
|
|
456
479
|
|
|
457
480
|
return ls(dc, list_path, recursive=recursive, object_name=object_name)
|
|
@@ -678,19 +701,11 @@ class DataChain:
|
|
|
678
701
|
session = Session.get(session, in_memory=in_memory)
|
|
679
702
|
catalog = kwargs.get("catalog") or session.catalog
|
|
680
703
|
|
|
681
|
-
listings = [
|
|
682
|
-
ListingInfo.from_models(d, v, j)
|
|
683
|
-
for d, v, j in catalog.list_datasets_versions(
|
|
684
|
-
include_listing=True, **kwargs
|
|
685
|
-
)
|
|
686
|
-
if is_listing_dataset(d.name)
|
|
687
|
-
]
|
|
688
|
-
|
|
689
704
|
return cls.from_values(
|
|
690
705
|
session=session,
|
|
691
706
|
in_memory=in_memory,
|
|
692
707
|
output={object_name: ListingInfo},
|
|
693
|
-
**{object_name: listings}, # type: ignore[arg-type]
|
|
708
|
+
**{object_name: catalog.listings()}, # type: ignore[arg-type]
|
|
694
709
|
)
|
|
695
710
|
|
|
696
711
|
def print_json_schema( # type: ignore[override]
|
|
@@ -982,10 +997,9 @@ class DataChain:
|
|
|
982
997
|
row is left in the result set.
|
|
983
998
|
|
|
984
999
|
Example:
|
|
985
|
-
|
|
986
|
-
|
|
987
|
-
|
|
988
|
-
```
|
|
1000
|
+
```py
|
|
1001
|
+
dc.distinct("file.parent", "file.name")
|
|
1002
|
+
```
|
|
989
1003
|
"""
|
|
990
1004
|
return self._evolve(
|
|
991
1005
|
query=self._query.distinct(
|
|
@@ -1011,6 +1025,63 @@ class DataChain:
|
|
|
1011
1025
|
query=self._query.select(*columns), signal_schema=new_schema
|
|
1012
1026
|
)
|
|
1013
1027
|
|
|
1028
|
+
def group_by(
|
|
1029
|
+
self,
|
|
1030
|
+
*,
|
|
1031
|
+
partition_by: Union[str, Sequence[str]],
|
|
1032
|
+
**kwargs: Func,
|
|
1033
|
+
) -> "Self":
|
|
1034
|
+
"""Group rows by specified set of signals and return new signals
|
|
1035
|
+
with aggregated values.
|
|
1036
|
+
|
|
1037
|
+
The supported functions:
|
|
1038
|
+
count(), sum(), avg(), min(), max(), any_value(), collect(), concat()
|
|
1039
|
+
|
|
1040
|
+
Example:
|
|
1041
|
+
```py
|
|
1042
|
+
chain = chain.group_by(
|
|
1043
|
+
cnt=func.count(),
|
|
1044
|
+
partition_by=("file_source", "file_ext"),
|
|
1045
|
+
)
|
|
1046
|
+
```
|
|
1047
|
+
"""
|
|
1048
|
+
if isinstance(partition_by, str):
|
|
1049
|
+
partition_by = [partition_by]
|
|
1050
|
+
if not partition_by:
|
|
1051
|
+
raise ValueError("At least one column should be provided for partition_by")
|
|
1052
|
+
|
|
1053
|
+
if not kwargs:
|
|
1054
|
+
raise ValueError("At least one column should be provided for group_by")
|
|
1055
|
+
for col_name, func in kwargs.items():
|
|
1056
|
+
if not isinstance(func, Func):
|
|
1057
|
+
raise DataChainColumnError(
|
|
1058
|
+
col_name,
|
|
1059
|
+
f"Column {col_name} has type {type(func)} but expected Func object",
|
|
1060
|
+
)
|
|
1061
|
+
|
|
1062
|
+
partition_by_columns: list[Column] = []
|
|
1063
|
+
signal_columns: list[Column] = []
|
|
1064
|
+
schema_fields: dict[str, DataType] = {}
|
|
1065
|
+
|
|
1066
|
+
# validate partition_by columns and add them to the schema
|
|
1067
|
+
for col_name in partition_by:
|
|
1068
|
+
col_db_name = ColumnMeta.to_db_name(col_name)
|
|
1069
|
+
col_type = self.signals_schema.get_column_type(col_db_name)
|
|
1070
|
+
col = Column(col_db_name, python_to_sql(col_type))
|
|
1071
|
+
partition_by_columns.append(col)
|
|
1072
|
+
schema_fields[col_db_name] = col_type
|
|
1073
|
+
|
|
1074
|
+
# validate signal columns and add them to the schema
|
|
1075
|
+
for col_name, func in kwargs.items():
|
|
1076
|
+
col = func.get_column(self.signals_schema, label=col_name)
|
|
1077
|
+
signal_columns.append(col)
|
|
1078
|
+
schema_fields[col_name] = func.get_result_type(self.signals_schema)
|
|
1079
|
+
|
|
1080
|
+
return self._evolve(
|
|
1081
|
+
query=self._query.group_by(signal_columns, partition_by_columns),
|
|
1082
|
+
signal_schema=SignalSchema(schema_fields),
|
|
1083
|
+
)
|
|
1084
|
+
|
|
1014
1085
|
def mutate(self, **kwargs) -> "Self":
|
|
1015
1086
|
"""Create new signals based on existing signals.
|
|
1016
1087
|
|
|
@@ -1029,13 +1100,22 @@ class DataChain:
|
|
|
1029
1100
|
Filename: name(), parent(), file_stem(), file_ext()
|
|
1030
1101
|
Array: length(), sip_hash_64(), euclidean_distance(),
|
|
1031
1102
|
cosine_distance()
|
|
1103
|
+
Window: row_number(), rank(), dense_rank(), first()
|
|
1032
1104
|
|
|
1033
1105
|
Example:
|
|
1034
1106
|
```py
|
|
1035
1107
|
dc.mutate(
|
|
1036
|
-
|
|
1037
|
-
|
|
1038
|
-
|
|
1108
|
+
area=Column("image.height") * Column("image.width"),
|
|
1109
|
+
extension=file_ext(Column("file.name")),
|
|
1110
|
+
dist=cosine_distance(embedding_text, embedding_image)
|
|
1111
|
+
)
|
|
1112
|
+
```
|
|
1113
|
+
|
|
1114
|
+
Window function example:
|
|
1115
|
+
```py
|
|
1116
|
+
window = func.window(partition_by="file.parent", order_by="file.size")
|
|
1117
|
+
dc.mutate(
|
|
1118
|
+
row_number=func.row_number().over(window),
|
|
1039
1119
|
)
|
|
1040
1120
|
```
|
|
1041
1121
|
|
|
@@ -1046,20 +1126,12 @@ class DataChain:
|
|
|
1046
1126
|
Example:
|
|
1047
1127
|
```py
|
|
1048
1128
|
dc.mutate(
|
|
1049
|
-
|
|
1129
|
+
newkey=Column("oldkey")
|
|
1050
1130
|
)
|
|
1051
1131
|
```
|
|
1052
1132
|
"""
|
|
1053
|
-
existing_columns = set(self.signals_schema.values.keys())
|
|
1054
|
-
for col_name in kwargs:
|
|
1055
|
-
if col_name in existing_columns:
|
|
1056
|
-
raise DataChainColumnError(
|
|
1057
|
-
col_name,
|
|
1058
|
-
"Cannot modify existing column with mutate(). "
|
|
1059
|
-
"Use a different name for the new column.",
|
|
1060
|
-
)
|
|
1061
1133
|
for col_name, expr in kwargs.items():
|
|
1062
|
-
if not isinstance(expr, Column) and isinstance(expr.type, NullType):
|
|
1134
|
+
if not isinstance(expr, (Column, Func)) and isinstance(expr.type, NullType):
|
|
1063
1135
|
raise DataChainColumnError(
|
|
1064
1136
|
col_name, f"Cannot infer type with expression {expr}"
|
|
1065
1137
|
)
|
|
@@ -1071,6 +1143,9 @@ class DataChain:
|
|
|
1071
1143
|
# renaming existing column
|
|
1072
1144
|
for signal in schema.db_signals(name=value.name, as_columns=True):
|
|
1073
1145
|
mutated[signal.name.replace(value.name, name, 1)] = signal # type: ignore[union-attr]
|
|
1146
|
+
elif isinstance(value, Func):
|
|
1147
|
+
# adding new signal
|
|
1148
|
+
mutated[name] = value.get_column(schema)
|
|
1074
1149
|
else:
|
|
1075
1150
|
# adding new signal
|
|
1076
1151
|
mutated[name] = value
|
|
@@ -1477,12 +1552,6 @@ class DataChain:
|
|
|
1477
1552
|
fr_map = {col.lower(): df[col].tolist() for col in df.columns}
|
|
1478
1553
|
|
|
1479
1554
|
for column in fr_map:
|
|
1480
|
-
if column in DatasetRow.schema:
|
|
1481
|
-
raise DatasetPrepareError(
|
|
1482
|
-
name,
|
|
1483
|
-
f"import from pandas error - column '{column}' conflicts with"
|
|
1484
|
-
" default schema",
|
|
1485
|
-
)
|
|
1486
1555
|
if not column.isidentifier():
|
|
1487
1556
|
raise DatasetPrepareError(
|
|
1488
1557
|
name,
|
|
@@ -1853,21 +1922,48 @@ class DataChain:
|
|
|
1853
1922
|
path: Union[str, os.PathLike[str], BinaryIO],
|
|
1854
1923
|
partition_cols: Optional[Sequence[str]] = None,
|
|
1855
1924
|
chunk_size: int = DEFAULT_PARQUET_CHUNK_SIZE,
|
|
1925
|
+
fs_kwargs: Optional[dict[str, Any]] = None,
|
|
1856
1926
|
**kwargs,
|
|
1857
1927
|
) -> None:
|
|
1858
1928
|
"""Save chain to parquet file with SignalSchema metadata.
|
|
1859
1929
|
|
|
1860
1930
|
Parameters:
|
|
1861
|
-
path : Path or a file-like binary object to save the file.
|
|
1931
|
+
path : Path or a file-like binary object to save the file. This supports
|
|
1932
|
+
local paths as well as remote paths, such as s3:// or hf:// with fsspec.
|
|
1862
1933
|
partition_cols : Column names by which to partition the dataset.
|
|
1863
1934
|
chunk_size : The chunk size of results to read and convert to columnar
|
|
1864
1935
|
data, to avoid running out of memory.
|
|
1936
|
+
fs_kwargs : Optional kwargs to pass to the fsspec filesystem, used only for
|
|
1937
|
+
write, for fsspec-type URLs, such as s3:// or hf:// when
|
|
1938
|
+
provided as the destination path.
|
|
1865
1939
|
"""
|
|
1866
1940
|
import pyarrow as pa
|
|
1867
1941
|
import pyarrow.parquet as pq
|
|
1868
1942
|
|
|
1869
1943
|
from datachain.lib.arrow import DATACHAIN_SIGNAL_SCHEMA_PARQUET_KEY
|
|
1870
1944
|
|
|
1945
|
+
fsspec_fs = None
|
|
1946
|
+
|
|
1947
|
+
if isinstance(path, str) and "://" in path:
|
|
1948
|
+
from datachain.client.fsspec import Client
|
|
1949
|
+
|
|
1950
|
+
fs_kwargs = {
|
|
1951
|
+
**self._query.catalog.client_config,
|
|
1952
|
+
**(fs_kwargs or {}),
|
|
1953
|
+
}
|
|
1954
|
+
|
|
1955
|
+
client = Client.get_implementation(path)
|
|
1956
|
+
|
|
1957
|
+
if path.startswith("file://"):
|
|
1958
|
+
# pyarrow does not handle file:// uris, and needs a direct path instead.
|
|
1959
|
+
from urllib.parse import urlparse
|
|
1960
|
+
|
|
1961
|
+
path = urlparse(path).path
|
|
1962
|
+
if sys.platform == "win32":
|
|
1963
|
+
path = os.path.normpath(path.lstrip("/"))
|
|
1964
|
+
|
|
1965
|
+
fsspec_fs = client.create_fs(**fs_kwargs)
|
|
1966
|
+
|
|
1871
1967
|
_partition_cols = list(partition_cols) if partition_cols else None
|
|
1872
1968
|
signal_schema_metadata = orjson.dumps(
|
|
1873
1969
|
self._effective_signals_schema.serialize()
|
|
@@ -1902,12 +1998,15 @@ class DataChain:
|
|
|
1902
1998
|
table,
|
|
1903
1999
|
root_path=path,
|
|
1904
2000
|
partition_cols=_partition_cols,
|
|
2001
|
+
filesystem=fsspec_fs,
|
|
1905
2002
|
**kwargs,
|
|
1906
2003
|
)
|
|
1907
2004
|
else:
|
|
1908
2005
|
if first_chunk:
|
|
1909
2006
|
# Write to a single parquet file.
|
|
1910
|
-
parquet_writer = pq.ParquetWriter(
|
|
2007
|
+
parquet_writer = pq.ParquetWriter(
|
|
2008
|
+
path, parquet_schema, filesystem=fsspec_fs, **kwargs
|
|
2009
|
+
)
|
|
1911
2010
|
first_chunk = False
|
|
1912
2011
|
|
|
1913
2012
|
assert parquet_writer
|
|
@@ -1920,28 +2019,122 @@ class DataChain:
|
|
|
1920
2019
|
self,
|
|
1921
2020
|
path: Union[str, os.PathLike[str]],
|
|
1922
2021
|
delimiter: str = ",",
|
|
2022
|
+
fs_kwargs: Optional[dict[str, Any]] = None,
|
|
1923
2023
|
**kwargs,
|
|
1924
2024
|
) -> None:
|
|
1925
2025
|
"""Save chain to a csv (comma-separated values) file.
|
|
1926
2026
|
|
|
1927
2027
|
Parameters:
|
|
1928
|
-
path : Path to save the file.
|
|
2028
|
+
path : Path to save the file. This supports local paths as well as
|
|
2029
|
+
remote paths, such as s3:// or hf:// with fsspec.
|
|
1929
2030
|
delimiter : Delimiter to use for the resulting file.
|
|
2031
|
+
fs_kwargs : Optional kwargs to pass to the fsspec filesystem, used only for
|
|
2032
|
+
write, for fsspec-type URLs, such as s3:// or hf:// when
|
|
2033
|
+
provided as the destination path.
|
|
1930
2034
|
"""
|
|
1931
2035
|
import csv
|
|
1932
2036
|
|
|
2037
|
+
opener = open
|
|
2038
|
+
|
|
2039
|
+
if isinstance(path, str) and "://" in path:
|
|
2040
|
+
from datachain.client.fsspec import Client
|
|
2041
|
+
|
|
2042
|
+
fs_kwargs = {
|
|
2043
|
+
**self._query.catalog.client_config,
|
|
2044
|
+
**(fs_kwargs or {}),
|
|
2045
|
+
}
|
|
2046
|
+
|
|
2047
|
+
client = Client.get_implementation(path)
|
|
2048
|
+
|
|
2049
|
+
fsspec_fs = client.create_fs(**fs_kwargs)
|
|
2050
|
+
|
|
2051
|
+
opener = fsspec_fs.open
|
|
2052
|
+
|
|
1933
2053
|
headers, _ = self._effective_signals_schema.get_headers_with_length()
|
|
1934
2054
|
column_names = [".".join(filter(None, header)) for header in headers]
|
|
1935
2055
|
|
|
1936
2056
|
results_iter = self.collect_flatten()
|
|
1937
2057
|
|
|
1938
|
-
with
|
|
2058
|
+
with opener(path, "w", newline="") as f:
|
|
1939
2059
|
writer = csv.writer(f, delimiter=delimiter, **kwargs)
|
|
1940
2060
|
writer.writerow(column_names)
|
|
1941
2061
|
|
|
1942
2062
|
for row in results_iter:
|
|
1943
2063
|
writer.writerow(row)
|
|
1944
2064
|
|
|
2065
|
+
def to_json(
|
|
2066
|
+
self,
|
|
2067
|
+
path: Union[str, os.PathLike[str]],
|
|
2068
|
+
fs_kwargs: Optional[dict[str, Any]] = None,
|
|
2069
|
+
include_outer_list: bool = True,
|
|
2070
|
+
) -> None:
|
|
2071
|
+
"""Save chain to a JSON file.
|
|
2072
|
+
|
|
2073
|
+
Parameters:
|
|
2074
|
+
path : Path to save the file. This supports local paths as well as
|
|
2075
|
+
remote paths, such as s3:// or hf:// with fsspec.
|
|
2076
|
+
fs_kwargs : Optional kwargs to pass to the fsspec filesystem, used only for
|
|
2077
|
+
write, for fsspec-type URLs, such as s3:// or hf:// when
|
|
2078
|
+
provided as the destination path.
|
|
2079
|
+
include_outer_list : Sets whether to include an outer list for all rows.
|
|
2080
|
+
Setting this to True makes the file valid JSON, while False instead
|
|
2081
|
+
writes in the JSON lines format.
|
|
2082
|
+
"""
|
|
2083
|
+
opener = open
|
|
2084
|
+
|
|
2085
|
+
if isinstance(path, str) and "://" in path:
|
|
2086
|
+
from datachain.client.fsspec import Client
|
|
2087
|
+
|
|
2088
|
+
fs_kwargs = {
|
|
2089
|
+
**self._query.catalog.client_config,
|
|
2090
|
+
**(fs_kwargs or {}),
|
|
2091
|
+
}
|
|
2092
|
+
|
|
2093
|
+
client = Client.get_implementation(path)
|
|
2094
|
+
|
|
2095
|
+
fsspec_fs = client.create_fs(**fs_kwargs)
|
|
2096
|
+
|
|
2097
|
+
opener = fsspec_fs.open
|
|
2098
|
+
|
|
2099
|
+
headers, _ = self._effective_signals_schema.get_headers_with_length()
|
|
2100
|
+
headers = [list(filter(None, header)) for header in headers]
|
|
2101
|
+
|
|
2102
|
+
is_first = True
|
|
2103
|
+
|
|
2104
|
+
with opener(path, "wb") as f:
|
|
2105
|
+
if include_outer_list:
|
|
2106
|
+
# This makes the file JSON instead of JSON lines.
|
|
2107
|
+
f.write(b"[\n")
|
|
2108
|
+
for row in self.collect_flatten():
|
|
2109
|
+
if not is_first:
|
|
2110
|
+
if include_outer_list:
|
|
2111
|
+
# This makes the file JSON instead of JSON lines.
|
|
2112
|
+
f.write(b",\n")
|
|
2113
|
+
else:
|
|
2114
|
+
f.write(b"\n")
|
|
2115
|
+
else:
|
|
2116
|
+
is_first = False
|
|
2117
|
+
f.write(orjson.dumps(row_to_nested_dict(headers, row)))
|
|
2118
|
+
if include_outer_list:
|
|
2119
|
+
# This makes the file JSON instead of JSON lines.
|
|
2120
|
+
f.write(b"\n]\n")
|
|
2121
|
+
|
|
2122
|
+
def to_jsonl(
|
|
2123
|
+
self,
|
|
2124
|
+
path: Union[str, os.PathLike[str]],
|
|
2125
|
+
fs_kwargs: Optional[dict[str, Any]] = None,
|
|
2126
|
+
) -> None:
|
|
2127
|
+
"""Save chain to a JSON lines file.
|
|
2128
|
+
|
|
2129
|
+
Parameters:
|
|
2130
|
+
path : Path to save the file. This supports local paths as well as
|
|
2131
|
+
remote paths, such as s3:// or hf:// with fsspec.
|
|
2132
|
+
fs_kwargs : Optional kwargs to pass to the fsspec filesystem, used only for
|
|
2133
|
+
write, for fsspec-type URLs, such as s3:// or hf:// when
|
|
2134
|
+
provided as the destination path.
|
|
2135
|
+
"""
|
|
2136
|
+
self.to_json(path, fs_kwargs, include_outer_list=False)
|
|
2137
|
+
|
|
1945
2138
|
@classmethod
|
|
1946
2139
|
def from_records(
|
|
1947
2140
|
cls,
|
|
@@ -1994,6 +2187,8 @@ class DataChain:
|
|
|
1994
2187
|
),
|
|
1995
2188
|
)
|
|
1996
2189
|
|
|
2190
|
+
session.add_dataset_version(dsr, dsr.latest_version)
|
|
2191
|
+
|
|
1997
2192
|
if isinstance(to_insert, dict):
|
|
1998
2193
|
to_insert = [to_insert]
|
|
1999
2194
|
elif not to_insert:
|
|
@@ -0,0 +1,32 @@
|
|
|
1
|
+
from .aggregate import (
|
|
2
|
+
any_value,
|
|
3
|
+
avg,
|
|
4
|
+
collect,
|
|
5
|
+
concat,
|
|
6
|
+
count,
|
|
7
|
+
dense_rank,
|
|
8
|
+
first,
|
|
9
|
+
max,
|
|
10
|
+
min,
|
|
11
|
+
rank,
|
|
12
|
+
row_number,
|
|
13
|
+
sum,
|
|
14
|
+
)
|
|
15
|
+
from .func import Func, window
|
|
16
|
+
|
|
17
|
+
__all__ = [
|
|
18
|
+
"Func",
|
|
19
|
+
"any_value",
|
|
20
|
+
"avg",
|
|
21
|
+
"collect",
|
|
22
|
+
"concat",
|
|
23
|
+
"count",
|
|
24
|
+
"dense_rank",
|
|
25
|
+
"first",
|
|
26
|
+
"max",
|
|
27
|
+
"min",
|
|
28
|
+
"rank",
|
|
29
|
+
"row_number",
|
|
30
|
+
"sum",
|
|
31
|
+
"window",
|
|
32
|
+
]
|