mplang-nightly 0.1.dev156__tar.gz → 0.1.dev157__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (166) hide show
  1. {mplang_nightly-0.1.dev156 → mplang_nightly-0.1.dev157}/PKG-INFO +1 -1
  2. {mplang_nightly-0.1.dev156 → mplang_nightly-0.1.dev157}/mplang/device.py +19 -5
  3. {mplang_nightly-0.1.dev156 → mplang_nightly-0.1.dev157}/mplang/kernels/context.py +1 -1
  4. {mplang_nightly-0.1.dev156 → mplang_nightly-0.1.dev157}/mplang/kernels/mock_tee.py +7 -3
  5. {mplang_nightly-0.1.dev156 → mplang_nightly-0.1.dev157}/mplang/ops/tee.py +26 -17
  6. {mplang_nightly-0.1.dev156 → mplang_nightly-0.1.dev157}/tests/ops/test_crypto_tee.py +8 -5
  7. {mplang_nightly-0.1.dev156 → mplang_nightly-0.1.dev157}/tests/simp/test_sugar.py +1 -1
  8. {mplang_nightly-0.1.dev156 → mplang_nightly-0.1.dev157}/tutorials/3_device.py +4 -1
  9. {mplang_nightly-0.1.dev156 → mplang_nightly-0.1.dev157}/tutorials/9_tee.py +13 -6
  10. {mplang_nightly-0.1.dev156 → mplang_nightly-0.1.dev157}/.gitignore +0 -0
  11. {mplang_nightly-0.1.dev156 → mplang_nightly-0.1.dev157}/LICENSE +0 -0
  12. {mplang_nightly-0.1.dev156 → mplang_nightly-0.1.dev157}/README.md +0 -0
  13. {mplang_nightly-0.1.dev156 → mplang_nightly-0.1.dev157}/examples/conf/3pc.yaml +0 -0
  14. {mplang_nightly-0.1.dev156 → mplang_nightly-0.1.dev157}/examples/stax_nn/README.md +0 -0
  15. {mplang_nightly-0.1.dev156 → mplang_nightly-0.1.dev157}/examples/stax_nn/models.py +0 -0
  16. {mplang_nightly-0.1.dev156 → mplang_nightly-0.1.dev157}/examples/stax_nn/stax_nn.py +0 -0
  17. {mplang_nightly-0.1.dev156 → mplang_nightly-0.1.dev157}/examples/xgboost/hist_jax.py +0 -0
  18. {mplang_nightly-0.1.dev156 → mplang_nightly-0.1.dev157}/examples/xgboost/hist_jax_test.py +0 -0
  19. {mplang_nightly-0.1.dev156 → mplang_nightly-0.1.dev157}/examples/xgboost/naive_np.py +0 -0
  20. {mplang_nightly-0.1.dev156 → mplang_nightly-0.1.dev157}/examples/xgboost/readme.md +0 -0
  21. {mplang_nightly-0.1.dev156 → mplang_nightly-0.1.dev157}/examples/xgboost/sgb.py +0 -0
  22. {mplang_nightly-0.1.dev156 → mplang_nightly-0.1.dev157}/examples/xgboost/sgb_test.py +0 -0
  23. {mplang_nightly-0.1.dev156 → mplang_nightly-0.1.dev157}/hatch_build.py +0 -0
  24. {mplang_nightly-0.1.dev156 → mplang_nightly-0.1.dev157}/mplang/__init__.py +0 -0
  25. {mplang_nightly-0.1.dev156 → mplang_nightly-0.1.dev157}/mplang/analysis/__init__.py +0 -0
  26. {mplang_nightly-0.1.dev156 → mplang_nightly-0.1.dev157}/mplang/analysis/diagram.py +0 -0
  27. {mplang_nightly-0.1.dev156 → mplang_nightly-0.1.dev157}/mplang/api.py +0 -0
  28. {mplang_nightly-0.1.dev156 → mplang_nightly-0.1.dev157}/mplang/core/__init__.py +0 -0
  29. {mplang_nightly-0.1.dev156 → mplang_nightly-0.1.dev157}/mplang/core/cluster.py +0 -0
  30. {mplang_nightly-0.1.dev156 → mplang_nightly-0.1.dev157}/mplang/core/comm.py +0 -0
  31. {mplang_nightly-0.1.dev156 → mplang_nightly-0.1.dev157}/mplang/core/context_mgr.py +0 -0
  32. {mplang_nightly-0.1.dev156 → mplang_nightly-0.1.dev157}/mplang/core/dtype.py +0 -0
  33. {mplang_nightly-0.1.dev156 → mplang_nightly-0.1.dev157}/mplang/core/expr/__init__.py +0 -0
  34. {mplang_nightly-0.1.dev156 → mplang_nightly-0.1.dev157}/mplang/core/expr/ast.py +0 -0
  35. {mplang_nightly-0.1.dev156 → mplang_nightly-0.1.dev157}/mplang/core/expr/evaluator.py +0 -0
  36. {mplang_nightly-0.1.dev156 → mplang_nightly-0.1.dev157}/mplang/core/expr/printer.py +0 -0
  37. {mplang_nightly-0.1.dev156 → mplang_nightly-0.1.dev157}/mplang/core/expr/transformer.py +0 -0
  38. {mplang_nightly-0.1.dev156 → mplang_nightly-0.1.dev157}/mplang/core/expr/utils.py +0 -0
  39. {mplang_nightly-0.1.dev156 → mplang_nightly-0.1.dev157}/mplang/core/expr/visitor.py +0 -0
  40. {mplang_nightly-0.1.dev156 → mplang_nightly-0.1.dev157}/mplang/core/expr/walk.py +0 -0
  41. {mplang_nightly-0.1.dev156 → mplang_nightly-0.1.dev157}/mplang/core/interp.py +0 -0
  42. {mplang_nightly-0.1.dev156 → mplang_nightly-0.1.dev157}/mplang/core/mask.py +0 -0
  43. {mplang_nightly-0.1.dev156 → mplang_nightly-0.1.dev157}/mplang/core/mpir.py +0 -0
  44. {mplang_nightly-0.1.dev156 → mplang_nightly-0.1.dev157}/mplang/core/mpobject.py +0 -0
  45. {mplang_nightly-0.1.dev156 → mplang_nightly-0.1.dev157}/mplang/core/mptype.py +0 -0
  46. {mplang_nightly-0.1.dev156 → mplang_nightly-0.1.dev157}/mplang/core/pfunc.py +0 -0
  47. {mplang_nightly-0.1.dev156 → mplang_nightly-0.1.dev157}/mplang/core/primitive.py +0 -0
  48. {mplang_nightly-0.1.dev156 → mplang_nightly-0.1.dev157}/mplang/core/table.py +0 -0
  49. {mplang_nightly-0.1.dev156 → mplang_nightly-0.1.dev157}/mplang/core/tensor.py +0 -0
  50. {mplang_nightly-0.1.dev156 → mplang_nightly-0.1.dev157}/mplang/core/tracer.py +0 -0
  51. {mplang_nightly-0.1.dev156 → mplang_nightly-0.1.dev157}/mplang/kernels/__init__.py +0 -0
  52. {mplang_nightly-0.1.dev156 → mplang_nightly-0.1.dev157}/mplang/kernels/base.py +0 -0
  53. {mplang_nightly-0.1.dev156 → mplang_nightly-0.1.dev157}/mplang/kernels/builtin.py +0 -0
  54. {mplang_nightly-0.1.dev156 → mplang_nightly-0.1.dev157}/mplang/kernels/crypto.py +0 -0
  55. {mplang_nightly-0.1.dev156 → mplang_nightly-0.1.dev157}/mplang/kernels/phe.py +0 -0
  56. {mplang_nightly-0.1.dev156 → mplang_nightly-0.1.dev157}/mplang/kernels/spu.py +0 -0
  57. {mplang_nightly-0.1.dev156 → mplang_nightly-0.1.dev157}/mplang/kernels/sql_duckdb.py +0 -0
  58. {mplang_nightly-0.1.dev156 → mplang_nightly-0.1.dev157}/mplang/kernels/stablehlo.py +0 -0
  59. {mplang_nightly-0.1.dev156 → mplang_nightly-0.1.dev157}/mplang/ops/__init__.py +0 -0
  60. {mplang_nightly-0.1.dev156 → mplang_nightly-0.1.dev157}/mplang/ops/base.py +0 -0
  61. {mplang_nightly-0.1.dev156 → mplang_nightly-0.1.dev157}/mplang/ops/builtin.py +0 -0
  62. {mplang_nightly-0.1.dev156 → mplang_nightly-0.1.dev157}/mplang/ops/crypto.py +0 -0
  63. {mplang_nightly-0.1.dev156 → mplang_nightly-0.1.dev157}/mplang/ops/ibis_cc.py +0 -0
  64. {mplang_nightly-0.1.dev156 → mplang_nightly-0.1.dev157}/mplang/ops/jax_cc.py +0 -0
  65. {mplang_nightly-0.1.dev156 → mplang_nightly-0.1.dev157}/mplang/ops/phe.py +0 -0
  66. {mplang_nightly-0.1.dev156 → mplang_nightly-0.1.dev157}/mplang/ops/spu.py +0 -0
  67. {mplang_nightly-0.1.dev156 → mplang_nightly-0.1.dev157}/mplang/ops/sql.py +0 -0
  68. {mplang_nightly-0.1.dev156 → mplang_nightly-0.1.dev157}/mplang/protos/v1alpha1/mpir_pb2.py +0 -0
  69. {mplang_nightly-0.1.dev156 → mplang_nightly-0.1.dev157}/mplang/protos/v1alpha1/mpir_pb2.pyi +0 -0
  70. {mplang_nightly-0.1.dev156 → mplang_nightly-0.1.dev157}/mplang/protos/v1alpha1/mpir_pb2_grpc.py +0 -0
  71. {mplang_nightly-0.1.dev156 → mplang_nightly-0.1.dev157}/mplang/runtime/__init__.py +0 -0
  72. {mplang_nightly-0.1.dev156 → mplang_nightly-0.1.dev157}/mplang/runtime/cli.py +0 -0
  73. {mplang_nightly-0.1.dev156 → mplang_nightly-0.1.dev157}/mplang/runtime/client.py +0 -0
  74. {mplang_nightly-0.1.dev156 → mplang_nightly-0.1.dev157}/mplang/runtime/communicator.py +0 -0
  75. {mplang_nightly-0.1.dev156 → mplang_nightly-0.1.dev157}/mplang/runtime/data_providers.py +0 -0
  76. {mplang_nightly-0.1.dev156 → mplang_nightly-0.1.dev157}/mplang/runtime/driver.py +0 -0
  77. {mplang_nightly-0.1.dev156 → mplang_nightly-0.1.dev157}/mplang/runtime/exceptions.py +0 -0
  78. {mplang_nightly-0.1.dev156 → mplang_nightly-0.1.dev157}/mplang/runtime/http_api.md +0 -0
  79. {mplang_nightly-0.1.dev156 → mplang_nightly-0.1.dev157}/mplang/runtime/link_comm.py +0 -0
  80. {mplang_nightly-0.1.dev156 → mplang_nightly-0.1.dev157}/mplang/runtime/server.py +0 -0
  81. {mplang_nightly-0.1.dev156 → mplang_nightly-0.1.dev157}/mplang/runtime/session.py +0 -0
  82. {mplang_nightly-0.1.dev156 → mplang_nightly-0.1.dev157}/mplang/runtime/simulation.py +0 -0
  83. {mplang_nightly-0.1.dev156 → mplang_nightly-0.1.dev157}/mplang/simp/__init__.py +0 -0
  84. {mplang_nightly-0.1.dev156 → mplang_nightly-0.1.dev157}/mplang/simp/mpi.py +0 -0
  85. {mplang_nightly-0.1.dev156 → mplang_nightly-0.1.dev157}/mplang/simp/random.py +0 -0
  86. {mplang_nightly-0.1.dev156 → mplang_nightly-0.1.dev157}/mplang/simp/smpc.py +0 -0
  87. {mplang_nightly-0.1.dev156 → mplang_nightly-0.1.dev157}/mplang/utils/__init__.py +0 -0
  88. {mplang_nightly-0.1.dev156 → mplang_nightly-0.1.dev157}/mplang/utils/crypto.py +0 -0
  89. {mplang_nightly-0.1.dev156 → mplang_nightly-0.1.dev157}/mplang/utils/func_utils.py +0 -0
  90. {mplang_nightly-0.1.dev156 → mplang_nightly-0.1.dev157}/mplang/utils/spu_utils.py +0 -0
  91. {mplang_nightly-0.1.dev156 → mplang_nightly-0.1.dev157}/mplang/utils/table_utils.py +0 -0
  92. {mplang_nightly-0.1.dev156 → mplang_nightly-0.1.dev157}/pyproject.toml +0 -0
  93. {mplang_nightly-0.1.dev156 → mplang_nightly-0.1.dev157}/tests/__init__.py +0 -0
  94. {mplang_nightly-0.1.dev156 → mplang_nightly-0.1.dev157}/tests/analysis/test_diagram.py +0 -0
  95. {mplang_nightly-0.1.dev156 → mplang_nightly-0.1.dev157}/tests/conftest.py +0 -0
  96. {mplang_nightly-0.1.dev156 → mplang_nightly-0.1.dev157}/tests/core/__init__.py +0 -0
  97. {mplang_nightly-0.1.dev156 → mplang_nightly-0.1.dev157}/tests/core/expr/__init__.py +0 -0
  98. {mplang_nightly-0.1.dev156 → mplang_nightly-0.1.dev157}/tests/core/expr/conftest.py +0 -0
  99. {mplang_nightly-0.1.dev156 → mplang_nightly-0.1.dev157}/tests/core/expr/test_ast.py +0 -0
  100. {mplang_nightly-0.1.dev156 → mplang_nightly-0.1.dev157}/tests/core/expr/test_printer.py +0 -0
  101. {mplang_nightly-0.1.dev156 → mplang_nightly-0.1.dev157}/tests/core/expr/test_utils.py +0 -0
  102. {mplang_nightly-0.1.dev156 → mplang_nightly-0.1.dev157}/tests/core/expr/test_walk.py +0 -0
  103. {mplang_nightly-0.1.dev156 → mplang_nightly-0.1.dev157}/tests/core/test_cluster.py +0 -0
  104. {mplang_nightly-0.1.dev156 → mplang_nightly-0.1.dev157}/tests/core/test_dtype.py +0 -0
  105. {mplang_nightly-0.1.dev156 → mplang_nightly-0.1.dev157}/tests/core/test_mask.py +0 -0
  106. {mplang_nightly-0.1.dev156 → mplang_nightly-0.1.dev157}/tests/core/test_mpir.py +0 -0
  107. {mplang_nightly-0.1.dev156 → mplang_nightly-0.1.dev157}/tests/core/test_mptype.py +0 -0
  108. {mplang_nightly-0.1.dev156 → mplang_nightly-0.1.dev157}/tests/core/test_primitive.py +0 -0
  109. {mplang_nightly-0.1.dev156 → mplang_nightly-0.1.dev157}/tests/core/test_table.py +0 -0
  110. {mplang_nightly-0.1.dev156 → mplang_nightly-0.1.dev157}/tests/core/test_tensor.py +0 -0
  111. {mplang_nightly-0.1.dev156 → mplang_nightly-0.1.dev157}/tests/core/test_tracer.py +0 -0
  112. {mplang_nightly-0.1.dev156 → mplang_nightly-0.1.dev157}/tests/device/__init__.py +0 -0
  113. {mplang_nightly-0.1.dev156 → mplang_nightly-0.1.dev157}/tests/device/test_device_basic.py +0 -0
  114. {mplang_nightly-0.1.dev156 → mplang_nightly-0.1.dev157}/tests/integration/README.md +0 -0
  115. {mplang_nightly-0.1.dev156 → mplang_nightly-0.1.dev157}/tests/integration/test_crypto_roundtrip.py +0 -0
  116. {mplang_nightly-0.1.dev156 → mplang_nightly-0.1.dev157}/tests/integration/test_http_e2e.py +0 -0
  117. {mplang_nightly-0.1.dev156 → mplang_nightly-0.1.dev157}/tests/integration/test_symbols_roundtrip.py +0 -0
  118. {mplang_nightly-0.1.dev156 → mplang_nightly-0.1.dev157}/tests/integration/test_tutorials.py +0 -0
  119. {mplang_nightly-0.1.dev156 → mplang_nightly-0.1.dev157}/tests/integration/test_unused_param_integration.py +0 -0
  120. {mplang_nightly-0.1.dev156 → mplang_nightly-0.1.dev157}/tests/kernels/test_builtin.py +0 -0
  121. {mplang_nightly-0.1.dev156 → mplang_nightly-0.1.dev157}/tests/kernels/test_debug_print.py +0 -0
  122. {mplang_nightly-0.1.dev156 → mplang_nightly-0.1.dev157}/tests/kernels/test_kernel_binding.py +0 -0
  123. {mplang_nightly-0.1.dev156 → mplang_nightly-0.1.dev157}/tests/kernels/test_phe.py +0 -0
  124. {mplang_nightly-0.1.dev156 → mplang_nightly-0.1.dev157}/tests/kernels/test_spu.py +0 -0
  125. {mplang_nightly-0.1.dev156 → mplang_nightly-0.1.dev157}/tests/kernels/test_sql_duckdb.py +0 -0
  126. {mplang_nightly-0.1.dev156 → mplang_nightly-0.1.dev157}/tests/kernels/test_stablehlo.py +0 -0
  127. {mplang_nightly-0.1.dev156 → mplang_nightly-0.1.dev157}/tests/ops/__init__.py +0 -0
  128. {mplang_nightly-0.1.dev156 → mplang_nightly-0.1.dev157}/tests/ops/dummy.py +0 -0
  129. {mplang_nightly-0.1.dev156 → mplang_nightly-0.1.dev157}/tests/ops/test_builtin_pack.py +0 -0
  130. {mplang_nightly-0.1.dev156 → mplang_nightly-0.1.dev157}/tests/ops/test_feop_base.py +0 -0
  131. {mplang_nightly-0.1.dev156 → mplang_nightly-0.1.dev157}/tests/ops/test_ibis.py +0 -0
  132. {mplang_nightly-0.1.dev156 → mplang_nightly-0.1.dev157}/tests/ops/test_ibis_cc.py +0 -0
  133. {mplang_nightly-0.1.dev156 → mplang_nightly-0.1.dev157}/tests/ops/test_jax_cc.py +0 -0
  134. {mplang_nightly-0.1.dev156 → mplang_nightly-0.1.dev157}/tests/ops/test_phe.py +0 -0
  135. {mplang_nightly-0.1.dev156 → mplang_nightly-0.1.dev157}/tests/ops/test_spu.py +0 -0
  136. {mplang_nightly-0.1.dev156 → mplang_nightly-0.1.dev157}/tests/ops/test_spu_defensive.py +0 -0
  137. {mplang_nightly-0.1.dev156 → mplang_nightly-0.1.dev157}/tests/ops/test_sql.py +0 -0
  138. {mplang_nightly-0.1.dev156 → mplang_nightly-0.1.dev157}/tests/ops/test_table_tensor_conversion.py +0 -0
  139. {mplang_nightly-0.1.dev156 → mplang_nightly-0.1.dev157}/tests/runtime/__init__.py +0 -0
  140. {mplang_nightly-0.1.dev156 → mplang_nightly-0.1.dev157}/tests/runtime/test_cli.py +0 -0
  141. {mplang_nightly-0.1.dev156 → mplang_nightly-0.1.dev157}/tests/runtime/test_communicator.py +0 -0
  142. {mplang_nightly-0.1.dev156 → mplang_nightly-0.1.dev157}/tests/runtime/test_driver.py +0 -0
  143. {mplang_nightly-0.1.dev156 → mplang_nightly-0.1.dev157}/tests/runtime/test_server.py +0 -0
  144. {mplang_nightly-0.1.dev156 → mplang_nightly-0.1.dev157}/tests/runtime/test_simulation.py +0 -0
  145. {mplang_nightly-0.1.dev156 → mplang_nightly-0.1.dev157}/tests/simp/test_mpi.py +0 -0
  146. {mplang_nightly-0.1.dev156 → mplang_nightly-0.1.dev157}/tests/simp/test_random.py +0 -0
  147. {mplang_nightly-0.1.dev156 → mplang_nightly-0.1.dev157}/tests/simp/test_simp.py +0 -0
  148. {mplang_nightly-0.1.dev156 → mplang_nightly-0.1.dev157}/tests/simp/test_smpc.py +0 -0
  149. {mplang_nightly-0.1.dev156 → mplang_nightly-0.1.dev157}/tests/utils/__init__.py +0 -0
  150. {mplang_nightly-0.1.dev156 → mplang_nightly-0.1.dev157}/tests/utils/server_fixtures.py +0 -0
  151. {mplang_nightly-0.1.dev156 → mplang_nightly-0.1.dev157}/tests/utils/test_func_utils.py +0 -0
  152. {mplang_nightly-0.1.dev156 → mplang_nightly-0.1.dev157}/tests/utils/test_spu_utils.py +0 -0
  153. {mplang_nightly-0.1.dev156 → mplang_nightly-0.1.dev157}/tests/utils/test_table_utils.py +0 -0
  154. {mplang_nightly-0.1.dev156 → mplang_nightly-0.1.dev157}/tutorials/0_basic.py +0 -0
  155. {mplang_nightly-0.1.dev156 → mplang_nightly-0.1.dev157}/tutorials/10_analysis.py +0 -0
  156. {mplang_nightly-0.1.dev156 → mplang_nightly-0.1.dev157}/tutorials/1_condition.py +0 -0
  157. {mplang_nightly-0.1.dev156 → mplang_nightly-0.1.dev157}/tutorials/2_whileloop.py +0 -0
  158. {mplang_nightly-0.1.dev156 → mplang_nightly-0.1.dev157}/tutorials/4_simulation.py +0 -0
  159. {mplang_nightly-0.1.dev156 → mplang_nightly-0.1.dev157}/tutorials/5_ir_dump.py +0 -0
  160. {mplang_nightly-0.1.dev156 → mplang_nightly-0.1.dev157}/tutorials/6_advanced.py +0 -0
  161. {mplang_nightly-0.1.dev156 → mplang_nightly-0.1.dev157}/tutorials/7_stdio.py +0 -0
  162. {mplang_nightly-0.1.dev156 → mplang_nightly-0.1.dev157}/tutorials/8_phe.py +0 -0
  163. {mplang_nightly-0.1.dev156 → mplang_nightly-0.1.dev157}/tutorials/__init__.py +0 -0
  164. {mplang_nightly-0.1.dev156 → mplang_nightly-0.1.dev157}/tutorials/pitfalls/late_binding.py +0 -0
  165. {mplang_nightly-0.1.dev156 → mplang_nightly-0.1.dev157}/tutorials/pitfalls/rand.py +0 -0
  166. {mplang_nightly-0.1.dev156 → mplang_nightly-0.1.dev157}/tutorials/run.sh +0 -0
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: mplang-nightly
3
- Version: 0.1.dev156
3
+ Version: 0.1.dev157
4
4
  Summary: Multi-Party Programming Language
