datachain 0.5.0__py3-none-any.whl → 0.5.1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of datachain might be problematic. Click here for more details.
- datachain/catalog/catalog.py +8 -0
- datachain/data_storage/metastore.py +20 -1
- datachain/data_storage/sqlite.py +24 -32
- datachain/lib/arrow.py +64 -19
- datachain/lib/dc.py +113 -10
- datachain/lib/udf.py +100 -78
- datachain/lib/udf_signature.py +8 -6
- datachain/query/dataset.py +6 -6
- datachain/query/dispatch.py +2 -2
- datachain/query/session.py +42 -0
- {datachain-0.5.0.dist-info → datachain-0.5.1.dist-info}/METADATA +1 -1
- {datachain-0.5.0.dist-info → datachain-0.5.1.dist-info}/RECORD +16 -17
- datachain/query/udf.py +0 -126
- {datachain-0.5.0.dist-info → datachain-0.5.1.dist-info}/LICENSE +0 -0
- {datachain-0.5.0.dist-info → datachain-0.5.1.dist-info}/WHEEL +0 -0
- {datachain-0.5.0.dist-info → datachain-0.5.1.dist-info}/entry_points.txt +0 -0
- {datachain-0.5.0.dist-info → datachain-0.5.1.dist-info}/top_level.txt +0 -0
datachain/catalog/catalog.py
CHANGED
|
@@ -988,6 +988,14 @@ class Catalog:
|
|
|
988
988
|
schema = {
|
|
989
989
|
c.name: c.type.to_dict() for c in columns if isinstance(c.type, SQLType)
|
|
990
990
|
}
|
|
991
|
+
|
|
992
|
+
job_id = job_id or os.getenv("DATACHAIN_JOB_ID")
|
|
993
|
+
if not job_id:
|
|
994
|
+
from datachain.query.session import Session
|
|
995
|
+
|
|
996
|
+
session = Session.get(catalog=self)
|
|
997
|
+
job_id = session.job_id
|
|
998
|
+
|
|
991
999
|
dataset = self.metastore.create_dataset_version(
|
|
992
1000
|
dataset,
|
|
993
1001
|
version,
|
|
@@ -50,7 +50,6 @@ if TYPE_CHECKING:
|
|
|
50
50
|
from datachain.data_storage import AbstractIDGenerator, schema
|
|
51
51
|
from datachain.data_storage.db_engine import DatabaseEngine
|
|
52
52
|
|
|
53
|
-
|
|
54
53
|
logger = logging.getLogger("datachain")
|
|
55
54
|
|
|
56
55
|
|
|
@@ -384,6 +383,11 @@ class AbstractMetastore(ABC, Serializable):
|
|
|
384
383
|
) -> None:
|
|
385
384
|
"""Set the status of the given job and dataset."""
|
|
386
385
|
|
|
386
|
+
@abstractmethod
|
|
387
|
+
def get_job_dataset_versions(self, job_id: str) -> list[tuple[str, int]]:
|
|
388
|
+
"""Returns dataset names and versions for the job."""
|
|
389
|
+
raise NotImplementedError
|
|
390
|
+
|
|
387
391
|
|
|
388
392
|
class AbstractDBMetastore(AbstractMetastore):
|
|
389
393
|
"""
|
|
@@ -1519,3 +1523,18 @@ class AbstractDBMetastore(AbstractMetastore):
|
|
|
1519
1523
|
.values(status=dataset_status)
|
|
1520
1524
|
)
|
|
1521
1525
|
self.db.execute(query, conn=conn) # type: ignore[attr-defined]
|
|
1526
|
+
|
|
1527
|
+
def get_job_dataset_versions(self, job_id: str) -> list[tuple[str, int]]:
|
|
1528
|
+
"""Returns dataset names and versions for the job."""
|
|
1529
|
+
dv = self._datasets_versions
|
|
1530
|
+
ds = self._datasets
|
|
1531
|
+
|
|
1532
|
+
join_condition = dv.c.dataset_id == ds.c.id
|
|
1533
|
+
|
|
1534
|
+
query = (
|
|
1535
|
+
self._datasets_versions_select(ds.c.name, dv.c.version)
|
|
1536
|
+
.select_from(dv.join(ds, join_condition))
|
|
1537
|
+
.where(dv.c.job_id == job_id)
|
|
1538
|
+
)
|
|
1539
|
+
|
|
1540
|
+
return list(self.db.execute(query))
|
datachain/data_storage/sqlite.py
CHANGED
|
@@ -15,6 +15,7 @@ from typing import (
|
|
|
15
15
|
)
|
|
16
16
|
|
|
17
17
|
import sqlalchemy
|
|
18
|
+
from packaging import version
|
|
18
19
|
from sqlalchemy import MetaData, Table, UniqueConstraint, exists, select
|
|
19
20
|
from sqlalchemy.dialects import sqlite
|
|
20
21
|
from sqlalchemy.schema import CreateIndex, CreateTable, DropTable
|
|
@@ -153,7 +154,7 @@ class SQLiteDatabaseEngine(DatabaseEngine):
|
|
|
153
154
|
if os.environ.get("DEBUG_SHOW_SQL_QUERIES"):
|
|
154
155
|
import sys
|
|
155
156
|
|
|
156
|
-
db.set_trace_callback(sys.stderr
|
|
157
|
+
db.set_trace_callback(lambda stmt: print(stmt, file=sys.stderr))
|
|
157
158
|
|
|
158
159
|
load_usearch_extension(db)
|
|
159
160
|
|
|
@@ -345,45 +346,36 @@ class SQLiteIDGenerator(AbstractDBIDGenerator):
|
|
|
345
346
|
def get_next_ids(self, uri: str, count: int) -> range:
|
|
346
347
|
"""Returns a range of IDs for the given URI."""
|
|
347
348
|
|
|
348
|
-
|
|
349
|
-
|
|
350
|
-
|
|
351
|
-
|
|
352
|
-
# leaving fallback to the current implementation for older versions of SQLite,
|
|
353
|
-
# which is still supported, for example, in Ubuntu 20.04 LTS (Focal Fossa),
|
|
354
|
-
# where SQLite version 3.31.1 is used.
|
|
355
|
-
|
|
356
|
-
# sqlite_version = version.parse(sqlite3.sqlite_version)
|
|
357
|
-
# if sqlite_version >= version.parse("3.35.0"):
|
|
358
|
-
# # RETURNING is supported on SQLite 3.35.0 (2021-03-12) or newer
|
|
359
|
-
# stmt = (
|
|
360
|
-
# sqlite.insert(self._table)
|
|
361
|
-
# .values(uri=uri, last_id=count)
|
|
362
|
-
# .on_conflict_do_update(
|
|
363
|
-
# index_elements=["uri"],
|
|
364
|
-
# set_={"last_id": self._table.c.last_id + count},
|
|
365
|
-
# )
|
|
366
|
-
# .returning(self._table.c.last_id)
|
|
367
|
-
# )
|
|
368
|
-
# last_id = self._db.execute(stmt).fetchone()[0]
|
|
369
|
-
# else:
|
|
370
|
-
# (fallback to the current implementation with a transaction)
|
|
371
|
-
|
|
372
|
-
# Transactions ensure no concurrency conflicts
|
|
373
|
-
with self._db.transaction() as conn:
|
|
374
|
-
# UPSERT syntax was added to SQLite with version 3.24.0 (2018-06-04).
|
|
375
|
-
stmt_ins = (
|
|
349
|
+
sqlite_version = version.parse(sqlite3.sqlite_version)
|
|
350
|
+
is_returning_supported = sqlite_version >= version.parse("3.35.0")
|
|
351
|
+
if is_returning_supported:
|
|
352
|
+
stmt = (
|
|
376
353
|
sqlite.insert(self._table)
|
|
377
354
|
.values(uri=uri, last_id=count)
|
|
378
355
|
.on_conflict_do_update(
|
|
379
356
|
index_elements=["uri"],
|
|
380
357
|
set_={"last_id": self._table.c.last_id + count},
|
|
381
358
|
)
|
|
359
|
+
.returning(self._table.c.last_id)
|
|
382
360
|
)
|
|
383
|
-
self._db.execute(
|
|
361
|
+
last_id = self._db.execute(stmt).fetchone()[0]
|
|
362
|
+
else:
|
|
363
|
+
# Older versions of SQLite are still the default under Ubuntu LTS,
|
|
364
|
+
# e.g. Ubuntu 20.04 LTS (Focal Fossa) uses 3.31.1
|
|
365
|
+
# Transactions ensure no concurrency conflicts
|
|
366
|
+
with self._db.transaction() as conn:
|
|
367
|
+
stmt_ins = (
|
|
368
|
+
sqlite.insert(self._table)
|
|
369
|
+
.values(uri=uri, last_id=count)
|
|
370
|
+
.on_conflict_do_update(
|
|
371
|
+
index_elements=["uri"],
|
|
372
|
+
set_={"last_id": self._table.c.last_id + count},
|
|
373
|
+
)
|
|
374
|
+
)
|
|
375
|
+
self._db.execute(stmt_ins, conn=conn)
|
|
384
376
|
|
|
385
|
-
|
|
386
|
-
|
|
377
|
+
stmt_sel = select(self._table.c.last_id).where(self._table.c.uri == uri)
|
|
378
|
+
last_id = self._db.execute(stmt_sel, conn=conn).fetchone()[0]
|
|
387
379
|
|
|
388
380
|
return range(last_id - count + 1, last_id + 1)
|
|
389
381
|
|
datachain/lib/arrow.py
CHANGED
|
@@ -1,8 +1,9 @@
|
|
|
1
1
|
import re
|
|
2
2
|
from collections.abc import Sequence
|
|
3
3
|
from tempfile import NamedTemporaryFile
|
|
4
|
-
from typing import TYPE_CHECKING, Optional
|
|
4
|
+
from typing import TYPE_CHECKING, Any, Optional
|
|
5
5
|
|
|
6
|
+
import orjson
|
|
6
7
|
import pyarrow as pa
|
|
7
8
|
from pyarrow.dataset import CsvFileFormat, dataset
|
|
8
9
|
from tqdm import tqdm
|
|
@@ -10,6 +11,7 @@ from tqdm import tqdm
|
|
|
10
11
|
from datachain.lib.data_model import dict_to_data_model
|
|
11
12
|
from datachain.lib.file import ArrowRow, File
|
|
12
13
|
from datachain.lib.model_store import ModelStore
|
|
14
|
+
from datachain.lib.signal_schema import SignalSchema
|
|
13
15
|
from datachain.lib.udf import Generator
|
|
14
16
|
|
|
15
17
|
if TYPE_CHECKING:
|
|
@@ -20,6 +22,9 @@ if TYPE_CHECKING:
|
|
|
20
22
|
from datachain.lib.dc import DataChain
|
|
21
23
|
|
|
22
24
|
|
|
25
|
+
DATACHAIN_SIGNAL_SCHEMA_PARQUET_KEY = b"DataChain SignalSchema"
|
|
26
|
+
|
|
27
|
+
|
|
23
28
|
class ArrowGenerator(Generator):
|
|
24
29
|
def __init__(
|
|
25
30
|
self,
|
|
@@ -61,28 +66,35 @@ class ArrowGenerator(Generator):
|
|
|
61
66
|
path, filesystem=file.get_fs(), schema=self.input_schema, **self.kwargs
|
|
62
67
|
)
|
|
63
68
|
hf_schema = _get_hf_schema(ds.schema)
|
|
69
|
+
use_datachain_schema = (
|
|
70
|
+
bool(ds.schema.metadata)
|
|
71
|
+
and DATACHAIN_SIGNAL_SCHEMA_PARQUET_KEY in ds.schema.metadata
|
|
72
|
+
)
|
|
64
73
|
index = 0
|
|
65
74
|
with tqdm(desc="Parsed by pyarrow", unit=" rows") as pbar:
|
|
66
75
|
for record_batch in ds.to_batches():
|
|
67
76
|
for record in record_batch.to_pylist():
|
|
68
|
-
|
|
69
|
-
|
|
70
|
-
|
|
71
|
-
|
|
72
|
-
|
|
73
|
-
|
|
74
|
-
|
|
75
|
-
|
|
76
|
-
|
|
77
|
-
|
|
78
|
-
|
|
79
|
-
|
|
80
|
-
|
|
81
|
-
|
|
82
|
-
|
|
83
|
-
|
|
84
|
-
|
|
85
|
-
|
|
77
|
+
if use_datachain_schema and self.output_schema:
|
|
78
|
+
vals = [_nested_model_instantiate(record, self.output_schema)]
|
|
79
|
+
else:
|
|
80
|
+
vals = list(record.values())
|
|
81
|
+
if self.output_schema:
|
|
82
|
+
fields = self.output_schema.model_fields
|
|
83
|
+
vals_dict = {}
|
|
84
|
+
for i, ((field, field_info), val) in enumerate(
|
|
85
|
+
zip(fields.items(), vals)
|
|
86
|
+
):
|
|
87
|
+
anno = field_info.annotation
|
|
88
|
+
if hf_schema:
|
|
89
|
+
from datachain.lib.hf import convert_feature
|
|
90
|
+
|
|
91
|
+
feat = list(hf_schema[0].values())[i]
|
|
92
|
+
vals_dict[field] = convert_feature(val, feat, anno)
|
|
93
|
+
elif ModelStore.is_pydantic(anno):
|
|
94
|
+
vals_dict[field] = anno(**val) # type: ignore[misc]
|
|
95
|
+
else:
|
|
96
|
+
vals_dict[field] = val
|
|
97
|
+
vals = [self.output_schema(**vals_dict)]
|
|
86
98
|
if self.source:
|
|
87
99
|
kwargs: dict = self.kwargs
|
|
88
100
|
# Can't serialize CsvFileFormat; may lose formatting options.
|
|
@@ -113,6 +125,9 @@ def schema_to_output(schema: pa.Schema, col_names: Optional[Sequence[str]] = Non
|
|
|
113
125
|
)
|
|
114
126
|
if not col_names:
|
|
115
127
|
col_names = schema.names
|
|
128
|
+
signal_schema = _get_datachain_schema(schema)
|
|
129
|
+
if signal_schema:
|
|
130
|
+
return signal_schema.values
|
|
116
131
|
columns = _convert_col_names(col_names) # type: ignore[arg-type]
|
|
117
132
|
hf_schema = _get_hf_schema(schema)
|
|
118
133
|
if hf_schema:
|
|
@@ -197,3 +212,33 @@ def _get_hf_schema(
|
|
|
197
212
|
features = schema_from_arrow(schema)
|
|
198
213
|
return features, get_output_schema(features)
|
|
199
214
|
return None
|
|
215
|
+
|
|
216
|
+
|
|
217
|
+
def _get_datachain_schema(schema: "pa.Schema") -> Optional[SignalSchema]:
|
|
218
|
+
"""Return a restored SignalSchema from parquet metadata, if any is found."""
|
|
219
|
+
if schema.metadata and DATACHAIN_SIGNAL_SCHEMA_PARQUET_KEY in schema.metadata:
|
|
220
|
+
serialized_signal_schema = orjson.loads(
|
|
221
|
+
schema.metadata[DATACHAIN_SIGNAL_SCHEMA_PARQUET_KEY]
|
|
222
|
+
)
|
|
223
|
+
return SignalSchema.deserialize(serialized_signal_schema)
|
|
224
|
+
return None
|
|
225
|
+
|
|
226
|
+
|
|
227
|
+
def _nested_model_instantiate(
|
|
228
|
+
column_values: dict[str, Any], model: type["BaseModel"], prefix: str = ""
|
|
229
|
+
) -> "BaseModel":
|
|
230
|
+
"""Instantiate the given model and all sub-models/fields based on the provided
|
|
231
|
+
column values."""
|
|
232
|
+
vals_dict = {}
|
|
233
|
+
for field, field_info in model.model_fields.items():
|
|
234
|
+
anno = field_info.annotation
|
|
235
|
+
cur_path = f"{prefix}.{field}" if prefix else field
|
|
236
|
+
if ModelStore.is_pydantic(anno):
|
|
237
|
+
vals_dict[field] = _nested_model_instantiate(
|
|
238
|
+
column_values,
|
|
239
|
+
anno, # type: ignore[arg-type]
|
|
240
|
+
prefix=cur_path,
|
|
241
|
+
)
|
|
242
|
+
elif cur_path in column_values:
|
|
243
|
+
vals_dict[field] = column_values[cur_path]
|
|
244
|
+
return model(**vals_dict)
|
datachain/lib/dc.py
CHANGED
|
@@ -16,6 +16,7 @@ from typing import (
|
|
|
16
16
|
overload,
|
|
17
17
|
)
|
|
18
18
|
|
|
19
|
+
import orjson
|
|
19
20
|
import pandas as pd
|
|
20
21
|
import sqlalchemy
|
|
21
22
|
from pydantic import BaseModel
|
|
@@ -58,7 +59,7 @@ from datachain.query.dataset import (
|
|
|
58
59
|
from datachain.query.schema import DEFAULT_DELIMITER, Column, DatasetRow
|
|
59
60
|
from datachain.sql.functions import path as pathfunc
|
|
60
61
|
from datachain.telemetry import telemetry
|
|
61
|
-
from datachain.utils import inside_notebook
|
|
62
|
+
from datachain.utils import batched_it, inside_notebook
|
|
62
63
|
|
|
63
64
|
if TYPE_CHECKING:
|
|
64
65
|
from typing_extensions import Concatenate, ParamSpec, Self
|
|
@@ -71,6 +72,10 @@ C = Column
|
|
|
71
72
|
|
|
72
73
|
_T = TypeVar("_T")
|
|
73
74
|
D = TypeVar("D", bound="DataChain")
|
|
75
|
+
UDFObjT = TypeVar("UDFObjT", bound=UDFBase)
|
|
76
|
+
|
|
77
|
+
|
|
78
|
+
DEFAULT_PARQUET_CHUNK_SIZE = 100_000
|
|
74
79
|
|
|
75
80
|
|
|
76
81
|
def resolve_columns(
|
|
@@ -819,7 +824,7 @@ class DataChain:
|
|
|
819
824
|
|
|
820
825
|
def gen(
|
|
821
826
|
self,
|
|
822
|
-
func: Optional[Callable] = None,
|
|
827
|
+
func: Optional[Union[Callable, Generator]] = None,
|
|
823
828
|
params: Union[None, str, Sequence[str]] = None,
|
|
824
829
|
output: OutputType = None,
|
|
825
830
|
**signal_map,
|
|
@@ -931,12 +936,12 @@ class DataChain:
|
|
|
931
936
|
|
|
932
937
|
def _udf_to_obj(
|
|
933
938
|
self,
|
|
934
|
-
target_class: type[
|
|
935
|
-
func: Optional[Callable],
|
|
939
|
+
target_class: type[UDFObjT],
|
|
940
|
+
func: Optional[Union[Callable, UDFObjT]],
|
|
936
941
|
params: Union[None, str, Sequence[str]],
|
|
937
942
|
output: OutputType,
|
|
938
943
|
signal_map,
|
|
939
|
-
) ->
|
|
944
|
+
) -> UDFObjT:
|
|
940
945
|
is_generator = target_class.is_output_batched
|
|
941
946
|
name = self.name or ""
|
|
942
947
|
|
|
@@ -1103,6 +1108,29 @@ class DataChain:
|
|
|
1103
1108
|
rows = (row_factory(db_signals, r) for r in rows)
|
|
1104
1109
|
yield from rows
|
|
1105
1110
|
|
|
1111
|
+
def to_columnar_data_with_names(
|
|
1112
|
+
self, chunk_size: int = DEFAULT_PARQUET_CHUNK_SIZE
|
|
1113
|
+
) -> tuple[list[str], Iterator[list[list[Any]]]]:
|
|
1114
|
+
"""Returns column names and the results as an iterator that provides chunks,
|
|
1115
|
+
with each chunk containing a list of columns, where each column contains a
|
|
1116
|
+
list of the row values for that column in that chunk. Useful for columnar data
|
|
1117
|
+
formats, such as parquet or other OLAP databases.
|
|
1118
|
+
"""
|
|
1119
|
+
headers, _ = self._effective_signals_schema.get_headers_with_length()
|
|
1120
|
+
column_names = [".".join(filter(None, header)) for header in headers]
|
|
1121
|
+
|
|
1122
|
+
results_iter = self.collect_flatten()
|
|
1123
|
+
|
|
1124
|
+
def column_chunks() -> Iterator[list[list[Any]]]:
|
|
1125
|
+
for chunk_iter in batched_it(results_iter, chunk_size):
|
|
1126
|
+
columns: list[list[Any]] = [[] for _ in column_names]
|
|
1127
|
+
for row in chunk_iter:
|
|
1128
|
+
for i, col in enumerate(columns):
|
|
1129
|
+
col.append(row[i])
|
|
1130
|
+
yield columns
|
|
1131
|
+
|
|
1132
|
+
return column_names, column_chunks()
|
|
1133
|
+
|
|
1106
1134
|
@overload
|
|
1107
1135
|
def results(self) -> list[tuple[Any, ...]]: ...
|
|
1108
1136
|
|
|
@@ -1808,21 +1836,96 @@ class DataChain:
|
|
|
1808
1836
|
self,
|
|
1809
1837
|
path: Union[str, os.PathLike[str], BinaryIO],
|
|
1810
1838
|
partition_cols: Optional[Sequence[str]] = None,
|
|
1839
|
+
chunk_size: int = DEFAULT_PARQUET_CHUNK_SIZE,
|
|
1811
1840
|
**kwargs,
|
|
1812
1841
|
) -> None:
|
|
1813
|
-
"""Save chain to parquet file.
|
|
1842
|
+
"""Save chain to parquet file with SignalSchema metadata.
|
|
1814
1843
|
|
|
1815
1844
|
Parameters:
|
|
1816
1845
|
path : Path or a file-like binary object to save the file.
|
|
1817
1846
|
partition_cols : Column names by which to partition the dataset.
|
|
1847
|
+
chunk_size : The chunk size of results to read and convert to columnar
|
|
1848
|
+
data, to avoid running out of memory.
|
|
1818
1849
|
"""
|
|
1850
|
+
import pyarrow as pa
|
|
1851
|
+
import pyarrow.parquet as pq
|
|
1852
|
+
|
|
1853
|
+
from datachain.lib.arrow import DATACHAIN_SIGNAL_SCHEMA_PARQUET_KEY
|
|
1854
|
+
|
|
1819
1855
|
_partition_cols = list(partition_cols) if partition_cols else None
|
|
1820
|
-
|
|
1821
|
-
|
|
1822
|
-
partition_cols=_partition_cols,
|
|
1823
|
-
**kwargs,
|
|
1856
|
+
signal_schema_metadata = orjson.dumps(
|
|
1857
|
+
self._effective_signals_schema.serialize()
|
|
1824
1858
|
)
|
|
1825
1859
|
|
|
1860
|
+
column_names, column_chunks = self.to_columnar_data_with_names(chunk_size)
|
|
1861
|
+
|
|
1862
|
+
parquet_schema = None
|
|
1863
|
+
parquet_writer = None
|
|
1864
|
+
first_chunk = True
|
|
1865
|
+
|
|
1866
|
+
for chunk in column_chunks:
|
|
1867
|
+
# pyarrow infers the best parquet schema from the python types of
|
|
1868
|
+
# the input data.
|
|
1869
|
+
table = pa.Table.from_pydict(
|
|
1870
|
+
dict(zip(column_names, chunk)),
|
|
1871
|
+
schema=parquet_schema,
|
|
1872
|
+
)
|
|
1873
|
+
|
|
1874
|
+
# Preserve any existing metadata, and add the DataChain SignalSchema.
|
|
1875
|
+
existing_metadata = table.schema.metadata or {}
|
|
1876
|
+
merged_metadata = {
|
|
1877
|
+
**existing_metadata,
|
|
1878
|
+
DATACHAIN_SIGNAL_SCHEMA_PARQUET_KEY: signal_schema_metadata,
|
|
1879
|
+
}
|
|
1880
|
+
table = table.replace_schema_metadata(merged_metadata)
|
|
1881
|
+
parquet_schema = table.schema
|
|
1882
|
+
|
|
1883
|
+
if _partition_cols:
|
|
1884
|
+
# Write to a partitioned parquet dataset.
|
|
1885
|
+
pq.write_to_dataset(
|
|
1886
|
+
table,
|
|
1887
|
+
root_path=path,
|
|
1888
|
+
partition_cols=_partition_cols,
|
|
1889
|
+
**kwargs,
|
|
1890
|
+
)
|
|
1891
|
+
else:
|
|
1892
|
+
if first_chunk:
|
|
1893
|
+
# Write to a single parquet file.
|
|
1894
|
+
parquet_writer = pq.ParquetWriter(path, parquet_schema, **kwargs)
|
|
1895
|
+
first_chunk = False
|
|
1896
|
+
|
|
1897
|
+
assert parquet_writer
|
|
1898
|
+
parquet_writer.write_table(table)
|
|
1899
|
+
|
|
1900
|
+
if parquet_writer:
|
|
1901
|
+
parquet_writer.close()
|
|
1902
|
+
|
|
1903
|
+
def to_csv(
|
|
1904
|
+
self,
|
|
1905
|
+
path: Union[str, os.PathLike[str]],
|
|
1906
|
+
delimiter: str = ",",
|
|
1907
|
+
**kwargs,
|
|
1908
|
+
) -> None:
|
|
1909
|
+
"""Save chain to a csv (comma-separated values) file.
|
|
1910
|
+
|
|
1911
|
+
Parameters:
|
|
1912
|
+
path : Path to save the file.
|
|
1913
|
+
delimiter : Delimiter to use for the resulting file.
|
|
1914
|
+
"""
|
|
1915
|
+
import csv
|
|
1916
|
+
|
|
1917
|
+
headers, _ = self._effective_signals_schema.get_headers_with_length()
|
|
1918
|
+
column_names = [".".join(filter(None, header)) for header in headers]
|
|
1919
|
+
|
|
1920
|
+
results_iter = self.collect_flatten()
|
|
1921
|
+
|
|
1922
|
+
with open(path, "w", newline="") as f:
|
|
1923
|
+
writer = csv.writer(f, delimiter=delimiter, **kwargs)
|
|
1924
|
+
writer.writerow(column_names)
|
|
1925
|
+
|
|
1926
|
+
for row in results_iter:
|
|
1927
|
+
writer.writerow(row)
|
|
1928
|
+
|
|
1826
1929
|
@classmethod
|
|
1827
1930
|
def from_records(
|
|
1828
1931
|
cls,
|
datachain/lib/udf.py
CHANGED
|
@@ -1,31 +1,33 @@
|
|
|
1
1
|
import sys
|
|
2
2
|
import traceback
|
|
3
|
-
from
|
|
3
|
+
from collections.abc import Iterable, Iterator, Mapping, Sequence
|
|
4
|
+
from dataclasses import dataclass
|
|
5
|
+
from typing import TYPE_CHECKING, Any, Callable, Optional
|
|
4
6
|
|
|
5
7
|
from fsspec.callbacks import DEFAULT_CALLBACK, Callback
|
|
6
8
|
from pydantic import BaseModel
|
|
7
9
|
|
|
8
10
|
from datachain.dataset import RowDict
|
|
9
11
|
from datachain.lib.convert.flatten import flatten
|
|
10
|
-
from datachain.lib.convert.unflatten import unflatten_to_json
|
|
11
12
|
from datachain.lib.file import File
|
|
12
|
-
from datachain.lib.model_store import ModelStore
|
|
13
13
|
from datachain.lib.signal_schema import SignalSchema
|
|
14
|
-
from datachain.lib.udf_signature import UdfSignature
|
|
15
14
|
from datachain.lib.utils import AbstractUDF, DataChainError, DataChainParamsError
|
|
16
|
-
from datachain.query.batch import
|
|
17
|
-
|
|
18
|
-
|
|
19
|
-
|
|
15
|
+
from datachain.query.batch import (
|
|
16
|
+
Batch,
|
|
17
|
+
BatchingStrategy,
|
|
18
|
+
NoBatching,
|
|
19
|
+
Partition,
|
|
20
|
+
RowsOutputBatch,
|
|
21
|
+
UDFInputBatch,
|
|
22
|
+
)
|
|
23
|
+
from datachain.query.schema import ColumnParameter, UDFParameter
|
|
20
24
|
|
|
21
25
|
if TYPE_CHECKING:
|
|
22
|
-
from collections.abc import Iterable, Iterator, Sequence
|
|
23
|
-
|
|
24
26
|
from typing_extensions import Self
|
|
25
27
|
|
|
26
28
|
from datachain.catalog import Catalog
|
|
29
|
+
from datachain.lib.udf_signature import UdfSignature
|
|
27
30
|
from datachain.query.batch import RowsOutput, UDFInput
|
|
28
|
-
from datachain.query.udf import UDFResult
|
|
29
31
|
|
|
30
32
|
|
|
31
33
|
class UdfError(DataChainParamsError):
|
|
@@ -33,14 +35,47 @@ class UdfError(DataChainParamsError):
|
|
|
33
35
|
super().__init__(f"UDF error: {msg}")
|
|
34
36
|
|
|
35
37
|
|
|
36
|
-
|
|
38
|
+
ColumnType = Any
|
|
39
|
+
|
|
40
|
+
# Specification for the output of a UDF
|
|
41
|
+
UDFOutputSpec = Mapping[str, ColumnType]
|
|
42
|
+
|
|
43
|
+
# Result type when calling the UDF wrapper around the actual
|
|
44
|
+
# Python function / class implementing it.
|
|
45
|
+
UDFResult = dict[str, Any]
|
|
46
|
+
|
|
47
|
+
|
|
48
|
+
@dataclass
|
|
49
|
+
class UDFProperties:
|
|
50
|
+
"""Container for basic UDF properties."""
|
|
51
|
+
|
|
52
|
+
params: list[UDFParameter]
|
|
53
|
+
output: UDFOutputSpec
|
|
54
|
+
batch: int = 1
|
|
55
|
+
|
|
56
|
+
def get_batching(self, use_partitioning: bool = False) -> BatchingStrategy:
|
|
57
|
+
if use_partitioning:
|
|
58
|
+
return Partition()
|
|
59
|
+
if self.batch == 1:
|
|
60
|
+
return NoBatching()
|
|
61
|
+
if self.batch > 1:
|
|
62
|
+
return Batch(self.batch)
|
|
63
|
+
raise ValueError(f"invalid batch size {self.batch}")
|
|
64
|
+
|
|
65
|
+
def signal_names(self) -> Iterable[str]:
|
|
66
|
+
return self.output.keys()
|
|
67
|
+
|
|
68
|
+
|
|
69
|
+
class UDFAdapter:
|
|
37
70
|
def __init__(
|
|
38
71
|
self,
|
|
39
72
|
inner: "UDFBase",
|
|
40
73
|
properties: UDFProperties,
|
|
41
74
|
):
|
|
42
75
|
self.inner = inner
|
|
43
|
-
|
|
76
|
+
self.properties = properties
|
|
77
|
+
self.signal_names = properties.signal_names()
|
|
78
|
+
self.output = properties.output
|
|
44
79
|
|
|
45
80
|
def run(
|
|
46
81
|
self,
|
|
@@ -51,20 +86,23 @@ class UDFAdapter(_UDFBase):
|
|
|
51
86
|
cache: bool,
|
|
52
87
|
download_cb: Callback = DEFAULT_CALLBACK,
|
|
53
88
|
processed_cb: Callback = DEFAULT_CALLBACK,
|
|
54
|
-
) ->
|
|
55
|
-
self.inner.
|
|
89
|
+
) -> Iterator[Iterable[UDFResult]]:
|
|
90
|
+
self.inner.catalog = catalog
|
|
56
91
|
if hasattr(self.inner, "setup") and callable(self.inner.setup):
|
|
57
92
|
self.inner.setup()
|
|
58
93
|
|
|
59
|
-
|
|
60
|
-
|
|
61
|
-
|
|
62
|
-
|
|
63
|
-
|
|
64
|
-
|
|
65
|
-
|
|
66
|
-
|
|
67
|
-
|
|
94
|
+
for batch in udf_inputs:
|
|
95
|
+
if isinstance(batch, RowsOutputBatch):
|
|
96
|
+
n_rows = len(batch.rows)
|
|
97
|
+
inputs: UDFInput = UDFInputBatch(
|
|
98
|
+
[RowDict(zip(udf_fields, row)) for row in batch.rows]
|
|
99
|
+
)
|
|
100
|
+
else:
|
|
101
|
+
n_rows = 1
|
|
102
|
+
inputs = RowDict(zip(udf_fields, batch))
|
|
103
|
+
output = self.run_once(catalog, inputs, is_generator, cache, cb=download_cb)
|
|
104
|
+
processed_cb.relative_update(n_rows)
|
|
105
|
+
yield output
|
|
68
106
|
|
|
69
107
|
if hasattr(self.inner, "teardown") and callable(self.inner.teardown):
|
|
70
108
|
self.inner.teardown()
|
|
@@ -76,23 +114,46 @@ class UDFAdapter(_UDFBase):
|
|
|
76
114
|
is_generator: bool = False,
|
|
77
115
|
cache: bool = False,
|
|
78
116
|
cb: Callback = DEFAULT_CALLBACK,
|
|
79
|
-
) ->
|
|
117
|
+
) -> Iterable[UDFResult]:
|
|
80
118
|
if isinstance(arg, UDFInputBatch):
|
|
81
119
|
udf_inputs = [
|
|
82
120
|
self.bind_parameters(catalog, row, cache=cache, cb=cb)
|
|
83
121
|
for row in arg.rows
|
|
84
122
|
]
|
|
85
|
-
udf_outputs = self.inner(udf_inputs, cache=cache, download_cb=cb)
|
|
123
|
+
udf_outputs = self.inner.run_once(udf_inputs, cache=cache, download_cb=cb)
|
|
86
124
|
return self._process_results(arg.rows, udf_outputs, is_generator)
|
|
87
125
|
if isinstance(arg, RowDict):
|
|
88
126
|
udf_inputs = self.bind_parameters(catalog, arg, cache=cache, cb=cb)
|
|
89
|
-
udf_outputs = self.inner(
|
|
127
|
+
udf_outputs = self.inner.run_once(udf_inputs, cache=cache, download_cb=cb)
|
|
90
128
|
if not is_generator:
|
|
91
129
|
# udf_outputs is generator already if is_generator=True
|
|
92
130
|
udf_outputs = [udf_outputs]
|
|
93
131
|
return self._process_results([arg], udf_outputs, is_generator)
|
|
94
132
|
raise ValueError(f"Unexpected UDF argument: {arg}")
|
|
95
133
|
|
|
134
|
+
def bind_parameters(self, catalog: "Catalog", row: "RowDict", **kwargs) -> list:
|
|
135
|
+
return [p.get_value(catalog, row, **kwargs) for p in self.properties.params]
|
|
136
|
+
|
|
137
|
+
def _process_results(
|
|
138
|
+
self,
|
|
139
|
+
rows: Sequence["RowDict"],
|
|
140
|
+
results: Sequence[Sequence[Any]],
|
|
141
|
+
is_generator=False,
|
|
142
|
+
) -> Iterable[UDFResult]:
|
|
143
|
+
"""Create a list of dictionaries representing UDF results."""
|
|
144
|
+
|
|
145
|
+
# outputting rows
|
|
146
|
+
if is_generator:
|
|
147
|
+
# each row in results is a tuple of column values
|
|
148
|
+
return (dict(zip(self.signal_names, row)) for row in results)
|
|
149
|
+
|
|
150
|
+
# outputting signals
|
|
151
|
+
row_ids = [row["sys__id"] for row in rows]
|
|
152
|
+
return [
|
|
153
|
+
{"sys__id": row_id} | dict(zip(self.signal_names, signals))
|
|
154
|
+
for row_id, signals in zip(row_ids, results)
|
|
155
|
+
]
|
|
156
|
+
|
|
96
157
|
|
|
97
158
|
class UDFBase(AbstractUDF):
|
|
98
159
|
"""Base class for stateful user-defined functions.
|
|
@@ -146,14 +207,14 @@ class UDFBase(AbstractUDF):
|
|
|
146
207
|
is_output_batched = False
|
|
147
208
|
is_input_grouped = False
|
|
148
209
|
params_spec: Optional[list[str]]
|
|
210
|
+
catalog: "Optional[Catalog]"
|
|
149
211
|
|
|
150
212
|
def __init__(self):
|
|
151
213
|
self.params = None
|
|
152
214
|
self.output = None
|
|
153
215
|
self.params_spec = None
|
|
154
216
|
self.output_spec = None
|
|
155
|
-
self.
|
|
156
|
-
self._catalog = None
|
|
217
|
+
self.catalog = None
|
|
157
218
|
self._func = None
|
|
158
219
|
|
|
159
220
|
def process(self, *args, **kwargs):
|
|
@@ -174,9 +235,9 @@ class UDFBase(AbstractUDF):
|
|
|
174
235
|
|
|
175
236
|
def _init(
|
|
176
237
|
self,
|
|
177
|
-
sign: UdfSignature,
|
|
238
|
+
sign: "UdfSignature",
|
|
178
239
|
params: SignalSchema,
|
|
179
|
-
func: Callable,
|
|
240
|
+
func: Optional[Callable],
|
|
180
241
|
):
|
|
181
242
|
self.params = params
|
|
182
243
|
self.output = sign.output_schema
|
|
@@ -190,13 +251,13 @@ class UDFBase(AbstractUDF):
|
|
|
190
251
|
@classmethod
|
|
191
252
|
def _create(
|
|
192
253
|
cls,
|
|
193
|
-
sign: UdfSignature,
|
|
254
|
+
sign: "UdfSignature",
|
|
194
255
|
params: SignalSchema,
|
|
195
256
|
) -> "Self":
|
|
196
257
|
if isinstance(sign.func, AbstractUDF):
|
|
197
258
|
if not isinstance(sign.func, cls): # type: ignore[unreachable]
|
|
198
259
|
raise UdfError(
|
|
199
|
-
f"cannot create UDF: provided UDF '{sign.func.__name__}'"
|
|
260
|
+
f"cannot create UDF: provided UDF '{type(sign.func).__name__}'"
|
|
200
261
|
f" must be a child of target class '{cls.__name__}'",
|
|
201
262
|
)
|
|
202
263
|
result = sign.func
|
|
@@ -212,13 +273,6 @@ class UDFBase(AbstractUDF):
|
|
|
212
273
|
def name(self):
|
|
213
274
|
return self.__class__.__name__
|
|
214
275
|
|
|
215
|
-
def set_catalog(self, catalog):
|
|
216
|
-
self._catalog = catalog.copy(db=False)
|
|
217
|
-
|
|
218
|
-
@property
|
|
219
|
-
def catalog(self):
|
|
220
|
-
return self._catalog
|
|
221
|
-
|
|
222
276
|
def to_udf_wrapper(self, batch: int = 1) -> UDFAdapter:
|
|
223
277
|
assert self.params_spec is not None
|
|
224
278
|
properties = UDFProperties(
|
|
@@ -229,11 +283,9 @@ class UDFBase(AbstractUDF):
|
|
|
229
283
|
def validate_results(self, results, *args, **kwargs):
|
|
230
284
|
return results
|
|
231
285
|
|
|
232
|
-
def
|
|
233
|
-
if self.
|
|
234
|
-
objs = self.
|
|
235
|
-
elif self.is_input_batched:
|
|
236
|
-
objs = zip(*self._parse_rows(rows[0], cache, download_cb))
|
|
286
|
+
def run_once(self, rows, cache, download_cb):
|
|
287
|
+
if self.is_input_batched:
|
|
288
|
+
objs = zip(*self._parse_rows(rows, cache, download_cb))
|
|
237
289
|
else:
|
|
238
290
|
objs = self._parse_rows([rows], cache, download_cb)[0]
|
|
239
291
|
|
|
@@ -259,8 +311,8 @@ class UDFBase(AbstractUDF):
|
|
|
259
311
|
):
|
|
260
312
|
res = list(res)
|
|
261
313
|
assert len(res) == len(
|
|
262
|
-
rows
|
|
263
|
-
), f"{self.name} returns {len(res)} rows while len(rows
|
|
314
|
+
rows
|
|
315
|
+
), f"{self.name} returns {len(res)} rows while {len(rows)} expected"
|
|
264
316
|
|
|
265
317
|
return res
|
|
266
318
|
|
|
@@ -283,41 +335,11 @@ class UDFBase(AbstractUDF):
|
|
|
283
335
|
for obj in obj_row:
|
|
284
336
|
if isinstance(obj, File):
|
|
285
337
|
obj._set_stream(
|
|
286
|
-
self.
|
|
338
|
+
self.catalog, caching_enabled=cache, download_cb=download_cb
|
|
287
339
|
)
|
|
288
340
|
objs.append(obj_row)
|
|
289
341
|
return objs
|
|
290
342
|
|
|
291
|
-
def _parse_grouped_rows(self, group, cache, download_cb):
|
|
292
|
-
spec_map = {}
|
|
293
|
-
output_map = {}
|
|
294
|
-
for name, (anno, subtree) in self.params.tree.items():
|
|
295
|
-
if ModelStore.is_pydantic(anno):
|
|
296
|
-
length = sum(1 for _ in self.params._get_flat_tree(subtree, [], 0))
|
|
297
|
-
else:
|
|
298
|
-
length = 1
|
|
299
|
-
spec_map[name] = anno, length
|
|
300
|
-
output_map[name] = []
|
|
301
|
-
|
|
302
|
-
for flat_obj in group:
|
|
303
|
-
position = 0
|
|
304
|
-
for signal, (cls, length) in spec_map.items():
|
|
305
|
-
slice = flat_obj[position : position + length]
|
|
306
|
-
position += length
|
|
307
|
-
|
|
308
|
-
if ModelStore.is_pydantic(cls):
|
|
309
|
-
obj = cls(**unflatten_to_json(cls, slice))
|
|
310
|
-
else:
|
|
311
|
-
obj = slice[0]
|
|
312
|
-
|
|
313
|
-
if isinstance(obj, File):
|
|
314
|
-
obj._set_stream(
|
|
315
|
-
self._catalog, caching_enabled=cache, download_cb=download_cb
|
|
316
|
-
)
|
|
317
|
-
output_map[signal].append(obj)
|
|
318
|
-
|
|
319
|
-
return list(output_map.values())
|
|
320
|
-
|
|
321
343
|
def process_safe(self, obj_rows):
|
|
322
344
|
try:
|
|
323
345
|
result_objs = self.process(*obj_rows)
|
datachain/lib/udf_signature.py
CHANGED
|
@@ -1,10 +1,11 @@
|
|
|
1
1
|
import inspect
|
|
2
2
|
from collections.abc import Generator, Iterator, Sequence
|
|
3
3
|
from dataclasses import dataclass
|
|
4
|
-
from typing import Callable,
|
|
4
|
+
from typing import Callable, Union, get_args, get_origin
|
|
5
5
|
|
|
6
6
|
from datachain.lib.data_model import DataType, DataTypeNames, is_chain_type
|
|
7
7
|
from datachain.lib.signal_schema import SignalSchema
|
|
8
|
+
from datachain.lib.udf import UDFBase
|
|
8
9
|
from datachain.lib.utils import AbstractUDF, DataChainParamsError
|
|
9
10
|
|
|
10
11
|
|
|
@@ -16,7 +17,7 @@ class UdfSignatureError(DataChainParamsError):
|
|
|
16
17
|
|
|
17
18
|
@dataclass
|
|
18
19
|
class UdfSignature:
|
|
19
|
-
func: Callable
|
|
20
|
+
func: Union[Callable, UDFBase]
|
|
20
21
|
params: Sequence[str]
|
|
21
22
|
output_schema: SignalSchema
|
|
22
23
|
|
|
@@ -27,7 +28,7 @@ class UdfSignature:
|
|
|
27
28
|
cls,
|
|
28
29
|
chain: str,
|
|
29
30
|
signal_map: dict[str, Callable],
|
|
30
|
-
func:
|
|
31
|
+
func: Union[None, UDFBase, Callable] = None,
|
|
31
32
|
params: Union[None, str, Sequence[str]] = None,
|
|
32
33
|
output: Union[None, DataType, Sequence[str], dict[str, DataType]] = None,
|
|
33
34
|
is_generator: bool = True,
|
|
@@ -39,6 +40,7 @@ class UdfSignature:
|
|
|
39
40
|
f"multiple signals '{keys}' are not supported in processors."
|
|
40
41
|
" Chain multiple processors instead.",
|
|
41
42
|
)
|
|
43
|
+
udf_func: Union[UDFBase, Callable]
|
|
42
44
|
if len(signal_map) == 1:
|
|
43
45
|
if func is not None:
|
|
44
46
|
raise UdfSignatureError(
|
|
@@ -53,7 +55,7 @@ class UdfSignature:
|
|
|
53
55
|
udf_func = func
|
|
54
56
|
signal_name = None
|
|
55
57
|
|
|
56
|
-
if not callable(udf_func):
|
|
58
|
+
if not isinstance(udf_func, UDFBase) and not callable(udf_func):
|
|
57
59
|
raise UdfSignatureError(chain, f"UDF '{udf_func}' is not callable")
|
|
58
60
|
|
|
59
61
|
func_params_map_sign, func_outs_sign, is_iterator = (
|
|
@@ -73,7 +75,7 @@ class UdfSignature:
|
|
|
73
75
|
if not func_outs_sign:
|
|
74
76
|
raise UdfSignatureError(
|
|
75
77
|
chain,
|
|
76
|
-
f"outputs are not defined in function '{udf_func
|
|
78
|
+
f"outputs are not defined in function '{udf_func}'"
|
|
77
79
|
" hints or 'output'",
|
|
78
80
|
)
|
|
79
81
|
|
|
@@ -154,7 +156,7 @@ class UdfSignature:
|
|
|
154
156
|
|
|
155
157
|
@staticmethod
|
|
156
158
|
def _func_signature(
|
|
157
|
-
chain: str, udf_func: Callable
|
|
159
|
+
chain: str, udf_func: Union[Callable, UDFBase]
|
|
158
160
|
) -> tuple[dict[str, type], Sequence[type], bool]:
|
|
159
161
|
if isinstance(udf_func, AbstractUDF):
|
|
160
162
|
func = udf_func.process # type: ignore[unreachable]
|
datachain/query/dataset.py
CHANGED
|
@@ -42,6 +42,7 @@ from datachain.data_storage.schema import (
|
|
|
42
42
|
)
|
|
43
43
|
from datachain.dataset import DatasetStatus, RowDict
|
|
44
44
|
from datachain.error import DatasetNotFoundError, QueryScriptCancelError
|
|
45
|
+
from datachain.lib.udf import UDFAdapter
|
|
45
46
|
from datachain.progress import CombinedDownloadCallback
|
|
46
47
|
from datachain.sql.functions import rand
|
|
47
48
|
from datachain.utils import (
|
|
@@ -53,7 +54,6 @@ from datachain.utils import (
|
|
|
53
54
|
|
|
54
55
|
from .schema import C, UDFParamSpec, normalize_param
|
|
55
56
|
from .session import Session
|
|
56
|
-
from .udf import UDFBase
|
|
57
57
|
|
|
58
58
|
if TYPE_CHECKING:
|
|
59
59
|
from sqlalchemy.sql.elements import ClauseElement
|
|
@@ -299,7 +299,7 @@ def adjust_outputs(
|
|
|
299
299
|
return row
|
|
300
300
|
|
|
301
301
|
|
|
302
|
-
def get_udf_col_types(warehouse: "AbstractWarehouse", udf:
|
|
302
|
+
def get_udf_col_types(warehouse: "AbstractWarehouse", udf: UDFAdapter) -> list[tuple]:
|
|
303
303
|
"""Optimization: Precompute UDF column types so these don't have to be computed
|
|
304
304
|
in the convert_type function for each row in a loop."""
|
|
305
305
|
dialect = warehouse.db.dialect
|
|
@@ -320,7 +320,7 @@ def process_udf_outputs(
|
|
|
320
320
|
warehouse: "AbstractWarehouse",
|
|
321
321
|
udf_table: "Table",
|
|
322
322
|
udf_results: Iterator[Iterable["UDFResult"]],
|
|
323
|
-
udf:
|
|
323
|
+
udf: UDFAdapter,
|
|
324
324
|
batch_size: int = INSERT_BATCH_SIZE,
|
|
325
325
|
cb: Callback = DEFAULT_CALLBACK,
|
|
326
326
|
) -> None:
|
|
@@ -364,7 +364,7 @@ def get_generated_callback(is_generator: bool = False) -> Callback:
|
|
|
364
364
|
|
|
365
365
|
@frozen
|
|
366
366
|
class UDFStep(Step, ABC):
|
|
367
|
-
udf:
|
|
367
|
+
udf: UDFAdapter
|
|
368
368
|
catalog: "Catalog"
|
|
369
369
|
partition_by: Optional[PartitionByType] = None
|
|
370
370
|
parallel: Optional[int] = None
|
|
@@ -1465,7 +1465,7 @@ class DatasetQuery:
|
|
|
1465
1465
|
@detach
|
|
1466
1466
|
def add_signals(
|
|
1467
1467
|
self,
|
|
1468
|
-
udf:
|
|
1468
|
+
udf: UDFAdapter,
|
|
1469
1469
|
parallel: Optional[int] = None,
|
|
1470
1470
|
workers: Union[bool, int] = False,
|
|
1471
1471
|
min_task_size: Optional[int] = None,
|
|
@@ -1509,7 +1509,7 @@ class DatasetQuery:
|
|
|
1509
1509
|
@detach
|
|
1510
1510
|
def generate(
|
|
1511
1511
|
self,
|
|
1512
|
-
udf:
|
|
1512
|
+
udf: UDFAdapter,
|
|
1513
1513
|
parallel: Optional[int] = None,
|
|
1514
1514
|
workers: Union[bool, int] = False,
|
|
1515
1515
|
min_task_size: Optional[int] = None,
|
datachain/query/dispatch.py
CHANGED
|
@@ -13,6 +13,7 @@ from multiprocess import get_context
|
|
|
13
13
|
|
|
14
14
|
from datachain.catalog import Catalog
|
|
15
15
|
from datachain.catalog.loader import get_distributed_class
|
|
16
|
+
from datachain.lib.udf import UDFAdapter, UDFResult
|
|
16
17
|
from datachain.query.dataset import (
|
|
17
18
|
get_download_callback,
|
|
18
19
|
get_generated_callback,
|
|
@@ -27,7 +28,6 @@ from datachain.query.queue import (
|
|
|
27
28
|
put_into_queue,
|
|
28
29
|
unmarshal,
|
|
29
30
|
)
|
|
30
|
-
from datachain.query.udf import UDFBase, UDFResult
|
|
31
31
|
from datachain.utils import batched_it
|
|
32
32
|
|
|
33
33
|
DEFAULT_BATCH_SIZE = 10000
|
|
@@ -336,7 +336,7 @@ class ProcessedCallback(Callback):
|
|
|
336
336
|
@attrs.define
|
|
337
337
|
class UDFWorker:
|
|
338
338
|
catalog: Catalog
|
|
339
|
-
udf:
|
|
339
|
+
udf: UDFAdapter
|
|
340
340
|
task_queue: "multiprocess.Queue"
|
|
341
341
|
done_queue: "multiprocess.Queue"
|
|
342
342
|
is_generator: bool
|
datachain/query/session.py
CHANGED
|
@@ -1,5 +1,8 @@
|
|
|
1
1
|
import atexit
|
|
2
|
+
import logging
|
|
3
|
+
import os
|
|
2
4
|
import re
|
|
5
|
+
import sys
|
|
3
6
|
from typing import TYPE_CHECKING, Optional
|
|
4
7
|
from uuid import uuid4
|
|
5
8
|
|
|
@@ -9,6 +12,8 @@ from datachain.error import TableMissingError
|
|
|
9
12
|
if TYPE_CHECKING:
|
|
10
13
|
from datachain.catalog import Catalog
|
|
11
14
|
|
|
15
|
+
logger = logging.getLogger("datachain")
|
|
16
|
+
|
|
12
17
|
|
|
13
18
|
class Session:
|
|
14
19
|
"""
|
|
@@ -35,6 +40,7 @@ class Session:
|
|
|
35
40
|
|
|
36
41
|
GLOBAL_SESSION_CTX: Optional["Session"] = None
|
|
37
42
|
GLOBAL_SESSION: Optional["Session"] = None
|
|
43
|
+
ORIGINAL_EXCEPT_HOOK = None
|
|
38
44
|
|
|
39
45
|
DATASET_PREFIX = "session_"
|
|
40
46
|
GLOBAL_SESSION_NAME = "global"
|
|
@@ -58,6 +64,7 @@ class Session:
|
|
|
58
64
|
|
|
59
65
|
session_uuid = uuid4().hex[: self.SESSION_UUID_LEN]
|
|
60
66
|
self.name = f"{name}_{session_uuid}"
|
|
67
|
+
self.job_id = os.getenv("DATACHAIN_JOB_ID") or str(uuid4())
|
|
61
68
|
self.is_new_catalog = not catalog
|
|
62
69
|
self.catalog = catalog or get_catalog(
|
|
63
70
|
client_config=client_config, in_memory=in_memory
|
|
@@ -67,6 +74,9 @@ class Session:
|
|
|
67
74
|
return self
|
|
68
75
|
|
|
69
76
|
def __exit__(self, exc_type, exc_val, exc_tb):
|
|
77
|
+
if exc_type:
|
|
78
|
+
self._cleanup_created_versions(self.name)
|
|
79
|
+
|
|
70
80
|
self._cleanup_temp_datasets()
|
|
71
81
|
if self.is_new_catalog:
|
|
72
82
|
self.catalog.metastore.close_on_exit()
|
|
@@ -88,6 +98,21 @@ class Session:
|
|
|
88
98
|
except TableMissingError:
|
|
89
99
|
pass
|
|
90
100
|
|
|
101
|
+
def _cleanup_created_versions(self, job_id: str) -> None:
|
|
102
|
+
versions = self.catalog.metastore.get_job_dataset_versions(job_id)
|
|
103
|
+
if not versions:
|
|
104
|
+
return
|
|
105
|
+
|
|
106
|
+
datasets = {}
|
|
107
|
+
for dataset_name, version in versions:
|
|
108
|
+
if dataset_name not in datasets:
|
|
109
|
+
datasets[dataset_name] = self.catalog.get_dataset(dataset_name)
|
|
110
|
+
dataset = datasets[dataset_name]
|
|
111
|
+
logger.info(
|
|
112
|
+
"Removing dataset version %s@%s due to exception", dataset_name, version
|
|
113
|
+
)
|
|
114
|
+
self.catalog.remove_dataset_version(dataset, version)
|
|
115
|
+
|
|
91
116
|
@classmethod
|
|
92
117
|
def get(
|
|
93
118
|
cls,
|
|
@@ -114,9 +139,23 @@ class Session:
|
|
|
114
139
|
in_memory=in_memory,
|
|
115
140
|
)
|
|
116
141
|
cls.GLOBAL_SESSION = cls.GLOBAL_SESSION_CTX.__enter__()
|
|
142
|
+
|
|
117
143
|
atexit.register(cls._global_cleanup)
|
|
144
|
+
cls.ORIGINAL_EXCEPT_HOOK = sys.excepthook
|
|
145
|
+
sys.excepthook = cls.except_hook
|
|
146
|
+
|
|
118
147
|
return cls.GLOBAL_SESSION
|
|
119
148
|
|
|
149
|
+
@staticmethod
|
|
150
|
+
def except_hook(exc_type, exc_value, exc_traceback):
|
|
151
|
+
Session._global_cleanup()
|
|
152
|
+
if Session.GLOBAL_SESSION_CTX is not None:
|
|
153
|
+
job_id = Session.GLOBAL_SESSION_CTX.job_id
|
|
154
|
+
Session.GLOBAL_SESSION_CTX._cleanup_created_versions(job_id)
|
|
155
|
+
|
|
156
|
+
if Session.ORIGINAL_EXCEPT_HOOK:
|
|
157
|
+
Session.ORIGINAL_EXCEPT_HOOK(exc_type, exc_value, exc_traceback)
|
|
158
|
+
|
|
120
159
|
@classmethod
|
|
121
160
|
def cleanup_for_tests(cls):
|
|
122
161
|
if cls.GLOBAL_SESSION_CTX is not None:
|
|
@@ -125,6 +164,9 @@ class Session:
|
|
|
125
164
|
cls.GLOBAL_SESSION_CTX = None
|
|
126
165
|
atexit.unregister(cls._global_cleanup)
|
|
127
166
|
|
|
167
|
+
if cls.ORIGINAL_EXCEPT_HOOK:
|
|
168
|
+
sys.excepthook = cls.ORIGINAL_EXCEPT_HOOK
|
|
169
|
+
|
|
128
170
|
@staticmethod
|
|
129
171
|
def _global_cleanup():
|
|
130
172
|
if Session.GLOBAL_SESSION_CTX is not None:
|
|
@@ -18,7 +18,7 @@ datachain/storage.py,sha256=RiSJLYdHUjnrEWkLBKPcETHpAxld_B2WxLg711t0aZI,3733
|
|
|
18
18
|
datachain/telemetry.py,sha256=0A4IOPPp9VlP5pyW9eBfaTK3YhHGzHl7dQudQjUAx9A,994
|
|
19
19
|
datachain/utils.py,sha256=KeFSRHsiYthnTu4a6bH-rw04mX1m8krTX0f2NqfQGFI,12114
|
|
20
20
|
datachain/catalog/__init__.py,sha256=g2iAAFx_gEIrqshXlhSEbrc8qDaEH11cjU40n3CHDz4,409
|
|
21
|
-
datachain/catalog/catalog.py,sha256=
|
|
21
|
+
datachain/catalog/catalog.py,sha256=BsMyk2RQibQYHgrmovFZeSEpPVMTwgb_7ntVYdc7t-E,64090
|
|
22
22
|
datachain/catalog/datasource.py,sha256=D-VWIVDCM10A8sQavLhRXdYSCG7F4o4ifswEF80_NAQ,1412
|
|
23
23
|
datachain/catalog/loader.py,sha256=-6VelNfXUdgUnwInVyA8g86Boxv2xqhTh9xNS-Zlwig,8242
|
|
24
24
|
datachain/client/__init__.py,sha256=T4wiYL9KIM0ZZ_UqIyzV8_ufzYlewmizlV4iymHNluE,86
|
|
@@ -33,17 +33,17 @@ datachain/data_storage/__init__.py,sha256=cEOJpyu1JDZtfUupYucCDNFI6e5Wmp_Oyzq6rZ
|
|
|
33
33
|
datachain/data_storage/db_engine.py,sha256=81Ol1of9TTTzD97ORajCnP366Xz2mEJt6C-kTUCaru4,3406
|
|
34
34
|
datachain/data_storage/id_generator.py,sha256=lCEoU0BM37Ai2aRpSbwo5oQT0GqZnSpYwwvizathRMQ,4292
|
|
35
35
|
datachain/data_storage/job.py,sha256=w-7spowjkOa1P5fUVtJou3OltT0L48P0RYWZ9rSJ9-s,383
|
|
36
|
-
datachain/data_storage/metastore.py,sha256=
|
|
36
|
+
datachain/data_storage/metastore.py,sha256=HfCxk4lmDUg2Q4WsFNQGMWxllP0mToA00fxkFTwdNIE,52919
|
|
37
37
|
datachain/data_storage/schema.py,sha256=AGbjyEir5UmRZXI3m0jChZogUh5wd8csj6-YlUWaAxQ,8383
|
|
38
38
|
datachain/data_storage/serializer.py,sha256=6G2YtOFqqDzJf1KbvZraKGXl2XHZyVml2krunWUum5o,927
|
|
39
|
-
datachain/data_storage/sqlite.py,sha256=
|
|
39
|
+
datachain/data_storage/sqlite.py,sha256=fW08P7AbJ0cDbTbcTKuAGpvMXvBjg-QkGsKT_Dslyws,28383
|
|
40
40
|
datachain/data_storage/warehouse.py,sha256=fXhVfao3NfWFGbbG5uJ-Ga4bX1FiKVfcbDyQgECYfk8,32122
|
|
41
41
|
datachain/lib/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
|
42
|
-
datachain/lib/arrow.py,sha256=
|
|
42
|
+
datachain/lib/arrow.py,sha256=0R2CYsN82nNa5_03iS6jVix9EKeeqNZNAMgpSQP2hfo,9482
|
|
43
43
|
datachain/lib/clip.py,sha256=lm5CzVi4Cj1jVLEKvERKArb-egb9j1Ls-fwTItT6vlI,6150
|
|
44
44
|
datachain/lib/data_model.py,sha256=gHIjlow84GMRDa78yLL1Ud-N18or21fnTyPEwsatpXY,2045
|
|
45
45
|
datachain/lib/dataset_info.py,sha256=srPPhI2UHf6hFPBecyFEVw2SS5aPisIIMsvGgKqi7ss,2366
|
|
46
|
-
datachain/lib/dc.py,sha256=
|
|
46
|
+
datachain/lib/dc.py,sha256=HLOAkJEKFHJV_PqwSu0Pyl1m7JmUea8_wiMJFr14Nfk,75960
|
|
47
47
|
datachain/lib/file.py,sha256=LjTW_-PDAnoUhvyB4bJ8Y8n__XGqrxvmd9mDOF0Gir8,14875
|
|
48
48
|
datachain/lib/hf.py,sha256=cPnmLuprr0pYABH7KqA5FARQ1JGlywdDwD3yDzVAm4k,5920
|
|
49
49
|
datachain/lib/image.py,sha256=AMXYwQsmarZjRbPCZY3M1jDsM2WAB_b3cTY4uOIuXNU,2675
|
|
@@ -56,8 +56,8 @@ datachain/lib/settings.py,sha256=39thOpYJw-zPirzeNO6pmRC2vPrQvt4eBsw1xLWDFsw,234
|
|
|
56
56
|
datachain/lib/signal_schema.py,sha256=iqgubjCBRiUJB30miv05qFX4uU04dA_Pzi3DCUsHZGs,24177
|
|
57
57
|
datachain/lib/tar.py,sha256=3WIzao6yD5fbLqXLTt9GhPGNonbFIs_fDRu-9vgLgsA,1038
|
|
58
58
|
datachain/lib/text.py,sha256=UNHm8fhidk7wdrWqacEWaA6I9ykfYqarQ2URby7jc7M,1261
|
|
59
|
-
datachain/lib/udf.py,sha256=
|
|
60
|
-
datachain/lib/udf_signature.py,sha256=
|
|
59
|
+
datachain/lib/udf.py,sha256=oHhJWb0gVTxcybGzYDzAeN0Gb1IMhZBoGefncT88dIY,12339
|
|
60
|
+
datachain/lib/udf_signature.py,sha256=GXw24A-Olna6DWCdgy2bC-gZh_gLGPQ-KvjuI6pUjC0,7281
|
|
61
61
|
datachain/lib/utils.py,sha256=5-kJlAZE0D9nXXweAjo7-SP_AWGo28feaDByONYaooQ,463
|
|
62
62
|
datachain/lib/vfile.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
|
63
63
|
datachain/lib/webdataset.py,sha256=o7SHk5HOUWsZ5Ln04xOM04eQqiBHiJNO7xLgyVBrwo8,6924
|
|
@@ -70,14 +70,13 @@ datachain/lib/convert/unflatten.py,sha256=Ogvh_5wg2f38_At_1lN0D_e2uZOOpYEvwvB2xd
|
|
|
70
70
|
datachain/lib/convert/values_to_tuples.py,sha256=YOdbjzHq-uj6-cV2Qq43G72eN2avMNDGl4x5t6yQMl8,3931
|
|
71
71
|
datachain/query/__init__.py,sha256=0NBOZVgIDpCcj1Ci883dQ9A0iiwe03xzmotkOCFbxYc,293
|
|
72
72
|
datachain/query/batch.py,sha256=-vlpINJiertlnaoUVv1C95RatU0F6zuhpIYRufJRo1M,3660
|
|
73
|
-
datachain/query/dataset.py,sha256=
|
|
74
|
-
datachain/query/dispatch.py,sha256=
|
|
73
|
+
datachain/query/dataset.py,sha256=1c7y178ccFSeL_WIba0vT87Md_Oo4F8zaTVDjB9Bp3I,53641
|
|
74
|
+
datachain/query/dispatch.py,sha256=JVcZ4REE_GOsqXbar_Cb_fk-pHgQoabQLzXwuu7IhOg,12409
|
|
75
75
|
datachain/query/metrics.py,sha256=r5b0ygYhokbXp8Mg3kCH8iFSRw0jxzyeBe-C-J_bKFc,938
|
|
76
76
|
datachain/query/params.py,sha256=O_j89mjYRLOwWNhYZl-z7mi-rkdP7WyFmaDufsdTryE,863
|
|
77
77
|
datachain/query/queue.py,sha256=waqM_KzavU8C-G95-4211Nd4GXna_u2747Chgwtgz2w,3839
|
|
78
78
|
datachain/query/schema.py,sha256=I8zLWJuWl5N332ni9mAzDYtcxMJupVPgWkSDe8spNEk,8019
|
|
79
|
-
datachain/query/session.py,sha256=
|
|
80
|
-
datachain/query/udf.py,sha256=HB2hbEuiGA4ch9P2mh9iLA5Jj9mRj-4JFy9VfjTLJ8U,3622
|
|
79
|
+
datachain/query/session.py,sha256=kpFFJMfWBnxaMPojMGhJRbk-BOsSYI8Ckl6vvqnx7d0,5787
|
|
81
80
|
datachain/remote/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
|
82
81
|
datachain/remote/studio.py,sha256=f5s6qSZ9uB4URGUoU_8_W1KZRRQQVSm6cgEBkBUEfuE,7226
|
|
83
82
|
datachain/sql/__init__.py,sha256=A2djrbQwSMUZZEIKGnm-mnRA-NDSbiDJNpAmmwGNyIo,303
|
|
@@ -97,9 +96,9 @@ datachain/sql/sqlite/base.py,sha256=WLPHBhZbXbiqPoRV1VgDrXJqku4UuvJpBhYeQ0k5rI8,
|
|
|
97
96
|
datachain/sql/sqlite/types.py,sha256=yzvp0sXSEoEYXs6zaYC_2YubarQoZH-MiUNXcpuEP4s,1573
|
|
98
97
|
datachain/sql/sqlite/vector.py,sha256=ncW4eu2FlJhrP_CIpsvtkUabZlQdl2D5Lgwy_cbfqR0,469
|
|
99
98
|
datachain/torch/__init__.py,sha256=gIS74PoEPy4TB3X6vx9nLO0Y3sLJzsA8ckn8pRWihJM,579
|
|
100
|
-
datachain-0.5.
|
|
101
|
-
datachain-0.5.
|
|
102
|
-
datachain-0.5.
|
|
103
|
-
datachain-0.5.
|
|
104
|
-
datachain-0.5.
|
|
105
|
-
datachain-0.5.
|
|
99
|
+
datachain-0.5.1.dist-info/LICENSE,sha256=8DnqK5yoPI_E50bEg_zsHKZHY2HqPy4rYN338BHQaRA,11344
|
|
100
|
+
datachain-0.5.1.dist-info/METADATA,sha256=n8TFKjDmTzNBMaW5Oa6MUUUOAQbAjPzkAMaKCW3Y9NU,17156
|
|
101
|
+
datachain-0.5.1.dist-info/WHEEL,sha256=GV9aMThwP_4oNCtvEC2ec3qUYutgWeAzklro_0m4WJQ,91
|
|
102
|
+
datachain-0.5.1.dist-info/entry_points.txt,sha256=0GMJS6B_KWq0m3VT98vQI2YZodAMkn4uReZ_okga9R4,49
|
|
103
|
+
datachain-0.5.1.dist-info/top_level.txt,sha256=lZPpdU_2jJABLNIg2kvEOBi8PtsYikbN1OdMLHk8bTg,10
|
|
104
|
+
datachain-0.5.1.dist-info/RECORD,,
|
datachain/query/udf.py
DELETED
|
@@ -1,126 +0,0 @@
|
|
|
1
|
-
import typing
|
|
2
|
-
from collections.abc import Iterable, Iterator, Sequence
|
|
3
|
-
from dataclasses import dataclass
|
|
4
|
-
from typing import (
|
|
5
|
-
TYPE_CHECKING,
|
|
6
|
-
Any,
|
|
7
|
-
)
|
|
8
|
-
|
|
9
|
-
from fsspec.callbacks import DEFAULT_CALLBACK, Callback
|
|
10
|
-
|
|
11
|
-
from datachain.dataset import RowDict
|
|
12
|
-
|
|
13
|
-
from .batch import (
|
|
14
|
-
Batch,
|
|
15
|
-
BatchingStrategy,
|
|
16
|
-
NoBatching,
|
|
17
|
-
Partition,
|
|
18
|
-
RowsOutputBatch,
|
|
19
|
-
UDFInputBatch,
|
|
20
|
-
)
|
|
21
|
-
from .schema import UDFParameter
|
|
22
|
-
|
|
23
|
-
if TYPE_CHECKING:
|
|
24
|
-
from datachain.catalog import Catalog
|
|
25
|
-
|
|
26
|
-
from .batch import RowsOutput, UDFInput
|
|
27
|
-
|
|
28
|
-
ColumnType = Any
|
|
29
|
-
|
|
30
|
-
|
|
31
|
-
# Specification for the output of a UDF
|
|
32
|
-
UDFOutputSpec = typing.Mapping[str, ColumnType]
|
|
33
|
-
|
|
34
|
-
# Result type when calling the UDF wrapper around the actual
|
|
35
|
-
# Python function / class implementing it.
|
|
36
|
-
UDFResult = dict[str, Any]
|
|
37
|
-
|
|
38
|
-
|
|
39
|
-
@dataclass
|
|
40
|
-
class UDFProperties:
|
|
41
|
-
"""Container for basic UDF properties."""
|
|
42
|
-
|
|
43
|
-
params: list[UDFParameter]
|
|
44
|
-
output: UDFOutputSpec
|
|
45
|
-
batch: int = 1
|
|
46
|
-
|
|
47
|
-
def get_batching(self, use_partitioning: bool = False) -> BatchingStrategy:
|
|
48
|
-
if use_partitioning:
|
|
49
|
-
return Partition()
|
|
50
|
-
if self.batch == 1:
|
|
51
|
-
return NoBatching()
|
|
52
|
-
if self.batch > 1:
|
|
53
|
-
return Batch(self.batch)
|
|
54
|
-
raise ValueError(f"invalid batch size {self.batch}")
|
|
55
|
-
|
|
56
|
-
def signal_names(self) -> Iterable[str]:
|
|
57
|
-
return self.output.keys()
|
|
58
|
-
|
|
59
|
-
|
|
60
|
-
class UDFBase:
|
|
61
|
-
"""A base class for implementing stateful UDFs."""
|
|
62
|
-
|
|
63
|
-
def __init__(
|
|
64
|
-
self,
|
|
65
|
-
properties: UDFProperties,
|
|
66
|
-
):
|
|
67
|
-
self.properties = properties
|
|
68
|
-
self.signal_names = properties.signal_names()
|
|
69
|
-
self.output = properties.output
|
|
70
|
-
|
|
71
|
-
def run(
|
|
72
|
-
self,
|
|
73
|
-
udf_fields: "Sequence[str]",
|
|
74
|
-
udf_inputs: "Iterable[RowsOutput]",
|
|
75
|
-
catalog: "Catalog",
|
|
76
|
-
is_generator: bool,
|
|
77
|
-
cache: bool,
|
|
78
|
-
download_cb: Callback = DEFAULT_CALLBACK,
|
|
79
|
-
processed_cb: Callback = DEFAULT_CALLBACK,
|
|
80
|
-
) -> Iterator[Iterable["UDFResult"]]:
|
|
81
|
-
for batch in udf_inputs:
|
|
82
|
-
if isinstance(batch, RowsOutputBatch):
|
|
83
|
-
n_rows = len(batch.rows)
|
|
84
|
-
inputs: UDFInput = UDFInputBatch(
|
|
85
|
-
[RowDict(zip(udf_fields, row)) for row in batch.rows]
|
|
86
|
-
)
|
|
87
|
-
else:
|
|
88
|
-
n_rows = 1
|
|
89
|
-
inputs = RowDict(zip(udf_fields, batch))
|
|
90
|
-
output = self.run_once(catalog, inputs, is_generator, cache, cb=download_cb)
|
|
91
|
-
processed_cb.relative_update(n_rows)
|
|
92
|
-
yield output
|
|
93
|
-
|
|
94
|
-
def run_once(
|
|
95
|
-
self,
|
|
96
|
-
catalog: "Catalog",
|
|
97
|
-
arg: "UDFInput",
|
|
98
|
-
is_generator: bool = False,
|
|
99
|
-
cache: bool = False,
|
|
100
|
-
cb: Callback = DEFAULT_CALLBACK,
|
|
101
|
-
) -> Iterable[UDFResult]:
|
|
102
|
-
raise NotImplementedError
|
|
103
|
-
|
|
104
|
-
def bind_parameters(self, catalog: "Catalog", row: "RowDict", **kwargs) -> list:
|
|
105
|
-
return [p.get_value(catalog, row, **kwargs) for p in self.properties.params]
|
|
106
|
-
|
|
107
|
-
def _process_results(
|
|
108
|
-
self,
|
|
109
|
-
rows: Sequence["RowDict"],
|
|
110
|
-
results: Sequence[Sequence[Any]],
|
|
111
|
-
is_generator=False,
|
|
112
|
-
) -> Iterable[UDFResult]:
|
|
113
|
-
"""Create a list of dictionaries representing UDF results."""
|
|
114
|
-
|
|
115
|
-
# outputting rows
|
|
116
|
-
if is_generator:
|
|
117
|
-
# each row in results is a tuple of column values
|
|
118
|
-
return (dict(zip(self.signal_names, row)) for row in results)
|
|
119
|
-
|
|
120
|
-
# outputting signals
|
|
121
|
-
row_ids = [row["sys__id"] for row in rows]
|
|
122
|
-
return [
|
|
123
|
-
{"sys__id": row_id} | dict(zip(self.signal_names, signals))
|
|
124
|
-
for row_id, signals in zip(row_ids, results)
|
|
125
|
-
if signals is not None # skip rows with no output
|
|
126
|
-
]
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|