datachain 0.14.2__py3-none-any.whl → 0.39.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (137) hide show
  1. datachain/__init__.py +20 -0
  2. datachain/asyn.py +11 -12
  3. datachain/cache.py +7 -7
  4. datachain/catalog/__init__.py +2 -2
  5. datachain/catalog/catalog.py +621 -507
  6. datachain/catalog/dependency.py +164 -0
  7. datachain/catalog/loader.py +28 -18
  8. datachain/checkpoint.py +43 -0
  9. datachain/cli/__init__.py +24 -33
  10. datachain/cli/commands/__init__.py +1 -8
  11. datachain/cli/commands/datasets.py +83 -52
  12. datachain/cli/commands/ls.py +17 -17
  13. datachain/cli/commands/show.py +4 -4
  14. datachain/cli/parser/__init__.py +8 -74
  15. datachain/cli/parser/job.py +95 -3
  16. datachain/cli/parser/studio.py +11 -4
  17. datachain/cli/parser/utils.py +1 -2
  18. datachain/cli/utils.py +2 -15
  19. datachain/client/azure.py +4 -4
  20. datachain/client/fsspec.py +45 -28
  21. datachain/client/gcs.py +6 -6
  22. datachain/client/hf.py +29 -2
  23. datachain/client/http.py +157 -0
  24. datachain/client/local.py +15 -11
  25. datachain/client/s3.py +17 -9
  26. datachain/config.py +4 -8
  27. datachain/data_storage/db_engine.py +12 -6
  28. datachain/data_storage/job.py +5 -1
  29. datachain/data_storage/metastore.py +1252 -186
  30. datachain/data_storage/schema.py +58 -45
  31. datachain/data_storage/serializer.py +105 -15
  32. datachain/data_storage/sqlite.py +286 -127
  33. datachain/data_storage/warehouse.py +250 -113
  34. datachain/dataset.py +353 -148
  35. datachain/delta.py +391 -0
  36. datachain/diff/__init__.py +27 -29
  37. datachain/error.py +60 -0
  38. datachain/func/__init__.py +2 -1
  39. datachain/func/aggregate.py +66 -42
  40. datachain/func/array.py +242 -38
  41. datachain/func/base.py +7 -4
  42. datachain/func/conditional.py +110 -60
  43. datachain/func/func.py +96 -45
  44. datachain/func/numeric.py +55 -38
  45. datachain/func/path.py +32 -20
  46. datachain/func/random.py +2 -2
  47. datachain/func/string.py +67 -37
  48. datachain/func/window.py +7 -8
  49. datachain/hash_utils.py +123 -0
  50. datachain/job.py +11 -7
  51. datachain/json.py +138 -0
  52. datachain/lib/arrow.py +58 -22
  53. datachain/lib/audio.py +245 -0
  54. datachain/lib/clip.py +14 -13
  55. datachain/lib/convert/flatten.py +5 -3
  56. datachain/lib/convert/python_to_sql.py +6 -10
  57. datachain/lib/convert/sql_to_python.py +8 -0
  58. datachain/lib/convert/values_to_tuples.py +156 -51
  59. datachain/lib/data_model.py +42 -20
  60. datachain/lib/dataset_info.py +36 -8
  61. datachain/lib/dc/__init__.py +8 -2
  62. datachain/lib/dc/csv.py +25 -28
  63. datachain/lib/dc/database.py +398 -0
  64. datachain/lib/dc/datachain.py +1289 -425
  65. datachain/lib/dc/datasets.py +320 -38
  66. datachain/lib/dc/hf.py +38 -24
  67. datachain/lib/dc/json.py +29 -32
  68. datachain/lib/dc/listings.py +112 -8
  69. datachain/lib/dc/pandas.py +16 -12
  70. datachain/lib/dc/parquet.py +35 -23
  71. datachain/lib/dc/records.py +31 -23
  72. datachain/lib/dc/storage.py +154 -64
  73. datachain/lib/dc/storage_pattern.py +251 -0
  74. datachain/lib/dc/utils.py +24 -16
  75. datachain/lib/dc/values.py +8 -9
  76. datachain/lib/file.py +622 -89
  77. datachain/lib/hf.py +69 -39
  78. datachain/lib/image.py +14 -14
  79. datachain/lib/listing.py +14 -11
  80. datachain/lib/listing_info.py +1 -2
  81. datachain/lib/meta_formats.py +3 -4
  82. datachain/lib/model_store.py +39 -7
  83. datachain/lib/namespaces.py +125 -0
  84. datachain/lib/projects.py +130 -0
  85. datachain/lib/pytorch.py +32 -21
  86. datachain/lib/settings.py +192 -56
  87. datachain/lib/signal_schema.py +427 -104
  88. datachain/lib/tar.py +1 -2
  89. datachain/lib/text.py +8 -7
  90. datachain/lib/udf.py +164 -76
  91. datachain/lib/udf_signature.py +60 -35
  92. datachain/lib/utils.py +118 -4
  93. datachain/lib/video.py +17 -9
  94. datachain/lib/webdataset.py +61 -56
  95. datachain/lib/webdataset_laion.py +15 -16
  96. datachain/listing.py +22 -10
  97. datachain/model/bbox.py +3 -1
  98. datachain/model/ultralytics/bbox.py +16 -12
  99. datachain/model/ultralytics/pose.py +16 -12
  100. datachain/model/ultralytics/segment.py +16 -12
  101. datachain/namespace.py +84 -0
  102. datachain/node.py +6 -6
  103. datachain/nodes_thread_pool.py +0 -1
  104. datachain/plugins.py +24 -0
  105. datachain/project.py +78 -0
  106. datachain/query/batch.py +40 -41
  107. datachain/query/dataset.py +604 -322
  108. datachain/query/dispatch.py +261 -154
  109. datachain/query/metrics.py +4 -6
  110. datachain/query/params.py +2 -3
  111. datachain/query/queue.py +3 -12
  112. datachain/query/schema.py +11 -6
  113. datachain/query/session.py +200 -33
  114. datachain/query/udf.py +34 -2
  115. datachain/remote/studio.py +171 -69
  116. datachain/script_meta.py +12 -12
  117. datachain/semver.py +68 -0
  118. datachain/sql/__init__.py +2 -0
  119. datachain/sql/functions/array.py +33 -1
  120. datachain/sql/postgresql_dialect.py +9 -0
  121. datachain/sql/postgresql_types.py +21 -0
  122. datachain/sql/sqlite/__init__.py +5 -1
  123. datachain/sql/sqlite/base.py +102 -29
  124. datachain/sql/sqlite/types.py +8 -13
  125. datachain/sql/types.py +70 -15
  126. datachain/studio.py +223 -46
  127. datachain/toolkit/split.py +31 -10
  128. datachain/utils.py +101 -59
  129. {datachain-0.14.2.dist-info → datachain-0.39.0.dist-info}/METADATA +77 -22
  130. datachain-0.39.0.dist-info/RECORD +173 -0
  131. {datachain-0.14.2.dist-info → datachain-0.39.0.dist-info}/WHEEL +1 -1
  132. datachain/cli/commands/query.py +0 -53
  133. datachain/query/utils.py +0 -42
  134. datachain-0.14.2.dist-info/RECORD +0 -158
  135. {datachain-0.14.2.dist-info → datachain-0.39.0.dist-info}/entry_points.txt +0 -0
  136. {datachain-0.14.2.dist-info → datachain-0.39.0.dist-info}/licenses/LICENSE +0 -0
  137. {datachain-0.14.2.dist-info → datachain-0.39.0.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,398 @@
1
+ import contextlib
2
+ import itertools
3
+ import os
4
+ import sqlite3
5
+ from collections.abc import Iterator, Mapping, Sequence
6
+ from typing import TYPE_CHECKING, Any
7
+
8
+ import sqlalchemy
9
+
10
+ from datachain.query.schema import ColumnMeta
11
+ from datachain.utils import batched
12
+
13
+ DEFAULT_DATABASE_BATCH_SIZE = 10_000
14
+
15
+ if TYPE_CHECKING:
16
+ import sqlalchemy.orm # noqa: TC004
17
+
18
+ from datachain.lib.data_model import DataType
19
+ from datachain.query import Session
20
+
21
+ from .datachain import DataChain
22
+
23
+ ConnectionType = (
24
+ str
25
+ | sqlalchemy.engine.URL
26
+ | sqlalchemy.engine.interfaces.Connectable
27
+ | sqlalchemy.engine.Engine
28
+ | sqlalchemy.engine.Connection
29
+ | sqlalchemy.orm.Session
30
+ | sqlite3.Connection
31
+ )
32
+
33
+
34
+ @contextlib.contextmanager
35
+ def _connect(
36
+ connection: "ConnectionType",
37
+ ) -> Iterator[sqlalchemy.engine.Connection]:
38
+ import sqlalchemy.orm
39
+
40
+ with contextlib.ExitStack() as stack:
41
+ engine_kwargs = {"echo": bool(os.environ.get("DEBUG_SHOW_SQL_QUERIES"))}
42
+ if isinstance(connection, (str, sqlalchemy.URL)):
43
+ engine = sqlalchemy.create_engine(connection, **engine_kwargs)
44
+ stack.callback(engine.dispose)
45
+ yield stack.enter_context(engine.connect())
46
+ elif isinstance(connection, sqlite3.Connection):
47
+ engine = sqlalchemy.create_engine(
48
+ "sqlite://",
49
+ creator=lambda: connection,
50
+ poolclass=sqlalchemy.pool.StaticPool,
51
+ **engine_kwargs,
52
+ )
53
+ # Close only the SQLAlchemy connection wrapper; the underlying
54
+ # sqlite3 connection remains managed by the caller via StaticPool.
55
+ yield stack.enter_context(engine.connect())
56
+ elif isinstance(connection, sqlalchemy.Engine):
57
+ yield stack.enter_context(connection.connect())
58
+ elif isinstance(connection, sqlalchemy.Connection):
59
+ # do not close the connection, as it is managed by the caller
60
+ yield connection
61
+ elif isinstance(connection, sqlalchemy.orm.Session):
62
+ # For Session objects, get the underlying bind (Engine or Connection)
63
+ # Sessions don't support DDL operations directly
64
+ bind = connection.get_bind()
65
+ if isinstance(bind, sqlalchemy.Engine):
66
+ yield stack.enter_context(bind.connect())
67
+ else:
68
+ # bind is already a Connection
69
+ yield bind
70
+ else:
71
+ raise TypeError(f"Unsupported connection type: {type(connection).__name__}")
72
+
73
+
74
+ def to_database(
75
+ chain: "DataChain",
76
+ table_name: str,
77
+ connection: "ConnectionType",
78
+ *,
79
+ batch_size: int = DEFAULT_DATABASE_BATCH_SIZE,
80
+ on_conflict: str | None = None,
81
+ conflict_columns: list[str] | None = None,
82
+ column_mapping: dict[str, str | None] | None = None,
83
+ ) -> int:
84
+ """
85
+ Implementation function for exporting DataChain to database tables.
86
+
87
+ This is the core implementation that handles the actual database operations.
88
+ For user-facing documentation, see DataChain.to_database() method.
89
+
90
+ Returns:
91
+ int: Number of rows affected (inserted/updated).
92
+ """
93
+ if on_conflict and on_conflict not in ("ignore", "update"):
94
+ raise ValueError(
95
+ f"on_conflict must be 'ignore' or 'update', got: {on_conflict}"
96
+ )
97
+
98
+ signals_schema = chain.signals_schema.clone_without_sys_signals()
99
+ all_columns = [
100
+ sqlalchemy.Column(c.name, c.type) # type: ignore[union-attr]
101
+ for c in signals_schema.db_signals(as_columns=True)
102
+ ]
103
+
104
+ column_mapping = column_mapping or {}
105
+ normalized_column_mapping = _normalize_column_mapping(column_mapping)
106
+ column_indices_and_names, columns = _prepare_columns(
107
+ all_columns, normalized_column_mapping
108
+ )
109
+
110
+ normalized_conflict_columns = _normalize_conflict_columns(
111
+ conflict_columns, normalized_column_mapping
112
+ )
113
+
114
+ with _connect(connection) as conn:
115
+ metadata = sqlalchemy.MetaData()
116
+ table = sqlalchemy.Table(table_name, metadata, *columns)
117
+
118
+ table_existed_before = False
119
+ total_rows_affected = 0
120
+ try:
121
+ with conn.begin():
122
+ # Check if table exists to determine if we should clean up on error.
123
+ inspector = sqlalchemy.inspect(conn)
124
+ assert inspector # to satisfy mypy
125
+ table_existed_before = table_name in inspector.get_table_names()
126
+
127
+ table.create(conn, checkfirst=True)
128
+
129
+ rows_iter = chain._leaf_values()
130
+ for batch in batched(rows_iter, batch_size):
131
+ rows_affected = _process_batch(
132
+ conn,
133
+ table,
134
+ batch,
135
+ on_conflict,
136
+ normalized_conflict_columns,
137
+ column_indices_and_names,
138
+ )
139
+ if rows_affected < 0 or total_rows_affected < 0:
140
+ total_rows_affected = -1
141
+ else:
142
+ total_rows_affected += rows_affected
143
+ except Exception:
144
+ if not table_existed_before:
145
+ try:
146
+ table.drop(conn, checkfirst=True)
147
+ conn.commit()
148
+ except sqlalchemy.exc.SQLAlchemyError:
149
+ pass
150
+ raise
151
+
152
+ return total_rows_affected
153
+
154
+
155
+ def _normalize_column_mapping(
156
+ column_mapping: dict[str, str | None],
157
+ ) -> dict[str, str | None]:
158
+ """
159
+ Convert column mapping keys from DataChain format (dots) to database format
160
+ (double underscores).
161
+
162
+ This allows users to specify column mappings using the intuitive DataChain
163
+ format like: {"nested_data.value": "data_value"} instead of
164
+ {"nested_data__value": "data_value"}
165
+ """
166
+ if not column_mapping:
167
+ return {}
168
+
169
+ normalized_mapping: dict[str, str | None] = {}
170
+ original_keys: dict[str, str] = {}
171
+ for key, value in column_mapping.items():
172
+ db_key = ColumnMeta.to_db_name(key)
173
+ if db_key in normalized_mapping:
174
+ prev = original_keys[db_key]
175
+ raise ValueError(
176
+ "Column mapping collision: multiple keys map to the same "
177
+ f"database column name '{db_key}': '{prev}' and '{key}'. "
178
+ )
179
+ normalized_mapping[db_key] = value
180
+ original_keys[db_key] = key
181
+
182
+ # If it's a defaultdict, preserve the default factory
183
+ if hasattr(column_mapping, "default_factory"):
184
+ from collections import defaultdict
185
+
186
+ default_factory = column_mapping.default_factory
187
+ result: dict[str, str | None] = defaultdict(default_factory)
188
+ result.update(normalized_mapping)
189
+ return result
190
+
191
+ return normalized_mapping
192
+
193
+
194
+ def _normalize_conflict_columns(
195
+ conflict_columns: list[str] | None, column_mapping: dict[str, str | None]
196
+ ) -> list[str] | None:
197
+ """
198
+ Normalize conflict_columns by converting DataChain format to database format
199
+ and applying column mapping.
200
+ """
201
+ if not conflict_columns:
202
+ return None
203
+
204
+ normalized_columns = []
205
+ for col in conflict_columns:
206
+ db_col = ColumnMeta.to_db_name(col)
207
+
208
+ if db_col in column_mapping or hasattr(column_mapping, "default_factory"):
209
+ mapped_name = column_mapping[db_col]
210
+ if mapped_name:
211
+ normalized_columns.append(mapped_name)
212
+ else:
213
+ normalized_columns.append(db_col)
214
+
215
+ return normalized_columns
216
+
217
+
218
+ def _prepare_columns(all_columns, column_mapping):
219
+ """Prepare column mapping and column definitions."""
220
+ column_indices_and_names = [] # List of (index, target_name) tuples
221
+ columns = []
222
+ for idx, col in enumerate(all_columns):
223
+ if col.name in column_mapping or hasattr(column_mapping, "default_factory"):
224
+ mapped_name = column_mapping[col.name]
225
+ if mapped_name:
226
+ columns.append(sqlalchemy.Column(mapped_name, col.type))
227
+ column_indices_and_names.append((idx, mapped_name))
228
+ else:
229
+ columns.append(col)
230
+ column_indices_and_names.append((idx, col.name))
231
+ return column_indices_and_names, columns
232
+
233
+
234
+ def _process_batch(
235
+ conn, table, batch, on_conflict, conflict_columns, column_indices_and_names
236
+ ) -> int:
237
+ """Process a batch of rows with conflict resolution.
238
+
239
+ Returns:
240
+ int: Number of rows affected by the insert operation.
241
+ """
242
+
243
+ def prepare_row(row_values):
244
+ """Convert a row tuple to a dictionary with proper DB column names."""
245
+ return {
246
+ target_name: row_values[idx]
247
+ for idx, target_name in column_indices_and_names
248
+ }
249
+
250
+ rows_to_insert = [prepare_row(row) for row in batch]
251
+
252
+ supports_conflict = on_conflict and conn.engine.name in ("postgresql", "sqlite")
253
+
254
+ insert_stmt: Any # Can be PostgreSQL, SQLite, or regular insert statement
255
+ if supports_conflict:
256
+ # Use dialect-specific insert for conflict resolution
257
+ if conn.engine.name == "postgresql":
258
+ from sqlalchemy.dialects.postgresql import insert as pg_insert
259
+
260
+ insert_stmt = pg_insert(table)
261
+ elif conn.engine.name == "sqlite":
262
+ from sqlalchemy.dialects.sqlite import insert as sqlite_insert
263
+
264
+ insert_stmt = sqlite_insert(table)
265
+ else:
266
+ insert_stmt = table.insert()
267
+
268
+ if supports_conflict:
269
+ if on_conflict == "ignore":
270
+ insert_stmt = insert_stmt.on_conflict_do_nothing()
271
+ elif on_conflict == "update":
272
+ update_values = {
273
+ col.name: insert_stmt.excluded[col.name] for col in table.columns
274
+ }
275
+ if conn.engine.name == "postgresql":
276
+ if not conflict_columns:
277
+ raise ValueError(
278
+ "conflict_columns parameter is required when "
279
+ "on_conflict='update' with PostgreSQL. Specify the column "
280
+ "names that form a unique constraint."
281
+ )
282
+
283
+ insert_stmt = insert_stmt.on_conflict_do_update(
284
+ index_elements=conflict_columns, set_=update_values
285
+ )
286
+ else:
287
+ insert_stmt = insert_stmt.on_conflict_do_update(set_=update_values)
288
+ elif on_conflict:
289
+ import warnings
290
+
291
+ warnings.warn(
292
+ f"Database does not support conflict resolution. "
293
+ f"Ignoring on_conflict='{on_conflict}' parameter.",
294
+ UserWarning,
295
+ stacklevel=2,
296
+ )
297
+
298
+ result = conn.execute(insert_stmt, rows_to_insert)
299
+ return result.rowcount
300
+
301
+
302
+ def read_database(
303
+ query: "str | sqlalchemy.sql.expression.Executable",
304
+ connection: "ConnectionType",
305
+ params: Sequence[Mapping[str, Any]] | Mapping[str, Any] | None = None,
306
+ *,
307
+ output: dict[str, "DataType"] | None = None,
308
+ session: "Session | None" = None,
309
+ settings: dict | None = None,
310
+ in_memory: bool = False,
311
+ infer_schema_length: int | None = 100,
312
+ ) -> "DataChain":
313
+ """
314
+ Read the results of a SQL query into a DataChain, using a given database connection.
315
+
316
+ Args:
317
+ query:
318
+ The SQL query to execute. Can be a raw SQL string or a SQLAlchemy
319
+ `Executable` object.
320
+ connection: SQLAlchemy connectable, str, or a sqlite3 connection
321
+ Using SQLAlchemy makes it possible to use any DB supported by that
322
+ library. If a DBAPI2 object, only sqlite3 is supported. The user is
323
+ responsible for engine disposal and connection closure for the
324
+ SQLAlchemy connectable; str connections are closed automatically.
325
+ params: Parameters to pass to execute method.
326
+ output: A dictionary mapping column names to types, used to override the
327
+ schema inferred from the query results.
328
+ session: Session to use for the chain.
329
+ settings: Settings to use for the chain.
330
+ in_memory: If True, creates an in-memory session. Defaults to False.
331
+ infer_schema_length:
332
+ The maximum number of rows to scan for inferring schema.
333
+ If set to `None`, the full data may be scanned.
334
+ The rows used for schema inference are stored in memory,
335
+ so large values can lead to high memory usage.
336
+ Only applies if the `output` parameter is not set for the given column.
337
+
338
+ Examples:
339
+ Reading from a SQL query against a user-supplied connection:
340
+ ```python
341
+ query = "SELECT key, value FROM tbl"
342
+ chain = dc.read_database(query, connection, output={"value": float})
343
+ ```
344
+
345
+ Load data from a SQLAlchemy driver/engine:
346
+ ```python
347
+ from sqlalchemy import create_engine
348
+ engine = create_engine("postgresql+psycopg://myuser:mypassword@localhost:5432/mydb")
349
+ chain = dc.read_database("select * from tbl", engine)
350
+ ```
351
+
352
+ Load data from a parameterized SQLAlchemy query:
353
+ ```python
354
+ query = "SELECT key, value FROM tbl WHERE value > :value"
355
+ dc.read_database(query, engine, params={"value": 50})
356
+ ```
357
+
358
+ Notes:
359
+ - This function works with a variety of databases — including,
360
+ but not limited to, SQLite, DuckDB, PostgreSQL, and Snowflake,
361
+ provided the appropriate driver is installed.
362
+ - This call is blocking, and will execute the query and return once the
363
+ results are saved.
364
+ """
365
+ from datachain.lib.dc.records import read_records
366
+
367
+ output = output or {}
368
+ if isinstance(query, str):
369
+ query = sqlalchemy.text(query)
370
+ kw = {"execution_options": {"stream_results": True}} # use server-side cursors
371
+ with _connect(connection) as conn, conn.execute(query, params, **kw) as result:
372
+ cols = result.keys()
373
+ to_infer = [k for k in cols if k not in output] # preserve the order
374
+ rows, inferred_schema = _infer_schema(result, to_infer, infer_schema_length)
375
+ records = (row._asdict() for row in itertools.chain(rows, result))
376
+ return read_records(
377
+ records,
378
+ session=session,
379
+ settings=settings,
380
+ in_memory=in_memory,
381
+ schema=inferred_schema | output,
382
+ )
383
+
384
+
385
+ def _infer_schema(
386
+ result: "sqlalchemy.engine.Result",
387
+ to_infer: list[str],
388
+ infer_schema_length: int | None = 100,
389
+ ) -> tuple[list["sqlalchemy.Row"], dict[str, "DataType"]]:
390
+ from datachain.lib.convert.values_to_tuples import values_to_tuples
391
+
392
+ if not to_infer:
393
+ return [], {}
394
+
395
+ rows = list(itertools.islice(result, infer_schema_length))
396
+ values = {col: [row._mapping[col] for row in rows] for col in to_infer}
397
+ _, output_schema, _ = values_to_tuples("", **values)
398
+ return rows, output_schema