datachain 0.30.5__py3-none-any.whl → 0.39.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (119) hide show
  1. datachain/__init__.py +4 -0
  2. datachain/asyn.py +11 -12
  3. datachain/cache.py +5 -5
  4. datachain/catalog/__init__.py +0 -2
  5. datachain/catalog/catalog.py +276 -354
  6. datachain/catalog/dependency.py +164 -0
  7. datachain/catalog/loader.py +8 -3
  8. datachain/checkpoint.py +43 -0
  9. datachain/cli/__init__.py +10 -17
  10. datachain/cli/commands/__init__.py +1 -8
  11. datachain/cli/commands/datasets.py +42 -27
  12. datachain/cli/commands/ls.py +15 -15
  13. datachain/cli/commands/show.py +2 -2
  14. datachain/cli/parser/__init__.py +3 -43
  15. datachain/cli/parser/job.py +1 -1
  16. datachain/cli/parser/utils.py +1 -2
  17. datachain/cli/utils.py +2 -15
  18. datachain/client/azure.py +2 -2
  19. datachain/client/fsspec.py +34 -23
  20. datachain/client/gcs.py +3 -3
  21. datachain/client/http.py +157 -0
  22. datachain/client/local.py +11 -7
  23. datachain/client/s3.py +3 -3
  24. datachain/config.py +4 -8
  25. datachain/data_storage/db_engine.py +12 -6
  26. datachain/data_storage/job.py +2 -0
  27. datachain/data_storage/metastore.py +716 -137
  28. datachain/data_storage/schema.py +20 -27
  29. datachain/data_storage/serializer.py +105 -15
  30. datachain/data_storage/sqlite.py +114 -114
  31. datachain/data_storage/warehouse.py +140 -48
  32. datachain/dataset.py +109 -89
  33. datachain/delta.py +117 -42
  34. datachain/diff/__init__.py +25 -33
  35. datachain/error.py +24 -0
  36. datachain/func/aggregate.py +9 -11
  37. datachain/func/array.py +12 -12
  38. datachain/func/base.py +7 -4
  39. datachain/func/conditional.py +9 -13
  40. datachain/func/func.py +63 -45
  41. datachain/func/numeric.py +5 -7
  42. datachain/func/string.py +2 -2
  43. datachain/hash_utils.py +123 -0
  44. datachain/job.py +11 -7
  45. datachain/json.py +138 -0
  46. datachain/lib/arrow.py +18 -15
  47. datachain/lib/audio.py +60 -59
  48. datachain/lib/clip.py +14 -13
  49. datachain/lib/convert/python_to_sql.py +6 -10
  50. datachain/lib/convert/values_to_tuples.py +151 -53
  51. datachain/lib/data_model.py +23 -19
  52. datachain/lib/dataset_info.py +7 -7
  53. datachain/lib/dc/__init__.py +2 -1
  54. datachain/lib/dc/csv.py +22 -26
  55. datachain/lib/dc/database.py +37 -34
  56. datachain/lib/dc/datachain.py +518 -324
  57. datachain/lib/dc/datasets.py +38 -30
  58. datachain/lib/dc/hf.py +16 -20
  59. datachain/lib/dc/json.py +17 -18
  60. datachain/lib/dc/listings.py +5 -8
  61. datachain/lib/dc/pandas.py +3 -6
  62. datachain/lib/dc/parquet.py +33 -21
  63. datachain/lib/dc/records.py +9 -13
  64. datachain/lib/dc/storage.py +103 -65
  65. datachain/lib/dc/storage_pattern.py +251 -0
  66. datachain/lib/dc/utils.py +17 -14
  67. datachain/lib/dc/values.py +3 -6
  68. datachain/lib/file.py +187 -50
  69. datachain/lib/hf.py +7 -5
  70. datachain/lib/image.py +13 -13
  71. datachain/lib/listing.py +5 -5
  72. datachain/lib/listing_info.py +1 -2
  73. datachain/lib/meta_formats.py +2 -3
  74. datachain/lib/model_store.py +20 -8
  75. datachain/lib/namespaces.py +59 -7
  76. datachain/lib/projects.py +51 -9
  77. datachain/lib/pytorch.py +31 -23
  78. datachain/lib/settings.py +188 -85
  79. datachain/lib/signal_schema.py +302 -64
  80. datachain/lib/text.py +8 -7
  81. datachain/lib/udf.py +103 -63
  82. datachain/lib/udf_signature.py +59 -34
  83. datachain/lib/utils.py +20 -0
  84. datachain/lib/video.py +3 -4
  85. datachain/lib/webdataset.py +31 -36
  86. datachain/lib/webdataset_laion.py +15 -16
  87. datachain/listing.py +12 -5
  88. datachain/model/bbox.py +3 -1
  89. datachain/namespace.py +22 -3
  90. datachain/node.py +6 -6
  91. datachain/nodes_thread_pool.py +0 -1
  92. datachain/plugins.py +24 -0
  93. datachain/project.py +4 -4
  94. datachain/query/batch.py +10 -12
  95. datachain/query/dataset.py +376 -194
  96. datachain/query/dispatch.py +112 -84
  97. datachain/query/metrics.py +3 -4
  98. datachain/query/params.py +2 -3
  99. datachain/query/queue.py +2 -1
  100. datachain/query/schema.py +7 -6
  101. datachain/query/session.py +190 -33
  102. datachain/query/udf.py +9 -6
  103. datachain/remote/studio.py +90 -53
  104. datachain/script_meta.py +12 -12
  105. datachain/sql/sqlite/base.py +37 -25
  106. datachain/sql/sqlite/types.py +1 -1
  107. datachain/sql/types.py +36 -5
  108. datachain/studio.py +49 -40
  109. datachain/toolkit/split.py +31 -10
  110. datachain/utils.py +39 -48
  111. {datachain-0.30.5.dist-info → datachain-0.39.0.dist-info}/METADATA +26 -38
  112. datachain-0.39.0.dist-info/RECORD +173 -0
  113. datachain/cli/commands/query.py +0 -54
  114. datachain/query/utils.py +0 -36
  115. datachain-0.30.5.dist-info/RECORD +0 -168
  116. {datachain-0.30.5.dist-info → datachain-0.39.0.dist-info}/WHEEL +0 -0
  117. {datachain-0.30.5.dist-info → datachain-0.39.0.dist-info}/entry_points.txt +0 -0
  118. {datachain-0.30.5.dist-info → datachain-0.39.0.dist-info}/licenses/LICENSE +0 -0
  119. {datachain-0.30.5.dist-info → datachain-0.39.0.dist-info}/top_level.txt +0 -0
