datachain 0.34.6__py3-none-any.whl → 0.34.7__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of datachain might be problematic. Click here for more details.

Files changed (105) hide show
  1. datachain/asyn.py +11 -12
  2. datachain/cache.py +5 -5
  3. datachain/catalog/catalog.py +75 -83
  4. datachain/catalog/loader.py +3 -3
  5. datachain/checkpoint.py +1 -2
  6. datachain/cli/__init__.py +2 -4
  7. datachain/cli/commands/datasets.py +13 -13
  8. datachain/cli/commands/ls.py +4 -4
  9. datachain/cli/commands/query.py +3 -3
  10. datachain/cli/commands/show.py +2 -2
  11. datachain/cli/parser/job.py +1 -1
  12. datachain/cli/parser/utils.py +1 -2
  13. datachain/cli/utils.py +1 -2
  14. datachain/client/azure.py +2 -2
  15. datachain/client/fsspec.py +11 -21
  16. datachain/client/gcs.py +3 -3
  17. datachain/client/http.py +4 -4
  18. datachain/client/local.py +4 -4
  19. datachain/client/s3.py +3 -3
  20. datachain/config.py +4 -8
  21. datachain/data_storage/db_engine.py +5 -5
  22. datachain/data_storage/metastore.py +107 -107
  23. datachain/data_storage/schema.py +18 -24
  24. datachain/data_storage/sqlite.py +21 -28
  25. datachain/data_storage/warehouse.py +13 -13
  26. datachain/dataset.py +64 -70
  27. datachain/delta.py +21 -18
  28. datachain/diff/__init__.py +13 -13
  29. datachain/func/aggregate.py +9 -11
  30. datachain/func/array.py +12 -12
  31. datachain/func/base.py +7 -4
  32. datachain/func/conditional.py +9 -13
  33. datachain/func/func.py +45 -42
  34. datachain/func/numeric.py +5 -7
  35. datachain/func/string.py +2 -2
  36. datachain/hash_utils.py +54 -81
  37. datachain/job.py +8 -8
  38. datachain/lib/arrow.py +17 -14
  39. datachain/lib/audio.py +6 -6
  40. datachain/lib/clip.py +5 -4
  41. datachain/lib/convert/python_to_sql.py +4 -22
  42. datachain/lib/convert/values_to_tuples.py +4 -9
  43. datachain/lib/data_model.py +20 -19
  44. datachain/lib/dataset_info.py +6 -6
  45. datachain/lib/dc/csv.py +10 -10
  46. datachain/lib/dc/database.py +28 -29
  47. datachain/lib/dc/datachain.py +98 -97
  48. datachain/lib/dc/datasets.py +22 -22
  49. datachain/lib/dc/hf.py +4 -4
  50. datachain/lib/dc/json.py +9 -10
  51. datachain/lib/dc/listings.py +5 -8
  52. datachain/lib/dc/pandas.py +3 -6
  53. datachain/lib/dc/parquet.py +5 -5
  54. datachain/lib/dc/records.py +5 -5
  55. datachain/lib/dc/storage.py +12 -12
  56. datachain/lib/dc/storage_pattern.py +2 -2
  57. datachain/lib/dc/utils.py +11 -14
  58. datachain/lib/dc/values.py +3 -6
  59. datachain/lib/file.py +26 -26
  60. datachain/lib/hf.py +7 -5
  61. datachain/lib/image.py +13 -13
  62. datachain/lib/listing.py +5 -5
  63. datachain/lib/listing_info.py +1 -2
  64. datachain/lib/meta_formats.py +1 -2
  65. datachain/lib/model_store.py +3 -3
  66. datachain/lib/namespaces.py +4 -6
  67. datachain/lib/projects.py +5 -9
  68. datachain/lib/pytorch.py +10 -10
  69. datachain/lib/settings.py +23 -23
  70. datachain/lib/signal_schema.py +52 -44
  71. datachain/lib/text.py +8 -7
  72. datachain/lib/udf.py +25 -17
  73. datachain/lib/udf_signature.py +11 -11
  74. datachain/lib/video.py +3 -4
  75. datachain/lib/webdataset.py +30 -35
  76. datachain/lib/webdataset_laion.py +15 -16
  77. datachain/listing.py +4 -4
  78. datachain/model/bbox.py +3 -1
  79. datachain/namespace.py +4 -4
  80. datachain/node.py +6 -6
  81. datachain/nodes_thread_pool.py +0 -1
  82. datachain/plugins.py +1 -7
  83. datachain/project.py +4 -4
  84. datachain/query/batch.py +7 -8
  85. datachain/query/dataset.py +80 -87
  86. datachain/query/dispatch.py +7 -7
  87. datachain/query/metrics.py +3 -4
  88. datachain/query/params.py +2 -3
  89. datachain/query/schema.py +7 -6
  90. datachain/query/session.py +7 -7
  91. datachain/query/udf.py +8 -7
  92. datachain/query/utils.py +3 -5
  93. datachain/remote/studio.py +33 -39
  94. datachain/script_meta.py +12 -12
  95. datachain/sql/sqlite/base.py +6 -9
  96. datachain/studio.py +30 -30
  97. datachain/toolkit/split.py +1 -2
  98. datachain/utils.py +21 -21
  99. {datachain-0.34.6.dist-info → datachain-0.34.7.dist-info}/METADATA +2 -3
  100. datachain-0.34.7.dist-info/RECORD +173 -0
  101. datachain-0.34.6.dist-info/RECORD +0 -173
  102. {datachain-0.34.6.dist-info → datachain-0.34.7.dist-info}/WHEEL +0 -0
  103. {datachain-0.34.6.dist-info → datachain-0.34.7.dist-info}/entry_points.txt +0 -0
  104. {datachain-0.34.6.dist-info → datachain-0.34.7.dist-info}/licenses/LICENSE +0 -0
  105. {datachain-0.34.6.dist-info → datachain-0.34.7.dist-info}/top_level.txt +0 -0
