mplang-nightly 0.1.dev158__py3-none-any.whl → 0.1.dev268__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (191) hide show
  1. mplang/__init__.py +21 -45
  2. mplang/py.typed +13 -0
  3. mplang/v1/__init__.py +157 -0
  4. mplang/v1/_device.py +602 -0
  5. mplang/{analysis → v1/analysis}/__init__.py +1 -1
  6. mplang/{analysis → v1/analysis}/diagram.py +5 -7
  7. mplang/v1/core/__init__.py +157 -0
  8. mplang/{core → v1/core}/cluster.py +30 -14
  9. mplang/{core → v1/core}/comm.py +5 -1
  10. mplang/{core → v1/core}/context_mgr.py +1 -1
  11. mplang/{core/dtype.py → v1/core/dtypes.py} +44 -2
  12. mplang/{core → v1/core}/expr/__init__.py +7 -7
  13. mplang/{core → v1/core}/expr/ast.py +13 -14
  14. mplang/{core → v1/core}/expr/evaluator.py +65 -24
  15. mplang/{core → v1/core}/expr/printer.py +24 -18
  16. mplang/{core → v1/core}/expr/transformer.py +3 -3
  17. mplang/{core → v1/core}/expr/utils.py +2 -2
  18. mplang/{core → v1/core}/expr/visitor.py +1 -1
  19. mplang/{core → v1/core}/expr/walk.py +1 -1
  20. mplang/{core → v1/core}/interp.py +6 -6
  21. mplang/{core → v1/core}/mpir.py +23 -16
  22. mplang/{core → v1/core}/mpobject.py +6 -6
  23. mplang/{core → v1/core}/mptype.py +13 -10
  24. mplang/{core → v1/core}/pfunc.py +4 -4
  25. mplang/{core → v1/core}/primitive.py +106 -201
  26. mplang/{core → v1/core}/table.py +36 -8
  27. mplang/{core → v1/core}/tensor.py +1 -1
  28. mplang/{core → v1/core}/tracer.py +9 -9
  29. mplang/{api.py → v1/host.py} +38 -6
  30. mplang/v1/kernels/__init__.py +41 -0
  31. mplang/{kernels → v1/kernels}/base.py +1 -1
  32. mplang/v1/kernels/basic.py +240 -0
  33. mplang/{kernels → v1/kernels}/context.py +42 -27
  34. mplang/{kernels → v1/kernels}/crypto.py +44 -37
  35. mplang/v1/kernels/fhe.py +858 -0
  36. mplang/{kernels → v1/kernels}/mock_tee.py +12 -13
  37. mplang/{kernels → v1/kernels}/phe.py +263 -57
  38. mplang/{kernels → v1/kernels}/spu.py +137 -48
  39. mplang/{kernels → v1/kernels}/sql_duckdb.py +12 -15
  40. mplang/{kernels → v1/kernels}/stablehlo.py +30 -23
  41. mplang/v1/kernels/value.py +626 -0
  42. mplang/{ops → v1/ops}/__init__.py +5 -16
  43. mplang/{ops → v1/ops}/base.py +2 -5
  44. mplang/{ops/builtin.py → v1/ops/basic.py} +34 -26
  45. mplang/v1/ops/crypto.py +262 -0
  46. mplang/v1/ops/fhe.py +272 -0
  47. mplang/{ops → v1/ops}/jax_cc.py +33 -68
  48. mplang/v1/ops/nnx_cc.py +168 -0
  49. mplang/{ops → v1/ops}/phe.py +16 -4
  50. mplang/{ops → v1/ops}/spu.py +3 -5
  51. mplang/v1/ops/sql_cc.py +303 -0
  52. mplang/{ops → v1/ops}/tee.py +9 -24
  53. mplang/{protos → v1/protos}/v1alpha1/mpir_pb2.pyi +71 -21
  54. mplang/v1/protos/v1alpha1/value_pb2.py +34 -0
  55. mplang/v1/protos/v1alpha1/value_pb2.pyi +169 -0
  56. mplang/{runtime → v1/runtime}/__init__.py +2 -2
  57. mplang/v1/runtime/channel.py +230 -0
  58. mplang/{runtime → v1/runtime}/cli.py +35 -20
  59. mplang/{runtime → v1/runtime}/client.py +19 -8
  60. mplang/{runtime → v1/runtime}/communicator.py +59 -15
  61. mplang/{runtime → v1/runtime}/data_providers.py +80 -19
  62. mplang/{runtime → v1/runtime}/driver.py +30 -12
  63. mplang/v1/runtime/link_comm.py +196 -0
  64. mplang/{runtime → v1/runtime}/server.py +58 -42
  65. mplang/{runtime → v1/runtime}/session.py +57 -71
  66. mplang/{runtime → v1/runtime}/simulation.py +55 -28
  67. mplang/v1/simp/api.py +353 -0
  68. mplang/{simp → v1/simp}/mpi.py +8 -9
  69. mplang/{simp/__init__.py → v1/simp/party.py} +19 -145
  70. mplang/{simp → v1/simp}/random.py +21 -22
  71. mplang/v1/simp/smpc.py +238 -0
  72. mplang/v1/utils/table_utils.py +185 -0
  73. mplang/v2/__init__.py +424 -0
  74. mplang/v2/backends/__init__.py +57 -0
  75. mplang/v2/backends/bfv_impl.py +705 -0
  76. mplang/v2/backends/channel.py +217 -0
  77. mplang/v2/backends/crypto_impl.py +723 -0
  78. mplang/v2/backends/field_impl.py +454 -0
  79. mplang/v2/backends/func_impl.py +107 -0
  80. mplang/v2/backends/phe_impl.py +148 -0
  81. mplang/v2/backends/simp_design.md +136 -0
  82. mplang/v2/backends/simp_driver/__init__.py +41 -0
  83. mplang/v2/backends/simp_driver/http.py +168 -0
  84. mplang/v2/backends/simp_driver/mem.py +280 -0
  85. mplang/v2/backends/simp_driver/ops.py +135 -0
  86. mplang/v2/backends/simp_driver/state.py +60 -0
  87. mplang/v2/backends/simp_driver/values.py +52 -0
  88. mplang/v2/backends/simp_worker/__init__.py +29 -0
  89. mplang/v2/backends/simp_worker/http.py +354 -0
  90. mplang/v2/backends/simp_worker/mem.py +102 -0
  91. mplang/v2/backends/simp_worker/ops.py +167 -0
  92. mplang/v2/backends/simp_worker/state.py +49 -0
  93. mplang/v2/backends/spu_impl.py +275 -0
  94. mplang/v2/backends/spu_state.py +187 -0
  95. mplang/v2/backends/store_impl.py +62 -0
  96. mplang/v2/backends/table_impl.py +838 -0
  97. mplang/v2/backends/tee_impl.py +215 -0
  98. mplang/v2/backends/tensor_impl.py +519 -0
  99. mplang/v2/cli.py +603 -0
  100. mplang/v2/cli_guide.md +122 -0
  101. mplang/v2/dialects/__init__.py +36 -0
  102. mplang/v2/dialects/bfv.py +665 -0
  103. mplang/v2/dialects/crypto.py +689 -0
  104. mplang/v2/dialects/dtypes.py +378 -0
  105. mplang/v2/dialects/field.py +210 -0
  106. mplang/v2/dialects/func.py +135 -0
  107. mplang/v2/dialects/phe.py +723 -0
  108. mplang/v2/dialects/simp.py +944 -0
  109. mplang/v2/dialects/spu.py +349 -0
  110. mplang/v2/dialects/store.py +63 -0
  111. mplang/v2/dialects/table.py +407 -0
  112. mplang/v2/dialects/tee.py +346 -0
  113. mplang/v2/dialects/tensor.py +1175 -0
  114. mplang/v2/edsl/README.md +279 -0
  115. mplang/v2/edsl/__init__.py +99 -0
  116. mplang/v2/edsl/context.py +311 -0
  117. mplang/v2/edsl/graph.py +463 -0
  118. mplang/v2/edsl/jit.py +62 -0
  119. mplang/v2/edsl/object.py +53 -0
  120. mplang/v2/edsl/primitive.py +284 -0
  121. mplang/v2/edsl/printer.py +119 -0
  122. mplang/v2/edsl/registry.py +207 -0
  123. mplang/v2/edsl/serde.py +375 -0
  124. mplang/v2/edsl/tracer.py +614 -0
  125. mplang/v2/edsl/typing.py +816 -0
  126. mplang/v2/kernels/Makefile +30 -0
  127. mplang/v2/kernels/__init__.py +23 -0
  128. mplang/v2/kernels/gf128.cpp +148 -0
  129. mplang/v2/kernels/ldpc.cpp +82 -0
  130. mplang/v2/kernels/okvs.cpp +283 -0
  131. mplang/v2/kernels/okvs_opt.cpp +291 -0
  132. mplang/v2/kernels/py_kernels.py +398 -0
  133. mplang/v2/libs/collective.py +330 -0
  134. mplang/v2/libs/device/__init__.py +51 -0
  135. mplang/v2/libs/device/api.py +813 -0
  136. mplang/v2/libs/device/cluster.py +352 -0
  137. mplang/v2/libs/ml/__init__.py +23 -0
  138. mplang/v2/libs/ml/sgb.py +1861 -0
  139. mplang/v2/libs/mpc/__init__.py +41 -0
  140. mplang/v2/libs/mpc/_utils.py +99 -0
  141. mplang/v2/libs/mpc/analytics/__init__.py +35 -0
  142. mplang/v2/libs/mpc/analytics/aggregation.py +372 -0
  143. mplang/v2/libs/mpc/analytics/groupby.md +99 -0
  144. mplang/v2/libs/mpc/analytics/groupby.py +331 -0
  145. mplang/v2/libs/mpc/analytics/permutation.py +386 -0
  146. mplang/v2/libs/mpc/common/constants.py +39 -0
  147. mplang/v2/libs/mpc/ot/__init__.py +32 -0
  148. mplang/v2/libs/mpc/ot/base.py +222 -0
  149. mplang/v2/libs/mpc/ot/extension.py +477 -0
  150. mplang/v2/libs/mpc/ot/silent.py +217 -0
  151. mplang/v2/libs/mpc/psi/__init__.py +40 -0
  152. mplang/v2/libs/mpc/psi/cuckoo.py +228 -0
  153. mplang/v2/libs/mpc/psi/okvs.py +49 -0
  154. mplang/v2/libs/mpc/psi/okvs_gct.py +79 -0
  155. mplang/v2/libs/mpc/psi/oprf.py +310 -0
  156. mplang/v2/libs/mpc/psi/rr22.py +344 -0
  157. mplang/v2/libs/mpc/psi/unbalanced.py +200 -0
  158. mplang/v2/libs/mpc/vole/__init__.py +31 -0
  159. mplang/v2/libs/mpc/vole/gilboa.py +327 -0
  160. mplang/v2/libs/mpc/vole/ldpc.py +383 -0
  161. mplang/v2/libs/mpc/vole/silver.py +336 -0
  162. mplang/v2/runtime/__init__.py +15 -0
  163. mplang/v2/runtime/dialect_state.py +41 -0
  164. mplang/v2/runtime/interpreter.py +871 -0
  165. mplang/v2/runtime/object_store.py +194 -0
  166. mplang/v2/runtime/value.py +141 -0
  167. {mplang_nightly-0.1.dev158.dist-info → mplang_nightly-0.1.dev268.dist-info}/METADATA +24 -17
  168. mplang_nightly-0.1.dev268.dist-info/RECORD +180 -0
  169. {mplang_nightly-0.1.dev158.dist-info → mplang_nightly-0.1.dev268.dist-info}/WHEEL +1 -1
  170. mplang/core/__init__.py +0 -92
  171. mplang/device.py +0 -340
  172. mplang/kernels/builtin.py +0 -207
  173. mplang/ops/crypto.py +0 -109
  174. mplang/ops/ibis_cc.py +0 -139
  175. mplang/ops/sql.py +0 -61
  176. mplang/protos/v1alpha1/mpir_pb2_grpc.py +0 -3
  177. mplang/runtime/link_comm.py +0 -131
  178. mplang/simp/smpc.py +0 -201
  179. mplang/utils/table_utils.py +0 -73
  180. mplang_nightly-0.1.dev158.dist-info/RECORD +0 -77
  181. /mplang/{core → v1/core}/mask.py +0 -0
  182. /mplang/{protos → v1/protos}/v1alpha1/mpir_pb2.py +0 -0
  183. /mplang/{runtime → v1/runtime}/exceptions.py +0 -0
  184. /mplang/{runtime → v1/runtime}/http_api.md +0 -0
  185. /mplang/{kernels → v1/simp}/__init__.py +0 -0
  186. /mplang/{utils → v1/utils}/__init__.py +0 -0
  187. /mplang/{utils → v1/utils}/crypto.py +0 -0
  188. /mplang/{utils → v1/utils}/func_utils.py +0 -0
  189. /mplang/{utils → v1/utils}/spu_utils.py +0 -0
  190. {mplang_nightly-0.1.dev158.dist-info → mplang_nightly-0.1.dev268.dist-info}/entry_points.txt +0 -0
  191. {mplang_nightly-0.1.dev158.dist-info → mplang_nightly-0.1.dev268.dist-info}/licenses/LICENSE +0 -0
