mplang-nightly 0.1.dev148__py3-none-any.whl → 0.1.dev149__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- mplang/backend/stablehlo.py +8 -1
- mplang/frontend/jax_cc.py +39 -7
- {mplang_nightly-0.1.dev148.dist-info → mplang_nightly-0.1.dev149.dist-info}/METADATA +1 -1
- {mplang_nightly-0.1.dev148.dist-info → mplang_nightly-0.1.dev149.dist-info}/RECORD +7 -7
- {mplang_nightly-0.1.dev148.dist-info → mplang_nightly-0.1.dev149.dist-info}/WHEEL +0 -0
- {mplang_nightly-0.1.dev148.dist-info → mplang_nightly-0.1.dev149.dist-info}/entry_points.txt +0 -0
- {mplang_nightly-0.1.dev148.dist-info → mplang_nightly-0.1.dev149.dist-info}/licenses/LICENSE +0 -0
mplang/backend/stablehlo.py
CHANGED
@@ -51,8 +51,15 @@ def _stablehlo_exec(pfunc: PFunction, *args: Any) -> Any:
|
|
51
51
|
raise RuntimeError(f"StableHLO compile failed: {e}") from e
|
52
52
|
cache[mlir_text] = compiled
|
53
53
|
|
54
|
+
# Handle JAX's unused parameter elimination via arg_keep_map
|
55
|
+
runtime_args = args
|
56
|
+
if "arg_keep_map" in pfunc.attrs:
|
57
|
+
keep_indices = pfunc.attrs["arg_keep_map"]
|
58
|
+
# Filter out arguments that were eliminated by JAX during compilation
|
59
|
+
runtime_args = tuple(args[i] for i in keep_indices)
|
60
|
+
|
54
61
|
jax_args = []
|
55
|
-
for arg in
|
62
|
+
for arg in runtime_args:
|
56
63
|
if hasattr(arg, "numpy"):
|
57
64
|
jax_arg = jnp.array(arg.numpy()) # type: ignore
|
58
65
|
else:
|
mplang/frontend/jax_cc.py
CHANGED
@@ -106,14 +106,46 @@ def jax2stablehlo(
|
|
106
106
|
out_info_flat, out_tree = tree_flatten(lowered.out_info)
|
107
107
|
out_info_flat = [TensorType.from_obj(info) for info in out_info_flat]
|
108
108
|
|
109
|
+
# Extract argument keep mapping to handle JAX's unused parameter elimination
|
110
|
+
# JAX can eliminate unused parameters during compilation, but the runtime still
|
111
|
+
# receives all original arguments. We need the mapping to filter them correctly.
|
112
|
+
arg_keep_map = None
|
113
|
+
original_arg_count = len(in_vars)
|
114
|
+
|
115
|
+
try:
|
116
|
+
# Access JAX internal kept_var_idx - the authoritative source
|
117
|
+
# This tells us exactly which original parameters survived compilation
|
118
|
+
compile_args = lowered._lowering.compile_args
|
119
|
+
kept_var_idx = compile_args["kept_var_idx"]
|
120
|
+
|
121
|
+
kept_indices = sorted(kept_var_idx)
|
122
|
+
if len(kept_indices) < original_arg_count:
|
123
|
+
arg_keep_map = kept_indices
|
124
|
+
|
125
|
+
except (AttributeError, KeyError, TypeError) as e:
|
126
|
+
# JAX internal API is not available or changed
|
127
|
+
# This is a hard error - we cannot reliably handle unused parameters
|
128
|
+
# without knowing exactly which ones were kept
|
129
|
+
raise RuntimeError(
|
130
|
+
f"Cannot access JAX's kept_var_idx to handle unused parameter elimination. "
|
131
|
+
f"This function may have unused parameters that JAX optimized away, "
|
132
|
+
f"but we cannot determine which ones without the internal API. "
|
133
|
+
f"Original error: {e}"
|
134
|
+
) from e
|
135
|
+
|
109
136
|
# This format tells JaxRT how to handle the compiled result
|
110
|
-
|
111
|
-
fn_type
|
112
|
-
ins_info
|
113
|
-
outs_info
|
114
|
-
fn_name
|
115
|
-
fn_text
|
116
|
-
|
137
|
+
pfn_kwargs: dict[str, Any] = {
|
138
|
+
"fn_type": "mlir.stablehlo", # Key: specify StableHLO MLIR format
|
139
|
+
"ins_info": tuple(TensorType.from_obj(x) for x in in_vars),
|
140
|
+
"outs_info": tuple(out_info_flat),
|
141
|
+
"fn_name": get_fn_name(flat_fn),
|
142
|
+
"fn_text": mlir_text, # MLIR text, serializable for transmission
|
143
|
+
}
|
144
|
+
|
145
|
+
if arg_keep_map is not None:
|
146
|
+
pfn_kwargs["arg_keep_map"] = arg_keep_map
|
147
|
+
|
148
|
+
pfn = PFunction(**pfn_kwargs)
|
117
149
|
return pfn, in_vars, out_tree
|
118
150
|
|
119
151
|
|
@@ -11,7 +11,7 @@ mplang/backend/crypto.py,sha256=H_s5HI7lUP7g0xz-a9qMbSn6dhJStUilKbn3-7SIh0I,3812
|
|
11
11
|
mplang/backend/phe.py,sha256=uNqmrbDAbd97TWS_O6D5sopastHy6J20R7knFE4M4uc,65247
|
12
12
|
mplang/backend/spu.py,sha256=QT1q5uv-5P_nBGtTvtA_yI2h3h3zIqNSnvzGT7Shua4,9307
|
13
13
|
mplang/backend/sql_duckdb.py,sha256=U_KzEUinxrBRDoUz2Vh597-N4I3hPOBT0RT3tX-ZqKE,1502
|
14
|
-
mplang/backend/stablehlo.py,sha256=
|
14
|
+
mplang/backend/stablehlo.py,sha256=RhKf6TUvjLrRvgtdVY2HxcRDGtjpKBobuBFMfsvZQOI,2937
|
15
15
|
mplang/backend/tee.py,sha256=6kc7qTe8nWc3pr6iYtozEGLO8Umg-UBQLDiz6p3pdVg,1918
|
16
16
|
mplang/core/__init__.py,sha256=lWxlEKfRwX7FNDzgyKZ1fiDMaCiqkyg0j5mKlZD_v7g,2244
|
17
17
|
mplang/core/cluster.py,sha256=gqMJenvXUfHhE181Dd5JiUkD4nT07RLoicBnvsGmRkE,8598
|
@@ -41,7 +41,7 @@ mplang/frontend/base.py,sha256=rGtfBejcDh9mTRxOdJK5VUlG5vYiVJSir8X72X0Huvc,18264
|
|
41
41
|
mplang/frontend/builtin.py,sha256=8qrlbe_SSy6QTXTnMG6_ADB8jSklVZGFBrkoR-p02FE,9368
|
42
42
|
mplang/frontend/crypto.py,sha256=Nf8zT4Eko7MIs4R2tgZecKVd7d6Hvd_CGGmANhs3Ghs,3651
|
43
43
|
mplang/frontend/ibis_cc.py,sha256=CTTbPPZ9hFnHuFDDIfgJHie1EdNnHmi5Ha1KsX0iYh8,4235
|
44
|
-
mplang/frontend/jax_cc.py,sha256=
|
44
|
+
mplang/frontend/jax_cc.py,sha256=lMqaYD1tyM5DsStTNYifAXzhzsNM5nDiG3a61ygbWyc,7807
|
45
45
|
mplang/frontend/phe.py,sha256=tDsCvStjVJ1Fs07yF3idkFnugUCA1zdFApPx7Uuulik,6795
|
46
46
|
mplang/frontend/spu.py,sha256=7G6DaEfC5APSDhfeWSISTG_8tEcVbWth3XmjL8QUrVA,4994
|
47
47
|
mplang/frontend/sql.py,sha256=DFdvjEPQX28VCRgUMeHYR0rwwOaoCH15bpvvlclLtHA,1999
|
@@ -70,8 +70,8 @@ mplang/utils/crypto.py,sha256=rvPomBFtznRHc3RPi6Aip9lsU8zW2oxBqGv1K3vn7Rs,1052
|
|
70
70
|
mplang/utils/func_utils.py,sha256=vCJcZmu0bEbqhOQKdpttV2_MBllIcPSN0b8U4WjNGGo,5164
|
71
71
|
mplang/utils/spu_utils.py,sha256=S3L9RBkBe2AvSuMSQQ12cBY5Y1NPthubvErSX_7nj1A,4158
|
72
72
|
mplang/utils/table_utils.py,sha256=aC-IZOKkSmFkpr3NZchLM0Wt0GOn-rg_xHBHREWBwAU,2202
|
73
|
-
mplang_nightly-0.1.
|
74
|
-
mplang_nightly-0.1.
|
75
|
-
mplang_nightly-0.1.
|
76
|
-
mplang_nightly-0.1.
|
77
|
-
mplang_nightly-0.1.
|
73
|
+
mplang_nightly-0.1.dev149.dist-info/METADATA,sha256=eZ_qGx1500gbGY5Ms8smHjQsyZ2Wxfa11z2mSeF0CSk,16547
|
74
|
+
mplang_nightly-0.1.dev149.dist-info/WHEEL,sha256=qtCwoSJWgHk21S1Kb4ihdzI2rlJ1ZKaIurTj_ngOhyQ,87
|
75
|
+
mplang_nightly-0.1.dev149.dist-info/entry_points.txt,sha256=mG1oJT-GAjQR834a62_QIWb7litzWPPyVnwFqm-rWuY,55
|
76
|
+
mplang_nightly-0.1.dev149.dist-info/licenses/LICENSE,sha256=xx0jnfkXJvxRnG63LTGOxlggYnIysveWIZ6H3PNdCrQ,11357
|
77
|
+
mplang_nightly-0.1.dev149.dist-info/RECORD,,
|
File without changes
|
{mplang_nightly-0.1.dev148.dist-info → mplang_nightly-0.1.dev149.dist-info}/entry_points.txt
RENAMED
File without changes
|
{mplang_nightly-0.1.dev148.dist-info → mplang_nightly-0.1.dev149.dist-info}/licenses/LICENSE
RENAMED
File without changes
|