mplang-nightly 0.1.dev143__tar.gz → 0.1.dev144__tar.gz
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- {mplang_nightly-0.1.dev143 → mplang_nightly-0.1.dev144}/PKG-INFO +1 -1
- mplang_nightly-0.1.dev144/mplang/backend/base.py +175 -0
- mplang_nightly-0.1.dev144/mplang/backend/context.py +255 -0
- {mplang_nightly-0.1.dev143 → mplang_nightly-0.1.dev144}/mplang/backend/spu.py +6 -4
- {mplang_nightly-0.1.dev143 → mplang_nightly-0.1.dev144}/mplang/backend/sql_duckdb.py +1 -1
- {mplang_nightly-0.1.dev143 → mplang_nightly-0.1.dev144}/mplang/core/expr/evaluator.py +6 -6
- {mplang_nightly-0.1.dev143 → mplang_nightly-0.1.dev144}/mplang/frontend/base.py +1 -1
- {mplang_nightly-0.1.dev143 → mplang_nightly-0.1.dev144}/mplang/frontend/ibis_cc.py +2 -1
- {mplang_nightly-0.1.dev143 → mplang_nightly-0.1.dev144}/mplang/frontend/spu.py +4 -3
- {mplang_nightly-0.1.dev143 → mplang_nightly-0.1.dev144}/mplang/runtime/resource.py +39 -62
- {mplang_nightly-0.1.dev143 → mplang_nightly-0.1.dev144}/mplang/runtime/simulation.py +6 -13
- {mplang_nightly-0.1.dev143 → mplang_nightly-0.1.dev144}/tests/backend/test_builtin.py +4 -4
- {mplang_nightly-0.1.dev143 → mplang_nightly-0.1.dev144}/tests/backend/test_phe.py +5 -4
- {mplang_nightly-0.1.dev143 → mplang_nightly-0.1.dev144}/tests/backend/test_spu.py +13 -15
- {mplang_nightly-0.1.dev143 → mplang_nightly-0.1.dev144}/tests/backend/test_sql_duckdb.py +4 -5
- {mplang_nightly-0.1.dev143 → mplang_nightly-0.1.dev144}/tests/backend/test_stablehlo.py +3 -3
- {mplang_nightly-0.1.dev143 → mplang_nightly-0.1.dev144}/tests/frontend/test_spu.py +1 -1
- {mplang_nightly-0.1.dev143 → mplang_nightly-0.1.dev144}/tests/frontend/test_spu_defensive.py +1 -1
- mplang_nightly-0.1.dev144/tests/runtime/__init__.py +13 -0
- mplang_nightly-0.1.dev143/mplang/backend/__init__.py +0 -20
- mplang_nightly-0.1.dev143/mplang/backend/base.py +0 -287
- {mplang_nightly-0.1.dev143 → mplang_nightly-0.1.dev144}/.gitignore +0 -0
- {mplang_nightly-0.1.dev143 → mplang_nightly-0.1.dev144}/LICENSE +0 -0
- {mplang_nightly-0.1.dev143 → mplang_nightly-0.1.dev144}/README.md +0 -0
- {mplang_nightly-0.1.dev143 → mplang_nightly-0.1.dev144}/examples/conf/3pc.yaml +0 -0
- {mplang_nightly-0.1.dev143 → mplang_nightly-0.1.dev144}/examples/stax_nn/README.md +0 -0
- {mplang_nightly-0.1.dev143 → mplang_nightly-0.1.dev144}/examples/stax_nn/models.py +0 -0
- {mplang_nightly-0.1.dev143 → mplang_nightly-0.1.dev144}/examples/stax_nn/stax_nn.py +0 -0
- {mplang_nightly-0.1.dev143 → mplang_nightly-0.1.dev144}/examples/xgboost/hist_jax.py +0 -0
- {mplang_nightly-0.1.dev143 → mplang_nightly-0.1.dev144}/examples/xgboost/hist_jax_test.py +0 -0
- {mplang_nightly-0.1.dev143 → mplang_nightly-0.1.dev144}/examples/xgboost/naive_np.py +0 -0
- {mplang_nightly-0.1.dev143 → mplang_nightly-0.1.dev144}/examples/xgboost/readme.md +0 -0
- {mplang_nightly-0.1.dev143 → mplang_nightly-0.1.dev144}/examples/xgboost/sgb.py +0 -0
- {mplang_nightly-0.1.dev143 → mplang_nightly-0.1.dev144}/examples/xgboost/sgb_test.py +0 -0
- {mplang_nightly-0.1.dev143 → mplang_nightly-0.1.dev144}/hatch_build.py +0 -0
- {mplang_nightly-0.1.dev143 → mplang_nightly-0.1.dev144}/mplang/__init__.py +0 -0
- {mplang_nightly-0.1.dev143 → mplang_nightly-0.1.dev144}/mplang/analysis/__init__.py +0 -0
- {mplang_nightly-0.1.dev143 → mplang_nightly-0.1.dev144}/mplang/analysis/diagram.py +0 -0
- {mplang_nightly-0.1.dev143 → mplang_nightly-0.1.dev144}/mplang/api.py +0 -0
- {mplang_nightly-0.1.dev143/mplang/utils → mplang_nightly-0.1.dev144/mplang/backend}/__init__.py +0 -0
- {mplang_nightly-0.1.dev143 → mplang_nightly-0.1.dev144}/mplang/backend/builtin.py +0 -0
- {mplang_nightly-0.1.dev143 → mplang_nightly-0.1.dev144}/mplang/backend/crypto.py +0 -0
- {mplang_nightly-0.1.dev143 → mplang_nightly-0.1.dev144}/mplang/backend/phe.py +0 -0
- {mplang_nightly-0.1.dev143 → mplang_nightly-0.1.dev144}/mplang/backend/stablehlo.py +0 -0
- {mplang_nightly-0.1.dev143 → mplang_nightly-0.1.dev144}/mplang/backend/tee.py +0 -0
- {mplang_nightly-0.1.dev143 → mplang_nightly-0.1.dev144}/mplang/core/__init__.py +0 -0
- {mplang_nightly-0.1.dev143 → mplang_nightly-0.1.dev144}/mplang/core/cluster.py +0 -0
- {mplang_nightly-0.1.dev143 → mplang_nightly-0.1.dev144}/mplang/core/comm.py +0 -0
- {mplang_nightly-0.1.dev143 → mplang_nightly-0.1.dev144}/mplang/core/context_mgr.py +0 -0
- {mplang_nightly-0.1.dev143 → mplang_nightly-0.1.dev144}/mplang/core/dtype.py +0 -0
- {mplang_nightly-0.1.dev143 → mplang_nightly-0.1.dev144}/mplang/core/expr/__init__.py +0 -0
- {mplang_nightly-0.1.dev143 → mplang_nightly-0.1.dev144}/mplang/core/expr/ast.py +0 -0
- {mplang_nightly-0.1.dev143 → mplang_nightly-0.1.dev144}/mplang/core/expr/printer.py +0 -0
- {mplang_nightly-0.1.dev143 → mplang_nightly-0.1.dev144}/mplang/core/expr/transformer.py +0 -0
- {mplang_nightly-0.1.dev143 → mplang_nightly-0.1.dev144}/mplang/core/expr/utils.py +0 -0
- {mplang_nightly-0.1.dev143 → mplang_nightly-0.1.dev144}/mplang/core/expr/visitor.py +0 -0
- {mplang_nightly-0.1.dev143 → mplang_nightly-0.1.dev144}/mplang/core/expr/walk.py +0 -0
- {mplang_nightly-0.1.dev143 → mplang_nightly-0.1.dev144}/mplang/core/interp.py +0 -0
- {mplang_nightly-0.1.dev143 → mplang_nightly-0.1.dev144}/mplang/core/mask.py +0 -0
- {mplang_nightly-0.1.dev143 → mplang_nightly-0.1.dev144}/mplang/core/mpir.py +0 -0
- {mplang_nightly-0.1.dev143 → mplang_nightly-0.1.dev144}/mplang/core/mpobject.py +0 -0
- {mplang_nightly-0.1.dev143 → mplang_nightly-0.1.dev144}/mplang/core/mptype.py +0 -0
- {mplang_nightly-0.1.dev143 → mplang_nightly-0.1.dev144}/mplang/core/pfunc.py +0 -0
- {mplang_nightly-0.1.dev143 → mplang_nightly-0.1.dev144}/mplang/core/primitive.py +0 -0
- {mplang_nightly-0.1.dev143 → mplang_nightly-0.1.dev144}/mplang/core/table.py +0 -0
- {mplang_nightly-0.1.dev143 → mplang_nightly-0.1.dev144}/mplang/core/tensor.py +0 -0
- {mplang_nightly-0.1.dev143 → mplang_nightly-0.1.dev144}/mplang/core/tracer.py +0 -0
- {mplang_nightly-0.1.dev143 → mplang_nightly-0.1.dev144}/mplang/device.py +0 -0
- {mplang_nightly-0.1.dev143 → mplang_nightly-0.1.dev144}/mplang/frontend/__init__.py +0 -0
- {mplang_nightly-0.1.dev143 → mplang_nightly-0.1.dev144}/mplang/frontend/builtin.py +0 -0
- {mplang_nightly-0.1.dev143 → mplang_nightly-0.1.dev144}/mplang/frontend/crypto.py +0 -0
- {mplang_nightly-0.1.dev143 → mplang_nightly-0.1.dev144}/mplang/frontend/jax_cc.py +0 -0
- {mplang_nightly-0.1.dev143 → mplang_nightly-0.1.dev144}/mplang/frontend/phe.py +0 -0
- {mplang_nightly-0.1.dev143 → mplang_nightly-0.1.dev144}/mplang/frontend/sql.py +0 -0
- {mplang_nightly-0.1.dev143 → mplang_nightly-0.1.dev144}/mplang/frontend/tee.py +0 -0
- {mplang_nightly-0.1.dev143 → mplang_nightly-0.1.dev144}/mplang/protos/v1alpha1/mpir_pb2.py +0 -0
- {mplang_nightly-0.1.dev143 → mplang_nightly-0.1.dev144}/mplang/protos/v1alpha1/mpir_pb2.pyi +0 -0
- {mplang_nightly-0.1.dev143 → mplang_nightly-0.1.dev144}/mplang/protos/v1alpha1/mpir_pb2_grpc.py +0 -0
- {mplang_nightly-0.1.dev143 → mplang_nightly-0.1.dev144}/mplang/runtime/__init__.py +0 -0
- {mplang_nightly-0.1.dev143 → mplang_nightly-0.1.dev144}/mplang/runtime/cli.py +0 -0
- {mplang_nightly-0.1.dev143 → mplang_nightly-0.1.dev144}/mplang/runtime/client.py +0 -0
- {mplang_nightly-0.1.dev143 → mplang_nightly-0.1.dev144}/mplang/runtime/communicator.py +0 -0
- {mplang_nightly-0.1.dev143 → mplang_nightly-0.1.dev144}/mplang/runtime/data_providers.py +0 -0
- {mplang_nightly-0.1.dev143 → mplang_nightly-0.1.dev144}/mplang/runtime/driver.py +0 -0
- {mplang_nightly-0.1.dev143 → mplang_nightly-0.1.dev144}/mplang/runtime/exceptions.py +0 -0
- {mplang_nightly-0.1.dev143 → mplang_nightly-0.1.dev144}/mplang/runtime/http_api.md +0 -0
- {mplang_nightly-0.1.dev143 → mplang_nightly-0.1.dev144}/mplang/runtime/link_comm.py +0 -0
- {mplang_nightly-0.1.dev143 → mplang_nightly-0.1.dev144}/mplang/runtime/server.py +0 -0
- {mplang_nightly-0.1.dev143 → mplang_nightly-0.1.dev144}/mplang/simp/__init__.py +0 -0
- {mplang_nightly-0.1.dev143 → mplang_nightly-0.1.dev144}/mplang/simp/mpi.py +0 -0
- {mplang_nightly-0.1.dev143 → mplang_nightly-0.1.dev144}/mplang/simp/random.py +0 -0
- {mplang_nightly-0.1.dev143 → mplang_nightly-0.1.dev144}/mplang/simp/smpc.py +0 -0
- {mplang_nightly-0.1.dev143/tests/device → mplang_nightly-0.1.dev144/mplang/utils}/__init__.py +0 -0
- {mplang_nightly-0.1.dev143 → mplang_nightly-0.1.dev144}/mplang/utils/crypto.py +0 -0
- {mplang_nightly-0.1.dev143 → mplang_nightly-0.1.dev144}/mplang/utils/func_utils.py +0 -0
- {mplang_nightly-0.1.dev143 → mplang_nightly-0.1.dev144}/mplang/utils/spu_utils.py +0 -0
- {mplang_nightly-0.1.dev143 → mplang_nightly-0.1.dev144}/mplang/utils/table_utils.py +0 -0
- {mplang_nightly-0.1.dev143 → mplang_nightly-0.1.dev144}/pyproject.toml +0 -0
- {mplang_nightly-0.1.dev143 → mplang_nightly-0.1.dev144}/tests/__init__.py +0 -0
- {mplang_nightly-0.1.dev143 → mplang_nightly-0.1.dev144}/tests/analysis/test_diagram.py +0 -0
- {mplang_nightly-0.1.dev143 → mplang_nightly-0.1.dev144}/tests/backend/test_debug_print.py +0 -0
- {mplang_nightly-0.1.dev143 → mplang_nightly-0.1.dev144}/tests/core/__init__.py +0 -0
- {mplang_nightly-0.1.dev143 → mplang_nightly-0.1.dev144}/tests/core/expr/__init__.py +0 -0
- {mplang_nightly-0.1.dev143 → mplang_nightly-0.1.dev144}/tests/core/expr/conftest.py +0 -0
- {mplang_nightly-0.1.dev143 → mplang_nightly-0.1.dev144}/tests/core/expr/test_ast.py +0 -0
- {mplang_nightly-0.1.dev143 → mplang_nightly-0.1.dev144}/tests/core/expr/test_printer.py +0 -0
- {mplang_nightly-0.1.dev143 → mplang_nightly-0.1.dev144}/tests/core/expr/test_utils.py +0 -0
- {mplang_nightly-0.1.dev143 → mplang_nightly-0.1.dev144}/tests/core/expr/test_walk.py +0 -0
- {mplang_nightly-0.1.dev143 → mplang_nightly-0.1.dev144}/tests/core/test_cluster.py +0 -0
- {mplang_nightly-0.1.dev143 → mplang_nightly-0.1.dev144}/tests/core/test_dtype.py +0 -0
- {mplang_nightly-0.1.dev143 → mplang_nightly-0.1.dev144}/tests/core/test_mask.py +0 -0
- {mplang_nightly-0.1.dev143 → mplang_nightly-0.1.dev144}/tests/core/test_mpir.py +0 -0
- {mplang_nightly-0.1.dev143 → mplang_nightly-0.1.dev144}/tests/core/test_mptype.py +0 -0
- {mplang_nightly-0.1.dev143 → mplang_nightly-0.1.dev144}/tests/core/test_primitive.py +0 -0
- {mplang_nightly-0.1.dev143 → mplang_nightly-0.1.dev144}/tests/core/test_table.py +0 -0
- {mplang_nightly-0.1.dev143 → mplang_nightly-0.1.dev144}/tests/core/test_tensor.py +0 -0
- {mplang_nightly-0.1.dev143 → mplang_nightly-0.1.dev144}/tests/core/test_tracer.py +0 -0
- {mplang_nightly-0.1.dev143/tests/frontend → mplang_nightly-0.1.dev144/tests/device}/__init__.py +0 -0
- {mplang_nightly-0.1.dev143 → mplang_nightly-0.1.dev144}/tests/device/test_device_basic.py +0 -0
- {mplang_nightly-0.1.dev143/tests/runtime → mplang_nightly-0.1.dev144/tests/frontend}/__init__.py +0 -0
- {mplang_nightly-0.1.dev143 → mplang_nightly-0.1.dev144}/tests/frontend/dummy.py +0 -0
- {mplang_nightly-0.1.dev143 → mplang_nightly-0.1.dev144}/tests/frontend/test_builtin_pack.py +0 -0
- {mplang_nightly-0.1.dev143 → mplang_nightly-0.1.dev144}/tests/frontend/test_crypto_tee.py +0 -0
- {mplang_nightly-0.1.dev143 → mplang_nightly-0.1.dev144}/tests/frontend/test_feop_base.py +0 -0
- {mplang_nightly-0.1.dev143 → mplang_nightly-0.1.dev144}/tests/frontend/test_ibis.py +0 -0
- {mplang_nightly-0.1.dev143 → mplang_nightly-0.1.dev144}/tests/frontend/test_ibis_cc.py +0 -0
- {mplang_nightly-0.1.dev143 → mplang_nightly-0.1.dev144}/tests/frontend/test_jax_cc.py +0 -0
- {mplang_nightly-0.1.dev143 → mplang_nightly-0.1.dev144}/tests/frontend/test_phe.py +0 -0
- {mplang_nightly-0.1.dev143 → mplang_nightly-0.1.dev144}/tests/frontend/test_sql.py +0 -0
- {mplang_nightly-0.1.dev143 → mplang_nightly-0.1.dev144}/tests/frontend/test_table_tensor_conversion.py +0 -0
- {mplang_nightly-0.1.dev143 → mplang_nightly-0.1.dev144}/tests/integration/README.md +0 -0
- {mplang_nightly-0.1.dev143 → mplang_nightly-0.1.dev144}/tests/integration/test_crypto_roundtrip.py +0 -0
- {mplang_nightly-0.1.dev143 → mplang_nightly-0.1.dev144}/tests/integration/test_http_e2e.py +0 -0
- {mplang_nightly-0.1.dev143 → mplang_nightly-0.1.dev144}/tests/integration/test_symbols_roundtrip.py +0 -0
- {mplang_nightly-0.1.dev143 → mplang_nightly-0.1.dev144}/tests/integration/test_tutorials.py +0 -0
- {mplang_nightly-0.1.dev143 → mplang_nightly-0.1.dev144}/tests/runtime/test_cli.py +0 -0
- {mplang_nightly-0.1.dev143 → mplang_nightly-0.1.dev144}/tests/runtime/test_communicator.py +0 -0
- {mplang_nightly-0.1.dev143 → mplang_nightly-0.1.dev144}/tests/runtime/test_driver.py +0 -0
- {mplang_nightly-0.1.dev143 → mplang_nightly-0.1.dev144}/tests/runtime/test_server.py +0 -0
- {mplang_nightly-0.1.dev143 → mplang_nightly-0.1.dev144}/tests/runtime/test_simulation.py +0 -0
- {mplang_nightly-0.1.dev143 → mplang_nightly-0.1.dev144}/tests/simp/test_mpi.py +0 -0
- {mplang_nightly-0.1.dev143 → mplang_nightly-0.1.dev144}/tests/simp/test_random.py +0 -0
- {mplang_nightly-0.1.dev143 → mplang_nightly-0.1.dev144}/tests/simp/test_simp.py +0 -0
- {mplang_nightly-0.1.dev143 → mplang_nightly-0.1.dev144}/tests/simp/test_smpc.py +0 -0
- {mplang_nightly-0.1.dev143 → mplang_nightly-0.1.dev144}/tests/simp/test_sugar.py +0 -0
- {mplang_nightly-0.1.dev143 → mplang_nightly-0.1.dev144}/tests/utils/__init__.py +0 -0
- {mplang_nightly-0.1.dev143 → mplang_nightly-0.1.dev144}/tests/utils/server_fixtures.py +0 -0
- {mplang_nightly-0.1.dev143 → mplang_nightly-0.1.dev144}/tests/utils/test_func_utils.py +0 -0
- {mplang_nightly-0.1.dev143 → mplang_nightly-0.1.dev144}/tests/utils/test_spu_utils.py +0 -0
- {mplang_nightly-0.1.dev143 → mplang_nightly-0.1.dev144}/tests/utils/test_table_utils.py +0 -0
- {mplang_nightly-0.1.dev143 → mplang_nightly-0.1.dev144}/tutorials/0_basic.py +0 -0
- {mplang_nightly-0.1.dev143 → mplang_nightly-0.1.dev144}/tutorials/10_analysis.py +0 -0
- {mplang_nightly-0.1.dev143 → mplang_nightly-0.1.dev144}/tutorials/1_condition.py +0 -0
- {mplang_nightly-0.1.dev143 → mplang_nightly-0.1.dev144}/tutorials/2_whileloop.py +0 -0
- {mplang_nightly-0.1.dev143 → mplang_nightly-0.1.dev144}/tutorials/3_device.py +0 -0
- {mplang_nightly-0.1.dev143 → mplang_nightly-0.1.dev144}/tutorials/4_simulation.py +0 -0
- {mplang_nightly-0.1.dev143 → mplang_nightly-0.1.dev144}/tutorials/5_ir_dump.py +0 -0
- {mplang_nightly-0.1.dev143 → mplang_nightly-0.1.dev144}/tutorials/6_advanced.py +0 -0
- {mplang_nightly-0.1.dev143 → mplang_nightly-0.1.dev144}/tutorials/7_stdio.py +0 -0
- {mplang_nightly-0.1.dev143 → mplang_nightly-0.1.dev144}/tutorials/8_phe.py +0 -0
- {mplang_nightly-0.1.dev143 → mplang_nightly-0.1.dev144}/tutorials/9_tee.py +0 -0
- {mplang_nightly-0.1.dev143 → mplang_nightly-0.1.dev144}/tutorials/__init__.py +0 -0
- {mplang_nightly-0.1.dev143 → mplang_nightly-0.1.dev144}/tutorials/pitfalls/late_binding.py +0 -0
- {mplang_nightly-0.1.dev143 → mplang_nightly-0.1.dev144}/tutorials/pitfalls/rand.py +0 -0
- {mplang_nightly-0.1.dev143 → mplang_nightly-0.1.dev144}/tutorials/run.sh +0 -0
@@ -0,0 +1,175 @@
|
|
1
|
+
# Copyright 2025 Ant Group Co., Ltd.
|
2
|
+
#
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4
|
+
# you may not use this file except in compliance with the License.
|
5
|
+
# You may obtain a copy of the License at
|
6
|
+
#
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
8
|
+
#
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12
|
+
# See the License for the specific language governing permissions and
|
13
|
+
# limitations under the License.
|
14
|
+
|
15
|
+
"""Backend kernel registry & per-participant runtime (explicit op->kernel binding).
|
16
|
+
|
17
|
+
This version decouples *kernel implementation registration* from *operation binding*.
|
18
|
+
|
19
|
+
Concepts:
|
20
|
+
* kernel_id: unique identifier of a concrete backend implementation.
|
21
|
+
* op_type: semantic operation name carried by ``PFunction.fn_type``.
|
22
|
+
* bind_op(op_type, kernel_id): performed by higher layer (see ``backend.context``)
|
23
|
+
to select which implementation handles an op. Runtime dispatch is now a 2-step:
|
24
|
+
pfunc.fn_type -> active kernel_id -> KernelSpec.fn
|
25
|
+
|
26
|
+
The previous implicit "import == register+bind" coupling is removed. Kernel
|
27
|
+
modules only call ``@kernel_def(kernel_id)``. Default bindings are established
|
28
|
+
centrally (lazy) the first time a runtime executes a kernel.
|
29
|
+
"""
|
30
|
+
|
31
|
+
from __future__ import annotations
|
32
|
+
|
33
|
+
import contextvars
|
34
|
+
from collections.abc import Callable
|
35
|
+
from dataclasses import dataclass
|
36
|
+
from typing import Any
|
37
|
+
|
38
|
+
__all__ = [
|
39
|
+
"KernelContext",
|
40
|
+
"KernelSpec",
|
41
|
+
"bind_op",
|
42
|
+
"cur_kctx",
|
43
|
+
"get_kernel_for_op",
|
44
|
+
"list_kernels",
|
45
|
+
"list_ops",
|
46
|
+
"unbind_op",
|
47
|
+
]
|
48
|
+
|
49
|
+
|
50
|
+
@dataclass
|
51
|
+
class KernelContext:
|
52
|
+
"""Ephemeral call context set via contextvar while a kernel runs."""
|
53
|
+
|
54
|
+
rank: int
|
55
|
+
world_size: int
|
56
|
+
state: dict[str, dict[str, Any]] # backend namespace -> pocket
|
57
|
+
cache: dict[str, Any] # runtime-level shared cache (per BackendRuntime)
|
58
|
+
|
59
|
+
|
60
|
+
_CTX_VAR: contextvars.ContextVar[KernelContext | None] = contextvars.ContextVar(
|
61
|
+
"_flat_backend_ctx", default=None
|
62
|
+
)
|
63
|
+
|
64
|
+
|
65
|
+
def cur_kctx() -> KernelContext:
|
66
|
+
"""Return the current kernel execution context (only valid inside a kernel).
|
67
|
+
|
68
|
+
Two storages:
|
69
|
+
- state: namespaced pockets (dict[str, dict]) for backend-local mutable helpers
|
70
|
+
- cache: global (per runtime) shared dict; prefer state unless truly cross-backend
|
71
|
+
|
72
|
+
Examples:
|
73
|
+
1) Compile cache::
|
74
|
+
@kernel_def("mlir.stablehlo")
|
75
|
+
def _exec(pfunc, args):
|
76
|
+
ctx = cur_kctx()
|
77
|
+
pocket = ctx.state.setdefault("stablehlo", {})
|
78
|
+
cache = pocket.setdefault("compile_cache", {})
|
79
|
+
text = pfunc.fn_text
|
80
|
+
mod = cache.get(text)
|
81
|
+
if mod is None:
|
82
|
+
mod = compile_mlir(text)
|
83
|
+
cache[text] = mod
|
84
|
+
return run(mod, args)
|
85
|
+
|
86
|
+
2) Deterministic RNG::
|
87
|
+
@kernel_def("crypto.keygen")
|
88
|
+
def _keygen(pfunc, args):
|
89
|
+
ctx = cur_kctx()
|
90
|
+
pocket = ctx.state.setdefault("crypto", {})
|
91
|
+
rng = pocket.get("rng")
|
92
|
+
if rng is None:
|
93
|
+
rng = np.random.default_rng(1234 + ctx.rank * 7919)
|
94
|
+
pocket["rng"] = rng
|
95
|
+
return (rng.integers(0, 256, size=(32,), dtype=np.uint8),)
|
96
|
+
"""
|
97
|
+
ctx = _CTX_VAR.get()
|
98
|
+
if ctx is None:
|
99
|
+
raise RuntimeError("cur_kctx() called outside backend kernel execution")
|
100
|
+
return ctx
|
101
|
+
|
102
|
+
|
103
|
+
# ---------------- Registry ----------------
|
104
|
+
|
105
|
+
# Kernel callable signature: (pfunc, *args) -> Any | sequence (no **kwargs)
|
106
|
+
KernelFn = Callable[..., Any]
|
107
|
+
|
108
|
+
|
109
|
+
@dataclass
|
110
|
+
class KernelSpec:
|
111
|
+
kernel_id: str
|
112
|
+
fn: KernelFn
|
113
|
+
meta: dict[str, Any]
|
114
|
+
|
115
|
+
|
116
|
+
# All registered kernel implementations: kernel_id -> spec
|
117
|
+
_KERNELS: dict[str, KernelSpec] = {}
|
118
|
+
|
119
|
+
# Active op bindings: op_type -> kernel_id
|
120
|
+
_BINDINGS: dict[str, str] = {}
|
121
|
+
|
122
|
+
|
123
|
+
def kernel_def(kernel_id: str, /, **meta: Any) -> Callable[[KernelFn], KernelFn]:
|
124
|
+
"""Decorator to register a concrete kernel implementation.
|
125
|
+
|
126
|
+
This ONLY registers the implementation (kernel_id -> fn). It does NOT bind
|
127
|
+
any op. Higher layer must call ``bind_op(op_type, kernel_id)`` explicitly.
|
128
|
+
"""
|
129
|
+
|
130
|
+
def _decorator(fn: KernelFn) -> KernelFn:
|
131
|
+
if kernel_id in _KERNELS:
|
132
|
+
raise ValueError(f"duplicate kernel_id={kernel_id}")
|
133
|
+
_KERNELS[kernel_id] = KernelSpec(kernel_id=kernel_id, fn=fn, meta=dict(meta))
|
134
|
+
return fn
|
135
|
+
|
136
|
+
return _decorator
|
137
|
+
|
138
|
+
|
139
|
+
def bind_op(op_type: str, kernel_id: str, *, force: bool = True) -> None:
|
140
|
+
"""Bind an op_type to a registered kernel implementation.
|
141
|
+
|
142
|
+
Args:
|
143
|
+
op_type: Semantic operation name.
|
144
|
+
kernel_id: Previously registered kernel identifier.
|
145
|
+
force: If False and op_type already bound, keep existing binding.
|
146
|
+
If True (default), overwrite.
|
147
|
+
"""
|
148
|
+
if kernel_id not in _KERNELS:
|
149
|
+
raise KeyError(f"kernel_id {kernel_id} not registered")
|
150
|
+
if not force and op_type in _BINDINGS:
|
151
|
+
return
|
152
|
+
_BINDINGS[op_type] = kernel_id
|
153
|
+
|
154
|
+
|
155
|
+
def unbind_op(op_type: str) -> None:
|
156
|
+
_BINDINGS.pop(op_type, None)
|
157
|
+
|
158
|
+
|
159
|
+
def get_kernel_for_op(op_type: str) -> KernelSpec:
|
160
|
+
kid = _BINDINGS.get(op_type)
|
161
|
+
if kid is None:
|
162
|
+
# Tests expect NotImplementedError for unsupported operations
|
163
|
+
raise NotImplementedError(f"no backend kernel registered for op {op_type}")
|
164
|
+
spec = _KERNELS.get(kid)
|
165
|
+
if spec is None: # inconsistent state
|
166
|
+
raise RuntimeError(f"active kernel_id {kid} missing spec")
|
167
|
+
return spec
|
168
|
+
|
169
|
+
|
170
|
+
def list_kernels() -> list[str]:
|
171
|
+
return sorted(_KERNELS.keys())
|
172
|
+
|
173
|
+
|
174
|
+
def list_ops() -> list[str]:
|
175
|
+
return sorted(_BINDINGS.keys())
|
@@ -0,0 +1,255 @@
|
|
1
|
+
# Copyright 2025 Ant Group Co., Ltd.
|
2
|
+
#
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4
|
+
# you may not use this file except in compliance with the License.
|
5
|
+
# You may obtain a copy of the License at
|
6
|
+
#
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
8
|
+
#
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12
|
+
# See the License for the specific language governing permissions and
|
13
|
+
# limitations under the License.
|
14
|
+
|
15
|
+
from __future__ import annotations
|
16
|
+
|
17
|
+
from collections.abc import Mapping
|
18
|
+
from dataclasses import dataclass, field
|
19
|
+
from typing import Any
|
20
|
+
|
21
|
+
from mplang.backend import base
|
22
|
+
from mplang.backend.base import KernelContext, bind_op, get_kernel_for_op
|
23
|
+
from mplang.core.dtype import UINT8, DType
|
24
|
+
from mplang.core.pfunc import PFunction
|
25
|
+
from mplang.core.table import TableLike, TableType
|
26
|
+
from mplang.core.tensor import TensorLike, TensorType
|
27
|
+
|
28
|
+
# Default bindings
|
29
|
+
# Import kernel implementation modules explicitly so their @kernel_def entries
|
30
|
+
# register at import time. Keep imports grouped; alias with leading underscore
|
31
|
+
# to silence unused variable warnings without F401 pragmas.
|
32
|
+
_IMPL_IMPORTED = False
|
33
|
+
|
34
|
+
|
35
|
+
def _ensure_impl_imported() -> None:
|
36
|
+
global _IMPL_IMPORTED
|
37
|
+
if _IMPL_IMPORTED:
|
38
|
+
return
|
39
|
+
from mplang.backend import builtin as _impl_builtin # noqa: F401
|
40
|
+
from mplang.backend import crypto as _impl_crypto # noqa: F401
|
41
|
+
from mplang.backend import phe as _impl_phe # noqa: F401
|
42
|
+
from mplang.backend import spu as _impl_spu # noqa: F401
|
43
|
+
from mplang.backend import sql_duckdb as _impl_sql_duckdb # noqa: F401
|
44
|
+
from mplang.backend import stablehlo as _impl_stablehlo # noqa: F401
|
45
|
+
from mplang.backend import tee as _impl_tee # noqa: F401
|
46
|
+
|
47
|
+
_IMPL_IMPORTED = True
|
48
|
+
|
49
|
+
|
50
|
+
# imports consolidated above
|
51
|
+
|
52
|
+
_DEFAULT_BINDINGS: dict[str, str] = {
|
53
|
+
# builtin
|
54
|
+
"builtin.identity": "builtin.identity",
|
55
|
+
"builtin.read": "builtin.read",
|
56
|
+
"builtin.write": "builtin.write",
|
57
|
+
"builtin.constant": "builtin.constant",
|
58
|
+
"builtin.rank": "builtin.rank",
|
59
|
+
"builtin.prand": "builtin.prand",
|
60
|
+
"builtin.table_to_tensor": "builtin.table_to_tensor",
|
61
|
+
"builtin.tensor_to_table": "builtin.tensor_to_table",
|
62
|
+
"builtin.debug_print": "builtin.debug_print",
|
63
|
+
"builtin.pack": "builtin.pack",
|
64
|
+
"builtin.unpack": "builtin.unpack",
|
65
|
+
# crypto
|
66
|
+
"crypto.keygen": "crypto.keygen",
|
67
|
+
"crypto.enc": "crypto.enc",
|
68
|
+
"crypto.dec": "crypto.dec",
|
69
|
+
"crypto.kem_keygen": "crypto.kem_keygen",
|
70
|
+
"crypto.kem_derive": "crypto.kem_derive",
|
71
|
+
"crypto.hkdf": "crypto.hkdf",
|
72
|
+
# phe
|
73
|
+
"phe.keygen": "phe.keygen",
|
74
|
+
"phe.encrypt": "phe.encrypt",
|
75
|
+
"phe.mul": "phe.mul",
|
76
|
+
"phe.add": "phe.add",
|
77
|
+
"phe.decrypt": "phe.decrypt",
|
78
|
+
"phe.dot": "phe.dot",
|
79
|
+
"phe.gather": "phe.gather",
|
80
|
+
"phe.scatter": "phe.scatter",
|
81
|
+
"phe.concat": "phe.concat",
|
82
|
+
"phe.reshape": "phe.reshape",
|
83
|
+
"phe.transpose": "phe.transpose",
|
84
|
+
# spu
|
85
|
+
"spu.seed_env": "spu.seed_env",
|
86
|
+
"spu.makeshares": "spu.makeshares",
|
87
|
+
"spu.reconstruct": "spu.reconstruct",
|
88
|
+
"spu.run_pphlo": "spu.run_pphlo",
|
89
|
+
# stablehlo
|
90
|
+
"mlir.stablehlo": "mlir.stablehlo",
|
91
|
+
# sql
|
92
|
+
# generic SQL op; backend-specific kernel id for duckdb
|
93
|
+
"sql.run": "duckdb.run_sql",
|
94
|
+
# tee
|
95
|
+
"tee.quote": "tee.quote",
|
96
|
+
"tee.attest": "tee.attest",
|
97
|
+
}
|
98
|
+
|
99
|
+
|
100
|
+
# --- RuntimeContext ---
|
101
|
+
|
102
|
+
|
103
|
+
@dataclass
|
104
|
+
class RuntimeContext:
|
105
|
+
rank: int
|
106
|
+
world_size: int
|
107
|
+
bindings: Mapping[str, str] | None = None # optional overrides
|
108
|
+
state: dict[str, dict[str, Any]] = field(default_factory=dict)
|
109
|
+
cache: dict[str, Any] = field(default_factory=dict)
|
110
|
+
stats: dict[str, Any] = field(default_factory=dict)
|
111
|
+
|
112
|
+
def __post_init__(self) -> None:
|
113
|
+
_ensure_impl_imported()
|
114
|
+
if self.bindings is not None:
|
115
|
+
for op, kid in self.bindings.items():
|
116
|
+
bind_op(op, kid)
|
117
|
+
else:
|
118
|
+
for op, kid in _DEFAULT_BINDINGS.items():
|
119
|
+
bind_op(op, kid)
|
120
|
+
# Initialize stats pocket
|
121
|
+
self.stats.setdefault("op_calls", {})
|
122
|
+
|
123
|
+
def run_kernel(self, pfunc: PFunction, arg_list: list[Any]) -> list[Any]:
|
124
|
+
fn_type = pfunc.fn_type
|
125
|
+
spec = get_kernel_for_op(fn_type)
|
126
|
+
fn = spec.fn
|
127
|
+
if len(arg_list) != len(pfunc.ins_info):
|
128
|
+
raise ValueError(
|
129
|
+
f"kernel {fn_type} arg count mismatch: got {len(arg_list)}, expect {len(pfunc.ins_info)}"
|
130
|
+
)
|
131
|
+
for idx, (ins_spec, val) in enumerate(
|
132
|
+
zip(pfunc.ins_info, arg_list, strict=True)
|
133
|
+
):
|
134
|
+
if isinstance(ins_spec, TableType):
|
135
|
+
_validate_table_arg(fn_type, idx, ins_spec, val)
|
136
|
+
continue
|
137
|
+
if isinstance(ins_spec, TensorType):
|
138
|
+
_validate_tensor_arg(fn_type, idx, ins_spec, val)
|
139
|
+
continue
|
140
|
+
# install kernel context
|
141
|
+
kctx = KernelContext(
|
142
|
+
rank=self.rank,
|
143
|
+
world_size=self.world_size,
|
144
|
+
state=self.state,
|
145
|
+
cache=self.cache,
|
146
|
+
)
|
147
|
+
token = base._CTX_VAR.set(kctx) # type: ignore[attr-defined]
|
148
|
+
try:
|
149
|
+
raw = fn(pfunc, *arg_list)
|
150
|
+
finally:
|
151
|
+
base._CTX_VAR.reset(token) # type: ignore[attr-defined]
|
152
|
+
# Stats (best effort)
|
153
|
+
try:
|
154
|
+
op_calls = self.stats.setdefault("op_calls", {})
|
155
|
+
op_calls[fn_type] = op_calls.get(fn_type, 0) + 1
|
156
|
+
except Exception: # pragma: no cover - never raise due to stats
|
157
|
+
pass
|
158
|
+
expected = len(pfunc.outs_info)
|
159
|
+
if expected == 0:
|
160
|
+
if raw in (None, (), []):
|
161
|
+
return []
|
162
|
+
raise ValueError(
|
163
|
+
f"kernel {fn_type} should return no values; got {type(raw).__name__}"
|
164
|
+
)
|
165
|
+
if expected == 1:
|
166
|
+
if isinstance(raw, (tuple, list)):
|
167
|
+
if len(raw) != 1:
|
168
|
+
raise ValueError(
|
169
|
+
f"kernel {fn_type} produced {len(raw)} outputs, expected 1"
|
170
|
+
)
|
171
|
+
return [raw[0]]
|
172
|
+
return [raw]
|
173
|
+
if not isinstance(raw, (tuple, list)):
|
174
|
+
raise TypeError(
|
175
|
+
f"kernel {fn_type} must return sequence (len={expected}), got {type(raw).__name__}"
|
176
|
+
)
|
177
|
+
if len(raw) != expected:
|
178
|
+
raise ValueError(
|
179
|
+
f"kernel {fn_type} produced {len(raw)} outputs, expected {expected}"
|
180
|
+
)
|
181
|
+
return list(raw)
|
182
|
+
|
183
|
+
def reset(self) -> None:
|
184
|
+
self.state.clear()
|
185
|
+
self.cache.clear()
|
186
|
+
|
187
|
+
# ---- explicit (re)binding API ----
|
188
|
+
def bind_op(self, op_type: str, kernel_id: str, *, force: bool = False) -> None:
|
189
|
+
"""Bind an operation to a kernel at runtime.
|
190
|
+
|
191
|
+
force=False (default) preserves any existing binding to avoid accidental
|
192
|
+
silent overrides. Use ``rebind_op`` or ``force=True`` to intentionally
|
193
|
+
change a binding.
|
194
|
+
"""
|
195
|
+
base.bind_op(op_type, kernel_id, force=force)
|
196
|
+
|
197
|
+
def rebind_op(self, op_type: str, kernel_id: str) -> None:
|
198
|
+
"""Force rebind an operation to a different kernel (shorthand)."""
|
199
|
+
base.bind_op(op_type, kernel_id, force=True)
|
200
|
+
|
201
|
+
|
202
|
+
def _validate_table_arg(
|
203
|
+
fn_type: str, arg_index: int, spec: TableType, value: Any
|
204
|
+
) -> None:
|
205
|
+
if not isinstance(value, TableLike):
|
206
|
+
raise TypeError(
|
207
|
+
f"kernel {fn_type} input[{arg_index}] expects TableLike, got {type(value).__name__}"
|
208
|
+
)
|
209
|
+
if len(value.columns) != len(spec.columns):
|
210
|
+
raise ValueError(
|
211
|
+
f"kernel {fn_type} input[{arg_index}] column count mismatch: got {len(value.columns)}, expected {len(spec.columns)}"
|
212
|
+
)
|
213
|
+
|
214
|
+
|
215
|
+
def _validate_tensor_arg(
|
216
|
+
fn_type: str, arg_index: int, spec: TensorType, value: Any
|
217
|
+
) -> None:
|
218
|
+
# Backend-only handle sentinel (e.g., PHE keys) bypasses all structural checks
|
219
|
+
if tuple(spec.shape) == (-1, 0) and spec.dtype == UINT8:
|
220
|
+
return
|
221
|
+
|
222
|
+
if isinstance(value, (int, float, bool, complex)):
|
223
|
+
val_shape: tuple[Any, ...] = ()
|
224
|
+
duck_dtype: Any = type(value)
|
225
|
+
else:
|
226
|
+
if not isinstance(value, TensorLike):
|
227
|
+
raise TypeError(
|
228
|
+
f"kernel {fn_type} input[{arg_index}] expects TensorLike, got {type(value).__name__}"
|
229
|
+
)
|
230
|
+
val_shape = getattr(value, "shape", ())
|
231
|
+
duck_dtype = getattr(value, "dtype", None)
|
232
|
+
|
233
|
+
if len(spec.shape) != len(val_shape):
|
234
|
+
raise ValueError(
|
235
|
+
f"kernel {fn_type} input[{arg_index}] rank mismatch: got {val_shape}, expected {spec.shape}"
|
236
|
+
)
|
237
|
+
|
238
|
+
for dim_idx, (spec_dim, val_dim) in enumerate(
|
239
|
+
zip(spec.shape, val_shape, strict=True)
|
240
|
+
):
|
241
|
+
if spec_dim >= 0 and spec_dim != val_dim:
|
242
|
+
raise ValueError(
|
243
|
+
f"kernel {fn_type} input[{arg_index}] shape mismatch at dim {dim_idx}: got {val_dim}, expected {spec_dim}"
|
244
|
+
)
|
245
|
+
|
246
|
+
try:
|
247
|
+
val_dtype = DType.from_any(duck_dtype)
|
248
|
+
except (ValueError, TypeError): # pragma: no cover
|
249
|
+
raise TypeError(
|
250
|
+
f"kernel {fn_type} input[{arg_index}] has unsupported dtype object {duck_dtype!r}"
|
251
|
+
) from None
|
252
|
+
if val_dtype != spec.dtype:
|
253
|
+
raise ValueError(
|
254
|
+
f"kernel {fn_type} input[{arg_index}] dtype mismatch: got {val_dtype}, expected {spec.dtype}"
|
255
|
+
)
|
@@ -186,16 +186,18 @@ def _spu_reconstruct(pfunc: PFunction, *args: Any) -> Any:
|
|
186
186
|
return reconstructed
|
187
187
|
|
188
188
|
|
189
|
-
@kernel_def("
|
189
|
+
@kernel_def("spu.run_pphlo")
|
190
190
|
def _spu_run_mlir(pfunc: PFunction, *args: Any) -> Any:
|
191
|
-
"""Execute compiled SPU function (
|
191
|
+
"""Execute compiled SPU function (spu.run_pphlo) and return SpuValue outputs.
|
192
192
|
|
193
193
|
Participation rule: a rank participates iff its entry in the stored
|
194
194
|
link_ctx list is non-None. This allows us to allocate a world-sized list
|
195
195
|
(indexed by global rank) and simply assign None for non-SPU parties.
|
196
196
|
"""
|
197
|
-
if pfunc.fn_type != "
|
198
|
-
raise ValueError(
|
197
|
+
if pfunc.fn_type != "spu.run_pphlo":
|
198
|
+
raise ValueError(
|
199
|
+
f"Unsupported format: {pfunc.fn_type}. Expected 'spu.run_pphlo'"
|
200
|
+
)
|
199
201
|
|
200
202
|
cfg, _ = _get_spu_config_and_world()
|
201
203
|
pocket = _get_spu_pocket()
|
@@ -27,7 +27,7 @@ from __future__ import annotations
|
|
27
27
|
from dataclasses import dataclass
|
28
28
|
from typing import Any, Protocol
|
29
29
|
|
30
|
-
from mplang.backend.
|
30
|
+
from mplang.backend.context import RuntimeContext
|
31
31
|
from mplang.core.comm import ICommunicator
|
32
32
|
from mplang.core.expr.ast import (
|
33
33
|
AccessExpr,
|
@@ -56,7 +56,7 @@ class IEvaluator(Protocol):
|
|
56
56
|
backend state via evaluator.runtime.run_kernel(...).
|
57
57
|
"""
|
58
58
|
|
59
|
-
runtime:
|
59
|
+
runtime: RuntimeContext
|
60
60
|
|
61
61
|
def evaluate(self, root: Expr, env: dict[str, Any] | None = None) -> list[Any]: ...
|
62
62
|
|
@@ -72,7 +72,7 @@ class EvalSemantic:
|
|
72
72
|
rank: int
|
73
73
|
env: dict[str, Any]
|
74
74
|
comm: ICommunicator
|
75
|
-
runtime:
|
75
|
+
runtime: RuntimeContext
|
76
76
|
|
77
77
|
# ------------------------------ Shared helpers (semantics) ------------------------------
|
78
78
|
def _should_run(self, rmask: Mask | None, args: list[Any]) -> bool:
|
@@ -205,7 +205,7 @@ class RecursiveEvaluator(EvalSemantic, ExprVisitor):
|
|
205
205
|
rank: int,
|
206
206
|
env: dict[str, Any],
|
207
207
|
comm: ICommunicator,
|
208
|
-
runtime:
|
208
|
+
runtime: RuntimeContext,
|
209
209
|
) -> None:
|
210
210
|
super().__init__(rank, env, comm, runtime)
|
211
211
|
self._cache: dict[int, Any] = {} # Cache based on expr id
|
@@ -380,7 +380,7 @@ class IterativeEvaluator(EvalSemantic):
|
|
380
380
|
rank: int,
|
381
381
|
env: dict[str, Any],
|
382
382
|
comm: ICommunicator,
|
383
|
-
runtime:
|
383
|
+
runtime: RuntimeContext,
|
384
384
|
) -> None:
|
385
385
|
super().__init__(rank, env, comm, runtime)
|
386
386
|
|
@@ -501,7 +501,7 @@ def create_evaluator(
|
|
501
501
|
rank: int,
|
502
502
|
env: dict[str, Any],
|
503
503
|
comm: ICommunicator,
|
504
|
-
runtime:
|
504
|
+
runtime: RuntimeContext,
|
505
505
|
kind: str | None = "iterative",
|
506
506
|
) -> IEvaluator:
|
507
507
|
"""Factory to create an evaluator engine.
|
@@ -129,7 +129,7 @@ class FeModule(ABC):
|
|
129
129
|
- You need compilation/stateful behavior/dynamic routing, multiple PFunctions, or complex capture flows.
|
130
130
|
|
131
131
|
Tips:
|
132
|
-
- Keep routing information in PFunction.fn_type (e.g., "builtin.read", "sql
|
132
|
+
- Keep routing information in PFunction.fn_type (e.g., "builtin.read", "sql.run", "mlir.stablehlo").
|
133
133
|
- Avoid backend-specific logic in kernels; only validate and shape types.
|
134
134
|
- Prefer keyword-only attributes in typed_op kernels for clarity (def op(x: MPObject, *, attr: int)).
|
135
135
|
"""
|
@@ -57,8 +57,9 @@ def ibis2sql(
|
|
57
57
|
outs_info = [_convert(expr.schema())]
|
58
58
|
|
59
59
|
sql = ibis.to_sql(expr, dialect="duckdb")
|
60
|
+
# Emit generic sql.run op; runtime maps to backend-specific kernel.
|
60
61
|
pfn = PFunction(
|
61
|
-
fn_type="sql
|
62
|
+
fn_type="sql.run",
|
62
63
|
fn_name=fn_name,
|
63
64
|
fn_text=sql,
|
64
65
|
ins_info=tuple(ins_info),
|
@@ -94,9 +94,10 @@ def _compile_jax(
|
|
94
94
|
*args: Any,
|
95
95
|
**kwargs: Any,
|
96
96
|
) -> tuple[PFunction, list[MPObject], PyTreeDef]:
|
97
|
-
"""Compile a JAX function into SPU pphlo MLIR
|
97
|
+
"""Compile a JAX function into SPU pphlo MLIR and wrap as PFunction.
|
98
98
|
|
99
|
-
|
99
|
+
Resulting PFunction uses fn_type 'spu.run_pphlo'.
|
100
|
+
"""
|
100
101
|
|
101
102
|
def is_variable(arg: Any) -> bool:
|
102
103
|
return isinstance(arg, MPObject)
|
@@ -132,7 +133,7 @@ def _compile_jax(
|
|
132
133
|
executable_code = executable_code.decode("utf-8")
|
133
134
|
|
134
135
|
pfunc = PFunction(
|
135
|
-
fn_type="
|
136
|
+
fn_type="spu.run_pphlo",
|
136
137
|
ins_info=tuple(TensorType.from_obj(x) for x in in_vars),
|
137
138
|
outs_info=tuple(output_tensor_infos),
|
138
139
|
fn_name=get_fn_name(fn),
|