datachain 0.2.8__py3-none-any.whl → 0.2.10__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of datachain might be problematic. Click here for more details.

Files changed (50) hide show
  1. datachain/__init__.py +17 -8
  2. datachain/catalog/catalog.py +5 -5
  3. datachain/cli.py +0 -2
  4. datachain/data_storage/schema.py +5 -5
  5. datachain/data_storage/sqlite.py +1 -1
  6. datachain/data_storage/warehouse.py +7 -7
  7. datachain/lib/arrow.py +25 -8
  8. datachain/lib/clip.py +6 -11
  9. datachain/lib/convert/__init__.py +0 -0
  10. datachain/lib/convert/flatten.py +67 -0
  11. datachain/lib/convert/type_converter.py +96 -0
  12. datachain/lib/convert/unflatten.py +69 -0
  13. datachain/lib/convert/values_to_tuples.py +85 -0
  14. datachain/lib/data_model.py +74 -0
  15. datachain/lib/dc.py +192 -167
  16. datachain/lib/feature_registry.py +36 -10
  17. datachain/lib/file.py +41 -41
  18. datachain/lib/gpt4_vision.py +1 -9
  19. datachain/lib/hf_image_to_text.py +9 -17
  20. datachain/lib/hf_pipeline.py +4 -12
  21. datachain/lib/image.py +2 -18
  22. datachain/lib/image_transform.py +0 -1
  23. datachain/lib/iptc_exif_xmp.py +8 -15
  24. datachain/lib/meta_formats.py +1 -5
  25. datachain/lib/model_store.py +77 -0
  26. datachain/lib/pytorch.py +9 -21
  27. datachain/lib/signal_schema.py +120 -58
  28. datachain/lib/text.py +5 -16
  29. datachain/lib/udf.py +114 -30
  30. datachain/lib/udf_signature.py +5 -5
  31. datachain/lib/webdataset.py +3 -4
  32. datachain/lib/webdataset_laion.py +2 -3
  33. datachain/node.py +4 -4
  34. datachain/query/batch.py +1 -1
  35. datachain/query/dataset.py +40 -60
  36. datachain/query/dispatch.py +28 -17
  37. datachain/query/udf.py +46 -26
  38. datachain/remote/studio.py +1 -9
  39. datachain/torch/__init__.py +21 -0
  40. {datachain-0.2.8.dist-info → datachain-0.2.10.dist-info}/METADATA +13 -12
  41. {datachain-0.2.8.dist-info → datachain-0.2.10.dist-info}/RECORD +45 -42
  42. datachain/image/__init__.py +0 -3
  43. datachain/lib/cached_stream.py +0 -38
  44. datachain/lib/claude.py +0 -69
  45. datachain/lib/feature.py +0 -412
  46. datachain/lib/feature_utils.py +0 -154
  47. {datachain-0.2.8.dist-info → datachain-0.2.10.dist-info}/LICENSE +0 -0
  48. {datachain-0.2.8.dist-info → datachain-0.2.10.dist-info}/WHEEL +0 -0
  49. {datachain-0.2.8.dist-info → datachain-0.2.10.dist-info}/entry_points.txt +0 -0
  50. {datachain-0.2.8.dist-info → datachain-0.2.10.dist-info}/top_level.txt +0 -0
datachain/lib/dc.py CHANGED
@@ -1,3 +1,4 @@
1
+ import copy
1
2
  import re
2
3
  from collections.abc import Iterator, Sequence
