datachain 0.34.6__py3-none-any.whl → 0.35.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of datachain might be problematic. Click here for more details.
- datachain/asyn.py +11 -12
- datachain/cache.py +5 -5
- datachain/catalog/catalog.py +75 -83
- datachain/catalog/loader.py +3 -3
- datachain/checkpoint.py +1 -2
- datachain/cli/__init__.py +2 -4
- datachain/cli/commands/datasets.py +13 -13
- datachain/cli/commands/ls.py +4 -4
- datachain/cli/commands/query.py +3 -3
- datachain/cli/commands/show.py +2 -2
- datachain/cli/parser/job.py +1 -1
- datachain/cli/parser/utils.py +1 -2
- datachain/cli/utils.py +1 -2
- datachain/client/azure.py +2 -2
- datachain/client/fsspec.py +11 -21
- datachain/client/gcs.py +3 -3
- datachain/client/http.py +4 -4
- datachain/client/local.py +4 -4
- datachain/client/s3.py +3 -3
- datachain/config.py +4 -8
- datachain/data_storage/db_engine.py +5 -5
- datachain/data_storage/metastore.py +107 -107
- datachain/data_storage/schema.py +18 -24
- datachain/data_storage/sqlite.py +21 -28
- datachain/data_storage/warehouse.py +13 -13
- datachain/dataset.py +64 -70
- datachain/delta.py +21 -18
- datachain/diff/__init__.py +13 -13
- datachain/func/aggregate.py +9 -11
- datachain/func/array.py +12 -12
- datachain/func/base.py +7 -4
- datachain/func/conditional.py +9 -13
- datachain/func/func.py +45 -42
- datachain/func/numeric.py +5 -7
- datachain/func/string.py +2 -2
- datachain/hash_utils.py +54 -81
- datachain/job.py +8 -8
- datachain/lib/arrow.py +17 -14
- datachain/lib/audio.py +6 -6
- datachain/lib/clip.py +5 -4
- datachain/lib/convert/python_to_sql.py +4 -22
- datachain/lib/convert/values_to_tuples.py +4 -9
- datachain/lib/data_model.py +20 -19
- datachain/lib/dataset_info.py +6 -6
- datachain/lib/dc/csv.py +10 -10
- datachain/lib/dc/database.py +28 -29
- datachain/lib/dc/datachain.py +98 -97
- datachain/lib/dc/datasets.py +22 -22
- datachain/lib/dc/hf.py +4 -4
- datachain/lib/dc/json.py +9 -10
- datachain/lib/dc/listings.py +5 -8
- datachain/lib/dc/pandas.py +3 -6
- datachain/lib/dc/parquet.py +5 -5
- datachain/lib/dc/records.py +5 -5
- datachain/lib/dc/storage.py +12 -12
- datachain/lib/dc/storage_pattern.py +2 -2
- datachain/lib/dc/utils.py +11 -14
- datachain/lib/dc/values.py +3 -6
- datachain/lib/file.py +32 -28
- datachain/lib/hf.py +7 -5
- datachain/lib/image.py +13 -13
- datachain/lib/listing.py +5 -5
- datachain/lib/listing_info.py +1 -2
- datachain/lib/meta_formats.py +1 -2
- datachain/lib/model_store.py +3 -3
- datachain/lib/namespaces.py +4 -6
- datachain/lib/projects.py +5 -9
- datachain/lib/pytorch.py +10 -10
- datachain/lib/settings.py +23 -23
- datachain/lib/signal_schema.py +52 -44
- datachain/lib/text.py +8 -7
- datachain/lib/udf.py +25 -17
- datachain/lib/udf_signature.py +11 -11
- datachain/lib/video.py +3 -4
- datachain/lib/webdataset.py +30 -35
- datachain/lib/webdataset_laion.py +15 -16
- datachain/listing.py +4 -4
- datachain/model/bbox.py +3 -1
- datachain/namespace.py +4 -4
- datachain/node.py +6 -6
- datachain/nodes_thread_pool.py +0 -1
- datachain/plugins.py +1 -7
- datachain/project.py +4 -4
- datachain/query/batch.py +7 -8
- datachain/query/dataset.py +80 -87
- datachain/query/dispatch.py +7 -7
- datachain/query/metrics.py +3 -4
- datachain/query/params.py +2 -3
- datachain/query/schema.py +7 -6
- datachain/query/session.py +7 -7
- datachain/query/udf.py +8 -7
- datachain/query/utils.py +3 -5
- datachain/remote/studio.py +33 -39
- datachain/script_meta.py +12 -12
- datachain/sql/sqlite/base.py +6 -9
- datachain/studio.py +30 -30
- datachain/toolkit/split.py +1 -2
- datachain/utils.py +21 -21
- {datachain-0.34.6.dist-info → datachain-0.35.0.dist-info}/METADATA +2 -3
- datachain-0.35.0.dist-info/RECORD +173 -0
- datachain-0.34.6.dist-info/RECORD +0 -173
- {datachain-0.34.6.dist-info → datachain-0.35.0.dist-info}/WHEEL +0 -0
- {datachain-0.34.6.dist-info → datachain-0.35.0.dist-info}/entry_points.txt +0 -0
- {datachain-0.34.6.dist-info → datachain-0.35.0.dist-info}/licenses/LICENSE +0 -0
- {datachain-0.34.6.dist-info → datachain-0.35.0.dist-info}/top_level.txt +0 -0
datachain/func/numeric.py
CHANGED
|
@@ -1,12 +1,10 @@
|
|
|
1
|
-
from typing import Union
|
|
2
|
-
|
|
3
1
|
from datachain.query.schema import Column
|
|
4
2
|
from datachain.sql.functions import numeric
|
|
5
3
|
|
|
6
4
|
from .func import Func
|
|
7
5
|
|
|
8
6
|
|
|
9
|
-
def bit_and(*args:
|
|
7
|
+
def bit_and(*args: str | Column | Func | int) -> Func:
|
|
10
8
|
"""
|
|
11
9
|
Returns a function that computes the bitwise AND operation between two values.
|
|
12
10
|
|
|
@@ -51,7 +49,7 @@ def bit_and(*args: Union[str, Column, Func, int]) -> Func:
|
|
|
51
49
|
)
|
|
52
50
|
|
|
53
51
|
|
|
54
|
-
def bit_or(*args:
|
|
52
|
+
def bit_or(*args: str | Column | Func | int) -> Func:
|
|
55
53
|
"""
|
|
56
54
|
Returns a function that computes the bitwise OR operation between two values.
|
|
57
55
|
|
|
@@ -96,7 +94,7 @@ def bit_or(*args: Union[str, Column, Func, int]) -> Func:
|
|
|
96
94
|
)
|
|
97
95
|
|
|
98
96
|
|
|
99
|
-
def bit_xor(*args:
|
|
97
|
+
def bit_xor(*args: str | Column | Func | int) -> Func:
|
|
100
98
|
"""
|
|
101
99
|
Returns a function that computes the bitwise XOR operation between two values.
|
|
102
100
|
|
|
@@ -141,7 +139,7 @@ def bit_xor(*args: Union[str, Column, Func, int]) -> Func:
|
|
|
141
139
|
)
|
|
142
140
|
|
|
143
141
|
|
|
144
|
-
def int_hash_64(col:
|
|
142
|
+
def int_hash_64(col: str | Column | Func | int) -> Func:
|
|
145
143
|
"""
|
|
146
144
|
Returns a function that computes the 64-bit hash of an integer.
|
|
147
145
|
|
|
@@ -177,7 +175,7 @@ def int_hash_64(col: Union[str, Column, Func, int]) -> Func:
|
|
|
177
175
|
)
|
|
178
176
|
|
|
179
177
|
|
|
180
|
-
def bit_hamming_distance(*args:
|
|
178
|
+
def bit_hamming_distance(*args: str | Column | Func | int) -> Func:
|
|
181
179
|
"""
|
|
182
180
|
Returns a function that computes the Hamming distance between two integers.
|
|
183
181
|
|
datachain/func/string.py
CHANGED
|
@@ -1,4 +1,4 @@
|
|
|
1
|
-
from typing import
|
|
1
|
+
from typing import get_origin
|
|
2
2
|
|
|
3
3
|
from sqlalchemy import literal
|
|
4
4
|
|
|
@@ -44,7 +44,7 @@ def length(col: ColT) -> Func:
|
|
|
44
44
|
return Func("length", inner=string.length, cols=[col], result_type=int)
|
|
45
45
|
|
|
46
46
|
|
|
47
|
-
def split(col: ColT, sep: str, limit:
|
|
47
|
+
def split(col: ColT, sep: str, limit: int | None = None) -> Func:
|
|
48
48
|
"""
|
|
49
49
|
Takes a column and split character and returns an array of the parts.
|
|
50
50
|
|
datachain/hash_utils.py
CHANGED
|
@@ -3,101 +3,74 @@ import inspect
|
|
|
3
3
|
import json
|
|
4
4
|
import textwrap
|
|
5
5
|
from collections.abc import Sequence
|
|
6
|
-
from typing import
|
|
7
|
-
|
|
8
|
-
from sqlalchemy.sql.elements import (
|
|
9
|
-
BinaryExpression,
|
|
10
|
-
BindParameter,
|
|
11
|
-
ColumnElement,
|
|
12
|
-
Label,
|
|
13
|
-
Over,
|
|
14
|
-
UnaryExpression,
|
|
15
|
-
)
|
|
16
|
-
from sqlalchemy.sql.functions import Function
|
|
17
|
-
|
|
18
|
-
T = TypeVar("T", bound=ColumnElement)
|
|
19
|
-
ColumnLike = Union[str, T]
|
|
6
|
+
from typing import TypeAlias, TypeVar
|
|
20
7
|
|
|
8
|
+
from sqlalchemy.sql.elements import ClauseElement, ColumnElement
|
|
21
9
|
|
|
22
|
-
|
|
10
|
+
T = TypeVar("T", bound=ColumnElement)
|
|
11
|
+
ColumnLike: TypeAlias = str | T
|
|
12
|
+
|
|
13
|
+
|
|
14
|
+
def _serialize_value(val): # noqa: PLR0911
|
|
15
|
+
"""Helper to serialize arbitrary values recursively."""
|
|
16
|
+
if val is None:
|
|
17
|
+
return None
|
|
18
|
+
if isinstance(val, (str, int, float, bool)):
|
|
19
|
+
return val
|
|
20
|
+
if isinstance(val, ClauseElement):
|
|
21
|
+
return serialize_column_element(val)
|
|
22
|
+
if isinstance(val, dict):
|
|
23
|
+
# Sort dict keys for deterministic serialization
|
|
24
|
+
return {k: _serialize_value(v) for k, v in sorted(val.items())}
|
|
25
|
+
if isinstance(val, (list, tuple)):
|
|
26
|
+
return [_serialize_value(v) for v in val]
|
|
27
|
+
if callable(val):
|
|
28
|
+
return val.__name__ if hasattr(val, "__name__") else str(val)
|
|
29
|
+
return str(val)
|
|
30
|
+
|
|
31
|
+
|
|
32
|
+
def serialize_column_element(expr: str | ColumnElement) -> dict:
|
|
23
33
|
"""
|
|
24
34
|
Recursively serialize a SQLAlchemy ColumnElement into a deterministic structure.
|
|
35
|
+
Uses SQLAlchemy's _traverse_internals to automatically handle all expression types.
|
|
25
36
|
"""
|
|
37
|
+
from sqlalchemy.sql.elements import BindParameter
|
|
26
38
|
|
|
27
|
-
#
|
|
28
|
-
if isinstance(expr, BinaryExpression):
|
|
29
|
-
op = (
|
|
30
|
-
expr.operator.__name__
|
|
31
|
-
if hasattr(expr.operator, "__name__")
|
|
32
|
-
else str(expr.operator)
|
|
33
|
-
)
|
|
34
|
-
return {
|
|
35
|
-
"type": "binary",
|
|
36
|
-
"op": op,
|
|
37
|
-
"left": serialize_column_element(expr.left),
|
|
38
|
-
"right": serialize_column_element(expr.right),
|
|
39
|
-
}
|
|
40
|
-
|
|
41
|
-
# Unary operations: -col, NOT col, etc.
|
|
42
|
-
if isinstance(expr, UnaryExpression):
|
|
43
|
-
op = (
|
|
44
|
-
expr.operator.__name__
|
|
45
|
-
if expr.operator is not None and hasattr(expr.operator, "__name__")
|
|
46
|
-
else str(expr.operator)
|
|
47
|
-
)
|
|
48
|
-
|
|
49
|
-
return {
|
|
50
|
-
"type": "unary",
|
|
51
|
-
"op": op,
|
|
52
|
-
"element": serialize_column_element(expr.element), # type: ignore[arg-type]
|
|
53
|
-
}
|
|
54
|
-
|
|
55
|
-
# Function calls: func.lower(col), func.count(col), etc.
|
|
56
|
-
if isinstance(expr, Function):
|
|
57
|
-
return {
|
|
58
|
-
"type": "function",
|
|
59
|
-
"name": expr.name,
|
|
60
|
-
"clauses": [serialize_column_element(c) for c in expr.clauses],
|
|
61
|
-
}
|
|
62
|
-
|
|
63
|
-
# Window functions: func.row_number().over(partition_by=..., order_by=...)
|
|
64
|
-
if isinstance(expr, Over):
|
|
65
|
-
return {
|
|
66
|
-
"type": "window",
|
|
67
|
-
"function": serialize_column_element(expr.element),
|
|
68
|
-
"partition_by": [
|
|
69
|
-
serialize_column_element(p) for p in getattr(expr, "partition_by", [])
|
|
70
|
-
],
|
|
71
|
-
"order_by": [
|
|
72
|
-
serialize_column_element(o) for o in getattr(expr, "order_by", [])
|
|
73
|
-
],
|
|
74
|
-
}
|
|
75
|
-
|
|
76
|
-
# Labeled expressions: col.label("alias")
|
|
77
|
-
if isinstance(expr, Label):
|
|
78
|
-
return {
|
|
79
|
-
"type": "label",
|
|
80
|
-
"name": expr.name,
|
|
81
|
-
"element": serialize_column_element(expr.element),
|
|
82
|
-
}
|
|
83
|
-
|
|
84
|
-
# Bound values (constants)
|
|
39
|
+
# Special case: BindParameter has non-deterministic 'key' attribute, only use value
|
|
85
40
|
if isinstance(expr, BindParameter):
|
|
86
|
-
return {"type": "bind", "value": expr.value}
|
|
87
|
-
|
|
88
|
-
#
|
|
89
|
-
if
|
|
90
|
-
|
|
91
|
-
|
|
92
|
-
|
|
41
|
+
return {"type": "bind", "value": _serialize_value(expr.value)}
|
|
42
|
+
|
|
43
|
+
# Generic handling for all ClauseElement types using SQLAlchemy's internals
|
|
44
|
+
if isinstance(expr, ClauseElement):
|
|
45
|
+
# All standard SQLAlchemy types have _traverse_internals
|
|
46
|
+
if hasattr(expr, "_traverse_internals"):
|
|
47
|
+
result = {"type": expr.__class__.__name__}
|
|
48
|
+
for attr_name, _ in expr._traverse_internals:
|
|
49
|
+
# Skip 'table' attribute - table names can be auto-generated/random
|
|
50
|
+
# and are not semantically important for hashing
|
|
51
|
+
if attr_name == "table":
|
|
52
|
+
continue
|
|
53
|
+
if hasattr(expr, attr_name):
|
|
54
|
+
val = getattr(expr, attr_name)
|
|
55
|
+
result[attr_name] = _serialize_value(val)
|
|
56
|
+
return result
|
|
57
|
+
# Rare case: custom user-defined ClauseElement without _traverse_internals
|
|
58
|
+
# We don't know its structure, so just stringify it
|
|
59
|
+
return {"type": expr.__class__.__name__, "repr": str(expr)}
|
|
60
|
+
|
|
61
|
+
# Absolute fallback: stringify completely unknown types
|
|
93
62
|
return {"type": "other", "repr": str(expr)}
|
|
94
63
|
|
|
95
64
|
|
|
96
|
-
def hash_column_elements(columns: Sequence[ColumnLike]) -> str:
|
|
65
|
+
def hash_column_elements(columns: ColumnLike | Sequence[ColumnLike]) -> str:
|
|
97
66
|
"""
|
|
98
67
|
Hash a list of ColumnElements deterministically, dialect agnostic.
|
|
99
68
|
Only accepts ordered iterables (like list or tuple).
|
|
100
69
|
"""
|
|
70
|
+
# Handle case where a single ColumnElement is passed instead of a sequence
|
|
71
|
+
if isinstance(columns, (ColumnElement, str)):
|
|
72
|
+
columns = (columns,)
|
|
73
|
+
|
|
101
74
|
serialized = [serialize_column_element(c) for c in columns]
|
|
102
75
|
json_str = json.dumps(serialized, sort_keys=True) # stable JSON
|
|
103
76
|
return hashlib.sha256(json_str.encode("utf-8")).hexdigest()
|
datachain/job.py
CHANGED
|
@@ -2,7 +2,7 @@ import json
|
|
|
2
2
|
import uuid
|
|
3
3
|
from dataclasses import dataclass
|
|
4
4
|
from datetime import datetime
|
|
5
|
-
from typing import Any,
|
|
5
|
+
from typing import Any, TypeVar
|
|
6
6
|
|
|
7
7
|
J = TypeVar("J", bound="Job")
|
|
8
8
|
|
|
@@ -18,29 +18,29 @@ class Job:
|
|
|
18
18
|
workers: int
|
|
19
19
|
params: dict[str, str]
|
|
20
20
|
metrics: dict[str, Any]
|
|
21
|
-
finished_at:
|
|
22
|
-
python_version:
|
|
21
|
+
finished_at: datetime | None = None
|
|
22
|
+
python_version: str | None = None
|
|
23
23
|
error_message: str = ""
|
|
24
24
|
error_stack: str = ""
|
|
25
|
-
parent_job_id:
|
|
25
|
+
parent_job_id: str | None = None
|
|
26
26
|
|
|
27
27
|
@classmethod
|
|
28
28
|
def parse(
|
|
29
29
|
cls,
|
|
30
|
-
id:
|
|
30
|
+
id: str | uuid.UUID,
|
|
31
31
|
name: str,
|
|
32
32
|
status: int,
|
|
33
33
|
created_at: datetime,
|
|
34
|
-
finished_at:
|
|
34
|
+
finished_at: datetime | None,
|
|
35
35
|
query: str,
|
|
36
36
|
query_type: int,
|
|
37
37
|
workers: int,
|
|
38
|
-
python_version:
|
|
38
|
+
python_version: str | None,
|
|
39
39
|
error_message: str,
|
|
40
40
|
error_stack: str,
|
|
41
41
|
params: str,
|
|
42
42
|
metrics: str,
|
|
43
|
-
parent_job_id:
|
|
43
|
+
parent_job_id: str | None,
|
|
44
44
|
) -> "Job":
|
|
45
45
|
return cls(
|
|
46
46
|
str(id),
|
datachain/lib/arrow.py
CHANGED
|
@@ -1,6 +1,6 @@
|
|
|
1
1
|
from collections.abc import Sequence
|
|
2
2
|
from itertools import islice
|
|
3
|
-
from typing import TYPE_CHECKING, Any
|
|
3
|
+
from typing import TYPE_CHECKING, Any
|
|
4
4
|
|
|
5
5
|
import pyarrow as pa
|
|
6
6
|
import ujson as json
|
|
@@ -44,10 +44,10 @@ class ArrowGenerator(Generator):
|
|
|
44
44
|
|
|
45
45
|
def __init__(
|
|
46
46
|
self,
|
|
47
|
-
input_schema:
|
|
48
|
-
output_schema:
|
|
47
|
+
input_schema: pa.Schema | None = None,
|
|
48
|
+
output_schema: type["BaseModel"] | None = None,
|
|
49
49
|
source: bool = True,
|
|
50
|
-
nrows:
|
|
50
|
+
nrows: int | None = None,
|
|
51
51
|
**kwargs,
|
|
52
52
|
):
|
|
53
53
|
"""
|
|
@@ -112,7 +112,7 @@ class ArrowGenerator(Generator):
|
|
|
112
112
|
record: dict[str, Any],
|
|
113
113
|
file: File,
|
|
114
114
|
index: int,
|
|
115
|
-
hf_schema:
|
|
115
|
+
hf_schema: tuple["Features", dict[str, "DataType"]] | None,
|
|
116
116
|
use_datachain_schema: bool,
|
|
117
117
|
):
|
|
118
118
|
if use_datachain_schema and self.output_schema:
|
|
@@ -141,7 +141,7 @@ class ArrowGenerator(Generator):
|
|
|
141
141
|
def _process_non_datachain_record(
|
|
142
142
|
self,
|
|
143
143
|
record: dict[str, Any],
|
|
144
|
-
hf_schema:
|
|
144
|
+
hf_schema: tuple["Features", dict[str, "DataType"]] | None,
|
|
145
145
|
):
|
|
146
146
|
vals = list(record.values())
|
|
147
147
|
if not self.output_schema:
|
|
@@ -149,7 +149,9 @@ class ArrowGenerator(Generator):
|
|
|
149
149
|
|
|
150
150
|
fields = self.output_schema.model_fields
|
|
151
151
|
vals_dict = {}
|
|
152
|
-
for i, ((field, field_info), val) in enumerate(
|
|
152
|
+
for i, ((field, field_info), val) in enumerate(
|
|
153
|
+
zip(fields.items(), vals, strict=False)
|
|
154
|
+
):
|
|
153
155
|
anno = field_info.annotation
|
|
154
156
|
if hf_schema:
|
|
155
157
|
from datachain.lib.hf import convert_feature
|
|
@@ -180,7 +182,7 @@ def infer_schema(chain: "DataChain", **kwargs) -> pa.Schema:
|
|
|
180
182
|
|
|
181
183
|
|
|
182
184
|
def schema_to_output(
|
|
183
|
-
schema: pa.Schema, col_names:
|
|
185
|
+
schema: pa.Schema, col_names: Sequence[str] | None = None
|
|
184
186
|
) -> tuple[dict[str, type], list[str]]:
|
|
185
187
|
"""
|
|
186
188
|
Generate UDF output schema from pyarrow schema.
|
|
@@ -205,14 +207,15 @@ def schema_to_output(
|
|
|
205
207
|
hf_schema = _get_hf_schema(schema)
|
|
206
208
|
if hf_schema:
|
|
207
209
|
return {
|
|
208
|
-
column: hf_type
|
|
210
|
+
column: hf_type
|
|
211
|
+
for hf_type, column in zip(hf_schema[1].values(), col_names, strict=False)
|
|
209
212
|
}, list(normalized_col_dict.values())
|
|
210
213
|
|
|
211
214
|
output = {}
|
|
212
|
-
for field, column in zip(schema, col_names):
|
|
215
|
+
for field, column in zip(schema, col_names, strict=False):
|
|
213
216
|
dtype = arrow_type_mapper(field.type, column)
|
|
214
217
|
if field.nullable and not ModelStore.is_pydantic(dtype):
|
|
215
|
-
dtype =
|
|
218
|
+
dtype = dtype | None # type: ignore[assignment]
|
|
216
219
|
output[column] = dtype
|
|
217
220
|
|
|
218
221
|
return output, list(normalized_col_dict.values())
|
|
@@ -243,7 +246,7 @@ def arrow_type_mapper(col_type: pa.DataType, column: str = "") -> type: # noqa:
|
|
|
243
246
|
for field in col_type:
|
|
244
247
|
dtype = arrow_type_mapper(field.type, field.name)
|
|
245
248
|
if field.nullable and not ModelStore.is_pydantic(dtype):
|
|
246
|
-
dtype =
|
|
249
|
+
dtype = dtype | None # type: ignore[assignment]
|
|
247
250
|
type_dict[field.name] = dtype
|
|
248
251
|
return dict_to_data_model(f"ArrowDataModel_{column}", type_dict)
|
|
249
252
|
if pa.types.is_map(col_type):
|
|
@@ -257,7 +260,7 @@ def arrow_type_mapper(col_type: pa.DataType, column: str = "") -> type: # noqa:
|
|
|
257
260
|
|
|
258
261
|
def _get_hf_schema(
|
|
259
262
|
schema: "pa.Schema",
|
|
260
|
-
) ->
|
|
263
|
+
) -> tuple["Features", dict[str, "DataType"]] | None:
|
|
261
264
|
if schema.metadata and b"huggingface" in schema.metadata:
|
|
262
265
|
from datachain.lib.hf import get_output_schema, schema_from_arrow
|
|
263
266
|
|
|
@@ -266,7 +269,7 @@ def _get_hf_schema(
|
|
|
266
269
|
return None
|
|
267
270
|
|
|
268
271
|
|
|
269
|
-
def _get_datachain_schema(schema: "pa.Schema") ->
|
|
272
|
+
def _get_datachain_schema(schema: "pa.Schema") -> SignalSchema | None:
|
|
270
273
|
"""Return a restored SignalSchema from parquet metadata, if any is found."""
|
|
271
274
|
if schema.metadata and DATACHAIN_SIGNAL_SCHEMA_PARQUET_KEY in schema.metadata:
|
|
272
275
|
serialized_signal_schema = json.loads(
|
datachain/lib/audio.py
CHANGED
|
@@ -1,5 +1,5 @@
|
|
|
1
1
|
import posixpath
|
|
2
|
-
from typing import TYPE_CHECKING
|
|
2
|
+
from typing import TYPE_CHECKING
|
|
3
3
|
|
|
4
4
|
from datachain.lib.file import FileError
|
|
5
5
|
|
|
@@ -18,7 +18,7 @@ except ImportError as exc:
|
|
|
18
18
|
) from exc
|
|
19
19
|
|
|
20
20
|
|
|
21
|
-
def audio_info(file: "
|
|
21
|
+
def audio_info(file: "File | AudioFile") -> "Audio":
|
|
22
22
|
"""Extract metadata like sample rate, channels, duration, and format."""
|
|
23
23
|
from datachain.lib.file import Audio
|
|
24
24
|
|
|
@@ -99,7 +99,7 @@ def _encoding_to_format(encoding: str, file_ext: str) -> str:
|
|
|
99
99
|
|
|
100
100
|
|
|
101
101
|
def audio_to_np(
|
|
102
|
-
audio: "AudioFile", start: float = 0, duration:
|
|
102
|
+
audio: "AudioFile", start: float = 0, duration: float | None = None
|
|
103
103
|
) -> "tuple[ndarray, int]":
|
|
104
104
|
"""Load audio fragment as numpy array.
|
|
105
105
|
Multi-channel audio is transposed to (samples, channels)."""
|
|
@@ -146,7 +146,7 @@ def audio_to_bytes(
|
|
|
146
146
|
audio: "AudioFile",
|
|
147
147
|
format: str = "wav",
|
|
148
148
|
start: float = 0,
|
|
149
|
-
duration:
|
|
149
|
+
duration: float | None = None,
|
|
150
150
|
) -> bytes:
|
|
151
151
|
"""Convert audio to bytes using soundfile.
|
|
152
152
|
|
|
@@ -166,9 +166,9 @@ def audio_to_bytes(
|
|
|
166
166
|
def save_audio(
|
|
167
167
|
audio: "AudioFile",
|
|
168
168
|
output: str,
|
|
169
|
-
format:
|
|
169
|
+
format: str | None = None,
|
|
170
170
|
start: float = 0,
|
|
171
|
-
end:
|
|
171
|
+
end: float | None = None,
|
|
172
172
|
) -> "AudioFile":
|
|
173
173
|
"""Save audio file or extract fragment to specified format.
|
|
174
174
|
|
datachain/lib/clip.py
CHANGED
|
@@ -1,5 +1,6 @@
|
|
|
1
1
|
import inspect
|
|
2
|
-
from
|
|
2
|
+
from collections.abc import Callable
|
|
3
|
+
from typing import TYPE_CHECKING, Any, Literal, Union
|
|
3
4
|
|
|
4
5
|
import torch
|
|
5
6
|
from transformers.modeling_utils import PreTrainedModel
|
|
@@ -32,14 +33,14 @@ def _get_encoder(model: Any, type: Literal["image", "text"]) -> Callable:
|
|
|
32
33
|
|
|
33
34
|
|
|
34
35
|
def clip_similarity_scores(
|
|
35
|
-
images: Union[
|
|
36
|
-
text:
|
|
36
|
+
images: Union["Image.Image", list["Image.Image"]] | None,
|
|
37
|
+
text: str | list[str] | None,
|
|
37
38
|
model: Any,
|
|
38
39
|
preprocess: Callable,
|
|
39
40
|
tokenizer: Callable,
|
|
40
41
|
prob: bool = False,
|
|
41
42
|
image_to_text: bool = True,
|
|
42
|
-
device:
|
|
43
|
+
device: str | torch.device | None = None,
|
|
43
44
|
) -> list[list[float]]:
|
|
44
45
|
"""
|
|
45
46
|
Calculate CLIP similarity scores between one or more images and/or text.
|
|
@@ -1,14 +1,9 @@
|
|
|
1
1
|
import inspect
|
|
2
|
-
import sys
|
|
3
2
|
from datetime import datetime
|
|
4
3
|
from enum import Enum
|
|
4
|
+
from types import UnionType
|
|
5
5
|
from typing import Annotated, Literal, Union, get_args, get_origin
|
|
6
6
|
|
|
7
|
-
if sys.version_info >= (3, 10):
|
|
8
|
-
from types import UnionType
|
|
9
|
-
else:
|
|
10
|
-
UnionType = None
|
|
11
|
-
|
|
12
7
|
from pydantic import BaseModel
|
|
13
8
|
from typing_extensions import Literal as LiteralEx
|
|
14
9
|
|
|
@@ -40,13 +35,6 @@ PYTHON_TO_SQL = {
|
|
|
40
35
|
}
|
|
41
36
|
|
|
42
37
|
|
|
43
|
-
def _is_union(orig) -> bool:
|
|
44
|
-
if orig == Union:
|
|
45
|
-
return True
|
|
46
|
-
# some code is unreachab in python<3.10
|
|
47
|
-
return UnionType is not None and orig is UnionType # type: ignore[unreachable]
|
|
48
|
-
|
|
49
|
-
|
|
50
38
|
def python_to_sql(typ): # noqa: PLR0911
|
|
51
39
|
if inspect.isclass(typ):
|
|
52
40
|
if issubclass(typ, SQLType):
|
|
@@ -82,12 +70,12 @@ def python_to_sql(typ): # noqa: PLR0911
|
|
|
82
70
|
if inspect.isclass(orig) and issubclass(dict, orig):
|
|
83
71
|
return JSON
|
|
84
72
|
|
|
85
|
-
if
|
|
73
|
+
if orig in (Union, UnionType):
|
|
86
74
|
if len(args) == 2 and (type(None) in args):
|
|
87
75
|
non_none_arg = args[0] if args[0] is not type(None) else args[1]
|
|
88
76
|
return python_to_sql(non_none_arg)
|
|
89
77
|
|
|
90
|
-
if
|
|
78
|
+
if all(arg is str or get_origin(arg) in (Literal, LiteralEx) for arg in args):
|
|
91
79
|
return String
|
|
92
80
|
|
|
93
81
|
if _is_json_inside_union(orig, args):
|
|
@@ -109,7 +97,7 @@ def list_of_args_to_type(args) -> SQLType:
|
|
|
109
97
|
|
|
110
98
|
|
|
111
99
|
def _is_json_inside_union(orig, args) -> bool:
|
|
112
|
-
if
|
|
100
|
+
if orig in (Union, UnionType) and len(args) >= 2:
|
|
113
101
|
# List in JSON: Union[dict, list[dict]]
|
|
114
102
|
args_no_nones = [arg for arg in args if arg != type(None)] # noqa: E721
|
|
115
103
|
if len(args_no_nones) == 2:
|
|
@@ -123,9 +111,3 @@ def _is_json_inside_union(orig, args) -> bool:
|
|
|
123
111
|
if any(inspect.isclass(arg) and issubclass(arg, BaseModel) for arg in args):
|
|
124
112
|
return True
|
|
125
113
|
return False
|
|
126
|
-
|
|
127
|
-
|
|
128
|
-
def _is_union_str_literal(orig, args) -> bool:
|
|
129
|
-
if not _is_union(orig):
|
|
130
|
-
return False
|
|
131
|
-
return all(arg is str or get_origin(arg) in (Literal, LiteralEx) for arg in args)
|
|
@@ -1,13 +1,8 @@
|
|
|
1
1
|
import itertools
|
|
2
2
|
from collections.abc import Sequence
|
|
3
|
-
from typing import Any
|
|
3
|
+
from typing import Any
|
|
4
4
|
|
|
5
|
-
from datachain.lib.data_model import
|
|
6
|
-
DataType,
|
|
7
|
-
DataTypeNames,
|
|
8
|
-
DataValue,
|
|
9
|
-
is_chain_type,
|
|
10
|
-
)
|
|
5
|
+
from datachain.lib.data_model import DataType, DataTypeNames, DataValue, is_chain_type
|
|
11
6
|
from datachain.lib.utils import DataChainParamsError
|
|
12
7
|
|
|
13
8
|
|
|
@@ -20,7 +15,7 @@ class ValuesToTupleError(DataChainParamsError):
|
|
|
20
15
|
|
|
21
16
|
def values_to_tuples( # noqa: C901, PLR0912
|
|
22
17
|
ds_name: str = "",
|
|
23
|
-
output:
|
|
18
|
+
output: DataType | Sequence[str] | dict[str, DataType] | None = None,
|
|
24
19
|
**fr_map: Sequence[DataValue],
|
|
25
20
|
) -> tuple[Any, Any, Any]:
|
|
26
21
|
if output:
|
|
@@ -111,7 +106,7 @@ def values_to_tuples( # noqa: C901, PLR0912
|
|
|
111
106
|
if len(output) > 1: # type: ignore[arg-type]
|
|
112
107
|
tuple_type = tuple(output_types)
|
|
113
108
|
res_type = tuple[tuple_type] # type: ignore[valid-type]
|
|
114
|
-
res_values: Sequence[Any] = list(zip(*fr_map.values()))
|
|
109
|
+
res_values: Sequence[Any] = list(zip(*fr_map.values(), strict=False))
|
|
115
110
|
else:
|
|
116
111
|
res_type = output_types[0] # type: ignore[misc]
|
|
117
112
|
res_values = next(iter(fr_map.values()))
|
datachain/lib/data_model.py
CHANGED
|
@@ -1,8 +1,9 @@
|
|
|
1
1
|
import inspect
|
|
2
|
+
import types
|
|
2
3
|
import uuid
|
|
3
4
|
from collections.abc import Sequence
|
|
4
5
|
from datetime import datetime
|
|
5
|
-
from typing import ClassVar,
|
|
6
|
+
from typing import ClassVar, Union, get_args, get_origin
|
|
6
7
|
|
|
7
8
|
from pydantic import AliasChoices, BaseModel, Field, create_model
|
|
8
9
|
from pydantic.fields import FieldInfo
|
|
@@ -10,19 +11,19 @@ from pydantic.fields import FieldInfo
|
|
|
10
11
|
from datachain.lib.model_store import ModelStore
|
|
11
12
|
from datachain.lib.utils import normalize_col_names
|
|
12
13
|
|
|
13
|
-
StandardType =
|
|
14
|
-
type[int]
|
|
15
|
-
type[str]
|
|
16
|
-
type[float]
|
|
17
|
-
type[bool]
|
|
18
|
-
type[list]
|
|
19
|
-
type[dict]
|
|
20
|
-
type[bytes]
|
|
21
|
-
type[datetime]
|
|
22
|
-
|
|
23
|
-
DataType =
|
|
14
|
+
StandardType = (
|
|
15
|
+
type[int]
|
|
16
|
+
| type[str]
|
|
17
|
+
| type[float]
|
|
18
|
+
| type[bool]
|
|
19
|
+
| type[list]
|
|
20
|
+
| type[dict]
|
|
21
|
+
| type[bytes]
|
|
22
|
+
| type[datetime]
|
|
23
|
+
)
|
|
24
|
+
DataType = type[BaseModel] | StandardType
|
|
24
25
|
DataTypeNames = "BaseModel, int, str, float, bool, list, dict, bytes, datetime"
|
|
25
|
-
DataValue =
|
|
26
|
+
DataValue = BaseModel | int | str | float | bool | list | dict | bytes | datetime
|
|
26
27
|
|
|
27
28
|
|
|
28
29
|
class DataModel(BaseModel):
|
|
@@ -37,7 +38,7 @@ class DataModel(BaseModel):
|
|
|
37
38
|
ModelStore.register(cls)
|
|
38
39
|
|
|
39
40
|
@staticmethod
|
|
40
|
-
def register(models:
|
|
41
|
+
def register(models: DataType | Sequence[DataType]):
|
|
41
42
|
"""For registering classes manually. It accepts a single class or a sequence of
|
|
42
43
|
classes."""
|
|
43
44
|
if not isinstance(models, Sequence):
|
|
@@ -63,8 +64,8 @@ def is_chain_type(t: type) -> bool:
|
|
|
63
64
|
if orig is list and len(args) == 1:
|
|
64
65
|
return is_chain_type(get_args(t)[0])
|
|
65
66
|
|
|
66
|
-
if orig
|
|
67
|
-
return is_chain_type(args[0])
|
|
67
|
+
if orig in (Union, types.UnionType) and len(args) == 2 and (type(None) in args):
|
|
68
|
+
return is_chain_type(args[0] if args[1] is type(None) else args[1])
|
|
68
69
|
|
|
69
70
|
return False
|
|
70
71
|
|
|
@@ -72,19 +73,19 @@ def is_chain_type(t: type) -> bool:
|
|
|
72
73
|
def dict_to_data_model(
|
|
73
74
|
name: str,
|
|
74
75
|
data_dict: dict[str, DataType],
|
|
75
|
-
original_names:
|
|
76
|
+
original_names: list[str] | None = None,
|
|
76
77
|
) -> type[BaseModel]:
|
|
77
78
|
if not original_names:
|
|
78
79
|
# Gets a map of a normalized_name -> original_name
|
|
79
80
|
columns = normalize_col_names(list(data_dict))
|
|
80
|
-
data_dict = dict(zip(columns.keys(), data_dict.values()))
|
|
81
|
+
data_dict = dict(zip(columns.keys(), data_dict.values(), strict=False))
|
|
81
82
|
original_names = list(columns.values())
|
|
82
83
|
|
|
83
84
|
fields = {
|
|
84
85
|
name: (
|
|
85
86
|
anno
|
|
86
87
|
if inspect.isclass(anno) and issubclass(anno, BaseModel)
|
|
87
|
-
else
|
|
88
|
+
else anno | None,
|
|
88
89
|
Field(
|
|
89
90
|
validation_alias=AliasChoices(name, original_names[idx] or name),
|
|
90
91
|
default=None,
|
datachain/lib/dataset_info.py
CHANGED
|
@@ -1,6 +1,6 @@
|
|
|
1
1
|
import json
|
|
2
2
|
from datetime import datetime
|
|
3
|
-
from typing import TYPE_CHECKING, Any
|
|
3
|
+
from typing import TYPE_CHECKING, Any
|
|
4
4
|
from uuid import uuid4
|
|
5
5
|
|
|
6
6
|
from pydantic import Field, field_validator
|
|
@@ -28,9 +28,9 @@ class DatasetInfo(DataModel):
|
|
|
28
28
|
version: str = Field(default=DEFAULT_DATASET_VERSION)
|
|
29
29
|
status: int = Field(default=DatasetStatus.CREATED)
|
|
30
30
|
created_at: datetime = Field(default=TIME_ZERO)
|
|
31
|
-
finished_at:
|
|
32
|
-
num_objects:
|
|
33
|
-
size:
|
|
31
|
+
finished_at: datetime | None = Field(default=None)
|
|
32
|
+
num_objects: int | None = Field(default=None)
|
|
33
|
+
size: int | None = Field(default=None)
|
|
34
34
|
params: dict[str, str] = Field(default={})
|
|
35
35
|
metrics: dict[str, Any] = Field(default={})
|
|
36
36
|
error_message: str = Field(default="")
|
|
@@ -59,7 +59,7 @@ class DatasetInfo(DataModel):
|
|
|
59
59
|
|
|
60
60
|
@staticmethod
|
|
61
61
|
def _validate_dict(
|
|
62
|
-
v:
|
|
62
|
+
v: str | dict | None,
|
|
63
63
|
) -> dict:
|
|
64
64
|
if v is None or v == "":
|
|
65
65
|
return {}
|
|
@@ -88,7 +88,7 @@ class DatasetInfo(DataModel):
|
|
|
88
88
|
cls,
|
|
89
89
|
dataset: DatasetListRecord,
|
|
90
90
|
version: DatasetListVersion,
|
|
91
|
-
job:
|
|
91
|
+
job: Job | None,
|
|
92
92
|
) -> "Self":
|
|
93
93
|
return cls(
|
|
94
94
|
uuid=version.uuid,
|