datachain 0.14.2__py3-none-any.whl → 0.39.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- datachain/__init__.py +20 -0
- datachain/asyn.py +11 -12
- datachain/cache.py +7 -7
- datachain/catalog/__init__.py +2 -2
- datachain/catalog/catalog.py +621 -507
- datachain/catalog/dependency.py +164 -0
- datachain/catalog/loader.py +28 -18
- datachain/checkpoint.py +43 -0
- datachain/cli/__init__.py +24 -33
- datachain/cli/commands/__init__.py +1 -8
- datachain/cli/commands/datasets.py +83 -52
- datachain/cli/commands/ls.py +17 -17
- datachain/cli/commands/show.py +4 -4
- datachain/cli/parser/__init__.py +8 -74
- datachain/cli/parser/job.py +95 -3
- datachain/cli/parser/studio.py +11 -4
- datachain/cli/parser/utils.py +1 -2
- datachain/cli/utils.py +2 -15
- datachain/client/azure.py +4 -4
- datachain/client/fsspec.py +45 -28
- datachain/client/gcs.py +6 -6
- datachain/client/hf.py +29 -2
- datachain/client/http.py +157 -0
- datachain/client/local.py +15 -11
- datachain/client/s3.py +17 -9
- datachain/config.py +4 -8
- datachain/data_storage/db_engine.py +12 -6
- datachain/data_storage/job.py +5 -1
- datachain/data_storage/metastore.py +1252 -186
- datachain/data_storage/schema.py +58 -45
- datachain/data_storage/serializer.py +105 -15
- datachain/data_storage/sqlite.py +286 -127
- datachain/data_storage/warehouse.py +250 -113
- datachain/dataset.py +353 -148
- datachain/delta.py +391 -0
- datachain/diff/__init__.py +27 -29
- datachain/error.py +60 -0
- datachain/func/__init__.py +2 -1
- datachain/func/aggregate.py +66 -42
- datachain/func/array.py +242 -38
- datachain/func/base.py +7 -4
- datachain/func/conditional.py +110 -60
- datachain/func/func.py +96 -45
- datachain/func/numeric.py +55 -38
- datachain/func/path.py +32 -20
- datachain/func/random.py +2 -2
- datachain/func/string.py +67 -37
- datachain/func/window.py +7 -8
- datachain/hash_utils.py +123 -0
- datachain/job.py +11 -7
- datachain/json.py +138 -0
- datachain/lib/arrow.py +58 -22
- datachain/lib/audio.py +245 -0
- datachain/lib/clip.py +14 -13
- datachain/lib/convert/flatten.py +5 -3
- datachain/lib/convert/python_to_sql.py +6 -10
- datachain/lib/convert/sql_to_python.py +8 -0
- datachain/lib/convert/values_to_tuples.py +156 -51
- datachain/lib/data_model.py +42 -20
- datachain/lib/dataset_info.py +36 -8
- datachain/lib/dc/__init__.py +8 -2
- datachain/lib/dc/csv.py +25 -28
- datachain/lib/dc/database.py +398 -0
- datachain/lib/dc/datachain.py +1289 -425
- datachain/lib/dc/datasets.py +320 -38
- datachain/lib/dc/hf.py +38 -24
- datachain/lib/dc/json.py +29 -32
- datachain/lib/dc/listings.py +112 -8
- datachain/lib/dc/pandas.py +16 -12
- datachain/lib/dc/parquet.py +35 -23
- datachain/lib/dc/records.py +31 -23
- datachain/lib/dc/storage.py +154 -64
- datachain/lib/dc/storage_pattern.py +251 -0
- datachain/lib/dc/utils.py +24 -16
- datachain/lib/dc/values.py +8 -9
- datachain/lib/file.py +622 -89
- datachain/lib/hf.py +69 -39
- datachain/lib/image.py +14 -14
- datachain/lib/listing.py +14 -11
- datachain/lib/listing_info.py +1 -2
- datachain/lib/meta_formats.py +3 -4
- datachain/lib/model_store.py +39 -7
- datachain/lib/namespaces.py +125 -0
- datachain/lib/projects.py +130 -0
- datachain/lib/pytorch.py +32 -21
- datachain/lib/settings.py +192 -56
- datachain/lib/signal_schema.py +427 -104
- datachain/lib/tar.py +1 -2
- datachain/lib/text.py +8 -7
- datachain/lib/udf.py +164 -76
- datachain/lib/udf_signature.py +60 -35
- datachain/lib/utils.py +118 -4
- datachain/lib/video.py +17 -9
- datachain/lib/webdataset.py +61 -56
- datachain/lib/webdataset_laion.py +15 -16
- datachain/listing.py +22 -10
- datachain/model/bbox.py +3 -1
- datachain/model/ultralytics/bbox.py +16 -12
- datachain/model/ultralytics/pose.py +16 -12
- datachain/model/ultralytics/segment.py +16 -12
- datachain/namespace.py +84 -0
- datachain/node.py +6 -6
- datachain/nodes_thread_pool.py +0 -1
- datachain/plugins.py +24 -0
- datachain/project.py +78 -0
- datachain/query/batch.py +40 -41
- datachain/query/dataset.py +604 -322
- datachain/query/dispatch.py +261 -154
- datachain/query/metrics.py +4 -6
- datachain/query/params.py +2 -3
- datachain/query/queue.py +3 -12
- datachain/query/schema.py +11 -6
- datachain/query/session.py +200 -33
- datachain/query/udf.py +34 -2
- datachain/remote/studio.py +171 -69
- datachain/script_meta.py +12 -12
- datachain/semver.py +68 -0
- datachain/sql/__init__.py +2 -0
- datachain/sql/functions/array.py +33 -1
- datachain/sql/postgresql_dialect.py +9 -0
- datachain/sql/postgresql_types.py +21 -0
- datachain/sql/sqlite/__init__.py +5 -1
- datachain/sql/sqlite/base.py +102 -29
- datachain/sql/sqlite/types.py +8 -13
- datachain/sql/types.py +70 -15
- datachain/studio.py +223 -46
- datachain/toolkit/split.py +31 -10
- datachain/utils.py +101 -59
- {datachain-0.14.2.dist-info → datachain-0.39.0.dist-info}/METADATA +77 -22
- datachain-0.39.0.dist-info/RECORD +173 -0
- {datachain-0.14.2.dist-info → datachain-0.39.0.dist-info}/WHEEL +1 -1
- datachain/cli/commands/query.py +0 -53
- datachain/query/utils.py +0 -42
- datachain-0.14.2.dist-info/RECORD +0 -158
- {datachain-0.14.2.dist-info → datachain-0.39.0.dist-info}/entry_points.txt +0 -0
- {datachain-0.14.2.dist-info → datachain-0.39.0.dist-info}/licenses/LICENSE +0 -0
- {datachain-0.14.2.dist-info → datachain-0.39.0.dist-info}/top_level.txt +0 -0
datachain/json.py
ADDED
|
@@ -0,0 +1,138 @@
|
|
|
1
|
+
"""DataChain JSON utilities.
|
|
2
|
+
|
|
3
|
+
This module wraps :mod:`ujson` so we can guarantee consistent handling
|
|
4
|
+
of values that the encoder does not support out of the box (for example
|
|
5
|
+
``datetime`` objects or ``bytes``).
|
|
6
|
+
All code inside DataChain should import this module instead of using
|
|
7
|
+
:mod:`ujson` directly.
|
|
8
|
+
"""
|
|
9
|
+
|
|
10
|
+
import datetime as _dt
|
|
11
|
+
import json as _json
|
|
12
|
+
import uuid as _uuid
|
|
13
|
+
from collections.abc import Callable
|
|
14
|
+
from typing import Any
|
|
15
|
+
|
|
16
|
+
import ujson as _ujson
|
|
17
|
+
|
|
18
|
+
__all__ = [
|
|
19
|
+
"JSONDecodeError",
|
|
20
|
+
"dump",
|
|
21
|
+
"dumps",
|
|
22
|
+
"load",
|
|
23
|
+
"loads",
|
|
24
|
+
]
|
|
25
|
+
|
|
26
|
+
JSONDecodeError = (_ujson.JSONDecodeError, _json.JSONDecodeError)
|
|
27
|
+
|
|
28
|
+
_SENTINEL = object()
|
|
29
|
+
_Default = Callable[[Any], Any]
|
|
30
|
+
DEFAULT_PREVIEW_BYTES = 1024
|
|
31
|
+
|
|
32
|
+
|
|
33
|
+
# To make it looks like Pydantic's ISO format with 'Z' for UTC
|
|
34
|
+
# It is minor but nice to have consistency
|
|
35
|
+
def _format_datetime(value: _dt.datetime) -> str:
|
|
36
|
+
iso = value.isoformat()
|
|
37
|
+
|
|
38
|
+
offset = value.utcoffset()
|
|
39
|
+
if value.tzinfo is None or offset is None:
|
|
40
|
+
return iso
|
|
41
|
+
|
|
42
|
+
if offset == _dt.timedelta(0) and iso.endswith(("+00:00", "-00:00")):
|
|
43
|
+
return iso[:-6] + "Z"
|
|
44
|
+
|
|
45
|
+
return iso
|
|
46
|
+
|
|
47
|
+
|
|
48
|
+
def _format_time(value: _dt.time) -> str:
|
|
49
|
+
iso = value.isoformat()
|
|
50
|
+
|
|
51
|
+
offset = value.utcoffset()
|
|
52
|
+
if value.tzinfo is None or offset is None:
|
|
53
|
+
return iso
|
|
54
|
+
|
|
55
|
+
if offset == _dt.timedelta(0) and iso.endswith(("+00:00", "-00:00")):
|
|
56
|
+
return iso[:-6] + "Z"
|
|
57
|
+
|
|
58
|
+
return iso
|
|
59
|
+
|
|
60
|
+
|
|
61
|
+
def _coerce(value: Any, serialize_bytes: bool) -> Any:
|
|
62
|
+
"""Return a JSON-serializable representation for supported extra types."""
|
|
63
|
+
|
|
64
|
+
if isinstance(value, _dt.datetime):
|
|
65
|
+
return _format_datetime(value)
|
|
66
|
+
if isinstance(value, _dt.date):
|
|
67
|
+
return value.isoformat()
|
|
68
|
+
if isinstance(value, _dt.time):
|
|
69
|
+
return _format_time(value)
|
|
70
|
+
if isinstance(value, _uuid.UUID):
|
|
71
|
+
return str(value)
|
|
72
|
+
if serialize_bytes and isinstance(value, (bytes, bytearray)):
|
|
73
|
+
return list(bytes(value)[:DEFAULT_PREVIEW_BYTES])
|
|
74
|
+
return _SENTINEL
|
|
75
|
+
|
|
76
|
+
|
|
77
|
+
def _base_default(value: Any, serialize_bytes: bool) -> Any:
|
|
78
|
+
converted = _coerce(value, serialize_bytes)
|
|
79
|
+
if converted is not _SENTINEL:
|
|
80
|
+
return converted
|
|
81
|
+
raise TypeError(f"Object of type {type(value).__name__} is not JSON serializable")
|
|
82
|
+
|
|
83
|
+
|
|
84
|
+
def _build_default(user_default: _Default | None, serialize_bytes: bool) -> _Default:
|
|
85
|
+
if user_default is None:
|
|
86
|
+
return lambda value: _base_default(value, serialize_bytes)
|
|
87
|
+
|
|
88
|
+
def combined(value: Any) -> Any:
|
|
89
|
+
converted = _coerce(value, serialize_bytes)
|
|
90
|
+
if converted is not _SENTINEL:
|
|
91
|
+
return converted
|
|
92
|
+
return user_default(value)
|
|
93
|
+
|
|
94
|
+
return combined
|
|
95
|
+
|
|
96
|
+
|
|
97
|
+
def dumps(
|
|
98
|
+
obj: Any,
|
|
99
|
+
*,
|
|
100
|
+
default: _Default | None = None,
|
|
101
|
+
serialize_bytes: bool = False,
|
|
102
|
+
**kwargs: Any,
|
|
103
|
+
) -> str:
|
|
104
|
+
"""Serialize *obj* to a JSON-formatted ``str``."""
|
|
105
|
+
|
|
106
|
+
if serialize_bytes:
|
|
107
|
+
return _json.dumps(obj, default=_build_default(default, True), **kwargs)
|
|
108
|
+
|
|
109
|
+
return _ujson.dumps(obj, default=_build_default(default, False), **kwargs)
|
|
110
|
+
|
|
111
|
+
|
|
112
|
+
def dump(
|
|
113
|
+
obj: Any,
|
|
114
|
+
fp,
|
|
115
|
+
*,
|
|
116
|
+
default: _Default | None = None,
|
|
117
|
+
serialize_bytes: bool = False,
|
|
118
|
+
**kwargs: Any,
|
|
119
|
+
) -> None:
|
|
120
|
+
"""Serialize *obj* as a JSON formatted stream to *fp*."""
|
|
121
|
+
|
|
122
|
+
if serialize_bytes:
|
|
123
|
+
_json.dump(obj, fp, default=_build_default(default, True), **kwargs)
|
|
124
|
+
return
|
|
125
|
+
|
|
126
|
+
_ujson.dump(obj, fp, default=_build_default(default, False), **kwargs)
|
|
127
|
+
|
|
128
|
+
|
|
129
|
+
def loads(s: str | bytes | bytearray, **kwargs: Any) -> Any:
|
|
130
|
+
"""Deserialize *s* to a Python object."""
|
|
131
|
+
|
|
132
|
+
return _ujson.loads(s, **kwargs)
|
|
133
|
+
|
|
134
|
+
|
|
135
|
+
def load(fp, **kwargs: Any) -> Any:
|
|
136
|
+
"""Deserialize JSON content from *fp* to a Python object."""
|
|
137
|
+
|
|
138
|
+
return loads(fp.read(), **kwargs)
|
datachain/lib/arrow.py
CHANGED
|
@@ -1,12 +1,13 @@
|
|
|
1
1
|
from collections.abc import Sequence
|
|
2
2
|
from itertools import islice
|
|
3
|
-
from typing import TYPE_CHECKING, Any
|
|
3
|
+
from typing import TYPE_CHECKING, Any
|
|
4
4
|
|
|
5
|
-
import orjson
|
|
6
5
|
import pyarrow as pa
|
|
6
|
+
from pyarrow._csv import ParseOptions
|
|
7
7
|
from pyarrow.dataset import CsvFileFormat, dataset
|
|
8
8
|
from tqdm.auto import tqdm
|
|
9
9
|
|
|
10
|
+
from datachain import json
|
|
10
11
|
from datachain.fs.reference import ReferenceFileSystem
|
|
11
12
|
from datachain.lib.data_model import dict_to_data_model
|
|
12
13
|
from datachain.lib.file import ArrowRow, File
|
|
@@ -26,15 +27,27 @@ if TYPE_CHECKING:
|
|
|
26
27
|
DATACHAIN_SIGNAL_SCHEMA_PARQUET_KEY = b"DataChain SignalSchema"
|
|
27
28
|
|
|
28
29
|
|
|
30
|
+
def fix_pyarrow_format(format, parse_options=None):
|
|
31
|
+
# Re-init invalid row handler: https://issues.apache.org/jira/browse/ARROW-17641
|
|
32
|
+
if (
|
|
33
|
+
format
|
|
34
|
+
and isinstance(format, CsvFileFormat)
|
|
35
|
+
and parse_options
|
|
36
|
+
and isinstance(parse_options, ParseOptions)
|
|
37
|
+
):
|
|
38
|
+
format.parse_options = parse_options
|
|
39
|
+
return format
|
|
40
|
+
|
|
41
|
+
|
|
29
42
|
class ArrowGenerator(Generator):
|
|
30
43
|
DEFAULT_BATCH_SIZE = 2**17 # same as `pyarrow._dataset._DEFAULT_BATCH_SIZE`
|
|
31
44
|
|
|
32
45
|
def __init__(
|
|
33
46
|
self,
|
|
34
|
-
input_schema:
|
|
35
|
-
output_schema:
|
|
47
|
+
input_schema: pa.Schema | None = None,
|
|
48
|
+
output_schema: type["BaseModel"] | None = None,
|
|
36
49
|
source: bool = True,
|
|
37
|
-
nrows:
|
|
50
|
+
nrows: int | None = None,
|
|
38
51
|
**kwargs,
|
|
39
52
|
):
|
|
40
53
|
"""
|
|
@@ -53,6 +66,7 @@ class ArrowGenerator(Generator):
|
|
|
53
66
|
self.output_schema = output_schema
|
|
54
67
|
self.source = source
|
|
55
68
|
self.nrows = nrows
|
|
69
|
+
self.parse_options = kwargs.pop("parse_options", None)
|
|
56
70
|
self.kwargs = kwargs
|
|
57
71
|
|
|
58
72
|
def process(self, file: File):
|
|
@@ -62,9 +76,13 @@ class ArrowGenerator(Generator):
|
|
|
62
76
|
fs_path = file.path
|
|
63
77
|
fs = ReferenceFileSystem({fs_path: [cache_path]})
|
|
64
78
|
else:
|
|
65
|
-
fs, fs_path = file.get_fs(), file.
|
|
79
|
+
fs, fs_path = file.get_fs(), file.get_fs_path()
|
|
80
|
+
|
|
81
|
+
kwargs = self.kwargs
|
|
82
|
+
if format := kwargs.get("format"):
|
|
83
|
+
kwargs["format"] = fix_pyarrow_format(format, self.parse_options)
|
|
66
84
|
|
|
67
|
-
ds = dataset(fs_path, schema=self.input_schema, filesystem=fs, **
|
|
85
|
+
ds = dataset(fs_path, schema=self.input_schema, filesystem=fs, **kwargs)
|
|
68
86
|
|
|
69
87
|
hf_schema = _get_hf_schema(ds.schema)
|
|
70
88
|
use_datachain_schema = (
|
|
@@ -94,7 +112,7 @@ class ArrowGenerator(Generator):
|
|
|
94
112
|
record: dict[str, Any],
|
|
95
113
|
file: File,
|
|
96
114
|
index: int,
|
|
97
|
-
hf_schema:
|
|
115
|
+
hf_schema: tuple["Features", dict[str, "DataType"]] | None,
|
|
98
116
|
use_datachain_schema: bool,
|
|
99
117
|
):
|
|
100
118
|
if use_datachain_schema and self.output_schema:
|
|
@@ -108,13 +126,22 @@ class ArrowGenerator(Generator):
|
|
|
108
126
|
if isinstance(kwargs.get("format"), CsvFileFormat):
|
|
109
127
|
kwargs["format"] = "csv"
|
|
110
128
|
arrow_file = ArrowRow(file=file, index=index, kwargs=kwargs)
|
|
129
|
+
|
|
130
|
+
if self.output_schema and hasattr(vals[0], "source"):
|
|
131
|
+
# if we are reading parquet file written by datachain it might have
|
|
132
|
+
# source inside of it already, so we should not duplicate it, instead
|
|
133
|
+
# we are re-creating it of the self.source flag
|
|
134
|
+
vals[0].source = arrow_file # type: ignore[attr-defined]
|
|
135
|
+
|
|
136
|
+
return vals
|
|
111
137
|
return [arrow_file, *vals]
|
|
138
|
+
|
|
112
139
|
return vals
|
|
113
140
|
|
|
114
141
|
def _process_non_datachain_record(
|
|
115
142
|
self,
|
|
116
143
|
record: dict[str, Any],
|
|
117
|
-
hf_schema:
|
|
144
|
+
hf_schema: tuple["Features", dict[str, "DataType"]] | None,
|
|
118
145
|
):
|
|
119
146
|
vals = list(record.values())
|
|
120
147
|
if not self.output_schema:
|
|
@@ -122,7 +149,9 @@ class ArrowGenerator(Generator):
|
|
|
122
149
|
|
|
123
150
|
fields = self.output_schema.model_fields
|
|
124
151
|
vals_dict = {}
|
|
125
|
-
for i, ((field, field_info), val) in enumerate(
|
|
152
|
+
for i, ((field, field_info), val) in enumerate(
|
|
153
|
+
zip(fields.items(), vals, strict=False)
|
|
154
|
+
):
|
|
126
155
|
anno = field_info.annotation
|
|
127
156
|
if hf_schema:
|
|
128
157
|
from datachain.lib.hf import convert_feature
|
|
@@ -137,9 +166,13 @@ class ArrowGenerator(Generator):
|
|
|
137
166
|
|
|
138
167
|
|
|
139
168
|
def infer_schema(chain: "DataChain", **kwargs) -> pa.Schema:
|
|
169
|
+
parse_options = kwargs.pop("parse_options", None)
|
|
170
|
+
if format := kwargs.get("format"):
|
|
171
|
+
kwargs["format"] = fix_pyarrow_format(format, parse_options)
|
|
172
|
+
|
|
140
173
|
schemas = []
|
|
141
|
-
for file in chain.
|
|
142
|
-
ds = dataset(file.
|
|
174
|
+
for (file,) in chain.to_iter("file"):
|
|
175
|
+
ds = dataset(file.get_fs_path(), filesystem=file.get_fs(), **kwargs) # type: ignore[union-attr]
|
|
143
176
|
schemas.append(ds.schema)
|
|
144
177
|
if not schemas:
|
|
145
178
|
raise ValueError(
|
|
@@ -149,7 +182,7 @@ def infer_schema(chain: "DataChain", **kwargs) -> pa.Schema:
|
|
|
149
182
|
|
|
150
183
|
|
|
151
184
|
def schema_to_output(
|
|
152
|
-
schema: pa.Schema, col_names:
|
|
185
|
+
schema: pa.Schema, col_names: Sequence[str] | None = None
|
|
153
186
|
) -> tuple[dict[str, type], list[str]]:
|
|
154
187
|
"""
|
|
155
188
|
Generate UDF output schema from pyarrow schema.
|
|
@@ -174,14 +207,15 @@ def schema_to_output(
|
|
|
174
207
|
hf_schema = _get_hf_schema(schema)
|
|
175
208
|
if hf_schema:
|
|
176
209
|
return {
|
|
177
|
-
column: hf_type
|
|
210
|
+
column: hf_type
|
|
211
|
+
for hf_type, column in zip(hf_schema[1].values(), col_names, strict=False)
|
|
178
212
|
}, list(normalized_col_dict.values())
|
|
179
213
|
|
|
180
214
|
output = {}
|
|
181
|
-
for field, column in zip(schema, col_names):
|
|
215
|
+
for field, column in zip(schema, col_names, strict=False):
|
|
182
216
|
dtype = arrow_type_mapper(field.type, column)
|
|
183
217
|
if field.nullable and not ModelStore.is_pydantic(dtype):
|
|
184
|
-
dtype =
|
|
218
|
+
dtype = dtype | None # type: ignore[assignment]
|
|
185
219
|
output[column] = dtype
|
|
186
220
|
|
|
187
221
|
return output, list(normalized_col_dict.values())
|
|
@@ -212,31 +246,33 @@ def arrow_type_mapper(col_type: pa.DataType, column: str = "") -> type: # noqa:
|
|
|
212
246
|
for field in col_type:
|
|
213
247
|
dtype = arrow_type_mapper(field.type, field.name)
|
|
214
248
|
if field.nullable and not ModelStore.is_pydantic(dtype):
|
|
215
|
-
dtype =
|
|
249
|
+
dtype = dtype | None # type: ignore[assignment]
|
|
216
250
|
type_dict[field.name] = dtype
|
|
217
|
-
return dict_to_data_model(column, type_dict)
|
|
251
|
+
return dict_to_data_model(f"ArrowDataModel_{column}", type_dict)
|
|
218
252
|
if pa.types.is_map(col_type):
|
|
219
253
|
return dict
|
|
220
254
|
if isinstance(col_type, pa.lib.DictionaryType):
|
|
221
255
|
return arrow_type_mapper(col_type.value_type) # type: ignore[return-value]
|
|
256
|
+
if pa.types.is_null(col_type):
|
|
257
|
+
return str # use strings for null columns
|
|
222
258
|
raise TypeError(f"{col_type!r} datatypes not supported, column: {column}")
|
|
223
259
|
|
|
224
260
|
|
|
225
261
|
def _get_hf_schema(
|
|
226
262
|
schema: "pa.Schema",
|
|
227
|
-
) ->
|
|
263
|
+
) -> tuple["Features", dict[str, "DataType"]] | None:
|
|
228
264
|
if schema.metadata and b"huggingface" in schema.metadata:
|
|
229
265
|
from datachain.lib.hf import get_output_schema, schema_from_arrow
|
|
230
266
|
|
|
231
267
|
features = schema_from_arrow(schema)
|
|
232
|
-
return features, get_output_schema(features)
|
|
268
|
+
return features, get_output_schema(features)[0]
|
|
233
269
|
return None
|
|
234
270
|
|
|
235
271
|
|
|
236
|
-
def _get_datachain_schema(schema: "pa.Schema") ->
|
|
272
|
+
def _get_datachain_schema(schema: "pa.Schema") -> SignalSchema | None:
|
|
237
273
|
"""Return a restored SignalSchema from parquet metadata, if any is found."""
|
|
238
274
|
if schema.metadata and DATACHAIN_SIGNAL_SCHEMA_PARQUET_KEY in schema.metadata:
|
|
239
|
-
serialized_signal_schema =
|
|
275
|
+
serialized_signal_schema = json.loads(
|
|
240
276
|
schema.metadata[DATACHAIN_SIGNAL_SCHEMA_PARQUET_KEY]
|
|
241
277
|
)
|
|
242
278
|
return SignalSchema.deserialize(serialized_signal_schema)
|
datachain/lib/audio.py
ADDED
|
@@ -0,0 +1,245 @@
|
|
|
1
|
+
import posixpath
|
|
2
|
+
import re
|
|
3
|
+
from typing import TYPE_CHECKING
|
|
4
|
+
|
|
5
|
+
from datachain.lib.file import FileError
|
|
6
|
+
|
|
7
|
+
if TYPE_CHECKING:
|
|
8
|
+
from numpy import ndarray
|
|
9
|
+
|
|
10
|
+
from datachain.lib.file import Audio, AudioFile, File
|
|
11
|
+
|
|
12
|
+
try:
|
|
13
|
+
import soundfile as sf
|
|
14
|
+
except ImportError as exc:
|
|
15
|
+
raise ImportError(
|
|
16
|
+
"Missing dependencies for processing audio.\n"
|
|
17
|
+
"To install run:\n\n"
|
|
18
|
+
" pip install 'datachain[audio]'\n"
|
|
19
|
+
) from exc
|
|
20
|
+
|
|
21
|
+
|
|
22
|
+
def audio_info(file: "File | AudioFile") -> "Audio":
|
|
23
|
+
"""Extract metadata like sample rate, channels, duration, and format."""
|
|
24
|
+
from datachain.lib.file import Audio
|
|
25
|
+
|
|
26
|
+
file = file.as_audio_file()
|
|
27
|
+
|
|
28
|
+
try:
|
|
29
|
+
with file.open() as f:
|
|
30
|
+
info = sf.info(f)
|
|
31
|
+
|
|
32
|
+
sample_rate = int(info.samplerate)
|
|
33
|
+
channels = int(info.channels)
|
|
34
|
+
frames = int(info.frames)
|
|
35
|
+
duration = float(info.duration)
|
|
36
|
+
|
|
37
|
+
# soundfile provides format and subtype
|
|
38
|
+
if info.format:
|
|
39
|
+
format_name = info.format.lower()
|
|
40
|
+
else:
|
|
41
|
+
format_name = file.get_file_ext().lower()
|
|
42
|
+
|
|
43
|
+
if not format_name:
|
|
44
|
+
format_name = "unknown"
|
|
45
|
+
codec_name = info.subtype if info.subtype else ""
|
|
46
|
+
|
|
47
|
+
# Calculate bit rate from subtype
|
|
48
|
+
bits_per_sample = _get_bits_per_sample(info.subtype)
|
|
49
|
+
bit_rate = (
|
|
50
|
+
bits_per_sample * sample_rate * channels if bits_per_sample > 0 else -1
|
|
51
|
+
)
|
|
52
|
+
|
|
53
|
+
except Exception as exc:
|
|
54
|
+
raise FileError(
|
|
55
|
+
"unable to extract metadata from audio file", file.source, file.path
|
|
56
|
+
) from exc
|
|
57
|
+
|
|
58
|
+
return Audio(
|
|
59
|
+
sample_rate=sample_rate,
|
|
60
|
+
channels=channels,
|
|
61
|
+
duration=duration,
|
|
62
|
+
samples=frames,
|
|
63
|
+
format=format_name,
|
|
64
|
+
codec=codec_name,
|
|
65
|
+
bit_rate=bit_rate,
|
|
66
|
+
)
|
|
67
|
+
|
|
68
|
+
|
|
69
|
+
def _get_bits_per_sample(subtype: str) -> int:
|
|
70
|
+
"""
|
|
71
|
+
Map soundfile subtype to bits per sample.
|
|
72
|
+
|
|
73
|
+
Args:
|
|
74
|
+
subtype: The subtype string from soundfile
|
|
75
|
+
|
|
76
|
+
Returns:
|
|
77
|
+
Bits per sample, or 0 if unknown
|
|
78
|
+
"""
|
|
79
|
+
if not subtype:
|
|
80
|
+
return 0
|
|
81
|
+
|
|
82
|
+
# Common PCM and floating-point subtypes
|
|
83
|
+
pcm_bits = {
|
|
84
|
+
"PCM_16": 16,
|
|
85
|
+
"PCM_24": 24,
|
|
86
|
+
"PCM_32": 32,
|
|
87
|
+
"PCM_S8": 8,
|
|
88
|
+
"PCM_U8": 8,
|
|
89
|
+
"FLOAT": 32,
|
|
90
|
+
"DOUBLE": 64,
|
|
91
|
+
}
|
|
92
|
+
|
|
93
|
+
if subtype in pcm_bits:
|
|
94
|
+
return pcm_bits[subtype]
|
|
95
|
+
|
|
96
|
+
# Handle variants such as PCM_S16LE, PCM_F32LE, etc.
|
|
97
|
+
match = re.search(r"PCM_(?:[A-Z]*?)(\d+)", subtype)
|
|
98
|
+
if match:
|
|
99
|
+
return int(match.group(1))
|
|
100
|
+
|
|
101
|
+
return 0
|
|
102
|
+
|
|
103
|
+
|
|
104
|
+
def audio_to_np(
|
|
105
|
+
audio: "AudioFile", start: float = 0, duration: float | None = None
|
|
106
|
+
) -> "tuple[ndarray, int]":
|
|
107
|
+
"""Load audio fragment as numpy array.
|
|
108
|
+
Multi-channel audio is transposed to (samples, channels)."""
|
|
109
|
+
if start < 0:
|
|
110
|
+
raise ValueError("start must be a non-negative float")
|
|
111
|
+
|
|
112
|
+
if duration is not None and duration <= 0:
|
|
113
|
+
raise ValueError("duration must be a positive float")
|
|
114
|
+
|
|
115
|
+
if hasattr(audio, "as_audio_file"):
|
|
116
|
+
audio = audio.as_audio_file()
|
|
117
|
+
|
|
118
|
+
try:
|
|
119
|
+
with audio.open() as f:
|
|
120
|
+
info = sf.info(f)
|
|
121
|
+
sample_rate = info.samplerate
|
|
122
|
+
|
|
123
|
+
frame_offset = int(start * sample_rate)
|
|
124
|
+
num_frames = int(duration * sample_rate) if duration is not None else -1
|
|
125
|
+
|
|
126
|
+
# Reset file pointer to the beginning
|
|
127
|
+
f.seek(0)
|
|
128
|
+
|
|
129
|
+
# Read audio data with offset and frame count
|
|
130
|
+
audio_np, sr = sf.read(
|
|
131
|
+
f,
|
|
132
|
+
start=frame_offset,
|
|
133
|
+
frames=num_frames,
|
|
134
|
+
always_2d=False,
|
|
135
|
+
dtype="float32",
|
|
136
|
+
)
|
|
137
|
+
|
|
138
|
+
# soundfile returns shape (frames,) for mono or
|
|
139
|
+
# (frames, channels) for multi-channel
|
|
140
|
+
# We keep this format as it matches expected output
|
|
141
|
+
return audio_np, int(sr)
|
|
142
|
+
except Exception as exc:
|
|
143
|
+
raise FileError(
|
|
144
|
+
"unable to read audio fragment", audio.source, audio.path
|
|
145
|
+
) from exc
|
|
146
|
+
|
|
147
|
+
|
|
148
|
+
def audio_to_bytes(
|
|
149
|
+
audio: "AudioFile",
|
|
150
|
+
format: str = "wav",
|
|
151
|
+
start: float = 0,
|
|
152
|
+
duration: float | None = None,
|
|
153
|
+
) -> bytes:
|
|
154
|
+
"""Convert audio to bytes using soundfile.
|
|
155
|
+
|
|
156
|
+
If duration is None, converts from start to end of file.
|
|
157
|
+
If start is 0 and duration is None, converts entire file."""
|
|
158
|
+
import io
|
|
159
|
+
|
|
160
|
+
y, sr = audio_to_np(audio, start, duration)
|
|
161
|
+
|
|
162
|
+
buffer = io.BytesIO()
|
|
163
|
+
sf.write(buffer, y, sr, format=format)
|
|
164
|
+
return buffer.getvalue()
|
|
165
|
+
|
|
166
|
+
|
|
167
|
+
def save_audio(
|
|
168
|
+
audio: "AudioFile",
|
|
169
|
+
output: str,
|
|
170
|
+
format: str | None = None,
|
|
171
|
+
start: float = 0,
|
|
172
|
+
end: float | None = None,
|
|
173
|
+
) -> "AudioFile":
|
|
174
|
+
"""Save audio file or extract fragment to specified format.
|
|
175
|
+
|
|
176
|
+
Args:
|
|
177
|
+
audio: Source AudioFile object
|
|
178
|
+
output: Output directory path
|
|
179
|
+
format: Output format ('wav', 'mp3', etc). Defaults to source format
|
|
180
|
+
start: Start time in seconds (>= 0). Defaults to 0
|
|
181
|
+
end: End time in seconds. If None, extracts to end of file
|
|
182
|
+
|
|
183
|
+
Returns:
|
|
184
|
+
AudioFile: New audio file with format conversion/extraction applied
|
|
185
|
+
|
|
186
|
+
Examples:
|
|
187
|
+
save_audio(audio, "/path", "mp3") # Entire file to MP3
|
|
188
|
+
save_audio(audio, "s3://bucket/path", "wav", start=2.5) # From 2.5s to end
|
|
189
|
+
save_audio(audio, "/path", "flac", start=1, end=3) # Extract 1-3s fragment
|
|
190
|
+
"""
|
|
191
|
+
if format is None:
|
|
192
|
+
format = audio.get_file_ext()
|
|
193
|
+
|
|
194
|
+
# Validate start time
|
|
195
|
+
if start < 0:
|
|
196
|
+
raise ValueError(
|
|
197
|
+
f"Can't save audio for '{audio.path}', "
|
|
198
|
+
f"start time must be non-negative: {start:.3f}"
|
|
199
|
+
)
|
|
200
|
+
|
|
201
|
+
# Handle full file conversion when end is None and start is 0
|
|
202
|
+
if end is None and start == 0:
|
|
203
|
+
output_file = posixpath.join(output, f"{audio.get_file_stem()}.{format}")
|
|
204
|
+
try:
|
|
205
|
+
audio_bytes = audio_to_bytes(audio, format, start=0, duration=None)
|
|
206
|
+
except Exception as exc:
|
|
207
|
+
raise FileError(
|
|
208
|
+
"unable to convert audio file", audio.source, audio.path
|
|
209
|
+
) from exc
|
|
210
|
+
elif end is None:
|
|
211
|
+
# Extract from start to end of file
|
|
212
|
+
output_file = posixpath.join(
|
|
213
|
+
output, f"{audio.get_file_stem()}_{int(start * 1000):06d}_end.{format}"
|
|
214
|
+
)
|
|
215
|
+
try:
|
|
216
|
+
audio_bytes = audio_to_bytes(audio, format, start=start, duration=None)
|
|
217
|
+
except Exception as exc:
|
|
218
|
+
raise FileError(
|
|
219
|
+
"unable to save audio fragment", audio.source, audio.path
|
|
220
|
+
) from exc
|
|
221
|
+
else:
|
|
222
|
+
# Fragment extraction mode with specific end time
|
|
223
|
+
if end < 0 or start >= end:
|
|
224
|
+
raise ValueError(
|
|
225
|
+
f"Can't save audio for '{audio.path}', "
|
|
226
|
+
f"invalid time range: ({start:.3f}, {end:.3f})"
|
|
227
|
+
)
|
|
228
|
+
|
|
229
|
+
duration = end - start
|
|
230
|
+
start_ms = int(start * 1000)
|
|
231
|
+
end_ms = int(end * 1000)
|
|
232
|
+
output_file = posixpath.join(
|
|
233
|
+
output, f"{audio.get_file_stem()}_{start_ms:06d}_{end_ms:06d}.{format}"
|
|
234
|
+
)
|
|
235
|
+
|
|
236
|
+
try:
|
|
237
|
+
audio_bytes = audio_to_bytes(audio, format, start, duration)
|
|
238
|
+
except Exception as exc:
|
|
239
|
+
raise FileError(
|
|
240
|
+
"unable to save audio fragment", audio.source, audio.path
|
|
241
|
+
) from exc
|
|
242
|
+
|
|
243
|
+
from datachain.lib.file import AudioFile
|
|
244
|
+
|
|
245
|
+
return AudioFile.upload(audio_bytes, output_file, catalog=audio._catalog)
|
datachain/lib/clip.py
CHANGED
|
@@ -1,5 +1,6 @@
|
|
|
1
1
|
import inspect
|
|
2
|
-
from
|
|
2
|
+
from collections.abc import Callable
|
|
3
|
+
from typing import TYPE_CHECKING, Any, Literal, Union
|
|
3
4
|
|
|
4
5
|
import torch
|
|
5
6
|
from transformers.modeling_utils import PreTrainedModel
|
|
@@ -32,28 +33,28 @@ def _get_encoder(model: Any, type: Literal["image", "text"]) -> Callable:
|
|
|
32
33
|
|
|
33
34
|
|
|
34
35
|
def clip_similarity_scores(
|
|
35
|
-
images: Union[
|
|
36
|
-
text:
|
|
36
|
+
images: Union["Image.Image", list["Image.Image"]] | None,
|
|
37
|
+
text: str | list[str] | None,
|
|
37
38
|
model: Any,
|
|
38
39
|
preprocess: Callable,
|
|
39
40
|
tokenizer: Callable,
|
|
40
41
|
prob: bool = False,
|
|
41
42
|
image_to_text: bool = True,
|
|
42
|
-
device:
|
|
43
|
+
device: str | torch.device | None = None,
|
|
43
44
|
) -> list[list[float]]:
|
|
44
45
|
"""
|
|
45
46
|
Calculate CLIP similarity scores between one or more images and/or text.
|
|
46
47
|
|
|
47
48
|
Parameters:
|
|
48
|
-
images
|
|
49
|
-
text
|
|
50
|
-
model
|
|
51
|
-
preprocess
|
|
52
|
-
tokenizer
|
|
53
|
-
prob
|
|
54
|
-
image_to_text
|
|
55
|
-
if only one of images or text provided.
|
|
56
|
-
device
|
|
49
|
+
images: Images to use as inputs.
|
|
50
|
+
text: Text to use as inputs.
|
|
51
|
+
model: Model from clip or open_clip packages.
|
|
52
|
+
preprocess: Image preprocessor to apply.
|
|
53
|
+
tokenizer: Text tokenizer.
|
|
54
|
+
prob: Compute softmax probabilities.
|
|
55
|
+
image_to_text: Whether to compute for image-to-text or text-to-image. Ignored
|
|
56
|
+
if only one of the images or text provided.
|
|
57
|
+
device: Device to use. Default is None - use model's device.
|
|
57
58
|
|
|
58
59
|
|
|
59
60
|
Example:
|
datachain/lib/convert/flatten.py
CHANGED
|
@@ -6,12 +6,14 @@ from datachain.lib.model_store import ModelStore
|
|
|
6
6
|
|
|
7
7
|
|
|
8
8
|
def flatten(obj: BaseModel) -> tuple:
|
|
9
|
-
return tuple(_flatten_fields_values(obj.model_fields, obj))
|
|
9
|
+
return tuple(_flatten_fields_values(type(obj).model_fields, obj))
|
|
10
10
|
|
|
11
11
|
|
|
12
12
|
def flatten_list(obj_list: list[BaseModel]) -> tuple:
|
|
13
13
|
return tuple(
|
|
14
|
-
val
|
|
14
|
+
val
|
|
15
|
+
for obj in obj_list
|
|
16
|
+
for val in _flatten_fields_values(type(obj).model_fields, obj)
|
|
15
17
|
)
|
|
16
18
|
|
|
17
19
|
|
|
@@ -43,4 +45,4 @@ def _flatten_fields_values(fields: dict, obj: BaseModel) -> Generator:
|
|
|
43
45
|
|
|
44
46
|
|
|
45
47
|
def _flatten(obj: BaseModel) -> tuple:
|
|
46
|
-
return tuple(_flatten_fields_values(obj.model_fields, obj))
|
|
48
|
+
return tuple(_flatten_fields_values(type(obj).model_fields, obj))
|
|
@@ -1,6 +1,7 @@
|
|
|
1
1
|
import inspect
|
|
2
2
|
from datetime import datetime
|
|
3
3
|
from enum import Enum
|
|
4
|
+
from types import UnionType
|
|
4
5
|
from typing import Annotated, Literal, Union, get_args, get_origin
|
|
5
6
|
|
|
6
7
|
from pydantic import BaseModel
|
|
@@ -69,11 +70,12 @@ def python_to_sql(typ): # noqa: PLR0911
|
|
|
69
70
|
if inspect.isclass(orig) and issubclass(dict, orig):
|
|
70
71
|
return JSON
|
|
71
72
|
|
|
72
|
-
if orig
|
|
73
|
+
if orig in (Union, UnionType):
|
|
73
74
|
if len(args) == 2 and (type(None) in args):
|
|
74
|
-
|
|
75
|
+
non_none_arg = args[0] if args[0] is not type(None) else args[1]
|
|
76
|
+
return python_to_sql(non_none_arg)
|
|
75
77
|
|
|
76
|
-
if
|
|
78
|
+
if all(arg is str or get_origin(arg) in (Literal, LiteralEx) for arg in args):
|
|
77
79
|
return String
|
|
78
80
|
|
|
79
81
|
if _is_json_inside_union(orig, args):
|
|
@@ -95,7 +97,7 @@ def list_of_args_to_type(args) -> SQLType:
|
|
|
95
97
|
|
|
96
98
|
|
|
97
99
|
def _is_json_inside_union(orig, args) -> bool:
|
|
98
|
-
if orig
|
|
100
|
+
if orig in (Union, UnionType) and len(args) >= 2:
|
|
99
101
|
# List in JSON: Union[dict, list[dict]]
|
|
100
102
|
args_no_nones = [arg for arg in args if arg != type(None)] # noqa: E721
|
|
101
103
|
if len(args_no_nones) == 2:
|
|
@@ -109,9 +111,3 @@ def _is_json_inside_union(orig, args) -> bool:
|
|
|
109
111
|
if any(inspect.isclass(arg) and issubclass(arg, BaseModel) for arg in args):
|
|
110
112
|
return True
|
|
111
113
|
return False
|
|
112
|
-
|
|
113
|
-
|
|
114
|
-
def _is_union_str_literal(orig, args) -> bool:
|
|
115
|
-
if orig != Union:
|
|
116
|
-
return False
|
|
117
|
-
return all(arg is str or get_origin(arg) in (Literal, LiteralEx) for arg in args)
|