mplang-nightly 0.1.dev163__tar.gz → 0.1.dev165__tar.gz
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- {mplang_nightly-0.1.dev163 → mplang_nightly-0.1.dev165}/PKG-INFO +1 -1
- {mplang_nightly-0.1.dev163 → mplang_nightly-0.1.dev165}/mplang/core/expr/evaluator.py +55 -15
- {mplang_nightly-0.1.dev163 → mplang_nightly-0.1.dev165}/mplang/device.py +4 -18
- mplang_nightly-0.1.dev165/mplang/kernels/__init__.py +41 -0
- {mplang_nightly-0.1.dev163 → mplang_nightly-0.1.dev165}/mplang/kernels/builtin.py +91 -56
- {mplang_nightly-0.1.dev163 → mplang_nightly-0.1.dev165}/mplang/kernels/crypto.py +39 -30
- {mplang_nightly-0.1.dev163 → mplang_nightly-0.1.dev165}/mplang/kernels/mock_tee.py +10 -11
- {mplang_nightly-0.1.dev163 → mplang_nightly-0.1.dev165}/mplang/kernels/phe.py +238 -39
- {mplang_nightly-0.1.dev163 → mplang_nightly-0.1.dev165}/mplang/kernels/spu.py +134 -45
- {mplang_nightly-0.1.dev163 → mplang_nightly-0.1.dev165}/mplang/kernels/sql_duckdb.py +8 -13
- {mplang_nightly-0.1.dev163 → mplang_nightly-0.1.dev165}/mplang/kernels/stablehlo.py +15 -9
- mplang_nightly-0.1.dev165/mplang/kernels/value.py +626 -0
- {mplang_nightly-0.1.dev163 → mplang_nightly-0.1.dev165}/mplang/ops/tee.py +7 -21
- {mplang_nightly-0.1.dev163 → mplang_nightly-0.1.dev165}/mplang/protos/v1alpha1/mpir_pb2.pyi +71 -21
- mplang_nightly-0.1.dev165/mplang/protos/v1alpha1/value_pb2.py +34 -0
- mplang_nightly-0.1.dev165/mplang/protos/v1alpha1/value_pb2.pyi +169 -0
- {mplang_nightly-0.1.dev163 → mplang_nightly-0.1.dev165}/mplang/runtime/client.py +19 -8
- {mplang_nightly-0.1.dev163 → mplang_nightly-0.1.dev165}/mplang/runtime/communicator.py +11 -4
- {mplang_nightly-0.1.dev163 → mplang_nightly-0.1.dev165}/mplang/runtime/driver.py +16 -1
- mplang_nightly-0.1.dev165/mplang/runtime/link_comm.py +78 -0
- {mplang_nightly-0.1.dev163 → mplang_nightly-0.1.dev165}/mplang/runtime/server.py +30 -29
- {mplang_nightly-0.1.dev163 → mplang_nightly-0.1.dev165}/mplang/runtime/session.py +9 -0
- {mplang_nightly-0.1.dev163 → mplang_nightly-0.1.dev165}/mplang/runtime/simulation.py +4 -5
- {mplang_nightly-0.1.dev163 → mplang_nightly-0.1.dev165}/mplang/simp/__init__.py +1 -1
- {mplang_nightly-0.1.dev163 → mplang_nightly-0.1.dev165}/pyproject.toml +5 -0
- {mplang_nightly-0.1.dev163 → mplang_nightly-0.1.dev165}/tests/conftest.py +2 -2
- {mplang_nightly-0.1.dev163 → mplang_nightly-0.1.dev165}/tests/integration/test_symbols_roundtrip.py +5 -1
- {mplang_nightly-0.1.dev163 → mplang_nightly-0.1.dev165}/tests/kernels/test_builtin.py +34 -20
- {mplang_nightly-0.1.dev163 → mplang_nightly-0.1.dev165}/tests/kernels/test_debug_print.py +8 -3
- {mplang_nightly-0.1.dev163 → mplang_nightly-0.1.dev165}/tests/kernels/test_kernel_binding.py +41 -15
- {mplang_nightly-0.1.dev163 → mplang_nightly-0.1.dev165}/tests/kernels/test_phe.py +18 -3
- {mplang_nightly-0.1.dev163 → mplang_nightly-0.1.dev165}/tests/kernels/test_spu.py +11 -10
- {mplang_nightly-0.1.dev163 → mplang_nightly-0.1.dev165}/tests/kernels/test_sql_duckdb.py +5 -1
- {mplang_nightly-0.1.dev163 → mplang_nightly-0.1.dev165}/tests/kernels/test_stablehlo.py +8 -2
- mplang_nightly-0.1.dev165/tests/kernels/test_value.py +324 -0
- mplang_nightly-0.1.dev165/tests/kernels/test_value_serde.py +377 -0
- {mplang_nightly-0.1.dev163 → mplang_nightly-0.1.dev165}/tests/ops/test_builtin_pack.py +7 -4
- {mplang_nightly-0.1.dev163 → mplang_nightly-0.1.dev165}/tests/ops/test_crypto_tee.py +7 -11
- {mplang_nightly-0.1.dev163 → mplang_nightly-0.1.dev165}/tests/runtime/test_communicator.py +22 -13
- {mplang_nightly-0.1.dev163 → mplang_nightly-0.1.dev165}/tests/runtime/test_server.py +12 -10
- {mplang_nightly-0.1.dev163 → mplang_nightly-0.1.dev165}/tutorials/9_tee.py +2 -3
- mplang_nightly-0.1.dev163/mplang/runtime/link_comm.py +0 -131
- mplang_nightly-0.1.dev163/tests/runtime/__init__.py +0 -13
- {mplang_nightly-0.1.dev163 → mplang_nightly-0.1.dev165}/.gitignore +0 -0
- {mplang_nightly-0.1.dev163 → mplang_nightly-0.1.dev165}/LICENSE +0 -0
- {mplang_nightly-0.1.dev163 → mplang_nightly-0.1.dev165}/README.md +0 -0
- {mplang_nightly-0.1.dev163 → mplang_nightly-0.1.dev165}/examples/conf/3pc.yaml +0 -0
- {mplang_nightly-0.1.dev163 → mplang_nightly-0.1.dev165}/examples/stax_nn/README.md +0 -0
- {mplang_nightly-0.1.dev163 → mplang_nightly-0.1.dev165}/examples/stax_nn/models.py +0 -0
- {mplang_nightly-0.1.dev163 → mplang_nightly-0.1.dev165}/examples/stax_nn/stax_nn.py +0 -0
- {mplang_nightly-0.1.dev163 → mplang_nightly-0.1.dev165}/examples/xgboost/hist_jax.py +0 -0
- {mplang_nightly-0.1.dev163 → mplang_nightly-0.1.dev165}/examples/xgboost/hist_jax_test.py +0 -0
- {mplang_nightly-0.1.dev163 → mplang_nightly-0.1.dev165}/examples/xgboost/naive_np.py +0 -0
- {mplang_nightly-0.1.dev163 → mplang_nightly-0.1.dev165}/examples/xgboost/readme.md +0 -0
- {mplang_nightly-0.1.dev163 → mplang_nightly-0.1.dev165}/examples/xgboost/sgb.py +0 -0
- {mplang_nightly-0.1.dev163 → mplang_nightly-0.1.dev165}/examples/xgboost/sgb_test.py +0 -0
- {mplang_nightly-0.1.dev163 → mplang_nightly-0.1.dev165}/hatch_build.py +0 -0
- {mplang_nightly-0.1.dev163 → mplang_nightly-0.1.dev165}/mplang/__init__.py +0 -0
- {mplang_nightly-0.1.dev163 → mplang_nightly-0.1.dev165}/mplang/analysis/__init__.py +0 -0
- {mplang_nightly-0.1.dev163 → mplang_nightly-0.1.dev165}/mplang/analysis/diagram.py +0 -0
- {mplang_nightly-0.1.dev163 → mplang_nightly-0.1.dev165}/mplang/api.py +0 -0
- {mplang_nightly-0.1.dev163 → mplang_nightly-0.1.dev165}/mplang/core/__init__.py +0 -0
- {mplang_nightly-0.1.dev163 → mplang_nightly-0.1.dev165}/mplang/core/cluster.py +0 -0
- {mplang_nightly-0.1.dev163 → mplang_nightly-0.1.dev165}/mplang/core/comm.py +0 -0
- {mplang_nightly-0.1.dev163 → mplang_nightly-0.1.dev165}/mplang/core/context_mgr.py +0 -0
- {mplang_nightly-0.1.dev163 → mplang_nightly-0.1.dev165}/mplang/core/dtype.py +0 -0
- {mplang_nightly-0.1.dev163 → mplang_nightly-0.1.dev165}/mplang/core/expr/__init__.py +0 -0
- {mplang_nightly-0.1.dev163 → mplang_nightly-0.1.dev165}/mplang/core/expr/ast.py +0 -0
- {mplang_nightly-0.1.dev163 → mplang_nightly-0.1.dev165}/mplang/core/expr/printer.py +0 -0
- {mplang_nightly-0.1.dev163 → mplang_nightly-0.1.dev165}/mplang/core/expr/transformer.py +0 -0
- {mplang_nightly-0.1.dev163 → mplang_nightly-0.1.dev165}/mplang/core/expr/utils.py +0 -0
- {mplang_nightly-0.1.dev163 → mplang_nightly-0.1.dev165}/mplang/core/expr/visitor.py +0 -0
- {mplang_nightly-0.1.dev163 → mplang_nightly-0.1.dev165}/mplang/core/expr/walk.py +0 -0
- {mplang_nightly-0.1.dev163 → mplang_nightly-0.1.dev165}/mplang/core/interp.py +0 -0
- {mplang_nightly-0.1.dev163 → mplang_nightly-0.1.dev165}/mplang/core/mask.py +0 -0
- {mplang_nightly-0.1.dev163 → mplang_nightly-0.1.dev165}/mplang/core/mpir.py +0 -0
- {mplang_nightly-0.1.dev163 → mplang_nightly-0.1.dev165}/mplang/core/mpobject.py +0 -0
- {mplang_nightly-0.1.dev163 → mplang_nightly-0.1.dev165}/mplang/core/mptype.py +0 -0
- {mplang_nightly-0.1.dev163 → mplang_nightly-0.1.dev165}/mplang/core/pfunc.py +0 -0
- {mplang_nightly-0.1.dev163 → mplang_nightly-0.1.dev165}/mplang/core/primitive.py +0 -0
- {mplang_nightly-0.1.dev163 → mplang_nightly-0.1.dev165}/mplang/core/table.py +0 -0
- {mplang_nightly-0.1.dev163 → mplang_nightly-0.1.dev165}/mplang/core/tensor.py +0 -0
- {mplang_nightly-0.1.dev163 → mplang_nightly-0.1.dev165}/mplang/core/tracer.py +0 -0
- {mplang_nightly-0.1.dev163 → mplang_nightly-0.1.dev165}/mplang/kernels/base.py +0 -0
- {mplang_nightly-0.1.dev163 → mplang_nightly-0.1.dev165}/mplang/kernels/context.py +0 -0
- {mplang_nightly-0.1.dev163 → mplang_nightly-0.1.dev165}/mplang/ops/__init__.py +0 -0
- {mplang_nightly-0.1.dev163 → mplang_nightly-0.1.dev165}/mplang/ops/base.py +0 -0
- {mplang_nightly-0.1.dev163 → mplang_nightly-0.1.dev165}/mplang/ops/builtin.py +0 -0
- {mplang_nightly-0.1.dev163 → mplang_nightly-0.1.dev165}/mplang/ops/crypto.py +0 -0
- {mplang_nightly-0.1.dev163 → mplang_nightly-0.1.dev165}/mplang/ops/ibis_cc.py +0 -0
- {mplang_nightly-0.1.dev163 → mplang_nightly-0.1.dev165}/mplang/ops/jax_cc.py +0 -0
- {mplang_nightly-0.1.dev163 → mplang_nightly-0.1.dev165}/mplang/ops/phe.py +0 -0
- {mplang_nightly-0.1.dev163 → mplang_nightly-0.1.dev165}/mplang/ops/spu.py +0 -0
- {mplang_nightly-0.1.dev163 → mplang_nightly-0.1.dev165}/mplang/ops/sql.py +0 -0
- {mplang_nightly-0.1.dev163 → mplang_nightly-0.1.dev165}/mplang/protos/v1alpha1/mpir_pb2.py +0 -0
- {mplang_nightly-0.1.dev163 → mplang_nightly-0.1.dev165}/mplang/protos/v1alpha1/mpir_pb2_grpc.py +0 -0
- {mplang_nightly-0.1.dev163 → mplang_nightly-0.1.dev165}/mplang/runtime/__init__.py +0 -0
- {mplang_nightly-0.1.dev163 → mplang_nightly-0.1.dev165}/mplang/runtime/cli.py +0 -0
- {mplang_nightly-0.1.dev163 → mplang_nightly-0.1.dev165}/mplang/runtime/data_providers.py +0 -0
- {mplang_nightly-0.1.dev163 → mplang_nightly-0.1.dev165}/mplang/runtime/exceptions.py +0 -0
- {mplang_nightly-0.1.dev163 → mplang_nightly-0.1.dev165}/mplang/runtime/http_api.md +0 -0
- {mplang_nightly-0.1.dev163 → mplang_nightly-0.1.dev165}/mplang/simp/mpi.py +0 -0
- {mplang_nightly-0.1.dev163 → mplang_nightly-0.1.dev165}/mplang/simp/random.py +0 -0
- {mplang_nightly-0.1.dev163 → mplang_nightly-0.1.dev165}/mplang/simp/smpc.py +0 -0
- {mplang_nightly-0.1.dev163/mplang/kernels → mplang_nightly-0.1.dev165/mplang/utils}/__init__.py +0 -0
- {mplang_nightly-0.1.dev163 → mplang_nightly-0.1.dev165}/mplang/utils/crypto.py +0 -0
- {mplang_nightly-0.1.dev163 → mplang_nightly-0.1.dev165}/mplang/utils/func_utils.py +0 -0
- {mplang_nightly-0.1.dev163 → mplang_nightly-0.1.dev165}/mplang/utils/spu_utils.py +0 -0
- {mplang_nightly-0.1.dev163 → mplang_nightly-0.1.dev165}/mplang/utils/table_utils.py +0 -0
- {mplang_nightly-0.1.dev163 → mplang_nightly-0.1.dev165}/tests/__init__.py +0 -0
- {mplang_nightly-0.1.dev163 → mplang_nightly-0.1.dev165}/tests/analysis/test_diagram.py +0 -0
- {mplang_nightly-0.1.dev163 → mplang_nightly-0.1.dev165}/tests/core/__init__.py +0 -0
- {mplang_nightly-0.1.dev163 → mplang_nightly-0.1.dev165}/tests/core/expr/__init__.py +0 -0
- {mplang_nightly-0.1.dev163 → mplang_nightly-0.1.dev165}/tests/core/expr/conftest.py +0 -0
- {mplang_nightly-0.1.dev163 → mplang_nightly-0.1.dev165}/tests/core/expr/test_ast.py +0 -0
- {mplang_nightly-0.1.dev163 → mplang_nightly-0.1.dev165}/tests/core/expr/test_printer.py +0 -0
- {mplang_nightly-0.1.dev163 → mplang_nightly-0.1.dev165}/tests/core/expr/test_utils.py +0 -0
- {mplang_nightly-0.1.dev163 → mplang_nightly-0.1.dev165}/tests/core/expr/test_walk.py +0 -0
- {mplang_nightly-0.1.dev163 → mplang_nightly-0.1.dev165}/tests/core/test_cluster.py +0 -0
- {mplang_nightly-0.1.dev163 → mplang_nightly-0.1.dev165}/tests/core/test_dtype.py +0 -0
- {mplang_nightly-0.1.dev163 → mplang_nightly-0.1.dev165}/tests/core/test_mask.py +0 -0
- {mplang_nightly-0.1.dev163 → mplang_nightly-0.1.dev165}/tests/core/test_mpir.py +0 -0
- {mplang_nightly-0.1.dev163 → mplang_nightly-0.1.dev165}/tests/core/test_mptype.py +0 -0
- {mplang_nightly-0.1.dev163 → mplang_nightly-0.1.dev165}/tests/core/test_primitive.py +0 -0
- {mplang_nightly-0.1.dev163 → mplang_nightly-0.1.dev165}/tests/core/test_table.py +0 -0
- {mplang_nightly-0.1.dev163 → mplang_nightly-0.1.dev165}/tests/core/test_tensor.py +0 -0
- {mplang_nightly-0.1.dev163 → mplang_nightly-0.1.dev165}/tests/core/test_tracer.py +0 -0
- {mplang_nightly-0.1.dev163/mplang/utils → mplang_nightly-0.1.dev165/tests/device}/__init__.py +0 -0
- {mplang_nightly-0.1.dev163 → mplang_nightly-0.1.dev165}/tests/device/test_device_basic.py +0 -0
- {mplang_nightly-0.1.dev163 → mplang_nightly-0.1.dev165}/tests/integration/README.md +0 -0
- {mplang_nightly-0.1.dev163 → mplang_nightly-0.1.dev165}/tests/integration/test_crypto_roundtrip.py +0 -0
- {mplang_nightly-0.1.dev163 → mplang_nightly-0.1.dev165}/tests/integration/test_http_e2e.py +0 -0
- {mplang_nightly-0.1.dev163 → mplang_nightly-0.1.dev165}/tests/integration/test_tutorials.py +0 -0
- {mplang_nightly-0.1.dev163 → mplang_nightly-0.1.dev165}/tests/integration/test_unused_param_integration.py +0 -0
- {mplang_nightly-0.1.dev163/tests/device → mplang_nightly-0.1.dev165/tests/ops}/__init__.py +0 -0
- {mplang_nightly-0.1.dev163 → mplang_nightly-0.1.dev165}/tests/ops/dummy.py +0 -0
- {mplang_nightly-0.1.dev163 → mplang_nightly-0.1.dev165}/tests/ops/test_feop_base.py +0 -0
- {mplang_nightly-0.1.dev163 → mplang_nightly-0.1.dev165}/tests/ops/test_ibis.py +0 -0
- {mplang_nightly-0.1.dev163 → mplang_nightly-0.1.dev165}/tests/ops/test_ibis_cc.py +0 -0
- {mplang_nightly-0.1.dev163 → mplang_nightly-0.1.dev165}/tests/ops/test_jax_cc.py +0 -0
- {mplang_nightly-0.1.dev163 → mplang_nightly-0.1.dev165}/tests/ops/test_phe.py +0 -0
- {mplang_nightly-0.1.dev163 → mplang_nightly-0.1.dev165}/tests/ops/test_spu.py +0 -0
- {mplang_nightly-0.1.dev163 → mplang_nightly-0.1.dev165}/tests/ops/test_spu_defensive.py +0 -0
- {mplang_nightly-0.1.dev163 → mplang_nightly-0.1.dev165}/tests/ops/test_sql.py +0 -0
- {mplang_nightly-0.1.dev163 → mplang_nightly-0.1.dev165}/tests/ops/test_table_tensor_conversion.py +0 -0
- {mplang_nightly-0.1.dev163/tests/ops → mplang_nightly-0.1.dev165/tests/runtime}/__init__.py +0 -0
- {mplang_nightly-0.1.dev163 → mplang_nightly-0.1.dev165}/tests/runtime/test_cli.py +0 -0
- {mplang_nightly-0.1.dev163 → mplang_nightly-0.1.dev165}/tests/runtime/test_driver.py +0 -0
- {mplang_nightly-0.1.dev163 → mplang_nightly-0.1.dev165}/tests/runtime/test_simulation.py +0 -0
- {mplang_nightly-0.1.dev163 → mplang_nightly-0.1.dev165}/tests/simp/test_mpi.py +0 -0
- {mplang_nightly-0.1.dev163 → mplang_nightly-0.1.dev165}/tests/simp/test_random.py +0 -0
- {mplang_nightly-0.1.dev163 → mplang_nightly-0.1.dev165}/tests/simp/test_simp.py +0 -0
- {mplang_nightly-0.1.dev163 → mplang_nightly-0.1.dev165}/tests/simp/test_smpc.py +0 -0
- {mplang_nightly-0.1.dev163 → mplang_nightly-0.1.dev165}/tests/simp/test_sugar.py +0 -0
- {mplang_nightly-0.1.dev163 → mplang_nightly-0.1.dev165}/tests/utils/__init__.py +0 -0
- {mplang_nightly-0.1.dev163 → mplang_nightly-0.1.dev165}/tests/utils/server_fixtures.py +0 -0
- {mplang_nightly-0.1.dev163 → mplang_nightly-0.1.dev165}/tests/utils/test_func_utils.py +0 -0
- {mplang_nightly-0.1.dev163 → mplang_nightly-0.1.dev165}/tests/utils/test_spu_utils.py +0 -0
- {mplang_nightly-0.1.dev163 → mplang_nightly-0.1.dev165}/tests/utils/test_table_utils.py +0 -0
- {mplang_nightly-0.1.dev163 → mplang_nightly-0.1.dev165}/tutorials/0_basic.py +0 -0
- {mplang_nightly-0.1.dev163 → mplang_nightly-0.1.dev165}/tutorials/10_analysis.py +0 -0
- {mplang_nightly-0.1.dev163 → mplang_nightly-0.1.dev165}/tutorials/1_condition.py +0 -0
- {mplang_nightly-0.1.dev163 → mplang_nightly-0.1.dev165}/tutorials/2_whileloop.py +0 -0
- {mplang_nightly-0.1.dev163 → mplang_nightly-0.1.dev165}/tutorials/3_device.py +0 -0
- {mplang_nightly-0.1.dev163 → mplang_nightly-0.1.dev165}/tutorials/4_simulation.py +0 -0
- {mplang_nightly-0.1.dev163 → mplang_nightly-0.1.dev165}/tutorials/5_ir_dump.py +0 -0
- {mplang_nightly-0.1.dev163 → mplang_nightly-0.1.dev165}/tutorials/6_advanced.py +0 -0
- {mplang_nightly-0.1.dev163 → mplang_nightly-0.1.dev165}/tutorials/7_stdio.py +0 -0
- {mplang_nightly-0.1.dev163 → mplang_nightly-0.1.dev165}/tutorials/8_phe.py +0 -0
- {mplang_nightly-0.1.dev163 → mplang_nightly-0.1.dev165}/tutorials/__init__.py +0 -0
- {mplang_nightly-0.1.dev163 → mplang_nightly-0.1.dev165}/tutorials/pitfalls/late_binding.py +0 -0
- {mplang_nightly-0.1.dev163 → mplang_nightly-0.1.dev165}/tutorials/pitfalls/rand.py +0 -0
- {mplang_nightly-0.1.dev163 → mplang_nightly-0.1.dev165}/tutorials/run.sh +0 -0
@@ -27,6 +27,8 @@ from __future__ import annotations
|
|
27
27
|
from dataclasses import dataclass
|
28
28
|
from typing import Any, Protocol
|
29
29
|
|
30
|
+
import numpy as np
|
31
|
+
|
30
32
|
from mplang.core.comm import ICommunicator
|
31
33
|
from mplang.core.expr.ast import (
|
32
34
|
AccessExpr,
|
@@ -47,6 +49,7 @@ from mplang.core.expr.walk import walk_dataflow
|
|
47
49
|
from mplang.core.mask import Mask
|
48
50
|
from mplang.core.pfunc import PFunction
|
49
51
|
from mplang.kernels.context import RuntimeContext
|
52
|
+
from mplang.kernels.value import Value
|
50
53
|
|
51
54
|
|
52
55
|
class IEvaluator(Protocol):
|
@@ -149,12 +152,12 @@ class EvalSemantic:
|
|
149
152
|
def _as_optional_int(val: Any) -> int | None:
|
150
153
|
"""Convert a value to int if possible, preserving None.
|
151
154
|
|
152
|
-
Handles Python ints, numpy
|
155
|
+
Handles Python ints, floats, numpy scalar types (e.g., np.int32, np.float64), and None.
|
156
|
+
Uses int(val) for conversion which works with numpy scalars via __int__().
|
153
157
|
"""
|
158
|
+
val = EvalSemantic._unwrap_value(val)
|
154
159
|
if val is None:
|
155
160
|
return None
|
156
|
-
if hasattr(val, "item"):
|
157
|
-
return int(val.item())
|
158
161
|
return int(val)
|
159
162
|
|
160
163
|
def _simple_allgather(self, value: Any) -> list[Any]:
|
@@ -167,6 +170,7 @@ class EvalSemantic:
|
|
167
170
|
Returns a list of length world_size with entries ordered by rank.
|
168
171
|
"""
|
169
172
|
ws = self.comm.world_size
|
173
|
+
value = self._unwrap_value(value)
|
170
174
|
# Trivial fast-path
|
171
175
|
if ws == 1:
|
172
176
|
return [value]
|
@@ -185,7 +189,12 @@ class EvalSemantic:
|
|
185
189
|
|
186
190
|
def _verify_uniform_predicate(self, pred: Any) -> None:
|
187
191
|
# Runtime uniformity check (O(P^2) send/recv emulation).
|
188
|
-
|
192
|
+
# Use Value.to_bool() if available, otherwise unwrap and convert
|
193
|
+
if isinstance(pred, Value):
|
194
|
+
pred_bool = pred.to_bool()
|
195
|
+
else:
|
196
|
+
pred_bool = bool(self._unwrap_value(pred))
|
197
|
+
vals = self._simple_allgather(pred_bool)
|
189
198
|
if not vals:
|
190
199
|
raise ValueError("uniform_cond: empty gather for predicate")
|
191
200
|
first = vals[0]
|
@@ -209,13 +218,33 @@ class EvalSemantic:
|
|
209
218
|
assert len(cond_result) == 1, (
|
210
219
|
f"Condition function must return a single value, got {cond_result}"
|
211
220
|
)
|
212
|
-
|
213
|
-
if
|
221
|
+
cond_val = cond_result[0]
|
222
|
+
if cond_val is None:
|
214
223
|
raise RuntimeError(
|
215
224
|
"while_loop condition produced None on rank "
|
216
225
|
f"{self.rank}; ensure the predicate yields a boolean for every party."
|
217
226
|
)
|
218
|
-
|
227
|
+
# Use Value.to_bool() if available for cleaner conversion
|
228
|
+
if isinstance(cond_val, Value):
|
229
|
+
return cond_val.to_bool()
|
230
|
+
return bool(self._unwrap_value(cond_val))
|
231
|
+
|
232
|
+
@staticmethod
|
233
|
+
def _unwrap_value(value: Any) -> Any:
|
234
|
+
"""Convert Value payloads to numpy/python equivalents when possible."""
|
235
|
+
if value is None:
|
236
|
+
return None
|
237
|
+
if isinstance(value, Value):
|
238
|
+
# Try to_numpy first for broader compatibility
|
239
|
+
to_numpy = getattr(value, "to_numpy", None)
|
240
|
+
if callable(to_numpy):
|
241
|
+
arr = to_numpy()
|
242
|
+
if isinstance(arr, np.ndarray):
|
243
|
+
if arr.size == 1:
|
244
|
+
return arr.item()
|
245
|
+
return arr
|
246
|
+
return arr
|
247
|
+
return value
|
219
248
|
|
220
249
|
|
221
250
|
class RecursiveEvaluator(EvalSemantic, ExprVisitor):
|
@@ -296,15 +325,21 @@ class RecursiveEvaluator(EvalSemantic, ExprVisitor):
|
|
296
325
|
* Add optional static uniform inference (data provenance) to elide the
|
297
326
|
runtime check when predicate uniformity is provable at trace time.
|
298
327
|
"""
|
299
|
-
|
300
|
-
if
|
328
|
+
pred_val = self._value(expr.pred)
|
329
|
+
if pred_val is None:
|
301
330
|
return [None] * len(expr.mptypes)
|
302
331
|
|
303
332
|
if expr.verify_uniform:
|
304
|
-
self._verify_uniform_predicate(
|
333
|
+
self._verify_uniform_predicate(pred_val)
|
334
|
+
|
335
|
+
# Convert to bool using Value.to_bool() if available
|
336
|
+
if isinstance(pred_val, Value):
|
337
|
+
pred = pred_val.to_bool()
|
338
|
+
else:
|
339
|
+
pred = bool(self._unwrap_value(pred_val))
|
305
340
|
|
306
341
|
# Only evaluate selected branch locally
|
307
|
-
if pred:
|
342
|
+
if bool(pred):
|
308
343
|
then_call = CallExpr(expr.then_fn, expr.args)
|
309
344
|
return self._values(then_call)
|
310
345
|
else:
|
@@ -435,15 +470,20 @@ class IterativeEvaluator(EvalSemantic):
|
|
435
470
|
res = self._iter_eval_graph(node.fn.body, {**env, **sub_env})
|
436
471
|
symbols[id(node)] = res
|
437
472
|
elif isinstance(node, CondExpr):
|
438
|
-
|
473
|
+
pred_val = self._first(symbols[id(node.pred)])
|
439
474
|
arg_vals = [self._first(symbols[id(a)]) for a in node.args]
|
440
|
-
if
|
475
|
+
if pred_val is None:
|
441
476
|
symbols[id(node)] = [None] * len(node.mptypes)
|
442
477
|
else:
|
443
478
|
# Optional uniform verification identical to recursive evaluator (DRY helper).
|
444
479
|
if node.verify_uniform:
|
445
|
-
self._verify_uniform_predicate(
|
446
|
-
|
480
|
+
self._verify_uniform_predicate(pred_val)
|
481
|
+
# Convert to bool using Value.to_bool() if available
|
482
|
+
if isinstance(pred_val, Value):
|
483
|
+
pred = pred_val.to_bool()
|
484
|
+
else:
|
485
|
+
pred = bool(self._unwrap_value(pred_val))
|
486
|
+
if pred:
|
447
487
|
sub_env = dict(zip(node.then_fn.params, arg_vals, strict=True))
|
448
488
|
res = self._iter_eval_graph(
|
449
489
|
node.then_fn.body, {**env, **sub_env}
|
@@ -207,15 +207,8 @@ def _d2d(to_dev_id: str, obj: MPObject) -> MPObject:
|
|
207
207
|
assert len(frm_dev.members) == 1 and len(to_dev.members) == 1
|
208
208
|
frm_rank = frm_dev.members[0].rank
|
209
209
|
tee_rank = to_dev.members[0].rank
|
210
|
-
platform = to_dev.config.get("platform")
|
211
|
-
if not platform:
|
212
|
-
raise ValueError(
|
213
|
-
f"TEE device '{to_dev_id}' is missing 'platform' in its config."
|
214
|
-
)
|
215
210
|
# Ensure sessions (both directions) exist for this PPU<->TEE pair
|
216
|
-
sess_p, sess_t = _ensure_tee_session(
|
217
|
-
frm_dev_id, to_dev_id, frm_rank, tee_rank, platform
|
218
|
-
)
|
211
|
+
sess_p, sess_t = _ensure_tee_session(frm_dev_id, to_dev_id, frm_rank, tee_rank)
|
219
212
|
# Bytes-only path: pack -> enc -> p2p -> dec -> unpack (with static out type)
|
220
213
|
obj_ty = TensorType.from_obj(obj)
|
221
214
|
b = simp.runAt(frm_rank, builtin.pack)(obj)
|
@@ -229,15 +222,8 @@ def _d2d(to_dev_id: str, obj: MPObject) -> MPObject:
|
|
229
222
|
assert len(frm_dev.members) == 1 and len(to_dev.members) == 1
|
230
223
|
tee_rank = frm_dev.members[0].rank
|
231
224
|
ppu_rank = to_dev.members[0].rank
|
232
|
-
platform = frm_dev.config.get("platform")
|
233
|
-
if not platform:
|
234
|
-
raise ValueError(
|
235
|
-
f"TEE device '{frm_dev_id}' is missing 'platform' in its config."
|
236
|
-
)
|
237
225
|
# Ensure bidirectional session established for this pair
|
238
|
-
sess_p, sess_t = _ensure_tee_session(
|
239
|
-
to_dev_id, frm_dev_id, ppu_rank, tee_rank, platform
|
240
|
-
)
|
226
|
+
sess_p, sess_t = _ensure_tee_session(to_dev_id, frm_dev_id, ppu_rank, tee_rank)
|
241
227
|
obj_ty = TensorType.from_obj(obj)
|
242
228
|
b = simp.runAt(tee_rank, builtin.pack)(obj)
|
243
229
|
ct = simp.runAt(tee_rank, crypto.enc)(b, sess_t)
|
@@ -259,7 +245,7 @@ def _d2d(to_dev_id: str, obj: MPObject) -> MPObject:
|
|
259
245
|
|
260
246
|
|
261
247
|
def _ensure_tee_session(
|
262
|
-
frm_dev_id: str, to_dev_id: str, frm_rank: int, tee_rank: int
|
248
|
+
frm_dev_id: str, to_dev_id: str, frm_rank: int, tee_rank: int
|
263
249
|
) -> tuple[MPObject, MPObject]:
|
264
250
|
"""Ensure a TEE session (sess_p at sender, sess_t at TEE) exists.
|
265
251
|
|
@@ -281,7 +267,7 @@ def _ensure_tee_session(
|
|
281
267
|
|
282
268
|
# 2) Send quote to sender and attest to obtain TEE pk
|
283
269
|
quote_at_sender = mpi.p2p(tee_rank, frm_rank, quote)
|
284
|
-
tee_pk_at_sender = simp.runAt(frm_rank, tee.attest)(quote_at_sender
|
270
|
+
tee_pk_at_sender = simp.runAt(frm_rank, tee.attest)(quote_at_sender)
|
285
271
|
|
286
272
|
# 3) Sender generates its ephemeral keypair and sends its pk to TEE
|
287
273
|
v_sk, v_pk = simp.runAt(frm_rank, crypto.kem_keygen)(_TEE_KEM_SUITE)
|
@@ -0,0 +1,41 @@
|
|
1
|
+
# Copyright 2025 Ant Group Co., Ltd.
|
2
|
+
#
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4
|
+
# you may not use this file except in compliance with the License.
|
5
|
+
# You may obtain a copy of the License at
|
6
|
+
#
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
8
|
+
#
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12
|
+
# See the License for the specific language governing permissions and
|
13
|
+
# limitations under the License.
|
14
|
+
|
15
|
+
from mplang.kernels.value import (
|
16
|
+
BytesBlob,
|
17
|
+
TableValue,
|
18
|
+
TensorValue,
|
19
|
+
Value,
|
20
|
+
ValueDecodeError,
|
21
|
+
ValueError,
|
22
|
+
decode_value,
|
23
|
+
encode_value,
|
24
|
+
is_value_envelope,
|
25
|
+
list_value_kinds,
|
26
|
+
register_value,
|
27
|
+
)
|
28
|
+
|
29
|
+
__all__ = [
|
30
|
+
"BytesBlob",
|
31
|
+
"TableValue",
|
32
|
+
"TensorValue",
|
33
|
+
"Value",
|
34
|
+
"ValueDecodeError",
|
35
|
+
"ValueError",
|
36
|
+
"decode_value",
|
37
|
+
"encode_value",
|
38
|
+
"is_value_envelope",
|
39
|
+
"list_value_kinds",
|
40
|
+
"register_value",
|
41
|
+
]
|
@@ -14,38 +14,25 @@
|
|
14
14
|
|
15
15
|
from __future__ import annotations
|
16
16
|
|
17
|
-
from typing import Any
|
18
|
-
|
19
17
|
import numpy as np
|
20
|
-
import pandas as pd
|
21
18
|
|
22
19
|
from mplang.core.pfunc import PFunction
|
23
20
|
from mplang.core.table import TableType
|
24
21
|
from mplang.core.tensor import TensorType
|
25
22
|
from mplang.kernels.base import cur_kctx, kernel_def
|
23
|
+
from mplang.kernels.value import TableValue, TensorValue, Value
|
26
24
|
from mplang.runtime.data_providers import get_provider, resolve_uri
|
27
25
|
from mplang.utils import table_utils
|
28
26
|
|
29
27
|
|
30
|
-
def _to_numpy(obj: Any) -> np.ndarray: # minimal helper to avoid duplicating logic
|
31
|
-
if isinstance(obj, np.ndarray):
|
32
|
-
return obj
|
33
|
-
if hasattr(obj, "numpy"):
|
34
|
-
try:
|
35
|
-
return np.asarray(obj.numpy()) # type: ignore
|
36
|
-
except Exception:
|
37
|
-
pass
|
38
|
-
return np.asarray(obj)
|
39
|
-
|
40
|
-
|
41
28
|
@kernel_def("builtin.identity")
|
42
|
-
def _identity(pfunc: PFunction, value:
|
29
|
+
def _identity(pfunc: PFunction, value: Value) -> Value:
|
43
30
|
# Runtime guarantees exactly one argument; no extra arity checks here.
|
44
31
|
return value
|
45
32
|
|
46
33
|
|
47
34
|
@kernel_def("builtin.read")
|
48
|
-
def _read(pfunc: PFunction) ->
|
35
|
+
def _read(pfunc: PFunction) -> Value:
|
49
36
|
path = pfunc.attrs.get("path")
|
50
37
|
if path is None:
|
51
38
|
raise ValueError("missing path attr for builtin.read")
|
@@ -56,13 +43,25 @@ def _read(pfunc: PFunction) -> Any:
|
|
56
43
|
raise NotImplementedError(f"no resource provider for scheme: {uri.scheme}")
|
57
44
|
ctx = cur_kctx()
|
58
45
|
try:
|
59
|
-
|
46
|
+
data = prov.read(uri, out_t, ctx=ctx)
|
60
47
|
except Exception as e: # pragma: no cover - provider errors
|
61
48
|
raise RuntimeError(f"builtin.read failed: {e}") from e
|
62
49
|
|
50
|
+
if isinstance(out_t, TableType):
|
51
|
+
if isinstance(data, TableValue):
|
52
|
+
return data
|
53
|
+
return TableValue(data)
|
54
|
+
if isinstance(out_t, TensorType):
|
55
|
+
if isinstance(data, TensorValue):
|
56
|
+
return data
|
57
|
+
return TensorValue(np.asarray(data))
|
58
|
+
raise TypeError(
|
59
|
+
f"builtin.read only supports TableType/TensorType outputs, got {type(out_t).__name__}"
|
60
|
+
)
|
61
|
+
|
63
62
|
|
64
63
|
@kernel_def("builtin.write")
|
65
|
-
def _write(pfunc: PFunction, obj:
|
64
|
+
def _write(pfunc: PFunction, obj: Value) -> Value:
|
66
65
|
path = pfunc.attrs.get("path")
|
67
66
|
if path is None:
|
68
67
|
raise ValueError("missing path attr for builtin.write")
|
@@ -70,16 +69,18 @@ def _write(pfunc: PFunction, obj: Any) -> Any:
|
|
70
69
|
prov = get_provider(uri.scheme)
|
71
70
|
if prov is None:
|
72
71
|
raise NotImplementedError(f"no resource provider for scheme: {uri.scheme}")
|
72
|
+
# Pass Value object directly to provider - let provider decide how to handle it
|
73
73
|
ctx = cur_kctx()
|
74
74
|
try:
|
75
75
|
prov.write(uri, obj, ctx=ctx)
|
76
|
-
return obj
|
77
76
|
except Exception as e: # pragma: no cover
|
78
77
|
raise RuntimeError(f"builtin.write failed: {e}") from e
|
78
|
+
return obj
|
79
79
|
|
80
80
|
|
81
81
|
@kernel_def("builtin.constant")
|
82
|
-
def _constant(pfunc: PFunction) ->
|
82
|
+
def _constant(pfunc: PFunction) -> Value:
|
83
|
+
"""Return constants as Value types (TensorValue or TableValue)."""
|
83
84
|
data_bytes = pfunc.attrs.get("data_bytes")
|
84
85
|
if data_bytes is None:
|
85
86
|
raise ValueError("missing data_bytes attr for builtin.constant")
|
@@ -89,69 +90,86 @@ def _constant(pfunc: PFunction) -> Any:
|
|
89
90
|
if fmt != "bytes[csv]":
|
90
91
|
raise ValueError(f"unsupported table constant format {fmt}")
|
91
92
|
df = table_utils.csv_to_dataframe(data_bytes)
|
92
|
-
return df
|
93
|
+
return TableValue(df)
|
93
94
|
# tensor path
|
94
95
|
shape = out_t.shape # type: ignore[attr-defined,union-attr]
|
95
96
|
dtype = out_t.dtype.numpy_dtype() # type: ignore[attr-defined,union-attr]
|
96
97
|
arr = np.frombuffer(data_bytes, dtype=dtype).reshape(shape)
|
97
|
-
return arr
|
98
|
+
return TensorValue(arr)
|
98
99
|
|
99
100
|
|
100
101
|
@kernel_def("builtin.rank")
|
101
|
-
def _rank(pfunc: PFunction) ->
|
102
|
+
def _rank(pfunc: PFunction) -> TensorValue:
|
103
|
+
"""Return rank as TensorValue."""
|
102
104
|
ctx = cur_kctx()
|
103
|
-
|
105
|
+
arr = np.array(ctx.rank, dtype=np.uint64)
|
106
|
+
return TensorValue(arr)
|
104
107
|
|
105
108
|
|
106
109
|
@kernel_def("builtin.prand")
|
107
|
-
def _prand(pfunc: PFunction) ->
|
110
|
+
def _prand(pfunc: PFunction) -> TensorValue:
|
111
|
+
"""Return random data as TensorValue."""
|
108
112
|
shape = pfunc.attrs.get("shape", ())
|
109
113
|
rng = np.random.default_rng()
|
110
114
|
info = np.iinfo(np.uint64)
|
111
115
|
data = rng.integers(
|
112
116
|
low=info.min, high=info.max, size=shape, dtype=np.uint64, endpoint=True
|
113
117
|
)
|
114
|
-
return data
|
118
|
+
return TensorValue(data)
|
115
119
|
|
116
120
|
|
117
121
|
@kernel_def("builtin.table_to_tensor")
|
118
|
-
def _table_to_tensor(pfunc: PFunction, table:
|
119
|
-
|
120
|
-
|
121
|
-
if
|
122
|
+
def _table_to_tensor(pfunc: PFunction, table: TableValue) -> TensorValue:
|
123
|
+
"""Convert table to tensor, return as TensorValue."""
|
124
|
+
arrow_table = table.to_arrow()
|
125
|
+
if arrow_table.num_columns == 0:
|
122
126
|
raise ValueError("cannot pack empty table")
|
123
|
-
|
124
|
-
|
127
|
+
# Convert Arrow columns to numpy arrays and stack
|
128
|
+
mat = np.column_stack([
|
129
|
+
arrow_table.column(i).to_numpy() for i in range(arrow_table.num_columns)
|
130
|
+
])
|
131
|
+
return TensorValue(mat)
|
125
132
|
|
126
133
|
|
127
134
|
@kernel_def("builtin.tensor_to_table")
|
128
|
-
def _tensor_to_table(pfunc: PFunction, tensor:
|
129
|
-
|
135
|
+
def _tensor_to_table(pfunc: PFunction, tensor: TensorValue) -> TableValue:
|
136
|
+
"""Convert tensor to table, return as TableValue."""
|
137
|
+
import pyarrow as pa # type: ignore
|
138
|
+
|
139
|
+
arr = tensor.to_numpy()
|
130
140
|
if arr.ndim != 2:
|
131
141
|
raise ValueError("tensor_to_table expects rank-2 array")
|
132
142
|
col_names = pfunc.attrs.get("column_names")
|
133
143
|
if col_names is None:
|
134
144
|
raise ValueError("missing column_names attr")
|
135
|
-
|
136
|
-
|
145
|
+
# Create Arrow table directly from numpy array columns
|
146
|
+
arrays = [pa.array(arr[:, i]) for i in range(arr.shape[1])]
|
147
|
+
arrow_table = pa.table(dict(zip(col_names, arrays, strict=True)))
|
148
|
+
return TableValue(arrow_table)
|
137
149
|
|
138
150
|
|
139
|
-
def _summ(v:
|
151
|
+
def _summ(v: Value) -> str:
|
140
152
|
try:
|
141
|
-
if isinstance(v,
|
142
|
-
|
143
|
-
|
144
|
-
|
145
|
-
|
146
|
-
|
153
|
+
if isinstance(v, TableValue):
|
154
|
+
# Use Arrow's native string representation (more efficient)
|
155
|
+
arrow_table = v.to_arrow()
|
156
|
+
# Show first 8 rows
|
157
|
+
preview = arrow_table.slice(0, min(8, arrow_table.num_rows))
|
158
|
+
return str(preview)
|
159
|
+
if isinstance(v, TensorValue):
|
160
|
+
arr = v.to_numpy()
|
161
|
+
return str(
|
162
|
+
np.array2string(
|
163
|
+
arr, threshold=64, edgeitems=3, precision=6, suppress_small=True
|
164
|
+
)
|
147
165
|
)
|
148
|
-
)
|
166
|
+
return repr(v)
|
149
167
|
except Exception as e: # pragma: no cover
|
150
168
|
return f"<unprintable {type(v).__name__}: {e}>"
|
151
169
|
|
152
170
|
|
153
171
|
@kernel_def("builtin.debug_print")
|
154
|
-
def _debug_print(pfunc: PFunction, val:
|
172
|
+
def _debug_print(pfunc: PFunction, val: Value) -> Value:
|
155
173
|
prefix = pfunc.attrs.get("prefix", "")
|
156
174
|
ctx = cur_kctx()
|
157
175
|
print(f"[debug_print][rank={ctx.rank}] {prefix}{_summ(val)}")
|
@@ -159,7 +177,7 @@ def _debug_print(pfunc: PFunction, val: Any) -> Any:
|
|
159
177
|
|
160
178
|
|
161
179
|
@kernel_def("builtin.pack")
|
162
|
-
def _pack(pfunc: PFunction, value:
|
180
|
+
def _pack(pfunc: PFunction, value: Value) -> TensorValue:
|
163
181
|
outs_info = pfunc.outs_info
|
164
182
|
if len(outs_info) != 1:
|
165
183
|
raise ValueError("builtin.pack expects single output type")
|
@@ -169,22 +187,33 @@ def _pack(pfunc: PFunction, value: Any) -> Any:
|
|
169
187
|
if out_ty.dtype.numpy_dtype() != np.uint8:
|
170
188
|
raise TypeError("builtin.pack output dtype must be uint8")
|
171
189
|
|
172
|
-
if isinstance(value,
|
173
|
-
|
174
|
-
|
190
|
+
if isinstance(value, TableValue):
|
191
|
+
# Serialize Arrow table using IPC stream for consistency with Value serde
|
192
|
+
import pyarrow as pa # type: ignore
|
193
|
+
import pyarrow.ipc as pa_ipc # type: ignore
|
194
|
+
|
195
|
+
arrow_table = value.to_arrow()
|
196
|
+
sink = pa.BufferOutputStream()
|
197
|
+
with pa_ipc.new_stream(sink, arrow_table.schema) as writer: # type: ignore[arg-type]
|
198
|
+
writer.write_table(arrow_table) # type: ignore[arg-type]
|
199
|
+
ipc_bytes = sink.getvalue().to_pybytes()
|
200
|
+
return TensorValue(np.frombuffer(ipc_bytes, dtype=np.uint8))
|
201
|
+
|
202
|
+
if isinstance(value, TensorValue):
|
203
|
+
arr = value.to_numpy()
|
204
|
+
return TensorValue(np.frombuffer(arr.tobytes(order="C"), dtype=np.uint8))
|
175
205
|
|
176
|
-
|
177
|
-
return np.frombuffer(arr.tobytes(order="C"), dtype=np.uint8)
|
206
|
+
raise TypeError(f"builtin.pack does not support Value type {type(value).__name__}")
|
178
207
|
|
179
208
|
|
180
209
|
@kernel_def("builtin.unpack")
|
181
|
-
def _unpack(pfunc: PFunction, packed:
|
210
|
+
def _unpack(pfunc: PFunction, packed: TensorValue) -> Value:
|
182
211
|
outs_info = pfunc.outs_info
|
183
212
|
if len(outs_info) != 1:
|
184
213
|
raise ValueError("builtin.unpack expects single output type")
|
185
214
|
out_ty = outs_info[0]
|
186
215
|
|
187
|
-
b = np.
|
216
|
+
b = packed.to_numpy().astype(np.uint8, copy=False).reshape(-1)
|
188
217
|
|
189
218
|
if isinstance(out_ty, TensorType):
|
190
219
|
np_dtype = out_ty.dtype.numpy_dtype()
|
@@ -198,10 +227,16 @@ def _unpack(pfunc: PFunction, packed: Any) -> Any:
|
|
198
227
|
f"unpack size mismatch: got {b.size} bytes, expect {expected} for {np_dtype} {shape}"
|
199
228
|
)
|
200
229
|
arr = np.frombuffer(b.tobytes(), dtype=np_dtype)
|
201
|
-
return arr.reshape(shape)
|
230
|
+
return TensorValue(arr.reshape(shape))
|
202
231
|
|
203
232
|
if isinstance(out_ty, TableType):
|
204
|
-
|
205
|
-
|
233
|
+
# Deserialize Arrow IPC stream back to TableValue
|
234
|
+
import pyarrow as pa # type: ignore
|
235
|
+
import pyarrow.ipc as pa_ipc # type: ignore
|
236
|
+
|
237
|
+
buf = pa.py_buffer(b.tobytes())
|
238
|
+
reader = pa_ipc.open_stream(buf)
|
239
|
+
table = reader.read_all()
|
240
|
+
return TableValue(table)
|
206
241
|
|
207
242
|
raise TypeError("builtin.unpack output type must be TensorType or TableType")
|
@@ -15,15 +15,15 @@
|
|
15
15
|
from __future__ import annotations
|
16
16
|
|
17
17
|
import os
|
18
|
-
from typing import Any
|
19
18
|
|
20
19
|
import numpy as np
|
21
20
|
|
22
21
|
from mplang.core.pfunc import PFunction
|
23
22
|
from mplang.kernels.base import cur_kctx, kernel_def
|
23
|
+
from mplang.kernels.value import TensorValue
|
24
24
|
from mplang.utils.crypto import blake2b
|
25
25
|
|
26
|
-
__all__: list[str] = [] #
|
26
|
+
__all__: list[str] = [] # No public exports currently
|
27
27
|
|
28
28
|
|
29
29
|
def _get_rng() -> np.random.Generator:
|
@@ -54,62 +54,71 @@ def _keystream(key: bytes, nonce: bytes, length: int) -> bytes:
|
|
54
54
|
|
55
55
|
|
56
56
|
@kernel_def("crypto.keygen")
|
57
|
-
def _crypto_keygen(pfunc: PFunction) ->
|
57
|
+
def _crypto_keygen(pfunc: PFunction) -> TensorValue:
|
58
58
|
length = int(pfunc.attrs.get("length", 32))
|
59
59
|
rng = _get_rng()
|
60
60
|
key = rng.integers(0, 256, size=(length,), dtype=np.uint8)
|
61
|
-
return key
|
61
|
+
return TensorValue(key)
|
62
62
|
|
63
63
|
|
64
64
|
@kernel_def("crypto.enc")
|
65
|
-
def _crypto_encrypt(
|
66
|
-
|
67
|
-
|
65
|
+
def _crypto_encrypt(
|
66
|
+
pfunc: PFunction, pt_bytes: TensorValue, key: TensorValue
|
67
|
+
) -> TensorValue:
|
68
|
+
pt_bytes_np = pt_bytes.to_numpy().astype(np.uint8, copy=False)
|
69
|
+
key_np = key.to_numpy().astype(np.uint8, copy=False)
|
68
70
|
rng = _get_rng()
|
69
71
|
nonce = rng.integers(0, 256, size=(12,), dtype=np.uint8)
|
70
72
|
stream = np.frombuffer(
|
71
|
-
_keystream(
|
73
|
+
_keystream(key_np.tobytes(), nonce.tobytes(), pt_bytes_np.size), dtype=np.uint8
|
72
74
|
)
|
73
|
-
ct = (
|
75
|
+
ct = (pt_bytes_np ^ stream).astype(np.uint8)
|
74
76
|
out = np.concatenate([nonce, ct]).astype(np.uint8)
|
75
|
-
return out
|
77
|
+
return TensorValue(out)
|
76
78
|
|
77
79
|
|
78
80
|
@kernel_def("crypto.dec")
|
79
|
-
def _crypto_decrypt(
|
80
|
-
|
81
|
-
|
82
|
-
|
83
|
-
|
81
|
+
def _crypto_decrypt(
|
82
|
+
pfunc: PFunction, ct_with_nonce: TensorValue, key: TensorValue
|
83
|
+
) -> TensorValue:
|
84
|
+
ct_np = ct_with_nonce.to_numpy().astype(np.uint8, copy=False)
|
85
|
+
key_np = key.to_numpy().astype(np.uint8, copy=False)
|
86
|
+
nonce = ct_np[:12]
|
87
|
+
ct = ct_np[12:]
|
84
88
|
stream = np.frombuffer(
|
85
|
-
_keystream(
|
89
|
+
_keystream(key_np.tobytes(), nonce.tobytes(), len(ct)), dtype=np.uint8
|
86
90
|
)
|
87
91
|
pt_bytes = (ct ^ stream).astype(np.uint8)
|
88
|
-
return pt_bytes
|
92
|
+
return TensorValue(pt_bytes)
|
89
93
|
|
90
94
|
|
91
95
|
@kernel_def("crypto.kem_keygen")
|
92
|
-
def _crypto_kem_keygen(pfunc: PFunction) ->
|
96
|
+
def _crypto_kem_keygen(pfunc: PFunction) -> tuple[TensorValue, TensorValue]:
|
93
97
|
rng = _get_rng()
|
94
98
|
sk = rng.integers(0, 256, size=(32,), dtype=np.uint8)
|
95
|
-
|
96
|
-
|
99
|
+
pk_bytes = blake2b(sk.tobytes())[:32]
|
100
|
+
pk = np.frombuffer(pk_bytes, dtype=np.uint8)
|
101
|
+
return (TensorValue(sk), TensorValue(pk))
|
97
102
|
|
98
103
|
|
99
104
|
@kernel_def("crypto.kem_derive")
|
100
|
-
def _crypto_kem_derive(
|
101
|
-
|
102
|
-
|
103
|
-
|
104
|
-
|
105
|
+
def _crypto_kem_derive(
|
106
|
+
pfunc: PFunction, sk: TensorValue, peer_pk: TensorValue
|
107
|
+
) -> TensorValue:
|
108
|
+
sk_np = sk.to_numpy().astype(np.uint8, copy=False)
|
109
|
+
peer_pk_np = peer_pk.to_numpy().astype(np.uint8, copy=False)
|
110
|
+
|
111
|
+
self_pk_bytes = blake2b(sk_np.tobytes())[:32]
|
112
|
+
self_pk_arr = np.frombuffer(self_pk_bytes, dtype=np.uint8)
|
113
|
+
xored = (self_pk_arr ^ peer_pk_np).astype(np.uint8)
|
105
114
|
secret = np.frombuffer(blake2b(xored.tobytes())[:32], dtype=np.uint8)
|
106
|
-
return secret
|
115
|
+
return TensorValue(secret)
|
107
116
|
|
108
117
|
|
109
118
|
@kernel_def("crypto.hkdf")
|
110
|
-
def _crypto_hkdf(pfunc: PFunction, secret:
|
111
|
-
|
119
|
+
def _crypto_hkdf(pfunc: PFunction, secret: TensorValue) -> TensorValue:
|
120
|
+
secret_np = secret.to_numpy().astype(np.uint8, copy=False)
|
112
121
|
info_str = str(pfunc.attrs.get("info", ""))
|
113
122
|
info = info_str.encode("utf-8")
|
114
|
-
out = np.frombuffer(blake2b(
|
115
|
-
return out
|
123
|
+
out = np.frombuffer(blake2b(secret_np.tobytes() + info)[:32], dtype=np.uint8)
|
124
|
+
return TensorValue(out)
|