mplang-nightly 0.1.dev203__tar.gz → 0.1.dev266__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (380) hide show
  1. {mplang_nightly-0.1.dev203 → mplang_nightly-0.1.dev266}/.gitignore +1 -0
  2. {mplang_nightly-0.1.dev203 → mplang_nightly-0.1.dev266}/PKG-INFO +11 -5
  3. {mplang_nightly-0.1.dev203 → mplang_nightly-0.1.dev266}/README.md +1 -1
  4. {mplang_nightly-0.1.dev203/examples → mplang_nightly-0.1.dev266/examples/v1}/stax_nn/stax_nn.py +2 -2
  5. mplang_nightly-0.1.dev266/examples/v1/xgboost/bench_fhe_hist.py +615 -0
  6. {mplang_nightly-0.1.dev203/examples → mplang_nightly-0.1.dev266/examples/v1}/xgboost/sgb.py +304 -218
  7. {mplang_nightly-0.1.dev203/examples → mplang_nightly-0.1.dev266/examples/v1}/xgboost/sgb_test.py +279 -70
  8. mplang_nightly-0.1.dev266/mplang/__init__.py +46 -0
  9. {mplang_nightly-0.1.dev203/mplang → mplang_nightly-0.1.dev266/mplang/v1}/__init__.py +11 -11
  10. {mplang_nightly-0.1.dev203/mplang → mplang_nightly-0.1.dev266/mplang/v1}/_device.py +63 -13
  11. {mplang_nightly-0.1.dev203/mplang → mplang_nightly-0.1.dev266/mplang/v1}/analysis/__init__.py +1 -1
  12. {mplang_nightly-0.1.dev203/mplang → mplang_nightly-0.1.dev266/mplang/v1}/analysis/diagram.py +4 -4
  13. {mplang_nightly-0.1.dev203/mplang → mplang_nightly-0.1.dev266/mplang/v1}/core/__init__.py +20 -14
  14. {mplang_nightly-0.1.dev203/mplang → mplang_nightly-0.1.dev266/mplang/v1}/core/comm.py +1 -1
  15. {mplang_nightly-0.1.dev203/mplang → mplang_nightly-0.1.dev266/mplang/v1}/core/context_mgr.py +1 -1
  16. {mplang_nightly-0.1.dev203/mplang → mplang_nightly-0.1.dev266/mplang/v1}/core/expr/__init__.py +7 -7
  17. {mplang_nightly-0.1.dev203/mplang → mplang_nightly-0.1.dev266/mplang/v1}/core/expr/ast.py +10 -10
  18. {mplang_nightly-0.1.dev203/mplang → mplang_nightly-0.1.dev266/mplang/v1}/core/expr/evaluator.py +8 -8
  19. {mplang_nightly-0.1.dev203/mplang → mplang_nightly-0.1.dev266/mplang/v1}/core/expr/printer.py +6 -6
  20. {mplang_nightly-0.1.dev203/mplang → mplang_nightly-0.1.dev266/mplang/v1}/core/expr/transformer.py +2 -2
  21. {mplang_nightly-0.1.dev203/mplang → mplang_nightly-0.1.dev266/mplang/v1}/core/expr/utils.py +2 -2
  22. {mplang_nightly-0.1.dev203/mplang → mplang_nightly-0.1.dev266/mplang/v1}/core/expr/visitor.py +1 -1
  23. {mplang_nightly-0.1.dev203/mplang → mplang_nightly-0.1.dev266/mplang/v1}/core/expr/walk.py +1 -1
  24. {mplang_nightly-0.1.dev203/mplang → mplang_nightly-0.1.dev266/mplang/v1}/core/interp.py +6 -6
  25. {mplang_nightly-0.1.dev203/mplang → mplang_nightly-0.1.dev266/mplang/v1}/core/mpir.py +13 -11
  26. {mplang_nightly-0.1.dev203/mplang → mplang_nightly-0.1.dev266/mplang/v1}/core/mpobject.py +6 -6
  27. {mplang_nightly-0.1.dev203/mplang → mplang_nightly-0.1.dev266/mplang/v1}/core/mptype.py +7 -7
  28. {mplang_nightly-0.1.dev203/mplang → mplang_nightly-0.1.dev266/mplang/v1}/core/pfunc.py +2 -2
  29. {mplang_nightly-0.1.dev203/mplang → mplang_nightly-0.1.dev266/mplang/v1}/core/primitive.py +10 -10
  30. {mplang_nightly-0.1.dev203/mplang → mplang_nightly-0.1.dev266/mplang/v1}/core/table.py +1 -1
  31. {mplang_nightly-0.1.dev203/mplang → mplang_nightly-0.1.dev266/mplang/v1}/core/tensor.py +1 -1
  32. {mplang_nightly-0.1.dev203/mplang → mplang_nightly-0.1.dev266/mplang/v1}/core/tracer.py +9 -9
  33. {mplang_nightly-0.1.dev203/mplang → mplang_nightly-0.1.dev266/mplang/v1}/host.py +2 -2
  34. {mplang_nightly-0.1.dev203/mplang → mplang_nightly-0.1.dev266/mplang/v1}/kernels/__init__.py +1 -1
  35. {mplang_nightly-0.1.dev203/mplang → mplang_nightly-0.1.dev266/mplang/v1}/kernels/base.py +1 -1
  36. {mplang_nightly-0.1.dev203/mplang → mplang_nightly-0.1.dev266/mplang/v1}/kernels/basic.py +13 -13
  37. {mplang_nightly-0.1.dev203/mplang → mplang_nightly-0.1.dev266/mplang/v1}/kernels/context.py +14 -14
  38. {mplang_nightly-0.1.dev203/mplang → mplang_nightly-0.1.dev266/mplang/v1}/kernels/crypto.py +4 -4
  39. {mplang_nightly-0.1.dev203/mplang → mplang_nightly-0.1.dev266/mplang/v1}/kernels/fhe.py +9 -7
  40. {mplang_nightly-0.1.dev203/mplang → mplang_nightly-0.1.dev266/mplang/v1}/kernels/mock_tee.py +3 -3
  41. {mplang_nightly-0.1.dev203/mplang → mplang_nightly-0.1.dev266/mplang/v1}/kernels/phe.py +18 -14
  42. {mplang_nightly-0.1.dev203/mplang → mplang_nightly-0.1.dev266/mplang/v1}/kernels/spu.py +5 -5
  43. {mplang_nightly-0.1.dev203/mplang → mplang_nightly-0.1.dev266/mplang/v1}/kernels/sql_duckdb.py +5 -3
  44. {mplang_nightly-0.1.dev203/mplang → mplang_nightly-0.1.dev266/mplang/v1}/kernels/stablehlo.py +18 -17
  45. {mplang_nightly-0.1.dev203/mplang → mplang_nightly-0.1.dev266/mplang/v1}/kernels/value.py +1 -1
  46. {mplang_nightly-0.1.dev203/mplang → mplang_nightly-0.1.dev266/mplang/v1}/ops/__init__.py +3 -2
  47. {mplang_nightly-0.1.dev203/mplang → mplang_nightly-0.1.dev266/mplang/v1}/ops/base.py +1 -1
  48. {mplang_nightly-0.1.dev203/mplang → mplang_nightly-0.1.dev266/mplang/v1}/ops/basic.py +3 -3
  49. {mplang_nightly-0.1.dev203/mplang → mplang_nightly-0.1.dev266/mplang/v1}/ops/crypto.py +4 -4
  50. {mplang_nightly-0.1.dev203/mplang → mplang_nightly-0.1.dev266/mplang/v1}/ops/fhe.py +2 -2
  51. {mplang_nightly-0.1.dev203/mplang → mplang_nightly-0.1.dev266/mplang/v1}/ops/jax_cc.py +26 -59
  52. mplang_nightly-0.1.dev266/mplang/v1/ops/nnx_cc.py +168 -0
  53. {mplang_nightly-0.1.dev203/mplang → mplang_nightly-0.1.dev266/mplang/v1}/ops/phe.py +16 -3
  54. {mplang_nightly-0.1.dev203/mplang → mplang_nightly-0.1.dev266/mplang/v1}/ops/spu.py +3 -3
  55. {mplang_nightly-0.1.dev203/mplang → mplang_nightly-0.1.dev266/mplang/v1}/ops/sql_cc.py +55 -48
  56. {mplang_nightly-0.1.dev203/mplang → mplang_nightly-0.1.dev266/mplang/v1}/ops/tee.py +2 -2
  57. {mplang_nightly-0.1.dev203/mplang → mplang_nightly-0.1.dev266/mplang/v1}/runtime/__init__.py +2 -2
  58. {mplang_nightly-0.1.dev203/mplang → mplang_nightly-0.1.dev266/mplang/v1}/runtime/cli.py +3 -3
  59. {mplang_nightly-0.1.dev203/mplang → mplang_nightly-0.1.dev266/mplang/v1}/runtime/client.py +1 -1
  60. {mplang_nightly-0.1.dev203/mplang → mplang_nightly-0.1.dev266/mplang/v1}/runtime/communicator.py +2 -2
  61. {mplang_nightly-0.1.dev203/mplang → mplang_nightly-0.1.dev266/mplang/v1}/runtime/data_providers.py +77 -15
  62. {mplang_nightly-0.1.dev203/mplang → mplang_nightly-0.1.dev266/mplang/v1}/runtime/driver.py +4 -4
  63. {mplang_nightly-0.1.dev203/mplang → mplang_nightly-0.1.dev266/mplang/v1}/runtime/server.py +12 -8
  64. {mplang_nightly-0.1.dev203/mplang → mplang_nightly-0.1.dev266/mplang/v1}/runtime/session.py +13 -13
  65. {mplang_nightly-0.1.dev203/mplang → mplang_nightly-0.1.dev266/mplang/v1}/runtime/simulation.py +6 -6
  66. {mplang_nightly-0.1.dev203/mplang → mplang_nightly-0.1.dev266/mplang/v1}/simp/api.py +72 -5
  67. {mplang_nightly-0.1.dev203/mplang → mplang_nightly-0.1.dev266/mplang/v1}/simp/mpi.py +1 -1
  68. {mplang_nightly-0.1.dev203/mplang → mplang_nightly-0.1.dev266/mplang/v1}/simp/party.py +5 -5
  69. {mplang_nightly-0.1.dev203/mplang → mplang_nightly-0.1.dev266/mplang/v1}/simp/random.py +2 -2
  70. {mplang_nightly-0.1.dev203/mplang → mplang_nightly-0.1.dev266/mplang/v1}/simp/smpc.py +7 -7
  71. {mplang_nightly-0.1.dev203/mplang → mplang_nightly-0.1.dev266/mplang/v1}/utils/table_utils.py +1 -1
  72. mplang_nightly-0.1.dev266/mplang/v2/__init__.py +424 -0
  73. mplang_nightly-0.1.dev266/mplang/v2/backends/__init__.py +57 -0
  74. mplang_nightly-0.1.dev266/mplang/v2/backends/bfv_impl.py +705 -0
  75. mplang_nightly-0.1.dev266/mplang/v2/backends/crypto_impl.py +723 -0
  76. mplang_nightly-0.1.dev266/mplang/v2/backends/field_impl.py +454 -0
  77. mplang_nightly-0.1.dev266/mplang/v2/backends/func_impl.py +107 -0
  78. mplang_nightly-0.1.dev266/mplang/v2/backends/phe_impl.py +148 -0
  79. mplang_nightly-0.1.dev266/mplang/v2/backends/simp_design.md +136 -0
  80. mplang_nightly-0.1.dev266/mplang/v2/backends/simp_driver/__init__.py +41 -0
  81. mplang_nightly-0.1.dev266/mplang/v2/backends/simp_driver/http.py +168 -0
  82. mplang_nightly-0.1.dev266/mplang/v2/backends/simp_driver/mem.py +280 -0
  83. mplang_nightly-0.1.dev266/mplang/v2/backends/simp_driver/ops.py +135 -0
  84. mplang_nightly-0.1.dev266/mplang/v2/backends/simp_driver/state.py +60 -0
  85. mplang_nightly-0.1.dev266/mplang/v2/backends/simp_driver/values.py +52 -0
  86. mplang_nightly-0.1.dev266/mplang/v2/backends/simp_worker/__init__.py +29 -0
  87. mplang_nightly-0.1.dev266/mplang/v2/backends/simp_worker/http.py +323 -0
  88. mplang_nightly-0.1.dev266/mplang/v2/backends/simp_worker/mem.py +99 -0
  89. mplang_nightly-0.1.dev266/mplang/v2/backends/simp_worker/ops.py +167 -0
  90. mplang_nightly-0.1.dev266/mplang/v2/backends/simp_worker/state.py +49 -0
  91. mplang_nightly-0.1.dev266/mplang/v2/backends/spu_impl.py +262 -0
  92. mplang_nightly-0.1.dev266/mplang/v2/backends/spu_state.py +124 -0
  93. mplang_nightly-0.1.dev266/mplang/v2/backends/store_impl.py +62 -0
  94. mplang_nightly-0.1.dev266/mplang/v2/backends/table_impl.py +838 -0
  95. mplang_nightly-0.1.dev266/mplang/v2/backends/tee_impl.py +215 -0
  96. mplang_nightly-0.1.dev266/mplang/v2/backends/tensor_impl.py +519 -0
  97. mplang_nightly-0.1.dev266/mplang/v2/cli.py +603 -0
  98. mplang_nightly-0.1.dev266/mplang/v2/cli_guide.md +122 -0
  99. mplang_nightly-0.1.dev266/mplang/v2/dialects/__init__.py +36 -0
  100. mplang_nightly-0.1.dev266/mplang/v2/dialects/bfv.py +665 -0
  101. mplang_nightly-0.1.dev266/mplang/v2/dialects/crypto.py +689 -0
  102. mplang_nightly-0.1.dev266/mplang/v2/dialects/dtypes.py +378 -0
  103. mplang_nightly-0.1.dev266/mplang/v2/dialects/field.py +210 -0
  104. mplang_nightly-0.1.dev266/mplang/v2/dialects/func.py +135 -0
  105. mplang_nightly-0.1.dev266/mplang/v2/dialects/phe.py +723 -0
  106. mplang_nightly-0.1.dev266/mplang/v2/dialects/simp.py +944 -0
  107. mplang_nightly-0.1.dev266/mplang/v2/dialects/spu.py +349 -0
  108. mplang_nightly-0.1.dev266/mplang/v2/dialects/store.py +63 -0
  109. mplang_nightly-0.1.dev266/mplang/v2/dialects/table.py +407 -0
  110. mplang_nightly-0.1.dev266/mplang/v2/dialects/tee.py +346 -0
  111. mplang_nightly-0.1.dev266/mplang/v2/dialects/tensor.py +1175 -0
  112. mplang_nightly-0.1.dev266/mplang/v2/edsl/README.md +279 -0
  113. mplang_nightly-0.1.dev266/mplang/v2/edsl/__init__.py +99 -0
  114. mplang_nightly-0.1.dev266/mplang/v2/edsl/context.py +311 -0
  115. mplang_nightly-0.1.dev266/mplang/v2/edsl/graph.py +463 -0
  116. mplang_nightly-0.1.dev266/mplang/v2/edsl/jit.py +62 -0
  117. mplang_nightly-0.1.dev266/mplang/v2/edsl/object.py +53 -0
  118. mplang_nightly-0.1.dev266/mplang/v2/edsl/primitive.py +284 -0
  119. mplang_nightly-0.1.dev266/mplang/v2/edsl/printer.py +119 -0
  120. mplang_nightly-0.1.dev266/mplang/v2/edsl/registry.py +207 -0
  121. mplang_nightly-0.1.dev266/mplang/v2/edsl/serde.py +375 -0
  122. mplang_nightly-0.1.dev266/mplang/v2/edsl/tracer.py +614 -0
  123. mplang_nightly-0.1.dev266/mplang/v2/edsl/typing.py +816 -0
  124. mplang_nightly-0.1.dev266/mplang/v2/kernels/Makefile +30 -0
  125. mplang_nightly-0.1.dev266/mplang/v2/kernels/__init__.py +23 -0
  126. mplang_nightly-0.1.dev266/mplang/v2/kernels/gf128.cpp +148 -0
  127. mplang_nightly-0.1.dev266/mplang/v2/kernels/ldpc.cpp +82 -0
  128. mplang_nightly-0.1.dev266/mplang/v2/kernels/okvs.cpp +283 -0
  129. mplang_nightly-0.1.dev266/mplang/v2/kernels/okvs_opt.cpp +291 -0
  130. mplang_nightly-0.1.dev266/mplang/v2/kernels/py_kernels.py +398 -0
  131. mplang_nightly-0.1.dev266/mplang/v2/libs/collective.py +330 -0
  132. mplang_nightly-0.1.dev266/mplang/v2/libs/device/__init__.py +51 -0
  133. mplang_nightly-0.1.dev266/mplang/v2/libs/device/api.py +813 -0
  134. mplang_nightly-0.1.dev266/mplang/v2/libs/device/cluster.py +352 -0
  135. mplang_nightly-0.1.dev266/mplang/v2/libs/ml/__init__.py +23 -0
  136. mplang_nightly-0.1.dev266/mplang/v2/libs/ml/sgb.py +1873 -0
  137. mplang_nightly-0.1.dev266/mplang/v2/libs/mpc/__init__.py +41 -0
  138. mplang_nightly-0.1.dev266/mplang/v2/libs/mpc/_utils.py +99 -0
  139. mplang_nightly-0.1.dev266/mplang/v2/libs/mpc/analytics/__init__.py +35 -0
  140. mplang_nightly-0.1.dev266/mplang/v2/libs/mpc/analytics/aggregation.py +372 -0
  141. mplang_nightly-0.1.dev266/mplang/v2/libs/mpc/analytics/groupby.md +99 -0
  142. mplang_nightly-0.1.dev266/mplang/v2/libs/mpc/analytics/groupby.py +331 -0
  143. mplang_nightly-0.1.dev266/mplang/v2/libs/mpc/analytics/permutation.py +386 -0
  144. mplang_nightly-0.1.dev266/mplang/v2/libs/mpc/common/constants.py +39 -0
  145. mplang_nightly-0.1.dev266/mplang/v2/libs/mpc/ot/__init__.py +32 -0
  146. mplang_nightly-0.1.dev266/mplang/v2/libs/mpc/ot/base.py +222 -0
  147. mplang_nightly-0.1.dev266/mplang/v2/libs/mpc/ot/extension.py +477 -0
  148. mplang_nightly-0.1.dev266/mplang/v2/libs/mpc/ot/silent.py +217 -0
  149. mplang_nightly-0.1.dev266/mplang/v2/libs/mpc/psi/__init__.py +40 -0
  150. mplang_nightly-0.1.dev266/mplang/v2/libs/mpc/psi/cuckoo.py +228 -0
  151. mplang_nightly-0.1.dev266/mplang/v2/libs/mpc/psi/okvs.py +49 -0
  152. mplang_nightly-0.1.dev266/mplang/v2/libs/mpc/psi/okvs_gct.py +79 -0
  153. mplang_nightly-0.1.dev266/mplang/v2/libs/mpc/psi/oprf.py +310 -0
  154. mplang_nightly-0.1.dev266/mplang/v2/libs/mpc/psi/rr22.py +344 -0
  155. mplang_nightly-0.1.dev266/mplang/v2/libs/mpc/psi/unbalanced.py +200 -0
  156. mplang_nightly-0.1.dev266/mplang/v2/libs/mpc/vole/__init__.py +31 -0
  157. mplang_nightly-0.1.dev266/mplang/v2/libs/mpc/vole/gilboa.py +327 -0
  158. mplang_nightly-0.1.dev266/mplang/v2/libs/mpc/vole/ldpc.py +383 -0
  159. mplang_nightly-0.1.dev266/mplang/v2/libs/mpc/vole/silver.py +336 -0
  160. mplang_nightly-0.1.dev266/mplang/v2/runtime/__init__.py +15 -0
  161. mplang_nightly-0.1.dev266/mplang/v2/runtime/dialect_state.py +41 -0
  162. mplang_nightly-0.1.dev266/mplang/v2/runtime/interpreter.py +871 -0
  163. mplang_nightly-0.1.dev266/mplang/v2/runtime/object_store.py +194 -0
  164. mplang_nightly-0.1.dev266/mplang/v2/runtime/value.py +141 -0
  165. {mplang_nightly-0.1.dev203 → mplang_nightly-0.1.dev266}/pyproject.toml +39 -12
  166. mplang_nightly-0.1.dev266/tests/__init__.py +15 -0
  167. {mplang_nightly-0.1.dev203 → mplang_nightly-0.1.dev266}/tests/conftest.py +1 -1
  168. {mplang_nightly-0.1.dev203/tests → mplang_nightly-0.1.dev266/tests/v1}/analysis/test_diagram.py +2 -2
  169. mplang_nightly-0.1.dev266/tests/v1/conftest.py +17 -0
  170. {mplang_nightly-0.1.dev203/tests → mplang_nightly-0.1.dev266/tests/v1}/core/expr/conftest.py +4 -4
  171. {mplang_nightly-0.1.dev203/tests → mplang_nightly-0.1.dev266/tests/v1}/core/expr/test_ast.py +4 -4
  172. {mplang_nightly-0.1.dev203/tests → mplang_nightly-0.1.dev266/tests/v1}/core/expr/test_printer.py +7 -7
  173. {mplang_nightly-0.1.dev203/tests → mplang_nightly-0.1.dev266/tests/v1}/core/expr/test_utils.py +3 -3
  174. {mplang_nightly-0.1.dev203/tests → mplang_nightly-0.1.dev266/tests/v1}/core/expr/test_walk.py +3 -3
  175. {mplang_nightly-0.1.dev203/tests → mplang_nightly-0.1.dev266/tests/v1}/core/test_cluster.py +1 -1
  176. {mplang_nightly-0.1.dev203/tests → mplang_nightly-0.1.dev266/tests/v1}/core/test_dtype.py +5 -5
  177. {mplang_nightly-0.1.dev203/tests → mplang_nightly-0.1.dev266/tests/v1}/core/test_mask.py +1 -1
  178. {mplang_nightly-0.1.dev203/tests → mplang_nightly-0.1.dev266/tests/v1}/core/test_mpir.py +29 -29
  179. {mplang_nightly-0.1.dev203/tests → mplang_nightly-0.1.dev266/tests/v1}/core/test_mptype.py +4 -4
  180. {mplang_nightly-0.1.dev203/tests → mplang_nightly-0.1.dev266/tests/v1}/core/test_primitive.py +14 -14
  181. {mplang_nightly-0.1.dev203/tests → mplang_nightly-0.1.dev266/tests/v1}/core/test_table.py +2 -2
  182. {mplang_nightly-0.1.dev203/tests → mplang_nightly-0.1.dev266/tests/v1}/core/test_tensor.py +3 -3
  183. {mplang_nightly-0.1.dev203/tests → mplang_nightly-0.1.dev266/tests/v1}/core/test_tracer.py +25 -19
  184. {mplang_nightly-0.1.dev203/tests → mplang_nightly-0.1.dev266/tests/v1}/device/test_device_basic.py +1 -1
  185. {mplang_nightly-0.1.dev203/tests → mplang_nightly-0.1.dev266/tests/v1}/integration/test_crypto_roundtrip.py +2 -2
  186. {mplang_nightly-0.1.dev203/tests → mplang_nightly-0.1.dev266/tests/v1}/integration/test_http_e2e.py +1 -1
  187. {mplang_nightly-0.1.dev203/tests → mplang_nightly-0.1.dev266/tests/v1}/integration/test_symbols_roundtrip.py +8 -6
  188. {mplang_nightly-0.1.dev203/tests → mplang_nightly-0.1.dev266/tests/v1}/integration/test_tee_workflow.py +2 -2
  189. {mplang_nightly-0.1.dev203/tests → mplang_nightly-0.1.dev266/tests/v1}/integration/test_tutorials.py +1 -1
  190. {mplang_nightly-0.1.dev203/tests → mplang_nightly-0.1.dev266/tests/v1}/integration/test_unused_param_integration.py +2 -2
  191. mplang_nightly-0.1.dev266/tests/v1/kernels/__init__.py +13 -0
  192. {mplang_nightly-0.1.dev203/tests → mplang_nightly-0.1.dev266/tests/v1}/kernels/test_basic.py +15 -7
  193. {mplang_nightly-0.1.dev203/tests → mplang_nightly-0.1.dev266/tests/v1}/kernels/test_debug_print.py +5 -5
  194. {mplang_nightly-0.1.dev203/tests → mplang_nightly-0.1.dev266/tests/v1}/kernels/test_fhe.py +39 -30
  195. {mplang_nightly-0.1.dev203/tests → mplang_nightly-0.1.dev266/tests/v1}/kernels/test_kernel_binding.py +10 -8
  196. {mplang_nightly-0.1.dev203/tests → mplang_nightly-0.1.dev266/tests/v1}/kernels/test_phe.py +7 -7
  197. {mplang_nightly-0.1.dev203/tests → mplang_nightly-0.1.dev266/tests/v1}/kernels/test_spu.py +12 -12
  198. {mplang_nightly-0.1.dev203/tests → mplang_nightly-0.1.dev266/tests/v1}/kernels/test_sql_duckdb.py +4 -4
  199. {mplang_nightly-0.1.dev203/tests → mplang_nightly-0.1.dev266/tests/v1}/kernels/test_stablehlo.py +9 -9
  200. {mplang_nightly-0.1.dev203/tests → mplang_nightly-0.1.dev266/tests/v1}/kernels/test_value.py +5 -5
  201. {mplang_nightly-0.1.dev203/tests → mplang_nightly-0.1.dev266/tests/v1}/kernels/test_value_serde.py +5 -5
  202. {mplang_nightly-0.1.dev203/tests → mplang_nightly-0.1.dev266/tests/v1}/ops/dummy.py +5 -5
  203. {mplang_nightly-0.1.dev203/tests → mplang_nightly-0.1.dev266/tests/v1}/ops/test_basic_pack.py +6 -6
  204. {mplang_nightly-0.1.dev203/tests → mplang_nightly-0.1.dev266/tests/v1}/ops/test_crypto.py +6 -6
  205. {mplang_nightly-0.1.dev203/tests → mplang_nightly-0.1.dev266/tests/v1}/ops/test_feop_base.py +5 -5
  206. {mplang_nightly-0.1.dev203/tests → mplang_nightly-0.1.dev266/tests/v1}/ops/test_jax_cc.py +1 -1
  207. mplang_nightly-0.1.dev266/tests/v1/ops/test_nnx_cc.py +265 -0
  208. {mplang_nightly-0.1.dev203/tests → mplang_nightly-0.1.dev266/tests/v1}/ops/test_phe.py +3 -3
  209. {mplang_nightly-0.1.dev203/tests → mplang_nightly-0.1.dev266/tests/v1}/ops/test_spu.py +2 -2
  210. {mplang_nightly-0.1.dev203/tests → mplang_nightly-0.1.dev266/tests/v1}/ops/test_spu_defensive.py +3 -3
  211. {mplang_nightly-0.1.dev203/tests → mplang_nightly-0.1.dev266/tests/v1}/ops/test_sql.py +8 -6
  212. {mplang_nightly-0.1.dev203/tests → mplang_nightly-0.1.dev266/tests/v1}/ops/test_sql_cc.py +9 -9
  213. {mplang_nightly-0.1.dev203/tests → mplang_nightly-0.1.dev266/tests/v1}/ops/test_table_tensor_conversion.py +7 -7
  214. {mplang_nightly-0.1.dev203/tests → mplang_nightly-0.1.dev266/tests/v1}/runtime/test_cli.py +12 -6
  215. {mplang_nightly-0.1.dev203/tests → mplang_nightly-0.1.dev266/tests/v1}/runtime/test_communicator.py +7 -7
  216. {mplang_nightly-0.1.dev203/tests → mplang_nightly-0.1.dev266/tests/v1}/runtime/test_driver.py +3 -3
  217. {mplang_nightly-0.1.dev203/tests → mplang_nightly-0.1.dev266/tests/v1}/runtime/test_server.py +4 -4
  218. {mplang_nightly-0.1.dev203/tests → mplang_nightly-0.1.dev266/tests/v1}/runtime/test_simulation.py +6 -6
  219. {mplang_nightly-0.1.dev203/tests → mplang_nightly-0.1.dev266/tests/v1}/simp/test_mpi.py +4 -4
  220. {mplang_nightly-0.1.dev203/tests → mplang_nightly-0.1.dev266/tests/v1}/simp/test_random.py +1 -1
  221. {mplang_nightly-0.1.dev203/tests → mplang_nightly-0.1.dev266/tests/v1}/simp/test_smpc.py +4 -4
  222. {mplang_nightly-0.1.dev203/tests → mplang_nightly-0.1.dev266/tests/v1}/simp/test_sugar.py +2 -2
  223. {mplang_nightly-0.1.dev203/tests → mplang_nightly-0.1.dev266/tests/v1}/utils/server_fixtures.py +1 -1
  224. {mplang_nightly-0.1.dev203/tests → mplang_nightly-0.1.dev266/tests/v1}/utils/test_func_utils.py +1 -1
  225. {mplang_nightly-0.1.dev203/tests → mplang_nightly-0.1.dev266/tests/v1}/utils/test_spu_utils.py +1 -1
  226. {mplang_nightly-0.1.dev203/tests → mplang_nightly-0.1.dev266/tests/v1}/utils/test_table_utils.py +1 -1
  227. mplang_nightly-0.1.dev266/tests/v2/__init__.py +13 -0
  228. mplang_nightly-0.1.dev266/tests/v2/backends/__init__.py +13 -0
  229. mplang_nightly-0.1.dev266/tests/v2/backends/simp_driver/__init__.py +15 -0
  230. mplang_nightly-0.1.dev266/tests/v2/backends/simp_driver/test_http.py +215 -0
  231. mplang_nightly-0.1.dev266/tests/v2/backends/simp_worker/__init__.py +15 -0
  232. mplang_nightly-0.1.dev266/tests/v2/backends/simp_worker/test_http.py +225 -0
  233. mplang_nightly-0.1.dev266/tests/v2/backends/simp_worker/test_mem.py +102 -0
  234. mplang_nightly-0.1.dev266/tests/v2/backends/test_bfv_impl.py +344 -0
  235. mplang_nightly-0.1.dev266/tests/v2/backends/test_crypto_impl.py +541 -0
  236. mplang_nightly-0.1.dev266/tests/v2/backends/test_okvs_binding.py +115 -0
  237. mplang_nightly-0.1.dev266/tests/v2/backends/test_simp_integration.py +196 -0
  238. mplang_nightly-0.1.dev266/tests/v2/backends/test_simp_object_store.py +106 -0
  239. mplang_nightly-0.1.dev266/tests/v2/backends/test_spu_impl.py +115 -0
  240. mplang_nightly-0.1.dev266/tests/v2/backends/test_table_impl.py +436 -0
  241. mplang_nightly-0.1.dev266/tests/v2/backends/test_tee_impl.py +148 -0
  242. mplang_nightly-0.1.dev266/tests/v2/backends/test_tensor_impl.py +177 -0
  243. mplang_nightly-0.1.dev266/tests/v2/backends/test_verify_clean.py +123 -0
  244. mplang_nightly-0.1.dev266/tests/v2/conftest.py +54 -0
  245. mplang_nightly-0.1.dev266/tests/v2/dialects/__init__.py +13 -0
  246. mplang_nightly-0.1.dev266/tests/v2/dialects/test_bfv.py +178 -0
  247. mplang_nightly-0.1.dev266/tests/v2/dialects/test_crypto.py +214 -0
  248. mplang_nightly-0.1.dev266/tests/v2/dialects/test_dtypes.py +219 -0
  249. mplang_nightly-0.1.dev266/tests/v2/dialects/test_field.py +49 -0
  250. mplang_nightly-0.1.dev266/tests/v2/dialects/test_func.py +130 -0
  251. mplang_nightly-0.1.dev266/tests/v2/dialects/test_okvs.py +56 -0
  252. mplang_nightly-0.1.dev266/tests/v2/dialects/test_okvs_bench.py +55 -0
  253. mplang_nightly-0.1.dev266/tests/v2/dialects/test_phe.py +531 -0
  254. mplang_nightly-0.1.dev266/tests/v2/dialects/test_simp.py +564 -0
  255. mplang_nightly-0.1.dev266/tests/v2/dialects/test_simp_comm.py +190 -0
  256. mplang_nightly-0.1.dev266/tests/v2/dialects/test_spu.py +214 -0
  257. mplang_nightly-0.1.dev266/tests/v2/dialects/test_table.py +60 -0
  258. mplang_nightly-0.1.dev266/tests/v2/dialects/test_tee.py +156 -0
  259. mplang_nightly-0.1.dev266/tests/v2/dialects/test_tensor.py +196 -0
  260. mplang_nightly-0.1.dev266/tests/v2/edsl/__init__.py +15 -0
  261. mplang_nightly-0.1.dev266/tests/v2/edsl/test_context.py +164 -0
  262. mplang_nightly-0.1.dev266/tests/v2/edsl/test_graph.py +664 -0
  263. mplang_nightly-0.1.dev266/tests/v2/edsl/test_primitive.py +252 -0
  264. mplang_nightly-0.1.dev266/tests/v2/edsl/test_primitive_multi_output.py +269 -0
  265. mplang_nightly-0.1.dev266/tests/v2/edsl/test_printer.py +100 -0
  266. mplang_nightly-0.1.dev266/tests/v2/edsl/test_serde.py +279 -0
  267. mplang_nightly-0.1.dev266/tests/v2/edsl/test_tracer.py +309 -0
  268. mplang_nightly-0.1.dev266/tests/v2/edsl/test_typing.py +836 -0
  269. mplang_nightly-0.1.dev266/tests/v2/edsl/test_typing_graph_serde.py +346 -0
  270. mplang_nightly-0.1.dev266/tests/v2/libs/device/__init__.py +13 -0
  271. mplang_nightly-0.1.dev266/tests/v2/libs/device/conftest.py +245 -0
  272. mplang_nightly-0.1.dev266/tests/v2/libs/device/test_device_api_errors.py +355 -0
  273. mplang_nightly-0.1.dev266/tests/v2/libs/device/test_device_dialects.py +345 -0
  274. mplang_nightly-0.1.dev266/tests/v2/libs/device/test_device_layouts.py +357 -0
  275. mplang_nightly-0.1.dev266/tests/v2/libs/device/test_device_tee.py +220 -0
  276. mplang_nightly-0.1.dev266/tests/v2/libs/ml/__init__.py +15 -0
  277. mplang_nightly-0.1.dev266/tests/v2/libs/ml/test_sgb.py +164 -0
  278. mplang_nightly-0.1.dev266/tests/v2/libs/ml/test_sgb_bench.py +401 -0
  279. mplang_nightly-0.1.dev266/tests/v2/libs/mpc/__init__.py +13 -0
  280. mplang_nightly-0.1.dev266/tests/v2/libs/mpc/analytics/__init__.py +15 -0
  281. mplang_nightly-0.1.dev266/tests/v2/libs/mpc/analytics/test_aggregation.py +77 -0
  282. mplang_nightly-0.1.dev266/tests/v2/libs/mpc/analytics/test_groupby.py +207 -0
  283. mplang_nightly-0.1.dev266/tests/v2/libs/mpc/analytics/test_permutation.py +108 -0
  284. mplang_nightly-0.1.dev266/tests/v2/libs/mpc/ot/__init__.py +15 -0
  285. mplang_nightly-0.1.dev266/tests/v2/libs/mpc/ot/test_ot.py +196 -0
  286. mplang_nightly-0.1.dev266/tests/v2/libs/mpc/ot/test_ot_extension.py +58 -0
  287. mplang_nightly-0.1.dev266/tests/v2/libs/mpc/ot/test_silent_ot.py +101 -0
  288. mplang_nightly-0.1.dev266/tests/v2/libs/mpc/psi/__init__.py +15 -0
  289. mplang_nightly-0.1.dev266/tests/v2/libs/mpc/psi/test_okvs_gct.py +112 -0
  290. mplang_nightly-0.1.dev266/tests/v2/libs/mpc/psi/test_oprf.py +79 -0
  291. mplang_nightly-0.1.dev266/tests/v2/libs/mpc/psi/test_psi.py +115 -0
  292. mplang_nightly-0.1.dev266/tests/v2/libs/mpc/psi/test_psi_bench.py +155 -0
  293. mplang_nightly-0.1.dev266/tests/v2/libs/mpc/psi/test_rr22.py +288 -0
  294. mplang_nightly-0.1.dev266/tests/v2/libs/mpc/psi/verify_psi_okvs_logic.py +164 -0
  295. mplang_nightly-0.1.dev266/tests/v2/libs/mpc/test_field_gf128.py +97 -0
  296. mplang_nightly-0.1.dev266/tests/v2/libs/mpc/test_utils.py +82 -0
  297. mplang_nightly-0.1.dev266/tests/v2/libs/mpc/vole/__init__.py +15 -0
  298. mplang_nightly-0.1.dev266/tests/v2/libs/mpc/vole/test_gilboa_manual.py +137 -0
  299. mplang_nightly-0.1.dev266/tests/v2/libs/mpc/vole/test_ldpc.py +179 -0
  300. mplang_nightly-0.1.dev266/tests/v2/libs/mpc/vole/test_silver_vole.py +183 -0
  301. mplang_nightly-0.1.dev266/tests/v2/libs/mpc/vole/test_vole.py +87 -0
  302. mplang_nightly-0.1.dev266/tests/v2/libs/mpc/vole/test_vole_bench.py +165 -0
  303. mplang_nightly-0.1.dev266/tests/v2/libs/mpc/vole/verify_vole_logic.py +171 -0
  304. mplang_nightly-0.1.dev266/tests/v2/libs/test_collective.py +251 -0
  305. mplang_nightly-0.1.dev266/tests/v2/libs/test_simple_guide.py +60 -0
  306. mplang_nightly-0.1.dev266/tests/v2/runtime/test_interpreter_async.py +113 -0
  307. mplang_nightly-0.1.dev266/tests/v2/runtime/test_object_store.py +55 -0
  308. mplang_nightly-0.1.dev266/tests/v2/runtime/test_object_store_fs.py +104 -0
  309. mplang_nightly-0.1.dev266/tests/v2/test_fetch_semantics.py +129 -0
  310. mplang_nightly-0.1.dev266/tests/v2/test_pytree_io.py +231 -0
  311. mplang_nightly-0.1.dev266/tests/v2/test_store.py +89 -0
  312. mplang_nightly-0.1.dev266/tests/v2/utils/__init__.py +13 -0
  313. mplang_nightly-0.1.dev266/tests/v2/utils/tensor_patch.py +131 -0
  314. mplang_nightly-0.1.dev266/tutorials/MIGRATION.md +191 -0
  315. {mplang_nightly-0.1.dev203 → mplang_nightly-0.1.dev266}/tutorials/run.sh +18 -15
  316. {mplang_nightly-0.1.dev203/tutorials → mplang_nightly-0.1.dev266/tutorials/v1}/device/00_device_basics.py +1 -1
  317. {mplang_nightly-0.1.dev203/tutorials → mplang_nightly-0.1.dev266/tutorials/v1}/device/01_function_decorator.py +1 -1
  318. {mplang_nightly-0.1.dev203/tutorials → mplang_nightly-0.1.dev266/tutorials/v1}/device/02_simulation_and_driver.py +6 -3
  319. {mplang_nightly-0.1.dev203/tutorials → mplang_nightly-0.1.dev266/tutorials/v1}/device/03_run_jax.py +1 -1
  320. {mplang_nightly-0.1.dev203/tutorials → mplang_nightly-0.1.dev266/tutorials/v1}/device/04_run_sql.py +8 -8
  321. {mplang_nightly-0.1.dev203/tutorials → mplang_nightly-0.1.dev266/tutorials/v1}/device/05_pipeline.py +3 -3
  322. {mplang_nightly-0.1.dev203/tutorials → mplang_nightly-0.1.dev266/tutorials/v1}/device/06_ir_dump_and_analysis.py +1 -1
  323. mplang_nightly-0.1.dev266/tutorials/v1/device/07_run_nnx.py +627 -0
  324. {mplang_nightly-0.1.dev203/tutorials → mplang_nightly-0.1.dev266/tutorials/v1}/pitfalls/late_binding.py +1 -1
  325. {mplang_nightly-0.1.dev203/tutorials → mplang_nightly-0.1.dev266/tutorials/v1}/pitfalls/rand.py +1 -1
  326. {mplang_nightly-0.1.dev203/tutorials → mplang_nightly-0.1.dev266/tutorials/v1}/simp/00_basic.py +1 -1
  327. {mplang_nightly-0.1.dev203/tutorials → mplang_nightly-0.1.dev266/tutorials/v1}/simp/01_condition.py +1 -1
  328. {mplang_nightly-0.1.dev203/tutorials → mplang_nightly-0.1.dev266/tutorials/v1}/simp/02_whileloop.py +1 -1
  329. {mplang_nightly-0.1.dev203/tutorials → mplang_nightly-0.1.dev266/tutorials/v1}/simp/03_stdio.py +2 -2
  330. {mplang_nightly-0.1.dev203/tutorials → mplang_nightly-0.1.dev266/tutorials/v1}/simp/04_phe.py +2 -2
  331. {mplang_nightly-0.1.dev203/tutorials → mplang_nightly-0.1.dev266/tutorials/v1}/simp/05_tee.py +48 -1
  332. {mplang_nightly-0.1.dev203/tutorials → mplang_nightly-0.1.dev266/tutorials/v1}/simp/06_fhe.py +2 -2
  333. {mplang_nightly-0.1.dev203/tutorials → mplang_nightly-0.1.dev266/tutorials/v1}/simp/07_advanced.py +1 -1
  334. mplang_nightly-0.1.dev266/tutorials/v1/simp/08_simple_secret_sharing.py +291 -0
  335. mplang_nightly-0.1.dev266/tutorials/v2/00_device_basics.py +190 -0
  336. mplang_nightly-0.1.dev266/tutorials/v2/01_function_decorator.py +135 -0
  337. mplang_nightly-0.1.dev266/tutorials/v2/02_simulation_and_driver.py +125 -0
  338. mplang_nightly-0.1.dev266/tutorials/v2/03_run_jax.py +131 -0
  339. mplang_nightly-0.1.dev266/tutorials/v2/04_ir_dump_and_analysis.py +115 -0
  340. mplang_nightly-0.1.dev266/tutorials/v2/05_run_sql.py +193 -0
  341. mplang_nightly-0.1.dev266/tutorials/v2/06_pipeline.py +279 -0
  342. mplang_nightly-0.1.dev266/tutorials/v2/07_stax_nn.py +272 -0
  343. mplang_nightly-0.1.dev266/tutorials/v2/__init__.py +15 -0
  344. {mplang_nightly-0.1.dev203 → mplang_nightly-0.1.dev266}/LICENSE +0 -0
  345. {mplang_nightly-0.1.dev203/examples → mplang_nightly-0.1.dev266/examples/v1}/conf/3pc.yaml +0 -0
  346. {mplang_nightly-0.1.dev203/examples → mplang_nightly-0.1.dev266/examples/v1}/stax_nn/README.md +0 -0
  347. {mplang_nightly-0.1.dev203/examples → mplang_nightly-0.1.dev266/examples/v1}/stax_nn/models.py +0 -0
  348. {mplang_nightly-0.1.dev203/examples → mplang_nightly-0.1.dev266/examples/v1}/xgboost/hist_jax.py +0 -0
  349. {mplang_nightly-0.1.dev203/examples → mplang_nightly-0.1.dev266/examples/v1}/xgboost/hist_jax_test.py +0 -0
  350. {mplang_nightly-0.1.dev203/examples → mplang_nightly-0.1.dev266/examples/v1}/xgboost/naive_np.py +0 -0
  351. {mplang_nightly-0.1.dev203/examples → mplang_nightly-0.1.dev266/examples/v1}/xgboost/readme.md +0 -0
  352. {mplang_nightly-0.1.dev203 → mplang_nightly-0.1.dev266}/hatch_build.py +0 -0
  353. {mplang_nightly-0.1.dev203/mplang → mplang_nightly-0.1.dev266/mplang/v1}/core/cluster.py +0 -0
  354. {mplang_nightly-0.1.dev203/mplang → mplang_nightly-0.1.dev266/mplang/v1}/core/dtypes.py +0 -0
  355. {mplang_nightly-0.1.dev203/mplang → mplang_nightly-0.1.dev266/mplang/v1}/core/mask.py +0 -0
  356. {mplang_nightly-0.1.dev203/mplang → mplang_nightly-0.1.dev266/mplang/v1}/protos/v1alpha1/mpir_pb2.py +0 -0
  357. {mplang_nightly-0.1.dev203/mplang → mplang_nightly-0.1.dev266/mplang/v1}/protos/v1alpha1/mpir_pb2.pyi +0 -0
  358. {mplang_nightly-0.1.dev203/mplang → mplang_nightly-0.1.dev266/mplang/v1}/protos/v1alpha1/value_pb2.py +0 -0
  359. {mplang_nightly-0.1.dev203/mplang → mplang_nightly-0.1.dev266/mplang/v1}/protos/v1alpha1/value_pb2.pyi +0 -0
  360. {mplang_nightly-0.1.dev203/mplang → mplang_nightly-0.1.dev266/mplang/v1}/runtime/exceptions.py +0 -0
  361. {mplang_nightly-0.1.dev203/mplang → mplang_nightly-0.1.dev266/mplang/v1}/runtime/http_api.md +0 -0
  362. {mplang_nightly-0.1.dev203/mplang → mplang_nightly-0.1.dev266/mplang/v1}/runtime/link_comm.py +0 -0
  363. {mplang_nightly-0.1.dev203/mplang → mplang_nightly-0.1.dev266/mplang/v1}/simp/__init__.py +0 -0
  364. {mplang_nightly-0.1.dev203/mplang → mplang_nightly-0.1.dev266/mplang/v1}/utils/__init__.py +0 -0
  365. {mplang_nightly-0.1.dev203/mplang → mplang_nightly-0.1.dev266/mplang/v1}/utils/crypto.py +0 -0
  366. {mplang_nightly-0.1.dev203/mplang → mplang_nightly-0.1.dev266/mplang/v1}/utils/func_utils.py +0 -0
  367. {mplang_nightly-0.1.dev203/mplang → mplang_nightly-0.1.dev266/mplang/v1}/utils/spu_utils.py +0 -0
  368. {mplang_nightly-0.1.dev203/tests → mplang_nightly-0.1.dev266/tests/v1}/__init__.py +0 -0
  369. {mplang_nightly-0.1.dev203/tests → mplang_nightly-0.1.dev266/tests/v1}/core/__init__.py +0 -0
  370. {mplang_nightly-0.1.dev203/tests → mplang_nightly-0.1.dev266/tests/v1}/core/expr/__init__.py +0 -0
  371. {mplang_nightly-0.1.dev203/tests → mplang_nightly-0.1.dev266/tests/v1}/device/__init__.py +0 -0
  372. {mplang_nightly-0.1.dev203/tests → mplang_nightly-0.1.dev266/tests/v1}/integration/README.md +0 -0
  373. {mplang_nightly-0.1.dev203/tests → mplang_nightly-0.1.dev266/tests/v1}/ops/__init__.py +0 -0
  374. {mplang_nightly-0.1.dev203/tests → mplang_nightly-0.1.dev266/tests/v1}/runtime/__init__.py +0 -0
  375. {mplang_nightly-0.1.dev203/tests → mplang_nightly-0.1.dev266/tests/v1}/utils/__init__.py +0 -0
  376. {mplang_nightly-0.1.dev203 → mplang_nightly-0.1.dev266}/tutorials/README.md +0 -0
  377. {mplang_nightly-0.1.dev203 → mplang_nightly-0.1.dev266}/tutorials/__init__.py +0 -0
  378. {mplang_nightly-0.1.dev203 → mplang_nightly-0.1.dev266}/tutorials/data/alice.csv +0 -0
  379. {mplang_nightly-0.1.dev203 → mplang_nightly-0.1.dev266}/tutorials/data/bob.csv +0 -0
  380. {mplang_nightly-0.1.dev203 → mplang_nightly-0.1.dev266}/tutorials/data/prepare_vertical_iris.py +0 -0
