datachain 0.3.1__py3-none-any.whl → 0.3.3__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of datachain might be problematic. Click here for more details.
- datachain/asyn.py +20 -0
- datachain/catalog/catalog.py +12 -1
- datachain/catalog/loader.py +75 -50
- datachain/client/azure.py +13 -0
- datachain/client/gcs.py +12 -0
- datachain/client/local.py +11 -0
- datachain/client/s3.py +12 -0
- datachain/data_storage/schema.py +22 -8
- datachain/data_storage/sqlite.py +60 -14
- datachain/data_storage/warehouse.py +17 -3
- datachain/lib/arrow.py +1 -1
- datachain/lib/convert/values_to_tuples.py +14 -8
- datachain/lib/data_model.py +1 -0
- datachain/lib/dc.py +52 -19
- datachain/lib/listing.py +111 -0
- datachain/lib/meta_formats.py +8 -2
- datachain/node.py +1 -1
- datachain/query/dataset.py +22 -12
- datachain/query/schema.py +4 -0
- datachain/query/session.py +9 -2
- datachain/sql/default/base.py +3 -0
- datachain/sql/sqlite/base.py +33 -4
- datachain/sql/types.py +120 -11
- {datachain-0.3.1.dist-info → datachain-0.3.3.dist-info}/METADATA +75 -87
- {datachain-0.3.1.dist-info → datachain-0.3.3.dist-info}/RECORD +29 -28
- {datachain-0.3.1.dist-info → datachain-0.3.3.dist-info}/WHEEL +1 -1
- {datachain-0.3.1.dist-info → datachain-0.3.3.dist-info}/LICENSE +0 -0
- {datachain-0.3.1.dist-info → datachain-0.3.3.dist-info}/entry_points.txt +0 -0
- {datachain-0.3.1.dist-info → datachain-0.3.3.dist-info}/top_level.txt +0 -0
datachain/asyn.py
CHANGED
|
@@ -224,3 +224,23 @@ class OrderedMapper(AsyncMapper[InputT, ResultT]):
|
|
|
224
224
|
async def _break_iteration(self) -> None:
|
|
225
225
|
self.heap = []
|
|
226
226
|
self._push_result(self._next_yield, None)
|
|
227
|
+
|
|
228
|
+
|
|
229
|
+
def iter_over_async(ait, loop):
|
|
230
|
+
"""Wrap an asynchronous iterator into a synchronous one"""
|
|
231
|
+
ait = ait.__aiter__()
|
|
232
|
+
|
|
233
|
+
# helper async fn that just gets the next element from the async iterator
|
|
234
|
+
async def get_next():
|
|
235
|
+
try:
|
|
236
|
+
obj = await ait.__anext__()
|
|
237
|
+
return False, obj
|
|
238
|
+
except StopAsyncIteration:
|
|
239
|
+
return True, None
|
|
240
|
+
|
|
241
|
+
# actual sync iterator
|
|
242
|
+
while True:
|
|
243
|
+
done, obj = asyncio.run_coroutine_threadsafe(get_next(), loop).result()
|
|
244
|
+
if done:
|
|
245
|
+
break
|
|
246
|
+
yield obj
|
datachain/catalog/catalog.py
CHANGED
|
@@ -577,6 +577,7 @@ class Catalog:
|
|
|
577
577
|
warehouse_ready_callback: Optional[
|
|
578
578
|
Callable[["AbstractWarehouse"], None]
|
|
579
579
|
] = None,
|
|
580
|
+
in_memory: bool = False,
|
|
580
581
|
):
|
|
581
582
|
datachain_dir = DataChainDir(cache=cache_dir, tmp=tmp_dir)
|
|
582
583
|
datachain_dir.init()
|
|
@@ -590,6 +591,7 @@ class Catalog:
|
|
|
590
591
|
"tmp_dir": tmp_dir,
|
|
591
592
|
}
|
|
592
593
|
self._warehouse_ready_callback = warehouse_ready_callback
|
|
594
|
+
self.in_memory = in_memory
|
|
593
595
|
|
|
594
596
|
@cached_property
|
|
595
597
|
def warehouse(self) -> "AbstractWarehouse":
|
|
@@ -1627,8 +1629,17 @@ class Catalog:
|
|
|
1627
1629
|
version = self.get_dataset(dataset_name).get_version(dataset_version)
|
|
1628
1630
|
|
|
1629
1631
|
file_signals_values = {}
|
|
1632
|
+
file_schemas = {}
|
|
1633
|
+
# TODO: To remove after we properly fix deserialization
|
|
1634
|
+
for signal, type_name in version.feature_schema.items():
|
|
1635
|
+
from datachain.lib.model_store import ModelStore
|
|
1630
1636
|
|
|
1631
|
-
|
|
1637
|
+
type_name_parsed, v = ModelStore.parse_name_version(type_name)
|
|
1638
|
+
fr = ModelStore.get(type_name_parsed, v)
|
|
1639
|
+
if fr and issubclass(fr, File):
|
|
1640
|
+
file_schemas[signal] = type_name
|
|
1641
|
+
|
|
1642
|
+
schema = SignalSchema.deserialize(file_schemas)
|
|
1632
1643
|
for file_signals in schema.get_signals(File):
|
|
1633
1644
|
prefix = file_signals.replace(".", DEFAULT_DELIMITER) + DEFAULT_DELIMITER
|
|
1634
1645
|
file_signals_values[file_signals] = {
|
datachain/catalog/loader.py
CHANGED
|
@@ -28,8 +28,10 @@ WAREHOUSE_ARG_PREFIX = "DATACHAIN_WAREHOUSE_ARG_"
|
|
|
28
28
|
DISTRIBUTED_IMPORT_PATH = "DATACHAIN_DISTRIBUTED"
|
|
29
29
|
DISTRIBUTED_ARG_PREFIX = "DATACHAIN_DISTRIBUTED_ARG_"
|
|
30
30
|
|
|
31
|
+
IN_MEMORY_ERROR_MESSAGE = "In-memory is only supported on SQLite"
|
|
31
32
|
|
|
32
|
-
|
|
33
|
+
|
|
34
|
+
def get_id_generator(in_memory: bool = False) -> "AbstractIDGenerator":
|
|
33
35
|
id_generator_serialized = os.environ.get(ID_GENERATOR_SERIALIZED)
|
|
34
36
|
if id_generator_serialized:
|
|
35
37
|
id_generator_obj = deserialize(id_generator_serialized)
|
|
@@ -43,25 +45,31 @@ def get_id_generator() -> "AbstractIDGenerator":
|
|
|
43
45
|
id_generator_import_path = os.environ.get(ID_GENERATOR_IMPORT_PATH)
|
|
44
46
|
id_generator_arg_envs = get_envs_by_prefix(ID_GENERATOR_ARG_PREFIX)
|
|
45
47
|
# Convert env variable names to keyword argument names by lowercasing them
|
|
46
|
-
id_generator_args
|
|
47
|
-
|
|
48
|
-
|
|
49
|
-
|
|
50
|
-
|
|
51
|
-
|
|
52
|
-
|
|
53
|
-
|
|
54
|
-
|
|
55
|
-
|
|
56
|
-
|
|
57
|
-
|
|
58
|
-
|
|
59
|
-
|
|
60
|
-
|
|
48
|
+
id_generator_args: dict[str, Any] = {
|
|
49
|
+
k.lower(): v for k, v in id_generator_arg_envs.items()
|
|
50
|
+
}
|
|
51
|
+
|
|
52
|
+
if not id_generator_import_path:
|
|
53
|
+
id_generator_args["in_memory"] = in_memory
|
|
54
|
+
return SQLiteIDGenerator(**id_generator_args)
|
|
55
|
+
if in_memory:
|
|
56
|
+
raise RuntimeError(IN_MEMORY_ERROR_MESSAGE)
|
|
57
|
+
# ID generator paths are specified as (for example):
|
|
58
|
+
# datachain.data_storage.SQLiteIDGenerator
|
|
59
|
+
if "." not in id_generator_import_path:
|
|
60
|
+
raise RuntimeError(
|
|
61
|
+
f"Invalid {ID_GENERATOR_IMPORT_PATH} import path:"
|
|
62
|
+
f"{id_generator_import_path}"
|
|
63
|
+
)
|
|
64
|
+
module_name, _, class_name = id_generator_import_path.rpartition(".")
|
|
65
|
+
id_generator = import_module(module_name)
|
|
66
|
+
id_generator_class = getattr(id_generator, class_name)
|
|
61
67
|
return id_generator_class(**id_generator_args)
|
|
62
68
|
|
|
63
69
|
|
|
64
|
-
def get_metastore(
|
|
70
|
+
def get_metastore(
|
|
71
|
+
id_generator: Optional["AbstractIDGenerator"], in_memory: bool = False
|
|
72
|
+
) -> "AbstractMetastore":
|
|
65
73
|
metastore_serialized = os.environ.get(METASTORE_SERIALIZED)
|
|
66
74
|
if metastore_serialized:
|
|
67
75
|
metastore_obj = deserialize(metastore_serialized)
|
|
@@ -78,24 +86,32 @@ def get_metastore(id_generator: Optional["AbstractIDGenerator"]) -> "AbstractMet
|
|
|
78
86
|
metastore_import_path = os.environ.get(METASTORE_IMPORT_PATH)
|
|
79
87
|
metastore_arg_envs = get_envs_by_prefix(METASTORE_ARG_PREFIX)
|
|
80
88
|
# Convert env variable names to keyword argument names by lowercasing them
|
|
81
|
-
metastore_args
|
|
82
|
-
|
|
83
|
-
|
|
84
|
-
|
|
85
|
-
|
|
86
|
-
if
|
|
87
|
-
raise
|
|
88
|
-
|
|
89
|
-
|
|
90
|
-
|
|
91
|
-
|
|
92
|
-
|
|
93
|
-
|
|
94
|
-
|
|
89
|
+
metastore_args: dict[str, Any] = {
|
|
90
|
+
k.lower(): v for k, v in metastore_arg_envs.items()
|
|
91
|
+
}
|
|
92
|
+
|
|
93
|
+
if not metastore_import_path:
|
|
94
|
+
if not isinstance(id_generator, SQLiteIDGenerator):
|
|
95
|
+
raise ValueError("SQLiteMetastore can only be used with SQLiteIDGenerator")
|
|
96
|
+
metastore_args["in_memory"] = in_memory
|
|
97
|
+
return SQLiteMetastore(id_generator, **metastore_args)
|
|
98
|
+
if in_memory:
|
|
99
|
+
raise RuntimeError(IN_MEMORY_ERROR_MESSAGE)
|
|
100
|
+
# Metastore paths are specified as (for example):
|
|
101
|
+
# datachain.data_storage.SQLiteMetastore
|
|
102
|
+
if "." not in metastore_import_path:
|
|
103
|
+
raise RuntimeError(
|
|
104
|
+
f"Invalid {METASTORE_IMPORT_PATH} import path: {metastore_import_path}"
|
|
105
|
+
)
|
|
106
|
+
module_name, _, class_name = metastore_import_path.rpartition(".")
|
|
107
|
+
metastore = import_module(module_name)
|
|
108
|
+
metastore_class = getattr(metastore, class_name)
|
|
95
109
|
return metastore_class(id_generator, **metastore_args)
|
|
96
110
|
|
|
97
111
|
|
|
98
|
-
def get_warehouse(
|
|
112
|
+
def get_warehouse(
|
|
113
|
+
id_generator: Optional["AbstractIDGenerator"], in_memory: bool = False
|
|
114
|
+
) -> "AbstractWarehouse":
|
|
99
115
|
warehouse_serialized = os.environ.get(WAREHOUSE_SERIALIZED)
|
|
100
116
|
if warehouse_serialized:
|
|
101
117
|
warehouse_obj = deserialize(warehouse_serialized)
|
|
@@ -112,20 +128,26 @@ def get_warehouse(id_generator: Optional["AbstractIDGenerator"]) -> "AbstractWar
|
|
|
112
128
|
warehouse_import_path = os.environ.get(WAREHOUSE_IMPORT_PATH)
|
|
113
129
|
warehouse_arg_envs = get_envs_by_prefix(WAREHOUSE_ARG_PREFIX)
|
|
114
130
|
# Convert env variable names to keyword argument names by lowercasing them
|
|
115
|
-
warehouse_args
|
|
116
|
-
|
|
117
|
-
|
|
118
|
-
|
|
119
|
-
|
|
120
|
-
if
|
|
121
|
-
raise
|
|
122
|
-
|
|
123
|
-
|
|
124
|
-
|
|
125
|
-
|
|
126
|
-
|
|
127
|
-
|
|
128
|
-
|
|
131
|
+
warehouse_args: dict[str, Any] = {
|
|
132
|
+
k.lower(): v for k, v in warehouse_arg_envs.items()
|
|
133
|
+
}
|
|
134
|
+
|
|
135
|
+
if not warehouse_import_path:
|
|
136
|
+
if not isinstance(id_generator, SQLiteIDGenerator):
|
|
137
|
+
raise ValueError("SQLiteWarehouse can only be used with SQLiteIDGenerator")
|
|
138
|
+
warehouse_args["in_memory"] = in_memory
|
|
139
|
+
return SQLiteWarehouse(id_generator, **warehouse_args)
|
|
140
|
+
if in_memory:
|
|
141
|
+
raise RuntimeError(IN_MEMORY_ERROR_MESSAGE)
|
|
142
|
+
# Warehouse paths are specified as (for example):
|
|
143
|
+
# datachain.data_storage.SQLiteWarehouse
|
|
144
|
+
if "." not in warehouse_import_path:
|
|
145
|
+
raise RuntimeError(
|
|
146
|
+
f"Invalid {WAREHOUSE_IMPORT_PATH} import path: {warehouse_import_path}"
|
|
147
|
+
)
|
|
148
|
+
module_name, _, class_name = warehouse_import_path.rpartition(".")
|
|
149
|
+
warehouse = import_module(module_name)
|
|
150
|
+
warehouse_class = getattr(warehouse, class_name)
|
|
129
151
|
return warehouse_class(id_generator, **warehouse_args)
|
|
130
152
|
|
|
131
153
|
|
|
@@ -152,7 +174,9 @@ def get_distributed_class(**kwargs):
|
|
|
152
174
|
return distributed_class(**distributed_args | kwargs)
|
|
153
175
|
|
|
154
176
|
|
|
155
|
-
def get_catalog(
|
|
177
|
+
def get_catalog(
|
|
178
|
+
client_config: Optional[dict[str, Any]] = None, in_memory: bool = False
|
|
179
|
+
) -> Catalog:
|
|
156
180
|
"""
|
|
157
181
|
Function that creates Catalog instance with appropriate metastore
|
|
158
182
|
and warehouse classes. Metastore class can be provided with env variable
|
|
@@ -164,10 +188,11 @@ def get_catalog(client_config: Optional[dict[str, Any]] = None) -> Catalog:
|
|
|
164
188
|
and name of variable after, e.g. if it accepts team_id as kwargs
|
|
165
189
|
we can provide DATACHAIN_METASTORE_ARG_TEAM_ID=12345 env variable.
|
|
166
190
|
"""
|
|
167
|
-
id_generator = get_id_generator()
|
|
191
|
+
id_generator = get_id_generator(in_memory=in_memory)
|
|
168
192
|
return Catalog(
|
|
169
193
|
id_generator=id_generator,
|
|
170
|
-
metastore=get_metastore(id_generator),
|
|
171
|
-
warehouse=get_warehouse(id_generator),
|
|
194
|
+
metastore=get_metastore(id_generator, in_memory=in_memory),
|
|
195
|
+
warehouse=get_warehouse(id_generator, in_memory=in_memory),
|
|
172
196
|
client_config=client_config,
|
|
197
|
+
in_memory=in_memory,
|
|
173
198
|
)
|
datachain/client/azure.py
CHANGED
|
@@ -3,6 +3,7 @@ from typing import Any
|
|
|
3
3
|
from adlfs import AzureBlobFileSystem
|
|
4
4
|
from tqdm import tqdm
|
|
5
5
|
|
|
6
|
+
from datachain.lib.file import File
|
|
6
7
|
from datachain.node import Entry
|
|
7
8
|
|
|
8
9
|
from .fsspec import DELIMITER, Client, ResultQueue
|
|
@@ -24,6 +25,18 @@ class AzureClient(Client):
|
|
|
24
25
|
size=v.get("size", ""),
|
|
25
26
|
)
|
|
26
27
|
|
|
28
|
+
def info_to_file(self, v: dict[str, Any], path: str) -> File:
|
|
29
|
+
version_id = v.get("version_id")
|
|
30
|
+
return File(
|
|
31
|
+
source=self.uri,
|
|
32
|
+
path=path,
|
|
33
|
+
etag=v.get("etag", "").strip('"'),
|
|
34
|
+
version=version_id or "",
|
|
35
|
+
is_latest=version_id is None or bool(v.get("is_current_version")),
|
|
36
|
+
last_modified=v["last_modified"],
|
|
37
|
+
size=v.get("size", ""),
|
|
38
|
+
)
|
|
39
|
+
|
|
27
40
|
async def _fetch_flat(self, start_prefix: str, result_queue: ResultQueue) -> None:
|
|
28
41
|
prefix = start_prefix
|
|
29
42
|
if prefix:
|
datachain/client/gcs.py
CHANGED
|
@@ -9,6 +9,7 @@ from dateutil.parser import isoparse
|
|
|
9
9
|
from gcsfs import GCSFileSystem
|
|
10
10
|
from tqdm import tqdm
|
|
11
11
|
|
|
12
|
+
from datachain.lib.file import File
|
|
12
13
|
from datachain.node import Entry
|
|
13
14
|
|
|
14
15
|
from .fsspec import DELIMITER, Client, ResultQueue
|
|
@@ -120,3 +121,14 @@ class GCSClient(Client):
|
|
|
120
121
|
last_modified=self.parse_timestamp(v["updated"]),
|
|
121
122
|
size=v.get("size", ""),
|
|
122
123
|
)
|
|
124
|
+
|
|
125
|
+
def info_to_file(self, v: dict[str, Any], path: str) -> File:
|
|
126
|
+
return File(
|
|
127
|
+
source=self.uri,
|
|
128
|
+
path=path,
|
|
129
|
+
etag=v.get("etag", ""),
|
|
130
|
+
version=v.get("generation", ""),
|
|
131
|
+
is_latest=not v.get("timeDeleted"),
|
|
132
|
+
last_modified=self.parse_timestamp(v["updated"]),
|
|
133
|
+
size=v.get("size", ""),
|
|
134
|
+
)
|
datachain/client/local.py
CHANGED
|
@@ -7,6 +7,7 @@ from urllib.parse import urlparse
|
|
|
7
7
|
|
|
8
8
|
from fsspec.implementations.local import LocalFileSystem
|
|
9
9
|
|
|
10
|
+
from datachain.lib.file import File
|
|
10
11
|
from datachain.node import Entry
|
|
11
12
|
from datachain.storage import StorageURI
|
|
12
13
|
|
|
@@ -144,6 +145,16 @@ class FileClient(Client):
|
|
|
144
145
|
size=v.get("size", ""),
|
|
145
146
|
)
|
|
146
147
|
|
|
148
|
+
def info_to_file(self, v: dict[str, Any], path: str) -> File:
|
|
149
|
+
return File(
|
|
150
|
+
source=self.uri,
|
|
151
|
+
path=path,
|
|
152
|
+
size=v.get("size", ""),
|
|
153
|
+
etag=v["mtime"].hex(),
|
|
154
|
+
is_latest=True,
|
|
155
|
+
last_modified=datetime.fromtimestamp(v["mtime"], timezone.utc),
|
|
156
|
+
)
|
|
157
|
+
|
|
147
158
|
def fetch_nodes(
|
|
148
159
|
self,
|
|
149
160
|
nodes,
|
datachain/client/s3.py
CHANGED
|
@@ -5,6 +5,7 @@ from botocore.exceptions import NoCredentialsError
|
|
|
5
5
|
from s3fs import S3FileSystem
|
|
6
6
|
from tqdm import tqdm
|
|
7
7
|
|
|
8
|
+
from datachain.lib.file import File
|
|
8
9
|
from datachain.node import Entry
|
|
9
10
|
|
|
10
11
|
from .fsspec import DELIMITER, Client, ResultQueue
|
|
@@ -167,3 +168,14 @@ class ClientS3(Client):
|
|
|
167
168
|
owner_name=v.get("Owner", {}).get("DisplayName", ""),
|
|
168
169
|
owner_id=v.get("Owner", {}).get("ID", ""),
|
|
169
170
|
)
|
|
171
|
+
|
|
172
|
+
def info_to_file(self, v: dict[str, Any], path: str) -> File:
|
|
173
|
+
return File(
|
|
174
|
+
source=self.uri,
|
|
175
|
+
path=path,
|
|
176
|
+
size=v["size"],
|
|
177
|
+
version=ClientS3.clean_s3_version(v.get("VersionId", "")),
|
|
178
|
+
etag=v.get("ETag", "").strip('"'),
|
|
179
|
+
is_latest=v.get("IsLatest", True),
|
|
180
|
+
last_modified=v.get("LastModified", ""),
|
|
181
|
+
)
|
datachain/data_storage/schema.py
CHANGED
|
@@ -67,7 +67,11 @@ def convert_rows_custom_column_types(
|
|
|
67
67
|
for row in rows:
|
|
68
68
|
row_list = list(row)
|
|
69
69
|
for idx, t in custom_columns_types:
|
|
70
|
-
row_list[idx] =
|
|
70
|
+
row_list[idx] = (
|
|
71
|
+
t.default_value(dialect)
|
|
72
|
+
if row_list[idx] is None
|
|
73
|
+
else t.on_read_convert(row_list[idx], dialect)
|
|
74
|
+
)
|
|
71
75
|
|
|
72
76
|
yield tuple(row_list)
|
|
73
77
|
|
|
@@ -136,7 +140,15 @@ class DataTable:
|
|
|
136
140
|
self.column_types: dict[str, SQLType] = column_types or {}
|
|
137
141
|
|
|
138
142
|
@staticmethod
|
|
139
|
-
def copy_column(
|
|
143
|
+
def copy_column(
|
|
144
|
+
column: sa.Column,
|
|
145
|
+
primary_key: Optional[bool] = None,
|
|
146
|
+
index: Optional[bool] = None,
|
|
147
|
+
nullable: Optional[bool] = None,
|
|
148
|
+
default: Optional[Any] = None,
|
|
149
|
+
server_default: Optional[Any] = None,
|
|
150
|
+
unique: Optional[bool] = None,
|
|
151
|
+
) -> sa.Column:
|
|
140
152
|
"""
|
|
141
153
|
Copy a sqlalchemy Column object intended for use as a signal column.
|
|
142
154
|
|
|
@@ -150,12 +162,14 @@ class DataTable:
|
|
|
150
162
|
return sa.Column(
|
|
151
163
|
column.name,
|
|
152
164
|
column.type,
|
|
153
|
-
primary_key=column.primary_key,
|
|
154
|
-
index=column.index,
|
|
155
|
-
nullable=column.nullable,
|
|
156
|
-
default=column.default,
|
|
157
|
-
server_default=
|
|
158
|
-
|
|
165
|
+
primary_key=primary_key if primary_key is not None else column.primary_key,
|
|
166
|
+
index=index if index is not None else column.index,
|
|
167
|
+
nullable=nullable if nullable is not None else column.nullable,
|
|
168
|
+
default=default if default is not None else column.default,
|
|
169
|
+
server_default=(
|
|
170
|
+
server_default if server_default is not None else column.server_default
|
|
171
|
+
),
|
|
172
|
+
unique=unique if unique is not None else column.unique,
|
|
159
173
|
)
|
|
160
174
|
|
|
161
175
|
@classmethod
|
datachain/data_storage/sqlite.py
CHANGED
|
@@ -20,6 +20,8 @@ from sqlalchemy.dialects import sqlite
|
|
|
20
20
|
from sqlalchemy.schema import CreateIndex, CreateTable, DropTable
|
|
21
21
|
from sqlalchemy.sql import func
|
|
22
22
|
from sqlalchemy.sql.expression import bindparam, cast
|
|
23
|
+
from sqlalchemy.sql.selectable import Select
|
|
24
|
+
from tqdm import tqdm
|
|
23
25
|
|
|
24
26
|
import datachain.sql.sqlite
|
|
25
27
|
from datachain.data_storage import AbstractDBMetastore, AbstractWarehouse
|
|
@@ -35,14 +37,13 @@ from datachain.sql.sqlite import create_user_defined_sql_functions, sqlite_diale
|
|
|
35
37
|
from datachain.sql.sqlite.base import load_usearch_extension
|
|
36
38
|
from datachain.sql.types import SQLType
|
|
37
39
|
from datachain.storage import StorageURI
|
|
38
|
-
from datachain.utils import DataChainDir
|
|
40
|
+
from datachain.utils import DataChainDir, batched_it
|
|
39
41
|
|
|
40
42
|
if TYPE_CHECKING:
|
|
41
43
|
from sqlalchemy.dialects.sqlite import Insert
|
|
42
44
|
from sqlalchemy.engine.base import Engine
|
|
43
45
|
from sqlalchemy.schema import SchemaItem
|
|
44
|
-
from sqlalchemy.sql.elements import
|
|
45
|
-
from sqlalchemy.sql.selectable import Select
|
|
46
|
+
from sqlalchemy.sql.elements import ColumnElement
|
|
46
47
|
from sqlalchemy.types import TypeEngine
|
|
47
48
|
|
|
48
49
|
|
|
@@ -54,8 +55,6 @@ RETRY_FACTOR = 2
|
|
|
54
55
|
|
|
55
56
|
DETECT_TYPES = sqlite3.PARSE_DECLTYPES | sqlite3.PARSE_COLNAMES
|
|
56
57
|
|
|
57
|
-
Column = Union[str, "ColumnClause[Any]", "TextClause"]
|
|
58
|
-
|
|
59
58
|
datachain.sql.sqlite.setup()
|
|
60
59
|
|
|
61
60
|
quote_schema = sqlite_dialect.identifier_preparer.quote_schema
|
|
@@ -82,6 +81,17 @@ def retry_sqlite_locks(func):
|
|
|
82
81
|
return wrapper
|
|
83
82
|
|
|
84
83
|
|
|
84
|
+
def get_db_file_in_memory(
|
|
85
|
+
db_file: Optional[str] = None, in_memory: bool = False
|
|
86
|
+
) -> Optional[str]:
|
|
87
|
+
"""Get in-memory db_file and check that conflicting arguments are not provided."""
|
|
88
|
+
if in_memory:
|
|
89
|
+
if db_file and db_file != ":memory:":
|
|
90
|
+
raise RuntimeError("A db_file cannot be specified if in_memory is True")
|
|
91
|
+
db_file = ":memory:"
|
|
92
|
+
return db_file
|
|
93
|
+
|
|
94
|
+
|
|
85
95
|
class SQLiteDatabaseEngine(DatabaseEngine):
|
|
86
96
|
dialect = sqlite_dialect
|
|
87
97
|
|
|
@@ -122,6 +132,11 @@ class SQLiteDatabaseEngine(DatabaseEngine):
|
|
|
122
132
|
engine = sqlalchemy.create_engine(
|
|
123
133
|
"sqlite+pysqlite:///", creator=lambda: db, future=True
|
|
124
134
|
)
|
|
135
|
+
# ensure we run SA on_connect init (e.g it registers regexp function),
|
|
136
|
+
# also makes sure that it's consistent. Otherwise in some cases it
|
|
137
|
+
# seems we are getting different results if engine object is used in a
|
|
138
|
+
# different thread first and enine is not used in the Main thread.
|
|
139
|
+
engine.connect().close()
|
|
125
140
|
|
|
126
141
|
db.isolation_level = None # Use autocommit mode
|
|
127
142
|
db.execute("PRAGMA foreign_keys = ON")
|
|
@@ -260,7 +275,10 @@ class SQLiteIDGenerator(AbstractDBIDGenerator):
|
|
|
260
275
|
table_prefix: Optional[str] = None,
|
|
261
276
|
skip_db_init: bool = False,
|
|
262
277
|
db_file: Optional[str] = None,
|
|
278
|
+
in_memory: bool = False,
|
|
263
279
|
):
|
|
280
|
+
db_file = get_db_file_in_memory(db_file, in_memory)
|
|
281
|
+
|
|
264
282
|
db = db or SQLiteDatabaseEngine.from_db_file(db_file)
|
|
265
283
|
|
|
266
284
|
super().__init__(db, table_prefix, skip_db_init)
|
|
@@ -378,6 +396,7 @@ class SQLiteMetastore(AbstractDBMetastore):
|
|
|
378
396
|
partial_id: Optional[int] = None,
|
|
379
397
|
db: Optional["SQLiteDatabaseEngine"] = None,
|
|
380
398
|
db_file: Optional[str] = None,
|
|
399
|
+
in_memory: bool = False,
|
|
381
400
|
):
|
|
382
401
|
self.schema: DefaultSchema = DefaultSchema()
|
|
383
402
|
super().__init__(id_generator, uri, partial_id)
|
|
@@ -386,6 +405,8 @@ class SQLiteMetastore(AbstractDBMetastore):
|
|
|
386
405
|
# foreign keys
|
|
387
406
|
self.default_table_names: list[str] = []
|
|
388
407
|
|
|
408
|
+
db_file = get_db_file_in_memory(db_file, in_memory)
|
|
409
|
+
|
|
389
410
|
self.db = db or SQLiteDatabaseEngine.from_db_file(db_file)
|
|
390
411
|
|
|
391
412
|
self._init_tables()
|
|
@@ -550,10 +571,13 @@ class SQLiteWarehouse(AbstractWarehouse):
|
|
|
550
571
|
id_generator: "SQLiteIDGenerator",
|
|
551
572
|
db: Optional["SQLiteDatabaseEngine"] = None,
|
|
552
573
|
db_file: Optional[str] = None,
|
|
574
|
+
in_memory: bool = False,
|
|
553
575
|
):
|
|
554
576
|
self.schema: DefaultSchema = DefaultSchema()
|
|
555
577
|
super().__init__(id_generator)
|
|
556
578
|
|
|
579
|
+
db_file = get_db_file_in_memory(db_file, in_memory)
|
|
580
|
+
|
|
557
581
|
self.db = db or SQLiteDatabaseEngine.from_db_file(db_file)
|
|
558
582
|
|
|
559
583
|
def __exit__(self, exc_type, exc_value, traceback) -> None:
|
|
@@ -626,9 +650,7 @@ class SQLiteWarehouse(AbstractWarehouse):
|
|
|
626
650
|
self.db.create_table(table, if_not_exists=if_not_exists)
|
|
627
651
|
return table
|
|
628
652
|
|
|
629
|
-
def dataset_rows_select(
|
|
630
|
-
self, select_query: sqlalchemy.sql.selectable.Select, **kwargs
|
|
631
|
-
):
|
|
653
|
+
def dataset_rows_select(self, select_query: Select, **kwargs):
|
|
632
654
|
rows = self.db.execute(select_query, **kwargs)
|
|
633
655
|
yield from convert_rows_custom_column_types(
|
|
634
656
|
select_query.selected_columns, rows, sqlite_dialect
|
|
@@ -746,6 +768,34 @@ class SQLiteWarehouse(AbstractWarehouse):
|
|
|
746
768
|
) -> list[str]:
|
|
747
769
|
raise NotImplementedError("Exporting dataset table not implemented for SQLite")
|
|
748
770
|
|
|
771
|
+
def copy_table(
|
|
772
|
+
self,
|
|
773
|
+
table: Table,
|
|
774
|
+
query: Select,
|
|
775
|
+
progress_cb: Optional[Callable[[int], None]] = None,
|
|
776
|
+
) -> None:
|
|
777
|
+
if "sys__id" in query.selected_columns:
|
|
778
|
+
col_id = query.selected_columns.sys__id
|
|
779
|
+
else:
|
|
780
|
+
col_id = sqlalchemy.column("sys__id")
|
|
781
|
+
select_ids = query.with_only_columns(col_id)
|
|
782
|
+
|
|
783
|
+
ids = self.db.execute(select_ids).fetchall()
|
|
784
|
+
|
|
785
|
+
select_q = query.with_only_columns(
|
|
786
|
+
*[c for c in query.selected_columns if c.name != "sys__id"]
|
|
787
|
+
)
|
|
788
|
+
|
|
789
|
+
for batch in batched_it(ids, 10_000):
|
|
790
|
+
batch_ids = [row[0] for row in batch]
|
|
791
|
+
select_q._where_criteria = (col_id.in_(batch_ids),)
|
|
792
|
+
q = table.insert().from_select(list(select_q.selected_columns), select_q)
|
|
793
|
+
|
|
794
|
+
self.db.execute(q)
|
|
795
|
+
|
|
796
|
+
if progress_cb:
|
|
797
|
+
progress_cb(len(batch_ids))
|
|
798
|
+
|
|
749
799
|
def create_pre_udf_table(self, query: "Select") -> "Table":
|
|
750
800
|
"""
|
|
751
801
|
Create a temporary table from a query for use in a UDF.
|
|
@@ -757,11 +807,7 @@ class SQLiteWarehouse(AbstractWarehouse):
|
|
|
757
807
|
]
|
|
758
808
|
table = self.create_udf_table(columns)
|
|
759
809
|
|
|
760
|
-
|
|
761
|
-
|
|
762
|
-
)
|
|
763
|
-
self.db.execute(
|
|
764
|
-
table.insert().from_select(list(select_q.selected_columns), select_q)
|
|
765
|
-
)
|
|
810
|
+
with tqdm(desc="Preparing", unit=" rows") as pbar:
|
|
811
|
+
self.copy_table(table, query, progress_cb=pbar.update)
|
|
766
812
|
|
|
767
813
|
return table
|
|
@@ -6,7 +6,7 @@ import random
|
|
|
6
6
|
import string
|
|
7
7
|
from abc import ABC, abstractmethod
|
|
8
8
|
from collections.abc import Generator, Iterable, Iterator, Sequence
|
|
9
|
-
from typing import TYPE_CHECKING, Any, Optional, Union
|
|
9
|
+
from typing import TYPE_CHECKING, Any, Callable, Optional, Union
|
|
10
10
|
from urllib.parse import urlparse
|
|
11
11
|
|
|
12
12
|
import attrs
|
|
@@ -14,6 +14,7 @@ import sqlalchemy as sa
|
|
|
14
14
|
from sqlalchemy import Table, case, select
|
|
15
15
|
from sqlalchemy.sql import func
|
|
16
16
|
from sqlalchemy.sql.expression import true
|
|
17
|
+
from tqdm import tqdm
|
|
17
18
|
|
|
18
19
|
from datachain.client import Client
|
|
19
20
|
from datachain.data_storage.serializer import Serializable
|
|
@@ -901,6 +902,17 @@ class AbstractWarehouse(ABC, Serializable):
|
|
|
901
902
|
self.db.create_table(tbl, if_not_exists=True)
|
|
902
903
|
return tbl
|
|
903
904
|
|
|
905
|
+
@abstractmethod
|
|
906
|
+
def copy_table(
|
|
907
|
+
self,
|
|
908
|
+
table: Table,
|
|
909
|
+
query: "Select",
|
|
910
|
+
progress_cb: Optional[Callable[[int], None]] = None,
|
|
911
|
+
) -> None:
|
|
912
|
+
"""
|
|
913
|
+
Copy the results of a query into a table.
|
|
914
|
+
"""
|
|
915
|
+
|
|
904
916
|
@abstractmethod
|
|
905
917
|
def create_pre_udf_table(self, query: "Select") -> "Table":
|
|
906
918
|
"""
|
|
@@ -928,8 +940,10 @@ class AbstractWarehouse(ABC, Serializable):
|
|
|
928
940
|
This should be implemented to ensure that the provided tables
|
|
929
941
|
are cleaned up as soon as they are no longer needed.
|
|
930
942
|
"""
|
|
931
|
-
|
|
932
|
-
|
|
943
|
+
with tqdm(desc="Cleanup", unit=" tables") as pbar:
|
|
944
|
+
for name in names:
|
|
945
|
+
self.db.drop_table(Table(name, self.db.metadata), if_exists=True)
|
|
946
|
+
pbar.update(1)
|
|
933
947
|
|
|
934
948
|
def changed_query(
|
|
935
949
|
self,
|
datachain/lib/arrow.py
CHANGED
|
@@ -122,7 +122,7 @@ def _arrow_type_mapper(col_type: pa.DataType) -> type: # noqa: PLR0911
|
|
|
122
122
|
if pa.types.is_string(col_type) or pa.types.is_large_string(col_type):
|
|
123
123
|
return str
|
|
124
124
|
if pa.types.is_list(col_type):
|
|
125
|
-
return list[_arrow_type_mapper(col_type.value_type)] # type: ignore[misc]
|
|
125
|
+
return list[_arrow_type_mapper(col_type.value_type)] # type: ignore[return-value, misc]
|
|
126
126
|
if pa.types.is_struct(col_type) or pa.types.is_map(col_type):
|
|
127
127
|
return dict
|
|
128
128
|
if isinstance(col_type, pa.lib.DictionaryType):
|
|
@@ -1,7 +1,12 @@
|
|
|
1
1
|
from collections.abc import Sequence
|
|
2
2
|
from typing import Any, Union
|
|
3
3
|
|
|
4
|
-
from datachain.lib.data_model import
|
|
4
|
+
from datachain.lib.data_model import (
|
|
5
|
+
DataType,
|
|
6
|
+
DataTypeNames,
|
|
7
|
+
DataValuesType,
|
|
8
|
+
is_chain_type,
|
|
9
|
+
)
|
|
5
10
|
from datachain.lib.utils import DataChainParamsError
|
|
6
11
|
|
|
7
12
|
|
|
@@ -15,7 +20,7 @@ class ValuesToTupleError(DataChainParamsError):
|
|
|
15
20
|
def values_to_tuples( # noqa: C901, PLR0912
|
|
16
21
|
ds_name: str = "",
|
|
17
22
|
output: Union[None, DataType, Sequence[str], dict[str, DataType]] = None,
|
|
18
|
-
**fr_map,
|
|
23
|
+
**fr_map: Sequence[DataValuesType],
|
|
19
24
|
) -> tuple[Any, Any, Any]:
|
|
20
25
|
if output:
|
|
21
26
|
if not isinstance(output, (Sequence, str, dict)):
|
|
@@ -47,10 +52,10 @@ def values_to_tuples( # noqa: C901, PLR0912
|
|
|
47
52
|
f" number of signals '{len(fr_map)}'",
|
|
48
53
|
)
|
|
49
54
|
|
|
50
|
-
types_map = {}
|
|
55
|
+
types_map: dict[str, type] = {}
|
|
51
56
|
length = -1
|
|
52
57
|
for k, v in fr_map.items():
|
|
53
|
-
if not isinstance(v, Sequence) or isinstance(v, str):
|
|
58
|
+
if not isinstance(v, Sequence) or isinstance(v, str): # type: ignore[unreachable]
|
|
54
59
|
raise ValuesToTupleError(ds_name, f"signals '{k}' is not a sequence")
|
|
55
60
|
len_ = len(v)
|
|
56
61
|
|
|
@@ -64,15 +69,16 @@ def values_to_tuples( # noqa: C901, PLR0912
|
|
|
64
69
|
if len_ == 0:
|
|
65
70
|
raise ValuesToTupleError(ds_name, f"signal '{k}' is empty list")
|
|
66
71
|
|
|
67
|
-
|
|
72
|
+
first_element = next(iter(v))
|
|
73
|
+
typ = type(first_element)
|
|
68
74
|
if not is_chain_type(typ):
|
|
69
75
|
raise ValuesToTupleError(
|
|
70
76
|
ds_name,
|
|
71
77
|
f"signal '{k}' has unsupported type '{typ.__name__}'."
|
|
72
78
|
f" Please use DataModel types: {DataTypeNames}",
|
|
73
79
|
)
|
|
74
|
-
if
|
|
75
|
-
types_map[k] = list[type(
|
|
80
|
+
if isinstance(first_element, list):
|
|
81
|
+
types_map[k] = list[type(first_element[0])] # type: ignore[assignment, misc]
|
|
76
82
|
else:
|
|
77
83
|
types_map[k] = typ
|
|
78
84
|
|
|
@@ -98,7 +104,7 @@ def values_to_tuples( # noqa: C901, PLR0912
|
|
|
98
104
|
if len(output) > 1: # type: ignore[arg-type]
|
|
99
105
|
tuple_type = tuple(output_types)
|
|
100
106
|
res_type = tuple[tuple_type] # type: ignore[valid-type]
|
|
101
|
-
res_values = list(zip(*fr_map.values()))
|
|
107
|
+
res_values: Sequence[Any] = list(zip(*fr_map.values()))
|
|
102
108
|
else:
|
|
103
109
|
res_type = output_types[0] # type: ignore[misc]
|
|
104
110
|
res_values = next(iter(fr_map.values()))
|