mplang-nightly 0.1.dev158__py3-none-any.whl → 0.1.dev268__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- mplang/__init__.py +21 -45
- mplang/py.typed +13 -0
- mplang/v1/__init__.py +157 -0
- mplang/v1/_device.py +602 -0
- mplang/{analysis → v1/analysis}/__init__.py +1 -1
- mplang/{analysis → v1/analysis}/diagram.py +5 -7
- mplang/v1/core/__init__.py +157 -0
- mplang/{core → v1/core}/cluster.py +30 -14
- mplang/{core → v1/core}/comm.py +5 -1
- mplang/{core → v1/core}/context_mgr.py +1 -1
- mplang/{core/dtype.py → v1/core/dtypes.py} +44 -2
- mplang/{core → v1/core}/expr/__init__.py +7 -7
- mplang/{core → v1/core}/expr/ast.py +13 -14
- mplang/{core → v1/core}/expr/evaluator.py +65 -24
- mplang/{core → v1/core}/expr/printer.py +24 -18
- mplang/{core → v1/core}/expr/transformer.py +3 -3
- mplang/{core → v1/core}/expr/utils.py +2 -2
- mplang/{core → v1/core}/expr/visitor.py +1 -1
- mplang/{core → v1/core}/expr/walk.py +1 -1
- mplang/{core → v1/core}/interp.py +6 -6
- mplang/{core → v1/core}/mpir.py +23 -16
- mplang/{core → v1/core}/mpobject.py +6 -6
- mplang/{core → v1/core}/mptype.py +13 -10
- mplang/{core → v1/core}/pfunc.py +4 -4
- mplang/{core → v1/core}/primitive.py +106 -201
- mplang/{core → v1/core}/table.py +36 -8
- mplang/{core → v1/core}/tensor.py +1 -1
- mplang/{core → v1/core}/tracer.py +9 -9
- mplang/{api.py → v1/host.py} +38 -6
- mplang/v1/kernels/__init__.py +41 -0
- mplang/{kernels → v1/kernels}/base.py +1 -1
- mplang/v1/kernels/basic.py +240 -0
- mplang/{kernels → v1/kernels}/context.py +42 -27
- mplang/{kernels → v1/kernels}/crypto.py +44 -37
- mplang/v1/kernels/fhe.py +858 -0
- mplang/{kernels → v1/kernels}/mock_tee.py +12 -13
- mplang/{kernels → v1/kernels}/phe.py +263 -57
- mplang/{kernels → v1/kernels}/spu.py +137 -48
- mplang/{kernels → v1/kernels}/sql_duckdb.py +12 -15
- mplang/{kernels → v1/kernels}/stablehlo.py +30 -23
- mplang/v1/kernels/value.py +626 -0
- mplang/{ops → v1/ops}/__init__.py +5 -16
- mplang/{ops → v1/ops}/base.py +2 -5
- mplang/{ops/builtin.py → v1/ops/basic.py} +34 -26
- mplang/v1/ops/crypto.py +262 -0
- mplang/v1/ops/fhe.py +272 -0
- mplang/{ops → v1/ops}/jax_cc.py +33 -68
- mplang/v1/ops/nnx_cc.py +168 -0
- mplang/{ops → v1/ops}/phe.py +16 -4
- mplang/{ops → v1/ops}/spu.py +3 -5
- mplang/v1/ops/sql_cc.py +303 -0
- mplang/{ops → v1/ops}/tee.py +9 -24
- mplang/{protos → v1/protos}/v1alpha1/mpir_pb2.pyi +71 -21
- mplang/v1/protos/v1alpha1/value_pb2.py +34 -0
- mplang/v1/protos/v1alpha1/value_pb2.pyi +169 -0
- mplang/{runtime → v1/runtime}/__init__.py +2 -2
- mplang/v1/runtime/channel.py +230 -0
- mplang/{runtime → v1/runtime}/cli.py +35 -20
- mplang/{runtime → v1/runtime}/client.py +19 -8
- mplang/{runtime → v1/runtime}/communicator.py +59 -15
- mplang/{runtime → v1/runtime}/data_providers.py +80 -19
- mplang/{runtime → v1/runtime}/driver.py +30 -12
- mplang/v1/runtime/link_comm.py +196 -0
- mplang/{runtime → v1/runtime}/server.py +58 -42
- mplang/{runtime → v1/runtime}/session.py +57 -71
- mplang/{runtime → v1/runtime}/simulation.py +55 -28
- mplang/v1/simp/api.py +353 -0
- mplang/{simp → v1/simp}/mpi.py +8 -9
- mplang/{simp/__init__.py → v1/simp/party.py} +19 -145
- mplang/{simp → v1/simp}/random.py +21 -22
- mplang/v1/simp/smpc.py +238 -0
- mplang/v1/utils/table_utils.py +185 -0
- mplang/v2/__init__.py +424 -0
- mplang/v2/backends/__init__.py +57 -0
- mplang/v2/backends/bfv_impl.py +705 -0
- mplang/v2/backends/channel.py +217 -0
- mplang/v2/backends/crypto_impl.py +723 -0
- mplang/v2/backends/field_impl.py +454 -0
- mplang/v2/backends/func_impl.py +107 -0
- mplang/v2/backends/phe_impl.py +148 -0
- mplang/v2/backends/simp_design.md +136 -0
- mplang/v2/backends/simp_driver/__init__.py +41 -0
- mplang/v2/backends/simp_driver/http.py +168 -0
- mplang/v2/backends/simp_driver/mem.py +280 -0
- mplang/v2/backends/simp_driver/ops.py +135 -0
- mplang/v2/backends/simp_driver/state.py +60 -0
- mplang/v2/backends/simp_driver/values.py +52 -0
- mplang/v2/backends/simp_worker/__init__.py +29 -0
- mplang/v2/backends/simp_worker/http.py +354 -0
- mplang/v2/backends/simp_worker/mem.py +102 -0
- mplang/v2/backends/simp_worker/ops.py +167 -0
- mplang/v2/backends/simp_worker/state.py +49 -0
- mplang/v2/backends/spu_impl.py +275 -0
- mplang/v2/backends/spu_state.py +187 -0
- mplang/v2/backends/store_impl.py +62 -0
- mplang/v2/backends/table_impl.py +838 -0
- mplang/v2/backends/tee_impl.py +215 -0
- mplang/v2/backends/tensor_impl.py +519 -0
- mplang/v2/cli.py +603 -0
- mplang/v2/cli_guide.md +122 -0
- mplang/v2/dialects/__init__.py +36 -0
- mplang/v2/dialects/bfv.py +665 -0
- mplang/v2/dialects/crypto.py +689 -0
- mplang/v2/dialects/dtypes.py +378 -0
- mplang/v2/dialects/field.py +210 -0
- mplang/v2/dialects/func.py +135 -0
- mplang/v2/dialects/phe.py +723 -0
- mplang/v2/dialects/simp.py +944 -0
- mplang/v2/dialects/spu.py +349 -0
- mplang/v2/dialects/store.py +63 -0
- mplang/v2/dialects/table.py +407 -0
- mplang/v2/dialects/tee.py +346 -0
- mplang/v2/dialects/tensor.py +1175 -0
- mplang/v2/edsl/README.md +279 -0
- mplang/v2/edsl/__init__.py +99 -0
- mplang/v2/edsl/context.py +311 -0
- mplang/v2/edsl/graph.py +463 -0
- mplang/v2/edsl/jit.py +62 -0
- mplang/v2/edsl/object.py +53 -0
- mplang/v2/edsl/primitive.py +284 -0
- mplang/v2/edsl/printer.py +119 -0
- mplang/v2/edsl/registry.py +207 -0
- mplang/v2/edsl/serde.py +375 -0
- mplang/v2/edsl/tracer.py +614 -0
- mplang/v2/edsl/typing.py +816 -0
- mplang/v2/kernels/Makefile +30 -0
- mplang/v2/kernels/__init__.py +23 -0
- mplang/v2/kernels/gf128.cpp +148 -0
- mplang/v2/kernels/ldpc.cpp +82 -0
- mplang/v2/kernels/okvs.cpp +283 -0
- mplang/v2/kernels/okvs_opt.cpp +291 -0
- mplang/v2/kernels/py_kernels.py +398 -0
- mplang/v2/libs/collective.py +330 -0
- mplang/v2/libs/device/__init__.py +51 -0
- mplang/v2/libs/device/api.py +813 -0
- mplang/v2/libs/device/cluster.py +352 -0
- mplang/v2/libs/ml/__init__.py +23 -0
- mplang/v2/libs/ml/sgb.py +1861 -0
- mplang/v2/libs/mpc/__init__.py +41 -0
- mplang/v2/libs/mpc/_utils.py +99 -0
- mplang/v2/libs/mpc/analytics/__init__.py +35 -0
- mplang/v2/libs/mpc/analytics/aggregation.py +372 -0
- mplang/v2/libs/mpc/analytics/groupby.md +99 -0
- mplang/v2/libs/mpc/analytics/groupby.py +331 -0
- mplang/v2/libs/mpc/analytics/permutation.py +386 -0
- mplang/v2/libs/mpc/common/constants.py +39 -0
- mplang/v2/libs/mpc/ot/__init__.py +32 -0
- mplang/v2/libs/mpc/ot/base.py +222 -0
- mplang/v2/libs/mpc/ot/extension.py +477 -0
- mplang/v2/libs/mpc/ot/silent.py +217 -0
- mplang/v2/libs/mpc/psi/__init__.py +40 -0
- mplang/v2/libs/mpc/psi/cuckoo.py +228 -0
- mplang/v2/libs/mpc/psi/okvs.py +49 -0
- mplang/v2/libs/mpc/psi/okvs_gct.py +79 -0
- mplang/v2/libs/mpc/psi/oprf.py +310 -0
- mplang/v2/libs/mpc/psi/rr22.py +344 -0
- mplang/v2/libs/mpc/psi/unbalanced.py +200 -0
- mplang/v2/libs/mpc/vole/__init__.py +31 -0
- mplang/v2/libs/mpc/vole/gilboa.py +327 -0
- mplang/v2/libs/mpc/vole/ldpc.py +383 -0
- mplang/v2/libs/mpc/vole/silver.py +336 -0
- mplang/v2/runtime/__init__.py +15 -0
- mplang/v2/runtime/dialect_state.py +41 -0
- mplang/v2/runtime/interpreter.py +871 -0
- mplang/v2/runtime/object_store.py +194 -0
- mplang/v2/runtime/value.py +141 -0
- {mplang_nightly-0.1.dev158.dist-info → mplang_nightly-0.1.dev268.dist-info}/METADATA +24 -17
- mplang_nightly-0.1.dev268.dist-info/RECORD +180 -0
- {mplang_nightly-0.1.dev158.dist-info → mplang_nightly-0.1.dev268.dist-info}/WHEEL +1 -1
- mplang/core/__init__.py +0 -92
- mplang/device.py +0 -340
- mplang/kernels/builtin.py +0 -207
- mplang/ops/crypto.py +0 -109
- mplang/ops/ibis_cc.py +0 -139
- mplang/ops/sql.py +0 -61
- mplang/protos/v1alpha1/mpir_pb2_grpc.py +0 -3
- mplang/runtime/link_comm.py +0 -131
- mplang/simp/smpc.py +0 -201
- mplang/utils/table_utils.py +0 -73
- mplang_nightly-0.1.dev158.dist-info/RECORD +0 -77
- /mplang/{core → v1/core}/mask.py +0 -0
- /mplang/{protos → v1/protos}/v1alpha1/mpir_pb2.py +0 -0
- /mplang/{runtime → v1/runtime}/exceptions.py +0 -0
- /mplang/{runtime → v1/runtime}/http_api.md +0 -0
- /mplang/{kernels → v1/simp}/__init__.py +0 -0
- /mplang/{utils → v1/utils}/__init__.py +0 -0
- /mplang/{utils → v1/utils}/crypto.py +0 -0
- /mplang/{utils → v1/utils}/func_utils.py +0 -0
- /mplang/{utils → v1/utils}/spu_utils.py +0 -0
- {mplang_nightly-0.1.dev158.dist-info → mplang_nightly-0.1.dev268.dist-info}/entry_points.txt +0 -0
- {mplang_nightly-0.1.dev158.dist-info → mplang_nightly-0.1.dev268.dist-info}/licenses/LICENSE +0 -0
|
@@ -0,0 +1,454 @@
|
|
|
1
|
+
# Copyright 2025 Ant Group Co., Ltd.
|
|
2
|
+
#
|
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
4
|
+
# you may not use this file except in compliance with the License.
|
|
5
|
+
# You may obtain a copy of the License at
|
|
6
|
+
#
|
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
8
|
+
#
|
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
12
|
+
# See the License for the specific language governing permissions and
|
|
13
|
+
# limitations under the License.
|
|
14
|
+
|
|
15
|
+
"""Field Backend Implementation.
|
|
16
|
+
|
|
17
|
+
Implements runtime execution logic for Field dialect primitives,
|
|
18
|
+
including bindings to C++ kernels (libmplang_kernels.so) and
|
|
19
|
+
NumPy fallbacks where appropriate.
|
|
20
|
+
"""
|
|
21
|
+
|
|
22
|
+
from __future__ import annotations
|
|
23
|
+
|
|
24
|
+
import ctypes
|
|
25
|
+
import os
|
|
26
|
+
import threading
|
|
27
|
+
|
|
28
|
+
# print("DEBUG: Importing field_impl.py")
|
|
29
|
+
import jax.numpy as jnp
|
|
30
|
+
import numpy as np
|
|
31
|
+
|
|
32
|
+
from mplang.v2.backends.tensor_impl import TensorValue, _unwrap, _wrap
|
|
33
|
+
from mplang.v2.dialects import field
|
|
34
|
+
from mplang.v2.edsl.graph import Operation
|
|
35
|
+
from mplang.v2.kernels import py_kernels
|
|
36
|
+
from mplang.v2.runtime.interpreter import Interpreter
|
|
37
|
+
|
|
38
|
+
# =============================================================================
|
|
39
|
+
# Kernel Loading
|
|
40
|
+
# =============================================================================
|
|
41
|
+
|
|
42
|
+
# Load Kernel Library
|
|
43
|
+
# In a real package, this path would be resolved robustly
|
|
44
|
+
_KERNEL_LIB_PATH = os.path.join(
|
|
45
|
+
os.path.dirname(__file__), "..", "kernels", "libmplang_kernels.so"
|
|
46
|
+
)
|
|
47
|
+
_LIB = None
|
|
48
|
+
_LIB_LOCK = threading.Lock()
|
|
49
|
+
|
|
50
|
+
|
|
51
|
+
def _get_lib() -> ctypes.CDLL | None:
|
|
52
|
+
global _LIB
|
|
53
|
+
with _LIB_LOCK:
|
|
54
|
+
if _LIB is None:
|
|
55
|
+
try:
|
|
56
|
+
_LIB = ctypes.CDLL(_KERNEL_LIB_PATH)
|
|
57
|
+
# Define signatures
|
|
58
|
+
_LIB.gf128_mul.argtypes = [
|
|
59
|
+
ctypes.POINTER(ctypes.c_uint64),
|
|
60
|
+
ctypes.POINTER(ctypes.c_uint64),
|
|
61
|
+
ctypes.POINTER(ctypes.c_uint64),
|
|
62
|
+
]
|
|
63
|
+
_LIB.gf128_mul_batch.argtypes = [
|
|
64
|
+
ctypes.POINTER(ctypes.c_uint64),
|
|
65
|
+
ctypes.POINTER(ctypes.c_uint64),
|
|
66
|
+
ctypes.POINTER(ctypes.c_uint64),
|
|
67
|
+
ctypes.c_int64,
|
|
68
|
+
]
|
|
69
|
+
_LIB.solve_okvs.argtypes = [
|
|
70
|
+
ctypes.POINTER(ctypes.c_uint64), # keys
|
|
71
|
+
ctypes.POINTER(ctypes.c_uint64), # values
|
|
72
|
+
ctypes.POINTER(ctypes.c_uint64), # output
|
|
73
|
+
ctypes.c_uint64, # n
|
|
74
|
+
ctypes.c_uint64, # m
|
|
75
|
+
ctypes.POINTER(ctypes.c_uint64), # seed
|
|
76
|
+
]
|
|
77
|
+
_LIB.decode_okvs.argtypes = [
|
|
78
|
+
ctypes.POINTER(ctypes.c_uint64), # keys
|
|
79
|
+
ctypes.POINTER(ctypes.c_uint64), # storage
|
|
80
|
+
ctypes.POINTER(ctypes.c_uint64), # output
|
|
81
|
+
ctypes.c_uint64, # n
|
|
82
|
+
ctypes.c_uint64, # m
|
|
83
|
+
ctypes.POINTER(ctypes.c_uint64), # seed
|
|
84
|
+
]
|
|
85
|
+
# Optimized Mega-Binning Versions
|
|
86
|
+
_LIB.solve_okvs_opt.argtypes = _LIB.solve_okvs.argtypes
|
|
87
|
+
_LIB.decode_okvs_opt.argtypes = _LIB.decode_okvs.argtypes
|
|
88
|
+
|
|
89
|
+
_LIB.aes_128_expand.argtypes = [
|
|
90
|
+
ctypes.POINTER(ctypes.c_uint64), # seeds
|
|
91
|
+
ctypes.POINTER(ctypes.c_uint64), # output
|
|
92
|
+
ctypes.c_uint64, # num_seeds
|
|
93
|
+
ctypes.c_uint64, # length
|
|
94
|
+
]
|
|
95
|
+
_LIB.ldpc_encode.argtypes = [
|
|
96
|
+
ctypes.POINTER(ctypes.c_uint64), # message
|
|
97
|
+
ctypes.POINTER(ctypes.c_uint64), # indices
|
|
98
|
+
ctypes.POINTER(ctypes.c_uint64), # indptr
|
|
99
|
+
ctypes.POINTER(ctypes.c_uint64), # output
|
|
100
|
+
ctypes.c_uint64, # m
|
|
101
|
+
ctypes.c_uint64, # n
|
|
102
|
+
]
|
|
103
|
+
except OSError:
|
|
104
|
+
print(f"WARNING: Could not load kernels from {_KERNEL_LIB_PATH}")
|
|
105
|
+
return _LIB
|
|
106
|
+
|
|
107
|
+
|
|
108
|
+
# =============================================================================
|
|
109
|
+
# Helper Implementations (C++ Wrappers)
|
|
110
|
+
# =============================================================================
|
|
111
|
+
|
|
112
|
+
|
|
113
|
+
def _gf128_mul_impl(a: np.ndarray, b: np.ndarray) -> np.ndarray:
|
|
114
|
+
# a, b are numpy arrays (uint64) usually (N, 2)
|
|
115
|
+
|
|
116
|
+
lib = _get_lib()
|
|
117
|
+
if lib is None:
|
|
118
|
+
# Use pure Python fallback
|
|
119
|
+
return py_kernels.gf128_mul_batch(a, b)
|
|
120
|
+
|
|
121
|
+
# Enforce contiguous C-order arrays (important for ctypes)
|
|
122
|
+
# Use ascontiguousarray to avoid copy if already contiguous
|
|
123
|
+
a_contig = np.ascontiguousarray(a, dtype=np.uint64)
|
|
124
|
+
b_contig = np.ascontiguousarray(b, dtype=np.uint64)
|
|
125
|
+
out = np.zeros_like(a_contig)
|
|
126
|
+
|
|
127
|
+
# Calculate number of elements
|
|
128
|
+
# Assumes last dim is 2.
|
|
129
|
+
# Total uint64 count / 2
|
|
130
|
+
n_elements = a_contig.size // 2
|
|
131
|
+
|
|
132
|
+
a_ptr = a_contig.ctypes.data_as(ctypes.POINTER(ctypes.c_uint64))
|
|
133
|
+
b_ptr = b_contig.ctypes.data_as(ctypes.POINTER(ctypes.c_uint64))
|
|
134
|
+
out_ptr = out.ctypes.data_as(ctypes.POINTER(ctypes.c_uint64))
|
|
135
|
+
|
|
136
|
+
lib.gf128_mul_batch(a_ptr, b_ptr, out_ptr, n_elements)
|
|
137
|
+
|
|
138
|
+
return out
|
|
139
|
+
|
|
140
|
+
|
|
141
|
+
def _okvs_solve_opt_impl(
|
|
142
|
+
keys: np.ndarray, values: np.ndarray, m: int, seed: np.ndarray
|
|
143
|
+
) -> np.ndarray:
|
|
144
|
+
lib = _get_lib()
|
|
145
|
+
if seed.ndim > 1:
|
|
146
|
+
seed = seed.flatten()
|
|
147
|
+
|
|
148
|
+
if lib is None:
|
|
149
|
+
# Fallback to standard (no python impl for opt)
|
|
150
|
+
return _okvs_solve_impl(keys, values, m, seed)
|
|
151
|
+
|
|
152
|
+
n = keys.shape[0]
|
|
153
|
+
|
|
154
|
+
# Heuristic: Mega-Binning is unstable < 200k.
|
|
155
|
+
if n < 200_000:
|
|
156
|
+
return _okvs_solve_impl(keys, values, m, seed)
|
|
157
|
+
|
|
158
|
+
# Heuristic: Mega-Binning requires higher expansion (epsilon ~ 1.35)
|
|
159
|
+
# If m/n is too tight, fallback to Naive (which works with 1.25)
|
|
160
|
+
if m / n < 1.32:
|
|
161
|
+
return _okvs_solve_impl(keys, values, m, seed)
|
|
162
|
+
|
|
163
|
+
keys_c = np.ascontiguousarray(keys, dtype=np.uint64)
|
|
164
|
+
values_c = np.ascontiguousarray(values, dtype=np.uint64)
|
|
165
|
+
seed_c = np.ascontiguousarray(seed, dtype=np.uint64)
|
|
166
|
+
output = np.zeros((m, 2), dtype=np.uint64)
|
|
167
|
+
|
|
168
|
+
lib.solve_okvs_opt(
|
|
169
|
+
keys_c.ctypes.data_as(ctypes.POINTER(ctypes.c_uint64)),
|
|
170
|
+
values_c.ctypes.data_as(ctypes.POINTER(ctypes.c_uint64)),
|
|
171
|
+
output.ctypes.data_as(ctypes.POINTER(ctypes.c_uint64)),
|
|
172
|
+
n,
|
|
173
|
+
m,
|
|
174
|
+
seed_c.ctypes.data_as(ctypes.POINTER(ctypes.c_uint64)),
|
|
175
|
+
)
|
|
176
|
+
return output
|
|
177
|
+
|
|
178
|
+
|
|
179
|
+
def _okvs_decode_opt_impl(
|
|
180
|
+
keys: np.ndarray, storage: np.ndarray, m: int, seed: np.ndarray
|
|
181
|
+
) -> np.ndarray:
|
|
182
|
+
lib = _get_lib()
|
|
183
|
+
if seed.ndim > 1:
|
|
184
|
+
seed = seed.flatten()
|
|
185
|
+
|
|
186
|
+
if lib is None:
|
|
187
|
+
return _okvs_decode_impl(keys, storage, m, seed)
|
|
188
|
+
|
|
189
|
+
n = keys.shape[0]
|
|
190
|
+
|
|
191
|
+
# Heuristic: Mega-Binning (1024 Bins) is unstable for small N due to variance.
|
|
192
|
+
# It requires ~1000 items/bin to be efficient and stable with epsilon=1.3.
|
|
193
|
+
# Threshold: 200,000 (approx 200 items/bin). Below this, Naive is fast enough (<50ms).
|
|
194
|
+
if n < 200_000:
|
|
195
|
+
return _okvs_decode_impl(keys, storage, m, seed)
|
|
196
|
+
|
|
197
|
+
if m / n < 1.32:
|
|
198
|
+
return _okvs_decode_impl(keys, storage, m, seed)
|
|
199
|
+
|
|
200
|
+
keys_c = np.ascontiguousarray(keys, dtype=np.uint64)
|
|
201
|
+
storage_c = np.ascontiguousarray(storage, dtype=np.uint64)
|
|
202
|
+
seed_c = np.ascontiguousarray(seed, dtype=np.uint64)
|
|
203
|
+
output = np.zeros((n, 2), dtype=np.uint64)
|
|
204
|
+
|
|
205
|
+
lib.decode_okvs_opt(
|
|
206
|
+
keys_c.ctypes.data_as(ctypes.POINTER(ctypes.c_uint64)),
|
|
207
|
+
storage_c.ctypes.data_as(ctypes.POINTER(ctypes.c_uint64)),
|
|
208
|
+
output.ctypes.data_as(ctypes.POINTER(ctypes.c_uint64)),
|
|
209
|
+
n,
|
|
210
|
+
m,
|
|
211
|
+
seed_c.ctypes.data_as(ctypes.POINTER(ctypes.c_uint64)),
|
|
212
|
+
)
|
|
213
|
+
return output
|
|
214
|
+
|
|
215
|
+
|
|
216
|
+
def _okvs_solve_impl(
|
|
217
|
+
keys: np.ndarray, values: np.ndarray, m: int, seed: np.ndarray
|
|
218
|
+
) -> np.ndarray:
|
|
219
|
+
lib = _get_lib()
|
|
220
|
+
# Ensure seed is flat tuple or array
|
|
221
|
+
if seed.ndim > 1:
|
|
222
|
+
seed = seed.flatten()
|
|
223
|
+
s_tuple = (int(seed[0]), int(seed[1]))
|
|
224
|
+
|
|
225
|
+
if lib is None:
|
|
226
|
+
# Use pure Python fallback
|
|
227
|
+
keys_flat = keys.flatten() if keys.ndim > 1 else keys
|
|
228
|
+
return py_kernels.okvs_solve(keys_flat, values, m, seed=s_tuple)
|
|
229
|
+
|
|
230
|
+
n = keys.shape[0]
|
|
231
|
+
keys_c = np.ascontiguousarray(keys, dtype=np.uint64)
|
|
232
|
+
values_c = np.ascontiguousarray(values, dtype=np.uint64)
|
|
233
|
+
seed_c = np.ascontiguousarray(seed, dtype=np.uint64)
|
|
234
|
+
output = np.zeros((m, 2), dtype=np.uint64)
|
|
235
|
+
|
|
236
|
+
lib.solve_okvs(
|
|
237
|
+
keys_c.ctypes.data_as(ctypes.POINTER(ctypes.c_uint64)),
|
|
238
|
+
values_c.ctypes.data_as(ctypes.POINTER(ctypes.c_uint64)),
|
|
239
|
+
output.ctypes.data_as(ctypes.POINTER(ctypes.c_uint64)),
|
|
240
|
+
n,
|
|
241
|
+
m,
|
|
242
|
+
seed_c.ctypes.data_as(ctypes.POINTER(ctypes.c_uint64)),
|
|
243
|
+
)
|
|
244
|
+
return output
|
|
245
|
+
|
|
246
|
+
|
|
247
|
+
def _okvs_decode_impl(
|
|
248
|
+
keys: np.ndarray, storage: np.ndarray, m: int, seed: np.ndarray
|
|
249
|
+
) -> np.ndarray:
|
|
250
|
+
lib = _get_lib()
|
|
251
|
+
# Ensure seed is flat tuple or array
|
|
252
|
+
if seed.ndim > 1:
|
|
253
|
+
seed = seed.flatten()
|
|
254
|
+
s_tuple = (int(seed[0]), int(seed[1]))
|
|
255
|
+
|
|
256
|
+
if lib is None:
|
|
257
|
+
# Use pure Python fallback
|
|
258
|
+
keys_flat = keys.flatten() if keys.ndim > 1 else keys
|
|
259
|
+
return py_kernels.okvs_decode(keys_flat, storage, m, seed=s_tuple)
|
|
260
|
+
|
|
261
|
+
n = keys.shape[0]
|
|
262
|
+
keys_c = np.ascontiguousarray(keys, dtype=np.uint64)
|
|
263
|
+
storage_c = np.ascontiguousarray(storage, dtype=np.uint64)
|
|
264
|
+
seed_c = np.ascontiguousarray(seed, dtype=np.uint64)
|
|
265
|
+
output = np.zeros((n, 2), dtype=np.uint64)
|
|
266
|
+
|
|
267
|
+
lib.decode_okvs(
|
|
268
|
+
keys_c.ctypes.data_as(ctypes.POINTER(ctypes.c_uint64)),
|
|
269
|
+
storage_c.ctypes.data_as(ctypes.POINTER(ctypes.c_uint64)),
|
|
270
|
+
output.ctypes.data_as(ctypes.POINTER(ctypes.c_uint64)),
|
|
271
|
+
n,
|
|
272
|
+
m,
|
|
273
|
+
seed_c.ctypes.data_as(ctypes.POINTER(ctypes.c_uint64)),
|
|
274
|
+
)
|
|
275
|
+
return output
|
|
276
|
+
|
|
277
|
+
|
|
278
|
+
def ldpc_encode_impl(
|
|
279
|
+
message: np.ndarray, h_indices: np.ndarray, h_indptr: np.ndarray, m: int
|
|
280
|
+
) -> np.ndarray:
|
|
281
|
+
lib = _get_lib()
|
|
282
|
+
if lib is None:
|
|
283
|
+
# Use pure Python fallback
|
|
284
|
+
h_idx_flat = h_indices.flatten() if h_indices.ndim > 1 else h_indices
|
|
285
|
+
h_ptr_flat = h_indptr.flatten() if h_indptr.ndim > 1 else h_indptr
|
|
286
|
+
return py_kernels.ldpc_encode(message, h_idx_flat, h_ptr_flat, m)
|
|
287
|
+
|
|
288
|
+
# Fast C++ Path
|
|
289
|
+
msg_c = np.ascontiguousarray(message, dtype=np.uint64)
|
|
290
|
+
idx_c = np.ascontiguousarray(h_indices, dtype=np.uint64)
|
|
291
|
+
ptr_c = np.ascontiguousarray(h_indptr, dtype=np.uint64)
|
|
292
|
+
|
|
293
|
+
output = np.zeros((m, 2), dtype=np.uint64)
|
|
294
|
+
|
|
295
|
+
# n is inferred from message length
|
|
296
|
+
n = message.shape[0]
|
|
297
|
+
|
|
298
|
+
lib.ldpc_encode(
|
|
299
|
+
msg_c.ctypes.data_as(ctypes.POINTER(ctypes.c_uint64)),
|
|
300
|
+
idx_c.ctypes.data_as(ctypes.POINTER(ctypes.c_uint64)),
|
|
301
|
+
ptr_c.ctypes.data_as(ctypes.POINTER(ctypes.c_uint64)),
|
|
302
|
+
output.ctypes.data_as(ctypes.POINTER(ctypes.c_uint64)),
|
|
303
|
+
m,
|
|
304
|
+
n,
|
|
305
|
+
)
|
|
306
|
+
return output
|
|
307
|
+
|
|
308
|
+
|
|
309
|
+
# =============================================================================
|
|
310
|
+
# Primitive Implementations
|
|
311
|
+
# =============================================================================
|
|
312
|
+
|
|
313
|
+
|
|
314
|
+
@field.ldpc_encode_p.def_impl
|
|
315
|
+
def _ldpc_encode_impl_prim(
|
|
316
|
+
interpreter: Interpreter,
|
|
317
|
+
op: Operation,
|
|
318
|
+
message_val: TensorValue,
|
|
319
|
+
indices_val: TensorValue,
|
|
320
|
+
indptr_val: TensorValue,
|
|
321
|
+
) -> TensorValue:
|
|
322
|
+
m = op.attrs["m"]
|
|
323
|
+
message = _unwrap(message_val)
|
|
324
|
+
indices = _unwrap(indices_val)
|
|
325
|
+
indptr = _unwrap(indptr_val)
|
|
326
|
+
res = ldpc_encode_impl(message, indices, indptr, m)
|
|
327
|
+
return _wrap(res)
|
|
328
|
+
|
|
329
|
+
|
|
330
|
+
@field.aes_expand_p.def_impl
|
|
331
|
+
def _aes_expand_impl_prim(
|
|
332
|
+
interpreter: Interpreter, op: Operation, seeds_val: TensorValue
|
|
333
|
+
) -> TensorValue:
|
|
334
|
+
length = op.attrs["length"]
|
|
335
|
+
seeds = _unwrap(seeds_val)
|
|
336
|
+
|
|
337
|
+
# JAX PRG Fallback crashed. Switching to NumPy PRG.
|
|
338
|
+
|
|
339
|
+
# Check if bytes
|
|
340
|
+
if seeds.dtype == np.uint8 and seeds.shape[-1] == 16:
|
|
341
|
+
seeds = seeds.view(np.uint64)
|
|
342
|
+
|
|
343
|
+
if seeds.shape[-1] != 2:
|
|
344
|
+
seeds = seeds.reshape(-1, 2)
|
|
345
|
+
|
|
346
|
+
num_seeds = seeds.shape[0]
|
|
347
|
+
out_shape = (num_seeds, length, 2)
|
|
348
|
+
output = np.zeros(out_shape, dtype=np.uint64)
|
|
349
|
+
|
|
350
|
+
lib = _get_lib()
|
|
351
|
+
if lib is not None:
|
|
352
|
+
# Fast C++ Path
|
|
353
|
+
seeds_c = np.ascontiguousarray(seeds, dtype=np.uint64)
|
|
354
|
+
|
|
355
|
+
lib.aes_128_expand(
|
|
356
|
+
seeds_c.ctypes.data_as(ctypes.POINTER(ctypes.c_uint64)),
|
|
357
|
+
output.ctypes.data_as(ctypes.POINTER(ctypes.c_uint64)),
|
|
358
|
+
num_seeds,
|
|
359
|
+
length,
|
|
360
|
+
)
|
|
361
|
+
else:
|
|
362
|
+
# Slow Python Path (Fallback)
|
|
363
|
+
# Iterate and generate
|
|
364
|
+
for i in range(num_seeds):
|
|
365
|
+
# Seed from pair
|
|
366
|
+
s0 = int(seeds[i, 0])
|
|
367
|
+
s1 = int(seeds[i, 1])
|
|
368
|
+
seed_val = [s0, s1]
|
|
369
|
+
|
|
370
|
+
rng = np.random.default_rng(seed_val)
|
|
371
|
+
vals = rng.integers(
|
|
372
|
+
0, 0xFFFFFFFFFFFFFFFF, size=(length, 2), dtype=np.uint64
|
|
373
|
+
)
|
|
374
|
+
output[i] = vals
|
|
375
|
+
|
|
376
|
+
# Return as JAX array to keep downstream happy
|
|
377
|
+
res_jax = jnp.array(output)
|
|
378
|
+
|
|
379
|
+
return _wrap(res_jax)
|
|
380
|
+
|
|
381
|
+
|
|
382
|
+
@field.mul_p.def_impl
|
|
383
|
+
def _mul_impl(
|
|
384
|
+
interpreter: Interpreter, op: Operation, a_val: TensorValue, b_val: TensorValue
|
|
385
|
+
) -> TensorValue:
|
|
386
|
+
a = a_val.unwrap()
|
|
387
|
+
b = b_val.unwrap()
|
|
388
|
+
res = _gf128_mul_impl(a, b)
|
|
389
|
+
return TensorValue(res)
|
|
390
|
+
|
|
391
|
+
|
|
392
|
+
@field.solve_okvs_p.def_impl
|
|
393
|
+
def _solve_okvs_impl(
|
|
394
|
+
interpreter: Interpreter,
|
|
395
|
+
op: Operation,
|
|
396
|
+
keys_val: TensorValue,
|
|
397
|
+
values_val: TensorValue,
|
|
398
|
+
seed_val: TensorValue,
|
|
399
|
+
) -> TensorValue:
|
|
400
|
+
m = op.attrs["m"]
|
|
401
|
+
keys = _unwrap(keys_val)
|
|
402
|
+
values = _unwrap(values_val)
|
|
403
|
+
seed = _unwrap(seed_val)
|
|
404
|
+
res = _okvs_solve_impl(keys, values, m, seed)
|
|
405
|
+
return _wrap(res)
|
|
406
|
+
|
|
407
|
+
|
|
408
|
+
@field.decode_okvs_p.def_impl
|
|
409
|
+
def _decode_okvs_impl(
|
|
410
|
+
interpreter: Interpreter,
|
|
411
|
+
op: Operation,
|
|
412
|
+
keys_val: TensorValue,
|
|
413
|
+
store_val: TensorValue,
|
|
414
|
+
seed_val: TensorValue,
|
|
415
|
+
) -> TensorValue:
|
|
416
|
+
keys = _unwrap(keys_val)
|
|
417
|
+
storage = _unwrap(store_val)
|
|
418
|
+
seed = _unwrap(seed_val)
|
|
419
|
+
m = storage.shape[0]
|
|
420
|
+
res = _okvs_decode_impl(keys, storage, m, seed)
|
|
421
|
+
return _wrap(res)
|
|
422
|
+
return _wrap(res)
|
|
423
|
+
|
|
424
|
+
|
|
425
|
+
@field.solve_okvs_opt_p.def_impl
|
|
426
|
+
def _solve_okvs_opt_impl_prim(
|
|
427
|
+
interpreter: Interpreter,
|
|
428
|
+
op: Operation,
|
|
429
|
+
keys_val: TensorValue,
|
|
430
|
+
values_val: TensorValue,
|
|
431
|
+
seed_val: TensorValue,
|
|
432
|
+
) -> TensorValue:
|
|
433
|
+
m = op.attrs["m"]
|
|
434
|
+
keys = _unwrap(keys_val)
|
|
435
|
+
values = _unwrap(values_val)
|
|
436
|
+
seed = _unwrap(seed_val)
|
|
437
|
+
res = _okvs_solve_opt_impl(keys, values, m, seed)
|
|
438
|
+
return _wrap(res)
|
|
439
|
+
|
|
440
|
+
|
|
441
|
+
@field.decode_okvs_opt_p.def_impl
|
|
442
|
+
def _decode_okvs_opt_impl_prim(
|
|
443
|
+
interpreter: Interpreter,
|
|
444
|
+
op: Operation,
|
|
445
|
+
keys_val: TensorValue,
|
|
446
|
+
store_val: TensorValue,
|
|
447
|
+
seed_val: TensorValue,
|
|
448
|
+
) -> TensorValue:
|
|
449
|
+
keys = _unwrap(keys_val)
|
|
450
|
+
storage = _unwrap(store_val)
|
|
451
|
+
seed = _unwrap(seed_val)
|
|
452
|
+
m = storage.shape[0]
|
|
453
|
+
res = _okvs_decode_opt_impl(keys, storage, m, seed)
|
|
454
|
+
return _wrap(res)
|
|
@@ -0,0 +1,107 @@
|
|
|
1
|
+
# Copyright 2025 Ant Group Co., Ltd.
|
|
2
|
+
#
|
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
4
|
+
# you may not use this file except in compliance with the License.
|
|
5
|
+
# You may obtain a copy of the License at
|
|
6
|
+
#
|
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
8
|
+
#
|
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
12
|
+
# See the License for the specific language governing permissions and
|
|
13
|
+
# limitations under the License.
|
|
14
|
+
|
|
15
|
+
"""Generic kernel implementations for the `func` dialect.
|
|
16
|
+
|
|
17
|
+
Design: Function as Value
|
|
18
|
+
- func.func impl: Returns a FunctionValue wrapping the traced graph.
|
|
19
|
+
- func.call impl: Executes the function graph with provided arguments.
|
|
20
|
+
"""
|
|
21
|
+
|
|
22
|
+
from __future__ import annotations
|
|
23
|
+
|
|
24
|
+
from typing import TYPE_CHECKING, Any, ClassVar
|
|
25
|
+
|
|
26
|
+
from mplang.v2.dialects.func import call_p, func_def_p
|
|
27
|
+
from mplang.v2.edsl import serde
|
|
28
|
+
from mplang.v2.edsl.graph import Graph, Operation
|
|
29
|
+
from mplang.v2.runtime.value import Value
|
|
30
|
+
|
|
31
|
+
if TYPE_CHECKING:
|
|
32
|
+
from typing import Self
|
|
33
|
+
|
|
34
|
+
|
|
35
|
+
@serde.register_class
|
|
36
|
+
class FunctionValue(Value):
|
|
37
|
+
"""Runtime representation of a traced function.
|
|
38
|
+
|
|
39
|
+
This is a first-class runtime Value that wraps a Graph.
|
|
40
|
+
Produced by func.func, consumed by func.call.
|
|
41
|
+
|
|
42
|
+
Semantic rationale:
|
|
43
|
+
In the interpreter's computation model, Values are data that flow
|
|
44
|
+
between Operations. A function (Graph) is just another kind of data
|
|
45
|
+
that can be passed around, stored, and invoked - hence it's a Value.
|
|
46
|
+
"""
|
|
47
|
+
|
|
48
|
+
_serde_kind: ClassVar[str] = "func.FunctionValue"
|
|
49
|
+
|
|
50
|
+
def __init__(self, graph: Graph, name: str = "anonymous") -> None:
|
|
51
|
+
self._graph = graph
|
|
52
|
+
self._name = name
|
|
53
|
+
|
|
54
|
+
@property
|
|
55
|
+
def graph(self) -> Graph:
|
|
56
|
+
return self._graph
|
|
57
|
+
|
|
58
|
+
@property
|
|
59
|
+
def name(self) -> str:
|
|
60
|
+
return self._name
|
|
61
|
+
|
|
62
|
+
def __repr__(self) -> str:
|
|
63
|
+
return f"FunctionValue({self._name!r}, ops={len(self._graph.operations)})"
|
|
64
|
+
|
|
65
|
+
def to_json(self) -> dict[str, Any]:
|
|
66
|
+
"""Serialize function to JSON."""
|
|
67
|
+
return {
|
|
68
|
+
"graph": serde.to_json(self._graph),
|
|
69
|
+
"name": self._name,
|
|
70
|
+
}
|
|
71
|
+
|
|
72
|
+
@classmethod
|
|
73
|
+
def from_json(cls, data: dict[str, Any]) -> Self:
|
|
74
|
+
"""Deserialize function from JSON."""
|
|
75
|
+
graph = serde.from_json(data["graph"])
|
|
76
|
+
return cls(graph=graph, name=data.get("name", "anonymous"))
|
|
77
|
+
|
|
78
|
+
|
|
79
|
+
@func_def_p.def_impl
|
|
80
|
+
def _func_def_impl(interpreter: Any, op: Operation, *args: Any) -> FunctionValue:
|
|
81
|
+
"""Implementation of func.func: return a FunctionValue wrapping the body graph."""
|
|
82
|
+
if not op.regions:
|
|
83
|
+
raise ValueError("func.func operation missing body region")
|
|
84
|
+
|
|
85
|
+
name = op.attrs.get("sym_name", "anonymous")
|
|
86
|
+
return FunctionValue(graph=op.regions[0], name=name)
|
|
87
|
+
|
|
88
|
+
|
|
89
|
+
@call_p.def_impl
|
|
90
|
+
def _call_impl(
|
|
91
|
+
interpreter: Any, op: Operation, fn_obj: FunctionValue, *args: Any
|
|
92
|
+
) -> Any:
|
|
93
|
+
"""Implementation of func.call: execute the function graph.
|
|
94
|
+
|
|
95
|
+
Args:
|
|
96
|
+
interpreter: The interpreter instance.
|
|
97
|
+
op: The func.call operation.
|
|
98
|
+
fn_obj: The FunctionValue returned by func.func.
|
|
99
|
+
*args: Arguments to pass to the function.
|
|
100
|
+
"""
|
|
101
|
+
if not isinstance(fn_obj, FunctionValue):
|
|
102
|
+
raise TypeError(f"func.call expects FunctionValue, got {type(fn_obj)}")
|
|
103
|
+
|
|
104
|
+
call_args = list(args)
|
|
105
|
+
result = interpreter.evaluate_graph(fn_obj.graph, call_args)
|
|
106
|
+
# Return single value or list based on graph outputs
|
|
107
|
+
return result[0] if len(fn_obj.graph.outputs) == 1 else result
|
|
@@ -0,0 +1,148 @@
|
|
|
1
|
+
# Copyright 2025 Ant Group Co., Ltd.
|
|
2
|
+
#
|
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
4
|
+
# you may not use this file except in compliance with the License.
|
|
5
|
+
# You may obtain a copy of the License at
|
|
6
|
+
#
|
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
8
|
+
#
|
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
12
|
+
# See the License for the specific language governing permissions and
|
|
13
|
+
# limitations under the License.
|
|
14
|
+
|
|
15
|
+
"""PHE Runtime Implementation using LightPHE."""
|
|
16
|
+
|
|
17
|
+
from __future__ import annotations
|
|
18
|
+
|
|
19
|
+
from dataclasses import dataclass
|
|
20
|
+
from typing import Any, cast
|
|
21
|
+
|
|
22
|
+
from lightphe import LightPHE
|
|
23
|
+
from lightphe.models.Ciphertext import Ciphertext
|
|
24
|
+
|
|
25
|
+
from mplang.v2.dialects import phe
|
|
26
|
+
from mplang.v2.edsl.graph import Operation
|
|
27
|
+
from mplang.v2.runtime.interpreter import Interpreter
|
|
28
|
+
|
|
29
|
+
|
|
30
|
+
class PHEContext:
|
|
31
|
+
"""Wraps LightPHE context."""
|
|
32
|
+
|
|
33
|
+
def __init__(self, algorithm_name: str = "Paillier", key_size: int = 2048):
|
|
34
|
+
# Normalize algorithm name (LightPHE expects capitalized names)
|
|
35
|
+
normalized_name = algorithm_name.capitalize()
|
|
36
|
+
self.cs = LightPHE(algorithm_name=normalized_name, key_size=key_size)
|
|
37
|
+
|
|
38
|
+
def encrypt(self, value: int) -> Ciphertext:
|
|
39
|
+
return self.cs.encrypt(value)
|
|
40
|
+
|
|
41
|
+
def decrypt(self, ct: Ciphertext) -> int:
|
|
42
|
+
return cast(int, self.cs.decrypt(ct))
|
|
43
|
+
|
|
44
|
+
|
|
45
|
+
@dataclass
|
|
46
|
+
class PHEEncoder:
|
|
47
|
+
"""Simple fixed-point encoder."""
|
|
48
|
+
|
|
49
|
+
scale: float
|
|
50
|
+
|
|
51
|
+
|
|
52
|
+
@dataclass
|
|
53
|
+
class WrappedCiphertext:
|
|
54
|
+
ct: Ciphertext
|
|
55
|
+
ctx: PHEContext
|
|
56
|
+
|
|
57
|
+
def __add__(self, other: Any) -> WrappedCiphertext:
|
|
58
|
+
if isinstance(other, WrappedCiphertext):
|
|
59
|
+
# ct + ct
|
|
60
|
+
new_ct = self.ct + other.ct
|
|
61
|
+
return WrappedCiphertext(new_ct, self.ctx)
|
|
62
|
+
elif isinstance(other, int):
|
|
63
|
+
# ct + int -> ct + encrypt(int)
|
|
64
|
+
ct_other = self.ctx.encrypt(other)
|
|
65
|
+
new_ct = self.ct + ct_other
|
|
66
|
+
return WrappedCiphertext(new_ct, self.ctx)
|
|
67
|
+
return NotImplemented
|
|
68
|
+
|
|
69
|
+
def __mul__(self, other: Any) -> WrappedCiphertext:
|
|
70
|
+
if isinstance(other, int):
|
|
71
|
+
# ct * int
|
|
72
|
+
new_ct = self.ct * other
|
|
73
|
+
return WrappedCiphertext(new_ct, self.ctx)
|
|
74
|
+
return NotImplemented
|
|
75
|
+
|
|
76
|
+
|
|
77
|
+
@phe.keygen_p.def_impl
|
|
78
|
+
def keygen_impl(
|
|
79
|
+
interpreter: Interpreter, op: Operation, *args: Any
|
|
80
|
+
) -> tuple[PHEContext, PHEContext]:
|
|
81
|
+
key_size = op.attrs.get("key_size", 2048)
|
|
82
|
+
scheme = op.attrs.get("scheme", "Paillier")
|
|
83
|
+
|
|
84
|
+
ctx = PHEContext(algorithm_name=scheme, key_size=key_size)
|
|
85
|
+
|
|
86
|
+
return ctx, ctx
|
|
87
|
+
|
|
88
|
+
|
|
89
|
+
@phe.create_encoder_p.def_impl
|
|
90
|
+
def create_encoder_impl(
|
|
91
|
+
interpreter: Interpreter, op: Operation, *args: Any
|
|
92
|
+
) -> PHEEncoder:
|
|
93
|
+
fxp_bits = op.attrs.get("fxp_bits", 16)
|
|
94
|
+
scale = 2.0**fxp_bits
|
|
95
|
+
return PHEEncoder(scale=scale)
|
|
96
|
+
|
|
97
|
+
|
|
98
|
+
@phe.encode_p.def_impl
|
|
99
|
+
def encode_impl(
|
|
100
|
+
interpreter: Interpreter, op: Operation, value: float, encoder: PHEEncoder
|
|
101
|
+
) -> int:
|
|
102
|
+
return int(value * encoder.scale)
|
|
103
|
+
|
|
104
|
+
|
|
105
|
+
@phe.decode_p.def_impl
|
|
106
|
+
def decode_impl(
|
|
107
|
+
interpreter: Interpreter, op: Operation, value: int, encoder: PHEEncoder
|
|
108
|
+
) -> float:
|
|
109
|
+
return float(value) / encoder.scale
|
|
110
|
+
|
|
111
|
+
|
|
112
|
+
@phe.encrypt_p.def_impl
|
|
113
|
+
def encrypt_impl(
|
|
114
|
+
interpreter: Interpreter, op: Operation, value: int, pk: PHEContext
|
|
115
|
+
) -> WrappedCiphertext:
|
|
116
|
+
ct = pk.encrypt(value)
|
|
117
|
+
return WrappedCiphertext(ct, pk)
|
|
118
|
+
|
|
119
|
+
|
|
120
|
+
@phe.decrypt_p.def_impl
|
|
121
|
+
def decrypt_impl(
|
|
122
|
+
interpreter: Interpreter, op: Operation, wct: WrappedCiphertext, sk: PHEContext
|
|
123
|
+
) -> int:
|
|
124
|
+
return sk.decrypt(wct.ct)
|
|
125
|
+
|
|
126
|
+
|
|
127
|
+
@phe.add_cc_p.def_impl
|
|
128
|
+
def add_cc_impl(
|
|
129
|
+
interpreter: Interpreter,
|
|
130
|
+
op: Operation,
|
|
131
|
+
lhs: WrappedCiphertext,
|
|
132
|
+
rhs: WrappedCiphertext,
|
|
133
|
+
) -> WrappedCiphertext:
|
|
134
|
+
return lhs + rhs
|
|
135
|
+
|
|
136
|
+
|
|
137
|
+
@phe.add_cp_p.def_impl
|
|
138
|
+
def add_cp_impl(
|
|
139
|
+
interpreter: Interpreter, op: Operation, lhs: WrappedCiphertext, rhs: int
|
|
140
|
+
) -> WrappedCiphertext:
|
|
141
|
+
return lhs + rhs
|
|
142
|
+
|
|
143
|
+
|
|
144
|
+
@phe.mul_cp_p.def_impl
|
|
145
|
+
def mul_cp_impl(
|
|
146
|
+
interpreter: Interpreter, op: Operation, lhs: WrappedCiphertext, rhs: int
|
|
147
|
+
) -> WrappedCiphertext:
|
|
148
|
+
return lhs * rhs
|