datachain 0.5.0__py3-none-any.whl → 0.5.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of datachain might be problematic. Click here for more details.

@@ -988,6 +988,14 @@ class Catalog:
988
988
  schema = {
989
989
  c.name: c.type.to_dict() for c in columns if isinstance(c.type, SQLType)
990
990
  }
991
+
992
+ job_id = job_id or os.getenv("DATACHAIN_JOB_ID")
993
+ if not job_id:
994
+ from datachain.query.session import Session
995
+
996
+ session = Session.get(catalog=self)
997
+ job_id = session.job_id
998
+
991
999
  dataset = self.metastore.create_dataset_version(
992
1000
  dataset,
993
1001
  version,
@@ -50,7 +50,6 @@ if TYPE_CHECKING:
50
50
  from datachain.data_storage import AbstractIDGenerator, schema
51
51
  from datachain.data_storage.db_engine import DatabaseEngine
52
52
 
53
-
54
53
  logger = logging.getLogger("datachain")
55
54
 
56
55
 
@@ -384,6 +383,11 @@ class AbstractMetastore(ABC, Serializable):
384
383
  ) -> None:
385
384
  """Set the status of the given job and dataset."""
386
385
 
386
+ @abstractmethod
387
+ def get_job_dataset_versions(self, job_id: str) -> list[tuple[str, int]]:
388
+ """Returns dataset names and versions for the job."""
389
+ raise NotImplementedError
390
+
387
391
 
388
392
  class AbstractDBMetastore(AbstractMetastore):
