mplang-nightly 0.1.dev158__py3-none-any.whl → 0.1.dev268__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- mplang/__init__.py +21 -45
- mplang/py.typed +13 -0
- mplang/v1/__init__.py +157 -0
- mplang/v1/_device.py +602 -0
- mplang/{analysis → v1/analysis}/__init__.py +1 -1
- mplang/{analysis → v1/analysis}/diagram.py +5 -7
- mplang/v1/core/__init__.py +157 -0
- mplang/{core → v1/core}/cluster.py +30 -14
- mplang/{core → v1/core}/comm.py +5 -1
- mplang/{core → v1/core}/context_mgr.py +1 -1
- mplang/{core/dtype.py → v1/core/dtypes.py} +44 -2
- mplang/{core → v1/core}/expr/__init__.py +7 -7
- mplang/{core → v1/core}/expr/ast.py +13 -14
- mplang/{core → v1/core}/expr/evaluator.py +65 -24
- mplang/{core → v1/core}/expr/printer.py +24 -18
- mplang/{core → v1/core}/expr/transformer.py +3 -3
- mplang/{core → v1/core}/expr/utils.py +2 -2
- mplang/{core → v1/core}/expr/visitor.py +1 -1
- mplang/{core → v1/core}/expr/walk.py +1 -1
- mplang/{core → v1/core}/interp.py +6 -6
- mplang/{core → v1/core}/mpir.py +23 -16
- mplang/{core → v1/core}/mpobject.py +6 -6
- mplang/{core → v1/core}/mptype.py +13 -10
- mplang/{core → v1/core}/pfunc.py +4 -4
- mplang/{core → v1/core}/primitive.py +106 -201
- mplang/{core → v1/core}/table.py +36 -8
- mplang/{core → v1/core}/tensor.py +1 -1
- mplang/{core → v1/core}/tracer.py +9 -9
- mplang/{api.py → v1/host.py} +38 -6
- mplang/v1/kernels/__init__.py +41 -0
- mplang/{kernels → v1/kernels}/base.py +1 -1
- mplang/v1/kernels/basic.py +240 -0
- mplang/{kernels → v1/kernels}/context.py +42 -27
- mplang/{kernels → v1/kernels}/crypto.py +44 -37
- mplang/v1/kernels/fhe.py +858 -0
- mplang/{kernels → v1/kernels}/mock_tee.py +12 -13
- mplang/{kernels → v1/kernels}/phe.py +263 -57
- mplang/{kernels → v1/kernels}/spu.py +137 -48
- mplang/{kernels → v1/kernels}/sql_duckdb.py +12 -15
- mplang/{kernels → v1/kernels}/stablehlo.py +30 -23
- mplang/v1/kernels/value.py +626 -0
- mplang/{ops → v1/ops}/__init__.py +5 -16
- mplang/{ops → v1/ops}/base.py +2 -5
- mplang/{ops/builtin.py → v1/ops/basic.py} +34 -26
- mplang/v1/ops/crypto.py +262 -0
- mplang/v1/ops/fhe.py +272 -0
- mplang/{ops → v1/ops}/jax_cc.py +33 -68
- mplang/v1/ops/nnx_cc.py +168 -0
- mplang/{ops → v1/ops}/phe.py +16 -4
- mplang/{ops → v1/ops}/spu.py +3 -5
- mplang/v1/ops/sql_cc.py +303 -0
- mplang/{ops → v1/ops}/tee.py +9 -24
- mplang/{protos → v1/protos}/v1alpha1/mpir_pb2.pyi +71 -21
- mplang/v1/protos/v1alpha1/value_pb2.py +34 -0
- mplang/v1/protos/v1alpha1/value_pb2.pyi +169 -0
- mplang/{runtime → v1/runtime}/__init__.py +2 -2
- mplang/v1/runtime/channel.py +230 -0
- mplang/{runtime → v1/runtime}/cli.py +35 -20
- mplang/{runtime → v1/runtime}/client.py +19 -8
- mplang/{runtime → v1/runtime}/communicator.py +59 -15
- mplang/{runtime → v1/runtime}/data_providers.py +80 -19
- mplang/{runtime → v1/runtime}/driver.py +30 -12
- mplang/v1/runtime/link_comm.py +196 -0
- mplang/{runtime → v1/runtime}/server.py +58 -42
- mplang/{runtime → v1/runtime}/session.py +57 -71
- mplang/{runtime → v1/runtime}/simulation.py +55 -28
- mplang/v1/simp/api.py +353 -0
- mplang/{simp → v1/simp}/mpi.py +8 -9
- mplang/{simp/__init__.py → v1/simp/party.py} +19 -145
- mplang/{simp → v1/simp}/random.py +21 -22
- mplang/v1/simp/smpc.py +238 -0
- mplang/v1/utils/table_utils.py +185 -0
- mplang/v2/__init__.py +424 -0
- mplang/v2/backends/__init__.py +57 -0
- mplang/v2/backends/bfv_impl.py +705 -0
- mplang/v2/backends/channel.py +217 -0
- mplang/v2/backends/crypto_impl.py +723 -0
- mplang/v2/backends/field_impl.py +454 -0
- mplang/v2/backends/func_impl.py +107 -0
- mplang/v2/backends/phe_impl.py +148 -0
- mplang/v2/backends/simp_design.md +136 -0
- mplang/v2/backends/simp_driver/__init__.py +41 -0
- mplang/v2/backends/simp_driver/http.py +168 -0
- mplang/v2/backends/simp_driver/mem.py +280 -0
- mplang/v2/backends/simp_driver/ops.py +135 -0
- mplang/v2/backends/simp_driver/state.py +60 -0
- mplang/v2/backends/simp_driver/values.py +52 -0
- mplang/v2/backends/simp_worker/__init__.py +29 -0
- mplang/v2/backends/simp_worker/http.py +354 -0
- mplang/v2/backends/simp_worker/mem.py +102 -0
- mplang/v2/backends/simp_worker/ops.py +167 -0
- mplang/v2/backends/simp_worker/state.py +49 -0
- mplang/v2/backends/spu_impl.py +275 -0
- mplang/v2/backends/spu_state.py +187 -0
- mplang/v2/backends/store_impl.py +62 -0
- mplang/v2/backends/table_impl.py +838 -0
- mplang/v2/backends/tee_impl.py +215 -0
- mplang/v2/backends/tensor_impl.py +519 -0
- mplang/v2/cli.py +603 -0
- mplang/v2/cli_guide.md +122 -0
- mplang/v2/dialects/__init__.py +36 -0
- mplang/v2/dialects/bfv.py +665 -0
- mplang/v2/dialects/crypto.py +689 -0
- mplang/v2/dialects/dtypes.py +378 -0
- mplang/v2/dialects/field.py +210 -0
- mplang/v2/dialects/func.py +135 -0
- mplang/v2/dialects/phe.py +723 -0
- mplang/v2/dialects/simp.py +944 -0
- mplang/v2/dialects/spu.py +349 -0
- mplang/v2/dialects/store.py +63 -0
- mplang/v2/dialects/table.py +407 -0
- mplang/v2/dialects/tee.py +346 -0
- mplang/v2/dialects/tensor.py +1175 -0
- mplang/v2/edsl/README.md +279 -0
- mplang/v2/edsl/__init__.py +99 -0
- mplang/v2/edsl/context.py +311 -0
- mplang/v2/edsl/graph.py +463 -0
- mplang/v2/edsl/jit.py +62 -0
- mplang/v2/edsl/object.py +53 -0
- mplang/v2/edsl/primitive.py +284 -0
- mplang/v2/edsl/printer.py +119 -0
- mplang/v2/edsl/registry.py +207 -0
- mplang/v2/edsl/serde.py +375 -0
- mplang/v2/edsl/tracer.py +614 -0
- mplang/v2/edsl/typing.py +816 -0
- mplang/v2/kernels/Makefile +30 -0
- mplang/v2/kernels/__init__.py +23 -0
- mplang/v2/kernels/gf128.cpp +148 -0
- mplang/v2/kernels/ldpc.cpp +82 -0
- mplang/v2/kernels/okvs.cpp +283 -0
- mplang/v2/kernels/okvs_opt.cpp +291 -0
- mplang/v2/kernels/py_kernels.py +398 -0
- mplang/v2/libs/collective.py +330 -0
- mplang/v2/libs/device/__init__.py +51 -0
- mplang/v2/libs/device/api.py +813 -0
- mplang/v2/libs/device/cluster.py +352 -0
- mplang/v2/libs/ml/__init__.py +23 -0
- mplang/v2/libs/ml/sgb.py +1861 -0
- mplang/v2/libs/mpc/__init__.py +41 -0
- mplang/v2/libs/mpc/_utils.py +99 -0
- mplang/v2/libs/mpc/analytics/__init__.py +35 -0
- mplang/v2/libs/mpc/analytics/aggregation.py +372 -0
- mplang/v2/libs/mpc/analytics/groupby.md +99 -0
- mplang/v2/libs/mpc/analytics/groupby.py +331 -0
- mplang/v2/libs/mpc/analytics/permutation.py +386 -0
- mplang/v2/libs/mpc/common/constants.py +39 -0
- mplang/v2/libs/mpc/ot/__init__.py +32 -0
- mplang/v2/libs/mpc/ot/base.py +222 -0
- mplang/v2/libs/mpc/ot/extension.py +477 -0
- mplang/v2/libs/mpc/ot/silent.py +217 -0
- mplang/v2/libs/mpc/psi/__init__.py +40 -0
- mplang/v2/libs/mpc/psi/cuckoo.py +228 -0
- mplang/v2/libs/mpc/psi/okvs.py +49 -0
- mplang/v2/libs/mpc/psi/okvs_gct.py +79 -0
- mplang/v2/libs/mpc/psi/oprf.py +310 -0
- mplang/v2/libs/mpc/psi/rr22.py +344 -0
- mplang/v2/libs/mpc/psi/unbalanced.py +200 -0
- mplang/v2/libs/mpc/vole/__init__.py +31 -0
- mplang/v2/libs/mpc/vole/gilboa.py +327 -0
- mplang/v2/libs/mpc/vole/ldpc.py +383 -0
- mplang/v2/libs/mpc/vole/silver.py +336 -0
- mplang/v2/runtime/__init__.py +15 -0
- mplang/v2/runtime/dialect_state.py +41 -0
- mplang/v2/runtime/interpreter.py +871 -0
- mplang/v2/runtime/object_store.py +194 -0
- mplang/v2/runtime/value.py +141 -0
- {mplang_nightly-0.1.dev158.dist-info → mplang_nightly-0.1.dev268.dist-info}/METADATA +24 -17
- mplang_nightly-0.1.dev268.dist-info/RECORD +180 -0
- {mplang_nightly-0.1.dev158.dist-info → mplang_nightly-0.1.dev268.dist-info}/WHEEL +1 -1
- mplang/core/__init__.py +0 -92
- mplang/device.py +0 -340
- mplang/kernels/builtin.py +0 -207
- mplang/ops/crypto.py +0 -109
- mplang/ops/ibis_cc.py +0 -139
- mplang/ops/sql.py +0 -61
- mplang/protos/v1alpha1/mpir_pb2_grpc.py +0 -3
- mplang/runtime/link_comm.py +0 -131
- mplang/simp/smpc.py +0 -201
- mplang/utils/table_utils.py +0 -73
- mplang_nightly-0.1.dev158.dist-info/RECORD +0 -77
- /mplang/{core → v1/core}/mask.py +0 -0
- /mplang/{protos → v1/protos}/v1alpha1/mpir_pb2.py +0 -0
- /mplang/{runtime → v1/runtime}/exceptions.py +0 -0
- /mplang/{runtime → v1/runtime}/http_api.md +0 -0
- /mplang/{kernels → v1/simp}/__init__.py +0 -0
- /mplang/{utils → v1/utils}/__init__.py +0 -0
- /mplang/{utils → v1/utils}/crypto.py +0 -0
- /mplang/{utils → v1/utils}/func_utils.py +0 -0
- /mplang/{utils → v1/utils}/spu_utils.py +0 -0
- {mplang_nightly-0.1.dev158.dist-info → mplang_nightly-0.1.dev268.dist-info}/entry_points.txt +0 -0
- {mplang_nightly-0.1.dev158.dist-info → mplang_nightly-0.1.dev268.dist-info}/licenses/LICENSE +0 -0
|
@@ -0,0 +1,407 @@
|
|
|
1
|
+
# Copyright 2025 Ant Group Co., Ltd.
|
|
2
|
+
#
|
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
4
|
+
# you may not use this file except in compliance with the License.
|
|
5
|
+
# You may obtain a copy of the License at
|
|
6
|
+
#
|
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
8
|
+
#
|
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
12
|
+
# See the License for the specific language governing permissions and
|
|
13
|
+
# limitations under the License.
|
|
14
|
+
|
|
15
|
+
"""Table dialect: table operations backed by plaintext/private SQL engines."""
|
|
16
|
+
|
|
17
|
+
from __future__ import annotations
|
|
18
|
+
|
|
19
|
+
from typing import Any, cast
|
|
20
|
+
|
|
21
|
+
import mplang.v2.edsl as el
|
|
22
|
+
import mplang.v2.edsl.typing as elt
|
|
23
|
+
|
|
24
|
+
run_sql_p: el.Primitive[Any] = el.Primitive("table.run_sql")
|
|
25
|
+
table2tensor_p: el.Primitive[el.Object] = el.Primitive("table.table2tensor")
|
|
26
|
+
tensor2table_p: el.Primitive[el.Object] = el.Primitive("table.tensor2table")
|
|
27
|
+
constant_p: el.Primitive[el.Object] = el.Primitive("table.constant")
|
|
28
|
+
read_p: el.Primitive[el.Object] = el.Primitive("table.read")
|
|
29
|
+
write_p: el.Primitive[el.Object] = el.Primitive("table.write")
|
|
30
|
+
|
|
31
|
+
|
|
32
|
+
def _current_tracer() -> el.Tracer:
|
|
33
|
+
ctx = el.get_current_context()
|
|
34
|
+
if not isinstance(ctx, el.Tracer):
|
|
35
|
+
raise TypeError(f"Expected Tracer context, got {type(ctx)}")
|
|
36
|
+
return ctx
|
|
37
|
+
|
|
38
|
+
|
|
39
|
+
@run_sql_p.def_trace
|
|
40
|
+
def _run_sql_trace(
|
|
41
|
+
query: str,
|
|
42
|
+
*,
|
|
43
|
+
out_type: elt.TableType,
|
|
44
|
+
dialect: str = "duckdb",
|
|
45
|
+
**tables: el.TraceObject,
|
|
46
|
+
) -> el.TraceObject:
|
|
47
|
+
tracer = _current_tracer()
|
|
48
|
+
if not isinstance(out_type, elt.TableType):
|
|
49
|
+
raise TypeError("run_sql out_type must be TableType")
|
|
50
|
+
if not tables:
|
|
51
|
+
raise ValueError("run_sql requires at least one table input")
|
|
52
|
+
|
|
53
|
+
ordered = list(tables.items())
|
|
54
|
+
inputs = []
|
|
55
|
+
names = []
|
|
56
|
+
for name, table in ordered:
|
|
57
|
+
if not isinstance(table, el.TraceObject):
|
|
58
|
+
raise TypeError(f"Table '{name}' must be TraceObject")
|
|
59
|
+
inputs.append(table._graph_value)
|
|
60
|
+
names.append(name)
|
|
61
|
+
|
|
62
|
+
[value] = tracer.graph.add_op(
|
|
63
|
+
opcode="table.run_sql",
|
|
64
|
+
inputs=inputs,
|
|
65
|
+
output_types=[out_type],
|
|
66
|
+
attrs={"query": query, "dialect": dialect, "table_names": names},
|
|
67
|
+
)
|
|
68
|
+
return el.TraceObject(value, tracer)
|
|
69
|
+
|
|
70
|
+
|
|
71
|
+
@table2tensor_p.def_abstract_eval
|
|
72
|
+
def _table2tensor_ae(table_t: elt.TableType, *, number_rows: int) -> elt.TensorType:
|
|
73
|
+
"""Infer tensor type for table.table2tensor."""
|
|
74
|
+
|
|
75
|
+
if not isinstance(number_rows, int):
|
|
76
|
+
raise TypeError("number_rows must be an int")
|
|
77
|
+
if number_rows < 0:
|
|
78
|
+
raise ValueError("number_rows must be >= 0")
|
|
79
|
+
if not table_t.schema:
|
|
80
|
+
raise ValueError("Cannot convert empty table to tensor")
|
|
81
|
+
column_types = list(table_t.schema.values())
|
|
82
|
+
first = column_types[0]
|
|
83
|
+
|
|
84
|
+
def _scalar_dtype(col: elt.BaseType) -> elt.BaseType:
|
|
85
|
+
if hasattr(col, "element_type"):
|
|
86
|
+
tensor_col = col # type: ignore[assignment]
|
|
87
|
+
if tensor_col.shape not in ((), None): # type: ignore[attr-defined]
|
|
88
|
+
raise TypeError(
|
|
89
|
+
"table2tensor expects scalar columns (rank-0 TensorType)"
|
|
90
|
+
)
|
|
91
|
+
return tensor_col.element_type # type: ignore[attr-defined,no-any-return]
|
|
92
|
+
return col
|
|
93
|
+
|
|
94
|
+
first_scalar = _scalar_dtype(first)
|
|
95
|
+
for col in column_types[1:]:
|
|
96
|
+
if _scalar_dtype(col) != first_scalar:
|
|
97
|
+
raise TypeError("All table columns must share the same scalar dtype")
|
|
98
|
+
if not isinstance(first_scalar, elt.BaseType):
|
|
99
|
+
raise TypeError("All table columns must share the same dtype for table2tensor")
|
|
100
|
+
return elt.TensorType(first_scalar, (number_rows, len(column_types)))
|
|
101
|
+
|
|
102
|
+
|
|
103
|
+
@tensor2table_p.def_abstract_eval
|
|
104
|
+
def _tensor2table_ae(
|
|
105
|
+
tensor_t: elt.TensorType, *, column_names: list[str]
|
|
106
|
+
) -> elt.TableType:
|
|
107
|
+
"""Infer table type for table.tensor2table."""
|
|
108
|
+
|
|
109
|
+
if len(tensor_t.shape) != 2:
|
|
110
|
+
raise TypeError(
|
|
111
|
+
f"tensor2table expects rank-2 tensor (N, F), got rank {len(tensor_t.shape)}"
|
|
112
|
+
)
|
|
113
|
+
n_cols = tensor_t.shape[1]
|
|
114
|
+
if not column_names:
|
|
115
|
+
raise ValueError("column_names must be provided")
|
|
116
|
+
if len(column_names) != n_cols:
|
|
117
|
+
raise ValueError("column_names length must match tensor second dimension")
|
|
118
|
+
seen: set[str] = set()
|
|
119
|
+
schema: dict[str, elt.BaseType] = {}
|
|
120
|
+
for idx, name in enumerate(column_names):
|
|
121
|
+
if not isinstance(name, str):
|
|
122
|
+
raise TypeError(
|
|
123
|
+
f"column_names[{idx}] must be str, got {type(name).__name__}"
|
|
124
|
+
)
|
|
125
|
+
if name.strip() == "":
|
|
126
|
+
raise ValueError("column names must be non-empty/non-whitespace")
|
|
127
|
+
if name in seen:
|
|
128
|
+
raise ValueError(f"duplicate column name: {name!r}")
|
|
129
|
+
seen.add(name)
|
|
130
|
+
schema[name] = tensor_t.element_type
|
|
131
|
+
# Each column shares the tensor's element dtype.
|
|
132
|
+
return elt.TableType(schema)
|
|
133
|
+
|
|
134
|
+
|
|
135
|
+
def run_sql(
|
|
136
|
+
query: str,
|
|
137
|
+
*,
|
|
138
|
+
out_type: elt.TableType,
|
|
139
|
+
dialect: str = "duckdb",
|
|
140
|
+
**tables: el.TraceObject,
|
|
141
|
+
) -> el.TraceObject:
|
|
142
|
+
"""Trace a SQL query over plaintext/private tables.
|
|
143
|
+
|
|
144
|
+
Inserts a `table.run_sql` op with the provided query string and table inputs.
|
|
145
|
+
The `out_type` describes the resulting table schema (columns + types).
|
|
146
|
+
"""
|
|
147
|
+
|
|
148
|
+
return run_sql_p.bind( # type: ignore[no-any-return]
|
|
149
|
+
query,
|
|
150
|
+
out_type=out_type,
|
|
151
|
+
dialect=dialect,
|
|
152
|
+
**tables,
|
|
153
|
+
)
|
|
154
|
+
|
|
155
|
+
|
|
156
|
+
def table2tensor(table: el.TraceObject, *, number_rows: int) -> el.Object:
|
|
157
|
+
"""Convert a homogeneous table into a dense tensor."""
|
|
158
|
+
|
|
159
|
+
return table2tensor_p.bind(table, number_rows=number_rows)
|
|
160
|
+
|
|
161
|
+
|
|
162
|
+
def tensor2table(tensor: el.TraceObject, *, column_names: list[str]) -> el.Object:
|
|
163
|
+
"""Convert a rank-2 tensor (N, F) into a table with named columns."""
|
|
164
|
+
|
|
165
|
+
return tensor2table_p.bind(tensor, column_names=column_names)
|
|
166
|
+
|
|
167
|
+
|
|
168
|
+
@constant_p.def_abstract_eval
|
|
169
|
+
def _constant_ae(*, data: Any) -> elt.TableType:
|
|
170
|
+
"""Infer table type for constant data.
|
|
171
|
+
|
|
172
|
+
Args:
|
|
173
|
+
data: Dictionary mapping column names to lists of values,
|
|
174
|
+
pandas DataFrame, PyArrow Table, or any data convertible to DataFrame
|
|
175
|
+
|
|
176
|
+
Returns:
|
|
177
|
+
TableType inferred from schema
|
|
178
|
+
|
|
179
|
+
Raises:
|
|
180
|
+
TypeError: If data cannot be converted to DataFrame
|
|
181
|
+
"""
|
|
182
|
+
import pandas as pd
|
|
183
|
+
import pyarrow as pa
|
|
184
|
+
|
|
185
|
+
from mplang.v2.dialects import dtypes
|
|
186
|
+
|
|
187
|
+
# Handle PyArrow Table directly
|
|
188
|
+
if isinstance(data, pa.Table):
|
|
189
|
+
schema: dict[str, elt.BaseType] = {}
|
|
190
|
+
for field in data.schema:
|
|
191
|
+
schema[field.name] = dtypes.from_arrow(field.type)
|
|
192
|
+
return elt.TableType(schema)
|
|
193
|
+
|
|
194
|
+
# Handle pandas DataFrame
|
|
195
|
+
if isinstance(data, pd.DataFrame):
|
|
196
|
+
df = data
|
|
197
|
+
else:
|
|
198
|
+
# Dict or other types - convert to DataFrame
|
|
199
|
+
df = pd.DataFrame(data)
|
|
200
|
+
|
|
201
|
+
# Infer schema from pandas dtypes
|
|
202
|
+
schema = {}
|
|
203
|
+
for col_name in df.columns:
|
|
204
|
+
schema[str(col_name)] = dtypes.from_pandas(df[col_name].dtype)
|
|
205
|
+
|
|
206
|
+
return elt.TableType(schema)
|
|
207
|
+
|
|
208
|
+
|
|
209
|
+
def constant(data: dict[str, list]) -> el.Object:
|
|
210
|
+
"""Create a table constant value.
|
|
211
|
+
|
|
212
|
+
This creates a constant table that can be used in table computations.
|
|
213
|
+
The constant value is embedded directly into the computation graph.
|
|
214
|
+
|
|
215
|
+
Args:
|
|
216
|
+
data: Dictionary mapping column names to lists of values,
|
|
217
|
+
pandas DataFrame, or any data convertible to DataFrame.
|
|
218
|
+
All columns must have the same length.
|
|
219
|
+
|
|
220
|
+
Returns:
|
|
221
|
+
Object representing the constant table (TraceObject in trace mode,
|
|
222
|
+
InterpObject in interp mode)
|
|
223
|
+
|
|
224
|
+
Raises:
|
|
225
|
+
TypeError: If data cannot be converted to DataFrame
|
|
226
|
+
ValueError: If columns have different lengths
|
|
227
|
+
|
|
228
|
+
Example:
|
|
229
|
+
>>> # From dict
|
|
230
|
+
>>> table = constant({
|
|
231
|
+
... "id": [1, 2, 3],
|
|
232
|
+
... "name": ["alice", "bob", "charlie"],
|
|
233
|
+
... "score": [95.5, 87.2, 92.8],
|
|
234
|
+
... })
|
|
235
|
+
>>> # From DataFrame
|
|
236
|
+
>>> import pandas as pd
|
|
237
|
+
>>> df = pd.DataFrame({"a": [1, 2], "b": [3.0, 4.0]})
|
|
238
|
+
>>> table = constant(df)
|
|
239
|
+
"""
|
|
240
|
+
return constant_p.bind(data=data) # type: ignore[no-any-return]
|
|
241
|
+
|
|
242
|
+
|
|
243
|
+
# =============================================================================
|
|
244
|
+
# Table I/O: read and write
|
|
245
|
+
# =============================================================================
|
|
246
|
+
|
|
247
|
+
|
|
248
|
+
@read_p.def_abstract_eval
|
|
249
|
+
def _read_ae(*, path: str, schema: elt.TableType, format: str) -> elt.TableType:
|
|
250
|
+
"""Infer output type for table.read.
|
|
251
|
+
|
|
252
|
+
Args:
|
|
253
|
+
path: File path to read from
|
|
254
|
+
schema: Expected table schema
|
|
255
|
+
format: File format ("auto", "csv", "parquet")
|
|
256
|
+
|
|
257
|
+
Returns:
|
|
258
|
+
The provided schema (since we can't inspect the file at trace time)
|
|
259
|
+
|
|
260
|
+
Raises:
|
|
261
|
+
TypeError: If schema is not a TableType
|
|
262
|
+
ValueError: If path is empty or format is invalid
|
|
263
|
+
"""
|
|
264
|
+
if not isinstance(path, str) or not path:
|
|
265
|
+
raise ValueError("path must be a non-empty string")
|
|
266
|
+
if not isinstance(schema, elt.TableType):
|
|
267
|
+
raise TypeError(f"schema must be TableType, got {type(schema).__name__}")
|
|
268
|
+
if format not in ("auto", "csv", "parquet"):
|
|
269
|
+
raise ValueError(f"format must be 'auto', 'csv', or 'parquet', got {format!r}")
|
|
270
|
+
return schema
|
|
271
|
+
|
|
272
|
+
|
|
273
|
+
def read(
|
|
274
|
+
path: str,
|
|
275
|
+
*,
|
|
276
|
+
schema: elt.TableType,
|
|
277
|
+
format: str = "auto",
|
|
278
|
+
) -> el.Object:
|
|
279
|
+
"""Read a table from a file.
|
|
280
|
+
|
|
281
|
+
This creates a table.read operation that reads data from the specified path
|
|
282
|
+
at runtime. The schema must be provided since the file cannot be inspected
|
|
283
|
+
at trace/compile time.
|
|
284
|
+
|
|
285
|
+
Args:
|
|
286
|
+
path: File path to read from. In distributed scenarios, each party
|
|
287
|
+
interprets this path relative to its own filesystem.
|
|
288
|
+
schema: Expected table schema. Must match the actual file structure.
|
|
289
|
+
format: File format. Options:
|
|
290
|
+
- "auto": Detect from file extension (.csv, .parquet)
|
|
291
|
+
- "csv": Read as CSV
|
|
292
|
+
- "parquet": Read as Parquet
|
|
293
|
+
|
|
294
|
+
Returns:
|
|
295
|
+
Table object with the specified schema.
|
|
296
|
+
|
|
297
|
+
Example:
|
|
298
|
+
>>> schema = TableType({
|
|
299
|
+
... "id": TensorType(i64, ()),
|
|
300
|
+
... "value": TensorType(f64, ()),
|
|
301
|
+
... })
|
|
302
|
+
>>> tbl = table.read("/data/input.csv", schema=schema)
|
|
303
|
+
"""
|
|
304
|
+
return read_p.bind(path=path, schema=schema, format=format) # type: ignore[no-any-return]
|
|
305
|
+
|
|
306
|
+
|
|
307
|
+
@write_p.def_abstract_eval
|
|
308
|
+
def _write_ae(in_types: list[elt.BaseType], *, path: str, format: str) -> elt.TableType:
|
|
309
|
+
"""Infer output type for table.write.
|
|
310
|
+
|
|
311
|
+
Args:
|
|
312
|
+
in_types: Input table's type
|
|
313
|
+
path: File path to write to
|
|
314
|
+
format: Output format ("csv", "parquet")
|
|
315
|
+
|
|
316
|
+
Returns:
|
|
317
|
+
The input table type
|
|
318
|
+
|
|
319
|
+
Raises:
|
|
320
|
+
TypeError: If input is not a TableType
|
|
321
|
+
ValueError: If path is empty or format is invalid
|
|
322
|
+
"""
|
|
323
|
+
|
|
324
|
+
if not in_types:
|
|
325
|
+
raise ValueError(
|
|
326
|
+
f"write requires at least one input table, got {len(in_types)}"
|
|
327
|
+
)
|
|
328
|
+
|
|
329
|
+
# Verify all inputs are TableType
|
|
330
|
+
for i, t in enumerate(in_types):
|
|
331
|
+
if not isinstance(t, elt.TableType):
|
|
332
|
+
raise TypeError(f"Input {i} is not TableType: {type(t)}")
|
|
333
|
+
|
|
334
|
+
table_types = cast(list[elt.TableType], in_types)
|
|
335
|
+
columns = {}
|
|
336
|
+
for table_type in table_types:
|
|
337
|
+
for col_name in table_type.schema:
|
|
338
|
+
if col_name in columns:
|
|
339
|
+
raise ValueError(
|
|
340
|
+
f"Duplicate column name '{col_name}' found across tables. "
|
|
341
|
+
f"When writing multiple tables, column names must be unique."
|
|
342
|
+
)
|
|
343
|
+
columns.update(table_type.schema)
|
|
344
|
+
|
|
345
|
+
if not isinstance(path, str) or not path:
|
|
346
|
+
raise ValueError("path must be a non-empty string")
|
|
347
|
+
if format not in ("auto", "parquet", "csv", "json"):
|
|
348
|
+
raise ValueError(
|
|
349
|
+
f"format must be in ['auto', 'parquet', 'csv', 'json'], got {format!r}"
|
|
350
|
+
)
|
|
351
|
+
return elt.TableType(columns)
|
|
352
|
+
|
|
353
|
+
|
|
354
|
+
def write(
|
|
355
|
+
tables: el.Object | list[el.Object] | Any,
|
|
356
|
+
path: str,
|
|
357
|
+
*,
|
|
358
|
+
format: str = "parquet",
|
|
359
|
+
) -> el.Object | None:
|
|
360
|
+
"""Write a table to a file.
|
|
361
|
+
|
|
362
|
+
This creates a table.write operation that persists the table data at runtime.
|
|
363
|
+
The operation returns the input table unchanged, allowing chaining.
|
|
364
|
+
|
|
365
|
+
If a runtime value (e.g., PyArrow Table, DataFrame, dict) is passed instead of
|
|
366
|
+
a traced object, it will be wrapped with table.constant() automatically.
|
|
367
|
+
|
|
368
|
+
Args:
|
|
369
|
+
table: Table to write. Can be a TraceObject, PyArrow Table, DataFrame, or dict.
|
|
370
|
+
path: Destination file path. In distributed scenarios, each party
|
|
371
|
+
interprets this path relative to its own filesystem.
|
|
372
|
+
format: Output format. Options:
|
|
373
|
+
- "csv": Write as CSV
|
|
374
|
+
- "parquet": Write as Parquet (default, more efficient)
|
|
375
|
+
|
|
376
|
+
Returns:
|
|
377
|
+
The input table (passthrough for chaining), or None in interpreter mode.
|
|
378
|
+
|
|
379
|
+
Example:
|
|
380
|
+
>>> result = table.run_sql("SELECT ...", out_type=schema, input=tbl)
|
|
381
|
+
>>> table.write(result, "/data/output.parquet")
|
|
382
|
+
"""
|
|
383
|
+
# Auto-wrap runtime values
|
|
384
|
+
if not isinstance(tables, list):
|
|
385
|
+
tables = [tables]
|
|
386
|
+
|
|
387
|
+
for idx, tbl in enumerate(tables):
|
|
388
|
+
if not isinstance(tbl, el.Object):
|
|
389
|
+
tables[idx] = constant(tbl)
|
|
390
|
+
|
|
391
|
+
return write_p.bind(*tables, path=path, format=format) # type: ignore[no-any-return]
|
|
392
|
+
|
|
393
|
+
|
|
394
|
+
__all__ = [
|
|
395
|
+
"constant",
|
|
396
|
+
"constant_p",
|
|
397
|
+
"read",
|
|
398
|
+
"read_p",
|
|
399
|
+
"run_sql",
|
|
400
|
+
"run_sql_p",
|
|
401
|
+
"table2tensor",
|
|
402
|
+
"table2tensor_p",
|
|
403
|
+
"tensor2table",
|
|
404
|
+
"tensor2table_p",
|
|
405
|
+
"write",
|
|
406
|
+
"write_p",
|
|
407
|
+
]
|