mplang-nightly 0.1.dev169__tar.gz → 0.1.dev170__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (171) hide show
  1. {mplang_nightly-0.1.dev169 → mplang_nightly-0.1.dev170}/PKG-INFO +1 -1
  2. {mplang_nightly-0.1.dev169 → mplang_nightly-0.1.dev170}/mplang/core/expr/ast.py +2 -1
  3. {mplang_nightly-0.1.dev169 → mplang_nightly-0.1.dev170}/mplang/core/expr/evaluator.py +2 -2
  4. {mplang_nightly-0.1.dev169 → mplang_nightly-0.1.dev170}/mplang/core/expr/printer.py +16 -6
  5. {mplang_nightly-0.1.dev169 → mplang_nightly-0.1.dev170}/mplang/core/expr/transformer.py +1 -1
  6. {mplang_nightly-0.1.dev169 → mplang_nightly-0.1.dev170}/mplang/core/mpir.py +6 -1
  7. {mplang_nightly-0.1.dev169 → mplang_nightly-0.1.dev170}/mplang/core/primitive.py +93 -21
  8. {mplang_nightly-0.1.dev169 → mplang_nightly-0.1.dev170}/tests/core/expr/test_ast.py +1 -1
  9. {mplang_nightly-0.1.dev169 → mplang_nightly-0.1.dev170}/tests/core/expr/test_printer.py +8 -8
  10. {mplang_nightly-0.1.dev169 → mplang_nightly-0.1.dev170}/tests/core/test_mpir.py +2 -2
  11. {mplang_nightly-0.1.dev169 → mplang_nightly-0.1.dev170}/tests/core/test_primitive.py +6 -6
  12. {mplang_nightly-0.1.dev169 → mplang_nightly-0.1.dev170}/.gitignore +0 -0
  13. {mplang_nightly-0.1.dev169 → mplang_nightly-0.1.dev170}/LICENSE +0 -0
  14. {mplang_nightly-0.1.dev169 → mplang_nightly-0.1.dev170}/README.md +0 -0
  15. {mplang_nightly-0.1.dev169 → mplang_nightly-0.1.dev170}/examples/conf/3pc.yaml +0 -0
  16. {mplang_nightly-0.1.dev169 → mplang_nightly-0.1.dev170}/examples/stax_nn/README.md +0 -0
  17. {mplang_nightly-0.1.dev169 → mplang_nightly-0.1.dev170}/examples/stax_nn/models.py +0 -0
  18. {mplang_nightly-0.1.dev169 → mplang_nightly-0.1.dev170}/examples/stax_nn/stax_nn.py +0 -0
  19. {mplang_nightly-0.1.dev169 → mplang_nightly-0.1.dev170}/examples/xgboost/hist_jax.py +0 -0
  20. {mplang_nightly-0.1.dev169 → mplang_nightly-0.1.dev170}/examples/xgboost/hist_jax_test.py +0 -0
  21. {mplang_nightly-0.1.dev169 → mplang_nightly-0.1.dev170}/examples/xgboost/naive_np.py +0 -0
  22. {mplang_nightly-0.1.dev169 → mplang_nightly-0.1.dev170}/examples/xgboost/readme.md +0 -0
  23. {mplang_nightly-0.1.dev169 → mplang_nightly-0.1.dev170}/examples/xgboost/sgb.py +0 -0
  24. {mplang_nightly-0.1.dev169 → mplang_nightly-0.1.dev170}/examples/xgboost/sgb_test.py +0 -0
  25. {mplang_nightly-0.1.dev169 → mplang_nightly-0.1.dev170}/hatch_build.py +0 -0
  26. {mplang_nightly-0.1.dev169 → mplang_nightly-0.1.dev170}/mplang/__init__.py +0 -0
  27. {mplang_nightly-0.1.dev169 → mplang_nightly-0.1.dev170}/mplang/analysis/__init__.py +0 -0
  28. {mplang_nightly-0.1.dev169 → mplang_nightly-0.1.dev170}/mplang/analysis/diagram.py +0 -0
  29. {mplang_nightly-0.1.dev169 → mplang_nightly-0.1.dev170}/mplang/api.py +0 -0
  30. {mplang_nightly-0.1.dev169 → mplang_nightly-0.1.dev170}/mplang/core/__init__.py +0 -0
  31. {mplang_nightly-0.1.dev169 → mplang_nightly-0.1.dev170}/mplang/core/cluster.py +0 -0
  32. {mplang_nightly-0.1.dev169 → mplang_nightly-0.1.dev170}/mplang/core/comm.py +0 -0
  33. {mplang_nightly-0.1.dev169 → mplang_nightly-0.1.dev170}/mplang/core/context_mgr.py +0 -0
  34. {mplang_nightly-0.1.dev169 → mplang_nightly-0.1.dev170}/mplang/core/dtype.py +0 -0
  35. {mplang_nightly-0.1.dev169 → mplang_nightly-0.1.dev170}/mplang/core/expr/__init__.py +0 -0
  36. {mplang_nightly-0.1.dev169 → mplang_nightly-0.1.dev170}/mplang/core/expr/utils.py +0 -0
  37. {mplang_nightly-0.1.dev169 → mplang_nightly-0.1.dev170}/mplang/core/expr/visitor.py +0 -0
  38. {mplang_nightly-0.1.dev169 → mplang_nightly-0.1.dev170}/mplang/core/expr/walk.py +0 -0
  39. {mplang_nightly-0.1.dev169 → mplang_nightly-0.1.dev170}/mplang/core/interp.py +0 -0
  40. {mplang_nightly-0.1.dev169 → mplang_nightly-0.1.dev170}/mplang/core/mask.py +0 -0
  41. {mplang_nightly-0.1.dev169 → mplang_nightly-0.1.dev170}/mplang/core/mpobject.py +0 -0
  42. {mplang_nightly-0.1.dev169 → mplang_nightly-0.1.dev170}/mplang/core/mptype.py +0 -0
  43. {mplang_nightly-0.1.dev169 → mplang_nightly-0.1.dev170}/mplang/core/pfunc.py +0 -0
  44. {mplang_nightly-0.1.dev169 → mplang_nightly-0.1.dev170}/mplang/core/table.py +0 -0
  45. {mplang_nightly-0.1.dev169 → mplang_nightly-0.1.dev170}/mplang/core/tensor.py +0 -0
  46. {mplang_nightly-0.1.dev169 → mplang_nightly-0.1.dev170}/mplang/core/tracer.py +0 -0
  47. {mplang_nightly-0.1.dev169 → mplang_nightly-0.1.dev170}/mplang/device.py +0 -0
  48. {mplang_nightly-0.1.dev169 → mplang_nightly-0.1.dev170}/mplang/kernels/__init__.py +0 -0
  49. {mplang_nightly-0.1.dev169 → mplang_nightly-0.1.dev170}/mplang/kernels/base.py +0 -0
  50. {mplang_nightly-0.1.dev169 → mplang_nightly-0.1.dev170}/mplang/kernels/builtin.py +0 -0
  51. {mplang_nightly-0.1.dev169 → mplang_nightly-0.1.dev170}/mplang/kernels/context.py +0 -0
  52. {mplang_nightly-0.1.dev169 → mplang_nightly-0.1.dev170}/mplang/kernels/crypto.py +0 -0
  53. {mplang_nightly-0.1.dev169 → mplang_nightly-0.1.dev170}/mplang/kernels/mock_tee.py +0 -0
  54. {mplang_nightly-0.1.dev169 → mplang_nightly-0.1.dev170}/mplang/kernels/phe.py +0 -0
  55. {mplang_nightly-0.1.dev169 → mplang_nightly-0.1.dev170}/mplang/kernels/spu.py +0 -0
  56. {mplang_nightly-0.1.dev169 → mplang_nightly-0.1.dev170}/mplang/kernels/sql_duckdb.py +0 -0
  57. {mplang_nightly-0.1.dev169 → mplang_nightly-0.1.dev170}/mplang/kernels/stablehlo.py +0 -0
  58. {mplang_nightly-0.1.dev169 → mplang_nightly-0.1.dev170}/mplang/kernels/value.py +0 -0
  59. {mplang_nightly-0.1.dev169 → mplang_nightly-0.1.dev170}/mplang/ops/__init__.py +0 -0
  60. {mplang_nightly-0.1.dev169 → mplang_nightly-0.1.dev170}/mplang/ops/base.py +0 -0
  61. {mplang_nightly-0.1.dev169 → mplang_nightly-0.1.dev170}/mplang/ops/builtin.py +0 -0
  62. {mplang_nightly-0.1.dev169 → mplang_nightly-0.1.dev170}/mplang/ops/crypto.py +0 -0
  63. {mplang_nightly-0.1.dev169 → mplang_nightly-0.1.dev170}/mplang/ops/ibis_cc.py +0 -0
  64. {mplang_nightly-0.1.dev169 → mplang_nightly-0.1.dev170}/mplang/ops/jax_cc.py +0 -0
  65. {mplang_nightly-0.1.dev169 → mplang_nightly-0.1.dev170}/mplang/ops/phe.py +0 -0
  66. {mplang_nightly-0.1.dev169 → mplang_nightly-0.1.dev170}/mplang/ops/spu.py +0 -0
  67. {mplang_nightly-0.1.dev169 → mplang_nightly-0.1.dev170}/mplang/ops/sql_cc.py +0 -0
  68. {mplang_nightly-0.1.dev169 → mplang_nightly-0.1.dev170}/mplang/ops/tee.py +0 -0
  69. {mplang_nightly-0.1.dev169 → mplang_nightly-0.1.dev170}/mplang/protos/v1alpha1/mpir_pb2.py +0 -0
  70. {mplang_nightly-0.1.dev169 → mplang_nightly-0.1.dev170}/mplang/protos/v1alpha1/mpir_pb2.pyi +0 -0
  71. {mplang_nightly-0.1.dev169 → mplang_nightly-0.1.dev170}/mplang/protos/v1alpha1/mpir_pb2_grpc.py +0 -0
  72. {mplang_nightly-0.1.dev169 → mplang_nightly-0.1.dev170}/mplang/protos/v1alpha1/value_pb2.py +0 -0
  73. {mplang_nightly-0.1.dev169 → mplang_nightly-0.1.dev170}/mplang/protos/v1alpha1/value_pb2.pyi +0 -0
  74. {mplang_nightly-0.1.dev169 → mplang_nightly-0.1.dev170}/mplang/runtime/__init__.py +0 -0
  75. {mplang_nightly-0.1.dev169 → mplang_nightly-0.1.dev170}/mplang/runtime/cli.py +0 -0
  76. {mplang_nightly-0.1.dev169 → mplang_nightly-0.1.dev170}/mplang/runtime/client.py +0 -0
  77. {mplang_nightly-0.1.dev169 → mplang_nightly-0.1.dev170}/mplang/runtime/communicator.py +0 -0
  78. {mplang_nightly-0.1.dev169 → mplang_nightly-0.1.dev170}/mplang/runtime/data_providers.py +0 -0
  79. {mplang_nightly-0.1.dev169 → mplang_nightly-0.1.dev170}/mplang/runtime/driver.py +0 -0
  80. {mplang_nightly-0.1.dev169 → mplang_nightly-0.1.dev170}/mplang/runtime/exceptions.py +0 -0
  81. {mplang_nightly-0.1.dev169 → mplang_nightly-0.1.dev170}/mplang/runtime/http_api.md +0 -0
  82. {mplang_nightly-0.1.dev169 → mplang_nightly-0.1.dev170}/mplang/runtime/link_comm.py +0 -0
  83. {mplang_nightly-0.1.dev169 → mplang_nightly-0.1.dev170}/mplang/runtime/server.py +0 -0
  84. {mplang_nightly-0.1.dev169 → mplang_nightly-0.1.dev170}/mplang/runtime/session.py +0 -0
  85. {mplang_nightly-0.1.dev169 → mplang_nightly-0.1.dev170}/mplang/runtime/simulation.py +0 -0
  86. {mplang_nightly-0.1.dev169 → mplang_nightly-0.1.dev170}/mplang/simp/__init__.py +0 -0
  87. {mplang_nightly-0.1.dev169 → mplang_nightly-0.1.dev170}/mplang/simp/mpi.py +0 -0
  88. {mplang_nightly-0.1.dev169 → mplang_nightly-0.1.dev170}/mplang/simp/random.py +0 -0
  89. {mplang_nightly-0.1.dev169 → mplang_nightly-0.1.dev170}/mplang/simp/smpc.py +0 -0
  90. {mplang_nightly-0.1.dev169 → mplang_nightly-0.1.dev170}/mplang/utils/__init__.py +0 -0
  91. {mplang_nightly-0.1.dev169 → mplang_nightly-0.1.dev170}/mplang/utils/crypto.py +0 -0
  92. {mplang_nightly-0.1.dev169 → mplang_nightly-0.1.dev170}/mplang/utils/func_utils.py +0 -0
  93. {mplang_nightly-0.1.dev169 → mplang_nightly-0.1.dev170}/mplang/utils/spu_utils.py +0 -0
  94. {mplang_nightly-0.1.dev169 → mplang_nightly-0.1.dev170}/mplang/utils/table_utils.py +0 -0
  95. {mplang_nightly-0.1.dev169 → mplang_nightly-0.1.dev170}/pyproject.toml +0 -0
  96. {mplang_nightly-0.1.dev169 → mplang_nightly-0.1.dev170}/tests/__init__.py +0 -0
  97. {mplang_nightly-0.1.dev169 → mplang_nightly-0.1.dev170}/tests/analysis/test_diagram.py +0 -0
  98. {mplang_nightly-0.1.dev169 → mplang_nightly-0.1.dev170}/tests/conftest.py +0 -0
  99. {mplang_nightly-0.1.dev169 → mplang_nightly-0.1.dev170}/tests/core/__init__.py +0 -0
  100. {mplang_nightly-0.1.dev169 → mplang_nightly-0.1.dev170}/tests/core/expr/__init__.py +0 -0
  101. {mplang_nightly-0.1.dev169 → mplang_nightly-0.1.dev170}/tests/core/expr/conftest.py +0 -0
  102. {mplang_nightly-0.1.dev169 → mplang_nightly-0.1.dev170}/tests/core/expr/test_utils.py +0 -0
  103. {mplang_nightly-0.1.dev169 → mplang_nightly-0.1.dev170}/tests/core/expr/test_walk.py +0 -0
  104. {mplang_nightly-0.1.dev169 → mplang_nightly-0.1.dev170}/tests/core/test_cluster.py +0 -0
  105. {mplang_nightly-0.1.dev169 → mplang_nightly-0.1.dev170}/tests/core/test_dtype.py +0 -0
  106. {mplang_nightly-0.1.dev169 → mplang_nightly-0.1.dev170}/tests/core/test_mask.py +0 -0
  107. {mplang_nightly-0.1.dev169 → mplang_nightly-0.1.dev170}/tests/core/test_mptype.py +0 -0
  108. {mplang_nightly-0.1.dev169 → mplang_nightly-0.1.dev170}/tests/core/test_table.py +0 -0
  109. {mplang_nightly-0.1.dev169 → mplang_nightly-0.1.dev170}/tests/core/test_tensor.py +0 -0
  110. {mplang_nightly-0.1.dev169 → mplang_nightly-0.1.dev170}/tests/core/test_tracer.py +0 -0
  111. {mplang_nightly-0.1.dev169 → mplang_nightly-0.1.dev170}/tests/device/__init__.py +0 -0
  112. {mplang_nightly-0.1.dev169 → mplang_nightly-0.1.dev170}/tests/device/test_device_basic.py +0 -0
  113. {mplang_nightly-0.1.dev169 → mplang_nightly-0.1.dev170}/tests/integration/README.md +0 -0
  114. {mplang_nightly-0.1.dev169 → mplang_nightly-0.1.dev170}/tests/integration/test_crypto_roundtrip.py +0 -0
  115. {mplang_nightly-0.1.dev169 → mplang_nightly-0.1.dev170}/tests/integration/test_http_e2e.py +0 -0
  116. {mplang_nightly-0.1.dev169 → mplang_nightly-0.1.dev170}/tests/integration/test_symbols_roundtrip.py +0 -0
  117. {mplang_nightly-0.1.dev169 → mplang_nightly-0.1.dev170}/tests/integration/test_tutorials.py +0 -0
  118. {mplang_nightly-0.1.dev169 → mplang_nightly-0.1.dev170}/tests/integration/test_unused_param_integration.py +0 -0
  119. {mplang_nightly-0.1.dev169 → mplang_nightly-0.1.dev170}/tests/kernels/test_builtin.py +0 -0
  120. {mplang_nightly-0.1.dev169 → mplang_nightly-0.1.dev170}/tests/kernels/test_debug_print.py +0 -0
  121. {mplang_nightly-0.1.dev169 → mplang_nightly-0.1.dev170}/tests/kernels/test_kernel_binding.py +0 -0
  122. {mplang_nightly-0.1.dev169 → mplang_nightly-0.1.dev170}/tests/kernels/test_phe.py +0 -0
  123. {mplang_nightly-0.1.dev169 → mplang_nightly-0.1.dev170}/tests/kernels/test_spu.py +0 -0
  124. {mplang_nightly-0.1.dev169 → mplang_nightly-0.1.dev170}/tests/kernels/test_sql_duckdb.py +0 -0
  125. {mplang_nightly-0.1.dev169 → mplang_nightly-0.1.dev170}/tests/kernels/test_stablehlo.py +0 -0
  126. {mplang_nightly-0.1.dev169 → mplang_nightly-0.1.dev170}/tests/kernels/test_value.py +0 -0
  127. {mplang_nightly-0.1.dev169 → mplang_nightly-0.1.dev170}/tests/kernels/test_value_serde.py +0 -0
  128. {mplang_nightly-0.1.dev169 → mplang_nightly-0.1.dev170}/tests/ops/__init__.py +0 -0
  129. {mplang_nightly-0.1.dev169 → mplang_nightly-0.1.dev170}/tests/ops/dummy.py +0 -0
  130. {mplang_nightly-0.1.dev169 → mplang_nightly-0.1.dev170}/tests/ops/test_builtin_pack.py +0 -0
  131. {mplang_nightly-0.1.dev169 → mplang_nightly-0.1.dev170}/tests/ops/test_crypto_tee.py +0 -0
  132. {mplang_nightly-0.1.dev169 → mplang_nightly-0.1.dev170}/tests/ops/test_feop_base.py +0 -0
  133. {mplang_nightly-0.1.dev169 → mplang_nightly-0.1.dev170}/tests/ops/test_ibis.py +0 -0
  134. {mplang_nightly-0.1.dev169 → mplang_nightly-0.1.dev170}/tests/ops/test_ibis_cc.py +0 -0
  135. {mplang_nightly-0.1.dev169 → mplang_nightly-0.1.dev170}/tests/ops/test_jax_cc.py +0 -0
  136. {mplang_nightly-0.1.dev169 → mplang_nightly-0.1.dev170}/tests/ops/test_phe.py +0 -0
  137. {mplang_nightly-0.1.dev169 → mplang_nightly-0.1.dev170}/tests/ops/test_spu.py +0 -0
  138. {mplang_nightly-0.1.dev169 → mplang_nightly-0.1.dev170}/tests/ops/test_spu_defensive.py +0 -0
  139. {mplang_nightly-0.1.dev169 → mplang_nightly-0.1.dev170}/tests/ops/test_sql.py +0 -0
  140. {mplang_nightly-0.1.dev169 → mplang_nightly-0.1.dev170}/tests/ops/test_table_tensor_conversion.py +0 -0
  141. {mplang_nightly-0.1.dev169 → mplang_nightly-0.1.dev170}/tests/runtime/__init__.py +0 -0
  142. {mplang_nightly-0.1.dev169 → mplang_nightly-0.1.dev170}/tests/runtime/test_cli.py +0 -0
  143. {mplang_nightly-0.1.dev169 → mplang_nightly-0.1.dev170}/tests/runtime/test_communicator.py +0 -0
  144. {mplang_nightly-0.1.dev169 → mplang_nightly-0.1.dev170}/tests/runtime/test_driver.py +0 -0
  145. {mplang_nightly-0.1.dev169 → mplang_nightly-0.1.dev170}/tests/runtime/test_server.py +0 -0
  146. {mplang_nightly-0.1.dev169 → mplang_nightly-0.1.dev170}/tests/runtime/test_simulation.py +0 -0
  147. {mplang_nightly-0.1.dev169 → mplang_nightly-0.1.dev170}/tests/simp/test_mpi.py +0 -0
  148. {mplang_nightly-0.1.dev169 → mplang_nightly-0.1.dev170}/tests/simp/test_random.py +0 -0
  149. {mplang_nightly-0.1.dev169 → mplang_nightly-0.1.dev170}/tests/simp/test_simp.py +0 -0
  150. {mplang_nightly-0.1.dev169 → mplang_nightly-0.1.dev170}/tests/simp/test_smpc.py +0 -0
  151. {mplang_nightly-0.1.dev169 → mplang_nightly-0.1.dev170}/tests/simp/test_sugar.py +0 -0
  152. {mplang_nightly-0.1.dev169 → mplang_nightly-0.1.dev170}/tests/utils/__init__.py +0 -0
  153. {mplang_nightly-0.1.dev169 → mplang_nightly-0.1.dev170}/tests/utils/server_fixtures.py +0 -0
  154. {mplang_nightly-0.1.dev169 → mplang_nightly-0.1.dev170}/tests/utils/test_func_utils.py +0 -0
  155. {mplang_nightly-0.1.dev169 → mplang_nightly-0.1.dev170}/tests/utils/test_spu_utils.py +0 -0
  156. {mplang_nightly-0.1.dev169 → mplang_nightly-0.1.dev170}/tests/utils/test_table_utils.py +0 -0
  157. {mplang_nightly-0.1.dev169 → mplang_nightly-0.1.dev170}/tutorials/0_basic.py +0 -0
  158. {mplang_nightly-0.1.dev169 → mplang_nightly-0.1.dev170}/tutorials/10_analysis.py +0 -0
  159. {mplang_nightly-0.1.dev169 → mplang_nightly-0.1.dev170}/tutorials/1_condition.py +0 -0
  160. {mplang_nightly-0.1.dev169 → mplang_nightly-0.1.dev170}/tutorials/2_whileloop.py +0 -0
  161. {mplang_nightly-0.1.dev169 → mplang_nightly-0.1.dev170}/tutorials/3_device.py +0 -0
  162. {mplang_nightly-0.1.dev169 → mplang_nightly-0.1.dev170}/tutorials/4_simulation.py +0 -0
  163. {mplang_nightly-0.1.dev169 → mplang_nightly-0.1.dev170}/tutorials/5_ir_dump.py +0 -0
  164. {mplang_nightly-0.1.dev169 → mplang_nightly-0.1.dev170}/tutorials/6_advanced.py +0 -0
  165. {mplang_nightly-0.1.dev169 → mplang_nightly-0.1.dev170}/tutorials/7_stdio.py +0 -0
  166. {mplang_nightly-0.1.dev169 → mplang_nightly-0.1.dev170}/tutorials/8_phe.py +0 -0
  167. {mplang_nightly-0.1.dev169 → mplang_nightly-0.1.dev170}/tutorials/9_tee.py +0 -0
  168. {mplang_nightly-0.1.dev169 → mplang_nightly-0.1.dev170}/tutorials/__init__.py +0 -0
  169. {mplang_nightly-0.1.dev169 → mplang_nightly-0.1.dev170}/tutorials/pitfalls/late_binding.py +0 -0
  170. {mplang_nightly-0.1.dev169 → mplang_nightly-0.1.dev170}/tutorials/pitfalls/rand.py +0 -0
  171. {mplang_nightly-0.1.dev169 → mplang_nightly-0.1.dev170}/tutorials/run.sh +0 -0
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: mplang-nightly
3
- Version: 0.1.dev169
3
+ Version: 0.1.dev170
4
4
  Summary: Multi-Party Programming Language