389
393
  """
@@ -1519,3 +1523,18 @@ class AbstractDBMetastore(AbstractMetastore):
1519
1523
  .values(status=dataset_status)
1520
1524
  )
1521
1525
  self.db.execute(query, conn=conn) # type: ignore[attr-defined]
1526
+
1527
+ def get_job_dataset_versions(self, job_id: str) -> list[tuple[str, int]]:
1528
+ """Returns dataset names and versions for the job."""
1529
+ dv = self._datasets_versions
1530
+ ds = self._datasets
1531
+
1532
+ join_condition = dv.c.dataset_id == ds.c.id
1533
+
1534
+ query = (
1535
+ self._datasets_versions_select(ds.c.name, dv.c.version)
1536
+ .select_from(dv.join(ds, join_condition))
1537
+ .where(dv.c.job_id == job_id)
1538
+ )
1539
+
1540
+ return list(self.db.execute(query))
@@ -15,6 +15,7 @@ from typing import (
15
15
  )
16
16
 
17
17
  import sqlalchemy
18
+ from packaging import version
18
19
  from sqlalchemy import MetaData, Table, UniqueConstraint, exists, select
19
20
  from sqlalchemy.dialects import sqlite
20
21
  from sqlalchemy.schema import CreateIndex, CreateTable, DropTable
@@ -153,7 +154,7 @@ class SQLiteDatabaseEngine(DatabaseEngine):
153
154
  if os.environ.get("DEBUG_SHOW_SQL_QUERIES"):
154
155
  import sys
155
156
 
156
- db.set_trace_callback(sys.stderr.write)
157
+ db.set_trace_callback(lambda stmt: print(stmt, file=sys.stderr))
157
158
 
158
159
  load_usearch_extension(db)
159
160
 
@@ -345,45 +346,36 @@ class SQLiteIDGenerator(AbstractDBIDGenerator):
345
346
  def get_next_ids(self, uri: str, count: int) -> range:
346
347
  """Returns a range of IDs for the given URI."""
347
348
 
348
- # NOTE: we can't use RETURNING clause here because it is only available
349
- # in sqlalchemy v2, see
350
- # https://github.com/sqlalchemy/sqlalchemy/issues/6195#issuecomment-1248700677
351
- # After we upgrade to sqlalchemy v2, we can use the following code,
352
- # leaving fallback to the current implementation for older versions of SQLite,
353
- # which is still supported, for example, in Ubuntu 20.04 LTS (Focal Fossa),
354
- # where SQLite version 3.31.1 is used.
355
-
356
- # sqlite_version = version.parse(sqlite3.sqlite_version)
357
- # if sqlite_version >= version.parse("3.35.0"):
358
- # # RETURNING is supported on SQLite 3.35.0 (2021-03-12) or newer
359
- # stmt = (
360
- # sqlite.insert(self._table)
361
- # .values(uri=uri, last_id=count)
362
- # .on_conflict_do_update(
363
- # index_elements=["uri"],
364
- # set_={"last_id": self._table.c.last_id + count},
365
- # )
366
- # .returning(self._table.c.last_id)
367
- # )
368
- # last_id = self._db.execute(stmt).fetchone()[0]
369
- # else:
370
- # (fallback to the current implementation with a transaction)
371
-
372
- # Transactions ensure no concurrency conflicts
373
- with self._db.transaction() as conn:
374
- # UPSERT syntax was added to SQLite with version 3.24.0 (2018-06-04).
375
- stmt_ins = (
349
+ sqlite_version = version.parse(sqlite3.sqlite_version)
350
+ is_returning_supported = sqlite_version >= version.parse("3.35.0")
351
+ if is_returning_supported:
352
+ stmt = (
376
353
  sqlite.insert(self._table)
377
354
  .values(uri=uri, last_id=count)
378
355
  .on_conflict_do_update(
379
356
  index_elements=["uri"],
380
357
  set_={"last_id": self._table.c.last_id + count},
381
358
  )
359
+ .returning(self._table.c.last_id)
382
360
  )
383
- self._db.execute(stmt_ins, conn=conn)
361
+ last_id = self._db.execute(stmt).fetchone()[0]
362
+ else:
363
+ # Older versions of SQLite are still the default under Ubuntu LTS,
364
+ # e.g. Ubuntu 20.04 LTS (Focal Fossa) uses 3.31.1
365
+ # Transactions ensure no concurrency conflicts
366
+ with self._db.transaction() as conn:
367
+ stmt_ins = (
368
+ sqlite.insert(self._table)
369
+ .values(uri=uri, last_id=count)
370
+ .on_conflict_do_update(
371
+ index_elements=["uri"],
372
+ set_={"last_id": self._table.c.last_id + count},
373
+ )
374
+ )
375
+ self._db.execute(stmt_ins, conn=conn)
384
376
 
385
- stmt_sel = select(self._table.c.last_id).where(self._table.c.uri == uri)
386
- last_id = self._db.execute(stmt_sel, conn=conn).fetchone()[0]
377
+ stmt_sel = select(self._table.c.last_id).where(self._table.c.uri == uri)
378
+ last_id = self._db.execute(stmt_sel, conn=conn).fetchone()[0]
387
379
 
388
380
  return range(last_id - count + 1, last_id + 1)
389
381
 
datachain/lib/arrow.py CHANGED
@@ -1,8 +1,9 @@
1
1
  import re
2
2
  from collections.abc import Sequence
3
3
  from tempfile import NamedTemporaryFile
4
- from typing import TYPE_CHECKING, Optional
4
+ from typing import TYPE_CHECKING, Any, Optional
5
5
 
6
+ import orjson
6
7
  import pyarrow as pa
7
8
  from pyarrow.dataset import CsvFileFormat, dataset
8
9
  from tqdm import tqdm
@@ -10,6 +11,7 @@ from tqdm import tqdm
10
11
  from datachain.lib.data_model import dict_to_data_model
11
12
  from datachain.lib.file import ArrowRow, File
12
13
  from datachain.lib.model_store import ModelStore
14
+ from datachain.lib.signal_schema import SignalSchema
13
15
  from datachain.lib.udf import Generator
14
16
 
15
17
  if TYPE_CHECKING:
@@ -20,6 +22,9 @@ if TYPE_CHECKING:
20
22
  from datachain.lib.dc import DataChain
21
23
 
22
24
 
25
+ DATACHAIN_SIGNAL_SCHEMA_PARQUET_KEY = b"DataChain SignalSchema"
26
+
27
+
23
28
  class ArrowGenerator(Generator):
24
29
  def __init__(
25
30
  self,
@@ -61,28 +66,35 @@ class ArrowGenerator(Generator):
61
66
  path, filesystem=file.get_fs(), schema=self.input_schema, **self.kwargs
62
67
  )
63
68
  hf_schema = _get_hf_schema(ds.schema)
69
+ use_datachain_schema = (
70
+ bool(ds.schema.metadata)
71
+ and DATACHAIN_SIGNAL_SCHEMA_PARQUET_KEY in ds.schema.metadata
72
+ )
64
73
  index = 0
65
74
  with tqdm(desc="Parsed by pyarrow", unit=" rows") as pbar:
66
75
  for record_batch in ds.to_batches():
67
76
  for record in record_batch.to_pylist():
68
- vals = list(record.values())
69
- if self.output_schema:
70
- fields = self.output_schema.model_fields
71
- vals_dict = {}
72
- for i, ((field, field_info), val) in enumerate(
73
- zip(fields.items(), vals)
74
- ):
75
- anno = field_info.annotation
76
- if hf_schema:
77
- from datachain.lib.hf import convert_feature
78
-
79
- feat = list(hf_schema[0].values())[i]
80
- vals_dict[field] = convert_feature(val, feat, anno)
81
- elif ModelStore.is_pydantic(anno):
82
- vals_dict[field] = anno(**val) # type: ignore[misc]
83
- else:
84
- vals_dict[field] = val
85
- vals = [self.output_schema(**vals_dict)]
77
+ if use_datachain_schema and self.output_schema:
78
+ vals = [_nested_model_instantiate(record, self.output_schema)]
79
+ else:
80
+ vals = list(record.values())
81
+ if self.output_schema:
82
+ fields = self.output_schema.model_fields
83
+ vals_dict = {}
84
+ for i, ((field, field_info), val) in enumerate(
85
+ zip(fields.items(), vals)
86
+ ):
87
+ anno = field_info.annotation
88
+ if hf_schema:
89
+ from datachain.lib.hf import convert_feature
90
+
91
+ feat = list(hf_schema[0].values())[i]
92
+ vals_dict[field] = convert_feature(val, feat, anno)
93
+ elif ModelStore.is_pydantic(anno):
94
+ vals_dict[field] = anno(**val) # type: ignore[misc]
95
+ else:
96
+ vals_dict[field] = val
97
+ vals = [self.output_schema(**vals_dict)]
86
98
  if self.source:
87
99
  kwargs: dict = self.kwargs
88
100
  # Can't serialize CsvFileFormat; may lose formatting options.
@@ -113,6 +125,9 @@ def schema_to_output(schema: pa.Schema, col_names: Optional[Sequence[str]] = Non
113
125
  )
114
126
  if not col_names:
115
127
  col_names = schema.names
128
+ signal_schema = _get_datachain_schema(schema)
129
+ if signal_schema:
130
+ return signal_schema.values
116
131
  columns = _convert_col_names(col_names) # type: ignore[arg-type]
117
132
  hf_schema = _get_hf_schema(schema)
118
133
  if hf_schema:
@@ -197,3 +212,33 @@ def _get_hf_schema(
197
212
  features = schema_from_arrow(schema)
198
213
  return features, get_output_schema(features)
199
214
  return None
215
+
216
+
217
+ def _get_datachain_schema(schema: "pa.Schema") -> Optional[SignalSchema]:
218
+ """Return a restored SignalSchema from parquet metadata, if any is found."""
219
+ if schema.metadata and DATACHAIN_SIGNAL_SCHEMA_PARQUET_KEY in schema.metadata:
220
+ serialized_signal_schema = orjson.loads(
221
+ schema.metadata[DATACHAIN_SIGNAL_SCHEMA_PARQUET_KEY]
222
+ )
223
+ return SignalSchema.deserialize(serialized_signal_schema)
224
+ return None
225
+
226
+
227
+ def _nested_model_instantiate(
228
+ column_values: dict[str, Any], model: type["BaseModel"], prefix: str = ""
229
+ ) -> "BaseModel":
230
+ """Instantiate the given model and all sub-models/fields based on the provided
231
+ column values."""
232
+ vals_dict = {}
233
+ for field, field_info in model.model_fields.items():
234
+ anno = field_info.annotation
235
+ cur_path = f"{prefix}.{field}" if prefix else field
236
+ if ModelStore.is_pydantic(anno):
237
+ vals_dict[field] = _nested_model_instantiate(
238
+ column_values,
239
+ anno, # type: ignore[arg-type]
240
+ prefix=cur_path,
241
+ )
242
+ elif cur_path in column_values:
243
+ vals_dict[field] = column_values[cur_path]
244
+ return model(**vals_dict)
datachain/lib/dc.py CHANGED
@@ -16,6 +16,7 @@ from typing import (
16
16
  overload,
17
17
  )
18
18
 
19
+ import orjson
19
20
  import pandas as pd
20
21
  import sqlalchemy
21
22
  from pydantic import BaseModel
@@ -58,7 +59,7 @@ from datachain.query.dataset import (
58
59
  from datachain.query.schema import DEFAULT_DELIMITER, Column, DatasetRow
59
60
  from datachain.sql.functions import path as pathfunc
60
61
  from datachain.telemetry import telemetry
61
- from datachain.utils import inside_notebook
62
+ from datachain.utils import batched_it, inside_notebook
62
63
 
63
64
  if TYPE_CHECKING:
64
65
  from typing_extensions import Concatenate, ParamSpec, Self
@@ -71,6 +72,10 @@ C = Column
71
72
 
72
73
  _T = TypeVar("_T")
73
74
  D = TypeVar("D", bound="DataChain")
75
+ UDFObjT = TypeVar("UDFObjT", bound=UDFBase)
76
+
77
+
78
+ DEFAULT_PARQUET_CHUNK_SIZE = 100_000
74
79
 
75
80
 
76
81
  def resolve_columns(
@@ -819,7 +824,7 @@ class DataChain:
819
824
 
820
825
  def gen(
821
826
  self,
822
- func: Optional[Callable] = None,
827
+ func: Optional[Union[Callable, Generator]] = None,
823
828
  params: Union[None, str, Sequence[str]] = None,
824
829
  output: OutputType = None,
825
830
  **signal_map,
@@ -931,12 +936,12 @@ class DataChain:
931
936
 
932
937
  def _udf_to_obj(
933
938
  self,
934
- target_class: type[UDFBase],
935
- func: Optional[Callable],
939
+ target_class: type[UDFObjT],
940
+ func: Optional[Union[Callable, UDFObjT]],
936
941
  params: Union[None, str, Sequence[str]],
937
942
  output: OutputType,
938
943
  signal_map,
939
- ) -> UDFBase:
944
+ ) -> UDFObjT:
940
945
  is_generator = target_class.is_output_batched
941
946
  name = self.name or ""
942
947
 
@@ -1103,6 +1108,29 @@ class DataChain:
1103
1108
  rows = (row_factory(db_signals, r) for r in rows)
1104
1109
  yield from rows
1105
1110
 
1111
+ def to_columnar_data_with_names(
1112
+ self, chunk_size: int = DEFAULT_PARQUET_CHUNK_SIZE
1113
+ ) -> tuple[list[str], Iterator[list[list[Any]]]]:
1114
+ """Returns column names and the results as an iterator that provides chunks,
1115
+ with each chunk containing a list of columns, where each column contains a
1116
+ list of the row values for that column in that chunk. Useful for columnar data
1117
+ formats, such as parquet or other OLAP databases.
1118
+ """
1119
+ headers, _ = self._effective_signals_schema.get_headers_with_length()
1120
+ column_names = [".".join(filter(None, header)) for header in headers]
1121
+
1122
+ results_iter = self.collect_flatten()
1123
+
1124
+ def column_chunks() -> Iterator[list[list[Any]]]:
1125
+ for chunk_iter in batched_it(results_iter, chunk_size):
1126
+ columns: list[list[Any]] = [[] for _ in column_names]
1127
+ for row in chunk_iter:
1128
+ for i, col in enumerate(columns):
1129
+ col.append(row[i])
1130
+ yield columns
1131
+
1132
+ return column_names, column_chunks()
1133
+
1106
1134
  @overload
1107
1135
  def results(self) -> list[tuple[Any, ...]]: ...
1108
1136
 
@@ -1808,21 +1836,96 @@ class DataChain:
1808
1836
  self,
1809
1837
  path: Union[str, os.PathLike[str], BinaryIO],
1810
1838
  partition_cols: Optional[Sequence[str]] = None,
1839
+ chunk_size: int = DEFAULT_PARQUET_CHUNK_SIZE,
1811
1840
  **kwargs,
1812
1841
  ) -> None:
1813
- """Save chain to parquet file.
1842
+ """Save chain to parquet file with SignalSchema metadata.
1814
1843
 
1815
1844
  Parameters:
1816
1845
  path : Path or a file-like binary object to save the file.
1817
1846
  partition_cols : Column names by which to partition the dataset.
1847
+ chunk_size : The chunk size of results to read and convert to columnar
1848
+ data, to avoid running out of memory.
1818
1849
  """
