datachain 0.34.6__py3-none-any.whl → 0.35.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of datachain might be problematic. Click here for more details.

Files changed (105) hide show
  1. datachain/asyn.py +11 -12
  2. datachain/cache.py +5 -5
  3. datachain/catalog/catalog.py +75 -83
  4. datachain/catalog/loader.py +3 -3
  5. datachain/checkpoint.py +1 -2
  6. datachain/cli/__init__.py +2 -4
  7. datachain/cli/commands/datasets.py +13 -13
  8. datachain/cli/commands/ls.py +4 -4
  9. datachain/cli/commands/query.py +3 -3
  10. datachain/cli/commands/show.py +2 -2
  11. datachain/cli/parser/job.py +1 -1
  12. datachain/cli/parser/utils.py +1 -2
  13. datachain/cli/utils.py +1 -2
  14. datachain/client/azure.py +2 -2
  15. datachain/client/fsspec.py +11 -21
  16. datachain/client/gcs.py +3 -3
  17. datachain/client/http.py +4 -4
  18. datachain/client/local.py +4 -4
  19. datachain/client/s3.py +3 -3
  20. datachain/config.py +4 -8
  21. datachain/data_storage/db_engine.py +5 -5
  22. datachain/data_storage/metastore.py +107 -107
  23. datachain/data_storage/schema.py +18 -24
  24. datachain/data_storage/sqlite.py +21 -28
  25. datachain/data_storage/warehouse.py +13 -13
  26. datachain/dataset.py +64 -70
  27. datachain/delta.py +21 -18
  28. datachain/diff/__init__.py +13 -13
  29. datachain/func/aggregate.py +9 -11
  30. datachain/func/array.py +12 -12
  31. datachain/func/base.py +7 -4
  32. datachain/func/conditional.py +9 -13
  33. datachain/func/func.py +45 -42
  34. datachain/func/numeric.py +5 -7
  35. datachain/func/string.py +2 -2
  36. datachain/hash_utils.py +54 -81
  37. datachain/job.py +8 -8
  38. datachain/lib/arrow.py +17 -14
  39. datachain/lib/audio.py +6 -6
  40. datachain/lib/clip.py +5 -4
  41. datachain/lib/convert/python_to_sql.py +4 -22
  42. datachain/lib/convert/values_to_tuples.py +4 -9
  43. datachain/lib/data_model.py +20 -19
  44. datachain/lib/dataset_info.py +6 -6
  45. datachain/lib/dc/csv.py +10 -10
  46. datachain/lib/dc/database.py +28 -29
  47. datachain/lib/dc/datachain.py +98 -97
  48. datachain/lib/dc/datasets.py +22 -22
  49. datachain/lib/dc/hf.py +4 -4
  50. datachain/lib/dc/json.py +9 -10
  51. datachain/lib/dc/listings.py +5 -8
  52. datachain/lib/dc/pandas.py +3 -6
  53. datachain/lib/dc/parquet.py +5 -5
  54. datachain/lib/dc/records.py +5 -5
  55. datachain/lib/dc/storage.py +12 -12
  56. datachain/lib/dc/storage_pattern.py +2 -2
  57. datachain/lib/dc/utils.py +11 -14
  58. datachain/lib/dc/values.py +3 -6
  59. datachain/lib/file.py +32 -28
  60. datachain/lib/hf.py +7 -5
  61. datachain/lib/image.py +13 -13
  62. datachain/lib/listing.py +5 -5
  63. datachain/lib/listing_info.py +1 -2
  64. datachain/lib/meta_formats.py +1 -2
  65. datachain/lib/model_store.py +3 -3
  66. datachain/lib/namespaces.py +4 -6
  67. datachain/lib/projects.py +5 -9
  68. datachain/lib/pytorch.py +10 -10
  69. datachain/lib/settings.py +23 -23
  70. datachain/lib/signal_schema.py +52 -44
  71. datachain/lib/text.py +8 -7
  72. datachain/lib/udf.py +25 -17
  73. datachain/lib/udf_signature.py +11 -11
  74. datachain/lib/video.py +3 -4
  75. datachain/lib/webdataset.py +30 -35
  76. datachain/lib/webdataset_laion.py +15 -16
  77. datachain/listing.py +4 -4
  78. datachain/model/bbox.py +3 -1
  79. datachain/namespace.py +4 -4
  80. datachain/node.py +6 -6
  81. datachain/nodes_thread_pool.py +0 -1
  82. datachain/plugins.py +1 -7
  83. datachain/project.py +4 -4
  84. datachain/query/batch.py +7 -8
  85. datachain/query/dataset.py +80 -87
  86. datachain/query/dispatch.py +7 -7
  87. datachain/query/metrics.py +3 -4
  88. datachain/query/params.py +2 -3
  89. datachain/query/schema.py +7 -6
  90. datachain/query/session.py +7 -7
  91. datachain/query/udf.py +8 -7
  92. datachain/query/utils.py +3 -5
  93. datachain/remote/studio.py +33 -39
  94. datachain/script_meta.py +12 -12
  95. datachain/sql/sqlite/base.py +6 -9
  96. datachain/studio.py +30 -30
  97. datachain/toolkit/split.py +1 -2
  98. datachain/utils.py +21 -21
  99. {datachain-0.34.6.dist-info → datachain-0.35.0.dist-info}/METADATA +2 -3
  100. datachain-0.35.0.dist-info/RECORD +173 -0
  101. datachain-0.34.6.dist-info/RECORD +0 -173
  102. {datachain-0.34.6.dist-info → datachain-0.35.0.dist-info}/WHEEL +0 -0
  103. {datachain-0.34.6.dist-info → datachain-0.35.0.dist-info}/entry_points.txt +0 -0
  104. {datachain-0.34.6.dist-info → datachain-0.35.0.dist-info}/licenses/LICENSE +0 -0
  105. {datachain-0.34.6.dist-info → datachain-0.35.0.dist-info}/top_level.txt +0 -0