mplang/ops/crypto.py DELETED
@@ -1,109 +0,0 @@
1
- # Copyright 2025 Ant Group Co., Ltd.
2
- #
3
- # Licensed under the Apache License, Version 2.0 (the "License");
4
- # you may not use this file except in compliance with the License.
5
- # You may obtain a copy of the License at
6
- #
7
- # http://www.apache.org/licenses/LICENSE-2.0
8
- #
9
- # Unless required by applicable law or agreed to in writing, software
10
- # distributed under the License is distributed on an "AS IS" BASIS,
11
- # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
- # See the License for the specific language governing permissions and
13
- # limitations under the License.
14
-
15
- """
16
- Crypto frontend operations: operation signatures, types, and high-level semantics.
17
-
18
- Scope and contracts:
19
- - This module defines portable API shapes; it does not implement cryptography.
20
- - Backends execute the operations and must meet the security semantics required
21
- by the deployment (confidentiality, authenticity, correctness, etc.).
22
- - The enc/dec API in this frontend uses a conventional 12-byte nonce prefix
23
- (ciphertext = nonce || payload), and dec expects that format. Other security
24
- properties (e.g., AEAD) are backend responsibilities.
25
- """
26
-
27
- from __future__ import annotations
28
-
29
- from mplang.core.dtype import UINT8
30
- from mplang.core.tensor import TensorType
31
- from mplang.ops.base import stateless_mod
32
-
33
- _CRYPTO_MOD = stateless_mod("crypto")
34
-
35
-
36
- @_CRYPTO_MOD.simple_op()
37
- def keygen(*, length: int = 32) -> TensorType:
38
- """Generate random bytes for symmetric keys or generic randomness.
39
-
40
- API: keygen(length: int = 32) -> key: u8[length]
41
-
42
- Notes:
43
- - Frontend defines the type/shape; backend provides randomness.
44
- - Raises ValueError when length <= 0.
45
- """
46
- if length <= 0:
47
- raise ValueError("length must be > 0")
48
- return TensorType(UINT8, (length,))
49
-
50
-
51
- @_CRYPTO_MOD.simple_op()
52
- def enc(plaintext: TensorType, key: TensorType) -> TensorType:
53
- """Symmetric encryption.
54
-
55
- API: enc(plaintext: u8[N], key: u8[M]) -> ciphertext: u8[N + 12]
56
- """
57
- pt_ty = plaintext
58
- if pt_ty.dtype != UINT8:
59
- raise TypeError("enc expects UINT8 plaintext")
60
- if len(pt_ty.shape) != 1:
61
- raise TypeError("enc expects 1-D plaintext")
62
- length = pt_ty.shape[0]
63
- if length >= 0:
64
- return TensorType(UINT8, (length + 12,))
65
- return TensorType(UINT8, (-1,))
66
-
67
-
68
- @_CRYPTO_MOD.simple_op()
69
- def dec(ciphertext: TensorType, key: TensorType) -> TensorType:
70
- """Symmetric decryption.
71
-
72
- API: dec(ciphertext: u8[N + 12], key: u8[M]) -> plaintext: u8[N]
73
- """
74
- ct_ty = ciphertext
75
- if ct_ty.dtype != UINT8:
76
- raise TypeError("dec expects UINT8 ciphertext")
77
- if len(ct_ty.shape) != 1:
78
- raise TypeError("dec expects 1-D ciphertext with nonce")
79
- length = ct_ty.shape[0]
80
- if length >= 0 and length < 12:
81
- raise TypeError("dec expects 1-D ciphertext with nonce")
82
- if length >= 0:
83
- return TensorType(UINT8, (length - 12,))
84
- return TensorType(UINT8, (-1,))
85
-
86
-
87
- @_CRYPTO_MOD.simple_op()
88
- def kem_keygen(*, suite: str = "x25519") -> tuple[TensorType, TensorType]:
89
- """KEM-style keypair generation: returns (sk, pk) bytes."""
90
- sk_ty = TensorType(UINT8, (32,))
91
- pk_ty = TensorType(UINT8, (32,))
92
- return sk_ty, pk_ty
93
-
94
-
95
- @_CRYPTO_MOD.simple_op()
96
- def kem_derive(
97
- sk: TensorType, peer_pk: TensorType, *, suite: str = "x25519"
98
- ) -> TensorType:
99
- """KEM-style shared secret derivation: returns secret bytes."""
100
- _ = sk
101
- _ = peer_pk
102
- return TensorType(UINT8, (32,))
103
-
104
-
105
- @_CRYPTO_MOD.simple_op()
106
- def hkdf(secret: TensorType, *, info: str) -> TensorType:
107
- """HKDF-style key derivation: returns a 32-byte key."""
108
- _ = secret
109
- return TensorType(UINT8, (32,))
mplang/ops/ibis_cc.py DELETED
@@ -1,139 +0,0 @@
1
- # Copyright 2025 Ant Group Co., Ltd.
2
- #
3
- # Licensed under the Apache License, Version 2.0 (the "License");
4
- # you may not use this file except in compliance with the License.
5
- # You may obtain a copy of the License at
6
- #
7
- # http://www.apache.org/licenses/LICENSE-2.0
8
- #
9
- # Unless required by applicable law or agreed to in writing, software
10
- # distributed under the License is distributed on an "AS IS" BASIS,
11
- # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
- # See the License for the specific language governing permissions and
13
- # limitations under the License.
14
-
15
-
16
- import inspect
17
- from collections.abc import Callable
18
- from typing import Any
19
-
20
- import ibis
21
- from jax.tree_util import PyTreeDef, tree_flatten
22
-
23
- from mplang.core import dtype
24
- from mplang.core.mpobject import MPObject
25
- from mplang.core.pfunc import PFunction
26
- from mplang.core.table import TableType
27
- from mplang.ops.base import FeOperation, stateless_mod
28
- from mplang.utils.func_utils import normalize_fn
29
-
30
-
31
- def ibis2sql(
32
- expr: ibis.Table,
33
- in_schemas: list[ibis.Schema],
34
- in_names: list[str],
35
- fn_name: str = "",
36
- ) -> PFunction:
37
- """
38
- Compile a ibis expr to sql and return the PFunction.
39
-
40
- Args:
41
- expr: ibis expr.
42
- in_schemas: the input table schemas
43
- in_names: the input table names, If there is only one table, it is usually defaulted to "table"
44
- Return:
45
- PFunction: The compiled PFunction
46
- """
47
- assert len(in_schemas) == len(in_names), (
48
- f"length of input table names and schemas mismatch. {len(in_schemas)}!={len(in_names)}"
49
- )
50
-
51
- def _convert(s: ibis.Schema) -> TableType:
52
- return TableType.from_pairs([
53
- (name, dtype.from_numpy(dt.to_numpy())) for name, dt in s.fields.items()
54
- ])
55
-
56
- ins_info = [_convert(s) for s in in_schemas]
57
- outs_info = [_convert(expr.schema())]
58
-
59
- sql = ibis.to_sql(expr, dialect="duckdb")
60
- # Emit generic sql.run op; runtime maps to backend-specific kernel.
61
- pfn = PFunction(
62
- fn_type="sql.run",
63
- fn_name=fn_name,
64
- fn_text=sql,
65
- ins_info=tuple(ins_info),
66
- outs_info=tuple(outs_info),
67
- in_names=tuple(in_names),
68
- dialect="duckdb",
69
- )
70
- return pfn
71
-
72
-
73
- def is_ibis_function(func: Callable) -> bool:
74
- """
75
- Verify whether a function is an ibis function.
76
- The func signature should like def foo(t0:ibis.Table, t1:ibis.Table)->ibis.Table
77
- """
78
- try:
79
- sig = inspect.signature(func)
80
- except (ValueError, TypeError):
81
- return False
82
-
83
- ret_anno = sig.return_annotation
84
- if ret_anno is ibis.Table:
85
- return True
86
-
87
- for param in sig.parameters.values():
88
- par_anno = param.annotation
89
- if par_anno is ibis.Table:
90
- return True
91
-
92
- return False
93
-
94
-
95
- _IBIS_MOD = stateless_mod("ibis")
96
-
97
-
98
- class IbisCompiler(FeOperation):
99
- """Ibis compiler frontend operation."""
100
-
101
- def trace(
102
- self, func: Callable, *args: Any, **kwargs: Any
103
- ) -> tuple[PFunction, list[MPObject], PyTreeDef]:
104
- """Compile an Ibis function to SQL format.
105
-
106
- Args:
107
- func: The Ibis function to compile
108
- *args: Positional arguments to the function
109
- **kwargs: Keyword arguments to the function
110
-
111
- Returns:
112
- tuple[PFunction, list[MPObject], Any]: The compiled PFunction, input variables, and output tree
113
- """
114
-
115
- def is_variable(arg: Any) -> bool:
116
- return isinstance(arg, MPObject)
117
-
118
- normalized_fn, in_vars = normalize_fn(func, args, kwargs, is_variable)
119
-
120
- in_args, in_schemas, in_names = [], [], []
121
- idx = 0
122
- for arg in in_vars:
123
- columns = [(p[0], p[1].to_numpy()) for p in arg.schema.columns]
124
- schema = ibis.schema(columns)
125
- name = f"table{idx}"
126
- table = ibis.table(schema=schema, name=name)
127
- in_args.append(table)
128
- in_schemas.append(schema)
129
- in_names.append(name)
130
- idx += 1
131
-
132
- result = normalized_fn(in_args)
133
- assert isinstance(result, ibis.Table)
134
- pfunc = ibis2sql(result, in_schemas, in_names, func.__name__)
135
- _, treedef = tree_flatten(result)
136
- return pfunc, in_vars, treedef
137
-
138
-
139
- ibis_compile = IbisCompiler(_IBIS_MOD, "compile")
mplang/ops/sql.py DELETED
@@ -1,61 +0,0 @@
1
- # Copyright 2025 Ant Group Co., Ltd.
2
- #
3
- # Licensed under the Apache License, Version 2.0 (the "License");
4
- # you may not use this file except in compliance with the License.
5
- # You may obtain a copy of the License at
6
- #
7
- # http://www.apache.org/licenses/LICENSE-2.0
8
- #
9
- # Unless required by applicable law or agreed to in writing, software
10
- # distributed under the License is distributed on an "AS IS" BASIS,
11
- # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
- # See the License for the specific language governing permissions and
13
- # limitations under the License.
14
-
15
- from jax.tree_util import PyTreeDef, tree_flatten
16
-
17
- from mplang.core.mpobject import MPObject
18
- from mplang.core.pfunc import PFunction
19
- from mplang.core.table import TableType
20
- from mplang.ops.base import FeOperation, stateless_mod
21
-
22
- _SQL_MOD = stateless_mod("sql")
23
-
24
-
25
- class SqlFE(FeOperation):
26
- def __init__(self, dialect: str = "duckdb"):
27
- # Bind to sql module with a stable op name for registry/dispatch
28
- super().__init__(_SQL_MOD, "run")
29
- self._dialect = dialect
30
-
31
- def trace(
32
- self,
33
- sql: str,
34
- out_type: TableType,
35
- in_tables: dict[str, MPObject] | None = None,
36
- ) -> tuple[PFunction, list[MPObject], PyTreeDef]:
37
- in_names: list[str] = []
38
- ins_info: list[TableType] = []
39
- in_vars: list[MPObject] = []
40
- if in_tables:
41
- for name, tbl in in_tables.items():
42
- assert isinstance(tbl, MPObject)
43
- assert tbl.schema is not None
44
- in_names.append(name)
45
- ins_info.append(tbl.schema)
46
- in_vars.append(tbl)
47
-
48
- pfn = PFunction(
49
- fn_type="sql.run",
50
- fn_name="",
51
- fn_text=sql,
52
- ins_info=tuple(ins_info),
53
- outs_info=(out_type,),
54
- in_names=tuple(in_names),
55
- dialect=self._dialect,
56
- )
57
- _, treedef = tree_flatten(out_type)
58
- return pfn, in_vars, treedef
59
-
60
-
61
- sql_run = SqlFE("duckdb")
@@ -1,3 +0,0 @@
1
- # Generated by the gRPC Python protocol compiler plugin. DO NOT EDIT!
2
- """Client and server classes corresponding to protobuf-defined services."""
3
- import grpc
@@ -1,131 +0,0 @@
1
- # Copyright 2025 Ant Group Co., Ltd.
2
- #
3
- # Licensed under the Apache License, Version 2.0 (the "License");
4
- # you may not use this file except in compliance with the License.
5
- # You may obtain a copy of the License at
6
- #
7
- # http://www.apache.org/licenses/LICENSE-2.0
8
- #
9
- # Unless required by applicable law or agreed to in writing, software
10
- # distributed under the License is distributed on an "AS IS" BASIS,
11
- # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
- # See the License for the specific language governing permissions and
13
- # limitations under the License.
14
-
15
- from __future__ import annotations
16
-
17
- import logging
18
- from typing import Any
19
-
20
- import cloudpickle as pickle
21
- import spu.libspu as libspu
22
-
23
- from mplang.core.comm import ICollective, ICommunicator
24
-
25
-
26
- class LinkCommunicator(ICommunicator, ICollective):
27
- """Wraps libspu link communicator for distributed communication"""
28
-
29
- def __init__(self, rank: int, addrs: list[str], *, mem_link: bool = False):
30
- self._rank = rank
31
- self._world_size = len(addrs)
32
-
33
- desc = libspu.link.Desc() # type: ignore
34
- desc.recv_timeout_ms = 100 * 1000 # 100 seconds
35
- desc.http_max_payload_size = 32 * 1024 * 1024 # Default set link payload to 32M
36
- for rank, addr in enumerate(addrs):
37
- desc.add_party(f"P{rank}", addr)
38
-
39
- if mem_link:
40
- self.lctx = libspu.link.create_mem(desc, self._rank)
41
- else:
42
- self.lctx = libspu.link.create_brpc(desc, self._rank)
43
-
44
- logging.info(
45
- f"LinkCommunicator initialized with rank={self._rank}, world_size={self._world_size}, addrs={addrs}",
46
- )
47
-
48
- self._counter = 0
49
-
50
- @property
51
- def rank(self) -> int:
52
- return self.lctx.rank # type: ignore[no-any-return]
53
-
54
- @property
55
- def world_size(self) -> int:
56
- return self.lctx.world_size # type: ignore[no-any-return]
57
-
58
- def get_lctx(self) -> libspu.link.Context:
59
- """Get the link context"""
60
- return self.lctx
61
-
62
- # override
63
- def new_id(self) -> str:
64
- res = self._counter
65
- self._counter += 1
66
- return str(res)
67
-
68
- def wrap(self, obj: Any) -> str:
69
- data = pickle.dumps(obj)
70
- return data.hex() # type: ignore[no-any-return]
71
-
72
- def unwrap(self, obj: str) -> Any:
73
- data = bytes.fromhex(obj)
74
- return pickle.loads(data) # type: ignore[no-any-return]
75
-
76
- def send(self, to: int, key: str, data: Any) -> None:
77
- serialized = pickle.dumps((key, data))
78
- self.lctx.send(to, serialized.hex())
79
-
80
- def recv(self, frm: int, key: str) -> Any:
81
- serialized = self.lctx.recv(frm)
82
- rkey, data = pickle.loads(bytes.fromhex(serialized.decode()))
83
- assert key == rkey, f"recv key {key} != {rkey}"
84
- return data # type: ignore[no-any-return]
85
-
86
- def p2p(self, frm: int, to: int, data: Any) -> Any:
87
- assert 0 <= frm < self.world_size
88
- assert 0 <= to < self.world_size
89
-
90
- # TODO: link handles cid internally?
91
- cid = self.new_id()
92
-
93
- if self.rank == frm:
94
- self.send(to, cid, data)
95
- return None
96
- elif self.rank == to:
97
- return self.recv(frm, cid)
98
- else:
99
- return None
100
-
101
- def gather(self, root: int, data: Any) -> list[Any]:
102
- assert 0 <= root < self.world_size
103
- rets = self.lctx.gather(self.wrap(data), root)
104
- return [self.unwrap(ret) for ret in rets]
105
-
106
- def scatter(self, root: int, args: list[Any]) -> Any:
107
- assert 0 <= root < self.world_size
108
- assert len(args) == self.world_size, f"{len(args)} != {self.world_size}"
109
- ret = self.lctx.scatter([self.wrap(arg) for arg in args], root)
110
- return self.unwrap(ret)
111
-
112
- def allgather(self, arg: Any) -> list[Any]:
113
- rets = self.lctx.all_gather(self.wrap(arg))
114
- return [self.unwrap(ret) for ret in rets]
115
-
116
- def bcast(self, root: int, arg: Any) -> Any:
117
- assert 0 <= root < self.world_size
118
- ret = self.lctx.broadcast(self.wrap(arg), root)
119
- return self.unwrap(ret)
120
-
121
- def gather_m(self, pmask: int, root: int, data: Any) -> list[Any]:
122
- raise ValueError("Not supported by LinkCommunicator")
123
-
124
- def scatter_m(self, pmask: int, root: int, args: list[Any]) -> Any:
125
- raise ValueError("Not supported by LinkCommunicator")
126
-
127
- def allgather_m(self, pmask: int, arg: Any) -> list[Any]:
128
- raise ValueError("Not supported by LinkCommunicator")
129
-
130
- def bcast_m(self, pmask: int, root: int, arg: Any) -> Any:
131
- raise ValueError("Not supported by LinkCommunicator")
mplang/simp/smpc.py DELETED
@@ -1,201 +0,0 @@
1
- # Copyright 2025 Ant Group Co., Ltd.
2
- #
3
- # Licensed under the Apache License, Version 2.0 (the "License");
4
- # you may not use this file except in compliance with the License.
5
- # You may obtain a copy of the License at
6
- #
7
- # http://www.apache.org/licenses/LICENSE-2.0
8
- #
9
- # Unless required by applicable law or agreed to in writing, software
10
- # distributed under the License is distributed on an "AS IS" BASIS,
11
- # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
- # See the License for the specific language governing permissions and
13
- # limitations under the License.
14
-
15
-
16
- from abc import ABC, abstractmethod
17
- from collections.abc import Callable
18
- from enum import Enum
19
- from functools import wraps
20
- from typing import Any
21
-
22
- from jax.tree_util import tree_unflatten
23
-
24
- from mplang.core import Mask, MPObject, Rank, peval, psize
25
- from mplang.core.context_mgr import cur_ctx
26
- from mplang.ops import spu
27
- from mplang.simp import mpi
28
-
29
-
30
- class SecureAPI(ABC):
31
- """Base class for secure APIs."""
32
-
33
- @abstractmethod
34
- def seal(self, obj: MPObject, frm_mask: Mask | None) -> list[MPObject]: ...
35
-
36
- @abstractmethod
37
- def sealFrom(self, obj: MPObject, root: Rank) -> MPObject: ...
38
-
39
- @abstractmethod
40
- def seval(self, fe_type: str, pyfn: Callable, *args: Any, **kwargs: Any) -> Any:
41
- """Run a function in the secure environment."""
42
-
43
- @abstractmethod
44
- def reveal(self, obj: MPObject, to_mask: Mask) -> MPObject: ...
45
-
46
- @abstractmethod
47
- def revealTo(self, obj: MPObject, to_rank: Rank) -> MPObject: ...
48
-
49
-
50
- class Delegation(SecureAPI):
51
- """Delegate to a trusted third-party to perform secure operations."""
52
-
53
- def seal(self, obj: MPObject, frm_mask: Mask | None = None) -> list[MPObject]:
54
- raise NotImplementedError("TODO")
55
-
56
- def sealFrom(self, obj: MPObject, root: Rank) -> MPObject:
57
- raise NotImplementedError("TODO")
58
-
59
- def seval(self, fe_type: str, pyfn: Callable, *args: Any, **kwargs: Any) -> Any:
60
- raise NotImplementedError("TODO")
61
-
62
- def reveal(self, obj: MPObject, to_mask: Mask) -> MPObject:
63
- raise NotImplementedError("TODO")
64
-
65
- def revealTo(self, obj: MPObject, to_rank: Rank) -> MPObject:
66
- raise NotImplementedError("TODO")
67
-
68
-
69
- class SPU(SecureAPI):
70
- """Use SPU to perform secure operations."""
71
-
72
- def get_spu_mask(self) -> Mask:
73
- spu_devices = cur_ctx().cluster_spec.get_devices_by_kind("SPU")
74
- if not spu_devices:
75
- raise ValueError("No SPU device found in the cluster specification")
76
- if len(spu_devices) > 1:
77
- raise ValueError("Multiple SPU devices found in the cluster specification")
78
- spu_device = spu_devices[0]
79
- spu_mask = Mask.from_ranks([member.rank for member in spu_device.members])
80
- return spu_mask
81
-
82
- def seal(self, obj: MPObject, frm_mask: Mask | None = None) -> list[MPObject]:
83
- spu_mask: Mask = self.get_spu_mask()
84
- if obj.pmask is None:
85
- if frm_mask is None:
86
- # NOTE: The length of the return list is statically determined by obj_mask,
87
- # so only static masks are supported here.
88
- raise ValueError("Seal does not support dynamic masks.")
89
- else:
90
- # Force seal from the given mask, the runtime will raise error if the mask
91
- # does not match obj.pmask.
92
- # TODO(jint): add set_pmask primitive.
93
- pass
94
- else:
95
- if frm_mask is None:
96
- frm_mask = obj.pmask
97
- else:
98
- if not Mask(frm_mask).is_subset(obj.pmask):
99
- raise ValueError(f"Cannot seal from {frm_mask} to {obj.pmask}, ")
100
-
101
- # Get the world_size from spu_mask (number of parties in SPU computation)
102
- world_size = Mask(spu_mask).num_parties()
103
- pfunc, ins, _ = spu.makeshares(
104
- obj, world_size=world_size, visibility=spu.Visibility.SECRET
105
- )
106
- assert len(ins) == 1
107
- shares = peval(pfunc, ins, frm_mask)
108
-
109
- # scatter the shares to each party.
110
- return [mpi.scatter_m(spu_mask, rank, shares) for rank in Mask(frm_mask)]
111
-
112
- def sealFrom(self, obj: MPObject, root: Rank) -> MPObject:
113
- results = seal(obj, frm_mask=Mask.from_ranks(root))
114
- assert len(results) == 1, f"Expected one result, got {len(results)}"
115
- return results[0]
116
-
117
- def seval(self, fe_type: str, pyfn: Callable, *args: Any, **kwargs: Any) -> Any:
118
- if fe_type != "jax":
119
- raise ValueError(f"Unsupported fe_type: {fe_type}")
120
-
121
- spu_mask = self.get_spu_mask()
122
- pfunc, in_vars, out_tree = spu.jax_compile(pyfn, *args, **kwargs)
123
- assert all(var.pmask == spu_mask for var in in_vars), in_vars
124
- out_flat = peval(pfunc, in_vars, spu_mask)
125
- return tree_unflatten(out_tree, out_flat)
126
-
127
- def reveal(self, obj: MPObject, to_mask: Mask) -> MPObject:
128
- spu_mask = self.get_spu_mask()
129
-
130
- assert obj.pmask == spu_mask, (obj.pmask, spu_mask)
131
-
132
- # (n_parties, n_shares)
133
- shares = [mpi.bcast_m(to_mask, rank, obj) for rank in Mask(spu_mask)]
134
- assert len(shares) == Mask(spu_mask).num_parties(), (shares, spu_mask)
135
- assert all(share.pmask == to_mask for share in shares)
136
-
137
- # Reconstruct the original object from shares
138
- pfunc, ins, _ = spu.reconstruct(*shares)
139
- return peval(pfunc, ins, to_mask)[0] # type: ignore[no-any-return]
140
-
141
- def revealTo(self, obj: MPObject, to_rank: Rank) -> MPObject:
142
- return self.reveal(obj, to_mask=Mask.from_ranks(to_rank))
143
-
144
-
145
- class SEE(Enum):
146
- """Secure Execution Environment."""
147
-
148
- MOCK = 0
149
- SPU = 1
150
- TEE = 2
151
-
152
-
153
- # TODO(jint): move me to options.py
154
- mode: SEE = SEE.SPU
155
-
156
-
157
- def _get_sapi() -> SecureAPI:
158
- """Get the current secure API based on the mode."""
159
- if mode == SEE.MOCK:
160
- return Delegation()
161
- elif mode == SEE.SPU:
162
- return SPU()
163
- elif mode == SEE.TEE:
164
- raise NotImplementedError("TEE is not implemented yet")
165
- else:
166
- raise ValueError(f"Unknown mode: {mode}")
167
-
168
-
169
- # seal :: m a -> [s a]
170
- def seal(obj: MPObject, frm_mask: Mask | None = None) -> list[MPObject]:
171
- """Seal an simp object, result a list of sealed objects, with
172
- the i'th element as the secret from the i'th party.
173
- """
174
- return _get_sapi().seal(obj, frm_mask=frm_mask)
175
-
176
-
177
- # sealFrom :: m a -> m Rank -> s a
178
- def sealFrom(obj: MPObject, root: Rank) -> MPObject:
179
- """Seal an simp object from a specific root party."""
180
- return _get_sapi().sealFrom(obj, root)
181
-
182
-
183
- # reveal :: s a -> m a
184
- def reveal(obj: MPObject, to_mask: Mask | None = None) -> MPObject:
185
- """Reveal a sealed object to pmask'ed parties."""
186
- to_mask = to_mask or Mask.all(psize())
187
- return _get_sapi().reveal(obj, to_mask)
188
-
189
-
190
- # revealTo :: s a -> m Rank -> m a
191
- def revealTo(obj: MPObject, to_rank: Rank) -> MPObject:
192
- return _get_sapi().revealTo(obj, to_rank)
193
-
194
-
195
- # srun :: (a -> a) -> s a -> s a
196
- def srun(pyfn: Callable, *, fe_type: str = "jax") -> Callable:
197
- @wraps(pyfn)
198
- def wrapped(*args: Any, **kwargs: Any) -> Any:
199
- return _get_sapi().seval(fe_type, pyfn, *args, **kwargs)
200
-
201
- return wrapped