mplang-nightly 0.1.dev268__py3-none-any.whl → 0.1.dev270__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (181) hide show
  1. mplang/__init__.py +391 -17
  2. mplang/{v2/backends → backends}/__init__.py +9 -7
  3. mplang/{v2/backends → backends}/bfv_impl.py +6 -6
  4. mplang/{v2/backends → backends}/crypto_impl.py +6 -6
  5. mplang/{v2/backends → backends}/field_impl.py +5 -5
  6. mplang/{v2/backends → backends}/func_impl.py +4 -4
  7. mplang/{v2/backends → backends}/phe_impl.py +3 -3
  8. mplang/{v2/backends → backends}/simp_design.md +1 -1
  9. mplang/{v2/backends → backends}/simp_driver/__init__.py +5 -5
  10. mplang/{v2/backends → backends}/simp_driver/http.py +8 -8
  11. mplang/{v2/backends → backends}/simp_driver/mem.py +9 -9
  12. mplang/{v2/backends → backends}/simp_driver/ops.py +4 -4
  13. mplang/{v2/backends → backends}/simp_driver/state.py +2 -2
  14. mplang/{v2/backends → backends}/simp_driver/values.py +2 -2
  15. mplang/{v2/backends → backends}/simp_worker/__init__.py +3 -3
  16. mplang/{v2/backends → backends}/simp_worker/http.py +10 -10
  17. mplang/{v2/backends → backends}/simp_worker/mem.py +1 -1
  18. mplang/{v2/backends → backends}/simp_worker/ops.py +5 -5
  19. mplang/{v2/backends → backends}/simp_worker/state.py +2 -4
  20. mplang/{v2/backends → backends}/spu_impl.py +8 -8
  21. mplang/{v2/backends → backends}/spu_state.py +4 -4
  22. mplang/{v2/backends → backends}/store_impl.py +3 -3
  23. mplang/{v2/backends → backends}/table_impl.py +8 -8
  24. mplang/{v2/backends → backends}/tee_impl.py +6 -6
  25. mplang/{v2/backends → backends}/tensor_impl.py +6 -6
  26. mplang/{v2/cli.py → cli.py} +9 -9
  27. mplang/{v2/cli_guide.md → cli_guide.md} +12 -12
  28. mplang/{v2/dialects → dialects}/__init__.py +5 -5
  29. mplang/{v2/dialects → dialects}/bfv.py +6 -6
  30. mplang/{v2/dialects → dialects}/crypto.py +5 -5
  31. mplang/{v2/dialects → dialects}/dtypes.py +2 -2
  32. mplang/{v2/dialects → dialects}/field.py +3 -3
  33. mplang/{v2/dialects → dialects}/func.py +2 -2
  34. mplang/{v2/dialects → dialects}/phe.py +6 -6
  35. mplang/{v2/dialects → dialects}/simp.py +6 -6
  36. mplang/{v2/dialects → dialects}/spu.py +7 -7
  37. mplang/{v2/dialects → dialects}/store.py +2 -2
  38. mplang/{v2/dialects → dialects}/table.py +3 -3
  39. mplang/{v2/dialects → dialects}/tee.py +6 -6
  40. mplang/{v2/dialects → dialects}/tensor.py +5 -5
  41. mplang/{v2/edsl → edsl}/__init__.py +3 -3
  42. mplang/{v2/edsl → edsl}/context.py +6 -6
  43. mplang/{v2/edsl → edsl}/graph.py +5 -5
  44. mplang/{v2/edsl → edsl}/jit.py +2 -2
  45. mplang/{v2/edsl → edsl}/object.py +1 -1
  46. mplang/{v2/edsl → edsl}/primitive.py +5 -5
  47. mplang/{v2/edsl → edsl}/printer.py +1 -1
  48. mplang/{v2/edsl → edsl}/serde.py +1 -1
  49. mplang/{v2/edsl → edsl}/tracer.py +7 -7
  50. mplang/{v2/edsl → edsl}/typing.py +1 -1
  51. mplang/{v2/kernels → kernels}/ldpc.cpp +13 -13
  52. mplang/{v2/kernels → kernels}/okvs.cpp +4 -4
  53. mplang/{v2/kernels → kernels}/okvs_opt.cpp +46 -31
  54. mplang/{v2/kernels → kernels}/py_kernels.py +1 -1
  55. mplang/{v2/libs → libs}/collective.py +5 -5
  56. mplang/{v2/libs → libs}/device/__init__.py +1 -1
  57. mplang/{v2/libs → libs}/device/api.py +12 -12
  58. mplang/{v2/libs → libs}/ml/__init__.py +1 -1
  59. mplang/{v2/libs → libs}/ml/sgb.py +4 -4
  60. mplang/{v2/libs → libs}/mpc/__init__.py +3 -3
  61. mplang/{v2/libs → libs}/mpc/_utils.py +2 -2
  62. mplang/{v2/libs → libs}/mpc/analytics/aggregation.py +1 -1
  63. mplang/{v2/libs → libs}/mpc/analytics/groupby.py +2 -2
  64. mplang/{v2/libs → libs}/mpc/analytics/permutation.py +3 -3
  65. mplang/{v2/libs → libs}/mpc/ot/base.py +3 -3
  66. mplang/{v2/libs → libs}/mpc/ot/extension.py +2 -2
  67. mplang/{v2/libs → libs}/mpc/ot/silent.py +4 -4
  68. mplang/{v2/libs → libs}/mpc/psi/cuckoo.py +3 -3
  69. mplang/{v2/libs → libs}/mpc/psi/okvs.py +1 -1
  70. mplang/{v2/libs → libs}/mpc/psi/okvs_gct.py +19 -13
  71. mplang/{v2/libs → libs}/mpc/psi/oprf.py +3 -3
  72. mplang/libs/mpc/psi/rr22.py +303 -0
  73. mplang/{v2/libs → libs}/mpc/psi/unbalanced.py +4 -4
  74. mplang/{v2/libs → libs}/mpc/vole/gilboa.py +3 -3
  75. mplang/{v2/libs → libs}/mpc/vole/ldpc.py +2 -2
  76. mplang/{v2/libs → libs}/mpc/vole/silver.py +6 -6
  77. mplang/{v2/runtime → runtime}/interpreter.py +11 -11
  78. mplang/{v2/runtime → runtime}/value.py +2 -2
  79. mplang/{v1/runtime → utils}/__init__.py +18 -15
  80. mplang/{v1/utils → utils}/func_utils.py +1 -1
  81. {mplang_nightly-0.1.dev268.dist-info → mplang_nightly-0.1.dev270.dist-info}/METADATA +2 -2
  82. mplang_nightly-0.1.dev270.dist-info/RECORD +102 -0
  83. mplang/v1/__init__.py +0 -157
  84. mplang/v1/_device.py +0 -602
  85. mplang/v1/analysis/__init__.py +0 -37
  86. mplang/v1/analysis/diagram.py +0 -567
  87. mplang/v1/core/__init__.py +0 -157
  88. mplang/v1/core/cluster.py +0 -343
  89. mplang/v1/core/comm.py +0 -281
  90. mplang/v1/core/context_mgr.py +0 -50
  91. mplang/v1/core/dtypes.py +0 -335
  92. mplang/v1/core/expr/__init__.py +0 -80
  93. mplang/v1/core/expr/ast.py +0 -542
  94. mplang/v1/core/expr/evaluator.py +0 -581
  95. mplang/v1/core/expr/printer.py +0 -285
  96. mplang/v1/core/expr/transformer.py +0 -141
  97. mplang/v1/core/expr/utils.py +0 -78
  98. mplang/v1/core/expr/visitor.py +0 -85
  99. mplang/v1/core/expr/walk.py +0 -387
  100. mplang/v1/core/interp.py +0 -160
  101. mplang/v1/core/mask.py +0 -325
  102. mplang/v1/core/mpir.py +0 -965
  103. mplang/v1/core/mpobject.py +0 -117
  104. mplang/v1/core/mptype.py +0 -407
  105. mplang/v1/core/pfunc.py +0 -130
  106. mplang/v1/core/primitive.py +0 -877
  107. mplang/v1/core/table.py +0 -218
  108. mplang/v1/core/tensor.py +0 -75
  109. mplang/v1/core/tracer.py +0 -383
  110. mplang/v1/host.py +0 -130
  111. mplang/v1/kernels/__init__.py +0 -41
  112. mplang/v1/kernels/base.py +0 -125
  113. mplang/v1/kernels/basic.py +0 -240
  114. mplang/v1/kernels/context.py +0 -369
  115. mplang/v1/kernels/crypto.py +0 -122
  116. mplang/v1/kernels/fhe.py +0 -858
  117. mplang/v1/kernels/mock_tee.py +0 -72
  118. mplang/v1/kernels/phe.py +0 -1864
  119. mplang/v1/kernels/spu.py +0 -341
  120. mplang/v1/kernels/sql_duckdb.py +0 -44
  121. mplang/v1/kernels/stablehlo.py +0 -90
  122. mplang/v1/kernels/value.py +0 -626
  123. mplang/v1/ops/__init__.py +0 -35
  124. mplang/v1/ops/base.py +0 -424
  125. mplang/v1/ops/basic.py +0 -294
  126. mplang/v1/ops/crypto.py +0 -262
  127. mplang/v1/ops/fhe.py +0 -272
  128. mplang/v1/ops/jax_cc.py +0 -147
  129. mplang/v1/ops/nnx_cc.py +0 -168
  130. mplang/v1/ops/phe.py +0 -216
  131. mplang/v1/ops/spu.py +0 -151
  132. mplang/v1/ops/sql_cc.py +0 -303
  133. mplang/v1/ops/tee.py +0 -36
  134. mplang/v1/protos/v1alpha1/mpir_pb2.py +0 -63
  135. mplang/v1/protos/v1alpha1/mpir_pb2.pyi +0 -557
  136. mplang/v1/protos/v1alpha1/value_pb2.py +0 -34
  137. mplang/v1/protos/v1alpha1/value_pb2.pyi +0 -169
  138. mplang/v1/runtime/channel.py +0 -230
  139. mplang/v1/runtime/cli.py +0 -451
  140. mplang/v1/runtime/client.py +0 -456
  141. mplang/v1/runtime/communicator.py +0 -131
  142. mplang/v1/runtime/data_providers.py +0 -303
  143. mplang/v1/runtime/driver.py +0 -324
  144. mplang/v1/runtime/exceptions.py +0 -27
  145. mplang/v1/runtime/http_api.md +0 -56
  146. mplang/v1/runtime/link_comm.py +0 -196
  147. mplang/v1/runtime/server.py +0 -501
  148. mplang/v1/runtime/session.py +0 -270
  149. mplang/v1/runtime/simulation.py +0 -324
  150. mplang/v1/simp/__init__.py +0 -13
  151. mplang/v1/simp/api.py +0 -353
  152. mplang/v1/simp/mpi.py +0 -131
  153. mplang/v1/simp/party.py +0 -225
  154. mplang/v1/simp/random.py +0 -120
  155. mplang/v1/simp/smpc.py +0 -238
  156. mplang/v1/utils/__init__.py +0 -13
  157. mplang/v1/utils/crypto.py +0 -32
  158. mplang/v1/utils/spu_utils.py +0 -130
  159. mplang/v1/utils/table_utils.py +0 -185
  160. mplang/v2/__init__.py +0 -424
  161. mplang/v2/libs/mpc/psi/rr22.py +0 -344
  162. mplang_nightly-0.1.dev268.dist-info/RECORD +0 -180
  163. /mplang/{v2/backends → backends}/channel.py +0 -0
  164. /mplang/{v2/edsl → edsl}/README.md +0 -0
  165. /mplang/{v2/edsl → edsl}/registry.py +0 -0
  166. /mplang/{v2/kernels → kernels}/Makefile +0 -0
  167. /mplang/{v2/kernels → kernels}/__init__.py +0 -0
  168. /mplang/{v2/kernels → kernels}/gf128.cpp +0 -0
  169. /mplang/{v2/libs → libs}/device/cluster.py +0 -0
  170. /mplang/{v2/libs → libs}/mpc/analytics/__init__.py +0 -0
  171. /mplang/{v2/libs → libs}/mpc/analytics/groupby.md +0 -0
  172. /mplang/{v2/libs → libs}/mpc/common/constants.py +0 -0
  173. /mplang/{v2/libs → libs}/mpc/ot/__init__.py +0 -0
  174. /mplang/{v2/libs → libs}/mpc/psi/__init__.py +0 -0
  175. /mplang/{v2/libs → libs}/mpc/vole/__init__.py +0 -0
  176. /mplang/{v2/runtime → runtime}/__init__.py +0 -0
  177. /mplang/{v2/runtime → runtime}/dialect_state.py +0 -0
  178. /mplang/{v2/runtime → runtime}/object_store.py +0 -0
  179. {mplang_nightly-0.1.dev268.dist-info → mplang_nightly-0.1.dev270.dist-info}/WHEEL +0 -0
  180. {mplang_nightly-0.1.dev268.dist-info → mplang_nightly-0.1.dev270.dist-info}/entry_points.txt +0 -0
  181. {mplang_nightly-0.1.dev268.dist-info → mplang_nightly-0.1.dev270.dist-info}/licenses/LICENSE +0 -0
