datachain 0.34.6__py3-none-any.whl → 0.35.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of datachain might be problematic. Click here for more details.

Files changed (105) hide show
  1. datachain/asyn.py +11 -12
  2. datachain/cache.py +5 -5
  3. datachain/catalog/catalog.py +75 -83
  4. datachain/catalog/loader.py +3 -3
  5. datachain/checkpoint.py +1 -2
  6. datachain/cli/__init__.py +2 -4
  7. datachain/cli/commands/datasets.py +13 -13
  8. datachain/cli/commands/ls.py +4 -4
  9. datachain/cli/commands/query.py +3 -3
  10. datachain/cli/commands/show.py +2 -2
  11. datachain/cli/parser/job.py +1 -1
  12. datachain/cli/parser/utils.py +1 -2
  13. datachain/cli/utils.py +1 -2
  14. datachain/client/azure.py +2 -2
  15. datachain/client/fsspec.py +11 -21
  16. datachain/client/gcs.py +3 -3
  17. datachain/client/http.py +4 -4
  18. datachain/client/local.py +4 -4
  19. datachain/client/s3.py +3 -3
  20. datachain/config.py +4 -8
  21. datachain/data_storage/db_engine.py +5 -5
  22. datachain/data_storage/metastore.py +107 -107
  23. datachain/data_storage/schema.py +18 -24
  24. datachain/data_storage/sqlite.py +21 -28
  25. datachain/data_storage/warehouse.py +13 -13
  26. datachain/dataset.py +64 -70
  27. datachain/delta.py +21 -18
  28. datachain/diff/__init__.py +13 -13
  29. datachain/func/aggregate.py +9 -11
  30. datachain/func/array.py +12 -12
  31. datachain/func/base.py +7 -4
  32. datachain/func/conditional.py +9 -13
  33. datachain/func/func.py +45 -42
  34. datachain/func/numeric.py +5 -7
  35. datachain/func/string.py +2 -2
  36. datachain/hash_utils.py +54 -81
  37. datachain/job.py +8 -8
  38. datachain/lib/arrow.py +17 -14
  39. datachain/lib/audio.py +6 -6
  40. datachain/lib/clip.py +5 -4
  41. datachain/lib/convert/python_to_sql.py +4 -22
  42. datachain/lib/convert/values_to_tuples.py +4 -9
  43. datachain/lib/data_model.py +20 -19
  44. datachain/lib/dataset_info.py +6 -6
  45. datachain/lib/dc/csv.py +10 -10
  46. datachain/lib/dc/database.py +28 -29
  47. datachain/lib/dc/datachain.py +98 -97
  48. datachain/lib/dc/datasets.py +22 -22
  49. datachain/lib/dc/hf.py +4 -4
  50. datachain/lib/dc/json.py +9 -10
  51. datachain/lib/dc/listings.py +5 -8
  52. datachain/lib/dc/pandas.py +3 -6
  53. datachain/lib/dc/parquet.py +5 -5
  54. datachain/lib/dc/records.py +5 -5
  55. datachain/lib/dc/storage.py +12 -12
  56. datachain/lib/dc/storage_pattern.py +2 -2
  57. datachain/lib/dc/utils.py +11 -14
  58. datachain/lib/dc/values.py +3 -6
  59. datachain/lib/file.py +32 -28
  60. datachain/lib/hf.py +7 -5
  61. datachain/lib/image.py +13 -13
  62. datachain/lib/listing.py +5 -5
  63. datachain/lib/listing_info.py +1 -2
  64. datachain/lib/meta_formats.py +1 -2
  65. datachain/lib/model_store.py +3 -3
  66. datachain/lib/namespaces.py +4 -6
  67. datachain/lib/projects.py +5 -9
  68. datachain/lib/pytorch.py +10 -10
  69. datachain/lib/settings.py +23 -23
  70. datachain/lib/signal_schema.py +52 -44
  71. datachain/lib/text.py +8 -7
  72. datachain/lib/udf.py +25 -17
  73. datachain/lib/udf_signature.py +11 -11
  74. datachain/lib/video.py +3 -4
  75. datachain/lib/webdataset.py +30 -35
  76. datachain/lib/webdataset_laion.py +15 -16
  77. datachain/listing.py +4 -4
  78. datachain/model/bbox.py +3 -1
  79. datachain/namespace.py +4 -4
  80. datachain/node.py +6 -6
  81. datachain/nodes_thread_pool.py +0 -1
  82. datachain/plugins.py +1 -7
  83. datachain/project.py +4 -4
  84. datachain/query/batch.py +7 -8
  85. datachain/query/dataset.py +80 -87
  86. datachain/query/dispatch.py +7 -7
  87. datachain/query/metrics.py +3 -4
  88. datachain/query/params.py +2 -3
  89. datachain/query/schema.py +7 -6
  90. datachain/query/session.py +7 -7
  91. datachain/query/udf.py +8 -7
  92. datachain/query/utils.py +3 -5
  93. datachain/remote/studio.py +33 -39
  94. datachain/script_meta.py +12 -12
  95. datachain/sql/sqlite/base.py +6 -9
  96. datachain/studio.py +30 -30
  97. datachain/toolkit/split.py +1 -2
  98. datachain/utils.py +21 -21
  99. {datachain-0.34.6.dist-info → datachain-0.35.0.dist-info}/METADATA +2 -3
  100. datachain-0.35.0.dist-info/RECORD +173 -0
  101. datachain-0.34.6.dist-info/RECORD +0 -173
  102. {datachain-0.34.6.dist-info → datachain-0.35.0.dist-info}/WHEEL +0 -0
  103. {datachain-0.34.6.dist-info → datachain-0.35.0.dist-info}/entry_points.txt +0 -0
  104. {datachain-0.34.6.dist-info → datachain-0.35.0.dist-info}/licenses/LICENSE +0 -0
  105. {datachain-0.34.6.dist-info → datachain-0.35.0.dist-info}/top_level.txt +0 -0
