daft-lance 0.1.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- daft_lance/__init__.py +278 -0
- daft_lance/compaction.py +68 -0
- daft_lance/merge.py +323 -0
- daft_lance/rest_config.py +55 -0
- daft_lance/rest_write.py +204 -0
- daft_lance/scalar_index.py +224 -0
- daft_lance/utils.py +189 -0
- daft_lance-0.1.0.dist-info/METADATA +12 -0
- daft_lance-0.1.0.dist-info/RECORD +11 -0
- daft_lance-0.1.0.dist-info/WHEEL +4 -0
- daft_lance-0.1.0.dist-info/licenses/LICENSE +199 -0
daft_lance/__init__.py
ADDED
|
@@ -0,0 +1,278 @@
|
|
|
1
|
+
# ruff: noqa: I002
|
|
2
|
+
# isort: dont-add-import: from __future__ import annotations
|
|
3
|
+
import pathlib
|
|
4
|
+
import warnings
|
|
5
|
+
from typing import TYPE_CHECKING, Any, Optional, Union
|
|
6
|
+
|
|
7
|
+
from daft import context
|
|
8
|
+
from daft.daft import IOConfig
|
|
9
|
+
from daft.io.object_store_options import io_config_to_storage_options
|
|
10
|
+
|
|
11
|
+
from daft_lance.merge import merge_columns_from_df, merge_columns_internal
|
|
12
|
+
from daft_lance.rest_config import LanceRestConfig
|
|
13
|
+
from daft_lance.rest_write import create_lance_table_rest, register_lance_table_rest, write_lance_rest
|
|
14
|
+
from daft_lance.utils import construct_lance_dataset
|
|
15
|
+
|
|
16
|
+
if TYPE_CHECKING:
|
|
17
|
+
from collections.abc import Callable
|
|
18
|
+
|
|
19
|
+
from daft.dataframe import DataFrame
|
|
20
|
+
from daft.dependencies import pa
|
|
21
|
+
|
|
22
|
+
try:
|
|
23
|
+
from lance.dataset import LanceDataset
|
|
24
|
+
from lance.udf import BatchUDF
|
|
25
|
+
except ImportError:
|
|
26
|
+
BatchUDF = None
|
|
27
|
+
LanceDataset = None
|
|
28
|
+
|
|
29
|
+
LanceDataset = Any
|
|
30
|
+
|
|
31
|
+
|
|
32
|
+
def merge_columns(
|
|
33
|
+
uri: str | pathlib.Path,
|
|
34
|
+
io_config: IOConfig | None = None,
|
|
35
|
+
*,
|
|
36
|
+
transform: Union[dict[str, str], "BatchUDF", Callable[["pa.lib.RecordBatch"], "pa.lib.RecordBatch"]] = None,
|
|
37
|
+
read_columns: list[str] | None = None,
|
|
38
|
+
reader_schema: Optional["pa.Schema"] = None,
|
|
39
|
+
storage_options: dict[str, Any] | None = None,
|
|
40
|
+
daft_remote_args: dict[str, Any] | None = None,
|
|
41
|
+
concurrency: int | None = None,
|
|
42
|
+
version: int | str | None = None,
|
|
43
|
+
asof: str | None = None,
|
|
44
|
+
block_size: int | None = None,
|
|
45
|
+
commit_lock: Any | None = None,
|
|
46
|
+
index_cache_size: int | None = None,
|
|
47
|
+
default_scan_options: dict[str, Any] | None = None,
|
|
48
|
+
metadata_cache_size_bytes: int | None = None,
|
|
49
|
+
) -> LanceDataset:
|
|
50
|
+
"""Merge new columns into a LanceDB table using a transformation function."""
|
|
51
|
+
warnings.warn(
|
|
52
|
+
"merge_columns is deprecated and will be removed in a future release. "
|
|
53
|
+
"Please use merge_columns_df instead.",
|
|
54
|
+
category=DeprecationWarning,
|
|
55
|
+
stacklevel=2,
|
|
56
|
+
)
|
|
57
|
+
|
|
58
|
+
if transform is None:
|
|
59
|
+
raise ValueError(
|
|
60
|
+
"merge_columns requires a `transform` function; prefer using merge_columns_df with a prepared DataFrame if no transform is needed."
|
|
61
|
+
)
|
|
62
|
+
|
|
63
|
+
io_config = context.get_context().daft_planning_config.default_io_config if io_config is None else io_config
|
|
64
|
+
storage_options = storage_options or io_config_to_storage_options(io_config, uri)
|
|
65
|
+
|
|
66
|
+
lance_ds = construct_lance_dataset(
|
|
67
|
+
uri,
|
|
68
|
+
storage_options=storage_options,
|
|
69
|
+
version=version,
|
|
70
|
+
asof=asof,
|
|
71
|
+
block_size=block_size,
|
|
72
|
+
commit_lock=commit_lock,
|
|
73
|
+
index_cache_size=index_cache_size,
|
|
74
|
+
default_scan_options=default_scan_options,
|
|
75
|
+
metadata_cache_size_bytes=metadata_cache_size_bytes,
|
|
76
|
+
)
|
|
77
|
+
|
|
78
|
+
return merge_columns_internal(
|
|
79
|
+
lance_ds,
|
|
80
|
+
uri,
|
|
81
|
+
transform=transform,
|
|
82
|
+
read_columns=read_columns,
|
|
83
|
+
reader_schema=reader_schema,
|
|
84
|
+
storage_options=storage_options,
|
|
85
|
+
daft_remote_args=daft_remote_args,
|
|
86
|
+
concurrency=concurrency,
|
|
87
|
+
)
|
|
88
|
+
|
|
89
|
+
|
|
90
|
+
def merge_columns_df(
|
|
91
|
+
df: "DataFrame",
|
|
92
|
+
uri: str | pathlib.Path,
|
|
93
|
+
io_config: IOConfig | None = None,
|
|
94
|
+
*,
|
|
95
|
+
read_columns: list[str] | None = None,
|
|
96
|
+
reader_schema: Optional["pa.Schema"] = None,
|
|
97
|
+
storage_options: dict[str, Any] | None = None,
|
|
98
|
+
daft_remote_args: dict[str, Any] | None = None,
|
|
99
|
+
concurrency: int | None = None,
|
|
100
|
+
version: int | str | None = None,
|
|
101
|
+
asof: str | None = None,
|
|
102
|
+
block_size: int | None = None,
|
|
103
|
+
commit_lock: Any | None = None,
|
|
104
|
+
index_cache_size: int | None = None,
|
|
105
|
+
default_scan_options: dict[str, Any] | None = None,
|
|
106
|
+
metadata_cache_size_bytes: int | None = None,
|
|
107
|
+
batch_size: int | None = None,
|
|
108
|
+
left_on: str | None = "_rowaddr",
|
|
109
|
+
right_on: str | None = "_rowaddr",
|
|
110
|
+
) -> None:
|
|
111
|
+
"""Row-level merge columns entrypoint using a DataFrame."""
|
|
112
|
+
io_config = context.get_context().daft_planning_config.default_io_config if io_config is None else io_config
|
|
113
|
+
storage_options = storage_options or io_config_to_storage_options(io_config, uri)
|
|
114
|
+
|
|
115
|
+
lance_ds = construct_lance_dataset(
|
|
116
|
+
uri,
|
|
117
|
+
storage_options=storage_options,
|
|
118
|
+
version=version,
|
|
119
|
+
asof=asof,
|
|
120
|
+
block_size=block_size,
|
|
121
|
+
commit_lock=commit_lock,
|
|
122
|
+
index_cache_size=index_cache_size,
|
|
123
|
+
default_scan_options=default_scan_options,
|
|
124
|
+
metadata_cache_size_bytes=metadata_cache_size_bytes,
|
|
125
|
+
)
|
|
126
|
+
|
|
127
|
+
effective_right_on = right_on or left_on
|
|
128
|
+
effective_batch_size = (
|
|
129
|
+
batch_size if batch_size is not None else daft_remote_args.get("batch_size", None) if daft_remote_args else None
|
|
130
|
+
)
|
|
131
|
+
|
|
132
|
+
return merge_columns_from_df(
|
|
133
|
+
df,
|
|
134
|
+
lance_ds=lance_ds,
|
|
135
|
+
uri=uri,
|
|
136
|
+
read_columns=read_columns,
|
|
137
|
+
reader_schema=reader_schema,
|
|
138
|
+
storage_options=storage_options,
|
|
139
|
+
daft_remote_args=daft_remote_args,
|
|
140
|
+
concurrency=concurrency,
|
|
141
|
+
left_on=left_on,
|
|
142
|
+
right_on=effective_right_on,
|
|
143
|
+
batch_size=effective_batch_size,
|
|
144
|
+
)
|
|
145
|
+
|
|
146
|
+
|
|
147
|
+
def create_scalar_index(
|
|
148
|
+
uri: str | pathlib.Path,
|
|
149
|
+
io_config: IOConfig | None = None,
|
|
150
|
+
*,
|
|
151
|
+
column: str,
|
|
152
|
+
index_type: str = "INVERTED",
|
|
153
|
+
name: str | None = None,
|
|
154
|
+
replace: bool = True,
|
|
155
|
+
storage_options: dict[str, Any] | None = None,
|
|
156
|
+
version: int | str | None = None,
|
|
157
|
+
asof: str | None = None,
|
|
158
|
+
block_size: int | None = None,
|
|
159
|
+
commit_lock: Any | None = None,
|
|
160
|
+
index_cache_size: int | None = None,
|
|
161
|
+
default_scan_options: dict[str, Any] | None = None,
|
|
162
|
+
metadata_cache_size_bytes: int | None = None,
|
|
163
|
+
fragment_group_size: int | None = None,
|
|
164
|
+
num_partitions: int | None = None,
|
|
165
|
+
max_concurrency: int | None = None,
|
|
166
|
+
**kwargs: Any,
|
|
167
|
+
) -> None:
|
|
168
|
+
"""Build a distributed scalar index using Daft's distributed execution."""
|
|
169
|
+
try:
|
|
170
|
+
import lance
|
|
171
|
+
from packaging import version as packaging_version
|
|
172
|
+
|
|
173
|
+
from daft_lance.scalar_index import create_scalar_index_internal
|
|
174
|
+
|
|
175
|
+
lance_version = packaging_version.parse(lance.__version__)
|
|
176
|
+
min_required_version = packaging_version.parse("0.37.0")
|
|
177
|
+
if lance_version < min_required_version:
|
|
178
|
+
raise RuntimeError(
|
|
179
|
+
f"Distributed indexing requires pylance >= 0.37.0, but found {lance.__version__}. "
|
|
180
|
+
"The distributed indexing interfaces are not available in older versions. "
|
|
181
|
+
"Please upgrade lance by running: pip install --upgrade pylance"
|
|
182
|
+
)
|
|
183
|
+
except ImportError as e:
|
|
184
|
+
raise ImportError(
|
|
185
|
+
"Unable to import the `lance` package, please install: `pip install pylance`"
|
|
186
|
+
) from e
|
|
187
|
+
|
|
188
|
+
io_config = context.get_context().daft_planning_config.default_io_config if io_config is None else io_config
|
|
189
|
+
storage_options = storage_options or io_config_to_storage_options(io_config, str(uri))
|
|
190
|
+
|
|
191
|
+
lance_ds = construct_lance_dataset(
|
|
192
|
+
uri,
|
|
193
|
+
storage_options=storage_options,
|
|
194
|
+
version=version,
|
|
195
|
+
asof=asof,
|
|
196
|
+
block_size=block_size,
|
|
197
|
+
commit_lock=commit_lock,
|
|
198
|
+
index_cache_size=index_cache_size,
|
|
199
|
+
default_scan_options=default_scan_options,
|
|
200
|
+
metadata_cache_size_bytes=metadata_cache_size_bytes,
|
|
201
|
+
)
|
|
202
|
+
|
|
203
|
+
create_scalar_index_internal(
|
|
204
|
+
lance_ds=lance_ds,
|
|
205
|
+
uri=uri,
|
|
206
|
+
column=column,
|
|
207
|
+
index_type=index_type,
|
|
208
|
+
name=name,
|
|
209
|
+
replace=replace,
|
|
210
|
+
storage_options=storage_options,
|
|
211
|
+
fragment_group_size=fragment_group_size,
|
|
212
|
+
num_partitions=num_partitions,
|
|
213
|
+
max_concurrency=max_concurrency,
|
|
214
|
+
**kwargs,
|
|
215
|
+
)
|
|
216
|
+
|
|
217
|
+
|
|
218
|
+
def compact_files(
|
|
219
|
+
uri: str | pathlib.Path,
|
|
220
|
+
io_config: IOConfig | None = None,
|
|
221
|
+
*,
|
|
222
|
+
storage_options: dict[str, Any] | None = None,
|
|
223
|
+
version: int | str | None = None,
|
|
224
|
+
asof: str | None = None,
|
|
225
|
+
block_size: int | None = None,
|
|
226
|
+
commit_lock: Any | None = None,
|
|
227
|
+
index_cache_size: int | None = None,
|
|
228
|
+
default_scan_options: dict[str, Any] | None = None,
|
|
229
|
+
metadata_cache_size_bytes: int | None = None,
|
|
230
|
+
compaction_options: dict[str, Any] | None = None,
|
|
231
|
+
partition_num: int | None = None,
|
|
232
|
+
concurrency: int | None = None,
|
|
233
|
+
) -> Any:
|
|
234
|
+
"""Compact Lance dataset files using Daft UDF-style distributed execution."""
|
|
235
|
+
try:
|
|
236
|
+
import lance
|
|
237
|
+
|
|
238
|
+
from daft_lance.compaction import compact_files_internal
|
|
239
|
+
except ImportError as e:
|
|
240
|
+
raise ImportError(
|
|
241
|
+
"Unable to import the `lance` package, please install: `pip install pylance`"
|
|
242
|
+
) from e
|
|
243
|
+
|
|
244
|
+
io_config = context.get_context().daft_planning_config.default_io_config if io_config is None else io_config
|
|
245
|
+
storage_options = storage_options or io_config_to_storage_options(
|
|
246
|
+
io_config, str(uri) if isinstance(uri, pathlib.Path) else uri
|
|
247
|
+
)
|
|
248
|
+
|
|
249
|
+
lance_ds = lance.dataset(
|
|
250
|
+
uri,
|
|
251
|
+
storage_options=storage_options,
|
|
252
|
+
version=version,
|
|
253
|
+
asof=asof,
|
|
254
|
+
block_size=block_size,
|
|
255
|
+
commit_lock=commit_lock,
|
|
256
|
+
index_cache_size=index_cache_size,
|
|
257
|
+
default_scan_options=default_scan_options,
|
|
258
|
+
metadata_cache_size_bytes=metadata_cache_size_bytes,
|
|
259
|
+
)
|
|
260
|
+
|
|
261
|
+
return compact_files_internal(
|
|
262
|
+
lance_ds=lance_ds,
|
|
263
|
+
compaction_options=compaction_options,
|
|
264
|
+
partition_num=partition_num,
|
|
265
|
+
concurrency=concurrency,
|
|
266
|
+
)
|
|
267
|
+
|
|
268
|
+
|
|
269
|
+
__all__ = [
|
|
270
|
+
"LanceRestConfig",
|
|
271
|
+
"compact_files",
|
|
272
|
+
"create_lance_table_rest",
|
|
273
|
+
"create_scalar_index",
|
|
274
|
+
"merge_columns",
|
|
275
|
+
"merge_columns_df",
|
|
276
|
+
"register_lance_table_rest",
|
|
277
|
+
"write_lance_rest",
|
|
278
|
+
]
|
daft_lance/compaction.py
ADDED
|
@@ -0,0 +1,68 @@
|
|
|
1
|
+
from __future__ import annotations
|
|
2
|
+
|
|
3
|
+
import logging
|
|
4
|
+
from typing import TYPE_CHECKING, Any
|
|
5
|
+
|
|
6
|
+
from lance.optimize import Compaction, CompactionMetrics, CompactionOptions, CompactionTask, RewriteResult
|
|
7
|
+
|
|
8
|
+
if TYPE_CHECKING:
|
|
9
|
+
import lance
|
|
10
|
+
import daft
|
|
11
|
+
|
|
12
|
+
logger = logging.getLogger(__name__)
|
|
13
|
+
|
|
14
|
+
|
|
15
|
+
class CompactionTaskUDF:
|
|
16
|
+
"""UDF to execute a batch of Lance CompactionTasks on remote workers and return execution result dictionaries."""
|
|
17
|
+
|
|
18
|
+
def __init__(
|
|
19
|
+
self,
|
|
20
|
+
lance_ds: lance.LanceDataset,
|
|
21
|
+
) -> None:
|
|
22
|
+
self.lance_ds = lance_ds
|
|
23
|
+
|
|
24
|
+
def __call__(self, task: CompactionTask) -> RewriteResult:
|
|
25
|
+
rewrite = task.execute(self.lance_ds)
|
|
26
|
+
return rewrite
|
|
27
|
+
|
|
28
|
+
|
|
29
|
+
def compact_files_internal(
|
|
30
|
+
lance_ds: lance.LanceDataset,
|
|
31
|
+
*,
|
|
32
|
+
compaction_options: dict[str, Any] | None = None,
|
|
33
|
+
partition_num: int | None = None,
|
|
34
|
+
concurrency: int | None = None,
|
|
35
|
+
) -> CompactionMetrics | None:
|
|
36
|
+
"""Execute Lance file compaction in distributed environment using Daft UDF style."""
|
|
37
|
+
logger.info("Starting UDF-style distributed compaction")
|
|
38
|
+
plan = Compaction.plan(
|
|
39
|
+
lance_ds,
|
|
40
|
+
CompactionOptions(
|
|
41
|
+
**(compaction_options or {}),
|
|
42
|
+
),
|
|
43
|
+
)
|
|
44
|
+
num_tasks = plan.num_tasks()
|
|
45
|
+
logger.info("Compaction plan created with %d tasks", num_tasks)
|
|
46
|
+
|
|
47
|
+
if num_tasks == 0:
|
|
48
|
+
logger.info("No compaction tasks needed")
|
|
49
|
+
return None
|
|
50
|
+
|
|
51
|
+
effective_partition_num = partition_num or 1
|
|
52
|
+
effective_partition_num = min(num_tasks, effective_partition_num)
|
|
53
|
+
assert effective_partition_num > 0
|
|
54
|
+
if effective_partition_num == 1:
|
|
55
|
+
df = daft.from_pydict({"task": plan.tasks})
|
|
56
|
+
else:
|
|
57
|
+
df = daft.from_pydict({"task": plan.tasks}).repartition(effective_partition_num)
|
|
58
|
+
|
|
59
|
+
WrappedRunner = daft.cls(
|
|
60
|
+
CompactionTaskUDF,
|
|
61
|
+
max_concurrency=concurrency,
|
|
62
|
+
)
|
|
63
|
+
df = df.select(WrappedRunner(lance_ds)(df["task"]).alias("rewrite"))
|
|
64
|
+
results = df.to_pandas()
|
|
65
|
+
|
|
66
|
+
metrics = Compaction.commit(lance_ds, results["rewrite"].to_list())
|
|
67
|
+
logger.info("Compaction completed successfully. Metrics: %s", metrics)
|
|
68
|
+
return metrics
|
daft_lance/merge.py
ADDED
|
@@ -0,0 +1,323 @@
|
|
|
1
|
+
from __future__ import annotations
|
|
2
|
+
|
|
3
|
+
import warnings
|
|
4
|
+
from typing import TYPE_CHECKING, Any
|
|
5
|
+
|
|
6
|
+
import daft.pickle
|
|
7
|
+
|
|
8
|
+
# mypy: disable-error-code="import-untyped"
|
|
9
|
+
from daft.datatype import DataType
|
|
10
|
+
from daft.udf import cls as daft_cls
|
|
11
|
+
from daft.udf import method
|
|
12
|
+
|
|
13
|
+
if TYPE_CHECKING:
|
|
14
|
+
import pathlib
|
|
15
|
+
from collections.abc import Callable
|
|
16
|
+
|
|
17
|
+
import lance
|
|
18
|
+
|
|
19
|
+
from daft.dependencies import pa
|
|
20
|
+
|
|
21
|
+
|
|
22
|
+
_FRAGMENT_HANDLER_RETURN_DTYPE = DataType.struct({"fragment_meta": DataType.binary(), "schema": DataType.binary()})
|
|
23
|
+
|
|
24
|
+
|
|
25
|
+
@daft_cls
|
|
26
|
+
class FragmentHandler:
|
|
27
|
+
def __init__(
|
|
28
|
+
self,
|
|
29
|
+
lance_ds: lance.LanceDataset,
|
|
30
|
+
transform: dict[str, str] | lance.udf.BatchUDF | Callable[[pa.lib.RecordBatch], pa.lib.RecordBatch],
|
|
31
|
+
read_columns: list[str] | None,
|
|
32
|
+
reader_schema: pa.Schema | None = None,
|
|
33
|
+
):
|
|
34
|
+
warnings.warn(
|
|
35
|
+
"FragmentHandler is deprecated and will be removed in a future version.",
|
|
36
|
+
category=DeprecationWarning,
|
|
37
|
+
stacklevel=2,
|
|
38
|
+
)
|
|
39
|
+
self.lance_ds = lance_ds
|
|
40
|
+
self.transform = transform
|
|
41
|
+
self.read_columns = read_columns
|
|
42
|
+
self.reader_schema = reader_schema
|
|
43
|
+
|
|
44
|
+
@method.batch(return_dtype=_FRAGMENT_HANDLER_RETURN_DTYPE)
|
|
45
|
+
def __call__(self, fragment_ids: Any) -> list[dict[str, bytes]]:
|
|
46
|
+
results = []
|
|
47
|
+
for fragment_id in fragment_ids:
|
|
48
|
+
fragment = self.lance_ds.get_fragment(fragment_id)
|
|
49
|
+
fragment_meta, schema = fragment.merge_columns(self.transform, self.read_columns, None, self.reader_schema)
|
|
50
|
+
results.append({"fragment_meta": daft.pickle.dumps(fragment_meta), "schema": daft.pickle.dumps(schema)})
|
|
51
|
+
return results
|
|
52
|
+
|
|
53
|
+
|
|
54
|
+
def merge_columns_internal(
|
|
55
|
+
lance_ds: lance.LanceDataset,
|
|
56
|
+
url: str | pathlib.Path,
|
|
57
|
+
*,
|
|
58
|
+
transform: dict[str, str] | lance.udf.BatchUDF | Callable[[pa.RecordBatch], pa.RecordBatch],
|
|
59
|
+
read_columns: list[str] | None = None,
|
|
60
|
+
reader_schema: pa.Schema | None = None,
|
|
61
|
+
storage_options: dict[str, Any] | None = None,
|
|
62
|
+
daft_remote_args: dict[str, Any] | None = None,
|
|
63
|
+
concurrency: int | None = None,
|
|
64
|
+
) -> lance.LanceDataset:
|
|
65
|
+
warnings.warn(
|
|
66
|
+
"FragmentHandler is deprecated and will be removed in a future version.",
|
|
67
|
+
category=DeprecationWarning,
|
|
68
|
+
stacklevel=2,
|
|
69
|
+
)
|
|
70
|
+
import lance
|
|
71
|
+
|
|
72
|
+
from daft import from_pylist
|
|
73
|
+
|
|
74
|
+
# NOTE: Legacy remote args (num_cpus/num_gpus/memory_bytes/batch_size) were
|
|
75
|
+
# only used for resource hints on the old @udf path. The new daft.cls
|
|
76
|
+
# interface does not expose these; functional behavior does not depend on
|
|
77
|
+
# them, so we ignore them here to keep the API simple.
|
|
78
|
+
fragment_ids = [f.metadata.id for f in lance_ds.get_fragments()]
|
|
79
|
+
fragment_data = [{"fragment_id": fid} for fid in fragment_ids]
|
|
80
|
+
|
|
81
|
+
df = from_pylist(fragment_data)
|
|
82
|
+
|
|
83
|
+
# Instantiate the Daft class with Lance-specific state and apply the
|
|
84
|
+
# batch method over the fragment_id column.
|
|
85
|
+
handler = FragmentHandler(lance_ds, transform, read_columns, reader_schema)
|
|
86
|
+
df = df.with_column("commit_message", handler(df["fragment_id"])) # type: ignore[arg-type]
|
|
87
|
+
|
|
88
|
+
commit_messages = df.collect().to_pydict()["commit_message"]
|
|
89
|
+
new_schema = None
|
|
90
|
+
fragment_metas = []
|
|
91
|
+
for commit_message in commit_messages:
|
|
92
|
+
fragment_meta = commit_message["fragment_meta"]
|
|
93
|
+
schema = commit_message["schema"]
|
|
94
|
+
fragment_metas.append(daft.pickle.loads(fragment_meta))
|
|
95
|
+
if new_schema is None:
|
|
96
|
+
new_schema = daft.pickle.loads(schema)
|
|
97
|
+
continue
|
|
98
|
+
if new_schema is None:
|
|
99
|
+
raise ValueError("No schema for new fragment found")
|
|
100
|
+
op = lance.LanceOperation.Merge(fragment_metas, new_schema)
|
|
101
|
+
return lance_ds.commit(
|
|
102
|
+
url,
|
|
103
|
+
op,
|
|
104
|
+
read_version=lance_ds.version,
|
|
105
|
+
storage_options=storage_options,
|
|
106
|
+
)
|
|
107
|
+
|
|
108
|
+
|
|
109
|
+
@daft_cls
|
|
110
|
+
class GroupFragmentMergeUDF:
|
|
111
|
+
def __init__(
|
|
112
|
+
self,
|
|
113
|
+
lance_ds: lance.LanceDataset,
|
|
114
|
+
left_on: str | None = "_rowaddr",
|
|
115
|
+
right_on: str | None = None,
|
|
116
|
+
read_columns: list[str] | None = None,
|
|
117
|
+
reader_schema: pa.Schema | None = None,
|
|
118
|
+
batch_size: int | None = 9223372036854775807,
|
|
119
|
+
):
|
|
120
|
+
"""Per-group merge handler that directly invokes Lance fragment.merge with keyed join.
|
|
121
|
+
|
|
122
|
+
Args:
|
|
123
|
+
lance_ds: Target Lance dataset.
|
|
124
|
+
left_on: Key column on the Lance fragment (default "_rowaddr").
|
|
125
|
+
right_on: Key column name present in the provided reader data (defaults to left_on).
|
|
126
|
+
read_columns: Names for columns provided to the handler via map_groups (must include right_on).
|
|
127
|
+
reader_schema: Optional Arrow schema for the reader.
|
|
128
|
+
batch_size: Optional batch size when building RecordBatchReader from the provided data.
|
|
129
|
+
"""
|
|
130
|
+
self.lance_ds = lance_ds
|
|
131
|
+
self.left_on = left_on or "_rowaddr"
|
|
132
|
+
self.right_on = right_on or self.left_on
|
|
133
|
+
self.read_columns = read_columns or []
|
|
134
|
+
self.reader_schema = reader_schema
|
|
135
|
+
self.batch_size = batch_size
|
|
136
|
+
|
|
137
|
+
@method.batch(return_dtype=_FRAGMENT_HANDLER_RETURN_DTYPE)
|
|
138
|
+
def __call__(self, *cols: Any) -> list[dict[str, bytes]]:
|
|
139
|
+
from daft.dependencies import pa as _pa
|
|
140
|
+
|
|
141
|
+
if len(cols) == 0:
|
|
142
|
+
return []
|
|
143
|
+
# Last argument is the fragment_id series, preceding args are data columns as per read_columns
|
|
144
|
+
*data_cols, fragment_ids = cols
|
|
145
|
+
ids = fragment_ids.to_pylist() if hasattr(fragment_ids, "to_pylist") else list(fragment_ids)
|
|
146
|
+
if len(ids) == 0:
|
|
147
|
+
return []
|
|
148
|
+
frag_id = ids[0]
|
|
149
|
+
|
|
150
|
+
if len(self.read_columns) != len(data_cols):
|
|
151
|
+
raise ValueError(
|
|
152
|
+
f"GroupFragmentMergeUDF expected {len(self.read_columns)} data columns, received {len(data_cols)}."
|
|
153
|
+
)
|
|
154
|
+
|
|
155
|
+
arrays: list[_pa.Array] = []
|
|
156
|
+
|
|
157
|
+
for col_name, s in zip(self.read_columns, data_cols):
|
|
158
|
+
pylist = s.to_pylist() if hasattr(s, "to_pylist") else list(s)
|
|
159
|
+
|
|
160
|
+
if col_name == self.right_on:
|
|
161
|
+
key_arr: _pa.Array
|
|
162
|
+
if self.right_on == "_rowaddr":
|
|
163
|
+
key_arr = _pa.array(pylist, type=_pa.uint64())
|
|
164
|
+
else:
|
|
165
|
+
pylist_int = [None if v is None else int(v) for v in pylist]
|
|
166
|
+
key_arr = _pa.array(pylist_int, type=_pa.int64())
|
|
167
|
+
|
|
168
|
+
# Convert all arrays to a consistent type to avoid mypy errors
|
|
169
|
+
arrays.append(key_arr.cast(_pa.int64()))
|
|
170
|
+
else:
|
|
171
|
+
arr = _pa.array(pylist)
|
|
172
|
+
if _pa.types.is_floating(arr.type):
|
|
173
|
+
arrays.append(arr)
|
|
174
|
+
elif _pa.types.is_integer(arr.type):
|
|
175
|
+
arrays.append(arr.cast(_pa.int64()))
|
|
176
|
+
else:
|
|
177
|
+
arrays.append(arr)
|
|
178
|
+
|
|
179
|
+
tbl = _pa.Table.from_arrays(arrays, names=self.read_columns)
|
|
180
|
+
|
|
181
|
+
# Ensure the join key exists in the reader data
|
|
182
|
+
if self.right_on not in tbl.schema.names:
|
|
183
|
+
raise ValueError(
|
|
184
|
+
f"Reader data missing join key '{self.right_on}'. Ensure the DataFrame includes this column (e.g., read with default_scan_options={'with_rowaddr': True} to expose '_rowaddr'). Hint: join key must be Int64; will be coerced automatically."
|
|
185
|
+
)
|
|
186
|
+
|
|
187
|
+
# After building the table, ensure the join key field is the correct type; cast if necessary
|
|
188
|
+
join_idx = tbl.schema.get_field_index(self.right_on)
|
|
189
|
+
if join_idx != -1:
|
|
190
|
+
join_field = tbl.schema.field(join_idx)
|
|
191
|
+
# Use appropriate type based on the join key name
|
|
192
|
+
expected_type = _pa.uint64() if self.right_on == "_rowaddr" else _pa.int64()
|
|
193
|
+
if join_field.type != expected_type:
|
|
194
|
+
fields = []
|
|
195
|
+
for i, name in enumerate(tbl.schema.names):
|
|
196
|
+
if name == self.right_on:
|
|
197
|
+
fields.append(_pa.field(name, expected_type))
|
|
198
|
+
else:
|
|
199
|
+
fields.append(tbl.schema.field(i))
|
|
200
|
+
coerced_schema = _pa.schema(fields)
|
|
201
|
+
tbl = tbl.cast(coerced_schema)
|
|
202
|
+
|
|
203
|
+
# Enforce that reader stream contains only join key + new columns (exclude existing dataset fields)
|
|
204
|
+
df_schema = tbl.schema
|
|
205
|
+
existing_fields: set[str] = set()
|
|
206
|
+
try:
|
|
207
|
+
existing_fields = {getattr(f, "name", str(f)) for f in self.lance_ds.schema}
|
|
208
|
+
except Exception:
|
|
209
|
+
names = []
|
|
210
|
+
try:
|
|
211
|
+
names = list(getattr(self.lance_ds.schema, "names", []))
|
|
212
|
+
except Exception:
|
|
213
|
+
try:
|
|
214
|
+
names = [getattr(f, "name", str(f)) for f in getattr(self.lance_ds.schema, "fields", [])]
|
|
215
|
+
except Exception:
|
|
216
|
+
names = []
|
|
217
|
+
existing_fields = set(names)
|
|
218
|
+
|
|
219
|
+
new_column_names = [name for name in df_schema.names if name not in existing_fields and name != self.right_on]
|
|
220
|
+
if len(new_column_names) == 0:
|
|
221
|
+
# No new columns to merge; return early
|
|
222
|
+
return [{"fragment_meta": b"", "schema": b""}] # Return empty bytes instead of None
|
|
223
|
+
|
|
224
|
+
# Filter table to only include join key + new columns
|
|
225
|
+
filtered_names = [name for name in df_schema.names if name == self.right_on or name in new_column_names]
|
|
226
|
+
tbl = tbl.select(filtered_names)
|
|
227
|
+
|
|
228
|
+
# Build RecordBatchReader from table batches
|
|
229
|
+
batches = tbl.to_batches(max_chunksize=self.batch_size) if self.batch_size is not None else tbl.to_batches()
|
|
230
|
+
reader = _pa.RecordBatchReader.from_batches(tbl.schema, batches)
|
|
231
|
+
|
|
232
|
+
fragment = self.lance_ds.get_fragment(frag_id)
|
|
233
|
+
# Build schema argument: use the table's schema (including join key and new columns) unless an explicit reader_schema is provided
|
|
234
|
+
schema_arg = tbl.schema if self.reader_schema is None else self.reader_schema
|
|
235
|
+
fragment_meta, schema = fragment.merge(reader, left_on=self.left_on, right_on=self.right_on, schema=schema_arg)
|
|
236
|
+
return [{"fragment_meta": daft.pickle.dumps(fragment_meta), "schema": daft.pickle.dumps(schema)}]
|
|
237
|
+
|
|
238
|
+
|
|
239
|
+
def merge_columns_from_df(
|
|
240
|
+
df: daft.DataFrame,
|
|
241
|
+
lance_ds: lance.LanceDataset,
|
|
242
|
+
uri: str | pathlib.Path,
|
|
243
|
+
*,
|
|
244
|
+
read_columns: list[str] | None = None,
|
|
245
|
+
reader_schema: pa.Schema | None = None,
|
|
246
|
+
storage_options: dict[str, Any] | None = None,
|
|
247
|
+
daft_remote_args: dict[str, Any] | None = None,
|
|
248
|
+
concurrency: int | None = None,
|
|
249
|
+
left_on: str | None = "_rowaddr",
|
|
250
|
+
right_on: str | None = None,
|
|
251
|
+
batch_size: int | None = 9223372036854775807,
|
|
252
|
+
) -> lance.LanceDataset:
|
|
253
|
+
import lance
|
|
254
|
+
|
|
255
|
+
# Validate required keys
|
|
256
|
+
if "fragment_id" not in df.column_names:
|
|
257
|
+
raise ValueError("DataFrame must contain 'fragment_id' column for row-level merge workflow")
|
|
258
|
+
join_key = right_on or left_on
|
|
259
|
+
if join_key not in df.column_names:
|
|
260
|
+
raise ValueError(
|
|
261
|
+
f"DataFrame must contain join key column '{join_key}'. If missing, read with default_scan_options={{'with_row_address': True}} to expose '_rowaddr', or include the key explicitly."
|
|
262
|
+
)
|
|
263
|
+
|
|
264
|
+
# Derive read_columns if not provided: exactly [join_key] + new columns (not present in dataset schema)
|
|
265
|
+
if read_columns is None:
|
|
266
|
+
# Compute dataset existing field names robustly
|
|
267
|
+
existing_fields: set[str] = set()
|
|
268
|
+
try:
|
|
269
|
+
existing_fields = {getattr(f, "name", str(f)) for f in lance_ds.schema}
|
|
270
|
+
except Exception:
|
|
271
|
+
names = []
|
|
272
|
+
try:
|
|
273
|
+
names = list(getattr(lance_ds.schema, "names", []))
|
|
274
|
+
except Exception:
|
|
275
|
+
try:
|
|
276
|
+
names = [getattr(f, "name", str(f)) for f in getattr(lance_ds.schema, "fields", [])]
|
|
277
|
+
except Exception:
|
|
278
|
+
names = []
|
|
279
|
+
existing_fields = set(names)
|
|
280
|
+
new_cols = [c for c in df.column_names if c not in existing_fields and c not in ("fragment_id", join_key)]
|
|
281
|
+
if len(new_cols) == 0:
|
|
282
|
+
raise ValueError(
|
|
283
|
+
"No new columns to merge; Lance requires the reader stream to include only the join key and new columns not present in the dataset."
|
|
284
|
+
)
|
|
285
|
+
read_columns = [join_key] + new_cols
|
|
286
|
+
|
|
287
|
+
handler_udf = GroupFragmentMergeUDF(
|
|
288
|
+
lance_ds,
|
|
289
|
+
left_on,
|
|
290
|
+
right_on,
|
|
291
|
+
read_columns,
|
|
292
|
+
reader_schema,
|
|
293
|
+
batch_size,
|
|
294
|
+
)
|
|
295
|
+
|
|
296
|
+
# map_groups: pass data columns followed by fragment_id
|
|
297
|
+
grouped = df.groupby("fragment_id").map_groups(
|
|
298
|
+
handler_udf(*(df[c] for c in read_columns), df["fragment_id"]).alias("commit_message") # type: ignore[attr-defined]
|
|
299
|
+
)
|
|
300
|
+
|
|
301
|
+
commit_messages = grouped.collect().to_pydict()["commit_message"]
|
|
302
|
+
new_schema = None
|
|
303
|
+
fragment_metas = []
|
|
304
|
+
for commit_message in commit_messages:
|
|
305
|
+
fragment_meta = commit_message["fragment_meta"]
|
|
306
|
+
schema = commit_message["schema"]
|
|
307
|
+
# Skip None values (when there are no new columns to merge)
|
|
308
|
+
if fragment_meta is None or schema is None:
|
|
309
|
+
continue
|
|
310
|
+
fragment_metas.append(daft.pickle.loads(fragment_meta))
|
|
311
|
+
if new_schema is None:
|
|
312
|
+
new_schema = daft.pickle.loads(schema)
|
|
313
|
+
continue
|
|
314
|
+
# If there are no new columns to merge, we can return early
|
|
315
|
+
if new_schema is None:
|
|
316
|
+
return
|
|
317
|
+
op = lance.LanceOperation.Merge(fragment_metas, new_schema)
|
|
318
|
+
return lance_ds.commit(
|
|
319
|
+
uri,
|
|
320
|
+
op,
|
|
321
|
+
read_version=lance_ds.version,
|
|
322
|
+
storage_options=storage_options,
|
|
323
|
+
)
|