1850
+ import pyarrow as pa
1851
+ import pyarrow.parquet as pq
1852
+
1853
+ from datachain.lib.arrow import DATACHAIN_SIGNAL_SCHEMA_PARQUET_KEY
1854
+
1819
1855
  _partition_cols = list(partition_cols) if partition_cols else None
1820
- return self.to_pandas().to_parquet(
1821
- path,
1822
- partition_cols=_partition_cols,
1823
- **kwargs,
1856
+ signal_schema_metadata = orjson.dumps(
1857
+ self._effective_signals_schema.serialize()
1824
1858
  )
1825
1859
 
1860
+ column_names, column_chunks = self.to_columnar_data_with_names(chunk_size)
1861
+
1862
+ parquet_schema = None
1863
+ parquet_writer = None
1864
+ first_chunk = True
1865
+
1866
+ for chunk in column_chunks:
1867
+ # pyarrow infers the best parquet schema from the python types of
1868
+ # the input data.
1869
+ table = pa.Table.from_pydict(
1870
+ dict(zip(column_names, chunk)),
1871
+ schema=parquet_schema,
1872
+ )
1873
+
1874
+ # Preserve any existing metadata, and add the DataChain SignalSchema.
1875
+ existing_metadata = table.schema.metadata or {}
1876
+ merged_metadata = {
1877
+ **existing_metadata,
1878
+ DATACHAIN_SIGNAL_SCHEMA_PARQUET_KEY: signal_schema_metadata,
1879
+ }
1880
+ table = table.replace_schema_metadata(merged_metadata)
1881
+ parquet_schema = table.schema
1882
+
1883
+ if _partition_cols:
1884
+ # Write to a partitioned parquet dataset.
1885
+ pq.write_to_dataset(
1886
+ table,
1887
+ root_path=path,
1888
+ partition_cols=_partition_cols,
1889
+ **kwargs,
1890
+ )
1891
+ else:
1892
+ if first_chunk:
1893
+ # Write to a single parquet file.
1894
+ parquet_writer = pq.ParquetWriter(path, parquet_schema, **kwargs)
1895
+ first_chunk = False
1896
+
1897
+ assert parquet_writer
1898
+ parquet_writer.write_table(table)
1899
+
1900
+ if parquet_writer:
1901
+ parquet_writer.close()
1902
+
1903
+ def to_csv(
1904
+ self,
1905
+ path: Union[str, os.PathLike[str]],
1906
+ delimiter: str = ",",
1907
+ **kwargs,
1908
+ ) -> None:
1909
+ """Save chain to a csv (comma-separated values) file.
1910
+
1911
+ Parameters:
1912
+ path : Path to save the file.
1913
+ delimiter : Delimiter to use for the resulting file.
1914
+ """
1915
+ import csv
1916
+
1917
+ headers, _ = self._effective_signals_schema.get_headers_with_length()
1918
+ column_names = [".".join(filter(None, header)) for header in headers]
1919
+
1920
+ results_iter = self.collect_flatten()
1921
+
1922
+ with open(path, "w", newline="") as f:
1923
+ writer = csv.writer(f, delimiter=delimiter, **kwargs)
1924
+ writer.writerow(column_names)
1925
+
1926
+ for row in results_iter:
1927
+ writer.writerow(row)
1928
+
1826
1929
  @classmethod
1827
1930
  def from_records(
1828
1931
  cls,
datachain/lib/udf.py CHANGED
@@ -1,31 +1,33 @@
1
1
  import sys
2
2
  import traceback
3
- from typing import TYPE_CHECKING, Callable, Optional
3
+ from collections.abc import Iterable, Iterator, Mapping, Sequence
4
+ from dataclasses import dataclass
5
+ from typing import TYPE_CHECKING, Any, Callable, Optional
4
6
 
5
7
  from fsspec.callbacks import DEFAULT_CALLBACK, Callback
6
8
  from pydantic import BaseModel
7
9
 
8
10
  from datachain.dataset import RowDict
9
11
  from datachain.lib.convert.flatten import flatten
10
- from datachain.lib.convert.unflatten import unflatten_to_json
11
12
  from datachain.lib.file import File
12
- from datachain.lib.model_store import ModelStore
13
13
  from datachain.lib.signal_schema import SignalSchema
14
- from datachain.lib.udf_signature import UdfSignature
15
14
  from datachain.lib.utils import AbstractUDF, DataChainError, DataChainParamsError
16
- from datachain.query.batch import UDFInputBatch
17
- from datachain.query.schema import ColumnParameter
18
- from datachain.query.udf import UDFBase as _UDFBase
19
- from datachain.query.udf import UDFProperties
15
+ from datachain.query.batch import (
16
+ Batch,
17
+ BatchingStrategy,
18
+ NoBatching,
19
+ Partition,
20
+ RowsOutputBatch,
21
+ UDFInputBatch,
22
+ )
23
+ from datachain.query.schema import ColumnParameter, UDFParameter
20
24
 
21
25
  if TYPE_CHECKING:
22
- from collections.abc import Iterable, Iterator, Sequence
23
-
24
26
  from typing_extensions import Self
25
27
 
26
28
  from datachain.catalog import Catalog
29
+ from datachain.lib.udf_signature import UdfSignature
27
30
  from datachain.query.batch import RowsOutput, UDFInput
28
- from datachain.query.udf import UDFResult
29
31
 
30
32
 
31
33
  class UdfError(DataChainParamsError):
@@ -33,14 +35,47 @@ class UdfError(DataChainParamsError):
33
35
  super().__init__(f"UDF error: {msg}")
34
36
 
35
37
 
36
- class UDFAdapter(_UDFBase):
38
+ ColumnType = Any
39
+
40
+ # Specification for the output of a UDF
41
+ UDFOutputSpec = Mapping[str, ColumnType]
42
+
43
+ # Result type when calling the UDF wrapper around the actual
44
+ # Python function / class implementing it.
45
+ UDFResult = dict[str, Any]
46
+
47
+
48
+ @dataclass
49
+ class UDFProperties:
50
+ """Container for basic UDF properties."""
51
+
52
+ params: list[UDFParameter]
53
+ output: UDFOutputSpec
54
+ batch: int = 1
55
+
56
+ def get_batching(self, use_partitioning: bool = False) -> BatchingStrategy:
57
+ if use_partitioning:
58
+ return Partition()
59
+ if self.batch == 1:
60
+ return NoBatching()
61
+ if self.batch > 1:
62
+ return Batch(self.batch)
63
+ raise ValueError(f"invalid batch size {self.batch}")
64
+
65
+ def signal_names(self) -> Iterable[str]:
66
+ return self.output.keys()
67
+
68
+
69
+ class UDFAdapter:
37
70
  def __init__(
38
71
  self,
39
72
  inner: "UDFBase",
40
73
  properties: UDFProperties,
41
74
  ):
42
75
  self.inner = inner
43
- super().__init__(properties)
76
+ self.properties = properties
77
+ self.signal_names = properties.signal_names()
78
+ self.output = properties.output
44
79
 
45
80
  def run(
46
81
  self,
@@ -51,20 +86,23 @@ class UDFAdapter(_UDFBase):
51
86
  cache: bool,
52
87
  download_cb: Callback = DEFAULT_CALLBACK,
53
88
  processed_cb: Callback = DEFAULT_CALLBACK,
54
- ) -> "Iterator[Iterable[UDFResult]]":
55
- self.inner._catalog = catalog
89
+ ) -> Iterator[Iterable[UDFResult]]:
90
+ self.inner.catalog = catalog
56
91
  if hasattr(self.inner, "setup") and callable(self.inner.setup):
57
92
  self.inner.setup()
58
93
 
59
- yield from super().run(
60
- udf_fields,
61
- udf_inputs,
62
- catalog,
63
- is_generator,
64
- cache,
65
- download_cb,
66
- processed_cb,
67
- )
94
+ for batch in udf_inputs:
95
+ if isinstance(batch, RowsOutputBatch):
96
+ n_rows = len(batch.rows)
97
+ inputs: UDFInput = UDFInputBatch(
98
+ [RowDict(zip(udf_fields, row)) for row in batch.rows]
99
+ )
100
+ else:
101
+ n_rows = 1
102
+ inputs = RowDict(zip(udf_fields, batch))
103
+ output = self.run_once(catalog, inputs, is_generator, cache, cb=download_cb)
104
+ processed_cb.relative_update(n_rows)
105
+ yield output
68
106
 
69
107
  if hasattr(self.inner, "teardown") and callable(self.inner.teardown):
70
108
  self.inner.teardown()
@@ -76,23 +114,46 @@ class UDFAdapter(_UDFBase):
76
114
  is_generator: bool = False,
77
115
  cache: bool = False,
78
116
  cb: Callback = DEFAULT_CALLBACK,
79
- ) -> "Iterable[UDFResult]":
117
+ ) -> Iterable[UDFResult]:
80
118
  if isinstance(arg, UDFInputBatch):
