mplang-nightly 0.1.dev147__tar.gz → 0.1.dev148__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (164) hide show
  1. {mplang_nightly-0.1.dev147 → mplang_nightly-0.1.dev148}/PKG-INFO +1 -1
  2. {mplang_nightly-0.1.dev147 → mplang_nightly-0.1.dev148}/mplang/backend/base.py +21 -47
  3. {mplang_nightly-0.1.dev147 → mplang_nightly-0.1.dev148}/mplang/backend/context.py +67 -26
  4. mplang_nightly-0.1.dev148/tests/backend/test_kernel_binding.py +102 -0
  5. {mplang_nightly-0.1.dev147 → mplang_nightly-0.1.dev148}/.gitignore +0 -0
  6. {mplang_nightly-0.1.dev147 → mplang_nightly-0.1.dev148}/LICENSE +0 -0
  7. {mplang_nightly-0.1.dev147 → mplang_nightly-0.1.dev148}/README.md +0 -0
  8. {mplang_nightly-0.1.dev147 → mplang_nightly-0.1.dev148}/examples/conf/3pc.yaml +0 -0
  9. {mplang_nightly-0.1.dev147 → mplang_nightly-0.1.dev148}/examples/stax_nn/README.md +0 -0
  10. {mplang_nightly-0.1.dev147 → mplang_nightly-0.1.dev148}/examples/stax_nn/models.py +0 -0
  11. {mplang_nightly-0.1.dev147 → mplang_nightly-0.1.dev148}/examples/stax_nn/stax_nn.py +0 -0
  12. {mplang_nightly-0.1.dev147 → mplang_nightly-0.1.dev148}/examples/xgboost/hist_jax.py +0 -0
  13. {mplang_nightly-0.1.dev147 → mplang_nightly-0.1.dev148}/examples/xgboost/hist_jax_test.py +0 -0
  14. {mplang_nightly-0.1.dev147 → mplang_nightly-0.1.dev148}/examples/xgboost/naive_np.py +0 -0
  15. {mplang_nightly-0.1.dev147 → mplang_nightly-0.1.dev148}/examples/xgboost/readme.md +0 -0
  16. {mplang_nightly-0.1.dev147 → mplang_nightly-0.1.dev148}/examples/xgboost/sgb.py +0 -0
  17. {mplang_nightly-0.1.dev147 → mplang_nightly-0.1.dev148}/examples/xgboost/sgb_test.py +0 -0
  18. {mplang_nightly-0.1.dev147 → mplang_nightly-0.1.dev148}/hatch_build.py +0 -0
  19. {mplang_nightly-0.1.dev147 → mplang_nightly-0.1.dev148}/mplang/__init__.py +0 -0
  20. {mplang_nightly-0.1.dev147 → mplang_nightly-0.1.dev148}/mplang/analysis/__init__.py +0 -0
  21. {mplang_nightly-0.1.dev147 → mplang_nightly-0.1.dev148}/mplang/analysis/diagram.py +0 -0
  22. {mplang_nightly-0.1.dev147 → mplang_nightly-0.1.dev148}/mplang/api.py +0 -0
  23. {mplang_nightly-0.1.dev147 → mplang_nightly-0.1.dev148}/mplang/backend/__init__.py +0 -0
  24. {mplang_nightly-0.1.dev147 → mplang_nightly-0.1.dev148}/mplang/backend/builtin.py +0 -0
  25. {mplang_nightly-0.1.dev147 → mplang_nightly-0.1.dev148}/mplang/backend/crypto.py +0 -0
  26. {mplang_nightly-0.1.dev147 → mplang_nightly-0.1.dev148}/mplang/backend/phe.py +0 -0
  27. {mplang_nightly-0.1.dev147 → mplang_nightly-0.1.dev148}/mplang/backend/spu.py +0 -0
  28. {mplang_nightly-0.1.dev147 → mplang_nightly-0.1.dev148}/mplang/backend/sql_duckdb.py +0 -0
  29. {mplang_nightly-0.1.dev147 → mplang_nightly-0.1.dev148}/mplang/backend/stablehlo.py +0 -0
  30. {mplang_nightly-0.1.dev147 → mplang_nightly-0.1.dev148}/mplang/backend/tee.py +0 -0
  31. {mplang_nightly-0.1.dev147 → mplang_nightly-0.1.dev148}/mplang/core/__init__.py +0 -0
  32. {mplang_nightly-0.1.dev147 → mplang_nightly-0.1.dev148}/mplang/core/cluster.py +0 -0
  33. {mplang_nightly-0.1.dev147 → mplang_nightly-0.1.dev148}/mplang/core/comm.py +0 -0
  34. {mplang_nightly-0.1.dev147 → mplang_nightly-0.1.dev148}/mplang/core/context_mgr.py +0 -0
  35. {mplang_nightly-0.1.dev147 → mplang_nightly-0.1.dev148}/mplang/core/dtype.py +0 -0
  36. {mplang_nightly-0.1.dev147 → mplang_nightly-0.1.dev148}/mplang/core/expr/__init__.py +0 -0
  37. {mplang_nightly-0.1.dev147 → mplang_nightly-0.1.dev148}/mplang/core/expr/ast.py +0 -0
  38. {mplang_nightly-0.1.dev147 → mplang_nightly-0.1.dev148}/mplang/core/expr/evaluator.py +0 -0
  39. {mplang_nightly-0.1.dev147 → mplang_nightly-0.1.dev148}/mplang/core/expr/printer.py +0 -0
  40. {mplang_nightly-0.1.dev147 → mplang_nightly-0.1.dev148}/mplang/core/expr/transformer.py +0 -0
  41. {mplang_nightly-0.1.dev147 → mplang_nightly-0.1.dev148}/mplang/core/expr/utils.py +0 -0
  42. {mplang_nightly-0.1.dev147 → mplang_nightly-0.1.dev148}/mplang/core/expr/visitor.py +0 -0
  43. {mplang_nightly-0.1.dev147 → mplang_nightly-0.1.dev148}/mplang/core/expr/walk.py +0 -0
  44. {mplang_nightly-0.1.dev147 → mplang_nightly-0.1.dev148}/mplang/core/interp.py +0 -0
  45. {mplang_nightly-0.1.dev147 → mplang_nightly-0.1.dev148}/mplang/core/mask.py +0 -0
  46. {mplang_nightly-0.1.dev147 → mplang_nightly-0.1.dev148}/mplang/core/mpir.py +0 -0
  47. {mplang_nightly-0.1.dev147 → mplang_nightly-0.1.dev148}/mplang/core/mpobject.py +0 -0
  48. {mplang_nightly-0.1.dev147 → mplang_nightly-0.1.dev148}/mplang/core/mptype.py +0 -0
  49. {mplang_nightly-0.1.dev147 → mplang_nightly-0.1.dev148}/mplang/core/pfunc.py +0 -0
  50. {mplang_nightly-0.1.dev147 → mplang_nightly-0.1.dev148}/mplang/core/primitive.py +0 -0
  51. {mplang_nightly-0.1.dev147 → mplang_nightly-0.1.dev148}/mplang/core/table.py +0 -0
  52. {mplang_nightly-0.1.dev147 → mplang_nightly-0.1.dev148}/mplang/core/tensor.py +0 -0
  53. {mplang_nightly-0.1.dev147 → mplang_nightly-0.1.dev148}/mplang/core/tracer.py +0 -0
  54. {mplang_nightly-0.1.dev147 → mplang_nightly-0.1.dev148}/mplang/device.py +0 -0
  55. {mplang_nightly-0.1.dev147 → mplang_nightly-0.1.dev148}/mplang/frontend/__init__.py +0 -0
  56. {mplang_nightly-0.1.dev147 → mplang_nightly-0.1.dev148}/mplang/frontend/base.py +0 -0
  57. {mplang_nightly-0.1.dev147 → mplang_nightly-0.1.dev148}/mplang/frontend/builtin.py +0 -0
  58. {mplang_nightly-0.1.dev147 → mplang_nightly-0.1.dev148}/mplang/frontend/crypto.py +0 -0
  59. {mplang_nightly-0.1.dev147 → mplang_nightly-0.1.dev148}/mplang/frontend/ibis_cc.py +0 -0
  60. {mplang_nightly-0.1.dev147 → mplang_nightly-0.1.dev148}/mplang/frontend/jax_cc.py +0 -0
  61. {mplang_nightly-0.1.dev147 → mplang_nightly-0.1.dev148}/mplang/frontend/phe.py +0 -0
  62. {mplang_nightly-0.1.dev147 → mplang_nightly-0.1.dev148}/mplang/frontend/spu.py +0 -0
  63. {mplang_nightly-0.1.dev147 → mplang_nightly-0.1.dev148}/mplang/frontend/sql.py +0 -0
  64. {mplang_nightly-0.1.dev147 → mplang_nightly-0.1.dev148}/mplang/frontend/tee.py +0 -0
  65. {mplang_nightly-0.1.dev147 → mplang_nightly-0.1.dev148}/mplang/protos/v1alpha1/mpir_pb2.py +0 -0
  66. {mplang_nightly-0.1.dev147 → mplang_nightly-0.1.dev148}/mplang/protos/v1alpha1/mpir_pb2.pyi +0 -0
  67. {mplang_nightly-0.1.dev147 → mplang_nightly-0.1.dev148}/mplang/protos/v1alpha1/mpir_pb2_grpc.py +0 -0
  68. {mplang_nightly-0.1.dev147 → mplang_nightly-0.1.dev148}/mplang/runtime/__init__.py +0 -0
  69. {mplang_nightly-0.1.dev147 → mplang_nightly-0.1.dev148}/mplang/runtime/cli.py +0 -0
  70. {mplang_nightly-0.1.dev147 → mplang_nightly-0.1.dev148}/mplang/runtime/client.py +0 -0
  71. {mplang_nightly-0.1.dev147 → mplang_nightly-0.1.dev148}/mplang/runtime/communicator.py +0 -0
  72. {mplang_nightly-0.1.dev147 → mplang_nightly-0.1.dev148}/mplang/runtime/data_providers.py +0 -0
  73. {mplang_nightly-0.1.dev147 → mplang_nightly-0.1.dev148}/mplang/runtime/driver.py +0 -0
  74. {mplang_nightly-0.1.dev147 → mplang_nightly-0.1.dev148}/mplang/runtime/exceptions.py +0 -0
  75. {mplang_nightly-0.1.dev147 → mplang_nightly-0.1.dev148}/mplang/runtime/http_api.md +0 -0
  76. {mplang_nightly-0.1.dev147 → mplang_nightly-0.1.dev148}/mplang/runtime/link_comm.py +0 -0
  77. {mplang_nightly-0.1.dev147 → mplang_nightly-0.1.dev148}/mplang/runtime/resource.py +0 -0
  78. {mplang_nightly-0.1.dev147 → mplang_nightly-0.1.dev148}/mplang/runtime/server.py +0 -0
  79. {mplang_nightly-0.1.dev147 → mplang_nightly-0.1.dev148}/mplang/runtime/simulation.py +0 -0
  80. {mplang_nightly-0.1.dev147 → mplang_nightly-0.1.dev148}/mplang/simp/__init__.py +0 -0
  81. {mplang_nightly-0.1.dev147 → mplang_nightly-0.1.dev148}/mplang/simp/mpi.py +0 -0
  82. {mplang_nightly-0.1.dev147 → mplang_nightly-0.1.dev148}/mplang/simp/random.py +0 -0
  83. {mplang_nightly-0.1.dev147 → mplang_nightly-0.1.dev148}/mplang/simp/smpc.py +0 -0
  84. {mplang_nightly-0.1.dev147 → mplang_nightly-0.1.dev148}/mplang/utils/__init__.py +0 -0
  85. {mplang_nightly-0.1.dev147 → mplang_nightly-0.1.dev148}/mplang/utils/crypto.py +0 -0
  86. {mplang_nightly-0.1.dev147 → mplang_nightly-0.1.dev148}/mplang/utils/func_utils.py +0 -0
  87. {mplang_nightly-0.1.dev147 → mplang_nightly-0.1.dev148}/mplang/utils/spu_utils.py +0 -0
  88. {mplang_nightly-0.1.dev147 → mplang_nightly-0.1.dev148}/mplang/utils/table_utils.py +0 -0
  89. {mplang_nightly-0.1.dev147 → mplang_nightly-0.1.dev148}/pyproject.toml +0 -0
  90. {mplang_nightly-0.1.dev147 → mplang_nightly-0.1.dev148}/tests/__init__.py +0 -0
  91. {mplang_nightly-0.1.dev147 → mplang_nightly-0.1.dev148}/tests/analysis/test_diagram.py +0 -0
  92. {mplang_nightly-0.1.dev147 → mplang_nightly-0.1.dev148}/tests/backend/test_builtin.py +0 -0
  93. {mplang_nightly-0.1.dev147 → mplang_nightly-0.1.dev148}/tests/backend/test_debug_print.py +0 -0
  94. {mplang_nightly-0.1.dev147 → mplang_nightly-0.1.dev148}/tests/backend/test_phe.py +0 -0
  95. {mplang_nightly-0.1.dev147 → mplang_nightly-0.1.dev148}/tests/backend/test_spu.py +0 -0
  96. {mplang_nightly-0.1.dev147 → mplang_nightly-0.1.dev148}/tests/backend/test_sql_duckdb.py +0 -0
  97. {mplang_nightly-0.1.dev147 → mplang_nightly-0.1.dev148}/tests/backend/test_stablehlo.py +0 -0
  98. {mplang_nightly-0.1.dev147 → mplang_nightly-0.1.dev148}/tests/core/__init__.py +0 -0
  99. {mplang_nightly-0.1.dev147 → mplang_nightly-0.1.dev148}/tests/core/expr/__init__.py +0 -0
  100. {mplang_nightly-0.1.dev147 → mplang_nightly-0.1.dev148}/tests/core/expr/conftest.py +0 -0
  101. {mplang_nightly-0.1.dev147 → mplang_nightly-0.1.dev148}/tests/core/expr/test_ast.py +0 -0
  102. {mplang_nightly-0.1.dev147 → mplang_nightly-0.1.dev148}/tests/core/expr/test_printer.py +0 -0
  103. {mplang_nightly-0.1.dev147 → mplang_nightly-0.1.dev148}/tests/core/expr/test_utils.py +0 -0
  104. {mplang_nightly-0.1.dev147 → mplang_nightly-0.1.dev148}/tests/core/expr/test_walk.py +0 -0
  105. {mplang_nightly-0.1.dev147 → mplang_nightly-0.1.dev148}/tests/core/test_cluster.py +0 -0
  106. {mplang_nightly-0.1.dev147 → mplang_nightly-0.1.dev148}/tests/core/test_dtype.py +0 -0
  107. {mplang_nightly-0.1.dev147 → mplang_nightly-0.1.dev148}/tests/core/test_mask.py +0 -0
  108. {mplang_nightly-0.1.dev147 → mplang_nightly-0.1.dev148}/tests/core/test_mpir.py +0 -0
  109. {mplang_nightly-0.1.dev147 → mplang_nightly-0.1.dev148}/tests/core/test_mptype.py +0 -0
  110. {mplang_nightly-0.1.dev147 → mplang_nightly-0.1.dev148}/tests/core/test_primitive.py +0 -0
  111. {mplang_nightly-0.1.dev147 → mplang_nightly-0.1.dev148}/tests/core/test_table.py +0 -0
  112. {mplang_nightly-0.1.dev147 → mplang_nightly-0.1.dev148}/tests/core/test_tensor.py +0 -0
  113. {mplang_nightly-0.1.dev147 → mplang_nightly-0.1.dev148}/tests/core/test_tracer.py +0 -0
  114. {mplang_nightly-0.1.dev147 → mplang_nightly-0.1.dev148}/tests/device/__init__.py +0 -0
  115. {mplang_nightly-0.1.dev147 → mplang_nightly-0.1.dev148}/tests/device/test_device_basic.py +0 -0
  116. {mplang_nightly-0.1.dev147 → mplang_nightly-0.1.dev148}/tests/frontend/__init__.py +0 -0
  117. {mplang_nightly-0.1.dev147 → mplang_nightly-0.1.dev148}/tests/frontend/dummy.py +0 -0
  118. {mplang_nightly-0.1.dev147 → mplang_nightly-0.1.dev148}/tests/frontend/test_builtin_pack.py +0 -0
  119. {mplang_nightly-0.1.dev147 → mplang_nightly-0.1.dev148}/tests/frontend/test_crypto_tee.py +0 -0
  120. {mplang_nightly-0.1.dev147 → mplang_nightly-0.1.dev148}/tests/frontend/test_feop_base.py +0 -0
  121. {mplang_nightly-0.1.dev147 → mplang_nightly-0.1.dev148}/tests/frontend/test_ibis.py +0 -0
  122. {mplang_nightly-0.1.dev147 → mplang_nightly-0.1.dev148}/tests/frontend/test_ibis_cc.py +0 -0
  123. {mplang_nightly-0.1.dev147 → mplang_nightly-0.1.dev148}/tests/frontend/test_jax_cc.py +0 -0
  124. {mplang_nightly-0.1.dev147 → mplang_nightly-0.1.dev148}/tests/frontend/test_phe.py +0 -0
  125. {mplang_nightly-0.1.dev147 → mplang_nightly-0.1.dev148}/tests/frontend/test_spu.py +0 -0
  126. {mplang_nightly-0.1.dev147 → mplang_nightly-0.1.dev148}/tests/frontend/test_spu_defensive.py +0 -0
  127. {mplang_nightly-0.1.dev147 → mplang_nightly-0.1.dev148}/tests/frontend/test_sql.py +0 -0
  128. {mplang_nightly-0.1.dev147 → mplang_nightly-0.1.dev148}/tests/frontend/test_table_tensor_conversion.py +0 -0
  129. {mplang_nightly-0.1.dev147 → mplang_nightly-0.1.dev148}/tests/integration/README.md +0 -0
  130. {mplang_nightly-0.1.dev147 → mplang_nightly-0.1.dev148}/tests/integration/test_crypto_roundtrip.py +0 -0
  131. {mplang_nightly-0.1.dev147 → mplang_nightly-0.1.dev148}/tests/integration/test_http_e2e.py +0 -0
  132. {mplang_nightly-0.1.dev147 → mplang_nightly-0.1.dev148}/tests/integration/test_symbols_roundtrip.py +0 -0
  133. {mplang_nightly-0.1.dev147 → mplang_nightly-0.1.dev148}/tests/integration/test_tutorials.py +0 -0
  134. {mplang_nightly-0.1.dev147 → mplang_nightly-0.1.dev148}/tests/runtime/__init__.py +0 -0
  135. {mplang_nightly-0.1.dev147 → mplang_nightly-0.1.dev148}/tests/runtime/test_cli.py +0 -0
  136. {mplang_nightly-0.1.dev147 → mplang_nightly-0.1.dev148}/tests/runtime/test_communicator.py +0 -0
  137. {mplang_nightly-0.1.dev147 → mplang_nightly-0.1.dev148}/tests/runtime/test_driver.py +0 -0
  138. {mplang_nightly-0.1.dev147 → mplang_nightly-0.1.dev148}/tests/runtime/test_server.py +0 -0
  139. {mplang_nightly-0.1.dev147 → mplang_nightly-0.1.dev148}/tests/runtime/test_simulation.py +0 -0
  140. {mplang_nightly-0.1.dev147 → mplang_nightly-0.1.dev148}/tests/simp/test_mpi.py +0 -0
  141. {mplang_nightly-0.1.dev147 → mplang_nightly-0.1.dev148}/tests/simp/test_random.py +0 -0
  142. {mplang_nightly-0.1.dev147 → mplang_nightly-0.1.dev148}/tests/simp/test_simp.py +0 -0
  143. {mplang_nightly-0.1.dev147 → mplang_nightly-0.1.dev148}/tests/simp/test_smpc.py +0 -0
  144. {mplang_nightly-0.1.dev147 → mplang_nightly-0.1.dev148}/tests/simp/test_sugar.py +0 -0
  145. {mplang_nightly-0.1.dev147 → mplang_nightly-0.1.dev148}/tests/utils/__init__.py +0 -0
  146. {mplang_nightly-0.1.dev147 → mplang_nightly-0.1.dev148}/tests/utils/server_fixtures.py +0 -0
  147. {mplang_nightly-0.1.dev147 → mplang_nightly-0.1.dev148}/tests/utils/test_func_utils.py +0 -0
  148. {mplang_nightly-0.1.dev147 → mplang_nightly-0.1.dev148}/tests/utils/test_spu_utils.py +0 -0
  149. {mplang_nightly-0.1.dev147 → mplang_nightly-0.1.dev148}/tests/utils/test_table_utils.py +0 -0
  150. {mplang_nightly-0.1.dev147 → mplang_nightly-0.1.dev148}/tutorials/0_basic.py +0 -0
  151. {mplang_nightly-0.1.dev147 → mplang_nightly-0.1.dev148}/tutorials/10_analysis.py +0 -0
  152. {mplang_nightly-0.1.dev147 → mplang_nightly-0.1.dev148}/tutorials/1_condition.py +0 -0
  153. {mplang_nightly-0.1.dev147 → mplang_nightly-0.1.dev148}/tutorials/2_whileloop.py +0 -0
  154. {mplang_nightly-0.1.dev147 → mplang_nightly-0.1.dev148}/tutorials/3_device.py +0 -0
  155. {mplang_nightly-0.1.dev147 → mplang_nightly-0.1.dev148}/tutorials/4_simulation.py +0 -0
  156. {mplang_nightly-0.1.dev147 → mplang_nightly-0.1.dev148}/tutorials/5_ir_dump.py +0 -0
  157. {mplang_nightly-0.1.dev147 → mplang_nightly-0.1.dev148}/tutorials/6_advanced.py +0 -0
  158. {mplang_nightly-0.1.dev147 → mplang_nightly-0.1.dev148}/tutorials/7_stdio.py +0 -0
  159. {mplang_nightly-0.1.dev147 → mplang_nightly-0.1.dev148}/tutorials/8_phe.py +0 -0
  160. {mplang_nightly-0.1.dev147 → mplang_nightly-0.1.dev148}/tutorials/9_tee.py +0 -0
  161. {mplang_nightly-0.1.dev147 → mplang_nightly-0.1.dev148}/tutorials/__init__.py +0 -0
  162. {mplang_nightly-0.1.dev147 → mplang_nightly-0.1.dev148}/tutorials/pitfalls/late_binding.py +0 -0
  163. {mplang_nightly-0.1.dev147 → mplang_nightly-0.1.dev148}/tutorials/pitfalls/rand.py +0 -0
  164. {mplang_nightly-0.1.dev147 → mplang_nightly-0.1.dev148}/tutorials/run.sh +0 -0
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: mplang-nightly
3
- Version: 0.1.dev147
3
+ Version: 0.1.dev148
4
4
  Summary: Multi-Party Programming Language
