mplang-nightly 0.1.dev158__py3-none-any.whl → 0.1.dev268__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- mplang/__init__.py +21 -45
- mplang/py.typed +13 -0
- mplang/v1/__init__.py +157 -0
- mplang/v1/_device.py +602 -0
- mplang/{analysis → v1/analysis}/__init__.py +1 -1
- mplang/{analysis → v1/analysis}/diagram.py +5 -7
- mplang/v1/core/__init__.py +157 -0
- mplang/{core → v1/core}/cluster.py +30 -14
- mplang/{core → v1/core}/comm.py +5 -1
- mplang/{core → v1/core}/context_mgr.py +1 -1
- mplang/{core/dtype.py → v1/core/dtypes.py} +44 -2
- mplang/{core → v1/core}/expr/__init__.py +7 -7
- mplang/{core → v1/core}/expr/ast.py +13 -14
- mplang/{core → v1/core}/expr/evaluator.py +65 -24
- mplang/{core → v1/core}/expr/printer.py +24 -18
- mplang/{core → v1/core}/expr/transformer.py +3 -3
- mplang/{core → v1/core}/expr/utils.py +2 -2
- mplang/{core → v1/core}/expr/visitor.py +1 -1
- mplang/{core → v1/core}/expr/walk.py +1 -1
- mplang/{core → v1/core}/interp.py +6 -6
- mplang/{core → v1/core}/mpir.py +23 -16
- mplang/{core → v1/core}/mpobject.py +6 -6
- mplang/{core → v1/core}/mptype.py +13 -10
- mplang/{core → v1/core}/pfunc.py +4 -4
- mplang/{core → v1/core}/primitive.py +106 -201
- mplang/{core → v1/core}/table.py +36 -8
- mplang/{core → v1/core}/tensor.py +1 -1
- mplang/{core → v1/core}/tracer.py +9 -9
- mplang/{api.py → v1/host.py} +38 -6
- mplang/v1/kernels/__init__.py +41 -0
- mplang/{kernels → v1/kernels}/base.py +1 -1
- mplang/v1/kernels/basic.py +240 -0
- mplang/{kernels → v1/kernels}/context.py +42 -27
- mplang/{kernels → v1/kernels}/crypto.py +44 -37
- mplang/v1/kernels/fhe.py +858 -0
- mplang/{kernels → v1/kernels}/mock_tee.py +12 -13
- mplang/{kernels → v1/kernels}/phe.py +263 -57
- mplang/{kernels → v1/kernels}/spu.py +137 -48
- mplang/{kernels → v1/kernels}/sql_duckdb.py +12 -15
- mplang/{kernels → v1/kernels}/stablehlo.py +30 -23
- mplang/v1/kernels/value.py +626 -0
- mplang/{ops → v1/ops}/__init__.py +5 -16
- mplang/{ops → v1/ops}/base.py +2 -5
- mplang/{ops/builtin.py → v1/ops/basic.py} +34 -26
- mplang/v1/ops/crypto.py +262 -0
- mplang/v1/ops/fhe.py +272 -0
- mplang/{ops → v1/ops}/jax_cc.py +33 -68
- mplang/v1/ops/nnx_cc.py +168 -0
- mplang/{ops → v1/ops}/phe.py +16 -4
- mplang/{ops → v1/ops}/spu.py +3 -5
- mplang/v1/ops/sql_cc.py +303 -0
- mplang/{ops → v1/ops}/tee.py +9 -24
- mplang/{protos → v1/protos}/v1alpha1/mpir_pb2.pyi +71 -21
- mplang/v1/protos/v1alpha1/value_pb2.py +34 -0
- mplang/v1/protos/v1alpha1/value_pb2.pyi +169 -0
- mplang/{runtime → v1/runtime}/__init__.py +2 -2
- mplang/v1/runtime/channel.py +230 -0
- mplang/{runtime → v1/runtime}/cli.py +35 -20
- mplang/{runtime → v1/runtime}/client.py +19 -8
- mplang/{runtime → v1/runtime}/communicator.py +59 -15
- mplang/{runtime → v1/runtime}/data_providers.py +80 -19
- mplang/{runtime → v1/runtime}/driver.py +30 -12
- mplang/v1/runtime/link_comm.py +196 -0
- mplang/{runtime → v1/runtime}/server.py +58 -42
- mplang/{runtime → v1/runtime}/session.py +57 -71
- mplang/{runtime → v1/runtime}/simulation.py +55 -28
- mplang/v1/simp/api.py +353 -0
- mplang/{simp → v1/simp}/mpi.py +8 -9
- mplang/{simp/__init__.py → v1/simp/party.py} +19 -145
- mplang/{simp → v1/simp}/random.py +21 -22
- mplang/v1/simp/smpc.py +238 -0
- mplang/v1/utils/table_utils.py +185 -0
- mplang/v2/__init__.py +424 -0
- mplang/v2/backends/__init__.py +57 -0
- mplang/v2/backends/bfv_impl.py +705 -0
- mplang/v2/backends/channel.py +217 -0
- mplang/v2/backends/crypto_impl.py +723 -0
- mplang/v2/backends/field_impl.py +454 -0
- mplang/v2/backends/func_impl.py +107 -0
- mplang/v2/backends/phe_impl.py +148 -0
- mplang/v2/backends/simp_design.md +136 -0
- mplang/v2/backends/simp_driver/__init__.py +41 -0
- mplang/v2/backends/simp_driver/http.py +168 -0
- mplang/v2/backends/simp_driver/mem.py +280 -0
- mplang/v2/backends/simp_driver/ops.py +135 -0
- mplang/v2/backends/simp_driver/state.py +60 -0
- mplang/v2/backends/simp_driver/values.py +52 -0
- mplang/v2/backends/simp_worker/__init__.py +29 -0
- mplang/v2/backends/simp_worker/http.py +354 -0
- mplang/v2/backends/simp_worker/mem.py +102 -0
- mplang/v2/backends/simp_worker/ops.py +167 -0
- mplang/v2/backends/simp_worker/state.py +49 -0
- mplang/v2/backends/spu_impl.py +275 -0
- mplang/v2/backends/spu_state.py +187 -0
- mplang/v2/backends/store_impl.py +62 -0
- mplang/v2/backends/table_impl.py +838 -0
- mplang/v2/backends/tee_impl.py +215 -0
- mplang/v2/backends/tensor_impl.py +519 -0
- mplang/v2/cli.py +603 -0
- mplang/v2/cli_guide.md +122 -0
- mplang/v2/dialects/__init__.py +36 -0
- mplang/v2/dialects/bfv.py +665 -0
- mplang/v2/dialects/crypto.py +689 -0
- mplang/v2/dialects/dtypes.py +378 -0
- mplang/v2/dialects/field.py +210 -0
- mplang/v2/dialects/func.py +135 -0
- mplang/v2/dialects/phe.py +723 -0
- mplang/v2/dialects/simp.py +944 -0
- mplang/v2/dialects/spu.py +349 -0
- mplang/v2/dialects/store.py +63 -0
- mplang/v2/dialects/table.py +407 -0
- mplang/v2/dialects/tee.py +346 -0
- mplang/v2/dialects/tensor.py +1175 -0
- mplang/v2/edsl/README.md +279 -0
- mplang/v2/edsl/__init__.py +99 -0
- mplang/v2/edsl/context.py +311 -0
- mplang/v2/edsl/graph.py +463 -0
- mplang/v2/edsl/jit.py +62 -0
- mplang/v2/edsl/object.py +53 -0
- mplang/v2/edsl/primitive.py +284 -0
- mplang/v2/edsl/printer.py +119 -0
- mplang/v2/edsl/registry.py +207 -0
- mplang/v2/edsl/serde.py +375 -0
- mplang/v2/edsl/tracer.py +614 -0
- mplang/v2/edsl/typing.py +816 -0
- mplang/v2/kernels/Makefile +30 -0
- mplang/v2/kernels/__init__.py +23 -0
- mplang/v2/kernels/gf128.cpp +148 -0
- mplang/v2/kernels/ldpc.cpp +82 -0
- mplang/v2/kernels/okvs.cpp +283 -0
- mplang/v2/kernels/okvs_opt.cpp +291 -0
- mplang/v2/kernels/py_kernels.py +398 -0
- mplang/v2/libs/collective.py +330 -0
- mplang/v2/libs/device/__init__.py +51 -0
- mplang/v2/libs/device/api.py +813 -0
- mplang/v2/libs/device/cluster.py +352 -0
- mplang/v2/libs/ml/__init__.py +23 -0
- mplang/v2/libs/ml/sgb.py +1861 -0
- mplang/v2/libs/mpc/__init__.py +41 -0
- mplang/v2/libs/mpc/_utils.py +99 -0
- mplang/v2/libs/mpc/analytics/__init__.py +35 -0
- mplang/v2/libs/mpc/analytics/aggregation.py +372 -0
- mplang/v2/libs/mpc/analytics/groupby.md +99 -0
- mplang/v2/libs/mpc/analytics/groupby.py +331 -0
- mplang/v2/libs/mpc/analytics/permutation.py +386 -0
- mplang/v2/libs/mpc/common/constants.py +39 -0
- mplang/v2/libs/mpc/ot/__init__.py +32 -0
- mplang/v2/libs/mpc/ot/base.py +222 -0
- mplang/v2/libs/mpc/ot/extension.py +477 -0
- mplang/v2/libs/mpc/ot/silent.py +217 -0
- mplang/v2/libs/mpc/psi/__init__.py +40 -0
- mplang/v2/libs/mpc/psi/cuckoo.py +228 -0
- mplang/v2/libs/mpc/psi/okvs.py +49 -0
- mplang/v2/libs/mpc/psi/okvs_gct.py +79 -0
- mplang/v2/libs/mpc/psi/oprf.py +310 -0
- mplang/v2/libs/mpc/psi/rr22.py +344 -0
- mplang/v2/libs/mpc/psi/unbalanced.py +200 -0
- mplang/v2/libs/mpc/vole/__init__.py +31 -0
- mplang/v2/libs/mpc/vole/gilboa.py +327 -0
- mplang/v2/libs/mpc/vole/ldpc.py +383 -0
- mplang/v2/libs/mpc/vole/silver.py +336 -0
- mplang/v2/runtime/__init__.py +15 -0
- mplang/v2/runtime/dialect_state.py +41 -0
- mplang/v2/runtime/interpreter.py +871 -0
- mplang/v2/runtime/object_store.py +194 -0
- mplang/v2/runtime/value.py +141 -0
- {mplang_nightly-0.1.dev158.dist-info → mplang_nightly-0.1.dev268.dist-info}/METADATA +24 -17
- mplang_nightly-0.1.dev268.dist-info/RECORD +180 -0
- {mplang_nightly-0.1.dev158.dist-info → mplang_nightly-0.1.dev268.dist-info}/WHEEL +1 -1
- mplang/core/__init__.py +0 -92
- mplang/device.py +0 -340
- mplang/kernels/builtin.py +0 -207
- mplang/ops/crypto.py +0 -109
- mplang/ops/ibis_cc.py +0 -139
- mplang/ops/sql.py +0 -61
- mplang/protos/v1alpha1/mpir_pb2_grpc.py +0 -3
- mplang/runtime/link_comm.py +0 -131
- mplang/simp/smpc.py +0 -201
- mplang/utils/table_utils.py +0 -73
- mplang_nightly-0.1.dev158.dist-info/RECORD +0 -77
- /mplang/{core → v1/core}/mask.py +0 -0
- /mplang/{protos → v1/protos}/v1alpha1/mpir_pb2.py +0 -0
- /mplang/{runtime → v1/runtime}/exceptions.py +0 -0
- /mplang/{runtime → v1/runtime}/http_api.md +0 -0
- /mplang/{kernels → v1/simp}/__init__.py +0 -0
- /mplang/{utils → v1/utils}/__init__.py +0 -0
- /mplang/{utils → v1/utils}/crypto.py +0 -0
- /mplang/{utils → v1/utils}/func_utils.py +0 -0
- /mplang/{utils → v1/utils}/spu_utils.py +0 -0
- {mplang_nightly-0.1.dev158.dist-info → mplang_nightly-0.1.dev268.dist-info}/entry_points.txt +0 -0
- {mplang_nightly-0.1.dev158.dist-info → mplang_nightly-0.1.dev268.dist-info}/licenses/LICENSE +0 -0
mplang/{ops → v1/ops}/jax_cc.py
RENAMED
|
@@ -14,18 +14,18 @@
|
|
|
14
14
|
|
|
15
15
|
from __future__ import annotations
|
|
16
16
|
|
|
17
|
+
import logging
|
|
17
18
|
from collections.abc import Callable
|
|
18
19
|
from typing import Any
|
|
19
20
|
|
|
20
21
|
import jax
|
|
21
22
|
import jax.numpy as jnp
|
|
23
|
+
from jax import export
|
|
22
24
|
from jax.tree_util import PyTreeDef, tree_flatten
|
|
23
25
|
|
|
24
|
-
from mplang.core
|
|
25
|
-
from mplang.
|
|
26
|
-
from mplang.
|
|
27
|
-
from mplang.ops.base import FeOperation, stateless_mod
|
|
28
|
-
from mplang.utils.func_utils import normalize_fn
|
|
26
|
+
from mplang.v1.core import MPObject, PFunction, TensorType, get_fn_name
|
|
27
|
+
from mplang.v1.ops.base import FeOperation, stateless_mod
|
|
28
|
+
from mplang.v1.utils.func_utils import normalize_fn
|
|
29
29
|
|
|
30
30
|
# Enable 64-bit precision for JAX to match tensor types
|
|
31
31
|
jax.config.update("jax_enable_x64", True)
|
|
@@ -38,7 +38,8 @@ def jax2stablehlo(
|
|
|
38
38
|
|
|
39
39
|
Translates high-level JAX functions into StableHLO MLIR representations,
|
|
40
40
|
enabling execution on JAX backends across different processes and platforms.
|
|
41
|
-
Uses
|
|
41
|
+
Uses a hybrid approach: traditional JAX trace/lower for compilation compatibility,
|
|
42
|
+
with stable jax.export API for parameter tracking.
|
|
42
43
|
|
|
43
44
|
Args:
|
|
44
45
|
is_variable: Predicate function to classify parameters as variables vs. constants.
|
|
@@ -54,34 +55,6 @@ def jax2stablehlo(
|
|
|
54
55
|
Non-variable parameters are captured as compile-time constants within
|
|
55
56
|
the PFunction body, while variables become runtime input parameters.
|
|
56
57
|
- PyTreeDef: Tree structure template for reconstructing nested output values
|
|
57
|
-
|
|
58
|
-
Rationale:
|
|
59
|
-
JAX Serialization Options Analysis:
|
|
60
|
-
1. jax.export (JAX ≥0.4.35) - Official export API with StableHLO backend
|
|
61
|
-
2. HLO protobuf - Raw XLA HloModule serialization
|
|
62
|
-
3. HLO text - Human-readable HLO representation
|
|
63
|
-
4. StableHLO MLIR - Portable intermediate representation
|
|
64
|
-
5. JAX compiled object pickling - Limited to same-process execution
|
|
65
|
-
|
|
66
|
-
Current Choice: StableHLO MLIR
|
|
67
|
-
Advantages:
|
|
68
|
-
- ✅ Available in current JAX version (0.4.34)
|
|
69
|
-
- ✅ Cross-version compatibility guaranteed by StableHLO design
|
|
70
|
-
- ✅ Direct compilation support via XLA client.compile(mlir_string)
|
|
71
|
-
- ✅ Handles complex functions (multi-input/output, control flow)
|
|
72
|
-
- ✅ Preserves numerical precision
|
|
73
|
-
- ✅ Platform-independent representation
|
|
74
|
-
|
|
75
|
-
Alternative Options Issues:
|
|
76
|
-
- jax.export: Not available in JAX 0.4.34
|
|
77
|
-
- HLO protobuf: Version compatibility issues with StableHLO parser
|
|
78
|
-
- HLO text: Parser compatibility issues with XLA client
|
|
79
|
-
- Pickle: Cannot serialize XLA LoadedExecutable objects
|
|
80
|
-
|
|
81
|
-
Future Migration Path:
|
|
82
|
-
- JAX ≥0.4.35: Migrate to jax.export.export() + jax.export.deserialize()
|
|
83
|
-
- JAX ≥0.5.x: Consider new portable formats if available
|
|
84
|
-
- Long-term: Adopt official JAX serialization standards as they mature
|
|
85
58
|
"""
|
|
86
59
|
# Flatten (args, kwargs) and capture immediates using the moved logic from primitive.py
|
|
87
60
|
normalized_fn, in_vars = normalize_fn(flat_fn, args, kwargs, is_variable)
|
|
@@ -91,47 +64,39 @@ def jax2stablehlo(
|
|
|
91
64
|
jax.ShapeDtypeStruct(arg.shape, jnp.dtype(arg.dtype.name)) for arg in in_vars
|
|
92
65
|
]
|
|
93
66
|
|
|
94
|
-
#
|
|
67
|
+
# Hybrid approach: Use standard JAX trace/lower for compatibility, but jax.export for parameter tracking
|
|
95
68
|
jitted_fn = jax.jit(normalized_fn)
|
|
96
69
|
traced = jitted_fn.trace(jax_params)
|
|
97
70
|
lowered = traced.lower()
|
|
98
71
|
|
|
99
|
-
# Get StableHLO MLIR representation
|
|
100
|
-
# compiler_ir("stablehlo") returns jaxlib.mlir.ir.Module object
|
|
101
|
-
# str() converts to serializable text format
|
|
72
|
+
# Get StableHLO MLIR representation using traditional approach
|
|
102
73
|
stablehlo_mlir = lowered.compiler_ir("stablehlo")
|
|
103
74
|
mlir_text = str(stablehlo_mlir)
|
|
104
75
|
|
|
105
|
-
# Get output info
|
|
76
|
+
# Get output info using traditional approach
|
|
106
77
|
out_info_flat, out_tree = tree_flatten(lowered.out_info)
|
|
107
78
|
out_info_flat = [TensorType.from_obj(info) for info in out_info_flat]
|
|
108
79
|
|
|
109
|
-
# Extract argument keep mapping
|
|
110
|
-
#
|
|
111
|
-
# receives all original arguments. We need the mapping to filter them correctly.
|
|
80
|
+
# Extract argument keep mapping using stable jax.export API for parameter tracking
|
|
81
|
+
# We use jax.export only for getting the kept_var_idx information, not for the main compilation
|
|
112
82
|
arg_keep_map = None
|
|
113
83
|
original_arg_count = len(in_vars)
|
|
114
84
|
|
|
115
85
|
try:
|
|
116
|
-
#
|
|
117
|
-
|
|
118
|
-
|
|
119
|
-
kept_var_idx =
|
|
120
|
-
|
|
121
|
-
|
|
122
|
-
|
|
123
|
-
arg_keep_map =
|
|
124
|
-
|
|
125
|
-
|
|
126
|
-
#
|
|
127
|
-
|
|
128
|
-
|
|
129
|
-
|
|
130
|
-
f"Cannot access JAX's kept_var_idx to handle unused parameter elimination. "
|
|
131
|
-
f"This function may have unused parameters that JAX optimized away, "
|
|
132
|
-
f"but we cannot determine which ones without the internal API. "
|
|
133
|
-
f"Original error: {e}"
|
|
134
|
-
) from e
|
|
86
|
+
# Use jax.export just to get the stable parameter tracking information
|
|
87
|
+
export_fn = export.export(jitted_fn)
|
|
88
|
+
exported = export_fn(jax_params)
|
|
89
|
+
kept_var_idx = exported.module_kept_var_idx
|
|
90
|
+
if kept_var_idx is not None and len(kept_var_idx) < original_arg_count:
|
|
91
|
+
# JAX eliminated some unused parameters during compilation
|
|
92
|
+
# Keep the indices in sorted order for consistent mapping
|
|
93
|
+
arg_keep_map = sorted(kept_var_idx)
|
|
94
|
+
except Exception as e:
|
|
95
|
+
# Fallback: if jax.export fails, we can still use the compiled result without parameter tracking
|
|
96
|
+
# This ensures backward compatibility even if export has issues
|
|
97
|
+
logging.warning(
|
|
98
|
+
f"jax.export failed to get kept_var_idx, proceeding without it. Error: {e}"
|
|
99
|
+
)
|
|
135
100
|
|
|
136
101
|
# This format tells JaxRT how to handle the compiled result
|
|
137
102
|
pfn_kwargs: dict[str, Any] = {
|
|
@@ -149,11 +114,11 @@ def jax2stablehlo(
|
|
|
149
114
|
return pfn, in_vars, out_tree
|
|
150
115
|
|
|
151
116
|
|
|
152
|
-
class
|
|
153
|
-
"""JAX
|
|
117
|
+
class JaxRunner(FeOperation):
|
|
118
|
+
"""JAX function runner frontend operation."""
|
|
154
119
|
|
|
155
120
|
def trace(
|
|
156
|
-
self,
|
|
121
|
+
self, jax_fn: Callable, *args: Any, **kwargs: Any
|
|
157
122
|
) -> tuple[PFunction, list[MPObject], PyTreeDef]:
|
|
158
123
|
"""
|
|
159
124
|
JAX compilation helper function.
|
|
@@ -162,21 +127,21 @@ class JaxCompiler(FeOperation):
|
|
|
162
127
|
along with variable arguments for evaluation.
|
|
163
128
|
|
|
164
129
|
Args:
|
|
165
|
-
|
|
130
|
+
jax_fn: The JAX function to compile
|
|
166
131
|
*args: Positional arguments to the function
|
|
167
132
|
**kwargs: Keyword arguments to the function
|
|
168
133
|
|
|
169
134
|
Returns:
|
|
170
|
-
tuple[PFunction, list[MPObject],
|
|
135
|
+
tuple[PFunction, list[MPObject], PyTreeDef]: The compiled PFunction, input variables, and output tree
|
|
171
136
|
"""
|
|
172
137
|
|
|
173
138
|
def is_variable(arg: Any) -> bool:
|
|
174
139
|
return isinstance(arg, MPObject)
|
|
175
140
|
|
|
176
|
-
pfunc, in_vars, out_tree = jax2stablehlo(is_variable,
|
|
141
|
+
pfunc, in_vars, out_tree = jax2stablehlo(is_variable, jax_fn, *args, **kwargs)
|
|
177
142
|
return pfunc, in_vars, out_tree
|
|
178
143
|
|
|
179
144
|
|
|
180
145
|
_JAX_MOD = stateless_mod("jax")
|
|
181
146
|
|
|
182
|
-
|
|
147
|
+
run_jax = JaxRunner(_JAX_MOD, "run")
|
mplang/v1/ops/nnx_cc.py
ADDED
|
@@ -0,0 +1,168 @@
|
|
|
1
|
+
# Copyright 2025 Ant Group Co., Ltd.
|
|
2
|
+
#
|
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
4
|
+
# you may not use this file except in compliance with the License.
|
|
5
|
+
# You may obtain a copy of the License at
|
|
6
|
+
#
|
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
8
|
+
#
|
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
12
|
+
# See the License for the specific language governing permissions and
|
|
13
|
+
# limitations under the License.
|
|
14
|
+
|
|
15
|
+
from __future__ import annotations
|
|
16
|
+
|
|
17
|
+
import logging
|
|
18
|
+
from collections.abc import Callable
|
|
19
|
+
from typing import Any
|
|
20
|
+
|
|
21
|
+
import jax
|
|
22
|
+
import jax.numpy as jnp
|
|
23
|
+
from flax import nnx
|
|
24
|
+
from jax import export
|
|
25
|
+
from jax.tree_util import PyTreeDef, tree_flatten
|
|
26
|
+
|
|
27
|
+
from mplang.v1.core import MPObject, PFunction, TensorType, get_fn_name
|
|
28
|
+
from mplang.v1.ops.base import FeOperation, stateless_mod
|
|
29
|
+
from mplang.v1.utils.func_utils import normalize_fn
|
|
30
|
+
|
|
31
|
+
# Enable 64-bit precision for JAX to match tensor types
|
|
32
|
+
jax.config.update("jax_enable_x64", True)
|
|
33
|
+
|
|
34
|
+
|
|
35
|
+
def nnx2stablehlo(
|
|
36
|
+
is_variable: Callable[[Any], bool], flat_fn: Any, *args: Any, **kwargs: Any
|
|
37
|
+
) -> tuple[PFunction, list[Any], PyTreeDef]:
|
|
38
|
+
"""Compile NNX function to StableHLO MLIR format for remote execution.
|
|
39
|
+
|
|
40
|
+
Translates high-level NNX functions into StableHLO MLIR representations,
|
|
41
|
+
enabling execution on JAX backends across different processes and platforms.
|
|
42
|
+
Uses a hybrid approach: traditional NNX trace/lower for compilation compatibility,
|
|
43
|
+
with stable jax.export API for parameter tracking.
|
|
44
|
+
|
|
45
|
+
Args:
|
|
46
|
+
is_variable: Predicate function to classify parameters as variables vs. constants.
|
|
47
|
+
Returns True for parameters that should be treated as PFunction inputs.
|
|
48
|
+
flat_fn: NNX function to be compiled into StableHLO format
|
|
49
|
+
*args: Positional arguments passed to the function during compilation
|
|
50
|
+
**kwargs: Keyword arguments passed to the function during compilation
|
|
51
|
+
|
|
52
|
+
Returns:
|
|
53
|
+
tuple[PFunction, list, PyTreeDef]: Compilation artifacts containing:
|
|
54
|
+
- PFunction: Serialized function with embedded MLIR text and type metadata
|
|
55
|
+
- list: Extracted variable parameters (those satisfying is_variable predicate).
|
|
56
|
+
Non-variable parameters are captured as compile-time constants within
|
|
57
|
+
the PFunction body, while variables become runtime input parameters.
|
|
58
|
+
- PyTreeDef: Tree structure template for reconstructing nested output values
|
|
59
|
+
"""
|
|
60
|
+
# Flatten (args, kwargs) and capture immediates using the moved logic from primitive.py
|
|
61
|
+
normalized_fn, in_vars = normalize_fn(flat_fn, args, kwargs, is_variable)
|
|
62
|
+
|
|
63
|
+
# Convert TensorType in_vars to ShapeDtypeStruct for JAX tracing
|
|
64
|
+
jax_params = [
|
|
65
|
+
jax.ShapeDtypeStruct(arg.shape, jnp.dtype(arg.dtype.name)) for arg in in_vars
|
|
66
|
+
]
|
|
67
|
+
|
|
68
|
+
# NNX compilation pipeline using JAX export API: nnx.jit → jax.export → StableHLO MLIR
|
|
69
|
+
# Use nnx.jit for NNX-specific functionality, then jax.export for stable parameter handling
|
|
70
|
+
nnx_jitted = nnx.jit(normalized_fn)
|
|
71
|
+
|
|
72
|
+
# Extract the underlying JAX function for jax.export compatibility
|
|
73
|
+
# nnx.jit wraps a JAX function, and we can access it via .fun attribute
|
|
74
|
+
underlying_jax_fn = nnx_jitted.fun
|
|
75
|
+
|
|
76
|
+
# Hybrid approach: Use NNX trace/lower for compilation, but jax.export for parameter tracking
|
|
77
|
+
# Use traditional nnx.jit → trace → lower for compatibility with argument structure
|
|
78
|
+
nnx_traced = nnx_jitted.trace(jax_params)
|
|
79
|
+
nnx_lowered = nnx_traced.lower()
|
|
80
|
+
|
|
81
|
+
# Get StableHLO MLIR representation using traditional NNX approach
|
|
82
|
+
# NNX lowered object wraps JAX lowered, so we access the inner JAX lowered object
|
|
83
|
+
jax_lowered = nnx_lowered.lowered
|
|
84
|
+
stablehlo_mlir = jax_lowered.compiler_ir("stablehlo")
|
|
85
|
+
mlir_text = str(stablehlo_mlir)
|
|
86
|
+
|
|
87
|
+
# Get output info using traditional NNX approach
|
|
88
|
+
# NNX captures output in (args, kwargs, result) format, so we need to extract just the result part
|
|
89
|
+
raw_out_info = jax_lowered.out_info
|
|
90
|
+
if isinstance(raw_out_info, tuple) and len(raw_out_info) == 3:
|
|
91
|
+
# NNX format: (args, kwargs, result) - extract just the result
|
|
92
|
+
_, _, actual_out_info = raw_out_info
|
|
93
|
+
out_info_flat, out_tree = tree_flatten(actual_out_info)
|
|
94
|
+
else:
|
|
95
|
+
# Fallback to direct format (shouldn't happen with NNX, but just in case)
|
|
96
|
+
out_info_flat, out_tree = tree_flatten(raw_out_info)
|
|
97
|
+
|
|
98
|
+
out_info_flat = [TensorType.from_obj(info) for info in out_info_flat]
|
|
99
|
+
|
|
100
|
+
# Extract argument keep mapping using stable jax.export API for parameter tracking
|
|
101
|
+
# We use the underlying JAX function with jax.export only for parameter tracking
|
|
102
|
+
arg_keep_map = None
|
|
103
|
+
original_arg_count = len(in_vars)
|
|
104
|
+
|
|
105
|
+
try:
|
|
106
|
+
# Use jax.export with the underlying JAX function just to get stable parameter tracking
|
|
107
|
+
export_fn = export.export(jax.jit(underlying_jax_fn))
|
|
108
|
+
exported = export_fn(jax_params)
|
|
109
|
+
kept_var_idx = exported.module_kept_var_idx
|
|
110
|
+
if kept_var_idx is not None and len(kept_var_idx) < original_arg_count:
|
|
111
|
+
# JAX eliminated some unused parameters during compilation
|
|
112
|
+
# Keep the indices in sorted order for consistent mapping
|
|
113
|
+
arg_keep_map = sorted(kept_var_idx)
|
|
114
|
+
except Exception as e:
|
|
115
|
+
# Fallback: if jax.export fails, we can still use the compiled result without parameter tracking
|
|
116
|
+
# This ensures backward compatibility even if export has issues
|
|
117
|
+
logging.warning(
|
|
118
|
+
f"jax.export failed to get kept_var_idx, proceeding without it. Error: {e}"
|
|
119
|
+
)
|
|
120
|
+
|
|
121
|
+
# This format tells JaxRT how to handle the compiled result
|
|
122
|
+
# Use the same format as JAX since NNX compiles to the same backend
|
|
123
|
+
pfn_kwargs: dict[str, Any] = {
|
|
124
|
+
"fn_type": "mlir.stablehlo", # Key: specify StableHLO MLIR format
|
|
125
|
+
"ins_info": tuple(TensorType.from_obj(x) for x in in_vars),
|
|
126
|
+
"outs_info": tuple(out_info_flat),
|
|
127
|
+
"fn_name": get_fn_name(flat_fn),
|
|
128
|
+
"fn_text": mlir_text, # MLIR text, serializable for transmission
|
|
129
|
+
}
|
|
130
|
+
|
|
131
|
+
if arg_keep_map is not None:
|
|
132
|
+
pfn_kwargs["arg_keep_map"] = arg_keep_map
|
|
133
|
+
|
|
134
|
+
pfn = PFunction(**pfn_kwargs)
|
|
135
|
+
return pfn, in_vars, out_tree
|
|
136
|
+
|
|
137
|
+
|
|
138
|
+
class NnxRunner(FeOperation):
|
|
139
|
+
"""NNX function runner frontend operation."""
|
|
140
|
+
|
|
141
|
+
def trace(
|
|
142
|
+
self, nnx_fn: Callable, *args: Any, **kwargs: Any
|
|
143
|
+
) -> tuple[PFunction, list[MPObject], PyTreeDef]:
|
|
144
|
+
"""
|
|
145
|
+
NNX compilation helper function.
|
|
146
|
+
|
|
147
|
+
Compiles an NNX function to StableHLO format and returns the PFunction
|
|
148
|
+
along with variable arguments for evaluation.
|
|
149
|
+
|
|
150
|
+
Args:
|
|
151
|
+
nnx_fn: The NNX function to compile
|
|
152
|
+
*args: Positional arguments to the function
|
|
153
|
+
**kwargs: Keyword arguments to the function
|
|
154
|
+
|
|
155
|
+
Returns:
|
|
156
|
+
tuple[PFunction, list[MPObject], PyTreeDef]: The compiled PFunction, input variables, and output tree
|
|
157
|
+
"""
|
|
158
|
+
|
|
159
|
+
def is_variable(arg: Any) -> bool:
|
|
160
|
+
return isinstance(arg, MPObject)
|
|
161
|
+
|
|
162
|
+
pfunc, in_vars, out_tree = nnx2stablehlo(is_variable, nnx_fn, *args, **kwargs)
|
|
163
|
+
return pfunc, in_vars, out_tree
|
|
164
|
+
|
|
165
|
+
|
|
166
|
+
_NNX_MOD = stateless_mod("nnx")
|
|
167
|
+
|
|
168
|
+
run_nnx = NnxRunner(_NNX_MOD, "run")
|
mplang/{ops → v1/ops}/phe.py
RENAMED
|
@@ -14,22 +14,34 @@
|
|
|
14
14
|
|
|
15
15
|
"""PHE (Partially Homomorphic Encryption) frontend operations."""
|
|
16
16
|
|
|
17
|
-
from mplang.core
|
|
18
|
-
from mplang.
|
|
19
|
-
from mplang.ops.base import stateless_mod
|
|
17
|
+
from mplang.v1.core import UINT8, TensorType
|
|
18
|
+
from mplang.v1.ops.base import stateless_mod
|
|
20
19
|
|
|
21
20
|
_PHE_MOD = stateless_mod("phe")
|
|
22
21
|
|
|
23
22
|
|
|
24
23
|
@_PHE_MOD.simple_op()
|
|
25
24
|
def keygen(
|
|
26
|
-
*,
|
|
25
|
+
*,
|
|
26
|
+
scheme: str = "paillier",
|
|
27
|
+
key_size: int = 2048,
|
|
28
|
+
max_value: int | None = None,
|
|
29
|
+
fxp_bits: int | None = None,
|
|
27
30
|
) -> tuple[TensorType, TensorType]:
|
|
28
31
|
"""Generate a PHE key pair: returns (public_key, private_key).
|
|
29
32
|
|
|
30
33
|
Keys are represented with a sentinel TensorType UINT8[(-1, 0)] to indicate
|
|
31
34
|
non-structural, backend-only handles. Runtime validation will treat this
|
|
32
35
|
shape as an opaque placeholder and skip dtype/shape checks.
|
|
36
|
+
|
|
37
|
+
Attributes (forwarded to backend):
|
|
38
|
+
scheme: PHE scheme (default: 'paillier')
|
|
39
|
+
key_size: Modulus size in bits (default: 2048)
|
|
40
|
+
max_value: Optional range-encoding bound B. If provided, the backend will
|
|
41
|
+
encode/decode integers/floats within [-B, B] and treat (B, N-B) as overflow.
|
|
42
|
+
Pick B to exceed the largest intermediate magnitude you expect in homomorphic
|
|
43
|
+
combinations. If omitted, backend default is used (currently 2**32).
|
|
44
|
+
fxp_bits: Optional fixed-point fractional bits for float encoding (default backend value).
|
|
33
45
|
"""
|
|
34
46
|
key_spec = TensorType(UINT8, (-1, 0))
|
|
35
47
|
return key_spec, key_spec
|
mplang/{ops → v1/ops}/spu.py
RENAMED
|
@@ -23,11 +23,9 @@ import spu.utils.frontend as spu_fe
|
|
|
23
23
|
from jax import ShapeDtypeStruct
|
|
24
24
|
from jax.tree_util import PyTreeDef, tree_flatten
|
|
25
25
|
|
|
26
|
-
from mplang.core
|
|
27
|
-
from mplang.
|
|
28
|
-
from mplang.
|
|
29
|
-
from mplang.ops.base import stateless_mod
|
|
30
|
-
from mplang.utils.func_utils import normalize_fn
|
|
26
|
+
from mplang.v1.core import MPObject, PFunction, TensorType, get_fn_name
|
|
27
|
+
from mplang.v1.ops.base import stateless_mod
|
|
28
|
+
from mplang.v1.utils.func_utils import normalize_fn
|
|
31
29
|
|
|
32
30
|
|
|
33
31
|
class Visibility:
|