5
5
  Author-email: SecretFlow Team <secretflow-contact@service.alipay.com>
6
6
  License: Apache License
@@ -528,8 +528,9 @@ class FuncDefExpr(Expr):
528
528
  class CallExpr(Expr):
529
529
  """Expression for function call."""
530
530
 
531
- def __init__(self, fn: FuncDefExpr, args: list[Expr]):
531
+ def __init__(self, name: str, fn: FuncDefExpr, args: list[Expr]):
532
532
  super().__init__()
533
+ self.name = name
533
534
  self.fn = fn
534
535
  self.args = args
535
536
 
@@ -341,10 +341,10 @@ class RecursiveEvaluator(EvalSemantic, ExprVisitor):
341
341
 
342
342
  # Only evaluate selected branch locally
343
343
  if bool(pred):
344
- then_call = CallExpr(expr.then_fn, expr.args)
344
+ then_call = CallExpr("then", expr.then_fn, expr.args)
345
345
  return self._values(then_call)
346
346
  else:
347
- else_call = CallExpr(expr.else_fn, expr.args)
347
+ else_call = CallExpr("else", expr.else_fn, expr.args)
348
348
  return self._values(else_call)
349
349
 
350
350
  def visit_call(self, expr: CallExpr) -> Any:
@@ -50,11 +50,13 @@ class Printer(ExprVisitor):
50
50
  compact_format: bool = True,
