mplang-nightly 0.1.dev148__tar.gz → 0.1.dev149__tar.gz
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- {mplang_nightly-0.1.dev148 → mplang_nightly-0.1.dev149}/PKG-INFO +1 -1
- {mplang_nightly-0.1.dev148 → mplang_nightly-0.1.dev149}/mplang/backend/stablehlo.py +8 -1
- {mplang_nightly-0.1.dev148 → mplang_nightly-0.1.dev149}/mplang/frontend/jax_cc.py +39 -7
- {mplang_nightly-0.1.dev148 → mplang_nightly-0.1.dev149}/tests/frontend/test_jax_cc.py +26 -10
- mplang_nightly-0.1.dev149/tests/integration/test_unused_param_integration.py +191 -0
- {mplang_nightly-0.1.dev148 → mplang_nightly-0.1.dev149}/.gitignore +0 -0
- {mplang_nightly-0.1.dev148 → mplang_nightly-0.1.dev149}/LICENSE +0 -0
- {mplang_nightly-0.1.dev148 → mplang_nightly-0.1.dev149}/README.md +0 -0
- {mplang_nightly-0.1.dev148 → mplang_nightly-0.1.dev149}/examples/conf/3pc.yaml +0 -0
- {mplang_nightly-0.1.dev148 → mplang_nightly-0.1.dev149}/examples/stax_nn/README.md +0 -0
- {mplang_nightly-0.1.dev148 → mplang_nightly-0.1.dev149}/examples/stax_nn/models.py +0 -0
- {mplang_nightly-0.1.dev148 → mplang_nightly-0.1.dev149}/examples/stax_nn/stax_nn.py +0 -0
- {mplang_nightly-0.1.dev148 → mplang_nightly-0.1.dev149}/examples/xgboost/hist_jax.py +0 -0
- {mplang_nightly-0.1.dev148 → mplang_nightly-0.1.dev149}/examples/xgboost/hist_jax_test.py +0 -0
- {mplang_nightly-0.1.dev148 → mplang_nightly-0.1.dev149}/examples/xgboost/naive_np.py +0 -0
- {mplang_nightly-0.1.dev148 → mplang_nightly-0.1.dev149}/examples/xgboost/readme.md +0 -0
- {mplang_nightly-0.1.dev148 → mplang_nightly-0.1.dev149}/examples/xgboost/sgb.py +0 -0
- {mplang_nightly-0.1.dev148 → mplang_nightly-0.1.dev149}/examples/xgboost/sgb_test.py +0 -0
- {mplang_nightly-0.1.dev148 → mplang_nightly-0.1.dev149}/hatch_build.py +0 -0
- {mplang_nightly-0.1.dev148 → mplang_nightly-0.1.dev149}/mplang/__init__.py +0 -0
- {mplang_nightly-0.1.dev148 → mplang_nightly-0.1.dev149}/mplang/analysis/__init__.py +0 -0
- {mplang_nightly-0.1.dev148 → mplang_nightly-0.1.dev149}/mplang/analysis/diagram.py +0 -0
- {mplang_nightly-0.1.dev148 → mplang_nightly-0.1.dev149}/mplang/api.py +0 -0
- {mplang_nightly-0.1.dev148 → mplang_nightly-0.1.dev149}/mplang/backend/__init__.py +0 -0
- {mplang_nightly-0.1.dev148 → mplang_nightly-0.1.dev149}/mplang/backend/base.py +0 -0
- {mplang_nightly-0.1.dev148 → mplang_nightly-0.1.dev149}/mplang/backend/builtin.py +0 -0
- {mplang_nightly-0.1.dev148 → mplang_nightly-0.1.dev149}/mplang/backend/context.py +0 -0
- {mplang_nightly-0.1.dev148 → mplang_nightly-0.1.dev149}/mplang/backend/crypto.py +0 -0
- {mplang_nightly-0.1.dev148 → mplang_nightly-0.1.dev149}/mplang/backend/phe.py +0 -0
- {mplang_nightly-0.1.dev148 → mplang_nightly-0.1.dev149}/mplang/backend/spu.py +0 -0
- {mplang_nightly-0.1.dev148 → mplang_nightly-0.1.dev149}/mplang/backend/sql_duckdb.py +0 -0
- {mplang_nightly-0.1.dev148 → mplang_nightly-0.1.dev149}/mplang/backend/tee.py +0 -0
- {mplang_nightly-0.1.dev148 → mplang_nightly-0.1.dev149}/mplang/core/__init__.py +0 -0
- {mplang_nightly-0.1.dev148 → mplang_nightly-0.1.dev149}/mplang/core/cluster.py +0 -0
- {mplang_nightly-0.1.dev148 → mplang_nightly-0.1.dev149}/mplang/core/comm.py +0 -0
- {mplang_nightly-0.1.dev148 → mplang_nightly-0.1.dev149}/mplang/core/context_mgr.py +0 -0
- {mplang_nightly-0.1.dev148 → mplang_nightly-0.1.dev149}/mplang/core/dtype.py +0 -0
- {mplang_nightly-0.1.dev148 → mplang_nightly-0.1.dev149}/mplang/core/expr/__init__.py +0 -0
- {mplang_nightly-0.1.dev148 → mplang_nightly-0.1.dev149}/mplang/core/expr/ast.py +0 -0
- {mplang_nightly-0.1.dev148 → mplang_nightly-0.1.dev149}/mplang/core/expr/evaluator.py +0 -0
- {mplang_nightly-0.1.dev148 → mplang_nightly-0.1.dev149}/mplang/core/expr/printer.py +0 -0
- {mplang_nightly-0.1.dev148 → mplang_nightly-0.1.dev149}/mplang/core/expr/transformer.py +0 -0
- {mplang_nightly-0.1.dev148 → mplang_nightly-0.1.dev149}/mplang/core/expr/utils.py +0 -0
- {mplang_nightly-0.1.dev148 → mplang_nightly-0.1.dev149}/mplang/core/expr/visitor.py +0 -0
- {mplang_nightly-0.1.dev148 → mplang_nightly-0.1.dev149}/mplang/core/expr/walk.py +0 -0
- {mplang_nightly-0.1.dev148 → mplang_nightly-0.1.dev149}/mplang/core/interp.py +0 -0
- {mplang_nightly-0.1.dev148 → mplang_nightly-0.1.dev149}/mplang/core/mask.py +0 -0
- {mplang_nightly-0.1.dev148 → mplang_nightly-0.1.dev149}/mplang/core/mpir.py +0 -0
- {mplang_nightly-0.1.dev148 → mplang_nightly-0.1.dev149}/mplang/core/mpobject.py +0 -0
- {mplang_nightly-0.1.dev148 → mplang_nightly-0.1.dev149}/mplang/core/mptype.py +0 -0
- {mplang_nightly-0.1.dev148 → mplang_nightly-0.1.dev149}/mplang/core/pfunc.py +0 -0
- {mplang_nightly-0.1.dev148 → mplang_nightly-0.1.dev149}/mplang/core/primitive.py +0 -0
- {mplang_nightly-0.1.dev148 → mplang_nightly-0.1.dev149}/mplang/core/table.py +0 -0
- {mplang_nightly-0.1.dev148 → mplang_nightly-0.1.dev149}/mplang/core/tensor.py +0 -0
- {mplang_nightly-0.1.dev148 → mplang_nightly-0.1.dev149}/mplang/core/tracer.py +0 -0
- {mplang_nightly-0.1.dev148 → mplang_nightly-0.1.dev149}/mplang/device.py +0 -0
- {mplang_nightly-0.1.dev148 → mplang_nightly-0.1.dev149}/mplang/frontend/__init__.py +0 -0
- {mplang_nightly-0.1.dev148 → mplang_nightly-0.1.dev149}/mplang/frontend/base.py +0 -0
- {mplang_nightly-0.1.dev148 → mplang_nightly-0.1.dev149}/mplang/frontend/builtin.py +0 -0
- {mplang_nightly-0.1.dev148 → mplang_nightly-0.1.dev149}/mplang/frontend/crypto.py +0 -0
- {mplang_nightly-0.1.dev148 → mplang_nightly-0.1.dev149}/mplang/frontend/ibis_cc.py +0 -0
- {mplang_nightly-0.1.dev148 → mplang_nightly-0.1.dev149}/mplang/frontend/phe.py +0 -0
- {mplang_nightly-0.1.dev148 → mplang_nightly-0.1.dev149}/mplang/frontend/spu.py +0 -0
- {mplang_nightly-0.1.dev148 → mplang_nightly-0.1.dev149}/mplang/frontend/sql.py +0 -0
- {mplang_nightly-0.1.dev148 → mplang_nightly-0.1.dev149}/mplang/frontend/tee.py +0 -0
- {mplang_nightly-0.1.dev148 → mplang_nightly-0.1.dev149}/mplang/protos/v1alpha1/mpir_pb2.py +0 -0
- {mplang_nightly-0.1.dev148 → mplang_nightly-0.1.dev149}/mplang/protos/v1alpha1/mpir_pb2.pyi +0 -0
- {mplang_nightly-0.1.dev148 → mplang_nightly-0.1.dev149}/mplang/protos/v1alpha1/mpir_pb2_grpc.py +0 -0
- {mplang_nightly-0.1.dev148 → mplang_nightly-0.1.dev149}/mplang/runtime/__init__.py +0 -0
- {mplang_nightly-0.1.dev148 → mplang_nightly-0.1.dev149}/mplang/runtime/cli.py +0 -0
- {mplang_nightly-0.1.dev148 → mplang_nightly-0.1.dev149}/mplang/runtime/client.py +0 -0
- {mplang_nightly-0.1.dev148 → mplang_nightly-0.1.dev149}/mplang/runtime/communicator.py +0 -0
- {mplang_nightly-0.1.dev148 → mplang_nightly-0.1.dev149}/mplang/runtime/data_providers.py +0 -0
- {mplang_nightly-0.1.dev148 → mplang_nightly-0.1.dev149}/mplang/runtime/driver.py +0 -0
- {mplang_nightly-0.1.dev148 → mplang_nightly-0.1.dev149}/mplang/runtime/exceptions.py +0 -0
- {mplang_nightly-0.1.dev148 → mplang_nightly-0.1.dev149}/mplang/runtime/http_api.md +0 -0
- {mplang_nightly-0.1.dev148 → mplang_nightly-0.1.dev149}/mplang/runtime/link_comm.py +0 -0
- {mplang_nightly-0.1.dev148 → mplang_nightly-0.1.dev149}/mplang/runtime/resource.py +0 -0
- {mplang_nightly-0.1.dev148 → mplang_nightly-0.1.dev149}/mplang/runtime/server.py +0 -0
- {mplang_nightly-0.1.dev148 → mplang_nightly-0.1.dev149}/mplang/runtime/simulation.py +0 -0
- {mplang_nightly-0.1.dev148 → mplang_nightly-0.1.dev149}/mplang/simp/__init__.py +0 -0
- {mplang_nightly-0.1.dev148 → mplang_nightly-0.1.dev149}/mplang/simp/mpi.py +0 -0
- {mplang_nightly-0.1.dev148 → mplang_nightly-0.1.dev149}/mplang/simp/random.py +0 -0
- {mplang_nightly-0.1.dev148 → mplang_nightly-0.1.dev149}/mplang/simp/smpc.py +0 -0
- {mplang_nightly-0.1.dev148 → mplang_nightly-0.1.dev149}/mplang/utils/__init__.py +0 -0
- {mplang_nightly-0.1.dev148 → mplang_nightly-0.1.dev149}/mplang/utils/crypto.py +0 -0
- {mplang_nightly-0.1.dev148 → mplang_nightly-0.1.dev149}/mplang/utils/func_utils.py +0 -0
- {mplang_nightly-0.1.dev148 → mplang_nightly-0.1.dev149}/mplang/utils/spu_utils.py +0 -0
- {mplang_nightly-0.1.dev148 → mplang_nightly-0.1.dev149}/mplang/utils/table_utils.py +0 -0
- {mplang_nightly-0.1.dev148 → mplang_nightly-0.1.dev149}/pyproject.toml +0 -0
- {mplang_nightly-0.1.dev148 → mplang_nightly-0.1.dev149}/tests/__init__.py +0 -0
- {mplang_nightly-0.1.dev148 → mplang_nightly-0.1.dev149}/tests/analysis/test_diagram.py +0 -0
- {mplang_nightly-0.1.dev148 → mplang_nightly-0.1.dev149}/tests/backend/test_builtin.py +0 -0
- {mplang_nightly-0.1.dev148 → mplang_nightly-0.1.dev149}/tests/backend/test_debug_print.py +0 -0
- {mplang_nightly-0.1.dev148 → mplang_nightly-0.1.dev149}/tests/backend/test_kernel_binding.py +0 -0
- {mplang_nightly-0.1.dev148 → mplang_nightly-0.1.dev149}/tests/backend/test_phe.py +0 -0
- {mplang_nightly-0.1.dev148 → mplang_nightly-0.1.dev149}/tests/backend/test_spu.py +0 -0
- {mplang_nightly-0.1.dev148 → mplang_nightly-0.1.dev149}/tests/backend/test_sql_duckdb.py +0 -0
- {mplang_nightly-0.1.dev148 → mplang_nightly-0.1.dev149}/tests/backend/test_stablehlo.py +0 -0
- {mplang_nightly-0.1.dev148 → mplang_nightly-0.1.dev149}/tests/core/__init__.py +0 -0
- {mplang_nightly-0.1.dev148 → mplang_nightly-0.1.dev149}/tests/core/expr/__init__.py +0 -0
- {mplang_nightly-0.1.dev148 → mplang_nightly-0.1.dev149}/tests/core/expr/conftest.py +0 -0
- {mplang_nightly-0.1.dev148 → mplang_nightly-0.1.dev149}/tests/core/expr/test_ast.py +0 -0
- {mplang_nightly-0.1.dev148 → mplang_nightly-0.1.dev149}/tests/core/expr/test_printer.py +0 -0
- {mplang_nightly-0.1.dev148 → mplang_nightly-0.1.dev149}/tests/core/expr/test_utils.py +0 -0
- {mplang_nightly-0.1.dev148 → mplang_nightly-0.1.dev149}/tests/core/expr/test_walk.py +0 -0
- {mplang_nightly-0.1.dev148 → mplang_nightly-0.1.dev149}/tests/core/test_cluster.py +0 -0
- {mplang_nightly-0.1.dev148 → mplang_nightly-0.1.dev149}/tests/core/test_dtype.py +0 -0
- {mplang_nightly-0.1.dev148 → mplang_nightly-0.1.dev149}/tests/core/test_mask.py +0 -0
- {mplang_nightly-0.1.dev148 → mplang_nightly-0.1.dev149}/tests/core/test_mpir.py +0 -0
- {mplang_nightly-0.1.dev148 → mplang_nightly-0.1.dev149}/tests/core/test_mptype.py +0 -0
- {mplang_nightly-0.1.dev148 → mplang_nightly-0.1.dev149}/tests/core/test_primitive.py +0 -0
- {mplang_nightly-0.1.dev148 → mplang_nightly-0.1.dev149}/tests/core/test_table.py +0 -0
- {mplang_nightly-0.1.dev148 → mplang_nightly-0.1.dev149}/tests/core/test_tensor.py +0 -0
- {mplang_nightly-0.1.dev148 → mplang_nightly-0.1.dev149}/tests/core/test_tracer.py +0 -0
- {mplang_nightly-0.1.dev148 → mplang_nightly-0.1.dev149}/tests/device/__init__.py +0 -0
- {mplang_nightly-0.1.dev148 → mplang_nightly-0.1.dev149}/tests/device/test_device_basic.py +0 -0
- {mplang_nightly-0.1.dev148 → mplang_nightly-0.1.dev149}/tests/frontend/__init__.py +0 -0
- {mplang_nightly-0.1.dev148 → mplang_nightly-0.1.dev149}/tests/frontend/dummy.py +0 -0
- {mplang_nightly-0.1.dev148 → mplang_nightly-0.1.dev149}/tests/frontend/test_builtin_pack.py +0 -0
- {mplang_nightly-0.1.dev148 → mplang_nightly-0.1.dev149}/tests/frontend/test_crypto_tee.py +0 -0
- {mplang_nightly-0.1.dev148 → mplang_nightly-0.1.dev149}/tests/frontend/test_feop_base.py +0 -0
- {mplang_nightly-0.1.dev148 → mplang_nightly-0.1.dev149}/tests/frontend/test_ibis.py +0 -0
- {mplang_nightly-0.1.dev148 → mplang_nightly-0.1.dev149}/tests/frontend/test_ibis_cc.py +0 -0
- {mplang_nightly-0.1.dev148 → mplang_nightly-0.1.dev149}/tests/frontend/test_phe.py +0 -0
- {mplang_nightly-0.1.dev148 → mplang_nightly-0.1.dev149}/tests/frontend/test_spu.py +0 -0
- {mplang_nightly-0.1.dev148 → mplang_nightly-0.1.dev149}/tests/frontend/test_spu_defensive.py +0 -0
- {mplang_nightly-0.1.dev148 → mplang_nightly-0.1.dev149}/tests/frontend/test_sql.py +0 -0
- {mplang_nightly-0.1.dev148 → mplang_nightly-0.1.dev149}/tests/frontend/test_table_tensor_conversion.py +0 -0
- {mplang_nightly-0.1.dev148 → mplang_nightly-0.1.dev149}/tests/integration/README.md +0 -0
- {mplang_nightly-0.1.dev148 → mplang_nightly-0.1.dev149}/tests/integration/test_crypto_roundtrip.py +0 -0
- {mplang_nightly-0.1.dev148 → mplang_nightly-0.1.dev149}/tests/integration/test_http_e2e.py +0 -0
- {mplang_nightly-0.1.dev148 → mplang_nightly-0.1.dev149}/tests/integration/test_symbols_roundtrip.py +0 -0
- {mplang_nightly-0.1.dev148 → mplang_nightly-0.1.dev149}/tests/integration/test_tutorials.py +0 -0
- {mplang_nightly-0.1.dev148 → mplang_nightly-0.1.dev149}/tests/runtime/__init__.py +0 -0
- {mplang_nightly-0.1.dev148 → mplang_nightly-0.1.dev149}/tests/runtime/test_cli.py +0 -0
- {mplang_nightly-0.1.dev148 → mplang_nightly-0.1.dev149}/tests/runtime/test_communicator.py +0 -0
- {mplang_nightly-0.1.dev148 → mplang_nightly-0.1.dev149}/tests/runtime/test_driver.py +0 -0
- {mplang_nightly-0.1.dev148 → mplang_nightly-0.1.dev149}/tests/runtime/test_server.py +0 -0
- {mplang_nightly-0.1.dev148 → mplang_nightly-0.1.dev149}/tests/runtime/test_simulation.py +0 -0
- {mplang_nightly-0.1.dev148 → mplang_nightly-0.1.dev149}/tests/simp/test_mpi.py +0 -0
- {mplang_nightly-0.1.dev148 → mplang_nightly-0.1.dev149}/tests/simp/test_random.py +0 -0
- {mplang_nightly-0.1.dev148 → mplang_nightly-0.1.dev149}/tests/simp/test_simp.py +0 -0
- {mplang_nightly-0.1.dev148 → mplang_nightly-0.1.dev149}/tests/simp/test_smpc.py +0 -0
- {mplang_nightly-0.1.dev148 → mplang_nightly-0.1.dev149}/tests/simp/test_sugar.py +0 -0
- {mplang_nightly-0.1.dev148 → mplang_nightly-0.1.dev149}/tests/utils/__init__.py +0 -0
- {mplang_nightly-0.1.dev148 → mplang_nightly-0.1.dev149}/tests/utils/server_fixtures.py +0 -0
- {mplang_nightly-0.1.dev148 → mplang_nightly-0.1.dev149}/tests/utils/test_func_utils.py +0 -0
- {mplang_nightly-0.1.dev148 → mplang_nightly-0.1.dev149}/tests/utils/test_spu_utils.py +0 -0
- {mplang_nightly-0.1.dev148 → mplang_nightly-0.1.dev149}/tests/utils/test_table_utils.py +0 -0
- {mplang_nightly-0.1.dev148 → mplang_nightly-0.1.dev149}/tutorials/0_basic.py +0 -0
- {mplang_nightly-0.1.dev148 → mplang_nightly-0.1.dev149}/tutorials/10_analysis.py +0 -0
- {mplang_nightly-0.1.dev148 → mplang_nightly-0.1.dev149}/tutorials/1_condition.py +0 -0
- {mplang_nightly-0.1.dev148 → mplang_nightly-0.1.dev149}/tutorials/2_whileloop.py +0 -0
- {mplang_nightly-0.1.dev148 → mplang_nightly-0.1.dev149}/tutorials/3_device.py +0 -0
- {mplang_nightly-0.1.dev148 → mplang_nightly-0.1.dev149}/tutorials/4_simulation.py +0 -0
- {mplang_nightly-0.1.dev148 → mplang_nightly-0.1.dev149}/tutorials/5_ir_dump.py +0 -0
- {mplang_nightly-0.1.dev148 → mplang_nightly-0.1.dev149}/tutorials/6_advanced.py +0 -0
- {mplang_nightly-0.1.dev148 → mplang_nightly-0.1.dev149}/tutorials/7_stdio.py +0 -0
- {mplang_nightly-0.1.dev148 → mplang_nightly-0.1.dev149}/tutorials/8_phe.py +0 -0
- {mplang_nightly-0.1.dev148 → mplang_nightly-0.1.dev149}/tutorials/9_tee.py +0 -0
- {mplang_nightly-0.1.dev148 → mplang_nightly-0.1.dev149}/tutorials/__init__.py +0 -0
- {mplang_nightly-0.1.dev148 → mplang_nightly-0.1.dev149}/tutorials/pitfalls/late_binding.py +0 -0
- {mplang_nightly-0.1.dev148 → mplang_nightly-0.1.dev149}/tutorials/pitfalls/rand.py +0 -0
- {mplang_nightly-0.1.dev148 → mplang_nightly-0.1.dev149}/tutorials/run.sh +0 -0
@@ -51,8 +51,15 @@ def _stablehlo_exec(pfunc: PFunction, *args: Any) -> Any:
|
|
51
51
|
raise RuntimeError(f"StableHLO compile failed: {e}") from e
|
52
52
|
cache[mlir_text] = compiled
|
53
53
|
|
54
|
+
# Handle JAX's unused parameter elimination via arg_keep_map
|
55
|
+
runtime_args = args
|
56
|
+
if "arg_keep_map" in pfunc.attrs:
|
57
|
+
keep_indices = pfunc.attrs["arg_keep_map"]
|
58
|
+
# Filter out arguments that were eliminated by JAX during compilation
|
59
|
+
runtime_args = tuple(args[i] for i in keep_indices)
|
60
|
+
|
54
61
|
jax_args = []
|
55
|
-
for arg in
|
62
|
+
for arg in runtime_args:
|
56
63
|
if hasattr(arg, "numpy"):
|
57
64
|
jax_arg = jnp.array(arg.numpy()) # type: ignore
|
58
65
|
else:
|
@@ -106,14 +106,46 @@ def jax2stablehlo(
|
|
106
106
|
out_info_flat, out_tree = tree_flatten(lowered.out_info)
|
107
107
|
out_info_flat = [TensorType.from_obj(info) for info in out_info_flat]
|
108
108
|
|
109
|
+
# Extract argument keep mapping to handle JAX's unused parameter elimination
|
110
|
+
# JAX can eliminate unused parameters during compilation, but the runtime still
|
111
|
+
# receives all original arguments. We need the mapping to filter them correctly.
|
112
|
+
arg_keep_map = None
|
113
|
+
original_arg_count = len(in_vars)
|
114
|
+
|
115
|
+
try:
|
116
|
+
# Access JAX internal kept_var_idx - the authoritative source
|
117
|
+
# This tells us exactly which original parameters survived compilation
|
118
|
+
compile_args = lowered._lowering.compile_args
|
119
|
+
kept_var_idx = compile_args["kept_var_idx"]
|
120
|
+
|
121
|
+
kept_indices = sorted(kept_var_idx)
|
122
|
+
if len(kept_indices) < original_arg_count:
|
123
|
+
arg_keep_map = kept_indices
|
124
|
+
|
125
|
+
except (AttributeError, KeyError, TypeError) as e:
|
126
|
+
# JAX internal API is not available or changed
|
127
|
+
# This is a hard error - we cannot reliably handle unused parameters
|
128
|
+
# without knowing exactly which ones were kept
|
129
|
+
raise RuntimeError(
|
130
|
+
f"Cannot access JAX's kept_var_idx to handle unused parameter elimination. "
|
131
|
+
f"This function may have unused parameters that JAX optimized away, "
|
132
|
+
f"but we cannot determine which ones without the internal API. "
|
133
|
+
f"Original error: {e}"
|
134
|
+
) from e
|
135
|
+
|
109
136
|
# This format tells JaxRT how to handle the compiled result
|
110
|
-
|
111
|
-
fn_type
|
112
|
-
ins_info
|
113
|
-
outs_info
|
114
|
-
fn_name
|
115
|
-
fn_text
|
116
|
-
|
137
|
+
pfn_kwargs: dict[str, Any] = {
|
138
|
+
"fn_type": "mlir.stablehlo", # Key: specify StableHLO MLIR format
|
139
|
+
"ins_info": tuple(TensorType.from_obj(x) for x in in_vars),
|
140
|
+
"outs_info": tuple(out_info_flat),
|
141
|
+
"fn_name": get_fn_name(flat_fn),
|
142
|
+
"fn_text": mlir_text, # MLIR text, serializable for transmission
|
143
|
+
}
|
144
|
+
|
145
|
+
if arg_keep_map is not None:
|
146
|
+
pfn_kwargs["arg_keep_map"] = arg_keep_map
|
147
|
+
|
148
|
+
pfn = PFunction(**pfn_kwargs)
|
117
149
|
return pfn, in_vars, out_tree
|
118
150
|
|
119
151
|
|
@@ -244,21 +244,37 @@ class TestJax2StableHLO:
|
|
244
244
|
assert cfunc.fn_text is not None
|
245
245
|
|
246
246
|
def test_multiple_outputs(self):
|
247
|
-
"""Test
|
247
|
+
"""Test functions with multiple outputs."""
|
248
248
|
|
249
249
|
def multi_output(x, y):
|
250
250
|
return x + y, x - y, x * y
|
251
251
|
|
252
|
-
|
253
|
-
|
252
|
+
pfunc, out_tree = self._compile_with_transformer(
|
253
|
+
multi_output, jnp.array([1, 2]), jnp.array([3, 4])
|
254
|
+
)
|
254
255
|
|
255
|
-
|
256
|
+
assert len(pfunc.outs_info) == 3
|
257
|
+
assert out_tree is not None
|
256
258
|
|
257
|
-
|
258
|
-
|
259
|
+
def test_unused_parameter_elimination(self):
|
260
|
+
"""Test that unused parameters are handled correctly via arg_keep_map."""
|
259
261
|
|
260
|
-
|
261
|
-
|
262
|
-
assert out_info.shape == x.shape
|
262
|
+
def func_with_unused(x, unused, z):
|
263
|
+
return x + z # unused parameter eliminated by JAX
|
263
264
|
|
264
|
-
|
265
|
+
x = jnp.array(1, dtype=jnp.int32)
|
266
|
+
unused = jnp.array(999, dtype=jnp.int32)
|
267
|
+
z = jnp.array(3, dtype=jnp.int32)
|
268
|
+
|
269
|
+
pfunc, _ = self._compile_with_transformer(func_with_unused, x, unused, z)
|
270
|
+
|
271
|
+
# Check that compilation succeeded
|
272
|
+
assert pfunc.fn_type == "mlir.stablehlo"
|
273
|
+
assert len(pfunc.ins_info) == 3 # Original input count
|
274
|
+
|
275
|
+
# If JAX eliminated unused parameters, arg_keep_map should be present
|
276
|
+
if "arg_keep_map" in pfunc.attrs:
|
277
|
+
keep_map = pfunc.attrs["arg_keep_map"]
|
278
|
+
assert isinstance(keep_map, list)
|
279
|
+
assert len(keep_map) < 3 # Should be fewer than original 3 params
|
280
|
+
assert 1 not in keep_map # Index 1 (unused) should not be kept
|
@@ -0,0 +1,191 @@
|
|
1
|
+
# Copyright 2025 Ant Group Co., Ltd.
|
2
|
+
#
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4
|
+
# you may not use this file except in compliance with the License.
|
5
|
+
# You may obtain a copy of the License at
|
6
|
+
#
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
8
|
+
#
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12
|
+
# See the License for the specific language governing permissions and
|
13
|
+
# limitations under the License.
|
14
|
+
|
15
|
+
"""
|
16
|
+
Test unused parameter handling with mplang integration.
|
17
|
+
This test verifies that functions with unused parameters work correctly
|
18
|
+
after the arg_keep_map implementation.
|
19
|
+
"""
|
20
|
+
|
21
|
+
import jax.numpy as jnp
|
22
|
+
|
23
|
+
import mplang
|
24
|
+
import mplang.simp as simp
|
25
|
+
|
26
|
+
|
27
|
+
def func_with_unused_params(a, unused_param, b, c):
|
28
|
+
"""Function with unused parameter in the middle."""
|
29
|
+
return a + b + c
|
30
|
+
|
31
|
+
|
32
|
+
def func_all_unused_returns_constant(a, unused1, unused2):
|
33
|
+
"""Function where all parameters are unused - returns constant."""
|
34
|
+
return 42
|
35
|
+
|
36
|
+
|
37
|
+
def func_first_last_unused(unused1, b, c, unused2):
|
38
|
+
"""Function with unused parameters at start and end."""
|
39
|
+
return b * c
|
40
|
+
|
41
|
+
|
42
|
+
class TestUnusedParameterHandling:
|
43
|
+
"""Test suite for JAX unused parameter elimination handling."""
|
44
|
+
|
45
|
+
@staticmethod
|
46
|
+
def _extract_scalar(output):
|
47
|
+
"""Extract scalar value from potentially wrapped output."""
|
48
|
+
if hasattr(output, "__iter__") and len(output) == 1:
|
49
|
+
output = output[0]
|
50
|
+
if hasattr(output, "item"): # JAX array
|
51
|
+
output = output.item()
|
52
|
+
return output
|
53
|
+
|
54
|
+
def test_basic_unused_param(self):
|
55
|
+
"""Test function with one unused parameter in middle position."""
|
56
|
+
sim = mplang.Simulator.simple(1)
|
57
|
+
|
58
|
+
# Create traced function
|
59
|
+
@mplang.function
|
60
|
+
def test_func():
|
61
|
+
# Test values - create inside traced context
|
62
|
+
a = simp.constant(1)
|
63
|
+
unused = simp.constant(999) # This should be eliminated by JAX
|
64
|
+
b = simp.constant(2)
|
65
|
+
c = simp.constant(3)
|
66
|
+
return simp.run(func_with_unused_params)(a, unused, b, c)
|
67
|
+
|
68
|
+
expected = 6 # 1 + 2 + 3
|
69
|
+
|
70
|
+
# Compile and check that compilation succeeds
|
71
|
+
compiled = mplang.compile(sim, test_func)
|
72
|
+
|
73
|
+
# The function should compile successfully
|
74
|
+
assert compiled is not None
|
75
|
+
|
76
|
+
# Execute and verify result
|
77
|
+
result = mplang.evaluate(sim, test_func)
|
78
|
+
output = mplang.fetch(sim, result)
|
79
|
+
|
80
|
+
output = self._extract_scalar(output)
|
81
|
+
|
82
|
+
assert output == expected, f"Expected {expected}, got {output}"
|
83
|
+
|
84
|
+
def test_multiple_unused_params(self):
|
85
|
+
"""Test function with multiple unused parameters."""
|
86
|
+
sim = mplang.Simulator.simple(1)
|
87
|
+
|
88
|
+
b_val = 5
|
89
|
+
c_val = 7
|
90
|
+
expected = b_val * c_val # 35
|
91
|
+
|
92
|
+
@mplang.function
|
93
|
+
def test_func():
|
94
|
+
unused1 = simp.constant(100)
|
95
|
+
b = simp.constant(b_val)
|
96
|
+
c = simp.constant(c_val)
|
97
|
+
unused2 = simp.constant(200)
|
98
|
+
return simp.run(func_first_last_unused)(unused1, b, c, unused2)
|
99
|
+
|
100
|
+
result = mplang.evaluate(sim, test_func)
|
101
|
+
output = mplang.fetch(sim, result)
|
102
|
+
output = self._extract_scalar(output)
|
103
|
+
|
104
|
+
assert output == expected, f"Expected {expected}, got {output}"
|
105
|
+
|
106
|
+
def test_all_params_unused(self):
|
107
|
+
"""Test function where all parameters are unused (returns constant)."""
|
108
|
+
sim = mplang.Simulator.simple(1)
|
109
|
+
expected = 42
|
110
|
+
|
111
|
+
@mplang.function
|
112
|
+
def test_func():
|
113
|
+
a = simp.constant(1)
|
114
|
+
unused1 = simp.constant(10)
|
115
|
+
unused2 = simp.constant(20)
|
116
|
+
return simp.run(func_all_unused_returns_constant)(a, unused1, unused2)
|
117
|
+
|
118
|
+
result = mplang.evaluate(sim, test_func)
|
119
|
+
output = mplang.fetch(sim, result)
|
120
|
+
output = self._extract_scalar(output)
|
121
|
+
|
122
|
+
assert output == expected, f"Expected {expected}, got {output}"
|
123
|
+
|
124
|
+
def test_no_unused_params(self):
|
125
|
+
"""Test function with no unused parameters (regression test)."""
|
126
|
+
sim = mplang.Simulator.simple(1)
|
127
|
+
|
128
|
+
def func_all_used(a, b, c):
|
129
|
+
return a + b + c
|
130
|
+
|
131
|
+
@mplang.function
|
132
|
+
def test_func():
|
133
|
+
a = simp.constant(10)
|
134
|
+
b = simp.constant(20)
|
135
|
+
c = simp.constant(30)
|
136
|
+
return simp.run(func_all_used)(a, b, c)
|
137
|
+
|
138
|
+
result = mplang.evaluate(sim, test_func)
|
139
|
+
output = mplang.fetch(sim, result)
|
140
|
+
output = self._extract_scalar(output)
|
141
|
+
|
142
|
+
assert output == 60, f"Expected 60, got {output}"
|
143
|
+
|
144
|
+
def test_arg_keep_map_in_pfunc(self):
|
145
|
+
"""Test that arg_keep_map is correctly stored in PFunction when needed."""
|
146
|
+
from mplang.frontend.jax_cc import jax2stablehlo
|
147
|
+
|
148
|
+
def func_with_unused(a, unused, b):
|
149
|
+
return a * b
|
150
|
+
|
151
|
+
# Create test inputs
|
152
|
+
a = jnp.array(2, dtype=jnp.int32)
|
153
|
+
unused = jnp.array(999, dtype=jnp.int32)
|
154
|
+
b = jnp.array(3, dtype=jnp.int32)
|
155
|
+
|
156
|
+
# Mock is_variable function
|
157
|
+
def is_variable(arg):
|
158
|
+
return True # Treat all as variables for this test
|
159
|
+
|
160
|
+
# Call jax2stablehlo directly
|
161
|
+
pfunc, _, _ = jax2stablehlo(is_variable, func_with_unused, a, unused, b)
|
162
|
+
|
163
|
+
# Check that arg_keep_map is present when parameters are eliminated
|
164
|
+
if "arg_keep_map" in pfunc.attrs:
|
165
|
+
keep_map = pfunc.attrs["arg_keep_map"]
|
166
|
+
assert isinstance(keep_map, list)
|
167
|
+
assert len(keep_map) < 3 # Should be fewer than original 3 params
|
168
|
+
assert 1 not in keep_map # Index 1 (unused) should not be in keep_map
|
169
|
+
else:
|
170
|
+
# If no elimination happened (possible with different JAX versions/optimizations)
|
171
|
+
pass
|
172
|
+
|
173
|
+
def test_different_dtypes_unused(self):
|
174
|
+
"""Test unused parameter elimination with different data types."""
|
175
|
+
sim = mplang.Simulator.simple(1)
|
176
|
+
|
177
|
+
def func_mixed_types(int_used, float_unused, int_used2):
|
178
|
+
return int_used + int_used2 # float_unused is not used
|
179
|
+
|
180
|
+
@mplang.function
|
181
|
+
def test_func():
|
182
|
+
a = simp.constant(5)
|
183
|
+
unused_float = simp.constant(3.14) # Different dtype, unused
|
184
|
+
c = simp.constant(7)
|
185
|
+
return simp.run(func_mixed_types)(a, unused_float, c)
|
186
|
+
|
187
|
+
result = mplang.evaluate(sim, test_func)
|
188
|
+
output = mplang.fetch(sim, result)
|
189
|
+
output = self._extract_scalar(output)
|
190
|
+
|
191
|
+
assert output == 12, f"Expected 12, got {output}"
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
{mplang_nightly-0.1.dev148 → mplang_nightly-0.1.dev149}/mplang/protos/v1alpha1/mpir_pb2_grpc.py
RENAMED
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
{mplang_nightly-0.1.dev148 → mplang_nightly-0.1.dev149}/tests/backend/test_kernel_binding.py
RENAMED
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
{mplang_nightly-0.1.dev148 → mplang_nightly-0.1.dev149}/tests/frontend/test_spu_defensive.py
RENAMED
File without changes
|
File without changes
|
File without changes
|
File without changes
|
{mplang_nightly-0.1.dev148 → mplang_nightly-0.1.dev149}/tests/integration/test_crypto_roundtrip.py
RENAMED
File without changes
|
File without changes
|
{mplang_nightly-0.1.dev148 → mplang_nightly-0.1.dev149}/tests/integration/test_symbols_roundtrip.py
RENAMED
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|