81
119
  udf_inputs = [
82
120
  self.bind_parameters(catalog, row, cache=cache, cb=cb)
83
121
  for row in arg.rows
84
122
  ]
85
- udf_outputs = self.inner(udf_inputs, cache=cache, download_cb=cb)
123
+ udf_outputs = self.inner.run_once(udf_inputs, cache=cache, download_cb=cb)
86
124
  return self._process_results(arg.rows, udf_outputs, is_generator)
87
125
  if isinstance(arg, RowDict):
88
126
  udf_inputs = self.bind_parameters(catalog, arg, cache=cache, cb=cb)
89
- udf_outputs = self.inner(*udf_inputs, cache=cache, download_cb=cb)
127
+ udf_outputs = self.inner.run_once(udf_inputs, cache=cache, download_cb=cb)
90
128
  if not is_generator:
91
129
  # udf_outputs is generator already if is_generator=True
92
130
  udf_outputs = [udf_outputs]
93
131
  return self._process_results([arg], udf_outputs, is_generator)
94
132
  raise ValueError(f"Unexpected UDF argument: {arg}")
95
133
 
134
+ def bind_parameters(self, catalog: "Catalog", row: "RowDict", **kwargs) -> list:
135
+ return [p.get_value(catalog, row, **kwargs) for p in self.properties.params]
136
+
137
+ def _process_results(
138
+ self,
139
+ rows: Sequence["RowDict"],
140
+ results: Sequence[Sequence[Any]],
141
+ is_generator=False,
142
+ ) -> Iterable[UDFResult]:
143
+ """Create a list of dictionaries representing UDF results."""
144
+
145
+ # outputting rows
146
+ if is_generator:
147
+ # each row in results is a tuple of column values
148
+ return (dict(zip(self.signal_names, row)) for row in results)
149
+
150
+ # outputting signals
151
+ row_ids = [row["sys__id"] for row in rows]
152
+ return [
153
+ {"sys__id": row_id} | dict(zip(self.signal_names, signals))
154
+ for row_id, signals in zip(row_ids, results)
155
+ ]
156
+
96
157
 
97
158
  class UDFBase(AbstractUDF):
