datachain 0.3.1__py3-none-any.whl → 0.3.3__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of datachain might be problematic. Click here for more details.
- datachain/asyn.py +20 -0
- datachain/catalog/catalog.py +12 -1
- datachain/catalog/loader.py +75 -50
- datachain/client/azure.py +13 -0
- datachain/client/gcs.py +12 -0
- datachain/client/local.py +11 -0
- datachain/client/s3.py +12 -0
- datachain/data_storage/schema.py +22 -8
- datachain/data_storage/sqlite.py +60 -14
- datachain/data_storage/warehouse.py +17 -3
- datachain/lib/arrow.py +1 -1
- datachain/lib/convert/values_to_tuples.py +14 -8
- datachain/lib/data_model.py +1 -0
- datachain/lib/dc.py +52 -19
- datachain/lib/listing.py +111 -0
- datachain/lib/meta_formats.py +8 -2
- datachain/node.py +1 -1
- datachain/query/dataset.py +22 -12
- datachain/query/schema.py +4 -0
- datachain/query/session.py +9 -2
- datachain/sql/default/base.py +3 -0
- datachain/sql/sqlite/base.py +33 -4
- datachain/sql/types.py +120 -11
- {datachain-0.3.1.dist-info → datachain-0.3.3.dist-info}/METADATA +75 -87
- {datachain-0.3.1.dist-info → datachain-0.3.3.dist-info}/RECORD +29 -28
- {datachain-0.3.1.dist-info → datachain-0.3.3.dist-info}/WHEEL +1 -1
- {datachain-0.3.1.dist-info → datachain-0.3.3.dist-info}/LICENSE +0 -0
- {datachain-0.3.1.dist-info → datachain-0.3.3.dist-info}/entry_points.txt +0 -0
- {datachain-0.3.1.dist-info → datachain-0.3.3.dist-info}/top_level.txt +0 -0
datachain/lib/data_model.py
CHANGED
|
@@ -18,6 +18,7 @@ StandardType = Union[
|
|
|
18
18
|
]
|
|
19
19
|
DataType = Union[type[BaseModel], StandardType]
|
|
20
20
|
DataTypeNames = "BaseModel, int, str, float, bool, list, dict, bytes, datetime"
|
|
21
|
+
DataValuesType = Union[BaseModel, int, str, float, bool, list, dict, bytes, datetime]
|
|
21
22
|
|
|
22
23
|
|
|
23
24
|
class DataModel(BaseModel):
|
datachain/lib/dc.py
CHANGED
|
@@ -309,6 +309,7 @@ class DataChain(DatasetQuery):
|
|
|
309
309
|
*,
|
|
310
310
|
type: Literal["binary", "text", "image"] = "binary",
|
|
311
311
|
session: Optional[Session] = None,
|
|
312
|
+
in_memory: bool = False,
|
|
312
313
|
recursive: Optional[bool] = True,
|
|
313
314
|
object_name: str = "file",
|
|
314
315
|
update: bool = False,
|
|
@@ -332,7 +333,14 @@ class DataChain(DatasetQuery):
|
|
|
332
333
|
"""
|
|
333
334
|
func = get_file(type)
|
|
334
335
|
return (
|
|
335
|
-
cls(
|
|
336
|
+
cls(
|
|
337
|
+
path,
|
|
338
|
+
session=session,
|
|
339
|
+
recursive=recursive,
|
|
340
|
+
update=update,
|
|
341
|
+
in_memory=in_memory,
|
|
342
|
+
**kwargs,
|
|
343
|
+
)
|
|
336
344
|
.map(**{object_name: func})
|
|
337
345
|
.select(object_name)
|
|
338
346
|
)
|
|
@@ -479,7 +487,10 @@ class DataChain(DatasetQuery):
|
|
|
479
487
|
|
|
480
488
|
@classmethod
|
|
481
489
|
def datasets(
|
|
482
|
-
cls,
|
|
490
|
+
cls,
|
|
491
|
+
session: Optional[Session] = None,
|
|
492
|
+
in_memory: bool = False,
|
|
493
|
+
object_name: str = "dataset",
|
|
483
494
|
) -> "DataChain":
|
|
484
495
|
"""Generate chain with list of registered datasets.
|
|
485
496
|
|
|
@@ -492,7 +503,7 @@ class DataChain(DatasetQuery):
|
|
|
492
503
|
print(f"{ds.name}@v{ds.version}")
|
|
493
504
|
```
|
|
494
505
|
"""
|
|
495
|
-
session = Session.get(session)
|
|
506
|
+
session = Session.get(session, in_memory=in_memory)
|
|
496
507
|
catalog = session.catalog
|
|
497
508
|
|
|
498
509
|
datasets = [
|
|
@@ -502,13 +513,14 @@ class DataChain(DatasetQuery):
|
|
|
502
513
|
|
|
503
514
|
return cls.from_values(
|
|
504
515
|
session=session,
|
|
516
|
+
in_memory=in_memory,
|
|
505
517
|
output={object_name: DatasetInfo},
|
|
506
518
|
**{object_name: datasets}, # type: ignore[arg-type]
|
|
507
519
|
)
|
|
508
520
|
|
|
509
521
|
def print_json_schema( # type: ignore[override]
|
|
510
522
|
self, jmespath: Optional[str] = None, model_name: Optional[str] = None
|
|
511
|
-
) -> "
|
|
523
|
+
) -> "Self":
|
|
512
524
|
"""Print JSON data model and save it. It returns the chain itself.
|
|
513
525
|
|
|
514
526
|
Parameters:
|
|
@@ -533,7 +545,7 @@ class DataChain(DatasetQuery):
|
|
|
533
545
|
|
|
534
546
|
def print_jsonl_schema( # type: ignore[override]
|
|
535
547
|
self, jmespath: Optional[str] = None, model_name: Optional[str] = None
|
|
536
|
-
) -> "
|
|
548
|
+
) -> "Self":
|
|
537
549
|
"""Print JSON data model and save it. It returns the chain itself.
|
|
538
550
|
|
|
539
551
|
Parameters:
|
|
@@ -549,7 +561,7 @@ class DataChain(DatasetQuery):
|
|
|
549
561
|
|
|
550
562
|
def save( # type: ignore[override]
|
|
551
563
|
self, name: Optional[str] = None, version: Optional[int] = None
|
|
552
|
-
) -> "
|
|
564
|
+
) -> "Self":
|
|
553
565
|
"""Save to a Dataset. It returns the chain itself.
|
|
554
566
|
|
|
555
567
|
Parameters:
|
|
@@ -785,7 +797,7 @@ class DataChain(DatasetQuery):
|
|
|
785
797
|
descending (bool): Whether to sort in descending order or not.
|
|
786
798
|
"""
|
|
787
799
|
if descending:
|
|
788
|
-
args = tuple(
|
|
800
|
+
args = tuple(sqlalchemy.desc(a) for a in args)
|
|
789
801
|
|
|
790
802
|
return super().order_by(*args)
|
|
791
803
|
|
|
@@ -1142,6 +1154,7 @@ class DataChain(DatasetQuery):
|
|
|
1142
1154
|
cls,
|
|
1143
1155
|
ds_name: str = "",
|
|
1144
1156
|
session: Optional[Session] = None,
|
|
1157
|
+
in_memory: bool = False,
|
|
1145
1158
|
output: OutputType = None,
|
|
1146
1159
|
object_name: str = "",
|
|
1147
1160
|
**fr_map,
|
|
@@ -1158,7 +1171,9 @@ class DataChain(DatasetQuery):
|
|
|
1158
1171
|
def _func_fr() -> Iterator[tuple_type]: # type: ignore[valid-type]
|
|
1159
1172
|
yield from tuples
|
|
1160
1173
|
|
|
1161
|
-
chain = DataChain.from_records(
|
|
1174
|
+
chain = DataChain.from_records(
|
|
1175
|
+
DataChain.DEFAULT_FILE_RECORD, session=session, in_memory=in_memory
|
|
1176
|
+
)
|
|
1162
1177
|
if object_name:
|
|
1163
1178
|
output = {object_name: DataChain._dict_to_data_model(object_name, output)} # type: ignore[arg-type]
|
|
1164
1179
|
return chain.gen(_func_fr, output=output)
|
|
@@ -1169,6 +1184,7 @@ class DataChain(DatasetQuery):
|
|
|
1169
1184
|
df: "pd.DataFrame",
|
|
1170
1185
|
name: str = "",
|
|
1171
1186
|
session: Optional[Session] = None,
|
|
1187
|
+
in_memory: bool = False,
|
|
1172
1188
|
object_name: str = "",
|
|
1173
1189
|
) -> "DataChain":
|
|
1174
1190
|
"""Generate chain from pandas data-frame.
|
|
@@ -1196,7 +1212,9 @@ class DataChain(DatasetQuery):
|
|
|
1196
1212
|
f"import from pandas error - '{column}' cannot be a column name",
|
|
1197
1213
|
)
|
|
1198
1214
|
|
|
1199
|
-
return cls.from_values(
|
|
1215
|
+
return cls.from_values(
|
|
1216
|
+
name, session, object_name=object_name, in_memory=in_memory, **fr_map
|
|
1217
|
+
)
|
|
1200
1218
|
|
|
1201
1219
|
def to_pandas(self, flatten=False) -> "pd.DataFrame":
|
|
1202
1220
|
"""Return a pandas DataFrame from the chain.
|
|
@@ -1206,14 +1224,14 @@ class DataChain(DatasetQuery):
|
|
|
1206
1224
|
"""
|
|
1207
1225
|
headers, max_length = self._effective_signals_schema.get_headers_with_length()
|
|
1208
1226
|
if flatten or max_length < 2:
|
|
1209
|
-
|
|
1227
|
+
columns = []
|
|
1210
1228
|
if headers:
|
|
1211
|
-
|
|
1212
|
-
return
|
|
1229
|
+
columns = [".".join(filter(None, header)) for header in headers]
|
|
1230
|
+
return pd.DataFrame.from_records(self.to_records(), columns=columns)
|
|
1213
1231
|
|
|
1214
|
-
|
|
1215
|
-
|
|
1216
|
-
|
|
1232
|
+
return pd.DataFrame(
|
|
1233
|
+
self.results(), columns=pd.MultiIndex.from_tuples(map(tuple, headers))
|
|
1234
|
+
)
|
|
1217
1235
|
|
|
1218
1236
|
def show(
|
|
1219
1237
|
self,
|
|
@@ -1232,6 +1250,12 @@ class DataChain(DatasetQuery):
|
|
|
1232
1250
|
"""
|
|
1233
1251
|
dc = self.limit(limit) if limit > 0 else self
|
|
1234
1252
|
df = dc.to_pandas(flatten)
|
|
1253
|
+
|
|
1254
|
+
if df.empty:
|
|
1255
|
+
print("Empty result")
|
|
1256
|
+
print(f"Columns: {list(df.columns)}")
|
|
1257
|
+
return
|
|
1258
|
+
|
|
1235
1259
|
if transpose:
|
|
1236
1260
|
df = df.T
|
|
1237
1261
|
|
|
@@ -1270,7 +1294,7 @@ class DataChain(DatasetQuery):
|
|
|
1270
1294
|
source: bool = True,
|
|
1271
1295
|
nrows: Optional[int] = None,
|
|
1272
1296
|
**kwargs,
|
|
1273
|
-
) -> "
|
|
1297
|
+
) -> "Self":
|
|
1274
1298
|
"""Generate chain from list of tabular files.
|
|
1275
1299
|
|
|
1276
1300
|
Parameters:
|
|
@@ -1390,7 +1414,8 @@ class DataChain(DatasetQuery):
|
|
|
1390
1414
|
dc = DataChain.from_csv("s3://mybucket/dir")
|
|
1391
1415
|
```
|
|
1392
1416
|
"""
|
|
1393
|
-
from
|
|
1417
|
+
from pandas.io.parsers.readers import STR_NA_VALUES
|
|
1418
|
+
from pyarrow.csv import ConvertOptions, ParseOptions, ReadOptions
|
|
1394
1419
|
from pyarrow.dataset import CsvFileFormat
|
|
1395
1420
|
|
|
1396
1421
|
chain = DataChain.from_storage(path, **kwargs)
|
|
@@ -1414,7 +1439,14 @@ class DataChain(DatasetQuery):
|
|
|
1414
1439
|
|
|
1415
1440
|
parse_options = ParseOptions(delimiter=delimiter)
|
|
1416
1441
|
read_options = ReadOptions(column_names=column_names)
|
|
1417
|
-
|
|
1442
|
+
convert_options = ConvertOptions(
|
|
1443
|
+
strings_can_be_null=True, null_values=STR_NA_VALUES
|
|
1444
|
+
)
|
|
1445
|
+
format = CsvFileFormat(
|
|
1446
|
+
parse_options=parse_options,
|
|
1447
|
+
read_options=read_options,
|
|
1448
|
+
convert_options=convert_options,
|
|
1449
|
+
)
|
|
1418
1450
|
return chain.parse_tabular(
|
|
1419
1451
|
output=output,
|
|
1420
1452
|
object_name=object_name,
|
|
@@ -1491,6 +1523,7 @@ class DataChain(DatasetQuery):
|
|
|
1491
1523
|
cls,
|
|
1492
1524
|
to_insert: Optional[Union[dict, list[dict]]],
|
|
1493
1525
|
session: Optional[Session] = None,
|
|
1526
|
+
in_memory: bool = False,
|
|
1494
1527
|
) -> "DataChain":
|
|
1495
1528
|
"""Create a DataChain from the provided records. This method can be used for
|
|
1496
1529
|
programmatically generating a chain in contrast of reading data from storages
|
|
@@ -1506,7 +1539,7 @@ class DataChain(DatasetQuery):
|
|
|
1506
1539
|
single_record = DataChain.from_records(DataChain.DEFAULT_FILE_RECORD)
|
|
1507
1540
|
```
|
|
1508
1541
|
"""
|
|
1509
|
-
session = Session.get(session)
|
|
1542
|
+
session = Session.get(session, in_memory=in_memory)
|
|
1510
1543
|
catalog = session.catalog
|
|
1511
1544
|
|
|
1512
1545
|
name = session.generate_temp_dataset_name()
|
datachain/lib/listing.py
ADDED
|
@@ -0,0 +1,111 @@
|
|
|
1
|
+
import asyncio
|
|
2
|
+
from collections.abc import AsyncIterator, Iterator, Sequence
|
|
3
|
+
from typing import Callable, Optional
|
|
4
|
+
|
|
5
|
+
from botocore.exceptions import ClientError
|
|
6
|
+
from fsspec.asyn import get_loop
|
|
7
|
+
|
|
8
|
+
from datachain.asyn import iter_over_async
|
|
9
|
+
from datachain.client import Client
|
|
10
|
+
from datachain.error import ClientError as DataChainClientError
|
|
11
|
+
from datachain.lib.file import File
|
|
12
|
+
|
|
13
|
+
ResultQueue = asyncio.Queue[Optional[Sequence[File]]]
|
|
14
|
+
|
|
15
|
+
DELIMITER = "/" # Path delimiter
|
|
16
|
+
FETCH_WORKERS = 100
|
|
17
|
+
|
|
18
|
+
|
|
19
|
+
async def _fetch_dir(client, prefix, result_queue) -> set[str]:
|
|
20
|
+
path = f"{client.name}/{prefix}"
|
|
21
|
+
infos = await client.ls_dir(path)
|
|
22
|
+
files = []
|
|
23
|
+
subdirs = set()
|
|
24
|
+
for info in infos:
|
|
25
|
+
full_path = info["name"]
|
|
26
|
+
subprefix = client.rel_path(full_path)
|
|
27
|
+
if prefix.strip(DELIMITER) == subprefix.strip(DELIMITER):
|
|
28
|
+
continue
|
|
29
|
+
if info["type"] == "directory":
|
|
30
|
+
subdirs.add(subprefix)
|
|
31
|
+
else:
|
|
32
|
+
files.append(client.info_to_file(info, subprefix))
|
|
33
|
+
if files:
|
|
34
|
+
await result_queue.put(files)
|
|
35
|
+
return subdirs
|
|
36
|
+
|
|
37
|
+
|
|
38
|
+
async def _fetch(
|
|
39
|
+
client, start_prefix: str, result_queue: ResultQueue, fetch_workers
|
|
40
|
+
) -> None:
|
|
41
|
+
loop = get_loop()
|
|
42
|
+
|
|
43
|
+
queue: asyncio.Queue[str] = asyncio.Queue()
|
|
44
|
+
queue.put_nowait(start_prefix)
|
|
45
|
+
|
|
46
|
+
async def process(queue) -> None:
|
|
47
|
+
while True:
|
|
48
|
+
prefix = await queue.get()
|
|
49
|
+
try:
|
|
50
|
+
subdirs = await _fetch_dir(client, prefix, result_queue)
|
|
51
|
+
for subdir in subdirs:
|
|
52
|
+
queue.put_nowait(subdir)
|
|
53
|
+
except Exception:
|
|
54
|
+
while not queue.empty():
|
|
55
|
+
queue.get_nowait()
|
|
56
|
+
queue.task_done()
|
|
57
|
+
raise
|
|
58
|
+
|
|
59
|
+
finally:
|
|
60
|
+
queue.task_done()
|
|
61
|
+
|
|
62
|
+
try:
|
|
63
|
+
workers: list[asyncio.Task] = [
|
|
64
|
+
loop.create_task(process(queue)) for _ in range(fetch_workers)
|
|
65
|
+
]
|
|
66
|
+
|
|
67
|
+
# Wait for all fetch tasks to complete
|
|
68
|
+
await queue.join()
|
|
69
|
+
# Stop the workers
|
|
70
|
+
excs = []
|
|
71
|
+
for worker in workers:
|
|
72
|
+
if worker.done() and (exc := worker.exception()):
|
|
73
|
+
excs.append(exc)
|
|
74
|
+
else:
|
|
75
|
+
worker.cancel()
|
|
76
|
+
if excs:
|
|
77
|
+
raise excs[0]
|
|
78
|
+
except ClientError as exc:
|
|
79
|
+
raise DataChainClientError(
|
|
80
|
+
exc.response.get("Error", {}).get("Message") or exc,
|
|
81
|
+
exc.response.get("Error", {}).get("Code"),
|
|
82
|
+
) from exc
|
|
83
|
+
finally:
|
|
84
|
+
# This ensures the progress bar is closed before any exceptions are raised
|
|
85
|
+
result_queue.put_nowait(None)
|
|
86
|
+
|
|
87
|
+
|
|
88
|
+
async def _scandir(client, prefix, fetch_workers) -> AsyncIterator:
|
|
89
|
+
"""Recursively goes through dir tree and yields files"""
|
|
90
|
+
result_queue: ResultQueue = asyncio.Queue()
|
|
91
|
+
loop = get_loop()
|
|
92
|
+
main_task = loop.create_task(_fetch(client, prefix, result_queue, fetch_workers))
|
|
93
|
+
while (files := await result_queue.get()) is not None:
|
|
94
|
+
for f in files:
|
|
95
|
+
yield f
|
|
96
|
+
|
|
97
|
+
await main_task
|
|
98
|
+
|
|
99
|
+
|
|
100
|
+
def list_bucket(uri: str, client_config=None, fetch_workers=FETCH_WORKERS) -> Callable:
|
|
101
|
+
"""
|
|
102
|
+
Function that returns another generator function that yields File objects
|
|
103
|
+
from bucket where each File represents one bucket entry.
|
|
104
|
+
"""
|
|
105
|
+
|
|
106
|
+
def list_func() -> Iterator[File]:
|
|
107
|
+
config = client_config or {}
|
|
108
|
+
client, path = Client.parse_url(uri, None, **config) # type: ignore[arg-type]
|
|
109
|
+
yield from iter_over_async(_scandir(client, path, fetch_workers), get_loop())
|
|
110
|
+
|
|
111
|
+
return list_func
|
datachain/lib/meta_formats.py
CHANGED
|
@@ -11,12 +11,16 @@ from collections.abc import Iterator
|
|
|
11
11
|
from typing import Any, Callable
|
|
12
12
|
|
|
13
13
|
import jmespath as jsp
|
|
14
|
-
from pydantic import Field, ValidationError # noqa: F401
|
|
14
|
+
from pydantic import BaseModel, ConfigDict, Field, ValidationError # noqa: F401
|
|
15
15
|
|
|
16
16
|
from datachain.lib.data_model import DataModel # noqa: F401
|
|
17
17
|
from datachain.lib.file import File
|
|
18
18
|
|
|
19
19
|
|
|
20
|
+
class UserModel(BaseModel):
|
|
21
|
+
model_config = ConfigDict(populate_by_name=True)
|
|
22
|
+
|
|
23
|
+
|
|
20
24
|
def generate_uuid():
|
|
21
25
|
return uuid.uuid4() # Generates a random UUID.
|
|
22
26
|
|
|
@@ -72,6 +76,8 @@ def read_schema(source_file, data_type="csv", expr=None, model_name=None):
|
|
|
72
76
|
data_type,
|
|
73
77
|
"--class-name",
|
|
74
78
|
model_name,
|
|
79
|
+
"--base-class",
|
|
80
|
+
"datachain.lib.meta_formats.UserModel",
|
|
75
81
|
]
|
|
76
82
|
try:
|
|
77
83
|
result = subprocess.run( # noqa: S603
|
|
@@ -87,7 +93,7 @@ def read_schema(source_file, data_type="csv", expr=None, model_name=None):
|
|
|
87
93
|
except subprocess.CalledProcessError as e:
|
|
88
94
|
model_output = f"An error occurred in datamodel-codegen: {e.stderr}"
|
|
89
95
|
print(f"{model_output}")
|
|
90
|
-
print("
|
|
96
|
+
print("from datachain.lib.data_model import DataModel")
|
|
91
97
|
print("\n" + f"DataModel.register({model_name})" + "\n")
|
|
92
98
|
print("\n" + f"spec={model_name}" + "\n")
|
|
93
99
|
return model_output
|
datachain/node.py
CHANGED
datachain/query/dataset.py
CHANGED
|
@@ -34,6 +34,7 @@ from sqlalchemy.sql.elements import ColumnClause, ColumnElement
|
|
|
34
34
|
from sqlalchemy.sql.expression import label
|
|
35
35
|
from sqlalchemy.sql.schema import TableClause
|
|
36
36
|
from sqlalchemy.sql.selectable import Select
|
|
37
|
+
from tqdm import tqdm
|
|
37
38
|
|
|
38
39
|
from datachain.asyn import ASYNC_WORKERS, AsyncMapper, OrderedMapper
|
|
39
40
|
from datachain.catalog import (
|
|
@@ -125,7 +126,10 @@ class QueryGenerator:
|
|
|
125
126
|
func: QueryGeneratorFunc
|
|
126
127
|
columns: tuple[ColumnElement, ...]
|
|
127
128
|
|
|
128
|
-
def
|
|
129
|
+
def only(self, column_names: Sequence[str]) -> Select:
|
|
130
|
+
return self.func(*(c for c in self.columns if c.name in column_names))
|
|
131
|
+
|
|
132
|
+
def exclude(self, column_names: Sequence[str]) -> Select:
|
|
129
133
|
return self.func(*(c for c in self.columns if c.name not in column_names))
|
|
130
134
|
|
|
131
135
|
def select(self, column_names=None) -> Select:
|
|
@@ -465,6 +469,12 @@ class UDFStep(Step, ABC):
|
|
|
465
469
|
|
|
466
470
|
try:
|
|
467
471
|
if workers:
|
|
472
|
+
if self.catalog.in_memory:
|
|
473
|
+
raise RuntimeError(
|
|
474
|
+
"In-memory databases cannot be used with "
|
|
475
|
+
"distributed processing."
|
|
476
|
+
)
|
|
477
|
+
|
|
468
478
|
from datachain.catalog.loader import get_distributed_class
|
|
469
479
|
|
|
470
480
|
distributor = get_distributed_class(min_task_size=self.min_task_size)
|
|
@@ -482,6 +492,10 @@ class UDFStep(Step, ABC):
|
|
|
482
492
|
)
|
|
483
493
|
elif processes:
|
|
484
494
|
# Parallel processing (faster for more CPU-heavy UDFs)
|
|
495
|
+
if self.catalog.in_memory:
|
|
496
|
+
raise RuntimeError(
|
|
497
|
+
"In-memory databases cannot be used with parallel processing."
|
|
498
|
+
)
|
|
485
499
|
udf_info = {
|
|
486
500
|
"udf_data": filtered_cloudpickle_dumps(self.udf),
|
|
487
501
|
"catalog_init": self.catalog.get_init_params(),
|
|
@@ -1049,6 +1063,7 @@ class DatasetQuery:
|
|
|
1049
1063
|
indexing_feature_schema: Optional[dict] = None,
|
|
1050
1064
|
indexing_column_types: Optional[dict[str, Any]] = None,
|
|
1051
1065
|
update: Optional[bool] = False,
|
|
1066
|
+
in_memory: bool = False,
|
|
1052
1067
|
):
|
|
1053
1068
|
if client_config is None:
|
|
1054
1069
|
client_config = {}
|
|
@@ -1057,7 +1072,7 @@ class DatasetQuery:
|
|
|
1057
1072
|
client_config["anon"] = True
|
|
1058
1073
|
|
|
1059
1074
|
self.session = Session.get(
|
|
1060
|
-
session, catalog=catalog, client_config=client_config
|
|
1075
|
+
session, catalog=catalog, client_config=client_config, in_memory=in_memory
|
|
1061
1076
|
)
|
|
1062
1077
|
self.catalog = catalog or self.session.catalog
|
|
1063
1078
|
self.steps: list[Step] = []
|
|
@@ -1648,18 +1663,13 @@ class DatasetQuery:
|
|
|
1648
1663
|
|
|
1649
1664
|
dr = self.catalog.warehouse.dataset_rows(dataset)
|
|
1650
1665
|
|
|
1651
|
-
|
|
1652
|
-
|
|
1653
|
-
|
|
1654
|
-
|
|
1655
|
-
|
|
1656
|
-
q = q.add_columns(
|
|
1657
|
-
f.row_number().over(order_by=q._order_by_clauses).label("sys__id")
|
|
1666
|
+
with tqdm(desc="Saving", unit=" rows") as pbar:
|
|
1667
|
+
self.catalog.warehouse.copy_table(
|
|
1668
|
+
dr.get_table(),
|
|
1669
|
+
query.select(),
|
|
1670
|
+
progress_cb=pbar.update,
|
|
1658
1671
|
)
|
|
1659
1672
|
|
|
1660
|
-
cols = tuple(c.name for c in q.selected_columns)
|
|
1661
|
-
insert_q = sqlalchemy.insert(dr.get_table()).from_select(cols, q)
|
|
1662
|
-
self.catalog.warehouse.db.execute(insert_q, **kwargs)
|
|
1663
1673
|
self.catalog.metastore.update_dataset_status(
|
|
1664
1674
|
dataset, DatasetStatus.COMPLETE, version=version
|
|
1665
1675
|
)
|
datachain/query/schema.py
CHANGED
|
@@ -45,6 +45,10 @@ class Column(sa.ColumnClause, metaclass=ColumnMeta):
|
|
|
45
45
|
"""Search for matches using glob pattern matching."""
|
|
46
46
|
return self.op("GLOB")(glob_str)
|
|
47
47
|
|
|
48
|
+
def regexp(self, regexp_str):
|
|
49
|
+
"""Search for matches using regexp pattern matching."""
|
|
50
|
+
return self.op("REGEXP")(regexp_str)
|
|
51
|
+
|
|
48
52
|
|
|
49
53
|
class UDFParameter(ABC):
|
|
50
54
|
@abstractmethod
|
datachain/query/session.py
CHANGED
|
@@ -46,6 +46,7 @@ class Session:
|
|
|
46
46
|
name="",
|
|
47
47
|
catalog: Optional["Catalog"] = None,
|
|
48
48
|
client_config: Optional[dict] = None,
|
|
49
|
+
in_memory: bool = False,
|
|
49
50
|
):
|
|
50
51
|
if re.match(r"^[0-9a-zA-Z]+$", name) is None:
|
|
51
52
|
raise ValueError(
|
|
@@ -58,7 +59,9 @@ class Session:
|
|
|
58
59
|
session_uuid = uuid4().hex[: self.SESSION_UUID_LEN]
|
|
59
60
|
self.name = f"{name}_{session_uuid}"
|
|
60
61
|
self.is_new_catalog = not catalog
|
|
61
|
-
self.catalog = catalog or get_catalog(
|
|
62
|
+
self.catalog = catalog or get_catalog(
|
|
63
|
+
client_config=client_config, in_memory=in_memory
|
|
64
|
+
)
|
|
62
65
|
|
|
63
66
|
def __enter__(self):
|
|
64
67
|
return self
|
|
@@ -89,6 +92,7 @@ class Session:
|
|
|
89
92
|
session: Optional["Session"] = None,
|
|
90
93
|
catalog: Optional["Catalog"] = None,
|
|
91
94
|
client_config: Optional[dict] = None,
|
|
95
|
+
in_memory: bool = False,
|
|
92
96
|
) -> "Session":
|
|
93
97
|
"""Creates a Session() object from a catalog.
|
|
94
98
|
|
|
@@ -102,7 +106,10 @@ class Session:
|
|
|
102
106
|
|
|
103
107
|
if cls.GLOBAL_SESSION is None:
|
|
104
108
|
cls.GLOBAL_SESSION_CTX = Session(
|
|
105
|
-
cls.GLOBAL_SESSION_NAME,
|
|
109
|
+
cls.GLOBAL_SESSION_NAME,
|
|
110
|
+
catalog,
|
|
111
|
+
client_config=client_config,
|
|
112
|
+
in_memory=in_memory,
|
|
106
113
|
)
|
|
107
114
|
cls.GLOBAL_SESSION = cls.GLOBAL_SESSION_CTX.__enter__()
|
|
108
115
|
atexit.register(cls._global_cleanup)
|
datachain/sql/default/base.py
CHANGED
|
@@ -1,8 +1,10 @@
|
|
|
1
1
|
from datachain.sql.types import (
|
|
2
|
+
DBDefaults,
|
|
2
3
|
TypeConverter,
|
|
3
4
|
TypeDefaults,
|
|
4
5
|
TypeReadConverter,
|
|
5
6
|
register_backend_types,
|
|
7
|
+
register_db_defaults,
|
|
6
8
|
register_type_defaults,
|
|
7
9
|
register_type_read_converters,
|
|
8
10
|
)
|
|
@@ -18,5 +20,6 @@ def setup() -> None:
|
|
|
18
20
|
register_backend_types("default", TypeConverter())
|
|
19
21
|
register_type_read_converters("default", TypeReadConverter())
|
|
20
22
|
register_type_defaults("default", TypeDefaults())
|
|
23
|
+
register_db_defaults("default", DBDefaults())
|
|
21
24
|
|
|
22
25
|
setup_is_complete = True
|
datachain/sql/sqlite/base.py
CHANGED
|
@@ -22,8 +22,10 @@ from datachain.sql.sqlite.types import (
|
|
|
22
22
|
register_type_converters,
|
|
23
23
|
)
|
|
24
24
|
from datachain.sql.types import (
|
|
25
|
+
DBDefaults,
|
|
25
26
|
TypeDefaults,
|
|
26
27
|
register_backend_types,
|
|
28
|
+
register_db_defaults,
|
|
27
29
|
register_type_defaults,
|
|
28
30
|
register_type_read_converters,
|
|
29
31
|
)
|
|
@@ -66,6 +68,7 @@ def setup():
|
|
|
66
68
|
register_backend_types("sqlite", SQLiteTypeConverter())
|
|
67
69
|
register_type_read_converters("sqlite", SQLiteTypeReadConverter())
|
|
68
70
|
register_type_defaults("sqlite", TypeDefaults())
|
|
71
|
+
register_db_defaults("sqlite", DBDefaults())
|
|
69
72
|
|
|
70
73
|
compiles(sql_path.parent, "sqlite")(compile_path_parent)
|
|
71
74
|
compiles(sql_path.name, "sqlite")(compile_path_name)
|
|
@@ -218,19 +221,45 @@ def path_name(path):
|
|
|
218
221
|
return func.ltrim(func.substr(path, func.length(path_parent(path)) + 1), slash)
|
|
219
222
|
|
|
220
223
|
|
|
221
|
-
def
|
|
222
|
-
name = path_name(path)
|
|
224
|
+
def name_file_ext_length(name):
|
|
223
225
|
expr = func.length(name) - func.length(
|
|
224
226
|
func.rtrim(name, func.replace(name, dot, empty_str))
|
|
225
227
|
)
|
|
226
228
|
return case((func.instr(name, dot) == 0, 0), else_=expr)
|
|
227
229
|
|
|
228
230
|
|
|
231
|
+
def path_file_ext_length(path):
|
|
232
|
+
name = path_name(path)
|
|
233
|
+
return name_file_ext_length(name)
|
|
234
|
+
|
|
235
|
+
|
|
229
236
|
def path_file_stem(path):
|
|
230
|
-
|
|
231
|
-
|
|
237
|
+
path_length = func.length(path)
|
|
238
|
+
parent_length = func.length(path_parent(path))
|
|
239
|
+
|
|
240
|
+
name_expr = func.rtrim(
|
|
241
|
+
func.substr(
|
|
242
|
+
path,
|
|
243
|
+
1,
|
|
244
|
+
path_length - name_file_ext_length(path),
|
|
245
|
+
),
|
|
246
|
+
dot,
|
|
247
|
+
)
|
|
248
|
+
|
|
249
|
+
full_path_expr = func.ltrim(
|
|
250
|
+
func.rtrim(
|
|
251
|
+
func.substr(
|
|
252
|
+
path,
|
|
253
|
+
parent_length + 1,
|
|
254
|
+
path_length - parent_length - path_file_ext_length(path),
|
|
255
|
+
),
|
|
256
|
+
dot,
|
|
257
|
+
),
|
|
258
|
+
slash,
|
|
232
259
|
)
|
|
233
260
|
|
|
261
|
+
return case((func.instr(path, slash) == 0, name_expr), else_=full_path_expr)
|
|
262
|
+
|
|
234
263
|
|
|
235
264
|
def path_file_ext(path):
|
|
236
265
|
return func.substr(path, func.length(path) - path_file_ext_length(path) + 1)
|