mplang-nightly 0.1.dev148__tar.gz → 0.1.dev149__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (165) hide show
  1. {mplang_nightly-0.1.dev148 → mplang_nightly-0.1.dev149}/PKG-INFO +1 -1
  2. {mplang_nightly-0.1.dev148 → mplang_nightly-0.1.dev149}/mplang/backend/stablehlo.py +8 -1
  3. {mplang_nightly-0.1.dev148 → mplang_nightly-0.1.dev149}/mplang/frontend/jax_cc.py +39 -7
  4. {mplang_nightly-0.1.dev148 → mplang_nightly-0.1.dev149}/tests/frontend/test_jax_cc.py +26 -10
  5. mplang_nightly-0.1.dev149/tests/integration/test_unused_param_integration.py +191 -0
  6. {mplang_nightly-0.1.dev148 → mplang_nightly-0.1.dev149}/.gitignore +0 -0
  7. {mplang_nightly-0.1.dev148 → mplang_nightly-0.1.dev149}/LICENSE +0 -0
  8. {mplang_nightly-0.1.dev148 → mplang_nightly-0.1.dev149}/README.md +0 -0
  9. {mplang_nightly-0.1.dev148 → mplang_nightly-0.1.dev149}/examples/conf/3pc.yaml +0 -0
  10. {mplang_nightly-0.1.dev148 → mplang_nightly-0.1.dev149}/examples/stax_nn/README.md +0 -0
  11. {mplang_nightly-0.1.dev148 → mplang_nightly-0.1.dev149}/examples/stax_nn/models.py +0 -0
  12. {mplang_nightly-0.1.dev148 → mplang_nightly-0.1.dev149}/examples/stax_nn/stax_nn.py +0 -0
  13. {mplang_nightly-0.1.dev148 → mplang_nightly-0.1.dev149}/examples/xgboost/hist_jax.py +0 -0
  14. {mplang_nightly-0.1.dev148 → mplang_nightly-0.1.dev149}/examples/xgboost/hist_jax_test.py +0 -0
  15. {mplang_nightly-0.1.dev148 → mplang_nightly-0.1.dev149}/examples/xgboost/naive_np.py +0 -0
  16. {mplang_nightly-0.1.dev148 → mplang_nightly-0.1.dev149}/examples/xgboost/readme.md +0 -0
  17. {mplang_nightly-0.1.dev148 → mplang_nightly-0.1.dev149}/examples/xgboost/sgb.py +0 -0
  18. {mplang_nightly-0.1.dev148 → mplang_nightly-0.1.dev149}/examples/xgboost/sgb_test.py +0 -0
  19. {mplang_nightly-0.1.dev148 → mplang_nightly-0.1.dev149}/hatch_build.py +0 -0
  20. {mplang_nightly-0.1.dev148 → mplang_nightly-0.1.dev149}/mplang/__init__.py +0 -0
  21. {mplang_nightly-0.1.dev148 → mplang_nightly-0.1.dev149}/mplang/analysis/__init__.py +0 -0
  22. {mplang_nightly-0.1.dev148 → mplang_nightly-0.1.dev149}/mplang/analysis/diagram.py +0 -0
  23. {mplang_nightly-0.1.dev148 → mplang_nightly-0.1.dev149}/mplang/api.py +0 -0
  24. {mplang_nightly-0.1.dev148 → mplang_nightly-0.1.dev149}/mplang/backend/__init__.py +0 -0
  25. {mplang_nightly-0.1.dev148 → mplang_nightly-0.1.dev149}/mplang/backend/base.py +0 -0
  26. {mplang_nightly-0.1.dev148 → mplang_nightly-0.1.dev149}/mplang/backend/builtin.py +0 -0
  27. {mplang_nightly-0.1.dev148 → mplang_nightly-0.1.dev149}/mplang/backend/context.py +0 -0
  28. {mplang_nightly-0.1.dev148 → mplang_nightly-0.1.dev149}/mplang/backend/crypto.py +0 -0
  29. {mplang_nightly-0.1.dev148 → mplang_nightly-0.1.dev149}/mplang/backend/phe.py +0 -0
  30. {mplang_nightly-0.1.dev148 → mplang_nightly-0.1.dev149}/mplang/backend/spu.py +0 -0
  31. {mplang_nightly-0.1.dev148 → mplang_nightly-0.1.dev149}/mplang/backend/sql_duckdb.py +0 -0
  32. {mplang_nightly-0.1.dev148 → mplang_nightly-0.1.dev149}/mplang/backend/tee.py +0 -0
  33. {mplang_nightly-0.1.dev148 → mplang_nightly-0.1.dev149}/mplang/core/__init__.py +0 -0
  34. {mplang_nightly-0.1.dev148 → mplang_nightly-0.1.dev149}/mplang/core/cluster.py +0 -0
  35. {mplang_nightly-0.1.dev148 → mplang_nightly-0.1.dev149}/mplang/core/comm.py +0 -0
  36. {mplang_nightly-0.1.dev148 → mplang_nightly-0.1.dev149}/mplang/core/context_mgr.py +0 -0
  37. {mplang_nightly-0.1.dev148 → mplang_nightly-0.1.dev149}/mplang/core/dtype.py +0 -0
  38. {mplang_nightly-0.1.dev148 → mplang_nightly-0.1.dev149}/mplang/core/expr/__init__.py +0 -0
  39. {mplang_nightly-0.1.dev148 → mplang_nightly-0.1.dev149}/mplang/core/expr/ast.py +0 -0
  40. {mplang_nightly-0.1.dev148 → mplang_nightly-0.1.dev149}/mplang/core/expr/evaluator.py +0 -0
  41. {mplang_nightly-0.1.dev148 → mplang_nightly-0.1.dev149}/mplang/core/expr/printer.py +0 -0
  42. {mplang_nightly-0.1.dev148 → mplang_nightly-0.1.dev149}/mplang/core/expr/transformer.py +0 -0
  43. {mplang_nightly-0.1.dev148 → mplang_nightly-0.1.dev149}/mplang/core/expr/utils.py +0 -0
  44. {mplang_nightly-0.1.dev148 → mplang_nightly-0.1.dev149}/mplang/core/expr/visitor.py +0 -0
  45. {mplang_nightly-0.1.dev148 → mplang_nightly-0.1.dev149}/mplang/core/expr/walk.py +0 -0
  46. {mplang_nightly-0.1.dev148 → mplang_nightly-0.1.dev149}/mplang/core/interp.py +0 -0
  47. {mplang_nightly-0.1.dev148 → mplang_nightly-0.1.dev149}/mplang/core/mask.py +0 -0
  48. {mplang_nightly-0.1.dev148 → mplang_nightly-0.1.dev149}/mplang/core/mpir.py +0 -0
  49. {mplang_nightly-0.1.dev148 → mplang_nightly-0.1.dev149}/mplang/core/mpobject.py +0 -0
  50. {mplang_nightly-0.1.dev148 → mplang_nightly-0.1.dev149}/mplang/core/mptype.py +0 -0
  51. {mplang_nightly-0.1.dev148 → mplang_nightly-0.1.dev149}/mplang/core/pfunc.py +0 -0
  52. {mplang_nightly-0.1.dev148 → mplang_nightly-0.1.dev149}/mplang/core/primitive.py +0 -0
  53. {mplang_nightly-0.1.dev148 → mplang_nightly-0.1.dev149}/mplang/core/table.py +0 -0
  54. {mplang_nightly-0.1.dev148 → mplang_nightly-0.1.dev149}/mplang/core/tensor.py +0 -0
  55. {mplang_nightly-0.1.dev148 → mplang_nightly-0.1.dev149}/mplang/core/tracer.py +0 -0
  56. {mplang_nightly-0.1.dev148 → mplang_nightly-0.1.dev149}/mplang/device.py +0 -0
  57. {mplang_nightly-0.1.dev148 → mplang_nightly-0.1.dev149}/mplang/frontend/__init__.py +0 -0
  58. {mplang_nightly-0.1.dev148 → mplang_nightly-0.1.dev149}/mplang/frontend/base.py +0 -0
  59. {mplang_nightly-0.1.dev148 → mplang_nightly-0.1.dev149}/mplang/frontend/builtin.py +0 -0
  60. {mplang_nightly-0.1.dev148 → mplang_nightly-0.1.dev149}/mplang/frontend/crypto.py +0 -0
  61. {mplang_nightly-0.1.dev148 → mplang_nightly-0.1.dev149}/mplang/frontend/ibis_cc.py +0 -0
  62. {mplang_nightly-0.1.dev148 → mplang_nightly-0.1.dev149}/mplang/frontend/phe.py +0 -0
  63. {mplang_nightly-0.1.dev148 → mplang_nightly-0.1.dev149}/mplang/frontend/spu.py +0 -0
  64. {mplang_nightly-0.1.dev148 → mplang_nightly-0.1.dev149}/mplang/frontend/sql.py +0 -0
  65. {mplang_nightly-0.1.dev148 → mplang_nightly-0.1.dev149}/mplang/frontend/tee.py +0 -0
  66. {mplang_nightly-0.1.dev148 → mplang_nightly-0.1.dev149}/mplang/protos/v1alpha1/mpir_pb2.py +0 -0
  67. {mplang_nightly-0.1.dev148 → mplang_nightly-0.1.dev149}/mplang/protos/v1alpha1/mpir_pb2.pyi +0 -0
  68. {mplang_nightly-0.1.dev148 → mplang_nightly-0.1.dev149}/mplang/protos/v1alpha1/mpir_pb2_grpc.py +0 -0
  69. {mplang_nightly-0.1.dev148 → mplang_nightly-0.1.dev149}/mplang/runtime/__init__.py +0 -0
  70. {mplang_nightly-0.1.dev148 → mplang_nightly-0.1.dev149}/mplang/runtime/cli.py +0 -0
  71. {mplang_nightly-0.1.dev148 → mplang_nightly-0.1.dev149}/mplang/runtime/client.py +0 -0
  72. {mplang_nightly-0.1.dev148 → mplang_nightly-0.1.dev149}/mplang/runtime/communicator.py +0 -0
  73. {mplang_nightly-0.1.dev148 → mplang_nightly-0.1.dev149}/mplang/runtime/data_providers.py +0 -0
  74. {mplang_nightly-0.1.dev148 → mplang_nightly-0.1.dev149}/mplang/runtime/driver.py +0 -0
  75. {mplang_nightly-0.1.dev148 → mplang_nightly-0.1.dev149}/mplang/runtime/exceptions.py +0 -0
  76. {mplang_nightly-0.1.dev148 → mplang_nightly-0.1.dev149}/mplang/runtime/http_api.md +0 -0
  77. {mplang_nightly-0.1.dev148 → mplang_nightly-0.1.dev149}/mplang/runtime/link_comm.py +0 -0
  78. {mplang_nightly-0.1.dev148 → mplang_nightly-0.1.dev149}/mplang/runtime/resource.py +0 -0
  79. {mplang_nightly-0.1.dev148 → mplang_nightly-0.1.dev149}/mplang/runtime/server.py +0 -0
  80. {mplang_nightly-0.1.dev148 → mplang_nightly-0.1.dev149}/mplang/runtime/simulation.py +0 -0
  81. {mplang_nightly-0.1.dev148 → mplang_nightly-0.1.dev149}/mplang/simp/__init__.py +0 -0
  82. {mplang_nightly-0.1.dev148 → mplang_nightly-0.1.dev149}/mplang/simp/mpi.py +0 -0
  83. {mplang_nightly-0.1.dev148 → mplang_nightly-0.1.dev149}/mplang/simp/random.py +0 -0
  84. {mplang_nightly-0.1.dev148 → mplang_nightly-0.1.dev149}/mplang/simp/smpc.py +0 -0
  85. {mplang_nightly-0.1.dev148 → mplang_nightly-0.1.dev149}/mplang/utils/__init__.py +0 -0
  86. {mplang_nightly-0.1.dev148 → mplang_nightly-0.1.dev149}/mplang/utils/crypto.py +0 -0
  87. {mplang_nightly-0.1.dev148 → mplang_nightly-0.1.dev149}/mplang/utils/func_utils.py +0 -0
  88. {mplang_nightly-0.1.dev148 → mplang_nightly-0.1.dev149}/mplang/utils/spu_utils.py +0 -0
  89. {mplang_nightly-0.1.dev148 → mplang_nightly-0.1.dev149}/mplang/utils/table_utils.py +0 -0
  90. {mplang_nightly-0.1.dev148 → mplang_nightly-0.1.dev149}/pyproject.toml +0 -0
  91. {mplang_nightly-0.1.dev148 → mplang_nightly-0.1.dev149}/tests/__init__.py +0 -0
  92. {mplang_nightly-0.1.dev148 → mplang_nightly-0.1.dev149}/tests/analysis/test_diagram.py +0 -0
  93. {mplang_nightly-0.1.dev148 → mplang_nightly-0.1.dev149}/tests/backend/test_builtin.py +0 -0
  94. {mplang_nightly-0.1.dev148 → mplang_nightly-0.1.dev149}/tests/backend/test_debug_print.py +0 -0
  95. {mplang_nightly-0.1.dev148 → mplang_nightly-0.1.dev149}/tests/backend/test_kernel_binding.py +0 -0
  96. {mplang_nightly-0.1.dev148 → mplang_nightly-0.1.dev149}/tests/backend/test_phe.py +0 -0
  97. {mplang_nightly-0.1.dev148 → mplang_nightly-0.1.dev149}/tests/backend/test_spu.py +0 -0
  98. {mplang_nightly-0.1.dev148 → mplang_nightly-0.1.dev149}/tests/backend/test_sql_duckdb.py +0 -0
  99. {mplang_nightly-0.1.dev148 → mplang_nightly-0.1.dev149}/tests/backend/test_stablehlo.py +0 -0
  100. {mplang_nightly-0.1.dev148 → mplang_nightly-0.1.dev149}/tests/core/__init__.py +0 -0
  101. {mplang_nightly-0.1.dev148 → mplang_nightly-0.1.dev149}/tests/core/expr/__init__.py +0 -0
  102. {mplang_nightly-0.1.dev148 → mplang_nightly-0.1.dev149}/tests/core/expr/conftest.py +0 -0
  103. {mplang_nightly-0.1.dev148 → mplang_nightly-0.1.dev149}/tests/core/expr/test_ast.py +0 -0
  104. {mplang_nightly-0.1.dev148 → mplang_nightly-0.1.dev149}/tests/core/expr/test_printer.py +0 -0
  105. {mplang_nightly-0.1.dev148 → mplang_nightly-0.1.dev149}/tests/core/expr/test_utils.py +0 -0
  106. {mplang_nightly-0.1.dev148 → mplang_nightly-0.1.dev149}/tests/core/expr/test_walk.py +0 -0
  107. {mplang_nightly-0.1.dev148 → mplang_nightly-0.1.dev149}/tests/core/test_cluster.py +0 -0
  108. {mplang_nightly-0.1.dev148 → mplang_nightly-0.1.dev149}/tests/core/test_dtype.py +0 -0
  109. {mplang_nightly-0.1.dev148 → mplang_nightly-0.1.dev149}/tests/core/test_mask.py +0 -0
  110. {mplang_nightly-0.1.dev148 → mplang_nightly-0.1.dev149}/tests/core/test_mpir.py +0 -0
  111. {mplang_nightly-0.1.dev148 → mplang_nightly-0.1.dev149}/tests/core/test_mptype.py +0 -0
  112. {mplang_nightly-0.1.dev148 → mplang_nightly-0.1.dev149}/tests/core/test_primitive.py +0 -0
  113. {mplang_nightly-0.1.dev148 → mplang_nightly-0.1.dev149}/tests/core/test_table.py +0 -0
  114. {mplang_nightly-0.1.dev148 → mplang_nightly-0.1.dev149}/tests/core/test_tensor.py +0 -0
  115. {mplang_nightly-0.1.dev148 → mplang_nightly-0.1.dev149}/tests/core/test_tracer.py +0 -0
  116. {mplang_nightly-0.1.dev148 → mplang_nightly-0.1.dev149}/tests/device/__init__.py +0 -0
  117. {mplang_nightly-0.1.dev148 → mplang_nightly-0.1.dev149}/tests/device/test_device_basic.py +0 -0
  118. {mplang_nightly-0.1.dev148 → mplang_nightly-0.1.dev149}/tests/frontend/__init__.py +0 -0
  119. {mplang_nightly-0.1.dev148 → mplang_nightly-0.1.dev149}/tests/frontend/dummy.py +0 -0
  120. {mplang_nightly-0.1.dev148 → mplang_nightly-0.1.dev149}/tests/frontend/test_builtin_pack.py +0 -0
  121. {mplang_nightly-0.1.dev148 → mplang_nightly-0.1.dev149}/tests/frontend/test_crypto_tee.py +0 -0
  122. {mplang_nightly-0.1.dev148 → mplang_nightly-0.1.dev149}/tests/frontend/test_feop_base.py +0 -0
  123. {mplang_nightly-0.1.dev148 → mplang_nightly-0.1.dev149}/tests/frontend/test_ibis.py +0 -0
  124. {mplang_nightly-0.1.dev148 → mplang_nightly-0.1.dev149}/tests/frontend/test_ibis_cc.py +0 -0
  125. {mplang_nightly-0.1.dev148 → mplang_nightly-0.1.dev149}/tests/frontend/test_phe.py +0 -0
  126. {mplang_nightly-0.1.dev148 → mplang_nightly-0.1.dev149}/tests/frontend/test_spu.py +0 -0
  127. {mplang_nightly-0.1.dev148 → mplang_nightly-0.1.dev149}/tests/frontend/test_spu_defensive.py +0 -0
  128. {mplang_nightly-0.1.dev148 → mplang_nightly-0.1.dev149}/tests/frontend/test_sql.py +0 -0
  129. {mplang_nightly-0.1.dev148 → mplang_nightly-0.1.dev149}/tests/frontend/test_table_tensor_conversion.py +0 -0
  130. {mplang_nightly-0.1.dev148 → mplang_nightly-0.1.dev149}/tests/integration/README.md +0 -0
  131. {mplang_nightly-0.1.dev148 → mplang_nightly-0.1.dev149}/tests/integration/test_crypto_roundtrip.py +0 -0
  132. {mplang_nightly-0.1.dev148 → mplang_nightly-0.1.dev149}/tests/integration/test_http_e2e.py +0 -0
  133. {mplang_nightly-0.1.dev148 → mplang_nightly-0.1.dev149}/tests/integration/test_symbols_roundtrip.py +0 -0
  134. {mplang_nightly-0.1.dev148 → mplang_nightly-0.1.dev149}/tests/integration/test_tutorials.py +0 -0
  135. {mplang_nightly-0.1.dev148 → mplang_nightly-0.1.dev149}/tests/runtime/__init__.py +0 -0
  136. {mplang_nightly-0.1.dev148 → mplang_nightly-0.1.dev149}/tests/runtime/test_cli.py +0 -0
  137. {mplang_nightly-0.1.dev148 → mplang_nightly-0.1.dev149}/tests/runtime/test_communicator.py +0 -0
  138. {mplang_nightly-0.1.dev148 → mplang_nightly-0.1.dev149}/tests/runtime/test_driver.py +0 -0
  139. {mplang_nightly-0.1.dev148 → mplang_nightly-0.1.dev149}/tests/runtime/test_server.py +0 -0
  140. {mplang_nightly-0.1.dev148 → mplang_nightly-0.1.dev149}/tests/runtime/test_simulation.py +0 -0
  141. {mplang_nightly-0.1.dev148 → mplang_nightly-0.1.dev149}/tests/simp/test_mpi.py +0 -0
  142. {mplang_nightly-0.1.dev148 → mplang_nightly-0.1.dev149}/tests/simp/test_random.py +0 -0
  143. {mplang_nightly-0.1.dev148 → mplang_nightly-0.1.dev149}/tests/simp/test_simp.py +0 -0
  144. {mplang_nightly-0.1.dev148 → mplang_nightly-0.1.dev149}/tests/simp/test_smpc.py +0 -0
  145. {mplang_nightly-0.1.dev148 → mplang_nightly-0.1.dev149}/tests/simp/test_sugar.py +0 -0
  146. {mplang_nightly-0.1.dev148 → mplang_nightly-0.1.dev149}/tests/utils/__init__.py +0 -0
  147. {mplang_nightly-0.1.dev148 → mplang_nightly-0.1.dev149}/tests/utils/server_fixtures.py +0 -0
  148. {mplang_nightly-0.1.dev148 → mplang_nightly-0.1.dev149}/tests/utils/test_func_utils.py +0 -0
  149. {mplang_nightly-0.1.dev148 → mplang_nightly-0.1.dev149}/tests/utils/test_spu_utils.py +0 -0
  150. {mplang_nightly-0.1.dev148 → mplang_nightly-0.1.dev149}/tests/utils/test_table_utils.py +0 -0
  151. {mplang_nightly-0.1.dev148 → mplang_nightly-0.1.dev149}/tutorials/0_basic.py +0 -0
  152. {mplang_nightly-0.1.dev148 → mplang_nightly-0.1.dev149}/tutorials/10_analysis.py +0 -0
  153. {mplang_nightly-0.1.dev148 → mplang_nightly-0.1.dev149}/tutorials/1_condition.py +0 -0
  154. {mplang_nightly-0.1.dev148 → mplang_nightly-0.1.dev149}/tutorials/2_whileloop.py +0 -0
  155. {mplang_nightly-0.1.dev148 → mplang_nightly-0.1.dev149}/tutorials/3_device.py +0 -0
  156. {mplang_nightly-0.1.dev148 → mplang_nightly-0.1.dev149}/tutorials/4_simulation.py +0 -0
  157. {mplang_nightly-0.1.dev148 → mplang_nightly-0.1.dev149}/tutorials/5_ir_dump.py +0 -0
  158. {mplang_nightly-0.1.dev148 → mplang_nightly-0.1.dev149}/tutorials/6_advanced.py +0 -0
  159. {mplang_nightly-0.1.dev148 → mplang_nightly-0.1.dev149}/tutorials/7_stdio.py +0 -0
  160. {mplang_nightly-0.1.dev148 → mplang_nightly-0.1.dev149}/tutorials/8_phe.py +0 -0
  161. {mplang_nightly-0.1.dev148 → mplang_nightly-0.1.dev149}/tutorials/9_tee.py +0 -0
  162. {mplang_nightly-0.1.dev148 → mplang_nightly-0.1.dev149}/tutorials/__init__.py +0 -0
  163. {mplang_nightly-0.1.dev148 → mplang_nightly-0.1.dev149}/tutorials/pitfalls/late_binding.py +0 -0
  164. {mplang_nightly-0.1.dev148 → mplang_nightly-0.1.dev149}/tutorials/pitfalls/rand.py +0 -0
  165. {mplang_nightly-0.1.dev148 → mplang_nightly-0.1.dev149}/tutorials/run.sh +0 -0
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: mplang-nightly
3
- Version: 0.1.dev148
3
+ Version: 0.1.dev149
4
4
  Summary: Multi-Party Programming Language
