mplang-nightly 0.1.dev143__tar.gz → 0.1.dev145__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (165) hide show
  1. {mplang_nightly-0.1.dev143 → mplang_nightly-0.1.dev145}/PKG-INFO +1 -1
  2. mplang_nightly-0.1.dev145/mplang/backend/base.py +175 -0
  3. mplang_nightly-0.1.dev145/mplang/backend/context.py +255 -0
  4. {mplang_nightly-0.1.dev143 → mplang_nightly-0.1.dev145}/mplang/backend/spu.py +6 -4
  5. {mplang_nightly-0.1.dev143 → mplang_nightly-0.1.dev145}/mplang/backend/sql_duckdb.py +1 -1
  6. {mplang_nightly-0.1.dev143 → mplang_nightly-0.1.dev145}/mplang/core/expr/evaluator.py +6 -6
  7. {mplang_nightly-0.1.dev143 → mplang_nightly-0.1.dev145}/mplang/frontend/base.py +1 -1
  8. {mplang_nightly-0.1.dev143 → mplang_nightly-0.1.dev145}/mplang/frontend/ibis_cc.py +2 -1
  9. {mplang_nightly-0.1.dev143 → mplang_nightly-0.1.dev145}/mplang/frontend/spu.py +4 -3
  10. {mplang_nightly-0.1.dev143 → mplang_nightly-0.1.dev145}/mplang/runtime/resource.py +39 -62
  11. {mplang_nightly-0.1.dev143 → mplang_nightly-0.1.dev145}/mplang/runtime/simulation.py +6 -13
  12. {mplang_nightly-0.1.dev143 → mplang_nightly-0.1.dev145}/tests/backend/test_builtin.py +4 -4
  13. {mplang_nightly-0.1.dev143 → mplang_nightly-0.1.dev145}/tests/backend/test_phe.py +5 -4
  14. {mplang_nightly-0.1.dev143 → mplang_nightly-0.1.dev145}/tests/backend/test_spu.py +13 -15
  15. {mplang_nightly-0.1.dev143 → mplang_nightly-0.1.dev145}/tests/backend/test_sql_duckdb.py +4 -5
  16. {mplang_nightly-0.1.dev143 → mplang_nightly-0.1.dev145}/tests/backend/test_stablehlo.py +3 -3
  17. {mplang_nightly-0.1.dev143 → mplang_nightly-0.1.dev145}/tests/frontend/test_spu.py +1 -1
  18. {mplang_nightly-0.1.dev143 → mplang_nightly-0.1.dev145}/tests/frontend/test_spu_defensive.py +1 -1
  19. mplang_nightly-0.1.dev145/tests/runtime/__init__.py +13 -0
  20. mplang_nightly-0.1.dev143/mplang/backend/__init__.py +0 -20
  21. mplang_nightly-0.1.dev143/mplang/backend/base.py +0 -287
  22. {mplang_nightly-0.1.dev143 → mplang_nightly-0.1.dev145}/.gitignore +0 -0
  23. {mplang_nightly-0.1.dev143 → mplang_nightly-0.1.dev145}/LICENSE +0 -0
  24. {mplang_nightly-0.1.dev143 → mplang_nightly-0.1.dev145}/README.md +0 -0
  25. {mplang_nightly-0.1.dev143 → mplang_nightly-0.1.dev145}/examples/conf/3pc.yaml +0 -0
  26. {mplang_nightly-0.1.dev143 → mplang_nightly-0.1.dev145}/examples/stax_nn/README.md +0 -0
  27. {mplang_nightly-0.1.dev143 → mplang_nightly-0.1.dev145}/examples/stax_nn/models.py +0 -0
  28. {mplang_nightly-0.1.dev143 → mplang_nightly-0.1.dev145}/examples/stax_nn/stax_nn.py +0 -0
  29. {mplang_nightly-0.1.dev143 → mplang_nightly-0.1.dev145}/examples/xgboost/hist_jax.py +0 -0
  30. {mplang_nightly-0.1.dev143 → mplang_nightly-0.1.dev145}/examples/xgboost/hist_jax_test.py +0 -0
  31. {mplang_nightly-0.1.dev143 → mplang_nightly-0.1.dev145}/examples/xgboost/naive_np.py +0 -0
  32. {mplang_nightly-0.1.dev143 → mplang_nightly-0.1.dev145}/examples/xgboost/readme.md +0 -0
  33. {mplang_nightly-0.1.dev143 → mplang_nightly-0.1.dev145}/examples/xgboost/sgb.py +0 -0
  34. {mplang_nightly-0.1.dev143 → mplang_nightly-0.1.dev145}/examples/xgboost/sgb_test.py +0 -0
  35. {mplang_nightly-0.1.dev143 → mplang_nightly-0.1.dev145}/hatch_build.py +0 -0
  36. {mplang_nightly-0.1.dev143 → mplang_nightly-0.1.dev145}/mplang/__init__.py +0 -0
  37. {mplang_nightly-0.1.dev143 → mplang_nightly-0.1.dev145}/mplang/analysis/__init__.py +0 -0
  38. {mplang_nightly-0.1.dev143 → mplang_nightly-0.1.dev145}/mplang/analysis/diagram.py +0 -0
  39. {mplang_nightly-0.1.dev143 → mplang_nightly-0.1.dev145}/mplang/api.py +0 -0
  40. {mplang_nightly-0.1.dev143/mplang/utils → mplang_nightly-0.1.dev145/mplang/backend}/__init__.py +0 -0
  41. {mplang_nightly-0.1.dev143 → mplang_nightly-0.1.dev145}/mplang/backend/builtin.py +0 -0
  42. {mplang_nightly-0.1.dev143 → mplang_nightly-0.1.dev145}/mplang/backend/crypto.py +0 -0
  43. {mplang_nightly-0.1.dev143 → mplang_nightly-0.1.dev145}/mplang/backend/phe.py +0 -0
  44. {mplang_nightly-0.1.dev143 → mplang_nightly-0.1.dev145}/mplang/backend/stablehlo.py +0 -0
  45. {mplang_nightly-0.1.dev143 → mplang_nightly-0.1.dev145}/mplang/backend/tee.py +0 -0
  46. {mplang_nightly-0.1.dev143 → mplang_nightly-0.1.dev145}/mplang/core/__init__.py +0 -0
  47. {mplang_nightly-0.1.dev143 → mplang_nightly-0.1.dev145}/mplang/core/cluster.py +0 -0
  48. {mplang_nightly-0.1.dev143 → mplang_nightly-0.1.dev145}/mplang/core/comm.py +0 -0
  49. {mplang_nightly-0.1.dev143 → mplang_nightly-0.1.dev145}/mplang/core/context_mgr.py +0 -0
  50. {mplang_nightly-0.1.dev143 → mplang_nightly-0.1.dev145}/mplang/core/dtype.py +0 -0
  51. {mplang_nightly-0.1.dev143 → mplang_nightly-0.1.dev145}/mplang/core/expr/__init__.py +0 -0
  52. {mplang_nightly-0.1.dev143 → mplang_nightly-0.1.dev145}/mplang/core/expr/ast.py +0 -0
  53. {mplang_nightly-0.1.dev143 → mplang_nightly-0.1.dev145}/mplang/core/expr/printer.py +0 -0
  54. {mplang_nightly-0.1.dev143 → mplang_nightly-0.1.dev145}/mplang/core/expr/transformer.py +0 -0
  55. {mplang_nightly-0.1.dev143 → mplang_nightly-0.1.dev145}/mplang/core/expr/utils.py +0 -0
  56. {mplang_nightly-0.1.dev143 → mplang_nightly-0.1.dev145}/mplang/core/expr/visitor.py +0 -0
  57. {mplang_nightly-0.1.dev143 → mplang_nightly-0.1.dev145}/mplang/core/expr/walk.py +0 -0
  58. {mplang_nightly-0.1.dev143 → mplang_nightly-0.1.dev145}/mplang/core/interp.py +0 -0
  59. {mplang_nightly-0.1.dev143 → mplang_nightly-0.1.dev145}/mplang/core/mask.py +0 -0
  60. {mplang_nightly-0.1.dev143 → mplang_nightly-0.1.dev145}/mplang/core/mpir.py +0 -0
  61. {mplang_nightly-0.1.dev143 → mplang_nightly-0.1.dev145}/mplang/core/mpobject.py +0 -0
  62. {mplang_nightly-0.1.dev143 → mplang_nightly-0.1.dev145}/mplang/core/mptype.py +0 -0
  63. {mplang_nightly-0.1.dev143 → mplang_nightly-0.1.dev145}/mplang/core/pfunc.py +0 -0
  64. {mplang_nightly-0.1.dev143 → mplang_nightly-0.1.dev145}/mplang/core/primitive.py +0 -0
  65. {mplang_nightly-0.1.dev143 → mplang_nightly-0.1.dev145}/mplang/core/table.py +0 -0
  66. {mplang_nightly-0.1.dev143 → mplang_nightly-0.1.dev145}/mplang/core/tensor.py +0 -0
  67. {mplang_nightly-0.1.dev143 → mplang_nightly-0.1.dev145}/mplang/core/tracer.py +0 -0
  68. {mplang_nightly-0.1.dev143 → mplang_nightly-0.1.dev145}/mplang/device.py +0 -0
  69. {mplang_nightly-0.1.dev143 → mplang_nightly-0.1.dev145}/mplang/frontend/__init__.py +0 -0
  70. {mplang_nightly-0.1.dev143 → mplang_nightly-0.1.dev145}/mplang/frontend/builtin.py +0 -0
  71. {mplang_nightly-0.1.dev143 → mplang_nightly-0.1.dev145}/mplang/frontend/crypto.py +0 -0
  72. {mplang_nightly-0.1.dev143 → mplang_nightly-0.1.dev145}/mplang/frontend/jax_cc.py +0 -0
  73. {mplang_nightly-0.1.dev143 → mplang_nightly-0.1.dev145}/mplang/frontend/phe.py +0 -0
  74. {mplang_nightly-0.1.dev143 → mplang_nightly-0.1.dev145}/mplang/frontend/sql.py +0 -0
  75. {mplang_nightly-0.1.dev143 → mplang_nightly-0.1.dev145}/mplang/frontend/tee.py +0 -0
  76. {mplang_nightly-0.1.dev143 → mplang_nightly-0.1.dev145}/mplang/protos/v1alpha1/mpir_pb2.py +0 -0
  77. {mplang_nightly-0.1.dev143 → mplang_nightly-0.1.dev145}/mplang/protos/v1alpha1/mpir_pb2.pyi +0 -0
  78. {mplang_nightly-0.1.dev143 → mplang_nightly-0.1.dev145}/mplang/protos/v1alpha1/mpir_pb2_grpc.py +0 -0
  79. {mplang_nightly-0.1.dev143 → mplang_nightly-0.1.dev145}/mplang/runtime/__init__.py +0 -0
  80. {mplang_nightly-0.1.dev143 → mplang_nightly-0.1.dev145}/mplang/runtime/cli.py +0 -0
  81. {mplang_nightly-0.1.dev143 → mplang_nightly-0.1.dev145}/mplang/runtime/client.py +0 -0
  82. {mplang_nightly-0.1.dev143 → mplang_nightly-0.1.dev145}/mplang/runtime/communicator.py +0 -0
  83. {mplang_nightly-0.1.dev143 → mplang_nightly-0.1.dev145}/mplang/runtime/data_providers.py +0 -0
  84. {mplang_nightly-0.1.dev143 → mplang_nightly-0.1.dev145}/mplang/runtime/driver.py +0 -0
  85. {mplang_nightly-0.1.dev143 → mplang_nightly-0.1.dev145}/mplang/runtime/exceptions.py +0 -0
  86. {mplang_nightly-0.1.dev143 → mplang_nightly-0.1.dev145}/mplang/runtime/http_api.md +0 -0
  87. {mplang_nightly-0.1.dev143 → mplang_nightly-0.1.dev145}/mplang/runtime/link_comm.py +0 -0
  88. {mplang_nightly-0.1.dev143 → mplang_nightly-0.1.dev145}/mplang/runtime/server.py +0 -0
  89. {mplang_nightly-0.1.dev143 → mplang_nightly-0.1.dev145}/mplang/simp/__init__.py +0 -0
  90. {mplang_nightly-0.1.dev143 → mplang_nightly-0.1.dev145}/mplang/simp/mpi.py +0 -0
  91. {mplang_nightly-0.1.dev143 → mplang_nightly-0.1.dev145}/mplang/simp/random.py +0 -0
  92. {mplang_nightly-0.1.dev143 → mplang_nightly-0.1.dev145}/mplang/simp/smpc.py +0 -0
  93. {mplang_nightly-0.1.dev143/tests/device → mplang_nightly-0.1.dev145/mplang/utils}/__init__.py +0 -0
  94. {mplang_nightly-0.1.dev143 → mplang_nightly-0.1.dev145}/mplang/utils/crypto.py +0 -0
  95. {mplang_nightly-0.1.dev143 → mplang_nightly-0.1.dev145}/mplang/utils/func_utils.py +0 -0
  96. {mplang_nightly-0.1.dev143 → mplang_nightly-0.1.dev145}/mplang/utils/spu_utils.py +0 -0
  97. {mplang_nightly-0.1.dev143 → mplang_nightly-0.1.dev145}/mplang/utils/table_utils.py +0 -0
  98. {mplang_nightly-0.1.dev143 → mplang_nightly-0.1.dev145}/pyproject.toml +0 -0
  99. {mplang_nightly-0.1.dev143 → mplang_nightly-0.1.dev145}/tests/__init__.py +0 -0
  100. {mplang_nightly-0.1.dev143 → mplang_nightly-0.1.dev145}/tests/analysis/test_diagram.py +0 -0
  101. {mplang_nightly-0.1.dev143 → mplang_nightly-0.1.dev145}/tests/backend/test_debug_print.py +0 -0
  102. {mplang_nightly-0.1.dev143 → mplang_nightly-0.1.dev145}/tests/core/__init__.py +0 -0
  103. {mplang_nightly-0.1.dev143 → mplang_nightly-0.1.dev145}/tests/core/expr/__init__.py +0 -0
  104. {mplang_nightly-0.1.dev143 → mplang_nightly-0.1.dev145}/tests/core/expr/conftest.py +0 -0
  105. {mplang_nightly-0.1.dev143 → mplang_nightly-0.1.dev145}/tests/core/expr/test_ast.py +0 -0
  106. {mplang_nightly-0.1.dev143 → mplang_nightly-0.1.dev145}/tests/core/expr/test_printer.py +0 -0
  107. {mplang_nightly-0.1.dev143 → mplang_nightly-0.1.dev145}/tests/core/expr/test_utils.py +0 -0
  108. {mplang_nightly-0.1.dev143 → mplang_nightly-0.1.dev145}/tests/core/expr/test_walk.py +0 -0
  109. {mplang_nightly-0.1.dev143 → mplang_nightly-0.1.dev145}/tests/core/test_cluster.py +0 -0
  110. {mplang_nightly-0.1.dev143 → mplang_nightly-0.1.dev145}/tests/core/test_dtype.py +0 -0
  111. {mplang_nightly-0.1.dev143 → mplang_nightly-0.1.dev145}/tests/core/test_mask.py +0 -0
  112. {mplang_nightly-0.1.dev143 → mplang_nightly-0.1.dev145}/tests/core/test_mpir.py +0 -0
  113. {mplang_nightly-0.1.dev143 → mplang_nightly-0.1.dev145}/tests/core/test_mptype.py +0 -0
  114. {mplang_nightly-0.1.dev143 → mplang_nightly-0.1.dev145}/tests/core/test_primitive.py +0 -0
  115. {mplang_nightly-0.1.dev143 → mplang_nightly-0.1.dev145}/tests/core/test_table.py +0 -0
  116. {mplang_nightly-0.1.dev143 → mplang_nightly-0.1.dev145}/tests/core/test_tensor.py +0 -0
  117. {mplang_nightly-0.1.dev143 → mplang_nightly-0.1.dev145}/tests/core/test_tracer.py +0 -0
  118. {mplang_nightly-0.1.dev143/tests/frontend → mplang_nightly-0.1.dev145/tests/device}/__init__.py +0 -0
  119. {mplang_nightly-0.1.dev143 → mplang_nightly-0.1.dev145}/tests/device/test_device_basic.py +0 -0
  120. {mplang_nightly-0.1.dev143/tests/runtime → mplang_nightly-0.1.dev145/tests/frontend}/__init__.py +0 -0
  121. {mplang_nightly-0.1.dev143 → mplang_nightly-0.1.dev145}/tests/frontend/dummy.py +0 -0
  122. {mplang_nightly-0.1.dev143 → mplang_nightly-0.1.dev145}/tests/frontend/test_builtin_pack.py +0 -0
  123. {mplang_nightly-0.1.dev143 → mplang_nightly-0.1.dev145}/tests/frontend/test_crypto_tee.py +0 -0
  124. {mplang_nightly-0.1.dev143 → mplang_nightly-0.1.dev145}/tests/frontend/test_feop_base.py +0 -0
  125. {mplang_nightly-0.1.dev143 → mplang_nightly-0.1.dev145}/tests/frontend/test_ibis.py +0 -0
  126. {mplang_nightly-0.1.dev143 → mplang_nightly-0.1.dev145}/tests/frontend/test_ibis_cc.py +0 -0
  127. {mplang_nightly-0.1.dev143 → mplang_nightly-0.1.dev145}/tests/frontend/test_jax_cc.py +0 -0
  128. {mplang_nightly-0.1.dev143 → mplang_nightly-0.1.dev145}/tests/frontend/test_phe.py +0 -0
  129. {mplang_nightly-0.1.dev143 → mplang_nightly-0.1.dev145}/tests/frontend/test_sql.py +0 -0
  130. {mplang_nightly-0.1.dev143 → mplang_nightly-0.1.dev145}/tests/frontend/test_table_tensor_conversion.py +0 -0
  131. {mplang_nightly-0.1.dev143 → mplang_nightly-0.1.dev145}/tests/integration/README.md +0 -0
  132. {mplang_nightly-0.1.dev143 → mplang_nightly-0.1.dev145}/tests/integration/test_crypto_roundtrip.py +0 -0
  133. {mplang_nightly-0.1.dev143 → mplang_nightly-0.1.dev145}/tests/integration/test_http_e2e.py +0 -0
  134. {mplang_nightly-0.1.dev143 → mplang_nightly-0.1.dev145}/tests/integration/test_symbols_roundtrip.py +0 -0
  135. {mplang_nightly-0.1.dev143 → mplang_nightly-0.1.dev145}/tests/integration/test_tutorials.py +0 -0
  136. {mplang_nightly-0.1.dev143 → mplang_nightly-0.1.dev145}/tests/runtime/test_cli.py +0 -0
  137. {mplang_nightly-0.1.dev143 → mplang_nightly-0.1.dev145}/tests/runtime/test_communicator.py +0 -0
  138. {mplang_nightly-0.1.dev143 → mplang_nightly-0.1.dev145}/tests/runtime/test_driver.py +0 -0
  139. {mplang_nightly-0.1.dev143 → mplang_nightly-0.1.dev145}/tests/runtime/test_server.py +0 -0
  140. {mplang_nightly-0.1.dev143 → mplang_nightly-0.1.dev145}/tests/runtime/test_simulation.py +0 -0
  141. {mplang_nightly-0.1.dev143 → mplang_nightly-0.1.dev145}/tests/simp/test_mpi.py +0 -0
  142. {mplang_nightly-0.1.dev143 → mplang_nightly-0.1.dev145}/tests/simp/test_random.py +0 -0
  143. {mplang_nightly-0.1.dev143 → mplang_nightly-0.1.dev145}/tests/simp/test_simp.py +0 -0
  144. {mplang_nightly-0.1.dev143 → mplang_nightly-0.1.dev145}/tests/simp/test_smpc.py +0 -0
  145. {mplang_nightly-0.1.dev143 → mplang_nightly-0.1.dev145}/tests/simp/test_sugar.py +0 -0
  146. {mplang_nightly-0.1.dev143 → mplang_nightly-0.1.dev145}/tests/utils/__init__.py +0 -0
  147. {mplang_nightly-0.1.dev143 → mplang_nightly-0.1.dev145}/tests/utils/server_fixtures.py +0 -0
  148. {mplang_nightly-0.1.dev143 → mplang_nightly-0.1.dev145}/tests/utils/test_func_utils.py +0 -0
  149. {mplang_nightly-0.1.dev143 → mplang_nightly-0.1.dev145}/tests/utils/test_spu_utils.py +0 -0
  150. {mplang_nightly-0.1.dev143 → mplang_nightly-0.1.dev145}/tests/utils/test_table_utils.py +0 -0
  151. {mplang_nightly-0.1.dev143 → mplang_nightly-0.1.dev145}/tutorials/0_basic.py +0 -0
  152. {mplang_nightly-0.1.dev143 → mplang_nightly-0.1.dev145}/tutorials/10_analysis.py +0 -0
  153. {mplang_nightly-0.1.dev143 → mplang_nightly-0.1.dev145}/tutorials/1_condition.py +0 -0
  154. {mplang_nightly-0.1.dev143 → mplang_nightly-0.1.dev145}/tutorials/2_whileloop.py +0 -0
  155. {mplang_nightly-0.1.dev143 → mplang_nightly-0.1.dev145}/tutorials/3_device.py +0 -0
  156. {mplang_nightly-0.1.dev143 → mplang_nightly-0.1.dev145}/tutorials/4_simulation.py +0 -0
  157. {mplang_nightly-0.1.dev143 → mplang_nightly-0.1.dev145}/tutorials/5_ir_dump.py +0 -0
  158. {mplang_nightly-0.1.dev143 → mplang_nightly-0.1.dev145}/tutorials/6_advanced.py +0 -0
  159. {mplang_nightly-0.1.dev143 → mplang_nightly-0.1.dev145}/tutorials/7_stdio.py +0 -0
  160. {mplang_nightly-0.1.dev143 → mplang_nightly-0.1.dev145}/tutorials/8_phe.py +0 -0
  161. {mplang_nightly-0.1.dev143 → mplang_nightly-0.1.dev145}/tutorials/9_tee.py +0 -0
  162. {mplang_nightly-0.1.dev143 → mplang_nightly-0.1.dev145}/tutorials/__init__.py +0 -0
  163. {mplang_nightly-0.1.dev143 → mplang_nightly-0.1.dev145}/tutorials/pitfalls/late_binding.py +0 -0
  164. {mplang_nightly-0.1.dev143 → mplang_nightly-0.1.dev145}/tutorials/pitfalls/rand.py +0 -0
  165. {mplang_nightly-0.1.dev143 → mplang_nightly-0.1.dev145}/tutorials/run.sh +0 -0
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: mplang-nightly
3
- Version: 0.1.dev143
3
+ Version: 0.1.dev145
4
4
  Summary: Multi-Party Programming Language
