mplang-nightly 0.1.dev158__py3-none-any.whl → 0.1.dev268__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (191) hide show
  1. mplang/__init__.py +21 -45
  2. mplang/py.typed +13 -0
  3. mplang/v1/__init__.py +157 -0
  4. mplang/v1/_device.py +602 -0
  5. mplang/{analysis → v1/analysis}/__init__.py +1 -1
  6. mplang/{analysis → v1/analysis}/diagram.py +5 -7
  7. mplang/v1/core/__init__.py +157 -0
  8. mplang/{core → v1/core}/cluster.py +30 -14
  9. mplang/{core → v1/core}/comm.py +5 -1
  10. mplang/{core → v1/core}/context_mgr.py +1 -1
  11. mplang/{core/dtype.py → v1/core/dtypes.py} +44 -2
  12. mplang/{core → v1/core}/expr/__init__.py +7 -7
  13. mplang/{core → v1/core}/expr/ast.py +13 -14
  14. mplang/{core → v1/core}/expr/evaluator.py +65 -24
  15. mplang/{core → v1/core}/expr/printer.py +24 -18
  16. mplang/{core → v1/core}/expr/transformer.py +3 -3
  17. mplang/{core → v1/core}/expr/utils.py +2 -2
  18. mplang/{core → v1/core}/expr/visitor.py +1 -1
  19. mplang/{core → v1/core}/expr/walk.py +1 -1
  20. mplang/{core → v1/core}/interp.py +6 -6
  21. mplang/{core → v1/core}/mpir.py +23 -16
  22. mplang/{core → v1/core}/mpobject.py +6 -6
  23. mplang/{core → v1/core}/mptype.py +13 -10
  24. mplang/{core → v1/core}/pfunc.py +4 -4
  25. mplang/{core → v1/core}/primitive.py +106 -201
  26. mplang/{core → v1/core}/table.py +36 -8
  27. mplang/{core → v1/core}/tensor.py +1 -1
  28. mplang/{core → v1/core}/tracer.py +9 -9
  29. mplang/{api.py → v1/host.py} +38 -6
  30. mplang/v1/kernels/__init__.py +41 -0
  31. mplang/{kernels → v1/kernels}/base.py +1 -1
  32. mplang/v1/kernels/basic.py +240 -0
  33. mplang/{kernels → v1/kernels}/context.py +42 -27
  34. mplang/{kernels → v1/kernels}/crypto.py +44 -37
  35. mplang/v1/kernels/fhe.py +858 -0
  36. mplang/{kernels → v1/kernels}/mock_tee.py +12 -13
  37. mplang/{kernels → v1/kernels}/phe.py +263 -57
  38. mplang/{kernels → v1/kernels}/spu.py +137 -48
  39. mplang/{kernels → v1/kernels}/sql_duckdb.py +12 -15
  40. mplang/{kernels → v1/kernels}/stablehlo.py +30 -23
  41. mplang/v1/kernels/value.py +626 -0
  42. mplang/{ops → v1/ops}/__init__.py +5 -16
  43. mplang/{ops → v1/ops}/base.py +2 -5
  44. mplang/{ops/builtin.py → v1/ops/basic.py} +34 -26
  45. mplang/v1/ops/crypto.py +262 -0
  46. mplang/v1/ops/fhe.py +272 -0
  47. mplang/{ops → v1/ops}/jax_cc.py +33 -68
  48. mplang/v1/ops/nnx_cc.py +168 -0
  49. mplang/{ops → v1/ops}/phe.py +16 -4
  50. mplang/{ops → v1/ops}/spu.py +3 -5
  51. mplang/v1/ops/sql_cc.py +303 -0
  52. mplang/{ops → v1/ops}/tee.py +9 -24
  53. mplang/{protos → v1/protos}/v1alpha1/mpir_pb2.pyi +71 -21
  54. mplang/v1/protos/v1alpha1/value_pb2.py +34 -0
  55. mplang/v1/protos/v1alpha1/value_pb2.pyi +169 -0
  56. mplang/{runtime → v1/runtime}/__init__.py +2 -2
  57. mplang/v1/runtime/channel.py +230 -0
  58. mplang/{runtime → v1/runtime}/cli.py +35 -20
  59. mplang/{runtime → v1/runtime}/client.py +19 -8
  60. mplang/{runtime → v1/runtime}/communicator.py +59 -15
  61. mplang/{runtime → v1/runtime}/data_providers.py +80 -19
  62. mplang/{runtime → v1/runtime}/driver.py +30 -12
  63. mplang/v1/runtime/link_comm.py +196 -0
  64. mplang/{runtime → v1/runtime}/server.py +58 -42
  65. mplang/{runtime → v1/runtime}/session.py +57 -71
  66. mplang/{runtime → v1/runtime}/simulation.py +55 -28
  67. mplang/v1/simp/api.py +353 -0
  68. mplang/{simp → v1/simp}/mpi.py +8 -9
  69. mplang/{simp/__init__.py → v1/simp/party.py} +19 -145
  70. mplang/{simp → v1/simp}/random.py +21 -22
  71. mplang/v1/simp/smpc.py +238 -0
  72. mplang/v1/utils/table_utils.py +185 -0
  73. mplang/v2/__init__.py +424 -0
  74. mplang/v2/backends/__init__.py +57 -0
  75. mplang/v2/backends/bfv_impl.py +705 -0
  76. mplang/v2/backends/channel.py +217 -0
  77. mplang/v2/backends/crypto_impl.py +723 -0
  78. mplang/v2/backends/field_impl.py +454 -0
  79. mplang/v2/backends/func_impl.py +107 -0
  80. mplang/v2/backends/phe_impl.py +148 -0
  81. mplang/v2/backends/simp_design.md +136 -0
  82. mplang/v2/backends/simp_driver/__init__.py +41 -0
  83. mplang/v2/backends/simp_driver/http.py +168 -0
  84. mplang/v2/backends/simp_driver/mem.py +280 -0
  85. mplang/v2/backends/simp_driver/ops.py +135 -0
  86. mplang/v2/backends/simp_driver/state.py +60 -0
  87. mplang/v2/backends/simp_driver/values.py +52 -0
  88. mplang/v2/backends/simp_worker/__init__.py +29 -0
  89. mplang/v2/backends/simp_worker/http.py +354 -0
  90. mplang/v2/backends/simp_worker/mem.py +102 -0
  91. mplang/v2/backends/simp_worker/ops.py +167 -0
  92. mplang/v2/backends/simp_worker/state.py +49 -0
  93. mplang/v2/backends/spu_impl.py +275 -0
  94. mplang/v2/backends/spu_state.py +187 -0
  95. mplang/v2/backends/store_impl.py +62 -0
  96. mplang/v2/backends/table_impl.py +838 -0
  97. mplang/v2/backends/tee_impl.py +215 -0
  98. mplang/v2/backends/tensor_impl.py +519 -0
  99. mplang/v2/cli.py +603 -0
  100. mplang/v2/cli_guide.md +122 -0
  101. mplang/v2/dialects/__init__.py +36 -0
  102. mplang/v2/dialects/bfv.py +665 -0
  103. mplang/v2/dialects/crypto.py +689 -0
  104. mplang/v2/dialects/dtypes.py +378 -0
  105. mplang/v2/dialects/field.py +210 -0
  106. mplang/v2/dialects/func.py +135 -0
  107. mplang/v2/dialects/phe.py +723 -0
  108. mplang/v2/dialects/simp.py +944 -0
  109. mplang/v2/dialects/spu.py +349 -0
  110. mplang/v2/dialects/store.py +63 -0
  111. mplang/v2/dialects/table.py +407 -0
  112. mplang/v2/dialects/tee.py +346 -0
  113. mplang/v2/dialects/tensor.py +1175 -0
  114. mplang/v2/edsl/README.md +279 -0
  115. mplang/v2/edsl/__init__.py +99 -0
  116. mplang/v2/edsl/context.py +311 -0
  117. mplang/v2/edsl/graph.py +463 -0
  118. mplang/v2/edsl/jit.py +62 -0
  119. mplang/v2/edsl/object.py +53 -0
  120. mplang/v2/edsl/primitive.py +284 -0
  121. mplang/v2/edsl/printer.py +119 -0
  122. mplang/v2/edsl/registry.py +207 -0
  123. mplang/v2/edsl/serde.py +375 -0
  124. mplang/v2/edsl/tracer.py +614 -0
  125. mplang/v2/edsl/typing.py +816 -0
  126. mplang/v2/kernels/Makefile +30 -0
  127. mplang/v2/kernels/__init__.py +23 -0
  128. mplang/v2/kernels/gf128.cpp +148 -0
  129. mplang/v2/kernels/ldpc.cpp +82 -0
  130. mplang/v2/kernels/okvs.cpp +283 -0
  131. mplang/v2/kernels/okvs_opt.cpp +291 -0
  132. mplang/v2/kernels/py_kernels.py +398 -0
  133. mplang/v2/libs/collective.py +330 -0
  134. mplang/v2/libs/device/__init__.py +51 -0
  135. mplang/v2/libs/device/api.py +813 -0
  136. mplang/v2/libs/device/cluster.py +352 -0
  137. mplang/v2/libs/ml/__init__.py +23 -0
  138. mplang/v2/libs/ml/sgb.py +1861 -0
  139. mplang/v2/libs/mpc/__init__.py +41 -0
  140. mplang/v2/libs/mpc/_utils.py +99 -0
  141. mplang/v2/libs/mpc/analytics/__init__.py +35 -0
  142. mplang/v2/libs/mpc/analytics/aggregation.py +372 -0
  143. mplang/v2/libs/mpc/analytics/groupby.md +99 -0
  144. mplang/v2/libs/mpc/analytics/groupby.py +331 -0
  145. mplang/v2/libs/mpc/analytics/permutation.py +386 -0
  146. mplang/v2/libs/mpc/common/constants.py +39 -0
  147. mplang/v2/libs/mpc/ot/__init__.py +32 -0
  148. mplang/v2/libs/mpc/ot/base.py +222 -0
  149. mplang/v2/libs/mpc/ot/extension.py +477 -0
  150. mplang/v2/libs/mpc/ot/silent.py +217 -0
  151. mplang/v2/libs/mpc/psi/__init__.py +40 -0
  152. mplang/v2/libs/mpc/psi/cuckoo.py +228 -0
  153. mplang/v2/libs/mpc/psi/okvs.py +49 -0
  154. mplang/v2/libs/mpc/psi/okvs_gct.py +79 -0
  155. mplang/v2/libs/mpc/psi/oprf.py +310 -0
  156. mplang/v2/libs/mpc/psi/rr22.py +344 -0
  157. mplang/v2/libs/mpc/psi/unbalanced.py +200 -0
  158. mplang/v2/libs/mpc/vole/__init__.py +31 -0
  159. mplang/v2/libs/mpc/vole/gilboa.py +327 -0
  160. mplang/v2/libs/mpc/vole/ldpc.py +383 -0
  161. mplang/v2/libs/mpc/vole/silver.py +336 -0
  162. mplang/v2/runtime/__init__.py +15 -0
  163. mplang/v2/runtime/dialect_state.py +41 -0
  164. mplang/v2/runtime/interpreter.py +871 -0
  165. mplang/v2/runtime/object_store.py +194 -0
  166. mplang/v2/runtime/value.py +141 -0
  167. {mplang_nightly-0.1.dev158.dist-info → mplang_nightly-0.1.dev268.dist-info}/METADATA +24 -17
  168. mplang_nightly-0.1.dev268.dist-info/RECORD +180 -0
  169. {mplang_nightly-0.1.dev158.dist-info → mplang_nightly-0.1.dev268.dist-info}/WHEEL +1 -1
  170. mplang/core/__init__.py +0 -92
  171. mplang/device.py +0 -340
  172. mplang/kernels/builtin.py +0 -207
  173. mplang/ops/crypto.py +0 -109
  174. mplang/ops/ibis_cc.py +0 -139
  175. mplang/ops/sql.py +0 -61
  176. mplang/protos/v1alpha1/mpir_pb2_grpc.py +0 -3
  177. mplang/runtime/link_comm.py +0 -131
  178. mplang/simp/smpc.py +0 -201
  179. mplang/utils/table_utils.py +0 -73
  180. mplang_nightly-0.1.dev158.dist-info/RECORD +0 -77
  181. /mplang/{core → v1/core}/mask.py +0 -0
  182. /mplang/{protos → v1/protos}/v1alpha1/mpir_pb2.py +0 -0
  183. /mplang/{runtime → v1/runtime}/exceptions.py +0 -0
  184. /mplang/{runtime → v1/runtime}/http_api.md +0 -0
  185. /mplang/{kernels → v1/simp}/__init__.py +0 -0
  186. /mplang/{utils → v1/utils}/__init__.py +0 -0
  187. /mplang/{utils → v1/utils}/crypto.py +0 -0
  188. /mplang/{utils → v1/utils}/func_utils.py +0 -0
  189. /mplang/{utils → v1/utils}/spu_utils.py +0 -0
  190. {mplang_nightly-0.1.dev158.dist-info → mplang_nightly-0.1.dev268.dist-info}/entry_points.txt +0 -0
  191. {mplang_nightly-0.1.dev158.dist-info → mplang_nightly-0.1.dev268.dist-info}/licenses/LICENSE +0 -0