51
51
  *,
52
52
  verbose_peval: bool = False,
53
+ inline_pcall: bool = True,
53
54
  ):
54
55
  super().__init__() # Initialize MemorizedVisitor
55
56
  self.indent_size = indent_size
56
57
  self.compact_format = compact_format
57
58
  self.verbose_peval = verbose_peval
59
+ self.inline_pcall = inline_pcall
58
60
  self._cur_indent = 0
59
61
  self._output: list[str] = []
60
62
  self._visited: dict[Expr, str] = {}
@@ -92,6 +94,7 @@ class Printer(ExprVisitor):
92
94
  body_printer = Printer(
93
95
  indent_size=self.indent_size,
94
96
  compact_format=self.compact_format,
97
+ inline_pcall=self.inline_pcall,
95
98
  )
96
99
  func_def_expr.accept(body_printer)
97
100
  regions_str += f"{indent}{r_name}: "
@@ -209,12 +212,19 @@ class Printer(ExprVisitor):
209
212
 
210
213
  def visit_call(self, expr: CallExpr) -> str:
211
214
  arg_names = [self._var_name(arg) for arg in expr.args]
212
- return self._do_print(
213
- "pcall",
214
- arg_names,
215
- regions={"fn": expr.fn},
216
- mptypes=expr.mptypes,
217
- )
215
+ if self.inline_pcall:
216
+ return self._do_print(
217
+ expr.name,
218
+ arg_names,
219
+ mptypes=expr.mptypes,
220
+ )
221
+ else:
222
+ return self._do_print(
223
+ "pcall",
224
+ arg_names,
225
+ regions={"fn": expr.fn},
226
+ mptypes=expr.mptypes,
227
+ )
218
228
 
