mplang-nightly 0.1.dev169__py3-none-any.whl → 0.1.dev170__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- mplang/core/expr/ast.py +2 -1
- mplang/core/expr/evaluator.py +2 -2
- mplang/core/expr/printer.py +16 -6
- mplang/core/expr/transformer.py +1 -1
- mplang/core/mpir.py +6 -1
- mplang/core/primitive.py +93 -21
- {mplang_nightly-0.1.dev169.dist-info → mplang_nightly-0.1.dev170.dist-info}/METADATA +1 -1
- {mplang_nightly-0.1.dev169.dist-info → mplang_nightly-0.1.dev170.dist-info}/RECORD +11 -11
- {mplang_nightly-0.1.dev169.dist-info → mplang_nightly-0.1.dev170.dist-info}/WHEEL +0 -0
- {mplang_nightly-0.1.dev169.dist-info → mplang_nightly-0.1.dev170.dist-info}/entry_points.txt +0 -0
- {mplang_nightly-0.1.dev169.dist-info → mplang_nightly-0.1.dev170.dist-info}/licenses/LICENSE +0 -0
mplang/core/expr/ast.py
CHANGED
@@ -528,8 +528,9 @@ class FuncDefExpr(Expr):
|
|
528
528
|
class CallExpr(Expr):
|
529
529
|
"""Expression for function call."""
|
530
530
|
|
531
|
-
def __init__(self, fn: FuncDefExpr, args: list[Expr]):
|
531
|
+
def __init__(self, name: str, fn: FuncDefExpr, args: list[Expr]):
|
532
532
|
super().__init__()
|
533
|
+
self.name = name
|
533
534
|
self.fn = fn
|
534
535
|
self.args = args
|
535
536
|
|
mplang/core/expr/evaluator.py
CHANGED
@@ -341,10 +341,10 @@ class RecursiveEvaluator(EvalSemantic, ExprVisitor):
|
|
341
341
|
|
342
342
|
# Only evaluate selected branch locally
|
343
343
|
if bool(pred):
|
344
|
-
then_call = CallExpr(expr.then_fn, expr.args)
|
344
|
+
then_call = CallExpr("then", expr.then_fn, expr.args)
|
345
345
|
return self._values(then_call)
|
346
346
|
else:
|
347
|
-
else_call = CallExpr(expr.else_fn, expr.args)
|
347
|
+
else_call = CallExpr("else", expr.else_fn, expr.args)
|
348
348
|
return self._values(else_call)
|
349
349
|
|
350
350
|
def visit_call(self, expr: CallExpr) -> Any:
|
mplang/core/expr/printer.py
CHANGED
@@ -50,11 +50,13 @@ class Printer(ExprVisitor):
|
|
50
50
|
compact_format: bool = True,
|
51
51
|
*,
|
52
52
|
verbose_peval: bool = False,
|
53
|
+
inline_pcall: bool = True,
|
53
54
|
):
|
54
55
|
super().__init__() # Initialize MemorizedVisitor
|
55
56
|
self.indent_size = indent_size
|
56
57
|
self.compact_format = compact_format
|
57
58
|
self.verbose_peval = verbose_peval
|
59
|
+
self.inline_pcall = inline_pcall
|
58
60
|
self._cur_indent = 0
|
59
61
|
self._output: list[str] = []
|
60
62
|
self._visited: dict[Expr, str] = {}
|
@@ -92,6 +94,7 @@ class Printer(ExprVisitor):
|
|
92
94
|
body_printer = Printer(
|
93
95
|
indent_size=self.indent_size,
|
94
96
|
compact_format=self.compact_format,
|
97
|
+
inline_pcall=self.inline_pcall,
|
95
98
|
)
|
96
99
|
func_def_expr.accept(body_printer)
|
97
100
|
regions_str += f"{indent}{r_name}: "
|
@@ -209,12 +212,19 @@ class Printer(ExprVisitor):
|
|
209
212
|
|
210
213
|
def visit_call(self, expr: CallExpr) -> str:
|
211
214
|
arg_names = [self._var_name(arg) for arg in expr.args]
|
212
|
-
|
213
|
-
|
214
|
-
|
215
|
-
|
216
|
-
|
217
|
-
|
215
|
+
if self.inline_pcall:
|
216
|
+
return self._do_print(
|
217
|
+
expr.name,
|
218
|
+
arg_names,
|
219
|
+
mptypes=expr.mptypes,
|
220
|
+
)
|
221
|
+
else:
|
222
|
+
return self._do_print(
|
223
|
+
"pcall",
|
224
|
+
arg_names,
|
225
|
+
regions={"fn": expr.fn},
|
226
|
+
mptypes=expr.mptypes,
|
227
|
+
)
|
218
228
|
|
219
229
|
def visit_while(self, expr: WhileExpr) -> str:
|
220
230
|
arg_names = [self._var_name(arg) for arg in expr.args]
|
mplang/core/expr/transformer.py
CHANGED
@@ -79,7 +79,7 @@ class ExprTransformer(ExprVisitor):
|
|
79
79
|
def visit_call(self, expr: CallExpr) -> Expr:
|
80
80
|
# Transform child expressions first
|
81
81
|
transformed_args = [arg.accept(self) for arg in expr.args]
|
82
|
-
new_expr = CallExpr(expr.fn, transformed_args)
|
82
|
+
new_expr = CallExpr(expr.name, expr.fn, transformed_args)
|
83
83
|
|
84
84
|
if "call" in self.trans_rules:
|
85
85
|
return self.trans_rules["call"](new_expr)
|
mplang/core/mpir.py
CHANGED
@@ -491,6 +491,7 @@ class Writer:
|
|
491
491
|
op = self._create_node_proto(expr, "call")
|
492
492
|
self._add_single_expr_inputs(op, expr.fn)
|
493
493
|
self._add_expr_inputs(op, *expr.args)
|
494
|
+
self._add_attrs(op, name=expr.name)
|
494
495
|
self._finalize_node(op, expr)
|
495
496
|
elif isinstance(expr, WhileExpr):
|
496
497
|
op = self._create_node_proto(expr, "while")
|
@@ -822,8 +823,12 @@ class Reader:
|
|
822
823
|
arg_exprs.append(self._value_cache[dep_name])
|
823
824
|
else:
|
824
825
|
raise ValueError(f"Input {input_name} not found for call node")
|
826
|
+
# Optional call-site name attribute
|
827
|
+
call_name = None
|
828
|
+
if "name" in node_proto.attrs:
|
829
|
+
call_name = self._proto_to_attr(node_proto.attrs["name"]) # type: ignore[assignment]
|
825
830
|
|
826
|
-
return CallExpr(fn_expr, arg_exprs)
|
831
|
+
return CallExpr(call_name or "", fn_expr, arg_exprs)
|
827
832
|
|
828
833
|
def _proto_to_mptype(self, type_proto: mpir_pb2.MPTypeProto) -> MPType:
|
829
834
|
"""Convert MPTypeProto to MPType."""
|
mplang/core/primitive.py
CHANGED
@@ -32,6 +32,7 @@ from mplang.core.context_mgr import cur_ctx
|
|
32
32
|
from mplang.core.dtype import BOOL
|
33
33
|
from mplang.core.expr.ast import (
|
34
34
|
AccessExpr,
|
35
|
+
CallExpr,
|
35
36
|
CondExpr,
|
36
37
|
ConvExpr,
|
37
38
|
EvalExpr,
|
@@ -87,30 +88,106 @@ P = ParamSpec("P")
|
|
87
88
|
R = TypeVar("R")
|
88
89
|
|
89
90
|
|
90
|
-
def
|
91
|
+
def trace_before_apply(fn: Callable[P, R], make_call: bool) -> Callable[P, R]:
|
91
92
|
"""A decorator to make all primitive call in trace context."""
|
92
93
|
|
93
94
|
@wraps(fn)
|
94
95
|
def wrapped(*args: P.args, **kwargs: P.kwargs) -> R:
|
95
96
|
current_ctx = cur_ctx()
|
96
97
|
if isinstance(current_ctx, TraceContext):
|
97
|
-
# If we are in a tracer context
|
98
|
-
|
99
|
-
|
100
|
-
|
98
|
+
# If we are already in a tracer context
|
99
|
+
if make_call:
|
100
|
+
# make a primitive call
|
101
|
+
tracer = current_ctx
|
102
|
+
tfn = trace(tracer.fork(), fn, *args, **kwargs)
|
103
|
+
is_mpobj = lambda x: isinstance(x, MPObject)
|
104
|
+
in_vars, in_imms, in_struct = var_morph((args, kwargs), is_mpobj)
|
105
|
+
assert in_struct == tfn.in_struct and in_imms == tfn.in_imms
|
106
|
+
arg_exprs = [arg.expr for arg in in_vars]
|
107
|
+
# re-capture all captured variables into current context if needed.
|
108
|
+
cap_exprs = [tracer.capture(var).expr for var in tfn.capture_map.keys()]
|
109
|
+
caller_expr = CallExpr(
|
110
|
+
name=fn.__name__, fn=tfn.make_expr(), args=arg_exprs + cap_exprs
|
111
|
+
)
|
112
|
+
out_vars = [
|
113
|
+
TraceVar(tracer, AccessExpr(caller_expr, idx))
|
114
|
+
for idx in range(caller_expr.num_outputs)
|
115
|
+
]
|
116
|
+
return cast(R, var_demorph(out_vars, tfn.out_imms, tfn.out_struct))
|
117
|
+
else:
|
118
|
+
# embed the function call in the current tracer context
|
119
|
+
# Note: switch_ctx will do the capture if needed.
|
120
|
+
args, kwargs = tree_map(
|
121
|
+
partial(_switch_ctx, current_ctx), (args, kwargs)
|
122
|
+
)
|
123
|
+
return fn(*args, **kwargs)
|
101
124
|
elif isinstance(current_ctx, InterpContext):
|
102
125
|
trace_ctx = TraceContext(current_ctx.cluster_spec, parent=current_ctx)
|
103
126
|
# TODO(jint): should we add trace_and_apply to improve the performance?
|
104
|
-
|
127
|
+
tfn = trace(trace_ctx, fn, *args, **kwargs)
|
105
128
|
# Return back to the original context.
|
106
|
-
return cast(R, apply(current_ctx,
|
129
|
+
return cast(R, apply(current_ctx, tfn, *args, **kwargs))
|
107
130
|
else:
|
108
131
|
raise ValueError(f"Unsupported context type: {type(current_ctx)}")
|
109
132
|
|
110
133
|
return wrapped
|
111
134
|
|
112
135
|
|
113
|
-
|
136
|
+
def primitive(fn: Callable[P, R]) -> Callable[P, R]:
|
137
|
+
"""Decorator to trace a Python function as an opaque primitive call (`CallExpr`).
|
138
|
+
|
139
|
+
When a function decorated with `@primitive` is called within a `TraceContext`, it is
|
140
|
+
not inlined. Instead, it is traced separately in a forked context, and a `CallExpr`
|
141
|
+
node is inserted into the main graph. This is useful for encapsulating complex
|
142
|
+
operations or third-party library calls as single, opaque nodes.
|
143
|
+
|
144
|
+
**Implementation Note:**
|
145
|
+
A `CallExpr` represents a call to a single inline lambda (non-recursive, as we don't
|
146
|
+
have Y-combinator support). This single lambda call can be treated as a "primitive call"
|
147
|
+
by the printer/visualizer - hence the name "primitive". The function body is captured
|
148
|
+
once during tracing and represented as an opaque callable unit in the expression graph,
|
149
|
+
maintaining a clear boundary between the caller and callee contexts.
|
150
|
+
|
151
|
+
Args:
|
152
|
+
fn: The function to be traced as a primitive operation.
|
153
|
+
|
154
|
+
Returns:
|
155
|
+
A wrapped function that creates a `CallExpr` node when called in a trace context.
|
156
|
+
|
157
|
+
Example:
|
158
|
+
```python
|
159
|
+
@primitive
|
160
|
+
def my_op(x: MPObject) -> MPObject:
|
161
|
+
# Complex logic traced as a single CallExpr node
|
162
|
+
return x + 1
|
163
|
+
```
|
164
|
+
"""
|
165
|
+
return trace_before_apply(fn, make_call=True)
|
166
|
+
|
167
|
+
|
168
|
+
def function(fn: Callable[P, R]) -> Callable[P, R]:
|
169
|
+
"""Decorator to trace a Python function by inlining its body.
|
170
|
+
|
171
|
+
When a function decorated with `@function` is called within a `TraceContext`, its
|
172
|
+
underlying primitive operations are expanded and inserted directly into the caller's
|
173
|
+
graph. This is the default tracing behavior and is suitable for most pure-Python
|
174
|
+
multi-party functions.
|
175
|
+
|
176
|
+
Args:
|
177
|
+
fn: The function to be traced and inlined.
|
178
|
+
|
179
|
+
Returns:
|
180
|
+
A wrapped function that inlines its operations into the caller's trace context.
|
181
|
+
|
182
|
+
Example:
|
183
|
+
```python
|
184
|
+
@function
|
185
|
+
def my_func(x: MPObject, y: MPObject) -> MPObject:
|
186
|
+
# Operations are inlined into the caller's trace
|
187
|
+
return x + y * constant(2)
|
188
|
+
```
|
189
|
+
"""
|
190
|
+
return trace_before_apply(fn, make_call=False)
|
114
191
|
|
115
192
|
|
116
193
|
# ============================================================================
|
@@ -126,18 +203,15 @@ def _tracer() -> TraceContext:
|
|
126
203
|
return ctx
|
127
204
|
|
128
205
|
|
129
|
-
@primitive
|
130
206
|
def psize() -> int:
|
131
207
|
"""Get the size of the current party world.
|
132
208
|
|
133
209
|
Returns:
|
134
210
|
int: The total number of parties in the current multi-party computation context.
|
135
211
|
"""
|
136
|
-
|
137
|
-
return ctx.world_size()
|
212
|
+
return cur_ctx().world_size()
|
138
213
|
|
139
214
|
|
140
|
-
@primitive
|
141
215
|
def pmask() -> Mask:
|
142
216
|
"""Get the current party mask in this computation context.
|
143
217
|
|
@@ -145,8 +219,7 @@ def pmask() -> Mask:
|
|
145
219
|
Mask: The current party mask indicating which parties are active
|
146
220
|
in the current computation context.
|
147
221
|
"""
|
148
|
-
|
149
|
-
return ctx.mask
|
222
|
+
return _tracer().mask
|
150
223
|
|
151
224
|
|
152
225
|
@primitive
|
@@ -203,7 +276,6 @@ def prand(shape: Shape = ()) -> MPObject:
|
|
203
276
|
return out_tree.unflatten(results) # type: ignore[no-any-return]
|
204
277
|
|
205
278
|
|
206
|
-
@primitive
|
207
279
|
def constant(data: TensorLike | ScalarType | TableLike) -> MPObject:
|
208
280
|
"""Create a constant tensor or table from data.
|
209
281
|
|
@@ -250,7 +322,7 @@ def debug_print(obj: MPObject, prefix: str = "") -> MPObject:
|
|
250
322
|
return out_tree.unflatten(results) # type: ignore[no-any-return]
|
251
323
|
|
252
324
|
|
253
|
-
@
|
325
|
+
@function
|
254
326
|
def peval(
|
255
327
|
pfunc: PFunction,
|
256
328
|
args: list[MPObject],
|
@@ -378,7 +450,7 @@ def set_mask(arg: MPObject, mask: Mask) -> MPObject:
|
|
378
450
|
return out_tree.unflatten(results) # type: ignore[no-any-return]
|
379
451
|
|
380
452
|
|
381
|
-
@
|
453
|
+
@function
|
382
454
|
def uniform_cond(
|
383
455
|
pred: MPObject,
|
384
456
|
then_fn: Callable[..., Any],
|
@@ -588,7 +660,7 @@ def uniform_cond(
|
|
588
660
|
return var_demorph(out_vars, then_tfn.out_imms, then_tfn.out_struct) # type: ignore[no-any-return]
|
589
661
|
|
590
662
|
|
591
|
-
@
|
663
|
+
@function
|
592
664
|
def while_loop(
|
593
665
|
cond_fn: Callable[[Any], MPObject],
|
594
666
|
body_fn: Callable[[Any], Any],
|
@@ -781,7 +853,7 @@ def while_loop(
|
|
781
853
|
return var_demorph(out_vars, body_tfn.out_imms, body_tfn.out_struct)
|
782
854
|
|
783
855
|
|
784
|
-
@
|
856
|
+
@function
|
785
857
|
def pshfl(src: MPObject, index: MPObject) -> MPObject:
|
786
858
|
"""Shuffle the input tensor to the specified index (dynamic version).
|
787
859
|
|
@@ -851,7 +923,7 @@ def pshfl(src: MPObject, index: MPObject) -> MPObject:
|
|
851
923
|
return TraceVar(_tracer(), shfl_expr)
|
852
924
|
|
853
925
|
|
854
|
-
@
|
926
|
+
@function
|
855
927
|
def pshfl_s(src_val: MPObject, pmask: Mask, src_ranks: list[Rank]) -> MPObject:
|
856
928
|
"""Shuffle the input tensor to the specified rank, static version.
|
857
929
|
|
@@ -910,7 +982,7 @@ def pshfl_s(src_val: MPObject, pmask: Mask, src_ranks: list[Rank]) -> MPObject:
|
|
910
982
|
return TraceVar(_tracer(), shfl_s_expr)
|
911
983
|
|
912
984
|
|
913
|
-
@
|
985
|
+
@function
|
914
986
|
def pconv(vars: list[MPObject]) -> MPObject:
|
915
987
|
"""Combine multiple variables that share the same dtype and shape into one.
|
916
988
|
|
@@ -10,19 +10,19 @@ mplang/core/context_mgr.py,sha256=R0QJAod-1nYduVoOknLfAsxZiy-RtmuQcp-07HABYZU,15
|
|
10
10
|
mplang/core/dtype.py,sha256=0rZqFaFikFu9RxtdO36JLEgFL-E-lo3hH10whwkTVVY,10213
|
11
11
|
mplang/core/interp.py,sha256=JKjKJGWURU5rlHQ2yG5XNKWzN6hLZsmo--hZuveQgxI,5915
|
12
12
|
mplang/core/mask.py,sha256=14DFxaA446lGjN4dzTuQgm9Shcn34rYI87YJHg0YGNQ,10693
|
13
|
-
mplang/core/mpir.py,sha256=
|
13
|
+
mplang/core/mpir.py,sha256=3NyHa1cDnUaw3XWIUgyOMXfZ9JS-30COb29AoXYcRtM,38251
|
14
14
|
mplang/core/mpobject.py,sha256=0pHSd7SrAFTScCFcB9ziDztElYQn-oIZOKBx47B3QX0,3732
|
15
15
|
mplang/core/mptype.py,sha256=7Cp2e58uUX-uqTp6QxuioOMJ8BzLBPXlWG5rRakv2uo,13773
|
16
16
|
mplang/core/pfunc.py,sha256=PAr8qRhVveWO5HOI0TgdsWjpi4PFi2iEyuTlr9UVKSY,5106
|
17
|
-
mplang/core/primitive.py,sha256
|
17
|
+
mplang/core/primitive.py,sha256=MxnHr12BorscFiz_sv-iVAbFWd4mdy-WYSg4ORtspOM,43871
|
18
18
|
mplang/core/table.py,sha256=BqTBZn7Tfwce4vzl3XYhaX5hVmKagVq9-YoERDta6d8,5892
|
19
19
|
mplang/core/tensor.py,sha256=86u6DogSZMoL0w5XjtTmQm2PhA_VjwybN1b6U4Zzphg,2361
|
20
20
|
mplang/core/tracer.py,sha256=dVMfUeCMmPz4o6tLXewGMW1Kpy5gpZORvr9w4MhwDtM,14288
|
21
21
|
mplang/core/expr/__init__.py,sha256=qwiSTUOcanFJLyK8HZ13_L1ZDrybqpPXIlTHAyeumE8,1988
|
22
|
-
mplang/core/expr/ast.py,sha256=
|
23
|
-
mplang/core/expr/evaluator.py,sha256=
|
24
|
-
mplang/core/expr/printer.py,sha256=
|
25
|
-
mplang/core/expr/transformer.py,sha256=
|
22
|
+
mplang/core/expr/ast.py,sha256=K-rNqlpgkdjVzwSrLgunYnL4zWl1USJGLOgfz0qJNO4,20959
|
23
|
+
mplang/core/expr/evaluator.py,sha256=rpzZQPPVtxBvUuCx-9_bFmzr_7tfAQjPlP_rqpWjgIo,23313
|
24
|
+
mplang/core/expr/printer.py,sha256=RIHoG0gTDr6jPke391KekmPLu4AdeIfrhlROWa0XseQ,9883
|
25
|
+
mplang/core/expr/transformer.py,sha256=gez9eedVsWoLasSgWvPmGR8WfQnGXPlldWeVFEjqyYo,4904
|
26
26
|
mplang/core/expr/utils.py,sha256=VDTJ_-CsdHtVy9wDaGa7XdFxQ7o5lYYaeqcgsAhkbNI,2625
|
27
27
|
mplang/core/expr/visitor.py,sha256=2Ge-I5N-wH8VVXy8d2WyNaEv8x6seiRx9peyH9S2BYU,2044
|
28
28
|
mplang/core/expr/walk.py,sha256=lXkGJEEuvKGDqQihbxXPxfz2RfR1Q1zYUlt11iooQW0,11889
|
@@ -73,8 +73,8 @@ mplang/utils/crypto.py,sha256=rvPomBFtznRHc3RPi6Aip9lsU8zW2oxBqGv1K3vn7Rs,1052
|
|
73
73
|
mplang/utils/func_utils.py,sha256=vCJcZmu0bEbqhOQKdpttV2_MBllIcPSN0b8U4WjNGGo,5164
|
74
74
|
mplang/utils/spu_utils.py,sha256=S3L9RBkBe2AvSuMSQQ12cBY5Y1NPthubvErSX_7nj1A,4158
|
75
75
|
mplang/utils/table_utils.py,sha256=aC-IZOKkSmFkpr3NZchLM0Wt0GOn-rg_xHBHREWBwAU,2202
|
76
|
-
mplang_nightly-0.1.
|
77
|
-
mplang_nightly-0.1.
|
78
|
-
mplang_nightly-0.1.
|
79
|
-
mplang_nightly-0.1.
|
80
|
-
mplang_nightly-0.1.
|
76
|
+
mplang_nightly-0.1.dev170.dist-info/METADATA,sha256=xh004xC6U_s5O350oQsAdZ3yFFJ5BvZz1dw2Kw3Tzuw,16547
|
77
|
+
mplang_nightly-0.1.dev170.dist-info/WHEEL,sha256=qtCwoSJWgHk21S1Kb4ihdzI2rlJ1ZKaIurTj_ngOhyQ,87
|
78
|
+
mplang_nightly-0.1.dev170.dist-info/entry_points.txt,sha256=mG1oJT-GAjQR834a62_QIWb7litzWPPyVnwFqm-rWuY,55
|
79
|
+
mplang_nightly-0.1.dev170.dist-info/licenses/LICENSE,sha256=xx0jnfkXJvxRnG63LTGOxlggYnIysveWIZ6H3PNdCrQ,11357
|
80
|
+
mplang_nightly-0.1.dev170.dist-info/RECORD,,
|
File without changes
|
{mplang_nightly-0.1.dev169.dist-info → mplang_nightly-0.1.dev170.dist-info}/entry_points.txt
RENAMED
File without changes
|
{mplang_nightly-0.1.dev169.dist-info → mplang_nightly-0.1.dev170.dist-info}/licenses/LICENSE
RENAMED
File without changes
|