mplang-nightly 0.1.dev169__py3-none-any.whl → 0.1.dev170__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
mplang/core/expr/ast.py CHANGED
@@ -528,8 +528,9 @@ class FuncDefExpr(Expr):
528
528
  class CallExpr(Expr):
529
529
  """Expression for function call."""
530
530
 
531
- def __init__(self, fn: FuncDefExpr, args: list[Expr]):
531
+ def __init__(self, name: str, fn: FuncDefExpr, args: list[Expr]):
532
532
  super().__init__()
533
+ self.name = name
533
534
  self.fn = fn
534
535
  self.args = args
535
536
 
@@ -341,10 +341,10 @@ class RecursiveEvaluator(EvalSemantic, ExprVisitor):
341
341
 
342
342
  # Only evaluate selected branch locally
343
343
  if bool(pred):
344
- then_call = CallExpr(expr.then_fn, expr.args)
344
+ then_call = CallExpr("then", expr.then_fn, expr.args)
345
345
  return self._values(then_call)
346
346
  else:
347
- else_call = CallExpr(expr.else_fn, expr.args)
347
+ else_call = CallExpr("else", expr.else_fn, expr.args)
348
348
  return self._values(else_call)
349
349
 
350
350
  def visit_call(self, expr: CallExpr) -> Any:
@@ -50,11 +50,13 @@ class Printer(ExprVisitor):
50
50
  compact_format: bool = True,
51
51
  *,
52
52
  verbose_peval: bool = False,
53
+ inline_pcall: bool = True,
53
54
  ):
54
55
  super().__init__() # Initialize MemorizedVisitor
55
56
  self.indent_size = indent_size
56
57
  self.compact_format = compact_format
57
58
  self.verbose_peval = verbose_peval
59
+ self.inline_pcall = inline_pcall
58
60
  self._cur_indent = 0
59
61
  self._output: list[str] = []
60
62
  self._visited: dict[Expr, str] = {}
@@ -92,6 +94,7 @@ class Printer(ExprVisitor):
92
94
  body_printer = Printer(
93
95
  indent_size=self.indent_size,
94
96
  compact_format=self.compact_format,
97
+ inline_pcall=self.inline_pcall,
95
98
  )
96
99
  func_def_expr.accept(body_printer)
97
100
  regions_str += f"{indent}{r_name}: "
@@ -209,12 +212,19 @@ class Printer(ExprVisitor):
209
212
 
210
213
  def visit_call(self, expr: CallExpr) -> str:
211
214
  arg_names = [self._var_name(arg) for arg in expr.args]
212
- return self._do_print(
213
- "pcall",
214
- arg_names,
215
- regions={"fn": expr.fn},
216
- mptypes=expr.mptypes,
217
- )
215
+ if self.inline_pcall:
216
+ return self._do_print(
217
+ expr.name,
218
+ arg_names,
219
+ mptypes=expr.mptypes,
220
+ )
221
+ else:
222
+ return self._do_print(
223
+ "pcall",
224
+ arg_names,
225
+ regions={"fn": expr.fn},
226
+ mptypes=expr.mptypes,
227
+ )
218
228
 
219
229
  def visit_while(self, expr: WhileExpr) -> str:
220
230
  arg_names = [self._var_name(arg) for arg in expr.args]
@@ -79,7 +79,7 @@ class ExprTransformer(ExprVisitor):
79
79
  def visit_call(self, expr: CallExpr) -> Expr:
80
80
  # Transform child expressions first
81
81
  transformed_args = [arg.accept(self) for arg in expr.args]
82
- new_expr = CallExpr(expr.fn, transformed_args)
82
+ new_expr = CallExpr(expr.name, expr.fn, transformed_args)
83
83
 
84
84
  if "call" in self.trans_rules:
85
85
  return self.trans_rules["call"](new_expr)
mplang/core/mpir.py CHANGED
@@ -491,6 +491,7 @@ class Writer:
491
491
  op = self._create_node_proto(expr, "call")
492
492
  self._add_single_expr_inputs(op, expr.fn)
493
493
  self._add_expr_inputs(op, *expr.args)
494
+ self._add_attrs(op, name=expr.name)
494
495
  self._finalize_node(op, expr)
495
496
  elif isinstance(expr, WhileExpr):
496
497
  op = self._create_node_proto(expr, "while")
@@ -822,8 +823,12 @@ class Reader:
822
823
  arg_exprs.append(self._value_cache[dep_name])
