datachain 0.14.4__py3-none-any.whl → 0.15.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of datachain might be problematic. Click here for more details.
- datachain/__init__.py +4 -0
- datachain/catalog/catalog.py +13 -5
- datachain/catalog/loader.py +11 -7
- datachain/data_storage/schema.py +21 -23
- datachain/data_storage/sqlite.py +1 -1
- datachain/data_storage/warehouse.py +6 -8
- datachain/lib/convert/values_to_tuples.py +23 -14
- datachain/lib/dc/__init__.py +4 -1
- datachain/lib/dc/csv.py +3 -3
- datachain/lib/dc/database.py +151 -0
- datachain/lib/dc/datachain.py +25 -15
- datachain/lib/dc/datasets.py +70 -10
- datachain/lib/dc/hf.py +5 -5
- datachain/lib/dc/json.py +7 -7
- datachain/lib/dc/listings.py +3 -3
- datachain/lib/dc/pandas.py +13 -6
- datachain/lib/dc/parquet.py +3 -3
- datachain/lib/dc/records.py +12 -14
- datachain/lib/dc/storage.py +6 -6
- datachain/lib/dc/values.py +3 -3
- datachain/lib/listing.py +2 -2
- datachain/lib/signal_schema.py +34 -10
- datachain/listing.py +4 -4
- datachain/query/dataset.py +10 -12
- datachain/query/dispatch.py +7 -2
- datachain/query/schema.py +4 -1
- {datachain-0.14.4.dist-info → datachain-0.15.0.dist-info}/METADATA +3 -3
- {datachain-0.14.4.dist-info → datachain-0.15.0.dist-info}/RECORD +32 -31
- {datachain-0.14.4.dist-info → datachain-0.15.0.dist-info}/WHEEL +0 -0
- {datachain-0.14.4.dist-info → datachain-0.15.0.dist-info}/entry_points.txt +0 -0
- {datachain-0.14.4.dist-info → datachain-0.15.0.dist-info}/licenses/LICENSE +0 -0
- {datachain-0.14.4.dist-info → datachain-0.15.0.dist-info}/top_level.txt +0 -0
datachain/__init__.py
CHANGED
|
@@ -5,8 +5,10 @@ from datachain.lib.dc import (
|
|
|
5
5
|
DataChain,
|
|
6
6
|
Sys,
|
|
7
7
|
datasets,
|
|
8
|
+
delete_dataset,
|
|
8
9
|
listings,
|
|
9
10
|
read_csv,
|
|
11
|
+
read_database,
|
|
10
12
|
read_dataset,
|
|
11
13
|
read_hf,
|
|
12
14
|
read_json,
|
|
@@ -61,11 +63,13 @@ __all__ = [
|
|
|
61
63
|
"VideoFragment",
|
|
62
64
|
"VideoFrame",
|
|
63
65
|
"datasets",
|
|
66
|
+
"delete_dataset",
|
|
64
67
|
"is_chain_type",
|
|
65
68
|
"listings",
|
|
66
69
|
"metrics",
|
|
67
70
|
"param",
|
|
68
71
|
"read_csv",
|
|
72
|
+
"read_database",
|
|
69
73
|
"read_dataset",
|
|
70
74
|
"read_hf",
|
|
71
75
|
"read_json",
|
datachain/catalog/catalog.py
CHANGED
|
@@ -580,15 +580,13 @@ class Catalog:
|
|
|
580
580
|
source: str,
|
|
581
581
|
update=False,
|
|
582
582
|
client_config=None,
|
|
583
|
-
|
|
583
|
+
column="file",
|
|
584
584
|
skip_indexing=False,
|
|
585
585
|
) -> tuple[Optional["Listing"], "Client", str]:
|
|
586
586
|
from datachain import read_storage
|
|
587
587
|
from datachain.listing import Listing
|
|
588
588
|
|
|
589
|
-
read_storage(
|
|
590
|
-
source, session=self.session, update=update, object_name=object_name
|
|
591
|
-
).exec()
|
|
589
|
+
read_storage(source, session=self.session, update=update, column=column).exec()
|
|
592
590
|
|
|
593
591
|
list_ds_name, list_uri, list_path, _ = get_listing(
|
|
594
592
|
source, self.session, update=update
|
|
@@ -602,7 +600,7 @@ class Catalog:
|
|
|
602
600
|
self.warehouse.clone(),
|
|
603
601
|
client,
|
|
604
602
|
dataset_name=list_ds_name,
|
|
605
|
-
|
|
603
|
+
column=column,
|
|
606
604
|
)
|
|
607
605
|
|
|
608
606
|
return lst, client, list_path
|
|
@@ -1301,7 +1299,17 @@ class Catalog:
|
|
|
1301
1299
|
name: str,
|
|
1302
1300
|
version: Optional[int] = None,
|
|
1303
1301
|
force: Optional[bool] = False,
|
|
1302
|
+
studio: Optional[bool] = False,
|
|
1304
1303
|
):
|
|
1304
|
+
from datachain.remote.studio import StudioClient
|
|
1305
|
+
|
|
1306
|
+
if studio:
|
|
1307
|
+
client = StudioClient()
|
|
1308
|
+
response = client.rm_dataset(name, version=version, force=force)
|
|
1309
|
+
if not response.ok:
|
|
1310
|
+
raise DataChainError(response.message)
|
|
1311
|
+
return
|
|
1312
|
+
|
|
1305
1313
|
dataset = self.get_dataset(name)
|
|
1306
1314
|
if not version and not force:
|
|
1307
1315
|
raise ValueError(f"Missing dataset version from input for dataset {name}")
|
datachain/catalog/loader.py
CHANGED
|
@@ -1,4 +1,5 @@
|
|
|
1
1
|
import os
|
|
2
|
+
import sys
|
|
2
3
|
from importlib import import_module
|
|
3
4
|
from typing import TYPE_CHECKING, Any, Optional
|
|
4
5
|
|
|
@@ -15,6 +16,7 @@ METASTORE_ARG_PREFIX = "DATACHAIN_METASTORE_ARG_"
|
|
|
15
16
|
WAREHOUSE_SERIALIZED = "DATACHAIN__WAREHOUSE"
|
|
16
17
|
WAREHOUSE_IMPORT_PATH = "DATACHAIN_WAREHOUSE"
|
|
17
18
|
WAREHOUSE_ARG_PREFIX = "DATACHAIN_WAREHOUSE_ARG_"
|
|
19
|
+
DISTRIBUTED_IMPORT_PYTHONPATH = "DATACHAIN_DISTRIBUTED_PYTHONPATH"
|
|
18
20
|
DISTRIBUTED_IMPORT_PATH = "DATACHAIN_DISTRIBUTED"
|
|
19
21
|
|
|
20
22
|
IN_MEMORY_ERROR_MESSAGE = "In-memory is only supported on SQLite"
|
|
@@ -100,19 +102,21 @@ def get_warehouse(in_memory: bool = False) -> "AbstractWarehouse":
|
|
|
100
102
|
return warehouse_class(**warehouse_args)
|
|
101
103
|
|
|
102
104
|
|
|
103
|
-
def get_udf_distributor_class() -> type["AbstractUDFDistributor"]:
|
|
104
|
-
distributed_import_path
|
|
105
|
+
def get_udf_distributor_class() -> Optional[type["AbstractUDFDistributor"]]:
|
|
106
|
+
if not (distributed_import_path := os.environ.get(DISTRIBUTED_IMPORT_PATH)):
|
|
107
|
+
return None
|
|
105
108
|
|
|
106
|
-
if not distributed_import_path:
|
|
107
|
-
raise RuntimeError(
|
|
108
|
-
f"{DISTRIBUTED_IMPORT_PATH} import path is required "
|
|
109
|
-
"for distributed UDF processing."
|
|
110
|
-
)
|
|
111
109
|
# Distributed class paths are specified as (for example): module.classname
|
|
112
110
|
if "." not in distributed_import_path:
|
|
113
111
|
raise RuntimeError(
|
|
114
112
|
f"Invalid {DISTRIBUTED_IMPORT_PATH} import path: {distributed_import_path}"
|
|
115
113
|
)
|
|
114
|
+
|
|
115
|
+
# Optional: set the Python path to look for the module
|
|
116
|
+
distributed_import_pythonpath = os.environ.get(DISTRIBUTED_IMPORT_PYTHONPATH)
|
|
117
|
+
if distributed_import_pythonpath and distributed_import_pythonpath not in sys.path:
|
|
118
|
+
sys.path.insert(0, distributed_import_pythonpath)
|
|
119
|
+
|
|
116
120
|
module_name, _, class_name = distributed_import_path.rpartition(".")
|
|
117
121
|
distributed = import_module(module_name)
|
|
118
122
|
return getattr(distributed, class_name)
|
datachain/data_storage/schema.py
CHANGED
|
@@ -30,8 +30,8 @@ if TYPE_CHECKING:
|
|
|
30
30
|
DEFAULT_DELIMITER = "__"
|
|
31
31
|
|
|
32
32
|
|
|
33
|
-
def col_name(name: str,
|
|
34
|
-
return f"{
|
|
33
|
+
def col_name(name: str, column: str = "file") -> str:
|
|
34
|
+
return f"{column}{DEFAULT_DELIMITER}{name}"
|
|
35
35
|
|
|
36
36
|
|
|
37
37
|
def dedup_columns(columns: Iterable[sa.Column]) -> list[sa.Column]:
|
|
@@ -84,19 +84,19 @@ def convert_rows_custom_column_types(
|
|
|
84
84
|
|
|
85
85
|
|
|
86
86
|
class DirExpansion:
|
|
87
|
-
def __init__(self,
|
|
88
|
-
self.
|
|
87
|
+
def __init__(self, column: str):
|
|
88
|
+
self.column = column
|
|
89
89
|
|
|
90
|
-
def col_name(self, name: str,
|
|
91
|
-
|
|
92
|
-
return col_name(name,
|
|
90
|
+
def col_name(self, name: str, column: Optional[str] = None) -> str:
|
|
91
|
+
column = column or self.column
|
|
92
|
+
return col_name(name, column)
|
|
93
93
|
|
|
94
|
-
def c(self, query, name: str,
|
|
95
|
-
return getattr(query.c, self.col_name(name,
|
|
94
|
+
def c(self, query, name: str, column: Optional[str] = None) -> str:
|
|
95
|
+
return getattr(query.c, self.col_name(name, column=column))
|
|
96
96
|
|
|
97
97
|
def base_select(self, q):
|
|
98
98
|
return sa.select(
|
|
99
|
-
self.c(q, "id",
|
|
99
|
+
self.c(q, "id", column="sys"),
|
|
100
100
|
false().label(self.col_name("is_dir")),
|
|
101
101
|
self.c(q, "source"),
|
|
102
102
|
self.c(q, "path"),
|
|
@@ -153,12 +153,12 @@ class DataTable:
|
|
|
153
153
|
name: str,
|
|
154
154
|
engine: "DatabaseEngine",
|
|
155
155
|
column_types: Optional[dict[str, SQLType]] = None,
|
|
156
|
-
|
|
156
|
+
column: str = "file",
|
|
157
157
|
):
|
|
158
158
|
self.name: str = name
|
|
159
159
|
self.engine = engine
|
|
160
160
|
self.column_types: dict[str, SQLType] = column_types or {}
|
|
161
|
-
self.
|
|
161
|
+
self.column = column
|
|
162
162
|
|
|
163
163
|
@staticmethod
|
|
164
164
|
def copy_column(
|
|
@@ -224,18 +224,16 @@ class DataTable:
|
|
|
224
224
|
def columns(self) -> "ReadOnlyColumnCollection[str, sa.Column[Any]]":
|
|
225
225
|
return self.table.columns
|
|
226
226
|
|
|
227
|
-
def col_name(self, name: str,
|
|
228
|
-
|
|
229
|
-
return col_name(name,
|
|
227
|
+
def col_name(self, name: str, column: Optional[str] = None) -> str:
|
|
228
|
+
column = column or self.column
|
|
229
|
+
return col_name(name, column)
|
|
230
230
|
|
|
231
|
-
def without_object(
|
|
232
|
-
|
|
233
|
-
|
|
234
|
-
object_name = object_name or self.object_name
|
|
235
|
-
return column_name.removeprefix(f"{object_name}{DEFAULT_DELIMITER}")
|
|
231
|
+
def without_object(self, column_name: str, column: Optional[str] = None) -> str:
|
|
232
|
+
column = column or self.column
|
|
233
|
+
return column_name.removeprefix(f"{column}{DEFAULT_DELIMITER}")
|
|
236
234
|
|
|
237
|
-
def c(self, name: str,
|
|
238
|
-
return getattr(self.columns, self.col_name(name,
|
|
235
|
+
def c(self, name: str, column: Optional[str] = None):
|
|
236
|
+
return getattr(self.columns, self.col_name(name, column=column))
|
|
239
237
|
|
|
240
238
|
@property
|
|
241
239
|
def table(self) -> "sa.Table":
|
|
@@ -275,7 +273,7 @@ class DataTable:
|
|
|
275
273
|
]
|
|
276
274
|
|
|
277
275
|
def dir_expansion(self):
|
|
278
|
-
return DirExpansion(self.
|
|
276
|
+
return DirExpansion(self.column)
|
|
279
277
|
|
|
280
278
|
|
|
281
279
|
PARTITION_COLUMN_ID = "partition_id"
|
datachain/data_storage/sqlite.py
CHANGED
|
@@ -489,7 +489,7 @@ class SQLiteWarehouse(AbstractWarehouse):
|
|
|
489
489
|
self, dataset: DatasetRecord, version: int
|
|
490
490
|
) -> list[StorageURI]:
|
|
491
491
|
dr = self.dataset_rows(dataset, version)
|
|
492
|
-
query = dr.select(dr.c("source",
|
|
492
|
+
query = dr.select(dr.c("source", column="file")).distinct()
|
|
493
493
|
cur = self.db.cursor()
|
|
494
494
|
cur.row_factory = sqlite3.Row # type: ignore[assignment]
|
|
495
495
|
|
|
@@ -179,7 +179,7 @@ class AbstractWarehouse(ABC, Serializable):
|
|
|
179
179
|
self,
|
|
180
180
|
dataset: DatasetRecord,
|
|
181
181
|
version: Optional[int] = None,
|
|
182
|
-
|
|
182
|
+
column: str = "file",
|
|
183
183
|
):
|
|
184
184
|
version = version or dataset.latest_version
|
|
185
185
|
|
|
@@ -188,7 +188,7 @@ class AbstractWarehouse(ABC, Serializable):
|
|
|
188
188
|
table_name,
|
|
189
189
|
self.db,
|
|
190
190
|
dataset.get_schema(version),
|
|
191
|
-
|
|
191
|
+
column=column,
|
|
192
192
|
)
|
|
193
193
|
|
|
194
194
|
@property
|
|
@@ -487,7 +487,7 @@ class AbstractWarehouse(ABC, Serializable):
|
|
|
487
487
|
dataset_rows: "DataTable",
|
|
488
488
|
path_list: list[str],
|
|
489
489
|
glob_name: str,
|
|
490
|
-
|
|
490
|
+
column="file",
|
|
491
491
|
) -> Iterator[Node]:
|
|
492
492
|
"""Finds all Nodes that correspond to GLOB like path pattern."""
|
|
493
493
|
dr = dataset_rows
|
|
@@ -521,7 +521,7 @@ class AbstractWarehouse(ABC, Serializable):
|
|
|
521
521
|
de = dr.dir_expansion()
|
|
522
522
|
q = de.query(
|
|
523
523
|
dr.select().where(dr.c("is_latest") == true()).subquery(),
|
|
524
|
-
|
|
524
|
+
column=dr.column,
|
|
525
525
|
).subquery()
|
|
526
526
|
q = self.expand_query(de, q, dr)
|
|
527
527
|
|
|
@@ -597,12 +597,10 @@ class AbstractWarehouse(ABC, Serializable):
|
|
|
597
597
|
with_default(dr.c("is_latest")),
|
|
598
598
|
dr.c("last_modified"),
|
|
599
599
|
with_default(dr.c("size")),
|
|
600
|
-
with_default(dr.c("rand",
|
|
600
|
+
with_default(dr.c("rand", column="sys")),
|
|
601
601
|
dr.c("location"),
|
|
602
602
|
de.c(q, "source"),
|
|
603
|
-
).select_from(
|
|
604
|
-
q.outerjoin(dr.table, q.c.sys__id == dr.c("id", object_name="sys"))
|
|
605
|
-
)
|
|
603
|
+
).select_from(q.outerjoin(dr.table, q.c.sys__id == dr.c("id", column="sys")))
|
|
606
604
|
|
|
607
605
|
def get_node_by_path(self, dataset_rows: "DataTable", path: str) -> Node:
|
|
608
606
|
"""Gets node that corresponds to some path"""
|
|
@@ -1,5 +1,6 @@
|
|
|
1
|
+
import itertools
|
|
1
2
|
from collections.abc import Sequence
|
|
2
|
-
from typing import Any, Union
|
|
3
|
+
from typing import Any, Optional, Union
|
|
3
4
|
|
|
4
5
|
from datachain.lib.data_model import (
|
|
5
6
|
DataType,
|
|
@@ -66,21 +67,29 @@ def values_to_tuples( # noqa: C901, PLR0912
|
|
|
66
67
|
f"signal '{k}' is not present in the output",
|
|
67
68
|
)
|
|
68
69
|
else:
|
|
69
|
-
|
|
70
|
-
|
|
71
|
-
|
|
72
|
-
|
|
73
|
-
|
|
74
|
-
|
|
75
|
-
raise ValuesToTupleError(
|
|
76
|
-
ds_name,
|
|
77
|
-
f"signal '{k}' has unsupported type '{typ.__name__}'."
|
|
78
|
-
f" Please use DataModel types: {DataTypeNames}",
|
|
70
|
+
# FIXME: Stops as soon as it finds the first non-None value.
|
|
71
|
+
# If a non-None value appears early, it won't check the remaining items for
|
|
72
|
+
# `None` values.
|
|
73
|
+
try:
|
|
74
|
+
pos, first_not_none_element = next(
|
|
75
|
+
itertools.dropwhile(lambda pair: pair[1] is None, enumerate(v))
|
|
79
76
|
)
|
|
80
|
-
|
|
81
|
-
|
|
77
|
+
except StopIteration:
|
|
78
|
+
typ = str # default to str if all values are None or has length 0
|
|
79
|
+
nullable = True
|
|
82
80
|
else:
|
|
83
|
-
|
|
81
|
+
nullable = pos > 0
|
|
82
|
+
typ = type(first_not_none_element) # type: ignore[assignment]
|
|
83
|
+
if not is_chain_type(typ):
|
|
84
|
+
raise ValuesToTupleError(
|
|
85
|
+
ds_name,
|
|
86
|
+
f"signal '{k}' has unsupported type '{typ.__name__}'."
|
|
87
|
+
f" Please use DataModel types: {DataTypeNames}",
|
|
88
|
+
)
|
|
89
|
+
if isinstance(first_not_none_element, list):
|
|
90
|
+
typ = list[type(first_not_none_element[0])] # type: ignore[assignment, misc]
|
|
91
|
+
|
|
92
|
+
types_map[k] = Optional[typ] if nullable else typ # type: ignore[assignment]
|
|
84
93
|
|
|
85
94
|
if length < 0:
|
|
86
95
|
length = len_
|
datachain/lib/dc/__init__.py
CHANGED
|
@@ -1,6 +1,7 @@
|
|
|
1
1
|
from .csv import read_csv
|
|
2
|
+
from .database import read_database
|
|
2
3
|
from .datachain import C, Column, DataChain
|
|
3
|
-
from .datasets import datasets, read_dataset
|
|
4
|
+
from .datasets import datasets, delete_dataset, read_dataset
|
|
4
5
|
from .hf import read_hf
|
|
5
6
|
from .json import read_json
|
|
6
7
|
from .listings import listings
|
|
@@ -19,8 +20,10 @@ __all__ = [
|
|
|
19
20
|
"DatasetPrepareError",
|
|
20
21
|
"Sys",
|
|
21
22
|
"datasets",
|
|
23
|
+
"delete_dataset",
|
|
22
24
|
"listings",
|
|
23
25
|
"read_csv",
|
|
26
|
+
"read_database",
|
|
24
27
|
"read_dataset",
|
|
25
28
|
"read_hf",
|
|
26
29
|
"read_json",
|
datachain/lib/dc/csv.py
CHANGED
|
@@ -21,7 +21,7 @@ def read_csv(
|
|
|
21
21
|
delimiter: Optional[str] = None,
|
|
22
22
|
header: bool = True,
|
|
23
23
|
output: OutputType = None,
|
|
24
|
-
|
|
24
|
+
column: str = "",
|
|
25
25
|
model_name: str = "",
|
|
26
26
|
source: bool = True,
|
|
27
27
|
nrows=None,
|
|
@@ -42,7 +42,7 @@ def read_csv(
|
|
|
42
42
|
output : Dictionary or feature class defining column names and their
|
|
43
43
|
corresponding types. List of column names is also accepted, in which
|
|
44
44
|
case types will be inferred.
|
|
45
|
-
|
|
45
|
+
column : Created column name.
|
|
46
46
|
model_name : Generated model name.
|
|
47
47
|
source : Whether to include info about the source file.
|
|
48
48
|
nrows : Optional row limit.
|
|
@@ -119,7 +119,7 @@ def read_csv(
|
|
|
119
119
|
)
|
|
120
120
|
return chain.parse_tabular(
|
|
121
121
|
output=output,
|
|
122
|
-
|
|
122
|
+
column=column,
|
|
123
123
|
model_name=model_name,
|
|
124
124
|
source=source,
|
|
125
125
|
nrows=nrows,
|
|
@@ -0,0 +1,151 @@
|
|
|
1
|
+
import contextlib
|
|
2
|
+
import itertools
|
|
3
|
+
import os
|
|
4
|
+
import sqlite3
|
|
5
|
+
from typing import TYPE_CHECKING, Any, Optional, Union
|
|
6
|
+
|
|
7
|
+
import sqlalchemy
|
|
8
|
+
|
|
9
|
+
if TYPE_CHECKING:
|
|
10
|
+
from collections.abc import Iterator, Mapping, Sequence
|
|
11
|
+
|
|
12
|
+
import sqlalchemy.orm # noqa: TC004
|
|
13
|
+
|
|
14
|
+
from datachain.lib.data_model import DataType
|
|
15
|
+
from datachain.query import Session
|
|
16
|
+
|
|
17
|
+
from .datachain import DataChain
|
|
18
|
+
|
|
19
|
+
ConnectionType = Union[
|
|
20
|
+
str,
|
|
21
|
+
sqlalchemy.engine.URL,
|
|
22
|
+
sqlalchemy.engine.interfaces.Connectable,
|
|
23
|
+
sqlalchemy.engine.Engine,
|
|
24
|
+
sqlalchemy.engine.Connection,
|
|
25
|
+
sqlalchemy.orm.Session,
|
|
26
|
+
sqlite3.Connection,
|
|
27
|
+
]
|
|
28
|
+
|
|
29
|
+
|
|
30
|
+
@contextlib.contextmanager
|
|
31
|
+
def _connect(
|
|
32
|
+
connection: "ConnectionType",
|
|
33
|
+
) -> "Iterator[Union[sqlalchemy.engine.Connection, sqlalchemy.orm.Session]]":
|
|
34
|
+
import sqlalchemy.orm
|
|
35
|
+
|
|
36
|
+
with contextlib.ExitStack() as stack:
|
|
37
|
+
engine_kwargs = {"echo": bool(os.environ.get("DEBUG_SHOW_SQL_QUERIES"))}
|
|
38
|
+
if isinstance(connection, (str, sqlalchemy.URL)):
|
|
39
|
+
engine = sqlalchemy.create_engine(connection, **engine_kwargs)
|
|
40
|
+
stack.callback(engine.dispose)
|
|
41
|
+
yield stack.enter_context(engine.connect())
|
|
42
|
+
elif isinstance(connection, sqlite3.Connection):
|
|
43
|
+
engine = sqlalchemy.create_engine(
|
|
44
|
+
"sqlite://", creator=lambda: connection, **engine_kwargs
|
|
45
|
+
)
|
|
46
|
+
# do not close the connection, as it is managed by the caller
|
|
47
|
+
yield engine.connect()
|
|
48
|
+
elif isinstance(connection, sqlalchemy.Engine):
|
|
49
|
+
yield stack.enter_context(connection.connect())
|
|
50
|
+
elif isinstance(connection, (sqlalchemy.Connection, sqlalchemy.orm.Session)):
|
|
51
|
+
# do not close the connection, as it is managed by the caller
|
|
52
|
+
yield connection
|
|
53
|
+
else:
|
|
54
|
+
raise TypeError(f"Unsupported connection type: {type(connection).__name__}")
|
|
55
|
+
|
|
56
|
+
|
|
57
|
+
def _infer_schema(
|
|
58
|
+
result: "sqlalchemy.engine.Result",
|
|
59
|
+
to_infer: list[str],
|
|
60
|
+
infer_schema_length: Optional[int] = 100,
|
|
61
|
+
) -> tuple[list["sqlalchemy.Row"], dict[str, "DataType"]]:
|
|
62
|
+
from datachain.lib.convert.values_to_tuples import values_to_tuples
|
|
63
|
+
|
|
64
|
+
if not to_infer:
|
|
65
|
+
return [], {}
|
|
66
|
+
|
|
67
|
+
rows = list(itertools.islice(result, infer_schema_length))
|
|
68
|
+
values = {col: [row._mapping[col] for row in rows] for col in to_infer}
|
|
69
|
+
_, output_schema, _ = values_to_tuples("", **values)
|
|
70
|
+
return rows, output_schema
|
|
71
|
+
|
|
72
|
+
|
|
73
|
+
def read_database(
|
|
74
|
+
query: Union[str, "sqlalchemy.sql.expression.Executable"],
|
|
75
|
+
connection: "ConnectionType",
|
|
76
|
+
params: Union["Sequence[Mapping[str, Any]]", "Mapping[str, Any]", None] = None,
|
|
77
|
+
*,
|
|
78
|
+
output: Optional["dict[str, DataType]"] = None,
|
|
79
|
+
session: Optional["Session"] = None,
|
|
80
|
+
settings: Optional[dict] = None,
|
|
81
|
+
in_memory: bool = False,
|
|
82
|
+
infer_schema_length: Optional[int] = 100,
|
|
83
|
+
) -> "DataChain":
|
|
84
|
+
"""
|
|
85
|
+
Read the results of a SQL query into a DataChain, using a given database connection.
|
|
86
|
+
|
|
87
|
+
Args:
|
|
88
|
+
query:
|
|
89
|
+
The SQL query to execute. Can be a raw SQL string or a SQLAlchemy
|
|
90
|
+
`Executable` object.
|
|
91
|
+
connection: SQLAlchemy connectable, str, or a sqlite3 connection
|
|
92
|
+
Using SQLAlchemy makes it possible to use any DB supported by that
|
|
93
|
+
library. If a DBAPI2 object, only sqlite3 is supported. The user is
|
|
94
|
+
responsible for engine disposal and connection closure for the
|
|
95
|
+
SQLAlchemy connectable; str connections are closed automatically.
|
|
96
|
+
params: Parameters to pass to execute method.
|
|
97
|
+
output: A dictionary mapping column names to types, used to override the
|
|
98
|
+
schema inferred from the query results.
|
|
99
|
+
session: Session to use for the chain.
|
|
100
|
+
settings: Settings to use for the chain.
|
|
101
|
+
in_memory: If True, creates an in-memory session. Defaults to False.
|
|
102
|
+
infer_schema_length:
|
|
103
|
+
The maximum number of rows to scan for inferring schema.
|
|
104
|
+
If set to `None`, the full data may be scanned.
|
|
105
|
+
The rows used for schema inference are stored in memory,
|
|
106
|
+
so large values can lead to high memory usage.
|
|
107
|
+
Only applies if the `output` parameter is not set for the given column.
|
|
108
|
+
|
|
109
|
+
Examples:
|
|
110
|
+
Reading from a SQL query against a user-supplied connection:
|
|
111
|
+
```python
|
|
112
|
+
query = "SELECT key, value FROM tbl"
|
|
113
|
+
chain = dc.read_database(query, connection, output={"value": float})
|
|
114
|
+
```
|
|
115
|
+
|
|
116
|
+
Load data from a SQLAlchemy driver/engine:
|
|
117
|
+
```python
|
|
118
|
+
from sqlalchemy import create_engine
|
|
119
|
+
engine = create_engine("postgresql+psycopg://myuser:mypassword@localhost:5432/mydb")
|
|
120
|
+
chain = dc.read_database("select * from tbl", engine)
|
|
121
|
+
```
|
|
122
|
+
|
|
123
|
+
Load data from a parameterized SQLAlchemy query:
|
|
124
|
+
```python
|
|
125
|
+
query = "SELECT key, value FROM tbl WHERE value > :value"
|
|
126
|
+
dc.read_database(query, engine, params={"value": 50})
|
|
127
|
+
```
|
|
128
|
+
|
|
129
|
+
Notes:
|
|
130
|
+
This function works with a variety of databases — including, but not limited to,
|
|
131
|
+
SQLite, DuckDB, PostgreSQL, and Snowflake, provided the appropriate driver is
|
|
132
|
+
installed.
|
|
133
|
+
"""
|
|
134
|
+
from datachain.lib.dc.records import read_records
|
|
135
|
+
|
|
136
|
+
output = output or {}
|
|
137
|
+
if isinstance(query, str):
|
|
138
|
+
query = sqlalchemy.text(query)
|
|
139
|
+
kw = {"execution_options": {"stream_results": True}} # use server-side cursors
|
|
140
|
+
with _connect(connection) as conn, conn.execute(query, params, **kw) as result:
|
|
141
|
+
cols = result.keys()
|
|
142
|
+
to_infer = [k for k in cols if k not in output] # preserve the order
|
|
143
|
+
rows, inferred_schema = _infer_schema(result, to_infer, infer_schema_length)
|
|
144
|
+
records = (row._asdict() for row in itertools.chain(rows, result))
|
|
145
|
+
return read_records(
|
|
146
|
+
records,
|
|
147
|
+
session=session,
|
|
148
|
+
settings=settings,
|
|
149
|
+
in_memory=in_memory,
|
|
150
|
+
schema=inferred_schema | output,
|
|
151
|
+
)
|
datachain/lib/dc/datachain.py
CHANGED
|
@@ -133,7 +133,7 @@ class DataChain:
|
|
|
133
133
|
.choices[0]
|
|
134
134
|
.message.content,
|
|
135
135
|
)
|
|
136
|
-
.
|
|
136
|
+
.persist()
|
|
137
137
|
)
|
|
138
138
|
|
|
139
139
|
try:
|
|
@@ -357,7 +357,7 @@ class DataChain:
|
|
|
357
357
|
self,
|
|
358
358
|
col: str,
|
|
359
359
|
model_name: Optional[str] = None,
|
|
360
|
-
|
|
360
|
+
column: Optional[str] = None,
|
|
361
361
|
schema_sample_size: int = 1,
|
|
362
362
|
) -> "DataChain":
|
|
363
363
|
"""Explodes a column containing JSON objects (dict or str DataChain type) into
|
|
@@ -368,7 +368,7 @@ class DataChain:
|
|
|
368
368
|
col: the name of the column containing JSON to be exploded.
|
|
369
369
|
model_name: optional generated model name. By default generates the name
|
|
370
370
|
automatically.
|
|
371
|
-
|
|
371
|
+
column: optional generated column name. By default generates the
|
|
372
372
|
name automatically.
|
|
373
373
|
schema_sample_size: the number of rows to use for inferring the schema of
|
|
374
374
|
the JSON (in case some fields are optional and it's not enough to
|
|
@@ -406,10 +406,10 @@ class DataChain:
|
|
|
406
406
|
)
|
|
407
407
|
return model.model_validate(json_dict)
|
|
408
408
|
|
|
409
|
-
if not
|
|
410
|
-
|
|
409
|
+
if not column:
|
|
410
|
+
column = f"{col}_expl"
|
|
411
411
|
|
|
412
|
-
return self.map(json_to_model, params=col, output={
|
|
412
|
+
return self.map(json_to_model, params=col, output={column: model})
|
|
413
413
|
|
|
414
414
|
@classmethod
|
|
415
415
|
def datasets(
|
|
@@ -443,9 +443,20 @@ class DataChain:
|
|
|
443
443
|
)
|
|
444
444
|
return listings(*args, **kwargs)
|
|
445
445
|
|
|
446
|
+
def persist(self) -> "Self":
|
|
447
|
+
"""Saves temporary chain that will be removed after the process ends.
|
|
448
|
+
Temporary datasets are useful for optimization, for example when we have
|
|
449
|
+
multiple chains starting with identical sub-chain. We can then persist that
|
|
450
|
+
common chain and use it to calculate other chains, to avoid re-calculation
|
|
451
|
+
every time.
|
|
452
|
+
It returns the chain itself.
|
|
453
|
+
"""
|
|
454
|
+
schema = self.signals_schema.clone_without_sys_signals().serialize()
|
|
455
|
+
return self._evolve(query=self._query.save(feature_schema=schema))
|
|
456
|
+
|
|
446
457
|
def save( # type: ignore[override]
|
|
447
458
|
self,
|
|
448
|
-
name:
|
|
459
|
+
name: str,
|
|
449
460
|
version: Optional[int] = None,
|
|
450
461
|
description: Optional[str] = None,
|
|
451
462
|
labels: Optional[list[str]] = None,
|
|
@@ -454,8 +465,7 @@ class DataChain:
|
|
|
454
465
|
"""Save to a Dataset. It returns the chain itself.
|
|
455
466
|
|
|
456
467
|
Parameters:
|
|
457
|
-
name : dataset name.
|
|
458
|
-
removed after process ends. Temp dataset are useful for optimization.
|
|
468
|
+
name : dataset name.
|
|
459
469
|
version : version of a dataset. Default - the last version that exist.
|
|
460
470
|
description : description of a dataset.
|
|
461
471
|
labels : labels of a dataset.
|
|
@@ -1112,7 +1122,7 @@ class DataChain:
|
|
|
1112
1122
|
if self._query.attached:
|
|
1113
1123
|
chain = self
|
|
1114
1124
|
else:
|
|
1115
|
-
chain = self.
|
|
1125
|
+
chain = self.persist()
|
|
1116
1126
|
assert chain.name is not None # for mypy
|
|
1117
1127
|
return PytorchDataset(
|
|
1118
1128
|
chain.name,
|
|
@@ -1588,7 +1598,7 @@ class DataChain:
|
|
|
1588
1598
|
def parse_tabular(
|
|
1589
1599
|
self,
|
|
1590
1600
|
output: OutputType = None,
|
|
1591
|
-
|
|
1601
|
+
column: str = "",
|
|
1592
1602
|
model_name: str = "",
|
|
1593
1603
|
source: bool = True,
|
|
1594
1604
|
nrows: Optional[int] = None,
|
|
@@ -1600,7 +1610,7 @@ class DataChain:
|
|
|
1600
1610
|
output : Dictionary or feature class defining column names and their
|
|
1601
1611
|
corresponding types. List of column names is also accepted, in which
|
|
1602
1612
|
case types will be inferred.
|
|
1603
|
-
|
|
1613
|
+
column : Generated column name.
|
|
1604
1614
|
model_name : Generated model name.
|
|
1605
1615
|
source : Whether to include info about the source file.
|
|
1606
1616
|
nrows : Optional row limit.
|
|
@@ -1651,14 +1661,14 @@ class DataChain:
|
|
|
1651
1661
|
raise DatasetPrepareError(self.name, e) from e
|
|
1652
1662
|
|
|
1653
1663
|
if isinstance(output, dict):
|
|
1654
|
-
model_name = model_name or
|
|
1664
|
+
model_name = model_name or column or ""
|
|
1655
1665
|
model = dict_to_data_model(model_name, output)
|
|
1656
1666
|
output = model
|
|
1657
1667
|
else:
|
|
1658
1668
|
model = output # type: ignore[assignment]
|
|
1659
1669
|
|
|
1660
|
-
if
|
|
1661
|
-
output = {
|
|
1670
|
+
if column:
|
|
1671
|
+
output = {column: model} # type: ignore[dict-item]
|
|
1662
1672
|
elif isinstance(output, type(BaseModel)):
|
|
1663
1673
|
output = {
|
|
1664
1674
|
name: info.annotation # type: ignore[misc]
|