mplang-nightly 0.1.dev146__tar.gz → 0.1.dev147__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (163) hide show
  1. {mplang_nightly-0.1.dev146 → mplang_nightly-0.1.dev147}/PKG-INFO +1 -1
  2. {mplang_nightly-0.1.dev146 → mplang_nightly-0.1.dev147}/mplang/core/expr/evaluator.py +26 -8
  3. {mplang_nightly-0.1.dev146 → mplang_nightly-0.1.dev147}/mplang/core/primitive.py +30 -0
  4. {mplang_nightly-0.1.dev146 → mplang_nightly-0.1.dev147}/tests/runtime/test_simulation.py +182 -1
  5. {mplang_nightly-0.1.dev146 → mplang_nightly-0.1.dev147}/.gitignore +0 -0
  6. {mplang_nightly-0.1.dev146 → mplang_nightly-0.1.dev147}/LICENSE +0 -0
  7. {mplang_nightly-0.1.dev146 → mplang_nightly-0.1.dev147}/README.md +0 -0
  8. {mplang_nightly-0.1.dev146 → mplang_nightly-0.1.dev147}/examples/conf/3pc.yaml +0 -0
  9. {mplang_nightly-0.1.dev146 → mplang_nightly-0.1.dev147}/examples/stax_nn/README.md +0 -0
  10. {mplang_nightly-0.1.dev146 → mplang_nightly-0.1.dev147}/examples/stax_nn/models.py +0 -0
  11. {mplang_nightly-0.1.dev146 → mplang_nightly-0.1.dev147}/examples/stax_nn/stax_nn.py +0 -0
  12. {mplang_nightly-0.1.dev146 → mplang_nightly-0.1.dev147}/examples/xgboost/hist_jax.py +0 -0
  13. {mplang_nightly-0.1.dev146 → mplang_nightly-0.1.dev147}/examples/xgboost/hist_jax_test.py +0 -0
  14. {mplang_nightly-0.1.dev146 → mplang_nightly-0.1.dev147}/examples/xgboost/naive_np.py +0 -0
  15. {mplang_nightly-0.1.dev146 → mplang_nightly-0.1.dev147}/examples/xgboost/readme.md +0 -0
  16. {mplang_nightly-0.1.dev146 → mplang_nightly-0.1.dev147}/examples/xgboost/sgb.py +0 -0
  17. {mplang_nightly-0.1.dev146 → mplang_nightly-0.1.dev147}/examples/xgboost/sgb_test.py +0 -0
  18. {mplang_nightly-0.1.dev146 → mplang_nightly-0.1.dev147}/hatch_build.py +0 -0
  19. {mplang_nightly-0.1.dev146 → mplang_nightly-0.1.dev147}/mplang/__init__.py +0 -0
  20. {mplang_nightly-0.1.dev146 → mplang_nightly-0.1.dev147}/mplang/analysis/__init__.py +0 -0
  21. {mplang_nightly-0.1.dev146 → mplang_nightly-0.1.dev147}/mplang/analysis/diagram.py +0 -0
  22. {mplang_nightly-0.1.dev146 → mplang_nightly-0.1.dev147}/mplang/api.py +0 -0
  23. {mplang_nightly-0.1.dev146 → mplang_nightly-0.1.dev147}/mplang/backend/__init__.py +0 -0
  24. {mplang_nightly-0.1.dev146 → mplang_nightly-0.1.dev147}/mplang/backend/base.py +0 -0
  25. {mplang_nightly-0.1.dev146 → mplang_nightly-0.1.dev147}/mplang/backend/builtin.py +0 -0
  26. {mplang_nightly-0.1.dev146 → mplang_nightly-0.1.dev147}/mplang/backend/context.py +0 -0
  27. {mplang_nightly-0.1.dev146 → mplang_nightly-0.1.dev147}/mplang/backend/crypto.py +0 -0
  28. {mplang_nightly-0.1.dev146 → mplang_nightly-0.1.dev147}/mplang/backend/phe.py +0 -0
  29. {mplang_nightly-0.1.dev146 → mplang_nightly-0.1.dev147}/mplang/backend/spu.py +0 -0
  30. {mplang_nightly-0.1.dev146 → mplang_nightly-0.1.dev147}/mplang/backend/sql_duckdb.py +0 -0
  31. {mplang_nightly-0.1.dev146 → mplang_nightly-0.1.dev147}/mplang/backend/stablehlo.py +0 -0
  32. {mplang_nightly-0.1.dev146 → mplang_nightly-0.1.dev147}/mplang/backend/tee.py +0 -0
  33. {mplang_nightly-0.1.dev146 → mplang_nightly-0.1.dev147}/mplang/core/__init__.py +0 -0
  34. {mplang_nightly-0.1.dev146 → mplang_nightly-0.1.dev147}/mplang/core/cluster.py +0 -0
  35. {mplang_nightly-0.1.dev146 → mplang_nightly-0.1.dev147}/mplang/core/comm.py +0 -0
  36. {mplang_nightly-0.1.dev146 → mplang_nightly-0.1.dev147}/mplang/core/context_mgr.py +0 -0
  37. {mplang_nightly-0.1.dev146 → mplang_nightly-0.1.dev147}/mplang/core/dtype.py +0 -0
  38. {mplang_nightly-0.1.dev146 → mplang_nightly-0.1.dev147}/mplang/core/expr/__init__.py +0 -0
  39. {mplang_nightly-0.1.dev146 → mplang_nightly-0.1.dev147}/mplang/core/expr/ast.py +0 -0
  40. {mplang_nightly-0.1.dev146 → mplang_nightly-0.1.dev147}/mplang/core/expr/printer.py +0 -0
  41. {mplang_nightly-0.1.dev146 → mplang_nightly-0.1.dev147}/mplang/core/expr/transformer.py +0 -0
  42. {mplang_nightly-0.1.dev146 → mplang_nightly-0.1.dev147}/mplang/core/expr/utils.py +0 -0
  43. {mplang_nightly-0.1.dev146 → mplang_nightly-0.1.dev147}/mplang/core/expr/visitor.py +0 -0
  44. {mplang_nightly-0.1.dev146 → mplang_nightly-0.1.dev147}/mplang/core/expr/walk.py +0 -0
  45. {mplang_nightly-0.1.dev146 → mplang_nightly-0.1.dev147}/mplang/core/interp.py +0 -0
  46. {mplang_nightly-0.1.dev146 → mplang_nightly-0.1.dev147}/mplang/core/mask.py +0 -0
  47. {mplang_nightly-0.1.dev146 → mplang_nightly-0.1.dev147}/mplang/core/mpir.py +0 -0
  48. {mplang_nightly-0.1.dev146 → mplang_nightly-0.1.dev147}/mplang/core/mpobject.py +0 -0
  49. {mplang_nightly-0.1.dev146 → mplang_nightly-0.1.dev147}/mplang/core/mptype.py +0 -0
  50. {mplang_nightly-0.1.dev146 → mplang_nightly-0.1.dev147}/mplang/core/pfunc.py +0 -0
  51. {mplang_nightly-0.1.dev146 → mplang_nightly-0.1.dev147}/mplang/core/table.py +0 -0
  52. {mplang_nightly-0.1.dev146 → mplang_nightly-0.1.dev147}/mplang/core/tensor.py +0 -0
  53. {mplang_nightly-0.1.dev146 → mplang_nightly-0.1.dev147}/mplang/core/tracer.py +0 -0
  54. {mplang_nightly-0.1.dev146 → mplang_nightly-0.1.dev147}/mplang/device.py +0 -0
  55. {mplang_nightly-0.1.dev146 → mplang_nightly-0.1.dev147}/mplang/frontend/__init__.py +0 -0
  56. {mplang_nightly-0.1.dev146 → mplang_nightly-0.1.dev147}/mplang/frontend/base.py +0 -0
  57. {mplang_nightly-0.1.dev146 → mplang_nightly-0.1.dev147}/mplang/frontend/builtin.py +0 -0
  58. {mplang_nightly-0.1.dev146 → mplang_nightly-0.1.dev147}/mplang/frontend/crypto.py +0 -0
  59. {mplang_nightly-0.1.dev146 → mplang_nightly-0.1.dev147}/mplang/frontend/ibis_cc.py +0 -0
  60. {mplang_nightly-0.1.dev146 → mplang_nightly-0.1.dev147}/mplang/frontend/jax_cc.py +0 -0
  61. {mplang_nightly-0.1.dev146 → mplang_nightly-0.1.dev147}/mplang/frontend/phe.py +0 -0
  62. {mplang_nightly-0.1.dev146 → mplang_nightly-0.1.dev147}/mplang/frontend/spu.py +0 -0
  63. {mplang_nightly-0.1.dev146 → mplang_nightly-0.1.dev147}/mplang/frontend/sql.py +0 -0
  64. {mplang_nightly-0.1.dev146 → mplang_nightly-0.1.dev147}/mplang/frontend/tee.py +0 -0
  65. {mplang_nightly-0.1.dev146 → mplang_nightly-0.1.dev147}/mplang/protos/v1alpha1/mpir_pb2.py +0 -0
  66. {mplang_nightly-0.1.dev146 → mplang_nightly-0.1.dev147}/mplang/protos/v1alpha1/mpir_pb2.pyi +0 -0
  67. {mplang_nightly-0.1.dev146 → mplang_nightly-0.1.dev147}/mplang/protos/v1alpha1/mpir_pb2_grpc.py +0 -0
  68. {mplang_nightly-0.1.dev146 → mplang_nightly-0.1.dev147}/mplang/runtime/__init__.py +0 -0
  69. {mplang_nightly-0.1.dev146 → mplang_nightly-0.1.dev147}/mplang/runtime/cli.py +0 -0
  70. {mplang_nightly-0.1.dev146 → mplang_nightly-0.1.dev147}/mplang/runtime/client.py +0 -0
  71. {mplang_nightly-0.1.dev146 → mplang_nightly-0.1.dev147}/mplang/runtime/communicator.py +0 -0
  72. {mplang_nightly-0.1.dev146 → mplang_nightly-0.1.dev147}/mplang/runtime/data_providers.py +0 -0
  73. {mplang_nightly-0.1.dev146 → mplang_nightly-0.1.dev147}/mplang/runtime/driver.py +0 -0
  74. {mplang_nightly-0.1.dev146 → mplang_nightly-0.1.dev147}/mplang/runtime/exceptions.py +0 -0
  75. {mplang_nightly-0.1.dev146 → mplang_nightly-0.1.dev147}/mplang/runtime/http_api.md +0 -0
  76. {mplang_nightly-0.1.dev146 → mplang_nightly-0.1.dev147}/mplang/runtime/link_comm.py +0 -0
  77. {mplang_nightly-0.1.dev146 → mplang_nightly-0.1.dev147}/mplang/runtime/resource.py +0 -0
  78. {mplang_nightly-0.1.dev146 → mplang_nightly-0.1.dev147}/mplang/runtime/server.py +0 -0
  79. {mplang_nightly-0.1.dev146 → mplang_nightly-0.1.dev147}/mplang/runtime/simulation.py +0 -0
  80. {mplang_nightly-0.1.dev146 → mplang_nightly-0.1.dev147}/mplang/simp/__init__.py +0 -0
  81. {mplang_nightly-0.1.dev146 → mplang_nightly-0.1.dev147}/mplang/simp/mpi.py +0 -0
  82. {mplang_nightly-0.1.dev146 → mplang_nightly-0.1.dev147}/mplang/simp/random.py +0 -0
  83. {mplang_nightly-0.1.dev146 → mplang_nightly-0.1.dev147}/mplang/simp/smpc.py +0 -0
  84. {mplang_nightly-0.1.dev146 → mplang_nightly-0.1.dev147}/mplang/utils/__init__.py +0 -0
  85. {mplang_nightly-0.1.dev146 → mplang_nightly-0.1.dev147}/mplang/utils/crypto.py +0 -0
  86. {mplang_nightly-0.1.dev146 → mplang_nightly-0.1.dev147}/mplang/utils/func_utils.py +0 -0
  87. {mplang_nightly-0.1.dev146 → mplang_nightly-0.1.dev147}/mplang/utils/spu_utils.py +0 -0
  88. {mplang_nightly-0.1.dev146 → mplang_nightly-0.1.dev147}/mplang/utils/table_utils.py +0 -0
  89. {mplang_nightly-0.1.dev146 → mplang_nightly-0.1.dev147}/pyproject.toml +0 -0
  90. {mplang_nightly-0.1.dev146 → mplang_nightly-0.1.dev147}/tests/__init__.py +0 -0
  91. {mplang_nightly-0.1.dev146 → mplang_nightly-0.1.dev147}/tests/analysis/test_diagram.py +0 -0
  92. {mplang_nightly-0.1.dev146 → mplang_nightly-0.1.dev147}/tests/backend/test_builtin.py +0 -0
  93. {mplang_nightly-0.1.dev146 → mplang_nightly-0.1.dev147}/tests/backend/test_debug_print.py +0 -0
  94. {mplang_nightly-0.1.dev146 → mplang_nightly-0.1.dev147}/tests/backend/test_phe.py +0 -0
  95. {mplang_nightly-0.1.dev146 → mplang_nightly-0.1.dev147}/tests/backend/test_spu.py +0 -0
  96. {mplang_nightly-0.1.dev146 → mplang_nightly-0.1.dev147}/tests/backend/test_sql_duckdb.py +0 -0
  97. {mplang_nightly-0.1.dev146 → mplang_nightly-0.1.dev147}/tests/backend/test_stablehlo.py +0 -0
  98. {mplang_nightly-0.1.dev146 → mplang_nightly-0.1.dev147}/tests/core/__init__.py +0 -0
  99. {mplang_nightly-0.1.dev146 → mplang_nightly-0.1.dev147}/tests/core/expr/__init__.py +0 -0
  100. {mplang_nightly-0.1.dev146 → mplang_nightly-0.1.dev147}/tests/core/expr/conftest.py +0 -0
  101. {mplang_nightly-0.1.dev146 → mplang_nightly-0.1.dev147}/tests/core/expr/test_ast.py +0 -0
  102. {mplang_nightly-0.1.dev146 → mplang_nightly-0.1.dev147}/tests/core/expr/test_printer.py +0 -0
  103. {mplang_nightly-0.1.dev146 → mplang_nightly-0.1.dev147}/tests/core/expr/test_utils.py +0 -0
  104. {mplang_nightly-0.1.dev146 → mplang_nightly-0.1.dev147}/tests/core/expr/test_walk.py +0 -0
  105. {mplang_nightly-0.1.dev146 → mplang_nightly-0.1.dev147}/tests/core/test_cluster.py +0 -0
  106. {mplang_nightly-0.1.dev146 → mplang_nightly-0.1.dev147}/tests/core/test_dtype.py +0 -0
  107. {mplang_nightly-0.1.dev146 → mplang_nightly-0.1.dev147}/tests/core/test_mask.py +0 -0
  108. {mplang_nightly-0.1.dev146 → mplang_nightly-0.1.dev147}/tests/core/test_mpir.py +0 -0
  109. {mplang_nightly-0.1.dev146 → mplang_nightly-0.1.dev147}/tests/core/test_mptype.py +0 -0
  110. {mplang_nightly-0.1.dev146 → mplang_nightly-0.1.dev147}/tests/core/test_primitive.py +0 -0
  111. {mplang_nightly-0.1.dev146 → mplang_nightly-0.1.dev147}/tests/core/test_table.py +0 -0
  112. {mplang_nightly-0.1.dev146 → mplang_nightly-0.1.dev147}/tests/core/test_tensor.py +0 -0
  113. {mplang_nightly-0.1.dev146 → mplang_nightly-0.1.dev147}/tests/core/test_tracer.py +0 -0
  114. {mplang_nightly-0.1.dev146 → mplang_nightly-0.1.dev147}/tests/device/__init__.py +0 -0
  115. {mplang_nightly-0.1.dev146 → mplang_nightly-0.1.dev147}/tests/device/test_device_basic.py +0 -0
  116. {mplang_nightly-0.1.dev146 → mplang_nightly-0.1.dev147}/tests/frontend/__init__.py +0 -0
  117. {mplang_nightly-0.1.dev146 → mplang_nightly-0.1.dev147}/tests/frontend/dummy.py +0 -0
  118. {mplang_nightly-0.1.dev146 → mplang_nightly-0.1.dev147}/tests/frontend/test_builtin_pack.py +0 -0
  119. {mplang_nightly-0.1.dev146 → mplang_nightly-0.1.dev147}/tests/frontend/test_crypto_tee.py +0 -0
  120. {mplang_nightly-0.1.dev146 → mplang_nightly-0.1.dev147}/tests/frontend/test_feop_base.py +0 -0
  121. {mplang_nightly-0.1.dev146 → mplang_nightly-0.1.dev147}/tests/frontend/test_ibis.py +0 -0
  122. {mplang_nightly-0.1.dev146 → mplang_nightly-0.1.dev147}/tests/frontend/test_ibis_cc.py +0 -0
  123. {mplang_nightly-0.1.dev146 → mplang_nightly-0.1.dev147}/tests/frontend/test_jax_cc.py +0 -0
  124. {mplang_nightly-0.1.dev146 → mplang_nightly-0.1.dev147}/tests/frontend/test_phe.py +0 -0
  125. {mplang_nightly-0.1.dev146 → mplang_nightly-0.1.dev147}/tests/frontend/test_spu.py +0 -0
  126. {mplang_nightly-0.1.dev146 → mplang_nightly-0.1.dev147}/tests/frontend/test_spu_defensive.py +0 -0
  127. {mplang_nightly-0.1.dev146 → mplang_nightly-0.1.dev147}/tests/frontend/test_sql.py +0 -0
  128. {mplang_nightly-0.1.dev146 → mplang_nightly-0.1.dev147}/tests/frontend/test_table_tensor_conversion.py +0 -0
  129. {mplang_nightly-0.1.dev146 → mplang_nightly-0.1.dev147}/tests/integration/README.md +0 -0
  130. {mplang_nightly-0.1.dev146 → mplang_nightly-0.1.dev147}/tests/integration/test_crypto_roundtrip.py +0 -0
  131. {mplang_nightly-0.1.dev146 → mplang_nightly-0.1.dev147}/tests/integration/test_http_e2e.py +0 -0
  132. {mplang_nightly-0.1.dev146 → mplang_nightly-0.1.dev147}/tests/integration/test_symbols_roundtrip.py +0 -0
  133. {mplang_nightly-0.1.dev146 → mplang_nightly-0.1.dev147}/tests/integration/test_tutorials.py +0 -0
  134. {mplang_nightly-0.1.dev146 → mplang_nightly-0.1.dev147}/tests/runtime/__init__.py +0 -0
  135. {mplang_nightly-0.1.dev146 → mplang_nightly-0.1.dev147}/tests/runtime/test_cli.py +0 -0
  136. {mplang_nightly-0.1.dev146 → mplang_nightly-0.1.dev147}/tests/runtime/test_communicator.py +0 -0
  137. {mplang_nightly-0.1.dev146 → mplang_nightly-0.1.dev147}/tests/runtime/test_driver.py +0 -0
  138. {mplang_nightly-0.1.dev146 → mplang_nightly-0.1.dev147}/tests/runtime/test_server.py +0 -0
  139. {mplang_nightly-0.1.dev146 → mplang_nightly-0.1.dev147}/tests/simp/test_mpi.py +0 -0
  140. {mplang_nightly-0.1.dev146 → mplang_nightly-0.1.dev147}/tests/simp/test_random.py +0 -0
  141. {mplang_nightly-0.1.dev146 → mplang_nightly-0.1.dev147}/tests/simp/test_simp.py +0 -0
  142. {mplang_nightly-0.1.dev146 → mplang_nightly-0.1.dev147}/tests/simp/test_smpc.py +0 -0
  143. {mplang_nightly-0.1.dev146 → mplang_nightly-0.1.dev147}/tests/simp/test_sugar.py +0 -0
  144. {mplang_nightly-0.1.dev146 → mplang_nightly-0.1.dev147}/tests/utils/__init__.py +0 -0
  145. {mplang_nightly-0.1.dev146 → mplang_nightly-0.1.dev147}/tests/utils/server_fixtures.py +0 -0
  146. {mplang_nightly-0.1.dev146 → mplang_nightly-0.1.dev147}/tests/utils/test_func_utils.py +0 -0
  147. {mplang_nightly-0.1.dev146 → mplang_nightly-0.1.dev147}/tests/utils/test_spu_utils.py +0 -0
  148. {mplang_nightly-0.1.dev146 → mplang_nightly-0.1.dev147}/tests/utils/test_table_utils.py +0 -0
  149. {mplang_nightly-0.1.dev146 → mplang_nightly-0.1.dev147}/tutorials/0_basic.py +0 -0
  150. {mplang_nightly-0.1.dev146 → mplang_nightly-0.1.dev147}/tutorials/10_analysis.py +0 -0
  151. {mplang_nightly-0.1.dev146 → mplang_nightly-0.1.dev147}/tutorials/1_condition.py +0 -0
  152. {mplang_nightly-0.1.dev146 → mplang_nightly-0.1.dev147}/tutorials/2_whileloop.py +0 -0
  153. {mplang_nightly-0.1.dev146 → mplang_nightly-0.1.dev147}/tutorials/3_device.py +0 -0
  154. {mplang_nightly-0.1.dev146 → mplang_nightly-0.1.dev147}/tutorials/4_simulation.py +0 -0
  155. {mplang_nightly-0.1.dev146 → mplang_nightly-0.1.dev147}/tutorials/5_ir_dump.py +0 -0
  156. {mplang_nightly-0.1.dev146 → mplang_nightly-0.1.dev147}/tutorials/6_advanced.py +0 -0
  157. {mplang_nightly-0.1.dev146 → mplang_nightly-0.1.dev147}/tutorials/7_stdio.py +0 -0
  158. {mplang_nightly-0.1.dev146 → mplang_nightly-0.1.dev147}/tutorials/8_phe.py +0 -0
  159. {mplang_nightly-0.1.dev146 → mplang_nightly-0.1.dev147}/tutorials/9_tee.py +0 -0
  160. {mplang_nightly-0.1.dev146 → mplang_nightly-0.1.dev147}/tutorials/__init__.py +0 -0
  161. {mplang_nightly-0.1.dev146 → mplang_nightly-0.1.dev147}/tutorials/pitfalls/late_binding.py +0 -0
  162. {mplang_nightly-0.1.dev146 → mplang_nightly-0.1.dev147}/tutorials/pitfalls/rand.py +0 -0
  163. {mplang_nightly-0.1.dev146 → mplang_nightly-0.1.dev147}/tutorials/run.sh +0 -0
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: mplang-nightly
3
- Version: 0.1.dev146
3
+ Version: 0.1.dev147
4
4
  Summary: Multi-Party Programming Language
