datachain 0.3.0__py3-none-any.whl → 0.3.2__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of datachain might be problematic. Click here for more details.

@@ -676,7 +676,7 @@ class Catalog:
676
676
 
677
677
  def parse_url(self, uri: str, **config: Any) -> tuple[Client, str]:
678
678
  config = config or self.client_config
679
- return Client.parse_url(uri, self.metastore, self.cache, **config)
679
+ return Client.parse_url(uri, self.cache, **config)
680
680
 
681
681
  def get_client(self, uri: StorageURI, **config: Any) -> Client:
682
682
  """
@@ -1627,8 +1627,17 @@ class Catalog:
1627
1627
  version = self.get_dataset(dataset_name).get_version(dataset_version)
1628
1628
 
1629
1629
  file_signals_values = {}
1630
+ file_schemas = {}
1631
+ # TODO: To remove after we properly fix deserialization
1632
+ for signal, type_name in version.feature_schema.items():
1633
+ from datachain.lib.model_store import ModelStore
1630
1634
 
1631
- schema = SignalSchema.deserialize(version.feature_schema)
1635
+ type_name_parsed, v = ModelStore.parse_name_version(type_name)
1636
+ fr = ModelStore.get(type_name_parsed, v)
1637
+ if fr and issubclass(fr, File):
1638
+ file_schemas[signal] = type_name
1639
+
1640
+ schema = SignalSchema.deserialize(file_schemas)
1632
1641
  for file_signals in schema.get_signals(File):
1633
1642
  prefix = file_signals.replace(".", DEFAULT_DELIMITER) + DEFAULT_DELIMITER
1634
1643
  file_signals_values[file_signals] = {
@@ -37,7 +37,6 @@ from datachain.storage import StorageURI
37
37
  if TYPE_CHECKING:
38
38
  from fsspec.spec import AbstractFileSystem
39
39
 
40
- from datachain.data_storage import AbstractMetastore
41
40
 
42
41
  logger = logging.getLogger("datachain")
43
42
 
@@ -116,13 +115,12 @@ class Client(ABC):
116
115
  @staticmethod
117
116
  def parse_url(
118
117
  source: str,
119
- metastore: "AbstractMetastore",
120
118
  cache: DataChainCache,
121
119
  **kwargs,
122
120
  ) -> tuple["Client", str]:
123
121
  cls = Client.get_implementation(source)
124
122
  storage_url, rel_path = cls.split_url(source)
125
- client = cls.from_name(storage_url, metastore, cache, kwargs)
123
+ client = cls.from_name(storage_url, cache, kwargs)
126
124
  return client, rel_path
127
125
 
128
126
  @classmethod
@@ -136,7 +134,6 @@ class Client(ABC):
136
134
  def from_name(
137
135
  cls,
138
136
  name: str,
139
- metastore: "AbstractMetastore",
140
137
  cache: DataChainCache,
141
138
  kwargs: dict[str, Any],
142
139
  ) -> "Client":
datachain/client/local.py CHANGED
@@ -2,7 +2,7 @@ import os
2
2
  import posixpath
3
3
  from datetime import datetime, timezone
4
4
  from pathlib import Path
5
- from typing import TYPE_CHECKING, Any
5
+ from typing import Any
6
6
  from urllib.parse import urlparse
7
7
 
8
8
  from fsspec.implementations.local import LocalFileSystem
@@ -12,9 +12,6 @@ from datachain.storage import StorageURI
12
12
 
13
13
  from .fsspec import Client
14
14
 
15
- if TYPE_CHECKING:
16
- from datachain.data_storage import AbstractMetastore
17
-
18
15
 
19
16
  class FileClient(Client):
20
17
  FS_CLASS = LocalFileSystem
@@ -97,9 +94,7 @@ class FileClient(Client):
97
94
  return cls.root_dir(), uri.removeprefix(cls.root_path().as_uri())
98
95
 
99
96
  @classmethod
100
- def from_name(
101
- cls, name: str, metastore: "AbstractMetastore", cache, kwargs
102
- ) -> "FileClient":
97
+ def from_name(cls, name: str, cache, kwargs) -> "FileClient":
103
98
  use_symlinks = kwargs.pop("use_symlinks", False)
104
99
  return cls(name, kwargs, cache, use_symlinks=use_symlinks)
105
100
 
@@ -67,7 +67,11 @@ def convert_rows_custom_column_types(
67
67
  for row in rows:
68
68
  row_list = list(row)
69
69
  for idx, t in custom_columns_types:
70
- row_list[idx] = t.on_read_convert(row_list[idx], dialect)
70
+ row_list[idx] = (
71
+ t.default_value(dialect)
72
+ if row_list[idx] is None
73
+ else t.on_read_convert(row_list[idx], dialect)
74
+ )
71
75
 
72
76
  yield tuple(row_list)
73
77
 
@@ -136,7 +140,15 @@ class DataTable:
136
140
  self.column_types: dict[str, SQLType] = column_types or {}
137
141
 
138
142
  @staticmethod
139
- def copy_column(column: sa.Column):
143
+ def copy_column(
144
+ column: sa.Column,
145
+ primary_key: Optional[bool] = None,
146
+ index: Optional[bool] = None,
147
+ nullable: Optional[bool] = None,
148
+ default: Optional[Any] = None,
149
+ server_default: Optional[Any] = None,
150
+ unique: Optional[bool] = None,
151
+ ) -> sa.Column:
140
152
  """
