mplang-nightly 0.1.dev192__py3-none-any.whl → 0.1.dev268__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (188) hide show
  1. mplang/__init__.py +21 -130
  2. mplang/py.typed +13 -0
  3. mplang/v1/__init__.py +157 -0
  4. mplang/v1/_device.py +602 -0
  5. mplang/{analysis → v1/analysis}/__init__.py +1 -1
  6. mplang/{analysis → v1/analysis}/diagram.py +4 -4
  7. mplang/{core → v1/core}/__init__.py +20 -14
  8. mplang/{core → v1/core}/cluster.py +6 -1
  9. mplang/{core → v1/core}/comm.py +1 -1
  10. mplang/{core → v1/core}/context_mgr.py +1 -1
  11. mplang/{core → v1/core}/dtypes.py +38 -0
  12. mplang/{core → v1/core}/expr/__init__.py +7 -7
  13. mplang/{core → v1/core}/expr/ast.py +11 -13
  14. mplang/{core → v1/core}/expr/evaluator.py +8 -8
  15. mplang/{core → v1/core}/expr/printer.py +6 -6
  16. mplang/{core → v1/core}/expr/transformer.py +2 -2
  17. mplang/{core → v1/core}/expr/utils.py +2 -2
  18. mplang/{core → v1/core}/expr/visitor.py +1 -1
  19. mplang/{core → v1/core}/expr/walk.py +1 -1
  20. mplang/{core → v1/core}/interp.py +6 -6
  21. mplang/{core → v1/core}/mpir.py +13 -11
  22. mplang/{core → v1/core}/mpobject.py +6 -6
  23. mplang/{core → v1/core}/mptype.py +13 -10
  24. mplang/{core → v1/core}/pfunc.py +2 -2
  25. mplang/{core → v1/core}/primitive.py +12 -12
  26. mplang/{core → v1/core}/table.py +36 -8
  27. mplang/{core → v1/core}/tensor.py +1 -1
  28. mplang/{core → v1/core}/tracer.py +9 -9
  29. mplang/{host.py → v1/host.py} +5 -5
  30. mplang/{kernels → v1/kernels}/__init__.py +1 -1
  31. mplang/{kernels → v1/kernels}/base.py +1 -1
  32. mplang/{kernels → v1/kernels}/basic.py +15 -15
  33. mplang/{kernels → v1/kernels}/context.py +19 -16
  34. mplang/{kernels → v1/kernels}/crypto.py +8 -10
  35. mplang/{kernels → v1/kernels}/fhe.py +9 -7
  36. mplang/{kernels → v1/kernels}/mock_tee.py +3 -3
  37. mplang/{kernels → v1/kernels}/phe.py +26 -18
  38. mplang/{kernels → v1/kernels}/spu.py +5 -5
  39. mplang/{kernels → v1/kernels}/sql_duckdb.py +5 -3
  40. mplang/{kernels → v1/kernels}/stablehlo.py +18 -17
  41. mplang/{kernels → v1/kernels}/value.py +2 -2
  42. mplang/{ops → v1/ops}/__init__.py +3 -3
  43. mplang/{ops → v1/ops}/base.py +1 -1
  44. mplang/{ops → v1/ops}/basic.py +6 -5
  45. mplang/v1/ops/crypto.py +262 -0
  46. mplang/{ops → v1/ops}/fhe.py +2 -2
  47. mplang/{ops → v1/ops}/jax_cc.py +26 -59
  48. mplang/v1/ops/nnx_cc.py +168 -0
  49. mplang/{ops → v1/ops}/phe.py +16 -3
  50. mplang/{ops → v1/ops}/spu.py +3 -3
  51. mplang/v1/ops/sql_cc.py +303 -0
  52. mplang/{ops → v1/ops}/tee.py +2 -2
  53. mplang/{runtime → v1/runtime}/__init__.py +2 -2
  54. mplang/v1/runtime/channel.py +230 -0
  55. mplang/{runtime → v1/runtime}/cli.py +3 -3
  56. mplang/{runtime → v1/runtime}/client.py +1 -1
  57. mplang/{runtime → v1/runtime}/communicator.py +39 -15
  58. mplang/{runtime → v1/runtime}/data_providers.py +80 -19
  59. mplang/{runtime → v1/runtime}/driver.py +4 -4
  60. mplang/v1/runtime/link_comm.py +196 -0
  61. mplang/{runtime → v1/runtime}/server.py +22 -9
  62. mplang/{runtime → v1/runtime}/session.py +24 -51
  63. mplang/{runtime → v1/runtime}/simulation.py +36 -14
  64. mplang/{simp → v1/simp}/api.py +72 -14
  65. mplang/{simp → v1/simp}/mpi.py +1 -1
  66. mplang/{simp → v1/simp}/party.py +5 -5
  67. mplang/{simp → v1/simp}/random.py +2 -2
  68. mplang/v1/simp/smpc.py +238 -0
  69. mplang/v1/utils/table_utils.py +185 -0
  70. mplang/v2/__init__.py +424 -0
  71. mplang/v2/backends/__init__.py +57 -0
  72. mplang/v2/backends/bfv_impl.py +705 -0
  73. mplang/v2/backends/channel.py +217 -0
  74. mplang/v2/backends/crypto_impl.py +723 -0
  75. mplang/v2/backends/field_impl.py +454 -0
  76. mplang/v2/backends/func_impl.py +107 -0
  77. mplang/v2/backends/phe_impl.py +148 -0
  78. mplang/v2/backends/simp_design.md +136 -0
  79. mplang/v2/backends/simp_driver/__init__.py +41 -0
  80. mplang/v2/backends/simp_driver/http.py +168 -0
  81. mplang/v2/backends/simp_driver/mem.py +280 -0
  82. mplang/v2/backends/simp_driver/ops.py +135 -0
  83. mplang/v2/backends/simp_driver/state.py +60 -0
  84. mplang/v2/backends/simp_driver/values.py +52 -0
  85. mplang/v2/backends/simp_worker/__init__.py +29 -0
  86. mplang/v2/backends/simp_worker/http.py +354 -0
  87. mplang/v2/backends/simp_worker/mem.py +102 -0
  88. mplang/v2/backends/simp_worker/ops.py +167 -0
  89. mplang/v2/backends/simp_worker/state.py +49 -0
  90. mplang/v2/backends/spu_impl.py +275 -0
  91. mplang/v2/backends/spu_state.py +187 -0
  92. mplang/v2/backends/store_impl.py +62 -0
  93. mplang/v2/backends/table_impl.py +838 -0
  94. mplang/v2/backends/tee_impl.py +215 -0
  95. mplang/v2/backends/tensor_impl.py +519 -0
  96. mplang/v2/cli.py +603 -0
  97. mplang/v2/cli_guide.md +122 -0
  98. mplang/v2/dialects/__init__.py +36 -0
  99. mplang/v2/dialects/bfv.py +665 -0
  100. mplang/v2/dialects/crypto.py +689 -0
  101. mplang/v2/dialects/dtypes.py +378 -0
  102. mplang/v2/dialects/field.py +210 -0
  103. mplang/v2/dialects/func.py +135 -0
  104. mplang/v2/dialects/phe.py +723 -0
  105. mplang/v2/dialects/simp.py +944 -0
  106. mplang/v2/dialects/spu.py +349 -0
  107. mplang/v2/dialects/store.py +63 -0
  108. mplang/v2/dialects/table.py +407 -0
  109. mplang/v2/dialects/tee.py +346 -0
  110. mplang/v2/dialects/tensor.py +1175 -0
  111. mplang/v2/edsl/README.md +279 -0
  112. mplang/v2/edsl/__init__.py +99 -0
  113. mplang/v2/edsl/context.py +311 -0
  114. mplang/v2/edsl/graph.py +463 -0
  115. mplang/v2/edsl/jit.py +62 -0
  116. mplang/v2/edsl/object.py +53 -0
  117. mplang/v2/edsl/primitive.py +284 -0
  118. mplang/v2/edsl/printer.py +119 -0
  119. mplang/v2/edsl/registry.py +207 -0
  120. mplang/v2/edsl/serde.py +375 -0
  121. mplang/v2/edsl/tracer.py +614 -0
  122. mplang/v2/edsl/typing.py +816 -0
  123. mplang/v2/kernels/Makefile +30 -0
  124. mplang/v2/kernels/__init__.py +23 -0
  125. mplang/v2/kernels/gf128.cpp +148 -0
  126. mplang/v2/kernels/ldpc.cpp +82 -0
  127. mplang/v2/kernels/okvs.cpp +283 -0
  128. mplang/v2/kernels/okvs_opt.cpp +291 -0
  129. mplang/v2/kernels/py_kernels.py +398 -0
  130. mplang/v2/libs/collective.py +330 -0
  131. mplang/v2/libs/device/__init__.py +51 -0
  132. mplang/v2/libs/device/api.py +813 -0
  133. mplang/v2/libs/device/cluster.py +352 -0
  134. mplang/v2/libs/ml/__init__.py +23 -0
  135. mplang/v2/libs/ml/sgb.py +1861 -0
  136. mplang/v2/libs/mpc/__init__.py +41 -0
  137. mplang/v2/libs/mpc/_utils.py +99 -0
  138. mplang/v2/libs/mpc/analytics/__init__.py +35 -0
  139. mplang/v2/libs/mpc/analytics/aggregation.py +372 -0
  140. mplang/v2/libs/mpc/analytics/groupby.md +99 -0
  141. mplang/v2/libs/mpc/analytics/groupby.py +331 -0
  142. mplang/v2/libs/mpc/analytics/permutation.py +386 -0
  143. mplang/v2/libs/mpc/common/constants.py +39 -0
  144. mplang/v2/libs/mpc/ot/__init__.py +32 -0
  145. mplang/v2/libs/mpc/ot/base.py +222 -0
  146. mplang/v2/libs/mpc/ot/extension.py +477 -0
  147. mplang/v2/libs/mpc/ot/silent.py +217 -0
  148. mplang/v2/libs/mpc/psi/__init__.py +40 -0
  149. mplang/v2/libs/mpc/psi/cuckoo.py +228 -0
  150. mplang/v2/libs/mpc/psi/okvs.py +49 -0
  151. mplang/v2/libs/mpc/psi/okvs_gct.py +79 -0
  152. mplang/v2/libs/mpc/psi/oprf.py +310 -0
  153. mplang/v2/libs/mpc/psi/rr22.py +344 -0
  154. mplang/v2/libs/mpc/psi/unbalanced.py +200 -0
  155. mplang/v2/libs/mpc/vole/__init__.py +31 -0
  156. mplang/v2/libs/mpc/vole/gilboa.py +327 -0
  157. mplang/v2/libs/mpc/vole/ldpc.py +383 -0
  158. mplang/v2/libs/mpc/vole/silver.py +336 -0
  159. mplang/v2/runtime/__init__.py +15 -0
  160. mplang/v2/runtime/dialect_state.py +41 -0
  161. mplang/v2/runtime/interpreter.py +871 -0
  162. mplang/v2/runtime/object_store.py +194 -0
  163. mplang/v2/runtime/value.py +141 -0
  164. {mplang_nightly-0.1.dev192.dist-info → mplang_nightly-0.1.dev268.dist-info}/METADATA +22 -16
  165. mplang_nightly-0.1.dev268.dist-info/RECORD +180 -0
  166. {mplang_nightly-0.1.dev192.dist-info → mplang_nightly-0.1.dev268.dist-info}/WHEEL +1 -1
  167. mplang/device.py +0 -327
  168. mplang/ops/crypto.py +0 -108
  169. mplang/ops/ibis_cc.py +0 -136
  170. mplang/ops/sql_cc.py +0 -62
  171. mplang/runtime/link_comm.py +0 -78
  172. mplang/simp/smpc.py +0 -201
  173. mplang/utils/table_utils.py +0 -85
  174. mplang_nightly-0.1.dev192.dist-info/RECORD +0 -83
  175. /mplang/{core → v1/core}/mask.py +0 -0
  176. /mplang/{protos → v1/protos}/v1alpha1/mpir_pb2.py +0 -0
  177. /mplang/{protos → v1/protos}/v1alpha1/mpir_pb2.pyi +0 -0
  178. /mplang/{protos → v1/protos}/v1alpha1/value_pb2.py +0 -0
  179. /mplang/{protos → v1/protos}/v1alpha1/value_pb2.pyi +0 -0
  180. /mplang/{runtime → v1/runtime}/exceptions.py +0 -0
  181. /mplang/{runtime → v1/runtime}/http_api.md +0 -0
  182. /mplang/{simp → v1/simp}/__init__.py +0 -0
  183. /mplang/{utils → v1/utils}/__init__.py +0 -0
  184. /mplang/{utils → v1/utils}/crypto.py +0 -0
  185. /mplang/{utils → v1/utils}/func_utils.py +0 -0
  186. /mplang/{utils → v1/utils}/spu_utils.py +0 -0
  187. {mplang_nightly-0.1.dev192.dist-info → mplang_nightly-0.1.dev268.dist-info}/entry_points.txt +0 -0
  188. {mplang_nightly-0.1.dev192.dist-info → mplang_nightly-0.1.dev268.dist-info}/licenses/LICENSE +0 -0