datachain/json.py ADDED
@@ -0,0 +1,138 @@
1
+ """DataChain JSON utilities.
2
+
3
+ This module wraps :mod:`ujson` so we can guarantee consistent handling
4
+ of values that the encoder does not support out of the box (for example
5
+ ``datetime`` objects or ``bytes``).
6
+ All code inside DataChain should import this module instead of using
7
+ :mod:`ujson` directly.
8
+ """
9
+
10
+ import datetime as _dt
11
+ import json as _json
12
+ import uuid as _uuid
13
+ from collections.abc import Callable
14
+ from typing import Any
15
+
16
+ import ujson as _ujson
17
+
18
+ __all__ = [
19
+ "JSONDecodeError",
20
+ "dump",
21
+ "dumps",
22
+ "load",
23
+ "loads",
24
+ ]
25
+
26
+ JSONDecodeError = (_ujson.JSONDecodeError, _json.JSONDecodeError)
27
+
28
+ _SENTINEL = object()
29
+ _Default = Callable[[Any], Any]
30
+ DEFAULT_PREVIEW_BYTES = 1024
31
+
32
+
33
+ # To make it looks like Pydantic's ISO format with 'Z' for UTC
34
+ # It is minor but nice to have consistency
35
+ def _format_datetime(value: _dt.datetime) -> str:
36
+ iso = value.isoformat()
37
+
38
+ offset = value.utcoffset()
39
+ if value.tzinfo is None or offset is None:
40
+ return iso
41
+
42
+ if offset == _dt.timedelta(0) and iso.endswith(("+00:00", "-00:00")):
43
+ return iso[:-6] + "Z"
44
+
45
+ return iso
46
+
47
+
48
+ def _format_time(value: _dt.time) -> str:
49
+ iso = value.isoformat()
50
+
51
+ offset = value.utcoffset()
52
+ if value.tzinfo is None or offset is None:
53
+ return iso
54
+
55
+ if offset == _dt.timedelta(0) and iso.endswith(("+00:00", "-00:00")):
56
+ return iso[:-6] + "Z"
57
+
58
+ return iso
59
+
60
+
61
+ def _coerce(value: Any, serialize_bytes: bool) -> Any:
62
+ """Return a JSON-serializable representation for supported extra types."""
63
+
64
+ if isinstance(value, _dt.datetime):
65
+ return _format_datetime(value)
66
+ if isinstance(value, _dt.date):
67
+ return value.isoformat()
68
+ if isinstance(value, _dt.time):
69
+ return _format_time(value)
70
+ if isinstance(value, _uuid.UUID):
71
+ return str(value)
72
+ if serialize_bytes and isinstance(value, (bytes, bytearray)):
73
+ return list(bytes(value)[:DEFAULT_PREVIEW_BYTES])
74
+ return _SENTINEL
75
+
76
+
77
+ def _base_default(value: Any, serialize_bytes: bool) -> Any:
78
+ converted = _coerce(value, serialize_bytes)
79
+ if converted is not _SENTINEL:
80
+ return converted
81
+ raise TypeError(f"Object of type {type(value).__name__} is not JSON serializable")
82
+
83
+
84
+ def _build_default(user_default: _Default | None, serialize_bytes: bool) -> _Default:
85
+ if user_default is None:
86
+ return lambda value: _base_default(value, serialize_bytes)
87
+
88
+ def combined(value: Any) -> Any:
89
+ converted = _coerce(value, serialize_bytes)
90
+ if converted is not _SENTINEL:
91
+ return converted
92
+ return user_default(value)
93
+
94
+ return combined
95
+
96
+
97
+ def dumps(
98
+ obj: Any,
99
+ *,
100
+ default: _Default | None = None,
101
+ serialize_bytes: bool = False,
102
+ **kwargs: Any,
103
+ ) -> str:
104
+ """Serialize *obj* to a JSON-formatted ``str``."""
105
+
106
+ if serialize_bytes:
107
+ return _json.dumps(obj, default=_build_default(default, True), **kwargs)
108
+
109
+ return _ujson.dumps(obj, default=_build_default(default, False), **kwargs)
110
+
111
+
112
+ def dump(
113
+ obj: Any,
114
+ fp,
115
+ *,
116
+ default: _Default | None = None,
117
+ serialize_bytes: bool = False,
118
+ **kwargs: Any,
119
+ ) -> None:
120
+ """Serialize *obj* as a JSON formatted stream to *fp*."""
121
+
122
+ if serialize_bytes:
123
+ _json.dump(obj, fp, default=_build_default(default, True), **kwargs)
124
+ return
125
+
126
+ _ujson.dump(obj, fp, default=_build_default(default, False), **kwargs)
127
+
128
+
129
+ def loads(s: str | bytes | bytearray, **kwargs: Any) -> Any:
130
+ """Deserialize *s* to a Python object."""
131
+
132
+ return _ujson.loads(s, **kwargs)
133
+
134
+
135
+ def load(fp, **kwargs: Any) -> Any:
136
+ """Deserialize JSON content from *fp* to a Python object."""
137
+
138
+ return loads(fp.read(), **kwargs)
datachain/lib/arrow.py CHANGED
@@ -1,13 +1,13 @@
1
1
  from collections.abc import Sequence