datachain/lib/projects.py CHANGED
@@ -1,5 +1,3 @@
1
- from typing import Optional
2
-
3
1
  from datachain.error import ProjectCreateNotAllowedError, ProjectDeleteNotAllowedError
4
2
  from datachain.project import Project
5
3
  from datachain.query import Session
@@ -8,8 +6,8 @@ from datachain.query import Session
8
6
  def create(
9
7
  namespace: str,
10
8
  name: str,
11
- descr: Optional[str] = None,
12
- session: Optional[Session] = None,
9
+ descr: str | None = None,
10
+ session: Session | None = None,
13
11
  ) -> Project:
14
12
  """
15
13
  Creates a new project under a specified namespace.
@@ -42,7 +40,7 @@ def create(
42
40
  return session.catalog.metastore.create_project(namespace, name, descr)
43
41
 
44
42
 
45
- def get(name: str, namespace: str, session: Optional[Session]) -> Project:
43
+ def get(name: str, namespace: str, session: Session | None) -> Project:
46
44
  """
47
45
  Gets a project by name in some namespace.
48
46
  If the project is not found, a `ProjectNotFoundError` is raised.
@@ -62,9 +60,7 @@ def get(name: str, namespace: str, session: Optional[Session]) -> Project:
62
60
  return Session.get(session).catalog.metastore.get_project(name, namespace)
63
61
 
64
62
 
65
- def ls(
66
- namespace: Optional[str] = None, session: Optional[Session] = None
67
- ) -> list[Project]:
63
+ def ls(namespace: str | None = None, session: Session | None = None) -> list[Project]:
68
64
  """
69
65
  Gets a list of projects in a specific namespace or from all namespaces.
70
66
 
@@ -88,7 +84,7 @@ def ls(
88
84
  return session.catalog.metastore.list_projects(namespace_id)
89
85
 
90
86
 
91
- def delete(name: str, namespace: str, session: Optional[Session] = None) -> None:
87
+ def delete(name: str, namespace: str, session: Session | None = None) -> None:
92
88
  """
93
89
  Removes a project by name within a namespace.
94
90
 
datachain/lib/pytorch.py CHANGED
@@ -1,9 +1,9 @@
1
1
  import logging
2
2
  import os
3
3
  import weakref
4
- from collections.abc import Generator, Iterable, Iterator
4
+ from collections.abc import Callable, Generator, Iterable, Iterator
5
5
  from contextlib import closing
6
- from typing import TYPE_CHECKING, Any, Callable, Optional
6
+ from typing import TYPE_CHECKING, Any
7
7
 
8
8
  from PIL import Image
9
9
  from torch import float32
@@ -43,13 +43,13 @@ class PytorchDataset(IterableDataset):
43
43
  def __init__(
44
44
  self,
45
45
  name: str,
46
- version: Optional[str] = None,
47
- catalog: Optional["Catalog"] = None,
48
- transform: Optional["Transform"] = None,
49
- tokenizer: Optional[Callable] = None,
50
- tokenizer_kwargs: Optional[dict[str, Any]] = None,
46
+ version: str | None = None,
47
+ catalog: Catalog | None = None,
48
+ transform: "Transform | None" = None,
49
+ tokenizer: Callable | None = None,
50
+ tokenizer_kwargs: dict[str, Any] | None = None,
51
51
  num_samples: int = 0,
52
- dc_settings: Optional[Settings] = None,
52
+ dc_settings: Settings | None = None,
53
53
  remove_prefetched: bool = False,
54
54
  ):
