datachain 0.3.0__py3-none-any.whl → 0.3.2__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of datachain might be problematic. Click here for more details.
- datachain/catalog/catalog.py +11 -2
- datachain/client/fsspec.py +1 -4
- datachain/client/local.py +2 -7
- datachain/data_storage/schema.py +22 -8
- datachain/data_storage/sqlite.py +5 -0
- datachain/data_storage/warehouse.py +8 -14
- datachain/lib/dc.py +28 -14
- datachain/lib/meta_formats.py +8 -2
- datachain/lib/udf.py +21 -14
- datachain/node.py +1 -1
- datachain/query/batch.py +45 -41
- datachain/query/dataset.py +13 -6
- datachain/query/dispatch.py +53 -68
- datachain/query/queue.py +120 -0
- datachain/query/schema.py +4 -0
- datachain/query/udf.py +23 -8
- datachain/sql/default/base.py +3 -0
- datachain/sql/sqlite/base.py +3 -0
- datachain/sql/types.py +120 -11
- datachain/utils.py +17 -2
- {datachain-0.3.0.dist-info → datachain-0.3.2.dist-info}/METADATA +74 -86
- {datachain-0.3.0.dist-info → datachain-0.3.2.dist-info}/RECORD +26 -25
- {datachain-0.3.0.dist-info → datachain-0.3.2.dist-info}/WHEEL +1 -1
- {datachain-0.3.0.dist-info → datachain-0.3.2.dist-info}/LICENSE +0 -0
- {datachain-0.3.0.dist-info → datachain-0.3.2.dist-info}/entry_points.txt +0 -0
- {datachain-0.3.0.dist-info → datachain-0.3.2.dist-info}/top_level.txt +0 -0
datachain/catalog/catalog.py
CHANGED
|
@@ -676,7 +676,7 @@ class Catalog:
|
|
|
676
676
|
|
|
677
677
|
def parse_url(self, uri: str, **config: Any) -> tuple[Client, str]:
|
|
678
678
|
config = config or self.client_config
|
|
679
|
-
return Client.parse_url(uri, self.
|
|
679
|
+
return Client.parse_url(uri, self.cache, **config)
|
|
680
680
|
|
|
681
681
|
def get_client(self, uri: StorageURI, **config: Any) -> Client:
|
|
682
682
|
"""
|
|
@@ -1627,8 +1627,17 @@ class Catalog:
|
|
|
1627
1627
|
version = self.get_dataset(dataset_name).get_version(dataset_version)
|
|
1628
1628
|
|
|
1629
1629
|
file_signals_values = {}
|
|
1630
|
+
file_schemas = {}
|
|
1631
|
+
# TODO: To remove after we properly fix deserialization
|
|
1632
|
+
for signal, type_name in version.feature_schema.items():
|
|
1633
|
+
from datachain.lib.model_store import ModelStore
|
|
1630
1634
|
|
|
1631
|
-
|
|
1635
|
+
type_name_parsed, v = ModelStore.parse_name_version(type_name)
|
|
1636
|
+
fr = ModelStore.get(type_name_parsed, v)
|
|
1637
|
+
if fr and issubclass(fr, File):
|
|
1638
|
+
file_schemas[signal] = type_name
|
|
1639
|
+
|
|
1640
|
+
schema = SignalSchema.deserialize(file_schemas)
|
|
1632
1641
|
for file_signals in schema.get_signals(File):
|
|
1633
1642
|
prefix = file_signals.replace(".", DEFAULT_DELIMITER) + DEFAULT_DELIMITER
|
|
1634
1643
|
file_signals_values[file_signals] = {
|
datachain/client/fsspec.py
CHANGED
|
@@ -37,7 +37,6 @@ from datachain.storage import StorageURI
|
|
|
37
37
|
if TYPE_CHECKING:
|
|
38
38
|
from fsspec.spec import AbstractFileSystem
|
|
39
39
|
|
|
40
|
-
from datachain.data_storage import AbstractMetastore
|
|
41
40
|
|
|
42
41
|
logger = logging.getLogger("datachain")
|
|
43
42
|
|
|
@@ -116,13 +115,12 @@ class Client(ABC):
|
|
|
116
115
|
@staticmethod
|
|
117
116
|
def parse_url(
|
|
118
117
|
source: str,
|
|
119
|
-
metastore: "AbstractMetastore",
|
|
120
118
|
cache: DataChainCache,
|
|
121
119
|
**kwargs,
|
|
122
120
|
) -> tuple["Client", str]:
|
|
123
121
|
cls = Client.get_implementation(source)
|
|
124
122
|
storage_url, rel_path = cls.split_url(source)
|
|
125
|
-
client = cls.from_name(storage_url,
|
|
123
|
+
client = cls.from_name(storage_url, cache, kwargs)
|
|
126
124
|
return client, rel_path
|
|
127
125
|
|
|
128
126
|
@classmethod
|
|
@@ -136,7 +134,6 @@ class Client(ABC):
|
|
|
136
134
|
def from_name(
|
|
137
135
|
cls,
|
|
138
136
|
name: str,
|
|
139
|
-
metastore: "AbstractMetastore",
|
|
140
137
|
cache: DataChainCache,
|
|
141
138
|
kwargs: dict[str, Any],
|
|
142
139
|
) -> "Client":
|
datachain/client/local.py
CHANGED
|
@@ -2,7 +2,7 @@ import os
|
|
|
2
2
|
import posixpath
|
|
3
3
|
from datetime import datetime, timezone
|
|
4
4
|
from pathlib import Path
|
|
5
|
-
from typing import
|
|
5
|
+
from typing import Any
|
|
6
6
|
from urllib.parse import urlparse
|
|
7
7
|
|
|
8
8
|
from fsspec.implementations.local import LocalFileSystem
|
|
@@ -12,9 +12,6 @@ from datachain.storage import StorageURI
|
|
|
12
12
|
|
|
13
13
|
from .fsspec import Client
|
|
14
14
|
|
|
15
|
-
if TYPE_CHECKING:
|
|
16
|
-
from datachain.data_storage import AbstractMetastore
|
|
17
|
-
|
|
18
15
|
|
|
19
16
|
class FileClient(Client):
|
|
20
17
|
FS_CLASS = LocalFileSystem
|
|
@@ -97,9 +94,7 @@ class FileClient(Client):
|
|
|
97
94
|
return cls.root_dir(), uri.removeprefix(cls.root_path().as_uri())
|
|
98
95
|
|
|
99
96
|
@classmethod
|
|
100
|
-
def from_name(
|
|
101
|
-
cls, name: str, metastore: "AbstractMetastore", cache, kwargs
|
|
102
|
-
) -> "FileClient":
|
|
97
|
+
def from_name(cls, name: str, cache, kwargs) -> "FileClient":
|
|
103
98
|
use_symlinks = kwargs.pop("use_symlinks", False)
|
|
104
99
|
return cls(name, kwargs, cache, use_symlinks=use_symlinks)
|
|
105
100
|
|
datachain/data_storage/schema.py
CHANGED
|
@@ -67,7 +67,11 @@ def convert_rows_custom_column_types(
|
|
|
67
67
|
for row in rows:
|
|
68
68
|
row_list = list(row)
|
|
69
69
|
for idx, t in custom_columns_types:
|
|
70
|
-
row_list[idx] =
|
|
70
|
+
row_list[idx] = (
|
|
71
|
+
t.default_value(dialect)
|
|
72
|
+
if row_list[idx] is None
|
|
73
|
+
else t.on_read_convert(row_list[idx], dialect)
|
|
74
|
+
)
|
|
71
75
|
|
|
72
76
|
yield tuple(row_list)
|
|
73
77
|
|
|
@@ -136,7 +140,15 @@ class DataTable:
|
|
|
136
140
|
self.column_types: dict[str, SQLType] = column_types or {}
|
|
137
141
|
|
|
138
142
|
@staticmethod
|
|
139
|
-
def copy_column(
|
|
143
|
+
def copy_column(
|
|
144
|
+
column: sa.Column,
|
|
145
|
+
primary_key: Optional[bool] = None,
|
|
146
|
+
index: Optional[bool] = None,
|
|
147
|
+
nullable: Optional[bool] = None,
|
|
148
|
+
default: Optional[Any] = None,
|
|
149
|
+
server_default: Optional[Any] = None,
|
|
150
|
+
unique: Optional[bool] = None,
|
|
151
|
+
) -> sa.Column:
|
|
140
152
|
"""
|
|
141
153
|
Copy a sqlalchemy Column object intended for use as a signal column.
|
|
142
154
|
|
|
@@ -150,12 +162,14 @@ class DataTable:
|
|
|
150
162
|
return sa.Column(
|
|
151
163
|
column.name,
|
|
152
164
|
column.type,
|
|
153
|
-
primary_key=column.primary_key,
|
|
154
|
-
index=column.index,
|
|
155
|
-
nullable=column.nullable,
|
|
156
|
-
default=column.default,
|
|
157
|
-
server_default=
|
|
158
|
-
|
|
165
|
+
primary_key=primary_key if primary_key is not None else column.primary_key,
|
|
166
|
+
index=index if index is not None else column.index,
|
|
167
|
+
nullable=nullable if nullable is not None else column.nullable,
|
|
168
|
+
default=default if default is not None else column.default,
|
|
169
|
+
server_default=(
|
|
170
|
+
server_default if server_default is not None else column.server_default
|
|
171
|
+
),
|
|
172
|
+
unique=unique if unique is not None else column.unique,
|
|
159
173
|
)
|
|
160
174
|
|
|
161
175
|
@classmethod
|
datachain/data_storage/sqlite.py
CHANGED
|
@@ -122,6 +122,11 @@ class SQLiteDatabaseEngine(DatabaseEngine):
|
|
|
122
122
|
engine = sqlalchemy.create_engine(
|
|
123
123
|
"sqlite+pysqlite:///", creator=lambda: db, future=True
|
|
124
124
|
)
|
|
125
|
+
# ensure we run SA on_connect init (e.g it registers regexp function),
|
|
126
|
+
# also makes sure that it's consistent. Otherwise in some cases it
|
|
127
|
+
# seems we are getting different results if engine object is used in a
|
|
128
|
+
# different thread first and enine is not used in the Main thread.
|
|
129
|
+
engine.connect().close()
|
|
125
130
|
|
|
126
131
|
db.isolation_level = None # Use autocommit mode
|
|
127
132
|
db.execute("PRAGMA foreign_keys = ON")
|
|
@@ -17,7 +17,7 @@ from sqlalchemy.sql.expression import true
|
|
|
17
17
|
|
|
18
18
|
from datachain.client import Client
|
|
19
19
|
from datachain.data_storage.serializer import Serializable
|
|
20
|
-
from datachain.dataset import DatasetRecord
|
|
20
|
+
from datachain.dataset import DatasetRecord
|
|
21
21
|
from datachain.node import DirType, DirTypeGroup, Entry, Node, NodeWithPath, get_path
|
|
22
22
|
from datachain.sql.functions import path as pathfunc
|
|
23
23
|
from datachain.sql.types import Int, SQLType
|
|
@@ -201,23 +201,17 @@ class AbstractWarehouse(ABC, Serializable):
|
|
|
201
201
|
def dataset_select_paginated(
|
|
202
202
|
self,
|
|
203
203
|
query,
|
|
204
|
-
limit: Optional[int] = None,
|
|
205
|
-
order_by: tuple["ColumnElement[Any]", ...] = (),
|
|
206
204
|
page_size: int = SELECT_BATCH_SIZE,
|
|
207
|
-
) -> Generator[
|
|
205
|
+
) -> Generator[Sequence, None, None]:
|
|
208
206
|
"""
|
|
209
207
|
This is equivalent to `db.execute`, but for selecting rows in batches
|
|
210
208
|
"""
|
|
211
|
-
|
|
212
|
-
|
|
209
|
+
limit = query._limit
|
|
210
|
+
paginated_query = query.limit(page_size)
|
|
213
211
|
|
|
214
|
-
if not
|
|
215
|
-
|
|
216
|
-
|
|
217
|
-
ordering = order_by # type: ignore[assignment]
|
|
218
|
-
|
|
219
|
-
# reset query order by and apply new order by id
|
|
220
|
-
paginated_query = query.order_by(None).order_by(*ordering).limit(page_size)
|
|
212
|
+
if not paginated_query._order_by_clauses:
|
|
213
|
+
# default order by is order by `sys__id`
|
|
214
|
+
paginated_query = paginated_query.order_by(query.selected_columns.sys__id)
|
|
221
215
|
|
|
222
216
|
results = None
|
|
223
217
|
offset = 0
|
|
@@ -236,7 +230,7 @@ class AbstractWarehouse(ABC, Serializable):
|
|
|
236
230
|
processed = False
|
|
237
231
|
for row in results:
|
|
238
232
|
processed = True
|
|
239
|
-
yield
|
|
233
|
+
yield row
|
|
240
234
|
num_yielded += 1
|
|
241
235
|
|
|
242
236
|
if not processed:
|
datachain/lib/dc.py
CHANGED
|
@@ -508,7 +508,7 @@ class DataChain(DatasetQuery):
|
|
|
508
508
|
|
|
509
509
|
def print_json_schema( # type: ignore[override]
|
|
510
510
|
self, jmespath: Optional[str] = None, model_name: Optional[str] = None
|
|
511
|
-
) -> "
|
|
511
|
+
) -> "Self":
|
|
512
512
|
"""Print JSON data model and save it. It returns the chain itself.
|
|
513
513
|
|
|
514
514
|
Parameters:
|
|
@@ -533,7 +533,7 @@ class DataChain(DatasetQuery):
|
|
|
533
533
|
|
|
534
534
|
def print_jsonl_schema( # type: ignore[override]
|
|
535
535
|
self, jmespath: Optional[str] = None, model_name: Optional[str] = None
|
|
536
|
-
) -> "
|
|
536
|
+
) -> "Self":
|
|
537
537
|
"""Print JSON data model and save it. It returns the chain itself.
|
|
538
538
|
|
|
539
539
|
Parameters:
|
|
@@ -549,7 +549,7 @@ class DataChain(DatasetQuery):
|
|
|
549
549
|
|
|
550
550
|
def save( # type: ignore[override]
|
|
551
551
|
self, name: Optional[str] = None, version: Optional[int] = None
|
|
552
|
-
) -> "
|
|
552
|
+
) -> "Self":
|
|
553
553
|
"""Save to a Dataset. It returns the chain itself.
|
|
554
554
|
|
|
555
555
|
Parameters:
|
|
@@ -785,7 +785,7 @@ class DataChain(DatasetQuery):
|
|
|
785
785
|
descending (bool): Whether to sort in descending order or not.
|
|
786
786
|
"""
|
|
787
787
|
if descending:
|
|
788
|
-
args = tuple(
|
|
788
|
+
args = tuple(sqlalchemy.desc(a) for a in args)
|
|
789
789
|
|
|
790
790
|
return super().order_by(*args)
|
|
791
791
|
|
|
@@ -1206,14 +1206,14 @@ class DataChain(DatasetQuery):
|
|
|
1206
1206
|
"""
|
|
1207
1207
|
headers, max_length = self._effective_signals_schema.get_headers_with_length()
|
|
1208
1208
|
if flatten or max_length < 2:
|
|
1209
|
-
|
|
1209
|
+
columns = []
|
|
1210
1210
|
if headers:
|
|
1211
|
-
|
|
1212
|
-
return
|
|
1211
|
+
columns = [".".join(filter(None, header)) for header in headers]
|
|
1212
|
+
return pd.DataFrame.from_records(self.to_records(), columns=columns)
|
|
1213
1213
|
|
|
1214
|
-
|
|
1215
|
-
|
|
1216
|
-
|
|
1214
|
+
return pd.DataFrame(
|
|
1215
|
+
self.results(), columns=pd.MultiIndex.from_tuples(map(tuple, headers))
|
|
1216
|
+
)
|
|
1217
1217
|
|
|
1218
1218
|
def show(
|
|
1219
1219
|
self,
|
|
@@ -1232,6 +1232,12 @@ class DataChain(DatasetQuery):
|
|
|
1232
1232
|
"""
|
|
1233
1233
|
dc = self.limit(limit) if limit > 0 else self
|
|
1234
1234
|
df = dc.to_pandas(flatten)
|
|
1235
|
+
|
|
1236
|
+
if df.empty:
|
|
1237
|
+
print("Empty result")
|
|
1238
|
+
print(f"Columns: {list(df.columns)}")
|
|
1239
|
+
return
|
|
1240
|
+
|
|
1235
1241
|
if transpose:
|
|
1236
1242
|
df = df.T
|
|
1237
1243
|
|
|
@@ -1270,7 +1276,7 @@ class DataChain(DatasetQuery):
|
|
|
1270
1276
|
source: bool = True,
|
|
1271
1277
|
nrows: Optional[int] = None,
|
|
1272
1278
|
**kwargs,
|
|
1273
|
-
) -> "
|
|
1279
|
+
) -> "Self":
|
|
1274
1280
|
"""Generate chain from list of tabular files.
|
|
1275
1281
|
|
|
1276
1282
|
Parameters:
|
|
@@ -1390,7 +1396,8 @@ class DataChain(DatasetQuery):
|
|
|
1390
1396
|
dc = DataChain.from_csv("s3://mybucket/dir")
|
|
1391
1397
|
```
|
|
1392
1398
|
"""
|
|
1393
|
-
from
|
|
1399
|
+
from pandas.io.parsers.readers import STR_NA_VALUES
|
|
1400
|
+
from pyarrow.csv import ConvertOptions, ParseOptions, ReadOptions
|
|
1394
1401
|
from pyarrow.dataset import CsvFileFormat
|
|
1395
1402
|
|
|
1396
1403
|
chain = DataChain.from_storage(path, **kwargs)
|
|
@@ -1414,7 +1421,14 @@ class DataChain(DatasetQuery):
|
|
|
1414
1421
|
|
|
1415
1422
|
parse_options = ParseOptions(delimiter=delimiter)
|
|
1416
1423
|
read_options = ReadOptions(column_names=column_names)
|
|
1417
|
-
|
|
1424
|
+
convert_options = ConvertOptions(
|
|
1425
|
+
strings_can_be_null=True, null_values=STR_NA_VALUES
|
|
1426
|
+
)
|
|
1427
|
+
format = CsvFileFormat(
|
|
1428
|
+
parse_options=parse_options,
|
|
1429
|
+
read_options=read_options,
|
|
1430
|
+
convert_options=convert_options,
|
|
1431
|
+
)
|
|
1418
1432
|
return chain.parse_tabular(
|
|
1419
1433
|
output=output,
|
|
1420
1434
|
object_name=object_name,
|
|
@@ -1623,7 +1637,7 @@ class DataChain(DatasetQuery):
|
|
|
1623
1637
|
|
|
1624
1638
|
Using glob to match patterns
|
|
1625
1639
|
```py
|
|
1626
|
-
dc.filter(C("file.name").glob("*.jpg))
|
|
1640
|
+
dc.filter(C("file.name").glob("*.jpg"))
|
|
1627
1641
|
```
|
|
1628
1642
|
|
|
1629
1643
|
Using `datachain.sql.functions`
|
datachain/lib/meta_formats.py
CHANGED
|
@@ -11,12 +11,16 @@ from collections.abc import Iterator
|
|
|
11
11
|
from typing import Any, Callable
|
|
12
12
|
|
|
13
13
|
import jmespath as jsp
|
|
14
|
-
from pydantic import Field, ValidationError # noqa: F401
|
|
14
|
+
from pydantic import BaseModel, ConfigDict, Field, ValidationError # noqa: F401
|
|
15
15
|
|
|
16
16
|
from datachain.lib.data_model import DataModel # noqa: F401
|
|
17
17
|
from datachain.lib.file import File
|
|
18
18
|
|
|
19
19
|
|
|
20
|
+
class UserModel(BaseModel):
|
|
21
|
+
model_config = ConfigDict(populate_by_name=True)
|
|
22
|
+
|
|
23
|
+
|
|
20
24
|
def generate_uuid():
|
|
21
25
|
return uuid.uuid4() # Generates a random UUID.
|
|
22
26
|
|
|
@@ -72,6 +76,8 @@ def read_schema(source_file, data_type="csv", expr=None, model_name=None):
|
|
|
72
76
|
data_type,
|
|
73
77
|
"--class-name",
|
|
74
78
|
model_name,
|
|
79
|
+
"--base-class",
|
|
80
|
+
"datachain.lib.meta_formats.UserModel",
|
|
75
81
|
]
|
|
76
82
|
try:
|
|
77
83
|
result = subprocess.run( # noqa: S603
|
|
@@ -87,7 +93,7 @@ def read_schema(source_file, data_type="csv", expr=None, model_name=None):
|
|
|
87
93
|
except subprocess.CalledProcessError as e:
|
|
88
94
|
model_output = f"An error occurred in datamodel-codegen: {e.stderr}"
|
|
89
95
|
print(f"{model_output}")
|
|
90
|
-
print("
|
|
96
|
+
print("from datachain.lib.data_model import DataModel")
|
|
91
97
|
print("\n" + f"DataModel.register({model_name})" + "\n")
|
|
92
98
|
print("\n" + f"spec={model_name}" + "\n")
|
|
93
99
|
return model_output
|
datachain/lib/udf.py
CHANGED
|
@@ -1,6 +1,5 @@
|
|
|
1
1
|
import sys
|
|
2
2
|
import traceback
|
|
3
|
-
from collections.abc import Iterable, Iterator
|
|
4
3
|
from typing import TYPE_CHECKING, Callable, Optional
|
|
5
4
|
|
|
6
5
|
from fsspec.callbacks import DEFAULT_CALLBACK, Callback
|
|
@@ -14,16 +13,19 @@ from datachain.lib.model_store import ModelStore
|
|
|
14
13
|
from datachain.lib.signal_schema import SignalSchema
|
|
15
14
|
from datachain.lib.udf_signature import UdfSignature
|
|
16
15
|
from datachain.lib.utils import AbstractUDF, DataChainError, DataChainParamsError
|
|
17
|
-
from datachain.query.batch import
|
|
16
|
+
from datachain.query.batch import UDFInputBatch
|
|
18
17
|
from datachain.query.schema import ColumnParameter
|
|
19
18
|
from datachain.query.udf import UDFBase as _UDFBase
|
|
20
|
-
from datachain.query.udf import UDFProperties
|
|
19
|
+
from datachain.query.udf import UDFProperties
|
|
21
20
|
|
|
22
21
|
if TYPE_CHECKING:
|
|
22
|
+
from collections.abc import Iterable, Iterator, Sequence
|
|
23
|
+
|
|
23
24
|
from typing_extensions import Self
|
|
24
25
|
|
|
25
26
|
from datachain.catalog import Catalog
|
|
26
|
-
from datachain.query.batch import
|
|
27
|
+
from datachain.query.batch import RowsOutput, UDFInput
|
|
28
|
+
from datachain.query.udf import UDFResult
|
|
27
29
|
|
|
28
30
|
|
|
29
31
|
class UdfError(DataChainParamsError):
|
|
@@ -42,22 +44,27 @@ class UDFAdapter(_UDFBase):
|
|
|
42
44
|
|
|
43
45
|
def run(
|
|
44
46
|
self,
|
|
45
|
-
|
|
47
|
+
udf_fields: "Sequence[str]",
|
|
48
|
+
udf_inputs: "Iterable[RowsOutput]",
|
|
46
49
|
catalog: "Catalog",
|
|
47
50
|
is_generator: bool,
|
|
48
51
|
cache: bool,
|
|
49
52
|
download_cb: Callback = DEFAULT_CALLBACK,
|
|
50
53
|
processed_cb: Callback = DEFAULT_CALLBACK,
|
|
51
|
-
) -> Iterator[Iterable[
|
|
54
|
+
) -> "Iterator[Iterable[UDFResult]]":
|
|
52
55
|
self.inner._catalog = catalog
|
|
53
56
|
if hasattr(self.inner, "setup") and callable(self.inner.setup):
|
|
54
57
|
self.inner.setup()
|
|
55
58
|
|
|
56
|
-
|
|
57
|
-
|
|
58
|
-
|
|
59
|
-
|
|
60
|
-
|
|
59
|
+
yield from super().run(
|
|
60
|
+
udf_fields,
|
|
61
|
+
udf_inputs,
|
|
62
|
+
catalog,
|
|
63
|
+
is_generator,
|
|
64
|
+
cache,
|
|
65
|
+
download_cb,
|
|
66
|
+
processed_cb,
|
|
67
|
+
)
|
|
61
68
|
|
|
62
69
|
if hasattr(self.inner, "teardown") and callable(self.inner.teardown):
|
|
63
70
|
self.inner.teardown()
|
|
@@ -65,12 +72,12 @@ class UDFAdapter(_UDFBase):
|
|
|
65
72
|
def run_once(
|
|
66
73
|
self,
|
|
67
74
|
catalog: "Catalog",
|
|
68
|
-
arg: "
|
|
75
|
+
arg: "UDFInput",
|
|
69
76
|
is_generator: bool = False,
|
|
70
77
|
cache: bool = False,
|
|
71
78
|
cb: Callback = DEFAULT_CALLBACK,
|
|
72
|
-
) -> Iterable[UDFResult]:
|
|
73
|
-
if isinstance(arg,
|
|
79
|
+
) -> "Iterable[UDFResult]":
|
|
80
|
+
if isinstance(arg, UDFInputBatch):
|
|
74
81
|
udf_inputs = [
|
|
75
82
|
self.bind_parameters(catalog, row, cache=cache, cb=cb)
|
|
76
83
|
for row in arg.rows
|
datachain/node.py
CHANGED
datachain/query/batch.py
CHANGED
|
@@ -5,21 +5,29 @@ from collections.abc import Generator, Sequence
|
|
|
5
5
|
from dataclasses import dataclass
|
|
6
6
|
from typing import TYPE_CHECKING, Callable, Optional, Union
|
|
7
7
|
|
|
8
|
-
import sqlalchemy as sa
|
|
9
|
-
|
|
10
8
|
from datachain.data_storage.schema import PARTITION_COLUMN_ID
|
|
11
9
|
from datachain.data_storage.warehouse import SELECT_BATCH_SIZE
|
|
12
10
|
|
|
13
11
|
if TYPE_CHECKING:
|
|
12
|
+
from sqlalchemy import Select
|
|
13
|
+
|
|
14
14
|
from datachain.dataset import RowDict
|
|
15
15
|
|
|
16
16
|
|
|
17
17
|
@dataclass
|
|
18
|
-
class
|
|
18
|
+
class RowsOutputBatch:
|
|
19
|
+
rows: Sequence[Sequence]
|
|
20
|
+
|
|
21
|
+
|
|
22
|
+
RowsOutput = Union[Sequence, RowsOutputBatch]
|
|
23
|
+
|
|
24
|
+
|
|
25
|
+
@dataclass
|
|
26
|
+
class UDFInputBatch:
|
|
19
27
|
rows: Sequence["RowDict"]
|
|
20
28
|
|
|
21
29
|
|
|
22
|
-
|
|
30
|
+
UDFInput = Union["RowDict", UDFInputBatch]
|
|
23
31
|
|
|
24
32
|
|
|
25
33
|
class BatchingStrategy(ABC):
|
|
@@ -28,9 +36,9 @@ class BatchingStrategy(ABC):
|
|
|
28
36
|
@abstractmethod
|
|
29
37
|
def __call__(
|
|
30
38
|
self,
|
|
31
|
-
execute: Callable,
|
|
32
|
-
query:
|
|
33
|
-
) -> Generator[
|
|
39
|
+
execute: Callable[..., Generator[Sequence, None, None]],
|
|
40
|
+
query: "Select",
|
|
41
|
+
) -> Generator[RowsOutput, None, None]:
|
|
34
42
|
"""Apply the provided parameters to the UDF."""
|
|
35
43
|
|
|
36
44
|
|
|
@@ -42,10 +50,10 @@ class NoBatching(BatchingStrategy):
|
|
|
42
50
|
|
|
43
51
|
def __call__(
|
|
44
52
|
self,
|
|
45
|
-
execute: Callable,
|
|
46
|
-
query:
|
|
47
|
-
) -> Generator[
|
|
48
|
-
return execute(query
|
|
53
|
+
execute: Callable[..., Generator[Sequence, None, None]],
|
|
54
|
+
query: "Select",
|
|
55
|
+
) -> Generator[Sequence, None, None]:
|
|
56
|
+
return execute(query)
|
|
49
57
|
|
|
50
58
|
|
|
51
59
|
class Batch(BatchingStrategy):
|
|
@@ -59,31 +67,24 @@ class Batch(BatchingStrategy):
|
|
|
59
67
|
|
|
60
68
|
def __call__(
|
|
61
69
|
self,
|
|
62
|
-
execute: Callable,
|
|
63
|
-
query:
|
|
64
|
-
) -> Generator[
|
|
70
|
+
execute: Callable[..., Generator[Sequence, None, None]],
|
|
71
|
+
query: "Select",
|
|
72
|
+
) -> Generator[RowsOutputBatch, None, None]:
|
|
65
73
|
# choose page size that is a multiple of the batch size
|
|
66
74
|
page_size = math.ceil(SELECT_BATCH_SIZE / self.count) * self.count
|
|
67
75
|
|
|
68
76
|
# select rows in batches
|
|
69
|
-
results: list[
|
|
70
|
-
|
|
71
|
-
with contextlib.closing(
|
|
72
|
-
execute(
|
|
73
|
-
query,
|
|
74
|
-
page_size=page_size,
|
|
75
|
-
limit=query._limit,
|
|
76
|
-
order_by=query._order_by_clauses,
|
|
77
|
-
)
|
|
78
|
-
) as rows:
|
|
77
|
+
results: list[Sequence] = []
|
|
78
|
+
|
|
79
|
+
with contextlib.closing(execute(query, page_size=page_size)) as rows:
|
|
79
80
|
for row in rows:
|
|
80
81
|
results.append(row)
|
|
81
82
|
if len(results) >= self.count:
|
|
82
83
|
batch, results = results[: self.count], results[self.count :]
|
|
83
|
-
yield
|
|
84
|
+
yield RowsOutputBatch(batch)
|
|
84
85
|
|
|
85
86
|
if len(results) > 0:
|
|
86
|
-
yield
|
|
87
|
+
yield RowsOutputBatch(results)
|
|
87
88
|
|
|
88
89
|
|
|
89
90
|
class Partition(BatchingStrategy):
|
|
@@ -95,27 +96,30 @@ class Partition(BatchingStrategy):
|
|
|
95
96
|
|
|
96
97
|
def __call__(
|
|
97
98
|
self,
|
|
98
|
-
execute: Callable,
|
|
99
|
-
query:
|
|
100
|
-
) -> Generator[
|
|
99
|
+
execute: Callable[..., Generator[Sequence, None, None]],
|
|
100
|
+
query: "Select",
|
|
101
|
+
) -> Generator[RowsOutputBatch, None, None]:
|
|
101
102
|
current_partition: Optional[int] = None
|
|
102
|
-
batch: list[
|
|
103
|
-
|
|
104
|
-
|
|
105
|
-
|
|
106
|
-
|
|
107
|
-
|
|
108
|
-
|
|
109
|
-
|
|
110
|
-
|
|
103
|
+
batch: list[Sequence] = []
|
|
104
|
+
|
|
105
|
+
query_fields = [str(c.name) for c in query.selected_columns]
|
|
106
|
+
partition_column_idx = query_fields.index(PARTITION_COLUMN_ID)
|
|
107
|
+
|
|
108
|
+
ordered_query = query.order_by(None).order_by(
|
|
109
|
+
PARTITION_COLUMN_ID,
|
|
110
|
+
"sys__id",
|
|
111
|
+
*query._order_by_clauses,
|
|
112
|
+
)
|
|
113
|
+
|
|
114
|
+
with contextlib.closing(execute(ordered_query)) as rows:
|
|
111
115
|
for row in rows:
|
|
112
|
-
partition = row[
|
|
116
|
+
partition = row[partition_column_idx]
|
|
113
117
|
if current_partition != partition:
|
|
114
118
|
current_partition = partition
|
|
115
119
|
if len(batch) > 0:
|
|
116
|
-
yield
|
|
120
|
+
yield RowsOutputBatch(batch)
|
|
117
121
|
batch = []
|
|
118
122
|
batch.append(row)
|
|
119
123
|
|
|
120
124
|
if len(batch) > 0:
|
|
121
|
-
yield
|
|
125
|
+
yield RowsOutputBatch(batch)
|
datachain/query/dataset.py
CHANGED
|
@@ -461,6 +461,8 @@ class UDFStep(Step, ABC):
|
|
|
461
461
|
|
|
462
462
|
processes = determine_processes(self.parallel)
|
|
463
463
|
|
|
464
|
+
udf_fields = [str(c.name) for c in query.selected_columns]
|
|
465
|
+
|
|
464
466
|
try:
|
|
465
467
|
if workers:
|
|
466
468
|
from datachain.catalog.loader import get_distributed_class
|
|
@@ -473,6 +475,7 @@ class UDFStep(Step, ABC):
|
|
|
473
475
|
query,
|
|
474
476
|
workers,
|
|
475
477
|
processes,
|
|
478
|
+
udf_fields=udf_fields,
|
|
476
479
|
is_generator=self.is_generator,
|
|
477
480
|
use_partitioning=use_partitioning,
|
|
478
481
|
cache=self.cache,
|
|
@@ -489,6 +492,7 @@ class UDFStep(Step, ABC):
|
|
|
489
492
|
"warehouse_clone_params": self.catalog.warehouse.clone_params(),
|
|
490
493
|
"table": udf_table,
|
|
491
494
|
"query": query,
|
|
495
|
+
"udf_fields": udf_fields,
|
|
492
496
|
"batching": batching,
|
|
493
497
|
"processes": processes,
|
|
494
498
|
"is_generator": self.is_generator,
|
|
@@ -528,6 +532,7 @@ class UDFStep(Step, ABC):
|
|
|
528
532
|
generated_cb = get_generated_callback(self.is_generator)
|
|
529
533
|
try:
|
|
530
534
|
udf_results = udf.run(
|
|
535
|
+
udf_fields,
|
|
531
536
|
udf_inputs,
|
|
532
537
|
self.catalog,
|
|
533
538
|
self.is_generator,
|
|
@@ -1244,21 +1249,23 @@ class DatasetQuery:
|
|
|
1244
1249
|
actual_params = [normalize_param(p) for p in params]
|
|
1245
1250
|
try:
|
|
1246
1251
|
query = self.apply_steps().select()
|
|
1252
|
+
query_fields = [str(c.name) for c in query.selected_columns]
|
|
1247
1253
|
|
|
1248
|
-
def row_iter() -> Generator[
|
|
1254
|
+
def row_iter() -> Generator[Sequence, None, None]:
|
|
1249
1255
|
# warehouse isn't threadsafe, we need to clone() it
|
|
1250
1256
|
# in the thread that uses the results
|
|
1251
1257
|
with self.catalog.warehouse.clone() as warehouse:
|
|
1252
|
-
gen = warehouse.dataset_select_paginated(
|
|
1253
|
-
query, limit=query._limit, order_by=query._order_by_clauses
|
|
1254
|
-
)
|
|
1258
|
+
gen = warehouse.dataset_select_paginated(query)
|
|
1255
1259
|
with contextlib.closing(gen) as rows:
|
|
1256
1260
|
yield from rows
|
|
1257
1261
|
|
|
1258
|
-
async def get_params(row:
|
|
1262
|
+
async def get_params(row: Sequence) -> tuple:
|
|
1263
|
+
row_dict = RowDict(zip(query_fields, row))
|
|
1259
1264
|
return tuple(
|
|
1260
1265
|
[
|
|
1261
|
-
await p.get_value_async(
|
|
1266
|
+
await p.get_value_async(
|
|
1267
|
+
self.catalog, row_dict, mapper, **kwargs
|
|
1268
|
+
)
|
|
1262
1269
|
for p in actual_params
|
|
1263
1270
|
]
|
|
1264
1271
|
)
|