datachain 0.3.7__py3-none-any.whl → 0.3.9__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of datachain might be problematic. Click here for more details.
- datachain/catalog/catalog.py +2 -92
- datachain/cli.py +0 -37
- datachain/lib/arrow.py +5 -5
- datachain/lib/clip.py +14 -3
- datachain/lib/convert/python_to_sql.py +9 -0
- datachain/lib/data_model.py +10 -1
- datachain/lib/dc.py +135 -39
- datachain/lib/hf.py +166 -0
- datachain/lib/image.py +9 -1
- datachain/lib/pytorch.py +1 -2
- datachain/lib/signal_schema.py +124 -20
- datachain/lib/text.py +4 -0
- datachain/lib/udf.py +14 -20
- datachain/lib/webdataset.py +1 -1
- datachain/query/dataset.py +24 -9
- datachain/query/session.py +5 -3
- {datachain-0.3.7.dist-info → datachain-0.3.9.dist-info}/METADATA +19 -15
- {datachain-0.3.7.dist-info → datachain-0.3.9.dist-info}/RECORD +22 -21
- {datachain-0.3.7.dist-info → datachain-0.3.9.dist-info}/WHEEL +1 -1
- {datachain-0.3.7.dist-info → datachain-0.3.9.dist-info}/LICENSE +0 -0
- {datachain-0.3.7.dist-info → datachain-0.3.9.dist-info}/entry_points.txt +0 -0
- {datachain-0.3.7.dist-info → datachain-0.3.9.dist-info}/top_level.txt +0 -0
datachain/catalog/catalog.py
CHANGED
|
@@ -1540,87 +1540,6 @@ class Catalog:
|
|
|
1540
1540
|
dataset = self.get_dataset(name)
|
|
1541
1541
|
return self.update_dataset(dataset, **update_data)
|
|
1542
1542
|
|
|
1543
|
-
def merge_datasets(
|
|
1544
|
-
self,
|
|
1545
|
-
src: DatasetRecord,
|
|
1546
|
-
dst: DatasetRecord,
|
|
1547
|
-
src_version: int,
|
|
1548
|
-
dst_version: Optional[int] = None,
|
|
1549
|
-
) -> DatasetRecord:
|
|
1550
|
-
"""
|
|
1551
|
-
Merges records from source to destination dataset.
|
|
1552
|
-
It will create new version
|
|
1553
|
-
of a dataset with records merged from old version and the source, unless
|
|
1554
|
-
existing version is specified for destination in which case it must
|
|
1555
|
-
be in non final status as datasets are immutable
|
|
1556
|
-
"""
|
|
1557
|
-
if (
|
|
1558
|
-
dst_version
|
|
1559
|
-
and not dst.is_valid_next_version(dst_version)
|
|
1560
|
-
and dst.get_version(dst_version).is_final_status()
|
|
1561
|
-
):
|
|
1562
|
-
raise DatasetInvalidVersionError(
|
|
1563
|
-
f"Version {dst_version} must be higher than the current latest one"
|
|
1564
|
-
)
|
|
1565
|
-
|
|
1566
|
-
src_dep = self.get_dataset_dependencies(src.name, src_version)
|
|
1567
|
-
dst_dep = self.get_dataset_dependencies(
|
|
1568
|
-
dst.name,
|
|
1569
|
-
dst.latest_version, # type: ignore[arg-type]
|
|
1570
|
-
)
|
|
1571
|
-
|
|
1572
|
-
if dst.has_version(dst_version): # type: ignore[arg-type]
|
|
1573
|
-
# case where we don't create new version, but append to the existing one
|
|
1574
|
-
self.warehouse.merge_dataset_rows(
|
|
1575
|
-
src,
|
|
1576
|
-
dst,
|
|
1577
|
-
src_version,
|
|
1578
|
-
dst_version=dst_version, # type: ignore[arg-type]
|
|
1579
|
-
)
|
|
1580
|
-
merged_schema = src.serialized_schema | dst.serialized_schema
|
|
1581
|
-
self.update_dataset(dst, schema=merged_schema)
|
|
1582
|
-
self.update_dataset_version_with_warehouse_info(
|
|
1583
|
-
dst,
|
|
1584
|
-
dst_version, # type: ignore[arg-type]
|
|
1585
|
-
schema=merged_schema,
|
|
1586
|
-
)
|
|
1587
|
-
for dep in src_dep:
|
|
1588
|
-
if dep and dep not in dst_dep:
|
|
1589
|
-
self.metastore.add_dependency(
|
|
1590
|
-
dep,
|
|
1591
|
-
dst.name,
|
|
1592
|
-
dst_version, # type: ignore[arg-type]
|
|
1593
|
-
)
|
|
1594
|
-
else:
|
|
1595
|
-
# case where we create new version of merged results
|
|
1596
|
-
src_dr = self.warehouse.dataset_rows(src, src_version)
|
|
1597
|
-
dst_dr = self.warehouse.dataset_rows(dst)
|
|
1598
|
-
|
|
1599
|
-
merge_result_columns = list(
|
|
1600
|
-
{
|
|
1601
|
-
c.name: c for c in list(src_dr.table.c) + list(dst_dr.table.c)
|
|
1602
|
-
}.values()
|
|
1603
|
-
)
|
|
1604
|
-
|
|
1605
|
-
dst_version = dst_version or dst.next_version
|
|
1606
|
-
dst = self.create_new_dataset_version(
|
|
1607
|
-
dst,
|
|
1608
|
-
dst_version,
|
|
1609
|
-
columns=merge_result_columns,
|
|
1610
|
-
)
|
|
1611
|
-
self.warehouse.merge_dataset_rows(
|
|
1612
|
-
src,
|
|
1613
|
-
dst,
|
|
1614
|
-
src_version,
|
|
1615
|
-
dst_version,
|
|
1616
|
-
)
|
|
1617
|
-
self.update_dataset_version_with_warehouse_info(dst, dst_version)
|
|
1618
|
-
for dep in set(src_dep + dst_dep):
|
|
1619
|
-
if dep:
|
|
1620
|
-
self.metastore.add_dependency(dep, dst.name, dst_version)
|
|
1621
|
-
|
|
1622
|
-
return dst
|
|
1623
|
-
|
|
1624
1543
|
def get_file_signals(
|
|
1625
1544
|
self, dataset_name: str, dataset_version: int, row: RowDict
|
|
1626
1545
|
) -> Optional[dict]:
|
|
@@ -1641,17 +1560,8 @@ class Catalog:
|
|
|
1641
1560
|
version = self.get_dataset(dataset_name).get_version(dataset_version)
|
|
1642
1561
|
|
|
1643
1562
|
file_signals_values = {}
|
|
1644
|
-
file_schemas = {}
|
|
1645
|
-
# TODO: To remove after we properly fix deserialization
|
|
1646
|
-
for signal, type_name in version.feature_schema.items():
|
|
1647
|
-
from datachain.lib.model_store import ModelStore
|
|
1648
|
-
|
|
1649
|
-
type_name_parsed, v = ModelStore.parse_name_version(type_name)
|
|
1650
|
-
fr = ModelStore.get(type_name_parsed, v)
|
|
1651
|
-
if fr and issubclass(fr, File):
|
|
1652
|
-
file_schemas[signal] = type_name
|
|
1653
1563
|
|
|
1654
|
-
schema = SignalSchema.deserialize(
|
|
1564
|
+
schema = SignalSchema.deserialize(version.feature_schema)
|
|
1655
1565
|
for file_signals in schema.get_signals(File):
|
|
1656
1566
|
prefix = file_signals.replace(".", DEFAULT_DELIMITER) + DEFAULT_DELIMITER
|
|
1657
1567
|
file_signals_values[file_signals] = {
|
|
@@ -1997,7 +1907,7 @@ class Catalog:
|
|
|
1997
1907
|
"""
|
|
1998
1908
|
from datachain.query.dataset import ExecutionResult
|
|
1999
1909
|
|
|
2000
|
-
feature_file = tempfile.NamedTemporaryFile(
|
|
1910
|
+
feature_file = tempfile.NamedTemporaryFile( # noqa: SIM115
|
|
2001
1911
|
dir=os.getcwd(), suffix=".py", delete=False
|
|
2002
1912
|
)
|
|
2003
1913
|
_, feature_module = os.path.split(feature_file.name)
|
datachain/cli.py
CHANGED
|
@@ -336,36 +336,6 @@ def get_parser() -> ArgumentParser: # noqa: PLR0915
|
|
|
336
336
|
help="Display size using powers of 1000 not 1024",
|
|
337
337
|
)
|
|
338
338
|
|
|
339
|
-
parse_merge_datasets = subp.add_parser(
|
|
340
|
-
"merge-datasets", parents=[parent_parser], description="Merges datasets"
|
|
341
|
-
)
|
|
342
|
-
parse_merge_datasets.add_argument(
|
|
343
|
-
"--src",
|
|
344
|
-
action="store",
|
|
345
|
-
default=None,
|
|
346
|
-
help="Source dataset name",
|
|
347
|
-
)
|
|
348
|
-
parse_merge_datasets.add_argument(
|
|
349
|
-
"--dst",
|
|
350
|
-
action="store",
|
|
351
|
-
default=None,
|
|
352
|
-
help="Destination dataset name",
|
|
353
|
-
)
|
|
354
|
-
parse_merge_datasets.add_argument(
|
|
355
|
-
"--src-version",
|
|
356
|
-
action="store",
|
|
357
|
-
default=None,
|
|
358
|
-
type=int,
|
|
359
|
-
help="Source dataset version",
|
|
360
|
-
)
|
|
361
|
-
parse_merge_datasets.add_argument(
|
|
362
|
-
"--dst-version",
|
|
363
|
-
action="store",
|
|
364
|
-
default=None,
|
|
365
|
-
type=int,
|
|
366
|
-
help="Destination dataset version",
|
|
367
|
-
)
|
|
368
|
-
|
|
369
339
|
parse_ls = subp.add_parser(
|
|
370
340
|
"ls", parents=[parent_parser], description="List storage contents"
|
|
371
341
|
)
|
|
@@ -996,13 +966,6 @@ def main(argv: Optional[list[str]] = None) -> int: # noqa: C901, PLR0912, PLR09
|
|
|
996
966
|
new_name=args.new_name,
|
|
997
967
|
labels=args.labels,
|
|
998
968
|
)
|
|
999
|
-
elif args.command == "merge-datasets":
|
|
1000
|
-
catalog.merge_datasets(
|
|
1001
|
-
catalog.get_dataset(args.src),
|
|
1002
|
-
catalog.get_dataset(args.dst),
|
|
1003
|
-
args.src_version,
|
|
1004
|
-
dst_version=args.dst_version,
|
|
1005
|
-
)
|
|
1006
969
|
elif args.command == "ls":
|
|
1007
970
|
ls(
|
|
1008
971
|
args.sources,
|
datachain/lib/arrow.py
CHANGED
|
@@ -95,7 +95,7 @@ def schema_to_output(schema: pa.Schema, col_names: Optional[Sequence[str]] = Non
|
|
|
95
95
|
if not column:
|
|
96
96
|
column = f"c{default_column}"
|
|
97
97
|
default_column += 1
|
|
98
|
-
dtype =
|
|
98
|
+
dtype = arrow_type_mapper(field.type) # type: ignore[assignment]
|
|
99
99
|
if field.nullable:
|
|
100
100
|
dtype = Optional[dtype] # type: ignore[assignment]
|
|
101
101
|
output[column] = dtype
|
|
@@ -103,7 +103,7 @@ def schema_to_output(schema: pa.Schema, col_names: Optional[Sequence[str]] = Non
|
|
|
103
103
|
return output
|
|
104
104
|
|
|
105
105
|
|
|
106
|
-
def
|
|
106
|
+
def arrow_type_mapper(col_type: pa.DataType) -> type: # noqa: PLR0911
|
|
107
107
|
"""Convert pyarrow types to basic types."""
|
|
108
108
|
from datetime import datetime
|
|
109
109
|
|
|
@@ -122,16 +122,16 @@ def _arrow_type_mapper(col_type: pa.DataType) -> type: # noqa: PLR0911
|
|
|
122
122
|
if pa.types.is_string(col_type) or pa.types.is_large_string(col_type):
|
|
123
123
|
return str
|
|
124
124
|
if pa.types.is_list(col_type):
|
|
125
|
-
return list[
|
|
125
|
+
return list[arrow_type_mapper(col_type.value_type)] # type: ignore[return-value, misc]
|
|
126
126
|
if pa.types.is_struct(col_type) or pa.types.is_map(col_type):
|
|
127
127
|
return dict
|
|
128
128
|
if isinstance(col_type, pa.lib.DictionaryType):
|
|
129
|
-
return
|
|
129
|
+
return arrow_type_mapper(col_type.value_type) # type: ignore[return-value]
|
|
130
130
|
raise TypeError(f"{col_type!r} datatypes not supported")
|
|
131
131
|
|
|
132
132
|
|
|
133
133
|
def _nrows_file(file: File, nrows: int) -> str:
|
|
134
|
-
tf = NamedTemporaryFile(delete=False)
|
|
134
|
+
tf = NamedTemporaryFile(delete=False) # noqa: SIM115
|
|
135
135
|
with file.open(mode="r") as reader:
|
|
136
136
|
with open(tf.name, "a") as writer:
|
|
137
137
|
for row, line in enumerate(reader):
|
datachain/lib/clip.py
CHANGED
|
@@ -1,5 +1,5 @@
|
|
|
1
1
|
import inspect
|
|
2
|
-
from typing import TYPE_CHECKING, Any, Callable, Literal, Union
|
|
2
|
+
from typing import TYPE_CHECKING, Any, Callable, Literal, Optional, Union
|
|
3
3
|
|
|
4
4
|
import torch
|
|
5
5
|
from transformers.modeling_utils import PreTrainedModel
|
|
@@ -39,6 +39,7 @@ def clip_similarity_scores(
|
|
|
39
39
|
tokenizer: Callable,
|
|
40
40
|
prob: bool = False,
|
|
41
41
|
image_to_text: bool = True,
|
|
42
|
+
device: Optional[Union[str, torch.device]] = None,
|
|
42
43
|
) -> list[list[float]]:
|
|
43
44
|
"""
|
|
44
45
|
Calculate CLIP similarity scores between one or more images and/or text.
|
|
@@ -52,6 +53,7 @@ def clip_similarity_scores(
|
|
|
52
53
|
prob : Compute softmax probabilities.
|
|
53
54
|
image_to_text : Whether to compute for image-to-text or text-to-image. Ignored
|
|
54
55
|
if only one of images or text provided.
|
|
56
|
+
device : Device to use. Defaults is None - use model's device.
|
|
55
57
|
|
|
56
58
|
|
|
57
59
|
Example:
|
|
@@ -130,17 +132,26 @@ def clip_similarity_scores(
|
|
|
130
132
|
```
|
|
131
133
|
"""
|
|
132
134
|
|
|
135
|
+
if device is None:
|
|
136
|
+
if hasattr(model, "device"):
|
|
137
|
+
device = model.device
|
|
138
|
+
else:
|
|
139
|
+
device = next(model.parameters()).device
|
|
140
|
+
else:
|
|
141
|
+
model = model.to(device)
|
|
133
142
|
with torch.no_grad():
|
|
134
143
|
if images is not None:
|
|
135
144
|
encoder = _get_encoder(model, "image")
|
|
136
145
|
image_features = convert_images(
|
|
137
|
-
images, transform=preprocess, encoder=encoder
|
|
146
|
+
images, transform=preprocess, encoder=encoder, device=device
|
|
138
147
|
)
|
|
139
148
|
image_features /= image_features.norm(dim=-1, keepdim=True) # type: ignore[union-attr]
|
|
140
149
|
|
|
141
150
|
if text is not None:
|
|
142
151
|
encoder = _get_encoder(model, "text")
|
|
143
|
-
text_features = convert_text(
|
|
152
|
+
text_features = convert_text(
|
|
153
|
+
text, tokenizer, encoder=encoder, device=device
|
|
154
|
+
)
|
|
144
155
|
text_features /= text_features.norm(dim=-1, keepdim=True) # type: ignore[union-attr]
|
|
145
156
|
|
|
146
157
|
if images is not None and text is not None:
|
|
@@ -73,6 +73,9 @@ def python_to_sql(typ): # noqa: PLR0911
|
|
|
73
73
|
if len(args) == 2 and (type(None) in args):
|
|
74
74
|
return python_to_sql(args[0])
|
|
75
75
|
|
|
76
|
+
if _is_union_str_literal(orig, args):
|
|
77
|
+
return String
|
|
78
|
+
|
|
76
79
|
if _is_json_inside_union(orig, args):
|
|
77
80
|
return JSON
|
|
78
81
|
|
|
@@ -94,3 +97,9 @@ def _is_json_inside_union(orig, args) -> bool:
|
|
|
94
97
|
if any(inspect.isclass(arg) and issubclass(arg, BaseModel) for arg in args):
|
|
95
98
|
return True
|
|
96
99
|
return False
|
|
100
|
+
|
|
101
|
+
|
|
102
|
+
def _is_union_str_literal(orig, args) -> bool:
|
|
103
|
+
if orig != Union:
|
|
104
|
+
return False
|
|
105
|
+
return all(arg is str or get_origin(arg) in (Literal, LiteralEx) for arg in args)
|
datachain/lib/data_model.py
CHANGED
|
@@ -2,7 +2,7 @@ from collections.abc import Sequence
|
|
|
2
2
|
from datetime import datetime
|
|
3
3
|
from typing import ClassVar, Union, get_args, get_origin
|
|
4
4
|
|
|
5
|
-
from pydantic import BaseModel
|
|
5
|
+
from pydantic import BaseModel, create_model
|
|
6
6
|
|
|
7
7
|
from datachain.lib.model_store import ModelStore
|
|
8
8
|
|
|
@@ -57,3 +57,12 @@ def is_chain_type(t: type) -> bool:
|
|
|
57
57
|
return is_chain_type(args[0])
|
|
58
58
|
|
|
59
59
|
return False
|
|
60
|
+
|
|
61
|
+
|
|
62
|
+
def dict_to_data_model(name: str, data_dict: dict[str, DataType]) -> type[BaseModel]:
|
|
63
|
+
fields = {name: (anno, ...) for name, anno in data_dict.items()}
|
|
64
|
+
return create_model(
|
|
65
|
+
name,
|
|
66
|
+
__base__=(DataModel,), # type: ignore[call-overload]
|
|
67
|
+
**fields,
|
|
68
|
+
) # type: ignore[call-overload]
|
datachain/lib/dc.py
CHANGED
|
@@ -18,14 +18,13 @@ from typing import (
|
|
|
18
18
|
|
|
19
19
|
import pandas as pd
|
|
20
20
|
import sqlalchemy
|
|
21
|
-
from pydantic import BaseModel
|
|
21
|
+
from pydantic import BaseModel
|
|
22
22
|
from sqlalchemy.sql.functions import GenericFunction
|
|
23
23
|
from sqlalchemy.sql.sqltypes import NullType
|
|
24
24
|
|
|
25
|
-
from datachain import DataModel
|
|
26
25
|
from datachain.lib.convert.python_to_sql import python_to_sql
|
|
27
26
|
from datachain.lib.convert.values_to_tuples import values_to_tuples
|
|
28
|
-
from datachain.lib.data_model import DataType
|
|
27
|
+
from datachain.lib.data_model import DataModel, DataType, dict_to_data_model
|
|
29
28
|
from datachain.lib.dataset_info import DatasetInfo
|
|
30
29
|
from datachain.lib.file import ExportPlacement as FileExportPlacement
|
|
31
30
|
from datachain.lib.file import File, IndexedFile, get_file
|
|
@@ -55,6 +54,8 @@ from datachain.utils import inside_notebook
|
|
|
55
54
|
if TYPE_CHECKING:
|
|
56
55
|
from typing_extensions import Concatenate, ParamSpec, Self
|
|
57
56
|
|
|
57
|
+
from datachain.lib.hf import HFDatasetType
|
|
58
|
+
|
|
58
59
|
P = ParamSpec("P")
|
|
59
60
|
|
|
60
61
|
C = Column
|
|
@@ -77,12 +78,12 @@ def resolve_columns(
|
|
|
77
78
|
@wraps(method)
|
|
78
79
|
def _inner(self: D, *args: "P.args", **kwargs: "P.kwargs") -> D:
|
|
79
80
|
resolved_args = self.signals_schema.resolve(
|
|
80
|
-
*[arg for arg in args if not isinstance(arg, GenericFunction)]
|
|
81
|
+
*[arg for arg in args if not isinstance(arg, GenericFunction)] # type: ignore[arg-type]
|
|
81
82
|
).db_signals()
|
|
82
83
|
|
|
83
84
|
for idx, arg in enumerate(args):
|
|
84
85
|
if isinstance(arg, GenericFunction):
|
|
85
|
-
resolved_args.insert(idx, arg)
|
|
86
|
+
resolved_args.insert(idx, arg) # type: ignore[arg-type]
|
|
86
87
|
|
|
87
88
|
return method(self, *resolved_args, **kwargs)
|
|
88
89
|
|
|
@@ -208,23 +209,28 @@ class DataChain(DatasetQuery):
|
|
|
208
209
|
"size": 0,
|
|
209
210
|
}
|
|
210
211
|
|
|
211
|
-
def __init__(self, *args, **kwargs):
|
|
212
|
+
def __init__(self, *args, settings: Optional[dict] = None, **kwargs):
|
|
212
213
|
"""This method needs to be redefined as a part of Dataset and DataChain
|
|
213
214
|
decoupling.
|
|
214
215
|
"""
|
|
215
|
-
super().__init__(
|
|
216
|
+
super().__init__( # type: ignore[misc]
|
|
216
217
|
*args,
|
|
217
218
|
**kwargs,
|
|
218
219
|
indexing_column_types=File._datachain_column_types,
|
|
219
220
|
)
|
|
220
|
-
|
|
221
|
-
|
|
221
|
+
if settings:
|
|
222
|
+
self._settings = Settings(**settings)
|
|
223
|
+
else:
|
|
224
|
+
self._settings = Settings()
|
|
225
|
+
self._setup: dict = {}
|
|
222
226
|
|
|
223
227
|
self.signals_schema = SignalSchema({"sys": Sys})
|
|
224
228
|
if self.feature_schema:
|
|
225
229
|
self.signals_schema |= SignalSchema.deserialize(self.feature_schema)
|
|
226
230
|
else:
|
|
227
|
-
self.signals_schema |= SignalSchema.from_column_types(
|
|
231
|
+
self.signals_schema |= SignalSchema.from_column_types(
|
|
232
|
+
self.column_types or {}
|
|
233
|
+
)
|
|
228
234
|
|
|
229
235
|
self._sys = False
|
|
230
236
|
|
|
@@ -309,6 +315,7 @@ class DataChain(DatasetQuery):
|
|
|
309
315
|
*,
|
|
310
316
|
type: Literal["binary", "text", "image"] = "binary",
|
|
311
317
|
session: Optional[Session] = None,
|
|
318
|
+
settings: Optional[dict] = None,
|
|
312
319
|
in_memory: bool = False,
|
|
313
320
|
recursive: Optional[bool] = True,
|
|
314
321
|
object_name: str = "file",
|
|
@@ -336,6 +343,7 @@ class DataChain(DatasetQuery):
|
|
|
336
343
|
cls(
|
|
337
344
|
path,
|
|
338
345
|
session=session,
|
|
346
|
+
settings=settings,
|
|
339
347
|
recursive=recursive,
|
|
340
348
|
update=update,
|
|
341
349
|
in_memory=in_memory,
|
|
@@ -489,6 +497,7 @@ class DataChain(DatasetQuery):
|
|
|
489
497
|
def datasets(
|
|
490
498
|
cls,
|
|
491
499
|
session: Optional[Session] = None,
|
|
500
|
+
settings: Optional[dict] = None,
|
|
492
501
|
in_memory: bool = False,
|
|
493
502
|
object_name: str = "dataset",
|
|
494
503
|
) -> "DataChain":
|
|
@@ -513,6 +522,7 @@ class DataChain(DatasetQuery):
|
|
|
513
522
|
|
|
514
523
|
return cls.from_values(
|
|
515
524
|
session=session,
|
|
525
|
+
settings=settings,
|
|
516
526
|
in_memory=in_memory,
|
|
517
527
|
output={object_name: DatasetInfo},
|
|
518
528
|
**{object_name: datasets}, # type: ignore[arg-type]
|
|
@@ -895,7 +905,7 @@ class DataChain(DatasetQuery):
|
|
|
895
905
|
if isinstance(value, Column):
|
|
896
906
|
# renaming existing column
|
|
897
907
|
for signal in schema.db_signals(name=value.name, as_columns=True):
|
|
898
|
-
mutated[signal.name.replace(value.name, name, 1)] = signal
|
|
908
|
+
mutated[signal.name.replace(value.name, name, 1)] = signal # type: ignore[union-attr]
|
|
899
909
|
else:
|
|
900
910
|
# adding new signal
|
|
901
911
|
mutated[name] = value
|
|
@@ -1086,7 +1096,7 @@ class DataChain(DatasetQuery):
|
|
|
1086
1096
|
)
|
|
1087
1097
|
|
|
1088
1098
|
signals_schema = self.signals_schema.clone_without_sys_signals()
|
|
1089
|
-
on_columns = signals_schema.resolve(*on).db_signals()
|
|
1099
|
+
on_columns: list[str] = signals_schema.resolve(*on).db_signals() # type: ignore[assignment]
|
|
1090
1100
|
|
|
1091
1101
|
right_signals_schema = right_ds.signals_schema.clone_without_sys_signals()
|
|
1092
1102
|
if right_on is not None:
|
|
@@ -1105,7 +1115,9 @@ class DataChain(DatasetQuery):
|
|
|
1105
1115
|
on, right_on, "'on' and 'right_on' must have the same length'"
|
|
1106
1116
|
)
|
|
1107
1117
|
|
|
1108
|
-
right_on_columns = right_signals_schema.resolve(
|
|
1118
|
+
right_on_columns: list[str] = right_signals_schema.resolve(
|
|
1119
|
+
*right_on
|
|
1120
|
+
).db_signals() # type: ignore[assignment]
|
|
1109
1121
|
|
|
1110
1122
|
if len(right_on_columns) != len(on_columns):
|
|
1111
1123
|
on_str = ", ".join(right_on_columns)
|
|
@@ -1141,17 +1153,35 @@ class DataChain(DatasetQuery):
|
|
|
1141
1153
|
self,
|
|
1142
1154
|
other: "DataChain",
|
|
1143
1155
|
on: Optional[Union[str, Sequence[str]]] = None,
|
|
1156
|
+
right_on: Optional[Union[str, Sequence[str]]] = None,
|
|
1144
1157
|
) -> "Self":
|
|
1145
1158
|
"""Remove rows that appear in another chain.
|
|
1146
1159
|
|
|
1147
1160
|
Parameters:
|
|
1148
1161
|
other: chain whose rows will be removed from `self`
|
|
1149
|
-
on: columns to consider for determining row equality
|
|
1150
|
-
defaults to all common columns
|
|
1162
|
+
on: columns to consider for determining row equality in `self`.
|
|
1163
|
+
If unspecified, defaults to all common columns
|
|
1164
|
+
between `self` and `other`.
|
|
1165
|
+
right_on: columns to consider for determining row equality in `other`.
|
|
1166
|
+
If unspecified, defaults to the same values as `on`.
|
|
1151
1167
|
"""
|
|
1152
1168
|
if isinstance(on, str):
|
|
1169
|
+
if not on:
|
|
1170
|
+
raise DataChainParamsError("'on' cannot be an empty string")
|
|
1153
1171
|
on = [on]
|
|
1154
|
-
|
|
1172
|
+
elif isinstance(on, Sequence):
|
|
1173
|
+
if not on or any(not col for col in on):
|
|
1174
|
+
raise DataChainParamsError("'on' cannot contain empty strings")
|
|
1175
|
+
|
|
1176
|
+
if isinstance(right_on, str):
|
|
1177
|
+
if not right_on:
|
|
1178
|
+
raise DataChainParamsError("'right_on' cannot be an empty string")
|
|
1179
|
+
right_on = [right_on]
|
|
1180
|
+
elif isinstance(right_on, Sequence):
|
|
1181
|
+
if not right_on or any(not col for col in right_on):
|
|
1182
|
+
raise DataChainParamsError("'right_on' cannot contain empty strings")
|
|
1183
|
+
|
|
1184
|
+
if on is None and right_on is None:
|
|
1155
1185
|
other_columns = set(other._effective_signals_schema.db_signals())
|
|
1156
1186
|
signals = [
|
|
1157
1187
|
c
|
|
@@ -1160,16 +1190,29 @@ class DataChain(DatasetQuery):
|
|
|
1160
1190
|
]
|
|
1161
1191
|
if not signals:
|
|
1162
1192
|
raise DataChainParamsError("subtract(): no common columns")
|
|
1163
|
-
elif not
|
|
1164
|
-
|
|
1165
|
-
|
|
1166
|
-
|
|
1167
|
-
elif not on:
|
|
1193
|
+
elif on is not None and right_on is None:
|
|
1194
|
+
right_on = on
|
|
1195
|
+
signals = list(self.signals_schema.resolve(*on).db_signals())
|
|
1196
|
+
elif on is None and right_on is not None:
|
|
1168
1197
|
raise DataChainParamsError(
|
|
1169
|
-
"'on'
|
|
1198
|
+
"'on' must be specified when 'right_on' is provided"
|
|
1170
1199
|
)
|
|
1171
1200
|
else:
|
|
1172
|
-
|
|
1201
|
+
if not isinstance(on, Sequence) or not isinstance(right_on, Sequence):
|
|
1202
|
+
raise TypeError(
|
|
1203
|
+
"'on' and 'right_on' must be 'str' or 'Sequence' object"
|
|
1204
|
+
)
|
|
1205
|
+
if len(on) != len(right_on):
|
|
1206
|
+
raise DataChainParamsError(
|
|
1207
|
+
"'on' and 'right_on' must have the same length"
|
|
1208
|
+
)
|
|
1209
|
+
signals = list(
|
|
1210
|
+
zip(
|
|
1211
|
+
self.signals_schema.resolve(*on).db_signals(),
|
|
1212
|
+
other.signals_schema.resolve(*right_on).db_signals(),
|
|
1213
|
+
) # type: ignore[arg-type]
|
|
1214
|
+
)
|
|
1215
|
+
|
|
1173
1216
|
return super()._subtract(other, signals) # type: ignore[arg-type]
|
|
1174
1217
|
|
|
1175
1218
|
@classmethod
|
|
@@ -1177,6 +1220,7 @@ class DataChain(DatasetQuery):
|
|
|
1177
1220
|
cls,
|
|
1178
1221
|
ds_name: str = "",
|
|
1179
1222
|
session: Optional[Session] = None,
|
|
1223
|
+
settings: Optional[dict] = None,
|
|
1180
1224
|
in_memory: bool = False,
|
|
1181
1225
|
output: OutputType = None,
|
|
1182
1226
|
object_name: str = "",
|
|
@@ -1195,10 +1239,13 @@ class DataChain(DatasetQuery):
|
|
|
1195
1239
|
yield from tuples
|
|
1196
1240
|
|
|
1197
1241
|
chain = DataChain.from_records(
|
|
1198
|
-
DataChain.DEFAULT_FILE_RECORD,
|
|
1242
|
+
DataChain.DEFAULT_FILE_RECORD,
|
|
1243
|
+
session=session,
|
|
1244
|
+
settings=settings,
|
|
1245
|
+
in_memory=in_memory,
|
|
1199
1246
|
)
|
|
1200
1247
|
if object_name:
|
|
1201
|
-
output = {object_name:
|
|
1248
|
+
output = {object_name: dict_to_data_model(object_name, output)} # type: ignore[arg-type]
|
|
1202
1249
|
return chain.gen(_func_fr, output=output)
|
|
1203
1250
|
|
|
1204
1251
|
@classmethod
|
|
@@ -1207,6 +1254,7 @@ class DataChain(DatasetQuery):
|
|
|
1207
1254
|
df: "pd.DataFrame",
|
|
1208
1255
|
name: str = "",
|
|
1209
1256
|
session: Optional[Session] = None,
|
|
1257
|
+
settings: Optional[dict] = None,
|
|
1210
1258
|
in_memory: bool = False,
|
|
1211
1259
|
object_name: str = "",
|
|
1212
1260
|
) -> "DataChain":
|
|
@@ -1236,7 +1284,12 @@ class DataChain(DatasetQuery):
|
|
|
1236
1284
|
)
|
|
1237
1285
|
|
|
1238
1286
|
return cls.from_values(
|
|
1239
|
-
name,
|
|
1287
|
+
name,
|
|
1288
|
+
session,
|
|
1289
|
+
settings=settings,
|
|
1290
|
+
object_name=object_name,
|
|
1291
|
+
in_memory=in_memory,
|
|
1292
|
+
**fr_map,
|
|
1240
1293
|
)
|
|
1241
1294
|
|
|
1242
1295
|
def to_pandas(self, flatten=False) -> "pd.DataFrame":
|
|
@@ -1306,6 +1359,59 @@ class DataChain(DatasetQuery):
|
|
|
1306
1359
|
if len(df) == limit:
|
|
1307
1360
|
print(f"\n[Limited by {len(df)} rows]")
|
|
1308
1361
|
|
|
1362
|
+
@classmethod
|
|
1363
|
+
def from_hf(
|
|
1364
|
+
cls,
|
|
1365
|
+
dataset: Union[str, "HFDatasetType"],
|
|
1366
|
+
*args,
|
|
1367
|
+
session: Optional[Session] = None,
|
|
1368
|
+
settings: Optional[dict] = None,
|
|
1369
|
+
object_name: str = "",
|
|
1370
|
+
model_name: str = "",
|
|
1371
|
+
**kwargs,
|
|
1372
|
+
) -> "DataChain":
|
|
1373
|
+
"""Generate chain from huggingface hub dataset.
|
|
1374
|
+
|
|
1375
|
+
Parameters:
|
|
1376
|
+
dataset : Path or name of the dataset to read from Hugging Face Hub,
|
|
1377
|
+
or an instance of `datasets.Dataset`-like object.
|
|
1378
|
+
session : Session to use for the chain.
|
|
1379
|
+
settings : Settings to use for the chain.
|
|
1380
|
+
object_name : Generated object column name.
|
|
1381
|
+
model_name : Generated model name.
|
|
1382
|
+
kwargs : Parameters to pass to datasets.load_dataset.
|
|
1383
|
+
|
|
1384
|
+
Example:
|
|
1385
|
+
Load from Hugging Face Hub:
|
|
1386
|
+
```py
|
|
1387
|
+
DataChain.from_hf("beans", split="train")
|
|
1388
|
+
```
|
|
1389
|
+
|
|
1390
|
+
Generate chain from loaded dataset:
|
|
1391
|
+
```py
|
|
1392
|
+
from datasets import load_dataset
|
|
1393
|
+
ds = load_dataset("beans", split="train")
|
|
1394
|
+
DataChain.from_hf(ds)
|
|
1395
|
+
```
|
|
1396
|
+
"""
|
|
1397
|
+
from datachain.lib.hf import HFGenerator, get_output_schema, stream_splits
|
|
1398
|
+
|
|
1399
|
+
output: dict[str, DataType] = {}
|
|
1400
|
+
ds_dict = stream_splits(dataset, *args, **kwargs)
|
|
1401
|
+
if len(ds_dict) > 1:
|
|
1402
|
+
output = {"split": str}
|
|
1403
|
+
|
|
1404
|
+
model_name = model_name or object_name or ""
|
|
1405
|
+
output = output | get_output_schema(next(iter(ds_dict.values())), model_name)
|
|
1406
|
+
model = dict_to_data_model(model_name, output)
|
|
1407
|
+
if object_name:
|
|
1408
|
+
output = {object_name: model}
|
|
1409
|
+
|
|
1410
|
+
chain = DataChain.from_values(
|
|
1411
|
+
split=list(ds_dict.keys()), session=session, settings=settings
|
|
1412
|
+
)
|
|
1413
|
+
return chain.gen(HFGenerator(dataset, model, *args, **kwargs), output=output)
|
|
1414
|
+
|
|
1309
1415
|
def parse_tabular(
|
|
1310
1416
|
self,
|
|
1311
1417
|
output: OutputType = None,
|
|
@@ -1367,7 +1473,7 @@ class DataChain(DatasetQuery):
|
|
|
1367
1473
|
|
|
1368
1474
|
if isinstance(output, dict):
|
|
1369
1475
|
model_name = model_name or object_name or ""
|
|
1370
|
-
model =
|
|
1476
|
+
model = dict_to_data_model(model_name, output)
|
|
1371
1477
|
else:
|
|
1372
1478
|
model = output # type: ignore[assignment]
|
|
1373
1479
|
|
|
@@ -1384,17 +1490,6 @@ class DataChain(DatasetQuery):
|
|
|
1384
1490
|
ArrowGenerator(schema, model, source, nrows, **kwargs), output=output
|
|
1385
1491
|
)
|
|
1386
1492
|
|
|
1387
|
-
@staticmethod
|
|
1388
|
-
def _dict_to_data_model(
|
|
1389
|
-
name: str, data_dict: dict[str, DataType]
|
|
1390
|
-
) -> type[BaseModel]:
|
|
1391
|
-
fields = {name: (anno, ...) for name, anno in data_dict.items()}
|
|
1392
|
-
return create_model(
|
|
1393
|
-
name,
|
|
1394
|
-
__base__=(DataModel,), # type: ignore[call-overload]
|
|
1395
|
-
**fields,
|
|
1396
|
-
) # type: ignore[call-overload]
|
|
1397
|
-
|
|
1398
1493
|
@classmethod
|
|
1399
1494
|
def from_csv(
|
|
1400
1495
|
cls,
|
|
@@ -1543,6 +1638,7 @@ class DataChain(DatasetQuery):
|
|
|
1543
1638
|
cls,
|
|
1544
1639
|
to_insert: Optional[Union[dict, list[dict]]],
|
|
1545
1640
|
session: Optional[Session] = None,
|
|
1641
|
+
settings: Optional[dict] = None,
|
|
1546
1642
|
in_memory: bool = False,
|
|
1547
1643
|
schema: Optional[dict[str, DataType]] = None,
|
|
1548
1644
|
) -> "DataChain":
|
|
@@ -1597,7 +1693,7 @@ class DataChain(DatasetQuery):
|
|
|
1597
1693
|
insert_q = dr.get_table().insert()
|
|
1598
1694
|
for record in to_insert:
|
|
1599
1695
|
db.execute(insert_q.values(**record))
|
|
1600
|
-
return DataChain(name=dsr.name)
|
|
1696
|
+
return DataChain(name=dsr.name, settings=settings)
|
|
1601
1697
|
|
|
1602
1698
|
def sum(self, fr: DataType): # type: ignore[override]
|
|
1603
1699
|
"""Compute the sum of a column."""
|