mplang-nightly 0.1.dev146__tar.gz → 0.1.dev148__tar.gz
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- {mplang_nightly-0.1.dev146 → mplang_nightly-0.1.dev148}/PKG-INFO +1 -1
- {mplang_nightly-0.1.dev146 → mplang_nightly-0.1.dev148}/mplang/backend/base.py +21 -47
- {mplang_nightly-0.1.dev146 → mplang_nightly-0.1.dev148}/mplang/backend/context.py +67 -26
- {mplang_nightly-0.1.dev146 → mplang_nightly-0.1.dev148}/mplang/core/expr/evaluator.py +26 -8
- {mplang_nightly-0.1.dev146 → mplang_nightly-0.1.dev148}/mplang/core/primitive.py +30 -0
- mplang_nightly-0.1.dev148/tests/backend/test_kernel_binding.py +102 -0
- {mplang_nightly-0.1.dev146 → mplang_nightly-0.1.dev148}/tests/runtime/test_simulation.py +182 -1
- {mplang_nightly-0.1.dev146 → mplang_nightly-0.1.dev148}/.gitignore +0 -0
- {mplang_nightly-0.1.dev146 → mplang_nightly-0.1.dev148}/LICENSE +0 -0
- {mplang_nightly-0.1.dev146 → mplang_nightly-0.1.dev148}/README.md +0 -0
- {mplang_nightly-0.1.dev146 → mplang_nightly-0.1.dev148}/examples/conf/3pc.yaml +0 -0
- {mplang_nightly-0.1.dev146 → mplang_nightly-0.1.dev148}/examples/stax_nn/README.md +0 -0
- {mplang_nightly-0.1.dev146 → mplang_nightly-0.1.dev148}/examples/stax_nn/models.py +0 -0
- {mplang_nightly-0.1.dev146 → mplang_nightly-0.1.dev148}/examples/stax_nn/stax_nn.py +0 -0
- {mplang_nightly-0.1.dev146 → mplang_nightly-0.1.dev148}/examples/xgboost/hist_jax.py +0 -0
- {mplang_nightly-0.1.dev146 → mplang_nightly-0.1.dev148}/examples/xgboost/hist_jax_test.py +0 -0
- {mplang_nightly-0.1.dev146 → mplang_nightly-0.1.dev148}/examples/xgboost/naive_np.py +0 -0
- {mplang_nightly-0.1.dev146 → mplang_nightly-0.1.dev148}/examples/xgboost/readme.md +0 -0
- {mplang_nightly-0.1.dev146 → mplang_nightly-0.1.dev148}/examples/xgboost/sgb.py +0 -0
- {mplang_nightly-0.1.dev146 → mplang_nightly-0.1.dev148}/examples/xgboost/sgb_test.py +0 -0
- {mplang_nightly-0.1.dev146 → mplang_nightly-0.1.dev148}/hatch_build.py +0 -0
- {mplang_nightly-0.1.dev146 → mplang_nightly-0.1.dev148}/mplang/__init__.py +0 -0
- {mplang_nightly-0.1.dev146 → mplang_nightly-0.1.dev148}/mplang/analysis/__init__.py +0 -0
- {mplang_nightly-0.1.dev146 → mplang_nightly-0.1.dev148}/mplang/analysis/diagram.py +0 -0
- {mplang_nightly-0.1.dev146 → mplang_nightly-0.1.dev148}/mplang/api.py +0 -0
- {mplang_nightly-0.1.dev146 → mplang_nightly-0.1.dev148}/mplang/backend/__init__.py +0 -0
- {mplang_nightly-0.1.dev146 → mplang_nightly-0.1.dev148}/mplang/backend/builtin.py +0 -0
- {mplang_nightly-0.1.dev146 → mplang_nightly-0.1.dev148}/mplang/backend/crypto.py +0 -0
- {mplang_nightly-0.1.dev146 → mplang_nightly-0.1.dev148}/mplang/backend/phe.py +0 -0
- {mplang_nightly-0.1.dev146 → mplang_nightly-0.1.dev148}/mplang/backend/spu.py +0 -0
- {mplang_nightly-0.1.dev146 → mplang_nightly-0.1.dev148}/mplang/backend/sql_duckdb.py +0 -0
- {mplang_nightly-0.1.dev146 → mplang_nightly-0.1.dev148}/mplang/backend/stablehlo.py +0 -0
- {mplang_nightly-0.1.dev146 → mplang_nightly-0.1.dev148}/mplang/backend/tee.py +0 -0
- {mplang_nightly-0.1.dev146 → mplang_nightly-0.1.dev148}/mplang/core/__init__.py +0 -0
- {mplang_nightly-0.1.dev146 → mplang_nightly-0.1.dev148}/mplang/core/cluster.py +0 -0
- {mplang_nightly-0.1.dev146 → mplang_nightly-0.1.dev148}/mplang/core/comm.py +0 -0
- {mplang_nightly-0.1.dev146 → mplang_nightly-0.1.dev148}/mplang/core/context_mgr.py +0 -0
- {mplang_nightly-0.1.dev146 → mplang_nightly-0.1.dev148}/mplang/core/dtype.py +0 -0
- {mplang_nightly-0.1.dev146 → mplang_nightly-0.1.dev148}/mplang/core/expr/__init__.py +0 -0
- {mplang_nightly-0.1.dev146 → mplang_nightly-0.1.dev148}/mplang/core/expr/ast.py +0 -0
- {mplang_nightly-0.1.dev146 → mplang_nightly-0.1.dev148}/mplang/core/expr/printer.py +0 -0
- {mplang_nightly-0.1.dev146 → mplang_nightly-0.1.dev148}/mplang/core/expr/transformer.py +0 -0
- {mplang_nightly-0.1.dev146 → mplang_nightly-0.1.dev148}/mplang/core/expr/utils.py +0 -0
- {mplang_nightly-0.1.dev146 → mplang_nightly-0.1.dev148}/mplang/core/expr/visitor.py +0 -0
- {mplang_nightly-0.1.dev146 → mplang_nightly-0.1.dev148}/mplang/core/expr/walk.py +0 -0
- {mplang_nightly-0.1.dev146 → mplang_nightly-0.1.dev148}/mplang/core/interp.py +0 -0
- {mplang_nightly-0.1.dev146 → mplang_nightly-0.1.dev148}/mplang/core/mask.py +0 -0
- {mplang_nightly-0.1.dev146 → mplang_nightly-0.1.dev148}/mplang/core/mpir.py +0 -0
- {mplang_nightly-0.1.dev146 → mplang_nightly-0.1.dev148}/mplang/core/mpobject.py +0 -0
- {mplang_nightly-0.1.dev146 → mplang_nightly-0.1.dev148}/mplang/core/mptype.py +0 -0
- {mplang_nightly-0.1.dev146 → mplang_nightly-0.1.dev148}/mplang/core/pfunc.py +0 -0
- {mplang_nightly-0.1.dev146 → mplang_nightly-0.1.dev148}/mplang/core/table.py +0 -0
- {mplang_nightly-0.1.dev146 → mplang_nightly-0.1.dev148}/mplang/core/tensor.py +0 -0
- {mplang_nightly-0.1.dev146 → mplang_nightly-0.1.dev148}/mplang/core/tracer.py +0 -0
- {mplang_nightly-0.1.dev146 → mplang_nightly-0.1.dev148}/mplang/device.py +0 -0
- {mplang_nightly-0.1.dev146 → mplang_nightly-0.1.dev148}/mplang/frontend/__init__.py +0 -0
- {mplang_nightly-0.1.dev146 → mplang_nightly-0.1.dev148}/mplang/frontend/base.py +0 -0
- {mplang_nightly-0.1.dev146 → mplang_nightly-0.1.dev148}/mplang/frontend/builtin.py +0 -0
- {mplang_nightly-0.1.dev146 → mplang_nightly-0.1.dev148}/mplang/frontend/crypto.py +0 -0
- {mplang_nightly-0.1.dev146 → mplang_nightly-0.1.dev148}/mplang/frontend/ibis_cc.py +0 -0
- {mplang_nightly-0.1.dev146 → mplang_nightly-0.1.dev148}/mplang/frontend/jax_cc.py +0 -0
- {mplang_nightly-0.1.dev146 → mplang_nightly-0.1.dev148}/mplang/frontend/phe.py +0 -0
- {mplang_nightly-0.1.dev146 → mplang_nightly-0.1.dev148}/mplang/frontend/spu.py +0 -0
- {mplang_nightly-0.1.dev146 → mplang_nightly-0.1.dev148}/mplang/frontend/sql.py +0 -0
- {mplang_nightly-0.1.dev146 → mplang_nightly-0.1.dev148}/mplang/frontend/tee.py +0 -0
- {mplang_nightly-0.1.dev146 → mplang_nightly-0.1.dev148}/mplang/protos/v1alpha1/mpir_pb2.py +0 -0
- {mplang_nightly-0.1.dev146 → mplang_nightly-0.1.dev148}/mplang/protos/v1alpha1/mpir_pb2.pyi +0 -0
- {mplang_nightly-0.1.dev146 → mplang_nightly-0.1.dev148}/mplang/protos/v1alpha1/mpir_pb2_grpc.py +0 -0
- {mplang_nightly-0.1.dev146 → mplang_nightly-0.1.dev148}/mplang/runtime/__init__.py +0 -0
- {mplang_nightly-0.1.dev146 → mplang_nightly-0.1.dev148}/mplang/runtime/cli.py +0 -0
- {mplang_nightly-0.1.dev146 → mplang_nightly-0.1.dev148}/mplang/runtime/client.py +0 -0
- {mplang_nightly-0.1.dev146 → mplang_nightly-0.1.dev148}/mplang/runtime/communicator.py +0 -0
- {mplang_nightly-0.1.dev146 → mplang_nightly-0.1.dev148}/mplang/runtime/data_providers.py +0 -0
- {mplang_nightly-0.1.dev146 → mplang_nightly-0.1.dev148}/mplang/runtime/driver.py +0 -0
- {mplang_nightly-0.1.dev146 → mplang_nightly-0.1.dev148}/mplang/runtime/exceptions.py +0 -0
- {mplang_nightly-0.1.dev146 → mplang_nightly-0.1.dev148}/mplang/runtime/http_api.md +0 -0
- {mplang_nightly-0.1.dev146 → mplang_nightly-0.1.dev148}/mplang/runtime/link_comm.py +0 -0
- {mplang_nightly-0.1.dev146 → mplang_nightly-0.1.dev148}/mplang/runtime/resource.py +0 -0
- {mplang_nightly-0.1.dev146 → mplang_nightly-0.1.dev148}/mplang/runtime/server.py +0 -0
- {mplang_nightly-0.1.dev146 → mplang_nightly-0.1.dev148}/mplang/runtime/simulation.py +0 -0
- {mplang_nightly-0.1.dev146 → mplang_nightly-0.1.dev148}/mplang/simp/__init__.py +0 -0
- {mplang_nightly-0.1.dev146 → mplang_nightly-0.1.dev148}/mplang/simp/mpi.py +0 -0
- {mplang_nightly-0.1.dev146 → mplang_nightly-0.1.dev148}/mplang/simp/random.py +0 -0
- {mplang_nightly-0.1.dev146 → mplang_nightly-0.1.dev148}/mplang/simp/smpc.py +0 -0
- {mplang_nightly-0.1.dev146 → mplang_nightly-0.1.dev148}/mplang/utils/__init__.py +0 -0
- {mplang_nightly-0.1.dev146 → mplang_nightly-0.1.dev148}/mplang/utils/crypto.py +0 -0
- {mplang_nightly-0.1.dev146 → mplang_nightly-0.1.dev148}/mplang/utils/func_utils.py +0 -0
- {mplang_nightly-0.1.dev146 → mplang_nightly-0.1.dev148}/mplang/utils/spu_utils.py +0 -0
- {mplang_nightly-0.1.dev146 → mplang_nightly-0.1.dev148}/mplang/utils/table_utils.py +0 -0
- {mplang_nightly-0.1.dev146 → mplang_nightly-0.1.dev148}/pyproject.toml +0 -0
- {mplang_nightly-0.1.dev146 → mplang_nightly-0.1.dev148}/tests/__init__.py +0 -0
- {mplang_nightly-0.1.dev146 → mplang_nightly-0.1.dev148}/tests/analysis/test_diagram.py +0 -0
- {mplang_nightly-0.1.dev146 → mplang_nightly-0.1.dev148}/tests/backend/test_builtin.py +0 -0
- {mplang_nightly-0.1.dev146 → mplang_nightly-0.1.dev148}/tests/backend/test_debug_print.py +0 -0
- {mplang_nightly-0.1.dev146 → mplang_nightly-0.1.dev148}/tests/backend/test_phe.py +0 -0
- {mplang_nightly-0.1.dev146 → mplang_nightly-0.1.dev148}/tests/backend/test_spu.py +0 -0
- {mplang_nightly-0.1.dev146 → mplang_nightly-0.1.dev148}/tests/backend/test_sql_duckdb.py +0 -0
- {mplang_nightly-0.1.dev146 → mplang_nightly-0.1.dev148}/tests/backend/test_stablehlo.py +0 -0
- {mplang_nightly-0.1.dev146 → mplang_nightly-0.1.dev148}/tests/core/__init__.py +0 -0
- {mplang_nightly-0.1.dev146 → mplang_nightly-0.1.dev148}/tests/core/expr/__init__.py +0 -0
- {mplang_nightly-0.1.dev146 → mplang_nightly-0.1.dev148}/tests/core/expr/conftest.py +0 -0
- {mplang_nightly-0.1.dev146 → mplang_nightly-0.1.dev148}/tests/core/expr/test_ast.py +0 -0
- {mplang_nightly-0.1.dev146 → mplang_nightly-0.1.dev148}/tests/core/expr/test_printer.py +0 -0
- {mplang_nightly-0.1.dev146 → mplang_nightly-0.1.dev148}/tests/core/expr/test_utils.py +0 -0
- {mplang_nightly-0.1.dev146 → mplang_nightly-0.1.dev148}/tests/core/expr/test_walk.py +0 -0
- {mplang_nightly-0.1.dev146 → mplang_nightly-0.1.dev148}/tests/core/test_cluster.py +0 -0
- {mplang_nightly-0.1.dev146 → mplang_nightly-0.1.dev148}/tests/core/test_dtype.py +0 -0
- {mplang_nightly-0.1.dev146 → mplang_nightly-0.1.dev148}/tests/core/test_mask.py +0 -0
- {mplang_nightly-0.1.dev146 → mplang_nightly-0.1.dev148}/tests/core/test_mpir.py +0 -0
- {mplang_nightly-0.1.dev146 → mplang_nightly-0.1.dev148}/tests/core/test_mptype.py +0 -0
- {mplang_nightly-0.1.dev146 → mplang_nightly-0.1.dev148}/tests/core/test_primitive.py +0 -0
- {mplang_nightly-0.1.dev146 → mplang_nightly-0.1.dev148}/tests/core/test_table.py +0 -0
- {mplang_nightly-0.1.dev146 → mplang_nightly-0.1.dev148}/tests/core/test_tensor.py +0 -0
- {mplang_nightly-0.1.dev146 → mplang_nightly-0.1.dev148}/tests/core/test_tracer.py +0 -0
- {mplang_nightly-0.1.dev146 → mplang_nightly-0.1.dev148}/tests/device/__init__.py +0 -0
- {mplang_nightly-0.1.dev146 → mplang_nightly-0.1.dev148}/tests/device/test_device_basic.py +0 -0
- {mplang_nightly-0.1.dev146 → mplang_nightly-0.1.dev148}/tests/frontend/__init__.py +0 -0
- {mplang_nightly-0.1.dev146 → mplang_nightly-0.1.dev148}/tests/frontend/dummy.py +0 -0
- {mplang_nightly-0.1.dev146 → mplang_nightly-0.1.dev148}/tests/frontend/test_builtin_pack.py +0 -0
- {mplang_nightly-0.1.dev146 → mplang_nightly-0.1.dev148}/tests/frontend/test_crypto_tee.py +0 -0
- {mplang_nightly-0.1.dev146 → mplang_nightly-0.1.dev148}/tests/frontend/test_feop_base.py +0 -0
- {mplang_nightly-0.1.dev146 → mplang_nightly-0.1.dev148}/tests/frontend/test_ibis.py +0 -0
- {mplang_nightly-0.1.dev146 → mplang_nightly-0.1.dev148}/tests/frontend/test_ibis_cc.py +0 -0
- {mplang_nightly-0.1.dev146 → mplang_nightly-0.1.dev148}/tests/frontend/test_jax_cc.py +0 -0
- {mplang_nightly-0.1.dev146 → mplang_nightly-0.1.dev148}/tests/frontend/test_phe.py +0 -0
- {mplang_nightly-0.1.dev146 → mplang_nightly-0.1.dev148}/tests/frontend/test_spu.py +0 -0
- {mplang_nightly-0.1.dev146 → mplang_nightly-0.1.dev148}/tests/frontend/test_spu_defensive.py +0 -0
- {mplang_nightly-0.1.dev146 → mplang_nightly-0.1.dev148}/tests/frontend/test_sql.py +0 -0
- {mplang_nightly-0.1.dev146 → mplang_nightly-0.1.dev148}/tests/frontend/test_table_tensor_conversion.py +0 -0
- {mplang_nightly-0.1.dev146 → mplang_nightly-0.1.dev148}/tests/integration/README.md +0 -0
- {mplang_nightly-0.1.dev146 → mplang_nightly-0.1.dev148}/tests/integration/test_crypto_roundtrip.py +0 -0
- {mplang_nightly-0.1.dev146 → mplang_nightly-0.1.dev148}/tests/integration/test_http_e2e.py +0 -0
- {mplang_nightly-0.1.dev146 → mplang_nightly-0.1.dev148}/tests/integration/test_symbols_roundtrip.py +0 -0
- {mplang_nightly-0.1.dev146 → mplang_nightly-0.1.dev148}/tests/integration/test_tutorials.py +0 -0
- {mplang_nightly-0.1.dev146 → mplang_nightly-0.1.dev148}/tests/runtime/__init__.py +0 -0
- {mplang_nightly-0.1.dev146 → mplang_nightly-0.1.dev148}/tests/runtime/test_cli.py +0 -0
- {mplang_nightly-0.1.dev146 → mplang_nightly-0.1.dev148}/tests/runtime/test_communicator.py +0 -0
- {mplang_nightly-0.1.dev146 → mplang_nightly-0.1.dev148}/tests/runtime/test_driver.py +0 -0
- {mplang_nightly-0.1.dev146 → mplang_nightly-0.1.dev148}/tests/runtime/test_server.py +0 -0
- {mplang_nightly-0.1.dev146 → mplang_nightly-0.1.dev148}/tests/simp/test_mpi.py +0 -0
- {mplang_nightly-0.1.dev146 → mplang_nightly-0.1.dev148}/tests/simp/test_random.py +0 -0
- {mplang_nightly-0.1.dev146 → mplang_nightly-0.1.dev148}/tests/simp/test_simp.py +0 -0
- {mplang_nightly-0.1.dev146 → mplang_nightly-0.1.dev148}/tests/simp/test_smpc.py +0 -0
- {mplang_nightly-0.1.dev146 → mplang_nightly-0.1.dev148}/tests/simp/test_sugar.py +0 -0
- {mplang_nightly-0.1.dev146 → mplang_nightly-0.1.dev148}/tests/utils/__init__.py +0 -0
- {mplang_nightly-0.1.dev146 → mplang_nightly-0.1.dev148}/tests/utils/server_fixtures.py +0 -0
- {mplang_nightly-0.1.dev146 → mplang_nightly-0.1.dev148}/tests/utils/test_func_utils.py +0 -0
- {mplang_nightly-0.1.dev146 → mplang_nightly-0.1.dev148}/tests/utils/test_spu_utils.py +0 -0
- {mplang_nightly-0.1.dev146 → mplang_nightly-0.1.dev148}/tests/utils/test_table_utils.py +0 -0
- {mplang_nightly-0.1.dev146 → mplang_nightly-0.1.dev148}/tutorials/0_basic.py +0 -0
- {mplang_nightly-0.1.dev146 → mplang_nightly-0.1.dev148}/tutorials/10_analysis.py +0 -0
- {mplang_nightly-0.1.dev146 → mplang_nightly-0.1.dev148}/tutorials/1_condition.py +0 -0
- {mplang_nightly-0.1.dev146 → mplang_nightly-0.1.dev148}/tutorials/2_whileloop.py +0 -0
- {mplang_nightly-0.1.dev146 → mplang_nightly-0.1.dev148}/tutorials/3_device.py +0 -0
- {mplang_nightly-0.1.dev146 → mplang_nightly-0.1.dev148}/tutorials/4_simulation.py +0 -0
- {mplang_nightly-0.1.dev146 → mplang_nightly-0.1.dev148}/tutorials/5_ir_dump.py +0 -0
- {mplang_nightly-0.1.dev146 → mplang_nightly-0.1.dev148}/tutorials/6_advanced.py +0 -0
- {mplang_nightly-0.1.dev146 → mplang_nightly-0.1.dev148}/tutorials/7_stdio.py +0 -0
- {mplang_nightly-0.1.dev146 → mplang_nightly-0.1.dev148}/tutorials/8_phe.py +0 -0
- {mplang_nightly-0.1.dev146 → mplang_nightly-0.1.dev148}/tutorials/9_tee.py +0 -0
- {mplang_nightly-0.1.dev146 → mplang_nightly-0.1.dev148}/tutorials/__init__.py +0 -0
- {mplang_nightly-0.1.dev146 → mplang_nightly-0.1.dev148}/tutorials/pitfalls/late_binding.py +0 -0
- {mplang_nightly-0.1.dev146 → mplang_nightly-0.1.dev148}/tutorials/pitfalls/rand.py +0 -0
- {mplang_nightly-0.1.dev146 → mplang_nightly-0.1.dev148}/tutorials/run.sh +0 -0
@@ -12,20 +12,21 @@
|
|
12
12
|
# See the License for the specific language governing permissions and
|
13
13
|
# limitations under the License.
|
14
14
|
|
15
|
-
"""Backend kernel registry
|
15
|
+
"""Backend kernel registry: mapping kernel_id -> implementation.
|
16
16
|
|
17
|
-
This
|
17
|
+
This module provides a lightweight registry for backend kernel implementations.
|
18
|
+
It does not track or decide which kernel handles a given semantic operation;
|
19
|
+
that policy (op -> kernel_id) is managed externally by each ``RuntimeContext``.
|
18
20
|
|
19
|
-
|
20
|
-
|
21
|
-
|
22
|
-
|
23
|
-
|
24
|
-
|
21
|
+
Exposed primitives:
|
22
|
+
* ``@kernel_def(kernel_id)``: decorator to register a kernel implementation.
|
23
|
+
* ``get_kernel_spec(kernel_id)``: look up a previously registered kernel.
|
24
|
+
* ``cur_kctx()`` / ``KernelContext``: execution context available only
|
25
|
+
inside a kernel body (rank, world_size, per-backend state pockets, and a
|
26
|
+
runtime-wide cache shared by kernels of the same runtime instance).
|
25
27
|
|
26
|
-
|
27
|
-
|
28
|
-
centrally (lazy) the first time a runtime executes a kernel.
|
28
|
+
No global op binding table exists here; callers resolve an op to a kernel_id
|
29
|
+
before invoking the kernel function.
|
29
30
|
"""
|
30
31
|
|
31
32
|
from __future__ import annotations
|
@@ -38,12 +39,10 @@ from typing import Any
|
|
38
39
|
__all__ = [
|
39
40
|
"KernelContext",
|
40
41
|
"KernelSpec",
|
41
|
-
"bind_op",
|
42
42
|
"cur_kctx",
|
43
|
-
"
|
43
|
+
"get_kernel_spec",
|
44
|
+
"kernel_exists",
|
44
45
|
"list_kernels",
|
45
|
-
"list_ops",
|
46
|
-
"unbind_op",
|
47
46
|
]
|
48
47
|
|
49
48
|
|
@@ -116,9 +115,6 @@ class KernelSpec:
|
|
116
115
|
# All registered kernel implementations: kernel_id -> spec
|
117
116
|
_KERNELS: dict[str, KernelSpec] = {}
|
118
117
|
|
119
|
-
# Active op bindings: op_type -> kernel_id
|
120
|
-
_BINDINGS: dict[str, str] = {}
|
121
|
-
|
122
118
|
|
123
119
|
def kernel_def(kernel_id: str, /, **meta: Any) -> Callable[[KernelFn], KernelFn]:
|
124
120
|
"""Decorator to register a concrete kernel implementation.
|
@@ -136,34 +132,11 @@ def kernel_def(kernel_id: str, /, **meta: Any) -> Callable[[KernelFn], KernelFn]
|
|
136
132
|
return _decorator
|
137
133
|
|
138
134
|
|
139
|
-
def
|
140
|
-
"""
|
141
|
-
|
142
|
-
|
143
|
-
op_type: Semantic operation name.
|
144
|
-
kernel_id: Previously registered kernel identifier.
|
145
|
-
force: If False and op_type already bound, keep existing binding.
|
146
|
-
If True (default), overwrite.
|
147
|
-
"""
|
148
|
-
if kernel_id not in _KERNELS:
|
135
|
+
def get_kernel_spec(kernel_id: str) -> KernelSpec:
|
136
|
+
"""Return KernelSpec for a registered kernel_id (no op binding lookup)."""
|
137
|
+
spec = _KERNELS.get(kernel_id)
|
138
|
+
if spec is None:
|
149
139
|
raise KeyError(f"kernel_id {kernel_id} not registered")
|
150
|
-
if not force and op_type in _BINDINGS:
|
151
|
-
return
|
152
|
-
_BINDINGS[op_type] = kernel_id
|
153
|
-
|
154
|
-
|
155
|
-
def unbind_op(op_type: str) -> None:
|
156
|
-
_BINDINGS.pop(op_type, None)
|
157
|
-
|
158
|
-
|
159
|
-
def get_kernel_for_op(op_type: str) -> KernelSpec:
|
160
|
-
kid = _BINDINGS.get(op_type)
|
161
|
-
if kid is None:
|
162
|
-
# Tests expect NotImplementedError for unsupported operations
|
163
|
-
raise NotImplementedError(f"no backend kernel registered for op {op_type}")
|
164
|
-
spec = _KERNELS.get(kid)
|
165
|
-
if spec is None: # inconsistent state
|
166
|
-
raise RuntimeError(f"active kernel_id {kid} missing spec")
|
167
140
|
return spec
|
168
141
|
|
169
142
|
|
@@ -171,5 +144,6 @@ def list_kernels() -> list[str]:
|
|
171
144
|
return sorted(_KERNELS.keys())
|
172
145
|
|
173
146
|
|
174
|
-
def
|
175
|
-
|
147
|
+
def kernel_exists(kernel_id: str) -> bool:
|
148
|
+
"""Return True if a kernel_id has been registered."""
|
149
|
+
return kernel_id in _KERNELS
|
@@ -15,11 +15,10 @@
|
|
15
15
|
from __future__ import annotations
|
16
16
|
|
17
17
|
from collections.abc import Mapping
|
18
|
-
from dataclasses import dataclass, field
|
19
18
|
from typing import Any
|
20
19
|
|
21
20
|
from mplang.backend import base
|
22
|
-
from mplang.backend.base import KernelContext,
|
21
|
+
from mplang.backend.base import KernelContext, get_kernel_spec, kernel_exists
|
23
22
|
from mplang.core.dtype import UINT8, DType
|
24
23
|
from mplang.core.pfunc import PFunction
|
25
24
|
from mplang.core.table import TableLike, TableType
|
@@ -100,30 +99,57 @@ _DEFAULT_BINDINGS: dict[str, str] = {
|
|
100
99
|
# --- RuntimeContext ---
|
101
100
|
|
102
101
|
|
103
|
-
@dataclass
|
104
102
|
class RuntimeContext:
|
105
|
-
|
106
|
-
|
107
|
-
|
108
|
-
|
109
|
-
|
110
|
-
|
111
|
-
|
112
|
-
|
103
|
+
"""Per-runtime execution context with isolated op->kernel bindings.
|
104
|
+
|
105
|
+
Parameters
|
106
|
+
----------
|
107
|
+
rank : int
|
108
|
+
Local rank of this participant.
|
109
|
+
world_size : int
|
110
|
+
Total number of participants.
|
111
|
+
initial_bindings : Mapping[str, str] | None, optional
|
112
|
+
Optional partial overrides applied on top of the default binding table
|
113
|
+
during construction (override semantics, not replace). After
|
114
|
+
initialization, all (re)binding must go through ``bind_op`` /
|
115
|
+
``rebind_op``.
|
116
|
+
state / cache / stats : dict, optional
|
117
|
+
Mutable pockets reused across kernel invocations. If omitted, new
|
118
|
+
dictionaries are created.
|
119
|
+
"""
|
120
|
+
|
121
|
+
__slots__ = ("_ibindings", "cache", "rank", "state", "stats", "world_size")
|
122
|
+
|
123
|
+
def __init__(
|
124
|
+
self,
|
125
|
+
rank: int,
|
126
|
+
world_size: int,
|
127
|
+
initial_bindings: Mapping[str, str] | None = None,
|
128
|
+
*,
|
129
|
+
state: dict[str, dict[str, Any]] | None = None,
|
130
|
+
cache: dict[str, Any] | None = None,
|
131
|
+
stats: dict[str, Any] | None = None,
|
132
|
+
) -> None:
|
113
133
|
_ensure_impl_imported()
|
114
|
-
|
115
|
-
|
116
|
-
|
117
|
-
|
118
|
-
|
119
|
-
|
120
|
-
|
134
|
+
self.rank = rank
|
135
|
+
self.world_size = world_size
|
136
|
+
# Merge defaults with user overrides (override semantics)
|
137
|
+
self._ibindings: dict[str, str] = {
|
138
|
+
**_DEFAULT_BINDINGS,
|
139
|
+
**(initial_bindings or {}),
|
140
|
+
}
|
141
|
+
self.state = state if state is not None else {}
|
142
|
+
self.cache = cache if cache is not None else {}
|
143
|
+
self.stats = stats if stats is not None else {}
|
121
144
|
self.stats.setdefault("op_calls", {})
|
122
145
|
|
123
146
|
def run_kernel(self, pfunc: PFunction, arg_list: list[Any]) -> list[Any]:
|
124
147
|
fn_type = pfunc.fn_type
|
125
|
-
|
126
|
-
|
148
|
+
kid = self._ibindings.get(fn_type)
|
149
|
+
if kid is None:
|
150
|
+
raise NotImplementedError(f"no backend kernel registered for op {fn_type}")
|
151
|
+
spec = get_kernel_spec(kid)
|
152
|
+
fn = spec.fn # kernel implementation
|
127
153
|
if len(arg_list) != len(pfunc.ins_info):
|
128
154
|
raise ValueError(
|
129
155
|
f"kernel {fn_type} arg count mismatch: got {len(arg_list)}, expect {len(pfunc.ins_info)}"
|
@@ -186,17 +212,32 @@ class RuntimeContext:
|
|
186
212
|
|
187
213
|
# ---- explicit (re)binding API ----
|
188
214
|
def bind_op(self, op_type: str, kernel_id: str, *, force: bool = False) -> None:
|
189
|
-
"""Bind an operation to a kernel
|
215
|
+
"""Bind an operation to a kernel for THIS context only.
|
190
216
|
|
191
|
-
force=False (default)
|
192
|
-
silent overrides. Use ``rebind_op`` or ``force=True`` to intentionally
|
193
|
-
change a binding.
|
217
|
+
force=False (default) keeps existing binding (no silent override).
|
194
218
|
"""
|
195
|
-
|
219
|
+
if not kernel_exists(kernel_id):
|
220
|
+
raise KeyError(f"kernel_id {kernel_id} not registered")
|
221
|
+
if not force and op_type in self._ibindings:
|
222
|
+
return
|
223
|
+
self._ibindings[op_type] = kernel_id
|
196
224
|
|
197
225
|
def rebind_op(self, op_type: str, kernel_id: str) -> None:
|
198
226
|
"""Force rebind an operation to a different kernel (shorthand)."""
|
199
|
-
|
227
|
+
self.bind_op(op_type, kernel_id, force=True)
|
228
|
+
|
229
|
+
# Introspection helpers
|
230
|
+
def list_bound_ops(self) -> list[str]: # pragma: no cover - convenience
|
231
|
+
return sorted(self._ibindings.keys())
|
232
|
+
|
233
|
+
def get_binding(self, op_type: str) -> str | None: # pragma: no cover
|
234
|
+
return self._ibindings.get(op_type)
|
235
|
+
|
236
|
+
def __repr__(self) -> str: # pragma: no cover - debug aid
|
237
|
+
return (
|
238
|
+
f"RuntimeContext(rank={self.rank}, world_size={self.world_size}, "
|
239
|
+
f"bound_ops={len(self._ibindings)})"
|
240
|
+
)
|
200
241
|
|
201
242
|
|
202
243
|
def _validate_table_arg(
|
@@ -196,6 +196,28 @@ class EvalSemantic:
|
|
196
196
|
"uniform_cond: predicate is not uniform across parties"
|
197
197
|
)
|
198
198
|
|
199
|
+
# ------------------------------ While helpers ------------------------------
|
200
|
+
def _check_while_predicate(self, cond_result: list[Any]) -> Any:
|
201
|
+
"""Validate while_loop predicate evaluation result.
|
202
|
+
|
203
|
+
Ensures the condition function returns exactly one value and that value
|
204
|
+
is non-None. Returns the boolean predicate value for convenience.
|
205
|
+
|
206
|
+
Raises:
|
207
|
+
AssertionError: If condition function returns != 1 value.
|
208
|
+
RuntimeError: If the single predicate value is None.
|
209
|
+
"""
|
210
|
+
assert len(cond_result) == 1, (
|
211
|
+
f"Condition function must return a single value, got {cond_result}"
|
212
|
+
)
|
213
|
+
cond_value = cond_result[0]
|
214
|
+
if cond_value is None:
|
215
|
+
raise RuntimeError(
|
216
|
+
"while_loop condition produced None on rank "
|
217
|
+
f"{self.rank}; ensure the predicate yields a boolean for every party."
|
218
|
+
)
|
219
|
+
return cond_value
|
220
|
+
|
199
221
|
|
200
222
|
class RecursiveEvaluator(EvalSemantic, ExprVisitor):
|
201
223
|
"""Recursive visitor-based evaluator."""
|
@@ -307,12 +329,8 @@ class RecursiveEvaluator(EvalSemantic, ExprVisitor):
|
|
307
329
|
cond_env = dict(zip(expr.cond_fn.params, state, strict=True))
|
308
330
|
cond_evaluator = self._fork(cond_env)
|
309
331
|
cond_result = expr.cond_fn.body.accept(cond_evaluator)
|
310
|
-
|
311
|
-
|
312
|
-
f"Condition function must return a single value, got {cond_result}"
|
313
|
-
)
|
314
|
-
|
315
|
-
if not cond_result[0]:
|
332
|
+
cond_value = self._check_while_predicate(cond_result)
|
333
|
+
if not cond_value:
|
316
334
|
break
|
317
335
|
|
318
336
|
# Call body function with same arguments
|
@@ -445,8 +463,8 @@ class IterativeEvaluator(EvalSemantic):
|
|
445
463
|
cond_vals = self._iter_eval_graph(
|
446
464
|
node.cond_fn.body, {**env, **cond_env}
|
447
465
|
)
|
448
|
-
|
449
|
-
if not bool(
|
466
|
+
cond_val = self._check_while_predicate(cond_vals)
|
467
|
+
if not bool(cond_val):
|
450
468
|
break
|
451
469
|
body_env = dict(zip(node.body_fn.params, state, strict=True))
|
452
470
|
new_state = self._iter_eval_graph(
|
@@ -483,6 +483,20 @@ def uniform_cond(
|
|
483
483
|
if pred_ty.dtype != BOOL:
|
484
484
|
raise TypeError(f"uniform_cond predicate must be boolean, got {pred_ty.dtype}")
|
485
485
|
|
486
|
+
# Static pmask rule:
|
487
|
+
# If predicate has a static pmask (not None), it must equal the current trace
|
488
|
+
# context mask. Otherwise some parties would execute a branch without a
|
489
|
+
# defined predicate value (unsafe). To run on a subset either:
|
490
|
+
# 1. Trace the entire uniform_cond under a subset TraceContext (ctx.fork(mask=...))
|
491
|
+
# 2. Broadcast / lift predicate to full mask (e.g. pshfl_s)
|
492
|
+
# Pred pmask None => dynamic: defer to runtime uniformity (if verify_uniform=True).
|
493
|
+
pred_pmask = pred_ty.pmask
|
494
|
+
if pred_pmask is not None and pred_pmask != cur_tracer.mask:
|
495
|
+
raise ValueError(
|
496
|
+
"uniform_cond predicate static pmask mismatch: predicate pmask="
|
497
|
+
f"{pred_pmask} trace mask={cur_tracer.mask}. Trace under a subset "
|
498
|
+
"context (ctx.fork(mask=...)) or broadcast predicate (pshfl_s) to all parties."
|
499
|
+
)
|
486
500
|
# Step 1: Trace both branches in separate contexts
|
487
501
|
then_tracer = cur_tracer.fork()
|
488
502
|
then_tfn = trace(then_tracer, then_fn, *args)
|
@@ -706,6 +720,22 @@ def while_loop(
|
|
706
720
|
f"Condition function must return a boolean scalar, got dtype {cond_out_var.mptype.dtype}"
|
707
721
|
)
|
708
722
|
|
723
|
+
# Static pmask rule:
|
724
|
+
# If the predicate's pmask is statically known it must match the trace context
|
725
|
+
# mask. Otherwise some parties in this context would lack a boolean to drive
|
726
|
+
# control flow (previously could lead to hang via None). To restrict to a subset:
|
727
|
+
# 1. Trace the entire while_loop under a subset context (ctx.fork(mask=submask)), or
|
728
|
+
# 2. Broadcast predicate to full mask (e.g. pshfl_s) before while_loop.
|
729
|
+
# Dynamic predicates (pmask=None) are allowed; runtime guard (evaluator) raises
|
730
|
+
# if any participating party observes None.
|
731
|
+
pred_pmask = cond_out_var.mptype.pmask
|
732
|
+
if pred_pmask is not None and pred_pmask != cur_tracer.mask:
|
733
|
+
raise ValueError(
|
734
|
+
"while_loop predicate static pmask mismatch: predicate pmask="
|
735
|
+
f"{pred_pmask} trace mask={cur_tracer.mask}. Trace under subset context "
|
736
|
+
"or broadcast predicate to all parties."
|
737
|
+
)
|
738
|
+
|
709
739
|
# Validate body returns same number of leaves and same dtype/shape per leaf
|
710
740
|
if len(body_tfn.out_vars) != len(cond_tfn.in_vars):
|
711
741
|
raise ValueError(
|
@@ -0,0 +1,102 @@
|
|
1
|
+
# Copyright 2025 Ant Group Co., Ltd.
|
2
|
+
#
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4
|
+
# you may not use this file except in compliance with the License.
|
5
|
+
# You may obtain a copy of the License at
|
6
|
+
#
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
8
|
+
#
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12
|
+
# See the License for the specific language governing permissions and
|
13
|
+
# limitations under the License.
|
14
|
+
|
15
|
+
# Tests for per-RuntimeContext binding isolation
|
16
|
+
from __future__ import annotations
|
17
|
+
|
18
|
+
import pytest
|
19
|
+
|
20
|
+
from mplang.backend import base
|
21
|
+
from mplang.backend.context import RuntimeContext
|
22
|
+
from mplang.core.dtype import INT64 # switched from INT32 to INT64 to match Python int
|
23
|
+
from mplang.core.pfunc import PFunction
|
24
|
+
from mplang.core.tensor import TensorType
|
25
|
+
|
26
|
+
# We'll register two fake kernels for an op to test rebinding.
|
27
|
+
# If they already exist due to other tests, we guard with try/except.
|
28
|
+
|
29
|
+
|
30
|
+
@base.kernel_def("test.echo.v1")
|
31
|
+
def _echo_v1(
|
32
|
+
pfunc: PFunction, x: int
|
33
|
+
) -> tuple[int,]: # pragma: no cover - executed in test
|
34
|
+
return (x + 1,)
|
35
|
+
|
36
|
+
|
37
|
+
@base.kernel_def("test.echo.v2")
|
38
|
+
def _echo_v2(
|
39
|
+
pfunc: PFunction, x: int
|
40
|
+
) -> tuple[int,]: # pragma: no cover - executed in test
|
41
|
+
return (x + 2,)
|
42
|
+
|
43
|
+
|
44
|
+
def make_pfunc(op_type: str) -> PFunction:
|
45
|
+
# Minimal PFunction stub compatible with backend.run_kernel expectations.
|
46
|
+
# shape info matters only for validation; use scalar INT64 (Python int maps to int64).
|
47
|
+
return PFunction(
|
48
|
+
fn_type=op_type,
|
49
|
+
fn_text="",
|
50
|
+
ins_info=[TensorType(shape=(), dtype=INT64)],
|
51
|
+
outs_info=[TensorType(shape=(), dtype=INT64)],
|
52
|
+
)
|
53
|
+
|
54
|
+
|
55
|
+
def test_isolated_rebind():
|
56
|
+
# ctx1 binds op -> v1, ctx2 binds op -> v2; they should not interfere.
|
57
|
+
op = "test.echo"
|
58
|
+
ctx1 = RuntimeContext(rank=0, world_size=1, initial_bindings={op: "test.echo.v1"})
|
59
|
+
ctx2 = RuntimeContext(rank=0, world_size=1, initial_bindings={op: "test.echo.v2"})
|
60
|
+
|
61
|
+
pfunc = make_pfunc(op)
|
62
|
+
out1 = ctx1.run_kernel(pfunc, [10])[0]
|
63
|
+
out2 = ctx2.run_kernel(pfunc, [10])[0]
|
64
|
+
|
65
|
+
assert out1 == 11
|
66
|
+
assert out2 == 12
|
67
|
+
|
68
|
+
|
69
|
+
def test_rebind_only_affects_context():
|
70
|
+
op = "test.echo"
|
71
|
+
ctx = RuntimeContext(rank=0, world_size=1, initial_bindings={op: "test.echo.v1"})
|
72
|
+
pfunc = make_pfunc(op)
|
73
|
+
assert ctx.run_kernel(pfunc, [5])[0] == 6
|
74
|
+
ctx.rebind_op(op, "test.echo.v2")
|
75
|
+
assert ctx.run_kernel(pfunc, [5])[0] == 7
|
76
|
+
|
77
|
+
|
78
|
+
def test_force_flag():
|
79
|
+
op = "test.echo"
|
80
|
+
ctx = RuntimeContext(rank=0, world_size=1, initial_bindings={op: "test.echo.v1"})
|
81
|
+
# Attempt non-force bind (should keep v1)
|
82
|
+
ctx.bind_op(op, "test.echo.v2", force=False)
|
83
|
+
pfunc = make_pfunc(op)
|
84
|
+
assert ctx.run_kernel(pfunc, [1])[0] == 2 # still v1 (+1)
|
85
|
+
# Now force
|
86
|
+
ctx.bind_op(op, "test.echo.v2", force=True)
|
87
|
+
assert ctx.run_kernel(pfunc, [1])[0] == 3
|
88
|
+
|
89
|
+
|
90
|
+
def test_unknown_kernel_id():
|
91
|
+
ctx = RuntimeContext(rank=0, world_size=1)
|
92
|
+
with pytest.raises(KeyError):
|
93
|
+
ctx.bind_op("some.op", "non.existent.kernel")
|
94
|
+
|
95
|
+
|
96
|
+
def test_missing_binding():
|
97
|
+
# Pick an op name unlikely in defaults
|
98
|
+
op = "unit.test.unbound"
|
99
|
+
ctx = RuntimeContext(rank=0, world_size=1)
|
100
|
+
pfunc = make_pfunc(op)
|
101
|
+
with pytest.raises(NotImplementedError):
|
102
|
+
ctx.run_kernel(pfunc, [0])
|
@@ -33,7 +33,14 @@ from mplang.core.dtype import FLOAT32, INT32
|
|
33
33
|
from mplang.core.mask import Mask
|
34
34
|
from mplang.core.mpobject import MPObject
|
35
35
|
from mplang.core.mptype import MPType, Rank
|
36
|
-
from mplang.core.primitive import
|
36
|
+
from mplang.core.primitive import (
|
37
|
+
constant,
|
38
|
+
prank,
|
39
|
+
pshfl_s,
|
40
|
+
set_mask,
|
41
|
+
uniform_cond,
|
42
|
+
while_loop,
|
43
|
+
)
|
37
44
|
from mplang.core.tracer import TraceContext, TraceVar, trace
|
38
45
|
from mplang.runtime.simulation import Simulator, SimVar
|
39
46
|
|
@@ -1392,6 +1399,180 @@ class TestWhileLoop:
|
|
1392
1399
|
assert results[0].values[0] == 6 # Party 0: 3 iterations
|
1393
1400
|
assert results[0].values[1] == 4 # Party 1: 2 iterations
|
1394
1401
|
|
1402
|
+
def test_while_loop_subset_state_mask(self):
|
1403
|
+
"""Loop state and control stay on subset of parties."""
|
1404
|
+
|
1405
|
+
cluster_spec = ClusterSpec.simple(world_size=3)
|
1406
|
+
full_mask = Mask(0b111)
|
1407
|
+
subset_mask = Mask(0b011)
|
1408
|
+
trace_ctx = TraceContext(cluster_spec=cluster_spec, mask=full_mask)
|
1409
|
+
simulator = Simulator.simple(world_size=3)
|
1410
|
+
|
1411
|
+
def subset_loop():
|
1412
|
+
init_state = set_mask(constant(np.int64(0)), subset_mask)
|
1413
|
+
threshold = set_mask(constant(np.int64(3)), subset_mask)
|
1414
|
+
step = set_mask(constant(np.int64(1)), subset_mask)
|
1415
|
+
|
1416
|
+
def cond_fn(state):
|
1417
|
+
subset_pred = simp.run(lambda val, limit: val < limit)(state, threshold)
|
1418
|
+
return pshfl_s(subset_pred, full_mask, [Rank(0), Rank(0), Rank(0)])
|
1419
|
+
|
1420
|
+
def body_fn(state):
|
1421
|
+
return simp.run(lambda val, inc: val + inc)(state, step)
|
1422
|
+
|
1423
|
+
return while_loop(cond_fn, body_fn, init_state)
|
1424
|
+
|
1425
|
+
with with_ctx(trace_ctx):
|
1426
|
+
traced_fn = trace(trace_ctx, subset_loop)
|
1427
|
+
|
1428
|
+
func_expr = traced_fn.make_expr()
|
1429
|
+
assert func_expr is not None
|
1430
|
+
expr = func_expr.body
|
1431
|
+
results = simulator.evaluate(expr, {})
|
1432
|
+
|
1433
|
+
assert len(results) == 1
|
1434
|
+
sim_var = results[0]
|
1435
|
+
assert isinstance(sim_var, SimVar)
|
1436
|
+
assert sim_var.mptype.pmask == subset_mask
|
1437
|
+
|
1438
|
+
values = sim_var.values
|
1439
|
+
assert len(values) == 3
|
1440
|
+
assert values[0] == 3
|
1441
|
+
assert values[1] == 3
|
1442
|
+
assert values[2] is None
|
1443
|
+
|
1444
|
+
def test_while_loop_subset_context_mask_success(self):
|
1445
|
+
"""Trace under subset context mask; predicate pmask==context mask so no broadcast needed.
|
1446
|
+
|
1447
|
+
Ensures static pmask validation (design A) does NOT raise when the trace context
|
1448
|
+
itself is the subset. Predicate pmask equals the context mask.
|
1449
|
+
"""
|
1450
|
+
# Use a 2-party cluster because only parties 0 and 1 participate.
|
1451
|
+
cluster_spec = ClusterSpec.simple(world_size=2)
|
1452
|
+
subset_mask = Mask(0b11) # parties 0 and 1
|
1453
|
+
trace_ctx = TraceContext(cluster_spec=cluster_spec, mask=subset_mask)
|
1454
|
+
simulator = Simulator.simple(world_size=2)
|
1455
|
+
|
1456
|
+
def subset_loop():
|
1457
|
+
init_state = set_mask(constant(np.int64(0)), subset_mask)
|
1458
|
+
threshold = set_mask(constant(np.int64(3)), subset_mask)
|
1459
|
+
step = set_mask(constant(np.int64(1)), subset_mask)
|
1460
|
+
|
1461
|
+
def cond_fn(state):
|
1462
|
+
# Returns bool with pmask=subset_mask (no broadcast)
|
1463
|
+
return simp.run(lambda val, limit: val < limit)(state, threshold)
|
1464
|
+
|
1465
|
+
def body_fn(state):
|
1466
|
+
return simp.run(lambda val, inc: val + inc)(state, step)
|
1467
|
+
|
1468
|
+
return while_loop(cond_fn, body_fn, init_state)
|
1469
|
+
|
1470
|
+
with with_ctx(trace_ctx):
|
1471
|
+
traced_fn = trace(trace_ctx, subset_loop)
|
1472
|
+
|
1473
|
+
func_expr = traced_fn.make_expr()
|
1474
|
+
assert func_expr is not None
|
1475
|
+
expr = func_expr.body
|
1476
|
+
results = simulator.evaluate(expr, {})
|
1477
|
+
|
1478
|
+
assert len(results) == 1
|
1479
|
+
sim_var = results[0]
|
1480
|
+
assert isinstance(sim_var, SimVar)
|
1481
|
+
assert sim_var.mptype.pmask == subset_mask
|
1482
|
+
values = sim_var.values
|
1483
|
+
assert len(values) == 2
|
1484
|
+
assert values[0] == 3
|
1485
|
+
assert values[1] == 3
|
1486
|
+
|
1487
|
+
def test_while_loop_predicate_static_pmask_mismatch_error(self):
|
1488
|
+
"""Full context mask but predicate has smaller static pmask -> trace-time ValueError.
|
1489
|
+
|
1490
|
+
We purposely do NOT broadcast the subset predicate to full mask, expecting the
|
1491
|
+
new static pmask validation in while_loop to raise.
|
1492
|
+
"""
|
1493
|
+
cluster_spec = ClusterSpec.simple(world_size=3)
|
1494
|
+
full_mask = Mask(0b111)
|
1495
|
+
subset_mask = Mask(0b011)
|
1496
|
+
trace_ctx = TraceContext(cluster_spec=cluster_spec, mask=full_mask)
|
1497
|
+
|
1498
|
+
def bad_loop():
|
1499
|
+
init_state = set_mask(constant(np.int64(0)), subset_mask)
|
1500
|
+
threshold = set_mask(constant(np.int64(2)), subset_mask)
|
1501
|
+
step = set_mask(constant(np.int64(1)), subset_mask)
|
1502
|
+
|
1503
|
+
def cond_fn(state):
|
1504
|
+
# Returns bool with pmask=subset_mask only; no broadcast.
|
1505
|
+
return simp.run(lambda val, limit: val < limit)(state, threshold)
|
1506
|
+
|
1507
|
+
def body_fn(state):
|
1508
|
+
return simp.run(lambda val, inc: val + inc)(state, step)
|
1509
|
+
|
1510
|
+
return while_loop(cond_fn, body_fn, init_state)
|
1511
|
+
|
1512
|
+
with with_ctx(trace_ctx):
|
1513
|
+
with pytest.raises(
|
1514
|
+
ValueError, match=r"while_loop predicate static pmask mismatch"
|
1515
|
+
):
|
1516
|
+
trace(trace_ctx, bad_loop)
|
1517
|
+
|
1518
|
+
def test_while_loop_cond_body_with_aux_party(self):
|
1519
|
+
"""Loop state on subset while cond/body still invoke a third party."""
|
1520
|
+
|
1521
|
+
cluster_spec = ClusterSpec.simple(world_size=3)
|
1522
|
+
full_mask = Mask(0b111)
|
1523
|
+
subset_mask = Mask(0b011)
|
1524
|
+
aux_mask = Mask(0b100)
|
1525
|
+
trace_ctx = TraceContext(cluster_spec=cluster_spec, mask=full_mask)
|
1526
|
+
simulator = Simulator.simple(world_size=3)
|
1527
|
+
|
1528
|
+
def cooperative_loop():
|
1529
|
+
subset_state = set_mask(constant(np.int64(0)), subset_mask)
|
1530
|
+
aux_state = set_mask(constant(np.int64(0)), aux_mask)
|
1531
|
+
|
1532
|
+
subset_limit = set_mask(constant(np.int64(6)), subset_mask)
|
1533
|
+
subset_step = set_mask(constant(np.int64(2)), subset_mask)
|
1534
|
+
aux_step = set_mask(constant(np.int64(1)), aux_mask)
|
1535
|
+
|
1536
|
+
def cond_fn(states):
|
1537
|
+
sub_val, aux_val = states
|
1538
|
+
|
1539
|
+
# Auxiliary party executes a helper kernel (result ignored by others)
|
1540
|
+
_ = simp.run(lambda val, inc: val + inc)(aux_val, aux_step)
|
1541
|
+
subset_pred = simp.run(lambda val, limit: val < limit)(
|
1542
|
+
sub_val, subset_limit
|
1543
|
+
)
|
1544
|
+
# Broadcast predicate so every party observes the same boolean
|
1545
|
+
return pshfl_s(subset_pred, full_mask, [Rank(0), Rank(0), Rank(0)])
|
1546
|
+
|
1547
|
+
def body_fn(states):
|
1548
|
+
sub_val, aux_val = states
|
1549
|
+
|
1550
|
+
next_sub = simp.run(lambda val, step: val + step)(sub_val, subset_step)
|
1551
|
+
next_aux = simp.run(lambda val, inc: val + inc)(aux_val, aux_step)
|
1552
|
+
|
1553
|
+
return (next_sub, next_aux)
|
1554
|
+
|
1555
|
+
return while_loop(cond_fn, body_fn, (subset_state, aux_state))
|
1556
|
+
|
1557
|
+
with with_ctx(trace_ctx):
|
1558
|
+
traced_fn = trace(trace_ctx, cooperative_loop)
|
1559
|
+
|
1560
|
+
func_expr = traced_fn.make_expr()
|
1561
|
+
assert func_expr is not None
|
1562
|
+
expr = func_expr.body
|
1563
|
+
results = simulator.evaluate(expr, {})
|
1564
|
+
|
1565
|
+
assert len(results) == 2
|
1566
|
+
subset_result, aux_result = results
|
1567
|
+
|
1568
|
+
assert isinstance(subset_result, SimVar)
|
1569
|
+
assert subset_result.mptype.pmask == subset_mask
|
1570
|
+
assert subset_result.values == [6, 6, None]
|
1571
|
+
|
1572
|
+
assert isinstance(aux_result, SimVar)
|
1573
|
+
assert aux_result.mptype.pmask == aux_mask
|
1574
|
+
assert aux_result.values == [None, None, 3]
|
1575
|
+
|
1395
1576
|
def test_nested_while_with_conditional(self, simulator, trace_context):
|
1396
1577
|
"""Test: While_loop containing conditional operations.
|
1397
1578
|
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|