datachain 0.30.5__py3-none-any.whl → 0.39.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (119) hide show
  1. datachain/__init__.py +4 -0
  2. datachain/asyn.py +11 -12
  3. datachain/cache.py +5 -5
  4. datachain/catalog/__init__.py +0 -2
  5. datachain/catalog/catalog.py +276 -354
  6. datachain/catalog/dependency.py +164 -0
  7. datachain/catalog/loader.py +8 -3
  8. datachain/checkpoint.py +43 -0
  9. datachain/cli/__init__.py +10 -17
  10. datachain/cli/commands/__init__.py +1 -8
  11. datachain/cli/commands/datasets.py +42 -27
  12. datachain/cli/commands/ls.py +15 -15
  13. datachain/cli/commands/show.py +2 -2
  14. datachain/cli/parser/__init__.py +3 -43
  15. datachain/cli/parser/job.py +1 -1
  16. datachain/cli/parser/utils.py +1 -2
  17. datachain/cli/utils.py +2 -15
  18. datachain/client/azure.py +2 -2
  19. datachain/client/fsspec.py +34 -23
  20. datachain/client/gcs.py +3 -3
  21. datachain/client/http.py +157 -0
  22. datachain/client/local.py +11 -7
  23. datachain/client/s3.py +3 -3
  24. datachain/config.py +4 -8
  25. datachain/data_storage/db_engine.py +12 -6
  26. datachain/data_storage/job.py +2 -0
  27. datachain/data_storage/metastore.py +716 -137
  28. datachain/data_storage/schema.py +20 -27
  29. datachain/data_storage/serializer.py +105 -15
  30. datachain/data_storage/sqlite.py +114 -114
  31. datachain/data_storage/warehouse.py +140 -48
  32. datachain/dataset.py +109 -89
  33. datachain/delta.py +117 -42
  34. datachain/diff/__init__.py +25 -33
  35. datachain/error.py +24 -0
  36. datachain/func/aggregate.py +9 -11
  37. datachain/func/array.py +12 -12
  38. datachain/func/base.py +7 -4
  39. datachain/func/conditional.py +9 -13
  40. datachain/func/func.py +63 -45
  41. datachain/func/numeric.py +5 -7
  42. datachain/func/string.py +2 -2
  43. datachain/hash_utils.py +123 -0
  44. datachain/job.py +11 -7
  45. datachain/json.py +138 -0
  46. datachain/lib/arrow.py +18 -15
  47. datachain/lib/audio.py +60 -59
  48. datachain/lib/clip.py +14 -13
  49. datachain/lib/convert/python_to_sql.py +6 -10
  50. datachain/lib/convert/values_to_tuples.py +151 -53
  51. datachain/lib/data_model.py +23 -19
  52. datachain/lib/dataset_info.py +7 -7
  53. datachain/lib/dc/__init__.py +2 -1
  54. datachain/lib/dc/csv.py +22 -26
  55. datachain/lib/dc/database.py +37 -34
  56. datachain/lib/dc/datachain.py +518 -324
  57. datachain/lib/dc/datasets.py +38 -30
  58. datachain/lib/dc/hf.py +16 -20
  59. datachain/lib/dc/json.py +17 -18
  60. datachain/lib/dc/listings.py +5 -8
  61. datachain/lib/dc/pandas.py +3 -6
  62. datachain/lib/dc/parquet.py +33 -21
  63. datachain/lib/dc/records.py +9 -13
  64. datachain/lib/dc/storage.py +103 -65
  65. datachain/lib/dc/storage_pattern.py +251 -0
  66. datachain/lib/dc/utils.py +17 -14
  67. datachain/lib/dc/values.py +3 -6
  68. datachain/lib/file.py +187 -50
  69. datachain/lib/hf.py +7 -5
  70. datachain/lib/image.py +13 -13
  71. datachain/lib/listing.py +5 -5
  72. datachain/lib/listing_info.py +1 -2
  73. datachain/lib/meta_formats.py +2 -3
  74. datachain/lib/model_store.py +20 -8
  75. datachain/lib/namespaces.py +59 -7
  76. datachain/lib/projects.py +51 -9
  77. datachain/lib/pytorch.py +31 -23
  78. datachain/lib/settings.py +188 -85
  79. datachain/lib/signal_schema.py +302 -64
  80. datachain/lib/text.py +8 -7
  81. datachain/lib/udf.py +103 -63
  82. datachain/lib/udf_signature.py +59 -34
  83. datachain/lib/utils.py +20 -0
  84. datachain/lib/video.py +3 -4
  85. datachain/lib/webdataset.py +31 -36
  86. datachain/lib/webdataset_laion.py +15 -16
  87. datachain/listing.py +12 -5
  88. datachain/model/bbox.py +3 -1
  89. datachain/namespace.py +22 -3
  90. datachain/node.py +6 -6
  91. datachain/nodes_thread_pool.py +0 -1
  92. datachain/plugins.py +24 -0
  93. datachain/project.py +4 -4
  94. datachain/query/batch.py +10 -12
  95. datachain/query/dataset.py +376 -194
  96. datachain/query/dispatch.py +112 -84
  97. datachain/query/metrics.py +3 -4
  98. datachain/query/params.py +2 -3
  99. datachain/query/queue.py +2 -1
  100. datachain/query/schema.py +7 -6
  101. datachain/query/session.py +190 -33
  102. datachain/query/udf.py +9 -6
  103. datachain/remote/studio.py +90 -53
  104. datachain/script_meta.py +12 -12
  105. datachain/sql/sqlite/base.py +37 -25
  106. datachain/sql/sqlite/types.py +1 -1
  107. datachain/sql/types.py +36 -5
  108. datachain/studio.py +49 -40
  109. datachain/toolkit/split.py +31 -10
  110. datachain/utils.py +39 -48
  111. {datachain-0.30.5.dist-info → datachain-0.39.0.dist-info}/METADATA +26 -38
  112. datachain-0.39.0.dist-info/RECORD +173 -0
  113. datachain/cli/commands/query.py +0 -54
  114. datachain/query/utils.py +0 -36
  115. datachain-0.30.5.dist-info/RECORD +0 -168
  116. {datachain-0.30.5.dist-info → datachain-0.39.0.dist-info}/WHEEL +0 -0
  117. {datachain-0.30.5.dist-info → datachain-0.39.0.dist-info}/entry_points.txt +0 -0
  118. {datachain-0.30.5.dist-info → datachain-0.39.0.dist-info}/licenses/LICENSE +0 -0
  119. {datachain-0.30.5.dist-info → datachain-0.39.0.dist-info}/top_level.txt +0 -0