mplang/__init__.py CHANGED
@@ -12,35 +12,409 @@
12
12
  # See the License for the specific language governing permissions and
13
13
  # limitations under the License.
14
14
 
15
- """Multi-Party Programming Language for Secure Computation.
15
+ """MPLang2: Next generation EDSL for multi-party computation.
16
16
 
17
- This module provides access to both v1 and v2 APIs:
17
+ This is the temporary home for the new EDSL implementation during migration.
18
+ Once migration is complete, this will replace the original mplang package.
18
19
 
19
- # Use v1 (current default, will be deprecated)
20
- import mplang as mp
21
- # or explicitly:
22
- import mplang.v1 as mp
20
+ Public API is designed to be compatible with mplang v1 where possible.
21
+ """
23
22
 
24
- # Use v2 (recommended for new code)
25
- import mplang.v2 as mp
23
+ from __future__ import annotations
26
24
 
27
- v1 API will be deprecated in a future release. Please migrate to v2.
28
- """
25
+ from collections.abc import Callable
26
+ from typing import Any
29
27
 
30
28
  # Version is managed by hatch-vcs and available after package installation
31
29
  try:
32
- from importlib.metadata import PackageNotFoundError, version
30
+ from importlib.metadata import version
33
31
 
34
32
  __version__ = version("mplang")