823
824
  else:
824
825
  raise ValueError(f"Input {input_name} not found for call node")
826
+ # Optional call-site name attribute
827
+ call_name = None
828
+ if "name" in node_proto.attrs:
829
+ call_name = self._proto_to_attr(node_proto.attrs["name"]) # type: ignore[assignment]
825
830
 
826
- return CallExpr(fn_expr, arg_exprs)
831
+ return CallExpr(call_name or "", fn_expr, arg_exprs)
827
832
 
828
833
  def _proto_to_mptype(self, type_proto: mpir_pb2.MPTypeProto) -> MPType:
829
834
  """Convert MPTypeProto to MPType."""
mplang/core/primitive.py CHANGED
@@ -32,6 +32,7 @@ from mplang.core.context_mgr import cur_ctx
32
32
  from mplang.core.dtype import BOOL
33
33
  from mplang.core.expr.ast import (
34
34
  AccessExpr,
35
+ CallExpr,
35
36
  CondExpr,
36
37
  ConvExpr,
37
38
  EvalExpr,
@@ -87,30 +88,106 @@ P = ParamSpec("P")
87
88
  R = TypeVar("R")
88
89
 
89
90
 
90
- def primitive(fn: Callable[P, R]) -> Callable[P, R]:
91
+ def trace_before_apply(fn: Callable[P, R], make_call: bool) -> Callable[P, R]:
91
92
  """A decorator to make all primitive call in trace context."""
92
93
 
93
94
  @wraps(fn)
94
95
  def wrapped(*args: P.args, **kwargs: P.kwargs) -> R:
95
96
  current_ctx = cur_ctx()
96
97
  if isinstance(current_ctx, TraceContext):
97
- # If we are in a tracer context, just call the function.
98
- # Note: switch_ctx will do the capture if needed.
99
- args, kwargs = tree_map(partial(_switch_ctx, current_ctx), (args, kwargs))
100
- return fn(*args, **kwargs)
98
+ # If we are already in a tracer context
99
+ if make_call:
100
+ # make a primitive call
101
+ tracer = current_ctx
102
+ tfn = trace(tracer.fork(), fn, *args, **kwargs)
103
+ is_mpobj = lambda x: isinstance(x, MPObject)
104
+ in_vars, in_imms, in_struct = var_morph((args, kwargs), is_mpobj)
105
+ assert in_struct == tfn.in_struct and in_imms == tfn.in_imms
106
+ arg_exprs = [arg.expr for arg in in_vars]
107
+ # re-capture all captured variables into current context if needed.
108
+ cap_exprs = [tracer.capture(var).expr for var in tfn.capture_map.keys()]
109
+ caller_expr = CallExpr(
110
+ name=fn.__name__, fn=tfn.make_expr(), args=arg_exprs + cap_exprs
111
+ )
112
+ out_vars = [
113
+ TraceVar(tracer, AccessExpr(caller_expr, idx))
114
+ for idx in range(caller_expr.num_outputs)
115
+ ]
116
+ return cast(R, var_demorph(out_vars, tfn.out_imms, tfn.out_struct))
117
+ else:
118
+ # embed the function call in the current tracer context
119
+ # Note: switch_ctx will do the capture if needed.
120
+ args, kwargs = tree_map(
121
+ partial(_switch_ctx, current_ctx), (args, kwargs)
122
+ )
123
+ return fn(*args, **kwargs)
101
124
  elif isinstance(current_ctx, InterpContext):
102
125
  trace_ctx = TraceContext(current_ctx.cluster_spec, parent=current_ctx)
103
126
  # TODO(jint): should we add trace_and_apply to improve the performance?
104
- traced_fn = trace(trace_ctx, fn, *args, **kwargs)
127
+ tfn = trace(trace_ctx, fn, *args, **kwargs)
105
128
  # Return back to the original context.
106
- return cast(R, apply(current_ctx, traced_fn, *args, **kwargs))
129
+ return cast(R, apply(current_ctx, tfn, *args, **kwargs))
107
130
  else:
108
131
  raise ValueError(f"Unsupported context type: {type(current_ctx)}")
109
132
 
110
133
  return wrapped
111
134
 
112
135
 
113
- function = primitive
136
+ def primitive(fn: Callable[P, R]) -> Callable[P, R]:
137
+ """Decorator to trace a Python function as an opaque primitive call (`CallExpr`).
138
+
139
+ When a function decorated with `@primitive` is called within a `TraceContext`, it is
140
+ not inlined. Instead, it is traced separately in a forked context, and a `CallExpr`
141
+ node is inserted into the main graph. This is useful for encapsulating complex
142
+ operations or third-party library calls as single, opaque nodes.
143
+
144
+ **Implementation Note:**
145
+ A `CallExpr` represents a call to a single inline lambda (non-recursive, as we don't
146
+ have Y-combinator support). This single lambda call can be treated as a "primitive call"
147
+ by the printer/visualizer - hence the name "primitive". The function body is captured
148
+ once during tracing and represented as an opaque callable unit in the expression graph,
149
+ maintaining a clear boundary between the caller and callee contexts.
150
+
151
+ Args:
152
+ fn: The function to be traced as a primitive operation.
153
+
154
+ Returns:
155
+ A wrapped function that creates a `CallExpr` node when called in a trace context.
156
+
157
+ Example:
158
+ ```python
159
+ @primitive
160
+ def my_op(x: MPObject) -> MPObject:
161
+ # Complex logic traced as a single CallExpr node
162
+ return x + 1
163
+ ```
164
+ """
165
+ return trace_before_apply(fn, make_call=True)
166
+
167
+
168
+ def function(fn: Callable[P, R]) -> Callable[P, R]:
169
+ """Decorator to trace a Python function by inlining its body.
170
+
171
+ When a function decorated with `@function` is called within a `TraceContext`, its
172
+ underlying primitive operations are expanded and inserted directly into the caller's
173
+ graph. This is the default tracing behavior and is suitable for most pure-Python
174
+ multi-party functions.
175
+
176
+ Args:
177
+ fn: The function to be traced and inlined.
178
+
179
+ Returns:
180
+ A wrapped function that inlines its operations into the caller's trace context.
181
+
182
+ Example:
183
+ ```python
184
+ @function
185
+ def my_func(x: MPObject, y: MPObject) -> MPObject:
186
+ # Operations are inlined into the caller's trace
187
+ return x + y * constant(2)
188
+ ```
189
+ """
190
+ return trace_before_apply(fn, make_call=False)
114
191
 
115
192
 
116
193
  # ============================================================================
@@ -126,18 +203,15 @@ def _tracer() -> TraceContext:
126
203
  return ctx
127
204
 
128
205
 
129
- @primitive
130
206
  def psize() -> int:
131
207
  """Get the size of the current party world.
132
208
 
133
209
  Returns:
134
210
  int: The total number of parties in the current multi-party computation context.
135
211
  """
136
- ctx = _tracer()
137
- return ctx.world_size()
212
+ return cur_ctx().world_size()
138
213
 
139
214
 
140
- @primitive
141
215
  def pmask() -> Mask:
142
216
  """Get the current party mask in this computation context.
143
217
 
@@ -145,8 +219,7 @@ def pmask() -> Mask:
145
219
  Mask: The current party mask indicating which parties are active
146
220
  in the current computation context.
147
221
  """
148
- ctx = _tracer()
149
- return ctx.mask
222
+ return _tracer().mask
150
223
 
151
224
 
152
225
  @primitive
@@ -203,7 +276,6 @@ def prand(shape: Shape = ()) -> MPObject:
203
276
  return out_tree.unflatten(results) # type: ignore[no-any-return]
204
277
 
205
278
 
206
- @primitive
207
279
  def constant(data: TensorLike | ScalarType | TableLike) -> MPObject:
208
280
  """Create a constant tensor or table from data.
209
281
 
@@ -250,7 +322,7 @@ def debug_print(obj: MPObject, prefix: str = "") -> MPObject:
250
322
  return out_tree.unflatten(results) # type: ignore[no-any-return]
251
323
 
252
324
 
253
- @primitive
325
+ @function
254
326
  def peval(
255
327
  pfunc: PFunction,
256
328
  args: list[MPObject],
@@ -378,7 +450,7 @@ def set_mask(arg: MPObject, mask: Mask) -> MPObject:
378
450
  return out_tree.unflatten(results) # type: ignore[no-any-return]
379
451
 
380
452
 
381
- @primitive
453
+ @function
382
454
  def uniform_cond(
383
455
  pred: MPObject,
384
456
  then_fn: Callable[..., Any],
@@ -588,7 +660,7 @@ def uniform_cond(
588
660
  return var_demorph(out_vars, then_tfn.out_imms, then_tfn.out_struct) # type: ignore[no-any-return]
589
661
 
590
662
 
591
- @primitive
663
+ @function
592
664
  def while_loop(
593
665
  cond_fn: Callable[[Any], MPObject],
594
666
  body_fn: Callable[[Any], Any],
@@ -781,7 +853,7 @@ def while_loop(
781
853
  return var_demorph(out_vars, body_tfn.out_imms, body_tfn.out_struct)
782
854
 
783
855
 
784
- @primitive
856
+ @function
785
857
  def pshfl(src: MPObject, index: MPObject) -> MPObject:
786
858
  """Shuffle the input tensor to the specified index (dynamic version).
787
859
 
@@ -851,7 +923,7 @@ def pshfl(src: MPObject, index: MPObject) -> MPObject:
851
923
  return TraceVar(_tracer(), shfl_expr)
852
924
 
853
925
 
854
- @primitive
926
+ @function
855
927
  def pshfl_s(src_val: MPObject, pmask: Mask, src_ranks: list[Rank]) -> MPObject:
856
928
  """Shuffle the input tensor to the specified rank, static version.
857
929
 
@@ -910,7 +982,7 @@ def pshfl_s(src_val: MPObject, pmask: Mask, src_ranks: list[Rank]) -> MPObject:
910
982
  return TraceVar(_tracer(), shfl_s_expr)
911
983
 
912
984
 
913
- @primitive
985
+ @function
914
986
  def pconv(vars: list[MPObject]) -> MPObject:
915
987
  """Combine multiple variables that share the same dtype and shape into one.
916
988
 
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: mplang-nightly
3
- Version: 0.1.dev169
3
+ Version: 0.1.dev170
4
4
  Summary: Multi-Party Programming Language
5
5
  Author-email: SecretFlow Team <secretflow-contact@service.alipay.com>
6
6
  License: Apache License
@@ -10,19 +10,19 @@ mplang/core/context_mgr.py,sha256=R0QJAod-1nYduVoOknLfAsxZiy-RtmuQcp-07HABYZU,15
10
10
  mplang/core/dtype.py,sha256=0rZqFaFikFu9RxtdO36JLEgFL-E-lo3hH10whwkTVVY,10213
11
11
  mplang/core/interp.py,sha256=JKjKJGWURU5rlHQ2yG5XNKWzN6hLZsmo--hZuveQgxI,5915
12
12
  mplang/core/mask.py,sha256=14DFxaA446lGjN4dzTuQgm9Shcn34rYI87YJHg0YGNQ,10693
13
- mplang/core/mpir.py,sha256=V6S9RqegaI0yojhLkHla5nGBi27ASoxlrEs1k4tGubM,37980
13
+ mplang/core/mpir.py,sha256=3NyHa1cDnUaw3XWIUgyOMXfZ9JS-30COb29AoXYcRtM,38251
14
14
  mplang/core/mpobject.py,sha256=0pHSd7SrAFTScCFcB9ziDztElYQn-oIZOKBx47B3QX0,3732
15
15
  mplang/core/mptype.py,sha256=7Cp2e58uUX-uqTp6QxuioOMJ8BzLBPXlWG5rRakv2uo,13773
16
16
  mplang/core/pfunc.py,sha256=PAr8qRhVveWO5HOI0TgdsWjpi4PFi2iEyuTlr9UVKSY,5106
17
- mplang/core/primitive.py,sha256=-IkGqdbwtbMkLEOOTghXfuFtFvxu5jFQBupm5nPV-RI,40569
17
+ mplang/core/primitive.py,sha256=MxnHr12BorscFiz_sv-iVAbFWd4mdy-WYSg4ORtspOM,43871
18
18
  mplang/core/table.py,sha256=BqTBZn7Tfwce4vzl3XYhaX5hVmKagVq9-YoERDta6d8,5892
19
19
  mplang/core/tensor.py,sha256=86u6DogSZMoL0w5XjtTmQm2PhA_VjwybN1b6U4Zzphg,2361
20
20
  mplang/core/tracer.py,sha256=dVMfUeCMmPz4o6tLXewGMW1Kpy5gpZORvr9w4MhwDtM,14288
21
21
  mplang/core/expr/__init__.py,sha256=qwiSTUOcanFJLyK8HZ13_L1ZDrybqpPXIlTHAyeumE8,1988
22
- mplang/core/expr/ast.py,sha256=KE46KTtlH9RA2V_EzWVKCKolsycgTmt7SotUrOc8Qxs,20923
23
- mplang/core/expr/evaluator.py,sha256=EFy71vYUL2xLHCtMkWlYJpyGyujDdVSAx8ByET-62qQ,23297
24
- mplang/core/expr/printer.py,sha256=VblKGnO0OUfzH7EBkszwRNxQUB8QyyC7BlJWJEUv9so,9546
25
- mplang/core/expr/transformer.py,sha256=TyL-8FjrVvDq_C9X7kAuKkiqt2XdZM-okjzVQj0A33s,4893
22
+ mplang/core/expr/ast.py,sha256=K-rNqlpgkdjVzwSrLgunYnL4zWl1USJGLOgfz0qJNO4,20959
23
+ mplang/core/expr/evaluator.py,sha256=rpzZQPPVtxBvUuCx-9_bFmzr_7tfAQjPlP_rqpWjgIo,23313
24
+ mplang/core/expr/printer.py,sha256=RIHoG0gTDr6jPke391KekmPLu4AdeIfrhlROWa0XseQ,9883
25
+ mplang/core/expr/transformer.py,sha256=gez9eedVsWoLasSgWvPmGR8WfQnGXPlldWeVFEjqyYo,4904
26
26
  mplang/core/expr/utils.py,sha256=VDTJ_-CsdHtVy9wDaGa7XdFxQ7o5lYYaeqcgsAhkbNI,2625
27
27
  mplang/core/expr/visitor.py,sha256=2Ge-I5N-wH8VVXy8d2WyNaEv8x6seiRx9peyH9S2BYU,2044
28
28
  mplang/core/expr/walk.py,sha256=lXkGJEEuvKGDqQihbxXPxfz2RfR1Q1zYUlt11iooQW0,11889
@@ -73,8 +73,8 @@ mplang/utils/crypto.py,sha256=rvPomBFtznRHc3RPi6Aip9lsU8zW2oxBqGv1K3vn7Rs,1052
73
73
  mplang/utils/func_utils.py,sha256=vCJcZmu0bEbqhOQKdpttV2_MBllIcPSN0b8U4WjNGGo,5164
74
74
  mplang/utils/spu_utils.py,sha256=S3L9RBkBe2AvSuMSQQ12cBY5Y1NPthubvErSX_7nj1A,4158
75
75
  mplang/utils/table_utils.py,sha256=aC-IZOKkSmFkpr3NZchLM0Wt0GOn-rg_xHBHREWBwAU,2202
76
- mplang_nightly-0.1.dev169.dist-info/METADATA,sha256=3Ml9Mvi3n9iBnvcVp7dp7lFJABk1hUoIP5AB5BbeQFE,16547
77
- mplang_nightly-0.1.dev169.dist-info/WHEEL,sha256=qtCwoSJWgHk21S1Kb4ihdzI2rlJ1ZKaIurTj_ngOhyQ,87
78
- mplang_nightly-0.1.dev169.dist-info/entry_points.txt,sha256=mG1oJT-GAjQR834a62_QIWb7litzWPPyVnwFqm-rWuY,55
79
- mplang_nightly-0.1.dev169.dist-info/licenses/LICENSE,sha256=xx0jnfkXJvxRnG63LTGOxlggYnIysveWIZ6H3PNdCrQ,11357
80
- mplang_nightly-0.1.dev169.dist-info/RECORD,,
76
+ mplang_nightly-0.1.dev170.dist-info/METADATA,sha256=xh004xC6U_s5O350oQsAdZ3yFFJ5BvZz1dw2Kw3Tzuw,16547
77
+ mplang_nightly-0.1.dev170.dist-info/WHEEL,sha256=qtCwoSJWgHk21S1Kb4ihdzI2rlJ1ZKaIurTj_ngOhyQ,87
78
+ mplang_nightly-0.1.dev170.dist-info/entry_points.txt,sha256=mG1oJT-GAjQR834a62_QIWb7litzWPPyVnwFqm-rWuY,55
79
+ mplang_nightly-0.1.dev170.dist-info/licenses/LICENSE,sha256=xx0jnfkXJvxRnG63LTGOxlggYnIysveWIZ6H3PNdCrQ,11357
80
+ mplang_nightly-0.1.dev170.dist-info/RECORD,,