5
5
  Author-email: SecretFlow Team <secretflow-contact@service.alipay.com>
6
6
  License: Apache License
@@ -0,0 +1,175 @@
1
+ # Copyright 2025 Ant Group Co., Ltd.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ """Backend kernel registry & per-participant runtime (explicit op->kernel binding).
16
+
17
+ This version decouples *kernel implementation registration* from *operation binding*.
18
+
19
+ Concepts:
20
+ * kernel_id: unique identifier of a concrete backend implementation.
21
+ * op_type: semantic operation name carried by ``PFunction.fn_type``.
22
+ * bind_op(op_type, kernel_id): performed by higher layer (see ``backend.context``)
23
+ to select which implementation handles an op. Runtime dispatch is now a 2-step:
24
+ pfunc.fn_type -> active kernel_id -> KernelSpec.fn
25
+
26
+ The previous implicit "import == register+bind" coupling is removed. Kernel
27
+ modules only call ``@kernel_def(kernel_id)``. Default bindings are established
28
+ centrally (lazy) the first time a runtime executes a kernel.
29
+ """
30
+
31
+ from __future__ import annotations
32
+
33
+ import contextvars
34
+ from collections.abc import Callable
35
+ from dataclasses import dataclass
36
+ from typing import Any
37
+
38
+ __all__ = [
39
+ "KernelContext",
40
+ "KernelSpec",
41
+ "bind_op",
42
+ "cur_kctx",
43
+ "get_kernel_for_op",
44
+ "list_kernels",
45
+ "list_ops",
46
+ "unbind_op",
47
+ ]
48
+
49
+
50
+ @dataclass
51
+ class KernelContext:
52
+ """Ephemeral call context set via contextvar while a kernel runs."""
53
+
54
+ rank: int
55
+ world_size: int
56
+ state: dict[str, dict[str, Any]] # backend namespace -> pocket
57
+ cache: dict[str, Any] # runtime-level shared cache (per BackendRuntime)
58
+
59
+
60
+ _CTX_VAR: contextvars.ContextVar[KernelContext | None] = contextvars.ContextVar(
61
+ "_flat_backend_ctx", default=None
62
+ )
63
+
64
+
65
+ def cur_kctx() -> KernelContext:
66
+ """Return the current kernel execution context (only valid inside a kernel).
67
+
68
+ Two storages:
69
+ - state: namespaced pockets (dict[str, dict]) for backend-local mutable helpers
70
+ - cache: global (per runtime) shared dict; prefer state unless truly cross-backend
71
+
72
+ Examples:
73
+ 1) Compile cache::
74
+ @kernel_def("mlir.stablehlo")
75
+ def _exec(pfunc, args):
76
+ ctx = cur_kctx()
77
+ pocket = ctx.state.setdefault("stablehlo", {})
78
+ cache = pocket.setdefault("compile_cache", {})
79
+ text = pfunc.fn_text
80
+ mod = cache.get(text)
81
+ if mod is None:
82
+ mod = compile_mlir(text)
83
+ cache[text] = mod
84
+ return run(mod, args)
85
+
86
+ 2) Deterministic RNG::
87
+ @kernel_def("crypto.keygen")
88
+ def _keygen(pfunc, args):
89
+ ctx = cur_kctx()
90
+ pocket = ctx.state.setdefault("crypto", {})
91
+ rng = pocket.get("rng")
92
+ if rng is None:
93
+ rng = np.random.default_rng(1234 + ctx.rank * 7919)
94
+ pocket["rng"] = rng
95
+ return (rng.integers(0, 256, size=(32,), dtype=np.uint8),)
96
+ """
97
+ ctx = _CTX_VAR.get()
98
+ if ctx is None:
99
+ raise RuntimeError("cur_kctx() called outside backend kernel execution")
100
+ return ctx
101
+
102
+
103
+ # ---------------- Registry ----------------
104
+
105
+ # Kernel callable signature: (pfunc, *args) -> Any | sequence (no **kwargs)
106
+ KernelFn = Callable[..., Any]
107
+
108
+
109
+ @dataclass
110
+ class KernelSpec:
111
+ kernel_id: str
112
+ fn: KernelFn
113
+ meta: dict[str, Any]
114
+
115
+
116
+ # All registered kernel implementations: kernel_id -> spec
117
+ _KERNELS: dict[str, KernelSpec] = {}
118
+
119
+ # Active op bindings: op_type -> kernel_id
120
+ _BINDINGS: dict[str, str] = {}
121
+
122
+
123
+ def kernel_def(kernel_id: str, /, **meta: Any) -> Callable[[KernelFn], KernelFn]:
124
+ """Decorator to register a concrete kernel implementation.
125
+
126
+ This ONLY registers the implementation (kernel_id -> fn). It does NOT bind
127
+ any op. Higher layer must call ``bind_op(op_type, kernel_id)`` explicitly.
128
+ """
129
+
130
+ def _decorator(fn: KernelFn) -> KernelFn:
131
+ if kernel_id in _KERNELS:
132
+ raise ValueError(f"duplicate kernel_id={kernel_id}")
133
+ _KERNELS[kernel_id] = KernelSpec(kernel_id=kernel_id, fn=fn, meta=dict(meta))
134
+ return fn
135
+
136
+ return _decorator
137
+
138
+
139
+ def bind_op(op_type: str, kernel_id: str, *, force: bool = True) -> None:
140
+ """Bind an op_type to a registered kernel implementation.
141
+
142
+ Args:
143
+ op_type: Semantic operation name.
144
+ kernel_id: Previously registered kernel identifier.
145
+ force: If False and op_type already bound, keep existing binding.
146
+ If True (default), overwrite.
147
+ """
148
+ if kernel_id not in _KERNELS:
149
+ raise KeyError(f"kernel_id {kernel_id} not registered")
150
+ if not force and op_type in _BINDINGS:
151
+ return
152
+ _BINDINGS[op_type] = kernel_id
153
+
154
+
155
+ def unbind_op(op_type: str) -> None:
156
+ _BINDINGS.pop(op_type, None)
157
+
158
+
159
+ def get_kernel_for_op(op_type: str) -> KernelSpec:
160
+ kid = _BINDINGS.get(op_type)
161
+ if kid is None:
162
+ # Tests expect NotImplementedError for unsupported operations
163
+ raise NotImplementedError(f"no backend kernel registered for op {op_type}")
164
+ spec = _KERNELS.get(kid)
165
+ if spec is None: # inconsistent state
166
+ raise RuntimeError(f"active kernel_id {kid} missing spec")
167
+ return spec
168
+
169
+
170
+ def list_kernels() -> list[str]:
171
+ return sorted(_KERNELS.keys())
172
+
173
+
174
+ def list_ops() -> list[str]:
175
+ return sorted(_BINDINGS.keys())
@@ -0,0 +1,255 @@
1
+ # Copyright 2025 Ant Group Co., Ltd.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ from __future__ import annotations
16
+
17
+ from collections.abc import Mapping
18
+ from dataclasses import dataclass, field
19
+ from typing import Any
20
+
21
+ from mplang.backend import base
22
+ from mplang.backend.base import KernelContext, bind_op, get_kernel_for_op
23
+ from mplang.core.dtype import UINT8, DType
24
+ from mplang.core.pfunc import PFunction
25
+ from mplang.core.table import TableLike, TableType
26
+ from mplang.core.tensor import TensorLike, TensorType
27
+
28
+ # Default bindings
29
+ # Import kernel implementation modules explicitly so their @kernel_def entries
30
+ # register at import time. Keep imports grouped; alias with leading underscore
31
+ # to silence unused variable warnings without F401 pragmas.
32
+ _IMPL_IMPORTED = False
33
+
34
+
35
+ def _ensure_impl_imported() -> None:
36
+ global _IMPL_IMPORTED
37
+ if _IMPL_IMPORTED:
38
+ return
39
+ from mplang.backend import builtin as _impl_builtin # noqa: F401
40
+ from mplang.backend import crypto as _impl_crypto # noqa: F401
41
+ from mplang.backend import phe as _impl_phe # noqa: F401
42
+ from mplang.backend import spu as _impl_spu # noqa: F401
43
+ from mplang.backend import sql_duckdb as _impl_sql_duckdb # noqa: F401
44
+ from mplang.backend import stablehlo as _impl_stablehlo # noqa: F401
45
+ from mplang.backend import tee as _impl_tee # noqa: F401
46
+
47
+ _IMPL_IMPORTED = True
48
+
49
+
50
+ # imports consolidated above
51
+
52
+ _DEFAULT_BINDINGS: dict[str, str] = {
53
+ # builtin
54
+ "builtin.identity": "builtin.identity",
55
+ "builtin.read": "builtin.read",
56
+ "builtin.write": "builtin.write",
57
+ "builtin.constant": "builtin.constant",
58
+ "builtin.rank": "builtin.rank",
59
+ "builtin.prand": "builtin.prand",
60
+ "builtin.table_to_tensor": "builtin.table_to_tensor",
61
+ "builtin.tensor_to_table": "builtin.tensor_to_table",
62
+ "builtin.debug_print": "builtin.debug_print",
63
+ "builtin.pack": "builtin.pack",
64
+ "builtin.unpack": "builtin.unpack",
65
+ # crypto
66
+ "crypto.keygen": "crypto.keygen",
67
+ "crypto.enc": "crypto.enc",
68
+ "crypto.dec": "crypto.dec",
69
+ "crypto.kem_keygen": "crypto.kem_keygen",
70
+ "crypto.kem_derive": "crypto.kem_derive",
71
+ "crypto.hkdf": "crypto.hkdf",
72
+ # phe
73
+ "phe.keygen": "phe.keygen",
74
+ "phe.encrypt": "phe.encrypt",
75
+ "phe.mul": "phe.mul",
76
+ "phe.add": "phe.add",
77
+ "phe.decrypt": "phe.decrypt",
78
+ "phe.dot": "phe.dot",
79
+ "phe.gather": "phe.gather",
80
+ "phe.scatter": "phe.scatter",
81
+ "phe.concat": "phe.concat",
82
+ "phe.reshape": "phe.reshape",
83
+ "phe.transpose": "phe.transpose",
84
+ # spu
85
+ "spu.seed_env": "spu.seed_env",
86
+ "spu.makeshares": "spu.makeshares",
87
+ "spu.reconstruct": "spu.reconstruct",
88
+ "spu.run_pphlo": "spu.run_pphlo",
89
+ # stablehlo
90
+ "mlir.stablehlo": "mlir.stablehlo",
91
+ # sql
92
+ # generic SQL op; backend-specific kernel id for duckdb
93
+ "sql.run": "duckdb.run_sql",
94
+ # tee
95
+ "tee.quote": "tee.quote",
96
+ "tee.attest": "tee.attest",
97
+ }
98
+
99
+
100
+ # --- RuntimeContext ---
101
+
102
+
103
+ @dataclass
104
+ class RuntimeContext:
105
+ rank: int
106
+ world_size: int
107
+ bindings: Mapping[str, str] | None = None # optional overrides
108
+ state: dict[str, dict[str, Any]] = field(default_factory=dict)
109
+ cache: dict[str, Any] = field(default_factory=dict)
110
+ stats: dict[str, Any] = field(default_factory=dict)
111
+
112
+ def __post_init__(self) -> None:
113
+ _ensure_impl_imported()
114
+ if self.bindings is not None:
115
+ for op, kid in self.bindings.items():
116
+ bind_op(op, kid)
117
+ else:
118
+ for op, kid in _DEFAULT_BINDINGS.items():
119
+ bind_op(op, kid)
120
+ # Initialize stats pocket
121
+ self.stats.setdefault("op_calls", {})
122
+
123
+ def run_kernel(self, pfunc: PFunction, arg_list: list[Any]) -> list[Any]:
124
+ fn_type = pfunc.fn_type
125
+ spec = get_kernel_for_op(fn_type)
126
+ fn = spec.fn
127
+ if len(arg_list) != len(pfunc.ins_info):
128
+ raise ValueError(
129
+ f"kernel {fn_type} arg count mismatch: got {len(arg_list)}, expect {len(pfunc.ins_info)}"
130
+ )
131
+ for idx, (ins_spec, val) in enumerate(
132
+ zip(pfunc.ins_info, arg_list, strict=True)
133
+ ):
134
+ if isinstance(ins_spec, TableType):
135
+ _validate_table_arg(fn_type, idx, ins_spec, val)
136
+ continue
137
+ if isinstance(ins_spec, TensorType):
138
+ _validate_tensor_arg(fn_type, idx, ins_spec, val)
139
+ continue
140
+ # install kernel context
141
+ kctx = KernelContext(
142
+ rank=self.rank,
143
+ world_size=self.world_size,
144
+ state=self.state,
145
+ cache=self.cache,
146
+ )
147
+ token = base._CTX_VAR.set(kctx) # type: ignore[attr-defined]
148
+ try:
149
+ raw = fn(pfunc, *arg_list)
150
+ finally:
151
+ base._CTX_VAR.reset(token) # type: ignore[attr-defined]
152
+ # Stats (best effort)
153
+ try:
154
+ op_calls = self.stats.setdefault("op_calls", {})
155
+ op_calls[fn_type] = op_calls.get(fn_type, 0) + 1
156
+ except Exception: # pragma: no cover - never raise due to stats
157
+ pass
158
+ expected = len(pfunc.outs_info)
159
+ if expected == 0:
160
+ if raw in (None, (), []):
161
+ return []
162
+ raise ValueError(
163
+ f"kernel {fn_type} should return no values; got {type(raw).__name__}"
164
+ )
165
+ if expected == 1:
166
+ if isinstance(raw, (tuple, list)):
167
+ if len(raw) != 1:
168
+ raise ValueError(
169
+ f"kernel {fn_type} produced {len(raw)} outputs, expected 1"
170
+ )
171
+ return [raw[0]]
172
+ return [raw]
173
+ if not isinstance(raw, (tuple, list)):
174
+ raise TypeError(
175
+ f"kernel {fn_type} must return sequence (len={expected}), got {type(raw).__name__}"
176
+ )
177
+ if len(raw) != expected:
178
+ raise ValueError(
179
+ f"kernel {fn_type} produced {len(raw)} outputs, expected {expected}"
180
+ )
181
+ return list(raw)
182
+
183
+ def reset(self) -> None:
184
+ self.state.clear()
185
+ self.cache.clear()
186
+
187
+ # ---- explicit (re)binding API ----
188
+ def bind_op(self, op_type: str, kernel_id: str, *, force: bool = False) -> None:
189
+ """Bind an operation to a kernel at runtime.
190
+
191
+ force=False (default) preserves any existing binding to avoid accidental
192
+ silent overrides. Use ``rebind_op`` or ``force=True`` to intentionally
193
+ change a binding.
194
+ """
195
+ base.bind_op(op_type, kernel_id, force=force)
196
+
197
+ def rebind_op(self, op_type: str, kernel_id: str) -> None:
198
+ """Force rebind an operation to a different kernel (shorthand)."""
199
+ base.bind_op(op_type, kernel_id, force=True)
200
+
201
+
202
+ def _validate_table_arg(
203
+ fn_type: str, arg_index: int, spec: TableType, value: Any
204
+ ) -> None:
205
+ if not isinstance(value, TableLike):
206
+ raise TypeError(
207
+ f"kernel {fn_type} input[{arg_index}] expects TableLike, got {type(value).__name__}"
208
+ )
209
+ if len(value.columns) != len(spec.columns):
210
+ raise ValueError(
211
+ f"kernel {fn_type} input[{arg_index}] column count mismatch: got {len(value.columns)}, expected {len(spec.columns)}"
212
+ )
213
+
214
+
215
+ def _validate_tensor_arg(
216
+ fn_type: str, arg_index: int, spec: TensorType, value: Any
217
+ ) -> None:
218
+ # Backend-only handle sentinel (e.g., PHE keys) bypasses all structural checks
219
+ if tuple(spec.shape) == (-1, 0) and spec.dtype == UINT8:
220
+ return
221
+
222
+ if isinstance(value, (int, float, bool, complex)):
223
+ val_shape: tuple[Any, ...] = ()
224
+ duck_dtype: Any = type(value)
225
+ else:
226
+ if not isinstance(value, TensorLike):
227
+ raise TypeError(
228
+ f"kernel {fn_type} input[{arg_index}] expects TensorLike, got {type(value).__name__}"
229
+ )
230
+ val_shape = getattr(value, "shape", ())
231
+ duck_dtype = getattr(value, "dtype", None)
232
+
233
+ if len(spec.shape) != len(val_shape):
234
+ raise ValueError(
235
+ f"kernel {fn_type} input[{arg_index}] rank mismatch: got {val_shape}, expected {spec.shape}"
236
+ )
237
+
238
+ for dim_idx, (spec_dim, val_dim) in enumerate(
239
+ zip(spec.shape, val_shape, strict=True)
240
+ ):
241
+ if spec_dim >= 0 and spec_dim != val_dim:
242
+ raise ValueError(
243
+ f"kernel {fn_type} input[{arg_index}] shape mismatch at dim {dim_idx}: got {val_dim}, expected {spec_dim}"
244
+ )
245
+
246
+ try:
247
+ val_dtype = DType.from_any(duck_dtype)
248
+ except (ValueError, TypeError): # pragma: no cover
249
+ raise TypeError(
250
+ f"kernel {fn_type} input[{arg_index}] has unsupported dtype object {duck_dtype!r}"
251
+ ) from None
252
+ if val_dtype != spec.dtype:
253
+ raise ValueError(
254
+ f"kernel {fn_type} input[{arg_index}] dtype mismatch: got {val_dtype}, expected {spec.dtype}"
255
+ )
@@ -186,16 +186,18 @@ def _spu_reconstruct(pfunc: PFunction, *args: Any) -> Any:
186
186
  return reconstructed
187
187
 
188
188
 
189
- @kernel_def("mlir.pphlo")
189
+ @kernel_def("spu.run_pphlo")
190
190
  def _spu_run_mlir(pfunc: PFunction, *args: Any) -> Any:
191
- """Execute compiled SPU function (mlir.pphlo) and return SpuValue outputs.
191
+ """Execute compiled SPU function (spu.run_pphlo) and return SpuValue outputs.
192
192
 
193
193
  Participation rule: a rank participates iff its entry in the stored
194
194
  link_ctx list is non-None. This allows us to allocate a world-sized list
195
195
  (indexed by global rank) and simply assign None for non-SPU parties.
196
196
  """
