mplang-nightly 0.1.dev156__py3-none-any.whl → 0.1.dev158__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- mplang/device.py +19 -5
- mplang/kernels/context.py +1 -1
- mplang/kernels/mock_tee.py +7 -3
- mplang/ops/tee.py +26 -17
- {mplang_nightly-0.1.dev156.dist-info → mplang_nightly-0.1.dev158.dist-info}/METADATA +1 -1
- {mplang_nightly-0.1.dev156.dist-info → mplang_nightly-0.1.dev158.dist-info}/RECORD +9 -9
- {mplang_nightly-0.1.dev156.dist-info → mplang_nightly-0.1.dev158.dist-info}/WHEEL +0 -0
- {mplang_nightly-0.1.dev156.dist-info → mplang_nightly-0.1.dev158.dist-info}/entry_points.txt +0 -0
- {mplang_nightly-0.1.dev156.dist-info → mplang_nightly-0.1.dev158.dist-info}/licenses/LICENSE +0 -0
mplang/device.py
CHANGED
@@ -207,8 +207,15 @@ def _d2d(to_dev_id: str, obj: MPObject) -> MPObject:
|
|
207
207
|
assert len(frm_dev.members) == 1 and len(to_dev.members) == 1
|
208
208
|
frm_rank = frm_dev.members[0].rank
|
209
209
|
tee_rank = to_dev.members[0].rank
|
210
|
+
platform = to_dev.config.get("platform")
|
211
|
+
if not platform:
|
212
|
+
raise ValueError(
|
213
|
+
f"TEE device '{to_dev_id}' is missing 'platform' in its config."
|
214
|
+
)
|
210
215
|
# Ensure sessions (both directions) exist for this PPU<->TEE pair
|
211
|
-
sess_p, sess_t = _ensure_tee_session(
|
216
|
+
sess_p, sess_t = _ensure_tee_session(
|
217
|
+
frm_dev_id, to_dev_id, frm_rank, tee_rank, platform
|
218
|
+
)
|
212
219
|
# Bytes-only path: pack -> enc -> p2p -> dec -> unpack (with static out type)
|
213
220
|
obj_ty = TensorType.from_obj(obj)
|
214
221
|
b = simp.runAt(frm_rank, builtin.pack)(obj)
|
@@ -222,8 +229,15 @@ def _d2d(to_dev_id: str, obj: MPObject) -> MPObject:
|
|
222
229
|
assert len(frm_dev.members) == 1 and len(to_dev.members) == 1
|
223
230
|
tee_rank = frm_dev.members[0].rank
|
224
231
|
ppu_rank = to_dev.members[0].rank
|
232
|
+
platform = frm_dev.config.get("platform")
|
233
|
+
if not platform:
|
234
|
+
raise ValueError(
|
235
|
+
f"TEE device '{frm_dev_id}' is missing 'platform' in its config."
|
236
|
+
)
|
225
237
|
# Ensure bidirectional session established for this pair
|
226
|
-
sess_p, sess_t = _ensure_tee_session(
|
238
|
+
sess_p, sess_t = _ensure_tee_session(
|
239
|
+
to_dev_id, frm_dev_id, ppu_rank, tee_rank, platform
|
240
|
+
)
|
227
241
|
obj_ty = TensorType.from_obj(obj)
|
228
242
|
b = simp.runAt(tee_rank, builtin.pack)(obj)
|
229
243
|
ct = simp.runAt(tee_rank, crypto.enc)(b, sess_t)
|
@@ -245,7 +259,7 @@ def _d2d(to_dev_id: str, obj: MPObject) -> MPObject:
|
|
245
259
|
|
246
260
|
|
247
261
|
def _ensure_tee_session(
|
248
|
-
frm_dev_id: str, to_dev_id: str, frm_rank: int, tee_rank: int
|
262
|
+
frm_dev_id: str, to_dev_id: str, frm_rank: int, tee_rank: int, platform: str
|
249
263
|
) -> tuple[MPObject, MPObject]:
|
250
264
|
"""Ensure a TEE session (sess_p at sender, sess_t at TEE) exists.
|
251
265
|
|
@@ -263,11 +277,11 @@ def _ensure_tee_session(
|
|
263
277
|
# 1) TEE generates (sk, pk) and quote(pk)
|
264
278
|
# KEM suite currently constant; future: read from tee device config (e.g. cluster_spec.devices[dev_id].config)
|
265
279
|
tee_sk, tee_pk = simp.runAt(tee_rank, crypto.kem_keygen)(_TEE_KEM_SUITE)
|
266
|
-
quote = simp.runAt(tee_rank, tee.
|
280
|
+
quote = simp.runAt(tee_rank, tee.quote_gen)(tee_pk)
|
267
281
|
|
268
282
|
# 2) Send quote to sender and attest to obtain TEE pk
|
269
283
|
quote_at_sender = mpi.p2p(tee_rank, frm_rank, quote)
|
270
|
-
tee_pk_at_sender = simp.runAt(frm_rank, tee.attest)(quote_at_sender)
|
284
|
+
tee_pk_at_sender = simp.runAt(frm_rank, tee.attest)(quote_at_sender, platform)
|
271
285
|
|
272
286
|
# 3) Sender generates its ephemeral keypair and sends its pk to TEE
|
273
287
|
v_sk, v_pk = simp.runAt(frm_rank, crypto.kem_keygen)(_TEE_KEM_SUITE)
|
mplang/kernels/context.py
CHANGED
@@ -91,7 +91,7 @@ _DEFAULT_BINDINGS: dict[str, str] = {
|
|
91
91
|
# generic SQL op; backend-specific kernel id for duckdb
|
92
92
|
"sql.run": "duckdb.run_sql",
|
93
93
|
# tee
|
94
|
-
# "tee.
|
94
|
+
# "tee.quote_gen": "mock_tee.quote_gen",
|
95
95
|
# "tee.attest": "mock_tee.attest",
|
96
96
|
}
|
97
97
|
|
mplang/kernels/mock_tee.py
CHANGED
@@ -45,10 +45,10 @@ def _quote_from_pk(pk: np.ndarray) -> NDArray[np.uint8]:
|
|
45
45
|
return out
|
46
46
|
|
47
47
|
|
48
|
-
@kernel_def("mock_tee.
|
49
|
-
def
|
48
|
+
@kernel_def("mock_tee.quote_gen")
|
49
|
+
def _tee_quote_gen(pfunc: PFunction, pk: object) -> NDArray[np.uint8]:
|
50
50
|
warnings.warn(
|
51
|
-
"Insecure mock TEE kernel 'mock_tee.
|
51
|
+
"Insecure mock TEE kernel 'mock_tee.quote_gen' in use. NOT secure; for local testing only.",
|
52
52
|
stacklevel=3,
|
53
53
|
)
|
54
54
|
pk = np.asarray(pk, dtype=np.uint8)
|
@@ -64,6 +64,10 @@ def _tee_attest(pfunc: PFunction, quote: object) -> NDArray[np.uint8]:
|
|
64
64
|
stacklevel=3,
|
65
65
|
)
|
66
66
|
quote = np.asarray(quote, dtype=np.uint8)
|
67
|
+
platform = pfunc.attrs.get("platform")
|
68
|
+
if platform is None:
|
69
|
+
raise ValueError("missing required 'platform' attribute in PFunction")
|
70
|
+
|
67
71
|
if quote.size != 33:
|
68
72
|
raise ValueError("mock quote must be 33 bytes (1 header + 32 pk)")
|
69
73
|
return quote[1:33].astype(np.uint8)
|
mplang/ops/tee.py
CHANGED
@@ -14,7 +14,11 @@
|
|
14
14
|
|
15
15
|
from __future__ import annotations
|
16
16
|
|
17
|
+
from jax.tree_util import PyTreeDef, tree_flatten
|
18
|
+
|
17
19
|
from mplang.core.dtype import UINT8
|
20
|
+
from mplang.core.mpobject import MPObject
|
21
|
+
from mplang.core.pfunc import PFunction
|
18
22
|
from mplang.core.tensor import TensorType
|
19
23
|
from mplang.ops.base import stateless_mod
|
20
24
|
|
@@ -22,21 +26,26 @@ _TEE_MOD = stateless_mod("tee")
|
|
22
26
|
|
23
27
|
|
24
28
|
@_TEE_MOD.simple_op()
|
25
|
-
def
|
26
|
-
"""TEE quote generation binding the provided ephemeral public key.
|
27
|
-
|
28
|
-
API (mock): quote(pk: u8[32]) -> (quote: u8[33])
|
29
|
-
The mock encodes a 1-byte header + 32-byte pk.
|
30
|
-
"""
|
29
|
+
def quote_gen(pk: TensorType) -> TensorType:
|
30
|
+
"""TEE quote generation binding the provided ephemeral public key."""
|
31
31
|
_ = pk # Mark as used for the decorator
|
32
|
-
return TensorType(UINT8, (
|
33
|
-
|
34
|
-
|
35
|
-
@_TEE_MOD.
|
36
|
-
def attest(
|
37
|
-
|
38
|
-
|
39
|
-
|
40
|
-
|
41
|
-
|
42
|
-
|
32
|
+
return TensorType(UINT8, (-1,))
|
33
|
+
|
34
|
+
|
35
|
+
@_TEE_MOD.op_def()
|
36
|
+
def attest(
|
37
|
+
quote: MPObject, platform: str
|
38
|
+
) -> tuple[PFunction, list[MPObject], PyTreeDef]:
|
39
|
+
"""TEE quote verification returning the attested TEE public key."""
|
40
|
+
|
41
|
+
ins_info = [TensorType.from_obj(quote)]
|
42
|
+
outs_info = [TensorType(UINT8, (32,))] # pk is always 32 bytes for x25519
|
43
|
+
pfunc = PFunction(
|
44
|
+
fn_type="tee.attest",
|
45
|
+
ins_info=ins_info,
|
46
|
+
outs_info=outs_info,
|
47
|
+
platform=platform,
|
48
|
+
)
|
49
|
+
_, treedef = tree_flatten(outs_info[0])
|
50
|
+
|
51
|
+
return pfunc, [quote], treedef
|
@@ -1,6 +1,6 @@
|
|
1
1
|
mplang/__init__.py,sha256=ofO-F-CNoVIxpMpTJtTJoQtKegJcHwcOJLzoVispiyc,1852
|
2
2
|
mplang/api.py,sha256=ssmv0_CyZPFORhOUJ84Jo6NwRJSK7_Ono3n7ZjEg4sA,3058
|
3
|
-
mplang/device.py,sha256=
|
3
|
+
mplang/device.py,sha256=UxfsgaGrJaQyAOzrQBWYArqU6CHIvyONWxgAJHXVQ_0,13041
|
4
4
|
mplang/analysis/__init__.py,sha256=CTHFvRsi-nFngojqjn08UaR3RY9i7CJ7T2UdR95kCrk,1056
|
5
5
|
mplang/analysis/diagram.py,sha256=ffwgD12gL1_KH1uJ_EYkjmIlDrfxYJJkWj-wHl09_Xk,19520
|
6
6
|
mplang/core/__init__.py,sha256=lWxlEKfRwX7FNDzgyKZ1fiDMaCiqkyg0j5mKlZD_v7g,2244
|
@@ -29,9 +29,9 @@ mplang/core/expr/walk.py,sha256=lXkGJEEuvKGDqQihbxXPxfz2RfR1Q1zYUlt11iooQW0,1188
|
|
29
29
|
mplang/kernels/__init__.py,sha256=2WE4cmW96Xkzyq2yRRYNww4kZ5o6u6NbPV0BxqZG698,581
|
30
30
|
mplang/kernels/base.py,sha256=-YV4Aj5fs6GT4ehS6Tyi8WQ-amxn5edHTFJRQzyjHXY,3826
|
31
31
|
mplang/kernels/builtin.py,sha256=nSuM79cn7M6M27A6Y8ycilXT_qAlB1ktkwkRX6dv_VQ,7052
|
32
|
-
mplang/kernels/context.py,sha256=
|
32
|
+
mplang/kernels/context.py,sha256=n-Z7fz7HjHb3UY380iZcasmn2sK-OQUGEIWJk2-fT18,13602
|
33
33
|
mplang/kernels/crypto.py,sha256=s7R0yd4Fk5cI2Qd3LpLc-kmbVuk8fFsbKbfKi43R0aE,3892
|
34
|
-
mplang/kernels/mock_tee.py,sha256=
|
34
|
+
mplang/kernels/mock_tee.py,sha256=wzXgiCpDWqQFS1HJHa5z_GxRJM7-r5l6T7KBpNcIZlg,2457
|
35
35
|
mplang/kernels/phe.py,sha256=8-_1IFPOaGECGj9mbYja8XoqbMYnYqfpDNVyMJd8J1Y,65247
|
36
36
|
mplang/kernels/spu.py,sha256=LkM8tNzhwTa8lufNgClHfnI4LNu25cdWLQZdJsMDEO8,9301
|
37
37
|
mplang/kernels/sql_duckdb.py,sha256=UN1Ev6-MxF_-65zMExUsLScC9PlmEIEcN8YziIoX_rY,1724
|
@@ -45,7 +45,7 @@ mplang/ops/jax_cc.py,sha256=42czYg3hNQbI_nUebXnshlU8ULwM-oBDe_TQoApLNVA,7802
|
|
45
45
|
mplang/ops/phe.py,sha256=SatswExjZWPed8y3qA33BCwIWbvsgHCuCAz_pv2RLLw,6790
|
46
46
|
mplang/ops/spu.py,sha256=UHr5DSoqG08xDYER_11OsMVjGGNXXxsvkFoVvXU8uik,4989
|
47
47
|
mplang/ops/sql.py,sha256=HyY2i5aGC5W7r62JryFSjQCUDXH3kQz82YADwn4z5uc,2015
|
48
|
-
mplang/ops/tee.py,sha256=
|
48
|
+
mplang/ops/tee.py,sha256=1yoaFFF5NI9gJge1-bhbT2lsphjBXErfDsleYlEMoWs,1664
|
49
49
|
mplang/protos/v1alpha1/mpir_pb2.py,sha256=Bros37t-4LMJbuUYVSM65rImUYTtZDhNTIADGbZCKp0,7522
|
50
50
|
mplang/protos/v1alpha1/mpir_pb2.pyi,sha256=GwXR4wPB_kB_36iYS9x-cGI9KDKFMq89KhdLhW_xmvE,19342
|
51
51
|
mplang/protos/v1alpha1/mpir_pb2_grpc.py,sha256=xYOs94SXiNYAlFodACnsXW5QovLsHY5tCk3p76RH5Zc,158
|
@@ -70,8 +70,8 @@ mplang/utils/crypto.py,sha256=rvPomBFtznRHc3RPi6Aip9lsU8zW2oxBqGv1K3vn7Rs,1052
|
|
70
70
|
mplang/utils/func_utils.py,sha256=vCJcZmu0bEbqhOQKdpttV2_MBllIcPSN0b8U4WjNGGo,5164
|
71
71
|
mplang/utils/spu_utils.py,sha256=S3L9RBkBe2AvSuMSQQ12cBY5Y1NPthubvErSX_7nj1A,4158
|
72
72
|
mplang/utils/table_utils.py,sha256=aC-IZOKkSmFkpr3NZchLM0Wt0GOn-rg_xHBHREWBwAU,2202
|
73
|
-
mplang_nightly-0.1.
|
74
|
-
mplang_nightly-0.1.
|
75
|
-
mplang_nightly-0.1.
|
76
|
-
mplang_nightly-0.1.
|
77
|
-
mplang_nightly-0.1.
|
73
|
+
mplang_nightly-0.1.dev158.dist-info/METADATA,sha256=LCrLu9AfQJfLUeapTmp0dAVGjYUjl92YfGLJh4WXKXQ,16547
|
74
|
+
mplang_nightly-0.1.dev158.dist-info/WHEEL,sha256=qtCwoSJWgHk21S1Kb4ihdzI2rlJ1ZKaIurTj_ngOhyQ,87
|
75
|
+
mplang_nightly-0.1.dev158.dist-info/entry_points.txt,sha256=mG1oJT-GAjQR834a62_QIWb7litzWPPyVnwFqm-rWuY,55
|
76
|
+
mplang_nightly-0.1.dev158.dist-info/licenses/LICENSE,sha256=xx0jnfkXJvxRnG63LTGOxlggYnIysveWIZ6H3PNdCrQ,11357
|
77
|
+
mplang_nightly-0.1.dev158.dist-info/RECORD,,
|
File without changes
|
{mplang_nightly-0.1.dev156.dist-info → mplang_nightly-0.1.dev158.dist-info}/entry_points.txt
RENAMED
File without changes
|
{mplang_nightly-0.1.dev156.dist-info → mplang_nightly-0.1.dev158.dist-info}/licenses/LICENSE
RENAMED
File without changes
|