5
5
  Author-email: SecretFlow Team <secretflow-contact@service.alipay.com>
6
6
  License: Apache License
@@ -207,8 +207,15 @@ def _d2d(to_dev_id: str, obj: MPObject) -> MPObject:
207
207
  assert len(frm_dev.members) == 1 and len(to_dev.members) == 1
208
208
  frm_rank = frm_dev.members[0].rank
209
209
  tee_rank = to_dev.members[0].rank
210
+ platform = to_dev.config.get("platform")
211
+ if not platform:
212
+ raise ValueError(
213
+ f"TEE device '{to_dev_id}' is missing 'platform' in its config."
214
+ )
210
215
  # Ensure sessions (both directions) exist for this PPU<->TEE pair
211
- sess_p, sess_t = _ensure_tee_session(frm_dev_id, to_dev_id, frm_rank, tee_rank)
216
+ sess_p, sess_t = _ensure_tee_session(
217
+ frm_dev_id, to_dev_id, frm_rank, tee_rank, platform
218
+ )
212
219
  # Bytes-only path: pack -> enc -> p2p -> dec -> unpack (with static out type)
213
220
  obj_ty = TensorType.from_obj(obj)
214
221
  b = simp.runAt(frm_rank, builtin.pack)(obj)
@@ -222,8 +229,15 @@ def _d2d(to_dev_id: str, obj: MPObject) -> MPObject:
222
229
  assert len(frm_dev.members) == 1 and len(to_dev.members) == 1
