mplang-nightly 0.1.dev147__tar.gz → 0.1.dev149__tar.gz
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- {mplang_nightly-0.1.dev147 → mplang_nightly-0.1.dev149}/PKG-INFO +1 -1
- {mplang_nightly-0.1.dev147 → mplang_nightly-0.1.dev149}/mplang/backend/base.py +21 -47
- {mplang_nightly-0.1.dev147 → mplang_nightly-0.1.dev149}/mplang/backend/context.py +67 -26
- {mplang_nightly-0.1.dev147 → mplang_nightly-0.1.dev149}/mplang/backend/stablehlo.py +8 -1
- {mplang_nightly-0.1.dev147 → mplang_nightly-0.1.dev149}/mplang/frontend/jax_cc.py +39 -7
- mplang_nightly-0.1.dev149/tests/backend/test_kernel_binding.py +102 -0
- {mplang_nightly-0.1.dev147 → mplang_nightly-0.1.dev149}/tests/frontend/test_jax_cc.py +26 -10
- mplang_nightly-0.1.dev149/tests/integration/test_unused_param_integration.py +191 -0
- {mplang_nightly-0.1.dev147 → mplang_nightly-0.1.dev149}/.gitignore +0 -0
- {mplang_nightly-0.1.dev147 → mplang_nightly-0.1.dev149}/LICENSE +0 -0
- {mplang_nightly-0.1.dev147 → mplang_nightly-0.1.dev149}/README.md +0 -0
- {mplang_nightly-0.1.dev147 → mplang_nightly-0.1.dev149}/examples/conf/3pc.yaml +0 -0
- {mplang_nightly-0.1.dev147 → mplang_nightly-0.1.dev149}/examples/stax_nn/README.md +0 -0
- {mplang_nightly-0.1.dev147 → mplang_nightly-0.1.dev149}/examples/stax_nn/models.py +0 -0
- {mplang_nightly-0.1.dev147 → mplang_nightly-0.1.dev149}/examples/stax_nn/stax_nn.py +0 -0
- {mplang_nightly-0.1.dev147 → mplang_nightly-0.1.dev149}/examples/xgboost/hist_jax.py +0 -0
- {mplang_nightly-0.1.dev147 → mplang_nightly-0.1.dev149}/examples/xgboost/hist_jax_test.py +0 -0
- {mplang_nightly-0.1.dev147 → mplang_nightly-0.1.dev149}/examples/xgboost/naive_np.py +0 -0
- {mplang_nightly-0.1.dev147 → mplang_nightly-0.1.dev149}/examples/xgboost/readme.md +0 -0
- {mplang_nightly-0.1.dev147 → mplang_nightly-0.1.dev149}/examples/xgboost/sgb.py +0 -0
- {mplang_nightly-0.1.dev147 → mplang_nightly-0.1.dev149}/examples/xgboost/sgb_test.py +0 -0
- {mplang_nightly-0.1.dev147 → mplang_nightly-0.1.dev149}/hatch_build.py +0 -0
- {mplang_nightly-0.1.dev147 → mplang_nightly-0.1.dev149}/mplang/__init__.py +0 -0
- {mplang_nightly-0.1.dev147 → mplang_nightly-0.1.dev149}/mplang/analysis/__init__.py +0 -0
- {mplang_nightly-0.1.dev147 → mplang_nightly-0.1.dev149}/mplang/analysis/diagram.py +0 -0
- {mplang_nightly-0.1.dev147 → mplang_nightly-0.1.dev149}/mplang/api.py +0 -0
- {mplang_nightly-0.1.dev147 → mplang_nightly-0.1.dev149}/mplang/backend/__init__.py +0 -0
- {mplang_nightly-0.1.dev147 → mplang_nightly-0.1.dev149}/mplang/backend/builtin.py +0 -0
- {mplang_nightly-0.1.dev147 → mplang_nightly-0.1.dev149}/mplang/backend/crypto.py +0 -0
- {mplang_nightly-0.1.dev147 → mplang_nightly-0.1.dev149}/mplang/backend/phe.py +0 -0
- {mplang_nightly-0.1.dev147 → mplang_nightly-0.1.dev149}/mplang/backend/spu.py +0 -0
- {mplang_nightly-0.1.dev147 → mplang_nightly-0.1.dev149}/mplang/backend/sql_duckdb.py +0 -0
- {mplang_nightly-0.1.dev147 → mplang_nightly-0.1.dev149}/mplang/backend/tee.py +0 -0
- {mplang_nightly-0.1.dev147 → mplang_nightly-0.1.dev149}/mplang/core/__init__.py +0 -0
- {mplang_nightly-0.1.dev147 → mplang_nightly-0.1.dev149}/mplang/core/cluster.py +0 -0
- {mplang_nightly-0.1.dev147 → mplang_nightly-0.1.dev149}/mplang/core/comm.py +0 -0
- {mplang_nightly-0.1.dev147 → mplang_nightly-0.1.dev149}/mplang/core/context_mgr.py +0 -0
- {mplang_nightly-0.1.dev147 → mplang_nightly-0.1.dev149}/mplang/core/dtype.py +0 -0
- {mplang_nightly-0.1.dev147 → mplang_nightly-0.1.dev149}/mplang/core/expr/__init__.py +0 -0
- {mplang_nightly-0.1.dev147 → mplang_nightly-0.1.dev149}/mplang/core/expr/ast.py +0 -0
- {mplang_nightly-0.1.dev147 → mplang_nightly-0.1.dev149}/mplang/core/expr/evaluator.py +0 -0
- {mplang_nightly-0.1.dev147 → mplang_nightly-0.1.dev149}/mplang/core/expr/printer.py +0 -0
- {mplang_nightly-0.1.dev147 → mplang_nightly-0.1.dev149}/mplang/core/expr/transformer.py +0 -0
- {mplang_nightly-0.1.dev147 → mplang_nightly-0.1.dev149}/mplang/core/expr/utils.py +0 -0
- {mplang_nightly-0.1.dev147 → mplang_nightly-0.1.dev149}/mplang/core/expr/visitor.py +0 -0
- {mplang_nightly-0.1.dev147 → mplang_nightly-0.1.dev149}/mplang/core/expr/walk.py +0 -0
- {mplang_nightly-0.1.dev147 → mplang_nightly-0.1.dev149}/mplang/core/interp.py +0 -0
- {mplang_nightly-0.1.dev147 → mplang_nightly-0.1.dev149}/mplang/core/mask.py +0 -0
- {mplang_nightly-0.1.dev147 → mplang_nightly-0.1.dev149}/mplang/core/mpir.py +0 -0
- {mplang_nightly-0.1.dev147 → mplang_nightly-0.1.dev149}/mplang/core/mpobject.py +0 -0
- {mplang_nightly-0.1.dev147 → mplang_nightly-0.1.dev149}/mplang/core/mptype.py +0 -0
- {mplang_nightly-0.1.dev147 → mplang_nightly-0.1.dev149}/mplang/core/pfunc.py +0 -0
- {mplang_nightly-0.1.dev147 → mplang_nightly-0.1.dev149}/mplang/core/primitive.py +0 -0
- {mplang_nightly-0.1.dev147 → mplang_nightly-0.1.dev149}/mplang/core/table.py +0 -0
- {mplang_nightly-0.1.dev147 → mplang_nightly-0.1.dev149}/mplang/core/tensor.py +0 -0
- {mplang_nightly-0.1.dev147 → mplang_nightly-0.1.dev149}/mplang/core/tracer.py +0 -0
- {mplang_nightly-0.1.dev147 → mplang_nightly-0.1.dev149}/mplang/device.py +0 -0
- {mplang_nightly-0.1.dev147 → mplang_nightly-0.1.dev149}/mplang/frontend/__init__.py +0 -0
- {mplang_nightly-0.1.dev147 → mplang_nightly-0.1.dev149}/mplang/frontend/base.py +0 -0
- {mplang_nightly-0.1.dev147 → mplang_nightly-0.1.dev149}/mplang/frontend/builtin.py +0 -0
- {mplang_nightly-0.1.dev147 → mplang_nightly-0.1.dev149}/mplang/frontend/crypto.py +0 -0
- {mplang_nightly-0.1.dev147 → mplang_nightly-0.1.dev149}/mplang/frontend/ibis_cc.py +0 -0
- {mplang_nightly-0.1.dev147 → mplang_nightly-0.1.dev149}/mplang/frontend/phe.py +0 -0
- {mplang_nightly-0.1.dev147 → mplang_nightly-0.1.dev149}/mplang/frontend/spu.py +0 -0
- {mplang_nightly-0.1.dev147 → mplang_nightly-0.1.dev149}/mplang/frontend/sql.py +0 -0
- {mplang_nightly-0.1.dev147 → mplang_nightly-0.1.dev149}/mplang/frontend/tee.py +0 -0
- {mplang_nightly-0.1.dev147 → mplang_nightly-0.1.dev149}/mplang/protos/v1alpha1/mpir_pb2.py +0 -0
- {mplang_nightly-0.1.dev147 → mplang_nightly-0.1.dev149}/mplang/protos/v1alpha1/mpir_pb2.pyi +0 -0
- {mplang_nightly-0.1.dev147 → mplang_nightly-0.1.dev149}/mplang/protos/v1alpha1/mpir_pb2_grpc.py +0 -0
- {mplang_nightly-0.1.dev147 → mplang_nightly-0.1.dev149}/mplang/runtime/__init__.py +0 -0
- {mplang_nightly-0.1.dev147 → mplang_nightly-0.1.dev149}/mplang/runtime/cli.py +0 -0
- {mplang_nightly-0.1.dev147 → mplang_nightly-0.1.dev149}/mplang/runtime/client.py +0 -0
- {mplang_nightly-0.1.dev147 → mplang_nightly-0.1.dev149}/mplang/runtime/communicator.py +0 -0
- {mplang_nightly-0.1.dev147 → mplang_nightly-0.1.dev149}/mplang/runtime/data_providers.py +0 -0
- {mplang_nightly-0.1.dev147 → mplang_nightly-0.1.dev149}/mplang/runtime/driver.py +0 -0
- {mplang_nightly-0.1.dev147 → mplang_nightly-0.1.dev149}/mplang/runtime/exceptions.py +0 -0
- {mplang_nightly-0.1.dev147 → mplang_nightly-0.1.dev149}/mplang/runtime/http_api.md +0 -0
- {mplang_nightly-0.1.dev147 → mplang_nightly-0.1.dev149}/mplang/runtime/link_comm.py +0 -0
- {mplang_nightly-0.1.dev147 → mplang_nightly-0.1.dev149}/mplang/runtime/resource.py +0 -0
- {mplang_nightly-0.1.dev147 → mplang_nightly-0.1.dev149}/mplang/runtime/server.py +0 -0
- {mplang_nightly-0.1.dev147 → mplang_nightly-0.1.dev149}/mplang/runtime/simulation.py +0 -0
- {mplang_nightly-0.1.dev147 → mplang_nightly-0.1.dev149}/mplang/simp/__init__.py +0 -0
- {mplang_nightly-0.1.dev147 → mplang_nightly-0.1.dev149}/mplang/simp/mpi.py +0 -0
- {mplang_nightly-0.1.dev147 → mplang_nightly-0.1.dev149}/mplang/simp/random.py +0 -0
- {mplang_nightly-0.1.dev147 → mplang_nightly-0.1.dev149}/mplang/simp/smpc.py +0 -0
- {mplang_nightly-0.1.dev147 → mplang_nightly-0.1.dev149}/mplang/utils/__init__.py +0 -0
- {mplang_nightly-0.1.dev147 → mplang_nightly-0.1.dev149}/mplang/utils/crypto.py +0 -0
- {mplang_nightly-0.1.dev147 → mplang_nightly-0.1.dev149}/mplang/utils/func_utils.py +0 -0
- {mplang_nightly-0.1.dev147 → mplang_nightly-0.1.dev149}/mplang/utils/spu_utils.py +0 -0
- {mplang_nightly-0.1.dev147 → mplang_nightly-0.1.dev149}/mplang/utils/table_utils.py +0 -0
- {mplang_nightly-0.1.dev147 → mplang_nightly-0.1.dev149}/pyproject.toml +0 -0
- {mplang_nightly-0.1.dev147 → mplang_nightly-0.1.dev149}/tests/__init__.py +0 -0
- {mplang_nightly-0.1.dev147 → mplang_nightly-0.1.dev149}/tests/analysis/test_diagram.py +0 -0
- {mplang_nightly-0.1.dev147 → mplang_nightly-0.1.dev149}/tests/backend/test_builtin.py +0 -0
- {mplang_nightly-0.1.dev147 → mplang_nightly-0.1.dev149}/tests/backend/test_debug_print.py +0 -0
- {mplang_nightly-0.1.dev147 → mplang_nightly-0.1.dev149}/tests/backend/test_phe.py +0 -0
- {mplang_nightly-0.1.dev147 → mplang_nightly-0.1.dev149}/tests/backend/test_spu.py +0 -0
- {mplang_nightly-0.1.dev147 → mplang_nightly-0.1.dev149}/tests/backend/test_sql_duckdb.py +0 -0
- {mplang_nightly-0.1.dev147 → mplang_nightly-0.1.dev149}/tests/backend/test_stablehlo.py +0 -0
- {mplang_nightly-0.1.dev147 → mplang_nightly-0.1.dev149}/tests/core/__init__.py +0 -0
- {mplang_nightly-0.1.dev147 → mplang_nightly-0.1.dev149}/tests/core/expr/__init__.py +0 -0
- {mplang_nightly-0.1.dev147 → mplang_nightly-0.1.dev149}/tests/core/expr/conftest.py +0 -0
- {mplang_nightly-0.1.dev147 → mplang_nightly-0.1.dev149}/tests/core/expr/test_ast.py +0 -0
- {mplang_nightly-0.1.dev147 → mplang_nightly-0.1.dev149}/tests/core/expr/test_printer.py +0 -0
- {mplang_nightly-0.1.dev147 → mplang_nightly-0.1.dev149}/tests/core/expr/test_utils.py +0 -0
- {mplang_nightly-0.1.dev147 → mplang_nightly-0.1.dev149}/tests/core/expr/test_walk.py +0 -0
- {mplang_nightly-0.1.dev147 → mplang_nightly-0.1.dev149}/tests/core/test_cluster.py +0 -0
- {mplang_nightly-0.1.dev147 → mplang_nightly-0.1.dev149}/tests/core/test_dtype.py +0 -0
- {mplang_nightly-0.1.dev147 → mplang_nightly-0.1.dev149}/tests/core/test_mask.py +0 -0
- {mplang_nightly-0.1.dev147 → mplang_nightly-0.1.dev149}/tests/core/test_mpir.py +0 -0
- {mplang_nightly-0.1.dev147 → mplang_nightly-0.1.dev149}/tests/core/test_mptype.py +0 -0
- {mplang_nightly-0.1.dev147 → mplang_nightly-0.1.dev149}/tests/core/test_primitive.py +0 -0
- {mplang_nightly-0.1.dev147 → mplang_nightly-0.1.dev149}/tests/core/test_table.py +0 -0
- {mplang_nightly-0.1.dev147 → mplang_nightly-0.1.dev149}/tests/core/test_tensor.py +0 -0
- {mplang_nightly-0.1.dev147 → mplang_nightly-0.1.dev149}/tests/core/test_tracer.py +0 -0
- {mplang_nightly-0.1.dev147 → mplang_nightly-0.1.dev149}/tests/device/__init__.py +0 -0
- {mplang_nightly-0.1.dev147 → mplang_nightly-0.1.dev149}/tests/device/test_device_basic.py +0 -0
- {mplang_nightly-0.1.dev147 → mplang_nightly-0.1.dev149}/tests/frontend/__init__.py +0 -0
- {mplang_nightly-0.1.dev147 → mplang_nightly-0.1.dev149}/tests/frontend/dummy.py +0 -0
- {mplang_nightly-0.1.dev147 → mplang_nightly-0.1.dev149}/tests/frontend/test_builtin_pack.py +0 -0
- {mplang_nightly-0.1.dev147 → mplang_nightly-0.1.dev149}/tests/frontend/test_crypto_tee.py +0 -0
- {mplang_nightly-0.1.dev147 → mplang_nightly-0.1.dev149}/tests/frontend/test_feop_base.py +0 -0
- {mplang_nightly-0.1.dev147 → mplang_nightly-0.1.dev149}/tests/frontend/test_ibis.py +0 -0
- {mplang_nightly-0.1.dev147 → mplang_nightly-0.1.dev149}/tests/frontend/test_ibis_cc.py +0 -0
- {mplang_nightly-0.1.dev147 → mplang_nightly-0.1.dev149}/tests/frontend/test_phe.py +0 -0
- {mplang_nightly-0.1.dev147 → mplang_nightly-0.1.dev149}/tests/frontend/test_spu.py +0 -0
- {mplang_nightly-0.1.dev147 → mplang_nightly-0.1.dev149}/tests/frontend/test_spu_defensive.py +0 -0
- {mplang_nightly-0.1.dev147 → mplang_nightly-0.1.dev149}/tests/frontend/test_sql.py +0 -0
- {mplang_nightly-0.1.dev147 → mplang_nightly-0.1.dev149}/tests/frontend/test_table_tensor_conversion.py +0 -0
- {mplang_nightly-0.1.dev147 → mplang_nightly-0.1.dev149}/tests/integration/README.md +0 -0
- {mplang_nightly-0.1.dev147 → mplang_nightly-0.1.dev149}/tests/integration/test_crypto_roundtrip.py +0 -0
- {mplang_nightly-0.1.dev147 → mplang_nightly-0.1.dev149}/tests/integration/test_http_e2e.py +0 -0
- {mplang_nightly-0.1.dev147 → mplang_nightly-0.1.dev149}/tests/integration/test_symbols_roundtrip.py +0 -0
- {mplang_nightly-0.1.dev147 → mplang_nightly-0.1.dev149}/tests/integration/test_tutorials.py +0 -0
- {mplang_nightly-0.1.dev147 → mplang_nightly-0.1.dev149}/tests/runtime/__init__.py +0 -0
- {mplang_nightly-0.1.dev147 → mplang_nightly-0.1.dev149}/tests/runtime/test_cli.py +0 -0
- {mplang_nightly-0.1.dev147 → mplang_nightly-0.1.dev149}/tests/runtime/test_communicator.py +0 -0
- {mplang_nightly-0.1.dev147 → mplang_nightly-0.1.dev149}/tests/runtime/test_driver.py +0 -0
- {mplang_nightly-0.1.dev147 → mplang_nightly-0.1.dev149}/tests/runtime/test_server.py +0 -0
- {mplang_nightly-0.1.dev147 → mplang_nightly-0.1.dev149}/tests/runtime/test_simulation.py +0 -0
- {mplang_nightly-0.1.dev147 → mplang_nightly-0.1.dev149}/tests/simp/test_mpi.py +0 -0
- {mplang_nightly-0.1.dev147 → mplang_nightly-0.1.dev149}/tests/simp/test_random.py +0 -0
- {mplang_nightly-0.1.dev147 → mplang_nightly-0.1.dev149}/tests/simp/test_simp.py +0 -0
- {mplang_nightly-0.1.dev147 → mplang_nightly-0.1.dev149}/tests/simp/test_smpc.py +0 -0
- {mplang_nightly-0.1.dev147 → mplang_nightly-0.1.dev149}/tests/simp/test_sugar.py +0 -0
- {mplang_nightly-0.1.dev147 → mplang_nightly-0.1.dev149}/tests/utils/__init__.py +0 -0
- {mplang_nightly-0.1.dev147 → mplang_nightly-0.1.dev149}/tests/utils/server_fixtures.py +0 -0
- {mplang_nightly-0.1.dev147 → mplang_nightly-0.1.dev149}/tests/utils/test_func_utils.py +0 -0
- {mplang_nightly-0.1.dev147 → mplang_nightly-0.1.dev149}/tests/utils/test_spu_utils.py +0 -0
- {mplang_nightly-0.1.dev147 → mplang_nightly-0.1.dev149}/tests/utils/test_table_utils.py +0 -0
- {mplang_nightly-0.1.dev147 → mplang_nightly-0.1.dev149}/tutorials/0_basic.py +0 -0
- {mplang_nightly-0.1.dev147 → mplang_nightly-0.1.dev149}/tutorials/10_analysis.py +0 -0
- {mplang_nightly-0.1.dev147 → mplang_nightly-0.1.dev149}/tutorials/1_condition.py +0 -0
- {mplang_nightly-0.1.dev147 → mplang_nightly-0.1.dev149}/tutorials/2_whileloop.py +0 -0
- {mplang_nightly-0.1.dev147 → mplang_nightly-0.1.dev149}/tutorials/3_device.py +0 -0
- {mplang_nightly-0.1.dev147 → mplang_nightly-0.1.dev149}/tutorials/4_simulation.py +0 -0
- {mplang_nightly-0.1.dev147 → mplang_nightly-0.1.dev149}/tutorials/5_ir_dump.py +0 -0
- {mplang_nightly-0.1.dev147 → mplang_nightly-0.1.dev149}/tutorials/6_advanced.py +0 -0
- {mplang_nightly-0.1.dev147 → mplang_nightly-0.1.dev149}/tutorials/7_stdio.py +0 -0
- {mplang_nightly-0.1.dev147 → mplang_nightly-0.1.dev149}/tutorials/8_phe.py +0 -0
- {mplang_nightly-0.1.dev147 → mplang_nightly-0.1.dev149}/tutorials/9_tee.py +0 -0
- {mplang_nightly-0.1.dev147 → mplang_nightly-0.1.dev149}/tutorials/__init__.py +0 -0
- {mplang_nightly-0.1.dev147 → mplang_nightly-0.1.dev149}/tutorials/pitfalls/late_binding.py +0 -0
- {mplang_nightly-0.1.dev147 → mplang_nightly-0.1.dev149}/tutorials/pitfalls/rand.py +0 -0
- {mplang_nightly-0.1.dev147 → mplang_nightly-0.1.dev149}/tutorials/run.sh +0 -0
@@ -12,20 +12,21 @@
|
|
12
12
|
# See the License for the specific language governing permissions and
|
13
13
|
# limitations under the License.
|
14
14
|
|
15
|
-
"""Backend kernel registry
|
15
|
+
"""Backend kernel registry: mapping kernel_id -> implementation.
|
16
16
|
|
17
|
-
This
|
17
|
+
This module provides a lightweight registry for backend kernel implementations.
|
18
|
+
It does not track or decide which kernel handles a given semantic operation;
|
19
|
+
that policy (op -> kernel_id) is managed externally by each ``RuntimeContext``.
|
18
20
|
|
19
|
-
|
20
|
-
|
21
|
-
|
22
|
-
|
23
|
-
|
24
|
-
|
21
|
+
Exposed primitives:
|
22
|
+
* ``@kernel_def(kernel_id)``: decorator to register a kernel implementation.
|
23
|
+
* ``get_kernel_spec(kernel_id)``: look up a previously registered kernel.
|
24
|
+
* ``cur_kctx()`` / ``KernelContext``: execution context available only
|
25
|
+
inside a kernel body (rank, world_size, per-backend state pockets, and a
|
26
|
+
runtime-wide cache shared by kernels of the same runtime instance).
|
25
27
|
|
26
|
-
|
27
|
-
|
28
|
-
centrally (lazy) the first time a runtime executes a kernel.
|
28
|
+
No global op binding table exists here; callers resolve an op to a kernel_id
|
29
|
+
before invoking the kernel function.
|
29
30
|
"""
|
30
31
|
|
31
32
|
from __future__ import annotations
|
@@ -38,12 +39,10 @@ from typing import Any
|
|
38
39
|
__all__ = [
|
39
40
|
"KernelContext",
|
40
41
|
"KernelSpec",
|
41
|
-
"bind_op",
|
42
42
|
"cur_kctx",
|
43
|
-
"
|
43
|
+
"get_kernel_spec",
|
44
|
+
"kernel_exists",
|
44
45
|
"list_kernels",
|
45
|
-
"list_ops",
|
46
|
-
"unbind_op",
|
47
46
|
]
|
48
47
|
|
49
48
|
|
@@ -116,9 +115,6 @@ class KernelSpec:
|
|
116
115
|
# All registered kernel implementations: kernel_id -> spec
|
117
116
|
_KERNELS: dict[str, KernelSpec] = {}
|
118
117
|
|
119
|
-
# Active op bindings: op_type -> kernel_id
|
120
|
-
_BINDINGS: dict[str, str] = {}
|
121
|
-
|
122
118
|
|
123
119
|
def kernel_def(kernel_id: str, /, **meta: Any) -> Callable[[KernelFn], KernelFn]:
|
124
120
|
"""Decorator to register a concrete kernel implementation.
|
@@ -136,34 +132,11 @@ def kernel_def(kernel_id: str, /, **meta: Any) -> Callable[[KernelFn], KernelFn]
|
|
136
132
|
return _decorator
|
137
133
|
|
138
134
|
|
139
|
-
def
|
140
|
-
"""
|
141
|
-
|
142
|
-
|
143
|
-
op_type: Semantic operation name.
|
144
|
-
kernel_id: Previously registered kernel identifier.
|
145
|
-
force: If False and op_type already bound, keep existing binding.
|
146
|
-
If True (default), overwrite.
|
147
|
-
"""
|
148
|
-
if kernel_id not in _KERNELS:
|
135
|
+
def get_kernel_spec(kernel_id: str) -> KernelSpec:
|
136
|
+
"""Return KernelSpec for a registered kernel_id (no op binding lookup)."""
|
137
|
+
spec = _KERNELS.get(kernel_id)
|
138
|
+
if spec is None:
|
149
139
|
raise KeyError(f"kernel_id {kernel_id} not registered")
|
150
|
-
if not force and op_type in _BINDINGS:
|
151
|
-
return
|
152
|
-
_BINDINGS[op_type] = kernel_id
|
153
|
-
|
154
|
-
|
155
|
-
def unbind_op(op_type: str) -> None:
|
156
|
-
_BINDINGS.pop(op_type, None)
|
157
|
-
|
158
|
-
|
159
|
-
def get_kernel_for_op(op_type: str) -> KernelSpec:
|
160
|
-
kid = _BINDINGS.get(op_type)
|
161
|
-
if kid is None:
|
162
|
-
# Tests expect NotImplementedError for unsupported operations
|
163
|
-
raise NotImplementedError(f"no backend kernel registered for op {op_type}")
|
164
|
-
spec = _KERNELS.get(kid)
|
165
|
-
if spec is None: # inconsistent state
|
166
|
-
raise RuntimeError(f"active kernel_id {kid} missing spec")
|
167
140
|
return spec
|
168
141
|
|
169
142
|
|
@@ -171,5 +144,6 @@ def list_kernels() -> list[str]:
|
|
171
144
|
return sorted(_KERNELS.keys())
|
172
145
|
|
173
146
|
|
174
|
-
def
|
175
|
-
|
147
|
+
def kernel_exists(kernel_id: str) -> bool:
|
148
|
+
"""Return True if a kernel_id has been registered."""
|
149
|
+
return kernel_id in _KERNELS
|
@@ -15,11 +15,10 @@
|
|
15
15
|
from __future__ import annotations
|
16
16
|
|
17
17
|
from collections.abc import Mapping
|
18
|
-
from dataclasses import dataclass, field
|
19
18
|
from typing import Any
|
20
19
|
|
21
20
|
from mplang.backend import base
|
22
|
-
from mplang.backend.base import KernelContext,
|
21
|
+
from mplang.backend.base import KernelContext, get_kernel_spec, kernel_exists
|
23
22
|
from mplang.core.dtype import UINT8, DType
|
24
23
|
from mplang.core.pfunc import PFunction
|
25
24
|
from mplang.core.table import TableLike, TableType
|
@@ -100,30 +99,57 @@ _DEFAULT_BINDINGS: dict[str, str] = {
|
|
100
99
|
# --- RuntimeContext ---
|
101
100
|
|
102
101
|
|
103
|
-
@dataclass
|
104
102
|
class RuntimeContext:
|
105
|
-
|
106
|
-
|
107
|
-
|
108
|
-
|
109
|
-
|
110
|
-
|
111
|
-
|
112
|
-
|
103
|
+
"""Per-runtime execution context with isolated op->kernel bindings.
|
104
|
+
|
105
|
+
Parameters
|
106
|
+
----------
|
107
|
+
rank : int
|
108
|
+
Local rank of this participant.
|
109
|
+
world_size : int
|
110
|
+
Total number of participants.
|
111
|
+
initial_bindings : Mapping[str, str] | None, optional
|
112
|
+
Optional partial overrides applied on top of the default binding table
|
113
|
+
during construction (override semantics, not replace). After
|
114
|
+
initialization, all (re)binding must go through ``bind_op`` /
|
115
|
+
``rebind_op``.
|
116
|
+
state / cache / stats : dict, optional
|
117
|
+
Mutable pockets reused across kernel invocations. If omitted, new
|
118
|
+
dictionaries are created.
|
119
|
+
"""
|
120
|
+
|
121
|
+
__slots__ = ("_ibindings", "cache", "rank", "state", "stats", "world_size")
|
122
|
+
|
123
|
+
def __init__(
|
124
|
+
self,
|
125
|
+
rank: int,
|
126
|
+
world_size: int,
|
127
|
+
initial_bindings: Mapping[str, str] | None = None,
|
128
|
+
*,
|
129
|
+
state: dict[str, dict[str, Any]] | None = None,
|
130
|
+
cache: dict[str, Any] | None = None,
|
131
|
+
stats: dict[str, Any] | None = None,
|
132
|
+
) -> None:
|
113
133
|
_ensure_impl_imported()
|
114
|
-
|
115
|
-
|
116
|
-
|
117
|
-
|
118
|
-
|
119
|
-
|
120
|
-
|
134
|
+
self.rank = rank
|
135
|
+
self.world_size = world_size
|
136
|
+
# Merge defaults with user overrides (override semantics)
|
137
|
+
self._ibindings: dict[str, str] = {
|
138
|
+
**_DEFAULT_BINDINGS,
|
139
|
+
**(initial_bindings or {}),
|
140
|
+
}
|
141
|
+
self.state = state if state is not None else {}
|
142
|
+
self.cache = cache if cache is not None else {}
|
143
|
+
self.stats = stats if stats is not None else {}
|
121
144
|
self.stats.setdefault("op_calls", {})
|
122
145
|
|
123
146
|
def run_kernel(self, pfunc: PFunction, arg_list: list[Any]) -> list[Any]:
|
124
147
|
fn_type = pfunc.fn_type
|
125
|
-
|
126
|
-
|
148
|
+
kid = self._ibindings.get(fn_type)
|
149
|
+
if kid is None:
|
150
|
+
raise NotImplementedError(f"no backend kernel registered for op {fn_type}")
|
151
|
+
spec = get_kernel_spec(kid)
|
152
|
+
fn = spec.fn # kernel implementation
|
127
153
|
if len(arg_list) != len(pfunc.ins_info):
|
128
154
|
raise ValueError(
|
129
155
|
f"kernel {fn_type} arg count mismatch: got {len(arg_list)}, expect {len(pfunc.ins_info)}"
|
@@ -186,17 +212,32 @@ class RuntimeContext:
|
|
186
212
|
|
187
213
|
# ---- explicit (re)binding API ----
|
188
214
|
def bind_op(self, op_type: str, kernel_id: str, *, force: bool = False) -> None:
|
189
|
-
"""Bind an operation to a kernel
|
215
|
+
"""Bind an operation to a kernel for THIS context only.
|
190
216
|
|
191
|
-
force=False (default)
|
192
|
-
silent overrides. Use ``rebind_op`` or ``force=True`` to intentionally
|
193
|
-
change a binding.
|
217
|
+
force=False (default) keeps existing binding (no silent override).
|
194
218
|
"""
|
195
|
-
|
219
|
+
if not kernel_exists(kernel_id):
|
220
|
+
raise KeyError(f"kernel_id {kernel_id} not registered")
|
221
|
+
if not force and op_type in self._ibindings:
|
222
|
+
return
|
223
|
+
self._ibindings[op_type] = kernel_id
|
196
224
|
|
197
225
|
def rebind_op(self, op_type: str, kernel_id: str) -> None:
|
198
226
|
"""Force rebind an operation to a different kernel (shorthand)."""
|
199
|
-
|
227
|
+
self.bind_op(op_type, kernel_id, force=True)
|
228
|
+
|
229
|
+
# Introspection helpers
|
230
|
+
def list_bound_ops(self) -> list[str]: # pragma: no cover - convenience
|
231
|
+
return sorted(self._ibindings.keys())
|
232
|
+
|
233
|
+
def get_binding(self, op_type: str) -> str | None: # pragma: no cover
|
234
|
+
return self._ibindings.get(op_type)
|
235
|
+
|
236
|
+
def __repr__(self) -> str: # pragma: no cover - debug aid
|
237
|
+
return (
|
238
|
+
f"RuntimeContext(rank={self.rank}, world_size={self.world_size}, "
|
239
|
+
f"bound_ops={len(self._ibindings)})"
|
240
|
+
)
|
200
241
|
|
201
242
|
|
202
243
|
def _validate_table_arg(
|
@@ -51,8 +51,15 @@ def _stablehlo_exec(pfunc: PFunction, *args: Any) -> Any:
|
|
51
51
|
raise RuntimeError(f"StableHLO compile failed: {e}") from e
|
52
52
|
cache[mlir_text] = compiled
|
53
53
|
|
54
|
+
# Handle JAX's unused parameter elimination via arg_keep_map
|
55
|
+
runtime_args = args
|
56
|
+
if "arg_keep_map" in pfunc.attrs:
|
57
|
+
keep_indices = pfunc.attrs["arg_keep_map"]
|
58
|
+
# Filter out arguments that were eliminated by JAX during compilation
|
59
|
+
runtime_args = tuple(args[i] for i in keep_indices)
|
60
|
+
|
54
61
|
jax_args = []
|
55
|
-
for arg in
|
62
|
+
for arg in runtime_args:
|
56
63
|
if hasattr(arg, "numpy"):
|
57
64
|
jax_arg = jnp.array(arg.numpy()) # type: ignore
|
58
65
|
else:
|
@@ -106,14 +106,46 @@ def jax2stablehlo(
|
|
106
106
|
out_info_flat, out_tree = tree_flatten(lowered.out_info)
|
107
107
|
out_info_flat = [TensorType.from_obj(info) for info in out_info_flat]
|
108
108
|
|
109
|
+
# Extract argument keep mapping to handle JAX's unused parameter elimination
|
110
|
+
# JAX can eliminate unused parameters during compilation, but the runtime still
|
111
|
+
# receives all original arguments. We need the mapping to filter them correctly.
|
112
|
+
arg_keep_map = None
|
113
|
+
original_arg_count = len(in_vars)
|
114
|
+
|
115
|
+
try:
|
116
|
+
# Access JAX internal kept_var_idx - the authoritative source
|
117
|
+
# This tells us exactly which original parameters survived compilation
|
118
|
+
compile_args = lowered._lowering.compile_args
|
119
|
+
kept_var_idx = compile_args["kept_var_idx"]
|
120
|
+
|
121
|
+
kept_indices = sorted(kept_var_idx)
|
122
|
+
if len(kept_indices) < original_arg_count:
|
123
|
+
arg_keep_map = kept_indices
|
124
|
+
|
125
|
+
except (AttributeError, KeyError, TypeError) as e:
|
126
|
+
# JAX internal API is not available or changed
|
127
|
+
# This is a hard error - we cannot reliably handle unused parameters
|
128
|
+
# without knowing exactly which ones were kept
|
129
|
+
raise RuntimeError(
|
130
|
+
f"Cannot access JAX's kept_var_idx to handle unused parameter elimination. "
|
131
|
+
f"This function may have unused parameters that JAX optimized away, "
|
132
|
+
f"but we cannot determine which ones without the internal API. "
|
133
|
+
f"Original error: {e}"
|
134
|
+
) from e
|
135
|
+
|
109
136
|
# This format tells JaxRT how to handle the compiled result
|
110
|
-
|
111
|
-
fn_type
|
112
|
-
ins_info
|
113
|
-
outs_info
|
114
|
-
fn_name
|
115
|
-
fn_text
|
116
|
-
|
137
|
+
pfn_kwargs: dict[str, Any] = {
|
138
|
+
"fn_type": "mlir.stablehlo", # Key: specify StableHLO MLIR format
|
139
|
+
"ins_info": tuple(TensorType.from_obj(x) for x in in_vars),
|
140
|
+
"outs_info": tuple(out_info_flat),
|
141
|
+
"fn_name": get_fn_name(flat_fn),
|
142
|
+
"fn_text": mlir_text, # MLIR text, serializable for transmission
|
143
|
+
}
|
144
|
+
|
145
|
+
if arg_keep_map is not None:
|
146
|
+
pfn_kwargs["arg_keep_map"] = arg_keep_map
|
147
|
+
|
148
|
+
pfn = PFunction(**pfn_kwargs)
|
117
149
|
return pfn, in_vars, out_tree
|
118
150
|
|
119
151
|
|
@@ -0,0 +1,102 @@
|
|
1
|
+
# Copyright 2025 Ant Group Co., Ltd.
|
2
|
+
#
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4
|
+
# you may not use this file except in compliance with the License.
|
5
|
+
# You may obtain a copy of the License at
|
6
|
+
#
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
8
|
+
#
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12
|
+
# See the License for the specific language governing permissions and
|
13
|
+
# limitations under the License.
|
14
|
+
|
15
|
+
# Tests for per-RuntimeContext binding isolation
|
16
|
+
from __future__ import annotations
|
17
|
+
|
18
|
+
import pytest
|
19
|
+
|
20
|
+
from mplang.backend import base
|
21
|
+
from mplang.backend.context import RuntimeContext
|
22
|
+
from mplang.core.dtype import INT64 # switched from INT32 to INT64 to match Python int
|
23
|
+
from mplang.core.pfunc import PFunction
|
24
|
+
from mplang.core.tensor import TensorType
|
25
|
+
|
26
|
+
# We'll register two fake kernels for an op to test rebinding.
|
27
|
+
# If they already exist due to other tests, we guard with try/except.
|
28
|
+
|
29
|
+
|
30
|
+
@base.kernel_def("test.echo.v1")
|
31
|
+
def _echo_v1(
|
32
|
+
pfunc: PFunction, x: int
|
33
|
+
) -> tuple[int,]: # pragma: no cover - executed in test
|
34
|
+
return (x + 1,)
|
35
|
+
|
36
|
+
|
37
|
+
@base.kernel_def("test.echo.v2")
|
38
|
+
def _echo_v2(
|
39
|
+
pfunc: PFunction, x: int
|
40
|
+
) -> tuple[int,]: # pragma: no cover - executed in test
|
41
|
+
return (x + 2,)
|
42
|
+
|
43
|
+
|
44
|
+
def make_pfunc(op_type: str) -> PFunction:
|
45
|
+
# Minimal PFunction stub compatible with backend.run_kernel expectations.
|
46
|
+
# shape info matters only for validation; use scalar INT64 (Python int maps to int64).
|
47
|
+
return PFunction(
|
48
|
+
fn_type=op_type,
|
49
|
+
fn_text="",
|
50
|
+
ins_info=[TensorType(shape=(), dtype=INT64)],
|
51
|
+
outs_info=[TensorType(shape=(), dtype=INT64)],
|
52
|
+
)
|
53
|
+
|
54
|
+
|
55
|
+
def test_isolated_rebind():
|
56
|
+
# ctx1 binds op -> v1, ctx2 binds op -> v2; they should not interfere.
|
57
|
+
op = "test.echo"
|
58
|
+
ctx1 = RuntimeContext(rank=0, world_size=1, initial_bindings={op: "test.echo.v1"})
|
59
|
+
ctx2 = RuntimeContext(rank=0, world_size=1, initial_bindings={op: "test.echo.v2"})
|
60
|
+
|
61
|
+
pfunc = make_pfunc(op)
|
62
|
+
out1 = ctx1.run_kernel(pfunc, [10])[0]
|
63
|
+
out2 = ctx2.run_kernel(pfunc, [10])[0]
|
64
|
+
|
65
|
+
assert out1 == 11
|
66
|
+
assert out2 == 12
|
67
|
+
|
68
|
+
|
69
|
+
def test_rebind_only_affects_context():
|
70
|
+
op = "test.echo"
|
71
|
+
ctx = RuntimeContext(rank=0, world_size=1, initial_bindings={op: "test.echo.v1"})
|
72
|
+
pfunc = make_pfunc(op)
|
73
|
+
assert ctx.run_kernel(pfunc, [5])[0] == 6
|
74
|
+
ctx.rebind_op(op, "test.echo.v2")
|
75
|
+
assert ctx.run_kernel(pfunc, [5])[0] == 7
|
76
|
+
|
77
|
+
|
78
|
+
def test_force_flag():
|
79
|
+
op = "test.echo"
|
80
|
+
ctx = RuntimeContext(rank=0, world_size=1, initial_bindings={op: "test.echo.v1"})
|
81
|
+
# Attempt non-force bind (should keep v1)
|
82
|
+
ctx.bind_op(op, "test.echo.v2", force=False)
|
83
|
+
pfunc = make_pfunc(op)
|
84
|
+
assert ctx.run_kernel(pfunc, [1])[0] == 2 # still v1 (+1)
|
85
|
+
# Now force
|
86
|
+
ctx.bind_op(op, "test.echo.v2", force=True)
|
87
|
+
assert ctx.run_kernel(pfunc, [1])[0] == 3
|
88
|
+
|
89
|
+
|
90
|
+
def test_unknown_kernel_id():
|
91
|
+
ctx = RuntimeContext(rank=0, world_size=1)
|
92
|
+
with pytest.raises(KeyError):
|
93
|
+
ctx.bind_op("some.op", "non.existent.kernel")
|
94
|
+
|
95
|
+
|
96
|
+
def test_missing_binding():
|
97
|
+
# Pick an op name unlikely in defaults
|
98
|
+
op = "unit.test.unbound"
|
99
|
+
ctx = RuntimeContext(rank=0, world_size=1)
|
100
|
+
pfunc = make_pfunc(op)
|
101
|
+
with pytest.raises(NotImplementedError):
|
102
|
+
ctx.run_kernel(pfunc, [0])
|
@@ -244,21 +244,37 @@ class TestJax2StableHLO:
|
|
244
244
|
assert cfunc.fn_text is not None
|
245
245
|
|
246
246
|
def test_multiple_outputs(self):
|
247
|
-
"""Test
|
247
|
+
"""Test functions with multiple outputs."""
|
248
248
|
|
249
249
|
def multi_output(x, y):
|
250
250
|
return x + y, x - y, x * y
|
251
251
|
|
252
|
-
|
253
|
-
|
252
|
+
pfunc, out_tree = self._compile_with_transformer(
|
253
|
+
multi_output, jnp.array([1, 2]), jnp.array([3, 4])
|
254
|
+
)
|
254
255
|
|
255
|
-
|
256
|
+
assert len(pfunc.outs_info) == 3
|
257
|
+
assert out_tree is not None
|
256
258
|
|
257
|
-
|
258
|
-
|
259
|
+
def test_unused_parameter_elimination(self):
|
260
|
+
"""Test that unused parameters are handled correctly via arg_keep_map."""
|
259
261
|
|
260
|
-
|
261
|
-
|
262
|
-
assert out_info.shape == x.shape
|
262
|
+
def func_with_unused(x, unused, z):
|
263
|
+
return x + z # unused parameter eliminated by JAX
|
263
264
|
|
264
|
-
|
265
|
+
x = jnp.array(1, dtype=jnp.int32)
|
266
|
+
unused = jnp.array(999, dtype=jnp.int32)
|
267
|
+
z = jnp.array(3, dtype=jnp.int32)
|
268
|
+
|
269
|
+
pfunc, _ = self._compile_with_transformer(func_with_unused, x, unused, z)
|
270
|
+
|
271
|
+
# Check that compilation succeeded
|
272
|
+
assert pfunc.fn_type == "mlir.stablehlo"
|
273
|
+
assert len(pfunc.ins_info) == 3 # Original input count
|
274
|
+
|
275
|
+
# If JAX eliminated unused parameters, arg_keep_map should be present
|
276
|
+
if "arg_keep_map" in pfunc.attrs:
|
277
|
+
keep_map = pfunc.attrs["arg_keep_map"]
|
278
|
+
assert isinstance(keep_map, list)
|
279
|
+
assert len(keep_map) < 3 # Should be fewer than original 3 params
|
280
|
+
assert 1 not in keep_map # Index 1 (unused) should not be kept
|
@@ -0,0 +1,191 @@
|
|
1
|
+
# Copyright 2025 Ant Group Co., Ltd.
|
2
|
+
#
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4
|
+
# you may not use this file except in compliance with the License.
|
5
|
+
# You may obtain a copy of the License at
|
6
|
+
#
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
8
|
+
#
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12
|
+
# See the License for the specific language governing permissions and
|
13
|
+
# limitations under the License.
|
14
|
+
|
15
|
+
"""
|
16
|
+
Test unused parameter handling with mplang integration.
|
17
|
+
This test verifies that functions with unused parameters work correctly
|
18
|
+
after the arg_keep_map implementation.
|
19
|
+
"""
|
20
|
+
|
21
|
+
import jax.numpy as jnp
|
22
|
+
|
23
|
+
import mplang
|
24
|
+
import mplang.simp as simp
|
25
|
+
|
26
|
+
|
27
|
+
def func_with_unused_params(a, unused_param, b, c):
|
28
|
+
"""Function with unused parameter in the middle."""
|
29
|
+
return a + b + c
|
30
|
+
|
31
|
+
|
32
|
+
def func_all_unused_returns_constant(a, unused1, unused2):
|
33
|
+
"""Function where all parameters are unused - returns constant."""
|
34
|
+
return 42
|
35
|
+
|
36
|
+
|
37
|
+
def func_first_last_unused(unused1, b, c, unused2):
|
38
|
+
"""Function with unused parameters at start and end."""
|
39
|
+
return b * c
|
40
|
+
|
41
|
+
|
42
|
+
class TestUnusedParameterHandling:
|
43
|
+
"""Test suite for JAX unused parameter elimination handling."""
|
44
|
+
|
45
|
+
@staticmethod
|
46
|
+
def _extract_scalar(output):
|
47
|
+
"""Extract scalar value from potentially wrapped output."""
|
48
|
+
if hasattr(output, "__iter__") and len(output) == 1:
|
49
|
+
output = output[0]
|
50
|
+
if hasattr(output, "item"): # JAX array
|
51
|
+
output = output.item()
|
52
|
+
return output
|
53
|
+
|
54
|
+
def test_basic_unused_param(self):
|
55
|
+
"""Test function with one unused parameter in middle position."""
|
56
|
+
sim = mplang.Simulator.simple(1)
|
57
|
+
|
58
|
+
# Create traced function
|
59
|
+
@mplang.function
|
60
|
+
def test_func():
|
61
|
+
# Test values - create inside traced context
|
62
|
+
a = simp.constant(1)
|
63
|
+
unused = simp.constant(999) # This should be eliminated by JAX
|
64
|
+
b = simp.constant(2)
|
65
|
+
c = simp.constant(3)
|
66
|
+
return simp.run(func_with_unused_params)(a, unused, b, c)
|
67
|
+
|
68
|
+
expected = 6 # 1 + 2 + 3
|
69
|
+
|
70
|
+
# Compile and check that compilation succeeds
|
71
|
+
compiled = mplang.compile(sim, test_func)
|
72
|
+
|
73
|
+
# The function should compile successfully
|
74
|
+
assert compiled is not None
|
75
|
+
|
76
|
+
# Execute and verify result
|
77
|
+
result = mplang.evaluate(sim, test_func)
|
78
|
+
output = mplang.fetch(sim, result)
|
79
|
+
|
80
|
+
output = self._extract_scalar(output)
|
81
|
+
|
82
|
+
assert output == expected, f"Expected {expected}, got {output}"
|
83
|
+
|
84
|
+
def test_multiple_unused_params(self):
|
85
|
+
"""Test function with multiple unused parameters."""
|
86
|
+
sim = mplang.Simulator.simple(1)
|
87
|
+
|
88
|
+
b_val = 5
|
89
|
+
c_val = 7
|
90
|
+
expected = b_val * c_val # 35
|
91
|
+
|
92
|
+
@mplang.function
|
93
|
+
def test_func():
|
94
|
+
unused1 = simp.constant(100)
|
95
|
+
b = simp.constant(b_val)
|
96
|
+
c = simp.constant(c_val)
|
97
|
+
unused2 = simp.constant(200)
|
98
|
+
return simp.run(func_first_last_unused)(unused1, b, c, unused2)
|
99
|
+
|
100
|
+
result = mplang.evaluate(sim, test_func)
|
101
|
+
output = mplang.fetch(sim, result)
|
102
|
+
output = self._extract_scalar(output)
|
103
|
+
|
104
|
+
assert output == expected, f"Expected {expected}, got {output}"
|
105
|
+
|
106
|
+
def test_all_params_unused(self):
|
107
|
+
"""Test function where all parameters are unused (returns constant)."""
|
108
|
+
sim = mplang.Simulator.simple(1)
|
109
|
+
expected = 42
|
110
|
+
|
111
|
+
@mplang.function
|
112
|
+
def test_func():
|
113
|
+
a = simp.constant(1)
|
114
|
+
unused1 = simp.constant(10)
|
115
|
+
unused2 = simp.constant(20)
|
116
|
+
return simp.run(func_all_unused_returns_constant)(a, unused1, unused2)
|
117
|
+
|
118
|
+
result = mplang.evaluate(sim, test_func)
|
119
|
+
output = mplang.fetch(sim, result)
|
120
|
+
output = self._extract_scalar(output)
|
121
|
+
|
122
|
+
assert output == expected, f"Expected {expected}, got {output}"
|
123
|
+
|
124
|
+
def test_no_unused_params(self):
|
125
|
+
"""Test function with no unused parameters (regression test)."""
|
126
|
+
sim = mplang.Simulator.simple(1)
|
127
|
+
|
128
|
+
def func_all_used(a, b, c):
|
129
|
+
return a + b + c
|
130
|
+
|
131
|
+
@mplang.function
|
132
|
+
def test_func():
|
133
|
+
a = simp.constant(10)
|
134
|
+
b = simp.constant(20)
|
135
|
+
c = simp.constant(30)
|
136
|
+
return simp.run(func_all_used)(a, b, c)
|
137
|
+
|
138
|
+
result = mplang.evaluate(sim, test_func)
|
139
|
+
output = mplang.fetch(sim, result)
|
140
|
+
output = self._extract_scalar(output)
|
141
|
+
|
142
|
+
assert output == 60, f"Expected 60, got {output}"
|
143
|
+
|
144
|
+
def test_arg_keep_map_in_pfunc(self):
|
145
|
+
"""Test that arg_keep_map is correctly stored in PFunction when needed."""
|
146
|
+
from mplang.frontend.jax_cc import jax2stablehlo
|
147
|
+
|
148
|
+
def func_with_unused(a, unused, b):
|
149
|
+
return a * b
|
150
|
+
|
151
|
+
# Create test inputs
|
152
|
+
a = jnp.array(2, dtype=jnp.int32)
|
153
|
+
unused = jnp.array(999, dtype=jnp.int32)
|
154
|
+
b = jnp.array(3, dtype=jnp.int32)
|
155
|
+
|
156
|
+
# Mock is_variable function
|
157
|
+
def is_variable(arg):
|
158
|
+
return True # Treat all as variables for this test
|
159
|
+
|
160
|
+
# Call jax2stablehlo directly
|
161
|
+
pfunc, _, _ = jax2stablehlo(is_variable, func_with_unused, a, unused, b)
|
162
|
+
|
163
|
+
# Check that arg_keep_map is present when parameters are eliminated
|
164
|
+
if "arg_keep_map" in pfunc.attrs:
|
165
|
+
keep_map = pfunc.attrs["arg_keep_map"]
|
166
|
+
assert isinstance(keep_map, list)
|
167
|
+
assert len(keep_map) < 3 # Should be fewer than original 3 params
|
168
|
+
assert 1 not in keep_map # Index 1 (unused) should not be in keep_map
|
169
|
+
else:
|
170
|
+
# If no elimination happened (possible with different JAX versions/optimizations)
|
171
|
+
pass
|
172
|
+
|
173
|
+
def test_different_dtypes_unused(self):
|
174
|
+
"""Test unused parameter elimination with different data types."""
|
175
|
+
sim = mplang.Simulator.simple(1)
|
176
|
+
|
177
|
+
def func_mixed_types(int_used, float_unused, int_used2):
|
178
|
+
return int_used + int_used2 # float_unused is not used
|
179
|
+
|
180
|
+
@mplang.function
|
181
|
+
def test_func():
|
182
|
+
a = simp.constant(5)
|
183
|
+
unused_float = simp.constant(3.14) # Different dtype, unused
|
184
|
+
c = simp.constant(7)
|
185
|
+
return simp.run(func_mixed_types)(a, unused_float, c)
|
186
|
+
|
187
|
+
result = mplang.evaluate(sim, test_func)
|
188
|
+
output = mplang.fetch(sim, result)
|
189
|
+
output = self._extract_scalar(output)
|
190
|
+
|
191
|
+
assert output == 12, f"Expected 12, got {output}"
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|