@@ -1,63 +1,177 @@
1
1
  import itertools
2
2
  from collections.abc import Sequence
3
- from typing import Any, Union
4
-
5
- from datachain.lib.data_model import (
6
- DataType,
7
- DataTypeNames,
8
- DataValue,
9
- is_chain_type,
10
- )
3
+ from typing import Any
4
+
5
+ from datachain.lib.data_model import DataType, DataTypeNames, DataValue, is_chain_type
11
6
  from datachain.lib.utils import DataChainParamsError
12
7
 
13
8
 
14
9
  class ValuesToTupleError(DataChainParamsError):
15
10
  def __init__(self, ds_name: str, msg: str):
11
+ self.ds_name = ds_name
12
+ self.msg = msg
13
+
16
14
  if ds_name:
17
15
  ds_name = f"' {ds_name}'"
16
+
18
17
  super().__init__(f"Cannot convert signals for dataset{ds_name}: {msg}")
19
18
 
19
+ def __reduce__(self):
20
+ return ValuesToTupleError, (self.ds_name, self.msg)
20
21
 
21
- def values_to_tuples( # noqa: C901, PLR0912
22
- ds_name: str = "",
23
- output: Union[None, DataType, Sequence[str], dict[str, DataType]] = None,
24
- **fr_map: Sequence[DataValue],
25
- ) -> tuple[Any, Any, Any]:
26
- if output:
27
- if not isinstance(output, (Sequence, str, dict)):
28
- if len(fr_map) != 1:
29
- raise ValuesToTupleError(
30
- ds_name,
31
- f"only one output type was specified, {len(fr_map)} expected",
32
- )
33
- if not isinstance(output, type):
34
- raise ValuesToTupleError(
35
- ds_name,
36
- f"output must specify a type while '{output}' was given",
37
- )
38
22
 
