mplang-nightly 0.1.dev158__py3-none-any.whl → 0.1.dev268__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (191) hide show
  1. mplang/__init__.py +21 -45
  2. mplang/py.typed +13 -0
  3. mplang/v1/__init__.py +157 -0
  4. mplang/v1/_device.py +602 -0
  5. mplang/{analysis → v1/analysis}/__init__.py +1 -1
  6. mplang/{analysis → v1/analysis}/diagram.py +5 -7
  7. mplang/v1/core/__init__.py +157 -0
  8. mplang/{core → v1/core}/cluster.py +30 -14
  9. mplang/{core → v1/core}/comm.py +5 -1
  10. mplang/{core → v1/core}/context_mgr.py +1 -1
  11. mplang/{core/dtype.py → v1/core/dtypes.py} +44 -2
  12. mplang/{core → v1/core}/expr/__init__.py +7 -7
  13. mplang/{core → v1/core}/expr/ast.py +13 -14
  14. mplang/{core → v1/core}/expr/evaluator.py +65 -24
  15. mplang/{core → v1/core}/expr/printer.py +24 -18
  16. mplang/{core → v1/core}/expr/transformer.py +3 -3
  17. mplang/{core → v1/core}/expr/utils.py +2 -2
  18. mplang/{core → v1/core}/expr/visitor.py +1 -1
  19. mplang/{core → v1/core}/expr/walk.py +1 -1
  20. mplang/{core → v1/core}/interp.py +6 -6
  21. mplang/{core → v1/core}/mpir.py +23 -16
  22. mplang/{core → v1/core}/mpobject.py +6 -6
  23. mplang/{core → v1/core}/mptype.py +13 -10
  24. mplang/{core → v1/core}/pfunc.py +4 -4
  25. mplang/{core → v1/core}/primitive.py +106 -201
  26. mplang/{core → v1/core}/table.py +36 -8
  27. mplang/{core → v1/core}/tensor.py +1 -1
  28. mplang/{core → v1/core}/tracer.py +9 -9
  29. mplang/{api.py → v1/host.py} +38 -6
  30. mplang/v1/kernels/__init__.py +41 -0
  31. mplang/{kernels → v1/kernels}/base.py +1 -1
  32. mplang/v1/kernels/basic.py +240 -0
  33. mplang/{kernels → v1/kernels}/context.py +42 -27
  34. mplang/{kernels → v1/kernels}/crypto.py +44 -37
  35. mplang/v1/kernels/fhe.py +858 -0
  36. mplang/{kernels → v1/kernels}/mock_tee.py +12 -13
  37. mplang/{kernels → v1/kernels}/phe.py +263 -57
  38. mplang/{kernels → v1/kernels}/spu.py +137 -48
  39. mplang/{kernels → v1/kernels}/sql_duckdb.py +12 -15
  40. mplang/{kernels → v1/kernels}/stablehlo.py +30 -23
  41. mplang/v1/kernels/value.py +626 -0
  42. mplang/{ops → v1/ops}/__init__.py +5 -16
  43. mplang/{ops → v1/ops}/base.py +2 -5
  44. mplang/{ops/builtin.py → v1/ops/basic.py} +34 -26
  45. mplang/v1/ops/crypto.py +262 -0
  46. mplang/v1/ops/fhe.py +272 -0
  47. mplang/{ops → v1/ops}/jax_cc.py +33 -68
  48. mplang/v1/ops/nnx_cc.py +168 -0
  49. mplang/{ops → v1/ops}/phe.py +16 -4
  50. mplang/{ops → v1/ops}/spu.py +3 -5
  51. mplang/v1/ops/sql_cc.py +303 -0
  52. mplang/{ops → v1/ops}/tee.py +9 -24
  53. mplang/{protos → v1/protos}/v1alpha1/mpir_pb2.pyi +71 -21
  54. mplang/v1/protos/v1alpha1/value_pb2.py +34 -0
  55. mplang/v1/protos/v1alpha1/value_pb2.pyi +169 -0
  56. mplang/{runtime → v1/runtime}/__init__.py +2 -2
  57. mplang/v1/runtime/channel.py +230 -0
  58. mplang/{runtime → v1/runtime}/cli.py +35 -20
  59. mplang/{runtime → v1/runtime}/client.py +19 -8
  60. mplang/{runtime → v1/runtime}/communicator.py +59 -15
  61. mplang/{runtime → v1/runtime}/data_providers.py +80 -19
  62. mplang/{runtime → v1/runtime}/driver.py +30 -12
  63. mplang/v1/runtime/link_comm.py +196 -0
  64. mplang/{runtime → v1/runtime}/server.py +58 -42
  65. mplang/{runtime → v1/runtime}/session.py +57 -71
  66. mplang/{runtime → v1/runtime}/simulation.py +55 -28
  67. mplang/v1/simp/api.py +353 -0
  68. mplang/{simp → v1/simp}/mpi.py +8 -9
  69. mplang/{simp/__init__.py → v1/simp/party.py} +19 -145
  70. mplang/{simp → v1/simp}/random.py +21 -22
  71. mplang/v1/simp/smpc.py +238 -0
  72. mplang/v1/utils/table_utils.py +185 -0
  73. mplang/v2/__init__.py +424 -0
  74. mplang/v2/backends/__init__.py +57 -0
  75. mplang/v2/backends/bfv_impl.py +705 -0
  76. mplang/v2/backends/channel.py +217 -0
  77. mplang/v2/backends/crypto_impl.py +723 -0
  78. mplang/v2/backends/field_impl.py +454 -0
  79. mplang/v2/backends/func_impl.py +107 -0
  80. mplang/v2/backends/phe_impl.py +148 -0
  81. mplang/v2/backends/simp_design.md +136 -0
  82. mplang/v2/backends/simp_driver/__init__.py +41 -0
  83. mplang/v2/backends/simp_driver/http.py +168 -0
  84. mplang/v2/backends/simp_driver/mem.py +280 -0
  85. mplang/v2/backends/simp_driver/ops.py +135 -0
  86. mplang/v2/backends/simp_driver/state.py +60 -0
  87. mplang/v2/backends/simp_driver/values.py +52 -0
  88. mplang/v2/backends/simp_worker/__init__.py +29 -0
  89. mplang/v2/backends/simp_worker/http.py +354 -0
  90. mplang/v2/backends/simp_worker/mem.py +102 -0
  91. mplang/v2/backends/simp_worker/ops.py +167 -0
  92. mplang/v2/backends/simp_worker/state.py +49 -0
  93. mplang/v2/backends/spu_impl.py +275 -0
  94. mplang/v2/backends/spu_state.py +187 -0
  95. mplang/v2/backends/store_impl.py +62 -0
  96. mplang/v2/backends/table_impl.py +838 -0
  97. mplang/v2/backends/tee_impl.py +215 -0
  98. mplang/v2/backends/tensor_impl.py +519 -0
  99. mplang/v2/cli.py +603 -0
  100. mplang/v2/cli_guide.md +122 -0
  101. mplang/v2/dialects/__init__.py +36 -0
  102. mplang/v2/dialects/bfv.py +665 -0
  103. mplang/v2/dialects/crypto.py +689 -0
  104. mplang/v2/dialects/dtypes.py +378 -0
  105. mplang/v2/dialects/field.py +210 -0
  106. mplang/v2/dialects/func.py +135 -0
  107. mplang/v2/dialects/phe.py +723 -0
  108. mplang/v2/dialects/simp.py +944 -0
  109. mplang/v2/dialects/spu.py +349 -0
  110. mplang/v2/dialects/store.py +63 -0
  111. mplang/v2/dialects/table.py +407 -0
  112. mplang/v2/dialects/tee.py +346 -0
  113. mplang/v2/dialects/tensor.py +1175 -0
  114. mplang/v2/edsl/README.md +279 -0
  115. mplang/v2/edsl/__init__.py +99 -0
  116. mplang/v2/edsl/context.py +311 -0
  117. mplang/v2/edsl/graph.py +463 -0
  118. mplang/v2/edsl/jit.py +62 -0
  119. mplang/v2/edsl/object.py +53 -0
  120. mplang/v2/edsl/primitive.py +284 -0
  121. mplang/v2/edsl/printer.py +119 -0
  122. mplang/v2/edsl/registry.py +207 -0
  123. mplang/v2/edsl/serde.py +375 -0
  124. mplang/v2/edsl/tracer.py +614 -0
  125. mplang/v2/edsl/typing.py +816 -0
  126. mplang/v2/kernels/Makefile +30 -0
  127. mplang/v2/kernels/__init__.py +23 -0
  128. mplang/v2/kernels/gf128.cpp +148 -0
  129. mplang/v2/kernels/ldpc.cpp +82 -0
  130. mplang/v2/kernels/okvs.cpp +283 -0
  131. mplang/v2/kernels/okvs_opt.cpp +291 -0
  132. mplang/v2/kernels/py_kernels.py +398 -0
  133. mplang/v2/libs/collective.py +330 -0
  134. mplang/v2/libs/device/__init__.py +51 -0
  135. mplang/v2/libs/device/api.py +813 -0
  136. mplang/v2/libs/device/cluster.py +352 -0
  137. mplang/v2/libs/ml/__init__.py +23 -0
  138. mplang/v2/libs/ml/sgb.py +1861 -0
  139. mplang/v2/libs/mpc/__init__.py +41 -0
  140. mplang/v2/libs/mpc/_utils.py +99 -0
  141. mplang/v2/libs/mpc/analytics/__init__.py +35 -0
  142. mplang/v2/libs/mpc/analytics/aggregation.py +372 -0
  143. mplang/v2/libs/mpc/analytics/groupby.md +99 -0
  144. mplang/v2/libs/mpc/analytics/groupby.py +331 -0
  145. mplang/v2/libs/mpc/analytics/permutation.py +386 -0
  146. mplang/v2/libs/mpc/common/constants.py +39 -0
  147. mplang/v2/libs/mpc/ot/__init__.py +32 -0
  148. mplang/v2/libs/mpc/ot/base.py +222 -0
  149. mplang/v2/libs/mpc/ot/extension.py +477 -0
  150. mplang/v2/libs/mpc/ot/silent.py +217 -0
  151. mplang/v2/libs/mpc/psi/__init__.py +40 -0
  152. mplang/v2/libs/mpc/psi/cuckoo.py +228 -0
  153. mplang/v2/libs/mpc/psi/okvs.py +49 -0
  154. mplang/v2/libs/mpc/psi/okvs_gct.py +79 -0
  155. mplang/v2/libs/mpc/psi/oprf.py +310 -0
  156. mplang/v2/libs/mpc/psi/rr22.py +344 -0
  157. mplang/v2/libs/mpc/psi/unbalanced.py +200 -0
  158. mplang/v2/libs/mpc/vole/__init__.py +31 -0
  159. mplang/v2/libs/mpc/vole/gilboa.py +327 -0
  160. mplang/v2/libs/mpc/vole/ldpc.py +383 -0
  161. mplang/v2/libs/mpc/vole/silver.py +336 -0
  162. mplang/v2/runtime/__init__.py +15 -0
  163. mplang/v2/runtime/dialect_state.py +41 -0
  164. mplang/v2/runtime/interpreter.py +871 -0
  165. mplang/v2/runtime/object_store.py +194 -0
  166. mplang/v2/runtime/value.py +141 -0
  167. {mplang_nightly-0.1.dev158.dist-info → mplang_nightly-0.1.dev268.dist-info}/METADATA +24 -17
  168. mplang_nightly-0.1.dev268.dist-info/RECORD +180 -0
  169. {mplang_nightly-0.1.dev158.dist-info → mplang_nightly-0.1.dev268.dist-info}/WHEEL +1 -1
  170. mplang/core/__init__.py +0 -92
  171. mplang/device.py +0 -340
  172. mplang/kernels/builtin.py +0 -207
  173. mplang/ops/crypto.py +0 -109
  174. mplang/ops/ibis_cc.py +0 -139
  175. mplang/ops/sql.py +0 -61
  176. mplang/protos/v1alpha1/mpir_pb2_grpc.py +0 -3
  177. mplang/runtime/link_comm.py +0 -131
  178. mplang/simp/smpc.py +0 -201
  179. mplang/utils/table_utils.py +0 -73
  180. mplang_nightly-0.1.dev158.dist-info/RECORD +0 -77
  181. /mplang/{core → v1/core}/mask.py +0 -0
  182. /mplang/{protos → v1/protos}/v1alpha1/mpir_pb2.py +0 -0
  183. /mplang/{runtime → v1/runtime}/exceptions.py +0 -0
  184. /mplang/{runtime → v1/runtime}/http_api.md +0 -0
  185. /mplang/{kernels → v1/simp}/__init__.py +0 -0
  186. /mplang/{utils → v1/utils}/__init__.py +0 -0
  187. /mplang/{utils → v1/utils}/crypto.py +0 -0
  188. /mplang/{utils → v1/utils}/func_utils.py +0 -0
  189. /mplang/{utils → v1/utils}/spu_utils.py +0 -0
  190. {mplang_nightly-0.1.dev158.dist-info → mplang_nightly-0.1.dev268.dist-info}/entry_points.txt +0 -0
  191. {mplang_nightly-0.1.dev158.dist-info → mplang_nightly-0.1.dev268.dist-info}/licenses/LICENSE +0 -0
