mplang-nightly 0.1.dev158__py3-none-any.whl → 0.1.dev268__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (191) hide show
  1. mplang/__init__.py +21 -45
  2. mplang/py.typed +13 -0
  3. mplang/v1/__init__.py +157 -0
  4. mplang/v1/_device.py +602 -0
  5. mplang/{analysis → v1/analysis}/__init__.py +1 -1
  6. mplang/{analysis → v1/analysis}/diagram.py +5 -7
  7. mplang/v1/core/__init__.py +157 -0
  8. mplang/{core → v1/core}/cluster.py +30 -14
  9. mplang/{core → v1/core}/comm.py +5 -1
  10. mplang/{core → v1/core}/context_mgr.py +1 -1
  11. mplang/{core/dtype.py → v1/core/dtypes.py} +44 -2
  12. mplang/{core → v1/core}/expr/__init__.py +7 -7
  13. mplang/{core → v1/core}/expr/ast.py +13 -14
  14. mplang/{core → v1/core}/expr/evaluator.py +65 -24
  15. mplang/{core → v1/core}/expr/printer.py +24 -18
  16. mplang/{core → v1/core}/expr/transformer.py +3 -3
  17. mplang/{core → v1/core}/expr/utils.py +2 -2
  18. mplang/{core → v1/core}/expr/visitor.py +1 -1
  19. mplang/{core → v1/core}/expr/walk.py +1 -1
  20. mplang/{core → v1/core}/interp.py +6 -6
  21. mplang/{core → v1/core}/mpir.py +23 -16
  22. mplang/{core → v1/core}/mpobject.py +6 -6
  23. mplang/{core → v1/core}/mptype.py +13 -10
  24. mplang/{core → v1/core}/pfunc.py +4 -4
  25. mplang/{core → v1/core}/primitive.py +106 -201
  26. mplang/{core → v1/core}/table.py +36 -8
  27. mplang/{core → v1/core}/tensor.py +1 -1
  28. mplang/{core → v1/core}/tracer.py +9 -9
  29. mplang/{api.py → v1/host.py} +38 -6
  30. mplang/v1/kernels/__init__.py +41 -0
  31. mplang/{kernels → v1/kernels}/base.py +1 -1
  32. mplang/v1/kernels/basic.py +240 -0
  33. mplang/{kernels → v1/kernels}/context.py +42 -27
  34. mplang/{kernels → v1/kernels}/crypto.py +44 -37
  35. mplang/v1/kernels/fhe.py +858 -0
  36. mplang/{kernels → v1/kernels}/mock_tee.py +12 -13
  37. mplang/{kernels → v1/kernels}/phe.py +263 -57
  38. mplang/{kernels → v1/kernels}/spu.py +137 -48
  39. mplang/{kernels → v1/kernels}/sql_duckdb.py +12 -15
  40. mplang/{kernels → v1/kernels}/stablehlo.py +30 -23
  41. mplang/v1/kernels/value.py +626 -0
  42. mplang/{ops → v1/ops}/__init__.py +5 -16
  43. mplang/{ops → v1/ops}/base.py +2 -5
  44. mplang/{ops/builtin.py → v1/ops/basic.py} +34 -26
  45. mplang/v1/ops/crypto.py +262 -0
  46. mplang/v1/ops/fhe.py +272 -0
  47. mplang/{ops → v1/ops}/jax_cc.py +33 -68
  48. mplang/v1/ops/nnx_cc.py +168 -0
  49. mplang/{ops → v1/ops}/phe.py +16 -4
  50. mplang/{ops → v1/ops}/spu.py +3 -5
  51. mplang/v1/ops/sql_cc.py +303 -0
  52. mplang/{ops → v1/ops}/tee.py +9 -24
  53. mplang/{protos → v1/protos}/v1alpha1/mpir_pb2.pyi +71 -21
  54. mplang/v1/protos/v1alpha1/value_pb2.py +34 -0
  55. mplang/v1/protos/v1alpha1/value_pb2.pyi +169 -0
  56. mplang/{runtime → v1/runtime}/__init__.py +2 -2
  57. mplang/v1/runtime/channel.py +230 -0
  58. mplang/{runtime → v1/runtime}/cli.py +35 -20
  59. mplang/{runtime → v1/runtime}/client.py +19 -8
  60. mplang/{runtime → v1/runtime}/communicator.py +59 -15
  61. mplang/{runtime → v1/runtime}/data_providers.py +80 -19
  62. mplang/{runtime → v1/runtime}/driver.py +30 -12
  63. mplang/v1/runtime/link_comm.py +196 -0
  64. mplang/{runtime → v1/runtime}/server.py +58 -42
  65. mplang/{runtime → v1/runtime}/session.py +57 -71
  66. mplang/{runtime → v1/runtime}/simulation.py +55 -28
  67. mplang/v1/simp/api.py +353 -0
  68. mplang/{simp → v1/simp}/mpi.py +8 -9
  69. mplang/{simp/__init__.py → v1/simp/party.py} +19 -145
  70. mplang/{simp → v1/simp}/random.py +21 -22
  71. mplang/v1/simp/smpc.py +238 -0
  72. mplang/v1/utils/table_utils.py +185 -0
  73. mplang/v2/__init__.py +424 -0
  74. mplang/v2/backends/__init__.py +57 -0
  75. mplang/v2/backends/bfv_impl.py +705 -0
  76. mplang/v2/backends/channel.py +217 -0
  77. mplang/v2/backends/crypto_impl.py +723 -0
  78. mplang/v2/backends/field_impl.py +454 -0
  79. mplang/v2/backends/func_impl.py +107 -0
  80. mplang/v2/backends/phe_impl.py +148 -0
  81. mplang/v2/backends/simp_design.md +136 -0
  82. mplang/v2/backends/simp_driver/__init__.py +41 -0
  83. mplang/v2/backends/simp_driver/http.py +168 -0
  84. mplang/v2/backends/simp_driver/mem.py +280 -0
  85. mplang/v2/backends/simp_driver/ops.py +135 -0
  86. mplang/v2/backends/simp_driver/state.py +60 -0
  87. mplang/v2/backends/simp_driver/values.py +52 -0
  88. mplang/v2/backends/simp_worker/__init__.py +29 -0
  89. mplang/v2/backends/simp_worker/http.py +354 -0
  90. mplang/v2/backends/simp_worker/mem.py +102 -0
  91. mplang/v2/backends/simp_worker/ops.py +167 -0
  92. mplang/v2/backends/simp_worker/state.py +49 -0
  93. mplang/v2/backends/spu_impl.py +275 -0
  94. mplang/v2/backends/spu_state.py +187 -0
  95. mplang/v2/backends/store_impl.py +62 -0
  96. mplang/v2/backends/table_impl.py +838 -0
  97. mplang/v2/backends/tee_impl.py +215 -0
  98. mplang/v2/backends/tensor_impl.py +519 -0
  99. mplang/v2/cli.py +603 -0
  100. mplang/v2/cli_guide.md +122 -0
  101. mplang/v2/dialects/__init__.py +36 -0
  102. mplang/v2/dialects/bfv.py +665 -0
  103. mplang/v2/dialects/crypto.py +689 -0
  104. mplang/v2/dialects/dtypes.py +378 -0
  105. mplang/v2/dialects/field.py +210 -0
  106. mplang/v2/dialects/func.py +135 -0
  107. mplang/v2/dialects/phe.py +723 -0
  108. mplang/v2/dialects/simp.py +944 -0
  109. mplang/v2/dialects/spu.py +349 -0
  110. mplang/v2/dialects/store.py +63 -0
  111. mplang/v2/dialects/table.py +407 -0
  112. mplang/v2/dialects/tee.py +346 -0
  113. mplang/v2/dialects/tensor.py +1175 -0
  114. mplang/v2/edsl/README.md +279 -0
  115. mplang/v2/edsl/__init__.py +99 -0
  116. mplang/v2/edsl/context.py +311 -0
  117. mplang/v2/edsl/graph.py +463 -0
  118. mplang/v2/edsl/jit.py +62 -0
  119. mplang/v2/edsl/object.py +53 -0
  120. mplang/v2/edsl/primitive.py +284 -0
  121. mplang/v2/edsl/printer.py +119 -0
  122. mplang/v2/edsl/registry.py +207 -0
  123. mplang/v2/edsl/serde.py +375 -0
  124. mplang/v2/edsl/tracer.py +614 -0
  125. mplang/v2/edsl/typing.py +816 -0
  126. mplang/v2/kernels/Makefile +30 -0
  127. mplang/v2/kernels/__init__.py +23 -0
  128. mplang/v2/kernels/gf128.cpp +148 -0
  129. mplang/v2/kernels/ldpc.cpp +82 -0
  130. mplang/v2/kernels/okvs.cpp +283 -0
  131. mplang/v2/kernels/okvs_opt.cpp +291 -0
  132. mplang/v2/kernels/py_kernels.py +398 -0
  133. mplang/v2/libs/collective.py +330 -0
  134. mplang/v2/libs/device/__init__.py +51 -0
  135. mplang/v2/libs/device/api.py +813 -0
  136. mplang/v2/libs/device/cluster.py +352 -0
  137. mplang/v2/libs/ml/__init__.py +23 -0
  138. mplang/v2/libs/ml/sgb.py +1861 -0
  139. mplang/v2/libs/mpc/__init__.py +41 -0
  140. mplang/v2/libs/mpc/_utils.py +99 -0
  141. mplang/v2/libs/mpc/analytics/__init__.py +35 -0
  142. mplang/v2/libs/mpc/analytics/aggregation.py +372 -0
  143. mplang/v2/libs/mpc/analytics/groupby.md +99 -0
  144. mplang/v2/libs/mpc/analytics/groupby.py +331 -0
  145. mplang/v2/libs/mpc/analytics/permutation.py +386 -0
  146. mplang/v2/libs/mpc/common/constants.py +39 -0
  147. mplang/v2/libs/mpc/ot/__init__.py +32 -0
  148. mplang/v2/libs/mpc/ot/base.py +222 -0
  149. mplang/v2/libs/mpc/ot/extension.py +477 -0
  150. mplang/v2/libs/mpc/ot/silent.py +217 -0
  151. mplang/v2/libs/mpc/psi/__init__.py +40 -0
  152. mplang/v2/libs/mpc/psi/cuckoo.py +228 -0
  153. mplang/v2/libs/mpc/psi/okvs.py +49 -0
  154. mplang/v2/libs/mpc/psi/okvs_gct.py +79 -0
  155. mplang/v2/libs/mpc/psi/oprf.py +310 -0
  156. mplang/v2/libs/mpc/psi/rr22.py +344 -0
  157. mplang/v2/libs/mpc/psi/unbalanced.py +200 -0
  158. mplang/v2/libs/mpc/vole/__init__.py +31 -0
  159. mplang/v2/libs/mpc/vole/gilboa.py +327 -0
  160. mplang/v2/libs/mpc/vole/ldpc.py +383 -0
  161. mplang/v2/libs/mpc/vole/silver.py +336 -0
  162. mplang/v2/runtime/__init__.py +15 -0
  163. mplang/v2/runtime/dialect_state.py +41 -0
  164. mplang/v2/runtime/interpreter.py +871 -0
  165. mplang/v2/runtime/object_store.py +194 -0
  166. mplang/v2/runtime/value.py +141 -0
  167. {mplang_nightly-0.1.dev158.dist-info → mplang_nightly-0.1.dev268.dist-info}/METADATA +24 -17
  168. mplang_nightly-0.1.dev268.dist-info/RECORD +180 -0
  169. {mplang_nightly-0.1.dev158.dist-info → mplang_nightly-0.1.dev268.dist-info}/WHEEL +1 -1
  170. mplang/core/__init__.py +0 -92
  171. mplang/device.py +0 -340
  172. mplang/kernels/builtin.py +0 -207
  173. mplang/ops/crypto.py +0 -109
  174. mplang/ops/ibis_cc.py +0 -139
  175. mplang/ops/sql.py +0 -61
  176. mplang/protos/v1alpha1/mpir_pb2_grpc.py +0 -3
  177. mplang/runtime/link_comm.py +0 -131
  178. mplang/simp/smpc.py +0 -201
  179. mplang/utils/table_utils.py +0 -73
  180. mplang_nightly-0.1.dev158.dist-info/RECORD +0 -77
  181. /mplang/{core → v1/core}/mask.py +0 -0
  182. /mplang/{protos → v1/protos}/v1alpha1/mpir_pb2.py +0 -0
  183. /mplang/{runtime → v1/runtime}/exceptions.py +0 -0
  184. /mplang/{runtime → v1/runtime}/http_api.md +0 -0
  185. /mplang/{kernels → v1/simp}/__init__.py +0 -0
  186. /mplang/{utils → v1/utils}/__init__.py +0 -0
  187. /mplang/{utils → v1/utils}/crypto.py +0 -0
  188. /mplang/{utils → v1/utils}/func_utils.py +0 -0
  189. /mplang/{utils → v1/utils}/spu_utils.py +0 -0
  190. {mplang_nightly-0.1.dev158.dist-info → mplang_nightly-0.1.dev268.dist-info}/entry_points.txt +0 -0
  191. {mplang_nightly-0.1.dev158.dist-info → mplang_nightly-0.1.dev268.dist-info}/licenses/LICENSE +0 -0
