mplang-nightly 0.1.dev155__tar.gz → 0.1.dev157__tar.gz
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- {mplang_nightly-0.1.dev155 → mplang_nightly-0.1.dev157}/PKG-INFO +1 -1
- {mplang_nightly-0.1.dev155 → mplang_nightly-0.1.dev157}/mplang/device.py +19 -5
- {mplang_nightly-0.1.dev155 → mplang_nightly-0.1.dev157}/mplang/kernels/base.py +11 -35
- {mplang_nightly-0.1.dev155 → mplang_nightly-0.1.dev157}/mplang/kernels/context.py +71 -18
- {mplang_nightly-0.1.dev155 → mplang_nightly-0.1.dev157}/mplang/kernels/crypto.py +11 -7
- {mplang_nightly-0.1.dev155 → mplang_nightly-0.1.dev157}/mplang/kernels/mock_tee.py +11 -6
- {mplang_nightly-0.1.dev155 → mplang_nightly-0.1.dev157}/mplang/kernels/spu.py +14 -18
- {mplang_nightly-0.1.dev155 → mplang_nightly-0.1.dev157}/mplang/kernels/stablehlo.py +8 -5
- {mplang_nightly-0.1.dev155 → mplang_nightly-0.1.dev157}/mplang/ops/tee.py +26 -17
- {mplang_nightly-0.1.dev155 → mplang_nightly-0.1.dev157}/mplang/runtime/data_providers.py +13 -19
- {mplang_nightly-0.1.dev155 → mplang_nightly-0.1.dev157}/tests/ops/test_crypto_tee.py +8 -5
- {mplang_nightly-0.1.dev155 → mplang_nightly-0.1.dev157}/tests/simp/test_sugar.py +1 -1
- {mplang_nightly-0.1.dev155 → mplang_nightly-0.1.dev157}/tutorials/3_device.py +4 -1
- {mplang_nightly-0.1.dev155 → mplang_nightly-0.1.dev157}/tutorials/9_tee.py +13 -6
- {mplang_nightly-0.1.dev155 → mplang_nightly-0.1.dev157}/.gitignore +0 -0
- {mplang_nightly-0.1.dev155 → mplang_nightly-0.1.dev157}/LICENSE +0 -0
- {mplang_nightly-0.1.dev155 → mplang_nightly-0.1.dev157}/README.md +0 -0
- {mplang_nightly-0.1.dev155 → mplang_nightly-0.1.dev157}/examples/conf/3pc.yaml +0 -0
- {mplang_nightly-0.1.dev155 → mplang_nightly-0.1.dev157}/examples/stax_nn/README.md +0 -0
- {mplang_nightly-0.1.dev155 → mplang_nightly-0.1.dev157}/examples/stax_nn/models.py +0 -0
- {mplang_nightly-0.1.dev155 → mplang_nightly-0.1.dev157}/examples/stax_nn/stax_nn.py +0 -0
- {mplang_nightly-0.1.dev155 → mplang_nightly-0.1.dev157}/examples/xgboost/hist_jax.py +0 -0
- {mplang_nightly-0.1.dev155 → mplang_nightly-0.1.dev157}/examples/xgboost/hist_jax_test.py +0 -0
- {mplang_nightly-0.1.dev155 → mplang_nightly-0.1.dev157}/examples/xgboost/naive_np.py +0 -0
- {mplang_nightly-0.1.dev155 → mplang_nightly-0.1.dev157}/examples/xgboost/readme.md +0 -0
- {mplang_nightly-0.1.dev155 → mplang_nightly-0.1.dev157}/examples/xgboost/sgb.py +0 -0
- {mplang_nightly-0.1.dev155 → mplang_nightly-0.1.dev157}/examples/xgboost/sgb_test.py +0 -0
- {mplang_nightly-0.1.dev155 → mplang_nightly-0.1.dev157}/hatch_build.py +0 -0
- {mplang_nightly-0.1.dev155 → mplang_nightly-0.1.dev157}/mplang/__init__.py +0 -0
- {mplang_nightly-0.1.dev155 → mplang_nightly-0.1.dev157}/mplang/analysis/__init__.py +0 -0
- {mplang_nightly-0.1.dev155 → mplang_nightly-0.1.dev157}/mplang/analysis/diagram.py +0 -0
- {mplang_nightly-0.1.dev155 → mplang_nightly-0.1.dev157}/mplang/api.py +0 -0
- {mplang_nightly-0.1.dev155 → mplang_nightly-0.1.dev157}/mplang/core/__init__.py +0 -0
- {mplang_nightly-0.1.dev155 → mplang_nightly-0.1.dev157}/mplang/core/cluster.py +0 -0
- {mplang_nightly-0.1.dev155 → mplang_nightly-0.1.dev157}/mplang/core/comm.py +0 -0
- {mplang_nightly-0.1.dev155 → mplang_nightly-0.1.dev157}/mplang/core/context_mgr.py +0 -0
- {mplang_nightly-0.1.dev155 → mplang_nightly-0.1.dev157}/mplang/core/dtype.py +0 -0
- {mplang_nightly-0.1.dev155 → mplang_nightly-0.1.dev157}/mplang/core/expr/__init__.py +0 -0
- {mplang_nightly-0.1.dev155 → mplang_nightly-0.1.dev157}/mplang/core/expr/ast.py +0 -0
- {mplang_nightly-0.1.dev155 → mplang_nightly-0.1.dev157}/mplang/core/expr/evaluator.py +0 -0
- {mplang_nightly-0.1.dev155 → mplang_nightly-0.1.dev157}/mplang/core/expr/printer.py +0 -0
- {mplang_nightly-0.1.dev155 → mplang_nightly-0.1.dev157}/mplang/core/expr/transformer.py +0 -0
- {mplang_nightly-0.1.dev155 → mplang_nightly-0.1.dev157}/mplang/core/expr/utils.py +0 -0
- {mplang_nightly-0.1.dev155 → mplang_nightly-0.1.dev157}/mplang/core/expr/visitor.py +0 -0
- {mplang_nightly-0.1.dev155 → mplang_nightly-0.1.dev157}/mplang/core/expr/walk.py +0 -0
- {mplang_nightly-0.1.dev155 → mplang_nightly-0.1.dev157}/mplang/core/interp.py +0 -0
- {mplang_nightly-0.1.dev155 → mplang_nightly-0.1.dev157}/mplang/core/mask.py +0 -0
- {mplang_nightly-0.1.dev155 → mplang_nightly-0.1.dev157}/mplang/core/mpir.py +0 -0
- {mplang_nightly-0.1.dev155 → mplang_nightly-0.1.dev157}/mplang/core/mpobject.py +0 -0
- {mplang_nightly-0.1.dev155 → mplang_nightly-0.1.dev157}/mplang/core/mptype.py +0 -0
- {mplang_nightly-0.1.dev155 → mplang_nightly-0.1.dev157}/mplang/core/pfunc.py +0 -0
- {mplang_nightly-0.1.dev155 → mplang_nightly-0.1.dev157}/mplang/core/primitive.py +0 -0
- {mplang_nightly-0.1.dev155 → mplang_nightly-0.1.dev157}/mplang/core/table.py +0 -0
- {mplang_nightly-0.1.dev155 → mplang_nightly-0.1.dev157}/mplang/core/tensor.py +0 -0
- {mplang_nightly-0.1.dev155 → mplang_nightly-0.1.dev157}/mplang/core/tracer.py +0 -0
- {mplang_nightly-0.1.dev155 → mplang_nightly-0.1.dev157}/mplang/kernels/__init__.py +0 -0
- {mplang_nightly-0.1.dev155 → mplang_nightly-0.1.dev157}/mplang/kernels/builtin.py +0 -0
- {mplang_nightly-0.1.dev155 → mplang_nightly-0.1.dev157}/mplang/kernels/phe.py +0 -0
- {mplang_nightly-0.1.dev155 → mplang_nightly-0.1.dev157}/mplang/kernels/sql_duckdb.py +0 -0
- {mplang_nightly-0.1.dev155 → mplang_nightly-0.1.dev157}/mplang/ops/__init__.py +0 -0
- {mplang_nightly-0.1.dev155 → mplang_nightly-0.1.dev157}/mplang/ops/base.py +0 -0
- {mplang_nightly-0.1.dev155 → mplang_nightly-0.1.dev157}/mplang/ops/builtin.py +0 -0
- {mplang_nightly-0.1.dev155 → mplang_nightly-0.1.dev157}/mplang/ops/crypto.py +0 -0
- {mplang_nightly-0.1.dev155 → mplang_nightly-0.1.dev157}/mplang/ops/ibis_cc.py +0 -0
- {mplang_nightly-0.1.dev155 → mplang_nightly-0.1.dev157}/mplang/ops/jax_cc.py +0 -0
- {mplang_nightly-0.1.dev155 → mplang_nightly-0.1.dev157}/mplang/ops/phe.py +0 -0
- {mplang_nightly-0.1.dev155 → mplang_nightly-0.1.dev157}/mplang/ops/spu.py +0 -0
- {mplang_nightly-0.1.dev155 → mplang_nightly-0.1.dev157}/mplang/ops/sql.py +0 -0
- {mplang_nightly-0.1.dev155 → mplang_nightly-0.1.dev157}/mplang/protos/v1alpha1/mpir_pb2.py +0 -0
- {mplang_nightly-0.1.dev155 → mplang_nightly-0.1.dev157}/mplang/protos/v1alpha1/mpir_pb2.pyi +0 -0
- {mplang_nightly-0.1.dev155 → mplang_nightly-0.1.dev157}/mplang/protos/v1alpha1/mpir_pb2_grpc.py +0 -0
- {mplang_nightly-0.1.dev155 → mplang_nightly-0.1.dev157}/mplang/runtime/__init__.py +0 -0
- {mplang_nightly-0.1.dev155 → mplang_nightly-0.1.dev157}/mplang/runtime/cli.py +0 -0
- {mplang_nightly-0.1.dev155 → mplang_nightly-0.1.dev157}/mplang/runtime/client.py +0 -0
- {mplang_nightly-0.1.dev155 → mplang_nightly-0.1.dev157}/mplang/runtime/communicator.py +0 -0
- {mplang_nightly-0.1.dev155 → mplang_nightly-0.1.dev157}/mplang/runtime/driver.py +0 -0
- {mplang_nightly-0.1.dev155 → mplang_nightly-0.1.dev157}/mplang/runtime/exceptions.py +0 -0
- {mplang_nightly-0.1.dev155 → mplang_nightly-0.1.dev157}/mplang/runtime/http_api.md +0 -0
- {mplang_nightly-0.1.dev155 → mplang_nightly-0.1.dev157}/mplang/runtime/link_comm.py +0 -0
- {mplang_nightly-0.1.dev155 → mplang_nightly-0.1.dev157}/mplang/runtime/server.py +0 -0
- {mplang_nightly-0.1.dev155 → mplang_nightly-0.1.dev157}/mplang/runtime/session.py +0 -0
- {mplang_nightly-0.1.dev155 → mplang_nightly-0.1.dev157}/mplang/runtime/simulation.py +0 -0
- {mplang_nightly-0.1.dev155 → mplang_nightly-0.1.dev157}/mplang/simp/__init__.py +0 -0
- {mplang_nightly-0.1.dev155 → mplang_nightly-0.1.dev157}/mplang/simp/mpi.py +0 -0
- {mplang_nightly-0.1.dev155 → mplang_nightly-0.1.dev157}/mplang/simp/random.py +0 -0
- {mplang_nightly-0.1.dev155 → mplang_nightly-0.1.dev157}/mplang/simp/smpc.py +0 -0
- {mplang_nightly-0.1.dev155 → mplang_nightly-0.1.dev157}/mplang/utils/__init__.py +0 -0
- {mplang_nightly-0.1.dev155 → mplang_nightly-0.1.dev157}/mplang/utils/crypto.py +0 -0
- {mplang_nightly-0.1.dev155 → mplang_nightly-0.1.dev157}/mplang/utils/func_utils.py +0 -0
- {mplang_nightly-0.1.dev155 → mplang_nightly-0.1.dev157}/mplang/utils/spu_utils.py +0 -0
- {mplang_nightly-0.1.dev155 → mplang_nightly-0.1.dev157}/mplang/utils/table_utils.py +0 -0
- {mplang_nightly-0.1.dev155 → mplang_nightly-0.1.dev157}/pyproject.toml +0 -0
- {mplang_nightly-0.1.dev155 → mplang_nightly-0.1.dev157}/tests/__init__.py +0 -0
- {mplang_nightly-0.1.dev155 → mplang_nightly-0.1.dev157}/tests/analysis/test_diagram.py +0 -0
- {mplang_nightly-0.1.dev155 → mplang_nightly-0.1.dev157}/tests/conftest.py +0 -0
- {mplang_nightly-0.1.dev155 → mplang_nightly-0.1.dev157}/tests/core/__init__.py +0 -0
- {mplang_nightly-0.1.dev155 → mplang_nightly-0.1.dev157}/tests/core/expr/__init__.py +0 -0
- {mplang_nightly-0.1.dev155 → mplang_nightly-0.1.dev157}/tests/core/expr/conftest.py +0 -0
- {mplang_nightly-0.1.dev155 → mplang_nightly-0.1.dev157}/tests/core/expr/test_ast.py +0 -0
- {mplang_nightly-0.1.dev155 → mplang_nightly-0.1.dev157}/tests/core/expr/test_printer.py +0 -0
- {mplang_nightly-0.1.dev155 → mplang_nightly-0.1.dev157}/tests/core/expr/test_utils.py +0 -0
- {mplang_nightly-0.1.dev155 → mplang_nightly-0.1.dev157}/tests/core/expr/test_walk.py +0 -0
- {mplang_nightly-0.1.dev155 → mplang_nightly-0.1.dev157}/tests/core/test_cluster.py +0 -0
- {mplang_nightly-0.1.dev155 → mplang_nightly-0.1.dev157}/tests/core/test_dtype.py +0 -0
- {mplang_nightly-0.1.dev155 → mplang_nightly-0.1.dev157}/tests/core/test_mask.py +0 -0
- {mplang_nightly-0.1.dev155 → mplang_nightly-0.1.dev157}/tests/core/test_mpir.py +0 -0
- {mplang_nightly-0.1.dev155 → mplang_nightly-0.1.dev157}/tests/core/test_mptype.py +0 -0
- {mplang_nightly-0.1.dev155 → mplang_nightly-0.1.dev157}/tests/core/test_primitive.py +0 -0
- {mplang_nightly-0.1.dev155 → mplang_nightly-0.1.dev157}/tests/core/test_table.py +0 -0
- {mplang_nightly-0.1.dev155 → mplang_nightly-0.1.dev157}/tests/core/test_tensor.py +0 -0
- {mplang_nightly-0.1.dev155 → mplang_nightly-0.1.dev157}/tests/core/test_tracer.py +0 -0
- {mplang_nightly-0.1.dev155 → mplang_nightly-0.1.dev157}/tests/device/__init__.py +0 -0
- {mplang_nightly-0.1.dev155 → mplang_nightly-0.1.dev157}/tests/device/test_device_basic.py +0 -0
- {mplang_nightly-0.1.dev155 → mplang_nightly-0.1.dev157}/tests/integration/README.md +0 -0
- {mplang_nightly-0.1.dev155 → mplang_nightly-0.1.dev157}/tests/integration/test_crypto_roundtrip.py +0 -0
- {mplang_nightly-0.1.dev155 → mplang_nightly-0.1.dev157}/tests/integration/test_http_e2e.py +0 -0
- {mplang_nightly-0.1.dev155 → mplang_nightly-0.1.dev157}/tests/integration/test_symbols_roundtrip.py +0 -0
- {mplang_nightly-0.1.dev155 → mplang_nightly-0.1.dev157}/tests/integration/test_tutorials.py +0 -0
- {mplang_nightly-0.1.dev155 → mplang_nightly-0.1.dev157}/tests/integration/test_unused_param_integration.py +0 -0
- {mplang_nightly-0.1.dev155 → mplang_nightly-0.1.dev157}/tests/kernels/test_builtin.py +0 -0
- {mplang_nightly-0.1.dev155 → mplang_nightly-0.1.dev157}/tests/kernels/test_debug_print.py +0 -0
- {mplang_nightly-0.1.dev155 → mplang_nightly-0.1.dev157}/tests/kernels/test_kernel_binding.py +0 -0
- {mplang_nightly-0.1.dev155 → mplang_nightly-0.1.dev157}/tests/kernels/test_phe.py +0 -0
- {mplang_nightly-0.1.dev155 → mplang_nightly-0.1.dev157}/tests/kernels/test_spu.py +0 -0
- {mplang_nightly-0.1.dev155 → mplang_nightly-0.1.dev157}/tests/kernels/test_sql_duckdb.py +0 -0
- {mplang_nightly-0.1.dev155 → mplang_nightly-0.1.dev157}/tests/kernels/test_stablehlo.py +0 -0
- {mplang_nightly-0.1.dev155 → mplang_nightly-0.1.dev157}/tests/ops/__init__.py +0 -0
- {mplang_nightly-0.1.dev155 → mplang_nightly-0.1.dev157}/tests/ops/dummy.py +0 -0
- {mplang_nightly-0.1.dev155 → mplang_nightly-0.1.dev157}/tests/ops/test_builtin_pack.py +0 -0
- {mplang_nightly-0.1.dev155 → mplang_nightly-0.1.dev157}/tests/ops/test_feop_base.py +0 -0
- {mplang_nightly-0.1.dev155 → mplang_nightly-0.1.dev157}/tests/ops/test_ibis.py +0 -0
- {mplang_nightly-0.1.dev155 → mplang_nightly-0.1.dev157}/tests/ops/test_ibis_cc.py +0 -0
- {mplang_nightly-0.1.dev155 → mplang_nightly-0.1.dev157}/tests/ops/test_jax_cc.py +0 -0
- {mplang_nightly-0.1.dev155 → mplang_nightly-0.1.dev157}/tests/ops/test_phe.py +0 -0
- {mplang_nightly-0.1.dev155 → mplang_nightly-0.1.dev157}/tests/ops/test_spu.py +0 -0
- {mplang_nightly-0.1.dev155 → mplang_nightly-0.1.dev157}/tests/ops/test_spu_defensive.py +0 -0
- {mplang_nightly-0.1.dev155 → mplang_nightly-0.1.dev157}/tests/ops/test_sql.py +0 -0
- {mplang_nightly-0.1.dev155 → mplang_nightly-0.1.dev157}/tests/ops/test_table_tensor_conversion.py +0 -0
- {mplang_nightly-0.1.dev155 → mplang_nightly-0.1.dev157}/tests/runtime/__init__.py +0 -0
- {mplang_nightly-0.1.dev155 → mplang_nightly-0.1.dev157}/tests/runtime/test_cli.py +0 -0
- {mplang_nightly-0.1.dev155 → mplang_nightly-0.1.dev157}/tests/runtime/test_communicator.py +0 -0
- {mplang_nightly-0.1.dev155 → mplang_nightly-0.1.dev157}/tests/runtime/test_driver.py +0 -0
- {mplang_nightly-0.1.dev155 → mplang_nightly-0.1.dev157}/tests/runtime/test_server.py +0 -0
- {mplang_nightly-0.1.dev155 → mplang_nightly-0.1.dev157}/tests/runtime/test_simulation.py +0 -0
- {mplang_nightly-0.1.dev155 → mplang_nightly-0.1.dev157}/tests/simp/test_mpi.py +0 -0
- {mplang_nightly-0.1.dev155 → mplang_nightly-0.1.dev157}/tests/simp/test_random.py +0 -0
- {mplang_nightly-0.1.dev155 → mplang_nightly-0.1.dev157}/tests/simp/test_simp.py +0 -0
- {mplang_nightly-0.1.dev155 → mplang_nightly-0.1.dev157}/tests/simp/test_smpc.py +0 -0
- {mplang_nightly-0.1.dev155 → mplang_nightly-0.1.dev157}/tests/utils/__init__.py +0 -0
- {mplang_nightly-0.1.dev155 → mplang_nightly-0.1.dev157}/tests/utils/server_fixtures.py +0 -0
- {mplang_nightly-0.1.dev155 → mplang_nightly-0.1.dev157}/tests/utils/test_func_utils.py +0 -0
- {mplang_nightly-0.1.dev155 → mplang_nightly-0.1.dev157}/tests/utils/test_spu_utils.py +0 -0
- {mplang_nightly-0.1.dev155 → mplang_nightly-0.1.dev157}/tests/utils/test_table_utils.py +0 -0
- {mplang_nightly-0.1.dev155 → mplang_nightly-0.1.dev157}/tutorials/0_basic.py +0 -0
- {mplang_nightly-0.1.dev155 → mplang_nightly-0.1.dev157}/tutorials/10_analysis.py +0 -0
- {mplang_nightly-0.1.dev155 → mplang_nightly-0.1.dev157}/tutorials/1_condition.py +0 -0
- {mplang_nightly-0.1.dev155 → mplang_nightly-0.1.dev157}/tutorials/2_whileloop.py +0 -0
- {mplang_nightly-0.1.dev155 → mplang_nightly-0.1.dev157}/tutorials/4_simulation.py +0 -0
- {mplang_nightly-0.1.dev155 → mplang_nightly-0.1.dev157}/tutorials/5_ir_dump.py +0 -0
- {mplang_nightly-0.1.dev155 → mplang_nightly-0.1.dev157}/tutorials/6_advanced.py +0 -0
- {mplang_nightly-0.1.dev155 → mplang_nightly-0.1.dev157}/tutorials/7_stdio.py +0 -0
- {mplang_nightly-0.1.dev155 → mplang_nightly-0.1.dev157}/tutorials/8_phe.py +0 -0
- {mplang_nightly-0.1.dev155 → mplang_nightly-0.1.dev157}/tutorials/__init__.py +0 -0
- {mplang_nightly-0.1.dev155 → mplang_nightly-0.1.dev157}/tutorials/pitfalls/late_binding.py +0 -0
- {mplang_nightly-0.1.dev155 → mplang_nightly-0.1.dev157}/tutorials/pitfalls/rand.py +0 -0
- {mplang_nightly-0.1.dev155 → mplang_nightly-0.1.dev157}/tutorials/run.sh +0 -0
@@ -207,8 +207,15 @@ def _d2d(to_dev_id: str, obj: MPObject) -> MPObject:
|
|
207
207
|
assert len(frm_dev.members) == 1 and len(to_dev.members) == 1
|
208
208
|
frm_rank = frm_dev.members[0].rank
|
209
209
|
tee_rank = to_dev.members[0].rank
|
210
|
+
platform = to_dev.config.get("platform")
|
211
|
+
if not platform:
|
212
|
+
raise ValueError(
|
213
|
+
f"TEE device '{to_dev_id}' is missing 'platform' in its config."
|
214
|
+
)
|
210
215
|
# Ensure sessions (both directions) exist for this PPU<->TEE pair
|
211
|
-
sess_p, sess_t = _ensure_tee_session(
|
216
|
+
sess_p, sess_t = _ensure_tee_session(
|
217
|
+
frm_dev_id, to_dev_id, frm_rank, tee_rank, platform
|
218
|
+
)
|
212
219
|
# Bytes-only path: pack -> enc -> p2p -> dec -> unpack (with static out type)
|
213
220
|
obj_ty = TensorType.from_obj(obj)
|
214
221
|
b = simp.runAt(frm_rank, builtin.pack)(obj)
|
@@ -222,8 +229,15 @@ def _d2d(to_dev_id: str, obj: MPObject) -> MPObject:
|
|
222
229
|
assert len(frm_dev.members) == 1 and len(to_dev.members) == 1
|
223
230
|
tee_rank = frm_dev.members[0].rank
|
224
231
|
ppu_rank = to_dev.members[0].rank
|
232
|
+
platform = to_dev.config.get("platform")
|
233
|
+
if not platform:
|
234
|
+
raise ValueError(
|
235
|
+
f"TEE device '{to_dev_id}' is missing 'platform' in its config."
|
236
|
+
)
|
225
237
|
# Ensure bidirectional session established for this pair
|
226
|
-
sess_p, sess_t = _ensure_tee_session(
|
238
|
+
sess_p, sess_t = _ensure_tee_session(
|
239
|
+
to_dev_id, frm_dev_id, ppu_rank, tee_rank, platform
|
240
|
+
)
|
227
241
|
obj_ty = TensorType.from_obj(obj)
|
228
242
|
b = simp.runAt(tee_rank, builtin.pack)(obj)
|
229
243
|
ct = simp.runAt(tee_rank, crypto.enc)(b, sess_t)
|
@@ -245,7 +259,7 @@ def _d2d(to_dev_id: str, obj: MPObject) -> MPObject:
|
|
245
259
|
|
246
260
|
|
247
261
|
def _ensure_tee_session(
|
248
|
-
frm_dev_id: str, to_dev_id: str, frm_rank: int, tee_rank: int
|
262
|
+
frm_dev_id: str, to_dev_id: str, frm_rank: int, tee_rank: int, platform: str
|
249
263
|
) -> tuple[MPObject, MPObject]:
|
250
264
|
"""Ensure a TEE session (sess_p at sender, sess_t at TEE) exists.
|
251
265
|
|
@@ -263,11 +277,11 @@ def _ensure_tee_session(
|
|
263
277
|
# 1) TEE generates (sk, pk) and quote(pk)
|
264
278
|
# KEM suite currently constant; future: read from tee device config (e.g. cluster_spec.devices[dev_id].config)
|
265
279
|
tee_sk, tee_pk = simp.runAt(tee_rank, crypto.kem_keygen)(_TEE_KEM_SUITE)
|
266
|
-
quote = simp.runAt(tee_rank, tee.
|
280
|
+
quote = simp.runAt(tee_rank, tee.quote_gen)(tee_pk)
|
267
281
|
|
268
282
|
# 2) Send quote to sender and attest to obtain TEE pk
|
269
283
|
quote_at_sender = mpi.p2p(tee_rank, frm_rank, quote)
|
270
|
-
tee_pk_at_sender = simp.runAt(frm_rank, tee.attest)(quote_at_sender)
|
284
|
+
tee_pk_at_sender = simp.runAt(frm_rank, tee.attest)(quote_at_sender, platform)
|
271
285
|
|
272
286
|
# 3) Sender generates its ephemeral keypair and sends its pk to TEE
|
273
287
|
v_sk, v_pk = simp.runAt(frm_rank, crypto.kem_keygen)(_TEE_KEM_SUITE)
|
@@ -34,7 +34,10 @@ from __future__ import annotations
|
|
34
34
|
import contextvars
|
35
35
|
from collections.abc import Callable
|
36
36
|
from dataclasses import dataclass
|
37
|
-
from typing import Any
|
37
|
+
from typing import TYPE_CHECKING, Any
|
38
|
+
|
39
|
+
if TYPE_CHECKING:
|
40
|
+
from mplang.kernels.context import RuntimeContext
|
38
41
|
|
39
42
|
__all__ = [
|
40
43
|
"KernelContext",
|
@@ -48,12 +51,15 @@ __all__ = [
|
|
48
51
|
|
49
52
|
@dataclass
|
50
53
|
class KernelContext:
|
51
|
-
"""Ephemeral
|
54
|
+
"""Ephemeral per-kernel invocation context.
|
55
|
+
|
56
|
+
Cross-kernel persistent state (RNGs, compiled artifacts, environment handles)
|
57
|
+
should be stored in RuntimeContext.
|
58
|
+
"""
|
52
59
|
|
53
60
|
rank: int
|
54
61
|
world_size: int
|
55
|
-
|
56
|
-
cache: dict[str, Any] # runtime-level shared cache (per BackendRuntime)
|
62
|
+
runtime: RuntimeContext
|
57
63
|
|
58
64
|
|
59
65
|
_CTX_VAR: contextvars.ContextVar[KernelContext | None] = contextvars.ContextVar(
|
@@ -62,37 +68,7 @@ _CTX_VAR: contextvars.ContextVar[KernelContext | None] = contextvars.ContextVar(
|
|
62
68
|
|
63
69
|
|
64
70
|
def cur_kctx() -> KernelContext:
|
65
|
-
"""Return
|
66
|
-
|
67
|
-
Two storages:
|
68
|
-
- state: namespaced pockets (dict[str, dict]) for backend-local mutable helpers
|
69
|
-
- cache: global (per runtime) shared dict; prefer state unless truly cross-backend
|
70
|
-
|
71
|
-
Examples:
|
72
|
-
1) Compile cache::
|
73
|
-
@kernel_def("mlir.stablehlo")
|
74
|
-
def _exec(pfunc, args):
|
75
|
-
ctx = cur_kctx()
|
76
|
-
pocket = ctx.state.setdefault("stablehlo", {})
|
77
|
-
cache = pocket.setdefault("compile_cache", {})
|
78
|
-
text = pfunc.fn_text
|
79
|
-
mod = cache.get(text)
|
80
|
-
if mod is None:
|
81
|
-
mod = compile_mlir(text)
|
82
|
-
cache[text] = mod
|
83
|
-
return run(mod, args)
|
84
|
-
|
85
|
-
2) Deterministic RNG::
|
86
|
-
@kernel_def("crypto.keygen")
|
87
|
-
def _keygen(pfunc, args):
|
88
|
-
ctx = cur_kctx()
|
89
|
-
pocket = ctx.state.setdefault("crypto", {})
|
90
|
-
rng = pocket.get("rng")
|
91
|
-
if rng is None:
|
92
|
-
rng = np.random.default_rng(1234 + ctx.rank * 7919)
|
93
|
-
pocket["rng"] = rng
|
94
|
-
return (rng.integers(0, 256, size=(32,), dtype=np.uint8),)
|
95
|
-
"""
|
71
|
+
"""Return current kernel execution context (only valid inside kernel)."""
|
96
72
|
ctx = _CTX_VAR.get()
|
97
73
|
if ctx is None:
|
98
74
|
raise RuntimeError("cur_kctx() called outside backend kernel execution")
|
@@ -91,7 +91,7 @@ _DEFAULT_BINDINGS: dict[str, str] = {
|
|
91
91
|
# generic SQL op; backend-specific kernel id for duckdb
|
92
92
|
"sql.run": "duckdb.run_sql",
|
93
93
|
# tee
|
94
|
-
# "tee.
|
94
|
+
# "tee.quote_gen": "mock_tee.quote_gen",
|
95
95
|
# "tee.attest": "mock_tee.attest",
|
96
96
|
}
|
97
97
|
|
@@ -118,12 +118,21 @@ class RuntimeContext:
|
|
118
118
|
op_type -> kernel_id and form a *template* for dispatch. After
|
119
119
|
initialization, all (re)binding must go through ``bind_op`` /
|
120
120
|
``rebind_op`` on this context (scoped to THIS runtime only).
|
121
|
-
state
|
122
|
-
Mutable
|
123
|
-
|
121
|
+
state : dict, optional
|
122
|
+
Mutable per-runtime key/value storage for kernels. Flat key space;
|
123
|
+
callers SHOULD use dotted prefixes (e.g. "stablehlo.compile_cache").
|
124
|
+
Kernels own their *state* (functional correctness data, caches,
|
125
|
+
handles, compiled objects, RNGs, etc.). Runtime does not interpret
|
126
|
+
structure—values may themselves be dicts if a kernel wants its own
|
127
|
+
pocket. Created empty when omitted.
|
128
|
+
stats : dict, optional
|
129
|
+
Mutable statistics/telemetry owned by the runtime (usage counters,
|
130
|
+
timings, profiling aids). Kernels may increment counters but should
|
131
|
+
avoid storing functional state here. A default "op_calls" mapping is
|
132
|
+
ensured. Created empty when omitted.
|
124
133
|
"""
|
125
134
|
|
126
|
-
__slots__ = ("_ibindings", "
|
135
|
+
__slots__ = ("_ibindings", "rank", "state", "stats", "world_size")
|
127
136
|
|
128
137
|
def __init__(
|
129
138
|
self,
|
@@ -131,8 +140,7 @@ class RuntimeContext:
|
|
131
140
|
world_size: int,
|
132
141
|
initial_bindings: Mapping[str, str] | None = None,
|
133
142
|
*,
|
134
|
-
state: dict[str,
|
135
|
-
cache: dict[str, Any] | None = None,
|
143
|
+
state: dict[str, Any] | None = None,
|
136
144
|
stats: dict[str, Any] | None = None,
|
137
145
|
) -> None:
|
138
146
|
_ensure_impl_imported()
|
@@ -144,7 +152,6 @@ class RuntimeContext:
|
|
144
152
|
**(initial_bindings or {}),
|
145
153
|
}
|
146
154
|
self.state = state if state is not None else {}
|
147
|
-
self.cache = cache if cache is not None else {}
|
148
155
|
self.stats = stats if stats is not None else {}
|
149
156
|
self.stats.setdefault("op_calls", {})
|
150
157
|
|
@@ -168,19 +175,15 @@ class RuntimeContext:
|
|
168
175
|
if isinstance(ins_spec, TensorType):
|
169
176
|
_validate_tensor_arg(fn_type, idx, ins_spec, val)
|
170
177
|
continue
|
178
|
+
|
171
179
|
# install kernel context
|
172
|
-
kctx = KernelContext(
|
173
|
-
|
174
|
-
world_size=self.world_size,
|
175
|
-
state=self.state,
|
176
|
-
cache=self.cache,
|
177
|
-
)
|
178
|
-
token = base._CTX_VAR.set(kctx) # type: ignore[attr-defined]
|
180
|
+
kctx = KernelContext(rank=self.rank, world_size=self.world_size, runtime=self)
|
181
|
+
token = base._CTX_VAR.set(kctx)
|
179
182
|
try:
|
180
183
|
raw = fn(pfunc, *arg_list)
|
181
184
|
finally:
|
182
|
-
base._CTX_VAR.reset(token)
|
183
|
-
|
185
|
+
base._CTX_VAR.reset(token)
|
186
|
+
|
184
187
|
try:
|
185
188
|
op_calls = self.stats.setdefault("op_calls", {})
|
186
189
|
op_calls[fn_type] = op_calls.get(fn_type, 0) + 1
|
@@ -213,7 +216,57 @@ class RuntimeContext:
|
|
213
216
|
|
214
217
|
def reset(self) -> None:
|
215
218
|
self.state.clear()
|
216
|
-
|
219
|
+
|
220
|
+
# ---- runtime state API (flat key space) ----
|
221
|
+
# Keys are treated atomically; convention encourages dotted prefixes
|
222
|
+
# (e.g. 'stablehlo.compile_cache.hash', 'crypto.rng'). Implementation
|
223
|
+
# does NOT parse or create hierarchical dicts—any grouping is purely
|
224
|
+
# by string prefix. Values themselves MAY be dicts if callers want a
|
225
|
+
# manual pocket. This keeps semantics simple and predictable.
|
226
|
+
|
227
|
+
def ensure_state(self, key: str, factory: type | Any = dict) -> Any:
|
228
|
+
"""Return value for key; if absent create via factory and store.
|
229
|
+
|
230
|
+
Key is not parsed; dotted forms are allowed but treated as a single
|
231
|
+
map key. Use consistent prefixes for grouping (e.g. 'spu.config').
|
232
|
+
"""
|
233
|
+
if not key:
|
234
|
+
raise ValueError("empty state key")
|
235
|
+
val = self.state.get(key)
|
236
|
+
if val is None:
|
237
|
+
val = factory()
|
238
|
+
self.state[key] = val
|
239
|
+
return val
|
240
|
+
|
241
|
+
def get_state(self, key: str, default: Any | None = None) -> Any:
|
242
|
+
if not key:
|
243
|
+
raise ValueError("empty state key")
|
244
|
+
return self.state.get(key, default)
|
245
|
+
|
246
|
+
def set_state(self, key: str, value: Any) -> None:
|
247
|
+
if not key:
|
248
|
+
raise ValueError("empty state key")
|
249
|
+
self.state[key] = value
|
250
|
+
|
251
|
+
def del_state(self, key: str) -> None:
|
252
|
+
if not key:
|
253
|
+
raise ValueError("empty state key")
|
254
|
+
self.state.pop(key, None)
|
255
|
+
|
256
|
+
def list_state(self, prefix: str = "") -> dict[str, Any]:
|
257
|
+
"""Return mapping of key -> value; optional prefix filter.
|
258
|
+
|
259
|
+
Prefix match is string-based; if prefix is non-empty include keys
|
260
|
+
where key == prefix or key starts with prefix + '.'.
|
261
|
+
"""
|
262
|
+
if not prefix:
|
263
|
+
return dict(self.state)
|
264
|
+
pref = prefix if prefix.endswith(".") else prefix + "."
|
265
|
+
out: dict[str, Any] = {}
|
266
|
+
for k, v in self.state.items():
|
267
|
+
if k == prefix or k.startswith(pref):
|
268
|
+
out[k] = v
|
269
|
+
return out
|
217
270
|
|
218
271
|
# ---- explicit (re)binding API ----
|
219
272
|
def bind_op(self, op_type: str, kernel_id: str, *, force: bool = False) -> None:
|
@@ -27,15 +27,19 @@ __all__: list[str] = [] # flat kernels only
|
|
27
27
|
|
28
28
|
|
29
29
|
def _get_rng() -> np.random.Generator:
|
30
|
-
"""Get (and lazily create) per-rank RNG for crypto kernels.
|
30
|
+
"""Get (and lazily create) per-rank RNG for crypto kernels.
|
31
|
+
|
32
|
+
Runtime state is untyped, so we narrow the type explicitly for mypy.
|
33
|
+
"""
|
31
34
|
kctx = cur_kctx()
|
32
|
-
|
33
|
-
|
34
|
-
if
|
35
|
+
rt = kctx.runtime
|
36
|
+
rng_obj = rt.get_state("crypto.rng")
|
37
|
+
if rng_obj is None:
|
35
38
|
seed = int(os.environ.get("MPLANG_CRYPTO_SEED", "0")) + kctx.rank * 7919
|
36
|
-
|
37
|
-
|
38
|
-
|
39
|
+
rng_obj = np.random.default_rng(seed)
|
40
|
+
rt.set_state("crypto.rng", rng_obj)
|
41
|
+
assert isinstance(rng_obj, np.random.Generator) # narrow
|
42
|
+
return rng_obj
|
39
43
|
|
40
44
|
|
41
45
|
def _keystream(key: bytes, nonce: bytes, length: int) -> bytes:
|
@@ -28,12 +28,13 @@ __all__: list[str] = []
|
|
28
28
|
|
29
29
|
def _rng() -> np.random.Generator:
|
30
30
|
kctx = cur_kctx()
|
31
|
-
|
32
|
-
r =
|
31
|
+
rt = kctx.runtime
|
32
|
+
r = rt.get_state("tee.rng")
|
33
33
|
if r is None:
|
34
34
|
seed = int(os.environ.get("MPLANG_TEE_SEED", "0")) + kctx.rank * 10007
|
35
35
|
r = np.random.default_rng(seed)
|
36
|
-
|
36
|
+
rt.set_state("tee.rng", r)
|
37
|
+
assert isinstance(r, np.random.Generator) # type narrowing for mypy
|
37
38
|
return r
|
38
39
|
|
39
40
|
|
@@ -44,10 +45,10 @@ def _quote_from_pk(pk: np.ndarray) -> NDArray[np.uint8]:
|
|
44
45
|
return out
|
45
46
|
|
46
47
|
|
47
|
-
@kernel_def("mock_tee.
|
48
|
-
def
|
48
|
+
@kernel_def("mock_tee.quote_gen")
|
49
|
+
def _tee_quote_gen(pfunc: PFunction, pk: object) -> NDArray[np.uint8]:
|
49
50
|
warnings.warn(
|
50
|
-
"Insecure mock TEE kernel 'mock_tee.
|
51
|
+
"Insecure mock TEE kernel 'mock_tee.quote_gen' in use. NOT secure; for local testing only.",
|
51
52
|
stacklevel=3,
|
52
53
|
)
|
53
54
|
pk = np.asarray(pk, dtype=np.uint8)
|
@@ -63,6 +64,10 @@ def _tee_attest(pfunc: PFunction, quote: object) -> NDArray[np.uint8]:
|
|
63
64
|
stacklevel=3,
|
64
65
|
)
|
65
66
|
quote = np.asarray(quote, dtype=np.uint8)
|
67
|
+
platform = pfunc.attrs.get("platform")
|
68
|
+
if platform is None:
|
69
|
+
raise ValueError("missing required 'platform' attribute in PFunction")
|
70
|
+
|
66
71
|
if quote.size != 33:
|
67
72
|
raise ValueError("mock quote must be 33 bytes (1 header + 32 pk)")
|
68
73
|
return quote[1:33].astype(np.uint8)
|
@@ -63,14 +63,10 @@ class SpuValue:
|
|
63
63
|
return f"SpuValue({self.shape},{self.dtype},{self.vtype})"
|
64
64
|
|
65
65
|
|
66
|
-
def _get_spu_pocket() -> dict[str, Any]:
|
67
|
-
return cur_kctx().state.setdefault("spu", {})
|
68
|
-
|
69
|
-
|
70
66
|
def _get_spu_config_and_world() -> tuple[libspu.RuntimeConfig, int]:
|
71
|
-
|
72
|
-
cfg =
|
73
|
-
world =
|
67
|
+
kctx = cur_kctx()
|
68
|
+
cfg = kctx.runtime.get_state("spu.config")
|
69
|
+
world = kctx.runtime.get_state("spu.world")
|
74
70
|
if cfg is None or world is None:
|
75
71
|
raise RuntimeError("SPU kernel state not initialized (config/world)")
|
76
72
|
return cfg, int(world)
|
@@ -84,12 +80,12 @@ def _register_spu_env(
|
|
84
80
|
Idempotent: if config/world already set, they must match; link is recorded per rank.
|
85
81
|
This replaces previous global fallback seeding logic.
|
86
82
|
"""
|
87
|
-
|
88
|
-
prev_cfg =
|
89
|
-
prev_world =
|
83
|
+
kctx = cur_kctx()
|
84
|
+
prev_cfg = kctx.runtime.get_state("spu.config")
|
85
|
+
prev_world = kctx.runtime.get_state("spu.world")
|
90
86
|
if prev_cfg is None:
|
91
|
-
|
92
|
-
|
87
|
+
kctx.runtime.set_state("spu.config", config)
|
88
|
+
kctx.runtime.set_state("spu.world", world_size)
|
93
89
|
else:
|
94
90
|
# libspu RuntimeConfig may not implement __eq__; compare serialized repr
|
95
91
|
same_cfg = (
|
@@ -102,7 +98,7 @@ def _register_spu_env(
|
|
102
98
|
raise RuntimeError("Conflicting SPU env registration")
|
103
99
|
# Store single link per runtime (one runtime per rank)
|
104
100
|
if link_ctx is not None:
|
105
|
-
|
101
|
+
kctx.runtime.set_state("spu.link", link_ctx)
|
106
102
|
|
107
103
|
|
108
104
|
@kernel_def("spu.seed_env")
|
@@ -197,16 +193,16 @@ def _spu_run_mlir(pfunc: PFunction, *args: Any) -> Any:
|
|
197
193
|
)
|
198
194
|
|
199
195
|
cfg, _ = _get_spu_config_and_world()
|
200
|
-
|
201
|
-
link_ctx
|
196
|
+
kctx = cur_kctx()
|
197
|
+
link_ctx = kctx.runtime.get_state("spu.link")
|
202
198
|
if link_ctx is None:
|
203
199
|
raise RuntimeError("Rank not participating in SPU; no link set via seed_env")
|
204
200
|
|
205
|
-
# Lazy runtime cache
|
206
|
-
spu_rt =
|
201
|
+
# Lazy runtime cache under key spu.runtime
|
202
|
+
spu_rt = kctx.runtime.get_state("spu.runtime")
|
207
203
|
if spu_rt is None:
|
208
204
|
spu_rt = spu_api.Runtime(link_ctx.get_lctx(), cfg)
|
209
|
-
|
205
|
+
kctx.runtime.set_state("spu.runtime", spu_rt)
|
210
206
|
|
211
207
|
# Validate that all inputs are SpuValue objects
|
212
208
|
for i, arg in enumerate(args):
|
@@ -36,11 +36,14 @@ def _stablehlo_exec(pfunc: PFunction, *args: Any) -> Any:
|
|
36
36
|
if isinstance(mlir_text, bytes):
|
37
37
|
mlir_text = mlir_text.decode("utf-8")
|
38
38
|
|
39
|
-
#
|
39
|
+
# Flat-key compile cache: stablehlo.compile_cache.<hash>
|
40
40
|
ctx = cur_kctx()
|
41
|
-
|
42
|
-
|
43
|
-
|
41
|
+
rt = ctx.runtime
|
42
|
+
import hashlib
|
43
|
+
|
44
|
+
h = hashlib.sha256(mlir_text.encode("utf-8")).hexdigest()[:16]
|
45
|
+
key = f"stablehlo.compile_cache.{h}"
|
46
|
+
compiled = rt.get_state(key)
|
44
47
|
if compiled is None:
|
45
48
|
backend = jax.default_backend()
|
46
49
|
client = xla_bridge.get_backend(backend)
|
@@ -49,7 +52,7 @@ def _stablehlo_exec(pfunc: PFunction, *args: Any) -> Any:
|
|
49
52
|
compiled = client.compile(mlir_text, compile_options)
|
50
53
|
except Exception as e: # pragma: no cover
|
51
54
|
raise RuntimeError(f"StableHLO compile failed: {e}") from e
|
52
|
-
|
55
|
+
rt.set_state(key, compiled)
|
53
56
|
|
54
57
|
# Handle JAX's unused parameter elimination via arg_keep_map
|
55
58
|
runtime_args = args
|
@@ -14,7 +14,11 @@
|
|
14
14
|
|
15
15
|
from __future__ import annotations
|
16
16
|
|
17
|
+
from jax.tree_util import PyTreeDef, tree_flatten
|
18
|
+
|
17
19
|
from mplang.core.dtype import UINT8
|
20
|
+
from mplang.core.mpobject import MPObject
|
21
|
+
from mplang.core.pfunc import PFunction
|
18
22
|
from mplang.core.tensor import TensorType
|
19
23
|
from mplang.ops.base import stateless_mod
|
20
24
|
|
@@ -22,21 +26,26 @@ _TEE_MOD = stateless_mod("tee")
|
|
22
26
|
|
23
27
|
|
24
28
|
@_TEE_MOD.simple_op()
|
25
|
-
def
|
26
|
-
"""TEE quote generation binding the provided ephemeral public key.
|
27
|
-
|
28
|
-
API (mock): quote(pk: u8[32]) -> (quote: u8[33])
|
29
|
-
The mock encodes a 1-byte header + 32-byte pk.
|
30
|
-
"""
|
29
|
+
def quote_gen(pk: TensorType) -> TensorType:
|
30
|
+
"""TEE quote generation binding the provided ephemeral public key."""
|
31
31
|
_ = pk # Mark as used for the decorator
|
32
|
-
return TensorType(UINT8, (
|
33
|
-
|
34
|
-
|
35
|
-
@_TEE_MOD.
|
36
|
-
def attest(
|
37
|
-
|
38
|
-
|
39
|
-
|
40
|
-
|
41
|
-
|
42
|
-
|
32
|
+
return TensorType(UINT8, (-1,))
|
33
|
+
|
34
|
+
|
35
|
+
@_TEE_MOD.op_def()
|
36
|
+
def attest(
|
37
|
+
quote: MPObject, platform: str
|
38
|
+
) -> tuple[PFunction, list[MPObject], PyTreeDef]:
|
39
|
+
"""TEE quote verification returning the attested TEE public key."""
|
40
|
+
|
41
|
+
ins_info = [TensorType.from_obj(quote)]
|
42
|
+
outs_info = [TensorType(UINT8, (32,))] # pk is always 32 bytes for x25519
|
43
|
+
pfunc = PFunction(
|
44
|
+
fn_type="tee.attest",
|
45
|
+
ins_info=ins_info,
|
46
|
+
outs_info=outs_info,
|
47
|
+
platform=platform,
|
48
|
+
)
|
49
|
+
_, treedef = tree_flatten(outs_info[0])
|
50
|
+
|
51
|
+
return pfunc, [quote], treedef
|
@@ -173,38 +173,32 @@ class FileProvider(DataProvider):
|
|
173
173
|
np.save(path, np.asarray(value))
|
174
174
|
|
175
175
|
|
176
|
-
class _KeyedPocket:
|
177
|
-
"""Small helper to keep a dict in KernelContext.state under a namespaced key."""
|
178
|
-
|
179
|
-
def __init__(self, ns: str):
|
180
|
-
self.ns = ns
|
181
|
-
|
182
|
-
def get_map(self, ctx: KernelContext) -> dict[str, Any]:
|
183
|
-
pocket = ctx.state.setdefault("resource.providers", {})
|
184
|
-
store = pocket.get(self.ns)
|
185
|
-
if store is None:
|
186
|
-
store = {}
|
187
|
-
pocket[self.ns] = store
|
188
|
-
return store # type: ignore[return-value]
|
189
|
-
|
190
|
-
|
191
176
|
class MemProvider(DataProvider):
|
192
177
|
"""In-memory per-runtime KV provider (per rank, per session/runtime)."""
|
193
178
|
|
194
|
-
|
195
|
-
|
179
|
+
STATE_KEY = "resource.providers.mem"
|
180
|
+
|
181
|
+
@staticmethod
|
182
|
+
def _store(ctx: KernelContext) -> dict[str, Any]:
|
183
|
+
# Use ensure_state so creation is atomic & centralized; enforce dict.
|
184
|
+
store = ctx.runtime.ensure_state(MemProvider.STATE_KEY, dict)
|
185
|
+
if not isinstance(store, dict): # pragma: no cover - defensive
|
186
|
+
raise TypeError(
|
187
|
+
f"runtime state key '{MemProvider.STATE_KEY}' expected dict, got {type(store).__name__}"
|
188
|
+
)
|
189
|
+
return store # type: ignore[return-value]
|
196
190
|
|
197
191
|
def read(
|
198
192
|
self, uri: ResolvedURI, out_spec: TensorType | TableType, *, ctx: KernelContext
|
199
193
|
) -> Any:
|
200
|
-
store = self.
|
194
|
+
store = self._store(ctx)
|
201
195
|
key = uri.raw
|
202
196
|
if key not in store:
|
203
197
|
raise FileNotFoundError(f"mem resource not found: {key}")
|
204
198
|
return store[key]
|
205
199
|
|
206
200
|
def write(self, uri: ResolvedURI, value: Any, *, ctx: KernelContext) -> None:
|
207
|
-
store = self.
|
201
|
+
store = self._store(ctx)
|
208
202
|
store[uri.raw] = value
|
209
203
|
|
210
204
|
|
@@ -25,14 +25,14 @@ def _demo_flow():
|
|
25
25
|
# TEE generates two ephemeral keypairs and quotes binding their pk
|
26
26
|
t_sk0, t_pk0 = simp.runAt(P2, crypto.kem_keygen)("x25519")
|
27
27
|
t_sk1, t_pk1 = simp.runAt(P2, crypto.kem_keygen)("x25519")
|
28
|
-
q0 = simp.runAt(P2, tee.
|
29
|
-
q1 = simp.runAt(P2, tee.
|
28
|
+
q0 = simp.runAt(P2, tee.quote_gen)(t_pk0)
|
29
|
+
q1 = simp.runAt(P2, tee.quote_gen)(t_pk1)
|
30
30
|
|
31
31
|
# Send quotes to P0/P1 and attest to obtain TEE public keys
|
32
32
|
q0_for_p0 = simp.p2p(P2, P0, q0)
|
33
33
|
q1_for_p1 = simp.p2p(P2, P1, q1)
|
34
|
-
t_pk0_for_p0 = simp.runAt(P0, tee.attest)(q0_for_p0)
|
35
|
-
t_pk1_for_p1 = simp.runAt(P1, tee.attest)(q1_for_p1)
|
34
|
+
t_pk0_for_p0 = simp.runAt(P0, tee.attest)(q0_for_p0, "TDX")
|
35
|
+
t_pk1_for_p1 = simp.runAt(P1, tee.attest)(q1_for_p1, "TDX")
|
36
36
|
|
37
37
|
# Each party generates its own ephemeral keypair and shares pk with TEE
|
38
38
|
v_sk0, v_pk0 = simp.runAt(P0, crypto.kem_keygen)("x25519")
|
@@ -75,7 +75,10 @@ def _demo_flow():
|
|
75
75
|
|
76
76
|
def test_crypto_enc_dec_and_tee_quote_attest_roundtrip():
|
77
77
|
# Create simulator with TEE bindings using the new initial_bindings parameter
|
78
|
-
tee_bindings = {
|
78
|
+
tee_bindings = {
|
79
|
+
"tee.quote_gen": "mock_tee.quote_gen",
|
80
|
+
"tee.attest": "mock_tee.attest",
|
81
|
+
}
|
79
82
|
sim = mplang.Simulator.simple(3, op_bindings=tee_bindings)
|
80
83
|
p0, p1 = mplang.evaluate(sim, _demo_flow)
|
81
84
|
a = mplang.fetch(sim, p0)
|
@@ -35,7 +35,7 @@ def test_basic_callable_and_namespace():
|
|
35
35
|
a, b = simp.P0(crypto.kem_keygen, "x25519")
|
36
36
|
# namespace form (tee side key, then quote)
|
37
37
|
t_sk, t_pk = simp.P[2].crypto.kem_keygen("x25519")
|
38
|
-
_ = simp.P[2].tee.
|
38
|
+
_ = simp.P[2].tee.quote_gen(t_pk)
|
39
39
|
# derive something simple at party 0 to ensure run path works
|
40
40
|
_ = simp.P0(lambda x: x + 1, 41)
|
41
41
|
return a, b, t_sk, t_pk
|
@@ -101,7 +101,10 @@ def run_tee():
|
|
101
101
|
print("-" * 10, "millionaire (TEE)", "-" * 10)
|
102
102
|
|
103
103
|
# TEE operations need explicit binding for security
|
104
|
-
tee_bindings = {
|
104
|
+
tee_bindings = {
|
105
|
+
"tee.quote_gen": "mock_tee.quote_gen",
|
106
|
+
"tee.attest": "mock_tee.attest",
|
107
|
+
}
|
105
108
|
# Apply tee bindings across nodes before constructing simulator
|
106
109
|
for n in cluster_spec.nodes.values():
|
107
110
|
n.runtime_info.op_bindings.update(tee_bindings)
|
@@ -46,7 +46,11 @@ cluster_spec = ClusterSpec.from_dict({
|
|
46
46
|
},
|
47
47
|
"P0": {"kind": "PPU", "members": ["node_0"], "config": {}},
|
48
48
|
"P1": {"kind": "PPU", "members": ["node_1"], "config": {}},
|
49
|
-
"TEE0": {
|
49
|
+
"TEE0": {
|
50
|
+
"kind": "TEE",
|
51
|
+
"members": ["node_2"],
|
52
|
+
"config": {"platform": "TDX"},
|
53
|
+
},
|
50
54
|
},
|
51
55
|
})
|
52
56
|
|
@@ -72,8 +76,8 @@ def millionaire_manual():
|
|
72
76
|
|
73
77
|
# P0 <-> TEE handshake and transfer x (using sugar)
|
74
78
|
tee_sk0, tee_pk0 = P2.crypto.kem_keygen("x25519")
|
75
|
-
quote0 = P2.tee.
|
76
|
-
tee_pk0_at_p0 = P0.tee.attest(P2P(P2, P0, quote0))
|
79
|
+
quote0 = P2.tee.quote_gen(tee_pk0)
|
80
|
+
tee_pk0_at_p0 = P0.tee.attest(P2P(P2, P0, quote0), "TDX")
|
77
81
|
v_sk0, v_pk0 = P0.crypto.kem_keygen("x25519")
|
78
82
|
shared0_p = P0.crypto.kem_derive(v_sk0, tee_pk0_at_p0, "x25519")
|
79
83
|
shared0_t = P2.crypto.kem_derive(tee_sk0, P2P(P0, P2, v_pk0), "x25519")
|
@@ -88,8 +92,8 @@ def millionaire_manual():
|
|
88
92
|
|
89
93
|
# P1 <-> TEE handshake and transfer y (still show original style for contrast)
|
90
94
|
tee_sk1, tee_pk1 = P2.crypto.kem_keygen("x25519")
|
91
|
-
quote1 = P2.tee.
|
92
|
-
tee_pk1_at_p1 = P1.tee.attest(P2P(P2, P1, quote1))
|
95
|
+
quote1 = P2.tee.quote_gen(tee_pk1)
|
96
|
+
tee_pk1_at_p1 = P1.tee.attest(P2P(P2, P1, quote1), "TDX")
|
93
97
|
v_sk1, v_pk1 = P1.crypto.kem_keygen("x25519")
|
94
98
|
shared1_p = P1.crypto.kem_derive(v_sk1, tee_pk1_at_p1, "x25519")
|
95
99
|
shared1_t = P2.crypto.kem_derive(tee_sk1, P2P(P1, P2, v_pk1), "x25519")
|
@@ -117,7 +121,10 @@ def millionaire_manual():
|
|
117
121
|
def main():
|
118
122
|
print("-" * 10, "TEE millionaire: device vs manual (end-to-end IR)", "-" * 10)
|
119
123
|
# Create simulator with TEE bindings
|
120
|
-
tee_bindings = {
|
124
|
+
tee_bindings = {
|
125
|
+
"tee.quote_gen": "mock_tee.quote_gen",
|
126
|
+
"tee.attest": "mock_tee.attest",
|
127
|
+
}
|
121
128
|
# Apply tee_bindings per-node (preferred) then construct Simulator
|
122
129
|
for n in cluster_spec.nodes.values():
|
123
130
|
n.runtime_info.op_bindings.update(tee_bindings)
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|