@@ -208,6 +208,7 @@ __marimo__/
208
208
 
209
209
  # mplang
210
210
  app.log
211
+ .mpl/
211
212
 
212
213
  # Ignore pylint config files (we use ruff instead)
213
214
  .pylintrc
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: mplang-nightly
3
- Version: 0.1.dev203
3
+ Version: 0.1.dev266
4
4
  Summary: Multi-Party Programming Language
5
5
  Author-email: SecretFlow Team <secretflow-contact@service.alipay.com>
6
6
  License: Apache License
@@ -205,15 +205,21 @@ License: Apache License
205
205
  See the License for the specific language governing permissions and
206
206
  limitations under the License.
207
207
  License-File: LICENSE
208
- Requires-Python: >=3.10
208
+ Requires-Python: <3.13,>=3.11
209
+ Requires-Dist: coincurve>=20.0.0
210
+ Requires-Dist: cryptography>=43.0.0
209
211
  Requires-Dist: duckdb>=1.0.0
210
212
  Requires-Dist: fastapi
211
- Requires-Dist: httpx>=0.27.0
213
+ Requires-Dist: flax>=0.12.0
214
+ Requires-Dist: httpx<1.0.0,>=0.27.0
215
+ Requires-Dist: jax[cpu]==0.8.0
212
216
  Requires-Dist: lightphe<0.1.0,>=0.0.15
217
+ Requires-Dist: numpy>=2.0.0
213
218
  Requires-Dist: pandas>=2.0.0
214
219
  Requires-Dist: protobuf<6.0,>=5.0
215
220
  Requires-Dist: pyarrow>=14.0.0
216
- Requires-Dist: spu==0.9.4.dev20250827
221
+ Requires-Dist: pyyaml>=6.0
222
+ Requires-Dist: spu>=0.10.0.dev20251208
217
223
  Requires-Dist: sqlglot>=23.0.0
218
224
  Requires-Dist: tenseal==0.3.16
219
225
  Requires-Dist: typing-extensions
@@ -242,7 +248,7 @@ multiple parties in a synchronous, SPMD (Single Program, Multiple Data) fashion.
242
248
 
243
249
  ### Installation
244
250
 
245
- You'll need a modern Python environment (3.10+). We recommend using `uv` for fast installation.
251
+ You'll need a modern Python environment (3.11+). We recommend using `uv` for fast installation.
246
252
 
247
253
  ```bash
248
254
  # Install uv (if not already installed)
@@ -20,7 +20,7 @@ multiple parties in a synchronous, SPMD (Single Program, Multiple Data) fashion.
20
20
 
21
21
  ### Installation
22
22
 
23
- You'll need a modern Python environment (3.10+). We recommend using `uv` for fast installation.
23
+ You'll need a modern Python environment (3.11+). We recommend using `uv` for fast installation.
24
24
 
25
25
  ```bash
26
26
  # Install uv (if not already installed)
@@ -27,11 +27,11 @@ import yaml
27
27
  from jax.example_libraries import stax
28
28
  from sklearn.metrics import accuracy_score
29
29
 
30
- import mplang as mp
30
+ import mplang.v1 as mp
31
31
 
32
32
  parser = argparse.ArgumentParser(description="distributed driver.")
33
33
  parser.add_argument("--model", default="network_a", type=str)
34
- parser.add_argument("-c", "--config", default="examples/conf/3pc.yaml", type=str)
34
+ parser.add_argument("-c", "--config", default="examples/v1/conf/3pc.yaml", type=str)
35
35
  parser.add_argument("-e", "--epoch", default=5, type=int)
36
36
  parser.add_argument("-b", "--batch_size", default=128, type=int)
37
37
  parser.add_argument("-o", "--optimizer", default="SGD", type=str)
@@ -0,0 +1,615 @@
1
+ # Copyright 2025 Ant Group Co., Ltd.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ """
16
+ Microbenchmark for FHE(BFV) histogram path in SecureBoost.
17
+
18
+ This script measures the time to compute PP-side cumulative bucket sums
19
+ via encrypted ct·ct dot products using TenSEAL/SEAL BFV vector backend.
20
+
21
+ It provides two modes:
22
+ - classic: separate g/h ciphertexts + ct·ct dot (current training path)
23
+ - interleaved: interleave g/h into one ct, do one ct·ct mul + two ct·pt dots (even/odd)
24
+
25
+ Usage examples:
26
+ uv run -q python examples/v1/xgboost/bench_fhe_hist.py --world-size 2 --m 4096 --n-total 16 --n-ap 4 --k 16 --t 4 --reps 3 --mode classic
27
+ uv run -q python examples/v1/xgboost/bench_fhe_hist.py --world-size 2 --m 4096 --n-total 16 --n-ap 4 --k 16 --t 4 --reps 3 --mode interleaved
28
+ """
29
+
30
+ from __future__ import annotations
31
+
32
+ import argparse
33
+ import time
34
+ from functools import partial
35
+
36
+ import jax
37
+ import jax.numpy as jnp
38
+ import numpy as np
39
+ from examples.xgboost.sgb import (
40
+ DEFAULT_FXP_BITS,
41
+ batch_feature_wise_bucket_sum_fhe_vector,
42
+ build_bins_equi_width,
43
+ compute_bin_indices,
44
+ compute_gh,
45
+ compute_init_pred,
46
+ quantize_gh,
47
+ )
48
+
49
+ import mplang.v1 as mp
50
+ from mplang.v1.ops import fhe
51
+
52
+
53
+ def _gen_data(n_samples: int, n_total_features: int, n_features_ap: int, seed: int):
54
+ rng = np.random.default_rng(seed)
55
+ X = rng.normal(size=(n_samples, n_total_features)).astype(np.float32)
56
+ # make a simple linear label with noise
57
+ w = rng.normal(size=(n_total_features,)).astype(np.float32)
58
+ z = X @ w + 0.1 * rng.normal(size=(n_samples,)).astype(np.float32)
59
+ p = 1 / (1 + np.exp(-z))
60
+ y = (p > 0.5).astype(np.float32)
61
+
62
+ X_ap = X[:, :n_features_ap]
63
+ X_pp = X[:, n_features_ap:]
64
+ return X_ap, X_pp, y
65
+
66
+
67
+ @mp.function
68
+ def _bench_once(
69
+ ap_id: int,
70
+ pp_ids: list[int],
71
+ X_ap: np.ndarray,
72
+ X_pp_splits: list[np.ndarray],
73
+ y: np.ndarray,
74
+ k: int,
75
+ t: int,
76
+ reps: int,
77
+ mode: str,
78
+ include_precompute: bool,
79
+ breakdown: bool,
80
+ ):
81
+ # Place data
82
+ X_ap_j = mp.run_jax_at(ap_id, lambda x: x, jnp.array(X_ap, dtype=jnp.float32))
83
+ X_pp_j = [
84
+ mp.run_jax_at(pp, lambda x: x, jnp.array(xpp, dtype=jnp.float32))
85
+ for pp, xpp in zip(pp_ids, X_pp_splits, strict=True)
86
+ ]
87
+ y_j = mp.run_jax_at(ap_id, lambda x: x, jnp.array(y, dtype=jnp.float32))
88
+
89
+ # Binning per party
90
+ build_bins_vmapped = jax.vmap(partial(build_bins_equi_width, max_bin=k), in_axes=1)
91
+ compute_indices_vmapped = jax.vmap(compute_bin_indices, in_axes=(1, 0), out_axes=1)
92
+
93
+ ap_bins = mp.run_jax_at(ap_id, build_bins_vmapped, X_ap_j)
94
+ _ = mp.run_jax_at(ap_id, compute_indices_vmapped, X_ap_j, ap_bins)
95
+
96
+ pp_bins = [
97
+ mp.run_jax_at(pp, build_bins_vmapped, X_pp_j[i]) for i, pp in enumerate(pp_ids)
98
+ ]
99
+ pp_idx = [
100
+ mp.run_jax_at(pp, compute_indices_vmapped, X_pp_j[i], pp_bins[i])
101
+ for i, pp in enumerate(pp_ids)
102
+ ]
103
+
104
+ # AP GH + quantize + encrypt
105
+ init_pred = mp.run_jax_at(ap_id, compute_init_pred, y_j)
106
+ logits0 = mp.run_jax_at(ap_id, lambda p, m=y_j.shape[0]: p * jnp.ones(m), init_pred)
107
+ GH = mp.run_jax_at(ap_id, compute_gh, y_j, logits0)
108
+
109
+ fxp_scale = 1 << DEFAULT_FXP_BITS
110
+ Q = mp.run_jax_at(ap_id, quantize_gh, GH, fxp_scale)
111
+ qg = mp.run_jax_at(ap_id, lambda a: a[:, 0].astype(jnp.int64), Q)
112
+ qh = mp.run_jax_at(ap_id, lambda a: a[:, 1].astype(jnp.int64), Q)
113
+
114
+ priv_ctx, pub_ctx, _ = mp.run_at(ap_id, fhe.keygen, scheme="BFV")
115
+
116
+ # Prepare ciphertext(s)
117
+ g_ct = None # type: ignore[assignment]
118
+ h_ct = None # type: ignore[assignment]
119
+ gh_ct = None # type: ignore[assignment]
120
+ if mode in ("classic", "classic_cached"):
121
+ g_ct = mp.run_at(ap_id, fhe.encrypt, qg, pub_ctx)
122
+ h_ct = mp.run_at(ap_id, fhe.encrypt, qh, pub_ctx)
123
+ elif mode in (
124
+ "interleaved",
125
+ "interleaved_cached",
126
+ "interleaved_fused",
127
+ "interleaved_fused_cached",
128
+ ):
129
+ # Interleave qg and qh into one vector: [g0,h0,g1,h1,...]
130
+ def _interleave(a, b):
131
+ m = a.shape[0]
132
+ out = jnp.empty((m * 2,), dtype=jnp.int64)
133
+ out = out.at[0::2].set(a)
134
+ out = out.at[1::2].set(b)
135
+ return out
136
+
137
+ qgh = mp.run_jax_at(ap_id, _interleave, qg, qh)
138
+ gh_ct = mp.run_at(ap_id, fhe.encrypt, qgh, pub_ctx)
139
+ else:
140
+ raise ValueError(f"Unknown mode: {mode}")
141
+ rng = mp.run_jax_at(
142
+ ap_id,
143
+ lambda m: jnp.array(
144
+ np.random.default_rng(0).integers(0, t, size=m), dtype=jnp.int64
145
+ ),
146
+ y_j.shape[0],
147
+ )
148
+
149
+ def mk_subgroup_map(bt_level, group_size):
150
+ group_indices = jnp.arange(group_size)[:, None]
151
+ return (group_indices == bt_level).astype(jnp.int8)
152
+
153
+ # Precompute subgroup maps per-PP once (rng fixed) and parity selectors once if needed
154
+ subgroup_maps = []
155
+ for pp in pp_ids:
156
+ pub_ctx_pp = mp.p2p(ap_id, pp, pub_ctx)
157
+ rng_pp = mp.p2p(ap_id, pp, rng)
158
+ subgroup_map_pp = mp.run_jax_at(pp, mk_subgroup_map, rng_pp, t)
159
+ subgroup_maps.append((pp, pub_ctx_pp, subgroup_map_pp))
160
+
161
+ even_sel = None
162
+ odd_sel = None
163
+ if mode in (
164
+ "interleaved",
165
+ "interleaved_cached",
166
+ ):
167
+
168
+ def _build_parity_selectors(m_samples):
169
+ n = m_samples * 2
170
+ even = jnp.zeros((n,), dtype=jnp.int64).at[0::2].set(1)
171
+ odd = jnp.zeros((n,), dtype=jnp.int64).at[1::2].set(1)
172
+ return even, odd
173
+
174
+ even_sel, odd_sel = mp.run_jax_at(ap_id, _build_parity_selectors, y_j.shape[0])
175
+
176
+ # Optional: precompute and encrypt all bucket masks per-PP for cached modes
177
+ cached_masks = None
178
+ pre_dt = mp.run_jax_at(ap_id, lambda: jnp.array(0.0, dtype=jnp.float64))
179
+ if mode in ("interleaved_cached", "classic_cached"):
180
+ # Helper function to duplicate mask to interleaved length (used only for interleaved mode)
181
+ def _dup2(mask):
182
+ n = mask.shape[0]
183
+ out = jnp.empty((n * 2,), dtype=jnp.int64)
184
+ out = out.at[0::2].set(mask)
185
+ out = out.at[1::2].set(mask)
186
+ return out
187
+
188
+ use_interleave = mode == "interleaved_cached"
189
+
190
+ cached_masks = [] # list per PP: [ [list per group: [list per feature: [mask_ct per bucket]]] ]
191
+ tpre0 = mp.run_jax_at(ap_id, lambda: jnp.array(time.time(), dtype=jnp.float64))
192
+ for i, (pp, pub_ctx_pp, subgroup_map_pp) in enumerate(subgroup_maps):
193
+ feature_size = pp_idx[i].shape[1]
194
+ grp_masks = []
195
+ for grp in range(t):
196
+ gom = mp.run_jax_at(pp, lambda m, idx: m[idx], subgroup_map_pp, grp)
197
+
198
+ def create_masked_order_map(m, om):
199
+ mask_expanded = jnp.expand_dims(m, axis=1)
200
+ mask_full = jnp.broadcast_to(mask_expanded, om.shape)
201
+ return jnp.where(mask_full == 1, om, -1)
202
+
203
+ gom_map = mp.run_jax_at(pp, create_masked_order_map, gom, pp_idx[i])
204
+
205
+ feat_masks = []
206
+ for feature_idx in range(feature_size):
207
+ # Build all bucket masks at once: (k, M)
208
+ def build_bucket_masks(gom_, f_idx, num_buckets):
209
+ def mask_for_b(b_idx, gom_i, f_i):
210
+ fb = gom_i[:, f_i]
211
+ valid_and_in_bucket = (fb >= 0) & (fb <= b_idx)
212
+ return valid_and_in_bucket.astype(jnp.int64)
213
+
214
+ bs = jnp.arange(num_buckets, dtype=jnp.int64)
215
+ return jax.vmap(mask_for_b, in_axes=(0, None, None))(
216
+ bs, gom_, f_idx
217
+ )
218
+
219
+ bucket_masks = mp.run_jax_at(
220
+ pp, build_bucket_masks, gom_map, feature_idx, k
221
+ )
222
+ # Encrypt each bucket mask (with optional duplication for interleaved mode)
223
+ masks_ct = []
224
+ for b in range(k):
225
+ row_b = mp.run_jax_at(pp, lambda M, bi: M[bi], bucket_masks, b)
226
+ # Apply _dup2 transformation only for interleaved mode
227
+ if use_interleave:
228
+ mask_to_encrypt = mp.run_jax_at(pp, _dup2, row_b)
229
+ else:
230
+ mask_to_encrypt = row_b
231
+ mask_ct_pp = mp.run_at(
232
+ pp, fhe.encrypt, mask_to_encrypt, pub_ctx_pp
233
+ )
234
+ mask_ct_ap = mp.p2p(pp, ap_id, mask_ct_pp)
235
+ masks_ct.append(mask_ct_ap)
236
+ feat_masks.append(masks_ct)
237
+ grp_masks.append(feat_masks)
238
+ cached_masks.append(grp_masks)
239
+ tpre1 = mp.run_jax_at(ap_id, lambda: jnp.array(time.time(), dtype=jnp.float64))
240
+ pre_dt = mp.run_jax_at(ap_id, lambda a, b: a - b, tpre1, tpre0)
241
+
242
+ # Run reps and time compute + decrypt assembly across all PPs
243
+ times_total = []
244
+ times_comp = []
245
+ times_dec = []
246
+ for rep_i in range(reps):
247
+ t0 = mp.run_jax_at(ap_id, lambda: jnp.array(time.time(), dtype=jnp.float64))
248
+ comp_parts = []
249
+ dec_parts = []
250
+ for i, (pp, pub_ctx_pp, subgroup_map_pp) in enumerate(subgroup_maps):
251
+ if mode == "classic":
252
+ assert g_ct is not None and h_ct is not None
253
+ tcomp0 = mp.run_jax_at(
254
+ ap_id, lambda: jnp.array(time.time(), dtype=jnp.float64)
255
+ )
256
+ g_lists, h_lists = batch_feature_wise_bucket_sum_fhe_vector(
257
+ g_ct,
258
+ h_ct,
259
+ subgroup_map_pp,
260
+ pp_idx[i],
261
+ k,
262
+ t,
263
+ rank=pp,
264
+ ap_rank=ap_id,
265
+ )
266
+ tcomp1 = mp.run_jax_at(
267
+ ap_id, lambda: jnp.array(time.time(), dtype=jnp.float64)
268
+ )
269
+ comp_parts.append(
270
+ mp.run_jax_at(ap_id, lambda a, b: a - b, tcomp1, tcomp0)
271
+ )
272
+ tdec0 = mp.run_jax_at(
273
+ ap_id, lambda: jnp.array(time.time(), dtype=jnp.float64)
274
+ )
275
+ for grp in range(t):
276
+ enc_g_list = g_lists[grp]
277
+ enc_h_list = h_lists[grp]
278
+ dec_g = [
279
+ mp.run_at(ap_id, fhe.decrypt, ct, priv_ctx) for ct in enc_g_list
280
+ ]
281
+ dec_h = [
282
+ mp.run_at(ap_id, fhe.decrypt, ct, priv_ctx) for ct in enc_h_list
283
+ ]
284
+
285
+ def _stack(*xs):
286
+ return jnp.stack(xs)
287
+
288
+ _ = mp.run_jax_at(ap_id, _stack, *dec_g)
289
+ _ = mp.run_jax_at(ap_id, _stack, *dec_h)
290
+ tdec1 = mp.run_jax_at(
291
+ ap_id, lambda: jnp.array(time.time(), dtype=jnp.float64)
292
+ )
293
+ dec_parts.append(mp.run_jax_at(ap_id, lambda a, b: a - b, tdec1, tdec0))
294
+ elif mode == "classic_cached":
295
+ assert g_ct is not None and h_ct is not None
296
+ assert cached_masks is not None
297
+ feature_size = pp_idx[i].shape[1]
298
+ g_lists = [[] for _ in range(t)]
299
+ h_lists = [[] for _ in range(t)]
300
+ tcomp0 = mp.run_jax_at(
301
+ ap_id, lambda: jnp.array(time.time(), dtype=jnp.float64)
302
+ )
303
+ for grp in range(t):
304
+ feat_masks = cached_masks[i][grp]
305
+ for feature_idx in range(feature_size):
306
+ for bucket_idx in range(k):
307
+ mask_ct_ap = feat_masks[feature_idx][bucket_idx]
308
+ g_sum_ct = mp.run_at(ap_id, fhe.dot, g_ct, mask_ct_ap)
309
+ h_sum_ct = mp.run_at(ap_id, fhe.dot, h_ct, mask_ct_ap)
310
+ g_lists[grp].append(g_sum_ct)
311
+ h_lists[grp].append(h_sum_ct)
312
+ tcomp1 = mp.run_jax_at(
313
+ ap_id, lambda: jnp.array(time.time(), dtype=jnp.float64)
314
+ )
315
+ comp_parts.append(
316
+ mp.run_jax_at(ap_id, lambda a, b: a - b, tcomp1, tcomp0)
317
+ )
318
+ tdec0 = mp.run_jax_at(
319
+ ap_id, lambda: jnp.array(time.time(), dtype=jnp.float64)
320
+ )
321
+ for grp in range(t):
322
+ enc_g_list = g_lists[grp]
323
+ enc_h_list = h_lists[grp]
324
+ dec_g = [
325
+ mp.run_at(ap_id, fhe.decrypt, ct, priv_ctx) for ct in enc_g_list
326
+ ]
327
+ dec_h = [
328
+ mp.run_at(ap_id, fhe.decrypt, ct, priv_ctx) for ct in enc_h_list
329
+ ]
330
+
331
+ def _stack(*xs):
332
+ return jnp.stack(xs)
333
+
334
+ _ = mp.run_jax_at(ap_id, _stack, *dec_g)
335
+ _ = mp.run_jax_at(ap_id, _stack, *dec_h)
336
+ tdec1 = mp.run_jax_at(
337
+ ap_id, lambda: jnp.array(time.time(), dtype=jnp.float64)
338
+ )
339
+ dec_parts.append(mp.run_jax_at(ap_id, lambda a, b: a - b, tdec1, tdec0))
340
+ else:
341
+ assert gh_ct is not None
342
+ # even_sel/odd_sel were built once before reps
343
+
344
+ def _dup2(mask):
345
+ n = mask.shape[0]
346
+ out = jnp.empty((n * 2,), dtype=jnp.int64)
347
+ out = out.at[0::2].set(mask)
348
+ out = out.at[1::2].set(mask)
349
+ return out
350
+
351
+ feature_size = pp_idx[i].shape[1]
352
+ g_lists = [[] for _ in range(t)]
353
+ h_lists = [[] for _ in range(t)]
354
+ tcomp0 = mp.run_jax_at(
355
+ ap_id, lambda: jnp.array(time.time(), dtype=jnp.float64)
356
+ )
357
+ for grp in range(t):
358
+ if mode in ("interleaved_cached",) and cached_masks is not None:
359
+ # Use precomputed encrypted masks
360
+ feat_masks = cached_masks[i][grp]
361
+ for feature_idx in range(feature_size):
362
+ for bucket_idx in range(k):
363
+ mask_ct_ap = feat_masks[feature_idx][bucket_idx]
364
+ prod_ct = mp.run_at(ap_id, fhe.mul, gh_ct, mask_ct_ap)
365
+ g_sum_ct = mp.run_at(ap_id, fhe.dot, prod_ct, even_sel)
366
+ h_sum_ct = mp.run_at(ap_id, fhe.dot, prod_ct, odd_sel)
367
+ g_lists[grp].append(g_sum_ct)
368
+ h_lists[grp].append(h_sum_ct)
369
+ else:
370
+ # Build on the fly
371
+ gom = mp.run_jax_at(
372
+ pp, lambda m, idx: m[idx], subgroup_map_pp, grp
373
+ )
374
+
375
+ def create_masked_order_map(m, om):
376
+ mask_expanded = jnp.expand_dims(m, axis=1)
377
+ mask_full = jnp.broadcast_to(mask_expanded, om.shape)
378
+ return jnp.where(mask_full == 1, om, -1)
379
+
380
+ gom_map = mp.run_jax_at(
381
+ pp, create_masked_order_map, gom, pp_idx[i]
382
+ )
383
+ for feature_idx in range(feature_size):
384
+ for bucket_idx in range(k):
385
+
386
+ def create_bucket_mask(gom_, f_idx, b_idx):
387
+ fb = gom_[:, f_idx]
388
+ valid_and_in_bucket = (fb >= 0) & (fb <= b_idx)
389
+ return valid_and_in_bucket.astype(jnp.int64)
390
+
391
+ bucket_mask = mp.run_jax_at(
392
+ pp,
393
+ create_bucket_mask,
394
+ gom_map,
395
+ feature_idx,
396
+ bucket_idx,
397
+ )
398
+ inter_mask = mp.run_jax_at(pp, _dup2, bucket_mask)
399
+ mask_ct_pp = mp.run_at(
400
+ pp, fhe.encrypt, inter_mask, pub_ctx_pp
401
+ )
402
+ mask_ct_ap = mp.p2p(pp, ap_id, mask_ct_pp)
403
+
404
+ prod_ct = mp.run_at(ap_id, fhe.mul, gh_ct, mask_ct_ap)
405
+ g_sum_ct = mp.run_at(ap_id, fhe.dot, prod_ct, even_sel)
406
+ h_sum_ct = mp.run_at(ap_id, fhe.dot, prod_ct, odd_sel)
407
+
408
+ g_lists[grp].append(g_sum_ct)
409
+ h_lists[grp].append(h_sum_ct)
410
+ tcomp1 = mp.run_jax_at(
411
+ ap_id, lambda: jnp.array(time.time(), dtype=jnp.float64)
412
+ )
413
+ comp_parts.append(
414
+ mp.run_jax_at(ap_id, lambda a, b: a - b, tcomp1, tcomp0)
415
+ )
416
+ tdec0 = mp.run_jax_at(
417
+ ap_id, lambda: jnp.array(time.time(), dtype=jnp.float64)
418
+ )
419
+ for grp in range(t):
420
+ enc_g_list = g_lists[grp]
421
+ enc_h_list = h_lists[grp]
422
+ dec_g = [
423
+ mp.run_at(ap_id, fhe.decrypt, ct, priv_ctx) for ct in enc_g_list
424
+ ]
425
+ dec_h = [
426
+ mp.run_at(ap_id, fhe.decrypt, ct, priv_ctx) for ct in enc_h_list
427
+ ]
428
+
429
+ def _stack(*xs):
430
+ return jnp.stack(xs)
431
+
432
+ _ = mp.run_jax_at(ap_id, _stack, *dec_g)
433
+ _ = mp.run_jax_at(ap_id, _stack, *dec_h)
434
+ tdec1 = mp.run_jax_at(
435
+ ap_id, lambda: jnp.array(time.time(), dtype=jnp.float64)
436
+ )
437
+ dec_parts.append(mp.run_jax_at(ap_id, lambda a, b: a - b, tdec1, tdec0))
438
+
439
+ t1 = mp.run_jax_at(ap_id, lambda: jnp.array(time.time(), dtype=jnp.float64))
440
+ dt = mp.run_jax_at(ap_id, lambda a, b: a - b, t1, t0)
441
+ # Optionally include precompute cost once (first repetition only)
442
+ if include_precompute and rep_i == 0:
443
+ dt = mp.run_jax_at(ap_id, lambda x, add: x + add, dt, pre_dt)
444
+ # Sum parts across PPs
445
+
446
+ def _sum_vec(*xs):
447
+ s = xs[0]
448
+ for x in xs[1:]:
449
+ s = s + x
450
+ return s
451
+
452
+ comp_sum = (
453
+ mp.run_jax_at(ap_id, _sum_vec, *comp_parts)
454
+ if comp_parts
455
+ else mp.run_jax_at(ap_id, lambda: jnp.array(0.0, dtype=jnp.float64))
456
+ )
457
+ dec_sum = (
458
+ mp.run_jax_at(ap_id, _sum_vec, *dec_parts)
459
+ if dec_parts
460
+ else mp.run_jax_at(ap_id, lambda: jnp.array(0.0, dtype=jnp.float64))
461
+ )
462
+ times_total.append(dt)
463
+ times_comp.append(comp_sum)
464
+ times_dec.append(dec_sum)
465
+
466
+ # Stack per-rep durations into a vector at AP for robust fetch
467
+ def _stack_times(*xs):
468
+ return jnp.stack(xs)
469
+
470
+ total_vec = mp.run_jax_at(ap_id, _stack_times, *times_total)
471
+ if not breakdown:
472
+ return total_vec
473
+ comp_vec = mp.run_jax_at(ap_id, _stack_times, *times_comp)
474
+ dec_vec = mp.run_jax_at(ap_id, _stack_times, *times_dec)
475
+
476
+ def _stack3(a, b, c):
477
+ return jnp.stack([a, b, c], axis=0)
478
+
479
+ return mp.run_jax_at(ap_id, _stack3, total_vec, comp_vec, dec_vec)
480
+
481
+
482
+ def main():
483
+ parser = argparse.ArgumentParser(description="FHE histogram microbenchmark")
484
+ parser.add_argument(
485
+ "--world-size", type=int, default=2, help="Total parties (AP=1+PPs)"
486
+ )
487
+ parser.add_argument("--m", type=int, default=4096, help="Samples")
488
+ parser.add_argument("--n-total", type=int, default=16, help="Total features")
489
+ parser.add_argument("--n-ap", type=int, default=4, help="AP feature count")
490
+ parser.add_argument("--k", type=int, default=16, help="Bins per feature")
491
+ parser.add_argument("--t", type=int, default=4, help="Groups (nodes at level)")
492
+ parser.add_argument("--reps", type=int, default=3, help="Repetitions")
493
+ parser.add_argument("--seed", type=int, default=0, help="Random seed")
494
+ parser.add_argument(
495
+ "--mode",
496
+ type=str,
497
+ default="classic",
498
+ choices=[
499
+ "classic",
500
+ "classic_cached",
501
+ "interleaved",
502
+ "interleaved_cached",
503
+ ],
504
+ help="Histogram mode to benchmark",
505
+ )
506
+ parser.add_argument(
507
+ "--include-precompute",
508
+ action="store_true",
509
+ help="Include precompute (mask generation+encryption) time in the first repetition",
510
+ )
511
+ parser.add_argument(
512
+ "--breakdown",
513
+ action="store_true",
514
+ help="Report timing breakdown (total, compute, decrypt)",
515
+ )
516
+ args = parser.parse_args()
517
+
518
+ assert args.world_size >= 2, "world-size must be >= 2"
519
+ pp_parties = args.world_size - 1
520
+ assert args.n_total >= args.n_ap
521
+
522
+ # Split PP features evenly
523
+ n_pp_total = args.n_total - args.n_ap
524
+ n_pp_each = [n_pp_total // pp_parties] * pp_parties
525
+ n_pp_each[-1] += n_pp_total - sum(n_pp_each)
526
+
527
+ X_ap, X_pp_all, y = _gen_data(args.m, args.n_total, args.n_ap, args.seed)
528
+ offset = 0
529
+ X_pp_splits = []
530
+ for c in n_pp_each:
531
+ X_pp_splits.append(X_pp_all[:, offset : offset + c])
532
+ offset += c
533
+
534
+ sim = mp.Simulator.simple(args.world_size)
535
+
536
+ ap_id = 0
537
+ pp_ids = list(range(1, args.world_size))
538
+
539
+ print("\n=== FHE Histogram Microbenchmark ===")
540
+ print(
541
+ f"world-size={args.world_size} (AP+{pp_parties} PPs), m={args.m}, n_total={args.n_total} (AP={args.n_ap}, PP={n_pp_total}), k={args.k}, t={args.t}, reps={args.reps}, mode={args.mode}"
542
+ )
543
+
544
+ out = mp.evaluate(
545
+ sim,
546
+ _bench_once,
547
+ ap_id,
548
+ pp_ids,
549
+ X_ap,
550
+ X_pp_splits,
551
+ y,
552
+ args.k,
553
+ args.t,
554
+ args.reps,
555
+ args.mode,
556
+ args.include_precompute,
557
+ args.breakdown,
558
+ )
559
+ times_raw = mp.fetch(sim, out)
560
+
561
+ # Expected: [times_at_ap, None, ...] in 2PC; extract first non-None
562
+ if isinstance(times_raw, list) and len(times_raw) >= 1 and times_raw[-1] is None:
563
+ times_nodes = times_raw[0]
564
+ else:
565
+ times_nodes = times_raw
566
+
567
+ if args.breakdown:
568
+ times_arr = np.asarray(times_nodes, dtype=float)
569
+ # Expect shape (3, reps): [total, compute, decrypt]
570
+ if times_arr.ndim == 1:
571
+ # Fallback if flattened; try to split into 3 roughly equal parts
572
+ n = times_arr.size
573
+ r = n // 3
574
+ total = times_arr[:r]
575
+ comp = times_arr[r : 2 * r]
576
+ dec = times_arr[2 * r : 3 * r]
577
+ else:
578
+ total, comp, dec = (
579
+ times_arr[0].ravel(),
580
+ times_arr[1].ravel(),
581
+ times_arr[2].ravel(),
582
+ )
583
+
584
+ print(f"Per-rep total (s): {', '.join(f'{t:.4f}' for t in total.tolist())}")
585
+ print(
586
+ f"Per-rep compute-only (s): {', '.join(f'{t:.4f}' for t in comp.tolist())}"
587
+ )
588
+ print(
589
+ f"Per-rep decrypt-only (s): {', '.join(f'{t:.4f}' for t in dec.tolist())}"
590
+ )
591
+ print(
592
+ f"Averages — total: {float(total.mean()):.4f}s, compute: {float(comp.mean()):.4f}s, decrypt: {float(dec.mean()):.4f}s"
593
+ )
594
+ else:
595
+ # Convert to numpy array of floats (handle scalar, list, or numpy array)
596
+ if isinstance(times_nodes, list):
597
+ # elements are likely [val, None] pairs; take first
598
+ times_arr = np.array(
599
+ [
600
+ float(np.array(e[0]))
601
+ if isinstance(e, (list, tuple))
602
+ else float(np.array(e))
603
+ for e in times_nodes
604
+ ],
605
+ dtype=float,
606
+ )
607
+ else:
608
+ times_arr = np.asarray(times_nodes, dtype=float).ravel()
609
+ avg = float(times_arr.mean())
610
+ print(f"Per-rep time (s): {', '.join(f'{t:.4f}' for t in times_arr.tolist())}")
611
+ print(f"Average time (s): {avg:.4f}")
612
+
613
+
614
+ if __name__ == "__main__":
615
+ main()