datachain 0.2.9__py3-none-any.whl → 0.2.11__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of datachain might be problematic. Click here for more details.

Files changed (51) hide show
  1. datachain/__init__.py +17 -8
  2. datachain/catalog/catalog.py +5 -5
  3. datachain/cli.py +0 -2
  4. datachain/data_storage/schema.py +5 -5
  5. datachain/data_storage/sqlite.py +1 -1
  6. datachain/data_storage/warehouse.py +7 -7
  7. datachain/lib/arrow.py +25 -8
  8. datachain/lib/clip.py +6 -11
  9. datachain/lib/convert/__init__.py +0 -0
  10. datachain/lib/convert/flatten.py +67 -0
  11. datachain/lib/convert/type_converter.py +96 -0
  12. datachain/lib/convert/unflatten.py +69 -0
  13. datachain/lib/convert/values_to_tuples.py +85 -0
  14. datachain/lib/data_model.py +74 -0
  15. datachain/lib/dc.py +225 -168
  16. datachain/lib/file.py +41 -41
  17. datachain/lib/gpt4_vision.py +1 -9
  18. datachain/lib/hf_image_to_text.py +9 -17
  19. datachain/lib/hf_pipeline.py +4 -12
  20. datachain/lib/image.py +2 -18
  21. datachain/lib/image_transform.py +0 -1
  22. datachain/lib/iptc_exif_xmp.py +8 -15
  23. datachain/lib/meta_formats.py +1 -5
  24. datachain/lib/model_store.py +77 -0
  25. datachain/lib/pytorch.py +9 -21
  26. datachain/lib/signal_schema.py +139 -60
  27. datachain/lib/text.py +5 -16
  28. datachain/lib/udf.py +114 -30
  29. datachain/lib/udf_signature.py +5 -5
  30. datachain/lib/webdataset.py +3 -3
  31. datachain/lib/webdataset_laion.py +2 -3
  32. datachain/node.py +4 -4
  33. datachain/query/batch.py +1 -1
  34. datachain/query/dataset.py +51 -178
  35. datachain/query/dispatch.py +43 -30
  36. datachain/query/udf.py +46 -26
  37. datachain/remote/studio.py +1 -9
  38. datachain/torch/__init__.py +21 -0
  39. datachain/utils.py +39 -0
  40. {datachain-0.2.9.dist-info → datachain-0.2.11.dist-info}/METADATA +14 -12
  41. {datachain-0.2.9.dist-info → datachain-0.2.11.dist-info}/RECORD +45 -43
  42. {datachain-0.2.9.dist-info → datachain-0.2.11.dist-info}/WHEEL +1 -1
  43. datachain/image/__init__.py +0 -3
  44. datachain/lib/cached_stream.py +0 -38
  45. datachain/lib/claude.py +0 -69
  46. datachain/lib/feature.py +0 -412
  47. datachain/lib/feature_registry.py +0 -51
  48. datachain/lib/feature_utils.py +0 -154
  49. {datachain-0.2.9.dist-info → datachain-0.2.11.dist-info}/LICENSE +0 -0
  50. {datachain-0.2.9.dist-info → datachain-0.2.11.dist-info}/entry_points.txt +0 -0
  51. {datachain-0.2.9.dist-info → datachain-0.2.11.dist-info}/top_level.txt +0 -0
datachain/__init__.py CHANGED
@@ -1,11 +1,16 @@
1
- from datachain.lib.dc import C, DataChain
2
- from datachain.lib.feature import Feature
3
- from datachain.lib.feature_utils import pydantic_to_feature
4
- from datachain.lib.file import File, FileError, FileFeature, IndexedFile, TarVFile
1
+ from datachain.lib.data_model import DataModel, DataType, FileBasic, is_chain_type
2
+ from datachain.lib.dc import C, Column, DataChain, Sys
3
+ from datachain.lib.file import (
4
+ File,
5
+ FileError,
6
+ ImageFile,
7
+ IndexedFile,
8
+ TarVFile,
9
+ TextFile,
10
+ )
5
11
  from datachain.lib.udf import Aggregator, Generator, Mapper
6
12
  from datachain.lib.utils import AbstractUDF, DataChainError