@@ -0,0 +1,944 @@
1
+ # Copyright 2025 Ant Group Co., Ltd.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ """SIMP dialect: SPMD multi-party primitives for EDSL.
16
+
17
+ Provides control flow and communication primitives:
18
+ - pcall_static: Party call with explicit static parties
19
+ - pcall_dynamic: Party call where all parties attempt execution (output always dynamic)
20
+ - shuffle_dynamic, shuffle: Data redistribution
21
+ - converge: Merge disjoint partitions
22
+ - uniform_cond: Uniform conditional (eager mode)
23
+ - while_loop: While loop (eager mode)
24
+
25
+ Primitive definition guideline:
26
+ - Simple ops (add, mul) → use def_abstract_eval
27
+ - Complex ops (control flow, fork tracer) → use def_trace
28
+
29
+ See individual primitive docstrings for detailed documentation.
30
+ """
31
+
32
+ from __future__ import annotations
33
+
34
+ from collections.abc import Callable, Sequence
35
+ from typing import Any, cast
36
+
37
+ from jax.tree_util import tree_flatten, tree_unflatten
38
+
39
+ import mplang.v2.edsl as el
40
+ import mplang.v2.edsl.typing as elt
41
+
42
+ # ---------------------------------------------------------------------------
43
+ # Global configuration
44
+ # ---------------------------------------------------------------------------
45
+
46
+ # Whether to verify predicate uniformity at runtime in uniform_cond
47
+ # Set to False to disable runtime checks (useful for testing or when
48
+ # uniformity is guaranteed)
49
+ VERIFY_UNIFORM_DEFAULT = True
50
+
51
+ # ---------------------------------------------------------------------------
52
+ # Helper utilities
53
+ # ---------------------------------------------------------------------------
54
+
55
+
56
+ def _validate_scalar_predicate(value: el.graph.Value, context: str) -> None:
57
+ """Validate that a graph value represents a scalar predicate."""
58
+ shape = getattr(value.type, "shape", None)
59
+ if shape is not None and shape != ():
60
+ raise TypeError(
61
+ f"{context} must be scalar, got shape {shape} with type {value.type}"
62
+ )
63
+
64
+
65
+ def _merge_captures(*capture_lists: list[el.Object]) -> list[el.Object]:
66
+ """Merge capture lists while preserving first-seen order and deduplicating by id."""
67
+ seen: dict[int, el.Object] = {}
68
+ for obj in (o for lst in capture_lists for o in lst):
69
+ seen.setdefault(id(obj), obj)
70
+ return list(seen.values())
71
+
72
+
73
+ def _deduce_parties(types: Sequence[elt.BaseType]) -> tuple[int, ...] | None:
74
+ """Deduce common parties by intersecting all known party sets."""
75
+ if not types:
76
+ return None
77
+
78
+ # Extract parties from MPType objects
79
+ parties_list = []
80
+ for tp in types:
81
+ if isinstance(tp, elt.MPType):
82
+ parties_list.append(tp.parties)
83
+
84
+ if not parties_list:
85
+ return None
86
+
87
+ if any(p is None for p in parties_list):
88
+ return None
89
+
90
+ # Intersect all party sets (we know all parties are not None here)
91
+ first_parties = parties_list[0]
92
+ assert first_parties is not None
93
+ current = set(first_parties)
94
+ for parties in parties_list[1:]:
95
+ assert parties is not None
96
+ current &= set(parties)
97
+ return tuple(sorted(current))
98
+
99
+
100
+ class _LocalMPTracer(el.Tracer):
101
+ """Tracer for single-party regions executed under MP context."""
102
+
103
+ def _lift_type(self, obj: el.Object) -> elt.BaseType:
104
+ """Override to unwrap MP-typed Objects to their value types.
105
+
106
+ This enables single-party regions to work with the underlying value types.
107
+ MP-typed objects are unwrapped to their value types.
108
+ Other types (e.g. TensorType) are passed through as-is (treated as public/replicated).
109
+
110
+ Args:
111
+ obj: Object to lift
112
+
113
+ Returns:
114
+ value_type (unwrapped from MPType) or original type
115
+ """
116
+ obj_type = obj.type
117
+ if isinstance(obj_type, elt.MPType):
118
+ return cast(elt.BaseType, obj_type.value_type)
119
+ return cast(elt.BaseType, obj_type)
120
+
121
+
122
+ # ---------------------------------------------------------------------------
123
+ # Control flow (scaffold)
124
+ # ---------------------------------------------------------------------------
125
+
126
+
127
+ uniform_cond_p: el.Primitive[Any] = el.Primitive("simp.uniform_cond")
128
+
129
+
130
+ @uniform_cond_p.def_trace
131
+ def _uniform_cond_trace(
132
+ pred: el.Object,
133
+ then_fn: Callable[..., Any],
134
+ else_fn: Callable[..., Any],
135
+ *args: Any,
136
+ **kwargs: Any,
137
+ ) -> Any:
138
+ """Implementation for uniform_cond in trace mode.
139
+
140
+ Uses def_trace (not def_abstract_eval) because uniform_cond is a complex
141
+ control flow primitive that requires forking tracers for both branches.
142
+
143
+ Args:
144
+ pred: Boolean scalar TraceObject (must be uniform across all parties)
145
+ then_fn: Callable accepting (*args, **kwargs) to execute when pred is True
146
+ else_fn: Callable accepting (*args, **kwargs) to execute when pred is False
147
+ *args: Positional arguments to pass to branch functions
148
+ **kwargs: Keyword arguments to pass to branch functions
149
+
150
+ Returns:
151
+ Result from tracing both branches (TraceObject or tuple of TraceObjects)
152
+
153
+ Raises:
154
+ TypeError: If pred is not TraceObject, branches are not callable,
155
+ or branch outputs have mismatched types/counts
156
+
157
+ Note:
158
+ The verify_uniform flag is controlled by the global
159
+ VERIFY_UNIFORM_DEFAULT config. To change it, set
160
+ mplang.dialects.simp.VERIFY_UNIFORM_DEFAULT = False
161
+
162
+ Example:
163
+ >>> def then_fn(x, y):
164
+ ... return x + y
165
+ >>> def else_fn(x, y):
166
+ ... return x - y
167
+ >>> result = uniform_cond(pred, then_fn, else_fn, x, y)
168
+ """
169
+ cur_ctx = el.get_current_context()
170
+ assert isinstance(cur_ctx, el.Tracer)
171
+
172
+ if not isinstance(pred, el.TraceObject):
173
+ raise TypeError(f"predicate must be TraceObject, got {type(pred)}")
174
+ _validate_scalar_predicate(pred._graph_value, "uniform_cond predicate")
175
+ if not callable(then_fn) or not callable(else_fn):
176
+ raise TypeError("In trace mode, both branches must be callable functions")
177
+
178
+ then_traced = el.trace(then_fn, *args, **kwargs)
179
+ else_traced = el.trace(else_fn, *args, **kwargs)
180
+ if not then_traced.is_output_signature_match(else_traced):
181
+ then_types = [v.type for v in then_traced.graph.outputs]
182
+ else_types = [v.type for v in else_traced.graph.outputs]
183
+ raise TypeError(
184
+ "uniform_cond branch output signature mismatch: "
185
+ f"then={then_types}, else={else_types}"
186
+ )
187
+
188
+ num_arg_vars = len(then_traced.in_var_pos)
189
+
190
+ # Get outer graph values for parameters
191
+ # then_traced.params contains the original TraceObjects from args/kwargs
192
+ outer_arg_values = [
193
+ param._graph_value
194
+ for param in then_traced.params
195
+ if isinstance(param, el.TraceObject)
196
+ ]
197
+
198
+ if len(outer_arg_values) != num_arg_vars:
199
+ raise RuntimeError(
200
+ f"uniform_cond: argument count mismatch. Expected {num_arg_vars} variables, "
201
+ f"got {len(outer_arg_values)} from params."
202
+ )
203
+
204
+ all_captures = _merge_captures(then_traced.captured, else_traced.captured)
205
+
206
+ then_traced.align_region_inputs(num_arg_vars, all_captures)
207
+ else_traced.align_region_inputs(num_arg_vars, all_captures)
208
+
209
+ capture_trace_objs = [cur_ctx.lift(obj) for obj in all_captures]
210
+ capture_values = [obj._graph_value for obj in capture_trace_objs]
211
+
212
+ output_types = [v.type for v in then_traced.graph.outputs]
213
+ cond_inputs = [pred._graph_value, *outer_arg_values, *capture_values]
214
+
215
+ result_values = cur_ctx.graph.add_op(
216
+ opcode="simp.uniform_cond",
217
+ inputs=cond_inputs,
218
+ output_types=output_types,
219
+ attrs={"verify_uniform": VERIFY_UNIFORM_DEFAULT},
220
+ regions=[then_traced.graph, else_traced.graph],
221
+ )
222
+
223
+ return cur_ctx.reconstruct_outputs(
224
+ then_traced.out_var_pos,
225
+ then_traced.out_imms,
226
+ then_traced.out_tree,
227
+ result_values,
228
+ )
229
+
230
+
231
+ def uniform_cond(
232
+ pred: el.Object,
233
+ then_fn: Callable[..., Any],
234
+ else_fn: Callable[..., Any],
235
+ *args: Any,
236
+ **kwargs: Any,
237
+ ) -> Any:
238
+ """Uniform conditional that executes only the selected branch at runtime.
239
+
240
+ Args:
241
+ pred: Boolean scalar TraceObject that is uniform across parties.
242
+ then_fn: Callable evaluated when `pred` is True.
243
+ else_fn: Callable evaluated when `pred` is False.
244
+ *args: Additional positional arguments forwarded to both branches.
245
+ **kwargs: Additional keyword arguments forwarded to both branches.
246
+
247
+ Returns:
248
+ The PyTree produced by the selected branch.
249
+
250
+ Raises:
251
+ TypeError: If predicate/branches are invalid or branch outputs mismatch.
252
+ """
253
+
254
+ return uniform_cond_p.bind(pred, then_fn, else_fn, *args, **kwargs)
255
+
256
+
257
+ # ---------------------------------------------------------------------------
258
+ # While loop (scaffold)
259
+ # ---------------------------------------------------------------------------
260
+
261
+ while_loop_p: el.Primitive[Any] = el.Primitive("simp.while_loop")
262
+
263
+
264
+ @while_loop_p.def_trace
265
+ def _while_loop_trace(
266
+ cond_fn: Callable[[Any], Any],
267
+ body_fn: Callable[[Any], Any],
268
+ init: Any,
269
+ ) -> Any:
270
+ """Trace-mode implementation for SIMP while_loop."""
271
+
272
+ cur_ctx = el.get_current_context()
273
+ assert isinstance(cur_ctx, el.Tracer)
274
+ assert callable(cond_fn) and callable(body_fn)
275
+
276
+ state_flat, state_treedef = tree_flatten(init)
277
+ assert state_treedef is not None
278
+ if not state_flat:
279
+ raise TypeError("while_loop init must contain at least one Object")
280
+
281
+ # Validate all leaves are TraceObjects
282
+ for leaf in state_flat:
283
+ if not isinstance(leaf, el.TraceObject):
284
+ raise TypeError(
285
+ f"while_loop init leaves must be TraceObject, got {type(leaf)}"
286
+ )
287
+
288
+ cond_traced = el.trace(cond_fn, init)
289
+ body_traced = el.trace(body_fn, init)
290
+
291
+ # Use params from traced function (same as state_flat filtered to Objects)
292
+ # These are TraceObjects since we're in trace mode
293
+ state_trace_objs = cast(list[el.TraceObject], cond_traced.params)
294
+ state_values = [obj._graph_value for obj in state_trace_objs]
295
+ state_types = [obj.type for obj in state_trace_objs]
296
+ state_count = len(state_trace_objs)
297
+
298
+ cond_output_count = len(cond_traced.out_var_pos) + len(cond_traced.out_imms)
299
+ if cond_output_count != 1:
300
+ raise TypeError(
301
+ "while_loop cond_fn must return exactly one output, "
302
+ f"got {cond_output_count}"
303
+ )
304
+ if cond_traced.out_var_pos:
305
+ _validate_scalar_predicate(
306
+ cond_traced.graph.outputs[0], "while_loop cond_fn output"
307
+ )
308
+
309
+ body_output_count = len(body_traced.out_var_pos) + len(body_traced.out_imms)
310
+ if body_output_count != state_count:
311
+ raise TypeError(
312
+ "while_loop body_fn must return same number of values as init state: "
313
+ f"{state_count} expected, got {body_output_count}"
314
+ )
315
+ body_outputs = body_traced.graph.outputs
316
+ if len(body_outputs) != state_count:
317
+ raise TypeError(
318
+ "while_loop body_fn must return all Variables "
319
+ "(no immediates allowed in loop state), "
320
+ f"expected {state_count} Variables, got {len(body_outputs)}"
321
+ )
322
+ for idx, (out_val, state_obj) in enumerate(
323
+ zip(body_outputs, state_trace_objs, strict=True)
324
+ ):
325
+ if out_val.type != state_obj.type:
326
+ raise TypeError(
327
+ "while_loop body_fn output type mismatch at index "
328
+ f"{idx}: {out_val.type} vs {state_obj.type}"
329
+ )
330
+
331
+ all_captures = _merge_captures(cond_traced.captured, body_traced.captured)
332
+
333
+ cond_traced.align_region_inputs(state_count, all_captures)
334
+ body_traced.align_region_inputs(state_count, all_captures)
335
+
336
+ capture_trace_objs = [cur_ctx.lift(obj) for obj in all_captures]
337
+ capture_values = [obj._graph_value for obj in capture_trace_objs]
338
+
339
+ loop_inputs = [*state_values, *capture_values]
340
+ result_values = cur_ctx.graph.add_op(
341
+ opcode="simp.while_loop",
342
+ inputs=loop_inputs,
343
+ output_types=state_types,
344
+ regions=[cond_traced.graph, body_traced.graph],
345
+ )
346
+
347
+ result_trace_objs = [el.TraceObject(val, cur_ctx) for val in result_values]
348
+ return tree_unflatten(state_treedef, result_trace_objs)
349
+
350
+
351
+ def while_loop(
352
+ cond_fn: Callable[[Any], Any],
353
+ body_fn: Callable[[Any], Any],
354
+ init: Any,
355
+ ) -> Any:
356
+ """Execute a SIMP while loop that synchronizes across parties.
357
+
358
+ Args:
359
+ cond_fn: Receives the current loop state and returns a boolean scalar.
360
+ body_fn: Receives the current loop state and returns the next state
361
+ with the same PyTree structure and per-leaf types as `init`.
362
+ init: Initial loop state (PyTree of Objects) shared by all parties.
363
+
364
+ Returns:
365
+ Final state after `cond_fn` evaluates to False.
366
+
367
+ Raises:
368
+ TypeError: If `cond_fn`/`body_fn` outputs violate the required shape or
369
+ type constraints.
370
+ """
371
+
372
+ return while_loop_p.bind(cond_fn, body_fn, init)
373
+
374
+
375
+ # Core primitives with clear semantic names
376
+ pcall_static_p = el.Primitive[Any]("simp.pcall_static")
377
+ pcall_dynamic_p = el.Primitive[Any]("simp.pcall_dynamic")
378
+ shuffle_dynamic_p = el.Primitive[el.Object]("simp.shuffle_dynamic")
379
+ shuffle_static_p = el.Primitive[el.Object]("simp.shuffle")
380
+ converge_p = el.Primitive[el.Object]("simp.converge")
381
+
382
+
383
+ @pcall_static_p.def_trace
384
+ def _pcall_static_trace(
385
+ parties: tuple[int, ...],
386
+ local_fn: Callable[..., Any],
387
+ *args: Any,
388
+ **kwargs: Any,
389
+ ) -> Any:
390
+ """Trace a local single-party region with explicit static parties.
391
+
392
+ Args:
393
+ parties: Required tuple of participating party ranks.
394
+ local_fn: Callable representing the single-party function body.
395
+ *args: Positional arguments forming a PyTree of MPObjects /
396
+ TraceObjects / immediates passed to the region.
397
+ **kwargs: Keyword arguments forwarded to ``local_fn``.
398
+
399
+ Returns:
400
+ PyTree of TraceObjects with static parties mask.
401
+
402
+ Raises:
403
+ TypeError: If ``local_fn`` is not callable or arguments contain invalid types.
404
+ ValueError: When explicitly provided parties are not covered by input parties.
405
+ """
406
+ cur_ctx = el.get_current_context()
407
+ assert isinstance(cur_ctx, el.Tracer)
408
+ assert callable(local_fn)
409
+
410
+ if parties is None:
411
+ raise ValueError("pcall_static requires explicit parties, got None")
412
+
413
+ requested_parties = tuple(sorted(set(parties)))
414
+
415
+ local_tracer = _LocalMPTracer()
416
+ local_traced = local_tracer.run(local_fn, *args, **kwargs)
417
+
418
+ # Get all input objects: params (function arguments) + captured (closures)
419
+ # TracedFunction guarantees: graph.inputs = [*params_inputs, *captured_inputs]
420
+ all_input_objs = local_traced.params + local_traced.captured
421
+
422
+ # All types are guaranteed to be MPType by _LocalMPTracer._lift
423
+ all_input_types: list[elt.MPType] = [obj.type for obj in all_input_objs] # type: ignore[misc]
424
+ deduced_parties = _deduce_parties(all_input_types)
425
+
426
+ if deduced_parties is not None:
427
+ if not set(requested_parties).issubset(set(deduced_parties)):
428
+ raise ValueError(
429
+ f"Requested parties {requested_parties} not covered by "
430
+ f"input argument parties {deduced_parties}"
431
+ )
432
+
433
+ # Re-capture all input objects in outer context
434
+ recaptured_objs = [cur_ctx.lift(obj) for obj in all_input_objs]
435
+ region_inputs = [obj._graph_value for obj in recaptured_objs]
436
+ result_types: list[elt.BaseType] = [
437
+ elt.MPType(value.type, requested_parties)
438
+ for value in local_traced.graph.outputs
439
+ ]
440
+
441
+ result_values = cur_ctx.graph.add_op(
442
+ opcode="simp.pcall_static",
443
+ inputs=region_inputs,
444
+ output_types=result_types,
445
+ attrs={
446
+ "fn_name": local_traced.name,
447
+ "parties": list(requested_parties),
448
+ },
449
+ regions=[local_traced.graph],
450
+ )
451
+
452
+ return cur_ctx.reconstruct_outputs(
453
+ local_traced.out_var_pos,
454
+ local_traced.out_imms,
455
+ local_traced.out_tree,
456
+ result_values,
457
+ )
458
+
459
+
460
+ @pcall_dynamic_p.def_trace
461
+ def _pcall_dynamic_trace(
462
+ local_fn: Callable[..., Any],
463
+ *args: Any,
464
+ **kwargs: Any,
465
+ ) -> Any:
466
+ """Trace a party call with dynamic execution.
467
+
468
+ All parties attempt to execute. Runtime behavior: each party executes
469
+ if all inputs are present, otherwise outputs None. Output always has
470
+ dynamic parties (None).
471
+
472
+ Args:
473
+ local_fn: Callable representing the single-party function body.
474
+ *args: Positional arguments forming a PyTree of MPObjects /
475
+ TraceObjects / immediates passed to the region.
476
+ **kwargs: Keyword arguments forwarded to ``local_fn``.
477
+
478
+ Returns:
479
+ PyTree of TraceObjects with dynamic parties (None).
480
+
481
+ Raises:
482
+ TypeError: If ``local_fn`` is not callable or arguments contain invalid types.
483
+ """
484
+ cur_ctx = el.get_current_context()
485
+ assert isinstance(cur_ctx, el.Tracer)
486
+ assert callable(local_fn)
487
+
488
+ local_tracer = _LocalMPTracer()
489
+ local_traced = local_tracer.run(local_fn, *args, **kwargs)
490
+
491
+ # Get all input objects: params (function arguments) + captured (closures)
492
+ # TracedFunction guarantees: graph.inputs = [*params_inputs, *captured_inputs]
493
+ all_input_objs = local_traced.params + local_traced.captured
494
+
495
+ recaptured_objs = [cur_ctx.lift(obj) for obj in all_input_objs]
496
+ region_inputs = [obj._graph_value for obj in recaptured_objs]
497
+
498
+ # Output always has dynamic parties (None)
499
+ result_types: list[elt.BaseType] = [
500
+ elt.MPType(value.type, None) for value in local_traced.graph.outputs
501
+ ]
502
+
503
+ result_values = cur_ctx.graph.add_op(
504
+ opcode="simp.pcall_dynamic",
505
+ inputs=region_inputs,
506
+ output_types=result_types,
507
+ attrs={
508
+ "fn_name": local_traced.name,
509
+ },
510
+ regions=[local_traced.graph],
511
+ )
512
+
513
+ return cur_ctx.reconstruct_outputs(
514
+ local_traced.out_var_pos,
515
+ local_traced.out_imms,
516
+ local_traced.out_tree,
517
+ result_values,
518
+ )
519
+
520
+
521
+ def pcall_static(
522
+ parties: tuple[int, ...],
523
+ local_fn: Callable[..., Any],
524
+ *call_args: Any,
525
+ **call_kwargs: Any,
526
+ ) -> Any:
527
+ """Execute a function on explicitly specified parties (static).
528
+
529
+ This primitive requires explicit party specification and always produces
530
+ static party masks in the output. Use this when the execution parties
531
+ are known at compile time.
532
+
533
+ Args:
534
+ parties: Required tuple of party ranks (must be explicit, not None).
535
+ local_fn: Callable representing the single-party computation.
536
+ *call_args: Positional arguments forwarded to ``local_fn``.
537
+ **call_kwargs: Keyword arguments forwarded to ``local_fn``.
538
+
539
+ Returns:
540
+ Result with static parties mask matching the parties argument.
541
+
542
+ Example:
543
+ >>> # Compute on parties 0 and 1 (static)
544
+ >>> result = pcall_static(parties=(0, 1), local_fn=lambda x: x + 1, x)
545
+ """
546
+ return pcall_static_p.bind(parties, local_fn, *call_args, **call_kwargs)
547
+
548
+
549
+ def pcall_dynamic(
550
+ local_fn: Callable[..., Any],
551
+ *call_args: Any,
552
+ **call_kwargs: Any,
553
+ ) -> Any:
554
+ """Execute a function on all parties with runtime-determined execution.
555
+
556
+ All parties attempt to execute the function. At runtime, each party executes
557
+ if all inputs are present, otherwise outputs None. Output always has dynamic
558
+ party mask (None).
559
+
560
+ Args:
561
+ local_fn: Callable representing the single-party computation.
562
+ *call_args: Positional arguments forwarded to ``local_fn``.
563
+ **call_kwargs: Keyword arguments forwarded to ``local_fn``.
564
+
565
+ Returns:
566
+ Result with dynamic parties (None). At runtime, parties with all inputs
567
+ execute, others output None.
568
+
569
+ Example:
570
+ >>> # All parties attempt execution based on input availability
571
+ >>> result = pcall_dynamic(local_fn=lambda x: x + 1, x)
572
+ """
573
+ return pcall_dynamic_p.bind(local_fn, *call_args, **call_kwargs)
574
+
575
+
576
+ @shuffle_dynamic_p.def_abstract_eval
577
+ def _shuffle_dynamic_ae(src_t: elt.BaseType, index_t: elt.BaseType) -> elt.BaseType:
578
+ """Type inference for dynamic shuffle (runtime-determined data redistribution).
579
+
580
+ Args:
581
+ src_t: Source value type (must be MPType)
582
+ index_t: Index value type (must be MPType with scalar shape)
583
+
584
+ Returns:
585
+ Output type with dynamic mask (parties=None)
586
+
587
+ Raises:
588
+ TypeError: If src or index are not MP-typed, or index is not scalar
589
+ """
590
+ if not isinstance(src_t, elt.MPType):
591
+ raise TypeError(f"shuffle_dynamic requires MP-typed src, got {src_t}")
592
+ if not isinstance(index_t, elt.MPType):
593
+ raise TypeError(f"shuffle_dynamic requires MP-typed index, got {index_t}")
594
+
595
+ # Validate index is scalar
596
+ index_shape = getattr(index_t.value_type, "shape", None)
597
+ if index_shape is not None and index_shape != ():
598
+ raise TypeError(
599
+ f"shuffle_dynamic index must be scalar, got shape {index_shape} "
600
+ f"with type {index_t.value_type}"
601
+ )
602
+
603
+ # Output: dynamic mask (None parties)
604
+ return elt.MPType(src_t.value_type, None)
605
+
606
+
607
+ @shuffle_static_p.def_abstract_eval
608
+ def _shuffle_ae(src_t: elt.BaseType, routing: dict[int, int]) -> elt.BaseType:
609
+ """Type inference for static shuffle (compile-time known data routing).
610
+
611
+ Args:
612
+ src_t: Source value type (must be MPType)
613
+ routing: Dict mapping target_party -> source_rank
614
+
615
+ Returns:
616
+ Output type with static mask (parties=tuple(sorted(routing.keys())))
617
+
618
+ Raises:
619
+ TypeError: If src is not MP-typed or routing is not a dict
620
+ ValueError: If routing references parties not in src.parties
621
+ """
622
+ if not isinstance(src_t, elt.MPType):
623
+ raise TypeError(f"shuffle_static requires MP-typed src, got {src_t}")
624
+
625
+ if not isinstance(routing, dict):
626
+ raise TypeError(f"shuffle_static requires routing dict, got {type(routing)}")
627
+
628
+ if not routing:
629
+ raise ValueError("shuffle_static requires non-empty routing dict")
630
+
631
+ # Target parties are the keys of routing dict
632
+ target_parties = tuple(sorted(routing.keys()))
633
+
634
+ # Validate source ranks are in src.parties (if src.parties is known)
635
+ if src_t.parties is not None:
636
+ for target, source in routing.items():
637
+ if source not in src_t.parties:
638
+ raise ValueError(
639
+ f"shuffle_static: routing[{target}]={source} not in "
640
+ f"src.parties {src_t.parties}"
641
+ )
642
+
643
+ # Output: static mask with target parties
644
+ return elt.MPType(src_t.value_type, target_parties)
645
+
646
+
647
+ @converge_p.def_abstract_eval
648
+ def _converge_ae(in_types: list[elt.BaseType], *, mask: int = -1) -> elt.BaseType:
649
+ """Type inference for converge operation (merge disjoint partitions).
650
+
651
+ Args:
652
+ in_types: List of input types (all must be MPType with same value_type)
653
+ attrs: Attributes dict (unused)
654
+
655
+ Returns:
656
+ Output type with union of input parties (or None if any input is dynamic)
657
+
658
+ Raises:
659
+ TypeError: If inputs are not all MP-typed or have inconsistent value_types
660
+ ValueError: If static parties are not disjoint
661
+ """
662
+ if not in_types:
663
+ raise TypeError("converge requires at least one input")
664
+
665
+ # Validate all are MPType
666
+ for i, t in enumerate(in_types):
667
+ if not isinstance(t, elt.MPType):
668
+ raise TypeError(f"converge input {i} must be MP-typed, got {t}")
669
+
670
+ mp_types = [t for t in in_types if isinstance(t, elt.MPType)]
671
+
672
+ # Check value_type consistency
673
+ first_vtype = mp_types[0].value_type
674
+ for i, mt in enumerate(mp_types[1:], 1):
675
+ if mt.value_type != first_vtype:
676
+ raise TypeError(
677
+ f"converge value type mismatch at input {i}: "
678
+ f"{mt.value_type} vs {first_vtype}"
679
+ )
680
+
681
+ # Deduce output parties
682
+ parties_list = [mt.parties for mt in mp_types]
683
+
684
+ if any(p is None for p in parties_list):
685
+ # Dynamic case: propagate None
686
+ output_parties = None
687
+ else:
688
+ # Static case: check disjoint and union
689
+ for i, p1 in enumerate(parties_list):
690
+ for j, p2 in enumerate(parties_list[i + 1 :], i + 1):
691
+ if p1 is not None and p2 is not None:
692
+ if set(p1) & set(p2):
693
+ raise ValueError(
694
+ f"converge requires disjoint parties, inputs {i} and {j} "
695
+ f"overlap: {set(p1) & set(p2)}"
696
+ )
697
+
698
+ # Union all parties
699
+ all_parties: set[int] = set()
700
+ for p in parties_list:
701
+ if p is not None:
702
+ all_parties.update(p)
703
+ output_parties = tuple(sorted(all_parties)) if all_parties else None
704
+
705
+ return elt.MPType(first_vtype, output_parties)
706
+
707
+
708
+ def shuffle_dynamic(src: el.Object, index: el.Object) -> el.Object:
709
+ """Dynamic shuffle: redistribute data based on runtime index values.
710
+
711
+ Each party uses its local index value to fetch data from the corresponding
712
+ source party. The output has dynamic mask (parties=None) since the data
713
+ distribution depends on runtime index values.
714
+
715
+ This is the most flexible shuffle primitive but requires runtime communication
716
+ pattern determination.
717
+
718
+ Args:
719
+ src: Source data (MP-typed)
720
+ index: Index indicating which source party to fetch from (MP-typed scalar)
721
+
722
+ Returns:
723
+ Shuffled data with dynamic mask (parties=None)
724
+
725
+ Example:
726
+ >>> # P0, P1, P2 each hold different index values at runtime
727
+ >>> result = shuffle_dynamic(src, index)
728
+ >>> # result.type.parties == None (dynamic)
729
+ """
730
+ return shuffle_dynamic_p.bind(src, index)
731
+
732
+
733
+ def shuffle_static(src: el.Object, routing: dict[int, int]) -> el.Object:
734
+ """Static shuffle: redistribute data with compile-time known routing pattern.
735
+
736
+ Unlike shuffle_dynamic, the routing pattern is known at compile time.
737
+ Each entry in routing specifies: target_party -> source_rank.
738
+
739
+ This enables compile-time optimization and produces a static output mask.
740
+
741
+ Design rationale:
742
+ Uses receiver-oriented routing {target: source} to naturally express:
743
+ - Permutation: {0: 1, 1: 0} (swap parties)
744
+ - Broadcast: {0: 1, 2: 1} (multiple targets from same source)
745
+ Maintains SIMP single-input-single-output semantics at MP value level.
746
+
747
+ Args:
748
+ src: Source data (MP-typed)
749
+ routing: Dict mapping target_party -> source_rank
750
+ e.g., {0: 1, 2: 0} means:
751
+ - party 0 receives from rank 1
752
+ - party 2 receives from rank 0
753
+
754
+ Returns:
755
+ Shuffled data with static mask (parties=sorted keys of routing)
756
+
757
+ Example:
758
+ >>> # Party 0 gets data from rank 1
759
+ >>> result = shuffle_static(src, routing={0: 1})
760
+ >>> # result.type.parties == (0,)
761
+ >>>
762
+ >>> # Multiple parties
763
+ >>> result = shuffle_static(src, routing={0: 1, 2: 0})
764
+ >>> # result.type.parties == (0, 2)
765
+ """
766
+ return shuffle_static_p.bind(src, routing=routing)
767
+
768
+
769
+ def converge(*vars: el.Object) -> el.Object:
770
+ """Converge multiple disjoint-partitioned variables into one.
771
+
772
+ Merges data from multiple parties into one logical variable. In static case,
773
+ validates that input parties are disjoint and produces their union. In dynamic
774
+ case, propagates the dynamic mask.
775
+
776
+ This is the fundamental operation for combining results from different parties.
777
+
778
+ Args:
779
+ *vars: Variable number of MP-typed inputs with disjoint parties
780
+
781
+ Returns:
782
+ Converged variable with union of input parties (or None if any input is dynamic)
783
+
784
+ Raises:
785
+ ValueError: If static parties are not disjoint
786
+
787
+ Example:
788
+ >>> # P0 has x, P1 has y (disjoint)
789
+ >>> result = converge(x, y)
790
+ >>> # result.type.parties == (0, 1)
791
+ """
792
+ return converge_p.bind(*vars)
793
+
794
+
795
+ def constant(parties: tuple[int, ...], data: Any) -> el.Object:
796
+ """Create a constant value distributed to specific parties.
797
+
798
+ This is a helper function that creates a constant value on the specified
799
+ parties. It is equivalent to calling `pcall_static` with a function that
800
+ returns the constant data.
801
+
802
+ Args:
803
+ parties: Tuple of party ranks where the constant should be placed.
804
+ data: The constant data (scalar, array, etc.).
805
+
806
+ Returns:
807
+ MP[Tensor, parties] object representing the distributed constant.
808
+ """
809
+ import jax.numpy as jnp
810
+ import numpy as np
811
+
812
+ from mplang.v2.dialects import table, tensor
813
+
814
+ # 1. Scalars (int, float, bool, numpy scalars)
815
+ if isinstance(data, (int, float, bool, np.number, np.bool_)):
816
+ return cast(el.Object, pcall_static(parties, tensor.constant, data))
817
+
818
+ # 2. Tensor-like (numpy array or JAX array)
819
+ if isinstance(data, (np.ndarray, jnp.ndarray)):
820
+ return cast(el.Object, pcall_static(parties, tensor.constant, data))
821
+
822
+ # 3. Table-like (dict, DataFrame)
823
+ is_dataframe = False
824
+ try:
825
+ import pandas as pd
826
+
827
+ if isinstance(data, pd.DataFrame):
828
+ is_dataframe = True
829
+ except ImportError:
830
+ pass
831
+
832
+ if is_dataframe or isinstance(data, dict):
833
+ return cast(el.Object, pcall_static(parties, table.constant, data))
834
+
835
+ # 4. Lists/Tuples (Ambiguous, default to tensor)
836
+ if isinstance(data, (list, tuple)):
837
+ return cast(el.Object, pcall_static(parties, tensor.constant, data))
838
+
839
+ raise TypeError(f"Unsupported data type for simp.constant: {type(data)}")
840
+
841
+
842
+ # Backward compatibility aliases
843
+ def peval(
844
+ parties: tuple[int, ...] | None,
845
+ local_fn: Callable[..., Any],
846
+ *call_args: Any,
847
+ **call_kwargs: Any,
848
+ ) -> Any:
849
+ """Backward compatible peval function.
850
+
851
+ Routes to pcall_static if parties is explicit, pcall_dynamic if None.
852
+ """
853
+ if parties is None:
854
+ return pcall_dynamic(local_fn, *call_args, **call_kwargs)
855
+ else:
856
+ return pcall_static(parties, local_fn, *call_args, **call_kwargs)
857
+
858
+
859
+ # =============================================================================
860
+ # Factory functions for creating configured Interpreters
861
+ # =============================================================================
862
+
863
+
864
+ def make_simulator(
865
+ world_size: int,
866
+ *,
867
+ cluster_spec: Any = None,
868
+ enable_tracing: bool = False,
869
+ enable_profiling: bool = False,
870
+ ) -> Any:
871
+ """Create an Interpreter configured for local SIMP simulation.
872
+
873
+ This factory creates a LocalCluster with workers and returns an
874
+ Interpreter with the simp dialect state attached.
875
+
876
+ Args:
877
+ world_size: Number of simulated parties.
878
+ cluster_spec: Optional ClusterSpec for metadata.
879
+ enable_tracing: If True, enable execution tracing.
880
+ enable_profiling: If True, enable primitive profiling for benchmarking.
881
+
882
+ Returns:
883
+ Configured Interpreter with simp state attached.
884
+
885
+ Example:
886
+ >>> interp = simp.make_simulator(2)
887
+ >>> with interp:
888
+ ... result = my_func()
889
+ """
890
+ if enable_profiling:
891
+ from mplang.v2.edsl import registry
892
+
893
+ registry.enable_profiling()
894
+
895
+ from mplang.v2.backends.simp_driver.mem import make_simulator as _make_sim
896
+
897
+ return _make_sim(
898
+ world_size, cluster_spec=cluster_spec, enable_tracing=enable_tracing
899
+ )
900
+
901
+
902
+ def make_driver(endpoints: list[str], *, cluster_spec: Any = None) -> Any:
903
+ """Create an Interpreter configured for remote SIMP execution.
904
+
905
+ This factory creates a RemoteSimpState and returns an Interpreter
906
+ with the simp dialect state attached.
907
+
908
+ Args:
909
+ endpoints: List of HTTP endpoints for workers.
910
+ cluster_spec: Optional ClusterSpec for metadata.
911
+
912
+ Returns:
913
+ Configured Interpreter with simp state attached.
914
+
915
+ Example:
916
+ >>> interp = simp.make_driver(["http://worker1:8000", "http://worker2:8000"])
917
+ >>> with interp:
918
+ ... result = my_func()
919
+ """
920
+ from mplang.v2.backends.simp_driver.http import make_driver as _make_drv
921
+
922
+ return _make_drv(endpoints, cluster_spec=cluster_spec)
923
+
924
+
925
+ __all__ = [
926
+ "constant",
927
+ "converge",
928
+ "converge_p",
929
+ "make_driver",
930
+ "make_simulator",
931
+ "pcall_dynamic",
932
+ "pcall_dynamic_p",
933
+ "pcall_static",
934
+ "pcall_static_p",
935
+ "peval",
936
+ "shuffle_dynamic",
937
+ "shuffle_dynamic_p",
938
+ "shuffle_static",
939
+ "shuffle_static_p",
940
+ "uniform_cond",
941
+ "uniform_cond_p",
942
+ "while_loop",
943
+ "while_loop_p",
944
+ ]