2
2
  from itertools import islice
3
- from typing import TYPE_CHECKING, Any, Optional
3
+ from typing import TYPE_CHECKING, Any
4
4
 
5
5
  import pyarrow as pa
6
- import ujson as json
7
6
  from pyarrow._csv import ParseOptions
8
7
  from pyarrow.dataset import CsvFileFormat, dataset
9
8
  from tqdm.auto import tqdm
10
9
 
10
+ from datachain import json
11
11
  from datachain.fs.reference import ReferenceFileSystem
12
12
  from datachain.lib.data_model import dict_to_data_model
13
13
  from datachain.lib.file import ArrowRow, File
@@ -44,10 +44,10 @@ class ArrowGenerator(Generator):
44
44
 
45
45
  def __init__(
46
46
  self,
47
- input_schema: Optional["pa.Schema"] = None,
48
- output_schema: Optional[type["BaseModel"]] = None,
47
+ input_schema: pa.Schema | None = None,
48
+ output_schema: type["BaseModel"] | None = None,
49
49
  source: bool = True,
50
- nrows: Optional[int] = None,
50
+ nrows: int | None = None,
51
51
  **kwargs,
52
52
  ):
53
53
  """
@@ -112,7 +112,7 @@ class ArrowGenerator(Generator):
112
112
  record: dict[str, Any],
113
113
  file: File,
114
114
  index: int,
115
- hf_schema: Optional[tuple["Features", dict[str, "DataType"]]],
115
+ hf_schema: tuple["Features", dict[str, "DataType"]] | None,
116
116
  use_datachain_schema: bool,
117
117
  ):
118
118
  if use_datachain_schema and self.output_schema:
@@ -141,7 +141,7 @@ class ArrowGenerator(Generator):
141
141
  def _process_non_datachain_record(
142
142
  self,
143
143
  record: dict[str, Any],
144
- hf_schema: Optional[tuple["Features", dict[str, "DataType"]]],
144
+ hf_schema: tuple["Features", dict[str, "DataType"]] | None,
145
145
  ):
146
146
  vals = list(record.values())
147
147
  if not self.output_schema:
@@ -149,7 +149,9 @@ class ArrowGenerator(Generator):
149
149
 
150
150
  fields = self.output_schema.model_fields
151
151
  vals_dict = {}
152
- for i, ((field, field_info), val) in enumerate(zip(fields.items(), vals)):
152
+ for i, ((field, field_info), val) in enumerate(
153
+ zip(fields.items(), vals, strict=False)
154
+ ):
153
155
  anno = field_info.annotation
154
156
  if hf_schema:
155
157
  from datachain.lib.hf import convert_feature
@@ -180,7 +182,7 @@ def infer_schema(chain: "DataChain", **kwargs) -> pa.Schema:
180
182
 
181
183
 
182
184
  def schema_to_output(
183
- schema: pa.Schema, col_names: Optional[Sequence[str]] = None
185
+ schema: pa.Schema, col_names: Sequence[str] | None = None
184
186
  ) -> tuple[dict[str, type], list[str]]:
185
187
  """