mplang/device.py DELETED
@@ -1,327 +0,0 @@
1
- # Copyright 2025 Ant Group Co., Ltd.
2
- #
3
- # Licensed under the Apache License, Version 2.0 (the "License");
4
- # you may not use this file except in compliance with the License.
5
- # You may obtain a copy of the License at
6
- #
7
- # http://www.apache.org/licenses/LICENSE-2.0
8
- #
9
- # Unless required by applicable law or agreed to in writing, software
10
- # distributed under the License is distributed on an "AS IS" BASIS,
11
- # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
- # See the License for the specific language governing permissions and
13
- # limitations under the License.
14
-
15
- """
16
- This module provides the device oriented programming interface for MPC.
17
-
18
- The device oriented programming interface is designed to provide a high-level
19
- abstraction for the MPC programming. It allows the user to write the program
20
- in a device-oriented manner, and the runtime will take care of the data
21
- transformation between devices.
22
- """
23
-
24
- from __future__ import annotations
25
-
26
- from collections.abc import Callable
27
- from functools import partial, wraps
28
- from typing import Any
29
-
30
- from jax.tree_util import tree_map
31
-
32
- import mplang.host as mphost
33
- from mplang.core import (
34
- ClusterSpec,
35
- Device,
36
- InterpContext,
37
- MPObject,
38
- TensorType,
39
- cur_ctx,
40
- primitive,
41
- )
42
- from mplang.ops import basic, crypto, ibis_cc, jax_cc, tee
43
- from mplang.ops.base import FeOperation
44
- from mplang.ops.ibis_cc import IbisRunner
45
- from mplang.ops.jax_cc import JaxRunner
46
- from mplang.simp import mpi, smpc
47
- from mplang.simp.api import run_at
48
-
49
- # Automatic transfer between devices when parameter is not on the target device.
50
- g_auto_trans: bool = True
51
-
52
- _HKDF_INFO_LITERAL: str = "mplang/device/tee/v1"
53
- # Default KEM suite for TEE session establishment; make configurable via ClusterSpec in future.
54
- _TEE_KEM_SUITE: str = "x25519"
55
-
56
-
57
- # `function` decorator could also compile device-level apis.
58
- function = primitive.function
59
-
60
- # magic attribute name to mark a MPObject as a device object
61
- DEVICE_ATTR_NAME = "_devid_"
62
-
63
-
64
- def _is_device_obj(obj: Any) -> bool:
65
- if not isinstance(obj, MPObject):
66
- return False
67
- return DEVICE_ATTR_NAME in obj.attrs
68
-
69
-
70
- def _set_devid(obj: MPObject, dev_id: str) -> MPObject:
71
- if not isinstance(obj, MPObject):
72
- raise TypeError(f"Input must be an instance of Object, {obj}")
73
- obj.attrs[DEVICE_ATTR_NAME] = dev_id
74
- return obj
75
-
76
-
77
- def _get_devid(obj: MPObject) -> str:
78
- if not isinstance(obj, MPObject):
79
- raise TypeError("Input must be an instance of Object")
80
-
81
- return obj.attrs[DEVICE_ATTR_NAME] # type: ignore[no-any-return]
82
-
83
-
84
- _is_mpobj = lambda x: isinstance(x, MPObject)
85
-
86
-
87
- def _device_run_spu(
88
- dev_info: Device, op: FeOperation, *args: Any, **kwargs: Any
89
- ) -> Any:
90
- if not isinstance(op, JaxRunner):
91
- raise ValueError("SPU device only supports JAX frontend.")
92
- fn, *aargs = args
93
- var = smpc.srun(fn)(*aargs, **kwargs)
94
- return tree_map(partial(_set_devid, dev_id=dev_info.name), var)
95
-
96
-
97
- def _device_run_tee(
98
- dev_info: Device, op: FeOperation, *args: Any, **kwargs: Any
99
- ) -> Any:
100
- if not isinstance(op, JaxRunner) and not isinstance(op, IbisRunner):
101
- raise ValueError("TEE device only supports JAX and Ibis frontend.")
102
- assert len(dev_info.members) == 1
103
- rank = dev_info.members[0].rank
104
- var = run_at(rank, op, *args, **kwargs)
105
- return tree_map(partial(_set_devid, dev_id=dev_info.name), var)
106
-
107
-
108
- def _device_run_ppu(
109
- dev_info: Device, op: FeOperation, *args: Any, **kwargs: Any
110
- ) -> Any:
111
- assert len(dev_info.members) == 1
112
- rank = dev_info.members[0].rank
113
- var = run_at(rank, op, *args, **kwargs)
114
- return tree_map(partial(_set_devid, dev_id=dev_info.name), var)
115
-
116
-
117
- def _device_run(dev_id: str, op: FeOperation, *args: Any, **kwargs: Any) -> Any:
118
- assert isinstance(op, FeOperation)
119
- cluster_spec = mphost.cur_ctx().cluster_spec
120
- if dev_id not in cluster_spec.devices:
121
- raise ValueError(f"Device {dev_id} not found in cluster spec.")
122
-
123
- if g_auto_trans:
124
-
125
- def trans(obj: Any) -> Any:
126
- if _is_mpobj(obj):
127
- assert _is_device_obj(obj)
128
- return _d2d(dev_id, obj)
129
- else:
130
- return obj
131
-
132
- args, kwargs = tree_map(trans, (args, kwargs))
133
-
134
- dev_info = cluster_spec.devices[dev_id]
135
- if dev_info.kind.upper() == "SPU":
136
- return _device_run_spu(dev_info, op, *args, **kwargs)
137
- elif dev_info.kind.upper() == "TEE":
138
- return _device_run_tee(dev_info, op, *args, **kwargs)
139
- elif dev_info.kind.upper() == "PPU":
140
- return _device_run_ppu(dev_info, op, *args, **kwargs)
141
- else:
142
- raise ValueError(f"Unknown device type: {dev_info.kind}")
143
-
144
-
145
- def device(dev_id: str, *, fe_type: str = "jax") -> Callable:
146
- """Decorator to mark a function to be executed on a specific device.
147
-
148
- Args:
149
- dev_id: The device id.
150
- fe_type: The frontend type of the device, could be "jax" or "ibis".
151
-
152
- Note: 'fe_type' is not needed if the decorated function is already a FeOperation.
153
-
154
- Example:
155
- >>> @device("P0")
156
- ... def foo(x, y):
157
- ... return x + y
158
- """
159
-
160
- def deco(fn: Callable) -> Callable:
161
- @wraps(fn)
162
- def wrapped(*args: Any, **kwargs: Any) -> Any:
163
- if isinstance(fn, FeOperation):
164
- return _device_run(dev_id, fn, *args, **kwargs)
165
- else:
166
- if fe_type == "jax":
167
- return _device_run(dev_id, jax_cc.run_jax, fn, *args, **kwargs)
168
- elif fe_type == "ibis":
169
- return _device_run(dev_id, ibis_cc.run_ibis, fn, *args, **kwargs)
170
- else:
171
- raise ValueError(f"Unsupported frontend type: {fe_type}")
172
-
173
- return wrapped
174
-
175
- return deco
176
-
177
-
178
- def _d2d(to_dev_id: str, obj: MPObject) -> MPObject:
179
- assert isinstance(obj, MPObject)
180
- frm_dev_id = _get_devid(obj)
181
-
182
- if frm_dev_id == to_dev_id:
183
- return obj
184
-
185
- cluster_spec: ClusterSpec = mphost.cur_ctx().cluster_spec
186
- frm_dev = cluster_spec.devices[frm_dev_id]
187
- to_dev = cluster_spec.devices[to_dev_id]
188
- frm_to_pair = (frm_dev.kind.upper(), to_dev.kind.upper())
189
-
190
- if frm_to_pair == ("SPU", "SPU"):
191
- raise NotImplementedError("Only one SPU is supported for now.")
192
- elif frm_to_pair == ("SPU", "PPU"):
193
- assert len(to_dev.members) == 1
194
- to_rank = to_dev.members[0].rank
195
- var = smpc.revealTo(obj, to_rank)
196
- return tree_map(partial(_set_devid, dev_id=to_dev_id), var) # type: ignore[no-any-return]
197
- elif frm_to_pair == ("PPU", "SPU"):
198
- assert len(frm_dev.members) == 1
199
- frm_rank = frm_dev.members[0].rank
200
- var = smpc.sealFrom(obj, frm_rank)
201
- return tree_map(partial(_set_devid, dev_id=to_dev_id), var) # type: ignore[no-any-return]
202
- elif frm_to_pair == ("PPU", "PPU"):
203
- assert len(frm_dev.members) == 1 and len(to_dev.members) == 1
204
- frm_rank = frm_dev.members[0].rank
205
- to_rank = to_dev.members[0].rank
206
- var = mpi.p2p(frm_rank, to_rank, obj)
207
- return tree_map(partial(_set_devid, dev_id=to_dev_id), var) # type: ignore[no-any-return]
208
- elif frm_to_pair == ("PPU", "TEE"):
209
- # Transparent handshake + encryption for the first transfer; reuse thereafter
210
- assert len(frm_dev.members) == 1 and len(to_dev.members) == 1
211
- frm_rank = frm_dev.members[0].rank
212
- tee_rank = to_dev.members[0].rank
213
- # Ensure sessions (both directions) exist for this PPU<->TEE pair
214
- sess_p, sess_t = _ensure_tee_session(frm_dev_id, to_dev_id, frm_rank, tee_rank)
215
- # Bytes-only path: pack -> enc -> p2p -> dec -> unpack (with static out type)
216
- obj_ty = TensorType.from_obj(obj)
217
- b = run_at(frm_rank, basic.pack, obj)
218
- ct = run_at(frm_rank, crypto.enc, b, sess_p)
219
- ct_at_tee = mpi.p2p(frm_rank, tee_rank, ct)
220
- b_at_tee = run_at(tee_rank, crypto.dec, ct_at_tee, sess_t)
221
- pt_at_tee = run_at(tee_rank, basic.unpack, b_at_tee, out_ty=obj_ty)
222
- return tree_map(partial(_set_devid, dev_id=to_dev_id), pt_at_tee) # type: ignore[no-any-return]
223
- elif frm_to_pair == ("TEE", "PPU"):
224
- # Transparent encryption from TEE to a specific PPU using the reverse-direction session key
225
- assert len(frm_dev.members) == 1 and len(to_dev.members) == 1
226
- tee_rank = frm_dev.members[0].rank
227
- ppu_rank = to_dev.members[0].rank
228
- # Ensure bidirectional session established for this pair
229
- sess_p, sess_t = _ensure_tee_session(to_dev_id, frm_dev_id, ppu_rank, tee_rank)
230
- obj_ty = TensorType.from_obj(obj)
231
- b = run_at(tee_rank, basic.pack, obj)
232
- ct = run_at(tee_rank, crypto.enc, b, sess_t)
233
- ct_at_ppu = mpi.p2p(tee_rank, ppu_rank, ct)
234
- b_at_ppu = run_at(ppu_rank, crypto.dec, ct_at_ppu, sess_p)
235
- pt_at_ppu = run_at(ppu_rank, basic.unpack, b_at_ppu, out_ty=obj_ty)
236
- return tree_map(partial(_set_devid, dev_id=to_dev_id), pt_at_ppu) # type: ignore[no-any-return]
237
- else:
238
- supported = [
239
- ("SPU", "PPU"),
240
- ("PPU", "SPU"),
241
- ("PPU", "PPU"),
242
- ("PPU", "TEE"),
243
- ("TEE", "PPU"),
244
- ]
245
- raise ValueError(
246
- f"Unsupported device transfer: {frm_to_pair}. Supported pairs: {supported}."
247
- )
248
-
249
-
250
- def _ensure_tee_session(
251
- frm_dev_id: str, to_dev_id: str, frm_rank: int, tee_rank: int
252
- ) -> tuple[MPObject, MPObject]:
253
- """Ensure a TEE session (sess_p at sender, sess_t at TEE) exists.
254
-
255
- Returns (sess_p, sess_t).
256
- """
257
- ctx = cur_ctx().root()
258
- if not hasattr(ctx, "_tee_sessions"):
259
- ctx._tee_sessions = {} # type: ignore[attr-defined]
260
- cache: dict[tuple[str, str], tuple[MPObject, MPObject]] = ctx._tee_sessions # type: ignore
261
-
262
- key = (frm_dev_id, to_dev_id)
263
- if key in cache:
264
- return cache[key]
265
-
266
- # 1) TEE generates (sk, pk) and quote(pk)
267
- # KEM suite currently constant; future: read from tee device config (e.g. cluster_spec.devices[dev_id].config)
268
- tee_sk, tee_pk = run_at(tee_rank, crypto.kem_keygen, _TEE_KEM_SUITE)
269
- quote = run_at(tee_rank, tee.quote_gen, tee_pk)
270
-
271
- # 2) Send quote to sender and attest to obtain TEE pk
272
- quote_at_sender = mpi.p2p(tee_rank, frm_rank, quote)
273
- tee_pk_at_sender = run_at(frm_rank, tee.attest, quote_at_sender)
274
-
275
- # 3) Sender generates its ephemeral keypair and sends its pk to TEE
276
- v_sk, v_pk = run_at(frm_rank, crypto.kem_keygen, _TEE_KEM_SUITE)
277
- v_pk_at_tee = mpi.p2p(frm_rank, tee_rank, v_pk)
278
-
279
- # 4) Both sides derive the shared secret and session key
280
- shared_p = run_at(
281
- frm_rank, crypto.kem_derive, v_sk, tee_pk_at_sender, _TEE_KEM_SUITE
282
- )
283
- shared_t = run_at(tee_rank, crypto.kem_derive, tee_sk, v_pk_at_tee, _TEE_KEM_SUITE)
284
- # Use a fixed ASCII string literal for HKDF info on both sides
285
- sess_p = run_at(frm_rank, crypto.hkdf, shared_p, _HKDF_INFO_LITERAL)
286
- sess_t = run_at(tee_rank, crypto.hkdf, shared_t, _HKDF_INFO_LITERAL)
287
-
288
- cache[key] = (sess_p, sess_t)
289
- return sess_p, sess_t
290
-
291
-
292
- def put(to_dev_id: str, obj: Any) -> MPObject:
293
- if not isinstance(obj, MPObject):
294
- return device(to_dev_id)(lambda x: x)(obj) # type: ignore[no-any-return]
295
- assert isinstance(obj, MPObject)
296
- return _d2d(to_dev_id, obj)
297
-
298
-
299
- def _fetch(interp: InterpContext, obj: MPObject) -> Any:
300
- dev_id = _get_devid(obj)
301
- cluster_spec = interp.cluster_spec
302
- dev_kind = cluster_spec.devices[dev_id].kind.upper()
303
-
304
- dev_info = cluster_spec.devices[dev_id]
305
- if dev_kind == "SPU":
306
- revealed = mphost.evaluate(interp, smpc.reveal, obj)
307
- result = mphost.fetch(interp, revealed)
308
- # now all members have the same value, return the one at rank 0
309
- return result[dev_info.members[0].rank]
310
- elif dev_kind == "PPU":
311
- assert len(dev_info.members) == 1
312
- rank = dev_info.members[0].rank
313
- result = mphost.fetch(interp, obj)
314
- return result[rank]
315
- elif dev_kind == "TEE":
316
- assert len(dev_info.members) == 1
317
- rank = dev_info.members[0].rank
318
- result = mphost.fetch(interp, obj)
319
- return result[rank]
320
- else:
321
- raise ValueError(f"Unknown device id: {dev_id}")
322
-
323
-
324
- def fetch(interp: InterpContext, objs: Any) -> Any:
325
- ctx = interp or mphost.cur_ctx()
326
- assert isinstance(ctx, InterpContext), f"Expect InterpContext, got {ctx}"
327
- return tree_map(partial(_fetch, ctx), objs)
mplang/ops/crypto.py DELETED
@@ -1,108 +0,0 @@
1
- # Copyright 2025 Ant Group Co., Ltd.
2
- #
3
- # Licensed under the Apache License, Version 2.0 (the "License");
4
- # you may not use this file except in compliance with the License.
5
- # You may obtain a copy of the License at
6
- #
7
- # http://www.apache.org/licenses/LICENSE-2.0
8
- #
9
- # Unless required by applicable law or agreed to in writing, software
10
- # distributed under the License is distributed on an "AS IS" BASIS,
11
- # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
- # See the License for the specific language governing permissions and
13
- # limitations under the License.
14
-
15
- """
16
- Crypto frontend operations: operation signatures, types, and high-level semantics.
17
-
18
- Scope and contracts:
19
- - This module defines portable API shapes; it does not implement cryptography.
20
- - Backends execute the operations and must meet the security semantics required
21
- by the deployment (confidentiality, authenticity, correctness, etc.).
22
- - The enc/dec API in this frontend uses a conventional 12-byte nonce prefix
23
- (ciphertext = nonce || payload), and dec expects that format. Other security
24
- properties (e.g., AEAD) are backend responsibilities.
25
- """
26
-
27
- from __future__ import annotations
28
-
29
- from mplang.core import UINT8, TensorType
30
- from mplang.ops.base import stateless_mod
31
-
32
- _CRYPTO_MOD = stateless_mod("crypto")
33
-
34
-
35
- @_CRYPTO_MOD.simple_op()
36
- def keygen(*, length: int = 32) -> TensorType:
37
- """Generate random bytes for symmetric keys or generic randomness.
38
-
39
- API: keygen(length: int = 32) -> key: u8[length]
40
-
41
- Notes:
42
- - Frontend defines the type/shape; backend provides randomness.
43
- - Raises ValueError when length <= 0.
44
- """
45
- if length <= 0:
46
- raise ValueError("length must be > 0")
47
- return TensorType(UINT8, (length,))
48
-
49
-
50
- @_CRYPTO_MOD.simple_op()
51
- def enc(plaintext: TensorType, key: TensorType) -> TensorType:
52
- """Symmetric encryption.
53
-
54
- API: enc(plaintext: u8[N], key: u8[M]) -> ciphertext: u8[N + 12]
55
- """
56
- pt_ty = plaintext
57
- if pt_ty.dtype != UINT8:
58
- raise TypeError("enc expects UINT8 plaintext")
59
- if len(pt_ty.shape) != 1:
60
- raise TypeError("enc expects 1-D plaintext")
61
- length = pt_ty.shape[0]
62
- if length >= 0:
63
- return TensorType(UINT8, (length + 12,))
64
- return TensorType(UINT8, (-1,))
65
-
66
-
67
- @_CRYPTO_MOD.simple_op()
68
- def dec(ciphertext: TensorType, key: TensorType) -> TensorType:
69
- """Symmetric decryption.
70
-
71
- API: dec(ciphertext: u8[N + 12], key: u8[M]) -> plaintext: u8[N]
72
- """
73
- ct_ty = ciphertext
74
- if ct_ty.dtype != UINT8:
75
- raise TypeError("dec expects UINT8 ciphertext")
76
- if len(ct_ty.shape) != 1:
77
- raise TypeError("dec expects 1-D ciphertext with nonce")
78
- length = ct_ty.shape[0]
79
- if length >= 0 and length < 12:
80
- raise TypeError("dec expects 1-D ciphertext with nonce")
81
- if length >= 0:
82
- return TensorType(UINT8, (length - 12,))
83
- return TensorType(UINT8, (-1,))
84
-
85
-
86
- @_CRYPTO_MOD.simple_op()
87
- def kem_keygen(*, suite: str = "x25519") -> tuple[TensorType, TensorType]:
88
- """KEM-style keypair generation: returns (sk, pk) bytes."""
89
- sk_ty = TensorType(UINT8, (32,))
90
- pk_ty = TensorType(UINT8, (32,))
91
- return sk_ty, pk_ty
92
-
93
-
94
- @_CRYPTO_MOD.simple_op()
95
- def kem_derive(
96
- sk: TensorType, peer_pk: TensorType, *, suite: str = "x25519"
97
- ) -> TensorType:
98
- """KEM-style shared secret derivation: returns secret bytes."""
99
- _ = sk
100
- _ = peer_pk
101
- return TensorType(UINT8, (32,))
102
-
103
-
104
- @_CRYPTO_MOD.simple_op()
105
- def hkdf(secret: TensorType, *, info: str) -> TensorType:
106
- """HKDF-style key derivation: returns a 32-byte key."""
107
- _ = secret
108
- return TensorType(UINT8, (32,))
mplang/ops/ibis_cc.py DELETED
@@ -1,136 +0,0 @@
1
- # Copyright 2025 Ant Group Co., Ltd.
2
- #
3
- # Licensed under the Apache License, Version 2.0 (the "License");
4
- # you may not use this file except in compliance with the License.
5
- # You may obtain a copy of the License at
6
- #
7
- # http://www.apache.org/licenses/LICENSE-2.0
8
- #
9
- # Unless required by applicable law or agreed to in writing, software
10
- # distributed under the License is distributed on an "AS IS" BASIS,
11
- # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
- # See the License for the specific language governing permissions and
13
- # limitations under the License.
14
-
15
-
16
- import inspect
17
- from collections.abc import Callable
18
- from typing import Any
19
-
20
- import ibis
21
- from jax.tree_util import PyTreeDef, tree_flatten
22
-
23
- from mplang.core import MPObject, PFunction, TableType, dtypes
24
- from mplang.ops.base import FeOperation, stateless_mod
25
- from mplang.utils.func_utils import normalize_fn
26
-
27
-
28
- def ibis2sql(
29
- expr: ibis.Table,
30
- in_schemas: list[ibis.Schema],
31
- in_names: list[str],
32
- fn_name: str = "",
33
- ) -> PFunction:
34
- """
35
- Compile a ibis expr to sql and return the PFunction.
36
-
37
- Args:
38
- expr: ibis expr.
39
- in_schemas: the input table schemas
40
- in_names: the input table names, If there is only one table, it is usually defaulted to "table"
41
- Return:
42
- PFunction: The compiled PFunction
43
- """
44
- assert len(in_schemas) == len(in_names), (
45
- f"length of input table names and schemas mismatch. {len(in_schemas)}!={len(in_names)}"
46
- )
47
-
48
- def _convert(s: ibis.Schema) -> TableType:
49
- return TableType.from_pairs([
50
- (name, dtypes.from_numpy(dt.to_numpy())) for name, dt in s.fields.items()
51
- ])
52
-
53
- ins_info = [_convert(s) for s in in_schemas]
54
- outs_info = [_convert(expr.schema())]
55
-
56
- sql = ibis.to_sql(expr, dialect="duckdb")
57
- # Emit generic sql.run op; runtime maps to backend-specific kernel.
58
- pfn = PFunction(
59
- fn_type="sql.run",
60
- fn_name=fn_name,
61
- fn_text=sql,
62
- ins_info=tuple(ins_info),
63
- outs_info=tuple(outs_info),
64
- in_names=tuple(in_names),
65
- dialect="duckdb",
66
- )
67
- return pfn
68
-
69
-
70
- def is_ibis_function(func: Callable) -> bool:
71
- """
72
- Verify whether a function is an ibis function.
73
- The func signature should like def foo(t0:ibis.Table, t1:ibis.Table)->ibis.Table
74
- """
75
- try:
76
- sig = inspect.signature(func)
77
- except (ValueError, TypeError):
78
- return False
79
-
80
- ret_anno = sig.return_annotation
81
- if ret_anno is ibis.Table:
82
- return True
83
-
84
- for param in sig.parameters.values():
85
- par_anno = param.annotation
86
- if par_anno is ibis.Table:
87
- return True
88
-
89
- return False
90
-
91
-
92
- _IBIS_MOD = stateless_mod("ibis")
93
-
94
-
95
- class IbisRunner(FeOperation):
96
- """Ibis runner frontend operation."""
97
-
98
- def trace(
99
- self, func: Callable, *args: Any, **kwargs: Any
100
- ) -> tuple[PFunction, list[MPObject], PyTreeDef]:
101
- """Compile an Ibis function to SQL format.
102
-
103
- Args:
104
- func: The Ibis function to compile
105
- *args: Positional arguments to the function
106
- **kwargs: Keyword arguments to the function
107
-
108
- Returns:
109
- tuple[PFunction, list[MPObject], Any]: The compiled PFunction, input variables, and output tree
110
- """
111
-
112
- def is_variable(arg: Any) -> bool:
113
- return isinstance(arg, MPObject)
114
-
115
- normalized_fn, in_vars = normalize_fn(func, args, kwargs, is_variable)
116
-
117
- in_args, in_schemas, in_names = [], [], []
118
- idx = 0
119
- for arg in in_vars:
120
- columns = [(p[0], p[1].to_numpy()) for p in arg.schema.columns]
121
- schema = ibis.schema(columns)
122
- name = f"table{idx}"
123
- table = ibis.table(schema=schema, name=name)
124
- in_args.append(table)
125
- in_schemas.append(schema)
126
- in_names.append(name)
127
- idx += 1
128
-
129
- result = normalized_fn(in_args)
130
- assert isinstance(result, ibis.Table)
131
- pfunc = ibis2sql(result, in_schemas, in_names, func.__name__)
132
- _, treedef = tree_flatten(result)
133
- return pfunc, in_vars, treedef
134
-
135
-
136
- run_ibis = IbisRunner(_IBIS_MOD, "run")
mplang/ops/sql_cc.py DELETED
@@ -1,62 +0,0 @@
1
- # Copyright 2025 Ant Group Co., Ltd.
2
- #
3
- # Licensed under the Apache License, Version 2.0 (the "License");
4
- # you may not use this file except in compliance with the License.
5
- # You may obtain a copy of the License at
6
- #
7
- # http://www.apache.org/licenses/LICENSE-2.0
8
- #
9
- # Unless required by applicable law or agreed to in writing, software
10
- # distributed under the License is distributed on an "AS IS" BASIS,
11
- # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
- # See the License for the specific language governing permissions and
13
- # limitations under the License.
14
-
15
- from jax.tree_util import PyTreeDef, tree_flatten
16
-
17
- from mplang.core.mpobject import MPObject
18
- from mplang.core.pfunc import PFunction
19
- from mplang.core.table import TableType
20
- from mplang.ops.base import FeOperation, stateless_mod
21
-
22
- _SQL_MOD = stateless_mod("sql")
23
-
24
-
25
- class SqlRunner(FeOperation):
26
- def __init__(self, dialect: str = "duckdb"):
27
- # Bind to sql module with a stable op name for registry/dispatch
28
- super().__init__(_SQL_MOD, "run")
29
- self._dialect = dialect
30
-
31
- # TODO(jint): we should deduce out_type according to query and in_tables' schema
32
- def trace(
33
- self,
34
- query: str,
35
- out_type: TableType,
36
- in_tables: dict[str, MPObject] | None = None,
37
- ) -> tuple[PFunction, list[MPObject], PyTreeDef]:
38
- in_names: list[str] = []
39
- ins_info: list[TableType] = []
40
- in_vars: list[MPObject] = []
41
- if in_tables:
42
- for name, tbl in in_tables.items():
43
- assert isinstance(tbl, MPObject)
44
- assert tbl.schema is not None
45
- in_names.append(name)
46
- ins_info.append(tbl.schema)
47
- in_vars.append(tbl)
48
-
49
- pfn = PFunction(
50
- fn_type="sql.run",
51
- fn_name="",
52
- fn_text=query,
53
- ins_info=tuple(ins_info),
54
- outs_info=(out_type,),
55
- in_names=tuple(in_names),
56
- dialect=self._dialect,
57
- )
58
- _, treedef = tree_flatten(out_type)
59
- return pfn, in_vars, treedef
60
-
61
-
62
- run_sql = SqlRunner("duckdb")