mplang-nightly 0.1.dev269__py3-none-any.whl → 0.1.dev270__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (180) hide show
  1. mplang/__init__.py +391 -17
  2. mplang/{v2/backends → backends}/__init__.py +9 -7
  3. mplang/{v2/backends → backends}/bfv_impl.py +6 -6
  4. mplang/{v2/backends → backends}/crypto_impl.py +6 -6
  5. mplang/{v2/backends → backends}/field_impl.py +5 -5
  6. mplang/{v2/backends → backends}/func_impl.py +4 -4
  7. mplang/{v2/backends → backends}/phe_impl.py +3 -3
  8. mplang/{v2/backends → backends}/simp_design.md +1 -1
  9. mplang/{v2/backends → backends}/simp_driver/__init__.py +5 -5
  10. mplang/{v2/backends → backends}/simp_driver/http.py +8 -8
  11. mplang/{v2/backends → backends}/simp_driver/mem.py +9 -9
  12. mplang/{v2/backends → backends}/simp_driver/ops.py +4 -4
  13. mplang/{v2/backends → backends}/simp_driver/state.py +2 -2
  14. mplang/{v2/backends → backends}/simp_driver/values.py +2 -2
  15. mplang/{v2/backends → backends}/simp_worker/__init__.py +3 -3
  16. mplang/{v2/backends → backends}/simp_worker/http.py +10 -10
  17. mplang/{v2/backends → backends}/simp_worker/mem.py +1 -1
  18. mplang/{v2/backends → backends}/simp_worker/ops.py +5 -5
  19. mplang/{v2/backends → backends}/simp_worker/state.py +2 -4
  20. mplang/{v2/backends → backends}/spu_impl.py +8 -8
  21. mplang/{v2/backends → backends}/spu_state.py +4 -4
  22. mplang/{v2/backends → backends}/store_impl.py +3 -3
  23. mplang/{v2/backends → backends}/table_impl.py +8 -8
  24. mplang/{v2/backends → backends}/tee_impl.py +6 -6
  25. mplang/{v2/backends → backends}/tensor_impl.py +6 -6
  26. mplang/{v2/cli.py → cli.py} +9 -9
  27. mplang/{v2/cli_guide.md → cli_guide.md} +12 -12
  28. mplang/{v2/dialects → dialects}/__init__.py +5 -5
  29. mplang/{v2/dialects → dialects}/bfv.py +6 -6
  30. mplang/{v2/dialects → dialects}/crypto.py +5 -5
  31. mplang/{v2/dialects → dialects}/dtypes.py +2 -2
  32. mplang/{v2/dialects → dialects}/field.py +3 -3
  33. mplang/{v2/dialects → dialects}/func.py +2 -2
  34. mplang/{v2/dialects → dialects}/phe.py +6 -6
  35. mplang/{v2/dialects → dialects}/simp.py +6 -6
  36. mplang/{v2/dialects → dialects}/spu.py +7 -7
  37. mplang/{v2/dialects → dialects}/store.py +2 -2
  38. mplang/{v2/dialects → dialects}/table.py +3 -3
  39. mplang/{v2/dialects → dialects}/tee.py +6 -6
  40. mplang/{v2/dialects → dialects}/tensor.py +5 -5
  41. mplang/{v2/edsl → edsl}/__init__.py +3 -3
  42. mplang/{v2/edsl → edsl}/context.py +6 -6
  43. mplang/{v2/edsl → edsl}/graph.py +5 -5
  44. mplang/{v2/edsl → edsl}/jit.py +2 -2
  45. mplang/{v2/edsl → edsl}/object.py +1 -1
  46. mplang/{v2/edsl → edsl}/primitive.py +5 -5
  47. mplang/{v2/edsl → edsl}/printer.py +1 -1
  48. mplang/{v2/edsl → edsl}/serde.py +1 -1
  49. mplang/{v2/edsl → edsl}/tracer.py +7 -7
  50. mplang/{v2/edsl → edsl}/typing.py +1 -1
  51. mplang/{v2/kernels → kernels}/ldpc.cpp +13 -13
  52. mplang/{v2/kernels → kernels}/okvs.cpp +4 -4
  53. mplang/{v2/kernels → kernels}/okvs_opt.cpp +31 -31
  54. mplang/{v2/kernels → kernels}/py_kernels.py +1 -1
  55. mplang/{v2/libs → libs}/collective.py +5 -5
  56. mplang/{v2/libs → libs}/device/__init__.py +1 -1
  57. mplang/{v2/libs → libs}/device/api.py +12 -12
  58. mplang/{v2/libs → libs}/ml/__init__.py +1 -1
  59. mplang/{v2/libs → libs}/ml/sgb.py +4 -4
  60. mplang/{v2/libs → libs}/mpc/__init__.py +3 -3
  61. mplang/{v2/libs → libs}/mpc/_utils.py +2 -2
  62. mplang/{v2/libs → libs}/mpc/analytics/aggregation.py +1 -1
  63. mplang/{v2/libs → libs}/mpc/analytics/groupby.py +2 -2
  64. mplang/{v2/libs → libs}/mpc/analytics/permutation.py +3 -3
  65. mplang/{v2/libs → libs}/mpc/ot/base.py +3 -3
  66. mplang/{v2/libs → libs}/mpc/ot/extension.py +2 -2
  67. mplang/{v2/libs → libs}/mpc/ot/silent.py +4 -4
  68. mplang/{v2/libs → libs}/mpc/psi/cuckoo.py +3 -3
  69. mplang/{v2/libs → libs}/mpc/psi/okvs.py +1 -1
  70. mplang/{v2/libs → libs}/mpc/psi/okvs_gct.py +3 -3
  71. mplang/{v2/libs → libs}/mpc/psi/oprf.py +3 -3
  72. mplang/{v2/libs → libs}/mpc/psi/rr22.py +7 -7
  73. mplang/{v2/libs → libs}/mpc/psi/unbalanced.py +4 -4
  74. mplang/{v2/libs → libs}/mpc/vole/gilboa.py +3 -3
  75. mplang/{v2/libs → libs}/mpc/vole/ldpc.py +2 -2
  76. mplang/{v2/libs → libs}/mpc/vole/silver.py +6 -6
  77. mplang/{v2/runtime → runtime}/interpreter.py +11 -11
  78. mplang/{v2/runtime → runtime}/value.py +2 -2
  79. mplang/{v1/runtime → utils}/__init__.py +18 -15
  80. mplang/{v1/utils → utils}/func_utils.py +1 -1
  81. {mplang_nightly-0.1.dev269.dist-info → mplang_nightly-0.1.dev270.dist-info}/METADATA +2 -2
  82. mplang_nightly-0.1.dev270.dist-info/RECORD +102 -0
  83. mplang/v1/__init__.py +0 -157
  84. mplang/v1/_device.py +0 -602
  85. mplang/v1/analysis/__init__.py +0 -37
  86. mplang/v1/analysis/diagram.py +0 -567
  87. mplang/v1/core/__init__.py +0 -157
  88. mplang/v1/core/cluster.py +0 -343
  89. mplang/v1/core/comm.py +0 -281
  90. mplang/v1/core/context_mgr.py +0 -50
  91. mplang/v1/core/dtypes.py +0 -335
  92. mplang/v1/core/expr/__init__.py +0 -80
  93. mplang/v1/core/expr/ast.py +0 -542
  94. mplang/v1/core/expr/evaluator.py +0 -581
  95. mplang/v1/core/expr/printer.py +0 -285
  96. mplang/v1/core/expr/transformer.py +0 -141
  97. mplang/v1/core/expr/utils.py +0 -78
  98. mplang/v1/core/expr/visitor.py +0 -85
  99. mplang/v1/core/expr/walk.py +0 -387
  100. mplang/v1/core/interp.py +0 -160
  101. mplang/v1/core/mask.py +0 -325
  102. mplang/v1/core/mpir.py +0 -965
  103. mplang/v1/core/mpobject.py +0 -117
  104. mplang/v1/core/mptype.py +0 -407
  105. mplang/v1/core/pfunc.py +0 -130
  106. mplang/v1/core/primitive.py +0 -877
  107. mplang/v1/core/table.py +0 -218
  108. mplang/v1/core/tensor.py +0 -75
  109. mplang/v1/core/tracer.py +0 -383
  110. mplang/v1/host.py +0 -130
  111. mplang/v1/kernels/__init__.py +0 -41
  112. mplang/v1/kernels/base.py +0 -125
  113. mplang/v1/kernels/basic.py +0 -240
  114. mplang/v1/kernels/context.py +0 -369
  115. mplang/v1/kernels/crypto.py +0 -122
  116. mplang/v1/kernels/fhe.py +0 -858
  117. mplang/v1/kernels/mock_tee.py +0 -72
  118. mplang/v1/kernels/phe.py +0 -1864
  119. mplang/v1/kernels/spu.py +0 -341
  120. mplang/v1/kernels/sql_duckdb.py +0 -44
  121. mplang/v1/kernels/stablehlo.py +0 -90
  122. mplang/v1/kernels/value.py +0 -626
  123. mplang/v1/ops/__init__.py +0 -35
  124. mplang/v1/ops/base.py +0 -424
  125. mplang/v1/ops/basic.py +0 -294
  126. mplang/v1/ops/crypto.py +0 -262
  127. mplang/v1/ops/fhe.py +0 -272
  128. mplang/v1/ops/jax_cc.py +0 -147
  129. mplang/v1/ops/nnx_cc.py +0 -168
  130. mplang/v1/ops/phe.py +0 -216
  131. mplang/v1/ops/spu.py +0 -151
  132. mplang/v1/ops/sql_cc.py +0 -303
  133. mplang/v1/ops/tee.py +0 -36
  134. mplang/v1/protos/v1alpha1/mpir_pb2.py +0 -63
  135. mplang/v1/protos/v1alpha1/mpir_pb2.pyi +0 -557
  136. mplang/v1/protos/v1alpha1/value_pb2.py +0 -34
  137. mplang/v1/protos/v1alpha1/value_pb2.pyi +0 -169
  138. mplang/v1/runtime/channel.py +0 -230
  139. mplang/v1/runtime/cli.py +0 -451
  140. mplang/v1/runtime/client.py +0 -456
  141. mplang/v1/runtime/communicator.py +0 -131
  142. mplang/v1/runtime/data_providers.py +0 -303
  143. mplang/v1/runtime/driver.py +0 -324
  144. mplang/v1/runtime/exceptions.py +0 -27
  145. mplang/v1/runtime/http_api.md +0 -56
  146. mplang/v1/runtime/link_comm.py +0 -196
  147. mplang/v1/runtime/server.py +0 -501
  148. mplang/v1/runtime/session.py +0 -270
  149. mplang/v1/runtime/simulation.py +0 -324
  150. mplang/v1/simp/__init__.py +0 -13
  151. mplang/v1/simp/api.py +0 -353
  152. mplang/v1/simp/mpi.py +0 -131
  153. mplang/v1/simp/party.py +0 -225
  154. mplang/v1/simp/random.py +0 -120
  155. mplang/v1/simp/smpc.py +0 -238
  156. mplang/v1/utils/__init__.py +0 -13
  157. mplang/v1/utils/crypto.py +0 -32
  158. mplang/v1/utils/spu_utils.py +0 -130
  159. mplang/v1/utils/table_utils.py +0 -185
  160. mplang/v2/__init__.py +0 -424
  161. mplang_nightly-0.1.dev269.dist-info/RECORD +0 -180
  162. /mplang/{v2/backends → backends}/channel.py +0 -0
  163. /mplang/{v2/edsl → edsl}/README.md +0 -0
  164. /mplang/{v2/edsl → edsl}/registry.py +0 -0
  165. /mplang/{v2/kernels → kernels}/Makefile +0 -0
  166. /mplang/{v2/kernels → kernels}/__init__.py +0 -0
  167. /mplang/{v2/kernels → kernels}/gf128.cpp +0 -0
  168. /mplang/{v2/libs → libs}/device/cluster.py +0 -0
  169. /mplang/{v2/libs → libs}/mpc/analytics/__init__.py +0 -0
  170. /mplang/{v2/libs → libs}/mpc/analytics/groupby.md +0 -0
  171. /mplang/{v2/libs → libs}/mpc/common/constants.py +0 -0
  172. /mplang/{v2/libs → libs}/mpc/ot/__init__.py +0 -0
  173. /mplang/{v2/libs → libs}/mpc/psi/__init__.py +0 -0
  174. /mplang/{v2/libs → libs}/mpc/vole/__init__.py +0 -0
  175. /mplang/{v2/runtime → runtime}/__init__.py +0 -0
  176. /mplang/{v2/runtime → runtime}/dialect_state.py +0 -0
  177. /mplang/{v2/runtime → runtime}/object_store.py +0 -0
  178. {mplang_nightly-0.1.dev269.dist-info → mplang_nightly-0.1.dev270.dist-info}/WHEEL +0 -0
  179. {mplang_nightly-0.1.dev269.dist-info → mplang_nightly-0.1.dev270.dist-info}/entry_points.txt +0 -0
  180. {mplang_nightly-0.1.dev269.dist-info → mplang_nightly-0.1.dev270.dist-info}/licenses/LICENSE +0 -0