98
159
  """Base class for stateful user-defined functions.
@@ -146,14 +207,14 @@ class UDFBase(AbstractUDF):
146
207
  is_output_batched = False
147
208
  is_input_grouped = False
148
209
  params_spec: Optional[list[str]]
210
+ catalog: "Optional[Catalog]"
149
211
 
150
212
  def __init__(self):
151
213
  self.params = None
152
214
  self.output = None
153
215
  self.params_spec = None
154
216
  self.output_spec = None
155
- self._contains_stream = None
156
- self._catalog = None
217
+ self.catalog = None
157
218
  self._func = None
158
219
 
159
220
  def process(self, *args, **kwargs):
@@ -174,9 +235,9 @@ class UDFBase(AbstractUDF):
174
235
 
175
236
  def _init(
176
237
  self,
177
- sign: UdfSignature,
238
+ sign: "UdfSignature",
178
239
  params: SignalSchema,
179
- func: Callable,
240
+ func: Optional[Callable],
180
241
  ):
181
242
  self.params = params
182
243
  self.output = sign.output_schema
@@ -190,13 +251,13 @@ class UDFBase(AbstractUDF):
190
251
  @classmethod
191
252
  def _create(
192
253
  cls,
193
- sign: UdfSignature,
254
+ sign: "UdfSignature",
194
255
  params: SignalSchema,
195
256
  ) -> "Self":
196
257
  if isinstance(sign.func, AbstractUDF):
197
258
  if not isinstance(sign.func, cls): # type: ignore[unreachable]
198
259
  raise UdfError(
199
- f"cannot create UDF: provided UDF '{sign.func.__name__}'"
260
+ f"cannot create UDF: provided UDF '{type(sign.func).__name__}'"
200
261
  f" must be a child of target class '{cls.__name__}'",
201
262
  )
202
263
  result = sign.func
@@ -212,13 +273,6 @@ class UDFBase(AbstractUDF):
212
273
  def name(self):
213
274
  return self.__class__.__name__
214
275
 
215
- def set_catalog(self, catalog):
216
- self._catalog = catalog.copy(db=False)
217
-
218
- @property
219
- def catalog(self):
220
- return self._catalog
221
-
222
276
  def to_udf_wrapper(self, batch: int = 1) -> UDFAdapter:
223
277
  assert self.params_spec is not None
224
278
  properties = UDFProperties(
@@ -229,11 +283,9 @@ class UDFBase(AbstractUDF):
229
283
  def validate_results(self, results, *args, **kwargs):
230
284
  return results
231
285
 
232
- def __call__(self, *rows, cache, download_cb):
233
- if self.is_input_grouped:
234
- objs = self._parse_grouped_rows(rows[0], cache, download_cb)
235
- elif self.is_input_batched:
236
- objs = zip(*self._parse_rows(rows[0], cache, download_cb))
286
+ def run_once(self, rows, cache, download_cb):
287
+ if self.is_input_batched:
288
+ objs = zip(*self._parse_rows(rows, cache, download_cb))
237
289
  else:
238
290
  objs = self._parse_rows([rows], cache, download_cb)[0]
239
291
 
@@ -259,8 +311,8 @@ class UDFBase(AbstractUDF):
259
311
  ):
260
312
  res = list(res)
261
313
  assert len(res) == len(
262
- rows[0]
263
- ), f"{self.name} returns {len(res)} rows while len(rows[0]) expected"
314
+ rows
315
+ ), f"{self.name} returns {len(res)} rows while {len(rows)} expected"
264
316
 
265
317
  return res
266
318
 
@@ -283,41 +335,11 @@ class UDFBase(AbstractUDF):
283
335
  for obj in obj_row:
284
336
  if isinstance(obj, File):
285
337
  obj._set_stream(
286
- self._catalog, caching_enabled=cache, download_cb=download_cb
338
+ self.catalog, caching_enabled=cache, download_cb=download_cb
287
339
  )
288
340
  objs.append(obj_row)
289
341
  return objs
290
342
 
291
- def _parse_grouped_rows(self, group, cache, download_cb):
292
- spec_map = {}
293
- output_map = {}
294
- for name, (anno, subtree) in self.params.tree.items():
295
- if ModelStore.is_pydantic(anno):
296
- length = sum(1 for _ in self.params._get_flat_tree(subtree, [], 0))
297
- else:
298
- length = 1
299
- spec_map[name] = anno, length
300
- output_map[name] = []
301
-
302
- for flat_obj in group:
303
- position = 0
304
- for signal, (cls, length) in spec_map.items():
305
- slice = flat_obj[position : position + length]
306
- position += length
307
-
308
- if ModelStore.is_pydantic(cls):
309
- obj = cls(**unflatten_to_json(cls, slice))
310
- else:
311
- obj = slice[0]
312
-
313
- if isinstance(obj, File):
314
- obj._set_stream(
315
- self._catalog, caching_enabled=cache, download_cb=download_cb
316
- )
317
- output_map[signal].append(obj)
318
-
319
- return list(output_map.values())
320
-
321
343
  def process_safe(self, obj_rows):
322
344
  try:
323
345
  result_objs = self.process(*obj_rows)
@@ -1,10 +1,11 @@
1
1
  import inspect
2
2
  from collections.abc import Generator, Iterator, Sequence
3
3
  from dataclasses import dataclass
4
- from typing import Callable, Optional, Union, get_args, get_origin
4
+ from typing import Callable, Union, get_args, get_origin
5
5
 
6
6
  from datachain.lib.data_model import DataType, DataTypeNames, is_chain_type
7
7
  from datachain.lib.signal_schema import SignalSchema
8
+ from datachain.lib.udf import UDFBase
8
9
  from datachain.lib.utils import AbstractUDF, DataChainParamsError
9
10
 
10
11
 
@@ -16,7 +17,7 @@ class UdfSignatureError(DataChainParamsError):
16
17
 
17
18
  @dataclass
18
19
  class UdfSignature:
19
- func: Callable
20
+ func: Union[Callable, UDFBase]
20
21
  params: Sequence[str]
21
22
  output_schema: SignalSchema
22
23
 
@@ -27,7 +28,7 @@ class UdfSignature:
27
28
  cls,
28
29
  chain: str,
29
30
  signal_map: dict[str, Callable],
30
- func: Optional[Callable] = None,
31
+ func: Union[None, UDFBase, Callable] = None,
31
32
  params: Union[None, str, Sequence[str]] = None,
32
33
  output: Union[None, DataType, Sequence[str], dict[str, DataType]] = None,
33
34
  is_generator: bool = True,
@@ -39,6 +40,7 @@ class UdfSignature:
39
40
  f"multiple signals '{keys}' are not supported in processors."
40
41
  " Chain multiple processors instead.",
41
42
  )
43
+ udf_func: Union[UDFBase, Callable]
42
44
  if len(signal_map) == 1:
43
45
  if func is not None:
44
46
  raise UdfSignatureError(
@@ -53,7 +55,7 @@ class UdfSignature:
53
55
  udf_func = func
54
56
  signal_name = None
55
57
 
56
- if not callable(udf_func):
58
+ if not isinstance(udf_func, UDFBase) and not callable(udf_func):
57
59
  raise UdfSignatureError(chain, f"UDF '{udf_func}' is not callable")
58
60
 
59
61
  func_params_map_sign, func_outs_sign, is_iterator = (
@@ -73,7 +75,7 @@ class UdfSignature:
73
75
  if not func_outs_sign:
74
76
  raise UdfSignatureError(
75
77
  chain,
76
- f"outputs are not defined in function '{udf_func.__name__}'"
78
+ f"outputs are not defined in function '{udf_func}'"
77
79
  " hints or 'output'",
78
80
  )
79
81
 
@@ -154,7 +156,7 @@ class UdfSignature:
154
156
 
155
157
  @staticmethod
156
158
  def _func_signature(
157
- chain: str, udf_func: Callable
159
+ chain: str, udf_func: Union[Callable, UDFBase]
158
160
  ) -> tuple[dict[str, type], Sequence[type], bool]:
159
161
  if isinstance(udf_func, AbstractUDF):
160
162
  func = udf_func.process # type: ignore[unreachable]
@@ -42,6 +42,7 @@ from datachain.data_storage.schema import (
42
42
  )
43
43
  from datachain.dataset import DatasetStatus, RowDict
44
44
  from datachain.error import DatasetNotFoundError, QueryScriptCancelError
45
+ from datachain.lib.udf import UDFAdapter
45
46
  from datachain.progress import CombinedDownloadCallback
46
47
  from datachain.sql.functions import rand
47
48
  from datachain.utils import (
@@ -53,7 +54,6 @@ from datachain.utils import (
53
54
 
54
55
  from .schema import C, UDFParamSpec, normalize_param
55
56
  from .session import Session
56
- from .udf import UDFBase
57
57
 
58
58
  if TYPE_CHECKING:
59
59
  from sqlalchemy.sql.elements import ClauseElement
@@ -299,7 +299,7 @@ def adjust_outputs(
299
299
  return row
300
300
 
301
301
 
302
- def get_udf_col_types(warehouse: "AbstractWarehouse", udf: UDFBase) -> list[tuple]:
302
+ def get_udf_col_types(warehouse: "AbstractWarehouse", udf: UDFAdapter) -> list[tuple]:
303
303
  """Optimization: Precompute UDF column types so these don't have to be computed