7
13
  from datachain.query.dataset import UDF as BaseUDF # noqa: N811
8
- from datachain.query.schema import Column
9
14
  from datachain.query.session import Session
10
15
 
11
16
  __all__ = [
@@ -16,14 +21,18 @@ __all__ = [
16
21
  "Column",
17
22
  "DataChain",
18
23
  "DataChainError",
19
- "Feature",
24
+ "DataModel",
25
+ "DataType",
20
26
  "File",
27
+ "FileBasic",
21
28
  "FileError",
22
- "FileFeature",
23
29
  "Generator",
30
+ "ImageFile",
24
31
  "IndexedFile",
25
32
  "Mapper",
26
33
  "Session",
34
+ "Sys",
27
35
  "TarVFile",
28
- "pydantic_to_feature",
36
+ "TextFile",
37
+ "is_chain_type",
29
38
  ]
@@ -256,7 +256,7 @@ class DatasetRowsFetcher(NodesThreadPool):
256
256
  self.fix_columns(df)
257
257
 
258
258
  # id will be autogenerated in DB
259
- df = df.drop("id", axis=1)
259
+ df = df.drop("sys__id", axis=1)
260
260
 
261
261
  inserted = warehouse.insert_dataset_rows(
262
262
  df, dataset, self.dataset_version
@@ -1041,7 +1041,7 @@ class Catalog:
1041
1041
  If version is None, then next unused version is created.
1042
1042
  If version is given, then it must be an unused version number.
1043
1043
  """
1044
- assert [c.name for c in columns if c.name != "id"], f"got {columns=}"
1044
+ assert [c.name for c in columns if c.name != "sys__id"], f"got {columns=}"
1045
1045
  if not listing and Client.is_data_source_uri(name):
1046
1046
  raise RuntimeError(
1047
1047
  "Cannot create dataset that starts with source prefix, e.g s3://"
@@ -1103,7 +1103,7 @@ class Catalog:
1103
1103
  Creates dataset version if it doesn't exist.
1104
1104
  If create_rows is False, dataset rows table will not be created
1105
1105
  """
1106
- assert [c.name for c in columns if c.name != "id"], f"got {columns=}"
1106
+ assert [c.name for c in columns if c.name != "sys__id"], f"got {columns=}"
1107
1107
  schema = {
1108
1108
  c.name: c.type.to_dict() for c in columns if isinstance(c.type, SQLType)
1109
1109
  }
@@ -1433,7 +1433,7 @@ class Catalog:
1433
1433
  if offset:
1434
1434
  q = q.offset(offset)
1435
1435
 
1436
- q = q.order_by("id")
1436
+ q = q.order_by("sys__id")
1437
1437
 
1438
1438
  return q.to_records()
1439
1439
 
@@ -1786,7 +1786,7 @@ class Catalog:
1786
1786
  schema = DatasetRecord.parse_schema(remote_dataset_version.schema)
1787
1787
 
1788
1788
  columns = tuple(
1789
- sa.Column(name, typ) for name, typ in schema.items() if name != "id"
1789
+ sa.Column(name, typ) for name, typ in schema.items() if name != "sys__id"
1790
1790
  )
1791
1791
  # creating new dataset (version) locally
1792
1792
  dataset = self.create_dataset(
datachain/cli.py CHANGED
@@ -811,8 +811,6 @@ def show(
811
811
  from datachain.query import DatasetQuery
812
812
  from datachain.utils import show_records
813
813
 
814
- if columns:
815
- columns = ("id", *columns)
816
814
  query = (
817
815
  DatasetQuery(name=name, version=version, catalog=catalog)
818
816
  .select(*columns)
@@ -72,7 +72,7 @@ class DirExpansion:
72
72
  @staticmethod
73
73
  def base_select(q):
74
74
  return sa.select(
75
- q.c.id,
75
+ q.c.sys__id,
76
76
  q.c.vtype,
77
77
  (q.c.dir_type == DirType.DIR).label("is_dir"),
78
78
  q.c.source,
@@ -86,7 +86,7 @@ class DirExpansion:
86
86
  def apply_group_by(q):
87
87
  return (
88
88
  sa.select(
89
- f.min(q.c.id).label("id"),
89
+ f.min(q.c.sys__id).label("sys__id"),
90
90
  q.c.vtype,
91
91
  q.c.is_dir,
92
92
  q.c.source,
@@ -111,7 +111,7 @@ class DirExpansion:
111
111
  parent_name = path.name(q.c.parent)
112
112
  q = q.union_all(
113
113
  sa.select(
114
- sa.literal(-1).label("id"),
114
+ sa.literal(-1).label("sys__id"),
115
115
  sa.literal("").label("vtype"),
116
116
  true().label("is_dir"),
117
117
  q.c.source,
@@ -233,9 +233,9 @@ class DataTable:
233
233
  @staticmethod
234
234
  def sys_columns():
235
235
  return [
236
- sa.Column("id", Int, primary_key=True),
236
+ sa.Column("sys__id", Int, primary_key=True),
237
237
  sa.Column(
238
- "random", UInt64, nullable=False, server_default=f.abs(f.random())
238
+ "sys__rand", UInt64, nullable=False, server_default=f.abs(f.random())
239
239
  ),
240
240
  ]
241
241
 
@@ -631,7 +631,7 @@ class SQLiteWarehouse(AbstractWarehouse):
631
631
  dst_empty = True
632
632
 
633
633
  dst_dr = self.dataset_rows(dst, dst_version).table
634
- merge_fields = [c.name for c in src_dr.c if c.name != "id"]
634
+ merge_fields = [c.name for c in src_dr.c if c.name != "sys__id"]
635
635
  select_src = select(*(getattr(src_dr.c, f) for f in merge_fields))
636
636
 
637
637
  if dst_empty:
@@ -195,7 +195,7 @@ class AbstractWarehouse(ABC, Serializable):
195
195
  cols_names = [c.name for c in cols]
196
196
 
197
197
  if not order_by:
198
- ordering = [cols.id]
198
+ ordering = [cols.sys__id]
199
199
  else:
200
200
  ordering = order_by # type: ignore[assignment]
201
201
 
@@ -372,7 +372,7 @@ class AbstractWarehouse(ABC, Serializable):
372
372
  """Returns total number of rows in a dataset"""
373
373
  dr = self.dataset_rows(dataset, version)
374
374
  table = dr.get_table()
375
- query = select(sa.func.count(table.c.id))
375
+ query = select(sa.func.count(table.c.sys__id))
376
376
  (res,) = self.db.execute(query)
377
377
  return res[0]
378
378
 
@@ -388,7 +388,7 @@ class AbstractWarehouse(ABC, Serializable):
388
388
  dr = self.dataset_rows(dataset, version)
389
389
  table = dr.get_table()
390
390
  expressions: tuple[_ColumnsClauseArgument[Any], ...] = (
391
- sa.func.count(table.c.id),
391
+ sa.func.count(table.c.sys__id),
392
392
  )
393
393
  if "size" in table.columns:
394
394
  expressions = (*expressions, sa.func.sum(table.c.size))
@@ -607,7 +607,7 @@ class AbstractWarehouse(ABC, Serializable):
607
607
  return func.coalesce(column, default).label(column.name)
608
608
 
609
609
  return sa.select(
610
- de.c.id,
610
+ de.c.sys__id,
611
611
  with_default(dr.c.vtype),
612
612
  case((de.c.is_dir == true(), DirType.DIR), else_=dr.c.dir_type).label(
613
613
  "dir_type"
@@ -621,10 +621,10 @@ class AbstractWarehouse(ABC, Serializable):
621
621
  with_default(dr.c.size),
622
622
  with_default(dr.c.owner_name),
623
623
  with_default(dr.c.owner_id),
624
- with_default(dr.c.random),
624
+ with_default(dr.c.sys__rand),
625
625
  dr.c.location,
626
626
  de.c.source,
627
- ).select_from(de.outerjoin(dr.table, de.c.id == dr.c.id))
627
+ ).select_from(de.outerjoin(dr.table, de.c.sys__id == dr.c.sys__id))
628
628
 
629
629
  def get_node_by_path(self, dataset_rows: "DataTable", path: str) -> Node:
630
630
  """Gets node that corresponds to some path"""
@@ -878,7 +878,7 @@ class AbstractWarehouse(ABC, Serializable):
878
878
  tbl = sa.Table(
879
879
  name,
880
880
  sa.MetaData(),
881
- sa.Column("id", Int, primary_key=True),
881
+ sa.Column("sys__id", Int, primary_key=True),
882
882
  *columns,
883
883
  )
884
884
  self.db.create_table(tbl, if_not_exists=True)
datachain/lib/arrow.py CHANGED
@@ -1,13 +1,15 @@
1
1
  import re
2
+ from collections.abc import Sequence
2
3
  from typing import TYPE_CHECKING, Optional
3
4
 
5
+ import pyarrow as pa
4
6
  from pyarrow.dataset import dataset
5
7
 
6
8
  from datachain.lib.file import File, IndexedFile
7
9
  from datachain.lib.udf import Generator
8
10
 
9
11
  if TYPE_CHECKING:
10
- import pyarrow as pa
12
+ from datachain.lib.dc import DataChain
11
13
 
12
14
 
13
15
  class ArrowGenerator(Generator):
@@ -35,12 +37,29 @@ class ArrowGenerator(Generator):
35
37
  index += 1
36
38
 
37
39
 
38
- def schema_to_output(schema: "pa.Schema"):
40
+ def infer_schema(chain: "DataChain", **kwargs) -> pa.Schema:
41
+ schemas = []
42
+ for file in chain.iterate_one("file"):
43
+ ds = dataset(file.get_path(), filesystem=file.get_fs(), **kwargs) # type: ignore[union-attr]
44
+ schemas.append(ds.schema)
45
+ return pa.unify_schemas(schemas)
46
+
47
+
48
+ def schema_to_output(schema: pa.Schema, col_names: Optional[Sequence[str]] = None):
39
49
  """Generate UDF output schema from pyarrow schema."""
50
+ if col_names and (len(schema) != len(col_names)):
51
+ raise ValueError(
52
+ "Error generating output from Arrow schema - "
53
+ f"Schema has {len(schema)} columns but got {len(col_names)} column names."
54
+ )
40
55
  default_column = 0
41
- output = {"source": IndexedFile}
42
- for field in schema:
43
- column = field.name.lower()
56
+ output = {}
57
+ for i, field in enumerate(schema):
58
+ if col_names:
59
+ column = col_names[i]
60
+ else:
61
+ column = field.name
62
+ column = column.lower()
44
63
  column = re.sub("[^0-9a-z_]+", "", column)
45
64
  if not column:
46
65
  column = f"c{default_column}"
@@ -50,12 +69,10 @@ def schema_to_output(schema: "pa.Schema"):
50
69
  return output
51
70
 
52
71
 
53
- def _arrow_type_mapper(col_type: "pa.DataType") -> type: # noqa: PLR0911
72
+ def _arrow_type_mapper(col_type: pa.DataType) -> type: # noqa: PLR0911
54
73
  """Convert pyarrow types to basic types."""
55
74
  from datetime import datetime
56
75
 
57
- import pyarrow as pa
58
-
59
76
  if pa.types.is_timestamp(col_type):
60
77
  return datetime
61
78
  if pa.types.is_binary(col_type):
datachain/lib/clip.py CHANGED
@@ -1,19 +1,14 @@
1
1
  import inspect
2
- from typing import Any, Callable, Literal, Union
2
+ from typing import TYPE_CHECKING, Any, Callable, Literal, Union
3
+
4
+ import torch
5
+ from transformers.modeling_utils import PreTrainedModel
3
6
 
4
7
  from datachain.lib.image import convert_images
5
8
  from datachain.lib.text import convert_text
6
9
 
7
- try:
8
- import torch
10
+ if TYPE_CHECKING:
9
11
  from PIL import Image
10
- from transformers.modeling_utils import PreTrainedModel
11
- except ImportError as exc:
12
- raise ImportError(
13
- "Missing dependencies for computer vision:\n"
14
- "To install run:\n\n"
15
- " pip install 'datachain[cv]'\n"
16
- ) from exc
17
12
 
18
13
 
19
14
  def _get_encoder(model: Any, type: Literal["image", "text"]) -> Callable:
@@ -37,7 +32,7 @@ def _get_encoder(model: Any, type: Literal["image", "text"]) -> Callable:
37
32
 
38
33
 
39
34
  def similarity_scores(
40
- images: Union[None, Image.Image, list[Image.Image]],
35
+ images: Union[None, "Image.Image", list["Image.Image"]],
41
36
  text: Union[None, str, list[str]],
42
37
  model: Any,
43
38
  preprocess: Callable,
File without changes
@@ -0,0 +1,67 @@
1
+ from datetime import datetime
2
+
3
+ from pydantic import BaseModel
4
+
5
+ from datachain.lib.model_store import ModelStore
6
+ from datachain.sql.types import (
7
+ JSON,
8
+ Array,
9
+ Binary,
10
+ Boolean,
11
+ DateTime,
12
+ Float,
13
+ Int,
14
+ Int32,
15
+ Int64,
16
+ NullType,
17
+ String,
18
+ )
19
+
20
+ DATACHAIN_TO_TYPE = {
21
+ Int: int,
22
+ Int32: int,
23
+ Int64: int,
24
+ String: str,
25
+ Float: float,
26
+ Boolean: bool,
27
+ DateTime: datetime,
28
+ Binary: bytes,
29
+ Array(NullType): list,
30
+ JSON: dict,
31
+ }
32
+
33
+
34
+ def flatten(obj: BaseModel):
35
+ return tuple(_flatten_fields_values(obj.model_fields, obj))
36
+
37
+
38
+ def flatten_list(obj_list):
39
+ return tuple(
40
+ val for obj in obj_list for val in _flatten_fields_values(obj.model_fields, obj)
41
+ )
42
+
43
+
44
+ def _flatten_fields_values(fields, obj: BaseModel):
45
+ for name, f_info in fields.items():
46
+ anno = f_info.annotation
47
+ # Optimization: Access attributes directly to skip the model_dump() call.
48
+ value = getattr(obj, name)
49
+
50
+ if isinstance(value, list):
51
+ yield [
52
+ val.model_dump() if ModelStore.is_pydantic(type(val)) else val
53
+ for val in value
54
+ ]
55
+ elif isinstance(value, dict):
56
+ yield {
57
+ key: val.model_dump() if ModelStore.is_pydantic(type(val)) else val
58
+ for key, val in value.items()
59
+ }
60
+ elif ModelStore.is_pydantic(anno):
61
+ yield from _flatten_fields_values(anno.model_fields, value)
62
+ else:
63
+ yield value
64
+
65
+
66
+ def _flatten(obj):
67
+ return tuple(_flatten_fields_values(obj.model_fields, obj))
@@ -0,0 +1,96 @@
1
+ import inspect
2
+ from datetime import datetime
3
+ from enum import Enum
4
+ from typing import Annotated, Literal, Union, get_args, get_origin
5
+
6
+ from pydantic import BaseModel
7
+ from typing_extensions import Literal as LiteralEx
8
+
9
+ from datachain.lib.model_store import ModelStore
10
+ from datachain.sql.types import (
11
+ JSON,
12
+ Array,
13
+ Binary,
14
+ Boolean,
15
+ DateTime,
16
+ Float,
17
+ Int64,
18
+ SQLType,
19
+ String,
20
+ )
21
+
22
+ TYPE_TO_DATACHAIN = {
23
+ int: Int64,
24
+ str: String,
25
+ Literal: String,
26
+ LiteralEx: String,
27
+ Enum: String,
28
+ float: Float,
29
+ bool: Boolean,
30
+ datetime: DateTime, # Note, list of datetime is not supported yet
31
+ bytes: Binary, # Note, list of bytes is not supported yet
32
+ list: Array,
33
+ dict: JSON,
34
+ }
35
+
36
+
37
+ def convert_to_db_type(typ): # noqa: PLR0911
38
+ if inspect.isclass(typ):
39
+ if issubclass(typ, SQLType):
40
+ return typ
41
+ if issubclass(typ, Enum):
42
+ return str
43
+
44
+ res = TYPE_TO_DATACHAIN.get(typ)
45
+ if res:
46
+ return res
47
+
48
+ orig = get_origin(typ)
49
+
50
+ if orig in (Literal, LiteralEx):
51
+ return String
52
+
53
+ args = get_args(typ)
54
+ if inspect.isclass(orig) and (issubclass(list, orig) or issubclass(tuple, orig)):
55
+ if args is None or len(args) != 1:
56
+ raise TypeError(f"Cannot resolve type '{typ}' for flattening features")
57
+
58
+ args0 = args[0]
59
+ if ModelStore.is_pydantic(args0):
60
+ return Array(JSON())
61
+
62
+ next_type = convert_to_db_type(args0)
63
+ return Array(next_type)
64
+
65
+ if orig is Annotated:
66
+ # Ignoring annotations
67
+ return convert_to_db_type(args[0])
68
+
69
+ if inspect.isclass(orig) and issubclass(dict, orig):
70
+ return JSON
71
+
72
+ if orig == Union:
73
+ if len(args) == 2 and (type(None) in args):
74
+ return convert_to_db_type(args[0])
75
+
76
+ if _is_json_inside_union(orig, args):
77
+ return JSON
78
+
79
+ raise TypeError(f"Cannot recognize type {typ}")
80
+
81
+
82
+ def _is_json_inside_union(orig, args) -> bool:
83
+ if orig == Union and len(args) >= 2:
84
+ # List in JSON: Union[dict, list[dict]]
85
+ args_no_nones = [arg for arg in args if arg != type(None)]
86
+ if len(args_no_nones) == 2:
87
+ args_no_dicts = [arg for arg in args_no_nones if arg is not dict]
88
+ if len(args_no_dicts) == 1 and get_origin(args_no_dicts[0]) is list:
89
+ arg = get_args(args_no_dicts[0])
90
+ if len(arg) == 1 and arg[0] is dict:
91
+ return True
92
+
93
+ # List of objects: Union[MyClass, OtherClass]
94
+ if any(inspect.isclass(arg) and issubclass(arg, BaseModel) for arg in args):
95
+ return True
96
+ return False
@@ -0,0 +1,69 @@
1
+ import copy
2
+ import inspect
3
+ import re
4
+ from collections.abc import Sequence
5
+ from typing import Any, get_origin
6
+
7
+ from pydantic import BaseModel
8
+
9
+ from datachain.query.schema import DEFAULT_DELIMITER
10
+
11
+
12
+ def unflatten_to_json(model: type[BaseModel], row: Sequence[Any], pos=0) -> dict:
13
+ return unflatten_to_json_pos(model, row, pos)[0]
14
+
15
+
16
+ def unflatten_to_json_pos(
17
+ model: type[BaseModel], row: Sequence[Any], pos=0
18
+ ) -> tuple[dict, int]:
19
+ res = {}
20
+ for name, f_info in model.model_fields.items():
21
+ anno = f_info.annotation
22
+ origin = get_origin(anno)
23
+ if (
24
+ origin not in (list, dict)
25
+ and inspect.isclass(anno)
26
+ and issubclass(anno, BaseModel)
27
+ ):
28
+ res[name], pos = unflatten_to_json_pos(anno, row, pos)
29
+ else:
30
+ res[name] = row[pos]
31
+ pos += 1
32
+ return res, pos
33
+
34
+
35
+ def _normalize(name: str) -> str:
36
+ if DEFAULT_DELIMITER in name:
37
+ raise RuntimeError(
38
+ f"variable '{name}' cannot be used "
39
+ f"because it contains {DEFAULT_DELIMITER}"
40
+ )
41
+ return _to_snake_case(name)
42
+
43
+
44
+ def _to_snake_case(name: str) -> str:
45
+ """Convert a CamelCase name to snake_case."""
46
+ s1 = re.sub("(.)([A-Z][a-z]+)", r"\1_\2", name)
47
+ return re.sub("([a-z0-9])([A-Z])", r"\1_\2", s1).lower()
48
+
49
+
50
+ def _unflatten_with_path(model: type[BaseModel], dump, name_path: list[str]):
51
+ res = {}
52
+ for name, f_info in model.model_fields.items():
53
+ anno = f_info.annotation
54
+ name_norm = _normalize(name)
55
+ lst = copy.copy(name_path)
56
+
57
+ if inspect.isclass(anno) and issubclass(anno, BaseModel):
58
+ lst.append(name_norm)
59
+ val = _unflatten_with_path(anno, dump, lst)
60
+ res[name] = val
61
+ else:
62
+ lst.append(name_norm)
63
+ curr_path = DEFAULT_DELIMITER.join(lst)
64
+ res[name] = dump[curr_path]
65
+ return model(**res)
66
+
67
+
68
+ def unflatten(model: type[BaseModel], dump):
69
+ return _unflatten_with_path(model, dump, [])
@@ -0,0 +1,85 @@
1
+ from collections.abc import Sequence
2
+ from typing import Any, Union
3
+
4
+ from datachain.lib.data_model import DataType, DataTypeNames, is_chain_type
5
+ from datachain.lib.utils import DataChainParamsError
6
+
7
+
8
+ class ValuesToTupleError(DataChainParamsError):
9
+ def __init__(self, ds_name, msg):
10
+ if ds_name:
11
+ ds_name = f"' {ds_name}'"
12
+ super().__init__(f"Cannot convert features for dataset{ds_name}: {msg}")
13
+
14
+
15
+ def values_to_tuples(
16
+ ds_name: str = "",
17
+ output: Union[None, DataType, Sequence[str], dict[str, DataType]] = None,
18
+ **fr_map,
19
+ ) -> tuple[Any, Any, Any]:
20
+ types_map = {}
21
+ length = -1
22
+ for k, v in fr_map.items():
23
+ if not isinstance(v, Sequence) or isinstance(v, str):
24
+ raise ValuesToTupleError(ds_name, f"features '{k}' is not a sequence")
25
+ len_ = len(v)
26
+
27
+ if len_ == 0:
28
+ raise ValuesToTupleError(ds_name, f"feature '{k}' is empty list")
29
+
30
+ if length < 0:
31
+ length = len_
32
+ elif length != len_:
33
+ raise ValuesToTupleError(
34
+ ds_name,
35
+ f"feature '{k}' should have length {length} while {len_} is given",
36
+ )
37
+ typ = type(v[0])
38
+ if not is_chain_type(typ):
39
+ raise ValuesToTupleError(
40
+ ds_name,
41
+ f"feature '{k}' has unsupported type '{typ.__name__}'."
42
+ f" Please use Feature types: {DataTypeNames}",
43
+ )
44
+ types_map[k] = typ
45
+ if output:
46
+ if not isinstance(output, Sequence) and not isinstance(output, str):
47
+ if len(fr_map) != 1:
48
+ raise ValuesToTupleError(
49
+ ds_name,
50
+ f"only one output type was specified, {len(fr_map)} expected",
51
+ )
52
+ if not isinstance(output, type):
53
+ raise ValuesToTupleError(
54
+ ds_name,
55
+ f"output must specify a type while '{output}' was given",
56
+ )
57
+
58
+ key: str = next(iter(fr_map.keys()))
59
+ output = {key: output} # type: ignore[dict-item]
60
+
61
+ if len(output) != len(fr_map):
62
+ raise ValuesToTupleError(
63
+ ds_name,
64
+ f"number of outputs '{len(output)}' should match"
65
+ f" number of features '{len(fr_map)}'",
66
+ )
67
+ if isinstance(output, dict):
68
+ raise ValuesToTupleError(
69
+ ds_name,
70
+ "output type must be dict[str, FeatureType] while "
71
+ f"'{type(output).__name__}' is given",
72
+ )
73
+ else:
74
+ output = types_map # type: ignore[assignment]
75
+
76
+ output_types: list[type] = list(output.values()) # type: ignore[union-attr,call-arg,arg-type]
77
+ if len(output) > 1: # type: ignore[arg-type]
78
+ tuple_type = tuple(output_types)
79
+ res_type = tuple[tuple_type] # type: ignore[valid-type]
80
+ res_values = list(zip(*fr_map.values()))
81
+ else:
82
+ res_type = output_types[0] # type: ignore[misc]
83
+ res_values = next(iter(fr_map.values()))
84
+
85
+ return res_type, output, res_values