mplang-nightly 0.1.dev164__tar.gz → 0.1.dev166__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (173) hide show
  1. {mplang_nightly-0.1.dev164 → mplang_nightly-0.1.dev166}/PKG-INFO +1 -1
  2. {mplang_nightly-0.1.dev164 → mplang_nightly-0.1.dev166}/mplang/core/expr/evaluator.py +55 -15
  3. mplang_nightly-0.1.dev166/mplang/kernels/__init__.py +41 -0
  4. {mplang_nightly-0.1.dev164 → mplang_nightly-0.1.dev166}/mplang/kernels/builtin.py +91 -56
  5. {mplang_nightly-0.1.dev164 → mplang_nightly-0.1.dev166}/mplang/kernels/crypto.py +39 -30
  6. {mplang_nightly-0.1.dev164 → mplang_nightly-0.1.dev166}/mplang/kernels/mock_tee.py +10 -8
  7. {mplang_nightly-0.1.dev164 → mplang_nightly-0.1.dev166}/mplang/kernels/phe.py +238 -39
  8. {mplang_nightly-0.1.dev164 → mplang_nightly-0.1.dev166}/mplang/kernels/spu.py +134 -45
  9. {mplang_nightly-0.1.dev164 → mplang_nightly-0.1.dev166}/mplang/kernels/sql_duckdb.py +8 -13
  10. {mplang_nightly-0.1.dev164 → mplang_nightly-0.1.dev166}/mplang/kernels/stablehlo.py +15 -9
  11. mplang_nightly-0.1.dev166/mplang/kernels/value.py +626 -0
  12. {mplang_nightly-0.1.dev164 → mplang_nightly-0.1.dev166}/mplang/protos/v1alpha1/mpir_pb2.pyi +71 -21
  13. mplang_nightly-0.1.dev166/mplang/protos/v1alpha1/value_pb2.py +34 -0
  14. mplang_nightly-0.1.dev166/mplang/protos/v1alpha1/value_pb2.pyi +169 -0
  15. {mplang_nightly-0.1.dev164 → mplang_nightly-0.1.dev166}/mplang/runtime/client.py +19 -8
  16. {mplang_nightly-0.1.dev164 → mplang_nightly-0.1.dev166}/mplang/runtime/communicator.py +11 -4
  17. {mplang_nightly-0.1.dev164 → mplang_nightly-0.1.dev166}/mplang/runtime/driver.py +16 -1
  18. mplang_nightly-0.1.dev166/mplang/runtime/link_comm.py +78 -0
  19. {mplang_nightly-0.1.dev164 → mplang_nightly-0.1.dev166}/mplang/runtime/server.py +30 -29
  20. {mplang_nightly-0.1.dev164 → mplang_nightly-0.1.dev166}/mplang/runtime/session.py +9 -0
  21. {mplang_nightly-0.1.dev164 → mplang_nightly-0.1.dev166}/mplang/runtime/simulation.py +4 -5
  22. {mplang_nightly-0.1.dev164 → mplang_nightly-0.1.dev166}/mplang/simp/__init__.py +1 -1
  23. {mplang_nightly-0.1.dev164 → mplang_nightly-0.1.dev166}/pyproject.toml +5 -0
  24. {mplang_nightly-0.1.dev164 → mplang_nightly-0.1.dev166}/tests/conftest.py +2 -2
  25. {mplang_nightly-0.1.dev164 → mplang_nightly-0.1.dev166}/tests/integration/test_symbols_roundtrip.py +5 -1
  26. {mplang_nightly-0.1.dev164 → mplang_nightly-0.1.dev166}/tests/kernels/test_builtin.py +34 -20
  27. {mplang_nightly-0.1.dev164 → mplang_nightly-0.1.dev166}/tests/kernels/test_debug_print.py +8 -3
  28. {mplang_nightly-0.1.dev164 → mplang_nightly-0.1.dev166}/tests/kernels/test_kernel_binding.py +41 -15
  29. {mplang_nightly-0.1.dev164 → mplang_nightly-0.1.dev166}/tests/kernels/test_phe.py +18 -3
  30. {mplang_nightly-0.1.dev164 → mplang_nightly-0.1.dev166}/tests/kernels/test_spu.py +11 -10
  31. {mplang_nightly-0.1.dev164 → mplang_nightly-0.1.dev166}/tests/kernels/test_sql_duckdb.py +5 -1
  32. {mplang_nightly-0.1.dev164 → mplang_nightly-0.1.dev166}/tests/kernels/test_stablehlo.py +8 -2
  33. mplang_nightly-0.1.dev166/tests/kernels/test_value.py +324 -0
  34. mplang_nightly-0.1.dev166/tests/kernels/test_value_serde.py +377 -0
  35. {mplang_nightly-0.1.dev164 → mplang_nightly-0.1.dev166}/tests/ops/test_builtin_pack.py +7 -4
  36. {mplang_nightly-0.1.dev164 → mplang_nightly-0.1.dev166}/tests/ops/test_crypto_tee.py +5 -9
  37. {mplang_nightly-0.1.dev164 → mplang_nightly-0.1.dev166}/tests/runtime/test_communicator.py +22 -13
  38. {mplang_nightly-0.1.dev164 → mplang_nightly-0.1.dev166}/tests/runtime/test_server.py +12 -10
  39. mplang_nightly-0.1.dev164/mplang/runtime/link_comm.py +0 -131
  40. mplang_nightly-0.1.dev164/tests/runtime/__init__.py +0 -13
  41. {mplang_nightly-0.1.dev164 → mplang_nightly-0.1.dev166}/.gitignore +0 -0
  42. {mplang_nightly-0.1.dev164 → mplang_nightly-0.1.dev166}/LICENSE +0 -0
  43. {mplang_nightly-0.1.dev164 → mplang_nightly-0.1.dev166}/README.md +0 -0
  44. {mplang_nightly-0.1.dev164 → mplang_nightly-0.1.dev166}/examples/conf/3pc.yaml +0 -0
  45. {mplang_nightly-0.1.dev164 → mplang_nightly-0.1.dev166}/examples/stax_nn/README.md +0 -0
  46. {mplang_nightly-0.1.dev164 → mplang_nightly-0.1.dev166}/examples/stax_nn/models.py +0 -0
  47. {mplang_nightly-0.1.dev164 → mplang_nightly-0.1.dev166}/examples/stax_nn/stax_nn.py +0 -0
  48. {mplang_nightly-0.1.dev164 → mplang_nightly-0.1.dev166}/examples/xgboost/hist_jax.py +0 -0
  49. {mplang_nightly-0.1.dev164 → mplang_nightly-0.1.dev166}/examples/xgboost/hist_jax_test.py +0 -0
  50. {mplang_nightly-0.1.dev164 → mplang_nightly-0.1.dev166}/examples/xgboost/naive_np.py +0 -0
  51. {mplang_nightly-0.1.dev164 → mplang_nightly-0.1.dev166}/examples/xgboost/readme.md +0 -0
  52. {mplang_nightly-0.1.dev164 → mplang_nightly-0.1.dev166}/examples/xgboost/sgb.py +0 -0
  53. {mplang_nightly-0.1.dev164 → mplang_nightly-0.1.dev166}/examples/xgboost/sgb_test.py +0 -0
  54. {mplang_nightly-0.1.dev164 → mplang_nightly-0.1.dev166}/hatch_build.py +0 -0
  55. {mplang_nightly-0.1.dev164 → mplang_nightly-0.1.dev166}/mplang/__init__.py +0 -0
  56. {mplang_nightly-0.1.dev164 → mplang_nightly-0.1.dev166}/mplang/analysis/__init__.py +0 -0
  57. {mplang_nightly-0.1.dev164 → mplang_nightly-0.1.dev166}/mplang/analysis/diagram.py +0 -0
  58. {mplang_nightly-0.1.dev164 → mplang_nightly-0.1.dev166}/mplang/api.py +0 -0
  59. {mplang_nightly-0.1.dev164 → mplang_nightly-0.1.dev166}/mplang/core/__init__.py +0 -0
  60. {mplang_nightly-0.1.dev164 → mplang_nightly-0.1.dev166}/mplang/core/cluster.py +0 -0
  61. {mplang_nightly-0.1.dev164 → mplang_nightly-0.1.dev166}/mplang/core/comm.py +0 -0
  62. {mplang_nightly-0.1.dev164 → mplang_nightly-0.1.dev166}/mplang/core/context_mgr.py +0 -0
  63. {mplang_nightly-0.1.dev164 → mplang_nightly-0.1.dev166}/mplang/core/dtype.py +0 -0
  64. {mplang_nightly-0.1.dev164 → mplang_nightly-0.1.dev166}/mplang/core/expr/__init__.py +0 -0
  65. {mplang_nightly-0.1.dev164 → mplang_nightly-0.1.dev166}/mplang/core/expr/ast.py +0 -0
  66. {mplang_nightly-0.1.dev164 → mplang_nightly-0.1.dev166}/mplang/core/expr/printer.py +0 -0
  67. {mplang_nightly-0.1.dev164 → mplang_nightly-0.1.dev166}/mplang/core/expr/transformer.py +0 -0
  68. {mplang_nightly-0.1.dev164 → mplang_nightly-0.1.dev166}/mplang/core/expr/utils.py +0 -0
  69. {mplang_nightly-0.1.dev164 → mplang_nightly-0.1.dev166}/mplang/core/expr/visitor.py +0 -0
  70. {mplang_nightly-0.1.dev164 → mplang_nightly-0.1.dev166}/mplang/core/expr/walk.py +0 -0
  71. {mplang_nightly-0.1.dev164 → mplang_nightly-0.1.dev166}/mplang/core/interp.py +0 -0
  72. {mplang_nightly-0.1.dev164 → mplang_nightly-0.1.dev166}/mplang/core/mask.py +0 -0
  73. {mplang_nightly-0.1.dev164 → mplang_nightly-0.1.dev166}/mplang/core/mpir.py +0 -0
  74. {mplang_nightly-0.1.dev164 → mplang_nightly-0.1.dev166}/mplang/core/mpobject.py +0 -0
  75. {mplang_nightly-0.1.dev164 → mplang_nightly-0.1.dev166}/mplang/core/mptype.py +0 -0
  76. {mplang_nightly-0.1.dev164 → mplang_nightly-0.1.dev166}/mplang/core/pfunc.py +0 -0
  77. {mplang_nightly-0.1.dev164 → mplang_nightly-0.1.dev166}/mplang/core/primitive.py +0 -0
  78. {mplang_nightly-0.1.dev164 → mplang_nightly-0.1.dev166}/mplang/core/table.py +0 -0
  79. {mplang_nightly-0.1.dev164 → mplang_nightly-0.1.dev166}/mplang/core/tensor.py +0 -0
  80. {mplang_nightly-0.1.dev164 → mplang_nightly-0.1.dev166}/mplang/core/tracer.py +0 -0
  81. {mplang_nightly-0.1.dev164 → mplang_nightly-0.1.dev166}/mplang/device.py +0 -0
  82. {mplang_nightly-0.1.dev164 → mplang_nightly-0.1.dev166}/mplang/kernels/base.py +0 -0
  83. {mplang_nightly-0.1.dev164 → mplang_nightly-0.1.dev166}/mplang/kernels/context.py +0 -0
  84. {mplang_nightly-0.1.dev164 → mplang_nightly-0.1.dev166}/mplang/ops/__init__.py +0 -0
  85. {mplang_nightly-0.1.dev164 → mplang_nightly-0.1.dev166}/mplang/ops/base.py +0 -0
  86. {mplang_nightly-0.1.dev164 → mplang_nightly-0.1.dev166}/mplang/ops/builtin.py +0 -0
  87. {mplang_nightly-0.1.dev164 → mplang_nightly-0.1.dev166}/mplang/ops/crypto.py +0 -0
  88. {mplang_nightly-0.1.dev164 → mplang_nightly-0.1.dev166}/mplang/ops/ibis_cc.py +0 -0
  89. {mplang_nightly-0.1.dev164 → mplang_nightly-0.1.dev166}/mplang/ops/jax_cc.py +0 -0
  90. {mplang_nightly-0.1.dev164 → mplang_nightly-0.1.dev166}/mplang/ops/phe.py +0 -0
  91. {mplang_nightly-0.1.dev164 → mplang_nightly-0.1.dev166}/mplang/ops/spu.py +0 -0
  92. {mplang_nightly-0.1.dev164 → mplang_nightly-0.1.dev166}/mplang/ops/sql.py +0 -0
  93. {mplang_nightly-0.1.dev164 → mplang_nightly-0.1.dev166}/mplang/ops/tee.py +0 -0
  94. {mplang_nightly-0.1.dev164 → mplang_nightly-0.1.dev166}/mplang/protos/v1alpha1/mpir_pb2.py +0 -0
  95. {mplang_nightly-0.1.dev164 → mplang_nightly-0.1.dev166}/mplang/protos/v1alpha1/mpir_pb2_grpc.py +0 -0
  96. {mplang_nightly-0.1.dev164 → mplang_nightly-0.1.dev166}/mplang/runtime/__init__.py +0 -0
  97. {mplang_nightly-0.1.dev164 → mplang_nightly-0.1.dev166}/mplang/runtime/cli.py +0 -0
  98. {mplang_nightly-0.1.dev164 → mplang_nightly-0.1.dev166}/mplang/runtime/data_providers.py +0 -0
  99. {mplang_nightly-0.1.dev164 → mplang_nightly-0.1.dev166}/mplang/runtime/exceptions.py +0 -0
  100. {mplang_nightly-0.1.dev164 → mplang_nightly-0.1.dev166}/mplang/runtime/http_api.md +0 -0
  101. {mplang_nightly-0.1.dev164 → mplang_nightly-0.1.dev166}/mplang/simp/mpi.py +0 -0
  102. {mplang_nightly-0.1.dev164 → mplang_nightly-0.1.dev166}/mplang/simp/random.py +0 -0
  103. {mplang_nightly-0.1.dev164 → mplang_nightly-0.1.dev166}/mplang/simp/smpc.py +0 -0
  104. {mplang_nightly-0.1.dev164/mplang/kernels → mplang_nightly-0.1.dev166/mplang/utils}/__init__.py +0 -0
  105. {mplang_nightly-0.1.dev164 → mplang_nightly-0.1.dev166}/mplang/utils/crypto.py +0 -0
  106. {mplang_nightly-0.1.dev164 → mplang_nightly-0.1.dev166}/mplang/utils/func_utils.py +0 -0
  107. {mplang_nightly-0.1.dev164 → mplang_nightly-0.1.dev166}/mplang/utils/spu_utils.py +0 -0
  108. {mplang_nightly-0.1.dev164 → mplang_nightly-0.1.dev166}/mplang/utils/table_utils.py +0 -0
  109. {mplang_nightly-0.1.dev164 → mplang_nightly-0.1.dev166}/tests/__init__.py +0 -0
  110. {mplang_nightly-0.1.dev164 → mplang_nightly-0.1.dev166}/tests/analysis/test_diagram.py +0 -0
  111. {mplang_nightly-0.1.dev164 → mplang_nightly-0.1.dev166}/tests/core/__init__.py +0 -0
  112. {mplang_nightly-0.1.dev164 → mplang_nightly-0.1.dev166}/tests/core/expr/__init__.py +0 -0
  113. {mplang_nightly-0.1.dev164 → mplang_nightly-0.1.dev166}/tests/core/expr/conftest.py +0 -0
  114. {mplang_nightly-0.1.dev164 → mplang_nightly-0.1.dev166}/tests/core/expr/test_ast.py +0 -0
  115. {mplang_nightly-0.1.dev164 → mplang_nightly-0.1.dev166}/tests/core/expr/test_printer.py +0 -0
  116. {mplang_nightly-0.1.dev164 → mplang_nightly-0.1.dev166}/tests/core/expr/test_utils.py +0 -0
  117. {mplang_nightly-0.1.dev164 → mplang_nightly-0.1.dev166}/tests/core/expr/test_walk.py +0 -0
  118. {mplang_nightly-0.1.dev164 → mplang_nightly-0.1.dev166}/tests/core/test_cluster.py +0 -0
  119. {mplang_nightly-0.1.dev164 → mplang_nightly-0.1.dev166}/tests/core/test_dtype.py +0 -0
  120. {mplang_nightly-0.1.dev164 → mplang_nightly-0.1.dev166}/tests/core/test_mask.py +0 -0
  121. {mplang_nightly-0.1.dev164 → mplang_nightly-0.1.dev166}/tests/core/test_mpir.py +0 -0
  122. {mplang_nightly-0.1.dev164 → mplang_nightly-0.1.dev166}/tests/core/test_mptype.py +0 -0
  123. {mplang_nightly-0.1.dev164 → mplang_nightly-0.1.dev166}/tests/core/test_primitive.py +0 -0
  124. {mplang_nightly-0.1.dev164 → mplang_nightly-0.1.dev166}/tests/core/test_table.py +0 -0
  125. {mplang_nightly-0.1.dev164 → mplang_nightly-0.1.dev166}/tests/core/test_tensor.py +0 -0
  126. {mplang_nightly-0.1.dev164 → mplang_nightly-0.1.dev166}/tests/core/test_tracer.py +0 -0
  127. {mplang_nightly-0.1.dev164/mplang/utils → mplang_nightly-0.1.dev166/tests/device}/__init__.py +0 -0
  128. {mplang_nightly-0.1.dev164 → mplang_nightly-0.1.dev166}/tests/device/test_device_basic.py +0 -0
  129. {mplang_nightly-0.1.dev164 → mplang_nightly-0.1.dev166}/tests/integration/README.md +0 -0
  130. {mplang_nightly-0.1.dev164 → mplang_nightly-0.1.dev166}/tests/integration/test_crypto_roundtrip.py +0 -0
  131. {mplang_nightly-0.1.dev164 → mplang_nightly-0.1.dev166}/tests/integration/test_http_e2e.py +0 -0
  132. {mplang_nightly-0.1.dev164 → mplang_nightly-0.1.dev166}/tests/integration/test_tutorials.py +0 -0
  133. {mplang_nightly-0.1.dev164 → mplang_nightly-0.1.dev166}/tests/integration/test_unused_param_integration.py +0 -0
  134. {mplang_nightly-0.1.dev164/tests/device → mplang_nightly-0.1.dev166/tests/ops}/__init__.py +0 -0
  135. {mplang_nightly-0.1.dev164 → mplang_nightly-0.1.dev166}/tests/ops/dummy.py +0 -0
  136. {mplang_nightly-0.1.dev164 → mplang_nightly-0.1.dev166}/tests/ops/test_feop_base.py +0 -0
  137. {mplang_nightly-0.1.dev164 → mplang_nightly-0.1.dev166}/tests/ops/test_ibis.py +0 -0
  138. {mplang_nightly-0.1.dev164 → mplang_nightly-0.1.dev166}/tests/ops/test_ibis_cc.py +0 -0
  139. {mplang_nightly-0.1.dev164 → mplang_nightly-0.1.dev166}/tests/ops/test_jax_cc.py +0 -0
  140. {mplang_nightly-0.1.dev164 → mplang_nightly-0.1.dev166}/tests/ops/test_phe.py +0 -0
  141. {mplang_nightly-0.1.dev164 → mplang_nightly-0.1.dev166}/tests/ops/test_spu.py +0 -0
  142. {mplang_nightly-0.1.dev164 → mplang_nightly-0.1.dev166}/tests/ops/test_spu_defensive.py +0 -0
  143. {mplang_nightly-0.1.dev164 → mplang_nightly-0.1.dev166}/tests/ops/test_sql.py +0 -0
  144. {mplang_nightly-0.1.dev164 → mplang_nightly-0.1.dev166}/tests/ops/test_table_tensor_conversion.py +0 -0
  145. {mplang_nightly-0.1.dev164/tests/ops → mplang_nightly-0.1.dev166/tests/runtime}/__init__.py +0 -0
  146. {mplang_nightly-0.1.dev164 → mplang_nightly-0.1.dev166}/tests/runtime/test_cli.py +0 -0
  147. {mplang_nightly-0.1.dev164 → mplang_nightly-0.1.dev166}/tests/runtime/test_driver.py +0 -0
  148. {mplang_nightly-0.1.dev164 → mplang_nightly-0.1.dev166}/tests/runtime/test_simulation.py +0 -0
  149. {mplang_nightly-0.1.dev164 → mplang_nightly-0.1.dev166}/tests/simp/test_mpi.py +0 -0
  150. {mplang_nightly-0.1.dev164 → mplang_nightly-0.1.dev166}/tests/simp/test_random.py +0 -0
  151. {mplang_nightly-0.1.dev164 → mplang_nightly-0.1.dev166}/tests/simp/test_simp.py +0 -0
  152. {mplang_nightly-0.1.dev164 → mplang_nightly-0.1.dev166}/tests/simp/test_smpc.py +0 -0
  153. {mplang_nightly-0.1.dev164 → mplang_nightly-0.1.dev166}/tests/simp/test_sugar.py +0 -0
  154. {mplang_nightly-0.1.dev164 → mplang_nightly-0.1.dev166}/tests/utils/__init__.py +0 -0
  155. {mplang_nightly-0.1.dev164 → mplang_nightly-0.1.dev166}/tests/utils/server_fixtures.py +0 -0
  156. {mplang_nightly-0.1.dev164 → mplang_nightly-0.1.dev166}/tests/utils/test_func_utils.py +0 -0
  157. {mplang_nightly-0.1.dev164 → mplang_nightly-0.1.dev166}/tests/utils/test_spu_utils.py +0 -0
  158. {mplang_nightly-0.1.dev164 → mplang_nightly-0.1.dev166}/tests/utils/test_table_utils.py +0 -0
  159. {mplang_nightly-0.1.dev164 → mplang_nightly-0.1.dev166}/tutorials/0_basic.py +0 -0
  160. {mplang_nightly-0.1.dev164 → mplang_nightly-0.1.dev166}/tutorials/10_analysis.py +0 -0
  161. {mplang_nightly-0.1.dev164 → mplang_nightly-0.1.dev166}/tutorials/1_condition.py +0 -0
  162. {mplang_nightly-0.1.dev164 → mplang_nightly-0.1.dev166}/tutorials/2_whileloop.py +0 -0
  163. {mplang_nightly-0.1.dev164 → mplang_nightly-0.1.dev166}/tutorials/3_device.py +0 -0
  164. {mplang_nightly-0.1.dev164 → mplang_nightly-0.1.dev166}/tutorials/4_simulation.py +0 -0
  165. {mplang_nightly-0.1.dev164 → mplang_nightly-0.1.dev166}/tutorials/5_ir_dump.py +0 -0
  166. {mplang_nightly-0.1.dev164 → mplang_nightly-0.1.dev166}/tutorials/6_advanced.py +0 -0
  167. {mplang_nightly-0.1.dev164 → mplang_nightly-0.1.dev166}/tutorials/7_stdio.py +0 -0
  168. {mplang_nightly-0.1.dev164 → mplang_nightly-0.1.dev166}/tutorials/8_phe.py +0 -0
  169. {mplang_nightly-0.1.dev164 → mplang_nightly-0.1.dev166}/tutorials/9_tee.py +0 -0
  170. {mplang_nightly-0.1.dev164 → mplang_nightly-0.1.dev166}/tutorials/__init__.py +0 -0
  171. {mplang_nightly-0.1.dev164 → mplang_nightly-0.1.dev166}/tutorials/pitfalls/late_binding.py +0 -0
  172. {mplang_nightly-0.1.dev164 → mplang_nightly-0.1.dev166}/tutorials/pitfalls/rand.py +0 -0
  173. {mplang_nightly-0.1.dev164 → mplang_nightly-0.1.dev166}/tutorials/run.sh +0 -0
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: mplang-nightly
3
- Version: 0.1.dev164
3
+ Version: 0.1.dev166
4
4
  Summary: Multi-Party Programming Language
5
5
  Author-email: SecretFlow Team <secretflow-contact@service.alipay.com>
6
6
  License: Apache License
@@ -27,6 +27,8 @@ from __future__ import annotations
27
27
  from dataclasses import dataclass
28
28
  from typing import Any, Protocol
29
29
 
30
+ import numpy as np
31
+
30
32
  from mplang.core.comm import ICommunicator
31
33
  from mplang.core.expr.ast import (
32
34
  AccessExpr,
@@ -47,6 +49,7 @@ from mplang.core.expr.walk import walk_dataflow
47
49
  from mplang.core.mask import Mask
48
50
  from mplang.core.pfunc import PFunction
49
51
  from mplang.kernels.context import RuntimeContext
52
+ from mplang.kernels.value import Value
50
53
 
51
54
 
52
55
  class IEvaluator(Protocol):
@@ -149,12 +152,12 @@ class EvalSemantic:
149
152
  def _as_optional_int(val: Any) -> int | None:
150
153
  """Convert a value to int if possible, preserving None.
151
154
 
152
- Handles Python ints, numpy scalars with .item(), and None.
155
+ Handles Python ints, floats, numpy scalar types (e.g., np.int32, np.float64), and None.
156
+ Uses int(val) for conversion which works with numpy scalars via __int__().
153
157
  """
158
+ val = EvalSemantic._unwrap_value(val)
154
159
  if val is None:
155
160
  return None
156
- if hasattr(val, "item"):
157
- return int(val.item())
158
161
  return int(val)
159
162
 
160
163
  def _simple_allgather(self, value: Any) -> list[Any]:
@@ -167,6 +170,7 @@ class EvalSemantic:
167
170
  Returns a list of length world_size with entries ordered by rank.
168
171
  """
169
172
  ws = self.comm.world_size
173
+ value = self._unwrap_value(value)
170
174
  # Trivial fast-path
171
175
  if ws == 1:
172
176
  return [value]
@@ -185,7 +189,12 @@ class EvalSemantic:
185
189
 
186
190
  def _verify_uniform_predicate(self, pred: Any) -> None:
187
191
  # Runtime uniformity check (O(P^2) send/recv emulation).
188
- vals = self._simple_allgather(bool(pred))
192
+ # Use Value.to_bool() if available, otherwise unwrap and convert
193
+ if isinstance(pred, Value):
194
+ pred_bool = pred.to_bool()
195
+ else:
196
+ pred_bool = bool(self._unwrap_value(pred))
197
+ vals = self._simple_allgather(pred_bool)
189
198
  if not vals:
190
199
  raise ValueError("uniform_cond: empty gather for predicate")
191
200
  first = vals[0]
@@ -209,13 +218,33 @@ class EvalSemantic:
209
218
  assert len(cond_result) == 1, (
210
219
  f"Condition function must return a single value, got {cond_result}"
211
220
  )
212
- cond_value = cond_result[0]
213
- if cond_value is None:
221
+ cond_val = cond_result[0]
222
+ if cond_val is None:
214
223
  raise RuntimeError(
215
224
  "while_loop condition produced None on rank "
216
225
  f"{self.rank}; ensure the predicate yields a boolean for every party."
217
226
  )
218
- return cond_value
227
+ # Use Value.to_bool() if available for cleaner conversion
228
+ if isinstance(cond_val, Value):
229
+ return cond_val.to_bool()
230
+ return bool(self._unwrap_value(cond_val))
231
+
232
+ @staticmethod
233
+ def _unwrap_value(value: Any) -> Any:
234
+ """Convert Value payloads to numpy/python equivalents when possible."""
235
+ if value is None:
236
+ return None
237
+ if isinstance(value, Value):
238
+ # Try to_numpy first for broader compatibility
239
+ to_numpy = getattr(value, "to_numpy", None)
240
+ if callable(to_numpy):
241
+ arr = to_numpy()
242
+ if isinstance(arr, np.ndarray):
243
+ if arr.size == 1:
244
+ return arr.item()
245
+ return arr
246
+ return arr
247
+ return value
219
248
 
220
249
 
221
250
  class RecursiveEvaluator(EvalSemantic, ExprVisitor):
@@ -296,15 +325,21 @@ class RecursiveEvaluator(EvalSemantic, ExprVisitor):
296
325
  * Add optional static uniform inference (data provenance) to elide the
297
326
  runtime check when predicate uniformity is provable at trace time.
298
327
  """
299
- pred = self._value(expr.pred)
300
- if pred is None:
328
+ pred_val = self._value(expr.pred)
329
+ if pred_val is None:
301
330
  return [None] * len(expr.mptypes)
302
331
 
303
332
  if expr.verify_uniform:
304
- self._verify_uniform_predicate(pred)
333
+ self._verify_uniform_predicate(pred_val)
334
+
335
+ # Convert to bool using Value.to_bool() if available
336
+ if isinstance(pred_val, Value):
337
+ pred = pred_val.to_bool()
338
+ else:
339
+ pred = bool(self._unwrap_value(pred_val))
305
340
 
306
341
  # Only evaluate selected branch locally
307
- if pred:
342
+ if bool(pred):
308
343
  then_call = CallExpr(expr.then_fn, expr.args)
309
344
  return self._values(then_call)
310
345
  else:
@@ -435,15 +470,20 @@ class IterativeEvaluator(EvalSemantic):
435
470
  res = self._iter_eval_graph(node.fn.body, {**env, **sub_env})
436
471
  symbols[id(node)] = res
437
472
  elif isinstance(node, CondExpr):
438
- pred_v = self._first(symbols[id(node.pred)])
473
+ pred_val = self._first(symbols[id(node.pred)])
439
474
  arg_vals = [self._first(symbols[id(a)]) for a in node.args]
440
- if pred_v is None:
475
+ if pred_val is None:
441
476
  symbols[id(node)] = [None] * len(node.mptypes)
442
477
  else:
443
478
  # Optional uniform verification identical to recursive evaluator (DRY helper).
444
479
  if node.verify_uniform:
445
- self._verify_uniform_predicate(pred_v)
446
- if bool(pred_v):
480
+ self._verify_uniform_predicate(pred_val)
481
+ # Convert to bool using Value.to_bool() if available
482
+ if isinstance(pred_val, Value):
483
+ pred = pred_val.to_bool()
484
+ else:
485
+ pred = bool(self._unwrap_value(pred_val))
486
+ if pred:
447
487
  sub_env = dict(zip(node.then_fn.params, arg_vals, strict=True))
448
488
  res = self._iter_eval_graph(
449
489
  node.then_fn.body, {**env, **sub_env}
@@ -0,0 +1,41 @@
1
+ # Copyright 2025 Ant Group Co., Ltd.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ from mplang.kernels.value import (
16
+ BytesBlob,
17
+ TableValue,
18
+ TensorValue,
19
+ Value,
20
+ ValueDecodeError,
21
+ ValueError,
22
+ decode_value,
23
+ encode_value,
24
+ is_value_envelope,
25
+ list_value_kinds,
26
+ register_value,
27
+ )
28
+
29
+ __all__ = [
30
+ "BytesBlob",
31
+ "TableValue",
32
+ "TensorValue",
33
+ "Value",
34
+ "ValueDecodeError",
35
+ "ValueError",
36
+ "decode_value",
37
+ "encode_value",
38
+ "is_value_envelope",
39
+ "list_value_kinds",
40
+ "register_value",
41
+ ]
@@ -14,38 +14,25 @@
14
14
 
15
15
  from __future__ import annotations
16
16
 
17
- from typing import Any
18
-
19
17
  import numpy as np
20
- import pandas as pd
21
18
 
22
19
  from mplang.core.pfunc import PFunction
23
20
  from mplang.core.table import TableType
24
21
  from mplang.core.tensor import TensorType
25
22
  from mplang.kernels.base import cur_kctx, kernel_def
23
+ from mplang.kernels.value import TableValue, TensorValue, Value
26
24
  from mplang.runtime.data_providers import get_provider, resolve_uri
27
25
  from mplang.utils import table_utils
28
26
 
29
27
 
30
- def _to_numpy(obj: Any) -> np.ndarray: # minimal helper to avoid duplicating logic
31
- if isinstance(obj, np.ndarray):
32
- return obj
33
- if hasattr(obj, "numpy"):
34
- try:
35
- return np.asarray(obj.numpy()) # type: ignore
36
- except Exception:
37
- pass
38
- return np.asarray(obj)
39
-
40
-
41
28
  @kernel_def("builtin.identity")
42
- def _identity(pfunc: PFunction, value: Any) -> Any:
29
+ def _identity(pfunc: PFunction, value: Value) -> Value:
43
30
  # Runtime guarantees exactly one argument; no extra arity checks here.
44
31
  return value
45
32
 
46
33
 
47
34
  @kernel_def("builtin.read")
48
- def _read(pfunc: PFunction) -> Any:
35
+ def _read(pfunc: PFunction) -> Value:
49
36
  path = pfunc.attrs.get("path")
50
37
  if path is None:
51
38
  raise ValueError("missing path attr for builtin.read")
@@ -56,13 +43,25 @@ def _read(pfunc: PFunction) -> Any:
56
43
  raise NotImplementedError(f"no resource provider for scheme: {uri.scheme}")
57
44
  ctx = cur_kctx()
58
45
  try:
59
- return prov.read(uri, out_t, ctx=ctx)
46
+ data = prov.read(uri, out_t, ctx=ctx)
60
47
  except Exception as e: # pragma: no cover - provider errors
61
48
  raise RuntimeError(f"builtin.read failed: {e}") from e
62
49
 
50
+ if isinstance(out_t, TableType):
51
+ if isinstance(data, TableValue):
52
+ return data
53
+ return TableValue(data)
54
+ if isinstance(out_t, TensorType):
55
+ if isinstance(data, TensorValue):
56
+ return data
57
+ return TensorValue(np.asarray(data))
58
+ raise TypeError(
59
+ f"builtin.read only supports TableType/TensorType outputs, got {type(out_t).__name__}"
60
+ )
61
+
63
62
 
64
63
  @kernel_def("builtin.write")
65
- def _write(pfunc: PFunction, obj: Any) -> Any:
64
+ def _write(pfunc: PFunction, obj: Value) -> Value:
66
65
  path = pfunc.attrs.get("path")
67
66
  if path is None:
68
67
  raise ValueError("missing path attr for builtin.write")
@@ -70,16 +69,18 @@ def _write(pfunc: PFunction, obj: Any) -> Any:
70
69
  prov = get_provider(uri.scheme)
71
70
  if prov is None:
72
71
  raise NotImplementedError(f"no resource provider for scheme: {uri.scheme}")
72
+ # Pass Value object directly to provider - let provider decide how to handle it
73
73
  ctx = cur_kctx()
74
74
  try:
75
75
  prov.write(uri, obj, ctx=ctx)
76
- return obj
77
76
  except Exception as e: # pragma: no cover
78
77
  raise RuntimeError(f"builtin.write failed: {e}") from e
78
+ return obj
79
79
 
80
80
 
81
81
  @kernel_def("builtin.constant")
82
- def _constant(pfunc: PFunction) -> Any:
82
+ def _constant(pfunc: PFunction) -> Value:
83
+ """Return constants as Value types (TensorValue or TableValue)."""
83
84
  data_bytes = pfunc.attrs.get("data_bytes")
84
85
  if data_bytes is None:
85
86
  raise ValueError("missing data_bytes attr for builtin.constant")
@@ -89,69 +90,86 @@ def _constant(pfunc: PFunction) -> Any:
89
90
  if fmt != "bytes[csv]":
90
91
  raise ValueError(f"unsupported table constant format {fmt}")
91
92
  df = table_utils.csv_to_dataframe(data_bytes)
92
- return df
93
+ return TableValue(df)
93
94
  # tensor path
94
95
  shape = out_t.shape # type: ignore[attr-defined,union-attr]
95
96
  dtype = out_t.dtype.numpy_dtype() # type: ignore[attr-defined,union-attr]
96
97
  arr = np.frombuffer(data_bytes, dtype=dtype).reshape(shape)
97
- return arr
98
+ return TensorValue(arr)
98
99
 
99
100
 
100
101
  @kernel_def("builtin.rank")
101
- def _rank(pfunc: PFunction) -> Any:
102
+ def _rank(pfunc: PFunction) -> TensorValue:
103
+ """Return rank as TensorValue."""
102
104
  ctx = cur_kctx()
103
- return np.array(ctx.rank, dtype=np.uint64)
105
+ arr = np.array(ctx.rank, dtype=np.uint64)
106
+ return TensorValue(arr)
104
107
 
105
108
 
106
109
  @kernel_def("builtin.prand")
107
- def _prand(pfunc: PFunction) -> Any:
110
+ def _prand(pfunc: PFunction) -> TensorValue:
111
+ """Return random data as TensorValue."""
108
112
  shape = pfunc.attrs.get("shape", ())
109
113
  rng = np.random.default_rng()
110
114
  info = np.iinfo(np.uint64)
111
115
  data = rng.integers(
112
116
  low=info.min, high=info.max, size=shape, dtype=np.uint64, endpoint=True
113
117
  )
114
- return data
118
+ return TensorValue(data)
115
119
 
116
120
 
117
121
  @kernel_def("builtin.table_to_tensor")
118
- def _table_to_tensor(pfunc: PFunction, table: Any) -> Any:
119
- if not isinstance(table, pd.DataFrame):
120
- raise TypeError("expected pandas DataFrame")
121
- if table.shape[1] == 0:
122
+ def _table_to_tensor(pfunc: PFunction, table: TableValue) -> TensorValue:
123
+ """Convert table to tensor, return as TensorValue."""
124
+ arrow_table = table.to_arrow()
125
+ if arrow_table.num_columns == 0:
122
126
  raise ValueError("cannot pack empty table")
123
- mat = np.column_stack([table[col].to_numpy() for col in table.columns])
124
- return mat
127
+ # Convert Arrow columns to numpy arrays and stack
128
+ mat = np.column_stack([
129
+ arrow_table.column(i).to_numpy() for i in range(arrow_table.num_columns)
130
+ ])
131
+ return TensorValue(mat)
125
132
 
126
133
 
127
134
  @kernel_def("builtin.tensor_to_table")
128
- def _tensor_to_table(pfunc: PFunction, tensor: Any) -> Any:
129
- arr = _to_numpy(tensor)
135
+ def _tensor_to_table(pfunc: PFunction, tensor: TensorValue) -> TableValue:
136
+ """Convert tensor to table, return as TableValue."""
137
+ import pyarrow as pa # type: ignore
138
+
139
+ arr = tensor.to_numpy()
130
140
  if arr.ndim != 2:
131
141
  raise ValueError("tensor_to_table expects rank-2 array")
132
142
  col_names = pfunc.attrs.get("column_names")
133
143
  if col_names is None:
134
144
  raise ValueError("missing column_names attr")
135
- df = pd.DataFrame(arr, columns=list(col_names))
136
- return df
145
+ # Create Arrow table directly from numpy array columns
146
+ arrays = [pa.array(arr[:, i]) for i in range(arr.shape[1])]
147
+ arrow_table = pa.table(dict(zip(col_names, arrays, strict=True)))
148
+ return TableValue(arrow_table)
137
149
 
138
150
 
139
- def _summ(v: Any) -> str:
151
+ def _summ(v: Value) -> str:
140
152
  try:
141
- if isinstance(v, pd.DataFrame):
142
- return str(v.head(8).to_string(index=False))
143
- arr = _to_numpy(v)
144
- return str(
145
- np.array2string(
146
- arr, threshold=64, edgeitems=3, precision=6, suppress_small=True
153
+ if isinstance(v, TableValue):
154
+ # Use Arrow's native string representation (more efficient)
155
+ arrow_table = v.to_arrow()
156
+ # Show first 8 rows
157
+ preview = arrow_table.slice(0, min(8, arrow_table.num_rows))
158
+ return str(preview)
159
+ if isinstance(v, TensorValue):
160
+ arr = v.to_numpy()
161
+ return str(
162
+ np.array2string(
163
+ arr, threshold=64, edgeitems=3, precision=6, suppress_small=True
164
+ )
147
165
  )
148
- )
166
+ return repr(v)
149
167
  except Exception as e: # pragma: no cover
150
168
  return f"<unprintable {type(v).__name__}: {e}>"
151
169
 
152
170
 
153
171
  @kernel_def("builtin.debug_print")
154
- def _debug_print(pfunc: PFunction, val: Any) -> Any:
172
+ def _debug_print(pfunc: PFunction, val: Value) -> Value:
155
173
  prefix = pfunc.attrs.get("prefix", "")
156
174
  ctx = cur_kctx()
157
175
  print(f"[debug_print][rank={ctx.rank}] {prefix}{_summ(val)}")
@@ -159,7 +177,7 @@ def _debug_print(pfunc: PFunction, val: Any) -> Any:
159
177
 
160
178
 
161
179
  @kernel_def("builtin.pack")
162
- def _pack(pfunc: PFunction, value: Any) -> Any:
180
+ def _pack(pfunc: PFunction, value: Value) -> TensorValue:
163
181
  outs_info = pfunc.outs_info
164
182
  if len(outs_info) != 1:
165
183
  raise ValueError("builtin.pack expects single output type")
@@ -169,22 +187,33 @@ def _pack(pfunc: PFunction, value: Any) -> Any:
169
187
  if out_ty.dtype.numpy_dtype() != np.uint8:
170
188
  raise TypeError("builtin.pack output dtype must be uint8")
171
189
 
172
- if isinstance(value, pd.DataFrame):
173
- csv_bytes = table_utils.dataframe_to_csv(value)
174
- return np.frombuffer(csv_bytes, dtype=np.uint8)
190
+ if isinstance(value, TableValue):
191
+ # Serialize Arrow table using IPC stream for consistency with Value serde
192
+ import pyarrow as pa # type: ignore
193
+ import pyarrow.ipc as pa_ipc # type: ignore
194
+
195
+ arrow_table = value.to_arrow()
196
+ sink = pa.BufferOutputStream()
197
+ with pa_ipc.new_stream(sink, arrow_table.schema) as writer: # type: ignore[arg-type]
198
+ writer.write_table(arrow_table) # type: ignore[arg-type]
199
+ ipc_bytes = sink.getvalue().to_pybytes()
200
+ return TensorValue(np.frombuffer(ipc_bytes, dtype=np.uint8))
201
+
202
+ if isinstance(value, TensorValue):
203
+ arr = value.to_numpy()
204
+ return TensorValue(np.frombuffer(arr.tobytes(order="C"), dtype=np.uint8))
175
205
 
176
- arr = _to_numpy(value)
177
- return np.frombuffer(arr.tobytes(order="C"), dtype=np.uint8)
206
+ raise TypeError(f"builtin.pack does not support Value type {type(value).__name__}")
178
207
 
179
208
 
180
209
  @kernel_def("builtin.unpack")
181
- def _unpack(pfunc: PFunction, packed: Any) -> Any:
210
+ def _unpack(pfunc: PFunction, packed: TensorValue) -> Value:
182
211
  outs_info = pfunc.outs_info
183
212
  if len(outs_info) != 1:
184
213
  raise ValueError("builtin.unpack expects single output type")
185
214
  out_ty = outs_info[0]
186
215
 
187
- b = np.asarray(packed, dtype=np.uint8).reshape(-1)
216
+ b = packed.to_numpy().astype(np.uint8, copy=False).reshape(-1)
188
217
 
189
218
  if isinstance(out_ty, TensorType):
190
219
  np_dtype = out_ty.dtype.numpy_dtype()
@@ -198,10 +227,16 @@ def _unpack(pfunc: PFunction, packed: Any) -> Any:
198
227
  f"unpack size mismatch: got {b.size} bytes, expect {expected} for {np_dtype} {shape}"
199
228
  )
200
229
  arr = np.frombuffer(b.tobytes(), dtype=np_dtype)
201
- return arr.reshape(shape)
230
+ return TensorValue(arr.reshape(shape))
202
231
 
203
232
  if isinstance(out_ty, TableType):
204
- csv_bytes = b.tobytes()
205
- return table_utils.csv_to_dataframe(csv_bytes)
233
+ # Deserialize Arrow IPC stream back to TableValue
234
+ import pyarrow as pa # type: ignore
235
+ import pyarrow.ipc as pa_ipc # type: ignore
236
+
237
+ buf = pa.py_buffer(b.tobytes())
238
+ reader = pa_ipc.open_stream(buf)
239
+ table = reader.read_all()
240
+ return TableValue(table)
206
241
 
207
242
  raise TypeError("builtin.unpack output type must be TensorType or TableType")
@@ -15,15 +15,15 @@
15
15
  from __future__ import annotations
16
16
 
17
17
  import os
18
- from typing import Any
19
18
 
20
19
  import numpy as np
21
20
 
22
21
  from mplang.core.pfunc import PFunction
23
22
  from mplang.kernels.base import cur_kctx, kernel_def
23
+ from mplang.kernels.value import TensorValue
24
24
  from mplang.utils.crypto import blake2b
25
25
 
26
- __all__: list[str] = [] # flat kernels only
26
+ __all__: list[str] = [] # No public exports currently
27
27
 
28
28
 
29
29
  def _get_rng() -> np.random.Generator:
@@ -54,62 +54,71 @@ def _keystream(key: bytes, nonce: bytes, length: int) -> bytes:
54
54
 
55
55
 
56
56
  @kernel_def("crypto.keygen")
57
- def _crypto_keygen(pfunc: PFunction) -> Any:
57
+ def _crypto_keygen(pfunc: PFunction) -> TensorValue:
58
58
  length = int(pfunc.attrs.get("length", 32))
59
59
  rng = _get_rng()
60
60
  key = rng.integers(0, 256, size=(length,), dtype=np.uint8)
61
- return key
61
+ return TensorValue(key)
62
62
 
63
63
 
64
64
  @kernel_def("crypto.enc")
65
- def _crypto_encrypt(pfunc: PFunction, pt_bytes: Any, key: Any) -> Any:
66
- pt_bytes = np.asarray(pt_bytes, dtype=np.uint8)
67
- key = np.asarray(key, dtype=np.uint8)
65
+ def _crypto_encrypt(
66
+ pfunc: PFunction, pt_bytes: TensorValue, key: TensorValue
67
+ ) -> TensorValue:
68
+ pt_bytes_np = pt_bytes.to_numpy().astype(np.uint8, copy=False)
69
+ key_np = key.to_numpy().astype(np.uint8, copy=False)
68
70
  rng = _get_rng()
69
71
  nonce = rng.integers(0, 256, size=(12,), dtype=np.uint8)
70
72
  stream = np.frombuffer(
71
- _keystream(key.tobytes(), nonce.tobytes(), pt_bytes.size), dtype=np.uint8
73
+ _keystream(key_np.tobytes(), nonce.tobytes(), pt_bytes_np.size), dtype=np.uint8
72
74
  )
73
- ct = (pt_bytes ^ stream).astype(np.uint8)
75
+ ct = (pt_bytes_np ^ stream).astype(np.uint8)
74
76
  out = np.concatenate([nonce, ct]).astype(np.uint8)
75
- return out
77
+ return TensorValue(out)
76
78
 
77
79
 
78
80
  @kernel_def("crypto.dec")
79
- def _crypto_decrypt(pfunc: PFunction, ct_with_nonce: Any, key: Any) -> Any:
80
- ct_with_nonce = np.asarray(ct_with_nonce, dtype=np.uint8)
81
- key = np.asarray(key, dtype=np.uint8)
82
- nonce = ct_with_nonce[:12]
83
- ct = ct_with_nonce[12:]
81
+ def _crypto_decrypt(
82
+ pfunc: PFunction, ct_with_nonce: TensorValue, key: TensorValue
83
+ ) -> TensorValue:
84
+ ct_np = ct_with_nonce.to_numpy().astype(np.uint8, copy=False)
85
+ key_np = key.to_numpy().astype(np.uint8, copy=False)
86
+ nonce = ct_np[:12]
87
+ ct = ct_np[12:]
84
88
  stream = np.frombuffer(
85
- _keystream(key.tobytes(), nonce.tobytes(), len(ct)), dtype=np.uint8
89
+ _keystream(key_np.tobytes(), nonce.tobytes(), len(ct)), dtype=np.uint8
86
90
  )
87
91
  pt_bytes = (ct ^ stream).astype(np.uint8)
88
- return pt_bytes
92
+ return TensorValue(pt_bytes)
89
93
 
90
94
 
91
95
  @kernel_def("crypto.kem_keygen")
92
- def _crypto_kem_keygen(pfunc: PFunction) -> Any:
96
+ def _crypto_kem_keygen(pfunc: PFunction) -> tuple[TensorValue, TensorValue]:
93
97
  rng = _get_rng()
94
98
  sk = rng.integers(0, 256, size=(32,), dtype=np.uint8)
95
- pk = np.frombuffer(blake2b(sk.tobytes())[:32], dtype=np.uint8)
96
- return (sk, pk)
99
+ pk_bytes = blake2b(sk.tobytes())[:32]
100
+ pk = np.frombuffer(pk_bytes, dtype=np.uint8)
101
+ return (TensorValue(sk), TensorValue(pk))
97
102
 
98
103
 
99
104
  @kernel_def("crypto.kem_derive")
100
- def _crypto_kem_derive(pfunc: PFunction, sk: Any, peer_pk: Any) -> Any:
101
- sk = np.asarray(sk, dtype=np.uint8)
102
- peer_pk = np.asarray(peer_pk, dtype=np.uint8)
103
- self_pk = np.frombuffer(blake2b(sk.tobytes())[:32], dtype=np.uint8)
104
- xored = (self_pk ^ peer_pk).astype(np.uint8)
105
+ def _crypto_kem_derive(
106
+ pfunc: PFunction, sk: TensorValue, peer_pk: TensorValue
107
+ ) -> TensorValue:
108
+ sk_np = sk.to_numpy().astype(np.uint8, copy=False)
109
+ peer_pk_np = peer_pk.to_numpy().astype(np.uint8, copy=False)
110
+
111
+ self_pk_bytes = blake2b(sk_np.tobytes())[:32]
112
+ self_pk_arr = np.frombuffer(self_pk_bytes, dtype=np.uint8)
113
+ xored = (self_pk_arr ^ peer_pk_np).astype(np.uint8)
105
114
  secret = np.frombuffer(blake2b(xored.tobytes())[:32], dtype=np.uint8)
106
- return secret
115
+ return TensorValue(secret)
107
116
 
108
117
 
109
118
  @kernel_def("crypto.hkdf")
110
- def _crypto_hkdf(pfunc: PFunction, secret: Any) -> Any:
111
- secret = np.asarray(secret, dtype=np.uint8)
119
+ def _crypto_hkdf(pfunc: PFunction, secret: TensorValue) -> TensorValue:
120
+ secret_np = secret.to_numpy().astype(np.uint8, copy=False)
112
121
  info_str = str(pfunc.attrs.get("info", ""))
113
122
  info = info_str.encode("utf-8")
114
- out = np.frombuffer(blake2b(secret.tobytes() + info)[:32], dtype=np.uint8)
115
- return out
123
+ out = np.frombuffer(blake2b(secret_np.tobytes() + info)[:32], dtype=np.uint8)
124
+ return TensorValue(out)
@@ -22,6 +22,7 @@ from numpy.typing import NDArray
22
22
 
23
23
  from mplang.core.pfunc import PFunction
24
24
  from mplang.kernels.base import cur_kctx, kernel_def
25
+ from mplang.kernels.value import TensorValue
25
26
 
26
27
  __all__: list[str] = []
27
28
 
@@ -46,25 +47,26 @@ def _quote_from_pk(pk: np.ndarray) -> NDArray[np.uint8]:
46
47
 
47
48
 
48
49
  @kernel_def("mock_tee.quote_gen")
49
- def _tee_quote_gen(pfunc: PFunction, pk: object) -> NDArray[np.uint8]:
50
+ def _tee_quote_gen(pfunc: PFunction, pk: TensorValue) -> TensorValue:
50
51
  warnings.warn(
51
52
  "Insecure mock TEE kernel 'mock_tee.quote_gen' in use. NOT secure; for local testing only.",
52
53
  stacklevel=3,
53
54
  )
54
- pk = np.asarray(pk, dtype=np.uint8)
55
+ pk_arr = pk.to_numpy().astype(np.uint8, copy=False)
55
56
  # rng access ensures deterministic seeding per rank even if unused now
56
57
  _rng()
57
- return _quote_from_pk(pk)
58
+ quote = _quote_from_pk(pk_arr)
59
+ return TensorValue(np.array(quote, copy=True))
58
60
 
59
61
 
60
62
  @kernel_def("mock_tee.attest")
61
- def _tee_attest(pfunc: PFunction, quote: object) -> NDArray[np.uint8]:
63
+ def _tee_attest(pfunc: PFunction, quote: TensorValue) -> TensorValue:
62
64
  warnings.warn(
63
65
  "Insecure mock TEE kernel 'mock_tee.attest' in use. NOT secure; for local testing only.",
64
66
  stacklevel=3,
65
67
  )
66
- quote = np.asarray(quote, dtype=np.uint8)
67
-
68
- if quote.size != 33:
68
+ quote_arr = quote.to_numpy().astype(np.uint8, copy=False)
69
+ if quote_arr.size != 33:
69
70
  raise ValueError("mock quote must be 33 bytes (1 header + 32 pk)")
70
- return quote[1:33].astype(np.uint8)
71
+ attest = quote_arr[1:33].astype(np.uint8, copy=True)
72
+ return TensorValue(attest)