mplang-nightly 0.1.dev158__py3-none-any.whl → 0.1.dev268__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (191) hide show
  1. mplang/__init__.py +21 -45
  2. mplang/py.typed +13 -0
  3. mplang/v1/__init__.py +157 -0
  4. mplang/v1/_device.py +602 -0
  5. mplang/{analysis → v1/analysis}/__init__.py +1 -1
  6. mplang/{analysis → v1/analysis}/diagram.py +5 -7
  7. mplang/v1/core/__init__.py +157 -0
  8. mplang/{core → v1/core}/cluster.py +30 -14
  9. mplang/{core → v1/core}/comm.py +5 -1
  10. mplang/{core → v1/core}/context_mgr.py +1 -1
  11. mplang/{core/dtype.py → v1/core/dtypes.py} +44 -2
  12. mplang/{core → v1/core}/expr/__init__.py +7 -7
  13. mplang/{core → v1/core}/expr/ast.py +13 -14
  14. mplang/{core → v1/core}/expr/evaluator.py +65 -24
  15. mplang/{core → v1/core}/expr/printer.py +24 -18
  16. mplang/{core → v1/core}/expr/transformer.py +3 -3
  17. mplang/{core → v1/core}/expr/utils.py +2 -2
  18. mplang/{core → v1/core}/expr/visitor.py +1 -1
  19. mplang/{core → v1/core}/expr/walk.py +1 -1
  20. mplang/{core → v1/core}/interp.py +6 -6
  21. mplang/{core → v1/core}/mpir.py +23 -16
  22. mplang/{core → v1/core}/mpobject.py +6 -6
  23. mplang/{core → v1/core}/mptype.py +13 -10
  24. mplang/{core → v1/core}/pfunc.py +4 -4
  25. mplang/{core → v1/core}/primitive.py +106 -201
  26. mplang/{core → v1/core}/table.py +36 -8
  27. mplang/{core → v1/core}/tensor.py +1 -1
  28. mplang/{core → v1/core}/tracer.py +9 -9
  29. mplang/{api.py → v1/host.py} +38 -6
  30. mplang/v1/kernels/__init__.py +41 -0
  31. mplang/{kernels → v1/kernels}/base.py +1 -1
  32. mplang/v1/kernels/basic.py +240 -0
  33. mplang/{kernels → v1/kernels}/context.py +42 -27
  34. mplang/{kernels → v1/kernels}/crypto.py +44 -37
  35. mplang/v1/kernels/fhe.py +858 -0
  36. mplang/{kernels → v1/kernels}/mock_tee.py +12 -13
  37. mplang/{kernels → v1/kernels}/phe.py +263 -57
  38. mplang/{kernels → v1/kernels}/spu.py +137 -48
  39. mplang/{kernels → v1/kernels}/sql_duckdb.py +12 -15
  40. mplang/{kernels → v1/kernels}/stablehlo.py +30 -23
  41. mplang/v1/kernels/value.py +626 -0
  42. mplang/{ops → v1/ops}/__init__.py +5 -16
  43. mplang/{ops → v1/ops}/base.py +2 -5
  44. mplang/{ops/builtin.py → v1/ops/basic.py} +34 -26
  45. mplang/v1/ops/crypto.py +262 -0
  46. mplang/v1/ops/fhe.py +272 -0
  47. mplang/{ops → v1/ops}/jax_cc.py +33 -68
  48. mplang/v1/ops/nnx_cc.py +168 -0
  49. mplang/{ops → v1/ops}/phe.py +16 -4
  50. mplang/{ops → v1/ops}/spu.py +3 -5
  51. mplang/v1/ops/sql_cc.py +303 -0
  52. mplang/{ops → v1/ops}/tee.py +9 -24
  53. mplang/{protos → v1/protos}/v1alpha1/mpir_pb2.pyi +71 -21
  54. mplang/v1/protos/v1alpha1/value_pb2.py +34 -0
  55. mplang/v1/protos/v1alpha1/value_pb2.pyi +169 -0
  56. mplang/{runtime → v1/runtime}/__init__.py +2 -2
  57. mplang/v1/runtime/channel.py +230 -0
  58. mplang/{runtime → v1/runtime}/cli.py +35 -20
  59. mplang/{runtime → v1/runtime}/client.py +19 -8
  60. mplang/{runtime → v1/runtime}/communicator.py +59 -15
  61. mplang/{runtime → v1/runtime}/data_providers.py +80 -19
  62. mplang/{runtime → v1/runtime}/driver.py +30 -12
  63. mplang/v1/runtime/link_comm.py +196 -0
  64. mplang/{runtime → v1/runtime}/server.py +58 -42
  65. mplang/{runtime → v1/runtime}/session.py +57 -71
  66. mplang/{runtime → v1/runtime}/simulation.py +55 -28
  67. mplang/v1/simp/api.py +353 -0
  68. mplang/{simp → v1/simp}/mpi.py +8 -9
  69. mplang/{simp/__init__.py → v1/simp/party.py} +19 -145
  70. mplang/{simp → v1/simp}/random.py +21 -22
  71. mplang/v1/simp/smpc.py +238 -0
  72. mplang/v1/utils/table_utils.py +185 -0
  73. mplang/v2/__init__.py +424 -0
  74. mplang/v2/backends/__init__.py +57 -0
  75. mplang/v2/backends/bfv_impl.py +705 -0
  76. mplang/v2/backends/channel.py +217 -0
  77. mplang/v2/backends/crypto_impl.py +723 -0
  78. mplang/v2/backends/field_impl.py +454 -0
  79. mplang/v2/backends/func_impl.py +107 -0
  80. mplang/v2/backends/phe_impl.py +148 -0
  81. mplang/v2/backends/simp_design.md +136 -0
  82. mplang/v2/backends/simp_driver/__init__.py +41 -0
  83. mplang/v2/backends/simp_driver/http.py +168 -0
  84. mplang/v2/backends/simp_driver/mem.py +280 -0
  85. mplang/v2/backends/simp_driver/ops.py +135 -0
  86. mplang/v2/backends/simp_driver/state.py +60 -0
  87. mplang/v2/backends/simp_driver/values.py +52 -0
  88. mplang/v2/backends/simp_worker/__init__.py +29 -0
  89. mplang/v2/backends/simp_worker/http.py +354 -0
  90. mplang/v2/backends/simp_worker/mem.py +102 -0
  91. mplang/v2/backends/simp_worker/ops.py +167 -0
  92. mplang/v2/backends/simp_worker/state.py +49 -0
  93. mplang/v2/backends/spu_impl.py +275 -0
  94. mplang/v2/backends/spu_state.py +187 -0
  95. mplang/v2/backends/store_impl.py +62 -0
  96. mplang/v2/backends/table_impl.py +838 -0
  97. mplang/v2/backends/tee_impl.py +215 -0
  98. mplang/v2/backends/tensor_impl.py +519 -0
  99. mplang/v2/cli.py +603 -0
  100. mplang/v2/cli_guide.md +122 -0
  101. mplang/v2/dialects/__init__.py +36 -0
  102. mplang/v2/dialects/bfv.py +665 -0
  103. mplang/v2/dialects/crypto.py +689 -0
  104. mplang/v2/dialects/dtypes.py +378 -0
  105. mplang/v2/dialects/field.py +210 -0
  106. mplang/v2/dialects/func.py +135 -0
  107. mplang/v2/dialects/phe.py +723 -0
  108. mplang/v2/dialects/simp.py +944 -0
  109. mplang/v2/dialects/spu.py +349 -0
  110. mplang/v2/dialects/store.py +63 -0
  111. mplang/v2/dialects/table.py +407 -0
  112. mplang/v2/dialects/tee.py +346 -0
  113. mplang/v2/dialects/tensor.py +1175 -0
  114. mplang/v2/edsl/README.md +279 -0
  115. mplang/v2/edsl/__init__.py +99 -0
  116. mplang/v2/edsl/context.py +311 -0
  117. mplang/v2/edsl/graph.py +463 -0
  118. mplang/v2/edsl/jit.py +62 -0
  119. mplang/v2/edsl/object.py +53 -0
  120. mplang/v2/edsl/primitive.py +284 -0
  121. mplang/v2/edsl/printer.py +119 -0
  122. mplang/v2/edsl/registry.py +207 -0
  123. mplang/v2/edsl/serde.py +375 -0
  124. mplang/v2/edsl/tracer.py +614 -0
  125. mplang/v2/edsl/typing.py +816 -0
  126. mplang/v2/kernels/Makefile +30 -0
  127. mplang/v2/kernels/__init__.py +23 -0
  128. mplang/v2/kernels/gf128.cpp +148 -0
  129. mplang/v2/kernels/ldpc.cpp +82 -0
  130. mplang/v2/kernels/okvs.cpp +283 -0
  131. mplang/v2/kernels/okvs_opt.cpp +291 -0
  132. mplang/v2/kernels/py_kernels.py +398 -0
  133. mplang/v2/libs/collective.py +330 -0
  134. mplang/v2/libs/device/__init__.py +51 -0
  135. mplang/v2/libs/device/api.py +813 -0
  136. mplang/v2/libs/device/cluster.py +352 -0
  137. mplang/v2/libs/ml/__init__.py +23 -0
  138. mplang/v2/libs/ml/sgb.py +1861 -0
  139. mplang/v2/libs/mpc/__init__.py +41 -0
  140. mplang/v2/libs/mpc/_utils.py +99 -0
  141. mplang/v2/libs/mpc/analytics/__init__.py +35 -0
  142. mplang/v2/libs/mpc/analytics/aggregation.py +372 -0
  143. mplang/v2/libs/mpc/analytics/groupby.md +99 -0
  144. mplang/v2/libs/mpc/analytics/groupby.py +331 -0
  145. mplang/v2/libs/mpc/analytics/permutation.py +386 -0
  146. mplang/v2/libs/mpc/common/constants.py +39 -0
  147. mplang/v2/libs/mpc/ot/__init__.py +32 -0
  148. mplang/v2/libs/mpc/ot/base.py +222 -0
  149. mplang/v2/libs/mpc/ot/extension.py +477 -0
  150. mplang/v2/libs/mpc/ot/silent.py +217 -0
  151. mplang/v2/libs/mpc/psi/__init__.py +40 -0
  152. mplang/v2/libs/mpc/psi/cuckoo.py +228 -0
  153. mplang/v2/libs/mpc/psi/okvs.py +49 -0
  154. mplang/v2/libs/mpc/psi/okvs_gct.py +79 -0
  155. mplang/v2/libs/mpc/psi/oprf.py +310 -0
  156. mplang/v2/libs/mpc/psi/rr22.py +344 -0
  157. mplang/v2/libs/mpc/psi/unbalanced.py +200 -0
  158. mplang/v2/libs/mpc/vole/__init__.py +31 -0
  159. mplang/v2/libs/mpc/vole/gilboa.py +327 -0
  160. mplang/v2/libs/mpc/vole/ldpc.py +383 -0
  161. mplang/v2/libs/mpc/vole/silver.py +336 -0
  162. mplang/v2/runtime/__init__.py +15 -0
  163. mplang/v2/runtime/dialect_state.py +41 -0
  164. mplang/v2/runtime/interpreter.py +871 -0
  165. mplang/v2/runtime/object_store.py +194 -0
  166. mplang/v2/runtime/value.py +141 -0
  167. {mplang_nightly-0.1.dev158.dist-info → mplang_nightly-0.1.dev268.dist-info}/METADATA +24 -17
  168. mplang_nightly-0.1.dev268.dist-info/RECORD +180 -0
  169. {mplang_nightly-0.1.dev158.dist-info → mplang_nightly-0.1.dev268.dist-info}/WHEEL +1 -1
  170. mplang/core/__init__.py +0 -92
  171. mplang/device.py +0 -340
  172. mplang/kernels/builtin.py +0 -207
  173. mplang/ops/crypto.py +0 -109
  174. mplang/ops/ibis_cc.py +0 -139
  175. mplang/ops/sql.py +0 -61
  176. mplang/protos/v1alpha1/mpir_pb2_grpc.py +0 -3
  177. mplang/runtime/link_comm.py +0 -131
  178. mplang/simp/smpc.py +0 -201
  179. mplang/utils/table_utils.py +0 -73
  180. mplang_nightly-0.1.dev158.dist-info/RECORD +0 -77
  181. /mplang/{core → v1/core}/mask.py +0 -0
  182. /mplang/{protos → v1/protos}/v1alpha1/mpir_pb2.py +0 -0
  183. /mplang/{runtime → v1/runtime}/exceptions.py +0 -0
  184. /mplang/{runtime → v1/runtime}/http_api.md +0 -0
  185. /mplang/{kernels → v1/simp}/__init__.py +0 -0
  186. /mplang/{utils → v1/utils}/__init__.py +0 -0
  187. /mplang/{utils → v1/utils}/crypto.py +0 -0
  188. /mplang/{utils → v1/utils}/func_utils.py +0 -0
  189. /mplang/{utils → v1/utils}/spu_utils.py +0 -0
  190. {mplang_nightly-0.1.dev158.dist-info → mplang_nightly-0.1.dev268.dist-info}/entry_points.txt +0 -0
  191. {mplang_nightly-0.1.dev158.dist-info → mplang_nightly-0.1.dev268.dist-info}/licenses/LICENSE +0 -0