@@ -14,18 +14,18 @@
14
14
 
15
15
  from __future__ import annotations
16
16
 
17
+ import logging
17
18
  from collections.abc import Callable
18
19
  from typing import Any
19
20
 
20
21
  import jax
21
22
  import jax.numpy as jnp
23
+ from jax import export
22
24
  from jax.tree_util import PyTreeDef, tree_flatten
23
25
 
24
- from mplang.core.mpobject import MPObject
25
- from mplang.core.pfunc import PFunction, get_fn_name
26
- from mplang.core.tensor import TensorType
27
- from mplang.ops.base import FeOperation, stateless_mod
28
- from mplang.utils.func_utils import normalize_fn
26
+ from mplang.v1.core import MPObject, PFunction, TensorType, get_fn_name
27
+ from mplang.v1.ops.base import FeOperation, stateless_mod
28
+ from mplang.v1.utils.func_utils import normalize_fn
29
29
 
30
30
  # Enable 64-bit precision for JAX to match tensor types
31
31
  jax.config.update("jax_enable_x64", True)
@@ -38,7 +38,8 @@ def jax2stablehlo(
38
38
 
39
39
  Translates high-level JAX functions into StableHLO MLIR representations,
40
40
  enabling execution on JAX backends across different processes and platforms.
41
- Uses the standard JAX compilation pipeline: jit tracelower StableHLO MLIR.
41
+ Uses a hybrid approach: traditional JAX trace/lower for compilation compatibility,
42
+ with stable jax.export API for parameter tracking.
42
43
 
43
44
  Args:
44
45
  is_variable: Predicate function to classify parameters as variables vs. constants.
@@ -54,34 +55,6 @@ def jax2stablehlo(
54
55
  Non-variable parameters are captured as compile-time constants within
55
56
  the PFunction body, while variables become runtime input parameters.
56
57
  - PyTreeDef: Tree structure template for reconstructing nested output values
57
-
58
- Rationale:
59
- JAX Serialization Options Analysis:
60
- 1. jax.export (JAX ≥0.4.35) - Official export API with StableHLO backend
61
- 2. HLO protobuf - Raw XLA HloModule serialization
62
- 3. HLO text - Human-readable HLO representation
63
- 4. StableHLO MLIR - Portable intermediate representation
64
- 5. JAX compiled object pickling - Limited to same-process execution
65
-
66
- Current Choice: StableHLO MLIR
67
- Advantages:
68
- - ✅ Available in current JAX version (0.4.34)
69
- - ✅ Cross-version compatibility guaranteed by StableHLO design
70
- - ✅ Direct compilation support via XLA client.compile(mlir_string)
71
- - ✅ Handles complex functions (multi-input/output, control flow)
72
- - ✅ Preserves numerical precision
73
- - ✅ Platform-independent representation
74
-
75
- Alternative Options Issues:
76
- - jax.export: Not available in JAX 0.4.34
77
- - HLO protobuf: Version compatibility issues with StableHLO parser
78
- - HLO text: Parser compatibility issues with XLA client
79
- - Pickle: Cannot serialize XLA LoadedExecutable objects
80
-
81
- Future Migration Path:
82
- - JAX ≥0.4.35: Migrate to jax.export.export() + jax.export.deserialize()
83
- - JAX ≥0.5.x: Consider new portable formats if available
84
- - Long-term: Adopt official JAX serialization standards as they mature
85
58
  """
86
59
  # Flatten (args, kwargs) and capture immediates using the moved logic from primitive.py
87
60
  normalized_fn, in_vars = normalize_fn(flat_fn, args, kwargs, is_variable)
@@ -91,47 +64,39 @@ def jax2stablehlo(
91
64
  jax.ShapeDtypeStruct(arg.shape, jnp.dtype(arg.dtype.name)) for arg in in_vars
92
65
  ]
93
66
 
94
- # Standard JAX serialization pipeline: jit trace lower StableHLO MLIR
67
+ # Hybrid approach: Use standard JAX trace/lower for compatibility, but jax.export for parameter tracking
95
68
  jitted_fn = jax.jit(normalized_fn)
96
69
  traced = jitted_fn.trace(jax_params)
97
70
  lowered = traced.lower()
98
71
 
99
- # Get StableHLO MLIR representation - the portable format
100
- # compiler_ir("stablehlo") returns jaxlib.mlir.ir.Module object
101
- # str() converts to serializable text format
72
+ # Get StableHLO MLIR representation using traditional approach
102
73
  stablehlo_mlir = lowered.compiler_ir("stablehlo")
103
74
  mlir_text = str(stablehlo_mlir)
104
75
 
105
- # Get output info and tree structure for result reconstruction after remote execution
76
+ # Get output info using traditional approach
106
77
  out_info_flat, out_tree = tree_flatten(lowered.out_info)
107
78
  out_info_flat = [TensorType.from_obj(info) for info in out_info_flat]
108
79
 
109
- # Extract argument keep mapping to handle JAX's unused parameter elimination
110
- # JAX can eliminate unused parameters during compilation, but the runtime still
111
- # receives all original arguments. We need the mapping to filter them correctly.
80
+ # Extract argument keep mapping using stable jax.export API for parameter tracking
81
+ # We use jax.export only for getting the kept_var_idx information, not for the main compilation
112
82
  arg_keep_map = None
113
83
  original_arg_count = len(in_vars)
114
84
 
115
85
  try:
116
- # Access JAX internal kept_var_idx - the authoritative source
117
- # This tells us exactly which original parameters survived compilation
118
- compile_args = lowered._lowering.compile_args
119
- kept_var_idx = compile_args["kept_var_idx"]
120
-
121
- kept_indices = sorted(kept_var_idx)
122
- if len(kept_indices) < original_arg_count:
123
- arg_keep_map = kept_indices
124
-
125
- except (AttributeError, KeyError, TypeError) as e:
126
- # JAX internal API is not available or changed
127
- # This is a hard error - we cannot reliably handle unused parameters
128
- # without knowing exactly which ones were kept
129
- raise RuntimeError(
130
- f"Cannot access JAX's kept_var_idx to handle unused parameter elimination. "
131
- f"This function may have unused parameters that JAX optimized away, "
132
- f"but we cannot determine which ones without the internal API. "
133
- f"Original error: {e}"
134
- ) from e
86
+ # Use jax.export just to get the stable parameter tracking information
87
+ export_fn = export.export(jitted_fn)
88
+ exported = export_fn(jax_params)
89
+ kept_var_idx = exported.module_kept_var_idx
90
+ if kept_var_idx is not None and len(kept_var_idx) < original_arg_count:
91
+ # JAX eliminated some unused parameters during compilation
92
+ # Keep the indices in sorted order for consistent mapping
93
+ arg_keep_map = sorted(kept_var_idx)
94
+ except Exception as e:
95
+ # Fallback: if jax.export fails, we can still use the compiled result without parameter tracking
96
+ # This ensures backward compatibility even if export has issues
97
+ logging.warning(
98
+ f"jax.export failed to get kept_var_idx, proceeding without it. Error: {e}"
99
+ )
135
100
 
136
101
  # This format tells JaxRT how to handle the compiled result
137
102
  pfn_kwargs: dict[str, Any] = {
@@ -149,11 +114,11 @@ def jax2stablehlo(
149
114
  return pfn, in_vars, out_tree
150
115
 
151
116
 
152
- class JaxCompiler(FeOperation):
153
- """JAX compiler frontend operation."""
117
+ class JaxRunner(FeOperation):
118
+ """JAX function runner frontend operation."""
154
119
 
155
120
  def trace(
156
- self, func: Callable, *args: Any, **kwargs: Any
121
+ self, jax_fn: Callable, *args: Any, **kwargs: Any
157
122
  ) -> tuple[PFunction, list[MPObject], PyTreeDef]:
158
123
  """
159
124
  JAX compilation helper function.
@@ -162,21 +127,21 @@ class JaxCompiler(FeOperation):
162
127
  along with variable arguments for evaluation.
163
128
 
164
129
  Args:
165
- func: The JAX function to compile
130
+ jax_fn: The JAX function to compile
166
131
  *args: Positional arguments to the function
167
132
  **kwargs: Keyword arguments to the function
168
133
 
169
134
  Returns:
170
- tuple[PFunction, list[MPObject], Any]: The compiled PFunction, input variables, and output tree
135
+ tuple[PFunction, list[MPObject], PyTreeDef]: The compiled PFunction, input variables, and output tree
171
136
  """
172
137
 
173
138
  def is_variable(arg: Any) -> bool:
174
139
  return isinstance(arg, MPObject)
175
140
 
176
- pfunc, in_vars, out_tree = jax2stablehlo(is_variable, func, *args, **kwargs)
141
+ pfunc, in_vars, out_tree = jax2stablehlo(is_variable, jax_fn, *args, **kwargs)
177
142
  return pfunc, in_vars, out_tree
178
143
 
179
144
 
180
145
  _JAX_MOD = stateless_mod("jax")
181
146
 
182
- jax_compile = JaxCompiler(_JAX_MOD, "compile")
147
+ run_jax = JaxRunner(_JAX_MOD, "run")
@@ -0,0 +1,168 @@
1
+ # Copyright 2025 Ant Group Co., Ltd.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ from __future__ import annotations
16
+
17
+ import logging
18
+ from collections.abc import Callable
19
+ from typing import Any
20
+
21
+ import jax
22
+ import jax.numpy as jnp
23
+ from flax import nnx
24
+ from jax import export
25
+ from jax.tree_util import PyTreeDef, tree_flatten
26
+
27
+ from mplang.v1.core import MPObject, PFunction, TensorType, get_fn_name
28
+ from mplang.v1.ops.base import FeOperation, stateless_mod
29
+ from mplang.v1.utils.func_utils import normalize_fn
30
+
31
+ # Enable 64-bit precision for JAX to match tensor types
32
+ jax.config.update("jax_enable_x64", True)
33
+
34
+
35
+ def nnx2stablehlo(
36
+ is_variable: Callable[[Any], bool], flat_fn: Any, *args: Any, **kwargs: Any
37
+ ) -> tuple[PFunction, list[Any], PyTreeDef]:
38
+ """Compile NNX function to StableHLO MLIR format for remote execution.
39
+
40
+ Translates high-level NNX functions into StableHLO MLIR representations,
41
+ enabling execution on JAX backends across different processes and platforms.
42
+ Uses a hybrid approach: traditional NNX trace/lower for compilation compatibility,
43
+ with stable jax.export API for parameter tracking.
44
+
45
+ Args:
46
+ is_variable: Predicate function to classify parameters as variables vs. constants.
47
+ Returns True for parameters that should be treated as PFunction inputs.
48
+ flat_fn: NNX function to be compiled into StableHLO format
49
+ *args: Positional arguments passed to the function during compilation
50
+ **kwargs: Keyword arguments passed to the function during compilation
51
+
52
+ Returns:
53
+ tuple[PFunction, list, PyTreeDef]: Compilation artifacts containing:
54
+ - PFunction: Serialized function with embedded MLIR text and type metadata
55
+ - list: Extracted variable parameters (those satisfying is_variable predicate).
56
+ Non-variable parameters are captured as compile-time constants within
57
+ the PFunction body, while variables become runtime input parameters.
58
+ - PyTreeDef: Tree structure template for reconstructing nested output values
59
+ """
60
+ # Flatten (args, kwargs) and capture immediates using the moved logic from primitive.py
61
+ normalized_fn, in_vars = normalize_fn(flat_fn, args, kwargs, is_variable)
62
+
63
+ # Convert TensorType in_vars to ShapeDtypeStruct for JAX tracing
64
+ jax_params = [
65
+ jax.ShapeDtypeStruct(arg.shape, jnp.dtype(arg.dtype.name)) for arg in in_vars
66
+ ]
67
+
68
+ # NNX compilation pipeline using JAX export API: nnx.jit → jax.export → StableHLO MLIR
69
+ # Use nnx.jit for NNX-specific functionality, then jax.export for stable parameter handling
70
+ nnx_jitted = nnx.jit(normalized_fn)
71
+
72
+ # Extract the underlying JAX function for jax.export compatibility
73
+ # nnx.jit wraps a JAX function, and we can access it via .fun attribute
74
+ underlying_jax_fn = nnx_jitted.fun
75
+
76
+ # Hybrid approach: Use NNX trace/lower for compilation, but jax.export for parameter tracking
77
+ # Use traditional nnx.jit → trace → lower for compatibility with argument structure
78
+ nnx_traced = nnx_jitted.trace(jax_params)
79
+ nnx_lowered = nnx_traced.lower()
80
+
81
+ # Get StableHLO MLIR representation using traditional NNX approach
82
+ # NNX lowered object wraps JAX lowered, so we access the inner JAX lowered object
83
+ jax_lowered = nnx_lowered.lowered
84
+ stablehlo_mlir = jax_lowered.compiler_ir("stablehlo")
85
+ mlir_text = str(stablehlo_mlir)
86
+
87
+ # Get output info using traditional NNX approach
88
+ # NNX captures output in (args, kwargs, result) format, so we need to extract just the result part
89
+ raw_out_info = jax_lowered.out_info
90
+ if isinstance(raw_out_info, tuple) and len(raw_out_info) == 3:
91
+ # NNX format: (args, kwargs, result) - extract just the result
92
+ _, _, actual_out_info = raw_out_info
93
+ out_info_flat, out_tree = tree_flatten(actual_out_info)
94
+ else:
95
+ # Fallback to direct format (shouldn't happen with NNX, but just in case)
96
+ out_info_flat, out_tree = tree_flatten(raw_out_info)
97
+
98
+ out_info_flat = [TensorType.from_obj(info) for info in out_info_flat]
99
+
100
+ # Extract argument keep mapping using stable jax.export API for parameter tracking
101
+ # We use the underlying JAX function with jax.export only for parameter tracking
102
+ arg_keep_map = None
103
+ original_arg_count = len(in_vars)
104
+
105
+ try:
106
+ # Use jax.export with the underlying JAX function just to get stable parameter tracking
107
+ export_fn = export.export(jax.jit(underlying_jax_fn))
108
+ exported = export_fn(jax_params)
109
+ kept_var_idx = exported.module_kept_var_idx
110
+ if kept_var_idx is not None and len(kept_var_idx) < original_arg_count:
111
+ # JAX eliminated some unused parameters during compilation
112
+ # Keep the indices in sorted order for consistent mapping
113
+ arg_keep_map = sorted(kept_var_idx)
114
+ except Exception as e:
115
+ # Fallback: if jax.export fails, we can still use the compiled result without parameter tracking
116
+ # This ensures backward compatibility even if export has issues
117
+ logging.warning(
118
+ f"jax.export failed to get kept_var_idx, proceeding without it. Error: {e}"
119
+ )
120
+
121
+ # This format tells JaxRT how to handle the compiled result
122
+ # Use the same format as JAX since NNX compiles to the same backend
123
+ pfn_kwargs: dict[str, Any] = {
124
+ "fn_type": "mlir.stablehlo", # Key: specify StableHLO MLIR format
125
+ "ins_info": tuple(TensorType.from_obj(x) for x in in_vars),
126
+ "outs_info": tuple(out_info_flat),
127
+ "fn_name": get_fn_name(flat_fn),
128
+ "fn_text": mlir_text, # MLIR text, serializable for transmission
129
+ }
130
+
131
+ if arg_keep_map is not None:
132
+ pfn_kwargs["arg_keep_map"] = arg_keep_map
133
+
134
+ pfn = PFunction(**pfn_kwargs)
135
+ return pfn, in_vars, out_tree
136
+
137
+
138
+ class NnxRunner(FeOperation):
139
+ """NNX function runner frontend operation."""
140
+
141
+ def trace(
142
+ self, nnx_fn: Callable, *args: Any, **kwargs: Any
143
+ ) -> tuple[PFunction, list[MPObject], PyTreeDef]:
144
+ """
145
+ NNX compilation helper function.
146
+
147
+ Compiles an NNX function to StableHLO format and returns the PFunction
148
+ along with variable arguments for evaluation.
149
+
150
+ Args:
151
+ nnx_fn: The NNX function to compile
152
+ *args: Positional arguments to the function
153
+ **kwargs: Keyword arguments to the function
154
+
155
+ Returns:
156
+ tuple[PFunction, list[MPObject], PyTreeDef]: The compiled PFunction, input variables, and output tree
157
+ """
158
+
159
+ def is_variable(arg: Any) -> bool:
160
+ return isinstance(arg, MPObject)
161
+
162
+ pfunc, in_vars, out_tree = nnx2stablehlo(is_variable, nnx_fn, *args, **kwargs)
163
+ return pfunc, in_vars, out_tree
164
+
165
+
166
+ _NNX_MOD = stateless_mod("nnx")
167
+
168
+ run_nnx = NnxRunner(_NNX_MOD, "run")
@@ -14,22 +14,34 @@
14
14
 
15
15
  """PHE (Partially Homomorphic Encryption) frontend operations."""
16
16
 
17
- from mplang.core.dtype import UINT8
18
- from mplang.core.tensor import TensorType
19
- from mplang.ops.base import stateless_mod
17
+ from mplang.v1.core import UINT8, TensorType
18
+ from mplang.v1.ops.base import stateless_mod
20
19
 
21
20
  _PHE_MOD = stateless_mod("phe")
22
21
 
23
22
 
24
23
  @_PHE_MOD.simple_op()
25
24
  def keygen(
26
- *, scheme: str = "paillier", key_size: int = 2048
25
+ *,
26
+ scheme: str = "paillier",
27
+ key_size: int = 2048,
28
+ max_value: int | None = None,
29
+ fxp_bits: int | None = None,
27
30
  ) -> tuple[TensorType, TensorType]:
28
31
  """Generate a PHE key pair: returns (public_key, private_key).
29
32
 
30
33
  Keys are represented with a sentinel TensorType UINT8[(-1, 0)] to indicate
31
34
  non-structural, backend-only handles. Runtime validation will treat this
32
35
  shape as an opaque placeholder and skip dtype/shape checks.
36
+
37
+ Attributes (forwarded to backend):
38
+ scheme: PHE scheme (default: 'paillier')
39
+ key_size: Modulus size in bits (default: 2048)
40
+ max_value: Optional range-encoding bound B. If provided, the backend will
41
+ encode/decode integers/floats within [-B, B] and treat (B, N-B) as overflow.
42
+ Pick B to exceed the largest intermediate magnitude you expect in homomorphic
43
+ combinations. If omitted, backend default is used (currently 2**32).
44
+ fxp_bits: Optional fixed-point fractional bits for float encoding (default backend value).
33
45
  """
34
46
  key_spec = TensorType(UINT8, (-1, 0))
35
47
  return key_spec, key_spec
@@ -23,11 +23,9 @@ import spu.utils.frontend as spu_fe
23
23
  from jax import ShapeDtypeStruct
24
24
  from jax.tree_util import PyTreeDef, tree_flatten
25
25
 
26
- from mplang.core.mpobject import MPObject
27
- from mplang.core.pfunc import PFunction, get_fn_name
28
- from mplang.core.tensor import TensorType
29
- from mplang.ops.base import stateless_mod
30
- from mplang.utils.func_utils import normalize_fn
26
+ from mplang.v1.core import MPObject, PFunction, TensorType, get_fn_name
27
+ from mplang.v1.ops.base import stateless_mod
28
+ from mplang.v1.utils.func_utils import normalize_fn
31
29
 
32
30
 
33
31
  class Visibility: