datachain 0.3.2__py3-none-any.whl → 0.3.3__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of datachain might be problematic. Click here for more details.
- datachain/asyn.py +20 -0
- datachain/catalog/catalog.py +2 -0
- datachain/catalog/loader.py +75 -50
- datachain/client/azure.py +13 -0
- datachain/client/gcs.py +12 -0
- datachain/client/local.py +11 -0
- datachain/client/s3.py +12 -0
- datachain/data_storage/sqlite.py +55 -14
- datachain/data_storage/warehouse.py +17 -3
- datachain/lib/arrow.py +1 -1
- datachain/lib/convert/values_to_tuples.py +14 -8
- datachain/lib/data_model.py +1 -0
- datachain/lib/dc.py +25 -6
- datachain/lib/listing.py +111 -0
- datachain/query/dataset.py +22 -12
- datachain/query/session.py +9 -2
- datachain/sql/sqlite/base.py +30 -4
- {datachain-0.3.2.dist-info → datachain-0.3.3.dist-info}/METADATA +2 -2
- {datachain-0.3.2.dist-info → datachain-0.3.3.dist-info}/RECORD +23 -22
- {datachain-0.3.2.dist-info → datachain-0.3.3.dist-info}/LICENSE +0 -0
- {datachain-0.3.2.dist-info → datachain-0.3.3.dist-info}/WHEEL +0 -0
- {datachain-0.3.2.dist-info → datachain-0.3.3.dist-info}/entry_points.txt +0 -0
- {datachain-0.3.2.dist-info → datachain-0.3.3.dist-info}/top_level.txt +0 -0
datachain/asyn.py
CHANGED
|
@@ -224,3 +224,23 @@ class OrderedMapper(AsyncMapper[InputT, ResultT]):
|
|
|
224
224
|
async def _break_iteration(self) -> None:
|
|
225
225
|
self.heap = []
|
|
226
226
|
self._push_result(self._next_yield, None)
|
|
227
|
+
|
|
228
|
+
|
|
229
|
+
def iter_over_async(ait, loop):
|
|
230
|
+
"""Wrap an asynchronous iterator into a synchronous one"""
|
|
231
|
+
ait = ait.__aiter__()
|
|
232
|
+
|
|
233
|
+
# helper async fn that just gets the next element from the async iterator
|
|
234
|
+
async def get_next():
|
|
235
|
+
try:
|
|
236
|
+
obj = await ait.__anext__()
|
|
237
|
+
return False, obj
|
|
238
|
+
except StopAsyncIteration:
|
|
239
|
+
return True, None
|
|
240
|
+
|
|
241
|
+
# actual sync iterator
|
|
242
|
+
while True:
|
|
243
|
+
done, obj = asyncio.run_coroutine_threadsafe(get_next(), loop).result()
|
|
244
|
+
if done:
|
|
245
|
+
break
|
|
246
|
+
yield obj
|
datachain/catalog/catalog.py
CHANGED
|
@@ -577,6 +577,7 @@ class Catalog:
|
|
|
577
577
|
warehouse_ready_callback: Optional[
|
|
578
578
|
Callable[["AbstractWarehouse"], None]
|
|
579
579
|
] = None,
|
|
580
|
+
in_memory: bool = False,
|
|
580
581
|
):
|
|
581
582
|
datachain_dir = DataChainDir(cache=cache_dir, tmp=tmp_dir)
|
|
582
583
|
datachain_dir.init()
|
|
@@ -590,6 +591,7 @@ class Catalog:
|
|
|
590
591
|
"tmp_dir": tmp_dir,
|
|
591
592
|
}
|
|
592
593
|
self._warehouse_ready_callback = warehouse_ready_callback
|
|
594
|
+
self.in_memory = in_memory
|
|
593
595
|
|
|
594
596
|
@cached_property
|
|
595
597
|
def warehouse(self) -> "AbstractWarehouse":
|
datachain/catalog/loader.py
CHANGED
|
@@ -28,8 +28,10 @@ WAREHOUSE_ARG_PREFIX = "DATACHAIN_WAREHOUSE_ARG_"
|
|
|
28
28
|
DISTRIBUTED_IMPORT_PATH = "DATACHAIN_DISTRIBUTED"
|
|
29
29
|
DISTRIBUTED_ARG_PREFIX = "DATACHAIN_DISTRIBUTED_ARG_"
|
|
30
30
|
|
|
31
|
+
IN_MEMORY_ERROR_MESSAGE = "In-memory is only supported on SQLite"
|
|
31
32
|
|
|
32
|
-
|
|
33
|
+
|
|
34
|
+
def get_id_generator(in_memory: bool = False) -> "AbstractIDGenerator":
|
|
33
35
|
id_generator_serialized = os.environ.get(ID_GENERATOR_SERIALIZED)
|
|
34
36
|
if id_generator_serialized:
|
|
35
37
|
id_generator_obj = deserialize(id_generator_serialized)
|
|
@@ -43,25 +45,31 @@ def get_id_generator() -> "AbstractIDGenerator":
|
|
|
43
45
|
id_generator_import_path = os.environ.get(ID_GENERATOR_IMPORT_PATH)
|
|
44
46
|
id_generator_arg_envs = get_envs_by_prefix(ID_GENERATOR_ARG_PREFIX)
|
|
45
47
|
# Convert env variable names to keyword argument names by lowercasing them
|
|
46
|
-
id_generator_args
|
|
47
|
-
|
|
48
|
-
|
|
49
|
-
|
|
50
|
-
|
|
51
|
-
|
|
52
|
-
|
|
53
|
-
|
|
54
|
-
|
|
55
|
-
|
|
56
|
-
|
|
57
|
-
|
|
58
|
-
|
|
59
|
-
|
|
60
|
-
|
|
48
|
+
id_generator_args: dict[str, Any] = {
|
|
49
|
+
k.lower(): v for k, v in id_generator_arg_envs.items()
|
|
50
|
+
}
|
|
51
|
+
|
|
52
|
+
if not id_generator_import_path:
|
|
53
|
+
id_generator_args["in_memory"] = in_memory
|
|
54
|
+
return SQLiteIDGenerator(**id_generator_args)
|
|
55
|
+
if in_memory:
|
|
56
|
+
raise RuntimeError(IN_MEMORY_ERROR_MESSAGE)
|
|
57
|
+
# ID generator paths are specified as (for example):
|
|
58
|
+
# datachain.data_storage.SQLiteIDGenerator
|
|
59
|
+
if "." not in id_generator_import_path:
|
|
60
|
+
raise RuntimeError(
|
|
61
|
+
f"Invalid {ID_GENERATOR_IMPORT_PATH} import path:"
|
|
62
|
+
f"{id_generator_import_path}"
|
|
63
|
+
)
|
|
64
|
+
module_name, _, class_name = id_generator_import_path.rpartition(".")
|
|
65
|
+
id_generator = import_module(module_name)
|
|
66
|
+
id_generator_class = getattr(id_generator, class_name)
|
|
61
67
|
return id_generator_class(**id_generator_args)
|
|
62
68
|
|
|
63
69
|
|
|
64
|
-
def get_metastore(
|
|
70
|
+
def get_metastore(
|
|
71
|
+
id_generator: Optional["AbstractIDGenerator"], in_memory: bool = False
|
|
72
|
+
) -> "AbstractMetastore":
|
|
65
73
|
metastore_serialized = os.environ.get(METASTORE_SERIALIZED)
|
|
66
74
|
if metastore_serialized:
|
|
67
75
|
metastore_obj = deserialize(metastore_serialized)
|
|
@@ -78,24 +86,32 @@ def get_metastore(id_generator: Optional["AbstractIDGenerator"]) -> "AbstractMet
|
|
|
78
86
|
metastore_import_path = os.environ.get(METASTORE_IMPORT_PATH)
|
|
79
87
|
metastore_arg_envs = get_envs_by_prefix(METASTORE_ARG_PREFIX)
|
|
80
88
|
# Convert env variable names to keyword argument names by lowercasing them
|
|
81
|
-
metastore_args
|
|
82
|
-
|
|
83
|
-
|
|
84
|
-
|
|
85
|
-
|
|
86
|
-
if
|
|
87
|
-
raise
|
|
88
|
-
|
|
89
|
-
|
|
90
|
-
|
|
91
|
-
|
|
92
|
-
|
|
93
|
-
|
|
94
|
-
|
|
89
|
+
metastore_args: dict[str, Any] = {
|
|
90
|
+
k.lower(): v for k, v in metastore_arg_envs.items()
|
|
91
|
+
}
|
|
92
|
+
|
|
93
|
+
if not metastore_import_path:
|
|
94
|
+
if not isinstance(id_generator, SQLiteIDGenerator):
|
|
95
|
+
raise ValueError("SQLiteMetastore can only be used with SQLiteIDGenerator")
|
|
96
|
+
metastore_args["in_memory"] = in_memory
|
|
97
|
+
return SQLiteMetastore(id_generator, **metastore_args)
|
|
98
|
+
if in_memory:
|
|
99
|
+
raise RuntimeError(IN_MEMORY_ERROR_MESSAGE)
|
|
100
|
+
# Metastore paths are specified as (for example):
|
|
101
|
+
# datachain.data_storage.SQLiteMetastore
|
|
102
|
+
if "." not in metastore_import_path:
|
|
103
|
+
raise RuntimeError(
|
|
104
|
+
f"Invalid {METASTORE_IMPORT_PATH} import path: {metastore_import_path}"
|
|
105
|
+
)
|
|
106
|
+
module_name, _, class_name = metastore_import_path.rpartition(".")
|
|
107
|
+
metastore = import_module(module_name)
|
|
108
|
+
metastore_class = getattr(metastore, class_name)
|
|
95
109
|
return metastore_class(id_generator, **metastore_args)
|
|
96
110
|
|
|
97
111
|
|
|
98
|
-
def get_warehouse(
|
|
112
|
+
def get_warehouse(
|
|
113
|
+
id_generator: Optional["AbstractIDGenerator"], in_memory: bool = False
|
|
114
|
+
) -> "AbstractWarehouse":
|
|
99
115
|
warehouse_serialized = os.environ.get(WAREHOUSE_SERIALIZED)
|
|
100
116
|
if warehouse_serialized:
|
|
101
117
|
warehouse_obj = deserialize(warehouse_serialized)
|
|
@@ -112,20 +128,26 @@ def get_warehouse(id_generator: Optional["AbstractIDGenerator"]) -> "AbstractWar
|
|
|
112
128
|
warehouse_import_path = os.environ.get(WAREHOUSE_IMPORT_PATH)
|
|
113
129
|
warehouse_arg_envs = get_envs_by_prefix(WAREHOUSE_ARG_PREFIX)
|
|
114
130
|
# Convert env variable names to keyword argument names by lowercasing them
|
|
115
|
-
warehouse_args
|
|
116
|
-
|
|
117
|
-
|
|
118
|
-
|
|
119
|
-
|
|
120
|
-
if
|
|
121
|
-
raise
|
|
122
|
-
|
|
123
|
-
|
|
124
|
-
|
|
125
|
-
|
|
126
|
-
|
|
127
|
-
|
|
128
|
-
|
|
131
|
+
warehouse_args: dict[str, Any] = {
|
|
132
|
+
k.lower(): v for k, v in warehouse_arg_envs.items()
|
|
133
|
+
}
|
|
134
|
+
|
|
135
|
+
if not warehouse_import_path:
|
|
136
|
+
if not isinstance(id_generator, SQLiteIDGenerator):
|
|
137
|
+
raise ValueError("SQLiteWarehouse can only be used with SQLiteIDGenerator")
|
|
138
|
+
warehouse_args["in_memory"] = in_memory
|
|
139
|
+
return SQLiteWarehouse(id_generator, **warehouse_args)
|
|
140
|
+
if in_memory:
|
|
141
|
+
raise RuntimeError(IN_MEMORY_ERROR_MESSAGE)
|
|
142
|
+
# Warehouse paths are specified as (for example):
|
|
143
|
+
# datachain.data_storage.SQLiteWarehouse
|
|
144
|
+
if "." not in warehouse_import_path:
|
|
145
|
+
raise RuntimeError(
|
|
146
|
+
f"Invalid {WAREHOUSE_IMPORT_PATH} import path: {warehouse_import_path}"
|
|
147
|
+
)
|
|
148
|
+
module_name, _, class_name = warehouse_import_path.rpartition(".")
|
|
149
|
+
warehouse = import_module(module_name)
|
|
150
|
+
warehouse_class = getattr(warehouse, class_name)
|
|
129
151
|
return warehouse_class(id_generator, **warehouse_args)
|
|
130
152
|
|
|
131
153
|
|
|
@@ -152,7 +174,9 @@ def get_distributed_class(**kwargs):
|
|
|
152
174
|
return distributed_class(**distributed_args | kwargs)
|
|
153
175
|
|
|
154
176
|
|
|
155
|
-
def get_catalog(
|
|
177
|
+
def get_catalog(
|
|
178
|
+
client_config: Optional[dict[str, Any]] = None, in_memory: bool = False
|
|
179
|
+
) -> Catalog:
|
|
156
180
|
"""
|
|
157
181
|
Function that creates Catalog instance with appropriate metastore
|
|
158
182
|
and warehouse classes. Metastore class can be provided with env variable
|
|
@@ -164,10 +188,11 @@ def get_catalog(client_config: Optional[dict[str, Any]] = None) -> Catalog:
|
|
|
164
188
|
and name of variable after, e.g. if it accepts team_id as kwargs
|
|
165
189
|
we can provide DATACHAIN_METASTORE_ARG_TEAM_ID=12345 env variable.
|
|
166
190
|
"""
|
|
167
|
-
id_generator = get_id_generator()
|
|
191
|
+
id_generator = get_id_generator(in_memory=in_memory)
|
|
168
192
|
return Catalog(
|
|
169
193
|
id_generator=id_generator,
|
|
170
|
-
metastore=get_metastore(id_generator),
|
|
171
|
-
warehouse=get_warehouse(id_generator),
|
|
194
|
+
metastore=get_metastore(id_generator, in_memory=in_memory),
|
|
195
|
+
warehouse=get_warehouse(id_generator, in_memory=in_memory),
|
|
172
196
|
client_config=client_config,
|
|
197
|
+
in_memory=in_memory,
|
|
173
198
|
)
|
datachain/client/azure.py
CHANGED
|
@@ -3,6 +3,7 @@ from typing import Any
|
|
|
3
3
|
from adlfs import AzureBlobFileSystem
|
|
4
4
|
from tqdm import tqdm
|
|
5
5
|
|
|
6
|
+
from datachain.lib.file import File
|
|
6
7
|
from datachain.node import Entry
|
|
7
8
|
|
|
8
9
|
from .fsspec import DELIMITER, Client, ResultQueue
|
|
@@ -24,6 +25,18 @@ class AzureClient(Client):
|
|
|
24
25
|
size=v.get("size", ""),
|
|
25
26
|
)
|
|
26
27
|
|
|
28
|
+
def info_to_file(self, v: dict[str, Any], path: str) -> File:
|
|
29
|
+
version_id = v.get("version_id")
|
|
30
|
+
return File(
|
|
31
|
+
source=self.uri,
|
|
32
|
+
path=path,
|
|
33
|
+
etag=v.get("etag", "").strip('"'),
|
|
34
|
+
version=version_id or "",
|
|
35
|
+
is_latest=version_id is None or bool(v.get("is_current_version")),
|
|
36
|
+
last_modified=v["last_modified"],
|
|
37
|
+
size=v.get("size", ""),
|
|
38
|
+
)
|
|
39
|
+
|
|
27
40
|
async def _fetch_flat(self, start_prefix: str, result_queue: ResultQueue) -> None:
|
|
28
41
|
prefix = start_prefix
|
|
29
42
|
if prefix:
|
datachain/client/gcs.py
CHANGED
|
@@ -9,6 +9,7 @@ from dateutil.parser import isoparse
|
|
|
9
9
|
from gcsfs import GCSFileSystem
|
|
10
10
|
from tqdm import tqdm
|
|
11
11
|
|
|
12
|
+
from datachain.lib.file import File
|
|
12
13
|
from datachain.node import Entry
|
|
13
14
|
|
|
14
15
|
from .fsspec import DELIMITER, Client, ResultQueue
|
|
@@ -120,3 +121,14 @@ class GCSClient(Client):
|
|
|
120
121
|
last_modified=self.parse_timestamp(v["updated"]),
|
|
121
122
|
size=v.get("size", ""),
|
|
122
123
|
)
|
|
124
|
+
|
|
125
|
+
def info_to_file(self, v: dict[str, Any], path: str) -> File:
|
|
126
|
+
return File(
|
|
127
|
+
source=self.uri,
|
|
128
|
+
path=path,
|
|
129
|
+
etag=v.get("etag", ""),
|
|
130
|
+
version=v.get("generation", ""),
|
|
131
|
+
is_latest=not v.get("timeDeleted"),
|
|
132
|
+
last_modified=self.parse_timestamp(v["updated"]),
|
|
133
|
+
size=v.get("size", ""),
|
|
134
|
+
)
|
datachain/client/local.py
CHANGED
|
@@ -7,6 +7,7 @@ from urllib.parse import urlparse
|
|
|
7
7
|
|
|
8
8
|
from fsspec.implementations.local import LocalFileSystem
|
|
9
9
|
|
|
10
|
+
from datachain.lib.file import File
|
|
10
11
|
from datachain.node import Entry
|
|
11
12
|
from datachain.storage import StorageURI
|
|
12
13
|
|
|
@@ -144,6 +145,16 @@ class FileClient(Client):
|
|
|
144
145
|
size=v.get("size", ""),
|
|
145
146
|
)
|
|
146
147
|
|
|
148
|
+
def info_to_file(self, v: dict[str, Any], path: str) -> File:
|
|
149
|
+
return File(
|
|
150
|
+
source=self.uri,
|
|
151
|
+
path=path,
|
|
152
|
+
size=v.get("size", ""),
|
|
153
|
+
etag=v["mtime"].hex(),
|
|
154
|
+
is_latest=True,
|
|
155
|
+
last_modified=datetime.fromtimestamp(v["mtime"], timezone.utc),
|
|
156
|
+
)
|
|
157
|
+
|
|
147
158
|
def fetch_nodes(
|
|
148
159
|
self,
|
|
149
160
|
nodes,
|
datachain/client/s3.py
CHANGED
|
@@ -5,6 +5,7 @@ from botocore.exceptions import NoCredentialsError
|
|
|
5
5
|
from s3fs import S3FileSystem
|
|
6
6
|
from tqdm import tqdm
|
|
7
7
|
|
|
8
|
+
from datachain.lib.file import File
|
|
8
9
|
from datachain.node import Entry
|
|
9
10
|
|
|
10
11
|
from .fsspec import DELIMITER, Client, ResultQueue
|
|
@@ -167,3 +168,14 @@ class ClientS3(Client):
|
|
|
167
168
|
owner_name=v.get("Owner", {}).get("DisplayName", ""),
|
|
168
169
|
owner_id=v.get("Owner", {}).get("ID", ""),
|
|
169
170
|
)
|
|
171
|
+
|
|
172
|
+
def info_to_file(self, v: dict[str, Any], path: str) -> File:
|
|
173
|
+
return File(
|
|
174
|
+
source=self.uri,
|
|
175
|
+
path=path,
|
|
176
|
+
size=v["size"],
|
|
177
|
+
version=ClientS3.clean_s3_version(v.get("VersionId", "")),
|
|
178
|
+
etag=v.get("ETag", "").strip('"'),
|
|
179
|
+
is_latest=v.get("IsLatest", True),
|
|
180
|
+
last_modified=v.get("LastModified", ""),
|
|
181
|
+
)
|
datachain/data_storage/sqlite.py
CHANGED
|
@@ -20,6 +20,8 @@ from sqlalchemy.dialects import sqlite
|
|
|
20
20
|
from sqlalchemy.schema import CreateIndex, CreateTable, DropTable
|
|
21
21
|
from sqlalchemy.sql import func
|
|
22
22
|
from sqlalchemy.sql.expression import bindparam, cast
|
|
23
|
+
from sqlalchemy.sql.selectable import Select
|
|
24
|
+
from tqdm import tqdm
|
|
23
25
|
|
|
24
26
|
import datachain.sql.sqlite
|
|
25
27
|
from datachain.data_storage import AbstractDBMetastore, AbstractWarehouse
|
|
@@ -35,14 +37,13 @@ from datachain.sql.sqlite import create_user_defined_sql_functions, sqlite_diale
|
|
|
35
37
|
from datachain.sql.sqlite.base import load_usearch_extension
|
|
36
38
|
from datachain.sql.types import SQLType
|
|
37
39
|
from datachain.storage import StorageURI
|
|
38
|
-
from datachain.utils import DataChainDir
|
|
40
|
+
from datachain.utils import DataChainDir, batched_it
|
|
39
41
|
|
|
40
42
|
if TYPE_CHECKING:
|
|
41
43
|
from sqlalchemy.dialects.sqlite import Insert
|
|
42
44
|
from sqlalchemy.engine.base import Engine
|
|
43
45
|
from sqlalchemy.schema import SchemaItem
|
|
44
|
-
from sqlalchemy.sql.elements import
|
|
45
|
-
from sqlalchemy.sql.selectable import Select
|
|
46
|
+
from sqlalchemy.sql.elements import ColumnElement
|
|
46
47
|
from sqlalchemy.types import TypeEngine
|
|
47
48
|
|
|
48
49
|
|
|
@@ -54,8 +55,6 @@ RETRY_FACTOR = 2
|
|
|
54
55
|
|
|
55
56
|
DETECT_TYPES = sqlite3.PARSE_DECLTYPES | sqlite3.PARSE_COLNAMES
|
|
56
57
|
|
|
57
|
-
Column = Union[str, "ColumnClause[Any]", "TextClause"]
|
|
58
|
-
|
|
59
58
|
datachain.sql.sqlite.setup()
|
|
60
59
|
|
|
61
60
|
quote_schema = sqlite_dialect.identifier_preparer.quote_schema
|
|
@@ -82,6 +81,17 @@ def retry_sqlite_locks(func):
|
|
|
82
81
|
return wrapper
|
|
83
82
|
|
|
84
83
|
|
|
84
|
+
def get_db_file_in_memory(
|
|
85
|
+
db_file: Optional[str] = None, in_memory: bool = False
|
|
86
|
+
) -> Optional[str]:
|
|
87
|
+
"""Get in-memory db_file and check that conflicting arguments are not provided."""
|
|
88
|
+
if in_memory:
|
|
89
|
+
if db_file and db_file != ":memory:":
|
|
90
|
+
raise RuntimeError("A db_file cannot be specified if in_memory is True")
|
|
91
|
+
db_file = ":memory:"
|
|
92
|
+
return db_file
|
|
93
|
+
|
|
94
|
+
|
|
85
95
|
class SQLiteDatabaseEngine(DatabaseEngine):
|
|
86
96
|
dialect = sqlite_dialect
|
|
87
97
|
|
|
@@ -265,7 +275,10 @@ class SQLiteIDGenerator(AbstractDBIDGenerator):
|
|
|
265
275
|
table_prefix: Optional[str] = None,
|
|
266
276
|
skip_db_init: bool = False,
|
|
267
277
|
db_file: Optional[str] = None,
|
|
278
|
+
in_memory: bool = False,
|
|
268
279
|
):
|
|
280
|
+
db_file = get_db_file_in_memory(db_file, in_memory)
|
|
281
|
+
|
|
269
282
|
db = db or SQLiteDatabaseEngine.from_db_file(db_file)
|
|
270
283
|
|
|
271
284
|
super().__init__(db, table_prefix, skip_db_init)
|
|
@@ -383,6 +396,7 @@ class SQLiteMetastore(AbstractDBMetastore):
|
|
|
383
396
|
partial_id: Optional[int] = None,
|
|
384
397
|
db: Optional["SQLiteDatabaseEngine"] = None,
|
|
385
398
|
db_file: Optional[str] = None,
|
|
399
|
+
in_memory: bool = False,
|
|
386
400
|
):
|
|
387
401
|
self.schema: DefaultSchema = DefaultSchema()
|
|
388
402
|
super().__init__(id_generator, uri, partial_id)
|
|
@@ -391,6 +405,8 @@ class SQLiteMetastore(AbstractDBMetastore):
|
|
|
391
405
|
# foreign keys
|
|
392
406
|
self.default_table_names: list[str] = []
|
|
393
407
|
|
|
408
|
+
db_file = get_db_file_in_memory(db_file, in_memory)
|
|
409
|
+
|
|
394
410
|
self.db = db or SQLiteDatabaseEngine.from_db_file(db_file)
|
|
395
411
|
|
|
396
412
|
self._init_tables()
|
|
@@ -555,10 +571,13 @@ class SQLiteWarehouse(AbstractWarehouse):
|
|
|
555
571
|
id_generator: "SQLiteIDGenerator",
|
|
556
572
|
db: Optional["SQLiteDatabaseEngine"] = None,
|
|
557
573
|
db_file: Optional[str] = None,
|
|
574
|
+
in_memory: bool = False,
|
|
558
575
|
):
|
|
559
576
|
self.schema: DefaultSchema = DefaultSchema()
|
|
560
577
|
super().__init__(id_generator)
|
|
561
578
|
|
|
579
|
+
db_file = get_db_file_in_memory(db_file, in_memory)
|
|
580
|
+
|
|
562
581
|
self.db = db or SQLiteDatabaseEngine.from_db_file(db_file)
|
|
563
582
|
|
|
564
583
|
def __exit__(self, exc_type, exc_value, traceback) -> None:
|
|
@@ -631,9 +650,7 @@ class SQLiteWarehouse(AbstractWarehouse):
|
|
|
631
650
|
self.db.create_table(table, if_not_exists=if_not_exists)
|
|
632
651
|
return table
|
|
633
652
|
|
|
634
|
-
def dataset_rows_select(
|
|
635
|
-
self, select_query: sqlalchemy.sql.selectable.Select, **kwargs
|
|
636
|
-
):
|
|
653
|
+
def dataset_rows_select(self, select_query: Select, **kwargs):
|
|
637
654
|
rows = self.db.execute(select_query, **kwargs)
|
|
638
655
|
yield from convert_rows_custom_column_types(
|
|
639
656
|
select_query.selected_columns, rows, sqlite_dialect
|
|
@@ -751,6 +768,34 @@ class SQLiteWarehouse(AbstractWarehouse):
|
|
|
751
768
|
) -> list[str]:
|
|
752
769
|
raise NotImplementedError("Exporting dataset table not implemented for SQLite")
|
|
753
770
|
|
|
771
|
+
def copy_table(
|
|
772
|
+
self,
|
|
773
|
+
table: Table,
|
|
774
|
+
query: Select,
|
|
775
|
+
progress_cb: Optional[Callable[[int], None]] = None,
|
|
776
|
+
) -> None:
|
|
777
|
+
if "sys__id" in query.selected_columns:
|
|
778
|
+
col_id = query.selected_columns.sys__id
|
|
779
|
+
else:
|
|
780
|
+
col_id = sqlalchemy.column("sys__id")
|
|
781
|
+
select_ids = query.with_only_columns(col_id)
|
|
782
|
+
|
|
783
|
+
ids = self.db.execute(select_ids).fetchall()
|
|
784
|
+
|
|
785
|
+
select_q = query.with_only_columns(
|
|
786
|
+
*[c for c in query.selected_columns if c.name != "sys__id"]
|
|
787
|
+
)
|
|
788
|
+
|
|
789
|
+
for batch in batched_it(ids, 10_000):
|
|
790
|
+
batch_ids = [row[0] for row in batch]
|
|
791
|
+
select_q._where_criteria = (col_id.in_(batch_ids),)
|
|
792
|
+
q = table.insert().from_select(list(select_q.selected_columns), select_q)
|
|
793
|
+
|
|
794
|
+
self.db.execute(q)
|
|
795
|
+
|
|
796
|
+
if progress_cb:
|
|
797
|
+
progress_cb(len(batch_ids))
|
|
798
|
+
|
|
754
799
|
def create_pre_udf_table(self, query: "Select") -> "Table":
|
|
755
800
|
"""
|
|
756
801
|
Create a temporary table from a query for use in a UDF.
|
|
@@ -762,11 +807,7 @@ class SQLiteWarehouse(AbstractWarehouse):
|
|
|
762
807
|
]
|
|
763
808
|
table = self.create_udf_table(columns)
|
|
764
809
|
|
|
765
|
-
|
|
766
|
-
|
|
767
|
-
)
|
|
768
|
-
self.db.execute(
|
|
769
|
-
table.insert().from_select(list(select_q.selected_columns), select_q)
|
|
770
|
-
)
|
|
810
|
+
with tqdm(desc="Preparing", unit=" rows") as pbar:
|
|
811
|
+
self.copy_table(table, query, progress_cb=pbar.update)
|
|
771
812
|
|
|
772
813
|
return table
|
|
@@ -6,7 +6,7 @@ import random
|
|
|
6
6
|
import string
|
|
7
7
|
from abc import ABC, abstractmethod
|
|
8
8
|
from collections.abc import Generator, Iterable, Iterator, Sequence
|
|
9
|
-
from typing import TYPE_CHECKING, Any, Optional, Union
|
|
9
|
+
from typing import TYPE_CHECKING, Any, Callable, Optional, Union
|
|
10
10
|
from urllib.parse import urlparse
|
|
11
11
|
|
|
12
12
|
import attrs
|
|
@@ -14,6 +14,7 @@ import sqlalchemy as sa
|
|
|
14
14
|
from sqlalchemy import Table, case, select
|
|
15
15
|
from sqlalchemy.sql import func
|
|
16
16
|
from sqlalchemy.sql.expression import true
|
|
17
|
+
from tqdm import tqdm
|
|
17
18
|
|
|
18
19
|
from datachain.client import Client
|
|
19
20
|
from datachain.data_storage.serializer import Serializable
|
|
@@ -901,6 +902,17 @@ class AbstractWarehouse(ABC, Serializable):
|
|
|
901
902
|
self.db.create_table(tbl, if_not_exists=True)
|
|
902
903
|
return tbl
|
|
903
904
|
|
|
905
|
+
@abstractmethod
|
|
906
|
+
def copy_table(
|
|
907
|
+
self,
|
|
908
|
+
table: Table,
|
|
909
|
+
query: "Select",
|
|
910
|
+
progress_cb: Optional[Callable[[int], None]] = None,
|
|
911
|
+
) -> None:
|
|
912
|
+
"""
|
|
913
|
+
Copy the results of a query into a table.
|
|
914
|
+
"""
|
|
915
|
+
|
|
904
916
|
@abstractmethod
|
|
905
917
|
def create_pre_udf_table(self, query: "Select") -> "Table":
|
|
906
918
|
"""
|
|
@@ -928,8 +940,10 @@ class AbstractWarehouse(ABC, Serializable):
|
|
|
928
940
|
This should be implemented to ensure that the provided tables
|
|
929
941
|
are cleaned up as soon as they are no longer needed.
|
|
930
942
|
"""
|
|
931
|
-
|
|
932
|
-
|
|
943
|
+
with tqdm(desc="Cleanup", unit=" tables") as pbar:
|
|
944
|
+
for name in names:
|
|
945
|
+
self.db.drop_table(Table(name, self.db.metadata), if_exists=True)
|
|
946
|
+
pbar.update(1)
|
|
933
947
|
|
|
934
948
|
def changed_query(
|
|
935
949
|
self,
|
datachain/lib/arrow.py
CHANGED
|
@@ -122,7 +122,7 @@ def _arrow_type_mapper(col_type: pa.DataType) -> type: # noqa: PLR0911
|
|
|
122
122
|
if pa.types.is_string(col_type) or pa.types.is_large_string(col_type):
|
|
123
123
|
return str
|
|
124
124
|
if pa.types.is_list(col_type):
|
|
125
|
-
return list[_arrow_type_mapper(col_type.value_type)] # type: ignore[misc]
|
|
125
|
+
return list[_arrow_type_mapper(col_type.value_type)] # type: ignore[return-value, misc]
|
|
126
126
|
if pa.types.is_struct(col_type) or pa.types.is_map(col_type):
|
|
127
127
|
return dict
|
|
128
128
|
if isinstance(col_type, pa.lib.DictionaryType):
|
|
@@ -1,7 +1,12 @@
|
|
|
1
1
|
from collections.abc import Sequence
|
|
2
2
|
from typing import Any, Union
|
|
3
3
|
|
|
4
|
-
from datachain.lib.data_model import
|
|
4
|
+
from datachain.lib.data_model import (
|
|
5
|
+
DataType,
|
|
6
|
+
DataTypeNames,
|
|
7
|
+
DataValuesType,
|
|
8
|
+
is_chain_type,
|
|
9
|
+
)
|
|
5
10
|
from datachain.lib.utils import DataChainParamsError
|
|
6
11
|
|
|
7
12
|
|
|
@@ -15,7 +20,7 @@ class ValuesToTupleError(DataChainParamsError):
|
|
|
15
20
|
def values_to_tuples( # noqa: C901, PLR0912
|
|
16
21
|
ds_name: str = "",
|
|
17
22
|
output: Union[None, DataType, Sequence[str], dict[str, DataType]] = None,
|
|
18
|
-
**fr_map,
|
|
23
|
+
**fr_map: Sequence[DataValuesType],
|
|
19
24
|
) -> tuple[Any, Any, Any]:
|
|
20
25
|
if output:
|
|
21
26
|
if not isinstance(output, (Sequence, str, dict)):
|
|
@@ -47,10 +52,10 @@ def values_to_tuples( # noqa: C901, PLR0912
|
|
|
47
52
|
f" number of signals '{len(fr_map)}'",
|
|
48
53
|
)
|
|
49
54
|
|
|
50
|
-
types_map = {}
|
|
55
|
+
types_map: dict[str, type] = {}
|
|
51
56
|
length = -1
|
|
52
57
|
for k, v in fr_map.items():
|
|
53
|
-
if not isinstance(v, Sequence) or isinstance(v, str):
|
|
58
|
+
if not isinstance(v, Sequence) or isinstance(v, str): # type: ignore[unreachable]
|
|
54
59
|
raise ValuesToTupleError(ds_name, f"signals '{k}' is not a sequence")
|
|
55
60
|
len_ = len(v)
|
|
56
61
|
|
|
@@ -64,15 +69,16 @@ def values_to_tuples( # noqa: C901, PLR0912
|
|
|
64
69
|
if len_ == 0:
|
|
65
70
|
raise ValuesToTupleError(ds_name, f"signal '{k}' is empty list")
|
|
66
71
|
|
|
67
|
-
|
|
72
|
+
first_element = next(iter(v))
|
|
73
|
+
typ = type(first_element)
|
|
68
74
|
if not is_chain_type(typ):
|
|
69
75
|
raise ValuesToTupleError(
|
|
70
76
|
ds_name,
|
|
71
77
|
f"signal '{k}' has unsupported type '{typ.__name__}'."
|
|
72
78
|
f" Please use DataModel types: {DataTypeNames}",
|
|
73
79
|
)
|
|
74
|
-
if
|
|
75
|
-
types_map[k] = list[type(
|
|
80
|
+
if isinstance(first_element, list):
|
|
81
|
+
types_map[k] = list[type(first_element[0])] # type: ignore[assignment, misc]
|
|
76
82
|
else:
|
|
77
83
|
types_map[k] = typ
|
|
78
84
|
|
|
@@ -98,7 +104,7 @@ def values_to_tuples( # noqa: C901, PLR0912
|
|
|
98
104
|
if len(output) > 1: # type: ignore[arg-type]
|
|
99
105
|
tuple_type = tuple(output_types)
|
|
100
106
|
res_type = tuple[tuple_type] # type: ignore[valid-type]
|
|
101
|
-
res_values = list(zip(*fr_map.values()))
|
|
107
|
+
res_values: Sequence[Any] = list(zip(*fr_map.values()))
|
|
102
108
|
else:
|
|
103
109
|
res_type = output_types[0] # type: ignore[misc]
|
|
104
110
|
res_values = next(iter(fr_map.values()))
|
datachain/lib/data_model.py
CHANGED
|
@@ -18,6 +18,7 @@ StandardType = Union[
|
|
|
18
18
|
]
|
|
19
19
|
DataType = Union[type[BaseModel], StandardType]
|
|
20
20
|
DataTypeNames = "BaseModel, int, str, float, bool, list, dict, bytes, datetime"
|
|
21
|
+
DataValuesType = Union[BaseModel, int, str, float, bool, list, dict, bytes, datetime]
|
|
21
22
|
|
|
22
23
|
|
|
23
24
|
class DataModel(BaseModel):
|
datachain/lib/dc.py
CHANGED
|
@@ -309,6 +309,7 @@ class DataChain(DatasetQuery):
|
|
|
309
309
|
*,
|
|
310
310
|
type: Literal["binary", "text", "image"] = "binary",
|
|
311
311
|
session: Optional[Session] = None,
|
|
312
|
+
in_memory: bool = False,
|
|
312
313
|
recursive: Optional[bool] = True,
|
|
313
314
|
object_name: str = "file",
|
|
314
315
|
update: bool = False,
|
|
@@ -332,7 +333,14 @@ class DataChain(DatasetQuery):
|
|
|
332
333
|
"""
|
|
333
334
|
func = get_file(type)
|
|
334
335
|
return (
|
|
335
|
-
cls(
|
|
336
|
+
cls(
|
|
337
|
+
path,
|
|
338
|
+
session=session,
|
|
339
|
+
recursive=recursive,
|
|
340
|
+
update=update,
|
|
341
|
+
in_memory=in_memory,
|
|
342
|
+
**kwargs,
|
|
343
|
+
)
|
|
336
344
|
.map(**{object_name: func})
|
|
337
345
|
.select(object_name)
|
|
338
346
|
)
|
|
@@ -479,7 +487,10 @@ class DataChain(DatasetQuery):
|
|
|
479
487
|
|
|
480
488
|
@classmethod
|
|
481
489
|
def datasets(
|
|
482
|
-
cls,
|
|
490
|
+
cls,
|
|
491
|
+
session: Optional[Session] = None,
|
|
492
|
+
in_memory: bool = False,
|
|
493
|
+
object_name: str = "dataset",
|
|
483
494
|
) -> "DataChain":
|
|
484
495
|
"""Generate chain with list of registered datasets.
|
|
485
496
|
|
|
@@ -492,7 +503,7 @@ class DataChain(DatasetQuery):
|
|
|
492
503
|
print(f"{ds.name}@v{ds.version}")
|
|
493
504
|
```
|
|
494
505
|
"""
|
|
495
|
-
session = Session.get(session)
|
|
506
|
+
session = Session.get(session, in_memory=in_memory)
|
|
496
507
|
catalog = session.catalog
|
|
497
508
|
|
|
498
509
|
datasets = [
|
|
@@ -502,6 +513,7 @@ class DataChain(DatasetQuery):
|
|
|
502
513
|
|
|
503
514
|
return cls.from_values(
|
|
504
515
|
session=session,
|
|
516
|
+
in_memory=in_memory,
|
|
505
517
|
output={object_name: DatasetInfo},
|
|
506
518
|
**{object_name: datasets}, # type: ignore[arg-type]
|
|
507
519
|
)
|
|
@@ -1142,6 +1154,7 @@ class DataChain(DatasetQuery):
|
|
|
1142
1154
|
cls,
|
|
1143
1155
|
ds_name: str = "",
|
|
1144
1156
|
session: Optional[Session] = None,
|
|
1157
|
+
in_memory: bool = False,
|
|
1145
1158
|
output: OutputType = None,
|
|
1146
1159
|
object_name: str = "",
|
|
1147
1160
|
**fr_map,
|
|
@@ -1158,7 +1171,9 @@ class DataChain(DatasetQuery):
|
|
|
1158
1171
|
def _func_fr() -> Iterator[tuple_type]: # type: ignore[valid-type]
|
|
1159
1172
|
yield from tuples
|
|
1160
1173
|
|
|
1161
|
-
chain = DataChain.from_records(
|
|
1174
|
+
chain = DataChain.from_records(
|
|
1175
|
+
DataChain.DEFAULT_FILE_RECORD, session=session, in_memory=in_memory
|
|
1176
|
+
)
|
|
1162
1177
|
if object_name:
|
|
1163
1178
|
output = {object_name: DataChain._dict_to_data_model(object_name, output)} # type: ignore[arg-type]
|
|
1164
1179
|
return chain.gen(_func_fr, output=output)
|
|
@@ -1169,6 +1184,7 @@ class DataChain(DatasetQuery):
|
|
|
1169
1184
|
df: "pd.DataFrame",
|
|
1170
1185
|
name: str = "",
|
|
1171
1186
|
session: Optional[Session] = None,
|
|
1187
|
+
in_memory: bool = False,
|
|
1172
1188
|
object_name: str = "",
|
|
1173
1189
|
) -> "DataChain":
|
|
1174
1190
|
"""Generate chain from pandas data-frame.
|
|
@@ -1196,7 +1212,9 @@ class DataChain(DatasetQuery):
|
|
|
1196
1212
|
f"import from pandas error - '{column}' cannot be a column name",
|
|
1197
1213
|
)
|
|
1198
1214
|
|
|
1199
|
-
return cls.from_values(
|
|
1215
|
+
return cls.from_values(
|
|
1216
|
+
name, session, object_name=object_name, in_memory=in_memory, **fr_map
|
|
1217
|
+
)
|
|
1200
1218
|
|
|
1201
1219
|
def to_pandas(self, flatten=False) -> "pd.DataFrame":
|
|
1202
1220
|
"""Return a pandas DataFrame from the chain.
|
|
@@ -1505,6 +1523,7 @@ class DataChain(DatasetQuery):
|
|
|
1505
1523
|
cls,
|
|
1506
1524
|
to_insert: Optional[Union[dict, list[dict]]],
|
|
1507
1525
|
session: Optional[Session] = None,
|
|
1526
|
+
in_memory: bool = False,
|
|
1508
1527
|
) -> "DataChain":
|
|
1509
1528
|
"""Create a DataChain from the provided records. This method can be used for
|
|
1510
1529
|
programmatically generating a chain in contrast of reading data from storages
|
|
@@ -1520,7 +1539,7 @@ class DataChain(DatasetQuery):
|
|
|
1520
1539
|
single_record = DataChain.from_records(DataChain.DEFAULT_FILE_RECORD)
|
|
1521
1540
|
```
|
|
1522
1541
|
"""
|
|
1523
|
-
session = Session.get(session)
|
|
1542
|
+
session = Session.get(session, in_memory=in_memory)
|
|
1524
1543
|
catalog = session.catalog
|
|
1525
1544
|
|
|
1526
1545
|
name = session.generate_temp_dataset_name()
|
datachain/lib/listing.py
ADDED
|
@@ -0,0 +1,111 @@
|
|
|
1
|
+
import asyncio
|
|
2
|
+
from collections.abc import AsyncIterator, Iterator, Sequence
|
|
3
|
+
from typing import Callable, Optional
|
|
4
|
+
|
|
5
|
+
from botocore.exceptions import ClientError
|
|
6
|
+
from fsspec.asyn import get_loop
|
|
7
|
+
|
|
8
|
+
from datachain.asyn import iter_over_async
|
|
9
|
+
from datachain.client import Client
|
|
10
|
+
from datachain.error import ClientError as DataChainClientError
|
|
11
|
+
from datachain.lib.file import File
|
|
12
|
+
|
|
13
|
+
ResultQueue = asyncio.Queue[Optional[Sequence[File]]]
|
|
14
|
+
|
|
15
|
+
DELIMITER = "/" # Path delimiter
|
|
16
|
+
FETCH_WORKERS = 100
|
|
17
|
+
|
|
18
|
+
|
|
19
|
+
async def _fetch_dir(client, prefix, result_queue) -> set[str]:
|
|
20
|
+
path = f"{client.name}/{prefix}"
|
|
21
|
+
infos = await client.ls_dir(path)
|
|
22
|
+
files = []
|
|
23
|
+
subdirs = set()
|
|
24
|
+
for info in infos:
|
|
25
|
+
full_path = info["name"]
|
|
26
|
+
subprefix = client.rel_path(full_path)
|
|
27
|
+
if prefix.strip(DELIMITER) == subprefix.strip(DELIMITER):
|
|
28
|
+
continue
|
|
29
|
+
if info["type"] == "directory":
|
|
30
|
+
subdirs.add(subprefix)
|
|
31
|
+
else:
|
|
32
|
+
files.append(client.info_to_file(info, subprefix))
|
|
33
|
+
if files:
|
|
34
|
+
await result_queue.put(files)
|
|
35
|
+
return subdirs
|
|
36
|
+
|
|
37
|
+
|
|
38
|
+
async def _fetch(
|
|
39
|
+
client, start_prefix: str, result_queue: ResultQueue, fetch_workers
|
|
40
|
+
) -> None:
|
|
41
|
+
loop = get_loop()
|
|
42
|
+
|
|
43
|
+
queue: asyncio.Queue[str] = asyncio.Queue()
|
|
44
|
+
queue.put_nowait(start_prefix)
|
|
45
|
+
|
|
46
|
+
async def process(queue) -> None:
|
|
47
|
+
while True:
|
|
48
|
+
prefix = await queue.get()
|
|
49
|
+
try:
|
|
50
|
+
subdirs = await _fetch_dir(client, prefix, result_queue)
|
|
51
|
+
for subdir in subdirs:
|
|
52
|
+
queue.put_nowait(subdir)
|
|
53
|
+
except Exception:
|
|
54
|
+
while not queue.empty():
|
|
55
|
+
queue.get_nowait()
|
|
56
|
+
queue.task_done()
|
|
57
|
+
raise
|
|
58
|
+
|
|
59
|
+
finally:
|
|
60
|
+
queue.task_done()
|
|
61
|
+
|
|
62
|
+
try:
|
|
63
|
+
workers: list[asyncio.Task] = [
|
|
64
|
+
loop.create_task(process(queue)) for _ in range(fetch_workers)
|
|
65
|
+
]
|
|
66
|
+
|
|
67
|
+
# Wait for all fetch tasks to complete
|
|
68
|
+
await queue.join()
|
|
69
|
+
# Stop the workers
|
|
70
|
+
excs = []
|
|
71
|
+
for worker in workers:
|
|
72
|
+
if worker.done() and (exc := worker.exception()):
|
|
73
|
+
excs.append(exc)
|
|
74
|
+
else:
|
|
75
|
+
worker.cancel()
|
|
76
|
+
if excs:
|
|
77
|
+
raise excs[0]
|
|
78
|
+
except ClientError as exc:
|
|
79
|
+
raise DataChainClientError(
|
|
80
|
+
exc.response.get("Error", {}).get("Message") or exc,
|
|
81
|
+
exc.response.get("Error", {}).get("Code"),
|
|
82
|
+
) from exc
|
|
83
|
+
finally:
|
|
84
|
+
# This ensures the progress bar is closed before any exceptions are raised
|
|
85
|
+
result_queue.put_nowait(None)
|
|
86
|
+
|
|
87
|
+
|
|
88
|
+
async def _scandir(client, prefix, fetch_workers) -> AsyncIterator:
|
|
89
|
+
"""Recursively goes through dir tree and yields files"""
|
|
90
|
+
result_queue: ResultQueue = asyncio.Queue()
|
|
91
|
+
loop = get_loop()
|
|
92
|
+
main_task = loop.create_task(_fetch(client, prefix, result_queue, fetch_workers))
|
|
93
|
+
while (files := await result_queue.get()) is not None:
|
|
94
|
+
for f in files:
|
|
95
|
+
yield f
|
|
96
|
+
|
|
97
|
+
await main_task
|
|
98
|
+
|
|
99
|
+
|
|
100
|
+
def list_bucket(uri: str, client_config=None, fetch_workers=FETCH_WORKERS) -> Callable:
|
|
101
|
+
"""
|
|
102
|
+
Function that returns another generator function that yields File objects
|
|
103
|
+
from bucket where each File represents one bucket entry.
|
|
104
|
+
"""
|
|
105
|
+
|
|
106
|
+
def list_func() -> Iterator[File]:
|
|
107
|
+
config = client_config or {}
|
|
108
|
+
client, path = Client.parse_url(uri, None, **config) # type: ignore[arg-type]
|
|
109
|
+
yield from iter_over_async(_scandir(client, path, fetch_workers), get_loop())
|
|
110
|
+
|
|
111
|
+
return list_func
|
datachain/query/dataset.py
CHANGED
|
@@ -34,6 +34,7 @@ from sqlalchemy.sql.elements import ColumnClause, ColumnElement
|
|
|
34
34
|
from sqlalchemy.sql.expression import label
|
|
35
35
|
from sqlalchemy.sql.schema import TableClause
|
|
36
36
|
from sqlalchemy.sql.selectable import Select
|
|
37
|
+
from tqdm import tqdm
|
|
37
38
|
|
|
38
39
|
from datachain.asyn import ASYNC_WORKERS, AsyncMapper, OrderedMapper
|
|
39
40
|
from datachain.catalog import (
|
|
@@ -125,7 +126,10 @@ class QueryGenerator:
|
|
|
125
126
|
func: QueryGeneratorFunc
|
|
126
127
|
columns: tuple[ColumnElement, ...]
|
|
127
128
|
|
|
128
|
-
def
|
|
129
|
+
def only(self, column_names: Sequence[str]) -> Select:
|
|
130
|
+
return self.func(*(c for c in self.columns if c.name in column_names))
|
|
131
|
+
|
|
132
|
+
def exclude(self, column_names: Sequence[str]) -> Select:
|
|
129
133
|
return self.func(*(c for c in self.columns if c.name not in column_names))
|
|
130
134
|
|
|
131
135
|
def select(self, column_names=None) -> Select:
|
|
@@ -465,6 +469,12 @@ class UDFStep(Step, ABC):
|
|
|
465
469
|
|
|
466
470
|
try:
|
|
467
471
|
if workers:
|
|
472
|
+
if self.catalog.in_memory:
|
|
473
|
+
raise RuntimeError(
|
|
474
|
+
"In-memory databases cannot be used with "
|
|
475
|
+
"distributed processing."
|
|
476
|
+
)
|
|
477
|
+
|
|
468
478
|
from datachain.catalog.loader import get_distributed_class
|
|
469
479
|
|
|
470
480
|
distributor = get_distributed_class(min_task_size=self.min_task_size)
|
|
@@ -482,6 +492,10 @@ class UDFStep(Step, ABC):
|
|
|
482
492
|
)
|
|
483
493
|
elif processes:
|
|
484
494
|
# Parallel processing (faster for more CPU-heavy UDFs)
|
|
495
|
+
if self.catalog.in_memory:
|
|
496
|
+
raise RuntimeError(
|
|
497
|
+
"In-memory databases cannot be used with parallel processing."
|
|
498
|
+
)
|
|
485
499
|
udf_info = {
|
|
486
500
|
"udf_data": filtered_cloudpickle_dumps(self.udf),
|
|
487
501
|
"catalog_init": self.catalog.get_init_params(),
|
|
@@ -1049,6 +1063,7 @@ class DatasetQuery:
|
|
|
1049
1063
|
indexing_feature_schema: Optional[dict] = None,
|
|
1050
1064
|
indexing_column_types: Optional[dict[str, Any]] = None,
|
|
1051
1065
|
update: Optional[bool] = False,
|
|
1066
|
+
in_memory: bool = False,
|
|
1052
1067
|
):
|
|
1053
1068
|
if client_config is None:
|
|
1054
1069
|
client_config = {}
|
|
@@ -1057,7 +1072,7 @@ class DatasetQuery:
|
|
|
1057
1072
|
client_config["anon"] = True
|
|
1058
1073
|
|
|
1059
1074
|
self.session = Session.get(
|
|
1060
|
-
session, catalog=catalog, client_config=client_config
|
|
1075
|
+
session, catalog=catalog, client_config=client_config, in_memory=in_memory
|
|
1061
1076
|
)
|
|
1062
1077
|
self.catalog = catalog or self.session.catalog
|
|
1063
1078
|
self.steps: list[Step] = []
|
|
@@ -1648,18 +1663,13 @@ class DatasetQuery:
|
|
|
1648
1663
|
|
|
1649
1664
|
dr = self.catalog.warehouse.dataset_rows(dataset)
|
|
1650
1665
|
|
|
1651
|
-
|
|
1652
|
-
|
|
1653
|
-
|
|
1654
|
-
|
|
1655
|
-
|
|
1656
|
-
q = q.add_columns(
|
|
1657
|
-
f.row_number().over(order_by=q._order_by_clauses).label("sys__id")
|
|
1666
|
+
with tqdm(desc="Saving", unit=" rows") as pbar:
|
|
1667
|
+
self.catalog.warehouse.copy_table(
|
|
1668
|
+
dr.get_table(),
|
|
1669
|
+
query.select(),
|
|
1670
|
+
progress_cb=pbar.update,
|
|
1658
1671
|
)
|
|
1659
1672
|
|
|
1660
|
-
cols = tuple(c.name for c in q.selected_columns)
|
|
1661
|
-
insert_q = sqlalchemy.insert(dr.get_table()).from_select(cols, q)
|
|
1662
|
-
self.catalog.warehouse.db.execute(insert_q, **kwargs)
|
|
1663
1673
|
self.catalog.metastore.update_dataset_status(
|
|
1664
1674
|
dataset, DatasetStatus.COMPLETE, version=version
|
|
1665
1675
|
)
|
datachain/query/session.py
CHANGED
|
@@ -46,6 +46,7 @@ class Session:
|
|
|
46
46
|
name="",
|
|
47
47
|
catalog: Optional["Catalog"] = None,
|
|
48
48
|
client_config: Optional[dict] = None,
|
|
49
|
+
in_memory: bool = False,
|
|
49
50
|
):
|
|
50
51
|
if re.match(r"^[0-9a-zA-Z]+$", name) is None:
|
|
51
52
|
raise ValueError(
|
|
@@ -58,7 +59,9 @@ class Session:
|
|
|
58
59
|
session_uuid = uuid4().hex[: self.SESSION_UUID_LEN]
|
|
59
60
|
self.name = f"{name}_{session_uuid}"
|
|
60
61
|
self.is_new_catalog = not catalog
|
|
61
|
-
self.catalog = catalog or get_catalog(
|
|
62
|
+
self.catalog = catalog or get_catalog(
|
|
63
|
+
client_config=client_config, in_memory=in_memory
|
|
64
|
+
)
|
|
62
65
|
|
|
63
66
|
def __enter__(self):
|
|
64
67
|
return self
|
|
@@ -89,6 +92,7 @@ class Session:
|
|
|
89
92
|
session: Optional["Session"] = None,
|
|
90
93
|
catalog: Optional["Catalog"] = None,
|
|
91
94
|
client_config: Optional[dict] = None,
|
|
95
|
+
in_memory: bool = False,
|
|
92
96
|
) -> "Session":
|
|
93
97
|
"""Creates a Session() object from a catalog.
|
|
94
98
|
|
|
@@ -102,7 +106,10 @@ class Session:
|
|
|
102
106
|
|
|
103
107
|
if cls.GLOBAL_SESSION is None:
|
|
104
108
|
cls.GLOBAL_SESSION_CTX = Session(
|
|
105
|
-
cls.GLOBAL_SESSION_NAME,
|
|
109
|
+
cls.GLOBAL_SESSION_NAME,
|
|
110
|
+
catalog,
|
|
111
|
+
client_config=client_config,
|
|
112
|
+
in_memory=in_memory,
|
|
106
113
|
)
|
|
107
114
|
cls.GLOBAL_SESSION = cls.GLOBAL_SESSION_CTX.__enter__()
|
|
108
115
|
atexit.register(cls._global_cleanup)
|
datachain/sql/sqlite/base.py
CHANGED
|
@@ -221,19 +221,45 @@ def path_name(path):
|
|
|
221
221
|
return func.ltrim(func.substr(path, func.length(path_parent(path)) + 1), slash)
|
|
222
222
|
|
|
223
223
|
|
|
224
|
-
def
|
|
225
|
-
name = path_name(path)
|
|
224
|
+
def name_file_ext_length(name):
|
|
226
225
|
expr = func.length(name) - func.length(
|
|
227
226
|
func.rtrim(name, func.replace(name, dot, empty_str))
|
|
228
227
|
)
|
|
229
228
|
return case((func.instr(name, dot) == 0, 0), else_=expr)
|
|
230
229
|
|
|
231
230
|
|
|
231
|
+
def path_file_ext_length(path):
|
|
232
|
+
name = path_name(path)
|
|
233
|
+
return name_file_ext_length(name)
|
|
234
|
+
|
|
235
|
+
|
|
232
236
|
def path_file_stem(path):
|
|
233
|
-
|
|
234
|
-
|
|
237
|
+
path_length = func.length(path)
|
|
238
|
+
parent_length = func.length(path_parent(path))
|
|
239
|
+
|
|
240
|
+
name_expr = func.rtrim(
|
|
241
|
+
func.substr(
|
|
242
|
+
path,
|
|
243
|
+
1,
|
|
244
|
+
path_length - name_file_ext_length(path),
|
|
245
|
+
),
|
|
246
|
+
dot,
|
|
247
|
+
)
|
|
248
|
+
|
|
249
|
+
full_path_expr = func.ltrim(
|
|
250
|
+
func.rtrim(
|
|
251
|
+
func.substr(
|
|
252
|
+
path,
|
|
253
|
+
parent_length + 1,
|
|
254
|
+
path_length - parent_length - path_file_ext_length(path),
|
|
255
|
+
),
|
|
256
|
+
dot,
|
|
257
|
+
),
|
|
258
|
+
slash,
|
|
235
259
|
)
|
|
236
260
|
|
|
261
|
+
return case((func.instr(path, slash) == 0, name_expr), else_=full_path_expr)
|
|
262
|
+
|
|
237
263
|
|
|
238
264
|
def path_file_ext(path):
|
|
239
265
|
return func.substr(path, func.length(path) - path_file_ext_length(path) + 1)
|
|
@@ -1,6 +1,6 @@
|
|
|
1
1
|
Metadata-Version: 2.1
|
|
2
2
|
Name: datachain
|
|
3
|
-
Version: 0.3.
|
|
3
|
+
Version: 0.3.3
|
|
4
4
|
Summary: Wrangle unstructured AI data at scale
|
|
5
5
|
Author-email: Dmitry Petrov <support@dvc.org>
|
|
6
6
|
License: Apache-2.0
|
|
@@ -43,7 +43,7 @@ Requires-Dist: Pillow <11,>=10.0.0
|
|
|
43
43
|
Requires-Dist: numpy <2,>=1 ; sys_platform == "win32"
|
|
44
44
|
Provides-Extra: dev
|
|
45
45
|
Requires-Dist: datachain[docs,tests] ; extra == 'dev'
|
|
46
|
-
Requires-Dist: mypy ==1.
|
|
46
|
+
Requires-Dist: mypy ==1.11.1 ; extra == 'dev'
|
|
47
47
|
Requires-Dist: types-python-dateutil ; extra == 'dev'
|
|
48
48
|
Requires-Dist: types-pytz ; extra == 'dev'
|
|
49
49
|
Requires-Dist: types-PyYAML ; extra == 'dev'
|
|
@@ -1,6 +1,6 @@
|
|
|
1
1
|
datachain/__init__.py,sha256=GeyhE-5LgfJav2OKYGaieP2lBvf2Gm-ihj7thnK9zjI,800
|
|
2
2
|
datachain/__main__.py,sha256=hG3Y4ARGEqe1AWwNMd259rBlqtphx1Wk39YbueQ0yV8,91
|
|
3
|
-
datachain/asyn.py,sha256=
|
|
3
|
+
datachain/asyn.py,sha256=biF8M8fQujtj5xs0VLi8S16eBtzG6kceWlO_NILbCsg,8197
|
|
4
4
|
datachain/cache.py,sha256=wznC2pge6RhlPTaJfBVGjmBc6bxWCPThu4aTFMltvFU,4076
|
|
5
5
|
datachain/cli.py,sha256=DbmI1sXs7-KCQz6RdLE_JAp3XO3yrTSRJ71LdUzx-XE,33099
|
|
6
6
|
datachain/cli_utils.py,sha256=jrn9ejGXjybeO1ur3fjdSiAyCHZrX0qsLLbJzN9ErPM,2418
|
|
@@ -17,17 +17,17 @@ datachain/py.typed,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
|
|
17
17
|
datachain/storage.py,sha256=RiSJLYdHUjnrEWkLBKPcETHpAxld_B2WxLg711t0aZI,3733
|
|
18
18
|
datachain/utils.py,sha256=ROVCLwb37VmFRzgTlSGUDw4eJNgYGiQ4yMX581HfUX8,12988
|
|
19
19
|
datachain/catalog/__init__.py,sha256=g2iAAFx_gEIrqshXlhSEbrc8qDaEH11cjU40n3CHDz4,409
|
|
20
|
-
datachain/catalog/catalog.py,sha256=
|
|
20
|
+
datachain/catalog/catalog.py,sha256=_BRaD261RnCJgXr_DJcDf58XmbjLiuLMSsX97E8k3z8,80771
|
|
21
21
|
datachain/catalog/datasource.py,sha256=D-VWIVDCM10A8sQavLhRXdYSCG7F4o4ifswEF80_NAQ,1412
|
|
22
|
-
datachain/catalog/loader.py,sha256
|
|
22
|
+
datachain/catalog/loader.py,sha256=-6VelNfXUdgUnwInVyA8g86Boxv2xqhTh9xNS-Zlwig,8242
|
|
23
23
|
datachain/catalog/subclass.py,sha256=B5R0qxeTYEyVAAPM1RutBPSoXZc8L5mVVZeSGXki9Sw,2096
|
|
24
24
|
datachain/client/__init__.py,sha256=T4wiYL9KIM0ZZ_UqIyzV8_ufzYlewmizlV4iymHNluE,86
|
|
25
|
-
datachain/client/azure.py,sha256=
|
|
25
|
+
datachain/client/azure.py,sha256=LXSahE0Z6r4dXqpBkKnq3J5fg7N7ymC1lSn-1SoILGc,2687
|
|
26
26
|
datachain/client/fileslice.py,sha256=bT7TYco1Qe3bqoc8aUkUZcPdPofJDHlryL5BsTn9xsY,3021
|
|
27
27
|
datachain/client/fsspec.py,sha256=G4QTm3KPhlaV74T3gLXJ86345_ak8CH38ezn2ET-oLc,13230
|
|
28
|
-
datachain/client/gcs.py,sha256=
|
|
29
|
-
datachain/client/local.py,sha256=
|
|
30
|
-
datachain/client/s3.py,sha256=
|
|
28
|
+
datachain/client/gcs.py,sha256=P_E3mhzhXR9mJ_wc3AYZuczzwOJ0-D3J5qhJXeSU-xk,4518
|
|
29
|
+
datachain/client/local.py,sha256=H8TNY8pi2kA8y9_f_1XLUjJF66f229qC_b2y4xGkzdU,5300
|
|
30
|
+
datachain/client/s3.py,sha256=aQxfMH8G8bUjmHF1-6P90MSkXsU5DgOPEVlKWLu459I,6568
|
|
31
31
|
datachain/data_storage/__init__.py,sha256=cEOJpyu1JDZtfUupYucCDNFI6e5Wmp_Oyzq6rZv32Y8,398
|
|
32
32
|
datachain/data_storage/db_engine.py,sha256=81Ol1of9TTTzD97ORajCnP366Xz2mEJt6C-kTUCaru4,3406
|
|
33
33
|
datachain/data_storage/id_generator.py,sha256=lCEoU0BM37Ai2aRpSbwo5oQT0GqZnSpYwwvizathRMQ,4292
|
|
@@ -35,16 +35,17 @@ datachain/data_storage/job.py,sha256=w-7spowjkOa1P5fUVtJou3OltT0L48P0RYWZ9rSJ9-s
|
|
|
35
35
|
datachain/data_storage/metastore.py,sha256=nxcY6nwyEmQWMAo33sNGO-FgUFQs2amBGGnZz2ftEz0,55362
|
|
36
36
|
datachain/data_storage/schema.py,sha256=GwJIHkjhrnBxJAV1WvCMM8jiJN5h79LXDyzMmUDtRw0,8523
|
|
37
37
|
datachain/data_storage/serializer.py,sha256=6G2YtOFqqDzJf1KbvZraKGXl2XHZyVml2krunWUum5o,927
|
|
38
|
-
datachain/data_storage/sqlite.py,sha256=
|
|
39
|
-
datachain/data_storage/warehouse.py,sha256=
|
|
38
|
+
datachain/data_storage/sqlite.py,sha256=GEE07ZXTAtzdf53J1UDLscS0xZjukRGlmZzG6q0fZI0,28589
|
|
39
|
+
datachain/data_storage/warehouse.py,sha256=tyJJDxFae6XWgLmOoG0B_MJ_Z_UEMoW_wJb96zzwTtA,33471
|
|
40
40
|
datachain/lib/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
|
41
|
-
datachain/lib/arrow.py,sha256=
|
|
41
|
+
datachain/lib/arrow.py,sha256=D8N7zCppRdc5sTYT1hNIbROc-sKA_8FN5J_m-KjD3Us,4929
|
|
42
42
|
datachain/lib/clip.py,sha256=16u4b_y2Y15nUS2UN_8ximMo6r_-_4IQpmct2ol-e-g,5730
|
|
43
|
-
datachain/lib/data_model.py,sha256=
|
|
43
|
+
datachain/lib/data_model.py,sha256=ZvtMRMcPpBxI-rOhkXb-ry1PkGYcEFFK1w1wH12vs4g,1718
|
|
44
44
|
datachain/lib/dataset_info.py,sha256=lONGr71ozo1DS4CQEhnpKORaU4qFb6Ketv8Xm8CVm2U,2188
|
|
45
|
-
datachain/lib/dc.py,sha256=
|
|
45
|
+
datachain/lib/dc.py,sha256=0pwNb91GW8MnHLfFd2YvEtEH0n77c3nxp5ozwIyW86o,58827
|
|
46
46
|
datachain/lib/file.py,sha256=ZHpdilDPYCob8uqtwUPtBvBNxVvQRq4AC_0IGg5m-G4,12003
|
|
47
47
|
datachain/lib/image.py,sha256=TgYhRhzd4nkytfFMeykQkPyzqb5Le_-tU81unVMPn4Q,2328
|
|
48
|
+
datachain/lib/listing.py,sha256=nXLmGae_oQke4hnurzzWiHTEjHjWiqqHdB41Wb-hMTk,3521
|
|
48
49
|
datachain/lib/meta_formats.py,sha256=Hels85LJmNCz1aYVJvhymNdAt3qdJ2-qoxsIiUezrow,7198
|
|
49
50
|
datachain/lib/model_store.py,sha256=c4USXsBBjrGH8VOh4seIgOiav-qHOwdoixtxfLgU63c,2409
|
|
50
51
|
datachain/lib/pytorch.py,sha256=9PsypKseyKfIimTmTQOgb-pbNXgeeAHLdlWx0qRPULY,5660
|
|
@@ -62,17 +63,17 @@ datachain/lib/convert/flatten.py,sha256=YMoC00BqEy3zSpvCp6Q0DfxihuPmgjUJj1g2cesW
|
|
|
62
63
|
datachain/lib/convert/python_to_sql.py,sha256=4gplGlr_Kg-Z40OpJUzJiarDWj7pwbUOk-dPOYYCJ9Q,2629
|
|
63
64
|
datachain/lib/convert/sql_to_python.py,sha256=lGnKzSF_tz9Y_5SSKkrIU95QEjpcDzvOxIRkEKTQag0,443
|
|
64
65
|
datachain/lib/convert/unflatten.py,sha256=Ogvh_5wg2f38_At_1lN0D_e2uZOOpYEvwvB2xdq56Tw,2012
|
|
65
|
-
datachain/lib/convert/values_to_tuples.py,sha256=
|
|
66
|
+
datachain/lib/convert/values_to_tuples.py,sha256=YOdbjzHq-uj6-cV2Qq43G72eN2avMNDGl4x5t6yQMl8,3931
|
|
66
67
|
datachain/query/__init__.py,sha256=tv-spkjUCYamMN9ys_90scYrZ8kJ7C7d1MTYVmxGtk4,325
|
|
67
68
|
datachain/query/batch.py,sha256=-vlpINJiertlnaoUVv1C95RatU0F6zuhpIYRufJRo1M,3660
|
|
68
69
|
datachain/query/builtins.py,sha256=EmKPYsoQ46zwdyOn54MuCzvYFmfsBn5F8zyF7UBUfrc,2550
|
|
69
|
-
datachain/query/dataset.py,sha256=
|
|
70
|
+
datachain/query/dataset.py,sha256=7lxlybS7I5IPsgOqMz-W4vS6kWBDHkHQRqBHlIRYRPw,60473
|
|
70
71
|
datachain/query/dispatch.py,sha256=GBh3EZHDp5AaXxrjOpfrpfsuy7Umnqxu-MAXcK9X3gc,12945
|
|
71
72
|
datachain/query/metrics.py,sha256=vsECqbZfoSDBnvC3GQlziKXmISVYDLgHP1fMPEOtKyo,640
|
|
72
73
|
datachain/query/params.py,sha256=O_j89mjYRLOwWNhYZl-z7mi-rkdP7WyFmaDufsdTryE,863
|
|
73
74
|
datachain/query/queue.py,sha256=waqM_KzavU8C-G95-4211Nd4GXna_u2747Chgwtgz2w,3839
|
|
74
75
|
datachain/query/schema.py,sha256=BvHipN79CnSTbVFcfIEwzo1npe7HmThnk0iY-CSLEkM,7899
|
|
75
|
-
datachain/query/session.py,sha256=
|
|
76
|
+
datachain/query/session.py,sha256=PkOLANS0s8KPz4wO17tAab-CMzIt7FK8RPzJiibExds,4290
|
|
76
77
|
datachain/query/udf.py,sha256=j3NhmKK5rYG5TclcM2Sr0LhS1tmYLMjzMugx9G9iFLM,8100
|
|
77
78
|
datachain/remote/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
|
78
79
|
datachain/remote/studio.py,sha256=f5s6qSZ9uB4URGUoU_8_W1KZRRQQVSm6cgEBkBUEfuE,7226
|
|
@@ -89,13 +90,13 @@ datachain/sql/functions/path.py,sha256=zixpERotTFP6LZ7I4TiGtyRA8kXOoZmH1yzH9oRW0
|
|
|
89
90
|
datachain/sql/functions/random.py,sha256=vBwEEj98VH4LjWixUCygQ5Bz1mv1nohsCG0-ZTELlVg,271
|
|
90
91
|
datachain/sql/functions/string.py,sha256=hIrF1fTvlPamDtm8UMnWDcnGfbbjCsHxZXS30U2Rzxo,651
|
|
91
92
|
datachain/sql/sqlite/__init__.py,sha256=TAdJX0Bg28XdqPO-QwUVKy8rg78cgMileHvMNot7d04,166
|
|
92
|
-
datachain/sql/sqlite/base.py,sha256=
|
|
93
|
+
datachain/sql/sqlite/base.py,sha256=5nLvOv0xcOlEpfZeY3SWbI401MSGM2i29P3SRkd7TAc,12898
|
|
93
94
|
datachain/sql/sqlite/types.py,sha256=yzvp0sXSEoEYXs6zaYC_2YubarQoZH-MiUNXcpuEP4s,1573
|
|
94
95
|
datachain/sql/sqlite/vector.py,sha256=ncW4eu2FlJhrP_CIpsvtkUabZlQdl2D5Lgwy_cbfqR0,469
|
|
95
96
|
datachain/torch/__init__.py,sha256=gIS74PoEPy4TB3X6vx9nLO0Y3sLJzsA8ckn8pRWihJM,579
|
|
96
|
-
datachain-0.3.
|
|
97
|
-
datachain-0.3.
|
|
98
|
-
datachain-0.3.
|
|
99
|
-
datachain-0.3.
|
|
100
|
-
datachain-0.3.
|
|
101
|
-
datachain-0.3.
|
|
97
|
+
datachain-0.3.3.dist-info/LICENSE,sha256=8DnqK5yoPI_E50bEg_zsHKZHY2HqPy4rYN338BHQaRA,11344
|
|
98
|
+
datachain-0.3.3.dist-info/METADATA,sha256=BDBQIVMBj7tqy0TntMooUyMlPEVgVHA4xvMESRHiF0I,16789
|
|
99
|
+
datachain-0.3.3.dist-info/WHEEL,sha256=HiCZjzuy6Dw0hdX5R3LCFPDmFS4BWl8H-8W39XfmgX4,91
|
|
100
|
+
datachain-0.3.3.dist-info/entry_points.txt,sha256=0GMJS6B_KWq0m3VT98vQI2YZodAMkn4uReZ_okga9R4,49
|
|
101
|
+
datachain-0.3.3.dist-info/top_level.txt,sha256=lZPpdU_2jJABLNIg2kvEOBi8PtsYikbN1OdMLHk8bTg,10
|
|
102
|
+
datachain-0.3.3.dist-info/RECORD,,
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|