datachain 0.14.2__py3-none-any.whl → 0.39.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (137) hide show
  1. datachain/__init__.py +20 -0
  2. datachain/asyn.py +11 -12
  3. datachain/cache.py +7 -7
  4. datachain/catalog/__init__.py +2 -2
  5. datachain/catalog/catalog.py +621 -507
  6. datachain/catalog/dependency.py +164 -0
  7. datachain/catalog/loader.py +28 -18
  8. datachain/checkpoint.py +43 -0
  9. datachain/cli/__init__.py +24 -33
  10. datachain/cli/commands/__init__.py +1 -8
  11. datachain/cli/commands/datasets.py +83 -52
  12. datachain/cli/commands/ls.py +17 -17
  13. datachain/cli/commands/show.py +4 -4
  14. datachain/cli/parser/__init__.py +8 -74
  15. datachain/cli/parser/job.py +95 -3
  16. datachain/cli/parser/studio.py +11 -4
  17. datachain/cli/parser/utils.py +1 -2
  18. datachain/cli/utils.py +2 -15
  19. datachain/client/azure.py +4 -4
  20. datachain/client/fsspec.py +45 -28
  21. datachain/client/gcs.py +6 -6
  22. datachain/client/hf.py +29 -2
  23. datachain/client/http.py +157 -0
  24. datachain/client/local.py +15 -11
  25. datachain/client/s3.py +17 -9
  26. datachain/config.py +4 -8
  27. datachain/data_storage/db_engine.py +12 -6
  28. datachain/data_storage/job.py +5 -1
  29. datachain/data_storage/metastore.py +1252 -186
  30. datachain/data_storage/schema.py +58 -45
  31. datachain/data_storage/serializer.py +105 -15
  32. datachain/data_storage/sqlite.py +286 -127
  33. datachain/data_storage/warehouse.py +250 -113
  34. datachain/dataset.py +353 -148
  35. datachain/delta.py +391 -0
  36. datachain/diff/__init__.py +27 -29
  37. datachain/error.py +60 -0
  38. datachain/func/__init__.py +2 -1
  39. datachain/func/aggregate.py +66 -42
  40. datachain/func/array.py +242 -38
  41. datachain/func/base.py +7 -4
  42. datachain/func/conditional.py +110 -60
  43. datachain/func/func.py +96 -45
  44. datachain/func/numeric.py +55 -38
  45. datachain/func/path.py +32 -20
  46. datachain/func/random.py +2 -2
  47. datachain/func/string.py +67 -37
  48. datachain/func/window.py +7 -8
  49. datachain/hash_utils.py +123 -0
  50. datachain/job.py +11 -7
  51. datachain/json.py +138 -0
  52. datachain/lib/arrow.py +58 -22
  53. datachain/lib/audio.py +245 -0
  54. datachain/lib/clip.py +14 -13
  55. datachain/lib/convert/flatten.py +5 -3
  56. datachain/lib/convert/python_to_sql.py +6 -10
  57. datachain/lib/convert/sql_to_python.py +8 -0
  58. datachain/lib/convert/values_to_tuples.py +156 -51
  59. datachain/lib/data_model.py +42 -20
  60. datachain/lib/dataset_info.py +36 -8
  61. datachain/lib/dc/__init__.py +8 -2
  62. datachain/lib/dc/csv.py +25 -28
  63. datachain/lib/dc/database.py +398 -0
  64. datachain/lib/dc/datachain.py +1289 -425
  65. datachain/lib/dc/datasets.py +320 -38
  66. datachain/lib/dc/hf.py +38 -24
  67. datachain/lib/dc/json.py +29 -32
  68. datachain/lib/dc/listings.py +112 -8
  69. datachain/lib/dc/pandas.py +16 -12
  70. datachain/lib/dc/parquet.py +35 -23
  71. datachain/lib/dc/records.py +31 -23
  72. datachain/lib/dc/storage.py +154 -64
  73. datachain/lib/dc/storage_pattern.py +251 -0
  74. datachain/lib/dc/utils.py +24 -16
  75. datachain/lib/dc/values.py +8 -9
  76. datachain/lib/file.py +622 -89
  77. datachain/lib/hf.py +69 -39
  78. datachain/lib/image.py +14 -14
  79. datachain/lib/listing.py +14 -11
  80. datachain/lib/listing_info.py +1 -2
  81. datachain/lib/meta_formats.py +3 -4
  82. datachain/lib/model_store.py +39 -7
  83. datachain/lib/namespaces.py +125 -0
  84. datachain/lib/projects.py +130 -0
  85. datachain/lib/pytorch.py +32 -21
  86. datachain/lib/settings.py +192 -56
  87. datachain/lib/signal_schema.py +427 -104
  88. datachain/lib/tar.py +1 -2
  89. datachain/lib/text.py +8 -7
  90. datachain/lib/udf.py +164 -76
  91. datachain/lib/udf_signature.py +60 -35
  92. datachain/lib/utils.py +118 -4
  93. datachain/lib/video.py +17 -9
  94. datachain/lib/webdataset.py +61 -56
  95. datachain/lib/webdataset_laion.py +15 -16
  96. datachain/listing.py +22 -10
  97. datachain/model/bbox.py +3 -1
  98. datachain/model/ultralytics/bbox.py +16 -12
  99. datachain/model/ultralytics/pose.py +16 -12
  100. datachain/model/ultralytics/segment.py +16 -12
  101. datachain/namespace.py +84 -0
  102. datachain/node.py +6 -6
  103. datachain/nodes_thread_pool.py +0 -1
  104. datachain/plugins.py +24 -0
  105. datachain/project.py +78 -0
  106. datachain/query/batch.py +40 -41
  107. datachain/query/dataset.py +604 -322
  108. datachain/query/dispatch.py +261 -154
  109. datachain/query/metrics.py +4 -6
  110. datachain/query/params.py +2 -3
  111. datachain/query/queue.py +3 -12
  112. datachain/query/schema.py +11 -6
  113. datachain/query/session.py +200 -33
  114. datachain/query/udf.py +34 -2
  115. datachain/remote/studio.py +171 -69
  116. datachain/script_meta.py +12 -12
  117. datachain/semver.py +68 -0
  118. datachain/sql/__init__.py +2 -0
  119. datachain/sql/functions/array.py +33 -1
  120. datachain/sql/postgresql_dialect.py +9 -0
  121. datachain/sql/postgresql_types.py +21 -0
  122. datachain/sql/sqlite/__init__.py +5 -1
  123. datachain/sql/sqlite/base.py +102 -29
  124. datachain/sql/sqlite/types.py +8 -13
  125. datachain/sql/types.py +70 -15
  126. datachain/studio.py +223 -46
  127. datachain/toolkit/split.py +31 -10
  128. datachain/utils.py +101 -59
  129. {datachain-0.14.2.dist-info → datachain-0.39.0.dist-info}/METADATA +77 -22
  130. datachain-0.39.0.dist-info/RECORD +173 -0
  131. {datachain-0.14.2.dist-info → datachain-0.39.0.dist-info}/WHEEL +1 -1
  132. datachain/cli/commands/query.py +0 -53
  133. datachain/query/utils.py +0 -42
  134. datachain-0.14.2.dist-info/RECORD +0 -158
  135. {datachain-0.14.2.dist-info → datachain-0.39.0.dist-info}/entry_points.txt +0 -0
  136. {datachain-0.14.2.dist-info → datachain-0.39.0.dist-info}/licenses/LICENSE +0 -0
  137. {datachain-0.14.2.dist-info → datachain-0.39.0.dist-info}/top_level.txt +0 -0
