datachain 0.14.2__py3-none-any.whl → 0.39.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (137) hide show
  1. datachain/__init__.py +20 -0
  2. datachain/asyn.py +11 -12
  3. datachain/cache.py +7 -7
  4. datachain/catalog/__init__.py +2 -2
  5. datachain/catalog/catalog.py +621 -507
  6. datachain/catalog/dependency.py +164 -0
  7. datachain/catalog/loader.py +28 -18
  8. datachain/checkpoint.py +43 -0
  9. datachain/cli/__init__.py +24 -33
  10. datachain/cli/commands/__init__.py +1 -8
  11. datachain/cli/commands/datasets.py +83 -52
  12. datachain/cli/commands/ls.py +17 -17
  13. datachain/cli/commands/show.py +4 -4
  14. datachain/cli/parser/__init__.py +8 -74
  15. datachain/cli/parser/job.py +95 -3
  16. datachain/cli/parser/studio.py +11 -4
  17. datachain/cli/parser/utils.py +1 -2
  18. datachain/cli/utils.py +2 -15
  19. datachain/client/azure.py +4 -4
  20. datachain/client/fsspec.py +45 -28
  21. datachain/client/gcs.py +6 -6
  22. datachain/client/hf.py +29 -2
  23. datachain/client/http.py +157 -0
  24. datachain/client/local.py +15 -11
  25. datachain/client/s3.py +17 -9
  26. datachain/config.py +4 -8
  27. datachain/data_storage/db_engine.py +12 -6
  28. datachain/data_storage/job.py +5 -1
  29. datachain/data_storage/metastore.py +1252 -186
  30. datachain/data_storage/schema.py +58 -45
  31. datachain/data_storage/serializer.py +105 -15
  32. datachain/data_storage/sqlite.py +286 -127
  33. datachain/data_storage/warehouse.py +250 -113
  34. datachain/dataset.py +353 -148
  35. datachain/delta.py +391 -0
  36. datachain/diff/__init__.py +27 -29
  37. datachain/error.py +60 -0
  38. datachain/func/__init__.py +2 -1
  39. datachain/func/aggregate.py +66 -42
  40. datachain/func/array.py +242 -38
  41. datachain/func/base.py +7 -4
  42. datachain/func/conditional.py +110 -60
  43. datachain/func/func.py +96 -45
  44. datachain/func/numeric.py +55 -38
  45. datachain/func/path.py +32 -20
  46. datachain/func/random.py +2 -2
  47. datachain/func/string.py +67 -37
  48. datachain/func/window.py +7 -8
  49. datachain/hash_utils.py +123 -0
  50. datachain/job.py +11 -7
  51. datachain/json.py +138 -0
  52. datachain/lib/arrow.py +58 -22
  53. datachain/lib/audio.py +245 -0
  54. datachain/lib/clip.py +14 -13
  55. datachain/lib/convert/flatten.py +5 -3
  56. datachain/lib/convert/python_to_sql.py +6 -10
  57. datachain/lib/convert/sql_to_python.py +8 -0
  58. datachain/lib/convert/values_to_tuples.py +156 -51
  59. datachain/lib/data_model.py +42 -20
  60. datachain/lib/dataset_info.py +36 -8
  61. datachain/lib/dc/__init__.py +8 -2
  62. datachain/lib/dc/csv.py +25 -28
  63. datachain/lib/dc/database.py +398 -0
  64. datachain/lib/dc/datachain.py +1289 -425
  65. datachain/lib/dc/datasets.py +320 -38
  66. datachain/lib/dc/hf.py +38 -24
  67. datachain/lib/dc/json.py +29 -32
  68. datachain/lib/dc/listings.py +112 -8
  69. datachain/lib/dc/pandas.py +16 -12
  70. datachain/lib/dc/parquet.py +35 -23
  71. datachain/lib/dc/records.py +31 -23
  72. datachain/lib/dc/storage.py +154 -64
  73. datachain/lib/dc/storage_pattern.py +251 -0
  74. datachain/lib/dc/utils.py +24 -16
  75. datachain/lib/dc/values.py +8 -9
  76. datachain/lib/file.py +622 -89
  77. datachain/lib/hf.py +69 -39
  78. datachain/lib/image.py +14 -14
  79. datachain/lib/listing.py +14 -11
  80. datachain/lib/listing_info.py +1 -2
  81. datachain/lib/meta_formats.py +3 -4
  82. datachain/lib/model_store.py +39 -7
  83. datachain/lib/namespaces.py +125 -0
  84. datachain/lib/projects.py +130 -0
  85. datachain/lib/pytorch.py +32 -21
  86. datachain/lib/settings.py +192 -56
  87. datachain/lib/signal_schema.py +427 -104
  88. datachain/lib/tar.py +1 -2
  89. datachain/lib/text.py +8 -7
  90. datachain/lib/udf.py +164 -76
  91. datachain/lib/udf_signature.py +60 -35
  92. datachain/lib/utils.py +118 -4
  93. datachain/lib/video.py +17 -9
  94. datachain/lib/webdataset.py +61 -56
  95. datachain/lib/webdataset_laion.py +15 -16
  96. datachain/listing.py +22 -10
  97. datachain/model/bbox.py +3 -1
  98. datachain/model/ultralytics/bbox.py +16 -12
  99. datachain/model/ultralytics/pose.py +16 -12
  100. datachain/model/ultralytics/segment.py +16 -12
  101. datachain/namespace.py +84 -0
  102. datachain/node.py +6 -6
  103. datachain/nodes_thread_pool.py +0 -1
  104. datachain/plugins.py +24 -0
  105. datachain/project.py +78 -0
  106. datachain/query/batch.py +40 -41
  107. datachain/query/dataset.py +604 -322
  108. datachain/query/dispatch.py +261 -154
  109. datachain/query/metrics.py +4 -6
  110. datachain/query/params.py +2 -3
  111. datachain/query/queue.py +3 -12
  112. datachain/query/schema.py +11 -6
  113. datachain/query/session.py +200 -33
  114. datachain/query/udf.py +34 -2
  115. datachain/remote/studio.py +171 -69
  116. datachain/script_meta.py +12 -12
  117. datachain/semver.py +68 -0
  118. datachain/sql/__init__.py +2 -0
  119. datachain/sql/functions/array.py +33 -1
  120. datachain/sql/postgresql_dialect.py +9 -0
  121. datachain/sql/postgresql_types.py +21 -0
  122. datachain/sql/sqlite/__init__.py +5 -1
  123. datachain/sql/sqlite/base.py +102 -29
  124. datachain/sql/sqlite/types.py +8 -13
  125. datachain/sql/types.py +70 -15
  126. datachain/studio.py +223 -46
  127. datachain/toolkit/split.py +31 -10
  128. datachain/utils.py +101 -59
  129. {datachain-0.14.2.dist-info → datachain-0.39.0.dist-info}/METADATA +77 -22
  130. datachain-0.39.0.dist-info/RECORD +173 -0
  131. {datachain-0.14.2.dist-info → datachain-0.39.0.dist-info}/WHEEL +1 -1
  132. datachain/cli/commands/query.py +0 -53
  133. datachain/query/utils.py +0 -42
  134. datachain-0.14.2.dist-info/RECORD +0 -158
  135. {datachain-0.14.2.dist-info → datachain-0.39.0.dist-info}/entry_points.txt +0 -0
  136. {datachain-0.14.2.dist-info → datachain-0.39.0.dist-info}/licenses/LICENSE +0 -0
  137. {datachain-0.14.2.dist-info → datachain-0.39.0.dist-info}/top_level.txt +0 -0