35
- except PackageNotFoundError:
33
+ except Exception:
36
34
  # Fallback for development/editable installs when package is not installed
37
35
  __version__ = "0.0.0-dev"
38
36
 
37
+ from mplang import dialects
38
+ from mplang.backends.simp_driver.ops import DRIVER_HANDLERS
39
+ from mplang.backends.simp_worker import SimpWorker
40
+ from mplang.backends.simp_worker.mem import LocalMesh
41
+ from mplang.backends.simp_worker.ops import WORKER_HANDLERS
42
+ from mplang.dialects.simp import make_driver, make_simulator
43
+ from mplang.edsl import (
44
+ Graph,
45
+ GraphPrinter,
46
+ Object,
47
+ Operation,
48
+ Primitive,
49
+ TracedFunction,
50
+ Tracer,
51
+ Value,
52
+ find_context,
53
+ find_context_with_state,
54
+ find_interpreter,
55
+ format_graph,
56
+ get_current_context,
57
+ get_default_context,
58
+ is_tracing,
59
+ jit,
60
+ pop_context,
61
+ primitive,
62
+ push_context,
63
+ register_default_context_factory,
64
+ set_root_context,
65
+ trace,
66
+ )
67
+ from mplang.edsl.registry import get_profiler
68
+
69
+ # Type system
70
+ from mplang.edsl.typing import (
71
+ MPType,
72
+ ScalarType,
73
+ SSType,
74
+ TableType,
75
+ TensorType,
76
+ VectorType,
77
+ )
78
+
79
+ # =============================================================================
80
+ # Device API (compatible with mplang v1)
81
+ # =============================================================================
82
+ from mplang.libs.device import (
83
+ ClusterSpec,
84
+ Device,
85
+ Node,
86
+ device,
87
+ get_dev_attr,
88
+ is_device_obj,
89
+ jax_fn,
90
+ put,
91
+ set_dev_attr,
92
+ )
93
+ from mplang.libs.device import fetch as device_fetch
94
+ from mplang.runtime.interpreter import Interpreter
95
+
96
+ # =============================================================================
97
+ # Context Management API (JAX-like pattern)
98
+ # =============================================================================
99
+
100
+
101
+ def _get_context(context: Interpreter | None) -> Interpreter:
102
+ """Get context from parameter or context stack."""
103
+ if context is not None:
104
+ return context
105
+ ctx = get_current_context()
106
+ if ctx is None:
107
+ raise RuntimeError(
108
+ "No context available. Either pass context explicitly or use "
109
+ "set_context()/push_context() to set a default context."
110
+ )
111
+ if not isinstance(ctx, Interpreter):
112
+ raise RuntimeError(
113
+ f"Current context is not an Interpreter: {type(ctx).__name__}. "
114
+ "Use mp.set_context(interpreter) to set the execution context."
115
+ )
116
+ return ctx
117
+
118
+
119
+ # =============================================================================
120
+ # Meta-APIs (compile, evaluate, fetch)
121
+ # =============================================================================
122
+
123
+
124
+ def evaluate(
125
+ fn: Callable[..., Any] | TracedFunction,
126
+ *args: Any,
127
+ context: Interpreter | None = None,
128
+ **kwargs: Any,
129
+ ) -> Any:
130
+ """Evaluate a function or traced function.
131
+
132
+ Args:
133
+ fn: The function or TracedFunction to evaluate.
134
+ *args: Positional arguments to pass to the function.
135
+ context: Optional interpreter context. If None, uses current context.
136
+ **kwargs: Keyword arguments to pass to the function.
137
+
138
+ Returns:
139
+ The result of the function evaluation.
140
+
141
+ Example:
142
+ >>> with mp.make_simulator(3) as sim:
143
+ ... result = mp.evaluate(traced) # uses sim from context
144
+ >>> # Or explicitly:
145
+ >>> result = mp.evaluate(traced, context=sim)
146
+ """
147
+ from mplang.edsl.tracer import TracedFunction
148
+ from mplang.runtime.interpreter import InterpObject
149
+
150
+ interp = _get_context(context)
151
+
152
+ def unwrap_if_interp(val: Any) -> Any:
153
+ """Unwrap InterpObject to runtime value at execution boundary."""
154
+ if isinstance(val, InterpObject):
155
+ return val.runtime_obj
156
+ return val
157
+
158
+ with interp:
159
+ if isinstance(fn, TracedFunction):
160
+ inputs = fn.prepare_inputs(*args, **kwargs)
161
+ inputs = [unwrap_if_interp(v) for v in inputs]
162
+ raw_result = interp.evaluate_graph(fn.graph, inputs)
163
+ wrapped = [
164
+ InterpObject(v, fn.graph.outputs[i].type, interp)
165
+ for i, v in enumerate(raw_result)
166
+ ]
167
+ return fn.reconstruct_outputs(wrapped)
168
+
169
+ return fn(*args, **kwargs)
170
+
171
+
172
+ def fetch(
173
+ result: Any,
174
+ *,
175
+ follow_device: bool = True,
176
+ context: Interpreter | None = None,
177
+ ) -> Any:
178
+ """Fetch results from interpreter context to Python.
179
+
180
+ This is a meta-function that operates at execution boundaries, not a traced
181
+ dialect operation. It brings data from the distributed/MPC runtime back to
182
+ the Python host.
183
+
184
+ Behavior in different contexts:
185
+ - **Tracing (compile)**: Returns the input unchanged (identity). The graph
186
+ outputs are determined by the function's return statement, not fetch calls.
187
+ - **Execution (evaluate)**: Actually fetches data from workers/parties.
188
+
189
+ Design Note (A vs B tradeoff):
190
+ Two designs were considered for fetch behavior during tracing:
191
+
192
+ - **Design A (chosen)**: fetch = identity during tracing. Graph outputs
193
+ are determined solely by the return statement. This is simpler and
194
+ avoids ambiguity when fetch and return reference different values.
195
+
196
+ - **Design B (alternative)**: fetch marks output points in the graph.
197
+ This would allow fetch(a), fetch(b), return b to output both a and b.
198
+ However, it complicates the semantics and requires tracking fetch
199
+ points separately from return values.
200
+
201
+ Design A was chosen for simplicity. If a value needs to be an output,
202
+ it should be returned. fetch's role is purely for execution-time I/O.
203
+
204
+ Args:
205
+ result: Object(s) to fetch. Can be a single InterpObject, DriverVar,
206
+ or nested structure containing them.
207
+ follow_device: If True and object has device attribute, dispatch to
208
+ device.fetch which fetches from the correct rank based on device.
209
+ If False, fetch from all parties.
210
+ context: Interpreter context. If None, uses current context.
211
+
212
+ Returns:
213
+ Fetched Python values. For device objects with follow_device=True,
214
+ returns single value from the device's rank(s). Otherwise returns
215
+ list of values (one per party) or single value for world_size=1.
216
+ During tracing, returns the input unchanged.
217
+ """
218
+ from jax.tree_util import tree_map
219
+
220
+ from mplang.backends.simp_driver.values import DriverVar
221
+ from mplang.edsl.context import is_tracing
222
+ from mplang.runtime.interpreter import InterpObject
223
+ from mplang.runtime.value import WrapValue
224
+
225
+ # Check if we are in tracing context - if so, return identity
226
+ if is_tracing():
227
+ # Design A: fetch = identity during tracing
228
+ # Graph outputs are determined by return statement, not fetch calls
229
+ return result
230
+
231
+ # Execution context - actually fetch data
232
+ interp = _get_context(context)
233
+
234
+ def _fetch_single(var: Any) -> Any:
235
+ """Fetch a single value from InterpObject."""
236
+ # InterpObject (from mp.evaluate) - extract runtime_obj
237
+ if isinstance(var, InterpObject):
238
+ if follow_device and is_device_obj(var):
239
+ return device_fetch(var)
240
+ var = var.runtime_obj # extract and continue processing
241
+
242
+ # DriverVar (simp dialect) - remote fetch from workers
243
+ if isinstance(var, DriverVar):
244
+ from mplang.backends.simp_driver.state import SimpDriver
245
+
246
+ simp_state = interp.get_dialect_state("simp")
247
+ assert isinstance(simp_state, SimpDriver), "DriverVar requires simp state"
248
+
249
+ resolved: list[Any] = []
250
+ for rank, uri in enumerate(var.values):
251
+ if uri is None:
252
+ resolved.append(None)
253
+ else:
254
+ fetched = simp_state.fetch(rank, uri).result()
255
+ if isinstance(fetched, WrapValue):
256
+ fetched = fetched.data
257
+ resolved.append(fetched)
258
+
259
+ return resolved[0] if len(resolved) == 1 else resolved
260
+
261
+ # WrapValue (TensorValue, TableValue, etc.) - unwrap
262
+ if isinstance(var, WrapValue):
263
+ return var.data
264
+
265
+ # Plain values pass through
266
+ return var
267
+
268
+ with interp:
269
+ return tree_map(_fetch_single, result)
270
+
271
+
272
+ # Alias for compatibility
273
+ def function(fn: Callable[..., Any] | None = None) -> Callable[..., Any]:
274
+ """Decorator defining a Multi-Party Function (MP Program).
275
+
276
+ This decorator "lifts" a local function into a distributed program by
277
+ automatically wrapping it in a `simp.pcall_static` that targets ALL available
278
+ parties in the current context.
279
+
280
+ Semantics: f(args) -> pcall(ALL, f, args)
281
+
282
+ Args:
283
+ fn: The function to decorate.
284
+
285
+ Returns:
286
+ A wrapper function that, when called, executes the original function
287
+ on all workers.
288
+ """
289
+ import functools
290
+
291
+ from mplang.dialects import simp
292
+
293
+ if fn is None:
294
+ return function
295
+
296
+ def has_simp_state(ctx: Any) -> bool:
297
+ if hasattr(ctx, "get_dialect_state"):
298
+ state = ctx.get_dialect_state("simp")
299
+ return state is not None and hasattr(state, "world_size")
300
+ return False
301
+
302
+ @functools.wraps(fn)
303
+ def wrapper(*args: Any, **kwargs: Any) -> Any:
304
+ from mplang.edsl.context import find_context
305
+
306
+ # Find context with simp dialect state
307
+ ctx = find_context(has_simp_state)
308
+ if ctx is None:
309
+ raise RuntimeError(
310
+ "mp.function requires a context with world_size information "
311
+ "(e.g. SimpSimulator or Driver initialized)."
312
+ )
313
+
314
+ # ctx found by predicate so we know it has get_dialect_state
315
+ simp_state = ctx.get_dialect_state("simp") # type: ignore[attr-defined]
316
+ world_size = simp_state.world_size # type: ignore
317
+
318
+ all_parties = tuple(range(world_size))
319
+ return simp.pcall_static(all_parties, fn, *args, **kwargs)
320
+
321
+ return wrapper
322
+
323
+
324
+ def compile(
325
+ fn: Callable[..., Any],
326
+ *args: Any,
327
+ context: Interpreter | None = None,
328
+ **kwargs: Any,
329
+ ) -> TracedFunction:
330
+ """Compile a function to get its IR without executing it.
331
+
332
+ Args:
333
+ fn: The function to compile.
334
+ *args: Example arguments for tracing.
335
+ context: Optional interpreter context. If None, uses current context.
336
+ **kwargs: Example keyword arguments for tracing.
337
+
338
+ Returns:
339
+ TracedFunction with the compiled graph.
340
+
341
+ Example:
342
+ >>> with mp.make_simulator(3) as sim:
343
+ ... traced = mp.compile(job) # uses sim from context
344
+ """
345
+ # If a context is explicitly provided, push it before tracing
346
+ # so that _resolve_cluster() can find it.
347
+ if context is not None:
348
+ with context:
349
+ return trace(fn, *args, **kwargs)
350
+
351
+ # Otherwise, rely on the caller having pushed an interpreter context.
352
+ # _resolve_cluster() will traverse the stack to find the interpreter.
353
+ return trace(fn, *args, **kwargs)
354
+
355
+
39
356
  # =============================================================================
