mplang-nightly 0.1.dev145__py3-none-any.whl → 0.1.dev147__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- mplang/core/expr/evaluator.py +26 -8
- mplang/core/primitive.py +30 -0
- {mplang_nightly-0.1.dev145.dist-info → mplang_nightly-0.1.dev147.dist-info}/METADATA +1 -1
- {mplang_nightly-0.1.dev145.dist-info → mplang_nightly-0.1.dev147.dist-info}/RECORD +7 -7
- {mplang_nightly-0.1.dev145.dist-info → mplang_nightly-0.1.dev147.dist-info}/WHEEL +0 -0
- {mplang_nightly-0.1.dev145.dist-info → mplang_nightly-0.1.dev147.dist-info}/entry_points.txt +0 -0
- {mplang_nightly-0.1.dev145.dist-info → mplang_nightly-0.1.dev147.dist-info}/licenses/LICENSE +0 -0
mplang/core/expr/evaluator.py
CHANGED
@@ -196,6 +196,28 @@ class EvalSemantic:
|
|
196
196
|
"uniform_cond: predicate is not uniform across parties"
|
197
197
|
)
|
198
198
|
|
199
|
+
# ------------------------------ While helpers ------------------------------
|
200
|
+
def _check_while_predicate(self, cond_result: list[Any]) -> Any:
|
201
|
+
"""Validate while_loop predicate evaluation result.
|
202
|
+
|
203
|
+
Ensures the condition function returns exactly one value and that value
|
204
|
+
is non-None. Returns the boolean predicate value for convenience.
|
205
|
+
|
206
|
+
Raises:
|
207
|
+
AssertionError: If condition function returns != 1 value.
|
208
|
+
RuntimeError: If the single predicate value is None.
|
209
|
+
"""
|
210
|
+
assert len(cond_result) == 1, (
|
211
|
+
f"Condition function must return a single value, got {cond_result}"
|
212
|
+
)
|
213
|
+
cond_value = cond_result[0]
|
214
|
+
if cond_value is None:
|
215
|
+
raise RuntimeError(
|
216
|
+
"while_loop condition produced None on rank "
|
217
|
+
f"{self.rank}; ensure the predicate yields a boolean for every party."
|
218
|
+
)
|
219
|
+
return cond_value
|
220
|
+
|
199
221
|
|
200
222
|
class RecursiveEvaluator(EvalSemantic, ExprVisitor):
|
201
223
|
"""Recursive visitor-based evaluator."""
|
@@ -307,12 +329,8 @@ class RecursiveEvaluator(EvalSemantic, ExprVisitor):
|
|
307
329
|
cond_env = dict(zip(expr.cond_fn.params, state, strict=True))
|
308
330
|
cond_evaluator = self._fork(cond_env)
|
309
331
|
cond_result = expr.cond_fn.body.accept(cond_evaluator)
|
310
|
-
|
311
|
-
|
312
|
-
f"Condition function must return a single value, got {cond_result}"
|
313
|
-
)
|
314
|
-
|
315
|
-
if not cond_result[0]:
|
332
|
+
cond_value = self._check_while_predicate(cond_result)
|
333
|
+
if not cond_value:
|
316
334
|
break
|
317
335
|
|
318
336
|
# Call body function with same arguments
|
@@ -445,8 +463,8 @@ class IterativeEvaluator(EvalSemantic):
|
|
445
463
|
cond_vals = self._iter_eval_graph(
|
446
464
|
node.cond_fn.body, {**env, **cond_env}
|
447
465
|
)
|
448
|
-
|
449
|
-
if not bool(
|
466
|
+
cond_val = self._check_while_predicate(cond_vals)
|
467
|
+
if not bool(cond_val):
|
450
468
|
break
|
451
469
|
body_env = dict(zip(node.body_fn.params, state, strict=True))
|
452
470
|
new_state = self._iter_eval_graph(
|
mplang/core/primitive.py
CHANGED
@@ -483,6 +483,20 @@ def uniform_cond(
|
|
483
483
|
if pred_ty.dtype != BOOL:
|
484
484
|
raise TypeError(f"uniform_cond predicate must be boolean, got {pred_ty.dtype}")
|
485
485
|
|
486
|
+
# Static pmask rule:
|
487
|
+
# If predicate has a static pmask (not None), it must equal the current trace
|
488
|
+
# context mask. Otherwise some parties would execute a branch without a
|
489
|
+
# defined predicate value (unsafe). To run on a subset either:
|
490
|
+
# 1. Trace the entire uniform_cond under a subset TraceContext (ctx.fork(mask=...))
|
491
|
+
# 2. Broadcast / lift predicate to full mask (e.g. pshfl_s)
|
492
|
+
# Pred pmask None => dynamic: defer to runtime uniformity (if verify_uniform=True).
|
493
|
+
pred_pmask = pred_ty.pmask
|
494
|
+
if pred_pmask is not None and pred_pmask != cur_tracer.mask:
|
495
|
+
raise ValueError(
|
496
|
+
"uniform_cond predicate static pmask mismatch: predicate pmask="
|
497
|
+
f"{pred_pmask} trace mask={cur_tracer.mask}. Trace under a subset "
|
498
|
+
"context (ctx.fork(mask=...)) or broadcast predicate (pshfl_s) to all parties."
|
499
|
+
)
|
486
500
|
# Step 1: Trace both branches in separate contexts
|
487
501
|
then_tracer = cur_tracer.fork()
|
488
502
|
then_tfn = trace(then_tracer, then_fn, *args)
|
@@ -706,6 +720,22 @@ def while_loop(
|
|
706
720
|
f"Condition function must return a boolean scalar, got dtype {cond_out_var.mptype.dtype}"
|
707
721
|
)
|
708
722
|
|
723
|
+
# Static pmask rule:
|
724
|
+
# If the predicate's pmask is statically known it must match the trace context
|
725
|
+
# mask. Otherwise some parties in this context would lack a boolean to drive
|
726
|
+
# control flow (previously could lead to hang via None). To restrict to a subset:
|
727
|
+
# 1. Trace the entire while_loop under a subset context (ctx.fork(mask=submask)), or
|
728
|
+
# 2. Broadcast predicate to full mask (e.g. pshfl_s) before while_loop.
|
729
|
+
# Dynamic predicates (pmask=None) are allowed; runtime guard (evaluator) raises
|
730
|
+
# if any participating party observes None.
|
731
|
+
pred_pmask = cond_out_var.mptype.pmask
|
732
|
+
if pred_pmask is not None and pred_pmask != cur_tracer.mask:
|
733
|
+
raise ValueError(
|
734
|
+
"while_loop predicate static pmask mismatch: predicate pmask="
|
735
|
+
f"{pred_pmask} trace mask={cur_tracer.mask}. Trace under subset context "
|
736
|
+
"or broadcast predicate to all parties."
|
737
|
+
)
|
738
|
+
|
709
739
|
# Validate body returns same number of leaves and same dtype/shape per leaf
|
710
740
|
if len(body_tfn.out_vars) != len(cond_tfn.in_vars):
|
711
741
|
raise ValueError(
|
@@ -24,13 +24,13 @@ mplang/core/mpir.py,sha256=V6S9RqegaI0yojhLkHla5nGBi27ASoxlrEs1k4tGubM,37980
|
|
24
24
|
mplang/core/mpobject.py,sha256=0pHSd7SrAFTScCFcB9ziDztElYQn-oIZOKBx47B3QX0,3732
|
25
25
|
mplang/core/mptype.py,sha256=09LbMyJp68W0IkbD0s9YLeVssPg3Rl-rcwjTaCfidIQ,15243
|
26
26
|
mplang/core/pfunc.py,sha256=PAr8qRhVveWO5HOI0TgdsWjpi4PFi2iEyuTlr9UVKSY,5106
|
27
|
-
mplang/core/primitive.py,sha256=
|
27
|
+
mplang/core/primitive.py,sha256=C1HMbqmkAvLbdgXiHrWPTQ2v2t1uwC_vsGCtI0TEqHY,40574
|
28
28
|
mplang/core/table.py,sha256=BqTBZn7Tfwce4vzl3XYhaX5hVmKagVq9-YoERDta6d8,5892
|
29
29
|
mplang/core/tensor.py,sha256=86u6DogSZMoL0w5XjtTmQm2PhA_VjwybN1b6U4Zzphg,2361
|
30
30
|
mplang/core/tracer.py,sha256=dVMfUeCMmPz4o6tLXewGMW1Kpy5gpZORvr9w4MhwDtM,14288
|
31
31
|
mplang/core/expr/__init__.py,sha256=qwiSTUOcanFJLyK8HZ13_L1ZDrybqpPXIlTHAyeumE8,1988
|
32
32
|
mplang/core/expr/ast.py,sha256=KE46KTtlH9RA2V_EzWVKCKolsycgTmt7SotUrOc8Qxs,20923
|
33
|
-
mplang/core/expr/evaluator.py,sha256=
|
33
|
+
mplang/core/expr/evaluator.py,sha256=OYmxkr4Lf2qMHnHK-aca-dfMsAAzGRVWuXrxNk_M_3U,21675
|
34
34
|
mplang/core/expr/printer.py,sha256=VblKGnO0OUfzH7EBkszwRNxQUB8QyyC7BlJWJEUv9so,9546
|
35
35
|
mplang/core/expr/transformer.py,sha256=TyL-8FjrVvDq_C9X7kAuKkiqt2XdZM-okjzVQj0A33s,4893
|
36
36
|
mplang/core/expr/utils.py,sha256=VDTJ_-CsdHtVy9wDaGa7XdFxQ7o5lYYaeqcgsAhkbNI,2625
|
@@ -70,8 +70,8 @@ mplang/utils/crypto.py,sha256=rvPomBFtznRHc3RPi6Aip9lsU8zW2oxBqGv1K3vn7Rs,1052
|
|
70
70
|
mplang/utils/func_utils.py,sha256=vCJcZmu0bEbqhOQKdpttV2_MBllIcPSN0b8U4WjNGGo,5164
|
71
71
|
mplang/utils/spu_utils.py,sha256=S3L9RBkBe2AvSuMSQQ12cBY5Y1NPthubvErSX_7nj1A,4158
|
72
72
|
mplang/utils/table_utils.py,sha256=aC-IZOKkSmFkpr3NZchLM0Wt0GOn-rg_xHBHREWBwAU,2202
|
73
|
-
mplang_nightly-0.1.
|
74
|
-
mplang_nightly-0.1.
|
75
|
-
mplang_nightly-0.1.
|
76
|
-
mplang_nightly-0.1.
|
77
|
-
mplang_nightly-0.1.
|
73
|
+
mplang_nightly-0.1.dev147.dist-info/METADATA,sha256=WxU1KBbckH3JcYQF1jnDD8ph-EE9wKhbompdIg0KTP8,16547
|
74
|
+
mplang_nightly-0.1.dev147.dist-info/WHEEL,sha256=qtCwoSJWgHk21S1Kb4ihdzI2rlJ1ZKaIurTj_ngOhyQ,87
|
75
|
+
mplang_nightly-0.1.dev147.dist-info/entry_points.txt,sha256=mG1oJT-GAjQR834a62_QIWb7litzWPPyVnwFqm-rWuY,55
|
76
|
+
mplang_nightly-0.1.dev147.dist-info/licenses/LICENSE,sha256=xx0jnfkXJvxRnG63LTGOxlggYnIysveWIZ6H3PNdCrQ,11357
|
77
|
+
mplang_nightly-0.1.dev147.dist-info/RECORD,,
|
File without changes
|
{mplang_nightly-0.1.dev145.dist-info → mplang_nightly-0.1.dev147.dist-info}/entry_points.txt
RENAMED
File without changes
|
{mplang_nightly-0.1.dev145.dist-info → mplang_nightly-0.1.dev147.dist-info}/licenses/LICENSE
RENAMED
File without changes
|