@@ -56,16 +56,16 @@ class YoloPose(DataModel):
56
56
  if not summary:
57
57
  return YoloPose(box=BBox(), pose=Pose3D())
58
58
  name = summary[0].get("name", "")
59
- box = (
60
- BBox.from_dict(summary[0]["box"], title=name)
61
- if "box" in summary[0]
62
- else BBox()
63
- )
64
- pose = (
65
- Pose3D.from_dict(summary[0]["keypoints"])
66
- if "keypoints" in summary[0]
67
- else Pose3D()
68
- )
59
+ if summary[0].get("box"):
60
+ assert isinstance(summary[0]["box"], dict)
61
+ box = BBox.from_dict(summary[0]["box"], title=name)
62
+ else:
63
+ box = BBox()
64
+ if summary[0].get("keypoints"):
65
+ assert isinstance(summary[0]["keypoints"], dict)
66
+ pose = Pose3D.from_dict(summary[0]["keypoints"])
67
+ else:
68
+ pose = Pose3D()
69
69
  return YoloPose(
70
70
  cls=summary[0]["class"],
71
71
  name=name,
@@ -102,8 +102,12 @@ class YoloPoses(DataModel):
102
102
  cls.append(s["class"])
103
103
  names.append(name)
104
104
  confidence.append(s["confidence"])
105
- box.append(BBox.from_dict(s.get("box", {}), title=name))
106
- pose.append(Pose3D.from_dict(s.get("keypoints", {})))
105
+ if s.get("box"):
106
+ assert isinstance(s["box"], dict)
107
+ box.append(BBox.from_dict(s["box"], title=name))
108
+ if s.get("keypoints"):
109
+ assert isinstance(s["keypoints"], dict)
110
+ pose.append(Pose3D.from_dict(s["keypoints"]))
107
111
  return YoloPoses(
108
112
  cls=cls,
109
113
  name=names,
@@ -34,16 +34,16 @@ class YoloSegment(DataModel):
34
34
  if not summary:
35
35
  return YoloSegment(box=BBox(), segment=Segment())
36
36
  name = summary[0].get("name", "")
37
- box = (
38
- BBox.from_dict(summary[0]["box"], title=name)
39
- if "box" in summary[0]
40
- else BBox()
41
- )
42
- segment = (
43
- Segment.from_dict(summary[0]["segments"], title=name)
44
- if "segments" in summary[0]
45
- else Segment()
46
- )
37
+ if summary[0].get("box"):
38
+ assert isinstance(summary[0]["box"], dict)
39
+ box = BBox.from_dict(summary[0]["box"], title=name)
40
+ else:
41
+ box = BBox()
42
+ if summary[0].get("segments"):
43
+ assert isinstance(summary[0]["segments"], dict)
44
+ segment = Segment.from_dict(summary[0]["segments"], title=name)
45
+ else:
46
+ segment = Segment()
47
47
  return YoloSegment(
48
48
  cls=summary[0]["class"],
49
49
  name=summary[0]["name"],
@@ -80,8 +80,12 @@ class YoloSegments(DataModel):
80
80
  cls.append(s["class"])
81
81
  names.append(name)
82
82
  confidence.append(s["confidence"])
83
- box.append(BBox.from_dict(s.get("box", {}), title=name))
84
- segment.append(Segment.from_dict(s.get("segments", {}), title=name))
83
+ if s.get("box"):
84
+ assert isinstance(s["box"], dict)
85
+ box.append(BBox.from_dict(s["box"], title=name))
86
+ if s.get("segments"):
87
+ assert isinstance(s["segments"], dict)
88
+ segment.append(Segment.from_dict(s["segments"], title=name))
85
89
  return YoloSegments(
86
90
  cls=cls,
87
91
  name=names,
datachain/namespace.py ADDED
@@ -0,0 +1,84 @@
1
+ import builtins
2
+ from dataclasses import dataclass, fields
3
+ from datetime import datetime
4
+ from typing import Any, TypeVar
5
+
6
+ from datachain.error import InvalidNamespaceNameError
7
+
8
+ N = TypeVar("N", bound="Namespace")
9
+ NAMESPACE_NAME_RESERVED_CHARS = [".", "@"]
10
+
11
+
12
+ def parse_name(name: str) -> tuple[str, str | None]:
13
+ """
14
+ Parses namespace name into namespace and optional project name.
15
+ If both namespace and project are defined in name, they need to be split by dot
16
+ e.g dev.my-project
17
+ Valid inputs:
18
+ - dev.my-project
19
+ - dev
20
+ """
21
+ parts = name.split(".")
22
+ if len(parts) == 1:
23
+ return name, None
24
+ if len(parts) == 2:
25
+ return parts[0], parts[1]
26
+ raise InvalidNamespaceNameError(
27
+ f"Invalid namespace format: {name}. Expected 'namespace' or 'ns1.ns2'."
28
+ )
29
+
30
+
31
+ @dataclass(frozen=True)
32
+ class Namespace:
33
+ id: int
34
+ uuid: str
35
+ name: str
36
+ descr: str | None
37
+ created_at: datetime
38
+
39
+ @staticmethod
40
+ def validate_name(name: str) -> None:
41
+ """Throws exception if name is invalid, otherwise returns None"""
42
+ if not name:
43
+ raise InvalidNamespaceNameError("Namespace name cannot be empty")
44
+
45
+ for c in NAMESPACE_NAME_RESERVED_CHARS:
46
+ if c in name:
47
+ raise InvalidNamespaceNameError(
48
+ f"Character {c} is reserved and not allowed in namespace name"
49
+ )
50
+
51
+ if name in [Namespace.default(), Namespace.system()]:
52
+ raise InvalidNamespaceNameError(
53
+ f"Namespace name {name} is reserved and cannot be used."
54
+ )
55
+
56
+ @staticmethod
57
+ def default() -> str:
58
+ """Name of default namespace"""
59
+ return "local"
60
+
61
+ @staticmethod
62
+ def system() -> str:
63
+ """Name of the system namespace"""
64
+ return "system"
65
+
66
+ @property
67
+ def is_system(self):
68
+ return self.name == Namespace.system()
69
+
70
+ @classmethod
71
+ def parse(
72
+ cls: builtins.type[N],
73
+ id: int,
74
+ uuid: str,
75
+ name: str,
76
+ descr: str | None,
77
+ created_at: datetime,
78
+ ) -> "Namespace":
79
+ return cls(id, uuid, name, descr, created_at)
80
+
81
+ @classmethod
82
+ def from_dict(cls, d: dict[str, Any]) -> "Namespace":
83
+ kwargs = {f.name: d[f.name] for f in fields(cls) if f.name in d}
84
+ return cls(**kwargs)
datachain/node.py CHANGED
@@ -1,6 +1,6 @@
1
1
  import os
2
2
  from datetime import datetime
3
- from typing import TYPE_CHECKING, Any, Optional
3
+ from typing import TYPE_CHECKING, Any
4
4
 
5
5
  import attrs
6
6
 
@@ -53,11 +53,11 @@ class Node:
53
53
  sys__rand: int = 0
54
54
  path: str = ""
55
55
  etag: str = ""
56
- version: Optional[str] = None
56
+ version: str | None = None
57
57
  is_latest: bool = True
58
- last_modified: Optional[datetime] = None
58
+ last_modified: datetime | None = None
59
59
  size: int = 0
60
- location: Optional[str] = None
60
+ location: str | None = None
61
61
  source: StorageURI = StorageURI("") # noqa: RUF009
62
62
  dir_type: int = DirType.FILE
63
63
 
@@ -90,7 +90,7 @@ class Node:
90
90
  return self.path + "/"
91
91
  return self.path
92
92
 
93
- def to_file(self, source: Optional[StorageURI] = None) -> File:
93
+ def to_file(self, source: StorageURI | None = None) -> File:
94
94
  if source is None:
95
95
  source = self.source
96
96
  return File(
@@ -189,7 +189,7 @@ class NodeWithPath:
189
189
  TIME_FMT = "%Y-%m-%d %H:%M"
190
190
 
191
191
 
192
- def long_line_str(name: str, timestamp: Optional[datetime]) -> str:
192
+ def long_line_str(name: str, timestamp: datetime | None) -> str:
193
193
  if timestamp is None:
194
194
  time = "-"
195
195
  else:
@@ -1,4 +1,3 @@
1
- import concurrent
2
1
  import concurrent.futures
3
2
  import threading
4
3
  from abc import ABC, abstractmethod
datachain/plugins.py ADDED
@@ -0,0 +1,24 @@
1
+ """Plugin loader for DataChain callables.
2
+
3
+ Discovers and invokes entry points in the group "datachain.callables" once
4
+ per process. This enables external packages (e.g., Studio) to register
5
+ their callables with the serializer registry without explicit imports.
6
+ """
7
+
8
+ from importlib import metadata as importlib_metadata
9
+
10
+ _plugins_loaded = False
11
+
12
+
13
+ def ensure_plugins_loaded() -> None:
14
+ global _plugins_loaded # noqa: PLW0603
15
+ if _plugins_loaded:
16
+ return
17
+
18
+ # Compatible across importlib.metadata versions
19
+ eps_obj = importlib_metadata.entry_points()
20
+ for ep in eps_obj.select(group="datachain.callables"):
21
+ func = ep.load()
22
+ func()
23
+
24
+ _plugins_loaded = True
datachain/project.py ADDED
@@ -0,0 +1,78 @@
1
+ import builtins
2
+ from dataclasses import dataclass, fields
3
+ from datetime import datetime
4
+ from typing import Any, TypeVar
5
+
6
+ from datachain.error import InvalidProjectNameError
7
+ from datachain.namespace import Namespace
8
+
9
+ P = TypeVar("P", bound="Project")
10
+ PROJECT_NAME_RESERVED_CHARS = [".", "@"]
11
+
12
+
13
+ @dataclass(frozen=True)
14
+ class Project:
15
+ id: int
16
+ uuid: str
17
+ name: str
18
+ descr: str | None
19
+ created_at: datetime
20
+ namespace: Namespace
21
+
22
+ @staticmethod
23
+ def validate_name(name: str) -> None:
24
+ """Throws exception if name is invalid, otherwise returns None"""
25
+ if not name:
26
+ raise InvalidProjectNameError("Project name cannot be empty")
27
+
28
+ for c in PROJECT_NAME_RESERVED_CHARS:
29
+ if c in name:
30
+ raise InvalidProjectNameError(
31
+ f"Character {c} is reserved and not allowed in project name."
32
+ )
33
+
34
+ if name in [Project.default(), Project.listing()]:
35
+ raise InvalidProjectNameError(
36
+ f"Project name {name} is reserved and cannot be used."
37
+ )
38
+
39
+ @staticmethod
40
+ def default() -> str:
41
+ """Name of default project"""
42
+ return "local"
43
+
44
+ @staticmethod
45
+ def listing() -> str:
46
+ """Name of listing project where all listing datasets will be saved"""
47
+ return "listing"
48
+
49
+ @classmethod
50
+ def parse(
51
+ cls: builtins.type[P],
52
+ namespace_id: int,
53
+ namespace_uuid: str,
54
+ namespace_name: str,
55
+ namespace_descr: str | None,
56
+ namespace_created_at: datetime,
57
+ project_id: int,
58
+ uuid: str,
59
+ name: str,
60
+ descr: str | None,
61
+ created_at: datetime,
62
+ project_namespace_id: int,
63
+ ) -> "Project":
64
+ namespace = Namespace.parse(
65
+ namespace_id,
66
+ namespace_uuid,
67
+ namespace_name,
68
+ namespace_descr,
69
+ namespace_created_at,
70
+ )
71
+
72
+ return cls(project_id, uuid, name, descr, created_at, namespace)
73
+
74
+ @classmethod
75
+ def from_dict(cls, d: dict[str, Any]) -> "Project":
76
+ namespace = Namespace.from_dict(d.pop("namespace"))
77
+ kwargs = {f.name: d[f.name] for f in fields(cls) if f.name in d}
78
+ return cls(**kwargs, namespace=namespace)
datachain/query/batch.py CHANGED
@@ -1,24 +1,14 @@
1
1
  import contextlib
2
2
  import math
3
3
  from abc import ABC, abstractmethod
4
- from collections.abc import Generator, Sequence
5
- from dataclasses import dataclass
6
- from typing import TYPE_CHECKING, Callable, Optional, Union
7
-
8
- from datachain.data_storage.schema import PARTITION_COLUMN_ID
9
- from datachain.data_storage.warehouse import SELECT_BATCH_SIZE
10
- from datachain.query.utils import get_query_column, get_query_id_column
11
-
12
- if TYPE_CHECKING:
13
- from sqlalchemy import Select
4
+ from collections.abc import Callable, Generator, Sequence
14
5
 
6
+ import sqlalchemy as sa
15
7
 
16
- @dataclass
17
- class RowsOutputBatch:
18
- rows: Sequence[Sequence]
19
-
8
+ from datachain.data_storage.schema import PARTITION_COLUMN_ID
20
9
 
21
- RowsOutput = Union[Sequence, RowsOutputBatch]
10
+ RowsOutputBatch = Sequence[Sequence]
11
+ RowsOutput = Sequence | RowsOutputBatch
22
12
 
23
13
 
24
14
  class BatchingStrategy(ABC):
@@ -30,8 +20,8 @@ class BatchingStrategy(ABC):
30
20
  def __call__(
31
21
  self,
32
22
  execute: Callable,
33
- query: "Select",
34
- ids_only: bool = False,
23
+ query: sa.Select,
24
+ id_col: sa.ColumnElement | None = None,
35
25
  ) -> Generator[RowsOutput, None, None]:
36
26
  """Apply the provided parameters to the UDF."""
37
27
 
@@ -47,12 +37,16 @@ class NoBatching(BatchingStrategy):
47
37
  def __call__(
48
38
  self,
49
39
  execute: Callable,
50
- query: "Select",
51
- ids_only: bool = False,
40
+ query: sa.Select,
41
+ id_col: sa.ColumnElement | None = None,
52
42
  ) -> Generator[Sequence, None, None]:
53
- if ids_only:
54
- query = query.with_only_columns(get_query_id_column(query))
55
- return execute(query)
43
+ ids_only = False
44
+ if id_col is not None:
45
+ query = query.with_only_columns(id_col)
46
+ ids_only = True
47
+
48
+ rows = execute(query)
49
+ yield from (r[0] for r in rows) if ids_only else rows
56
50
 
57
51
 
58
52
  class Batch(BatchingStrategy):
@@ -69,27 +63,31 @@ class Batch(BatchingStrategy):
69
63
  def __call__(
70
64
  self,
71
65
  execute: Callable,
72
- query: "Select",
73
- ids_only: bool = False,
74
- ) -> Generator[RowsOutputBatch, None, None]:
75
- if ids_only:
76
- query = query.with_only_columns(get_query_id_column(query))
66
+ query: sa.Select,
67
+ id_col: sa.ColumnElement | None = None,
68
+ ) -> Generator[RowsOutput, None, None]:
69
+ from datachain.data_storage.warehouse import SELECT_BATCH_SIZE
70
+
71
+ ids_only = False
72
+ if id_col is not None:
73
+ query = query.with_only_columns(id_col)
74
+ ids_only = True
77
75
 
78
76
  # choose page size that is a multiple of the batch size
79
77
  page_size = math.ceil(SELECT_BATCH_SIZE / self.count) * self.count
80
78
 
81
79
  # select rows in batches
82
- results: list[Sequence] = []
80
+ results = []
83
81
 
84
82
  with contextlib.closing(execute(query, page_size=page_size)) as rows:
85
83
  for row in rows:
86
84
  results.append(row)
87
85
  if len(results) >= self.count:
88
86
  batch, results = results[: self.count], results[self.count :]
89
- yield RowsOutputBatch(batch)
87
+ yield [r[0] for r in batch] if ids_only else batch
90
88
 
91
89
  if len(results) > 0:
92
- yield RowsOutputBatch(results)
90
+ yield [r[0] for r in results] if ids_only else results
93
91
 
94
92
 
95
93
  class Partition(BatchingStrategy):
@@ -104,18 +102,19 @@ class Partition(BatchingStrategy):
104
102
  def __call__(
105
103
  self,
106
104
  execute: Callable,
107
- query: "Select",
108
- ids_only: bool = False,
109
- ) -> Generator[RowsOutputBatch, None, None]:
110
- id_col = get_query_id_column(query)
111
- if (partition_col := get_query_column(query, PARTITION_COLUMN_ID)) is None:
105
+ query: sa.Select,
106
+ id_col: sa.ColumnElement | None = None,
107
+ ) -> Generator[RowsOutput, None, None]:
108
+ if (partition_col := query.selected_columns.get(PARTITION_COLUMN_ID)) is None:
112
109
  raise RuntimeError("partition column not found in query")
113
110
 
114
- if ids_only:
111
+ ids_only = False
112
+ if id_col is not None:
115
113
  query = query.with_only_columns(id_col, partition_col)
114
+ ids_only = True
116
115
 
117
- current_partition: Optional[int] = None
118
- batch: list[Sequence] = []
116
+ current_partition: int | None = None
117
+ batch: list = []
119
118
 
120
119
  query_fields = [str(c.name) for c in query.selected_columns]
121
120
  id_column_idx = query_fields.index("sys__id")
@@ -132,9 +131,9 @@ class Partition(BatchingStrategy):
132
131
  if current_partition != partition:
133
132
  current_partition = partition
134
133
  if len(batch) > 0:
135
- yield RowsOutputBatch(batch)
134
+ yield batch
136
135
  batch = []
137
- batch.append([row[id_column_idx]] if ids_only else row)
136
+ batch.append(row[id_column_idx] if ids_only else row)
138
137
 
139
138
  if len(batch) > 0:
140
- yield RowsOutputBatch(batch)
139
+ yield batch