219
229
  def visit_while(self, expr: WhileExpr) -> str:
220
230
  arg_names = [self._var_name(arg) for arg in expr.args]
@@ -79,7 +79,7 @@ class ExprTransformer(ExprVisitor):
79
79
  def visit_call(self, expr: CallExpr) -> Expr:
80
80
  # Transform child expressions first
81
81
  transformed_args = [arg.accept(self) for arg in expr.args]
82
- new_expr = CallExpr(expr.fn, transformed_args)
82
+ new_expr = CallExpr(expr.name, expr.fn, transformed_args)
83
83
 
84
84
  if "call" in self.trans_rules:
85
85
  return self.trans_rules["call"](new_expr)
@@ -491,6 +491,7 @@ class Writer:
491
491
  op = self._create_node_proto(expr, "call")
492
492
  self._add_single_expr_inputs(op, expr.fn)
493
493
  self._add_expr_inputs(op, *expr.args)
494
+ self._add_attrs(op, name=expr.name)
494
495
  self._finalize_node(op, expr)
495
496
  elif isinstance(expr, WhileExpr):
496
497
  op = self._create_node_proto(expr, "while")
@@ -822,8 +823,12 @@ class Reader:
822
823
  arg_exprs.append(self._value_cache[dep_name])
823
824
  else:
