mplang-nightly 0.1.dev148__py3-none-any.whl → 0.1.dev149__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -51,8 +51,15 @@ def _stablehlo_exec(pfunc: PFunction, *args: Any) -> Any:
51
51
  raise RuntimeError(f"StableHLO compile failed: {e}") from e
52
52
  cache[mlir_text] = compiled
53
53
 
54
+ # Handle JAX's unused parameter elimination via arg_keep_map
55
+ runtime_args = args
56
+ if "arg_keep_map" in pfunc.attrs:
57
+ keep_indices = pfunc.attrs["arg_keep_map"]
58
+ # Filter out arguments that were eliminated by JAX during compilation
59
+ runtime_args = tuple(args[i] for i in keep_indices)
60
+
54
61
  jax_args = []
55
- for arg in args:
62
+ for arg in runtime_args:
56
63
  if hasattr(arg, "numpy"):
57
64
  jax_arg = jnp.array(arg.numpy()) # type: ignore
58
65
  else:
mplang/frontend/jax_cc.py CHANGED
@@ -106,14 +106,46 @@ def jax2stablehlo(
106
106
  out_info_flat, out_tree = tree_flatten(lowered.out_info)
107
107
  out_info_flat = [TensorType.from_obj(info) for info in out_info_flat]
108
108
 
109
+ # Extract argument keep mapping to handle JAX's unused parameter elimination
110
+ # JAX can eliminate unused parameters during compilation, but the runtime still
111
+ # receives all original arguments. We need the mapping to filter them correctly.
112
+ arg_keep_map = None
113
+ original_arg_count = len(in_vars)
114
+
115
+ try:
116
+ # Access JAX internal kept_var_idx - the authoritative source
117
+ # This tells us exactly which original parameters survived compilation
118
+ compile_args = lowered._lowering.compile_args
119
+ kept_var_idx = compile_args["kept_var_idx"]
120
+
121
+ kept_indices = sorted(kept_var_idx)
122
+ if len(kept_indices) < original_arg_count:
123
+ arg_keep_map = kept_indices
124
+
125
+ except (AttributeError, KeyError, TypeError) as e:
126
+ # JAX internal API is not available or changed
127
+ # This is a hard error - we cannot reliably handle unused parameters
128
+ # without knowing exactly which ones were kept
129
+ raise RuntimeError(
130
+ f"Cannot access JAX's kept_var_idx to handle unused parameter elimination. "
131
+ f"This function may have unused parameters that JAX optimized away, "
132
+ f"but we cannot determine which ones without the internal API. "
133
+ f"Original error: {e}"
134
+ ) from e
135
+
109
136
  # This format tells JaxRT how to handle the compiled result
110
- pfn = PFunction(
111
- fn_type="mlir.stablehlo", # Key: specify StableHLO MLIR format
112
- ins_info=tuple(TensorType.from_obj(x) for x in in_vars),
113
- outs_info=tuple(out_info_flat),
114
- fn_name=get_fn_name(flat_fn),
115
- fn_text=mlir_text, # MLIR text, serializable for transmission
116
- )
137
+ pfn_kwargs: dict[str, Any] = {
138
+ "fn_type": "mlir.stablehlo", # Key: specify StableHLO MLIR format
139
+ "ins_info": tuple(TensorType.from_obj(x) for x in in_vars),
140
+ "outs_info": tuple(out_info_flat),
141
+ "fn_name": get_fn_name(flat_fn),
142
+ "fn_text": mlir_text, # MLIR text, serializable for transmission
143
+ }
144
+
145
+ if arg_keep_map is not None:
146
+ pfn_kwargs["arg_keep_map"] = arg_keep_map
147
+
148
+ pfn = PFunction(**pfn_kwargs)
117
149
  return pfn, in_vars, out_tree
118
150
 
119
151
 
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: mplang-nightly
3
- Version: 0.1.dev148
3
+ Version: 0.1.dev149
4
4
  Summary: Multi-Party Programming Language
5
5
  Author-email: SecretFlow Team <secretflow-contact@service.alipay.com>
6
6
  License: Apache License
@@ -11,7 +11,7 @@ mplang/backend/crypto.py,sha256=H_s5HI7lUP7g0xz-a9qMbSn6dhJStUilKbn3-7SIh0I,3812
11
11
  mplang/backend/phe.py,sha256=uNqmrbDAbd97TWS_O6D5sopastHy6J20R7knFE4M4uc,65247
12
12
  mplang/backend/spu.py,sha256=QT1q5uv-5P_nBGtTvtA_yI2h3h3zIqNSnvzGT7Shua4,9307
13
13
  mplang/backend/sql_duckdb.py,sha256=U_KzEUinxrBRDoUz2Vh597-N4I3hPOBT0RT3tX-ZqKE,1502
14
- mplang/backend/stablehlo.py,sha256=GOxy-qgOxyEdtBkt6LASKzaZnPewZhvHYSPOgFFXgIM,2612
14
+ mplang/backend/stablehlo.py,sha256=RhKf6TUvjLrRvgtdVY2HxcRDGtjpKBobuBFMfsvZQOI,2937
15
15
  mplang/backend/tee.py,sha256=6kc7qTe8nWc3pr6iYtozEGLO8Umg-UBQLDiz6p3pdVg,1918
16
16
  mplang/core/__init__.py,sha256=lWxlEKfRwX7FNDzgyKZ1fiDMaCiqkyg0j5mKlZD_v7g,2244
17
17
  mplang/core/cluster.py,sha256=gqMJenvXUfHhE181Dd5JiUkD4nT07RLoicBnvsGmRkE,8598
@@ -41,7 +41,7 @@ mplang/frontend/base.py,sha256=rGtfBejcDh9mTRxOdJK5VUlG5vYiVJSir8X72X0Huvc,18264
41
41
  mplang/frontend/builtin.py,sha256=8qrlbe_SSy6QTXTnMG6_ADB8jSklVZGFBrkoR-p02FE,9368
42
42
  mplang/frontend/crypto.py,sha256=Nf8zT4Eko7MIs4R2tgZecKVd7d6Hvd_CGGmANhs3Ghs,3651
43
43
  mplang/frontend/ibis_cc.py,sha256=CTTbPPZ9hFnHuFDDIfgJHie1EdNnHmi5Ha1KsX0iYh8,4235
44
- mplang/frontend/jax_cc.py,sha256=ssP6rCvyWQ5VAr80-7z9QZUE2mWXyozJCGpq1dYQYY8,6374
44
+ mplang/frontend/jax_cc.py,sha256=lMqaYD1tyM5DsStTNYifAXzhzsNM5nDiG3a61ygbWyc,7807
45
45
  mplang/frontend/phe.py,sha256=tDsCvStjVJ1Fs07yF3idkFnugUCA1zdFApPx7Uuulik,6795
46
46
  mplang/frontend/spu.py,sha256=7G6DaEfC5APSDhfeWSISTG_8tEcVbWth3XmjL8QUrVA,4994
47
47
  mplang/frontend/sql.py,sha256=DFdvjEPQX28VCRgUMeHYR0rwwOaoCH15bpvvlclLtHA,1999
@@ -70,8 +70,8 @@ mplang/utils/crypto.py,sha256=rvPomBFtznRHc3RPi6Aip9lsU8zW2oxBqGv1K3vn7Rs,1052
70
70
  mplang/utils/func_utils.py,sha256=vCJcZmu0bEbqhOQKdpttV2_MBllIcPSN0b8U4WjNGGo,5164
71
71
  mplang/utils/spu_utils.py,sha256=S3L9RBkBe2AvSuMSQQ12cBY5Y1NPthubvErSX_7nj1A,4158
72
72
  mplang/utils/table_utils.py,sha256=aC-IZOKkSmFkpr3NZchLM0Wt0GOn-rg_xHBHREWBwAU,2202
73
- mplang_nightly-0.1.dev148.dist-info/METADATA,sha256=3hkF4x8KZwg7PaGyfuAiDvp_qH9T1bhWS41rYWwl2Zs,16547
74
- mplang_nightly-0.1.dev148.dist-info/WHEEL,sha256=qtCwoSJWgHk21S1Kb4ihdzI2rlJ1ZKaIurTj_ngOhyQ,87
75
- mplang_nightly-0.1.dev148.dist-info/entry_points.txt,sha256=mG1oJT-GAjQR834a62_QIWb7litzWPPyVnwFqm-rWuY,55
76
- mplang_nightly-0.1.dev148.dist-info/licenses/LICENSE,sha256=xx0jnfkXJvxRnG63LTGOxlggYnIysveWIZ6H3PNdCrQ,11357
77
- mplang_nightly-0.1.dev148.dist-info/RECORD,,
73
+ mplang_nightly-0.1.dev149.dist-info/METADATA,sha256=eZ_qGx1500gbGY5Ms8smHjQsyZ2Wxfa11z2mSeF0CSk,16547
74
+ mplang_nightly-0.1.dev149.dist-info/WHEEL,sha256=qtCwoSJWgHk21S1Kb4ihdzI2rlJ1ZKaIurTj_ngOhyQ,87
75
+ mplang_nightly-0.1.dev149.dist-info/entry_points.txt,sha256=mG1oJT-GAjQR834a62_QIWb7litzWPPyVnwFqm-rWuY,55
76
+ mplang_nightly-0.1.dev149.dist-info/licenses/LICENSE,sha256=xx0jnfkXJvxRnG63LTGOxlggYnIysveWIZ6H3PNdCrQ,11357
77
+ mplang_nightly-0.1.dev149.dist-info/RECORD,,