mplang-nightly 0.1.dev172__py3-none-any.whl → 0.1.dev174__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
mplang/__init__.py CHANGED
@@ -31,40 +31,117 @@ from mplang.core import (
31
31
  MPContext,
32
32
  MPObject,
33
33
  MPType,
34
+ Rank,
34
35
  TableType,
35
36
  TensorType,
37
+ TraceContext,
38
+ TracedFunction,
39
+ cur_ctx,
36
40
  function,
41
+ set_ctx,
42
+ trace,
43
+ with_ctx,
37
44
  )
38
45
  from mplang.core.cluster import ClusterSpec, Device, Node, RuntimeInfo
39
- from mplang.core.context_mgr import cur_ctx, set_ctx, with_ctx
46
+ from mplang.core.mpir import IrReader, IrWriter
47
+ from mplang.core.primitive import (
48
+ constant,
49
+ pconv,
50
+ peval,
51
+ prand,
52
+ prank,
53
+ pshfl,
54
+ pshfl_s,
55
+ uniform_cond,
56
+ while_loop,
57
+ )
40
58
  from mplang.host import CompileOptions, compile, evaluate, fetch
41
59
  from mplang.runtime.driver import Driver
42
60
  from mplang.runtime.simulation import Simulator
61
+ from mplang.simp.api import (
62
+ run,
63
+ run_at,
64
+ run_ibis,
65
+ run_ibis_at,
66
+ run_jax,
67
+ run_jax_at,
68
+ run_sql,
69
+ run_sql_at,
70
+ )
71
+ from mplang.simp.mpi import allgather_m, bcast_m, gather_m, p2p, scatter_m
72
+ from mplang.simp.party import P0, P1, P2, P2P, P, Party, load_module
73
+ from mplang.simp.random import key_split, pperm, prandint, ukey, urandint
74
+ from mplang.simp.smpc import reveal, revealTo, seal, sealFrom, srun
43
75
 
44
76
  # Public API
45
77
  __all__ = [
78
+ "P0",
79
+ "P1",
80
+ "P2",
81
+ "P2P",
46
82
  "ClusterSpec",
47
83
  "CompileOptions",
48
84
  "DType",
49
85
  "Device",
50
86
  "Driver",
51
87
  "InterpContext",
88
+ "IrReader",
89
+ "IrWriter",
52
90
  "MPContext",
53
91
  "MPObject",
54
92
  "MPType",
55
93
  "Mask",
56
94
  "Node",
95
+ "P",
96
+ "Party",
97
+ "Rank",
57
98
  "RuntimeInfo",
58
99
  "Simulator",
59
100
  "TableType",
60
101
  "TensorType",
102
+ "TraceContext",
103
+ "TracedFunction",
61
104
  "__version__",
105
+ "allgather_m",
62
106
  "analysis",
107
+ "bcast_m",
63
108
  "compile",
109
+ "constant",
64
110
  "cur_ctx",
65
111
  "evaluate",
66
112
  "fetch",
67
113
  "function",
114
+ "gather_m",
115
+ "key_split",
116
+ "load_module",
117
+ "p2p",
118
+ "pconv",
119
+ "peval",
120
+ "pperm",
121
+ "prand",
122
+ "prandint",
123
+ "prank",
124
+ "pshfl",
125
+ "pshfl_s",
126
+ "reveal",
127
+ "revealTo",
128
+ "run",
129
+ "run_at",
130
+ "run_ibis",
131
+ "run_ibis_at",
132
+ "run_jax",
133
+ "run_jax_at",
134
+ "run_sql",
135
+ "run_sql_at",
136
+ "scatter_m",
137
+ "seal",
138
+ "sealFrom",
68
139
  "set_ctx",
140
+ "srun",
141
+ "trace",
142
+ "ukey",
143
+ "uniform_cond",
144
+ "urandint",
145
+ "while_loop",
69
146
  "with_ctx",
70
147
  ]
@@ -26,7 +26,7 @@ from typing import TypedDict
26
26
  from mplang.core import TracedFunction
27
27
  from mplang.core.cluster import ClusterSpec
28
28
  from mplang.core.mask import Mask
29
- from mplang.core.mpir import Writer, get_graph_statistics
29
+ from mplang.core.mpir import IrWriter, get_graph_statistics
30
30
  from mplang.protos.v1alpha1 import mpir_pb2
31
31
 
32
32
  # ----------------------------- Core helpers (copied) -----------------------------
