mplang-nightly 0.1.dev192__py3-none-any.whl → 0.1.dev268__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (188) hide show
  1. mplang/__init__.py +21 -130
  2. mplang/py.typed +13 -0
  3. mplang/v1/__init__.py +157 -0
  4. mplang/v1/_device.py +602 -0
  5. mplang/{analysis → v1/analysis}/__init__.py +1 -1
  6. mplang/{analysis → v1/analysis}/diagram.py +4 -4
  7. mplang/{core → v1/core}/__init__.py +20 -14
  8. mplang/{core → v1/core}/cluster.py +6 -1
  9. mplang/{core → v1/core}/comm.py +1 -1
  10. mplang/{core → v1/core}/context_mgr.py +1 -1
  11. mplang/{core → v1/core}/dtypes.py +38 -0
  12. mplang/{core → v1/core}/expr/__init__.py +7 -7
  13. mplang/{core → v1/core}/expr/ast.py +11 -13
  14. mplang/{core → v1/core}/expr/evaluator.py +8 -8
  15. mplang/{core → v1/core}/expr/printer.py +6 -6
  16. mplang/{core → v1/core}/expr/transformer.py +2 -2
  17. mplang/{core → v1/core}/expr/utils.py +2 -2
  18. mplang/{core → v1/core}/expr/visitor.py +1 -1
  19. mplang/{core → v1/core}/expr/walk.py +1 -1
  20. mplang/{core → v1/core}/interp.py +6 -6
  21. mplang/{core → v1/core}/mpir.py +13 -11
  22. mplang/{core → v1/core}/mpobject.py +6 -6
  23. mplang/{core → v1/core}/mptype.py +13 -10
  24. mplang/{core → v1/core}/pfunc.py +2 -2
  25. mplang/{core → v1/core}/primitive.py +12 -12
  26. mplang/{core → v1/core}/table.py +36 -8
  27. mplang/{core → v1/core}/tensor.py +1 -1
  28. mplang/{core → v1/core}/tracer.py +9 -9
  29. mplang/{host.py → v1/host.py} +5 -5
  30. mplang/{kernels → v1/kernels}/__init__.py +1 -1
  31. mplang/{kernels → v1/kernels}/base.py +1 -1
  32. mplang/{kernels → v1/kernels}/basic.py +15 -15
  33. mplang/{kernels → v1/kernels}/context.py +19 -16
  34. mplang/{kernels → v1/kernels}/crypto.py +8 -10
  35. mplang/{kernels → v1/kernels}/fhe.py +9 -7
  36. mplang/{kernels → v1/kernels}/mock_tee.py +3 -3
  37. mplang/{kernels → v1/kernels}/phe.py +26 -18
  38. mplang/{kernels → v1/kernels}/spu.py +5 -5
  39. mplang/{kernels → v1/kernels}/sql_duckdb.py +5 -3
  40. mplang/{kernels → v1/kernels}/stablehlo.py +18 -17
  41. mplang/{kernels → v1/kernels}/value.py +2 -2
  42. mplang/{ops → v1/ops}/__init__.py +3 -3
  43. mplang/{ops → v1/ops}/base.py +1 -1
  44. mplang/{ops → v1/ops}/basic.py +6 -5
  45. mplang/v1/ops/crypto.py +262 -0
  46. mplang/{ops → v1/ops}/fhe.py +2 -2
  47. mplang/{ops → v1/ops}/jax_cc.py +26 -59
  48. mplang/v1/ops/nnx_cc.py +168 -0
  49. mplang/{ops → v1/ops}/phe.py +16 -3
  50. mplang/{ops → v1/ops}/spu.py +3 -3
  51. mplang/v1/ops/sql_cc.py +303 -0
  52. mplang/{ops → v1/ops}/tee.py +2 -2
  53. mplang/{runtime → v1/runtime}/__init__.py +2 -2
  54. mplang/v1/runtime/channel.py +230 -0
  55. mplang/{runtime → v1/runtime}/cli.py +3 -3
  56. mplang/{runtime → v1/runtime}/client.py +1 -1
  57. mplang/{runtime → v1/runtime}/communicator.py +39 -15
  58. mplang/{runtime → v1/runtime}/data_providers.py +80 -19
  59. mplang/{runtime → v1/runtime}/driver.py +4 -4
  60. mplang/v1/runtime/link_comm.py +196 -0
  61. mplang/{runtime → v1/runtime}/server.py +22 -9
  62. mplang/{runtime → v1/runtime}/session.py +24 -51
  63. mplang/{runtime → v1/runtime}/simulation.py +36 -14
  64. mplang/{simp → v1/simp}/api.py +72 -14
  65. mplang/{simp → v1/simp}/mpi.py +1 -1
  66. mplang/{simp → v1/simp}/party.py +5 -5
  67. mplang/{simp → v1/simp}/random.py +2 -2
  68. mplang/v1/simp/smpc.py +238 -0
  69. mplang/v1/utils/table_utils.py +185 -0
  70. mplang/v2/__init__.py +424 -0
  71. mplang/v2/backends/__init__.py +57 -0
  72. mplang/v2/backends/bfv_impl.py +705 -0
  73. mplang/v2/backends/channel.py +217 -0
  74. mplang/v2/backends/crypto_impl.py +723 -0
  75. mplang/v2/backends/field_impl.py +454 -0
  76. mplang/v2/backends/func_impl.py +107 -0
  77. mplang/v2/backends/phe_impl.py +148 -0
  78. mplang/v2/backends/simp_design.md +136 -0
  79. mplang/v2/backends/simp_driver/__init__.py +41 -0
  80. mplang/v2/backends/simp_driver/http.py +168 -0
  81. mplang/v2/backends/simp_driver/mem.py +280 -0
  82. mplang/v2/backends/simp_driver/ops.py +135 -0
  83. mplang/v2/backends/simp_driver/state.py +60 -0
  84. mplang/v2/backends/simp_driver/values.py +52 -0
  85. mplang/v2/backends/simp_worker/__init__.py +29 -0
  86. mplang/v2/backends/simp_worker/http.py +354 -0
  87. mplang/v2/backends/simp_worker/mem.py +102 -0
  88. mplang/v2/backends/simp_worker/ops.py +167 -0
  89. mplang/v2/backends/simp_worker/state.py +49 -0
  90. mplang/v2/backends/spu_impl.py +275 -0
  91. mplang/v2/backends/spu_state.py +187 -0
  92. mplang/v2/backends/store_impl.py +62 -0
  93. mplang/v2/backends/table_impl.py +838 -0
  94. mplang/v2/backends/tee_impl.py +215 -0
  95. mplang/v2/backends/tensor_impl.py +519 -0
  96. mplang/v2/cli.py +603 -0
  97. mplang/v2/cli_guide.md +122 -0
  98. mplang/v2/dialects/__init__.py +36 -0
  99. mplang/v2/dialects/bfv.py +665 -0
  100. mplang/v2/dialects/crypto.py +689 -0
  101. mplang/v2/dialects/dtypes.py +378 -0
  102. mplang/v2/dialects/field.py +210 -0
  103. mplang/v2/dialects/func.py +135 -0
  104. mplang/v2/dialects/phe.py +723 -0
  105. mplang/v2/dialects/simp.py +944 -0
  106. mplang/v2/dialects/spu.py +349 -0
  107. mplang/v2/dialects/store.py +63 -0
  108. mplang/v2/dialects/table.py +407 -0
  109. mplang/v2/dialects/tee.py +346 -0
  110. mplang/v2/dialects/tensor.py +1175 -0
  111. mplang/v2/edsl/README.md +279 -0
  112. mplang/v2/edsl/__init__.py +99 -0
  113. mplang/v2/edsl/context.py +311 -0
  114. mplang/v2/edsl/graph.py +463 -0
  115. mplang/v2/edsl/jit.py +62 -0
  116. mplang/v2/edsl/object.py +53 -0
  117. mplang/v2/edsl/primitive.py +284 -0
  118. mplang/v2/edsl/printer.py +119 -0
  119. mplang/v2/edsl/registry.py +207 -0
  120. mplang/v2/edsl/serde.py +375 -0
  121. mplang/v2/edsl/tracer.py +614 -0
  122. mplang/v2/edsl/typing.py +816 -0
  123. mplang/v2/kernels/Makefile +30 -0
  124. mplang/v2/kernels/__init__.py +23 -0
  125. mplang/v2/kernels/gf128.cpp +148 -0
  126. mplang/v2/kernels/ldpc.cpp +82 -0
  127. mplang/v2/kernels/okvs.cpp +283 -0
  128. mplang/v2/kernels/okvs_opt.cpp +291 -0
  129. mplang/v2/kernels/py_kernels.py +398 -0
  130. mplang/v2/libs/collective.py +330 -0
  131. mplang/v2/libs/device/__init__.py +51 -0
  132. mplang/v2/libs/device/api.py +813 -0
  133. mplang/v2/libs/device/cluster.py +352 -0
  134. mplang/v2/libs/ml/__init__.py +23 -0
  135. mplang/v2/libs/ml/sgb.py +1861 -0
  136. mplang/v2/libs/mpc/__init__.py +41 -0
  137. mplang/v2/libs/mpc/_utils.py +99 -0
  138. mplang/v2/libs/mpc/analytics/__init__.py +35 -0
  139. mplang/v2/libs/mpc/analytics/aggregation.py +372 -0
  140. mplang/v2/libs/mpc/analytics/groupby.md +99 -0
  141. mplang/v2/libs/mpc/analytics/groupby.py +331 -0
  142. mplang/v2/libs/mpc/analytics/permutation.py +386 -0
  143. mplang/v2/libs/mpc/common/constants.py +39 -0
  144. mplang/v2/libs/mpc/ot/__init__.py +32 -0
  145. mplang/v2/libs/mpc/ot/base.py +222 -0
  146. mplang/v2/libs/mpc/ot/extension.py +477 -0
  147. mplang/v2/libs/mpc/ot/silent.py +217 -0
  148. mplang/v2/libs/mpc/psi/__init__.py +40 -0
  149. mplang/v2/libs/mpc/psi/cuckoo.py +228 -0
  150. mplang/v2/libs/mpc/psi/okvs.py +49 -0
  151. mplang/v2/libs/mpc/psi/okvs_gct.py +79 -0
  152. mplang/v2/libs/mpc/psi/oprf.py +310 -0
  153. mplang/v2/libs/mpc/psi/rr22.py +344 -0
  154. mplang/v2/libs/mpc/psi/unbalanced.py +200 -0
  155. mplang/v2/libs/mpc/vole/__init__.py +31 -0
  156. mplang/v2/libs/mpc/vole/gilboa.py +327 -0
  157. mplang/v2/libs/mpc/vole/ldpc.py +383 -0
  158. mplang/v2/libs/mpc/vole/silver.py +336 -0
  159. mplang/v2/runtime/__init__.py +15 -0
  160. mplang/v2/runtime/dialect_state.py +41 -0
  161. mplang/v2/runtime/interpreter.py +871 -0
  162. mplang/v2/runtime/object_store.py +194 -0
  163. mplang/v2/runtime/value.py +141 -0
  164. {mplang_nightly-0.1.dev192.dist-info → mplang_nightly-0.1.dev268.dist-info}/METADATA +22 -16
  165. mplang_nightly-0.1.dev268.dist-info/RECORD +180 -0
  166. {mplang_nightly-0.1.dev192.dist-info → mplang_nightly-0.1.dev268.dist-info}/WHEEL +1 -1
  167. mplang/device.py +0 -327
  168. mplang/ops/crypto.py +0 -108
  169. mplang/ops/ibis_cc.py +0 -136
  170. mplang/ops/sql_cc.py +0 -62
  171. mplang/runtime/link_comm.py +0 -78
  172. mplang/simp/smpc.py +0 -201
  173. mplang/utils/table_utils.py +0 -85
  174. mplang_nightly-0.1.dev192.dist-info/RECORD +0 -83
  175. /mplang/{core → v1/core}/mask.py +0 -0
  176. /mplang/{protos → v1/protos}/v1alpha1/mpir_pb2.py +0 -0
  177. /mplang/{protos → v1/protos}/v1alpha1/mpir_pb2.pyi +0 -0
  178. /mplang/{protos → v1/protos}/v1alpha1/value_pb2.py +0 -0
  179. /mplang/{protos → v1/protos}/v1alpha1/value_pb2.pyi +0 -0
  180. /mplang/{runtime → v1/runtime}/exceptions.py +0 -0
  181. /mplang/{runtime → v1/runtime}/http_api.md +0 -0
  182. /mplang/{simp → v1/simp}/__init__.py +0 -0
  183. /mplang/{utils → v1/utils}/__init__.py +0 -0
  184. /mplang/{utils → v1/utils}/crypto.py +0 -0
  185. /mplang/{utils → v1/utils}/func_utils.py +0 -0
  186. /mplang/{utils → v1/utils}/spu_utils.py +0 -0
  187. {mplang_nightly-0.1.dev192.dist-info → mplang_nightly-0.1.dev268.dist-info}/entry_points.txt +0 -0
  188. {mplang_nightly-0.1.dev192.dist-info → mplang_nightly-0.1.dev268.dist-info}/licenses/LICENSE +0 -0