mplang/v1/simp/party.py DELETED
@@ -1,225 +0,0 @@
1
- # Copyright 2025 Ant Group Co., Ltd.
2
- #
3
- # Licensed under the Apache License, Version 2.0 (the "License");
4
- # you may not use this file except in compliance with the License.
5
- # You may obtain a copy of the License at
6
- #
7
- # http://www.apache.org/licenses/LICENSE-2.0
8
- #
9
- # Unless required by applicable law or agreed to in writing, software
10
- # distributed under the License is distributed on an "AS IS" BASIS,
11
- # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
- # See the License for the specific language governing permissions and
13
- # limitations under the License.
14
-
15
- from __future__ import annotations
16
-
17
- import importlib
18
- import pathlib
19
- import pkgutil
20
- from collections.abc import Callable
21
- from functools import wraps
22
- from types import ModuleType
23
- from typing import Any
24
-
25
- from mplang.v1.ops.base import FeOperation
26
- from mplang.v1.simp.api import run_at, run_jax_at
27
- from mplang.v1.simp.mpi import p2p
28
-
29
-
30
- def P2P(src: Party, dst: Party, value: Any) -> Any:
31
- """Point-to-point transfer using Party objects instead of raw ranks.
32
-
33
- Equivalent to ``p2p(src.rank, dst.rank, value)`` but improves readability
34
- and reduces magic numbers in user code / tutorials.
35
-
36
- Parameters
37
- ----------
38
- src : Party
39
- Source party object.
40
- dst : Party
41
- Destination party object.
42
- value : Any
43
- Value to transfer.
44
-
45
- Returns
46
- -------
47
- Any
48
- The same value representation at destination context (as defined by
49
- underlying ``p2p`` primitive semantics).
50
- """
51
- if not isinstance(src, Party) or not isinstance(dst, Party): # defensive
52
- raise TypeError("P2P expects Party objects, e.g. P2P(P0, P2, value)")
53
- return p2p(src.rank, dst.rank, value)
54
-
55
-
56
- """Party-scoped module registration & dispatch.
57
-
58
- This module provides a light-weight mechanism to expose *module-like* groups
59
- of callable operations bound to a specific party (rank) via attribute access:
60
-
61
- load_module("mplang.ops.crypto", alias="crypto")
62
- P0.crypto.encrypt(x) # executes encrypt() with pmask = {rank 0}
63
-
64
- Core concepts:
65
- * Registry (``_NAMESPACE_REGISTRY``): maps alias -> importable module path.
66
- * Lazy import: underlying module is imported on first attribute access.
67
- * Wrapping: fetched callables are wrapped so that invocation automatically
68
- routes through ``run_impl`` with that party's mask.
69
-
70
- Only *callable* attributes are exposed; non-callable attributes raise
71
- ``AttributeError`` to avoid surprising divergent local vs. distributed
72
- semantics.
73
-
74
- The public API surface intentionally stays small (`Party`, `P`, `run`,
75
- `runAt`, and `load_module`). Internal details (proxy class / registry) are
76
- not part of the stability guarantee.
77
- """
78
-
79
- _NAMESPACE_REGISTRY: dict[str, str] = {}
80
-
81
-
82
- class _PartyModuleProxy:
83
- """Lazy module proxy bound to a specific party.
84
-
85
- Attribute access resolves a callable inside the registered module and
86
- returns a wrapped function that executes with the party's mask.
87
- Non-callable attributes are rejected explicitly to keep semantics clear.
88
- """
89
-
90
- def __init__(self, party: Party, name: str):
91
- self._party: Party = party
92
- self._name: str = name
93
- self._module: ModuleType | None = None # loaded lazily
94
-
95
- def _ensure(self) -> None:
96
- if self._module is None:
97
- self._module = importlib.import_module(_NAMESPACE_REGISTRY[self._name])
98
-
99
- def __getattr__(self, item: str) -> Callable[..., Any]:
100
- self._ensure()
101
- op = getattr(self._module, item)
102
- if not callable(op):
103
- raise AttributeError(
104
- f"Attribute '{item}' of party module '{self._name}' is not callable (got {type(op).__name__})"
105
- )
106
-
107
- @wraps(op)
108
- def _wrapped(*args: Any, **kw: Any) -> Any:
109
- # Inline runAt to reduce an extra partial layer while preserving semantics.
110
- return run_at(self._party.rank, op, *args, **kw)
111
-
112
- # Provide a party-qualified name for debugging / logs without losing original metadata.
113
- base_name = getattr(op, "__name__", None)
114
- if base_name is None:
115
- # Frontend FeOperation or object without __name__; try .name attribute (FeOperation contract) or fallback to repr
116
- base_name = getattr(op, "name", None) or type(op).__name__
117
- try:
118
- _wrapped.__name__ = f"{base_name}@P{self._party.rank}"
119
- except Exception: # pragma: no cover - assignment may fail for exotic wrappers
120
- pass
121
- return _wrapped
122
-
123
-
124
- class Party:
125
- def __init__(self, rank: int) -> None:
126
- self.rank: int = int(rank)
127
-
128
- def __repr__(self) -> str: # pragma: no cover
129
- return f"Party(rank={self.rank})"
130
-
131
- def __call__(self, fn: Callable[..., Any], *args: Any, **kwargs: Any) -> Any:
132
- if not callable(fn):
133
- raise TypeError(
134
- f"First argument to Party({self.rank}) must be callable, got {fn!r}"
135
- )
136
- # Use run_op_at for FeOperation, run_jax_at for plain callables
137
- if isinstance(fn, FeOperation):
138
- return run_at(self.rank, fn, *args, **kwargs)
139
- else:
140
- # TODO(jint): implicitly assume non-FeOperation as JAX function is a bit too magical?
141
- return run_jax_at(self.rank, fn, *args, **kwargs)
142
-
143
- def __getattr__(self, name: str) -> _PartyModuleProxy:
144
- if name in _NAMESPACE_REGISTRY:
145
- return _PartyModuleProxy(self, name)
146
- raise AttributeError(
147
- f"Party has no attribute '{name}'. Registered: {list(_NAMESPACE_REGISTRY)}"
148
- )
149
-
150
-
151
- class _PartyIndex:
152
- def __getitem__(self, rank: int) -> Party:
153
- return Party(rank)
154
-
155
-
156
- def _load_prelude_modules() -> None:
157
- """Auto-register public frontend submodules for party namespace access.
158
-
159
- Implementation detail: we treat every non-underscore immediate child of
160
- ``mplang.ops`` as public and make it available as ``P0.<name>``.
161
- This keeps user ergonomics high (no manual load_module calls for core
162
- frontends) but slightly increases implicit surface area. If this grows
163
- unwieldy we can switch to an allowlist.
164
- """
165
- try:
166
- import mplang.v1.ops as _fe # type: ignore
167
- except (ImportError, ModuleNotFoundError): # pragma: no cover
168
- # Frontend package not present (minimal install); safe to skip.
169
- return
170
-
171
- pkg_path = pathlib.Path(_fe.__file__).parent
172
- for m in pkgutil.iter_modules([str(pkg_path)]):
173
- if m.name.startswith("_"):
174
- continue
175
- if m.name not in _NAMESPACE_REGISTRY:
176
- _NAMESPACE_REGISTRY[m.name] = f"mplang.v1.ops.{m.name}"
177
-
178
-
179
- def load_module(module: str, alias: str | None = None) -> None:
180
- """Register a module for party-scoped (per-rank) callable access.
181
-
182
- After registration, each party object (e.g. ``P0``) can access callable
183
- attributes of the target module through the chosen alias and have them
184
- executed under that party's mask automatically. Non-callable attributes
185
- are intentionally not exposed to avoid ambiguity between local data and
186
- distributed execution semantics.
187
-
188
- Parameters
189
- ----------
190
- module : str
191
- The fully-qualified import path of the module to expose. It must be
192
- importable via ``importlib.import_module``.
193
- alias : str | None, optional
194
- The short name used as an attribute on ``Party``/``P0``/``P1``/... .
195
- If omitted, the last path segment of ``module`` is used.
196
-
197
- Raises
198
- ------
199
- ValueError
200
- If the alias is already registered to a *different* module path.
201
-
202
- Notes
203
- -----
204
- Registration is idempotent when the alias maps to the same module. The
205
- actual module object is imported lazily on first attribute access, so
206
- calling ``load_module`` has negligible upfront cost.
207
-
208
- Examples
209
- --------
210
- >>> load_module("mplang.ops.crypto", alias="crypto")
211
- >>> # Now call an op on party 0
212
- >>> P0.crypto.encrypt(data)
213
- """
214
- if alias is None:
215
- alias = module.rsplit(".", 1)[-1]
216
- prev = _NAMESPACE_REGISTRY.get(alias)
217
- if prev and prev != module:
218
- raise ValueError(f"Alias '{alias}' already registered for '{prev}'")
219
- _NAMESPACE_REGISTRY[alias] = module
220
-
221
-
222
- P = _PartyIndex()
223
- P0, P1, P2 = Party(0), Party(1), Party(2)
224
-
225
- _load_prelude_modules()
mplang/v1/simp/random.py DELETED
@@ -1,120 +0,0 @@
1
- # Copyright 2025 Ant Group Co., Ltd.
2
- #
3
- # Licensed under the Apache License, Version 2.0 (the "License");
4
- # you may not use this file except in compliance with the License.
5
- # You may obtain a copy of the License at
6
- #
7
- # http://www.apache.org/licenses/LICENSE-2.0
8
- #
9
- # Unless required by applicable law or agreed to in writing, software
10
- # distributed under the License is distributed on an "AS IS" BASIS,
11
- # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
- # See the License for the specific language governing permissions and
13
- # limitations under the License.
14
-
15
- from __future__ import annotations
16
-
17
- from functools import partial
18
-
19
- import jax
20
- import jax.numpy as jnp
21
- import jax.random as jr
22
- from jax.typing import ArrayLike
23
-
24
- from mplang.v1.core import MPObject, Shape, function, pmask, psize
25
- from mplang.v1.simp.api import prand, prank, run_jax
26
-
27
-
28
- @function
29
- def key_split(key: MPObject) -> tuple[MPObject, MPObject]:
30
- """Split the key into two keys."""
31
-
32
- def kernel(key: jax.Array) -> tuple[jax.Array, jax.Array]:
33
- # TODO: since MPObject tensor does not implement slicing yet.
34
- # subkey, key = run_jax(jr.split, key) does not work.
35
- # we workaround it by splitting inside tracer.
36
- subkey, key = jr.split(key)
37
- return subkey, key
38
-
39
- return run_jax(kernel, key) # type: ignore[no-any-return]
40
-
41
-
42
- @function
43
- def ukey(seed: int | ArrayLike) -> MPObject:
44
- """Party uniformly generate a random key."""
45
-
46
- def kernel() -> jax.Array:
47
- key = jax.random.key(seed)
48
- # Note: key.dtype is jax._src.prng.KeyTy, which could not be handled by MPObject.
49
- return jax.random.key_data(key)
50
-
51
- return run_jax(kernel) # type: ignore[no-any-return]
52
-
53
-
54
- @function
55
- def urandint(
56
- key: MPObject | ArrayLike,
57
- low: int,
58
- high: int,
59
- shape: Shape = (),
60
- ) -> MPObject:
61
- """Party uniformly generate a random integer in the range [low, high) with the given shape."""
62
-
63
- return run_jax(partial(jr.randint, minval=low, maxval=high, shape=shape), key) # type: ignore[no-any-return]
64
-
65
-
66
- # Private(different per-party) related functions begin.
67
-
68
-
69
- @function
70
- def prandint(low: int, high: int, shape: Shape = ()) -> MPObject:
71
- """Party privately generate a random integer in the range [low, high) with the given shape."""
72
-
73
- def kernel(rand_u64: jnp.ndarray) -> jnp.ndarray:
74
- range_size = high - low
75
- if range_size <= 0:
76
- raise ValueError("'high' must be greater than 'low'")
77
-
78
- remainder = jax.lax.rem(rand_u64, jnp.uint64(range_size))
79
- result = low + remainder.astype(jnp.int64)
80
- return result
81
-
82
- rand_u64 = prand(shape)
83
- return run_jax(kernel, rand_u64) # type: ignore[no-any-return]
84
-
85
-
86
- @function
87
- def pperm(key: MPObject) -> MPObject:
88
- """Party jointly generate a random permutation.
89
-
90
- That is, each party holds a random number in range(size), and all parties as a whole
91
- hold a random permutation of integers from 0 to size-1.
92
-
93
- Note: this function is NOT 'secure', that is, all parties know the permutation result.
94
- """
95
-
96
- if key.pmask is None:
97
- raise ValueError("dynamic pmask is not supported for pperm")
98
-
99
- full_mask = (1 << psize()) - 1
100
-
101
- if key.pmask != full_mask:
102
- raise ValueError(
103
- "key must be a MPObject with mask covering all parties, "
104
- f"got {key.pmask} with world size {psize()}"
105
- )
106
-
107
- if pmask() is None or pmask() != full_mask:
108
- raise ValueError(
109
- "pperm must be run with a mask covering all parties, "
110
- f"got {key.pmask} with world size {psize()}"
111
- )
112
-
113
- size = psize()
114
-
115
- def kernel(key: jax.Array) -> jax.Array:
116
- return jr.permutation(key, size)
117
-
118
- perm = run_jax(kernel, key)
119
- rank = prank()
120
- return run_jax(lambda perm, rank: perm[rank], perm, rank) # type: ignore[no-any-return]
mplang/v1/simp/smpc.py DELETED
@@ -1,238 +0,0 @@
1
- # Copyright 2025 Ant Group Co., Ltd.
2
- #
3
- # Licensed under the Apache License, Version 2.0 (the "License");
4
- # you may not use this file except in compliance with the License.
5
- # You may obtain a copy of the License at
6
- #
7
- # http://www.apache.org/licenses/LICENSE-2.0
8
- #
9
- # Unless required by applicable law or agreed to in writing, software
10
- # distributed under the License is distributed on an "AS IS" BASIS,
11
- # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
- # See the License for the specific language governing permissions and
13
- # limitations under the License.
14
-
15
- """
16
- SMPC on simp: conventions and object semantics
17
-
18
- Overview
19
- - simp is party-centric. Objects produced purely by simp code carry only an execution
20
- mask ("pmask") and have no security device semantics by default.
21
- - Secure semantics (secret sharing, protected execution, declassification) are introduced
22
- only when using the device API or the helpers in this module: "seal", "srun", "reveal".
23
-
24
- Definitions
25
- - "__device__" attribute is attached by the device API to indicate the concrete device
26
- an object is bound to (e.g., an SPU/TEE/PPU name). See mplang.device.DEVICE_ATTR_NAME.
27
- - pmask describes which parties currently hold/execute the value under the simp model.
28
-
29
- Conventions
30
- 1) If an object has NO "__device__" attribute (i.e., it has not gone through the device API):
31
- - It is a simp object, privately owned on the parties indicated by its pmask.
32
- - When sealed via "seal(obj)", we infer target PPU device(s) from pmask:
33
- • one-hot pmask {pi} → route to PPU(pi).
34
- • multi-party pmask → fan out per party and seal independently to each party's PPU.
35
- - Such objects CANNOT be passed to "srun"/"reveal" directly; seal first.
36
-
37
- 2) If an object HAS a "__device__" attribute:
38
- - Its behavior follows the bound device (e.g., SPU/TEE/PPU) and its membership.
39
- - "srun" executes on that device; "reveal" declassifies from that device to the requested parties.
40
- - pmask must be consistent with the device membership during transitions; inconsistencies raise errors.
41
-
42
- Notes
43
- - "seal"/"seal_from" construct secret shares on the chosen secure device and attach the
44
- "__device__" attribute to outputs. "srun"/"reveal" assume inputs are already sealed
45
- (device-bound) and validate pmask ↔ device-membership consistency.
46
- - These rules align with "design/simp_vs_device.md" and keep routing unambiguous.
47
-
48
- Examples (obj state → interpretation/behavior)
49
- - {pmask={A}, dev_attr=None}: simp plaintext on party A. "seal" routes to PPU(A);
50
- must "seal" before "srun"/"reveal".
51
- - {pmask={A,B}, dev_attr=None}: simp plaintext held by A and B. "seal" produces two
52
- per-party sealed objects via PPU(A) and PPU(B), respectively.
53
- - {pmask={A,B}, dev_attr="spu:spu0"}: device object on SPU(spu0) whose members are {A,B};
54
- "srun" runs on spu0; "reveal(to={A})" reveals result to party A.
55
- - {pmask={A}, dev_attr="ppu:A"}: device object on PPU(A); "reveal(to={A})" returns A's plaintext.
56
- - {pmask=None, dev_attr=None}: dynamic pmask; "seal" is unsupported and will error.
57
- - {pmask={A}, dev_attr="spu:spu0"} where A ∉ members(spu0): inconsistent; operations will error.
58
- """
59
-
60
- from collections.abc import Callable
61
- from typing import Any
62
-
63
- from mplang.v1 import _device
64
- from mplang.v1.core import Mask, MPObject, Rank, psize
65
- from mplang.v1.core.cluster import Device
66
- from mplang.v1.core.context_mgr import cur_ctx
67
- from mplang.v1.core.primitive import pconv
68
- from mplang.v1.simp.api import set_mask
69
- from mplang.v1.utils.func_utils import normalize_fn
70
-
71
-
72
- def _determine_secure_device(*args: MPObject) -> Device:
73
- """Determine secure device from args, or find any available if no args."""
74
- if not args:
75
- # Find an available secure device (fallback when no args provided).
76
- devices = cur_ctx().cluster_spec.get_devices_by_kind("SPU")
77
- if devices:
78
- return devices[0]
79
-
80
- devices = cur_ctx().cluster_spec.get_devices_by_kind("TEE")
81
- if devices:
82
- return devices[0]
83
-
84
- raise ValueError(
85
- "No secure device (SPU or TEE) found in the cluster specification"
86
- )
87
-
88
- dev_names: list[str] = []
89
- for arg in args:
90
- if not _device.is_device_obj(arg):
91
- raise ValueError(
92
- "srun/reveal expect sealed inputs with a device attribute; "
93
- f"got an unsealed object: {arg}. Please call seal()/seal_from() first."
94
- )
95
- dev_names.append(_device.get_dev_attr(arg))
96
-
97
- if len(set(dev_names)) != 1:
98
- raise ValueError(f"Ambiguous secure devices among arguments: {dev_names}")
99
-
100
- dev_name = dev_names[0]
101
-
102
- cluster_spec = cur_ctx().cluster_spec
103
- assert dev_name in cluster_spec.devices
104
- return cluster_spec.devices[dev_name]
105
-
106
-
107
- def _get_ppu_from_rank(rank: Rank) -> Device:
108
- """Get the PPU device for a specific rank."""
109
- for dev in cur_ctx().cluster_spec.get_devices_by_kind("PPU"):
110
- assert len(dev.members) == 1, "Expected single member PPU devices."
111
- if dev.members[0].rank == rank:
112
- return dev
113
- raise ValueError(f"No PPU device found for rank {rank}.")
114
-
115
-
116
- def seal(obj: MPObject) -> list[MPObject] | MPObject:
117
- """Seal a simp object to a secure device.
118
-
119
- Args:
120
- obj: The simp object to seal.
121
-
122
- Returns:
123
- The sealed object(s). If the input is a plaintext simp object with a multi-party
124
- mask, a list of sealed objects (one per party) is returned. Otherwise, a
125
- single sealed object is returned.
126
- """
127
-
128
- if obj.pmask is None:
129
- raise ValueError("Seal does not support dynamic masks.")
130
-
131
- if _device.is_device_obj(obj):
132
- sdev = _determine_secure_device()
133
- return _device._d2d(sdev.name, obj)
134
- else:
135
- # it's a normal plaintext simp object, treat as a list of PPU objects
136
- rets: list[MPObject] = []
137
- for rank in obj.pmask:
138
- ppu_obj = set_mask(obj, Mask.from_ranks([rank]))
139
- _device.set_dev_attr(ppu_obj, _get_ppu_from_rank(rank).name)
140
- sealed = seal(ppu_obj)
141
- assert isinstance(sealed, MPObject), (
142
- "Expected single sealed object per rank"
143
- )
144
- rets.append(sealed)
145
- return rets
146
-
147
-
148
- def seal_from(from_rank: Rank, obj: MPObject) -> MPObject:
149
- """Seal a simp object from a specific party to its PPU.
150
-
151
- Args:
152
- from_rank: The party rank from which to seal the object.
153
- obj: The simp object to seal.
154
-
155
- Returns:
156
- The sealed object.
157
- """
158
- obj = set_mask(obj, Mask.from_ranks([from_rank]))
159
- out = seal(obj)
160
- assert isinstance(out, list), "seal_from should return a list of sealed objects."
161
- assert len(out) == 1, "seal_from should return a single sealed object."
162
- return out[0]
163
-
164
-
165
- # reveal :: s a -> m a
166
- def reveal(obj: MPObject, to_mask: Mask | None = None) -> MPObject:
167
- """Reveal a sealed object to pmask'ed parties."""
168
- assert isinstance(obj, MPObject), "reveal expects an MPObject."
169
-
170
- if not _device.is_device_obj(obj):
171
- raise ValueError(f"reveal does not support non-device object={obj}.")
172
-
173
- if to_mask is None:
174
- ranks = []
175
- for rank in range(psize()):
176
- try:
177
- _get_ppu_from_rank(rank)
178
- except ValueError:
179
- continue
180
- ranks.append(rank)
181
- to_mask = Mask.from_ranks(ranks)
182
- rets = [reveal_to(rank, obj) for rank in to_mask]
183
- return pconv(rets)
184
-
185
-
186
- def reveal_to(to_rank: Rank, obj: MPObject) -> MPObject:
187
- """Reveal a sealed object to a specific party."""
188
- if not _device.is_device_obj(obj):
189
- raise ValueError("reveal_to expects a device object (sealed value).")
190
-
191
- to_dev = _get_ppu_from_rank(to_rank)
192
- return _device._d2d(to_dev.name, obj)
193
-
194
-
195
- def srun(fe_type: str, pyfn: Callable, *args: Any, **kwargs: Any) -> Any:
196
- """Run a function on sealed values securely.
197
-
198
- This function executes a computation on sealed (secret-shared) values
199
- using secure multi-party computation (MPC).
200
-
201
- Args:
202
- fe_type: The front-end type, e.g., "jax"
203
- pyfn: A function to run on sealed values
204
- *args: Positional arguments (sealed values)
205
- **kwargs: Keyword arguments (sealed values)
206
-
207
- Returns:
208
- The result of the computation, still in sealed form
209
- """
210
-
211
- fn_flat, args_flat = normalize_fn(
212
- pyfn, args, kwargs, lambda x: isinstance(x, MPObject)
213
- )
214
-
215
- dev_info = _determine_secure_device(*args_flat)
216
-
217
- dev_kind = dev_info.kind.upper()
218
- if dev_kind in {"SPU", "TEE"}:
219
- return _device.device(dev_info.name, fe_type=fe_type)(fn_flat)(args_flat)
220
- else:
221
- raise ValueError(f"Unsupported secure device kind: {dev_kind}")
222
-
223
-
224
- def srun_jax(jax_fn: Callable, *args: Any, **kwargs: Any) -> Any:
225
- """Run a jax function on sealed values securely.
226
-
227
- This function executes a JAX computation on sealed (secret-shared) values
228
- using secure multi-party computation (MPC).
229
-
230
- Args:
231
- jax_fn: A JAX function to run on sealed values
232
- *args: Positional arguments (sealed values)
233
- **kwargs: Keyword arguments (sealed values)
234
-
235
- Returns:
236
- The result of the computation, still in sealed form
237
- """
238
- return srun("jax", jax_fn, *args, **kwargs)
@@ -1,13 +0,0 @@
1
- # Copyright 2025 Ant Group Co., Ltd.
2
- #
3
- # Licensed under the Apache License, Version 2.0 (the "License");
4
- # you may not use this file except in compliance with the License.
5
- # You may obtain a copy of the License at
6
- #
7
- # http://www.apache.org/licenses/LICENSE-2.0
8
- #
9
- # Unless required by applicable law or agreed to in writing, software
10
- # distributed under the License is distributed on an "AS IS" BASIS,
11
- # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
- # See the License for the specific language governing permissions and
13
- # limitations under the License.
mplang/v1/utils/crypto.py DELETED
@@ -1,32 +0,0 @@
1
- # Copyright 2025 Ant Group Co., Ltd.
2
- #
3
- # Licensed under the Apache License, Version 2.0 (the "License");
4
- # you may not use this file except in compliance with the License.
5
- # You may obtain a copy of the License at
6
- #
7
- # http://www.apache.org/licenses/LICENSE-2.0
8
- #
9
- # Unless required by applicable law or agreed to in writing, software
10
- # distributed under the License is distributed on an "AS IS" BASIS,
11
- # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
- # See the License for the specific language governing permissions and
13
- # limitations under the License.
14
-
15
- """Lightweight cryptographic utilities.
16
-
17
- Currently only exposes `blake2b` which is used by mock crypto backends.
18
- Previously contained mock signing/key utilities which were unused in code
19
- paths and have been removed to avoid confusion.
20
- """
21
-
22
- from __future__ import annotations
23
-
24
- import hashlib
25
-
26
-
27
- def blake2b(data: bytes) -> bytes:
28
- """Return 32-byte BLAKE2b digest for the given data."""
29
- return hashlib.blake2b(data, digest_size=32).digest()
30
-
31
-
32
- __all__ = ["blake2b"]