mplang-nightly 0.1.dev269__py3-none-any.whl → 0.1.dev271__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (180) hide show
  1. mplang/__init__.py +391 -17
  2. mplang/{v2/backends → backends}/__init__.py +9 -7
  3. mplang/{v2/backends → backends}/bfv_impl.py +6 -6
  4. mplang/{v2/backends → backends}/crypto_impl.py +6 -6
  5. mplang/{v2/backends → backends}/field_impl.py +5 -5
  6. mplang/{v2/backends → backends}/func_impl.py +4 -4
  7. mplang/{v2/backends → backends}/phe_impl.py +3 -3
  8. mplang/{v2/backends → backends}/simp_design.md +1 -1
  9. mplang/{v2/backends → backends}/simp_driver/__init__.py +5 -5
  10. mplang/{v2/backends → backends}/simp_driver/http.py +8 -8
  11. mplang/{v2/backends → backends}/simp_driver/mem.py +9 -9
  12. mplang/{v2/backends → backends}/simp_driver/ops.py +4 -4
  13. mplang/{v2/backends → backends}/simp_driver/state.py +2 -2
  14. mplang/{v2/backends → backends}/simp_driver/values.py +2 -2
  15. mplang/{v2/backends → backends}/simp_worker/__init__.py +3 -3
  16. mplang/{v2/backends → backends}/simp_worker/http.py +10 -10
  17. mplang/{v2/backends → backends}/simp_worker/mem.py +1 -1
  18. mplang/{v2/backends → backends}/simp_worker/ops.py +5 -5
  19. mplang/{v2/backends → backends}/simp_worker/state.py +2 -4
  20. mplang/{v2/backends → backends}/spu_impl.py +8 -8
  21. mplang/{v2/backends → backends}/spu_state.py +4 -4
  22. mplang/{v2/backends → backends}/store_impl.py +3 -3
  23. mplang/{v2/backends → backends}/table_impl.py +8 -8
  24. mplang/{v2/backends → backends}/tee_impl.py +6 -6
  25. mplang/{v2/backends → backends}/tensor_impl.py +6 -6
  26. mplang/{v2/cli.py → cli.py} +9 -9
  27. mplang/{v2/cli_guide.md → cli_guide.md} +12 -12
  28. mplang/{v2/dialects → dialects}/__init__.py +5 -5
  29. mplang/{v2/dialects → dialects}/bfv.py +6 -6
  30. mplang/{v2/dialects → dialects}/crypto.py +5 -5
  31. mplang/{v2/dialects → dialects}/dtypes.py +2 -2
  32. mplang/{v2/dialects → dialects}/field.py +3 -3
  33. mplang/{v2/dialects → dialects}/func.py +2 -2
  34. mplang/{v2/dialects → dialects}/phe.py +6 -6
  35. mplang/{v2/dialects → dialects}/simp.py +6 -6
  36. mplang/{v2/dialects → dialects}/spu.py +7 -7
  37. mplang/{v2/dialects → dialects}/store.py +2 -2
  38. mplang/{v2/dialects → dialects}/table.py +3 -3
  39. mplang/{v2/dialects → dialects}/tee.py +6 -6
  40. mplang/{v2/dialects → dialects}/tensor.py +5 -5
  41. mplang/{v2/edsl → edsl}/__init__.py +3 -3
  42. mplang/{v2/edsl → edsl}/context.py +6 -6
  43. mplang/{v2/edsl → edsl}/graph.py +5 -5
  44. mplang/{v2/edsl → edsl}/jit.py +2 -2
  45. mplang/{v2/edsl → edsl}/object.py +1 -1
  46. mplang/{v2/edsl → edsl}/primitive.py +5 -5
  47. mplang/{v2/edsl → edsl}/printer.py +1 -1
  48. mplang/{v2/edsl → edsl}/serde.py +1 -1
  49. mplang/{v2/edsl → edsl}/tracer.py +7 -7
  50. mplang/{v2/edsl → edsl}/typing.py +1 -1
  51. mplang/{v2/kernels → kernels}/ldpc.cpp +13 -13
  52. mplang/{v2/kernels → kernels}/okvs.cpp +4 -4
  53. mplang/{v2/kernels → kernels}/okvs_opt.cpp +31 -31
  54. mplang/{v2/kernels → kernels}/py_kernels.py +1 -1
  55. mplang/{v2/libs → libs}/collective.py +5 -5
  56. mplang/{v2/libs → libs}/device/__init__.py +1 -1
  57. mplang/{v2/libs → libs}/device/api.py +12 -12
  58. mplang/{v2/libs → libs}/ml/__init__.py +1 -1
  59. mplang/{v2/libs → libs}/ml/sgb.py +4 -4
  60. mplang/{v2/libs → libs}/mpc/__init__.py +3 -3
  61. mplang/{v2/libs → libs}/mpc/_utils.py +2 -2
  62. mplang/{v2/libs → libs}/mpc/analytics/aggregation.py +1 -1
  63. mplang/{v2/libs → libs}/mpc/analytics/groupby.py +2 -2
  64. mplang/{v2/libs → libs}/mpc/analytics/permutation.py +3 -3
  65. mplang/{v2/libs → libs}/mpc/ot/base.py +3 -3
  66. mplang/{v2/libs → libs}/mpc/ot/extension.py +2 -2
  67. mplang/{v2/libs → libs}/mpc/ot/silent.py +4 -4
  68. mplang/{v2/libs → libs}/mpc/psi/cuckoo.py +3 -3
  69. mplang/{v2/libs → libs}/mpc/psi/okvs.py +1 -1
  70. mplang/{v2/libs → libs}/mpc/psi/okvs_gct.py +3 -3
  71. mplang/{v2/libs → libs}/mpc/psi/oprf.py +3 -3
  72. mplang/{v2/libs → libs}/mpc/psi/rr22.py +7 -7
  73. mplang/{v2/libs → libs}/mpc/psi/unbalanced.py +4 -4
  74. mplang/{v2/libs → libs}/mpc/vole/gilboa.py +3 -3
  75. mplang/{v2/libs → libs}/mpc/vole/ldpc.py +2 -2
  76. mplang/{v2/libs → libs}/mpc/vole/silver.py +6 -6
  77. mplang/{v2/runtime → runtime}/interpreter.py +11 -11
  78. mplang/{v2/runtime → runtime}/value.py +2 -2
  79. mplang/{v1/runtime → utils}/__init__.py +18 -15
  80. mplang/{v1/utils → utils}/func_utils.py +1 -1
  81. {mplang_nightly-0.1.dev269.dist-info → mplang_nightly-0.1.dev271.dist-info}/METADATA +2 -2
  82. mplang_nightly-0.1.dev271.dist-info/RECORD +102 -0
  83. mplang/v1/__init__.py +0 -157
  84. mplang/v1/_device.py +0 -602
  85. mplang/v1/analysis/__init__.py +0 -37
  86. mplang/v1/analysis/diagram.py +0 -567
  87. mplang/v1/core/__init__.py +0 -157
  88. mplang/v1/core/cluster.py +0 -343
  89. mplang/v1/core/comm.py +0 -281
  90. mplang/v1/core/context_mgr.py +0 -50
  91. mplang/v1/core/dtypes.py +0 -335
  92. mplang/v1/core/expr/__init__.py +0 -80
  93. mplang/v1/core/expr/ast.py +0 -542
  94. mplang/v1/core/expr/evaluator.py +0 -581
  95. mplang/v1/core/expr/printer.py +0 -285
  96. mplang/v1/core/expr/transformer.py +0 -141
  97. mplang/v1/core/expr/utils.py +0 -78
  98. mplang/v1/core/expr/visitor.py +0 -85
  99. mplang/v1/core/expr/walk.py +0 -387
  100. mplang/v1/core/interp.py +0 -160
  101. mplang/v1/core/mask.py +0 -325
  102. mplang/v1/core/mpir.py +0 -965
  103. mplang/v1/core/mpobject.py +0 -117
  104. mplang/v1/core/mptype.py +0 -407
  105. mplang/v1/core/pfunc.py +0 -130
  106. mplang/v1/core/primitive.py +0 -877
  107. mplang/v1/core/table.py +0 -218
  108. mplang/v1/core/tensor.py +0 -75
  109. mplang/v1/core/tracer.py +0 -383
  110. mplang/v1/host.py +0 -130
  111. mplang/v1/kernels/__init__.py +0 -41
  112. mplang/v1/kernels/base.py +0 -125
  113. mplang/v1/kernels/basic.py +0 -240
  114. mplang/v1/kernels/context.py +0 -369
  115. mplang/v1/kernels/crypto.py +0 -122
  116. mplang/v1/kernels/fhe.py +0 -858
  117. mplang/v1/kernels/mock_tee.py +0 -72
  118. mplang/v1/kernels/phe.py +0 -1864
  119. mplang/v1/kernels/spu.py +0 -341
  120. mplang/v1/kernels/sql_duckdb.py +0 -44
  121. mplang/v1/kernels/stablehlo.py +0 -90
  122. mplang/v1/kernels/value.py +0 -626
  123. mplang/v1/ops/__init__.py +0 -35
  124. mplang/v1/ops/base.py +0 -424
  125. mplang/v1/ops/basic.py +0 -294
  126. mplang/v1/ops/crypto.py +0 -262
  127. mplang/v1/ops/fhe.py +0 -272
  128. mplang/v1/ops/jax_cc.py +0 -147
  129. mplang/v1/ops/nnx_cc.py +0 -168
  130. mplang/v1/ops/phe.py +0 -216
  131. mplang/v1/ops/spu.py +0 -151
  132. mplang/v1/ops/sql_cc.py +0 -303
  133. mplang/v1/ops/tee.py +0 -36
  134. mplang/v1/protos/v1alpha1/mpir_pb2.py +0 -63
  135. mplang/v1/protos/v1alpha1/mpir_pb2.pyi +0 -557
  136. mplang/v1/protos/v1alpha1/value_pb2.py +0 -34
  137. mplang/v1/protos/v1alpha1/value_pb2.pyi +0 -169
  138. mplang/v1/runtime/channel.py +0 -230
  139. mplang/v1/runtime/cli.py +0 -451
  140. mplang/v1/runtime/client.py +0 -456
  141. mplang/v1/runtime/communicator.py +0 -131
  142. mplang/v1/runtime/data_providers.py +0 -303
  143. mplang/v1/runtime/driver.py +0 -324
  144. mplang/v1/runtime/exceptions.py +0 -27
  145. mplang/v1/runtime/http_api.md +0 -56
  146. mplang/v1/runtime/link_comm.py +0 -196
  147. mplang/v1/runtime/server.py +0 -501
  148. mplang/v1/runtime/session.py +0 -270
  149. mplang/v1/runtime/simulation.py +0 -324
  150. mplang/v1/simp/__init__.py +0 -13
  151. mplang/v1/simp/api.py +0 -353
  152. mplang/v1/simp/mpi.py +0 -131
  153. mplang/v1/simp/party.py +0 -225
  154. mplang/v1/simp/random.py +0 -120
  155. mplang/v1/simp/smpc.py +0 -238
  156. mplang/v1/utils/__init__.py +0 -13
  157. mplang/v1/utils/crypto.py +0 -32
  158. mplang/v1/utils/spu_utils.py +0 -130
  159. mplang/v1/utils/table_utils.py +0 -185
  160. mplang/v2/__init__.py +0 -424
  161. mplang_nightly-0.1.dev269.dist-info/RECORD +0 -180
  162. /mplang/{v2/backends → backends}/channel.py +0 -0
  163. /mplang/{v2/edsl → edsl}/README.md +0 -0
  164. /mplang/{v2/edsl → edsl}/registry.py +0 -0
  165. /mplang/{v2/kernels → kernels}/Makefile +0 -0
  166. /mplang/{v2/kernels → kernels}/__init__.py +0 -0
  167. /mplang/{v2/kernels → kernels}/gf128.cpp +0 -0
  168. /mplang/{v2/libs → libs}/device/cluster.py +0 -0
  169. /mplang/{v2/libs → libs}/mpc/analytics/__init__.py +0 -0
  170. /mplang/{v2/libs → libs}/mpc/analytics/groupby.md +0 -0
  171. /mplang/{v2/libs → libs}/mpc/common/constants.py +0 -0
  172. /mplang/{v2/libs → libs}/mpc/ot/__init__.py +0 -0
  173. /mplang/{v2/libs → libs}/mpc/psi/__init__.py +0 -0
  174. /mplang/{v2/libs → libs}/mpc/vole/__init__.py +0 -0
  175. /mplang/{v2/runtime → runtime}/__init__.py +0 -0
  176. /mplang/{v2/runtime → runtime}/dialect_state.py +0 -0
  177. /mplang/{v2/runtime → runtime}/object_store.py +0 -0
  178. {mplang_nightly-0.1.dev269.dist-info → mplang_nightly-0.1.dev271.dist-info}/WHEEL +0 -0
  179. {mplang_nightly-0.1.dev269.dist-info → mplang_nightly-0.1.dev271.dist-info}/entry_points.txt +0 -0
  180. {mplang_nightly-0.1.dev269.dist-info → mplang_nightly-0.1.dev271.dist-info}/licenses/LICENSE +0 -0
