mplang-nightly 0.1.dev149__py3-none-any.whl → 0.1.dev151__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (34) hide show
  1. mplang/core/expr/evaluator.py +1 -1
  2. mplang/core/primitive.py +1 -1
  3. mplang/device.py +4 -4
  4. mplang/{backend → kernels}/builtin.py +1 -1
  5. mplang/{backend → kernels}/context.py +18 -13
  6. mplang/{backend → kernels}/crypto.py +1 -1
  7. mplang/{backend/tee.py → kernels/mock_tee.py} +12 -3
  8. mplang/{backend → kernels}/phe.py +1 -1
  9. mplang/{backend → kernels}/spu.py +1 -1
  10. mplang/{backend → kernels}/sql_duckdb.py +1 -1
  11. mplang/{backend → kernels}/stablehlo.py +1 -1
  12. mplang/{frontend → ops}/__init__.py +11 -11
  13. mplang/{frontend → ops}/builtin.py +1 -1
  14. mplang/{frontend → ops}/crypto.py +1 -1
  15. mplang/{frontend → ops}/ibis_cc.py +1 -1
  16. mplang/{frontend → ops}/jax_cc.py +1 -1
  17. mplang/{frontend → ops}/phe.py +1 -1
  18. mplang/{frontend → ops}/spu.py +1 -1
  19. mplang/{frontend → ops}/sql.py +1 -1
  20. mplang/{frontend → ops}/tee.py +1 -1
  21. mplang/runtime/data_providers.py +1 -1
  22. mplang/runtime/resource.py +2 -2
  23. mplang/runtime/server.py +1 -1
  24. mplang/runtime/simulation.py +35 -30
  25. mplang/simp/__init__.py +7 -7
  26. mplang/simp/smpc.py +1 -1
  27. {mplang_nightly-0.1.dev149.dist-info → mplang_nightly-0.1.dev151.dist-info}/METADATA +1 -1
  28. {mplang_nightly-0.1.dev149.dist-info → mplang_nightly-0.1.dev151.dist-info}/RECORD +34 -34
  29. /mplang/{backend → kernels}/__init__.py +0 -0
  30. /mplang/{backend → kernels}/base.py +0 -0
  31. /mplang/{frontend → ops}/base.py +0 -0
  32. {mplang_nightly-0.1.dev149.dist-info → mplang_nightly-0.1.dev151.dist-info}/WHEEL +0 -0
  33. {mplang_nightly-0.1.dev149.dist-info → mplang_nightly-0.1.dev151.dist-info}/entry_points.txt +0 -0
  34. {mplang_nightly-0.1.dev149.dist-info → mplang_nightly-0.1.dev151.dist-info}/licenses/LICENSE +0 -0
@@ -27,7 +27,6 @@ from __future__ import annotations
27
27
  from dataclasses import dataclass
28
28
  from typing import Any, Protocol
29
29
 
30
- from mplang.backend.context import RuntimeContext
31
30
  from mplang.core.comm import ICommunicator
