mplang-nightly 0.1.dev269__py3-none-any.whl → 0.1.dev271__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (180) hide show
  1. mplang/__init__.py +391 -17
  2. mplang/{v2/backends → backends}/__init__.py +9 -7
  3. mplang/{v2/backends → backends}/bfv_impl.py +6 -6
  4. mplang/{v2/backends → backends}/crypto_impl.py +6 -6
  5. mplang/{v2/backends → backends}/field_impl.py +5 -5
  6. mplang/{v2/backends → backends}/func_impl.py +4 -4
  7. mplang/{v2/backends → backends}/phe_impl.py +3 -3
  8. mplang/{v2/backends → backends}/simp_design.md +1 -1
  9. mplang/{v2/backends → backends}/simp_driver/__init__.py +5 -5
  10. mplang/{v2/backends → backends}/simp_driver/http.py +8 -8
  11. mplang/{v2/backends → backends}/simp_driver/mem.py +9 -9
  12. mplang/{v2/backends → backends}/simp_driver/ops.py +4 -4
  13. mplang/{v2/backends → backends}/simp_driver/state.py +2 -2
  14. mplang/{v2/backends → backends}/simp_driver/values.py +2 -2
  15. mplang/{v2/backends → backends}/simp_worker/__init__.py +3 -3
  16. mplang/{v2/backends → backends}/simp_worker/http.py +10 -10
  17. mplang/{v2/backends → backends}/simp_worker/mem.py +1 -1
  18. mplang/{v2/backends → backends}/simp_worker/ops.py +5 -5
  19. mplang/{v2/backends → backends}/simp_worker/state.py +2 -4
  20. mplang/{v2/backends → backends}/spu_impl.py +8 -8
  21. mplang/{v2/backends → backends}/spu_state.py +4 -4
  22. mplang/{v2/backends → backends}/store_impl.py +3 -3
  23. mplang/{v2/backends → backends}/table_impl.py +8 -8
  24. mplang/{v2/backends → backends}/tee_impl.py +6 -6
  25. mplang/{v2/backends → backends}/tensor_impl.py +6 -6
  26. mplang/{v2/cli.py → cli.py} +9 -9
  27. mplang/{v2/cli_guide.md → cli_guide.md} +12 -12
  28. mplang/{v2/dialects → dialects}/__init__.py +5 -5
  29. mplang/{v2/dialects → dialects}/bfv.py +6 -6
  30. mplang/{v2/dialects → dialects}/crypto.py +5 -5
  31. mplang/{v2/dialects → dialects}/dtypes.py +2 -2
  32. mplang/{v2/dialects → dialects}/field.py +3 -3
  33. mplang/{v2/dialects → dialects}/func.py +2 -2
  34. mplang/{v2/dialects → dialects}/phe.py +6 -6
  35. mplang/{v2/dialects → dialects}/simp.py +6 -6
  36. mplang/{v2/dialects → dialects}/spu.py +7 -7
  37. mplang/{v2/dialects → dialects}/store.py +2 -2
  38. mplang/{v2/dialects → dialects}/table.py +3 -3
  39. mplang/{v2/dialects → dialects}/tee.py +6 -6
  40. mplang/{v2/dialects → dialects}/tensor.py +5 -5
  41. mplang/{v2/edsl → edsl}/__init__.py +3 -3
  42. mplang/{v2/edsl → edsl}/context.py +6 -6
  43. mplang/{v2/edsl → edsl}/graph.py +5 -5
  44. mplang/{v2/edsl → edsl}/jit.py +2 -2
  45. mplang/{v2/edsl → edsl}/object.py +1 -1
  46. mplang/{v2/edsl → edsl}/primitive.py +5 -5
  47. mplang/{v2/edsl → edsl}/printer.py +1 -1
  48. mplang/{v2/edsl → edsl}/serde.py +1 -1
  49. mplang/{v2/edsl → edsl}/tracer.py +7 -7
  50. mplang/{v2/edsl → edsl}/typing.py +1 -1
  51. mplang/{v2/kernels → kernels}/ldpc.cpp +13 -13
  52. mplang/{v2/kernels → kernels}/okvs.cpp +4 -4
  53. mplang/{v2/kernels → kernels}/okvs_opt.cpp +31 -31
  54. mplang/{v2/kernels → kernels}/py_kernels.py +1 -1
  55. mplang/{v2/libs → libs}/collective.py +5 -5
  56. mplang/{v2/libs → libs}/device/__init__.py +1 -1
  57. mplang/{v2/libs → libs}/device/api.py +12 -12
  58. mplang/{v2/libs → libs}/ml/__init__.py +1 -1
  59. mplang/{v2/libs → libs}/ml/sgb.py +4 -4
  60. mplang/{v2/libs → libs}/mpc/__init__.py +3 -3
  61. mplang/{v2/libs → libs}/mpc/_utils.py +2 -2
  62. mplang/{v2/libs → libs}/mpc/analytics/aggregation.py +1 -1
  63. mplang/{v2/libs → libs}/mpc/analytics/groupby.py +2 -2
  64. mplang/{v2/libs → libs}/mpc/analytics/permutation.py +3 -3
  65. mplang/{v2/libs → libs}/mpc/ot/base.py +3 -3
  66. mplang/{v2/libs → libs}/mpc/ot/extension.py +2 -2
  67. mplang/{v2/libs → libs}/mpc/ot/silent.py +4 -4
  68. mplang/{v2/libs → libs}/mpc/psi/cuckoo.py +3 -3
  69. mplang/{v2/libs → libs}/mpc/psi/okvs.py +1 -1
  70. mplang/{v2/libs → libs}/mpc/psi/okvs_gct.py +3 -3
  71. mplang/{v2/libs → libs}/mpc/psi/oprf.py +3 -3
  72. mplang/{v2/libs → libs}/mpc/psi/rr22.py +7 -7
  73. mplang/{v2/libs → libs}/mpc/psi/unbalanced.py +4 -4
  74. mplang/{v2/libs → libs}/mpc/vole/gilboa.py +3 -3
  75. mplang/{v2/libs → libs}/mpc/vole/ldpc.py +2 -2
  76. mplang/{v2/libs → libs}/mpc/vole/silver.py +6 -6
  77. mplang/{v2/runtime → runtime}/interpreter.py +11 -11
  78. mplang/{v2/runtime → runtime}/value.py +2 -2
  79. mplang/{v1/runtime → utils}/__init__.py +18 -15
  80. mplang/{v1/utils → utils}/func_utils.py +1 -1
  81. {mplang_nightly-0.1.dev269.dist-info → mplang_nightly-0.1.dev271.dist-info}/METADATA +2 -2
  82. mplang_nightly-0.1.dev271.dist-info/RECORD +102 -0
  83. mplang/v1/__init__.py +0 -157
  84. mplang/v1/_device.py +0 -602
  85. mplang/v1/analysis/__init__.py +0 -37
  86. mplang/v1/analysis/diagram.py +0 -567
  87. mplang/v1/core/__init__.py +0 -157
  88. mplang/v1/core/cluster.py +0 -343
  89. mplang/v1/core/comm.py +0 -281
  90. mplang/v1/core/context_mgr.py +0 -50
  91. mplang/v1/core/dtypes.py +0 -335
  92. mplang/v1/core/expr/__init__.py +0 -80
  93. mplang/v1/core/expr/ast.py +0 -542
  94. mplang/v1/core/expr/evaluator.py +0 -581
  95. mplang/v1/core/expr/printer.py +0 -285
  96. mplang/v1/core/expr/transformer.py +0 -141
  97. mplang/v1/core/expr/utils.py +0 -78
  98. mplang/v1/core/expr/visitor.py +0 -85
  99. mplang/v1/core/expr/walk.py +0 -387
  100. mplang/v1/core/interp.py +0 -160
  101. mplang/v1/core/mask.py +0 -325
  102. mplang/v1/core/mpir.py +0 -965
  103. mplang/v1/core/mpobject.py +0 -117
  104. mplang/v1/core/mptype.py +0 -407
  105. mplang/v1/core/pfunc.py +0 -130
  106. mplang/v1/core/primitive.py +0 -877
  107. mplang/v1/core/table.py +0 -218
  108. mplang/v1/core/tensor.py +0 -75
  109. mplang/v1/core/tracer.py +0 -383
  110. mplang/v1/host.py +0 -130
  111. mplang/v1/kernels/__init__.py +0 -41
  112. mplang/v1/kernels/base.py +0 -125
  113. mplang/v1/kernels/basic.py +0 -240
  114. mplang/v1/kernels/context.py +0 -369
  115. mplang/v1/kernels/crypto.py +0 -122
  116. mplang/v1/kernels/fhe.py +0 -858
  117. mplang/v1/kernels/mock_tee.py +0 -72
  118. mplang/v1/kernels/phe.py +0 -1864
  119. mplang/v1/kernels/spu.py +0 -341
  120. mplang/v1/kernels/sql_duckdb.py +0 -44
  121. mplang/v1/kernels/stablehlo.py +0 -90
  122. mplang/v1/kernels/value.py +0 -626
  123. mplang/v1/ops/__init__.py +0 -35
  124. mplang/v1/ops/base.py +0 -424
  125. mplang/v1/ops/basic.py +0 -294
  126. mplang/v1/ops/crypto.py +0 -262
  127. mplang/v1/ops/fhe.py +0 -272
  128. mplang/v1/ops/jax_cc.py +0 -147
  129. mplang/v1/ops/nnx_cc.py +0 -168
  130. mplang/v1/ops/phe.py +0 -216
  131. mplang/v1/ops/spu.py +0 -151
  132. mplang/v1/ops/sql_cc.py +0 -303
  133. mplang/v1/ops/tee.py +0 -36
  134. mplang/v1/protos/v1alpha1/mpir_pb2.py +0 -63
  135. mplang/v1/protos/v1alpha1/mpir_pb2.pyi +0 -557
  136. mplang/v1/protos/v1alpha1/value_pb2.py +0 -34
  137. mplang/v1/protos/v1alpha1/value_pb2.pyi +0 -169
  138. mplang/v1/runtime/channel.py +0 -230
  139. mplang/v1/runtime/cli.py +0 -451
  140. mplang/v1/runtime/client.py +0 -456
  141. mplang/v1/runtime/communicator.py +0 -131
  142. mplang/v1/runtime/data_providers.py +0 -303
  143. mplang/v1/runtime/driver.py +0 -324
  144. mplang/v1/runtime/exceptions.py +0 -27
  145. mplang/v1/runtime/http_api.md +0 -56
  146. mplang/v1/runtime/link_comm.py +0 -196
  147. mplang/v1/runtime/server.py +0 -501
  148. mplang/v1/runtime/session.py +0 -270
  149. mplang/v1/runtime/simulation.py +0 -324
  150. mplang/v1/simp/__init__.py +0 -13
  151. mplang/v1/simp/api.py +0 -353
  152. mplang/v1/simp/mpi.py +0 -131
  153. mplang/v1/simp/party.py +0 -225
  154. mplang/v1/simp/random.py +0 -120
  155. mplang/v1/simp/smpc.py +0 -238
  156. mplang/v1/utils/__init__.py +0 -13
  157. mplang/v1/utils/crypto.py +0 -32
  158. mplang/v1/utils/spu_utils.py +0 -130
  159. mplang/v1/utils/table_utils.py +0 -185
  160. mplang/v2/__init__.py +0 -424
  161. mplang_nightly-0.1.dev269.dist-info/RECORD +0 -180
  162. /mplang/{v2/backends → backends}/channel.py +0 -0
  163. /mplang/{v2/edsl → edsl}/README.md +0 -0
  164. /mplang/{v2/edsl → edsl}/registry.py +0 -0
  165. /mplang/{v2/kernels → kernels}/Makefile +0 -0
  166. /mplang/{v2/kernels → kernels}/__init__.py +0 -0
  167. /mplang/{v2/kernels → kernels}/gf128.cpp +0 -0
  168. /mplang/{v2/libs → libs}/device/cluster.py +0 -0
  169. /mplang/{v2/libs → libs}/mpc/analytics/__init__.py +0 -0
  170. /mplang/{v2/libs → libs}/mpc/analytics/groupby.md +0 -0
  171. /mplang/{v2/libs → libs}/mpc/common/constants.py +0 -0
  172. /mplang/{v2/libs → libs}/mpc/ot/__init__.py +0 -0
  173. /mplang/{v2/libs → libs}/mpc/psi/__init__.py +0 -0
  174. /mplang/{v2/libs → libs}/mpc/vole/__init__.py +0 -0
  175. /mplang/{v2/runtime → runtime}/__init__.py +0 -0
  176. /mplang/{v2/runtime → runtime}/dialect_state.py +0 -0
  177. /mplang/{v2/runtime → runtime}/object_store.py +0 -0
  178. {mplang_nightly-0.1.dev269.dist-info → mplang_nightly-0.1.dev271.dist-info}/WHEEL +0 -0
  179. {mplang_nightly-0.1.dev269.dist-info → mplang_nightly-0.1.dev271.dist-info}/entry_points.txt +0 -0
  180. {mplang_nightly-0.1.dev269.dist-info → mplang_nightly-0.1.dev271.dist-info}/licenses/LICENSE +0 -0