824
825
  raise ValueError(f"Input {input_name} not found for call node")
826
+ # Optional call-site name attribute
827
+ call_name = None
828
+ if "name" in node_proto.attrs:
829
+ call_name = self._proto_to_attr(node_proto.attrs["name"]) # type: ignore[assignment]
825
830
 
826
- return CallExpr(fn_expr, arg_exprs)
831
+ return CallExpr(call_name or "", fn_expr, arg_exprs)
827
832
 
828
833
  def _proto_to_mptype(self, type_proto: mpir_pb2.MPTypeProto) -> MPType:
829
834
  """Convert MPTypeProto to MPType."""
@@ -32,6 +32,7 @@ from mplang.core.context_mgr import cur_ctx
32
32
  from mplang.core.dtype import BOOL
33
33
  from mplang.core.expr.ast import (
34
34
  AccessExpr,
35
+ CallExpr,
35
36
  CondExpr,
36
37
  ConvExpr,
37
38
  EvalExpr,
@@ -87,30 +88,106 @@ P = ParamSpec("P")
87
88
  R = TypeVar("R")
88
89
 
89
90
 
90
- def primitive(fn: Callable[P, R]) -> Callable[P, R]:
91
+ def trace_before_apply(fn: Callable[P, R], make_call: bool) -> Callable[P, R]:
91
92
  """A decorator to make all primitive call in trace context."""
92
93
 
93
94
  @wraps(fn)
94
95
  def wrapped(*args: P.args, **kwargs: P.kwargs) -> R:
95
96
  current_ctx = cur_ctx()
96
97
  if isinstance(current_ctx, TraceContext):
97
- # If we are in a tracer context, just call the function.
98
- # Note: switch_ctx will do the capture if needed.
99
- args, kwargs = tree_map(partial(_switch_ctx, current_ctx), (args, kwargs))
100
- return fn(*args, **kwargs)
98
+ # If we are already in a tracer context
99
+ if make_call:
100
+ # make a primitive call
101
+ tracer = current_ctx
102
+ tfn = trace(tracer.fork(), fn, *args, **kwargs)
103
+ is_mpobj = lambda x: isinstance(x, MPObject)
104
+ in_vars, in_imms, in_struct = var_morph((args, kwargs), is_mpobj)
105
+ assert in_struct == tfn.in_struct and in_imms == tfn.in_imms
106
+ arg_exprs = [arg.expr for arg in in_vars]
107
+ # re-capture all captured variables into current context if needed.
108
+ cap_exprs = [tracer.capture(var).expr for var in tfn.capture_map.keys()]
109
+ caller_expr = CallExpr(
110
+ name=fn.__name__, fn=tfn.make_expr(), args=arg_exprs + cap_exprs
111
+ )
112
+ out_vars = [
113
+ TraceVar(tracer, AccessExpr(caller_expr, idx))
114
+ for idx in range(caller_expr.num_outputs)
115
+ ]
116
+ return cast(R, var_demorph(out_vars, tfn.out_imms, tfn.out_struct))
117
+ else:
118
+ # embed the function call in the current tracer context
119
+ # Note: switch_ctx will do the capture if needed.
120
+ args, kwargs = tree_map(
121
+ partial(_switch_ctx, current_ctx), (args, kwargs)
122
+ )
123
+ return fn(*args, **kwargs)
101
124
  elif isinstance(current_ctx, InterpContext):
102
125
  trace_ctx = TraceContext(current_ctx.cluster_spec, parent=current_ctx)
103
126
  # TODO(jint): should we add trace_and_apply to improve the performance?
104
- traced_fn = trace(trace_ctx, fn, *args, **kwargs)
127
+ tfn = trace(trace_ctx, fn, *args, **kwargs)
105
128
  # Return back to the original context.
106
- return cast(R, apply(current_ctx, traced_fn, *args, **kwargs))
129
+ return cast(R, apply(current_ctx, tfn, *args, **kwargs))
107
130
  else:
108
131
  raise ValueError(f"Unsupported context type: {type(current_ctx)}")
109
132
 
110
133
  return wrapped
111
134
 
112
135
 
113
- function = primitive
136
+ def primitive(fn: Callable[P, R]) -> Callable[P, R]:
137
+ """Decorator to trace a Python function as an opaque primitive call (`CallExpr`).
138
+
139
+ When a function decorated with `@primitive` is called within a `TraceContext`, it is
140
+ not inlined. Instead, it is traced separately in a forked context, and a `CallExpr`
141
+ node is inserted into the main graph. This is useful for encapsulating complex
142
+ operations or third-party library calls as single, opaque nodes.
143
+
144
+ **Implementation Note:**
145
+ A `CallExpr` represents a call to a single inline lambda (non-recursive, as we don't
146
+ have Y-combinator support). This single lambda call can be treated as a "primitive call"
147
+ by the printer/visualizer - hence the name "primitive". The function body is captured
148
+ once during tracing and represented as an opaque callable unit in the expression graph,
149
+ maintaining a clear boundary between the caller and callee contexts.
150
+
151
+ Args:
152
+ fn: The function to be traced as a primitive operation.
153
+
154
+ Returns:
155
+ A wrapped function that creates a `CallExpr` node when called in a trace context.
156
+
157
+ Example:
158
+ ```python
159
+ @primitive
160
+ def my_op(x: MPObject) -> MPObject:
161
+ # Complex logic traced as a single CallExpr node
162
+ return x + 1
163
+ ```
164
+ """
165
+ return trace_before_apply(fn, make_call=True)
166
+
167
+
168
+ def function(fn: Callable[P, R]) -> Callable[P, R]:
169
+ """Decorator to trace a Python function by inlining its body.
170
+
171
+ When a function decorated with `@function` is called within a `TraceContext`, its
172
+ underlying primitive operations are expanded and inserted directly into the caller's
173
+ graph. This is the default tracing behavior and is suitable for most pure-Python
174
+ multi-party functions.
175
+
176
+ Args:
177
+ fn: The function to be traced and inlined.
178
+
179
+ Returns:
180
+ A wrapped function that inlines its operations into the caller's trace context.
181
+
182
+ Example:
183
+ ```python
184
+ @function
185
+ def my_func(x: MPObject, y: MPObject) -> MPObject:
186
+ # Operations are inlined into the caller's trace
187
+ return x + y * constant(2)
188
+ ```
189
+ """
190
+ return trace_before_apply(fn, make_call=False)
114
191
 
115
192
 
116
193
  # ============================================================================
@@ -126,18 +203,15 @@ def _tracer() -> TraceContext:
126
203
  return ctx
127
204
 
128
205
 
129
- @primitive
130
206
  def psize() -> int:
131
207
  """Get the size of the current party world.
132
208
 
133
209
  Returns:
134
210
  int: The total number of parties in the current multi-party computation context.
135
211
  """
136
- ctx = _tracer()
137
- return ctx.world_size()
212
+ return cur_ctx().world_size()
138
213
 
139
214
 
140
- @primitive
141
215
  def pmask() -> Mask:
142
216
  """Get the current party mask in this computation context.
143
217
 
@@ -145,8 +219,7 @@ def pmask() -> Mask:
145
219
  Mask: The current party mask indicating which parties are active
146
220
  in the current computation context.
147
221
  """
148
- ctx = _tracer()
149
- return ctx.mask
222
+ return _tracer().mask
150
223
 
151
224
 
152
225
  @primitive
@@ -203,7 +276,6 @@ def prand(shape: Shape = ()) -> MPObject:
203
276
  return out_tree.unflatten(results) # type: ignore[no-any-return]
204
277
 
205
278
 
206
- @primitive
207
279
  def constant(data: TensorLike | ScalarType | TableLike) -> MPObject:
208
280
  """Create a constant tensor or table from data.
209
281
 
@@ -250,7 +322,7 @@ def debug_print(obj: MPObject, prefix: str = "") -> MPObject:
250
322
  return out_tree.unflatten(results) # type: ignore[no-any-return]
251
323
 
252
324
 
253
- @primitive
325
+ @function
254
326
  def peval(
255
327
  pfunc: PFunction,
256
328
  args: list[MPObject],
@@ -378,7 +450,7 @@ def set_mask(arg: MPObject, mask: Mask) -> MPObject:
378
450
  return out_tree.unflatten(results) # type: ignore[no-any-return]
379
451
 
380
452
 
381
- @primitive
453
+ @function
382
454
  def uniform_cond(
383
455
  pred: MPObject,
384
456
  then_fn: Callable[..., Any],
@@ -588,7 +660,7 @@ def uniform_cond(
588
660
  return var_demorph(out_vars, then_tfn.out_imms, then_tfn.out_struct) # type: ignore[no-any-return]
589
661
 
590
662
 
591
- @primitive
663
+ @function
592
664
  def while_loop(
593
665
  cond_fn: Callable[[Any], MPObject],
594
666
  body_fn: Callable[[Any], Any],
@@ -781,7 +853,7 @@ def while_loop(
781
853
  return var_demorph(out_vars, body_tfn.out_imms, body_tfn.out_struct)
782
854
 
783
855
 
784
- @primitive
856
+ @function
785
857
  def pshfl(src: MPObject, index: MPObject) -> MPObject:
786
858
  """Shuffle the input tensor to the specified index (dynamic version).
787
859
 
@@ -851,7 +923,7 @@ def pshfl(src: MPObject, index: MPObject) -> MPObject:
851
923
  return TraceVar(_tracer(), shfl_expr)
852
924
 
853
925
 
854
- @primitive
926
+ @function
855
927
  def pshfl_s(src_val: MPObject, pmask: Mask, src_ranks: list[Rank]) -> MPObject:
856
928
  """Shuffle the input tensor to the specified rank, static version.
857
929
 
@@ -910,7 +982,7 @@ def pshfl_s(src_val: MPObject, pmask: Mask, src_ranks: list[Rank]) -> MPObject:
910
982
  return TraceVar(_tracer(), shfl_s_expr)
911
983
 
912
984
 
913
- @primitive
985
+ @function
914
986
  def pconv(vars: list[MPObject]) -> MPObject:
915
987
  """Combine multiple variables that share the same dtype and shape into one.
916
988
 
@@ -233,7 +233,7 @@ class TestCallExpr:
233
233
  arg_mptype = MPType.tensor(FLOAT32, (2, 3), pmask_2p)
234
234
  arg = VariableExpr("arg", arg_mptype)
235
235
 
236
- expr = CallExpr(func, [arg])
236
+ expr = CallExpr("test", func, [arg])
237
237
 
238
238
  # The call should have the correct structure
239
239
  assert isinstance(expr.fn, FuncDefExpr)
@@ -321,7 +321,7 @@ class TestPrinterComplexExpressions:
321
321
 
322
322
  def test_all_expr_types_printing(self, pmask_2p, pfunc_1i1o):
323
323
  """Test printing of a complex expression involving all Expr types with meaningful parameter usage."""
324
- printer = Printer(compact_format=False)
324
+ printer = Printer(compact_format=False, inline_pcall=False)
325
325
 
326
326
  # 1. Variable expressions
327
327
  var1_mptype = MPType.tensor(FLOAT32, (2,), pmask_2p)
@@ -357,7 +357,7 @@ class TestPrinterComplexExpressions:
357
357
  var_expr = VariableExpr("input_data", param_type)
358
358
  func_body = TupleExpr([var_expr, access_expr])
359
359
  func_def = FuncDefExpr(["input_data"], func_body)
360
- call_expr = CallExpr(func_def, [var1])
360
+ call_expr = CallExpr("test", func_def, [var1])
361
361
 
362
362
  # Access the first output of the call expression to get a single-output expr
363
363
  call_expr_first = AccessExpr(call_expr, 0)
@@ -405,7 +405,7 @@ class TestPrinterComplexExpressions:
405
405
  assert result == expected.strip()
406
406
 
407
407
  # Test with optimize_variables=True for comparison
408
- printer_optimized = Printer(compact_format=True)
408
+ printer_optimized = Printer(compact_format=True, inline_pcall=False)
409
409
  result_optimized = printer_optimized.print_expr(final_expr)
410
410
 
411
411
  # Expected output with variable optimization (compact mode uses variable names directly)
@@ -469,7 +469,7 @@ class TestPrinterComplexExpressions:
469
469
 
470
470
  def test_call_expr_printing(self, pmask_2p):
471
471
  """Test printing of CallExpr with nested function that uses its parameters."""
472
- printer = Printer(compact_format=False)
472
+ printer = Printer(compact_format=False, inline_pcall=False)
473
473
 
474
474
  # Create function body that actually uses the parameters
475
475
  x_mptype = MPType.tensor(FLOAT32, (1,), pmask_2p)
@@ -489,7 +489,7 @@ class TestPrinterComplexExpressions:
489
489
  arg2 = VariableExpr("arg2", arg2_mptype)
490
490
 
491
491
  # Create call expression
492
- expr = CallExpr(fn, [arg1, arg2])
492
+ expr = CallExpr("test", fn, [arg1, arg2])
493
493
 
494
494
  result = printer.print_expr(expr)
495
495
 
@@ -617,7 +617,7 @@ class TestPrinterEdgeCases:
617
617
 
618
618
  def test_deep_nesting_indentation(self, pmask_2p):
619
619
  """Test printer with deeply nested expressions that use parameters meaningfully."""
620
- printer = Printer(compact_format=False)
620
+ printer = Printer(compact_format=False, inline_pcall=False)
621
621
 
622
622
  # Create nested function definitions with meaningful parameter usage
623
623
  inner_param_type = MPType.tensor(FLOAT32, (1,), pmask_2p)
@@ -630,13 +630,13 @@ class TestPrinterEdgeCases:
630
630
 
631
631
  # Middle function: takes parameter and calls inner function with it
632
632
  middle_param = VariableExpr("middle_param", middle_param_type)
633
- middle_body = CallExpr(inner_fn, [middle_param])
633
+ middle_body = CallExpr("middle", inner_fn, [middle_param])
634
634
  middle_fn = FuncDefExpr(["middle_param"], middle_body)
635
635
 
636
636
  # Outer expression: call middle function with a variable expression
637
637
  arg_mptype = MPType.tensor(UINT64, (), pmask_2p)
638
638
  arg = VariableExpr("arg", arg_mptype)
639
- outer_expr = CallExpr(middle_fn, [arg])
639
+ outer_expr = CallExpr("outer", middle_fn, [arg])
640
640
 
641
641
  result = printer.print_expr(outer_expr)
642
642
 
@@ -374,7 +374,7 @@ class TestComplexExpressions:
374
374
  # Create arguments using VariableExpr
375
375
  arg = VariableExpr("input", MPType.tensor(FLOAT32, (2,), pmask=Mask(7)))
376
376
 
377
- original = CallExpr(fn, [arg])
377
+ original = CallExpr("original", fn, [arg])
378
378
 
379
379
  writer = Writer()
380
380
  proto = writer.dumps(original)
@@ -887,7 +887,7 @@ class TestComplexExpressionRoundtrip:
887
887
  MPType.tensor(FLOAT32, (2,), pmask=Mask(7)),
888
888
  )