@@ -8,19 +8,11 @@ import string
8
8
  import subprocess
9
9
  import sys
10
10
  from abc import ABC, abstractmethod
11
- from collections.abc import Generator, Iterable, Iterator, Sequence
11
+ from collections.abc import Callable, Generator, Iterable, Iterator, Sequence
12
12
  from copy import copy
13
13
  from functools import wraps
14
14
  from types import GeneratorType
15
- from typing import (
16
- TYPE_CHECKING,
17
- Any,
18
- Callable,
19
- Optional,
20
- Protocol,
21
- TypeVar,
22
- Union,
23
- )
15
+ from typing import TYPE_CHECKING, Any, Protocol, TypeVar
24
16
 
25
17
  import attrs
26
18
  import sqlalchemy
@@ -67,11 +59,12 @@ from datachain.utils import (
67
59
 
68
60
  if TYPE_CHECKING:
69
61
  from collections.abc import Mapping
62
+ from typing import Concatenate
70
63
 
71
64
  from sqlalchemy.sql.elements import ClauseElement
72
65
  from sqlalchemy.sql.schema import Table
73
66
  from sqlalchemy.sql.selectable import GenerativeSelect
74
- from typing_extensions import Concatenate, ParamSpec, Self
67
+ from typing_extensions import ParamSpec, Self
75
68
 
76
69
  from datachain.catalog import Catalog
77
70
  from datachain.data_storage import AbstractWarehouse
@@ -83,13 +76,10 @@ if TYPE_CHECKING:
83
76
 
84
77
  INSERT_BATCH_SIZE = 10000
85
78
 
86
- PartitionByType = Union[
87
- str,
88
- Function,
89
- ColumnElement,
90
- Sequence[Union[str, Function, ColumnElement]],
91
- ]
92
- JoinPredicateType = Union[str, ColumnClause, ColumnElement]
79
+ PartitionByType = (
80
+ str | Function | ColumnElement | Sequence[str | Function | ColumnElement]
81
+ )
82
+ JoinPredicateType = str | ColumnClause | ColumnElement
93
83
  DatasetDependencyType = tuple["DatasetRecord", str]
94
84
 
95
85
  logger = logging.getLogger("datachain")
@@ -411,14 +401,14 @@ def get_generated_callback(is_generator: bool = False) -> Callback:
411
401
  class UDFStep(Step, ABC):
412
402
  udf: "UDFAdapter"
413
403
  catalog: "Catalog"
414
- partition_by: Optional[PartitionByType] = None
404
+ partition_by: PartitionByType | None = None
415
405
  is_generator = False
416
406
  # Parameters from Settings
417
407
  cache: bool = False
418
- parallel: Optional[int] = None
419
- workers: Union[bool, int] = False
420
- min_task_size: Optional[int] = None
421
- batch_size: Optional[int] = None
408
+ parallel: int | None = None
409
+ workers: bool | int = False
410
+ min_task_size: int | None = None
411
+ batch_size: int | None = None
422
412
 
423
413
  def hash_inputs(self) -> str:
424
414
  partition_by = ensure_sequence(self.partition_by or [])
@@ -624,7 +614,7 @@ class UDFStep(Step, ABC):
624
614
 
625
615
  return tbl
626
616
 
627
- def clone(self, partition_by: Optional[PartitionByType] = None) -> "Self":
617
+ def clone(self, partition_by: PartitionByType | None = None) -> "Self":
628
618
  if partition_by is not None:
629
619
  return self.__class__(
630
620
  self.udf,
@@ -681,14 +671,14 @@ class UDFStep(Step, ABC):
681
671
  class UDFSignal(UDFStep):
682
672
  udf: "UDFAdapter"
683
673
  catalog: "Catalog"
684
- partition_by: Optional[PartitionByType] = None
674
+ partition_by: PartitionByType | None = None
685
675
  is_generator = False
686
676
  # Parameters from Settings
687
677
  cache: bool = False
688
- parallel: Optional[int] = None
689
- workers: Union[bool, int] = False
690
- min_task_size: Optional[int] = None
691
- batch_size: Optional[int] = None
678
+ parallel: int | None = None
679
+ workers: bool | int = False
680
+ min_task_size: int | None = None
681
+ batch_size: int | None = None
692
682
 
693
683
  def create_udf_table(self, query: Select) -> "Table":
694
684
  udf_output_columns: list[sqlalchemy.Column[Any]] = [
@@ -760,14 +750,14 @@ class RowGenerator(UDFStep):
760
750
 
761
751
  udf: "UDFAdapter"
762
752
  catalog: "Catalog"
763
- partition_by: Optional[PartitionByType] = None
753
+ partition_by: PartitionByType | None = None
764
754
  is_generator = True
765
755
  # Parameters from Settings
766
756
  cache: bool = False
767
- parallel: Optional[int] = None
768
- workers: Union[bool, int] = False
769
- min_task_size: Optional[int] = None
770
- batch_size: Optional[int] = None
757
+ parallel: int | None = None
758
+ workers: bool | int = False
759
+ min_task_size: int | None = None
760
+ batch_size: int | None = None
771
761
 
772
762
  def create_udf_table(self, query: Select) -> "Table":
773
763
  warehouse = self.catalog.warehouse
@@ -814,7 +804,7 @@ class SQLClause(Step, ABC):
814
804
 
815
805
  def parse_cols(
816
806
  self,
817
- cols: Sequence[Union[Function, ColumnElement]],
807
+ cols: Sequence[Function | ColumnElement],
818
808
  ) -> tuple[ColumnElement, ...]:
819
809
  return tuple(c.get_column() if isinstance(c, Function) else c for c in cols)
820
810
 
@@ -825,7 +815,7 @@ class SQLClause(Step, ABC):
825
815
 
826
816
  @frozen
827
817
  class SQLSelect(SQLClause):
828
- args: tuple[Union[Function, ColumnElement], ...]
818
+ args: tuple[Function | ColumnElement, ...]
829
819
 
830
820
  def hash_inputs(self) -> str:
831
821
  return hash_column_elements(self.args)
@@ -844,7 +834,7 @@ class SQLSelect(SQLClause):
844
834
 
845
835
  @frozen
846
836
  class SQLSelectExcept(SQLClause):
847
- args: tuple[Union[Function, ColumnElement], ...]
837
+ args: tuple[Function | ColumnElement, ...]
848
838
 
849
839
  def hash_inputs(self) -> str:
850
840
  return hash_column_elements(self.args)
@@ -890,7 +880,7 @@ class SQLMutate(SQLClause):
890
880
 
891
881
  @frozen
892
882
  class SQLFilter(SQLClause):
893
- expressions: tuple[Union[Function, ColumnElement], ...]
883
+ expressions: tuple[Function | ColumnElement, ...]
894
884
 
895
885
  def hash_inputs(self) -> str:
896
886
  return hash_column_elements(self.expressions)
@@ -906,7 +896,7 @@ class SQLFilter(SQLClause):
906
896
 
907
897
  @frozen
908
898
  class SQLOrderBy(SQLClause):
909
- args: tuple[Union[Function, ColumnElement], ...]
899
+ args: tuple[Function | ColumnElement, ...]
910
900
 
911
901
  def hash_inputs(self) -> str:
912
902
  return hash_column_elements(self.args)
@@ -1011,7 +1001,7 @@ class SQLJoin(Step):
1011
1001
  catalog: "Catalog"
1012
1002
  query1: "DatasetQuery"
1013
1003
  query2: "DatasetQuery"
1014
- predicates: Union[JoinPredicateType, tuple[JoinPredicateType, ...]]
1004
+ predicates: JoinPredicateType | tuple[JoinPredicateType, ...]
1015
1005
  inner: bool
1016
1006
  full: bool
1017
1007
  rname: str
@@ -1150,8 +1140,8 @@ class SQLJoin(Step):
1150
1140
 
1151
1141
  @frozen
1152
1142
  class SQLGroupBy(SQLClause):
1153
- cols: Sequence[Union[str, Function, ColumnElement]]
1154
- group_by: Sequence[Union[str, Function, ColumnElement]]
1143
+ cols: Sequence[str | Function | ColumnElement]
1144
+ group_by: Sequence[str | Function | ColumnElement]
1155
1145
 
1156
1146
  def hash_inputs(self) -> str:
1157
1147
  return hashlib.sha256(
@@ -1211,6 +1201,7 @@ def _validate_columns(
1211
1201
  missing_left,
1212
1202
  ],
1213
1203
  ["left", "right"],
1204
+ strict=False,
1214
1205
  )
1215
1206
  if missing_columns
1216
1207
  ]
@@ -1243,32 +1234,32 @@ class DatasetQuery:
1243
1234
  def __init__(
1244
1235
  self,
1245
1236
  name: str,
1246
- version: Optional[str] = None,
1247
- project_name: Optional[str] = None,
1248
- namespace_name: Optional[str] = None,
1249
- catalog: Optional["Catalog"] = None,
1250
- session: Optional[Session] = None,
1237
+ version: str | None = None,
1238
+ project_name: str | None = None,
1239
+ namespace_name: str | None = None,
1240
+ catalog: "Catalog | None" = None,
1241
+ session: Session | None = None,
1251
1242
  in_memory: bool = False,
1252
1243
  update: bool = False,
1253
1244
  ) -> None:
1254
1245
  self.session = Session.get(session, catalog=catalog, in_memory=in_memory)
1255
1246
  self.catalog = catalog or self.session.catalog
1256
1247
  self.steps: list[Step] = []
1257
- self._chunk_index: Optional[int] = None
1258
- self._chunk_total: Optional[int] = None
1248
+ self._chunk_index: int | None = None
1249
+ self._chunk_total: int | None = None
1259
1250
  self.temp_table_names: list[str] = []
1260
1251
  self.dependencies: set[DatasetDependencyType] = set()
1261
1252
  self.table = self.get_table()
1262
- self.starting_step: Optional[QueryStep] = None
1263
- self.name: Optional[str] = None
1264
- self.version: Optional[str] = None
1265
- self.feature_schema: Optional[dict] = None
1266
- self.column_types: Optional[dict[str, Any]] = None
1253
+ self.starting_step: QueryStep | None = None
1254
+ self.name: str | None = None
1255
+ self.version: str | None = None
1256
+ self.feature_schema: dict | None = None
1257
+ self.column_types: dict[str, Any] | None = None
1267
1258
  self.before_steps: list[Callable] = []
1268
- self.listing_fn: Optional[Callable] = None
1259
+ self.listing_fn: Callable | None = None
1269
1260
  self.update = update
1270
1261
 
1271
- self.list_ds_name: Optional[str] = None
1262
+ self.list_ds_name: str | None = None
1272
1263
 
1273
1264
  self.name = name
1274
1265
  self.dialect = self.catalog.warehouse.db.dialect
@@ -1352,7 +1343,7 @@ class DatasetQuery:
1352
1343
  """
1353
1344
  return self.name is not None and self.version is not None
1354
1345
 
1355
- def c(self, column: Union[C, str]) -> "ColumnClause[Any]":
1346
+ def c(self, column: C | str) -> "ColumnClause[Any]":
1356
1347
  col: sqlalchemy.ColumnClause = (
1357
1348
  sqlalchemy.column(column)
1358
1349
  if isinstance(column, str)
@@ -1447,6 +1438,7 @@ class DatasetQuery:
1447
1438
  # This is needed to always use a new connection with all metastore and warehouse
1448
1439
  # implementations, as errors may close or render unusable the existing
1449
1440
  # connections.
1441
+ assert len(self.temp_table_names) == len(set(self.temp_table_names))
1450
1442
  with self.catalog.metastore.clone(use_new_connection=True) as metastore:
1451
1443
  metastore.cleanup_tables(self.temp_table_names)
1452
1444
  with self.catalog.warehouse.clone(use_new_connection=True) as warehouse:
@@ -1461,7 +1453,7 @@ class DatasetQuery:
1461
1453
  return list(result)
1462
1454
 
1463
1455
  def to_db_records(self) -> list[dict[str, Any]]:
1464
- return self.db_results(lambda cols, row: dict(zip(cols, row)))
1456
+ return self.db_results(lambda cols, row: dict(zip(cols, row, strict=False)))
1465
1457
 
1466
1458
  @contextlib.contextmanager
1467
1459
  def as_iterable(self, **kwargs) -> Iterator[ResultIter]:
@@ -1500,7 +1492,7 @@ class DatasetQuery:
1500
1492
  yield from rows
1501
1493
 
1502
1494
  async def get_params(row: Sequence) -> tuple:
1503
- row_dict = RowDict(zip(query_fields, row))
1495
+ row_dict = RowDict(zip(query_fields, row, strict=False))
1504
1496
  return tuple( # noqa: C409
1505
1497
  [
1506
1498
  await p.get_value_async(
@@ -1540,6 +1532,7 @@ class DatasetQuery:
1540
1532
  obj.steps = obj.steps.copy()
1541
1533
  if new_table:
1542
1534
  obj.table = self.get_table()
1535
+ obj.temp_table_names = []
1543
1536
  return obj
1544
1537
 
1545
1538
  @detach
@@ -1720,7 +1713,7 @@ class DatasetQuery:
1720
1713
  def join(
1721
1714
  self,
1722
1715
  dataset_query: "DatasetQuery",
1723
- predicates: Union[JoinPredicateType, Sequence[JoinPredicateType]],
1716
+ predicates: JoinPredicateType | Sequence[JoinPredicateType],
1724
1717
  inner=False,
1725
1718
  full=False,
1726
1719
  rname="{name}_right",
@@ -1762,17 +1755,17 @@ class DatasetQuery:
1762
1755
  def add_signals(
1763
1756
  self,
1764
1757
  udf: "UDFAdapter",
1765
- partition_by: Optional[PartitionByType] = None,
1758
+ partition_by: PartitionByType | None = None,
1766
1759
  # Parameters from Settings
1767
1760
  cache: bool = False,
1768
- parallel: Optional[int] = None,
1769
- workers: Union[bool, int] = False,
1770
- min_task_size: Optional[int] = None,
1771
- batch_size: Optional[int] = None,
1761
+ parallel: int | None = None,
1762
+ workers: bool | int = False,
1763
+ min_task_size: int | None = None,
1764
+ batch_size: int | None = None,
1772
1765
  # Parameters are unused, kept only to match the signature of Settings.to_dict
1773
- prefetch: Optional[int] = None,
1774
- namespace: Optional[str] = None,
1775
- project: Optional[str] = None,
1766
+ prefetch: int | None = None,
1767
+ namespace: str | None = None,
1768
+ project: str | None = None,
1776
1769
  ) -> "Self":
1777
1770
  """
1778
1771
  Adds one or more signals based on the results from the provided UDF.
@@ -1813,17 +1806,17 @@ class DatasetQuery:
1813
1806
  def generate(
1814
1807
  self,
1815
1808
  udf: "UDFAdapter",
1816
- partition_by: Optional[PartitionByType] = None,
1809
+ partition_by: PartitionByType | None = None,
1817
1810
  # Parameters from Settings
1818
1811
  cache: bool = False,
1819
- parallel: Optional[int] = None,
1820
- workers: Union[bool, int] = False,
1821
- min_task_size: Optional[int] = None,
1822
- batch_size: Optional[int] = None,
1812
+ parallel: int | None = None,
1813
+ workers: bool | int = False,
1814
+ min_task_size: int | None = None,
1815
+ batch_size: int | None = None,
1823
1816
  # Parameters are unused, kept only to match the signature of Settings.to_dict:
1824
- prefetch: Optional[int] = None,
1825
- namespace: Optional[str] = None,
1826
- project: Optional[str] = None,
1817
+ prefetch: int | None = None,
1818
+ namespace: str | None = None,
1819
+ project: str | None = None,
1827
1820
  ) -> "Self":
1828
1821
  query = self.clone()
1829
1822
  steps = query.steps
@@ -1879,23 +1872,23 @@ class DatasetQuery:
1879
1872
 
1880
1873
  def exec(self) -> "Self":
1881
1874
  """Execute the query."""
1875
+ query = self.clone()
1882
1876
  try:
1883
- query = self.clone()
1884
1877
  query.apply_steps()
1885
1878
  finally:
1886
- self.cleanup()
1879
+ query.cleanup()
1887
1880
  return query
1888
1881
 
1889
1882
  def save(
1890
1883
  self,
1891
- name: Optional[str] = None,
1892
- version: Optional[str] = None,
1893
- project: Optional[Project] = None,
1894
- feature_schema: Optional[dict] = None,
1895
- dependencies: Optional[list[DatasetDependency]] = None,
1896
- description: Optional[str] = None,
1897
- attrs: Optional[list[str]] = None,
1898
- update_version: Optional[str] = "patch",
1884
+ name: str | None = None,
1885
+ version: str | None = None,
1886
+ project: Project | None = None,
1887
+ feature_schema: dict | None = None,
1888
+ dependencies: list[DatasetDependency] | None = None,
1889
+ description: str | None = None,
1890
+ attrs: list[str] | None = None,
1891
+ update_version: str | None = "patch",
1899
1892
  **kwargs,
1900
1893
  ) -> "Self":
1901
1894
  """Save the query as a dataset."""
@@ -1989,5 +1982,5 @@ class DatasetQuery:
1989
1982
  return isinstance(self.last_step, SQLOrderBy)
1990
1983
 
1991
1984
  @property
1992
- def last_step(self) -> Optional[Step]:
1985
+ def last_step(self) -> Step | None:
1993
1986
  return self.steps[-1] if self.steps else None
@@ -3,9 +3,8 @@ from collections.abc import Iterable, Sequence
3
3
  from itertools import chain
4
4
  from multiprocessing import cpu_count
5
5
  from sys import stdin
6
- from typing import TYPE_CHECKING, Literal, Optional
6
+ from typing import TYPE_CHECKING, Literal
7
7
 
8
- import multiprocess
9
8
  from cloudpickle import load, loads
10
9
  from fsspec.callbacks import DEFAULT_CALLBACK, Callback
11
10
  from multiprocess import get_context
@@ -27,6 +26,7 @@ from datachain.query.utils import get_query_id_column
27
26
  from datachain.utils import batched, flatten, safe_closing
28
27
 
29
28
  if TYPE_CHECKING:
29
+ import multiprocess
30
30
  from sqlalchemy import Select, Table
31
31
 
32
32
  from datachain.data_storage import AbstractMetastore, AbstractWarehouse
@@ -41,7 +41,7 @@ FAILED_STATUS = "FAILED"
41
41
  NOTIFY_STATUS = "NOTIFY"
42
42
 
43
43
 
44
- def get_n_workers_from_arg(n_workers: Optional[int] = None) -> int:
44
+ def get_n_workers_from_arg(n_workers: int | None = None) -> int:
45
45
  if not n_workers:
46
46
  return cpu_count()
47
47
  if n_workers < 1:
@@ -86,7 +86,7 @@ def udf_entrypoint() -> int:
86
86
  return 0
87
87
 
88
88
 
89
- def udf_worker_entrypoint(fd: Optional[int] = None) -> int:
89
+ def udf_worker_entrypoint(fd: int | None = None) -> int:
90
90
  if not (udf_distributor_class := get_udf_distributor_class()):
91
91
  raise RuntimeError(
92
92
  f"{DISTRIBUTED_IMPORT_PATH} import path is required "
@@ -97,9 +97,9 @@ def udf_worker_entrypoint(fd: Optional[int] = None) -> int:
97
97
 
98
98
 
99
99
  class UDFDispatcher:
100
- _catalog: Optional[Catalog] = None
101
- task_queue: Optional[multiprocess.Queue] = None
102
- done_queue: Optional[multiprocess.Queue] = None
100
+ _catalog: Catalog | None = None
101
+ task_queue: "multiprocess.Queue | None" = None
102
+ done_queue: "multiprocess.Queue | None" = None
103
103
 
104
104
  def __init__(self, udf_info: UdfInfo, buffer_size: int = DEFAULT_BATCH_SIZE):
105
105
  self.udf_data = udf_info["udf_data"]
@@ -1,10 +1,9 @@
1
1
  import os
2
- from typing import Optional, Union
3
2
 
4
- metrics: dict[str, Union[str, int, float, bool, None]] = {}
3
+ metrics: dict[str, str | int | float | bool | None] = {}
5
4
 
6
5
 
7
- def set(key: str, value: Union[str, int, float, bool, None]) -> None: # noqa: PYI041
6
+ def set(key: str, value: str | int | float | bool | None) -> None: # noqa: PYI041
8
7
  """Set a metric value."""
9
8
  if not isinstance(key, str):
10
9
  raise TypeError("Key must be a string")
@@ -21,6 +20,6 @@ def set(key: str, value: Union[str, int, float, bool, None]) -> None: # noqa: P
21
20
  metastore.update_job(job_id, metrics=metrics)
22
21
 
23
22
 
24
- def get(key: str) -> Optional[Union[str, int, float, bool]]:
23
+ def get(key: str) -> str | int | float | bool | None:
25
24
  """Get a metric value."""
26
25
  return metrics[key]
datachain/query/params.py CHANGED
@@ -1,11 +1,10 @@
1
1
  import json
2
2
  import os
3
- from typing import Optional
4
3
 
5
- params_cache: Optional[dict[str, str]] = None
4
+ params_cache: dict[str, str] | None = None
6
5
 
7
6
 
8
- def param(key: str, default: Optional[str] = None) -> Optional[str]:
7
+ def param(key: str, default: str | None = None) -> str | None:
9
8
  """Get query parameter."""
10
9
  if not isinstance(key, str):
11
10
  raise TypeError("Param key must be a string")
datachain/query/schema.py CHANGED
@@ -1,7 +1,8 @@
1
1
  import functools
2
2
  from abc import ABC, abstractmethod
3
+ from collections.abc import Callable
3
4
  from fnmatch import fnmatch
4
- from typing import TYPE_CHECKING, Any, Callable, Optional, Union
5
+ from typing import TYPE_CHECKING, Any
5
6
 
6
7
  import attrs
7
8
  import sqlalchemy as sa
@@ -42,7 +43,7 @@ class ColumnMeta(type):
42
43
 
43
44
 
44
45
  class Column(sa.ColumnClause, metaclass=ColumnMeta):
45
- inherit_cache: Optional[bool] = True
46
+ inherit_cache: bool | None = True
46
47
 
47
48
  def __init__(self, text, type_=None, is_literal=False, _selectable=None):
48
49
  """Dataset column."""
@@ -177,7 +178,7 @@ class LocalFilename(UDFParameter):
177
178
  otherwise None will be returned.
178
179
  """
179
180
 
180
- glob: Optional[str] = None
181
+ glob: str | None = None
181
182
 
182
183
  def get_value(
183
184
  self,
@@ -186,7 +187,7 @@ class LocalFilename(UDFParameter):
186
187
  *,
187
188
  cb: Callback = DEFAULT_CALLBACK,
188
189
  **kwargs,
189
- ) -> Optional[str]:
190
+ ) -> str | None:
190
191
  if self.glob and not fnmatch(row["name"], self.glob): # type: ignore[type-var]
191
192
  # If the glob pattern is specified and the row filename
192
193
  # does not match it, then return None
@@ -205,7 +206,7 @@ class LocalFilename(UDFParameter):
205
206
  cache: bool = False,
206
207
  cb: Callback = DEFAULT_CALLBACK,
207
208
  **kwargs,
208
- ) -> Optional[str]:
209
+ ) -> str | None:
209
210
  if self.glob and not fnmatch(row["name"], self.glob): # type: ignore[type-var]
210
211
  # If the glob pattern is specified and the row filename
211
212
  # does not match it, then return None
@@ -216,7 +217,7 @@ class LocalFilename(UDFParameter):
216
217
  return client.cache.get_path(file)
217
218
 
218
219
 
219
- UDFParamSpec = Union[str, Column, UDFParameter]
220
+ UDFParamSpec = str | Column | UDFParameter
220
221
 
221
222
 
222
223
  def normalize_param(param: UDFParamSpec) -> UDFParameter:
@@ -3,7 +3,7 @@ import gc
3
3
  import logging
4
4
  import re
5
5
  import sys
6
- from typing import TYPE_CHECKING, ClassVar, Optional
6
+ from typing import TYPE_CHECKING, ClassVar
7
7
  from uuid import uuid4
8
8
 
9
9
  from datachain.catalog import get_catalog
@@ -39,7 +39,7 @@ class Session:
39
39
  catalog (Catalog): Catalog object.
40
40
  """
41
41
 
42
- GLOBAL_SESSION_CTX: Optional["Session"] = None
42
+ GLOBAL_SESSION_CTX: "Session | None" = None
43
43
  SESSION_CONTEXTS: ClassVar[list["Session"]] = []
44
44
  ORIGINAL_EXCEPT_HOOK = None
45
45
 
@@ -51,8 +51,8 @@ class Session:
51
51
  def __init__(
52
52
  self,
53
53
  name="",
54
- catalog: Optional["Catalog"] = None,
55
- client_config: Optional[dict] = None,
54
+ catalog: "Catalog | None" = None,
55
+ client_config: dict | None = None,
56
56
  in_memory: bool = False,
57
57
  ):
58
58
  if re.match(r"^[0-9a-zA-Z]*$", name) is None:
@@ -126,9 +126,9 @@ class Session:
126
126
  @classmethod
127
127
  def get(
128
128
  cls,
129
- session: Optional["Session"] = None,
130
- catalog: Optional["Catalog"] = None,
131
- client_config: Optional[dict] = None,
129
+ session: "Session | None" = None,
130
+ catalog: "Catalog | None" = None,
131
+ client_config: dict | None = None,
132
132
  in_memory: bool = False,
133
133
  ) -> "Session":
134
134
  """Creates a Session() object from a catalog.
datachain/query/udf.py CHANGED
@@ -1,5 +1,6 @@
1
1
  from abc import ABC, abstractmethod
2
- from typing import TYPE_CHECKING, Any, Callable, Optional, TypedDict, Union
2
+ from collections.abc import Callable
3
+ from typing import TYPE_CHECKING, Any, TypedDict
3
4
 
4
5
  if TYPE_CHECKING:
5
6
  from sqlalchemy import Select, Table
@@ -17,7 +18,7 @@ class UdfInfo(TypedDict):
17
18
  query: "Select"
18
19
  udf_fields: list[str]
19
20
  batching: "BatchingStrategy"
20
- processes: Optional[int]
21
+ processes: int | None
21
22
  is_generator: bool
22
23
  cache: bool
23
24
  rows_total: int
@@ -33,14 +34,14 @@ class AbstractUDFDistributor(ABC):
33
34
  query: "Select",
34
35
  udf_data: bytes,
35
36
  batching: "BatchingStrategy",
36
- workers: Union[bool, int],
37
- processes: Union[bool, int],
37
+ workers: bool | int,
38
+ processes: bool | int,
38
39
  udf_fields: list[str],
39
40
  rows_total: int,
40
41
  use_cache: bool,
41
42
  is_generator: bool = False,
42
- min_task_size: Optional[Union[str, int]] = None,
43
- batch_size: Optional[int] = None,
43
+ min_task_size: str | int | None = None,
44
+ batch_size: int | None = None,
44
45
  ) -> None: ...
45
46
 
46
47
  @abstractmethod
@@ -48,4 +49,4 @@ class AbstractUDFDistributor(ABC):
48
49
 
49
50
  @staticmethod
50
51
  @abstractmethod
51
- def run_udf(fd: Optional[int] = None) -> int: ...
52
+ def run_udf(fd: int | None = None) -> int: ...
datachain/query/utils.py CHANGED
@@ -1,8 +1,6 @@
1
- from typing import Optional, Union
2
-
3
1
  import sqlalchemy as sa
4
2
 
5
- ColT = Union[sa.ColumnClause, sa.Column, sa.ColumnElement, sa.TextClause, sa.Label]
3
+ ColT = sa.ColumnClause | sa.Column | sa.ColumnElement | sa.TextClause | sa.Label
6
4
 
7
5
 
8
6
  def column_name(col: ColT) -> str:
@@ -14,12 +12,12 @@ def column_name(col: ColT) -> str:
14
12
  )
15
13
 
16
14
 
17
- def get_query_column(query: sa.Select, name: str) -> Optional[ColT]:
15
+ def get_query_column(query: sa.Select, name: str) -> ColT | None:
18
16
  """Returns column element from query by name or None if column not found."""
19
17
  return next((col for col in query.inner_columns if column_name(col) == name), None)
20
18
 
21
19
 
22
- def get_query_id_column(query: sa.Select) -> Optional[sa.ColumnElement]:
20
+ def get_query_id_column(query: sa.Select) -> sa.ColumnElement | None:
23
21
  """Returns ID column element from query or None if column not found."""
24
22
  col = get_query_column(query, "sys__id")
25
23
  return col if col is not None and isinstance(col, sa.ColumnElement) else None