40
- # Default: Re-export v1 API for backward compatibility
357
+ # Public API
41
358
  # =============================================================================
42
- # This will be changed to v2 in a future release
359
+ __all__ = [ # noqa: RUF022
360
+ # Version
361
+ "__version__",
362
+ # Device API
363
+ "ClusterSpec",
364
+ "Device",
365
+ "Node",
366
+ "device",
367
+ "get_dev_attr",
368
+ "is_device_obj",
369
+ "jax_fn",
370
+ "put",
371
+ "set_dev_attr",
372
+ # Core EDSL
373
+ "Graph",
374
+ "GraphPrinter",
375
+ "Object",
376
+ "Operation",
377
+ "Primitive",
378
+ "TracedFunction",
379
+ "Tracer",
380
+ "Value",
381
+ "compile",
382
+ "evaluate",
383
+ "fetch",
384
+ "find_context",
385
+ "find_context_with_state",
386
+ "find_interpreter",
387
+ "format_graph",
388
+ "function",
389
+ "get_current_context",
390
+ "get_default_context",
391
+ "is_tracing",
392
+ "jit",
393
+ "pop_context",
394
+ "primitive",
395
+ "push_context",
396
+ "set_root_context",
397
+ "trace",
398
+ # Type system
399
+ "MPType",
400
+ "ScalarType",
401
+ "SSType",
402
+ "TableType",
403
+ "TensorType",
404
+ "VectorType",
405
+ # Backend / Runtime
406
+ "DRIVER_HANDLERS",
407
+ "Interpreter",
408
+ "LocalMesh",
409
+ "SimpWorker",
410
+ "WORKER_HANDLERS",
411
+ "make_driver",
412
+ "make_simulator",
413
+ # Dialects
414
+ "dialects",
415
+ "register_default_context_factory",
416
+ "get_profiler",
417
+ ]
43
418
 