5
5
  Author-email: SecretFlow Team <secretflow-contact@service.alipay.com>
6
6
  License: Apache License
@@ -12,20 +12,21 @@
12
12
  # See the License for the specific language governing permissions and
13
13
  # limitations under the License.
14
14
 
15
- """Backend kernel registry & per-participant runtime (explicit op->kernel binding).
15
+ """Backend kernel registry: mapping kernel_id -> implementation.
16
16
 
17
- This version decouples *kernel implementation registration* from *operation binding*.
17
+ This module provides a lightweight registry for backend kernel implementations.
18
+ It does not track or decide which kernel handles a given semantic operation;
19
+ that policy (op -> kernel_id) is managed externally by each ``RuntimeContext``.
18
20
 
19
- Concepts:
20
- * kernel_id: unique identifier of a concrete backend implementation.
21
- * op_type: semantic operation name carried by ``PFunction.fn_type``.
22
- * bind_op(op_type, kernel_id): performed by higher layer (see ``backend.context``)
23
- to select which implementation handles an op. Runtime dispatch is now a 2-step:
24
- pfunc.fn_type -> active kernel_id -> KernelSpec.fn
21
+ Exposed primitives:
22
+ * ``@kernel_def(kernel_id)``: decorator to register a kernel implementation.
23
+ * ``get_kernel_spec(kernel_id)``: look up a previously registered kernel.
24
+ * ``cur_kctx()`` / ``KernelContext``: execution context available only
25
+ inside a kernel body (rank, world_size, per-backend state pockets, and a
26
+ runtime-wide cache shared by kernels of the same runtime instance).
25
27
 
