mplang-nightly 0.1.dev158__py3-none-any.whl → 0.1.dev268__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (191) hide show
  1. mplang/__init__.py +21 -45
  2. mplang/py.typed +13 -0
  3. mplang/v1/__init__.py +157 -0
  4. mplang/v1/_device.py +602 -0
  5. mplang/{analysis → v1/analysis}/__init__.py +1 -1
  6. mplang/{analysis → v1/analysis}/diagram.py +5 -7
  7. mplang/v1/core/__init__.py +157 -0
  8. mplang/{core → v1/core}/cluster.py +30 -14
  9. mplang/{core → v1/core}/comm.py +5 -1
  10. mplang/{core → v1/core}/context_mgr.py +1 -1
  11. mplang/{core/dtype.py → v1/core/dtypes.py} +44 -2
  12. mplang/{core → v1/core}/expr/__init__.py +7 -7
  13. mplang/{core → v1/core}/expr/ast.py +13 -14
  14. mplang/{core → v1/core}/expr/evaluator.py +65 -24
  15. mplang/{core → v1/core}/expr/printer.py +24 -18
  16. mplang/{core → v1/core}/expr/transformer.py +3 -3
  17. mplang/{core → v1/core}/expr/utils.py +2 -2
  18. mplang/{core → v1/core}/expr/visitor.py +1 -1
  19. mplang/{core → v1/core}/expr/walk.py +1 -1
  20. mplang/{core → v1/core}/interp.py +6 -6
  21. mplang/{core → v1/core}/mpir.py +23 -16
  22. mplang/{core → v1/core}/mpobject.py +6 -6
  23. mplang/{core → v1/core}/mptype.py +13 -10
  24. mplang/{core → v1/core}/pfunc.py +4 -4
  25. mplang/{core → v1/core}/primitive.py +106 -201
  26. mplang/{core → v1/core}/table.py +36 -8
  27. mplang/{core → v1/core}/tensor.py +1 -1
  28. mplang/{core → v1/core}/tracer.py +9 -9
  29. mplang/{api.py → v1/host.py} +38 -6
  30. mplang/v1/kernels/__init__.py +41 -0
  31. mplang/{kernels → v1/kernels}/base.py +1 -1
  32. mplang/v1/kernels/basic.py +240 -0
  33. mplang/{kernels → v1/kernels}/context.py +42 -27
  34. mplang/{kernels → v1/kernels}/crypto.py +44 -37
  35. mplang/v1/kernels/fhe.py +858 -0
  36. mplang/{kernels → v1/kernels}/mock_tee.py +12 -13
  37. mplang/{kernels → v1/kernels}/phe.py +263 -57
  38. mplang/{kernels → v1/kernels}/spu.py +137 -48
  39. mplang/{kernels → v1/kernels}/sql_duckdb.py +12 -15
  40. mplang/{kernels → v1/kernels}/stablehlo.py +30 -23
  41. mplang/v1/kernels/value.py +626 -0
  42. mplang/{ops → v1/ops}/__init__.py +5 -16
  43. mplang/{ops → v1/ops}/base.py +2 -5
  44. mplang/{ops/builtin.py → v1/ops/basic.py} +34 -26
  45. mplang/v1/ops/crypto.py +262 -0
  46. mplang/v1/ops/fhe.py +272 -0
  47. mplang/{ops → v1/ops}/jax_cc.py +33 -68
  48. mplang/v1/ops/nnx_cc.py +168 -0
  49. mplang/{ops → v1/ops}/phe.py +16 -4
  50. mplang/{ops → v1/ops}/spu.py +3 -5
  51. mplang/v1/ops/sql_cc.py +303 -0
  52. mplang/{ops → v1/ops}/tee.py +9 -24
  53. mplang/{protos → v1/protos}/v1alpha1/mpir_pb2.pyi +71 -21
  54. mplang/v1/protos/v1alpha1/value_pb2.py +34 -0
  55. mplang/v1/protos/v1alpha1/value_pb2.pyi +169 -0
  56. mplang/{runtime → v1/runtime}/__init__.py +2 -2
  57. mplang/v1/runtime/channel.py +230 -0
  58. mplang/{runtime → v1/runtime}/cli.py +35 -20
  59. mplang/{runtime → v1/runtime}/client.py +19 -8
  60. mplang/{runtime → v1/runtime}/communicator.py +59 -15
  61. mplang/{runtime → v1/runtime}/data_providers.py +80 -19
  62. mplang/{runtime → v1/runtime}/driver.py +30 -12
  63. mplang/v1/runtime/link_comm.py +196 -0
  64. mplang/{runtime → v1/runtime}/server.py +58 -42
  65. mplang/{runtime → v1/runtime}/session.py +57 -71
  66. mplang/{runtime → v1/runtime}/simulation.py +55 -28
  67. mplang/v1/simp/api.py +353 -0
  68. mplang/{simp → v1/simp}/mpi.py +8 -9
  69. mplang/{simp/__init__.py → v1/simp/party.py} +19 -145
  70. mplang/{simp → v1/simp}/random.py +21 -22
  71. mplang/v1/simp/smpc.py +238 -0
  72. mplang/v1/utils/table_utils.py +185 -0
  73. mplang/v2/__init__.py +424 -0
  74. mplang/v2/backends/__init__.py +57 -0
  75. mplang/v2/backends/bfv_impl.py +705 -0
  76. mplang/v2/backends/channel.py +217 -0
  77. mplang/v2/backends/crypto_impl.py +723 -0
  78. mplang/v2/backends/field_impl.py +454 -0
  79. mplang/v2/backends/func_impl.py +107 -0
  80. mplang/v2/backends/phe_impl.py +148 -0
  81. mplang/v2/backends/simp_design.md +136 -0
  82. mplang/v2/backends/simp_driver/__init__.py +41 -0
  83. mplang/v2/backends/simp_driver/http.py +168 -0
  84. mplang/v2/backends/simp_driver/mem.py +280 -0
  85. mplang/v2/backends/simp_driver/ops.py +135 -0
  86. mplang/v2/backends/simp_driver/state.py +60 -0
  87. mplang/v2/backends/simp_driver/values.py +52 -0
  88. mplang/v2/backends/simp_worker/__init__.py +29 -0
  89. mplang/v2/backends/simp_worker/http.py +354 -0
  90. mplang/v2/backends/simp_worker/mem.py +102 -0
  91. mplang/v2/backends/simp_worker/ops.py +167 -0
  92. mplang/v2/backends/simp_worker/state.py +49 -0
  93. mplang/v2/backends/spu_impl.py +275 -0
  94. mplang/v2/backends/spu_state.py +187 -0
  95. mplang/v2/backends/store_impl.py +62 -0
  96. mplang/v2/backends/table_impl.py +838 -0
  97. mplang/v2/backends/tee_impl.py +215 -0
  98. mplang/v2/backends/tensor_impl.py +519 -0
  99. mplang/v2/cli.py +603 -0
  100. mplang/v2/cli_guide.md +122 -0
  101. mplang/v2/dialects/__init__.py +36 -0
  102. mplang/v2/dialects/bfv.py +665 -0
  103. mplang/v2/dialects/crypto.py +689 -0
  104. mplang/v2/dialects/dtypes.py +378 -0
  105. mplang/v2/dialects/field.py +210 -0
  106. mplang/v2/dialects/func.py +135 -0
  107. mplang/v2/dialects/phe.py +723 -0
  108. mplang/v2/dialects/simp.py +944 -0
  109. mplang/v2/dialects/spu.py +349 -0
  110. mplang/v2/dialects/store.py +63 -0
  111. mplang/v2/dialects/table.py +407 -0
  112. mplang/v2/dialects/tee.py +346 -0
  113. mplang/v2/dialects/tensor.py +1175 -0
  114. mplang/v2/edsl/README.md +279 -0
  115. mplang/v2/edsl/__init__.py +99 -0
  116. mplang/v2/edsl/context.py +311 -0
  117. mplang/v2/edsl/graph.py +463 -0
  118. mplang/v2/edsl/jit.py +62 -0
  119. mplang/v2/edsl/object.py +53 -0
  120. mplang/v2/edsl/primitive.py +284 -0
  121. mplang/v2/edsl/printer.py +119 -0
  122. mplang/v2/edsl/registry.py +207 -0
  123. mplang/v2/edsl/serde.py +375 -0
  124. mplang/v2/edsl/tracer.py +614 -0
  125. mplang/v2/edsl/typing.py +816 -0
  126. mplang/v2/kernels/Makefile +30 -0
  127. mplang/v2/kernels/__init__.py +23 -0
  128. mplang/v2/kernels/gf128.cpp +148 -0
  129. mplang/v2/kernels/ldpc.cpp +82 -0
  130. mplang/v2/kernels/okvs.cpp +283 -0
  131. mplang/v2/kernels/okvs_opt.cpp +291 -0
  132. mplang/v2/kernels/py_kernels.py +398 -0
  133. mplang/v2/libs/collective.py +330 -0
  134. mplang/v2/libs/device/__init__.py +51 -0
  135. mplang/v2/libs/device/api.py +813 -0
  136. mplang/v2/libs/device/cluster.py +352 -0
  137. mplang/v2/libs/ml/__init__.py +23 -0
  138. mplang/v2/libs/ml/sgb.py +1861 -0
  139. mplang/v2/libs/mpc/__init__.py +41 -0
  140. mplang/v2/libs/mpc/_utils.py +99 -0
  141. mplang/v2/libs/mpc/analytics/__init__.py +35 -0
  142. mplang/v2/libs/mpc/analytics/aggregation.py +372 -0
  143. mplang/v2/libs/mpc/analytics/groupby.md +99 -0
  144. mplang/v2/libs/mpc/analytics/groupby.py +331 -0
  145. mplang/v2/libs/mpc/analytics/permutation.py +386 -0
  146. mplang/v2/libs/mpc/common/constants.py +39 -0
  147. mplang/v2/libs/mpc/ot/__init__.py +32 -0
  148. mplang/v2/libs/mpc/ot/base.py +222 -0
  149. mplang/v2/libs/mpc/ot/extension.py +477 -0
  150. mplang/v2/libs/mpc/ot/silent.py +217 -0
  151. mplang/v2/libs/mpc/psi/__init__.py +40 -0
  152. mplang/v2/libs/mpc/psi/cuckoo.py +228 -0
  153. mplang/v2/libs/mpc/psi/okvs.py +49 -0
  154. mplang/v2/libs/mpc/psi/okvs_gct.py +79 -0
  155. mplang/v2/libs/mpc/psi/oprf.py +310 -0
  156. mplang/v2/libs/mpc/psi/rr22.py +344 -0
  157. mplang/v2/libs/mpc/psi/unbalanced.py +200 -0
  158. mplang/v2/libs/mpc/vole/__init__.py +31 -0
  159. mplang/v2/libs/mpc/vole/gilboa.py +327 -0
  160. mplang/v2/libs/mpc/vole/ldpc.py +383 -0
  161. mplang/v2/libs/mpc/vole/silver.py +336 -0
  162. mplang/v2/runtime/__init__.py +15 -0
  163. mplang/v2/runtime/dialect_state.py +41 -0
  164. mplang/v2/runtime/interpreter.py +871 -0
  165. mplang/v2/runtime/object_store.py +194 -0
  166. mplang/v2/runtime/value.py +141 -0
  167. {mplang_nightly-0.1.dev158.dist-info → mplang_nightly-0.1.dev268.dist-info}/METADATA +24 -17
  168. mplang_nightly-0.1.dev268.dist-info/RECORD +180 -0
  169. {mplang_nightly-0.1.dev158.dist-info → mplang_nightly-0.1.dev268.dist-info}/WHEEL +1 -1
  170. mplang/core/__init__.py +0 -92
  171. mplang/device.py +0 -340
  172. mplang/kernels/builtin.py +0 -207
  173. mplang/ops/crypto.py +0 -109
  174. mplang/ops/ibis_cc.py +0 -139
  175. mplang/ops/sql.py +0 -61
  176. mplang/protos/v1alpha1/mpir_pb2_grpc.py +0 -3
  177. mplang/runtime/link_comm.py +0 -131
  178. mplang/simp/smpc.py +0 -201
  179. mplang/utils/table_utils.py +0 -73
  180. mplang_nightly-0.1.dev158.dist-info/RECORD +0 -77
  181. /mplang/{core → v1/core}/mask.py +0 -0
  182. /mplang/{protos → v1/protos}/v1alpha1/mpir_pb2.py +0 -0
  183. /mplang/{runtime → v1/runtime}/exceptions.py +0 -0
  184. /mplang/{runtime → v1/runtime}/http_api.md +0 -0
  185. /mplang/{kernels → v1/simp}/__init__.py +0 -0
  186. /mplang/{utils → v1/utils}/__init__.py +0 -0
  187. /mplang/{utils → v1/utils}/crypto.py +0 -0
  188. /mplang/{utils → v1/utils}/func_utils.py +0 -0
  189. /mplang/{utils → v1/utils}/spu_utils.py +0 -0
  190. {mplang_nightly-0.1.dev158.dist-info → mplang_nightly-0.1.dev268.dist-info}/entry_points.txt +0 -0
  191. {mplang_nightly-0.1.dev158.dist-info → mplang_nightly-0.1.dev268.dist-info}/licenses/LICENSE +0 -0
