datachain 0.2.9__py3-none-any.whl → 0.2.11__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of datachain might be problematic. Click here for more details.
- datachain/__init__.py +17 -8
- datachain/catalog/catalog.py +5 -5
- datachain/cli.py +0 -2
- datachain/data_storage/schema.py +5 -5
- datachain/data_storage/sqlite.py +1 -1
- datachain/data_storage/warehouse.py +7 -7
- datachain/lib/arrow.py +25 -8
- datachain/lib/clip.py +6 -11
- datachain/lib/convert/__init__.py +0 -0
- datachain/lib/convert/flatten.py +67 -0
- datachain/lib/convert/type_converter.py +96 -0
- datachain/lib/convert/unflatten.py +69 -0
- datachain/lib/convert/values_to_tuples.py +85 -0
- datachain/lib/data_model.py +74 -0
- datachain/lib/dc.py +225 -168
- datachain/lib/file.py +41 -41
- datachain/lib/gpt4_vision.py +1 -9
- datachain/lib/hf_image_to_text.py +9 -17
- datachain/lib/hf_pipeline.py +4 -12
- datachain/lib/image.py +2 -18
- datachain/lib/image_transform.py +0 -1
- datachain/lib/iptc_exif_xmp.py +8 -15
- datachain/lib/meta_formats.py +1 -5
- datachain/lib/model_store.py +77 -0
- datachain/lib/pytorch.py +9 -21
- datachain/lib/signal_schema.py +139 -60
- datachain/lib/text.py +5 -16
- datachain/lib/udf.py +114 -30
- datachain/lib/udf_signature.py +5 -5
- datachain/lib/webdataset.py +3 -3
- datachain/lib/webdataset_laion.py +2 -3
- datachain/node.py +4 -4
- datachain/query/batch.py +1 -1
- datachain/query/dataset.py +51 -178
- datachain/query/dispatch.py +43 -30
- datachain/query/udf.py +46 -26
- datachain/remote/studio.py +1 -9
- datachain/torch/__init__.py +21 -0
- datachain/utils.py +39 -0
- {datachain-0.2.9.dist-info → datachain-0.2.11.dist-info}/METADATA +14 -12
- {datachain-0.2.9.dist-info → datachain-0.2.11.dist-info}/RECORD +45 -43
- {datachain-0.2.9.dist-info → datachain-0.2.11.dist-info}/WHEEL +1 -1
- datachain/image/__init__.py +0 -3
- datachain/lib/cached_stream.py +0 -38
- datachain/lib/claude.py +0 -69
- datachain/lib/feature.py +0 -412
- datachain/lib/feature_registry.py +0 -51
- datachain/lib/feature_utils.py +0 -154
- {datachain-0.2.9.dist-info → datachain-0.2.11.dist-info}/LICENSE +0 -0
- {datachain-0.2.9.dist-info → datachain-0.2.11.dist-info}/entry_points.txt +0 -0
- {datachain-0.2.9.dist-info → datachain-0.2.11.dist-info}/top_level.txt +0 -0
datachain/__init__.py
CHANGED
|
@@ -1,11 +1,16 @@
|
|
|
1
|
-
from datachain.lib.
|
|
2
|
-
from datachain.lib.
|
|
3
|
-
from datachain.lib.
|
|
4
|
-
|
|
1
|
+
from datachain.lib.data_model import DataModel, DataType, FileBasic, is_chain_type
|
|
2
|
+
from datachain.lib.dc import C, Column, DataChain, Sys
|
|
3
|
+
from datachain.lib.file import (
|
|
4
|
+
File,
|
|
5
|
+
FileError,
|
|
6
|
+
ImageFile,
|
|
7
|
+
IndexedFile,
|
|
8
|
+
TarVFile,
|
|
9
|
+
TextFile,
|
|
10
|
+
)
|
|
5
11
|
from datachain.lib.udf import Aggregator, Generator, Mapper
|
|
6
12
|
from datachain.lib.utils import AbstractUDF, DataChainError
|
|
7
13
|
from datachain.query.dataset import UDF as BaseUDF # noqa: N811
|
|
8
|
-
from datachain.query.schema import Column
|
|
9
14
|
from datachain.query.session import Session
|
|
10
15
|
|
|
11
16
|
__all__ = [
|
|
@@ -16,14 +21,18 @@ __all__ = [
|
|
|
16
21
|
"Column",
|
|
17
22
|
"DataChain",
|
|
18
23
|
"DataChainError",
|
|
19
|
-
"
|
|
24
|
+
"DataModel",
|
|
25
|
+
"DataType",
|
|
20
26
|
"File",
|
|
27
|
+
"FileBasic",
|
|
21
28
|
"FileError",
|
|
22
|
-
"FileFeature",
|
|
23
29
|
"Generator",
|
|
30
|
+
"ImageFile",
|
|
24
31
|
"IndexedFile",
|
|
25
32
|
"Mapper",
|
|
26
33
|
"Session",
|
|
34
|
+
"Sys",
|
|
27
35
|
"TarVFile",
|
|
28
|
-
"
|
|
36
|
+
"TextFile",
|
|
37
|
+
"is_chain_type",
|
|
29
38
|
]
|
datachain/catalog/catalog.py
CHANGED
|
@@ -256,7 +256,7 @@ class DatasetRowsFetcher(NodesThreadPool):
|
|
|
256
256
|
self.fix_columns(df)
|
|
257
257
|
|
|
258
258
|
# id will be autogenerated in DB
|
|
259
|
-
df = df.drop("
|
|
259
|
+
df = df.drop("sys__id", axis=1)
|
|
260
260
|
|
|
261
261
|
inserted = warehouse.insert_dataset_rows(
|
|
262
262
|
df, dataset, self.dataset_version
|
|
@@ -1041,7 +1041,7 @@ class Catalog:
|
|
|
1041
1041
|
If version is None, then next unused version is created.
|
|
1042
1042
|
If version is given, then it must be an unused version number.
|
|
1043
1043
|
"""
|
|
1044
|
-
assert [c.name for c in columns if c.name != "
|
|
1044
|
+
assert [c.name for c in columns if c.name != "sys__id"], f"got {columns=}"
|
|
1045
1045
|
if not listing and Client.is_data_source_uri(name):
|
|
1046
1046
|
raise RuntimeError(
|
|
1047
1047
|
"Cannot create dataset that starts with source prefix, e.g s3://"
|
|
@@ -1103,7 +1103,7 @@ class Catalog:
|
|
|
1103
1103
|
Creates dataset version if it doesn't exist.
|
|
1104
1104
|
If create_rows is False, dataset rows table will not be created
|
|
1105
1105
|
"""
|
|
1106
|
-
assert [c.name for c in columns if c.name != "
|
|
1106
|
+
assert [c.name for c in columns if c.name != "sys__id"], f"got {columns=}"
|
|
1107
1107
|
schema = {
|
|
1108
1108
|
c.name: c.type.to_dict() for c in columns if isinstance(c.type, SQLType)
|
|
1109
1109
|
}
|
|
@@ -1433,7 +1433,7 @@ class Catalog:
|
|
|
1433
1433
|
if offset:
|
|
1434
1434
|
q = q.offset(offset)
|
|
1435
1435
|
|
|
1436
|
-
q = q.order_by("
|
|
1436
|
+
q = q.order_by("sys__id")
|
|
1437
1437
|
|
|
1438
1438
|
return q.to_records()
|
|
1439
1439
|
|
|
@@ -1786,7 +1786,7 @@ class Catalog:
|
|
|
1786
1786
|
schema = DatasetRecord.parse_schema(remote_dataset_version.schema)
|
|
1787
1787
|
|
|
1788
1788
|
columns = tuple(
|
|
1789
|
-
sa.Column(name, typ) for name, typ in schema.items() if name != "
|
|
1789
|
+
sa.Column(name, typ) for name, typ in schema.items() if name != "sys__id"
|
|
1790
1790
|
)
|
|
1791
1791
|
# creating new dataset (version) locally
|
|
1792
1792
|
dataset = self.create_dataset(
|
datachain/cli.py
CHANGED
datachain/data_storage/schema.py
CHANGED
|
@@ -72,7 +72,7 @@ class DirExpansion:
|
|
|
72
72
|
@staticmethod
|
|
73
73
|
def base_select(q):
|
|
74
74
|
return sa.select(
|
|
75
|
-
q.c.
|
|
75
|
+
q.c.sys__id,
|
|
76
76
|
q.c.vtype,
|
|
77
77
|
(q.c.dir_type == DirType.DIR).label("is_dir"),
|
|
78
78
|
q.c.source,
|
|
@@ -86,7 +86,7 @@ class DirExpansion:
|
|
|
86
86
|
def apply_group_by(q):
|
|
87
87
|
return (
|
|
88
88
|
sa.select(
|
|
89
|
-
f.min(q.c.
|
|
89
|
+
f.min(q.c.sys__id).label("sys__id"),
|
|
90
90
|
q.c.vtype,
|
|
91
91
|
q.c.is_dir,
|
|
92
92
|
q.c.source,
|
|
@@ -111,7 +111,7 @@ class DirExpansion:
|
|
|
111
111
|
parent_name = path.name(q.c.parent)
|
|
112
112
|
q = q.union_all(
|
|
113
113
|
sa.select(
|
|
114
|
-
sa.literal(-1).label("
|
|
114
|
+
sa.literal(-1).label("sys__id"),
|
|
115
115
|
sa.literal("").label("vtype"),
|
|
116
116
|
true().label("is_dir"),
|
|
117
117
|
q.c.source,
|
|
@@ -233,9 +233,9 @@ class DataTable:
|
|
|
233
233
|
@staticmethod
|
|
234
234
|
def sys_columns():
|
|
235
235
|
return [
|
|
236
|
-
sa.Column("
|
|
236
|
+
sa.Column("sys__id", Int, primary_key=True),
|
|
237
237
|
sa.Column(
|
|
238
|
-
"
|
|
238
|
+
"sys__rand", UInt64, nullable=False, server_default=f.abs(f.random())
|
|
239
239
|
),
|
|
240
240
|
]
|
|
241
241
|
|
datachain/data_storage/sqlite.py
CHANGED
|
@@ -631,7 +631,7 @@ class SQLiteWarehouse(AbstractWarehouse):
|
|
|
631
631
|
dst_empty = True
|
|
632
632
|
|
|
633
633
|
dst_dr = self.dataset_rows(dst, dst_version).table
|
|
634
|
-
merge_fields = [c.name for c in src_dr.c if c.name != "
|
|
634
|
+
merge_fields = [c.name for c in src_dr.c if c.name != "sys__id"]
|
|
635
635
|
select_src = select(*(getattr(src_dr.c, f) for f in merge_fields))
|
|
636
636
|
|
|
637
637
|
if dst_empty:
|
|
@@ -195,7 +195,7 @@ class AbstractWarehouse(ABC, Serializable):
|
|
|
195
195
|
cols_names = [c.name for c in cols]
|
|
196
196
|
|
|
197
197
|
if not order_by:
|
|
198
|
-
ordering = [cols.
|
|
198
|
+
ordering = [cols.sys__id]
|
|
199
199
|
else:
|
|
200
200
|
ordering = order_by # type: ignore[assignment]
|
|
201
201
|
|
|
@@ -372,7 +372,7 @@ class AbstractWarehouse(ABC, Serializable):
|
|
|
372
372
|
"""Returns total number of rows in a dataset"""
|
|
373
373
|
dr = self.dataset_rows(dataset, version)
|
|
374
374
|
table = dr.get_table()
|
|
375
|
-
query = select(sa.func.count(table.c.
|
|
375
|
+
query = select(sa.func.count(table.c.sys__id))
|
|
376
376
|
(res,) = self.db.execute(query)
|
|
377
377
|
return res[0]
|
|
378
378
|
|
|
@@ -388,7 +388,7 @@ class AbstractWarehouse(ABC, Serializable):
|
|
|
388
388
|
dr = self.dataset_rows(dataset, version)
|
|
389
389
|
table = dr.get_table()
|
|
390
390
|
expressions: tuple[_ColumnsClauseArgument[Any], ...] = (
|
|
391
|
-
sa.func.count(table.c.
|
|
391
|
+
sa.func.count(table.c.sys__id),
|
|
392
392
|
)
|
|
393
393
|
if "size" in table.columns:
|
|
394
394
|
expressions = (*expressions, sa.func.sum(table.c.size))
|
|
@@ -607,7 +607,7 @@ class AbstractWarehouse(ABC, Serializable):
|
|
|
607
607
|
return func.coalesce(column, default).label(column.name)
|
|
608
608
|
|
|
609
609
|
return sa.select(
|
|
610
|
-
de.c.
|
|
610
|
+
de.c.sys__id,
|
|
611
611
|
with_default(dr.c.vtype),
|
|
612
612
|
case((de.c.is_dir == true(), DirType.DIR), else_=dr.c.dir_type).label(
|
|
613
613
|
"dir_type"
|
|
@@ -621,10 +621,10 @@ class AbstractWarehouse(ABC, Serializable):
|
|
|
621
621
|
with_default(dr.c.size),
|
|
622
622
|
with_default(dr.c.owner_name),
|
|
623
623
|
with_default(dr.c.owner_id),
|
|
624
|
-
with_default(dr.c.
|
|
624
|
+
with_default(dr.c.sys__rand),
|
|
625
625
|
dr.c.location,
|
|
626
626
|
de.c.source,
|
|
627
|
-
).select_from(de.outerjoin(dr.table, de.c.
|
|
627
|
+
).select_from(de.outerjoin(dr.table, de.c.sys__id == dr.c.sys__id))
|
|
628
628
|
|
|
629
629
|
def get_node_by_path(self, dataset_rows: "DataTable", path: str) -> Node:
|
|
630
630
|
"""Gets node that corresponds to some path"""
|
|
@@ -878,7 +878,7 @@ class AbstractWarehouse(ABC, Serializable):
|
|
|
878
878
|
tbl = sa.Table(
|
|
879
879
|
name,
|
|
880
880
|
sa.MetaData(),
|
|
881
|
-
sa.Column("
|
|
881
|
+
sa.Column("sys__id", Int, primary_key=True),
|
|
882
882
|
*columns,
|
|
883
883
|
)
|
|
884
884
|
self.db.create_table(tbl, if_not_exists=True)
|
datachain/lib/arrow.py
CHANGED
|
@@ -1,13 +1,15 @@
|
|
|
1
1
|
import re
|
|
2
|
+
from collections.abc import Sequence
|
|
2
3
|
from typing import TYPE_CHECKING, Optional
|
|
3
4
|
|
|
5
|
+
import pyarrow as pa
|
|
4
6
|
from pyarrow.dataset import dataset
|
|
5
7
|
|
|
6
8
|
from datachain.lib.file import File, IndexedFile
|
|
7
9
|
from datachain.lib.udf import Generator
|
|
8
10
|
|
|
9
11
|
if TYPE_CHECKING:
|
|
10
|
-
|
|
12
|
+
from datachain.lib.dc import DataChain
|
|
11
13
|
|
|
12
14
|
|
|
13
15
|
class ArrowGenerator(Generator):
|
|
@@ -35,12 +37,29 @@ class ArrowGenerator(Generator):
|
|
|
35
37
|
index += 1
|
|
36
38
|
|
|
37
39
|
|
|
38
|
-
def
|
|
40
|
+
def infer_schema(chain: "DataChain", **kwargs) -> pa.Schema:
|
|
41
|
+
schemas = []
|
|
42
|
+
for file in chain.iterate_one("file"):
|
|
43
|
+
ds = dataset(file.get_path(), filesystem=file.get_fs(), **kwargs) # type: ignore[union-attr]
|
|
44
|
+
schemas.append(ds.schema)
|
|
45
|
+
return pa.unify_schemas(schemas)
|
|
46
|
+
|
|
47
|
+
|
|
48
|
+
def schema_to_output(schema: pa.Schema, col_names: Optional[Sequence[str]] = None):
|
|
39
49
|
"""Generate UDF output schema from pyarrow schema."""
|
|
50
|
+
if col_names and (len(schema) != len(col_names)):
|
|
51
|
+
raise ValueError(
|
|
52
|
+
"Error generating output from Arrow schema - "
|
|
53
|
+
f"Schema has {len(schema)} columns but got {len(col_names)} column names."
|
|
54
|
+
)
|
|
40
55
|
default_column = 0
|
|
41
|
-
output = {
|
|
42
|
-
for field in schema:
|
|
43
|
-
|
|
56
|
+
output = {}
|
|
57
|
+
for i, field in enumerate(schema):
|
|
58
|
+
if col_names:
|
|
59
|
+
column = col_names[i]
|
|
60
|
+
else:
|
|
61
|
+
column = field.name
|
|
62
|
+
column = column.lower()
|
|
44
63
|
column = re.sub("[^0-9a-z_]+", "", column)
|
|
45
64
|
if not column:
|
|
46
65
|
column = f"c{default_column}"
|
|
@@ -50,12 +69,10 @@ def schema_to_output(schema: "pa.Schema"):
|
|
|
50
69
|
return output
|
|
51
70
|
|
|
52
71
|
|
|
53
|
-
def _arrow_type_mapper(col_type:
|
|
72
|
+
def _arrow_type_mapper(col_type: pa.DataType) -> type: # noqa: PLR0911
|
|
54
73
|
"""Convert pyarrow types to basic types."""
|
|
55
74
|
from datetime import datetime
|
|
56
75
|
|
|
57
|
-
import pyarrow as pa
|
|
58
|
-
|
|
59
76
|
if pa.types.is_timestamp(col_type):
|
|
60
77
|
return datetime
|
|
61
78
|
if pa.types.is_binary(col_type):
|
datachain/lib/clip.py
CHANGED
|
@@ -1,19 +1,14 @@
|
|
|
1
1
|
import inspect
|
|
2
|
-
from typing import Any, Callable, Literal, Union
|
|
2
|
+
from typing import TYPE_CHECKING, Any, Callable, Literal, Union
|
|
3
|
+
|
|
4
|
+
import torch
|
|
5
|
+
from transformers.modeling_utils import PreTrainedModel
|
|
3
6
|
|
|
4
7
|
from datachain.lib.image import convert_images
|
|
5
8
|
from datachain.lib.text import convert_text
|
|
6
9
|
|
|
7
|
-
|
|
8
|
-
import torch
|
|
10
|
+
if TYPE_CHECKING:
|
|
9
11
|
from PIL import Image
|
|
10
|
-
from transformers.modeling_utils import PreTrainedModel
|
|
11
|
-
except ImportError as exc:
|
|
12
|
-
raise ImportError(
|
|
13
|
-
"Missing dependencies for computer vision:\n"
|
|
14
|
-
"To install run:\n\n"
|
|
15
|
-
" pip install 'datachain[cv]'\n"
|
|
16
|
-
) from exc
|
|
17
12
|
|
|
18
13
|
|
|
19
14
|
def _get_encoder(model: Any, type: Literal["image", "text"]) -> Callable:
|
|
@@ -37,7 +32,7 @@ def _get_encoder(model: Any, type: Literal["image", "text"]) -> Callable:
|
|
|
37
32
|
|
|
38
33
|
|
|
39
34
|
def similarity_scores(
|
|
40
|
-
images: Union[None, Image.Image, list[Image.Image]],
|
|
35
|
+
images: Union[None, "Image.Image", list["Image.Image"]],
|
|
41
36
|
text: Union[None, str, list[str]],
|
|
42
37
|
model: Any,
|
|
43
38
|
preprocess: Callable,
|
|
File without changes
|
|
@@ -0,0 +1,67 @@
|
|
|
1
|
+
from datetime import datetime
|
|
2
|
+
|
|
3
|
+
from pydantic import BaseModel
|
|
4
|
+
|
|
5
|
+
from datachain.lib.model_store import ModelStore
|
|
6
|
+
from datachain.sql.types import (
|
|
7
|
+
JSON,
|
|
8
|
+
Array,
|
|
9
|
+
Binary,
|
|
10
|
+
Boolean,
|
|
11
|
+
DateTime,
|
|
12
|
+
Float,
|
|
13
|
+
Int,
|
|
14
|
+
Int32,
|
|
15
|
+
Int64,
|
|
16
|
+
NullType,
|
|
17
|
+
String,
|
|
18
|
+
)
|
|
19
|
+
|
|
20
|
+
DATACHAIN_TO_TYPE = {
|
|
21
|
+
Int: int,
|
|
22
|
+
Int32: int,
|
|
23
|
+
Int64: int,
|
|
24
|
+
String: str,
|
|
25
|
+
Float: float,
|
|
26
|
+
Boolean: bool,
|
|
27
|
+
DateTime: datetime,
|
|
28
|
+
Binary: bytes,
|
|
29
|
+
Array(NullType): list,
|
|
30
|
+
JSON: dict,
|
|
31
|
+
}
|
|
32
|
+
|
|
33
|
+
|
|
34
|
+
def flatten(obj: BaseModel):
|
|
35
|
+
return tuple(_flatten_fields_values(obj.model_fields, obj))
|
|
36
|
+
|
|
37
|
+
|
|
38
|
+
def flatten_list(obj_list):
|
|
39
|
+
return tuple(
|
|
40
|
+
val for obj in obj_list for val in _flatten_fields_values(obj.model_fields, obj)
|
|
41
|
+
)
|
|
42
|
+
|
|
43
|
+
|
|
44
|
+
def _flatten_fields_values(fields, obj: BaseModel):
|
|
45
|
+
for name, f_info in fields.items():
|
|
46
|
+
anno = f_info.annotation
|
|
47
|
+
# Optimization: Access attributes directly to skip the model_dump() call.
|
|
48
|
+
value = getattr(obj, name)
|
|
49
|
+
|
|
50
|
+
if isinstance(value, list):
|
|
51
|
+
yield [
|
|
52
|
+
val.model_dump() if ModelStore.is_pydantic(type(val)) else val
|
|
53
|
+
for val in value
|
|
54
|
+
]
|
|
55
|
+
elif isinstance(value, dict):
|
|
56
|
+
yield {
|
|
57
|
+
key: val.model_dump() if ModelStore.is_pydantic(type(val)) else val
|
|
58
|
+
for key, val in value.items()
|
|
59
|
+
}
|
|
60
|
+
elif ModelStore.is_pydantic(anno):
|
|
61
|
+
yield from _flatten_fields_values(anno.model_fields, value)
|
|
62
|
+
else:
|
|
63
|
+
yield value
|
|
64
|
+
|
|
65
|
+
|
|
66
|
+
def _flatten(obj):
|
|
67
|
+
return tuple(_flatten_fields_values(obj.model_fields, obj))
|
|
@@ -0,0 +1,96 @@
|
|
|
1
|
+
import inspect
|
|
2
|
+
from datetime import datetime
|
|
3
|
+
from enum import Enum
|
|
4
|
+
from typing import Annotated, Literal, Union, get_args, get_origin
|
|
5
|
+
|
|
6
|
+
from pydantic import BaseModel
|
|
7
|
+
from typing_extensions import Literal as LiteralEx
|
|
8
|
+
|
|
9
|
+
from datachain.lib.model_store import ModelStore
|
|
10
|
+
from datachain.sql.types import (
|
|
11
|
+
JSON,
|
|
12
|
+
Array,
|
|
13
|
+
Binary,
|
|
14
|
+
Boolean,
|
|
15
|
+
DateTime,
|
|
16
|
+
Float,
|
|
17
|
+
Int64,
|
|
18
|
+
SQLType,
|
|
19
|
+
String,
|
|
20
|
+
)
|
|
21
|
+
|
|
22
|
+
TYPE_TO_DATACHAIN = {
|
|
23
|
+
int: Int64,
|
|
24
|
+
str: String,
|
|
25
|
+
Literal: String,
|
|
26
|
+
LiteralEx: String,
|
|
27
|
+
Enum: String,
|
|
28
|
+
float: Float,
|
|
29
|
+
bool: Boolean,
|
|
30
|
+
datetime: DateTime, # Note, list of datetime is not supported yet
|
|
31
|
+
bytes: Binary, # Note, list of bytes is not supported yet
|
|
32
|
+
list: Array,
|
|
33
|
+
dict: JSON,
|
|
34
|
+
}
|
|
35
|
+
|
|
36
|
+
|
|
37
|
+
def convert_to_db_type(typ): # noqa: PLR0911
|
|
38
|
+
if inspect.isclass(typ):
|
|
39
|
+
if issubclass(typ, SQLType):
|
|
40
|
+
return typ
|
|
41
|
+
if issubclass(typ, Enum):
|
|
42
|
+
return str
|
|
43
|
+
|
|
44
|
+
res = TYPE_TO_DATACHAIN.get(typ)
|
|
45
|
+
if res:
|
|
46
|
+
return res
|
|
47
|
+
|
|
48
|
+
orig = get_origin(typ)
|
|
49
|
+
|
|
50
|
+
if orig in (Literal, LiteralEx):
|
|
51
|
+
return String
|
|
52
|
+
|
|
53
|
+
args = get_args(typ)
|
|
54
|
+
if inspect.isclass(orig) and (issubclass(list, orig) or issubclass(tuple, orig)):
|
|
55
|
+
if args is None or len(args) != 1:
|
|
56
|
+
raise TypeError(f"Cannot resolve type '{typ}' for flattening features")
|
|
57
|
+
|
|
58
|
+
args0 = args[0]
|
|
59
|
+
if ModelStore.is_pydantic(args0):
|
|
60
|
+
return Array(JSON())
|
|
61
|
+
|
|
62
|
+
next_type = convert_to_db_type(args0)
|
|
63
|
+
return Array(next_type)
|
|
64
|
+
|
|
65
|
+
if orig is Annotated:
|
|
66
|
+
# Ignoring annotations
|
|
67
|
+
return convert_to_db_type(args[0])
|
|
68
|
+
|
|
69
|
+
if inspect.isclass(orig) and issubclass(dict, orig):
|
|
70
|
+
return JSON
|
|
71
|
+
|
|
72
|
+
if orig == Union:
|
|
73
|
+
if len(args) == 2 and (type(None) in args):
|
|
74
|
+
return convert_to_db_type(args[0])
|
|
75
|
+
|
|
76
|
+
if _is_json_inside_union(orig, args):
|
|
77
|
+
return JSON
|
|
78
|
+
|
|
79
|
+
raise TypeError(f"Cannot recognize type {typ}")
|
|
80
|
+
|
|
81
|
+
|
|
82
|
+
def _is_json_inside_union(orig, args) -> bool:
|
|
83
|
+
if orig == Union and len(args) >= 2:
|
|
84
|
+
# List in JSON: Union[dict, list[dict]]
|
|
85
|
+
args_no_nones = [arg for arg in args if arg != type(None)]
|
|
86
|
+
if len(args_no_nones) == 2:
|
|
87
|
+
args_no_dicts = [arg for arg in args_no_nones if arg is not dict]
|
|
88
|
+
if len(args_no_dicts) == 1 and get_origin(args_no_dicts[0]) is list:
|
|
89
|
+
arg = get_args(args_no_dicts[0])
|
|
90
|
+
if len(arg) == 1 and arg[0] is dict:
|
|
91
|
+
return True
|
|
92
|
+
|
|
93
|
+
# List of objects: Union[MyClass, OtherClass]
|
|
94
|
+
if any(inspect.isclass(arg) and issubclass(arg, BaseModel) for arg in args):
|
|
95
|
+
return True
|
|
96
|
+
return False
|
|
@@ -0,0 +1,69 @@
|
|
|
1
|
+
import copy
|
|
2
|
+
import inspect
|
|
3
|
+
import re
|
|
4
|
+
from collections.abc import Sequence
|
|
5
|
+
from typing import Any, get_origin
|
|
6
|
+
|
|
7
|
+
from pydantic import BaseModel
|
|
8
|
+
|
|
9
|
+
from datachain.query.schema import DEFAULT_DELIMITER
|
|
10
|
+
|
|
11
|
+
|
|
12
|
+
def unflatten_to_json(model: type[BaseModel], row: Sequence[Any], pos=0) -> dict:
|
|
13
|
+
return unflatten_to_json_pos(model, row, pos)[0]
|
|
14
|
+
|
|
15
|
+
|
|
16
|
+
def unflatten_to_json_pos(
|
|
17
|
+
model: type[BaseModel], row: Sequence[Any], pos=0
|
|
18
|
+
) -> tuple[dict, int]:
|
|
19
|
+
res = {}
|
|
20
|
+
for name, f_info in model.model_fields.items():
|
|
21
|
+
anno = f_info.annotation
|
|
22
|
+
origin = get_origin(anno)
|
|
23
|
+
if (
|
|
24
|
+
origin not in (list, dict)
|
|
25
|
+
and inspect.isclass(anno)
|
|
26
|
+
and issubclass(anno, BaseModel)
|
|
27
|
+
):
|
|
28
|
+
res[name], pos = unflatten_to_json_pos(anno, row, pos)
|
|
29
|
+
else:
|
|
30
|
+
res[name] = row[pos]
|
|
31
|
+
pos += 1
|
|
32
|
+
return res, pos
|
|
33
|
+
|
|
34
|
+
|
|
35
|
+
def _normalize(name: str) -> str:
|
|
36
|
+
if DEFAULT_DELIMITER in name:
|
|
37
|
+
raise RuntimeError(
|
|
38
|
+
f"variable '{name}' cannot be used "
|
|
39
|
+
f"because it contains {DEFAULT_DELIMITER}"
|
|
40
|
+
)
|
|
41
|
+
return _to_snake_case(name)
|
|
42
|
+
|
|
43
|
+
|
|
44
|
+
def _to_snake_case(name: str) -> str:
|
|
45
|
+
"""Convert a CamelCase name to snake_case."""
|
|
46
|
+
s1 = re.sub("(.)([A-Z][a-z]+)", r"\1_\2", name)
|
|
47
|
+
return re.sub("([a-z0-9])([A-Z])", r"\1_\2", s1).lower()
|
|
48
|
+
|
|
49
|
+
|
|
50
|
+
def _unflatten_with_path(model: type[BaseModel], dump, name_path: list[str]):
|
|
51
|
+
res = {}
|
|
52
|
+
for name, f_info in model.model_fields.items():
|
|
53
|
+
anno = f_info.annotation
|
|
54
|
+
name_norm = _normalize(name)
|
|
55
|
+
lst = copy.copy(name_path)
|
|
56
|
+
|
|
57
|
+
if inspect.isclass(anno) and issubclass(anno, BaseModel):
|
|
58
|
+
lst.append(name_norm)
|
|
59
|
+
val = _unflatten_with_path(anno, dump, lst)
|
|
60
|
+
res[name] = val
|
|
61
|
+
else:
|
|
62
|
+
lst.append(name_norm)
|
|
63
|
+
curr_path = DEFAULT_DELIMITER.join(lst)
|
|
64
|
+
res[name] = dump[curr_path]
|
|
65
|
+
return model(**res)
|
|
66
|
+
|
|
67
|
+
|
|
68
|
+
def unflatten(model: type[BaseModel], dump):
|
|
69
|
+
return _unflatten_with_path(model, dump, [])
|
|
@@ -0,0 +1,85 @@
|
|
|
1
|
+
from collections.abc import Sequence
|
|
2
|
+
from typing import Any, Union
|
|
3
|
+
|
|
4
|
+
from datachain.lib.data_model import DataType, DataTypeNames, is_chain_type
|
|
5
|
+
from datachain.lib.utils import DataChainParamsError
|
|
6
|
+
|
|
7
|
+
|
|
8
|
+
class ValuesToTupleError(DataChainParamsError):
|
|
9
|
+
def __init__(self, ds_name, msg):
|
|
10
|
+
if ds_name:
|
|
11
|
+
ds_name = f"' {ds_name}'"
|
|
12
|
+
super().__init__(f"Cannot convert features for dataset{ds_name}: {msg}")
|
|
13
|
+
|
|
14
|
+
|
|
15
|
+
def values_to_tuples(
|
|
16
|
+
ds_name: str = "",
|
|
17
|
+
output: Union[None, DataType, Sequence[str], dict[str, DataType]] = None,
|
|
18
|
+
**fr_map,
|
|
19
|
+
) -> tuple[Any, Any, Any]:
|
|
20
|
+
types_map = {}
|
|
21
|
+
length = -1
|
|
22
|
+
for k, v in fr_map.items():
|
|
23
|
+
if not isinstance(v, Sequence) or isinstance(v, str):
|
|
24
|
+
raise ValuesToTupleError(ds_name, f"features '{k}' is not a sequence")
|
|
25
|
+
len_ = len(v)
|
|
26
|
+
|
|
27
|
+
if len_ == 0:
|
|
28
|
+
raise ValuesToTupleError(ds_name, f"feature '{k}' is empty list")
|
|
29
|
+
|
|
30
|
+
if length < 0:
|
|
31
|
+
length = len_
|
|
32
|
+
elif length != len_:
|
|
33
|
+
raise ValuesToTupleError(
|
|
34
|
+
ds_name,
|
|
35
|
+
f"feature '{k}' should have length {length} while {len_} is given",
|
|
36
|
+
)
|
|
37
|
+
typ = type(v[0])
|
|
38
|
+
if not is_chain_type(typ):
|
|
39
|
+
raise ValuesToTupleError(
|
|
40
|
+
ds_name,
|
|
41
|
+
f"feature '{k}' has unsupported type '{typ.__name__}'."
|
|
42
|
+
f" Please use Feature types: {DataTypeNames}",
|
|
43
|
+
)
|
|
44
|
+
types_map[k] = typ
|
|
45
|
+
if output:
|
|
46
|
+
if not isinstance(output, Sequence) and not isinstance(output, str):
|
|
47
|
+
if len(fr_map) != 1:
|
|
48
|
+
raise ValuesToTupleError(
|
|
49
|
+
ds_name,
|
|
50
|
+
f"only one output type was specified, {len(fr_map)} expected",
|
|
51
|
+
)
|
|
52
|
+
if not isinstance(output, type):
|
|
53
|
+
raise ValuesToTupleError(
|
|
54
|
+
ds_name,
|
|
55
|
+
f"output must specify a type while '{output}' was given",
|
|
56
|
+
)
|
|
57
|
+
|
|
58
|
+
key: str = next(iter(fr_map.keys()))
|
|
59
|
+
output = {key: output} # type: ignore[dict-item]
|
|
60
|
+
|
|
61
|
+
if len(output) != len(fr_map):
|
|
62
|
+
raise ValuesToTupleError(
|
|
63
|
+
ds_name,
|
|
64
|
+
f"number of outputs '{len(output)}' should match"
|
|
65
|
+
f" number of features '{len(fr_map)}'",
|
|
66
|
+
)
|
|
67
|
+
if isinstance(output, dict):
|
|
68
|
+
raise ValuesToTupleError(
|
|
69
|
+
ds_name,
|
|
70
|
+
"output type must be dict[str, FeatureType] while "
|
|
71
|
+
f"'{type(output).__name__}' is given",
|
|
72
|
+
)
|
|
73
|
+
else:
|
|
74
|
+
output = types_map # type: ignore[assignment]
|
|
75
|
+
|
|
76
|
+
output_types: list[type] = list(output.values()) # type: ignore[union-attr,call-arg,arg-type]
|
|
77
|
+
if len(output) > 1: # type: ignore[arg-type]
|
|
78
|
+
tuple_type = tuple(output_types)
|
|
79
|
+
res_type = tuple[tuple_type] # type: ignore[valid-type]
|
|
80
|
+
res_values = list(zip(*fr_map.values()))
|
|
81
|
+
else:
|
|
82
|
+
res_type = output_types[0] # type: ignore[misc]
|
|
83
|
+
res_values = next(iter(fr_map.values()))
|
|
84
|
+
|
|
85
|
+
return res_type, output, res_values
|