datachain/func/numeric.py CHANGED
@@ -1,12 +1,10 @@
1
- from typing import Union
2
-
3
1
  from datachain.query.schema import Column
4
2
  from datachain.sql.functions import numeric
5
3
 
6
4
  from .func import Func
7
5
 
8
6
 
9
- def bit_and(*args: Union[str, Column, Func, int]) -> Func:
7
+ def bit_and(*args: str | Column | Func | int) -> Func:
10
8
  """
11
9
  Returns a function that computes the bitwise AND operation between two values.
12
10
 
@@ -51,7 +49,7 @@ def bit_and(*args: Union[str, Column, Func, int]) -> Func:
51
49
  )
52
50
 
53
51
 
54
- def bit_or(*args: Union[str, Column, Func, int]) -> Func:
52
+ def bit_or(*args: str | Column | Func | int) -> Func:
55
53
  """
56
54
  Returns a function that computes the bitwise OR operation between two values.
57
55
 
@@ -96,7 +94,7 @@ def bit_or(*args: Union[str, Column, Func, int]) -> Func:
96
94
  )
97
95
 
98
96
 
99
- def bit_xor(*args: Union[str, Column, Func, int]) -> Func:
97
+ def bit_xor(*args: str | Column | Func | int) -> Func:
100
98
  """
101
99
  Returns a function that computes the bitwise XOR operation between two values.
102
100
 
@@ -141,7 +139,7 @@ def bit_xor(*args: Union[str, Column, Func, int]) -> Func:
141
139
  )
142
140
 
143
141
 
144
- def int_hash_64(col: Union[str, Column, Func, int]) -> Func:
142
+ def int_hash_64(col: str | Column | Func | int) -> Func:
145
143
  """
146
144
  Returns a function that computes the 64-bit hash of an integer.
147
145
 
@@ -177,7 +175,7 @@ def int_hash_64(col: Union[str, Column, Func, int]) -> Func:
177
175
  )
178
176
 
179
177
 
180
- def bit_hamming_distance(*args: Union[str, Column, Func, int]) -> Func:
178
+ def bit_hamming_distance(*args: str | Column | Func | int) -> Func:
181
179
  """
182
180
  Returns a function that computes the Hamming distance between two integers.
183
181
 
datachain/func/string.py CHANGED
@@ -1,4 +1,4 @@
1
- from typing import Optional, get_origin
1
+ from typing import get_origin
2
2
 
3
3
  from sqlalchemy import literal
4
4
 
@@ -44,7 +44,7 @@ def length(col: ColT) -> Func:
44
44
  return Func("length", inner=string.length, cols=[col], result_type=int)
45
45
 
46
46
 
47
- def split(col: ColT, sep: str, limit: Optional[int] = None) -> Func:
47
+ def split(col: ColT, sep: str, limit: int | None = None) -> Func:
48
48
  """
49
49
  Takes a column and split character and returns an array of the parts.
50
50
 