mplang/v1/ops/fhe.py DELETED
@@ -1,272 +0,0 @@
1
- # Copyright 2025 Ant Group Co., Ltd.
2
- #
3
- # Licensed under the Apache License, Version 2.0 (the "License");
4
- # you may not use this file except in compliance with the License.
5
- # You may obtain a copy of the License at
6
- #
7
- # http://www.apache.org/licenses/LICENSE-2.0
8
- #
9
- # Unless required by applicable law or agreed to in writing, software
10
- # distributed under the License is distributed on an "AS IS" BASIS,
11
- # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
- # See the License for the specific language governing permissions and
13
- # limitations under the License.
14
-
15
- from mplang.v1.core import UINT8, TensorType
16
- from mplang.v1.ops.base import stateless_mod
17
-
18
- _fhe_MOD = stateless_mod("fhe")
19
-
20
-
21
- @_fhe_MOD.simple_op()
22
- def keygen(
23
- *,
24
- scheme: str = "CKKS",
25
- poly_modulus_degree: int = 8192,
26
- coeff_mod_bit_sizes: tuple[int, ...] | None = None,
27
- global_scale: int | None = None,
28
- plain_modulus: int | None = None,
29
- ) -> tuple[TensorType, TensorType, TensorType]:
30
- """Generate an FHE key pair for Vector backend: returns (private_context, public_context, evaluation_context).
31
-
32
- Args:
33
- scheme: FHE scheme to use ("CKKS" for approximate, "BFV" for exact integer)
34
- poly_modulus_degree: Polynomial modulus degree (default: 8192)
35
- coeff_mod_bit_sizes: Coefficient modulus bit sizes for CKKS (optional)
36
- global_scale: Global scale for CKKS (optional)
37
- plain_modulus: Plain modulus for BFV (optional)
38
-
39
- Returns:
40
- Tuple of (private_context, public_context, evaluation_context) represented as UINT8[(-1, 0)]
41
-
42
- Contexts are represented with a sentinel TensorType UINT8[(-1, 0)] to indicate
43
- non-structural, backend-only handles.
44
-
45
- Note: Vector backend only supports 1D data. For multi-dimensional tensors,
46
- use mplang.ops.fhe instead.
47
- """
48
- if scheme not in ("CKKS", "BFV"):
49
- raise ValueError("Unsupported scheme. Choose either 'CKKS' or 'BFV'.")
50
- if scheme == "CKKS":
51
- assert plain_modulus is None, "plain_modulus is not used in CKKS scheme."
52
- context_spec = TensorType(UINT8, (-1, 0))
53
- return context_spec, context_spec, context_spec
54
-
55
-
56
- @_fhe_MOD.simple_op()
57
- def encrypt(plaintext: TensorType, context: TensorType) -> TensorType:
58
- """Encrypt plaintext using FHE Vector backend: returns ciphertext with same semantic type.
59
-
60
- Args:
61
- plaintext: Data to encrypt (scalar or 1D vector only)
62
- context: FHE context (private or public)
63
-
64
- Returns:
65
- Ciphertext with same semantic type as plaintext
66
-
67
- Raises:
68
- ValueError: If plaintext has more than 1 dimension
69
-
70
- Note: Vector backend only supports scalars (shape=()) and 1D vectors (shape=(n,)).
71
- For multi-dimensional data, use mplang.ops.fhe.encrypt instead.
72
- """
73
- _ = context
74
- if len(plaintext.shape) > 1:
75
- raise ValueError(
76
- f"FHE Vector backend only supports 1D data. Got shape {plaintext.shape}. "
77
- "Use mplang.ops.fhe for multi-dimensional tensors."
78
- )
79
- return plaintext
80
-
81
-
82
- @_fhe_MOD.simple_op()
83
- def decrypt(ciphertext: TensorType, context: TensorType) -> TensorType:
84
- """Decrypt ciphertext using FHE Vector backend: returns plaintext with same semantic type.
85
-
86
- Args:
87
- ciphertext: Encrypted data to decrypt (scalar or 1D vector)
88
- context: FHE context (must be private context with secret key)
89
-
90
- Returns:
91
- Plaintext with same semantic type as ciphertext
92
-
93
- Note: Ciphertext encrypted with public context can be decrypted with
94
- the corresponding private context.
95
- """
96
- _ = context
97
- return ciphertext
98
-
99
-
100
- @_fhe_MOD.simple_op()
101
- def add(operand1: TensorType, operand2: TensorType) -> TensorType:
102
- """Add two FHE operands (ciphertext + ciphertext or ciphertext + plaintext).
103
-
104
- Args:
105
- operand1: First operand (ciphertext or plaintext, scalar or 1D vector)
106
- operand2: Second operand (ciphertext or plaintext, scalar or 1D vector)
107
-
108
- Returns:
109
- Result of homomorphic addition
110
-
111
- Raises:
112
- ValueError: If operands have incompatible shapes or dtypes
113
-
114
- Note: At least one operand must be ciphertext. Both operands must have
115
- the same shape (no broadcasting in Vector backend).
116
- """
117
- assert operand1.dtype == operand2.dtype, (
118
- f"Operand dtypes must match, got {operand1.dtype} and {operand2.dtype}."
119
- )
120
- assert operand1.shape == operand2.shape, (
121
- f"Operand shapes must match, got {operand1.shape} and {operand2.shape}."
122
- )
123
- return operand1
124
-
125
-
126
- @_fhe_MOD.simple_op()
127
- def sub(operand1: TensorType, operand2: TensorType) -> TensorType:
128
- """Subtract two FHE operands (ciphertext - ciphertext or ciphertext - plaintext).
129
-
130
- Args:
131
- operand1: First operand (ciphertext or plaintext, scalar or 1D vector)
132
- operand2: Second operand (ciphertext or plaintext, scalar or 1D vector)
133
-
134
- Returns:
135
- Result of homomorphic subtraction
136
-
137
- Raises:
138
- ValueError: If operands have incompatible shapes or dtypes
139
-
140
- Note: At least one operand must be ciphertext. Both operands must have
141
- the same shape (no broadcasting in Vector backend).
142
- """
143
- assert operand1.dtype == operand2.dtype, (
144
- f"Operand dtypes must match, got {operand1.dtype} and {operand2.dtype}."
145
- )
146
- assert operand1.shape == operand2.shape, (
147
- f"Operand shapes must match, got {operand1.shape} and {operand2.shape}."
148
- )
149
- return operand1
150
-
151
-
152
- @_fhe_MOD.simple_op()
153
- def mul(operand1: TensorType, operand2: TensorType) -> TensorType:
154
- """Multiply two FHE operands (ciphertext * ciphertext or ciphertext * plaintext).
155
-
156
- Args:
157
- operand1: First operand (ciphertext or plaintext, scalar or 1D vector)
158
- operand2: Second operand (ciphertext or plaintext, scalar or 1D vector)
159
-
160
- Returns:
161
- Result of homomorphic multiplication
162
-
163
- Raises:
164
- ValueError: If operands have incompatible shapes or dtypes
165
-
166
- Note: At least one operand must be ciphertext. Both operands must have
167
- the same shape (no broadcasting in Vector backend).
168
- For BFV scheme, plaintext operands must be integers.
169
- """
170
- assert operand1.dtype == operand2.dtype, (
171
- f"Operand dtypes must match, got {operand1.dtype} and {operand2.dtype}."
172
- )
173
- assert operand1.shape == operand2.shape, (
174
- f"Operand shapes must match, got {operand1.shape} and {operand2.shape}."
175
- )
176
- return operand1
177
-
178
-
179
- @_fhe_MOD.simple_op()
180
- def dot(operand1: TensorType, operand2: TensorType) -> TensorType:
181
- """Compute dot product of FHE operands (ciphertext · ciphertext or ciphertext · plaintext).
182
-
183
- Args:
184
- operand1: First operand (ciphertext or plaintext, must be 1D vector)
185
- operand2: Second operand (ciphertext or plaintext, must be 1D vector)
186
-
187
- Returns:
188
- Scalar result of homomorphic dot product (shape=())
189
-
190
- Raises:
191
- ValueError: If operands are not 1D vectors or have different lengths
192
-
193
- Note: Both operands must be 1D vectors (not scalars). For scalar multiplication,
194
- use mul() instead. This operation always returns a scalar.
195
- """
196
- if len(operand1.shape) != 1:
197
- raise ValueError(
198
- f"Dot product requires 1D vectors, got shape {operand1.shape} for operand1"
199
- )
200
- if len(operand2.shape) != 1:
201
- raise ValueError(
202
- f"Dot product requires 1D vectors, got shape {operand2.shape} for operand2"
203
- )
204
- if operand1.shape[0] != operand2.shape[0]:
205
- raise ValueError(
206
- f"Dot product dimension mismatch: {operand1.shape[0]} vs {operand2.shape[0]}"
207
- )
208
-
209
- # Dot product of 1D vectors returns a scalar
210
- return TensorType(operand1.dtype, ())
211
-
212
-
213
- @_fhe_MOD.simple_op()
214
- def polyval(ciphertext: TensorType, coeffs: TensorType) -> TensorType:
215
- """Evaluate polynomial on encrypted data with plaintext coefficients.
216
-
217
- Args:
218
- ciphertext: Encrypted data (scalar or 1D vector)
219
- coeffs: Plaintext polynomial coefficients as 1D array [c0, c1, c2, ...]
220
- representing c0 + c1*x + c2*x^2 + ...
221
-
222
- Returns:
223
- Result of polynomial evaluation with same shape and dtype as ciphertext
224
-
225
- Raises:
226
- ValueError: If coefficients array is not 1D or has fewer than 2 elements
227
-
228
- Note: Polynomial must have degree >= 1 (at least 2 coefficients required).
229
- Constant polynomials (degree 0, single coefficient) are NOT supported due to
230
- TenSEAL limitation. For constant values, use: ct * 0 + constant instead.
231
- For BFV scheme, coefficients must be integers.
232
-
233
- Common use case - Sigmoid approximation:
234
- sigmoid_coeffs = [0.5, 0.15012, 0.0, -0.0018027]
235
- result = polyval(ciphertext, sigmoid_coeffs)
236
- """
237
- if len(coeffs.shape) != 1:
238
- raise ValueError(
239
- f"Polynomial coefficients must be 1D array, got shape {coeffs.shape}"
240
- )
241
- _ = coeffs
242
- return ciphertext
243
-
244
-
245
- @_fhe_MOD.simple_op()
246
- def negate(ciphertext: TensorType) -> TensorType:
247
- """Negate encrypted data (unary minus).
248
-
249
- Args:
250
- ciphertext: Encrypted data (scalar or 1D vector)
251
-
252
- Returns:
253
- Negated ciphertext with same shape and dtype
254
-
255
- Note: Equivalent to multiplying by -1.
256
- """
257
- return ciphertext
258
-
259
-
260
- @_fhe_MOD.simple_op()
261
- def square(ciphertext: TensorType) -> TensorType:
262
- """Square encrypted data (element-wise).
263
-
264
- Args:
265
- ciphertext: Encrypted data (scalar or 1D vector)
266
-
267
- Returns:
268
- Squared ciphertext with same shape and dtype
269
-
270
- Note: More efficient than mul(ciphertext, ciphertext) in some FHE schemes.
271
- """
272
- return ciphertext
mplang/v1/ops/jax_cc.py DELETED
@@ -1,147 +0,0 @@
1
- # Copyright 2025 Ant Group Co., Ltd.
2
- #
3
- # Licensed under the Apache License, Version 2.0 (the "License");
4
- # you may not use this file except in compliance with the License.
5
- # You may obtain a copy of the License at
6
- #
7
- # http://www.apache.org/licenses/LICENSE-2.0
8
- #
9
- # Unless required by applicable law or agreed to in writing, software
10
- # distributed under the License is distributed on an "AS IS" BASIS,
11
- # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
- # See the License for the specific language governing permissions and
13
- # limitations under the License.
14
-
15
- from __future__ import annotations
16
-
17
- import logging
18
- from collections.abc import Callable
19
- from typing import Any
20
-
21
- import jax
22
- import jax.numpy as jnp
23
- from jax import export
24
- from jax.tree_util import PyTreeDef, tree_flatten
25
-
26
- from mplang.v1.core import MPObject, PFunction, TensorType, get_fn_name
27
- from mplang.v1.ops.base import FeOperation, stateless_mod
28
- from mplang.v1.utils.func_utils import normalize_fn
29
-
30
- # Enable 64-bit precision for JAX to match tensor types
31
- jax.config.update("jax_enable_x64", True)
32
-
33
-
34
- def jax2stablehlo(
35
- is_variable: Callable[[Any], bool], flat_fn: Any, *args: Any, **kwargs: Any
36
- ) -> tuple[PFunction, list, PyTreeDef]:
37
- """Compile JAX function to StableHLO MLIR format for remote execution.
38
-
39
- Translates high-level JAX functions into StableHLO MLIR representations,
40
- enabling execution on JAX backends across different processes and platforms.
41
- Uses a hybrid approach: traditional JAX trace/lower for compilation compatibility,
42
- with stable jax.export API for parameter tracking.
43
-
44
- Args:
45
- is_variable: Predicate function to classify parameters as variables vs. constants.
46
- Returns True for parameters that should be treated as PFunction inputs.
47
- flat_fn: JAX function to be compiled into StableHLO format
48
- *args: Positional arguments passed to the function during compilation
49
- **kwargs: Keyword arguments passed to the function during compilation
50
-
51
- Returns:
52
- tuple[PFunction, list, PyTreeDef]: Compilation artifacts containing:
53
- - PFunction: Serialized function with embedded MLIR text and type metadata
54
- - list: Extracted variable parameters (those satisfying is_variable predicate).
55
- Non-variable parameters are captured as compile-time constants within
56
- the PFunction body, while variables become runtime input parameters.
57
- - PyTreeDef: Tree structure template for reconstructing nested output values
58
- """
59
- # Flatten (args, kwargs) and capture immediates using the moved logic from primitive.py
60
- normalized_fn, in_vars = normalize_fn(flat_fn, args, kwargs, is_variable)
61
-
62
- # Convert TensorType in_vars to ShapeDtypeStruct for JAX tracing
63
- jax_params = [
64
- jax.ShapeDtypeStruct(arg.shape, jnp.dtype(arg.dtype.name)) for arg in in_vars
65
- ]
66
-
67
- # Hybrid approach: Use standard JAX trace/lower for compatibility, but jax.export for parameter tracking
68
- jitted_fn = jax.jit(normalized_fn)
69
- traced = jitted_fn.trace(jax_params)
70
- lowered = traced.lower()
71
-
72
- # Get StableHLO MLIR representation using traditional approach
73
- stablehlo_mlir = lowered.compiler_ir("stablehlo")
74
- mlir_text = str(stablehlo_mlir)
75
-
76
- # Get output info using traditional approach
77
- out_info_flat, out_tree = tree_flatten(lowered.out_info)
78
- out_info_flat = [TensorType.from_obj(info) for info in out_info_flat]
79
-
80
- # Extract argument keep mapping using stable jax.export API for parameter tracking
81
- # We use jax.export only for getting the kept_var_idx information, not for the main compilation
82
- arg_keep_map = None
83
- original_arg_count = len(in_vars)
84
-
85
- try:
86
- # Use jax.export just to get the stable parameter tracking information
87
- export_fn = export.export(jitted_fn)
88
- exported = export_fn(jax_params)
89
- kept_var_idx = exported.module_kept_var_idx
90
- if kept_var_idx is not None and len(kept_var_idx) < original_arg_count:
91
- # JAX eliminated some unused parameters during compilation
92
- # Keep the indices in sorted order for consistent mapping
93
- arg_keep_map = sorted(kept_var_idx)
94
- except Exception as e:
95
- # Fallback: if jax.export fails, we can still use the compiled result without parameter tracking
96
- # This ensures backward compatibility even if export has issues
97
- logging.warning(
98
- f"jax.export failed to get kept_var_idx, proceeding without it. Error: {e}"
99
- )
100
-
101
- # This format tells JaxRT how to handle the compiled result
102
- pfn_kwargs: dict[str, Any] = {
103
- "fn_type": "mlir.stablehlo", # Key: specify StableHLO MLIR format
104
- "ins_info": tuple(TensorType.from_obj(x) for x in in_vars),
105
- "outs_info": tuple(out_info_flat),
106
- "fn_name": get_fn_name(flat_fn),
107
- "fn_text": mlir_text, # MLIR text, serializable for transmission
108
- }
109
-
110
- if arg_keep_map is not None:
111
- pfn_kwargs["arg_keep_map"] = arg_keep_map
112
-
113
- pfn = PFunction(**pfn_kwargs)
114
- return pfn, in_vars, out_tree
115
-
116
-
117
- class JaxRunner(FeOperation):
118
- """JAX function runner frontend operation."""
119
-
120
- def trace(
121
- self, jax_fn: Callable, *args: Any, **kwargs: Any
122
- ) -> tuple[PFunction, list[MPObject], PyTreeDef]:
123
- """
124
- JAX compilation helper function.
125
-
126
- Compiles a JAX function to StableHLO format and returns the PFunction
127
- along with variable arguments for evaluation.
128
-
129
- Args:
130
- jax_fn: The JAX function to compile
131
- *args: Positional arguments to the function
132
- **kwargs: Keyword arguments to the function
133
-
134
- Returns:
135
- tuple[PFunction, list[MPObject], PyTreeDef]: The compiled PFunction, input variables, and output tree
136
- """
137
-
138
- def is_variable(arg: Any) -> bool:
139
- return isinstance(arg, MPObject)
140
-
141
- pfunc, in_vars, out_tree = jax2stablehlo(is_variable, jax_fn, *args, **kwargs)
142
- return pfunc, in_vars, out_tree
143
-
144
-
145
- _JAX_MOD = stateless_mod("jax")
146
-
147
- run_jax = JaxRunner(_JAX_MOD, "run")
mplang/v1/ops/nnx_cc.py DELETED
@@ -1,168 +0,0 @@
1
- # Copyright 2025 Ant Group Co., Ltd.
2
- #
3
- # Licensed under the Apache License, Version 2.0 (the "License");
4
- # you may not use this file except in compliance with the License.
5
- # You may obtain a copy of the License at
6
- #
7
- # http://www.apache.org/licenses/LICENSE-2.0
8
- #
9
- # Unless required by applicable law or agreed to in writing, software
10
- # distributed under the License is distributed on an "AS IS" BASIS,
11
- # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
- # See the License for the specific language governing permissions and
13
- # limitations under the License.
14
-
15
- from __future__ import annotations
16
-
17
- import logging
18
- from collections.abc import Callable
19
- from typing import Any
20
-
21
- import jax
22
- import jax.numpy as jnp
23
- from flax import nnx
24
- from jax import export
25
- from jax.tree_util import PyTreeDef, tree_flatten
26
-
27
- from mplang.v1.core import MPObject, PFunction, TensorType, get_fn_name
28
- from mplang.v1.ops.base import FeOperation, stateless_mod
29
- from mplang.v1.utils.func_utils import normalize_fn
30
-
31
- # Enable 64-bit precision for JAX to match tensor types
32
- jax.config.update("jax_enable_x64", True)
33
-
34
-
35
- def nnx2stablehlo(
36
- is_variable: Callable[[Any], bool], flat_fn: Any, *args: Any, **kwargs: Any
37
- ) -> tuple[PFunction, list[Any], PyTreeDef]:
38
- """Compile NNX function to StableHLO MLIR format for remote execution.
39
-
40
- Translates high-level NNX functions into StableHLO MLIR representations,
41
- enabling execution on JAX backends across different processes and platforms.
42
- Uses a hybrid approach: traditional NNX trace/lower for compilation compatibility,
43
- with stable jax.export API for parameter tracking.
44
-
45
- Args:
46
- is_variable: Predicate function to classify parameters as variables vs. constants.
47
- Returns True for parameters that should be treated as PFunction inputs.
48
- flat_fn: NNX function to be compiled into StableHLO format
49
- *args: Positional arguments passed to the function during compilation
50
- **kwargs: Keyword arguments passed to the function during compilation
51
-
52
- Returns:
53
- tuple[PFunction, list, PyTreeDef]: Compilation artifacts containing:
54
- - PFunction: Serialized function with embedded MLIR text and type metadata
55
- - list: Extracted variable parameters (those satisfying is_variable predicate).
56
- Non-variable parameters are captured as compile-time constants within
57
- the PFunction body, while variables become runtime input parameters.
58
- - PyTreeDef: Tree structure template for reconstructing nested output values
59
- """
60
- # Flatten (args, kwargs) and capture immediates using the moved logic from primitive.py
61
- normalized_fn, in_vars = normalize_fn(flat_fn, args, kwargs, is_variable)
62
-
63
- # Convert TensorType in_vars to ShapeDtypeStruct for JAX tracing
64
- jax_params = [
65
- jax.ShapeDtypeStruct(arg.shape, jnp.dtype(arg.dtype.name)) for arg in in_vars
66
- ]
67
-
68
- # NNX compilation pipeline using JAX export API: nnx.jit → jax.export → StableHLO MLIR
69
- # Use nnx.jit for NNX-specific functionality, then jax.export for stable parameter handling
70
- nnx_jitted = nnx.jit(normalized_fn)
71
-
72
- # Extract the underlying JAX function for jax.export compatibility
73
- # nnx.jit wraps a JAX function, and we can access it via .fun attribute
74
- underlying_jax_fn = nnx_jitted.fun
75
-
76
- # Hybrid approach: Use NNX trace/lower for compilation, but jax.export for parameter tracking
77
- # Use traditional nnx.jit → trace → lower for compatibility with argument structure
78
- nnx_traced = nnx_jitted.trace(jax_params)
79
- nnx_lowered = nnx_traced.lower()
80
-
81
- # Get StableHLO MLIR representation using traditional NNX approach
82
- # NNX lowered object wraps JAX lowered, so we access the inner JAX lowered object
83
- jax_lowered = nnx_lowered.lowered
84
- stablehlo_mlir = jax_lowered.compiler_ir("stablehlo")
85
- mlir_text = str(stablehlo_mlir)
86
-
87
- # Get output info using traditional NNX approach
88
- # NNX captures output in (args, kwargs, result) format, so we need to extract just the result part
89
- raw_out_info = jax_lowered.out_info
90
- if isinstance(raw_out_info, tuple) and len(raw_out_info) == 3:
91
- # NNX format: (args, kwargs, result) - extract just the result
92
- _, _, actual_out_info = raw_out_info
93
- out_info_flat, out_tree = tree_flatten(actual_out_info)
94
- else:
95
- # Fallback to direct format (shouldn't happen with NNX, but just in case)
96
- out_info_flat, out_tree = tree_flatten(raw_out_info)
97
-
98
- out_info_flat = [TensorType.from_obj(info) for info in out_info_flat]
99
-
100
- # Extract argument keep mapping using stable jax.export API for parameter tracking
101
- # We use the underlying JAX function with jax.export only for parameter tracking
102
- arg_keep_map = None
103
- original_arg_count = len(in_vars)
104
-
105
- try:
106
- # Use jax.export with the underlying JAX function just to get stable parameter tracking
107
- export_fn = export.export(jax.jit(underlying_jax_fn))
108
- exported = export_fn(jax_params)
109
- kept_var_idx = exported.module_kept_var_idx
110
- if kept_var_idx is not None and len(kept_var_idx) < original_arg_count:
111
- # JAX eliminated some unused parameters during compilation
112
- # Keep the indices in sorted order for consistent mapping
113
- arg_keep_map = sorted(kept_var_idx)
114
- except Exception as e:
115
- # Fallback: if jax.export fails, we can still use the compiled result without parameter tracking
116
- # This ensures backward compatibility even if export has issues
117
- logging.warning(
118
- f"jax.export failed to get kept_var_idx, proceeding without it. Error: {e}"
119
- )
120
-
121
- # This format tells JaxRT how to handle the compiled result
122
- # Use the same format as JAX since NNX compiles to the same backend
123
- pfn_kwargs: dict[str, Any] = {
124
- "fn_type": "mlir.stablehlo", # Key: specify StableHLO MLIR format
125
- "ins_info": tuple(TensorType.from_obj(x) for x in in_vars),
126
- "outs_info": tuple(out_info_flat),
127
- "fn_name": get_fn_name(flat_fn),
128
- "fn_text": mlir_text, # MLIR text, serializable for transmission
129
- }
130
-
131
- if arg_keep_map is not None:
132
- pfn_kwargs["arg_keep_map"] = arg_keep_map
133
-
134
- pfn = PFunction(**pfn_kwargs)
135
- return pfn, in_vars, out_tree
136
-
137
-
138
- class NnxRunner(FeOperation):
139
- """NNX function runner frontend operation."""
140
-
141
- def trace(
142
- self, nnx_fn: Callable, *args: Any, **kwargs: Any
143
- ) -> tuple[PFunction, list[MPObject], PyTreeDef]:
144
- """
145
- NNX compilation helper function.
146
-
147
- Compiles an NNX function to StableHLO format and returns the PFunction
148
- along with variable arguments for evaluation.
149
-
150
- Args:
151
- nnx_fn: The NNX function to compile
152
- *args: Positional arguments to the function
153
- **kwargs: Keyword arguments to the function
154
-
155
- Returns:
156
- tuple[PFunction, list[MPObject], PyTreeDef]: The compiled PFunction, input variables, and output tree
157
- """
158
-
159
- def is_variable(arg: Any) -> bool:
160
- return isinstance(arg, MPObject)
161
-
162
- pfunc, in_vars, out_tree = nnx2stablehlo(is_variable, nnx_fn, *args, **kwargs)
163
- return pfunc, in_vars, out_tree
164
-
165
-
166
- _NNX_MOD = stateless_mod("nnx")
167
-
168
- run_nnx = NnxRunner(_NNX_MOD, "run")