@@ -0,0 +1,1861 @@
1
+ # Copyright 2025 Ant Group Co., Ltd.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ # mypy: disable-error-code="no-untyped-def,no-any-return,var-annotated"
16
+
17
+ """SecureBoost v2: Optimized implementation using mplang.v2 low-level BFV APIs.
18
+
19
+ This implementation improves upon v1 by leveraging BFV SIMD slots and the
20
+ groupby primitives for efficient histogram computation.
21
+
22
+ Key optimizations:
23
+ 1. SIMD slot packing for parallel histogram bucket computation
24
+ 2. Rotation-based aggregation for efficient slot summation
25
+ 3. Reduced communication via packed ciphertext results
26
+
27
+ See design/sgb_v2.md for detailed architecture documentation.
28
+
29
+ Usage:
30
+ from examples.v2.sgb import SecureBoost
31
+
32
+ model = SecureBoost(n_estimators=10, max_depth=3)
33
+ model.fit([X_ap, X_pp], y)
34
+ predictions = model.predict([X_ap_test, X_pp_test])
35
+ """
36
+
37
+ from __future__ import annotations
38
+
39
+ from collections import deque
40
+ from dataclasses import dataclass
41
+ from functools import partial
42
+ from typing import Any
43
+
44
+ import jax
45
+ import jax.numpy as jnp
46
+ import numpy as np
47
+ from jax.ops import segment_sum
48
+
49
+ from mplang.v2.dialects import bfv, simp, tensor
50
+ from mplang.v2.libs.mpc.analytics import aggregation
51
+
52
+ # ==============================================================================
53
+ # Configuration
54
+ # ==============================================================================
55
+
56
+ DEFAULT_FXP_BITS = 15 # Fixed-point scale = 2^15 = 32768
57
+ # BFV slot count (Increased for depth)
58
+ # NOTE: For 1M samples, the sum of gradients can reach ~3.2e10 (2^35).
59
+ # The default plain_modulus (1032193 ~ 2^20) will cause overflow.
60
+ # For large datasets, you MUST increase plain_modulus (e.g. to a 40-bit prime).
61
+ DEFAULT_POLY_MODULUS_DEGREE = 8192
62
+
63
+
64
+ # ==============================================================================
65
+ # Data Structures
66
+ # ==============================================================================
67
+
68
+
69
+ @dataclass
70
+ class Tree:
71
+ """Single decision tree in flat array representation."""
72
+
73
+ feature: list[Any] # Per-party feature indices, shape (n_nodes,)
74
+ threshold: list[Any] # Per-party thresholds, shape (n_nodes,)
75
+ value: Any # Leaf values at AP, shape (n_nodes,)
76
+ is_leaf: Any # Leaf mask, shape (n_nodes,)
77
+ owned_party_id: Any # Node owner, shape (n_nodes,)
78
+
79
+
80
+ @dataclass
81
+ class TreeEnsemble:
82
+ """XGBoost ensemble model."""
83
+
84
+ max_depth: int
85
+ trees: list[Tree]
86
+ initial_prediction: Any # Base prediction at AP
87
+
88
+
89
+ # ==============================================================================
90
+ # JAX Mathematical Functions
91
+ # ==============================================================================
92
+
93
+
94
+ @jax.jit
95
+ def compute_init_pred(y: jnp.ndarray) -> jnp.ndarray:
96
+ """Compute initial prediction for binary classification (log-odds)."""
97
+ p_base = jnp.clip(jnp.mean(y), 1e-15, 1 - 1e-15)
98
+ return jnp.log(p_base / (1 - p_base))
99
+
100
+
101
+ @jax.jit
102
+ def sigmoid(x: jnp.ndarray) -> jnp.ndarray:
103
+ """Sigmoid activation function."""
104
+ return 1 / (1 + jnp.exp(-x))
105
+
106
+
107
+ @jax.jit
108
+ def compute_gh(y_true: jnp.ndarray, y_pred_logits: jnp.ndarray) -> jnp.ndarray:
109
+ """Compute gradient and hessian for log loss. Returns (m, 2) array."""
110
+ p = sigmoid(y_pred_logits)
111
+ g = p - y_true
112
+ h = p * (1 - p)
113
+ return jnp.column_stack([g, h])
114
+
115
+
116
+ @jax.jit
117
+ def quantize_gh(gh: jnp.ndarray, scale: int) -> jnp.ndarray:
118
+ """Quantize float G/H to int64 for BFV encryption."""
119
+ return jnp.round(gh * scale).astype(jnp.int64)
120
+
121
+
122
+ @jax.jit
123
+ def dequantize(arr: jnp.ndarray, scale: int) -> jnp.ndarray:
124
+ """Dequantize int64 back to float."""
125
+ return arr.astype(jnp.float32) / scale
126
+
127
+
128
+ # ==============================================================================
129
+ # Binning Functions
130
+ # ==============================================================================
131
+
132
+
133
+ def build_bins_equi_width(x: jnp.ndarray, max_bin: int) -> jnp.ndarray:
134
+ """Build equi-width bin boundaries for a single feature."""
135
+ n_samples = x.shape[0]
136
+ n_splits = max_bin - 1
137
+ inf_splits = jnp.full(n_splits, jnp.inf, dtype=x.dtype)
138
+
139
+ def create_bins():
140
+ min_val, max_val = jnp.min(x), jnp.max(x)
141
+ is_constant = (max_val - min_val) < 1e-9
142
+
143
+ def gen_splits():
144
+ return jnp.linspace(min_val, max_val, num=max_bin + 1)[1:-1]
145
+
146
+ return jax.lax.cond(is_constant, lambda: inf_splits, gen_splits)
147
+
148
+ return jax.lax.cond(n_samples >= 2, create_bins, lambda: inf_splits)
149
+
150
+
151
+ @jax.jit
152
+ def compute_bin_indices(x: jnp.ndarray, bins: jnp.ndarray) -> jnp.ndarray:
153
+ """Compute bin indices for all samples of a single feature."""
154
+ return jnp.digitize(x, bins, right=True)
155
+
156
+
157
+ # ==============================================================================
158
+ # Local Histogram (AP, no FHE needed)
159
+ # ==============================================================================
160
+
161
+
162
+ def make_local_build_histogram(n_nodes: int, n_buckets: int):
163
+ """Create a JIT-compiled local histogram builder with static n_nodes and n_buckets."""
164
+
165
+ @jax.jit
166
+ def local_build_histogram(
167
+ gh: jnp.ndarray,
168
+ bt_local: jnp.ndarray,
169
+ bin_indices: jnp.ndarray,
170
+ ) -> jnp.ndarray:
171
+ """Build G/H histogram using segment_sum. Returns (n_features, n_nodes, n_buckets, 2)."""
172
+
173
+ def hist_one_feature(bins_one: jnp.ndarray) -> jnp.ndarray:
174
+ combined = bt_local * n_buckets + bins_one
175
+ valid_mask = bt_local >= 0
176
+ valid_gh = gh * valid_mask[:, None]
177
+ return segment_sum(valid_gh, combined, num_segments=n_nodes * n_buckets)
178
+
179
+ flat = jax.vmap(hist_one_feature, in_axes=1, out_axes=0)(bin_indices)
180
+ return flat.reshape((bin_indices.shape[1], n_nodes, n_buckets, 2))
181
+
182
+ return local_build_histogram
183
+
184
+
185
+ @jax.jit
186
+ def compute_best_split_from_hist(
187
+ gh_hist: jnp.ndarray, # (n_features, n_buckets, 2) for one node
188
+ reg_lambda: float,
189
+ gamma: float,
190
+ min_child_weight: float,
191
+ ) -> tuple[jnp.ndarray, jnp.ndarray, jnp.ndarray]:
192
+ """Find best split for a single node from its histogram."""
193
+ gh_total = jnp.sum(gh_hist, axis=1) # (n_features, 2)
194
+ gh_left = jnp.cumsum(gh_hist, axis=1)[:, :-1, :] # (n_features, n_buckets-1, 2)
195
+
196
+ g_total, h_total = gh_total[..., 0], gh_total[..., 1]
197
+ G_left, H_left = gh_left[..., 0], gh_left[..., 1]
198
+ G_right = g_total[:, None] - G_left
199
+ H_right = h_total[:, None] - H_left
200
+
201
+ score_parent = jnp.square(g_total) / (h_total + reg_lambda + 1e-9)
202
+ score_left = jnp.square(G_left) / (H_left + reg_lambda + 1e-9)
203
+ score_right = jnp.square(G_right) / (H_right + reg_lambda + 1e-9)
204
+
205
+ gain = (score_left + score_right - score_parent[:, None]) / 2.0
206
+ valid = (H_left >= min_child_weight) & (H_right >= min_child_weight)
207
+ gain = jnp.where(valid, gain - gamma, -jnp.inf)
208
+
209
+ flat_idx = jnp.argmax(gain)
210
+ best_feat, best_thresh = jnp.unravel_index(flat_idx, gain.shape)
211
+ return jnp.max(gain), best_feat, best_thresh
212
+
213
+
214
+ def local_compute_best_splits(
215
+ gh_hist: jnp.ndarray, # (n_features, n_nodes, n_buckets, 2)
216
+ reg_lambda: float,
217
+ gamma: float,
218
+ min_child_weight: float,
219
+ ) -> tuple[jnp.ndarray, jnp.ndarray, jnp.ndarray]:
220
+ """Find best splits for all nodes. Returns (n_nodes,) arrays."""
221
+ # Transpose to (n_nodes, n_features, n_buckets, 2)
222
+ gh_trans = jnp.transpose(gh_hist, (1, 0, 2, 3))
223
+
224
+ fn = partial(
225
+ compute_best_split_from_hist,
226
+ reg_lambda=reg_lambda,
227
+ gamma=gamma,
228
+ min_child_weight=min_child_weight,
229
+ )
230
+ return jax.vmap(fn)(gh_trans)
231
+
232
+
233
+ # ==============================================================================
234
+ # FHE Histogram (PP, using low-level BFV)
235
+ # ==============================================================================
236
+
237
+
238
+ def _build_packed_mask_jit(node_mask, feat_bins, n_buckets, stride, slot_count):
239
+ valid = node_mask == 1
240
+ bucket_onehot = (jnp.arange(n_buckets)[None, :] == feat_bins[:, None]) & valid[
241
+ :, None
242
+ ]
243
+ running_counts = jnp.cumsum(bucket_onehot, axis=0)
244
+ shifted_counts = jnp.zeros_like(running_counts)
245
+ shifted_counts = shifted_counts.at[1:].set(running_counts[:-1])
246
+ sample_offsets = jnp.take_along_axis(
247
+ shifted_counts, feat_bins[:, None], axis=1
248
+ ).squeeze(-1)
249
+
250
+ scatter_indices = jnp.where(valid, feat_bins * stride + sample_offsets, -1)
251
+
252
+ valid_mask = scatter_indices >= 0
253
+ valid_indices = jnp.where(valid_mask, scatter_indices, 0).astype(jnp.int32)
254
+ valid_ones = jnp.where(valid_mask, 1, 0).astype(jnp.int64)
255
+ output = segment_sum(valid_ones, valid_indices, num_segments=slot_count)
256
+ return output
257
+
258
+
259
+ def compute_all_masks(
260
+ subgroup_map, bin_indices, n_buckets, stride, slot_count, n_chunks
261
+ ):
262
+ # subgroup_map: (n_nodes, m)
263
+ # bin_indices: (m, n_features)
264
+
265
+ m = bin_indices.shape[0]
266
+ n_features = bin_indices.shape[1]
267
+ n_nodes = subgroup_map.shape[0]
268
+
269
+ # Pad
270
+ pad_len = n_chunks * slot_count - m
271
+ if pad_len > 0:
272
+ subgroup_map = jnp.pad(subgroup_map, ((0, 0), (0, pad_len)))
273
+ bin_indices = jnp.pad(bin_indices, ((0, pad_len), (0, 0)))
274
+
275
+ # Reshape chunks
276
+ # subgroup_map: (n_nodes, n_chunks, slot_count)
277
+ sg_chunks = subgroup_map.reshape(n_nodes, n_chunks, slot_count)
278
+
279
+ # bin_indices: (n_chunks, slot_count, n_features) -> (n_features, n_chunks, slot_count)
280
+ bi_chunks = bin_indices.reshape(n_chunks, slot_count, n_features).transpose(2, 0, 1)
281
+
282
+ # vmap over chunks
283
+ def process_chunk(nm, fb):
284
+ return _build_packed_mask_jit(nm, fb, n_buckets, stride, slot_count)
285
+
286
+ v_chunk = jax.vmap(process_chunk, in_axes=(0, 0))
287
+
288
+ # vmap over features (nm fixed, fb varies)
289
+ v_feat = jax.vmap(v_chunk, in_axes=(None, 0))
290
+
291
+ # vmap over nodes (nm varies, fb fixed)
292
+ v_node = jax.vmap(v_feat, in_axes=(0, None))
293
+
294
+ all_masks = v_node(sg_chunks, bi_chunks)
295
+ # Flatten and convert to tuple of arrays
296
+ return all_masks.reshape(-1, slot_count)
297
+
298
+
299
+ def _compute_histogram_chunk_batch(
300
+ subgroup_map,
301
+ bin_indices,
302
+ g_cts,
303
+ h_cts,
304
+ encoder,
305
+ relin_keys,
306
+ galois_keys,
307
+ n_nodes,
308
+ n_features,
309
+ n_chunks,
310
+ n_buckets,
311
+ slot_count,
312
+ stride,
313
+ max_samples_per_bucket,
314
+ m,
315
+ ):
316
+ # Precompute all masks in one go
317
+ compute_all_masks_jit = partial(
318
+ compute_all_masks,
319
+ n_buckets=n_buckets,
320
+ stride=stride,
321
+ slot_count=slot_count,
322
+ n_chunks=n_chunks,
323
+ )
324
+ all_masks_tensor = tensor.run_jax(
325
+ compute_all_masks_jit,
326
+ subgroup_map,
327
+ bin_indices,
328
+ )
329
+
330
+ # Batch encode all masks at once to avoid scheduler bottleneck
331
+ # Pass relin_keys as context provider (it holds the SEALContext)
332
+ all_masks_pt = bfv.batch_encode(all_masks_tensor, encoder, key=relin_keys)
333
+ mask_iter = iter(all_masks_pt)
334
+
335
+ # ==========================================================================
336
+ # Optimization: Incremental Packing to reduce peak memory
337
+ # ==========================================================================
338
+ # Instead of accumulating all features and then packing, we pack incrementally.
339
+ # This reduces peak memory from O(n_features) to O(stride).
340
+
341
+ # Create mask for valid slots (0, stride, 2*stride, ...)
342
+ m_np = np.zeros(slot_count, dtype=np.int64)
343
+ idx_np = np.arange(n_buckets) * stride
344
+ m_np[idx_np] = 1
345
+ mask_arr = tensor.constant(m_np)
346
+ mask_pt_pack = bfv.encode(mask_arr, encoder)
347
+
348
+ g_packed_flat = []
349
+ h_packed_flat = []
350
+
351
+ # Optimization 2: Tree Reduction
352
+ # Helper to sum a list of ciphertexts using a binary tree structure.
353
+ # This reduces the dependency chain depth from O(N) to O(log N),
354
+ # allowing the scheduler to parallelize additions.
355
+ def tree_sum(items):
356
+ if not items:
357
+ return None
358
+ if len(items) == 1:
359
+ return items[0]
360
+
361
+ queue = deque(items)
362
+ while len(queue) > 1:
363
+ # Process in pairs
364
+ for _ in range(len(queue) // 2):
365
+ left = queue.popleft()
366
+ right = queue.popleft()
367
+ queue.append(bfv.add(left, right))
368
+
369
+ return queue[0] if queue else None
370
+
371
+ for _node_idx in range(n_nodes):
372
+ # Process features in batches of 'stride'
373
+ for batch_start in range(0, n_features, stride):
374
+ batch_end = min(batch_start + stride, n_features)
375
+
376
+ g_rot_list = []
377
+ h_rot_list = []
378
+
379
+ for i, _feat_idx in enumerate(range(batch_start, batch_end)):
380
+ # 1. Compute Histogram for this feature (across chunks)
381
+ g_masked_list = []
382
+ h_masked_list = []
383
+
384
+ for chunk_idx in range(n_chunks):
385
+ mask_pt = next(mask_iter)
386
+ # mask_pt is already encoded via batch_encode
387
+
388
+ g_ct_chunk = g_cts[chunk_idx]
389
+ h_ct_chunk = h_cts[chunk_idx]
390
+
391
+ g_masked = bfv.relinearize(bfv.mul(g_ct_chunk, mask_pt), relin_keys)
392
+ h_masked = bfv.relinearize(bfv.mul(h_ct_chunk, mask_pt), relin_keys)
393
+
394
+ g_masked_list.append(g_masked)
395
+ h_masked_list.append(h_masked)
396
+
397
+ g_masked_acc = tree_sum(g_masked_list)
398
+ h_masked_acc = tree_sum(h_masked_list)
399
+
400
+ # Lazy Aggregation: Aggregate once after summing all chunks
401
+ # This reduces rotations by a factor of n_chunks
402
+ g_feat_acc = aggregation.batch_bucket_aggregate(
403
+ g_masked_acc,
404
+ n_buckets,
405
+ max_samples_per_bucket,
406
+ galois_keys,
407
+ slot_count,
408
+ )
409
+ h_feat_acc = aggregation.batch_bucket_aggregate(
410
+ h_masked_acc,
411
+ n_buckets,
412
+ max_samples_per_bucket,
413
+ galois_keys,
414
+ slot_count,
415
+ )
416
+
417
+ assert g_feat_acc is not None
418
+ assert h_feat_acc is not None
419
+
420
+ # 2. Pack immediately
421
+ # Relative offset = i
422
+ # Mask valid slots
423
+ g_masked_pack = bfv.relinearize(
424
+ bfv.mul(g_feat_acc, mask_pt_pack), relin_keys
425
+ )
426
+ h_masked_pack = bfv.relinearize(
427
+ bfv.mul(h_feat_acc, mask_pt_pack), relin_keys
428
+ )
429
+
430
+ # Rotate to position
431
+ g_rot = bfv.rotate(g_masked_pack, -i, galois_keys)
432
+ h_rot = bfv.rotate(h_masked_pack, -i, galois_keys)
433
+
434
+ g_rot_list.append(g_rot)
435
+ h_rot_list.append(h_rot)
436
+
437
+ g_packed_acc = tree_sum(g_rot_list)
438
+ h_packed_acc = tree_sum(h_rot_list)
439
+
440
+ g_packed_flat.append(g_packed_acc)
441
+ h_packed_flat.append(h_packed_acc)
442
+
443
+ return g_packed_flat, h_packed_flat
444
+
445
+
446
+ def _process_decrypted_jit(
447
+ g_vecs, h_vecs, scale, n_nodes, n_features, n_buckets, stride
448
+ ):
449
+ # g_vecs is list of packed vectors.
450
+ # Shape of each vector: (slot_count,)
451
+ g_stack = jnp.stack(g_vecs)
452
+ h_stack = jnp.stack(h_vecs)
453
+
454
+ # We need to reconstruct (n_nodes, n_features, n_buckets)
455
+ g_unpacked = []
456
+ h_unpacked = []
457
+
458
+ cts_per_node = (n_features + stride - 1) // stride
459
+
460
+ for node_i in range(n_nodes):
461
+ for feat_i in range(n_features):
462
+ # Which CT?
463
+ ct_idx = node_i * cts_per_node + (feat_i // stride)
464
+ # Which offset in CT?
465
+ offset = feat_i % stride
466
+
467
+ # Indices for buckets: b*stride + offset
468
+ bucket_indices = jnp.arange(n_buckets) * stride + offset
469
+
470
+ g_vals = g_stack[ct_idx, bucket_indices]
471
+ h_vals = h_stack[ct_idx, bucket_indices]
472
+
473
+ g_unpacked.append(g_vals)
474
+ h_unpacked.append(h_vals)
475
+
476
+ # Now we have flat list of (n_buckets,) arrays
477
+ g_flat = jnp.stack(g_unpacked) # (n_nodes*n_features, n_buckets)
478
+ h_flat = jnp.stack(h_unpacked)
479
+
480
+ g_buckets = g_flat.astype(jnp.float32) / scale
481
+ h_buckets = h_flat.astype(jnp.float32) / scale
482
+
483
+ g_cumsum = jnp.cumsum(g_buckets, axis=1)
484
+ h_cumsum = jnp.cumsum(h_buckets, axis=1)
485
+
486
+ g_reshaped = g_cumsum.reshape(n_nodes, n_features, n_buckets)
487
+ h_reshaped = h_cumsum.reshape(n_nodes, n_features, n_buckets)
488
+
489
+ combined = jnp.stack([g_reshaped, h_reshaped], axis=-1)
490
+ return combined
491
+
492
+
493
+ def _decrypt_batch(
494
+ g_enc_flat,
495
+ h_enc_flat,
496
+ sk,
497
+ encoder,
498
+ fxp_scale,
499
+ n_nodes,
500
+ n_features,
501
+ n_buckets,
502
+ stride,
503
+ ):
504
+ g_vecs = [bfv.decode(bfv.decrypt(ct, sk), encoder) for ct in g_enc_flat]
505
+ h_vecs = [bfv.decode(bfv.decrypt(ct, sk), encoder) for ct in h_enc_flat]
506
+
507
+ fn_jit = partial(
508
+ _process_decrypted_jit,
509
+ n_nodes=n_nodes,
510
+ n_features=n_features,
511
+ n_buckets=n_buckets,
512
+ stride=stride,
513
+ )
514
+ return tensor.run_jax(
515
+ fn_jit,
516
+ g_vecs,
517
+ h_vecs,
518
+ fxp_scale,
519
+ )
520
+
521
+
522
+ def fhe_encrypt_gh(
523
+ qg: Any,
524
+ qh: Any,
525
+ pk: Any,
526
+ encoder: Any,
527
+ ap_rank: int,
528
+ n_samples: int,
529
+ slot_count: int = DEFAULT_POLY_MODULUS_DEGREE,
530
+ ) -> tuple[list[Any], list[Any], int]:
531
+ """Encrypt quantized G/H vectors at AP, splitting into chunks if m > slot_count.
532
+
533
+ When m > slot_count, the vectors are split into ceil(m / slot_count) chunks,
534
+ each encrypted as a separate ciphertext. This enables processing arbitrarily
535
+ large datasets with a fixed poly_modulus_degree.
536
+
537
+ Args:
538
+ qg: Quantized G vector, shape (m,)
539
+ qh: Quantized H vector, shape (m,)
540
+ pk: BFV public key
541
+ encoder: BFV encoder
542
+ ap_rank: Active party rank
543
+ n_samples: Number of samples (m)
544
+ slot_count: Number of slots per ciphertext (default 4096)
545
+
546
+ Returns:
547
+ (g_cts, h_cts, n_chunks): Lists of encrypted G/H chunks and chunk count
548
+ """
549
+ # Calculate n_chunks at trace time (known statically)
550
+ n_chunks = (n_samples + slot_count - 1) // slot_count
551
+
552
+ g_cts: list[Any] = []
553
+ h_cts: list[Any] = []
554
+
555
+ for chunk_idx in range(n_chunks):
556
+ start = chunk_idx * slot_count
557
+ end = min((chunk_idx + 1) * slot_count, n_samples)
558
+ chunk_size = end - start
559
+
560
+ # Extract, pad, encode and encrypt both G and H chunks together
561
+ def slice_pad_encode_encrypt(
562
+ g_vec, h_vec, enc, key, s=start, e=end, cs=chunk_size, sc=slot_count
563
+ ):
564
+ # Slice and pad using JAX
565
+ def slice_and_pad_both(gv, hv):
566
+ g_chunk = gv[s:e]
567
+ h_chunk = hv[s:e]
568
+ if cs < sc:
569
+ g_chunk = jnp.pad(g_chunk, (0, sc - cs))
570
+ h_chunk = jnp.pad(h_chunk, (0, sc - cs))
571
+ return g_chunk, h_chunk
572
+
573
+ g_chunk, h_chunk = tensor.run_jax(slice_and_pad_both, g_vec, h_vec)
574
+ # Encode and encrypt
575
+ g_pt = bfv.encode(g_chunk, enc)
576
+ h_pt = bfv.encode(h_chunk, enc)
577
+ return bfv.encrypt(g_pt, key), bfv.encrypt(h_pt, key)
578
+
579
+ g_ct, h_ct = simp.pcall_static(
580
+ (ap_rank,), slice_pad_encode_encrypt, qg, qh, encoder, pk
581
+ )
582
+
583
+ g_cts.append(g_ct)
584
+ h_cts.append(h_ct)
585
+
586
+ return g_cts, h_cts, n_chunks
587
+
588
+
589
+ def fhe_histogram_optimized(
590
+ g_cts: list[Any], # List of encrypted G chunks at PP
591
+ h_cts: list[Any], # List of encrypted H chunks at PP
592
+ subgroup_map: Any, # (n_nodes, m) binary node membership
593
+ bin_indices: Any, # (m, n_features) binned features
594
+ n_buckets: int,
595
+ n_nodes: int,
596
+ n_features: int,
597
+ pp_rank: int,
598
+ ap_rank: int,
599
+ encoder: Any,
600
+ relin_keys: Any,
601
+ galois_keys: Any,
602
+ m: int,
603
+ n_chunks: int = 1,
604
+ slot_count: int = DEFAULT_POLY_MODULUS_DEGREE,
605
+ ) -> tuple[list[Any], list[Any]]:
606
+ """Compute encrypted histogram sums using SIMD bucket packing.
607
+
608
+ **Multi-CT Support**
609
+
610
+ When m > slot_count, data is split into n_chunks ciphertexts:
611
+ - Chunk 0: samples [0, slot_count)
612
+ - Chunk 1: samples [slot_count, 2*slot_count)
613
+ - ...
614
+
615
+ For each chunk, we compute the histogram separately, then add results
616
+ in the FHE domain.
617
+
618
+ **SIMD Bucket Packing** (per chunk)
619
+
620
+ 1. Divide slot_count into n_buckets regions, each with `stride` slots
621
+ 2. Build scatter mask placing sample i at slot (bucket[i] * stride + offset[i])
622
+ 3. Single CT × packed_mask multiplication
623
+ 4. Single rotate_and_sum aggregates ALL buckets simultaneously
624
+ 5. Add chunk results together
625
+
626
+ Returns:
627
+ g_enc[node][feat]: List of packed encrypted G histograms (one CT per feature)
628
+ h_enc[node][feat]: List of packed encrypted H histograms (one CT per feature)
629
+ """
630
+ stride = slot_count // n_buckets
631
+ # Estimate max samples per bucket per chunk
632
+ samples_per_chunk = (m + n_chunks - 1) // n_chunks
633
+ max_samples_per_bucket = min(stride, max(samples_per_chunk // n_buckets * 2, 64))
634
+
635
+ # Use partial to bake in static arguments (integers) so they are treated as static by JAX
636
+ fn = partial(
637
+ _compute_histogram_chunk_batch,
638
+ n_nodes=n_nodes,
639
+ n_features=n_features,
640
+ n_chunks=n_chunks,
641
+ n_buckets=n_buckets,
642
+ slot_count=slot_count,
643
+ stride=stride,
644
+ max_samples_per_bucket=max_samples_per_bucket,
645
+ m=m,
646
+ )
647
+
648
+ g_results_flat, h_results_flat = simp.pcall_static(
649
+ (pp_rank,),
650
+ fn,
651
+ subgroup_map,
652
+ bin_indices,
653
+ g_cts,
654
+ h_cts,
655
+ encoder,
656
+ relin_keys,
657
+ galois_keys,
658
+ )
659
+
660
+ # Transfer final packed result to AP
661
+ # g_results_flat is a list of Objects (one per node/feature/chunk accumulation)
662
+ g_packed_ap = [
663
+ simp.shuffle_static(obj, {ap_rank: pp_rank}) for obj in g_results_flat
664
+ ]
665
+ h_packed_ap = [
666
+ simp.shuffle_static(obj, {ap_rank: pp_rank}) for obj in h_results_flat
667
+ ]
668
+
669
+ return g_packed_ap, h_packed_ap
670
+
671
+
672
+ def decrypt_histogram_results(
673
+ g_enc_flat: Any,
674
+ h_enc_flat: Any,
675
+ sk: Any,
676
+ encoder: Any,
677
+ fxp_scale: int,
678
+ n_nodes: int,
679
+ n_features: int,
680
+ n_buckets: int,
681
+ ap_rank: int,
682
+ slot_count: int = DEFAULT_POLY_MODULUS_DEGREE,
683
+ ) -> list[Any]:
684
+ """Decrypt and assemble histogram results at AP.
685
+
686
+ **SIMD Bucket Packing Format**
687
+
688
+ With SIMD bucket packing, each ciphertext contains ALL buckets for one feature:
689
+ - g_enc_flat is a list of packed CTs (one per feature per node)
690
+ - slot[b * stride] contains histogram[b] for bucket b
691
+ - stride = slot_count // n_buckets
692
+
693
+ We extract bucket results from strided positions, then compute cumulative sum.
694
+
695
+ Returns list of (n_features, n_buckets, 2) arrays, one per node.
696
+ The returned histograms are CUMULATIVE (sum of all bins <= bucket_idx).
697
+ """
698
+ stride = slot_count // n_buckets
699
+
700
+ fn = partial(
701
+ _decrypt_batch,
702
+ fxp_scale=fxp_scale,
703
+ n_nodes=n_nodes,
704
+ n_features=n_features,
705
+ n_buckets=n_buckets,
706
+ stride=stride,
707
+ )
708
+
709
+ combined_results = simp.pcall_static(
710
+ (ap_rank,),
711
+ fn,
712
+ g_enc_flat,
713
+ h_enc_flat,
714
+ sk,
715
+ encoder,
716
+ )
717
+
718
+ # combined_results is (n_nodes, n_features, n_buckets, 2)
719
+ # Convert to list of (n_features, n_buckets, 2)
720
+ # Since combined_results is an Object, we can't iterate it in Python.
721
+ # But the caller (build_tree) expects a list of Objects (one per node)
722
+ # because it stacks them later: stacked = jnp.stack(hists, axis=0)
723
+
724
+ # Wait, if combined_results is a single Object representing the whole tensor,
725
+ # we can just return that single Object if we change the caller to handle it.
726
+ # But build_tree expects a list.
727
+
728
+ # Actually, build_tree does:
729
+ # pp_hists = decrypt_histogram_results(...)
730
+ # def find_splits(*hists):
731
+ # stacked = jnp.stack(hists, axis=0)
732
+ # pp_gains, ... = simp.pcall_static(..., tensor.run_jax(find_splits, *pp_hists))
733
+
734
+ # If pp_hists is a single tensor (n_nodes, ...), we can change find_splits to take it directly.
735
+
736
+ return combined_results
737
+
738
+
739
+ # ==============================================================================
740
+ # Tree Update Functions
741
+ # ==============================================================================
742
+
743
+
744
+ def make_get_subgroup_map(n_nodes: int):
745
+ """Create a JIT-compiled subgroup map function with static n_nodes."""
746
+
747
+ @jax.jit
748
+ def get_subgroup_map(bt_level: jnp.ndarray) -> jnp.ndarray:
749
+ """Create one-hot node membership map. Returns (n_nodes, m)."""
750
+ return (jnp.arange(n_nodes)[:, None] == bt_level).astype(jnp.int8)
751
+
752
+ return get_subgroup_map
753
+
754
+
755
+ @jax.jit
756
+ def update_is_leaf(
757
+ is_leaf: jnp.ndarray,
758
+ gains: jnp.ndarray,
759
+ indices: jnp.ndarray,
760
+ ) -> jnp.ndarray:
761
+ """Mark nodes as leaves if gain <= 0 or non-finite."""
762
+ new_leaf = (gains <= 0.0) | (~jnp.isfinite(gains))
763
+ return is_leaf.at[indices].set(new_leaf.astype(jnp.int64))
764
+
765
+
766
+ @jax.jit
767
+ def update_bt(
768
+ bt: jnp.ndarray,
769
+ bt_level: jnp.ndarray,
770
+ is_leaf: jnp.ndarray,
771
+ bin_indices: jnp.ndarray,
772
+ best_feature: jnp.ndarray,
773
+ best_thresh_idx: jnp.ndarray,
774
+ ) -> jnp.ndarray:
775
+ """Update sample-to-node assignments after splitting."""
776
+ m = bt.shape[0]
777
+ feat_per_sample = best_feature[bt_level]
778
+ thresh_per_sample = best_thresh_idx[bt_level]
779
+ sample_bins = bin_indices[jnp.arange(m), feat_per_sample]
780
+
781
+ go_left = sample_bins <= thresh_per_sample
782
+ bt_next = jnp.where(go_left, 2 * bt + 1, 2 * bt + 2)
783
+ return jnp.where(is_leaf[bt].astype(bool), bt, bt_next)
784
+
785
+
786
+ def make_compute_leaf_values(n_nodes: int):
787
+ """Create a JIT-compiled leaf value computation with static n_nodes."""
788
+
789
+ @jax.jit
790
+ def compute_leaf_values(
791
+ gh: jnp.ndarray,
792
+ bt: jnp.ndarray,
793
+ is_leaf: jnp.ndarray,
794
+ reg_lambda: float,
795
+ ) -> jnp.ndarray:
796
+ """Compute leaf values from aggregated G/H."""
797
+ sum_gh = segment_sum(gh, bt, num_segments=n_nodes)
798
+ sum_g, sum_h = sum_gh[:, 0], sum_gh[:, 1]
799
+ safe_h = jnp.where(sum_h == 0, 1.0, sum_h)
800
+ leaf_vals = -sum_g / (safe_h + reg_lambda)
801
+
802
+ has_samples = sum_h != 0
803
+ return jnp.where(is_leaf.astype(bool) & has_samples, leaf_vals, 0.0)
804
+
805
+ return compute_leaf_values
806
+
807
+
808
+ # ==============================================================================
809
+ # Tree Building Helpers
810
+ # ==============================================================================
811
+
812
+
813
+ def _find_splits_ap(
814
+ ap_rank: int,
815
+ n_level: int,
816
+ n_buckets: int,
817
+ gh: Any,
818
+ bt_level: Any,
819
+ bin_indices: Any,
820
+ reg_lambda: float,
821
+ gamma: float,
822
+ min_child_weight: float,
823
+ ) -> tuple[Any, Any, Any]:
824
+ """Compute local histograms and find best splits at AP."""
825
+ local_hist_fn = make_local_build_histogram(n_level, n_buckets)
826
+ ap_hist = simp.pcall_static(
827
+ (ap_rank,),
828
+ lambda fn=local_hist_fn: tensor.run_jax(fn, gh, bt_level, bin_indices),
829
+ )
830
+ ap_gains, ap_feats, ap_threshs = simp.pcall_static(
831
+ (ap_rank,),
832
+ lambda rl=reg_lambda, gm=gamma, mcw=min_child_weight: tensor.run_jax(
833
+ local_compute_best_splits, ap_hist, rl, gm, mcw
834
+ ),
835
+ )
836
+ return ap_gains, ap_feats, ap_threshs
837
+
838
+
839
+ def _find_splits_pps(
840
+ level: int,
841
+ pp_ranks: list[int],
842
+ ap_rank: int,
843
+ g_cts_pps: dict[int, list[Any]],
844
+ h_cts_pps: dict[int, list[Any]],
845
+ bt_level: Any,
846
+ all_bin_indices: list[Any],
847
+ n_features_per_party: list[int],
848
+ last_level_hists: list[Any],
849
+ encoder: Any,
850
+ relin_keys: Any,
851
+ galois_keys: Any,
852
+ sk: Any,
853
+ fxp_scale: int,
854
+ m: int,
855
+ n_chunks: int,
856
+ slot_count: int,
857
+ n_buckets: int,
858
+ reg_lambda: float,
859
+ gamma: float,
860
+ min_child_weight: float,
861
+ ) -> tuple[list[Any], list[Any], list[Any]]:
862
+ """Compute remote histograms via FHE and find best splits at PPs."""
863
+ pp_gains_list = []
864
+ pp_feats_list = []
865
+ pp_threshs_list = []
866
+
867
+ n_level = 2**level
868
+
869
+ for pp_idx, pp_rank in enumerate(pp_ranks):
870
+ # Retrieve pre-transferred encrypted CT chunks
871
+ g_cts_pp = g_cts_pps[pp_rank]
872
+ h_cts_pp = h_cts_pps[pp_rank]
873
+
874
+ # Transfer keys and other metadata to PP
875
+ bt_level_pp = simp.shuffle_static(bt_level, {pp_rank: ap_rank})
876
+ encoder_pp = simp.shuffle_static(encoder, {pp_rank: ap_rank})
877
+ rk_pp = simp.shuffle_static(relin_keys, {pp_rank: ap_rank})
878
+ gk_pp = simp.shuffle_static(galois_keys, {pp_rank: ap_rank})
879
+
880
+ # Build subgroup map at PP
881
+ subgroup_map_fn = make_get_subgroup_map(n_level)
882
+ subgroup_map = simp.pcall_static(
883
+ (pp_rank,),
884
+ lambda fn=subgroup_map_fn, bt_lv=bt_level_pp: tensor.run_jax(fn, bt_lv),
885
+ )
886
+
887
+ n_pp_features = n_features_per_party[pp_idx + 1]
888
+
889
+ if level == 0:
890
+ # Root level: Compute full FHE
891
+ g_enc, h_enc = fhe_histogram_optimized(
892
+ g_cts_pp,
893
+ h_cts_pp,
894
+ subgroup_map,
895
+ all_bin_indices[pp_idx + 1],
896
+ n_buckets,
897
+ n_level,
898
+ n_pp_features,
899
+ pp_rank,
900
+ ap_rank,
901
+ encoder_pp,
902
+ rk_pp,
903
+ gk_pp,
904
+ m,
905
+ n_chunks,
906
+ slot_count,
907
+ )
908
+
909
+ pp_hists = decrypt_histogram_results(
910
+ g_enc,
911
+ h_enc,
912
+ sk,
913
+ encoder,
914
+ fxp_scale,
915
+ n_level,
916
+ n_pp_features,
917
+ n_buckets,
918
+ ap_rank,
919
+ )
920
+ # Store for next level
921
+ last_level_hists[pp_idx + 1] = pp_hists
922
+
923
+ else:
924
+ # Histogram Subtraction Optimization
925
+ # 1. Slice subgroup_map to get Left children (even indices 0, 2, ...)
926
+ def slice_left(sm):
927
+ return sm[0::2]
928
+
929
+ subgroup_map_left = simp.pcall_static(
930
+ (pp_rank,),
931
+ lambda sm=subgroup_map: tensor.run_jax(slice_left, sm),
932
+ )
933
+
934
+ # 2. Run FHE for Left children
935
+ n_left = n_level // 2
936
+ g_enc, h_enc = fhe_histogram_optimized(
937
+ g_cts_pp,
938
+ h_cts_pp,
939
+ subgroup_map_left,
940
+ all_bin_indices[pp_idx + 1],
941
+ n_buckets,
942
+ n_left,
943
+ n_pp_features,
944
+ pp_rank,
945
+ ap_rank,
946
+ encoder_pp,
947
+ rk_pp,
948
+ gk_pp,
949
+ m,
950
+ n_chunks,
951
+ slot_count,
952
+ )
953
+
954
+ # 3. Decrypt Left
955
+ left_hists = decrypt_histogram_results(
956
+ g_enc,
957
+ h_enc,
958
+ sk,
959
+ encoder,
960
+ fxp_scale,
961
+ n_left,
962
+ n_pp_features,
963
+ n_buckets,
964
+ ap_rank,
965
+ )
966
+
967
+ # 4. Derive Right and Reconstruct
968
+ parent_hists = last_level_hists[pp_idx + 1]
969
+
970
+ def derive_right_and_combine(l_hists, p_hists):
971
+ # l_hists: (n_left, ...)
972
+ # p_hists: (n_left, ...) - parents correspond exactly to left children
973
+ r_hists = p_hists - l_hists
974
+
975
+ # Interleave [L, R]
976
+ # Stack on new axis 1 -> (n_left, 2, ...)
977
+ combined = jnp.stack([l_hists, r_hists], axis=1)
978
+ # Reshape -> (2*n_left, ...)
979
+ return combined.reshape((-1, *l_hists.shape[1:]))
980
+
981
+ pp_hists = simp.pcall_static(
982
+ (ap_rank,),
983
+ lambda lh=left_hists, ph=parent_hists: tensor.run_jax(
984
+ derive_right_and_combine, lh, ph
985
+ ),
986
+ )
987
+
988
+ # Store for next level (if needed)
989
+ # Note: We don't know max_depth here, but storing it is harmless if not used
990
+ last_level_hists[pp_idx + 1] = pp_hists
991
+
992
+ # Stack and find best splits
993
+ def find_splits(hists, rl=reg_lambda, gm=gamma, mcw=min_child_weight):
994
+ # hists is already (n_nodes, n_feat, n_buck, 2)
995
+ return jax.vmap(lambda h: compute_best_split_from_hist(h, rl, gm, mcw))(
996
+ hists
997
+ )
998
+
999
+ pp_gains, pp_feats, pp_threshs = simp.pcall_static(
1000
+ (ap_rank,),
1001
+ lambda h=pp_hists: tensor.run_jax(find_splits, h),
1002
+ )
1003
+
1004
+ pp_gains_list.append(pp_gains)
1005
+ pp_feats_list.append(pp_feats)
1006
+ pp_threshs_list.append(pp_threshs)
1007
+
1008
+ return pp_gains_list, pp_feats_list, pp_threshs_list
1009
+
1010
+
1011
+ def _update_tree_state(
1012
+ ap_rank: int,
1013
+ pp_ranks: list[int],
1014
+ all_ranks: list[int],
1015
+ all_feats: list[Any],
1016
+ all_thresholds: list[Any],
1017
+ bt: Any,
1018
+ bt_level: Any,
1019
+ is_leaf: Any,
1020
+ owned_party: Any,
1021
+ cur_indices: Any,
1022
+ best_party: Any,
1023
+ best_gains: Any,
1024
+ all_feats_level: list[Any],
1025
+ all_threshs_level: list[Any],
1026
+ all_bins: list[Any],
1027
+ all_bin_indices: list[Any],
1028
+ ) -> tuple[Any, Any, list[Any], list[Any], Any]:
1029
+ """Update tree structure and sample assignments based on best splits."""
1030
+ # Update is_leaf
1031
+ is_leaf = simp.pcall_static(
1032
+ (ap_rank,),
1033
+ lambda: tensor.run_jax(update_is_leaf, is_leaf, best_gains, cur_indices),
1034
+ )
1035
+
1036
+ # Broadcast is_leaf to all parties (keep source, shuffle to each target, then converge)
1037
+ if pp_ranks:
1038
+ is_leaf_parts = [is_leaf] # Start with AP's copy
1039
+ for r in pp_ranks:
1040
+ is_leaf_parts.append(simp.shuffle_static(is_leaf, {r: ap_rank}))
1041
+ is_leaf = simp.converge(*is_leaf_parts)
1042
+
1043
+ # Update owned_party
1044
+ owned_party = simp.pcall_static(
1045
+ (ap_rank,),
1046
+ lambda: tensor.run_jax(
1047
+ lambda op, bp, ci: op.at[ci].set(bp),
1048
+ owned_party,
1049
+ best_party,
1050
+ cur_indices,
1051
+ ),
1052
+ )
1053
+
1054
+ # Broadcast owned_party to all parties
1055
+ if pp_ranks:
1056
+ owned_party_parts = [owned_party]
1057
+ for r in pp_ranks:
1058
+ owned_party_parts.append(simp.shuffle_static(owned_party, {r: ap_rank}))
1059
+ owned_party = simp.converge(*owned_party_parts)
1060
+
1061
+ # === Update features and thresholds for each party ===
1062
+ # Route best_feats/best_threshs to correct parties based on best_party
1063
+ all_tmp_bt: list[Any] = []
1064
+
1065
+ for party_idx, party_rank in enumerate(all_ranks):
1066
+ # Transfer data to this party if needed
1067
+ if party_idx > 0:
1068
+ # PP's results are already at AP, send back to PP
1069
+ all_feats_level[party_idx] = simp.shuffle_static(
1070
+ all_feats_level[party_idx], {party_rank: ap_rank}
1071
+ )
1072
+ all_threshs_level[party_idx] = simp.shuffle_static(
1073
+ all_threshs_level[party_idx], {party_rank: ap_rank}
1074
+ )
1075
+ # Also need cur_indices, owned_party, is_leaf at PP
1076
+ cur_indices_party = simp.shuffle_static(cur_indices, {party_rank: ap_rank})
1077
+ owned_party_party = simp.shuffle_static(owned_party, {party_rank: ap_rank})
1078
+ is_leaf_party = simp.shuffle_static(is_leaf, {party_rank: ap_rank})
1079
+ else:
1080
+ cur_indices_party = cur_indices
1081
+ owned_party_party = owned_party
1082
+ is_leaf_party = is_leaf
1083
+
1084
+ # Update this party's feature and threshold arrays
1085
+ def update_party_feats(
1086
+ feats,
1087
+ best_feat,
1088
+ indices,
1089
+ owned,
1090
+ leaf,
1091
+ pid=party_idx,
1092
+ ):
1093
+ tmp = feats.at[indices].set(best_feat)
1094
+ tmp = jnp.where(leaf.astype(bool), jnp.int64(-1), tmp)
1095
+ mask = owned == pid
1096
+ return jnp.where(mask, tmp, jnp.int64(-1))
1097
+
1098
+ all_feats[party_idx] = simp.pcall_static(
1099
+ (party_rank,),
1100
+ lambda pf=all_feats[party_idx], bf=all_feats_level[party_idx], ci=cur_indices_party, op=owned_party_party, il=is_leaf_party: (
1101
+ tensor.run_jax(update_party_feats, pf, bf, ci, op, il)
1102
+ ),
1103
+ )
1104
+
1105
+ def update_party_thresholds(
1106
+ thresholds,
1107
+ bins_arr,
1108
+ best_feat,
1109
+ best_thresh_idx,
1110
+ indices,
1111
+ owned,
1112
+ leaf,
1113
+ pid=party_idx,
1114
+ ):
1115
+ # Get actual threshold values from bins
1116
+ best_thresh = bins_arr[best_feat, best_thresh_idx]
1117
+ tmp = thresholds.at[indices].set(best_thresh)
1118
+ tmp = jnp.where(leaf.astype(bool), jnp.float32(jnp.inf), tmp)
1119
+ mask = owned == pid
1120
+ return jnp.where(mask, tmp, jnp.float32(jnp.inf))
1121
+
1122
+ all_thresholds[party_idx] = simp.pcall_static(
1123
+ (party_rank,),
1124
+ lambda pt=all_thresholds[party_idx], b=all_bins[party_idx], bf=all_feats_level[party_idx], bt_idx=all_threshs_level[party_idx], ci=cur_indices_party, op=owned_party_party, il=is_leaf_party: (
1125
+ tensor.run_jax(
1126
+ update_party_thresholds,
1127
+ pt,
1128
+ b,
1129
+ bf,
1130
+ bt_idx,
1131
+ ci,
1132
+ op,
1133
+ il,
1134
+ )
1135
+ ),
1136
+ )
1137
+
1138
+ # Compute temporary bt for this party
1139
+ # Need bt and bt_level at this party too
1140
+ if party_idx > 0:
1141
+ bt_party = simp.shuffle_static(bt, {party_rank: ap_rank})
1142
+ bt_level_party = simp.shuffle_static(bt_level, {party_rank: ap_rank})
1143
+ else:
1144
+ bt_party = bt
1145
+ bt_level_party = bt_level
1146
+
1147
+ tmp_bt = simp.pcall_static(
1148
+ (party_rank,),
1149
+ lambda bi=all_bin_indices[party_idx], bf=all_feats_level[party_idx], bt_idx=all_threshs_level[party_idx], bt_arr=bt_party, bt_lv=bt_level_party, il=is_leaf_party: (
1150
+ tensor.run_jax(update_bt, bt_arr, bt_lv, il, bi, bf, bt_idx)
1151
+ ),
1152
+ )
1153
+
1154
+ # Transfer PP's tmp_bt to AP for merging
1155
+ if party_idx > 0:
1156
+ tmp_bt = simp.shuffle_static(tmp_bt, {ap_rank: party_rank})
1157
+
1158
+ all_tmp_bt.append(tmp_bt)
1159
+
1160
+ # === Merge bt updates based on best_party ===
1161
+ def merge_bt_updates(
1162
+ current_bt,
1163
+ all_tmp,
1164
+ best_party_arr,
1165
+ level_indices,
1166
+ ):
1167
+ stacked = jnp.stack(all_tmp, axis=0) # (n_parties, m)
1168
+ updated_bt = current_bt
1169
+
1170
+ def update_for_node(carry, i):
1171
+ bt_arr = carry
1172
+ node_idx = level_indices[i]
1173
+ winning_party = best_party_arr[i]
1174
+ samples_in_node = current_bt == node_idx
1175
+ winning_bt = stacked[winning_party]
1176
+ return jnp.where(samples_in_node, winning_bt, bt_arr), None
1177
+
1178
+ updated_bt, _ = jax.lax.scan(
1179
+ update_for_node, updated_bt, jnp.arange(len(level_indices))
1180
+ )
1181
+ return updated_bt
1182
+
1183
+ bt = simp.pcall_static(
1184
+ (ap_rank,),
1185
+ lambda: tensor.run_jax(
1186
+ merge_bt_updates, bt, all_tmp_bt, best_party, cur_indices
1187
+ ),
1188
+ )
1189
+
1190
+ # Broadcast updated bt to all parties
1191
+ if pp_ranks:
1192
+ bt_parts = [bt]
1193
+ for r in pp_ranks:
1194
+ bt_parts.append(simp.shuffle_static(bt, {r: ap_rank}))
1195
+ bt = simp.converge(*bt_parts)
1196
+
1197
+ return is_leaf, owned_party, all_feats, all_thresholds, bt
1198
+
1199
+
1200
+ def build_tree(
1201
+ gh: Any, # Plaintext G/H at AP, shape (m, 2)
1202
+ g_cts: list[Any], # Encrypted G chunks at AP
1203
+ h_cts: list[Any], # Encrypted H chunks at AP
1204
+ n_chunks: int, # Number of CT chunks
1205
+ all_bins: list[Any], # Bin boundaries per party
1206
+ all_bin_indices: list[Any], # Binned features per party
1207
+ sk: Any, # Secret key at AP
1208
+ pk: Any, # Public key at AP
1209
+ encoder: Any, # BFV encoder
1210
+ relin_keys: Any, # Relinearization keys
1211
+ galois_keys: Any, # Galois keys for rotation
1212
+ fxp_scale: int,
1213
+ ap_rank: int,
1214
+ pp_ranks: list[int],
1215
+ max_depth: int,
1216
+ reg_lambda: float,
1217
+ gamma: float,
1218
+ min_child_weight: float,
1219
+ n_samples: int,
1220
+ n_buckets: int,
1221
+ n_features_per_party: list[int], # Number of features for each party
1222
+ slot_count: int = DEFAULT_POLY_MODULUS_DEGREE,
1223
+ ) -> Tree:
1224
+ """Build a single decision tree level by level.
1225
+
1226
+ The algorithm proceeds breadth-first:
1227
+ 1. For each level, compute histograms (local at AP, FHE at PPs)
1228
+ 2. Find best split per node across all parties
1229
+ 3. Update tree structure and sample assignments
1230
+ 4. Repeat until max_depth reached
1231
+
1232
+ **Multi-CT Support**: When n_samples > slot_count, data is split into
1233
+ n_chunks ciphertexts. Each chunk is processed separately and results
1234
+ are accumulated in the FHE domain.
1235
+ """
1236
+ m = n_samples
1237
+ n_nodes = 2 ** (max_depth + 1) - 1
1238
+ all_ranks = [ap_rank, *pp_ranks]
1239
+
1240
+ # Initialize tree arrays
1241
+ def init_array(rank, shape, dtype, fill):
1242
+ return simp.pcall_static(
1243
+ (rank,),
1244
+ lambda: tensor.constant(np.full(shape, fill, dtype=dtype)),
1245
+ )
1246
+
1247
+ all_feats = [init_array(r, n_nodes, np.int64, -1) for r in all_ranks]
1248
+ all_thresholds = [init_array(r, n_nodes, np.float32, np.inf) for r in all_ranks]
1249
+ values = init_array(ap_rank, n_nodes, np.float32, 0.0)
1250
+ is_leaf = init_array(ap_rank, n_nodes, np.int64, 0)
1251
+ owned_party = init_array(ap_rank, n_nodes, np.int64, -1)
1252
+ bt = init_array(ap_rank, m, np.int64, 0)
1253
+
1254
+ # Store parent histograms for subtraction optimization
1255
+ # List of TraceObjects (JAX arrays) representing stacked histograms for previous level
1256
+ # Index 0 is AP (unused), 1..k are PPs
1257
+ last_level_hists: list[Any] = [None] * (len(pp_ranks) + 1)
1258
+
1259
+ # Optimization 1: Hoist Ciphertext Transfer
1260
+ # Transfer encrypted gradients to all PPs once, before the tree building loop.
1261
+ g_cts_pps: dict[int, list[Any]] = {}
1262
+ h_cts_pps: dict[int, list[Any]] = {}
1263
+
1264
+ for pp_rank in pp_ranks:
1265
+ g_cts_pps[pp_rank] = [
1266
+ simp.shuffle_static(ct, {pp_rank: ap_rank}) for ct in g_cts
1267
+ ]
1268
+ h_cts_pps[pp_rank] = [
1269
+ simp.shuffle_static(ct, {pp_rank: ap_rank}) for ct in h_cts
1270
+ ]
1271
+
1272
+ for level in range(max_depth):
1273
+ n_level = 2**level
1274
+ level_offset = 2**level - 1
1275
+
1276
+ cur_indices = simp.pcall_static(
1277
+ (ap_rank,),
1278
+ lambda off=level_offset, nl=n_level: tensor.constant(
1279
+ np.arange(nl, dtype=np.int64) + off
1280
+ ),
1281
+ )
1282
+
1283
+ # Local bt for this level
1284
+ bt_level = simp.pcall_static(
1285
+ (ap_rank,),
1286
+ lambda off=level_offset, b=bt: tensor.run_jax(lambda x: x - off, b),
1287
+ )
1288
+
1289
+ # === AP: Local histogram computation ===
1290
+ ap_gains, ap_feats, ap_threshs = _find_splits_ap(
1291
+ ap_rank,
1292
+ n_level,
1293
+ n_buckets,
1294
+ gh,
1295
+ bt_level,
1296
+ all_bin_indices[0],
1297
+ reg_lambda,
1298
+ gamma,
1299
+ min_child_weight,
1300
+ )
1301
+
1302
+ all_gains = [ap_gains]
1303
+ all_feats_level = [ap_feats]
1304
+ all_threshs_level = [ap_threshs]
1305
+
1306
+ # === PP: FHE histogram computation ===
1307
+ pp_gains_list, pp_feats_list, pp_threshs_list = _find_splits_pps(
1308
+ level,
1309
+ pp_ranks,
1310
+ ap_rank,
1311
+ g_cts_pps,
1312
+ h_cts_pps,
1313
+ bt_level,
1314
+ all_bin_indices,
1315
+ n_features_per_party,
1316
+ last_level_hists,
1317
+ encoder,
1318
+ relin_keys,
1319
+ galois_keys,
1320
+ sk,
1321
+ fxp_scale,
1322
+ m,
1323
+ n_chunks,
1324
+ slot_count,
1325
+ n_buckets,
1326
+ reg_lambda,
1327
+ gamma,
1328
+ min_child_weight,
1329
+ )
1330
+
1331
+ all_gains.extend(pp_gains_list)
1332
+ all_feats_level.extend(pp_feats_list)
1333
+ all_threshs_level.extend(pp_threshs_list)
1334
+
1335
+ # === Find global best split across all parties ===
1336
+ def find_global_best(*gains):
1337
+ stacked = jnp.stack(gains, axis=0) # (n_parties, n_nodes)
1338
+ best_party = jnp.argmax(stacked, axis=0)
1339
+ best_gains = jnp.take_along_axis(
1340
+ stacked, best_party[None, :], axis=0
1341
+ ).squeeze(0)
1342
+ return best_gains, best_party
1343
+
1344
+ best_gains, best_party = simp.pcall_static(
1345
+ (ap_rank,),
1346
+ lambda gains=all_gains: tensor.run_jax(find_global_best, *gains),
1347
+ )
1348
+
1349
+ # === Update Tree State ===
1350
+ is_leaf, owned_party, all_feats, all_thresholds, bt = _update_tree_state(
1351
+ ap_rank,
1352
+ pp_ranks,
1353
+ all_ranks,
1354
+ all_feats,
1355
+ all_thresholds,
1356
+ bt,
1357
+ bt_level,
1358
+ is_leaf,
1359
+ owned_party,
1360
+ cur_indices,
1361
+ best_party,
1362
+ best_gains,
1363
+ all_feats_level,
1364
+ all_threshs_level,
1365
+ all_bins,
1366
+ all_bin_indices,
1367
+ )
1368
+
1369
+ # Force final level nodes to be leaves
1370
+ final_start = 2**max_depth - 1
1371
+ final_end = 2 ** (max_depth + 1) - 1
1372
+ final_indices = simp.pcall_static(
1373
+ (ap_rank,),
1374
+ lambda: tensor.constant(np.arange(final_start, final_end, dtype=np.int64)),
1375
+ )
1376
+ is_leaf = simp.pcall_static(
1377
+ (ap_rank,),
1378
+ lambda: tensor.run_jax(
1379
+ lambda il, fi: il.at[fi].set(1),
1380
+ is_leaf,
1381
+ final_indices,
1382
+ ),
1383
+ )
1384
+
1385
+ # Broadcast final is_leaf to all parties (needed for prediction)
1386
+ # Note: owned_party is already converged to all parties during the level loop
1387
+ if pp_ranks:
1388
+ is_leaf_parts = [is_leaf]
1389
+ for r in pp_ranks:
1390
+ is_leaf_parts.append(simp.shuffle_static(is_leaf, {r: ap_rank}))
1391
+ is_leaf = simp.converge(*is_leaf_parts)
1392
+
1393
+ # Compute final leaf values
1394
+ leaf_val_fn = make_compute_leaf_values(n_nodes)
1395
+ values = simp.pcall_static(
1396
+ (ap_rank,),
1397
+ lambda fn=leaf_val_fn: tensor.run_jax(fn, gh, bt, is_leaf, reg_lambda),
1398
+ )
1399
+
1400
+ return Tree(
1401
+ feature=all_feats,
1402
+ threshold=all_thresholds,
1403
+ value=values,
1404
+ is_leaf=is_leaf,
1405
+ owned_party_id=owned_party,
1406
+ )
1407
+
1408
+
1409
+ # ==============================================================================
1410
+ # Prediction
1411
+ # ==============================================================================
1412
+
1413
+
1414
+ def predict_tree_single_party(
1415
+ data: Any,
1416
+ feature: Any,
1417
+ threshold: Any,
1418
+ is_leaf: Any,
1419
+ owned_party_id: Any,
1420
+ party_id: int,
1421
+ n_nodes: int,
1422
+ ) -> Any:
1423
+ """Local tree traversal for a single party.
1424
+
1425
+ Returns a location matrix (m, n_nodes) where each sample may be in multiple
1426
+ nodes if splits are owned by other parties.
1427
+ """
1428
+
1429
+ def traverse_kernel(
1430
+ data_arr,
1431
+ feat_arr,
1432
+ thresh_arr,
1433
+ leaf_arr,
1434
+ owner_arr,
1435
+ ):
1436
+ n_samples = data_arr.shape[0]
1437
+ # Start all samples at root
1438
+ locations = jnp.zeros((n_samples, n_nodes), dtype=jnp.int64).at[:, 0].set(1)
1439
+
1440
+ def propagate(i, locs):
1441
+ is_my_split = (leaf_arr[i] == 0) & (owner_arr[i] == party_id)
1442
+
1443
+ def process_my_split(locs_inner):
1444
+ samples_here = locs_inner[:, i]
1445
+ feat_idx = feat_arr[i]
1446
+ thresh = thresh_arr[i]
1447
+ go_left = data_arr[:, feat_idx] <= thresh
1448
+ to_left = samples_here * go_left.astype(jnp.int64)
1449
+ to_right = samples_here * (1 - go_left.astype(jnp.int64))
1450
+ locs_inner = locs_inner.at[:, 2 * i + 1].add(to_left)
1451
+ locs_inner = locs_inner.at[:, 2 * i + 2].add(to_right)
1452
+ return locs_inner.at[:, i].set(0)
1453
+
1454
+ def propagate_unknown(locs_inner):
1455
+ is_split = leaf_arr[i] == 0
1456
+
1457
+ def propagate_both(loc):
1458
+ samples_here = loc[:, i]
1459
+ loc = loc.at[:, 2 * i + 1].add(samples_here)
1460
+ loc = loc.at[:, 2 * i + 2].add(samples_here)
1461
+ return loc.at[:, i].set(0)
1462
+
1463
+ return jax.lax.cond(is_split, propagate_both, lambda x: x, locs_inner)
1464
+
1465
+ return jax.lax.cond(is_my_split, process_my_split, propagate_unknown, locs)
1466
+
1467
+ return jax.lax.fori_loop(0, n_nodes // 2, propagate, locations)
1468
+
1469
+ return tensor.run_jax(
1470
+ traverse_kernel, data, feature, threshold, is_leaf, owned_party_id
1471
+ )
1472
+
1473
+
1474
+ def predict_tree(
1475
+ tree: Tree,
1476
+ all_datas: list[Any],
1477
+ ap_rank: int,
1478
+ pp_ranks: list[int],
1479
+ n_nodes: int,
1480
+ ) -> Any:
1481
+ """Predict using a single tree by aggregating location masks from all parties."""
1482
+ all_ranks = [ap_rank, *pp_ranks]
1483
+
1484
+ # Each party computes its local traversal
1485
+ all_masks: list[Any] = []
1486
+
1487
+ for i, rank in enumerate(all_ranks):
1488
+ mask = simp.pcall_static(
1489
+ (rank,),
1490
+ lambda d=all_datas[i], f=tree.feature[i], t=tree.threshold[i], idx=i: (
1491
+ predict_tree_single_party(
1492
+ d, f, t, tree.is_leaf, tree.owned_party_id, idx, n_nodes
1493
+ )
1494
+ ),
1495
+ )
1496
+ # Transfer to AP
1497
+ if rank != ap_rank:
1498
+ mask = simp.shuffle_static(mask, {ap_rank: rank})
1499
+ all_masks.append(mask)
1500
+
1501
+ # Aggregate masks at AP
1502
+ def aggregate_predictions(
1503
+ *masks,
1504
+ leaf_arr,
1505
+ values_arr,
1506
+ ):
1507
+ stacked = jnp.stack(masks, axis=0) # (n_parties, m, n_nodes)
1508
+ # Consensus: sample is at node only if ALL parties agree
1509
+ consensus = jnp.all(stacked > 0, axis=0) # (m, n_nodes)
1510
+ # Find leaf nodes
1511
+ final_leaf_mask = consensus * leaf_arr.astype(bool)
1512
+ # Get leaf index for each sample
1513
+ leaf_indices = jnp.argmax(final_leaf_mask, axis=1)
1514
+ return values_arr[leaf_indices]
1515
+
1516
+ predictions = simp.pcall_static(
1517
+ (ap_rank,),
1518
+ lambda: tensor.run_jax(
1519
+ aggregate_predictions,
1520
+ *all_masks,
1521
+ leaf_arr=tree.is_leaf,
1522
+ values_arr=tree.value,
1523
+ ),
1524
+ )
1525
+
1526
+ return predictions
1527
+
1528
+
1529
+ def predict_ensemble(
1530
+ model: TreeEnsemble,
1531
+ all_datas: list[Any],
1532
+ ap_rank: int,
1533
+ pp_ranks: list[int],
1534
+ learning_rate: float,
1535
+ n_samples: int,
1536
+ n_nodes: int,
1537
+ ) -> Any:
1538
+ """Predict using the full ensemble."""
1539
+ m = n_samples
1540
+
1541
+ # Start with initial prediction
1542
+ y_pred_logits = simp.pcall_static(
1543
+ (ap_rank,),
1544
+ lambda n=m: tensor.run_jax(
1545
+ lambda init: init * jnp.ones(n), model.initial_prediction
1546
+ ),
1547
+ )
1548
+
1549
+ # Add predictions from each tree
1550
+ for tree in model.trees:
1551
+ tree_pred = predict_tree(tree, all_datas, ap_rank, pp_ranks, n_nodes)
1552
+
1553
+ def update_pred(y_pred, pred, lr=learning_rate):
1554
+ return y_pred + lr * pred
1555
+
1556
+ y_pred_logits = simp.pcall_static(
1557
+ (ap_rank,),
1558
+ lambda yp=y_pred_logits, tp=tree_pred: tensor.run_jax(update_pred, yp, tp),
1559
+ )
1560
+
1561
+ # Convert logits to probabilities
1562
+ y_prob = simp.pcall_static(
1563
+ (ap_rank,),
1564
+ lambda: tensor.run_jax(sigmoid, y_pred_logits),
1565
+ )
1566
+
1567
+ return y_prob
1568
+
1569
+
1570
+ # ==============================================================================
1571
+ # Training API
1572
+ # ==============================================================================
1573
+
1574
+
1575
+ def fit_tree_ensemble(
1576
+ all_datas: list[Any],
1577
+ y_data: Any,
1578
+ all_bins: list[Any],
1579
+ all_bin_indices: list[Any],
1580
+ initial_pred: Any,
1581
+ n_samples: int,
1582
+ n_buckets: int,
1583
+ n_features_per_party: list[int],
1584
+ n_estimators: int,
1585
+ learning_rate: float,
1586
+ max_depth: int,
1587
+ reg_lambda: float,
1588
+ gamma: float,
1589
+ min_child_weight: float,
1590
+ ap_rank: int,
1591
+ pp_ranks: list[int],
1592
+ ) -> TreeEnsemble:
1593
+ """Fit a SecureBoost tree ensemble."""
1594
+ m = n_samples
1595
+ fxp_scale = 1 << DEFAULT_FXP_BITS
1596
+
1597
+ y_pred = simp.pcall_static(
1598
+ (ap_rank,),
1599
+ lambda n=m: tensor.run_jax(lambda init: init * jnp.ones(n), initial_pred),
1600
+ )
1601
+
1602
+ # BFV key generation at AP (only if we have passive parties)
1603
+ pk, sk, relin_keys, galois_keys, encoder = None, None, None, None, None
1604
+ if pp_ranks:
1605
+
1606
+ def keygen_fn():
1607
+ pub, sec = bfv.keygen(poly_modulus_degree=DEFAULT_POLY_MODULUS_DEGREE)
1608
+ rk = bfv.make_relin_keys(sec)
1609
+ gk = bfv.make_galois_keys(sec)
1610
+ enc = bfv.create_encoder(poly_modulus_degree=DEFAULT_POLY_MODULUS_DEGREE)
1611
+ return pub, sec, rk, gk, enc
1612
+
1613
+ pk, sk, relin_keys, galois_keys, encoder = simp.pcall_static(
1614
+ (ap_rank,), keygen_fn
1615
+ )
1616
+
1617
+ trees: list[Tree] = []
1618
+
1619
+ for _tree_idx in range(n_estimators):
1620
+ # Compute G/H, quantize, and split into qg/qh in one call
1621
+ def compute_gh_quantized(y_true, y_pred_logits, scale):
1622
+ gh = compute_gh(y_true, y_pred_logits)
1623
+ qgh = quantize_gh(gh, scale)
1624
+ return gh, qgh[:, 0], qgh[:, 1]
1625
+
1626
+ gh, qg, qh = simp.pcall_static(
1627
+ (ap_rank,),
1628
+ lambda yp=y_pred: tensor.run_jax(
1629
+ compute_gh_quantized, y_data, yp, fxp_scale
1630
+ ),
1631
+ )
1632
+
1633
+ # FHE encrypt only if we have passive parties
1634
+ g_cts, h_cts, n_chunks = [], [], 1
1635
+ if pp_ranks:
1636
+ g_cts, h_cts, n_chunks = fhe_encrypt_gh(
1637
+ qg, qh, pk, encoder, ap_rank, n_samples
1638
+ )
1639
+
1640
+ tree = build_tree(
1641
+ gh,
1642
+ g_cts,
1643
+ h_cts,
1644
+ n_chunks,
1645
+ all_bins,
1646
+ all_bin_indices,
1647
+ sk,
1648
+ pk,
1649
+ encoder,
1650
+ relin_keys,
1651
+ galois_keys,
1652
+ fxp_scale,
1653
+ ap_rank,
1654
+ pp_ranks,
1655
+ max_depth,
1656
+ reg_lambda,
1657
+ gamma,
1658
+ min_child_weight,
1659
+ n_samples,
1660
+ n_buckets,
1661
+ n_features_per_party,
1662
+ )
1663
+ trees.append(tree)
1664
+
1665
+ # Predict tree and update y_pred
1666
+ n_nodes = 2 ** (max_depth + 1) - 1
1667
+ tree_pred = predict_tree(tree, all_datas, ap_rank, pp_ranks, n_nodes)
1668
+
1669
+ def update_pred_fn(curr_y, t_pred, lr=learning_rate):
1670
+ return curr_y + lr * t_pred
1671
+
1672
+ y_pred = simp.pcall_static(
1673
+ (ap_rank,),
1674
+ lambda yp=y_pred, tp=tree_pred: tensor.run_jax(update_pred_fn, yp, tp),
1675
+ )
1676
+
1677
+ return TreeEnsemble(
1678
+ max_depth=max_depth,
1679
+ trees=trees,
1680
+ initial_prediction=initial_pred,
1681
+ )
1682
+
1683
+
1684
+ # ==============================================================================
1685
+ # SecureBoost Class
1686
+ # ==============================================================================
1687
+
1688
+
1689
+ class SecureBoost:
1690
+ """SecureBoost classifier using mplang.v2 low-level BFV APIs.
1691
+
1692
+ This is an optimized implementation that uses BFV SIMD slots for
1693
+ efficient histogram computation.
1694
+
1695
+ Example:
1696
+ model = SecureBoost(n_estimators=10, max_depth=3)
1697
+ model.fit([X_ap, X_pp], y)
1698
+ predictions = model.predict([X_ap_test, X_pp_test])
1699
+ """
1700
+
1701
+ def __init__(
1702
+ self,
1703
+ n_estimators: int = 10,
1704
+ learning_rate: float = 0.1,
1705
+ max_depth: int = 3,
1706
+ max_bin: int = 8,
1707
+ reg_lambda: float = 1.0,
1708
+ gamma: float = 0.0,
1709
+ min_child_weight: float = 1.0,
1710
+ ap_rank: int = 0,
1711
+ pp_ranks: list[int] | None = None,
1712
+ ):
1713
+ """Initialize SecureBoost model.
1714
+
1715
+ Args:
1716
+ n_estimators: Number of trees to train
1717
+ learning_rate: Shrinkage factor for updates
1718
+ max_depth: Maximum tree depth
1719
+ max_bin: Maximum number of bins per feature
1720
+ reg_lambda: L2 regularization on leaf weights
1721
+ gamma: Minimum gain required to split
1722
+ min_child_weight: Minimum hessian sum in children
1723
+ ap_rank: Active party rank (holds labels)
1724
+ pp_ranks: Passive party ranks (hold features)
1725
+ """
1726
+ if max_bin < 2:
1727
+ raise ValueError(f"max_bin must be >= 2, got {max_bin}")
1728
+
1729
+ self.n_estimators = n_estimators
1730
+ self.learning_rate = learning_rate
1731
+ self.max_depth = max_depth
1732
+ self.max_bin = max_bin
1733
+ self.reg_lambda = reg_lambda
1734
+ self.gamma = gamma
1735
+ self.min_child_weight = min_child_weight
1736
+ self.ap_rank = ap_rank
1737
+ self.pp_ranks = pp_ranks if pp_ranks is not None else [1]
1738
+ self.model: TreeEnsemble | None = None
1739
+
1740
+ def fit(
1741
+ self,
1742
+ all_datas: list[Any],
1743
+ y_data: Any,
1744
+ n_samples: int,
1745
+ n_features_per_party: list[int],
1746
+ ) -> SecureBoost:
1747
+ """Fit the SecureBoost model.
1748
+
1749
+ Args:
1750
+ all_datas: List of feature tensors, one per party.
1751
+ First element is AP's features, rest are PPs'.
1752
+ y_data: Labels tensor at AP.
1753
+ n_samples: Number of training samples.
1754
+ n_features_per_party: Number of features for each party.
1755
+
1756
+ Returns:
1757
+ self for method chaining
1758
+ """
1759
+ self.n_samples = n_samples
1760
+ self.n_features_per_party = n_features_per_party
1761
+ # Build bins for each party
1762
+ all_ranks = [self.ap_rank, *self.pp_ranks]
1763
+
1764
+ build_bins_vmap = jax.vmap(
1765
+ partial(build_bins_equi_width, max_bin=self.max_bin), in_axes=1
1766
+ )
1767
+ compute_indices_vmap = jax.vmap(compute_bin_indices, in_axes=(1, 0), out_axes=1)
1768
+
1769
+ all_bins: list[Any] = []
1770
+ all_bin_indices: list[Any] = []
1771
+
1772
+ for i, rank in enumerate(all_ranks):
1773
+ data = all_datas[i]
1774
+ bins = simp.pcall_static(
1775
+ (rank,),
1776
+ lambda d=data: tensor.run_jax(build_bins_vmap, d),
1777
+ )
1778
+ indices = simp.pcall_static(
1779
+ (rank,),
1780
+ lambda d=data, b=bins: tensor.run_jax(compute_indices_vmap, d, b),
1781
+ )
1782
+ all_bins.append(bins)
1783
+ all_bin_indices.append(indices)
1784
+
1785
+ # Initial prediction
1786
+ initial_pred = simp.pcall_static(
1787
+ (self.ap_rank,),
1788
+ lambda: tensor.run_jax(compute_init_pred, y_data),
1789
+ )
1790
+
1791
+ # Calculate metadata
1792
+ n_buckets = self.max_bin + 1
1793
+ n_features_per_party = self.n_features_per_party
1794
+
1795
+ self.model = fit_tree_ensemble(
1796
+ all_datas,
1797
+ y_data,
1798
+ all_bins,
1799
+ all_bin_indices,
1800
+ initial_pred,
1801
+ self.n_samples,
1802
+ n_buckets,
1803
+ n_features_per_party,
1804
+ self.n_estimators,
1805
+ self.learning_rate,
1806
+ self.max_depth,
1807
+ self.reg_lambda,
1808
+ self.gamma,
1809
+ self.min_child_weight,
1810
+ self.ap_rank,
1811
+ self.pp_ranks,
1812
+ )
1813
+
1814
+ return self
1815
+
1816
+ def predict(self, all_datas: list[Any], n_samples: int) -> Any:
1817
+ """Predict probabilities for new data.
1818
+
1819
+ Args:
1820
+ all_datas: List of feature tensors, one per party.
1821
+ n_samples: Number of samples.
1822
+
1823
+ Returns:
1824
+ Predicted probabilities at AP.
1825
+ """
1826
+ if self.model is None:
1827
+ raise RuntimeError("Model not fitted. Call fit() first.")
1828
+
1829
+ n_nodes = 2 ** (self.max_depth + 1) - 1
1830
+ return predict_ensemble(
1831
+ self.model,
1832
+ all_datas,
1833
+ self.ap_rank,
1834
+ self.pp_ranks,
1835
+ self.learning_rate,
1836
+ n_samples,
1837
+ n_nodes,
1838
+ )
1839
+
1840
+ def predict_proba(self, all_datas: list[Any], n_samples: int) -> Any:
1841
+ """Alias for predict()."""
1842
+ return self.predict(all_datas, n_samples)
1843
+
1844
+ def evaluate(self, all_datas: list[Any], y_data: Any, n_samples: int) -> Any:
1845
+ """Evaluate model on test data.
1846
+
1847
+ Returns:
1848
+ Accuracy tensor at AP (needs to be fetched after graph execution).
1849
+ """
1850
+ y_prob = self.predict(all_datas, n_samples)
1851
+
1852
+ def compute_metrics(y_pred, y_true):
1853
+ y_class = (y_pred > 0.5).astype(jnp.float32)
1854
+ accuracy = jnp.mean(y_class == y_true)
1855
+ return accuracy
1856
+
1857
+ accuracy = simp.pcall_static(
1858
+ (self.ap_rank,),
1859
+ lambda: tensor.run_jax(compute_metrics, y_prob, y_data),
1860
+ )
1861
+ return accuracy