5
5
  Author-email: SecretFlow Team <secretflow-contact@service.alipay.com>
6
6
  License: Apache License
@@ -51,8 +51,15 @@ def _stablehlo_exec(pfunc: PFunction, *args: Any) -> Any:
51
51
  raise RuntimeError(f"StableHLO compile failed: {e}") from e
52
52
  cache[mlir_text] = compiled
53
53
 
54
+ # Handle JAX's unused parameter elimination via arg_keep_map
55
+ runtime_args = args
56
+ if "arg_keep_map" in pfunc.attrs:
57
+ keep_indices = pfunc.attrs["arg_keep_map"]
58
+ # Filter out arguments that were eliminated by JAX during compilation
59
+ runtime_args = tuple(args[i] for i in keep_indices)
60
+
54
61
  jax_args = []
55
- for arg in args:
62
+ for arg in runtime_args:
56
63
  if hasattr(arg, "numpy"):
57
64
  jax_arg = jnp.array(arg.numpy()) # type: ignore
58
65
  else:
@@ -106,14 +106,46 @@ def jax2stablehlo(
106
106
  out_info_flat, out_tree = tree_flatten(lowered.out_info)
107
107
  out_info_flat = [TensorType.from_obj(info) for info in out_info_flat]
108
108
 
109
+ # Extract argument keep mapping to handle JAX's unused parameter elimination
110
+ # JAX can eliminate unused parameters during compilation, but the runtime still
111
+ # receives all original arguments. We need the mapping to filter them correctly.
112
+ arg_keep_map = None
113
+ original_arg_count = len(in_vars)
114
+
115
+ try:
116
+ # Access JAX internal kept_var_idx - the authoritative source
117
+ # This tells us exactly which original parameters survived compilation
118
+ compile_args = lowered._lowering.compile_args
119
+ kept_var_idx = compile_args["kept_var_idx"]
120
+
121
+ kept_indices = sorted(kept_var_idx)
122
+ if len(kept_indices) < original_arg_count:
123
+ arg_keep_map = kept_indices
124
+
125
+ except (AttributeError, KeyError, TypeError) as e:
126
+ # JAX internal API is not available or changed
127
+ # This is a hard error - we cannot reliably handle unused parameters
128
+ # without knowing exactly which ones were kept
129
+ raise RuntimeError(
130
+ f"Cannot access JAX's kept_var_idx to handle unused parameter elimination. "
131
+ f"This function may have unused parameters that JAX optimized away, "
132
+ f"but we cannot determine which ones without the internal API. "
133
+ f"Original error: {e}"
134
+ ) from e
135
+
109
136
  # This format tells JaxRT how to handle the compiled result
110
- pfn = PFunction(
111
- fn_type="mlir.stablehlo", # Key: specify StableHLO MLIR format
112
- ins_info=tuple(TensorType.from_obj(x) for x in in_vars),
113
- outs_info=tuple(out_info_flat),
114
- fn_name=get_fn_name(flat_fn),
115
- fn_text=mlir_text, # MLIR text, serializable for transmission
116
- )
137
+ pfn_kwargs: dict[str, Any] = {
138
+ "fn_type": "mlir.stablehlo", # Key: specify StableHLO MLIR format
139
+ "ins_info": tuple(TensorType.from_obj(x) for x in in_vars),
140
+ "outs_info": tuple(out_info_flat),
141
+ "fn_name": get_fn_name(flat_fn),
142
+ "fn_text": mlir_text, # MLIR text, serializable for transmission
143
+ }
144
+
145
+ if arg_keep_map is not None:
146
+ pfn_kwargs["arg_keep_map"] = arg_keep_map
147
+
148
+ pfn = PFunction(**pfn_kwargs)
117
149
  return pfn, in_vars, out_tree
118
150
 
119
151
 
@@ -244,21 +244,37 @@ class TestJax2StableHLO:
244
244
  assert cfunc.fn_text is not None
245
245
 
246
246
  def test_multiple_outputs(self):
247
- """Test compilation with multiple outputs."""
247
+ """Test functions with multiple outputs."""
248
248
 
249
249
  def multi_output(x, y):
250
250
  return x + y, x - y, x * y
251
251
 
252
- x = jnp.array([1.0, 2.0])
253
- y = jnp.array([3.0, 4.0])
252
+ pfunc, out_tree = self._compile_with_transformer(
253
+ multi_output, jnp.array([1, 2]), jnp.array([3, 4])
254
+ )
254
255
 
255
- cfunc, _out_tree = self._compile_with_transformer(multi_output, x, y)
256
+ assert len(pfunc.outs_info) == 3
257
+ assert out_tree is not None
256
258
 
257
- assert len(cfunc.ins_info) == 2
258
- assert len(cfunc.outs_info) == 3
259
+ def test_unused_parameter_elimination(self):
260
+ """Test that unused parameters are handled correctly via arg_keep_map."""
259
261
 
260
- # All outputs should have the same shape as inputs
261
- for out_info in cfunc.outs_info:
262
- assert out_info.shape == x.shape
262
+ def func_with_unused(x, unused, z):
263
+ return x + z # unused parameter eliminated by JAX
263
264
 
264
- assert cfunc.fn_text is not None
265
+ x = jnp.array(1, dtype=jnp.int32)
266
+ unused = jnp.array(999, dtype=jnp.int32)
267
+ z = jnp.array(3, dtype=jnp.int32)
268
+
269
+ pfunc, _ = self._compile_with_transformer(func_with_unused, x, unused, z)
270
+
271
+ # Check that compilation succeeded
272
+ assert pfunc.fn_type == "mlir.stablehlo"
273
+ assert len(pfunc.ins_info) == 3 # Original input count
274
+
275
+ # If JAX eliminated unused parameters, arg_keep_map should be present
276
+ if "arg_keep_map" in pfunc.attrs:
277
+ keep_map = pfunc.attrs["arg_keep_map"]
278
+ assert isinstance(keep_map, list)
279
+ assert len(keep_map) < 3 # Should be fewer than original 3 params
280
+ assert 1 not in keep_map # Index 1 (unused) should not be kept
@@ -0,0 +1,191 @@
1
+ # Copyright 2025 Ant Group Co., Ltd.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ """
16
+ Test unused parameter handling with mplang integration.
17
+ This test verifies that functions with unused parameters work correctly
18
+ after the arg_keep_map implementation.
19
+ """
20
+
21
+ import jax.numpy as jnp
22
+
23
+ import mplang
24
+ import mplang.simp as simp
25
+
26
+
27
+ def func_with_unused_params(a, unused_param, b, c):
28
+ """Function with unused parameter in the middle."""
29
+ return a + b + c
30
+
31
+
32
+ def func_all_unused_returns_constant(a, unused1, unused2):
33
+ """Function where all parameters are unused - returns constant."""
34
+ return 42
35
+
36
+
37
+ def func_first_last_unused(unused1, b, c, unused2):
38
+ """Function with unused parameters at start and end."""
39
+ return b * c
40
+
41
+
42
+ class TestUnusedParameterHandling:
43
+ """Test suite for JAX unused parameter elimination handling."""
44
+
45
+ @staticmethod
46
+ def _extract_scalar(output):
47
+ """Extract scalar value from potentially wrapped output."""
48
+ if hasattr(output, "__iter__") and len(output) == 1:
49
+ output = output[0]
50
+ if hasattr(output, "item"): # JAX array
51
+ output = output.item()
52
+ return output
53
+
54
+ def test_basic_unused_param(self):
55
+ """Test function with one unused parameter in middle position."""
56
+ sim = mplang.Simulator.simple(1)
57
+
58
+ # Create traced function
59
+ @mplang.function
60
+ def test_func():
61
+ # Test values - create inside traced context
62
+ a = simp.constant(1)
63
+ unused = simp.constant(999) # This should be eliminated by JAX
64
+ b = simp.constant(2)
65
+ c = simp.constant(3)
66
+ return simp.run(func_with_unused_params)(a, unused, b, c)
67
+
68
+ expected = 6 # 1 + 2 + 3
69
+
70
+ # Compile and check that compilation succeeds
71
+ compiled = mplang.compile(sim, test_func)
72
+
73
+ # The function should compile successfully
74
+ assert compiled is not None
75
+
76
+ # Execute and verify result
77
+ result = mplang.evaluate(sim, test_func)
78
+ output = mplang.fetch(sim, result)
79
+
80
+ output = self._extract_scalar(output)
81
+
82
+ assert output == expected, f"Expected {expected}, got {output}"
83
+
84
+ def test_multiple_unused_params(self):
85
+ """Test function with multiple unused parameters."""
86
+ sim = mplang.Simulator.simple(1)
87
+
88
+ b_val = 5
89
+ c_val = 7
90
+ expected = b_val * c_val # 35
91
+
92
+ @mplang.function
93
+ def test_func():
94
+ unused1 = simp.constant(100)
95
+ b = simp.constant(b_val)
96
+ c = simp.constant(c_val)
97
+ unused2 = simp.constant(200)
98
+ return simp.run(func_first_last_unused)(unused1, b, c, unused2)
99
+
100
+ result = mplang.evaluate(sim, test_func)
101
+ output = mplang.fetch(sim, result)
102
+ output = self._extract_scalar(output)
103
+
104
+ assert output == expected, f"Expected {expected}, got {output}"
105
+
106
+ def test_all_params_unused(self):
107
+ """Test function where all parameters are unused (returns constant)."""
108
+ sim = mplang.Simulator.simple(1)
109
+ expected = 42
110
+
111
+ @mplang.function
112
+ def test_func():
113
+ a = simp.constant(1)
114
+ unused1 = simp.constant(10)
115
+ unused2 = simp.constant(20)
116
+ return simp.run(func_all_unused_returns_constant)(a, unused1, unused2)
117
+
118
+ result = mplang.evaluate(sim, test_func)
119
+ output = mplang.fetch(sim, result)
120
+ output = self._extract_scalar(output)
121
+
122
+ assert output == expected, f"Expected {expected}, got {output}"
123
+
124
+ def test_no_unused_params(self):
125
+ """Test function with no unused parameters (regression test)."""
126
+ sim = mplang.Simulator.simple(1)
127
+
128
+ def func_all_used(a, b, c):
129
+ return a + b + c
130
+
131
+ @mplang.function
132
+ def test_func():
133
+ a = simp.constant(10)
134
+ b = simp.constant(20)
135
+ c = simp.constant(30)
136
+ return simp.run(func_all_used)(a, b, c)
137
+
138
+ result = mplang.evaluate(sim, test_func)
139
+ output = mplang.fetch(sim, result)
140
+ output = self._extract_scalar(output)
141
+
142
+ assert output == 60, f"Expected 60, got {output}"
143
+
144
+ def test_arg_keep_map_in_pfunc(self):
145
+ """Test that arg_keep_map is correctly stored in PFunction when needed."""
146
+ from mplang.frontend.jax_cc import jax2stablehlo
147
+
148
+ def func_with_unused(a, unused, b):
149
+ return a * b
150
+
151
+ # Create test inputs
152
+ a = jnp.array(2, dtype=jnp.int32)
153
+ unused = jnp.array(999, dtype=jnp.int32)
154
+ b = jnp.array(3, dtype=jnp.int32)
155
+
156
+ # Mock is_variable function
157
+ def is_variable(arg):
158
+ return True # Treat all as variables for this test
159
+
160
+ # Call jax2stablehlo directly
161
+ pfunc, _, _ = jax2stablehlo(is_variable, func_with_unused, a, unused, b)
162
+
163
+ # Check that arg_keep_map is present when parameters are eliminated
164
+ if "arg_keep_map" in pfunc.attrs:
165
+ keep_map = pfunc.attrs["arg_keep_map"]
166
+ assert isinstance(keep_map, list)
167
+ assert len(keep_map) < 3 # Should be fewer than original 3 params
168
+ assert 1 not in keep_map # Index 1 (unused) should not be in keep_map
169
+ else:
170
+ # If no elimination happened (possible with different JAX versions/optimizations)
171
+ pass
172
+
173
+ def test_different_dtypes_unused(self):
174
+ """Test unused parameter elimination with different data types."""
175
+ sim = mplang.Simulator.simple(1)
176
+
177
+ def func_mixed_types(int_used, float_unused, int_used2):
178
+ return int_used + int_used2 # float_unused is not used
179
+
180
+ @mplang.function
181
+ def test_func():
182
+ a = simp.constant(5)
183
+ unused_float = simp.constant(3.14) # Different dtype, unused
184
+ c = simp.constant(7)
185
+ return simp.run(func_mixed_types)(a, unused_float, c)
186
+
187
+ result = mplang.evaluate(sim, test_func)
188
+ output = mplang.fetch(sim, result)
189
+ output = self._extract_scalar(output)
190
+
191
+ assert output == 12, f"Expected 12, got {output}"