304
304
  in the convert_type function for each row in a loop."""
305
305
  dialect = warehouse.db.dialect
@@ -320,7 +320,7 @@ def process_udf_outputs(
320
320
  warehouse: "AbstractWarehouse",
321
321
  udf_table: "Table",
322
322
  udf_results: Iterator[Iterable["UDFResult"]],
323
- udf: UDFBase,
323
+ udf: UDFAdapter,
324
324
  batch_size: int = INSERT_BATCH_SIZE,
325
325
  cb: Callback = DEFAULT_CALLBACK,
326
326
  ) -> None:
@@ -364,7 +364,7 @@ def get_generated_callback(is_generator: bool = False) -> Callback:
364
364
 
365
365
  @frozen
366
366
  class UDFStep(Step, ABC):
367
- udf: UDFBase
367
+ udf: UDFAdapter
368
368
  catalog: "Catalog"
369
369
  partition_by: Optional[PartitionByType] = None
370
370
  parallel: Optional[int] = None
@@ -1465,7 +1465,7 @@ class DatasetQuery:
1465
1465
  @detach
1466
1466
  def add_signals(
1467
1467
  self,
1468
- udf: UDFBase,
1468
+ udf: UDFAdapter,
1469
1469
  parallel: Optional[int] = None,
1470
1470
  workers: Union[bool, int] = False,
1471
1471
  min_task_size: Optional[int] = None,
@@ -1509,7 +1509,7 @@ class DatasetQuery:
1509
1509
  @detach
1510
1510
  def generate(
1511
1511
  self,
1512
- udf: UDFBase,
1512
+ udf: UDFAdapter,
1513
1513
  parallel: Optional[int] = None,
1514
1514
  workers: Union[bool, int] = False,
1515
1515
  min_task_size: Optional[int] = None,
@@ -13,6 +13,7 @@ from multiprocess import get_context
13
13
 
14
14
  from datachain.catalog import Catalog
15
15
  from datachain.catalog.loader import get_distributed_class
16
+ from datachain.lib.udf import UDFAdapter, UDFResult
16
17
  from datachain.query.dataset import (
17
18
  get_download_callback,
18
19
  get_generated_callback,
@@ -27,7 +28,6 @@ from datachain.query.queue import (
27
28
  put_into_queue,
28
29
  unmarshal,
29
30
  )
30
- from datachain.query.udf import UDFBase, UDFResult
31
31
  from datachain.utils import batched_it
32
32
 
33
33
  DEFAULT_BATCH_SIZE = 10000
@@ -336,7 +336,7 @@ class ProcessedCallback(Callback):
336
336
  @attrs.define
337
337
  class UDFWorker:
338
338
  catalog: Catalog
339
- udf: UDFBase
339
+ udf: UDFAdapter
340
340
  task_queue: "multiprocess.Queue"
341
341
  done_queue: "multiprocess.Queue"
342
342
  is_generator: bool
@@ -1,5 +1,8 @@
1
1
  import atexit
2
+ import logging
3
+ import os
2
4
  import re
5
+ import sys
3
6
  from typing import TYPE_CHECKING, Optional
4
7
  from uuid import uuid4
5
8
 
@@ -9,6 +12,8 @@ from datachain.error import TableMissingError
9
12
  if TYPE_CHECKING:
10
13
  from datachain.catalog import Catalog
11
14
 
15
+ logger = logging.getLogger("datachain")
16
+
12
17
 
13
18
  class Session:
14
19
  """
@@ -35,6 +40,7 @@ class Session:
35
40
 
36
41
  GLOBAL_SESSION_CTX: Optional["Session"] = None
37
42
  GLOBAL_SESSION: Optional["Session"] = None
43
+ ORIGINAL_EXCEPT_HOOK = None
38
44
 
39
45
  DATASET_PREFIX = "session_"
40
46
  GLOBAL_SESSION_NAME = "global"
@@ -58,6 +64,7 @@ class Session:
58
64
 
59
65
  session_uuid = uuid4().hex[: self.SESSION_UUID_LEN]
60
66
  self.name = f"{name}_{session_uuid}"
67
+ self.job_id = os.getenv("DATACHAIN_JOB_ID") or str(uuid4())
61
68
  self.is_new_catalog = not catalog