39
- key: str = next(iter(fr_map.keys()))
40
- output = {key: output} # type: ignore[dict-item]
23
+ def _find_first_non_none(sequence: Sequence[Any]) -> Any | None:
24
+ """Find the first non-None element in a sequence."""
25
+ try:
26
+ return next(itertools.dropwhile(lambda i: i is None, sequence))
27
+ except StopIteration:
28
+ return None
29
+
30
+
31
+ def _infer_list_item_type(lst: list) -> type:
32
+ """Infer the item type of a list, handling None values and nested lists."""
33
+ if len(lst) == 0:
34
+ # Default to str when list is empty to avoid generic list
35
+ return str
36
+
37
+ first_item = _find_first_non_none(lst)
38
+ if first_item is None:
39
+ # Default to str when all items are None
40
+ return str
41
+
42
+ item_type = type(first_item)
43
+
44
+ # Handle nested lists one level deep
45
+ if isinstance(first_item, list) and len(first_item) > 0:
46
+ nested_item = _find_first_non_none(first_item)
47
+ if nested_item is not None:
48
+ return list[type(nested_item)] # type: ignore[misc, return-value]
49
+ # Default to str for nested lists with all None
50
+ return list[str] # type: ignore[return-value]
51
+
52
+ return item_type
53
+
54
+
55
+ def _infer_dict_value_type(dct: dict) -> type:
56
+ """Infer the value type of a dict, handling None values and list values."""
57
+ if len(dct) == 0:
58
+ # Default to str when dict is empty to avoid generic dict values
59
+ return str
60
+
61
+ # Find first non-None value
62
+ first_value = None
63
+ for val in dct.values():
64
+ if val is not None:
65
+ first_value = val
66
+ break
67
+
68
+ if first_value is None:
69
+ # Default to str when all values are None
70
+ return str
71
+
72
+ # Handle list values
73
+ if isinstance(first_value, list) and len(first_value) > 0:
74
+ list_item = _find_first_non_none(first_value)
75
+ if list_item is not None:
76
+ return list[type(list_item)] # type: ignore[misc, return-value]
77
+ # Default to str for lists with all None
78
+ return list[str] # type: ignore[return-value]
79
+
80
+ return type(first_value)
81
+
82
+
83
+ def _infer_type_from_sequence(
84
+ sequence: Sequence[DataValue], signal_name: str, ds_name: str
85
+ ) -> type:
86
+ """
87
+ Infer the type from a sequence of values.
88
+
89
+ Returns str if all values are None, otherwise infers from the first non-None value.
90
+ Handles lists and dicts with proper type inference for nested structures.
91
+ """
92
+ first_element = _find_first_non_none(sequence)
93
+
94
+ if first_element is None:
95
+ # Default to str if column is empty or all values are None
96
+ return str
41
97
 
42
- if not isinstance(output, dict):
98
+ typ = type(first_element)
99
+
100
+ if not is_chain_type(typ):
101
+ raise ValuesToTupleError(
102
+ ds_name,
103
+ f"signal '{signal_name}' has unsupported type '{typ.__name__}'."
104
+ f" Please use DataModel types: {DataTypeNames}",
105
+ )
106
+
107
+ if isinstance(first_element, list):
108
+ item_type = _infer_list_item_type(first_element)
109
+ return list[item_type] # type: ignore[valid-type, return-value]
110
+
111
+ if isinstance(first_element, dict):
112
+ # If the first dict is empty, use str as default key/value types
113
+ if len(first_element) == 0:
114
+ return dict[str, str] # type: ignore[return-value]
115
+ first_key = next(iter(first_element.keys()))
116
+ value_type = _infer_dict_value_type(first_element)
117
+ return dict[type(first_key), value_type] # type: ignore[misc, return-value]
118
+
119
+ return typ
120
+
121
+
122
+ def _validate_and_normalize_output(
123
+ output: DataType | Sequence[str] | dict[str, DataType] | None,
124
+ fr_map: dict[str, Sequence[DataValue]],
125
+ ds_name: str,
126
+ ) -> dict[str, DataType] | None:
127
+ """Validate and normalize the output parameter to a dict format."""
128
+ if not output:
129
+ return None
130
+
131
+ if not isinstance(output, (Sequence, str, dict)):
132
+ if len(fr_map) != 1:
43
133
  raise ValuesToTupleError(
44
134
  ds_name,
45
- "output type must be dict[str, DataType] while "
46
- f"'{type(output).__name__}' is given",
135
+ f"only one output type was specified, {len(fr_map)} expected",
47
136
  )
