mplang-nightly 0.1.dev268__py3-none-any.whl → 0.1.dev270__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (181) hide show
  1. mplang/__init__.py +391 -17
  2. mplang/{v2/backends → backends}/__init__.py +9 -7
  3. mplang/{v2/backends → backends}/bfv_impl.py +6 -6
  4. mplang/{v2/backends → backends}/crypto_impl.py +6 -6
  5. mplang/{v2/backends → backends}/field_impl.py +5 -5
  6. mplang/{v2/backends → backends}/func_impl.py +4 -4
  7. mplang/{v2/backends → backends}/phe_impl.py +3 -3
  8. mplang/{v2/backends → backends}/simp_design.md +1 -1
  9. mplang/{v2/backends → backends}/simp_driver/__init__.py +5 -5
  10. mplang/{v2/backends → backends}/simp_driver/http.py +8 -8
  11. mplang/{v2/backends → backends}/simp_driver/mem.py +9 -9
  12. mplang/{v2/backends → backends}/simp_driver/ops.py +4 -4
  13. mplang/{v2/backends → backends}/simp_driver/state.py +2 -2
  14. mplang/{v2/backends → backends}/simp_driver/values.py +2 -2
  15. mplang/{v2/backends → backends}/simp_worker/__init__.py +3 -3
  16. mplang/{v2/backends → backends}/simp_worker/http.py +10 -10
  17. mplang/{v2/backends → backends}/simp_worker/mem.py +1 -1
  18. mplang/{v2/backends → backends}/simp_worker/ops.py +5 -5
  19. mplang/{v2/backends → backends}/simp_worker/state.py +2 -4
  20. mplang/{v2/backends → backends}/spu_impl.py +8 -8
  21. mplang/{v2/backends → backends}/spu_state.py +4 -4
  22. mplang/{v2/backends → backends}/store_impl.py +3 -3
  23. mplang/{v2/backends → backends}/table_impl.py +8 -8
  24. mplang/{v2/backends → backends}/tee_impl.py +6 -6
  25. mplang/{v2/backends → backends}/tensor_impl.py +6 -6
  26. mplang/{v2/cli.py → cli.py} +9 -9
  27. mplang/{v2/cli_guide.md → cli_guide.md} +12 -12
  28. mplang/{v2/dialects → dialects}/__init__.py +5 -5
  29. mplang/{v2/dialects → dialects}/bfv.py +6 -6
  30. mplang/{v2/dialects → dialects}/crypto.py +5 -5
  31. mplang/{v2/dialects → dialects}/dtypes.py +2 -2
  32. mplang/{v2/dialects → dialects}/field.py +3 -3
  33. mplang/{v2/dialects → dialects}/func.py +2 -2
  34. mplang/{v2/dialects → dialects}/phe.py +6 -6
  35. mplang/{v2/dialects → dialects}/simp.py +6 -6
  36. mplang/{v2/dialects → dialects}/spu.py +7 -7
  37. mplang/{v2/dialects → dialects}/store.py +2 -2
  38. mplang/{v2/dialects → dialects}/table.py +3 -3
  39. mplang/{v2/dialects → dialects}/tee.py +6 -6
  40. mplang/{v2/dialects → dialects}/tensor.py +5 -5
  41. mplang/{v2/edsl → edsl}/__init__.py +3 -3
  42. mplang/{v2/edsl → edsl}/context.py +6 -6
  43. mplang/{v2/edsl → edsl}/graph.py +5 -5
  44. mplang/{v2/edsl → edsl}/jit.py +2 -2
  45. mplang/{v2/edsl → edsl}/object.py +1 -1
  46. mplang/{v2/edsl → edsl}/primitive.py +5 -5
  47. mplang/{v2/edsl → edsl}/printer.py +1 -1
  48. mplang/{v2/edsl → edsl}/serde.py +1 -1
  49. mplang/{v2/edsl → edsl}/tracer.py +7 -7
  50. mplang/{v2/edsl → edsl}/typing.py +1 -1
  51. mplang/{v2/kernels → kernels}/ldpc.cpp +13 -13
  52. mplang/{v2/kernels → kernels}/okvs.cpp +4 -4
  53. mplang/{v2/kernels → kernels}/okvs_opt.cpp +46 -31
  54. mplang/{v2/kernels → kernels}/py_kernels.py +1 -1
  55. mplang/{v2/libs → libs}/collective.py +5 -5
  56. mplang/{v2/libs → libs}/device/__init__.py +1 -1
  57. mplang/{v2/libs → libs}/device/api.py +12 -12
  58. mplang/{v2/libs → libs}/ml/__init__.py +1 -1
  59. mplang/{v2/libs → libs}/ml/sgb.py +4 -4
  60. mplang/{v2/libs → libs}/mpc/__init__.py +3 -3
  61. mplang/{v2/libs → libs}/mpc/_utils.py +2 -2
  62. mplang/{v2/libs → libs}/mpc/analytics/aggregation.py +1 -1
  63. mplang/{v2/libs → libs}/mpc/analytics/groupby.py +2 -2
  64. mplang/{v2/libs → libs}/mpc/analytics/permutation.py +3 -3
  65. mplang/{v2/libs → libs}/mpc/ot/base.py +3 -3
  66. mplang/{v2/libs → libs}/mpc/ot/extension.py +2 -2
  67. mplang/{v2/libs → libs}/mpc/ot/silent.py +4 -4
  68. mplang/{v2/libs → libs}/mpc/psi/cuckoo.py +3 -3
  69. mplang/{v2/libs → libs}/mpc/psi/okvs.py +1 -1
  70. mplang/{v2/libs → libs}/mpc/psi/okvs_gct.py +19 -13
  71. mplang/{v2/libs → libs}/mpc/psi/oprf.py +3 -3
  72. mplang/libs/mpc/psi/rr22.py +303 -0
  73. mplang/{v2/libs → libs}/mpc/psi/unbalanced.py +4 -4
  74. mplang/{v2/libs → libs}/mpc/vole/gilboa.py +3 -3
  75. mplang/{v2/libs → libs}/mpc/vole/ldpc.py +2 -2
  76. mplang/{v2/libs → libs}/mpc/vole/silver.py +6 -6
  77. mplang/{v2/runtime → runtime}/interpreter.py +11 -11
  78. mplang/{v2/runtime → runtime}/value.py +2 -2
  79. mplang/{v1/runtime → utils}/__init__.py +18 -15
  80. mplang/{v1/utils → utils}/func_utils.py +1 -1
  81. {mplang_nightly-0.1.dev268.dist-info → mplang_nightly-0.1.dev270.dist-info}/METADATA +2 -2
  82. mplang_nightly-0.1.dev270.dist-info/RECORD +102 -0
  83. mplang/v1/__init__.py +0 -157
  84. mplang/v1/_device.py +0 -602
  85. mplang/v1/analysis/__init__.py +0 -37
  86. mplang/v1/analysis/diagram.py +0 -567
  87. mplang/v1/core/__init__.py +0 -157
  88. mplang/v1/core/cluster.py +0 -343
  89. mplang/v1/core/comm.py +0 -281
  90. mplang/v1/core/context_mgr.py +0 -50
  91. mplang/v1/core/dtypes.py +0 -335
  92. mplang/v1/core/expr/__init__.py +0 -80
  93. mplang/v1/core/expr/ast.py +0 -542
  94. mplang/v1/core/expr/evaluator.py +0 -581
  95. mplang/v1/core/expr/printer.py +0 -285
  96. mplang/v1/core/expr/transformer.py +0 -141
  97. mplang/v1/core/expr/utils.py +0 -78
  98. mplang/v1/core/expr/visitor.py +0 -85
  99. mplang/v1/core/expr/walk.py +0 -387
  100. mplang/v1/core/interp.py +0 -160
  101. mplang/v1/core/mask.py +0 -325
  102. mplang/v1/core/mpir.py +0 -965
  103. mplang/v1/core/mpobject.py +0 -117
  104. mplang/v1/core/mptype.py +0 -407
  105. mplang/v1/core/pfunc.py +0 -130
  106. mplang/v1/core/primitive.py +0 -877
  107. mplang/v1/core/table.py +0 -218
  108. mplang/v1/core/tensor.py +0 -75
  109. mplang/v1/core/tracer.py +0 -383
  110. mplang/v1/host.py +0 -130
  111. mplang/v1/kernels/__init__.py +0 -41
  112. mplang/v1/kernels/base.py +0 -125
  113. mplang/v1/kernels/basic.py +0 -240
  114. mplang/v1/kernels/context.py +0 -369
  115. mplang/v1/kernels/crypto.py +0 -122
  116. mplang/v1/kernels/fhe.py +0 -858
  117. mplang/v1/kernels/mock_tee.py +0 -72
  118. mplang/v1/kernels/phe.py +0 -1864
  119. mplang/v1/kernels/spu.py +0 -341
  120. mplang/v1/kernels/sql_duckdb.py +0 -44
  121. mplang/v1/kernels/stablehlo.py +0 -90
  122. mplang/v1/kernels/value.py +0 -626
  123. mplang/v1/ops/__init__.py +0 -35
  124. mplang/v1/ops/base.py +0 -424
  125. mplang/v1/ops/basic.py +0 -294
  126. mplang/v1/ops/crypto.py +0 -262
  127. mplang/v1/ops/fhe.py +0 -272
  128. mplang/v1/ops/jax_cc.py +0 -147
  129. mplang/v1/ops/nnx_cc.py +0 -168
  130. mplang/v1/ops/phe.py +0 -216
  131. mplang/v1/ops/spu.py +0 -151
  132. mplang/v1/ops/sql_cc.py +0 -303
  133. mplang/v1/ops/tee.py +0 -36
  134. mplang/v1/protos/v1alpha1/mpir_pb2.py +0 -63
  135. mplang/v1/protos/v1alpha1/mpir_pb2.pyi +0 -557
  136. mplang/v1/protos/v1alpha1/value_pb2.py +0 -34
  137. mplang/v1/protos/v1alpha1/value_pb2.pyi +0 -169
  138. mplang/v1/runtime/channel.py +0 -230
  139. mplang/v1/runtime/cli.py +0 -451
  140. mplang/v1/runtime/client.py +0 -456
  141. mplang/v1/runtime/communicator.py +0 -131
  142. mplang/v1/runtime/data_providers.py +0 -303
  143. mplang/v1/runtime/driver.py +0 -324
  144. mplang/v1/runtime/exceptions.py +0 -27
  145. mplang/v1/runtime/http_api.md +0 -56
  146. mplang/v1/runtime/link_comm.py +0 -196
  147. mplang/v1/runtime/server.py +0 -501
  148. mplang/v1/runtime/session.py +0 -270
  149. mplang/v1/runtime/simulation.py +0 -324
  150. mplang/v1/simp/__init__.py +0 -13
  151. mplang/v1/simp/api.py +0 -353
  152. mplang/v1/simp/mpi.py +0 -131
  153. mplang/v1/simp/party.py +0 -225
  154. mplang/v1/simp/random.py +0 -120
  155. mplang/v1/simp/smpc.py +0 -238
  156. mplang/v1/utils/__init__.py +0 -13
  157. mplang/v1/utils/crypto.py +0 -32
  158. mplang/v1/utils/spu_utils.py +0 -130
  159. mplang/v1/utils/table_utils.py +0 -185
  160. mplang/v2/__init__.py +0 -424
  161. mplang/v2/libs/mpc/psi/rr22.py +0 -344
  162. mplang_nightly-0.1.dev268.dist-info/RECORD +0 -180
  163. /mplang/{v2/backends → backends}/channel.py +0 -0
  164. /mplang/{v2/edsl → edsl}/README.md +0 -0
  165. /mplang/{v2/edsl → edsl}/registry.py +0 -0
  166. /mplang/{v2/kernels → kernels}/Makefile +0 -0
  167. /mplang/{v2/kernels → kernels}/__init__.py +0 -0
  168. /mplang/{v2/kernels → kernels}/gf128.cpp +0 -0
  169. /mplang/{v2/libs → libs}/device/cluster.py +0 -0
  170. /mplang/{v2/libs → libs}/mpc/analytics/__init__.py +0 -0
  171. /mplang/{v2/libs → libs}/mpc/analytics/groupby.md +0 -0
  172. /mplang/{v2/libs → libs}/mpc/common/constants.py +0 -0
  173. /mplang/{v2/libs → libs}/mpc/ot/__init__.py +0 -0
  174. /mplang/{v2/libs → libs}/mpc/psi/__init__.py +0 -0
  175. /mplang/{v2/libs → libs}/mpc/vole/__init__.py +0 -0
  176. /mplang/{v2/runtime → runtime}/__init__.py +0 -0
  177. /mplang/{v2/runtime → runtime}/dialect_state.py +0 -0
  178. /mplang/{v2/runtime → runtime}/object_store.py +0 -0
  179. {mplang_nightly-0.1.dev268.dist-info → mplang_nightly-0.1.dev270.dist-info}/WHEEL +0 -0
  180. {mplang_nightly-0.1.dev268.dist-info → mplang_nightly-0.1.dev270.dist-info}/entry_points.txt +0 -0
  181. {mplang_nightly-0.1.dev268.dist-info → mplang_nightly-0.1.dev270.dist-info}/licenses/LICENSE +0 -0