@@ -28,10 +28,11 @@ from typing import Any, ParamSpec, TypeVar, cast
28
28
 
29
29
  from jax.tree_util import tree_map
30
30
 
31
- from mplang.core.context_mgr import cur_ctx
32
- from mplang.core.dtype import BOOL
33
- from mplang.core.expr.ast import (
31
+ from mplang.v1.core.context_mgr import cur_ctx
32
+ from mplang.v1.core.dtypes import BOOL
33
+ from mplang.v1.core.expr.ast import (
34
34
  AccessExpr,
35
+ CallExpr,
35
36
  CondExpr,
36
37
  ConvExpr,
37
38
  EvalExpr,
@@ -39,16 +40,13 @@ from mplang.core.expr.ast import (
39
40
  ShflSExpr,
40
41
  WhileExpr,
41
42
  )
42
- from mplang.core.interp import InterpContext, InterpVar, apply
43
- from mplang.core.mask import Mask
44
- from mplang.core.mpobject import MPContext, MPObject
45
- from mplang.core.mptype import Rank
46
- from mplang.core.pfunc import PFunction
47
- from mplang.core.table import TableLike
48
- from mplang.core.tensor import ScalarType, Shape, TensorLike
49
- from mplang.core.tracer import TraceContext, TraceVar, trace
50
- from mplang.ops import builtin
51
- from mplang.utils.func_utils import var_demorph, var_morph
43
+ from mplang.v1.core.interp import InterpContext, InterpVar, apply
44
+ from mplang.v1.core.mask import Mask
45
+ from mplang.v1.core.mpobject import MPContext, MPObject
46
+ from mplang.v1.core.mptype import Rank
47
+ from mplang.v1.core.pfunc import PFunction
48
+ from mplang.v1.core.tracer import TraceContext, TraceVar, trace
49
+ from mplang.v1.utils.func_utils import var_demorph, var_morph
52
50
 
53
51
 
54
52
  def _switch_ctx(ctx: MPContext, obj: MPObject | Any) -> MPObject | Any:
@@ -87,30 +85,106 @@ P = ParamSpec("P")
87
85
  R = TypeVar("R")
88
86
 
89
87
 
90
- def primitive(fn: Callable[P, R]) -> Callable[P, R]:
88
+ def trace_before_apply(fn: Callable[P, R], make_call: bool) -> Callable[P, R]:
91
89
  """A decorator to make all primitive call in trace context."""
92
90
 
93
91
  @wraps(fn)
94
92
  def wrapped(*args: P.args, **kwargs: P.kwargs) -> R:
95
93
  current_ctx = cur_ctx()
96
94
  if isinstance(current_ctx, TraceContext):
97
- # If we are in a tracer context, just call the function.
98
- # Note: switch_ctx will do the capture if needed.
99
- args, kwargs = tree_map(partial(_switch_ctx, current_ctx), (args, kwargs))
100
- return fn(*args, **kwargs)
95
+ # If we are already in a tracer context
96
+ if make_call:
97
+ # make a primitive call
98
+ tracer = current_ctx
99
+ tfn = trace(tracer.fork(), fn, *args, **kwargs)
100
+ is_mpobj = lambda x: isinstance(x, MPObject)
101
+ in_vars, in_imms, in_struct = var_morph((args, kwargs), is_mpobj)
102
+ assert in_struct == tfn.in_struct and in_imms == tfn.in_imms
103
+ arg_exprs = [arg.expr for arg in in_vars]
104
+ # re-capture all captured variables into current context if needed.
105
+ cap_exprs = [tracer.capture(var).expr for var in tfn.capture_map.keys()]
106
+ caller_expr = CallExpr(
107
+ name=fn.__name__, fn=tfn.make_expr(), args=arg_exprs + cap_exprs
108
+ )
109
+ out_vars = [
110
+ TraceVar(tracer, AccessExpr(caller_expr, idx))
111
+ for idx in range(caller_expr.num_outputs)
112
+ ]
113
+ return cast(R, var_demorph(out_vars, tfn.out_imms, tfn.out_struct))
114
+ else:
115
+ # embed the function call in the current tracer context
116
+ # Note: switch_ctx will do the capture if needed.
117
+ args, kwargs = tree_map(
118
+ partial(_switch_ctx, current_ctx), (args, kwargs)
119
+ )
120
+ return fn(*args, **kwargs)
101
121
  elif isinstance(current_ctx, InterpContext):
102
122
  trace_ctx = TraceContext(current_ctx.cluster_spec, parent=current_ctx)
103
123
  # TODO(jint): should we add trace_and_apply to improve the performance?
104
- traced_fn = trace(trace_ctx, fn, *args, **kwargs)
124
+ tfn = trace(trace_ctx, fn, *args, **kwargs)
105
125
  # Return back to the original context.
106
- return cast(R, apply(current_ctx, traced_fn, *args, **kwargs))
126
+ return cast(R, apply(current_ctx, tfn, *args, **kwargs))
107
127
  else:
108
128
  raise ValueError(f"Unsupported context type: {type(current_ctx)}")
109
129
 
110
130
  return wrapped
111
131
 
112
132
 
113
- function = primitive
133
+ def builtin_function(fn: Callable[P, R]) -> Callable[P, R]:
134
+ """Decorator to trace a Python function as an opaque primitive call (`CallExpr`).
135
+
136
+ When a function decorated with `@builtin_function` is called within a `TraceContext`, it is
137
+ not inlined. Instead, it is traced separately in a forked context, and a `CallExpr`
138
+ node is inserted into the main graph. This is useful for encapsulating complex
139
+ operations or third-party library calls as single, opaque nodes.
140
+
141
+ **Implementation Note:**
142
+ A `CallExpr` represents a call to a single inline lambda (non-recursive, as we don't
143
+ have Y-combinator support). This single lambda call can be treated as a "primitive call"
144
+ by the printer/visualizer - hence the name "primitive". The function body is captured
145
+ once during tracing and represented as an opaque callable unit in the expression graph,
146
+ maintaining a clear boundary between the caller and callee contexts.
147
+
148
+ Args:
149
+ fn: The function to be traced as a primitive operation.
150
+
151
+ Returns:
152
+ A wrapped function that creates a `CallExpr` node when called in a trace context.
153
+
154
+ Example:
155
+ ```python
156
+ @builtin_function
157
+ def my_op(x: MPObject) -> MPObject:
158
+ # Complex logic traced as a single CallExpr node
159
+ return x + 1
160
+ ```
161
+ """
162
+ return trace_before_apply(fn, make_call=True)
163
+
164
+
165
+ def function(fn: Callable[P, R]) -> Callable[P, R]:
166
+ """Decorator to trace a Python function by inlining its body.
167
+
168
+ When a function decorated with `@function` is called within a `TraceContext`, its
169
+ underlying primitive operations are expanded and inserted directly into the caller's
170
+ graph. This is the default tracing behavior and is suitable for most pure-Python
171
+ multi-party functions.
172
+
173
+ Args:
174
+ fn: The function to be traced and inlined.
175
+
176
+ Returns:
177
+ A wrapped function that inlines its operations into the caller's trace context.
178
+
179
+ Example:
180
+ ```python
181
+ @function
182
+ def my_func(x: MPObject, y: MPObject) -> MPObject:
183
+ # Operations are inlined into the caller's trace
184
+ return x + y * constant(2)
185
+ ```
186
+ """
187
+ return trace_before_apply(fn, make_call=False)
114
188
 
115
189
 
116
190
  # ============================================================================
@@ -126,18 +200,15 @@ def _tracer() -> TraceContext:
126
200
  return ctx
127
201
 
128
202
 
129
- @primitive
130
203
  def psize() -> int:
131
204
  """Get the size of the current party world.
132
205
 
133
206
  Returns:
134
207
  int: The total number of parties in the current multi-party computation context.
135
208
  """
136
- ctx = _tracer()
137
- return ctx.world_size()
209
+ return cur_ctx().world_size()
138
210
 
139
211
 
140
- @primitive
141
212
  def pmask() -> Mask:
142
213
  """Get the current party mask in this computation context.
143
214
 
@@ -145,112 +216,10 @@ def pmask() -> Mask:
145
216
  Mask: The current party mask indicating which parties are active
146
217
  in the current computation context.
147
218
  """
148
- ctx = _tracer()
149
- return ctx.mask
150
-
151
-
152
- @primitive
153
- def prank() -> MPObject:
154
- """Multi-party get the rank (party identifier) of each party.
155
-
156
- This function returns a scalar tensor containing the rank (party identifier)
157
- for each party in the current party mask. Each party independently produces
158
- its own rank value, which serves as a unique identifier within the multi-party
159
- computation context.
160
-
161
- The rank values range from 0 to world_size-1, where world_size is the total
162
- number of parties in the computation. Each party's rank is private to that
163
- party and represents its position in the multi-party protocol.
164
-
165
- Returns:
166
- MPObject: A variable representing a scalar tensor with:
167
- - dtype: UINT64
168
- - shape: () (scalar)
169
-
170
- Note:
171
- Each party in the current party mask independently produces its own rank value.
172
- """
173
- pfunc, eval_args, out_tree = builtin.rank()
174
- results = peval(pfunc, eval_args)
175
- return out_tree.unflatten(results) # type: ignore[no-any-return]
176
-
177
-
178
- @primitive
179
- def prand(shape: Shape = ()) -> MPObject:
180
- """Multi-party generate a private random (uint64) tensor with the given shape.
181
-
182
- This function creates a private random tensor where each party independently
183
- generates its own local random values. Each party's random values are private
184
- and unknown to other parties. The output tensor contains 64-bit unsigned
185
- integers, with each party holding its own privately generated values.
186
-
187
- Args:
188
- shape: The shape of the random tensor to generate.
189
- Must be a tuple of positive integers. Defaults to () for scalar.
190
-
191
- Returns:
192
- MPObject: A variable representing the generated private random tensor with:
193
- - dtype: UINT64
194
- - shape: As specified by the shape parameter
195
-
196
- Note:
197
- Each party in the current party mask independently generates its own
198
- private random values. The randomness is local to each party and is
199
- not shared or revealed to other parties.
200
- """
201
- pfunc, eval_args, out_tree = builtin.prand(shape)
202
- results = peval(pfunc, eval_args)
203
- return out_tree.unflatten(results) # type: ignore[no-any-return]
204
-
205
-
206
- @primitive
207
- def constant(data: TensorLike | ScalarType | TableLike) -> MPObject:
208
- """Create a constant tensor or table from data.
209
-
210
- This function creates a constant that can be used in multi-party
211
- computations. The constant value is embedded directly into the computation
212
- graph and is available to all parties in the current party mask.
213
-
214
- Args:
215
- data: The constant data to embed. Can be:
216
- - A scalar value (int, float, bool)
217
- - A numpy array or other tensor-like object
218
- - A pandas DataFrame or other table-like object
219
- - Any object that can be converted to tensor
220
-
221
- Returns:
222
- MPObject: A variable representing the constant tensor or table with:
223
- - dtype: Inferred from the input data
224
- - shape: Inferred from the input data (for tensors)
225
- - schema: Inferred from the input data (for tables)
226
- - data: The embedded constant values
227
-
228
- Note:
229
- The constant data is embedded at graph construction time and is available
230
- to all parties during execution. Large constants may impact graph size.
231
-
232
- For table-like objects (e.g., pandas DataFrame), JSON serialization is used.
233
- Note that the constant primitive is not designed to carry large tables efficiently -
234
- consider using dedicated table loading mechanisms for substantial datasets.
235
- """
236
- pfunc, eval_args, out_tree = builtin.constant(data)
237
- results = peval(pfunc, eval_args)
238
- return out_tree.unflatten(results) # type: ignore[no-any-return]
239
-
240
-
241
- @primitive
242
- def debug_print(obj: MPObject, prefix: str = "") -> MPObject:
243
- """Print local value of obj on owning parties and pass it through.
244
-
245
- Returns the same MPObject value to keep it alive against DCE and to
246
- support usage like: x = debug_print(x, prefix="x=").
247
- """
248
- pfunc, eval_args, out_tree = builtin.debug_print(obj, prefix=prefix)
249
- results = peval(pfunc, eval_args)
250
- return out_tree.unflatten(results) # type: ignore[no-any-return]
219
+ return _tracer().mask
251
220
 
252
221
 
253
- @primitive
222
+ @function
254
223
  def peval(
255
224
  pfunc: PFunction,
256
225
  args: list[MPObject],
@@ -314,71 +283,7 @@ def peval(
314
283
  return [TraceVar(ctx, res) for res in ret_exprs]
315
284
 
316
285
 
317
- def set_mask(arg: MPObject, mask: Mask) -> MPObject:
318
- """Set the mask of an MPObject to a new value.
319
-
320
- This function allows changing the party mask of an existing MPObject variable.
321
- The behavior depends on whether the input MPObject has a dynamic or static pmask:
322
-
323
- **Case 1: Dynamic pmask (arg.pmask is None)**
324
- - The input MPObject has a runtime-determined pmask
325
- - The return value's pmask will be exactly the specified mask
326
- - No validation is performed at compile time
327
-
328
- **Case 2: Static pmask (arg.pmask is not None)**
329
- - If mask is a subset of arg.pmask: return_var.pmask == arg.pmask (unchanged)
330
- - If mask is NOT a subset of arg.pmask: raises ValueError at compile time
331
-
332
- Args:
333
- arg: The MPObject whose mask needs to be changed.
334
- mask: The target mask to apply. Must be a valid party mask.
335
-
336
- Returns:
337
- MPObject: A new variable with the specified mask behavior:
338
- - For dynamic inputs: pmask = mask
339
- - For static inputs (valid subset): pmask = arg.pmask
340
-
341
- Raises:
342
- ValueError: When arg has a static pmask and mask is not a subset of arg.pmask.
343
- This validation occurs at compile time during graph construction.
344
-
345
- Examples:
346
- **Example 1: Dynamic pmask - mask assignment**
347
- P0 P1 P2
348
- -- -- --
349
- Input: ? ? ? (pmask=None, runtime-determined)
350
- mask: [0,2] (target mask)
351
- -----------------------------------------------------------
352
- Output: x0 - x2 (pmask=[0,2])
353
-
354
- **Example 2: Static pmask - valid subset**
355
- P0 P1 P2
356
- -- -- --
357
- Input: x0 x1 x2 (pmask=[0,1,2])
358
- mask: [0,2] (subset of input pmask)
359
- -----------------------------------------------------------
360
- Output: x0 - x2 (pmask=[0,2])
361
-
362
- **Example 3: Static pmask - invalid subset (compile error)**
363
- P0 P1 P2
364
- -- -- --
365
- Input: x0 - x2 (pmask=[0,2])
366
- mask: [1,2] (NOT subset of [0,2])
367
- -----------------------------------------------------------
368
- Result: ValueError at compile time
369
-
370
- Note:
371
- This function is typically used for constraining the execution scope
372
- of variables or for type casting between different pmask contexts.
373
- The underlying implementation uses JAX identity function with the
374
- specified execution mask.
375
- """
376
- pfunc, eval_args, out_tree = builtin.identity(arg)
377
- results = peval(pfunc, eval_args, mask)
378
- return out_tree.unflatten(results) # type: ignore[no-any-return]
379
-
380
-
381
- @primitive
286
+ @function
382
287
  def uniform_cond(
383
288
  pred: MPObject,
384
289
  then_fn: Callable[..., Any],
@@ -393,7 +298,7 @@ def uniform_cond(
393
298
 
394
299
  1. ``pred`` is a boolean scalar whose runtime value is identical for every enabled party.
395
300
  2. At least one branch contains multi-party primitives (``seal`` / ``reveal`` /
396
- ``srun`` / ``pshfl`` / mask transformations) whose cost or side-effects you
301
+ ``srun_jax`` / ``pshfl`` / mask transformations) whose cost or side-effects you
397
302
  want to avoid if the branch is not taken.
398
303
  3. You require the semantic guarantee that the *non-selected* branch does **not**
399
304
  perform communication, allocate intermediate buffers, or leak timing/side-effects.
@@ -588,7 +493,7 @@ def uniform_cond(
588
493
  return var_demorph(out_vars, then_tfn.out_imms, then_tfn.out_struct) # type: ignore[no-any-return]
589
494
 
590
495
 
591
- @primitive
496
+ @function
592
497
  def while_loop(
593
498
  cond_fn: Callable[[Any], MPObject],
594
499
  body_fn: Callable[[Any], Any],
@@ -654,7 +559,7 @@ def while_loop(
654
559
  secret-shared reduction).
655
560
 
656
561
  cond_fn::
657
- sealed_sum = smpc.reveal(smpc.srun(lambda x: jnp.sum(x))(smpc.seal(x)))
562
+ sealed_sum = smpc.reveal(smpc.srun_jax(lambda x: jnp.sum(x), smpc.seal(x)))
658
563
  return sealed_sum < constant(10)
659
564
 
660
565
  body_fn::
@@ -781,7 +686,7 @@ def while_loop(
781
686
  return var_demorph(out_vars, body_tfn.out_imms, body_tfn.out_struct)
782
687
 
783
688
 
784
- @primitive
689
+ @function
785
690
  def pshfl(src: MPObject, index: MPObject) -> MPObject:
786
691
  """Shuffle the input tensor to the specified index (dynamic version).
787
692
 
@@ -813,7 +718,7 @@ def pshfl(src: MPObject, index: MPObject) -> MPObject:
813
718
 
814
719
  Raises:
815
720
  ValueError: If the index tensor is not a scalar.
816
- RuntimeError: If src[index[i]] is None for any valid index[i] (i.e/,
721
+ RuntimeError: If src[index[i]] is None for any valid index[i] (i.e.,
817
722
  trying to fetch from a party that doesn't hold the data).
818
723
 
819
724
  Examples:
@@ -851,7 +756,7 @@ def pshfl(src: MPObject, index: MPObject) -> MPObject:
851
756
  return TraceVar(_tracer(), shfl_expr)
852
757
 
853
758
 
854
- @primitive
759
+ @function
855
760
  def pshfl_s(src_val: MPObject, pmask: Mask, src_ranks: list[Rank]) -> MPObject:
856
761
  """Shuffle the input tensor to the specified rank, static version.
857
762
 
@@ -910,7 +815,7 @@ def pshfl_s(src_val: MPObject, pmask: Mask, src_ranks: list[Rank]) -> MPObject:
910
815
  return TraceVar(_tracer(), shfl_s_expr)
911
816
 
912
817
 
913
- @primitive
818
+ @function
914
819
  def pconv(vars: list[MPObject]) -> MPObject:
915
820
  """Combine multiple variables that share the same dtype and shape into one.
916
821
 
@@ -18,16 +18,16 @@ from collections.abc import Iterator
18
18
  from dataclasses import dataclass, field
19
19
  from typing import Any, Protocol, runtime_checkable
20
20
 
21
- from mplang.core.dtype import DType
21
+ from mplang.v1.core.dtypes import DType
22
22
 
23
23
  __all__ = ["TableLike", "TableType"]
24
24
 
25
25
 
26
26
  @runtime_checkable
27
- class TableLike(Protocol):
27
+ class PandasTableLike(Protocol):
28
28
  """
29
29
  Protocol for objects structurally resembling tables from common libraries
30
- (pandas DataFrame, pyarrow Table, etc.), focusing on dtypes and columns attributes.
30
+ (pandas DataFrame, polars DataFrame, etc.), focusing on dtypes and columns attributes.
31
31
  """
32
32
 
33
33
  @property
@@ -37,6 +37,26 @@ class TableLike(Protocol):
37
37
  def columns(self) -> Any: ...
38
38
 
39
39
 
40
+ @runtime_checkable
41
+ class ArrowSchema(Protocol):
42
+ @property
43
+ def names(self) -> list[str]: ...
44
+ @property
45
+ def types(self) -> list[Any]: ...
46
+
47
+
48
+ @runtime_checkable
49
+ class ArrowTableLike(Protocol):
50
+ @property
51
+ def column_names(self) -> list[str]: ...
52
+
53
+ @property
54
+ def schema(self) -> ArrowSchema: ...
55
+
56
+
57
+ TableLike = PandasTableLike | ArrowTableLike
58
+
59
+
40
60
  @dataclass(frozen=True)
41
61
  class TableType:
42
62
  """Table schema: ordered list of column name-type pairs.
@@ -109,11 +129,19 @@ class TableType:
109
129
  Returns:
110
130
  TableType instance
111
131
  """
112
- columns = [
113
- (name, DType.from_any(dtype))
114
- for name, dtype in zip(table.columns, table.dtypes, strict=True)
115
- ]
116
- return cls(tuple(columns))
132
+ if isinstance(table, PandasTableLike):
133
+ columns = [
134
+ (name, DType.from_any(dtype))
135
+ for name, dtype in zip(table.columns, table.dtypes, strict=True)
136
+ ]
137
+ return cls(tuple(columns))
138
+ elif isinstance(table, ArrowTableLike):
139
+ schema = table.schema
140
+ columns = [
141
+ (name, DType.from_any(dtype))
142
+ for name, dtype in zip(schema.names, schema.types, strict=True)
143
+ ]
144
+ return cls(tuple(columns))
117
145
 
118
146
  def column_names(self) -> tuple[str, ...]:
119
147
  """Get all column names."""
@@ -19,7 +19,7 @@ from typing import Any, Protocol, runtime_checkable
19
19
 
20
20
  import numpy as np
21
21
 
22
- from mplang.core.dtype import DType
22
+ from mplang.v1.core.dtypes import DType
23
23
 
24
24
  # basic type aliases
25
25
  Shape = tuple[int, ...]
@@ -60,15 +60,15 @@ from collections.abc import Callable
60
60
  from dataclasses import dataclass
61
61
  from typing import Any, cast
62
62
 
63
- from mplang.core.cluster import ClusterSpec
64
- from mplang.core.context_mgr import with_ctx
65
- from mplang.core.expr.ast import Expr, FuncDefExpr, TupleExpr, VariableExpr
66
- from mplang.core.expr.printer import Printer
67
- from mplang.core.mask import Mask
68
- from mplang.core.mpobject import MPContext, MPObject
69
- from mplang.core.mptype import MPType
70
- from mplang.core.pfunc import get_fn_name
71
- from mplang.utils.func_utils import MorphStruct, var_demorph, var_morph
63
+ from mplang.v1.core.cluster import ClusterSpec
64
+ from mplang.v1.core.context_mgr import with_ctx
65
+ from mplang.v1.core.expr.ast import Expr, FuncDefExpr, TupleExpr, VariableExpr
66
+ from mplang.v1.core.expr.printer import Printer
67
+ from mplang.v1.core.mask import Mask
68
+ from mplang.v1.core.mpobject import MPContext, MPObject
69
+ from mplang.v1.core.mptype import MPType
70
+ from mplang.v1.core.pfunc import get_fn_name
71
+ from mplang.v1.utils.func_utils import MorphStruct, var_demorph, var_morph
72
72
 
73
73
 
74
74
  class VarNamer:
@@ -19,7 +19,8 @@ from typing import Any
19
19
 
20
20
  from jax.tree_util import tree_map
21
21
 
22
- from mplang.core import (
22
+ from mplang.v1.core import (
23
+ ClusterSpec,
23
24
  InterpContext,
24
25
  MPContext,
25
26
  MPObject,
@@ -27,8 +28,7 @@ from mplang.core import (
27
28
  TracedFunction,
28
29
  trace,
29
30
  )
30
- from mplang.core.cluster import ClusterSpec
31
- from mplang.core.context_mgr import cur_ctx, with_ctx
31
+ from mplang.v1.core.context_mgr import cur_ctx, with_ctx
32
32
 
33
33
 
34
34
  def evaluate(
@@ -38,6 +38,16 @@ def evaluate(
38
38
 
39
39
  This function accepts arbitrary types as it's designed to handle
40
40
  any multi-party computation function and arguments.
41
+
42
+ Args:
43
+ interp: The interpreter context for evaluating the multi-party function.
44
+ mpfn: The multi-party function to evaluate.
45
+ *args: Positional arguments to pass to the function.
46
+ **kwargs: Keyword arguments to pass to the function.
47
+
48
+ Returns:
49
+ Any: The result of evaluating the multi-party function, which can be
50
+ any type depending on the function's return type.
41
51
  """
42
52
  assert isinstance(interp, InterpContext), f"Expect InterpContext, got {interp}"
43
53
  with with_ctx(interp):
@@ -49,6 +59,16 @@ def fetch(interp: InterpContext | None, objs: Any) -> Any: # type: ignore[misc]
49
59
 
50
60
  This function uses tree_map to handle arbitrary nested structures,
51
61
  so it needs to accept and return Any type.
62
+
63
+ Args:
64
+ interp: The interpreter context for fetching results. If None, uses the
65
+ current context from cur_ctx().
66
+ objs: The objects containing MPObject instances to fetch. Can be any
67
+ nested structure.
68
+
69
+ Returns:
70
+ Any: The fetched results with the same structure as the input objects,
71
+ but with MPObject instances replaced by their computed values.
52
72
  """
53
73
  ctx = interp or cur_ctx()
54
74
  assert isinstance(ctx, InterpContext), f"Expect MPExecutor, got {ctx}"
@@ -56,11 +76,11 @@ def fetch(interp: InterpContext | None, objs: Any) -> Any: # type: ignore[misc]
56
76
  evaluated = evaluate(ctx, lambda x: x, objs)
57
77
 
58
78
  def fetch_impl(arg: MPObject | Any) -> Any:
59
- if isinstance(arg, MPObject):
60
- return ctx.fetch(arg)
61
- else:
79
+ if not isinstance(arg, MPObject):
62
80
  return arg
63
81
 
82
+ return ctx.fetch(arg)
83
+
64
84
  return tree_map(fetch_impl, evaluated)
65
85
 
66
86
 
@@ -94,5 +114,17 @@ class CompileOptions(MPContext):
94
114
  def compile(
95
115
  mctx: MPContext, fn: Callable[..., Any], *args: Any, **kwargs: Any
96
116
  ) -> TracedFunction:
117
+ """Compile a multi-party function into a TracedFunction.
118
+
119
+ Args:
120
+ mctx: The multi-party context for compilation.
121
+ fn: The function to compile.
122
+ *args: Positional arguments to pass during compilation.
123
+ **kwargs: Keyword arguments to pass during compilation.
124
+
125
+ Returns:
126
+ TracedFunction: The compiled function representation that can be
127
+ evaluated in multi-party contexts.
128
+ """
97
129
  trace_ctx = TraceContext(mctx.cluster_spec)
98
130
  return trace(trace_ctx, fn, *args, **kwargs)
@@ -0,0 +1,41 @@
1
+ # Copyright 2025 Ant Group Co., Ltd.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ from mplang.v1.kernels.value import (
16
+ BytesBlob,
17
+ TableValue,
18
+ TensorValue,
19
+ Value,
20
+ ValueDecodeError,
21
+ ValueError,
22
+ decode_value,
23
+ encode_value,
24
+ is_value_envelope,
25
+ list_value_kinds,
26
+ register_value,
27
+ )
28
+
29
+ __all__ = [
30
+ "BytesBlob",
31
+ "TableValue",
32
+ "TensorValue",
33
+ "Value",
34
+ "ValueDecodeError",
35
+ "ValueError",
36
+ "decode_value",
37
+ "encode_value",
38
+ "is_value_envelope",
39
+ "list_value_kinds",
40
+ "register_value",
41
+ ]
@@ -37,7 +37,7 @@ from dataclasses import dataclass
37
37
  from typing import TYPE_CHECKING, Any
38
38
 
39
39
  if TYPE_CHECKING:
40
- from mplang.kernels.context import RuntimeContext
40
+ from mplang.v1.kernels.context import RuntimeContext
41
41
 
42
42
  __all__ = [
43
43
  "KernelContext",