datachain 0.4.0__py3-none-any.whl → 0.5.1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of datachain might be problematic. Click here for more details.
- datachain/catalog/catalog.py +8 -0
- datachain/cli.py +3 -2
- datachain/data_storage/metastore.py +28 -9
- datachain/data_storage/sqlite.py +24 -32
- datachain/data_storage/warehouse.py +1 -3
- datachain/dataset.py +0 -3
- datachain/lib/arrow.py +64 -19
- datachain/lib/dc.py +310 -123
- datachain/lib/listing.py +5 -3
- datachain/lib/pytorch.py +5 -1
- datachain/lib/udf.py +100 -78
- datachain/lib/udf_signature.py +8 -6
- datachain/query/dataset.py +7 -7
- datachain/query/dispatch.py +2 -2
- datachain/query/session.py +42 -0
- {datachain-0.4.0.dist-info → datachain-0.5.1.dist-info}/METADATA +1 -1
- {datachain-0.4.0.dist-info → datachain-0.5.1.dist-info}/RECORD +21 -22
- datachain/query/udf.py +0 -126
- {datachain-0.4.0.dist-info → datachain-0.5.1.dist-info}/LICENSE +0 -0
- {datachain-0.4.0.dist-info → datachain-0.5.1.dist-info}/WHEEL +0 -0
- {datachain-0.4.0.dist-info → datachain-0.5.1.dist-info}/entry_points.txt +0 -0
- {datachain-0.4.0.dist-info → datachain-0.5.1.dist-info}/top_level.txt +0 -0
datachain/lib/dc.py
CHANGED
|
@@ -16,6 +16,7 @@ from typing import (
|
|
|
16
16
|
overload,
|
|
17
17
|
)
|
|
18
18
|
|
|
19
|
+
import orjson
|
|
19
20
|
import pandas as pd
|
|
20
21
|
import sqlalchemy
|
|
21
22
|
from pydantic import BaseModel
|
|
@@ -54,12 +55,11 @@ from datachain.query import Session
|
|
|
54
55
|
from datachain.query.dataset import (
|
|
55
56
|
DatasetQuery,
|
|
56
57
|
PartitionByType,
|
|
57
|
-
detach,
|
|
58
58
|
)
|
|
59
59
|
from datachain.query.schema import DEFAULT_DELIMITER, Column, DatasetRow
|
|
60
60
|
from datachain.sql.functions import path as pathfunc
|
|
61
61
|
from datachain.telemetry import telemetry
|
|
62
|
-
from datachain.utils import inside_notebook
|
|
62
|
+
from datachain.utils import batched_it, inside_notebook
|
|
63
63
|
|
|
64
64
|
if TYPE_CHECKING:
|
|
65
65
|
from typing_extensions import Concatenate, ParamSpec, Self
|
|
@@ -72,6 +72,10 @@ C = Column
|
|
|
72
72
|
|
|
73
73
|
_T = TypeVar("_T")
|
|
74
74
|
D = TypeVar("D", bound="DataChain")
|
|
75
|
+
UDFObjT = TypeVar("UDFObjT", bound=UDFBase)
|
|
76
|
+
|
|
77
|
+
|
|
78
|
+
DEFAULT_PARQUET_CHUNK_SIZE = 100_000
|
|
75
79
|
|
|
76
80
|
|
|
77
81
|
def resolve_columns(
|
|
@@ -159,7 +163,7 @@ class Sys(DataModel):
|
|
|
159
163
|
rand: int
|
|
160
164
|
|
|
161
165
|
|
|
162
|
-
class DataChain
|
|
166
|
+
class DataChain:
|
|
163
167
|
"""DataChain - a data structure for batch data processing and evaluation.
|
|
164
168
|
|
|
165
169
|
It represents a sequence of data manipulation steps such as reading data from
|
|
@@ -238,33 +242,20 @@ class DataChain(DatasetQuery):
|
|
|
238
242
|
"size": 0,
|
|
239
243
|
}
|
|
240
244
|
|
|
241
|
-
def __init__(
|
|
242
|
-
|
|
243
|
-
|
|
244
|
-
|
|
245
|
-
|
|
246
|
-
|
|
247
|
-
|
|
248
|
-
|
|
249
|
-
|
|
250
|
-
|
|
251
|
-
|
|
252
|
-
|
|
253
|
-
|
|
254
|
-
|
|
255
|
-
else:
|
|
256
|
-
self._settings = Settings()
|
|
257
|
-
self._setup: dict = {}
|
|
258
|
-
|
|
259
|
-
self.signals_schema = SignalSchema({"sys": Sys})
|
|
260
|
-
if self.feature_schema:
|
|
261
|
-
self.signals_schema |= SignalSchema.deserialize(self.feature_schema)
|
|
262
|
-
else:
|
|
263
|
-
self.signals_schema |= SignalSchema.from_column_types(
|
|
264
|
-
self.column_types or {}
|
|
265
|
-
)
|
|
266
|
-
|
|
267
|
-
self._sys = False
|
|
245
|
+
def __init__(
|
|
246
|
+
self,
|
|
247
|
+
query: DatasetQuery,
|
|
248
|
+
settings: Settings,
|
|
249
|
+
signal_schema: SignalSchema,
|
|
250
|
+
setup: Optional[dict] = None,
|
|
251
|
+
_sys: bool = False,
|
|
252
|
+
) -> None:
|
|
253
|
+
"""Don't instantiate this directly, use one of the from_XXX constructors."""
|
|
254
|
+
self._query = query
|
|
255
|
+
self._settings = settings
|
|
256
|
+
self.signals_schema = signal_schema
|
|
257
|
+
self._setup: dict = setup or {}
|
|
258
|
+
self._sys = _sys
|
|
268
259
|
|
|
269
260
|
@property
|
|
270
261
|
def schema(self) -> dict[str, DataType]:
|
|
@@ -290,18 +281,55 @@ class DataChain(DatasetQuery):
|
|
|
290
281
|
def c(self, column: Union[str, Column]) -> Column:
|
|
291
282
|
"""Returns Column instance attached to the current chain."""
|
|
292
283
|
c = self.column(column) if isinstance(column, str) else self.column(column.name)
|
|
293
|
-
c.table = self.table
|
|
284
|
+
c.table = self._query.table
|
|
294
285
|
return c
|
|
295
286
|
|
|
287
|
+
@property
|
|
288
|
+
def session(self) -> Session:
|
|
289
|
+
"""Session of the chain."""
|
|
290
|
+
return self._query.session
|
|
291
|
+
|
|
292
|
+
@property
|
|
293
|
+
def name(self) -> Optional[str]:
|
|
294
|
+
"""Name of the underlying dataset, if there is one."""
|
|
295
|
+
return self._query.name
|
|
296
|
+
|
|
297
|
+
@property
|
|
298
|
+
def version(self) -> Optional[int]:
|
|
299
|
+
"""Version of the underlying dataset, if there is one."""
|
|
300
|
+
return self._query.version
|
|
301
|
+
|
|
302
|
+
def __or__(self, other: "Self") -> "Self":
|
|
303
|
+
"""Return `self.union(other)`."""
|
|
304
|
+
return self.union(other)
|
|
305
|
+
|
|
296
306
|
def print_schema(self) -> None:
|
|
297
307
|
"""Print schema of the chain."""
|
|
298
308
|
self._effective_signals_schema.print_tree()
|
|
299
309
|
|
|
300
|
-
def clone(self
|
|
310
|
+
def clone(self) -> "Self":
|
|
301
311
|
"""Make a copy of the chain in a new table."""
|
|
302
|
-
|
|
303
|
-
|
|
304
|
-
|
|
312
|
+
return self._evolve(query=self._query.clone(new_table=True))
|
|
313
|
+
|
|
314
|
+
def _evolve(
|
|
315
|
+
self,
|
|
316
|
+
*,
|
|
317
|
+
query: Optional[DatasetQuery] = None,
|
|
318
|
+
settings: Optional[Settings] = None,
|
|
319
|
+
signal_schema=None,
|
|
320
|
+
_sys=None,
|
|
321
|
+
) -> "Self":
|
|
322
|
+
if query is None:
|
|
323
|
+
query = self._query.clone(new_table=False)
|
|
324
|
+
if settings is None:
|
|
325
|
+
settings = self._settings
|
|
326
|
+
if signal_schema is None:
|
|
327
|
+
signal_schema = copy.deepcopy(self.signals_schema)
|
|
328
|
+
if _sys is None:
|
|
329
|
+
_sys = self._sys
|
|
330
|
+
return type(self)(
|
|
331
|
+
query, settings, signal_schema=signal_schema, setup=self._setup, _sys=_sys
|
|
332
|
+
)
|
|
305
333
|
|
|
306
334
|
def settings(
|
|
307
335
|
self,
|
|
@@ -332,11 +360,11 @@ class DataChain(DatasetQuery):
|
|
|
332
360
|
)
|
|
333
361
|
```
|
|
334
362
|
"""
|
|
335
|
-
|
|
336
|
-
|
|
337
|
-
|
|
338
|
-
|
|
339
|
-
return
|
|
363
|
+
if sys is None:
|
|
364
|
+
sys = self._sys
|
|
365
|
+
settings = copy.copy(self._settings)
|
|
366
|
+
settings.add(Settings(cache, parallel, workers, min_task_size))
|
|
367
|
+
return self._evolve(settings=settings, _sys=sys)
|
|
340
368
|
|
|
341
369
|
def reset_settings(self, settings: Optional[Settings] = None) -> "Self":
|
|
342
370
|
"""Reset all settings to default values."""
|
|
@@ -434,7 +462,7 @@ class DataChain(DatasetQuery):
|
|
|
434
462
|
version: Optional[int] = None,
|
|
435
463
|
session: Optional[Session] = None,
|
|
436
464
|
settings: Optional[dict] = None,
|
|
437
|
-
) -> "
|
|
465
|
+
) -> "Self":
|
|
438
466
|
"""Get data from a saved Dataset. It returns the chain itself.
|
|
439
467
|
|
|
440
468
|
Parameters:
|
|
@@ -446,7 +474,24 @@ class DataChain(DatasetQuery):
|
|
|
446
474
|
chain = DataChain.from_dataset("my_cats")
|
|
447
475
|
```
|
|
448
476
|
"""
|
|
449
|
-
|
|
477
|
+
query = DatasetQuery(
|
|
478
|
+
name=name,
|
|
479
|
+
version=version,
|
|
480
|
+
session=session,
|
|
481
|
+
indexing_column_types=File._datachain_column_types,
|
|
482
|
+
)
|
|
483
|
+
telemetry.send_event_once("class", "datachain_init", name=name, version=version)
|
|
484
|
+
if settings:
|
|
485
|
+
_settings = Settings(**settings)
|
|
486
|
+
else:
|
|
487
|
+
_settings = Settings()
|
|
488
|
+
|
|
489
|
+
signals_schema = SignalSchema({"sys": Sys})
|
|
490
|
+
if query.feature_schema:
|
|
491
|
+
signals_schema |= SignalSchema.deserialize(query.feature_schema)
|
|
492
|
+
else:
|
|
493
|
+
signals_schema |= SignalSchema.from_column_types(query.column_types or {})
|
|
494
|
+
return cls(query, _settings, signals_schema)
|
|
450
495
|
|
|
451
496
|
@classmethod
|
|
452
497
|
def from_json(
|
|
@@ -699,7 +744,11 @@ class DataChain(DatasetQuery):
|
|
|
699
744
|
version : version of a dataset. Default - the last version that exist.
|
|
700
745
|
"""
|
|
701
746
|
schema = self.signals_schema.clone_without_sys_signals().serialize()
|
|
702
|
-
return
|
|
747
|
+
return self._evolve(
|
|
748
|
+
query=self._query.save(
|
|
749
|
+
name=name, version=version, feature_schema=schema, **kwargs
|
|
750
|
+
)
|
|
751
|
+
)
|
|
703
752
|
|
|
704
753
|
def apply(self, func, *args, **kwargs):
|
|
705
754
|
"""Apply any function to the chain.
|
|
@@ -765,16 +814,17 @@ class DataChain(DatasetQuery):
|
|
|
765
814
|
"""
|
|
766
815
|
udf_obj = self._udf_to_obj(Mapper, func, params, output, signal_map)
|
|
767
816
|
|
|
768
|
-
|
|
769
|
-
|
|
770
|
-
|
|
817
|
+
return self._evolve(
|
|
818
|
+
query=self._query.add_signals(
|
|
819
|
+
udf_obj.to_udf_wrapper(),
|
|
820
|
+
**self._settings.to_dict(),
|
|
821
|
+
),
|
|
822
|
+
signal_schema=self.signals_schema | udf_obj.output,
|
|
771
823
|
)
|
|
772
824
|
|
|
773
|
-
return chain.add_schema(udf_obj.output).reset_settings(self._settings)
|
|
774
|
-
|
|
775
825
|
def gen(
|
|
776
826
|
self,
|
|
777
|
-
func: Optional[Callable] = None,
|
|
827
|
+
func: Optional[Union[Callable, Generator]] = None,
|
|
778
828
|
params: Union[None, str, Sequence[str]] = None,
|
|
779
829
|
output: OutputType = None,
|
|
780
830
|
**signal_map,
|
|
@@ -800,14 +850,14 @@ class DataChain(DatasetQuery):
|
|
|
800
850
|
```
|
|
801
851
|
"""
|
|
802
852
|
udf_obj = self._udf_to_obj(Generator, func, params, output, signal_map)
|
|
803
|
-
|
|
804
|
-
self
|
|
805
|
-
|
|
806
|
-
|
|
853
|
+
return self._evolve(
|
|
854
|
+
query=self._query.generate(
|
|
855
|
+
udf_obj.to_udf_wrapper(),
|
|
856
|
+
**self._settings.to_dict(),
|
|
857
|
+
),
|
|
858
|
+
signal_schema=udf_obj.output,
|
|
807
859
|
)
|
|
808
860
|
|
|
809
|
-
return chain.reset_schema(udf_obj.output).reset_settings(self._settings)
|
|
810
|
-
|
|
811
861
|
def agg(
|
|
812
862
|
self,
|
|
813
863
|
func: Optional[Callable] = None,
|
|
@@ -840,15 +890,15 @@ class DataChain(DatasetQuery):
|
|
|
840
890
|
```
|
|
841
891
|
"""
|
|
842
892
|
udf_obj = self._udf_to_obj(Aggregator, func, params, output, signal_map)
|
|
843
|
-
|
|
844
|
-
self
|
|
845
|
-
|
|
846
|
-
|
|
847
|
-
|
|
893
|
+
return self._evolve(
|
|
894
|
+
query=self._query.generate(
|
|
895
|
+
udf_obj.to_udf_wrapper(),
|
|
896
|
+
partition_by=partition_by,
|
|
897
|
+
**self._settings.to_dict(),
|
|
898
|
+
),
|
|
899
|
+
signal_schema=udf_obj.output,
|
|
848
900
|
)
|
|
849
901
|
|
|
850
|
-
return chain.reset_schema(udf_obj.output).reset_settings(self._settings)
|
|
851
|
-
|
|
852
902
|
def batch_map(
|
|
853
903
|
self,
|
|
854
904
|
func: Optional[Callable] = None,
|
|
@@ -876,22 +926,22 @@ class DataChain(DatasetQuery):
|
|
|
876
926
|
```
|
|
877
927
|
"""
|
|
878
928
|
udf_obj = self._udf_to_obj(BatchMapper, func, params, output, signal_map)
|
|
879
|
-
|
|
880
|
-
self
|
|
881
|
-
|
|
882
|
-
|
|
929
|
+
return self._evolve(
|
|
930
|
+
query=self._query.add_signals(
|
|
931
|
+
udf_obj.to_udf_wrapper(batch),
|
|
932
|
+
**self._settings.to_dict(),
|
|
933
|
+
),
|
|
934
|
+
signal_schema=self.signals_schema | udf_obj.output,
|
|
883
935
|
)
|
|
884
936
|
|
|
885
|
-
return chain.add_schema(udf_obj.output).reset_settings(self._settings)
|
|
886
|
-
|
|
887
937
|
def _udf_to_obj(
|
|
888
938
|
self,
|
|
889
|
-
target_class: type[
|
|
890
|
-
func: Optional[Callable],
|
|
939
|
+
target_class: type[UDFObjT],
|
|
940
|
+
func: Optional[Union[Callable, UDFObjT]],
|
|
891
941
|
params: Union[None, str, Sequence[str]],
|
|
892
942
|
output: OutputType,
|
|
893
943
|
signal_map,
|
|
894
|
-
) ->
|
|
944
|
+
) -> UDFObjT:
|
|
895
945
|
is_generator = target_class.is_output_batched
|
|
896
946
|
name = self.name or ""
|
|
897
947
|
|
|
@@ -907,17 +957,12 @@ class DataChain(DatasetQuery):
|
|
|
907
957
|
return target_class._create(sign, params_schema)
|
|
908
958
|
|
|
909
959
|
def _extend_to_data_model(self, method_name, *args, **kwargs):
|
|
910
|
-
|
|
960
|
+
query_func = getattr(self._query, method_name)
|
|
911
961
|
|
|
912
962
|
new_schema = self.signals_schema.resolve(*args)
|
|
913
963
|
columns = [C(col) for col in new_schema.db_signals()]
|
|
914
|
-
|
|
915
|
-
if isinstance(res, DataChain):
|
|
916
|
-
res.signals_schema = new_schema
|
|
917
|
-
|
|
918
|
-
return res
|
|
964
|
+
return query_func(*columns, **kwargs)
|
|
919
965
|
|
|
920
|
-
@detach
|
|
921
966
|
@resolve_columns
|
|
922
967
|
def order_by(self, *args, descending: bool = False) -> "Self":
|
|
923
968
|
"""Orders by specified set of signals.
|
|
@@ -928,9 +973,8 @@ class DataChain(DatasetQuery):
|
|
|
928
973
|
if descending:
|
|
929
974
|
args = tuple(sqlalchemy.desc(a) for a in args)
|
|
930
975
|
|
|
931
|
-
return
|
|
976
|
+
return self._evolve(query=self._query.order_by(*args))
|
|
932
977
|
|
|
933
|
-
@detach
|
|
934
978
|
def distinct(self, arg: str, *args: str) -> "Self": # type: ignore[override]
|
|
935
979
|
"""Removes duplicate rows based on uniqueness of some input column(s)
|
|
936
980
|
i.e if rows are found with the same value of input column(s), only one
|
|
@@ -942,29 +986,30 @@ class DataChain(DatasetQuery):
|
|
|
942
986
|
)
|
|
943
987
|
```
|
|
944
988
|
"""
|
|
945
|
-
return
|
|
989
|
+
return self._evolve(
|
|
990
|
+
query=self._query.distinct(
|
|
991
|
+
*self.signals_schema.resolve(arg, *args).db_signals()
|
|
992
|
+
)
|
|
993
|
+
)
|
|
946
994
|
|
|
947
|
-
@detach
|
|
948
995
|
def select(self, *args: str, _sys: bool = True) -> "Self":
|
|
949
996
|
"""Select only a specified set of signals."""
|
|
950
997
|
new_schema = self.signals_schema.resolve(*args)
|
|
951
998
|
if _sys:
|
|
952
999
|
new_schema = SignalSchema({"sys": Sys}) | new_schema
|
|
953
1000
|
columns = new_schema.db_signals()
|
|
954
|
-
|
|
955
|
-
|
|
956
|
-
|
|
1001
|
+
return self._evolve(
|
|
1002
|
+
query=self._query.select(*columns), signal_schema=new_schema
|
|
1003
|
+
)
|
|
957
1004
|
|
|
958
|
-
@detach
|
|
959
1005
|
def select_except(self, *args: str) -> "Self":
|
|
960
1006
|
"""Select all the signals expect the specified signals."""
|
|
961
1007
|
new_schema = self.signals_schema.select_except_signals(*args)
|
|
962
1008
|
columns = new_schema.db_signals()
|
|
963
|
-
|
|
964
|
-
|
|
965
|
-
|
|
1009
|
+
return self._evolve(
|
|
1010
|
+
query=self._query.select(*columns), signal_schema=new_schema
|
|
1011
|
+
)
|
|
966
1012
|
|
|
967
|
-
@detach
|
|
968
1013
|
def mutate(self, **kwargs) -> "Self":
|
|
969
1014
|
"""Create new signals based on existing signals.
|
|
970
1015
|
|
|
@@ -1029,9 +1074,9 @@ class DataChain(DatasetQuery):
|
|
|
1029
1074
|
# adding new signal
|
|
1030
1075
|
mutated[name] = value
|
|
1031
1076
|
|
|
1032
|
-
|
|
1033
|
-
|
|
1034
|
-
|
|
1077
|
+
return self._evolve(
|
|
1078
|
+
query=self._query.mutate(**mutated), signal_schema=schema.mutate(kwargs)
|
|
1079
|
+
)
|
|
1035
1080
|
|
|
1036
1081
|
@property
|
|
1037
1082
|
def _effective_signals_schema(self) -> "SignalSchema":
|
|
@@ -1058,11 +1103,34 @@ class DataChain(DatasetQuery):
|
|
|
1058
1103
|
a tuple of row values.
|
|
1059
1104
|
"""
|
|
1060
1105
|
db_signals = self._effective_signals_schema.db_signals()
|
|
1061
|
-
with
|
|
1106
|
+
with self._query.select(*db_signals).as_iterable() as rows:
|
|
1062
1107
|
if row_factory:
|
|
1063
1108
|
rows = (row_factory(db_signals, r) for r in rows)
|
|
1064
1109
|
yield from rows
|
|
1065
1110
|
|
|
1111
|
+
def to_columnar_data_with_names(
|
|
1112
|
+
self, chunk_size: int = DEFAULT_PARQUET_CHUNK_SIZE
|
|
1113
|
+
) -> tuple[list[str], Iterator[list[list[Any]]]]:
|
|
1114
|
+
"""Returns column names and the results as an iterator that provides chunks,
|
|
1115
|
+
with each chunk containing a list of columns, where each column contains a
|
|
1116
|
+
list of the row values for that column in that chunk. Useful for columnar data
|
|
1117
|
+
formats, such as parquet or other OLAP databases.
|
|
1118
|
+
"""
|
|
1119
|
+
headers, _ = self._effective_signals_schema.get_headers_with_length()
|
|
1120
|
+
column_names = [".".join(filter(None, header)) for header in headers]
|
|
1121
|
+
|
|
1122
|
+
results_iter = self.collect_flatten()
|
|
1123
|
+
|
|
1124
|
+
def column_chunks() -> Iterator[list[list[Any]]]:
|
|
1125
|
+
for chunk_iter in batched_it(results_iter, chunk_size):
|
|
1126
|
+
columns: list[list[Any]] = [[] for _ in column_names]
|
|
1127
|
+
for row in chunk_iter:
|
|
1128
|
+
for i, col in enumerate(columns):
|
|
1129
|
+
col.append(row[i])
|
|
1130
|
+
yield columns
|
|
1131
|
+
|
|
1132
|
+
return column_names, column_chunks()
|
|
1133
|
+
|
|
1066
1134
|
@overload
|
|
1067
1135
|
def results(self) -> list[tuple[Any, ...]]: ...
|
|
1068
1136
|
|
|
@@ -1126,7 +1194,7 @@ class DataChain(DatasetQuery):
|
|
|
1126
1194
|
chain = self.select(*cols) if cols else self
|
|
1127
1195
|
signals_schema = chain._effective_signals_schema
|
|
1128
1196
|
db_signals = signals_schema.db_signals()
|
|
1129
|
-
with
|
|
1197
|
+
with self._query.select(*db_signals).as_iterable() as rows:
|
|
1130
1198
|
for row in rows:
|
|
1131
1199
|
ret = signals_schema.row_to_features(
|
|
1132
1200
|
row, catalog=chain.session.catalog, cache=chain._settings.cache
|
|
@@ -1156,7 +1224,7 @@ class DataChain(DatasetQuery):
|
|
|
1156
1224
|
"""
|
|
1157
1225
|
from datachain.torch import PytorchDataset
|
|
1158
1226
|
|
|
1159
|
-
if self.attached:
|
|
1227
|
+
if self._query.attached:
|
|
1160
1228
|
chain = self
|
|
1161
1229
|
else:
|
|
1162
1230
|
chain = self.save()
|
|
@@ -1164,7 +1232,7 @@ class DataChain(DatasetQuery):
|
|
|
1164
1232
|
return PytorchDataset(
|
|
1165
1233
|
chain.name,
|
|
1166
1234
|
chain.version,
|
|
1167
|
-
catalog=self.catalog,
|
|
1235
|
+
catalog=self.session.catalog,
|
|
1168
1236
|
transform=transform,
|
|
1169
1237
|
tokenizer=tokenizer,
|
|
1170
1238
|
tokenizer_kwargs=tokenizer_kwargs,
|
|
@@ -1175,7 +1243,6 @@ class DataChain(DatasetQuery):
|
|
|
1175
1243
|
schema = self.signals_schema.clone_without_file_signals()
|
|
1176
1244
|
return self.select(*schema.values.keys())
|
|
1177
1245
|
|
|
1178
|
-
@detach
|
|
1179
1246
|
def merge(
|
|
1180
1247
|
self,
|
|
1181
1248
|
right_ds: "DataChain",
|
|
@@ -1240,7 +1307,7 @@ class DataChain(DatasetQuery):
|
|
|
1240
1307
|
)
|
|
1241
1308
|
|
|
1242
1309
|
if self == right_ds:
|
|
1243
|
-
right_ds = right_ds.clone(
|
|
1310
|
+
right_ds = right_ds.clone()
|
|
1244
1311
|
|
|
1245
1312
|
errors = []
|
|
1246
1313
|
|
|
@@ -1266,9 +1333,11 @@ class DataChain(DatasetQuery):
|
|
|
1266
1333
|
on, right_on, f"Could not resolve {', '.join(errors)}"
|
|
1267
1334
|
)
|
|
1268
1335
|
|
|
1269
|
-
|
|
1270
|
-
|
|
1271
|
-
|
|
1336
|
+
query = self._query.join(
|
|
1337
|
+
right_ds._query, sqlalchemy.and_(*ops), inner, rname + "{name}"
|
|
1338
|
+
)
|
|
1339
|
+
query.feature_schema = None
|
|
1340
|
+
ds = self._evolve(query=query)
|
|
1272
1341
|
|
|
1273
1342
|
signals_schema = self.signals_schema.clone_without_sys_signals()
|
|
1274
1343
|
right_signals_schema = right_ds.signals_schema.clone_without_sys_signals()
|
|
@@ -1278,6 +1347,14 @@ class DataChain(DatasetQuery):
|
|
|
1278
1347
|
|
|
1279
1348
|
return ds
|
|
1280
1349
|
|
|
1350
|
+
def union(self, other: "Self") -> "Self":
|
|
1351
|
+
"""Return the set union of the two datasets.
|
|
1352
|
+
|
|
1353
|
+
Parameters:
|
|
1354
|
+
other: chain whose rows will be added to `self`.
|
|
1355
|
+
"""
|
|
1356
|
+
return self._evolve(query=self._query.union(other._query))
|
|
1357
|
+
|
|
1281
1358
|
def subtract( # type: ignore[override]
|
|
1282
1359
|
self,
|
|
1283
1360
|
other: "DataChain",
|
|
@@ -1341,7 +1418,7 @@ class DataChain(DatasetQuery):
|
|
|
1341
1418
|
other.signals_schema.resolve(*right_on).db_signals(),
|
|
1342
1419
|
) # type: ignore[arg-type]
|
|
1343
1420
|
)
|
|
1344
|
-
return
|
|
1421
|
+
return self._evolve(query=self._query.subtract(other._query, signals)) # type: ignore[arg-type]
|
|
1345
1422
|
|
|
1346
1423
|
@classmethod
|
|
1347
1424
|
def from_values(
|
|
@@ -1449,7 +1526,7 @@ class DataChain(DatasetQuery):
|
|
|
1449
1526
|
transpose : Whether to transpose rows and columns.
|
|
1450
1527
|
truncate : Whether or not to truncate the contents of columns.
|
|
1451
1528
|
"""
|
|
1452
|
-
dc = self.limit(limit) if limit > 0 else self
|
|
1529
|
+
dc = self.limit(limit) if limit > 0 else self # type: ignore[misc]
|
|
1453
1530
|
df = dc.to_pandas(flatten)
|
|
1454
1531
|
|
|
1455
1532
|
if df.empty:
|
|
@@ -1759,21 +1836,96 @@ class DataChain(DatasetQuery):
|
|
|
1759
1836
|
self,
|
|
1760
1837
|
path: Union[str, os.PathLike[str], BinaryIO],
|
|
1761
1838
|
partition_cols: Optional[Sequence[str]] = None,
|
|
1839
|
+
chunk_size: int = DEFAULT_PARQUET_CHUNK_SIZE,
|
|
1762
1840
|
**kwargs,
|
|
1763
1841
|
) -> None:
|
|
1764
|
-
"""Save chain to parquet file.
|
|
1842
|
+
"""Save chain to parquet file with SignalSchema metadata.
|
|
1765
1843
|
|
|
1766
1844
|
Parameters:
|
|
1767
1845
|
path : Path or a file-like binary object to save the file.
|
|
1768
1846
|
partition_cols : Column names by which to partition the dataset.
|
|
1847
|
+
chunk_size : The chunk size of results to read and convert to columnar
|
|
1848
|
+
data, to avoid running out of memory.
|
|
1769
1849
|
"""
|
|
1850
|
+
import pyarrow as pa
|
|
1851
|
+
import pyarrow.parquet as pq
|
|
1852
|
+
|
|
1853
|
+
from datachain.lib.arrow import DATACHAIN_SIGNAL_SCHEMA_PARQUET_KEY
|
|
1854
|
+
|
|
1770
1855
|
_partition_cols = list(partition_cols) if partition_cols else None
|
|
1771
|
-
|
|
1772
|
-
|
|
1773
|
-
partition_cols=_partition_cols,
|
|
1774
|
-
**kwargs,
|
|
1856
|
+
signal_schema_metadata = orjson.dumps(
|
|
1857
|
+
self._effective_signals_schema.serialize()
|
|
1775
1858
|
)
|
|
1776
1859
|
|
|
1860
|
+
column_names, column_chunks = self.to_columnar_data_with_names(chunk_size)
|
|
1861
|
+
|
|
1862
|
+
parquet_schema = None
|
|
1863
|
+
parquet_writer = None
|
|
1864
|
+
first_chunk = True
|
|
1865
|
+
|
|
1866
|
+
for chunk in column_chunks:
|
|
1867
|
+
# pyarrow infers the best parquet schema from the python types of
|
|
1868
|
+
# the input data.
|
|
1869
|
+
table = pa.Table.from_pydict(
|
|
1870
|
+
dict(zip(column_names, chunk)),
|
|
1871
|
+
schema=parquet_schema,
|
|
1872
|
+
)
|
|
1873
|
+
|
|
1874
|
+
# Preserve any existing metadata, and add the DataChain SignalSchema.
|
|
1875
|
+
existing_metadata = table.schema.metadata or {}
|
|
1876
|
+
merged_metadata = {
|
|
1877
|
+
**existing_metadata,
|
|
1878
|
+
DATACHAIN_SIGNAL_SCHEMA_PARQUET_KEY: signal_schema_metadata,
|
|
1879
|
+
}
|
|
1880
|
+
table = table.replace_schema_metadata(merged_metadata)
|
|
1881
|
+
parquet_schema = table.schema
|
|
1882
|
+
|
|
1883
|
+
if _partition_cols:
|
|
1884
|
+
# Write to a partitioned parquet dataset.
|
|
1885
|
+
pq.write_to_dataset(
|
|
1886
|
+
table,
|
|
1887
|
+
root_path=path,
|
|
1888
|
+
partition_cols=_partition_cols,
|
|
1889
|
+
**kwargs,
|
|
1890
|
+
)
|
|
1891
|
+
else:
|
|
1892
|
+
if first_chunk:
|
|
1893
|
+
# Write to a single parquet file.
|
|
1894
|
+
parquet_writer = pq.ParquetWriter(path, parquet_schema, **kwargs)
|
|
1895
|
+
first_chunk = False
|
|
1896
|
+
|
|
1897
|
+
assert parquet_writer
|
|
1898
|
+
parquet_writer.write_table(table)
|
|
1899
|
+
|
|
1900
|
+
if parquet_writer:
|
|
1901
|
+
parquet_writer.close()
|
|
1902
|
+
|
|
1903
|
+
def to_csv(
|
|
1904
|
+
self,
|
|
1905
|
+
path: Union[str, os.PathLike[str]],
|
|
1906
|
+
delimiter: str = ",",
|
|
1907
|
+
**kwargs,
|
|
1908
|
+
) -> None:
|
|
1909
|
+
"""Save chain to a csv (comma-separated values) file.
|
|
1910
|
+
|
|
1911
|
+
Parameters:
|
|
1912
|
+
path : Path to save the file.
|
|
1913
|
+
delimiter : Delimiter to use for the resulting file.
|
|
1914
|
+
"""
|
|
1915
|
+
import csv
|
|
1916
|
+
|
|
1917
|
+
headers, _ = self._effective_signals_schema.get_headers_with_length()
|
|
1918
|
+
column_names = [".".join(filter(None, header)) for header in headers]
|
|
1919
|
+
|
|
1920
|
+
results_iter = self.collect_flatten()
|
|
1921
|
+
|
|
1922
|
+
with open(path, "w", newline="") as f:
|
|
1923
|
+
writer = csv.writer(f, delimiter=delimiter, **kwargs)
|
|
1924
|
+
writer.writerow(column_names)
|
|
1925
|
+
|
|
1926
|
+
for row in results_iter:
|
|
1927
|
+
writer.writerow(row)
|
|
1928
|
+
|
|
1777
1929
|
@classmethod
|
|
1778
1930
|
def from_records(
|
|
1779
1931
|
cls,
|
|
@@ -1782,7 +1934,7 @@ class DataChain(DatasetQuery):
|
|
|
1782
1934
|
settings: Optional[dict] = None,
|
|
1783
1935
|
in_memory: bool = False,
|
|
1784
1936
|
schema: Optional[dict[str, DataType]] = None,
|
|
1785
|
-
) -> "
|
|
1937
|
+
) -> "Self":
|
|
1786
1938
|
"""Create a DataChain from the provided records. This method can be used for
|
|
1787
1939
|
programmatically generating a chain in contrast of reading data from storages
|
|
1788
1940
|
or other sources.
|
|
@@ -1837,7 +1989,7 @@ class DataChain(DatasetQuery):
|
|
|
1837
1989
|
insert_q = dr.get_table().insert()
|
|
1838
1990
|
for record in to_insert:
|
|
1839
1991
|
db.execute(insert_q.values(**record))
|
|
1840
|
-
return
|
|
1992
|
+
return cls.from_dataset(name=dsr.name, session=session, settings=settings)
|
|
1841
1993
|
|
|
1842
1994
|
def sum(self, fr: DataType): # type: ignore[override]
|
|
1843
1995
|
"""Compute the sum of a column."""
|
|
@@ -1898,8 +2050,8 @@ class DataChain(DatasetQuery):
|
|
|
1898
2050
|
) -> None:
|
|
1899
2051
|
"""Method that exports all files from chain to some folder."""
|
|
1900
2052
|
if placement == "filename" and (
|
|
1901
|
-
|
|
1902
|
-
!= self.count()
|
|
2053
|
+
self._query.distinct(pathfunc.name(C(f"{signal}__path"))).count()
|
|
2054
|
+
!= self._query.count()
|
|
1903
2055
|
):
|
|
1904
2056
|
raise ValueError("Files with the same name found")
|
|
1905
2057
|
|
|
@@ -1919,10 +2071,9 @@ class DataChain(DatasetQuery):
|
|
|
1919
2071
|
NOTE: Samples are not deterministic, and streamed/paginated queries or
|
|
1920
2072
|
multiple workers will draw samples with replacement.
|
|
1921
2073
|
"""
|
|
1922
|
-
return
|
|
2074
|
+
return self._evolve(query=self._query.sample(n))
|
|
1923
2075
|
|
|
1924
|
-
|
|
1925
|
-
def filter(self, *args) -> "Self":
|
|
2076
|
+
def filter(self, *args: Any) -> "Self":
|
|
1926
2077
|
"""Filter the chain according to conditions.
|
|
1927
2078
|
|
|
1928
2079
|
Example:
|
|
@@ -1955,14 +2106,50 @@ class DataChain(DatasetQuery):
|
|
|
1955
2106
|
)
|
|
1956
2107
|
```
|
|
1957
2108
|
"""
|
|
1958
|
-
return
|
|
2109
|
+
return self._evolve(query=self._query.filter(*args))
|
|
1959
2110
|
|
|
1960
|
-
@detach
|
|
1961
2111
|
def limit(self, n: int) -> "Self":
|
|
1962
|
-
"""Return the first n rows of the chain.
|
|
1963
|
-
|
|
2112
|
+
"""Return the first `n` rows of the chain.
|
|
2113
|
+
|
|
2114
|
+
If the chain is unordered, which rows are returned is undefined.
|
|
2115
|
+
If the chain has less than `n` rows, the whole chain is returned.
|
|
2116
|
+
|
|
2117
|
+
Parameters:
|
|
2118
|
+
n (int): Number of rows to return.
|
|
2119
|
+
"""
|
|
2120
|
+
return self._evolve(query=self._query.limit(n))
|
|
1964
2121
|
|
|
1965
|
-
@detach
|
|
1966
2122
|
def offset(self, offset: int) -> "Self":
|
|
1967
|
-
"""Return the results starting with the offset row.
|
|
1968
|
-
|
|
2123
|
+
"""Return the results starting with the offset row.
|
|
2124
|
+
|
|
2125
|
+
If the chain is unordered, which rows are skipped in undefined.
|
|
2126
|
+
If the chain has less than `offset` rows, the result is an empty chain.
|
|
2127
|
+
|
|
2128
|
+
Parameters:
|
|
2129
|
+
offset (int): Number of rows to skip.
|
|
2130
|
+
"""
|
|
2131
|
+
return self._evolve(query=self._query.offset(offset))
|
|
2132
|
+
|
|
2133
|
+
def count(self) -> int:
|
|
2134
|
+
"""Return the number of rows in the chain."""
|
|
2135
|
+
return self._query.count()
|
|
2136
|
+
|
|
2137
|
+
def exec(self) -> "Self":
|
|
2138
|
+
"""Execute the chain."""
|
|
2139
|
+
return self._evolve(query=self._query.exec())
|
|
2140
|
+
|
|
2141
|
+
def chunk(self, index: int, total: int) -> "Self":
|
|
2142
|
+
"""Split a chain into smaller chunks for e.g. parallelization.
|
|
2143
|
+
|
|
2144
|
+
Example:
|
|
2145
|
+
```py
|
|
2146
|
+
chain = DataChain.from_storage(...)
|
|
2147
|
+
chunk_1 = query._chunk(0, 2)
|
|
2148
|
+
chunk_2 = query._chunk(1, 2)
|
|
2149
|
+
```
|
|
2150
|
+
|
|
2151
|
+
Note:
|
|
2152
|
+
Bear in mind that `index` is 0-indexed but `total` isn't.
|
|
2153
|
+
Use 0/3, 1/3 and 2/3, not 1/3, 2/3 and 3/3.
|
|
2154
|
+
"""
|
|
2155
|
+
return self._evolve(query=self._query.chunk(index, total))
|