186
188
  Generate UDF output schema from pyarrow schema.
@@ -205,14 +207,15 @@ def schema_to_output(
205
207
  hf_schema = _get_hf_schema(schema)
206
208
  if hf_schema:
207
209
  return {
208
- column: hf_type for hf_type, column in zip(hf_schema[1].values(), col_names)
210
+ column: hf_type
211
+ for hf_type, column in zip(hf_schema[1].values(), col_names, strict=False)
209
212
  }, list(normalized_col_dict.values())
210
213
 
211
214
  output = {}
212
- for field, column in zip(schema, col_names):
215
+ for field, column in zip(schema, col_names, strict=False):
213
216
  dtype = arrow_type_mapper(field.type, column)
214
217
  if field.nullable and not ModelStore.is_pydantic(dtype):
215
- dtype = Optional[dtype] # type: ignore[assignment]
218
+ dtype = dtype | None # type: ignore[assignment]
216
219
  output[column] = dtype
217
220
 
218
221
  return output, list(normalized_col_dict.values())
@@ -243,7 +246,7 @@ def arrow_type_mapper(col_type: pa.DataType, column: str = "") -> type: # noqa:
243
246
  for field in col_type:
244
247
  dtype = arrow_type_mapper(field.type, field.name)
245
248
  if field.nullable and not ModelStore.is_pydantic(dtype):
246
- dtype = Optional[dtype] # type: ignore[assignment]
249
+ dtype = dtype | None # type: ignore[assignment]
247
250
  type_dict[field.name] = dtype
248
251
  return dict_to_data_model(f"ArrowDataModel_{column}", type_dict)
249
252
  if pa.types.is_map(col_type):
@@ -257,7 +260,7 @@ def arrow_type_mapper(col_type: pa.DataType, column: str = "") -> type: # noqa:
257
260
 
258
261
  def _get_hf_schema(
259
262
  schema: "pa.Schema",
260
- ) -> Optional[tuple["Features", dict[str, "DataType"]]]:
263
+ ) -> tuple["Features", dict[str, "DataType"]] | None:
261
264
  if schema.metadata and b"huggingface" in schema.metadata:
262
265
  from datachain.lib.hf import get_output_schema, schema_from_arrow
263
266
 
