mplang-nightly 0.1.dev158__py3-none-any.whl → 0.1.dev268__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (191) hide show
  1. mplang/__init__.py +21 -45
  2. mplang/py.typed +13 -0
  3. mplang/v1/__init__.py +157 -0
  4. mplang/v1/_device.py +602 -0
  5. mplang/{analysis → v1/analysis}/__init__.py +1 -1
  6. mplang/{analysis → v1/analysis}/diagram.py +5 -7
  7. mplang/v1/core/__init__.py +157 -0
  8. mplang/{core → v1/core}/cluster.py +30 -14
  9. mplang/{core → v1/core}/comm.py +5 -1
  10. mplang/{core → v1/core}/context_mgr.py +1 -1
  11. mplang/{core/dtype.py → v1/core/dtypes.py} +44 -2
  12. mplang/{core → v1/core}/expr/__init__.py +7 -7
  13. mplang/{core → v1/core}/expr/ast.py +13 -14
  14. mplang/{core → v1/core}/expr/evaluator.py +65 -24
  15. mplang/{core → v1/core}/expr/printer.py +24 -18
  16. mplang/{core → v1/core}/expr/transformer.py +3 -3
  17. mplang/{core → v1/core}/expr/utils.py +2 -2
  18. mplang/{core → v1/core}/expr/visitor.py +1 -1
  19. mplang/{core → v1/core}/expr/walk.py +1 -1
  20. mplang/{core → v1/core}/interp.py +6 -6
  21. mplang/{core → v1/core}/mpir.py +23 -16
  22. mplang/{core → v1/core}/mpobject.py +6 -6
  23. mplang/{core → v1/core}/mptype.py +13 -10
  24. mplang/{core → v1/core}/pfunc.py +4 -4
  25. mplang/{core → v1/core}/primitive.py +106 -201
  26. mplang/{core → v1/core}/table.py +36 -8
  27. mplang/{core → v1/core}/tensor.py +1 -1
  28. mplang/{core → v1/core}/tracer.py +9 -9
  29. mplang/{api.py → v1/host.py} +38 -6
  30. mplang/v1/kernels/__init__.py +41 -0
  31. mplang/{kernels → v1/kernels}/base.py +1 -1
  32. mplang/v1/kernels/basic.py +240 -0
  33. mplang/{kernels → v1/kernels}/context.py +42 -27
  34. mplang/{kernels → v1/kernels}/crypto.py +44 -37
  35. mplang/v1/kernels/fhe.py +858 -0
  36. mplang/{kernels → v1/kernels}/mock_tee.py +12 -13
  37. mplang/{kernels → v1/kernels}/phe.py +263 -57
  38. mplang/{kernels → v1/kernels}/spu.py +137 -48
  39. mplang/{kernels → v1/kernels}/sql_duckdb.py +12 -15
  40. mplang/{kernels → v1/kernels}/stablehlo.py +30 -23
  41. mplang/v1/kernels/value.py +626 -0
  42. mplang/{ops → v1/ops}/__init__.py +5 -16
  43. mplang/{ops → v1/ops}/base.py +2 -5
  44. mplang/{ops/builtin.py → v1/ops/basic.py} +34 -26
  45. mplang/v1/ops/crypto.py +262 -0
  46. mplang/v1/ops/fhe.py +272 -0
  47. mplang/{ops → v1/ops}/jax_cc.py +33 -68
  48. mplang/v1/ops/nnx_cc.py +168 -0
  49. mplang/{ops → v1/ops}/phe.py +16 -4
  50. mplang/{ops → v1/ops}/spu.py +3 -5
  51. mplang/v1/ops/sql_cc.py +303 -0
  52. mplang/{ops → v1/ops}/tee.py +9 -24
  53. mplang/{protos → v1/protos}/v1alpha1/mpir_pb2.pyi +71 -21
  54. mplang/v1/protos/v1alpha1/value_pb2.py +34 -0
  55. mplang/v1/protos/v1alpha1/value_pb2.pyi +169 -0
  56. mplang/{runtime → v1/runtime}/__init__.py +2 -2
  57. mplang/v1/runtime/channel.py +230 -0
  58. mplang/{runtime → v1/runtime}/cli.py +35 -20
  59. mplang/{runtime → v1/runtime}/client.py +19 -8
  60. mplang/{runtime → v1/runtime}/communicator.py +59 -15
  61. mplang/{runtime → v1/runtime}/data_providers.py +80 -19
  62. mplang/{runtime → v1/runtime}/driver.py +30 -12
  63. mplang/v1/runtime/link_comm.py +196 -0
  64. mplang/{runtime → v1/runtime}/server.py +58 -42
  65. mplang/{runtime → v1/runtime}/session.py +57 -71
  66. mplang/{runtime → v1/runtime}/simulation.py +55 -28
  67. mplang/v1/simp/api.py +353 -0
  68. mplang/{simp → v1/simp}/mpi.py +8 -9
  69. mplang/{simp/__init__.py → v1/simp/party.py} +19 -145
  70. mplang/{simp → v1/simp}/random.py +21 -22
  71. mplang/v1/simp/smpc.py +238 -0
  72. mplang/v1/utils/table_utils.py +185 -0
  73. mplang/v2/__init__.py +424 -0
  74. mplang/v2/backends/__init__.py +57 -0
  75. mplang/v2/backends/bfv_impl.py +705 -0
  76. mplang/v2/backends/channel.py +217 -0
  77. mplang/v2/backends/crypto_impl.py +723 -0
  78. mplang/v2/backends/field_impl.py +454 -0
  79. mplang/v2/backends/func_impl.py +107 -0
  80. mplang/v2/backends/phe_impl.py +148 -0
  81. mplang/v2/backends/simp_design.md +136 -0
  82. mplang/v2/backends/simp_driver/__init__.py +41 -0
  83. mplang/v2/backends/simp_driver/http.py +168 -0
  84. mplang/v2/backends/simp_driver/mem.py +280 -0
  85. mplang/v2/backends/simp_driver/ops.py +135 -0
  86. mplang/v2/backends/simp_driver/state.py +60 -0
  87. mplang/v2/backends/simp_driver/values.py +52 -0
  88. mplang/v2/backends/simp_worker/__init__.py +29 -0
  89. mplang/v2/backends/simp_worker/http.py +354 -0
  90. mplang/v2/backends/simp_worker/mem.py +102 -0
  91. mplang/v2/backends/simp_worker/ops.py +167 -0
  92. mplang/v2/backends/simp_worker/state.py +49 -0
  93. mplang/v2/backends/spu_impl.py +275 -0
  94. mplang/v2/backends/spu_state.py +187 -0
  95. mplang/v2/backends/store_impl.py +62 -0
  96. mplang/v2/backends/table_impl.py +838 -0
  97. mplang/v2/backends/tee_impl.py +215 -0
  98. mplang/v2/backends/tensor_impl.py +519 -0
  99. mplang/v2/cli.py +603 -0
  100. mplang/v2/cli_guide.md +122 -0
  101. mplang/v2/dialects/__init__.py +36 -0
  102. mplang/v2/dialects/bfv.py +665 -0
  103. mplang/v2/dialects/crypto.py +689 -0
  104. mplang/v2/dialects/dtypes.py +378 -0
  105. mplang/v2/dialects/field.py +210 -0
  106. mplang/v2/dialects/func.py +135 -0
  107. mplang/v2/dialects/phe.py +723 -0
  108. mplang/v2/dialects/simp.py +944 -0
  109. mplang/v2/dialects/spu.py +349 -0
  110. mplang/v2/dialects/store.py +63 -0
  111. mplang/v2/dialects/table.py +407 -0
  112. mplang/v2/dialects/tee.py +346 -0
  113. mplang/v2/dialects/tensor.py +1175 -0
  114. mplang/v2/edsl/README.md +279 -0
  115. mplang/v2/edsl/__init__.py +99 -0
  116. mplang/v2/edsl/context.py +311 -0
  117. mplang/v2/edsl/graph.py +463 -0
  118. mplang/v2/edsl/jit.py +62 -0
  119. mplang/v2/edsl/object.py +53 -0
  120. mplang/v2/edsl/primitive.py +284 -0
  121. mplang/v2/edsl/printer.py +119 -0
  122. mplang/v2/edsl/registry.py +207 -0
  123. mplang/v2/edsl/serde.py +375 -0
  124. mplang/v2/edsl/tracer.py +614 -0
  125. mplang/v2/edsl/typing.py +816 -0
  126. mplang/v2/kernels/Makefile +30 -0
  127. mplang/v2/kernels/__init__.py +23 -0
  128. mplang/v2/kernels/gf128.cpp +148 -0
  129. mplang/v2/kernels/ldpc.cpp +82 -0
  130. mplang/v2/kernels/okvs.cpp +283 -0
  131. mplang/v2/kernels/okvs_opt.cpp +291 -0
  132. mplang/v2/kernels/py_kernels.py +398 -0
  133. mplang/v2/libs/collective.py +330 -0
  134. mplang/v2/libs/device/__init__.py +51 -0
  135. mplang/v2/libs/device/api.py +813 -0
  136. mplang/v2/libs/device/cluster.py +352 -0
  137. mplang/v2/libs/ml/__init__.py +23 -0
  138. mplang/v2/libs/ml/sgb.py +1861 -0
  139. mplang/v2/libs/mpc/__init__.py +41 -0
  140. mplang/v2/libs/mpc/_utils.py +99 -0
  141. mplang/v2/libs/mpc/analytics/__init__.py +35 -0
  142. mplang/v2/libs/mpc/analytics/aggregation.py +372 -0
  143. mplang/v2/libs/mpc/analytics/groupby.md +99 -0
  144. mplang/v2/libs/mpc/analytics/groupby.py +331 -0
  145. mplang/v2/libs/mpc/analytics/permutation.py +386 -0
  146. mplang/v2/libs/mpc/common/constants.py +39 -0
  147. mplang/v2/libs/mpc/ot/__init__.py +32 -0
  148. mplang/v2/libs/mpc/ot/base.py +222 -0
  149. mplang/v2/libs/mpc/ot/extension.py +477 -0
  150. mplang/v2/libs/mpc/ot/silent.py +217 -0
  151. mplang/v2/libs/mpc/psi/__init__.py +40 -0
  152. mplang/v2/libs/mpc/psi/cuckoo.py +228 -0
  153. mplang/v2/libs/mpc/psi/okvs.py +49 -0
  154. mplang/v2/libs/mpc/psi/okvs_gct.py +79 -0
  155. mplang/v2/libs/mpc/psi/oprf.py +310 -0
  156. mplang/v2/libs/mpc/psi/rr22.py +344 -0
  157. mplang/v2/libs/mpc/psi/unbalanced.py +200 -0
  158. mplang/v2/libs/mpc/vole/__init__.py +31 -0
  159. mplang/v2/libs/mpc/vole/gilboa.py +327 -0
  160. mplang/v2/libs/mpc/vole/ldpc.py +383 -0
  161. mplang/v2/libs/mpc/vole/silver.py +336 -0
  162. mplang/v2/runtime/__init__.py +15 -0
  163. mplang/v2/runtime/dialect_state.py +41 -0
  164. mplang/v2/runtime/interpreter.py +871 -0
  165. mplang/v2/runtime/object_store.py +194 -0
  166. mplang/v2/runtime/value.py +141 -0
  167. {mplang_nightly-0.1.dev158.dist-info → mplang_nightly-0.1.dev268.dist-info}/METADATA +24 -17
  168. mplang_nightly-0.1.dev268.dist-info/RECORD +180 -0
  169. {mplang_nightly-0.1.dev158.dist-info → mplang_nightly-0.1.dev268.dist-info}/WHEEL +1 -1
  170. mplang/core/__init__.py +0 -92
  171. mplang/device.py +0 -340
  172. mplang/kernels/builtin.py +0 -207
  173. mplang/ops/crypto.py +0 -109
  174. mplang/ops/ibis_cc.py +0 -139
  175. mplang/ops/sql.py +0 -61
  176. mplang/protos/v1alpha1/mpir_pb2_grpc.py +0 -3
  177. mplang/runtime/link_comm.py +0 -131
  178. mplang/simp/smpc.py +0 -201
  179. mplang/utils/table_utils.py +0 -73
  180. mplang_nightly-0.1.dev158.dist-info/RECORD +0 -77
  181. /mplang/{core → v1/core}/mask.py +0 -0
  182. /mplang/{protos → v1/protos}/v1alpha1/mpir_pb2.py +0 -0
  183. /mplang/{runtime → v1/runtime}/exceptions.py +0 -0
  184. /mplang/{runtime → v1/runtime}/http_api.md +0 -0
  185. /mplang/{kernels → v1/simp}/__init__.py +0 -0
  186. /mplang/{utils → v1/utils}/__init__.py +0 -0
  187. /mplang/{utils → v1/utils}/crypto.py +0 -0
  188. /mplang/{utils → v1/utils}/func_utils.py +0 -0
  189. /mplang/{utils → v1/utils}/spu_utils.py +0 -0
  190. {mplang_nightly-0.1.dev158.dist-info → mplang_nightly-0.1.dev268.dist-info}/entry_points.txt +0 -0
  191. {mplang_nightly-0.1.dev158.dist-info → mplang_nightly-0.1.dev268.dist-info}/licenses/LICENSE +0 -0
