mplang-nightly 0.1.dev158__py3-none-any.whl → 0.1.dev268__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- mplang/__init__.py +21 -45
- mplang/py.typed +13 -0
- mplang/v1/__init__.py +157 -0
- mplang/v1/_device.py +602 -0
- mplang/{analysis → v1/analysis}/__init__.py +1 -1
- mplang/{analysis → v1/analysis}/diagram.py +5 -7
- mplang/v1/core/__init__.py +157 -0
- mplang/{core → v1/core}/cluster.py +30 -14
- mplang/{core → v1/core}/comm.py +5 -1
- mplang/{core → v1/core}/context_mgr.py +1 -1
- mplang/{core/dtype.py → v1/core/dtypes.py} +44 -2
- mplang/{core → v1/core}/expr/__init__.py +7 -7
- mplang/{core → v1/core}/expr/ast.py +13 -14
- mplang/{core → v1/core}/expr/evaluator.py +65 -24
- mplang/{core → v1/core}/expr/printer.py +24 -18
- mplang/{core → v1/core}/expr/transformer.py +3 -3
- mplang/{core → v1/core}/expr/utils.py +2 -2
- mplang/{core → v1/core}/expr/visitor.py +1 -1
- mplang/{core → v1/core}/expr/walk.py +1 -1
- mplang/{core → v1/core}/interp.py +6 -6
- mplang/{core → v1/core}/mpir.py +23 -16
- mplang/{core → v1/core}/mpobject.py +6 -6
- mplang/{core → v1/core}/mptype.py +13 -10
- mplang/{core → v1/core}/pfunc.py +4 -4
- mplang/{core → v1/core}/primitive.py +106 -201
- mplang/{core → v1/core}/table.py +36 -8
- mplang/{core → v1/core}/tensor.py +1 -1
- mplang/{core → v1/core}/tracer.py +9 -9
- mplang/{api.py → v1/host.py} +38 -6
- mplang/v1/kernels/__init__.py +41 -0
- mplang/{kernels → v1/kernels}/base.py +1 -1
- mplang/v1/kernels/basic.py +240 -0
- mplang/{kernels → v1/kernels}/context.py +42 -27
- mplang/{kernels → v1/kernels}/crypto.py +44 -37
- mplang/v1/kernels/fhe.py +858 -0
- mplang/{kernels → v1/kernels}/mock_tee.py +12 -13
- mplang/{kernels → v1/kernels}/phe.py +263 -57
- mplang/{kernels → v1/kernels}/spu.py +137 -48
- mplang/{kernels → v1/kernels}/sql_duckdb.py +12 -15
- mplang/{kernels → v1/kernels}/stablehlo.py +30 -23
- mplang/v1/kernels/value.py +626 -0
- mplang/{ops → v1/ops}/__init__.py +5 -16
- mplang/{ops → v1/ops}/base.py +2 -5
- mplang/{ops/builtin.py → v1/ops/basic.py} +34 -26
- mplang/v1/ops/crypto.py +262 -0
- mplang/v1/ops/fhe.py +272 -0
- mplang/{ops → v1/ops}/jax_cc.py +33 -68
- mplang/v1/ops/nnx_cc.py +168 -0
- mplang/{ops → v1/ops}/phe.py +16 -4
- mplang/{ops → v1/ops}/spu.py +3 -5
- mplang/v1/ops/sql_cc.py +303 -0
- mplang/{ops → v1/ops}/tee.py +9 -24
- mplang/{protos → v1/protos}/v1alpha1/mpir_pb2.pyi +71 -21
- mplang/v1/protos/v1alpha1/value_pb2.py +34 -0
- mplang/v1/protos/v1alpha1/value_pb2.pyi +169 -0
- mplang/{runtime → v1/runtime}/__init__.py +2 -2
- mplang/v1/runtime/channel.py +230 -0
- mplang/{runtime → v1/runtime}/cli.py +35 -20
- mplang/{runtime → v1/runtime}/client.py +19 -8
- mplang/{runtime → v1/runtime}/communicator.py +59 -15
- mplang/{runtime → v1/runtime}/data_providers.py +80 -19
- mplang/{runtime → v1/runtime}/driver.py +30 -12
- mplang/v1/runtime/link_comm.py +196 -0
- mplang/{runtime → v1/runtime}/server.py +58 -42
- mplang/{runtime → v1/runtime}/session.py +57 -71
- mplang/{runtime → v1/runtime}/simulation.py +55 -28
- mplang/v1/simp/api.py +353 -0
- mplang/{simp → v1/simp}/mpi.py +8 -9
- mplang/{simp/__init__.py → v1/simp/party.py} +19 -145
- mplang/{simp → v1/simp}/random.py +21 -22
- mplang/v1/simp/smpc.py +238 -0
- mplang/v1/utils/table_utils.py +185 -0
- mplang/v2/__init__.py +424 -0
- mplang/v2/backends/__init__.py +57 -0
- mplang/v2/backends/bfv_impl.py +705 -0
- mplang/v2/backends/channel.py +217 -0
- mplang/v2/backends/crypto_impl.py +723 -0
- mplang/v2/backends/field_impl.py +454 -0
- mplang/v2/backends/func_impl.py +107 -0
- mplang/v2/backends/phe_impl.py +148 -0
- mplang/v2/backends/simp_design.md +136 -0
- mplang/v2/backends/simp_driver/__init__.py +41 -0
- mplang/v2/backends/simp_driver/http.py +168 -0
- mplang/v2/backends/simp_driver/mem.py +280 -0
- mplang/v2/backends/simp_driver/ops.py +135 -0
- mplang/v2/backends/simp_driver/state.py +60 -0
- mplang/v2/backends/simp_driver/values.py +52 -0
- mplang/v2/backends/simp_worker/__init__.py +29 -0
- mplang/v2/backends/simp_worker/http.py +354 -0
- mplang/v2/backends/simp_worker/mem.py +102 -0
- mplang/v2/backends/simp_worker/ops.py +167 -0
- mplang/v2/backends/simp_worker/state.py +49 -0
- mplang/v2/backends/spu_impl.py +275 -0
- mplang/v2/backends/spu_state.py +187 -0
- mplang/v2/backends/store_impl.py +62 -0
- mplang/v2/backends/table_impl.py +838 -0
- mplang/v2/backends/tee_impl.py +215 -0
- mplang/v2/backends/tensor_impl.py +519 -0
- mplang/v2/cli.py +603 -0
- mplang/v2/cli_guide.md +122 -0
- mplang/v2/dialects/__init__.py +36 -0
- mplang/v2/dialects/bfv.py +665 -0
- mplang/v2/dialects/crypto.py +689 -0
- mplang/v2/dialects/dtypes.py +378 -0
- mplang/v2/dialects/field.py +210 -0
- mplang/v2/dialects/func.py +135 -0
- mplang/v2/dialects/phe.py +723 -0
- mplang/v2/dialects/simp.py +944 -0
- mplang/v2/dialects/spu.py +349 -0
- mplang/v2/dialects/store.py +63 -0
- mplang/v2/dialects/table.py +407 -0
- mplang/v2/dialects/tee.py +346 -0
- mplang/v2/dialects/tensor.py +1175 -0
- mplang/v2/edsl/README.md +279 -0
- mplang/v2/edsl/__init__.py +99 -0
- mplang/v2/edsl/context.py +311 -0
- mplang/v2/edsl/graph.py +463 -0
- mplang/v2/edsl/jit.py +62 -0
- mplang/v2/edsl/object.py +53 -0
- mplang/v2/edsl/primitive.py +284 -0
- mplang/v2/edsl/printer.py +119 -0
- mplang/v2/edsl/registry.py +207 -0
- mplang/v2/edsl/serde.py +375 -0
- mplang/v2/edsl/tracer.py +614 -0
- mplang/v2/edsl/typing.py +816 -0
- mplang/v2/kernels/Makefile +30 -0
- mplang/v2/kernels/__init__.py +23 -0
- mplang/v2/kernels/gf128.cpp +148 -0
- mplang/v2/kernels/ldpc.cpp +82 -0
- mplang/v2/kernels/okvs.cpp +283 -0
- mplang/v2/kernels/okvs_opt.cpp +291 -0
- mplang/v2/kernels/py_kernels.py +398 -0
- mplang/v2/libs/collective.py +330 -0
- mplang/v2/libs/device/__init__.py +51 -0
- mplang/v2/libs/device/api.py +813 -0
- mplang/v2/libs/device/cluster.py +352 -0
- mplang/v2/libs/ml/__init__.py +23 -0
- mplang/v2/libs/ml/sgb.py +1861 -0
- mplang/v2/libs/mpc/__init__.py +41 -0
- mplang/v2/libs/mpc/_utils.py +99 -0
- mplang/v2/libs/mpc/analytics/__init__.py +35 -0
- mplang/v2/libs/mpc/analytics/aggregation.py +372 -0
- mplang/v2/libs/mpc/analytics/groupby.md +99 -0
- mplang/v2/libs/mpc/analytics/groupby.py +331 -0
- mplang/v2/libs/mpc/analytics/permutation.py +386 -0
- mplang/v2/libs/mpc/common/constants.py +39 -0
- mplang/v2/libs/mpc/ot/__init__.py +32 -0
- mplang/v2/libs/mpc/ot/base.py +222 -0
- mplang/v2/libs/mpc/ot/extension.py +477 -0
- mplang/v2/libs/mpc/ot/silent.py +217 -0
- mplang/v2/libs/mpc/psi/__init__.py +40 -0
- mplang/v2/libs/mpc/psi/cuckoo.py +228 -0
- mplang/v2/libs/mpc/psi/okvs.py +49 -0
- mplang/v2/libs/mpc/psi/okvs_gct.py +79 -0
- mplang/v2/libs/mpc/psi/oprf.py +310 -0
- mplang/v2/libs/mpc/psi/rr22.py +344 -0
- mplang/v2/libs/mpc/psi/unbalanced.py +200 -0
- mplang/v2/libs/mpc/vole/__init__.py +31 -0
- mplang/v2/libs/mpc/vole/gilboa.py +327 -0
- mplang/v2/libs/mpc/vole/ldpc.py +383 -0
- mplang/v2/libs/mpc/vole/silver.py +336 -0
- mplang/v2/runtime/__init__.py +15 -0
- mplang/v2/runtime/dialect_state.py +41 -0
- mplang/v2/runtime/interpreter.py +871 -0
- mplang/v2/runtime/object_store.py +194 -0
- mplang/v2/runtime/value.py +141 -0
- {mplang_nightly-0.1.dev158.dist-info → mplang_nightly-0.1.dev268.dist-info}/METADATA +24 -17
- mplang_nightly-0.1.dev268.dist-info/RECORD +180 -0
- {mplang_nightly-0.1.dev158.dist-info → mplang_nightly-0.1.dev268.dist-info}/WHEEL +1 -1
- mplang/core/__init__.py +0 -92
- mplang/device.py +0 -340
- mplang/kernels/builtin.py +0 -207
- mplang/ops/crypto.py +0 -109
- mplang/ops/ibis_cc.py +0 -139
- mplang/ops/sql.py +0 -61
- mplang/protos/v1alpha1/mpir_pb2_grpc.py +0 -3
- mplang/runtime/link_comm.py +0 -131
- mplang/simp/smpc.py +0 -201
- mplang/utils/table_utils.py +0 -73
- mplang_nightly-0.1.dev158.dist-info/RECORD +0 -77
- /mplang/{core → v1/core}/mask.py +0 -0
- /mplang/{protos → v1/protos}/v1alpha1/mpir_pb2.py +0 -0
- /mplang/{runtime → v1/runtime}/exceptions.py +0 -0
- /mplang/{runtime → v1/runtime}/http_api.md +0 -0
- /mplang/{kernels → v1/simp}/__init__.py +0 -0
- /mplang/{utils → v1/utils}/__init__.py +0 -0
- /mplang/{utils → v1/utils}/crypto.py +0 -0
- /mplang/{utils → v1/utils}/func_utils.py +0 -0
- /mplang/{utils → v1/utils}/spu_utils.py +0 -0
- {mplang_nightly-0.1.dev158.dist-info → mplang_nightly-0.1.dev268.dist-info}/entry_points.txt +0 -0
- {mplang_nightly-0.1.dev158.dist-info → mplang_nightly-0.1.dev268.dist-info}/licenses/LICENSE +0 -0
mplang/v2/libs/ml/sgb.py
ADDED
|
@@ -0,0 +1,1861 @@
|
|
|
1
|
+
# Copyright 2025 Ant Group Co., Ltd.
|
|
2
|
+
#
|
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
4
|
+
# you may not use this file except in compliance with the License.
|
|
5
|
+
# You may obtain a copy of the License at
|
|
6
|
+
#
|
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
8
|
+
#
|
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
12
|
+
# See the License for the specific language governing permissions and
|
|
13
|
+
# limitations under the License.
|
|
14
|
+
|
|
15
|
+
# mypy: disable-error-code="no-untyped-def,no-any-return,var-annotated"
|
|
16
|
+
|
|
17
|
+
"""SecureBoost v2: Optimized implementation using mplang.v2 low-level BFV APIs.
|
|
18
|
+
|
|
19
|
+
This implementation improves upon v1 by leveraging BFV SIMD slots and the
|
|
20
|
+
groupby primitives for efficient histogram computation.
|
|
21
|
+
|
|
22
|
+
Key optimizations:
|
|
23
|
+
1. SIMD slot packing for parallel histogram bucket computation
|
|
24
|
+
2. Rotation-based aggregation for efficient slot summation
|
|
25
|
+
3. Reduced communication via packed ciphertext results
|
|
26
|
+
|
|
27
|
+
See design/sgb_v2.md for detailed architecture documentation.
|
|
28
|
+
|
|
29
|
+
Usage:
|
|
30
|
+
from examples.v2.sgb import SecureBoost
|
|
31
|
+
|
|
32
|
+
model = SecureBoost(n_estimators=10, max_depth=3)
|
|
33
|
+
model.fit([X_ap, X_pp], y)
|
|
34
|
+
predictions = model.predict([X_ap_test, X_pp_test])
|
|
35
|
+
"""
|
|
36
|
+
|
|
37
|
+
from __future__ import annotations
|
|
38
|
+
|
|
39
|
+
from collections import deque
|
|
40
|
+
from dataclasses import dataclass
|
|
41
|
+
from functools import partial
|
|
42
|
+
from typing import Any
|
|
43
|
+
|
|
44
|
+
import jax
|
|
45
|
+
import jax.numpy as jnp
|
|
46
|
+
import numpy as np
|
|
47
|
+
from jax.ops import segment_sum
|
|
48
|
+
|
|
49
|
+
from mplang.v2.dialects import bfv, simp, tensor
|
|
50
|
+
from mplang.v2.libs.mpc.analytics import aggregation
|
|
51
|
+
|
|
52
|
+
# ==============================================================================
|
|
53
|
+
# Configuration
|
|
54
|
+
# ==============================================================================
|
|
55
|
+
|
|
56
|
+
DEFAULT_FXP_BITS = 15 # Fixed-point scale = 2^15 = 32768
|
|
57
|
+
# BFV slot count (Increased for depth)
|
|
58
|
+
# NOTE: For 1M samples, the sum of gradients can reach ~3.2e10 (2^35).
|
|
59
|
+
# The default plain_modulus (1032193 ~ 2^20) will cause overflow.
|
|
60
|
+
# For large datasets, you MUST increase plain_modulus (e.g. to a 40-bit prime).
|
|
61
|
+
DEFAULT_POLY_MODULUS_DEGREE = 8192
|
|
62
|
+
|
|
63
|
+
|
|
64
|
+
# ==============================================================================
|
|
65
|
+
# Data Structures
|
|
66
|
+
# ==============================================================================
|
|
67
|
+
|
|
68
|
+
|
|
69
|
+
@dataclass
|
|
70
|
+
class Tree:
|
|
71
|
+
"""Single decision tree in flat array representation."""
|
|
72
|
+
|
|
73
|
+
feature: list[Any] # Per-party feature indices, shape (n_nodes,)
|
|
74
|
+
threshold: list[Any] # Per-party thresholds, shape (n_nodes,)
|
|
75
|
+
value: Any # Leaf values at AP, shape (n_nodes,)
|
|
76
|
+
is_leaf: Any # Leaf mask, shape (n_nodes,)
|
|
77
|
+
owned_party_id: Any # Node owner, shape (n_nodes,)
|
|
78
|
+
|
|
79
|
+
|
|
80
|
+
@dataclass
|
|
81
|
+
class TreeEnsemble:
|
|
82
|
+
"""XGBoost ensemble model."""
|
|
83
|
+
|
|
84
|
+
max_depth: int
|
|
85
|
+
trees: list[Tree]
|
|
86
|
+
initial_prediction: Any # Base prediction at AP
|
|
87
|
+
|
|
88
|
+
|
|
89
|
+
# ==============================================================================
|
|
90
|
+
# JAX Mathematical Functions
|
|
91
|
+
# ==============================================================================
|
|
92
|
+
|
|
93
|
+
|
|
94
|
+
@jax.jit
|
|
95
|
+
def compute_init_pred(y: jnp.ndarray) -> jnp.ndarray:
|
|
96
|
+
"""Compute initial prediction for binary classification (log-odds)."""
|
|
97
|
+
p_base = jnp.clip(jnp.mean(y), 1e-15, 1 - 1e-15)
|
|
98
|
+
return jnp.log(p_base / (1 - p_base))
|
|
99
|
+
|
|
100
|
+
|
|
101
|
+
@jax.jit
|
|
102
|
+
def sigmoid(x: jnp.ndarray) -> jnp.ndarray:
|
|
103
|
+
"""Sigmoid activation function."""
|
|
104
|
+
return 1 / (1 + jnp.exp(-x))
|
|
105
|
+
|
|
106
|
+
|
|
107
|
+
@jax.jit
|
|
108
|
+
def compute_gh(y_true: jnp.ndarray, y_pred_logits: jnp.ndarray) -> jnp.ndarray:
|
|
109
|
+
"""Compute gradient and hessian for log loss. Returns (m, 2) array."""
|
|
110
|
+
p = sigmoid(y_pred_logits)
|
|
111
|
+
g = p - y_true
|
|
112
|
+
h = p * (1 - p)
|
|
113
|
+
return jnp.column_stack([g, h])
|
|
114
|
+
|
|
115
|
+
|
|
116
|
+
@jax.jit
|
|
117
|
+
def quantize_gh(gh: jnp.ndarray, scale: int) -> jnp.ndarray:
|
|
118
|
+
"""Quantize float G/H to int64 for BFV encryption."""
|
|
119
|
+
return jnp.round(gh * scale).astype(jnp.int64)
|
|
120
|
+
|
|
121
|
+
|
|
122
|
+
@jax.jit
|
|
123
|
+
def dequantize(arr: jnp.ndarray, scale: int) -> jnp.ndarray:
|
|
124
|
+
"""Dequantize int64 back to float."""
|
|
125
|
+
return arr.astype(jnp.float32) / scale
|
|
126
|
+
|
|
127
|
+
|
|
128
|
+
# ==============================================================================
|
|
129
|
+
# Binning Functions
|
|
130
|
+
# ==============================================================================
|
|
131
|
+
|
|
132
|
+
|
|
133
|
+
def build_bins_equi_width(x: jnp.ndarray, max_bin: int) -> jnp.ndarray:
|
|
134
|
+
"""Build equi-width bin boundaries for a single feature."""
|
|
135
|
+
n_samples = x.shape[0]
|
|
136
|
+
n_splits = max_bin - 1
|
|
137
|
+
inf_splits = jnp.full(n_splits, jnp.inf, dtype=x.dtype)
|
|
138
|
+
|
|
139
|
+
def create_bins():
|
|
140
|
+
min_val, max_val = jnp.min(x), jnp.max(x)
|
|
141
|
+
is_constant = (max_val - min_val) < 1e-9
|
|
142
|
+
|
|
143
|
+
def gen_splits():
|
|
144
|
+
return jnp.linspace(min_val, max_val, num=max_bin + 1)[1:-1]
|
|
145
|
+
|
|
146
|
+
return jax.lax.cond(is_constant, lambda: inf_splits, gen_splits)
|
|
147
|
+
|
|
148
|
+
return jax.lax.cond(n_samples >= 2, create_bins, lambda: inf_splits)
|
|
149
|
+
|
|
150
|
+
|
|
151
|
+
@jax.jit
|
|
152
|
+
def compute_bin_indices(x: jnp.ndarray, bins: jnp.ndarray) -> jnp.ndarray:
|
|
153
|
+
"""Compute bin indices for all samples of a single feature."""
|
|
154
|
+
return jnp.digitize(x, bins, right=True)
|
|
155
|
+
|
|
156
|
+
|
|
157
|
+
# ==============================================================================
|
|
158
|
+
# Local Histogram (AP, no FHE needed)
|
|
159
|
+
# ==============================================================================
|
|
160
|
+
|
|
161
|
+
|
|
162
|
+
def make_local_build_histogram(n_nodes: int, n_buckets: int):
|
|
163
|
+
"""Create a JIT-compiled local histogram builder with static n_nodes and n_buckets."""
|
|
164
|
+
|
|
165
|
+
@jax.jit
|
|
166
|
+
def local_build_histogram(
|
|
167
|
+
gh: jnp.ndarray,
|
|
168
|
+
bt_local: jnp.ndarray,
|
|
169
|
+
bin_indices: jnp.ndarray,
|
|
170
|
+
) -> jnp.ndarray:
|
|
171
|
+
"""Build G/H histogram using segment_sum. Returns (n_features, n_nodes, n_buckets, 2)."""
|
|
172
|
+
|
|
173
|
+
def hist_one_feature(bins_one: jnp.ndarray) -> jnp.ndarray:
|
|
174
|
+
combined = bt_local * n_buckets + bins_one
|
|
175
|
+
valid_mask = bt_local >= 0
|
|
176
|
+
valid_gh = gh * valid_mask[:, None]
|
|
177
|
+
return segment_sum(valid_gh, combined, num_segments=n_nodes * n_buckets)
|
|
178
|
+
|
|
179
|
+
flat = jax.vmap(hist_one_feature, in_axes=1, out_axes=0)(bin_indices)
|
|
180
|
+
return flat.reshape((bin_indices.shape[1], n_nodes, n_buckets, 2))
|
|
181
|
+
|
|
182
|
+
return local_build_histogram
|
|
183
|
+
|
|
184
|
+
|
|
185
|
+
@jax.jit
|
|
186
|
+
def compute_best_split_from_hist(
|
|
187
|
+
gh_hist: jnp.ndarray, # (n_features, n_buckets, 2) for one node
|
|
188
|
+
reg_lambda: float,
|
|
189
|
+
gamma: float,
|
|
190
|
+
min_child_weight: float,
|
|
191
|
+
) -> tuple[jnp.ndarray, jnp.ndarray, jnp.ndarray]:
|
|
192
|
+
"""Find best split for a single node from its histogram."""
|
|
193
|
+
gh_total = jnp.sum(gh_hist, axis=1) # (n_features, 2)
|
|
194
|
+
gh_left = jnp.cumsum(gh_hist, axis=1)[:, :-1, :] # (n_features, n_buckets-1, 2)
|
|
195
|
+
|
|
196
|
+
g_total, h_total = gh_total[..., 0], gh_total[..., 1]
|
|
197
|
+
G_left, H_left = gh_left[..., 0], gh_left[..., 1]
|
|
198
|
+
G_right = g_total[:, None] - G_left
|
|
199
|
+
H_right = h_total[:, None] - H_left
|
|
200
|
+
|
|
201
|
+
score_parent = jnp.square(g_total) / (h_total + reg_lambda + 1e-9)
|
|
202
|
+
score_left = jnp.square(G_left) / (H_left + reg_lambda + 1e-9)
|
|
203
|
+
score_right = jnp.square(G_right) / (H_right + reg_lambda + 1e-9)
|
|
204
|
+
|
|
205
|
+
gain = (score_left + score_right - score_parent[:, None]) / 2.0
|
|
206
|
+
valid = (H_left >= min_child_weight) & (H_right >= min_child_weight)
|
|
207
|
+
gain = jnp.where(valid, gain - gamma, -jnp.inf)
|
|
208
|
+
|
|
209
|
+
flat_idx = jnp.argmax(gain)
|
|
210
|
+
best_feat, best_thresh = jnp.unravel_index(flat_idx, gain.shape)
|
|
211
|
+
return jnp.max(gain), best_feat, best_thresh
|
|
212
|
+
|
|
213
|
+
|
|
214
|
+
def local_compute_best_splits(
|
|
215
|
+
gh_hist: jnp.ndarray, # (n_features, n_nodes, n_buckets, 2)
|
|
216
|
+
reg_lambda: float,
|
|
217
|
+
gamma: float,
|
|
218
|
+
min_child_weight: float,
|
|
219
|
+
) -> tuple[jnp.ndarray, jnp.ndarray, jnp.ndarray]:
|
|
220
|
+
"""Find best splits for all nodes. Returns (n_nodes,) arrays."""
|
|
221
|
+
# Transpose to (n_nodes, n_features, n_buckets, 2)
|
|
222
|
+
gh_trans = jnp.transpose(gh_hist, (1, 0, 2, 3))
|
|
223
|
+
|
|
224
|
+
fn = partial(
|
|
225
|
+
compute_best_split_from_hist,
|
|
226
|
+
reg_lambda=reg_lambda,
|
|
227
|
+
gamma=gamma,
|
|
228
|
+
min_child_weight=min_child_weight,
|
|
229
|
+
)
|
|
230
|
+
return jax.vmap(fn)(gh_trans)
|
|
231
|
+
|
|
232
|
+
|
|
233
|
+
# ==============================================================================
|
|
234
|
+
# FHE Histogram (PP, using low-level BFV)
|
|
235
|
+
# ==============================================================================
|
|
236
|
+
|
|
237
|
+
|
|
238
|
+
def _build_packed_mask_jit(node_mask, feat_bins, n_buckets, stride, slot_count):
|
|
239
|
+
valid = node_mask == 1
|
|
240
|
+
bucket_onehot = (jnp.arange(n_buckets)[None, :] == feat_bins[:, None]) & valid[
|
|
241
|
+
:, None
|
|
242
|
+
]
|
|
243
|
+
running_counts = jnp.cumsum(bucket_onehot, axis=0)
|
|
244
|
+
shifted_counts = jnp.zeros_like(running_counts)
|
|
245
|
+
shifted_counts = shifted_counts.at[1:].set(running_counts[:-1])
|
|
246
|
+
sample_offsets = jnp.take_along_axis(
|
|
247
|
+
shifted_counts, feat_bins[:, None], axis=1
|
|
248
|
+
).squeeze(-1)
|
|
249
|
+
|
|
250
|
+
scatter_indices = jnp.where(valid, feat_bins * stride + sample_offsets, -1)
|
|
251
|
+
|
|
252
|
+
valid_mask = scatter_indices >= 0
|
|
253
|
+
valid_indices = jnp.where(valid_mask, scatter_indices, 0).astype(jnp.int32)
|
|
254
|
+
valid_ones = jnp.where(valid_mask, 1, 0).astype(jnp.int64)
|
|
255
|
+
output = segment_sum(valid_ones, valid_indices, num_segments=slot_count)
|
|
256
|
+
return output
|
|
257
|
+
|
|
258
|
+
|
|
259
|
+
def compute_all_masks(
|
|
260
|
+
subgroup_map, bin_indices, n_buckets, stride, slot_count, n_chunks
|
|
261
|
+
):
|
|
262
|
+
# subgroup_map: (n_nodes, m)
|
|
263
|
+
# bin_indices: (m, n_features)
|
|
264
|
+
|
|
265
|
+
m = bin_indices.shape[0]
|
|
266
|
+
n_features = bin_indices.shape[1]
|
|
267
|
+
n_nodes = subgroup_map.shape[0]
|
|
268
|
+
|
|
269
|
+
# Pad
|
|
270
|
+
pad_len = n_chunks * slot_count - m
|
|
271
|
+
if pad_len > 0:
|
|
272
|
+
subgroup_map = jnp.pad(subgroup_map, ((0, 0), (0, pad_len)))
|
|
273
|
+
bin_indices = jnp.pad(bin_indices, ((0, pad_len), (0, 0)))
|
|
274
|
+
|
|
275
|
+
# Reshape chunks
|
|
276
|
+
# subgroup_map: (n_nodes, n_chunks, slot_count)
|
|
277
|
+
sg_chunks = subgroup_map.reshape(n_nodes, n_chunks, slot_count)
|
|
278
|
+
|
|
279
|
+
# bin_indices: (n_chunks, slot_count, n_features) -> (n_features, n_chunks, slot_count)
|
|
280
|
+
bi_chunks = bin_indices.reshape(n_chunks, slot_count, n_features).transpose(2, 0, 1)
|
|
281
|
+
|
|
282
|
+
# vmap over chunks
|
|
283
|
+
def process_chunk(nm, fb):
|
|
284
|
+
return _build_packed_mask_jit(nm, fb, n_buckets, stride, slot_count)
|
|
285
|
+
|
|
286
|
+
v_chunk = jax.vmap(process_chunk, in_axes=(0, 0))
|
|
287
|
+
|
|
288
|
+
# vmap over features (nm fixed, fb varies)
|
|
289
|
+
v_feat = jax.vmap(v_chunk, in_axes=(None, 0))
|
|
290
|
+
|
|
291
|
+
# vmap over nodes (nm varies, fb fixed)
|
|
292
|
+
v_node = jax.vmap(v_feat, in_axes=(0, None))
|
|
293
|
+
|
|
294
|
+
all_masks = v_node(sg_chunks, bi_chunks)
|
|
295
|
+
# Flatten and convert to tuple of arrays
|
|
296
|
+
return all_masks.reshape(-1, slot_count)
|
|
297
|
+
|
|
298
|
+
|
|
299
|
+
def _compute_histogram_chunk_batch(
|
|
300
|
+
subgroup_map,
|
|
301
|
+
bin_indices,
|
|
302
|
+
g_cts,
|
|
303
|
+
h_cts,
|
|
304
|
+
encoder,
|
|
305
|
+
relin_keys,
|
|
306
|
+
galois_keys,
|
|
307
|
+
n_nodes,
|
|
308
|
+
n_features,
|
|
309
|
+
n_chunks,
|
|
310
|
+
n_buckets,
|
|
311
|
+
slot_count,
|
|
312
|
+
stride,
|
|
313
|
+
max_samples_per_bucket,
|
|
314
|
+
m,
|
|
315
|
+
):
|
|
316
|
+
# Precompute all masks in one go
|
|
317
|
+
compute_all_masks_jit = partial(
|
|
318
|
+
compute_all_masks,
|
|
319
|
+
n_buckets=n_buckets,
|
|
320
|
+
stride=stride,
|
|
321
|
+
slot_count=slot_count,
|
|
322
|
+
n_chunks=n_chunks,
|
|
323
|
+
)
|
|
324
|
+
all_masks_tensor = tensor.run_jax(
|
|
325
|
+
compute_all_masks_jit,
|
|
326
|
+
subgroup_map,
|
|
327
|
+
bin_indices,
|
|
328
|
+
)
|
|
329
|
+
|
|
330
|
+
# Batch encode all masks at once to avoid scheduler bottleneck
|
|
331
|
+
# Pass relin_keys as context provider (it holds the SEALContext)
|
|
332
|
+
all_masks_pt = bfv.batch_encode(all_masks_tensor, encoder, key=relin_keys)
|
|
333
|
+
mask_iter = iter(all_masks_pt)
|
|
334
|
+
|
|
335
|
+
# ==========================================================================
|
|
336
|
+
# Optimization: Incremental Packing to reduce peak memory
|
|
337
|
+
# ==========================================================================
|
|
338
|
+
# Instead of accumulating all features and then packing, we pack incrementally.
|
|
339
|
+
# This reduces peak memory from O(n_features) to O(stride).
|
|
340
|
+
|
|
341
|
+
# Create mask for valid slots (0, stride, 2*stride, ...)
|
|
342
|
+
m_np = np.zeros(slot_count, dtype=np.int64)
|
|
343
|
+
idx_np = np.arange(n_buckets) * stride
|
|
344
|
+
m_np[idx_np] = 1
|
|
345
|
+
mask_arr = tensor.constant(m_np)
|
|
346
|
+
mask_pt_pack = bfv.encode(mask_arr, encoder)
|
|
347
|
+
|
|
348
|
+
g_packed_flat = []
|
|
349
|
+
h_packed_flat = []
|
|
350
|
+
|
|
351
|
+
# Optimization 2: Tree Reduction
|
|
352
|
+
# Helper to sum a list of ciphertexts using a binary tree structure.
|
|
353
|
+
# This reduces the dependency chain depth from O(N) to O(log N),
|
|
354
|
+
# allowing the scheduler to parallelize additions.
|
|
355
|
+
def tree_sum(items):
|
|
356
|
+
if not items:
|
|
357
|
+
return None
|
|
358
|
+
if len(items) == 1:
|
|
359
|
+
return items[0]
|
|
360
|
+
|
|
361
|
+
queue = deque(items)
|
|
362
|
+
while len(queue) > 1:
|
|
363
|
+
# Process in pairs
|
|
364
|
+
for _ in range(len(queue) // 2):
|
|
365
|
+
left = queue.popleft()
|
|
366
|
+
right = queue.popleft()
|
|
367
|
+
queue.append(bfv.add(left, right))
|
|
368
|
+
|
|
369
|
+
return queue[0] if queue else None
|
|
370
|
+
|
|
371
|
+
for _node_idx in range(n_nodes):
|
|
372
|
+
# Process features in batches of 'stride'
|
|
373
|
+
for batch_start in range(0, n_features, stride):
|
|
374
|
+
batch_end = min(batch_start + stride, n_features)
|
|
375
|
+
|
|
376
|
+
g_rot_list = []
|
|
377
|
+
h_rot_list = []
|
|
378
|
+
|
|
379
|
+
for i, _feat_idx in enumerate(range(batch_start, batch_end)):
|
|
380
|
+
# 1. Compute Histogram for this feature (across chunks)
|
|
381
|
+
g_masked_list = []
|
|
382
|
+
h_masked_list = []
|
|
383
|
+
|
|
384
|
+
for chunk_idx in range(n_chunks):
|
|
385
|
+
mask_pt = next(mask_iter)
|
|
386
|
+
# mask_pt is already encoded via batch_encode
|
|
387
|
+
|
|
388
|
+
g_ct_chunk = g_cts[chunk_idx]
|
|
389
|
+
h_ct_chunk = h_cts[chunk_idx]
|
|
390
|
+
|
|
391
|
+
g_masked = bfv.relinearize(bfv.mul(g_ct_chunk, mask_pt), relin_keys)
|
|
392
|
+
h_masked = bfv.relinearize(bfv.mul(h_ct_chunk, mask_pt), relin_keys)
|
|
393
|
+
|
|
394
|
+
g_masked_list.append(g_masked)
|
|
395
|
+
h_masked_list.append(h_masked)
|
|
396
|
+
|
|
397
|
+
g_masked_acc = tree_sum(g_masked_list)
|
|
398
|
+
h_masked_acc = tree_sum(h_masked_list)
|
|
399
|
+
|
|
400
|
+
# Lazy Aggregation: Aggregate once after summing all chunks
|
|
401
|
+
# This reduces rotations by a factor of n_chunks
|
|
402
|
+
g_feat_acc = aggregation.batch_bucket_aggregate(
|
|
403
|
+
g_masked_acc,
|
|
404
|
+
n_buckets,
|
|
405
|
+
max_samples_per_bucket,
|
|
406
|
+
galois_keys,
|
|
407
|
+
slot_count,
|
|
408
|
+
)
|
|
409
|
+
h_feat_acc = aggregation.batch_bucket_aggregate(
|
|
410
|
+
h_masked_acc,
|
|
411
|
+
n_buckets,
|
|
412
|
+
max_samples_per_bucket,
|
|
413
|
+
galois_keys,
|
|
414
|
+
slot_count,
|
|
415
|
+
)
|
|
416
|
+
|
|
417
|
+
assert g_feat_acc is not None
|
|
418
|
+
assert h_feat_acc is not None
|
|
419
|
+
|
|
420
|
+
# 2. Pack immediately
|
|
421
|
+
# Relative offset = i
|
|
422
|
+
# Mask valid slots
|
|
423
|
+
g_masked_pack = bfv.relinearize(
|
|
424
|
+
bfv.mul(g_feat_acc, mask_pt_pack), relin_keys
|
|
425
|
+
)
|
|
426
|
+
h_masked_pack = bfv.relinearize(
|
|
427
|
+
bfv.mul(h_feat_acc, mask_pt_pack), relin_keys
|
|
428
|
+
)
|
|
429
|
+
|
|
430
|
+
# Rotate to position
|
|
431
|
+
g_rot = bfv.rotate(g_masked_pack, -i, galois_keys)
|
|
432
|
+
h_rot = bfv.rotate(h_masked_pack, -i, galois_keys)
|
|
433
|
+
|
|
434
|
+
g_rot_list.append(g_rot)
|
|
435
|
+
h_rot_list.append(h_rot)
|
|
436
|
+
|
|
437
|
+
g_packed_acc = tree_sum(g_rot_list)
|
|
438
|
+
h_packed_acc = tree_sum(h_rot_list)
|
|
439
|
+
|
|
440
|
+
g_packed_flat.append(g_packed_acc)
|
|
441
|
+
h_packed_flat.append(h_packed_acc)
|
|
442
|
+
|
|
443
|
+
return g_packed_flat, h_packed_flat
|
|
444
|
+
|
|
445
|
+
|
|
446
|
+
def _process_decrypted_jit(
|
|
447
|
+
g_vecs, h_vecs, scale, n_nodes, n_features, n_buckets, stride
|
|
448
|
+
):
|
|
449
|
+
# g_vecs is list of packed vectors.
|
|
450
|
+
# Shape of each vector: (slot_count,)
|
|
451
|
+
g_stack = jnp.stack(g_vecs)
|
|
452
|
+
h_stack = jnp.stack(h_vecs)
|
|
453
|
+
|
|
454
|
+
# We need to reconstruct (n_nodes, n_features, n_buckets)
|
|
455
|
+
g_unpacked = []
|
|
456
|
+
h_unpacked = []
|
|
457
|
+
|
|
458
|
+
cts_per_node = (n_features + stride - 1) // stride
|
|
459
|
+
|
|
460
|
+
for node_i in range(n_nodes):
|
|
461
|
+
for feat_i in range(n_features):
|
|
462
|
+
# Which CT?
|
|
463
|
+
ct_idx = node_i * cts_per_node + (feat_i // stride)
|
|
464
|
+
# Which offset in CT?
|
|
465
|
+
offset = feat_i % stride
|
|
466
|
+
|
|
467
|
+
# Indices for buckets: b*stride + offset
|
|
468
|
+
bucket_indices = jnp.arange(n_buckets) * stride + offset
|
|
469
|
+
|
|
470
|
+
g_vals = g_stack[ct_idx, bucket_indices]
|
|
471
|
+
h_vals = h_stack[ct_idx, bucket_indices]
|
|
472
|
+
|
|
473
|
+
g_unpacked.append(g_vals)
|
|
474
|
+
h_unpacked.append(h_vals)
|
|
475
|
+
|
|
476
|
+
# Now we have flat list of (n_buckets,) arrays
|
|
477
|
+
g_flat = jnp.stack(g_unpacked) # (n_nodes*n_features, n_buckets)
|
|
478
|
+
h_flat = jnp.stack(h_unpacked)
|
|
479
|
+
|
|
480
|
+
g_buckets = g_flat.astype(jnp.float32) / scale
|
|
481
|
+
h_buckets = h_flat.astype(jnp.float32) / scale
|
|
482
|
+
|
|
483
|
+
g_cumsum = jnp.cumsum(g_buckets, axis=1)
|
|
484
|
+
h_cumsum = jnp.cumsum(h_buckets, axis=1)
|
|
485
|
+
|
|
486
|
+
g_reshaped = g_cumsum.reshape(n_nodes, n_features, n_buckets)
|
|
487
|
+
h_reshaped = h_cumsum.reshape(n_nodes, n_features, n_buckets)
|
|
488
|
+
|
|
489
|
+
combined = jnp.stack([g_reshaped, h_reshaped], axis=-1)
|
|
490
|
+
return combined
|
|
491
|
+
|
|
492
|
+
|
|
493
|
+
def _decrypt_batch(
|
|
494
|
+
g_enc_flat,
|
|
495
|
+
h_enc_flat,
|
|
496
|
+
sk,
|
|
497
|
+
encoder,
|
|
498
|
+
fxp_scale,
|
|
499
|
+
n_nodes,
|
|
500
|
+
n_features,
|
|
501
|
+
n_buckets,
|
|
502
|
+
stride,
|
|
503
|
+
):
|
|
504
|
+
g_vecs = [bfv.decode(bfv.decrypt(ct, sk), encoder) for ct in g_enc_flat]
|
|
505
|
+
h_vecs = [bfv.decode(bfv.decrypt(ct, sk), encoder) for ct in h_enc_flat]
|
|
506
|
+
|
|
507
|
+
fn_jit = partial(
|
|
508
|
+
_process_decrypted_jit,
|
|
509
|
+
n_nodes=n_nodes,
|
|
510
|
+
n_features=n_features,
|
|
511
|
+
n_buckets=n_buckets,
|
|
512
|
+
stride=stride,
|
|
513
|
+
)
|
|
514
|
+
return tensor.run_jax(
|
|
515
|
+
fn_jit,
|
|
516
|
+
g_vecs,
|
|
517
|
+
h_vecs,
|
|
518
|
+
fxp_scale,
|
|
519
|
+
)
|
|
520
|
+
|
|
521
|
+
|
|
522
|
+
def fhe_encrypt_gh(
|
|
523
|
+
qg: Any,
|
|
524
|
+
qh: Any,
|
|
525
|
+
pk: Any,
|
|
526
|
+
encoder: Any,
|
|
527
|
+
ap_rank: int,
|
|
528
|
+
n_samples: int,
|
|
529
|
+
slot_count: int = DEFAULT_POLY_MODULUS_DEGREE,
|
|
530
|
+
) -> tuple[list[Any], list[Any], int]:
|
|
531
|
+
"""Encrypt quantized G/H vectors at AP, splitting into chunks if m > slot_count.
|
|
532
|
+
|
|
533
|
+
When m > slot_count, the vectors are split into ceil(m / slot_count) chunks,
|
|
534
|
+
each encrypted as a separate ciphertext. This enables processing arbitrarily
|
|
535
|
+
large datasets with a fixed poly_modulus_degree.
|
|
536
|
+
|
|
537
|
+
Args:
|
|
538
|
+
qg: Quantized G vector, shape (m,)
|
|
539
|
+
qh: Quantized H vector, shape (m,)
|
|
540
|
+
pk: BFV public key
|
|
541
|
+
encoder: BFV encoder
|
|
542
|
+
ap_rank: Active party rank
|
|
543
|
+
n_samples: Number of samples (m)
|
|
544
|
+
slot_count: Number of slots per ciphertext (default 4096)
|
|
545
|
+
|
|
546
|
+
Returns:
|
|
547
|
+
(g_cts, h_cts, n_chunks): Lists of encrypted G/H chunks and chunk count
|
|
548
|
+
"""
|
|
549
|
+
# Calculate n_chunks at trace time (known statically)
|
|
550
|
+
n_chunks = (n_samples + slot_count - 1) // slot_count
|
|
551
|
+
|
|
552
|
+
g_cts: list[Any] = []
|
|
553
|
+
h_cts: list[Any] = []
|
|
554
|
+
|
|
555
|
+
for chunk_idx in range(n_chunks):
|
|
556
|
+
start = chunk_idx * slot_count
|
|
557
|
+
end = min((chunk_idx + 1) * slot_count, n_samples)
|
|
558
|
+
chunk_size = end - start
|
|
559
|
+
|
|
560
|
+
# Extract, pad, encode and encrypt both G and H chunks together
|
|
561
|
+
def slice_pad_encode_encrypt(
|
|
562
|
+
g_vec, h_vec, enc, key, s=start, e=end, cs=chunk_size, sc=slot_count
|
|
563
|
+
):
|
|
564
|
+
# Slice and pad using JAX
|
|
565
|
+
def slice_and_pad_both(gv, hv):
|
|
566
|
+
g_chunk = gv[s:e]
|
|
567
|
+
h_chunk = hv[s:e]
|
|
568
|
+
if cs < sc:
|
|
569
|
+
g_chunk = jnp.pad(g_chunk, (0, sc - cs))
|
|
570
|
+
h_chunk = jnp.pad(h_chunk, (0, sc - cs))
|
|
571
|
+
return g_chunk, h_chunk
|
|
572
|
+
|
|
573
|
+
g_chunk, h_chunk = tensor.run_jax(slice_and_pad_both, g_vec, h_vec)
|
|
574
|
+
# Encode and encrypt
|
|
575
|
+
g_pt = bfv.encode(g_chunk, enc)
|
|
576
|
+
h_pt = bfv.encode(h_chunk, enc)
|
|
577
|
+
return bfv.encrypt(g_pt, key), bfv.encrypt(h_pt, key)
|
|
578
|
+
|
|
579
|
+
g_ct, h_ct = simp.pcall_static(
|
|
580
|
+
(ap_rank,), slice_pad_encode_encrypt, qg, qh, encoder, pk
|
|
581
|
+
)
|
|
582
|
+
|
|
583
|
+
g_cts.append(g_ct)
|
|
584
|
+
h_cts.append(h_ct)
|
|
585
|
+
|
|
586
|
+
return g_cts, h_cts, n_chunks
|
|
587
|
+
|
|
588
|
+
|
|
589
|
+
def fhe_histogram_optimized(
|
|
590
|
+
g_cts: list[Any], # List of encrypted G chunks at PP
|
|
591
|
+
h_cts: list[Any], # List of encrypted H chunks at PP
|
|
592
|
+
subgroup_map: Any, # (n_nodes, m) binary node membership
|
|
593
|
+
bin_indices: Any, # (m, n_features) binned features
|
|
594
|
+
n_buckets: int,
|
|
595
|
+
n_nodes: int,
|
|
596
|
+
n_features: int,
|
|
597
|
+
pp_rank: int,
|
|
598
|
+
ap_rank: int,
|
|
599
|
+
encoder: Any,
|
|
600
|
+
relin_keys: Any,
|
|
601
|
+
galois_keys: Any,
|
|
602
|
+
m: int,
|
|
603
|
+
n_chunks: int = 1,
|
|
604
|
+
slot_count: int = DEFAULT_POLY_MODULUS_DEGREE,
|
|
605
|
+
) -> tuple[list[Any], list[Any]]:
|
|
606
|
+
"""Compute encrypted histogram sums using SIMD bucket packing.
|
|
607
|
+
|
|
608
|
+
**Multi-CT Support**
|
|
609
|
+
|
|
610
|
+
When m > slot_count, data is split into n_chunks ciphertexts:
|
|
611
|
+
- Chunk 0: samples [0, slot_count)
|
|
612
|
+
- Chunk 1: samples [slot_count, 2*slot_count)
|
|
613
|
+
- ...
|
|
614
|
+
|
|
615
|
+
For each chunk, we compute the histogram separately, then add results
|
|
616
|
+
in the FHE domain.
|
|
617
|
+
|
|
618
|
+
**SIMD Bucket Packing** (per chunk)
|
|
619
|
+
|
|
620
|
+
1. Divide slot_count into n_buckets regions, each with `stride` slots
|
|
621
|
+
2. Build scatter mask placing sample i at slot (bucket[i] * stride + offset[i])
|
|
622
|
+
3. Single CT × packed_mask multiplication
|
|
623
|
+
4. Single rotate_and_sum aggregates ALL buckets simultaneously
|
|
624
|
+
5. Add chunk results together
|
|
625
|
+
|
|
626
|
+
Returns:
|
|
627
|
+
g_enc[node][feat]: List of packed encrypted G histograms (one CT per feature)
|
|
628
|
+
h_enc[node][feat]: List of packed encrypted H histograms (one CT per feature)
|
|
629
|
+
"""
|
|
630
|
+
stride = slot_count // n_buckets
|
|
631
|
+
# Estimate max samples per bucket per chunk
|
|
632
|
+
samples_per_chunk = (m + n_chunks - 1) // n_chunks
|
|
633
|
+
max_samples_per_bucket = min(stride, max(samples_per_chunk // n_buckets * 2, 64))
|
|
634
|
+
|
|
635
|
+
# Use partial to bake in static arguments (integers) so they are treated as static by JAX
|
|
636
|
+
fn = partial(
|
|
637
|
+
_compute_histogram_chunk_batch,
|
|
638
|
+
n_nodes=n_nodes,
|
|
639
|
+
n_features=n_features,
|
|
640
|
+
n_chunks=n_chunks,
|
|
641
|
+
n_buckets=n_buckets,
|
|
642
|
+
slot_count=slot_count,
|
|
643
|
+
stride=stride,
|
|
644
|
+
max_samples_per_bucket=max_samples_per_bucket,
|
|
645
|
+
m=m,
|
|
646
|
+
)
|
|
647
|
+
|
|
648
|
+
g_results_flat, h_results_flat = simp.pcall_static(
|
|
649
|
+
(pp_rank,),
|
|
650
|
+
fn,
|
|
651
|
+
subgroup_map,
|
|
652
|
+
bin_indices,
|
|
653
|
+
g_cts,
|
|
654
|
+
h_cts,
|
|
655
|
+
encoder,
|
|
656
|
+
relin_keys,
|
|
657
|
+
galois_keys,
|
|
658
|
+
)
|
|
659
|
+
|
|
660
|
+
# Transfer final packed result to AP
|
|
661
|
+
# g_results_flat is a list of Objects (one per node/feature/chunk accumulation)
|
|
662
|
+
g_packed_ap = [
|
|
663
|
+
simp.shuffle_static(obj, {ap_rank: pp_rank}) for obj in g_results_flat
|
|
664
|
+
]
|
|
665
|
+
h_packed_ap = [
|
|
666
|
+
simp.shuffle_static(obj, {ap_rank: pp_rank}) for obj in h_results_flat
|
|
667
|
+
]
|
|
668
|
+
|
|
669
|
+
return g_packed_ap, h_packed_ap
|
|
670
|
+
|
|
671
|
+
|
|
672
|
+
def decrypt_histogram_results(
|
|
673
|
+
g_enc_flat: Any,
|
|
674
|
+
h_enc_flat: Any,
|
|
675
|
+
sk: Any,
|
|
676
|
+
encoder: Any,
|
|
677
|
+
fxp_scale: int,
|
|
678
|
+
n_nodes: int,
|
|
679
|
+
n_features: int,
|
|
680
|
+
n_buckets: int,
|
|
681
|
+
ap_rank: int,
|
|
682
|
+
slot_count: int = DEFAULT_POLY_MODULUS_DEGREE,
|
|
683
|
+
) -> list[Any]:
|
|
684
|
+
"""Decrypt and assemble histogram results at AP.
|
|
685
|
+
|
|
686
|
+
**SIMD Bucket Packing Format**
|
|
687
|
+
|
|
688
|
+
With SIMD bucket packing, each ciphertext contains ALL buckets for one feature:
|
|
689
|
+
- g_enc_flat is a list of packed CTs (one per feature per node)
|
|
690
|
+
- slot[b * stride] contains histogram[b] for bucket b
|
|
691
|
+
- stride = slot_count // n_buckets
|
|
692
|
+
|
|
693
|
+
We extract bucket results from strided positions, then compute cumulative sum.
|
|
694
|
+
|
|
695
|
+
Returns list of (n_features, n_buckets, 2) arrays, one per node.
|
|
696
|
+
The returned histograms are CUMULATIVE (sum of all bins <= bucket_idx).
|
|
697
|
+
"""
|
|
698
|
+
stride = slot_count // n_buckets
|
|
699
|
+
|
|
700
|
+
fn = partial(
|
|
701
|
+
_decrypt_batch,
|
|
702
|
+
fxp_scale=fxp_scale,
|
|
703
|
+
n_nodes=n_nodes,
|
|
704
|
+
n_features=n_features,
|
|
705
|
+
n_buckets=n_buckets,
|
|
706
|
+
stride=stride,
|
|
707
|
+
)
|
|
708
|
+
|
|
709
|
+
combined_results = simp.pcall_static(
|
|
710
|
+
(ap_rank,),
|
|
711
|
+
fn,
|
|
712
|
+
g_enc_flat,
|
|
713
|
+
h_enc_flat,
|
|
714
|
+
sk,
|
|
715
|
+
encoder,
|
|
716
|
+
)
|
|
717
|
+
|
|
718
|
+
# combined_results is (n_nodes, n_features, n_buckets, 2)
|
|
719
|
+
# Convert to list of (n_features, n_buckets, 2)
|
|
720
|
+
# Since combined_results is an Object, we can't iterate it in Python.
|
|
721
|
+
# But the caller (build_tree) expects a list of Objects (one per node)
|
|
722
|
+
# because it stacks them later: stacked = jnp.stack(hists, axis=0)
|
|
723
|
+
|
|
724
|
+
# Wait, if combined_results is a single Object representing the whole tensor,
|
|
725
|
+
# we can just return that single Object if we change the caller to handle it.
|
|
726
|
+
# But build_tree expects a list.
|
|
727
|
+
|
|
728
|
+
# Actually, build_tree does:
|
|
729
|
+
# pp_hists = decrypt_histogram_results(...)
|
|
730
|
+
# def find_splits(*hists):
|
|
731
|
+
# stacked = jnp.stack(hists, axis=0)
|
|
732
|
+
# pp_gains, ... = simp.pcall_static(..., tensor.run_jax(find_splits, *pp_hists))
|
|
733
|
+
|
|
734
|
+
# If pp_hists is a single tensor (n_nodes, ...), we can change find_splits to take it directly.
|
|
735
|
+
|
|
736
|
+
return combined_results
|
|
737
|
+
|
|
738
|
+
|
|
739
|
+
# ==============================================================================
|
|
740
|
+
# Tree Update Functions
|
|
741
|
+
# ==============================================================================
|
|
742
|
+
|
|
743
|
+
|
|
744
|
+
def make_get_subgroup_map(n_nodes: int):
|
|
745
|
+
"""Create a JIT-compiled subgroup map function with static n_nodes."""
|
|
746
|
+
|
|
747
|
+
@jax.jit
|
|
748
|
+
def get_subgroup_map(bt_level: jnp.ndarray) -> jnp.ndarray:
|
|
749
|
+
"""Create one-hot node membership map. Returns (n_nodes, m)."""
|
|
750
|
+
return (jnp.arange(n_nodes)[:, None] == bt_level).astype(jnp.int8)
|
|
751
|
+
|
|
752
|
+
return get_subgroup_map
|
|
753
|
+
|
|
754
|
+
|
|
755
|
+
@jax.jit
|
|
756
|
+
def update_is_leaf(
|
|
757
|
+
is_leaf: jnp.ndarray,
|
|
758
|
+
gains: jnp.ndarray,
|
|
759
|
+
indices: jnp.ndarray,
|
|
760
|
+
) -> jnp.ndarray:
|
|
761
|
+
"""Mark nodes as leaves if gain <= 0 or non-finite."""
|
|
762
|
+
new_leaf = (gains <= 0.0) | (~jnp.isfinite(gains))
|
|
763
|
+
return is_leaf.at[indices].set(new_leaf.astype(jnp.int64))
|
|
764
|
+
|
|
765
|
+
|
|
766
|
+
@jax.jit
|
|
767
|
+
def update_bt(
|
|
768
|
+
bt: jnp.ndarray,
|
|
769
|
+
bt_level: jnp.ndarray,
|
|
770
|
+
is_leaf: jnp.ndarray,
|
|
771
|
+
bin_indices: jnp.ndarray,
|
|
772
|
+
best_feature: jnp.ndarray,
|
|
773
|
+
best_thresh_idx: jnp.ndarray,
|
|
774
|
+
) -> jnp.ndarray:
|
|
775
|
+
"""Update sample-to-node assignments after splitting."""
|
|
776
|
+
m = bt.shape[0]
|
|
777
|
+
feat_per_sample = best_feature[bt_level]
|
|
778
|
+
thresh_per_sample = best_thresh_idx[bt_level]
|
|
779
|
+
sample_bins = bin_indices[jnp.arange(m), feat_per_sample]
|
|
780
|
+
|
|
781
|
+
go_left = sample_bins <= thresh_per_sample
|
|
782
|
+
bt_next = jnp.where(go_left, 2 * bt + 1, 2 * bt + 2)
|
|
783
|
+
return jnp.where(is_leaf[bt].astype(bool), bt, bt_next)
|
|
784
|
+
|
|
785
|
+
|
|
786
|
+
def make_compute_leaf_values(n_nodes: int):
|
|
787
|
+
"""Create a JIT-compiled leaf value computation with static n_nodes."""
|
|
788
|
+
|
|
789
|
+
@jax.jit
|
|
790
|
+
def compute_leaf_values(
|
|
791
|
+
gh: jnp.ndarray,
|
|
792
|
+
bt: jnp.ndarray,
|
|
793
|
+
is_leaf: jnp.ndarray,
|
|
794
|
+
reg_lambda: float,
|
|
795
|
+
) -> jnp.ndarray:
|
|
796
|
+
"""Compute leaf values from aggregated G/H."""
|
|
797
|
+
sum_gh = segment_sum(gh, bt, num_segments=n_nodes)
|
|
798
|
+
sum_g, sum_h = sum_gh[:, 0], sum_gh[:, 1]
|
|
799
|
+
safe_h = jnp.where(sum_h == 0, 1.0, sum_h)
|
|
800
|
+
leaf_vals = -sum_g / (safe_h + reg_lambda)
|
|
801
|
+
|
|
802
|
+
has_samples = sum_h != 0
|
|
803
|
+
return jnp.where(is_leaf.astype(bool) & has_samples, leaf_vals, 0.0)
|
|
804
|
+
|
|
805
|
+
return compute_leaf_values
|
|
806
|
+
|
|
807
|
+
|
|
808
|
+
# ==============================================================================
|
|
809
|
+
# Tree Building Helpers
|
|
810
|
+
# ==============================================================================
|
|
811
|
+
|
|
812
|
+
|
|
813
|
+
def _find_splits_ap(
|
|
814
|
+
ap_rank: int,
|
|
815
|
+
n_level: int,
|
|
816
|
+
n_buckets: int,
|
|
817
|
+
gh: Any,
|
|
818
|
+
bt_level: Any,
|
|
819
|
+
bin_indices: Any,
|
|
820
|
+
reg_lambda: float,
|
|
821
|
+
gamma: float,
|
|
822
|
+
min_child_weight: float,
|
|
823
|
+
) -> tuple[Any, Any, Any]:
|
|
824
|
+
"""Compute local histograms and find best splits at AP."""
|
|
825
|
+
local_hist_fn = make_local_build_histogram(n_level, n_buckets)
|
|
826
|
+
ap_hist = simp.pcall_static(
|
|
827
|
+
(ap_rank,),
|
|
828
|
+
lambda fn=local_hist_fn: tensor.run_jax(fn, gh, bt_level, bin_indices),
|
|
829
|
+
)
|
|
830
|
+
ap_gains, ap_feats, ap_threshs = simp.pcall_static(
|
|
831
|
+
(ap_rank,),
|
|
832
|
+
lambda rl=reg_lambda, gm=gamma, mcw=min_child_weight: tensor.run_jax(
|
|
833
|
+
local_compute_best_splits, ap_hist, rl, gm, mcw
|
|
834
|
+
),
|
|
835
|
+
)
|
|
836
|
+
return ap_gains, ap_feats, ap_threshs
|
|
837
|
+
|
|
838
|
+
|
|
839
|
+
def _find_splits_pps(
|
|
840
|
+
level: int,
|
|
841
|
+
pp_ranks: list[int],
|
|
842
|
+
ap_rank: int,
|
|
843
|
+
g_cts_pps: dict[int, list[Any]],
|
|
844
|
+
h_cts_pps: dict[int, list[Any]],
|
|
845
|
+
bt_level: Any,
|
|
846
|
+
all_bin_indices: list[Any],
|
|
847
|
+
n_features_per_party: list[int],
|
|
848
|
+
last_level_hists: list[Any],
|
|
849
|
+
encoder: Any,
|
|
850
|
+
relin_keys: Any,
|
|
851
|
+
galois_keys: Any,
|
|
852
|
+
sk: Any,
|
|
853
|
+
fxp_scale: int,
|
|
854
|
+
m: int,
|
|
855
|
+
n_chunks: int,
|
|
856
|
+
slot_count: int,
|
|
857
|
+
n_buckets: int,
|
|
858
|
+
reg_lambda: float,
|
|
859
|
+
gamma: float,
|
|
860
|
+
min_child_weight: float,
|
|
861
|
+
) -> tuple[list[Any], list[Any], list[Any]]:
|
|
862
|
+
"""Compute remote histograms via FHE and find best splits at PPs."""
|
|
863
|
+
pp_gains_list = []
|
|
864
|
+
pp_feats_list = []
|
|
865
|
+
pp_threshs_list = []
|
|
866
|
+
|
|
867
|
+
n_level = 2**level
|
|
868
|
+
|
|
869
|
+
for pp_idx, pp_rank in enumerate(pp_ranks):
|
|
870
|
+
# Retrieve pre-transferred encrypted CT chunks
|
|
871
|
+
g_cts_pp = g_cts_pps[pp_rank]
|
|
872
|
+
h_cts_pp = h_cts_pps[pp_rank]
|
|
873
|
+
|
|
874
|
+
# Transfer keys and other metadata to PP
|
|
875
|
+
bt_level_pp = simp.shuffle_static(bt_level, {pp_rank: ap_rank})
|
|
876
|
+
encoder_pp = simp.shuffle_static(encoder, {pp_rank: ap_rank})
|
|
877
|
+
rk_pp = simp.shuffle_static(relin_keys, {pp_rank: ap_rank})
|
|
878
|
+
gk_pp = simp.shuffle_static(galois_keys, {pp_rank: ap_rank})
|
|
879
|
+
|
|
880
|
+
# Build subgroup map at PP
|
|
881
|
+
subgroup_map_fn = make_get_subgroup_map(n_level)
|
|
882
|
+
subgroup_map = simp.pcall_static(
|
|
883
|
+
(pp_rank,),
|
|
884
|
+
lambda fn=subgroup_map_fn, bt_lv=bt_level_pp: tensor.run_jax(fn, bt_lv),
|
|
885
|
+
)
|
|
886
|
+
|
|
887
|
+
n_pp_features = n_features_per_party[pp_idx + 1]
|
|
888
|
+
|
|
889
|
+
if level == 0:
|
|
890
|
+
# Root level: Compute full FHE
|
|
891
|
+
g_enc, h_enc = fhe_histogram_optimized(
|
|
892
|
+
g_cts_pp,
|
|
893
|
+
h_cts_pp,
|
|
894
|
+
subgroup_map,
|
|
895
|
+
all_bin_indices[pp_idx + 1],
|
|
896
|
+
n_buckets,
|
|
897
|
+
n_level,
|
|
898
|
+
n_pp_features,
|
|
899
|
+
pp_rank,
|
|
900
|
+
ap_rank,
|
|
901
|
+
encoder_pp,
|
|
902
|
+
rk_pp,
|
|
903
|
+
gk_pp,
|
|
904
|
+
m,
|
|
905
|
+
n_chunks,
|
|
906
|
+
slot_count,
|
|
907
|
+
)
|
|
908
|
+
|
|
909
|
+
pp_hists = decrypt_histogram_results(
|
|
910
|
+
g_enc,
|
|
911
|
+
h_enc,
|
|
912
|
+
sk,
|
|
913
|
+
encoder,
|
|
914
|
+
fxp_scale,
|
|
915
|
+
n_level,
|
|
916
|
+
n_pp_features,
|
|
917
|
+
n_buckets,
|
|
918
|
+
ap_rank,
|
|
919
|
+
)
|
|
920
|
+
# Store for next level
|
|
921
|
+
last_level_hists[pp_idx + 1] = pp_hists
|
|
922
|
+
|
|
923
|
+
else:
|
|
924
|
+
# Histogram Subtraction Optimization
|
|
925
|
+
# 1. Slice subgroup_map to get Left children (even indices 0, 2, ...)
|
|
926
|
+
def slice_left(sm):
|
|
927
|
+
return sm[0::2]
|
|
928
|
+
|
|
929
|
+
subgroup_map_left = simp.pcall_static(
|
|
930
|
+
(pp_rank,),
|
|
931
|
+
lambda sm=subgroup_map: tensor.run_jax(slice_left, sm),
|
|
932
|
+
)
|
|
933
|
+
|
|
934
|
+
# 2. Run FHE for Left children
|
|
935
|
+
n_left = n_level // 2
|
|
936
|
+
g_enc, h_enc = fhe_histogram_optimized(
|
|
937
|
+
g_cts_pp,
|
|
938
|
+
h_cts_pp,
|
|
939
|
+
subgroup_map_left,
|
|
940
|
+
all_bin_indices[pp_idx + 1],
|
|
941
|
+
n_buckets,
|
|
942
|
+
n_left,
|
|
943
|
+
n_pp_features,
|
|
944
|
+
pp_rank,
|
|
945
|
+
ap_rank,
|
|
946
|
+
encoder_pp,
|
|
947
|
+
rk_pp,
|
|
948
|
+
gk_pp,
|
|
949
|
+
m,
|
|
950
|
+
n_chunks,
|
|
951
|
+
slot_count,
|
|
952
|
+
)
|
|
953
|
+
|
|
954
|
+
# 3. Decrypt Left
|
|
955
|
+
left_hists = decrypt_histogram_results(
|
|
956
|
+
g_enc,
|
|
957
|
+
h_enc,
|
|
958
|
+
sk,
|
|
959
|
+
encoder,
|
|
960
|
+
fxp_scale,
|
|
961
|
+
n_left,
|
|
962
|
+
n_pp_features,
|
|
963
|
+
n_buckets,
|
|
964
|
+
ap_rank,
|
|
965
|
+
)
|
|
966
|
+
|
|
967
|
+
# 4. Derive Right and Reconstruct
|
|
968
|
+
parent_hists = last_level_hists[pp_idx + 1]
|
|
969
|
+
|
|
970
|
+
def derive_right_and_combine(l_hists, p_hists):
|
|
971
|
+
# l_hists: (n_left, ...)
|
|
972
|
+
# p_hists: (n_left, ...) - parents correspond exactly to left children
|
|
973
|
+
r_hists = p_hists - l_hists
|
|
974
|
+
|
|
975
|
+
# Interleave [L, R]
|
|
976
|
+
# Stack on new axis 1 -> (n_left, 2, ...)
|
|
977
|
+
combined = jnp.stack([l_hists, r_hists], axis=1)
|
|
978
|
+
# Reshape -> (2*n_left, ...)
|
|
979
|
+
return combined.reshape((-1, *l_hists.shape[1:]))
|
|
980
|
+
|
|
981
|
+
pp_hists = simp.pcall_static(
|
|
982
|
+
(ap_rank,),
|
|
983
|
+
lambda lh=left_hists, ph=parent_hists: tensor.run_jax(
|
|
984
|
+
derive_right_and_combine, lh, ph
|
|
985
|
+
),
|
|
986
|
+
)
|
|
987
|
+
|
|
988
|
+
# Store for next level (if needed)
|
|
989
|
+
# Note: We don't know max_depth here, but storing it is harmless if not used
|
|
990
|
+
last_level_hists[pp_idx + 1] = pp_hists
|
|
991
|
+
|
|
992
|
+
# Stack and find best splits
|
|
993
|
+
def find_splits(hists, rl=reg_lambda, gm=gamma, mcw=min_child_weight):
|
|
994
|
+
# hists is already (n_nodes, n_feat, n_buck, 2)
|
|
995
|
+
return jax.vmap(lambda h: compute_best_split_from_hist(h, rl, gm, mcw))(
|
|
996
|
+
hists
|
|
997
|
+
)
|
|
998
|
+
|
|
999
|
+
pp_gains, pp_feats, pp_threshs = simp.pcall_static(
|
|
1000
|
+
(ap_rank,),
|
|
1001
|
+
lambda h=pp_hists: tensor.run_jax(find_splits, h),
|
|
1002
|
+
)
|
|
1003
|
+
|
|
1004
|
+
pp_gains_list.append(pp_gains)
|
|
1005
|
+
pp_feats_list.append(pp_feats)
|
|
1006
|
+
pp_threshs_list.append(pp_threshs)
|
|
1007
|
+
|
|
1008
|
+
return pp_gains_list, pp_feats_list, pp_threshs_list
|
|
1009
|
+
|
|
1010
|
+
|
|
1011
|
+
def _update_tree_state(
|
|
1012
|
+
ap_rank: int,
|
|
1013
|
+
pp_ranks: list[int],
|
|
1014
|
+
all_ranks: list[int],
|
|
1015
|
+
all_feats: list[Any],
|
|
1016
|
+
all_thresholds: list[Any],
|
|
1017
|
+
bt: Any,
|
|
1018
|
+
bt_level: Any,
|
|
1019
|
+
is_leaf: Any,
|
|
1020
|
+
owned_party: Any,
|
|
1021
|
+
cur_indices: Any,
|
|
1022
|
+
best_party: Any,
|
|
1023
|
+
best_gains: Any,
|
|
1024
|
+
all_feats_level: list[Any],
|
|
1025
|
+
all_threshs_level: list[Any],
|
|
1026
|
+
all_bins: list[Any],
|
|
1027
|
+
all_bin_indices: list[Any],
|
|
1028
|
+
) -> tuple[Any, Any, list[Any], list[Any], Any]:
|
|
1029
|
+
"""Update tree structure and sample assignments based on best splits."""
|
|
1030
|
+
# Update is_leaf
|
|
1031
|
+
is_leaf = simp.pcall_static(
|
|
1032
|
+
(ap_rank,),
|
|
1033
|
+
lambda: tensor.run_jax(update_is_leaf, is_leaf, best_gains, cur_indices),
|
|
1034
|
+
)
|
|
1035
|
+
|
|
1036
|
+
# Broadcast is_leaf to all parties (keep source, shuffle to each target, then converge)
|
|
1037
|
+
if pp_ranks:
|
|
1038
|
+
is_leaf_parts = [is_leaf] # Start with AP's copy
|
|
1039
|
+
for r in pp_ranks:
|
|
1040
|
+
is_leaf_parts.append(simp.shuffle_static(is_leaf, {r: ap_rank}))
|
|
1041
|
+
is_leaf = simp.converge(*is_leaf_parts)
|
|
1042
|
+
|
|
1043
|
+
# Update owned_party
|
|
1044
|
+
owned_party = simp.pcall_static(
|
|
1045
|
+
(ap_rank,),
|
|
1046
|
+
lambda: tensor.run_jax(
|
|
1047
|
+
lambda op, bp, ci: op.at[ci].set(bp),
|
|
1048
|
+
owned_party,
|
|
1049
|
+
best_party,
|
|
1050
|
+
cur_indices,
|
|
1051
|
+
),
|
|
1052
|
+
)
|
|
1053
|
+
|
|
1054
|
+
# Broadcast owned_party to all parties
|
|
1055
|
+
if pp_ranks:
|
|
1056
|
+
owned_party_parts = [owned_party]
|
|
1057
|
+
for r in pp_ranks:
|
|
1058
|
+
owned_party_parts.append(simp.shuffle_static(owned_party, {r: ap_rank}))
|
|
1059
|
+
owned_party = simp.converge(*owned_party_parts)
|
|
1060
|
+
|
|
1061
|
+
# === Update features and thresholds for each party ===
|
|
1062
|
+
# Route best_feats/best_threshs to correct parties based on best_party
|
|
1063
|
+
all_tmp_bt: list[Any] = []
|
|
1064
|
+
|
|
1065
|
+
for party_idx, party_rank in enumerate(all_ranks):
|
|
1066
|
+
# Transfer data to this party if needed
|
|
1067
|
+
if party_idx > 0:
|
|
1068
|
+
# PP's results are already at AP, send back to PP
|
|
1069
|
+
all_feats_level[party_idx] = simp.shuffle_static(
|
|
1070
|
+
all_feats_level[party_idx], {party_rank: ap_rank}
|
|
1071
|
+
)
|
|
1072
|
+
all_threshs_level[party_idx] = simp.shuffle_static(
|
|
1073
|
+
all_threshs_level[party_idx], {party_rank: ap_rank}
|
|
1074
|
+
)
|
|
1075
|
+
# Also need cur_indices, owned_party, is_leaf at PP
|
|
1076
|
+
cur_indices_party = simp.shuffle_static(cur_indices, {party_rank: ap_rank})
|
|
1077
|
+
owned_party_party = simp.shuffle_static(owned_party, {party_rank: ap_rank})
|
|
1078
|
+
is_leaf_party = simp.shuffle_static(is_leaf, {party_rank: ap_rank})
|
|
1079
|
+
else:
|
|
1080
|
+
cur_indices_party = cur_indices
|
|
1081
|
+
owned_party_party = owned_party
|
|
1082
|
+
is_leaf_party = is_leaf
|
|
1083
|
+
|
|
1084
|
+
# Update this party's feature and threshold arrays
|
|
1085
|
+
def update_party_feats(
|
|
1086
|
+
feats,
|
|
1087
|
+
best_feat,
|
|
1088
|
+
indices,
|
|
1089
|
+
owned,
|
|
1090
|
+
leaf,
|
|
1091
|
+
pid=party_idx,
|
|
1092
|
+
):
|
|
1093
|
+
tmp = feats.at[indices].set(best_feat)
|
|
1094
|
+
tmp = jnp.where(leaf.astype(bool), jnp.int64(-1), tmp)
|
|
1095
|
+
mask = owned == pid
|
|
1096
|
+
return jnp.where(mask, tmp, jnp.int64(-1))
|
|
1097
|
+
|
|
1098
|
+
all_feats[party_idx] = simp.pcall_static(
|
|
1099
|
+
(party_rank,),
|
|
1100
|
+
lambda pf=all_feats[party_idx], bf=all_feats_level[party_idx], ci=cur_indices_party, op=owned_party_party, il=is_leaf_party: (
|
|
1101
|
+
tensor.run_jax(update_party_feats, pf, bf, ci, op, il)
|
|
1102
|
+
),
|
|
1103
|
+
)
|
|
1104
|
+
|
|
1105
|
+
def update_party_thresholds(
|
|
1106
|
+
thresholds,
|
|
1107
|
+
bins_arr,
|
|
1108
|
+
best_feat,
|
|
1109
|
+
best_thresh_idx,
|
|
1110
|
+
indices,
|
|
1111
|
+
owned,
|
|
1112
|
+
leaf,
|
|
1113
|
+
pid=party_idx,
|
|
1114
|
+
):
|
|
1115
|
+
# Get actual threshold values from bins
|
|
1116
|
+
best_thresh = bins_arr[best_feat, best_thresh_idx]
|
|
1117
|
+
tmp = thresholds.at[indices].set(best_thresh)
|
|
1118
|
+
tmp = jnp.where(leaf.astype(bool), jnp.float32(jnp.inf), tmp)
|
|
1119
|
+
mask = owned == pid
|
|
1120
|
+
return jnp.where(mask, tmp, jnp.float32(jnp.inf))
|
|
1121
|
+
|
|
1122
|
+
all_thresholds[party_idx] = simp.pcall_static(
|
|
1123
|
+
(party_rank,),
|
|
1124
|
+
lambda pt=all_thresholds[party_idx], b=all_bins[party_idx], bf=all_feats_level[party_idx], bt_idx=all_threshs_level[party_idx], ci=cur_indices_party, op=owned_party_party, il=is_leaf_party: (
|
|
1125
|
+
tensor.run_jax(
|
|
1126
|
+
update_party_thresholds,
|
|
1127
|
+
pt,
|
|
1128
|
+
b,
|
|
1129
|
+
bf,
|
|
1130
|
+
bt_idx,
|
|
1131
|
+
ci,
|
|
1132
|
+
op,
|
|
1133
|
+
il,
|
|
1134
|
+
)
|
|
1135
|
+
),
|
|
1136
|
+
)
|
|
1137
|
+
|
|
1138
|
+
# Compute temporary bt for this party
|
|
1139
|
+
# Need bt and bt_level at this party too
|
|
1140
|
+
if party_idx > 0:
|
|
1141
|
+
bt_party = simp.shuffle_static(bt, {party_rank: ap_rank})
|
|
1142
|
+
bt_level_party = simp.shuffle_static(bt_level, {party_rank: ap_rank})
|
|
1143
|
+
else:
|
|
1144
|
+
bt_party = bt
|
|
1145
|
+
bt_level_party = bt_level
|
|
1146
|
+
|
|
1147
|
+
tmp_bt = simp.pcall_static(
|
|
1148
|
+
(party_rank,),
|
|
1149
|
+
lambda bi=all_bin_indices[party_idx], bf=all_feats_level[party_idx], bt_idx=all_threshs_level[party_idx], bt_arr=bt_party, bt_lv=bt_level_party, il=is_leaf_party: (
|
|
1150
|
+
tensor.run_jax(update_bt, bt_arr, bt_lv, il, bi, bf, bt_idx)
|
|
1151
|
+
),
|
|
1152
|
+
)
|
|
1153
|
+
|
|
1154
|
+
# Transfer PP's tmp_bt to AP for merging
|
|
1155
|
+
if party_idx > 0:
|
|
1156
|
+
tmp_bt = simp.shuffle_static(tmp_bt, {ap_rank: party_rank})
|
|
1157
|
+
|
|
1158
|
+
all_tmp_bt.append(tmp_bt)
|
|
1159
|
+
|
|
1160
|
+
# === Merge bt updates based on best_party ===
|
|
1161
|
+
def merge_bt_updates(
|
|
1162
|
+
current_bt,
|
|
1163
|
+
all_tmp,
|
|
1164
|
+
best_party_arr,
|
|
1165
|
+
level_indices,
|
|
1166
|
+
):
|
|
1167
|
+
stacked = jnp.stack(all_tmp, axis=0) # (n_parties, m)
|
|
1168
|
+
updated_bt = current_bt
|
|
1169
|
+
|
|
1170
|
+
def update_for_node(carry, i):
|
|
1171
|
+
bt_arr = carry
|
|
1172
|
+
node_idx = level_indices[i]
|
|
1173
|
+
winning_party = best_party_arr[i]
|
|
1174
|
+
samples_in_node = current_bt == node_idx
|
|
1175
|
+
winning_bt = stacked[winning_party]
|
|
1176
|
+
return jnp.where(samples_in_node, winning_bt, bt_arr), None
|
|
1177
|
+
|
|
1178
|
+
updated_bt, _ = jax.lax.scan(
|
|
1179
|
+
update_for_node, updated_bt, jnp.arange(len(level_indices))
|
|
1180
|
+
)
|
|
1181
|
+
return updated_bt
|
|
1182
|
+
|
|
1183
|
+
bt = simp.pcall_static(
|
|
1184
|
+
(ap_rank,),
|
|
1185
|
+
lambda: tensor.run_jax(
|
|
1186
|
+
merge_bt_updates, bt, all_tmp_bt, best_party, cur_indices
|
|
1187
|
+
),
|
|
1188
|
+
)
|
|
1189
|
+
|
|
1190
|
+
# Broadcast updated bt to all parties
|
|
1191
|
+
if pp_ranks:
|
|
1192
|
+
bt_parts = [bt]
|
|
1193
|
+
for r in pp_ranks:
|
|
1194
|
+
bt_parts.append(simp.shuffle_static(bt, {r: ap_rank}))
|
|
1195
|
+
bt = simp.converge(*bt_parts)
|
|
1196
|
+
|
|
1197
|
+
return is_leaf, owned_party, all_feats, all_thresholds, bt
|
|
1198
|
+
|
|
1199
|
+
|
|
1200
|
+
def build_tree(
|
|
1201
|
+
gh: Any, # Plaintext G/H at AP, shape (m, 2)
|
|
1202
|
+
g_cts: list[Any], # Encrypted G chunks at AP
|
|
1203
|
+
h_cts: list[Any], # Encrypted H chunks at AP
|
|
1204
|
+
n_chunks: int, # Number of CT chunks
|
|
1205
|
+
all_bins: list[Any], # Bin boundaries per party
|
|
1206
|
+
all_bin_indices: list[Any], # Binned features per party
|
|
1207
|
+
sk: Any, # Secret key at AP
|
|
1208
|
+
pk: Any, # Public key at AP
|
|
1209
|
+
encoder: Any, # BFV encoder
|
|
1210
|
+
relin_keys: Any, # Relinearization keys
|
|
1211
|
+
galois_keys: Any, # Galois keys for rotation
|
|
1212
|
+
fxp_scale: int,
|
|
1213
|
+
ap_rank: int,
|
|
1214
|
+
pp_ranks: list[int],
|
|
1215
|
+
max_depth: int,
|
|
1216
|
+
reg_lambda: float,
|
|
1217
|
+
gamma: float,
|
|
1218
|
+
min_child_weight: float,
|
|
1219
|
+
n_samples: int,
|
|
1220
|
+
n_buckets: int,
|
|
1221
|
+
n_features_per_party: list[int], # Number of features for each party
|
|
1222
|
+
slot_count: int = DEFAULT_POLY_MODULUS_DEGREE,
|
|
1223
|
+
) -> Tree:
|
|
1224
|
+
"""Build a single decision tree level by level.
|
|
1225
|
+
|
|
1226
|
+
The algorithm proceeds breadth-first:
|
|
1227
|
+
1. For each level, compute histograms (local at AP, FHE at PPs)
|
|
1228
|
+
2. Find best split per node across all parties
|
|
1229
|
+
3. Update tree structure and sample assignments
|
|
1230
|
+
4. Repeat until max_depth reached
|
|
1231
|
+
|
|
1232
|
+
**Multi-CT Support**: When n_samples > slot_count, data is split into
|
|
1233
|
+
n_chunks ciphertexts. Each chunk is processed separately and results
|
|
1234
|
+
are accumulated in the FHE domain.
|
|
1235
|
+
"""
|
|
1236
|
+
m = n_samples
|
|
1237
|
+
n_nodes = 2 ** (max_depth + 1) - 1
|
|
1238
|
+
all_ranks = [ap_rank, *pp_ranks]
|
|
1239
|
+
|
|
1240
|
+
# Initialize tree arrays
|
|
1241
|
+
def init_array(rank, shape, dtype, fill):
|
|
1242
|
+
return simp.pcall_static(
|
|
1243
|
+
(rank,),
|
|
1244
|
+
lambda: tensor.constant(np.full(shape, fill, dtype=dtype)),
|
|
1245
|
+
)
|
|
1246
|
+
|
|
1247
|
+
all_feats = [init_array(r, n_nodes, np.int64, -1) for r in all_ranks]
|
|
1248
|
+
all_thresholds = [init_array(r, n_nodes, np.float32, np.inf) for r in all_ranks]
|
|
1249
|
+
values = init_array(ap_rank, n_nodes, np.float32, 0.0)
|
|
1250
|
+
is_leaf = init_array(ap_rank, n_nodes, np.int64, 0)
|
|
1251
|
+
owned_party = init_array(ap_rank, n_nodes, np.int64, -1)
|
|
1252
|
+
bt = init_array(ap_rank, m, np.int64, 0)
|
|
1253
|
+
|
|
1254
|
+
# Store parent histograms for subtraction optimization
|
|
1255
|
+
# List of TraceObjects (JAX arrays) representing stacked histograms for previous level
|
|
1256
|
+
# Index 0 is AP (unused), 1..k are PPs
|
|
1257
|
+
last_level_hists: list[Any] = [None] * (len(pp_ranks) + 1)
|
|
1258
|
+
|
|
1259
|
+
# Optimization 1: Hoist Ciphertext Transfer
|
|
1260
|
+
# Transfer encrypted gradients to all PPs once, before the tree building loop.
|
|
1261
|
+
g_cts_pps: dict[int, list[Any]] = {}
|
|
1262
|
+
h_cts_pps: dict[int, list[Any]] = {}
|
|
1263
|
+
|
|
1264
|
+
for pp_rank in pp_ranks:
|
|
1265
|
+
g_cts_pps[pp_rank] = [
|
|
1266
|
+
simp.shuffle_static(ct, {pp_rank: ap_rank}) for ct in g_cts
|
|
1267
|
+
]
|
|
1268
|
+
h_cts_pps[pp_rank] = [
|
|
1269
|
+
simp.shuffle_static(ct, {pp_rank: ap_rank}) for ct in h_cts
|
|
1270
|
+
]
|
|
1271
|
+
|
|
1272
|
+
for level in range(max_depth):
|
|
1273
|
+
n_level = 2**level
|
|
1274
|
+
level_offset = 2**level - 1
|
|
1275
|
+
|
|
1276
|
+
cur_indices = simp.pcall_static(
|
|
1277
|
+
(ap_rank,),
|
|
1278
|
+
lambda off=level_offset, nl=n_level: tensor.constant(
|
|
1279
|
+
np.arange(nl, dtype=np.int64) + off
|
|
1280
|
+
),
|
|
1281
|
+
)
|
|
1282
|
+
|
|
1283
|
+
# Local bt for this level
|
|
1284
|
+
bt_level = simp.pcall_static(
|
|
1285
|
+
(ap_rank,),
|
|
1286
|
+
lambda off=level_offset, b=bt: tensor.run_jax(lambda x: x - off, b),
|
|
1287
|
+
)
|
|
1288
|
+
|
|
1289
|
+
# === AP: Local histogram computation ===
|
|
1290
|
+
ap_gains, ap_feats, ap_threshs = _find_splits_ap(
|
|
1291
|
+
ap_rank,
|
|
1292
|
+
n_level,
|
|
1293
|
+
n_buckets,
|
|
1294
|
+
gh,
|
|
1295
|
+
bt_level,
|
|
1296
|
+
all_bin_indices[0],
|
|
1297
|
+
reg_lambda,
|
|
1298
|
+
gamma,
|
|
1299
|
+
min_child_weight,
|
|
1300
|
+
)
|
|
1301
|
+
|
|
1302
|
+
all_gains = [ap_gains]
|
|
1303
|
+
all_feats_level = [ap_feats]
|
|
1304
|
+
all_threshs_level = [ap_threshs]
|
|
1305
|
+
|
|
1306
|
+
# === PP: FHE histogram computation ===
|
|
1307
|
+
pp_gains_list, pp_feats_list, pp_threshs_list = _find_splits_pps(
|
|
1308
|
+
level,
|
|
1309
|
+
pp_ranks,
|
|
1310
|
+
ap_rank,
|
|
1311
|
+
g_cts_pps,
|
|
1312
|
+
h_cts_pps,
|
|
1313
|
+
bt_level,
|
|
1314
|
+
all_bin_indices,
|
|
1315
|
+
n_features_per_party,
|
|
1316
|
+
last_level_hists,
|
|
1317
|
+
encoder,
|
|
1318
|
+
relin_keys,
|
|
1319
|
+
galois_keys,
|
|
1320
|
+
sk,
|
|
1321
|
+
fxp_scale,
|
|
1322
|
+
m,
|
|
1323
|
+
n_chunks,
|
|
1324
|
+
slot_count,
|
|
1325
|
+
n_buckets,
|
|
1326
|
+
reg_lambda,
|
|
1327
|
+
gamma,
|
|
1328
|
+
min_child_weight,
|
|
1329
|
+
)
|
|
1330
|
+
|
|
1331
|
+
all_gains.extend(pp_gains_list)
|
|
1332
|
+
all_feats_level.extend(pp_feats_list)
|
|
1333
|
+
all_threshs_level.extend(pp_threshs_list)
|
|
1334
|
+
|
|
1335
|
+
# === Find global best split across all parties ===
|
|
1336
|
+
def find_global_best(*gains):
|
|
1337
|
+
stacked = jnp.stack(gains, axis=0) # (n_parties, n_nodes)
|
|
1338
|
+
best_party = jnp.argmax(stacked, axis=0)
|
|
1339
|
+
best_gains = jnp.take_along_axis(
|
|
1340
|
+
stacked, best_party[None, :], axis=0
|
|
1341
|
+
).squeeze(0)
|
|
1342
|
+
return best_gains, best_party
|
|
1343
|
+
|
|
1344
|
+
best_gains, best_party = simp.pcall_static(
|
|
1345
|
+
(ap_rank,),
|
|
1346
|
+
lambda gains=all_gains: tensor.run_jax(find_global_best, *gains),
|
|
1347
|
+
)
|
|
1348
|
+
|
|
1349
|
+
# === Update Tree State ===
|
|
1350
|
+
is_leaf, owned_party, all_feats, all_thresholds, bt = _update_tree_state(
|
|
1351
|
+
ap_rank,
|
|
1352
|
+
pp_ranks,
|
|
1353
|
+
all_ranks,
|
|
1354
|
+
all_feats,
|
|
1355
|
+
all_thresholds,
|
|
1356
|
+
bt,
|
|
1357
|
+
bt_level,
|
|
1358
|
+
is_leaf,
|
|
1359
|
+
owned_party,
|
|
1360
|
+
cur_indices,
|
|
1361
|
+
best_party,
|
|
1362
|
+
best_gains,
|
|
1363
|
+
all_feats_level,
|
|
1364
|
+
all_threshs_level,
|
|
1365
|
+
all_bins,
|
|
1366
|
+
all_bin_indices,
|
|
1367
|
+
)
|
|
1368
|
+
|
|
1369
|
+
# Force final level nodes to be leaves
|
|
1370
|
+
final_start = 2**max_depth - 1
|
|
1371
|
+
final_end = 2 ** (max_depth + 1) - 1
|
|
1372
|
+
final_indices = simp.pcall_static(
|
|
1373
|
+
(ap_rank,),
|
|
1374
|
+
lambda: tensor.constant(np.arange(final_start, final_end, dtype=np.int64)),
|
|
1375
|
+
)
|
|
1376
|
+
is_leaf = simp.pcall_static(
|
|
1377
|
+
(ap_rank,),
|
|
1378
|
+
lambda: tensor.run_jax(
|
|
1379
|
+
lambda il, fi: il.at[fi].set(1),
|
|
1380
|
+
is_leaf,
|
|
1381
|
+
final_indices,
|
|
1382
|
+
),
|
|
1383
|
+
)
|
|
1384
|
+
|
|
1385
|
+
# Broadcast final is_leaf to all parties (needed for prediction)
|
|
1386
|
+
# Note: owned_party is already converged to all parties during the level loop
|
|
1387
|
+
if pp_ranks:
|
|
1388
|
+
is_leaf_parts = [is_leaf]
|
|
1389
|
+
for r in pp_ranks:
|
|
1390
|
+
is_leaf_parts.append(simp.shuffle_static(is_leaf, {r: ap_rank}))
|
|
1391
|
+
is_leaf = simp.converge(*is_leaf_parts)
|
|
1392
|
+
|
|
1393
|
+
# Compute final leaf values
|
|
1394
|
+
leaf_val_fn = make_compute_leaf_values(n_nodes)
|
|
1395
|
+
values = simp.pcall_static(
|
|
1396
|
+
(ap_rank,),
|
|
1397
|
+
lambda fn=leaf_val_fn: tensor.run_jax(fn, gh, bt, is_leaf, reg_lambda),
|
|
1398
|
+
)
|
|
1399
|
+
|
|
1400
|
+
return Tree(
|
|
1401
|
+
feature=all_feats,
|
|
1402
|
+
threshold=all_thresholds,
|
|
1403
|
+
value=values,
|
|
1404
|
+
is_leaf=is_leaf,
|
|
1405
|
+
owned_party_id=owned_party,
|
|
1406
|
+
)
|
|
1407
|
+
|
|
1408
|
+
|
|
1409
|
+
# ==============================================================================
|
|
1410
|
+
# Prediction
|
|
1411
|
+
# ==============================================================================
|
|
1412
|
+
|
|
1413
|
+
|
|
1414
|
+
def predict_tree_single_party(
|
|
1415
|
+
data: Any,
|
|
1416
|
+
feature: Any,
|
|
1417
|
+
threshold: Any,
|
|
1418
|
+
is_leaf: Any,
|
|
1419
|
+
owned_party_id: Any,
|
|
1420
|
+
party_id: int,
|
|
1421
|
+
n_nodes: int,
|
|
1422
|
+
) -> Any:
|
|
1423
|
+
"""Local tree traversal for a single party.
|
|
1424
|
+
|
|
1425
|
+
Returns a location matrix (m, n_nodes) where each sample may be in multiple
|
|
1426
|
+
nodes if splits are owned by other parties.
|
|
1427
|
+
"""
|
|
1428
|
+
|
|
1429
|
+
def traverse_kernel(
|
|
1430
|
+
data_arr,
|
|
1431
|
+
feat_arr,
|
|
1432
|
+
thresh_arr,
|
|
1433
|
+
leaf_arr,
|
|
1434
|
+
owner_arr,
|
|
1435
|
+
):
|
|
1436
|
+
n_samples = data_arr.shape[0]
|
|
1437
|
+
# Start all samples at root
|
|
1438
|
+
locations = jnp.zeros((n_samples, n_nodes), dtype=jnp.int64).at[:, 0].set(1)
|
|
1439
|
+
|
|
1440
|
+
def propagate(i, locs):
|
|
1441
|
+
is_my_split = (leaf_arr[i] == 0) & (owner_arr[i] == party_id)
|
|
1442
|
+
|
|
1443
|
+
def process_my_split(locs_inner):
|
|
1444
|
+
samples_here = locs_inner[:, i]
|
|
1445
|
+
feat_idx = feat_arr[i]
|
|
1446
|
+
thresh = thresh_arr[i]
|
|
1447
|
+
go_left = data_arr[:, feat_idx] <= thresh
|
|
1448
|
+
to_left = samples_here * go_left.astype(jnp.int64)
|
|
1449
|
+
to_right = samples_here * (1 - go_left.astype(jnp.int64))
|
|
1450
|
+
locs_inner = locs_inner.at[:, 2 * i + 1].add(to_left)
|
|
1451
|
+
locs_inner = locs_inner.at[:, 2 * i + 2].add(to_right)
|
|
1452
|
+
return locs_inner.at[:, i].set(0)
|
|
1453
|
+
|
|
1454
|
+
def propagate_unknown(locs_inner):
|
|
1455
|
+
is_split = leaf_arr[i] == 0
|
|
1456
|
+
|
|
1457
|
+
def propagate_both(loc):
|
|
1458
|
+
samples_here = loc[:, i]
|
|
1459
|
+
loc = loc.at[:, 2 * i + 1].add(samples_here)
|
|
1460
|
+
loc = loc.at[:, 2 * i + 2].add(samples_here)
|
|
1461
|
+
return loc.at[:, i].set(0)
|
|
1462
|
+
|
|
1463
|
+
return jax.lax.cond(is_split, propagate_both, lambda x: x, locs_inner)
|
|
1464
|
+
|
|
1465
|
+
return jax.lax.cond(is_my_split, process_my_split, propagate_unknown, locs)
|
|
1466
|
+
|
|
1467
|
+
return jax.lax.fori_loop(0, n_nodes // 2, propagate, locations)
|
|
1468
|
+
|
|
1469
|
+
return tensor.run_jax(
|
|
1470
|
+
traverse_kernel, data, feature, threshold, is_leaf, owned_party_id
|
|
1471
|
+
)
|
|
1472
|
+
|
|
1473
|
+
|
|
1474
|
+
def predict_tree(
|
|
1475
|
+
tree: Tree,
|
|
1476
|
+
all_datas: list[Any],
|
|
1477
|
+
ap_rank: int,
|
|
1478
|
+
pp_ranks: list[int],
|
|
1479
|
+
n_nodes: int,
|
|
1480
|
+
) -> Any:
|
|
1481
|
+
"""Predict using a single tree by aggregating location masks from all parties."""
|
|
1482
|
+
all_ranks = [ap_rank, *pp_ranks]
|
|
1483
|
+
|
|
1484
|
+
# Each party computes its local traversal
|
|
1485
|
+
all_masks: list[Any] = []
|
|
1486
|
+
|
|
1487
|
+
for i, rank in enumerate(all_ranks):
|
|
1488
|
+
mask = simp.pcall_static(
|
|
1489
|
+
(rank,),
|
|
1490
|
+
lambda d=all_datas[i], f=tree.feature[i], t=tree.threshold[i], idx=i: (
|
|
1491
|
+
predict_tree_single_party(
|
|
1492
|
+
d, f, t, tree.is_leaf, tree.owned_party_id, idx, n_nodes
|
|
1493
|
+
)
|
|
1494
|
+
),
|
|
1495
|
+
)
|
|
1496
|
+
# Transfer to AP
|
|
1497
|
+
if rank != ap_rank:
|
|
1498
|
+
mask = simp.shuffle_static(mask, {ap_rank: rank})
|
|
1499
|
+
all_masks.append(mask)
|
|
1500
|
+
|
|
1501
|
+
# Aggregate masks at AP
|
|
1502
|
+
def aggregate_predictions(
|
|
1503
|
+
*masks,
|
|
1504
|
+
leaf_arr,
|
|
1505
|
+
values_arr,
|
|
1506
|
+
):
|
|
1507
|
+
stacked = jnp.stack(masks, axis=0) # (n_parties, m, n_nodes)
|
|
1508
|
+
# Consensus: sample is at node only if ALL parties agree
|
|
1509
|
+
consensus = jnp.all(stacked > 0, axis=0) # (m, n_nodes)
|
|
1510
|
+
# Find leaf nodes
|
|
1511
|
+
final_leaf_mask = consensus * leaf_arr.astype(bool)
|
|
1512
|
+
# Get leaf index for each sample
|
|
1513
|
+
leaf_indices = jnp.argmax(final_leaf_mask, axis=1)
|
|
1514
|
+
return values_arr[leaf_indices]
|
|
1515
|
+
|
|
1516
|
+
predictions = simp.pcall_static(
|
|
1517
|
+
(ap_rank,),
|
|
1518
|
+
lambda: tensor.run_jax(
|
|
1519
|
+
aggregate_predictions,
|
|
1520
|
+
*all_masks,
|
|
1521
|
+
leaf_arr=tree.is_leaf,
|
|
1522
|
+
values_arr=tree.value,
|
|
1523
|
+
),
|
|
1524
|
+
)
|
|
1525
|
+
|
|
1526
|
+
return predictions
|
|
1527
|
+
|
|
1528
|
+
|
|
1529
|
+
def predict_ensemble(
|
|
1530
|
+
model: TreeEnsemble,
|
|
1531
|
+
all_datas: list[Any],
|
|
1532
|
+
ap_rank: int,
|
|
1533
|
+
pp_ranks: list[int],
|
|
1534
|
+
learning_rate: float,
|
|
1535
|
+
n_samples: int,
|
|
1536
|
+
n_nodes: int,
|
|
1537
|
+
) -> Any:
|
|
1538
|
+
"""Predict using the full ensemble."""
|
|
1539
|
+
m = n_samples
|
|
1540
|
+
|
|
1541
|
+
# Start with initial prediction
|
|
1542
|
+
y_pred_logits = simp.pcall_static(
|
|
1543
|
+
(ap_rank,),
|
|
1544
|
+
lambda n=m: tensor.run_jax(
|
|
1545
|
+
lambda init: init * jnp.ones(n), model.initial_prediction
|
|
1546
|
+
),
|
|
1547
|
+
)
|
|
1548
|
+
|
|
1549
|
+
# Add predictions from each tree
|
|
1550
|
+
for tree in model.trees:
|
|
1551
|
+
tree_pred = predict_tree(tree, all_datas, ap_rank, pp_ranks, n_nodes)
|
|
1552
|
+
|
|
1553
|
+
def update_pred(y_pred, pred, lr=learning_rate):
|
|
1554
|
+
return y_pred + lr * pred
|
|
1555
|
+
|
|
1556
|
+
y_pred_logits = simp.pcall_static(
|
|
1557
|
+
(ap_rank,),
|
|
1558
|
+
lambda yp=y_pred_logits, tp=tree_pred: tensor.run_jax(update_pred, yp, tp),
|
|
1559
|
+
)
|
|
1560
|
+
|
|
1561
|
+
# Convert logits to probabilities
|
|
1562
|
+
y_prob = simp.pcall_static(
|
|
1563
|
+
(ap_rank,),
|
|
1564
|
+
lambda: tensor.run_jax(sigmoid, y_pred_logits),
|
|
1565
|
+
)
|
|
1566
|
+
|
|
1567
|
+
return y_prob
|
|
1568
|
+
|
|
1569
|
+
|
|
1570
|
+
# ==============================================================================
|
|
1571
|
+
# Training API
|
|
1572
|
+
# ==============================================================================
|
|
1573
|
+
|
|
1574
|
+
|
|
1575
|
+
def fit_tree_ensemble(
|
|
1576
|
+
all_datas: list[Any],
|
|
1577
|
+
y_data: Any,
|
|
1578
|
+
all_bins: list[Any],
|
|
1579
|
+
all_bin_indices: list[Any],
|
|
1580
|
+
initial_pred: Any,
|
|
1581
|
+
n_samples: int,
|
|
1582
|
+
n_buckets: int,
|
|
1583
|
+
n_features_per_party: list[int],
|
|
1584
|
+
n_estimators: int,
|
|
1585
|
+
learning_rate: float,
|
|
1586
|
+
max_depth: int,
|
|
1587
|
+
reg_lambda: float,
|
|
1588
|
+
gamma: float,
|
|
1589
|
+
min_child_weight: float,
|
|
1590
|
+
ap_rank: int,
|
|
1591
|
+
pp_ranks: list[int],
|
|
1592
|
+
) -> TreeEnsemble:
|
|
1593
|
+
"""Fit a SecureBoost tree ensemble."""
|
|
1594
|
+
m = n_samples
|
|
1595
|
+
fxp_scale = 1 << DEFAULT_FXP_BITS
|
|
1596
|
+
|
|
1597
|
+
y_pred = simp.pcall_static(
|
|
1598
|
+
(ap_rank,),
|
|
1599
|
+
lambda n=m: tensor.run_jax(lambda init: init * jnp.ones(n), initial_pred),
|
|
1600
|
+
)
|
|
1601
|
+
|
|
1602
|
+
# BFV key generation at AP (only if we have passive parties)
|
|
1603
|
+
pk, sk, relin_keys, galois_keys, encoder = None, None, None, None, None
|
|
1604
|
+
if pp_ranks:
|
|
1605
|
+
|
|
1606
|
+
def keygen_fn():
|
|
1607
|
+
pub, sec = bfv.keygen(poly_modulus_degree=DEFAULT_POLY_MODULUS_DEGREE)
|
|
1608
|
+
rk = bfv.make_relin_keys(sec)
|
|
1609
|
+
gk = bfv.make_galois_keys(sec)
|
|
1610
|
+
enc = bfv.create_encoder(poly_modulus_degree=DEFAULT_POLY_MODULUS_DEGREE)
|
|
1611
|
+
return pub, sec, rk, gk, enc
|
|
1612
|
+
|
|
1613
|
+
pk, sk, relin_keys, galois_keys, encoder = simp.pcall_static(
|
|
1614
|
+
(ap_rank,), keygen_fn
|
|
1615
|
+
)
|
|
1616
|
+
|
|
1617
|
+
trees: list[Tree] = []
|
|
1618
|
+
|
|
1619
|
+
for _tree_idx in range(n_estimators):
|
|
1620
|
+
# Compute G/H, quantize, and split into qg/qh in one call
|
|
1621
|
+
def compute_gh_quantized(y_true, y_pred_logits, scale):
|
|
1622
|
+
gh = compute_gh(y_true, y_pred_logits)
|
|
1623
|
+
qgh = quantize_gh(gh, scale)
|
|
1624
|
+
return gh, qgh[:, 0], qgh[:, 1]
|
|
1625
|
+
|
|
1626
|
+
gh, qg, qh = simp.pcall_static(
|
|
1627
|
+
(ap_rank,),
|
|
1628
|
+
lambda yp=y_pred: tensor.run_jax(
|
|
1629
|
+
compute_gh_quantized, y_data, yp, fxp_scale
|
|
1630
|
+
),
|
|
1631
|
+
)
|
|
1632
|
+
|
|
1633
|
+
# FHE encrypt only if we have passive parties
|
|
1634
|
+
g_cts, h_cts, n_chunks = [], [], 1
|
|
1635
|
+
if pp_ranks:
|
|
1636
|
+
g_cts, h_cts, n_chunks = fhe_encrypt_gh(
|
|
1637
|
+
qg, qh, pk, encoder, ap_rank, n_samples
|
|
1638
|
+
)
|
|
1639
|
+
|
|
1640
|
+
tree = build_tree(
|
|
1641
|
+
gh,
|
|
1642
|
+
g_cts,
|
|
1643
|
+
h_cts,
|
|
1644
|
+
n_chunks,
|
|
1645
|
+
all_bins,
|
|
1646
|
+
all_bin_indices,
|
|
1647
|
+
sk,
|
|
1648
|
+
pk,
|
|
1649
|
+
encoder,
|
|
1650
|
+
relin_keys,
|
|
1651
|
+
galois_keys,
|
|
1652
|
+
fxp_scale,
|
|
1653
|
+
ap_rank,
|
|
1654
|
+
pp_ranks,
|
|
1655
|
+
max_depth,
|
|
1656
|
+
reg_lambda,
|
|
1657
|
+
gamma,
|
|
1658
|
+
min_child_weight,
|
|
1659
|
+
n_samples,
|
|
1660
|
+
n_buckets,
|
|
1661
|
+
n_features_per_party,
|
|
1662
|
+
)
|
|
1663
|
+
trees.append(tree)
|
|
1664
|
+
|
|
1665
|
+
# Predict tree and update y_pred
|
|
1666
|
+
n_nodes = 2 ** (max_depth + 1) - 1
|
|
1667
|
+
tree_pred = predict_tree(tree, all_datas, ap_rank, pp_ranks, n_nodes)
|
|
1668
|
+
|
|
1669
|
+
def update_pred_fn(curr_y, t_pred, lr=learning_rate):
|
|
1670
|
+
return curr_y + lr * t_pred
|
|
1671
|
+
|
|
1672
|
+
y_pred = simp.pcall_static(
|
|
1673
|
+
(ap_rank,),
|
|
1674
|
+
lambda yp=y_pred, tp=tree_pred: tensor.run_jax(update_pred_fn, yp, tp),
|
|
1675
|
+
)
|
|
1676
|
+
|
|
1677
|
+
return TreeEnsemble(
|
|
1678
|
+
max_depth=max_depth,
|
|
1679
|
+
trees=trees,
|
|
1680
|
+
initial_prediction=initial_pred,
|
|
1681
|
+
)
|
|
1682
|
+
|
|
1683
|
+
|
|
1684
|
+
# ==============================================================================
|
|
1685
|
+
# SecureBoost Class
|
|
1686
|
+
# ==============================================================================
|
|
1687
|
+
|
|
1688
|
+
|
|
1689
|
+
class SecureBoost:
|
|
1690
|
+
"""SecureBoost classifier using mplang.v2 low-level BFV APIs.
|
|
1691
|
+
|
|
1692
|
+
This is an optimized implementation that uses BFV SIMD slots for
|
|
1693
|
+
efficient histogram computation.
|
|
1694
|
+
|
|
1695
|
+
Example:
|
|
1696
|
+
model = SecureBoost(n_estimators=10, max_depth=3)
|
|
1697
|
+
model.fit([X_ap, X_pp], y)
|
|
1698
|
+
predictions = model.predict([X_ap_test, X_pp_test])
|
|
1699
|
+
"""
|
|
1700
|
+
|
|
1701
|
+
def __init__(
|
|
1702
|
+
self,
|
|
1703
|
+
n_estimators: int = 10,
|
|
1704
|
+
learning_rate: float = 0.1,
|
|
1705
|
+
max_depth: int = 3,
|
|
1706
|
+
max_bin: int = 8,
|
|
1707
|
+
reg_lambda: float = 1.0,
|
|
1708
|
+
gamma: float = 0.0,
|
|
1709
|
+
min_child_weight: float = 1.0,
|
|
1710
|
+
ap_rank: int = 0,
|
|
1711
|
+
pp_ranks: list[int] | None = None,
|
|
1712
|
+
):
|
|
1713
|
+
"""Initialize SecureBoost model.
|
|
1714
|
+
|
|
1715
|
+
Args:
|
|
1716
|
+
n_estimators: Number of trees to train
|
|
1717
|
+
learning_rate: Shrinkage factor for updates
|
|
1718
|
+
max_depth: Maximum tree depth
|
|
1719
|
+
max_bin: Maximum number of bins per feature
|
|
1720
|
+
reg_lambda: L2 regularization on leaf weights
|
|
1721
|
+
gamma: Minimum gain required to split
|
|
1722
|
+
min_child_weight: Minimum hessian sum in children
|
|
1723
|
+
ap_rank: Active party rank (holds labels)
|
|
1724
|
+
pp_ranks: Passive party ranks (hold features)
|
|
1725
|
+
"""
|
|
1726
|
+
if max_bin < 2:
|
|
1727
|
+
raise ValueError(f"max_bin must be >= 2, got {max_bin}")
|
|
1728
|
+
|
|
1729
|
+
self.n_estimators = n_estimators
|
|
1730
|
+
self.learning_rate = learning_rate
|
|
1731
|
+
self.max_depth = max_depth
|
|
1732
|
+
self.max_bin = max_bin
|
|
1733
|
+
self.reg_lambda = reg_lambda
|
|
1734
|
+
self.gamma = gamma
|
|
1735
|
+
self.min_child_weight = min_child_weight
|
|
1736
|
+
self.ap_rank = ap_rank
|
|
1737
|
+
self.pp_ranks = pp_ranks if pp_ranks is not None else [1]
|
|
1738
|
+
self.model: TreeEnsemble | None = None
|
|
1739
|
+
|
|
1740
|
+
def fit(
|
|
1741
|
+
self,
|
|
1742
|
+
all_datas: list[Any],
|
|
1743
|
+
y_data: Any,
|
|
1744
|
+
n_samples: int,
|
|
1745
|
+
n_features_per_party: list[int],
|
|
1746
|
+
) -> SecureBoost:
|
|
1747
|
+
"""Fit the SecureBoost model.
|
|
1748
|
+
|
|
1749
|
+
Args:
|
|
1750
|
+
all_datas: List of feature tensors, one per party.
|
|
1751
|
+
First element is AP's features, rest are PPs'.
|
|
1752
|
+
y_data: Labels tensor at AP.
|
|
1753
|
+
n_samples: Number of training samples.
|
|
1754
|
+
n_features_per_party: Number of features for each party.
|
|
1755
|
+
|
|
1756
|
+
Returns:
|
|
1757
|
+
self for method chaining
|
|
1758
|
+
"""
|
|
1759
|
+
self.n_samples = n_samples
|
|
1760
|
+
self.n_features_per_party = n_features_per_party
|
|
1761
|
+
# Build bins for each party
|
|
1762
|
+
all_ranks = [self.ap_rank, *self.pp_ranks]
|
|
1763
|
+
|
|
1764
|
+
build_bins_vmap = jax.vmap(
|
|
1765
|
+
partial(build_bins_equi_width, max_bin=self.max_bin), in_axes=1
|
|
1766
|
+
)
|
|
1767
|
+
compute_indices_vmap = jax.vmap(compute_bin_indices, in_axes=(1, 0), out_axes=1)
|
|
1768
|
+
|
|
1769
|
+
all_bins: list[Any] = []
|
|
1770
|
+
all_bin_indices: list[Any] = []
|
|
1771
|
+
|
|
1772
|
+
for i, rank in enumerate(all_ranks):
|
|
1773
|
+
data = all_datas[i]
|
|
1774
|
+
bins = simp.pcall_static(
|
|
1775
|
+
(rank,),
|
|
1776
|
+
lambda d=data: tensor.run_jax(build_bins_vmap, d),
|
|
1777
|
+
)
|
|
1778
|
+
indices = simp.pcall_static(
|
|
1779
|
+
(rank,),
|
|
1780
|
+
lambda d=data, b=bins: tensor.run_jax(compute_indices_vmap, d, b),
|
|
1781
|
+
)
|
|
1782
|
+
all_bins.append(bins)
|
|
1783
|
+
all_bin_indices.append(indices)
|
|
1784
|
+
|
|
1785
|
+
# Initial prediction
|
|
1786
|
+
initial_pred = simp.pcall_static(
|
|
1787
|
+
(self.ap_rank,),
|
|
1788
|
+
lambda: tensor.run_jax(compute_init_pred, y_data),
|
|
1789
|
+
)
|
|
1790
|
+
|
|
1791
|
+
# Calculate metadata
|
|
1792
|
+
n_buckets = self.max_bin + 1
|
|
1793
|
+
n_features_per_party = self.n_features_per_party
|
|
1794
|
+
|
|
1795
|
+
self.model = fit_tree_ensemble(
|
|
1796
|
+
all_datas,
|
|
1797
|
+
y_data,
|
|
1798
|
+
all_bins,
|
|
1799
|
+
all_bin_indices,
|
|
1800
|
+
initial_pred,
|
|
1801
|
+
self.n_samples,
|
|
1802
|
+
n_buckets,
|
|
1803
|
+
n_features_per_party,
|
|
1804
|
+
self.n_estimators,
|
|
1805
|
+
self.learning_rate,
|
|
1806
|
+
self.max_depth,
|
|
1807
|
+
self.reg_lambda,
|
|
1808
|
+
self.gamma,
|
|
1809
|
+
self.min_child_weight,
|
|
1810
|
+
self.ap_rank,
|
|
1811
|
+
self.pp_ranks,
|
|
1812
|
+
)
|
|
1813
|
+
|
|
1814
|
+
return self
|
|
1815
|
+
|
|
1816
|
+
def predict(self, all_datas: list[Any], n_samples: int) -> Any:
|
|
1817
|
+
"""Predict probabilities for new data.
|
|
1818
|
+
|
|
1819
|
+
Args:
|
|
1820
|
+
all_datas: List of feature tensors, one per party.
|
|
1821
|
+
n_samples: Number of samples.
|
|
1822
|
+
|
|
1823
|
+
Returns:
|
|
1824
|
+
Predicted probabilities at AP.
|
|
1825
|
+
"""
|
|
1826
|
+
if self.model is None:
|
|
1827
|
+
raise RuntimeError("Model not fitted. Call fit() first.")
|
|
1828
|
+
|
|
1829
|
+
n_nodes = 2 ** (self.max_depth + 1) - 1
|
|
1830
|
+
return predict_ensemble(
|
|
1831
|
+
self.model,
|
|
1832
|
+
all_datas,
|
|
1833
|
+
self.ap_rank,
|
|
1834
|
+
self.pp_ranks,
|
|
1835
|
+
self.learning_rate,
|
|
1836
|
+
n_samples,
|
|
1837
|
+
n_nodes,
|
|
1838
|
+
)
|
|
1839
|
+
|
|
1840
|
+
def predict_proba(self, all_datas: list[Any], n_samples: int) -> Any:
|
|
1841
|
+
"""Alias for predict()."""
|
|
1842
|
+
return self.predict(all_datas, n_samples)
|
|
1843
|
+
|
|
1844
|
+
def evaluate(self, all_datas: list[Any], y_data: Any, n_samples: int) -> Any:
|
|
1845
|
+
"""Evaluate model on test data.
|
|
1846
|
+
|
|
1847
|
+
Returns:
|
|
1848
|
+
Accuracy tensor at AP (needs to be fetched after graph execution).
|
|
1849
|
+
"""
|
|
1850
|
+
y_prob = self.predict(all_datas, n_samples)
|
|
1851
|
+
|
|
1852
|
+
def compute_metrics(y_pred, y_true):
|
|
1853
|
+
y_class = (y_pred > 0.5).astype(jnp.float32)
|
|
1854
|
+
accuracy = jnp.mean(y_class == y_true)
|
|
1855
|
+
return accuracy
|
|
1856
|
+
|
|
1857
|
+
accuracy = simp.pcall_static(
|
|
1858
|
+
(self.ap_rank,),
|
|
1859
|
+
lambda: tensor.run_jax(compute_metrics, y_prob, y_data),
|
|
1860
|
+
)
|
|
1861
|
+
return accuracy
|