mplang-nightly 0.1.dev156__py3-none-any.whl → 0.1.dev158__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
mplang/device.py CHANGED
@@ -207,8 +207,15 @@ def _d2d(to_dev_id: str, obj: MPObject) -> MPObject:
207
207
  assert len(frm_dev.members) == 1 and len(to_dev.members) == 1
208
208
  frm_rank = frm_dev.members[0].rank
209
209
  tee_rank = to_dev.members[0].rank
210
+ platform = to_dev.config.get("platform")
211
+ if not platform:
212
+ raise ValueError(
213
+ f"TEE device '{to_dev_id}' is missing 'platform' in its config."
214
+ )
210
215
  # Ensure sessions (both directions) exist for this PPU<->TEE pair
211
- sess_p, sess_t = _ensure_tee_session(frm_dev_id, to_dev_id, frm_rank, tee_rank)
216
+ sess_p, sess_t = _ensure_tee_session(
217
+ frm_dev_id, to_dev_id, frm_rank, tee_rank, platform
218
+ )
212
219
  # Bytes-only path: pack -> enc -> p2p -> dec -> unpack (with static out type)
213
220
  obj_ty = TensorType.from_obj(obj)
214
221
  b = simp.runAt(frm_rank, builtin.pack)(obj)
@@ -222,8 +229,15 @@ def _d2d(to_dev_id: str, obj: MPObject) -> MPObject:
222
229
  assert len(frm_dev.members) == 1 and len(to_dev.members) == 1
223
230
  tee_rank = frm_dev.members[0].rank
224
231
  ppu_rank = to_dev.members[0].rank
232
+ platform = frm_dev.config.get("platform")
233
+ if not platform:
234
+ raise ValueError(
235
+ f"TEE device '{frm_dev_id}' is missing 'platform' in its config."
236
+ )
225
237
  # Ensure bidirectional session established for this pair
226
- sess_p, sess_t = _ensure_tee_session(to_dev_id, frm_dev_id, ppu_rank, tee_rank)
238
+ sess_p, sess_t = _ensure_tee_session(
239
+ to_dev_id, frm_dev_id, ppu_rank, tee_rank, platform
240
+ )
227
241
  obj_ty = TensorType.from_obj(obj)
228
242
  b = simp.runAt(tee_rank, builtin.pack)(obj)
229
243
  ct = simp.runAt(tee_rank, crypto.enc)(b, sess_t)
@@ -245,7 +259,7 @@ def _d2d(to_dev_id: str, obj: MPObject) -> MPObject:
245
259
 
246
260
 
247
261
  def _ensure_tee_session(
248
- frm_dev_id: str, to_dev_id: str, frm_rank: int, tee_rank: int
262
+ frm_dev_id: str, to_dev_id: str, frm_rank: int, tee_rank: int, platform: str
249
263
  ) -> tuple[MPObject, MPObject]:
