datachain 0.2.9__py3-none-any.whl → 0.2.11__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of datachain might be problematic. Click here for more details.
- datachain/__init__.py +17 -8
- datachain/catalog/catalog.py +5 -5
- datachain/cli.py +0 -2
- datachain/data_storage/schema.py +5 -5
- datachain/data_storage/sqlite.py +1 -1
- datachain/data_storage/warehouse.py +7 -7
- datachain/lib/arrow.py +25 -8
- datachain/lib/clip.py +6 -11
- datachain/lib/convert/__init__.py +0 -0
- datachain/lib/convert/flatten.py +67 -0
- datachain/lib/convert/type_converter.py +96 -0
- datachain/lib/convert/unflatten.py +69 -0
- datachain/lib/convert/values_to_tuples.py +85 -0
- datachain/lib/data_model.py +74 -0
- datachain/lib/dc.py +225 -168
- datachain/lib/file.py +41 -41
- datachain/lib/gpt4_vision.py +1 -9
- datachain/lib/hf_image_to_text.py +9 -17
- datachain/lib/hf_pipeline.py +4 -12
- datachain/lib/image.py +2 -18
- datachain/lib/image_transform.py +0 -1
- datachain/lib/iptc_exif_xmp.py +8 -15
- datachain/lib/meta_formats.py +1 -5
- datachain/lib/model_store.py +77 -0
- datachain/lib/pytorch.py +9 -21
- datachain/lib/signal_schema.py +139 -60
- datachain/lib/text.py +5 -16
- datachain/lib/udf.py +114 -30
- datachain/lib/udf_signature.py +5 -5
- datachain/lib/webdataset.py +3 -3
- datachain/lib/webdataset_laion.py +2 -3
- datachain/node.py +4 -4
- datachain/query/batch.py +1 -1
- datachain/query/dataset.py +51 -178
- datachain/query/dispatch.py +43 -30
- datachain/query/udf.py +46 -26
- datachain/remote/studio.py +1 -9
- datachain/torch/__init__.py +21 -0
- datachain/utils.py +39 -0
- {datachain-0.2.9.dist-info → datachain-0.2.11.dist-info}/METADATA +14 -12
- {datachain-0.2.9.dist-info → datachain-0.2.11.dist-info}/RECORD +45 -43
- {datachain-0.2.9.dist-info → datachain-0.2.11.dist-info}/WHEEL +1 -1
- datachain/image/__init__.py +0 -3
- datachain/lib/cached_stream.py +0 -38
- datachain/lib/claude.py +0 -69
- datachain/lib/feature.py +0 -412
- datachain/lib/feature_registry.py +0 -51
- datachain/lib/feature_utils.py +0 -154
- {datachain-0.2.9.dist-info → datachain-0.2.11.dist-info}/LICENSE +0 -0
- {datachain-0.2.9.dist-info → datachain-0.2.11.dist-info}/entry_points.txt +0 -0
- {datachain-0.2.9.dist-info → datachain-0.2.11.dist-info}/top_level.txt +0 -0
datachain/lib/dc.py
CHANGED
|
@@ -1,3 +1,4 @@
|
|
|
1
|
+
import copy
|
|
1
2
|
import re
|
|
2
3
|
from collections.abc import Iterator, Sequence
|
|
3
4
|
from typing import (
|
|
@@ -10,12 +11,16 @@ from typing import (
|
|
|
10
11
|
Union,
|
|
11
12
|
)
|
|
12
13
|
|
|
14
|
+
import pandas as pd
|
|
13
15
|
import sqlalchemy
|
|
16
|
+
from pydantic import BaseModel, create_model
|
|
14
17
|
|
|
15
|
-
from datachain
|
|
16
|
-
from datachain.lib.
|
|
18
|
+
from datachain import DataModel
|
|
19
|
+
from datachain.lib.convert.values_to_tuples import values_to_tuples
|
|
20
|
+
from datachain.lib.data_model import DataType
|
|
17
21
|
from datachain.lib.file import File, IndexedFile, get_file
|
|
18
22
|
from datachain.lib.meta_formats import read_meta, read_schema
|
|
23
|
+
from datachain.lib.model_store import ModelStore
|
|
19
24
|
from datachain.lib.settings import Settings
|
|
20
25
|
from datachain.lib.signal_schema import SignalSchema
|
|
21
26
|
from datachain.lib.udf import (
|
|
@@ -34,13 +39,11 @@ from datachain.query.dataset import (
|
|
|
34
39
|
detach,
|
|
35
40
|
)
|
|
36
41
|
from datachain.query.schema import Column, DatasetRow
|
|
42
|
+
from datachain.utils import inside_notebook
|
|
37
43
|
|
|
38
44
|
if TYPE_CHECKING:
|
|
39
|
-
import pandas as pd
|
|
40
45
|
from typing_extensions import Self
|
|
41
46
|
|
|
42
|
-
from datachain.catalog import Catalog
|
|
43
|
-
|
|
44
47
|
C = Column
|
|
45
48
|
|
|
46
49
|
|
|
@@ -51,10 +54,10 @@ class DatasetPrepareError(DataChainParamsError):
|
|
|
51
54
|
super().__init__(f"Dataset{name}{output} processing prepare error: {msg}")
|
|
52
55
|
|
|
53
56
|
|
|
54
|
-
class
|
|
57
|
+
class DatasetFromValuesError(DataChainParamsError):
|
|
55
58
|
def __init__(self, name, msg):
|
|
56
59
|
name = f" '{name}'" if name else ""
|
|
57
|
-
super().__init__(f"Dataset {name} from
|
|
60
|
+
super().__init__(f"Dataset {name} from values error: {msg}")
|
|
58
61
|
|
|
59
62
|
|
|
60
63
|
class DatasetMergeError(DataChainParamsError):
|
|
@@ -68,6 +71,14 @@ class DatasetMergeError(DataChainParamsError):
|
|
|
68
71
|
super().__init__(f"Merge error on='{on_str}'{right_on_str}: {msg}")
|
|
69
72
|
|
|
70
73
|
|
|
74
|
+
OutputType = Union[None, DataType, Sequence[str], dict[str, DataType]]
|
|
75
|
+
|
|
76
|
+
|
|
77
|
+
class Sys(DataModel):
|
|
78
|
+
id: int
|
|
79
|
+
rand: int
|
|
80
|
+
|
|
81
|
+
|
|
71
82
|
class DataChain(DatasetQuery):
|
|
72
83
|
"""AI 🔗 DataChain - a data structure for batch data processing and evaluation.
|
|
73
84
|
|
|
@@ -120,12 +131,10 @@ class DataChain(DatasetQuery):
|
|
|
120
131
|
"""
|
|
121
132
|
|
|
122
133
|
DEFAULT_FILE_RECORD: ClassVar[dict] = {
|
|
123
|
-
"id": 0,
|
|
124
134
|
"source": "",
|
|
125
135
|
"name": "",
|
|
126
136
|
"vtype": "",
|
|
127
137
|
"size": 0,
|
|
128
|
-
"random": 0,
|
|
129
138
|
}
|
|
130
139
|
|
|
131
140
|
def __init__(self, *args, **kwargs):
|
|
@@ -151,11 +160,19 @@ class DataChain(DatasetQuery):
|
|
|
151
160
|
def print_schema(self):
|
|
152
161
|
self.signals_schema.print_tree()
|
|
153
162
|
|
|
154
|
-
def
|
|
155
|
-
|
|
163
|
+
def clone(self, new_table: bool = True) -> "Self":
|
|
164
|
+
obj = super().clone(new_table=new_table)
|
|
165
|
+
obj.signals_schema = copy.deepcopy(self.signals_schema)
|
|
166
|
+
return obj
|
|
156
167
|
|
|
157
168
|
def settings(
|
|
158
|
-
self,
|
|
169
|
+
self,
|
|
170
|
+
cache=None,
|
|
171
|
+
batch=None,
|
|
172
|
+
parallel=None,
|
|
173
|
+
workers=None,
|
|
174
|
+
min_task_size=None,
|
|
175
|
+
include_sys: Optional[bool] = None,
|
|
159
176
|
) -> "Self":
|
|
160
177
|
"""Change settings for chain.
|
|
161
178
|
|
|
@@ -179,8 +196,13 @@ class DataChain(DatasetQuery):
|
|
|
179
196
|
)
|
|
180
197
|
```
|
|
181
198
|
"""
|
|
182
|
-
self.
|
|
183
|
-
|
|
199
|
+
chain = self.clone()
|
|
200
|
+
if include_sys is True:
|
|
201
|
+
chain.signals_schema = SignalSchema({"sys": Sys}) | chain.signals_schema
|
|
202
|
+
elif include_sys is False and "sys" in chain.signals_schema:
|
|
203
|
+
chain.signals_schema.remove("sys")
|
|
204
|
+
chain._settings.add(Settings(cache, batch, parallel, workers, min_task_size))
|
|
205
|
+
return chain
|
|
184
206
|
|
|
185
207
|
def reset_settings(self, settings: Optional[Settings] = None) -> "Self":
|
|
186
208
|
"""Reset all settings to default values."""
|
|
@@ -192,12 +214,11 @@ class DataChain(DatasetQuery):
|
|
|
192
214
|
return self
|
|
193
215
|
|
|
194
216
|
def add_schema(self, signals_schema: SignalSchema) -> "Self":
|
|
195
|
-
|
|
196
|
-
self.signals_schema = SignalSchema(union)
|
|
217
|
+
self.signals_schema |= signals_schema
|
|
197
218
|
return self
|
|
198
219
|
|
|
199
220
|
def get_file_signals(self) -> list[str]:
|
|
200
|
-
return self.signals_schema.get_file_signals()
|
|
221
|
+
return list(self.signals_schema.get_file_signals())
|
|
201
222
|
|
|
202
223
|
@classmethod
|
|
203
224
|
def from_storage(
|
|
@@ -205,9 +226,10 @@ class DataChain(DatasetQuery):
|
|
|
205
226
|
path,
|
|
206
227
|
*,
|
|
207
228
|
type: Literal["binary", "text", "image"] = "binary",
|
|
208
|
-
|
|
229
|
+
session: Optional[Session] = None,
|
|
209
230
|
recursive: Optional[bool] = True,
|
|
210
|
-
|
|
231
|
+
object_name: str = "file",
|
|
232
|
+
**kwargs,
|
|
211
233
|
) -> "Self":
|
|
212
234
|
"""Get data from a storage as a list of file with all file attributes. It
|
|
213
235
|
returns the chain itself as usual.
|
|
@@ -217,7 +239,7 @@ class DataChain(DatasetQuery):
|
|
|
217
239
|
as `s3://`, `gs://`, `az://` or "file:///"
|
|
218
240
|
type : read file as "binary", "text", or "image" data. Default is "binary".
|
|
219
241
|
recursive : search recursively for the given path.
|
|
220
|
-
|
|
242
|
+
object_name : Created object column name.
|
|
221
243
|
|
|
222
244
|
Example:
|
|
223
245
|
```py
|
|
@@ -225,7 +247,9 @@ class DataChain(DatasetQuery):
|
|
|
225
247
|
```
|
|
226
248
|
"""
|
|
227
249
|
func = get_file(type)
|
|
228
|
-
return cls(path,
|
|
250
|
+
return cls(path, session=session, recursive=recursive, **kwargs).map(
|
|
251
|
+
**{object_name: func}
|
|
252
|
+
)
|
|
229
253
|
|
|
230
254
|
@classmethod
|
|
231
255
|
def from_dataset(cls, name: str, version: Optional[int] = None) -> "DataChain":
|
|
@@ -240,66 +264,19 @@ class DataChain(DatasetQuery):
|
|
|
240
264
|
"""
|
|
241
265
|
return DataChain(name=name, version=version)
|
|
242
266
|
|
|
243
|
-
@classmethod
|
|
244
|
-
def from_csv(
|
|
245
|
-
cls,
|
|
246
|
-
path,
|
|
247
|
-
type: Literal["binary", "text", "image"] = "text",
|
|
248
|
-
anon: bool = False,
|
|
249
|
-
spec: Optional[FeatureType] = None,
|
|
250
|
-
schema_from: Optional[str] = "auto",
|
|
251
|
-
object_name: Optional[str] = "csv",
|
|
252
|
-
model_name: Optional[str] = None,
|
|
253
|
-
show_schema: Optional[bool] = False,
|
|
254
|
-
) -> "DataChain":
|
|
255
|
-
"""Get data from CSV. It returns the chain itself.
|
|
256
|
-
|
|
257
|
-
Parameters:
|
|
258
|
-
path : storage URI with directory. URI must start with storage prefix such
|
|
259
|
-
as `s3://`, `gs://`, `az://` or "file:///"
|
|
260
|
-
type : read file as "binary", "text", or "image" data. Default is "text".
|
|
261
|
-
anon : use anonymous mode to access the storage.
|
|
262
|
-
spec : Data Model for CSV file
|
|
263
|
-
object_name : generated object column name
|
|
264
|
-
model_name : generated model name
|
|
265
|
-
schema_from : path to sample to infer spec from
|
|
266
|
-
show_schema : print auto-generated schema
|
|
267
|
-
|
|
268
|
-
Examples:
|
|
269
|
-
infer model from the first two lines (header + data)
|
|
270
|
-
>>> chain = DataChain.from_csv("gs://csv")
|
|
271
|
-
|
|
272
|
-
use a particular data model
|
|
273
|
-
>>> chain = DataChain.from_csv("gs://csv"i, spec=MyModel)
|
|
274
|
-
"""
|
|
275
|
-
if schema_from == "auto":
|
|
276
|
-
schema_from = path
|
|
277
|
-
|
|
278
|
-
chain = DataChain.from_storage(path=path, type=type, anon=anon)
|
|
279
|
-
signal_dict = {
|
|
280
|
-
object_name: read_meta(
|
|
281
|
-
schema_from=schema_from,
|
|
282
|
-
meta_type="csv",
|
|
283
|
-
spec=spec,
|
|
284
|
-
model_name=model_name,
|
|
285
|
-
show_schema=show_schema,
|
|
286
|
-
)
|
|
287
|
-
}
|
|
288
|
-
return chain.gen(**signal_dict) # type: ignore[misc, arg-type]
|
|
289
|
-
|
|
290
267
|
@classmethod
|
|
291
268
|
def from_json(
|
|
292
269
|
cls,
|
|
293
270
|
path,
|
|
294
271
|
type: Literal["binary", "text", "image"] = "text",
|
|
295
|
-
|
|
296
|
-
spec: Optional[FeatureType] = None,
|
|
272
|
+
spec: Optional[DataType] = None,
|
|
297
273
|
schema_from: Optional[str] = "auto",
|
|
298
274
|
jmespath: Optional[str] = None,
|
|
299
|
-
object_name:
|
|
275
|
+
object_name: str = "",
|
|
300
276
|
model_name: Optional[str] = None,
|
|
301
277
|
show_schema: Optional[bool] = False,
|
|
302
278
|
meta_type: Optional[str] = "json",
|
|
279
|
+
**kwargs,
|
|
303
280
|
) -> "DataChain":
|
|
304
281
|
"""Get data from JSON. It returns the chain itself.
|
|
305
282
|
|
|
@@ -307,7 +284,6 @@ class DataChain(DatasetQuery):
|
|
|
307
284
|
path : storage URI with directory. URI must start with storage prefix such
|
|
308
285
|
as `s3://`, `gs://`, `az://` or "file:///"
|
|
309
286
|
type : read file as "binary", "text", or "image" data. Default is "binary".
|
|
310
|
-
anon : use anonymous mode to access the storage.
|
|
311
287
|
spec : optional Data Model
|
|
312
288
|
schema_from : path to sample to infer spec from
|
|
313
289
|
object_name : generated object column name
|
|
@@ -333,7 +309,7 @@ class DataChain(DatasetQuery):
|
|
|
333
309
|
object_name = jmespath_to_name(jmespath)
|
|
334
310
|
if not object_name:
|
|
335
311
|
object_name = "json"
|
|
336
|
-
chain = DataChain.from_storage(path=path, type=type,
|
|
312
|
+
chain = DataChain.from_storage(path=path, type=type, **kwargs)
|
|
337
313
|
signal_dict = {
|
|
338
314
|
object_name: read_meta(
|
|
339
315
|
schema_from=schema_from,
|
|
@@ -396,6 +372,7 @@ class DataChain(DatasetQuery):
|
|
|
396
372
|
version : version of a dataset. Default - the last version that exist.
|
|
397
373
|
"""
|
|
398
374
|
schema = self.signals_schema.serialize()
|
|
375
|
+
schema.pop("sys", None)
|
|
399
376
|
return super().save(name=name, version=version, feature_schema=schema)
|
|
400
377
|
|
|
401
378
|
def apply(self, func, *args, **kwargs):
|
|
@@ -405,7 +382,7 @@ class DataChain(DatasetQuery):
|
|
|
405
382
|
self,
|
|
406
383
|
func: Optional[Callable] = None,
|
|
407
384
|
params: Union[None, str, Sequence[str]] = None,
|
|
408
|
-
output:
|
|
385
|
+
output: OutputType = None,
|
|
409
386
|
**signal_map,
|
|
410
387
|
) -> "Self":
|
|
411
388
|
"""Apply a function to each row to create new signals. The function should
|
|
@@ -449,7 +426,7 @@ class DataChain(DatasetQuery):
|
|
|
449
426
|
self,
|
|
450
427
|
func: Optional[Callable] = None,
|
|
451
428
|
params: Union[None, str, Sequence[str]] = None,
|
|
452
|
-
output:
|
|
429
|
+
output: OutputType = None,
|
|
453
430
|
**signal_map,
|
|
454
431
|
) -> "Self":
|
|
455
432
|
"""Apply a function to each row to create new rows (with potentially new
|
|
@@ -478,7 +455,7 @@ class DataChain(DatasetQuery):
|
|
|
478
455
|
func: Optional[Callable] = None,
|
|
479
456
|
partition_by: Optional[PartitionByType] = None,
|
|
480
457
|
params: Union[None, str, Sequence[str]] = None,
|
|
481
|
-
output:
|
|
458
|
+
output: OutputType = None,
|
|
482
459
|
**signal_map,
|
|
483
460
|
) -> "Self":
|
|
484
461
|
"""Aggregate rows using `partition_by` statement and apply a function to the
|
|
@@ -508,7 +485,7 @@ class DataChain(DatasetQuery):
|
|
|
508
485
|
self,
|
|
509
486
|
func: Optional[Callable] = None,
|
|
510
487
|
params: Union[None, str, Sequence[str]] = None,
|
|
511
|
-
output:
|
|
488
|
+
output: OutputType = None,
|
|
512
489
|
**signal_map,
|
|
513
490
|
) -> "Self":
|
|
514
491
|
"""This is a batch version of map().
|
|
@@ -530,18 +507,20 @@ class DataChain(DatasetQuery):
|
|
|
530
507
|
target_class: type[UDFBase],
|
|
531
508
|
func: Optional[Callable],
|
|
532
509
|
params: Union[None, str, Sequence[str]],
|
|
533
|
-
output:
|
|
510
|
+
output: OutputType,
|
|
534
511
|
signal_map,
|
|
535
512
|
) -> UDFBase:
|
|
536
513
|
is_generator = target_class.is_output_batched
|
|
537
514
|
name = self.name or ""
|
|
538
515
|
|
|
539
516
|
sign = UdfSignature.parse(name, signal_map, func, params, output, is_generator)
|
|
517
|
+
DataModel.register(list(sign.output_schema.values.values()))
|
|
518
|
+
|
|
540
519
|
params_schema = self.signals_schema.slice(sign.params, self._setup)
|
|
541
520
|
|
|
542
|
-
return
|
|
521
|
+
return target_class._create(sign, params_schema)
|
|
543
522
|
|
|
544
|
-
def
|
|
523
|
+
def _extend_to_data_model(self, method_name, *args, **kwargs):
|
|
545
524
|
super_func = getattr(super(), method_name)
|
|
546
525
|
|
|
547
526
|
new_schema = self.signals_schema.resolve(*args)
|
|
@@ -570,39 +549,46 @@ class DataChain(DatasetQuery):
|
|
|
570
549
|
chain.signals_schema = new_schema
|
|
571
550
|
return chain
|
|
572
551
|
|
|
573
|
-
def
|
|
552
|
+
def iterate_flatten(self) -> Iterator[tuple[Any]]:
|
|
553
|
+
db_signals = self.signals_schema.db_signals()
|
|
554
|
+
with super().select(*db_signals).as_iterable() as rows:
|
|
555
|
+
yield from rows
|
|
556
|
+
|
|
557
|
+
def results(
|
|
558
|
+
self, row_factory: Optional[Callable] = None, **kwargs
|
|
559
|
+
) -> list[tuple[Any, ...]]:
|
|
560
|
+
rows = self.iterate_flatten()
|
|
561
|
+
if row_factory:
|
|
562
|
+
db_signals = self.signals_schema.db_signals()
|
|
563
|
+
rows = (row_factory(db_signals, r) for r in rows)
|
|
564
|
+
return list(rows)
|
|
565
|
+
|
|
566
|
+
def iterate(self, *cols: str) -> Iterator[list[DataType]]:
|
|
574
567
|
"""Iterate over rows.
|
|
575
568
|
|
|
576
569
|
If columns are specified - limit them to specified
|
|
577
570
|
columns.
|
|
578
571
|
"""
|
|
579
572
|
chain = self.select(*cols) if cols else self
|
|
573
|
+
for row in chain.iterate_flatten():
|
|
574
|
+
yield chain.signals_schema.row_to_features(
|
|
575
|
+
row, catalog=chain.session.catalog, cache=chain._settings.cache
|
|
576
|
+
)
|
|
580
577
|
|
|
581
|
-
|
|
582
|
-
with super().select(*db_signals).as_iterable() as rows_iter:
|
|
583
|
-
for row in rows_iter:
|
|
584
|
-
yield chain.signals_schema.row_to_features(row, chain.session.catalog)
|
|
585
|
-
|
|
586
|
-
def iterate_one(self, col: str) -> Iterator[FeatureType]:
|
|
578
|
+
def iterate_one(self, col: str) -> Iterator[DataType]:
|
|
587
579
|
for item in self.iterate(col):
|
|
588
580
|
yield item[0]
|
|
589
581
|
|
|
590
|
-
def collect(self, *cols: str) -> list[list[
|
|
582
|
+
def collect(self, *cols: str) -> list[list[DataType]]:
|
|
591
583
|
return list(self.iterate(*cols))
|
|
592
584
|
|
|
593
|
-
def collect_one(self, col: str) -> list[
|
|
585
|
+
def collect_one(self, col: str) -> list[DataType]:
|
|
594
586
|
return list(self.iterate_one(col))
|
|
595
587
|
|
|
596
588
|
def to_pytorch(self, **kwargs):
|
|
597
589
|
"""Convert to pytorch dataset format."""
|
|
598
590
|
|
|
599
|
-
|
|
600
|
-
import torch # noqa: F401
|
|
601
|
-
except ImportError as exc:
|
|
602
|
-
raise ImportError(
|
|
603
|
-
"Missing required dependency 'torch' for Dataset.to_pytorch()"
|
|
604
|
-
) from exc
|
|
605
|
-
from datachain.lib.pytorch import PytorchDataset
|
|
591
|
+
from datachain.torch import PytorchDataset
|
|
606
592
|
|
|
607
593
|
if self.attached:
|
|
608
594
|
chain = self
|
|
@@ -701,25 +687,32 @@ class DataChain(DatasetQuery):
|
|
|
701
687
|
return ds
|
|
702
688
|
|
|
703
689
|
@classmethod
|
|
704
|
-
def
|
|
690
|
+
def from_values(
|
|
705
691
|
cls,
|
|
706
692
|
ds_name: str = "",
|
|
707
693
|
session: Optional[Session] = None,
|
|
708
|
-
output:
|
|
694
|
+
output: OutputType = None,
|
|
695
|
+
object_name: str = "",
|
|
709
696
|
**fr_map,
|
|
710
697
|
) -> "DataChain":
|
|
711
|
-
"""Generate chain from list of
|
|
712
|
-
tuple_type, output, tuples =
|
|
698
|
+
"""Generate chain from list of values."""
|
|
699
|
+
tuple_type, output, tuples = values_to_tuples(ds_name, output, **fr_map)
|
|
713
700
|
|
|
714
701
|
def _func_fr() -> Iterator[tuple_type]: # type: ignore[valid-type]
|
|
715
702
|
yield from tuples
|
|
716
703
|
|
|
717
704
|
chain = DataChain.create_empty(DataChain.DEFAULT_FILE_RECORD, session=session)
|
|
705
|
+
if object_name:
|
|
706
|
+
output = {object_name: DataChain._dict_to_data_model(object_name, output)} # type: ignore[arg-type]
|
|
718
707
|
return chain.gen(_func_fr, output=output)
|
|
719
708
|
|
|
720
709
|
@classmethod
|
|
721
710
|
def from_pandas( # type: ignore[override]
|
|
722
|
-
cls,
|
|
711
|
+
cls,
|
|
712
|
+
df: "pd.DataFrame",
|
|
713
|
+
name: str = "",
|
|
714
|
+
session: Optional[Session] = None,
|
|
715
|
+
object_name: str = "",
|
|
723
716
|
) -> "DataChain":
|
|
724
717
|
"""Generate chain from pandas data-frame."""
|
|
725
718
|
fr_map = {col.lower(): df[col].tolist() for col in df.columns}
|
|
@@ -737,17 +730,54 @@ class DataChain(DatasetQuery):
|
|
|
737
730
|
f"import from pandas error - '{column}' cannot be a column name",
|
|
738
731
|
)
|
|
739
732
|
|
|
740
|
-
return cls.
|
|
733
|
+
return cls.from_values(name, session, object_name=object_name, **fr_map)
|
|
734
|
+
|
|
735
|
+
def to_pandas(self, flatten=False) -> "pd.DataFrame":
|
|
736
|
+
headers, max_length = self.signals_schema.get_headers_with_length()
|
|
737
|
+
if flatten or max_length < 2:
|
|
738
|
+
df = pd.DataFrame.from_records(self.to_records())
|
|
739
|
+
if headers:
|
|
740
|
+
df.columns = [".".join(filter(None, header)) for header in headers]
|
|
741
|
+
return df
|
|
742
|
+
|
|
743
|
+
transposed_result = list(map(list, zip(*self.results())))
|
|
744
|
+
data = {tuple(n): val for n, val in zip(headers, transposed_result)}
|
|
745
|
+
return pd.DataFrame(data)
|
|
746
|
+
|
|
747
|
+
def show(self, limit: int = 20, flatten=False, transpose=False) -> None:
|
|
748
|
+
dc = self.limit(limit) if limit > 0 else self
|
|
749
|
+
df = dc.to_pandas(flatten)
|
|
750
|
+
if transpose:
|
|
751
|
+
df = df.T
|
|
752
|
+
|
|
753
|
+
with pd.option_context(
|
|
754
|
+
"display.max_columns", None, "display.multi_sparse", False
|
|
755
|
+
):
|
|
756
|
+
if inside_notebook():
|
|
757
|
+
from IPython.display import display
|
|
758
|
+
|
|
759
|
+
display(df)
|
|
760
|
+
else:
|
|
761
|
+
print(df)
|
|
762
|
+
|
|
763
|
+
if len(df) == limit:
|
|
764
|
+
print(f"\n[Limited by {len(df)} rows]")
|
|
741
765
|
|
|
742
766
|
def parse_tabular(
|
|
743
767
|
self,
|
|
744
|
-
output:
|
|
768
|
+
output: OutputType = None,
|
|
769
|
+
object_name: str = "",
|
|
770
|
+
model_name: str = "",
|
|
745
771
|
**kwargs,
|
|
746
772
|
) -> "DataChain":
|
|
747
773
|
"""Generate chain from list of tabular files.
|
|
748
774
|
|
|
749
775
|
Parameters:
|
|
750
|
-
output : Dictionary defining column names and their
|
|
776
|
+
output : Dictionary or feature class defining column names and their
|
|
777
|
+
corresponding types. List of column names is also accepted, in which
|
|
778
|
+
case types will be inferred.
|
|
779
|
+
object_name : Generated object column name.
|
|
780
|
+
model_name : Generated model name.
|
|
751
781
|
kwargs : Parameters to pass to pyarrow.dataset.dataset.
|
|
752
782
|
|
|
753
783
|
Examples:
|
|
@@ -760,107 +790,134 @@ class DataChain(DatasetQuery):
|
|
|
760
790
|
>>> dc = dc.filter(C("file.name").glob("*.jsonl"))
|
|
761
791
|
>>> dc = dc.parse_tabular(format="json")
|
|
762
792
|
"""
|
|
763
|
-
from pyarrow import unify_schemas
|
|
764
|
-
from pyarrow.dataset import dataset
|
|
765
793
|
|
|
766
|
-
from datachain.lib.arrow import ArrowGenerator, schema_to_output
|
|
794
|
+
from datachain.lib.arrow import ArrowGenerator, infer_schema, schema_to_output
|
|
767
795
|
|
|
768
796
|
schema = None
|
|
769
|
-
if output
|
|
770
|
-
|
|
771
|
-
else:
|
|
772
|
-
schemas = []
|
|
773
|
-
for row in self.select("file").iterate():
|
|
774
|
-
file = row[0]
|
|
775
|
-
ds = dataset(file.get_path(), filesystem=file.get_fs(), **kwargs) # type: ignore[union-attr]
|
|
776
|
-
schemas.append(ds.schema)
|
|
777
|
-
if not schemas:
|
|
778
|
-
msg = "error parsing tabular data schema - found no files to parse"
|
|
779
|
-
raise DatasetPrepareError(self.name, msg)
|
|
780
|
-
schema = unify_schemas(schemas)
|
|
797
|
+
col_names = output if isinstance(output, Sequence) else None
|
|
798
|
+
if col_names or not output:
|
|
781
799
|
try:
|
|
782
|
-
|
|
800
|
+
schema = infer_schema(self, **kwargs)
|
|
801
|
+
output = schema_to_output(schema, col_names)
|
|
783
802
|
except ValueError as e:
|
|
784
803
|
raise DatasetPrepareError(self.name, e) from e
|
|
785
804
|
|
|
805
|
+
if object_name:
|
|
806
|
+
if isinstance(output, dict):
|
|
807
|
+
model_name = model_name or object_name
|
|
808
|
+
output = DataChain._dict_to_data_model(model_name, output)
|
|
809
|
+
output = {object_name: output} # type: ignore[dict-item]
|
|
810
|
+
elif isinstance(output, type(BaseModel)):
|
|
811
|
+
output = {
|
|
812
|
+
name: info.annotation # type: ignore[misc]
|
|
813
|
+
for name, info in output.model_fields.items()
|
|
814
|
+
}
|
|
815
|
+
output = {"source": IndexedFile} | output # type: ignore[assignment,operator]
|
|
786
816
|
return self.gen(ArrowGenerator(schema, **kwargs), output=output)
|
|
787
817
|
|
|
788
|
-
|
|
789
|
-
|
|
818
|
+
@staticmethod
|
|
819
|
+
def _dict_to_data_model(
|
|
820
|
+
name: str, data_dict: dict[str, DataType]
|
|
821
|
+
) -> type[BaseModel]:
|
|
822
|
+
fields = {name: (anno, ...) for name, anno in data_dict.items()}
|
|
823
|
+
return create_model(
|
|
824
|
+
name,
|
|
825
|
+
__base__=(DataModel,), # type: ignore[call-overload]
|
|
826
|
+
**fields,
|
|
827
|
+
) # type: ignore[call-overload]
|
|
828
|
+
|
|
829
|
+
@classmethod
|
|
830
|
+
def from_csv(
|
|
831
|
+
cls,
|
|
832
|
+
path,
|
|
790
833
|
delimiter: str = ",",
|
|
791
834
|
header: bool = True,
|
|
792
835
|
column_names: Optional[list[str]] = None,
|
|
793
|
-
output:
|
|
836
|
+
output: OutputType = None,
|
|
837
|
+
object_name: str = "",
|
|
838
|
+
model_name: str = "",
|
|
839
|
+
**kwargs,
|
|
794
840
|
) -> "DataChain":
|
|
795
|
-
"""Generate chain from
|
|
841
|
+
"""Generate chain from csv files.
|
|
796
842
|
|
|
797
843
|
Parameters:
|
|
844
|
+
path : Storage URI with directory. URI must start with storage prefix such
|
|
845
|
+
as `s3://`, `gs://`, `az://` or "file:///".
|
|
798
846
|
delimiter : Character for delimiting columns.
|
|
799
847
|
header : Whether the files include a header row.
|
|
800
|
-
|
|
801
|
-
|
|
848
|
+
output : Dictionary or feature class defining column names and their
|
|
849
|
+
corresponding types. List of column names is also accepted, in which
|
|
850
|
+
case types will be inferred.
|
|
851
|
+
object_name : Created object column name.
|
|
852
|
+
model_name : Generated model name.
|
|
802
853
|
|
|
803
854
|
Examples:
|
|
804
855
|
Reading a csv file:
|
|
805
|
-
>>> dc = DataChain.
|
|
806
|
-
>>> dc = dc.parse_tabular(format="csv")
|
|
856
|
+
>>> dc = DataChain.from_csv("s3://mybucket/file.csv")
|
|
807
857
|
|
|
808
|
-
Reading
|
|
809
|
-
>>> dc = DataChain.
|
|
810
|
-
>>> dc = dc.filter(C("file.name").glob("*.csv"))
|
|
811
|
-
>>> dc = dc.parse_tabular()
|
|
858
|
+
Reading csv files from a directory as a combined dataset:
|
|
859
|
+
>>> dc = DataChain.from_csv("s3://mybucket/dir")
|
|
812
860
|
"""
|
|
813
861
|
from pyarrow.csv import ParseOptions, ReadOptions
|
|
814
862
|
from pyarrow.dataset import CsvFileFormat
|
|
815
863
|
|
|
816
|
-
|
|
817
|
-
msg = "error parsing csv - only one of column_names or output is allowed"
|
|
818
|
-
raise DatasetPrepareError(self.name, msg)
|
|
864
|
+
chain = DataChain.from_storage(path, **kwargs)
|
|
819
865
|
|
|
820
|
-
if not header
|
|
821
|
-
if output:
|
|
866
|
+
if not header:
|
|
867
|
+
if not output:
|
|
868
|
+
msg = "error parsing csv - provide output if no header"
|
|
869
|
+
raise DatasetPrepareError(chain.name, msg)
|
|
870
|
+
if isinstance(output, Sequence):
|
|
871
|
+
column_names = output # type: ignore[assignment]
|
|
872
|
+
elif isinstance(output, dict):
|
|
822
873
|
column_names = list(output.keys())
|
|
874
|
+
elif (fr := ModelStore.to_pydantic(output)) is not None:
|
|
875
|
+
column_names = list(fr.model_fields.keys())
|
|
823
876
|
else:
|
|
824
|
-
msg = "error parsing csv -
|
|
825
|
-
raise DatasetPrepareError(
|
|
877
|
+
msg = f"error parsing csv - incompatible output type {type(output)}"
|
|
878
|
+
raise DatasetPrepareError(chain.name, msg)
|
|
826
879
|
|
|
827
880
|
parse_options = ParseOptions(delimiter=delimiter)
|
|
828
881
|
read_options = ReadOptions(column_names=column_names)
|
|
829
882
|
format = CsvFileFormat(parse_options=parse_options, read_options=read_options)
|
|
830
|
-
return
|
|
883
|
+
return chain.parse_tabular(
|
|
884
|
+
output=output, object_name=object_name, model_name=model_name, format=format
|
|
885
|
+
)
|
|
831
886
|
|
|
832
|
-
|
|
833
|
-
|
|
887
|
+
@classmethod
|
|
888
|
+
def from_parquet(
|
|
889
|
+
cls,
|
|
890
|
+
path,
|
|
834
891
|
partitioning: Any = "hive",
|
|
835
|
-
output: Optional[dict[str,
|
|
892
|
+
output: Optional[dict[str, DataType]] = None,
|
|
893
|
+
object_name: str = "",
|
|
894
|
+
model_name: str = "",
|
|
895
|
+
**kwargs,
|
|
836
896
|
) -> "DataChain":
|
|
837
|
-
"""Generate chain from
|
|
897
|
+
"""Generate chain from parquet files.
|
|
838
898
|
|
|
839
899
|
Parameters:
|
|
900
|
+
path : Storage URI with directory. URI must start with storage prefix such
|
|
901
|
+
as `s3://`, `gs://`, `az://` or "file:///".
|
|
840
902
|
partitioning : Any pyarrow partitioning schema.
|
|
841
903
|
output : Dictionary defining column names and their corresponding types.
|
|
904
|
+
object_name : Created object column name.
|
|
905
|
+
model_name : Generated model name.
|
|
842
906
|
|
|
843
907
|
Examples:
|
|
844
908
|
Reading a single file:
|
|
845
|
-
>>> dc = DataChain.
|
|
846
|
-
>>> dc = dc.parse_tabular()
|
|
909
|
+
>>> dc = DataChain.from_parquet("s3://mybucket/file.parquet")
|
|
847
910
|
|
|
848
911
|
Reading a partitioned dataset from a directory:
|
|
849
|
-
>>> dc = DataChain.
|
|
850
|
-
>>> dc = dc.parse_tabular()
|
|
851
|
-
|
|
852
|
-
Reading a filtered list of files as a dataset:
|
|
853
|
-
>>> dc = DataChain.from_storage("s3://mybucket")
|
|
854
|
-
>>> dc = dc.filter(C("file.name").glob("*.parquet"))
|
|
855
|
-
>>> dc = dc.parse_tabular()
|
|
856
|
-
|
|
857
|
-
Reading a filtered list of partitions as a dataset:
|
|
858
|
-
>>> dc = DataChain.from_storage("s3://mybucket")
|
|
859
|
-
>>> dc = dc.filter(C("file.parent").glob("*month=1*"))
|
|
860
|
-
>>> dc = dc.parse_tabular()
|
|
912
|
+
>>> dc = DataChain.from_parquet("s3://mybucket/dir")
|
|
861
913
|
"""
|
|
862
|
-
|
|
863
|
-
|
|
914
|
+
chain = DataChain.from_storage(path, **kwargs)
|
|
915
|
+
return chain.parse_tabular(
|
|
916
|
+
output=output,
|
|
917
|
+
object_name=object_name,
|
|
918
|
+
model_name=model_name,
|
|
919
|
+
format="parquet",
|
|
920
|
+
partitioning=partitioning,
|
|
864
921
|
)
|
|
865
922
|
|
|
866
923
|
@classmethod
|
|
@@ -903,17 +960,17 @@ class DataChain(DatasetQuery):
|
|
|
903
960
|
db.execute(insert_q.values(**record))
|
|
904
961
|
return DataChain(name=dsr.name)
|
|
905
962
|
|
|
906
|
-
def sum(self, fr:
|
|
907
|
-
return self.
|
|
963
|
+
def sum(self, fr: DataType): # type: ignore[override]
|
|
964
|
+
return self._extend_to_data_model("sum", fr)
|
|
908
965
|
|
|
909
|
-
def avg(self, fr:
|
|
910
|
-
return self.
|
|
966
|
+
def avg(self, fr: DataType): # type: ignore[override]
|
|
967
|
+
return self._extend_to_data_model("avg", fr)
|
|
911
968
|
|
|
912
|
-
def min(self, fr:
|
|
913
|
-
return self.
|
|
969
|
+
def min(self, fr: DataType): # type: ignore[override]
|
|
970
|
+
return self._extend_to_data_model("min", fr)
|
|
914
971
|
|
|
915
|
-
def max(self, fr:
|
|
916
|
-
return self.
|
|
972
|
+
def max(self, fr: DataType): # type: ignore[override]
|
|
973
|
+
return self._extend_to_data_model("max", fr)
|
|
917
974
|
|
|
918
975
|
def setup(self, **kwargs) -> "Self":
|
|
919
976
|
intersection = set(self._setup.keys()) & set(kwargs.keys())
|