mplang-nightly 0.1.dev192__py3-none-any.whl → 0.1.dev268__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (188) hide show
  1. mplang/__init__.py +21 -130
  2. mplang/py.typed +13 -0
  3. mplang/v1/__init__.py +157 -0
  4. mplang/v1/_device.py +602 -0
  5. mplang/{analysis → v1/analysis}/__init__.py +1 -1
  6. mplang/{analysis → v1/analysis}/diagram.py +4 -4
  7. mplang/{core → v1/core}/__init__.py +20 -14
  8. mplang/{core → v1/core}/cluster.py +6 -1
  9. mplang/{core → v1/core}/comm.py +1 -1
  10. mplang/{core → v1/core}/context_mgr.py +1 -1
  11. mplang/{core → v1/core}/dtypes.py +38 -0
  12. mplang/{core → v1/core}/expr/__init__.py +7 -7
  13. mplang/{core → v1/core}/expr/ast.py +11 -13
  14. mplang/{core → v1/core}/expr/evaluator.py +8 -8
  15. mplang/{core → v1/core}/expr/printer.py +6 -6
  16. mplang/{core → v1/core}/expr/transformer.py +2 -2
  17. mplang/{core → v1/core}/expr/utils.py +2 -2
  18. mplang/{core → v1/core}/expr/visitor.py +1 -1
  19. mplang/{core → v1/core}/expr/walk.py +1 -1
  20. mplang/{core → v1/core}/interp.py +6 -6
  21. mplang/{core → v1/core}/mpir.py +13 -11
  22. mplang/{core → v1/core}/mpobject.py +6 -6
  23. mplang/{core → v1/core}/mptype.py +13 -10
  24. mplang/{core → v1/core}/pfunc.py +2 -2
  25. mplang/{core → v1/core}/primitive.py +12 -12
  26. mplang/{core → v1/core}/table.py +36 -8
  27. mplang/{core → v1/core}/tensor.py +1 -1
  28. mplang/{core → v1/core}/tracer.py +9 -9
  29. mplang/{host.py → v1/host.py} +5 -5
  30. mplang/{kernels → v1/kernels}/__init__.py +1 -1
  31. mplang/{kernels → v1/kernels}/base.py +1 -1
  32. mplang/{kernels → v1/kernels}/basic.py +15 -15
  33. mplang/{kernels → v1/kernels}/context.py +19 -16
  34. mplang/{kernels → v1/kernels}/crypto.py +8 -10
  35. mplang/{kernels → v1/kernels}/fhe.py +9 -7
  36. mplang/{kernels → v1/kernels}/mock_tee.py +3 -3
  37. mplang/{kernels → v1/kernels}/phe.py +26 -18
  38. mplang/{kernels → v1/kernels}/spu.py +5 -5
  39. mplang/{kernels → v1/kernels}/sql_duckdb.py +5 -3
  40. mplang/{kernels → v1/kernels}/stablehlo.py +18 -17
  41. mplang/{kernels → v1/kernels}/value.py +2 -2
  42. mplang/{ops → v1/ops}/__init__.py +3 -3
  43. mplang/{ops → v1/ops}/base.py +1 -1
  44. mplang/{ops → v1/ops}/basic.py +6 -5
  45. mplang/v1/ops/crypto.py +262 -0
  46. mplang/{ops → v1/ops}/fhe.py +2 -2
  47. mplang/{ops → v1/ops}/jax_cc.py +26 -59
  48. mplang/v1/ops/nnx_cc.py +168 -0
  49. mplang/{ops → v1/ops}/phe.py +16 -3
  50. mplang/{ops → v1/ops}/spu.py +3 -3
  51. mplang/v1/ops/sql_cc.py +303 -0
  52. mplang/{ops → v1/ops}/tee.py +2 -2
  53. mplang/{runtime → v1/runtime}/__init__.py +2 -2
  54. mplang/v1/runtime/channel.py +230 -0
  55. mplang/{runtime → v1/runtime}/cli.py +3 -3
  56. mplang/{runtime → v1/runtime}/client.py +1 -1
  57. mplang/{runtime → v1/runtime}/communicator.py +39 -15
  58. mplang/{runtime → v1/runtime}/data_providers.py +80 -19
  59. mplang/{runtime → v1/runtime}/driver.py +4 -4
  60. mplang/v1/runtime/link_comm.py +196 -0
  61. mplang/{runtime → v1/runtime}/server.py +22 -9
  62. mplang/{runtime → v1/runtime}/session.py +24 -51
  63. mplang/{runtime → v1/runtime}/simulation.py +36 -14
  64. mplang/{simp → v1/simp}/api.py +72 -14
  65. mplang/{simp → v1/simp}/mpi.py +1 -1
  66. mplang/{simp → v1/simp}/party.py +5 -5
  67. mplang/{simp → v1/simp}/random.py +2 -2
  68. mplang/v1/simp/smpc.py +238 -0
  69. mplang/v1/utils/table_utils.py +185 -0
  70. mplang/v2/__init__.py +424 -0
  71. mplang/v2/backends/__init__.py +57 -0
  72. mplang/v2/backends/bfv_impl.py +705 -0
  73. mplang/v2/backends/channel.py +217 -0
  74. mplang/v2/backends/crypto_impl.py +723 -0
  75. mplang/v2/backends/field_impl.py +454 -0
  76. mplang/v2/backends/func_impl.py +107 -0
  77. mplang/v2/backends/phe_impl.py +148 -0
  78. mplang/v2/backends/simp_design.md +136 -0
  79. mplang/v2/backends/simp_driver/__init__.py +41 -0
  80. mplang/v2/backends/simp_driver/http.py +168 -0
  81. mplang/v2/backends/simp_driver/mem.py +280 -0
  82. mplang/v2/backends/simp_driver/ops.py +135 -0
  83. mplang/v2/backends/simp_driver/state.py +60 -0
  84. mplang/v2/backends/simp_driver/values.py +52 -0
  85. mplang/v2/backends/simp_worker/__init__.py +29 -0
  86. mplang/v2/backends/simp_worker/http.py +354 -0
  87. mplang/v2/backends/simp_worker/mem.py +102 -0
  88. mplang/v2/backends/simp_worker/ops.py +167 -0
  89. mplang/v2/backends/simp_worker/state.py +49 -0
  90. mplang/v2/backends/spu_impl.py +275 -0
  91. mplang/v2/backends/spu_state.py +187 -0
  92. mplang/v2/backends/store_impl.py +62 -0
  93. mplang/v2/backends/table_impl.py +838 -0
  94. mplang/v2/backends/tee_impl.py +215 -0
  95. mplang/v2/backends/tensor_impl.py +519 -0
  96. mplang/v2/cli.py +603 -0
  97. mplang/v2/cli_guide.md +122 -0
  98. mplang/v2/dialects/__init__.py +36 -0
  99. mplang/v2/dialects/bfv.py +665 -0
  100. mplang/v2/dialects/crypto.py +689 -0
  101. mplang/v2/dialects/dtypes.py +378 -0
  102. mplang/v2/dialects/field.py +210 -0
  103. mplang/v2/dialects/func.py +135 -0
  104. mplang/v2/dialects/phe.py +723 -0
  105. mplang/v2/dialects/simp.py +944 -0
  106. mplang/v2/dialects/spu.py +349 -0
  107. mplang/v2/dialects/store.py +63 -0
  108. mplang/v2/dialects/table.py +407 -0
  109. mplang/v2/dialects/tee.py +346 -0
  110. mplang/v2/dialects/tensor.py +1175 -0
  111. mplang/v2/edsl/README.md +279 -0
  112. mplang/v2/edsl/__init__.py +99 -0
  113. mplang/v2/edsl/context.py +311 -0
  114. mplang/v2/edsl/graph.py +463 -0
  115. mplang/v2/edsl/jit.py +62 -0
  116. mplang/v2/edsl/object.py +53 -0
  117. mplang/v2/edsl/primitive.py +284 -0
  118. mplang/v2/edsl/printer.py +119 -0
  119. mplang/v2/edsl/registry.py +207 -0
  120. mplang/v2/edsl/serde.py +375 -0
  121. mplang/v2/edsl/tracer.py +614 -0
  122. mplang/v2/edsl/typing.py +816 -0
  123. mplang/v2/kernels/Makefile +30 -0
  124. mplang/v2/kernels/__init__.py +23 -0
  125. mplang/v2/kernels/gf128.cpp +148 -0
  126. mplang/v2/kernels/ldpc.cpp +82 -0
  127. mplang/v2/kernels/okvs.cpp +283 -0
  128. mplang/v2/kernels/okvs_opt.cpp +291 -0
  129. mplang/v2/kernels/py_kernels.py +398 -0
  130. mplang/v2/libs/collective.py +330 -0
  131. mplang/v2/libs/device/__init__.py +51 -0
  132. mplang/v2/libs/device/api.py +813 -0
  133. mplang/v2/libs/device/cluster.py +352 -0
  134. mplang/v2/libs/ml/__init__.py +23 -0
  135. mplang/v2/libs/ml/sgb.py +1861 -0
  136. mplang/v2/libs/mpc/__init__.py +41 -0
  137. mplang/v2/libs/mpc/_utils.py +99 -0
  138. mplang/v2/libs/mpc/analytics/__init__.py +35 -0
  139. mplang/v2/libs/mpc/analytics/aggregation.py +372 -0
  140. mplang/v2/libs/mpc/analytics/groupby.md +99 -0
  141. mplang/v2/libs/mpc/analytics/groupby.py +331 -0
  142. mplang/v2/libs/mpc/analytics/permutation.py +386 -0
  143. mplang/v2/libs/mpc/common/constants.py +39 -0
  144. mplang/v2/libs/mpc/ot/__init__.py +32 -0
  145. mplang/v2/libs/mpc/ot/base.py +222 -0
  146. mplang/v2/libs/mpc/ot/extension.py +477 -0
  147. mplang/v2/libs/mpc/ot/silent.py +217 -0
  148. mplang/v2/libs/mpc/psi/__init__.py +40 -0
  149. mplang/v2/libs/mpc/psi/cuckoo.py +228 -0
  150. mplang/v2/libs/mpc/psi/okvs.py +49 -0
  151. mplang/v2/libs/mpc/psi/okvs_gct.py +79 -0
  152. mplang/v2/libs/mpc/psi/oprf.py +310 -0
  153. mplang/v2/libs/mpc/psi/rr22.py +344 -0
  154. mplang/v2/libs/mpc/psi/unbalanced.py +200 -0
  155. mplang/v2/libs/mpc/vole/__init__.py +31 -0
  156. mplang/v2/libs/mpc/vole/gilboa.py +327 -0
  157. mplang/v2/libs/mpc/vole/ldpc.py +383 -0
  158. mplang/v2/libs/mpc/vole/silver.py +336 -0
  159. mplang/v2/runtime/__init__.py +15 -0
  160. mplang/v2/runtime/dialect_state.py +41 -0
  161. mplang/v2/runtime/interpreter.py +871 -0
  162. mplang/v2/runtime/object_store.py +194 -0
  163. mplang/v2/runtime/value.py +141 -0
  164. {mplang_nightly-0.1.dev192.dist-info → mplang_nightly-0.1.dev268.dist-info}/METADATA +22 -16
  165. mplang_nightly-0.1.dev268.dist-info/RECORD +180 -0
  166. {mplang_nightly-0.1.dev192.dist-info → mplang_nightly-0.1.dev268.dist-info}/WHEEL +1 -1
  167. mplang/device.py +0 -327
  168. mplang/ops/crypto.py +0 -108
  169. mplang/ops/ibis_cc.py +0 -136
  170. mplang/ops/sql_cc.py +0 -62
  171. mplang/runtime/link_comm.py +0 -78
  172. mplang/simp/smpc.py +0 -201
  173. mplang/utils/table_utils.py +0 -85
  174. mplang_nightly-0.1.dev192.dist-info/RECORD +0 -83
  175. /mplang/{core → v1/core}/mask.py +0 -0
  176. /mplang/{protos → v1/protos}/v1alpha1/mpir_pb2.py +0 -0
  177. /mplang/{protos → v1/protos}/v1alpha1/mpir_pb2.pyi +0 -0
  178. /mplang/{protos → v1/protos}/v1alpha1/value_pb2.py +0 -0
  179. /mplang/{protos → v1/protos}/v1alpha1/value_pb2.pyi +0 -0
  180. /mplang/{runtime → v1/runtime}/exceptions.py +0 -0
  181. /mplang/{runtime → v1/runtime}/http_api.md +0 -0
  182. /mplang/{simp → v1/simp}/__init__.py +0 -0
  183. /mplang/{utils → v1/utils}/__init__.py +0 -0
  184. /mplang/{utils → v1/utils}/crypto.py +0 -0
  185. /mplang/{utils → v1/utils}/func_utils.py +0 -0
  186. /mplang/{utils → v1/utils}/spu_utils.py +0 -0
  187. {mplang_nightly-0.1.dev192.dist-info → mplang_nightly-0.1.dev268.dist-info}/entry_points.txt +0 -0
  188. {mplang_nightly-0.1.dev192.dist-info → mplang_nightly-0.1.dev268.dist-info}/licenses/LICENSE +0 -0