datachain/hash_utils.py CHANGED
@@ -3,101 +3,74 @@ import inspect
3
3
  import json
4
4
  import textwrap
5
5
  from collections.abc import Sequence
6
- from typing import TypeVar, Union
7
-
8
- from sqlalchemy.sql.elements import (
9
- BinaryExpression,
10
- BindParameter,
11
- ColumnElement,
12
- Label,
13
- Over,
14
- UnaryExpression,
15
- )
16
- from sqlalchemy.sql.functions import Function
17
-
18
- T = TypeVar("T", bound=ColumnElement)
19
- ColumnLike = Union[str, T]
6
+ from typing import TypeAlias, TypeVar
20
7
 
8
+ from sqlalchemy.sql.elements import ClauseElement, ColumnElement
21
9
 
22
- def serialize_column_element(expr: Union[str, ColumnElement]) -> dict: # noqa: PLR0911
10
+ T = TypeVar("T", bound=ColumnElement)
11
+ ColumnLike: TypeAlias = str | T
12
+
13
+
14
+ def _serialize_value(val): # noqa: PLR0911
15
+ """Helper to serialize arbitrary values recursively."""
16
+ if val is None:
17
+ return None
18
+ if isinstance(val, (str, int, float, bool)):
19
+ return val
20
+ if isinstance(val, ClauseElement):
21
+ return serialize_column_element(val)
22
+ if isinstance(val, dict):
23
+ # Sort dict keys for deterministic serialization
24
+ return {k: _serialize_value(v) for k, v in sorted(val.items())}
25
+ if isinstance(val, (list, tuple)):
26
+ return [_serialize_value(v) for v in val]
27
+ if callable(val):
28
+ return val.__name__ if hasattr(val, "__name__") else str(val)
29
+ return str(val)
30
+
31
+
32
+ def serialize_column_element(expr: str | ColumnElement) -> dict:
23
33
  """
24
34
  Recursively serialize a SQLAlchemy ColumnElement into a deterministic structure.
35
+ Uses SQLAlchemy's _traverse_internals to automatically handle all expression types.
25
36
  """
37
+ from sqlalchemy.sql.elements import BindParameter
26
38
 
27
- # Binary operations: col > 5, col1 + col2, etc.
28
- if isinstance(expr, BinaryExpression):
29
- op = (
30
- expr.operator.__name__
31
- if hasattr(expr.operator, "__name__")
32
- else str(expr.operator)
33
- )
34
- return {
35
- "type": "binary",
36
- "op": op,
37
- "left": serialize_column_element(expr.left),
38
- "right": serialize_column_element(expr.right),
39
- }
40
-
41
- # Unary operations: -col, NOT col, etc.
42
- if isinstance(expr, UnaryExpression):
43
- op = (
44
- expr.operator.__name__
45
- if expr.operator is not None and hasattr(expr.operator, "__name__")
46
- else str(expr.operator)
47
- )
48
-
49
- return {
50
- "type": "unary",
51
- "op": op,
52
- "element": serialize_column_element(expr.element), # type: ignore[arg-type]
53
- }
54
-
55
- # Function calls: func.lower(col), func.count(col), etc.
56
- if isinstance(expr, Function):
57
- return {
58
- "type": "function",
59
- "name": expr.name,
60
- "clauses": [serialize_column_element(c) for c in expr.clauses],
61
- }
62
-
63
- # Window functions: func.row_number().over(partition_by=..., order_by=...)
64
- if isinstance(expr, Over):
65
- return {
66
- "type": "window",
67
- "function": serialize_column_element(expr.element),
68
- "partition_by": [
69
- serialize_column_element(p) for p in getattr(expr, "partition_by", [])
70
- ],
71
- "order_by": [
72
- serialize_column_element(o) for o in getattr(expr, "order_by", [])
73
- ],
74
- }
75
-
76
- # Labeled expressions: col.label("alias")
77
- if isinstance(expr, Label):
78
- return {
79
- "type": "label",
80
- "name": expr.name,
81
- "element": serialize_column_element(expr.element),
82
- }
83
-
84
- # Bound values (constants)
39
+ # Special case: BindParameter has non-deterministic 'key' attribute, only use value
85
40
  if isinstance(expr, BindParameter):
