datachain 0.2.11__py3-none-any.whl → 0.2.12__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of datachain might be problematic. Click here for more details.
- datachain/__init__.py +3 -4
- datachain/cache.py +10 -4
- datachain/catalog/catalog.py +35 -15
- datachain/cli.py +37 -32
- datachain/data_storage/metastore.py +24 -0
- datachain/data_storage/warehouse.py +3 -1
- datachain/job.py +56 -0
- datachain/lib/arrow.py +19 -7
- datachain/lib/clip.py +89 -66
- datachain/lib/convert/{type_converter.py → python_to_sql.py} +6 -6
- datachain/lib/convert/sql_to_python.py +23 -0
- datachain/lib/convert/values_to_tuples.py +51 -33
- datachain/lib/data_model.py +6 -27
- datachain/lib/dataset_info.py +70 -0
- datachain/lib/dc.py +618 -156
- datachain/lib/file.py +117 -15
- datachain/lib/image.py +1 -1
- datachain/lib/meta_formats.py +14 -2
- datachain/lib/model_store.py +3 -2
- datachain/lib/pytorch.py +10 -7
- datachain/lib/signal_schema.py +19 -11
- datachain/lib/text.py +2 -1
- datachain/lib/udf.py +56 -5
- datachain/lib/udf_signature.py +1 -1
- datachain/node.py +11 -8
- datachain/query/dataset.py +52 -26
- datachain/query/schema.py +2 -0
- datachain/query/session.py +4 -4
- datachain/sql/functions/array.py +12 -0
- datachain/sql/functions/string.py +8 -0
- datachain/torch/__init__.py +1 -1
- datachain/utils.py +6 -0
- datachain-0.2.12.dist-info/METADATA +412 -0
- {datachain-0.2.11.dist-info → datachain-0.2.12.dist-info}/RECORD +38 -42
- {datachain-0.2.11.dist-info → datachain-0.2.12.dist-info}/WHEEL +1 -1
- datachain/lib/gpt4_vision.py +0 -97
- datachain/lib/hf_image_to_text.py +0 -97
- datachain/lib/hf_pipeline.py +0 -90
- datachain/lib/image_transform.py +0 -103
- datachain/lib/iptc_exif_xmp.py +0 -76
- datachain/lib/unstructured.py +0 -41
- datachain/text/__init__.py +0 -3
- datachain-0.2.11.dist-info/METADATA +0 -431
- {datachain-0.2.11.dist-info → datachain-0.2.12.dist-info}/LICENSE +0 -0
- {datachain-0.2.11.dist-info → datachain-0.2.12.dist-info}/entry_points.txt +0 -0
- {datachain-0.2.11.dist-info → datachain-0.2.12.dist-info}/top_level.txt +0 -0
datachain/query/dataset.py
CHANGED
|
@@ -54,6 +54,7 @@ from datachain.utils import (
|
|
|
54
54
|
batched,
|
|
55
55
|
determine_processes,
|
|
56
56
|
filtered_cloudpickle_dumps,
|
|
57
|
+
get_datachain_executable,
|
|
57
58
|
)
|
|
58
59
|
|
|
59
60
|
from .metrics import metrics
|
|
@@ -426,7 +427,7 @@ def get_generated_callback(is_generator: bool = False) -> Callback:
|
|
|
426
427
|
|
|
427
428
|
|
|
428
429
|
@frozen
|
|
429
|
-
class
|
|
430
|
+
class UDFStep(Step, ABC):
|
|
430
431
|
udf: UDFType
|
|
431
432
|
catalog: "Catalog"
|
|
432
433
|
partition_by: Optional[PartitionByType] = None
|
|
@@ -507,13 +508,12 @@ class UDF(Step, ABC):
|
|
|
507
508
|
|
|
508
509
|
# Run the UDFDispatcher in another process to avoid needing
|
|
509
510
|
# if __name__ == '__main__': in user scripts
|
|
510
|
-
|
|
511
|
-
|
|
511
|
+
exec_cmd = get_datachain_executable()
|
|
512
512
|
envs = dict(os.environ)
|
|
513
513
|
envs.update({"PYTHONPATH": os.getcwd()})
|
|
514
514
|
process_data = filtered_cloudpickle_dumps(udf_info)
|
|
515
515
|
result = subprocess.run( # noqa: S603
|
|
516
|
-
[
|
|
516
|
+
[*exec_cmd, "internal-run-udf"],
|
|
517
517
|
input=process_data,
|
|
518
518
|
check=False,
|
|
519
519
|
env=envs,
|
|
@@ -635,7 +635,7 @@ class UDF(Step, ABC):
|
|
|
635
635
|
|
|
636
636
|
|
|
637
637
|
@frozen
|
|
638
|
-
class UDFSignal(
|
|
638
|
+
class UDFSignal(UDFStep):
|
|
639
639
|
is_generator = False
|
|
640
640
|
|
|
641
641
|
def create_udf_table(self, query: Select) -> "Table":
|
|
@@ -730,7 +730,7 @@ class UDFSignal(UDF):
|
|
|
730
730
|
|
|
731
731
|
|
|
732
732
|
@frozen
|
|
733
|
-
class RowGenerator(
|
|
733
|
+
class RowGenerator(UDFStep):
|
|
734
734
|
"""Extend dataset with new rows."""
|
|
735
735
|
|
|
736
736
|
is_generator = True
|
|
@@ -865,6 +865,18 @@ class SQLCount(SQLClause):
|
|
|
865
865
|
return sqlalchemy.select(f.count(1)).select_from(query.subquery())
|
|
866
866
|
|
|
867
867
|
|
|
868
|
+
@frozen
|
|
869
|
+
class SQLDistinct(SQLClause):
|
|
870
|
+
args: tuple[ColumnElement, ...]
|
|
871
|
+
dialect: str
|
|
872
|
+
|
|
873
|
+
def apply_sql_clause(self, query):
|
|
874
|
+
if self.dialect == "sqlite":
|
|
875
|
+
return query.group_by(*self.args)
|
|
876
|
+
|
|
877
|
+
return query.distinct(*self.args)
|
|
878
|
+
|
|
879
|
+
|
|
868
880
|
@frozen
|
|
869
881
|
class SQLUnion(Step):
|
|
870
882
|
query1: "DatasetQuery"
|
|
@@ -946,12 +958,15 @@ class SQLJoin(Step):
|
|
|
946
958
|
|
|
947
959
|
q1_columns = list(q1.c)
|
|
948
960
|
q1_column_names = {c.name for c in q1_columns}
|
|
949
|
-
|
|
950
|
-
|
|
951
|
-
|
|
952
|
-
|
|
953
|
-
|
|
954
|
-
|
|
961
|
+
|
|
962
|
+
q2_columns = []
|
|
963
|
+
for c in q2.c:
|
|
964
|
+
if c.name.startswith("sys__"):
|
|
965
|
+
continue
|
|
966
|
+
|
|
967
|
+
if c.name in q1_column_names:
|
|
968
|
+
c = c.label(self.rname.format(name=c.name))
|
|
969
|
+
q2_columns.append(c)
|
|
955
970
|
|
|
956
971
|
res_columns = q1_columns + q2_columns
|
|
957
972
|
predicates = (
|
|
@@ -1058,6 +1073,7 @@ class DatasetQuery:
|
|
|
1058
1073
|
anon: bool = False,
|
|
1059
1074
|
indexing_feature_schema: Optional[dict] = None,
|
|
1060
1075
|
indexing_column_types: Optional[dict[str, Any]] = None,
|
|
1076
|
+
update: Optional[bool] = False,
|
|
1061
1077
|
):
|
|
1062
1078
|
if client_config is None:
|
|
1063
1079
|
client_config = {}
|
|
@@ -1080,10 +1096,12 @@ class DatasetQuery:
|
|
|
1080
1096
|
self.session = Session.get(session, catalog=catalog)
|
|
1081
1097
|
|
|
1082
1098
|
if path:
|
|
1083
|
-
|
|
1099
|
+
kwargs = {"update": True} if update else {}
|
|
1100
|
+
self.starting_step = IndexingStep(path, self.catalog, kwargs, recursive)
|
|
1084
1101
|
self.feature_schema = indexing_feature_schema
|
|
1085
1102
|
self.column_types = indexing_column_types
|
|
1086
1103
|
elif name:
|
|
1104
|
+
self.name = name
|
|
1087
1105
|
ds = self.catalog.get_dataset(name)
|
|
1088
1106
|
self.version = version or ds.latest_version
|
|
1089
1107
|
self.feature_schema = ds.get_version(self.version).feature_schema
|
|
@@ -1091,9 +1109,6 @@ class DatasetQuery:
|
|
|
1091
1109
|
if "sys__id" in self.column_types:
|
|
1092
1110
|
self.column_types.pop("sys__id")
|
|
1093
1111
|
self.starting_step = QueryStep(self.catalog, name, self.version)
|
|
1094
|
-
# attaching to specific dataset
|
|
1095
|
-
self.name = name
|
|
1096
|
-
self.version = version
|
|
1097
1112
|
else:
|
|
1098
1113
|
raise ValueError("must provide path or name")
|
|
1099
1114
|
|
|
@@ -1102,7 +1117,7 @@ class DatasetQuery:
|
|
|
1102
1117
|
return bool(re.compile(r"^[a-zA-Z0-9]+://").match(path))
|
|
1103
1118
|
|
|
1104
1119
|
def __iter__(self):
|
|
1105
|
-
return iter(self.
|
|
1120
|
+
return iter(self.db_results())
|
|
1106
1121
|
|
|
1107
1122
|
def __or__(self, other):
|
|
1108
1123
|
return self.union(other)
|
|
@@ -1223,13 +1238,16 @@ class DatasetQuery:
|
|
|
1223
1238
|
warehouse.close()
|
|
1224
1239
|
self.temp_table_names = []
|
|
1225
1240
|
|
|
1226
|
-
def
|
|
1241
|
+
def db_results(self, row_factory=None, **kwargs):
|
|
1227
1242
|
with self.as_iterable(**kwargs) as result:
|
|
1228
1243
|
if row_factory:
|
|
1229
1244
|
cols = result.columns
|
|
1230
1245
|
return [row_factory(cols, r) for r in result]
|
|
1231
1246
|
return list(result)
|
|
1232
1247
|
|
|
1248
|
+
def to_db_records(self) -> list[dict[str, Any]]:
|
|
1249
|
+
return self.db_results(lambda cols, row: dict(zip(cols, row)))
|
|
1250
|
+
|
|
1233
1251
|
@contextlib.contextmanager
|
|
1234
1252
|
def as_iterable(self, **kwargs) -> Iterator[ResultIter]:
|
|
1235
1253
|
try:
|
|
@@ -1289,9 +1307,6 @@ class DatasetQuery:
|
|
|
1289
1307
|
finally:
|
|
1290
1308
|
self.cleanup()
|
|
1291
1309
|
|
|
1292
|
-
def to_records(self) -> list[dict[str, Any]]:
|
|
1293
|
-
return self.results(lambda cols, row: dict(zip(cols, row)))
|
|
1294
|
-
|
|
1295
1310
|
def shuffle(self) -> "Self":
|
|
1296
1311
|
# ToDo: implement shaffle based on seed and/or generating random column
|
|
1297
1312
|
return self.order_by(C.sys__rand)
|
|
@@ -1407,6 +1422,14 @@ class DatasetQuery:
|
|
|
1407
1422
|
query.steps.append(SQLOffset(offset))
|
|
1408
1423
|
return query
|
|
1409
1424
|
|
|
1425
|
+
@detach
|
|
1426
|
+
def distinct(self, *args) -> "Self":
|
|
1427
|
+
query = self.clone()
|
|
1428
|
+
query.steps.append(
|
|
1429
|
+
SQLDistinct(args, dialect=self.catalog.warehouse.db.dialect.name)
|
|
1430
|
+
)
|
|
1431
|
+
return query
|
|
1432
|
+
|
|
1410
1433
|
def as_scalar(self) -> Any:
|
|
1411
1434
|
with self.as_iterable() as rows:
|
|
1412
1435
|
row = next(iter(rows))
|
|
@@ -1705,10 +1728,13 @@ def _send_result(dataset_query: DatasetQuery) -> None:
|
|
|
1705
1728
|
|
|
1706
1729
|
columns = preview_args.get("columns") or []
|
|
1707
1730
|
|
|
1708
|
-
|
|
1709
|
-
dataset_query.select(*columns)
|
|
1710
|
-
|
|
1711
|
-
.
|
|
1731
|
+
if type(dataset_query) is DatasetQuery:
|
|
1732
|
+
preview_query = dataset_query.select(*columns)
|
|
1733
|
+
else:
|
|
1734
|
+
preview_query = dataset_query.select(*columns, _sys=False)
|
|
1735
|
+
|
|
1736
|
+
preview_query = preview_query.limit(preview_args.get("limit", 10)).offset(
|
|
1737
|
+
preview_args.get("offset", 0)
|
|
1712
1738
|
)
|
|
1713
1739
|
|
|
1714
1740
|
dataset: Optional[tuple[str, int]] = None
|
|
@@ -1717,7 +1743,7 @@ def _send_result(dataset_query: DatasetQuery) -> None:
|
|
|
1717
1743
|
assert dataset_query.version, "Dataset version should be provided"
|
|
1718
1744
|
dataset = dataset_query.name, dataset_query.version
|
|
1719
1745
|
|
|
1720
|
-
preview = preview_query.
|
|
1746
|
+
preview = preview_query.to_db_records()
|
|
1721
1747
|
result = ExecutionResult(preview, dataset, metrics)
|
|
1722
1748
|
data = attrs.asdict(result)
|
|
1723
1749
|
|
datachain/query/schema.py
CHANGED
|
@@ -32,6 +32,7 @@ class Column(sa.ColumnClause, metaclass=ColumnMeta):
|
|
|
32
32
|
inherit_cache: Optional[bool] = True
|
|
33
33
|
|
|
34
34
|
def __init__(self, text, type_=None, is_literal=False, _selectable=None):
|
|
35
|
+
"""Dataset column."""
|
|
35
36
|
self.name = ColumnMeta.to_db_name(text)
|
|
36
37
|
super().__init__(
|
|
37
38
|
self.name, type_=type_, is_literal=is_literal, _selectable=_selectable
|
|
@@ -41,6 +42,7 @@ class Column(sa.ColumnClause, metaclass=ColumnMeta):
|
|
|
41
42
|
return Column(self.name + DEFAULT_DELIMITER + name)
|
|
42
43
|
|
|
43
44
|
def glob(self, glob_str):
|
|
45
|
+
"""Search for matches using glob pattern matching."""
|
|
44
46
|
return self.op("GLOB")(glob_str)
|
|
45
47
|
|
|
46
48
|
|
datachain/query/session.py
CHANGED
|
@@ -28,9 +28,9 @@ class Session:
|
|
|
28
28
|
|
|
29
29
|
Parameters:
|
|
30
30
|
|
|
31
|
-
|
|
31
|
+
name (str): The name of the session. Only latters and numbers are supported.
|
|
32
32
|
It can be empty.
|
|
33
|
-
|
|
33
|
+
catalog (Catalog): Catalog object.
|
|
34
34
|
"""
|
|
35
35
|
|
|
36
36
|
GLOBAL_SESSION_CTX: Optional["Session"] = None
|
|
@@ -80,9 +80,9 @@ class Session:
|
|
|
80
80
|
"""Creates a Session() object from a catalog.
|
|
81
81
|
|
|
82
82
|
Parameters:
|
|
83
|
-
|
|
83
|
+
session (Session): Optional Session(). If not provided a new session will
|
|
84
84
|
be created. It's needed mostly for simplie API purposes.
|
|
85
|
-
|
|
85
|
+
catalog (Catalog): Optional catalog. By default a new catalog is created.
|
|
86
86
|
"""
|
|
87
87
|
if session:
|
|
88
88
|
return session
|
datachain/sql/functions/array.py
CHANGED
|
@@ -5,6 +5,10 @@ from datachain.sql.utils import compiler_not_implemented
|
|
|
5
5
|
|
|
6
6
|
|
|
7
7
|
class cosine_distance(GenericFunction): # noqa: N801
|
|
8
|
+
"""
|
|
9
|
+
Takes a column and array and returns the cosine distance between them.
|
|
10
|
+
"""
|
|
11
|
+
|
|
8
12
|
type = Float()
|
|
9
13
|
package = "array"
|
|
10
14
|
name = "cosine_distance"
|
|
@@ -12,6 +16,10 @@ class cosine_distance(GenericFunction): # noqa: N801
|
|
|
12
16
|
|
|
13
17
|
|
|
14
18
|
class euclidean_distance(GenericFunction): # noqa: N801
|
|
19
|
+
"""
|
|
20
|
+
Takes a column and array and returns the Euclidean distance between them.
|
|
21
|
+
"""
|
|
22
|
+
|
|
15
23
|
type = Float()
|
|
16
24
|
package = "array"
|
|
17
25
|
name = "euclidean_distance"
|
|
@@ -19,6 +27,10 @@ class euclidean_distance(GenericFunction): # noqa: N801
|
|
|
19
27
|
|
|
20
28
|
|
|
21
29
|
class length(GenericFunction): # noqa: N801
|
|
30
|
+
"""
|
|
31
|
+
Returns the length of the array.
|
|
32
|
+
"""
|
|
33
|
+
|
|
22
34
|
type = Int64()
|
|
23
35
|
package = "array"
|
|
24
36
|
name = "length"
|
|
@@ -5,6 +5,10 @@ from datachain.sql.utils import compiler_not_implemented
|
|
|
5
5
|
|
|
6
6
|
|
|
7
7
|
class length(GenericFunction): # noqa: N801
|
|
8
|
+
"""
|
|
9
|
+
Returns the length of the string.
|
|
10
|
+
"""
|
|
11
|
+
|
|
8
12
|
type = Int64()
|
|
9
13
|
package = "string"
|
|
10
14
|
name = "length"
|
|
@@ -12,6 +16,10 @@ class length(GenericFunction): # noqa: N801
|
|
|
12
16
|
|
|
13
17
|
|
|
14
18
|
class split(GenericFunction): # noqa: N801
|
|
19
|
+
"""
|
|
20
|
+
Takes a column and split character and returns an array of the parts.
|
|
21
|
+
"""
|
|
22
|
+
|
|
15
23
|
type = Array(String())
|
|
16
24
|
package = "string"
|
|
17
25
|
name = "split"
|
datachain/torch/__init__.py
CHANGED
|
@@ -1,5 +1,5 @@
|
|
|
1
1
|
try:
|
|
2
|
-
from datachain.lib.clip import
|
|
2
|
+
from datachain.lib.clip import clip_similarity_scores
|
|
3
3
|
from datachain.lib.image import convert_image, convert_images
|
|
4
4
|
from datachain.lib.pytorch import PytorchDataset, label_to_int
|
|
5
5
|
from datachain.lib.text import convert_text
|
datachain/utils.py
CHANGED
|
@@ -427,3 +427,9 @@ def filtered_cloudpickle_dumps(obj: Any) -> bytes:
|
|
|
427
427
|
for model_class, namespace in model_namespaces.items():
|
|
428
428
|
# Restore original __pydantic_parent_namespace__ locally.
|
|
429
429
|
model_class.__pydantic_parent_namespace__ = namespace
|
|
430
|
+
|
|
431
|
+
|
|
432
|
+
def get_datachain_executable() -> list[str]:
|
|
433
|
+
if datachain_exec_path := os.getenv("DATACHAIN_EXEC_PATH"):
|
|
434
|
+
return [datachain_exec_path]
|
|
435
|
+
return [sys.executable, "-m", "datachain"]
|