mplang/core/__init__.py DELETED
@@ -1,92 +0,0 @@
1
- # Copyright 2025 Ant Group Co., Ltd.
2
- #
3
- # Licensed under the Apache License, Version 2.0 (the "License");
4
- # you may not use this file except in compliance with the License.
5
- # You may obtain a copy of the License at
6
- #
7
- # http://www.apache.org/licenses/LICENSE-2.0
8
- #
9
- # Unless required by applicable law or agreed to in writing, software
10
- # distributed under the License is distributed on an "AS IS" BASIS,
11
- # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
- # See the License for the specific language governing permissions and
13
- # limitations under the License.
14
-
15
- """
16
- Core components for multi-party computation.
17
-
18
- This package provides the fundamental building blocks for multi-party computation,
19
- including type systems, tracing mechanisms, and interpreter contexts.
20
- """
21
-
22
- # Core type system
23
- # Communication interfaces & core symbols
24
- from mplang.core.comm import (
25
- CollectiveMixin,
26
- CommunicatorBase,
27
- ICollective,
28
- ICommunicator,
29
- )
30
- from mplang.core.dtype import DType
31
- from mplang.core.interp import InterpContext, InterpVar
32
- from mplang.core.mask import Mask
33
- from mplang.core.mpobject import MPContext, MPObject
34
- from mplang.core.mptype import MPType, Rank, Shape
35
- from mplang.core.pfunc import PFunction
36
- from mplang.core.primitive import (
37
- constant,
38
- debug_print,
39
- function,
40
- pconv,
41
- peval,
42
- prand,
43
- prank,
44
- pshfl,
45
- pshfl_s,
46
- psize,
47
- set_mask,
48
- uniform_cond,
49
- while_loop,
50
- )
51
- from mplang.core.table import TableLike, TableType
52
- from mplang.core.tensor import TensorLike, TensorType
53
- from mplang.core.tracer import TraceContext, TracedFunction, TraceVar, VarNamer, trace
54
-
55
- __all__ = [
56
- "CollectiveMixin",
57
- "CommunicatorBase",
58
- "DType",
59
- "ICollective",
60
- "ICommunicator",
61
- "InterpContext",
62
- "InterpVar",
63
- "MPContext",
64
- "MPObject",
65
- "MPType",
66
- "Mask",
67
- "PFunction",
68
- "Rank",
69
- "Shape",
70
- "TableLike",
71
- "TableType",
72
- "TensorLike",
73
- "TensorType",
74
- "TraceContext",
75
- "TraceVar",
76
- "TracedFunction",
77
- "VarNamer",
78
- "constant",
79
- "debug_print",
80
- "function",
81
- "pconv",
82
- "peval",
83
- "prand",
84
- "prank",
85
- "pshfl",
86
- "pshfl_s",
87
- "psize",
88
- "set_mask",
89
- "trace",
90
- "uniform_cond",
91
- "while_loop",
92
- ]
mplang/device.py DELETED
@@ -1,340 +0,0 @@
1
- # Copyright 2025 Ant Group Co., Ltd.
2
- #
3
- # Licensed under the Apache License, Version 2.0 (the "License");
4
- # you may not use this file except in compliance with the License.
5
- # You may obtain a copy of the License at
6
- #
7
- # http://www.apache.org/licenses/LICENSE-2.0
8
- #
9
- # Unless required by applicable law or agreed to in writing, software
10
- # distributed under the License is distributed on an "AS IS" BASIS,
11
- # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
- # See the License for the specific language governing permissions and
13
- # limitations under the License.
14
-
15
- """
16
- This module provides the device oriented programming interface for MPC.
17
-
18
- The device oriented programming interface is designed to provide a high-level
19
- abstraction for the MPC programming. It allows the user to write the program
20
- in a device-oriented manner, and the runtime will take care of the data
21
- transformation between devices.
22
- """
23
-
24
- from __future__ import annotations
25
-
26
- from collections.abc import Callable
27
- from functools import partial, wraps
28
- from typing import Any
29
-
30
- from jax.tree_util import tree_map
31
-
32
- import mplang.api as mapi
33
- from mplang import simp
34
- from mplang.core import InterpContext, MPObject, primitive
35
- from mplang.core.cluster import ClusterSpec, Device
36
- from mplang.core.context_mgr import cur_ctx
37
- from mplang.core.tensor import TensorType
38
- from mplang.ops import builtin, crypto, ibis_cc, jax_cc, tee
39
- from mplang.ops.base import FeOperation
40
- from mplang.ops.ibis_cc import IbisCompiler
41
- from mplang.ops.jax_cc import JaxCompiler
42
- from mplang.simp import mpi, smpc
43
-
44
- # Automatic transfer between devices when parameter is not on the target device.
45
- g_auto_trans: bool = True
46
-
47
- _HKDF_INFO_LITERAL: str = "mplang/device/tee/v1"
48
- # Default KEM suite for TEE session establishment; make configurable via ClusterSpec in future.
49
- _TEE_KEM_SUITE: str = "x25519"
50
-
51
-
52
- # `function` decorator could also compile device-level apis.
53
- function = primitive.function
54
-
55
- # magic attribute name to mark a MPObject as a device object
56
- DEVICE_ATTR_NAME = "_devid_"
57
-
58
-
59
- def _is_device_obj(obj: Any) -> bool:
60
- if not isinstance(obj, MPObject):
61
- return False
62
- return DEVICE_ATTR_NAME in obj.attrs
63
-
64
-
65
- def _set_devid(obj: MPObject, dev_id: str) -> MPObject:
66
- if not isinstance(obj, MPObject):
67
- raise TypeError(f"Input must be an instance of Object, {obj}")
68
- obj.attrs[DEVICE_ATTR_NAME] = dev_id
69
- return obj
70
-
71
-
72
- def _get_devid(obj: MPObject) -> str:
73
- if not isinstance(obj, MPObject):
74
- raise TypeError("Input must be an instance of Object")
75
-
76
- return obj.attrs[DEVICE_ATTR_NAME] # type: ignore[no-any-return]
77
-
78
-
79
- _is_mpobj = lambda x: isinstance(x, MPObject)
80
-
81
-
82
- def _device_run_spu(
83
- dev_info: Device, op: FeOperation, *args: Any, **kwargs: Any
84
- ) -> Any:
85
- if not isinstance(op, JaxCompiler):
86
- raise ValueError("SPU device only supports JAX frontend.")
87
- fn, *aargs = args
88
- var = smpc.srun(fn)(*aargs, **kwargs)
89
- return tree_map(partial(_set_devid, dev_id=dev_info.name), var)
90
-
91
-
92
- def _device_run_tee(
93
- dev_info: Device, op: FeOperation, *args: Any, **kwargs: Any
94
- ) -> Any:
95
- if not isinstance(op, JaxCompiler) and not isinstance(op, IbisCompiler):
96
- raise ValueError("TEE device only supports JAX and Ibis frontend.")
97
- assert len(dev_info.members) == 1
98
- rank = dev_info.members[0].rank
99
- var = simp.runAt(rank, op)(*args, **kwargs)
100
- return tree_map(partial(_set_devid, dev_id=dev_info.name), var)
101
-
102
-
103
- def _device_run_ppu(
104
- dev_info: Device, op: FeOperation, *args: Any, **kwargs: Any
105
- ) -> Any:
106
- assert len(dev_info.members) == 1
107
- rank = dev_info.members[0].rank
108
- var = simp.runAt(rank, op)(*args, **kwargs)
109
- return tree_map(partial(_set_devid, dev_id=dev_info.name), var)
110
-
111
-
112
- def _device_run(dev_id: str, op: FeOperation, *args: Any, **kwargs: Any) -> Any:
113
- assert isinstance(op, FeOperation)
114
- cluster_spec = mapi.cur_ctx().cluster_spec
115
- if dev_id not in cluster_spec.devices:
116
- raise ValueError(f"Device {dev_id} not found in cluster spec.")
117
-
118
- if g_auto_trans:
119
-
120
- def trans(obj: Any) -> Any:
121
- if _is_mpobj(obj):
122
- assert _is_device_obj(obj)
123
- return _d2d(dev_id, obj)
124
- else:
125
- return obj
126
-
127
- args, kwargs = tree_map(trans, (args, kwargs))
128
-
129
- dev_info = cluster_spec.devices[dev_id]
130
- if dev_info.kind.upper() == "SPU":
131
- return _device_run_spu(dev_info, op, *args, **kwargs)
132
- elif dev_info.kind.upper() == "TEE":
133
- return _device_run_tee(dev_info, op, *args, **kwargs)
134
- elif dev_info.kind.upper() == "PPU":
135
- return _device_run_ppu(dev_info, op, *args, **kwargs)
136
- else:
137
- raise ValueError(f"Unknown device type: {dev_info.kind}")
138
-
139
-
140
- def device(dev_id: str, *, fe_type: str = "jax") -> Callable:
141
- """Decorator to mark a function to be executed on a specific device.
142
-
143
- Args:
144
- dev_id: The device id.
145
- fe_type: The frontend type of the device, could be "jax" or "ibis".
146
-
147
- Note: 'fe_type' is not needed if the decorated function is already a FeOperation.
148
-
149
- Example:
150
- >>> @device("P0")
151
- ... def foo(x, y):
152
- ... return x + y
153
- """
154
-
155
- def deco(fn: Callable) -> Callable:
156
- @wraps(fn)
157
- def wrapped(*args: Any, **kwargs: Any) -> Any:
158
- if isinstance(fn, FeOperation):
159
- return _device_run(dev_id, fn, *args, **kwargs)
160
- else:
161
- if fe_type == "jax":
162
- return _device_run(dev_id, jax_cc.jax_compile, fn, *args, **kwargs)
163
- elif fe_type == "ibis":
164
- return _device_run(
165
- dev_id, ibis_cc.ibis_compile, fn, *args, **kwargs
166
- )
167
- else:
168
- raise ValueError(f"Unsupported frontend type: {fe_type}")
169
-
170
- return wrapped
171
-
172
- return deco
173
-
174
-
175
- def _d2d(to_dev_id: str, obj: MPObject) -> MPObject:
176
- assert isinstance(obj, MPObject)
177
- frm_dev_id = _get_devid(obj)
178
-
179
- if frm_dev_id == to_dev_id:
180
- return obj
181
-
182
- cluster_spec: ClusterSpec = mapi.cur_ctx().cluster_spec
183
- frm_dev = cluster_spec.devices[frm_dev_id]
184
- to_dev = cluster_spec.devices[to_dev_id]
185
- frm_to_pair = (frm_dev.kind.upper(), to_dev.kind.upper())
186
-
187
- if frm_to_pair == ("SPU", "SPU"):
188
- raise NotImplementedError("Only one SPU is supported for now.")
189
- elif frm_to_pair == ("SPU", "PPU"):
190
- assert len(to_dev.members) == 1
191
- to_rank = to_dev.members[0].rank
192
- var = smpc.revealTo(obj, to_rank)
193
- return tree_map(partial(_set_devid, dev_id=to_dev_id), var) # type: ignore[no-any-return]
194
- elif frm_to_pair == ("PPU", "SPU"):
195
- assert len(frm_dev.members) == 1
196
- frm_rank = frm_dev.members[0].rank
197
- var = smpc.sealFrom(obj, frm_rank)
198
- return tree_map(partial(_set_devid, dev_id=to_dev_id), var) # type: ignore[no-any-return]
199
- elif frm_to_pair == ("PPU", "PPU"):
200
- assert len(frm_dev.members) == 1 and len(to_dev.members) == 1
201
- frm_rank = frm_dev.members[0].rank
202
- to_rank = to_dev.members[0].rank
203
- var = mpi.p2p(frm_rank, to_rank, obj)
204
- return tree_map(partial(_set_devid, dev_id=to_dev_id), var) # type: ignore[no-any-return]
205
- elif frm_to_pair == ("PPU", "TEE"):
206
- # Transparent handshake + encryption for the first transfer; reuse thereafter
207
- assert len(frm_dev.members) == 1 and len(to_dev.members) == 1
208
- frm_rank = frm_dev.members[0].rank
209
- tee_rank = to_dev.members[0].rank
210
- platform = to_dev.config.get("platform")
211
- if not platform:
212
- raise ValueError(
213
- f"TEE device '{to_dev_id}' is missing 'platform' in its config."
214
- )
215
- # Ensure sessions (both directions) exist for this PPU<->TEE pair
216
- sess_p, sess_t = _ensure_tee_session(
217
- frm_dev_id, to_dev_id, frm_rank, tee_rank, platform
218
- )
219
- # Bytes-only path: pack -> enc -> p2p -> dec -> unpack (with static out type)
220
- obj_ty = TensorType.from_obj(obj)
221
- b = simp.runAt(frm_rank, builtin.pack)(obj)
222
- ct = simp.runAt(frm_rank, crypto.enc)(b, sess_p)
223
- ct_at_tee = mpi.p2p(frm_rank, tee_rank, ct)
224
- b_at_tee = simp.runAt(tee_rank, crypto.dec)(ct_at_tee, sess_t)
225
- pt_at_tee = simp.runAt(tee_rank, builtin.unpack)(b_at_tee, out_ty=obj_ty)
226
- return tree_map(partial(_set_devid, dev_id=to_dev_id), pt_at_tee) # type: ignore[no-any-return]
227
- elif frm_to_pair == ("TEE", "PPU"):
228
- # Transparent encryption from TEE to a specific PPU using the reverse-direction session key
229
- assert len(frm_dev.members) == 1 and len(to_dev.members) == 1
230
- tee_rank = frm_dev.members[0].rank
231
- ppu_rank = to_dev.members[0].rank
232
- platform = frm_dev.config.get("platform")
233
- if not platform:
234
- raise ValueError(
235
- f"TEE device '{frm_dev_id}' is missing 'platform' in its config."
236
- )
237
- # Ensure bidirectional session established for this pair
238
- sess_p, sess_t = _ensure_tee_session(
239
- to_dev_id, frm_dev_id, ppu_rank, tee_rank, platform
240
- )
241
- obj_ty = TensorType.from_obj(obj)
242
- b = simp.runAt(tee_rank, builtin.pack)(obj)
243
- ct = simp.runAt(tee_rank, crypto.enc)(b, sess_t)
244
- ct_at_ppu = mpi.p2p(tee_rank, ppu_rank, ct)
245
- b_at_ppu = simp.runAt(ppu_rank, crypto.dec)(ct_at_ppu, sess_p)
246
- pt_at_ppu = simp.runAt(ppu_rank, builtin.unpack)(b_at_ppu, out_ty=obj_ty)
247
- return tree_map(partial(_set_devid, dev_id=to_dev_id), pt_at_ppu) # type: ignore[no-any-return]
248
- else:
249
- supported = [
250
- ("SPU", "PPU"),
251
- ("PPU", "SPU"),
252
- ("PPU", "PPU"),
253
- ("PPU", "TEE"),
254
- ("TEE", "PPU"),
255
- ]
256
- raise ValueError(
257
- f"Unsupported device transfer: {frm_to_pair}. Supported pairs: {supported}."
258
- )
259
-
260
-
261
- def _ensure_tee_session(
262
- frm_dev_id: str, to_dev_id: str, frm_rank: int, tee_rank: int, platform: str
263
- ) -> tuple[MPObject, MPObject]:
264
- """Ensure a TEE session (sess_p at sender, sess_t at TEE) exists.
265
-
266
- Returns (sess_p, sess_t).
267
- """
268
- ctx = cur_ctx().root()
269
- if not hasattr(ctx, "_tee_sessions"):
270
- ctx._tee_sessions = {} # type: ignore[attr-defined]
271
- cache: dict[tuple[str, str], tuple[MPObject, MPObject]] = ctx._tee_sessions # type: ignore
272
-
273
- key = (frm_dev_id, to_dev_id)
274
- if key in cache:
275
- return cache[key]
276
-
277
- # 1) TEE generates (sk, pk) and quote(pk)
278
- # KEM suite currently constant; future: read from tee device config (e.g. cluster_spec.devices[dev_id].config)
279
- tee_sk, tee_pk = simp.runAt(tee_rank, crypto.kem_keygen)(_TEE_KEM_SUITE)
280
- quote = simp.runAt(tee_rank, tee.quote_gen)(tee_pk)
281
-
282
- # 2) Send quote to sender and attest to obtain TEE pk
283
- quote_at_sender = mpi.p2p(tee_rank, frm_rank, quote)
284
- tee_pk_at_sender = simp.runAt(frm_rank, tee.attest)(quote_at_sender, platform)
285
-
286
- # 3) Sender generates its ephemeral keypair and sends its pk to TEE
287
- v_sk, v_pk = simp.runAt(frm_rank, crypto.kem_keygen)(_TEE_KEM_SUITE)
288
- v_pk_at_tee = mpi.p2p(frm_rank, tee_rank, v_pk)
289
-
290
- # 4) Both sides derive the shared secret and session key
291
- shared_p = simp.runAt(frm_rank, crypto.kem_derive)(
292
- v_sk, tee_pk_at_sender, _TEE_KEM_SUITE
293
- )
294
- shared_t = simp.runAt(tee_rank, crypto.kem_derive)(
295
- tee_sk, v_pk_at_tee, _TEE_KEM_SUITE
296
- )
297
- # Use a fixed ASCII string literal for HKDF info on both sides
298
- sess_p = simp.runAt(frm_rank, crypto.hkdf)(shared_p, _HKDF_INFO_LITERAL)
299
- sess_t = simp.runAt(tee_rank, crypto.hkdf)(shared_t, _HKDF_INFO_LITERAL)
300
-
301
- cache[key] = (sess_p, sess_t)
302
- return sess_p, sess_t
303
-
304
-
305
- def put(to_dev_id: str, obj: Any) -> MPObject:
306
- if not isinstance(obj, MPObject):
307
- return device(to_dev_id)(lambda x: x)(obj) # type: ignore[no-any-return]
308
- assert isinstance(obj, MPObject)
309
- return _d2d(to_dev_id, obj)
310
-
311
-
312
- def _fetch(interp: InterpContext, obj: MPObject) -> Any:
313
- dev_id = _get_devid(obj)
314
- cluster_spec = interp.cluster_spec
315
- dev_kind = cluster_spec.devices[dev_id].kind.upper()
316
-
317
- dev_info = cluster_spec.devices[dev_id]
318
- if dev_kind == "SPU":
319
- revealed = mapi.evaluate(interp, smpc.reveal, obj)
320
- result = mapi.fetch(interp, revealed)
321
- # now all members have the same value, return the one at rank 0
322
- return result[dev_info.members[0].rank]
323
- elif dev_kind == "PPU":
324
- assert len(dev_info.members) == 1
325
- rank = dev_info.members[0].rank
326
- result = mapi.fetch(interp, obj)
327
- return result[rank]
328
- elif dev_kind == "TEE":
329
- assert len(dev_info.members) == 1
330
- rank = dev_info.members[0].rank
331
- result = mapi.fetch(interp, obj)
332
- return result[rank]
333
- else:
334
- raise ValueError(f"Unknown device id: {dev_id}")
335
-
336
-
337
- def fetch(interp: InterpContext, objs: Any) -> Any:
338
- ctx = interp or mapi.cur_ctx()
339
- assert isinstance(ctx, InterpContext), f"Expect InterpContext, got {ctx}"
340
- return tree_map(partial(_fetch, ctx), objs)
mplang/kernels/builtin.py DELETED
@@ -1,207 +0,0 @@
1
- # Copyright 2025 Ant Group Co., Ltd.
2
- #
3
- # Licensed under the Apache License, Version 2.0 (the "License");
4
- # you may not use this file except in compliance with the License.
5
- # You may obtain a copy of the License at
6
- #
7
- # http://www.apache.org/licenses/LICENSE-2.0
8
- #
9
- # Unless required by applicable law or agreed to in writing, software
10
- # distributed under the License is distributed on an "AS IS" BASIS,
11
- # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
- # See the License for the specific language governing permissions and
13
- # limitations under the License.
14
-
15
- from __future__ import annotations
16
-
17
- from typing import Any
18
-
19
- import numpy as np
20
- import pandas as pd
21
-
22
- from mplang.core.pfunc import PFunction
23
- from mplang.core.table import TableType
24
- from mplang.core.tensor import TensorType
25
- from mplang.kernels.base import cur_kctx, kernel_def
26
- from mplang.runtime.data_providers import get_provider, resolve_uri
27
- from mplang.utils import table_utils
28
-
29
-
30
- def _to_numpy(obj: Any) -> np.ndarray: # minimal helper to avoid duplicating logic
31
- if isinstance(obj, np.ndarray):
32
- return obj
33
- if hasattr(obj, "numpy"):
34
- try:
35
- return np.asarray(obj.numpy()) # type: ignore
36
- except Exception:
37
- pass
38
- return np.asarray(obj)
39
-
40
-
41
- @kernel_def("builtin.identity")
42
- def _identity(pfunc: PFunction, value: Any) -> Any:
43
- # Runtime guarantees exactly one argument; no extra arity checks here.
44
- return value
45
-
46
-
47
- @kernel_def("builtin.read")
48
- def _read(pfunc: PFunction) -> Any:
49
- path = pfunc.attrs.get("path")
50
- if path is None:
51
- raise ValueError("missing path attr for builtin.read")
52
- out_t = pfunc.outs_info[0]
53
- uri = resolve_uri(str(path))
54
- prov = get_provider(uri.scheme)
55
- if prov is None:
56
- raise NotImplementedError(f"no resource provider for scheme: {uri.scheme}")
57
- ctx = cur_kctx()
58
- try:
59
- return prov.read(uri, out_t, ctx=ctx)
60
- except Exception as e: # pragma: no cover - provider errors
61
- raise RuntimeError(f"builtin.read failed: {e}") from e
62
-
63
-
64
- @kernel_def("builtin.write")
65
- def _write(pfunc: PFunction, obj: Any) -> Any:
66
- path = pfunc.attrs.get("path")
67
- if path is None:
68
- raise ValueError("missing path attr for builtin.write")
69
- uri = resolve_uri(str(path))
70
- prov = get_provider(uri.scheme)
71
- if prov is None:
72
- raise NotImplementedError(f"no resource provider for scheme: {uri.scheme}")
73
- ctx = cur_kctx()
74
- try:
75
- prov.write(uri, obj, ctx=ctx)
76
- return obj
77
- except Exception as e: # pragma: no cover
78
- raise RuntimeError(f"builtin.write failed: {e}") from e
79
-
80
-
81
- @kernel_def("builtin.constant")
82
- def _constant(pfunc: PFunction) -> Any:
83
- data_bytes = pfunc.attrs.get("data_bytes")
84
- if data_bytes is None:
85
- raise ValueError("missing data_bytes attr for builtin.constant")
86
- out_t = pfunc.outs_info[0]
87
- fmt = pfunc.attrs.get("data_format")
88
- if isinstance(out_t, TableType):
89
- if fmt != "bytes[csv]":
90
- raise ValueError(f"unsupported table constant format {fmt}")
91
- df = table_utils.csv_to_dataframe(data_bytes)
92
- return df
93
- # tensor path
94
- shape = out_t.shape # type: ignore[attr-defined,union-attr]
95
- dtype = out_t.dtype.numpy_dtype() # type: ignore[attr-defined,union-attr]
96
- arr = np.frombuffer(data_bytes, dtype=dtype).reshape(shape)
97
- return arr
98
-
99
-
100
- @kernel_def("builtin.rank")
101
- def _rank(pfunc: PFunction) -> Any:
102
- ctx = cur_kctx()
103
- return np.array(ctx.rank, dtype=np.uint64)
104
-
105
-
106
- @kernel_def("builtin.prand")
107
- def _prand(pfunc: PFunction) -> Any:
108
- shape = pfunc.attrs.get("shape", ())
109
- rng = np.random.default_rng()
110
- info = np.iinfo(np.uint64)
111
- data = rng.integers(
112
- low=info.min, high=info.max, size=shape, dtype=np.uint64, endpoint=True
113
- )
114
- return data
115
-
116
-
117
- @kernel_def("builtin.table_to_tensor")
118
- def _table_to_tensor(pfunc: PFunction, table: Any) -> Any:
119
- if not isinstance(table, pd.DataFrame):
120
- raise TypeError("expected pandas DataFrame")
121
- if table.shape[1] == 0:
122
- raise ValueError("cannot pack empty table")
123
- mat = np.column_stack([table[col].to_numpy() for col in table.columns])
124
- return mat
125
-
126
-
127
- @kernel_def("builtin.tensor_to_table")
128
- def _tensor_to_table(pfunc: PFunction, tensor: Any) -> Any:
129
- arr = _to_numpy(tensor)
130
- if arr.ndim != 2:
131
- raise ValueError("tensor_to_table expects rank-2 array")
132
- col_names = pfunc.attrs.get("column_names")
133
- if col_names is None:
134
- raise ValueError("missing column_names attr")
135
- df = pd.DataFrame(arr, columns=list(col_names))
136
- return df
137
-
138
-
139
- def _summ(v: Any) -> str:
140
- try:
141
- if isinstance(v, pd.DataFrame):
142
- return str(v.head(8).to_string(index=False))
143
- arr = _to_numpy(v)
144
- return str(
145
- np.array2string(
146
- arr, threshold=64, edgeitems=3, precision=6, suppress_small=True
147
- )
148
- )
149
- except Exception as e: # pragma: no cover
150
- return f"<unprintable {type(v).__name__}: {e}>"
151
-
152
-
153
- @kernel_def("builtin.debug_print")
154
- def _debug_print(pfunc: PFunction, val: Any) -> Any:
155
- prefix = pfunc.attrs.get("prefix", "")
156
- ctx = cur_kctx()
157
- print(f"[debug_print][rank={ctx.rank}] {prefix}{_summ(val)}")
158
- return val
159
-
160
-
161
- @kernel_def("builtin.pack")
162
- def _pack(pfunc: PFunction, value: Any) -> Any:
163
- outs_info = pfunc.outs_info
164
- if len(outs_info) != 1:
165
- raise ValueError("builtin.pack expects single output type")
166
- out_ty = outs_info[0]
167
- if not isinstance(out_ty, TensorType):
168
- raise TypeError("builtin.pack must return TensorType")
169
- if out_ty.dtype.numpy_dtype() != np.uint8:
170
- raise TypeError("builtin.pack output dtype must be uint8")
171
-
172
- if isinstance(value, pd.DataFrame):
173
- csv_bytes = table_utils.dataframe_to_csv(value)
174
- return np.frombuffer(csv_bytes, dtype=np.uint8)
175
-
176
- arr = _to_numpy(value)
177
- return np.frombuffer(arr.tobytes(order="C"), dtype=np.uint8)
178
-
179
-
180
- @kernel_def("builtin.unpack")
181
- def _unpack(pfunc: PFunction, packed: Any) -> Any:
182
- outs_info = pfunc.outs_info
183
- if len(outs_info) != 1:
184
- raise ValueError("builtin.unpack expects single output type")
185
- out_ty = outs_info[0]
186
-
187
- b = np.asarray(packed, dtype=np.uint8).reshape(-1)
188
-
189
- if isinstance(out_ty, TensorType):
190
- np_dtype = out_ty.dtype.numpy_dtype()
191
- shape = tuple(out_ty.shape)
192
- if any(dim < 0 for dim in shape):
193
- raise ValueError("builtin.unpack does not support dynamic tensor shapes")
194
- elem_count = int(np.prod(shape))
195
- expected = elem_count * np.dtype(np_dtype).itemsize
196
- if b.size != expected:
197
- raise ValueError(
198
- f"unpack size mismatch: got {b.size} bytes, expect {expected} for {np_dtype} {shape}"
199
- )
200
- arr = np.frombuffer(b.tobytes(), dtype=np_dtype)
201
- return arr.reshape(shape)
202
-
203
- if isinstance(out_ty, TableType):
204
- csv_bytes = b.tobytes()
205
- return table_utils.csv_to_dataframe(csv_bytes)
206
-
207
- raise TypeError("builtin.unpack output type must be TensorType or TableType")