44
- # Make v1 and v2 subpackages directly accessible
45
- from mplang import v1, v2 # noqa: F401
46
- from mplang.v1 import * # noqa: F403
419
+ # Register Interpreter as default context factory
420
+ register_default_context_factory(Interpreter)
@@ -41,13 +41,15 @@ def load_builtins() -> None:
41
41
  """Load all built-in backend implementations."""
42
42
  # Core backends that are expected to be present
43
43
  builtin_backends = [
44
- "mplang.v2.backends.spu_impl",
45
- "mplang.v2.backends.tensor_impl",
46
- "mplang.v2.backends.table_impl",
47
- "mplang.v2.backends.crypto_impl",
48
- "mplang.v2.backends.tee_impl",
49
- "mplang.v2.backends.bfv_impl",
50
- "mplang.v2.backends.store_impl",
44
+ "mplang.backends.spu_impl",
45
+ "mplang.backends.tensor_impl",
46
+ "mplang.backends.table_impl",
47
+ "mplang.backends.crypto_impl",
48
+ "mplang.backends.tee_impl",
49
+ "mplang.backends.bfv_impl",
50
+ "mplang.backends.store_impl",
51
+ "mplang.backends.func_impl",
52
+ "mplang.backends.field_impl",
51
53
  ]
52
54
 
