datachain 0.7.0__py3-none-any.whl → 0.7.2__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of datachain might be problematic. Click here for more details.
- datachain/__init__.py +0 -3
- datachain/catalog/catalog.py +8 -6
- datachain/cli.py +1 -1
- datachain/client/fsspec.py +9 -9
- datachain/data_storage/schema.py +2 -2
- datachain/data_storage/sqlite.py +5 -4
- datachain/data_storage/warehouse.py +18 -18
- datachain/func/__init__.py +49 -0
- datachain/{lib/func → func}/aggregate.py +13 -11
- datachain/func/array.py +176 -0
- datachain/func/base.py +23 -0
- datachain/func/conditional.py +81 -0
- datachain/func/func.py +384 -0
- datachain/func/path.py +110 -0
- datachain/func/random.py +23 -0
- datachain/func/string.py +154 -0
- datachain/func/window.py +49 -0
- datachain/lib/arrow.py +24 -12
- datachain/lib/data_model.py +25 -9
- datachain/lib/dataset_info.py +2 -2
- datachain/lib/dc.py +94 -56
- datachain/lib/hf.py +1 -1
- datachain/lib/signal_schema.py +1 -1
- datachain/lib/utils.py +1 -0
- datachain/lib/webdataset_laion.py +5 -5
- datachain/model/__init__.py +6 -0
- datachain/model/bbox.py +102 -0
- datachain/model/pose.py +88 -0
- datachain/model/segment.py +47 -0
- datachain/model/ultralytics/__init__.py +27 -0
- datachain/model/ultralytics/bbox.py +147 -0
- datachain/model/ultralytics/pose.py +113 -0
- datachain/model/ultralytics/segment.py +91 -0
- datachain/nodes_fetcher.py +2 -2
- datachain/query/dataset.py +57 -34
- datachain/sql/__init__.py +0 -2
- datachain/sql/functions/__init__.py +0 -26
- datachain/sql/selectable.py +11 -5
- datachain/sql/sqlite/base.py +11 -2
- datachain/toolkit/split.py +6 -2
- {datachain-0.7.0.dist-info → datachain-0.7.2.dist-info}/METADATA +72 -71
- {datachain-0.7.0.dist-info → datachain-0.7.2.dist-info}/RECORD +46 -35
- {datachain-0.7.0.dist-info → datachain-0.7.2.dist-info}/WHEEL +1 -1
- datachain/lib/func/__init__.py +0 -32
- datachain/lib/func/func.py +0 -152
- datachain/lib/models/__init__.py +0 -5
- datachain/lib/models/bbox.py +0 -45
- datachain/lib/models/pose.py +0 -37
- datachain/lib/models/yolo.py +0 -39
- {datachain-0.7.0.dist-info → datachain-0.7.2.dist-info}/LICENSE +0 -0
- {datachain-0.7.0.dist-info → datachain-0.7.2.dist-info}/entry_points.txt +0 -0
- {datachain-0.7.0.dist-info → datachain-0.7.2.dist-info}/top_level.txt +0 -0
datachain/func/window.py
ADDED
|
@@ -0,0 +1,49 @@
|
|
|
1
|
+
from dataclasses import dataclass
|
|
2
|
+
|
|
3
|
+
from datachain.query.schema import ColumnMeta
|
|
4
|
+
|
|
5
|
+
|
|
6
|
+
@dataclass
|
|
7
|
+
class Window:
|
|
8
|
+
"""Represents a window specification for SQL window functions."""
|
|
9
|
+
|
|
10
|
+
partition_by: str
|
|
11
|
+
order_by: str
|
|
12
|
+
desc: bool = False
|
|
13
|
+
|
|
14
|
+
|
|
15
|
+
def window(partition_by: str, order_by: str, desc: bool = False) -> Window:
|
|
16
|
+
"""
|
|
17
|
+
Defines a window specification for SQL window functions.
|
|
18
|
+
|
|
19
|
+
The `window` function specifies how to partition and order the result set
|
|
20
|
+
for the associated window function. It is used to define the scope of the rows
|
|
21
|
+
that the window function will operate on.
|
|
22
|
+
|
|
23
|
+
Args:
|
|
24
|
+
partition_by (str): The column name by which to partition the result set.
|
|
25
|
+
Rows with the same value in the partition column
|
|
26
|
+
will be grouped together for the window function.
|
|
27
|
+
order_by (str): The column name by which to order the rows
|
|
28
|
+
within each partition. This determines the sequence in which
|
|
29
|
+
the window function is applied.
|
|
30
|
+
desc (bool, optional): If True, the rows will be ordered in descending order.
|
|
31
|
+
Defaults to False, which orders the rows
|
|
32
|
+
in ascending order.
|
|
33
|
+
|
|
34
|
+
Returns:
|
|
35
|
+
Window: A Window object representing the window specification.
|
|
36
|
+
|
|
37
|
+
Example:
|
|
38
|
+
```py
|
|
39
|
+
window = func.window(partition_by="signal.category", order_by="created_at")
|
|
40
|
+
dc.mutate(
|
|
41
|
+
row_number=func.row_number().over(window),
|
|
42
|
+
)
|
|
43
|
+
```
|
|
44
|
+
"""
|
|
45
|
+
return Window(
|
|
46
|
+
ColumnMeta.to_db_name(partition_by),
|
|
47
|
+
ColumnMeta.to_db_name(order_by),
|
|
48
|
+
desc,
|
|
49
|
+
)
|
datachain/lib/arrow.py
CHANGED
|
@@ -116,31 +116,43 @@ def infer_schema(chain: "DataChain", **kwargs) -> pa.Schema:
|
|
|
116
116
|
return pa.unify_schemas(schemas)
|
|
117
117
|
|
|
118
118
|
|
|
119
|
-
def schema_to_output(
|
|
120
|
-
|
|
119
|
+
def schema_to_output(
|
|
120
|
+
schema: pa.Schema, col_names: Optional[Sequence[str]] = None
|
|
121
|
+
) -> tuple[dict[str, type], list[str]]:
|
|
122
|
+
"""
|
|
123
|
+
Generate UDF output schema from pyarrow schema.
|
|
124
|
+
Returns a tuple of output schema and original column names (since they may be
|
|
125
|
+
normalized in the output dict).
|
|
126
|
+
"""
|
|
127
|
+
signal_schema = _get_datachain_schema(schema)
|
|
128
|
+
if signal_schema:
|
|
129
|
+
return signal_schema.values, list(signal_schema.values)
|
|
130
|
+
|
|
121
131
|
if col_names and (len(schema) != len(col_names)):
|
|
122
132
|
raise ValueError(
|
|
123
133
|
"Error generating output from Arrow schema - "
|
|
124
134
|
f"Schema has {len(schema)} columns but got {len(col_names)} column names."
|
|
125
135
|
)
|
|
126
136
|
if not col_names:
|
|
127
|
-
col_names = schema.names
|
|
128
|
-
|
|
129
|
-
|
|
130
|
-
|
|
131
|
-
|
|
137
|
+
col_names = schema.names or []
|
|
138
|
+
|
|
139
|
+
normalized_col_dict = normalize_col_names(col_names)
|
|
140
|
+
col_names = list(normalized_col_dict)
|
|
141
|
+
|
|
132
142
|
hf_schema = _get_hf_schema(schema)
|
|
133
143
|
if hf_schema:
|
|
134
144
|
return {
|
|
135
|
-
column: hf_type for hf_type, column in zip(hf_schema[1].values(),
|
|
136
|
-
}
|
|
145
|
+
column: hf_type for hf_type, column in zip(hf_schema[1].values(), col_names)
|
|
146
|
+
}, list(normalized_col_dict.values())
|
|
147
|
+
|
|
137
148
|
output = {}
|
|
138
|
-
for field, column in zip(schema,
|
|
139
|
-
dtype = arrow_type_mapper(field.type, column)
|
|
149
|
+
for field, column in zip(schema, col_names):
|
|
150
|
+
dtype = arrow_type_mapper(field.type, column)
|
|
140
151
|
if field.nullable and not ModelStore.is_pydantic(dtype):
|
|
141
152
|
dtype = Optional[dtype] # type: ignore[assignment]
|
|
142
153
|
output[column] = dtype
|
|
143
|
-
|
|
154
|
+
|
|
155
|
+
return output, list(normalized_col_dict.values())
|
|
144
156
|
|
|
145
157
|
|
|
146
158
|
def arrow_type_mapper(col_type: pa.DataType, column: str = "") -> type: # noqa: PLR0911
|
datachain/lib/data_model.py
CHANGED
|
@@ -1,8 +1,8 @@
|
|
|
1
1
|
from collections.abc import Sequence
|
|
2
2
|
from datetime import datetime
|
|
3
|
-
from typing import ClassVar, Union, get_args, get_origin
|
|
3
|
+
from typing import ClassVar, Optional, Union, get_args, get_origin
|
|
4
4
|
|
|
5
|
-
from pydantic import BaseModel, Field, create_model
|
|
5
|
+
from pydantic import AliasChoices, BaseModel, Field, create_model
|
|
6
6
|
|
|
7
7
|
from datachain.lib.model_store import ModelStore
|
|
8
8
|
from datachain.lib.utils import normalize_col_names
|
|
@@ -60,17 +60,33 @@ def is_chain_type(t: type) -> bool:
|
|
|
60
60
|
return False
|
|
61
61
|
|
|
62
62
|
|
|
63
|
-
def dict_to_data_model(
|
|
64
|
-
|
|
65
|
-
|
|
66
|
-
|
|
67
|
-
|
|
63
|
+
def dict_to_data_model(
|
|
64
|
+
name: str,
|
|
65
|
+
data_dict: dict[str, DataType],
|
|
66
|
+
original_names: Optional[list[str]] = None,
|
|
67
|
+
) -> type[BaseModel]:
|
|
68
|
+
if not original_names:
|
|
69
|
+
# Gets a map of a normalized_name -> original_name
|
|
70
|
+
columns = normalize_col_names(list(data_dict))
|
|
71
|
+
data_dict = dict(zip(columns.keys(), data_dict.values()))
|
|
72
|
+
original_names = list(columns.values())
|
|
68
73
|
|
|
69
74
|
fields = {
|
|
70
|
-
|
|
75
|
+
name: (
|
|
76
|
+
anno,
|
|
77
|
+
Field(
|
|
78
|
+
validation_alias=AliasChoices(name, original_names[idx] or name),
|
|
79
|
+
default=None,
|
|
80
|
+
),
|
|
81
|
+
)
|
|
82
|
+
for idx, (name, anno) in enumerate(data_dict.items())
|
|
71
83
|
}
|
|
84
|
+
|
|
85
|
+
class _DataModelStrict(BaseModel, extra="forbid"):
|
|
86
|
+
pass
|
|
87
|
+
|
|
72
88
|
return create_model(
|
|
73
89
|
name,
|
|
74
|
-
__base__=
|
|
90
|
+
__base__=_DataModelStrict,
|
|
75
91
|
**fields,
|
|
76
92
|
) # type: ignore[call-overload]
|
datachain/lib/dataset_info.py
CHANGED
|
@@ -23,8 +23,8 @@ class DatasetInfo(DataModel):
|
|
|
23
23
|
finished_at: Optional[datetime] = Field(default=None)
|
|
24
24
|
num_objects: Optional[int] = Field(default=None)
|
|
25
25
|
size: Optional[int] = Field(default=None)
|
|
26
|
-
params: dict[str, str] = Field(default=
|
|
27
|
-
metrics: dict[str, Any] = Field(default=
|
|
26
|
+
params: dict[str, str] = Field(default={})
|
|
27
|
+
metrics: dict[str, Any] = Field(default={})
|
|
28
28
|
error_message: str = Field(default="")
|
|
29
29
|
error_stack: str = Field(default="")
|
|
30
30
|
|
datachain/lib/dc.py
CHANGED
|
@@ -28,13 +28,14 @@ from sqlalchemy.sql.sqltypes import NullType
|
|
|
28
28
|
from datachain.client import Client
|
|
29
29
|
from datachain.client.local import FileClient
|
|
30
30
|
from datachain.dataset import DatasetRecord
|
|
31
|
+
from datachain.func.base import Function
|
|
32
|
+
from datachain.func.func import Func
|
|
31
33
|
from datachain.lib.convert.python_to_sql import python_to_sql
|
|
32
34
|
from datachain.lib.convert.values_to_tuples import values_to_tuples
|
|
33
35
|
from datachain.lib.data_model import DataModel, DataType, DataValue, dict_to_data_model
|
|
34
36
|
from datachain.lib.dataset_info import DatasetInfo
|
|
35
37
|
from datachain.lib.file import ArrowRow, File, get_file_type
|
|
36
38
|
from datachain.lib.file import ExportPlacement as FileExportPlacement
|
|
37
|
-
from datachain.lib.func import Func
|
|
38
39
|
from datachain.lib.listing import (
|
|
39
40
|
list_bucket,
|
|
40
41
|
ls,
|
|
@@ -112,9 +113,29 @@ class DatasetFromValuesError(DataChainParamsError): # noqa: D101
|
|
|
112
113
|
super().__init__(f"Dataset{name} from values error: {msg}")
|
|
113
114
|
|
|
114
115
|
|
|
115
|
-
|
|
116
|
+
MergeColType = Union[str, Function, sqlalchemy.ColumnElement]
|
|
117
|
+
|
|
118
|
+
|
|
119
|
+
def _validate_merge_on(
|
|
120
|
+
on: Union[MergeColType, Sequence[MergeColType]],
|
|
121
|
+
ds: "DataChain",
|
|
122
|
+
) -> Sequence[MergeColType]:
|
|
123
|
+
if isinstance(on, (str, sqlalchemy.ColumnElement)):
|
|
124
|
+
return [on]
|
|
125
|
+
if isinstance(on, Function):
|
|
126
|
+
return [on.get_column(table=ds._query.table)]
|
|
127
|
+
if isinstance(on, Sequence):
|
|
128
|
+
return [
|
|
129
|
+
c.get_column(table=ds._query.table) if isinstance(c, Function) else c
|
|
130
|
+
for c in on
|
|
131
|
+
]
|
|
132
|
+
|
|
133
|
+
|
|
134
|
+
def _get_merge_error_str(col: MergeColType) -> str:
|
|
116
135
|
if isinstance(col, str):
|
|
117
136
|
return col
|
|
137
|
+
if isinstance(col, Function):
|
|
138
|
+
return f"{col.name}()"
|
|
118
139
|
if isinstance(col, sqlalchemy.Column):
|
|
119
140
|
return col.name.replace(DEFAULT_DELIMITER, ".")
|
|
120
141
|
if isinstance(col, sqlalchemy.ColumnElement) and hasattr(col, "name"):
|
|
@@ -125,11 +146,13 @@ def _get_merge_error_str(col: Union[str, sqlalchemy.ColumnElement]) -> str:
|
|
|
125
146
|
class DatasetMergeError(DataChainParamsError): # noqa: D101
|
|
126
147
|
def __init__( # noqa: D107
|
|
127
148
|
self,
|
|
128
|
-
on:
|
|
129
|
-
right_on: Optional[
|
|
149
|
+
on: Union[MergeColType, Sequence[MergeColType]],
|
|
150
|
+
right_on: Optional[Union[MergeColType, Sequence[MergeColType]]],
|
|
130
151
|
msg: str,
|
|
131
152
|
):
|
|
132
|
-
def _get_str(
|
|
153
|
+
def _get_str(
|
|
154
|
+
on: Union[MergeColType, Sequence[MergeColType]],
|
|
155
|
+
) -> str:
|
|
133
156
|
if not isinstance(on, Sequence):
|
|
134
157
|
return str(on) # type: ignore[unreachable]
|
|
135
158
|
return ", ".join([_get_merge_error_str(col) for col in on])
|
|
@@ -348,6 +371,9 @@ class DataChain:
|
|
|
348
371
|
enable all available CPUs (default=1)
|
|
349
372
|
workers : number of distributed workers. Only for Studio mode. (default=1)
|
|
350
373
|
min_task_size : minimum number of tasks (default=1)
|
|
374
|
+
prefetch: number of workers to use for downloading files in advance.
|
|
375
|
+
This is enabled by default and uses 2 workers.
|
|
376
|
+
To disable prefetching, set it to 0.
|
|
351
377
|
|
|
352
378
|
Example:
|
|
353
379
|
```py
|
|
@@ -648,6 +674,7 @@ class DataChain:
|
|
|
648
674
|
col: str,
|
|
649
675
|
model_name: Optional[str] = None,
|
|
650
676
|
object_name: Optional[str] = None,
|
|
677
|
+
schema_sample_size: int = 1,
|
|
651
678
|
) -> "DataChain":
|
|
652
679
|
"""Explodes a column containing JSON objects (dict or str DataChain type) into
|
|
653
680
|
individual columns based on the schema of the JSON. Schema is inferred from
|
|
@@ -659,6 +686,9 @@ class DataChain:
|
|
|
659
686
|
automatically.
|
|
660
687
|
object_name: optional generated object column name. By default generates the
|
|
661
688
|
name automatically.
|
|
689
|
+
schema_sample_size: the number of rows to use for inferring the schema of
|
|
690
|
+
the JSON (in case some fields are optional and it's not enough to
|
|
691
|
+
analyze a single row).
|
|
662
692
|
|
|
663
693
|
Returns:
|
|
664
694
|
DataChain: A new DataChain instance with the new set of columns.
|
|
@@ -669,21 +699,22 @@ class DataChain:
|
|
|
669
699
|
|
|
670
700
|
from datachain.lib.arrow import schema_to_output
|
|
671
701
|
|
|
672
|
-
|
|
673
|
-
|
|
702
|
+
json_values = list(self.limit(schema_sample_size).collect(col))
|
|
703
|
+
json_dicts = [
|
|
674
704
|
json.loads(json_value) if isinstance(json_value, str) else json_value
|
|
675
|
-
|
|
705
|
+
for json_value in json_values
|
|
706
|
+
]
|
|
676
707
|
|
|
677
|
-
if not isinstance(json_dict, dict):
|
|
708
|
+
if any(not isinstance(json_dict, dict) for json_dict in json_dicts):
|
|
678
709
|
raise TypeError(f"Column {col} should be a string or dict type with JSON")
|
|
679
710
|
|
|
680
|
-
schema = pa.Table.from_pylist(
|
|
681
|
-
output = schema_to_output(schema, None)
|
|
711
|
+
schema = pa.Table.from_pylist(json_dicts).schema
|
|
712
|
+
output, original_names = schema_to_output(schema, None)
|
|
682
713
|
|
|
683
714
|
if not model_name:
|
|
684
715
|
model_name = f"{col.title()}ExplodedModel"
|
|
685
716
|
|
|
686
|
-
model = dict_to_data_model(model_name, output)
|
|
717
|
+
model = dict_to_data_model(model_name, output, original_names)
|
|
687
718
|
|
|
688
719
|
def json_to_model(json_value: Union[str, dict]):
|
|
689
720
|
json_dict = (
|
|
@@ -776,7 +807,7 @@ class DataChain:
|
|
|
776
807
|
```py
|
|
777
808
|
uri = "gs://datachain-demo/coco2017/annotations_captions/"
|
|
778
809
|
chain = DataChain.from_storage(uri)
|
|
779
|
-
chain = chain.
|
|
810
|
+
chain = chain.print_json_schema()
|
|
780
811
|
chain.save()
|
|
781
812
|
```
|
|
782
813
|
"""
|
|
@@ -1119,7 +1150,7 @@ class DataChain:
|
|
|
1119
1150
|
def group_by(
|
|
1120
1151
|
self,
|
|
1121
1152
|
*,
|
|
1122
|
-
partition_by: Union[str, Sequence[str]],
|
|
1153
|
+
partition_by: Union[str, Func, Sequence[Union[str, Func]]],
|
|
1123
1154
|
**kwargs: Func,
|
|
1124
1155
|
) -> "Self":
|
|
1125
1156
|
"""Group rows by specified set of signals and return new signals
|
|
@@ -1136,36 +1167,47 @@ class DataChain:
|
|
|
1136
1167
|
)
|
|
1137
1168
|
```
|
|
1138
1169
|
"""
|
|
1139
|
-
if isinstance(partition_by, str):
|
|
1170
|
+
if isinstance(partition_by, (str, Func)):
|
|
1140
1171
|
partition_by = [partition_by]
|
|
1141
1172
|
if not partition_by:
|
|
1142
1173
|
raise ValueError("At least one column should be provided for partition_by")
|
|
1143
1174
|
|
|
1144
|
-
if not kwargs:
|
|
1145
|
-
raise ValueError("At least one column should be provided for group_by")
|
|
1146
|
-
for col_name, func in kwargs.items():
|
|
1147
|
-
if not isinstance(func, Func):
|
|
1148
|
-
raise DataChainColumnError(
|
|
1149
|
-
col_name,
|
|
1150
|
-
f"Column {col_name} has type {type(func)} but expected Func object",
|
|
1151
|
-
)
|
|
1152
|
-
|
|
1153
1175
|
partition_by_columns: list[Column] = []
|
|
1154
1176
|
signal_columns: list[Column] = []
|
|
1155
1177
|
schema_fields: dict[str, DataType] = {}
|
|
1156
1178
|
|
|
1157
1179
|
# validate partition_by columns and add them to the schema
|
|
1158
|
-
for
|
|
1159
|
-
|
|
1160
|
-
|
|
1161
|
-
|
|
1162
|
-
|
|
1180
|
+
for col in partition_by:
|
|
1181
|
+
if isinstance(col, str):
|
|
1182
|
+
col_db_name = ColumnMeta.to_db_name(col)
|
|
1183
|
+
col_type = self.signals_schema.get_column_type(col_db_name)
|
|
1184
|
+
column = Column(col_db_name, python_to_sql(col_type))
|
|
1185
|
+
elif isinstance(col, Function):
|
|
1186
|
+
column = col.get_column(self.signals_schema)
|
|
1187
|
+
col_db_name = column.name
|
|
1188
|
+
col_type = column.type.python_type
|
|
1189
|
+
else:
|
|
1190
|
+
raise DataChainColumnError(
|
|
1191
|
+
col,
|
|
1192
|
+
(
|
|
1193
|
+
f"partition_by column {col} has type {type(col)}"
|
|
1194
|
+
" but expected str or Function"
|
|
1195
|
+
),
|
|
1196
|
+
)
|
|
1197
|
+
partition_by_columns.append(column)
|
|
1163
1198
|
schema_fields[col_db_name] = col_type
|
|
1164
1199
|
|
|
1165
1200
|
# validate signal columns and add them to the schema
|
|
1201
|
+
if not kwargs:
|
|
1202
|
+
raise ValueError("At least one column should be provided for group_by")
|
|
1166
1203
|
for col_name, func in kwargs.items():
|
|
1167
|
-
|
|
1168
|
-
|
|
1204
|
+
if not isinstance(func, Func):
|
|
1205
|
+
raise DataChainColumnError(
|
|
1206
|
+
col_name,
|
|
1207
|
+
f"Column {col_name} has type {type(func)} but expected Func object",
|
|
1208
|
+
)
|
|
1209
|
+
column = func.get_column(self.signals_schema, label=col_name)
|
|
1210
|
+
signal_columns.append(column)
|
|
1169
1211
|
schema_fields[col_name] = func.get_result_type(self.signals_schema)
|
|
1170
1212
|
|
|
1171
1213
|
return self._evolve(
|
|
@@ -1413,25 +1455,16 @@ class DataChain:
|
|
|
1413
1455
|
def merge(
|
|
1414
1456
|
self,
|
|
1415
1457
|
right_ds: "DataChain",
|
|
1416
|
-
on: Union[
|
|
1417
|
-
|
|
1418
|
-
sqlalchemy.ColumnElement,
|
|
1419
|
-
Sequence[Union[str, sqlalchemy.ColumnElement]],
|
|
1420
|
-
],
|
|
1421
|
-
right_on: Union[
|
|
1422
|
-
str,
|
|
1423
|
-
sqlalchemy.ColumnElement,
|
|
1424
|
-
Sequence[Union[str, sqlalchemy.ColumnElement]],
|
|
1425
|
-
None,
|
|
1426
|
-
] = None,
|
|
1458
|
+
on: Union[MergeColType, Sequence[MergeColType]],
|
|
1459
|
+
right_on: Optional[Union[MergeColType, Sequence[MergeColType]]] = None,
|
|
1427
1460
|
inner=False,
|
|
1428
1461
|
rname="right_",
|
|
1429
1462
|
) -> "Self":
|
|
1430
1463
|
"""Merge two chains based on the specified criteria.
|
|
1431
1464
|
|
|
1432
1465
|
Parameters:
|
|
1433
|
-
right_ds
|
|
1434
|
-
on
|
|
1466
|
+
right_ds: Chain to join with.
|
|
1467
|
+
on: Predicate or list of Predicates to join on. If both chains have the
|
|
1435
1468
|
same predicates then this predicate is enough for the join. Otherwise,
|
|
1436
1469
|
`right_on` parameter has to specify the predicates for the other chain.
|
|
1437
1470
|
right_on: Optional predicate or list of Predicates
|
|
@@ -1448,23 +1481,24 @@ class DataChain:
|
|
|
1448
1481
|
if on is None:
|
|
1449
1482
|
raise DatasetMergeError(["None"], None, "'on' must be specified")
|
|
1450
1483
|
|
|
1451
|
-
|
|
1452
|
-
|
|
1453
|
-
elif not isinstance(on, Sequence):
|
|
1484
|
+
on = _validate_merge_on(on, self)
|
|
1485
|
+
if not on:
|
|
1454
1486
|
raise DatasetMergeError(
|
|
1455
1487
|
on,
|
|
1456
1488
|
right_on,
|
|
1457
|
-
|
|
1489
|
+
(
|
|
1490
|
+
"'on' must be 'str', 'Func' or 'Sequence' object "
|
|
1491
|
+
f"but got type '{type(on)}'"
|
|
1492
|
+
),
|
|
1458
1493
|
)
|
|
1459
1494
|
|
|
1460
1495
|
if right_on is not None:
|
|
1461
|
-
|
|
1462
|
-
|
|
1463
|
-
elif not isinstance(right_on, Sequence):
|
|
1496
|
+
right_on = _validate_merge_on(right_on, right_ds)
|
|
1497
|
+
if not right_on:
|
|
1464
1498
|
raise DatasetMergeError(
|
|
1465
1499
|
on,
|
|
1466
1500
|
right_on,
|
|
1467
|
-
"'right_on' must be 'str' or 'Sequence' object"
|
|
1501
|
+
"'right_on' must be 'str', 'Func' or 'Sequence' object"
|
|
1468
1502
|
f" but got type '{type(right_on)}'",
|
|
1469
1503
|
)
|
|
1470
1504
|
|
|
@@ -1480,10 +1514,12 @@ class DataChain:
|
|
|
1480
1514
|
|
|
1481
1515
|
def _resolve(
|
|
1482
1516
|
ds: DataChain,
|
|
1483
|
-
col: Union[str, sqlalchemy.ColumnElement],
|
|
1517
|
+
col: Union[str, Function, sqlalchemy.ColumnElement],
|
|
1484
1518
|
side: Union[str, None],
|
|
1485
1519
|
):
|
|
1486
1520
|
try:
|
|
1521
|
+
if isinstance(col, Function):
|
|
1522
|
+
return ds.c(col.get_column())
|
|
1487
1523
|
return ds.c(col) if isinstance(col, (str, C)) else col
|
|
1488
1524
|
except ValueError:
|
|
1489
1525
|
if side:
|
|
@@ -1834,13 +1870,14 @@ class DataChain:
|
|
|
1834
1870
|
if col_names or not output:
|
|
1835
1871
|
try:
|
|
1836
1872
|
schema = infer_schema(self, **kwargs)
|
|
1837
|
-
output = schema_to_output(schema, col_names)
|
|
1873
|
+
output, _ = schema_to_output(schema, col_names)
|
|
1838
1874
|
except ValueError as e:
|
|
1839
1875
|
raise DatasetPrepareError(self.name, e) from e
|
|
1840
1876
|
|
|
1841
1877
|
if isinstance(output, dict):
|
|
1842
1878
|
model_name = model_name or object_name or ""
|
|
1843
1879
|
model = dict_to_data_model(model_name, output)
|
|
1880
|
+
output = model
|
|
1844
1881
|
else:
|
|
1845
1882
|
model = output # type: ignore[assignment]
|
|
1846
1883
|
|
|
@@ -1851,6 +1888,7 @@ class DataChain:
|
|
|
1851
1888
|
name: info.annotation # type: ignore[misc]
|
|
1852
1889
|
for name, info in output.model_fields.items()
|
|
1853
1890
|
}
|
|
1891
|
+
|
|
1854
1892
|
if source:
|
|
1855
1893
|
output = {"source": ArrowRow} | output # type: ignore[assignment,operator]
|
|
1856
1894
|
return self.gen(
|
|
@@ -2389,9 +2427,9 @@ class DataChain:
|
|
|
2389
2427
|
dc.filter(C("file.name").glob("*.jpg"))
|
|
2390
2428
|
```
|
|
2391
2429
|
|
|
2392
|
-
Using `datachain.
|
|
2430
|
+
Using `datachain.func`
|
|
2393
2431
|
```py
|
|
2394
|
-
from datachain.
|
|
2432
|
+
from datachain.func import string
|
|
2395
2433
|
dc.filter(string.length(C("file.name")) > 5)
|
|
2396
2434
|
```
|
|
2397
2435
|
|
datachain/lib/hf.py
CHANGED
|
@@ -98,7 +98,7 @@ class HFGenerator(Generator):
|
|
|
98
98
|
with tqdm(desc=desc, unit=" rows") as pbar:
|
|
99
99
|
for row in ds:
|
|
100
100
|
output_dict = {}
|
|
101
|
-
if split:
|
|
101
|
+
if split and "split" in self.output_schema.model_fields:
|
|
102
102
|
output_dict["split"] = split
|
|
103
103
|
for name, feat in ds.features.items():
|
|
104
104
|
anno = self.output_schema.model_fields[name].annotation
|
datachain/lib/signal_schema.py
CHANGED
|
@@ -23,12 +23,12 @@ from pydantic import BaseModel, create_model
|
|
|
23
23
|
from sqlalchemy import ColumnElement
|
|
24
24
|
from typing_extensions import Literal as LiteralEx
|
|
25
25
|
|
|
26
|
+
from datachain.func.func import Func
|
|
26
27
|
from datachain.lib.convert.python_to_sql import python_to_sql
|
|
27
28
|
from datachain.lib.convert.sql_to_python import sql_to_python
|
|
28
29
|
from datachain.lib.convert.unflatten import unflatten_to_json_pos
|
|
29
30
|
from datachain.lib.data_model import DataModel, DataType, DataValue
|
|
30
31
|
from datachain.lib.file import File
|
|
31
|
-
from datachain.lib.func import Func
|
|
32
32
|
from datachain.lib.model_store import ModelStore
|
|
33
33
|
from datachain.lib.utils import DataChainParamsError
|
|
34
34
|
from datachain.query.schema import DEFAULT_DELIMITER, Column
|
datachain/lib/utils.py
CHANGED
|
@@ -33,6 +33,7 @@ class DataChainColumnError(DataChainParamsError):
|
|
|
33
33
|
|
|
34
34
|
|
|
35
35
|
def normalize_col_names(col_names: Sequence[str]) -> dict[str, str]:
|
|
36
|
+
"""Returns normalized_name -> original_name dict."""
|
|
36
37
|
gen_col_counter = 0
|
|
37
38
|
new_col_names = {}
|
|
38
39
|
org_col_names = set(col_names)
|
|
@@ -49,11 +49,11 @@ class WDSLaion(WDSBasic):
|
|
|
49
49
|
class LaionMeta(BaseModel):
|
|
50
50
|
file: File
|
|
51
51
|
index: Optional[int] = Field(default=None)
|
|
52
|
-
b32_img: list[float] = Field(default=
|
|
53
|
-
b32_txt: list[float] = Field(default=
|
|
54
|
-
l14_img: list[float] = Field(default=
|
|
55
|
-
l14_txt: list[float] = Field(default=
|
|
56
|
-
dedup: list[float] = Field(default=
|
|
52
|
+
b32_img: list[float] = Field(default=[])
|
|
53
|
+
b32_txt: list[float] = Field(default=[])
|
|
54
|
+
l14_img: list[float] = Field(default=[])
|
|
55
|
+
l14_txt: list[float] = Field(default=[])
|
|
56
|
+
dedup: list[float] = Field(default=[])
|
|
57
57
|
|
|
58
58
|
|
|
59
59
|
def process_laion_meta(file: File) -> Iterator[LaionMeta]:
|
datachain/model/bbox.py
ADDED
|
@@ -0,0 +1,102 @@
|
|
|
1
|
+
from pydantic import Field
|
|
2
|
+
|
|
3
|
+
from datachain.lib.data_model import DataModel
|
|
4
|
+
|
|
5
|
+
|
|
6
|
+
class BBox(DataModel):
|
|
7
|
+
"""
|
|
8
|
+
A data model for representing bounding box.
|
|
9
|
+
|
|
10
|
+
Attributes:
|
|
11
|
+
title (str): The title of the bounding box.
|
|
12
|
+
coords (list[int]): The coordinates of the bounding box.
|
|
13
|
+
|
|
14
|
+
The bounding box is defined by two points:
|
|
15
|
+
- (x1, y1): The top-left corner of the box.
|
|
16
|
+
- (x2, y2): The bottom-right corner of the box.
|
|
17
|
+
"""
|
|
18
|
+
|
|
19
|
+
title: str = Field(default="")
|
|
20
|
+
coords: list[int] = Field(default=[])
|
|
21
|
+
|
|
22
|
+
@staticmethod
|
|
23
|
+
def from_list(coords: list[float], title: str = "") -> "BBox":
|
|
24
|
+
assert len(coords) == 4, "Bounding box must be a list of 4 coordinates."
|
|
25
|
+
assert all(
|
|
26
|
+
isinstance(value, (int, float)) for value in coords
|
|
27
|
+
), "Bounding box coordinates must be floats or integers."
|
|
28
|
+
return BBox(
|
|
29
|
+
title=title,
|
|
30
|
+
coords=[round(c) for c in coords],
|
|
31
|
+
)
|
|
32
|
+
|
|
33
|
+
@staticmethod
|
|
34
|
+
def from_dict(coords: dict[str, float], title: str = "") -> "BBox":
|
|
35
|
+
assert isinstance(coords, dict) and set(coords) == {
|
|
36
|
+
"x1",
|
|
37
|
+
"y1",
|
|
38
|
+
"x2",
|
|
39
|
+
"y2",
|
|
40
|
+
}, "Bounding box must be a dictionary with keys 'x1', 'y1', 'x2' and 'y2'."
|
|
41
|
+
return BBox.from_list(
|
|
42
|
+
[coords["x1"], coords["y1"], coords["x2"], coords["y2"]],
|
|
43
|
+
title=title,
|
|
44
|
+
)
|
|
45
|
+
|
|
46
|
+
|
|
47
|
+
class OBBox(DataModel):
|
|
48
|
+
"""
|
|
49
|
+
A data model for representing oriented bounding boxes.
|
|
50
|
+
|
|
51
|
+
Attributes:
|
|
52
|
+
title (str): The title of the oriented bounding box.
|
|
53
|
+
coords (list[int]): The coordinates of the oriented bounding box.
|
|
54
|
+
|
|
55
|
+
The oriented bounding box is defined by four points:
|
|
56
|
+
- (x1, y1): The first corner of the box.
|
|
57
|
+
- (x2, y2): The second corner of the box.
|
|
58
|
+
- (x3, y3): The third corner of the box.
|
|
59
|
+
- (x4, y4): The fourth corner of the box.
|
|
60
|
+
"""
|
|
61
|
+
|
|
62
|
+
title: str = Field(default="")
|
|
63
|
+
coords: list[int] = Field(default=[])
|
|
64
|
+
|
|
65
|
+
@staticmethod
|
|
66
|
+
def from_list(coords: list[float], title: str = "") -> "OBBox":
|
|
67
|
+
assert (
|
|
68
|
+
len(coords) == 8
|
|
69
|
+
), "Oriented bounding box must be a list of 8 coordinates."
|
|
70
|
+
assert all(
|
|
71
|
+
isinstance(value, (int, float)) for value in coords
|
|
72
|
+
), "Oriented bounding box coordinates must be floats or integers."
|
|
73
|
+
return OBBox(
|
|
74
|
+
title=title,
|
|
75
|
+
coords=[round(c) for c in coords],
|
|
76
|
+
)
|
|
77
|
+
|
|
78
|
+
@staticmethod
|
|
79
|
+
def from_dict(coords: dict[str, float], title: str = "") -> "OBBox":
|
|
80
|
+
assert isinstance(coords, dict) and set(coords) == {
|
|
81
|
+
"x1",
|
|
82
|
+
"y1",
|
|
83
|
+
"x2",
|
|
84
|
+
"y2",
|
|
85
|
+
"x3",
|
|
86
|
+
"y3",
|
|
87
|
+
"x4",
|
|
88
|
+
"y4",
|
|
89
|
+
}, "Oriented bounding box must be a dictionary with coordinates."
|
|
90
|
+
return OBBox.from_list(
|
|
91
|
+
[
|
|
92
|
+
coords["x1"],
|
|
93
|
+
coords["y1"],
|
|
94
|
+
coords["x2"],
|
|
95
|
+
coords["y2"],
|
|
96
|
+
coords["x3"],
|
|
97
|
+
coords["y3"],
|
|
98
|
+
coords["x4"],
|
|
99
|
+
coords["y4"],
|
|
100
|
+
],
|
|
101
|
+
title=title,
|
|
102
|
+
)
|