datachain/json.py ADDED
@@ -0,0 +1,138 @@
1
+ """DataChain JSON utilities.
2
+
3
+ This module wraps :mod:`ujson` so we can guarantee consistent handling
4
+ of values that the encoder does not support out of the box (for example
5
+ ``datetime`` objects or ``bytes``).
6
+ All code inside DataChain should import this module instead of using
7
+ :mod:`ujson` directly.
8
+ """
9
+
10
+ import datetime as _dt
11
+ import json as _json
12
+ import uuid as _uuid
13
+ from collections.abc import Callable
14
+ from typing import Any
15
+
16
+ import ujson as _ujson
17
+
18
+ __all__ = [
19
+ "JSONDecodeError",
20
+ "dump",
21
+ "dumps",
22
+ "load",
23
+ "loads",
24
+ ]
25
+
26
+ JSONDecodeError = (_ujson.JSONDecodeError, _json.JSONDecodeError)
27
+
28
+ _SENTINEL = object()
29
+ _Default = Callable[[Any], Any]
30
+ DEFAULT_PREVIEW_BYTES = 1024
31
+
32
+
33
+ # To make it looks like Pydantic's ISO format with 'Z' for UTC
34
+ # It is minor but nice to have consistency
35
+ def _format_datetime(value: _dt.datetime) -> str:
36
+ iso = value.isoformat()
37
+
38
+ offset = value.utcoffset()
39
+ if value.tzinfo is None or offset is None:
40
+ return iso
41
+
42
+ if offset == _dt.timedelta(0) and iso.endswith(("+00:00", "-00:00")):
43
+ return iso[:-6] + "Z"
44
+
45
+ return iso
46
+
47
+
48
+ def _format_time(value: _dt.time) -> str:
49
+ iso = value.isoformat()
50
+
51
+ offset = value.utcoffset()
52
+ if value.tzinfo is None or offset is None:
53
+ return iso
54
+
55
+ if offset == _dt.timedelta(0) and iso.endswith(("+00:00", "-00:00")):
56
+ return iso[:-6] + "Z"
57
+
58
+ return iso
59
+
60
+
61
+ def _coerce(value: Any, serialize_bytes: bool) -> Any:
62
+ """Return a JSON-serializable representation for supported extra types."""
63
+
64
+ if isinstance(value, _dt.datetime):
65
+ return _format_datetime(value)
66
+ if isinstance(value, _dt.date):
67
+ return value.isoformat()
68
+ if isinstance(value, _dt.time):
69
+ return _format_time(value)
70
+ if isinstance(value, _uuid.UUID):
71
+ return str(value)
72
+ if serialize_bytes and isinstance(value, (bytes, bytearray)):
73
+ return list(bytes(value)[:DEFAULT_PREVIEW_BYTES])
74
+ return _SENTINEL
75
+
76
+
77
+ def _base_default(value: Any, serialize_bytes: bool) -> Any:
78
+ converted = _coerce(value, serialize_bytes)
79
+ if converted is not _SENTINEL:
80
+ return converted
81
+ raise TypeError(f"Object of type {type(value).__name__} is not JSON serializable")
82
+
83
+
84
+ def _build_default(user_default: _Default | None, serialize_bytes: bool) -> _Default:
85
+ if user_default is None:
86
+ return lambda value: _base_default(value, serialize_bytes)
87
+
88
+ def combined(value: Any) -> Any:
89
+ converted = _coerce(value, serialize_bytes)
90
+ if converted is not _SENTINEL:
91
+ return converted
92
+ return user_default(value)
93
+
94
+ return combined
95
+
96
+
97
+ def dumps(
98
+ obj: Any,
99
+ *,
100
+ default: _Default | None = None,
101
+ serialize_bytes: bool = False,
102
+ **kwargs: Any,
103
+ ) -> str:
104
+ """Serialize *obj* to a JSON-formatted ``str``."""
105
+
106
+ if serialize_bytes:
107
+ return _json.dumps(obj, default=_build_default(default, True), **kwargs)
108
+
109
+ return _ujson.dumps(obj, default=_build_default(default, False), **kwargs)
110
+
111
+
112
+ def dump(
113
+ obj: Any,
114
+ fp,
115
+ *,
116
+ default: _Default | None = None,
117
+ serialize_bytes: bool = False,
118
+ **kwargs: Any,
119
+ ) -> None:
120
+ """Serialize *obj* as a JSON formatted stream to *fp*."""
121
+
122
+ if serialize_bytes:
123
+ _json.dump(obj, fp, default=_build_default(default, True), **kwargs)
124
+ return
125
+
126
+ _ujson.dump(obj, fp, default=_build_default(default, False), **kwargs)
127
+
128
+
129
+ def loads(s: str | bytes | bytearray, **kwargs: Any) -> Any:
130
+ """Deserialize *s* to a Python object."""
131
+
132
+ return _ujson.loads(s, **kwargs)
133
+
134
+
135
+ def load(fp, **kwargs: Any) -> Any:
136
+ """Deserialize JSON content from *fp* to a Python object."""
137
+
138
+ return loads(fp.read(), **kwargs)
datachain/lib/arrow.py CHANGED
@@ -1,12 +1,13 @@
1
1
  from collections.abc import Sequence