141
153
  Copy a sqlalchemy Column object intended for use as a signal column.
142
154
 
@@ -150,12 +162,14 @@ class DataTable:
150
162
  return sa.Column(
151
163
  column.name,
152
164
  column.type,
153
- primary_key=column.primary_key,
154
- index=column.index,
155
- nullable=column.nullable,
156
- default=column.default,
157
- server_default=column.server_default,
158
- unique=column.unique,
165
+ primary_key=primary_key if primary_key is not None else column.primary_key,
166
+ index=index if index is not None else column.index,
167
+ nullable=nullable if nullable is not None else column.nullable,
168
+ default=default if default is not None else column.default,
169
+ server_default=(
170
+ server_default if server_default is not None else column.server_default
171
+ ),
172
+ unique=unique if unique is not None else column.unique,
159
173
  )
160
174
 
161
175
  @classmethod
@@ -122,6 +122,11 @@ class SQLiteDatabaseEngine(DatabaseEngine):
122
122
  engine = sqlalchemy.create_engine(
123
123
  "sqlite+pysqlite:///", creator=lambda: db, future=True
124
124
  )
125
+ # ensure we run SA on_connect init (e.g it registers regexp function),
126
+ # also makes sure that it's consistent. Otherwise in some cases it
127
+ # seems we are getting different results if engine object is used in a
128
+ # different thread first and enine is not used in the Main thread.
129
+ engine.connect().close()
125
130
 
126
131
  db.isolation_level = None # Use autocommit mode
127
132
  db.execute("PRAGMA foreign_keys = ON")
@@ -17,7 +17,7 @@ from sqlalchemy.sql.expression import true
17
17
 
18
18
  from datachain.client import Client
19
19
  from datachain.data_storage.serializer import Serializable
20
- from datachain.dataset import DatasetRecord, RowDict
20
+ from datachain.dataset import DatasetRecord
21
21
  from datachain.node import DirType, DirTypeGroup, Entry, Node, NodeWithPath, get_path
22
22
  from datachain.sql.functions import path as pathfunc
23
23
  from datachain.sql.types import Int, SQLType
@@ -201,23 +201,17 @@ class AbstractWarehouse(ABC, Serializable):
201
201
  def dataset_select_paginated(
202
202
  self,
203
203
  query,
204
- limit: Optional[int] = None,
205
- order_by: tuple["ColumnElement[Any]", ...] = (),
206
204
  page_size: int = SELECT_BATCH_SIZE,
207
- ) -> Generator[RowDict, None, None]:
205
+ ) -> Generator[Sequence, None, None]:
208
206
  """
209
207
  This is equivalent to `db.execute`, but for selecting rows in batches
210
208
  """
211
- cols = query.selected_columns
212
- cols_names = [c.name for c in cols]
209
+ limit = query._limit
210
+ paginated_query = query.limit(page_size)
213
211
 