223
230
  tee_rank = frm_dev.members[0].rank
224
231
  ppu_rank = to_dev.members[0].rank
232
+ platform = to_dev.config.get("platform")
233
+ if not platform:
234
+ raise ValueError(
235
+ f"TEE device '{to_dev_id}' is missing 'platform' in its config."
236
+ )
225
237
  # Ensure bidirectional session established for this pair
226
- sess_p, sess_t = _ensure_tee_session(to_dev_id, frm_dev_id, ppu_rank, tee_rank)
238
+ sess_p, sess_t = _ensure_tee_session(
239
+ to_dev_id, frm_dev_id, ppu_rank, tee_rank, platform
240
+ )
227
241
  obj_ty = TensorType.from_obj(obj)
228
242
  b = simp.runAt(tee_rank, builtin.pack)(obj)
229
243
  ct = simp.runAt(tee_rank, crypto.enc)(b, sess_t)
@@ -245,7 +259,7 @@ def _d2d(to_dev_id: str, obj: MPObject) -> MPObject:
245
259
 
246
260
 
247
261
  def _ensure_tee_session(
248
- frm_dev_id: str, to_dev_id: str, frm_rank: int, tee_rank: int
262
+ frm_dev_id: str, to_dev_id: str, frm_rank: int, tee_rank: int, platform: str
249
263
  ) -> tuple[MPObject, MPObject]:
250
264
  """Ensure a TEE session (sess_p at sender, sess_t at TEE) exists.
251
265
 
@@ -263,11 +277,11 @@ def _ensure_tee_session(
263
277
  # 1) TEE generates (sk, pk) and quote(pk)
264
278
  # KEM suite currently constant; future: read from tee device config (e.g. cluster_spec.devices[dev_id].config)
265
279
  tee_sk, tee_pk = simp.runAt(tee_rank, crypto.kem_keygen)(_TEE_KEM_SUITE)
266
- quote = simp.runAt(tee_rank, tee.quote)(tee_pk)
280
+ quote = simp.runAt(tee_rank, tee.quote_gen)(tee_pk)
267
281
 
268
282
  # 2) Send quote to sender and attest to obtain TEE pk
269
283
  quote_at_sender = mpi.p2p(tee_rank, frm_rank, quote)
270
- tee_pk_at_sender = simp.runAt(frm_rank, tee.attest)(quote_at_sender)
284
+ tee_pk_at_sender = simp.runAt(frm_rank, tee.attest)(quote_at_sender, platform)
271
285
 
272
286
  # 3) Sender generates its ephemeral keypair and sends its pk to TEE
273
287
  v_sk, v_pk = simp.runAt(frm_rank, crypto.kem_keygen)(_TEE_KEM_SUITE)
@@ -91,7 +91,7 @@ _DEFAULT_BINDINGS: dict[str, str] = {
91
91
  # generic SQL op; backend-specific kernel id for duckdb
92
92
  "sql.run": "duckdb.run_sql",
93
93
  # tee
94
- # "tee.quote": "mock_tee.quote",
94
+ # "tee.quote_gen": "mock_tee.quote_gen",
95
95
  # "tee.attest": "mock_tee.attest",
96
96
  }
97
97
 
@@ -45,10 +45,10 @@ def _quote_from_pk(pk: np.ndarray) -> NDArray[np.uint8]:
45
45
  return out
46
46
 
47
47
 
48
- @kernel_def("mock_tee.quote")
49
- def _tee_quote(pfunc: PFunction, pk: object) -> NDArray[np.uint8]:
48
+ @kernel_def("mock_tee.quote_gen")
49
+ def _tee_quote_gen(pfunc: PFunction, pk: object) -> NDArray[np.uint8]:
50
50
  warnings.warn(
51
- "Insecure mock TEE kernel 'mock_tee.quote' in use. NOT secure; for local testing only.",
51
+ "Insecure mock TEE kernel 'mock_tee.quote_gen' in use. NOT secure; for local testing only.",
52
52
  stacklevel=3,
53
53
  )
54
54
  pk = np.asarray(pk, dtype=np.uint8)
@@ -64,6 +64,10 @@ def _tee_attest(pfunc: PFunction, quote: object) -> NDArray[np.uint8]:
64
64
  stacklevel=3,
65
65
  )
66
66
  quote = np.asarray(quote, dtype=np.uint8)
67
+ platform = pfunc.attrs.get("platform")
68
+ if platform is None:
69
+ raise ValueError("missing required 'platform' attribute in PFunction")
70
+
67
71
  if quote.size != 33:
68
72
  raise ValueError("mock quote must be 33 bytes (1 header + 32 pk)")
69
73
  return quote[1:33].astype(np.uint8)
@@ -14,7 +14,11 @@
14
14
 
15
15
  from __future__ import annotations
16
16
 
17
+ from jax.tree_util import PyTreeDef, tree_flatten
18
+
17
19
  from mplang.core.dtype import UINT8
20
+ from mplang.core.mpobject import MPObject
21
+ from mplang.core.pfunc import PFunction
18
22
  from mplang.core.tensor import TensorType
19
23
  from mplang.ops.base import stateless_mod
20
24
 
@@ -22,21 +26,26 @@ _TEE_MOD = stateless_mod("tee")
22
26
 
23
27
 
24
28
  @_TEE_MOD.simple_op()
25
- def quote(pk: TensorType) -> TensorType:
26
- """TEE quote generation binding the provided ephemeral public key.
27
-
28
- API (mock): quote(pk: u8[32]) -> (quote: u8[33])
29
- The mock encodes a 1-byte header + 32-byte pk.
30
- """
29
+ def quote_gen(pk: TensorType) -> TensorType:
30
+ """TEE quote generation binding the provided ephemeral public key."""
31
31
  _ = pk # Mark as used for the decorator
32
- return TensorType(UINT8, (33,))
33
-
34
-
35
- @_TEE_MOD.simple_op()
36
- def attest(quote: TensorType) -> TensorType:
37
- """TEE quote verification returning the attested TEE public key.
38
-
39
- API (mock): attest(quote: u8[33]) -> tee_pk: u8[32]
40
- """
41
- _ = quote # Mark as used for the decorator
42
- return TensorType(UINT8, (32,))
32
+ return TensorType(UINT8, (-1,))
33
+
34
+
35
+ @_TEE_MOD.op_def()
36
+ def attest(
37
+ quote: MPObject, platform: str
38
+ ) -> tuple[PFunction, list[MPObject], PyTreeDef]:
39
+ """TEE quote verification returning the attested TEE public key."""
40
+
41
+ ins_info = [TensorType.from_obj(quote)]
42
+ outs_info = [TensorType(UINT8, (32,))] # pk is always 32 bytes for x25519
43
+ pfunc = PFunction(
44
+ fn_type="tee.attest",
45
+ ins_info=ins_info,
46
+ outs_info=outs_info,
47
+ platform=platform,
48
+ )
49
+ _, treedef = tree_flatten(outs_info[0])
50
+
51
+ return pfunc, [quote], treedef
@@ -25,14 +25,14 @@ def _demo_flow():
25
25
  # TEE generates two ephemeral keypairs and quotes binding their pk
26
26
  t_sk0, t_pk0 = simp.runAt(P2, crypto.kem_keygen)("x25519")
27
27
  t_sk1, t_pk1 = simp.runAt(P2, crypto.kem_keygen)("x25519")
28
- q0 = simp.runAt(P2, tee.quote)(t_pk0)
29
- q1 = simp.runAt(P2, tee.quote)(t_pk1)
28
+ q0 = simp.runAt(P2, tee.quote_gen)(t_pk0)
29
+ q1 = simp.runAt(P2, tee.quote_gen)(t_pk1)
30
30
 
31
31
  # Send quotes to P0/P1 and attest to obtain TEE public keys
32
32
  q0_for_p0 = simp.p2p(P2, P0, q0)
33
33
  q1_for_p1 = simp.p2p(P2, P1, q1)
34
- t_pk0_for_p0 = simp.runAt(P0, tee.attest)(q0_for_p0)
35
- t_pk1_for_p1 = simp.runAt(P1, tee.attest)(q1_for_p1)
34
+ t_pk0_for_p0 = simp.runAt(P0, tee.attest)(q0_for_p0, "TDX")
35
+ t_pk1_for_p1 = simp.runAt(P1, tee.attest)(q1_for_p1, "TDX")
36
36
 
37
37
  # Each party generates its own ephemeral keypair and shares pk with TEE
38
38
  v_sk0, v_pk0 = simp.runAt(P0, crypto.kem_keygen)("x25519")
@@ -75,7 +75,10 @@ def _demo_flow():
75
75
 
76
76
  def test_crypto_enc_dec_and_tee_quote_attest_roundtrip():
77
77
  # Create simulator with TEE bindings using the new initial_bindings parameter
78
- tee_bindings = {"tee.quote": "mock_tee.quote", "tee.attest": "mock_tee.attest"}
78
+ tee_bindings = {
79
+ "tee.quote_gen": "mock_tee.quote_gen",
80
+ "tee.attest": "mock_tee.attest",
81
+ }
79
82
  sim = mplang.Simulator.simple(3, op_bindings=tee_bindings)
80
83
  p0, p1 = mplang.evaluate(sim, _demo_flow)
81
84
  a = mplang.fetch(sim, p0)
@@ -35,7 +35,7 @@ def test_basic_callable_and_namespace():
35
35
  a, b = simp.P0(crypto.kem_keygen, "x25519")
36
36
  # namespace form (tee side key, then quote)
37
37
  t_sk, t_pk = simp.P[2].crypto.kem_keygen("x25519")
38
- _ = simp.P[2].tee.quote(t_pk)
38
+ _ = simp.P[2].tee.quote_gen(t_pk)
39
39
  # derive something simple at party 0 to ensure run path works
40
40
  _ = simp.P0(lambda x: x + 1, 41)
41
41
  return a, b, t_sk, t_pk
@@ -101,7 +101,10 @@ def run_tee():
101
101
  print("-" * 10, "millionaire (TEE)", "-" * 10)
102
102
 
103
103
  # TEE operations need explicit binding for security
104
- tee_bindings = {"tee.quote": "mock_tee.quote", "tee.attest": "mock_tee.attest"}
104
+ tee_bindings = {
105
+ "tee.quote_gen": "mock_tee.quote_gen",
106
+ "tee.attest": "mock_tee.attest",
107
+ }
105
108
  # Apply tee bindings across nodes before constructing simulator
106
109
  for n in cluster_spec.nodes.values():
107
110
  n.runtime_info.op_bindings.update(tee_bindings)
@@ -46,7 +46,11 @@ cluster_spec = ClusterSpec.from_dict({
46
46
  },
47
47
  "P0": {"kind": "PPU", "members": ["node_0"], "config": {}},
48
48
  "P1": {"kind": "PPU", "members": ["node_1"], "config": {}},
49
- "TEE0": {"kind": "TEE", "members": ["node_2"], "config": {}},
49
+ "TEE0": {
50
+ "kind": "TEE",
51
+ "members": ["node_2"],
52
+ "config": {"platform": "TDX"},
53
+ },
50
54
  },
51
55
  })
52
56
 
@@ -72,8 +76,8 @@ def millionaire_manual():
72
76
 
73
77
  # P0 <-> TEE handshake and transfer x (using sugar)
74
78
  tee_sk0, tee_pk0 = P2.crypto.kem_keygen("x25519")
75
- quote0 = P2.tee.quote(tee_pk0)
76
- tee_pk0_at_p0 = P0.tee.attest(P2P(P2, P0, quote0))
79
+ quote0 = P2.tee.quote_gen(tee_pk0)
80
+ tee_pk0_at_p0 = P0.tee.attest(P2P(P2, P0, quote0), "TDX")
77
81
  v_sk0, v_pk0 = P0.crypto.kem_keygen("x25519")
78
82
  shared0_p = P0.crypto.kem_derive(v_sk0, tee_pk0_at_p0, "x25519")
79
83
  shared0_t = P2.crypto.kem_derive(tee_sk0, P2P(P0, P2, v_pk0), "x25519")
@@ -88,8 +92,8 @@ def millionaire_manual():
88
92
 
89
93
  # P1 <-> TEE handshake and transfer y (still show original style for contrast)
90
94
  tee_sk1, tee_pk1 = P2.crypto.kem_keygen("x25519")
91
- quote1 = P2.tee.quote(tee_pk1)
92
- tee_pk1_at_p1 = P1.tee.attest(P2P(P2, P1, quote1))
95
+ quote1 = P2.tee.quote_gen(tee_pk1)
96
+ tee_pk1_at_p1 = P1.tee.attest(P2P(P2, P1, quote1), "TDX")
93
97
  v_sk1, v_pk1 = P1.crypto.kem_keygen("x25519")
94
98
  shared1_p = P1.crypto.kem_derive(v_sk1, tee_pk1_at_p1, "x25519")
95
99
  shared1_t = P2.crypto.kem_derive(tee_sk1, P2P(P1, P2, v_pk1), "x25519")
@@ -117,7 +121,10 @@ def millionaire_manual():
117
121
  def main():
118
122
  print("-" * 10, "TEE millionaire: device vs manual (end-to-end IR)", "-" * 10)
119
123
  # Create simulator with TEE bindings
120
- tee_bindings = {"tee.quote": "mock_tee.quote", "tee.attest": "mock_tee.attest"}
124
+ tee_bindings = {
125
+ "tee.quote_gen": "mock_tee.quote_gen",
126
+ "tee.attest": "mock_tee.attest",
127
+ }
121
128
  # Apply tee_bindings per-node (preferred) then construct Simulator
122
129
  for n in cluster_spec.nodes.values():
123
130
  n.runtime_info.op_bindings.update(tee_bindings)