53
55
  for module_name in builtin_backends:
@@ -29,12 +29,12 @@ import numpy as np
29
29
  import tenseal as ts
30
30
  import tenseal.sealapi as sealapi
31
31
 
32
- from mplang.v2.backends.tensor_impl import TensorValue
33
- from mplang.v2.dialects import bfv
34
- from mplang.v2.edsl import serde
35
- from mplang.v2.edsl.graph import Operation
36
- from mplang.v2.runtime.interpreter import Interpreter
37
- from mplang.v2.runtime.value import Value, WrapValue
32
+ from mplang.backends.tensor_impl import TensorValue
33
+ from mplang.dialects import bfv
34
+ from mplang.edsl import serde
35
+ from mplang.edsl.graph import Operation
36
+ from mplang.runtime.interpreter import Interpreter
37
+ from mplang.runtime.value import Value, WrapValue
38
38
 
39
39
  # =============================================================================
40
40
  # Helper for SEAL serialization
@@ -26,12 +26,12 @@ import coincurve
26
26
  import numpy as np
27
27
  from cryptography.hazmat.primitives.ciphers.aead import AESGCM
28
28
 
29
- from mplang.v2.backends.tensor_impl import TensorValue
30
- from mplang.v2.dialects import crypto
31
- from mplang.v2.edsl import serde
32
- from mplang.v2.edsl.graph import Operation
33
- from mplang.v2.runtime.interpreter import Interpreter
34
- from mplang.v2.runtime.value import Value, WrapValue
29
+ from mplang.backends.tensor_impl import TensorValue
30
+ from mplang.dialects import crypto
31
+ from mplang.edsl import serde
32
+ from mplang.edsl.graph import Operation
33
+ from mplang.runtime.interpreter import Interpreter
34
+ from mplang.runtime.value import Value, WrapValue
35
35
 
