mplang-nightly 0.1.dev268__py3-none-any.whl → 0.1.dev270__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- mplang/__init__.py +391 -17
- mplang/{v2/backends → backends}/__init__.py +9 -7
- mplang/{v2/backends → backends}/bfv_impl.py +6 -6
- mplang/{v2/backends → backends}/crypto_impl.py +6 -6
- mplang/{v2/backends → backends}/field_impl.py +5 -5
- mplang/{v2/backends → backends}/func_impl.py +4 -4
- mplang/{v2/backends → backends}/phe_impl.py +3 -3
- mplang/{v2/backends → backends}/simp_design.md +1 -1
- mplang/{v2/backends → backends}/simp_driver/__init__.py +5 -5
- mplang/{v2/backends → backends}/simp_driver/http.py +8 -8
- mplang/{v2/backends → backends}/simp_driver/mem.py +9 -9
- mplang/{v2/backends → backends}/simp_driver/ops.py +4 -4
- mplang/{v2/backends → backends}/simp_driver/state.py +2 -2
- mplang/{v2/backends → backends}/simp_driver/values.py +2 -2
- mplang/{v2/backends → backends}/simp_worker/__init__.py +3 -3
- mplang/{v2/backends → backends}/simp_worker/http.py +10 -10
- mplang/{v2/backends → backends}/simp_worker/mem.py +1 -1
- mplang/{v2/backends → backends}/simp_worker/ops.py +5 -5
- mplang/{v2/backends → backends}/simp_worker/state.py +2 -4
- mplang/{v2/backends → backends}/spu_impl.py +8 -8
- mplang/{v2/backends → backends}/spu_state.py +4 -4
- mplang/{v2/backends → backends}/store_impl.py +3 -3
- mplang/{v2/backends → backends}/table_impl.py +8 -8
- mplang/{v2/backends → backends}/tee_impl.py +6 -6
- mplang/{v2/backends → backends}/tensor_impl.py +6 -6
- mplang/{v2/cli.py → cli.py} +9 -9
- mplang/{v2/cli_guide.md → cli_guide.md} +12 -12
- mplang/{v2/dialects → dialects}/__init__.py +5 -5
- mplang/{v2/dialects → dialects}/bfv.py +6 -6
- mplang/{v2/dialects → dialects}/crypto.py +5 -5
- mplang/{v2/dialects → dialects}/dtypes.py +2 -2
- mplang/{v2/dialects → dialects}/field.py +3 -3
- mplang/{v2/dialects → dialects}/func.py +2 -2
- mplang/{v2/dialects → dialects}/phe.py +6 -6
- mplang/{v2/dialects → dialects}/simp.py +6 -6
- mplang/{v2/dialects → dialects}/spu.py +7 -7
- mplang/{v2/dialects → dialects}/store.py +2 -2
- mplang/{v2/dialects → dialects}/table.py +3 -3
- mplang/{v2/dialects → dialects}/tee.py +6 -6
- mplang/{v2/dialects → dialects}/tensor.py +5 -5
- mplang/{v2/edsl → edsl}/__init__.py +3 -3
- mplang/{v2/edsl → edsl}/context.py +6 -6
- mplang/{v2/edsl → edsl}/graph.py +5 -5
- mplang/{v2/edsl → edsl}/jit.py +2 -2
- mplang/{v2/edsl → edsl}/object.py +1 -1
- mplang/{v2/edsl → edsl}/primitive.py +5 -5
- mplang/{v2/edsl → edsl}/printer.py +1 -1
- mplang/{v2/edsl → edsl}/serde.py +1 -1
- mplang/{v2/edsl → edsl}/tracer.py +7 -7
- mplang/{v2/edsl → edsl}/typing.py +1 -1
- mplang/{v2/kernels → kernels}/ldpc.cpp +13 -13
- mplang/{v2/kernels → kernels}/okvs.cpp +4 -4
- mplang/{v2/kernels → kernels}/okvs_opt.cpp +46 -31
- mplang/{v2/kernels → kernels}/py_kernels.py +1 -1
- mplang/{v2/libs → libs}/collective.py +5 -5
- mplang/{v2/libs → libs}/device/__init__.py +1 -1
- mplang/{v2/libs → libs}/device/api.py +12 -12
- mplang/{v2/libs → libs}/ml/__init__.py +1 -1
- mplang/{v2/libs → libs}/ml/sgb.py +4 -4
- mplang/{v2/libs → libs}/mpc/__init__.py +3 -3
- mplang/{v2/libs → libs}/mpc/_utils.py +2 -2
- mplang/{v2/libs → libs}/mpc/analytics/aggregation.py +1 -1
- mplang/{v2/libs → libs}/mpc/analytics/groupby.py +2 -2
- mplang/{v2/libs → libs}/mpc/analytics/permutation.py +3 -3
- mplang/{v2/libs → libs}/mpc/ot/base.py +3 -3
- mplang/{v2/libs → libs}/mpc/ot/extension.py +2 -2
- mplang/{v2/libs → libs}/mpc/ot/silent.py +4 -4
- mplang/{v2/libs → libs}/mpc/psi/cuckoo.py +3 -3
- mplang/{v2/libs → libs}/mpc/psi/okvs.py +1 -1
- mplang/{v2/libs → libs}/mpc/psi/okvs_gct.py +19 -13
- mplang/{v2/libs → libs}/mpc/psi/oprf.py +3 -3
- mplang/libs/mpc/psi/rr22.py +303 -0
- mplang/{v2/libs → libs}/mpc/psi/unbalanced.py +4 -4
- mplang/{v2/libs → libs}/mpc/vole/gilboa.py +3 -3
- mplang/{v2/libs → libs}/mpc/vole/ldpc.py +2 -2
- mplang/{v2/libs → libs}/mpc/vole/silver.py +6 -6
- mplang/{v2/runtime → runtime}/interpreter.py +11 -11
- mplang/{v2/runtime → runtime}/value.py +2 -2
- mplang/{v1/runtime → utils}/__init__.py +18 -15
- mplang/{v1/utils → utils}/func_utils.py +1 -1
- {mplang_nightly-0.1.dev268.dist-info → mplang_nightly-0.1.dev270.dist-info}/METADATA +2 -2
- mplang_nightly-0.1.dev270.dist-info/RECORD +102 -0
- mplang/v1/__init__.py +0 -157
- mplang/v1/_device.py +0 -602
- mplang/v1/analysis/__init__.py +0 -37
- mplang/v1/analysis/diagram.py +0 -567
- mplang/v1/core/__init__.py +0 -157
- mplang/v1/core/cluster.py +0 -343
- mplang/v1/core/comm.py +0 -281
- mplang/v1/core/context_mgr.py +0 -50
- mplang/v1/core/dtypes.py +0 -335
- mplang/v1/core/expr/__init__.py +0 -80
- mplang/v1/core/expr/ast.py +0 -542
- mplang/v1/core/expr/evaluator.py +0 -581
- mplang/v1/core/expr/printer.py +0 -285
- mplang/v1/core/expr/transformer.py +0 -141
- mplang/v1/core/expr/utils.py +0 -78
- mplang/v1/core/expr/visitor.py +0 -85
- mplang/v1/core/expr/walk.py +0 -387
- mplang/v1/core/interp.py +0 -160
- mplang/v1/core/mask.py +0 -325
- mplang/v1/core/mpir.py +0 -965
- mplang/v1/core/mpobject.py +0 -117
- mplang/v1/core/mptype.py +0 -407
- mplang/v1/core/pfunc.py +0 -130
- mplang/v1/core/primitive.py +0 -877
- mplang/v1/core/table.py +0 -218
- mplang/v1/core/tensor.py +0 -75
- mplang/v1/core/tracer.py +0 -383
- mplang/v1/host.py +0 -130
- mplang/v1/kernels/__init__.py +0 -41
- mplang/v1/kernels/base.py +0 -125
- mplang/v1/kernels/basic.py +0 -240
- mplang/v1/kernels/context.py +0 -369
- mplang/v1/kernels/crypto.py +0 -122
- mplang/v1/kernels/fhe.py +0 -858
- mplang/v1/kernels/mock_tee.py +0 -72
- mplang/v1/kernels/phe.py +0 -1864
- mplang/v1/kernels/spu.py +0 -341
- mplang/v1/kernels/sql_duckdb.py +0 -44
- mplang/v1/kernels/stablehlo.py +0 -90
- mplang/v1/kernels/value.py +0 -626
- mplang/v1/ops/__init__.py +0 -35
- mplang/v1/ops/base.py +0 -424
- mplang/v1/ops/basic.py +0 -294
- mplang/v1/ops/crypto.py +0 -262
- mplang/v1/ops/fhe.py +0 -272
- mplang/v1/ops/jax_cc.py +0 -147
- mplang/v1/ops/nnx_cc.py +0 -168
- mplang/v1/ops/phe.py +0 -216
- mplang/v1/ops/spu.py +0 -151
- mplang/v1/ops/sql_cc.py +0 -303
- mplang/v1/ops/tee.py +0 -36
- mplang/v1/protos/v1alpha1/mpir_pb2.py +0 -63
- mplang/v1/protos/v1alpha1/mpir_pb2.pyi +0 -557
- mplang/v1/protos/v1alpha1/value_pb2.py +0 -34
- mplang/v1/protos/v1alpha1/value_pb2.pyi +0 -169
- mplang/v1/runtime/channel.py +0 -230
- mplang/v1/runtime/cli.py +0 -451
- mplang/v1/runtime/client.py +0 -456
- mplang/v1/runtime/communicator.py +0 -131
- mplang/v1/runtime/data_providers.py +0 -303
- mplang/v1/runtime/driver.py +0 -324
- mplang/v1/runtime/exceptions.py +0 -27
- mplang/v1/runtime/http_api.md +0 -56
- mplang/v1/runtime/link_comm.py +0 -196
- mplang/v1/runtime/server.py +0 -501
- mplang/v1/runtime/session.py +0 -270
- mplang/v1/runtime/simulation.py +0 -324
- mplang/v1/simp/__init__.py +0 -13
- mplang/v1/simp/api.py +0 -353
- mplang/v1/simp/mpi.py +0 -131
- mplang/v1/simp/party.py +0 -225
- mplang/v1/simp/random.py +0 -120
- mplang/v1/simp/smpc.py +0 -238
- mplang/v1/utils/__init__.py +0 -13
- mplang/v1/utils/crypto.py +0 -32
- mplang/v1/utils/spu_utils.py +0 -130
- mplang/v1/utils/table_utils.py +0 -185
- mplang/v2/__init__.py +0 -424
- mplang/v2/libs/mpc/psi/rr22.py +0 -344
- mplang_nightly-0.1.dev268.dist-info/RECORD +0 -180
- /mplang/{v2/backends → backends}/channel.py +0 -0
- /mplang/{v2/edsl → edsl}/README.md +0 -0
- /mplang/{v2/edsl → edsl}/registry.py +0 -0
- /mplang/{v2/kernels → kernels}/Makefile +0 -0
- /mplang/{v2/kernels → kernels}/__init__.py +0 -0
- /mplang/{v2/kernels → kernels}/gf128.cpp +0 -0
- /mplang/{v2/libs → libs}/device/cluster.py +0 -0
- /mplang/{v2/libs → libs}/mpc/analytics/__init__.py +0 -0
- /mplang/{v2/libs → libs}/mpc/analytics/groupby.md +0 -0
- /mplang/{v2/libs → libs}/mpc/common/constants.py +0 -0
- /mplang/{v2/libs → libs}/mpc/ot/__init__.py +0 -0
- /mplang/{v2/libs → libs}/mpc/psi/__init__.py +0 -0
- /mplang/{v2/libs → libs}/mpc/vole/__init__.py +0 -0
- /mplang/{v2/runtime → runtime}/__init__.py +0 -0
- /mplang/{v2/runtime → runtime}/dialect_state.py +0 -0
- /mplang/{v2/runtime → runtime}/object_store.py +0 -0
- {mplang_nightly-0.1.dev268.dist-info → mplang_nightly-0.1.dev270.dist-info}/WHEEL +0 -0
- {mplang_nightly-0.1.dev268.dist-info → mplang_nightly-0.1.dev270.dist-info}/entry_points.txt +0 -0
- {mplang_nightly-0.1.dev268.dist-info → mplang_nightly-0.1.dev270.dist-info}/licenses/LICENSE +0 -0
|
@@ -12,7 +12,7 @@
|
|
|
12
12
|
# See the License for the specific language governing permissions and
|
|
13
13
|
# limitations under the License.
|
|
14
14
|
|
|
15
|
-
"""TEE (Trusted Execution Environment) dialect for mplang
|
|
15
|
+
"""TEE (Trusted Execution Environment) dialect for mplang EDSL.
|
|
16
16
|
|
|
17
17
|
This dialect provides primitives for TEE remote attestation, enabling secure
|
|
18
18
|
computation where:
|
|
@@ -38,7 +38,7 @@ Supported Platforms:
|
|
|
38
38
|
|
|
39
39
|
Example:
|
|
40
40
|
```python
|
|
41
|
-
from mplang.
|
|
41
|
+
from mplang.dialects import tee, crypto
|
|
42
42
|
|
|
43
43
|
# On TEE side: generate keypair and quote
|
|
44
44
|
sk, pk = crypto.kem_keygen("x25519")
|
|
@@ -62,10 +62,10 @@ from __future__ import annotations
|
|
|
62
62
|
|
|
63
63
|
from typing import Any, ClassVar, Literal
|
|
64
64
|
|
|
65
|
-
import mplang.
|
|
66
|
-
import mplang.
|
|
67
|
-
from mplang.
|
|
68
|
-
from mplang.
|
|
65
|
+
import mplang.edsl as el
|
|
66
|
+
import mplang.edsl.typing as elt
|
|
67
|
+
from mplang.dialects.crypto import PublicKeyType
|
|
68
|
+
from mplang.edsl import serde
|
|
69
69
|
|
|
70
70
|
# ==============================================================================
|
|
71
71
|
# --- Type Definitions
|
|
@@ -58,10 +58,10 @@ import numpy as np
|
|
|
58
58
|
from jax import ShapeDtypeStruct
|
|
59
59
|
from jax.tree_util import PyTreeDef, tree_flatten
|
|
60
60
|
|
|
61
|
-
import mplang.
|
|
62
|
-
import mplang.
|
|
63
|
-
from mplang.
|
|
64
|
-
from mplang.
|
|
61
|
+
import mplang.edsl as el
|
|
62
|
+
import mplang.edsl.typing as elt
|
|
63
|
+
from mplang.dialects import dtypes
|
|
64
|
+
from mplang.utils import normalize_fn
|
|
65
65
|
|
|
66
66
|
run_jax_p = el.Primitive[Any]("tensor.run_jax")
|
|
67
67
|
constant_p = el.Primitive[el.Object]("tensor.constant")
|
|
@@ -77,7 +77,7 @@ class RunJaxCompilation:
|
|
|
77
77
|
|
|
78
78
|
fn: Callable[..., Any]
|
|
79
79
|
stablehlo: str
|
|
80
|
-
out_tree: PyTreeDef
|
|
80
|
+
out_tree: PyTreeDef # type: ignore
|
|
81
81
|
output_types: list[elt.BaseType]
|
|
82
82
|
arg_keep_map: list[int] | None = None
|
|
83
83
|
|
|
@@ -17,8 +17,8 @@
|
|
|
17
17
|
This module keeps the surface area intentionally small so downstream code can
|
|
18
18
|
simply write::
|
|
19
19
|
|
|
20
|
-
import mplang.
|
|
21
|
-
import mplang.
|
|
20
|
+
import mplang.edsl as el
|
|
21
|
+
import mplang.edsl.typing as elt
|
|
22
22
|
|
|
23
23
|
The `el` namespace re-exports the commonly used building blocks (context,
|
|
24
24
|
graph, tracer, primitives, etc.), while the full type system lives under
|
|
@@ -27,7 +27,7 @@ graph, tracer, primitives, etc.), while the full type system lives under
|
|
|
27
27
|
|
|
28
28
|
from __future__ import annotations
|
|
29
29
|
|
|
30
|
-
# Re-export the typing module so callers can `import mplang.
|
|
30
|
+
# Re-export the typing module so callers can `import mplang.edsl.typing as elt`
|
|
31
31
|
from . import typing as typing
|
|
32
32
|
|
|
33
33
|
# Context management
|
|
@@ -21,7 +21,7 @@ This module defines the Context hierarchy:
|
|
|
21
21
|
|
|
22
22
|
Contexts can be used directly with Python's 'with' statement:
|
|
23
23
|
|
|
24
|
-
from mplang.
|
|
24
|
+
from mplang.edsl import Tracer
|
|
25
25
|
|
|
26
26
|
tracer = Tracer()
|
|
27
27
|
with tracer:
|
|
@@ -46,9 +46,9 @@ from collections.abc import Callable
|
|
|
46
46
|
from typing import TYPE_CHECKING, Any, Self
|
|
47
47
|
|
|
48
48
|
if TYPE_CHECKING:
|
|
49
|
-
from mplang.
|
|
50
|
-
from mplang.
|
|
51
|
-
from mplang.
|
|
49
|
+
from mplang.edsl.graph import Graph
|
|
50
|
+
from mplang.edsl.object import Object
|
|
51
|
+
from mplang.edsl.primitive import Primitive
|
|
52
52
|
|
|
53
53
|
|
|
54
54
|
class Context(ABC):
|
|
@@ -257,7 +257,7 @@ def is_tracing() -> bool:
|
|
|
257
257
|
Returns:
|
|
258
258
|
True if the top of the context stack is a Tracer.
|
|
259
259
|
"""
|
|
260
|
-
from mplang.
|
|
260
|
+
from mplang.edsl.tracer import Tracer
|
|
261
261
|
|
|
262
262
|
return isinstance(get_current_context(), Tracer)
|
|
263
263
|
|
|
@@ -280,7 +280,7 @@ def get_default_context() -> Context:
|
|
|
280
280
|
if _default_context_factory is None:
|
|
281
281
|
raise RuntimeError(
|
|
282
282
|
"No default context factory registered. "
|
|
283
|
-
"Ensure mplang.
|
|
283
|
+
"Ensure mplang.edsl is imported or register a factory manually."
|
|
284
284
|
)
|
|
285
285
|
_default_context = _default_context_factory()
|
|
286
286
|
return _default_context
|
mplang/{v2/edsl → edsl}/graph.py
RENAMED
|
@@ -27,8 +27,8 @@ Key Design Principles:
|
|
|
27
27
|
|
|
28
28
|
Example:
|
|
29
29
|
--------
|
|
30
|
-
from mplang.
|
|
31
|
-
from mplang.
|
|
30
|
+
from mplang.edsl.graph import Graph, Operation, Value
|
|
31
|
+
from mplang.edsl.typing import Tensor, f32
|
|
32
32
|
|
|
33
33
|
graph = Graph()
|
|
34
34
|
|
|
@@ -61,8 +61,8 @@ from collections.abc import Sequence
|
|
|
61
61
|
from dataclasses import dataclass, field
|
|
62
62
|
from typing import Any, ClassVar
|
|
63
63
|
|
|
64
|
-
from mplang.
|
|
65
|
-
from mplang.
|
|
64
|
+
from mplang.edsl import serde
|
|
65
|
+
from mplang.edsl.typing import BaseType
|
|
66
66
|
|
|
67
67
|
|
|
68
68
|
@dataclass
|
|
@@ -74,7 +74,7 @@ class Value:
|
|
|
74
74
|
|
|
75
75
|
Attributes:
|
|
76
76
|
name: Unique SSA name (e.g., "%0", "%1", ...)
|
|
77
|
-
type: Type of this value (from mplang.
|
|
77
|
+
type: Type of this value (from mplang.edsl.typing)
|
|
78
78
|
defining_op: Operation that produces this value (None for inputs)
|
|
79
79
|
uses: List of operations that consume this value
|
|
80
80
|
"""
|
mplang/{v2/edsl → edsl}/jit.py
RENAMED
|
@@ -19,12 +19,12 @@ from typing import Any
|
|
|
19
19
|
|
|
20
20
|
from jax.tree_util import tree_map
|
|
21
21
|
|
|
22
|
-
from mplang.
|
|
22
|
+
from mplang.edsl.context import (
|
|
23
23
|
AbstractInterpreter,
|
|
24
24
|
get_current_context,
|
|
25
25
|
get_default_context,
|
|
26
26
|
)
|
|
27
|
-
from mplang.
|
|
27
|
+
from mplang.edsl.tracer import Tracer
|
|
28
28
|
|
|
29
29
|
|
|
30
30
|
def jit(fn: Callable) -> Callable:
|
|
@@ -27,11 +27,11 @@ from typing import TYPE_CHECKING, Any, Generic, TypeVar, cast
|
|
|
27
27
|
|
|
28
28
|
from jax.tree_util import tree_map
|
|
29
29
|
|
|
30
|
-
from mplang.
|
|
31
|
-
from mplang.
|
|
30
|
+
from mplang.edsl.context import get_current_context, get_default_context
|
|
31
|
+
from mplang.edsl.object import Object
|
|
32
32
|
|
|
33
33
|
if TYPE_CHECKING:
|
|
34
|
-
from mplang.
|
|
34
|
+
from mplang.edsl.typing import BaseType
|
|
35
35
|
|
|
36
36
|
T_Ret = TypeVar("T_Ret")
|
|
37
37
|
|
|
@@ -54,7 +54,7 @@ class Primitive(Generic[T_Ret]):
|
|
|
54
54
|
>>>
|
|
55
55
|
>>> @encrypt_p.def_abstract_eval
|
|
56
56
|
>>> def encrypt_abstract(x_type):
|
|
57
|
-
>>> from mplang.
|
|
57
|
+
>>> from mplang.edsl.typing import Vector
|
|
58
58
|
>>> return Vector[x_type.dtype, x_type.shape]
|
|
59
59
|
>>>
|
|
60
60
|
>>> # Execution happens via Graph IR → Backend
|
|
@@ -91,7 +91,7 @@ class Primitive(Generic[T_Ret]):
|
|
|
91
91
|
"""
|
|
92
92
|
self._impl = fn
|
|
93
93
|
# Register with the global interpreter registry
|
|
94
|
-
from mplang.
|
|
94
|
+
from mplang.edsl.registry import register_impl
|
|
95
95
|
|
|
96
96
|
register_impl(self.name, fn)
|
|
97
97
|
return fn
|
mplang/{v2/edsl → edsl}/serde.py
RENAMED
|
@@ -31,14 +31,14 @@ from typing import TYPE_CHECKING, Any, cast
|
|
|
31
31
|
|
|
32
32
|
from jax.tree_util import PyTreeDef, tree_flatten, tree_map
|
|
33
33
|
|
|
34
|
-
from mplang.
|
|
35
|
-
from mplang.
|
|
36
|
-
from mplang.
|
|
37
|
-
from mplang.
|
|
38
|
-
from mplang.
|
|
34
|
+
from mplang.edsl.context import Context
|
|
35
|
+
from mplang.edsl.graph import Graph
|
|
36
|
+
from mplang.edsl.graph import Value as GraphValue
|
|
37
|
+
from mplang.edsl.object import Object
|
|
38
|
+
from mplang.edsl.typing import BaseType
|
|
39
39
|
|
|
40
40
|
if TYPE_CHECKING:
|
|
41
|
-
from mplang.
|
|
41
|
+
from mplang.edsl.primitive import Primitive
|
|
42
42
|
|
|
43
43
|
|
|
44
44
|
class TraceObject(Object):
|
|
@@ -48,7 +48,7 @@ class TraceObject(Object):
|
|
|
48
48
|
All operations delegate to primitives which record into Graph.
|
|
49
49
|
|
|
50
50
|
Example:
|
|
51
|
-
>>> from mplang.
|
|
51
|
+
>>> from mplang.edsl import trace
|
|
52
52
|
>>> def compute(x, y):
|
|
53
53
|
... z = x + y # TraceObject.__add__ → add_p.bind(x, y)
|
|
54
54
|
... return z
|
|
@@ -124,7 +124,7 @@ from __future__ import annotations
|
|
|
124
124
|
|
|
125
125
|
from typing import Any, ClassVar, Generic, TypeVar
|
|
126
126
|
|
|
127
|
-
from mplang.
|
|
127
|
+
from mplang.edsl import serde
|
|
128
128
|
|
|
129
129
|
# ==============================================================================
|
|
130
130
|
# --- Base Type & Type Aliases
|
|
@@ -27,13 +27,13 @@ extern "C" {
|
|
|
27
27
|
|
|
28
28
|
/**
|
|
29
29
|
* @brief LDPC Encoding: Compute Syndrome s = H * x
|
|
30
|
-
*
|
|
30
|
+
*
|
|
31
31
|
* H is a sparse M x N binary matrix (CSR format).
|
|
32
32
|
* x is a dense N-vector of 128-bit blocks (N * 16 bytes).
|
|
33
33
|
* s is a dense M-vector of 128-bit blocks (M * 16 bytes).
|
|
34
|
-
*
|
|
34
|
+
*
|
|
35
35
|
* Logic: For each row i of H, s[i] = XOR(x[j]) for all j where H[i, j] = 1.
|
|
36
|
-
*
|
|
36
|
+
*
|
|
37
37
|
* @param message_ptr Pointer to message x (N * 2 uint64_t)
|
|
38
38
|
* @param indices_ptr Pointer to CSR indices (uint64_t)
|
|
39
39
|
* @param indptr_ptr Pointer to CSR indptr (M+1 uint64_t)
|
|
@@ -41,17 +41,17 @@ extern "C" {
|
|
|
41
41
|
* @param m Number of rows in H (syndrome length)
|
|
42
42
|
* @param n Number of cols in H (message length)
|
|
43
43
|
*/
|
|
44
|
-
void ldpc_encode(const uint64_t* message_ptr,
|
|
45
|
-
const uint64_t* indices_ptr,
|
|
46
|
-
const uint64_t* indptr_ptr,
|
|
47
|
-
uint64_t* output_ptr,
|
|
48
|
-
uint64_t m,
|
|
44
|
+
void ldpc_encode(const uint64_t* message_ptr,
|
|
45
|
+
const uint64_t* indices_ptr,
|
|
46
|
+
const uint64_t* indptr_ptr,
|
|
47
|
+
uint64_t* output_ptr,
|
|
48
|
+
uint64_t m,
|
|
49
49
|
uint64_t n) {
|
|
50
|
-
|
|
50
|
+
|
|
51
51
|
// Check alignment
|
|
52
52
|
// We assume message_ptr and output_ptr are 16-byte aligned for SSE/AVX?
|
|
53
53
|
// JAX/Numpy arrays are usually aligned.
|
|
54
|
-
|
|
54
|
+
|
|
55
55
|
// Cast to __m128i for efficiency
|
|
56
56
|
// But we need to handle potential unaligned access if numpy doesn't align.
|
|
57
57
|
// _mm_loadu_si128 handles unaligned.
|
|
@@ -63,10 +63,10 @@ void ldpc_encode(const uint64_t* message_ptr,
|
|
|
63
63
|
for (uint64_t i = 0; i < m; ++i) {
|
|
64
64
|
// Row i
|
|
65
65
|
__m128i sum = _mm_setzero_si128();
|
|
66
|
-
|
|
66
|
+
|
|
67
67
|
uint64_t start = indptr_ptr[i];
|
|
68
68
|
uint64_t end = indptr_ptr[i+1];
|
|
69
|
-
|
|
69
|
+
|
|
70
70
|
for (uint64_t k = start; k < end; ++k) {
|
|
71
71
|
uint64_t col_idx = indices_ptr[k];
|
|
72
72
|
// XOR accumulation
|
|
@@ -74,7 +74,7 @@ void ldpc_encode(const uint64_t* message_ptr,
|
|
|
74
74
|
__m128i val = _mm_loadu_si128(&x_vec[col_idx]);
|
|
75
75
|
sum = _mm_xor_si128(sum, val);
|
|
76
76
|
}
|
|
77
|
-
|
|
77
|
+
|
|
78
78
|
_mm_storeu_si128(&s_vec[i], sum);
|
|
79
79
|
}
|
|
80
80
|
}
|
|
@@ -86,7 +86,7 @@ extern "C" {
|
|
|
86
86
|
// 3. Build CSR Structure (Flat Arrays) to replace vector<vector>
|
|
87
87
|
// col_start[j] points to start of column j's rows in flat_rows
|
|
88
88
|
std::vector<int> col_start(m + 1, 0);
|
|
89
|
-
|
|
89
|
+
|
|
90
90
|
// Prefix sum to compute start positions
|
|
91
91
|
// col_start[0] = 0
|
|
92
92
|
// col_start[j+1] = col_start[j] + degree[j]
|
|
@@ -96,7 +96,7 @@ extern "C" {
|
|
|
96
96
|
|
|
97
97
|
// Total edges = 3 * N implies flat_rows size
|
|
98
98
|
std::vector<int> flat_rows(n * 3);
|
|
99
|
-
|
|
99
|
+
|
|
100
100
|
// Temporary copy of start indices to use as fill pointers
|
|
101
101
|
std::vector<int> fill_ptr = col_start;
|
|
102
102
|
|
|
@@ -106,7 +106,7 @@ extern "C" {
|
|
|
106
106
|
flat_rows[fill_ptr[r.h2]++] = i;
|
|
107
107
|
flat_rows[fill_ptr[r.h3]++] = i;
|
|
108
108
|
}
|
|
109
|
-
|
|
109
|
+
|
|
110
110
|
// 4. Initialize Peeling
|
|
111
111
|
std::vector<int> peel_stack;
|
|
112
112
|
peel_stack.reserve(m);
|
|
@@ -135,7 +135,7 @@ extern "C" {
|
|
|
135
135
|
int owner_row = -1;
|
|
136
136
|
int start = col_start[j];
|
|
137
137
|
int end = col_start[j+1];
|
|
138
|
-
|
|
138
|
+
|
|
139
139
|
for(int k=start; k<end; ++k) {
|
|
140
140
|
int r_idx = flat_rows[k];
|
|
141
141
|
if(!row_removed[r_idx]) {
|
|
@@ -28,7 +28,7 @@
|
|
|
28
28
|
extern "C" {
|
|
29
29
|
|
|
30
30
|
// Number of Bins for Mega-Binning strategy.
|
|
31
|
-
// 1024 bins implies ~1000 items per bin for N=1M, fitting the working set
|
|
31
|
+
// 1024 bins implies ~1000 items per bin for N=1M, fitting the working set
|
|
32
32
|
// entirely in L1 cache (32KB/48KB) for maximum performance.
|
|
33
33
|
static const uint64_t NUM_BINS = 1024;
|
|
34
34
|
|
|
@@ -36,6 +36,10 @@ extern "C" {
|
|
|
36
36
|
uint64_t h1, h2, h3;
|
|
37
37
|
};
|
|
38
38
|
|
|
39
|
+
// Declaration of the safe (robust) solver implemented in okvs.cpp
|
|
40
|
+
// Signature: solve_okvs(keys, values, output, n, m, seed_ptr)
|
|
41
|
+
void solve_okvs(uint64_t* keys, uint64_t* values, uint64_t* output, uint64_t n, uint64_t m, uint64_t* seed_ptr);
|
|
42
|
+
|
|
39
43
|
// Stateless Bin Selection
|
|
40
44
|
// Maps a key to a deterministic bin index [0, NUM_BINS).
|
|
41
45
|
inline uint64_t get_bin_index(uint64_t key, __m128i seed) {
|
|
@@ -50,14 +54,14 @@ extern "C" {
|
|
|
50
54
|
inline Indices get_bin_local_indices(uint64_t key, uint64_t m_local, __m128i seed) {
|
|
51
55
|
// Use a distinct seed mix to decorrelate from bin selection
|
|
52
56
|
__m128i k = _mm_set_epi64x(0, key);
|
|
53
|
-
__m128i s2 = _mm_add_epi64(seed, _mm_set_epi64x(1, 1));
|
|
57
|
+
__m128i s2 = _mm_add_epi64(seed, _mm_set_epi64x(1, 1));
|
|
54
58
|
__m128i h = _mm_aesenc_si128(k, s2);
|
|
55
59
|
h = _mm_aesenc_si128(h, s2);
|
|
56
60
|
h = _mm_aesenc_si128(h, s2);
|
|
57
61
|
|
|
58
62
|
uint64_t r = _mm_extract_epi64(h, 0);
|
|
59
63
|
Indices idx;
|
|
60
|
-
|
|
64
|
+
|
|
61
65
|
// Fast modulo for local indices
|
|
62
66
|
idx.h1 = r % m_local;
|
|
63
67
|
r = r * 6364136223846793005ULL + 1442695040888963407ULL; // LCG step
|
|
@@ -76,10 +80,10 @@ extern "C" {
|
|
|
76
80
|
|
|
77
81
|
// Core Peeling Solver for a single Bin
|
|
78
82
|
bool solve_bin(
|
|
79
|
-
const std::vector<uint64_t>& keys,
|
|
80
|
-
const std::vector<__m128i>& vals,
|
|
81
|
-
__m128i* P_local,
|
|
82
|
-
uint64_t m,
|
|
83
|
+
const std::vector<uint64_t>& keys,
|
|
84
|
+
const std::vector<__m128i>& vals,
|
|
85
|
+
__m128i* P_local,
|
|
86
|
+
uint64_t m,
|
|
83
87
|
__m128i seed
|
|
84
88
|
) {
|
|
85
89
|
uint64_t n = keys.size();
|
|
@@ -91,7 +95,7 @@ extern "C" {
|
|
|
91
95
|
};
|
|
92
96
|
std::vector<Edge> edges(n);
|
|
93
97
|
std::vector<int> col_degree(m, 0);
|
|
94
|
-
|
|
98
|
+
|
|
95
99
|
// 1. Build Local Graph
|
|
96
100
|
for(uint64_t i=0; i<n; ++i) {
|
|
97
101
|
Indices idx = get_bin_local_indices(keys[i], m, seed);
|
|
@@ -123,14 +127,14 @@ extern "C" {
|
|
|
123
127
|
|
|
124
128
|
std::vector<bool> row_removed(n, false);
|
|
125
129
|
std::vector<bool> col_removed(m, false);
|
|
126
|
-
|
|
130
|
+
|
|
127
131
|
struct Assignment {
|
|
128
132
|
int col;
|
|
129
133
|
int row_idx;
|
|
130
134
|
};
|
|
131
135
|
std::vector<Assignment> assignment_stack;
|
|
132
136
|
assignment_stack.reserve(n);
|
|
133
|
-
|
|
137
|
+
|
|
134
138
|
int head = 0;
|
|
135
139
|
while(head < peel_stack.size()) {
|
|
136
140
|
int j = peel_stack[head++];
|
|
@@ -169,15 +173,15 @@ extern "C" {
|
|
|
169
173
|
for(int i=(int)assignment_stack.size()-1; i>=0; --i) {
|
|
170
174
|
auto a = assignment_stack[i];
|
|
171
175
|
const auto& e = edges[a.row_idx];
|
|
172
|
-
|
|
176
|
+
|
|
173
177
|
__m128i val1 = _mm_loadu_si128(&P_local[e.h1]);
|
|
174
178
|
__m128i val2 = _mm_loadu_si128(&P_local[e.h2]);
|
|
175
179
|
__m128i val3 = _mm_loadu_si128(&P_local[e.h3]);
|
|
176
180
|
__m128i target = vals[e.key_idx];
|
|
177
|
-
|
|
181
|
+
|
|
178
182
|
__m128i current = _mm_xor_si128(_mm_xor_si128(val1, val2), val3);
|
|
179
183
|
__m128i diff = _mm_xor_si128(target, current);
|
|
180
|
-
|
|
184
|
+
|
|
181
185
|
_mm_storeu_si128(&P_local[a.col], diff);
|
|
182
186
|
}
|
|
183
187
|
return true;
|
|
@@ -185,15 +189,15 @@ extern "C" {
|
|
|
185
189
|
|
|
186
190
|
void solve_okvs_opt(uint64_t* keys, uint64_t* values, uint64_t* output, uint64_t n, uint64_t m, uint64_t* seed_ptr) {
|
|
187
191
|
__m128i seed = _mm_loadu_si128((__m128i*)seed_ptr);
|
|
188
|
-
|
|
192
|
+
|
|
189
193
|
// 1. Calculate Bin Boundaries
|
|
190
194
|
// We divide M evenly among bins. The remainder is distributed to the first few bins.
|
|
191
195
|
std::vector<uint64_t> bin_offsets(NUM_BINS + 1);
|
|
192
196
|
std::vector<uint64_t> m_per_bin(NUM_BINS);
|
|
193
|
-
|
|
197
|
+
|
|
194
198
|
uint64_t base_m = m / NUM_BINS;
|
|
195
199
|
uint64_t remainder = m % NUM_BINS;
|
|
196
|
-
|
|
200
|
+
|
|
197
201
|
uint64_t current_offset = 0;
|
|
198
202
|
for(uint64_t b=0; b<NUM_BINS; ++b) {
|
|
199
203
|
bin_offsets[b] = current_offset;
|
|
@@ -204,18 +208,18 @@ extern "C" {
|
|
|
204
208
|
|
|
205
209
|
// 2. Partition Data (Stateless)
|
|
206
210
|
// Note on "Two-Choice Hashing":
|
|
207
|
-
// While Two-Choice Hashing (selecting the lighter of 2 potential bins) would significantly
|
|
211
|
+
// While Two-Choice Hashing (selecting the lighter of 2 potential bins) would significantly
|
|
208
212
|
// reduce max bin load variance, it introduces "Statefulness".
|
|
209
213
|
// The bin assignment for Key K would depend on the load of bins, which depends on other keys.
|
|
210
|
-
// In standard PSI protocols (like RR22), the Decode step must be capable of processing keys
|
|
214
|
+
// In standard PSI protocols (like RR22), the Decode step must be capable of processing keys
|
|
211
215
|
// independently or without knowledge of the full set distribution (Sender/Receiver separation).
|
|
212
216
|
// Therefore, we use **Simple Binning** (Stateless Hash) where Bin(K) = H(K) % Bins.
|
|
213
|
-
// We mitigate the resulting variance ("Balls-in-Bins" problem) by using a slightly larger
|
|
217
|
+
// We mitigate the resulting variance ("Balls-in-Bins" problem) by using a slightly larger
|
|
214
218
|
// expansion factor (epsilon ~ 1.35) which is bandwidth-acceptable and ensures stability.
|
|
215
|
-
|
|
219
|
+
|
|
216
220
|
std::vector<std::vector<uint64_t>> bin_keys(NUM_BINS);
|
|
217
221
|
std::vector<std::vector<__m128i>> bin_vals(NUM_BINS);
|
|
218
|
-
|
|
222
|
+
|
|
219
223
|
// Pre-allocate to reduce reallocation overhead (assume ~uniform distribution)
|
|
220
224
|
// 1.5x margin for pre-allocation safety
|
|
221
225
|
size_t est_size = (n / NUM_BINS) * 3 / 2;
|
|
@@ -223,14 +227,14 @@ extern "C" {
|
|
|
223
227
|
bin_keys[b].reserve(est_size);
|
|
224
228
|
bin_vals[b].reserve(est_size);
|
|
225
229
|
}
|
|
226
|
-
|
|
230
|
+
|
|
227
231
|
const __m128i* V_ptr = (const __m128i*)values;
|
|
228
232
|
for(uint64_t i=0; i<n; ++i) {
|
|
229
233
|
uint64_t b = get_bin_index(keys[i], seed);
|
|
230
234
|
bin_keys[b].push_back(keys[i]);
|
|
231
235
|
bin_vals[b].push_back(_mm_loadu_si128(&V_ptr[i]));
|
|
232
236
|
}
|
|
233
|
-
|
|
237
|
+
|
|
234
238
|
// 3. Parallel Solve
|
|
235
239
|
// Each bin is solved independently. This logic is perfectly parallelizable (embarrassingly parallel).
|
|
236
240
|
// The working set for each bin (~1000 items) stays hot in L1 Cache.
|
|
@@ -240,15 +244,26 @@ extern "C" {
|
|
|
240
244
|
#pragma omp parallel for schedule(dynamic)
|
|
241
245
|
for(uint64_t b=0; b<NUM_BINS; ++b) {
|
|
242
246
|
if(bin_keys[b].empty()) continue;
|
|
243
|
-
|
|
247
|
+
|
|
244
248
|
uint64_t offset = bin_offsets[b];
|
|
245
249
|
uint64_t valid_m = m_per_bin[b];
|
|
246
|
-
|
|
250
|
+
|
|
247
251
|
if(!solve_bin(bin_keys[b], bin_vals[b], &P_vec[offset], valid_m, seed)) {
|
|
252
|
+
// On failure, log and fall back to the robust solver for this bin.
|
|
253
|
+
// The fallback is executed inside a critical section to avoid nested OpenMP
|
|
254
|
+
// regions and to serialize rare fallbacks.
|
|
248
255
|
#pragma omp critical
|
|
249
256
|
{
|
|
250
|
-
fprintf(stderr, "[
|
|
257
|
+
fprintf(stderr, "[WARN] Bin %lu failed optimized peeling; falling back to safe solver. Items: %lu / M: %lu (Ratio: %.2f)\n",
|
|
251
258
|
b, bin_keys[b].size(), valid_m, (double)valid_m / bin_keys[b].size());
|
|
259
|
+
|
|
260
|
+
// Prepare pointers for the safe solver
|
|
261
|
+
uint64_t* keys_ptr = &bin_keys[b][0];
|
|
262
|
+
uint64_t* vals_ptr = (uint64_t*)&bin_vals[b][0]; // Cast __m128i* to uint64_t*
|
|
263
|
+
uint64_t* out_ptr = output + (offset * 2ULL); // each 128-bit slot == 2 uint64_t
|
|
264
|
+
|
|
265
|
+
// Call the safe solver implemented in okvs.cpp
|
|
266
|
+
solve_okvs(keys_ptr, vals_ptr, out_ptr, bin_keys[b].size(), valid_m, seed_ptr);
|
|
252
267
|
}
|
|
253
268
|
}
|
|
254
269
|
}
|
|
@@ -258,10 +273,10 @@ extern "C" {
|
|
|
258
273
|
__m128i seed = _mm_loadu_si128((__m128i*)seed_ptr);
|
|
259
274
|
__m128i* P_vec = (__m128i*)storage;
|
|
260
275
|
__m128i* out_vec = (__m128i*)output;
|
|
261
|
-
|
|
276
|
+
|
|
262
277
|
// Replicate Boundary Logic
|
|
263
278
|
std::vector<uint64_t> bin_offsets(NUM_BINS + 1);
|
|
264
|
-
std::vector<uint64_t> m_per_bin(NUM_BINS);
|
|
279
|
+
std::vector<uint64_t> m_per_bin(NUM_BINS);
|
|
265
280
|
uint64_t base_m = m / NUM_BINS;
|
|
266
281
|
uint64_t remainder = m % NUM_BINS;
|
|
267
282
|
uint64_t current_offset = 0;
|
|
@@ -275,12 +290,12 @@ extern "C" {
|
|
|
275
290
|
#pragma omp parallel for schedule(static)
|
|
276
291
|
for(uint64_t i=0; i<n; ++i) {
|
|
277
292
|
uint64_t b = get_bin_index(keys[i], seed);
|
|
278
|
-
|
|
293
|
+
|
|
279
294
|
uint64_t m_local = m_per_bin[b];
|
|
280
295
|
uint64_t offset = bin_offsets[b];
|
|
281
|
-
|
|
296
|
+
|
|
282
297
|
Indices idx = get_bin_local_indices(keys[i], m_local, seed);
|
|
283
|
-
|
|
298
|
+
|
|
284
299
|
__m128i val = _mm_xor_si128(
|
|
285
300
|
_mm_xor_si128(_mm_loadu_si128(&P_vec[offset + idx.h1]), _mm_loadu_si128(&P_vec[offset + idx.h2])),
|
|
286
301
|
_mm_loadu_si128(&P_vec[offset + idx.h3])
|
|
@@ -30,8 +30,8 @@ Naming Convention:
|
|
|
30
30
|
- collect: gather (N parties → 1 party, stacked)
|
|
31
31
|
|
|
32
32
|
Example:
|
|
33
|
-
>>> from mplang.
|
|
34
|
-
>>> from mplang.
|
|
33
|
+
>>> from mplang.libs.collective import transfer, replicate, distribute, collect
|
|
34
|
+
>>> from mplang.dialects.simp import constant, converge
|
|
35
35
|
>>>
|
|
36
36
|
>>> # Create data on party 0
|
|
37
37
|
>>> x = constant((0,), 42)
|
|
@@ -47,9 +47,9 @@ from __future__ import annotations
|
|
|
47
47
|
|
|
48
48
|
from typing import TYPE_CHECKING
|
|
49
49
|
|
|
50
|
-
from mplang.
|
|
51
|
-
from mplang.
|
|
52
|
-
from mplang.
|
|
50
|
+
from mplang.dialects.simp import converge, shuffle_static
|
|
51
|
+
from mplang.edsl import Object
|
|
52
|
+
from mplang.edsl.typing import MPType
|
|
53
53
|
|
|
54
54
|
if TYPE_CHECKING:
|
|
55
55
|
pass
|