@@ -266,7 +269,7 @@ def _get_hf_schema(
266
269
  return None
267
270
 
268
271
 
269
- def _get_datachain_schema(schema: "pa.Schema") -> Optional[SignalSchema]:
272
+ def _get_datachain_schema(schema: "pa.Schema") -> SignalSchema | None:
270
273
  """Return a restored SignalSchema from parquet metadata, if any is found."""
271
274
  if schema.metadata and DATACHAIN_SIGNAL_SCHEMA_PARQUET_KEY in schema.metadata:
272
275
  serialized_signal_schema = json.loads(
datachain/lib/audio.py CHANGED
@@ -1,5 +1,6 @@
1
1
  import posixpath
2
- from typing import TYPE_CHECKING, Optional, Union
2
+ import re
3
+ from typing import TYPE_CHECKING
3
4
 
4
5
  from datachain.lib.file import FileError
5
6
 
@@ -9,7 +10,7 @@ if TYPE_CHECKING:
9
10
  from datachain.lib.file import Audio, AudioFile, File
10
11
 
11
12
  try:
12
- import torchaudio
13
+ import soundfile as sf
13
14
  except ImportError as exc:
14
15
  raise ImportError(
15
16
  "Missing dependencies for processing audio.\n"
@@ -18,7 +19,7 @@ except ImportError as exc:
18
19
  ) from exc
19
20
 
20
21
 
21
- def audio_info(file: "Union[File, AudioFile]") -> "Audio":
22
+ def audio_info(file: "File | AudioFile") -> "Audio":
22
23
  """Extract metadata like sample rate, channels, duration, and format."""
23
24
  from datachain.lib.file import Audio
24
25
 
@@ -26,18 +27,25 @@ def audio_info(file: "Union[File, AudioFile]") -> "Audio":
26
27
 
27
28
  try:
28
29
  with file.open() as f:
29
- info = torchaudio.info(f)
30
+ info = sf.info(f)
31
+
32
+ sample_rate = int(info.samplerate)
33
+ channels = int(info.channels)
34
+ frames = int(info.frames)
35
+ duration = float(info.duration)
30
36
 
31
- sample_rate = int(info.sample_rate)
32
- channels = int(info.num_channels)
33
- frames = int(info.num_frames)
34
- duration = float(frames / sample_rate) if sample_rate > 0 else 0.0
37
+ # soundfile provides format and subtype
38
+ if info.format:
39
+ format_name = info.format.lower()
40
+ else:
41
+ format_name = file.get_file_ext().lower()
35
42
 
36
- codec_name = getattr(info, "encoding", "")
37
- file_ext = file.get_file_ext().lower()
38
- format_name = _encoding_to_format(codec_name, file_ext)
43
+ if not format_name:
44
+ format_name = "unknown"
45
+ codec_name = info.subtype if info.subtype else ""
39
46
 
40
- bits_per_sample = getattr(info, "bits_per_sample", 0)
47
+ # Calculate bit rate from subtype
48
+ bits_per_sample = _get_bits_per_sample(info.subtype)
41
49
  bit_rate = (
42
50
  bits_per_sample * sample_rate * channels if bits_per_sample > 0 else -1
43
51
  )
@@ -58,48 +66,43 @@ def audio_info(file: "Union[File, AudioFile]") -> "Audio":
58
66
  )
59
67
 
60
68
 
61
- def _encoding_to_format(encoding: str, file_ext: str) -> str:
69
+ def _get_bits_per_sample(subtype: str) -> int:
62
70
  """
63
- Map torchaudio encoding to a format name.
71
+ Map soundfile subtype to bits per sample.
64
72
 
65
73
  Args:
66
- encoding: The encoding string from torchaudio.info()
67
- file_ext: The file extension as a fallback
74
+ subtype: The subtype string from soundfile
68
75
 
69
76
  Returns:
70
- Format name as a string
77
+ Bits per sample, or 0 if unknown
71
78
  """
72
- # Direct mapping for formats that match exactly
73
- encoding_map = {
74
- "FLAC": "flac",
75
- "MP3": "mp3",
76
- "VORBIS": "ogg",
77
- "AMR_WB": "amr",
78
- "AMR_NB": "amr",
79
- "OPUS": "opus",
80
- "GSM": "gsm",
79
+ if not subtype:
80
+ return 0
81
+
82
+ # Common PCM and floating-point subtypes
83
+ pcm_bits = {
84
+ "PCM_16": 16,
85
+ "PCM_24": 24,
86
+ "PCM_32": 32,
87
+ "PCM_S8": 8,
88
+ "PCM_U8": 8,
89
+ "FLOAT": 32,
90
+ "DOUBLE": 64,
81
91
  }
82
92
 
83
- if encoding in encoding_map:
84
- return encoding_map[encoding]
93
+ if subtype in pcm_bits:
94
+ return pcm_bits[subtype]
85
95
 
86
- # For PCM variants, use file extension to determine format
87
- if encoding.startswith("PCM_"):
88
- # Common PCM formats by extension
89
- pcm_formats = {
90
- "wav": "wav",
91
- "aiff": "aiff",
92
- "au": "au",
93
- "raw": "raw",
94
- }
95
- return pcm_formats.get(file_ext, "wav") # Default to wav for PCM
96
+ # Handle variants such as PCM_S16LE, PCM_F32LE, etc.
97
+ match = re.search(r"PCM_(?:[A-Z]*?)(\d+)", subtype)
98
+ if match:
99
+ return int(match.group(1))
96
100
 
97
- # Fallback to file extension if encoding is unknown
98
- return file_ext if file_ext else "unknown"
101
+ return 0
99
102
 
100
103
 
101
104
  def audio_to_np(
102
- audio: "AudioFile", start: float = 0, duration: Optional[float] = None
105
+ audio: "AudioFile", start: float = 0, duration: float | None = None
103
106
  ) -> "tuple[ndarray, int]":
104
107
  """Load audio fragment as numpy array.
105
108
  Multi-channel audio is transposed to (samples, channels)."""
@@ -114,27 +117,27 @@ def audio_to_np(
114
117
 
115
118
  try:
116
119
  with audio.open() as f:
117
- info = torchaudio.info(f)
118
- sample_rate = info.sample_rate
120
+ info = sf.info(f)
121
+ sample_rate = info.samplerate
119
122
 
120
123
  frame_offset = int(start * sample_rate)
121
124
  num_frames = int(duration * sample_rate) if duration is not None else -1
122
125
 
123
126
  # Reset file pointer to the beginning
124
- # This is important to ensure we read from the correct position later
125
127
  f.seek(0)
126
128
 
127
- waveform, sr = torchaudio.load(
128
- f, frame_offset=frame_offset, num_frames=num_frames
129
+ # Read audio data with offset and frame count
130
+ audio_np, sr = sf.read(
131
+ f,
132
+ start=frame_offset,
133
+ frames=num_frames,
134
+ always_2d=False,
135
+ dtype="float32",
129
136
  )
130
137
 
131
- audio_np = waveform.numpy()
132
-
133
- if audio_np.shape[0] > 1:
134
- audio_np = audio_np.T
135
- else:
136
- audio_np = audio_np.squeeze()
137
-
138
+ # soundfile returns shape (frames,) for mono or
139
+ # (frames, channels) for multi-channel
140
+ # We keep this format as it matches expected output
138
141
  return audio_np, int(sr)
139
142
  except Exception as exc:
140
143
  raise FileError(
@@ -146,17 +149,15 @@ def audio_to_bytes(
146
149
  audio: "AudioFile",
147
150
  format: str = "wav",
148
151
  start: float = 0,
149
- duration: Optional[float] = None,
152
+ duration: float | None = None,
150
153
  ) -> bytes:
151
154
  """Convert audio to bytes using soundfile.
152
155
 
153
156
  If duration is None, converts from start to end of file.
154
157
  If start is 0 and duration is None, converts entire file."""
155
- y, sr = audio_to_np(audio, start, duration)
156
-
157
158
  import io
158
159
 
159
- import soundfile as sf
160
+ y, sr = audio_to_np(audio, start, duration)
160
161
 
161
162
  buffer = io.BytesIO()
162
163
  sf.write(buffer, y, sr, format=format)
@@ -166,9 +167,9 @@ def audio_to_bytes(
166
167
  def save_audio(
167
168
  audio: "AudioFile",
168
169
  output: str,
169
- format: Optional[str] = None,
170
+ format: str | None = None,
170
171
  start: float = 0,
171
- end: Optional[float] = None,
172
+ end: float | None = None,
172
173
  ) -> "AudioFile":
173
174
  """Save audio file or extract fragment to specified format.
174
175
 
datachain/lib/clip.py CHANGED
@@ -1,5 +1,6 @@
1
1
  import inspect
2
- from typing import TYPE_CHECKING, Any, Callable, Literal, Optional, Union
2
+ from collections.abc import Callable
3
+ from typing import TYPE_CHECKING, Any, Literal, Union
3
4
 
4
5
  import torch
5
6
  from transformers.modeling_utils import PreTrainedModel
@@ -32,28 +33,28 @@ def _get_encoder(model: Any, type: Literal["image", "text"]) -> Callable:
32
33
 
33
34
 
34
35
  def clip_similarity_scores(
35
- images: Union[None, "Image.Image", list["Image.Image"]],
36
- text: Union[None, str, list[str]],
36
+ images: Union["Image.Image", list["Image.Image"]] | None,
37
+ text: str | list[str] | None,
37
38
  model: Any,
38
39
  preprocess: Callable,
39
40
  tokenizer: Callable,
40
41
  prob: bool = False,
41
42
  image_to_text: bool = True,
42
- device: Optional[Union[str, torch.device]] = None,
43
+ device: str | torch.device | None = None,
43
44
  ) -> list[list[float]]:
44
45
  """
45
46
  Calculate CLIP similarity scores between one or more images and/or text.
46
47
 
47
48
  Parameters:
48
- images : Images to use as inputs.
49
- text : Text to use as inputs.
50
- model : Model from clip or open_clip packages.
51
- preprocess : Image preprocessor to apply.
52
- tokenizer : Text tokenizer.
53
- prob : Compute softmax probabilities.
54
- image_to_text : Whether to compute for image-to-text or text-to-image. Ignored
55
- if only one of images or text provided.
56
- device : Device to use. Defaults is None - use model's device.
49
+ images: Images to use as inputs.
50
+ text: Text to use as inputs.
51
+ model: Model from clip or open_clip packages.
52
+ preprocess: Image preprocessor to apply.
53
+ tokenizer: Text tokenizer.
54
+ prob: Compute softmax probabilities.
55
+ image_to_text: Whether to compute for image-to-text or text-to-image. Ignored
56
+ if only one of the images or text provided.
57
+ device: Device to use. Default is None - use model's device.
57
58
 
58
59
 
59
60
  Example:
@@ -1,6 +1,7 @@
1
1
  import inspect
2
2
  from datetime import datetime
3
3
  from enum import Enum
4
+ from types import UnionType
4
5
  from typing import Annotated, Literal, Union, get_args, get_origin
5
6
 
6
7
  from pydantic import BaseModel
@@ -69,11 +70,12 @@ def python_to_sql(typ): # noqa: PLR0911
69
70
  if inspect.isclass(orig) and issubclass(dict, orig):
70
71
  return JSON
71
72
 
72
- if orig == Union:
73
+ if orig in (Union, UnionType):
73
74
  if len(args) == 2 and (type(None) in args):
74
- return python_to_sql(args[0])
75
+ non_none_arg = args[0] if args[0] is not type(None) else args[1]
76
+ return python_to_sql(non_none_arg)
75
77
 
76
- if _is_union_str_literal(orig, args):
78
+ if all(arg is str or get_origin(arg) in (Literal, LiteralEx) for arg in args):
77
79
  return String
78
80
 
79
81
  if _is_json_inside_union(orig, args):
@@ -95,7 +97,7 @@ def list_of_args_to_type(args) -> SQLType:
95
97
 
96
98
 
97
99
  def _is_json_inside_union(orig, args) -> bool:
98
- if orig == Union and len(args) >= 2:
100
+ if orig in (Union, UnionType) and len(args) >= 2:
99
101
  # List in JSON: Union[dict, list[dict]]
100
102
  args_no_nones = [arg for arg in args if arg != type(None)] # noqa: E721
101
103
  if len(args_no_nones) == 2:
@@ -109,9 +111,3 @@ def _is_json_inside_union(orig, args) -> bool:
109
111
  if any(inspect.isclass(arg) and issubclass(arg, BaseModel) for arg in args):
110
112
  return True
111
113
  return False
112
-
113
-
114
- def _is_union_str_literal(orig, args) -> bool:
115
- if orig != Union:
116
- return False
117
- return all(arg is str or get_origin(arg) in (Literal, LiteralEx) for arg in args)