mplang/v2/__init__.py DELETED
@@ -1,424 +0,0 @@
1
- # Copyright 2025 Ant Group Co., Ltd.
2
- #
3
- # Licensed under the Apache License, Version 2.0 (the "License");
4
- # you may not use this file except in compliance with the License.
5
- # You may obtain a copy of the License at
6
- #
7
- # http://www.apache.org/licenses/LICENSE-2.0
8
- #
9
- # Unless required by applicable law or agreed to in writing, software
10
- # distributed under the License is distributed on an "AS IS" BASIS,
11
- # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
- # See the License for the specific language governing permissions and
13
- # limitations under the License.
14
-
15
- """MPLang2: Next generation EDSL for multi-party computation.
16
-
17
- This is the temporary home for the new EDSL implementation during migration.
18
- Once migration is complete, this will replace the original mplang package.
19
-
20
- Public API is designed to be compatible with mplang v1 where possible.
21
- """
22
-
23
- from __future__ import annotations
24
-
25
- from collections.abc import Callable
26
- from typing import Any
27
-
28
- __version__ = "0.1.0"
29
-
30
- # =============================================================================
31
- # Core EDSL components
32
- # =============================================================================
33
- # =============================================================================
34
- # Dialects
35
- # =============================================================================
36
- # =============================================================================
37
- # Backend / Runtime
38
- # =============================================================================
39
- import mplang.v2.backends.func_impl # Register func handlers
40
- from mplang.v2 import dialects
41
- from mplang.v2.backends.simp_driver.ops import DRIVER_HANDLERS
42
- from mplang.v2.backends.simp_worker import SimpWorker
43
- from mplang.v2.backends.simp_worker.mem import LocalMesh
44
- from mplang.v2.backends.simp_worker.ops import WORKER_HANDLERS
45
- from mplang.v2.dialects.simp import make_driver, make_simulator
46
- from mplang.v2.edsl import (
47
- Graph,
48
- GraphPrinter,
49
- Object,
50
- Operation,
51
- Primitive,
52
- TracedFunction,
53
- Tracer,
54
- Value,
55
- find_context,
56
- find_context_with_state,
57
- find_interpreter,
58
- format_graph,
59
- get_current_context,
60
- get_default_context,
61
- is_tracing,
62
- jit,
63
- pop_context,
64
- primitive,
65
- push_context,
66
- register_default_context_factory,
67
- set_root_context,
68
- trace,
69
- )
70
- from mplang.v2.edsl.registry import get_profiler
71
-
72
- # Type system
73
- from mplang.v2.edsl.typing import (
74
- MPType,
75
- ScalarType,
76
- SSType,
77
- TableType,
78
- TensorType,
79
- VectorType,
80
- )
81
-
82
- # =============================================================================
83
- # Device API (compatible with mplang v1)
84
- # =============================================================================
85
- from mplang.v2.libs.device import (
86
- ClusterSpec,
87
- Device,
88
- Node,
89
- device,
90
- get_dev_attr,
91
- is_device_obj,
92
- jax_fn,
93
- put,
94
- set_dev_attr,
95
- )
96
- from mplang.v2.libs.device import fetch as device_fetch
97
- from mplang.v2.runtime.interpreter import Interpreter
98
-
99
- # =============================================================================
100
- # Context Management API (JAX-like pattern)
101
- # =============================================================================
102
-
103
-
104
- def _get_context(context: Interpreter | None) -> Interpreter:
105
- """Get context from parameter or context stack."""
106
- if context is not None:
107
- return context
108
- ctx = get_current_context()
109
- if ctx is None:
110
- raise RuntimeError(
111
- "No context available. Either pass context explicitly or use "
112
- "set_context()/push_context() to set a default context."
113
- )
114
- if not isinstance(ctx, Interpreter):
115
- raise RuntimeError(
116
- f"Current context is not an Interpreter: {type(ctx).__name__}. "
117
- "Use mp.set_context(interpreter) to set the execution context."
118
- )
119
- return ctx
120
-
121
-
122
- # =============================================================================
123
- # Meta-APIs (compile, evaluate, fetch)
124
- # =============================================================================
125
-
126
-
127
- def evaluate(
128
- fn: Callable[..., Any] | TracedFunction,
129
- *args: Any,
130
- context: Interpreter | None = None,
131
- **kwargs: Any,
132
- ) -> Any:
133
- """Evaluate a function or traced function.
134
-
135
- Args:
136
- fn: The function or TracedFunction to evaluate.
137
- *args: Positional arguments to pass to the function.
138
- context: Optional interpreter context. If None, uses current context.
139
- **kwargs: Keyword arguments to pass to the function.
140
-
141
- Returns:
142
- The result of the function evaluation.
143
-
144
- Example:
145
- >>> with mp.make_simulator(3) as sim:
146
- ... result = mp.evaluate(traced) # uses sim from context
147
- >>> # Or explicitly:
148
- >>> result = mp.evaluate(traced, context=sim)
149
- """
150
- from mplang.v2.edsl.tracer import TracedFunction
151
- from mplang.v2.runtime.interpreter import InterpObject
152
-
153
- interp = _get_context(context)
154
-
155
- def unwrap_if_interp(val: Any) -> Any:
156
- """Unwrap InterpObject to runtime value at execution boundary."""
157
- if isinstance(val, InterpObject):
158
- return val.runtime_obj
159
- return val
160
-
161
- with interp:
162
- if isinstance(fn, TracedFunction):
163
- inputs = fn.prepare_inputs(*args, **kwargs)
164
- inputs = [unwrap_if_interp(v) for v in inputs]
165
- raw_result = interp.evaluate_graph(fn.graph, inputs)
166
- wrapped = [
167
- InterpObject(v, fn.graph.outputs[i].type, interp)
168
- for i, v in enumerate(raw_result)
169
- ]
170
- return fn.reconstruct_outputs(wrapped)
171
-
172
- return fn(*args, **kwargs)
173
-
174
-
175
- def fetch(
176
- result: Any,
177
- *,
178
- follow_device: bool = True,
179
- context: Interpreter | None = None,
180
- ) -> Any:
181
- """Fetch results from interpreter context to Python.
182
-
183
- This is a meta-function that operates at execution boundaries, not a traced
184
- dialect operation. It brings data from the distributed/MPC runtime back to
185
- the Python host.
186
-
187
- Behavior in different contexts:
188
- - **Tracing (compile)**: Returns the input unchanged (identity). The graph
189
- outputs are determined by the function's return statement, not fetch calls.
190
- - **Execution (evaluate)**: Actually fetches data from workers/parties.
191
-
192
- Design Note (A vs B tradeoff):
193
- Two designs were considered for fetch behavior during tracing:
194
-
195
- - **Design A (chosen)**: fetch = identity during tracing. Graph outputs
196
- are determined solely by the return statement. This is simpler and
197
- avoids ambiguity when fetch and return reference different values.
198
-
199
- - **Design B (alternative)**: fetch marks output points in the graph.
200
- This would allow fetch(a), fetch(b), return b to output both a and b.
201
- However, it complicates the semantics and requires tracking fetch
202
- points separately from return values.
203
-
204
- Design A was chosen for simplicity. If a value needs to be an output,
205
- it should be returned. fetch's role is purely for execution-time I/O.
206
-
207
- Args:
208
- result: Object(s) to fetch. Can be a single InterpObject, DriverVar,
209
- or nested structure containing them.
210
- follow_device: If True and object has device attribute, dispatch to
211
- device.fetch which fetches from the correct rank based on device.
212
- If False, fetch from all parties.
213
- context: Interpreter context. If None, uses current context.
214
-
215
- Returns:
216
- Fetched Python values. For device objects with follow_device=True,
217
- returns single value from the device's rank(s). Otherwise returns
218
- list of values (one per party) or single value for world_size=1.
219
- During tracing, returns the input unchanged.
220
- """
221
- from jax.tree_util import tree_map
222
-
223
- from mplang.v2.backends.simp_driver.values import DriverVar
224
- from mplang.v2.edsl.context import is_tracing
225
- from mplang.v2.runtime.interpreter import InterpObject
226
- from mplang.v2.runtime.value import WrapValue
227
-
228
- # Check if we are in tracing context - if so, return identity
229
- if is_tracing():
230
- # Design A: fetch = identity during tracing
231
- # Graph outputs are determined by return statement, not fetch calls
232
- return result
233
-
234
- # Execution context - actually fetch data
235
- interp = _get_context(context)
236
-
237
- def _fetch_single(var: Any) -> Any:
238
- """Fetch a single value from InterpObject."""
239
- # InterpObject (from mp.evaluate) - extract runtime_obj
240
- if isinstance(var, InterpObject):
241
- if follow_device and is_device_obj(var):
242
- return device_fetch(var)
243
- var = var.runtime_obj # extract and continue processing
244
-
245
- # DriverVar (simp dialect) - remote fetch from workers
246
- if isinstance(var, DriverVar):
247
- from mplang.v2.backends.simp_driver.state import SimpDriver
248
-
249
- simp_state = interp.get_dialect_state("simp")
250
- assert isinstance(simp_state, SimpDriver), "DriverVar requires simp state"
251
-
252
- resolved: list[Any] = []
253
- for rank, uri in enumerate(var.values):
254
- if uri is None:
255
- resolved.append(None)
256
- else:
257
- fetched = simp_state.fetch(rank, uri).result()
258
- if isinstance(fetched, WrapValue):
259
- fetched = fetched.data
260
- resolved.append(fetched)
261
-
262
- return resolved[0] if len(resolved) == 1 else resolved
263
-
264
- # WrapValue (TensorValue, TableValue, etc.) - unwrap
265
- if isinstance(var, WrapValue):
266
- return var.data
267
-
268
- # Plain values pass through
269
- return var
270
-
271
- with interp:
272
- return tree_map(_fetch_single, result)
273
-
274
-
275
- # Alias for compatibility
276
- def function(fn: Callable[..., Any] | None = None) -> Callable[..., Any]:
277
- """Decorator defining a Multi-Party Function (MP Program).
278
-
279
- This decorator "lifts" a local function into a distributed program by
280
- automatically wrapping it in a `simp.pcall_static` that targets ALL available
281
- parties in the current context.
282
-
283
- Semantics: f(args) -> pcall(ALL, f, args)
284
-
285
- Args:
286
- fn: The function to decorate.
287
-
288
- Returns:
289
- A wrapper function that, when called, executes the original function
290
- on all workers.
291
- """
292
- import functools
293
-
294
- from mplang.v2.dialects import simp
295
-
296
- if fn is None:
297
- return function
298
-
299
- def has_simp_state(ctx: Any) -> bool:
300
- if hasattr(ctx, "get_dialect_state"):
301
- state = ctx.get_dialect_state("simp")
302
- return state is not None and hasattr(state, "world_size")
303
- return False
304
-
305
- @functools.wraps(fn)
306
- def wrapper(*args: Any, **kwargs: Any) -> Any:
307
- from mplang.v2.edsl.context import find_context
308
-
309
- # Find context with simp dialect state
310
- ctx = find_context(has_simp_state)
311
- if ctx is None:
312
- raise RuntimeError(
313
- "mp.function requires a context with world_size information "
314
- "(e.g. SimpSimulator or Driver initialized)."
315
- )
316
-
317
- # ctx found by predicate so we know it has get_dialect_state
318
- simp_state = ctx.get_dialect_state("simp") # type: ignore[attr-defined]
319
- world_size = simp_state.world_size # type: ignore
320
-
321
- all_parties = tuple(range(world_size))
322
- return simp.pcall_static(all_parties, fn, *args, **kwargs)
323
-
324
- return wrapper
325
-
326
-
327
- def compile(
328
- fn: Callable[..., Any],
329
- *args: Any,
330
- context: Interpreter | None = None,
331
- **kwargs: Any,
332
- ) -> TracedFunction:
333
- """Compile a function to get its IR without executing it.
334
-
335
- Args:
336
- fn: The function to compile.
337
- *args: Example arguments for tracing.
338
- context: Optional interpreter context. If None, uses current context.
339
- **kwargs: Example keyword arguments for tracing.
340
-
341
- Returns:
342
- TracedFunction with the compiled graph.
343
-
344
- Example:
345
- >>> with mp.make_simulator(3) as sim:
346
- ... traced = mp.compile(job) # uses sim from context
347
- """
348
- # If a context is explicitly provided, push it before tracing
349
- # so that _resolve_cluster() can find it.
350
- if context is not None:
351
- with context:
352
- return trace(fn, *args, **kwargs)
353
-
354
- # Otherwise, rely on the caller having pushed an interpreter context.
355
- # _resolve_cluster() will traverse the stack to find the interpreter.
356
- return trace(fn, *args, **kwargs)
357
-
358
-
359
- # =============================================================================
360
- # Public API
361
- # =============================================================================
362
- __all__ = [ # noqa: RUF022
363
- # Version
364
- "__version__",
365
- # Device API
366
- "ClusterSpec",
367
- "Device",
368
- "Node",
369
- "device",
370
- "get_dev_attr",
371
- "is_device_obj",
372
- "jax_fn",
373
- "put",
374
- "set_dev_attr",
375
- # Core EDSL
376
- "Graph",
377
- "GraphPrinter",
378
- "Object",
379
- "Operation",
380
- "Primitive",
381
- "TracedFunction",
382
- "Tracer",
383
- "Value",
384
- "compile",
385
- "evaluate",
386
- "fetch",
387
- "find_context",
388
- "find_context_with_state",
389
- "find_interpreter",
390
- "format_graph",
391
- "function",
392
- "get_current_context",
393
- "get_default_context",
394
- "is_tracing",
395
- "jit",
396
- "mplang",
397
- "pop_context",
398
- "primitive",
399
- "push_context",
400
- "set_root_context",
401
- "trace",
402
- # Type system
403
- "MPType",
404
- "ScalarType",
405
- "SSType",
406
- "TableType",
407
- "TensorType",
408
- "VectorType",
409
- # Backend / Runtime
410
- "DRIVER_HANDLERS",
411
- "Interpreter",
412
- "LocalMesh",
413
- "SimpWorker",
414
- "WORKER_HANDLERS",
415
- "make_driver",
416
- "make_simulator",
417
- # Dialects
418
- "dialects",
419
- "register_default_context_factory",
420
- "get_profiler",
421
- ]
422
-
423
- # Register Interpreter as default context factory
424
- register_default_context_factory(Interpreter)