214
- if not order_by:
215
- ordering = [cols.sys__id]
216
- else:
217
- ordering = order_by # type: ignore[assignment]
218
-
219
- # reset query order by and apply new order by id
220
- paginated_query = query.order_by(None).order_by(*ordering).limit(page_size)
212
+ if not paginated_query._order_by_clauses:
213
+ # default order by is order by `sys__id`
214
+ paginated_query = paginated_query.order_by(query.selected_columns.sys__id)
221
215
 
222
216
  results = None
223
217
  offset = 0
@@ -236,7 +230,7 @@ class AbstractWarehouse(ABC, Serializable):
236
230
  processed = False
237
231
  for row in results:
238
232
  processed = True
239
- yield RowDict(zip(cols_names, row))
233
+ yield row
240
234
  num_yielded += 1
241
235
 
242
236
  if not processed:
datachain/lib/dc.py CHANGED
@@ -508,7 +508,7 @@ class DataChain(DatasetQuery):
508
508
 
509
509
  def print_json_schema( # type: ignore[override]
510
510
  self, jmespath: Optional[str] = None, model_name: Optional[str] = None
511
- ) -> "DataChain":
511
+ ) -> "Self":
512
512
  """Print JSON data model and save it. It returns the chain itself.
513
513
 
514
514
  Parameters:
@@ -533,7 +533,7 @@ class DataChain(DatasetQuery):
533
533
 
534
534
  def print_jsonl_schema( # type: ignore[override]
535
535
  self, jmespath: Optional[str] = None, model_name: Optional[str] = None
536
- ) -> "DataChain":
536
+ ) -> "Self":
537
537
  """Print JSON data model and save it. It returns the chain itself.
538
538
 
539
539
  Parameters:
@@ -549,7 +549,7 @@ class DataChain(DatasetQuery):
549
549
 
550
550
  def save( # type: ignore[override]
551
551
  self, name: Optional[str] = None, version: Optional[int] = None
552
- ) -> "DataChain":
552
+ ) -> "Self":
553
553
  """Save to a Dataset. It returns the chain itself.
554
554
 
555
555
  Parameters:
@@ -785,7 +785,7 @@ class DataChain(DatasetQuery):
785
785
  descending (bool): Whether to sort in descending order or not.
786
786
  """
787
787
  if descending:
788
- args = tuple([sqlalchemy.desc(a) for a in args])
788
+ args = tuple(sqlalchemy.desc(a) for a in args)
789
789
 
790
790
  return super().order_by(*args)
791
791
 
@@ -1206,14 +1206,14 @@ class DataChain(DatasetQuery):
1206
1206
  """
1207
1207
  headers, max_length = self._effective_signals_schema.get_headers_with_length()
1208
1208
  if flatten or max_length < 2:
1209
- df = pd.DataFrame.from_records(self.to_records())
1209
+ columns = []
1210
1210
  if headers:
1211
- df.columns = [".".join(filter(None, header)) for header in headers]
1212
- return df
1211
+ columns = [".".join(filter(None, header)) for header in headers]
1212
+ return pd.DataFrame.from_records(self.to_records(), columns=columns)
1213
1213
 
1214
- transposed_result = list(map(list, zip(*self.results())))
1215
- data = {tuple(n): val for n, val in zip(headers, transposed_result)}
1216
- return pd.DataFrame(data)
1214
+ return pd.DataFrame(
1215
+ self.results(), columns=pd.MultiIndex.from_tuples(map(tuple, headers))
1216
+ )
1217
1217
 
1218
1218
  def show(
1219
1219
  self,
@@ -1232,6 +1232,12 @@ class DataChain(DatasetQuery):
1232
1232
  """
1233
1233
  dc = self.limit(limit) if limit > 0 else self
1234
1234
  df = dc.to_pandas(flatten)
1235
+
1236
+ if df.empty:
1237
+ print("Empty result")
1238
+ print(f"Columns: {list(df.columns)}")
1239
+ return
1240
+
1235
1241
  if transpose:
1236
1242
  df = df.T
1237
1243
 
@@ -1270,7 +1276,7 @@ class DataChain(DatasetQuery):
1270
1276
  source: bool = True,
