mplang-nightly 0.1.dev286__tar.gz → 0.1.dev288__tar.gz
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- {mplang_nightly-0.1.dev286 → mplang_nightly-0.1.dev288}/PKG-INFO +1 -1
- mplang_nightly-0.1.dev288/mplang/backends/simp_worker/collective_algorithms.py +228 -0
- mplang_nightly-0.1.dev288/mplang/backends/simp_worker/collectives.py +275 -0
- {mplang_nightly-0.1.dev286 → mplang_nightly-0.1.dev288}/mplang/backends/simp_worker/ops.py +5 -2
- {mplang_nightly-0.1.dev286 → mplang_nightly-0.1.dev288}/mplang/backends/table_impl.py +1 -1
- mplang_nightly-0.1.dev288/tests/backends/simp_worker/test_collectives.py +141 -0
- {mplang_nightly-0.1.dev286 → mplang_nightly-0.1.dev288}/tests/backends/test_simp_integration.py +27 -0
- {mplang_nightly-0.1.dev286 → mplang_nightly-0.1.dev288}/tests/backends/test_table_impl.py +37 -0
- {mplang_nightly-0.1.dev286 → mplang_nightly-0.1.dev288}/.gitignore +0 -0
- {mplang_nightly-0.1.dev286 → mplang_nightly-0.1.dev288}/LICENSE +0 -0
- {mplang_nightly-0.1.dev286 → mplang_nightly-0.1.dev288}/README.md +0 -0
- {mplang_nightly-0.1.dev286 → mplang_nightly-0.1.dev288}/examples/.gitkeep +0 -0
- {mplang_nightly-0.1.dev286 → mplang_nightly-0.1.dev288}/hatch_build.py +0 -0
- {mplang_nightly-0.1.dev286 → mplang_nightly-0.1.dev288}/mplang/__init__.py +0 -0
- {mplang_nightly-0.1.dev286 → mplang_nightly-0.1.dev288}/mplang/backends/__init__.py +0 -0
- {mplang_nightly-0.1.dev286 → mplang_nightly-0.1.dev288}/mplang/backends/bfv_impl.py +0 -0
- {mplang_nightly-0.1.dev286 → mplang_nightly-0.1.dev288}/mplang/backends/channel.py +0 -0
- {mplang_nightly-0.1.dev286 → mplang_nightly-0.1.dev288}/mplang/backends/crypto_impl.py +0 -0
- {mplang_nightly-0.1.dev286 → mplang_nightly-0.1.dev288}/mplang/backends/field_impl.py +0 -0
- {mplang_nightly-0.1.dev286 → mplang_nightly-0.1.dev288}/mplang/backends/func_impl.py +0 -0
- {mplang_nightly-0.1.dev286 → mplang_nightly-0.1.dev288}/mplang/backends/phe_impl.py +0 -0
- {mplang_nightly-0.1.dev286 → mplang_nightly-0.1.dev288}/mplang/backends/simp_design.md +0 -0
- {mplang_nightly-0.1.dev286 → mplang_nightly-0.1.dev288}/mplang/backends/simp_driver/__init__.py +0 -0
- {mplang_nightly-0.1.dev286 → mplang_nightly-0.1.dev288}/mplang/backends/simp_driver/http.py +0 -0
- {mplang_nightly-0.1.dev286 → mplang_nightly-0.1.dev288}/mplang/backends/simp_driver/mem.py +0 -0
- {mplang_nightly-0.1.dev286 → mplang_nightly-0.1.dev288}/mplang/backends/simp_driver/ops.py +0 -0
- {mplang_nightly-0.1.dev286 → mplang_nightly-0.1.dev288}/mplang/backends/simp_driver/state.py +0 -0
- {mplang_nightly-0.1.dev286 → mplang_nightly-0.1.dev288}/mplang/backends/simp_driver/values.py +0 -0
- {mplang_nightly-0.1.dev286 → mplang_nightly-0.1.dev288}/mplang/backends/simp_worker/__init__.py +0 -0
- {mplang_nightly-0.1.dev286 → mplang_nightly-0.1.dev288}/mplang/backends/simp_worker/http.py +0 -0
- {mplang_nightly-0.1.dev286 → mplang_nightly-0.1.dev288}/mplang/backends/simp_worker/mem.py +0 -0
- {mplang_nightly-0.1.dev286 → mplang_nightly-0.1.dev288}/mplang/backends/simp_worker/state.py +0 -0
- {mplang_nightly-0.1.dev286 → mplang_nightly-0.1.dev288}/mplang/backends/spu_impl.py +0 -0
- {mplang_nightly-0.1.dev286 → mplang_nightly-0.1.dev288}/mplang/backends/spu_state.py +0 -0
- {mplang_nightly-0.1.dev286 → mplang_nightly-0.1.dev288}/mplang/backends/store_impl.py +0 -0
- {mplang_nightly-0.1.dev286 → mplang_nightly-0.1.dev288}/mplang/backends/tee_impl.py +0 -0
- {mplang_nightly-0.1.dev286 → mplang_nightly-0.1.dev288}/mplang/backends/tensor_impl.py +0 -0
- {mplang_nightly-0.1.dev286 → mplang_nightly-0.1.dev288}/mplang/cli.py +0 -0
- {mplang_nightly-0.1.dev286 → mplang_nightly-0.1.dev288}/mplang/cli_guide.md +0 -0
- {mplang_nightly-0.1.dev286 → mplang_nightly-0.1.dev288}/mplang/dialects/__init__.py +0 -0
- {mplang_nightly-0.1.dev286 → mplang_nightly-0.1.dev288}/mplang/dialects/bfv.py +0 -0
- {mplang_nightly-0.1.dev286 → mplang_nightly-0.1.dev288}/mplang/dialects/crypto.py +0 -0
- {mplang_nightly-0.1.dev286 → mplang_nightly-0.1.dev288}/mplang/dialects/dtypes.py +0 -0
- {mplang_nightly-0.1.dev286 → mplang_nightly-0.1.dev288}/mplang/dialects/field.py +0 -0
- {mplang_nightly-0.1.dev286 → mplang_nightly-0.1.dev288}/mplang/dialects/func.py +0 -0
- {mplang_nightly-0.1.dev286 → mplang_nightly-0.1.dev288}/mplang/dialects/phe.py +0 -0
- {mplang_nightly-0.1.dev286 → mplang_nightly-0.1.dev288}/mplang/dialects/simp.py +0 -0
- {mplang_nightly-0.1.dev286 → mplang_nightly-0.1.dev288}/mplang/dialects/spu.py +0 -0
- {mplang_nightly-0.1.dev286 → mplang_nightly-0.1.dev288}/mplang/dialects/store.py +0 -0
- {mplang_nightly-0.1.dev286 → mplang_nightly-0.1.dev288}/mplang/dialects/table.py +0 -0
- {mplang_nightly-0.1.dev286 → mplang_nightly-0.1.dev288}/mplang/dialects/tee.py +0 -0
- {mplang_nightly-0.1.dev286 → mplang_nightly-0.1.dev288}/mplang/dialects/tensor.py +0 -0
- {mplang_nightly-0.1.dev286 → mplang_nightly-0.1.dev288}/mplang/edsl/README.md +0 -0
- {mplang_nightly-0.1.dev286 → mplang_nightly-0.1.dev288}/mplang/edsl/__init__.py +0 -0
- {mplang_nightly-0.1.dev286 → mplang_nightly-0.1.dev288}/mplang/edsl/context.py +0 -0
- {mplang_nightly-0.1.dev286 → mplang_nightly-0.1.dev288}/mplang/edsl/graph.py +0 -0
- {mplang_nightly-0.1.dev286 → mplang_nightly-0.1.dev288}/mplang/edsl/jit.py +0 -0
- {mplang_nightly-0.1.dev286 → mplang_nightly-0.1.dev288}/mplang/edsl/object.py +0 -0
- {mplang_nightly-0.1.dev286 → mplang_nightly-0.1.dev288}/mplang/edsl/primitive.py +0 -0
- {mplang_nightly-0.1.dev286 → mplang_nightly-0.1.dev288}/mplang/edsl/printer.py +0 -0
- {mplang_nightly-0.1.dev286 → mplang_nightly-0.1.dev288}/mplang/edsl/program.py +0 -0
- {mplang_nightly-0.1.dev286 → mplang_nightly-0.1.dev288}/mplang/edsl/registry.py +0 -0
- {mplang_nightly-0.1.dev286 → mplang_nightly-0.1.dev288}/mplang/edsl/serde.py +0 -0
- {mplang_nightly-0.1.dev286 → mplang_nightly-0.1.dev288}/mplang/edsl/tracer.py +0 -0
- {mplang_nightly-0.1.dev286 → mplang_nightly-0.1.dev288}/mplang/edsl/typing.py +0 -0
- {mplang_nightly-0.1.dev286 → mplang_nightly-0.1.dev288}/mplang/kernels/Makefile +0 -0
- {mplang_nightly-0.1.dev286 → mplang_nightly-0.1.dev288}/mplang/kernels/__init__.py +0 -0
- {mplang_nightly-0.1.dev286 → mplang_nightly-0.1.dev288}/mplang/kernels/gf128.cpp +0 -0
- {mplang_nightly-0.1.dev286 → mplang_nightly-0.1.dev288}/mplang/kernels/ldpc.cpp +0 -0
- {mplang_nightly-0.1.dev286 → mplang_nightly-0.1.dev288}/mplang/kernels/okvs.cpp +0 -0
- {mplang_nightly-0.1.dev286 → mplang_nightly-0.1.dev288}/mplang/kernels/okvs_opt.cpp +0 -0
- {mplang_nightly-0.1.dev286 → mplang_nightly-0.1.dev288}/mplang/kernels/py_kernels.py +0 -0
- {mplang_nightly-0.1.dev286 → mplang_nightly-0.1.dev288}/mplang/libs/collective.py +0 -0
- {mplang_nightly-0.1.dev286 → mplang_nightly-0.1.dev288}/mplang/libs/device/__init__.py +0 -0
- {mplang_nightly-0.1.dev286 → mplang_nightly-0.1.dev288}/mplang/libs/device/api.py +0 -0
- {mplang_nightly-0.1.dev286 → mplang_nightly-0.1.dev288}/mplang/libs/device/cluster.py +0 -0
- {mplang_nightly-0.1.dev286 → mplang_nightly-0.1.dev288}/mplang/libs/ml/__init__.py +0 -0
- {mplang_nightly-0.1.dev286 → mplang_nightly-0.1.dev288}/mplang/libs/ml/sgb.py +0 -0
- {mplang_nightly-0.1.dev286 → mplang_nightly-0.1.dev288}/mplang/libs/mpc/__init__.py +0 -0
- {mplang_nightly-0.1.dev286 → mplang_nightly-0.1.dev288}/mplang/libs/mpc/_utils.py +0 -0
- {mplang_nightly-0.1.dev286 → mplang_nightly-0.1.dev288}/mplang/libs/mpc/analytics/__init__.py +0 -0
- {mplang_nightly-0.1.dev286 → mplang_nightly-0.1.dev288}/mplang/libs/mpc/analytics/aggregation.py +0 -0
- {mplang_nightly-0.1.dev286 → mplang_nightly-0.1.dev288}/mplang/libs/mpc/analytics/groupby.md +0 -0
- {mplang_nightly-0.1.dev286 → mplang_nightly-0.1.dev288}/mplang/libs/mpc/analytics/groupby.py +0 -0
- {mplang_nightly-0.1.dev286 → mplang_nightly-0.1.dev288}/mplang/libs/mpc/analytics/permutation.py +0 -0
- {mplang_nightly-0.1.dev286 → mplang_nightly-0.1.dev288}/mplang/libs/mpc/common/constants.py +0 -0
- {mplang_nightly-0.1.dev286 → mplang_nightly-0.1.dev288}/mplang/libs/mpc/ot/__init__.py +0 -0
- {mplang_nightly-0.1.dev286 → mplang_nightly-0.1.dev288}/mplang/libs/mpc/ot/base.py +0 -0
- {mplang_nightly-0.1.dev286 → mplang_nightly-0.1.dev288}/mplang/libs/mpc/ot/extension.py +0 -0
- {mplang_nightly-0.1.dev286 → mplang_nightly-0.1.dev288}/mplang/libs/mpc/ot/silent.py +0 -0
- {mplang_nightly-0.1.dev286 → mplang_nightly-0.1.dev288}/mplang/libs/mpc/psi/__init__.py +0 -0
- {mplang_nightly-0.1.dev286 → mplang_nightly-0.1.dev288}/mplang/libs/mpc/psi/cuckoo.py +0 -0
- {mplang_nightly-0.1.dev286 → mplang_nightly-0.1.dev288}/mplang/libs/mpc/psi/okvs.py +0 -0
- {mplang_nightly-0.1.dev286 → mplang_nightly-0.1.dev288}/mplang/libs/mpc/psi/okvs_gct.py +0 -0
- {mplang_nightly-0.1.dev286 → mplang_nightly-0.1.dev288}/mplang/libs/mpc/psi/oprf.py +0 -0
- {mplang_nightly-0.1.dev286 → mplang_nightly-0.1.dev288}/mplang/libs/mpc/psi/rr22.py +0 -0
- {mplang_nightly-0.1.dev286 → mplang_nightly-0.1.dev288}/mplang/libs/mpc/psi/unbalanced.py +0 -0
- {mplang_nightly-0.1.dev286 → mplang_nightly-0.1.dev288}/mplang/libs/mpc/vole/__init__.py +0 -0
- {mplang_nightly-0.1.dev286 → mplang_nightly-0.1.dev288}/mplang/libs/mpc/vole/gilboa.py +0 -0
- {mplang_nightly-0.1.dev286 → mplang_nightly-0.1.dev288}/mplang/libs/mpc/vole/ldpc.py +0 -0
- {mplang_nightly-0.1.dev286 → mplang_nightly-0.1.dev288}/mplang/libs/mpc/vole/silver.py +0 -0
- {mplang_nightly-0.1.dev286 → mplang_nightly-0.1.dev288}/mplang/py.typed +0 -0
- {mplang_nightly-0.1.dev286 → mplang_nightly-0.1.dev288}/mplang/runtime/__init__.py +0 -0
- {mplang_nightly-0.1.dev286 → mplang_nightly-0.1.dev288}/mplang/runtime/dialect_state.py +0 -0
- {mplang_nightly-0.1.dev286 → mplang_nightly-0.1.dev288}/mplang/runtime/interpreter.py +0 -0
- {mplang_nightly-0.1.dev286 → mplang_nightly-0.1.dev288}/mplang/runtime/object_store.py +0 -0
- {mplang_nightly-0.1.dev286 → mplang_nightly-0.1.dev288}/mplang/runtime/value.py +0 -0
- {mplang_nightly-0.1.dev286 → mplang_nightly-0.1.dev288}/mplang/tool/__init__.py +0 -0
- {mplang_nightly-0.1.dev286 → mplang_nightly-0.1.dev288}/mplang/tool/program.py +0 -0
- {mplang_nightly-0.1.dev286 → mplang_nightly-0.1.dev288}/mplang/utils/__init__.py +0 -0
- {mplang_nightly-0.1.dev286 → mplang_nightly-0.1.dev288}/mplang/utils/func_utils.py +0 -0
- {mplang_nightly-0.1.dev286 → mplang_nightly-0.1.dev288}/mplang/utils/logging.py +0 -0
- {mplang_nightly-0.1.dev286 → mplang_nightly-0.1.dev288}/pyproject.toml +0 -0
- {mplang_nightly-0.1.dev286 → mplang_nightly-0.1.dev288}/tests/__init__.py +0 -0
- {mplang_nightly-0.1.dev286 → mplang_nightly-0.1.dev288}/tests/backends/__init__.py +0 -0
- {mplang_nightly-0.1.dev286 → mplang_nightly-0.1.dev288}/tests/backends/simp_driver/__init__.py +0 -0
- {mplang_nightly-0.1.dev286 → mplang_nightly-0.1.dev288}/tests/backends/simp_driver/test_http.py +0 -0
- {mplang_nightly-0.1.dev286 → mplang_nightly-0.1.dev288}/tests/backends/simp_worker/__init__.py +0 -0
- {mplang_nightly-0.1.dev286 → mplang_nightly-0.1.dev288}/tests/backends/simp_worker/test_http.py +0 -0
- {mplang_nightly-0.1.dev286 → mplang_nightly-0.1.dev288}/tests/backends/simp_worker/test_mem.py +0 -0
- {mplang_nightly-0.1.dev286 → mplang_nightly-0.1.dev288}/tests/backends/simp_worker/test_shuffle_exec_id_key.py +0 -0
- {mplang_nightly-0.1.dev286 → mplang_nightly-0.1.dev288}/tests/backends/test_bfv_impl.py +0 -0
- {mplang_nightly-0.1.dev286 → mplang_nightly-0.1.dev288}/tests/backends/test_channel.py +0 -0
- {mplang_nightly-0.1.dev286 → mplang_nightly-0.1.dev288}/tests/backends/test_crypto_impl.py +0 -0
- {mplang_nightly-0.1.dev286 → mplang_nightly-0.1.dev288}/tests/backends/test_okvs_binding.py +0 -0
- {mplang_nightly-0.1.dev286 → mplang_nightly-0.1.dev288}/tests/backends/test_simp_object_store.py +0 -0
- {mplang_nightly-0.1.dev286 → mplang_nightly-0.1.dev288}/tests/backends/test_spu_impl.py +0 -0
- {mplang_nightly-0.1.dev286 → mplang_nightly-0.1.dev288}/tests/backends/test_tee_impl.py +0 -0
- {mplang_nightly-0.1.dev286 → mplang_nightly-0.1.dev288}/tests/backends/test_tensor_impl.py +0 -0
- {mplang_nightly-0.1.dev286 → mplang_nightly-0.1.dev288}/tests/backends/test_verify_clean.py +0 -0
- {mplang_nightly-0.1.dev286 → mplang_nightly-0.1.dev288}/tests/conftest.py +0 -0
- {mplang_nightly-0.1.dev286 → mplang_nightly-0.1.dev288}/tests/dialects/__init__.py +0 -0
- {mplang_nightly-0.1.dev286 → mplang_nightly-0.1.dev288}/tests/dialects/test_bfv.py +0 -0
- {mplang_nightly-0.1.dev286 → mplang_nightly-0.1.dev288}/tests/dialects/test_crypto.py +0 -0
- {mplang_nightly-0.1.dev286 → mplang_nightly-0.1.dev288}/tests/dialects/test_dtypes.py +0 -0
- {mplang_nightly-0.1.dev286 → mplang_nightly-0.1.dev288}/tests/dialects/test_field.py +0 -0
- {mplang_nightly-0.1.dev286 → mplang_nightly-0.1.dev288}/tests/dialects/test_func.py +0 -0
- {mplang_nightly-0.1.dev286 → mplang_nightly-0.1.dev288}/tests/dialects/test_okvs.py +0 -0
- {mplang_nightly-0.1.dev286 → mplang_nightly-0.1.dev288}/tests/dialects/test_okvs_bench.py +0 -0
- {mplang_nightly-0.1.dev286 → mplang_nightly-0.1.dev288}/tests/dialects/test_phe.py +0 -0
- {mplang_nightly-0.1.dev286 → mplang_nightly-0.1.dev288}/tests/dialects/test_simp.py +0 -0
- {mplang_nightly-0.1.dev286 → mplang_nightly-0.1.dev288}/tests/dialects/test_simp_comm.py +0 -0
- {mplang_nightly-0.1.dev286 → mplang_nightly-0.1.dev288}/tests/dialects/test_spu.py +0 -0
- {mplang_nightly-0.1.dev286 → mplang_nightly-0.1.dev288}/tests/dialects/test_store.py +0 -0
- {mplang_nightly-0.1.dev286 → mplang_nightly-0.1.dev288}/tests/dialects/test_table.py +0 -0
- {mplang_nightly-0.1.dev286 → mplang_nightly-0.1.dev288}/tests/dialects/test_tee.py +0 -0
- {mplang_nightly-0.1.dev286 → mplang_nightly-0.1.dev288}/tests/dialects/test_tensor.py +0 -0
- {mplang_nightly-0.1.dev286 → mplang_nightly-0.1.dev288}/tests/edsl/__init__.py +0 -0
- {mplang_nightly-0.1.dev286 → mplang_nightly-0.1.dev288}/tests/edsl/test_compiled_program_artifact.py +0 -0
- {mplang_nightly-0.1.dev286 → mplang_nightly-0.1.dev288}/tests/edsl/test_context.py +0 -0
- {mplang_nightly-0.1.dev286 → mplang_nightly-0.1.dev288}/tests/edsl/test_graph.py +0 -0
- {mplang_nightly-0.1.dev286 → mplang_nightly-0.1.dev288}/tests/edsl/test_primitive.py +0 -0
- {mplang_nightly-0.1.dev286 → mplang_nightly-0.1.dev288}/tests/edsl/test_primitive_multi_output.py +0 -0
- {mplang_nightly-0.1.dev286 → mplang_nightly-0.1.dev288}/tests/edsl/test_printer.py +0 -0
- {mplang_nightly-0.1.dev286 → mplang_nightly-0.1.dev288}/tests/edsl/test_serde.py +0 -0
- {mplang_nightly-0.1.dev286 → mplang_nightly-0.1.dev288}/tests/edsl/test_tracer.py +0 -0
- {mplang_nightly-0.1.dev286 → mplang_nightly-0.1.dev288}/tests/edsl/test_typing.py +0 -0
- {mplang_nightly-0.1.dev286 → mplang_nightly-0.1.dev288}/tests/edsl/test_typing_graph_serde.py +0 -0
- {mplang_nightly-0.1.dev286 → mplang_nightly-0.1.dev288}/tests/libs/device/__init__.py +0 -0
- {mplang_nightly-0.1.dev286 → mplang_nightly-0.1.dev288}/tests/libs/device/conftest.py +0 -0
- {mplang_nightly-0.1.dev286 → mplang_nightly-0.1.dev288}/tests/libs/device/test_device_api_errors.py +0 -0
- {mplang_nightly-0.1.dev286 → mplang_nightly-0.1.dev288}/tests/libs/device/test_device_dialects.py +0 -0
- {mplang_nightly-0.1.dev286 → mplang_nightly-0.1.dev288}/tests/libs/device/test_device_layouts.py +0 -0
- {mplang_nightly-0.1.dev286 → mplang_nightly-0.1.dev288}/tests/libs/device/test_device_tee.py +0 -0
- {mplang_nightly-0.1.dev286 → mplang_nightly-0.1.dev288}/tests/libs/ml/__init__.py +0 -0
- {mplang_nightly-0.1.dev286 → mplang_nightly-0.1.dev288}/tests/libs/ml/test_sgb.py +0 -0
- {mplang_nightly-0.1.dev286 → mplang_nightly-0.1.dev288}/tests/libs/ml/test_sgb_bench.py +0 -0
- {mplang_nightly-0.1.dev286 → mplang_nightly-0.1.dev288}/tests/libs/mpc/__init__.py +0 -0
- {mplang_nightly-0.1.dev286 → mplang_nightly-0.1.dev288}/tests/libs/mpc/analytics/__init__.py +0 -0
- {mplang_nightly-0.1.dev286 → mplang_nightly-0.1.dev288}/tests/libs/mpc/analytics/test_aggregation.py +0 -0
- {mplang_nightly-0.1.dev286 → mplang_nightly-0.1.dev288}/tests/libs/mpc/analytics/test_groupby.py +0 -0
- {mplang_nightly-0.1.dev286 → mplang_nightly-0.1.dev288}/tests/libs/mpc/analytics/test_permutation.py +0 -0
- {mplang_nightly-0.1.dev286 → mplang_nightly-0.1.dev288}/tests/libs/mpc/ot/__init__.py +0 -0
- {mplang_nightly-0.1.dev286 → mplang_nightly-0.1.dev288}/tests/libs/mpc/ot/test_ot.py +0 -0
- {mplang_nightly-0.1.dev286 → mplang_nightly-0.1.dev288}/tests/libs/mpc/ot/test_ot_extension.py +0 -0
- {mplang_nightly-0.1.dev286 → mplang_nightly-0.1.dev288}/tests/libs/mpc/ot/test_silent_ot.py +0 -0
- {mplang_nightly-0.1.dev286 → mplang_nightly-0.1.dev288}/tests/libs/mpc/psi/__init__.py +0 -0
- {mplang_nightly-0.1.dev286 → mplang_nightly-0.1.dev288}/tests/libs/mpc/psi/test_okvs_gct.py +0 -0
- {mplang_nightly-0.1.dev286 → mplang_nightly-0.1.dev288}/tests/libs/mpc/psi/test_oprf.py +0 -0
- {mplang_nightly-0.1.dev286 → mplang_nightly-0.1.dev288}/tests/libs/mpc/psi/test_psi.py +0 -0
- {mplang_nightly-0.1.dev286 → mplang_nightly-0.1.dev288}/tests/libs/mpc/psi/test_psi_bench.py +0 -0
- {mplang_nightly-0.1.dev286 → mplang_nightly-0.1.dev288}/tests/libs/mpc/psi/test_rr22.py +0 -0
- {mplang_nightly-0.1.dev286 → mplang_nightly-0.1.dev288}/tests/libs/mpc/psi/verify_psi_okvs_logic.py +0 -0
- {mplang_nightly-0.1.dev286 → mplang_nightly-0.1.dev288}/tests/libs/mpc/test_field_gf128.py +0 -0
- {mplang_nightly-0.1.dev286 → mplang_nightly-0.1.dev288}/tests/libs/mpc/test_utils.py +0 -0
- {mplang_nightly-0.1.dev286 → mplang_nightly-0.1.dev288}/tests/libs/mpc/vole/__init__.py +0 -0
- {mplang_nightly-0.1.dev286 → mplang_nightly-0.1.dev288}/tests/libs/mpc/vole/test_gilboa_manual.py +0 -0
- {mplang_nightly-0.1.dev286 → mplang_nightly-0.1.dev288}/tests/libs/mpc/vole/test_ldpc.py +0 -0
- {mplang_nightly-0.1.dev286 → mplang_nightly-0.1.dev288}/tests/libs/mpc/vole/test_silver_vole.py +0 -0
- {mplang_nightly-0.1.dev286 → mplang_nightly-0.1.dev288}/tests/libs/mpc/vole/test_vole.py +0 -0
- {mplang_nightly-0.1.dev286 → mplang_nightly-0.1.dev288}/tests/libs/mpc/vole/test_vole_bench.py +0 -0
- {mplang_nightly-0.1.dev286 → mplang_nightly-0.1.dev288}/tests/libs/mpc/vole/verify_vole_logic.py +0 -0
- {mplang_nightly-0.1.dev286 → mplang_nightly-0.1.dev288}/tests/libs/test_collective.py +0 -0
- {mplang_nightly-0.1.dev286 → mplang_nightly-0.1.dev288}/tests/libs/test_simple_guide.py +0 -0
- {mplang_nightly-0.1.dev286 → mplang_nightly-0.1.dev288}/tests/runtime/test_interpreter_async.py +0 -0
- {mplang_nightly-0.1.dev286 → mplang_nightly-0.1.dev288}/tests/runtime/test_object_store.py +0 -0
- {mplang_nightly-0.1.dev286 → mplang_nightly-0.1.dev288}/tests/runtime/test_object_store_fs.py +0 -0
- {mplang_nightly-0.1.dev286 → mplang_nightly-0.1.dev288}/tests/test_fetch_semantics.py +0 -0
- {mplang_nightly-0.1.dev286 → mplang_nightly-0.1.dev288}/tests/test_pytree_io.py +0 -0
- {mplang_nightly-0.1.dev286 → mplang_nightly-0.1.dev288}/tests/utils/__init__.py +0 -0
- {mplang_nightly-0.1.dev286 → mplang_nightly-0.1.dev288}/tests/utils/tensor_patch.py +0 -0
- {mplang_nightly-0.1.dev286 → mplang_nightly-0.1.dev288}/tests/utils/test_func_utils.py +0 -0
- {mplang_nightly-0.1.dev286 → mplang_nightly-0.1.dev288}/tests/utils/test_logging.py +0 -0
- {mplang_nightly-0.1.dev286 → mplang_nightly-0.1.dev288}/tutorials/00_device_basics.py +0 -0
- {mplang_nightly-0.1.dev286 → mplang_nightly-0.1.dev288}/tutorials/01_function_decorator.py +0 -0
- {mplang_nightly-0.1.dev286 → mplang_nightly-0.1.dev288}/tutorials/02_simulation_and_driver.py +0 -0
- {mplang_nightly-0.1.dev286 → mplang_nightly-0.1.dev288}/tutorials/03_run_jax.py +0 -0
- {mplang_nightly-0.1.dev286 → mplang_nightly-0.1.dev288}/tutorials/04_ir_dump_and_analysis.py +0 -0
- {mplang_nightly-0.1.dev286 → mplang_nightly-0.1.dev288}/tutorials/05_run_sql.py +0 -0
- {mplang_nightly-0.1.dev286 → mplang_nightly-0.1.dev288}/tutorials/06_pipeline.py +0 -0
- {mplang_nightly-0.1.dev286 → mplang_nightly-0.1.dev288}/tutorials/07_stax_nn.py +0 -0
- {mplang_nightly-0.1.dev286 → mplang_nightly-0.1.dev288}/tutorials/08_logging.py +0 -0
- {mplang_nightly-0.1.dev286 → mplang_nightly-0.1.dev288}/tutorials/MIGRATION.md +0 -0
- {mplang_nightly-0.1.dev286 → mplang_nightly-0.1.dev288}/tutorials/README.md +0 -0
- {mplang_nightly-0.1.dev286 → mplang_nightly-0.1.dev288}/tutorials/__init__.py +0 -0
- {mplang_nightly-0.1.dev286 → mplang_nightly-0.1.dev288}/tutorials/data/alice.csv +0 -0
- {mplang_nightly-0.1.dev286 → mplang_nightly-0.1.dev288}/tutorials/data/bob.csv +0 -0
- {mplang_nightly-0.1.dev286 → mplang_nightly-0.1.dev288}/tutorials/data/prepare_vertical_iris.py +0 -0
- {mplang_nightly-0.1.dev286 → mplang_nightly-0.1.dev288}/tutorials/run.sh +0 -0
|
@@ -0,0 +1,228 @@
|
|
|
1
|
+
# Copyright 2026 Ant Group Co., Ltd.
|
|
2
|
+
#
|
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
4
|
+
# you may not use this file except in compliance with the License.
|
|
5
|
+
# You may obtain a copy of the License at
|
|
6
|
+
#
|
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
8
|
+
#
|
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
12
|
+
# See the License for the specific language governing permissions and
|
|
13
|
+
# limitations under the License.
|
|
14
|
+
|
|
15
|
+
"""Collective communication algorithms (communicator-only).
|
|
16
|
+
|
|
17
|
+
This module contains *pure* collective algorithms implemented only in terms of
|
|
18
|
+
(a) a communicator and (b) an explicit participant set.
|
|
19
|
+
|
|
20
|
+
It intentionally does NOT depend on:
|
|
21
|
+
- Interpreter execution IDs / graph keys
|
|
22
|
+
- SimpWorker current_parties
|
|
23
|
+
- Operation objects
|
|
24
|
+
|
|
25
|
+
Callers are expected to provide a collision-free key prefix for each collective
|
|
26
|
+
instance.
|
|
27
|
+
"""
|
|
28
|
+
|
|
29
|
+
from __future__ import annotations
|
|
30
|
+
|
|
31
|
+
import operator
|
|
32
|
+
from collections.abc import Callable, Sequence
|
|
33
|
+
from typing import Any, Protocol
|
|
34
|
+
|
|
35
|
+
|
|
36
|
+
class Communicator(Protocol):
|
|
37
|
+
"""Minimal communicator interface required by the algorithms."""
|
|
38
|
+
|
|
39
|
+
rank: int
|
|
40
|
+
world_size: int
|
|
41
|
+
|
|
42
|
+
def send(
|
|
43
|
+
self, to: int, key: str, data: Any, *, is_raw_bytes: bool = False
|
|
44
|
+
) -> None: ...
|
|
45
|
+
|
|
46
|
+
def recv(self, frm: int, key: str) -> Any: ...
|
|
47
|
+
|
|
48
|
+
|
|
49
|
+
def normalize_participants(
|
|
50
|
+
comm: Communicator, participants: Sequence[int]
|
|
51
|
+
) -> tuple[int, ...]:
|
|
52
|
+
ps = tuple(sorted({int(r) for r in participants}))
|
|
53
|
+
if not ps:
|
|
54
|
+
raise ValueError("participants must be non-empty")
|
|
55
|
+
if any(r < 0 or r >= comm.world_size for r in ps):
|
|
56
|
+
raise ValueError(
|
|
57
|
+
f"participants out of range: {ps}, world_size={comm.world_size}"
|
|
58
|
+
)
|
|
59
|
+
if comm.rank not in ps:
|
|
60
|
+
raise ValueError(f"rank {comm.rank} is not in participants {ps}")
|
|
61
|
+
return ps
|
|
62
|
+
|
|
63
|
+
|
|
64
|
+
def barrier(
|
|
65
|
+
comm: Communicator, *, participants: Sequence[int], key_prefix: str
|
|
66
|
+
) -> None:
|
|
67
|
+
"""Barrier using root gather + root release."""
|
|
68
|
+
|
|
69
|
+
ps = normalize_participants(comm, participants)
|
|
70
|
+
root = ps[0]
|
|
71
|
+
|
|
72
|
+
arrive_key = f"{key_prefix}_arrive"
|
|
73
|
+
release_key = f"{key_prefix}_release"
|
|
74
|
+
|
|
75
|
+
if comm.rank != root:
|
|
76
|
+
comm.send(root, arrive_key, True)
|
|
77
|
+
comm.recv(root, release_key)
|
|
78
|
+
return
|
|
79
|
+
|
|
80
|
+
for r in ps:
|
|
81
|
+
if r == root:
|
|
82
|
+
continue
|
|
83
|
+
_ = comm.recv(r, arrive_key)
|
|
84
|
+
|
|
85
|
+
for r in ps:
|
|
86
|
+
if r == root:
|
|
87
|
+
continue
|
|
88
|
+
comm.send(r, release_key, True)
|
|
89
|
+
|
|
90
|
+
|
|
91
|
+
def broadcast(
|
|
92
|
+
comm: Communicator,
|
|
93
|
+
value: Any,
|
|
94
|
+
*,
|
|
95
|
+
root: int,
|
|
96
|
+
participants: Sequence[int],
|
|
97
|
+
key_prefix: str,
|
|
98
|
+
) -> Any:
|
|
99
|
+
"""Broadcast a value from root to all participants."""
|
|
100
|
+
|
|
101
|
+
ps = normalize_participants(comm, participants)
|
|
102
|
+
if root not in ps:
|
|
103
|
+
raise ValueError(f"root {root} must be in participants {ps}")
|
|
104
|
+
|
|
105
|
+
bcast_key = f"{key_prefix}_bcast"
|
|
106
|
+
|
|
107
|
+
if comm.rank == root:
|
|
108
|
+
for r in ps:
|
|
109
|
+
if r == root:
|
|
110
|
+
continue
|
|
111
|
+
comm.send(r, bcast_key, value)
|
|
112
|
+
return value
|
|
113
|
+
|
|
114
|
+
return comm.recv(root, bcast_key)
|
|
115
|
+
|
|
116
|
+
|
|
117
|
+
def allgather(
|
|
118
|
+
comm: Communicator, value: Any, *, participants: Sequence[int], key_prefix: str
|
|
119
|
+
) -> list[Any]:
|
|
120
|
+
"""Allgather implemented as gather-to-root then root broadcast."""
|
|
121
|
+
|
|
122
|
+
ps = normalize_participants(comm, participants)
|
|
123
|
+
root = ps[0]
|
|
124
|
+
|
|
125
|
+
gather_key = f"{key_prefix}_gather"
|
|
126
|
+
bcast_key = f"{key_prefix}_bcast"
|
|
127
|
+
|
|
128
|
+
if comm.rank != root:
|
|
129
|
+
comm.send(root, gather_key, value)
|
|
130
|
+
gathered = comm.recv(root, bcast_key)
|
|
131
|
+
if not isinstance(gathered, list):
|
|
132
|
+
raise TypeError(f"expected list from root broadcast, got {type(gathered)}")
|
|
133
|
+
return gathered
|
|
134
|
+
|
|
135
|
+
values_by_rank: dict[int, Any] = {root: value}
|
|
136
|
+
for r in ps:
|
|
137
|
+
if r == root:
|
|
138
|
+
continue
|
|
139
|
+
values_by_rank[r] = comm.recv(r, gather_key)
|
|
140
|
+
|
|
141
|
+
gathered = [values_by_rank[r] for r in ps]
|
|
142
|
+
|
|
143
|
+
for r in ps:
|
|
144
|
+
if r == root:
|
|
145
|
+
continue
|
|
146
|
+
comm.send(r, bcast_key, gathered)
|
|
147
|
+
|
|
148
|
+
return gathered
|
|
149
|
+
|
|
150
|
+
|
|
151
|
+
def allreduce_bool_and(
|
|
152
|
+
comm: Communicator,
|
|
153
|
+
value: bool,
|
|
154
|
+
*,
|
|
155
|
+
participants: Sequence[int],
|
|
156
|
+
key_prefix: str,
|
|
157
|
+
) -> bool:
|
|
158
|
+
return _allreduce_bool(
|
|
159
|
+
comm,
|
|
160
|
+
value,
|
|
161
|
+
participants=participants,
|
|
162
|
+
key_prefix=key_prefix,
|
|
163
|
+
combine=operator.and_,
|
|
164
|
+
)
|
|
165
|
+
|
|
166
|
+
|
|
167
|
+
def allreduce_bool_or(
|
|
168
|
+
comm: Communicator,
|
|
169
|
+
value: bool,
|
|
170
|
+
*,
|
|
171
|
+
participants: Sequence[int],
|
|
172
|
+
key_prefix: str,
|
|
173
|
+
) -> bool:
|
|
174
|
+
return _allreduce_bool(
|
|
175
|
+
comm,
|
|
176
|
+
value,
|
|
177
|
+
participants=participants,
|
|
178
|
+
key_prefix=key_prefix,
|
|
179
|
+
combine=operator.or_,
|
|
180
|
+
)
|
|
181
|
+
|
|
182
|
+
|
|
183
|
+
def allreduce_bool_xor(
|
|
184
|
+
comm: Communicator,
|
|
185
|
+
value: bool,
|
|
186
|
+
*,
|
|
187
|
+
participants: Sequence[int],
|
|
188
|
+
key_prefix: str,
|
|
189
|
+
) -> bool:
|
|
190
|
+
return _allreduce_bool(
|
|
191
|
+
comm,
|
|
192
|
+
value,
|
|
193
|
+
participants=participants,
|
|
194
|
+
key_prefix=key_prefix,
|
|
195
|
+
combine=operator.xor,
|
|
196
|
+
)
|
|
197
|
+
|
|
198
|
+
|
|
199
|
+
def _allreduce_bool(
|
|
200
|
+
comm: Communicator,
|
|
201
|
+
value: bool,
|
|
202
|
+
*,
|
|
203
|
+
participants: Sequence[int],
|
|
204
|
+
key_prefix: str,
|
|
205
|
+
combine: Callable[[bool, bool], bool],
|
|
206
|
+
) -> bool:
|
|
207
|
+
ps = normalize_participants(comm, participants)
|
|
208
|
+
root = ps[0]
|
|
209
|
+
|
|
210
|
+
gather_key = f"{key_prefix}_gather"
|
|
211
|
+
bcast_key = f"{key_prefix}_bcast"
|
|
212
|
+
|
|
213
|
+
if comm.rank != root:
|
|
214
|
+
comm.send(root, gather_key, bool(value))
|
|
215
|
+
return bool(comm.recv(root, bcast_key))
|
|
216
|
+
|
|
217
|
+
acc = bool(value)
|
|
218
|
+
for r in ps:
|
|
219
|
+
if r == root:
|
|
220
|
+
continue
|
|
221
|
+
acc = combine(acc, bool(comm.recv(r, gather_key)))
|
|
222
|
+
|
|
223
|
+
for r in ps:
|
|
224
|
+
if r == root:
|
|
225
|
+
continue
|
|
226
|
+
comm.send(r, bcast_key, acc)
|
|
227
|
+
|
|
228
|
+
return acc
|
|
@@ -0,0 +1,275 @@
|
|
|
1
|
+
# Copyright 2026 Ant Group Co., Ltd.
|
|
2
|
+
#
|
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
4
|
+
# you may not use this file except in compliance with the License.
|
|
5
|
+
# You may obtain a copy of the License at
|
|
6
|
+
#
|
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
8
|
+
#
|
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
12
|
+
# See the License for the specific language governing permissions and
|
|
13
|
+
# limitations under the License.
|
|
14
|
+
|
|
15
|
+
"""Simp worker-side collectives (wrapper layer).
|
|
16
|
+
|
|
17
|
+
This module is the *context-aware wrapper* on top of
|
|
18
|
+
`mplang.backends.simp_worker.collective_algorithms`.
|
|
19
|
+
|
|
20
|
+
Responsibilities here:
|
|
21
|
+
- Resolve "participants" from (explicit arg / op.attrs["parties"] /
|
|
22
|
+
worker.current_parties / world).
|
|
23
|
+
- Build collision-free `key_prefix` using interpreter execution IDs.
|
|
24
|
+
|
|
25
|
+
The underlying algorithms only depend on the communicator interface.
|
|
26
|
+
"""
|
|
27
|
+
|
|
28
|
+
from __future__ import annotations
|
|
29
|
+
|
|
30
|
+
from collections.abc import Sequence
|
|
31
|
+
from typing import Any, Protocol
|
|
32
|
+
|
|
33
|
+
from mplang.backends.simp_worker import collective_algorithms as algo
|
|
34
|
+
from mplang.edsl.graph import Operation
|
|
35
|
+
|
|
36
|
+
|
|
37
|
+
class _ExecContext(Protocol):
|
|
38
|
+
def current_op_exec_id(self) -> int: ...
|
|
39
|
+
|
|
40
|
+
def current_graph_exec_key(self) -> str: ...
|
|
41
|
+
|
|
42
|
+
|
|
43
|
+
class _Worker(Protocol):
|
|
44
|
+
rank: int
|
|
45
|
+
world_size: int
|
|
46
|
+
communicator: algo.Communicator
|
|
47
|
+
current_parties: tuple[int, ...] | None
|
|
48
|
+
|
|
49
|
+
|
|
50
|
+
def resolve_participants(
|
|
51
|
+
worker: _Worker,
|
|
52
|
+
*,
|
|
53
|
+
op: Operation | None = None,
|
|
54
|
+
participants: Sequence[int] | None = None,
|
|
55
|
+
) -> Sequence[int]:
|
|
56
|
+
"""Resolve participant ranks.
|
|
57
|
+
|
|
58
|
+
Priority:
|
|
59
|
+
1) explicit participants argument
|
|
60
|
+
2) op.attrs["parties"] if present
|
|
61
|
+
3) worker.current_parties if set (pcall_static dynamic scope)
|
|
62
|
+
4) all ranks [0, world_size)
|
|
63
|
+
|
|
64
|
+
Note:
|
|
65
|
+
Normalization/validation (sorting, emptiness, range checks, rank
|
|
66
|
+
inclusion) is intentionally delegated to the lower-level algorithms in
|
|
67
|
+
`collective_algorithms`.
|
|
68
|
+
"""
|
|
69
|
+
|
|
70
|
+
if participants is not None:
|
|
71
|
+
return participants
|
|
72
|
+
|
|
73
|
+
if op is not None:
|
|
74
|
+
parties = op.attrs.get("parties")
|
|
75
|
+
if parties is not None:
|
|
76
|
+
if not isinstance(parties, Sequence):
|
|
77
|
+
raise TypeError(
|
|
78
|
+
"op.attrs['parties'] must be a sequence of rank integers"
|
|
79
|
+
)
|
|
80
|
+
return tuple(int(r) for r in parties)
|
|
81
|
+
|
|
82
|
+
if worker.current_parties is not None:
|
|
83
|
+
return worker.current_parties
|
|
84
|
+
|
|
85
|
+
return tuple(range(worker.world_size))
|
|
86
|
+
|
|
87
|
+
|
|
88
|
+
def _collective_prefix(
|
|
89
|
+
interpreter: _ExecContext, *, op: Operation | None, name: str
|
|
90
|
+
) -> str:
|
|
91
|
+
exec_id = interpreter.current_op_exec_id()
|
|
92
|
+
graph_key = interpreter.current_graph_exec_key()
|
|
93
|
+
op_name = op.name if op is not None else "_"
|
|
94
|
+
return f"coll_{graph_key}_{op_name}_{exec_id}_{name}"
|
|
95
|
+
|
|
96
|
+
|
|
97
|
+
def barrier(
|
|
98
|
+
interpreter: _ExecContext,
|
|
99
|
+
worker: _Worker,
|
|
100
|
+
*,
|
|
101
|
+
op: Operation | None = None,
|
|
102
|
+
participants: Sequence[int] | None = None,
|
|
103
|
+
name: str = "barrier",
|
|
104
|
+
) -> None:
|
|
105
|
+
ps = resolve_participants(worker, op=op, participants=participants)
|
|
106
|
+
prefix = _collective_prefix(interpreter, op=op, name=name)
|
|
107
|
+
algo.barrier(worker.communicator, participants=ps, key_prefix=prefix)
|
|
108
|
+
|
|
109
|
+
|
|
110
|
+
def broadcast(
|
|
111
|
+
interpreter: _ExecContext,
|
|
112
|
+
worker: _Worker,
|
|
113
|
+
value: Any,
|
|
114
|
+
*,
|
|
115
|
+
root: int,
|
|
116
|
+
op: Operation | None = None,
|
|
117
|
+
participants: Sequence[int] | None = None,
|
|
118
|
+
name: str = "broadcast",
|
|
119
|
+
) -> Any:
|
|
120
|
+
ps = resolve_participants(worker, op=op, participants=participants)
|
|
121
|
+
prefix = _collective_prefix(interpreter, op=op, name=name)
|
|
122
|
+
return algo.broadcast(
|
|
123
|
+
worker.communicator,
|
|
124
|
+
value,
|
|
125
|
+
root=int(root),
|
|
126
|
+
participants=ps,
|
|
127
|
+
key_prefix=prefix,
|
|
128
|
+
)
|
|
129
|
+
|
|
130
|
+
|
|
131
|
+
def allgather_obj(
|
|
132
|
+
interpreter: _ExecContext,
|
|
133
|
+
worker: _Worker,
|
|
134
|
+
value: Any,
|
|
135
|
+
*,
|
|
136
|
+
op: Operation | None = None,
|
|
137
|
+
participants: Sequence[int] | None = None,
|
|
138
|
+
name: str = "allgather_obj",
|
|
139
|
+
) -> list[Any]:
|
|
140
|
+
ps = resolve_participants(worker, op=op, participants=participants)
|
|
141
|
+
prefix = _collective_prefix(interpreter, op=op, name=name)
|
|
142
|
+
return algo.allgather(
|
|
143
|
+
worker.communicator, value, participants=ps, key_prefix=prefix
|
|
144
|
+
)
|
|
145
|
+
|
|
146
|
+
|
|
147
|
+
def allgather_bool(
|
|
148
|
+
interpreter: _ExecContext,
|
|
149
|
+
worker: _Worker,
|
|
150
|
+
value: bool,
|
|
151
|
+
*,
|
|
152
|
+
op: Operation | None = None,
|
|
153
|
+
participants: Sequence[int] | None = None,
|
|
154
|
+
name: str = "allgather_bool",
|
|
155
|
+
) -> list[bool]:
|
|
156
|
+
gathered = allgather_obj(
|
|
157
|
+
interpreter,
|
|
158
|
+
worker,
|
|
159
|
+
bool(value),
|
|
160
|
+
op=op,
|
|
161
|
+
participants=participants,
|
|
162
|
+
name=name,
|
|
163
|
+
)
|
|
164
|
+
return [bool(v) for v in gathered]
|
|
165
|
+
|
|
166
|
+
|
|
167
|
+
def allreduce_bool_and(
|
|
168
|
+
interpreter: _ExecContext,
|
|
169
|
+
worker: _Worker,
|
|
170
|
+
value: bool,
|
|
171
|
+
*,
|
|
172
|
+
op: Operation | None = None,
|
|
173
|
+
participants: Sequence[int] | None = None,
|
|
174
|
+
name: str = "allreduce_bool_and",
|
|
175
|
+
) -> bool:
|
|
176
|
+
ps = resolve_participants(worker, op=op, participants=participants)
|
|
177
|
+
prefix = _collective_prefix(interpreter, op=op, name=name)
|
|
178
|
+
return algo.allreduce_bool_and(
|
|
179
|
+
worker.communicator,
|
|
180
|
+
bool(value),
|
|
181
|
+
participants=ps,
|
|
182
|
+
key_prefix=prefix,
|
|
183
|
+
)
|
|
184
|
+
|
|
185
|
+
|
|
186
|
+
def allreduce_bool_or(
|
|
187
|
+
interpreter: _ExecContext,
|
|
188
|
+
worker: _Worker,
|
|
189
|
+
value: bool,
|
|
190
|
+
*,
|
|
191
|
+
op: Operation | None = None,
|
|
192
|
+
participants: Sequence[int] | None = None,
|
|
193
|
+
name: str = "allreduce_bool_or",
|
|
194
|
+
) -> bool:
|
|
195
|
+
ps = resolve_participants(worker, op=op, participants=participants)
|
|
196
|
+
prefix = _collective_prefix(interpreter, op=op, name=name)
|
|
197
|
+
return algo.allreduce_bool_or(
|
|
198
|
+
worker.communicator,
|
|
199
|
+
bool(value),
|
|
200
|
+
participants=ps,
|
|
201
|
+
key_prefix=prefix,
|
|
202
|
+
)
|
|
203
|
+
|
|
204
|
+
|
|
205
|
+
def allreduce_bool_xor(
|
|
206
|
+
interpreter: _ExecContext,
|
|
207
|
+
worker: _Worker,
|
|
208
|
+
value: bool,
|
|
209
|
+
*,
|
|
210
|
+
op: Operation | None = None,
|
|
211
|
+
participants: Sequence[int] | None = None,
|
|
212
|
+
name: str = "allreduce_bool_xor",
|
|
213
|
+
) -> bool:
|
|
214
|
+
ps = resolve_participants(worker, op=op, participants=participants)
|
|
215
|
+
prefix = _collective_prefix(interpreter, op=op, name=name)
|
|
216
|
+
return algo.allreduce_bool_xor(
|
|
217
|
+
worker.communicator,
|
|
218
|
+
bool(value),
|
|
219
|
+
participants=ps,
|
|
220
|
+
key_prefix=prefix,
|
|
221
|
+
)
|
|
222
|
+
|
|
223
|
+
|
|
224
|
+
def verify_uniform_predicate(
|
|
225
|
+
interpreter: _ExecContext,
|
|
226
|
+
worker: _Worker,
|
|
227
|
+
pred: bool,
|
|
228
|
+
*,
|
|
229
|
+
op: Operation | None = None,
|
|
230
|
+
participants: Sequence[int] | None = None,
|
|
231
|
+
name: str = "uniform_predicate",
|
|
232
|
+
) -> bool:
|
|
233
|
+
"""Verify that `pred` is uniform across participants.
|
|
234
|
+
|
|
235
|
+
Uses AND/OR all-reduce to detect mismatch. If mismatch is detected, runs an
|
|
236
|
+
allgather to provide a helpful error message. All participants execute the
|
|
237
|
+
same comm steps to avoid deadlocks.
|
|
238
|
+
"""
|
|
239
|
+
|
|
240
|
+
ps = resolve_participants(worker, op=op, participants=participants)
|
|
241
|
+
|
|
242
|
+
all_and = allreduce_bool_and(
|
|
243
|
+
interpreter,
|
|
244
|
+
worker,
|
|
245
|
+
bool(pred),
|
|
246
|
+
op=op,
|
|
247
|
+
participants=ps,
|
|
248
|
+
name=f"{name}_and",
|
|
249
|
+
)
|
|
250
|
+
all_or = allreduce_bool_or(
|
|
251
|
+
interpreter,
|
|
252
|
+
worker,
|
|
253
|
+
bool(pred),
|
|
254
|
+
op=op,
|
|
255
|
+
participants=ps,
|
|
256
|
+
name=f"{name}_or",
|
|
257
|
+
)
|
|
258
|
+
|
|
259
|
+
if all_and != all_or:
|
|
260
|
+
gathered = allgather_bool(
|
|
261
|
+
interpreter,
|
|
262
|
+
worker,
|
|
263
|
+
bool(pred),
|
|
264
|
+
op=op,
|
|
265
|
+
participants=ps,
|
|
266
|
+
name=f"{name}_gather",
|
|
267
|
+
)
|
|
268
|
+
ps_norm = algo.normalize_participants(worker.communicator, ps)
|
|
269
|
+
dist = dict(zip(ps_norm, gathered, strict=True))
|
|
270
|
+
raise RuntimeError(
|
|
271
|
+
"simp.uniform_cond predicate is not uniform across participants: "
|
|
272
|
+
f"participants={ps_norm}, values={dist}"
|
|
273
|
+
)
|
|
274
|
+
|
|
275
|
+
return bool(pred)
|
|
@@ -22,6 +22,7 @@ from __future__ import annotations
|
|
|
22
22
|
|
|
23
23
|
from typing import Any
|
|
24
24
|
|
|
25
|
+
from mplang.backends.simp_worker.collectives import verify_uniform_predicate
|
|
25
26
|
from mplang.dialects import simp
|
|
26
27
|
from mplang.edsl.graph import Operation
|
|
27
28
|
from mplang.runtime.interpreter import Interpreter
|
|
@@ -117,12 +118,14 @@ def _uniform_cond_worker_impl(
|
|
|
117
118
|
"""Worker implementation of simp.uniform_cond."""
|
|
118
119
|
from mplang.backends.tensor_impl import TensorValue
|
|
119
120
|
|
|
120
|
-
|
|
121
|
-
pass # TODO: Implement AllReduce verification
|
|
121
|
+
worker = _ensure_worker_context(interpreter, "uniform_cond_impl")
|
|
122
122
|
|
|
123
123
|
if isinstance(pred, TensorValue):
|
|
124
124
|
pred = bool(pred.unwrap())
|
|
125
125
|
|
|
126
|
+
if op.attrs.get("verify_uniform", True):
|
|
127
|
+
pred = verify_uniform_predicate(interpreter, worker, bool(pred), op=op)
|
|
128
|
+
|
|
126
129
|
if pred:
|
|
127
130
|
result = interpreter.evaluate_graph(op.regions[0], list(args))
|
|
128
131
|
else:
|
|
@@ -346,7 +346,7 @@ class ParquetReader(pa.RecordBatchReader):
|
|
|
346
346
|
pass
|
|
347
347
|
if batches:
|
|
348
348
|
return pa.Table.from_batches(batches)
|
|
349
|
-
return pa.Table.from_batches([])
|
|
349
|
+
return pa.Table.from_batches([], schema=self._schema)
|
|
350
350
|
|
|
351
351
|
def read_next_batch(self) -> pa.RecordBatch:
|
|
352
352
|
batch = next(self._iter)
|
|
@@ -0,0 +1,141 @@
|
|
|
1
|
+
# Copyright 2026 Ant Group Co., Ltd.
|
|
2
|
+
#
|
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
4
|
+
# you may not use this file except in compliance with the License.
|
|
5
|
+
# You may obtain a copy of the License at
|
|
6
|
+
#
|
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
8
|
+
#
|
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
12
|
+
# See the License for the specific language governing permissions and
|
|
13
|
+
# limitations under the License.
|
|
14
|
+
|
|
15
|
+
from __future__ import annotations
|
|
16
|
+
|
|
17
|
+
import concurrent.futures
|
|
18
|
+
from dataclasses import dataclass
|
|
19
|
+
from typing import Any
|
|
20
|
+
|
|
21
|
+
from mplang.backends.simp_worker.collectives import (
|
|
22
|
+
allgather_obj,
|
|
23
|
+
allreduce_bool_and,
|
|
24
|
+
allreduce_bool_or,
|
|
25
|
+
barrier,
|
|
26
|
+
broadcast,
|
|
27
|
+
)
|
|
28
|
+
from mplang.backends.simp_worker.mem import LocalMesh
|
|
29
|
+
|
|
30
|
+
|
|
31
|
+
class _FakeInterpreter:
|
|
32
|
+
def current_op_exec_id(self) -> int:
|
|
33
|
+
return 1
|
|
34
|
+
|
|
35
|
+
def current_graph_exec_key(self) -> str:
|
|
36
|
+
return "test_graph"
|
|
37
|
+
|
|
38
|
+
|
|
39
|
+
@dataclass
|
|
40
|
+
class _Worker:
|
|
41
|
+
rank: int
|
|
42
|
+
world_size: int
|
|
43
|
+
communicator: Any
|
|
44
|
+
current_parties: tuple[int, ...] | None = None
|
|
45
|
+
|
|
46
|
+
|
|
47
|
+
def test_broadcast_roundtrip() -> None:
|
|
48
|
+
mesh = LocalMesh(world_size=3)
|
|
49
|
+
interp: Any = _FakeInterpreter()
|
|
50
|
+
|
|
51
|
+
def run_rank(rank: int) -> Any:
|
|
52
|
+
worker = _Worker(rank, 3, mesh.comms[rank])
|
|
53
|
+
return broadcast(
|
|
54
|
+
interp,
|
|
55
|
+
worker,
|
|
56
|
+
{"x": 123},
|
|
57
|
+
root=0,
|
|
58
|
+
participants=(0, 1, 2),
|
|
59
|
+
name="test_broadcast",
|
|
60
|
+
)
|
|
61
|
+
|
|
62
|
+
with concurrent.futures.ThreadPoolExecutor(max_workers=3) as ex:
|
|
63
|
+
results = list(ex.map(run_rank, range(3)))
|
|
64
|
+
|
|
65
|
+
assert results == [{"x": 123}, {"x": 123}, {"x": 123}]
|
|
66
|
+
mesh.shutdown()
|
|
67
|
+
|
|
68
|
+
|
|
69
|
+
def test_allgather_obj_order() -> None:
|
|
70
|
+
mesh = LocalMesh(world_size=3)
|
|
71
|
+
interp: Any = _FakeInterpreter()
|
|
72
|
+
|
|
73
|
+
def run_rank(rank: int) -> Any:
|
|
74
|
+
worker = _Worker(rank, 3, mesh.comms[rank])
|
|
75
|
+
gathered = allgather_obj(
|
|
76
|
+
interp,
|
|
77
|
+
worker,
|
|
78
|
+
rank,
|
|
79
|
+
participants=(0, 1, 2),
|
|
80
|
+
name="test_allgather",
|
|
81
|
+
)
|
|
82
|
+
return tuple(gathered)
|
|
83
|
+
|
|
84
|
+
with concurrent.futures.ThreadPoolExecutor(max_workers=3) as ex:
|
|
85
|
+
results = list(ex.map(run_rank, range(3)))
|
|
86
|
+
|
|
87
|
+
assert results == [(0, 1, 2), (0, 1, 2), (0, 1, 2)]
|
|
88
|
+
mesh.shutdown()
|
|
89
|
+
|
|
90
|
+
|
|
91
|
+
def test_allreduce_bool_and_or() -> None:
|
|
92
|
+
mesh = LocalMesh(world_size=3)
|
|
93
|
+
interp: Any = _FakeInterpreter()
|
|
94
|
+
|
|
95
|
+
inputs = {0: True, 1: True, 2: False}
|
|
96
|
+
|
|
97
|
+
def run_rank(rank: int) -> tuple[bool, bool]:
|
|
98
|
+
worker = _Worker(rank, 3, mesh.comms[rank])
|
|
99
|
+
v = inputs[rank]
|
|
100
|
+
r_and = allreduce_bool_and(
|
|
101
|
+
interp,
|
|
102
|
+
worker,
|
|
103
|
+
v,
|
|
104
|
+
participants=(0, 1, 2),
|
|
105
|
+
name="test_allreduce_and",
|
|
106
|
+
)
|
|
107
|
+
r_or = allreduce_bool_or(
|
|
108
|
+
interp,
|
|
109
|
+
worker,
|
|
110
|
+
v,
|
|
111
|
+
participants=(0, 1, 2),
|
|
112
|
+
name="test_allreduce_or",
|
|
113
|
+
)
|
|
114
|
+
return r_and, r_or
|
|
115
|
+
|
|
116
|
+
with concurrent.futures.ThreadPoolExecutor(max_workers=3) as ex:
|
|
117
|
+
results = list(ex.map(run_rank, range(3)))
|
|
118
|
+
|
|
119
|
+
assert results == [(False, True), (False, True), (False, True)]
|
|
120
|
+
mesh.shutdown()
|
|
121
|
+
|
|
122
|
+
|
|
123
|
+
def test_barrier_completes() -> None:
|
|
124
|
+
mesh = LocalMesh(world_size=3)
|
|
125
|
+
interp: Any = _FakeInterpreter()
|
|
126
|
+
|
|
127
|
+
def run_rank(rank: int) -> int:
|
|
128
|
+
worker = _Worker(rank, 3, mesh.comms[rank])
|
|
129
|
+
barrier(
|
|
130
|
+
interp,
|
|
131
|
+
worker,
|
|
132
|
+
participants=(0, 1, 2),
|
|
133
|
+
name="test_barrier",
|
|
134
|
+
)
|
|
135
|
+
return rank
|
|
136
|
+
|
|
137
|
+
with concurrent.futures.ThreadPoolExecutor(max_workers=3) as ex:
|
|
138
|
+
results = list(ex.map(run_rank, range(3)))
|
|
139
|
+
|
|
140
|
+
assert sorted(results) == [0, 1, 2]
|
|
141
|
+
mesh.shutdown()
|