86
- return {"type": "bind", "value": expr.value}
87
-
88
- # Plain columns
89
- if hasattr(expr, "name"):
90
- return {"type": "column", "name": expr.name}
91
-
92
- # Fallback: stringify unknown nodes
41
+ return {"type": "bind", "value": _serialize_value(expr.value)}
42
+
43
+ # Generic handling for all ClauseElement types using SQLAlchemy's internals
44
+ if isinstance(expr, ClauseElement):
45
+ # All standard SQLAlchemy types have _traverse_internals
46
+ if hasattr(expr, "_traverse_internals"):
47
+ result = {"type": expr.__class__.__name__}
48
+ for attr_name, _ in expr._traverse_internals:
49
+ # Skip 'table' attribute - table names can be auto-generated/random
50
+ # and are not semantically important for hashing
51
+ if attr_name == "table":
52
+ continue
53
+ if hasattr(expr, attr_name):
54
+ val = getattr(expr, attr_name)
55
+ result[attr_name] = _serialize_value(val)
56
+ return result
57
+ # Rare case: custom user-defined ClauseElement without _traverse_internals
58
+ # We don't know its structure, so just stringify it
59
+ return {"type": expr.__class__.__name__, "repr": str(expr)}
60
+
61
+ # Absolute fallback: stringify completely unknown types
93
62
  return {"type": "other", "repr": str(expr)}
94
63
 
95
64
 
96
- def hash_column_elements(columns: Sequence[ColumnLike]) -> str:
65
+ def hash_column_elements(columns: ColumnLike | Sequence[ColumnLike]) -> str:
97
66
  """
98
67
  Hash a list of ColumnElements deterministically, dialect agnostic.
99
68
  Only accepts ordered iterables (like list or tuple).
100
69
  """
70
+ # Handle case where a single ColumnElement is passed instead of a sequence
71
+ if isinstance(columns, (ColumnElement, str)):
72
+ columns = (columns,)
73
+
101
74
  serialized = [serialize_column_element(c) for c in columns]
102
75
  json_str = json.dumps(serialized, sort_keys=True) # stable JSON
103
76
  return hashlib.sha256(json_str.encode("utf-8")).hexdigest()
datachain/job.py CHANGED
@@ -2,7 +2,7 @@ import json
2
2
  import uuid
3
3
  from dataclasses import dataclass
4
4
  from datetime import datetime
5
- from typing import Any, Optional, TypeVar, Union
5
+ from typing import Any, TypeVar
6
6
 
7
7
  J = TypeVar("J", bound="Job")
8
8
 
@@ -18,29 +18,29 @@ class Job:
18
18
  workers: int
19
19
  params: dict[str, str]
20
20
  metrics: dict[str, Any]
21
- finished_at: Optional[datetime] = None
22
- python_version: Optional[str] = None
21
+ finished_at: datetime | None = None
22
+ python_version: str | None = None
23
23
  error_message: str = ""
24
24
  error_stack: str = ""
25
- parent_job_id: Optional[str] = None
25
+ parent_job_id: str | None = None
26
26
 
27
27
  @classmethod
28
28
  def parse(
29
29
  cls,
30
- id: Union[str, uuid.UUID],
30
+ id: str | uuid.UUID,
31
31
  name: str,
32
32
  status: int,
33
33
  created_at: datetime,
34
- finished_at: Optional[datetime],
34
+ finished_at: datetime | None,
35
35
  query: str,
36
36
  query_type: int,
37
37
  workers: int,
38
- python_version: Optional[str],
38
+ python_version: str | None,
39
39
  error_message: str,
40
40
  error_stack: str,
41
41
  params: str,
42
42
  metrics: str,
43
- parent_job_id: Optional[str],
43
+ parent_job_id: str | None,
44
44
  ) -> "Job":