55
55
  """
@@ -84,7 +84,7 @@ class PytorchDataset(IterableDataset):
84
84
  self.prefetch = prefetch
85
85
 
86
86
  self._cache = catalog.cache
87
- self._prefetch_cache: Optional[Cache] = None
87
+ self._prefetch_cache: Cache | None = None
88
88
  self._remove_prefetched = remove_prefetched
89
89
  if prefetch and not self.cache:
90
90
  tmp_dir = catalog.cache.tmp_dir
@@ -104,7 +104,7 @@ class PytorchDataset(IterableDataset):
104
104
  self._ms_params = catalog.metastore.clone_params()
105
105
  self._wh_params = catalog.warehouse.clone_params()
106
106
  self._catalog_params = catalog.get_init_params()
107
- self.catalog: Optional[Catalog] = None
107
+ self.catalog: Catalog | None = None
108
108
 
109
109
  def _get_catalog(self) -> "Catalog":
110
110
  ms_cls, ms_args, ms_kwargs = self._ms_params
datachain/lib/settings.py CHANGED
@@ -1,4 +1,4 @@
1
- from typing import Any, Optional, Union
1
+ from typing import Any
2
2
 
3
3
  from datachain.lib.utils import DataChainParamsError
4
4
 
@@ -15,25 +15,25 @@ class SettingsError(DataChainParamsError):
15
15
  class Settings:
16
16
  """Settings for datachain."""
17
17
 
18
- _cache: Optional[bool]
19
- _prefetch: Optional[int]
20
- _parallel: Optional[Union[bool, int]]
21
- _workers: Optional[int]
22
- _namespace: Optional[str]
23
- _project: Optional[str]
24
- _min_task_size: Optional[int]
25
- _batch_size: Optional[int]
18
+ _cache: bool | None
19
+ _prefetch: int | None
20
+ _parallel: bool | int | None
21
+ _workers: int | None
22
+ _namespace: str | None
23
+ _project: str | None
24
+ _min_task_size: int | None
25
+ _batch_size: int | None
26
26
 
27
27
  def __init__( # noqa: C901, PLR0912
28
28
  self,
29
- cache: Optional[bool] = None,
30
- prefetch: Optional[Union[bool, int]] = None,
31
- parallel: Optional[Union[bool, int]] = None,
32
- workers: Optional[int] = None,
33
- namespace: Optional[str] = None,
34
- project: Optional[str] = None,
35
- min_task_size: Optional[int] = None,
36
- batch_size: Optional[int] = None,
29
+ cache: bool | None = None,
30
+ prefetch: bool | int | None = None,
31
+ parallel: bool | int | None = None,
32
+ workers: int | None = None,
33
+ namespace: str | None = None,
34
+ project: str | None = None,
35
+ min_task_size: int | None = None,
36
+ batch_size: int | None = None,
37
37
  ) -> None:
38
38
  if cache is None:
39
39
  self._cache = None
@@ -148,27 +148,27 @@ class Settings:
148
148
  return self._cache if self._cache is not None else DEFAULT_CACHE
149
149
 
150
150
  @property
151
- def prefetch(self) -> Optional[int]:
151
+ def prefetch(self) -> int | None:
152
152
  return self._prefetch if self._prefetch is not None else DEFAULT_PREFETCH
153
153
 
154
154
  @property
155
- def parallel(self) -> Optional[Union[bool, int]]:
155
+ def parallel(self) -> bool | int | None:
156
156
  return self._parallel if self._parallel is not None else None
157
157
 
158
158
  @property
159
- def workers(self) -> Optional[int]:
159
+ def workers(self) -> int | None:
160
160
  return self._workers if self._workers is not None else None
161
161
 
162
162
  @property
163
- def namespace(self) -> Optional[str]:
163
+ def namespace(self) -> str | None:
164
164
  return self._namespace if self._namespace is not None else None
165
165
 
166
166
  @property
167
- def project(self) -> Optional[str]:
167
+ def project(self) -> str | None:
168
168
  return self._project if self._project is not None else None
169
169
 
170
170
  @property
171
- def min_task_size(self) -> Optional[int]:
171
+ def min_task_size(self) -> int | None:
172
172
  return self._min_task_size if self._min_task_size is not None else None
173
173
 
174
174
  @property
@@ -3,22 +3,21 @@ import hashlib
3
3
  import json
4
4
  import logging
5
5
  import math
6
+ import types
6
7
  import warnings
7
- from collections.abc import Iterator, Sequence
8
+ from collections.abc import Callable, Iterator, Mapping, Sequence
8
9
  from dataclasses import dataclass
9
10
  from datetime import datetime
10
11
  from inspect import isclass
11
- from typing import ( # noqa: UP035
12
+ from typing import (
12
13
  IO,
13
14
  TYPE_CHECKING,
14
15
  Annotated,
15
16
  Any,
16
- Callable,
17
- Dict,
17
+ Dict, # type: ignore[UP035]
18
18
  Final,
19
- List,
19
+ List, # type: ignore[UP035]
20
20
  Literal,
21
- Mapping,
22
21
  Optional,
23
22
  Union,
24
23
  get_args,
@@ -75,7 +74,7 @@ class SignalSchemaWarning(RuntimeWarning):
75
74
 
76
75
 
77
76
  class SignalResolvingError(SignalSchemaError):
78
- def __init__(self, path: Optional[list[str]], msg: str):
77
+ def __init__(self, path: list[str] | None, msg: str):
79
78
  name = " '" + ".".join(path) + "'" if path else ""
80
79
  super().__init__(f"cannot resolve signal name{name}: {msg}")
81
80
 
@@ -95,7 +94,7 @@ class SignalResolvingTypeError(SignalResolvingError):
95
94
 
96
95
 
97
96
  class SignalRemoveError(SignalSchemaError):
98
- def __init__(self, path: Optional[list[str]], msg: str):
97
+ def __init__(self, path: list[str] | None, msg: str):
99
98
  name = " '" + ".".join(path) + "'" if path else ""
100
99
  super().__init__(f"cannot remove signal name{name}: {msg}")
101
100
 
@@ -104,8 +103,8 @@ class CustomType(BaseModel):
104
103
  schema_version: int = Field(ge=1, le=2, strict=True)
105
104
  name: str
106
105
  fields: dict[str, str]
107
- bases: list[tuple[str, str, Optional[str]]]
108
- hidden_fields: Optional[list[str]] = None
106
+ bases: list[tuple[str, str, str | None]]
107
+ hidden_fields: list[str] | None = None
109
108
 
110
109
  @classmethod
111
110
  def deserialize(cls, data: dict[str, Any], type_name: str) -> "CustomType":
@@ -125,8 +124,8 @@ class CustomType(BaseModel):
125
124
 
126
125
  def create_feature_model(
127
126
  name: str,
128
- fields: Mapping[str, Union[type, None, tuple[type, Any]]],
129
- base: Optional[type] = None,
127
+ fields: Mapping[str, type | tuple[type, Any] | None],
128
+ base: type | None = None,
130
129
  ) -> type[BaseModel]:
131
130
  """
