mplang-nightly 0.1.dev192__py3-none-any.whl → 0.1.dev268__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (188) hide show
  1. mplang/__init__.py +21 -130
  2. mplang/py.typed +13 -0
  3. mplang/v1/__init__.py +157 -0
  4. mplang/v1/_device.py +602 -0
  5. mplang/{analysis → v1/analysis}/__init__.py +1 -1
  6. mplang/{analysis → v1/analysis}/diagram.py +4 -4
  7. mplang/{core → v1/core}/__init__.py +20 -14
  8. mplang/{core → v1/core}/cluster.py +6 -1
  9. mplang/{core → v1/core}/comm.py +1 -1
  10. mplang/{core → v1/core}/context_mgr.py +1 -1
  11. mplang/{core → v1/core}/dtypes.py +38 -0
  12. mplang/{core → v1/core}/expr/__init__.py +7 -7
  13. mplang/{core → v1/core}/expr/ast.py +11 -13
  14. mplang/{core → v1/core}/expr/evaluator.py +8 -8
  15. mplang/{core → v1/core}/expr/printer.py +6 -6
  16. mplang/{core → v1/core}/expr/transformer.py +2 -2
  17. mplang/{core → v1/core}/expr/utils.py +2 -2
  18. mplang/{core → v1/core}/expr/visitor.py +1 -1
  19. mplang/{core → v1/core}/expr/walk.py +1 -1
  20. mplang/{core → v1/core}/interp.py +6 -6
  21. mplang/{core → v1/core}/mpir.py +13 -11
  22. mplang/{core → v1/core}/mpobject.py +6 -6
  23. mplang/{core → v1/core}/mptype.py +13 -10
  24. mplang/{core → v1/core}/pfunc.py +2 -2
  25. mplang/{core → v1/core}/primitive.py +12 -12
  26. mplang/{core → v1/core}/table.py +36 -8
  27. mplang/{core → v1/core}/tensor.py +1 -1
  28. mplang/{core → v1/core}/tracer.py +9 -9
  29. mplang/{host.py → v1/host.py} +5 -5
  30. mplang/{kernels → v1/kernels}/__init__.py +1 -1
  31. mplang/{kernels → v1/kernels}/base.py +1 -1
  32. mplang/{kernels → v1/kernels}/basic.py +15 -15
  33. mplang/{kernels → v1/kernels}/context.py +19 -16
  34. mplang/{kernels → v1/kernels}/crypto.py +8 -10
  35. mplang/{kernels → v1/kernels}/fhe.py +9 -7
  36. mplang/{kernels → v1/kernels}/mock_tee.py +3 -3
  37. mplang/{kernels → v1/kernels}/phe.py +26 -18
  38. mplang/{kernels → v1/kernels}/spu.py +5 -5
  39. mplang/{kernels → v1/kernels}/sql_duckdb.py +5 -3
  40. mplang/{kernels → v1/kernels}/stablehlo.py +18 -17
  41. mplang/{kernels → v1/kernels}/value.py +2 -2
  42. mplang/{ops → v1/ops}/__init__.py +3 -3
  43. mplang/{ops → v1/ops}/base.py +1 -1
  44. mplang/{ops → v1/ops}/basic.py +6 -5
  45. mplang/v1/ops/crypto.py +262 -0
  46. mplang/{ops → v1/ops}/fhe.py +2 -2
  47. mplang/{ops → v1/ops}/jax_cc.py +26 -59
  48. mplang/v1/ops/nnx_cc.py +168 -0
  49. mplang/{ops → v1/ops}/phe.py +16 -3
  50. mplang/{ops → v1/ops}/spu.py +3 -3
  51. mplang/v1/ops/sql_cc.py +303 -0
  52. mplang/{ops → v1/ops}/tee.py +2 -2
  53. mplang/{runtime → v1/runtime}/__init__.py +2 -2
  54. mplang/v1/runtime/channel.py +230 -0
  55. mplang/{runtime → v1/runtime}/cli.py +3 -3
  56. mplang/{runtime → v1/runtime}/client.py +1 -1
  57. mplang/{runtime → v1/runtime}/communicator.py +39 -15
  58. mplang/{runtime → v1/runtime}/data_providers.py +80 -19
  59. mplang/{runtime → v1/runtime}/driver.py +4 -4
  60. mplang/v1/runtime/link_comm.py +196 -0
  61. mplang/{runtime → v1/runtime}/server.py +22 -9
  62. mplang/{runtime → v1/runtime}/session.py +24 -51
  63. mplang/{runtime → v1/runtime}/simulation.py +36 -14
  64. mplang/{simp → v1/simp}/api.py +72 -14
  65. mplang/{simp → v1/simp}/mpi.py +1 -1
  66. mplang/{simp → v1/simp}/party.py +5 -5
  67. mplang/{simp → v1/simp}/random.py +2 -2
  68. mplang/v1/simp/smpc.py +238 -0
  69. mplang/v1/utils/table_utils.py +185 -0
  70. mplang/v2/__init__.py +424 -0
  71. mplang/v2/backends/__init__.py +57 -0
  72. mplang/v2/backends/bfv_impl.py +705 -0
  73. mplang/v2/backends/channel.py +217 -0
  74. mplang/v2/backends/crypto_impl.py +723 -0
  75. mplang/v2/backends/field_impl.py +454 -0
  76. mplang/v2/backends/func_impl.py +107 -0
  77. mplang/v2/backends/phe_impl.py +148 -0
  78. mplang/v2/backends/simp_design.md +136 -0
  79. mplang/v2/backends/simp_driver/__init__.py +41 -0
  80. mplang/v2/backends/simp_driver/http.py +168 -0
  81. mplang/v2/backends/simp_driver/mem.py +280 -0
  82. mplang/v2/backends/simp_driver/ops.py +135 -0
  83. mplang/v2/backends/simp_driver/state.py +60 -0
  84. mplang/v2/backends/simp_driver/values.py +52 -0
  85. mplang/v2/backends/simp_worker/__init__.py +29 -0
  86. mplang/v2/backends/simp_worker/http.py +354 -0
  87. mplang/v2/backends/simp_worker/mem.py +102 -0
  88. mplang/v2/backends/simp_worker/ops.py +167 -0
  89. mplang/v2/backends/simp_worker/state.py +49 -0
  90. mplang/v2/backends/spu_impl.py +275 -0
  91. mplang/v2/backends/spu_state.py +187 -0
  92. mplang/v2/backends/store_impl.py +62 -0
  93. mplang/v2/backends/table_impl.py +838 -0
  94. mplang/v2/backends/tee_impl.py +215 -0
  95. mplang/v2/backends/tensor_impl.py +519 -0
  96. mplang/v2/cli.py +603 -0
  97. mplang/v2/cli_guide.md +122 -0
  98. mplang/v2/dialects/__init__.py +36 -0
  99. mplang/v2/dialects/bfv.py +665 -0
  100. mplang/v2/dialects/crypto.py +689 -0
  101. mplang/v2/dialects/dtypes.py +378 -0
  102. mplang/v2/dialects/field.py +210 -0
  103. mplang/v2/dialects/func.py +135 -0
  104. mplang/v2/dialects/phe.py +723 -0
  105. mplang/v2/dialects/simp.py +944 -0
  106. mplang/v2/dialects/spu.py +349 -0
  107. mplang/v2/dialects/store.py +63 -0
  108. mplang/v2/dialects/table.py +407 -0
  109. mplang/v2/dialects/tee.py +346 -0
  110. mplang/v2/dialects/tensor.py +1175 -0
  111. mplang/v2/edsl/README.md +279 -0
  112. mplang/v2/edsl/__init__.py +99 -0
  113. mplang/v2/edsl/context.py +311 -0
  114. mplang/v2/edsl/graph.py +463 -0
  115. mplang/v2/edsl/jit.py +62 -0
  116. mplang/v2/edsl/object.py +53 -0
  117. mplang/v2/edsl/primitive.py +284 -0
  118. mplang/v2/edsl/printer.py +119 -0
  119. mplang/v2/edsl/registry.py +207 -0
  120. mplang/v2/edsl/serde.py +375 -0
  121. mplang/v2/edsl/tracer.py +614 -0
  122. mplang/v2/edsl/typing.py +816 -0
  123. mplang/v2/kernels/Makefile +30 -0
  124. mplang/v2/kernels/__init__.py +23 -0
  125. mplang/v2/kernels/gf128.cpp +148 -0
  126. mplang/v2/kernels/ldpc.cpp +82 -0
  127. mplang/v2/kernels/okvs.cpp +283 -0
  128. mplang/v2/kernels/okvs_opt.cpp +291 -0
  129. mplang/v2/kernels/py_kernels.py +398 -0
  130. mplang/v2/libs/collective.py +330 -0
  131. mplang/v2/libs/device/__init__.py +51 -0
  132. mplang/v2/libs/device/api.py +813 -0
  133. mplang/v2/libs/device/cluster.py +352 -0
  134. mplang/v2/libs/ml/__init__.py +23 -0
  135. mplang/v2/libs/ml/sgb.py +1861 -0
  136. mplang/v2/libs/mpc/__init__.py +41 -0
  137. mplang/v2/libs/mpc/_utils.py +99 -0
  138. mplang/v2/libs/mpc/analytics/__init__.py +35 -0
  139. mplang/v2/libs/mpc/analytics/aggregation.py +372 -0
  140. mplang/v2/libs/mpc/analytics/groupby.md +99 -0
  141. mplang/v2/libs/mpc/analytics/groupby.py +331 -0
  142. mplang/v2/libs/mpc/analytics/permutation.py +386 -0
  143. mplang/v2/libs/mpc/common/constants.py +39 -0
  144. mplang/v2/libs/mpc/ot/__init__.py +32 -0
  145. mplang/v2/libs/mpc/ot/base.py +222 -0
  146. mplang/v2/libs/mpc/ot/extension.py +477 -0
  147. mplang/v2/libs/mpc/ot/silent.py +217 -0
  148. mplang/v2/libs/mpc/psi/__init__.py +40 -0
  149. mplang/v2/libs/mpc/psi/cuckoo.py +228 -0
  150. mplang/v2/libs/mpc/psi/okvs.py +49 -0
  151. mplang/v2/libs/mpc/psi/okvs_gct.py +79 -0
  152. mplang/v2/libs/mpc/psi/oprf.py +310 -0
  153. mplang/v2/libs/mpc/psi/rr22.py +344 -0
  154. mplang/v2/libs/mpc/psi/unbalanced.py +200 -0
  155. mplang/v2/libs/mpc/vole/__init__.py +31 -0
  156. mplang/v2/libs/mpc/vole/gilboa.py +327 -0
  157. mplang/v2/libs/mpc/vole/ldpc.py +383 -0
  158. mplang/v2/libs/mpc/vole/silver.py +336 -0
  159. mplang/v2/runtime/__init__.py +15 -0
  160. mplang/v2/runtime/dialect_state.py +41 -0
  161. mplang/v2/runtime/interpreter.py +871 -0
  162. mplang/v2/runtime/object_store.py +194 -0
  163. mplang/v2/runtime/value.py +141 -0
  164. {mplang_nightly-0.1.dev192.dist-info → mplang_nightly-0.1.dev268.dist-info}/METADATA +22 -16
  165. mplang_nightly-0.1.dev268.dist-info/RECORD +180 -0
  166. {mplang_nightly-0.1.dev192.dist-info → mplang_nightly-0.1.dev268.dist-info}/WHEEL +1 -1
  167. mplang/device.py +0 -327
  168. mplang/ops/crypto.py +0 -108
  169. mplang/ops/ibis_cc.py +0 -136
  170. mplang/ops/sql_cc.py +0 -62
  171. mplang/runtime/link_comm.py +0 -78
  172. mplang/simp/smpc.py +0 -201
  173. mplang/utils/table_utils.py +0 -85
  174. mplang_nightly-0.1.dev192.dist-info/RECORD +0 -83
  175. /mplang/{core → v1/core}/mask.py +0 -0
  176. /mplang/{protos → v1/protos}/v1alpha1/mpir_pb2.py +0 -0
  177. /mplang/{protos → v1/protos}/v1alpha1/mpir_pb2.pyi +0 -0
  178. /mplang/{protos → v1/protos}/v1alpha1/value_pb2.py +0 -0
  179. /mplang/{protos → v1/protos}/v1alpha1/value_pb2.pyi +0 -0
  180. /mplang/{runtime → v1/runtime}/exceptions.py +0 -0
  181. /mplang/{runtime → v1/runtime}/http_api.md +0 -0
  182. /mplang/{simp → v1/simp}/__init__.py +0 -0
  183. /mplang/{utils → v1/utils}/__init__.py +0 -0
  184. /mplang/{utils → v1/utils}/crypto.py +0 -0
  185. /mplang/{utils → v1/utils}/func_utils.py +0 -0
  186. /mplang/{utils → v1/utils}/spu_utils.py +0 -0
  187. {mplang_nightly-0.1.dev192.dist-info → mplang_nightly-0.1.dev268.dist-info}/entry_points.txt +0 -0
  188. {mplang_nightly-0.1.dev192.dist-info → mplang_nightly-0.1.dev268.dist-info}/licenses/LICENSE +0 -0
