datachain 0.14.2__py3-none-any.whl → 0.39.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- datachain/__init__.py +20 -0
- datachain/asyn.py +11 -12
- datachain/cache.py +7 -7
- datachain/catalog/__init__.py +2 -2
- datachain/catalog/catalog.py +621 -507
- datachain/catalog/dependency.py +164 -0
- datachain/catalog/loader.py +28 -18
- datachain/checkpoint.py +43 -0
- datachain/cli/__init__.py +24 -33
- datachain/cli/commands/__init__.py +1 -8
- datachain/cli/commands/datasets.py +83 -52
- datachain/cli/commands/ls.py +17 -17
- datachain/cli/commands/show.py +4 -4
- datachain/cli/parser/__init__.py +8 -74
- datachain/cli/parser/job.py +95 -3
- datachain/cli/parser/studio.py +11 -4
- datachain/cli/parser/utils.py +1 -2
- datachain/cli/utils.py +2 -15
- datachain/client/azure.py +4 -4
- datachain/client/fsspec.py +45 -28
- datachain/client/gcs.py +6 -6
- datachain/client/hf.py +29 -2
- datachain/client/http.py +157 -0
- datachain/client/local.py +15 -11
- datachain/client/s3.py +17 -9
- datachain/config.py +4 -8
- datachain/data_storage/db_engine.py +12 -6
- datachain/data_storage/job.py +5 -1
- datachain/data_storage/metastore.py +1252 -186
- datachain/data_storage/schema.py +58 -45
- datachain/data_storage/serializer.py +105 -15
- datachain/data_storage/sqlite.py +286 -127
- datachain/data_storage/warehouse.py +250 -113
- datachain/dataset.py +353 -148
- datachain/delta.py +391 -0
- datachain/diff/__init__.py +27 -29
- datachain/error.py +60 -0
- datachain/func/__init__.py +2 -1
- datachain/func/aggregate.py +66 -42
- datachain/func/array.py +242 -38
- datachain/func/base.py +7 -4
- datachain/func/conditional.py +110 -60
- datachain/func/func.py +96 -45
- datachain/func/numeric.py +55 -38
- datachain/func/path.py +32 -20
- datachain/func/random.py +2 -2
- datachain/func/string.py +67 -37
- datachain/func/window.py +7 -8
- datachain/hash_utils.py +123 -0
- datachain/job.py +11 -7
- datachain/json.py +138 -0
- datachain/lib/arrow.py +58 -22
- datachain/lib/audio.py +245 -0
- datachain/lib/clip.py +14 -13
- datachain/lib/convert/flatten.py +5 -3
- datachain/lib/convert/python_to_sql.py +6 -10
- datachain/lib/convert/sql_to_python.py +8 -0
- datachain/lib/convert/values_to_tuples.py +156 -51
- datachain/lib/data_model.py +42 -20
- datachain/lib/dataset_info.py +36 -8
- datachain/lib/dc/__init__.py +8 -2
- datachain/lib/dc/csv.py +25 -28
- datachain/lib/dc/database.py +398 -0
- datachain/lib/dc/datachain.py +1289 -425
- datachain/lib/dc/datasets.py +320 -38
- datachain/lib/dc/hf.py +38 -24
- datachain/lib/dc/json.py +29 -32
- datachain/lib/dc/listings.py +112 -8
- datachain/lib/dc/pandas.py +16 -12
- datachain/lib/dc/parquet.py +35 -23
- datachain/lib/dc/records.py +31 -23
- datachain/lib/dc/storage.py +154 -64
- datachain/lib/dc/storage_pattern.py +251 -0
- datachain/lib/dc/utils.py +24 -16
- datachain/lib/dc/values.py +8 -9
- datachain/lib/file.py +622 -89
- datachain/lib/hf.py +69 -39
- datachain/lib/image.py +14 -14
- datachain/lib/listing.py +14 -11
- datachain/lib/listing_info.py +1 -2
- datachain/lib/meta_formats.py +3 -4
- datachain/lib/model_store.py +39 -7
- datachain/lib/namespaces.py +125 -0
- datachain/lib/projects.py +130 -0
- datachain/lib/pytorch.py +32 -21
- datachain/lib/settings.py +192 -56
- datachain/lib/signal_schema.py +427 -104
- datachain/lib/tar.py +1 -2
- datachain/lib/text.py +8 -7
- datachain/lib/udf.py +164 -76
- datachain/lib/udf_signature.py +60 -35
- datachain/lib/utils.py +118 -4
- datachain/lib/video.py +17 -9
- datachain/lib/webdataset.py +61 -56
- datachain/lib/webdataset_laion.py +15 -16
- datachain/listing.py +22 -10
- datachain/model/bbox.py +3 -1
- datachain/model/ultralytics/bbox.py +16 -12
- datachain/model/ultralytics/pose.py +16 -12
- datachain/model/ultralytics/segment.py +16 -12
- datachain/namespace.py +84 -0
- datachain/node.py +6 -6
- datachain/nodes_thread_pool.py +0 -1
- datachain/plugins.py +24 -0
- datachain/project.py +78 -0
- datachain/query/batch.py +40 -41
- datachain/query/dataset.py +604 -322
- datachain/query/dispatch.py +261 -154
- datachain/query/metrics.py +4 -6
- datachain/query/params.py +2 -3
- datachain/query/queue.py +3 -12
- datachain/query/schema.py +11 -6
- datachain/query/session.py +200 -33
- datachain/query/udf.py +34 -2
- datachain/remote/studio.py +171 -69
- datachain/script_meta.py +12 -12
- datachain/semver.py +68 -0
- datachain/sql/__init__.py +2 -0
- datachain/sql/functions/array.py +33 -1
- datachain/sql/postgresql_dialect.py +9 -0
- datachain/sql/postgresql_types.py +21 -0
- datachain/sql/sqlite/__init__.py +5 -1
- datachain/sql/sqlite/base.py +102 -29
- datachain/sql/sqlite/types.py +8 -13
- datachain/sql/types.py +70 -15
- datachain/studio.py +223 -46
- datachain/toolkit/split.py +31 -10
- datachain/utils.py +101 -59
- {datachain-0.14.2.dist-info → datachain-0.39.0.dist-info}/METADATA +77 -22
- datachain-0.39.0.dist-info/RECORD +173 -0
- {datachain-0.14.2.dist-info → datachain-0.39.0.dist-info}/WHEEL +1 -1
- datachain/cli/commands/query.py +0 -53
- datachain/query/utils.py +0 -42
- datachain-0.14.2.dist-info/RECORD +0 -158
- {datachain-0.14.2.dist-info → datachain-0.39.0.dist-info}/entry_points.txt +0 -0
- {datachain-0.14.2.dist-info → datachain-0.39.0.dist-info}/licenses/LICENSE +0 -0
- {datachain-0.14.2.dist-info → datachain-0.39.0.dist-info}/top_level.txt +0 -0
|
@@ -0,0 +1,398 @@
|
|
|
1
|
+
import contextlib
|
|
2
|
+
import itertools
|
|
3
|
+
import os
|
|
4
|
+
import sqlite3
|
|
5
|
+
from collections.abc import Iterator, Mapping, Sequence
|
|
6
|
+
from typing import TYPE_CHECKING, Any
|
|
7
|
+
|
|
8
|
+
import sqlalchemy
|
|
9
|
+
|
|
10
|
+
from datachain.query.schema import ColumnMeta
|
|
11
|
+
from datachain.utils import batched
|
|
12
|
+
|
|
13
|
+
DEFAULT_DATABASE_BATCH_SIZE = 10_000
|
|
14
|
+
|
|
15
|
+
if TYPE_CHECKING:
|
|
16
|
+
import sqlalchemy.orm # noqa: TC004
|
|
17
|
+
|
|
18
|
+
from datachain.lib.data_model import DataType
|
|
19
|
+
from datachain.query import Session
|
|
20
|
+
|
|
21
|
+
from .datachain import DataChain
|
|
22
|
+
|
|
23
|
+
ConnectionType = (
|
|
24
|
+
str
|
|
25
|
+
| sqlalchemy.engine.URL
|
|
26
|
+
| sqlalchemy.engine.interfaces.Connectable
|
|
27
|
+
| sqlalchemy.engine.Engine
|
|
28
|
+
| sqlalchemy.engine.Connection
|
|
29
|
+
| sqlalchemy.orm.Session
|
|
30
|
+
| sqlite3.Connection
|
|
31
|
+
)
|
|
32
|
+
|
|
33
|
+
|
|
34
|
+
@contextlib.contextmanager
|
|
35
|
+
def _connect(
|
|
36
|
+
connection: "ConnectionType",
|
|
37
|
+
) -> Iterator[sqlalchemy.engine.Connection]:
|
|
38
|
+
import sqlalchemy.orm
|
|
39
|
+
|
|
40
|
+
with contextlib.ExitStack() as stack:
|
|
41
|
+
engine_kwargs = {"echo": bool(os.environ.get("DEBUG_SHOW_SQL_QUERIES"))}
|
|
42
|
+
if isinstance(connection, (str, sqlalchemy.URL)):
|
|
43
|
+
engine = sqlalchemy.create_engine(connection, **engine_kwargs)
|
|
44
|
+
stack.callback(engine.dispose)
|
|
45
|
+
yield stack.enter_context(engine.connect())
|
|
46
|
+
elif isinstance(connection, sqlite3.Connection):
|
|
47
|
+
engine = sqlalchemy.create_engine(
|
|
48
|
+
"sqlite://",
|
|
49
|
+
creator=lambda: connection,
|
|
50
|
+
poolclass=sqlalchemy.pool.StaticPool,
|
|
51
|
+
**engine_kwargs,
|
|
52
|
+
)
|
|
53
|
+
# Close only the SQLAlchemy connection wrapper; the underlying
|
|
54
|
+
# sqlite3 connection remains managed by the caller via StaticPool.
|
|
55
|
+
yield stack.enter_context(engine.connect())
|
|
56
|
+
elif isinstance(connection, sqlalchemy.Engine):
|
|
57
|
+
yield stack.enter_context(connection.connect())
|
|
58
|
+
elif isinstance(connection, sqlalchemy.Connection):
|
|
59
|
+
# do not close the connection, as it is managed by the caller
|
|
60
|
+
yield connection
|
|
61
|
+
elif isinstance(connection, sqlalchemy.orm.Session):
|
|
62
|
+
# For Session objects, get the underlying bind (Engine or Connection)
|
|
63
|
+
# Sessions don't support DDL operations directly
|
|
64
|
+
bind = connection.get_bind()
|
|
65
|
+
if isinstance(bind, sqlalchemy.Engine):
|
|
66
|
+
yield stack.enter_context(bind.connect())
|
|
67
|
+
else:
|
|
68
|
+
# bind is already a Connection
|
|
69
|
+
yield bind
|
|
70
|
+
else:
|
|
71
|
+
raise TypeError(f"Unsupported connection type: {type(connection).__name__}")
|
|
72
|
+
|
|
73
|
+
|
|
74
|
+
def to_database(
|
|
75
|
+
chain: "DataChain",
|
|
76
|
+
table_name: str,
|
|
77
|
+
connection: "ConnectionType",
|
|
78
|
+
*,
|
|
79
|
+
batch_size: int = DEFAULT_DATABASE_BATCH_SIZE,
|
|
80
|
+
on_conflict: str | None = None,
|
|
81
|
+
conflict_columns: list[str] | None = None,
|
|
82
|
+
column_mapping: dict[str, str | None] | None = None,
|
|
83
|
+
) -> int:
|
|
84
|
+
"""
|
|
85
|
+
Implementation function for exporting DataChain to database tables.
|
|
86
|
+
|
|
87
|
+
This is the core implementation that handles the actual database operations.
|
|
88
|
+
For user-facing documentation, see DataChain.to_database() method.
|
|
89
|
+
|
|
90
|
+
Returns:
|
|
91
|
+
int: Number of rows affected (inserted/updated).
|
|
92
|
+
"""
|
|
93
|
+
if on_conflict and on_conflict not in ("ignore", "update"):
|
|
94
|
+
raise ValueError(
|
|
95
|
+
f"on_conflict must be 'ignore' or 'update', got: {on_conflict}"
|
|
96
|
+
)
|
|
97
|
+
|
|
98
|
+
signals_schema = chain.signals_schema.clone_without_sys_signals()
|
|
99
|
+
all_columns = [
|
|
100
|
+
sqlalchemy.Column(c.name, c.type) # type: ignore[union-attr]
|
|
101
|
+
for c in signals_schema.db_signals(as_columns=True)
|
|
102
|
+
]
|
|
103
|
+
|
|
104
|
+
column_mapping = column_mapping or {}
|
|
105
|
+
normalized_column_mapping = _normalize_column_mapping(column_mapping)
|
|
106
|
+
column_indices_and_names, columns = _prepare_columns(
|
|
107
|
+
all_columns, normalized_column_mapping
|
|
108
|
+
)
|
|
109
|
+
|
|
110
|
+
normalized_conflict_columns = _normalize_conflict_columns(
|
|
111
|
+
conflict_columns, normalized_column_mapping
|
|
112
|
+
)
|
|
113
|
+
|
|
114
|
+
with _connect(connection) as conn:
|
|
115
|
+
metadata = sqlalchemy.MetaData()
|
|
116
|
+
table = sqlalchemy.Table(table_name, metadata, *columns)
|
|
117
|
+
|
|
118
|
+
table_existed_before = False
|
|
119
|
+
total_rows_affected = 0
|
|
120
|
+
try:
|
|
121
|
+
with conn.begin():
|
|
122
|
+
# Check if table exists to determine if we should clean up on error.
|
|
123
|
+
inspector = sqlalchemy.inspect(conn)
|
|
124
|
+
assert inspector # to satisfy mypy
|
|
125
|
+
table_existed_before = table_name in inspector.get_table_names()
|
|
126
|
+
|
|
127
|
+
table.create(conn, checkfirst=True)
|
|
128
|
+
|
|
129
|
+
rows_iter = chain._leaf_values()
|
|
130
|
+
for batch in batched(rows_iter, batch_size):
|
|
131
|
+
rows_affected = _process_batch(
|
|
132
|
+
conn,
|
|
133
|
+
table,
|
|
134
|
+
batch,
|
|
135
|
+
on_conflict,
|
|
136
|
+
normalized_conflict_columns,
|
|
137
|
+
column_indices_and_names,
|
|
138
|
+
)
|
|
139
|
+
if rows_affected < 0 or total_rows_affected < 0:
|
|
140
|
+
total_rows_affected = -1
|
|
141
|
+
else:
|
|
142
|
+
total_rows_affected += rows_affected
|
|
143
|
+
except Exception:
|
|
144
|
+
if not table_existed_before:
|
|
145
|
+
try:
|
|
146
|
+
table.drop(conn, checkfirst=True)
|
|
147
|
+
conn.commit()
|
|
148
|
+
except sqlalchemy.exc.SQLAlchemyError:
|
|
149
|
+
pass
|
|
150
|
+
raise
|
|
151
|
+
|
|
152
|
+
return total_rows_affected
|
|
153
|
+
|
|
154
|
+
|
|
155
|
+
def _normalize_column_mapping(
|
|
156
|
+
column_mapping: dict[str, str | None],
|
|
157
|
+
) -> dict[str, str | None]:
|
|
158
|
+
"""
|
|
159
|
+
Convert column mapping keys from DataChain format (dots) to database format
|
|
160
|
+
(double underscores).
|
|
161
|
+
|
|
162
|
+
This allows users to specify column mappings using the intuitive DataChain
|
|
163
|
+
format like: {"nested_data.value": "data_value"} instead of
|
|
164
|
+
{"nested_data__value": "data_value"}
|
|
165
|
+
"""
|
|
166
|
+
if not column_mapping:
|
|
167
|
+
return {}
|
|
168
|
+
|
|
169
|
+
normalized_mapping: dict[str, str | None] = {}
|
|
170
|
+
original_keys: dict[str, str] = {}
|
|
171
|
+
for key, value in column_mapping.items():
|
|
172
|
+
db_key = ColumnMeta.to_db_name(key)
|
|
173
|
+
if db_key in normalized_mapping:
|
|
174
|
+
prev = original_keys[db_key]
|
|
175
|
+
raise ValueError(
|
|
176
|
+
"Column mapping collision: multiple keys map to the same "
|
|
177
|
+
f"database column name '{db_key}': '{prev}' and '{key}'. "
|
|
178
|
+
)
|
|
179
|
+
normalized_mapping[db_key] = value
|
|
180
|
+
original_keys[db_key] = key
|
|
181
|
+
|
|
182
|
+
# If it's a defaultdict, preserve the default factory
|
|
183
|
+
if hasattr(column_mapping, "default_factory"):
|
|
184
|
+
from collections import defaultdict
|
|
185
|
+
|
|
186
|
+
default_factory = column_mapping.default_factory
|
|
187
|
+
result: dict[str, str | None] = defaultdict(default_factory)
|
|
188
|
+
result.update(normalized_mapping)
|
|
189
|
+
return result
|
|
190
|
+
|
|
191
|
+
return normalized_mapping
|
|
192
|
+
|
|
193
|
+
|
|
194
|
+
def _normalize_conflict_columns(
|
|
195
|
+
conflict_columns: list[str] | None, column_mapping: dict[str, str | None]
|
|
196
|
+
) -> list[str] | None:
|
|
197
|
+
"""
|
|
198
|
+
Normalize conflict_columns by converting DataChain format to database format
|
|
199
|
+
and applying column mapping.
|
|
200
|
+
"""
|
|
201
|
+
if not conflict_columns:
|
|
202
|
+
return None
|
|
203
|
+
|
|
204
|
+
normalized_columns = []
|
|
205
|
+
for col in conflict_columns:
|
|
206
|
+
db_col = ColumnMeta.to_db_name(col)
|
|
207
|
+
|
|
208
|
+
if db_col in column_mapping or hasattr(column_mapping, "default_factory"):
|
|
209
|
+
mapped_name = column_mapping[db_col]
|
|
210
|
+
if mapped_name:
|
|
211
|
+
normalized_columns.append(mapped_name)
|
|
212
|
+
else:
|
|
213
|
+
normalized_columns.append(db_col)
|
|
214
|
+
|
|
215
|
+
return normalized_columns
|
|
216
|
+
|
|
217
|
+
|
|
218
|
+
def _prepare_columns(all_columns, column_mapping):
|
|
219
|
+
"""Prepare column mapping and column definitions."""
|
|
220
|
+
column_indices_and_names = [] # List of (index, target_name) tuples
|
|
221
|
+
columns = []
|
|
222
|
+
for idx, col in enumerate(all_columns):
|
|
223
|
+
if col.name in column_mapping or hasattr(column_mapping, "default_factory"):
|
|
224
|
+
mapped_name = column_mapping[col.name]
|
|
225
|
+
if mapped_name:
|
|
226
|
+
columns.append(sqlalchemy.Column(mapped_name, col.type))
|
|
227
|
+
column_indices_and_names.append((idx, mapped_name))
|
|
228
|
+
else:
|
|
229
|
+
columns.append(col)
|
|
230
|
+
column_indices_and_names.append((idx, col.name))
|
|
231
|
+
return column_indices_and_names, columns
|
|
232
|
+
|
|
233
|
+
|
|
234
|
+
def _process_batch(
|
|
235
|
+
conn, table, batch, on_conflict, conflict_columns, column_indices_and_names
|
|
236
|
+
) -> int:
|
|
237
|
+
"""Process a batch of rows with conflict resolution.
|
|
238
|
+
|
|
239
|
+
Returns:
|
|
240
|
+
int: Number of rows affected by the insert operation.
|
|
241
|
+
"""
|
|
242
|
+
|
|
243
|
+
def prepare_row(row_values):
|
|
244
|
+
"""Convert a row tuple to a dictionary with proper DB column names."""
|
|
245
|
+
return {
|
|
246
|
+
target_name: row_values[idx]
|
|
247
|
+
for idx, target_name in column_indices_and_names
|
|
248
|
+
}
|
|
249
|
+
|
|
250
|
+
rows_to_insert = [prepare_row(row) for row in batch]
|
|
251
|
+
|
|
252
|
+
supports_conflict = on_conflict and conn.engine.name in ("postgresql", "sqlite")
|
|
253
|
+
|
|
254
|
+
insert_stmt: Any # Can be PostgreSQL, SQLite, or regular insert statement
|
|
255
|
+
if supports_conflict:
|
|
256
|
+
# Use dialect-specific insert for conflict resolution
|
|
257
|
+
if conn.engine.name == "postgresql":
|
|
258
|
+
from sqlalchemy.dialects.postgresql import insert as pg_insert
|
|
259
|
+
|
|
260
|
+
insert_stmt = pg_insert(table)
|
|
261
|
+
elif conn.engine.name == "sqlite":
|
|
262
|
+
from sqlalchemy.dialects.sqlite import insert as sqlite_insert
|
|
263
|
+
|
|
264
|
+
insert_stmt = sqlite_insert(table)
|
|
265
|
+
else:
|
|
266
|
+
insert_stmt = table.insert()
|
|
267
|
+
|
|
268
|
+
if supports_conflict:
|
|
269
|
+
if on_conflict == "ignore":
|
|
270
|
+
insert_stmt = insert_stmt.on_conflict_do_nothing()
|
|
271
|
+
elif on_conflict == "update":
|
|
272
|
+
update_values = {
|
|
273
|
+
col.name: insert_stmt.excluded[col.name] for col in table.columns
|
|
274
|
+
}
|
|
275
|
+
if conn.engine.name == "postgresql":
|
|
276
|
+
if not conflict_columns:
|
|
277
|
+
raise ValueError(
|
|
278
|
+
"conflict_columns parameter is required when "
|
|
279
|
+
"on_conflict='update' with PostgreSQL. Specify the column "
|
|
280
|
+
"names that form a unique constraint."
|
|
281
|
+
)
|
|
282
|
+
|
|
283
|
+
insert_stmt = insert_stmt.on_conflict_do_update(
|
|
284
|
+
index_elements=conflict_columns, set_=update_values
|
|
285
|
+
)
|
|
286
|
+
else:
|
|
287
|
+
insert_stmt = insert_stmt.on_conflict_do_update(set_=update_values)
|
|
288
|
+
elif on_conflict:
|
|
289
|
+
import warnings
|
|
290
|
+
|
|
291
|
+
warnings.warn(
|
|
292
|
+
f"Database does not support conflict resolution. "
|
|
293
|
+
f"Ignoring on_conflict='{on_conflict}' parameter.",
|
|
294
|
+
UserWarning,
|
|
295
|
+
stacklevel=2,
|
|
296
|
+
)
|
|
297
|
+
|
|
298
|
+
result = conn.execute(insert_stmt, rows_to_insert)
|
|
299
|
+
return result.rowcount
|
|
300
|
+
|
|
301
|
+
|
|
302
|
+
def read_database(
|
|
303
|
+
query: "str | sqlalchemy.sql.expression.Executable",
|
|
304
|
+
connection: "ConnectionType",
|
|
305
|
+
params: Sequence[Mapping[str, Any]] | Mapping[str, Any] | None = None,
|
|
306
|
+
*,
|
|
307
|
+
output: dict[str, "DataType"] | None = None,
|
|
308
|
+
session: "Session | None" = None,
|
|
309
|
+
settings: dict | None = None,
|
|
310
|
+
in_memory: bool = False,
|
|
311
|
+
infer_schema_length: int | None = 100,
|
|
312
|
+
) -> "DataChain":
|
|
313
|
+
"""
|
|
314
|
+
Read the results of a SQL query into a DataChain, using a given database connection.
|
|
315
|
+
|
|
316
|
+
Args:
|
|
317
|
+
query:
|
|
318
|
+
The SQL query to execute. Can be a raw SQL string or a SQLAlchemy
|
|
319
|
+
`Executable` object.
|
|
320
|
+
connection: SQLAlchemy connectable, str, or a sqlite3 connection
|
|
321
|
+
Using SQLAlchemy makes it possible to use any DB supported by that
|
|
322
|
+
library. If a DBAPI2 object, only sqlite3 is supported. The user is
|
|
323
|
+
responsible for engine disposal and connection closure for the
|
|
324
|
+
SQLAlchemy connectable; str connections are closed automatically.
|
|
325
|
+
params: Parameters to pass to execute method.
|
|
326
|
+
output: A dictionary mapping column names to types, used to override the
|
|
327
|
+
schema inferred from the query results.
|
|
328
|
+
session: Session to use for the chain.
|
|
329
|
+
settings: Settings to use for the chain.
|
|
330
|
+
in_memory: If True, creates an in-memory session. Defaults to False.
|
|
331
|
+
infer_schema_length:
|
|
332
|
+
The maximum number of rows to scan for inferring schema.
|
|
333
|
+
If set to `None`, the full data may be scanned.
|
|
334
|
+
The rows used for schema inference are stored in memory,
|
|
335
|
+
so large values can lead to high memory usage.
|
|
336
|
+
Only applies if the `output` parameter is not set for the given column.
|
|
337
|
+
|
|
338
|
+
Examples:
|
|
339
|
+
Reading from a SQL query against a user-supplied connection:
|
|
340
|
+
```python
|
|
341
|
+
query = "SELECT key, value FROM tbl"
|
|
342
|
+
chain = dc.read_database(query, connection, output={"value": float})
|
|
343
|
+
```
|
|
344
|
+
|
|
345
|
+
Load data from a SQLAlchemy driver/engine:
|
|
346
|
+
```python
|
|
347
|
+
from sqlalchemy import create_engine
|
|
348
|
+
engine = create_engine("postgresql+psycopg://myuser:mypassword@localhost:5432/mydb")
|
|
349
|
+
chain = dc.read_database("select * from tbl", engine)
|
|
350
|
+
```
|
|
351
|
+
|
|
352
|
+
Load data from a parameterized SQLAlchemy query:
|
|
353
|
+
```python
|
|
354
|
+
query = "SELECT key, value FROM tbl WHERE value > :value"
|
|
355
|
+
dc.read_database(query, engine, params={"value": 50})
|
|
356
|
+
```
|
|
357
|
+
|
|
358
|
+
Notes:
|
|
359
|
+
- This function works with a variety of databases — including,
|
|
360
|
+
but not limited to, SQLite, DuckDB, PostgreSQL, and Snowflake,
|
|
361
|
+
provided the appropriate driver is installed.
|
|
362
|
+
- This call is blocking, and will execute the query and return once the
|
|
363
|
+
results are saved.
|
|
364
|
+
"""
|
|
365
|
+
from datachain.lib.dc.records import read_records
|
|
366
|
+
|
|
367
|
+
output = output or {}
|
|
368
|
+
if isinstance(query, str):
|
|
369
|
+
query = sqlalchemy.text(query)
|
|
370
|
+
kw = {"execution_options": {"stream_results": True}} # use server-side cursors
|
|
371
|
+
with _connect(connection) as conn, conn.execute(query, params, **kw) as result:
|
|
372
|
+
cols = result.keys()
|
|
373
|
+
to_infer = [k for k in cols if k not in output] # preserve the order
|
|
374
|
+
rows, inferred_schema = _infer_schema(result, to_infer, infer_schema_length)
|
|
375
|
+
records = (row._asdict() for row in itertools.chain(rows, result))
|
|
376
|
+
return read_records(
|
|
377
|
+
records,
|
|
378
|
+
session=session,
|
|
379
|
+
settings=settings,
|
|
380
|
+
in_memory=in_memory,
|
|
381
|
+
schema=inferred_schema | output,
|
|
382
|
+
)
|
|
383
|
+
|
|
384
|
+
|
|
385
|
+
def _infer_schema(
|
|
386
|
+
result: "sqlalchemy.engine.Result",
|
|
387
|
+
to_infer: list[str],
|
|
388
|
+
infer_schema_length: int | None = 100,
|
|
389
|
+
) -> tuple[list["sqlalchemy.Row"], dict[str, "DataType"]]:
|
|
390
|
+
from datachain.lib.convert.values_to_tuples import values_to_tuples
|
|
391
|
+
|
|
392
|
+
if not to_infer:
|
|
393
|
+
return [], {}
|
|
394
|
+
|
|
395
|
+
rows = list(itertools.islice(result, infer_schema_length))
|
|
396
|
+
values = {col: [row._mapping[col] for row in rows] for col in to_infer}
|
|
397
|
+
_, output_schema, _ = values_to_tuples("", **values)
|
|
398
|
+
return rows, output_schema
|