26
- The previous implicit "import == register+bind" coupling is removed. Kernel
27
- modules only call ``@kernel_def(kernel_id)``. Default bindings are established
28
- centrally (lazy) the first time a runtime executes a kernel.
28
+ No global op binding table exists here; callers resolve an op to a kernel_id
29
+ before invoking the kernel function.
29
30
  """
30
31
 
31
32
  from __future__ import annotations
@@ -38,12 +39,10 @@ from typing import Any
38
39
  __all__ = [
39
40
  "KernelContext",
40
41
  "KernelSpec",
41
- "bind_op",
42
42
  "cur_kctx",
43
- "get_kernel_for_op",
43
+ "get_kernel_spec",
44
+ "kernel_exists",
44
45
  "list_kernels",
45
- "list_ops",
46
- "unbind_op",
47
46
  ]
48
47
 
49
48
 
@@ -116,9 +115,6 @@ class KernelSpec:
116
115
  # All registered kernel implementations: kernel_id -> spec
117
116
  _KERNELS: dict[str, KernelSpec] = {}
118
117
 
119
- # Active op bindings: op_type -> kernel_id
120
- _BINDINGS: dict[str, str] = {}
121
-
122
118
 
123
119
  def kernel_def(kernel_id: str, /, **meta: Any) -> Callable[[KernelFn], KernelFn]:
124
120
  """Decorator to register a concrete kernel implementation.