1271
1277
  nrows: Optional[int] = None,
1272
1278
  **kwargs,
1273
- ) -> "DataChain":
1279
+ ) -> "Self":
1274
1280
  """Generate chain from list of tabular files.
1275
1281
 
1276
1282
  Parameters:
@@ -1390,7 +1396,8 @@ class DataChain(DatasetQuery):
1390
1396
  dc = DataChain.from_csv("s3://mybucket/dir")
1391
1397
  ```
1392
1398
  """
1393
- from pyarrow.csv import ParseOptions, ReadOptions
1399
+ from pandas.io.parsers.readers import STR_NA_VALUES
1400
+ from pyarrow.csv import ConvertOptions, ParseOptions, ReadOptions
1394
1401
  from pyarrow.dataset import CsvFileFormat
1395
1402
 
1396
1403
  chain = DataChain.from_storage(path, **kwargs)
@@ -1414,7 +1421,14 @@ class DataChain(DatasetQuery):
1414
1421
 
1415
1422
  parse_options = ParseOptions(delimiter=delimiter)
1416
1423
  read_options = ReadOptions(column_names=column_names)
1417
- format = CsvFileFormat(parse_options=parse_options, read_options=read_options)
1424
+ convert_options = ConvertOptions(
1425
+ strings_can_be_null=True, null_values=STR_NA_VALUES
1426
+ )
1427
+ format = CsvFileFormat(
1428
+ parse_options=parse_options,
1429
+ read_options=read_options,
1430
+ convert_options=convert_options,
1431
+ )
1418
1432
  return chain.parse_tabular(
1419
1433
  output=output,
1420
1434
  object_name=object_name,
@@ -1623,7 +1637,7 @@ class DataChain(DatasetQuery):
1623
1637
 
1624
1638
  Using glob to match patterns
1625
1639
  ```py
1626
- dc.filter(C("file.name").glob("*.jpg))
1640
+ dc.filter(C("file.name").glob("*.jpg"))
1627
1641
  ```
1628
1642
 
1629
1643
  Using `datachain.sql.functions`
@@ -11,12 +11,16 @@ from collections.abc import Iterator
11
11
  from typing import Any, Callable
12
12
 
13
13
  import jmespath as jsp
14
- from pydantic import Field, ValidationError # noqa: F401
14
+ from pydantic import BaseModel, ConfigDict, Field, ValidationError # noqa: F401
15
15
 
16
16
  from datachain.lib.data_model import DataModel # noqa: F401
17
17
  from datachain.lib.file import File
18
18
 
19
19
 
20
+ class UserModel(BaseModel):
21
+ model_config = ConfigDict(populate_by_name=True)
22
+
23
+
20
24
  def generate_uuid():
21
25
  return uuid.uuid4() # Generates a random UUID.
22
26
 
@@ -72,6 +76,8 @@ def read_schema(source_file, data_type="csv", expr=None, model_name=None):
72
76
  data_type,
73
77
  "--class-name",
74
78
  model_name,
79
+ "--base-class",
80
+ "datachain.lib.meta_formats.UserModel",
75
81
  ]
76
82
  try:
77
83
  result = subprocess.run( # noqa: S603
@@ -87,7 +93,7 @@ def read_schema(source_file, data_type="csv", expr=None, model_name=None):
87
93
  except subprocess.CalledProcessError as e:
88
94
  model_output = f"An error occurred in datamodel-codegen: {e.stderr}"
89
95
  print(f"{model_output}")
90
- print("\n" + "from datachain.lib.data_model import DataModel" + "\n")
96
+ print("from datachain.lib.data_model import DataModel")
91
97
  print("\n" + f"DataModel.register({model_name})" + "\n")
92
98
  print("\n" + f"spec={model_name}" + "\n")
93
99
  return model_output
datachain/lib/udf.py CHANGED
@@ -1,6 +1,5 @@
1
1
  import sys
2
2
  import traceback
3
- from collections.abc import Iterable, Iterator
4
3
  from typing import TYPE_CHECKING, Callable, Optional
5
4
 
6
5
  from fsspec.callbacks import DEFAULT_CALLBACK, Callback
@@ -14,16 +13,19 @@ from datachain.lib.model_store import ModelStore
14
13
  from datachain.lib.signal_schema import SignalSchema
15
14
  from datachain.lib.udf_signature import UdfSignature
16
15
  from datachain.lib.utils import AbstractUDF, DataChainError, DataChainParamsError
17
- from datachain.query.batch import RowBatch
16
+ from datachain.query.batch import UDFInputBatch
18
17
  from datachain.query.schema import ColumnParameter
19
18
  from datachain.query.udf import UDFBase as _UDFBase
20
- from datachain.query.udf import UDFProperties, UDFResult
19
+ from datachain.query.udf import UDFProperties
21
20
 
22
21
  if TYPE_CHECKING:
22
+ from collections.abc import Iterable, Iterator, Sequence
23
+
23
24
  from typing_extensions import Self
24
25
 
25
26
  from datachain.catalog import Catalog
26
- from datachain.query.batch import BatchingResult
27
+ from datachain.query.batch import RowsOutput, UDFInput
28
+ from datachain.query.udf import UDFResult
27
29
 
28
30
 
29
31
  class UdfError(DataChainParamsError):
@@ -42,22 +44,27 @@ class UDFAdapter(_UDFBase):
42
44
 
43
45
  def run(
44
46
  self,
45
- udf_inputs: "Iterable[BatchingResult]",
47
+ udf_fields: "Sequence[str]",
48
+ udf_inputs: "Iterable[RowsOutput]",
46
49
  catalog: "Catalog",
47
50
  is_generator: bool,
48
51
  cache: bool,
49
52
  download_cb: Callback = DEFAULT_CALLBACK,
50
53
  processed_cb: Callback = DEFAULT_CALLBACK,
51
- ) -> Iterator[Iterable["UDFResult"]]:
54
+ ) -> "Iterator[Iterable[UDFResult]]":
52
55
  self.inner._catalog = catalog
53
56
  if hasattr(self.inner, "setup") and callable(self.inner.setup):
54
57
  self.inner.setup()
55
58
 
56
- for batch in udf_inputs:
57
- n_rows = len(batch.rows) if isinstance(batch, RowBatch) else 1
58
- output = self.run_once(catalog, batch, is_generator, cache, cb=download_cb)
59
- processed_cb.relative_update(n_rows)
60
- yield output
59
+ yield from super().run(
60
+ udf_fields,
61
+ udf_inputs,
62
+ catalog,
63
+ is_generator,
64
+ cache,
65
+ download_cb,
66
+ processed_cb,
67
+ )
61
68
 
62
69
  if hasattr(self.inner, "teardown") and callable(self.inner.teardown):
63
70
  self.inner.teardown()