132
131
  This gets or returns a dynamic feature model for use in restoring a model
@@ -152,12 +151,12 @@ class SignalSchema:
152
151
  values: dict[str, DataType]
153
152
  tree: dict[str, Any]
154
153
  setup_func: dict[str, Callable]
155
- setup_values: Optional[dict[str, Any]]
154
+ setup_values: dict[str, Any] | None
156
155
 
157
156
  def __init__(
158
157
  self,
159
158
  values: dict[str, DataType],
160
- setup: Optional[dict[str, Callable]] = None,
159
+ setup: dict[str, Callable] | None = None,
161
160
  ):
162
161
  self.values = values
163
162
  self.tree = self._build_tree(values)
@@ -196,8 +195,8 @@ class SignalSchema:
196
195
  return SignalSchema(signals)
197
196
 
198
197
  @staticmethod
199
- def _get_bases(fr: type) -> list[tuple[str, str, Optional[str]]]:
200
- bases: list[tuple[str, str, Optional[str]]] = []
198
+ def _get_bases(fr: type) -> list[tuple[str, str, str | None]]:
199
+ bases: list[tuple[str, str, str | None]] = []
201
200
  for base in fr.__mro__:
202
201
  model_store_name = (
203
202
  ModelStore.get_name(base) if issubclass(base, DataModel) else None
@@ -294,7 +293,7 @@ class SignalSchema:
294
293
  @staticmethod
295
294
  def _deserialize_custom_type(
296
295
  type_name: str, custom_types: dict[str, Any]
297
- ) -> Optional[type]:
296
+ ) -> type | None:
298
297
  """Given a type name like MyType@v1 gets a type from ModelStore or recreates
299
298
  it based on the information from the custom types dict that includes fields and
300
299
  bases."""
@@ -327,7 +326,7 @@ class SignalSchema:
327
326
  return None
328
327
 
329
328
  @staticmethod
330
- def _resolve_type(type_name: str, custom_types: dict[str, Any]) -> Optional[type]:
329
+ def _resolve_type(type_name: str, custom_types: dict[str, Any]) -> type | None:
331
330
  """Convert a string-based type back into a python type."""
332
331
  type_name = type_name.strip()
333
332
  if not type_name:
@@ -336,7 +335,7 @@ class SignalSchema:
336
335
  return None
337
336
 
338
337
  bracket_idx = type_name.find("[")
339
- subtypes: Optional[tuple[Optional[type], ...]] = None
338
+ subtypes: tuple[type | None, ...] | None = None
340
339
  if bracket_idx > -1:
341
340
  if bracket_idx == 0:
342
341
  raise ValueError("Type cannot start with '['")
@@ -493,7 +492,7 @@ class SignalSchema:
493
492
  return math.isnan(value) or value is None
494
493
  return value is None
495
494
 
496
- def get_file_signal(self) -> Optional[str]:
495
+ def get_file_signal(self) -> str | None:
497
496
  for signal_name, signal_type in self.values.items():
498
497
  if (fr := ModelStore.to_pydantic(signal_type)) is not None and issubclass(
499
498
  fr, File
@@ -503,8 +502,8 @@ class SignalSchema:
503
502
 
504
503
  def slice(
505
504
  self,
506
- params: dict[str, Union[DataType, Any]],
507
- setup: Optional[dict[str, Callable]] = None,
505
+ params: dict[str, DataType | Any],
506
+ setup: dict[str, Callable] | None = None,
508
507
  is_batch: bool = False,
509
508
  ) -> "SignalSchema":
510
509
  """
@@ -528,9 +527,13 @@ class SignalSchema:
528
527
  schema_origin = get_origin(schema_type)
529
528
  param_origin = get_origin(param_type)
530
529
 
531
- if schema_origin is Union and type(None) in get_args(schema_type):
530
+ if schema_origin in (Union, types.UnionType) and type(None) in get_args(
531
+ schema_type
532
+ ):
532
533
  schema_type = get_args(schema_type)[0]
533
- if param_origin is Union and type(None) in get_args(param_type):
534
+ if param_origin in (Union, types.UnionType) and type(None) in get_args(
535
+ param_type
536
+ ):
534
537
  param_type = get_args(param_type)[0]
535
538
 
536
539
  if is_batch:
@@ -610,8 +613,8 @@ class SignalSchema:
610
613
  raise SignalResolvingError([col_name], "is not found")
611
614
 
612
615
  def db_signals(
613
- self, name: Optional[str] = None, as_columns=False, include_hidden: bool = True
614
- ) -> Union[list[str], list[Column]]:
616
+ self, name: str | None = None, as_columns=False, include_hidden: bool = True
617
+ ) -> list[str] | list[Column]:
615
618
  """
616
619
  Returns DB columns as strings or Column objects with proper types
617
620
  Optionally, it can filter results by specific object, returning only his signals
@@ -802,7 +805,7 @@ class SignalSchema:
802
805
  @staticmethod
803
806
  def _build_tree(
804
807
  values: dict[str, DataType],
805
- ) -> dict[str, tuple[DataType, Optional[dict]]]:
808
+ ) -> dict[str, tuple[DataType, dict | None]]:
806
809
  return {
807
810
  name: (val, SignalSchema._build_tree_for_type(val))
808
811
  for name, val in values.items()
@@ -834,7 +837,7 @@ class SignalSchema:
834
837
  substree, new_prefix, depth + 1, include_hidden
835
838
  )
836
839
 
837
- def print_tree(self, indent: int = 2, start_at: int = 0, file: Optional[IO] = None):
840
+ def print_tree(self, indent: int = 2, start_at: int = 0, file: IO | None = None):
838
841
  for path, type_, _, depth in self.get_flat_tree():
839
842
  total_indent = start_at + depth * indent
840
843
  col_name = " " * total_indent + path[-1]
@@ -873,15 +876,20 @@ class SignalSchema:
873
876
  return self.values.pop(name)
874
877
 
875
878
  @staticmethod
876
- def _type_to_str(type_: Optional[type], subtypes: Optional[list] = None) -> str: # noqa: PLR0911
879
+ def _type_to_str(type_: type | None, subtypes: list | None = None) -> str: # noqa: C901, PLR0911
877
880
  """Convert a type to a string-based representation."""
878
881
  if type_ is None:
879
882
  return "NoneType"
880
883
 
881
884
  origin = get_origin(type_)
882
885
 
883
- if origin == Union:
886
+ if origin in (Union, types.UnionType):
884
887
  args = get_args(type_)
888
+ if len(args) == 2 and type(None) in args:
889
+ # This is an Optional type.
890
+ non_none_type = args[0] if args[1] is type(None) else args[1]
891
+ type_str = SignalSchema._type_to_str(non_none_type, subtypes)
892
+ return f"Optional[{type_str}]"
885
893
  formatted_types = ", ".join(
886
894
  SignalSchema._type_to_str(arg, subtypes) for arg in args
887
895
  )
@@ -892,19 +900,19 @@ class SignalSchema:
892
900
  return f"Optional[{type_str}]"
893
901
  if origin in (list, List): # noqa: UP006
894
902
  args = get_args(type_)
903
+ if len(args) == 0:
904
+ return "list"
895
905
  type_str = SignalSchema._type_to_str(args[0], subtypes)
896
906
  return f"list[{type_str}]"