2
2
  from itertools import islice
3
- from typing import TYPE_CHECKING, Any, Optional
3
+ from typing import TYPE_CHECKING, Any
4
4
 
5
- import orjson
6
5
  import pyarrow as pa
6
+ from pyarrow._csv import ParseOptions
7
7
  from pyarrow.dataset import CsvFileFormat, dataset
8
8
  from tqdm.auto import tqdm
9
9
 
10
+ from datachain import json
10
11
  from datachain.fs.reference import ReferenceFileSystem
11
12
  from datachain.lib.data_model import dict_to_data_model
12
13
  from datachain.lib.file import ArrowRow, File
@@ -26,15 +27,27 @@ if TYPE_CHECKING:
26
27
  DATACHAIN_SIGNAL_SCHEMA_PARQUET_KEY = b"DataChain SignalSchema"
27
28
 
28
29
 
30
+ def fix_pyarrow_format(format, parse_options=None):
31
+ # Re-init invalid row handler: https://issues.apache.org/jira/browse/ARROW-17641
32
+ if (
33
+ format
34
+ and isinstance(format, CsvFileFormat)
35
+ and parse_options
36
+ and isinstance(parse_options, ParseOptions)
37
+ ):
38
+ format.parse_options = parse_options
39
+ return format
40
+
41
+
29
42
  class ArrowGenerator(Generator):