36
36
  # =============================================================================
37
37
  # BytesValue - Wrapper for raw bytes (keys, hashes, ciphertexts)
@@ -29,11 +29,11 @@ import threading
29
29
  import jax.numpy as jnp
30
30
  import numpy as np
31
31
 
32
- from mplang.v2.backends.tensor_impl import TensorValue, _unwrap, _wrap
33
- from mplang.v2.dialects import field
34
- from mplang.v2.edsl.graph import Operation
35
- from mplang.v2.kernels import py_kernels
36
- from mplang.v2.runtime.interpreter import Interpreter
32
+ from mplang.backends.tensor_impl import TensorValue, _unwrap, _wrap
33
+ from mplang.dialects import field
34
+ from mplang.edsl.graph import Operation
35
+ from mplang.kernels import py_kernels
36
+ from mplang.runtime.interpreter import Interpreter
37
37
 
38
38
  # =============================================================================
39
39
  # Kernel Loading
@@ -23,10 +23,10 @@ from __future__ import annotations
23
23
 
24
24
  from typing import TYPE_CHECKING, Any, ClassVar
25
25
 
26
- from mplang.v2.dialects.func import call_p, func_def_p
27
- from mplang.v2.edsl import serde
28
- from mplang.v2.edsl.graph import Graph, Operation
29
- from mplang.v2.runtime.value import Value
26
+ from mplang.dialects.func import call_p, func_def_p
27
+ from mplang.edsl import serde
28
+ from mplang.edsl.graph import Graph, Operation
29
+ from mplang.runtime.value import Value
30
30
 
31
31
  if TYPE_CHECKING:
32
32
  from typing import Self
@@ -22,9 +22,9 @@ from typing import Any, cast
22
22
  from lightphe import LightPHE
23
23
  from lightphe.models.Ciphertext import Ciphertext
24
24
 
25
- from mplang.v2.dialects import phe
26
- from mplang.v2.edsl.graph import Operation
27
- from mplang.v2.runtime.interpreter import Interpreter
25
+ from mplang.dialects import phe
26
+ from mplang.edsl.graph import Operation
27
+ from mplang.runtime.interpreter import Interpreter
28
28
 
29
29
 
30
30
  class PHEContext:
@@ -67,7 +67,7 @@ class SimpDriver(DialectState, ABC):
67
67
  """Abstract interface for simp drivers."""
68
68
  dialect_name = "simp"
69
69
  world_size: int
70
-
70
+
71
71
  @abstractmethod
72
72
  def submit(self, rank, graph, inputs, job_id=None) -> Future: ...
73
73
  @abstractmethod
@@ -17,16 +17,16 @@
17
17
  Provides Driver-side state, values, and ops for the simp dialect.