62
69
  self.catalog = catalog or get_catalog(
63
70
  client_config=client_config, in_memory=in_memory
@@ -67,6 +74,9 @@ class Session:
67
74
  return self
68
75
 
69
76
  def __exit__(self, exc_type, exc_val, exc_tb):
77
+ if exc_type:
78
+ self._cleanup_created_versions(self.name)
79
+
70
80
  self._cleanup_temp_datasets()
71
81
  if self.is_new_catalog:
72
82
  self.catalog.metastore.close_on_exit()
@@ -88,6 +98,21 @@ class Session:
88
98
  except TableMissingError:
89
99
  pass
90
100
 
101
+ def _cleanup_created_versions(self, job_id: str) -> None:
102
+ versions = self.catalog.metastore.get_job_dataset_versions(job_id)
103
+ if not versions:
104
+ return
105
+
106
+ datasets = {}
107
+ for dataset_name, version in versions:
108
+ if dataset_name not in datasets:
109
+ datasets[dataset_name] = self.catalog.get_dataset(dataset_name)
110
+ dataset = datasets[dataset_name]
111
+ logger.info(
112
+ "Removing dataset version %s@%s due to exception", dataset_name, version
113
+ )
114
+ self.catalog.remove_dataset_version(dataset, version)
115
+
91
116
  @classmethod
92
117
  def get(
93
118
  cls,
@@ -114,9 +139,23 @@ class Session:
114
139
  in_memory=in_memory,
115
140
  )
116
141
  cls.GLOBAL_SESSION = cls.GLOBAL_SESSION_CTX.__enter__()
142
+
117
143
  atexit.register(cls._global_cleanup)
144
+ cls.ORIGINAL_EXCEPT_HOOK = sys.excepthook
145
+ sys.excepthook = cls.except_hook
146
+
118
147
  return cls.GLOBAL_SESSION
119
148
 
149
+ @staticmethod
150
+ def except_hook(exc_type, exc_value, exc_traceback):
151
+ Session._global_cleanup()
152
+ if Session.GLOBAL_SESSION_CTX is not None:
153
+ job_id = Session.GLOBAL_SESSION_CTX.job_id
154
+ Session.GLOBAL_SESSION_CTX._cleanup_created_versions(job_id)
155
+
156
+ if Session.ORIGINAL_EXCEPT_HOOK:
157
+ Session.ORIGINAL_EXCEPT_HOOK(exc_type, exc_value, exc_traceback)
158
+
120
159
  @classmethod
121
160
  def cleanup_for_tests(cls):
122
161
  if cls.GLOBAL_SESSION_CTX is not None:
@@ -125,6 +164,9 @@ class Session:
125
164
  cls.GLOBAL_SESSION_CTX = None
126
165
  atexit.unregister(cls._global_cleanup)
127
166
 
167
+ if cls.ORIGINAL_EXCEPT_HOOK:
168
+ sys.excepthook = cls.ORIGINAL_EXCEPT_HOOK
169
+
128
170
  @staticmethod
129
171
  def _global_cleanup():
130
172
  if Session.GLOBAL_SESSION_CTX is not None:
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.1
2
2
  Name: datachain
3
- Version: 0.5.0
3
+ Version: 0.5.1
4
4
  Summary: Wrangle unstructured AI data at scale
5
5
  Author-email: Dmitry Petrov <support@dvc.org>
6
6
  License: Apache-2.0
@@ -18,7 +18,7 @@ datachain/storage.py,sha256=RiSJLYdHUjnrEWkLBKPcETHpAxld_B2WxLg711t0aZI,3733
18
18
  datachain/telemetry.py,sha256=0A4IOPPp9VlP5pyW9eBfaTK3YhHGzHl7dQudQjUAx9A,994
19
19
  datachain/utils.py,sha256=KeFSRHsiYthnTu4a6bH-rw04mX1m8krTX0f2NqfQGFI,12114
20
20
  datachain/catalog/__init__.py,sha256=g2iAAFx_gEIrqshXlhSEbrc8qDaEH11cjU40n3CHDz4,409
21
- datachain/catalog/catalog.py,sha256=FuKuIiCwPgN5Ea25hnFe_ZFZH9YEUZ2ma9k_Lczk-JU,63867
21
+ datachain/catalog/catalog.py,sha256=BsMyk2RQibQYHgrmovFZeSEpPVMTwgb_7ntVYdc7t-E,64090
22
22
  datachain/catalog/datasource.py,sha256=D-VWIVDCM10A8sQavLhRXdYSCG7F4o4ifswEF80_NAQ,1412
23
23
  datachain/catalog/loader.py,sha256=-6VelNfXUdgUnwInVyA8g86Boxv2xqhTh9xNS-Zlwig,8242
24
24
  datachain/client/__init__.py,sha256=T4wiYL9KIM0ZZ_UqIyzV8_ufzYlewmizlV4iymHNluE,86
@@ -33,17 +33,17 @@ datachain/data_storage/__init__.py,sha256=cEOJpyu1JDZtfUupYucCDNFI6e5Wmp_Oyzq6rZ
33
33
  datachain/data_storage/db_engine.py,sha256=81Ol1of9TTTzD97ORajCnP366Xz2mEJt6C-kTUCaru4,3406
34
34
  datachain/data_storage/id_generator.py,sha256=lCEoU0BM37Ai2aRpSbwo5oQT0GqZnSpYwwvizathRMQ,4292
35
35
  datachain/data_storage/job.py,sha256=w-7spowjkOa1P5fUVtJou3OltT0L48P0RYWZ9rSJ9-s,383
36
- datachain/data_storage/metastore.py,sha256=NV4FJ_W16Q19Sx70i5Qtre-n4DC2kMD0qw0vBz3j7Ks,52228
36
+ datachain/data_storage/metastore.py,sha256=HfCxk4lmDUg2Q4WsFNQGMWxllP0mToA00fxkFTwdNIE,52919
37
37
  datachain/data_storage/schema.py,sha256=AGbjyEir5UmRZXI3m0jChZogUh5wd8csj6-YlUWaAxQ,8383
38
38
  datachain/data_storage/serializer.py,sha256=6G2YtOFqqDzJf1KbvZraKGXl2XHZyVml2krunWUum5o,927
39
- datachain/data_storage/sqlite.py,sha256=EBKJncuzcyQfcKFm2mUjvHjHRTODsteM-k_zndunBrw,28834
39
+ datachain/data_storage/sqlite.py,sha256=fW08P7AbJ0cDbTbcTKuAGpvMXvBjg-QkGsKT_Dslyws,28383
40
40
  datachain/data_storage/warehouse.py,sha256=fXhVfao3NfWFGbbG5uJ-Ga4bX1FiKVfcbDyQgECYfk8,32122
41
41
  datachain/lib/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
42
- datachain/lib/arrow.py,sha256=aUsoQmxDmuSnB8Ik9p57Y66gc_dgx6NBqkDDIfLsvno,7630
42
+ datachain/lib/arrow.py,sha256=0R2CYsN82nNa5_03iS6jVix9EKeeqNZNAMgpSQP2hfo,9482
43
43
  datachain/lib/clip.py,sha256=lm5CzVi4Cj1jVLEKvERKArb-egb9j1Ls-fwTItT6vlI,6150
44
44
  datachain/lib/data_model.py,sha256=gHIjlow84GMRDa78yLL1Ud-N18or21fnTyPEwsatpXY,2045
45
45
  datachain/lib/dataset_info.py,sha256=srPPhI2UHf6hFPBecyFEVw2SS5aPisIIMsvGgKqi7ss,2366
46
- datachain/lib/dc.py,sha256=yTyHrKIswCzdlvl2n-wdEVZEEF5VQpkLJPzPfUL9CTU,72054
46
+ datachain/lib/dc.py,sha256=HLOAkJEKFHJV_PqwSu0Pyl1m7JmUea8_wiMJFr14Nfk,75960
47
47
  datachain/lib/file.py,sha256=LjTW_-PDAnoUhvyB4bJ8Y8n__XGqrxvmd9mDOF0Gir8,14875
48
48
  datachain/lib/hf.py,sha256=cPnmLuprr0pYABH7KqA5FARQ1JGlywdDwD3yDzVAm4k,5920
49
49
  datachain/lib/image.py,sha256=AMXYwQsmarZjRbPCZY3M1jDsM2WAB_b3cTY4uOIuXNU,2675
@@ -56,8 +56,8 @@ datachain/lib/settings.py,sha256=39thOpYJw-zPirzeNO6pmRC2vPrQvt4eBsw1xLWDFsw,234
56
56
  datachain/lib/signal_schema.py,sha256=iqgubjCBRiUJB30miv05qFX4uU04dA_Pzi3DCUsHZGs,24177
57
57
  datachain/lib/tar.py,sha256=3WIzao6yD5fbLqXLTt9GhPGNonbFIs_fDRu-9vgLgsA,1038
58
58
  datachain/lib/text.py,sha256=UNHm8fhidk7wdrWqacEWaA6I9ykfYqarQ2URby7jc7M,1261
59
- datachain/lib/udf.py,sha256=nG7DDuPgZ5ZuijwvDoCq-OZMxlDM8vFNzyxMmik0Y1c,11716
60
- datachain/lib/udf_signature.py,sha256=gMStcEeYJka5M6cg50Z9orC6y6HzCAJ3MkFqqn1fjZg,7137
59
+ datachain/lib/udf.py,sha256=oHhJWb0gVTxcybGzYDzAeN0Gb1IMhZBoGefncT88dIY,12339
60
+ datachain/lib/udf_signature.py,sha256=GXw24A-Olna6DWCdgy2bC-gZh_gLGPQ-KvjuI6pUjC0,7281
61
61
  datachain/lib/utils.py,sha256=5-kJlAZE0D9nXXweAjo7-SP_AWGo28feaDByONYaooQ,463
62
62
  datachain/lib/vfile.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
63
63
  datachain/lib/webdataset.py,sha256=o7SHk5HOUWsZ5Ln04xOM04eQqiBHiJNO7xLgyVBrwo8,6924
@@ -70,14 +70,13 @@ datachain/lib/convert/unflatten.py,sha256=Ogvh_5wg2f38_At_1lN0D_e2uZOOpYEvwvB2xd
70
70
  datachain/lib/convert/values_to_tuples.py,sha256=YOdbjzHq-uj6-cV2Qq43G72eN2avMNDGl4x5t6yQMl8,3931
71
71
  datachain/query/__init__.py,sha256=0NBOZVgIDpCcj1Ci883dQ9A0iiwe03xzmotkOCFbxYc,293
72
72
  datachain/query/batch.py,sha256=-vlpINJiertlnaoUVv1C95RatU0F6zuhpIYRufJRo1M,3660
73
- datachain/query/dataset.py,sha256=tLCTaj4K93BY93GgOPv9PknZByEF89zpHc7y9s8ZF_w,53610
74
- datachain/query/dispatch.py,sha256=CFAc09O6UllcyUSSEY1GUlEMPzeO8RYhXinNN4HBl9M,12405
73
+ datachain/query/dataset.py,sha256=1c7y178ccFSeL_WIba0vT87Md_Oo4F8zaTVDjB9Bp3I,53641
74
+ datachain/query/dispatch.py,sha256=JVcZ4REE_GOsqXbar_Cb_fk-pHgQoabQLzXwuu7IhOg,12409
75
75
  datachain/query/metrics.py,sha256=r5b0ygYhokbXp8Mg3kCH8iFSRw0jxzyeBe-C-J_bKFc,938
76
76
  datachain/query/params.py,sha256=O_j89mjYRLOwWNhYZl-z7mi-rkdP7WyFmaDufsdTryE,863
77
77
  datachain/query/queue.py,sha256=waqM_KzavU8C-G95-4211Nd4GXna_u2747Chgwtgz2w,3839
78
78
  datachain/query/schema.py,sha256=I8zLWJuWl5N332ni9mAzDYtcxMJupVPgWkSDe8spNEk,8019
79
- datachain/query/session.py,sha256=UPH5Z4fzCDsvj81ji0e8GA6Mgra3bOAEpVq4htqOtis,4317
80
- datachain/query/udf.py,sha256=HB2hbEuiGA4ch9P2mh9iLA5Jj9mRj-4JFy9VfjTLJ8U,3622
79
+ datachain/query/session.py,sha256=kpFFJMfWBnxaMPojMGhJRbk-BOsSYI8Ckl6vvqnx7d0,5787
81
80
  datachain/remote/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
82
81
  datachain/remote/studio.py,sha256=f5s6qSZ9uB4URGUoU_8_W1KZRRQQVSm6cgEBkBUEfuE,7226
83
82
  datachain/sql/__init__.py,sha256=A2djrbQwSMUZZEIKGnm-mnRA-NDSbiDJNpAmmwGNyIo,303
@@ -97,9 +96,9 @@ datachain/sql/sqlite/base.py,sha256=WLPHBhZbXbiqPoRV1VgDrXJqku4UuvJpBhYeQ0k5rI8,
97
96
  datachain/sql/sqlite/types.py,sha256=yzvp0sXSEoEYXs6zaYC_2YubarQoZH-MiUNXcpuEP4s,1573
98
97
  datachain/sql/sqlite/vector.py,sha256=ncW4eu2FlJhrP_CIpsvtkUabZlQdl2D5Lgwy_cbfqR0,469
99
98
  datachain/torch/__init__.py,sha256=gIS74PoEPy4TB3X6vx9nLO0Y3sLJzsA8ckn8pRWihJM,579
100
- datachain-0.5.0.dist-info/LICENSE,sha256=8DnqK5yoPI_E50bEg_zsHKZHY2HqPy4rYN338BHQaRA,11344
101
- datachain-0.5.0.dist-info/METADATA,sha256=tKSZNiHZY0WJ_w6irkpSF7qDfuOTfiYNEQ6St3eBs-M,17156
102
- datachain-0.5.0.dist-info/WHEEL,sha256=GV9aMThwP_4oNCtvEC2ec3qUYutgWeAzklro_0m4WJQ,91
103
- datachain-0.5.0.dist-info/entry_points.txt,sha256=0GMJS6B_KWq0m3VT98vQI2YZodAMkn4uReZ_okga9R4,49
104
- datachain-0.5.0.dist-info/top_level.txt,sha256=lZPpdU_2jJABLNIg2kvEOBi8PtsYikbN1OdMLHk8bTg,10
105
- datachain-0.5.0.dist-info/RECORD,,
99
+ datachain-0.5.1.dist-info/LICENSE,sha256=8DnqK5yoPI_E50bEg_zsHKZHY2HqPy4rYN338BHQaRA,11344
100
+ datachain-0.5.1.dist-info/METADATA,sha256=n8TFKjDmTzNBMaW5Oa6MUUUOAQbAjPzkAMaKCW3Y9NU,17156
101
+ datachain-0.5.1.dist-info/WHEEL,sha256=GV9aMThwP_4oNCtvEC2ec3qUYutgWeAzklro_0m4WJQ,91
102
+ datachain-0.5.1.dist-info/entry_points.txt,sha256=0GMJS6B_KWq0m3VT98vQI2YZodAMkn4uReZ_okga9R4,49
103
+ datachain-0.5.1.dist-info/top_level.txt,sha256=lZPpdU_2jJABLNIg2kvEOBi8PtsYikbN1OdMLHk8bTg,10
104
+ datachain-0.5.1.dist-info/RECORD,,
datachain/query/udf.py DELETED
@@ -1,126 +0,0 @@
1
- import typing
2
- from collections.abc import Iterable, Iterator, Sequence
3
- from dataclasses import dataclass
4
- from typing import (
5
- TYPE_CHECKING,
6
- Any,
7
- )
8
-
9
- from fsspec.callbacks import DEFAULT_CALLBACK, Callback
10
-
11
- from datachain.dataset import RowDict
12
-
13
- from .batch import (
14
- Batch,
15
- BatchingStrategy,
16
- NoBatching,
17
- Partition,
18
- RowsOutputBatch,
19
- UDFInputBatch,
20
- )
21
- from .schema import UDFParameter
22
-
23
- if TYPE_CHECKING:
24
- from datachain.catalog import Catalog
25
-
26
- from .batch import RowsOutput, UDFInput
27
-
28
- ColumnType = Any
29
-
30
-
31
- # Specification for the output of a UDF
32
- UDFOutputSpec = typing.Mapping[str, ColumnType]
33
-
34
- # Result type when calling the UDF wrapper around the actual
35
- # Python function / class implementing it.
36
- UDFResult = dict[str, Any]
37
-
38
-
39
- @dataclass
40
- class UDFProperties:
41
- """Container for basic UDF properties."""
42
-
43
- params: list[UDFParameter]
44
- output: UDFOutputSpec
45
- batch: int = 1
46
-
47
- def get_batching(self, use_partitioning: bool = False) -> BatchingStrategy:
48
- if use_partitioning:
49
- return Partition()
50
- if self.batch == 1:
51
- return NoBatching()
52
- if self.batch > 1:
53
- return Batch(self.batch)
54
- raise ValueError(f"invalid batch size {self.batch}")
55
-
56
- def signal_names(self) -> Iterable[str]:
57
- return self.output.keys()
58
-
59
-
60
- class UDFBase:
61
- """A base class for implementing stateful UDFs."""
62
-
63
- def __init__(
64
- self,
65
- properties: UDFProperties,
66
- ):
67
- self.properties = properties
68
- self.signal_names = properties.signal_names()
69
- self.output = properties.output
70
-
71
- def run(
72
- self,
73
- udf_fields: "Sequence[str]",
74
- udf_inputs: "Iterable[RowsOutput]",
75
- catalog: "Catalog",
76
- is_generator: bool,
77
- cache: bool,
78
- download_cb: Callback = DEFAULT_CALLBACK,
79
- processed_cb: Callback = DEFAULT_CALLBACK,
80
- ) -> Iterator[Iterable["UDFResult"]]:
81
- for batch in udf_inputs:
82
- if isinstance(batch, RowsOutputBatch):
83
- n_rows = len(batch.rows)
84
- inputs: UDFInput = UDFInputBatch(
85
- [RowDict(zip(udf_fields, row)) for row in batch.rows]
86
- )
87
- else:
88
- n_rows = 1
89
- inputs = RowDict(zip(udf_fields, batch))
90
- output = self.run_once(catalog, inputs, is_generator, cache, cb=download_cb)
91
- processed_cb.relative_update(n_rows)
92
- yield output
93
-
94
- def run_once(
95
- self,
96
- catalog: "Catalog",
97
- arg: "UDFInput",
98
- is_generator: bool = False,
99
- cache: bool = False,
100
- cb: Callback = DEFAULT_CALLBACK,
101
- ) -> Iterable[UDFResult]:
102
- raise NotImplementedError
103
-
104
- def bind_parameters(self, catalog: "Catalog", row: "RowDict", **kwargs) -> list:
105
- return [p.get_value(catalog, row, **kwargs) for p in self.properties.params]
106
-
107
- def _process_results(
108
- self,
109
- rows: Sequence["RowDict"],
110
- results: Sequence[Sequence[Any]],
111
- is_generator=False,
112
- ) -> Iterable[UDFResult]:
113
- """Create a list of dictionaries representing UDF results."""
114
-
115
- # outputting rows
116
- if is_generator:
117
- # each row in results is a tuple of column values
118
- return (dict(zip(self.signal_names, row)) for row in results)
119
-
120
- # outputting signals
121
- row_ids = [row["sys__id"] for row in rows]
122
- return [
123
- {"sys__id": row_id} | dict(zip(self.signal_names, signals))
124
- for row_id, signals in zip(row_ids, results)
125
- if signals is not None # skip rows with no output
126
- ]