250
264
  """Ensure a TEE session (sess_p at sender, sess_t at TEE) exists.
251
265
 
@@ -263,11 +277,11 @@ def _ensure_tee_session(
263
277
  # 1) TEE generates (sk, pk) and quote(pk)
264
278
  # KEM suite currently constant; future: read from tee device config (e.g. cluster_spec.devices[dev_id].config)
265
279
  tee_sk, tee_pk = simp.runAt(tee_rank, crypto.kem_keygen)(_TEE_KEM_SUITE)
266
- quote = simp.runAt(tee_rank, tee.quote)(tee_pk)
280
+ quote = simp.runAt(tee_rank, tee.quote_gen)(tee_pk)
267
281
 
268
282
  # 2) Send quote to sender and attest to obtain TEE pk
269
283
  quote_at_sender = mpi.p2p(tee_rank, frm_rank, quote)
270
- tee_pk_at_sender = simp.runAt(frm_rank, tee.attest)(quote_at_sender)
284
+ tee_pk_at_sender = simp.runAt(frm_rank, tee.attest)(quote_at_sender, platform)
271
285
 
272
286
  # 3) Sender generates its ephemeral keypair and sends its pk to TEE
273
287
  v_sk, v_pk = simp.runAt(frm_rank, crypto.kem_keygen)(_TEE_KEM_SUITE)
mplang/kernels/context.py CHANGED
@@ -91,7 +91,7 @@ _DEFAULT_BINDINGS: dict[str, str] = {
91
91
  # generic SQL op; backend-specific kernel id for duckdb
92
92
  "sql.run": "duckdb.run_sql",
93
93
  # tee
94
- # "tee.quote": "mock_tee.quote",
94
+ # "tee.quote_gen": "mock_tee.quote_gen",
95
95
  # "tee.attest": "mock_tee.attest",
96
96
  }
97
97
 
@@ -45,10 +45,10 @@ def _quote_from_pk(pk: np.ndarray) -> NDArray[np.uint8]:
45
45
  return out
46
46
 
47
47
 
48
- @kernel_def("mock_tee.quote")
49
- def _tee_quote(pfunc: PFunction, pk: object) -> NDArray[np.uint8]:
48
+ @kernel_def("mock_tee.quote_gen")
49
+ def _tee_quote_gen(pfunc: PFunction, pk: object) -> NDArray[np.uint8]:
50
50
  warnings.warn(
51
- "Insecure mock TEE kernel 'mock_tee.quote' in use. NOT secure; for local testing only.",
51
+ "Insecure mock TEE kernel 'mock_tee.quote_gen' in use. NOT secure; for local testing only.",
52
52
  stacklevel=3,
53
53
  )
54
54
  pk = np.asarray(pk, dtype=np.uint8)
@@ -64,6 +64,10 @@ def _tee_attest(pfunc: PFunction, quote: object) -> NDArray[np.uint8]:
64
64
  stacklevel=3,
65
65
  )
66
66
  quote = np.asarray(quote, dtype=np.uint8)
67
+ platform = pfunc.attrs.get("platform")
68
+ if platform is None:
69
+ raise ValueError("missing required 'platform' attribute in PFunction")
70
+
67
71
  if quote.size != 33:
68
72
  raise ValueError("mock quote must be 33 bytes (1 header + 32 pk)")
69
73
  return quote[1:33].astype(np.uint8)
mplang/ops/tee.py CHANGED
@@ -14,7 +14,11 @@
14
14
 
15
15
  from __future__ import annotations
16
16
 
17
+ from jax.tree_util import PyTreeDef, tree_flatten
18
+
17
19
  from mplang.core.dtype import UINT8
20
+ from mplang.core.mpobject import MPObject
21
+ from mplang.core.pfunc import PFunction
18
22
  from mplang.core.tensor import TensorType
19
23
  from mplang.ops.base import stateless_mod
20
24
 
@@ -22,21 +26,26 @@ _TEE_MOD = stateless_mod("tee")
22
26
 
23
27
 
24
28
  @_TEE_MOD.simple_op()
25
- def quote(pk: TensorType) -> TensorType:
26
- """TEE quote generation binding the provided ephemeral public key.
27
-
28
- API (mock): quote(pk: u8[32]) -> (quote: u8[33])
29
- The mock encodes a 1-byte header + 32-byte pk.
30
- """
29
+ def quote_gen(pk: TensorType) -> TensorType:
30
+ """TEE quote generation binding the provided ephemeral public key."""
31
31
  _ = pk # Mark as used for the decorator
32
- return TensorType(UINT8, (33,))
33
-
34
-
35
- @_TEE_MOD.simple_op()
36
- def attest(quote: TensorType) -> TensorType:
37
- """TEE quote verification returning the attested TEE public key.
38
-
39
- API (mock): attest(quote: u8[33]) -> tee_pk: u8[32]
40
- """
41
- _ = quote # Mark as used for the decorator
42
- return TensorType(UINT8, (32,))
32
+ return TensorType(UINT8, (-1,))
33
+
34
+
35
+ @_TEE_MOD.op_def()
36
+ def attest(
37
+ quote: MPObject, platform: str
38
+ ) -> tuple[PFunction, list[MPObject], PyTreeDef]:
39
+ """TEE quote verification returning the attested TEE public key."""
40
+
41
+ ins_info = [TensorType.from_obj(quote)]
42
+ outs_info = [TensorType(UINT8, (32,))] # pk is always 32 bytes for x25519
43
+ pfunc = PFunction(
44
+ fn_type="tee.attest",
45
+ ins_info=ins_info,
46
+ outs_info=outs_info,
47
+ platform=platform,
48
+ )
49
+ _, treedef = tree_flatten(outs_info[0])
50
+
51
+ return pfunc, [quote], treedef
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: mplang-nightly
3
- Version: 0.1.dev156
3
+ Version: 0.1.dev158
4
4
  Summary: Multi-Party Programming Language
5
5
  Author-email: SecretFlow Team <secretflow-contact@service.alipay.com>
6
6
  License: Apache License
@@ -1,6 +1,6 @@
1
1
  mplang/__init__.py,sha256=ofO-F-CNoVIxpMpTJtTJoQtKegJcHwcOJLzoVispiyc,1852
2
2
  mplang/api.py,sha256=ssmv0_CyZPFORhOUJ84Jo6NwRJSK7_Ono3n7ZjEg4sA,3058
3
- mplang/device.py,sha256=RmjnhzHxJkkNmtBKtYMEbpQYBZpuC43qlllkCOp-QD8,12548
3
+ mplang/device.py,sha256=UxfsgaGrJaQyAOzrQBWYArqU6CHIvyONWxgAJHXVQ_0,13041
4
4
  mplang/analysis/__init__.py,sha256=CTHFvRsi-nFngojqjn08UaR3RY9i7CJ7T2UdR95kCrk,1056
5
5
  mplang/analysis/diagram.py,sha256=ffwgD12gL1_KH1uJ_EYkjmIlDrfxYJJkWj-wHl09_Xk,19520
6
6
  mplang/core/__init__.py,sha256=lWxlEKfRwX7FNDzgyKZ1fiDMaCiqkyg0j5mKlZD_v7g,2244
@@ -29,9 +29,9 @@ mplang/core/expr/walk.py,sha256=lXkGJEEuvKGDqQihbxXPxfz2RfR1Q1zYUlt11iooQW0,1188
29
29
  mplang/kernels/__init__.py,sha256=2WE4cmW96Xkzyq2yRRYNww4kZ5o6u6NbPV0BxqZG698,581
30
30
  mplang/kernels/base.py,sha256=-YV4Aj5fs6GT4ehS6Tyi8WQ-amxn5edHTFJRQzyjHXY,3826
31
31
  mplang/kernels/builtin.py,sha256=nSuM79cn7M6M27A6Y8ycilXT_qAlB1ktkwkRX6dv_VQ,7052
32
- mplang/kernels/context.py,sha256=yJjQUyQmBBl6btLb7KlmA-Ejf4-cgK3KmtC-m0sBbb8,13594
32
+ mplang/kernels/context.py,sha256=n-Z7fz7HjHb3UY380iZcasmn2sK-OQUGEIWJk2-fT18,13602
33
33
  mplang/kernels/crypto.py,sha256=s7R0yd4Fk5cI2Qd3LpLc-kmbVuk8fFsbKbfKi43R0aE,3892
34
- mplang/kernels/mock_tee.py,sha256=173QSzPgkrLo0zn0jsx6nNmq1WvfUhIM67FM4Dn30aA,2297
34
+ mplang/kernels/mock_tee.py,sha256=wzXgiCpDWqQFS1HJHa5z_GxRJM7-r5l6T7KBpNcIZlg,2457
35
35
  mplang/kernels/phe.py,sha256=8-_1IFPOaGECGj9mbYja8XoqbMYnYqfpDNVyMJd8J1Y,65247
36
36
  mplang/kernels/spu.py,sha256=LkM8tNzhwTa8lufNgClHfnI4LNu25cdWLQZdJsMDEO8,9301
37
37
  mplang/kernels/sql_duckdb.py,sha256=UN1Ev6-MxF_-65zMExUsLScC9PlmEIEcN8YziIoX_rY,1724
@@ -45,7 +45,7 @@ mplang/ops/jax_cc.py,sha256=42czYg3hNQbI_nUebXnshlU8ULwM-oBDe_TQoApLNVA,7802
45
45
  mplang/ops/phe.py,sha256=SatswExjZWPed8y3qA33BCwIWbvsgHCuCAz_pv2RLLw,6790
46
46
  mplang/ops/spu.py,sha256=UHr5DSoqG08xDYER_11OsMVjGGNXXxsvkFoVvXU8uik,4989
47
47
  mplang/ops/sql.py,sha256=HyY2i5aGC5W7r62JryFSjQCUDXH3kQz82YADwn4z5uc,2015
48
- mplang/ops/tee.py,sha256=gwzP81y2idH-d-Du84H6oNZpLaGD-3fEgm8G1uxWpUA,1388
48
+ mplang/ops/tee.py,sha256=1yoaFFF5NI9gJge1-bhbT2lsphjBXErfDsleYlEMoWs,1664
49
49
  mplang/protos/v1alpha1/mpir_pb2.py,sha256=Bros37t-4LMJbuUYVSM65rImUYTtZDhNTIADGbZCKp0,7522
50
50
  mplang/protos/v1alpha1/mpir_pb2.pyi,sha256=GwXR4wPB_kB_36iYS9x-cGI9KDKFMq89KhdLhW_xmvE,19342
51
51
  mplang/protos/v1alpha1/mpir_pb2_grpc.py,sha256=xYOs94SXiNYAlFodACnsXW5QovLsHY5tCk3p76RH5Zc,158
@@ -70,8 +70,8 @@ mplang/utils/crypto.py,sha256=rvPomBFtznRHc3RPi6Aip9lsU8zW2oxBqGv1K3vn7Rs,1052
70
70
  mplang/utils/func_utils.py,sha256=vCJcZmu0bEbqhOQKdpttV2_MBllIcPSN0b8U4WjNGGo,5164
71
71
  mplang/utils/spu_utils.py,sha256=S3L9RBkBe2AvSuMSQQ12cBY5Y1NPthubvErSX_7nj1A,4158
72
72
  mplang/utils/table_utils.py,sha256=aC-IZOKkSmFkpr3NZchLM0Wt0GOn-rg_xHBHREWBwAU,2202
73
- mplang_nightly-0.1.dev156.dist-info/METADATA,sha256=iaTA5SM0ALRdvM5h40B44CV8WMxsMQXP6DqsPrqeCY4,16547
74
- mplang_nightly-0.1.dev156.dist-info/WHEEL,sha256=qtCwoSJWgHk21S1Kb4ihdzI2rlJ1ZKaIurTj_ngOhyQ,87
75
- mplang_nightly-0.1.dev156.dist-info/entry_points.txt,sha256=mG1oJT-GAjQR834a62_QIWb7litzWPPyVnwFqm-rWuY,55
76
- mplang_nightly-0.1.dev156.dist-info/licenses/LICENSE,sha256=xx0jnfkXJvxRnG63LTGOxlggYnIysveWIZ6H3PNdCrQ,11357
77
- mplang_nightly-0.1.dev156.dist-info/RECORD,,
73
+ mplang_nightly-0.1.dev158.dist-info/METADATA,sha256=LCrLu9AfQJfLUeapTmp0dAVGjYUjl92YfGLJh4WXKXQ,16547
74
+ mplang_nightly-0.1.dev158.dist-info/WHEEL,sha256=qtCwoSJWgHk21S1Kb4ihdzI2rlJ1ZKaIurTj_ngOhyQ,87
75
+ mplang_nightly-0.1.dev158.dist-info/entry_points.txt,sha256=mG1oJT-GAjQR834a62_QIWb7litzWPPyVnwFqm-rWuY,55
76
+ mplang_nightly-0.1.dev158.dist-info/licenses/LICENSE,sha256=xx0jnfkXJvxRnG63LTGOxlggYnIysveWIZ6H3PNdCrQ,11357
77
+ mplang_nightly-0.1.dev158.dist-info/RECORD,,