mplang-nightly 0.1.dev269__py3-none-any.whl → 0.1.dev271__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- mplang/__init__.py +391 -17
- mplang/{v2/backends → backends}/__init__.py +9 -7
- mplang/{v2/backends → backends}/bfv_impl.py +6 -6
- mplang/{v2/backends → backends}/crypto_impl.py +6 -6
- mplang/{v2/backends → backends}/field_impl.py +5 -5
- mplang/{v2/backends → backends}/func_impl.py +4 -4
- mplang/{v2/backends → backends}/phe_impl.py +3 -3
- mplang/{v2/backends → backends}/simp_design.md +1 -1
- mplang/{v2/backends → backends}/simp_driver/__init__.py +5 -5
- mplang/{v2/backends → backends}/simp_driver/http.py +8 -8
- mplang/{v2/backends → backends}/simp_driver/mem.py +9 -9
- mplang/{v2/backends → backends}/simp_driver/ops.py +4 -4
- mplang/{v2/backends → backends}/simp_driver/state.py +2 -2
- mplang/{v2/backends → backends}/simp_driver/values.py +2 -2
- mplang/{v2/backends → backends}/simp_worker/__init__.py +3 -3
- mplang/{v2/backends → backends}/simp_worker/http.py +10 -10
- mplang/{v2/backends → backends}/simp_worker/mem.py +1 -1
- mplang/{v2/backends → backends}/simp_worker/ops.py +5 -5
- mplang/{v2/backends → backends}/simp_worker/state.py +2 -4
- mplang/{v2/backends → backends}/spu_impl.py +8 -8
- mplang/{v2/backends → backends}/spu_state.py +4 -4
- mplang/{v2/backends → backends}/store_impl.py +3 -3
- mplang/{v2/backends → backends}/table_impl.py +8 -8
- mplang/{v2/backends → backends}/tee_impl.py +6 -6
- mplang/{v2/backends → backends}/tensor_impl.py +6 -6
- mplang/{v2/cli.py → cli.py} +9 -9
- mplang/{v2/cli_guide.md → cli_guide.md} +12 -12
- mplang/{v2/dialects → dialects}/__init__.py +5 -5
- mplang/{v2/dialects → dialects}/bfv.py +6 -6
- mplang/{v2/dialects → dialects}/crypto.py +5 -5
- mplang/{v2/dialects → dialects}/dtypes.py +2 -2
- mplang/{v2/dialects → dialects}/field.py +3 -3
- mplang/{v2/dialects → dialects}/func.py +2 -2
- mplang/{v2/dialects → dialects}/phe.py +6 -6
- mplang/{v2/dialects → dialects}/simp.py +6 -6
- mplang/{v2/dialects → dialects}/spu.py +7 -7
- mplang/{v2/dialects → dialects}/store.py +2 -2
- mplang/{v2/dialects → dialects}/table.py +3 -3
- mplang/{v2/dialects → dialects}/tee.py +6 -6
- mplang/{v2/dialects → dialects}/tensor.py +5 -5
- mplang/{v2/edsl → edsl}/__init__.py +3 -3
- mplang/{v2/edsl → edsl}/context.py +6 -6
- mplang/{v2/edsl → edsl}/graph.py +5 -5
- mplang/{v2/edsl → edsl}/jit.py +2 -2
- mplang/{v2/edsl → edsl}/object.py +1 -1
- mplang/{v2/edsl → edsl}/primitive.py +5 -5
- mplang/{v2/edsl → edsl}/printer.py +1 -1
- mplang/{v2/edsl → edsl}/serde.py +1 -1
- mplang/{v2/edsl → edsl}/tracer.py +7 -7
- mplang/{v2/edsl → edsl}/typing.py +1 -1
- mplang/{v2/kernels → kernels}/ldpc.cpp +13 -13
- mplang/{v2/kernels → kernels}/okvs.cpp +4 -4
- mplang/{v2/kernels → kernels}/okvs_opt.cpp +31 -31
- mplang/{v2/kernels → kernels}/py_kernels.py +1 -1
- mplang/{v2/libs → libs}/collective.py +5 -5
- mplang/{v2/libs → libs}/device/__init__.py +1 -1
- mplang/{v2/libs → libs}/device/api.py +12 -12
- mplang/{v2/libs → libs}/ml/__init__.py +1 -1
- mplang/{v2/libs → libs}/ml/sgb.py +4 -4
- mplang/{v2/libs → libs}/mpc/__init__.py +3 -3
- mplang/{v2/libs → libs}/mpc/_utils.py +2 -2
- mplang/{v2/libs → libs}/mpc/analytics/aggregation.py +1 -1
- mplang/{v2/libs → libs}/mpc/analytics/groupby.py +2 -2
- mplang/{v2/libs → libs}/mpc/analytics/permutation.py +3 -3
- mplang/{v2/libs → libs}/mpc/ot/base.py +3 -3
- mplang/{v2/libs → libs}/mpc/ot/extension.py +2 -2
- mplang/{v2/libs → libs}/mpc/ot/silent.py +4 -4
- mplang/{v2/libs → libs}/mpc/psi/cuckoo.py +3 -3
- mplang/{v2/libs → libs}/mpc/psi/okvs.py +1 -1
- mplang/{v2/libs → libs}/mpc/psi/okvs_gct.py +3 -3
- mplang/{v2/libs → libs}/mpc/psi/oprf.py +3 -3
- mplang/{v2/libs → libs}/mpc/psi/rr22.py +7 -7
- mplang/{v2/libs → libs}/mpc/psi/unbalanced.py +4 -4
- mplang/{v2/libs → libs}/mpc/vole/gilboa.py +3 -3
- mplang/{v2/libs → libs}/mpc/vole/ldpc.py +2 -2
- mplang/{v2/libs → libs}/mpc/vole/silver.py +6 -6
- mplang/{v2/runtime → runtime}/interpreter.py +11 -11
- mplang/{v2/runtime → runtime}/value.py +2 -2
- mplang/{v1/runtime → utils}/__init__.py +18 -15
- mplang/{v1/utils → utils}/func_utils.py +1 -1
- {mplang_nightly-0.1.dev269.dist-info → mplang_nightly-0.1.dev271.dist-info}/METADATA +2 -2
- mplang_nightly-0.1.dev271.dist-info/RECORD +102 -0
- mplang/v1/__init__.py +0 -157
- mplang/v1/_device.py +0 -602
- mplang/v1/analysis/__init__.py +0 -37
- mplang/v1/analysis/diagram.py +0 -567
- mplang/v1/core/__init__.py +0 -157
- mplang/v1/core/cluster.py +0 -343
- mplang/v1/core/comm.py +0 -281
- mplang/v1/core/context_mgr.py +0 -50
- mplang/v1/core/dtypes.py +0 -335
- mplang/v1/core/expr/__init__.py +0 -80
- mplang/v1/core/expr/ast.py +0 -542
- mplang/v1/core/expr/evaluator.py +0 -581
- mplang/v1/core/expr/printer.py +0 -285
- mplang/v1/core/expr/transformer.py +0 -141
- mplang/v1/core/expr/utils.py +0 -78
- mplang/v1/core/expr/visitor.py +0 -85
- mplang/v1/core/expr/walk.py +0 -387
- mplang/v1/core/interp.py +0 -160
- mplang/v1/core/mask.py +0 -325
- mplang/v1/core/mpir.py +0 -965
- mplang/v1/core/mpobject.py +0 -117
- mplang/v1/core/mptype.py +0 -407
- mplang/v1/core/pfunc.py +0 -130
- mplang/v1/core/primitive.py +0 -877
- mplang/v1/core/table.py +0 -218
- mplang/v1/core/tensor.py +0 -75
- mplang/v1/core/tracer.py +0 -383
- mplang/v1/host.py +0 -130
- mplang/v1/kernels/__init__.py +0 -41
- mplang/v1/kernels/base.py +0 -125
- mplang/v1/kernels/basic.py +0 -240
- mplang/v1/kernels/context.py +0 -369
- mplang/v1/kernels/crypto.py +0 -122
- mplang/v1/kernels/fhe.py +0 -858
- mplang/v1/kernels/mock_tee.py +0 -72
- mplang/v1/kernels/phe.py +0 -1864
- mplang/v1/kernels/spu.py +0 -341
- mplang/v1/kernels/sql_duckdb.py +0 -44
- mplang/v1/kernels/stablehlo.py +0 -90
- mplang/v1/kernels/value.py +0 -626
- mplang/v1/ops/__init__.py +0 -35
- mplang/v1/ops/base.py +0 -424
- mplang/v1/ops/basic.py +0 -294
- mplang/v1/ops/crypto.py +0 -262
- mplang/v1/ops/fhe.py +0 -272
- mplang/v1/ops/jax_cc.py +0 -147
- mplang/v1/ops/nnx_cc.py +0 -168
- mplang/v1/ops/phe.py +0 -216
- mplang/v1/ops/spu.py +0 -151
- mplang/v1/ops/sql_cc.py +0 -303
- mplang/v1/ops/tee.py +0 -36
- mplang/v1/protos/v1alpha1/mpir_pb2.py +0 -63
- mplang/v1/protos/v1alpha1/mpir_pb2.pyi +0 -557
- mplang/v1/protos/v1alpha1/value_pb2.py +0 -34
- mplang/v1/protos/v1alpha1/value_pb2.pyi +0 -169
- mplang/v1/runtime/channel.py +0 -230
- mplang/v1/runtime/cli.py +0 -451
- mplang/v1/runtime/client.py +0 -456
- mplang/v1/runtime/communicator.py +0 -131
- mplang/v1/runtime/data_providers.py +0 -303
- mplang/v1/runtime/driver.py +0 -324
- mplang/v1/runtime/exceptions.py +0 -27
- mplang/v1/runtime/http_api.md +0 -56
- mplang/v1/runtime/link_comm.py +0 -196
- mplang/v1/runtime/server.py +0 -501
- mplang/v1/runtime/session.py +0 -270
- mplang/v1/runtime/simulation.py +0 -324
- mplang/v1/simp/__init__.py +0 -13
- mplang/v1/simp/api.py +0 -353
- mplang/v1/simp/mpi.py +0 -131
- mplang/v1/simp/party.py +0 -225
- mplang/v1/simp/random.py +0 -120
- mplang/v1/simp/smpc.py +0 -238
- mplang/v1/utils/__init__.py +0 -13
- mplang/v1/utils/crypto.py +0 -32
- mplang/v1/utils/spu_utils.py +0 -130
- mplang/v1/utils/table_utils.py +0 -185
- mplang/v2/__init__.py +0 -424
- mplang_nightly-0.1.dev269.dist-info/RECORD +0 -180
- /mplang/{v2/backends → backends}/channel.py +0 -0
- /mplang/{v2/edsl → edsl}/README.md +0 -0
- /mplang/{v2/edsl → edsl}/registry.py +0 -0
- /mplang/{v2/kernels → kernels}/Makefile +0 -0
- /mplang/{v2/kernels → kernels}/__init__.py +0 -0
- /mplang/{v2/kernels → kernels}/gf128.cpp +0 -0
- /mplang/{v2/libs → libs}/device/cluster.py +0 -0
- /mplang/{v2/libs → libs}/mpc/analytics/__init__.py +0 -0
- /mplang/{v2/libs → libs}/mpc/analytics/groupby.md +0 -0
- /mplang/{v2/libs → libs}/mpc/common/constants.py +0 -0
- /mplang/{v2/libs → libs}/mpc/ot/__init__.py +0 -0
- /mplang/{v2/libs → libs}/mpc/psi/__init__.py +0 -0
- /mplang/{v2/libs → libs}/mpc/vole/__init__.py +0 -0
- /mplang/{v2/runtime → runtime}/__init__.py +0 -0
- /mplang/{v2/runtime → runtime}/dialect_state.py +0 -0
- /mplang/{v2/runtime → runtime}/object_store.py +0 -0
- {mplang_nightly-0.1.dev269.dist-info → mplang_nightly-0.1.dev271.dist-info}/WHEEL +0 -0
- {mplang_nightly-0.1.dev269.dist-info → mplang_nightly-0.1.dev271.dist-info}/entry_points.txt +0 -0
- {mplang_nightly-0.1.dev269.dist-info → mplang_nightly-0.1.dev271.dist-info}/licenses/LICENSE +0 -0
mplang/v1/ops/phe.py
DELETED
|
@@ -1,216 +0,0 @@
|
|
|
1
|
-
# Copyright 2025 Ant Group Co., Ltd.
|
|
2
|
-
#
|
|
3
|
-
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
4
|
-
# you may not use this file except in compliance with the License.
|
|
5
|
-
# You may obtain a copy of the License at
|
|
6
|
-
#
|
|
7
|
-
# http://www.apache.org/licenses/LICENSE-2.0
|
|
8
|
-
#
|
|
9
|
-
# Unless required by applicable law or agreed to in writing, software
|
|
10
|
-
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
11
|
-
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
12
|
-
# See the License for the specific language governing permissions and
|
|
13
|
-
# limitations under the License.
|
|
14
|
-
|
|
15
|
-
"""PHE (Partially Homomorphic Encryption) frontend operations."""
|
|
16
|
-
|
|
17
|
-
from mplang.v1.core import UINT8, TensorType
|
|
18
|
-
from mplang.v1.ops.base import stateless_mod
|
|
19
|
-
|
|
20
|
-
_PHE_MOD = stateless_mod("phe")
|
|
21
|
-
|
|
22
|
-
|
|
23
|
-
@_PHE_MOD.simple_op()
|
|
24
|
-
def keygen(
|
|
25
|
-
*,
|
|
26
|
-
scheme: str = "paillier",
|
|
27
|
-
key_size: int = 2048,
|
|
28
|
-
max_value: int | None = None,
|
|
29
|
-
fxp_bits: int | None = None,
|
|
30
|
-
) -> tuple[TensorType, TensorType]:
|
|
31
|
-
"""Generate a PHE key pair: returns (public_key, private_key).
|
|
32
|
-
|
|
33
|
-
Keys are represented with a sentinel TensorType UINT8[(-1, 0)] to indicate
|
|
34
|
-
non-structural, backend-only handles. Runtime validation will treat this
|
|
35
|
-
shape as an opaque placeholder and skip dtype/shape checks.
|
|
36
|
-
|
|
37
|
-
Attributes (forwarded to backend):
|
|
38
|
-
scheme: PHE scheme (default: 'paillier')
|
|
39
|
-
key_size: Modulus size in bits (default: 2048)
|
|
40
|
-
max_value: Optional range-encoding bound B. If provided, the backend will
|
|
41
|
-
encode/decode integers/floats within [-B, B] and treat (B, N-B) as overflow.
|
|
42
|
-
Pick B to exceed the largest intermediate magnitude you expect in homomorphic
|
|
43
|
-
combinations. If omitted, backend default is used (currently 2**32).
|
|
44
|
-
fxp_bits: Optional fixed-point fractional bits for float encoding (default backend value).
|
|
45
|
-
"""
|
|
46
|
-
key_spec = TensorType(UINT8, (-1, 0))
|
|
47
|
-
return key_spec, key_spec
|
|
48
|
-
|
|
49
|
-
|
|
50
|
-
@_PHE_MOD.simple_op()
|
|
51
|
-
def encrypt(plaintext: TensorType, public_key: TensorType) -> TensorType:
|
|
52
|
-
"""Encrypt plaintext using PHE public key: returns ciphertext with same semantic type as plaintext."""
|
|
53
|
-
_ = public_key
|
|
54
|
-
return plaintext
|
|
55
|
-
|
|
56
|
-
|
|
57
|
-
@_PHE_MOD.simple_op()
|
|
58
|
-
def add(operand1: TensorType, operand2: TensorType) -> TensorType:
|
|
59
|
-
"""Add two PHE operands (semantics depend on backend representation)."""
|
|
60
|
-
_ = operand2
|
|
61
|
-
return operand1
|
|
62
|
-
|
|
63
|
-
|
|
64
|
-
@_PHE_MOD.simple_op()
|
|
65
|
-
def mul(ciphertext: TensorType, plaintext: TensorType) -> TensorType:
|
|
66
|
-
"""Multiply a PHE ciphertext with a plaintext value (ciphertext dtype preserved)."""
|
|
67
|
-
if plaintext.dtype.is_floating:
|
|
68
|
-
raise ValueError(
|
|
69
|
-
"PHE multiplication does not support floating-point plaintext."
|
|
70
|
-
)
|
|
71
|
-
return ciphertext
|
|
72
|
-
|
|
73
|
-
|
|
74
|
-
@_PHE_MOD.simple_op()
|
|
75
|
-
def decrypt(ciphertext: TensorType, private_key: TensorType) -> TensorType:
|
|
76
|
-
"""Decrypt ciphertext using PHE private key: returns plaintext with same semantic type as ciphertext."""
|
|
77
|
-
_ = private_key
|
|
78
|
-
return ciphertext
|
|
79
|
-
|
|
80
|
-
|
|
81
|
-
@_PHE_MOD.simple_op()
|
|
82
|
-
def dot(ciphertext: TensorType, plaintext: TensorType) -> TensorType:
|
|
83
|
-
"""Compute dot product of ciphertext with plaintext.
|
|
84
|
-
|
|
85
|
-
Args:
|
|
86
|
-
ciphertext: The ciphertext operand (first argument)
|
|
87
|
-
plaintext: The plaintext operand (second argument)
|
|
88
|
-
|
|
89
|
-
Returns:
|
|
90
|
-
TensorType: Result tensor type with computed shape following numpy dot product rules
|
|
91
|
-
"""
|
|
92
|
-
# For dot product, we need to calculate the result shape
|
|
93
|
-
# This follows numpy dot product rules
|
|
94
|
-
import numpy as np
|
|
95
|
-
|
|
96
|
-
# Create dummy arrays to determine result shape
|
|
97
|
-
dummy_ct = np.zeros(ciphertext.shape)
|
|
98
|
-
dummy_pt = np.zeros(plaintext.shape)
|
|
99
|
-
dummy_result = np.dot(dummy_ct, dummy_pt)
|
|
100
|
-
|
|
101
|
-
return TensorType(ciphertext.dtype, dummy_result.shape)
|
|
102
|
-
|
|
103
|
-
|
|
104
|
-
@_PHE_MOD.simple_op()
|
|
105
|
-
def gather(ciphertext: TensorType, indices: TensorType, *, axis: int = 0) -> TensorType:
|
|
106
|
-
"""Gather elements from ciphertext using indices.
|
|
107
|
-
|
|
108
|
-
Args:
|
|
109
|
-
ciphertext: The ciphertext to gather from
|
|
110
|
-
indices: The indices to gather
|
|
111
|
-
axis: The axis along which to gather (default: 0)
|
|
112
|
-
"""
|
|
113
|
-
# Calculate result shape based on axis parameter
|
|
114
|
-
ct_shape = list(ciphertext.shape)
|
|
115
|
-
indices_shape = list(indices.shape)
|
|
116
|
-
|
|
117
|
-
# Normalize negative axis
|
|
118
|
-
normalized_axis = axis if axis >= 0 else len(ct_shape) + axis
|
|
119
|
-
|
|
120
|
-
# Result shape: replace the axis dimension with indices shape
|
|
121
|
-
result_shape = (
|
|
122
|
-
ct_shape[:normalized_axis] + indices_shape + ct_shape[normalized_axis + 1 :]
|
|
123
|
-
)
|
|
124
|
-
return TensorType(ciphertext.dtype, tuple(result_shape))
|
|
125
|
-
|
|
126
|
-
|
|
127
|
-
@_PHE_MOD.simple_op()
|
|
128
|
-
def scatter(
|
|
129
|
-
ciphertext: TensorType,
|
|
130
|
-
indices: TensorType,
|
|
131
|
-
updates: TensorType,
|
|
132
|
-
*,
|
|
133
|
-
axis: int = 0,
|
|
134
|
-
) -> TensorType:
|
|
135
|
-
"""Scatter updates into ciphertext at specified indices.
|
|
136
|
-
|
|
137
|
-
Args:
|
|
138
|
-
ciphertext: The ciphertext to scatter into
|
|
139
|
-
indices: The indices to scatter at
|
|
140
|
-
updates: The ciphertext updates to scatter
|
|
141
|
-
axis: The axis along which to scatter (default: 0)
|
|
142
|
-
|
|
143
|
-
Returns:
|
|
144
|
-
TensorType: Result tensor type with same shape and dtype as original ciphertext
|
|
145
|
-
"""
|
|
146
|
-
return ciphertext
|
|
147
|
-
|
|
148
|
-
|
|
149
|
-
@_PHE_MOD.simple_op()
|
|
150
|
-
def concat(operand0: TensorType, operand1: TensorType, *, axis: int = 0) -> TensorType:
|
|
151
|
-
"""Concatenate ciphertext tensors along specified axis.
|
|
152
|
-
|
|
153
|
-
Args:
|
|
154
|
-
operand0: The first ciphertext operand to concatenate
|
|
155
|
-
operand1: The second ciphertext operand to concatenate
|
|
156
|
-
axis: Axis along which to concatenate
|
|
157
|
-
|
|
158
|
-
Returns:
|
|
159
|
-
TensorType: Result tensor type with computed shape following numpy concatenation rules
|
|
160
|
-
"""
|
|
161
|
-
# All operands should have same dtype
|
|
162
|
-
first_dtype = operand0.dtype
|
|
163
|
-
if operand1.dtype != first_dtype:
|
|
164
|
-
raise ValueError("All operands must have the same dtype for concatenation")
|
|
165
|
-
|
|
166
|
-
# Calculate result shape using numpy concatenation logic
|
|
167
|
-
import numpy as np
|
|
168
|
-
|
|
169
|
-
dummy_arrays = [np.zeros(operand0.shape), np.zeros(operand1.shape)]
|
|
170
|
-
dummy_result = np.concatenate(dummy_arrays, axis=axis)
|
|
171
|
-
|
|
172
|
-
return TensorType(first_dtype, dummy_result.shape)
|
|
173
|
-
|
|
174
|
-
|
|
175
|
-
@_PHE_MOD.simple_op()
|
|
176
|
-
def reshape(ciphertext: TensorType, *, new_shape: tuple[int, ...]) -> TensorType:
|
|
177
|
-
"""Reshape ciphertext to new shape.
|
|
178
|
-
|
|
179
|
-
Args:
|
|
180
|
-
ciphertext: The ciphertext to reshape
|
|
181
|
-
new_shape: The target shape (can contain -1 for inferred dimension)
|
|
182
|
-
|
|
183
|
-
Returns:
|
|
184
|
-
TensorType: Result tensor type with computed shape following numpy reshape rules
|
|
185
|
-
"""
|
|
186
|
-
# Calculate the actual result shape (handling -1 inference)
|
|
187
|
-
import numpy as np
|
|
188
|
-
|
|
189
|
-
dummy_array = np.zeros(ciphertext.shape)
|
|
190
|
-
# use this to check the correctness of new_shape
|
|
191
|
-
dummy_result = dummy_array.reshape(new_shape)
|
|
192
|
-
actual_shape = dummy_result.shape
|
|
193
|
-
|
|
194
|
-
return TensorType(ciphertext.dtype, actual_shape)
|
|
195
|
-
|
|
196
|
-
|
|
197
|
-
@_PHE_MOD.simple_op()
|
|
198
|
-
def transpose(
|
|
199
|
-
ciphertext: TensorType, *, axes: tuple[int, ...] | None = None
|
|
200
|
-
) -> TensorType:
|
|
201
|
-
"""Transpose ciphertext by permuting axes.
|
|
202
|
-
|
|
203
|
-
Args:
|
|
204
|
-
ciphertext: The ciphertext to transpose
|
|
205
|
-
axes: Permutation of axes (None for default reverse order)
|
|
206
|
-
|
|
207
|
-
Returns:
|
|
208
|
-
TensorType: Result tensor type with computed shape following numpy transpose rules
|
|
209
|
-
"""
|
|
210
|
-
# Calculate result shape using numpy transpose logic
|
|
211
|
-
import numpy as np
|
|
212
|
-
|
|
213
|
-
dummy_array = np.zeros(ciphertext.shape)
|
|
214
|
-
dummy_result = np.transpose(dummy_array, axes)
|
|
215
|
-
|
|
216
|
-
return TensorType(ciphertext.dtype, dummy_result.shape)
|
mplang/v1/ops/spu.py
DELETED
|
@@ -1,151 +0,0 @@
|
|
|
1
|
-
# Copyright 2025 Ant Group Co., Ltd.
|
|
2
|
-
#
|
|
3
|
-
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
4
|
-
# you may not use this file except in compliance with the License.
|
|
5
|
-
# You may obtain a copy of the License at
|
|
6
|
-
#
|
|
7
|
-
# http://www.apache.org/licenses/LICENSE-2.0
|
|
8
|
-
#
|
|
9
|
-
# Unless required by applicable law or agreed to in writing, software
|
|
10
|
-
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
11
|
-
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
12
|
-
# See the License for the specific language governing permissions and
|
|
13
|
-
# limitations under the License.
|
|
14
|
-
|
|
15
|
-
from __future__ import annotations
|
|
16
|
-
|
|
17
|
-
from collections.abc import Callable
|
|
18
|
-
from typing import Any
|
|
19
|
-
|
|
20
|
-
import jax.numpy as jnp
|
|
21
|
-
import spu.libspu as libspu
|
|
22
|
-
import spu.utils.frontend as spu_fe
|
|
23
|
-
from jax import ShapeDtypeStruct
|
|
24
|
-
from jax.tree_util import PyTreeDef, tree_flatten
|
|
25
|
-
|
|
26
|
-
from mplang.v1.core import MPObject, PFunction, TensorType, get_fn_name
|
|
27
|
-
from mplang.v1.ops.base import stateless_mod
|
|
28
|
-
from mplang.v1.utils.func_utils import normalize_fn
|
|
29
|
-
|
|
30
|
-
|
|
31
|
-
class Visibility:
|
|
32
|
-
"""Frontend visibility constants mapping to libspu.Visibility.
|
|
33
|
-
|
|
34
|
-
Note: these are direct aliases to libspu.Visibility members so that
|
|
35
|
-
downstream serialization and backends receive the exact enum type
|
|
36
|
-
they expect. Keep the friendly names (SECRET/PUBLIC/PRIVATE) for
|
|
37
|
-
frontend ergonomics.
|
|
38
|
-
"""
|
|
39
|
-
|
|
40
|
-
SECRET = libspu.Visibility.VIS_SECRET
|
|
41
|
-
PUBLIC = libspu.Visibility.VIS_PUBLIC
|
|
42
|
-
PRIVATE = libspu.Visibility.VIS_PRIVATE
|
|
43
|
-
|
|
44
|
-
|
|
45
|
-
_SPU_MOD = stateless_mod("spu")
|
|
46
|
-
|
|
47
|
-
|
|
48
|
-
@_SPU_MOD.simple_op()
|
|
49
|
-
def makeshares(
|
|
50
|
-
data: TensorType,
|
|
51
|
-
*,
|
|
52
|
-
world_size: int,
|
|
53
|
-
visibility: libspu.Visibility = Visibility.SECRET,
|
|
54
|
-
owner_rank: int = -1,
|
|
55
|
-
enable_private: bool = False,
|
|
56
|
-
) -> tuple:
|
|
57
|
-
"""Create SPU shares from a plaintext tensor (type-only kernel).
|
|
58
|
-
|
|
59
|
-
Returns a PyTree of TensorType repeated `world_size` times.
|
|
60
|
-
Validation only; PFunction assembly handled by typed_op decorator.
|
|
61
|
-
"""
|
|
62
|
-
if world_size <= 0:
|
|
63
|
-
raise ValueError("world_size must be positive")
|
|
64
|
-
if visibility == Visibility.PRIVATE:
|
|
65
|
-
if not enable_private:
|
|
66
|
-
raise ValueError("PRIVATE visibility disabled; set enable_private=True")
|
|
67
|
-
if owner_rank < 0 or owner_rank >= world_size:
|
|
68
|
-
raise ValueError(f"owner_rank {owner_rank} out of range [0,{world_size})")
|
|
69
|
-
return tuple(data for _ in range(world_size))
|
|
70
|
-
|
|
71
|
-
|
|
72
|
-
@_SPU_MOD.op_def()
|
|
73
|
-
def reconstruct(*shares: MPObject) -> tuple[PFunction, list[MPObject], PyTreeDef]:
|
|
74
|
-
"""Reconstruct plaintext tensor from shares."""
|
|
75
|
-
if len(shares) == 0:
|
|
76
|
-
raise ValueError("reconstruct requires at least one share")
|
|
77
|
-
|
|
78
|
-
ins_info = tuple(TensorType.from_obj(s) for s in shares)
|
|
79
|
-
outs_info = (ins_info[0],)
|
|
80
|
-
pfunc = PFunction(
|
|
81
|
-
fn_type="spu.reconstruct",
|
|
82
|
-
ins_info=ins_info,
|
|
83
|
-
outs_info=outs_info,
|
|
84
|
-
)
|
|
85
|
-
_, treedef = tree_flatten(outs_info[0])
|
|
86
|
-
return pfunc, list(shares), treedef
|
|
87
|
-
|
|
88
|
-
|
|
89
|
-
def _compile_jax(
|
|
90
|
-
copts: libspu.CompilerOptions,
|
|
91
|
-
fn: Callable,
|
|
92
|
-
*args: Any,
|
|
93
|
-
**kwargs: Any,
|
|
94
|
-
) -> tuple[PFunction, list[MPObject], PyTreeDef]:
|
|
95
|
-
"""Compile a JAX function into SPU pphlo MLIR and wrap as PFunction.
|
|
96
|
-
|
|
97
|
-
Resulting PFunction uses fn_type 'spu.run_pphlo'.
|
|
98
|
-
"""
|
|
99
|
-
|
|
100
|
-
def is_variable(arg: Any) -> bool:
|
|
101
|
-
return isinstance(arg, MPObject)
|
|
102
|
-
|
|
103
|
-
normalized_fn, in_vars = normalize_fn(fn, args, kwargs, is_variable)
|
|
104
|
-
|
|
105
|
-
jax_params = [
|
|
106
|
-
ShapeDtypeStruct(arg.shape, jnp.dtype(arg.dtype.name)) for arg in in_vars
|
|
107
|
-
]
|
|
108
|
-
in_vis = [libspu.Visibility.VIS_SECRET for _ in in_vars]
|
|
109
|
-
in_names = [f"in{idx}" for idx in range(len(in_vars))]
|
|
110
|
-
out_names_gen = lambda outs: [f"out{idx}" for idx in range(len(outs))]
|
|
111
|
-
|
|
112
|
-
executable, out_info = spu_fe.compile(
|
|
113
|
-
spu_fe.Kind.JAX,
|
|
114
|
-
normalized_fn,
|
|
115
|
-
[jax_params],
|
|
116
|
-
{},
|
|
117
|
-
in_names,
|
|
118
|
-
in_vis,
|
|
119
|
-
out_names_gen,
|
|
120
|
-
static_argnums=(),
|
|
121
|
-
static_argnames=None,
|
|
122
|
-
copts=copts,
|
|
123
|
-
)
|
|
124
|
-
out_info_flat, out_tree = tree_flatten(out_info)
|
|
125
|
-
output_tensor_infos = [TensorType.from_obj(out) for out in out_info_flat]
|
|
126
|
-
|
|
127
|
-
executable_code = executable.code
|
|
128
|
-
assert isinstance(executable_code, bytes), (
|
|
129
|
-
f"Expected bytes, got {type(executable_code)}"
|
|
130
|
-
)
|
|
131
|
-
executable_code = executable_code.decode("utf-8")
|
|
132
|
-
|
|
133
|
-
pfunc = PFunction(
|
|
134
|
-
fn_type="spu.run_pphlo",
|
|
135
|
-
ins_info=tuple(TensorType.from_obj(x) for x in in_vars),
|
|
136
|
-
outs_info=tuple(output_tensor_infos),
|
|
137
|
-
fn_name=get_fn_name(fn),
|
|
138
|
-
fn_text=executable_code,
|
|
139
|
-
input_visibilities=in_vis,
|
|
140
|
-
input_names=list(executable.input_names),
|
|
141
|
-
output_names=list(executable.output_names),
|
|
142
|
-
executable_name=executable.name,
|
|
143
|
-
)
|
|
144
|
-
return pfunc, in_vars, out_tree
|
|
145
|
-
|
|
146
|
-
|
|
147
|
-
@_SPU_MOD.op_def()
|
|
148
|
-
def jax_compile(
|
|
149
|
-
fn: Callable, *args: Any, **kwargs: Any
|
|
150
|
-
) -> tuple[PFunction, list[MPObject], PyTreeDef]:
|
|
151
|
-
return _compile_jax(libspu.CompilerOptions(), fn, *args, **kwargs)
|
mplang/v1/ops/sql_cc.py
DELETED
|
@@ -1,303 +0,0 @@
|
|
|
1
|
-
# Copyright 2025 Ant Group Co., Ltd.
|
|
2
|
-
#
|
|
3
|
-
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
4
|
-
# you may not use this file except in compliance with the License.
|
|
5
|
-
# You may obtain a copy of the License at
|
|
6
|
-
#
|
|
7
|
-
# http://www.apache.org/licenses/LICENSE-2.0
|
|
8
|
-
#
|
|
9
|
-
# Unless required by applicable law or agreed to in writing, software
|
|
10
|
-
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
11
|
-
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
12
|
-
# See the License for the specific language governing permissions and
|
|
13
|
-
# limitations under the License.
|
|
14
|
-
|
|
15
|
-
from typing import Any
|
|
16
|
-
|
|
17
|
-
import sqlglot as sg
|
|
18
|
-
from jax.tree_util import PyTreeDef, tree_flatten
|
|
19
|
-
from sqlglot import exp as sge
|
|
20
|
-
from sqlglot.optimizer import annotate_types as opt_annot
|
|
21
|
-
from sqlglot.optimizer import qualify as opt_qualify
|
|
22
|
-
|
|
23
|
-
from mplang.v1.core import MPObject, PFunction, TableType
|
|
24
|
-
from mplang.v1.core.dtypes import (
|
|
25
|
-
BINARY,
|
|
26
|
-
BOOL,
|
|
27
|
-
DATE,
|
|
28
|
-
DECIMAL,
|
|
29
|
-
FLOAT32,
|
|
30
|
-
FLOAT64,
|
|
31
|
-
INT8,
|
|
32
|
-
INT16,
|
|
33
|
-
INT32,
|
|
34
|
-
INT64,
|
|
35
|
-
INTERVAL,
|
|
36
|
-
JSON,
|
|
37
|
-
STRING,
|
|
38
|
-
TIME,
|
|
39
|
-
TIMESTAMP,
|
|
40
|
-
UINT8,
|
|
41
|
-
UINT16,
|
|
42
|
-
UINT32,
|
|
43
|
-
UINT64,
|
|
44
|
-
UUID,
|
|
45
|
-
DType,
|
|
46
|
-
)
|
|
47
|
-
from mplang.v1.ops.base import stateless_mod
|
|
48
|
-
|
|
49
|
-
_SQL_MOD = stateless_mod("sql")
|
|
50
|
-
|
|
51
|
-
|
|
52
|
-
# Static dtype mappings (MPLang <-> SQL)
|
|
53
|
-
MP_TO_SQL_TYPE: dict[DType, str] = {
|
|
54
|
-
# Floats
|
|
55
|
-
FLOAT64: "DOUBLE",
|
|
56
|
-
FLOAT32: "FLOAT",
|
|
57
|
-
# Signed ints
|
|
58
|
-
INT8: "TINYINT",
|
|
59
|
-
INT16: "SMALLINT",
|
|
60
|
-
INT32: "INT",
|
|
61
|
-
INT64: "BIGINT",
|
|
62
|
-
# Unsigned ints (portable approximations)
|
|
63
|
-
UINT8: "SMALLINT",
|
|
64
|
-
UINT16: "INT",
|
|
65
|
-
UINT32: "BIGINT",
|
|
66
|
-
UINT64: "DECIMAL(38)",
|
|
67
|
-
# Booleans & strings
|
|
68
|
-
BOOL: "BOOLEAN",
|
|
69
|
-
STRING: "VARCHAR",
|
|
70
|
-
# Dates / times
|
|
71
|
-
DATE: "DATE",
|
|
72
|
-
TIME: "TIME",
|
|
73
|
-
TIMESTAMP: "TIMESTAMP",
|
|
74
|
-
# Other table types
|
|
75
|
-
DECIMAL: "DECIMAL",
|
|
76
|
-
JSON: "JSON",
|
|
77
|
-
BINARY: "BLOB",
|
|
78
|
-
UUID: "UUID",
|
|
79
|
-
INTERVAL: "INTERVAL",
|
|
80
|
-
}
|
|
81
|
-
|
|
82
|
-
SQL_TYPE_TO_MP: dict[str, DType] = {
|
|
83
|
-
# Floats
|
|
84
|
-
"double": FLOAT64,
|
|
85
|
-
"double precision": FLOAT64,
|
|
86
|
-
"float": FLOAT32,
|
|
87
|
-
"real": FLOAT32,
|
|
88
|
-
# Signed ints
|
|
89
|
-
"bigint": INT64,
|
|
90
|
-
"long": INT64,
|
|
91
|
-
"int": INT32,
|
|
92
|
-
"integer": INT32,
|
|
93
|
-
"int4": INT32,
|
|
94
|
-
"smallint": INT16,
|
|
95
|
-
"int2": INT16,
|
|
96
|
-
"tinyint": INT8,
|
|
97
|
-
"int1": INT8,
|
|
98
|
-
# Unsigned (rare in SQL)
|
|
99
|
-
"uint8": UINT8,
|
|
100
|
-
"ubyte": UINT8,
|
|
101
|
-
"uint16": UINT16,
|
|
102
|
-
"uint32": UINT32,
|
|
103
|
-
"uint64": UINT64,
|
|
104
|
-
# Booleans / strings
|
|
105
|
-
"bool": BOOL,
|
|
106
|
-
"boolean": BOOL,
|
|
107
|
-
"char": STRING,
|
|
108
|
-
"varchar": STRING,
|
|
109
|
-
"text": STRING,
|
|
110
|
-
"string": STRING,
|
|
111
|
-
# Dates / times
|
|
112
|
-
"date": DATE,
|
|
113
|
-
"time": TIME,
|
|
114
|
-
"timestamp": TIMESTAMP,
|
|
115
|
-
# Decimal / numeric
|
|
116
|
-
"decimal": DECIMAL,
|
|
117
|
-
"numeric": DECIMAL,
|
|
118
|
-
# Others
|
|
119
|
-
"json": JSON,
|
|
120
|
-
"binary": BINARY,
|
|
121
|
-
"varbinary": BINARY,
|
|
122
|
-
"blob": BINARY,
|
|
123
|
-
"uuid": UUID,
|
|
124
|
-
"interval": INTERVAL,
|
|
125
|
-
}
|
|
126
|
-
|
|
127
|
-
|
|
128
|
-
def _deduce_out_schema(
|
|
129
|
-
parsed: sge.Expression,
|
|
130
|
-
dialect: str,
|
|
131
|
-
in_schemas: dict[str, TableType],
|
|
132
|
-
) -> TableType:
|
|
133
|
-
"""Deduce output schema using sqlglot's qualify + annotate_types.
|
|
134
|
-
|
|
135
|
-
This implementation leverages sqlglot's optimizer to resolve table/column
|
|
136
|
-
references (including star expansion) and annotate expression types. It then
|
|
137
|
-
maps sqlglot DataType to mplang DType and returns a TableType.
|
|
138
|
-
"""
|
|
139
|
-
|
|
140
|
-
# 1) Build sqlglot schema from MPObject/TableType inputs
|
|
141
|
-
def _dtype_to_sql(dt: DType) -> str:
|
|
142
|
-
return MP_TO_SQL_TYPE.get(dt, "VARCHAR")
|
|
143
|
-
|
|
144
|
-
sqlglot_schema: dict[str, dict[str, str]] = {
|
|
145
|
-
tname: {col: _dtype_to_sql(dt) for col, dt in schema.columns}
|
|
146
|
-
for tname, schema in in_schemas.items()
|
|
147
|
-
}
|
|
148
|
-
|
|
149
|
-
# 2) Parse with read dialect; 3) Qualify (resolve names, expand star); 4) Annotate types
|
|
150
|
-
qualified = opt_qualify.qualify(parsed, schema=sqlglot_schema, dialect=dialect)
|
|
151
|
-
typed = opt_annot.annotate_types(qualified, schema=sqlglot_schema)
|
|
152
|
-
|
|
153
|
-
# 5) Extract projection names and types
|
|
154
|
-
select = typed if isinstance(typed, sge.Select) else typed.find(sge.Select)
|
|
155
|
-
if select is None:
|
|
156
|
-
raise NotImplementedError(
|
|
157
|
-
"Only SELECT queries are supported for schema deduction"
|
|
158
|
-
)
|
|
159
|
-
|
|
160
|
-
def _sqlglot_type_to_dtype(tobj: Any) -> DType:
|
|
161
|
-
ts = str(tobj).lower().replace(" with time zone", "").strip()
|
|
162
|
-
base = ts.split("(", 1)[0].strip()
|
|
163
|
-
return SQL_TYPE_TO_MP.get(base, STRING)
|
|
164
|
-
|
|
165
|
-
pairs: list[tuple[str, DType]] = []
|
|
166
|
-
idx = 0
|
|
167
|
-
used: set[str] = set()
|
|
168
|
-
for proj in select.expressions:
|
|
169
|
-
name = getattr(proj, "alias_or_name", None) or getattr(proj, "name", None)
|
|
170
|
-
if not name:
|
|
171
|
-
name = f"expr_{idx}"
|
|
172
|
-
idx += 1
|
|
173
|
-
t = getattr(proj, "type", None)
|
|
174
|
-
if t is None:
|
|
175
|
-
raise NotImplementedError(
|
|
176
|
-
"Cannot infer type for projection; please provide out_type explicitly"
|
|
177
|
-
)
|
|
178
|
-
dtype = _sqlglot_type_to_dtype(t)
|
|
179
|
-
if name in used:
|
|
180
|
-
raise ValueError(
|
|
181
|
-
f"Duplicate output column name '{name}' after qualification"
|
|
182
|
-
)
|
|
183
|
-
used.add(name)
|
|
184
|
-
pairs.append((name, dtype))
|
|
185
|
-
|
|
186
|
-
return TableType.from_pairs(pairs)
|
|
187
|
-
|
|
188
|
-
|
|
189
|
-
@_SQL_MOD.op_def()
|
|
190
|
-
def run_sql(
|
|
191
|
-
query: str,
|
|
192
|
-
*,
|
|
193
|
-
out_type: TableType | None = None,
|
|
194
|
-
dialect: str = "duckdb",
|
|
195
|
-
**in_tables: Any,
|
|
196
|
-
) -> tuple[PFunction, list[MPObject], PyTreeDef]:
|
|
197
|
-
"""Build a sql.run PFunction from a SQL query with optional schema deduction.
|
|
198
|
-
|
|
199
|
-
API: run_sql(query: str, *, out_type: TableType | None = None, dialect: str = "duckdb", **in_tables) -> (PFunction, [MPObject], PyTreeDef)
|
|
200
|
-
|
|
201
|
-
Semantics:
|
|
202
|
-
- Parses the SQL and binds only the tables that are actually referenced in the query by name.
|
|
203
|
-
- If ``out_type`` is not provided, attempts to deduce the output table schema using sqlglot (qualify + annotate types).
|
|
204
|
-
- Returns a triad consisting of the constructed PFunction (``fn_type='sql.run'``), the ordered list of input MPObjects, and the output PyTreeDef.
|
|
205
|
-
|
|
206
|
-
Difference vs ``run_sql_raw``: this op can infer ``out_type`` and will parse the SQL to filter inputs; ``run_sql_raw`` requires an explicit ``out_type`` and does not parse/filter inputs.
|
|
207
|
-
"""
|
|
208
|
-
# Extract required table names from SQL (order by first appearance)
|
|
209
|
-
parsed = sg.parse_one(query, read=dialect)
|
|
210
|
-
required_names: list[str] = []
|
|
211
|
-
for t in parsed.find_all(sge.Table):
|
|
212
|
-
# Prefer .name; fallback to str(this) if needed
|
|
213
|
-
tname = getattr(t, "name", None) or str(t.this)
|
|
214
|
-
if tname not in required_names:
|
|
215
|
-
required_names.append(tname)
|
|
216
|
-
|
|
217
|
-
# Disallow extras not referenced by the query to avoid surprises
|
|
218
|
-
extra = set(in_tables.keys()) - set(required_names)
|
|
219
|
-
if extra:
|
|
220
|
-
raise ValueError(
|
|
221
|
-
f"Unexpected tables provided that are not referenced in SQL: {sorted(extra)}"
|
|
222
|
-
)
|
|
223
|
-
|
|
224
|
-
# Validate required tables and require MPObject for runtime registration
|
|
225
|
-
in_names: list[str] = []
|
|
226
|
-
ins_info: list[TableType] = []
|
|
227
|
-
in_vars: list[MPObject] = []
|
|
228
|
-
for name in required_names:
|
|
229
|
-
if name not in in_tables:
|
|
230
|
-
raise KeyError(f"Missing required table '{name}' for SQL query")
|
|
231
|
-
obj = in_tables[name]
|
|
232
|
-
if not isinstance(obj, MPObject):
|
|
233
|
-
raise TypeError(
|
|
234
|
-
f"Table '{name}' must be an MPObject (for runtime registration), got {type(obj).__name__}"
|
|
235
|
-
)
|
|
236
|
-
assert obj.schema is not None, f"Input table '{name}' missing schema"
|
|
237
|
-
in_vars.append(obj)
|
|
238
|
-
ins_info.append(obj.schema)
|
|
239
|
-
in_names.append(name)
|
|
240
|
-
|
|
241
|
-
if out_type is None:
|
|
242
|
-
in_schemas: dict[str, TableType] = {
|
|
243
|
-
n: in_tables[n].schema for n in required_names
|
|
244
|
-
}
|
|
245
|
-
out_type = _deduce_out_schema(parsed, dialect, in_schemas)
|
|
246
|
-
|
|
247
|
-
pfn = PFunction(
|
|
248
|
-
fn_type="sql.run",
|
|
249
|
-
ins_info=tuple(ins_info),
|
|
250
|
-
outs_info=(out_type,),
|
|
251
|
-
fn_name="",
|
|
252
|
-
fn_text=query,
|
|
253
|
-
in_names=tuple(in_names),
|
|
254
|
-
dialect=dialect,
|
|
255
|
-
)
|
|
256
|
-
_, treedef = tree_flatten(out_type)
|
|
257
|
-
return pfn, in_vars, treedef
|
|
258
|
-
|
|
259
|
-
|
|
260
|
-
@_SQL_MOD.op_def()
|
|
261
|
-
def run_sql_raw(
|
|
262
|
-
query: str,
|
|
263
|
-
out_type: TableType,
|
|
264
|
-
*,
|
|
265
|
-
dialect: str = "duckdb",
|
|
266
|
-
in_tables: dict[str, MPObject] | None = None,
|
|
267
|
-
) -> tuple[PFunction, list[MPObject], PyTreeDef]:
|
|
268
|
-
"""Build a sql.run PFunction from a SQL query with an explicit output schema.
|
|
269
|
-
|
|
270
|
-
API: run_sql_raw(query: str, out_type: TableType, *, dialect: str = "duckdb", in_tables: dict[str, MPObject] | None = None) -> (PFunction, [MPObject], PyTreeDef)
|
|
271
|
-
|
|
272
|
-
Semantics:
|
|
273
|
-
- Does not parse the SQL; carries all tables provided via ``in_tables`` in the mapping's iteration order.
|
|
274
|
-
- Requires an explicit ``out_type``; no schema deduction is attempted.
|
|
275
|
-
- Returns a triad consisting of the constructed PFunction (``fn_type='sql.run'``), the ordered list of input MPObjects, and the output PyTreeDef.
|
|
276
|
-
|
|
277
|
-
Difference vs ``run_sql``: this op requires ``out_type`` and does not parse/filter inputs; ``run_sql`` can infer ``out_type`` and selects only tables referenced by the query.
|
|
278
|
-
"""
|
|
279
|
-
|
|
280
|
-
# Collect inputs strictly as provided by caller
|
|
281
|
-
in_names: list[str] = []
|
|
282
|
-
ins_info: list[TableType] = []
|
|
283
|
-
in_vars: list[MPObject] = []
|
|
284
|
-
if in_tables:
|
|
285
|
-
for name, tbl in in_tables.items():
|
|
286
|
-
if not isinstance(tbl, MPObject):
|
|
287
|
-
raise TypeError(f"Input table '{name}' is not an MPObject {type(tbl)}")
|
|
288
|
-
assert tbl.schema is not None, f"Input table '{name}' is missing a schema"
|
|
289
|
-
in_names.append(name)
|
|
290
|
-
ins_info.append(tbl.schema)
|
|
291
|
-
in_vars.append(tbl)
|
|
292
|
-
|
|
293
|
-
pfn = PFunction(
|
|
294
|
-
fn_type="sql.run",
|
|
295
|
-
fn_name="",
|
|
296
|
-
fn_text=query,
|
|
297
|
-
ins_info=tuple(ins_info),
|
|
298
|
-
outs_info=(out_type,),
|
|
299
|
-
in_names=tuple(in_names),
|
|
300
|
-
dialect=dialect,
|
|
301
|
-
)
|
|
302
|
-
_, treedef = tree_flatten(out_type)
|
|
303
|
-
return pfn, in_vars, treedef
|
mplang/v1/ops/tee.py
DELETED
|
@@ -1,36 +0,0 @@
|
|
|
1
|
-
# Copyright 2025 Ant Group Co., Ltd.
|
|
2
|
-
#
|
|
3
|
-
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
4
|
-
# you may not use this file except in compliance with the License.
|
|
5
|
-
# You may obtain a copy of the License at
|
|
6
|
-
#
|
|
7
|
-
# http://www.apache.org/licenses/LICENSE-2.0
|
|
8
|
-
#
|
|
9
|
-
# Unless required by applicable law or agreed to in writing, software
|
|
10
|
-
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
11
|
-
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
12
|
-
# See the License for the specific language governing permissions and
|
|
13
|
-
# limitations under the License.
|
|
14
|
-
|
|
15
|
-
from __future__ import annotations
|
|
16
|
-
|
|
17
|
-
from mplang.v1.core import UINT8, TensorType
|
|
18
|
-
from mplang.v1.ops.base import stateless_mod
|
|
19
|
-
|
|
20
|
-
_TEE_MOD = stateless_mod("tee")
|
|
21
|
-
|
|
22
|
-
|
|
23
|
-
@_TEE_MOD.simple_op()
|
|
24
|
-
def quote_gen(pk: TensorType) -> TensorType:
|
|
25
|
-
"""TEE quote generation binding the provided ephemeral public key."""
|
|
26
|
-
_ = pk # Mark as used for the decorator
|
|
27
|
-
return TensorType(UINT8, (-1,))
|
|
28
|
-
|
|
29
|
-
|
|
30
|
-
@_TEE_MOD.simple_op()
|
|
31
|
-
def attest(quote: TensorType) -> TensorType:
|
|
32
|
-
"""TEE quote verification returning the attested TEE public key.
|
|
33
|
-
API (mock): attest(quote: u8[33]) -> tee_pk: u8[32]
|
|
34
|
-
"""
|
|
35
|
-
_ = quote # Mark as used for the decorator
|
|
36
|
-
return TensorType(UINT8, (32,))
|