mplang-nightly 0.1.dev203__tar.gz → 0.1.dev266__tar.gz
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- {mplang_nightly-0.1.dev203 → mplang_nightly-0.1.dev266}/.gitignore +1 -0
- {mplang_nightly-0.1.dev203 → mplang_nightly-0.1.dev266}/PKG-INFO +11 -5
- {mplang_nightly-0.1.dev203 → mplang_nightly-0.1.dev266}/README.md +1 -1
- {mplang_nightly-0.1.dev203/examples → mplang_nightly-0.1.dev266/examples/v1}/stax_nn/stax_nn.py +2 -2
- mplang_nightly-0.1.dev266/examples/v1/xgboost/bench_fhe_hist.py +615 -0
- {mplang_nightly-0.1.dev203/examples → mplang_nightly-0.1.dev266/examples/v1}/xgboost/sgb.py +304 -218
- {mplang_nightly-0.1.dev203/examples → mplang_nightly-0.1.dev266/examples/v1}/xgboost/sgb_test.py +279 -70
- mplang_nightly-0.1.dev266/mplang/__init__.py +46 -0
- {mplang_nightly-0.1.dev203/mplang → mplang_nightly-0.1.dev266/mplang/v1}/__init__.py +11 -11
- {mplang_nightly-0.1.dev203/mplang → mplang_nightly-0.1.dev266/mplang/v1}/_device.py +63 -13
- {mplang_nightly-0.1.dev203/mplang → mplang_nightly-0.1.dev266/mplang/v1}/analysis/__init__.py +1 -1
- {mplang_nightly-0.1.dev203/mplang → mplang_nightly-0.1.dev266/mplang/v1}/analysis/diagram.py +4 -4
- {mplang_nightly-0.1.dev203/mplang → mplang_nightly-0.1.dev266/mplang/v1}/core/__init__.py +20 -14
- {mplang_nightly-0.1.dev203/mplang → mplang_nightly-0.1.dev266/mplang/v1}/core/comm.py +1 -1
- {mplang_nightly-0.1.dev203/mplang → mplang_nightly-0.1.dev266/mplang/v1}/core/context_mgr.py +1 -1
- {mplang_nightly-0.1.dev203/mplang → mplang_nightly-0.1.dev266/mplang/v1}/core/expr/__init__.py +7 -7
- {mplang_nightly-0.1.dev203/mplang → mplang_nightly-0.1.dev266/mplang/v1}/core/expr/ast.py +10 -10
- {mplang_nightly-0.1.dev203/mplang → mplang_nightly-0.1.dev266/mplang/v1}/core/expr/evaluator.py +8 -8
- {mplang_nightly-0.1.dev203/mplang → mplang_nightly-0.1.dev266/mplang/v1}/core/expr/printer.py +6 -6
- {mplang_nightly-0.1.dev203/mplang → mplang_nightly-0.1.dev266/mplang/v1}/core/expr/transformer.py +2 -2
- {mplang_nightly-0.1.dev203/mplang → mplang_nightly-0.1.dev266/mplang/v1}/core/expr/utils.py +2 -2
- {mplang_nightly-0.1.dev203/mplang → mplang_nightly-0.1.dev266/mplang/v1}/core/expr/visitor.py +1 -1
- {mplang_nightly-0.1.dev203/mplang → mplang_nightly-0.1.dev266/mplang/v1}/core/expr/walk.py +1 -1
- {mplang_nightly-0.1.dev203/mplang → mplang_nightly-0.1.dev266/mplang/v1}/core/interp.py +6 -6
- {mplang_nightly-0.1.dev203/mplang → mplang_nightly-0.1.dev266/mplang/v1}/core/mpir.py +13 -11
- {mplang_nightly-0.1.dev203/mplang → mplang_nightly-0.1.dev266/mplang/v1}/core/mpobject.py +6 -6
- {mplang_nightly-0.1.dev203/mplang → mplang_nightly-0.1.dev266/mplang/v1}/core/mptype.py +7 -7
- {mplang_nightly-0.1.dev203/mplang → mplang_nightly-0.1.dev266/mplang/v1}/core/pfunc.py +2 -2
- {mplang_nightly-0.1.dev203/mplang → mplang_nightly-0.1.dev266/mplang/v1}/core/primitive.py +10 -10
- {mplang_nightly-0.1.dev203/mplang → mplang_nightly-0.1.dev266/mplang/v1}/core/table.py +1 -1
- {mplang_nightly-0.1.dev203/mplang → mplang_nightly-0.1.dev266/mplang/v1}/core/tensor.py +1 -1
- {mplang_nightly-0.1.dev203/mplang → mplang_nightly-0.1.dev266/mplang/v1}/core/tracer.py +9 -9
- {mplang_nightly-0.1.dev203/mplang → mplang_nightly-0.1.dev266/mplang/v1}/host.py +2 -2
- {mplang_nightly-0.1.dev203/mplang → mplang_nightly-0.1.dev266/mplang/v1}/kernels/__init__.py +1 -1
- {mplang_nightly-0.1.dev203/mplang → mplang_nightly-0.1.dev266/mplang/v1}/kernels/base.py +1 -1
- {mplang_nightly-0.1.dev203/mplang → mplang_nightly-0.1.dev266/mplang/v1}/kernels/basic.py +13 -13
- {mplang_nightly-0.1.dev203/mplang → mplang_nightly-0.1.dev266/mplang/v1}/kernels/context.py +14 -14
- {mplang_nightly-0.1.dev203/mplang → mplang_nightly-0.1.dev266/mplang/v1}/kernels/crypto.py +4 -4
- {mplang_nightly-0.1.dev203/mplang → mplang_nightly-0.1.dev266/mplang/v1}/kernels/fhe.py +9 -7
- {mplang_nightly-0.1.dev203/mplang → mplang_nightly-0.1.dev266/mplang/v1}/kernels/mock_tee.py +3 -3
- {mplang_nightly-0.1.dev203/mplang → mplang_nightly-0.1.dev266/mplang/v1}/kernels/phe.py +18 -14
- {mplang_nightly-0.1.dev203/mplang → mplang_nightly-0.1.dev266/mplang/v1}/kernels/spu.py +5 -5
- {mplang_nightly-0.1.dev203/mplang → mplang_nightly-0.1.dev266/mplang/v1}/kernels/sql_duckdb.py +5 -3
- {mplang_nightly-0.1.dev203/mplang → mplang_nightly-0.1.dev266/mplang/v1}/kernels/stablehlo.py +18 -17
- {mplang_nightly-0.1.dev203/mplang → mplang_nightly-0.1.dev266/mplang/v1}/kernels/value.py +1 -1
- {mplang_nightly-0.1.dev203/mplang → mplang_nightly-0.1.dev266/mplang/v1}/ops/__init__.py +3 -2
- {mplang_nightly-0.1.dev203/mplang → mplang_nightly-0.1.dev266/mplang/v1}/ops/base.py +1 -1
- {mplang_nightly-0.1.dev203/mplang → mplang_nightly-0.1.dev266/mplang/v1}/ops/basic.py +3 -3
- {mplang_nightly-0.1.dev203/mplang → mplang_nightly-0.1.dev266/mplang/v1}/ops/crypto.py +4 -4
- {mplang_nightly-0.1.dev203/mplang → mplang_nightly-0.1.dev266/mplang/v1}/ops/fhe.py +2 -2
- {mplang_nightly-0.1.dev203/mplang → mplang_nightly-0.1.dev266/mplang/v1}/ops/jax_cc.py +26 -59
- mplang_nightly-0.1.dev266/mplang/v1/ops/nnx_cc.py +168 -0
- {mplang_nightly-0.1.dev203/mplang → mplang_nightly-0.1.dev266/mplang/v1}/ops/phe.py +16 -3
- {mplang_nightly-0.1.dev203/mplang → mplang_nightly-0.1.dev266/mplang/v1}/ops/spu.py +3 -3
- {mplang_nightly-0.1.dev203/mplang → mplang_nightly-0.1.dev266/mplang/v1}/ops/sql_cc.py +55 -48
- {mplang_nightly-0.1.dev203/mplang → mplang_nightly-0.1.dev266/mplang/v1}/ops/tee.py +2 -2
- {mplang_nightly-0.1.dev203/mplang → mplang_nightly-0.1.dev266/mplang/v1}/runtime/__init__.py +2 -2
- {mplang_nightly-0.1.dev203/mplang → mplang_nightly-0.1.dev266/mplang/v1}/runtime/cli.py +3 -3
- {mplang_nightly-0.1.dev203/mplang → mplang_nightly-0.1.dev266/mplang/v1}/runtime/client.py +1 -1
- {mplang_nightly-0.1.dev203/mplang → mplang_nightly-0.1.dev266/mplang/v1}/runtime/communicator.py +2 -2
- {mplang_nightly-0.1.dev203/mplang → mplang_nightly-0.1.dev266/mplang/v1}/runtime/data_providers.py +77 -15
- {mplang_nightly-0.1.dev203/mplang → mplang_nightly-0.1.dev266/mplang/v1}/runtime/driver.py +4 -4
- {mplang_nightly-0.1.dev203/mplang → mplang_nightly-0.1.dev266/mplang/v1}/runtime/server.py +12 -8
- {mplang_nightly-0.1.dev203/mplang → mplang_nightly-0.1.dev266/mplang/v1}/runtime/session.py +13 -13
- {mplang_nightly-0.1.dev203/mplang → mplang_nightly-0.1.dev266/mplang/v1}/runtime/simulation.py +6 -6
- {mplang_nightly-0.1.dev203/mplang → mplang_nightly-0.1.dev266/mplang/v1}/simp/api.py +72 -5
- {mplang_nightly-0.1.dev203/mplang → mplang_nightly-0.1.dev266/mplang/v1}/simp/mpi.py +1 -1
- {mplang_nightly-0.1.dev203/mplang → mplang_nightly-0.1.dev266/mplang/v1}/simp/party.py +5 -5
- {mplang_nightly-0.1.dev203/mplang → mplang_nightly-0.1.dev266/mplang/v1}/simp/random.py +2 -2
- {mplang_nightly-0.1.dev203/mplang → mplang_nightly-0.1.dev266/mplang/v1}/simp/smpc.py +7 -7
- {mplang_nightly-0.1.dev203/mplang → mplang_nightly-0.1.dev266/mplang/v1}/utils/table_utils.py +1 -1
- mplang_nightly-0.1.dev266/mplang/v2/__init__.py +424 -0
- mplang_nightly-0.1.dev266/mplang/v2/backends/__init__.py +57 -0
- mplang_nightly-0.1.dev266/mplang/v2/backends/bfv_impl.py +705 -0
- mplang_nightly-0.1.dev266/mplang/v2/backends/crypto_impl.py +723 -0
- mplang_nightly-0.1.dev266/mplang/v2/backends/field_impl.py +454 -0
- mplang_nightly-0.1.dev266/mplang/v2/backends/func_impl.py +107 -0
- mplang_nightly-0.1.dev266/mplang/v2/backends/phe_impl.py +148 -0
- mplang_nightly-0.1.dev266/mplang/v2/backends/simp_design.md +136 -0
- mplang_nightly-0.1.dev266/mplang/v2/backends/simp_driver/__init__.py +41 -0
- mplang_nightly-0.1.dev266/mplang/v2/backends/simp_driver/http.py +168 -0
- mplang_nightly-0.1.dev266/mplang/v2/backends/simp_driver/mem.py +280 -0
- mplang_nightly-0.1.dev266/mplang/v2/backends/simp_driver/ops.py +135 -0
- mplang_nightly-0.1.dev266/mplang/v2/backends/simp_driver/state.py +60 -0
- mplang_nightly-0.1.dev266/mplang/v2/backends/simp_driver/values.py +52 -0
- mplang_nightly-0.1.dev266/mplang/v2/backends/simp_worker/__init__.py +29 -0
- mplang_nightly-0.1.dev266/mplang/v2/backends/simp_worker/http.py +323 -0
- mplang_nightly-0.1.dev266/mplang/v2/backends/simp_worker/mem.py +99 -0
- mplang_nightly-0.1.dev266/mplang/v2/backends/simp_worker/ops.py +167 -0
- mplang_nightly-0.1.dev266/mplang/v2/backends/simp_worker/state.py +49 -0
- mplang_nightly-0.1.dev266/mplang/v2/backends/spu_impl.py +262 -0
- mplang_nightly-0.1.dev266/mplang/v2/backends/spu_state.py +124 -0
- mplang_nightly-0.1.dev266/mplang/v2/backends/store_impl.py +62 -0
- mplang_nightly-0.1.dev266/mplang/v2/backends/table_impl.py +838 -0
- mplang_nightly-0.1.dev266/mplang/v2/backends/tee_impl.py +215 -0
- mplang_nightly-0.1.dev266/mplang/v2/backends/tensor_impl.py +519 -0
- mplang_nightly-0.1.dev266/mplang/v2/cli.py +603 -0
- mplang_nightly-0.1.dev266/mplang/v2/cli_guide.md +122 -0
- mplang_nightly-0.1.dev266/mplang/v2/dialects/__init__.py +36 -0
- mplang_nightly-0.1.dev266/mplang/v2/dialects/bfv.py +665 -0
- mplang_nightly-0.1.dev266/mplang/v2/dialects/crypto.py +689 -0
- mplang_nightly-0.1.dev266/mplang/v2/dialects/dtypes.py +378 -0
- mplang_nightly-0.1.dev266/mplang/v2/dialects/field.py +210 -0
- mplang_nightly-0.1.dev266/mplang/v2/dialects/func.py +135 -0
- mplang_nightly-0.1.dev266/mplang/v2/dialects/phe.py +723 -0
- mplang_nightly-0.1.dev266/mplang/v2/dialects/simp.py +944 -0
- mplang_nightly-0.1.dev266/mplang/v2/dialects/spu.py +349 -0
- mplang_nightly-0.1.dev266/mplang/v2/dialects/store.py +63 -0
- mplang_nightly-0.1.dev266/mplang/v2/dialects/table.py +407 -0
- mplang_nightly-0.1.dev266/mplang/v2/dialects/tee.py +346 -0
- mplang_nightly-0.1.dev266/mplang/v2/dialects/tensor.py +1175 -0
- mplang_nightly-0.1.dev266/mplang/v2/edsl/README.md +279 -0
- mplang_nightly-0.1.dev266/mplang/v2/edsl/__init__.py +99 -0
- mplang_nightly-0.1.dev266/mplang/v2/edsl/context.py +311 -0
- mplang_nightly-0.1.dev266/mplang/v2/edsl/graph.py +463 -0
- mplang_nightly-0.1.dev266/mplang/v2/edsl/jit.py +62 -0
- mplang_nightly-0.1.dev266/mplang/v2/edsl/object.py +53 -0
- mplang_nightly-0.1.dev266/mplang/v2/edsl/primitive.py +284 -0
- mplang_nightly-0.1.dev266/mplang/v2/edsl/printer.py +119 -0
- mplang_nightly-0.1.dev266/mplang/v2/edsl/registry.py +207 -0
- mplang_nightly-0.1.dev266/mplang/v2/edsl/serde.py +375 -0
- mplang_nightly-0.1.dev266/mplang/v2/edsl/tracer.py +614 -0
- mplang_nightly-0.1.dev266/mplang/v2/edsl/typing.py +816 -0
- mplang_nightly-0.1.dev266/mplang/v2/kernels/Makefile +30 -0
- mplang_nightly-0.1.dev266/mplang/v2/kernels/__init__.py +23 -0
- mplang_nightly-0.1.dev266/mplang/v2/kernels/gf128.cpp +148 -0
- mplang_nightly-0.1.dev266/mplang/v2/kernels/ldpc.cpp +82 -0
- mplang_nightly-0.1.dev266/mplang/v2/kernels/okvs.cpp +283 -0
- mplang_nightly-0.1.dev266/mplang/v2/kernels/okvs_opt.cpp +291 -0
- mplang_nightly-0.1.dev266/mplang/v2/kernels/py_kernels.py +398 -0
- mplang_nightly-0.1.dev266/mplang/v2/libs/collective.py +330 -0
- mplang_nightly-0.1.dev266/mplang/v2/libs/device/__init__.py +51 -0
- mplang_nightly-0.1.dev266/mplang/v2/libs/device/api.py +813 -0
- mplang_nightly-0.1.dev266/mplang/v2/libs/device/cluster.py +352 -0
- mplang_nightly-0.1.dev266/mplang/v2/libs/ml/__init__.py +23 -0
- mplang_nightly-0.1.dev266/mplang/v2/libs/ml/sgb.py +1873 -0
- mplang_nightly-0.1.dev266/mplang/v2/libs/mpc/__init__.py +41 -0
- mplang_nightly-0.1.dev266/mplang/v2/libs/mpc/_utils.py +99 -0
- mplang_nightly-0.1.dev266/mplang/v2/libs/mpc/analytics/__init__.py +35 -0
- mplang_nightly-0.1.dev266/mplang/v2/libs/mpc/analytics/aggregation.py +372 -0
- mplang_nightly-0.1.dev266/mplang/v2/libs/mpc/analytics/groupby.md +99 -0
- mplang_nightly-0.1.dev266/mplang/v2/libs/mpc/analytics/groupby.py +331 -0
- mplang_nightly-0.1.dev266/mplang/v2/libs/mpc/analytics/permutation.py +386 -0
- mplang_nightly-0.1.dev266/mplang/v2/libs/mpc/common/constants.py +39 -0
- mplang_nightly-0.1.dev266/mplang/v2/libs/mpc/ot/__init__.py +32 -0
- mplang_nightly-0.1.dev266/mplang/v2/libs/mpc/ot/base.py +222 -0
- mplang_nightly-0.1.dev266/mplang/v2/libs/mpc/ot/extension.py +477 -0
- mplang_nightly-0.1.dev266/mplang/v2/libs/mpc/ot/silent.py +217 -0
- mplang_nightly-0.1.dev266/mplang/v2/libs/mpc/psi/__init__.py +40 -0
- mplang_nightly-0.1.dev266/mplang/v2/libs/mpc/psi/cuckoo.py +228 -0
- mplang_nightly-0.1.dev266/mplang/v2/libs/mpc/psi/okvs.py +49 -0
- mplang_nightly-0.1.dev266/mplang/v2/libs/mpc/psi/okvs_gct.py +79 -0
- mplang_nightly-0.1.dev266/mplang/v2/libs/mpc/psi/oprf.py +310 -0
- mplang_nightly-0.1.dev266/mplang/v2/libs/mpc/psi/rr22.py +344 -0
- mplang_nightly-0.1.dev266/mplang/v2/libs/mpc/psi/unbalanced.py +200 -0
- mplang_nightly-0.1.dev266/mplang/v2/libs/mpc/vole/__init__.py +31 -0
- mplang_nightly-0.1.dev266/mplang/v2/libs/mpc/vole/gilboa.py +327 -0
- mplang_nightly-0.1.dev266/mplang/v2/libs/mpc/vole/ldpc.py +383 -0
- mplang_nightly-0.1.dev266/mplang/v2/libs/mpc/vole/silver.py +336 -0
- mplang_nightly-0.1.dev266/mplang/v2/runtime/__init__.py +15 -0
- mplang_nightly-0.1.dev266/mplang/v2/runtime/dialect_state.py +41 -0
- mplang_nightly-0.1.dev266/mplang/v2/runtime/interpreter.py +871 -0
- mplang_nightly-0.1.dev266/mplang/v2/runtime/object_store.py +194 -0
- mplang_nightly-0.1.dev266/mplang/v2/runtime/value.py +141 -0
- {mplang_nightly-0.1.dev203 → mplang_nightly-0.1.dev266}/pyproject.toml +39 -12
- mplang_nightly-0.1.dev266/tests/__init__.py +15 -0
- {mplang_nightly-0.1.dev203 → mplang_nightly-0.1.dev266}/tests/conftest.py +1 -1
- {mplang_nightly-0.1.dev203/tests → mplang_nightly-0.1.dev266/tests/v1}/analysis/test_diagram.py +2 -2
- mplang_nightly-0.1.dev266/tests/v1/conftest.py +17 -0
- {mplang_nightly-0.1.dev203/tests → mplang_nightly-0.1.dev266/tests/v1}/core/expr/conftest.py +4 -4
- {mplang_nightly-0.1.dev203/tests → mplang_nightly-0.1.dev266/tests/v1}/core/expr/test_ast.py +4 -4
- {mplang_nightly-0.1.dev203/tests → mplang_nightly-0.1.dev266/tests/v1}/core/expr/test_printer.py +7 -7
- {mplang_nightly-0.1.dev203/tests → mplang_nightly-0.1.dev266/tests/v1}/core/expr/test_utils.py +3 -3
- {mplang_nightly-0.1.dev203/tests → mplang_nightly-0.1.dev266/tests/v1}/core/expr/test_walk.py +3 -3
- {mplang_nightly-0.1.dev203/tests → mplang_nightly-0.1.dev266/tests/v1}/core/test_cluster.py +1 -1
- {mplang_nightly-0.1.dev203/tests → mplang_nightly-0.1.dev266/tests/v1}/core/test_dtype.py +5 -5
- {mplang_nightly-0.1.dev203/tests → mplang_nightly-0.1.dev266/tests/v1}/core/test_mask.py +1 -1
- {mplang_nightly-0.1.dev203/tests → mplang_nightly-0.1.dev266/tests/v1}/core/test_mpir.py +29 -29
- {mplang_nightly-0.1.dev203/tests → mplang_nightly-0.1.dev266/tests/v1}/core/test_mptype.py +4 -4
- {mplang_nightly-0.1.dev203/tests → mplang_nightly-0.1.dev266/tests/v1}/core/test_primitive.py +14 -14
- {mplang_nightly-0.1.dev203/tests → mplang_nightly-0.1.dev266/tests/v1}/core/test_table.py +2 -2
- {mplang_nightly-0.1.dev203/tests → mplang_nightly-0.1.dev266/tests/v1}/core/test_tensor.py +3 -3
- {mplang_nightly-0.1.dev203/tests → mplang_nightly-0.1.dev266/tests/v1}/core/test_tracer.py +25 -19
- {mplang_nightly-0.1.dev203/tests → mplang_nightly-0.1.dev266/tests/v1}/device/test_device_basic.py +1 -1
- {mplang_nightly-0.1.dev203/tests → mplang_nightly-0.1.dev266/tests/v1}/integration/test_crypto_roundtrip.py +2 -2
- {mplang_nightly-0.1.dev203/tests → mplang_nightly-0.1.dev266/tests/v1}/integration/test_http_e2e.py +1 -1
- {mplang_nightly-0.1.dev203/tests → mplang_nightly-0.1.dev266/tests/v1}/integration/test_symbols_roundtrip.py +8 -6
- {mplang_nightly-0.1.dev203/tests → mplang_nightly-0.1.dev266/tests/v1}/integration/test_tee_workflow.py +2 -2
- {mplang_nightly-0.1.dev203/tests → mplang_nightly-0.1.dev266/tests/v1}/integration/test_tutorials.py +1 -1
- {mplang_nightly-0.1.dev203/tests → mplang_nightly-0.1.dev266/tests/v1}/integration/test_unused_param_integration.py +2 -2
- mplang_nightly-0.1.dev266/tests/v1/kernels/__init__.py +13 -0
- {mplang_nightly-0.1.dev203/tests → mplang_nightly-0.1.dev266/tests/v1}/kernels/test_basic.py +15 -7
- {mplang_nightly-0.1.dev203/tests → mplang_nightly-0.1.dev266/tests/v1}/kernels/test_debug_print.py +5 -5
- {mplang_nightly-0.1.dev203/tests → mplang_nightly-0.1.dev266/tests/v1}/kernels/test_fhe.py +39 -30
- {mplang_nightly-0.1.dev203/tests → mplang_nightly-0.1.dev266/tests/v1}/kernels/test_kernel_binding.py +10 -8
- {mplang_nightly-0.1.dev203/tests → mplang_nightly-0.1.dev266/tests/v1}/kernels/test_phe.py +7 -7
- {mplang_nightly-0.1.dev203/tests → mplang_nightly-0.1.dev266/tests/v1}/kernels/test_spu.py +12 -12
- {mplang_nightly-0.1.dev203/tests → mplang_nightly-0.1.dev266/tests/v1}/kernels/test_sql_duckdb.py +4 -4
- {mplang_nightly-0.1.dev203/tests → mplang_nightly-0.1.dev266/tests/v1}/kernels/test_stablehlo.py +9 -9
- {mplang_nightly-0.1.dev203/tests → mplang_nightly-0.1.dev266/tests/v1}/kernels/test_value.py +5 -5
- {mplang_nightly-0.1.dev203/tests → mplang_nightly-0.1.dev266/tests/v1}/kernels/test_value_serde.py +5 -5
- {mplang_nightly-0.1.dev203/tests → mplang_nightly-0.1.dev266/tests/v1}/ops/dummy.py +5 -5
- {mplang_nightly-0.1.dev203/tests → mplang_nightly-0.1.dev266/tests/v1}/ops/test_basic_pack.py +6 -6
- {mplang_nightly-0.1.dev203/tests → mplang_nightly-0.1.dev266/tests/v1}/ops/test_crypto.py +6 -6
- {mplang_nightly-0.1.dev203/tests → mplang_nightly-0.1.dev266/tests/v1}/ops/test_feop_base.py +5 -5
- {mplang_nightly-0.1.dev203/tests → mplang_nightly-0.1.dev266/tests/v1}/ops/test_jax_cc.py +1 -1
- mplang_nightly-0.1.dev266/tests/v1/ops/test_nnx_cc.py +265 -0
- {mplang_nightly-0.1.dev203/tests → mplang_nightly-0.1.dev266/tests/v1}/ops/test_phe.py +3 -3
- {mplang_nightly-0.1.dev203/tests → mplang_nightly-0.1.dev266/tests/v1}/ops/test_spu.py +2 -2
- {mplang_nightly-0.1.dev203/tests → mplang_nightly-0.1.dev266/tests/v1}/ops/test_spu_defensive.py +3 -3
- {mplang_nightly-0.1.dev203/tests → mplang_nightly-0.1.dev266/tests/v1}/ops/test_sql.py +8 -6
- {mplang_nightly-0.1.dev203/tests → mplang_nightly-0.1.dev266/tests/v1}/ops/test_sql_cc.py +9 -9
- {mplang_nightly-0.1.dev203/tests → mplang_nightly-0.1.dev266/tests/v1}/ops/test_table_tensor_conversion.py +7 -7
- {mplang_nightly-0.1.dev203/tests → mplang_nightly-0.1.dev266/tests/v1}/runtime/test_cli.py +12 -6
- {mplang_nightly-0.1.dev203/tests → mplang_nightly-0.1.dev266/tests/v1}/runtime/test_communicator.py +7 -7
- {mplang_nightly-0.1.dev203/tests → mplang_nightly-0.1.dev266/tests/v1}/runtime/test_driver.py +3 -3
- {mplang_nightly-0.1.dev203/tests → mplang_nightly-0.1.dev266/tests/v1}/runtime/test_server.py +4 -4
- {mplang_nightly-0.1.dev203/tests → mplang_nightly-0.1.dev266/tests/v1}/runtime/test_simulation.py +6 -6
- {mplang_nightly-0.1.dev203/tests → mplang_nightly-0.1.dev266/tests/v1}/simp/test_mpi.py +4 -4
- {mplang_nightly-0.1.dev203/tests → mplang_nightly-0.1.dev266/tests/v1}/simp/test_random.py +1 -1
- {mplang_nightly-0.1.dev203/tests → mplang_nightly-0.1.dev266/tests/v1}/simp/test_smpc.py +4 -4
- {mplang_nightly-0.1.dev203/tests → mplang_nightly-0.1.dev266/tests/v1}/simp/test_sugar.py +2 -2
- {mplang_nightly-0.1.dev203/tests → mplang_nightly-0.1.dev266/tests/v1}/utils/server_fixtures.py +1 -1
- {mplang_nightly-0.1.dev203/tests → mplang_nightly-0.1.dev266/tests/v1}/utils/test_func_utils.py +1 -1
- {mplang_nightly-0.1.dev203/tests → mplang_nightly-0.1.dev266/tests/v1}/utils/test_spu_utils.py +1 -1
- {mplang_nightly-0.1.dev203/tests → mplang_nightly-0.1.dev266/tests/v1}/utils/test_table_utils.py +1 -1
- mplang_nightly-0.1.dev266/tests/v2/__init__.py +13 -0
- mplang_nightly-0.1.dev266/tests/v2/backends/__init__.py +13 -0
- mplang_nightly-0.1.dev266/tests/v2/backends/simp_driver/__init__.py +15 -0
- mplang_nightly-0.1.dev266/tests/v2/backends/simp_driver/test_http.py +215 -0
- mplang_nightly-0.1.dev266/tests/v2/backends/simp_worker/__init__.py +15 -0
- mplang_nightly-0.1.dev266/tests/v2/backends/simp_worker/test_http.py +225 -0
- mplang_nightly-0.1.dev266/tests/v2/backends/simp_worker/test_mem.py +102 -0
- mplang_nightly-0.1.dev266/tests/v2/backends/test_bfv_impl.py +344 -0
- mplang_nightly-0.1.dev266/tests/v2/backends/test_crypto_impl.py +541 -0
- mplang_nightly-0.1.dev266/tests/v2/backends/test_okvs_binding.py +115 -0
- mplang_nightly-0.1.dev266/tests/v2/backends/test_simp_integration.py +196 -0
- mplang_nightly-0.1.dev266/tests/v2/backends/test_simp_object_store.py +106 -0
- mplang_nightly-0.1.dev266/tests/v2/backends/test_spu_impl.py +115 -0
- mplang_nightly-0.1.dev266/tests/v2/backends/test_table_impl.py +436 -0
- mplang_nightly-0.1.dev266/tests/v2/backends/test_tee_impl.py +148 -0
- mplang_nightly-0.1.dev266/tests/v2/backends/test_tensor_impl.py +177 -0
- mplang_nightly-0.1.dev266/tests/v2/backends/test_verify_clean.py +123 -0
- mplang_nightly-0.1.dev266/tests/v2/conftest.py +54 -0
- mplang_nightly-0.1.dev266/tests/v2/dialects/__init__.py +13 -0
- mplang_nightly-0.1.dev266/tests/v2/dialects/test_bfv.py +178 -0
- mplang_nightly-0.1.dev266/tests/v2/dialects/test_crypto.py +214 -0
- mplang_nightly-0.1.dev266/tests/v2/dialects/test_dtypes.py +219 -0
- mplang_nightly-0.1.dev266/tests/v2/dialects/test_field.py +49 -0
- mplang_nightly-0.1.dev266/tests/v2/dialects/test_func.py +130 -0
- mplang_nightly-0.1.dev266/tests/v2/dialects/test_okvs.py +56 -0
- mplang_nightly-0.1.dev266/tests/v2/dialects/test_okvs_bench.py +55 -0
- mplang_nightly-0.1.dev266/tests/v2/dialects/test_phe.py +531 -0
- mplang_nightly-0.1.dev266/tests/v2/dialects/test_simp.py +564 -0
- mplang_nightly-0.1.dev266/tests/v2/dialects/test_simp_comm.py +190 -0
- mplang_nightly-0.1.dev266/tests/v2/dialects/test_spu.py +214 -0
- mplang_nightly-0.1.dev266/tests/v2/dialects/test_table.py +60 -0
- mplang_nightly-0.1.dev266/tests/v2/dialects/test_tee.py +156 -0
- mplang_nightly-0.1.dev266/tests/v2/dialects/test_tensor.py +196 -0
- mplang_nightly-0.1.dev266/tests/v2/edsl/__init__.py +15 -0
- mplang_nightly-0.1.dev266/tests/v2/edsl/test_context.py +164 -0
- mplang_nightly-0.1.dev266/tests/v2/edsl/test_graph.py +664 -0
- mplang_nightly-0.1.dev266/tests/v2/edsl/test_primitive.py +252 -0
- mplang_nightly-0.1.dev266/tests/v2/edsl/test_primitive_multi_output.py +269 -0
- mplang_nightly-0.1.dev266/tests/v2/edsl/test_printer.py +100 -0
- mplang_nightly-0.1.dev266/tests/v2/edsl/test_serde.py +279 -0
- mplang_nightly-0.1.dev266/tests/v2/edsl/test_tracer.py +309 -0
- mplang_nightly-0.1.dev266/tests/v2/edsl/test_typing.py +836 -0
- mplang_nightly-0.1.dev266/tests/v2/edsl/test_typing_graph_serde.py +346 -0
- mplang_nightly-0.1.dev266/tests/v2/libs/device/__init__.py +13 -0
- mplang_nightly-0.1.dev266/tests/v2/libs/device/conftest.py +245 -0
- mplang_nightly-0.1.dev266/tests/v2/libs/device/test_device_api_errors.py +355 -0
- mplang_nightly-0.1.dev266/tests/v2/libs/device/test_device_dialects.py +345 -0
- mplang_nightly-0.1.dev266/tests/v2/libs/device/test_device_layouts.py +357 -0
- mplang_nightly-0.1.dev266/tests/v2/libs/device/test_device_tee.py +220 -0
- mplang_nightly-0.1.dev266/tests/v2/libs/ml/__init__.py +15 -0
- mplang_nightly-0.1.dev266/tests/v2/libs/ml/test_sgb.py +164 -0
- mplang_nightly-0.1.dev266/tests/v2/libs/ml/test_sgb_bench.py +401 -0
- mplang_nightly-0.1.dev266/tests/v2/libs/mpc/__init__.py +13 -0
- mplang_nightly-0.1.dev266/tests/v2/libs/mpc/analytics/__init__.py +15 -0
- mplang_nightly-0.1.dev266/tests/v2/libs/mpc/analytics/test_aggregation.py +77 -0
- mplang_nightly-0.1.dev266/tests/v2/libs/mpc/analytics/test_groupby.py +207 -0
- mplang_nightly-0.1.dev266/tests/v2/libs/mpc/analytics/test_permutation.py +108 -0
- mplang_nightly-0.1.dev266/tests/v2/libs/mpc/ot/__init__.py +15 -0
- mplang_nightly-0.1.dev266/tests/v2/libs/mpc/ot/test_ot.py +196 -0
- mplang_nightly-0.1.dev266/tests/v2/libs/mpc/ot/test_ot_extension.py +58 -0
- mplang_nightly-0.1.dev266/tests/v2/libs/mpc/ot/test_silent_ot.py +101 -0
- mplang_nightly-0.1.dev266/tests/v2/libs/mpc/psi/__init__.py +15 -0
- mplang_nightly-0.1.dev266/tests/v2/libs/mpc/psi/test_okvs_gct.py +112 -0
- mplang_nightly-0.1.dev266/tests/v2/libs/mpc/psi/test_oprf.py +79 -0
- mplang_nightly-0.1.dev266/tests/v2/libs/mpc/psi/test_psi.py +115 -0
- mplang_nightly-0.1.dev266/tests/v2/libs/mpc/psi/test_psi_bench.py +155 -0
- mplang_nightly-0.1.dev266/tests/v2/libs/mpc/psi/test_rr22.py +288 -0
- mplang_nightly-0.1.dev266/tests/v2/libs/mpc/psi/verify_psi_okvs_logic.py +164 -0
- mplang_nightly-0.1.dev266/tests/v2/libs/mpc/test_field_gf128.py +97 -0
- mplang_nightly-0.1.dev266/tests/v2/libs/mpc/test_utils.py +82 -0
- mplang_nightly-0.1.dev266/tests/v2/libs/mpc/vole/__init__.py +15 -0
- mplang_nightly-0.1.dev266/tests/v2/libs/mpc/vole/test_gilboa_manual.py +137 -0
- mplang_nightly-0.1.dev266/tests/v2/libs/mpc/vole/test_ldpc.py +179 -0
- mplang_nightly-0.1.dev266/tests/v2/libs/mpc/vole/test_silver_vole.py +183 -0
- mplang_nightly-0.1.dev266/tests/v2/libs/mpc/vole/test_vole.py +87 -0
- mplang_nightly-0.1.dev266/tests/v2/libs/mpc/vole/test_vole_bench.py +165 -0
- mplang_nightly-0.1.dev266/tests/v2/libs/mpc/vole/verify_vole_logic.py +171 -0
- mplang_nightly-0.1.dev266/tests/v2/libs/test_collective.py +251 -0
- mplang_nightly-0.1.dev266/tests/v2/libs/test_simple_guide.py +60 -0
- mplang_nightly-0.1.dev266/tests/v2/runtime/test_interpreter_async.py +113 -0
- mplang_nightly-0.1.dev266/tests/v2/runtime/test_object_store.py +55 -0
- mplang_nightly-0.1.dev266/tests/v2/runtime/test_object_store_fs.py +104 -0
- mplang_nightly-0.1.dev266/tests/v2/test_fetch_semantics.py +129 -0
- mplang_nightly-0.1.dev266/tests/v2/test_pytree_io.py +231 -0
- mplang_nightly-0.1.dev266/tests/v2/test_store.py +89 -0
- mplang_nightly-0.1.dev266/tests/v2/utils/__init__.py +13 -0
- mplang_nightly-0.1.dev266/tests/v2/utils/tensor_patch.py +131 -0
- mplang_nightly-0.1.dev266/tutorials/MIGRATION.md +191 -0
- {mplang_nightly-0.1.dev203 → mplang_nightly-0.1.dev266}/tutorials/run.sh +18 -15
- {mplang_nightly-0.1.dev203/tutorials → mplang_nightly-0.1.dev266/tutorials/v1}/device/00_device_basics.py +1 -1
- {mplang_nightly-0.1.dev203/tutorials → mplang_nightly-0.1.dev266/tutorials/v1}/device/01_function_decorator.py +1 -1
- {mplang_nightly-0.1.dev203/tutorials → mplang_nightly-0.1.dev266/tutorials/v1}/device/02_simulation_and_driver.py +6 -3
- {mplang_nightly-0.1.dev203/tutorials → mplang_nightly-0.1.dev266/tutorials/v1}/device/03_run_jax.py +1 -1
- {mplang_nightly-0.1.dev203/tutorials → mplang_nightly-0.1.dev266/tutorials/v1}/device/04_run_sql.py +8 -8
- {mplang_nightly-0.1.dev203/tutorials → mplang_nightly-0.1.dev266/tutorials/v1}/device/05_pipeline.py +3 -3
- {mplang_nightly-0.1.dev203/tutorials → mplang_nightly-0.1.dev266/tutorials/v1}/device/06_ir_dump_and_analysis.py +1 -1
- mplang_nightly-0.1.dev266/tutorials/v1/device/07_run_nnx.py +627 -0
- {mplang_nightly-0.1.dev203/tutorials → mplang_nightly-0.1.dev266/tutorials/v1}/pitfalls/late_binding.py +1 -1
- {mplang_nightly-0.1.dev203/tutorials → mplang_nightly-0.1.dev266/tutorials/v1}/pitfalls/rand.py +1 -1
- {mplang_nightly-0.1.dev203/tutorials → mplang_nightly-0.1.dev266/tutorials/v1}/simp/00_basic.py +1 -1
- {mplang_nightly-0.1.dev203/tutorials → mplang_nightly-0.1.dev266/tutorials/v1}/simp/01_condition.py +1 -1
- {mplang_nightly-0.1.dev203/tutorials → mplang_nightly-0.1.dev266/tutorials/v1}/simp/02_whileloop.py +1 -1
- {mplang_nightly-0.1.dev203/tutorials → mplang_nightly-0.1.dev266/tutorials/v1}/simp/03_stdio.py +2 -2
- {mplang_nightly-0.1.dev203/tutorials → mplang_nightly-0.1.dev266/tutorials/v1}/simp/04_phe.py +2 -2
- {mplang_nightly-0.1.dev203/tutorials → mplang_nightly-0.1.dev266/tutorials/v1}/simp/05_tee.py +48 -1
- {mplang_nightly-0.1.dev203/tutorials → mplang_nightly-0.1.dev266/tutorials/v1}/simp/06_fhe.py +2 -2
- {mplang_nightly-0.1.dev203/tutorials → mplang_nightly-0.1.dev266/tutorials/v1}/simp/07_advanced.py +1 -1
- mplang_nightly-0.1.dev266/tutorials/v1/simp/08_simple_secret_sharing.py +291 -0
- mplang_nightly-0.1.dev266/tutorials/v2/00_device_basics.py +190 -0
- mplang_nightly-0.1.dev266/tutorials/v2/01_function_decorator.py +135 -0
- mplang_nightly-0.1.dev266/tutorials/v2/02_simulation_and_driver.py +125 -0
- mplang_nightly-0.1.dev266/tutorials/v2/03_run_jax.py +131 -0
- mplang_nightly-0.1.dev266/tutorials/v2/04_ir_dump_and_analysis.py +115 -0
- mplang_nightly-0.1.dev266/tutorials/v2/05_run_sql.py +193 -0
- mplang_nightly-0.1.dev266/tutorials/v2/06_pipeline.py +279 -0
- mplang_nightly-0.1.dev266/tutorials/v2/07_stax_nn.py +272 -0
- mplang_nightly-0.1.dev266/tutorials/v2/__init__.py +15 -0
- {mplang_nightly-0.1.dev203 → mplang_nightly-0.1.dev266}/LICENSE +0 -0
- {mplang_nightly-0.1.dev203/examples → mplang_nightly-0.1.dev266/examples/v1}/conf/3pc.yaml +0 -0
- {mplang_nightly-0.1.dev203/examples → mplang_nightly-0.1.dev266/examples/v1}/stax_nn/README.md +0 -0
- {mplang_nightly-0.1.dev203/examples → mplang_nightly-0.1.dev266/examples/v1}/stax_nn/models.py +0 -0
- {mplang_nightly-0.1.dev203/examples → mplang_nightly-0.1.dev266/examples/v1}/xgboost/hist_jax.py +0 -0
- {mplang_nightly-0.1.dev203/examples → mplang_nightly-0.1.dev266/examples/v1}/xgboost/hist_jax_test.py +0 -0
- {mplang_nightly-0.1.dev203/examples → mplang_nightly-0.1.dev266/examples/v1}/xgboost/naive_np.py +0 -0
- {mplang_nightly-0.1.dev203/examples → mplang_nightly-0.1.dev266/examples/v1}/xgboost/readme.md +0 -0
- {mplang_nightly-0.1.dev203 → mplang_nightly-0.1.dev266}/hatch_build.py +0 -0
- {mplang_nightly-0.1.dev203/mplang → mplang_nightly-0.1.dev266/mplang/v1}/core/cluster.py +0 -0
- {mplang_nightly-0.1.dev203/mplang → mplang_nightly-0.1.dev266/mplang/v1}/core/dtypes.py +0 -0
- {mplang_nightly-0.1.dev203/mplang → mplang_nightly-0.1.dev266/mplang/v1}/core/mask.py +0 -0
- {mplang_nightly-0.1.dev203/mplang → mplang_nightly-0.1.dev266/mplang/v1}/protos/v1alpha1/mpir_pb2.py +0 -0
- {mplang_nightly-0.1.dev203/mplang → mplang_nightly-0.1.dev266/mplang/v1}/protos/v1alpha1/mpir_pb2.pyi +0 -0
- {mplang_nightly-0.1.dev203/mplang → mplang_nightly-0.1.dev266/mplang/v1}/protos/v1alpha1/value_pb2.py +0 -0
- {mplang_nightly-0.1.dev203/mplang → mplang_nightly-0.1.dev266/mplang/v1}/protos/v1alpha1/value_pb2.pyi +0 -0
- {mplang_nightly-0.1.dev203/mplang → mplang_nightly-0.1.dev266/mplang/v1}/runtime/exceptions.py +0 -0
- {mplang_nightly-0.1.dev203/mplang → mplang_nightly-0.1.dev266/mplang/v1}/runtime/http_api.md +0 -0
- {mplang_nightly-0.1.dev203/mplang → mplang_nightly-0.1.dev266/mplang/v1}/runtime/link_comm.py +0 -0
- {mplang_nightly-0.1.dev203/mplang → mplang_nightly-0.1.dev266/mplang/v1}/simp/__init__.py +0 -0
- {mplang_nightly-0.1.dev203/mplang → mplang_nightly-0.1.dev266/mplang/v1}/utils/__init__.py +0 -0
- {mplang_nightly-0.1.dev203/mplang → mplang_nightly-0.1.dev266/mplang/v1}/utils/crypto.py +0 -0
- {mplang_nightly-0.1.dev203/mplang → mplang_nightly-0.1.dev266/mplang/v1}/utils/func_utils.py +0 -0
- {mplang_nightly-0.1.dev203/mplang → mplang_nightly-0.1.dev266/mplang/v1}/utils/spu_utils.py +0 -0
- {mplang_nightly-0.1.dev203/tests → mplang_nightly-0.1.dev266/tests/v1}/__init__.py +0 -0
- {mplang_nightly-0.1.dev203/tests → mplang_nightly-0.1.dev266/tests/v1}/core/__init__.py +0 -0
- {mplang_nightly-0.1.dev203/tests → mplang_nightly-0.1.dev266/tests/v1}/core/expr/__init__.py +0 -0
- {mplang_nightly-0.1.dev203/tests → mplang_nightly-0.1.dev266/tests/v1}/device/__init__.py +0 -0
- {mplang_nightly-0.1.dev203/tests → mplang_nightly-0.1.dev266/tests/v1}/integration/README.md +0 -0
- {mplang_nightly-0.1.dev203/tests → mplang_nightly-0.1.dev266/tests/v1}/ops/__init__.py +0 -0
- {mplang_nightly-0.1.dev203/tests → mplang_nightly-0.1.dev266/tests/v1}/runtime/__init__.py +0 -0
- {mplang_nightly-0.1.dev203/tests → mplang_nightly-0.1.dev266/tests/v1}/utils/__init__.py +0 -0
- {mplang_nightly-0.1.dev203 → mplang_nightly-0.1.dev266}/tutorials/README.md +0 -0
- {mplang_nightly-0.1.dev203 → mplang_nightly-0.1.dev266}/tutorials/__init__.py +0 -0
- {mplang_nightly-0.1.dev203 → mplang_nightly-0.1.dev266}/tutorials/data/alice.csv +0 -0
- {mplang_nightly-0.1.dev203 → mplang_nightly-0.1.dev266}/tutorials/data/bob.csv +0 -0
- {mplang_nightly-0.1.dev203 → mplang_nightly-0.1.dev266}/tutorials/data/prepare_vertical_iris.py +0 -0
|
@@ -1,6 +1,6 @@
|
|
|
1
1
|
Metadata-Version: 2.4
|
|
2
2
|
Name: mplang-nightly
|
|
3
|
-
Version: 0.1.
|
|
3
|
+
Version: 0.1.dev266
|
|
4
4
|
Summary: Multi-Party Programming Language
|
|
5
5
|
Author-email: SecretFlow Team <secretflow-contact@service.alipay.com>
|
|
6
6
|
License: Apache License
|
|
@@ -205,15 +205,21 @@ License: Apache License
|
|
|
205
205
|
See the License for the specific language governing permissions and
|
|
206
206
|
limitations under the License.
|
|
207
207
|
License-File: LICENSE
|
|
208
|
-
Requires-Python:
|
|
208
|
+
Requires-Python: <3.13,>=3.11
|
|
209
|
+
Requires-Dist: coincurve>=20.0.0
|
|
210
|
+
Requires-Dist: cryptography>=43.0.0
|
|
209
211
|
Requires-Dist: duckdb>=1.0.0
|
|
210
212
|
Requires-Dist: fastapi
|
|
211
|
-
Requires-Dist:
|
|
213
|
+
Requires-Dist: flax>=0.12.0
|
|
214
|
+
Requires-Dist: httpx<1.0.0,>=0.27.0
|
|
215
|
+
Requires-Dist: jax[cpu]==0.8.0
|
|
212
216
|
Requires-Dist: lightphe<0.1.0,>=0.0.15
|
|
217
|
+
Requires-Dist: numpy>=2.0.0
|
|
213
218
|
Requires-Dist: pandas>=2.0.0
|
|
214
219
|
Requires-Dist: protobuf<6.0,>=5.0
|
|
215
220
|
Requires-Dist: pyarrow>=14.0.0
|
|
216
|
-
Requires-Dist:
|
|
221
|
+
Requires-Dist: pyyaml>=6.0
|
|
222
|
+
Requires-Dist: spu>=0.10.0.dev20251208
|
|
217
223
|
Requires-Dist: sqlglot>=23.0.0
|
|
218
224
|
Requires-Dist: tenseal==0.3.16
|
|
219
225
|
Requires-Dist: typing-extensions
|
|
@@ -242,7 +248,7 @@ multiple parties in a synchronous, SPMD (Single Program, Multiple Data) fashion.
|
|
|
242
248
|
|
|
243
249
|
### Installation
|
|
244
250
|
|
|
245
|
-
You'll need a modern Python environment (3.
|
|
251
|
+
You'll need a modern Python environment (3.11+). We recommend using `uv` for fast installation.
|
|
246
252
|
|
|
247
253
|
```bash
|
|
248
254
|
# Install uv (if not already installed)
|
|
@@ -20,7 +20,7 @@ multiple parties in a synchronous, SPMD (Single Program, Multiple Data) fashion.
|
|
|
20
20
|
|
|
21
21
|
### Installation
|
|
22
22
|
|
|
23
|
-
You'll need a modern Python environment (3.
|
|
23
|
+
You'll need a modern Python environment (3.11+). We recommend using `uv` for fast installation.
|
|
24
24
|
|
|
25
25
|
```bash
|
|
26
26
|
# Install uv (if not already installed)
|
{mplang_nightly-0.1.dev203/examples → mplang_nightly-0.1.dev266/examples/v1}/stax_nn/stax_nn.py
RENAMED
|
@@ -27,11 +27,11 @@ import yaml
|
|
|
27
27
|
from jax.example_libraries import stax
|
|
28
28
|
from sklearn.metrics import accuracy_score
|
|
29
29
|
|
|
30
|
-
import mplang as mp
|
|
30
|
+
import mplang.v1 as mp
|
|
31
31
|
|
|
32
32
|
parser = argparse.ArgumentParser(description="distributed driver.")
|
|
33
33
|
parser.add_argument("--model", default="network_a", type=str)
|
|
34
|
-
parser.add_argument("-c", "--config", default="examples/conf/3pc.yaml", type=str)
|
|
34
|
+
parser.add_argument("-c", "--config", default="examples/v1/conf/3pc.yaml", type=str)
|
|
35
35
|
parser.add_argument("-e", "--epoch", default=5, type=int)
|
|
36
36
|
parser.add_argument("-b", "--batch_size", default=128, type=int)
|
|
37
37
|
parser.add_argument("-o", "--optimizer", default="SGD", type=str)
|
|
@@ -0,0 +1,615 @@
|
|
|
1
|
+
# Copyright 2025 Ant Group Co., Ltd.
|
|
2
|
+
#
|
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
4
|
+
# you may not use this file except in compliance with the License.
|
|
5
|
+
# You may obtain a copy of the License at
|
|
6
|
+
#
|
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
8
|
+
#
|
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
12
|
+
# See the License for the specific language governing permissions and
|
|
13
|
+
# limitations under the License.
|
|
14
|
+
|
|
15
|
+
"""
|
|
16
|
+
Microbenchmark for FHE(BFV) histogram path in SecureBoost.
|
|
17
|
+
|
|
18
|
+
This script measures the time to compute PP-side cumulative bucket sums
|
|
19
|
+
via encrypted ct·ct dot products using TenSEAL/SEAL BFV vector backend.
|
|
20
|
+
|
|
21
|
+
It provides two modes:
|
|
22
|
+
- classic: separate g/h ciphertexts + ct·ct dot (current training path)
|
|
23
|
+
- interleaved: interleave g/h into one ct, do one ct·ct mul + two ct·pt dots (even/odd)
|
|
24
|
+
|
|
25
|
+
Usage examples:
|
|
26
|
+
uv run -q python examples/v1/xgboost/bench_fhe_hist.py --world-size 2 --m 4096 --n-total 16 --n-ap 4 --k 16 --t 4 --reps 3 --mode classic
|
|
27
|
+
uv run -q python examples/v1/xgboost/bench_fhe_hist.py --world-size 2 --m 4096 --n-total 16 --n-ap 4 --k 16 --t 4 --reps 3 --mode interleaved
|
|
28
|
+
"""
|
|
29
|
+
|
|
30
|
+
from __future__ import annotations
|
|
31
|
+
|
|
32
|
+
import argparse
|
|
33
|
+
import time
|
|
34
|
+
from functools import partial
|
|
35
|
+
|
|
36
|
+
import jax
|
|
37
|
+
import jax.numpy as jnp
|
|
38
|
+
import numpy as np
|
|
39
|
+
from examples.xgboost.sgb import (
|
|
40
|
+
DEFAULT_FXP_BITS,
|
|
41
|
+
batch_feature_wise_bucket_sum_fhe_vector,
|
|
42
|
+
build_bins_equi_width,
|
|
43
|
+
compute_bin_indices,
|
|
44
|
+
compute_gh,
|
|
45
|
+
compute_init_pred,
|
|
46
|
+
quantize_gh,
|
|
47
|
+
)
|
|
48
|
+
|
|
49
|
+
import mplang.v1 as mp
|
|
50
|
+
from mplang.v1.ops import fhe
|
|
51
|
+
|
|
52
|
+
|
|
53
|
+
def _gen_data(n_samples: int, n_total_features: int, n_features_ap: int, seed: int):
|
|
54
|
+
rng = np.random.default_rng(seed)
|
|
55
|
+
X = rng.normal(size=(n_samples, n_total_features)).astype(np.float32)
|
|
56
|
+
# make a simple linear label with noise
|
|
57
|
+
w = rng.normal(size=(n_total_features,)).astype(np.float32)
|
|
58
|
+
z = X @ w + 0.1 * rng.normal(size=(n_samples,)).astype(np.float32)
|
|
59
|
+
p = 1 / (1 + np.exp(-z))
|
|
60
|
+
y = (p > 0.5).astype(np.float32)
|
|
61
|
+
|
|
62
|
+
X_ap = X[:, :n_features_ap]
|
|
63
|
+
X_pp = X[:, n_features_ap:]
|
|
64
|
+
return X_ap, X_pp, y
|
|
65
|
+
|
|
66
|
+
|
|
67
|
+
@mp.function
|
|
68
|
+
def _bench_once(
|
|
69
|
+
ap_id: int,
|
|
70
|
+
pp_ids: list[int],
|
|
71
|
+
X_ap: np.ndarray,
|
|
72
|
+
X_pp_splits: list[np.ndarray],
|
|
73
|
+
y: np.ndarray,
|
|
74
|
+
k: int,
|
|
75
|
+
t: int,
|
|
76
|
+
reps: int,
|
|
77
|
+
mode: str,
|
|
78
|
+
include_precompute: bool,
|
|
79
|
+
breakdown: bool,
|
|
80
|
+
):
|
|
81
|
+
# Place data
|
|
82
|
+
X_ap_j = mp.run_jax_at(ap_id, lambda x: x, jnp.array(X_ap, dtype=jnp.float32))
|
|
83
|
+
X_pp_j = [
|
|
84
|
+
mp.run_jax_at(pp, lambda x: x, jnp.array(xpp, dtype=jnp.float32))
|
|
85
|
+
for pp, xpp in zip(pp_ids, X_pp_splits, strict=True)
|
|
86
|
+
]
|
|
87
|
+
y_j = mp.run_jax_at(ap_id, lambda x: x, jnp.array(y, dtype=jnp.float32))
|
|
88
|
+
|
|
89
|
+
# Binning per party
|
|
90
|
+
build_bins_vmapped = jax.vmap(partial(build_bins_equi_width, max_bin=k), in_axes=1)
|
|
91
|
+
compute_indices_vmapped = jax.vmap(compute_bin_indices, in_axes=(1, 0), out_axes=1)
|
|
92
|
+
|
|
93
|
+
ap_bins = mp.run_jax_at(ap_id, build_bins_vmapped, X_ap_j)
|
|
94
|
+
_ = mp.run_jax_at(ap_id, compute_indices_vmapped, X_ap_j, ap_bins)
|
|
95
|
+
|
|
96
|
+
pp_bins = [
|
|
97
|
+
mp.run_jax_at(pp, build_bins_vmapped, X_pp_j[i]) for i, pp in enumerate(pp_ids)
|
|
98
|
+
]
|
|
99
|
+
pp_idx = [
|
|
100
|
+
mp.run_jax_at(pp, compute_indices_vmapped, X_pp_j[i], pp_bins[i])
|
|
101
|
+
for i, pp in enumerate(pp_ids)
|
|
102
|
+
]
|
|
103
|
+
|
|
104
|
+
# AP GH + quantize + encrypt
|
|
105
|
+
init_pred = mp.run_jax_at(ap_id, compute_init_pred, y_j)
|
|
106
|
+
logits0 = mp.run_jax_at(ap_id, lambda p, m=y_j.shape[0]: p * jnp.ones(m), init_pred)
|
|
107
|
+
GH = mp.run_jax_at(ap_id, compute_gh, y_j, logits0)
|
|
108
|
+
|
|
109
|
+
fxp_scale = 1 << DEFAULT_FXP_BITS
|
|
110
|
+
Q = mp.run_jax_at(ap_id, quantize_gh, GH, fxp_scale)
|
|
111
|
+
qg = mp.run_jax_at(ap_id, lambda a: a[:, 0].astype(jnp.int64), Q)
|
|
112
|
+
qh = mp.run_jax_at(ap_id, lambda a: a[:, 1].astype(jnp.int64), Q)
|
|
113
|
+
|
|
114
|
+
priv_ctx, pub_ctx, _ = mp.run_at(ap_id, fhe.keygen, scheme="BFV")
|
|
115
|
+
|
|
116
|
+
# Prepare ciphertext(s)
|
|
117
|
+
g_ct = None # type: ignore[assignment]
|
|
118
|
+
h_ct = None # type: ignore[assignment]
|
|
119
|
+
gh_ct = None # type: ignore[assignment]
|
|
120
|
+
if mode in ("classic", "classic_cached"):
|
|
121
|
+
g_ct = mp.run_at(ap_id, fhe.encrypt, qg, pub_ctx)
|
|
122
|
+
h_ct = mp.run_at(ap_id, fhe.encrypt, qh, pub_ctx)
|
|
123
|
+
elif mode in (
|
|
124
|
+
"interleaved",
|
|
125
|
+
"interleaved_cached",
|
|
126
|
+
"interleaved_fused",
|
|
127
|
+
"interleaved_fused_cached",
|
|
128
|
+
):
|
|
129
|
+
# Interleave qg and qh into one vector: [g0,h0,g1,h1,...]
|
|
130
|
+
def _interleave(a, b):
|
|
131
|
+
m = a.shape[0]
|
|
132
|
+
out = jnp.empty((m * 2,), dtype=jnp.int64)
|
|
133
|
+
out = out.at[0::2].set(a)
|
|
134
|
+
out = out.at[1::2].set(b)
|
|
135
|
+
return out
|
|
136
|
+
|
|
137
|
+
qgh = mp.run_jax_at(ap_id, _interleave, qg, qh)
|
|
138
|
+
gh_ct = mp.run_at(ap_id, fhe.encrypt, qgh, pub_ctx)
|
|
139
|
+
else:
|
|
140
|
+
raise ValueError(f"Unknown mode: {mode}")
|
|
141
|
+
rng = mp.run_jax_at(
|
|
142
|
+
ap_id,
|
|
143
|
+
lambda m: jnp.array(
|
|
144
|
+
np.random.default_rng(0).integers(0, t, size=m), dtype=jnp.int64
|
|
145
|
+
),
|
|
146
|
+
y_j.shape[0],
|
|
147
|
+
)
|
|
148
|
+
|
|
149
|
+
def mk_subgroup_map(bt_level, group_size):
|
|
150
|
+
group_indices = jnp.arange(group_size)[:, None]
|
|
151
|
+
return (group_indices == bt_level).astype(jnp.int8)
|
|
152
|
+
|
|
153
|
+
# Precompute subgroup maps per-PP once (rng fixed) and parity selectors once if needed
|
|
154
|
+
subgroup_maps = []
|
|
155
|
+
for pp in pp_ids:
|
|
156
|
+
pub_ctx_pp = mp.p2p(ap_id, pp, pub_ctx)
|
|
157
|
+
rng_pp = mp.p2p(ap_id, pp, rng)
|
|
158
|
+
subgroup_map_pp = mp.run_jax_at(pp, mk_subgroup_map, rng_pp, t)
|
|
159
|
+
subgroup_maps.append((pp, pub_ctx_pp, subgroup_map_pp))
|
|
160
|
+
|
|
161
|
+
even_sel = None
|
|
162
|
+
odd_sel = None
|
|
163
|
+
if mode in (
|
|
164
|
+
"interleaved",
|
|
165
|
+
"interleaved_cached",
|
|
166
|
+
):
|
|
167
|
+
|
|
168
|
+
def _build_parity_selectors(m_samples):
|
|
169
|
+
n = m_samples * 2
|
|
170
|
+
even = jnp.zeros((n,), dtype=jnp.int64).at[0::2].set(1)
|
|
171
|
+
odd = jnp.zeros((n,), dtype=jnp.int64).at[1::2].set(1)
|
|
172
|
+
return even, odd
|
|
173
|
+
|
|
174
|
+
even_sel, odd_sel = mp.run_jax_at(ap_id, _build_parity_selectors, y_j.shape[0])
|
|
175
|
+
|
|
176
|
+
# Optional: precompute and encrypt all bucket masks per-PP for cached modes
|
|
177
|
+
cached_masks = None
|
|
178
|
+
pre_dt = mp.run_jax_at(ap_id, lambda: jnp.array(0.0, dtype=jnp.float64))
|
|
179
|
+
if mode in ("interleaved_cached", "classic_cached"):
|
|
180
|
+
# Helper function to duplicate mask to interleaved length (used only for interleaved mode)
|
|
181
|
+
def _dup2(mask):
|
|
182
|
+
n = mask.shape[0]
|
|
183
|
+
out = jnp.empty((n * 2,), dtype=jnp.int64)
|
|
184
|
+
out = out.at[0::2].set(mask)
|
|
185
|
+
out = out.at[1::2].set(mask)
|
|
186
|
+
return out
|
|
187
|
+
|
|
188
|
+
use_interleave = mode == "interleaved_cached"
|
|
189
|
+
|
|
190
|
+
cached_masks = [] # list per PP: [ [list per group: [list per feature: [mask_ct per bucket]]] ]
|
|
191
|
+
tpre0 = mp.run_jax_at(ap_id, lambda: jnp.array(time.time(), dtype=jnp.float64))
|
|
192
|
+
for i, (pp, pub_ctx_pp, subgroup_map_pp) in enumerate(subgroup_maps):
|
|
193
|
+
feature_size = pp_idx[i].shape[1]
|
|
194
|
+
grp_masks = []
|
|
195
|
+
for grp in range(t):
|
|
196
|
+
gom = mp.run_jax_at(pp, lambda m, idx: m[idx], subgroup_map_pp, grp)
|
|
197
|
+
|
|
198
|
+
def create_masked_order_map(m, om):
|
|
199
|
+
mask_expanded = jnp.expand_dims(m, axis=1)
|
|
200
|
+
mask_full = jnp.broadcast_to(mask_expanded, om.shape)
|
|
201
|
+
return jnp.where(mask_full == 1, om, -1)
|
|
202
|
+
|
|
203
|
+
gom_map = mp.run_jax_at(pp, create_masked_order_map, gom, pp_idx[i])
|
|
204
|
+
|
|
205
|
+
feat_masks = []
|
|
206
|
+
for feature_idx in range(feature_size):
|
|
207
|
+
# Build all bucket masks at once: (k, M)
|
|
208
|
+
def build_bucket_masks(gom_, f_idx, num_buckets):
|
|
209
|
+
def mask_for_b(b_idx, gom_i, f_i):
|
|
210
|
+
fb = gom_i[:, f_i]
|
|
211
|
+
valid_and_in_bucket = (fb >= 0) & (fb <= b_idx)
|
|
212
|
+
return valid_and_in_bucket.astype(jnp.int64)
|
|
213
|
+
|
|
214
|
+
bs = jnp.arange(num_buckets, dtype=jnp.int64)
|
|
215
|
+
return jax.vmap(mask_for_b, in_axes=(0, None, None))(
|
|
216
|
+
bs, gom_, f_idx
|
|
217
|
+
)
|
|
218
|
+
|
|
219
|
+
bucket_masks = mp.run_jax_at(
|
|
220
|
+
pp, build_bucket_masks, gom_map, feature_idx, k
|
|
221
|
+
)
|
|
222
|
+
# Encrypt each bucket mask (with optional duplication for interleaved mode)
|
|
223
|
+
masks_ct = []
|
|
224
|
+
for b in range(k):
|
|
225
|
+
row_b = mp.run_jax_at(pp, lambda M, bi: M[bi], bucket_masks, b)
|
|
226
|
+
# Apply _dup2 transformation only for interleaved mode
|
|
227
|
+
if use_interleave:
|
|
228
|
+
mask_to_encrypt = mp.run_jax_at(pp, _dup2, row_b)
|
|
229
|
+
else:
|
|
230
|
+
mask_to_encrypt = row_b
|
|
231
|
+
mask_ct_pp = mp.run_at(
|
|
232
|
+
pp, fhe.encrypt, mask_to_encrypt, pub_ctx_pp
|
|
233
|
+
)
|
|
234
|
+
mask_ct_ap = mp.p2p(pp, ap_id, mask_ct_pp)
|
|
235
|
+
masks_ct.append(mask_ct_ap)
|
|
236
|
+
feat_masks.append(masks_ct)
|
|
237
|
+
grp_masks.append(feat_masks)
|
|
238
|
+
cached_masks.append(grp_masks)
|
|
239
|
+
tpre1 = mp.run_jax_at(ap_id, lambda: jnp.array(time.time(), dtype=jnp.float64))
|
|
240
|
+
pre_dt = mp.run_jax_at(ap_id, lambda a, b: a - b, tpre1, tpre0)
|
|
241
|
+
|
|
242
|
+
# Run reps and time compute + decrypt assembly across all PPs
|
|
243
|
+
times_total = []
|
|
244
|
+
times_comp = []
|
|
245
|
+
times_dec = []
|
|
246
|
+
for rep_i in range(reps):
|
|
247
|
+
t0 = mp.run_jax_at(ap_id, lambda: jnp.array(time.time(), dtype=jnp.float64))
|
|
248
|
+
comp_parts = []
|
|
249
|
+
dec_parts = []
|
|
250
|
+
for i, (pp, pub_ctx_pp, subgroup_map_pp) in enumerate(subgroup_maps):
|
|
251
|
+
if mode == "classic":
|
|
252
|
+
assert g_ct is not None and h_ct is not None
|
|
253
|
+
tcomp0 = mp.run_jax_at(
|
|
254
|
+
ap_id, lambda: jnp.array(time.time(), dtype=jnp.float64)
|
|
255
|
+
)
|
|
256
|
+
g_lists, h_lists = batch_feature_wise_bucket_sum_fhe_vector(
|
|
257
|
+
g_ct,
|
|
258
|
+
h_ct,
|
|
259
|
+
subgroup_map_pp,
|
|
260
|
+
pp_idx[i],
|
|
261
|
+
k,
|
|
262
|
+
t,
|
|
263
|
+
rank=pp,
|
|
264
|
+
ap_rank=ap_id,
|
|
265
|
+
)
|
|
266
|
+
tcomp1 = mp.run_jax_at(
|
|
267
|
+
ap_id, lambda: jnp.array(time.time(), dtype=jnp.float64)
|
|
268
|
+
)
|
|
269
|
+
comp_parts.append(
|
|
270
|
+
mp.run_jax_at(ap_id, lambda a, b: a - b, tcomp1, tcomp0)
|
|
271
|
+
)
|
|
272
|
+
tdec0 = mp.run_jax_at(
|
|
273
|
+
ap_id, lambda: jnp.array(time.time(), dtype=jnp.float64)
|
|
274
|
+
)
|
|
275
|
+
for grp in range(t):
|
|
276
|
+
enc_g_list = g_lists[grp]
|
|
277
|
+
enc_h_list = h_lists[grp]
|
|
278
|
+
dec_g = [
|
|
279
|
+
mp.run_at(ap_id, fhe.decrypt, ct, priv_ctx) for ct in enc_g_list
|
|
280
|
+
]
|
|
281
|
+
dec_h = [
|
|
282
|
+
mp.run_at(ap_id, fhe.decrypt, ct, priv_ctx) for ct in enc_h_list
|
|
283
|
+
]
|
|
284
|
+
|
|
285
|
+
def _stack(*xs):
|
|
286
|
+
return jnp.stack(xs)
|
|
287
|
+
|
|
288
|
+
_ = mp.run_jax_at(ap_id, _stack, *dec_g)
|
|
289
|
+
_ = mp.run_jax_at(ap_id, _stack, *dec_h)
|
|
290
|
+
tdec1 = mp.run_jax_at(
|
|
291
|
+
ap_id, lambda: jnp.array(time.time(), dtype=jnp.float64)
|
|
292
|
+
)
|
|
293
|
+
dec_parts.append(mp.run_jax_at(ap_id, lambda a, b: a - b, tdec1, tdec0))
|
|
294
|
+
elif mode == "classic_cached":
|
|
295
|
+
assert g_ct is not None and h_ct is not None
|
|
296
|
+
assert cached_masks is not None
|
|
297
|
+
feature_size = pp_idx[i].shape[1]
|
|
298
|
+
g_lists = [[] for _ in range(t)]
|
|
299
|
+
h_lists = [[] for _ in range(t)]
|
|
300
|
+
tcomp0 = mp.run_jax_at(
|
|
301
|
+
ap_id, lambda: jnp.array(time.time(), dtype=jnp.float64)
|
|
302
|
+
)
|
|
303
|
+
for grp in range(t):
|
|
304
|
+
feat_masks = cached_masks[i][grp]
|
|
305
|
+
for feature_idx in range(feature_size):
|
|
306
|
+
for bucket_idx in range(k):
|
|
307
|
+
mask_ct_ap = feat_masks[feature_idx][bucket_idx]
|
|
308
|
+
g_sum_ct = mp.run_at(ap_id, fhe.dot, g_ct, mask_ct_ap)
|
|
309
|
+
h_sum_ct = mp.run_at(ap_id, fhe.dot, h_ct, mask_ct_ap)
|
|
310
|
+
g_lists[grp].append(g_sum_ct)
|
|
311
|
+
h_lists[grp].append(h_sum_ct)
|
|
312
|
+
tcomp1 = mp.run_jax_at(
|
|
313
|
+
ap_id, lambda: jnp.array(time.time(), dtype=jnp.float64)
|
|
314
|
+
)
|
|
315
|
+
comp_parts.append(
|
|
316
|
+
mp.run_jax_at(ap_id, lambda a, b: a - b, tcomp1, tcomp0)
|
|
317
|
+
)
|
|
318
|
+
tdec0 = mp.run_jax_at(
|
|
319
|
+
ap_id, lambda: jnp.array(time.time(), dtype=jnp.float64)
|
|
320
|
+
)
|
|
321
|
+
for grp in range(t):
|
|
322
|
+
enc_g_list = g_lists[grp]
|
|
323
|
+
enc_h_list = h_lists[grp]
|
|
324
|
+
dec_g = [
|
|
325
|
+
mp.run_at(ap_id, fhe.decrypt, ct, priv_ctx) for ct in enc_g_list
|
|
326
|
+
]
|
|
327
|
+
dec_h = [
|
|
328
|
+
mp.run_at(ap_id, fhe.decrypt, ct, priv_ctx) for ct in enc_h_list
|
|
329
|
+
]
|
|
330
|
+
|
|
331
|
+
def _stack(*xs):
|
|
332
|
+
return jnp.stack(xs)
|
|
333
|
+
|
|
334
|
+
_ = mp.run_jax_at(ap_id, _stack, *dec_g)
|
|
335
|
+
_ = mp.run_jax_at(ap_id, _stack, *dec_h)
|
|
336
|
+
tdec1 = mp.run_jax_at(
|
|
337
|
+
ap_id, lambda: jnp.array(time.time(), dtype=jnp.float64)
|
|
338
|
+
)
|
|
339
|
+
dec_parts.append(mp.run_jax_at(ap_id, lambda a, b: a - b, tdec1, tdec0))
|
|
340
|
+
else:
|
|
341
|
+
assert gh_ct is not None
|
|
342
|
+
# even_sel/odd_sel were built once before reps
|
|
343
|
+
|
|
344
|
+
def _dup2(mask):
|
|
345
|
+
n = mask.shape[0]
|
|
346
|
+
out = jnp.empty((n * 2,), dtype=jnp.int64)
|
|
347
|
+
out = out.at[0::2].set(mask)
|
|
348
|
+
out = out.at[1::2].set(mask)
|
|
349
|
+
return out
|
|
350
|
+
|
|
351
|
+
feature_size = pp_idx[i].shape[1]
|
|
352
|
+
g_lists = [[] for _ in range(t)]
|
|
353
|
+
h_lists = [[] for _ in range(t)]
|
|
354
|
+
tcomp0 = mp.run_jax_at(
|
|
355
|
+
ap_id, lambda: jnp.array(time.time(), dtype=jnp.float64)
|
|
356
|
+
)
|
|
357
|
+
for grp in range(t):
|
|
358
|
+
if mode in ("interleaved_cached",) and cached_masks is not None:
|
|
359
|
+
# Use precomputed encrypted masks
|
|
360
|
+
feat_masks = cached_masks[i][grp]
|
|
361
|
+
for feature_idx in range(feature_size):
|
|
362
|
+
for bucket_idx in range(k):
|
|
363
|
+
mask_ct_ap = feat_masks[feature_idx][bucket_idx]
|
|
364
|
+
prod_ct = mp.run_at(ap_id, fhe.mul, gh_ct, mask_ct_ap)
|
|
365
|
+
g_sum_ct = mp.run_at(ap_id, fhe.dot, prod_ct, even_sel)
|
|
366
|
+
h_sum_ct = mp.run_at(ap_id, fhe.dot, prod_ct, odd_sel)
|
|
367
|
+
g_lists[grp].append(g_sum_ct)
|
|
368
|
+
h_lists[grp].append(h_sum_ct)
|
|
369
|
+
else:
|
|
370
|
+
# Build on the fly
|
|
371
|
+
gom = mp.run_jax_at(
|
|
372
|
+
pp, lambda m, idx: m[idx], subgroup_map_pp, grp
|
|
373
|
+
)
|
|
374
|
+
|
|
375
|
+
def create_masked_order_map(m, om):
|
|
376
|
+
mask_expanded = jnp.expand_dims(m, axis=1)
|
|
377
|
+
mask_full = jnp.broadcast_to(mask_expanded, om.shape)
|
|
378
|
+
return jnp.where(mask_full == 1, om, -1)
|
|
379
|
+
|
|
380
|
+
gom_map = mp.run_jax_at(
|
|
381
|
+
pp, create_masked_order_map, gom, pp_idx[i]
|
|
382
|
+
)
|
|
383
|
+
for feature_idx in range(feature_size):
|
|
384
|
+
for bucket_idx in range(k):
|
|
385
|
+
|
|
386
|
+
def create_bucket_mask(gom_, f_idx, b_idx):
|
|
387
|
+
fb = gom_[:, f_idx]
|
|
388
|
+
valid_and_in_bucket = (fb >= 0) & (fb <= b_idx)
|
|
389
|
+
return valid_and_in_bucket.astype(jnp.int64)
|
|
390
|
+
|
|
391
|
+
bucket_mask = mp.run_jax_at(
|
|
392
|
+
pp,
|
|
393
|
+
create_bucket_mask,
|
|
394
|
+
gom_map,
|
|
395
|
+
feature_idx,
|
|
396
|
+
bucket_idx,
|
|
397
|
+
)
|
|
398
|
+
inter_mask = mp.run_jax_at(pp, _dup2, bucket_mask)
|
|
399
|
+
mask_ct_pp = mp.run_at(
|
|
400
|
+
pp, fhe.encrypt, inter_mask, pub_ctx_pp
|
|
401
|
+
)
|
|
402
|
+
mask_ct_ap = mp.p2p(pp, ap_id, mask_ct_pp)
|
|
403
|
+
|
|
404
|
+
prod_ct = mp.run_at(ap_id, fhe.mul, gh_ct, mask_ct_ap)
|
|
405
|
+
g_sum_ct = mp.run_at(ap_id, fhe.dot, prod_ct, even_sel)
|
|
406
|
+
h_sum_ct = mp.run_at(ap_id, fhe.dot, prod_ct, odd_sel)
|
|
407
|
+
|
|
408
|
+
g_lists[grp].append(g_sum_ct)
|
|
409
|
+
h_lists[grp].append(h_sum_ct)
|
|
410
|
+
tcomp1 = mp.run_jax_at(
|
|
411
|
+
ap_id, lambda: jnp.array(time.time(), dtype=jnp.float64)
|
|
412
|
+
)
|
|
413
|
+
comp_parts.append(
|
|
414
|
+
mp.run_jax_at(ap_id, lambda a, b: a - b, tcomp1, tcomp0)
|
|
415
|
+
)
|
|
416
|
+
tdec0 = mp.run_jax_at(
|
|
417
|
+
ap_id, lambda: jnp.array(time.time(), dtype=jnp.float64)
|
|
418
|
+
)
|
|
419
|
+
for grp in range(t):
|
|
420
|
+
enc_g_list = g_lists[grp]
|
|
421
|
+
enc_h_list = h_lists[grp]
|
|
422
|
+
dec_g = [
|
|
423
|
+
mp.run_at(ap_id, fhe.decrypt, ct, priv_ctx) for ct in enc_g_list
|
|
424
|
+
]
|
|
425
|
+
dec_h = [
|
|
426
|
+
mp.run_at(ap_id, fhe.decrypt, ct, priv_ctx) for ct in enc_h_list
|
|
427
|
+
]
|
|
428
|
+
|
|
429
|
+
def _stack(*xs):
|
|
430
|
+
return jnp.stack(xs)
|
|
431
|
+
|
|
432
|
+
_ = mp.run_jax_at(ap_id, _stack, *dec_g)
|
|
433
|
+
_ = mp.run_jax_at(ap_id, _stack, *dec_h)
|
|
434
|
+
tdec1 = mp.run_jax_at(
|
|
435
|
+
ap_id, lambda: jnp.array(time.time(), dtype=jnp.float64)
|
|
436
|
+
)
|
|
437
|
+
dec_parts.append(mp.run_jax_at(ap_id, lambda a, b: a - b, tdec1, tdec0))
|
|
438
|
+
|
|
439
|
+
t1 = mp.run_jax_at(ap_id, lambda: jnp.array(time.time(), dtype=jnp.float64))
|
|
440
|
+
dt = mp.run_jax_at(ap_id, lambda a, b: a - b, t1, t0)
|
|
441
|
+
# Optionally include precompute cost once (first repetition only)
|
|
442
|
+
if include_precompute and rep_i == 0:
|
|
443
|
+
dt = mp.run_jax_at(ap_id, lambda x, add: x + add, dt, pre_dt)
|
|
444
|
+
# Sum parts across PPs
|
|
445
|
+
|
|
446
|
+
def _sum_vec(*xs):
|
|
447
|
+
s = xs[0]
|
|
448
|
+
for x in xs[1:]:
|
|
449
|
+
s = s + x
|
|
450
|
+
return s
|
|
451
|
+
|
|
452
|
+
comp_sum = (
|
|
453
|
+
mp.run_jax_at(ap_id, _sum_vec, *comp_parts)
|
|
454
|
+
if comp_parts
|
|
455
|
+
else mp.run_jax_at(ap_id, lambda: jnp.array(0.0, dtype=jnp.float64))
|
|
456
|
+
)
|
|
457
|
+
dec_sum = (
|
|
458
|
+
mp.run_jax_at(ap_id, _sum_vec, *dec_parts)
|
|
459
|
+
if dec_parts
|
|
460
|
+
else mp.run_jax_at(ap_id, lambda: jnp.array(0.0, dtype=jnp.float64))
|
|
461
|
+
)
|
|
462
|
+
times_total.append(dt)
|
|
463
|
+
times_comp.append(comp_sum)
|
|
464
|
+
times_dec.append(dec_sum)
|
|
465
|
+
|
|
466
|
+
# Stack per-rep durations into a vector at AP for robust fetch
|
|
467
|
+
def _stack_times(*xs):
|
|
468
|
+
return jnp.stack(xs)
|
|
469
|
+
|
|
470
|
+
total_vec = mp.run_jax_at(ap_id, _stack_times, *times_total)
|
|
471
|
+
if not breakdown:
|
|
472
|
+
return total_vec
|
|
473
|
+
comp_vec = mp.run_jax_at(ap_id, _stack_times, *times_comp)
|
|
474
|
+
dec_vec = mp.run_jax_at(ap_id, _stack_times, *times_dec)
|
|
475
|
+
|
|
476
|
+
def _stack3(a, b, c):
|
|
477
|
+
return jnp.stack([a, b, c], axis=0)
|
|
478
|
+
|
|
479
|
+
return mp.run_jax_at(ap_id, _stack3, total_vec, comp_vec, dec_vec)
|
|
480
|
+
|
|
481
|
+
|
|
482
|
+
def main():
|
|
483
|
+
parser = argparse.ArgumentParser(description="FHE histogram microbenchmark")
|
|
484
|
+
parser.add_argument(
|
|
485
|
+
"--world-size", type=int, default=2, help="Total parties (AP=1+PPs)"
|
|
486
|
+
)
|
|
487
|
+
parser.add_argument("--m", type=int, default=4096, help="Samples")
|
|
488
|
+
parser.add_argument("--n-total", type=int, default=16, help="Total features")
|
|
489
|
+
parser.add_argument("--n-ap", type=int, default=4, help="AP feature count")
|
|
490
|
+
parser.add_argument("--k", type=int, default=16, help="Bins per feature")
|
|
491
|
+
parser.add_argument("--t", type=int, default=4, help="Groups (nodes at level)")
|
|
492
|
+
parser.add_argument("--reps", type=int, default=3, help="Repetitions")
|
|
493
|
+
parser.add_argument("--seed", type=int, default=0, help="Random seed")
|
|
494
|
+
parser.add_argument(
|
|
495
|
+
"--mode",
|
|
496
|
+
type=str,
|
|
497
|
+
default="classic",
|
|
498
|
+
choices=[
|
|
499
|
+
"classic",
|
|
500
|
+
"classic_cached",
|
|
501
|
+
"interleaved",
|
|
502
|
+
"interleaved_cached",
|
|
503
|
+
],
|
|
504
|
+
help="Histogram mode to benchmark",
|
|
505
|
+
)
|
|
506
|
+
parser.add_argument(
|
|
507
|
+
"--include-precompute",
|
|
508
|
+
action="store_true",
|
|
509
|
+
help="Include precompute (mask generation+encryption) time in the first repetition",
|
|
510
|
+
)
|
|
511
|
+
parser.add_argument(
|
|
512
|
+
"--breakdown",
|
|
513
|
+
action="store_true",
|
|
514
|
+
help="Report timing breakdown (total, compute, decrypt)",
|
|
515
|
+
)
|
|
516
|
+
args = parser.parse_args()
|
|
517
|
+
|
|
518
|
+
assert args.world_size >= 2, "world-size must be >= 2"
|
|
519
|
+
pp_parties = args.world_size - 1
|
|
520
|
+
assert args.n_total >= args.n_ap
|
|
521
|
+
|
|
522
|
+
# Split PP features evenly
|
|
523
|
+
n_pp_total = args.n_total - args.n_ap
|
|
524
|
+
n_pp_each = [n_pp_total // pp_parties] * pp_parties
|
|
525
|
+
n_pp_each[-1] += n_pp_total - sum(n_pp_each)
|
|
526
|
+
|
|
527
|
+
X_ap, X_pp_all, y = _gen_data(args.m, args.n_total, args.n_ap, args.seed)
|
|
528
|
+
offset = 0
|
|
529
|
+
X_pp_splits = []
|
|
530
|
+
for c in n_pp_each:
|
|
531
|
+
X_pp_splits.append(X_pp_all[:, offset : offset + c])
|
|
532
|
+
offset += c
|
|
533
|
+
|
|
534
|
+
sim = mp.Simulator.simple(args.world_size)
|
|
535
|
+
|
|
536
|
+
ap_id = 0
|
|
537
|
+
pp_ids = list(range(1, args.world_size))
|
|
538
|
+
|
|
539
|
+
print("\n=== FHE Histogram Microbenchmark ===")
|
|
540
|
+
print(
|
|
541
|
+
f"world-size={args.world_size} (AP+{pp_parties} PPs), m={args.m}, n_total={args.n_total} (AP={args.n_ap}, PP={n_pp_total}), k={args.k}, t={args.t}, reps={args.reps}, mode={args.mode}"
|
|
542
|
+
)
|
|
543
|
+
|
|
544
|
+
out = mp.evaluate(
|
|
545
|
+
sim,
|
|
546
|
+
_bench_once,
|
|
547
|
+
ap_id,
|
|
548
|
+
pp_ids,
|
|
549
|
+
X_ap,
|
|
550
|
+
X_pp_splits,
|
|
551
|
+
y,
|
|
552
|
+
args.k,
|
|
553
|
+
args.t,
|
|
554
|
+
args.reps,
|
|
555
|
+
args.mode,
|
|
556
|
+
args.include_precompute,
|
|
557
|
+
args.breakdown,
|
|
558
|
+
)
|
|
559
|
+
times_raw = mp.fetch(sim, out)
|
|
560
|
+
|
|
561
|
+
# Expected: [times_at_ap, None, ...] in 2PC; extract first non-None
|
|
562
|
+
if isinstance(times_raw, list) and len(times_raw) >= 1 and times_raw[-1] is None:
|
|
563
|
+
times_nodes = times_raw[0]
|
|
564
|
+
else:
|
|
565
|
+
times_nodes = times_raw
|
|
566
|
+
|
|
567
|
+
if args.breakdown:
|
|
568
|
+
times_arr = np.asarray(times_nodes, dtype=float)
|
|
569
|
+
# Expect shape (3, reps): [total, compute, decrypt]
|
|
570
|
+
if times_arr.ndim == 1:
|
|
571
|
+
# Fallback if flattened; try to split into 3 roughly equal parts
|
|
572
|
+
n = times_arr.size
|
|
573
|
+
r = n // 3
|
|
574
|
+
total = times_arr[:r]
|
|
575
|
+
comp = times_arr[r : 2 * r]
|
|
576
|
+
dec = times_arr[2 * r : 3 * r]
|
|
577
|
+
else:
|
|
578
|
+
total, comp, dec = (
|
|
579
|
+
times_arr[0].ravel(),
|
|
580
|
+
times_arr[1].ravel(),
|
|
581
|
+
times_arr[2].ravel(),
|
|
582
|
+
)
|
|
583
|
+
|
|
584
|
+
print(f"Per-rep total (s): {', '.join(f'{t:.4f}' for t in total.tolist())}")
|
|
585
|
+
print(
|
|
586
|
+
f"Per-rep compute-only (s): {', '.join(f'{t:.4f}' for t in comp.tolist())}"
|
|
587
|
+
)
|
|
588
|
+
print(
|
|
589
|
+
f"Per-rep decrypt-only (s): {', '.join(f'{t:.4f}' for t in dec.tolist())}"
|
|
590
|
+
)
|
|
591
|
+
print(
|
|
592
|
+
f"Averages — total: {float(total.mean()):.4f}s, compute: {float(comp.mean()):.4f}s, decrypt: {float(dec.mean()):.4f}s"
|
|
593
|
+
)
|
|
594
|
+
else:
|
|
595
|
+
# Convert to numpy array of floats (handle scalar, list, or numpy array)
|
|
596
|
+
if isinstance(times_nodes, list):
|
|
597
|
+
# elements are likely [val, None] pairs; take first
|
|
598
|
+
times_arr = np.array(
|
|
599
|
+
[
|
|
600
|
+
float(np.array(e[0]))
|
|
601
|
+
if isinstance(e, (list, tuple))
|
|
602
|
+
else float(np.array(e))
|
|
603
|
+
for e in times_nodes
|
|
604
|
+
],
|
|
605
|
+
dtype=float,
|
|
606
|
+
)
|
|
607
|
+
else:
|
|
608
|
+
times_arr = np.asarray(times_nodes, dtype=float).ravel()
|
|
609
|
+
avg = float(times_arr.mean())
|
|
610
|
+
print(f"Per-rep time (s): {', '.join(f'{t:.4f}' for t in times_arr.tolist())}")
|
|
611
|
+
print(f"Average time (s): {avg:.4f}")
|
|
612
|
+
|
|
613
|
+
|
|
614
|
+
if __name__ == "__main__":
|
|
615
|
+
main()
|