mplang-nightly 0.1.dev152__tar.gz → 0.1.dev153__tar.gz
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- {mplang_nightly-0.1.dev152 → mplang_nightly-0.1.dev153}/PKG-INFO +1 -1
- {mplang_nightly-0.1.dev152 → mplang_nightly-0.1.dev153}/mplang/core/cluster.py +95 -24
- {mplang_nightly-0.1.dev152 → mplang_nightly-0.1.dev153}/mplang/runtime/client.py +4 -18
- {mplang_nightly-0.1.dev152 → mplang_nightly-0.1.dev153}/mplang/runtime/driver.py +1 -6
- {mplang_nightly-0.1.dev152 → mplang_nightly-0.1.dev153}/mplang/runtime/server.py +95 -49
- mplang_nightly-0.1.dev153/mplang/runtime/session.py +285 -0
- {mplang_nightly-0.1.dev152 → mplang_nightly-0.1.dev153}/mplang/runtime/simulation.py +15 -13
- mplang_nightly-0.1.dev153/tests/conftest.py +17 -0
- {mplang_nightly-0.1.dev152 → mplang_nightly-0.1.dev153}/tests/core/test_cluster.py +29 -24
- {mplang_nightly-0.1.dev152 → mplang_nightly-0.1.dev153}/tests/integration/test_http_e2e.py +2 -3
- {mplang_nightly-0.1.dev152 → mplang_nightly-0.1.dev153}/tests/integration/test_symbols_roundtrip.py +1 -3
- {mplang_nightly-0.1.dev152 → mplang_nightly-0.1.dev153}/tests/kernels/test_spu.py +1 -1
- {mplang_nightly-0.1.dev152 → mplang_nightly-0.1.dev153}/tests/ops/test_table_tensor_conversion.py +1 -1
- {mplang_nightly-0.1.dev152 → mplang_nightly-0.1.dev153}/tests/runtime/test_communicator.py +45 -30
- {mplang_nightly-0.1.dev152 → mplang_nightly-0.1.dev153}/tests/runtime/test_driver.py +19 -38
- {mplang_nightly-0.1.dev152 → mplang_nightly-0.1.dev153}/tests/runtime/test_server.py +30 -21
- {mplang_nightly-0.1.dev152 → mplang_nightly-0.1.dev153}/tutorials/3_device.py +4 -1
- {mplang_nightly-0.1.dev152 → mplang_nightly-0.1.dev153}/tutorials/9_tee.py +4 -1
- mplang_nightly-0.1.dev152/mplang/runtime/resource.py +0 -365
- {mplang_nightly-0.1.dev152 → mplang_nightly-0.1.dev153}/.gitignore +0 -0
- {mplang_nightly-0.1.dev152 → mplang_nightly-0.1.dev153}/LICENSE +0 -0
- {mplang_nightly-0.1.dev152 → mplang_nightly-0.1.dev153}/README.md +0 -0
- {mplang_nightly-0.1.dev152 → mplang_nightly-0.1.dev153}/examples/conf/3pc.yaml +0 -0
- {mplang_nightly-0.1.dev152 → mplang_nightly-0.1.dev153}/examples/stax_nn/README.md +0 -0
- {mplang_nightly-0.1.dev152 → mplang_nightly-0.1.dev153}/examples/stax_nn/models.py +0 -0
- {mplang_nightly-0.1.dev152 → mplang_nightly-0.1.dev153}/examples/stax_nn/stax_nn.py +0 -0
- {mplang_nightly-0.1.dev152 → mplang_nightly-0.1.dev153}/examples/xgboost/hist_jax.py +0 -0
- {mplang_nightly-0.1.dev152 → mplang_nightly-0.1.dev153}/examples/xgboost/hist_jax_test.py +0 -0
- {mplang_nightly-0.1.dev152 → mplang_nightly-0.1.dev153}/examples/xgboost/naive_np.py +0 -0
- {mplang_nightly-0.1.dev152 → mplang_nightly-0.1.dev153}/examples/xgboost/readme.md +0 -0
- {mplang_nightly-0.1.dev152 → mplang_nightly-0.1.dev153}/examples/xgboost/sgb.py +0 -0
- {mplang_nightly-0.1.dev152 → mplang_nightly-0.1.dev153}/examples/xgboost/sgb_test.py +0 -0
- {mplang_nightly-0.1.dev152 → mplang_nightly-0.1.dev153}/hatch_build.py +0 -0
- {mplang_nightly-0.1.dev152 → mplang_nightly-0.1.dev153}/mplang/__init__.py +0 -0
- {mplang_nightly-0.1.dev152 → mplang_nightly-0.1.dev153}/mplang/analysis/__init__.py +0 -0
- {mplang_nightly-0.1.dev152 → mplang_nightly-0.1.dev153}/mplang/analysis/diagram.py +0 -0
- {mplang_nightly-0.1.dev152 → mplang_nightly-0.1.dev153}/mplang/api.py +0 -0
- {mplang_nightly-0.1.dev152 → mplang_nightly-0.1.dev153}/mplang/core/__init__.py +0 -0
- {mplang_nightly-0.1.dev152 → mplang_nightly-0.1.dev153}/mplang/core/comm.py +0 -0
- {mplang_nightly-0.1.dev152 → mplang_nightly-0.1.dev153}/mplang/core/context_mgr.py +0 -0
- {mplang_nightly-0.1.dev152 → mplang_nightly-0.1.dev153}/mplang/core/dtype.py +0 -0
- {mplang_nightly-0.1.dev152 → mplang_nightly-0.1.dev153}/mplang/core/expr/__init__.py +0 -0
- {mplang_nightly-0.1.dev152 → mplang_nightly-0.1.dev153}/mplang/core/expr/ast.py +0 -0
- {mplang_nightly-0.1.dev152 → mplang_nightly-0.1.dev153}/mplang/core/expr/evaluator.py +0 -0
- {mplang_nightly-0.1.dev152 → mplang_nightly-0.1.dev153}/mplang/core/expr/printer.py +0 -0
- {mplang_nightly-0.1.dev152 → mplang_nightly-0.1.dev153}/mplang/core/expr/transformer.py +0 -0
- {mplang_nightly-0.1.dev152 → mplang_nightly-0.1.dev153}/mplang/core/expr/utils.py +0 -0
- {mplang_nightly-0.1.dev152 → mplang_nightly-0.1.dev153}/mplang/core/expr/visitor.py +0 -0
- {mplang_nightly-0.1.dev152 → mplang_nightly-0.1.dev153}/mplang/core/expr/walk.py +0 -0
- {mplang_nightly-0.1.dev152 → mplang_nightly-0.1.dev153}/mplang/core/interp.py +0 -0
- {mplang_nightly-0.1.dev152 → mplang_nightly-0.1.dev153}/mplang/core/mask.py +0 -0
- {mplang_nightly-0.1.dev152 → mplang_nightly-0.1.dev153}/mplang/core/mpir.py +0 -0
- {mplang_nightly-0.1.dev152 → mplang_nightly-0.1.dev153}/mplang/core/mpobject.py +0 -0
- {mplang_nightly-0.1.dev152 → mplang_nightly-0.1.dev153}/mplang/core/mptype.py +0 -0
- {mplang_nightly-0.1.dev152 → mplang_nightly-0.1.dev153}/mplang/core/pfunc.py +0 -0
- {mplang_nightly-0.1.dev152 → mplang_nightly-0.1.dev153}/mplang/core/primitive.py +0 -0
- {mplang_nightly-0.1.dev152 → mplang_nightly-0.1.dev153}/mplang/core/table.py +0 -0
- {mplang_nightly-0.1.dev152 → mplang_nightly-0.1.dev153}/mplang/core/tensor.py +0 -0
- {mplang_nightly-0.1.dev152 → mplang_nightly-0.1.dev153}/mplang/core/tracer.py +0 -0
- {mplang_nightly-0.1.dev152 → mplang_nightly-0.1.dev153}/mplang/device.py +0 -0
- {mplang_nightly-0.1.dev152 → mplang_nightly-0.1.dev153}/mplang/kernels/__init__.py +0 -0
- {mplang_nightly-0.1.dev152 → mplang_nightly-0.1.dev153}/mplang/kernels/base.py +0 -0
- {mplang_nightly-0.1.dev152 → mplang_nightly-0.1.dev153}/mplang/kernels/builtin.py +0 -0
- {mplang_nightly-0.1.dev152 → mplang_nightly-0.1.dev153}/mplang/kernels/context.py +0 -0
- {mplang_nightly-0.1.dev152 → mplang_nightly-0.1.dev153}/mplang/kernels/crypto.py +0 -0
- {mplang_nightly-0.1.dev152 → mplang_nightly-0.1.dev153}/mplang/kernels/mock_tee.py +0 -0
- {mplang_nightly-0.1.dev152 → mplang_nightly-0.1.dev153}/mplang/kernels/phe.py +0 -0
- {mplang_nightly-0.1.dev152 → mplang_nightly-0.1.dev153}/mplang/kernels/spu.py +0 -0
- {mplang_nightly-0.1.dev152 → mplang_nightly-0.1.dev153}/mplang/kernels/sql_duckdb.py +0 -0
- {mplang_nightly-0.1.dev152 → mplang_nightly-0.1.dev153}/mplang/kernels/stablehlo.py +0 -0
- {mplang_nightly-0.1.dev152 → mplang_nightly-0.1.dev153}/mplang/ops/__init__.py +0 -0
- {mplang_nightly-0.1.dev152 → mplang_nightly-0.1.dev153}/mplang/ops/base.py +0 -0
- {mplang_nightly-0.1.dev152 → mplang_nightly-0.1.dev153}/mplang/ops/builtin.py +0 -0
- {mplang_nightly-0.1.dev152 → mplang_nightly-0.1.dev153}/mplang/ops/crypto.py +0 -0
- {mplang_nightly-0.1.dev152 → mplang_nightly-0.1.dev153}/mplang/ops/ibis_cc.py +0 -0
- {mplang_nightly-0.1.dev152 → mplang_nightly-0.1.dev153}/mplang/ops/jax_cc.py +0 -0
- {mplang_nightly-0.1.dev152 → mplang_nightly-0.1.dev153}/mplang/ops/phe.py +0 -0
- {mplang_nightly-0.1.dev152 → mplang_nightly-0.1.dev153}/mplang/ops/spu.py +0 -0
- {mplang_nightly-0.1.dev152 → mplang_nightly-0.1.dev153}/mplang/ops/sql.py +0 -0
- {mplang_nightly-0.1.dev152 → mplang_nightly-0.1.dev153}/mplang/ops/tee.py +0 -0
- {mplang_nightly-0.1.dev152 → mplang_nightly-0.1.dev153}/mplang/protos/v1alpha1/mpir_pb2.py +0 -0
- {mplang_nightly-0.1.dev152 → mplang_nightly-0.1.dev153}/mplang/protos/v1alpha1/mpir_pb2.pyi +0 -0
- {mplang_nightly-0.1.dev152 → mplang_nightly-0.1.dev153}/mplang/protos/v1alpha1/mpir_pb2_grpc.py +0 -0
- {mplang_nightly-0.1.dev152 → mplang_nightly-0.1.dev153}/mplang/runtime/__init__.py +0 -0
- {mplang_nightly-0.1.dev152 → mplang_nightly-0.1.dev153}/mplang/runtime/cli.py +0 -0
- {mplang_nightly-0.1.dev152 → mplang_nightly-0.1.dev153}/mplang/runtime/communicator.py +0 -0
- {mplang_nightly-0.1.dev152 → mplang_nightly-0.1.dev153}/mplang/runtime/data_providers.py +0 -0
- {mplang_nightly-0.1.dev152 → mplang_nightly-0.1.dev153}/mplang/runtime/exceptions.py +0 -0
- {mplang_nightly-0.1.dev152 → mplang_nightly-0.1.dev153}/mplang/runtime/http_api.md +0 -0
- {mplang_nightly-0.1.dev152 → mplang_nightly-0.1.dev153}/mplang/runtime/link_comm.py +0 -0
- {mplang_nightly-0.1.dev152 → mplang_nightly-0.1.dev153}/mplang/simp/__init__.py +0 -0
- {mplang_nightly-0.1.dev152 → mplang_nightly-0.1.dev153}/mplang/simp/mpi.py +0 -0
- {mplang_nightly-0.1.dev152 → mplang_nightly-0.1.dev153}/mplang/simp/random.py +0 -0
- {mplang_nightly-0.1.dev152 → mplang_nightly-0.1.dev153}/mplang/simp/smpc.py +0 -0
- {mplang_nightly-0.1.dev152 → mplang_nightly-0.1.dev153}/mplang/utils/__init__.py +0 -0
- {mplang_nightly-0.1.dev152 → mplang_nightly-0.1.dev153}/mplang/utils/crypto.py +0 -0
- {mplang_nightly-0.1.dev152 → mplang_nightly-0.1.dev153}/mplang/utils/func_utils.py +0 -0
- {mplang_nightly-0.1.dev152 → mplang_nightly-0.1.dev153}/mplang/utils/spu_utils.py +0 -0
- {mplang_nightly-0.1.dev152 → mplang_nightly-0.1.dev153}/mplang/utils/table_utils.py +0 -0
- {mplang_nightly-0.1.dev152 → mplang_nightly-0.1.dev153}/pyproject.toml +0 -0
- {mplang_nightly-0.1.dev152 → mplang_nightly-0.1.dev153}/tests/__init__.py +0 -0
- {mplang_nightly-0.1.dev152 → mplang_nightly-0.1.dev153}/tests/analysis/test_diagram.py +0 -0
- {mplang_nightly-0.1.dev152 → mplang_nightly-0.1.dev153}/tests/core/__init__.py +0 -0
- {mplang_nightly-0.1.dev152 → mplang_nightly-0.1.dev153}/tests/core/expr/__init__.py +0 -0
- {mplang_nightly-0.1.dev152 → mplang_nightly-0.1.dev153}/tests/core/expr/conftest.py +0 -0
- {mplang_nightly-0.1.dev152 → mplang_nightly-0.1.dev153}/tests/core/expr/test_ast.py +0 -0
- {mplang_nightly-0.1.dev152 → mplang_nightly-0.1.dev153}/tests/core/expr/test_printer.py +0 -0
- {mplang_nightly-0.1.dev152 → mplang_nightly-0.1.dev153}/tests/core/expr/test_utils.py +0 -0
- {mplang_nightly-0.1.dev152 → mplang_nightly-0.1.dev153}/tests/core/expr/test_walk.py +0 -0
- {mplang_nightly-0.1.dev152 → mplang_nightly-0.1.dev153}/tests/core/test_dtype.py +0 -0
- {mplang_nightly-0.1.dev152 → mplang_nightly-0.1.dev153}/tests/core/test_mask.py +0 -0
- {mplang_nightly-0.1.dev152 → mplang_nightly-0.1.dev153}/tests/core/test_mpir.py +0 -0
- {mplang_nightly-0.1.dev152 → mplang_nightly-0.1.dev153}/tests/core/test_mptype.py +0 -0
- {mplang_nightly-0.1.dev152 → mplang_nightly-0.1.dev153}/tests/core/test_primitive.py +0 -0
- {mplang_nightly-0.1.dev152 → mplang_nightly-0.1.dev153}/tests/core/test_table.py +0 -0
- {mplang_nightly-0.1.dev152 → mplang_nightly-0.1.dev153}/tests/core/test_tensor.py +0 -0
- {mplang_nightly-0.1.dev152 → mplang_nightly-0.1.dev153}/tests/core/test_tracer.py +0 -0
- {mplang_nightly-0.1.dev152 → mplang_nightly-0.1.dev153}/tests/device/__init__.py +0 -0
- {mplang_nightly-0.1.dev152 → mplang_nightly-0.1.dev153}/tests/device/test_device_basic.py +0 -0
- {mplang_nightly-0.1.dev152 → mplang_nightly-0.1.dev153}/tests/integration/README.md +0 -0
- {mplang_nightly-0.1.dev152 → mplang_nightly-0.1.dev153}/tests/integration/test_crypto_roundtrip.py +0 -0
- {mplang_nightly-0.1.dev152 → mplang_nightly-0.1.dev153}/tests/integration/test_tutorials.py +0 -0
- {mplang_nightly-0.1.dev152 → mplang_nightly-0.1.dev153}/tests/integration/test_unused_param_integration.py +0 -0
- {mplang_nightly-0.1.dev152 → mplang_nightly-0.1.dev153}/tests/kernels/test_builtin.py +0 -0
- {mplang_nightly-0.1.dev152 → mplang_nightly-0.1.dev153}/tests/kernels/test_debug_print.py +0 -0
- {mplang_nightly-0.1.dev152 → mplang_nightly-0.1.dev153}/tests/kernels/test_kernel_binding.py +0 -0
- {mplang_nightly-0.1.dev152 → mplang_nightly-0.1.dev153}/tests/kernels/test_phe.py +0 -0
- {mplang_nightly-0.1.dev152 → mplang_nightly-0.1.dev153}/tests/kernels/test_sql_duckdb.py +0 -0
- {mplang_nightly-0.1.dev152 → mplang_nightly-0.1.dev153}/tests/kernels/test_stablehlo.py +0 -0
- {mplang_nightly-0.1.dev152 → mplang_nightly-0.1.dev153}/tests/ops/__init__.py +0 -0
- {mplang_nightly-0.1.dev152 → mplang_nightly-0.1.dev153}/tests/ops/dummy.py +0 -0
- {mplang_nightly-0.1.dev152 → mplang_nightly-0.1.dev153}/tests/ops/test_builtin_pack.py +0 -0
- {mplang_nightly-0.1.dev152 → mplang_nightly-0.1.dev153}/tests/ops/test_crypto_tee.py +0 -0
- {mplang_nightly-0.1.dev152 → mplang_nightly-0.1.dev153}/tests/ops/test_feop_base.py +0 -0
- {mplang_nightly-0.1.dev152 → mplang_nightly-0.1.dev153}/tests/ops/test_ibis.py +0 -0
- {mplang_nightly-0.1.dev152 → mplang_nightly-0.1.dev153}/tests/ops/test_ibis_cc.py +0 -0
- {mplang_nightly-0.1.dev152 → mplang_nightly-0.1.dev153}/tests/ops/test_jax_cc.py +0 -0
- {mplang_nightly-0.1.dev152 → mplang_nightly-0.1.dev153}/tests/ops/test_phe.py +0 -0
- {mplang_nightly-0.1.dev152 → mplang_nightly-0.1.dev153}/tests/ops/test_spu.py +0 -0
- {mplang_nightly-0.1.dev152 → mplang_nightly-0.1.dev153}/tests/ops/test_spu_defensive.py +0 -0
- {mplang_nightly-0.1.dev152 → mplang_nightly-0.1.dev153}/tests/ops/test_sql.py +0 -0
- {mplang_nightly-0.1.dev152 → mplang_nightly-0.1.dev153}/tests/runtime/__init__.py +0 -0
- {mplang_nightly-0.1.dev152 → mplang_nightly-0.1.dev153}/tests/runtime/test_cli.py +0 -0
- {mplang_nightly-0.1.dev152 → mplang_nightly-0.1.dev153}/tests/runtime/test_simulation.py +0 -0
- {mplang_nightly-0.1.dev152 → mplang_nightly-0.1.dev153}/tests/simp/test_mpi.py +0 -0
- {mplang_nightly-0.1.dev152 → mplang_nightly-0.1.dev153}/tests/simp/test_random.py +0 -0
- {mplang_nightly-0.1.dev152 → mplang_nightly-0.1.dev153}/tests/simp/test_simp.py +0 -0
- {mplang_nightly-0.1.dev152 → mplang_nightly-0.1.dev153}/tests/simp/test_smpc.py +0 -0
- {mplang_nightly-0.1.dev152 → mplang_nightly-0.1.dev153}/tests/simp/test_sugar.py +0 -0
- {mplang_nightly-0.1.dev152 → mplang_nightly-0.1.dev153}/tests/utils/__init__.py +0 -0
- {mplang_nightly-0.1.dev152 → mplang_nightly-0.1.dev153}/tests/utils/server_fixtures.py +0 -0
- {mplang_nightly-0.1.dev152 → mplang_nightly-0.1.dev153}/tests/utils/test_func_utils.py +0 -0
- {mplang_nightly-0.1.dev152 → mplang_nightly-0.1.dev153}/tests/utils/test_spu_utils.py +0 -0
- {mplang_nightly-0.1.dev152 → mplang_nightly-0.1.dev153}/tests/utils/test_table_utils.py +0 -0
- {mplang_nightly-0.1.dev152 → mplang_nightly-0.1.dev153}/tutorials/0_basic.py +0 -0
- {mplang_nightly-0.1.dev152 → mplang_nightly-0.1.dev153}/tutorials/10_analysis.py +0 -0
- {mplang_nightly-0.1.dev152 → mplang_nightly-0.1.dev153}/tutorials/1_condition.py +0 -0
- {mplang_nightly-0.1.dev152 → mplang_nightly-0.1.dev153}/tutorials/2_whileloop.py +0 -0
- {mplang_nightly-0.1.dev152 → mplang_nightly-0.1.dev153}/tutorials/4_simulation.py +0 -0
- {mplang_nightly-0.1.dev152 → mplang_nightly-0.1.dev153}/tutorials/5_ir_dump.py +0 -0
- {mplang_nightly-0.1.dev152 → mplang_nightly-0.1.dev153}/tutorials/6_advanced.py +0 -0
- {mplang_nightly-0.1.dev152 → mplang_nightly-0.1.dev153}/tutorials/7_stdio.py +0 -0
- {mplang_nightly-0.1.dev152 → mplang_nightly-0.1.dev153}/tutorials/8_phe.py +0 -0
- {mplang_nightly-0.1.dev152 → mplang_nightly-0.1.dev153}/tutorials/__init__.py +0 -0
- {mplang_nightly-0.1.dev152 → mplang_nightly-0.1.dev153}/tutorials/pitfalls/late_binding.py +0 -0
- {mplang_nightly-0.1.dev152 → mplang_nightly-0.1.dev153}/tutorials/pitfalls/rand.py +0 -0
- {mplang_nightly-0.1.dev152 → mplang_nightly-0.1.dev153}/tutorials/run.sh +0 -0
@@ -25,23 +25,28 @@ from typing import Any
|
|
25
25
|
|
26
26
|
@dataclass(frozen=True)
|
27
27
|
class RuntimeInfo:
|
28
|
-
"""
|
29
|
-
|
28
|
+
"""Per-physical-node runtime configuration.
|
29
|
+
|
30
|
+
``op_bindings`` is a per-node override map (logical_op -> kernel_id) merged
|
31
|
+
into that node's ``RuntimeContext``. Unknown future / auxiliary fields are
|
32
|
+
preserved in ``extra``.
|
30
33
|
"""
|
31
34
|
|
32
35
|
version: str
|
33
36
|
platform: str
|
34
|
-
|
37
|
+
# Per-node partial override dispatch table (merged over project defaults).
|
38
|
+
op_bindings: dict[str, str] = field(default_factory=dict)
|
35
39
|
|
36
|
-
# A catch-all for any other custom or future properties
|
40
|
+
# A catch-all for any other custom or future properties (must not collide
|
41
|
+
# with reserved keys: version, platform, op_bindings).
|
37
42
|
extra: dict[str, Any] = field(default_factory=dict)
|
38
43
|
|
39
44
|
def to_dict(self) -> dict[str, Any]:
|
40
|
-
"""Convert RuntimeInfo to a dictionary."""
|
45
|
+
"""Convert RuntimeInfo to a dictionary (stable field names)."""
|
41
46
|
result = {
|
42
47
|
"version": self.version,
|
43
48
|
"platform": self.platform,
|
44
|
-
"
|
49
|
+
"op_bindings": self.op_bindings,
|
45
50
|
}
|
46
51
|
result.update(self.extra)
|
47
52
|
return result
|
@@ -175,7 +180,8 @@ class ClusterSpec:
|
|
175
180
|
|
176
181
|
# 2. Parse Physical Nodes, using the list index as the rank
|
177
182
|
nodes_map: dict[str, Node] = {}
|
178
|
-
|
183
|
+
# Reserved runtime info keys we recognize explicitly.
|
184
|
+
known_runtime_fields = {"version", "platform", "op_bindings"}
|
179
185
|
for i, node_cfg in enumerate(config["nodes"]):
|
180
186
|
if "rank" in node_cfg:
|
181
187
|
# Optionally, we can log a warning that the explicit 'rank' is ignored.
|
@@ -187,11 +193,12 @@ class ClusterSpec:
|
|
187
193
|
for k, v in runtime_info_cfg.items()
|
188
194
|
if k not in known_runtime_fields
|
189
195
|
}
|
190
|
-
|
196
|
+
# Gracefully ignore legacy 'backends' if present (treated as extra)
|
197
|
+
# for backward compatibility.
|
191
198
|
runtime_info = RuntimeInfo(
|
192
199
|
version=runtime_info_cfg.get("version", "N/A"),
|
193
200
|
platform=runtime_info_cfg.get("platform", "N/A"),
|
194
|
-
|
201
|
+
op_bindings=runtime_info_cfg.get("op_bindings", {}) or {},
|
195
202
|
extra=extra_runtime_info,
|
196
203
|
)
|
197
204
|
|
@@ -227,32 +234,96 @@ class ClusterSpec:
|
|
227
234
|
return cls(nodes=nodes_map, devices=devices_map)
|
228
235
|
|
229
236
|
@classmethod
|
230
|
-
def simple(
|
231
|
-
|
232
|
-
|
233
|
-
|
237
|
+
def simple(
|
238
|
+
cls,
|
239
|
+
world_size: int,
|
240
|
+
*,
|
241
|
+
endpoints: list[str] | None = None,
|
242
|
+
spu_protocol: str = "SEMI2K",
|
243
|
+
spu_field: str = "FM128",
|
244
|
+
runtime_version: str = "simulated",
|
245
|
+
runtime_platform: str = "simulated",
|
246
|
+
op_bindings: list[dict[str, str]] | None = None,
|
247
|
+
enable_local_device: bool = True,
|
248
|
+
enable_spu_device: bool = True,
|
249
|
+
) -> ClusterSpec:
|
250
|
+
"""Convenience constructor used heavily in tests.
|
251
|
+
|
252
|
+
Parameters
|
253
|
+
----------
|
254
|
+
world_size:
|
255
|
+
Number of parties (physical nodes).
|
256
|
+
endpoints:
|
257
|
+
Optional explicit endpoint list of length ``world_size``. Each element may
|
258
|
+
include scheme (``http://``) or not; stored verbatim. If not provided we
|
259
|
+
synthesize ``localhost:{5000 + i}`` (5000 is a fixed default; pass explicit
|
260
|
+
endpoints for control). Deprecated ``base_port`` legacy kwarg can adjust it.
|
261
|
+
spu_protocol / spu_field:
|
262
|
+
SPU device config values.
|
263
|
+
runtime_version / runtime_platform:
|
264
|
+
Populated into each node's ``RuntimeInfo``.
|
265
|
+
op_bindings:
|
266
|
+
Optional list of length ``world_size`` supplying per-node op_bindings
|
267
|
+
override dicts (defaults to empty dicts).
|
268
|
+
enable_local_device:
|
269
|
+
If True (default), create one ``local_{rank}`` device per node.
|
270
|
+
enable_spu_device:
|
271
|
+
If True (default) create a shared SPU device named ``SP0``.
|
272
|
+
"""
|
273
|
+
base_port = 5000
|
274
|
+
|
275
|
+
if endpoints is not None and len(endpoints) != world_size:
|
276
|
+
raise ValueError(
|
277
|
+
"len(endpoints) must equal world_size when provided: "
|
278
|
+
f"{len(endpoints)} != {world_size}"
|
279
|
+
)
|
280
|
+
|
281
|
+
if op_bindings is not None and len(op_bindings) != world_size:
|
282
|
+
raise ValueError(
|
283
|
+
"len(op_bindings) must equal world_size when provided: "
|
284
|
+
f"{len(op_bindings)} != {world_size}"
|
285
|
+
)
|
286
|
+
|
287
|
+
if not enable_local_device and not enable_spu_device:
|
288
|
+
raise ValueError(
|
289
|
+
"At least one of enable_local_device or enable_spu_device must be True"
|
290
|
+
)
|
291
|
+
|
292
|
+
nodes: dict[str, Node] = {}
|
293
|
+
for i in range(world_size):
|
294
|
+
ep = endpoints[i] if endpoints is not None else f"localhost:{base_port + i}"
|
295
|
+
node_op_bindings = op_bindings[i] if op_bindings is not None else {}
|
296
|
+
nodes[f"node{i}"] = Node(
|
234
297
|
name=f"node{i}",
|
235
298
|
rank=i,
|
236
|
-
endpoint=
|
299
|
+
endpoint=ep,
|
237
300
|
runtime_info=RuntimeInfo(
|
238
|
-
version=
|
239
|
-
platform=
|
240
|
-
|
301
|
+
version=runtime_version,
|
302
|
+
platform=runtime_platform,
|
303
|
+
op_bindings=node_op_bindings,
|
241
304
|
),
|
242
305
|
)
|
243
|
-
for i in range(world_size)
|
244
|
-
}
|
245
306
|
|
246
|
-
devices = {
|
247
|
-
|
307
|
+
devices: dict[str, Device] = {}
|
308
|
+
# Optional per-node local devices
|
309
|
+
if enable_local_device:
|
310
|
+
for i in range(world_size):
|
311
|
+
devices[f"local_{i}"] = Device(
|
312
|
+
name=f"local_{i}",
|
313
|
+
kind="local",
|
314
|
+
members=[nodes[f"node{i}"]],
|
315
|
+
)
|
316
|
+
|
317
|
+
# Shared SPU device
|
318
|
+
if enable_spu_device:
|
319
|
+
devices["SP0"] = Device(
|
248
320
|
name="SP0",
|
249
321
|
kind="SPU",
|
250
322
|
members=list(nodes.values()),
|
251
323
|
config={
|
252
|
-
"protocol":
|
253
|
-
"field":
|
324
|
+
"protocol": spu_protocol,
|
325
|
+
"field": spu_field,
|
254
326
|
},
|
255
327
|
)
|
256
|
-
}
|
257
328
|
|
258
329
|
return cls(nodes=nodes, devices=devices)
|
@@ -81,21 +81,14 @@ class HttpExecutorClient:
|
|
81
81
|
self,
|
82
82
|
name: str,
|
83
83
|
rank: int,
|
84
|
-
|
85
|
-
*,
|
86
|
-
spu_mask: int = 0,
|
87
|
-
spu_protocol: str = "SEMI2K",
|
88
|
-
spu_field: str = "FM64",
|
84
|
+
cluster_spec: dict,
|
89
85
|
) -> str:
|
90
86
|
"""Create a new session.
|
91
87
|
|
92
88
|
Args:
|
93
89
|
name: Session name/ID.
|
94
|
-
rank:
|
95
|
-
|
96
|
-
spu_mask: SPU mask for the session, 0 means no SPU.
|
97
|
-
spu_protocol: SPU protocol for the session (e.g., "SEMI2K", "ABY3").
|
98
|
-
spu_field: SPU field for the session (e.g., "FM64", "FM128").
|
90
|
+
rank: This party's rank.
|
91
|
+
cluster_spec: Full cluster specification dict (ClusterSpec.to_dict()).
|
99
92
|
|
100
93
|
Returns:
|
101
94
|
The session name/ID
|
@@ -104,14 +97,7 @@ class HttpExecutorClient:
|
|
104
97
|
RuntimeError: If session creation fails
|
105
98
|
"""
|
106
99
|
url = f"/sessions/{name}"
|
107
|
-
|
108
|
-
payload: dict[str, Any] = {
|
109
|
-
"rank": rank,
|
110
|
-
"endpoints": endpoints,
|
111
|
-
"spu_mask": spu_mask,
|
112
|
-
"spu_protocol": spu_protocol,
|
113
|
-
"spu_field": spu_field,
|
114
|
-
}
|
100
|
+
payload: dict[str, Any] = {"rank": rank, "cluster_spec": cluster_spec}
|
115
101
|
|
116
102
|
try:
|
117
103
|
response = await self._client.put(url, json=payload)
|
@@ -145,8 +145,6 @@ class Driver(InterpContext):
|
|
145
145
|
"""Get existing session or create a new one across all HTTP servers."""
|
146
146
|
if self._session_id is None:
|
147
147
|
new_session_id = new_uuid()
|
148
|
-
endpoints_list = list(self.node_addrs.values())
|
149
|
-
|
150
148
|
# Create temporary clients for session creation
|
151
149
|
clients = self._create_clients()
|
152
150
|
try:
|
@@ -158,10 +156,7 @@ class Driver(InterpContext):
|
|
158
156
|
task = client.create_session(
|
159
157
|
name=new_session_id,
|
160
158
|
rank=rank,
|
161
|
-
|
162
|
-
spu_mask=self.spu_mask_int,
|
163
|
-
spu_protocol=self.spu_protocol_str,
|
164
|
-
spu_field=self.spu_field_str,
|
159
|
+
cluster_spec=self.cluster_spec.to_dict(),
|
165
160
|
)
|
166
161
|
tasks.append(task)
|
167
162
|
|
@@ -32,14 +32,30 @@ from mplang.core.table import TableType
|
|
32
32
|
from mplang.core.tensor import TensorType
|
33
33
|
from mplang.kernels.base import KernelContext
|
34
34
|
from mplang.protos.v1alpha1 import mpir_pb2
|
35
|
-
from mplang.runtime import resource
|
36
35
|
from mplang.runtime.data_providers import DataProvider, ResolvedURI, register_provider
|
37
36
|
from mplang.runtime.exceptions import InvalidRequestError, ResourceNotFound
|
37
|
+
from mplang.runtime.session import (
|
38
|
+
Computation,
|
39
|
+
Session,
|
40
|
+
Symbol,
|
41
|
+
)
|
38
42
|
|
39
43
|
logger = logging.getLogger(__name__)
|
40
44
|
|
41
45
|
app = FastAPI()
|
42
46
|
|
47
|
+
# per-server global state
|
48
|
+
_sessions: dict[str, Session] = {}
|
49
|
+
_global_symbols: dict[str, Symbol] = {}
|
50
|
+
|
51
|
+
|
52
|
+
def register_session(session: Session) -> Session: # pragma: no cover - test helper
|
53
|
+
existing = _sessions.get(session.name)
|
54
|
+
if existing:
|
55
|
+
return existing
|
56
|
+
_sessions[session.name] = session
|
57
|
+
return session
|
58
|
+
|
43
59
|
|
44
60
|
class _SymbolsProvider(DataProvider):
|
45
61
|
"""Server-local symbols provider backed by BackendRuntime.state."""
|
@@ -83,7 +99,7 @@ class _SymbolsProvider(DataProvider):
|
|
83
99
|
ctx: KernelContext,
|
84
100
|
) -> Any: # type: ignore[override]
|
85
101
|
name = self._symbol_name(uri)
|
86
|
-
sym =
|
102
|
+
sym = _global_symbols.get(name)
|
87
103
|
if sym is None:
|
88
104
|
raise ResourceNotFound(f"Global symbol '{name}' not found")
|
89
105
|
return sym.data
|
@@ -102,8 +118,13 @@ class _SymbolsProvider(DataProvider):
|
|
102
118
|
raise InvalidRequestError(
|
103
119
|
f"Failed to encode value for symbols:// write: {e!s}"
|
104
120
|
) from e
|
105
|
-
|
106
|
-
|
121
|
+
try:
|
122
|
+
obj = pickle.loads(base64.b64decode(data_b64))
|
123
|
+
except Exception as e: # pragma: no cover - defensive
|
124
|
+
raise InvalidRequestError(
|
125
|
+
f"Failed to decode value for symbols:// write: {e!s}"
|
126
|
+
) from e
|
127
|
+
_global_symbols[name] = Symbol(name=name, mptype={}, data=obj)
|
107
128
|
|
108
129
|
|
109
130
|
# Register symbols provider explicitly for server runtime
|
@@ -168,11 +189,7 @@ def validate_name(name: str, name_type: str) -> None:
|
|
168
189
|
# Request/Response Models
|
169
190
|
class CreateSessionRequest(BaseModel):
|
170
191
|
rank: int
|
171
|
-
|
172
|
-
# SPU related
|
173
|
-
spu_mask: int
|
174
|
-
spu_protocol: str
|
175
|
-
spu_field: str
|
192
|
+
cluster_spec: dict
|
176
193
|
|
177
194
|
|
178
195
|
class SessionResponse(BaseModel):
|
@@ -229,7 +246,7 @@ async def health_check() -> dict[str, str]:
|
|
229
246
|
@app.get("/sessions", response_model=SessionListResponse)
|
230
247
|
def list_sessions() -> SessionListResponse:
|
231
248
|
"""List all session names."""
|
232
|
-
return SessionListResponse(sessions=
|
249
|
+
return SessionListResponse(sessions=list(_sessions.keys()))
|
233
250
|
|
234
251
|
|
235
252
|
# List all computations in a session
|
@@ -238,39 +255,44 @@ def list_sessions() -> SessionListResponse:
|
|
238
255
|
)
|
239
256
|
def list_session_computations(session_name: str) -> ComputationListResponse:
|
240
257
|
"""List all computation names in a session."""
|
241
|
-
|
242
|
-
if not
|
258
|
+
sess = _sessions.get(session_name)
|
259
|
+
if not sess:
|
243
260
|
raise ResourceNotFound(f"Session '{session_name}' not found")
|
244
|
-
return ComputationListResponse(computations=
|
261
|
+
return ComputationListResponse(computations=sess.list_computations())
|
245
262
|
|
246
263
|
|
247
264
|
# Session endpoints
|
248
265
|
@app.put("/sessions/{session_name}", response_model=SessionResponse)
|
249
266
|
def create_session(session_name: str, request: CreateSessionRequest) -> SessionResponse:
|
250
267
|
validate_name(session_name, "session")
|
251
|
-
session
|
252
|
-
|
253
|
-
|
254
|
-
|
255
|
-
|
256
|
-
|
257
|
-
|
258
|
-
|
259
|
-
|
268
|
+
# Delegate cluster spec parsing & session construction to resource layer
|
269
|
+
from mplang.core.cluster import ClusterSpec # local import to avoid cycles
|
270
|
+
|
271
|
+
if session_name in _sessions:
|
272
|
+
sess = _sessions[session_name]
|
273
|
+
else:
|
274
|
+
spec = ClusterSpec.from_dict(request.cluster_spec)
|
275
|
+
if len(spec.get_devices_by_kind("SPU")) == 0:
|
276
|
+
raise InvalidRequestError("No SPU device found in cluster_spec for session")
|
277
|
+
sess = Session(name=session_name, rank=request.rank, cluster_spec=spec)
|
278
|
+
_sessions[session_name] = sess
|
279
|
+
return SessionResponse(name=sess.name)
|
260
280
|
|
261
281
|
|
262
282
|
@app.get("/sessions/{session_name}", response_model=SessionResponse)
|
263
283
|
def get_session(session_name: str) -> SessionResponse:
|
264
|
-
|
265
|
-
if not
|
284
|
+
sess = _sessions.get(session_name)
|
285
|
+
if not sess:
|
266
286
|
raise ResourceNotFound(f"Session '{session_name}' not found")
|
267
|
-
return SessionResponse(name=
|
287
|
+
return SessionResponse(name=sess.name)
|
268
288
|
|
269
289
|
|
270
290
|
@app.delete("/sessions/{session_name}")
|
271
291
|
def delete_session(session_name: str) -> dict[str, str]:
|
272
292
|
"""Delete a session and all its associated resources."""
|
273
|
-
if
|
293
|
+
if session_name in _sessions:
|
294
|
+
del _sessions[session_name]
|
295
|
+
logging.info(f"Session {session_name} deleted successfully")
|
274
296
|
return {"message": f"Session '{session_name}' deleted successfully"}
|
275
297
|
else:
|
276
298
|
raise ResourceNotFound(f"Session '{session_name}' not found")
|
@@ -299,18 +321,25 @@ def create_and_execute_computation(
|
|
299
321
|
raise InvalidRequestError("Failed to parse expression from protobuf")
|
300
322
|
|
301
323
|
# Create the computation resource
|
302
|
-
|
303
|
-
|
304
|
-
|
305
|
-
|
306
|
-
|
307
|
-
|
324
|
+
sess = _sessions.get(session_name)
|
325
|
+
if not sess:
|
326
|
+
raise ResourceNotFound(f"Session '{session_name}' not found.")
|
327
|
+
comp = sess.get_computation(computation_id)
|
328
|
+
if not comp:
|
329
|
+
comp = Computation(name=computation_id, expr=expr)
|
330
|
+
sess.add_computation(comp)
|
331
|
+
sess.execute(comp, request.input_names, request.output_names)
|
332
|
+
return ComputationResponse(name=computation_id)
|
308
333
|
|
309
334
|
|
310
335
|
@app.delete("/sessions/{session_name}/computations/{computation_id}")
|
311
336
|
def delete_computation(session_name: str, computation_id: str) -> dict[str, str]:
|
312
337
|
"""Delete a specific computation."""
|
313
|
-
|
338
|
+
sess = _sessions.get(session_name)
|
339
|
+
if sess and sess.delete_computation(computation_id):
|
340
|
+
logging.info(
|
341
|
+
f"Computation {computation_id} deleted from session {session_name}"
|
342
|
+
)
|
314
343
|
return {"message": f"Computation '{computation_id}' deleted successfully"}
|
315
344
|
else:
|
316
345
|
raise ResourceNotFound(
|
@@ -326,9 +355,15 @@ def create_session_symbol(
|
|
326
355
|
session_name: str, symbol_name: str, request: CreateSymbolRequest
|
327
356
|
) -> SymbolResponse:
|
328
357
|
"""Create a symbol in a session."""
|
329
|
-
|
330
|
-
|
331
|
-
|
358
|
+
sess = _sessions.get(session_name)
|
359
|
+
if not sess:
|
360
|
+
raise ResourceNotFound(f"Session '{session_name}' not found.")
|
361
|
+
try:
|
362
|
+
obj = pickle.loads(base64.b64decode(request.data))
|
363
|
+
except Exception as e:
|
364
|
+
raise InvalidRequestError(f"Invalid symbol data: {e!s}") from e
|
365
|
+
symbol = Symbol(name=symbol_name, mptype=request.mptype, data=obj)
|
366
|
+
sess.add_symbol(symbol)
|
332
367
|
# Return the base64 data back to client; server stores Python object
|
333
368
|
return SymbolResponse(
|
334
369
|
name=symbol.name,
|
@@ -346,8 +381,8 @@ def get_session_symbol(session_name: str, symbol_name: str) -> SymbolResponse:
|
|
346
381
|
logger.debug(
|
347
382
|
f"Looking for symbol: '{symbol_name}' in session: '{session_name}'"
|
348
383
|
)
|
349
|
-
|
350
|
-
symbol =
|
384
|
+
sess = _sessions.get(session_name)
|
385
|
+
symbol = sess.get_symbol(symbol_name) if sess else None
|
351
386
|
if not symbol:
|
352
387
|
raise HTTPException(
|
353
388
|
status_code=404, detail=f"Symbol {symbol_name} not found"
|
@@ -368,14 +403,19 @@ def get_session_symbol(session_name: str, symbol_name: str) -> SymbolResponse:
|
|
368
403
|
@app.get("/sessions/{session_name}/symbols")
|
369
404
|
def list_session_symbols(session_name: str) -> dict[str, list[str]]:
|
370
405
|
"""List all symbols in a session."""
|
371
|
-
|
406
|
+
sess = _sessions.get(session_name)
|
407
|
+
if not sess:
|
408
|
+
raise ResourceNotFound(f"Session '{session_name}' not found.")
|
409
|
+
symbols = sess.list_symbols()
|
372
410
|
return {"symbols": symbols}
|
373
411
|
|
374
412
|
|
375
413
|
@app.delete("/sessions/{session_name}/symbols/{symbol_name}")
|
376
414
|
def delete_symbol(session_name: str, symbol_name: str) -> dict[str, str]:
|
377
415
|
"""Delete a specific symbol."""
|
378
|
-
|
416
|
+
sess = _sessions.get(session_name)
|
417
|
+
if sess and sess.delete_symbol(symbol_name):
|
418
|
+
logging.info(f"Symbol {symbol_name} deleted from session {session_name}")
|
379
419
|
return {"message": f"Symbol '{symbol_name}' deleted successfully"}
|
380
420
|
else:
|
381
421
|
raise ResourceNotFound(
|
@@ -389,13 +429,18 @@ def create_global_symbol(
|
|
389
429
|
symbol_name: str, request: CreateSymbolRequest
|
390
430
|
) -> GlobalSymbolResponse:
|
391
431
|
validate_name(symbol_name, "symbol")
|
392
|
-
|
432
|
+
try:
|
433
|
+
obj = pickle.loads(base64.b64decode(request.data))
|
434
|
+
except Exception as e:
|
435
|
+
raise InvalidRequestError(f"Invalid global symbol data: {e!s}") from e
|
436
|
+
sym = Symbol(name=symbol_name, mptype=request.mptype, data=obj)
|
437
|
+
_global_symbols[symbol_name] = sym
|
393
438
|
return GlobalSymbolResponse(name=sym.name, mptype=sym.mptype, data=request.data)
|
394
439
|
|
395
440
|
|
396
441
|
@app.get("/api/v1/symbols/{symbol_name}", response_model=GlobalSymbolResponse)
|
397
|
-
def get_global_symbol(symbol_name: str) -> GlobalSymbolResponse:
|
398
|
-
sym =
|
442
|
+
def get_global_symbol(symbol_name: str) -> GlobalSymbolResponse: # route handler
|
443
|
+
sym = _global_symbols.get(symbol_name)
|
399
444
|
if not sym:
|
400
445
|
raise ResourceNotFound(f"Global symbol '{symbol_name}' not found")
|
401
446
|
data_bytes = pickle.dumps(sym.data)
|
@@ -405,12 +450,13 @@ def get_global_symbol(symbol_name: str) -> GlobalSymbolResponse:
|
|
405
450
|
|
406
451
|
@app.get("/api/v1/symbols")
|
407
452
|
def list_global_symbols() -> dict[str, list[str]]:
|
408
|
-
return {"symbols":
|
453
|
+
return {"symbols": list(_global_symbols.keys())}
|
409
454
|
|
410
455
|
|
411
456
|
@app.delete("/api/v1/symbols/{symbol_name}")
|
412
|
-
def delete_global_symbol(symbol_name: str) -> dict[str, str]:
|
413
|
-
if
|
457
|
+
def delete_global_symbol(symbol_name: str) -> dict[str, str]: # route handler
|
458
|
+
if symbol_name in _global_symbols:
|
459
|
+
del _global_symbols[symbol_name]
|
414
460
|
return {"message": f"Global symbol '{symbol_name}' deleted successfully"}
|
415
461
|
else:
|
416
462
|
raise ResourceNotFound(f"Global symbol '{symbol_name}' not found")
|
@@ -426,8 +472,8 @@ def comm_send(
|
|
426
472
|
Receive a message from another party and deliver it to the session's communicator.
|
427
473
|
This endpoint runs on the receiver's server.
|
428
474
|
"""
|
429
|
-
|
430
|
-
if not
|
475
|
+
sess = _sessions.get(session_name)
|
476
|
+
if not sess or not sess.communicator:
|
431
477
|
logger.error(f"Session or communicator not found: session={session_name}")
|
432
478
|
raise HTTPException(status_code=404, detail="Session or communicator not found")
|
433
479
|
|
@@ -435,5 +481,5 @@ def comm_send(
|
|
435
481
|
# We don't need to validate to_rank since the request is coming to this server
|
436
482
|
|
437
483
|
# Use the proper onSent mechanism from CommunicatorBase
|
438
|
-
|
484
|
+
sess.communicator.onSent(from_rank, key, request.data)
|
439
485
|
return {"status": "ok"}
|