@@ -0,0 +1,407 @@
1
+ # Copyright 2025 Ant Group Co., Ltd.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ """Table dialect: table operations backed by plaintext/private SQL engines."""
16
+
17
+ from __future__ import annotations
18
+
19
+ from typing import Any, cast
20
+
21
+ import mplang.v2.edsl as el
22
+ import mplang.v2.edsl.typing as elt
23
+
24
+ run_sql_p: el.Primitive[Any] = el.Primitive("table.run_sql")
25
+ table2tensor_p: el.Primitive[el.Object] = el.Primitive("table.table2tensor")
26
+ tensor2table_p: el.Primitive[el.Object] = el.Primitive("table.tensor2table")
27
+ constant_p: el.Primitive[el.Object] = el.Primitive("table.constant")
28
+ read_p: el.Primitive[el.Object] = el.Primitive("table.read")
29
+ write_p: el.Primitive[el.Object] = el.Primitive("table.write")
30
+
31
+
32
+ def _current_tracer() -> el.Tracer:
33
+ ctx = el.get_current_context()
34
+ if not isinstance(ctx, el.Tracer):
35
+ raise TypeError(f"Expected Tracer context, got {type(ctx)}")
36
+ return ctx
37
+
38
+
39
+ @run_sql_p.def_trace
40
+ def _run_sql_trace(
41
+ query: str,
42
+ *,
43
+ out_type: elt.TableType,
44
+ dialect: str = "duckdb",
45
+ **tables: el.TraceObject,
46
+ ) -> el.TraceObject:
47
+ tracer = _current_tracer()
48
+ if not isinstance(out_type, elt.TableType):
49
+ raise TypeError("run_sql out_type must be TableType")
50
+ if not tables:
51
+ raise ValueError("run_sql requires at least one table input")
52
+
53
+ ordered = list(tables.items())
54
+ inputs = []
55
+ names = []
56
+ for name, table in ordered:
57
+ if not isinstance(table, el.TraceObject):
58
+ raise TypeError(f"Table '{name}' must be TraceObject")
59
+ inputs.append(table._graph_value)
60
+ names.append(name)
61
+
62
+ [value] = tracer.graph.add_op(
63
+ opcode="table.run_sql",
64
+ inputs=inputs,
65
+ output_types=[out_type],
66
+ attrs={"query": query, "dialect": dialect, "table_names": names},
67
+ )
68
+ return el.TraceObject(value, tracer)
69
+
70
+
71
+ @table2tensor_p.def_abstract_eval
72
+ def _table2tensor_ae(table_t: elt.TableType, *, number_rows: int) -> elt.TensorType:
73
+ """Infer tensor type for table.table2tensor."""
74
+
75
+ if not isinstance(number_rows, int):
76
+ raise TypeError("number_rows must be an int")
77
+ if number_rows < 0:
78
+ raise ValueError("number_rows must be >= 0")
79
+ if not table_t.schema:
80
+ raise ValueError("Cannot convert empty table to tensor")
81
+ column_types = list(table_t.schema.values())
82
+ first = column_types[0]
83
+
84
+ def _scalar_dtype(col: elt.BaseType) -> elt.BaseType:
85
+ if hasattr(col, "element_type"):
86
+ tensor_col = col # type: ignore[assignment]
87
+ if tensor_col.shape not in ((), None): # type: ignore[attr-defined]
88
+ raise TypeError(
89
+ "table2tensor expects scalar columns (rank-0 TensorType)"
90
+ )
91
+ return tensor_col.element_type # type: ignore[attr-defined,no-any-return]
92
+ return col
93
+
94
+ first_scalar = _scalar_dtype(first)
95
+ for col in column_types[1:]:
96
+ if _scalar_dtype(col) != first_scalar:
97
+ raise TypeError("All table columns must share the same scalar dtype")
98
+ if not isinstance(first_scalar, elt.BaseType):
99
+ raise TypeError("All table columns must share the same dtype for table2tensor")
100
+ return elt.TensorType(first_scalar, (number_rows, len(column_types)))
101
+
102
+
103
+ @tensor2table_p.def_abstract_eval
104
+ def _tensor2table_ae(
105
+ tensor_t: elt.TensorType, *, column_names: list[str]
106
+ ) -> elt.TableType:
107
+ """Infer table type for table.tensor2table."""
108
+
109
+ if len(tensor_t.shape) != 2:
110
+ raise TypeError(
111
+ f"tensor2table expects rank-2 tensor (N, F), got rank {len(tensor_t.shape)}"
112
+ )
113
+ n_cols = tensor_t.shape[1]
114
+ if not column_names:
115
+ raise ValueError("column_names must be provided")
116
+ if len(column_names) != n_cols:
117
+ raise ValueError("column_names length must match tensor second dimension")
118
+ seen: set[str] = set()
119
+ schema: dict[str, elt.BaseType] = {}
120
+ for idx, name in enumerate(column_names):
121
+ if not isinstance(name, str):
122
+ raise TypeError(
123
+ f"column_names[{idx}] must be str, got {type(name).__name__}"
124
+ )
125
+ if name.strip() == "":
126
+ raise ValueError("column names must be non-empty/non-whitespace")
127
+ if name in seen:
128
+ raise ValueError(f"duplicate column name: {name!r}")
129
+ seen.add(name)
130
+ schema[name] = tensor_t.element_type
131
+ # Each column shares the tensor's element dtype.
132
+ return elt.TableType(schema)
133
+
134
+
135
+ def run_sql(
136
+ query: str,
137
+ *,
138
+ out_type: elt.TableType,
139
+ dialect: str = "duckdb",
140
+ **tables: el.TraceObject,
141
+ ) -> el.TraceObject:
142
+ """Trace a SQL query over plaintext/private tables.
143
+
144
+ Inserts a `table.run_sql` op with the provided query string and table inputs.
145
+ The `out_type` describes the resulting table schema (columns + types).
146
+ """
147
+
148
+ return run_sql_p.bind( # type: ignore[no-any-return]
149
+ query,
150
+ out_type=out_type,
151
+ dialect=dialect,
152
+ **tables,
153
+ )
154
+
155
+
156
+ def table2tensor(table: el.TraceObject, *, number_rows: int) -> el.Object:
157
+ """Convert a homogeneous table into a dense tensor."""
158
+
159
+ return table2tensor_p.bind(table, number_rows=number_rows)
160
+
161
+
162
+ def tensor2table(tensor: el.TraceObject, *, column_names: list[str]) -> el.Object:
163
+ """Convert a rank-2 tensor (N, F) into a table with named columns."""
164
+
165
+ return tensor2table_p.bind(tensor, column_names=column_names)
166
+
167
+
168
+ @constant_p.def_abstract_eval
169
+ def _constant_ae(*, data: Any) -> elt.TableType:
170
+ """Infer table type for constant data.
171
+
172
+ Args:
173
+ data: Dictionary mapping column names to lists of values,
174
+ pandas DataFrame, PyArrow Table, or any data convertible to DataFrame
175
+
176
+ Returns:
177
+ TableType inferred from schema
178
+
179
+ Raises:
180
+ TypeError: If data cannot be converted to DataFrame
181
+ """
182
+ import pandas as pd
183
+ import pyarrow as pa
184
+
185
+ from mplang.v2.dialects import dtypes
186
+
187
+ # Handle PyArrow Table directly
188
+ if isinstance(data, pa.Table):
189
+ schema: dict[str, elt.BaseType] = {}
190
+ for field in data.schema:
191
+ schema[field.name] = dtypes.from_arrow(field.type)
192
+ return elt.TableType(schema)
193
+
194
+ # Handle pandas DataFrame
195
+ if isinstance(data, pd.DataFrame):
196
+ df = data
197
+ else:
198
+ # Dict or other types - convert to DataFrame
199
+ df = pd.DataFrame(data)
200
+
201
+ # Infer schema from pandas dtypes
202
+ schema = {}
203
+ for col_name in df.columns:
204
+ schema[str(col_name)] = dtypes.from_pandas(df[col_name].dtype)
205
+
206
+ return elt.TableType(schema)
207
+
208
+
209
+ def constant(data: dict[str, list]) -> el.Object:
210
+ """Create a table constant value.
211
+
212
+ This creates a constant table that can be used in table computations.
213
+ The constant value is embedded directly into the computation graph.
214
+
215
+ Args:
216
+ data: Dictionary mapping column names to lists of values,
217
+ pandas DataFrame, or any data convertible to DataFrame.
218
+ All columns must have the same length.
219
+
220
+ Returns:
221
+ Object representing the constant table (TraceObject in trace mode,
222
+ InterpObject in interp mode)
223
+
224
+ Raises:
225
+ TypeError: If data cannot be converted to DataFrame
226
+ ValueError: If columns have different lengths
227
+
228
+ Example:
229
+ >>> # From dict
230
+ >>> table = constant({
231
+ ... "id": [1, 2, 3],
232
+ ... "name": ["alice", "bob", "charlie"],
233
+ ... "score": [95.5, 87.2, 92.8],
234
+ ... })
235
+ >>> # From DataFrame
236
+ >>> import pandas as pd
237
+ >>> df = pd.DataFrame({"a": [1, 2], "b": [3.0, 4.0]})
238
+ >>> table = constant(df)
239
+ """
240
+ return constant_p.bind(data=data) # type: ignore[no-any-return]
241
+
242
+
243
+ # =============================================================================
244
+ # Table I/O: read and write
245
+ # =============================================================================
246
+
247
+
248
+ @read_p.def_abstract_eval
249
+ def _read_ae(*, path: str, schema: elt.TableType, format: str) -> elt.TableType:
250
+ """Infer output type for table.read.
251
+
252
+ Args:
253
+ path: File path to read from
254
+ schema: Expected table schema
255
+ format: File format ("auto", "csv", "parquet")
256
+
257
+ Returns:
258
+ The provided schema (since we can't inspect the file at trace time)
259
+
260
+ Raises:
261
+ TypeError: If schema is not a TableType
262
+ ValueError: If path is empty or format is invalid
263
+ """
264
+ if not isinstance(path, str) or not path:
265
+ raise ValueError("path must be a non-empty string")
266
+ if not isinstance(schema, elt.TableType):
267
+ raise TypeError(f"schema must be TableType, got {type(schema).__name__}")
268
+ if format not in ("auto", "csv", "parquet"):
269
+ raise ValueError(f"format must be 'auto', 'csv', or 'parquet', got {format!r}")
270
+ return schema
271
+
272
+
273
+ def read(
274
+ path: str,
275
+ *,
276
+ schema: elt.TableType,
277
+ format: str = "auto",
278
+ ) -> el.Object:
279
+ """Read a table from a file.
280
+
281
+ This creates a table.read operation that reads data from the specified path
282
+ at runtime. The schema must be provided since the file cannot be inspected
283
+ at trace/compile time.
284
+
285
+ Args:
286
+ path: File path to read from. In distributed scenarios, each party
287
+ interprets this path relative to its own filesystem.
288
+ schema: Expected table schema. Must match the actual file structure.
289
+ format: File format. Options:
290
+ - "auto": Detect from file extension (.csv, .parquet)
291
+ - "csv": Read as CSV
292
+ - "parquet": Read as Parquet
293
+
294
+ Returns:
295
+ Table object with the specified schema.
296
+
297
+ Example:
298
+ >>> schema = TableType({
299
+ ... "id": TensorType(i64, ()),
300
+ ... "value": TensorType(f64, ()),
301
+ ... })
302
+ >>> tbl = table.read("/data/input.csv", schema=schema)
303
+ """
304
+ return read_p.bind(path=path, schema=schema, format=format) # type: ignore[no-any-return]
305
+
306
+
307
+ @write_p.def_abstract_eval
308
+ def _write_ae(in_types: list[elt.BaseType], *, path: str, format: str) -> elt.TableType:
309
+ """Infer output type for table.write.
310
+
311
+ Args:
312
+ in_types: Input table's type
313
+ path: File path to write to
314
+ format: Output format ("csv", "parquet")
315
+
316
+ Returns:
317
+ The input table type
318
+
319
+ Raises:
320
+ TypeError: If input is not a TableType
321
+ ValueError: If path is empty or format is invalid
322
+ """
323
+
324
+ if not in_types:
325
+ raise ValueError(
326
+ f"write requires at least one input table, got {len(in_types)}"
327
+ )
328
+
329
+ # Verify all inputs are TableType
330
+ for i, t in enumerate(in_types):
331
+ if not isinstance(t, elt.TableType):
332
+ raise TypeError(f"Input {i} is not TableType: {type(t)}")
333
+
334
+ table_types = cast(list[elt.TableType], in_types)
335
+ columns = {}
336
+ for table_type in table_types:
337
+ for col_name in table_type.schema:
338
+ if col_name in columns:
339
+ raise ValueError(
340
+ f"Duplicate column name '{col_name}' found across tables. "
341
+ f"When writing multiple tables, column names must be unique."
342
+ )
343
+ columns.update(table_type.schema)
344
+
345
+ if not isinstance(path, str) or not path:
346
+ raise ValueError("path must be a non-empty string")
347
+ if format not in ("auto", "parquet", "csv", "json"):
348
+ raise ValueError(
349
+ f"format must be in ['auto', 'parquet', 'csv', 'json'], got {format!r}"
350
+ )
351
+ return elt.TableType(columns)
352
+
353
+
354
+ def write(
355
+ tables: el.Object | list[el.Object] | Any,
356
+ path: str,
357
+ *,
358
+ format: str = "parquet",
359
+ ) -> el.Object | None:
360
+ """Write a table to a file.
361
+
362
+ This creates a table.write operation that persists the table data at runtime.
363
+ The operation returns the input table unchanged, allowing chaining.
364
+
365
+ If a runtime value (e.g., PyArrow Table, DataFrame, dict) is passed instead of
366
+ a traced object, it will be wrapped with table.constant() automatically.
367
+
368
+ Args:
369
+ table: Table to write. Can be a TraceObject, PyArrow Table, DataFrame, or dict.
370
+ path: Destination file path. In distributed scenarios, each party
371
+ interprets this path relative to its own filesystem.
372
+ format: Output format. Options:
373
+ - "csv": Write as CSV
374
+ - "parquet": Write as Parquet (default, more efficient)
375
+
376
+ Returns:
377
+ The input table (passthrough for chaining), or None in interpreter mode.
378
+
379
+ Example:
380
+ >>> result = table.run_sql("SELECT ...", out_type=schema, input=tbl)
381
+ >>> table.write(result, "/data/output.parquet")
382
+ """
383
+ # Auto-wrap runtime values
384
+ if not isinstance(tables, list):
385
+ tables = [tables]
386
+
387
+ for idx, tbl in enumerate(tables):
388
+ if not isinstance(tbl, el.Object):
389
+ tables[idx] = constant(tbl)
390
+
391
+ return write_p.bind(*tables, path=path, format=format) # type: ignore[no-any-return]
392
+
393
+
394
+ __all__ = [
395
+ "constant",
396
+ "constant_p",
397
+ "read",
398
+ "read_p",
399
+ "run_sql",
400
+ "run_sql_p",
401
+ "table2tensor",
402
+ "table2tensor_p",
403
+ "tensor2table",
404
+ "tensor2table_p",
405
+ "write",
406
+ "write_p",
407
+ ]