datachain 0.2.8__py3-none-any.whl → 0.2.10__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of datachain might be problematic. Click here for more details.
- datachain/__init__.py +17 -8
- datachain/catalog/catalog.py +5 -5
- datachain/cli.py +0 -2
- datachain/data_storage/schema.py +5 -5
- datachain/data_storage/sqlite.py +1 -1
- datachain/data_storage/warehouse.py +7 -7
- datachain/lib/arrow.py +25 -8
- datachain/lib/clip.py +6 -11
- datachain/lib/convert/__init__.py +0 -0
- datachain/lib/convert/flatten.py +67 -0
- datachain/lib/convert/type_converter.py +96 -0
- datachain/lib/convert/unflatten.py +69 -0
- datachain/lib/convert/values_to_tuples.py +85 -0
- datachain/lib/data_model.py +74 -0
- datachain/lib/dc.py +192 -167
- datachain/lib/feature_registry.py +36 -10
- datachain/lib/file.py +41 -41
- datachain/lib/gpt4_vision.py +1 -9
- datachain/lib/hf_image_to_text.py +9 -17
- datachain/lib/hf_pipeline.py +4 -12
- datachain/lib/image.py +2 -18
- datachain/lib/image_transform.py +0 -1
- datachain/lib/iptc_exif_xmp.py +8 -15
- datachain/lib/meta_formats.py +1 -5
- datachain/lib/model_store.py +77 -0
- datachain/lib/pytorch.py +9 -21
- datachain/lib/signal_schema.py +120 -58
- datachain/lib/text.py +5 -16
- datachain/lib/udf.py +114 -30
- datachain/lib/udf_signature.py +5 -5
- datachain/lib/webdataset.py +3 -4
- datachain/lib/webdataset_laion.py +2 -3
- datachain/node.py +4 -4
- datachain/query/batch.py +1 -1
- datachain/query/dataset.py +40 -60
- datachain/query/dispatch.py +28 -17
- datachain/query/udf.py +46 -26
- datachain/remote/studio.py +1 -9
- datachain/torch/__init__.py +21 -0
- {datachain-0.2.8.dist-info → datachain-0.2.10.dist-info}/METADATA +13 -12
- {datachain-0.2.8.dist-info → datachain-0.2.10.dist-info}/RECORD +45 -42
- datachain/image/__init__.py +0 -3
- datachain/lib/cached_stream.py +0 -38
- datachain/lib/claude.py +0 -69
- datachain/lib/feature.py +0 -412
- datachain/lib/feature_utils.py +0 -154
- {datachain-0.2.8.dist-info → datachain-0.2.10.dist-info}/LICENSE +0 -0
- {datachain-0.2.8.dist-info → datachain-0.2.10.dist-info}/WHEEL +0 -0
- {datachain-0.2.8.dist-info → datachain-0.2.10.dist-info}/entry_points.txt +0 -0
- {datachain-0.2.8.dist-info → datachain-0.2.10.dist-info}/top_level.txt +0 -0
datachain/lib/dc.py
CHANGED
|
@@ -1,3 +1,4 @@
|
|
|
1
|
+
import copy
|
|
1
2
|
import re
|
|
2
3
|
from collections.abc import Iterator, Sequence
|
|
3
4
|
from typing import (
|
|
@@ -11,11 +12,14 @@ from typing import (
|
|
|
11
12
|
)
|
|
12
13
|
|
|
13
14
|
import sqlalchemy
|
|
15
|
+
from pydantic import BaseModel, create_model
|
|
14
16
|
|
|
15
|
-
from datachain
|
|
16
|
-
from datachain.lib.
|
|
17
|
+
from datachain import DataModel
|
|
18
|
+
from datachain.lib.convert.values_to_tuples import values_to_tuples
|
|
19
|
+
from datachain.lib.data_model import DataType
|
|
17
20
|
from datachain.lib.file import File, IndexedFile, get_file
|
|
18
21
|
from datachain.lib.meta_formats import read_meta, read_schema
|
|
22
|
+
from datachain.lib.model_store import ModelStore
|
|
19
23
|
from datachain.lib.settings import Settings
|
|
20
24
|
from datachain.lib.signal_schema import SignalSchema
|
|
21
25
|
from datachain.lib.udf import (
|
|
@@ -39,8 +43,6 @@ if TYPE_CHECKING:
|
|
|
39
43
|
import pandas as pd
|
|
40
44
|
from typing_extensions import Self
|
|
41
45
|
|
|
42
|
-
from datachain.catalog import Catalog
|
|
43
|
-
|
|
44
46
|
C = Column
|
|
45
47
|
|
|
46
48
|
|
|
@@ -51,10 +53,10 @@ class DatasetPrepareError(DataChainParamsError):
|
|
|
51
53
|
super().__init__(f"Dataset{name}{output} processing prepare error: {msg}")
|
|
52
54
|
|
|
53
55
|
|
|
54
|
-
class
|
|
56
|
+
class DatasetFromValuesError(DataChainParamsError):
|
|
55
57
|
def __init__(self, name, msg):
|
|
56
58
|
name = f" '{name}'" if name else ""
|
|
57
|
-
super().__init__(f"Dataset {name} from
|
|
59
|
+
super().__init__(f"Dataset {name} from values error: {msg}")
|
|
58
60
|
|
|
59
61
|
|
|
60
62
|
class DatasetMergeError(DataChainParamsError):
|
|
@@ -68,6 +70,14 @@ class DatasetMergeError(DataChainParamsError):
|
|
|
68
70
|
super().__init__(f"Merge error on='{on_str}'{right_on_str}: {msg}")
|
|
69
71
|
|
|
70
72
|
|
|
73
|
+
OutputType = Union[None, DataType, Sequence[str], dict[str, DataType]]
|
|
74
|
+
|
|
75
|
+
|
|
76
|
+
class Sys(DataModel):
|
|
77
|
+
id: int
|
|
78
|
+
rand: int
|
|
79
|
+
|
|
80
|
+
|
|
71
81
|
class DataChain(DatasetQuery):
|
|
72
82
|
"""AI 🔗 DataChain - a data structure for batch data processing and evaluation.
|
|
73
83
|
|
|
@@ -120,12 +130,10 @@ class DataChain(DatasetQuery):
|
|
|
120
130
|
"""
|
|
121
131
|
|
|
122
132
|
DEFAULT_FILE_RECORD: ClassVar[dict] = {
|
|
123
|
-
"id": 0,
|
|
124
133
|
"source": "",
|
|
125
134
|
"name": "",
|
|
126
135
|
"vtype": "",
|
|
127
136
|
"size": 0,
|
|
128
|
-
"random": 0,
|
|
129
137
|
}
|
|
130
138
|
|
|
131
139
|
def __init__(self, *args, **kwargs):
|
|
@@ -151,11 +159,19 @@ class DataChain(DatasetQuery):
|
|
|
151
159
|
def print_schema(self):
|
|
152
160
|
self.signals_schema.print_tree()
|
|
153
161
|
|
|
154
|
-
def
|
|
155
|
-
|
|
162
|
+
def clone(self, new_table: bool = True) -> "Self":
|
|
163
|
+
obj = super().clone(new_table=new_table)
|
|
164
|
+
obj.signals_schema = copy.deepcopy(self.signals_schema)
|
|
165
|
+
return obj
|
|
156
166
|
|
|
157
167
|
def settings(
|
|
158
|
-
self,
|
|
168
|
+
self,
|
|
169
|
+
cache=None,
|
|
170
|
+
batch=None,
|
|
171
|
+
parallel=None,
|
|
172
|
+
workers=None,
|
|
173
|
+
min_task_size=None,
|
|
174
|
+
include_sys: Optional[bool] = None,
|
|
159
175
|
) -> "Self":
|
|
160
176
|
"""Change settings for chain.
|
|
161
177
|
|
|
@@ -179,8 +195,13 @@ class DataChain(DatasetQuery):
|
|
|
179
195
|
)
|
|
180
196
|
```
|
|
181
197
|
"""
|
|
182
|
-
self.
|
|
183
|
-
|
|
198
|
+
chain = self.clone()
|
|
199
|
+
if include_sys is True:
|
|
200
|
+
chain.signals_schema = SignalSchema({"sys": Sys}) | chain.signals_schema
|
|
201
|
+
elif include_sys is False and "sys" in chain.signals_schema:
|
|
202
|
+
chain.signals_schema.remove("sys")
|
|
203
|
+
chain._settings.add(Settings(cache, batch, parallel, workers, min_task_size))
|
|
204
|
+
return chain
|
|
184
205
|
|
|
185
206
|
def reset_settings(self, settings: Optional[Settings] = None) -> "Self":
|
|
186
207
|
"""Reset all settings to default values."""
|
|
@@ -192,12 +213,11 @@ class DataChain(DatasetQuery):
|
|
|
192
213
|
return self
|
|
193
214
|
|
|
194
215
|
def add_schema(self, signals_schema: SignalSchema) -> "Self":
|
|
195
|
-
|
|
196
|
-
self.signals_schema = SignalSchema(union)
|
|
216
|
+
self.signals_schema |= signals_schema
|
|
197
217
|
return self
|
|
198
218
|
|
|
199
219
|
def get_file_signals(self) -> list[str]:
|
|
200
|
-
return self.signals_schema.get_file_signals()
|
|
220
|
+
return list(self.signals_schema.get_file_signals())
|
|
201
221
|
|
|
202
222
|
@classmethod
|
|
203
223
|
def from_storage(
|
|
@@ -205,9 +225,10 @@ class DataChain(DatasetQuery):
|
|
|
205
225
|
path,
|
|
206
226
|
*,
|
|
207
227
|
type: Literal["binary", "text", "image"] = "binary",
|
|
208
|
-
|
|
228
|
+
session: Optional[Session] = None,
|
|
209
229
|
recursive: Optional[bool] = True,
|
|
210
|
-
|
|
230
|
+
object_name: str = "file",
|
|
231
|
+
**kwargs,
|
|
211
232
|
) -> "Self":
|
|
212
233
|
"""Get data from a storage as a list of file with all file attributes. It
|
|
213
234
|
returns the chain itself as usual.
|
|
@@ -217,7 +238,7 @@ class DataChain(DatasetQuery):
|
|
|
217
238
|
as `s3://`, `gs://`, `az://` or "file:///"
|
|
218
239
|
type : read file as "binary", "text", or "image" data. Default is "binary".
|
|
219
240
|
recursive : search recursively for the given path.
|
|
220
|
-
|
|
241
|
+
object_name : Created object column name.
|
|
221
242
|
|
|
222
243
|
Example:
|
|
223
244
|
```py
|
|
@@ -225,7 +246,9 @@ class DataChain(DatasetQuery):
|
|
|
225
246
|
```
|
|
226
247
|
"""
|
|
227
248
|
func = get_file(type)
|
|
228
|
-
return cls(path,
|
|
249
|
+
return cls(path, session=session, recursive=recursive, **kwargs).map(
|
|
250
|
+
**{object_name: func}
|
|
251
|
+
)
|
|
229
252
|
|
|
230
253
|
@classmethod
|
|
231
254
|
def from_dataset(cls, name: str, version: Optional[int] = None) -> "DataChain":
|
|
@@ -240,66 +263,19 @@ class DataChain(DatasetQuery):
|
|
|
240
263
|
"""
|
|
241
264
|
return DataChain(name=name, version=version)
|
|
242
265
|
|
|
243
|
-
@classmethod
|
|
244
|
-
def from_csv(
|
|
245
|
-
cls,
|
|
246
|
-
path,
|
|
247
|
-
type: Literal["binary", "text", "image"] = "text",
|
|
248
|
-
anon: bool = False,
|
|
249
|
-
spec: Optional[FeatureType] = None,
|
|
250
|
-
schema_from: Optional[str] = "auto",
|
|
251
|
-
object_name: Optional[str] = "csv",
|
|
252
|
-
model_name: Optional[str] = None,
|
|
253
|
-
show_schema: Optional[bool] = False,
|
|
254
|
-
) -> "DataChain":
|
|
255
|
-
"""Get data from CSV. It returns the chain itself.
|
|
256
|
-
|
|
257
|
-
Parameters:
|
|
258
|
-
path : storage URI with directory. URI must start with storage prefix such
|
|
259
|
-
as `s3://`, `gs://`, `az://` or "file:///"
|
|
260
|
-
type : read file as "binary", "text", or "image" data. Default is "text".
|
|
261
|
-
anon : use anonymous mode to access the storage.
|
|
262
|
-
spec : Data Model for CSV file
|
|
263
|
-
object_name : generated object column name
|
|
264
|
-
model_name : generated model name
|
|
265
|
-
schema_from : path to sample to infer spec from
|
|
266
|
-
show_schema : print auto-generated schema
|
|
267
|
-
|
|
268
|
-
Examples:
|
|
269
|
-
infer model from the first two lines (header + data)
|
|
270
|
-
>>> chain = DataChain.from_csv("gs://csv")
|
|
271
|
-
|
|
272
|
-
use a particular data model
|
|
273
|
-
>>> chain = DataChain.from_csv("gs://csv"i, spec=MyModel)
|
|
274
|
-
"""
|
|
275
|
-
if schema_from == "auto":
|
|
276
|
-
schema_from = path
|
|
277
|
-
|
|
278
|
-
chain = DataChain.from_storage(path=path, type=type, anon=anon)
|
|
279
|
-
signal_dict = {
|
|
280
|
-
object_name: read_meta(
|
|
281
|
-
schema_from=schema_from,
|
|
282
|
-
meta_type="csv",
|
|
283
|
-
spec=spec,
|
|
284
|
-
model_name=model_name,
|
|
285
|
-
show_schema=show_schema,
|
|
286
|
-
)
|
|
287
|
-
}
|
|
288
|
-
return chain.gen(**signal_dict) # type: ignore[misc, arg-type]
|
|
289
|
-
|
|
290
266
|
@classmethod
|
|
291
267
|
def from_json(
|
|
292
268
|
cls,
|
|
293
269
|
path,
|
|
294
270
|
type: Literal["binary", "text", "image"] = "text",
|
|
295
|
-
|
|
296
|
-
spec: Optional[FeatureType] = None,
|
|
271
|
+
spec: Optional[DataType] = None,
|
|
297
272
|
schema_from: Optional[str] = "auto",
|
|
298
273
|
jmespath: Optional[str] = None,
|
|
299
|
-
object_name:
|
|
274
|
+
object_name: str = "",
|
|
300
275
|
model_name: Optional[str] = None,
|
|
301
276
|
show_schema: Optional[bool] = False,
|
|
302
277
|
meta_type: Optional[str] = "json",
|
|
278
|
+
**kwargs,
|
|
303
279
|
) -> "DataChain":
|
|
304
280
|
"""Get data from JSON. It returns the chain itself.
|
|
305
281
|
|
|
@@ -307,7 +283,6 @@ class DataChain(DatasetQuery):
|
|
|
307
283
|
path : storage URI with directory. URI must start with storage prefix such
|
|
308
284
|
as `s3://`, `gs://`, `az://` or "file:///"
|
|
309
285
|
type : read file as "binary", "text", or "image" data. Default is "binary".
|
|
310
|
-
anon : use anonymous mode to access the storage.
|
|
311
286
|
spec : optional Data Model
|
|
312
287
|
schema_from : path to sample to infer spec from
|
|
313
288
|
object_name : generated object column name
|
|
@@ -333,7 +308,7 @@ class DataChain(DatasetQuery):
|
|
|
333
308
|
object_name = jmespath_to_name(jmespath)
|
|
334
309
|
if not object_name:
|
|
335
310
|
object_name = "json"
|
|
336
|
-
chain = DataChain.from_storage(path=path, type=type,
|
|
311
|
+
chain = DataChain.from_storage(path=path, type=type, **kwargs)
|
|
337
312
|
signal_dict = {
|
|
338
313
|
object_name: read_meta(
|
|
339
314
|
schema_from=schema_from,
|
|
@@ -396,6 +371,7 @@ class DataChain(DatasetQuery):
|
|
|
396
371
|
version : version of a dataset. Default - the last version that exist.
|
|
397
372
|
"""
|
|
398
373
|
schema = self.signals_schema.serialize()
|
|
374
|
+
schema.pop("sys", None)
|
|
399
375
|
return super().save(name=name, version=version, feature_schema=schema)
|
|
400
376
|
|
|
401
377
|
def apply(self, func, *args, **kwargs):
|
|
@@ -405,7 +381,7 @@ class DataChain(DatasetQuery):
|
|
|
405
381
|
self,
|
|
406
382
|
func: Optional[Callable] = None,
|
|
407
383
|
params: Union[None, str, Sequence[str]] = None,
|
|
408
|
-
output:
|
|
384
|
+
output: OutputType = None,
|
|
409
385
|
**signal_map,
|
|
410
386
|
) -> "Self":
|
|
411
387
|
"""Apply a function to each row to create new signals. The function should
|
|
@@ -449,7 +425,7 @@ class DataChain(DatasetQuery):
|
|
|
449
425
|
self,
|
|
450
426
|
func: Optional[Callable] = None,
|
|
451
427
|
params: Union[None, str, Sequence[str]] = None,
|
|
452
|
-
output:
|
|
428
|
+
output: OutputType = None,
|
|
453
429
|
**signal_map,
|
|
454
430
|
) -> "Self":
|
|
455
431
|
"""Apply a function to each row to create new rows (with potentially new
|
|
@@ -478,7 +454,7 @@ class DataChain(DatasetQuery):
|
|
|
478
454
|
func: Optional[Callable] = None,
|
|
479
455
|
partition_by: Optional[PartitionByType] = None,
|
|
480
456
|
params: Union[None, str, Sequence[str]] = None,
|
|
481
|
-
output:
|
|
457
|
+
output: OutputType = None,
|
|
482
458
|
**signal_map,
|
|
483
459
|
) -> "Self":
|
|
484
460
|
"""Aggregate rows using `partition_by` statement and apply a function to the
|
|
@@ -508,7 +484,7 @@ class DataChain(DatasetQuery):
|
|
|
508
484
|
self,
|
|
509
485
|
func: Optional[Callable] = None,
|
|
510
486
|
params: Union[None, str, Sequence[str]] = None,
|
|
511
|
-
output:
|
|
487
|
+
output: OutputType = None,
|
|
512
488
|
**signal_map,
|
|
513
489
|
) -> "Self":
|
|
514
490
|
"""This is a batch version of map().
|
|
@@ -530,18 +506,20 @@ class DataChain(DatasetQuery):
|
|
|
530
506
|
target_class: type[UDFBase],
|
|
531
507
|
func: Optional[Callable],
|
|
532
508
|
params: Union[None, str, Sequence[str]],
|
|
533
|
-
output:
|
|
509
|
+
output: OutputType,
|
|
534
510
|
signal_map,
|
|
535
511
|
) -> UDFBase:
|
|
536
512
|
is_generator = target_class.is_output_batched
|
|
537
513
|
name = self.name or ""
|
|
538
514
|
|
|
539
515
|
sign = UdfSignature.parse(name, signal_map, func, params, output, is_generator)
|
|
516
|
+
DataModel.register(list(sign.output_schema.values.values()))
|
|
517
|
+
|
|
540
518
|
params_schema = self.signals_schema.slice(sign.params, self._setup)
|
|
541
519
|
|
|
542
|
-
return
|
|
520
|
+
return target_class._create(sign, params_schema)
|
|
543
521
|
|
|
544
|
-
def
|
|
522
|
+
def _extend_to_data_model(self, method_name, *args, **kwargs):
|
|
545
523
|
super_func = getattr(super(), method_name)
|
|
546
524
|
|
|
547
525
|
new_schema = self.signals_schema.resolve(*args)
|
|
@@ -570,39 +548,46 @@ class DataChain(DatasetQuery):
|
|
|
570
548
|
chain.signals_schema = new_schema
|
|
571
549
|
return chain
|
|
572
550
|
|
|
573
|
-
def
|
|
551
|
+
def iterate_flatten(self) -> Iterator[tuple[Any]]:
|
|
552
|
+
db_signals = self.signals_schema.db_signals()
|
|
553
|
+
with super().select(*db_signals).as_iterable() as rows:
|
|
554
|
+
yield from rows
|
|
555
|
+
|
|
556
|
+
def results(
|
|
557
|
+
self, row_factory: Optional[Callable] = None, **kwargs
|
|
558
|
+
) -> list[tuple[Any, ...]]:
|
|
559
|
+
rows = self.iterate_flatten()
|
|
560
|
+
if row_factory:
|
|
561
|
+
db_signals = self.signals_schema.db_signals()
|
|
562
|
+
rows = (row_factory(db_signals, r) for r in rows)
|
|
563
|
+
return list(rows)
|
|
564
|
+
|
|
565
|
+
def iterate(self, *cols: str) -> Iterator[list[DataType]]:
|
|
574
566
|
"""Iterate over rows.
|
|
575
567
|
|
|
576
568
|
If columns are specified - limit them to specified
|
|
577
569
|
columns.
|
|
578
570
|
"""
|
|
579
571
|
chain = self.select(*cols) if cols else self
|
|
572
|
+
for row in chain.iterate_flatten():
|
|
573
|
+
yield chain.signals_schema.row_to_features(
|
|
574
|
+
row, catalog=chain.session.catalog, cache=chain._settings.cache
|
|
575
|
+
)
|
|
580
576
|
|
|
581
|
-
|
|
582
|
-
with super().select(*db_signals).as_iterable() as rows_iter:
|
|
583
|
-
for row in rows_iter:
|
|
584
|
-
yield chain.signals_schema.row_to_features(row, chain.session.catalog)
|
|
585
|
-
|
|
586
|
-
def iterate_one(self, col: str) -> Iterator[FeatureType]:
|
|
577
|
+
def iterate_one(self, col: str) -> Iterator[DataType]:
|
|
587
578
|
for item in self.iterate(col):
|
|
588
579
|
yield item[0]
|
|
589
580
|
|
|
590
|
-
def collect(self, *cols: str) -> list[list[
|
|
581
|
+
def collect(self, *cols: str) -> list[list[DataType]]:
|
|
591
582
|
return list(self.iterate(*cols))
|
|
592
583
|
|
|
593
|
-
def collect_one(self, col: str) -> list[
|
|
584
|
+
def collect_one(self, col: str) -> list[DataType]:
|
|
594
585
|
return list(self.iterate_one(col))
|
|
595
586
|
|
|
596
587
|
def to_pytorch(self, **kwargs):
|
|
597
588
|
"""Convert to pytorch dataset format."""
|
|
598
589
|
|
|
599
|
-
|
|
600
|
-
import torch # noqa: F401
|
|
601
|
-
except ImportError as exc:
|
|
602
|
-
raise ImportError(
|
|
603
|
-
"Missing required dependency 'torch' for Dataset.to_pytorch()"
|
|
604
|
-
) from exc
|
|
605
|
-
from datachain.lib.pytorch import PytorchDataset
|
|
590
|
+
from datachain.torch import PytorchDataset
|
|
606
591
|
|
|
607
592
|
if self.attached:
|
|
608
593
|
chain = self
|
|
@@ -701,25 +686,32 @@ class DataChain(DatasetQuery):
|
|
|
701
686
|
return ds
|
|
702
687
|
|
|
703
688
|
@classmethod
|
|
704
|
-
def
|
|
689
|
+
def from_values(
|
|
705
690
|
cls,
|
|
706
691
|
ds_name: str = "",
|
|
707
692
|
session: Optional[Session] = None,
|
|
708
|
-
output:
|
|
693
|
+
output: OutputType = None,
|
|
694
|
+
object_name: str = "",
|
|
709
695
|
**fr_map,
|
|
710
696
|
) -> "DataChain":
|
|
711
|
-
"""Generate chain from list of
|
|
712
|
-
tuple_type, output, tuples =
|
|
697
|
+
"""Generate chain from list of values."""
|
|
698
|
+
tuple_type, output, tuples = values_to_tuples(ds_name, output, **fr_map)
|
|
713
699
|
|
|
714
700
|
def _func_fr() -> Iterator[tuple_type]: # type: ignore[valid-type]
|
|
715
701
|
yield from tuples
|
|
716
702
|
|
|
717
703
|
chain = DataChain.create_empty(DataChain.DEFAULT_FILE_RECORD, session=session)
|
|
704
|
+
if object_name:
|
|
705
|
+
output = {object_name: DataChain._dict_to_data_model(object_name, output)} # type: ignore[arg-type]
|
|
718
706
|
return chain.gen(_func_fr, output=output)
|
|
719
707
|
|
|
720
708
|
@classmethod
|
|
721
709
|
def from_pandas( # type: ignore[override]
|
|
722
|
-
cls,
|
|
710
|
+
cls,
|
|
711
|
+
df: "pd.DataFrame",
|
|
712
|
+
name: str = "",
|
|
713
|
+
session: Optional[Session] = None,
|
|
714
|
+
object_name: str = "",
|
|
723
715
|
) -> "DataChain":
|
|
724
716
|
"""Generate chain from pandas data-frame."""
|
|
725
717
|
fr_map = {col.lower(): df[col].tolist() for col in df.columns}
|
|
@@ -737,17 +729,23 @@ class DataChain(DatasetQuery):
|
|
|
737
729
|
f"import from pandas error - '{column}' cannot be a column name",
|
|
738
730
|
)
|
|
739
731
|
|
|
740
|
-
return cls.
|
|
732
|
+
return cls.from_values(name, session, object_name=object_name, **fr_map)
|
|
741
733
|
|
|
742
734
|
def parse_tabular(
|
|
743
735
|
self,
|
|
744
|
-
output:
|
|
736
|
+
output: OutputType = None,
|
|
737
|
+
object_name: str = "",
|
|
738
|
+
model_name: str = "",
|
|
745
739
|
**kwargs,
|
|
746
740
|
) -> "DataChain":
|
|
747
741
|
"""Generate chain from list of tabular files.
|
|
748
742
|
|
|
749
743
|
Parameters:
|
|
750
|
-
output : Dictionary defining column names and their
|
|
744
|
+
output : Dictionary or feature class defining column names and their
|
|
745
|
+
corresponding types. List of column names is also accepted, in which
|
|
746
|
+
case types will be inferred.
|
|
747
|
+
object_name : Generated object column name.
|
|
748
|
+
model_name : Generated model name.
|
|
751
749
|
kwargs : Parameters to pass to pyarrow.dataset.dataset.
|
|
752
750
|
|
|
753
751
|
Examples:
|
|
@@ -760,107 +758,134 @@ class DataChain(DatasetQuery):
|
|
|
760
758
|
>>> dc = dc.filter(C("file.name").glob("*.jsonl"))
|
|
761
759
|
>>> dc = dc.parse_tabular(format="json")
|
|
762
760
|
"""
|
|
763
|
-
from pyarrow import unify_schemas
|
|
764
|
-
from pyarrow.dataset import dataset
|
|
765
761
|
|
|
766
|
-
from datachain.lib.arrow import ArrowGenerator, schema_to_output
|
|
762
|
+
from datachain.lib.arrow import ArrowGenerator, infer_schema, schema_to_output
|
|
767
763
|
|
|
768
764
|
schema = None
|
|
769
|
-
if output
|
|
770
|
-
|
|
771
|
-
else:
|
|
772
|
-
schemas = []
|
|
773
|
-
for row in self.select("file").iterate():
|
|
774
|
-
file = row[0]
|
|
775
|
-
ds = dataset(file.get_path(), filesystem=file.get_fs(), **kwargs) # type: ignore[union-attr]
|
|
776
|
-
schemas.append(ds.schema)
|
|
777
|
-
if not schemas:
|
|
778
|
-
msg = "error parsing tabular data schema - found no files to parse"
|
|
779
|
-
raise DatasetPrepareError(self.name, msg)
|
|
780
|
-
schema = unify_schemas(schemas)
|
|
765
|
+
col_names = output if isinstance(output, Sequence) else None
|
|
766
|
+
if col_names or not output:
|
|
781
767
|
try:
|
|
782
|
-
|
|
768
|
+
schema = infer_schema(self, **kwargs)
|
|
769
|
+
output = schema_to_output(schema, col_names)
|
|
783
770
|
except ValueError as e:
|
|
784
771
|
raise DatasetPrepareError(self.name, e) from e
|
|
785
772
|
|
|
773
|
+
if object_name:
|
|
774
|
+
if isinstance(output, dict):
|
|
775
|
+
model_name = model_name or object_name
|
|
776
|
+
output = DataChain._dict_to_data_model(model_name, output)
|
|
777
|
+
output = {object_name: output} # type: ignore[dict-item]
|
|
778
|
+
elif isinstance(output, type(BaseModel)):
|
|
779
|
+
output = {
|
|
780
|
+
name: info.annotation # type: ignore[misc]
|
|
781
|
+
for name, info in output.model_fields.items()
|
|
782
|
+
}
|
|
783
|
+
output = {"source": IndexedFile} | output # type: ignore[assignment,operator]
|
|
786
784
|
return self.gen(ArrowGenerator(schema, **kwargs), output=output)
|
|
787
785
|
|
|
788
|
-
|
|
789
|
-
|
|
786
|
+
@staticmethod
|
|
787
|
+
def _dict_to_data_model(
|
|
788
|
+
name: str, data_dict: dict[str, DataType]
|
|
789
|
+
) -> type[BaseModel]:
|
|
790
|
+
fields = {name: (anno, ...) for name, anno in data_dict.items()}
|
|
791
|
+
return create_model(
|
|
792
|
+
name,
|
|
793
|
+
__base__=(DataModel,), # type: ignore[call-overload]
|
|
794
|
+
**fields,
|
|
795
|
+
) # type: ignore[call-overload]
|
|
796
|
+
|
|
797
|
+
@classmethod
|
|
798
|
+
def from_csv(
|
|
799
|
+
cls,
|
|
800
|
+
path,
|
|
790
801
|
delimiter: str = ",",
|
|
791
802
|
header: bool = True,
|
|
792
803
|
column_names: Optional[list[str]] = None,
|
|
793
|
-
output:
|
|
804
|
+
output: OutputType = None,
|
|
805
|
+
object_name: str = "",
|
|
806
|
+
model_name: str = "",
|
|
807
|
+
**kwargs,
|
|
794
808
|
) -> "DataChain":
|
|
795
|
-
"""Generate chain from
|
|
809
|
+
"""Generate chain from csv files.
|
|
796
810
|
|
|
797
811
|
Parameters:
|
|
812
|
+
path : Storage URI with directory. URI must start with storage prefix such
|
|
813
|
+
as `s3://`, `gs://`, `az://` or "file:///".
|
|
798
814
|
delimiter : Character for delimiting columns.
|
|
799
815
|
header : Whether the files include a header row.
|
|
800
|
-
|
|
801
|
-
|
|
816
|
+
output : Dictionary or feature class defining column names and their
|
|
817
|
+
corresponding types. List of column names is also accepted, in which
|
|
818
|
+
case types will be inferred.
|
|
819
|
+
object_name : Created object column name.
|
|
820
|
+
model_name : Generated model name.
|
|
802
821
|
|
|
803
822
|
Examples:
|
|
804
823
|
Reading a csv file:
|
|
805
|
-
>>> dc = DataChain.
|
|
806
|
-
>>> dc = dc.parse_tabular(format="csv")
|
|
824
|
+
>>> dc = DataChain.from_csv("s3://mybucket/file.csv")
|
|
807
825
|
|
|
808
|
-
Reading
|
|
809
|
-
>>> dc = DataChain.
|
|
810
|
-
>>> dc = dc.filter(C("file.name").glob("*.csv"))
|
|
811
|
-
>>> dc = dc.parse_tabular()
|
|
826
|
+
Reading csv files from a directory as a combined dataset:
|
|
827
|
+
>>> dc = DataChain.from_csv("s3://mybucket/dir")
|
|
812
828
|
"""
|
|
813
829
|
from pyarrow.csv import ParseOptions, ReadOptions
|
|
814
830
|
from pyarrow.dataset import CsvFileFormat
|
|
815
831
|
|
|
816
|
-
|
|
817
|
-
msg = "error parsing csv - only one of column_names or output is allowed"
|
|
818
|
-
raise DatasetPrepareError(self.name, msg)
|
|
832
|
+
chain = DataChain.from_storage(path, **kwargs)
|
|
819
833
|
|
|
820
|
-
if not header
|
|
821
|
-
if output:
|
|
834
|
+
if not header:
|
|
835
|
+
if not output:
|
|
836
|
+
msg = "error parsing csv - provide output if no header"
|
|
837
|
+
raise DatasetPrepareError(chain.name, msg)
|
|
838
|
+
if isinstance(output, Sequence):
|
|
839
|
+
column_names = output # type: ignore[assignment]
|
|
840
|
+
elif isinstance(output, dict):
|
|
822
841
|
column_names = list(output.keys())
|
|
842
|
+
elif (fr := ModelStore.to_pydantic(output)) is not None:
|
|
843
|
+
column_names = list(fr.model_fields.keys())
|
|
823
844
|
else:
|
|
824
|
-
msg = "error parsing csv -
|
|
825
|
-
raise DatasetPrepareError(
|
|
845
|
+
msg = f"error parsing csv - incompatible output type {type(output)}"
|
|
846
|
+
raise DatasetPrepareError(chain.name, msg)
|
|
826
847
|
|
|
827
848
|
parse_options = ParseOptions(delimiter=delimiter)
|
|
828
849
|
read_options = ReadOptions(column_names=column_names)
|
|
829
850
|
format = CsvFileFormat(parse_options=parse_options, read_options=read_options)
|
|
830
|
-
return
|
|
851
|
+
return chain.parse_tabular(
|
|
852
|
+
output=output, object_name=object_name, model_name=model_name, format=format
|
|
853
|
+
)
|
|
831
854
|
|
|
832
|
-
|
|
833
|
-
|
|
855
|
+
@classmethod
|
|
856
|
+
def from_parquet(
|
|
857
|
+
cls,
|
|
858
|
+
path,
|
|
834
859
|
partitioning: Any = "hive",
|
|
835
|
-
output: Optional[dict[str,
|
|
860
|
+
output: Optional[dict[str, DataType]] = None,
|
|
861
|
+
object_name: str = "",
|
|
862
|
+
model_name: str = "",
|
|
863
|
+
**kwargs,
|
|
836
864
|
) -> "DataChain":
|
|
837
|
-
"""Generate chain from
|
|
865
|
+
"""Generate chain from parquet files.
|
|
838
866
|
|
|
839
867
|
Parameters:
|
|
868
|
+
path : Storage URI with directory. URI must start with storage prefix such
|
|
869
|
+
as `s3://`, `gs://`, `az://` or "file:///".
|
|
840
870
|
partitioning : Any pyarrow partitioning schema.
|
|
841
871
|
output : Dictionary defining column names and their corresponding types.
|
|
872
|
+
object_name : Created object column name.
|
|
873
|
+
model_name : Generated model name.
|
|
842
874
|
|
|
843
875
|
Examples:
|
|
844
876
|
Reading a single file:
|
|
845
|
-
>>> dc = DataChain.
|
|
846
|
-
>>> dc = dc.parse_tabular()
|
|
877
|
+
>>> dc = DataChain.from_parquet("s3://mybucket/file.parquet")
|
|
847
878
|
|
|
848
879
|
Reading a partitioned dataset from a directory:
|
|
849
|
-
>>> dc = DataChain.
|
|
850
|
-
>>> dc = dc.parse_tabular()
|
|
851
|
-
|
|
852
|
-
Reading a filtered list of files as a dataset:
|
|
853
|
-
>>> dc = DataChain.from_storage("s3://mybucket")
|
|
854
|
-
>>> dc = dc.filter(C("file.name").glob("*.parquet"))
|
|
855
|
-
>>> dc = dc.parse_tabular()
|
|
856
|
-
|
|
857
|
-
Reading a filtered list of partitions as a dataset:
|
|
858
|
-
>>> dc = DataChain.from_storage("s3://mybucket")
|
|
859
|
-
>>> dc = dc.filter(C("file.parent").glob("*month=1*"))
|
|
860
|
-
>>> dc = dc.parse_tabular()
|
|
880
|
+
>>> dc = DataChain.from_parquet("s3://mybucket/dir")
|
|
861
881
|
"""
|
|
862
|
-
|
|
863
|
-
|
|
882
|
+
chain = DataChain.from_storage(path, **kwargs)
|
|
883
|
+
return chain.parse_tabular(
|
|
884
|
+
output=output,
|
|
885
|
+
object_name=object_name,
|
|
886
|
+
model_name=model_name,
|
|
887
|
+
format="parquet",
|
|
888
|
+
partitioning=partitioning,
|
|
864
889
|
)
|
|
865
890
|
|
|
866
891
|
@classmethod
|
|
@@ -903,17 +928,17 @@ class DataChain(DatasetQuery):
|
|
|
903
928
|
db.execute(insert_q.values(**record))
|
|
904
929
|
return DataChain(name=dsr.name)
|
|
905
930
|
|
|
906
|
-
def sum(self, fr:
|
|
907
|
-
return self.
|
|
931
|
+
def sum(self, fr: DataType): # type: ignore[override]
|
|
932
|
+
return self._extend_to_data_model("sum", fr)
|
|
908
933
|
|
|
909
|
-
def avg(self, fr:
|
|
910
|
-
return self.
|
|
934
|
+
def avg(self, fr: DataType): # type: ignore[override]
|
|
935
|
+
return self._extend_to_data_model("avg", fr)
|
|
911
936
|
|
|
912
|
-
def min(self, fr:
|
|
913
|
-
return self.
|
|
937
|
+
def min(self, fr: DataType): # type: ignore[override]
|
|
938
|
+
return self._extend_to_data_model("min", fr)
|
|
914
939
|
|
|
915
|
-
def max(self, fr:
|
|
916
|
-
return self.
|
|
940
|
+
def max(self, fr: DataType): # type: ignore[override]
|
|
941
|
+
return self._extend_to_data_model("max", fr)
|
|
917
942
|
|
|
918
943
|
def setup(self, **kwargs) -> "Self":
|
|
919
944
|
intersection = set(self._setup.keys()) & set(kwargs.keys())
|