@@ -21,26 +21,25 @@ import jax.numpy as jnp
21
21
  import jax.random as jr
22
22
  from jax.typing import ArrayLike
23
23
 
24
- import mplang.core.primitive as prim
25
- from mplang import simp
26
- from mplang.core import MPObject, Shape
24
+ from mplang.v1.core import MPObject, Shape, function, pmask, psize
25
+ from mplang.v1.simp.api import prand, prank, run_jax
27
26
 
28
27
 
29
- @prim.function
28
+ @function
30
29
  def key_split(key: MPObject) -> tuple[MPObject, MPObject]:
31
30
  """Split the key into two keys."""
32
31
 
33
32
  def kernel(key: jax.Array) -> tuple[jax.Array, jax.Array]:
34
33
  # TODO: since MPObject tensor does not implement slicing yet.
35
- # subkey, key = simp.run(jr.split)(key) does not work.
34
+ # subkey, key = run_jax(jr.split, key) does not work.
36
35
  # we workaround it by splitting inside tracer.
37
36
  subkey, key = jr.split(key)
38
37
  return subkey, key
39
38
 
40
- return simp.run(kernel)(key) # type: ignore[no-any-return]
39
+ return run_jax(kernel, key) # type: ignore[no-any-return]
41
40
 
42
41
 