3
4
  from typing import (
@@ -11,11 +12,14 @@ from typing import (
11
12
  )
12
13
 
13
14
  import sqlalchemy
15
+ from pydantic import BaseModel, create_model
14
16
 
15
- from datachain.lib.feature import Feature, FeatureType
16
- from datachain.lib.feature_utils import features_to_tuples
17
+ from datachain import DataModel
18
+ from datachain.lib.convert.values_to_tuples import values_to_tuples
19
+ from datachain.lib.data_model import DataType
17
20
  from datachain.lib.file import File, IndexedFile, get_file
18
21
  from datachain.lib.meta_formats import read_meta, read_schema
22
+ from datachain.lib.model_store import ModelStore
19
23
  from datachain.lib.settings import Settings
20
24
  from datachain.lib.signal_schema import SignalSchema
21
25
  from datachain.lib.udf import (
@@ -39,8 +43,6 @@ if TYPE_CHECKING:
39
43
  import pandas as pd
40
44
  from typing_extensions import Self
41
45
 
42
- from datachain.catalog import Catalog
43
-
44
46
  C = Column
45
47
 
46
48
 
@@ -51,10 +53,10 @@ class DatasetPrepareError(DataChainParamsError):
51
53
  super().__init__(f"Dataset{name}{output} processing prepare error: {msg}")
52
54
 
53
55
 
54
- class DatasetFromFeatureError(DataChainParamsError):
56
+ class DatasetFromValuesError(DataChainParamsError):
55
57
  def __init__(self, name, msg):
56
58
  name = f" '{name}'" if name else ""
57
- super().__init__(f"Dataset {name} from feature error: {msg}")
59
+ super().__init__(f"Dataset {name} from values error: {msg}")
58
60
 
59
61
 
60
62
  class DatasetMergeError(DataChainParamsError):
@@ -68,6 +70,14 @@ class DatasetMergeError(DataChainParamsError):
68
70
  super().__init__(f"Merge error on='{on_str}'{right_on_str}: {msg}")
69
71
 
70
72
 
73
+ OutputType = Union[None, DataType, Sequence[str], dict[str, DataType]]
74
+
75
+
76
+ class Sys(DataModel):
77
+ id: int
78
+ rand: int
79
+
80
+
71
81
  class DataChain(DatasetQuery):
72
82
  """AI 🔗 DataChain - a data structure for batch data processing and evaluation.
73
83
 
@@ -120,12 +130,10 @@ class DataChain(DatasetQuery):
120
130
  """
121
131
 
122
132
  DEFAULT_FILE_RECORD: ClassVar[dict] = {
123
- "id": 0,
124
133
  "source": "",
125
134
  "name": "",
126
135
  "vtype": "",
127
136
  "size": 0,
128
- "random": 0,
129
137
  }
130
138
 
131
139
  def __init__(self, *args, **kwargs):
@@ -151,11 +159,19 @@ class DataChain(DatasetQuery):
151
159
  def print_schema(self):
152
160
  self.signals_schema.print_tree()
153
161
 
154
- def create_model(self, name: str) -> type[Feature]:
155
- return self.signals_schema.create_model(name)
162
+ def clone(self, new_table: bool = True) -> "Self":
163
+ obj = super().clone(new_table=new_table)
164
+ obj.signals_schema = copy.deepcopy(self.signals_schema)
165
+ return obj
156
166
 
157
167
  def settings(
158
- self, cache=None, batch=None, parallel=None, workers=None, min_task_size=None
168
+ self,
169
+ cache=None,
170
+ batch=None,
171
+ parallel=None,
172
+ workers=None,
173
+ min_task_size=None,
174
+ include_sys: Optional[bool] = None,
159
175
  ) -> "Self":
160
176
  """Change settings for chain.
161
177
 
@@ -179,8 +195,13 @@ class DataChain(DatasetQuery):
179
195
  )
180
196
  ```
181
197
  """
182
- self._settings.add(Settings(cache, batch, parallel, workers, min_task_size))
183
- return self
198
+ chain = self.clone()
199
+ if include_sys is True:
200
+ chain.signals_schema = SignalSchema({"sys": Sys}) | chain.signals_schema
201
+ elif include_sys is False and "sys" in chain.signals_schema:
202
+ chain.signals_schema.remove("sys")
203
+ chain._settings.add(Settings(cache, batch, parallel, workers, min_task_size))
204
+ return chain
184
205
 
185
206
  def reset_settings(self, settings: Optional[Settings] = None) -> "Self":
186
207
  """Reset all settings to default values."""
@@ -192,12 +213,11 @@ class DataChain(DatasetQuery):
192
213
  return self
193
214
 
194
215
  def add_schema(self, signals_schema: SignalSchema) -> "Self":
195
- union = self.signals_schema.values | signals_schema.values
196
- self.signals_schema = SignalSchema(union)
216
+ self.signals_schema |= signals_schema
197
217
  return self
198
218
 
199
219
  def get_file_signals(self) -> list[str]:
200
- return self.signals_schema.get_file_signals()
220
+ return list(self.signals_schema.get_file_signals())
201
221
 
202
222
  @classmethod
203
223
  def from_storage(
@@ -205,9 +225,10 @@ class DataChain(DatasetQuery):
205
225
  path,
206
226
  *,
207
227
  type: Literal["binary", "text", "image"] = "binary",
208
- catalog: Optional["Catalog"] = None,
228
+ session: Optional[Session] = None,
209
229
  recursive: Optional[bool] = True,
210
- anon: bool = False,
230
+ object_name: str = "file",
231
+ **kwargs,
211
232
  ) -> "Self":
212
233
  """Get data from a storage as a list of file with all file attributes. It
213
234
  returns the chain itself as usual.
@@ -217,7 +238,7 @@ class DataChain(DatasetQuery):
217
238
  as `s3://`, `gs://`, `az://` or "file:///"
218
239
  type : read file as "binary", "text", or "image" data. Default is "binary".
219
240
  recursive : search recursively for the given path.
220
- anon : use anonymous mode to access the storage.
241
+ object_name : Created object column name.
221
242
 
222
243
  Example:
223
244
  ```py
@@ -225,7 +246,9 @@ class DataChain(DatasetQuery):
225
246
  ```
226
247
  """
227
248
  func = get_file(type)
228
- return cls(path, catalog=catalog, recursive=recursive, anon=anon).map(file=func)
249
+ return cls(path, session=session, recursive=recursive, **kwargs).map(
250
+ **{object_name: func}
251
+ )
229
252
 
230
253
  @classmethod
231
254
  def from_dataset(cls, name: str, version: Optional[int] = None) -> "DataChain":
@@ -240,66 +263,19 @@ class DataChain(DatasetQuery):
240
263
  """
241
264
  return DataChain(name=name, version=version)
242
265
 
243
- @classmethod
244
- def from_csv(
245
- cls,
246
- path,
247
- type: Literal["binary", "text", "image"] = "text",
248
- anon: bool = False,
249
- spec: Optional[FeatureType] = None,
250
- schema_from: Optional[str] = "auto",
251
- object_name: Optional[str] = "csv",
252
- model_name: Optional[str] = None,
253
- show_schema: Optional[bool] = False,
254
- ) -> "DataChain":
255
- """Get data from CSV. It returns the chain itself.
256
-
257
- Parameters:
258
- path : storage URI with directory. URI must start with storage prefix such
259
- as `s3://`, `gs://`, `az://` or "file:///"
260
- type : read file as "binary", "text", or "image" data. Default is "text".
261
- anon : use anonymous mode to access the storage.
262
- spec : Data Model for CSV file
263
- object_name : generated object column name
264
- model_name : generated model name
265
- schema_from : path to sample to infer spec from
266
- show_schema : print auto-generated schema
267
-
268
- Examples:
269
- infer model from the first two lines (header + data)
270
- >>> chain = DataChain.from_csv("gs://csv")
271
-
272
- use a particular data model
273
- >>> chain = DataChain.from_csv("gs://csv"i, spec=MyModel)
274
- """
275
- if schema_from == "auto":
276
- schema_from = path
277
-
278
- chain = DataChain.from_storage(path=path, type=type, anon=anon)
279
- signal_dict = {
280
- object_name: read_meta(
281
- schema_from=schema_from,
282
- meta_type="csv",
283
- spec=spec,
284
- model_name=model_name,
285
- show_schema=show_schema,
286
- )
287
- }
288
- return chain.gen(**signal_dict) # type: ignore[misc, arg-type]
289
-
290
266
  @classmethod
291
267
  def from_json(
292
268
  cls,
293
269
  path,
294
270
  type: Literal["binary", "text", "image"] = "text",
295
- anon: bool = False,
296
- spec: Optional[FeatureType] = None,
271
+ spec: Optional[DataType] = None,
297
272
  schema_from: Optional[str] = "auto",
298
273
  jmespath: Optional[str] = None,
299
- object_name: Optional[str] = None,
274
+ object_name: str = "",
300
275
  model_name: Optional[str] = None,
301
276
  show_schema: Optional[bool] = False,
302
277
  meta_type: Optional[str] = "json",
278
+ **kwargs,
303
279
  ) -> "DataChain":
304
280
  """Get data from JSON. It returns the chain itself.
305
281
 
@@ -307,7 +283,6 @@ class DataChain(DatasetQuery):
307
283
  path : storage URI with directory. URI must start with storage prefix such
308
284
  as `s3://`, `gs://`, `az://` or "file:///"
309
285
  type : read file as "binary", "text", or "image" data. Default is "binary".
310
- anon : use anonymous mode to access the storage.
311
286
  spec : optional Data Model
312
287
  schema_from : path to sample to infer spec from
313
288
  object_name : generated object column name
@@ -333,7 +308,7 @@ class DataChain(DatasetQuery):
333
308
  object_name = jmespath_to_name(jmespath)
334
309
  if not object_name:
335
310
  object_name = "json"
336
- chain = DataChain.from_storage(path=path, type=type, anon=anon)
311
+ chain = DataChain.from_storage(path=path, type=type, **kwargs)
337
312
  signal_dict = {
338
313
  object_name: read_meta(
339
314
  schema_from=schema_from,
@@ -396,6 +371,7 @@ class DataChain(DatasetQuery):
396
371
  version : version of a dataset. Default - the last version that exist.
397
372
  """
398
373
  schema = self.signals_schema.serialize()
374
+ schema.pop("sys", None)
399
375
  return super().save(name=name, version=version, feature_schema=schema)
400
376
 
401
377
  def apply(self, func, *args, **kwargs):
@@ -405,7 +381,7 @@ class DataChain(DatasetQuery):
405
381
  self,
406
382
  func: Optional[Callable] = None,
407
383
  params: Union[None, str, Sequence[str]] = None,
408
- output: Union[None, FeatureType, Sequence[str], dict[str, FeatureType]] = None,
384
+ output: OutputType = None,
409
385
  **signal_map,
410
386
  ) -> "Self":
411
387
  """Apply a function to each row to create new signals. The function should
@@ -449,7 +425,7 @@ class DataChain(DatasetQuery):
449
425
  self,
450
426
  func: Optional[Callable] = None,
451
427
  params: Union[None, str, Sequence[str]] = None,
452
- output: Union[None, FeatureType, Sequence[str], dict[str, FeatureType]] = None,
428
+ output: OutputType = None,
453
429
  **signal_map,
454
430
  ) -> "Self":
455
431
  """Apply a function to each row to create new rows (with potentially new
@@ -478,7 +454,7 @@ class DataChain(DatasetQuery):
478
454
  func: Optional[Callable] = None,
479
455
  partition_by: Optional[PartitionByType] = None,
480
456
  params: Union[None, str, Sequence[str]] = None,
481
- output: Union[None, FeatureType, Sequence[str], dict[str, FeatureType]] = None,
457
+ output: OutputType = None,
482
458
  **signal_map,
483
459
  ) -> "Self":
484
460
  """Aggregate rows using `partition_by` statement and apply a function to the
@@ -508,7 +484,7 @@ class DataChain(DatasetQuery):
508
484
  self,
509
485
  func: Optional[Callable] = None,
510
486
  params: Union[None, str, Sequence[str]] = None,
511
- output: Union[None, FeatureType, Sequence[str], dict[str, FeatureType]] = None,
487
+ output: OutputType = None,
512
488
  **signal_map,
513
489
  ) -> "Self":
514
490
  """This is a batch version of map().
@@ -530,18 +506,20 @@ class DataChain(DatasetQuery):
530
506
  target_class: type[UDFBase],
531
507
  func: Optional[Callable],
532
508
  params: Union[None, str, Sequence[str]],
533
- output: Union[None, FeatureType, Sequence[str], dict[str, FeatureType]],
509
+ output: OutputType,
534
510
  signal_map,
535
511
  ) -> UDFBase:
536
512
  is_generator = target_class.is_output_batched
537
513
  name = self.name or ""
538
514
 
539
515
  sign = UdfSignature.parse(name, signal_map, func, params, output, is_generator)
516
+ DataModel.register(list(sign.output_schema.values.values()))
517
+
540
518
  params_schema = self.signals_schema.slice(sign.params, self._setup)
541
519
 
542
- return UDFBase._create(target_class, sign, params_schema)
520
+ return target_class._create(sign, params_schema)
543
521
 
544
- def _extend_features(self, method_name, *args, **kwargs):
522
+ def _extend_to_data_model(self, method_name, *args, **kwargs):
545
523
  super_func = getattr(super(), method_name)
546
524
 
547
525
  new_schema = self.signals_schema.resolve(*args)
@@ -570,39 +548,46 @@ class DataChain(DatasetQuery):
570
548
  chain.signals_schema = new_schema
571
549
  return chain
572
550
 
573
- def iterate(self, *cols: str) -> Iterator[list[FeatureType]]:
551
+ def iterate_flatten(self) -> Iterator[tuple[Any]]:
552
+ db_signals = self.signals_schema.db_signals()
553
+ with super().select(*db_signals).as_iterable() as rows:
554
+ yield from rows
555
+
556
+ def results(
557
+ self, row_factory: Optional[Callable] = None, **kwargs
558
+ ) -> list[tuple[Any, ...]]:
559
+ rows = self.iterate_flatten()
560
+ if row_factory:
561
+ db_signals = self.signals_schema.db_signals()
562
+ rows = (row_factory(db_signals, r) for r in rows)
563
+ return list(rows)
564
+
565
+ def iterate(self, *cols: str) -> Iterator[list[DataType]]:
574
566
  """Iterate over rows.
575
567
 
576
568
  If columns are specified - limit them to specified
577
569
  columns.
578
570
  """
579
571
  chain = self.select(*cols) if cols else self
572
+ for row in chain.iterate_flatten():
573
+ yield chain.signals_schema.row_to_features(
574
+ row, catalog=chain.session.catalog, cache=chain._settings.cache
575
+ )
580
576
 
581
- db_signals = chain.signals_schema.db_signals()
582
- with super().select(*db_signals).as_iterable() as rows_iter:
583
- for row in rows_iter:
584
- yield chain.signals_schema.row_to_features(row, chain.session.catalog)
585
-
586
- def iterate_one(self, col: str) -> Iterator[FeatureType]:
577
+ def iterate_one(self, col: str) -> Iterator[DataType]:
587
578
  for item in self.iterate(col):
588
579
  yield item[0]
589
580
 
590
- def collect(self, *cols: str) -> list[list[FeatureType]]:
581
+ def collect(self, *cols: str) -> list[list[DataType]]:
591
582
  return list(self.iterate(*cols))
592
583
 
593
- def collect_one(self, col: str) -> list[FeatureType]:
584
+ def collect_one(self, col: str) -> list[DataType]:
594
585
  return list(self.iterate_one(col))
595
586
 
596
587
  def to_pytorch(self, **kwargs):
597
588
  """Convert to pytorch dataset format."""
598
589
 
599
- try:
600
- import torch # noqa: F401
601
- except ImportError as exc:
602
- raise ImportError(
603
- "Missing required dependency 'torch' for Dataset.to_pytorch()"
604
- ) from exc
605
- from datachain.lib.pytorch import PytorchDataset
590
+ from datachain.torch import PytorchDataset
606
591
 
607
592
  if self.attached:
608
593
  chain = self
@@ -701,25 +686,32 @@ class DataChain(DatasetQuery):
701
686
  return ds
702
687
 
703
688
  @classmethod
704
- def from_features(
689
+ def from_values(
705
690
  cls,
706
691
  ds_name: str = "",
707
692
  session: Optional[Session] = None,
708
- output: Union[None, FeatureType, Sequence[str], dict[str, FeatureType]] = None,
693
+ output: OutputType = None,
694
+ object_name: str = "",
709
695
  **fr_map,
710
696
  ) -> "DataChain":
711
- """Generate chain from list of features."""
712
- tuple_type, output, tuples = features_to_tuples(ds_name, output, **fr_map)
697
+ """Generate chain from list of values."""
698
+ tuple_type, output, tuples = values_to_tuples(ds_name, output, **fr_map)
713
699
 
714
700
  def _func_fr() -> Iterator[tuple_type]: # type: ignore[valid-type]
715
701
  yield from tuples
716
702
 
717
703
  chain = DataChain.create_empty(DataChain.DEFAULT_FILE_RECORD, session=session)
704
+ if object_name:
705
+ output = {object_name: DataChain._dict_to_data_model(object_name, output)} # type: ignore[arg-type]
718
706
  return chain.gen(_func_fr, output=output)
719
707
 
720
708
  @classmethod
721
709
  def from_pandas( # type: ignore[override]
722
- cls, df: "pd.DataFrame", name: str = "", session: Optional[Session] = None
710
+ cls,
711
+ df: "pd.DataFrame",
712
+ name: str = "",
713
+ session: Optional[Session] = None,
714
+ object_name: str = "",
723
715
  ) -> "DataChain":
724
716
  """Generate chain from pandas data-frame."""
725
717
  fr_map = {col.lower(): df[col].tolist() for col in df.columns}
@@ -737,17 +729,23 @@ class DataChain(DatasetQuery):
737
729
  f"import from pandas error - '{column}' cannot be a column name",
738
730
  )
739
731
 
740
- return cls.from_features(name, session, **fr_map)
732
+ return cls.from_values(name, session, object_name=object_name, **fr_map)
741
733
 
742
734
  def parse_tabular(
743
735
  self,
744
- output: Optional[dict[str, FeatureType]] = None,
736
+ output: OutputType = None,
737
+ object_name: str = "",
738
+ model_name: str = "",
745
739
  **kwargs,
746
740
  ) -> "DataChain":
747
741
  """Generate chain from list of tabular files.
748
742
 
749
743
  Parameters:
750
- output : Dictionary defining column names and their corresponding types.
744
+ output : Dictionary or feature class defining column names and their
745
+ corresponding types. List of column names is also accepted, in which
746
+ case types will be inferred.
747
+ object_name : Generated object column name.
748
+ model_name : Generated model name.
751
749
  kwargs : Parameters to pass to pyarrow.dataset.dataset.
752
750
 
753
751
  Examples:
@@ -760,107 +758,134 @@ class DataChain(DatasetQuery):
760
758
  >>> dc = dc.filter(C("file.name").glob("*.jsonl"))
761
759
  >>> dc = dc.parse_tabular(format="json")
762
760
  """
763
- from pyarrow import unify_schemas
764
- from pyarrow.dataset import dataset
765
761
 
766
- from datachain.lib.arrow import ArrowGenerator, schema_to_output
762
+ from datachain.lib.arrow import ArrowGenerator, infer_schema, schema_to_output
767
763
 
768
764
  schema = None
769
- if output:
770
- output = {"source": IndexedFile} | output
771
- else:
772
- schemas = []
773
- for row in self.select("file").iterate():
774
- file = row[0]
775
- ds = dataset(file.get_path(), filesystem=file.get_fs(), **kwargs) # type: ignore[union-attr]
776
- schemas.append(ds.schema)
777
- if not schemas:
778
- msg = "error parsing tabular data schema - found no files to parse"
779
- raise DatasetPrepareError(self.name, msg)
780
- schema = unify_schemas(schemas)
765
+ col_names = output if isinstance(output, Sequence) else None
766
+ if col_names or not output:
781
767
  try:
782
- output = schema_to_output(schema)
768
+ schema = infer_schema(self, **kwargs)
769
+ output = schema_to_output(schema, col_names)
783
770
  except ValueError as e:
784
771
  raise DatasetPrepareError(self.name, e) from e
785
772
 
773
+ if object_name:
774
+ if isinstance(output, dict):
775
+ model_name = model_name or object_name
776
+ output = DataChain._dict_to_data_model(model_name, output)
777
+ output = {object_name: output} # type: ignore[dict-item]
778
+ elif isinstance(output, type(BaseModel)):
779
+ output = {
780
+ name: info.annotation # type: ignore[misc]
781
+ for name, info in output.model_fields.items()
782
+ }
783
+ output = {"source": IndexedFile} | output # type: ignore[assignment,operator]
786
784
  return self.gen(ArrowGenerator(schema, **kwargs), output=output)
787
785
 
788
- def parse_csv(
789
- self,
786
+ @staticmethod
787
+ def _dict_to_data_model(
788
+ name: str, data_dict: dict[str, DataType]
789
+ ) -> type[BaseModel]:
790
+ fields = {name: (anno, ...) for name, anno in data_dict.items()}
791
+ return create_model(
792
+ name,
793
+ __base__=(DataModel,), # type: ignore[call-overload]
794
+ **fields,
795
+ ) # type: ignore[call-overload]
796
+
797
+ @classmethod
798
+ def from_csv(
799
+ cls,
800
+ path,
790
801
  delimiter: str = ",",
791
802
  header: bool = True,
792
803
  column_names: Optional[list[str]] = None,
793
- output: Optional[dict[str, FeatureType]] = None,
804
+ output: OutputType = None,
805
+ object_name: str = "",
806
+ model_name: str = "",
807
+ **kwargs,
794
808
  ) -> "DataChain":
795
- """Generate chain from list of csv files.
809
+ """Generate chain from csv files.
796
810
 
797
811
  Parameters:
812
+ path : Storage URI with directory. URI must start with storage prefix such
813
+ as `s3://`, `gs://`, `az://` or "file:///".
798
814
  delimiter : Character for delimiting columns.
799
815
  header : Whether the files include a header row.
800
- column_names : Column names if no header. Implies `header = False`.
801
- output : Dictionary defining column names and their corresponding types.
816
+ output : Dictionary or feature class defining column names and their
817
+ corresponding types. List of column names is also accepted, in which
818
+ case types will be inferred.
819
+ object_name : Created object column name.
820
+ model_name : Generated model name.
802
821
 
803
822
  Examples:
804
823
  Reading a csv file:
805
- >>> dc = DataChain.from_storage("s3://mybucket/file.csv")
806
- >>> dc = dc.parse_tabular(format="csv")
824
+ >>> dc = DataChain.from_csv("s3://mybucket/file.csv")
807
825
 
808
- Reading a filtered list of csv files as a dataset:
809
- >>> dc = DataChain.from_storage("s3://mybucket")
810
- >>> dc = dc.filter(C("file.name").glob("*.csv"))
811
- >>> dc = dc.parse_tabular()
826
+ Reading csv files from a directory as a combined dataset:
827
+ >>> dc = DataChain.from_csv("s3://mybucket/dir")
812
828
  """
813
829
  from pyarrow.csv import ParseOptions, ReadOptions
814
830
  from pyarrow.dataset import CsvFileFormat
815
831
 
816
- if column_names and output:
817
- msg = "error parsing csv - only one of column_names or output is allowed"
818
- raise DatasetPrepareError(self.name, msg)
832
+ chain = DataChain.from_storage(path, **kwargs)
819
833
 
820
- if not header and not column_names:
821
- if output:
834
+ if not header:
835
+ if not output:
836
+ msg = "error parsing csv - provide output if no header"
837
+ raise DatasetPrepareError(chain.name, msg)
838
+ if isinstance(output, Sequence):
839
+ column_names = output # type: ignore[assignment]
840
+ elif isinstance(output, dict):
822
841
  column_names = list(output.keys())
842
+ elif (fr := ModelStore.to_pydantic(output)) is not None:
843
+ column_names = list(fr.model_fields.keys())
823
844
  else:
824
- msg = "error parsing csv - provide column_names or output if no header"
825
- raise DatasetPrepareError(self.name, msg)
845
+ msg = f"error parsing csv - incompatible output type {type(output)}"
846
+ raise DatasetPrepareError(chain.name, msg)
826
847
 
827
848
  parse_options = ParseOptions(delimiter=delimiter)
828
849
  read_options = ReadOptions(column_names=column_names)
829
850
  format = CsvFileFormat(parse_options=parse_options, read_options=read_options)
830
- return self.parse_tabular(output=output, format=format)
851
+ return chain.parse_tabular(
852
+ output=output, object_name=object_name, model_name=model_name, format=format
853
+ )
831
854
 
832
- def parse_parquet(
833
- self,
855
+ @classmethod
856
+ def from_parquet(
857
+ cls,
858
+ path,
834
859
  partitioning: Any = "hive",
835
- output: Optional[dict[str, FeatureType]] = None,
860
+ output: Optional[dict[str, DataType]] = None,
861
+ object_name: str = "",
862
+ model_name: str = "",
863
+ **kwargs,
836
864
  ) -> "DataChain":
837
- """Generate chain from list of parquet files.
865
+ """Generate chain from parquet files.
838
866
 
839
867
  Parameters:
868
+ path : Storage URI with directory. URI must start with storage prefix such
869
+ as `s3://`, `gs://`, `az://` or "file:///".
840
870
  partitioning : Any pyarrow partitioning schema.
841
871
  output : Dictionary defining column names and their corresponding types.
872
+ object_name : Created object column name.
873
+ model_name : Generated model name.
842
874
 
843
875
  Examples:
844
876
  Reading a single file:
845
- >>> dc = DataChain.from_storage("s3://mybucket/file.parquet")
846
- >>> dc = dc.parse_tabular()
877
+ >>> dc = DataChain.from_parquet("s3://mybucket/file.parquet")
847
878
 
848
879
  Reading a partitioned dataset from a directory:
849
- >>> dc = DataChain.from_storage("path/to/dir")
850
- >>> dc = dc.parse_tabular()
851
-
852
- Reading a filtered list of files as a dataset:
853
- >>> dc = DataChain.from_storage("s3://mybucket")
854
- >>> dc = dc.filter(C("file.name").glob("*.parquet"))
855
- >>> dc = dc.parse_tabular()
856
-
857
- Reading a filtered list of partitions as a dataset:
858
- >>> dc = DataChain.from_storage("s3://mybucket")
859
- >>> dc = dc.filter(C("file.parent").glob("*month=1*"))
860
- >>> dc = dc.parse_tabular()
880
+ >>> dc = DataChain.from_parquet("s3://mybucket/dir")
861
881
  """
862
- return self.parse_tabular(
863
- output=output, format="parquet", partitioning=partitioning
882
+ chain = DataChain.from_storage(path, **kwargs)
883
+ return chain.parse_tabular(
884
+ output=output,
885
+ object_name=object_name,
886
+ model_name=model_name,
887
+ format="parquet",
888
+ partitioning=partitioning,
864
889
  )
865
890
 
866
891
  @classmethod
@@ -903,17 +928,17 @@ class DataChain(DatasetQuery):
903
928
  db.execute(insert_q.values(**record))
904
929
  return DataChain(name=dsr.name)
905
930
 
906
- def sum(self, fr: FeatureType): # type: ignore[override]
907
- return self._extend_features("sum", fr)
931
+ def sum(self, fr: DataType): # type: ignore[override]
932
+ return self._extend_to_data_model("sum", fr)
908
933
 
909
- def avg(self, fr: FeatureType): # type: ignore[override]
910
- return self._extend_features("avg", fr)
934
+ def avg(self, fr: DataType): # type: ignore[override]
935
+ return self._extend_to_data_model("avg", fr)
911
936
 
912
- def min(self, fr: FeatureType): # type: ignore[override]
913
- return self._extend_features("min", fr)
937
+ def min(self, fr: DataType): # type: ignore[override]
938
+ return self._extend_to_data_model("min", fr)
914
939
 
915
- def max(self, fr: FeatureType): # type: ignore[override]
916
- return self._extend_features("max", fr)
940
+ def max(self, fr: DataType): # type: ignore[override]
941
+ return self._extend_to_data_model("max", fr)
917
942
 
918
943
  def setup(self, **kwargs) -> "Self":
919
944
  intersection = set(self._setup.keys()) & set(kwargs.keys())