@@ -65,12 +72,12 @@ class UDFAdapter(_UDFBase):
65
72
  def run_once(
66
73
  self,
67
74
  catalog: "Catalog",
68
- arg: "BatchingResult",
75
+ arg: "UDFInput",
69
76
  is_generator: bool = False,
70
77
  cache: bool = False,
71
78
  cb: Callback = DEFAULT_CALLBACK,
72
- ) -> Iterable[UDFResult]:
73
- if isinstance(arg, RowBatch):
79
+ ) -> "Iterable[UDFResult]":
80
+ if isinstance(arg, UDFInputBatch):
74
81
  udf_inputs = [
75
82
  self.bind_parameters(catalog, row, cache=cache, cb=cb)
76
83
  for row in arg.rows
datachain/node.py CHANGED
@@ -47,7 +47,7 @@ class DirTypeGroup:
47
47
  @attrs.define
48
48
  class Node:
49
49
  sys__id: int = 0
50
- sys__rand: int = -1
50
+ sys__rand: int = 0
51
51
  vtype: str = ""
52
52
  dir_type: Optional[int] = None
53
53
  path: str = ""
datachain/query/batch.py CHANGED
@@ -5,21 +5,29 @@ from collections.abc import Generator, Sequence
5
5
  from dataclasses import dataclass
6
6
  from typing import TYPE_CHECKING, Callable, Optional, Union
7
7
 
8
- import sqlalchemy as sa
9
-
10
8
  from datachain.data_storage.schema import PARTITION_COLUMN_ID
11
9
  from datachain.data_storage.warehouse import SELECT_BATCH_SIZE
12
10
 
13
11
  if TYPE_CHECKING:
12
+ from sqlalchemy import Select
13
+
14
14
  from datachain.dataset import RowDict
15
15
 
16
16
 
17
17
  @dataclass
18
- class RowBatch:
18
+ class RowsOutputBatch:
19
+ rows: Sequence[Sequence]
20
+
21
+
22
+ RowsOutput = Union[Sequence, RowsOutputBatch]
23
+
24
+
25
+ @dataclass
26
+ class UDFInputBatch:
19
27
  rows: Sequence["RowDict"]
20
28
 
21
29
 
22
- BatchingResult = Union["RowDict", RowBatch]
30
+ UDFInput = Union["RowDict", UDFInputBatch]
23
31
 
24
32
 
25
33
  class BatchingStrategy(ABC):
@@ -28,9 +36,9 @@ class BatchingStrategy(ABC):
28
36
  @abstractmethod
29
37
  def __call__(
30
38
  self,
31
- execute: Callable,
32
- query: sa.sql.selectable.Select,
33
- ) -> Generator[BatchingResult, None, None]:
39
+ execute: Callable[..., Generator[Sequence, None, None]],
40
+ query: "Select",
41
+ ) -> Generator[RowsOutput, None, None]:
34
42
  """Apply the provided parameters to the UDF."""
35
43
 
36
44
 
@@ -42,10 +50,10 @@ class NoBatching(BatchingStrategy):
42
50
 
43
51
  def __call__(
44
52
  self,
45
- execute: Callable,
46
- query: sa.sql.selectable.Select,
47
- ) -> Generator["RowDict", None, None]:
48
- return execute(query, limit=query._limit, order_by=query._order_by_clauses)
53
+ execute: Callable[..., Generator[Sequence, None, None]],
54
+ query: "Select",
55
+ ) -> Generator[Sequence, None, None]:
56
+ return execute(query)
49
57
 
50
58
 
51
59
  class Batch(BatchingStrategy):
@@ -59,31 +67,24 @@ class Batch(BatchingStrategy):
59
67
 
60
68
  def __call__(
61
69
  self,
62
- execute: Callable,
63
- query: sa.sql.selectable.Select,
64
- ) -> Generator[RowBatch, None, None]:
70
+ execute: Callable[..., Generator[Sequence, None, None]],
71
+ query: "Select",
72
+ ) -> Generator[RowsOutputBatch, None, None]:
65
73
  # choose page size that is a multiple of the batch size
66
74
  page_size = math.ceil(SELECT_BATCH_SIZE / self.count) * self.count
67
75
 
68
76
  # select rows in batches
69
- results: list[RowDict] = []
70
-
71
- with contextlib.closing(
72
- execute(
73
- query,
74
- page_size=page_size,
75
- limit=query._limit,
76
- order_by=query._order_by_clauses,
77
- )
78
- ) as rows:
77
+ results: list[Sequence] = []
78
+
79
+ with contextlib.closing(execute(query, page_size=page_size)) as rows:
79
80
  for row in rows:
80
81
  results.append(row)
81
82
  if len(results) >= self.count:
82
83
  batch, results = results[: self.count], results[self.count :]
83
- yield RowBatch(batch)
84
+ yield RowsOutputBatch(batch)
84
85
 
85
86
  if len(results) > 0:
86
- yield RowBatch(results)
87
+ yield RowsOutputBatch(results)
87
88
 
88
89
 
89
90
  class Partition(BatchingStrategy):
@@ -95,27 +96,30 @@ class Partition(BatchingStrategy):
95
96
 
96
97
  def __call__(
97
98
  self,
98
- execute: Callable,
99
- query: sa.sql.selectable.Select,
100
- ) -> Generator[RowBatch, None, None]:
99
+ execute: Callable[..., Generator[Sequence, None, None]],
100
+ query: "Select",
101
+ ) -> Generator[RowsOutputBatch, None, None]:
101
102
  current_partition: Optional[int] = None
