mplang-nightly 0.1.dev158__py3-none-any.whl → 0.1.dev268__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- mplang/__init__.py +21 -45
- mplang/py.typed +13 -0
- mplang/v1/__init__.py +157 -0
- mplang/v1/_device.py +602 -0
- mplang/{analysis → v1/analysis}/__init__.py +1 -1
- mplang/{analysis → v1/analysis}/diagram.py +5 -7
- mplang/v1/core/__init__.py +157 -0
- mplang/{core → v1/core}/cluster.py +30 -14
- mplang/{core → v1/core}/comm.py +5 -1
- mplang/{core → v1/core}/context_mgr.py +1 -1
- mplang/{core/dtype.py → v1/core/dtypes.py} +44 -2
- mplang/{core → v1/core}/expr/__init__.py +7 -7
- mplang/{core → v1/core}/expr/ast.py +13 -14
- mplang/{core → v1/core}/expr/evaluator.py +65 -24
- mplang/{core → v1/core}/expr/printer.py +24 -18
- mplang/{core → v1/core}/expr/transformer.py +3 -3
- mplang/{core → v1/core}/expr/utils.py +2 -2
- mplang/{core → v1/core}/expr/visitor.py +1 -1
- mplang/{core → v1/core}/expr/walk.py +1 -1
- mplang/{core → v1/core}/interp.py +6 -6
- mplang/{core → v1/core}/mpir.py +23 -16
- mplang/{core → v1/core}/mpobject.py +6 -6
- mplang/{core → v1/core}/mptype.py +13 -10
- mplang/{core → v1/core}/pfunc.py +4 -4
- mplang/{core → v1/core}/primitive.py +106 -201
- mplang/{core → v1/core}/table.py +36 -8
- mplang/{core → v1/core}/tensor.py +1 -1
- mplang/{core → v1/core}/tracer.py +9 -9
- mplang/{api.py → v1/host.py} +38 -6
- mplang/v1/kernels/__init__.py +41 -0
- mplang/{kernels → v1/kernels}/base.py +1 -1
- mplang/v1/kernels/basic.py +240 -0
- mplang/{kernels → v1/kernels}/context.py +42 -27
- mplang/{kernels → v1/kernels}/crypto.py +44 -37
- mplang/v1/kernels/fhe.py +858 -0
- mplang/{kernels → v1/kernels}/mock_tee.py +12 -13
- mplang/{kernels → v1/kernels}/phe.py +263 -57
- mplang/{kernels → v1/kernels}/spu.py +137 -48
- mplang/{kernels → v1/kernels}/sql_duckdb.py +12 -15
- mplang/{kernels → v1/kernels}/stablehlo.py +30 -23
- mplang/v1/kernels/value.py +626 -0
- mplang/{ops → v1/ops}/__init__.py +5 -16
- mplang/{ops → v1/ops}/base.py +2 -5
- mplang/{ops/builtin.py → v1/ops/basic.py} +34 -26
- mplang/v1/ops/crypto.py +262 -0
- mplang/v1/ops/fhe.py +272 -0
- mplang/{ops → v1/ops}/jax_cc.py +33 -68
- mplang/v1/ops/nnx_cc.py +168 -0
- mplang/{ops → v1/ops}/phe.py +16 -4
- mplang/{ops → v1/ops}/spu.py +3 -5
- mplang/v1/ops/sql_cc.py +303 -0
- mplang/{ops → v1/ops}/tee.py +9 -24
- mplang/{protos → v1/protos}/v1alpha1/mpir_pb2.pyi +71 -21
- mplang/v1/protos/v1alpha1/value_pb2.py +34 -0
- mplang/v1/protos/v1alpha1/value_pb2.pyi +169 -0
- mplang/{runtime → v1/runtime}/__init__.py +2 -2
- mplang/v1/runtime/channel.py +230 -0
- mplang/{runtime → v1/runtime}/cli.py +35 -20
- mplang/{runtime → v1/runtime}/client.py +19 -8
- mplang/{runtime → v1/runtime}/communicator.py +59 -15
- mplang/{runtime → v1/runtime}/data_providers.py +80 -19
- mplang/{runtime → v1/runtime}/driver.py +30 -12
- mplang/v1/runtime/link_comm.py +196 -0
- mplang/{runtime → v1/runtime}/server.py +58 -42
- mplang/{runtime → v1/runtime}/session.py +57 -71
- mplang/{runtime → v1/runtime}/simulation.py +55 -28
- mplang/v1/simp/api.py +353 -0
- mplang/{simp → v1/simp}/mpi.py +8 -9
- mplang/{simp/__init__.py → v1/simp/party.py} +19 -145
- mplang/{simp → v1/simp}/random.py +21 -22
- mplang/v1/simp/smpc.py +238 -0
- mplang/v1/utils/table_utils.py +185 -0
- mplang/v2/__init__.py +424 -0
- mplang/v2/backends/__init__.py +57 -0
- mplang/v2/backends/bfv_impl.py +705 -0
- mplang/v2/backends/channel.py +217 -0
- mplang/v2/backends/crypto_impl.py +723 -0
- mplang/v2/backends/field_impl.py +454 -0
- mplang/v2/backends/func_impl.py +107 -0
- mplang/v2/backends/phe_impl.py +148 -0
- mplang/v2/backends/simp_design.md +136 -0
- mplang/v2/backends/simp_driver/__init__.py +41 -0
- mplang/v2/backends/simp_driver/http.py +168 -0
- mplang/v2/backends/simp_driver/mem.py +280 -0
- mplang/v2/backends/simp_driver/ops.py +135 -0
- mplang/v2/backends/simp_driver/state.py +60 -0
- mplang/v2/backends/simp_driver/values.py +52 -0
- mplang/v2/backends/simp_worker/__init__.py +29 -0
- mplang/v2/backends/simp_worker/http.py +354 -0
- mplang/v2/backends/simp_worker/mem.py +102 -0
- mplang/v2/backends/simp_worker/ops.py +167 -0
- mplang/v2/backends/simp_worker/state.py +49 -0
- mplang/v2/backends/spu_impl.py +275 -0
- mplang/v2/backends/spu_state.py +187 -0
- mplang/v2/backends/store_impl.py +62 -0
- mplang/v2/backends/table_impl.py +838 -0
- mplang/v2/backends/tee_impl.py +215 -0
- mplang/v2/backends/tensor_impl.py +519 -0
- mplang/v2/cli.py +603 -0
- mplang/v2/cli_guide.md +122 -0
- mplang/v2/dialects/__init__.py +36 -0
- mplang/v2/dialects/bfv.py +665 -0
- mplang/v2/dialects/crypto.py +689 -0
- mplang/v2/dialects/dtypes.py +378 -0
- mplang/v2/dialects/field.py +210 -0
- mplang/v2/dialects/func.py +135 -0
- mplang/v2/dialects/phe.py +723 -0
- mplang/v2/dialects/simp.py +944 -0
- mplang/v2/dialects/spu.py +349 -0
- mplang/v2/dialects/store.py +63 -0
- mplang/v2/dialects/table.py +407 -0
- mplang/v2/dialects/tee.py +346 -0
- mplang/v2/dialects/tensor.py +1175 -0
- mplang/v2/edsl/README.md +279 -0
- mplang/v2/edsl/__init__.py +99 -0
- mplang/v2/edsl/context.py +311 -0
- mplang/v2/edsl/graph.py +463 -0
- mplang/v2/edsl/jit.py +62 -0
- mplang/v2/edsl/object.py +53 -0
- mplang/v2/edsl/primitive.py +284 -0
- mplang/v2/edsl/printer.py +119 -0
- mplang/v2/edsl/registry.py +207 -0
- mplang/v2/edsl/serde.py +375 -0
- mplang/v2/edsl/tracer.py +614 -0
- mplang/v2/edsl/typing.py +816 -0
- mplang/v2/kernels/Makefile +30 -0
- mplang/v2/kernels/__init__.py +23 -0
- mplang/v2/kernels/gf128.cpp +148 -0
- mplang/v2/kernels/ldpc.cpp +82 -0
- mplang/v2/kernels/okvs.cpp +283 -0
- mplang/v2/kernels/okvs_opt.cpp +291 -0
- mplang/v2/kernels/py_kernels.py +398 -0
- mplang/v2/libs/collective.py +330 -0
- mplang/v2/libs/device/__init__.py +51 -0
- mplang/v2/libs/device/api.py +813 -0
- mplang/v2/libs/device/cluster.py +352 -0
- mplang/v2/libs/ml/__init__.py +23 -0
- mplang/v2/libs/ml/sgb.py +1861 -0
- mplang/v2/libs/mpc/__init__.py +41 -0
- mplang/v2/libs/mpc/_utils.py +99 -0
- mplang/v2/libs/mpc/analytics/__init__.py +35 -0
- mplang/v2/libs/mpc/analytics/aggregation.py +372 -0
- mplang/v2/libs/mpc/analytics/groupby.md +99 -0
- mplang/v2/libs/mpc/analytics/groupby.py +331 -0
- mplang/v2/libs/mpc/analytics/permutation.py +386 -0
- mplang/v2/libs/mpc/common/constants.py +39 -0
- mplang/v2/libs/mpc/ot/__init__.py +32 -0
- mplang/v2/libs/mpc/ot/base.py +222 -0
- mplang/v2/libs/mpc/ot/extension.py +477 -0
- mplang/v2/libs/mpc/ot/silent.py +217 -0
- mplang/v2/libs/mpc/psi/__init__.py +40 -0
- mplang/v2/libs/mpc/psi/cuckoo.py +228 -0
- mplang/v2/libs/mpc/psi/okvs.py +49 -0
- mplang/v2/libs/mpc/psi/okvs_gct.py +79 -0
- mplang/v2/libs/mpc/psi/oprf.py +310 -0
- mplang/v2/libs/mpc/psi/rr22.py +344 -0
- mplang/v2/libs/mpc/psi/unbalanced.py +200 -0
- mplang/v2/libs/mpc/vole/__init__.py +31 -0
- mplang/v2/libs/mpc/vole/gilboa.py +327 -0
- mplang/v2/libs/mpc/vole/ldpc.py +383 -0
- mplang/v2/libs/mpc/vole/silver.py +336 -0
- mplang/v2/runtime/__init__.py +15 -0
- mplang/v2/runtime/dialect_state.py +41 -0
- mplang/v2/runtime/interpreter.py +871 -0
- mplang/v2/runtime/object_store.py +194 -0
- mplang/v2/runtime/value.py +141 -0
- {mplang_nightly-0.1.dev158.dist-info → mplang_nightly-0.1.dev268.dist-info}/METADATA +24 -17
- mplang_nightly-0.1.dev268.dist-info/RECORD +180 -0
- {mplang_nightly-0.1.dev158.dist-info → mplang_nightly-0.1.dev268.dist-info}/WHEEL +1 -1
- mplang/core/__init__.py +0 -92
- mplang/device.py +0 -340
- mplang/kernels/builtin.py +0 -207
- mplang/ops/crypto.py +0 -109
- mplang/ops/ibis_cc.py +0 -139
- mplang/ops/sql.py +0 -61
- mplang/protos/v1alpha1/mpir_pb2_grpc.py +0 -3
- mplang/runtime/link_comm.py +0 -131
- mplang/simp/smpc.py +0 -201
- mplang/utils/table_utils.py +0 -73
- mplang_nightly-0.1.dev158.dist-info/RECORD +0 -77
- /mplang/{core → v1/core}/mask.py +0 -0
- /mplang/{protos → v1/protos}/v1alpha1/mpir_pb2.py +0 -0
- /mplang/{runtime → v1/runtime}/exceptions.py +0 -0
- /mplang/{runtime → v1/runtime}/http_api.md +0 -0
- /mplang/{kernels → v1/simp}/__init__.py +0 -0
- /mplang/{utils → v1/utils}/__init__.py +0 -0
- /mplang/{utils → v1/utils}/crypto.py +0 -0
- /mplang/{utils → v1/utils}/func_utils.py +0 -0
- /mplang/{utils → v1/utils}/spu_utils.py +0 -0
- {mplang_nightly-0.1.dev158.dist-info → mplang_nightly-0.1.dev268.dist-info}/entry_points.txt +0 -0
- {mplang_nightly-0.1.dev158.dist-info → mplang_nightly-0.1.dev268.dist-info}/licenses/LICENSE +0 -0
|
@@ -0,0 +1,838 @@
|
|
|
1
|
+
# Copyright 2025 Ant Group Co., Ltd.
|
|
2
|
+
#
|
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
4
|
+
# you may not use this file except in compliance with the License.
|
|
5
|
+
# You may obtain a copy of the License at
|
|
6
|
+
#
|
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
8
|
+
#
|
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
12
|
+
# See the License for the specific language governing permissions and
|
|
13
|
+
# limitations under the License.
|
|
14
|
+
|
|
15
|
+
"""Table Runtime Implementation.
|
|
16
|
+
|
|
17
|
+
Implements execution logic for Table primitives using DuckDB and PyArrow.
|
|
18
|
+
"""
|
|
19
|
+
|
|
20
|
+
from __future__ import annotations
|
|
21
|
+
|
|
22
|
+
import base64
|
|
23
|
+
from abc import ABC, abstractmethod
|
|
24
|
+
from dataclasses import dataclass
|
|
25
|
+
from typing import Any, ClassVar, Protocol, Self, runtime_checkable
|
|
26
|
+
|
|
27
|
+
import duckdb
|
|
28
|
+
import pandas as pd
|
|
29
|
+
import pyarrow as pa
|
|
30
|
+
|
|
31
|
+
import mplang.v2.edsl.typing as elt
|
|
32
|
+
from mplang.v2.backends.tensor_impl import TensorValue
|
|
33
|
+
from mplang.v2.dialects import table
|
|
34
|
+
from mplang.v2.edsl import serde
|
|
35
|
+
from mplang.v2.edsl.graph import Operation
|
|
36
|
+
from mplang.v2.runtime.interpreter import Interpreter
|
|
37
|
+
from mplang.v2.runtime.value import WrapValue
|
|
38
|
+
|
|
39
|
+
|
|
40
|
+
class BatchReader(ABC):
|
|
41
|
+
@property
|
|
42
|
+
@abstractmethod
|
|
43
|
+
def schema(self) -> pa.Schema: ...
|
|
44
|
+
|
|
45
|
+
@abstractmethod
|
|
46
|
+
def read_next_batch(self) -> pa.RecordBatch: ...
|
|
47
|
+
@abstractmethod
|
|
48
|
+
def close(self) -> None: ...
|
|
49
|
+
|
|
50
|
+
def __enter__(self) -> Self:
|
|
51
|
+
return self
|
|
52
|
+
|
|
53
|
+
def __exit__(self, exc_type: object, exc_val: object, exc_tb: object) -> None:
|
|
54
|
+
self.close()
|
|
55
|
+
|
|
56
|
+
def __iter__(self) -> Self:
|
|
57
|
+
return self
|
|
58
|
+
|
|
59
|
+
def __next__(self) -> pa.RecordBatch:
|
|
60
|
+
return self.read_next_batch()
|
|
61
|
+
|
|
62
|
+
|
|
63
|
+
class TableReader(BatchReader):
|
|
64
|
+
"""A reader for streaming table data from PyArrow RecordBatchReader or Table.
|
|
65
|
+
|
|
66
|
+
This class provides an efficient way to read large tables in batches,
|
|
67
|
+
with support for custom batch sizes and proper handling of data boundaries.
|
|
68
|
+
It implements the iterator protocol for easy consumption of data.
|
|
69
|
+
"""
|
|
70
|
+
|
|
71
|
+
def __init__(
|
|
72
|
+
self,
|
|
73
|
+
data: pa.RecordBatchReader | pa.Table,
|
|
74
|
+
num_rows: int = -1,
|
|
75
|
+
batch_size: int = -1,
|
|
76
|
+
) -> None:
|
|
77
|
+
"""Initialize a TableReader.
|
|
78
|
+
|
|
79
|
+
Args:
|
|
80
|
+
data: Either a RecordBatchReader or Table to read from
|
|
81
|
+
num_rows: Expected number of rows in the data. -1 indicates unknown
|
|
82
|
+
batch_size: Size of each batch to read. -1 means use default/reader's batch size
|
|
83
|
+
"""
|
|
84
|
+
# Store the underlying reader and row count based on input type
|
|
85
|
+
if isinstance(data, pa.RecordBatchReader):
|
|
86
|
+
self._reader = data
|
|
87
|
+
self._num_rows = num_rows
|
|
88
|
+
else:
|
|
89
|
+
# Convert Table to RecordBatchReader for consistent interface
|
|
90
|
+
self._reader = data.to_reader()
|
|
91
|
+
self._num_rows = data.num_rows
|
|
92
|
+
|
|
93
|
+
# Configuration for batch reading
|
|
94
|
+
self._batch_size = batch_size
|
|
95
|
+
|
|
96
|
+
# Internal state for handling custom batch sizes
|
|
97
|
+
self._remain: pa.RecordBatch | None = (
|
|
98
|
+
None # Stores partial batch from previous read
|
|
99
|
+
)
|
|
100
|
+
self._eof = False # Flag to indicate end of data
|
|
101
|
+
|
|
102
|
+
@property
|
|
103
|
+
def num_rows(self) -> int:
|
|
104
|
+
"""Get the total number of rows in the table.
|
|
105
|
+
|
|
106
|
+
Returns:
|
|
107
|
+
Total number of rows, or -1 if unknown
|
|
108
|
+
"""
|
|
109
|
+
return self._num_rows
|
|
110
|
+
|
|
111
|
+
@property
|
|
112
|
+
def schema(self) -> pa.Schema:
|
|
113
|
+
"""Get the schema of the table.
|
|
114
|
+
|
|
115
|
+
Returns:
|
|
116
|
+
PyArrow Schema describing the table's columns and types
|
|
117
|
+
"""
|
|
118
|
+
return self._reader.schema
|
|
119
|
+
|
|
120
|
+
def read_all(self) -> pa.Table:
|
|
121
|
+
"""Read all remaining data as a Table.
|
|
122
|
+
|
|
123
|
+
This is a convenience method that reads all data from the reader
|
|
124
|
+
and returns it as a single PyArrow Table.
|
|
125
|
+
|
|
126
|
+
Returns:
|
|
127
|
+
Complete table containing all remaining data
|
|
128
|
+
"""
|
|
129
|
+
return self._reader.read_all()
|
|
130
|
+
|
|
131
|
+
def read_next_batch(self) -> pa.RecordBatch:
|
|
132
|
+
"""Read the next batch of records.
|
|
133
|
+
|
|
134
|
+
This method respects the configured batch size. If the native reader
|
|
135
|
+
returns batches larger than the configured size, this method will split
|
|
136
|
+
them appropriately. Any partial data from previous reads is included
|
|
137
|
+
in the returned batch.
|
|
138
|
+
|
|
139
|
+
Returns:
|
|
140
|
+
Next RecordBatch of data
|
|
141
|
+
|
|
142
|
+
Raises:
|
|
143
|
+
StopIteration: When no more data is available
|
|
144
|
+
"""
|
|
145
|
+
# Check if we've reached end of file
|
|
146
|
+
if self._eof:
|
|
147
|
+
raise StopIteration
|
|
148
|
+
|
|
149
|
+
# Get the next batch using internal logic
|
|
150
|
+
batch = self._read_next_batch()
|
|
151
|
+
|
|
152
|
+
# Handle end of data
|
|
153
|
+
if batch is None:
|
|
154
|
+
self._eof = True
|
|
155
|
+
raise StopIteration
|
|
156
|
+
|
|
157
|
+
return batch
|
|
158
|
+
|
|
159
|
+
def _read_next_batch(self) -> pa.RecordBatch | None:
|
|
160
|
+
"""Internal method to read and process the next batch.
|
|
161
|
+
|
|
162
|
+
This method handles the complex logic of:
|
|
163
|
+
- Using default batch size when none is specified
|
|
164
|
+
- Accumulating data from multiple native batches to reach the target size
|
|
165
|
+
- Splitting oversized batches and saving the remainder
|
|
166
|
+
- Converting between Table and RecordBatch formats as needed
|
|
167
|
+
|
|
168
|
+
Returns:
|
|
169
|
+
Next RecordBatch of the configured size, or None if no more data
|
|
170
|
+
"""
|
|
171
|
+
# If no batch size specified, just return the reader's native batches
|
|
172
|
+
if self._batch_size <= 0:
|
|
173
|
+
try:
|
|
174
|
+
batch = self._reader.read_next_batch()
|
|
175
|
+
# Convert to RecordBatch if the reader returns a Table
|
|
176
|
+
if isinstance(batch, pa.Table) and batch.num_rows > 0:
|
|
177
|
+
return batch.to_batches()[0]
|
|
178
|
+
return batch
|
|
179
|
+
except StopIteration:
|
|
180
|
+
return None
|
|
181
|
+
|
|
182
|
+
# We have a custom batch size - need to accumulate/split batches
|
|
183
|
+
batches: list[pa.RecordBatch] = []
|
|
184
|
+
num_rows: int = 0
|
|
185
|
+
|
|
186
|
+
# First, include any remaining data from the previous read
|
|
187
|
+
if self._remain is not None:
|
|
188
|
+
num_rows = self._remain.num_rows
|
|
189
|
+
batches = [self._remain]
|
|
190
|
+
self._remain = None
|
|
191
|
+
|
|
192
|
+
# Keep reading until we have enough rows or run out of data
|
|
193
|
+
while num_rows < self._batch_size:
|
|
194
|
+
try:
|
|
195
|
+
batch = self._reader.read_next_batch()
|
|
196
|
+
|
|
197
|
+
# Handle the case where reader returns a Table instead of RecordBatch
|
|
198
|
+
if isinstance(batch, pa.Table):
|
|
199
|
+
if batch.num_rows > 0:
|
|
200
|
+
# Convert each batch from the Table
|
|
201
|
+
for rb in batch.to_batches():
|
|
202
|
+
num_rows += rb.num_rows
|
|
203
|
+
if rb.num_rows > 0: # Skip empty batches
|
|
204
|
+
batches.append(rb)
|
|
205
|
+
else:
|
|
206
|
+
# Already a RecordBatch
|
|
207
|
+
num_rows += batch.num_rows
|
|
208
|
+
if batch.num_rows > 0: # Skip empty batches
|
|
209
|
+
batches.append(batch)
|
|
210
|
+
except StopIteration:
|
|
211
|
+
# Mark EOF but continue processing what we have
|
|
212
|
+
self._eof = True
|
|
213
|
+
break
|
|
214
|
+
|
|
215
|
+
# If we didn't get any data, return None
|
|
216
|
+
if num_rows == 0:
|
|
217
|
+
return None
|
|
218
|
+
|
|
219
|
+
# Split the last batch if we have more rows than needed
|
|
220
|
+
if num_rows > self._batch_size:
|
|
221
|
+
last = batches[-1]
|
|
222
|
+
remain_size = num_rows - self._batch_size
|
|
223
|
+
last_size = last.num_rows - remain_size
|
|
224
|
+
|
|
225
|
+
# Keep only what we need from the last batch
|
|
226
|
+
batches[-1] = last.slice(0, last_size)
|
|
227
|
+
# Save the remainder for the next read
|
|
228
|
+
self._remain = last.slice(last_size, remain_size)
|
|
229
|
+
|
|
230
|
+
# Optimized path: if we only have one batch, return it directly
|
|
231
|
+
if len(batches) == 1:
|
|
232
|
+
return batches[0]
|
|
233
|
+
|
|
234
|
+
# Otherwise, combine all batches and return as a single RecordBatch
|
|
235
|
+
combined = pa.Table.from_batches(batches)
|
|
236
|
+
return combined.to_batches()[0]
|
|
237
|
+
|
|
238
|
+
def close(self) -> None:
|
|
239
|
+
"""Close the reader and release all resources.
|
|
240
|
+
|
|
241
|
+
This method should be called when the reader is no longer needed.
|
|
242
|
+
It closes the underlying reader and clears internal state.
|
|
243
|
+
"""
|
|
244
|
+
# Close the underlying reader
|
|
245
|
+
self._reader.close()
|
|
246
|
+
# Clear internal state
|
|
247
|
+
self._remain = None
|
|
248
|
+
self._eof = False
|
|
249
|
+
|
|
250
|
+
|
|
251
|
+
DEFAULT_BATCH_SIZE = 1_000_000
|
|
252
|
+
|
|
253
|
+
|
|
254
|
+
class TableSource(ABC):
|
|
255
|
+
"""Abstract base class for lazy table operations.
|
|
256
|
+
|
|
257
|
+
Provides deferred execution for table operations to prevent OOM issues.
|
|
258
|
+
"""
|
|
259
|
+
|
|
260
|
+
@abstractmethod
|
|
261
|
+
def register(
|
|
262
|
+
self, conn: duckdb.DuckDBPyConnection, name: str, replace: bool = True
|
|
263
|
+
) -> None: ...
|
|
264
|
+
|
|
265
|
+
@abstractmethod
|
|
266
|
+
def open(self, batch_size: int = DEFAULT_BATCH_SIZE) -> TableReader:
|
|
267
|
+
"""Read data as a stream of record batches."""
|
|
268
|
+
...
|
|
269
|
+
|
|
270
|
+
|
|
271
|
+
class ParquetReader(pa.RecordBatchReader):
|
|
272
|
+
"""A reader that implements the pa.RecordBatchReader interface for Parquet files."""
|
|
273
|
+
|
|
274
|
+
def __init__(self, source: Any, columns: list[str] | None = None):
|
|
275
|
+
import pyarrow.parquet as pq
|
|
276
|
+
|
|
277
|
+
file = pq.ParquetFile(source)
|
|
278
|
+
|
|
279
|
+
# Use schema_arrow to get the proper pa.Schema
|
|
280
|
+
if columns:
|
|
281
|
+
# Filter the schema to only include selected columns
|
|
282
|
+
fields = [
|
|
283
|
+
file.schema_arrow.field(col)
|
|
284
|
+
for col in columns
|
|
285
|
+
if col in file.schema_arrow.names
|
|
286
|
+
]
|
|
287
|
+
schema = pa.schema(fields)
|
|
288
|
+
else:
|
|
289
|
+
schema = file.schema_arrow
|
|
290
|
+
|
|
291
|
+
self._file = file
|
|
292
|
+
self._schema = schema
|
|
293
|
+
self._cast = False
|
|
294
|
+
self._num_rows = int(file.metadata.num_rows)
|
|
295
|
+
self._iter = file.iter_batches(columns=columns)
|
|
296
|
+
|
|
297
|
+
@property
|
|
298
|
+
def num_rows(self) -> int:
|
|
299
|
+
return self._num_rows
|
|
300
|
+
|
|
301
|
+
@property
|
|
302
|
+
def schema(self) -> pa.Schema:
|
|
303
|
+
return self._schema
|
|
304
|
+
|
|
305
|
+
def cast(self, target_schema: pa.Schema) -> ParquetReader:
|
|
306
|
+
# Validate that the number of columns is the same
|
|
307
|
+
if len(target_schema) != len(self._schema):
|
|
308
|
+
raise ValueError(
|
|
309
|
+
f"Cannot cast schema: target schema has {len(target_schema)} columns, "
|
|
310
|
+
f"but current schema has {len(self._schema)} columns"
|
|
311
|
+
)
|
|
312
|
+
|
|
313
|
+
# Check if there are any changes in the schema
|
|
314
|
+
schema_changed = False
|
|
315
|
+
for i, (target_field, current_field) in enumerate(
|
|
316
|
+
zip(target_schema, self._schema, strict=True)
|
|
317
|
+
):
|
|
318
|
+
# Check if field names are the same (allowing type changes)
|
|
319
|
+
if target_field.name != current_field.name:
|
|
320
|
+
raise ValueError(
|
|
321
|
+
f"Cannot cast schema: field name at position {i} differs. "
|
|
322
|
+
f"Current: '{current_field.name}', Target: '{target_field.name}'. "
|
|
323
|
+
f"Field names must match."
|
|
324
|
+
)
|
|
325
|
+
# Check if types are different
|
|
326
|
+
if target_field.type != current_field.type:
|
|
327
|
+
schema_changed = True
|
|
328
|
+
|
|
329
|
+
# Only set _cast if there are actual changes
|
|
330
|
+
if schema_changed:
|
|
331
|
+
self._schema = target_schema
|
|
332
|
+
self._cast = True
|
|
333
|
+
|
|
334
|
+
return self
|
|
335
|
+
|
|
336
|
+
def read_all(self) -> pa.Table:
|
|
337
|
+
batches = []
|
|
338
|
+
try:
|
|
339
|
+
while True:
|
|
340
|
+
batch = self.read_next_batch()
|
|
341
|
+
batches.append(batch)
|
|
342
|
+
except StopIteration:
|
|
343
|
+
pass
|
|
344
|
+
if batches:
|
|
345
|
+
return pa.Table.from_batches(batches)
|
|
346
|
+
return pa.Table.from_batches([])
|
|
347
|
+
|
|
348
|
+
def read_next_batch(self) -> pa.RecordBatch:
|
|
349
|
+
batch = next(self._iter)
|
|
350
|
+
if self._cast:
|
|
351
|
+
return batch.cast(self._schema)
|
|
352
|
+
else:
|
|
353
|
+
return batch
|
|
354
|
+
|
|
355
|
+
def close(self) -> None:
|
|
356
|
+
"""Close the Parquet reader and release resources."""
|
|
357
|
+
self._file.close()
|
|
358
|
+
|
|
359
|
+
|
|
360
|
+
_type_mapping = {
|
|
361
|
+
elt.bool_: pa.bool_(),
|
|
362
|
+
elt.i8: pa.int8(),
|
|
363
|
+
elt.i16: pa.int16(),
|
|
364
|
+
elt.i32: pa.int32(),
|
|
365
|
+
elt.i64: pa.int64(),
|
|
366
|
+
elt.u8: pa.uint8(),
|
|
367
|
+
elt.u16: pa.uint16(),
|
|
368
|
+
elt.u32: pa.uint32(),
|
|
369
|
+
elt.u64: pa.uint64(),
|
|
370
|
+
elt.f16: pa.float16(),
|
|
371
|
+
elt.f32: pa.float32(),
|
|
372
|
+
elt.f64: pa.float64(),
|
|
373
|
+
elt.STRING: pa.string(),
|
|
374
|
+
elt.DATE: pa.date64(),
|
|
375
|
+
elt.TIME: pa.time32("ms"),
|
|
376
|
+
elt.TIMESTAMP: pa.timestamp("ms"),
|
|
377
|
+
elt.DECIMAL: pa.decimal128(38, 10),
|
|
378
|
+
elt.BINARY: pa.binary(),
|
|
379
|
+
elt.JSON: pa.json_(),
|
|
380
|
+
}
|
|
381
|
+
|
|
382
|
+
|
|
383
|
+
def _pa_schema(s: elt.TableType) -> pa.Schema:
|
|
384
|
+
fields = []
|
|
385
|
+
for k, v in s.schema.items():
|
|
386
|
+
if v not in _type_mapping:
|
|
387
|
+
raise ValueError(f"cannot convert to pyarrow type. type={v}, name={k}")
|
|
388
|
+
fields.append(pa.field(k, _type_mapping[v]))
|
|
389
|
+
|
|
390
|
+
return pa.schema(fields)
|
|
391
|
+
|
|
392
|
+
|
|
393
|
+
@dataclass
|
|
394
|
+
class FileTableSource(TableSource):
|
|
395
|
+
"""Lazy table handle for file-based operations with streaming reads."""
|
|
396
|
+
|
|
397
|
+
path: str
|
|
398
|
+
format: str
|
|
399
|
+
schema: pa.Schema | None = None
|
|
400
|
+
|
|
401
|
+
def register(
|
|
402
|
+
self, conn: duckdb.DuckDBPyConnection, name: str, replace: bool = True
|
|
403
|
+
) -> None:
|
|
404
|
+
"""Register the file as a view in DuckDB."""
|
|
405
|
+
func_name = ""
|
|
406
|
+
match self.format:
|
|
407
|
+
case "parquet":
|
|
408
|
+
func_name = "read_parquet"
|
|
409
|
+
case "csv":
|
|
410
|
+
func_name = "read_csv_auto"
|
|
411
|
+
case "json":
|
|
412
|
+
func_name = "read_json_auto"
|
|
413
|
+
case _:
|
|
414
|
+
raise ValueError(f"Unsupported format: {self.format}")
|
|
415
|
+
|
|
416
|
+
safe_path = self.path.replace("'", "''")
|
|
417
|
+
base_query = f"SELECT * FROM {func_name}('{safe_path}')"
|
|
418
|
+
if replace:
|
|
419
|
+
query = f"CREATE OR REPLACE VIEW {name} AS {base_query}"
|
|
420
|
+
else:
|
|
421
|
+
query = f"CREATE VIEW {name} AS {base_query}"
|
|
422
|
+
conn.execute(query)
|
|
423
|
+
|
|
424
|
+
def open(self, batch_size: int = DEFAULT_BATCH_SIZE) -> TableReader:
|
|
425
|
+
"""Create a streaming reader for the file."""
|
|
426
|
+
import pyarrow.csv as pa_csv
|
|
427
|
+
import pyarrow.json as pa_json
|
|
428
|
+
|
|
429
|
+
columns = self.schema.names if self.schema else None
|
|
430
|
+
|
|
431
|
+
reader = None
|
|
432
|
+
num_rows = -1
|
|
433
|
+
match self.format.lower():
|
|
434
|
+
case "parquet":
|
|
435
|
+
reader = ParquetReader(self.path, columns)
|
|
436
|
+
num_rows = reader.num_rows
|
|
437
|
+
case "csv":
|
|
438
|
+
read_options = pa_csv.ReadOptions(use_threads=True)
|
|
439
|
+
convert_options = pa_csv.ConvertOptions(
|
|
440
|
+
column_types=self.schema,
|
|
441
|
+
include_columns=columns,
|
|
442
|
+
)
|
|
443
|
+
reader = pa_csv.open_csv(
|
|
444
|
+
self.path,
|
|
445
|
+
read_options=read_options,
|
|
446
|
+
convert_options=convert_options,
|
|
447
|
+
)
|
|
448
|
+
case "json":
|
|
449
|
+
read_options = pa_json.ReadOptions(use_threads=True)
|
|
450
|
+
table = pa_json.read_json(self.path, read_options=read_options)
|
|
451
|
+
if columns:
|
|
452
|
+
table = table.select(columns)
|
|
453
|
+
reader = table.to_reader()
|
|
454
|
+
num_rows = table.num_rows
|
|
455
|
+
case _:
|
|
456
|
+
raise ValueError(f"Unsupported format: {self.format}")
|
|
457
|
+
|
|
458
|
+
if self.schema and self.schema != reader.schema:
|
|
459
|
+
reader = reader.cast(self.schema)
|
|
460
|
+
|
|
461
|
+
return TableReader(reader, num_rows=num_rows, batch_size=batch_size)
|
|
462
|
+
|
|
463
|
+
|
|
464
|
+
class DuckDBState:
|
|
465
|
+
def __init__(self, conn: duckdb.DuckDBPyConnection) -> None:
|
|
466
|
+
self.conn = conn
|
|
467
|
+
self.tables: dict[str, Any] = {}
|
|
468
|
+
|
|
469
|
+
|
|
470
|
+
@dataclass(frozen=True)
|
|
471
|
+
class QueryTableSource(TableSource):
|
|
472
|
+
"""Handle for existing DuckDB relations (kept for compatibility)."""
|
|
473
|
+
|
|
474
|
+
relation: duckdb.DuckDBPyRelation
|
|
475
|
+
state: DuckDBState
|
|
476
|
+
|
|
477
|
+
def register(
|
|
478
|
+
self, conn: duckdb.DuckDBPyConnection, name: str, replace: bool = True
|
|
479
|
+
) -> None:
|
|
480
|
+
self.relation.create_view(name, replace)
|
|
481
|
+
|
|
482
|
+
def open(self, batch_size: int = DEFAULT_BATCH_SIZE) -> TableReader:
|
|
483
|
+
"""Read from the DuckDB relation."""
|
|
484
|
+
if batch_size <= 0:
|
|
485
|
+
batch_size = DEFAULT_BATCH_SIZE
|
|
486
|
+
reader = self.relation.arrow(batch_size)
|
|
487
|
+
return TableReader(reader)
|
|
488
|
+
|
|
489
|
+
|
|
490
|
+
# =============================================================================
|
|
491
|
+
# TableValue Wrapper
|
|
492
|
+
# =============================================================================
|
|
493
|
+
|
|
494
|
+
|
|
495
|
+
@serde.register_class
|
|
496
|
+
class TableValue(WrapValue[pa.Table | TableSource]):
|
|
497
|
+
"""Runtime value wrapping a PyArrow Table.
|
|
498
|
+
|
|
499
|
+
Provides serialization via Arrow IPC format (streaming).
|
|
500
|
+
Future: may extend to support other backends (Polars, DuckDB relations, etc.)
|
|
501
|
+
"""
|
|
502
|
+
|
|
503
|
+
_serde_kind: ClassVar[str] = "table_impl.TableValue"
|
|
504
|
+
|
|
505
|
+
@property
|
|
506
|
+
def data(self) -> pa.Table:
|
|
507
|
+
"""Get the underlying PyArrow Table data.
|
|
508
|
+
|
|
509
|
+
For lazy TableSource, this triggers a full read of the data and caches
|
|
510
|
+
the result in self._data. Subsequent calls will return the cached table.
|
|
511
|
+
|
|
512
|
+
Returns:
|
|
513
|
+
The PyArrow Table containing all data
|
|
514
|
+
"""
|
|
515
|
+
if isinstance(self._data, TableSource):
|
|
516
|
+
source = self._data
|
|
517
|
+
with source.open() as reader:
|
|
518
|
+
self._data = reader.read_all()
|
|
519
|
+
|
|
520
|
+
return self._data
|
|
521
|
+
|
|
522
|
+
# =========== Wrap/Unwrap ===========
|
|
523
|
+
|
|
524
|
+
def _convert(self, data: Any) -> pa.Table | TableSource:
|
|
525
|
+
"""Convert input data to pa.Table or TableSource."""
|
|
526
|
+
if isinstance(data, TableValue):
|
|
527
|
+
return data.unwrap()
|
|
528
|
+
if isinstance(data, pd.DataFrame):
|
|
529
|
+
data = pa.Table.from_pandas(data)
|
|
530
|
+
if not isinstance(data, pa.Table | TableSource):
|
|
531
|
+
raise TypeError(f"Expected pa.Table or TableSource, got {type(data)}")
|
|
532
|
+
return data
|
|
533
|
+
|
|
534
|
+
# =========== Serialization ===========
|
|
535
|
+
|
|
536
|
+
def to_json(self) -> dict[str, Any]:
|
|
537
|
+
# Serialize using Arrow IPC streaming format
|
|
538
|
+
data = self.data
|
|
539
|
+
sink = pa.BufferOutputStream()
|
|
540
|
+
with pa.ipc.new_stream(sink, data.schema) as writer:
|
|
541
|
+
writer.write_table(data)
|
|
542
|
+
ipc_bytes = sink.getvalue().to_pybytes()
|
|
543
|
+
return {"ipc": base64.b64encode(ipc_bytes).decode("ascii")}
|
|
544
|
+
|
|
545
|
+
@classmethod
|
|
546
|
+
def from_json(cls, data: dict[str, Any]) -> TableValue:
|
|
547
|
+
ipc_bytes = base64.b64decode(data["ipc"])
|
|
548
|
+
reader = pa.ipc.open_stream(ipc_bytes)
|
|
549
|
+
table = reader.read_all()
|
|
550
|
+
return cls(table)
|
|
551
|
+
|
|
552
|
+
|
|
553
|
+
# Module-level helpers for convenience (delegate to class methods)
|
|
554
|
+
def _wrap(val: pa.Table | pd.DataFrame | TableSource | TableValue) -> TableValue:
|
|
555
|
+
"""Wrap a table-like value into TableValue."""
|
|
556
|
+
return TableValue.wrap(val)
|
|
557
|
+
|
|
558
|
+
|
|
559
|
+
def _unwrap(val: TableValue | pa.Table | pd.DataFrame) -> pa.Table:
|
|
560
|
+
"""Unwrap TableValue to pa.Table, also accepts raw pa.Table/DataFrame."""
|
|
561
|
+
if isinstance(val, TableValue):
|
|
562
|
+
return val.data
|
|
563
|
+
if isinstance(val, pd.DataFrame):
|
|
564
|
+
return pa.Table.from_pandas(val)
|
|
565
|
+
if isinstance(val, pa.Table):
|
|
566
|
+
return val
|
|
567
|
+
# Handle RecordBatchReader from newer PyArrow versions
|
|
568
|
+
if isinstance(val, pa.RecordBatchReader):
|
|
569
|
+
return val.read_all()
|
|
570
|
+
raise TypeError(
|
|
571
|
+
f"Expected TableValue, pa.Table, pd.DataFrame, or RecordBatchReader, got {type(val)}"
|
|
572
|
+
)
|
|
573
|
+
|
|
574
|
+
|
|
575
|
+
# =============================================================================
|
|
576
|
+
# Table Primitive Implementations
|
|
577
|
+
# =============================================================================
|
|
578
|
+
|
|
579
|
+
|
|
580
|
+
@table.run_sql_p.def_impl
|
|
581
|
+
def run_sql_impl(interpreter: Interpreter, op: Operation, *args: Any) -> TableValue:
|
|
582
|
+
"""Execute SQL query on input tables."""
|
|
583
|
+
query = op.attrs["query"]
|
|
584
|
+
dialect = op.attrs.get("dialect", "duckdb")
|
|
585
|
+
table_names = op.attrs["table_names"]
|
|
586
|
+
|
|
587
|
+
if dialect != "duckdb":
|
|
588
|
+
raise ValueError(f"Unsupported dialect: {dialect}")
|
|
589
|
+
|
|
590
|
+
state: DuckDBState | None = None
|
|
591
|
+
tables: list[TableValue] = []
|
|
592
|
+
for arg in args:
|
|
593
|
+
tbl = _wrap(arg)
|
|
594
|
+
tables.append(tbl)
|
|
595
|
+
data = tbl.unwrap()
|
|
596
|
+
if isinstance(data, QueryTableSource):
|
|
597
|
+
if state is None:
|
|
598
|
+
state = data.state
|
|
599
|
+
elif state != data.state:
|
|
600
|
+
raise ValueError("All tables must belong to the same DuckDB connection")
|
|
601
|
+
|
|
602
|
+
if state is None:
|
|
603
|
+
conn = duckdb.connect()
|
|
604
|
+
state = DuckDBState(conn)
|
|
605
|
+
|
|
606
|
+
try:
|
|
607
|
+
conn = state.conn
|
|
608
|
+
# register tables or create view
|
|
609
|
+
for name, tbl in zip(table_names, tables, strict=True):
|
|
610
|
+
data = tbl.unwrap()
|
|
611
|
+
if name in state.tables:
|
|
612
|
+
if state.tables[name] is not data:
|
|
613
|
+
# TODO: rename and rewrite sql??
|
|
614
|
+
raise ValueError(f"{name} has been registered.")
|
|
615
|
+
else:
|
|
616
|
+
state.tables[name] = data
|
|
617
|
+
if isinstance(data, TableSource):
|
|
618
|
+
data.register(state.conn, name)
|
|
619
|
+
else:
|
|
620
|
+
conn.register(name, data)
|
|
621
|
+
|
|
622
|
+
relation = conn.sql(query)
|
|
623
|
+
return _wrap(QueryTableSource(relation, state))
|
|
624
|
+
except Exception as e:
|
|
625
|
+
raise RuntimeError(f"Failed to execute SQL query: {query}") from e
|
|
626
|
+
|
|
627
|
+
|
|
628
|
+
@table.table2tensor_p.def_impl
|
|
629
|
+
def table2tensor_impl(interpreter: Interpreter, op: Operation, table_val: Any) -> Any:
|
|
630
|
+
"""Convert table to tensor (numpy array).
|
|
631
|
+
|
|
632
|
+
Returns TensorValue if tensor_impl is available, otherwise raw np.ndarray.
|
|
633
|
+
"""
|
|
634
|
+
from mplang.v2.backends.tensor_impl import TensorValue
|
|
635
|
+
|
|
636
|
+
tbl = _unwrap(table_val)
|
|
637
|
+
df = tbl.to_pandas()
|
|
638
|
+
# Convert to numpy array
|
|
639
|
+
# Note: This assumes the table is homogeneous as enforced by abstract_eval
|
|
640
|
+
arr = df.to_numpy()
|
|
641
|
+
return TensorValue.wrap(arr)
|
|
642
|
+
|
|
643
|
+
|
|
644
|
+
@table.tensor2table_p.def_impl
|
|
645
|
+
def tensor2table_impl(
|
|
646
|
+
interpreter: Interpreter, op: Operation, tensor_val: TensorValue
|
|
647
|
+
) -> TableValue:
|
|
648
|
+
"""Convert tensor (numpy array) to table."""
|
|
649
|
+
column_names = op.attrs["column_names"]
|
|
650
|
+
|
|
651
|
+
# Unwrap TensorValue
|
|
652
|
+
arr = tensor_val.unwrap()
|
|
653
|
+
|
|
654
|
+
if arr.ndim != 2:
|
|
655
|
+
raise ValueError(f"Expected 2D array, got {arr.ndim}D")
|
|
656
|
+
|
|
657
|
+
if arr.shape[1] != len(column_names):
|
|
658
|
+
raise ValueError(
|
|
659
|
+
f"Shape mismatch: tensor has {arr.shape[1]} columns, "
|
|
660
|
+
f"but {len(column_names)} names provided"
|
|
661
|
+
)
|
|
662
|
+
|
|
663
|
+
# Create dictionary for DataFrame/Table creation
|
|
664
|
+
data = {}
|
|
665
|
+
for i, name in enumerate(column_names):
|
|
666
|
+
data[name] = arr[:, i]
|
|
667
|
+
|
|
668
|
+
return _wrap(pa.Table.from_pydict(data))
|
|
669
|
+
|
|
670
|
+
|
|
671
|
+
@table.constant_p.def_impl
|
|
672
|
+
def constant_impl(interpreter: Interpreter, op: Operation) -> TableValue:
|
|
673
|
+
"""Create constant table."""
|
|
674
|
+
# data is stored in attrs by default bind if not TraceObject
|
|
675
|
+
data = op.attrs["data"]
|
|
676
|
+
|
|
677
|
+
# Handle pandas DataFrame if passed directly (though attrs usually store basic types)
|
|
678
|
+
# If data was a DataFrame, it might have been stored as is if the IR supports it.
|
|
679
|
+
# If data was a dict, it's fine.
|
|
680
|
+
|
|
681
|
+
if isinstance(data, pa.Table):
|
|
682
|
+
return _wrap(data)
|
|
683
|
+
else:
|
|
684
|
+
return _wrap(pa.table(data))
|
|
685
|
+
|
|
686
|
+
|
|
687
|
+
def _infer_format(path: str, format_hint: str) -> str:
|
|
688
|
+
"""Infer file format from path extension or hint."""
|
|
689
|
+
if format_hint != "auto":
|
|
690
|
+
return format_hint
|
|
691
|
+
|
|
692
|
+
path_lower = path.lower()
|
|
693
|
+
if path_lower.endswith((".parquet", ".pq")):
|
|
694
|
+
return "parquet"
|
|
695
|
+
elif path_lower.endswith(".csv"):
|
|
696
|
+
return "csv"
|
|
697
|
+
elif path_lower.endswith((".json", ".jsonl")):
|
|
698
|
+
return "json"
|
|
699
|
+
else:
|
|
700
|
+
# Default to parquet
|
|
701
|
+
return "parquet"
|
|
702
|
+
|
|
703
|
+
|
|
704
|
+
@table.read_p.def_impl
|
|
705
|
+
def read_impl(interpreter: Interpreter, op: Operation) -> TableValue:
|
|
706
|
+
"""Read table from file.
|
|
707
|
+
|
|
708
|
+
Supported formats: parquet, csv, json
|
|
709
|
+
"""
|
|
710
|
+
import os
|
|
711
|
+
|
|
712
|
+
path: str = op.attrs["path"]
|
|
713
|
+
schema: elt.TableType = op.attrs["schema"]
|
|
714
|
+
format_hint: str = op.attrs.get("format", "auto")
|
|
715
|
+
fmt = _infer_format(path, format_hint)
|
|
716
|
+
if not os.path.exists(path):
|
|
717
|
+
raise FileNotFoundError(f"{path} not exists")
|
|
718
|
+
|
|
719
|
+
pa_schema = _pa_schema(schema) if schema else None
|
|
720
|
+
return _wrap(FileTableSource(path=path, format=fmt, schema=pa_schema))
|
|
721
|
+
|
|
722
|
+
|
|
723
|
+
class MultiTableReader(BatchReader):
|
|
724
|
+
def __init__(self, readers: list[TableReader]) -> None:
|
|
725
|
+
fields = {}
|
|
726
|
+
for r in readers:
|
|
727
|
+
for f in r.schema:
|
|
728
|
+
if f.name in fields:
|
|
729
|
+
raise ValueError(f"Field name conflict. {f.name}")
|
|
730
|
+
fields[f.name] = f
|
|
731
|
+
|
|
732
|
+
self._readers = readers
|
|
733
|
+
self._schema = pa.schema(list(fields.values()))
|
|
734
|
+
|
|
735
|
+
@property
|
|
736
|
+
def schema(self) -> pa.Schema:
|
|
737
|
+
return self._schema
|
|
738
|
+
|
|
739
|
+
def read_next_batch(self) -> pa.RecordBatch:
|
|
740
|
+
num_rows = -1
|
|
741
|
+
columns: list[pa.ChunkedArray] = []
|
|
742
|
+
for idx, r in enumerate(self._readers):
|
|
743
|
+
batch = r.read_next_batch()
|
|
744
|
+
if num_rows == -1:
|
|
745
|
+
num_rows = batch.num_rows
|
|
746
|
+
elif num_rows != batch.num_rows:
|
|
747
|
+
raise ValueError(
|
|
748
|
+
f"Batch {idx} has {batch.num_rows} rows, expected {num_rows}"
|
|
749
|
+
)
|
|
750
|
+
columns.extend(batch.columns)
|
|
751
|
+
return pa.RecordBatch.from_arrays(columns, names=self._schema.names)
|
|
752
|
+
|
|
753
|
+
def close(self) -> None:
|
|
754
|
+
for r in self._readers:
|
|
755
|
+
r.close()
|
|
756
|
+
|
|
757
|
+
|
|
758
|
+
@table.write_p.def_impl
|
|
759
|
+
def write_impl(interpreter: Interpreter, op: Operation, *tables: TableValue) -> None:
|
|
760
|
+
"""Write table to file.
|
|
761
|
+
|
|
762
|
+
Supported formats: parquet, csv, json
|
|
763
|
+
|
|
764
|
+
For LazyTable, performs streaming writes when supported.
|
|
765
|
+
For regular Tables, performs direct writes.
|
|
766
|
+
"""
|
|
767
|
+
import os
|
|
768
|
+
|
|
769
|
+
path: str = op.attrs["path"]
|
|
770
|
+
format_hint: str = op.attrs.get("format", "parquet")
|
|
771
|
+
|
|
772
|
+
fmt = _infer_format(path, format_hint)
|
|
773
|
+
|
|
774
|
+
batch_size = DEFAULT_BATCH_SIZE if len(tables) > 1 else -1
|
|
775
|
+
readers: list[TableReader] = []
|
|
776
|
+
for t in tables:
|
|
777
|
+
data = t.unwrap()
|
|
778
|
+
readers.append(
|
|
779
|
+
data.open(batch_size)
|
|
780
|
+
if isinstance(data, TableSource)
|
|
781
|
+
else TableReader(data)
|
|
782
|
+
)
|
|
783
|
+
|
|
784
|
+
reader: BatchReader = readers[0] if len(readers) == 1 else MultiTableReader(readers)
|
|
785
|
+
|
|
786
|
+
import pyarrow.csv as pa_csv
|
|
787
|
+
import pyarrow.parquet as pa_pq
|
|
788
|
+
|
|
789
|
+
@runtime_checkable
|
|
790
|
+
class BatchWriter(Protocol):
|
|
791
|
+
def write_batch(self, batch: pa.RecordBatch) -> None: ...
|
|
792
|
+
def close(self) -> None: ...
|
|
793
|
+
|
|
794
|
+
class JsonWriter(BatchWriter):
|
|
795
|
+
def __init__(self, path: str) -> None:
|
|
796
|
+
self._path = path
|
|
797
|
+
self._batches: list[pa.RecordBatch] = []
|
|
798
|
+
|
|
799
|
+
def write_batch(self, batch: pa.RecordBatch) -> None:
|
|
800
|
+
self._batches.append(batch)
|
|
801
|
+
|
|
802
|
+
def close(self) -> None:
|
|
803
|
+
# PyArrow doesn't have direct JSON write, convert to pandas
|
|
804
|
+
tbl = pa.Table.from_batches(self._batches)
|
|
805
|
+
df = tbl.to_pandas()
|
|
806
|
+
df.to_json(self._path, orient="records", lines=True)
|
|
807
|
+
|
|
808
|
+
def _safe_remove_file(path: str) -> None:
|
|
809
|
+
if os.path.exists(path):
|
|
810
|
+
try:
|
|
811
|
+
os.remove(path)
|
|
812
|
+
except Exception:
|
|
813
|
+
pass # Ignore cleanup errors
|
|
814
|
+
|
|
815
|
+
try:
|
|
816
|
+
match fmt:
|
|
817
|
+
case "parquet":
|
|
818
|
+
writer = pa_pq.ParquetWriter(path, reader.schema)
|
|
819
|
+
case "csv":
|
|
820
|
+
writer = pa_csv.CSVWriter(path, reader.schema)
|
|
821
|
+
case "json":
|
|
822
|
+
writer = JsonWriter(path)
|
|
823
|
+
case _:
|
|
824
|
+
raise ValueError(f"Unsupported format: {fmt}")
|
|
825
|
+
except Exception as e:
|
|
826
|
+
reader.close()
|
|
827
|
+
_safe_remove_file(path)
|
|
828
|
+
raise e
|
|
829
|
+
|
|
830
|
+
try:
|
|
831
|
+
for batch in reader:
|
|
832
|
+
writer.write_batch(batch)
|
|
833
|
+
except Exception as e:
|
|
834
|
+
_safe_remove_file(path)
|
|
835
|
+
raise e
|
|
836
|
+
finally:
|
|
837
|
+
reader.close()
|
|
838
|
+
writer.close()
|