@@ -450,7 +450,7 @@ def dump(
450
450
 
451
451
  # Build graph once
452
452
  expr = traced.make_expr()
453
- graph_proto = Writer().dumps(expr)
453
+ graph_proto = IrWriter().dumps(expr)
454
454
 
455
455
  # Derive world_size from cluster_spec if provided
456
456
  derived_world_size: int | None = None
mplang/core/__init__.py CHANGED
@@ -27,6 +27,7 @@ from mplang.core.comm import (
27
27
  ICollective,
28
28
  ICommunicator,
29
29
  )
30
+ from mplang.core.context_mgr import cur_ctx, set_ctx, with_ctx
30
31
  from mplang.core.dtype import DType
31
32
  from mplang.core.interp import InterpContext, InterpVar
32
33
  from mplang.core.mask import Mask
@@ -76,6 +77,7 @@ __all__ = [
76
77
  "TracedFunction",
77
78
  "VarNamer",
78
79
  "constant",
80
+ "cur_ctx",
79
81
  "debug_print",
80
82
  "function",
81
83
  "pconv",
@@ -85,8 +87,10 @@ __all__ = [
85
87
  "pshfl",
86
88
  "pshfl_s",
87
89
  "psize",
90
+ "set_ctx",
88
91
  "set_mask",
89
92
  "trace",
90
93
  "uniform_cond",
91
94
  "while_loop",
95
+ "with_ctx",
92
96
  ]
mplang/core/mpir.py CHANGED
@@ -204,7 +204,7 @@ def attr_to_proto(py_value: Any) -> mpir_pb2.AttrProto:
204
204
  raise TypeError(f"Unsupported tuple/list type: {type(py_value)}")
205
205
  elif isinstance(py_value, FuncDefExpr):
206
206
  # Convert FuncDefExpr to GraphProto
207
- graph = Writer().dumps(py_value)
207
+ graph = IrWriter().dumps(py_value)
208
208
  attr_proto.type = mpir_pb2.AttrProto.GRAPH
209
209
  attr_proto.graph.CopyFrom(graph)
210
210
  elif isinstance(py_value, PFunction):
@@ -234,7 +234,7 @@ def attr_to_proto(py_value: Any) -> mpir_pb2.AttrProto:
234
234
  return attr_proto
235
235
 
236
236
 
237
- class Writer:
237
+ class IrWriter:
238
238
  """Writer for serializing Expr-based expressions to GraphProto.
239
239
 
240
240
  This class traverses an expression tree and converts it into a serialized
@@ -525,7 +525,7 @@ class Writer:
525
525
  raise TypeError(f"Unsupported expr type for serialization: {type(expr)}")
526
526
 
527
527
 
528
- class Reader:
528
+ class IrReader:
529
529
  """Reader for deserializing GraphProto back to Expr-based expressions.
530
530
 
531
531
  This class is responsible for converting serialized GraphProto representations
@@ -902,7 +902,7 @@ class Reader:
902
902
  )
903
903
  elif attr_proto.type == mpir_pb2.AttrProto.GRAPH:
904
904
  # Handle nested expressions (for control flow)
905
- reader = Reader()
905
+ reader = IrReader()
906
906
  return reader.loads(attr_proto.graph)
907
907
  else:
908
908
  raise TypeError(f"Unsupported attribute type: {attr_proto.type}")
mplang/core/primitive.py CHANGED
@@ -49,6 +49,9 @@ from mplang.core.table import TableLike
49
49
  from mplang.core.tensor import ScalarType, Shape, TensorLike
50
50
  from mplang.core.tracer import TraceContext, TraceVar, trace
51
51
  from mplang.ops import basic
52
+ from mplang.ops.base import (
53
+ FeOperation, # TODO(jint), is this a backward dependency?
54
+ )
52
55
  from mplang.utils.func_utils import var_demorph, var_morph
53
56
 
54
57
 
@@ -243,9 +246,7 @@ def prank() -> MPObject:
243
246
  Note:
244
247
  Each party in the current party mask independently produces its own rank value.
245
248
  """
246
- pfunc, eval_args, out_tree = basic.rank()
247
- results = peval(pfunc, eval_args)
248
- return out_tree.unflatten(results) # type: ignore[no-any-return]
249
+ return cast(MPObject, run(None, basic.rank))
249
250
 
250
251
 
251
252
  @bltin_function
@@ -271,9 +272,7 @@ def prand(shape: Shape = ()) -> MPObject:
271
272
  private random values. The randomness is local to each party and is
272
273
  not shared or revealed to other parties.
273
274
  """
274
- pfunc, eval_args, out_tree = basic.prand(shape)
275
- results = peval(pfunc, eval_args)
276
- return out_tree.unflatten(results) # type: ignore[no-any-return]
275
+ return cast(MPObject, run(None, basic.prand, shape))
277
276
 
278
277
 
279
278
  def constant(data: TensorLike | ScalarType | TableLike) -> MPObject:
@@ -305,9 +304,7 @@ def constant(data: TensorLike | ScalarType | TableLike) -> MPObject:
305
304
  Note that the constant primitive is not designed to carry large tables efficiently -
306
305
  consider using dedicated table loading mechanisms for substantial datasets.
307
306
  """
308
- pfunc, eval_args, out_tree = basic.constant(data)
309
- results = peval(pfunc, eval_args)
310
- return out_tree.unflatten(results) # type: ignore[no-any-return]
307
+ return cast(MPObject, run(None, basic.constant, data))
311
308
 
312
309
 
313
310
  @bltin_function
@@ -319,7 +316,7 @@ def debug_print(obj: MPObject, prefix: str = "") -> MPObject:
319
316
  """
320
317
  pfunc, eval_args, out_tree = basic.debug_print(obj, prefix=prefix)
321
318
  results = peval(pfunc, eval_args)
322
- return out_tree.unflatten(results) # type: ignore[no-any-return]
319
+ return cast(MPObject, out_tree.unflatten(results))
323
320
 
324
321
 
325
322
  @function
@@ -386,6 +383,23 @@ def peval(
386
383
  return [TraceVar(ctx, res) for res in ret_exprs]
387
384
 
388
385
 
386
+ def run(
387
+ pmask: Mask | None,
388
+ fe_op: FeOperation,
389
+ *args: Any,
390
+ **kwargs: Any,
391
+ ) -> Any:
392
+ """Run an operation in the current context."""
393
+ pfunc, eval_args, out_tree = fe_op(*args, **kwargs)
394
+ results = peval(pfunc, eval_args, pmask)
395
+ return out_tree.unflatten(results)
396
+
397
+
398
+ def run_at(rank: Rank, op: Any, *args: Any, **kwargs: Any) -> Any:
399
+ """Run an operation at a specific rank."""
400
+ return run(Mask.from_ranks(rank), op, *args, **kwargs)
401
+
402
+
389
403
  def set_mask(arg: MPObject, mask: Mask) -> MPObject:
390
404
  """Set the mask of an MPObject to a new value.
391
405
 
@@ -447,7 +461,7 @@ def set_mask(arg: MPObject, mask: Mask) -> MPObject:
447
461
  """
448
462
  pfunc, eval_args, out_tree = basic.identity(arg)
449
463
  results = peval(pfunc, eval_args, mask)
450
- return out_tree.unflatten(results) # type: ignore[no-any-return]
464
+ return cast(MPObject, out_tree.unflatten(results))
451
465
 
452
466
 
453
467
  @function
mplang/device.py CHANGED
@@ -29,8 +29,7 @@ from typing import Any
29
29
 
30
30
  from jax.tree_util import tree_map
31
31
 
32
- import mplang.host as mapi
33
- from mplang import simp
32
+ import mplang.host as mphost
34
33
  from mplang.core import InterpContext, MPObject, primitive
35
34
  from mplang.core.cluster import ClusterSpec, Device
36
35
  from mplang.core.context_mgr import cur_ctx
@@ -40,6 +39,7 @@ from mplang.ops.base import FeOperation
40
39
  from mplang.ops.ibis_cc import IbisRunner
41
40
  from mplang.ops.jax_cc import JaxRunner
42
41
  from mplang.simp import mpi, smpc
42
+ from mplang.simp.api import run_at
43
43
 
44
44
  # Automatic transfer between devices when parameter is not on the target device.
45
45
  g_auto_trans: bool = True
@@ -96,7 +96,7 @@ def _device_run_tee(
96
96
  raise ValueError("TEE device only supports JAX and Ibis frontend.")
97
97
  assert len(dev_info.members) == 1
98
98
  rank = dev_info.members[0].rank
99
- var = simp.runAt(rank, op)(*args, **kwargs)
99
+ var = run_at(rank, op, *args, **kwargs)
100
100
  return tree_map(partial(_set_devid, dev_id=dev_info.name), var)
101
101
 
102
102
 
@@ -105,13 +105,13 @@ def _device_run_ppu(
105
105
  ) -> Any:
106
106
  assert len(dev_info.members) == 1
107
107
  rank = dev_info.members[0].rank
108
- var = simp.runAt(rank, op)(*args, **kwargs)
108
+ var = run_at(rank, op, *args, **kwargs)
109
109
  return tree_map(partial(_set_devid, dev_id=dev_info.name), var)
110
110
 
111
111
 
112
112
  def _device_run(dev_id: str, op: FeOperation, *args: Any, **kwargs: Any) -> Any:
113
113
  assert isinstance(op, FeOperation)
114
- cluster_spec = mapi.cur_ctx().cluster_spec
114
+ cluster_spec = mphost.cur_ctx().cluster_spec
115
115
  if dev_id not in cluster_spec.devices:
116
116
  raise ValueError(f"Device {dev_id} not found in cluster spec.")
117
117
 
@@ -177,7 +177,7 @@ def _d2d(to_dev_id: str, obj: MPObject) -> MPObject:
177
177
  if frm_dev_id == to_dev_id:
178
178
  return obj
179
179
 
180
- cluster_spec: ClusterSpec = mapi.cur_ctx().cluster_spec
180
+ cluster_spec: ClusterSpec = mphost.cur_ctx().cluster_spec
181
181
  frm_dev = cluster_spec.devices[frm_dev_id]
182
182
  to_dev = cluster_spec.devices[to_dev_id]
183
183
  frm_to_pair = (frm_dev.kind.upper(), to_dev.kind.upper())
@@ -209,11 +209,11 @@ def _d2d(to_dev_id: str, obj: MPObject) -> MPObject:
209
209
  sess_p, sess_t = _ensure_tee_session(frm_dev_id, to_dev_id, frm_rank, tee_rank)
210
210
  # Bytes-only path: pack -> enc -> p2p -> dec -> unpack (with static out type)
211
211
  obj_ty = TensorType.from_obj(obj)
212
- b = simp.runAt(frm_rank, basic.pack)(obj)
213
- ct = simp.runAt(frm_rank, crypto.enc)(b, sess_p)
212
+ b = run_at(frm_rank, basic.pack, obj)
213
+ ct = run_at(frm_rank, crypto.enc, b, sess_p)
214
214
  ct_at_tee = mpi.p2p(frm_rank, tee_rank, ct)
215
- b_at_tee = simp.runAt(tee_rank, crypto.dec)(ct_at_tee, sess_t)
216
- pt_at_tee = simp.runAt(tee_rank, basic.unpack)(b_at_tee, out_ty=obj_ty)
215
+ b_at_tee = run_at(tee_rank, crypto.dec, ct_at_tee, sess_t)
216
+ pt_at_tee = run_at(tee_rank, basic.unpack, b_at_tee, out_ty=obj_ty)
217
217
  return tree_map(partial(_set_devid, dev_id=to_dev_id), pt_at_tee) # type: ignore[no-any-return]
218
218
  elif frm_to_pair == ("TEE", "PPU"):
219
219
  # Transparent encryption from TEE to a specific PPU using the reverse-direction session key
@@ -223,11 +223,11 @@ def _d2d(to_dev_id: str, obj: MPObject) -> MPObject:
223
223
  # Ensure bidirectional session established for this pair
224
224
  sess_p, sess_t = _ensure_tee_session(to_dev_id, frm_dev_id, ppu_rank, tee_rank)
225
225
  obj_ty = TensorType.from_obj(obj)
226
- b = simp.runAt(tee_rank, basic.pack)(obj)
227
- ct = simp.runAt(tee_rank, crypto.enc)(b, sess_t)
226
+ b = run_at(tee_rank, basic.pack, obj)
227
+ ct = run_at(tee_rank, crypto.enc, b, sess_t)
228
228
  ct_at_ppu = mpi.p2p(tee_rank, ppu_rank, ct)
229
- b_at_ppu = simp.runAt(ppu_rank, crypto.dec)(ct_at_ppu, sess_p)
230
- pt_at_ppu = simp.runAt(ppu_rank, basic.unpack)(b_at_ppu, out_ty=obj_ty)
229
+ b_at_ppu = run_at(ppu_rank, crypto.dec, ct_at_ppu, sess_p)
230
+ pt_at_ppu = run_at(ppu_rank, basic.unpack, b_at_ppu, out_ty=obj_ty)
231
231
  return tree_map(partial(_set_devid, dev_id=to_dev_id), pt_at_ppu) # type: ignore[no-any-return]
232
232
  else:
233
233
  supported = [
@@ -260,27 +260,25 @@ def _ensure_tee_session(
260
260
 
261
261
  # 1) TEE generates (sk, pk) and quote(pk)
262
262
  # KEM suite currently constant; future: read from tee device config (e.g. cluster_spec.devices[dev_id].config)
263
- tee_sk, tee_pk = simp.runAt(tee_rank, crypto.kem_keygen)(_TEE_KEM_SUITE)
264
- quote = simp.runAt(tee_rank, tee.quote_gen)(tee_pk)
263
+ tee_sk, tee_pk = run_at(tee_rank, crypto.kem_keygen, _TEE_KEM_SUITE)
264
+ quote = run_at(tee_rank, tee.quote_gen, tee_pk)
265
265
 
266
266
  # 2) Send quote to sender and attest to obtain TEE pk
267
267
  quote_at_sender = mpi.p2p(tee_rank, frm_rank, quote)
268
- tee_pk_at_sender = simp.runAt(frm_rank, tee.attest)(quote_at_sender)
268
+ tee_pk_at_sender = run_at(frm_rank, tee.attest, quote_at_sender)
269
269
 
270
270
  # 3) Sender generates its ephemeral keypair and sends its pk to TEE
271
- v_sk, v_pk = simp.runAt(frm_rank, crypto.kem_keygen)(_TEE_KEM_SUITE)
271
+ v_sk, v_pk = run_at(frm_rank, crypto.kem_keygen, _TEE_KEM_SUITE)
272
272
  v_pk_at_tee = mpi.p2p(frm_rank, tee_rank, v_pk)
273
273
 
274
274
  # 4) Both sides derive the shared secret and session key
275
- shared_p = simp.runAt(frm_rank, crypto.kem_derive)(
276
- v_sk, tee_pk_at_sender, _TEE_KEM_SUITE
277
- )
278
- shared_t = simp.runAt(tee_rank, crypto.kem_derive)(
279
- tee_sk, v_pk_at_tee, _TEE_KEM_SUITE
275
+ shared_p = run_at(
276
+ frm_rank, crypto.kem_derive, v_sk, tee_pk_at_sender, _TEE_KEM_SUITE
280
277
  )
278
+ shared_t = run_at(tee_rank, crypto.kem_derive, tee_sk, v_pk_at_tee, _TEE_KEM_SUITE)
281
279
  # Use a fixed ASCII string literal for HKDF info on both sides
282
- sess_p = simp.runAt(frm_rank, crypto.hkdf)(shared_p, _HKDF_INFO_LITERAL)
283
- sess_t = simp.runAt(tee_rank, crypto.hkdf)(shared_t, _HKDF_INFO_LITERAL)
280
+ sess_p = run_at(frm_rank, crypto.hkdf, shared_p, _HKDF_INFO_LITERAL)
281
+ sess_t = run_at(tee_rank, crypto.hkdf, shared_t, _HKDF_INFO_LITERAL)
284
282
 
285
283
  cache[key] = (sess_p, sess_t)
286
284
  return sess_p, sess_t
@@ -300,25 +298,25 @@ def _fetch(interp: InterpContext, obj: MPObject) -> Any:
300
298
 
301
299
  dev_info = cluster_spec.devices[dev_id]
302
300
  if dev_kind == "SPU":
303
- revealed = mapi.evaluate(interp, smpc.reveal, obj)
304
- result = mapi.fetch(interp, revealed)
301
+ revealed = mphost.evaluate(interp, smpc.reveal, obj)
302
+ result = mphost.fetch(interp, revealed)
305
303
  # now all members have the same value, return the one at rank 0
306
304
  return result[dev_info.members[0].rank]
307
305
  elif dev_kind == "PPU":
308
306
  assert len(dev_info.members) == 1
309
307
  rank = dev_info.members[0].rank
310
- result = mapi.fetch(interp, obj)
308
+ result = mphost.fetch(interp, obj)
311
309
  return result[rank]
312
310
  elif dev_kind == "TEE":
313
311
  assert len(dev_info.members) == 1
314
312
  rank = dev_info.members[0].rank
315
- result = mapi.fetch(interp, obj)
313
+ result = mphost.fetch(interp, obj)
316
314
  return result[rank]
317
315
  else:
318
316
  raise ValueError(f"Unknown device id: {dev_id}")
319
317
 
320
318
 
321
319
  def fetch(interp: InterpContext, objs: Any) -> Any:
322
- ctx = interp or mapi.cur_ctx()
320
+ ctx = interp or mphost.cur_ctx()
323
321
  assert isinstance(ctx, InterpContext), f"Expect InterpContext, got {ctx}"
324
322
  return tree_map(partial(_fetch, ctx), objs)
mplang/runtime/driver.py CHANGED
@@ -15,8 +15,8 @@
15
15
  """
16
16
  HTTP-based driver implementation for distributed execution.
17
17
 
18
- This module provides an HTTP-based alternative to the gRPC Driver,
19
- using REST APIs for distributed multi-party computation coordination.
18
+ This module provides an HTTP-based driver, using REST APIs
19
+ for distributed multi-party computation coordination.
20
20
  """
21
21
 
22
22
  from __future__ import annotations
@@ -33,7 +33,7 @@ from mplang.core.cluster import ClusterSpec
33
33
  from mplang.core.expr.ast import Expr
34
34
  from mplang.core.interp import InterpContext, InterpVar
35
35
  from mplang.core.mask import Mask
36
- from mplang.core.mpir import Writer
36
+ from mplang.core.mpir import IrWriter
37
37
  from mplang.core.mpobject import MPObject
38
38
  from mplang.core.mptype import MPType
39
39
  from mplang.kernels.value import TableValue, TensorValue
@@ -198,7 +198,7 @@ class Driver(InterpContext):
198
198
 
199
199
  var_name_mapping = dict(zip(var_names, party_symbol_names, strict=True))
200
200
 
201
- writer = Writer(var_name_mapping)
201
+ writer = IrWriter(var_name_mapping)
202
202
  program_proto = writer.dumps(expr)
203
203
 
204
204
  output_symbols = [self.new_name() for _ in range(expr.num_outputs)]
mplang/runtime/server.py CHANGED
@@ -30,7 +30,7 @@ from fastapi import (
30
30
  from fastapi.responses import JSONResponse
31
31
  from pydantic import BaseModel
32
32
 
33
- from mplang.core.mpir import Reader
33
+ from mplang.core.mpir import IrReader
34
34
  from mplang.core.table import TableType
35
35
  from mplang.core.tensor import TensorType
36
36
  from mplang.kernels.base import KernelContext
@@ -310,7 +310,7 @@ def create_and_execute_computation(
310
310
  f"Invalid base64 or protobuf for mpprogram: {e!s}"
311
311
  ) from e
312
312
 
313
- reader = Reader()
313
+ reader = IrReader()
314
314
  expr = reader.loads(graph_proto)
315
315
 
316
316
  if expr is None:
@@ -30,7 +30,7 @@ from mplang.core.expr.ast import Expr
30
30
  from mplang.core.expr.evaluator import IEvaluator, create_evaluator
31
31
  from mplang.core.interp import InterpContext, InterpVar
32
32
  from mplang.core.mask import Mask
33
- from mplang.core.mpir import Reader, Writer
33
+ from mplang.core.mpir import IrReader, IrWriter
34
34
  from mplang.core.mpobject import MPObject
35
35
  from mplang.core.mptype import MPType, TensorLike
36
36
  from mplang.core.pfunc import PFunction # for spu.seed_env kernel seeding
@@ -187,10 +187,10 @@ class Simulator(InterpContext):
187
187
  This exposes potential MPIR serialization bugs by forcing expressions
188
188
  to go through the full serialize->deserialize cycle.
189
189
  """
190
- writer = Writer()
190
+ writer = IrWriter()
191
191
  graph_proto = writer.dumps(expr)
192
192
 
193
- reader = Reader()
193
+ reader = IrReader()
194
194
  deserialized_expr = reader.loads(graph_proto)
195
195
 
196
196
  if deserialized_expr is None:
mplang/simp/__init__.py CHANGED
@@ -11,341 +11,3 @@
11
11
  # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
12
  # See the License for the specific language governing permissions and
13
13
  # limitations under the License.
14
-
15
- from __future__ import annotations
16
-
17
- import importlib
18
- import pathlib
19
- import pkgutil
20
- from collections.abc import Callable
21
- from functools import partial, wraps
22
- from types import ModuleType
23
- from typing import Any
24
-
25
- from mplang.core.mask import Mask
26
- from mplang.core.mpobject import MPObject
27
- from mplang.core.mptype import Rank
28
- from mplang.core.primitive import (
29
- constant,
30
- pconv,
31
- peval,
32
- prand,
33
- prank,
34
- pshfl,
35
- pshfl_s,
36
- uniform_cond,
37
- while_loop,
38
- )
39
- from mplang.ops import ibis_cc, jax_cc
40
- from mplang.ops.base import FeOperation
41
- from mplang.simp.mpi import allgather_m, bcast_m, gather_m, p2p, scatter_m
42
- from mplang.simp.random import key_split, pperm, prandint, ukey, urandint
43
- from mplang.simp.smpc import reveal, revealTo, seal, sealFrom, srun
44
-
45
- # Public exports of the simplified party execution API.
46
- # NOTE: Replaces previous internal __reexport__ (not a Python convention)
47
- # to make star-imports explicit and tooling-friendly.
48
- __all__ = [ # noqa: RUF022
49
- "MPObject",
50
- "P",
51
- "P0",
52
- "P1",
53
- "P2",
54
- "P2P",
55
- "Party",
56
- "allgather_m",
57
- "bcast_m",
58
- "constant",
59
- "gather_m",
60
- "key_split",
61
- "load_module",
62
- "p2p",
63
- "pconv",
64
- "peval",
65
- "pperm",
66
- "prand",
67
- "prandint",
68
- "prank",
69
- "pshfl",
70
- "pshfl_s",
71
- "reveal",
72
- "revealTo",
73
- "run",
74
- "runAt",
75
- "scatter_m",
76
- "seal",
77
- "sealFrom",
78
- "srun",
79
- "ukey",
80
- "uniform_cond",
81
- "urandint",
82
- "while_loop",
83
- ]
84
-
85
-
86
- def run_impl(
87
- pmask: Mask | None,
88
- func: Callable,
89
- *args: Any,
90
- **kwargs: Any,
91
- ) -> Any:
92
- """
93
- Run a function that can be evaluated by the mplang system.
94
-
95
- This function provides a dispatch mechanism based on the first argument
96
- to route different function types to appropriate handlers.
97
-
98
- Args:
99
- pmask: The party mask of this function, None indicates auto deduce parties from args.
100
- func: The function to be dispatched and executed
101
- *args: Positional arguments to pass to the function
102
- **kwargs: Keyword arguments to pass to the function
103
-
104
- Returns:
105
- The result of evaluating the function through the appropriate handler
106
-
107
- Raises:
108
- ValueError: If basic.write is called without required arguments
109
- TypeError: If the function compilation or evaluation fails
110
- RuntimeError: If the underlying peval execution encounters errors
111
-
112
- Examples:
113
- Reading data from a file:
114
-
115
- >>> tensor_info = TensorType(shape=(10, 10), dtype=np.float32)
116
- >>> attrs = {"format": "binary"}
117
- >>> result = run_impl(basic.read, "data/input.bin", tensor_info, attrs)
118
-
119
- Writing data to a file:
120
-
121
- >>> run_impl(basic.write, data, "data/output.bin")
122
-
123
- Running a JAX function:
124
-
125
- >>> def matrix_multiply(a, b):
126
- ... return jnp.dot(a, b)
127
- >>> result = run_impl(matrix_multiply, mat_a, mat_b)
128
-
129
- Running a custom computation function:
130
-
131
- >>> def compute_statistics(data):
132
- ... mean = jnp.mean(data)
133
- ... std = jnp.std(data)
134
- ... return {"mean": mean, "std": std}
135
- >>> stats = run_impl(compute_statistics, dataset)
136
- """
137
-
138
- if isinstance(func, FeOperation):
139
- pfunc, eval_args, out_tree = func(*args, **kwargs)
140
- else:
141
- if ibis_cc.is_ibis_function(func):
142
- pfunc, eval_args, out_tree = ibis_cc.run_ibis(func, *args, **kwargs)
143
- else:
144
- # unknown python callable, treat it as jax function
145
- pfunc, eval_args, out_tree = jax_cc.run_jax(func, *args, **kwargs)
146
- results = peval(pfunc, eval_args, pmask)
147
- return out_tree.unflatten(results)
148
-
149
-
150
- # run :: (a -> a) -> m a -> m a
151
- def run(pyfn: Callable) -> Callable:
152
- return partial(run_impl, None, pyfn)
153
-
154
-
155
- # runAt :: Rank -> (a -> a) -> m a -> m a
156
- def runAt(rank: Rank, pyfn: Callable) -> Callable:
157
- pmask = Mask.from_ranks(rank)
158
- return partial(run_impl, pmask, pyfn)
159
-
160
-
161
- def P2P(src: Party, dst: Party, value: Any) -> Any:
162
- """Point-to-point transfer using Party objects instead of raw ranks.
163
-
164
- Equivalent to ``p2p(src.rank, dst.rank, value)`` but improves readability
165
- and reduces magic numbers in user code / tutorials.
166
-
167
- Parameters
168
- ----------
169
- src : Party
170
- Source party object.
171
- dst : Party
172
- Destination party object.
173
- value : Any
174
- Value to transfer.
175
-
176
- Returns
177
- -------
178
- Any
179
- The same value representation at destination context (as defined by
180
- underlying ``p2p`` primitive semantics).
181
- """
182
- if not isinstance(src, Party) or not isinstance(dst, Party): # defensive
183
- raise TypeError("P2P expects Party objects, e.g. P2P(P0, P2, value)")
184
- return p2p(src.rank, dst.rank, value)
185
-
186
-
187
- """Party-scoped module registration & dispatch.
188
-
189
- This module provides a light-weight mechanism to expose *module-like* groups
190
- of callable operations bound to a specific party (rank) via attribute access:
191
-
192
- load_module("mplang.ops.crypto", alias="crypto")
193
- P0.crypto.encrypt(x) # executes encrypt() with pmask = {rank 0}
194
-
195
- Core concepts:
196
- * Registry (``_NAMESPACE_REGISTRY``): maps alias -> importable module path.
197
- * Lazy import: underlying module is imported on first attribute access.
198
- * Wrapping: fetched callables are wrapped so that invocation automatically
199
- routes through ``run_impl`` with that party's mask.
200
-
201
- Only *callable* attributes are exposed; non-callable attributes raise
202
- ``AttributeError`` to avoid surprising divergent local vs. distributed
203
- semantics.
204
-
205
- The public API surface intentionally stays small (`Party`, `P`, `run`,
206
- `runAt`, and `load_module`). Internal details (proxy class / registry) are
207
- not part of the stability guarantee.
208
- """
209
-
210
- _NAMESPACE_REGISTRY: dict[str, str] = {}
211
-
212
-
213
- class _PartyModuleProxy:
214
- """Lazy module proxy bound to a specific party.
215
-
216
- Attribute access resolves a callable inside the registered module and
217
- returns a wrapped function that executes with the party's mask.
218
- Non-callable attributes are rejected explicitly to keep semantics clear.
219
- """
220
-
221
- def __init__(self, party: Party, name: str):
222
- self._party: Party = party
223
- self._name: str = name
224
- self._module: ModuleType | None = None # loaded lazily
225
-
226
- def _ensure(self) -> None:
227
- if self._module is None:
228
- self._module = importlib.import_module(_NAMESPACE_REGISTRY[self._name])
229
-
230
- def __getattr__(self, item: str) -> Callable[..., Any]:
231
- self._ensure()
232
- target = getattr(self._module, item)
233
- if not callable(target):
234
- raise AttributeError(
235
- f"Attribute '{item}' of party module '{self._name}' is not callable (got {type(target).__name__})"
236
- )
237
-
238
- @wraps(target)
239
- def _wrapped(*args: Any, **kw: Any) -> Any:
240
- # Inline runAt to reduce an extra partial layer while preserving semantics.
241
- return run_impl(Mask.from_ranks(self._party.rank), target, *args, **kw)
242
-
243
- # Provide a party-qualified name for debugging / logs without losing original metadata.
244
- base_name = getattr(target, "__name__", None)
245
- if base_name is None:
246
- # Frontend FeOperation or object without __name__; try .name attribute (FeOperation contract) or fallback to repr
247
- base_name = getattr(target, "name", None) or type(target).__name__
248
- try:
249
- _wrapped.__name__ = f"{base_name}@P{self._party.rank}"
250
- except Exception: # pragma: no cover - assignment may fail for exotic wrappers
251
- pass
252
- return _wrapped
253
-
254
-
255
- class Party:
256
- def __init__(self, rank: int) -> None:
257
- self.rank: int = int(rank)
258
-
259
- def __repr__(self) -> str: # pragma: no cover
260
- return f"Party(rank={self.rank})"
261
-
262
- def __call__(self, fn: Callable[..., Any], *args: Any, **kwargs: Any) -> Any:
263
- if not callable(fn):
264
- raise TypeError(
265
- f"First argument to Party({self.rank}) must be callable, got {fn!r}"
266
- )
267
- return runAt(self.rank, fn)(*args, **kwargs)
268
-
269
- def __getattr__(self, name: str) -> _PartyModuleProxy:
270
- if name in _NAMESPACE_REGISTRY:
271
- return _PartyModuleProxy(self, name)
272
- raise AttributeError(
273
- f"Party has no attribute '{name}'. Registered: {list(_NAMESPACE_REGISTRY)}"
274
- )
275
-
276
-
277
- class _PartyIndex:
278
- def __getitem__(self, rank: int) -> Party:
279
- return Party(rank)
280
-
281
-
282
- def _load_prelude_modules() -> None:
283
- """Auto-register public frontend submodules for party namespace access.
284
-
285
- Implementation detail: we treat every non-underscore immediate child of
286
- ``mplang.ops`` as public and make it available as ``P0.<name>``.
287
- This keeps user ergonomics high (no manual load_module calls for core
288
- frontends) but slightly increases implicit surface area. If this grows
289
- unwieldy we can switch to an allowlist.
290
- """
291
- try:
292
- import mplang.ops as _fe # type: ignore
293
- except (ImportError, ModuleNotFoundError): # pragma: no cover
294
- # Frontend package not present (minimal install); safe to skip.
295
- return
296
-
297
- pkg_path = pathlib.Path(_fe.__file__).parent
298
- for m in pkgutil.iter_modules([str(pkg_path)]):
299
- if m.name.startswith("_"):
300
- continue
301
- if m.name not in _NAMESPACE_REGISTRY:
302
- _NAMESPACE_REGISTRY[m.name] = f"mplang.ops.{m.name}"
303
-
304
-
305
- def load_module(module: str, alias: str | None = None) -> None:
306
- """Register a module for party-scoped (per-rank) callable access.
307
-
308
- After registration, each party object (e.g. ``P0``) can access callable
309
- attributes of the target module through the chosen alias and have them
310
- executed under that party's mask automatically. Non-callable attributes
311
- are intentionally not exposed to avoid ambiguity between local data and
312
- distributed execution semantics.
313
-
314
- Parameters
315
- ----------
316
- module : str
317
- The fully-qualified import path of the module to expose. It must be
318
- importable via ``importlib.import_module``.
319
- alias : str | None, optional
320
- The short name used as an attribute on ``Party``/``P0``/``P1``/... .
321
- If omitted, the last path segment of ``module`` is used.
322
-
323
- Raises
324
- ------
325
- ValueError
326
- If the alias is already registered to a *different* module path.
327
-
328
- Notes
329
- -----
330
- Registration is idempotent when the alias maps to the same module. The
331
- actual module object is imported lazily on first attribute access, so
332
- calling ``load_module`` has negligible upfront cost.
333
-
334
- Examples
335
- --------
336
- >>> load_module("mplang.ops.crypto", alias="crypto")
337
- >>> # Now call an op on party 0
338
- >>> P0.crypto.encrypt(data)
339
- """
340
- if alias is None:
341
- alias = module.rsplit(".", 1)[-1]
342
- prev = _NAMESPACE_REGISTRY.get(alias)
343
- if prev and prev != module:
344
- raise ValueError(f"Alias '{alias}' already registered for '{prev}'")
345
- _NAMESPACE_REGISTRY[alias] = module
346
-
347
-
348
- P = _PartyIndex()
349
- P0, P1, P2 = Party(0), Party(1), Party(2)
350
-
351
- _load_prelude_modules()
mplang/simp/api.py ADDED
@@ -0,0 +1,96 @@
1
+ # Copyright 2025 Ant Group Co., Ltd.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ from __future__ import annotations
16
+
17
+ from collections.abc import Callable
18
+ from typing import Any
19
+
20
+ from mplang.core.mpobject import MPObject
21
+ from mplang.core.mptype import Rank
22
+ from mplang.core.primitive import run, run_at
23
+ from mplang.ops import ibis_cc, jax_cc, sql_cc
24
+
25
+
26
+ def run_jax(jax_fn: Callable, *args: Any, **kwargs: Any) -> Any:
27
+ """Run a JAX function.
28
+
29
+ Args:
30
+ jax_fn: The JAX function to be executed.
31
+ *args: Positional arguments to pass to the JAX function.
32
+ **kwargs: Keyword arguments to pass to the JAX function.
33
+
34
+ Returns:
35
+ The result of evaluating the JAX function through the mplang system.
36
+
37
+ Raises:
38
+ TypeError: If the function compilation or evaluation fails.
39
+ RuntimeError: If the underlying peval execution encounters errors.
40
+
41
+ Notes:
42
+ Argument binding semantics with respect to JAX static arguments:
43
+
44
+ - If an argument (or any leaf within a PyTree argument) is an
45
+ :class:`~mplang.core.mpobject.MPObject`, it is captured as a runtime
46
+ variable (dynamic value) in the traced program and is not treated as a
47
+ JAX static argument.
48
+ - If an argument contains no :class:`MPObject` leaves, it is treated as a
49
+ constant configuration with respect to JAX; effectively it behaves
50
+ like a static argument and may contribute to JAX compilation cache
51
+ keys (similar to ``static_argnums`` semantics). Changing such constant
52
+ arguments can lead to different compiled variants/cached entries.
53
+
54
+ Examples:
55
+ Defining and running a simple JAX function:
56
+
57
+ >>> import jax.numpy as jnp
58
+ >>> def add_matrices(a, b):
59
+ ... return jnp.add(a, b)
60
+ >>> result = run_jax(add_matrices, matrix_a, matrix_b)
61
+
62
+ Running a more complex JAX function:
63
+
64
+ >>> def compute_statistics(data):
65
+ ... mean = jnp.mean(data)
66
+ ... std = jnp.std(data)
67
+ ... return {"mean": mean, "std": std}
68
+ >>> stats = run_jax(compute_statistics, dataset)
69
+ """
70
+ return run(None, jax_cc.run_jax, jax_fn, *args, **kwargs)
71
+
72
+
73
+ def run_jax_at(rank: Rank, jax_fn: Callable, *args: Any, **kwargs: Any) -> Any:
74
+ return run_at(rank, jax_cc.run_jax, jax_fn, *args, **kwargs)
75
+
76
+
77
+ def run_ibis(ibis_expr: Any, *args: Any, **kwargs: Any) -> Any:
78
+ # TODO(jint): add docstring, add type hints, describe args and kwargs constraints.
79
+ return run(None, ibis_cc.run_ibis, ibis_expr, *args, **kwargs)
80
+
81
+
82
+ def run_ibis_at(rank: Rank, ibis_fn: Any, *args: Any, **kwargs: Any) -> Any:
83
+ return run_at(rank, ibis_cc.run_ibis, ibis_fn, *args, **kwargs)
84
+
85
+
86
+ def run_sql(
87
+ query: str, out_type: Any, in_tables: dict[str, MPObject] | None = None
88
+ ) -> Any:
89
+ # TODO(jint): add docstring, drop out_type.
90
+ return run(None, sql_cc.run_sql, query, out_type, in_tables)
91
+
92
+
93
+ def run_sql_at(
94
+ rank: Rank, query: str, out_type: Any, in_tables: dict[str, MPObject] | None = None
95
+ ) -> Any:
96
+ return run_at(rank, sql_cc.run_sql, query, out_type, in_tables)
mplang/simp/party.py ADDED
@@ -0,0 +1,225 @@
1
+ # Copyright 2025 Ant Group Co., Ltd.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ from __future__ import annotations
16
+
17
+ import importlib
18
+ import pathlib
19
+ import pkgutil
20
+ from collections.abc import Callable
21
+ from functools import wraps
22
+ from types import ModuleType
23
+ from typing import Any
24
+
25
+ from mplang.ops.base import FeOperation
26
+ from mplang.simp.api import run_at, run_jax_at
27
+ from mplang.simp.mpi import p2p
28
+
29
+
30
+ def P2P(src: Party, dst: Party, value: Any) -> Any:
31
+ """Point-to-point transfer using Party objects instead of raw ranks.
32
+
33
+ Equivalent to ``p2p(src.rank, dst.rank, value)`` but improves readability
34
+ and reduces magic numbers in user code / tutorials.
35
+
36
+ Parameters
37
+ ----------
38
+ src : Party
39
+ Source party object.
40
+ dst : Party
41
+ Destination party object.
42
+ value : Any
43
+ Value to transfer.
44
+
45
+ Returns
46
+ -------
47
+ Any
48
+ The same value representation at destination context (as defined by
49
+ underlying ``p2p`` primitive semantics).
50
+ """
51
+ if not isinstance(src, Party) or not isinstance(dst, Party): # defensive
52
+ raise TypeError("P2P expects Party objects, e.g. P2P(P0, P2, value)")
53
+ return p2p(src.rank, dst.rank, value)
54
+
55
+
56
+ """Party-scoped module registration & dispatch.
57
+
58
+ This module provides a light-weight mechanism to expose *module-like* groups
59
+ of callable operations bound to a specific party (rank) via attribute access:
60
+
61
+ load_module("mplang.ops.crypto", alias="crypto")
62
+ P0.crypto.encrypt(x) # executes encrypt() with pmask = {rank 0}
63
+
64
+ Core concepts:
65
+ * Registry (``_NAMESPACE_REGISTRY``): maps alias -> importable module path.
66
+ * Lazy import: underlying module is imported on first attribute access.
67
+ * Wrapping: fetched callables are wrapped so that invocation automatically
68
+ routes through ``run_impl`` with that party's mask.
69
+
70
+ Only *callable* attributes are exposed; non-callable attributes raise
71
+ ``AttributeError`` to avoid surprising divergent local vs. distributed
72
+ semantics.
73
+
74
+ The public API surface intentionally stays small (`Party`, `P`, `run`,
75
+ `runAt`, and `load_module`). Internal details (proxy class / registry) are
76
+ not part of the stability guarantee.
77
+ """
78
+
79
+ _NAMESPACE_REGISTRY: dict[str, str] = {}
80
+
81
+
82
+ class _PartyModuleProxy:
83
+ """Lazy module proxy bound to a specific party.
84
+
85
+ Attribute access resolves a callable inside the registered module and
86
+ returns a wrapped function that executes with the party's mask.
87
+ Non-callable attributes are rejected explicitly to keep semantics clear.
88
+ """
89
+
90
+ def __init__(self, party: Party, name: str):
91
+ self._party: Party = party
92
+ self._name: str = name
93
+ self._module: ModuleType | None = None # loaded lazily
94
+
95
+ def _ensure(self) -> None:
96
+ if self._module is None:
97
+ self._module = importlib.import_module(_NAMESPACE_REGISTRY[self._name])
98
+
99
+ def __getattr__(self, item: str) -> Callable[..., Any]:
100
+ self._ensure()
101
+ op = getattr(self._module, item)
102
+ if not callable(op):
103
+ raise AttributeError(
104
+ f"Attribute '{item}' of party module '{self._name}' is not callable (got {type(op).__name__})"
105
+ )
106
+
107
+ @wraps(op)
108
+ def _wrapped(*args: Any, **kw: Any) -> Any:
109
+ # Inline runAt to reduce an extra partial layer while preserving semantics.
110
+ return run_at(self._party.rank, op, *args, **kw)
111
+
112
+ # Provide a party-qualified name for debugging / logs without losing original metadata.
113
+ base_name = getattr(op, "__name__", None)
114
+ if base_name is None:
115
+ # Frontend FeOperation or object without __name__; try .name attribute (FeOperation contract) or fallback to repr
116
+ base_name = getattr(op, "name", None) or type(op).__name__
117
+ try:
118
+ _wrapped.__name__ = f"{base_name}@P{self._party.rank}"
119
+ except Exception: # pragma: no cover - assignment may fail for exotic wrappers
120
+ pass
121
+ return _wrapped
122
+
123
+
124
+ class Party:
125
+ def __init__(self, rank: int) -> None:
126
+ self.rank: int = int(rank)
127
+
128
+ def __repr__(self) -> str: # pragma: no cover
129
+ return f"Party(rank={self.rank})"
130
+
131
+ def __call__(self, fn: Callable[..., Any], *args: Any, **kwargs: Any) -> Any:
132
+ if not callable(fn):
133
+ raise TypeError(
134
+ f"First argument to Party({self.rank}) must be callable, got {fn!r}"
135
+ )
136
+ # Use run_op_at for FeOperation, run_jax_at for plain callables
137
+ if isinstance(fn, FeOperation):
138
+ return run_at(self.rank, fn, *args, **kwargs)
139
+ else:
140
+ # TODO(jint): implicitly assume non-FeOperation as JAX function is a bit too magical?
141
+ return run_jax_at(self.rank, fn, *args, **kwargs)
142
+
143
+ def __getattr__(self, name: str) -> _PartyModuleProxy:
144
+ if name in _NAMESPACE_REGISTRY:
145
+ return _PartyModuleProxy(self, name)
146
+ raise AttributeError(
147
+ f"Party has no attribute '{name}'. Registered: {list(_NAMESPACE_REGISTRY)}"
148
+ )
149
+
150
+
151
+ class _PartyIndex:
152
+ def __getitem__(self, rank: int) -> Party:
153
+ return Party(rank)
154
+
155
+
156
+ def _load_prelude_modules() -> None:
157
+ """Auto-register public frontend submodules for party namespace access.
158
+
159
+ Implementation detail: we treat every non-underscore immediate child of
160
+ ``mplang.ops`` as public and make it available as ``P0.<name>``.
161
+ This keeps user ergonomics high (no manual load_module calls for core
162
+ frontends) but slightly increases implicit surface area. If this grows
163
+ unwieldy we can switch to an allowlist.
164
+ """
165
+ try:
166
+ import mplang.ops as _fe # type: ignore
167
+ except (ImportError, ModuleNotFoundError): # pragma: no cover
168
+ # Frontend package not present (minimal install); safe to skip.
169
+ return
170
+
171
+ pkg_path = pathlib.Path(_fe.__file__).parent
172
+ for m in pkgutil.iter_modules([str(pkg_path)]):
173
+ if m.name.startswith("_"):
174
+ continue
175
+ if m.name not in _NAMESPACE_REGISTRY:
176
+ _NAMESPACE_REGISTRY[m.name] = f"mplang.ops.{m.name}"
177
+
178
+
179
+ def load_module(module: str, alias: str | None = None) -> None:
180
+ """Register a module for party-scoped (per-rank) callable access.
181
+
182
+ After registration, each party object (e.g. ``P0``) can access callable
183
+ attributes of the target module through the chosen alias and have them
184
+ executed under that party's mask automatically. Non-callable attributes
185
+ are intentionally not exposed to avoid ambiguity between local data and
186
+ distributed execution semantics.
187
+
188
+ Parameters
189
+ ----------
190
+ module : str
191
+ The fully-qualified import path of the module to expose. It must be
192
+ importable via ``importlib.import_module``.
193
+ alias : str | None, optional
194
+ The short name used as an attribute on ``Party``/``P0``/``P1``/... .
195
+ If omitted, the last path segment of ``module`` is used.
196
+
197
+ Raises
198
+ ------
199
+ ValueError
200
+ If the alias is already registered to a *different* module path.
201
+
202
+ Notes
203
+ -----
204
+ Registration is idempotent when the alias maps to the same module. The
205
+ actual module object is imported lazily on first attribute access, so
206
+ calling ``load_module`` has negligible upfront cost.
207
+
208
+ Examples
209
+ --------
210
+ >>> load_module("mplang.ops.crypto", alias="crypto")
211
+ >>> # Now call an op on party 0
212
+ >>> P0.crypto.encrypt(data)
213
+ """
214
+ if alias is None:
215
+ alias = module.rsplit(".", 1)[-1]
216
+ prev = _NAMESPACE_REGISTRY.get(alias)
217
+ if prev and prev != module:
218
+ raise ValueError(f"Alias '{alias}' already registered for '{prev}'")
219
+ _NAMESPACE_REGISTRY[alias] = module
220
+
221
+
222
+ P = _PartyIndex()
223
+ P0, P1, P2 = Party(0), Party(1), Party(2)
224
+
225
+ _load_prelude_modules()
mplang/simp/random.py CHANGED
@@ -22,8 +22,8 @@ import jax.random as jr
22
22
  from jax.typing import ArrayLike
23
23
 
24
24
  import mplang.core.primitive as prim
25
- from mplang import simp
26
25
  from mplang.core import MPObject, Shape
26
+ from mplang.simp.api import run_jax
27
27
 
28
28
 
29
29
  @prim.function
@@ -32,12 +32,12 @@ def key_split(key: MPObject) -> tuple[MPObject, MPObject]:
32
32
 
33
33
  def kernel(key: jax.Array) -> tuple[jax.Array, jax.Array]:
34
34
  # TODO: since MPObject tensor does not implement slicing yet.
35
- # subkey, key = simp.run(jr.split)(key) does not work.
35
+ # subkey, key = run_jax(jr.split, key) does not work.
36
36
  # we workaround it by splitting inside tracer.
37
37
  subkey, key = jr.split(key)
38
38
  return subkey, key
39
39
 
40
- return simp.run(kernel)(key) # type: ignore[no-any-return]
40
+ return run_jax(kernel, key) # type: ignore[no-any-return]
41
41
 
42
42
 
43
43
  @prim.function
@@ -49,7 +49,7 @@ def ukey(seed: int | ArrayLike) -> MPObject:
49
49
  # Note: key.dtype is jax._src.prng.KeyTy, which could not be handled by MPObject.
50
50
  return jax.random.key_data(key)
51
51
 
52
- return simp.run(kernel)() # type: ignore[no-any-return]
52
+ return run_jax(kernel) # type: ignore[no-any-return]
53
53
 
54
54
 
55
55
  @prim.function
@@ -61,7 +61,7 @@ def urandint(
61
61
  ) -> MPObject:
62
62
  """Party uniformly generate a random integer in the range [low, high) with the given shape."""
63
63
 
64
- return simp.run(partial(jr.randint, minval=low, maxval=high, shape=shape))(key) # type: ignore[no-any-return]
64
+ return run_jax(partial(jr.randint, minval=low, maxval=high, shape=shape), key) # type: ignore[no-any-return]
65
65
 
66
66
 
67
67
  # Private(different per-party) related functions begin.
@@ -81,7 +81,7 @@ def prandint(low: int, high: int, shape: Shape = ()) -> MPObject:
81
81
  return result
82
82
 
83
83
  rand_u64 = prim.prand(shape)
84
- return simp.run(kernel)(rand_u64) # type: ignore[no-any-return]
84
+ return run_jax(kernel, rand_u64) # type: ignore[no-any-return]
85
85
 
86
86
 
87
87
  @prim.function
@@ -116,6 +116,6 @@ def pperm(key: MPObject) -> MPObject:
116
116
  def kernel(key: jax.Array) -> jax.Array:
117
117
  return jr.permutation(key, size)
118
118
 
119
- perm = simp.run(kernel)(key)
119
+ perm = run_jax(kernel, key)
120
120
  rank = prim.prank()
121
- return simp.run(lambda perm, rank: perm[rank])(perm, rank) # type: ignore[no-any-return]
121
+ return run_jax(lambda perm, rank: perm[rank], perm, rank) # type: ignore[no-any-return]
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: mplang-nightly
3
- Version: 0.1.dev172
3
+ Version: 0.1.dev174
4
4
  Summary: Multi-Party Programming Language
5
5
  Author-email: SecretFlow Team <secretflow-contact@service.alipay.com>
6
6
  License: Apache License
@@ -1,20 +1,20 @@
1
- mplang/__init__.py,sha256=tZmMm5LXY5QXKDvyYw1ysSStRky8CHelPQ9T0qgmfGs,1853
2
- mplang/device.py,sha256=b9H1I-MFtL7hvJ38xoq65QdH1vGgPl9fNLQ-IwZWRvE,12479
1
+ mplang/__init__.py,sha256=K5N0yTkRr4uJDCUSD2IXfMSCE81a-Pm-mx07hJ_6QrA,3182
2
+ mplang/device.py,sha256=BWCFgWbhgnuzyt1thopgsBUy75sJDe9s-WgM29gxZqY,12420
3
3
  mplang/host.py,sha256=ssmv0_CyZPFORhOUJ84Jo6NwRJSK7_Ono3n7ZjEg4sA,3058
4
4
  mplang/analysis/__init__.py,sha256=CTHFvRsi-nFngojqjn08UaR3RY9i7CJ7T2UdR95kCrk,1056
5
- mplang/analysis/diagram.py,sha256=ffwgD12gL1_KH1uJ_EYkjmIlDrfxYJJkWj-wHl09_Xk,19520
6
- mplang/core/__init__.py,sha256=lWxlEKfRwX7FNDzgyKZ1fiDMaCiqkyg0j5mKlZD_v7g,2244
5
+ mplang/analysis/diagram.py,sha256=JUJEe6HnR-V8yVrV49VSy2Hbpyvy-i8EbggVlmpJTNw,19524
6
+ mplang/core/__init__.py,sha256=fXTmdyeAeWNmAVWeKmRXWnoz2n3ucoNpSQh_OylXBRA,2353
7
7
  mplang/core/cluster.py,sha256=uARPlcWL0ddrWyOZ7vcX6JYN7a72mqbz2XGI1rSfkjE,11625
8
8
  mplang/core/comm.py,sha256=MByyu3etlQh_TkP1vKCFLIAPPuJOpl9Kjs6hOj6m4Yc,8843
9
9
  mplang/core/context_mgr.py,sha256=R0QJAod-1nYduVoOknLfAsxZiy-RtmuQcp-07HABYZU,1541
10
10
  mplang/core/dtype.py,sha256=0rZqFaFikFu9RxtdO36JLEgFL-E-lo3hH10whwkTVVY,10213
11
11
  mplang/core/interp.py,sha256=JKjKJGWURU5rlHQ2yG5XNKWzN6hLZsmo--hZuveQgxI,5915
12
12
  mplang/core/mask.py,sha256=14DFxaA446lGjN4dzTuQgm9Shcn34rYI87YJHg0YGNQ,10693
13
- mplang/core/mpir.py,sha256=3NyHa1cDnUaw3XWIUgyOMXfZ9JS-30COb29AoXYcRtM,38251
13
+ mplang/core/mpir.py,sha256=01AOZgu8fcoEA9TDKCDnFW03AmSIhsFJGHPMmf-gZt0,38259
14
14
  mplang/core/mpobject.py,sha256=0pHSd7SrAFTScCFcB9ziDztElYQn-oIZOKBx47B3QX0,3732
15
15
  mplang/core/mptype.py,sha256=7Cp2e58uUX-uqTp6QxuioOMJ8BzLBPXlWG5rRakv2uo,13773
16
16
  mplang/core/pfunc.py,sha256=WOGmMr4HCUELED5QxYbhhyQJRDXrA5Bk3tPbZWpwmw8,5102
17
- mplang/core/primitive.py,sha256=vu60-k0fSAWWidcWDC0_FGvrRZww12oGXjB8CR9F6Yo,43889
17
+ mplang/core/primitive.py,sha256=cuQYNK80pNcWuwGOt-djX_p93RKk9tztWexxndpTzuA,44113
18
18
  mplang/core/table.py,sha256=BqTBZn7Tfwce4vzl3XYhaX5hVmKagVq9-YoERDta6d8,5892
19
19
  mplang/core/tensor.py,sha256=86u6DogSZMoL0w5XjtTmQm2PhA_VjwybN1b6U4Zzphg,2361
20
20
  mplang/core/tracer.py,sha256=dVMfUeCMmPz4o6tLXewGMW1Kpy5gpZORvr9w4MhwDtM,14288
@@ -49,7 +49,6 @@ mplang/ops/sql_cc.py,sha256=-9uf75gOxLQlFiKjDm75qIx8Gbun7unvkOxezdSLGlE,2112
49
49
  mplang/ops/tee.py,sha256=bOpS_BXG12D6bONikzdF2yt0oVZj9Jyd0g_3IXP8VgE,1281
50
50
  mplang/protos/v1alpha1/mpir_pb2.py,sha256=Bros37t-4LMJbuUYVSM65rImUYTtZDhNTIADGbZCKp0,7522
51
51
  mplang/protos/v1alpha1/mpir_pb2.pyi,sha256=dLxAtFW7mgFR-HYxC4ExI4jbtEWUGTKBvcKhI3BJ8m0,20972
52
- mplang/protos/v1alpha1/mpir_pb2_grpc.py,sha256=xYOs94SXiNYAlFodACnsXW5QovLsHY5tCk3p76RH5Zc,158
53
52
  mplang/protos/v1alpha1/value_pb2.py,sha256=V9fqQTqXNo2efYmlP9xOhC7EpjBUp5jL-05yrJsAvWU,2785
54
53
  mplang/protos/v1alpha1/value_pb2.pyi,sha256=47GVvuZfiV5oaVemwh0xTfns8OYTVBT8YnluIQeQPbs,7108
55
54
  mplang/runtime/__init__.py,sha256=IRPP3TtpFC4iSt7_uaq-S4dL7CwrXL0XBMeaBoEYLlg,948
@@ -57,24 +56,26 @@ mplang/runtime/cli.py,sha256=WehDodeVB4AukSWx1LJxxtKUqGmLPY4qjayrPlOg3bE,14438
57
56
  mplang/runtime/client.py,sha256=v73cFnwN7ePaGJmPi20waizeIq6dlJTjEs6OkybSR2M,15858
58
57
  mplang/runtime/communicator.py,sha256=P5nl3wxRomUafoAj-1FSgD7phelof399deitVB1JMos,3508
59
58
  mplang/runtime/data_providers.py,sha256=GX10_nch8PmEyok32mSC4p5rDowvmXrJ-4J5-LvY6ig,8206
60
- mplang/runtime/driver.py,sha256=Elo6zx66kvpNW7OwYr-gcb-RJDhcnCbPFQWyBcYM0Uk,11745
59
+ mplang/runtime/driver.py,sha256=c7c4ULmMln5GwkK17cPyc-45XrQbprwSt7HnEaIQfcc,11725
61
60
  mplang/runtime/exceptions.py,sha256=c18U0xK20dRmgZo0ogTf5vXlkix9y3VAFuzkHxaXPEk,981
62
61
  mplang/runtime/http_api.md,sha256=-re1DhEqMplAkv_wnqEU-PSs8tTzf4-Ml0Gq0f3Go6s,4883
63
62
  mplang/runtime/link_comm.py,sha256=ZHNcis8QDu2rcyyF3rhpxMiJDkczoMA_c0iZ2GDW_bA,2793
64
- mplang/runtime/server.py,sha256=CdmBmpbylEl7XeZj26i0rUmTrPTvl2CVdRgbtR02gcg,16543
63
+ mplang/runtime/server.py,sha256=N2R_8KIi_eC5H0gEWEmGOrBW0MWMVFurJKeCxsADcSg,16547
65
64
  mplang/runtime/session.py,sha256=I2711V-pPRCYibNgBhjboUURdubnL6ltCoh5RvFVabs,10641
66
- mplang/runtime/simulation.py,sha256=1I_8dIqxivxtYnQK0ofz0oXk3lXXh3-zN0lmNFnWucA,11615
67
- mplang/simp/__init__.py,sha256=_X1kpq9qhoPUL2gQRVNkSjS5jNSxYRu7-_1GSwJ9PK8,11575
65
+ mplang/runtime/simulation.py,sha256=wz47a1sskgBlt_xVPvZn7vO9-XGJ94CK5muxPZoZaRk,11623
66
+ mplang/simp/__init__.py,sha256=2WE4cmW96Xkzyq2yRRYNww4kZ5o6u6NbPV0BxqZG698,581
67
+ mplang/simp/api.py,sha256=apxrfPbAqtUGUT-C6GMwJDg32KHSkMb9aaxjeUqVksw,3613
68
68
  mplang/simp/mpi.py,sha256=Wv_Q16TQ3rdLam6OzqXiefIGSMmagGkso09ycyOkHEs,4774
69
- mplang/simp/random.py,sha256=7PVgWNL1j7Sf3MqT5PRiWplUu-0dyhF3Ub566iqX86M,3898
69
+ mplang/simp/party.py,sha256=jl-_K4CNj9lA9ZW5hgmoxvqtliXcL3PLWRdu3C7i4fo,8211
70
+ mplang/simp/random.py,sha256=2pCYYZKRbpTLaMWJMj6UxtOR6EsI2Et4trSgiFD4At0,3901
70
71
  mplang/simp/smpc.py,sha256=tdH54aU4T-GIDPhpmf9NCeJC0G67PdOYc04cyUkOnwE,7119
71
72
  mplang/utils/__init__.py,sha256=2WE4cmW96Xkzyq2yRRYNww4kZ5o6u6NbPV0BxqZG698,581
72
73
  mplang/utils/crypto.py,sha256=rvPomBFtznRHc3RPi6Aip9lsU8zW2oxBqGv1K3vn7Rs,1052
73
74
  mplang/utils/func_utils.py,sha256=vCJcZmu0bEbqhOQKdpttV2_MBllIcPSN0b8U4WjNGGo,5164
74
75
  mplang/utils/spu_utils.py,sha256=S3L9RBkBe2AvSuMSQQ12cBY5Y1NPthubvErSX_7nj1A,4158
75
76
  mplang/utils/table_utils.py,sha256=aC-IZOKkSmFkpr3NZchLM0Wt0GOn-rg_xHBHREWBwAU,2202
76
- mplang_nightly-0.1.dev172.dist-info/METADATA,sha256=hjAMloWrHmrUgIZhcHEO268e1Nn2r4Kg52MXkhl927s,16547
77
- mplang_nightly-0.1.dev172.dist-info/WHEEL,sha256=qtCwoSJWgHk21S1Kb4ihdzI2rlJ1ZKaIurTj_ngOhyQ,87
78
- mplang_nightly-0.1.dev172.dist-info/entry_points.txt,sha256=mG1oJT-GAjQR834a62_QIWb7litzWPPyVnwFqm-rWuY,55
79
- mplang_nightly-0.1.dev172.dist-info/licenses/LICENSE,sha256=xx0jnfkXJvxRnG63LTGOxlggYnIysveWIZ6H3PNdCrQ,11357
80
- mplang_nightly-0.1.dev172.dist-info/RECORD,,
77
+ mplang_nightly-0.1.dev174.dist-info/METADATA,sha256=l5YgCuDdKf-x_geEWMkuTSKVmBclYfN2UD5atDj7EM4,16547
78
+ mplang_nightly-0.1.dev174.dist-info/WHEEL,sha256=qtCwoSJWgHk21S1Kb4ihdzI2rlJ1ZKaIurTj_ngOhyQ,87
79
+ mplang_nightly-0.1.dev174.dist-info/entry_points.txt,sha256=mG1oJT-GAjQR834a62_QIWb7litzWPPyVnwFqm-rWuY,55
80
+ mplang_nightly-0.1.dev174.dist-info/licenses/LICENSE,sha256=xx0jnfkXJvxRnG63LTGOxlggYnIysveWIZ6H3PNdCrQ,11357
81
+ mplang_nightly-0.1.dev174.dist-info/RECORD,,
@@ -1,3 +0,0 @@
1
- # Generated by the gRPC Python protocol compiler plugin. DO NOT EDIT!
2
- """Client and server classes corresponding to protobuf-defined services."""
3
- import grpc