30
43
  DEFAULT_BATCH_SIZE = 2**17 # same as `pyarrow._dataset._DEFAULT_BATCH_SIZE`
31
44
 
32
45
  def __init__(
33
46
  self,
34
- input_schema: Optional["pa.Schema"] = None,
35
- output_schema: Optional[type["BaseModel"]] = None,
47
+ input_schema: pa.Schema | None = None,
48
+ output_schema: type["BaseModel"] | None = None,
36
49
  source: bool = True,
37
- nrows: Optional[int] = None,
50
+ nrows: int | None = None,
38
51
  **kwargs,
39
52
  ):
40
53
  """
@@ -53,6 +66,7 @@ class ArrowGenerator(Generator):
53
66
  self.output_schema = output_schema
54
67
  self.source = source
55
68
  self.nrows = nrows
69
+ self.parse_options = kwargs.pop("parse_options", None)
56
70
  self.kwargs = kwargs
57
71
 
58
72
  def process(self, file: File):
@@ -62,9 +76,13 @@ class ArrowGenerator(Generator):
62
76
  fs_path = file.path
63
77
  fs = ReferenceFileSystem({fs_path: [cache_path]})
64
78
  else:
65
- fs, fs_path = file.get_fs(), file.get_path()
79
+ fs, fs_path = file.get_fs(), file.get_fs_path()
80
+
81
+ kwargs = self.kwargs
82
+ if format := kwargs.get("format"):
83
+ kwargs["format"] = fix_pyarrow_format(format, self.parse_options)
66
84
 
67
- ds = dataset(fs_path, schema=self.input_schema, filesystem=fs, **self.kwargs)
85
+ ds = dataset(fs_path, schema=self.input_schema, filesystem=fs, **kwargs)
68
86
 
69
87
  hf_schema = _get_hf_schema(ds.schema)
70
88
  use_datachain_schema = (
@@ -94,7 +112,7 @@ class ArrowGenerator(Generator):
94
112
  record: dict[str, Any],
95
113
  file: File,
96
114
  index: int,
97
- hf_schema: Optional[tuple["Features", dict[str, "DataType"]]],
115
+ hf_schema: tuple["Features", dict[str, "DataType"]] | None,
98
116
  use_datachain_schema: bool,
99
117
  ):
100
118
  if use_datachain_schema and self.output_schema:
@@ -108,13 +126,22 @@ class ArrowGenerator(Generator):
108
126
  if isinstance(kwargs.get("format"), CsvFileFormat):
109
127
  kwargs["format"] = "csv"
110
128
  arrow_file = ArrowRow(file=file, index=index, kwargs=kwargs)
129
+
130
+ if self.output_schema and hasattr(vals[0], "source"):
131
+ # if we are reading parquet file written by datachain it might have
132
+ # source inside of it already, so we should not duplicate it, instead
133
+ # we are re-creating it of the self.source flag
134
+ vals[0].source = arrow_file # type: ignore[attr-defined]
135
+
136
+ return vals
111
137
  return [arrow_file, *vals]
138
+
112
139
  return vals
113
140
 
114
141
  def _process_non_datachain_record(
115
142
  self,
116
143
  record: dict[str, Any],
117
- hf_schema: Optional[tuple["Features", dict[str, "DataType"]]],
144
+ hf_schema: tuple["Features", dict[str, "DataType"]] | None,
118
145
  ):
119
146
  vals = list(record.values())
120
147
  if not self.output_schema:
@@ -122,7 +149,9 @@ class ArrowGenerator(Generator):
122
149
 
123
150
  fields = self.output_schema.model_fields
124
151
  vals_dict = {}
125
- for i, ((field, field_info), val) in enumerate(zip(fields.items(), vals)):
152
+ for i, ((field, field_info), val) in enumerate(
153
+ zip(fields.items(), vals, strict=False)
154
+ ):
126
155
  anno = field_info.annotation
127
156
  if hf_schema:
128
157
  from datachain.lib.hf import convert_feature
@@ -137,9 +166,13 @@ class ArrowGenerator(Generator):
137
166
 
138
167
 
139
168
  def infer_schema(chain: "DataChain", **kwargs) -> pa.Schema:
169
+ parse_options = kwargs.pop("parse_options", None)
170
+ if format := kwargs.get("format"):
171
+ kwargs["format"] = fix_pyarrow_format(format, parse_options)
172
+
140
173
  schemas = []
141
- for file in chain.collect("file"):
142
- ds = dataset(file.get_path(), filesystem=file.get_fs(), **kwargs) # type: ignore[union-attr]
174
+ for (file,) in chain.to_iter("file"):
175
+ ds = dataset(file.get_fs_path(), filesystem=file.get_fs(), **kwargs) # type: ignore[union-attr]
143
176
  schemas.append(ds.schema)
144
177
  if not schemas:
145
178
  raise ValueError(
@@ -149,7 +182,7 @@ def infer_schema(chain: "DataChain", **kwargs) -> pa.Schema:
149
182
 
150
183
 
151
184
  def schema_to_output(
152
- schema: pa.Schema, col_names: Optional[Sequence[str]] = None
185
+ schema: pa.Schema, col_names: Sequence[str] | None = None
153
186
  ) -> tuple[dict[str, type], list[str]]:
154
187
  """