102
- batch: list[RowDict] = []
103
-
104
- with contextlib.closing(
105
- execute(
106
- query,
107
- order_by=(PARTITION_COLUMN_ID, "sys__id", *query._order_by_clauses),
108
- limit=query._limit,
109
- )
110
- ) as rows:
103
+ batch: list[Sequence] = []
104
+
105
+ query_fields = [str(c.name) for c in query.selected_columns]
106
+ partition_column_idx = query_fields.index(PARTITION_COLUMN_ID)
107
+
108
+ ordered_query = query.order_by(None).order_by(
109
+ PARTITION_COLUMN_ID,
110
+ "sys__id",
111
+ *query._order_by_clauses,
112
+ )
113
+
114
+ with contextlib.closing(execute(ordered_query)) as rows:
111
115
  for row in rows:
112
- partition = row[PARTITION_COLUMN_ID]
116
+ partition = row[partition_column_idx]
113
117
  if current_partition != partition:
114
118
  current_partition = partition
115
119
  if len(batch) > 0:
116
- yield RowBatch(batch)
120
+ yield RowsOutputBatch(batch)
117
121
  batch = []
118
122
  batch.append(row)
119
123
 
120
124
  if len(batch) > 0:
121
- yield RowBatch(batch)
125
+ yield RowsOutputBatch(batch)
@@ -461,6 +461,8 @@ class UDFStep(Step, ABC):
461
461
 
462
462
  processes = determine_processes(self.parallel)
463
463
 
464
+ udf_fields = [str(c.name) for c in query.selected_columns]
465
+
464
466
  try:
465
467
  if workers:
466
468
  from datachain.catalog.loader import get_distributed_class
@@ -473,6 +475,7 @@ class UDFStep(Step, ABC):
473
475
  query,
474
476
  workers,
475
477
  processes,
478
+ udf_fields=udf_fields,
476
479
  is_generator=self.is_generator,
477
480
  use_partitioning=use_partitioning,
478
481
  cache=self.cache,
@@ -489,6 +492,7 @@ class UDFStep(Step, ABC):
489
492
  "warehouse_clone_params": self.catalog.warehouse.clone_params(),
490
493
  "table": udf_table,
491
494
  "query": query,
495
+ "udf_fields": udf_fields,
492
496
  "batching": batching,
493
497
  "processes": processes,
494
498
  "is_generator": self.is_generator,
@@ -528,6 +532,7 @@ class UDFStep(Step, ABC):
528
532
  generated_cb = get_generated_callback(self.is_generator)
529
533
  try:
530
534
  udf_results = udf.run(
535
+ udf_fields,
531
536
  udf_inputs,
532
537
  self.catalog,
533
538
  self.is_generator,
@@ -1244,21 +1249,23 @@ class DatasetQuery:
1244
1249
  actual_params = [normalize_param(p) for p in params]
1245
1250
  try:
1246
1251
  query = self.apply_steps().select()
1252
+ query_fields = [str(c.name) for c in query.selected_columns]
1247
1253
 
1248
- def row_iter() -> Generator[RowDict, None, None]:
1254
+ def row_iter() -> Generator[Sequence, None, None]:
1249
1255
  # warehouse isn't threadsafe, we need to clone() it
1250
1256
  # in the thread that uses the results
1251
1257
  with self.catalog.warehouse.clone() as warehouse:
1252
- gen = warehouse.dataset_select_paginated(
1253
- query, limit=query._limit, order_by=query._order_by_clauses
1254
- )
1258
+ gen = warehouse.dataset_select_paginated(query)
1255
1259
  with contextlib.closing(gen) as rows:
1256
1260
  yield from rows
1257
1261
 
1258
- async def get_params(row: RowDict) -> tuple:
1262
+ async def get_params(row: Sequence) -> tuple:
1263
+ row_dict = RowDict(zip(query_fields, row))
1259
1264
  return tuple(
1260
1265
  [
1261
- await p.get_value_async(self.catalog, row, mapper, **kwargs)
1266
+ await p.get_value_async(
1267
+ self.catalog, row_dict, mapper, **kwargs
1268
+ )
1262
1269
  for p in actual_params
1263
1270
  ]
1264
1271
  )