43
- @prim.function
42
+ @function
44
43
  def ukey(seed: int | ArrayLike) -> MPObject:
45
44
  """Party uniformly generate a random key."""
46
45
 
@@ -49,10 +48,10 @@ def ukey(seed: int | ArrayLike) -> MPObject:
49
48
  # Note: key.dtype is jax._src.prng.KeyTy, which could not be handled by MPObject.
50
49
  return jax.random.key_data(key)
51
50
 
52
- return simp.run(kernel)() # type: ignore[no-any-return]
51
+ return run_jax(kernel) # type: ignore[no-any-return]
53
52
 
54
53
 
55
- @prim.function
54
+ @function
56
55
  def urandint(
57
56
  key: MPObject | ArrayLike,
58
57
  low: int,
@@ -61,13 +60,13 @@ def urandint(
61
60
  ) -> MPObject:
62
61
  """Party uniformly generate a random integer in the range [low, high) with the given shape."""
63
62
 
64
- return simp.run(partial(jr.randint, minval=low, maxval=high, shape=shape))(key) # type: ignore[no-any-return]
63
+ return run_jax(partial(jr.randint, minval=low, maxval=high, shape=shape), key) # type: ignore[no-any-return]
65
64
 
66
65
 
67
66
  # Private(different per-party) related functions begin.
68
67
 
69
68
 
70
- @prim.function
69
+ @function
71
70
  def prandint(low: int, high: int, shape: Shape = ()) -> MPObject:
72
71
  """Party privately generate a random integer in the range [low, high) with the given shape."""
73
72
 
@@ -80,11 +79,11 @@ def prandint(low: int, high: int, shape: Shape = ()) -> MPObject:
80
79
  result = low + remainder.astype(jnp.int64)
81
80
  return result
82
81
 
83
- rand_u64 = prim.prand(shape)
84
- return simp.run(kernel)(rand_u64) # type: ignore[no-any-return]
82
+ rand_u64 = prand(shape)
83
+ return run_jax(kernel, rand_u64) # type: ignore[no-any-return]
85
84
 
86
85
 
87
- @prim.function
86
+ @function
88
87
  def pperm(key: MPObject) -> MPObject:
89
88
  """Party jointly generate a random permutation.
90
89
 
@@ -97,25 +96,25 @@ def pperm(key: MPObject) -> MPObject:
97
96
  if key.pmask is None:
98
97
  raise ValueError("dynamic pmask is not supported for pperm")
99
98
 
100
- full_mask = (1 << prim.psize()) - 1
99
+ full_mask = (1 << psize()) - 1
101
100
 
102
101
  if key.pmask != full_mask:
103
102
  raise ValueError(
104
103
  "key must be a MPObject with mask covering all parties, "
105
- f"got {key.pmask} with world size {prim.psize()}"
104
+ f"got {key.pmask} with world size {psize()}"
106
105
  )
107
106
 
108
- if prim.pmask() is None or prim.pmask() != full_mask:
107
+ if pmask() is None or pmask() != full_mask:
109
108
  raise ValueError(
110
109
  "pperm must be run with a mask covering all parties, "
111
- f"got {key.pmask} with world size {prim.psize()}"
110
+ f"got {key.pmask} with world size {psize()}"
112
111
  )
113
112
 
114
- size = prim.psize()
113
+ size = psize()
115
114
 
116
115
  def kernel(key: jax.Array) -> jax.Array:
117
116
  return jr.permutation(key, size)
118
117
 
119
- perm = simp.run(kernel)(key)
120
- rank = prim.prank()
121
- return simp.run(lambda perm, rank: perm[rank])(perm, rank) # type: ignore[no-any-return]
118
+ perm = run_jax(kernel, key)
119
+ rank = prank()
120
+ return run_jax(lambda perm, rank: perm[rank], perm, rank) # type: ignore[no-any-return]
mplang/v1/simp/smpc.py ADDED
@@ -0,0 +1,238 @@
1
+ # Copyright 2025 Ant Group Co., Ltd.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ """
16
+ SMPC on simp: conventions and object semantics
17
+
18
+ Overview
19
+ - simp is party-centric. Objects produced purely by simp code carry only an execution
20
+ mask ("pmask") and have no security device semantics by default.
21
+ - Secure semantics (secret sharing, protected execution, declassification) are introduced
22
+ only when using the device API or the helpers in this module: "seal", "srun", "reveal".
23
+
24
+ Definitions
25
+ - "__device__" attribute is attached by the device API to indicate the concrete device
26
+ an object is bound to (e.g., an SPU/TEE/PPU name). See mplang.device.DEVICE_ATTR_NAME.
27
+ - pmask describes which parties currently hold/execute the value under the simp model.
28
+
29
+ Conventions
30
+ 1) If an object has NO "__device__" attribute (i.e., it has not gone through the device API):
31
+ - It is a simp object, privately owned on the parties indicated by its pmask.
32
+ - When sealed via "seal(obj)", we infer target PPU device(s) from pmask:
33
+ • one-hot pmask {pi} → route to PPU(pi).
34
+ • multi-party pmask → fan out per party and seal independently to each party's PPU.
35
+ - Such objects CANNOT be passed to "srun"/"reveal" directly; seal first.
36
+
37
+ 2) If an object HAS a "__device__" attribute:
38
+ - Its behavior follows the bound device (e.g., SPU/TEE/PPU) and its membership.
39
+ - "srun" executes on that device; "reveal" declassifies from that device to the requested parties.
40
+ - pmask must be consistent with the device membership during transitions; inconsistencies raise errors.
41
+
42
+ Notes
43
+ - "seal"/"seal_from" construct secret shares on the chosen secure device and attach the
44
+ "__device__" attribute to outputs. "srun"/"reveal" assume inputs are already sealed
45
+ (device-bound) and validate pmask ↔ device-membership consistency.
46
+ - These rules align with "design/simp_vs_device.md" and keep routing unambiguous.
47
+
48
+ Examples (obj state → interpretation/behavior)
49
+ - {pmask={A}, dev_attr=None}: simp plaintext on party A. "seal" routes to PPU(A);
50
+ must "seal" before "srun"/"reveal".
51
+ - {pmask={A,B}, dev_attr=None}: simp plaintext held by A and B. "seal" produces two
52
+ per-party sealed objects via PPU(A) and PPU(B), respectively.
53
+ - {pmask={A,B}, dev_attr="spu:spu0"}: device object on SPU(spu0) whose members are {A,B};
54
+ "srun" runs on spu0; "reveal(to={A})" reveals result to party A.
55
+ - {pmask={A}, dev_attr="ppu:A"}: device object on PPU(A); "reveal(to={A})" returns A's plaintext.
56
+ - {pmask=None, dev_attr=None}: dynamic pmask; "seal" is unsupported and will error.
57
+ - {pmask={A}, dev_attr="spu:spu0"} where A ∉ members(spu0): inconsistent; operations will error.
58
+ """
59
+
60
+ from collections.abc import Callable
61
+ from typing import Any
62
+
63
+ from mplang.v1 import _device
64
+ from mplang.v1.core import Mask, MPObject, Rank, psize
65
+ from mplang.v1.core.cluster import Device
66
+ from mplang.v1.core.context_mgr import cur_ctx
67
+ from mplang.v1.core.primitive import pconv
68
+ from mplang.v1.simp.api import set_mask
69
+ from mplang.v1.utils.func_utils import normalize_fn
70
+
71
+
72
+ def _determine_secure_device(*args: MPObject) -> Device:
73
+ """Determine secure device from args, or find any available if no args."""
74
+ if not args:
75
+ # Find an available secure device (fallback when no args provided).
76
+ devices = cur_ctx().cluster_spec.get_devices_by_kind("SPU")
77
+ if devices:
78
+ return devices[0]
79
+
80
+ devices = cur_ctx().cluster_spec.get_devices_by_kind("TEE")
81
+ if devices:
82
+ return devices[0]
83
+
84
+ raise ValueError(
85
+ "No secure device (SPU or TEE) found in the cluster specification"
86
+ )
87
+
88
+ dev_names: list[str] = []
89
+ for arg in args:
90
+ if not _device.is_device_obj(arg):
91
+ raise ValueError(
92
+ "srun/reveal expect sealed inputs with a device attribute; "
93
+ f"got an unsealed object: {arg}. Please call seal()/seal_from() first."
94
+ )
95
+ dev_names.append(_device.get_dev_attr(arg))
96
+
97
+ if len(set(dev_names)) != 1:
98
+ raise ValueError(f"Ambiguous secure devices among arguments: {dev_names}")
99
+
100
+ dev_name = dev_names[0]
101
+
102
+ cluster_spec = cur_ctx().cluster_spec
103
+ assert dev_name in cluster_spec.devices
104
+ return cluster_spec.devices[dev_name]
105
+
106
+
107
+ def _get_ppu_from_rank(rank: Rank) -> Device:
108
+ """Get the PPU device for a specific rank."""
109
+ for dev in cur_ctx().cluster_spec.get_devices_by_kind("PPU"):
110
+ assert len(dev.members) == 1, "Expected single member PPU devices."
111
+ if dev.members[0].rank == rank:
112
+ return dev
113
+ raise ValueError(f"No PPU device found for rank {rank}.")
114
+
115
+
116
+ def seal(obj: MPObject) -> list[MPObject] | MPObject:
117
+ """Seal a simp object to a secure device.
118
+
119
+ Args:
120
+ obj: The simp object to seal.
121
+
122
+ Returns:
123
+ The sealed object(s). If the input is a plaintext simp object with a multi-party
124
+ mask, a list of sealed objects (one per party) is returned. Otherwise, a
125
+ single sealed object is returned.
126
+ """
127
+
128
+ if obj.pmask is None:
129
+ raise ValueError("Seal does not support dynamic masks.")
130
+
131
+ if _device.is_device_obj(obj):
132
+ sdev = _determine_secure_device()
133
+ return _device._d2d(sdev.name, obj)
134
+ else:
135
+ # it's a normal plaintext simp object, treat as a list of PPU objects
136
+ rets: list[MPObject] = []
137
+ for rank in obj.pmask:
138
+ ppu_obj = set_mask(obj, Mask.from_ranks([rank]))
139
+ _device.set_dev_attr(ppu_obj, _get_ppu_from_rank(rank).name)
140
+ sealed = seal(ppu_obj)
141
+ assert isinstance(sealed, MPObject), (
142
+ "Expected single sealed object per rank"
143
+ )
144
+ rets.append(sealed)
145
+ return rets
146
+
147
+
148
+ def seal_from(from_rank: Rank, obj: MPObject) -> MPObject:
149
+ """Seal a simp object from a specific party to its PPU.
150
+
151
+ Args:
152
+ from_rank: The party rank from which to seal the object.
153
+ obj: The simp object to seal.
154
+
155
+ Returns:
156
+ The sealed object.
157
+ """
158
+ obj = set_mask(obj, Mask.from_ranks([from_rank]))
159
+ out = seal(obj)
160
+ assert isinstance(out, list), "seal_from should return a list of sealed objects."
161
+ assert len(out) == 1, "seal_from should return a single sealed object."
162
+ return out[0]
163
+
164
+
165
+ # reveal :: s a -> m a
166
+ def reveal(obj: MPObject, to_mask: Mask | None = None) -> MPObject:
167
+ """Reveal a sealed object to pmask'ed parties."""
168
+ assert isinstance(obj, MPObject), "reveal expects an MPObject."
169
+
170
+ if not _device.is_device_obj(obj):
171
+ raise ValueError(f"reveal does not support non-device object={obj}.")
172
+
173
+ if to_mask is None:
174
+ ranks = []
175
+ for rank in range(psize()):
176
+ try:
177
+ _get_ppu_from_rank(rank)
178
+ except ValueError:
179
+ continue
180
+ ranks.append(rank)
181
+ to_mask = Mask.from_ranks(ranks)
182
+ rets = [reveal_to(rank, obj) for rank in to_mask]
183
+ return pconv(rets)
184
+
185
+
186
+ def reveal_to(to_rank: Rank, obj: MPObject) -> MPObject:
187
+ """Reveal a sealed object to a specific party."""
188
+ if not _device.is_device_obj(obj):
189
+ raise ValueError("reveal_to expects a device object (sealed value).")
190
+
191
+ to_dev = _get_ppu_from_rank(to_rank)
192
+ return _device._d2d(to_dev.name, obj)
193
+
194
+
195
+ def srun(fe_type: str, pyfn: Callable, *args: Any, **kwargs: Any) -> Any:
196
+ """Run a function on sealed values securely.
197
+
198
+ This function executes a computation on sealed (secret-shared) values
199
+ using secure multi-party computation (MPC).
200
+
201
+ Args:
202
+ fe_type: The front-end type, e.g., "jax"
203
+ pyfn: A function to run on sealed values
204
+ *args: Positional arguments (sealed values)
205
+ **kwargs: Keyword arguments (sealed values)
206
+
207
+ Returns:
208
+ The result of the computation, still in sealed form
209
+ """
210
+
211
+ fn_flat, args_flat = normalize_fn(
212
+ pyfn, args, kwargs, lambda x: isinstance(x, MPObject)
213
+ )
214
+
215
+ dev_info = _determine_secure_device(*args_flat)
216
+
217
+ dev_kind = dev_info.kind.upper()
218
+ if dev_kind in {"SPU", "TEE"}:
219
+ return _device.device(dev_info.name, fe_type=fe_type)(fn_flat)(args_flat)
220
+ else:
221
+ raise ValueError(f"Unsupported secure device kind: {dev_kind}")
222
+
223
+
224
+ def srun_jax(jax_fn: Callable, *args: Any, **kwargs: Any) -> Any:
225
+ """Run a jax function on sealed values securely.
226
+
227
+ This function executes a JAX computation on sealed (secret-shared) values
228
+ using secure multi-party computation (MPC).
229
+
230
+ Args:
231
+ jax_fn: A JAX function to run on sealed values
232
+ *args: Positional arguments (sealed values)
233
+ **kwargs: Keyword arguments (sealed values)
234
+
235
+ Returns:
236
+ The result of the computation, still in sealed form
237
+ """
238
+ return srun("jax", jax_fn, *args, **kwargs)
@@ -0,0 +1,185 @@
1
+ # Copyright 2025 Ant Group Co., Ltd.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ from __future__ import annotations
16
+
17
+ import io
18
+ from typing import Any
19
+
20
+ import pyarrow as pa
21
+ import pyarrow.csv as pa_csv
22
+ import pyarrow.orc as pa_orc
23
+ import pyarrow.parquet as pa_pq
24
+
25
+ from mplang.v1.core.table import TableLike
26
+
27
+ __all__ = ["decode_table", "encode_table", "read_table", "write_table"]
28
+
29
+
30
+ def _parse_kwargs(kwargs: dict[str, Any], keys: list[str]) -> dict[str, Any] | None:
31
+ if not kwargs:
32
+ return None
33
+
34
+ return {key: kwargs[key] for key in keys if key in kwargs}
35
+
36
+
37
+ _csv_read_option_keys = [
38
+ "skip_rows",
39
+ "skip_rows_after_names",
40
+ "column_names",
41
+ "autogenerate_column_names",
42
+ "encoding",
43
+ ]
44
+ _csv_parse_option_keys = [
45
+ "delimiter",
46
+ "quote_char",
47
+ "double_quote",
48
+ "escape_char",
49
+ "newlines_in_values",
50
+ "ignore_empty_lines",
51
+ ]
52
+ _csv_convert_option_keys = [
53
+ "check_utf8",
54
+ "column_types",
55
+ "null_values",
56
+ "true_values",
57
+ "false_values",
58
+ "decimal_point",
59
+ "strings_can_be_null",
60
+ "quoted_strings_can_be_null",
61
+ "include_columns",
62
+ "include_missing_columns",
63
+ "auto_dict_encode",
64
+ "auto_dict_max_cardinality",
65
+ "timestamp_parsers",
66
+ ]
67
+
68
+
69
+ def read_table(
70
+ source: Any,
71
+ format: str = "parquet",
72
+ columns: list[str] | None = None,
73
+ **kwargs: Any,
74
+ ) -> pa.Table:
75
+ """Read data from a file and return a PyArrow table.
76
+
77
+ Args:
78
+ source: The source to read data from (file path, file-like object, etc.)
79
+ format: The format of the data source ("parquet", "csv", or "orc")
80
+ columns: List of column names to read (None means all columns)
81
+ **kwargs: Additional keyword arguments passed to the underlying reader
82
+
83
+ Returns:
84
+ A PyArrow Table containing the data from the source
85
+
86
+ Raises:
87
+ ValueError: If an unsupported format is specified
88
+ """
89
+ match format:
90
+ case "csv":
91
+ if columns:
92
+ kwargs["include_columns"] = columns
93
+ read_args = _parse_kwargs(kwargs, _csv_read_option_keys)
94
+ parse_args = _parse_kwargs(kwargs, _csv_parse_option_keys)
95
+ convert_args = _parse_kwargs(kwargs, _csv_convert_option_keys)
96
+
97
+ read_opts = pa_csv.ReadOptions(**read_args) if read_args else None
98
+ parse_opts = pa_csv.ParseOptions(**parse_args) if parse_args else None
99
+ conv_opts = pa_csv.ConvertOptions(**convert_args) if convert_args else None
100
+ return pa_csv.read_csv(
101
+ source,
102
+ read_options=read_opts,
103
+ parse_options=parse_opts,
104
+ convert_options=conv_opts,
105
+ )
106
+ case "orc":
107
+ return pa_orc.read_table(source, columns=columns, **kwargs)
108
+ case "parquet":
109
+ return pa_pq.read_table(source, columns=columns, **kwargs)
110
+ case _:
111
+ raise ValueError(f"unsupported data format. {format}")
112
+
113
+
114
+ def write_table(
115
+ data: TableLike,
116
+ where: Any,
117
+ format: str = "parquet",
118
+ **kwargs: Any,
119
+ ) -> None:
120
+ """Write a table-like object to a file in the specified format.
121
+
122
+ Args:
123
+ data: The table-like object to write (PyArrow Table or other compatible format)
124
+ where: The destination to write to (file path, file-like object, etc.)
125
+ format: The format to write the data in ("parquet", "csv", or "orc")
126
+ **kwargs: Additional keyword arguments passed to the underlying writer
127
+
128
+ Returns:
129
+ None
130
+
131
+ Raises:
132
+ ValueError: If the table has no columns or an unsupported format is specified
133
+ """
134
+ # Convert data to PyArrow Table if needed
135
+ table = data if isinstance(data, pa.Table) else pa.table(data)
136
+ if len(table.column_names) == 0:
137
+ raise ValueError("Cannot convert Table with no columns.")
138
+
139
+ match format:
140
+ case "csv":
141
+ options = pa_csv.WriteOptions(**kwargs) if kwargs else None
142
+ pa_csv.write_csv(table, where, write_options=options)
143
+ case "orc":
144
+ pa_orc.write_table(table, where, **kwargs)
145
+ case "parquet":
146
+ pa_pq.write_table(table, where, **kwargs)
147
+ case _:
148
+ raise ValueError(f"unsupported data format. {format}")
149
+
150
+
151
+ def decode_table(
152
+ data: bytes,
153
+ format: str = "parquet",
154
+ columns: list[str] | None = None,
155
+ **kwargs: Any,
156
+ ) -> pa.Table:
157
+ """Decode a bytes object into a PyArrow table.
158
+
159
+ Args:
160
+ data: The bytes object containing the encoded table data
161
+ format: The format of the encoded data ("parquet", "csv", or "orc")
162
+ columns: List of column names to decode (None means all columns)
163
+ **kwargs: Additional keyword arguments passed to the underlying reader
164
+
165
+ Returns:
166
+ A PyArrow Table decoded from the bytes data
167
+ """
168
+ buffer = io.BytesIO(data)
169
+ return read_table(buffer, format=format, columns=columns, **kwargs)
170
+
171
+
172
+ def encode_table(data: TableLike, format: str = "parquet", **kwargs: Any) -> bytes:
173
+ """Encode a table-like object into bytes.
174
+
175
+ Args:
176
+ data: The table-like object to encode (PyArrow Table or other compatible format)
177
+ format: The format to encode the data in ("parquet", "csv", or "orc")
178
+ **kwargs: Additional keyword arguments passed to the underlying writer
179
+
180
+ Returns:
181
+ Bytes object containing the encoded table data
182
+ """
183
+ buffer = io.BytesIO()
184
+ write_table(data, buffer, format, **kwargs)
185
+ return buffer.getvalue()