897
907
  if origin in (dict, Dict): # noqa: UP006
898
908
  args = get_args(type_)
899
- type_str = (
900
- SignalSchema._type_to_str(args[0], subtypes) if len(args) > 0 else ""
901
- )
902
- vals = (
903
- f", {SignalSchema._type_to_str(args[1], subtypes)}"
904
- if len(args) > 1
905
- else ""
906
- )
907
- return f"dict[{type_str}{vals}]"
909
+ if len(args) == 0:
910
+ return "dict"
911
+ key_type = SignalSchema._type_to_str(args[0], subtypes)
912
+ if len(args) == 1:
913
+ return f"dict[{key_type}, Any]"
914
+ val_type = SignalSchema._type_to_str(args[1], subtypes)
915
+ return f"dict[{key_type}, {val_type}]"
908
916
  if origin == Annotated:
909
917
  args = get_args(type_)
910
918
  return SignalSchema._type_to_str(args[0], subtypes)
@@ -918,7 +926,7 @@ class SignalSchema:
918
926
  # Include this type in the list of all subtypes, if requested.
919
927
  subtypes.append(type_)
920
928
  if not hasattr(type_, "__name__"):
921
- # This can happen for some third-party or custom types, mostly on Python 3.9
929
+ # This can happen for some third-party or custom types
922
930
  warnings.warn(
923
931
  f"Unable to determine name of type '{type_}'.",
924
932
  SignalSchemaWarning,
@@ -933,7 +941,7 @@ class SignalSchema:
933
941
  @staticmethod
934
942
  def _build_tree_for_type(
935
943
  model: DataType,
936
- ) -> Optional[dict[str, tuple[DataType, Optional[dict]]]]:
944
+ ) -> dict[str, tuple[DataType, dict | None]] | None:
937
945
  if (fr := ModelStore.to_pydantic(model)) is not None:
938
946
  return SignalSchema._build_tree_for_model(fr)
939
947
  return None
@@ -941,8 +949,8 @@ class SignalSchema:
941
949
  @staticmethod
942
950
  def _build_tree_for_model(
943
951
  model: type[BaseModel],
944
- ) -> Optional[dict[str, tuple[DataType, Optional[dict]]]]:
945
- res: dict[str, tuple[DataType, Optional[dict]]] = {}
952
+ ) -> dict[str, tuple[DataType, dict | None]] | None:
953
+ res: dict[str, tuple[DataType, dict | None]] = {}
946
954
 
947
955
  for name, f_info in model.model_fields.items():
948
956
  anno = f_info.annotation
@@ -991,7 +999,7 @@ class SignalSchema:
991
999
  schema: dict[str, Any] = {}
992
1000
  schema_custom_types: dict[str, CustomType] = {}
993
1001
 
994
- data_model_bases: Optional[list[tuple[str, str, Optional[str]]]] = None
1002
+ data_model_bases: list[tuple[str, str, str | None]] | None = None
995
1003
 
996
1004
  signal_partials: dict[str, str] = {}
997
1005
  partial_versions: dict[str, int] = {}
datachain/lib/text.py CHANGED
@@ -1,16 +1,17 @@
1
- from typing import Any, Callable, Optional, Union
1
+ from collections.abc import Callable
2
+ from typing import Any
2
3
 
3
4
  import torch
4
5
  from transformers.tokenization_utils_base import PreTrainedTokenizerBase
5
6
 
6
7
 