155
188
  Generate UDF output schema from pyarrow schema.
@@ -174,14 +207,15 @@ def schema_to_output(
174
207
  hf_schema = _get_hf_schema(schema)
175
208
  if hf_schema:
176
209
  return {
177
- column: hf_type for hf_type, column in zip(hf_schema[1].values(), col_names)
210
+ column: hf_type
211
+ for hf_type, column in zip(hf_schema[1].values(), col_names, strict=False)
178
212
  }, list(normalized_col_dict.values())
179
213
 
180
214
  output = {}
181
- for field, column in zip(schema, col_names):
215
+ for field, column in zip(schema, col_names, strict=False):
182
216
  dtype = arrow_type_mapper(field.type, column)
183
217
  if field.nullable and not ModelStore.is_pydantic(dtype):
184
- dtype = Optional[dtype] # type: ignore[assignment]
218
+ dtype = dtype | None # type: ignore[assignment]
185
219
  output[column] = dtype
186
220
 
187
221
  return output, list(normalized_col_dict.values())
@@ -212,31 +246,33 @@ def arrow_type_mapper(col_type: pa.DataType, column: str = "") -> type: # noqa:
212
246
  for field in col_type:
213
247
  dtype = arrow_type_mapper(field.type, field.name)
214
248
  if field.nullable and not ModelStore.is_pydantic(dtype):
215
- dtype = Optional[dtype] # type: ignore[assignment]
249
+ dtype = dtype | None # type: ignore[assignment]
216
250
  type_dict[field.name] = dtype
217
- return dict_to_data_model(column, type_dict)
251
+ return dict_to_data_model(f"ArrowDataModel_{column}", type_dict)
218
252
  if pa.types.is_map(col_type):
219
253
  return dict
220
254
  if isinstance(col_type, pa.lib.DictionaryType):
221
255
  return arrow_type_mapper(col_type.value_type) # type: ignore[return-value]
256
+ if pa.types.is_null(col_type):
257
+ return str # use strings for null columns
222
258
  raise TypeError(f"{col_type!r} datatypes not supported, column: {column}")
223
259
 
224
260
 
225
261
  def _get_hf_schema(
226
262
  schema: "pa.Schema",
227
- ) -> Optional[tuple["Features", dict[str, "DataType"]]]:
263
+ ) -> tuple["Features", dict[str, "DataType"]] | None:
228
264
  if schema.metadata and b"huggingface" in schema.metadata:
229
265
  from datachain.lib.hf import get_output_schema, schema_from_arrow
230
266
 
231
267
  features = schema_from_arrow(schema)
232
- return features, get_output_schema(features)
268
+ return features, get_output_schema(features)[0]
233
269
  return None
234
270
 
235
271
 
236
- def _get_datachain_schema(schema: "pa.Schema") -> Optional[SignalSchema]:
272
+ def _get_datachain_schema(schema: "pa.Schema") -> SignalSchema | None:
237
273
  """Return a restored SignalSchema from parquet metadata, if any is found."""
238
274
  if schema.metadata and DATACHAIN_SIGNAL_SCHEMA_PARQUET_KEY in schema.metadata:
239
- serialized_signal_schema = orjson.loads(
275
+ serialized_signal_schema = json.loads(
240
276
  schema.metadata[DATACHAIN_SIGNAL_SCHEMA_PARQUET_KEY]
241
277
  )
242
278
  return SignalSchema.deserialize(serialized_signal_schema)
datachain/lib/audio.py ADDED
@@ -0,0 +1,245 @@
1
+ import posixpath
2
+ import re
3
+ from typing import TYPE_CHECKING
4
+
5
+ from datachain.lib.file import FileError
6
+
7
+ if TYPE_CHECKING:
8
+ from numpy import ndarray
9
+
10
+ from datachain.lib.file import Audio, AudioFile, File
11
+
12
+ try:
13
+ import soundfile as sf
14
+ except ImportError as exc:
15
+ raise ImportError(
16
+ "Missing dependencies for processing audio.\n"
17
+ "To install run:\n\n"
18
+ " pip install 'datachain[audio]'\n"
19
+ ) from exc
20
+
21
+
22
+ def audio_info(file: "File | AudioFile") -> "Audio":
23
+ """Extract metadata like sample rate, channels, duration, and format."""
24
+ from datachain.lib.file import Audio
25
+
26
+ file = file.as_audio_file()
27
+
28
+ try:
29
+ with file.open() as f:
30
+ info = sf.info(f)
31
+
32
+ sample_rate = int(info.samplerate)
33
+ channels = int(info.channels)
34
+ frames = int(info.frames)
35
+ duration = float(info.duration)
36
+
37
+ # soundfile provides format and subtype
38
+ if info.format:
39
+ format_name = info.format.lower()
40
+ else:
41
+ format_name = file.get_file_ext().lower()
42
+
43
+ if not format_name:
44
+ format_name = "unknown"
45
+ codec_name = info.subtype if info.subtype else ""
46
+
47
+ # Calculate bit rate from subtype
48
+ bits_per_sample = _get_bits_per_sample(info.subtype)
49
+ bit_rate = (
50
+ bits_per_sample * sample_rate * channels if bits_per_sample > 0 else -1
51
+ )
52
+
53
+ except Exception as exc:
54
+ raise FileError(
55
+ "unable to extract metadata from audio file", file.source, file.path
56
+ ) from exc
57
+
58
+ return Audio(
59
+ sample_rate=sample_rate,
60
+ channels=channels,
61
+ duration=duration,
62
+ samples=frames,
63
+ format=format_name,
64
+ codec=codec_name,
65
+ bit_rate=bit_rate,
66
+ )
67
+
68
+
69
+ def _get_bits_per_sample(subtype: str) -> int:
70
+ """
71
+ Map soundfile subtype to bits per sample.
72
+
73
+ Args:
74
+ subtype: The subtype string from soundfile
75
+
76
+ Returns:
77
+ Bits per sample, or 0 if unknown
78
+ """
79
+ if not subtype:
80
+ return 0
81
+
82
+ # Common PCM and floating-point subtypes
83
+ pcm_bits = {
84
+ "PCM_16": 16,
85
+ "PCM_24": 24,
86
+ "PCM_32": 32,
87
+ "PCM_S8": 8,
88
+ "PCM_U8": 8,
89
+ "FLOAT": 32,
90
+ "DOUBLE": 64,
91
+ }
92
+
93
+ if subtype in pcm_bits:
94
+ return pcm_bits[subtype]
95
+
96
+ # Handle variants such as PCM_S16LE, PCM_F32LE, etc.
97
+ match = re.search(r"PCM_(?:[A-Z]*?)(\d+)", subtype)
98
+ if match:
99
+ return int(match.group(1))
100
+
101
+ return 0
102
+
103
+
104
+ def audio_to_np(
105
+ audio: "AudioFile", start: float = 0, duration: float | None = None
106
+ ) -> "tuple[ndarray, int]":
107
+ """Load audio fragment as numpy array.
108
+ Multi-channel audio is transposed to (samples, channels)."""
109
+ if start < 0:
110
+ raise ValueError("start must be a non-negative float")
111
+
112
+ if duration is not None and duration <= 0:
113
+ raise ValueError("duration must be a positive float")
114
+
115
+ if hasattr(audio, "as_audio_file"):
116
+ audio = audio.as_audio_file()
117
+
118
+ try:
119
+ with audio.open() as f:
120
+ info = sf.info(f)
121
+ sample_rate = info.samplerate
122
+
123
+ frame_offset = int(start * sample_rate)
124
+ num_frames = int(duration * sample_rate) if duration is not None else -1
125
+
126
+ # Reset file pointer to the beginning
127
+ f.seek(0)
128
+
129
+ # Read audio data with offset and frame count
130
+ audio_np, sr = sf.read(
131
+ f,
132
+ start=frame_offset,
133
+ frames=num_frames,
134
+ always_2d=False,
135
+ dtype="float32",
136
+ )
137
+
138
+ # soundfile returns shape (frames,) for mono or
139
+ # (frames, channels) for multi-channel
140
+ # We keep this format as it matches expected output
141
+ return audio_np, int(sr)
142
+ except Exception as exc:
143
+ raise FileError(
144
+ "unable to read audio fragment", audio.source, audio.path
145
+ ) from exc
146
+
147
+
148
+ def audio_to_bytes(
149
+ audio: "AudioFile",
150
+ format: str = "wav",
151
+ start: float = 0,
152
+ duration: float | None = None,
153
+ ) -> bytes:
154
+ """Convert audio to bytes using soundfile.
155
+
156
+ If duration is None, converts from start to end of file.
157
+ If start is 0 and duration is None, converts entire file."""
158
+ import io
159
+
160
+ y, sr = audio_to_np(audio, start, duration)
161
+
162
+ buffer = io.BytesIO()
163
+ sf.write(buffer, y, sr, format=format)
164
+ return buffer.getvalue()
165
+
166
+
167
+ def save_audio(
168
+ audio: "AudioFile",
169
+ output: str,
170
+ format: str | None = None,
171
+ start: float = 0,
172
+ end: float | None = None,
173
+ ) -> "AudioFile":
174
+ """Save audio file or extract fragment to specified format.
175
+
176
+ Args:
177
+ audio: Source AudioFile object
178
+ output: Output directory path
179
+ format: Output format ('wav', 'mp3', etc). Defaults to source format
180
+ start: Start time in seconds (>= 0). Defaults to 0
181
+ end: End time in seconds. If None, extracts to end of file
182
+
183
+ Returns:
184
+ AudioFile: New audio file with format conversion/extraction applied
185
+
186
+ Examples:
187
+ save_audio(audio, "/path", "mp3") # Entire file to MP3
188
+ save_audio(audio, "s3://bucket/path", "wav", start=2.5) # From 2.5s to end
189
+ save_audio(audio, "/path", "flac", start=1, end=3) # Extract 1-3s fragment
190
+ """
191
+ if format is None:
192
+ format = audio.get_file_ext()
193
+
194
+ # Validate start time
195
+ if start < 0:
196
+ raise ValueError(
197
+ f"Can't save audio for '{audio.path}', "
198
+ f"start time must be non-negative: {start:.3f}"
199
+ )
200
+
201
+ # Handle full file conversion when end is None and start is 0
202
+ if end is None and start == 0:
203
+ output_file = posixpath.join(output, f"{audio.get_file_stem()}.{format}")
204
+ try:
205
+ audio_bytes = audio_to_bytes(audio, format, start=0, duration=None)
206
+ except Exception as exc:
207
+ raise FileError(
208
+ "unable to convert audio file", audio.source, audio.path
209
+ ) from exc
210
+ elif end is None:
211
+ # Extract from start to end of file
212
+ output_file = posixpath.join(
213
+ output, f"{audio.get_file_stem()}_{int(start * 1000):06d}_end.{format}"
214
+ )
215
+ try:
216
+ audio_bytes = audio_to_bytes(audio, format, start=start, duration=None)
217
+ except Exception as exc:
218
+ raise FileError(
219
+ "unable to save audio fragment", audio.source, audio.path
220
+ ) from exc
221
+ else:
222
+ # Fragment extraction mode with specific end time
223
+ if end < 0 or start >= end:
224
+ raise ValueError(
225
+ f"Can't save audio for '{audio.path}', "
226
+ f"invalid time range: ({start:.3f}, {end:.3f})"
227
+ )
228
+
229
+ duration = end - start
230
+ start_ms = int(start * 1000)
231
+ end_ms = int(end * 1000)
232
+ output_file = posixpath.join(
233
+ output, f"{audio.get_file_stem()}_{start_ms:06d}_{end_ms:06d}.{format}"
234
+ )
235
+
236
+ try:
237
+ audio_bytes = audio_to_bytes(audio, format, start, duration)
238
+ except Exception as exc:
239
+ raise FileError(
240
+ "unable to save audio fragment", audio.source, audio.path
241
+ ) from exc
242
+
243
+ from datachain.lib.file import AudioFile
244
+
245
+ return AudioFile.upload(audio_bytes, output_file, catalog=audio._catalog)
datachain/lib/clip.py CHANGED
@@ -1,5 +1,6 @@
1
1
  import inspect
2
- from typing import TYPE_CHECKING, Any, Callable, Literal, Optional, Union
2
+ from collections.abc import Callable
3
+ from typing import TYPE_CHECKING, Any, Literal, Union
3
4
 
4
5
  import torch
5
6
  from transformers.modeling_utils import PreTrainedModel
@@ -32,28 +33,28 @@ def _get_encoder(model: Any, type: Literal["image", "text"]) -> Callable:
32
33
 
33
34
 
34
35
  def clip_similarity_scores(
35
- images: Union[None, "Image.Image", list["Image.Image"]],
36
- text: Union[None, str, list[str]],
36
+ images: Union["Image.Image", list["Image.Image"]] | None,
37
+ text: str | list[str] | None,
37
38
  model: Any,
38
39
  preprocess: Callable,
39
40
  tokenizer: Callable,
40
41
  prob: bool = False,
41
42
  image_to_text: bool = True,
42
- device: Optional[Union[str, torch.device]] = None,
43
+ device: str | torch.device | None = None,
43
44
  ) -> list[list[float]]:
44
45
  """
45
46
  Calculate CLIP similarity scores between one or more images and/or text.
46
47
 
47
48
  Parameters:
48
- images : Images to use as inputs.
49
- text : Text to use as inputs.
50
- model : Model from clip or open_clip packages.
51
- preprocess : Image preprocessor to apply.
52
- tokenizer : Text tokenizer.
53
- prob : Compute softmax probabilities.
54
- image_to_text : Whether to compute for image-to-text or text-to-image. Ignored
55
- if only one of images or text provided.
56
- device : Device to use. Defaults is None - use model's device.
49
+ images: Images to use as inputs.
50
+ text: Text to use as inputs.
51
+ model: Model from clip or open_clip packages.
52
+ preprocess: Image preprocessor to apply.
53
+ tokenizer: Text tokenizer.
54
+ prob: Compute softmax probabilities.
55
+ image_to_text: Whether to compute for image-to-text or text-to-image. Ignored
56
+ if only one of the images or text provided.
57
+ device: Device to use. Default is None - use model's device.
57
58
 
58
59
 
59
60
  Example:
@@ -6,12 +6,14 @@ from datachain.lib.model_store import ModelStore
6
6
 
7
7
 
8
8
  def flatten(obj: BaseModel) -> tuple:
9
- return tuple(_flatten_fields_values(obj.model_fields, obj))
9
+ return tuple(_flatten_fields_values(type(obj).model_fields, obj))
10
10
 
11
11
 
12
12
  def flatten_list(obj_list: list[BaseModel]) -> tuple:
13
13
  return tuple(
14
- val for obj in obj_list for val in _flatten_fields_values(obj.model_fields, obj)
14
+ val
15
+ for obj in obj_list
16
+ for val in _flatten_fields_values(type(obj).model_fields, obj)
15
17
  )
16
18
 
17
19
 
@@ -43,4 +45,4 @@ def _flatten_fields_values(fields: dict, obj: BaseModel) -> Generator:
43
45
 
44
46
 
45
47
  def _flatten(obj: BaseModel) -> tuple:
46
- return tuple(_flatten_fields_values(obj.model_fields, obj))
48
+ return tuple(_flatten_fields_values(type(obj).model_fields, obj))
@@ -1,6 +1,7 @@
1
1
  import inspect
2
2
  from datetime import datetime
3
3
  from enum import Enum
4
+ from types import UnionType
4
5
  from typing import Annotated, Literal, Union, get_args, get_origin
5
6
 
6
7
  from pydantic import BaseModel
@@ -69,11 +70,12 @@ def python_to_sql(typ): # noqa: PLR0911
69
70
  if inspect.isclass(orig) and issubclass(dict, orig):
70
71
  return JSON
71
72
 
72
- if orig == Union:
73
+ if orig in (Union, UnionType):
73
74
  if len(args) == 2 and (type(None) in args):
74
- return python_to_sql(args[0])
75
+ non_none_arg = args[0] if args[0] is not type(None) else args[1]
76
+ return python_to_sql(non_none_arg)
75
77
 
76
- if _is_union_str_literal(orig, args):
78
+ if all(arg is str or get_origin(arg) in (Literal, LiteralEx) for arg in args):
77
79
  return String
78
80
 
79
81
  if _is_json_inside_union(orig, args):
@@ -95,7 +97,7 @@ def list_of_args_to_type(args) -> SQLType:
95
97
 
96
98
 
97
99
  def _is_json_inside_union(orig, args) -> bool:
98
- if orig == Union and len(args) >= 2:
100
+ if orig in (Union, UnionType) and len(args) >= 2:
99
101
  # List in JSON: Union[dict, list[dict]]
100
102
  args_no_nones = [arg for arg in args if arg != type(None)] # noqa: E721
101
103
  if len(args_no_nones) == 2:
@@ -109,9 +111,3 @@ def _is_json_inside_union(orig, args) -> bool:
109
111
  if any(inspect.isclass(arg) and issubclass(arg, BaseModel) for arg in args):
110
112
  return True
111
113
  return False
112
-
113
-
114
- def _is_union_str_literal(orig, args) -> bool:
115
- if orig != Union:
116
- return False
117
- return all(arg is str or get_origin(arg) in (Literal, LiteralEx) for arg in args)