mplang/v1/ops/phe.py DELETED
@@ -1,216 +0,0 @@
1
- # Copyright 2025 Ant Group Co., Ltd.
2
- #
3
- # Licensed under the Apache License, Version 2.0 (the "License");
4
- # you may not use this file except in compliance with the License.
5
- # You may obtain a copy of the License at
6
- #
7
- # http://www.apache.org/licenses/LICENSE-2.0
8
- #
9
- # Unless required by applicable law or agreed to in writing, software
10
- # distributed under the License is distributed on an "AS IS" BASIS,
11
- # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
- # See the License for the specific language governing permissions and
13
- # limitations under the License.
14
-
15
- """PHE (Partially Homomorphic Encryption) frontend operations."""
16
-
17
- from mplang.v1.core import UINT8, TensorType
18
- from mplang.v1.ops.base import stateless_mod
19
-
20
- _PHE_MOD = stateless_mod("phe")
21
-
22
-
23
- @_PHE_MOD.simple_op()
24
- def keygen(
25
- *,
26
- scheme: str = "paillier",
27
- key_size: int = 2048,
28
- max_value: int | None = None,
29
- fxp_bits: int | None = None,
30
- ) -> tuple[TensorType, TensorType]:
31
- """Generate a PHE key pair: returns (public_key, private_key).
32
-
33
- Keys are represented with a sentinel TensorType UINT8[(-1, 0)] to indicate
34
- non-structural, backend-only handles. Runtime validation will treat this
35
- shape as an opaque placeholder and skip dtype/shape checks.
36
-
37
- Attributes (forwarded to backend):
38
- scheme: PHE scheme (default: 'paillier')
39
- key_size: Modulus size in bits (default: 2048)
40
- max_value: Optional range-encoding bound B. If provided, the backend will
41
- encode/decode integers/floats within [-B, B] and treat (B, N-B) as overflow.
42
- Pick B to exceed the largest intermediate magnitude you expect in homomorphic
43
- combinations. If omitted, backend default is used (currently 2**32).
44
- fxp_bits: Optional fixed-point fractional bits for float encoding (default backend value).
45
- """
46
- key_spec = TensorType(UINT8, (-1, 0))
47
- return key_spec, key_spec
48
-
49
-
50
- @_PHE_MOD.simple_op()
51
- def encrypt(plaintext: TensorType, public_key: TensorType) -> TensorType:
52
- """Encrypt plaintext using PHE public key: returns ciphertext with same semantic type as plaintext."""
53
- _ = public_key
54
- return plaintext
55
-
56
-
57
- @_PHE_MOD.simple_op()
58
- def add(operand1: TensorType, operand2: TensorType) -> TensorType:
59
- """Add two PHE operands (semantics depend on backend representation)."""
60
- _ = operand2
61
- return operand1
62
-
63
-
64
- @_PHE_MOD.simple_op()
65
- def mul(ciphertext: TensorType, plaintext: TensorType) -> TensorType:
66
- """Multiply a PHE ciphertext with a plaintext value (ciphertext dtype preserved)."""
67
- if plaintext.dtype.is_floating:
68
- raise ValueError(
69
- "PHE multiplication does not support floating-point plaintext."
70
- )
71
- return ciphertext
72
-
73
-
74
- @_PHE_MOD.simple_op()
75
- def decrypt(ciphertext: TensorType, private_key: TensorType) -> TensorType:
76
- """Decrypt ciphertext using PHE private key: returns plaintext with same semantic type as ciphertext."""
77
- _ = private_key
78
- return ciphertext
79
-
80
-
81
- @_PHE_MOD.simple_op()
82
- def dot(ciphertext: TensorType, plaintext: TensorType) -> TensorType:
83
- """Compute dot product of ciphertext with plaintext.
84
-
85
- Args:
86
- ciphertext: The ciphertext operand (first argument)
87
- plaintext: The plaintext operand (second argument)
88
-
89
- Returns:
90
- TensorType: Result tensor type with computed shape following numpy dot product rules
91
- """
92
- # For dot product, we need to calculate the result shape
93
- # This follows numpy dot product rules
94
- import numpy as np
95
-
96
- # Create dummy arrays to determine result shape
97
- dummy_ct = np.zeros(ciphertext.shape)
98
- dummy_pt = np.zeros(plaintext.shape)
99
- dummy_result = np.dot(dummy_ct, dummy_pt)
100
-
101
- return TensorType(ciphertext.dtype, dummy_result.shape)
102
-
103
-
104
- @_PHE_MOD.simple_op()
105
- def gather(ciphertext: TensorType, indices: TensorType, *, axis: int = 0) -> TensorType:
106
- """Gather elements from ciphertext using indices.
107
-
108
- Args:
109
- ciphertext: The ciphertext to gather from
110
- indices: The indices to gather
111
- axis: The axis along which to gather (default: 0)
112
- """
113
- # Calculate result shape based on axis parameter
114
- ct_shape = list(ciphertext.shape)
115
- indices_shape = list(indices.shape)
116
-
117
- # Normalize negative axis
118
- normalized_axis = axis if axis >= 0 else len(ct_shape) + axis
119
-
120
- # Result shape: replace the axis dimension with indices shape
121
- result_shape = (
122
- ct_shape[:normalized_axis] + indices_shape + ct_shape[normalized_axis + 1 :]
123
- )
124
- return TensorType(ciphertext.dtype, tuple(result_shape))
125
-
126
-
127
- @_PHE_MOD.simple_op()
128
- def scatter(
129
- ciphertext: TensorType,
130
- indices: TensorType,
131
- updates: TensorType,
132
- *,
133
- axis: int = 0,
134
- ) -> TensorType:
135
- """Scatter updates into ciphertext at specified indices.
136
-
137
- Args:
138
- ciphertext: The ciphertext to scatter into
139
- indices: The indices to scatter at
140
- updates: The ciphertext updates to scatter
141
- axis: The axis along which to scatter (default: 0)
142
-
143
- Returns:
144
- TensorType: Result tensor type with same shape and dtype as original ciphertext
145
- """
146
- return ciphertext
147
-
148
-
149
- @_PHE_MOD.simple_op()
150
- def concat(operand0: TensorType, operand1: TensorType, *, axis: int = 0) -> TensorType:
151
- """Concatenate ciphertext tensors along specified axis.
152
-
153
- Args:
154
- operand0: The first ciphertext operand to concatenate
155
- operand1: The second ciphertext operand to concatenate
156
- axis: Axis along which to concatenate
157
-
158
- Returns:
159
- TensorType: Result tensor type with computed shape following numpy concatenation rules
160
- """
161
- # All operands should have same dtype
162
- first_dtype = operand0.dtype
163
- if operand1.dtype != first_dtype:
164
- raise ValueError("All operands must have the same dtype for concatenation")
165
-
166
- # Calculate result shape using numpy concatenation logic
167
- import numpy as np
168
-
169
- dummy_arrays = [np.zeros(operand0.shape), np.zeros(operand1.shape)]
170
- dummy_result = np.concatenate(dummy_arrays, axis=axis)
171
-
172
- return TensorType(first_dtype, dummy_result.shape)
173
-
174
-
175
- @_PHE_MOD.simple_op()
176
- def reshape(ciphertext: TensorType, *, new_shape: tuple[int, ...]) -> TensorType:
177
- """Reshape ciphertext to new shape.
178
-
179
- Args:
180
- ciphertext: The ciphertext to reshape
181
- new_shape: The target shape (can contain -1 for inferred dimension)
182
-
183
- Returns:
184
- TensorType: Result tensor type with computed shape following numpy reshape rules
185
- """
186
- # Calculate the actual result shape (handling -1 inference)
187
- import numpy as np
188
-
189
- dummy_array = np.zeros(ciphertext.shape)
190
- # use this to check the correctness of new_shape
191
- dummy_result = dummy_array.reshape(new_shape)
192
- actual_shape = dummy_result.shape
193
-
194
- return TensorType(ciphertext.dtype, actual_shape)
195
-
196
-
197
- @_PHE_MOD.simple_op()
198
- def transpose(
199
- ciphertext: TensorType, *, axes: tuple[int, ...] | None = None
200
- ) -> TensorType:
201
- """Transpose ciphertext by permuting axes.
202
-
203
- Args:
204
- ciphertext: The ciphertext to transpose
205
- axes: Permutation of axes (None for default reverse order)
206
-
207
- Returns:
208
- TensorType: Result tensor type with computed shape following numpy transpose rules
209
- """
210
- # Calculate result shape using numpy transpose logic
211
- import numpy as np
212
-
213
- dummy_array = np.zeros(ciphertext.shape)
214
- dummy_result = np.transpose(dummy_array, axes)
215
-
216
- return TensorType(ciphertext.dtype, dummy_result.shape)
mplang/v1/ops/spu.py DELETED
@@ -1,151 +0,0 @@
1
- # Copyright 2025 Ant Group Co., Ltd.
2
- #
3
- # Licensed under the Apache License, Version 2.0 (the "License");
4
- # you may not use this file except in compliance with the License.
5
- # You may obtain a copy of the License at
6
- #
7
- # http://www.apache.org/licenses/LICENSE-2.0
8
- #
9
- # Unless required by applicable law or agreed to in writing, software
10
- # distributed under the License is distributed on an "AS IS" BASIS,
11
- # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
- # See the License for the specific language governing permissions and
13
- # limitations under the License.
14
-
15
- from __future__ import annotations
16
-
17
- from collections.abc import Callable
18
- from typing import Any
19
-
20
- import jax.numpy as jnp
21
- import spu.libspu as libspu
22
- import spu.utils.frontend as spu_fe
23
- from jax import ShapeDtypeStruct
24
- from jax.tree_util import PyTreeDef, tree_flatten
25
-
26
- from mplang.v1.core import MPObject, PFunction, TensorType, get_fn_name
27
- from mplang.v1.ops.base import stateless_mod
28
- from mplang.v1.utils.func_utils import normalize_fn
29
-
30
-
31
- class Visibility:
32
- """Frontend visibility constants mapping to libspu.Visibility.
33
-
34
- Note: these are direct aliases to libspu.Visibility members so that
35
- downstream serialization and backends receive the exact enum type
36
- they expect. Keep the friendly names (SECRET/PUBLIC/PRIVATE) for
37
- frontend ergonomics.
38
- """
39
-
40
- SECRET = libspu.Visibility.VIS_SECRET
41
- PUBLIC = libspu.Visibility.VIS_PUBLIC
42
- PRIVATE = libspu.Visibility.VIS_PRIVATE
43
-
44
-
45
- _SPU_MOD = stateless_mod("spu")
46
-
47
-
48
- @_SPU_MOD.simple_op()
49
- def makeshares(
50
- data: TensorType,
51
- *,
52
- world_size: int,
53
- visibility: libspu.Visibility = Visibility.SECRET,
54
- owner_rank: int = -1,
55
- enable_private: bool = False,
56
- ) -> tuple:
57
- """Create SPU shares from a plaintext tensor (type-only kernel).
58
-
59
- Returns a PyTree of TensorType repeated `world_size` times.
60
- Validation only; PFunction assembly handled by typed_op decorator.
61
- """
62
- if world_size <= 0:
63
- raise ValueError("world_size must be positive")
64
- if visibility == Visibility.PRIVATE:
65
- if not enable_private:
66
- raise ValueError("PRIVATE visibility disabled; set enable_private=True")
67
- if owner_rank < 0 or owner_rank >= world_size:
68
- raise ValueError(f"owner_rank {owner_rank} out of range [0,{world_size})")
69
- return tuple(data for _ in range(world_size))
70
-
71
-
72
- @_SPU_MOD.op_def()
73
- def reconstruct(*shares: MPObject) -> tuple[PFunction, list[MPObject], PyTreeDef]:
74
- """Reconstruct plaintext tensor from shares."""
75
- if len(shares) == 0:
76
- raise ValueError("reconstruct requires at least one share")
77
-
78
- ins_info = tuple(TensorType.from_obj(s) for s in shares)
79
- outs_info = (ins_info[0],)
80
- pfunc = PFunction(
81
- fn_type="spu.reconstruct",
82
- ins_info=ins_info,
83
- outs_info=outs_info,
84
- )
85
- _, treedef = tree_flatten(outs_info[0])
86
- return pfunc, list(shares), treedef
87
-
88
-
89
- def _compile_jax(
90
- copts: libspu.CompilerOptions,
91
- fn: Callable,
92
- *args: Any,
93
- **kwargs: Any,
94
- ) -> tuple[PFunction, list[MPObject], PyTreeDef]:
95
- """Compile a JAX function into SPU pphlo MLIR and wrap as PFunction.
96
-
97
- Resulting PFunction uses fn_type 'spu.run_pphlo'.
98
- """
99
-
100
- def is_variable(arg: Any) -> bool:
101
- return isinstance(arg, MPObject)
102
-
103
- normalized_fn, in_vars = normalize_fn(fn, args, kwargs, is_variable)
104
-
105
- jax_params = [
106
- ShapeDtypeStruct(arg.shape, jnp.dtype(arg.dtype.name)) for arg in in_vars
107
- ]
108
- in_vis = [libspu.Visibility.VIS_SECRET for _ in in_vars]
109
- in_names = [f"in{idx}" for idx in range(len(in_vars))]
110
- out_names_gen = lambda outs: [f"out{idx}" for idx in range(len(outs))]
111
-
112
- executable, out_info = spu_fe.compile(
113
- spu_fe.Kind.JAX,
114
- normalized_fn,
115
- [jax_params],
116
- {},
117
- in_names,
118
- in_vis,
119
- out_names_gen,
120
- static_argnums=(),
121
- static_argnames=None,
122
- copts=copts,
123
- )
124
- out_info_flat, out_tree = tree_flatten(out_info)
125
- output_tensor_infos = [TensorType.from_obj(out) for out in out_info_flat]
126
-
127
- executable_code = executable.code
128
- assert isinstance(executable_code, bytes), (
129
- f"Expected bytes, got {type(executable_code)}"
130
- )
131
- executable_code = executable_code.decode("utf-8")
132
-
133
- pfunc = PFunction(
134
- fn_type="spu.run_pphlo",
135
- ins_info=tuple(TensorType.from_obj(x) for x in in_vars),
136
- outs_info=tuple(output_tensor_infos),
137
- fn_name=get_fn_name(fn),
138
- fn_text=executable_code,
139
- input_visibilities=in_vis,
140
- input_names=list(executable.input_names),
141
- output_names=list(executable.output_names),
142
- executable_name=executable.name,
143
- )
144
- return pfunc, in_vars, out_tree
145
-
146
-
147
- @_SPU_MOD.op_def()
148
- def jax_compile(
149
- fn: Callable, *args: Any, **kwargs: Any
150
- ) -> tuple[PFunction, list[MPObject], PyTreeDef]:
151
- return _compile_jax(libspu.CompilerOptions(), fn, *args, **kwargs)
mplang/v1/ops/sql_cc.py DELETED
@@ -1,303 +0,0 @@
1
- # Copyright 2025 Ant Group Co., Ltd.
2
- #
3
- # Licensed under the Apache License, Version 2.0 (the "License");
4
- # you may not use this file except in compliance with the License.
5
- # You may obtain a copy of the License at
6
- #
7
- # http://www.apache.org/licenses/LICENSE-2.0
8
- #
9
- # Unless required by applicable law or agreed to in writing, software
10
- # distributed under the License is distributed on an "AS IS" BASIS,
11
- # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
- # See the License for the specific language governing permissions and
13
- # limitations under the License.
14
-
15
- from typing import Any
16
-
17
- import sqlglot as sg
18
- from jax.tree_util import PyTreeDef, tree_flatten
19
- from sqlglot import exp as sge
20
- from sqlglot.optimizer import annotate_types as opt_annot
21
- from sqlglot.optimizer import qualify as opt_qualify
22
-
23
- from mplang.v1.core import MPObject, PFunction, TableType
24
- from mplang.v1.core.dtypes import (
25
- BINARY,
26
- BOOL,
27
- DATE,
28
- DECIMAL,
29
- FLOAT32,
30
- FLOAT64,
31
- INT8,
32
- INT16,
33
- INT32,
34
- INT64,
35
- INTERVAL,
36
- JSON,
37
- STRING,
38
- TIME,
39
- TIMESTAMP,
40
- UINT8,
41
- UINT16,
42
- UINT32,
43
- UINT64,
44
- UUID,
45
- DType,
46
- )
47
- from mplang.v1.ops.base import stateless_mod
48
-
49
- _SQL_MOD = stateless_mod("sql")
50
-
51
-
52
- # Static dtype mappings (MPLang <-> SQL)
53
- MP_TO_SQL_TYPE: dict[DType, str] = {
54
- # Floats
55
- FLOAT64: "DOUBLE",
56
- FLOAT32: "FLOAT",
57
- # Signed ints
58
- INT8: "TINYINT",
59
- INT16: "SMALLINT",
60
- INT32: "INT",
61
- INT64: "BIGINT",
62
- # Unsigned ints (portable approximations)
63
- UINT8: "SMALLINT",
64
- UINT16: "INT",
65
- UINT32: "BIGINT",
66
- UINT64: "DECIMAL(38)",
67
- # Booleans & strings
68
- BOOL: "BOOLEAN",
69
- STRING: "VARCHAR",
70
- # Dates / times
71
- DATE: "DATE",
72
- TIME: "TIME",
73
- TIMESTAMP: "TIMESTAMP",
74
- # Other table types
75
- DECIMAL: "DECIMAL",
76
- JSON: "JSON",
77
- BINARY: "BLOB",
78
- UUID: "UUID",
79
- INTERVAL: "INTERVAL",
80
- }
81
-
82
- SQL_TYPE_TO_MP: dict[str, DType] = {
83
- # Floats
84
- "double": FLOAT64,
85
- "double precision": FLOAT64,
86
- "float": FLOAT32,
87
- "real": FLOAT32,
88
- # Signed ints
89
- "bigint": INT64,
90
- "long": INT64,
91
- "int": INT32,
92
- "integer": INT32,
93
- "int4": INT32,
94
- "smallint": INT16,
95
- "int2": INT16,
96
- "tinyint": INT8,
97
- "int1": INT8,
98
- # Unsigned (rare in SQL)
99
- "uint8": UINT8,
100
- "ubyte": UINT8,
101
- "uint16": UINT16,
102
- "uint32": UINT32,
103
- "uint64": UINT64,
104
- # Booleans / strings
105
- "bool": BOOL,
106
- "boolean": BOOL,
107
- "char": STRING,
108
- "varchar": STRING,
109
- "text": STRING,
110
- "string": STRING,
111
- # Dates / times
112
- "date": DATE,
113
- "time": TIME,
114
- "timestamp": TIMESTAMP,
115
- # Decimal / numeric
116
- "decimal": DECIMAL,
117
- "numeric": DECIMAL,
118
- # Others
119
- "json": JSON,
120
- "binary": BINARY,
121
- "varbinary": BINARY,
122
- "blob": BINARY,
123
- "uuid": UUID,
124
- "interval": INTERVAL,
125
- }
126
-
127
-
128
- def _deduce_out_schema(
129
- parsed: sge.Expression,
130
- dialect: str,
131
- in_schemas: dict[str, TableType],
132
- ) -> TableType:
133
- """Deduce output schema using sqlglot's qualify + annotate_types.
134
-
135
- This implementation leverages sqlglot's optimizer to resolve table/column
136
- references (including star expansion) and annotate expression types. It then
137
- maps sqlglot DataType to mplang DType and returns a TableType.
138
- """
139
-
140
- # 1) Build sqlglot schema from MPObject/TableType inputs
141
- def _dtype_to_sql(dt: DType) -> str:
142
- return MP_TO_SQL_TYPE.get(dt, "VARCHAR")
143
-
144
- sqlglot_schema: dict[str, dict[str, str]] = {
145
- tname: {col: _dtype_to_sql(dt) for col, dt in schema.columns}
146
- for tname, schema in in_schemas.items()
147
- }
148
-
149
- # 2) Parse with read dialect; 3) Qualify (resolve names, expand star); 4) Annotate types
150
- qualified = opt_qualify.qualify(parsed, schema=sqlglot_schema, dialect=dialect)
151
- typed = opt_annot.annotate_types(qualified, schema=sqlglot_schema)
152
-
153
- # 5) Extract projection names and types
154
- select = typed if isinstance(typed, sge.Select) else typed.find(sge.Select)
155
- if select is None:
156
- raise NotImplementedError(
157
- "Only SELECT queries are supported for schema deduction"
158
- )
159
-
160
- def _sqlglot_type_to_dtype(tobj: Any) -> DType:
161
- ts = str(tobj).lower().replace(" with time zone", "").strip()
162
- base = ts.split("(", 1)[0].strip()
163
- return SQL_TYPE_TO_MP.get(base, STRING)
164
-
165
- pairs: list[tuple[str, DType]] = []
166
- idx = 0
167
- used: set[str] = set()
168
- for proj in select.expressions:
169
- name = getattr(proj, "alias_or_name", None) or getattr(proj, "name", None)
170
- if not name:
171
- name = f"expr_{idx}"
172
- idx += 1
173
- t = getattr(proj, "type", None)
174
- if t is None:
175
- raise NotImplementedError(
176
- "Cannot infer type for projection; please provide out_type explicitly"
177
- )
178
- dtype = _sqlglot_type_to_dtype(t)
179
- if name in used:
180
- raise ValueError(
181
- f"Duplicate output column name '{name}' after qualification"
182
- )
183
- used.add(name)
184
- pairs.append((name, dtype))
185
-
186
- return TableType.from_pairs(pairs)
187
-
188
-
189
- @_SQL_MOD.op_def()
190
- def run_sql(
191
- query: str,
192
- *,
193
- out_type: TableType | None = None,
194
- dialect: str = "duckdb",
195
- **in_tables: Any,
196
- ) -> tuple[PFunction, list[MPObject], PyTreeDef]:
197
- """Build a sql.run PFunction from a SQL query with optional schema deduction.
198
-
199
- API: run_sql(query: str, *, out_type: TableType | None = None, dialect: str = "duckdb", **in_tables) -> (PFunction, [MPObject], PyTreeDef)
200
-
201
- Semantics:
202
- - Parses the SQL and binds only the tables that are actually referenced in the query by name.
203
- - If ``out_type`` is not provided, attempts to deduce the output table schema using sqlglot (qualify + annotate types).
204
- - Returns a triad consisting of the constructed PFunction (``fn_type='sql.run'``), the ordered list of input MPObjects, and the output PyTreeDef.
205
-
206
- Difference vs ``run_sql_raw``: this op can infer ``out_type`` and will parse the SQL to filter inputs; ``run_sql_raw`` requires an explicit ``out_type`` and does not parse/filter inputs.
207
- """
208
- # Extract required table names from SQL (order by first appearance)
209
- parsed = sg.parse_one(query, read=dialect)
210
- required_names: list[str] = []
211
- for t in parsed.find_all(sge.Table):
212
- # Prefer .name; fallback to str(this) if needed
213
- tname = getattr(t, "name", None) or str(t.this)
214
- if tname not in required_names:
215
- required_names.append(tname)
216
-
217
- # Disallow extras not referenced by the query to avoid surprises
218
- extra = set(in_tables.keys()) - set(required_names)
219
- if extra:
220
- raise ValueError(
221
- f"Unexpected tables provided that are not referenced in SQL: {sorted(extra)}"
222
- )
223
-
224
- # Validate required tables and require MPObject for runtime registration
225
- in_names: list[str] = []
226
- ins_info: list[TableType] = []
227
- in_vars: list[MPObject] = []
228
- for name in required_names:
229
- if name not in in_tables:
230
- raise KeyError(f"Missing required table '{name}' for SQL query")
231
- obj = in_tables[name]
232
- if not isinstance(obj, MPObject):
233
- raise TypeError(
234
- f"Table '{name}' must be an MPObject (for runtime registration), got {type(obj).__name__}"
235
- )
236
- assert obj.schema is not None, f"Input table '{name}' missing schema"
237
- in_vars.append(obj)
238
- ins_info.append(obj.schema)
239
- in_names.append(name)
240
-
241
- if out_type is None:
242
- in_schemas: dict[str, TableType] = {
243
- n: in_tables[n].schema for n in required_names
244
- }
245
- out_type = _deduce_out_schema(parsed, dialect, in_schemas)
246
-
247
- pfn = PFunction(
248
- fn_type="sql.run",
249
- ins_info=tuple(ins_info),
250
- outs_info=(out_type,),
251
- fn_name="",
252
- fn_text=query,
253
- in_names=tuple(in_names),
254
- dialect=dialect,
255
- )
256
- _, treedef = tree_flatten(out_type)
257
- return pfn, in_vars, treedef
258
-
259
-
260
- @_SQL_MOD.op_def()
261
- def run_sql_raw(
262
- query: str,
263
- out_type: TableType,
264
- *,
265
- dialect: str = "duckdb",
266
- in_tables: dict[str, MPObject] | None = None,
267
- ) -> tuple[PFunction, list[MPObject], PyTreeDef]:
268
- """Build a sql.run PFunction from a SQL query with an explicit output schema.
269
-
270
- API: run_sql_raw(query: str, out_type: TableType, *, dialect: str = "duckdb", in_tables: dict[str, MPObject] | None = None) -> (PFunction, [MPObject], PyTreeDef)
271
-
272
- Semantics:
273
- - Does not parse the SQL; carries all tables provided via ``in_tables`` in the mapping's iteration order.
274
- - Requires an explicit ``out_type``; no schema deduction is attempted.
275
- - Returns a triad consisting of the constructed PFunction (``fn_type='sql.run'``), the ordered list of input MPObjects, and the output PyTreeDef.
276
-
277
- Difference vs ``run_sql``: this op requires ``out_type`` and does not parse/filter inputs; ``run_sql`` can infer ``out_type`` and selects only tables referenced by the query.
278
- """
279
-
280
- # Collect inputs strictly as provided by caller
281
- in_names: list[str] = []
282
- ins_info: list[TableType] = []
283
- in_vars: list[MPObject] = []
284
- if in_tables:
285
- for name, tbl in in_tables.items():
286
- if not isinstance(tbl, MPObject):
287
- raise TypeError(f"Input table '{name}' is not an MPObject {type(tbl)}")
288
- assert tbl.schema is not None, f"Input table '{name}' is missing a schema"
289
- in_names.append(name)
290
- ins_info.append(tbl.schema)
291
- in_vars.append(tbl)
292
-
293
- pfn = PFunction(
294
- fn_type="sql.run",
295
- fn_name="",
296
- fn_text=query,
297
- ins_info=tuple(ins_info),
298
- outs_info=(out_type,),
299
- in_names=tuple(in_names),
300
- dialect=dialect,
301
- )
302
- _, treedef = tree_flatten(out_type)
303
- return pfn, in_vars, treedef
mplang/v1/ops/tee.py DELETED
@@ -1,36 +0,0 @@
1
- # Copyright 2025 Ant Group Co., Ltd.
2
- #
3
- # Licensed under the Apache License, Version 2.0 (the "License");
4
- # you may not use this file except in compliance with the License.
5
- # You may obtain a copy of the License at
6
- #
7
- # http://www.apache.org/licenses/LICENSE-2.0
8
- #
9
- # Unless required by applicable law or agreed to in writing, software
10
- # distributed under the License is distributed on an "AS IS" BASIS,
11
- # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
- # See the License for the specific language governing permissions and
13
- # limitations under the License.
14
-
15
- from __future__ import annotations
16
-
17
- from mplang.v1.core import UINT8, TensorType
18
- from mplang.v1.ops.base import stateless_mod
19
-
20
- _TEE_MOD = stateless_mod("tee")
21
-
22
-
23
- @_TEE_MOD.simple_op()
24
- def quote_gen(pk: TensorType) -> TensorType:
25
- """TEE quote generation binding the provided ephemeral public key."""
26
- _ = pk # Mark as used for the decorator
27
- return TensorType(UINT8, (-1,))
28
-
29
-
30
- @_TEE_MOD.simple_op()
31
- def attest(quote: TensorType) -> TensorType:
32
- """TEE quote verification returning the attested TEE public key.
33
- API (mock): attest(quote: u8[33]) -> tee_pk: u8[32]
34
- """
35
- _ = quote # Mark as used for the decorator
36
- return TensorType(UINT8, (32,))