@@ -0,0 +1,838 @@
1
+ # Copyright 2025 Ant Group Co., Ltd.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ """Table Runtime Implementation.
16
+
17
+ Implements execution logic for Table primitives using DuckDB and PyArrow.
18
+ """
19
+
20
+ from __future__ import annotations
21
+
22
+ import base64
23
+ from abc import ABC, abstractmethod
24
+ from dataclasses import dataclass
25
+ from typing import Any, ClassVar, Protocol, Self, runtime_checkable
26
+
27
+ import duckdb
28
+ import pandas as pd
29
+ import pyarrow as pa
30
+
31
+ import mplang.v2.edsl.typing as elt
32
+ from mplang.v2.backends.tensor_impl import TensorValue
33
+ from mplang.v2.dialects import table
34
+ from mplang.v2.edsl import serde
35
+ from mplang.v2.edsl.graph import Operation
36
+ from mplang.v2.runtime.interpreter import Interpreter
37
+ from mplang.v2.runtime.value import WrapValue
38
+
39
+
40
+ class BatchReader(ABC):
41
+ @property
42
+ @abstractmethod
43
+ def schema(self) -> pa.Schema: ...
44
+
45
+ @abstractmethod
46
+ def read_next_batch(self) -> pa.RecordBatch: ...
47
+ @abstractmethod
48
+ def close(self) -> None: ...
49
+
50
+ def __enter__(self) -> Self:
51
+ return self
52
+
53
+ def __exit__(self, exc_type: object, exc_val: object, exc_tb: object) -> None:
54
+ self.close()
55
+
56
+ def __iter__(self) -> Self:
57
+ return self
58
+
59
+ def __next__(self) -> pa.RecordBatch:
60
+ return self.read_next_batch()
61
+
62
+
63
+ class TableReader(BatchReader):
64
+ """A reader for streaming table data from PyArrow RecordBatchReader or Table.
65
+
66
+ This class provides an efficient way to read large tables in batches,
67
+ with support for custom batch sizes and proper handling of data boundaries.
68
+ It implements the iterator protocol for easy consumption of data.
69
+ """
70
+
71
+ def __init__(
72
+ self,
73
+ data: pa.RecordBatchReader | pa.Table,
74
+ num_rows: int = -1,
75
+ batch_size: int = -1,
76
+ ) -> None:
77
+ """Initialize a TableReader.
78
+
79
+ Args:
80
+ data: Either a RecordBatchReader or Table to read from
81
+ num_rows: Expected number of rows in the data. -1 indicates unknown
82
+ batch_size: Size of each batch to read. -1 means use default/reader's batch size
83
+ """
84
+ # Store the underlying reader and row count based on input type
85
+ if isinstance(data, pa.RecordBatchReader):
86
+ self._reader = data
87
+ self._num_rows = num_rows
88
+ else:
89
+ # Convert Table to RecordBatchReader for consistent interface
90
+ self._reader = data.to_reader()
91
+ self._num_rows = data.num_rows
92
+
93
+ # Configuration for batch reading
94
+ self._batch_size = batch_size
95
+
96
+ # Internal state for handling custom batch sizes
97
+ self._remain: pa.RecordBatch | None = (
98
+ None # Stores partial batch from previous read
99
+ )
100
+ self._eof = False # Flag to indicate end of data
101
+
102
+ @property
103
+ def num_rows(self) -> int:
104
+ """Get the total number of rows in the table.
105
+
106
+ Returns:
107
+ Total number of rows, or -1 if unknown
108
+ """
109
+ return self._num_rows
110
+
111
+ @property
112
+ def schema(self) -> pa.Schema:
113
+ """Get the schema of the table.
114
+
115
+ Returns:
116
+ PyArrow Schema describing the table's columns and types
117
+ """
118
+ return self._reader.schema
119
+
120
+ def read_all(self) -> pa.Table:
121
+ """Read all remaining data as a Table.
122
+
123
+ This is a convenience method that reads all data from the reader
124
+ and returns it as a single PyArrow Table.
125
+
126
+ Returns:
127
+ Complete table containing all remaining data
128
+ """
129
+ return self._reader.read_all()
130
+
131
+ def read_next_batch(self) -> pa.RecordBatch:
132
+ """Read the next batch of records.
133
+
134
+ This method respects the configured batch size. If the native reader
135
+ returns batches larger than the configured size, this method will split
136
+ them appropriately. Any partial data from previous reads is included
137
+ in the returned batch.
138
+
139
+ Returns:
140
+ Next RecordBatch of data
141
+
142
+ Raises:
143
+ StopIteration: When no more data is available
144
+ """
145
+ # Check if we've reached end of file
146
+ if self._eof:
147
+ raise StopIteration
148
+
149
+ # Get the next batch using internal logic
150
+ batch = self._read_next_batch()
151
+
152
+ # Handle end of data
153
+ if batch is None:
154
+ self._eof = True
155
+ raise StopIteration
156
+
157
+ return batch
158
+
159
+ def _read_next_batch(self) -> pa.RecordBatch | None:
160
+ """Internal method to read and process the next batch.
161
+
162
+ This method handles the complex logic of:
163
+ - Using default batch size when none is specified
164
+ - Accumulating data from multiple native batches to reach the target size
165
+ - Splitting oversized batches and saving the remainder
166
+ - Converting between Table and RecordBatch formats as needed
167
+
168
+ Returns:
169
+ Next RecordBatch of the configured size, or None if no more data
170
+ """
171
+ # If no batch size specified, just return the reader's native batches
172
+ if self._batch_size <= 0:
173
+ try:
174
+ batch = self._reader.read_next_batch()
175
+ # Convert to RecordBatch if the reader returns a Table
176
+ if isinstance(batch, pa.Table) and batch.num_rows > 0:
177
+ return batch.to_batches()[0]
178
+ return batch
179
+ except StopIteration:
180
+ return None
181
+
182
+ # We have a custom batch size - need to accumulate/split batches
183
+ batches: list[pa.RecordBatch] = []
184
+ num_rows: int = 0
185
+
186
+ # First, include any remaining data from the previous read
187
+ if self._remain is not None:
188
+ num_rows = self._remain.num_rows
189
+ batches = [self._remain]
190
+ self._remain = None
191
+
192
+ # Keep reading until we have enough rows or run out of data
193
+ while num_rows < self._batch_size:
194
+ try:
195
+ batch = self._reader.read_next_batch()
196
+
197
+ # Handle the case where reader returns a Table instead of RecordBatch
198
+ if isinstance(batch, pa.Table):
199
+ if batch.num_rows > 0:
200
+ # Convert each batch from the Table
201
+ for rb in batch.to_batches():
202
+ num_rows += rb.num_rows
203
+ if rb.num_rows > 0: # Skip empty batches
204
+ batches.append(rb)
205
+ else:
206
+ # Already a RecordBatch
207
+ num_rows += batch.num_rows
208
+ if batch.num_rows > 0: # Skip empty batches
209
+ batches.append(batch)
210
+ except StopIteration:
211
+ # Mark EOF but continue processing what we have
212
+ self._eof = True
213
+ break
214
+
215
+ # If we didn't get any data, return None
216
+ if num_rows == 0:
217
+ return None
218
+
219
+ # Split the last batch if we have more rows than needed
220
+ if num_rows > self._batch_size:
221
+ last = batches[-1]
222
+ remain_size = num_rows - self._batch_size
223
+ last_size = last.num_rows - remain_size
224
+
225
+ # Keep only what we need from the last batch
226
+ batches[-1] = last.slice(0, last_size)
227
+ # Save the remainder for the next read
228
+ self._remain = last.slice(last_size, remain_size)
229
+
230
+ # Optimized path: if we only have one batch, return it directly
231
+ if len(batches) == 1:
232
+ return batches[0]
233
+
234
+ # Otherwise, combine all batches and return as a single RecordBatch
235
+ combined = pa.Table.from_batches(batches)
236
+ return combined.to_batches()[0]
237
+
238
+ def close(self) -> None:
239
+ """Close the reader and release all resources.
240
+
241
+ This method should be called when the reader is no longer needed.
242
+ It closes the underlying reader and clears internal state.
243
+ """
244
+ # Close the underlying reader
245
+ self._reader.close()
246
+ # Clear internal state
247
+ self._remain = None
248
+ self._eof = False
249
+
250
+
251
+ DEFAULT_BATCH_SIZE = 1_000_000
252
+
253
+
254
+ class TableSource(ABC):
255
+ """Abstract base class for lazy table operations.
256
+
257
+ Provides deferred execution for table operations to prevent OOM issues.
258
+ """
259
+
260
+ @abstractmethod
261
+ def register(
262
+ self, conn: duckdb.DuckDBPyConnection, name: str, replace: bool = True
263
+ ) -> None: ...
264
+
265
+ @abstractmethod
266
+ def open(self, batch_size: int = DEFAULT_BATCH_SIZE) -> TableReader:
267
+ """Read data as a stream of record batches."""
268
+ ...
269
+
270
+
271
+ class ParquetReader(pa.RecordBatchReader):
272
+ """A reader that implements the pa.RecordBatchReader interface for Parquet files."""
273
+
274
+ def __init__(self, source: Any, columns: list[str] | None = None):
275
+ import pyarrow.parquet as pq
276
+
277
+ file = pq.ParquetFile(source)
278
+
279
+ # Use schema_arrow to get the proper pa.Schema
280
+ if columns:
281
+ # Filter the schema to only include selected columns
282
+ fields = [
283
+ file.schema_arrow.field(col)
284
+ for col in columns
285
+ if col in file.schema_arrow.names
286
+ ]
287
+ schema = pa.schema(fields)
288
+ else:
289
+ schema = file.schema_arrow
290
+
291
+ self._file = file
292
+ self._schema = schema
293
+ self._cast = False
294
+ self._num_rows = int(file.metadata.num_rows)
295
+ self._iter = file.iter_batches(columns=columns)
296
+
297
+ @property
298
+ def num_rows(self) -> int:
299
+ return self._num_rows
300
+
301
+ @property
302
+ def schema(self) -> pa.Schema:
303
+ return self._schema
304
+
305
+ def cast(self, target_schema: pa.Schema) -> ParquetReader:
306
+ # Validate that the number of columns is the same
307
+ if len(target_schema) != len(self._schema):
308
+ raise ValueError(
309
+ f"Cannot cast schema: target schema has {len(target_schema)} columns, "
310
+ f"but current schema has {len(self._schema)} columns"
311
+ )
312
+
313
+ # Check if there are any changes in the schema
314
+ schema_changed = False
315
+ for i, (target_field, current_field) in enumerate(
316
+ zip(target_schema, self._schema, strict=True)
317
+ ):
318
+ # Check if field names are the same (allowing type changes)
319
+ if target_field.name != current_field.name:
320
+ raise ValueError(
321
+ f"Cannot cast schema: field name at position {i} differs. "
322
+ f"Current: '{current_field.name}', Target: '{target_field.name}'. "
323
+ f"Field names must match."
324
+ )
325
+ # Check if types are different
326
+ if target_field.type != current_field.type:
327
+ schema_changed = True
328
+
329
+ # Only set _cast if there are actual changes
330
+ if schema_changed:
331
+ self._schema = target_schema
332
+ self._cast = True
333
+
334
+ return self
335
+
336
+ def read_all(self) -> pa.Table:
337
+ batches = []
338
+ try:
339
+ while True:
340
+ batch = self.read_next_batch()
341
+ batches.append(batch)
342
+ except StopIteration:
343
+ pass
344
+ if batches:
345
+ return pa.Table.from_batches(batches)
346
+ return pa.Table.from_batches([])
347
+
348
+ def read_next_batch(self) -> pa.RecordBatch:
349
+ batch = next(self._iter)
350
+ if self._cast:
351
+ return batch.cast(self._schema)
352
+ else:
353
+ return batch
354
+
355
+ def close(self) -> None:
356
+ """Close the Parquet reader and release resources."""
357
+ self._file.close()
358
+
359
+
360
+ _type_mapping = {
361
+ elt.bool_: pa.bool_(),
362
+ elt.i8: pa.int8(),
363
+ elt.i16: pa.int16(),
364
+ elt.i32: pa.int32(),
365
+ elt.i64: pa.int64(),
366
+ elt.u8: pa.uint8(),
367
+ elt.u16: pa.uint16(),
368
+ elt.u32: pa.uint32(),
369
+ elt.u64: pa.uint64(),
370
+ elt.f16: pa.float16(),
371
+ elt.f32: pa.float32(),
372
+ elt.f64: pa.float64(),
373
+ elt.STRING: pa.string(),
374
+ elt.DATE: pa.date64(),
375
+ elt.TIME: pa.time32("ms"),
376
+ elt.TIMESTAMP: pa.timestamp("ms"),
377
+ elt.DECIMAL: pa.decimal128(38, 10),
378
+ elt.BINARY: pa.binary(),
379
+ elt.JSON: pa.json_(),
380
+ }
381
+
382
+
383
+ def _pa_schema(s: elt.TableType) -> pa.Schema:
384
+ fields = []
385
+ for k, v in s.schema.items():
386
+ if v not in _type_mapping:
387
+ raise ValueError(f"cannot convert to pyarrow type. type={v}, name={k}")
388
+ fields.append(pa.field(k, _type_mapping[v]))
389
+
390
+ return pa.schema(fields)
391
+
392
+
393
+ @dataclass
394
+ class FileTableSource(TableSource):
395
+ """Lazy table handle for file-based operations with streaming reads."""
396
+
397
+ path: str
398
+ format: str
399
+ schema: pa.Schema | None = None
400
+
401
+ def register(
402
+ self, conn: duckdb.DuckDBPyConnection, name: str, replace: bool = True
403
+ ) -> None:
404
+ """Register the file as a view in DuckDB."""
405
+ func_name = ""
406
+ match self.format:
407
+ case "parquet":
408
+ func_name = "read_parquet"
409
+ case "csv":
410
+ func_name = "read_csv_auto"
411
+ case "json":
412
+ func_name = "read_json_auto"
413
+ case _:
414
+ raise ValueError(f"Unsupported format: {self.format}")
415
+
416
+ safe_path = self.path.replace("'", "''")
417
+ base_query = f"SELECT * FROM {func_name}('{safe_path}')"
418
+ if replace:
419
+ query = f"CREATE OR REPLACE VIEW {name} AS {base_query}"
420
+ else:
421
+ query = f"CREATE VIEW {name} AS {base_query}"
422
+ conn.execute(query)
423
+
424
+ def open(self, batch_size: int = DEFAULT_BATCH_SIZE) -> TableReader:
425
+ """Create a streaming reader for the file."""
426
+ import pyarrow.csv as pa_csv
427
+ import pyarrow.json as pa_json
428
+
429
+ columns = self.schema.names if self.schema else None
430
+
431
+ reader = None
432
+ num_rows = -1
433
+ match self.format.lower():
434
+ case "parquet":
435
+ reader = ParquetReader(self.path, columns)
436
+ num_rows = reader.num_rows
437
+ case "csv":
438
+ read_options = pa_csv.ReadOptions(use_threads=True)
439
+ convert_options = pa_csv.ConvertOptions(
440
+ column_types=self.schema,
441
+ include_columns=columns,
442
+ )
443
+ reader = pa_csv.open_csv(
444
+ self.path,
445
+ read_options=read_options,
446
+ convert_options=convert_options,
447
+ )
448
+ case "json":
449
+ read_options = pa_json.ReadOptions(use_threads=True)
450
+ table = pa_json.read_json(self.path, read_options=read_options)
451
+ if columns:
452
+ table = table.select(columns)
453
+ reader = table.to_reader()
454
+ num_rows = table.num_rows
455
+ case _:
456
+ raise ValueError(f"Unsupported format: {self.format}")
457
+
458
+ if self.schema and self.schema != reader.schema:
459
+ reader = reader.cast(self.schema)
460
+
461
+ return TableReader(reader, num_rows=num_rows, batch_size=batch_size)
462
+
463
+
464
+ class DuckDBState:
465
+ def __init__(self, conn: duckdb.DuckDBPyConnection) -> None:
466
+ self.conn = conn
467
+ self.tables: dict[str, Any] = {}
468
+
469
+
470
+ @dataclass(frozen=True)
471
+ class QueryTableSource(TableSource):
472
+ """Handle for existing DuckDB relations (kept for compatibility)."""
473
+
474
+ relation: duckdb.DuckDBPyRelation
475
+ state: DuckDBState
476
+
477
+ def register(
478
+ self, conn: duckdb.DuckDBPyConnection, name: str, replace: bool = True
479
+ ) -> None:
480
+ self.relation.create_view(name, replace)
481
+
482
+ def open(self, batch_size: int = DEFAULT_BATCH_SIZE) -> TableReader:
483
+ """Read from the DuckDB relation."""
484
+ if batch_size <= 0:
485
+ batch_size = DEFAULT_BATCH_SIZE
486
+ reader = self.relation.arrow(batch_size)
487
+ return TableReader(reader)
488
+
489
+
490
+ # =============================================================================
491
+ # TableValue Wrapper
492
+ # =============================================================================
493
+
494
+
495
+ @serde.register_class
496
+ class TableValue(WrapValue[pa.Table | TableSource]):
497
+ """Runtime value wrapping a PyArrow Table.
498
+
499
+ Provides serialization via Arrow IPC format (streaming).
500
+ Future: may extend to support other backends (Polars, DuckDB relations, etc.)
501
+ """
502
+
503
+ _serde_kind: ClassVar[str] = "table_impl.TableValue"
504
+
505
+ @property
506
+ def data(self) -> pa.Table:
507
+ """Get the underlying PyArrow Table data.
508
+
509
+ For lazy TableSource, this triggers a full read of the data and caches
510
+ the result in self._data. Subsequent calls will return the cached table.
511
+
512
+ Returns:
513
+ The PyArrow Table containing all data
514
+ """
515
+ if isinstance(self._data, TableSource):
516
+ source = self._data
517
+ with source.open() as reader:
518
+ self._data = reader.read_all()
519
+
520
+ return self._data
521
+
522
+ # =========== Wrap/Unwrap ===========
523
+
524
+ def _convert(self, data: Any) -> pa.Table | TableSource:
525
+ """Convert input data to pa.Table or TableSource."""
526
+ if isinstance(data, TableValue):
527
+ return data.unwrap()
528
+ if isinstance(data, pd.DataFrame):
529
+ data = pa.Table.from_pandas(data)
530
+ if not isinstance(data, pa.Table | TableSource):
531
+ raise TypeError(f"Expected pa.Table or TableSource, got {type(data)}")
532
+ return data
533
+
534
+ # =========== Serialization ===========
535
+
536
+ def to_json(self) -> dict[str, Any]:
537
+ # Serialize using Arrow IPC streaming format
538
+ data = self.data
539
+ sink = pa.BufferOutputStream()
540
+ with pa.ipc.new_stream(sink, data.schema) as writer:
541
+ writer.write_table(data)
542
+ ipc_bytes = sink.getvalue().to_pybytes()
543
+ return {"ipc": base64.b64encode(ipc_bytes).decode("ascii")}
544
+
545
+ @classmethod
546
+ def from_json(cls, data: dict[str, Any]) -> TableValue:
547
+ ipc_bytes = base64.b64decode(data["ipc"])
548
+ reader = pa.ipc.open_stream(ipc_bytes)
549
+ table = reader.read_all()
550
+ return cls(table)
551
+
552
+
553
+ # Module-level helpers for convenience (delegate to class methods)
554
+ def _wrap(val: pa.Table | pd.DataFrame | TableSource | TableValue) -> TableValue:
555
+ """Wrap a table-like value into TableValue."""
556
+ return TableValue.wrap(val)
557
+
558
+
559
+ def _unwrap(val: TableValue | pa.Table | pd.DataFrame) -> pa.Table:
560
+ """Unwrap TableValue to pa.Table, also accepts raw pa.Table/DataFrame."""
561
+ if isinstance(val, TableValue):
562
+ return val.data
563
+ if isinstance(val, pd.DataFrame):
564
+ return pa.Table.from_pandas(val)
565
+ if isinstance(val, pa.Table):
566
+ return val
567
+ # Handle RecordBatchReader from newer PyArrow versions
568
+ if isinstance(val, pa.RecordBatchReader):
569
+ return val.read_all()
570
+ raise TypeError(
571
+ f"Expected TableValue, pa.Table, pd.DataFrame, or RecordBatchReader, got {type(val)}"
572
+ )
573
+
574
+
575
+ # =============================================================================
576
+ # Table Primitive Implementations
577
+ # =============================================================================
578
+
579
+
580
+ @table.run_sql_p.def_impl
581
+ def run_sql_impl(interpreter: Interpreter, op: Operation, *args: Any) -> TableValue:
582
+ """Execute SQL query on input tables."""
583
+ query = op.attrs["query"]
584
+ dialect = op.attrs.get("dialect", "duckdb")
585
+ table_names = op.attrs["table_names"]
586
+
587
+ if dialect != "duckdb":
588
+ raise ValueError(f"Unsupported dialect: {dialect}")
589
+
590
+ state: DuckDBState | None = None
591
+ tables: list[TableValue] = []
592
+ for arg in args:
593
+ tbl = _wrap(arg)
594
+ tables.append(tbl)
595
+ data = tbl.unwrap()
596
+ if isinstance(data, QueryTableSource):
597
+ if state is None:
598
+ state = data.state
599
+ elif state != data.state:
600
+ raise ValueError("All tables must belong to the same DuckDB connection")
601
+
602
+ if state is None:
603
+ conn = duckdb.connect()
604
+ state = DuckDBState(conn)
605
+
606
+ try:
607
+ conn = state.conn
608
+ # register tables or create view
609
+ for name, tbl in zip(table_names, tables, strict=True):
610
+ data = tbl.unwrap()
611
+ if name in state.tables:
612
+ if state.tables[name] is not data:
613
+ # TODO: rename and rewrite sql??
614
+ raise ValueError(f"{name} has been registered.")
615
+ else:
616
+ state.tables[name] = data
617
+ if isinstance(data, TableSource):
618
+ data.register(state.conn, name)
619
+ else:
620
+ conn.register(name, data)
621
+
622
+ relation = conn.sql(query)
623
+ return _wrap(QueryTableSource(relation, state))
624
+ except Exception as e:
625
+ raise RuntimeError(f"Failed to execute SQL query: {query}") from e
626
+
627
+
628
+ @table.table2tensor_p.def_impl
629
+ def table2tensor_impl(interpreter: Interpreter, op: Operation, table_val: Any) -> Any:
630
+ """Convert table to tensor (numpy array).
631
+
632
+ Returns TensorValue if tensor_impl is available, otherwise raw np.ndarray.
633
+ """
634
+ from mplang.v2.backends.tensor_impl import TensorValue
635
+
636
+ tbl = _unwrap(table_val)
637
+ df = tbl.to_pandas()
638
+ # Convert to numpy array
639
+ # Note: This assumes the table is homogeneous as enforced by abstract_eval
640
+ arr = df.to_numpy()
641
+ return TensorValue.wrap(arr)
642
+
643
+
644
+ @table.tensor2table_p.def_impl
645
+ def tensor2table_impl(
646
+ interpreter: Interpreter, op: Operation, tensor_val: TensorValue
647
+ ) -> TableValue:
648
+ """Convert tensor (numpy array) to table."""
649
+ column_names = op.attrs["column_names"]
650
+
651
+ # Unwrap TensorValue
652
+ arr = tensor_val.unwrap()
653
+
654
+ if arr.ndim != 2:
655
+ raise ValueError(f"Expected 2D array, got {arr.ndim}D")
656
+
657
+ if arr.shape[1] != len(column_names):
658
+ raise ValueError(
659
+ f"Shape mismatch: tensor has {arr.shape[1]} columns, "
660
+ f"but {len(column_names)} names provided"
661
+ )
662
+
663
+ # Create dictionary for DataFrame/Table creation
664
+ data = {}
665
+ for i, name in enumerate(column_names):
666
+ data[name] = arr[:, i]
667
+
668
+ return _wrap(pa.Table.from_pydict(data))
669
+
670
+
671
+ @table.constant_p.def_impl
672
+ def constant_impl(interpreter: Interpreter, op: Operation) -> TableValue:
673
+ """Create constant table."""
674
+ # data is stored in attrs by default bind if not TraceObject
675
+ data = op.attrs["data"]
676
+
677
+ # Handle pandas DataFrame if passed directly (though attrs usually store basic types)
678
+ # If data was a DataFrame, it might have been stored as is if the IR supports it.
679
+ # If data was a dict, it's fine.
680
+
681
+ if isinstance(data, pa.Table):
682
+ return _wrap(data)
683
+ else:
684
+ return _wrap(pa.table(data))
685
+
686
+
687
+ def _infer_format(path: str, format_hint: str) -> str:
688
+ """Infer file format from path extension or hint."""
689
+ if format_hint != "auto":
690
+ return format_hint
691
+
692
+ path_lower = path.lower()
693
+ if path_lower.endswith((".parquet", ".pq")):
694
+ return "parquet"
695
+ elif path_lower.endswith(".csv"):
696
+ return "csv"
697
+ elif path_lower.endswith((".json", ".jsonl")):
698
+ return "json"
699
+ else:
700
+ # Default to parquet
701
+ return "parquet"
702
+
703
+
704
+ @table.read_p.def_impl
705
+ def read_impl(interpreter: Interpreter, op: Operation) -> TableValue:
706
+ """Read table from file.
707
+
708
+ Supported formats: parquet, csv, json
709
+ """
710
+ import os
711
+
712
+ path: str = op.attrs["path"]
713
+ schema: elt.TableType = op.attrs["schema"]
714
+ format_hint: str = op.attrs.get("format", "auto")
715
+ fmt = _infer_format(path, format_hint)
716
+ if not os.path.exists(path):
717
+ raise FileNotFoundError(f"{path} not exists")
718
+
719
+ pa_schema = _pa_schema(schema) if schema else None
720
+ return _wrap(FileTableSource(path=path, format=fmt, schema=pa_schema))
721
+
722
+
723
+ class MultiTableReader(BatchReader):
724
+ def __init__(self, readers: list[TableReader]) -> None:
725
+ fields = {}
726
+ for r in readers:
727
+ for f in r.schema:
728
+ if f.name in fields:
729
+ raise ValueError(f"Field name conflict. {f.name}")
730
+ fields[f.name] = f
731
+
732
+ self._readers = readers
733
+ self._schema = pa.schema(list(fields.values()))
734
+
735
+ @property
736
+ def schema(self) -> pa.Schema:
737
+ return self._schema
738
+
739
+ def read_next_batch(self) -> pa.RecordBatch:
740
+ num_rows = -1
741
+ columns: list[pa.ChunkedArray] = []
742
+ for idx, r in enumerate(self._readers):
743
+ batch = r.read_next_batch()
744
+ if num_rows == -1:
745
+ num_rows = batch.num_rows
746
+ elif num_rows != batch.num_rows:
747
+ raise ValueError(
748
+ f"Batch {idx} has {batch.num_rows} rows, expected {num_rows}"
749
+ )
750
+ columns.extend(batch.columns)
751
+ return pa.RecordBatch.from_arrays(columns, names=self._schema.names)
752
+
753
+ def close(self) -> None:
754
+ for r in self._readers:
755
+ r.close()
756
+
757
+
758
+ @table.write_p.def_impl
759
+ def write_impl(interpreter: Interpreter, op: Operation, *tables: TableValue) -> None:
760
+ """Write table to file.
761
+
762
+ Supported formats: parquet, csv, json
763
+
764
+ For LazyTable, performs streaming writes when supported.
765
+ For regular Tables, performs direct writes.
766
+ """
767
+ import os
768
+
769
+ path: str = op.attrs["path"]
770
+ format_hint: str = op.attrs.get("format", "parquet")
771
+
772
+ fmt = _infer_format(path, format_hint)
773
+
774
+ batch_size = DEFAULT_BATCH_SIZE if len(tables) > 1 else -1
775
+ readers: list[TableReader] = []
776
+ for t in tables:
777
+ data = t.unwrap()
778
+ readers.append(
779
+ data.open(batch_size)
780
+ if isinstance(data, TableSource)
781
+ else TableReader(data)
782
+ )
783
+
784
+ reader: BatchReader = readers[0] if len(readers) == 1 else MultiTableReader(readers)
785
+
786
+ import pyarrow.csv as pa_csv
787
+ import pyarrow.parquet as pa_pq
788
+
789
+ @runtime_checkable
790
+ class BatchWriter(Protocol):
791
+ def write_batch(self, batch: pa.RecordBatch) -> None: ...
792
+ def close(self) -> None: ...
793
+
794
+ class JsonWriter(BatchWriter):
795
+ def __init__(self, path: str) -> None:
796
+ self._path = path
797
+ self._batches: list[pa.RecordBatch] = []
798
+
799
+ def write_batch(self, batch: pa.RecordBatch) -> None:
800
+ self._batches.append(batch)
801
+
802
+ def close(self) -> None:
803
+ # PyArrow doesn't have direct JSON write, convert to pandas
804
+ tbl = pa.Table.from_batches(self._batches)
805
+ df = tbl.to_pandas()
806
+ df.to_json(self._path, orient="records", lines=True)
807
+
808
+ def _safe_remove_file(path: str) -> None:
809
+ if os.path.exists(path):
810
+ try:
811
+ os.remove(path)
812
+ except Exception:
813
+ pass # Ignore cleanup errors
814
+
815
+ try:
816
+ match fmt:
817
+ case "parquet":
818
+ writer = pa_pq.ParquetWriter(path, reader.schema)
819
+ case "csv":
820
+ writer = pa_csv.CSVWriter(path, reader.schema)
821
+ case "json":
822
+ writer = JsonWriter(path)
823
+ case _:
824
+ raise ValueError(f"Unsupported format: {fmt}")
825
+ except Exception as e:
826
+ reader.close()
827
+ _safe_remove_file(path)
828
+ raise e
829
+
830
+ try:
831
+ for batch in reader:
832
+ writer.write_batch(batch)
833
+ except Exception as e:
834
+ _safe_remove_file(path)
835
+ raise e
836
+ finally:
837
+ reader.close()
838
+ writer.close()