18
18
  """
19
19
 
20
- from mplang.v2.backends.simp_driver.http import SimpHttpDriver, make_driver
21
- from mplang.v2.backends.simp_driver.mem import (
20
+ from mplang.backends.simp_driver.http import SimpHttpDriver, make_driver
21
+ from mplang.backends.simp_driver.mem import (
22
22
  LocalCluster,
23
23
  MemCluster,
24
24
  SimpMemDriver,
25
25
  make_simulator,
26
26
  )
27
- from mplang.v2.backends.simp_driver.ops import DRIVER_HANDLERS
28
- from mplang.v2.backends.simp_driver.state import SimpDriver
29
- from mplang.v2.backends.simp_driver.values import DriverVar
27
+ from mplang.backends.simp_driver.ops import DRIVER_HANDLERS
28
+ from mplang.backends.simp_driver.state import SimpDriver
29
+ from mplang.backends.simp_driver.values import DriverVar
30
30
 
31
31
  __all__ = [
32
32
  "DRIVER_HANDLERS",
@@ -23,16 +23,16 @@ from typing import TYPE_CHECKING, Any
23
23
 
24
24
  import httpx
25
25
 
26
- from mplang.v2.backends.simp_driver.state import SimpDriver
27
- from mplang.v2.edsl import serde
28
- from mplang.v2.runtime.interpreter import Interpreter
29
- from mplang.v2.runtime.object_store import ObjectStore
26
+ from mplang.backends.simp_driver.state import SimpDriver
27
+ from mplang.edsl import serde
28
+ from mplang.runtime.interpreter import Interpreter
29
+ from mplang.runtime.object_store import ObjectStore
30
30
 
31
31
  if TYPE_CHECKING:
32
32
  from concurrent.futures import Future
33
33
 
34
- from mplang.v2.edsl.graph import Graph
35
- from mplang.v2.libs.device import ClusterSpec
34
+ from mplang.edsl.graph import Graph
35
+ from mplang.libs.device import ClusterSpec
36
36
 
37
37
 
38
38
  class SimpHttpDriver(SimpDriver):
@@ -142,10 +142,10 @@ def make_driver(endpoints: list[str], *, cluster_spec: Any = None) -> Interprete
142
142
  >>> with interp:
143
143
  ... result = my_func()
144
144
  """
145
- from mplang.v2.backends.simp_driver.ops import DRIVER_HANDLERS
145
+ from mplang.backends.simp_driver.ops import DRIVER_HANDLERS
146
146
 
147
147
  if cluster_spec is None:
148
- from mplang.v2.libs.device import ClusterSpec
148
+ from mplang.libs.device import ClusterSpec
149
149
 
150
150
  cluster_spec = ClusterSpec.simple(
151
151
  world_size=len(endpoints), endpoints=endpoints
@@ -22,17 +22,17 @@ import pathlib
22
22
  from collections.abc import Callable
23
23
  from typing import TYPE_CHECKING, Any, cast
24
24
 
25
- from mplang.v2.backends.simp_driver.state import SimpDriver
26
- from mplang.v2.backends.simp_worker import WORKER_HANDLERS, SimpWorker
27
- from mplang.v2.backends.simp_worker.mem import LocalMesh
28
- from mplang.v2.runtime.interpreter import ExecutionTracer, Interpreter
29
- from mplang.v2.runtime.object_store import ObjectStore
25
+ from mplang.backends.simp_driver.state import SimpDriver
26
+ from mplang.backends.simp_worker import WORKER_HANDLERS, SimpWorker
27
+ from mplang.backends.simp_worker.mem import LocalMesh
28
+ from mplang.runtime.interpreter import ExecutionTracer, Interpreter
29
+ from mplang.runtime.object_store import ObjectStore
30
30
 
31
31
  if TYPE_CHECKING:
32
32
  from concurrent.futures import Future
33
33
 
34
- from mplang.v2.edsl.graph import Graph
35
- from mplang.v2.libs.device import ClusterSpec
34
+ from mplang.edsl.graph import Graph
35
+ from mplang.libs.device import ClusterSpec
36
36
 
37
37
 
38
38
  class MemCluster:
@@ -245,10 +245,10 @@ def make_simulator(
245
245
  >>> with interp:
246
246
  ... result = my_func()
247
247
  """
248
- from mplang.v2.backends.simp_driver.ops import DRIVER_HANDLERS
248
+ from mplang.backends.simp_driver.ops import DRIVER_HANDLERS
249
249
 
250
250
  if cluster_spec is None:
251
- from mplang.v2.libs.device import ClusterSpec
251
+ from mplang.libs.device import ClusterSpec
252
252
 
253
253
  cluster_spec = ClusterSpec.simple(world_size)
254
254
 
@@ -23,10 +23,10 @@ from __future__ import annotations
23
23
 
24
24
  from typing import Any
25
25
 
26
- from mplang.v2.backends.simp_driver.values import DriverVar
27
- from mplang.v2.dialects import simp
28
- from mplang.v2.edsl.graph import Graph, Operation
29
- from mplang.v2.edsl.typing import CustomType
26
+ from mplang.backends.simp_driver.values import DriverVar
27
+ from mplang.dialects import simp
28
+ from mplang.edsl.graph import Graph, Operation
29
+ from mplang.edsl.typing import CustomType
30
30
 
31
31
 
32
32
  def _get_driver_context(interpreter: Any) -> Any:
@@ -19,12 +19,12 @@ from __future__ import annotations
19
19
  from abc import ABC, abstractmethod
20
20
  from typing import TYPE_CHECKING, Any
21
21
 
22
- from mplang.v2.runtime.dialect_state import DialectState
22
+ from mplang.runtime.dialect_state import DialectState
23
23
 
24
24
  if TYPE_CHECKING:
25
25
  from concurrent.futures import Future
26
26
 
27
- from mplang.v2.edsl.graph import Graph
27
+ from mplang.edsl.graph import Graph
28
28
 
29
29
 
30
30
  class SimpDriver(DialectState, ABC):