datachain 0.2.9__py3-none-any.whl → 0.2.11__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of datachain might be problematic. Click here for more details.

Files changed (51) hide show
  1. datachain/__init__.py +17 -8
  2. datachain/catalog/catalog.py +5 -5
  3. datachain/cli.py +0 -2
  4. datachain/data_storage/schema.py +5 -5
  5. datachain/data_storage/sqlite.py +1 -1
  6. datachain/data_storage/warehouse.py +7 -7
  7. datachain/lib/arrow.py +25 -8
  8. datachain/lib/clip.py +6 -11
  9. datachain/lib/convert/__init__.py +0 -0
  10. datachain/lib/convert/flatten.py +67 -0
  11. datachain/lib/convert/type_converter.py +96 -0
  12. datachain/lib/convert/unflatten.py +69 -0
  13. datachain/lib/convert/values_to_tuples.py +85 -0
  14. datachain/lib/data_model.py +74 -0
  15. datachain/lib/dc.py +225 -168
  16. datachain/lib/file.py +41 -41
  17. datachain/lib/gpt4_vision.py +1 -9
  18. datachain/lib/hf_image_to_text.py +9 -17
  19. datachain/lib/hf_pipeline.py +4 -12
  20. datachain/lib/image.py +2 -18
  21. datachain/lib/image_transform.py +0 -1
  22. datachain/lib/iptc_exif_xmp.py +8 -15
  23. datachain/lib/meta_formats.py +1 -5
  24. datachain/lib/model_store.py +77 -0
  25. datachain/lib/pytorch.py +9 -21
  26. datachain/lib/signal_schema.py +139 -60
  27. datachain/lib/text.py +5 -16
  28. datachain/lib/udf.py +114 -30
  29. datachain/lib/udf_signature.py +5 -5
  30. datachain/lib/webdataset.py +3 -3
  31. datachain/lib/webdataset_laion.py +2 -3
  32. datachain/node.py +4 -4
  33. datachain/query/batch.py +1 -1
  34. datachain/query/dataset.py +51 -178
  35. datachain/query/dispatch.py +43 -30
  36. datachain/query/udf.py +46 -26
  37. datachain/remote/studio.py +1 -9
  38. datachain/torch/__init__.py +21 -0
  39. datachain/utils.py +39 -0
  40. {datachain-0.2.9.dist-info → datachain-0.2.11.dist-info}/METADATA +14 -12
  41. {datachain-0.2.9.dist-info → datachain-0.2.11.dist-info}/RECORD +45 -43
  42. {datachain-0.2.9.dist-info → datachain-0.2.11.dist-info}/WHEEL +1 -1
  43. datachain/image/__init__.py +0 -3
  44. datachain/lib/cached_stream.py +0 -38
  45. datachain/lib/claude.py +0 -69
  46. datachain/lib/feature.py +0 -412
  47. datachain/lib/feature_registry.py +0 -51
  48. datachain/lib/feature_utils.py +0 -154
  49. {datachain-0.2.9.dist-info → datachain-0.2.11.dist-info}/LICENSE +0 -0
  50. {datachain-0.2.9.dist-info → datachain-0.2.11.dist-info}/entry_points.txt +0 -0
  51. {datachain-0.2.9.dist-info → datachain-0.2.11.dist-info}/top_level.txt +0 -0
datachain/lib/dc.py CHANGED
@@ -1,3 +1,4 @@
1
+ import copy
1
2
  import re
2
3
  from collections.abc import Iterator, Sequence
3
4
  from typing import (
@@ -10,12 +11,16 @@ from typing import (
10
11
  Union,
11
12
  )
12
13
 
14
+ import pandas as pd
13
15
  import sqlalchemy
16
+ from pydantic import BaseModel, create_model
14
17
 
15
- from datachain.lib.feature import Feature, FeatureType
16
- from datachain.lib.feature_utils import features_to_tuples
18
+ from datachain import DataModel
19
+ from datachain.lib.convert.values_to_tuples import values_to_tuples
20
+ from datachain.lib.data_model import DataType
17
21
  from datachain.lib.file import File, IndexedFile, get_file
18
22
  from datachain.lib.meta_formats import read_meta, read_schema
23
+ from datachain.lib.model_store import ModelStore
19
24
  from datachain.lib.settings import Settings
20
25
  from datachain.lib.signal_schema import SignalSchema
21
26
  from datachain.lib.udf import (
@@ -34,13 +39,11 @@ from datachain.query.dataset import (
34
39
  detach,
35
40
  )
36
41
  from datachain.query.schema import Column, DatasetRow
42
+ from datachain.utils import inside_notebook
37
43
 
38
44
  if TYPE_CHECKING:
39
- import pandas as pd
40
45
  from typing_extensions import Self
41
46
 
42
- from datachain.catalog import Catalog
43
-
44
47
  C = Column
45
48
 
46
49
 
@@ -51,10 +54,10 @@ class DatasetPrepareError(DataChainParamsError):
51
54
  super().__init__(f"Dataset{name}{output} processing prepare error: {msg}")
52
55
 
53
56
 
54
- class DatasetFromFeatureError(DataChainParamsError):
57
+ class DatasetFromValuesError(DataChainParamsError):
55
58
  def __init__(self, name, msg):
56
59
  name = f" '{name}'" if name else ""
57
- super().__init__(f"Dataset {name} from feature error: {msg}")
60
+ super().__init__(f"Dataset {name} from values error: {msg}")
58
61
 
59
62
 
60
63
  class DatasetMergeError(DataChainParamsError):
@@ -68,6 +71,14 @@ class DatasetMergeError(DataChainParamsError):
68
71
  super().__init__(f"Merge error on='{on_str}'{right_on_str}: {msg}")
69
72
 
70
73
 
74
+ OutputType = Union[None, DataType, Sequence[str], dict[str, DataType]]
75
+
76
+
77
+ class Sys(DataModel):
78
+ id: int
79
+ rand: int
80
+
81
+
71
82
  class DataChain(DatasetQuery):
72
83
  """AI 🔗 DataChain - a data structure for batch data processing and evaluation.
73
84
 
@@ -120,12 +131,10 @@ class DataChain(DatasetQuery):
120
131
  """
121
132
 
122
133
  DEFAULT_FILE_RECORD: ClassVar[dict] = {
123
- "id": 0,
124
134
  "source": "",
125
135
  "name": "",
126
136
  "vtype": "",
127
137
  "size": 0,
128
- "random": 0,
129
138
  }
130
139
 
131
140
  def __init__(self, *args, **kwargs):
@@ -151,11 +160,19 @@ class DataChain(DatasetQuery):
151
160
  def print_schema(self):
152
161
  self.signals_schema.print_tree()
153
162
 
154
- def create_model(self, name: str) -> type[Feature]:
155
- return self.signals_schema.create_model(name)
163
+ def clone(self, new_table: bool = True) -> "Self":
164
+ obj = super().clone(new_table=new_table)
165
+ obj.signals_schema = copy.deepcopy(self.signals_schema)
166
+ return obj
156
167
 
157
168
  def settings(
158
- self, cache=None, batch=None, parallel=None, workers=None, min_task_size=None
169
+ self,
170
+ cache=None,
171
+ batch=None,
172
+ parallel=None,
173
+ workers=None,
174
+ min_task_size=None,
175
+ include_sys: Optional[bool] = None,
159
176
  ) -> "Self":
160
177
  """Change settings for chain.
161
178
 
@@ -179,8 +196,13 @@ class DataChain(DatasetQuery):
179
196
  )
180
197
  ```
181
198
  """
182
- self._settings.add(Settings(cache, batch, parallel, workers, min_task_size))
183
- return self
199
+ chain = self.clone()
200
+ if include_sys is True:
201
+ chain.signals_schema = SignalSchema({"sys": Sys}) | chain.signals_schema
202
+ elif include_sys is False and "sys" in chain.signals_schema:
203
+ chain.signals_schema.remove("sys")
204
+ chain._settings.add(Settings(cache, batch, parallel, workers, min_task_size))
205
+ return chain
184
206
 
185
207
  def reset_settings(self, settings: Optional[Settings] = None) -> "Self":
186
208
  """Reset all settings to default values."""
@@ -192,12 +214,11 @@ class DataChain(DatasetQuery):
192
214
  return self
193
215
 
194
216
  def add_schema(self, signals_schema: SignalSchema) -> "Self":
195
- union = self.signals_schema.values | signals_schema.values
196
- self.signals_schema = SignalSchema(union)
217
+ self.signals_schema |= signals_schema
197
218
  return self
198
219
 
199
220
  def get_file_signals(self) -> list[str]:
200
- return self.signals_schema.get_file_signals()
221
+ return list(self.signals_schema.get_file_signals())
201
222
 
202
223
  @classmethod
203
224
  def from_storage(
@@ -205,9 +226,10 @@ class DataChain(DatasetQuery):
205
226
  path,
206
227
  *,
207
228
  type: Literal["binary", "text", "image"] = "binary",
208
- catalog: Optional["Catalog"] = None,
229
+ session: Optional[Session] = None,
209
230
  recursive: Optional[bool] = True,
210
- anon: bool = False,
231
+ object_name: str = "file",
232
+ **kwargs,
211
233
  ) -> "Self":
212
234
  """Get data from a storage as a list of file with all file attributes. It
213
235
  returns the chain itself as usual.
@@ -217,7 +239,7 @@ class DataChain(DatasetQuery):
217
239
  as `s3://`, `gs://`, `az://` or "file:///"
218
240
  type : read file as "binary", "text", or "image" data. Default is "binary".
219
241
  recursive : search recursively for the given path.
220
- anon : use anonymous mode to access the storage.
242
+ object_name : Created object column name.
221
243
 
222
244
  Example:
223
245
  ```py
@@ -225,7 +247,9 @@ class DataChain(DatasetQuery):
225
247
  ```
226
248
  """
227
249
  func = get_file(type)
228
- return cls(path, catalog=catalog, recursive=recursive, anon=anon).map(file=func)
250
+ return cls(path, session=session, recursive=recursive, **kwargs).map(
251
+ **{object_name: func}
252
+ )
229
253
 
230
254
  @classmethod
231
255
  def from_dataset(cls, name: str, version: Optional[int] = None) -> "DataChain":
@@ -240,66 +264,19 @@ class DataChain(DatasetQuery):
240
264
  """
241
265
  return DataChain(name=name, version=version)
242
266
 
243
- @classmethod
244
- def from_csv(
245
- cls,
246
- path,
247
- type: Literal["binary", "text", "image"] = "text",
248
- anon: bool = False,
249
- spec: Optional[FeatureType] = None,
250
- schema_from: Optional[str] = "auto",
251
- object_name: Optional[str] = "csv",
252
- model_name: Optional[str] = None,
253
- show_schema: Optional[bool] = False,
254
- ) -> "DataChain":
255
- """Get data from CSV. It returns the chain itself.
256
-
257
- Parameters:
258
- path : storage URI with directory. URI must start with storage prefix such
259
- as `s3://`, `gs://`, `az://` or "file:///"
260
- type : read file as "binary", "text", or "image" data. Default is "text".
261
- anon : use anonymous mode to access the storage.
262
- spec : Data Model for CSV file
263
- object_name : generated object column name
264
- model_name : generated model name
265
- schema_from : path to sample to infer spec from
266
- show_schema : print auto-generated schema
267
-
268
- Examples:
269
- infer model from the first two lines (header + data)
270
- >>> chain = DataChain.from_csv("gs://csv")
271
-
272
- use a particular data model
273
- >>> chain = DataChain.from_csv("gs://csv"i, spec=MyModel)
274
- """
275
- if schema_from == "auto":
276
- schema_from = path
277
-
278
- chain = DataChain.from_storage(path=path, type=type, anon=anon)
279
- signal_dict = {
280
- object_name: read_meta(
281
- schema_from=schema_from,
282
- meta_type="csv",
283
- spec=spec,
284
- model_name=model_name,
285
- show_schema=show_schema,
286
- )
287
- }
288
- return chain.gen(**signal_dict) # type: ignore[misc, arg-type]
289
-
290
267
  @classmethod
291
268
  def from_json(
292
269
  cls,
293
270
  path,
294
271
  type: Literal["binary", "text", "image"] = "text",
295
- anon: bool = False,
296
- spec: Optional[FeatureType] = None,
272
+ spec: Optional[DataType] = None,
297
273
  schema_from: Optional[str] = "auto",
298
274
  jmespath: Optional[str] = None,
299
- object_name: Optional[str] = None,
275
+ object_name: str = "",
300
276
  model_name: Optional[str] = None,
301
277
  show_schema: Optional[bool] = False,
302
278
  meta_type: Optional[str] = "json",
279
+ **kwargs,
303
280
  ) -> "DataChain":
304
281
  """Get data from JSON. It returns the chain itself.
305
282
 
@@ -307,7 +284,6 @@ class DataChain(DatasetQuery):
307
284
  path : storage URI with directory. URI must start with storage prefix such
308
285
  as `s3://`, `gs://`, `az://` or "file:///"
309
286
  type : read file as "binary", "text", or "image" data. Default is "binary".
310
- anon : use anonymous mode to access the storage.
311
287
  spec : optional Data Model
312
288
  schema_from : path to sample to infer spec from
313
289
  object_name : generated object column name
@@ -333,7 +309,7 @@ class DataChain(DatasetQuery):
333
309
  object_name = jmespath_to_name(jmespath)
334
310
  if not object_name:
335
311
  object_name = "json"
336
- chain = DataChain.from_storage(path=path, type=type, anon=anon)
312
+ chain = DataChain.from_storage(path=path, type=type, **kwargs)
337
313
  signal_dict = {
338
314
  object_name: read_meta(
339
315
  schema_from=schema_from,
@@ -396,6 +372,7 @@ class DataChain(DatasetQuery):
396
372
  version : version of a dataset. Default - the last version that exist.
397
373
  """
398
374
  schema = self.signals_schema.serialize()
375
+ schema.pop("sys", None)
399
376
  return super().save(name=name, version=version, feature_schema=schema)
400
377
 
401
378
  def apply(self, func, *args, **kwargs):
@@ -405,7 +382,7 @@ class DataChain(DatasetQuery):
405
382
  self,
406
383
  func: Optional[Callable] = None,
407
384
  params: Union[None, str, Sequence[str]] = None,
408
- output: Union[None, FeatureType, Sequence[str], dict[str, FeatureType]] = None,
385
+ output: OutputType = None,
409
386
  **signal_map,
410
387
  ) -> "Self":
411
388
  """Apply a function to each row to create new signals. The function should
@@ -449,7 +426,7 @@ class DataChain(DatasetQuery):
449
426
  self,
450
427
  func: Optional[Callable] = None,
451
428
  params: Union[None, str, Sequence[str]] = None,
452
- output: Union[None, FeatureType, Sequence[str], dict[str, FeatureType]] = None,
429
+ output: OutputType = None,
453
430
  **signal_map,
454
431
  ) -> "Self":
455
432
  """Apply a function to each row to create new rows (with potentially new
@@ -478,7 +455,7 @@ class DataChain(DatasetQuery):
478
455
  func: Optional[Callable] = None,
479
456
  partition_by: Optional[PartitionByType] = None,
480
457
  params: Union[None, str, Sequence[str]] = None,
481
- output: Union[None, FeatureType, Sequence[str], dict[str, FeatureType]] = None,
458
+ output: OutputType = None,
482
459
  **signal_map,
483
460
  ) -> "Self":
484
461
  """Aggregate rows using `partition_by` statement and apply a function to the
@@ -508,7 +485,7 @@ class DataChain(DatasetQuery):
508
485
  self,
509
486
  func: Optional[Callable] = None,
510
487
  params: Union[None, str, Sequence[str]] = None,
511
- output: Union[None, FeatureType, Sequence[str], dict[str, FeatureType]] = None,
488
+ output: OutputType = None,
512
489
  **signal_map,
513
490
  ) -> "Self":
514
491
  """This is a batch version of map().
@@ -530,18 +507,20 @@ class DataChain(DatasetQuery):
530
507
  target_class: type[UDFBase],
531
508
  func: Optional[Callable],
532
509
  params: Union[None, str, Sequence[str]],
533
- output: Union[None, FeatureType, Sequence[str], dict[str, FeatureType]],
510
+ output: OutputType,
534
511
  signal_map,
535
512
  ) -> UDFBase:
536
513
  is_generator = target_class.is_output_batched
537
514
  name = self.name or ""
538
515
 
539
516
  sign = UdfSignature.parse(name, signal_map, func, params, output, is_generator)
517
+ DataModel.register(list(sign.output_schema.values.values()))
518
+
540
519
  params_schema = self.signals_schema.slice(sign.params, self._setup)
541
520
 
542
- return UDFBase._create(target_class, sign, params_schema)
521
+ return target_class._create(sign, params_schema)
543
522
 
544
- def _extend_features(self, method_name, *args, **kwargs):
523
+ def _extend_to_data_model(self, method_name, *args, **kwargs):
545
524
  super_func = getattr(super(), method_name)
546
525
 
547
526
  new_schema = self.signals_schema.resolve(*args)
@@ -570,39 +549,46 @@ class DataChain(DatasetQuery):
570
549
  chain.signals_schema = new_schema
571
550
  return chain
572
551
 
573
- def iterate(self, *cols: str) -> Iterator[list[FeatureType]]:
552
+ def iterate_flatten(self) -> Iterator[tuple[Any]]:
553
+ db_signals = self.signals_schema.db_signals()
554
+ with super().select(*db_signals).as_iterable() as rows:
555
+ yield from rows
556
+
557
+ def results(
558
+ self, row_factory: Optional[Callable] = None, **kwargs
559
+ ) -> list[tuple[Any, ...]]:
560
+ rows = self.iterate_flatten()
561
+ if row_factory:
562
+ db_signals = self.signals_schema.db_signals()
563
+ rows = (row_factory(db_signals, r) for r in rows)
564
+ return list(rows)
565
+
566
+ def iterate(self, *cols: str) -> Iterator[list[DataType]]:
574
567
  """Iterate over rows.
575
568
 
576
569
  If columns are specified - limit them to specified
577
570
  columns.
578
571
  """
579
572
  chain = self.select(*cols) if cols else self
573
+ for row in chain.iterate_flatten():
574
+ yield chain.signals_schema.row_to_features(
575
+ row, catalog=chain.session.catalog, cache=chain._settings.cache
576
+ )
580
577
 
581
- db_signals = chain.signals_schema.db_signals()
582
- with super().select(*db_signals).as_iterable() as rows_iter:
583
- for row in rows_iter:
584
- yield chain.signals_schema.row_to_features(row, chain.session.catalog)
585
-
586
- def iterate_one(self, col: str) -> Iterator[FeatureType]:
578
+ def iterate_one(self, col: str) -> Iterator[DataType]:
587
579
  for item in self.iterate(col):
588
580
  yield item[0]
589
581
 
590
- def collect(self, *cols: str) -> list[list[FeatureType]]:
582
+ def collect(self, *cols: str) -> list[list[DataType]]:
591
583
  return list(self.iterate(*cols))
592
584
 
593
- def collect_one(self, col: str) -> list[FeatureType]:
585
+ def collect_one(self, col: str) -> list[DataType]:
594
586
  return list(self.iterate_one(col))
595
587
 
596
588
  def to_pytorch(self, **kwargs):
597
589
  """Convert to pytorch dataset format."""
598
590
 
599
- try:
600
- import torch # noqa: F401
601
- except ImportError as exc:
602
- raise ImportError(
603
- "Missing required dependency 'torch' for Dataset.to_pytorch()"
604
- ) from exc
605
- from datachain.lib.pytorch import PytorchDataset
591
+ from datachain.torch import PytorchDataset
606
592
 
607
593
  if self.attached:
608
594
  chain = self
@@ -701,25 +687,32 @@ class DataChain(DatasetQuery):
701
687
  return ds
702
688
 
703
689
  @classmethod
704
- def from_features(
690
+ def from_values(
705
691
  cls,
706
692
  ds_name: str = "",
707
693
  session: Optional[Session] = None,
708
- output: Union[None, FeatureType, Sequence[str], dict[str, FeatureType]] = None,
694
+ output: OutputType = None,
695
+ object_name: str = "",
709
696
  **fr_map,
710
697
  ) -> "DataChain":
711
- """Generate chain from list of features."""
712
- tuple_type, output, tuples = features_to_tuples(ds_name, output, **fr_map)
698
+ """Generate chain from list of values."""
699
+ tuple_type, output, tuples = values_to_tuples(ds_name, output, **fr_map)
713
700
 
714
701
  def _func_fr() -> Iterator[tuple_type]: # type: ignore[valid-type]
715
702
  yield from tuples
716
703
 
717
704
  chain = DataChain.create_empty(DataChain.DEFAULT_FILE_RECORD, session=session)
705
+ if object_name:
706
+ output = {object_name: DataChain._dict_to_data_model(object_name, output)} # type: ignore[arg-type]
718
707
  return chain.gen(_func_fr, output=output)
719
708
 
720
709
  @classmethod
721
710
  def from_pandas( # type: ignore[override]
722
- cls, df: "pd.DataFrame", name: str = "", session: Optional[Session] = None
711
+ cls,
712
+ df: "pd.DataFrame",
713
+ name: str = "",
714
+ session: Optional[Session] = None,
715
+ object_name: str = "",
723
716
  ) -> "DataChain":
724
717
  """Generate chain from pandas data-frame."""
725
718
  fr_map = {col.lower(): df[col].tolist() for col in df.columns}
@@ -737,17 +730,54 @@ class DataChain(DatasetQuery):
737
730
  f"import from pandas error - '{column}' cannot be a column name",
738
731
  )
739
732
 
740
- return cls.from_features(name, session, **fr_map)
733
+ return cls.from_values(name, session, object_name=object_name, **fr_map)
734
+
735
+ def to_pandas(self, flatten=False) -> "pd.DataFrame":
736
+ headers, max_length = self.signals_schema.get_headers_with_length()
737
+ if flatten or max_length < 2:
738
+ df = pd.DataFrame.from_records(self.to_records())
739
+ if headers:
740
+ df.columns = [".".join(filter(None, header)) for header in headers]
741
+ return df
742
+
743
+ transposed_result = list(map(list, zip(*self.results())))
744
+ data = {tuple(n): val for n, val in zip(headers, transposed_result)}
745
+ return pd.DataFrame(data)
746
+
747
+ def show(self, limit: int = 20, flatten=False, transpose=False) -> None:
748
+ dc = self.limit(limit) if limit > 0 else self
749
+ df = dc.to_pandas(flatten)
750
+ if transpose:
751
+ df = df.T
752
+
753
+ with pd.option_context(
754
+ "display.max_columns", None, "display.multi_sparse", False
755
+ ):
756
+ if inside_notebook():
757
+ from IPython.display import display
758
+
759
+ display(df)
760
+ else:
761
+ print(df)
762
+
763
+ if len(df) == limit:
764
+ print(f"\n[Limited by {len(df)} rows]")
741
765
 
742
766
  def parse_tabular(
743
767
  self,
744
- output: Optional[dict[str, FeatureType]] = None,
768
+ output: OutputType = None,
769
+ object_name: str = "",
770
+ model_name: str = "",
745
771
  **kwargs,
746
772
  ) -> "DataChain":
747
773
  """Generate chain from list of tabular files.
748
774
 
749
775
  Parameters:
750
- output : Dictionary defining column names and their corresponding types.
776
+ output : Dictionary or feature class defining column names and their
777
+ corresponding types. List of column names is also accepted, in which
778
+ case types will be inferred.
779
+ object_name : Generated object column name.
780
+ model_name : Generated model name.
751
781
  kwargs : Parameters to pass to pyarrow.dataset.dataset.
752
782
 
753
783
  Examples:
@@ -760,107 +790,134 @@ class DataChain(DatasetQuery):
760
790
  >>> dc = dc.filter(C("file.name").glob("*.jsonl"))
761
791
  >>> dc = dc.parse_tabular(format="json")
762
792
  """
763
- from pyarrow import unify_schemas
764
- from pyarrow.dataset import dataset
765
793
 
766
- from datachain.lib.arrow import ArrowGenerator, schema_to_output
794
+ from datachain.lib.arrow import ArrowGenerator, infer_schema, schema_to_output
767
795
 
768
796
  schema = None
769
- if output:
770
- output = {"source": IndexedFile} | output
771
- else:
772
- schemas = []
773
- for row in self.select("file").iterate():
774
- file = row[0]
775
- ds = dataset(file.get_path(), filesystem=file.get_fs(), **kwargs) # type: ignore[union-attr]
776
- schemas.append(ds.schema)
777
- if not schemas:
778
- msg = "error parsing tabular data schema - found no files to parse"
779
- raise DatasetPrepareError(self.name, msg)
780
- schema = unify_schemas(schemas)
797
+ col_names = output if isinstance(output, Sequence) else None
798
+ if col_names or not output:
781
799
  try:
782
- output = schema_to_output(schema)
800
+ schema = infer_schema(self, **kwargs)
801
+ output = schema_to_output(schema, col_names)
783
802
  except ValueError as e:
784
803
  raise DatasetPrepareError(self.name, e) from e
785
804
 
805
+ if object_name:
806
+ if isinstance(output, dict):
807
+ model_name = model_name or object_name
808
+ output = DataChain._dict_to_data_model(model_name, output)
809
+ output = {object_name: output} # type: ignore[dict-item]
810
+ elif isinstance(output, type(BaseModel)):
811
+ output = {
812
+ name: info.annotation # type: ignore[misc]
813
+ for name, info in output.model_fields.items()
814
+ }
815
+ output = {"source": IndexedFile} | output # type: ignore[assignment,operator]
786
816
  return self.gen(ArrowGenerator(schema, **kwargs), output=output)
787
817
 
788
- def parse_csv(
789
- self,
818
+ @staticmethod
819
+ def _dict_to_data_model(
820
+ name: str, data_dict: dict[str, DataType]
821
+ ) -> type[BaseModel]:
822
+ fields = {name: (anno, ...) for name, anno in data_dict.items()}
823
+ return create_model(
824
+ name,
825
+ __base__=(DataModel,), # type: ignore[call-overload]
826
+ **fields,
827
+ ) # type: ignore[call-overload]
828
+
829
+ @classmethod
830
+ def from_csv(
831
+ cls,
832
+ path,
790
833
  delimiter: str = ",",
791
834
  header: bool = True,
792
835
  column_names: Optional[list[str]] = None,
793
- output: Optional[dict[str, FeatureType]] = None,
836
+ output: OutputType = None,
837
+ object_name: str = "",
838
+ model_name: str = "",
839
+ **kwargs,
794
840
  ) -> "DataChain":
795
- """Generate chain from list of csv files.
841
+ """Generate chain from csv files.
796
842
 
797
843
  Parameters:
844
+ path : Storage URI with directory. URI must start with storage prefix such
845
+ as `s3://`, `gs://`, `az://` or "file:///".
798
846
  delimiter : Character for delimiting columns.
799
847
  header : Whether the files include a header row.
800
- column_names : Column names if no header. Implies `header = False`.
801
- output : Dictionary defining column names and their corresponding types.
848
+ output : Dictionary or feature class defining column names and their
849
+ corresponding types. List of column names is also accepted, in which
850
+ case types will be inferred.
851
+ object_name : Created object column name.
852
+ model_name : Generated model name.
802
853
 
803
854
  Examples:
804
855
  Reading a csv file:
805
- >>> dc = DataChain.from_storage("s3://mybucket/file.csv")
806
- >>> dc = dc.parse_tabular(format="csv")
856
+ >>> dc = DataChain.from_csv("s3://mybucket/file.csv")
807
857
 
808
- Reading a filtered list of csv files as a dataset:
809
- >>> dc = DataChain.from_storage("s3://mybucket")
810
- >>> dc = dc.filter(C("file.name").glob("*.csv"))
811
- >>> dc = dc.parse_tabular()
858
+ Reading csv files from a directory as a combined dataset:
859
+ >>> dc = DataChain.from_csv("s3://mybucket/dir")
812
860
  """
813
861
  from pyarrow.csv import ParseOptions, ReadOptions
814
862
  from pyarrow.dataset import CsvFileFormat
815
863
 
816
- if column_names and output:
817
- msg = "error parsing csv - only one of column_names or output is allowed"
818
- raise DatasetPrepareError(self.name, msg)
864
+ chain = DataChain.from_storage(path, **kwargs)
819
865
 
820
- if not header and not column_names:
821
- if output:
866
+ if not header:
867
+ if not output:
868
+ msg = "error parsing csv - provide output if no header"
869
+ raise DatasetPrepareError(chain.name, msg)
870
+ if isinstance(output, Sequence):
871
+ column_names = output # type: ignore[assignment]
872
+ elif isinstance(output, dict):
822
873
  column_names = list(output.keys())
874
+ elif (fr := ModelStore.to_pydantic(output)) is not None:
875
+ column_names = list(fr.model_fields.keys())
823
876
  else:
824
- msg = "error parsing csv - provide column_names or output if no header"
825
- raise DatasetPrepareError(self.name, msg)
877
+ msg = f"error parsing csv - incompatible output type {type(output)}"
878
+ raise DatasetPrepareError(chain.name, msg)
826
879
 
827
880
  parse_options = ParseOptions(delimiter=delimiter)
828
881
  read_options = ReadOptions(column_names=column_names)
829
882
  format = CsvFileFormat(parse_options=parse_options, read_options=read_options)
830
- return self.parse_tabular(output=output, format=format)
883
+ return chain.parse_tabular(
884
+ output=output, object_name=object_name, model_name=model_name, format=format
885
+ )
831
886
 
832
- def parse_parquet(
833
- self,
887
+ @classmethod
888
+ def from_parquet(
889
+ cls,
890
+ path,
834
891
  partitioning: Any = "hive",
835
- output: Optional[dict[str, FeatureType]] = None,
892
+ output: Optional[dict[str, DataType]] = None,
893
+ object_name: str = "",
894
+ model_name: str = "",
895
+ **kwargs,
836
896
  ) -> "DataChain":
837
- """Generate chain from list of parquet files.
897
+ """Generate chain from parquet files.
838
898
 
839
899
  Parameters:
900
+ path : Storage URI with directory. URI must start with storage prefix such
901
+ as `s3://`, `gs://`, `az://` or "file:///".
840
902
  partitioning : Any pyarrow partitioning schema.
841
903
  output : Dictionary defining column names and their corresponding types.
904
+ object_name : Created object column name.
905
+ model_name : Generated model name.
842
906
 
843
907
  Examples:
844
908
  Reading a single file:
845
- >>> dc = DataChain.from_storage("s3://mybucket/file.parquet")
846
- >>> dc = dc.parse_tabular()
909
+ >>> dc = DataChain.from_parquet("s3://mybucket/file.parquet")
847
910
 
848
911
  Reading a partitioned dataset from a directory:
849
- >>> dc = DataChain.from_storage("path/to/dir")
850
- >>> dc = dc.parse_tabular()
851
-
852
- Reading a filtered list of files as a dataset:
853
- >>> dc = DataChain.from_storage("s3://mybucket")
854
- >>> dc = dc.filter(C("file.name").glob("*.parquet"))
855
- >>> dc = dc.parse_tabular()
856
-
857
- Reading a filtered list of partitions as a dataset:
858
- >>> dc = DataChain.from_storage("s3://mybucket")
859
- >>> dc = dc.filter(C("file.parent").glob("*month=1*"))
860
- >>> dc = dc.parse_tabular()
912
+ >>> dc = DataChain.from_parquet("s3://mybucket/dir")
861
913
  """
862
- return self.parse_tabular(
863
- output=output, format="parquet", partitioning=partitioning
914
+ chain = DataChain.from_storage(path, **kwargs)
915
+ return chain.parse_tabular(
916
+ output=output,
917
+ object_name=object_name,
918
+ model_name=model_name,
919
+ format="parquet",
920
+ partitioning=partitioning,
864
921
  )
865
922
 
866
923
  @classmethod
@@ -903,17 +960,17 @@ class DataChain(DatasetQuery):
903
960
  db.execute(insert_q.values(**record))
904
961
  return DataChain(name=dsr.name)
905
962
 
906
- def sum(self, fr: FeatureType): # type: ignore[override]
907
- return self._extend_features("sum", fr)
963
+ def sum(self, fr: DataType): # type: ignore[override]
964
+ return self._extend_to_data_model("sum", fr)
908
965
 
909
- def avg(self, fr: FeatureType): # type: ignore[override]
910
- return self._extend_features("avg", fr)
966
+ def avg(self, fr: DataType): # type: ignore[override]
967
+ return self._extend_to_data_model("avg", fr)
911
968
 
912
- def min(self, fr: FeatureType): # type: ignore[override]
913
- return self._extend_features("min", fr)
969
+ def min(self, fr: DataType): # type: ignore[override]
970
+ return self._extend_to_data_model("min", fr)
914
971
 
915
- def max(self, fr: FeatureType): # type: ignore[override]
916
- return self._extend_features("max", fr)
972
+ def max(self, fr: DataType): # type: ignore[override]
973
+ return self._extend_to_data_model("max", fr)
917
974
 
918
975
  def setup(self, **kwargs) -> "Self":
919
976
  intersection = set(self._setup.keys()) & set(kwargs.keys())