7
8
  def convert_text(
8
- text: Union[str, list[str]],
9
- tokenizer: Optional[Callable] = None,
10
- tokenizer_kwargs: Optional[dict[str, Any]] = None,
11
- encoder: Optional[Callable] = None,
12
- device: Optional[Union[str, torch.device]] = None,
13
- ) -> Union[str, list[str], torch.Tensor]:
9
+ text: str | list[str],
10
+ tokenizer: Callable | None = None,
11
+ tokenizer_kwargs: dict[str, Any] | None = None,
12
+ encoder: Callable | None = None,
13
+ device: str | torch.device | None = None,
14
+ ) -> str | list[str] | torch.Tensor:
14
15
  """
15
16
  Tokenize and otherwise transform text.
16
17
 
datachain/lib/udf.py CHANGED
@@ -4,7 +4,7 @@ import traceback
4
4
  from collections.abc import Callable, Iterable, Iterator, Mapping, Sequence
5
5
  from contextlib import closing, nullcontext
6
6
  from functools import partial
7
- from typing import TYPE_CHECKING, Any, Optional, TypeVar
7
+ from typing import TYPE_CHECKING, Any, TypeVar
8
8
 
9
9
  import attrs
10
10
  from fsspec.callbacks import DEFAULT_CALLBACK, Callback
@@ -60,7 +60,7 @@ UDFResult = dict[str, Any]
60
60
  class UDFAdapter:
61
61
  inner: "UDFBase"
62
62
  output: UDFOutputSpec
63
- batch_size: Optional[int] = None
63
+ batch_size: int | None = None
64
64
  batch: int = 1
65
65
 
66
66
  def hash(self) -> str:
@@ -152,7 +152,7 @@ class UDFBase(AbstractUDF):
152
152
  prefetch: int = 0
153
153
 
154
154
  def __init__(self):
155
- self.params: Optional[SignalSchema] = None
155
+ self.params: SignalSchema | None = None
156
156
  self.output = None
157
157
  self._func = None
158
158
 
@@ -197,7 +197,7 @@ class UDFBase(AbstractUDF):
197
197
  self,
198
198
  sign: "UdfSignature",
199
199
  params: "SignalSchema",
200
- func: Optional[Callable],
200
+ func: Callable | None,
201
201
  ):
202
202
  self.params = params
203
203
  self.output = sign.output_schema
@@ -246,7 +246,7 @@ class UDFBase(AbstractUDF):
246
246
 
247
247
  def to_udf_wrapper(
248
248
  self,
249
- batch_size: Optional[int] = None,
249
+ batch_size: int | None = None,
250
250
  batch: int = 1,
251
251
  ) -> UDFAdapter:
252
252
  return UDFAdapter(
@@ -304,11 +304,11 @@ class UDFBase(AbstractUDF):
304
304
  self._set_stream_recursive(field_value, catalog, cache, download_cb)
305
305
 
306
306
  def _prepare_row(self, row, udf_fields, catalog, cache, download_cb):
307
- row_dict = RowDict(zip(udf_fields, row))
307
+ row_dict = RowDict(zip(udf_fields, row, strict=False))
308
308
  return self._parse_row(row_dict, catalog, cache, download_cb)
309
309
 
310
310
  def _prepare_row_and_id(self, row, udf_fields, catalog, cache, download_cb):
311
- row_dict = RowDict(zip(udf_fields, row))
311
+ row_dict = RowDict(zip(udf_fields, row, strict=False))
312
312
  udf_input = self._parse_row(row_dict, catalog, cache, download_cb)
313
313
  return row_dict["sys__id"], *udf_input
314
314
 
@@ -333,7 +333,7 @@ def noop(*args, **kwargs):
333
333
 
334
334
  async def _prefetch_input(
335
335
  row: T,
336
- download_cb: Optional["Callback"] = None,
336
+ download_cb: Callback | None = None,
337
337
  after_prefetch: "Callable[[], None]" = noop,
338
338
  ) -> T:
339
339
  for obj in row:
@@ -356,8 +356,8 @@ def _remove_prefetched(row: T) -> None:
356
356
  def _prefetch_inputs(
357
357
  prepared_inputs: "Iterable[T]",
358
358
  prefetch: int = 0,
359
- download_cb: Optional["Callback"] = None,
360
- after_prefetch: Optional[Callable[[], None]] = None,
359
+ download_cb: Callback | None = None,
360
+ after_prefetch: Callable[[], None] | None = None,
361
361
  remove_prefetched: bool = False,
362
362
  ) -> "abc.Generator[T, None, None]":
363
363
  if not prefetch:
@@ -426,7 +426,10 @@ class Mapper(UDFBase):
426
426
  for id_, *udf_args in prepared_inputs:
427
427
  result_objs = self.process_safe(udf_args)
428
428
  udf_output = self._flatten_row(result_objs)
429
- output = [{"sys__id": id_} | dict(zip(self.signal_names, udf_output))]
429
+ output = [
430
+ {"sys__id": id_}
431
+ | dict(zip(self.signal_names, udf_output, strict=False))
432
+ ]
430
433
  processed_cb.relative_update(1)
431
434
  yield output
432
435
 
@@ -474,7 +477,8 @@ class BatchMapper(UDFBase):
474
477
  row, udf_fields, catalog, cache, download_cb
475
478
  )
476
479
  for row in batch
477
- ]
480
+ ],
481
+ strict=False,
478
482
  )
479
483
  result_objs = list(self.process_safe(udf_args))
480
484
  n_objs = len(result_objs)
@@ -483,8 +487,9 @@ class BatchMapper(UDFBase):
483
487
  )
484
488
  udf_outputs = (self._flatten_row(row) for row in result_objs)
485
489
  output = [
486
- {"sys__id": row_id} | dict(zip(self.signal_names, signals))
487
- for row_id, signals in zip(row_ids, udf_outputs)
490
+ {"sys__id": row_id}
491
+ | dict(zip(self.signal_names, signals, strict=False))
492
+ for row_id, signals in zip(row_ids, udf_outputs, strict=False)
488
493
  ]
489
494
  processed_cb.relative_update(n_rows)
490
495
  yield output
@@ -520,7 +525,7 @@ class Generator(UDFBase):
520
525
  with safe_closing(self.process_safe(row)) as result_objs:
521
526
  for result_obj in result_objs:
522
527
  udf_output = self._flatten_row(result_obj)
523
- yield dict(zip(self.signal_names, udf_output))
528
+ yield dict(zip(self.signal_names, udf_output, strict=False))
524
529
 
525
530
  prepared_inputs = _prepare_rows(udf_inputs)
526
531
  prepared_inputs = _prefetch_inputs(
@@ -559,11 +564,14 @@ class Aggregator(UDFBase):
559
564
  *[
560
565
  self._prepare_row(row, udf_fields, catalog, cache, download_cb)
561
566
  for row in batch
562
- ]
567
+ ],
568
+ strict=False,
563
569
  )
564
570
  result_objs = self.process_safe(udf_args)
565
571
  udf_outputs = (self._flatten_row(row) for row in result_objs)
566
- output = (dict(zip(self.signal_names, row)) for row in udf_outputs)
572
+ output = (
573
+ dict(zip(self.signal_names, row, strict=False)) for row in udf_outputs
574
+ )
567
575
  processed_cb.relative_update(len(batch))
568
576
  yield output
569
577
 
@@ -1,7 +1,7 @@
1
1
  import inspect
2
- from collections.abc import Generator, Iterator, Sequence
2
+ from collections.abc import Callable, Generator, Iterator, Sequence
3
3
  from dataclasses import dataclass
4
- from typing import Any, Callable, Union, get_args, get_origin
4
+ from typing import Any, get_args, get_origin
5
5
 
6
6
  from datachain.lib.data_model import DataType, DataTypeNames, is_chain_type
7
7
  from datachain.lib.signal_schema import SignalSchema
@@ -17,8 +17,8 @@ class UdfSignatureError(DataChainParamsError):
17
17
 
18
18
  @dataclass
19
19
  class UdfSignature: # noqa: PLW1641
20
- func: Union[Callable, UDFBase]
21
- params: dict[str, Union[DataType, Any]]
20
+ func: Callable | UDFBase
21
+ params: dict[str, DataType | Any]
22
22
  output_schema: SignalSchema
23
23
 
24
24
  DEFAULT_RETURN_TYPE = str
@@ -28,9 +28,9 @@ class UdfSignature: # noqa: PLW1641
28
28
  cls,
29
29
  chain: str,
30
30
  signal_map: dict[str, Callable],
31
- func: Union[None, UDFBase, Callable] = None,
32
- params: Union[None, str, Sequence[str]] = None,
33
- output: Union[None, DataType, Sequence[str], dict[str, DataType]] = None,
31
+ func: UDFBase | Callable | None = None,
32
+ params: str | Sequence[str] | None = None,
33
+ output: DataType | Sequence[str] | dict[str, DataType] | None = None,
34
34
  is_generator: bool = True,
35
35
  ) -> "UdfSignature":
36
36
  keys = ", ".join(signal_map.keys())
@@ -40,7 +40,7 @@ class UdfSignature: # noqa: PLW1641
40
40
  f"multiple signals '{keys}' are not supported in processors."
41
41
  " Chain multiple processors instead.",
42
42
  )
43
- udf_func: Union[UDFBase, Callable]
43
+ udf_func: UDFBase | Callable
44
44
  if len(signal_map) == 1:
45
45
  if func is not None:
46
46
  raise UdfSignatureError(
@@ -62,7 +62,7 @@ class UdfSignature: # noqa: PLW1641
62
62
  chain, udf_func
63
63
  )
64
64
 
65
- udf_params: dict[str, Union[DataType, Any]] = {}
65
+ udf_params: dict[str, DataType | Any] = {}
66
66
  if params:
67
67
  udf_params = (
68
68
  {params: Any} if isinstance(params, str) else dict.fromkeys(params, Any)
@@ -128,7 +128,7 @@ class UdfSignature: # noqa: PLW1641
128
128
  f" return type length ({len(func_outs_sign)}) does not match",
129
129
  )
130
130
 
131
- udf_output_map = dict(zip(output, func_outs_sign))
131
+ udf_output_map = dict(zip(output, func_outs_sign, strict=False))
132
132
  elif isinstance(output, dict):
133
133
  for key, value in output.items():
134
134
  if not isinstance(key, str):
@@ -164,7 +164,7 @@ class UdfSignature: # noqa: PLW1641
164
164
 
165
165
  @staticmethod
166
166
  def _func_signature(
167
- chain: str, udf_func: Union[Callable, UDFBase]
167
+ chain: str, udf_func: Callable | UDFBase
168
168
  ) -> tuple[dict[str, type], Sequence[type], bool]:
169
169
  if isinstance(udf_func, AbstractUDF):
170
170
  func = udf_func.process # type: ignore[unreachable]