@@ -136,34 +132,11 @@ def kernel_def(kernel_id: str, /, **meta: Any) -> Callable[[KernelFn], KernelFn]
136
132
  return _decorator
137
133
 
138
134
 
139
- def bind_op(op_type: str, kernel_id: str, *, force: bool = True) -> None:
140
- """Bind an op_type to a registered kernel implementation.
141
-
142
- Args:
143
- op_type: Semantic operation name.
144
- kernel_id: Previously registered kernel identifier.
145
- force: If False and op_type already bound, keep existing binding.
146
- If True (default), overwrite.
147
- """
148
- if kernel_id not in _KERNELS:
135
+ def get_kernel_spec(kernel_id: str) -> KernelSpec:
136
+ """Return KernelSpec for a registered kernel_id (no op binding lookup)."""
137
+ spec = _KERNELS.get(kernel_id)
138
+ if spec is None:
149
139
  raise KeyError(f"kernel_id {kernel_id} not registered")
150
- if not force and op_type in _BINDINGS:
151
- return
152
- _BINDINGS[op_type] = kernel_id
153
-
154
-
155
- def unbind_op(op_type: str) -> None:
156
- _BINDINGS.pop(op_type, None)
157
-
158
-
159
- def get_kernel_for_op(op_type: str) -> KernelSpec:
160
- kid = _BINDINGS.get(op_type)
161
- if kid is None:
162
- # Tests expect NotImplementedError for unsupported operations
163
- raise NotImplementedError(f"no backend kernel registered for op {op_type}")
164
- spec = _KERNELS.get(kid)
165
- if spec is None: # inconsistent state
166
- raise RuntimeError(f"active kernel_id {kid} missing spec")
167
140
  return spec