mplang/v2/__init__.py ADDED
@@ -0,0 +1,424 @@
1
+ # Copyright 2025 Ant Group Co., Ltd.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ """MPLang2: Next generation EDSL for multi-party computation.
16
+
17
+ This is the temporary home for the new EDSL implementation during migration.
18
+ Once migration is complete, this will replace the original mplang package.
19
+
20
+ Public API is designed to be compatible with mplang v1 where possible.
21
+ """
22
+
23
+ from __future__ import annotations
24
+
25
+ from collections.abc import Callable
26
+ from typing import Any
27
+
28
+ __version__ = "0.1.0"
29
+
30
+ # =============================================================================
31
+ # Core EDSL components
32
+ # =============================================================================
33
+ # =============================================================================
34
+ # Dialects
35
+ # =============================================================================
36
+ # =============================================================================
37
+ # Backend / Runtime
38
+ # =============================================================================
39
+ import mplang.v2.backends.func_impl # Register func handlers
40
+ from mplang.v2 import dialects
41
+ from mplang.v2.backends.simp_driver.ops import DRIVER_HANDLERS
42
+ from mplang.v2.backends.simp_worker import SimpWorker
43
+ from mplang.v2.backends.simp_worker.mem import LocalMesh
44
+ from mplang.v2.backends.simp_worker.ops import WORKER_HANDLERS
45
+ from mplang.v2.dialects.simp import make_driver, make_simulator
46
+ from mplang.v2.edsl import (
47
+ Graph,
48
+ GraphPrinter,
49
+ Object,
50
+ Operation,
51
+ Primitive,
52
+ TracedFunction,
53
+ Tracer,
54
+ Value,
55
+ find_context,
56
+ find_context_with_state,
57
+ find_interpreter,
58
+ format_graph,
59
+ get_current_context,
60
+ get_default_context,
61
+ is_tracing,
62
+ jit,
63
+ pop_context,
64
+ primitive,
65
+ push_context,
66
+ register_default_context_factory,
67
+ set_root_context,
68
+ trace,
69
+ )
70
+ from mplang.v2.edsl.registry import get_profiler
71
+
72
+ # Type system
73
+ from mplang.v2.edsl.typing import (
74
+ MPType,
75
+ ScalarType,
76
+ SSType,
77
+ TableType,
78
+ TensorType,
79
+ VectorType,
80
+ )
81
+
82
+ # =============================================================================
83
+ # Device API (compatible with mplang v1)
84
+ # =============================================================================
85
+ from mplang.v2.libs.device import (
86
+ ClusterSpec,
87
+ Device,
88
+ Node,
89
+ device,
90
+ get_dev_attr,
91
+ is_device_obj,
92
+ jax_fn,
93
+ put,
94
+ set_dev_attr,
95
+ )
96
+ from mplang.v2.libs.device import fetch as device_fetch
97
+ from mplang.v2.runtime.interpreter import Interpreter
98
+
99
+ # =============================================================================
100
+ # Context Management API (JAX-like pattern)
101
+ # =============================================================================
102
+
103
+
104
+ def _get_context(context: Interpreter | None) -> Interpreter:
105
+ """Get context from parameter or context stack."""
106
+ if context is not None:
107
+ return context
108
+ ctx = get_current_context()
109
+ if ctx is None:
110
+ raise RuntimeError(
111
+ "No context available. Either pass context explicitly or use "
112
+ "set_context()/push_context() to set a default context."
113
+ )
114
+ if not isinstance(ctx, Interpreter):
115
+ raise RuntimeError(
116
+ f"Current context is not an Interpreter: {type(ctx).__name__}. "
117
+ "Use mp.set_context(interpreter) to set the execution context."
118
+ )
119
+ return ctx
120
+
121
+
122
+ # =============================================================================
123
+ # Meta-APIs (compile, evaluate, fetch)
124
+ # =============================================================================
125
+
126
+
127
+ def evaluate(
128
+ fn: Callable[..., Any] | TracedFunction,
129
+ *args: Any,
130
+ context: Interpreter | None = None,
131
+ **kwargs: Any,
132
+ ) -> Any:
133
+ """Evaluate a function or traced function.
134
+
135
+ Args:
136
+ fn: The function or TracedFunction to evaluate.
137
+ *args: Positional arguments to pass to the function.
138
+ context: Optional interpreter context. If None, uses current context.
139
+ **kwargs: Keyword arguments to pass to the function.
140
+
141
+ Returns:
142
+ The result of the function evaluation.
143
+
144
+ Example:
145
+ >>> with mp.make_simulator(3) as sim:
146
+ ... result = mp.evaluate(traced) # uses sim from context
147
+ >>> # Or explicitly:
148
+ >>> result = mp.evaluate(traced, context=sim)
149
+ """
150
+ from mplang.v2.edsl.tracer import TracedFunction
151
+ from mplang.v2.runtime.interpreter import InterpObject
152
+
153
+ interp = _get_context(context)
154
+
155
+ def unwrap_if_interp(val: Any) -> Any:
156
+ """Unwrap InterpObject to runtime value at execution boundary."""
157
+ if isinstance(val, InterpObject):
158
+ return val.runtime_obj
159
+ return val
160
+
161
+ with interp:
162
+ if isinstance(fn, TracedFunction):
163
+ inputs = fn.prepare_inputs(*args, **kwargs)
164
+ inputs = [unwrap_if_interp(v) for v in inputs]
165
+ raw_result = interp.evaluate_graph(fn.graph, inputs)
166
+ wrapped = [
167
+ InterpObject(v, fn.graph.outputs[i].type, interp)
168
+ for i, v in enumerate(raw_result)
169
+ ]
170
+ return fn.reconstruct_outputs(wrapped)
171
+
172
+ return fn(*args, **kwargs)
173
+
174
+
175
+ def fetch(
176
+ result: Any,
177
+ *,
178
+ follow_device: bool = True,
179
+ context: Interpreter | None = None,
180
+ ) -> Any:
181
+ """Fetch results from interpreter context to Python.
182
+
183
+ This is a meta-function that operates at execution boundaries, not a traced
184
+ dialect operation. It brings data from the distributed/MPC runtime back to
185
+ the Python host.
186
+
187
+ Behavior in different contexts:
188
+ - **Tracing (compile)**: Returns the input unchanged (identity). The graph
189
+ outputs are determined by the function's return statement, not fetch calls.
190
+ - **Execution (evaluate)**: Actually fetches data from workers/parties.
191
+
192
+ Design Note (A vs B tradeoff):
193
+ Two designs were considered for fetch behavior during tracing:
194
+
195
+ - **Design A (chosen)**: fetch = identity during tracing. Graph outputs
196
+ are determined solely by the return statement. This is simpler and
197
+ avoids ambiguity when fetch and return reference different values.
198
+
199
+ - **Design B (alternative)**: fetch marks output points in the graph.
200
+ This would allow fetch(a), fetch(b), return b to output both a and b.
201
+ However, it complicates the semantics and requires tracking fetch
202
+ points separately from return values.
203
+
204
+ Design A was chosen for simplicity. If a value needs to be an output,
205
+ it should be returned. fetch's role is purely for execution-time I/O.
206
+
207
+ Args:
208
+ result: Object(s) to fetch. Can be a single InterpObject, DriverVar,
209
+ or nested structure containing them.
210
+ follow_device: If True and object has device attribute, dispatch to
211
+ device.fetch which fetches from the correct rank based on device.
212
+ If False, fetch from all parties.
213
+ context: Interpreter context. If None, uses current context.
214
+
215
+ Returns:
216
+ Fetched Python values. For device objects with follow_device=True,
217
+ returns single value from the device's rank(s). Otherwise returns
218
+ list of values (one per party) or single value for world_size=1.
219
+ During tracing, returns the input unchanged.
220
+ """
221
+ from jax.tree_util import tree_map
222
+
223
+ from mplang.v2.backends.simp_driver.values import DriverVar
224
+ from mplang.v2.edsl.context import is_tracing
225
+ from mplang.v2.runtime.interpreter import InterpObject
226
+ from mplang.v2.runtime.value import WrapValue
227
+
228
+ # Check if we are in tracing context - if so, return identity
229
+ if is_tracing():
230
+ # Design A: fetch = identity during tracing
231
+ # Graph outputs are determined by return statement, not fetch calls
232
+ return result
233
+
234
+ # Execution context - actually fetch data
235
+ interp = _get_context(context)
236
+
237
+ def _fetch_single(var: Any) -> Any:
238
+ """Fetch a single value from InterpObject."""
239
+ # InterpObject (from mp.evaluate) - extract runtime_obj
240
+ if isinstance(var, InterpObject):
241
+ if follow_device and is_device_obj(var):
242
+ return device_fetch(var)
243
+ var = var.runtime_obj # extract and continue processing
244
+
245
+ # DriverVar (simp dialect) - remote fetch from workers
246
+ if isinstance(var, DriverVar):
247
+ from mplang.v2.backends.simp_driver.state import SimpDriver
248
+
249
+ simp_state = interp.get_dialect_state("simp")
250
+ assert isinstance(simp_state, SimpDriver), "DriverVar requires simp state"
251
+
252
+ resolved: list[Any] = []
253
+ for rank, uri in enumerate(var.values):
254
+ if uri is None:
255
+ resolved.append(None)
256
+ else:
257
+ fetched = simp_state.fetch(rank, uri).result()
258
+ if isinstance(fetched, WrapValue):
259
+ fetched = fetched.data
260
+ resolved.append(fetched)
261
+
262
+ return resolved[0] if len(resolved) == 1 else resolved
263
+
264
+ # WrapValue (TensorValue, TableValue, etc.) - unwrap
265
+ if isinstance(var, WrapValue):
266
+ return var.data
267
+
268
+ # Plain values pass through
269
+ return var
270
+
271
+ with interp:
272
+ return tree_map(_fetch_single, result)
273
+
274
+
275
+ # Alias for compatibility
276
+ def function(fn: Callable[..., Any] | None = None) -> Callable[..., Any]:
277
+ """Decorator defining a Multi-Party Function (MP Program).
278
+
279
+ This decorator "lifts" a local function into a distributed program by
280
+ automatically wrapping it in a `simp.pcall_static` that targets ALL available
281
+ parties in the current context.
282
+
283
+ Semantics: f(args) -> pcall(ALL, f, args)
284
+
285
+ Args:
286
+ fn: The function to decorate.
287
+
288
+ Returns:
289
+ A wrapper function that, when called, executes the original function
290
+ on all workers.
291
+ """
292
+ import functools
293
+
294
+ from mplang.v2.dialects import simp
295
+
296
+ if fn is None:
297
+ return function
298
+
299
+ def has_simp_state(ctx: Any) -> bool:
300
+ if hasattr(ctx, "get_dialect_state"):
301
+ state = ctx.get_dialect_state("simp")
302
+ return state is not None and hasattr(state, "world_size")
303
+ return False
304
+
305
+ @functools.wraps(fn)
306
+ def wrapper(*args: Any, **kwargs: Any) -> Any:
307
+ from mplang.v2.edsl.context import find_context
308
+
309
+ # Find context with simp dialect state
310
+ ctx = find_context(has_simp_state)
311
+ if ctx is None:
312
+ raise RuntimeError(
313
+ "mp.function requires a context with world_size information "
314
+ "(e.g. SimpSimulator or Driver initialized)."
315
+ )
316
+
317
+ # ctx found by predicate so we know it has get_dialect_state
318
+ simp_state = ctx.get_dialect_state("simp") # type: ignore[attr-defined]
319
+ world_size = simp_state.world_size # type: ignore
320
+
321
+ all_parties = tuple(range(world_size))
322
+ return simp.pcall_static(all_parties, fn, *args, **kwargs)
323
+
324
+ return wrapper
325
+
326
+
327
+ def compile(
328
+ fn: Callable[..., Any],
329
+ *args: Any,
330
+ context: Interpreter | None = None,
331
+ **kwargs: Any,
332
+ ) -> TracedFunction:
333
+ """Compile a function to get its IR without executing it.
334
+
335
+ Args:
336
+ fn: The function to compile.
337
+ *args: Example arguments for tracing.
338
+ context: Optional interpreter context. If None, uses current context.
339
+ **kwargs: Example keyword arguments for tracing.
340
+
341
+ Returns:
342
+ TracedFunction with the compiled graph.
343
+
344
+ Example:
345
+ >>> with mp.make_simulator(3) as sim:
346
+ ... traced = mp.compile(job) # uses sim from context
347
+ """
348
+ # If a context is explicitly provided, push it before tracing
349
+ # so that _resolve_cluster() can find it.
350
+ if context is not None:
351
+ with context:
352
+ return trace(fn, *args, **kwargs)
353
+
354
+ # Otherwise, rely on the caller having pushed an interpreter context.
355
+ # _resolve_cluster() will traverse the stack to find the interpreter.
356
+ return trace(fn, *args, **kwargs)
357
+
358
+
359
+ # =============================================================================
360
+ # Public API
361
+ # =============================================================================
362
+ __all__ = [ # noqa: RUF022
363
+ # Version
364
+ "__version__",
365
+ # Device API
366
+ "ClusterSpec",
367
+ "Device",
368
+ "Node",
369
+ "device",
370
+ "get_dev_attr",
371
+ "is_device_obj",
372
+ "jax_fn",
373
+ "put",
374
+ "set_dev_attr",
375
+ # Core EDSL
376
+ "Graph",
377
+ "GraphPrinter",
378
+ "Object",
379
+ "Operation",
380
+ "Primitive",
381
+ "TracedFunction",
382
+ "Tracer",
383
+ "Value",
384
+ "compile",
385
+ "evaluate",
386
+ "fetch",
387
+ "find_context",
388
+ "find_context_with_state",
389
+ "find_interpreter",
390
+ "format_graph",
391
+ "function",
392
+ "get_current_context",
393
+ "get_default_context",
394
+ "is_tracing",
395
+ "jit",
396
+ "mplang",
397
+ "pop_context",
398
+ "primitive",
399
+ "push_context",
400
+ "set_root_context",
401
+ "trace",
402
+ # Type system
403
+ "MPType",
404
+ "ScalarType",
405
+ "SSType",
406
+ "TableType",
407
+ "TensorType",
408
+ "VectorType",
409
+ # Backend / Runtime
410
+ "DRIVER_HANDLERS",
411
+ "Interpreter",
412
+ "LocalMesh",
413
+ "SimpWorker",
414
+ "WORKER_HANDLERS",
415
+ "make_driver",
416
+ "make_simulator",
417
+ # Dialects
418
+ "dialects",
419
+ "register_default_context_factory",
420
+ "get_profiler",
421
+ ]
422
+
423
+ # Register Interpreter as default context factory
424
+ register_default_context_factory(Interpreter)
@@ -0,0 +1,57 @@
1
+ # Copyright 2025 Ant Group Co., Ltd.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ """Runtime implementations for MPLang2 dialects."""
16
+
17
+ import importlib
18
+ import logging
19
+
20
+ logger = logging.getLogger(__name__)
21
+
22
+
23
+ def load_backend(module_name: str) -> None:
24
+ """Load a backend implementation module by name.
25
+
26
+ This is a helper to avoid 'unused import' warnings when importing backend
27
+ implementations solely for their side effects (registration).
28
+
29
+ Args:
30
+ module_name: The dotted name of the module to load.
31
+ """
32
+ try:
33
+ importlib.import_module(module_name)
34
+ logger.debug(f"Loaded backend: {module_name}")
35
+ except ImportError as e:
36
+ # We re-raise the error so the user knows their backend failed to load
37
+ raise ImportError(f"Failed to load backend '{module_name}': {e}") from e
38
+
39
+
40
+ def load_builtins() -> None:
41
+ """Load all built-in backend implementations."""
42
+ # Core backends that are expected to be present
43
+ builtin_backends = [
44
+ "mplang.v2.backends.spu_impl",
45
+ "mplang.v2.backends.tensor_impl",
46
+ "mplang.v2.backends.table_impl",
47
+ "mplang.v2.backends.crypto_impl",
48
+ "mplang.v2.backends.tee_impl",
49
+ "mplang.v2.backends.bfv_impl",
50
+ "mplang.v2.backends.store_impl",
51
+ ]
52
+
53
+ for module_name in builtin_backends:
54
+ try:
55
+ load_backend(module_name)
56
+ except ImportError as e:
57
+ logger.warning(f"Could not load built-in backend '{module_name}': {e}")