197
- if pfunc.fn_type != "mlir.pphlo":
198
- raise ValueError(f"Unsupported format: {pfunc.fn_type}. Expected 'mlir.pphlo'")
197
+ if pfunc.fn_type != "spu.run_pphlo":
198
+ raise ValueError(
199
+ f"Unsupported format: {pfunc.fn_type}. Expected 'spu.run_pphlo'"
200
+ )
199
201
 
200
202
  cfg, _ = _get_spu_config_and_world()
201
203
  pocket = _get_spu_pocket()
@@ -20,7 +20,7 @@ from mplang.backend.base import kernel_def
20
20
  from mplang.core.pfunc import PFunction
21
21
 
22
22
 
23
- @kernel_def("sql[duckdb]")
23
+ @kernel_def("duckdb.run_sql")
24
24
  def _duckdb_sql(pfunc: PFunction, *args: Any) -> Any:
25
25
  import duckdb
26
26
  import pandas as pd
@@ -27,7 +27,7 @@ from __future__ import annotations
27
27
  from dataclasses import dataclass
28
28
  from typing import Any, Protocol
29
29
 
30
- from mplang.backend.base import BackendRuntime
30
+ from mplang.backend.context import RuntimeContext
31
31
  from mplang.core.comm import ICommunicator
32
32
  from mplang.core.expr.ast import (
33
33
  AccessExpr,
@@ -56,7 +56,7 @@ class IEvaluator(Protocol):
56
56
  backend state via evaluator.runtime.run_kernel(...).
57
57
  """
58
58
 
59
- runtime: BackendRuntime
59
+ runtime: RuntimeContext
60
60
 
61
61
  def evaluate(self, root: Expr, env: dict[str, Any] | None = None) -> list[Any]: ...
62
62
 
@@ -72,7 +72,7 @@ class EvalSemantic:
72
72
  rank: int
73
73
  env: dict[str, Any]
74
74
  comm: ICommunicator
75
- runtime: BackendRuntime
75
+ runtime: RuntimeContext
76
76
 
77
77
  # ------------------------------ Shared helpers (semantics) ------------------------------
78
78
  def _should_run(self, rmask: Mask | None, args: list[Any]) -> bool:
@@ -205,7 +205,7 @@ class RecursiveEvaluator(EvalSemantic, ExprVisitor):
205
205
  rank: int,
206
206
  env: dict[str, Any],
207
207
  comm: ICommunicator,
208
- runtime: BackendRuntime,
208
+ runtime: RuntimeContext,
209
209
  ) -> None:
210
210
  super().__init__(rank, env, comm, runtime)
211
211
  self._cache: dict[int, Any] = {} # Cache based on expr id
@@ -380,7 +380,7 @@ class IterativeEvaluator(EvalSemantic):
380
380
  rank: int,
381
381
  env: dict[str, Any],
382
382
  comm: ICommunicator,
383
- runtime: BackendRuntime,
383
+ runtime: RuntimeContext,
384
384
  ) -> None:
385
385
  super().__init__(rank, env, comm, runtime)
386
386
 
@@ -501,7 +501,7 @@ def create_evaluator(
501
501
  rank: int,
502
502
  env: dict[str, Any],
503
503
  comm: ICommunicator,
504
- runtime: BackendRuntime,
504
+ runtime: RuntimeContext,
505
505
  kind: str | None = "iterative",
506
506
  ) -> IEvaluator:
507
507
  """Factory to create an evaluator engine.
@@ -129,7 +129,7 @@ class FeModule(ABC):
129
129
  - You need compilation/stateful behavior/dynamic routing, multiple PFunctions, or complex capture flows.
130
130
 
131
131
  Tips:
132
- - Keep routing information in PFunction.fn_type (e.g., "builtin.read", "sql[duckdb]", "mlir.stablehlo").
132
+ - Keep routing information in PFunction.fn_type (e.g., "builtin.read", "sql.run", "mlir.stablehlo").
133
133
  - Avoid backend-specific logic in kernels; only validate and shape types.
134
134
  - Prefer keyword-only attributes in typed_op kernels for clarity (def op(x: MPObject, *, attr: int)).
135
135
  """
@@ -57,8 +57,9 @@ def ibis2sql(
57
57
  outs_info = [_convert(expr.schema())]
58
58
 
59
59
  sql = ibis.to_sql(expr, dialect="duckdb")
60
+ # Emit generic sql.run op; runtime maps to backend-specific kernel.
60
61
  pfn = PFunction(
61
- fn_type="sql[duckdb]",
62
+ fn_type="sql.run",
62
63
  fn_name=fn_name,
63
64
  fn_text=sql,
64
65
  ins_info=tuple(ins_info),
@@ -94,9 +94,10 @@ def _compile_jax(
94
94
  *args: Any,
95
95
  **kwargs: Any,
96
96
  ) -> tuple[PFunction, list[MPObject], PyTreeDef]:
97
- """Compile a JAX function into SPU pphlo MLIR representation."""
97
+ """Compile a JAX function into SPU pphlo MLIR and wrap as PFunction.
98
98
 
99
- """Compile a JAX function into SPU pphlo MLIR representation."""
99
+ Resulting PFunction uses fn_type 'spu.run_pphlo'.
100
+ """
100
101
 
101
102
  def is_variable(arg: Any) -> bool:
102
103
  return isinstance(arg, MPObject)
@@ -132,7 +133,7 @@ def _compile_jax(
132
133
  executable_code = executable_code.decode("utf-8")
133
134
 
134
135
  pfunc = PFunction(
135
- fn_type="mlir.pphlo",
136
+ fn_type="spu.run_pphlo",
136
137
  ins_info=tuple(TensorType.from_obj(x) for x in in_vars),
137
138
  outs_info=tuple(output_tensor_infos),
138
139
  fn_name=get_fn_name(fn),