5
5
  Author-email: SecretFlow Team <secretflow-contact@service.alipay.com>
6
6
  License: Apache License
@@ -196,6 +196,28 @@ class EvalSemantic:
196
196
  "uniform_cond: predicate is not uniform across parties"
197
197
  )
198
198
 
199
+ # ------------------------------ While helpers ------------------------------
200
+ def _check_while_predicate(self, cond_result: list[Any]) -> Any:
201
+ """Validate while_loop predicate evaluation result.
202
+
203
+ Ensures the condition function returns exactly one value and that value
204
+ is non-None. Returns the boolean predicate value for convenience.
205
+
206
+ Raises:
207
+ AssertionError: If condition function returns != 1 value.
208
+ RuntimeError: If the single predicate value is None.
209
+ """
210
+ assert len(cond_result) == 1, (
211
+ f"Condition function must return a single value, got {cond_result}"
212
+ )
213
+ cond_value = cond_result[0]
214
+ if cond_value is None:
215
+ raise RuntimeError(
216
+ "while_loop condition produced None on rank "
217
+ f"{self.rank}; ensure the predicate yields a boolean for every party."
218
+ )
219
+ return cond_value
220
+
199
221
 
200
222
  class RecursiveEvaluator(EvalSemantic, ExprVisitor):