889
889
 
890
- original = CallExpr(fn, [arg])
890
+ original = CallExpr("original", fn, [arg])
891
891
 
892
892
  # Test roundtrip
893
893
  writer = Writer()
@@ -37,11 +37,11 @@ from mplang.core.mptype import Rank
37
37
  from mplang.core.primitive import (
38
38
  _switch_ctx,
39
39
  constant,
40
+ function,
40
41
  pconv,
41
42
  peval,
42
43
  prand,
43
44
  prank,
44
- primitive,
45
45
  pshfl,
46
46
  pshfl_s,
47
47
  set_mask,
@@ -76,7 +76,7 @@ class TestPrimitiveDecorator:
76
76
  def test_primitive_decorator_basic(self, trace_context):
77
77
  """Test basic primitive decorator functionality."""
78
78
 
79
- @primitive
79
+ @function
80
80
  def simple_func():
81
81
  return constant(42)
82
82
 
@@ -1156,7 +1156,7 @@ class TestCompleteExample:
1156
1156
  def test_complete_primitive_example(self, trace_context):
1157
1157
  """Complete example showing the complete testing pattern."""
1158
1158
 
1159
- @primitive
1159
+ @function
1160
1160
  def example_computation():
1161
1161
  """Example multi-party computation function."""
1162
1162
  my_rank = prank()
@@ -1193,7 +1193,7 @@ class TestComplexExpressions:
1193
1193
  def test_complex_function(self, trace_context):
1194
1194
  """Test function combining multiple primitives."""
1195
1195
 
1196
- @primitive
1196
+ @function
1197
1197
  def complex_func():
1198
1198
  prank()
1199
1199
  prand((2, 2))
@@ -1226,9 +1226,9 @@ class TestComplexExpressions:
1226
1226
  def test_nested_primitives(self, trace_context):
1227
1227
  """Test nested primitive calls."""
1228
1228
 
1229
- @primitive
1229
+ @function
1230
1230
  def outer_func():
1231
- @primitive
1231
+ @function
1232
1232
  def inner_func():
1233
1233
  return constant(1)
1234
1234