45
45
  return cls(
46
46
  str(id),
datachain/lib/arrow.py CHANGED
@@ -1,6 +1,6 @@
1
1
  from collections.abc import Sequence
2
2
  from itertools import islice
3
- from typing import TYPE_CHECKING, Any, Optional
3
+ from typing import TYPE_CHECKING, Any
4
4
 
5
5
  import pyarrow as pa
6
6
  import ujson as json
@@ -44,10 +44,10 @@ class ArrowGenerator(Generator):
44
44
 
45
45
  def __init__(
46
46
  self,
47
- input_schema: Optional["pa.Schema"] = None,
48
- output_schema: Optional[type["BaseModel"]] = None,
47
+ input_schema: pa.Schema | None = None,
48
+ output_schema: type["BaseModel"] | None = None,
49
49
  source: bool = True,
50
- nrows: Optional[int] = None,
50
+ nrows: int | None = None,
51
51
  **kwargs,
52
52
  ):
53
53
  """
@@ -112,7 +112,7 @@ class ArrowGenerator(Generator):
112
112
  record: dict[str, Any],
113
113
  file: File,
114
114
  index: int,
115
- hf_schema: Optional[tuple["Features", dict[str, "DataType"]]],
115
+ hf_schema: tuple["Features", dict[str, "DataType"]] | None,
116
116
  use_datachain_schema: bool,
117
117
  ):
118
118
  if use_datachain_schema and self.output_schema:
@@ -141,7 +141,7 @@ class ArrowGenerator(Generator):
141
141
  def _process_non_datachain_record(
142
142
  self,
143
143
  record: dict[str, Any],
144
- hf_schema: Optional[tuple["Features", dict[str, "DataType"]]],
144
+ hf_schema: tuple["Features", dict[str, "DataType"]] | None,
145
145
  ):
146
146
  vals = list(record.values())
147
147
  if not self.output_schema:
@@ -149,7 +149,9 @@ class ArrowGenerator(Generator):
149
149
 
150
150
  fields = self.output_schema.model_fields
151
151
  vals_dict = {}
152
- for i, ((field, field_info), val) in enumerate(zip(fields.items(), vals)):
152
+ for i, ((field, field_info), val) in enumerate(
153
+ zip(fields.items(), vals, strict=False)
154
+ ):
153
155
  anno = field_info.annotation
154
156
  if hf_schema:
155
157
  from datachain.lib.hf import convert_feature
@@ -180,7 +182,7 @@ def infer_schema(chain: "DataChain", **kwargs) -> pa.Schema:
180
182
 
181
183
 
182
184
  def schema_to_output(
183
- schema: pa.Schema, col_names: Optional[Sequence[str]] = None
185
+ schema: pa.Schema, col_names: Sequence[str] | None = None
184
186
  ) -> tuple[dict[str, type], list[str]]:
185
187
  """
186
188
  Generate UDF output schema from pyarrow schema.
@@ -205,14 +207,15 @@ def schema_to_output(
205
207
  hf_schema = _get_hf_schema(schema)
206
208
  if hf_schema:
207
209
  return {
208
- column: hf_type for hf_type, column in zip(hf_schema[1].values(), col_names)
210
+ column: hf_type
211
+ for hf_type, column in zip(hf_schema[1].values(), col_names, strict=False)
209
212
  }, list(normalized_col_dict.values())
210
213
 
211
214
  output = {}
212
- for field, column in zip(schema, col_names):
215
+ for field, column in zip(schema, col_names, strict=False):
213
216
  dtype = arrow_type_mapper(field.type, column)
214
217
  if field.nullable and not ModelStore.is_pydantic(dtype):
215
- dtype = Optional[dtype] # type: ignore[assignment]
218
+ dtype = dtype | None # type: ignore[assignment]
216
219
  output[column] = dtype
217
220
 
218
221
  return output, list(normalized_col_dict.values())
@@ -243,7 +246,7 @@ def arrow_type_mapper(col_type: pa.DataType, column: str = "") -> type: # noqa:
243
246
  for field in col_type:
244
247
  dtype = arrow_type_mapper(field.type, field.name)
245
248
  if field.nullable and not ModelStore.is_pydantic(dtype):
246
- dtype = Optional[dtype] # type: ignore[assignment]
249
+ dtype = dtype | None # type: ignore[assignment]
247
250
  type_dict[field.name] = dtype
248
251
  return dict_to_data_model(f"ArrowDataModel_{column}", type_dict)
249
252
  if pa.types.is_map(col_type):
@@ -257,7 +260,7 @@ def arrow_type_mapper(col_type: pa.DataType, column: str = "") -> type: # noqa:
257
260
 
258
261
  def _get_hf_schema(
259
262
  schema: "pa.Schema",
260
- ) -> Optional[tuple["Features", dict[str, "DataType"]]]:
263
+ ) -> tuple["Features", dict[str, "DataType"]] | None:
261
264
  if schema.metadata and b"huggingface" in schema.metadata:
262
265
  from datachain.lib.hf import get_output_schema, schema_from_arrow
263
266
 
@@ -266,7 +269,7 @@ def _get_hf_schema(
266
269
  return None
267
270
 
268
271
 
269
- def _get_datachain_schema(schema: "pa.Schema") -> Optional[SignalSchema]:
272
+ def _get_datachain_schema(schema: "pa.Schema") -> SignalSchema | None:
270
273
  """Return a restored SignalSchema from parquet metadata, if any is found."""
271
274
  if schema.metadata and DATACHAIN_SIGNAL_SCHEMA_PARQUET_KEY in schema.metadata:
272
275
  serialized_signal_schema = json.loads(
datachain/lib/audio.py CHANGED
@@ -1,5 +1,5 @@
1
1
  import posixpath
2
- from typing import TYPE_CHECKING, Optional, Union
2
+ from typing import TYPE_CHECKING
3
3
 
4
4
  from datachain.lib.file import FileError
5
5
 
@@ -18,7 +18,7 @@ except ImportError as exc:
18
18
  ) from exc
19
19
 
20
20
 
21
- def audio_info(file: "Union[File, AudioFile]") -> "Audio":
21
+ def audio_info(file: "File | AudioFile") -> "Audio":
22
22
  """Extract metadata like sample rate, channels, duration, and format."""
23
23
  from datachain.lib.file import Audio
24
24
 
@@ -99,7 +99,7 @@ def _encoding_to_format(encoding: str, file_ext: str) -> str:
99
99
 
100
100
 
101
101
  def audio_to_np(
102
- audio: "AudioFile", start: float = 0, duration: Optional[float] = None
102
+ audio: "AudioFile", start: float = 0, duration: float | None = None
103
103
  ) -> "tuple[ndarray, int]":
104
104
  """Load audio fragment as numpy array.
105
105
  Multi-channel audio is transposed to (samples, channels)."""
@@ -146,7 +146,7 @@ def audio_to_bytes(
146
146
  audio: "AudioFile",
147
147
  format: str = "wav",
148
148
  start: float = 0,
149
- duration: Optional[float] = None,
149
+ duration: float | None = None,
150
150
  ) -> bytes:
151
151
  """Convert audio to bytes using soundfile.
152
152
 
@@ -166,9 +166,9 @@ def audio_to_bytes(
166
166
  def save_audio(
167
167
  audio: "AudioFile",
168
168
  output: str,
169
- format: Optional[str] = None,
169
+ format: str | None = None,
170
170
  start: float = 0,
171
- end: Optional[float] = None,
171
+ end: float | None = None,
172
172
  ) -> "AudioFile":
173
173
  """Save audio file or extract fragment to specified format.
174
174
 
datachain/lib/clip.py CHANGED
@@ -1,5 +1,6 @@
1
1
  import inspect
2
- from typing import TYPE_CHECKING, Any, Callable, Literal, Optional, Union
2
+ from collections.abc import Callable
3
+ from typing import TYPE_CHECKING, Any, Literal, Union
3
4
 
4
5
  import torch
5
6
  from transformers.modeling_utils import PreTrainedModel
@@ -32,14 +33,14 @@ def _get_encoder(model: Any, type: Literal["image", "text"]) -> Callable:
32
33
 
33
34
 
34
35
  def clip_similarity_scores(
35
- images: Union[None, "Image.Image", list["Image.Image"]],
36
- text: Union[None, str, list[str]],
36
+ images: Union["Image.Image", list["Image.Image"]] | None,
37
+ text: str | list[str] | None,
37
38
  model: Any,
38
39
  preprocess: Callable,
39
40
  tokenizer: Callable,
40
41
  prob: bool = False,
41
42
  image_to_text: bool = True,
42
- device: Optional[Union[str, torch.device]] = None,
43
+ device: str | torch.device | None = None,
43
44
  ) -> list[list[float]]:
44
45
  """
45
46
  Calculate CLIP similarity scores between one or more images and/or text.
@@ -1,14 +1,9 @@
1
1
  import inspect
2
- import sys
3
2
  from datetime import datetime
4
3
  from enum import Enum
4
+ from types import UnionType
5
5
  from typing import Annotated, Literal, Union, get_args, get_origin
6
6
 
7
- if sys.version_info >= (3, 10):
8
- from types import UnionType
9
- else:
10
- UnionType = None
11
-
12
7
  from pydantic import BaseModel
13
8
  from typing_extensions import Literal as LiteralEx
14
9
 
@@ -40,13 +35,6 @@ PYTHON_TO_SQL = {
40
35
  }
41
36
 
42
37
 
43
- def _is_union(orig) -> bool:
44
- if orig == Union:
45
- return True
46
- # some code is unreachab in python<3.10
47
- return UnionType is not None and orig is UnionType # type: ignore[unreachable]
48
-
49
-
50
38
  def python_to_sql(typ): # noqa: PLR0911
51
39
  if inspect.isclass(typ):
52
40
  if issubclass(typ, SQLType):
@@ -82,12 +70,12 @@ def python_to_sql(typ): # noqa: PLR0911
82
70
  if inspect.isclass(orig) and issubclass(dict, orig):
83
71
  return JSON
84
72
 
85
- if _is_union(orig):
73
+ if orig in (Union, UnionType):
86
74
  if len(args) == 2 and (type(None) in args):
87
75
  non_none_arg = args[0] if args[0] is not type(None) else args[1]
88
76
  return python_to_sql(non_none_arg)
89
77
 
90
- if _is_union_str_literal(orig, args):
78
+ if all(arg is str or get_origin(arg) in (Literal, LiteralEx) for arg in args):
91
79
  return String
92
80
 
93
81
  if _is_json_inside_union(orig, args):
@@ -109,7 +97,7 @@ def list_of_args_to_type(args) -> SQLType:
109
97
 
110
98
 
111
99
  def _is_json_inside_union(orig, args) -> bool:
112
- if _is_union(orig) and len(args) >= 2:
100
+ if orig in (Union, UnionType) and len(args) >= 2:
113
101
  # List in JSON: Union[dict, list[dict]]
114
102
  args_no_nones = [arg for arg in args if arg != type(None)] # noqa: E721
115
103
  if len(args_no_nones) == 2:
@@ -123,9 +111,3 @@ def _is_json_inside_union(orig, args) -> bool:
123
111
  if any(inspect.isclass(arg) and issubclass(arg, BaseModel) for arg in args):
124
112
  return True
125
113
  return False
126
-
127
-
128
- def _is_union_str_literal(orig, args) -> bool:
129
- if not _is_union(orig):
130
- return False
131
- return all(arg is str or get_origin(arg) in (Literal, LiteralEx) for arg in args)
@@ -1,13 +1,8 @@
1
1
  import itertools
2
2
  from collections.abc import Sequence
3
- from typing import Any, Union
3
+ from typing import Any
4
4
 
5
- from datachain.lib.data_model import (
6
- DataType,
7
- DataTypeNames,
8
- DataValue,
9
- is_chain_type,
10
- )
5
+ from datachain.lib.data_model import DataType, DataTypeNames, DataValue, is_chain_type
11
6
  from datachain.lib.utils import DataChainParamsError
12
7
 
13
8
 
@@ -20,7 +15,7 @@ class ValuesToTupleError(DataChainParamsError):
20
15
 
21
16
  def values_to_tuples( # noqa: C901, PLR0912
22
17
  ds_name: str = "",
23
- output: Union[None, DataType, Sequence[str], dict[str, DataType]] = None,
18
+ output: DataType | Sequence[str] | dict[str, DataType] | None = None,
24
19
  **fr_map: Sequence[DataValue],
25
20
  ) -> tuple[Any, Any, Any]:
26
21
  if output:
@@ -111,7 +106,7 @@ def values_to_tuples( # noqa: C901, PLR0912
111
106
  if len(output) > 1: # type: ignore[arg-type]
112
107
  tuple_type = tuple(output_types)
113
108
  res_type = tuple[tuple_type] # type: ignore[valid-type]
114
- res_values: Sequence[Any] = list(zip(*fr_map.values()))
109
+ res_values: Sequence[Any] = list(zip(*fr_map.values(), strict=False))
115
110
  else:
116
111
  res_type = output_types[0] # type: ignore[misc]
117
112
  res_values = next(iter(fr_map.values()))
@@ -1,8 +1,9 @@
1
1
  import inspect
2
+ import types
2
3
  import uuid
3
4
  from collections.abc import Sequence
4
5
  from datetime import datetime
5
- from typing import ClassVar, Optional, Union, get_args, get_origin
6
+ from typing import ClassVar, Union, get_args, get_origin
6
7
 
7
8
  from pydantic import AliasChoices, BaseModel, Field, create_model
8
9
  from pydantic.fields import FieldInfo
@@ -10,19 +11,19 @@ from pydantic.fields import FieldInfo
10
11
  from datachain.lib.model_store import ModelStore
11
12
  from datachain.lib.utils import normalize_col_names
12
13
 
13
- StandardType = Union[
14
- type[int],
15
- type[str],
16
- type[float],
17
- type[bool],
18
- type[list],
19
- type[dict],
20
- type[bytes],
21
- type[datetime],
22
- ]
23
- DataType = Union[type[BaseModel], StandardType]
14
+ StandardType = (
15
+ type[int]
16
+ | type[str]
17
+ | type[float]
18
+ | type[bool]
19
+ | type[list]
20
+ | type[dict]
21
+ | type[bytes]
22
+ | type[datetime]
23
+ )
24
+ DataType = type[BaseModel] | StandardType
24
25
  DataTypeNames = "BaseModel, int, str, float, bool, list, dict, bytes, datetime"
25
- DataValue = Union[BaseModel, int, str, float, bool, list, dict, bytes, datetime]
26
+ DataValue = BaseModel | int | str | float | bool | list | dict | bytes | datetime
26
27
 
27
28
 
28
29
  class DataModel(BaseModel):
@@ -37,7 +38,7 @@ class DataModel(BaseModel):
37
38
  ModelStore.register(cls)
38
39
 
39
40
  @staticmethod
40
- def register(models: Union[DataType, Sequence[DataType]]):
41
+ def register(models: DataType | Sequence[DataType]):
41
42
  """For registering classes manually. It accepts a single class or a sequence of
42
43
  classes."""
43
44
  if not isinstance(models, Sequence):
@@ -63,8 +64,8 @@ def is_chain_type(t: type) -> bool:
63
64
  if orig is list and len(args) == 1:
64
65
  return is_chain_type(get_args(t)[0])
65
66
 
66
- if orig is Union and len(args) == 2 and (type(None) in args):
67
- return is_chain_type(args[0])
67
+ if orig in (Union, types.UnionType) and len(args) == 2 and (type(None) in args):
68
+ return is_chain_type(args[0] if args[1] is type(None) else args[1])
68
69
 
69
70
  return False
70
71
 
@@ -72,19 +73,19 @@ def is_chain_type(t: type) -> bool:
72
73
  def dict_to_data_model(
73
74
  name: str,
74
75
  data_dict: dict[str, DataType],
75
- original_names: Optional[list[str]] = None,
76
+ original_names: list[str] | None = None,
76
77
  ) -> type[BaseModel]:
77
78
  if not original_names:
78
79
  # Gets a map of a normalized_name -> original_name
79
80
  columns = normalize_col_names(list(data_dict))
80
- data_dict = dict(zip(columns.keys(), data_dict.values()))
81
+ data_dict = dict(zip(columns.keys(), data_dict.values(), strict=False))
81
82
  original_names = list(columns.values())
82
83
 
83
84
  fields = {
84
85
  name: (
85
86
  anno
86
87
  if inspect.isclass(anno) and issubclass(anno, BaseModel)
87
- else Optional[anno],
88
+ else anno | None,
88
89
  Field(
89
90
  validation_alias=AliasChoices(name, original_names[idx] or name),
90
91
  default=None,
@@ -1,6 +1,6 @@
1
1
  import json
2
2
  from datetime import datetime
3
- from typing import TYPE_CHECKING, Any, Optional, Union
3
+ from typing import TYPE_CHECKING, Any
4
4
  from uuid import uuid4
5
5
 
6
6
  from pydantic import Field, field_validator
@@ -28,9 +28,9 @@ class DatasetInfo(DataModel):
28
28
  version: str = Field(default=DEFAULT_DATASET_VERSION)
29
29
  status: int = Field(default=DatasetStatus.CREATED)
30
30
  created_at: datetime = Field(default=TIME_ZERO)
31
- finished_at: Optional[datetime] = Field(default=None)
32
- num_objects: Optional[int] = Field(default=None)
33
- size: Optional[int] = Field(default=None)
31
+ finished_at: datetime | None = Field(default=None)
32
+ num_objects: int | None = Field(default=None)
33
+ size: int | None = Field(default=None)
34
34
  params: dict[str, str] = Field(default={})
35
35
  metrics: dict[str, Any] = Field(default={})
36
36
  error_message: str = Field(default="")
@@ -59,7 +59,7 @@ class DatasetInfo(DataModel):
59
59
 
60
60
  @staticmethod
61
61
  def _validate_dict(
62
- v: Optional[Union[str, dict]],
62
+ v: str | dict | None,
63
63
  ) -> dict:
64
64
  if v is None or v == "":
65
65
  return {}
@@ -88,7 +88,7 @@ class DatasetInfo(DataModel):
88
88
  cls,
89
89
  dataset: DatasetListRecord,
90
90
  version: DatasetListVersion,
91
- job: Optional[Job],
91
+ job: Job | None,
92
92
  ) -> "Self":
93
93
  return cls(
94
94
  uuid=version.uuid,