201
223
  """Recursive visitor-based evaluator."""
@@ -307,12 +329,8 @@ class RecursiveEvaluator(EvalSemantic, ExprVisitor):
307
329
  cond_env = dict(zip(expr.cond_fn.params, state, strict=True))
308
330
  cond_evaluator = self._fork(cond_env)
309
331
  cond_result = expr.cond_fn.body.accept(cond_evaluator)
310
-
311
- assert len(cond_result) == 1, (
312
- f"Condition function must return a single value, got {cond_result}"
313
- )
314
-
315
- if not cond_result[0]:
332
+ cond_value = self._check_while_predicate(cond_result)
333
+ if not cond_value:
316
334
  break
317
335
 
318
336
  # Call body function with same arguments
@@ -445,8 +463,8 @@ class IterativeEvaluator(EvalSemantic):
445
463
  cond_vals = self._iter_eval_graph(
446
464
  node.cond_fn.body, {**env, **cond_env}
447
465
  )
448
- assert len(cond_vals) == 1
449
- if not bool(cond_vals[0]):
466
+ cond_val = self._check_while_predicate(cond_vals)
467
+ if not bool(cond_val):
450
468
  break
451
469
  body_env = dict(zip(node.body_fn.params, state, strict=True))
452
470
  new_state = self._iter_eval_graph(
@@ -483,6 +483,20 @@ def uniform_cond(
483
483
  if pred_ty.dtype != BOOL:
484
484
  raise TypeError(f"uniform_cond predicate must be boolean, got {pred_ty.dtype}")
485
485
 
486
+ # Static pmask rule:
487
+ # If predicate has a static pmask (not None), it must equal the current trace
488
+ # context mask. Otherwise some parties would execute a branch without a
489
+ # defined predicate value (unsafe). To run on a subset either:
490
+ # 1. Trace the entire uniform_cond under a subset TraceContext (ctx.fork(mask=...))
491
+ # 2. Broadcast / lift predicate to full mask (e.g. pshfl_s)
492
+ # Pred pmask None => dynamic: defer to runtime uniformity (if verify_uniform=True).
493
+ pred_pmask = pred_ty.pmask
494
+ if pred_pmask is not None and pred_pmask != cur_tracer.mask:
495
+ raise ValueError(
496
+ "uniform_cond predicate static pmask mismatch: predicate pmask="
497
+ f"{pred_pmask} trace mask={cur_tracer.mask}. Trace under a subset "
498
+ "context (ctx.fork(mask=...)) or broadcast predicate (pshfl_s) to all parties."
499
+ )
486
500
  # Step 1: Trace both branches in separate contexts
487
501
  then_tracer = cur_tracer.fork()
488
502
  then_tfn = trace(then_tracer, then_fn, *args)
@@ -706,6 +720,22 @@ def while_loop(
706
720
  f"Condition function must return a boolean scalar, got dtype {cond_out_var.mptype.dtype}"
707
721
  )
708
722
 
723
+ # Static pmask rule:
724
+ # If the predicate's pmask is statically known it must match the trace context
725
+ # mask. Otherwise some parties in this context would lack a boolean to drive
726
+ # control flow (previously could lead to hang via None). To restrict to a subset:
727
+ # 1. Trace the entire while_loop under a subset context (ctx.fork(mask=submask)), or
728
+ # 2. Broadcast predicate to full mask (e.g. pshfl_s) before while_loop.
729
+ # Dynamic predicates (pmask=None) are allowed; runtime guard (evaluator) raises
730
+ # if any participating party observes None.
731
+ pred_pmask = cond_out_var.mptype.pmask
732
+ if pred_pmask is not None and pred_pmask != cur_tracer.mask:
733
+ raise ValueError(
734
+ "while_loop predicate static pmask mismatch: predicate pmask="
735
+ f"{pred_pmask} trace mask={cur_tracer.mask}. Trace under subset context "
736
+ "or broadcast predicate to all parties."
737
+ )
738
+
709
739
  # Validate body returns same number of leaves and same dtype/shape per leaf
710
740
  if len(body_tfn.out_vars) != len(cond_tfn.in_vars):
711
741
  raise ValueError(
@@ -33,7 +33,14 @@ from mplang.core.dtype import FLOAT32, INT32
33
33
  from mplang.core.mask import Mask
34
34
  from mplang.core.mpobject import MPObject
35
35
  from mplang.core.mptype import MPType, Rank
36
- from mplang.core.primitive import constant, prank, pshfl_s, uniform_cond, while_loop
36
+ from mplang.core.primitive import (
37
+ constant,
38
+ prank,
39
+ pshfl_s,
40
+ set_mask,
41
+ uniform_cond,
42
+ while_loop,
43
+ )
37
44
  from mplang.core.tracer import TraceContext, TraceVar, trace
38
45
  from mplang.runtime.simulation import Simulator, SimVar
39
46
 
@@ -1392,6 +1399,180 @@ class TestWhileLoop:
1392
1399
  assert results[0].values[0] == 6 # Party 0: 3 iterations
1393
1400
  assert results[0].values[1] == 4 # Party 1: 2 iterations
1394
1401
 
1402
+ def test_while_loop_subset_state_mask(self):
1403
+ """Loop state and control stay on subset of parties."""
1404
+
1405
+ cluster_spec = ClusterSpec.simple(world_size=3)
1406
+ full_mask = Mask(0b111)
1407
+ subset_mask = Mask(0b011)
1408
+ trace_ctx = TraceContext(cluster_spec=cluster_spec, mask=full_mask)
1409
+ simulator = Simulator.simple(world_size=3)
1410
+
1411
+ def subset_loop():
1412
+ init_state = set_mask(constant(np.int64(0)), subset_mask)
1413
+ threshold = set_mask(constant(np.int64(3)), subset_mask)
1414
+ step = set_mask(constant(np.int64(1)), subset_mask)
1415
+
1416
+ def cond_fn(state):
1417
+ subset_pred = simp.run(lambda val, limit: val < limit)(state, threshold)
1418
+ return pshfl_s(subset_pred, full_mask, [Rank(0), Rank(0), Rank(0)])
1419
+
1420
+ def body_fn(state):
1421
+ return simp.run(lambda val, inc: val + inc)(state, step)
1422
+
1423
+ return while_loop(cond_fn, body_fn, init_state)
1424
+
1425
+ with with_ctx(trace_ctx):
1426
+ traced_fn = trace(trace_ctx, subset_loop)
1427
+
1428
+ func_expr = traced_fn.make_expr()
1429
+ assert func_expr is not None
1430
+ expr = func_expr.body
1431
+ results = simulator.evaluate(expr, {})
1432
+
1433
+ assert len(results) == 1
1434
+ sim_var = results[0]
1435
+ assert isinstance(sim_var, SimVar)
1436
+ assert sim_var.mptype.pmask == subset_mask
1437
+
1438
+ values = sim_var.values
1439
+ assert len(values) == 3
1440
+ assert values[0] == 3
1441
+ assert values[1] == 3
1442
+ assert values[2] is None
1443
+
1444
+ def test_while_loop_subset_context_mask_success(self):
1445
+ """Trace under subset context mask; predicate pmask==context mask so no broadcast needed.
1446
+
1447
+ Ensures static pmask validation (design A) does NOT raise when the trace context
1448
+ itself is the subset. Predicate pmask equals the context mask.
1449
+ """
1450
+ # Use a 2-party cluster because only parties 0 and 1 participate.
1451
+ cluster_spec = ClusterSpec.simple(world_size=2)
1452
+ subset_mask = Mask(0b11) # parties 0 and 1
1453
+ trace_ctx = TraceContext(cluster_spec=cluster_spec, mask=subset_mask)
1454
+ simulator = Simulator.simple(world_size=2)
1455
+
1456
+ def subset_loop():
1457
+ init_state = set_mask(constant(np.int64(0)), subset_mask)
1458
+ threshold = set_mask(constant(np.int64(3)), subset_mask)
1459
+ step = set_mask(constant(np.int64(1)), subset_mask)
1460
+
1461
+ def cond_fn(state):
1462
+ # Returns bool with pmask=subset_mask (no broadcast)
1463
+ return simp.run(lambda val, limit: val < limit)(state, threshold)
1464
+
1465
+ def body_fn(state):
1466
+ return simp.run(lambda val, inc: val + inc)(state, step)
1467
+
1468
+ return while_loop(cond_fn, body_fn, init_state)
1469
+
1470
+ with with_ctx(trace_ctx):
1471
+ traced_fn = trace(trace_ctx, subset_loop)
1472
+
1473
+ func_expr = traced_fn.make_expr()
1474
+ assert func_expr is not None
1475
+ expr = func_expr.body
1476
+ results = simulator.evaluate(expr, {})
1477
+
1478
+ assert len(results) == 1
1479
+ sim_var = results[0]
1480
+ assert isinstance(sim_var, SimVar)
1481
+ assert sim_var.mptype.pmask == subset_mask
1482
+ values = sim_var.values
1483
+ assert len(values) == 2
1484
+ assert values[0] == 3
1485
+ assert values[1] == 3
1486
+
1487
+ def test_while_loop_predicate_static_pmask_mismatch_error(self):
1488
+ """Full context mask but predicate has smaller static pmask -> trace-time ValueError.
1489
+
1490
+ We purposely do NOT broadcast the subset predicate to full mask, expecting the
1491
+ new static pmask validation in while_loop to raise.
1492
+ """
1493
+ cluster_spec = ClusterSpec.simple(world_size=3)
1494
+ full_mask = Mask(0b111)
1495
+ subset_mask = Mask(0b011)
1496
+ trace_ctx = TraceContext(cluster_spec=cluster_spec, mask=full_mask)
1497
+
1498
+ def bad_loop():
1499
+ init_state = set_mask(constant(np.int64(0)), subset_mask)
1500
+ threshold = set_mask(constant(np.int64(2)), subset_mask)
1501
+ step = set_mask(constant(np.int64(1)), subset_mask)
1502
+
1503
+ def cond_fn(state):
1504
+ # Returns bool with pmask=subset_mask only; no broadcast.
1505
+ return simp.run(lambda val, limit: val < limit)(state, threshold)
1506
+
1507
+ def body_fn(state):
1508
+ return simp.run(lambda val, inc: val + inc)(state, step)
1509
+
1510
+ return while_loop(cond_fn, body_fn, init_state)
1511
+
1512
+ with with_ctx(trace_ctx):
1513
+ with pytest.raises(
1514
+ ValueError, match=r"while_loop predicate static pmask mismatch"
1515
+ ):
1516
+ trace(trace_ctx, bad_loop)
1517
+
1518
+ def test_while_loop_cond_body_with_aux_party(self):
1519
+ """Loop state on subset while cond/body still invoke a third party."""
1520
+
1521
+ cluster_spec = ClusterSpec.simple(world_size=3)
1522
+ full_mask = Mask(0b111)
1523
+ subset_mask = Mask(0b011)
1524
+ aux_mask = Mask(0b100)
1525
+ trace_ctx = TraceContext(cluster_spec=cluster_spec, mask=full_mask)
1526
+ simulator = Simulator.simple(world_size=3)
1527
+
1528
+ def cooperative_loop():
1529
+ subset_state = set_mask(constant(np.int64(0)), subset_mask)
1530
+ aux_state = set_mask(constant(np.int64(0)), aux_mask)
1531
+
1532
+ subset_limit = set_mask(constant(np.int64(6)), subset_mask)
1533
+ subset_step = set_mask(constant(np.int64(2)), subset_mask)
1534
+ aux_step = set_mask(constant(np.int64(1)), aux_mask)
1535
+
1536
+ def cond_fn(states):
1537
+ sub_val, aux_val = states
1538
+
1539
+ # Auxiliary party executes a helper kernel (result ignored by others)
1540
+ _ = simp.run(lambda val, inc: val + inc)(aux_val, aux_step)
1541
+ subset_pred = simp.run(lambda val, limit: val < limit)(
1542
+ sub_val, subset_limit
1543
+ )
1544
+ # Broadcast predicate so every party observes the same boolean
1545
+ return pshfl_s(subset_pred, full_mask, [Rank(0), Rank(0), Rank(0)])
1546
+
1547
+ def body_fn(states):
1548
+ sub_val, aux_val = states
1549
+
1550
+ next_sub = simp.run(lambda val, step: val + step)(sub_val, subset_step)
1551
+ next_aux = simp.run(lambda val, inc: val + inc)(aux_val, aux_step)
1552
+
1553
+ return (next_sub, next_aux)
1554
+
1555
+ return while_loop(cond_fn, body_fn, (subset_state, aux_state))
1556
+
1557
+ with with_ctx(trace_ctx):
1558
+ traced_fn = trace(trace_ctx, cooperative_loop)
1559
+
1560
+ func_expr = traced_fn.make_expr()
1561
+ assert func_expr is not None
1562
+ expr = func_expr.body
1563
+ results = simulator.evaluate(expr, {})
1564
+
1565
+ assert len(results) == 2
1566
+ subset_result, aux_result = results
1567
+
1568
+ assert isinstance(subset_result, SimVar)
1569
+ assert subset_result.mptype.pmask == subset_mask
1570
+ assert subset_result.values == [6, 6, None]
1571
+
1572
+ assert isinstance(aux_result, SimVar)
1573
+ assert aux_result.mptype.pmask == aux_mask
1574
+ assert aux_result.values == [None, None, 3]
1575
+
1395
1576
  def test_nested_while_with_conditional(self, simulator, trace_context):
1396
1577
  """Test: While_loop containing conditional operations.
1397
1578