48
-
49
- if len(output) != len(fr_map):
137
+ if not isinstance(output, type):
50
138
  raise ValuesToTupleError(
51
139
  ds_name,
52
- f"number of outputs '{len(output)}' should match"
53
- f" number of signals '{len(fr_map)}'",
140
+ f"output must specify a type while '{output}' was given",
54
141
  )
55
142
 
143
+ key: str = next(iter(fr_map.keys()))
144
+ return {key: output} # type: ignore[dict-item]
145
+
146
+ if not isinstance(output, dict):
147
+ raise ValuesToTupleError(
148
+ ds_name,
149
+ "output type must be dict[str, DataType] while "
150
+ f"'{type(output).__name__}' is given",
151
+ )
152
+
153
+ if len(output) != len(fr_map):
154
+ raise ValuesToTupleError(
155
+ ds_name,
156
+ f"number of outputs '{len(output)}' should match"
157
+ f" number of signals '{len(fr_map)}'",
158
+ )
159
+
160
+ return output # type: ignore[return-value]
161
+
162
+
163
+ def values_to_tuples(
164
+ ds_name: str = "",
165
+ output: DataType | Sequence[str] | dict[str, DataType] | None = None,
166
+ **fr_map: Sequence[DataValue],
167
+ ) -> tuple[Any, Any, Any]:
168
+ output = _validate_and_normalize_output(output, fr_map, ds_name)
169
+
56
170
  types_map: dict[str, type] = {}
57
171
  length = -1
58
172
  for k, v in fr_map.items():
59
173
  if not isinstance(v, Sequence) or isinstance(v, str): # type: ignore[unreachable]
60
- raise ValuesToTupleError(ds_name, f"signals '{k}' is not a sequence")
174
+ raise ValuesToTupleError(ds_name, f"signal '{k}' is not a sequence")
61
175
  len_ = len(v)
62
176
 
63
177
  if output:
@@ -70,23 +184,7 @@ def values_to_tuples( # noqa: C901, PLR0912
70
184
  # FIXME: Stops as soon as it finds the first non-None value.
71
185
  # If a non-None value appears early, it won't check the remaining items for
72
186
  # `None` values.
73
- try:
74
- first_not_none_element = next(
75
- itertools.dropwhile(lambda i: i is None, v)
76
- )
77
- except StopIteration:
78
- # set default type to `str` if column is empty or all values are `None`
79
- typ = str
80
- else:
81
- typ = type(first_not_none_element) # type: ignore[assignment]
82
- if not is_chain_type(typ):
83
- raise ValuesToTupleError(
84
- ds_name,
85
- f"signal '{k}' has unsupported type '{typ.__name__}'."
86
- f" Please use DataModel types: {DataTypeNames}",
87
- )
88
- if isinstance(first_not_none_element, list):
89
- typ = list[type(first_not_none_element[0])] # type: ignore[assignment, misc]
187
+ typ = _infer_type_from_sequence(v, k, ds_name)
90
188
  types_map[k] = typ
91
189
 
92
190
  if length < 0:
@@ -111,7 +209,7 @@ def values_to_tuples( # noqa: C901, PLR0912
111
209
  if len(output) > 1: # type: ignore[arg-type]
112
210
  tuple_type = tuple(output_types)
113
211
  res_type = tuple[tuple_type] # type: ignore[valid-type]
114
- res_values: Sequence[Any] = list(zip(*fr_map.values()))
212
+ res_values: Sequence[Any] = list(zip(*fr_map.values(), strict=False))
115
213
  else:
116
214
  res_type = output_types[0] # type: ignore[misc]
117
215
  res_values = next(iter(fr_map.values()))
@@ -1,8 +1,9 @@
1
1
  import inspect
2
+ import types
2
3
  import uuid
3
4
  from collections.abc import Sequence
4
5
  from datetime import datetime
5
- from typing import ClassVar, Optional, Union, get_args, get_origin
6
+ from typing import ClassVar, Union, get_args, get_origin
6
7
 
7
8
  from pydantic import AliasChoices, BaseModel, Field, create_model
8
9
  from pydantic.fields import FieldInfo
@@ -10,19 +11,19 @@ from pydantic.fields import FieldInfo
10
11
  from datachain.lib.model_store import ModelStore
11
12
  from datachain.lib.utils import normalize_col_names
12
13
 
13
- StandardType = Union[
14
- type[int],
15
- type[str],
16
- type[float],
17
- type[bool],
18
- type[list],
19
- type[dict],
20
- type[bytes],
21
- type[datetime],
22
- ]
23
- DataType = Union[type[BaseModel], StandardType]
14
+ StandardType = (
15
+ type[int]
16
+ | type[str]
17
+ | type[float]
18
+ | type[bool]
19
+ | type[list]
20
+ | type[dict]
21
+ | type[bytes]
22
+ | type[datetime]
23
+ )
24
+ DataType = type[BaseModel] | StandardType
24
25
  DataTypeNames = "BaseModel, int, str, float, bool, list, dict, bytes, datetime"
25
- DataValue = Union[BaseModel, int, str, float, bool, list, dict, bytes, datetime]
26
+ DataValue = BaseModel | int | str | float | bool | list | dict | bytes | datetime
26
27
 
27
28
 
28
29
  class DataModel(BaseModel):
@@ -37,7 +38,7 @@ class DataModel(BaseModel):
37
38
  ModelStore.register(cls)
38
39
 
39
40
  @staticmethod
40
- def register(models: Union[DataType, Sequence[DataType]]):
41
+ def register(models: DataType | Sequence[DataType]):
41
42
  """For registering classes manually. It accepts a single class or a sequence of
42
43
  classes."""
43
44
  if not isinstance(models, Sequence):
@@ -63,8 +64,11 @@ def is_chain_type(t: type) -> bool:
63
64
  if orig is list and len(args) == 1:
64
65
  return is_chain_type(get_args(t)[0])
65
66
 
66
- if orig is Union and len(args) == 2 and (type(None) in args):
67
- return is_chain_type(args[0])
67
+ if orig is dict and len(args) == 2:
68
+ return is_chain_type(args[0]) and is_chain_type(args[1])
69
+
70
+ if orig in (Union, types.UnionType) and len(args) == 2 and (type(None) in args):
71
+ return is_chain_type(args[0] if args[1] is type(None) else args[1])
68
72
 
69
73
  return False
70
74
 
@@ -72,19 +76,19 @@ def is_chain_type(t: type) -> bool:
72
76
  def dict_to_data_model(
73
77
  name: str,
74
78
  data_dict: dict[str, DataType],
75
- original_names: Optional[list[str]] = None,
79
+ original_names: list[str] | None = None,
76
80
  ) -> type[BaseModel]:
77
81
  if not original_names:
78
82
  # Gets a map of a normalized_name -> original_name
79
83
  columns = normalize_col_names(list(data_dict))
80
- data_dict = dict(zip(columns.keys(), data_dict.values()))
84
+ data_dict = dict(zip(columns.keys(), data_dict.values(), strict=False))
81
85
  original_names = list(columns.values())
82
86
 
83
87
  fields = {
84
88
  name: (
85
89
  anno
86
90
  if inspect.isclass(anno) and issubclass(anno, BaseModel)
87
- else Optional[anno],
91
+ else anno | None,
88
92
  Field(
89
93
  validation_alias=AliasChoices(name, original_names[idx] or name),
90
94
  default=None,
@@ -1,10 +1,10 @@
1
- import json
2
1
  from datetime import datetime
3
- from typing import TYPE_CHECKING, Any, Optional, Union
2
+ from typing import TYPE_CHECKING, Any
4
3
  from uuid import uuid4
5
4
 
6
5
  from pydantic import Field, field_validator
7
6
 
7
+ from datachain import json
8
8
  from datachain.dataset import (
9
9
  DEFAULT_DATASET_VERSION,
10
10
  DatasetListRecord,
@@ -28,9 +28,9 @@ class DatasetInfo(DataModel):
28
28
  version: str = Field(default=DEFAULT_DATASET_VERSION)
29
29
  status: int = Field(default=DatasetStatus.CREATED)
30
30
  created_at: datetime = Field(default=TIME_ZERO)
31
- finished_at: Optional[datetime] = Field(default=None)
32
- num_objects: Optional[int] = Field(default=None)
33
- size: Optional[int] = Field(default=None)
31
+ finished_at: datetime | None = Field(default=None)
32
+ num_objects: int | None = Field(default=None)
33
+ size: int | None = Field(default=None)
34
34
  params: dict[str, str] = Field(default={})
35
35
  metrics: dict[str, Any] = Field(default={})
36
36
  error_message: str = Field(default="")
@@ -59,7 +59,7 @@ class DatasetInfo(DataModel):
59
59
 
60
60
  @staticmethod
61
61
  def _validate_dict(
62
- v: Optional[Union[str, dict]],
62
+ v: str | dict | None,
63
63
  ) -> dict:
64
64
  if v is None or v == "":
65
65
  return {}
@@ -88,7 +88,7 @@ class DatasetInfo(DataModel):
88
88
  cls,
89
89
  dataset: DatasetListRecord,
90
90
  version: DatasetListVersion,
91
- job: Optional[Job],
91
+ job: Job | None,
92
92
  ) -> "Self":
93
93
  return cls(
94
94
  uuid=version.uuid,
@@ -9,7 +9,7 @@ from .pandas import read_pandas
9
9
  from .parquet import read_parquet
10
10
  from .records import read_records
11
11
  from .storage import read_storage
12
- from .utils import DatasetMergeError, DatasetPrepareError, Sys, is_studio
12
+ from .utils import DatasetMergeError, DatasetPrepareError, Sys, is_local, is_studio
13
13
  from .values import read_values
14
14
 
15
15
  __all__ = [
@@ -21,6 +21,7 @@ __all__ = [
21
21
  "Sys",
22
22
  "datasets",
23
23
  "delete_dataset",
24
+ "is_local",
24
25
  "is_studio",
25
26
  "listings",
26
27
  "move_dataset",
datachain/lib/dc/csv.py CHANGED
@@ -1,10 +1,6 @@
1
- from collections.abc import Sequence
2
- from typing import (
3
- TYPE_CHECKING,
4
- Callable,
5
- Optional,
6
- Union,
7
- )
1
+ import os
2
+ from collections.abc import Callable, Sequence
3
+ from typing import TYPE_CHECKING
8
4
 
9
5
  from datachain.lib.dc.utils import DatasetPrepareError, OutputType
10
6
  from datachain.lib.model_store import ModelStore
@@ -17,38 +13,38 @@ if TYPE_CHECKING:
17
13
 
18
14
 
19
15
  def read_csv(
20
- path,
21
- delimiter: Optional[str] = None,
16
+ path: str | os.PathLike[str] | list[str] | list[os.PathLike[str]],
17
+ delimiter: str | None = None,
22
18
  header: bool = True,
23
19
  output: OutputType = None,
24
20
  column: str = "",
25
21
  model_name: str = "",
26
22
  source: bool = True,
27
- nrows=None,
28
- session: Optional[Session] = None,
29
- settings: Optional[dict] = None,
30
- column_types: Optional[dict[str, "Union[str, ArrowDataType]"]] = None,
31
- parse_options: Optional[dict[str, "Union[str, Union[bool, Callable]]"]] = None,
23
+ nrows: int | None = None,
24
+ session: Session | None = None,
25
+ settings: dict | None = None,
26
+ column_types: dict[str, "str | ArrowDataType"] | None = None,
27
+ parse_options: dict[str, str | bool | Callable] | None = None,
32
28
  **kwargs,
33
29
  ) -> "DataChain":
34
30
  """Generate chain from csv files.
35
31
 
36
32
  Parameters:
37
- path : Storage URI with directory. URI must start with storage prefix such
33
+ path: Storage URI with directory. URI must start with storage prefix such
38
34
  as `s3://`, `gs://`, `az://` or "file:///".
39
- delimiter : Character for delimiting columns. Takes precedence if also
35
+ delimiter: Character for delimiting columns. Takes precedence if also
40
36
  specified in `parse_options`. Defaults to ",".
41
- header : Whether the files include a header row.
42
- output : Dictionary or feature class defining column names and their
37
+ header: Whether the files include a header row.
38
+ output: Dictionary or feature class defining column names and their
43
39
  corresponding types. List of column names is also accepted, in which
44
40
  case types will be inferred.
45
- column : Created column name.
46
- model_name : Generated model name.
47
- source : Whether to include info about the source file.
48
- nrows : Optional row limit.
49
- session : Session to use for the chain.
50
- settings : Settings to use for the chain.
51
- column_types : Dictionary of column names and their corresponding types.
41
+ column: Created column name.
42
+ model_name: Generated model name.
43
+ source: Whether to include info about the source file.
44
+ nrows: Optional row limit.
45
+ session: Session to use for the chain.
46
+ settings: Settings to use for the chain.
47
+ column_types: Dictionary of column names and their corresponding types.
52
48
  It is passed to CSV reader and for each column specified type auto
53
49
  inference is disabled.
54
50
  parse_options: Tells the parser how to process lines.
@@ -67,7 +63,7 @@ def read_csv(
67
63
  chain = dc.read_csv("s3://mybucket/dir")
68
64
  ```
69
65
  """
70
- from pandas.io.parsers.readers import STR_NA_VALUES
66
+ from pandas._libs.parsers import STR_NA_VALUES
71
67
  from pyarrow.csv import ConvertOptions, ParseOptions, ReadOptions
72
68
  from pyarrow.dataset import CsvFileFormat
73
69
  from pyarrow.lib import type_for_alias
@@ -2,7 +2,8 @@ import contextlib
2
2
  import itertools
3
3
  import os
4
4
  import sqlite3
5
- from typing import TYPE_CHECKING, Any, Optional, Union
5
+ from collections.abc import Iterator, Mapping, Sequence
6
+ from typing import TYPE_CHECKING, Any
6
7
 
7
8
  import sqlalchemy
8
9
 
@@ -12,8 +13,6 @@ from datachain.utils import batched
12
13
  DEFAULT_DATABASE_BATCH_SIZE = 10_000
13
14
 
14
15
  if TYPE_CHECKING:
15
- from collections.abc import Iterator, Mapping, Sequence
16
-
17
16
  import sqlalchemy.orm # noqa: TC004
18
17
 
19
18
  from datachain.lib.data_model import DataType
@@ -21,21 +20,21 @@ if TYPE_CHECKING:
21
20
 
22
21
  from .datachain import DataChain
23
22
 
24
- ConnectionType = Union[
25
- str,
26
- sqlalchemy.engine.URL,
27
- sqlalchemy.engine.interfaces.Connectable,
28
- sqlalchemy.engine.Engine,
29
- sqlalchemy.engine.Connection,
30
- sqlalchemy.orm.Session,
31
- sqlite3.Connection,
32
- ]
23
+ ConnectionType = (
24
+ str
25
+ | sqlalchemy.engine.URL
26
+ | sqlalchemy.engine.interfaces.Connectable
27
+ | sqlalchemy.engine.Engine
28
+ | sqlalchemy.engine.Connection
29
+ | sqlalchemy.orm.Session
30
+ | sqlite3.Connection
31
+ )
33
32
 
34
33
 
35
34
  @contextlib.contextmanager
36
35
  def _connect(
37
36
  connection: "ConnectionType",
38
- ) -> "Iterator[sqlalchemy.engine.Connection]":
37
+ ) -> Iterator[sqlalchemy.engine.Connection]:
39
38
  import sqlalchemy.orm
40
39
 
41
40
  with contextlib.ExitStack() as stack:
@@ -46,10 +45,14 @@ def _connect(
46
45
  yield stack.enter_context(engine.connect())
47
46
  elif isinstance(connection, sqlite3.Connection):
48
47
  engine = sqlalchemy.create_engine(
49
- "sqlite://", creator=lambda: connection, **engine_kwargs
48
+ "sqlite://",
49
+ creator=lambda: connection,
50
+ poolclass=sqlalchemy.pool.StaticPool,
51
+ **engine_kwargs,
50
52
  )
51
- # do not close the connection, as it is managed by the caller
52
- yield engine.connect()
53
+ # Close only the SQLAlchemy connection wrapper; the underlying
54
+ # sqlite3 connection remains managed by the caller via StaticPool.
55
+ yield stack.enter_context(engine.connect())
53
56
  elif isinstance(connection, sqlalchemy.Engine):
54
57
  yield stack.enter_context(connection.connect())
55
58
  elif isinstance(connection, sqlalchemy.Connection):
@@ -73,10 +76,10 @@ def to_database(
73
76
  table_name: str,
74
77
  connection: "ConnectionType",
75
78
  *,
76
- batch_rows: int = DEFAULT_DATABASE_BATCH_SIZE,
77
- on_conflict: Optional[str] = None,
78
- conflict_columns: Optional[list[str]] = None,
79
- column_mapping: Optional[dict[str, Optional[str]]] = None,
79
+ batch_size: int = DEFAULT_DATABASE_BATCH_SIZE,
80
+ on_conflict: str | None = None,
81
+ conflict_columns: list[str] | None = None,
82
+ column_mapping: dict[str, str | None] | None = None,
80
83
  ) -> int:
81
84
  """
82
85
  Implementation function for exporting DataChain to database tables.
@@ -124,7 +127,7 @@ def to_database(
124
127
  table.create(conn, checkfirst=True)
125
128
 
126
129
  rows_iter = chain._leaf_values()
127
- for batch in batched(rows_iter, batch_rows):
130
+ for batch in batched(rows_iter, batch_size):
128
131
  rows_affected = _process_batch(
129
132
  conn,
130
133
  table,
@@ -150,8 +153,8 @@ def to_database(
150
153
 
151
154
 
152
155
  def _normalize_column_mapping(
153
- column_mapping: dict[str, Optional[str]],
154
- ) -> dict[str, Optional[str]]:
156
+ column_mapping: dict[str, str | None],
157
+ ) -> dict[str, str | None]:
155
158
  """
156
159
  Convert column mapping keys from DataChain format (dots) to database format
157
160
  (double underscores).
@@ -163,7 +166,7 @@ def _normalize_column_mapping(
163
166
  if not column_mapping:
164
167
  return {}
165
168
 
166
- normalized_mapping: dict[str, Optional[str]] = {}
169
+ normalized_mapping: dict[str, str | None] = {}
167
170
  original_keys: dict[str, str] = {}
168
171
  for key, value in column_mapping.items():
169
172
  db_key = ColumnMeta.to_db_name(key)
@@ -181,7 +184,7 @@ def _normalize_column_mapping(
181
184
  from collections import defaultdict
182
185
 
183
186
  default_factory = column_mapping.default_factory
184
- result: dict[str, Optional[str]] = defaultdict(default_factory)
187
+ result: dict[str, str | None] = defaultdict(default_factory)
185
188
  result.update(normalized_mapping)
186
189
  return result
187
190
 
@@ -189,8 +192,8 @@ def _normalize_column_mapping(
189
192
 
190
193
 
191
194
  def _normalize_conflict_columns(
192
- conflict_columns: Optional[list[str]], column_mapping: dict[str, Optional[str]]
193
- ) -> Optional[list[str]]:
195
+ conflict_columns: list[str] | None, column_mapping: dict[str, str | None]
196
+ ) -> list[str] | None:
194
197
  """
195
198
  Normalize conflict_columns by converting DataChain format to database format
196
199
  and applying column mapping.
@@ -297,15 +300,15 @@ def _process_batch(
297
300
 
298
301
 
299
302
  def read_database(
300
- query: Union[str, "sqlalchemy.sql.expression.Executable"],
303
+ query: "str | sqlalchemy.sql.expression.Executable",
301
304
  connection: "ConnectionType",
302
- params: Union["Sequence[Mapping[str, Any]]", "Mapping[str, Any]", None] = None,
305
+ params: Sequence[Mapping[str, Any]] | Mapping[str, Any] | None = None,
303
306
  *,
304
- output: Optional["dict[str, DataType]"] = None,
305
- session: Optional["Session"] = None,
306
- settings: Optional[dict] = None,
307
+ output: dict[str, "DataType"] | None = None,
308
+ session: "Session | None" = None,
309
+ settings: dict | None = None,
307
310
  in_memory: bool = False,
308
- infer_schema_length: Optional[int] = 100,
311
+ infer_schema_length: int | None = 100,
309
312
  ) -> "DataChain":
310
313
  """
311
314
  Read the results of a SQL query into a DataChain, using a given database connection.
@@ -382,7 +385,7 @@ def read_database(
382
385
  def _infer_schema(
383
386
  result: "sqlalchemy.engine.Result",
384
387
  to_infer: list[str],
385
- infer_schema_length: Optional[int] = 100,
388
+ infer_schema_length: int | None = 100,
386
389
  ) -> tuple[list["sqlalchemy.Row"], dict[str, "DataType"]]:
387
390
  from datachain.lib.convert.values_to_tuples import values_to_tuples
388
391