32
31
  from mplang.core.expr.ast import (
33
32
  AccessExpr,
@@ -47,6 +46,7 @@ from mplang.core.expr.visitor import ExprVisitor
47
46
  from mplang.core.expr.walk import walk_dataflow
48
47
  from mplang.core.mask import Mask
49
48
  from mplang.core.pfunc import PFunction
49
+ from mplang.kernels.context import RuntimeContext
50
50
 
51
51
 
52
52
  class IEvaluator(Protocol):
mplang/core/primitive.py CHANGED
@@ -47,7 +47,7 @@ from mplang.core.pfunc import PFunction
47
47
  from mplang.core.table import TableLike
48
48
  from mplang.core.tensor import ScalarType, Shape, TensorLike
49
49
  from mplang.core.tracer import TraceContext, TraceVar, trace
50
- from mplang.frontend import builtin
50
+ from mplang.ops import builtin
51
51
  from mplang.utils.func_utils import var_demorph, var_morph
52
52
 
53
53
 
mplang/device.py CHANGED
@@ -35,10 +35,10 @@ from mplang.core import InterpContext, MPObject, primitive
35
35
  from mplang.core.cluster import ClusterSpec, Device
36
36
  from mplang.core.context_mgr import cur_ctx
37
37
  from mplang.core.tensor import TensorType
38
- from mplang.frontend import builtin, crypto, ibis_cc, jax_cc, tee
39
- from mplang.frontend.base import FeOperation
40
- from mplang.frontend.ibis_cc import IbisCompiler
41
- from mplang.frontend.jax_cc import JaxCompiler
38
+ from mplang.ops import builtin, crypto, ibis_cc, jax_cc, tee
39
+ from mplang.ops.base import FeOperation
40
+ from mplang.ops.ibis_cc import IbisCompiler
41
+ from mplang.ops.jax_cc import JaxCompiler
42
42
  from mplang.simp import mpi, smpc
43
43
 
44
44
  # Automatic transfer between devices when parameter is not on the target device.
@@ -19,10 +19,10 @@ from typing import Any
19
19
  import numpy as np
20
20
  import pandas as pd
21
21
 
22
- from mplang.backend.base import cur_kctx, kernel_def
23
22
  from mplang.core.pfunc import PFunction
24
23
  from mplang.core.table import TableType
25
24
  from mplang.core.tensor import TensorType
25
+ from mplang.kernels.base import cur_kctx, kernel_def
26
26
  from mplang.runtime.data_providers import get_provider, resolve_uri
27
27
  from mplang.utils import table_utils
28
28
 
@@ -17,12 +17,12 @@ from __future__ import annotations
17
17
  from collections.abc import Mapping
18
18
  from typing import Any
19
19
 
20
- from mplang.backend import base
21
- from mplang.backend.base import KernelContext, get_kernel_spec, kernel_exists
22
20
  from mplang.core.dtype import UINT8, DType
23
21
  from mplang.core.pfunc import PFunction
24
22
  from mplang.core.table import TableLike, TableType
25
23
  from mplang.core.tensor import TensorLike, TensorType
24
+ from mplang.kernels import base
25
+ from mplang.kernels.base import KernelContext, get_kernel_spec, kernel_exists
26
26
 
27
27
  # Default bindings
28
28
  # Import kernel implementation modules explicitly so their @kernel_def entries
@@ -35,13 +35,13 @@ def _ensure_impl_imported() -> None:
35
35
  global _IMPL_IMPORTED
36
36
  if _IMPL_IMPORTED:
37
37
  return
38
- from mplang.backend import builtin as _impl_builtin # noqa: F401
39
- from mplang.backend import crypto as _impl_crypto # noqa: F401
40
- from mplang.backend import phe as _impl_phe # noqa: F401
41
- from mplang.backend import spu as _impl_spu # noqa: F401
42
- from mplang.backend import sql_duckdb as _impl_sql_duckdb # noqa: F401
43
- from mplang.backend import stablehlo as _impl_stablehlo # noqa: F401
44
- from mplang.backend import tee as _impl_tee # noqa: F401
38
+ from mplang.kernels import builtin as _impl_builtin # noqa: F401
39
+ from mplang.kernels import crypto as _impl_crypto # noqa: F401
40
+ from mplang.kernels import mock_tee as _impl_tee # noqa: F401
41
+ from mplang.kernels import phe as _impl_phe # noqa: F401
42
+ from mplang.kernels import spu as _impl_spu # noqa: F401
43
+ from mplang.kernels import sql_duckdb as _impl_sql_duckdb # noqa: F401
44
+ from mplang.kernels import stablehlo as _impl_stablehlo # noqa: F401
45
45
 
46
46
  _IMPL_IMPORTED = True
47
47
 
@@ -91,8 +91,8 @@ _DEFAULT_BINDINGS: dict[str, str] = {
91
91
  # generic SQL op; backend-specific kernel id for duckdb
92
92
  "sql.run": "duckdb.run_sql",
93
93
  # tee
94
- "tee.quote": "tee.quote",
95
- "tee.attest": "tee.attest",
94
+ # "tee.quote": "mock_tee.quote",
95
+ # "tee.attest": "mock_tee.attest",
96
96
  }
97
97
 
98
98
 
@@ -102,6 +102,10 @@ _DEFAULT_BINDINGS: dict[str, str] = {
102
102
  class RuntimeContext:
103
103
  """Per-runtime execution context with isolated op->kernel bindings.
104
104
 
105
+ This object owns ONLY static dispatch metadata ("op bindings") and mutable
106
+ per-rank kernel side state/cache/stats. It does NOT store per-evaluation
107
+ variable bindings (those are provided to the evaluator at evaluation time).
108
+
105
109
  Parameters
106
110
  ----------
107
111
  rank : int
@@ -110,9 +114,10 @@ class RuntimeContext:
110
114
  Total number of participants.
111
115
  initial_bindings : Mapping[str, str] | None, optional
112
116
  Optional partial overrides applied on top of the default binding table
113
- during construction (override semantics, not replace). After
117
+ during construction (override semantics, not replace). These map
118
+ op_type -> kernel_id and form a *template* for dispatch. After
114
119
  initialization, all (re)binding must go through ``bind_op`` /
115
- ``rebind_op``.
120
+ ``rebind_op`` on this context (scoped to THIS runtime only).
116
121
  state / cache / stats : dict, optional
117
122
  Mutable pockets reused across kernel invocations. If omitted, new
118
123
  dictionaries are created.
@@ -19,8 +19,8 @@ from typing import Any
19
19
 
20
20
  import numpy as np
21
21
 
22
- from mplang.backend.base import cur_kctx, kernel_def
23
22
  from mplang.core.pfunc import PFunction
23
+ from mplang.kernels.base import cur_kctx, kernel_def
24
24
  from mplang.utils.crypto import blake2b
25
25
 
26
26
  __all__: list[str] = [] # flat kernels only
@@ -15,12 +15,13 @@
15
15
  from __future__ import annotations
16
16
 
17
17
  import os
18
+ import warnings
18
19
 
19
20
  import numpy as np
20
21
  from numpy.typing import NDArray
21
22
 
22
- from mplang.backend.base import cur_kctx, kernel_def
23
23
  from mplang.core.pfunc import PFunction
24
+ from mplang.kernels.base import cur_kctx, kernel_def
24
25
 
25
26
  __all__: list[str] = []
26
27
 
@@ -43,16 +44,24 @@ def _quote_from_pk(pk: np.ndarray) -> NDArray[np.uint8]:
43
44
  return out
44
45
 
45
46
 
46
- @kernel_def("tee.quote")
47
+ @kernel_def("mock_tee.quote")
47
48
  def _tee_quote(pfunc: PFunction, pk: object) -> NDArray[np.uint8]:
49
+ warnings.warn(
50
+ "Insecure mock TEE kernel 'mock_tee.quote' in use. NOT secure; for local testing only.",
51
+ stacklevel=3,
52
+ )
48
53
  pk = np.asarray(pk, dtype=np.uint8)
49
54
  # rng access ensures deterministic seeding per rank even if unused now
50
55
  _rng()
51
56
  return _quote_from_pk(pk)
52
57
 
53
58
 
54
- @kernel_def("tee.attest")
59
+ @kernel_def("mock_tee.attest")
55
60
  def _tee_attest(pfunc: PFunction, quote: object) -> NDArray[np.uint8]:
61
+ warnings.warn(
62
+ "Insecure mock TEE kernel 'mock_tee.attest' in use. NOT secure; for local testing only.",
63
+ stacklevel=3,
64
+ )
56
65
  quote = np.asarray(quote, dtype=np.uint8)
57
66
  if quote.size != 33:
58
67
  raise ValueError("mock quote must be 33 bytes (1 header + 32 pk)")
@@ -19,10 +19,10 @@ from typing import Any
19
19
  import numpy as np
20
20
  from lightphe import LightPHE
21
21
 
22
- from mplang.backend.base import kernel_def
23
22
  from mplang.core.dtype import DType
24
23
  from mplang.core.mptype import TensorLike
25
24
  from mplang.core.pfunc import PFunction
25
+ from mplang.kernels.base import kernel_def
26
26
 
27
27
  # This controls the decimal precision used in lightPHE for float operations
28
28
  # we force it to 0 to only support integer operations
@@ -21,9 +21,9 @@ import numpy as np
21
21
  import spu.api as spu_api
22
22
  import spu.libspu as libspu
23
23
 
24
- from mplang.backend.base import cur_kctx, kernel_def
25
24
  from mplang.core.mptype import TensorLike
26
25
  from mplang.core.pfunc import PFunction
26
+ from mplang.kernels.base import cur_kctx, kernel_def
27
27
  from mplang.runtime.link_comm import LinkCommunicator
28
28
 
29
29
 
@@ -16,8 +16,8 @@ from __future__ import annotations
16
16
 
17
17
  from typing import Any
18
18
 
19
- from mplang.backend.base import kernel_def
20
19
  from mplang.core.pfunc import PFunction
20
+ from mplang.kernels.base import kernel_def
21
21
 
22
22
 
23
23
  @kernel_def("duckdb.run_sql")
@@ -21,8 +21,8 @@ import jax.numpy as jnp
21
21
  from jax._src import xla_bridge
22
22
  from jax.lib import xla_client as xc
23
23
 
24
- from mplang.backend.base import cur_kctx, kernel_def
25
24
  from mplang.core.pfunc import PFunction
25
+ from mplang.kernels.base import cur_kctx, kernel_def
26
26
 
27
27
 
28
28
  @kernel_def("mlir.stablehlo")
@@ -19,17 +19,17 @@ This module contains compilers that transform high-level functions into
19
19
  portable, serializable intermediate representations.
20
20
  """
21
21
 
22
- from mplang.frontend import builtin as builtin
23
- from mplang.frontend import crypto as crypto
24
- from mplang.frontend import ibis_cc as ibis_cc
25
- from mplang.frontend import jax_cc as jax_cc
26
- from mplang.frontend import phe as phe
27
- from mplang.frontend import spu as spu
28
- from mplang.frontend import tee as tee
29
- from mplang.frontend.base import FeOperation as FeOperation
30
- from mplang.frontend.ibis_cc import ibis_compile as ibis_compile
31
- from mplang.frontend.jax_cc import jax_compile as jax_compile
32
- from mplang.frontend.sql import sql_run as sql_run
22
+ from mplang.ops import builtin as builtin
23
+ from mplang.ops import crypto as crypto
24
+ from mplang.ops import ibis_cc as ibis_cc
25
+ from mplang.ops import jax_cc as jax_cc
26
+ from mplang.ops import phe as phe
27
+ from mplang.ops import spu as spu
28
+ from mplang.ops import tee as tee
29
+ from mplang.ops.base import FeOperation as FeOperation
30
+ from mplang.ops.ibis_cc import ibis_compile as ibis_compile
31
+ from mplang.ops.jax_cc import jax_compile as jax_compile
32
+ from mplang.ops.sql import sql_run as sql_run
33
33
 
34
34
  __all__ = [
35
35
  "FeOperation",
@@ -20,7 +20,7 @@ from mplang.core.mpobject import MPObject # Needed for constant() triad return
20
20
  from mplang.core.pfunc import PFunction
21
21
  from mplang.core.table import TableLike, TableType
22
22
  from mplang.core.tensor import ScalarType, Shape, TensorLike, TensorType
23
- from mplang.frontend.base import stateless_mod
23
+ from mplang.ops.base import stateless_mod
24
24
  from mplang.utils import table_utils
25
25
 
26
26
  _BUILTIN_MOD = stateless_mod("builtin")
@@ -28,7 +28,7 @@ from __future__ import annotations
28
28
 
29
29
  from mplang.core.dtype import UINT8
30
30
  from mplang.core.tensor import TensorType
31
- from mplang.frontend.base import stateless_mod
31
+ from mplang.ops.base import stateless_mod
32
32
 
33
33
  _CRYPTO_MOD = stateless_mod("crypto")
34
34
 
@@ -24,7 +24,7 @@ from mplang.core import dtype
24
24
  from mplang.core.mpobject import MPObject
25
25
  from mplang.core.pfunc import PFunction
26
26
  from mplang.core.table import TableType
27
- from mplang.frontend.base import FeOperation, stateless_mod
27
+ from mplang.ops.base import FeOperation, stateless_mod
28
28
  from mplang.utils.func_utils import normalize_fn
29
29
 
30
30
 
@@ -24,7 +24,7 @@ from jax.tree_util import PyTreeDef, tree_flatten
24
24
  from mplang.core.mpobject import MPObject
25
25
  from mplang.core.pfunc import PFunction, get_fn_name
26
26
  from mplang.core.tensor import TensorType
27
- from mplang.frontend.base import FeOperation, stateless_mod
27
+ from mplang.ops.base import FeOperation, stateless_mod
28
28
  from mplang.utils.func_utils import normalize_fn
29
29
 
30
30
  # Enable 64-bit precision for JAX to match tensor types
@@ -16,7 +16,7 @@
16
16
 
17
17
  from mplang.core.dtype import UINT8
18
18
  from mplang.core.tensor import TensorType
19
- from mplang.frontend.base import stateless_mod
19
+ from mplang.ops.base import stateless_mod
20
20
 
21
21
  _PHE_MOD = stateless_mod("phe")
22
22
 
@@ -26,7 +26,7 @@ from jax.tree_util import PyTreeDef, tree_flatten
26
26
  from mplang.core.mpobject import MPObject
27
27
  from mplang.core.pfunc import PFunction, get_fn_name
28
28
  from mplang.core.tensor import TensorType
29
- from mplang.frontend.base import stateless_mod
29
+ from mplang.ops.base import stateless_mod
30
30
  from mplang.utils.func_utils import normalize_fn
31
31
 
32
32
 
@@ -17,7 +17,7 @@ from jax.tree_util import PyTreeDef, tree_flatten
17
17
  from mplang.core.mpobject import MPObject
18
18
  from mplang.core.pfunc import PFunction
19
19
  from mplang.core.table import TableType
20
- from mplang.frontend.base import FeOperation, stateless_mod
20
+ from mplang.ops.base import FeOperation, stateless_mod
21
21
 
22
22
  _SQL_MOD = stateless_mod("sql")
23
23
 
@@ -16,7 +16,7 @@ from __future__ import annotations
16
16
 
17
17
  from mplang.core.dtype import UINT8
18
18
  from mplang.core.tensor import TensorType
19
- from mplang.frontend.base import stateless_mod
19
+ from mplang.ops.base import stateless_mod
20
20
 
21
21
  _TEE_MOD = stateless_mod("tee")
22
22
 
@@ -21,9 +21,9 @@ from urllib.parse import ParseResult, urlparse
21
21
  import numpy as np
22
22
  import pandas as pd
23
23
 
24
- from mplang.backend.base import KernelContext
25
24
  from mplang.core.table import TableType
26
25
  from mplang.core.tensor import TensorType
26
+ from mplang.kernels.base import KernelContext
27
27
  from mplang.utils import table_utils
28
28
 
29
29
 
@@ -26,11 +26,11 @@ from urllib.parse import urlparse
26
26
  import cloudpickle as pickle
27
27
  import spu.libspu as libspu
28
28
 
29
- from mplang.backend.context import RuntimeContext
30
- from mplang.backend.spu import PFunction # type: ignore
31
29
  from mplang.core.expr.ast import Expr
32
30
  from mplang.core.expr.evaluator import IEvaluator, create_evaluator
33
31
  from mplang.core.mask import Mask
32
+ from mplang.kernels.context import RuntimeContext
33
+ from mplang.kernels.spu import PFunction # type: ignore
34
34
  from mplang.runtime.communicator import HttpCommunicator
35
35
  from mplang.runtime.exceptions import InvalidRequestError, ResourceNotFound
36
36
  from mplang.runtime.link_comm import LinkCommunicator
mplang/runtime/server.py CHANGED
@@ -27,10 +27,10 @@ from fastapi import FastAPI, HTTPException, Request
27
27
  from fastapi.responses import JSONResponse
28
28
  from pydantic import BaseModel
29
29
 
30
- from mplang.backend.base import KernelContext
31
30
  from mplang.core.mpir import Reader
32
31
  from mplang.core.table import TableType
33
32
  from mplang.core.tensor import TensorType
33
+ from mplang.kernels.base import KernelContext
34
34
  from mplang.protos.v1alpha1 import mpir_pb2
35
35
  from mplang.runtime import resource
36
36
  from mplang.runtime.data_providers import DataProvider, ResolvedURI, register_provider
@@ -24,10 +24,6 @@ from typing import Any, cast
24
24
 
25
25
  import spu.libspu as libspu
26
26
 
27
- # New explicit binding model: we only need RuntimeContext which ensures
28
- # bindings via bind_all_ops() on creation; per-module side-effect imports
29
- # are no longer required here.
30
- from mplang.backend.context import RuntimeContext
31
27
  from mplang.core.cluster import ClusterSpec
32
28
  from mplang.core.comm import CollectiveMixin, CommunicatorBase
33
29
  from mplang.core.expr.ast import Expr
@@ -38,6 +34,7 @@ from mplang.core.mpir import Reader, Writer
38
34
  from mplang.core.mpobject import MPObject
39
35
  from mplang.core.mptype import MPType, TensorLike
40
36
  from mplang.core.pfunc import PFunction # for spu.seed_env kernel seeding
37
+ from mplang.kernels.context import RuntimeContext
41
38
  from mplang.runtime.link_comm import LinkCommunicator
42
39
  from mplang.utils.spu_utils import parse_field, parse_protocol
43
40
 
@@ -89,14 +86,20 @@ class Simulator(InterpContext):
89
86
  cluster_spec: ClusterSpec,
90
87
  *,
91
88
  trace_ranks: list[int] | None = None,
89
+ op_bindings: dict[str, str] | None = None,
92
90
  ) -> None:
93
91
  """Initialize a simulator with the given cluster specification.
94
92
 
95
93
  Args:
96
94
  cluster_spec: The cluster specification defining the simulation environment.
97
95
  trace_ranks: List of ranks to trace execution for debugging.
96
+ op_bindings: Optional op->kernel binding template applied to all
97
+ RuntimeContexts. These are static dispatch overrides (merged
98
+ with project defaults) and are orthogonal to the per-evaluate
99
+ variable ``bindings`` dict passed into ``evaluate``.
98
100
  """
99
101
  super().__init__(cluster_spec)
102
+ self._op_bindings_template = op_bindings or {}
100
103
  self._trace_ranks = trace_ranks or []
101
104
 
102
105
  spu_devices = cluster_spec.get_devices_by_kind("SPU")
@@ -140,20 +143,18 @@ class Simulator(InterpContext):
140
143
  self._spu_world = spu_mask.num_parties()
141
144
  self._spu_mask = spu_mask
142
145
 
143
- # No per-backend handlers needed anymore (all flat kernels)
144
- self._handlers: list[list[Any]] = [[] for _ in range(self.world_size())]
145
-
146
- self._evaluators: list[IEvaluator] = []
147
- for rank in range(self.world_size()):
148
- runtime = RuntimeContext(rank=rank, world_size=self.world_size())
149
- ev = create_evaluator(
150
- rank,
151
- {}, # the global environment for this rank
152
- self._comms[rank],
153
- runtime,
154
- None,
146
+ # Persistent per-rank RuntimeContext instances (reused across evaluates).
147
+ # We no longer pre-create evaluators since each evaluate has different env bindings.
148
+ self._runtimes: list[RuntimeContext] = [
149
+ RuntimeContext(
150
+ rank=rank,
151
+ world_size=self.world_size(),
152
+ # Static op bindings template cloned into each runtime. These are kernel
153
+ # dispatch mappings, not per-evaluate variable bindings.
154
+ initial_bindings=self._op_bindings_template,
155
155
  )
156
- self._evaluators.append(ev)
156
+ for rank in range(self.world_size())
157
+ ]
157
158
 
158
159
  @classmethod
159
160
  def simple(
@@ -214,10 +215,10 @@ class Simulator(InterpContext):
214
215
  for rank in range(self.world_size())
215
216
  ]
216
217
 
217
- # Build per-rank evaluators with the per-party environment
218
+ # Build per-rank evaluators with the per-party environment (runtime reused)
218
219
  pts_evaluators: list[IEvaluator] = []
219
220
  for rank in range(self.world_size()):
220
- runtime = RuntimeContext(rank=rank, world_size=self.world_size())
221
+ runtime = self._runtimes[rank]
221
222
  ev = create_evaluator(
222
223
  rank,
223
224
  pts_env[rank],
@@ -225,17 +226,21 @@ class Simulator(InterpContext):
225
226
  runtime,
226
227
  None,
227
228
  )
228
- link_ctx = self._spu_link_ctxs[rank]
229
- seed_fn = PFunction(
230
- fn_type="spu.seed_env",
231
- ins_info=(),
232
- outs_info=(),
233
- config=self._spu_runtime_cfg,
234
- world=self._spu_world,
235
- link=link_ctx,
236
- )
237
- # Seed SPU backend environment explicitly via runtime (no evaluator fast-path)
238
- ev.runtime.run_kernel(seed_fn, []) # type: ignore[arg-type]
229
+ # Seed SPU once per runtime (idempotent logical requirement)
230
+ # Use setdefault to both retrieve and create metadata dict in one step.
231
+ spu_meta = runtime.state.setdefault("_spu", {})
232
+ if not spu_meta.get("inited", False):
233
+ link_ctx = self._spu_link_ctxs[rank]
234
+ seed_fn = PFunction(
235
+ fn_type="spu.seed_env",
236
+ ins_info=(),
237
+ outs_info=(),
238
+ config=self._spu_runtime_cfg,
239
+ world=self._spu_world,
240
+ link=link_ctx,
241
+ )
242
+ ev.runtime.run_kernel(seed_fn, []) # type: ignore[arg-type]
243
+ spu_meta["inited"] = True
239
244
  pts_evaluators.append(ev)
240
245
 
241
246
  # Collect evaluation results from all parties
mplang/simp/__init__.py CHANGED
@@ -36,8 +36,8 @@ from mplang.core.primitive import (
36
36
  uniform_cond,
37
37
  while_loop,
38
38
  )
39
- from mplang.frontend import ibis_cc, jax_cc
40
- from mplang.frontend.base import FeOperation
39
+ from mplang.ops import ibis_cc, jax_cc
40
+ from mplang.ops.base import FeOperation
41
41
  from mplang.simp.mpi import allgather_m, bcast_m, gather_m, p2p, scatter_m
42
42
  from mplang.simp.random import key_split, pperm, prandint, ukey, urandint
43
43
  from mplang.simp.smpc import reveal, revealTo, seal, sealFrom, srun
@@ -189,7 +189,7 @@ def P2P(src: Party, dst: Party, value: Any) -> Any:
189
189
  This module provides a light-weight mechanism to expose *module-like* groups
190
190
  of callable operations bound to a specific party (rank) via attribute access:
191
191
 
192
- load_module("mplang.frontend.crypto", alias="crypto")
192
+ load_module("mplang.ops.crypto", alias="crypto")
193
193
  P0.crypto.encrypt(x) # executes encrypt() with pmask = {rank 0}
194
194
 
195
195
  Core concepts:
@@ -283,13 +283,13 @@ def _load_prelude_modules() -> None:
283
283
  """Auto-register public frontend submodules for party namespace access.
284
284
 
285
285
  Implementation detail: we treat every non-underscore immediate child of
286
- ``mplang.frontend`` as public and make it available as ``P0.<name>``.
286
+ ``mplang.ops`` as public and make it available as ``P0.<name>``.
287
287
  This keeps user ergonomics high (no manual load_module calls for core
288
288
  frontends) but slightly increases implicit surface area. If this grows
289
289
  unwieldy we can switch to an allowlist.
290
290
  """
291
291
  try:
292
- import mplang.frontend as _fe # type: ignore
292
+ import mplang.ops as _fe # type: ignore
293
293
  except (ImportError, ModuleNotFoundError): # pragma: no cover
294
294
  # Frontend package not present (minimal install); safe to skip.
295
295
  return
@@ -299,7 +299,7 @@ def _load_prelude_modules() -> None:
299
299
  if m.name.startswith("_"):
300
300
  continue
301
301
  if m.name not in _NAMESPACE_REGISTRY:
302
- _NAMESPACE_REGISTRY[m.name] = f"mplang.frontend.{m.name}"
302
+ _NAMESPACE_REGISTRY[m.name] = f"mplang.ops.{m.name}"
303
303
 
304
304
 
305
305
  def load_module(module: str, alias: str | None = None) -> None:
@@ -333,7 +333,7 @@ def load_module(module: str, alias: str | None = None) -> None:
333
333
 
334
334
  Examples
335
335
  --------
336
- >>> load_module("mplang.frontend.crypto", alias="crypto")
336
+ >>> load_module("mplang.ops.crypto", alias="crypto")
337
337
  >>> # Now call an op on party 0
338
338
  >>> P0.crypto.encrypt(data)
339
339
  """
mplang/simp/smpc.py CHANGED
@@ -23,7 +23,7 @@ from jax.tree_util import tree_unflatten
23
23
 
24
24
  from mplang.core import Mask, MPObject, Rank, peval, psize
25
25
  from mplang.core.context_mgr import cur_ctx
26
- from mplang.frontend import spu
26
+ from mplang.ops import spu
27
27
  from mplang.simp import mpi
28
28
 
29
29
 
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: mplang-nightly
3
- Version: 0.1.dev149
3
+ Version: 0.1.dev151
4
4
  Summary: Multi-Party Programming Language
5
5
  Author-email: SecretFlow Team <secretflow-contact@service.alipay.com>
6
6
  License: Apache License
@@ -1,18 +1,8 @@
1
1
  mplang/__init__.py,sha256=ofO-F-CNoVIxpMpTJtTJoQtKegJcHwcOJLzoVispiyc,1852
2
2
  mplang/api.py,sha256=ssmv0_CyZPFORhOUJ84Jo6NwRJSK7_Ono3n7ZjEg4sA,3058
3
- mplang/device.py,sha256=Iz_YFKkrbTFKtTQdGqkQZfc0NQH9dIxXP7-fUkIQOa4,12568
3
+ mplang/device.py,sha256=RmjnhzHxJkkNmtBKtYMEbpQYBZpuC43qlllkCOp-QD8,12548
4
4
  mplang/analysis/__init__.py,sha256=CTHFvRsi-nFngojqjn08UaR3RY9i7CJ7T2UdR95kCrk,1056
5
5
  mplang/analysis/diagram.py,sha256=ffwgD12gL1_KH1uJ_EYkjmIlDrfxYJJkWj-wHl09_Xk,19520
6
- mplang/backend/__init__.py,sha256=2WE4cmW96Xkzyq2yRRYNww4kZ5o6u6NbPV0BxqZG698,581
7
- mplang/backend/base.py,sha256=eizxj16sWkUvBvXWS0Zl-S9uuqalJmNUzB1xLhBg8S8,4920
8
- mplang/backend/builtin.py,sha256=Mk1uUO2Vpw3meqZ0B7B0hG-wndea6cmFv2Uk1vM_uTg,7052
9
- mplang/backend/context.py,sha256=fVJ0w0cw15JEqJO048dncWg7DGNWqbHSUjq42Jsyvos,10952
10
- mplang/backend/crypto.py,sha256=H_s5HI7lUP7g0xz-a9qMbSn6dhJStUilKbn3-7SIh0I,3812
11
- mplang/backend/phe.py,sha256=uNqmrbDAbd97TWS_O6D5sopastHy6J20R7knFE4M4uc,65247
12
- mplang/backend/spu.py,sha256=QT1q5uv-5P_nBGtTvtA_yI2h3h3zIqNSnvzGT7Shua4,9307
13
- mplang/backend/sql_duckdb.py,sha256=U_KzEUinxrBRDoUz2Vh597-N4I3hPOBT0RT3tX-ZqKE,1502
14
- mplang/backend/stablehlo.py,sha256=RhKf6TUvjLrRvgtdVY2HxcRDGtjpKBobuBFMfsvZQOI,2937
15
- mplang/backend/tee.py,sha256=6kc7qTe8nWc3pr6iYtozEGLO8Umg-UBQLDiz6p3pdVg,1918
16
6
  mplang/core/__init__.py,sha256=lWxlEKfRwX7FNDzgyKZ1fiDMaCiqkyg0j5mKlZD_v7g,2244
17
7
  mplang/core/cluster.py,sha256=gqMJenvXUfHhE181Dd5JiUkD4nT07RLoicBnvsGmRkE,8598
18
8
  mplang/core/comm.py,sha256=MByyu3etlQh_TkP1vKCFLIAPPuJOpl9Kjs6hOj6m4Yc,8843
@@ -24,28 +14,38 @@ mplang/core/mpir.py,sha256=V6S9RqegaI0yojhLkHla5nGBi27ASoxlrEs1k4tGubM,37980
24
14
  mplang/core/mpobject.py,sha256=0pHSd7SrAFTScCFcB9ziDztElYQn-oIZOKBx47B3QX0,3732
25
15
  mplang/core/mptype.py,sha256=09LbMyJp68W0IkbD0s9YLeVssPg3Rl-rcwjTaCfidIQ,15243
26
16
  mplang/core/pfunc.py,sha256=PAr8qRhVveWO5HOI0TgdsWjpi4PFi2iEyuTlr9UVKSY,5106
27
- mplang/core/primitive.py,sha256=C1HMbqmkAvLbdgXiHrWPTQ2v2t1uwC_vsGCtI0TEqHY,40574
17
+ mplang/core/primitive.py,sha256=-IkGqdbwtbMkLEOOTghXfuFtFvxu5jFQBupm5nPV-RI,40569
28
18
  mplang/core/table.py,sha256=BqTBZn7Tfwce4vzl3XYhaX5hVmKagVq9-YoERDta6d8,5892
29
19
  mplang/core/tensor.py,sha256=86u6DogSZMoL0w5XjtTmQm2PhA_VjwybN1b6U4Zzphg,2361
30
20
  mplang/core/tracer.py,sha256=dVMfUeCMmPz4o6tLXewGMW1Kpy5gpZORvr9w4MhwDtM,14288
31
21
  mplang/core/expr/__init__.py,sha256=qwiSTUOcanFJLyK8HZ13_L1ZDrybqpPXIlTHAyeumE8,1988
32
22
  mplang/core/expr/ast.py,sha256=KE46KTtlH9RA2V_EzWVKCKolsycgTmt7SotUrOc8Qxs,20923
33
- mplang/core/expr/evaluator.py,sha256=OYmxkr4Lf2qMHnHK-aca-dfMsAAzGRVWuXrxNk_M_3U,21675
23
+ mplang/core/expr/evaluator.py,sha256=UezuvGY65Xq-QJwqhQ9PzsK-RBmmKJjHPQZYlWgYvnc,21675
34
24
  mplang/core/expr/printer.py,sha256=VblKGnO0OUfzH7EBkszwRNxQUB8QyyC7BlJWJEUv9so,9546
35
25
  mplang/core/expr/transformer.py,sha256=TyL-8FjrVvDq_C9X7kAuKkiqt2XdZM-okjzVQj0A33s,4893
36
26
  mplang/core/expr/utils.py,sha256=VDTJ_-CsdHtVy9wDaGa7XdFxQ7o5lYYaeqcgsAhkbNI,2625
37
27
  mplang/core/expr/visitor.py,sha256=2Ge-I5N-wH8VVXy8d2WyNaEv8x6seiRx9peyH9S2BYU,2044
38
28
  mplang/core/expr/walk.py,sha256=lXkGJEEuvKGDqQihbxXPxfz2RfR1Q1zYUlt11iooQW0,11889
39
- mplang/frontend/__init__.py,sha256=3ZBFX_acM96tZ2mtJaxJm150n1cf0LnnCRmkrAc4uBw,1463
40
- mplang/frontend/base.py,sha256=rGtfBejcDh9mTRxOdJK5VUlG5vYiVJSir8X72X0Huvc,18264
41
- mplang/frontend/builtin.py,sha256=8qrlbe_SSy6QTXTnMG6_ADB8jSklVZGFBrkoR-p02FE,9368
42
- mplang/frontend/crypto.py,sha256=Nf8zT4Eko7MIs4R2tgZecKVd7d6Hvd_CGGmANhs3Ghs,3651
43
- mplang/frontend/ibis_cc.py,sha256=CTTbPPZ9hFnHuFDDIfgJHie1EdNnHmi5Ha1KsX0iYh8,4235
44
- mplang/frontend/jax_cc.py,sha256=lMqaYD1tyM5DsStTNYifAXzhzsNM5nDiG3a61ygbWyc,7807
45
- mplang/frontend/phe.py,sha256=tDsCvStjVJ1Fs07yF3idkFnugUCA1zdFApPx7Uuulik,6795
46
- mplang/frontend/spu.py,sha256=7G6DaEfC5APSDhfeWSISTG_8tEcVbWth3XmjL8QUrVA,4994
47
- mplang/frontend/sql.py,sha256=DFdvjEPQX28VCRgUMeHYR0rwwOaoCH15bpvvlclLtHA,1999
48
- mplang/frontend/tee.py,sha256=EigmlbYDGvXkZCMHSYRAiOboSl9TG0ewoudbgl3_V6M,1393
29
+ mplang/kernels/__init__.py,sha256=2WE4cmW96Xkzyq2yRRYNww4kZ5o6u6NbPV0BxqZG698,581
30
+ mplang/kernels/base.py,sha256=eizxj16sWkUvBvXWS0Zl-S9uuqalJmNUzB1xLhBg8S8,4920
31
+ mplang/kernels/builtin.py,sha256=nSuM79cn7M6M27A6Y8ycilXT_qAlB1ktkwkRX6dv_VQ,7052
32
+ mplang/kernels/context.py,sha256=EiqsYq9Cw4Z_2lLvRNc8xNsu5EHkG_738QKFJ3Wf0zE,11329
33
+ mplang/kernels/crypto.py,sha256=TWixli1uRQ_7OjA49qQXUXa2ldHDEwaCMXXPSHdAPi8,3812
34
+ mplang/kernels/mock_tee.py,sha256=ifVqb03FgQ5EDK8r14PMpZKMMaSeqMDorcpGLiy00mM,2233
35
+ mplang/kernels/phe.py,sha256=8-_1IFPOaGECGj9mbYja8XoqbMYnYqfpDNVyMJd8J1Y,65247
36
+ mplang/kernels/spu.py,sha256=Kkg1ZQbmklXl7YkIeBfxqs3o4wX7ygBE8hXLpx90L9Y,9307
37
+ mplang/kernels/sql_duckdb.py,sha256=C2XdNLoE2Apma-Fs7OYzDzkBAXAVuShuROxhCWCHDG4,1502
38
+ mplang/kernels/stablehlo.py,sha256=jDsu-lIHRAm94FcUcxZgK02c6BhFJpbO8cf2hP2DFgk,2937
39
+ mplang/ops/__init__.py,sha256=dpe7WWiYapOFzJeGoKFYBr5mnd6P5SdOyvdYaM2Nhm0,1408
40
+ mplang/ops/base.py,sha256=rGtfBejcDh9mTRxOdJK5VUlG5vYiVJSir8X72X0Huvc,18264
41
+ mplang/ops/builtin.py,sha256=D7T8rRF9g05VIw9T72lsncF5cDQqaT37eapBieRKvRI,9363
42
+ mplang/ops/crypto.py,sha256=9CeFJrYmvjmgx-3WQl6jHXh8VafRpT4QBunbzsPF8Uc,3646
43
+ mplang/ops/ibis_cc.py,sha256=bWKN1dL8Nluwvu5TLi8iUwytcnpXtWakZDCL793zBRA,4230
44
+ mplang/ops/jax_cc.py,sha256=42czYg3hNQbI_nUebXnshlU8ULwM-oBDe_TQoApLNVA,7802
45
+ mplang/ops/phe.py,sha256=SatswExjZWPed8y3qA33BCwIWbvsgHCuCAz_pv2RLLw,6790
46
+ mplang/ops/spu.py,sha256=UHr5DSoqG08xDYER_11OsMVjGGNXXxsvkFoVvXU8uik,4989
47
+ mplang/ops/sql.py,sha256=p-u0wQPk9KlgveltYvQcF1UefScJoqBCqhzYPeLBB5Y,1994
48
+ mplang/ops/tee.py,sha256=gwzP81y2idH-d-Du84H6oNZpLaGD-3fEgm8G1uxWpUA,1388
49
49
  mplang/protos/v1alpha1/mpir_pb2.py,sha256=Bros37t-4LMJbuUYVSM65rImUYTtZDhNTIADGbZCKp0,7522
50
50
  mplang/protos/v1alpha1/mpir_pb2.pyi,sha256=GwXR4wPB_kB_36iYS9x-cGI9KDKFMq89KhdLhW_xmvE,19342
51
51
  mplang/protos/v1alpha1/mpir_pb2_grpc.py,sha256=xYOs94SXiNYAlFodACnsXW5QovLsHY5tCk3p76RH5Zc,158
@@ -53,25 +53,25 @@ mplang/runtime/__init__.py,sha256=IRPP3TtpFC4iSt7_uaq-S4dL7CwrXL0XBMeaBoEYLlg,94
53
53
  mplang/runtime/cli.py,sha256=WehDodeVB4AukSWx1LJxxtKUqGmLPY4qjayrPlOg3bE,14438
54
54
  mplang/runtime/client.py,sha256=w8sPuQzqaJI5uS_3JHu2mf0tLaFmZH3f6-SeUBfMLMY,15737
55
55
  mplang/runtime/communicator.py,sha256=Lek6_h_Wmr_W-_JpT-vMxL3CHxcVZdtf7jdaLGuxPgQ,3199
56
- mplang/runtime/data_providers.py,sha256=TPAJSko_2J95oiHCxAKALICVM_LvnxzfgcM48ubhnKU,8226
56
+ mplang/runtime/data_providers.py,sha256=hH2butEOYNGq2rRZjVBDfXLxe3YUin2ftAF6htbTfLA,8226
57
57
  mplang/runtime/driver.py,sha256=Ok1jY301ctN1_KTb4jwSxOdB0lI_xhx9AwhtEGJ-VLQ,11300
58
58
  mplang/runtime/exceptions.py,sha256=c18U0xK20dRmgZo0ogTf5vXlkix9y3VAFuzkHxaXPEk,981
59
59
  mplang/runtime/http_api.md,sha256=-re1DhEqMplAkv_wnqEU-PSs8tTzf4-Ml0Gq0f3Go6s,4883
60
60
  mplang/runtime/link_comm.py,sha256=uNqTCGZVwWeuHAb7yXXQf0DUsMXLa8leHCkrcZdzYMU,4559
61
- mplang/runtime/resource.py,sha256=-B9kSM7xhocc6mpXHmV9xTdpVR2duiUCepJKS7QuLqA,11688
62
- mplang/runtime/server.py,sha256=gTPqAux1EdefaBFnserYIXamoi7pbEsQrFX6cXbOjik,14716
63
- mplang/runtime/simulation.py,sha256=kuFXWuJLGcmy4OvLCBby4K5QbXaQZmKSb4qrCJ2stBY,10957
64
- mplang/simp/__init__.py,sha256=DmSMcKvHVXWS2pYsuHazEmwOWWpZeKOJQsNU6VxC10U,11614
61
+ mplang/runtime/resource.py,sha256=xNke4UpNDjsjWcr09oXWNBXsMfSZFOwsKD7FWdCVPbc,11688
62
+ mplang/runtime/server.py,sha256=LQ5uJi95tYrKmgHwZaxUQi-aiqwSsT3W4z7pZ9dQaUQ,14716
63
+ mplang/runtime/simulation.py,sha256=_cmUsYL58mvc6msHZ2fDjFAEHHLdJ-TRzJV8BxOP_WA,11473
64
+ mplang/simp/__init__.py,sha256=xNXnA8-jZAANa2A1W39b3lYO7D02zdCXl0TpivkTGS4,11579
65
65
  mplang/simp/mpi.py,sha256=Wv_Q16TQ3rdLam6OzqXiefIGSMmagGkso09ycyOkHEs,4774
66
66
  mplang/simp/random.py,sha256=7PVgWNL1j7Sf3MqT5PRiWplUu-0dyhF3Ub566iqX86M,3898
67
- mplang/simp/smpc.py,sha256=9upqlozUko5jSDWJdNQoVXTlibUSVyx-Uu8op4buLCM,7124
67
+ mplang/simp/smpc.py,sha256=tdH54aU4T-GIDPhpmf9NCeJC0G67PdOYc04cyUkOnwE,7119
68
68
  mplang/utils/__init__.py,sha256=2WE4cmW96Xkzyq2yRRYNww4kZ5o6u6NbPV0BxqZG698,581
69
69
  mplang/utils/crypto.py,sha256=rvPomBFtznRHc3RPi6Aip9lsU8zW2oxBqGv1K3vn7Rs,1052
70
70
  mplang/utils/func_utils.py,sha256=vCJcZmu0bEbqhOQKdpttV2_MBllIcPSN0b8U4WjNGGo,5164
71
71
  mplang/utils/spu_utils.py,sha256=S3L9RBkBe2AvSuMSQQ12cBY5Y1NPthubvErSX_7nj1A,4158
72
72
  mplang/utils/table_utils.py,sha256=aC-IZOKkSmFkpr3NZchLM0Wt0GOn-rg_xHBHREWBwAU,2202
73
- mplang_nightly-0.1.dev149.dist-info/METADATA,sha256=eZ_qGx1500gbGY5Ms8smHjQsyZ2Wxfa11z2mSeF0CSk,16547
74
- mplang_nightly-0.1.dev149.dist-info/WHEEL,sha256=qtCwoSJWgHk21S1Kb4ihdzI2rlJ1ZKaIurTj_ngOhyQ,87
75
- mplang_nightly-0.1.dev149.dist-info/entry_points.txt,sha256=mG1oJT-GAjQR834a62_QIWb7litzWPPyVnwFqm-rWuY,55
76
- mplang_nightly-0.1.dev149.dist-info/licenses/LICENSE,sha256=xx0jnfkXJvxRnG63LTGOxlggYnIysveWIZ6H3PNdCrQ,11357
77
- mplang_nightly-0.1.dev149.dist-info/RECORD,,
73
+ mplang_nightly-0.1.dev151.dist-info/METADATA,sha256=sCQECTJOQoKyr3XXAf8Kma7lrB5KEt6toJbQA9a5nEA,16547
74
+ mplang_nightly-0.1.dev151.dist-info/WHEEL,sha256=qtCwoSJWgHk21S1Kb4ihdzI2rlJ1ZKaIurTj_ngOhyQ,87
75
+ mplang_nightly-0.1.dev151.dist-info/entry_points.txt,sha256=mG1oJT-GAjQR834a62_QIWb7litzWPPyVnwFqm-rWuY,55
76
+ mplang_nightly-0.1.dev151.dist-info/licenses/LICENSE,sha256=xx0jnfkXJvxRnG63LTGOxlggYnIysveWIZ6H3PNdCrQ,11357
77
+ mplang_nightly-0.1.dev151.dist-info/RECORD,,
File without changes
File without changes
File without changes