datachain 0.28.1__py3-none-any.whl → 0.29.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of datachain might be problematic. Click here for more details.

@@ -21,6 +21,7 @@ from datachain.lib.file import File
21
21
  from datachain.lib.signal_schema import SignalSchema
22
22
  from datachain.node import DirType, DirTypeGroup, Node, NodeWithPath, get_path
23
23
  from datachain.query.batch import RowsOutput
24
+ from datachain.query.schema import ColumnMeta
24
25
  from datachain.query.utils import get_query_id_column
25
26
  from datachain.sql.functions import path as pathfunc
26
27
  from datachain.sql.types import Int, SQLType
@@ -400,7 +401,7 @@ class AbstractWarehouse(ABC, Serializable):
400
401
  expressions: tuple[_ColumnsClauseArgument[Any], ...] = (
401
402
  sa.func.count(table.c.sys__id),
402
403
  )
403
- size_column_names = [s.replace(".", "__") + "__size" for s in file_signals]
404
+ size_column_names = [ColumnMeta.to_db_name(s) + "__size" for s in file_signals]
404
405
  size_columns = [c for c in table.columns if c.name in size_column_names]
405
406
 
406
407
  if size_columns:
@@ -6,6 +6,10 @@ from typing import TYPE_CHECKING, Any, Optional, Union
6
6
 
7
7
  import sqlalchemy
8
8
 
9
+ from datachain.query.schema import ColumnMeta
10
+
11
+ DEFAULT_DATABASE_BATCH_SIZE = 10_000
12
+
9
13
  if TYPE_CHECKING:
10
14
  from collections.abc import Iterator, Mapping, Sequence
11
15
 
@@ -30,7 +34,7 @@ if TYPE_CHECKING:
30
34
  @contextlib.contextmanager
31
35
  def _connect(
32
36
  connection: "ConnectionType",
33
- ) -> "Iterator[Union[sqlalchemy.engine.Connection, sqlalchemy.orm.Session]]":
37
+ ) -> "Iterator[sqlalchemy.engine.Connection]":
34
38
  import sqlalchemy.orm
35
39
 
36
40
  with contextlib.ExitStack() as stack:
@@ -47,27 +51,184 @@ def _connect(
47
51
  yield engine.connect()
48
52
  elif isinstance(connection, sqlalchemy.Engine):
49
53
  yield stack.enter_context(connection.connect())
50
- elif isinstance(connection, (sqlalchemy.Connection, sqlalchemy.orm.Session)):
54
+ elif isinstance(connection, sqlalchemy.Connection):
51
55
  # do not close the connection, as it is managed by the caller
52
56
  yield connection
57
+ elif isinstance(connection, sqlalchemy.orm.Session):
58
+ # For Session objects, get the underlying bind (Engine or Connection)
59
+ # Sessions don't support DDL operations directly
60
+ bind = connection.get_bind()
61
+ if isinstance(bind, sqlalchemy.Engine):
62
+ yield stack.enter_context(bind.connect())
63
+ else:
64
+ # bind is already a Connection
65
+ yield bind
53
66
  else:
54
67
  raise TypeError(f"Unsupported connection type: {type(connection).__name__}")
55
68
 
56
69
 
57
- def _infer_schema(
58
- result: "sqlalchemy.engine.Result",
59
- to_infer: list[str],
60
- infer_schema_length: Optional[int] = 100,
61
- ) -> tuple[list["sqlalchemy.Row"], dict[str, "DataType"]]:
62
- from datachain.lib.convert.values_to_tuples import values_to_tuples
70
+ def to_database(
71
+ chain: "DataChain",
72
+ table_name: str,
73
+ connection: "ConnectionType",
74
+ *,
75
+ batch_rows: int = DEFAULT_DATABASE_BATCH_SIZE,
76
+ on_conflict: Optional[str] = None,
77
+ column_mapping: Optional[dict[str, Optional[str]]] = None,
78
+ ) -> None:
79
+ """
80
+ Implementation function for exporting DataChain to database tables.
63
81
 
64
- if not to_infer:
65
- return [], {}
82
+ This is the core implementation that handles the actual database operations.
83
+ For user-facing documentation, see DataChain.to_database() method.
84
+ """
85
+ from datachain.utils import batched
66
86
 
67
- rows = list(itertools.islice(result, infer_schema_length))
68
- values = {col: [row._mapping[col] for row in rows] for col in to_infer}
69
- _, output_schema, _ = values_to_tuples("", **values)
70
- return rows, output_schema
87
+ if on_conflict and on_conflict not in ("ignore", "update"):
88
+ raise ValueError(
89
+ f"on_conflict must be 'ignore' or 'update', got: {on_conflict}"
90
+ )
91
+
92
+ signals_schema = chain.signals_schema.clone_without_sys_signals()
93
+ all_columns = [
94
+ sqlalchemy.Column(c.name, c.type) # type: ignore[union-attr]
95
+ for c in signals_schema.db_signals(as_columns=True)
96
+ ]
97
+
98
+ column_mapping = column_mapping or {}
99
+ normalized_column_mapping = _normalize_column_mapping(column_mapping)
100
+ column_indices_and_names, columns = _prepare_columns(
101
+ all_columns, normalized_column_mapping
102
+ )
103
+
104
+ with _connect(connection) as conn:
105
+ metadata = sqlalchemy.MetaData()
106
+ table = sqlalchemy.Table(table_name, metadata, *columns)
107
+
108
+ # Check if table already exists to determine if we should clean up on error.
109
+ inspector = sqlalchemy.inspect(conn)
110
+ assert inspector # to satisfy mypy
111
+ table_existed_before = table_name in inspector.get_table_names()
112
+
113
+ try:
114
+ table.create(conn, checkfirst=True)
115
+ rows_iter = chain._leaf_values()
116
+ for batch in batched(rows_iter, batch_rows):
117
+ _process_batch(
118
+ conn, table, batch, on_conflict, column_indices_and_names
119
+ )
120
+ conn.commit()
121
+ except Exception:
122
+ if not table_existed_before:
123
+ try:
124
+ table.drop(conn, checkfirst=True)
125
+ conn.commit()
126
+ except sqlalchemy.exc.SQLAlchemyError:
127
+ pass
128
+ raise
129
+
130
+
131
+ def _normalize_column_mapping(
132
+ column_mapping: dict[str, Optional[str]],
133
+ ) -> dict[str, Optional[str]]:
134
+ """
135
+ Convert column mapping keys from DataChain format (dots) to database format
136
+ (double underscores).
137
+
138
+ This allows users to specify column mappings using the intuitive DataChain
139
+ format like: {"nested_data.value": "data_value"} instead of
140
+ {"nested_data__value": "data_value"}
141
+ """
142
+ if not column_mapping:
143
+ return {}
144
+
145
+ normalized_mapping: dict[str, Optional[str]] = {}
146
+ original_keys: dict[str, str] = {}
147
+ for key, value in column_mapping.items():
148
+ db_key = ColumnMeta.to_db_name(key)
149
+ if db_key in normalized_mapping:
150
+ prev = original_keys[db_key]
151
+ raise ValueError(
152
+ "Column mapping collision: multiple keys map to the same "
153
+ f"database column name '{db_key}': '{prev}' and '{key}'. "
154
+ )
155
+ normalized_mapping[db_key] = value
156
+ original_keys[db_key] = key
157
+
158
+ # If it's a defaultdict, preserve the default factory
159
+ if hasattr(column_mapping, "default_factory"):
160
+ from collections import defaultdict
161
+
162
+ default_factory = column_mapping.default_factory
163
+ result: dict[str, Optional[str]] = defaultdict(default_factory)
164
+ result.update(normalized_mapping)
165
+ return result
166
+
167
+ return normalized_mapping
168
+
169
+
170
+ def _prepare_columns(all_columns, column_mapping):
171
+ """Prepare column mapping and column definitions."""
172
+ column_indices_and_names = [] # List of (index, target_name) tuples
173
+ columns = []
174
+ for idx, col in enumerate(all_columns):
175
+ if col.name in column_mapping or hasattr(column_mapping, "default_factory"):
176
+ mapped_name = column_mapping[col.name]
177
+ if mapped_name:
178
+ columns.append(sqlalchemy.Column(mapped_name, col.type))
179
+ column_indices_and_names.append((idx, mapped_name))
180
+ else:
181
+ columns.append(col)
182
+ column_indices_and_names.append((idx, col.name))
183
+ return column_indices_and_names, columns
184
+
185
+
186
+ def _process_batch(conn, table, batch, on_conflict, column_indices_and_names):
187
+ """Process a batch of rows with conflict resolution."""
188
+
189
+ def prepare_row(row_values):
190
+ """Convert a row tuple to a dictionary with proper DB column names."""
191
+ return {
192
+ target_name: row_values[idx]
193
+ for idx, target_name in column_indices_and_names
194
+ }
195
+
196
+ rows_to_insert = [prepare_row(row) for row in batch]
197
+
198
+ supports_conflict = on_conflict and conn.engine.name in ("postgresql", "sqlite")
199
+
200
+ if supports_conflict:
201
+ # Use dialect-specific insert for conflict resolution
202
+ if conn.engine.name == "postgresql":
203
+ from sqlalchemy.dialects.postgresql import insert as pg_insert
204
+
205
+ insert_stmt = pg_insert(table)
206
+ elif conn.engine.name == "sqlite":
207
+ from sqlalchemy.dialects.sqlite import insert as sqlite_insert
208
+
209
+ insert_stmt = sqlite_insert(table)
210
+ else:
211
+ insert_stmt = table.insert()
212
+
213
+ if supports_conflict:
214
+ if on_conflict == "ignore":
215
+ insert_stmt = insert_stmt.on_conflict_do_nothing()
216
+ elif on_conflict == "update":
217
+ update_values = {
218
+ col.name: insert_stmt.excluded[col.name] for col in table.columns
219
+ }
220
+ insert_stmt = insert_stmt.on_conflict_do_update(set_=update_values)
221
+ elif on_conflict:
222
+ import warnings
223
+
224
+ warnings.warn(
225
+ f"Database does not support conflict resolution. "
226
+ f"Ignoring on_conflict='{on_conflict}' parameter.",
227
+ UserWarning,
228
+ stacklevel=2,
229
+ )
230
+
231
+ conn.execute(insert_stmt, rows_to_insert)
71
232
 
72
233
 
73
234
  def read_database(
@@ -151,3 +312,19 @@ def read_database(
151
312
  in_memory=in_memory,
152
313
  schema=inferred_schema | output,
153
314
  )
315
+
316
+
317
+ def _infer_schema(
318
+ result: "sqlalchemy.engine.Result",
319
+ to_infer: list[str],
320
+ infer_schema_length: Optional[int] = 100,
321
+ ) -> tuple[list["sqlalchemy.Row"], dict[str, "DataType"]]:
322
+ from datachain.lib.convert.values_to_tuples import values_to_tuples
323
+
324
+ if not to_infer:
325
+ return [], {}
326
+
327
+ rows = list(itertools.islice(result, infer_schema_length))
328
+ values = {col: [row._mapping[col] for row in rows] for col in to_infer}
329
+ _, output_schema, _ = values_to_tuples("", **values)
330
+ return rows, output_schema
@@ -58,6 +58,7 @@ from datachain.query.schema import DEFAULT_DELIMITER, Column
58
58
  from datachain.sql.functions import path as pathfunc
59
59
  from datachain.utils import batched_it, inside_notebook, row_to_nested_dict
60
60
 
61
+ from .database import DEFAULT_DATABASE_BATCH_SIZE
61
62
  from .utils import (
62
63
  DatasetMergeError,
63
64
  DatasetPrepareError,
@@ -77,11 +78,23 @@ UDFObjT = TypeVar("UDFObjT", bound=UDFBase)
77
78
  DEFAULT_PARQUET_CHUNK_SIZE = 100_000
78
79
 
79
80
  if TYPE_CHECKING:
81
+ import sqlite3
82
+
80
83
  import pandas as pd
81
84
  from typing_extensions import ParamSpec, Self
82
85
 
83
86
  P = ParamSpec("P")
84
87
 
88
+ ConnectionType = Union[
89
+ str,
90
+ sqlalchemy.engine.URL,
91
+ sqlalchemy.engine.interfaces.Connectable,
92
+ sqlalchemy.engine.Engine,
93
+ sqlalchemy.engine.Connection,
94
+ "sqlalchemy.orm.Session",
95
+ sqlite3.Connection,
96
+ ]
97
+
85
98
 
86
99
  T = TypeVar("T", bound="DataChain")
87
100
 
@@ -324,6 +337,7 @@ class DataChain:
324
337
  sys: Optional[bool] = None,
325
338
  namespace: Optional[str] = None,
326
339
  project: Optional[str] = None,
340
+ batch_rows: Optional[int] = None,
327
341
  ) -> "Self":
328
342
  """Change settings for chain.
329
343
 
@@ -331,22 +345,24 @@ class DataChain:
331
345
  It returns chain, so, it can be chained later with next operation.
332
346
 
333
347
  Parameters:
334
- cache : data caching (default=False)
348
+ cache : data caching. (default=False)
335
349
  parallel : number of thread for processors. True is a special value to
336
- enable all available CPUs (default=1)
350
+ enable all available CPUs. (default=1)
337
351
  workers : number of distributed workers. Only for Studio mode. (default=1)
338
- min_task_size : minimum number of tasks (default=1)
339
- prefetch: number of workers to use for downloading files in advance.
352
+ min_task_size : minimum number of tasks. (default=1)
353
+ prefetch : number of workers to use for downloading files in advance.
340
354
  This is enabled by default and uses 2 workers.
341
355
  To disable prefetching, set it to 0.
342
- namespace: namespace name.
343
- project: project name.
356
+ namespace : namespace name.
357
+ project : project name.
358
+ batch_rows : row limit per insert to balance speed and memory usage.
359
+ (default=2000)
344
360
 
345
361
  Example:
346
362
  ```py
347
363
  chain = (
348
364
  chain
349
- .settings(cache=True, parallel=8)
365
+ .settings(cache=True, parallel=8, batch_rows=300)
350
366
  .map(laion=process_webdataset(spec=WDSLaion), params="file")
351
367
  )
352
368
  ```
@@ -356,7 +372,14 @@ class DataChain:
356
372
  settings = copy.copy(self._settings)
357
373
  settings.add(
358
374
  Settings(
359
- cache, parallel, workers, min_task_size, prefetch, namespace, project
375
+ cache,
376
+ parallel,
377
+ workers,
378
+ min_task_size,
379
+ prefetch,
380
+ namespace,
381
+ project,
382
+ batch_rows,
360
383
  )
361
384
  )
362
385
  return self._evolve(settings=settings, _sys=sys)
@@ -711,7 +734,7 @@ class DataChain:
711
734
 
712
735
  return self._evolve(
713
736
  query=self._query.add_signals(
714
- udf_obj.to_udf_wrapper(),
737
+ udf_obj.to_udf_wrapper(self._settings.batch_rows),
715
738
  **self._settings.to_dict(),
716
739
  ),
717
740
  signal_schema=self.signals_schema | udf_obj.output,
@@ -749,7 +772,7 @@ class DataChain:
749
772
  udf_obj.prefetch = prefetch
750
773
  return self._evolve(
751
774
  query=self._query.generate(
752
- udf_obj.to_udf_wrapper(),
775
+ udf_obj.to_udf_wrapper(self._settings.batch_rows),
753
776
  **self._settings.to_dict(),
754
777
  ),
755
778
  signal_schema=udf_obj.output,
@@ -885,7 +908,7 @@ class DataChain:
885
908
  udf_obj = self._udf_to_obj(Aggregator, func, params, output, signal_map)
886
909
  return self._evolve(
887
910
  query=self._query.generate(
888
- udf_obj.to_udf_wrapper(),
911
+ udf_obj.to_udf_wrapper(self._settings.batch_rows),
889
912
  partition_by=processed_partition_by,
890
913
  **self._settings.to_dict(),
891
914
  ),
@@ -917,11 +940,24 @@ class DataChain:
917
940
  )
918
941
  chain.save("new_dataset")
919
942
  ```
943
+
944
+ .. deprecated:: 0.29.0
945
+ This method is deprecated and will be removed in a future version.
946
+ Use `agg()` instead, which provides the similar functionality.
920
947
  """
948
+ import warnings
949
+
950
+ warnings.warn(
951
+ "batch_map() is deprecated and will be removed in a future version. "
952
+ "Use agg() instead, which provides the similar functionality.",
953
+ DeprecationWarning,
954
+ stacklevel=2,
955
+ )
921
956
  udf_obj = self._udf_to_obj(BatchMapper, func, params, output, signal_map)
957
+
922
958
  return self._evolve(
923
959
  query=self._query.add_signals(
924
- udf_obj.to_udf_wrapper(batch),
960
+ udf_obj.to_udf_wrapper(self._settings.batch_rows, batch=batch),
925
961
  **self._settings.to_dict(),
926
962
  ),
927
963
  signal_schema=self.signals_schema | udf_obj.output,
@@ -2253,6 +2289,97 @@ class DataChain:
2253
2289
  """
2254
2290
  self.to_json(path, fs_kwargs, include_outer_list=False)
2255
2291
 
2292
+ def to_database(
2293
+ self,
2294
+ table_name: str,
2295
+ connection: "ConnectionType",
2296
+ *,
2297
+ batch_rows: int = DEFAULT_DATABASE_BATCH_SIZE,
2298
+ on_conflict: Optional[str] = None,
2299
+ column_mapping: Optional[dict[str, Optional[str]]] = None,
2300
+ ) -> None:
2301
+ """Save chain to a database table using a given database connection.
2302
+
2303
+ This method exports all DataChain records to a database table, creating the
2304
+ table if it doesn't exist and appending data if it does. The table schema
2305
+ is automatically inferred from the DataChain's signal schema.
2306
+
2307
+ Parameters:
2308
+ table_name: Name of the database table to create/write to.
2309
+ connection: SQLAlchemy connectable, str, or a sqlite3 connection
2310
+ Using SQLAlchemy makes it possible to use any DB supported by that
2311
+ library. If a DBAPI2 object, only sqlite3 is supported. The user is
2312
+ responsible for engine disposal and connection closure for the
2313
+ SQLAlchemy connectable; str connections are closed automatically.
2314
+ batch_rows: Number of rows to insert per batch for optimal performance.
2315
+ Larger batches are faster but use more memory. Default: 10,000.
2316
+ on_conflict: Strategy for handling duplicate rows (requires table
2317
+ constraints):
2318
+ - None: Raise error (`sqlalchemy.exc.IntegrityError`) on conflict
2319
+ (default)
2320
+ - "ignore": Skip duplicate rows silently
2321
+ - "update": Update existing rows with new values
2322
+ column_mapping: Optional mapping to rename or skip columns:
2323
+ - Dict mapping DataChain column names to database column names
2324
+ - Set values to None to skip columns entirely, or use `defaultdict` to
2325
+ skip all columns except those specified.
2326
+
2327
+ Examples:
2328
+ Basic usage with PostgreSQL:
2329
+ ```py
2330
+ import sqlalchemy as sa
2331
+ import datachain as dc
2332
+
2333
+ chain = dc.read_storage("s3://my-bucket/")
2334
+ engine = sa.create_engine("postgresql://user:pass@localhost/mydb")
2335
+ chain.to_database("files_table", engine)
2336
+ ```
2337
+
2338
+ Using SQLite with connection string:
2339
+ ```py
2340
+ chain.to_database("my_table", "sqlite:///data.db")
2341
+ ```
2342
+
2343
+ Column mapping and renaming:
2344
+ ```py
2345
+ mapping = {
2346
+ "user.id": "id",
2347
+ "user.name": "name",
2348
+ "user.password": None # Skip this column
2349
+ }
2350
+ chain.to_database("users", engine, column_mapping=mapping)
2351
+ ```
2352
+
2353
+ Handling conflicts (requires PRIMARY KEY or UNIQUE constraints):
2354
+ ```py
2355
+ # Skip duplicates
2356
+ chain.to_database("my_table", engine, on_conflict="ignore")
2357
+
2358
+ # Update existing records
2359
+ chain.to_database("my_table", engine, on_conflict="update")
2360
+ ```
2361
+
2362
+ Working with different databases:
2363
+ ```py
2364
+ # MySQL
2365
+ mysql_engine = sa.create_engine("mysql+pymysql://user:pass@host/db")
2366
+ chain.to_database("mysql_table", mysql_engine)
2367
+
2368
+ # SQLite in-memory
2369
+ chain.to_database("temp_table", "sqlite:///:memory:")
2370
+ ```
2371
+ """
2372
+ from .database import to_database
2373
+
2374
+ to_database(
2375
+ self,
2376
+ table_name,
2377
+ connection,
2378
+ batch_rows=batch_rows,
2379
+ on_conflict=on_conflict,
2380
+ column_mapping=column_mapping,
2381
+ )
2382
+
2256
2383
  @classmethod
2257
2384
  def from_records(
2258
2385
  cls,
@@ -2340,7 +2467,7 @@ class DataChain:
2340
2467
  def setup(self, **kwargs) -> "Self":
2341
2468
  """Setup variables to pass to UDF functions.
2342
2469
 
2343
- Use before running map/gen/agg/batch_map to save an object and pass it as an
2470
+ Use before running map/gen/agg to save an object and pass it as an
2344
2471
  argument to the UDF.
2345
2472
 
2346
2473
  The value must be a callable (a `lambda: <value>` syntax can be used to quickly
@@ -15,6 +15,8 @@ if TYPE_CHECKING:
15
15
 
16
16
  P = ParamSpec("P")
17
17
 
18
+ READ_RECORDS_BATCH_SIZE = 10000
19
+
18
20
 
19
21
  def read_records(
20
22
  to_insert: Optional[Union[dict, Iterable[dict]]],
@@ -41,7 +43,7 @@ def read_records(
41
43
  Notes:
42
44
  This call blocks until all records are inserted.
43
45
  """
44
- from datachain.query.dataset import INSERT_BATCH_SIZE, adjust_outputs, get_col_types
46
+ from datachain.query.dataset import adjust_outputs, get_col_types
45
47
  from datachain.sql.types import SQLType
46
48
  from datachain.utils import batched
47
49
 
@@ -94,7 +96,7 @@ def read_records(
94
96
  {c.name: c.type for c in columns if isinstance(c.type, SQLType)},
95
97
  )
96
98
  records = (adjust_outputs(warehouse, record, col_types) for record in to_insert)
97
- for chunk in batched(records, INSERT_BATCH_SIZE):
99
+ for chunk in batched(records, READ_RECORDS_BATCH_SIZE):
98
100
  warehouse.insert_rows(table, chunk)
99
101
  warehouse.insert_rows_done(table)
100
102
  return read_dataset(name=dsr.full_name, session=session, settings=settings)
datachain/lib/settings.py CHANGED
@@ -1,4 +1,5 @@
1
1
  from datachain.lib.utils import DataChainParamsError
2
+ from datachain.utils import DEFAULT_CHUNK_ROWS
2
3
 
3
4
 
4
5
  class SettingsError(DataChainParamsError):
@@ -16,6 +17,7 @@ class Settings:
16
17
  prefetch=None,
17
18
  namespace=None,
18
19
  project=None,
20
+ batch_rows=None,
19
21
  ):
20
22
  self._cache = cache
21
23
  self.parallel = parallel
@@ -24,6 +26,7 @@ class Settings:
24
26
  self.prefetch = prefetch
25
27
  self.namespace = namespace
26
28
  self.project = project
29
+ self._chunk_rows = batch_rows
27
30
 
28
31
  if not isinstance(cache, bool) and cache is not None:
29
32
  raise SettingsError(
@@ -53,6 +56,18 @@ class Settings:
53
56
  f", {min_task_size.__class__.__name__} was given"
54
57
  )
55
58
 
59
+ if batch_rows is not None and not isinstance(batch_rows, int):
60
+ raise SettingsError(
61
+ "'batch_rows' argument must be int or None"
62
+ f", {batch_rows.__class__.__name__} was given"
63
+ )
64
+
65
+ if batch_rows is not None and batch_rows <= 0:
66
+ raise SettingsError(
67
+ "'batch_rows' argument must be positive integer"
68
+ f", {batch_rows} was given"
69
+ )
70
+
56
71
  @property
57
72
  def cache(self):
58
73
  return self._cache if self._cache is not None else False
@@ -61,6 +76,10 @@ class Settings:
61
76
  def workers(self):
62
77
  return self._workers if self._workers is not None else False
63
78
 
79
+ @property
80
+ def batch_rows(self):
81
+ return self._chunk_rows if self._chunk_rows is not None else DEFAULT_CHUNK_ROWS
82
+
64
83
  def to_dict(self):
65
84
  res = {}
66
85
  if self._cache is not None:
@@ -75,6 +94,8 @@ class Settings:
75
94
  res["namespace"] = self.namespace
76
95
  if self.project is not None:
77
96
  res["project"] = self.project
97
+ if self._chunk_rows is not None:
98
+ res["batch_rows"] = self._chunk_rows
78
99
  return res
79
100
 
80
101
  def add(self, settings: "Settings"):
@@ -86,3 +107,5 @@ class Settings:
86
107
  self.project = settings.project or self.project
87
108
  if settings.prefetch is not None:
88
109
  self.prefetch = settings.prefetch
110
+ if settings._chunk_rows is not None:
111
+ self._chunk_rows = settings._chunk_rows
@@ -34,7 +34,7 @@ from datachain.lib.data_model import DataModel, DataType, DataValue
34
34
  from datachain.lib.file import File
35
35
  from datachain.lib.model_store import ModelStore
36
36
  from datachain.lib.utils import DataChainParamsError
37
- from datachain.query.schema import DEFAULT_DELIMITER, Column
37
+ from datachain.query.schema import DEFAULT_DELIMITER, Column, ColumnMeta
38
38
  from datachain.sql.types import SQLType
39
39
 
40
40
  if TYPE_CHECKING:
@@ -590,7 +590,7 @@ class SignalSchema:
590
590
 
591
591
  if name:
592
592
  if "." in name:
593
- name = name.replace(".", "__")
593
+ name = ColumnMeta.to_db_name(name)
594
594
 
595
595
  signals = [
596
596
  s
datachain/lib/udf.py CHANGED
@@ -62,19 +62,21 @@ class UDFProperties:
62
62
  return self.udf.get_batching(use_partitioning)
63
63
 
64
64
  @property
65
- def batch(self):
66
- return self.udf.batch
65
+ def batch_rows(self):
66
+ return self.udf.batch_rows
67
67
 
68
68
 
69
69
  @attrs.define(slots=False)
70
70
  class UDFAdapter:
71
71
  inner: "UDFBase"
72
72
  output: UDFOutputSpec
73
+ batch_rows: Optional[int] = None
73
74
  batch: int = 1
74
75
 
75
76
  def get_batching(self, use_partitioning: bool = False) -> BatchingStrategy:
76
77
  if use_partitioning:
77
78
  return Partition()
79
+
78
80
  if self.batch == 1:
79
81
  return NoBatching()
80
82
  if self.batch > 1:
@@ -233,10 +235,15 @@ class UDFBase(AbstractUDF):
233
235
  def signal_names(self) -> Iterable[str]:
234
236
  return self.output.to_udf_spec().keys()
235
237
 
236
- def to_udf_wrapper(self, batch: int = 1) -> UDFAdapter:
238
+ def to_udf_wrapper(
239
+ self,
240
+ batch_rows: Optional[int] = None,
241
+ batch: int = 1,
242
+ ) -> UDFAdapter:
237
243
  return UDFAdapter(
238
244
  self,
239
245
  self.output.to_udf_spec(),
246
+ batch_rows,
240
247
  batch,
241
248
  )
242
249
 
@@ -418,11 +425,27 @@ class Mapper(UDFBase):
418
425
 
419
426
 
420
427
  class BatchMapper(UDFBase):
421
- """Inherit from this class to pass to `DataChain.batch_map()`."""
428
+ """Inherit from this class to pass to `DataChain.batch_map()`.
429
+
430
+ .. deprecated:: 0.29.0
431
+ This class is deprecated and will be removed in a future version.
432
+ Use `Aggregator` instead, which provides the similar functionality.
433
+ """
422
434
 
423
435
  is_input_batched = True
424
436
  is_output_batched = True
425
437
 
438
+ def __init__(self):
439
+ import warnings
440
+
441
+ warnings.warn(
442
+ "BatchMapper is deprecated and will be removed in a future version. "
443
+ "Use Aggregator instead, which provides the similar functionality.",
444
+ DeprecationWarning,
445
+ stacklevel=2,
446
+ )
447
+ super().__init__()
448
+
426
449
  def run(
427
450
  self,
428
451
  udf_fields: Sequence[str],
@@ -333,32 +333,24 @@ def process_udf_outputs(
333
333
  udf_table: "Table",
334
334
  udf_results: Iterator[Iterable["UDFResult"]],
335
335
  udf: "UDFAdapter",
336
- batch_size: int = INSERT_BATCH_SIZE,
337
336
  cb: Callback = DEFAULT_CALLBACK,
338
337
  ) -> None:
339
- import psutil
340
-
341
- rows: list[UDFResult] = []
342
338
  # Optimization: Compute row types once, rather than for every row.
343
339
  udf_col_types = get_col_types(warehouse, udf.output)
340
+ batch_rows = udf.batch_rows or INSERT_BATCH_SIZE
344
341
 
345
- for udf_output in udf_results:
346
- if not udf_output:
347
- continue
348
- with safe_closing(udf_output):
349
- for row in udf_output:
350
- cb.relative_update()
351
- rows.append(adjust_outputs(warehouse, row, udf_col_types))
352
- if len(rows) >= batch_size or (
353
- len(rows) % 10 == 0 and psutil.virtual_memory().percent > 80
354
- ):
355
- for row_chunk in batched(rows, batch_size):
356
- warehouse.insert_rows(udf_table, row_chunk)
357
- rows.clear()
342
+ def _insert_rows():
343
+ for udf_output in udf_results:
344
+ if not udf_output:
345
+ continue
346
+
347
+ with safe_closing(udf_output):
348
+ for row in udf_output:
349
+ cb.relative_update()
350
+ yield adjust_outputs(warehouse, row, udf_col_types)
358
351
 
359
- if rows:
360
- for row_chunk in batched(rows, batch_size):
361
- warehouse.insert_rows(udf_table, row_chunk)
352
+ for row_chunk in batched(_insert_rows(), batch_rows):
353
+ warehouse.insert_rows(udf_table, row_chunk)
362
354
 
363
355
  warehouse.insert_rows_done(udf_table)
364
356
 
@@ -401,6 +393,7 @@ class UDFStep(Step, ABC):
401
393
  min_task_size: Optional[int] = None
402
394
  is_generator = False
403
395
  cache: bool = False
396
+ batch_rows: Optional[int] = None
404
397
 
405
398
  @abstractmethod
406
399
  def create_udf_table(self, query: Select) -> "Table":
@@ -602,6 +595,7 @@ class UDFStep(Step, ABC):
602
595
  parallel=self.parallel,
603
596
  workers=self.workers,
604
597
  min_task_size=self.min_task_size,
598
+ batch_rows=self.batch_rows,
605
599
  )
606
600
  return self.__class__(self.udf, self.catalog)
607
601
 
@@ -1633,6 +1627,7 @@ class DatasetQuery:
1633
1627
  min_task_size: Optional[int] = None,
1634
1628
  partition_by: Optional[PartitionByType] = None,
1635
1629
  cache: bool = False,
1630
+ batch_rows: Optional[int] = None,
1636
1631
  ) -> "Self":
1637
1632
  """
1638
1633
  Adds one or more signals based on the results from the provided UDF.
@@ -1658,6 +1653,7 @@ class DatasetQuery:
1658
1653
  workers=workers,
1659
1654
  min_task_size=min_task_size,
1660
1655
  cache=cache,
1656
+ batch_rows=batch_rows,
1661
1657
  )
1662
1658
  )
1663
1659
  return query
@@ -1679,6 +1675,7 @@ class DatasetQuery:
1679
1675
  namespace: Optional[str] = None,
1680
1676
  project: Optional[str] = None,
1681
1677
  cache: bool = False,
1678
+ batch_rows: Optional[int] = None,
1682
1679
  ) -> "Self":
1683
1680
  query = self.clone()
1684
1681
  steps = query.steps
@@ -1691,6 +1688,7 @@ class DatasetQuery:
1691
1688
  workers=workers,
1692
1689
  min_task_size=min_task_size,
1693
1690
  cache=cache,
1691
+ batch_rows=batch_rows,
1694
1692
  )
1695
1693
  )
1696
1694
  return query
datachain/utils.py CHANGED
@@ -11,7 +11,6 @@ import time
11
11
  from collections.abc import Iterable, Iterator, Sequence
12
12
  from contextlib import contextmanager
13
13
  from datetime import date, datetime, timezone
14
- from itertools import chain, islice
15
14
  from typing import TYPE_CHECKING, Any, Optional, TypeVar, Union
16
15
  from uuid import UUID
17
16
 
@@ -26,6 +25,8 @@ if TYPE_CHECKING:
26
25
  from typing_extensions import Self
27
26
 
28
27
 
28
+ DEFAULT_CHUNK_ROWS = 2000
29
+
29
30
  logger = logging.getLogger("datachain")
30
31
 
31
32
  NUL = b"\0"
@@ -225,30 +226,44 @@ def get_envs_by_prefix(prefix: str) -> dict[str, str]:
225
226
  _T_co = TypeVar("_T_co", covariant=True)
226
227
 
227
228
 
228
- def batched(iterable: Iterable[_T_co], n: int) -> Iterator[tuple[_T_co, ...]]:
229
- """Batch data into tuples of length n. The last batch may be shorter."""
230
- # Based on: https://docs.python.org/3/library/itertools.html#itertools-recipes
231
- # batched('ABCDEFG', 3) --> ABC DEF G
232
- if n < 1:
233
- raise ValueError("Batch size must be at least one")
234
- it = iter(iterable)
235
- while batch := tuple(islice(it, n)):
229
+ def _dynamic_batched_core(
230
+ iterable: Iterable[_T_co],
231
+ batch_rows: int,
232
+ ) -> Iterator[list[_T_co]]:
233
+ """Core batching logic that yields lists."""
234
+
235
+ batch: list[_T_co] = []
236
+
237
+ for item in iterable:
238
+ # Check if adding this item would exceed limits
239
+ if len(batch) >= batch_rows and batch: # Yield current batch if we have one
240
+ yield batch
241
+ batch = []
242
+
243
+ batch.append(item)
244
+
245
+ # Yield any remaining items
246
+ if batch:
236
247
  yield batch
237
248
 
238
249
 
239
- def batched_it(iterable: Iterable[_T_co], n: int) -> Iterator[Iterator[_T_co]]:
240
- """Batch data into iterators of length n. The last batch may be shorter."""
241
- # batched('ABCDEFG', 3) --> ABC DEF G
242
- if n < 1:
243
- raise ValueError("Batch size must be at least one")
244
- it = iter(iterable)
245
- while True:
246
- chunk_it = islice(it, n)
247
- try:
248
- first_el = next(chunk_it)
249
- except StopIteration:
250
- return
251
- yield chain((first_el,), chunk_it)
250
+ def batched(iterable: Iterable[_T_co], batch_rows: int) -> Iterator[tuple[_T_co, ...]]:
251
+ """
252
+ Batch data into tuples of length batch_rows .
253
+ The last batch may be shorter.
254
+ """
255
+ yield from (tuple(batch) for batch in _dynamic_batched_core(iterable, batch_rows))
256
+
257
+
258
+ def batched_it(
259
+ iterable: Iterable[_T_co],
260
+ batch_rows: int = DEFAULT_CHUNK_ROWS,
261
+ ) -> Iterator[Iterator[_T_co]]:
262
+ """
263
+ Batch data into iterators with dynamic sizing
264
+ based on row count and memory usage.
265
+ """
266
+ yield from (iter(batch) for batch in _dynamic_batched_core(iterable, batch_rows))
252
267
 
253
268
 
254
269
  def flatten(items):
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: datachain
3
- Version: 0.28.1
3
+ Version: 0.29.0
4
4
  Summary: Wrangle unstructured AI data at scale
5
5
  Author-email: Dmitry Petrov <support@dvc.org>
6
6
  License-Expression: Apache-2.0
@@ -19,7 +19,7 @@ datachain/script_meta.py,sha256=V-LaFOZG84pD0Zc0NvejYdzwDgzITv6yHvAHggDCnuY,4978
19
19
  datachain/semver.py,sha256=UB8GHPBtAP3UJGeiuJoInD7SK-DnB93_Xd1qy_CQ9cU,2074
20
20
  datachain/studio.py,sha256=-BmKLVNBLPFveUgVVE2So3aaiGndO2jK2qbHZ0zBDd8,15239
21
21
  datachain/telemetry.py,sha256=0A4IOPPp9VlP5pyW9eBfaTK3YhHGzHl7dQudQjUAx9A,994
22
- datachain/utils.py,sha256=DNqOi-Ydb7InyWvD9m7_yailxz6-YGpZzh00biQaHNo,15305
22
+ datachain/utils.py,sha256=Gp5JVr_m7nVWQGDOjrGnZjRXF9-Ai-MBxiPJIcpPvWQ,15451
23
23
  datachain/catalog/__init__.py,sha256=cMZzSz3VoUi-6qXSVaHYN-agxQuAcz2XSqnEPZ55crE,353
24
24
  datachain/catalog/catalog.py,sha256=QTWCXy75iWo-0MCXyfV_WbsKeZ1fpLpvL8d60rxn1ws,65528
25
25
  datachain/catalog/datasource.py,sha256=IkGMh0Ttg6Q-9DWfU_H05WUnZepbGa28HYleECi6K7I,1353
@@ -53,7 +53,7 @@ datachain/data_storage/metastore.py,sha256=Qw332arvhgXB4UY0yX-Hu8Vgl3smU12l6bvxr
53
53
  datachain/data_storage/schema.py,sha256=o3JbURKXRg3IJyIVA4QjHHkn6byRuz7avbydU2FlvNY,9897
54
54
  datachain/data_storage/serializer.py,sha256=6G2YtOFqqDzJf1KbvZraKGXl2XHZyVml2krunWUum5o,927
55
55
  datachain/data_storage/sqlite.py,sha256=TTQjdDXUaZSr3MEaxZjDhsVIkIJqxFNA-sD25TO3m_4,30228
56
- datachain/data_storage/warehouse.py,sha256=nhF8yfpdJpstpXnv_sj7WFzU97JkvSeqetqJQp33cyE,32563
56
+ datachain/data_storage/warehouse.py,sha256=66PETLzfkgSmj-EF604m62xmFMQBXaRZSw8sdKGMam8,32613
57
57
  datachain/diff/__init__.py,sha256=-OFZzgOplqO84iWgGY7kfe60NXaWR9JRIh9T-uJboAM,9668
58
58
  datachain/fs/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
59
59
  datachain/fs/reference.py,sha256=A8McpXF0CqbXPqanXuvpKu50YLB3a2ZXA3YAPxtBXSM,914
@@ -85,11 +85,11 @@ datachain/lib/model_store.py,sha256=dkL2rcT5ag-kbgkhQPL_byEs-TCYr29qvdltroL5NxM,
85
85
  datachain/lib/namespaces.py,sha256=it52UbbwB8dzhesO2pMs_nThXiPQ1Ph9sD9I3GQkg5s,2099
86
86
  datachain/lib/projects.py,sha256=8lN0qV8czX1LGtWURCUvRlSJk-RpO9w9Rra_pOZus6g,2595
87
87
  datachain/lib/pytorch.py,sha256=S-st2SAczYut13KMf6eSqP_OQ8otWI5TRmzhK5fN3k0,7828
88
- datachain/lib/settings.py,sha256=9wi0FoHxRxNiyn99pR28IYsMkoo47jQxeXuObQr2Ar0,2929
89
- datachain/lib/signal_schema.py,sha256=JMsL8c4iCRH9PoRumvjimsOLQQslTjm_aDR2jh1zT2Q,38558
88
+ datachain/lib/settings.py,sha256=n0YYhCVdgCdMkCSLY7kscJF9mUhlQ0a4ENWBsJFynkw,3809
89
+ datachain/lib/signal_schema.py,sha256=FmsfEAdRDeAzv1ApQnRXzkkyNeY9fTaXpjMzSMhDh7M,38574
90
90
  datachain/lib/tar.py,sha256=MLcVjzIgBqRuJacCNpZ6kwSZNq1i2tLyROc8PVprHsA,999
91
91
  datachain/lib/text.py,sha256=UNHm8fhidk7wdrWqacEWaA6I9ykfYqarQ2URby7jc7M,1261
92
- datachain/lib/udf.py,sha256=SUnJWRDC3TlLhvpi8iqqJbeZGn5DChot7DyH-0Q-z20,17305
92
+ datachain/lib/udf.py,sha256=IB1IKF5KyA-NiyfhVzmBPpF_aITPS3zSlrt24f_Ofjo,17956
93
93
  datachain/lib/udf_signature.py,sha256=Yz20iJ-WF1pijT3hvcDIKFzgWV9gFxZM73KZRx3NbPk,7560
94
94
  datachain/lib/utils.py,sha256=RLji1gHnfDXtJCnBo8BcNu1obndFpVsXJ_1Vb-FQ9Qo,4554
95
95
  datachain/lib/video.py,sha256=ddVstiMkfxyBPDsnjCKY0d_93bw-DcMqGqN60yzsZoo,6851
@@ -103,15 +103,15 @@ datachain/lib/convert/unflatten.py,sha256=ysMkstwJzPMWUlnxn-Z-tXJR3wmhjHeSN_P-sD
103
103
  datachain/lib/convert/values_to_tuples.py,sha256=j5yZMrVUH6W7b-7yUvdCTGI7JCUAYUOzHUGPoyZXAB0,4360
104
104
  datachain/lib/dc/__init__.py,sha256=TFci5HTvYGjBesNUxDAnXaX36PnzPEUSn5a6JxB9o0U,872
105
105
  datachain/lib/dc/csv.py,sha256=q6a9BpapGwP6nwy6c5cklxQumep2fUp9l2LAjtTJr6s,4411
106
- datachain/lib/dc/database.py,sha256=g5M6NjYR1T0vKte-abV-3Ejnm-HqxTIMir5cRi_SziE,6051
107
- datachain/lib/dc/datachain.py,sha256=U2CV8-ewfu-sW1D2BysdqCtbnEA7uNL1ZhYLWPAFB1o,93298
106
+ datachain/lib/dc/database.py,sha256=MPE-KzwcR2DhWLCEbl1gWFp63dLqjWuiJ1iEfC2BrJI,12443
107
+ datachain/lib/dc/datachain.py,sha256=_C9PZjUHVewpdp94AR2GS3QEI96Svsyx52dLJVM4tm4,98143
108
108
  datachain/lib/dc/datasets.py,sha256=P6CIJizD2IYFwOQG5D3VbQRjDmUiRH0ysdtb551Xdm8,15098
109
109
  datachain/lib/dc/hf.py,sha256=AP_MUHg6HJWae10PN9hD_beQVjrl0cleZ6Cvhtl1yoI,2901
110
110
  datachain/lib/dc/json.py,sha256=dNijfJ-H92vU3soyR7X1IiDrWhm6yZIGG3bSnZkPdAE,2733
111
111
  datachain/lib/dc/listings.py,sha256=V379Cb-7ZyquM0w7sWArQZkzInZy4GB7QQ1ZfowKzQY,4544
112
112
  datachain/lib/dc/pandas.py,sha256=ObueUXDUFKJGu380GmazdG02ARpKAHPhSaymfmOH13E,1489
113
113
  datachain/lib/dc/parquet.py,sha256=zYcSgrWwyEDW9UxGUSVdIVsCu15IGEf0xL8KfWQqK94,1782
114
- datachain/lib/dc/records.py,sha256=FpPbApWopUri1gIaSMsfXN4fevja4mjmfb6Q5eiaGxI,3116
114
+ datachain/lib/dc/records.py,sha256=4N1Fq-j5r4GK-PR5jIO-9B2u_zTNX9l-6SmcRhQDAsw,3136
115
115
  datachain/lib/dc/storage.py,sha256=FXroEdxOZfbuEBIWfWTkbGwrI0D4_mrLZSRsIQm0WFE,7693
116
116
  datachain/lib/dc/utils.py,sha256=VawOAlJSvAtZbsMg33s5tJe21TRx1Km3QggI1nN6tnw,3984
117
117
  datachain/lib/dc/values.py,sha256=7l1n352xWrEdql2NhBcZ3hj8xyPglWiY4qHjFPjn6iw,1428
@@ -126,7 +126,7 @@ datachain/model/ultralytics/pose.py,sha256=pBlmt63Qe68FKmexHimUGlNbNOoOlMHXG4fzX
126
126
  datachain/model/ultralytics/segment.py,sha256=63bDCj43E6iZ0hFI5J6uQfksdCmjEp6sEm1XzVaE8pw,2986
127
127
  datachain/query/__init__.py,sha256=7DhEIjAA8uZJfejruAVMZVcGFmvUpffuZJwgRqNwe-c,263
128
128
  datachain/query/batch.py,sha256=-goxLpE0EUvaDHu66rstj53UnfHpYfBUGux8GSpJ93k,4306
129
- datachain/query/dataset.py,sha256=cYNrg1QyrZpO-oup3mqmSYHUvgEYBKe8RgkVbyQa6p0,62777
129
+ datachain/query/dataset.py,sha256=OJZ_YwpS5i4B0wVmosMmMNW1qABr6zyOmqNHQdAWir4,62704
130
130
  datachain/query/dispatch.py,sha256=A0nPxn6mEN5d9dDo6S8m16Ji_9IvJLXrgF2kqXdi4fs,15546
131
131
  datachain/query/metrics.py,sha256=DOK5HdNVaRugYPjl8qnBONvTkwjMloLqAr7Mi3TjCO0,858
132
132
  datachain/query/params.py,sha256=O_j89mjYRLOwWNhYZl-z7mi-rkdP7WyFmaDufsdTryE,863
@@ -158,9 +158,9 @@ datachain/sql/sqlite/vector.py,sha256=ncW4eu2FlJhrP_CIpsvtkUabZlQdl2D5Lgwy_cbfqR
158
158
  datachain/toolkit/__init__.py,sha256=eQ58Q5Yf_Fgv1ZG0IO5dpB4jmP90rk8YxUWmPc1M2Bo,68
159
159
  datachain/toolkit/split.py,sha256=ktGWzY4kyzjWyR86dhvzw-Zhl0lVk_LOX3NciTac6qo,2914
160
160
  datachain/torch/__init__.py,sha256=gIS74PoEPy4TB3X6vx9nLO0Y3sLJzsA8ckn8pRWihJM,579
161
- datachain-0.28.1.dist-info/licenses/LICENSE,sha256=8DnqK5yoPI_E50bEg_zsHKZHY2HqPy4rYN338BHQaRA,11344
162
- datachain-0.28.1.dist-info/METADATA,sha256=9rZc1mFjNj6S3v6FjgrhM7bUdi6kO_5606CB7HQCfeo,13766
163
- datachain-0.28.1.dist-info/WHEEL,sha256=_zCd3N1l69ArxyTb8rzEoP9TpbYXkqRFSNOD5OuxnTs,91
164
- datachain-0.28.1.dist-info/entry_points.txt,sha256=0GMJS6B_KWq0m3VT98vQI2YZodAMkn4uReZ_okga9R4,49
165
- datachain-0.28.1.dist-info/top_level.txt,sha256=lZPpdU_2jJABLNIg2kvEOBi8PtsYikbN1OdMLHk8bTg,10
166
- datachain-0.28.1.dist-info/RECORD,,
161
+ datachain-0.29.0.dist-info/licenses/LICENSE,sha256=8DnqK5yoPI_E50bEg_zsHKZHY2HqPy4rYN338BHQaRA,11344
162
+ datachain-0.29.0.dist-info/METADATA,sha256=g5YmnSXxBvUz_ZO1ZoEPHkzRyQGW5ZbPc8a4ZRJqHXE,13766
163
+ datachain-0.29.0.dist-info/WHEEL,sha256=_zCd3N1l69ArxyTb8rzEoP9TpbYXkqRFSNOD5OuxnTs,91
164
+ datachain-0.29.0.dist-info/entry_points.txt,sha256=0GMJS6B_KWq0m3VT98vQI2YZodAMkn4uReZ_okga9R4,49
165
+ datachain-0.29.0.dist-info/top_level.txt,sha256=lZPpdU_2jJABLNIg2kvEOBi8PtsYikbN1OdMLHk8bTg,10
166
+ datachain-0.29.0.dist-info/RECORD,,