168
141
 
169
142
 
@@ -171,5 +144,6 @@ def list_kernels() -> list[str]:
171
144
  return sorted(_KERNELS.keys())
172
145
 
173
146
 
174
- def list_ops() -> list[str]:
175
- return sorted(_BINDINGS.keys())
147
+ def kernel_exists(kernel_id: str) -> bool:
148
+ """Return True if a kernel_id has been registered."""
149
+ return kernel_id in _KERNELS
@@ -15,11 +15,10 @@
15
15
  from __future__ import annotations
16
16
 
17
17
  from collections.abc import Mapping
18
- from dataclasses import dataclass, field
19
18
  from typing import Any
20
19
 
21
20
  from mplang.backend import base
22
- from mplang.backend.base import KernelContext, bind_op, get_kernel_for_op
21
+ from mplang.backend.base import KernelContext, get_kernel_spec, kernel_exists
23
22
  from mplang.core.dtype import UINT8, DType
24
23
  from mplang.core.pfunc import PFunction
25
24
  from mplang.core.table import TableLike, TableType
@@ -100,30 +99,57 @@ _DEFAULT_BINDINGS: dict[str, str] = {
100
99
  # --- RuntimeContext ---
101
100
 
102
101
 
103
- @dataclass
104
102
  class RuntimeContext:
105
- rank: int
106
- world_size: int
107
- bindings: Mapping[str, str] | None = None # optional overrides
108
- state: dict[str, dict[str, Any]] = field(default_factory=dict)
109
- cache: dict[str, Any] = field(default_factory=dict)
110
- stats: dict[str, Any] = field(default_factory=dict)
111
-
112
- def __post_init__(self) -> None:
103
+ """Per-runtime execution context with isolated op->kernel bindings.
104
+
105
+ Parameters
106
+ ----------
107
+ rank : int
108
+ Local rank of this participant.
109
+ world_size : int
110
+ Total number of participants.
111
+ initial_bindings : Mapping[str, str] | None, optional
112
+ Optional partial overrides applied on top of the default binding table
113
+ during construction (override semantics, not replace). After
114
+ initialization, all (re)binding must go through ``bind_op`` /
115
+ ``rebind_op``.
116
+ state / cache / stats : dict, optional
117
+ Mutable pockets reused across kernel invocations. If omitted, new
118
+ dictionaries are created.
119
+ """
120
+
121
+ __slots__ = ("_ibindings", "cache", "rank", "state", "stats", "world_size")
122
+
123
+ def __init__(
124
+ self,
125
+ rank: int,
126
+ world_size: int,
127
+ initial_bindings: Mapping[str, str] | None = None,
128
+ *,
129
+ state: dict[str, dict[str, Any]] | None = None,
130
+ cache: dict[str, Any] | None = None,
131
+ stats: dict[str, Any] | None = None,
132
+ ) -> None:
113
133
  _ensure_impl_imported()
114
- if self.bindings is not None:
115
- for op, kid in self.bindings.items():
116
- bind_op(op, kid)
117
- else:
118
- for op, kid in _DEFAULT_BINDINGS.items():
119
- bind_op(op, kid)
120
- # Initialize stats pocket
134
+ self.rank = rank
135
+ self.world_size = world_size
136
+ # Merge defaults with user overrides (override semantics)
137
+ self._ibindings: dict[str, str] = {
138
+ **_DEFAULT_BINDINGS,
139
+ **(initial_bindings or {}),
140
+ }
141
+ self.state = state if state is not None else {}
142
+ self.cache = cache if cache is not None else {}
143
+ self.stats = stats if stats is not None else {}
121
144
  self.stats.setdefault("op_calls", {})
122
145
 
123
146
  def run_kernel(self, pfunc: PFunction, arg_list: list[Any]) -> list[Any]:
124
147
  fn_type = pfunc.fn_type
125
- spec = get_kernel_for_op(fn_type)
126
- fn = spec.fn
148
+ kid = self._ibindings.get(fn_type)
149
+ if kid is None:
150
+ raise NotImplementedError(f"no backend kernel registered for op {fn_type}")
151
+ spec = get_kernel_spec(kid)
152
+ fn = spec.fn # kernel implementation
127
153
  if len(arg_list) != len(pfunc.ins_info):
128
154
  raise ValueError(
129
155
  f"kernel {fn_type} arg count mismatch: got {len(arg_list)}, expect {len(pfunc.ins_info)}"
@@ -186,17 +212,32 @@ class RuntimeContext:
186
212
 
187
213
  # ---- explicit (re)binding API ----
188
214
  def bind_op(self, op_type: str, kernel_id: str, *, force: bool = False) -> None:
189
- """Bind an operation to a kernel at runtime.
215
+ """Bind an operation to a kernel for THIS context only.
190
216
 
191
- force=False (default) preserves any existing binding to avoid accidental
192
- silent overrides. Use ``rebind_op`` or ``force=True`` to intentionally
193
- change a binding.
217
+ force=False (default) keeps existing binding (no silent override).
194
218
  """
195
- base.bind_op(op_type, kernel_id, force=force)
219
+ if not kernel_exists(kernel_id):
220
+ raise KeyError(f"kernel_id {kernel_id} not registered")
221
+ if not force and op_type in self._ibindings:
222
+ return
223
+ self._ibindings[op_type] = kernel_id
196
224
 
197
225
  def rebind_op(self, op_type: str, kernel_id: str) -> None:
198
226
  """Force rebind an operation to a different kernel (shorthand)."""
199
- base.bind_op(op_type, kernel_id, force=True)
227
+ self.bind_op(op_type, kernel_id, force=True)
228
+
229
+ # Introspection helpers
230
+ def list_bound_ops(self) -> list[str]: # pragma: no cover - convenience
231
+ return sorted(self._ibindings.keys())
232
+
233
+ def get_binding(self, op_type: str) -> str | None: # pragma: no cover
234
+ return self._ibindings.get(op_type)
235
+
236
+ def __repr__(self) -> str: # pragma: no cover - debug aid
237
+ return (
238
+ f"RuntimeContext(rank={self.rank}, world_size={self.world_size}, "
239
+ f"bound_ops={len(self._ibindings)})"
240
+ )
200
241
 
201
242
 
202
243
  def _validate_table_arg(
@@ -0,0 +1,102 @@
1
+ # Copyright 2025 Ant Group Co., Ltd.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ # Tests for per-RuntimeContext binding isolation
16
+ from __future__ import annotations
17
+
18
+ import pytest
19
+
20
+ from mplang.backend import base
21
+ from mplang.backend.context import RuntimeContext
22
+ from mplang.core.dtype import INT64 # switched from INT32 to INT64 to match Python int
23
+ from mplang.core.pfunc import PFunction
24
+ from mplang.core.tensor import TensorType
25
+
26
+ # We'll register two fake kernels for an op to test rebinding.
27
+ # If they already exist due to other tests, we guard with try/except.
28
+
29
+
30
+ @base.kernel_def("test.echo.v1")
31
+ def _echo_v1(
32
+ pfunc: PFunction, x: int
33
+ ) -> tuple[int,]: # pragma: no cover - executed in test
34
+ return (x + 1,)
35
+
36
+
37
+ @base.kernel_def("test.echo.v2")
38
+ def _echo_v2(
39
+ pfunc: PFunction, x: int
40
+ ) -> tuple[int,]: # pragma: no cover - executed in test
41
+ return (x + 2,)
42
+
43
+
44
+ def make_pfunc(op_type: str) -> PFunction:
45
+ # Minimal PFunction stub compatible with backend.run_kernel expectations.
46
+ # shape info matters only for validation; use scalar INT64 (Python int maps to int64).
47
+ return PFunction(
48
+ fn_type=op_type,
49
+ fn_text="",
50
+ ins_info=[TensorType(shape=(), dtype=INT64)],
51
+ outs_info=[TensorType(shape=(), dtype=INT64)],
52
+ )
53
+
54
+
55
+ def test_isolated_rebind():
56
+ # ctx1 binds op -> v1, ctx2 binds op -> v2; they should not interfere.
57
+ op = "test.echo"
58
+ ctx1 = RuntimeContext(rank=0, world_size=1, initial_bindings={op: "test.echo.v1"})
59
+ ctx2 = RuntimeContext(rank=0, world_size=1, initial_bindings={op: "test.echo.v2"})
60
+
61
+ pfunc = make_pfunc(op)
62
+ out1 = ctx1.run_kernel(pfunc, [10])[0]
63
+ out2 = ctx2.run_kernel(pfunc, [10])[0]
64
+
65
+ assert out1 == 11
66
+ assert out2 == 12
67
+
68
+
69
+ def test_rebind_only_affects_context():
70
+ op = "test.echo"
71
+ ctx = RuntimeContext(rank=0, world_size=1, initial_bindings={op: "test.echo.v1"})
72
+ pfunc = make_pfunc(op)
73
+ assert ctx.run_kernel(pfunc, [5])[0] == 6
74
+ ctx.rebind_op(op, "test.echo.v2")
75
+ assert ctx.run_kernel(pfunc, [5])[0] == 7
76
+
77
+
78
+ def test_force_flag():
79
+ op = "test.echo"
80
+ ctx = RuntimeContext(rank=0, world_size=1, initial_bindings={op: "test.echo.v1"})
81
+ # Attempt non-force bind (should keep v1)
82
+ ctx.bind_op(op, "test.echo.v2", force=False)
83
+ pfunc = make_pfunc(op)
84
+ assert ctx.run_kernel(pfunc, [1])[0] == 2 # still v1 (+1)
85
+ # Now force
86
+ ctx.bind_op(op, "test.echo.v2", force=True)
87
+ assert ctx.run_kernel(pfunc, [1])[0] == 3
88
+
89
+
90
+ def test_unknown_kernel_id():
91
+ ctx = RuntimeContext(rank=0, world_size=1)
92
+ with pytest.raises(KeyError):
93
+ ctx.bind_op("some.op", "non.existent.kernel")
94
+
95
+
96
+ def test_missing_binding():
97
+ # Pick an op name unlikely in defaults
98
+ op = "unit.test.unbound"
99
+ ctx = RuntimeContext(rank=0, world_size=1)
100
+ pfunc = make_pfunc(op)
101
+ with pytest.raises(NotImplementedError):
102
+ ctx.run_kernel(pfunc, [0])