mplang-nightly 0.1.dev268__py3-none-any.whl → 0.1.dev270__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- mplang/__init__.py +391 -17
- mplang/{v2/backends → backends}/__init__.py +9 -7
- mplang/{v2/backends → backends}/bfv_impl.py +6 -6
- mplang/{v2/backends → backends}/crypto_impl.py +6 -6
- mplang/{v2/backends → backends}/field_impl.py +5 -5
- mplang/{v2/backends → backends}/func_impl.py +4 -4
- mplang/{v2/backends → backends}/phe_impl.py +3 -3
- mplang/{v2/backends → backends}/simp_design.md +1 -1
- mplang/{v2/backends → backends}/simp_driver/__init__.py +5 -5
- mplang/{v2/backends → backends}/simp_driver/http.py +8 -8
- mplang/{v2/backends → backends}/simp_driver/mem.py +9 -9
- mplang/{v2/backends → backends}/simp_driver/ops.py +4 -4
- mplang/{v2/backends → backends}/simp_driver/state.py +2 -2
- mplang/{v2/backends → backends}/simp_driver/values.py +2 -2
- mplang/{v2/backends → backends}/simp_worker/__init__.py +3 -3
- mplang/{v2/backends → backends}/simp_worker/http.py +10 -10
- mplang/{v2/backends → backends}/simp_worker/mem.py +1 -1
- mplang/{v2/backends → backends}/simp_worker/ops.py +5 -5
- mplang/{v2/backends → backends}/simp_worker/state.py +2 -4
- mplang/{v2/backends → backends}/spu_impl.py +8 -8
- mplang/{v2/backends → backends}/spu_state.py +4 -4
- mplang/{v2/backends → backends}/store_impl.py +3 -3
- mplang/{v2/backends → backends}/table_impl.py +8 -8
- mplang/{v2/backends → backends}/tee_impl.py +6 -6
- mplang/{v2/backends → backends}/tensor_impl.py +6 -6
- mplang/{v2/cli.py → cli.py} +9 -9
- mplang/{v2/cli_guide.md → cli_guide.md} +12 -12
- mplang/{v2/dialects → dialects}/__init__.py +5 -5
- mplang/{v2/dialects → dialects}/bfv.py +6 -6
- mplang/{v2/dialects → dialects}/crypto.py +5 -5
- mplang/{v2/dialects → dialects}/dtypes.py +2 -2
- mplang/{v2/dialects → dialects}/field.py +3 -3
- mplang/{v2/dialects → dialects}/func.py +2 -2
- mplang/{v2/dialects → dialects}/phe.py +6 -6
- mplang/{v2/dialects → dialects}/simp.py +6 -6
- mplang/{v2/dialects → dialects}/spu.py +7 -7
- mplang/{v2/dialects → dialects}/store.py +2 -2
- mplang/{v2/dialects → dialects}/table.py +3 -3
- mplang/{v2/dialects → dialects}/tee.py +6 -6
- mplang/{v2/dialects → dialects}/tensor.py +5 -5
- mplang/{v2/edsl → edsl}/__init__.py +3 -3
- mplang/{v2/edsl → edsl}/context.py +6 -6
- mplang/{v2/edsl → edsl}/graph.py +5 -5
- mplang/{v2/edsl → edsl}/jit.py +2 -2
- mplang/{v2/edsl → edsl}/object.py +1 -1
- mplang/{v2/edsl → edsl}/primitive.py +5 -5
- mplang/{v2/edsl → edsl}/printer.py +1 -1
- mplang/{v2/edsl → edsl}/serde.py +1 -1
- mplang/{v2/edsl → edsl}/tracer.py +7 -7
- mplang/{v2/edsl → edsl}/typing.py +1 -1
- mplang/{v2/kernels → kernels}/ldpc.cpp +13 -13
- mplang/{v2/kernels → kernels}/okvs.cpp +4 -4
- mplang/{v2/kernels → kernels}/okvs_opt.cpp +46 -31
- mplang/{v2/kernels → kernels}/py_kernels.py +1 -1
- mplang/{v2/libs → libs}/collective.py +5 -5
- mplang/{v2/libs → libs}/device/__init__.py +1 -1
- mplang/{v2/libs → libs}/device/api.py +12 -12
- mplang/{v2/libs → libs}/ml/__init__.py +1 -1
- mplang/{v2/libs → libs}/ml/sgb.py +4 -4
- mplang/{v2/libs → libs}/mpc/__init__.py +3 -3
- mplang/{v2/libs → libs}/mpc/_utils.py +2 -2
- mplang/{v2/libs → libs}/mpc/analytics/aggregation.py +1 -1
- mplang/{v2/libs → libs}/mpc/analytics/groupby.py +2 -2
- mplang/{v2/libs → libs}/mpc/analytics/permutation.py +3 -3
- mplang/{v2/libs → libs}/mpc/ot/base.py +3 -3
- mplang/{v2/libs → libs}/mpc/ot/extension.py +2 -2
- mplang/{v2/libs → libs}/mpc/ot/silent.py +4 -4
- mplang/{v2/libs → libs}/mpc/psi/cuckoo.py +3 -3
- mplang/{v2/libs → libs}/mpc/psi/okvs.py +1 -1
- mplang/{v2/libs → libs}/mpc/psi/okvs_gct.py +19 -13
- mplang/{v2/libs → libs}/mpc/psi/oprf.py +3 -3
- mplang/libs/mpc/psi/rr22.py +303 -0
- mplang/{v2/libs → libs}/mpc/psi/unbalanced.py +4 -4
- mplang/{v2/libs → libs}/mpc/vole/gilboa.py +3 -3
- mplang/{v2/libs → libs}/mpc/vole/ldpc.py +2 -2
- mplang/{v2/libs → libs}/mpc/vole/silver.py +6 -6
- mplang/{v2/runtime → runtime}/interpreter.py +11 -11
- mplang/{v2/runtime → runtime}/value.py +2 -2
- mplang/{v1/runtime → utils}/__init__.py +18 -15
- mplang/{v1/utils → utils}/func_utils.py +1 -1
- {mplang_nightly-0.1.dev268.dist-info → mplang_nightly-0.1.dev270.dist-info}/METADATA +2 -2
- mplang_nightly-0.1.dev270.dist-info/RECORD +102 -0
- mplang/v1/__init__.py +0 -157
- mplang/v1/_device.py +0 -602
- mplang/v1/analysis/__init__.py +0 -37
- mplang/v1/analysis/diagram.py +0 -567
- mplang/v1/core/__init__.py +0 -157
- mplang/v1/core/cluster.py +0 -343
- mplang/v1/core/comm.py +0 -281
- mplang/v1/core/context_mgr.py +0 -50
- mplang/v1/core/dtypes.py +0 -335
- mplang/v1/core/expr/__init__.py +0 -80
- mplang/v1/core/expr/ast.py +0 -542
- mplang/v1/core/expr/evaluator.py +0 -581
- mplang/v1/core/expr/printer.py +0 -285
- mplang/v1/core/expr/transformer.py +0 -141
- mplang/v1/core/expr/utils.py +0 -78
- mplang/v1/core/expr/visitor.py +0 -85
- mplang/v1/core/expr/walk.py +0 -387
- mplang/v1/core/interp.py +0 -160
- mplang/v1/core/mask.py +0 -325
- mplang/v1/core/mpir.py +0 -965
- mplang/v1/core/mpobject.py +0 -117
- mplang/v1/core/mptype.py +0 -407
- mplang/v1/core/pfunc.py +0 -130
- mplang/v1/core/primitive.py +0 -877
- mplang/v1/core/table.py +0 -218
- mplang/v1/core/tensor.py +0 -75
- mplang/v1/core/tracer.py +0 -383
- mplang/v1/host.py +0 -130
- mplang/v1/kernels/__init__.py +0 -41
- mplang/v1/kernels/base.py +0 -125
- mplang/v1/kernels/basic.py +0 -240
- mplang/v1/kernels/context.py +0 -369
- mplang/v1/kernels/crypto.py +0 -122
- mplang/v1/kernels/fhe.py +0 -858
- mplang/v1/kernels/mock_tee.py +0 -72
- mplang/v1/kernels/phe.py +0 -1864
- mplang/v1/kernels/spu.py +0 -341
- mplang/v1/kernels/sql_duckdb.py +0 -44
- mplang/v1/kernels/stablehlo.py +0 -90
- mplang/v1/kernels/value.py +0 -626
- mplang/v1/ops/__init__.py +0 -35
- mplang/v1/ops/base.py +0 -424
- mplang/v1/ops/basic.py +0 -294
- mplang/v1/ops/crypto.py +0 -262
- mplang/v1/ops/fhe.py +0 -272
- mplang/v1/ops/jax_cc.py +0 -147
- mplang/v1/ops/nnx_cc.py +0 -168
- mplang/v1/ops/phe.py +0 -216
- mplang/v1/ops/spu.py +0 -151
- mplang/v1/ops/sql_cc.py +0 -303
- mplang/v1/ops/tee.py +0 -36
- mplang/v1/protos/v1alpha1/mpir_pb2.py +0 -63
- mplang/v1/protos/v1alpha1/mpir_pb2.pyi +0 -557
- mplang/v1/protos/v1alpha1/value_pb2.py +0 -34
- mplang/v1/protos/v1alpha1/value_pb2.pyi +0 -169
- mplang/v1/runtime/channel.py +0 -230
- mplang/v1/runtime/cli.py +0 -451
- mplang/v1/runtime/client.py +0 -456
- mplang/v1/runtime/communicator.py +0 -131
- mplang/v1/runtime/data_providers.py +0 -303
- mplang/v1/runtime/driver.py +0 -324
- mplang/v1/runtime/exceptions.py +0 -27
- mplang/v1/runtime/http_api.md +0 -56
- mplang/v1/runtime/link_comm.py +0 -196
- mplang/v1/runtime/server.py +0 -501
- mplang/v1/runtime/session.py +0 -270
- mplang/v1/runtime/simulation.py +0 -324
- mplang/v1/simp/__init__.py +0 -13
- mplang/v1/simp/api.py +0 -353
- mplang/v1/simp/mpi.py +0 -131
- mplang/v1/simp/party.py +0 -225
- mplang/v1/simp/random.py +0 -120
- mplang/v1/simp/smpc.py +0 -238
- mplang/v1/utils/__init__.py +0 -13
- mplang/v1/utils/crypto.py +0 -32
- mplang/v1/utils/spu_utils.py +0 -130
- mplang/v1/utils/table_utils.py +0 -185
- mplang/v2/__init__.py +0 -424
- mplang/v2/libs/mpc/psi/rr22.py +0 -344
- mplang_nightly-0.1.dev268.dist-info/RECORD +0 -180
- /mplang/{v2/backends → backends}/channel.py +0 -0
- /mplang/{v2/edsl → edsl}/README.md +0 -0
- /mplang/{v2/edsl → edsl}/registry.py +0 -0
- /mplang/{v2/kernels → kernels}/Makefile +0 -0
- /mplang/{v2/kernels → kernels}/__init__.py +0 -0
- /mplang/{v2/kernels → kernels}/gf128.cpp +0 -0
- /mplang/{v2/libs → libs}/device/cluster.py +0 -0
- /mplang/{v2/libs → libs}/mpc/analytics/__init__.py +0 -0
- /mplang/{v2/libs → libs}/mpc/analytics/groupby.md +0 -0
- /mplang/{v2/libs → libs}/mpc/common/constants.py +0 -0
- /mplang/{v2/libs → libs}/mpc/ot/__init__.py +0 -0
- /mplang/{v2/libs → libs}/mpc/psi/__init__.py +0 -0
- /mplang/{v2/libs → libs}/mpc/vole/__init__.py +0 -0
- /mplang/{v2/runtime → runtime}/__init__.py +0 -0
- /mplang/{v2/runtime → runtime}/dialect_state.py +0 -0
- /mplang/{v2/runtime → runtime}/object_store.py +0 -0
- {mplang_nightly-0.1.dev268.dist-info → mplang_nightly-0.1.dev270.dist-info}/WHEEL +0 -0
- {mplang_nightly-0.1.dev268.dist-info → mplang_nightly-0.1.dev270.dist-info}/entry_points.txt +0 -0
- {mplang_nightly-0.1.dev268.dist-info → mplang_nightly-0.1.dev270.dist-info}/licenses/LICENSE +0 -0
mplang/v1/host.py
DELETED
|
@@ -1,130 +0,0 @@
|
|
|
1
|
-
# Copyright 2025 Ant Group Co., Ltd.
|
|
2
|
-
#
|
|
3
|
-
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
4
|
-
# you may not use this file except in compliance with the License.
|
|
5
|
-
# You may obtain a copy of the License at
|
|
6
|
-
#
|
|
7
|
-
# http://www.apache.org/licenses/LICENSE-2.0
|
|
8
|
-
#
|
|
9
|
-
# Unless required by applicable law or agreed to in writing, software
|
|
10
|
-
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
11
|
-
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
12
|
-
# See the License for the specific language governing permissions and
|
|
13
|
-
# limitations under the License.
|
|
14
|
-
|
|
15
|
-
from __future__ import annotations
|
|
16
|
-
|
|
17
|
-
from collections.abc import Callable
|
|
18
|
-
from typing import Any
|
|
19
|
-
|
|
20
|
-
from jax.tree_util import tree_map
|
|
21
|
-
|
|
22
|
-
from mplang.v1.core import (
|
|
23
|
-
ClusterSpec,
|
|
24
|
-
InterpContext,
|
|
25
|
-
MPContext,
|
|
26
|
-
MPObject,
|
|
27
|
-
TraceContext,
|
|
28
|
-
TracedFunction,
|
|
29
|
-
trace,
|
|
30
|
-
)
|
|
31
|
-
from mplang.v1.core.context_mgr import cur_ctx, with_ctx
|
|
32
|
-
|
|
33
|
-
|
|
34
|
-
def evaluate(
|
|
35
|
-
interp: InterpContext, mpfn: Callable[..., Any], *args: Any, **kwargs: Any
|
|
36
|
-
) -> Any: # type: ignore[misc]
|
|
37
|
-
"""Evaluate a multi-party function with the given interpreter context.
|
|
38
|
-
|
|
39
|
-
This function accepts arbitrary types as it's designed to handle
|
|
40
|
-
any multi-party computation function and arguments.
|
|
41
|
-
|
|
42
|
-
Args:
|
|
43
|
-
interp: The interpreter context for evaluating the multi-party function.
|
|
44
|
-
mpfn: The multi-party function to evaluate.
|
|
45
|
-
*args: Positional arguments to pass to the function.
|
|
46
|
-
**kwargs: Keyword arguments to pass to the function.
|
|
47
|
-
|
|
48
|
-
Returns:
|
|
49
|
-
Any: The result of evaluating the multi-party function, which can be
|
|
50
|
-
any type depending on the function's return type.
|
|
51
|
-
"""
|
|
52
|
-
assert isinstance(interp, InterpContext), f"Expect InterpContext, got {interp}"
|
|
53
|
-
with with_ctx(interp):
|
|
54
|
-
return mpfn(*args, **kwargs)
|
|
55
|
-
|
|
56
|
-
|
|
57
|
-
def fetch(interp: InterpContext | None, objs: Any) -> Any: # type: ignore[misc]
|
|
58
|
-
"""Fetch computed results from MPObject instances in nested data structures.
|
|
59
|
-
|
|
60
|
-
This function uses tree_map to handle arbitrary nested structures,
|
|
61
|
-
so it needs to accept and return Any type.
|
|
62
|
-
|
|
63
|
-
Args:
|
|
64
|
-
interp: The interpreter context for fetching results. If None, uses the
|
|
65
|
-
current context from cur_ctx().
|
|
66
|
-
objs: The objects containing MPObject instances to fetch. Can be any
|
|
67
|
-
nested structure.
|
|
68
|
-
|
|
69
|
-
Returns:
|
|
70
|
-
Any: The fetched results with the same structure as the input objects,
|
|
71
|
-
but with MPObject instances replaced by their computed values.
|
|
72
|
-
"""
|
|
73
|
-
ctx = interp or cur_ctx()
|
|
74
|
-
assert isinstance(ctx, InterpContext), f"Expect MPExecutor, got {ctx}"
|
|
75
|
-
|
|
76
|
-
evaluated = evaluate(ctx, lambda x: x, objs)
|
|
77
|
-
|
|
78
|
-
def fetch_impl(arg: MPObject | Any) -> Any:
|
|
79
|
-
if not isinstance(arg, MPObject):
|
|
80
|
-
return arg
|
|
81
|
-
|
|
82
|
-
return ctx.fetch(arg)
|
|
83
|
-
|
|
84
|
-
return tree_map(fetch_impl, evaluated)
|
|
85
|
-
|
|
86
|
-
|
|
87
|
-
class CompileOptions(MPContext):
|
|
88
|
-
"""
|
|
89
|
-
Lightweight ``MPContext`` used for ahead-of-time (AOT) compilation.
|
|
90
|
-
|
|
91
|
-
Args:
|
|
92
|
-
psize: Number of participating parties.
|
|
93
|
-
spu_mask: Bitmask indicating which parties own an SPU device. Defaults
|
|
94
|
-
to a mask that enables all parties.
|
|
95
|
-
"""
|
|
96
|
-
|
|
97
|
-
def __init__(self, cluster_spec: Any) -> None:
|
|
98
|
-
super().__init__(cluster_spec)
|
|
99
|
-
|
|
100
|
-
@classmethod
|
|
101
|
-
def simple(cls, world_size: int) -> CompileOptions:
|
|
102
|
-
"""Create a simple CompileOptions with the given party size and SPU mask.
|
|
103
|
-
|
|
104
|
-
Args:
|
|
105
|
-
world_size: Number of participating parties.
|
|
106
|
-
|
|
107
|
-
Returns:
|
|
108
|
-
A CompileOptions instance.
|
|
109
|
-
"""
|
|
110
|
-
cluster_spec = ClusterSpec.simple(world_size)
|
|
111
|
-
return cls(cluster_spec)
|
|
112
|
-
|
|
113
|
-
|
|
114
|
-
def compile(
|
|
115
|
-
mctx: MPContext, fn: Callable[..., Any], *args: Any, **kwargs: Any
|
|
116
|
-
) -> TracedFunction:
|
|
117
|
-
"""Compile a multi-party function into a TracedFunction.
|
|
118
|
-
|
|
119
|
-
Args:
|
|
120
|
-
mctx: The multi-party context for compilation.
|
|
121
|
-
fn: The function to compile.
|
|
122
|
-
*args: Positional arguments to pass during compilation.
|
|
123
|
-
**kwargs: Keyword arguments to pass during compilation.
|
|
124
|
-
|
|
125
|
-
Returns:
|
|
126
|
-
TracedFunction: The compiled function representation that can be
|
|
127
|
-
evaluated in multi-party contexts.
|
|
128
|
-
"""
|
|
129
|
-
trace_ctx = TraceContext(mctx.cluster_spec)
|
|
130
|
-
return trace(trace_ctx, fn, *args, **kwargs)
|
mplang/v1/kernels/__init__.py
DELETED
|
@@ -1,41 +0,0 @@
|
|
|
1
|
-
# Copyright 2025 Ant Group Co., Ltd.
|
|
2
|
-
#
|
|
3
|
-
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
4
|
-
# you may not use this file except in compliance with the License.
|
|
5
|
-
# You may obtain a copy of the License at
|
|
6
|
-
#
|
|
7
|
-
# http://www.apache.org/licenses/LICENSE-2.0
|
|
8
|
-
#
|
|
9
|
-
# Unless required by applicable law or agreed to in writing, software
|
|
10
|
-
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
11
|
-
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
12
|
-
# See the License for the specific language governing permissions and
|
|
13
|
-
# limitations under the License.
|
|
14
|
-
|
|
15
|
-
from mplang.v1.kernels.value import (
|
|
16
|
-
BytesBlob,
|
|
17
|
-
TableValue,
|
|
18
|
-
TensorValue,
|
|
19
|
-
Value,
|
|
20
|
-
ValueDecodeError,
|
|
21
|
-
ValueError,
|
|
22
|
-
decode_value,
|
|
23
|
-
encode_value,
|
|
24
|
-
is_value_envelope,
|
|
25
|
-
list_value_kinds,
|
|
26
|
-
register_value,
|
|
27
|
-
)
|
|
28
|
-
|
|
29
|
-
__all__ = [
|
|
30
|
-
"BytesBlob",
|
|
31
|
-
"TableValue",
|
|
32
|
-
"TensorValue",
|
|
33
|
-
"Value",
|
|
34
|
-
"ValueDecodeError",
|
|
35
|
-
"ValueError",
|
|
36
|
-
"decode_value",
|
|
37
|
-
"encode_value",
|
|
38
|
-
"is_value_envelope",
|
|
39
|
-
"list_value_kinds",
|
|
40
|
-
"register_value",
|
|
41
|
-
]
|
mplang/v1/kernels/base.py
DELETED
|
@@ -1,125 +0,0 @@
|
|
|
1
|
-
# Copyright 2025 Ant Group Co., Ltd.
|
|
2
|
-
#
|
|
3
|
-
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
4
|
-
# you may not use this file except in compliance with the License.
|
|
5
|
-
# You may obtain a copy of the License at
|
|
6
|
-
#
|
|
7
|
-
# http://www.apache.org/licenses/LICENSE-2.0
|
|
8
|
-
#
|
|
9
|
-
# Unless required by applicable law or agreed to in writing, software
|
|
10
|
-
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
11
|
-
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
12
|
-
# See the License for the specific language governing permissions and
|
|
13
|
-
# limitations under the License.
|
|
14
|
-
|
|
15
|
-
"""Backend kernel registry: mapping kernel_id -> implementation.
|
|
16
|
-
|
|
17
|
-
This module provides a lightweight registry for backend kernel implementations.
|
|
18
|
-
It does not track or decide which kernel handles a given semantic operation;
|
|
19
|
-
that policy (op -> kernel_id) is managed externally by each ``RuntimeContext``.
|
|
20
|
-
|
|
21
|
-
Exposed primitives:
|
|
22
|
-
* ``@kernel_def(kernel_id)``: decorator to register a kernel implementation.
|
|
23
|
-
* ``get_kernel_spec(kernel_id)``: look up a previously registered kernel.
|
|
24
|
-
* ``cur_kctx()`` / ``KernelContext``: execution context available only
|
|
25
|
-
inside a kernel body (rank, world_size, per-backend state pockets, and a
|
|
26
|
-
runtime-wide cache shared by kernels of the same runtime instance).
|
|
27
|
-
|
|
28
|
-
No global op binding table exists here; callers resolve an op to a kernel_id
|
|
29
|
-
before invoking the kernel function.
|
|
30
|
-
"""
|
|
31
|
-
|
|
32
|
-
from __future__ import annotations
|
|
33
|
-
|
|
34
|
-
import contextvars
|
|
35
|
-
from collections.abc import Callable
|
|
36
|
-
from dataclasses import dataclass
|
|
37
|
-
from typing import TYPE_CHECKING, Any
|
|
38
|
-
|
|
39
|
-
if TYPE_CHECKING:
|
|
40
|
-
from mplang.v1.kernels.context import RuntimeContext
|
|
41
|
-
|
|
42
|
-
__all__ = [
|
|
43
|
-
"KernelContext",
|
|
44
|
-
"KernelSpec",
|
|
45
|
-
"cur_kctx",
|
|
46
|
-
"get_kernel_spec",
|
|
47
|
-
"kernel_exists",
|
|
48
|
-
"list_kernels",
|
|
49
|
-
]
|
|
50
|
-
|
|
51
|
-
|
|
52
|
-
@dataclass
|
|
53
|
-
class KernelContext:
|
|
54
|
-
"""Ephemeral per-kernel invocation context.
|
|
55
|
-
|
|
56
|
-
Cross-kernel persistent state (RNGs, compiled artifacts, environment handles)
|
|
57
|
-
should be stored in RuntimeContext.
|
|
58
|
-
"""
|
|
59
|
-
|
|
60
|
-
rank: int
|
|
61
|
-
world_size: int
|
|
62
|
-
runtime: RuntimeContext
|
|
63
|
-
|
|
64
|
-
|
|
65
|
-
_CTX_VAR: contextvars.ContextVar[KernelContext | None] = contextvars.ContextVar(
|
|
66
|
-
"_flat_backend_ctx", default=None
|
|
67
|
-
)
|
|
68
|
-
|
|
69
|
-
|
|
70
|
-
def cur_kctx() -> KernelContext:
|
|
71
|
-
"""Return current kernel execution context (only valid inside kernel)."""
|
|
72
|
-
ctx = _CTX_VAR.get()
|
|
73
|
-
if ctx is None:
|
|
74
|
-
raise RuntimeError("cur_kctx() called outside backend kernel execution")
|
|
75
|
-
return ctx
|
|
76
|
-
|
|
77
|
-
|
|
78
|
-
# ---------------- Registry ----------------
|
|
79
|
-
|
|
80
|
-
# Kernel callable signature: (pfunc, *args) -> Any | sequence (no **kwargs)
|
|
81
|
-
KernelFn = Callable[..., Any]
|
|
82
|
-
|
|
83
|
-
|
|
84
|
-
@dataclass
|
|
85
|
-
class KernelSpec:
|
|
86
|
-
kernel_id: str
|
|
87
|
-
fn: KernelFn
|
|
88
|
-
meta: dict[str, Any]
|
|
89
|
-
|
|
90
|
-
|
|
91
|
-
# All registered kernel implementations: kernel_id -> spec
|
|
92
|
-
_KERNELS: dict[str, KernelSpec] = {}
|
|
93
|
-
|
|
94
|
-
|
|
95
|
-
def kernel_def(kernel_id: str, /, **meta: Any) -> Callable[[KernelFn], KernelFn]:
|
|
96
|
-
"""Decorator to register a concrete kernel implementation.
|
|
97
|
-
|
|
98
|
-
This ONLY registers the implementation (kernel_id -> fn). It does NOT bind
|
|
99
|
-
any op. Higher layer must call ``bind_op(op_type, kernel_id)`` explicitly.
|
|
100
|
-
"""
|
|
101
|
-
|
|
102
|
-
def _decorator(fn: KernelFn) -> KernelFn:
|
|
103
|
-
if kernel_id in _KERNELS:
|
|
104
|
-
raise ValueError(f"duplicate kernel_id={kernel_id}")
|
|
105
|
-
_KERNELS[kernel_id] = KernelSpec(kernel_id=kernel_id, fn=fn, meta=dict(meta))
|
|
106
|
-
return fn
|
|
107
|
-
|
|
108
|
-
return _decorator
|
|
109
|
-
|
|
110
|
-
|
|
111
|
-
def get_kernel_spec(kernel_id: str) -> KernelSpec:
|
|
112
|
-
"""Return KernelSpec for a registered kernel_id (no op binding lookup)."""
|
|
113
|
-
spec = _KERNELS.get(kernel_id)
|
|
114
|
-
if spec is None:
|
|
115
|
-
raise KeyError(f"kernel_id {kernel_id} not registered")
|
|
116
|
-
return spec
|
|
117
|
-
|
|
118
|
-
|
|
119
|
-
def list_kernels() -> list[str]:
|
|
120
|
-
return sorted(_KERNELS.keys())
|
|
121
|
-
|
|
122
|
-
|
|
123
|
-
def kernel_exists(kernel_id: str) -> bool:
|
|
124
|
-
"""Return True if a kernel_id has been registered."""
|
|
125
|
-
return kernel_id in _KERNELS
|
mplang/v1/kernels/basic.py
DELETED
|
@@ -1,240 +0,0 @@
|
|
|
1
|
-
# Copyright 2025 Ant Group Co., Ltd.
|
|
2
|
-
#
|
|
3
|
-
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
4
|
-
# you may not use this file except in compliance with the License.
|
|
5
|
-
# You may obtain a copy of the License at
|
|
6
|
-
#
|
|
7
|
-
# http://www.apache.org/licenses/LICENSE-2.0
|
|
8
|
-
#
|
|
9
|
-
# Unless required by applicable law or agreed to in writing, software
|
|
10
|
-
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
11
|
-
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
12
|
-
# See the License for the specific language governing permissions and
|
|
13
|
-
# limitations under the License.
|
|
14
|
-
|
|
15
|
-
from __future__ import annotations
|
|
16
|
-
|
|
17
|
-
import numpy as np
|
|
18
|
-
|
|
19
|
-
from mplang.v1.core import PFunction, TableType, TensorType
|
|
20
|
-
from mplang.v1.kernels.base import cur_kctx, kernel_def
|
|
21
|
-
from mplang.v1.kernels.value import TableValue, TensorValue, Value
|
|
22
|
-
from mplang.v1.runtime.data_providers import get_provider, resolve_uri
|
|
23
|
-
from mplang.v1.utils import table_utils
|
|
24
|
-
|
|
25
|
-
|
|
26
|
-
@kernel_def("basic.identity")
|
|
27
|
-
def _identity(pfunc: PFunction, value: Value) -> Value:
|
|
28
|
-
# Runtime guarantees exactly one argument; no extra arity checks here.
|
|
29
|
-
return value
|
|
30
|
-
|
|
31
|
-
|
|
32
|
-
@kernel_def("basic.read")
|
|
33
|
-
def _read(pfunc: PFunction) -> Value:
|
|
34
|
-
path = pfunc.attrs.get("path")
|
|
35
|
-
if path is None:
|
|
36
|
-
raise ValueError("missing path attr for basic.read")
|
|
37
|
-
out_t = pfunc.outs_info[0]
|
|
38
|
-
uri = resolve_uri(str(path))
|
|
39
|
-
prov = get_provider(uri.scheme)
|
|
40
|
-
if prov is None:
|
|
41
|
-
raise NotImplementedError(f"no resource provider for scheme: {uri.scheme}")
|
|
42
|
-
ctx = cur_kctx()
|
|
43
|
-
try:
|
|
44
|
-
data = prov.read(uri, out_t, ctx=ctx)
|
|
45
|
-
except Exception as e: # pragma: no cover - provider errors
|
|
46
|
-
raise RuntimeError(f"basic.read failed: {e}") from e
|
|
47
|
-
|
|
48
|
-
if isinstance(data, Value):
|
|
49
|
-
return data
|
|
50
|
-
|
|
51
|
-
if isinstance(out_t, TableType):
|
|
52
|
-
return TableValue(data)
|
|
53
|
-
elif isinstance(out_t, TensorType):
|
|
54
|
-
return TensorValue(np.asarray(data))
|
|
55
|
-
else:
|
|
56
|
-
raise TypeError(
|
|
57
|
-
f"basic.read only supports TableType/TensorType outputs, got {type(out_t).__name__}"
|
|
58
|
-
)
|
|
59
|
-
|
|
60
|
-
|
|
61
|
-
@kernel_def("basic.write")
|
|
62
|
-
def _write(pfunc: PFunction, obj: Value) -> Value:
|
|
63
|
-
path = pfunc.attrs.get("path")
|
|
64
|
-
if path is None:
|
|
65
|
-
raise ValueError("missing path attr for basic.write")
|
|
66
|
-
uri = resolve_uri(str(path))
|
|
67
|
-
prov = get_provider(uri.scheme)
|
|
68
|
-
if prov is None:
|
|
69
|
-
raise NotImplementedError(f"no resource provider for scheme: {uri.scheme}")
|
|
70
|
-
# Pass Value object directly to provider - let provider decide how to handle it
|
|
71
|
-
ctx = cur_kctx()
|
|
72
|
-
try:
|
|
73
|
-
prov.write(uri, obj, ctx=ctx)
|
|
74
|
-
except Exception as e: # pragma: no cover
|
|
75
|
-
raise RuntimeError(f"basic.write failed: {e}") from e
|
|
76
|
-
return obj
|
|
77
|
-
|
|
78
|
-
|
|
79
|
-
@kernel_def("basic.constant")
|
|
80
|
-
def _constant(pfunc: PFunction) -> Value:
|
|
81
|
-
"""Return constants as Value types (TensorValue or TableValue)."""
|
|
82
|
-
data_bytes = pfunc.attrs.get("data_bytes")
|
|
83
|
-
if data_bytes is None:
|
|
84
|
-
raise ValueError("missing data_bytes attr for basic.constant")
|
|
85
|
-
out_t = pfunc.outs_info[0]
|
|
86
|
-
fmt = pfunc.attrs.get("data_format")
|
|
87
|
-
if isinstance(out_t, TableType):
|
|
88
|
-
if fmt != "bytes[parquet]":
|
|
89
|
-
raise ValueError(f"unsupported table constant format {fmt}")
|
|
90
|
-
df = table_utils.decode_table(data_bytes, format="parquet")
|
|
91
|
-
return TableValue(df)
|
|
92
|
-
# tensor path
|
|
93
|
-
shape = out_t.shape # type: ignore[attr-defined,union-attr]
|
|
94
|
-
dtype = out_t.dtype.numpy_dtype() # type: ignore[attr-defined,union-attr]
|
|
95
|
-
arr = np.frombuffer(data_bytes, dtype=dtype).reshape(shape)
|
|
96
|
-
return TensorValue(arr)
|
|
97
|
-
|
|
98
|
-
|
|
99
|
-
@kernel_def("basic.rank")
|
|
100
|
-
def _rank(pfunc: PFunction) -> TensorValue:
|
|
101
|
-
"""Return rank as TensorValue."""
|
|
102
|
-
ctx = cur_kctx()
|
|
103
|
-
arr = np.array(ctx.rank, dtype=np.uint64)
|
|
104
|
-
return TensorValue(arr)
|
|
105
|
-
|
|
106
|
-
|
|
107
|
-
@kernel_def("basic.prand")
|
|
108
|
-
def _prand(pfunc: PFunction) -> TensorValue:
|
|
109
|
-
"""Return random data as TensorValue."""
|
|
110
|
-
shape = pfunc.attrs.get("shape", ())
|
|
111
|
-
rng = np.random.default_rng()
|
|
112
|
-
info = np.iinfo(np.uint64)
|
|
113
|
-
data = rng.integers(
|
|
114
|
-
low=info.min, high=info.max, size=shape, dtype=np.uint64, endpoint=True
|
|
115
|
-
)
|
|
116
|
-
return TensorValue(data)
|
|
117
|
-
|
|
118
|
-
|
|
119
|
-
@kernel_def("basic.table_to_tensor")
|
|
120
|
-
def _table_to_tensor(pfunc: PFunction, table: TableValue) -> TensorValue:
|
|
121
|
-
"""Convert table to tensor, return as TensorValue."""
|
|
122
|
-
arrow_table = table.to_arrow()
|
|
123
|
-
if arrow_table.num_columns == 0:
|
|
124
|
-
raise ValueError("cannot pack empty table")
|
|
125
|
-
# Convert Arrow columns to numpy arrays and stack
|
|
126
|
-
mat = np.column_stack([
|
|
127
|
-
arrow_table.column(i).to_numpy() for i in range(arrow_table.num_columns)
|
|
128
|
-
])
|
|
129
|
-
return TensorValue(mat)
|
|
130
|
-
|
|
131
|
-
|
|
132
|
-
@kernel_def("basic.tensor_to_table")
|
|
133
|
-
def _tensor_to_table(pfunc: PFunction, tensor: TensorValue) -> TableValue:
|
|
134
|
-
"""Convert tensor to table, return as TableValue."""
|
|
135
|
-
import pyarrow as pa # type: ignore
|
|
136
|
-
|
|
137
|
-
arr = tensor.to_numpy()
|
|
138
|
-
if arr.ndim != 2:
|
|
139
|
-
raise ValueError("tensor_to_table expects rank-2 array")
|
|
140
|
-
col_names = pfunc.attrs.get("column_names")
|
|
141
|
-
if col_names is None:
|
|
142
|
-
raise ValueError("missing column_names attr")
|
|
143
|
-
# Create Arrow table directly from numpy array columns
|
|
144
|
-
arrays = [pa.array(arr[:, i]) for i in range(arr.shape[1])]
|
|
145
|
-
arrow_table = pa.table(dict(zip(col_names, arrays, strict=True)))
|
|
146
|
-
return TableValue(arrow_table)
|
|
147
|
-
|
|
148
|
-
|
|
149
|
-
def _summ(v: Value) -> str:
|
|
150
|
-
try:
|
|
151
|
-
if isinstance(v, TableValue):
|
|
152
|
-
# Use Arrow's native string representation (more efficient)
|
|
153
|
-
arrow_table = v.to_arrow()
|
|
154
|
-
# Show first 8 rows
|
|
155
|
-
preview = arrow_table.slice(0, min(8, arrow_table.num_rows))
|
|
156
|
-
return str(preview)
|
|
157
|
-
if isinstance(v, TensorValue):
|
|
158
|
-
arr = v.to_numpy()
|
|
159
|
-
return str(
|
|
160
|
-
np.array2string(
|
|
161
|
-
arr, threshold=64, edgeitems=3, precision=6, suppress_small=True
|
|
162
|
-
)
|
|
163
|
-
)
|
|
164
|
-
return repr(v)
|
|
165
|
-
except Exception as e: # pragma: no cover
|
|
166
|
-
return f"<unprintable {type(v).__name__}: {e}>"
|
|
167
|
-
|
|
168
|
-
|
|
169
|
-
@kernel_def("basic.debug_print")
|
|
170
|
-
def _debug_print(pfunc: PFunction, val: Value) -> Value:
|
|
171
|
-
prefix = pfunc.attrs.get("prefix", "")
|
|
172
|
-
ctx = cur_kctx()
|
|
173
|
-
print(f"[debug_print][rank={ctx.rank}] {prefix}{_summ(val)}")
|
|
174
|
-
return val
|
|
175
|
-
|
|
176
|
-
|
|
177
|
-
@kernel_def("basic.pack")
|
|
178
|
-
def _pack(pfunc: PFunction, value: Value) -> TensorValue:
|
|
179
|
-
outs_info = pfunc.outs_info
|
|
180
|
-
if len(outs_info) != 1:
|
|
181
|
-
raise ValueError("basic.pack expects single output type")
|
|
182
|
-
out_ty = outs_info[0]
|
|
183
|
-
if not isinstance(out_ty, TensorType):
|
|
184
|
-
raise TypeError("basic.pack must return TensorType")
|
|
185
|
-
if out_ty.dtype.numpy_dtype() != np.uint8:
|
|
186
|
-
raise TypeError("basic.pack output dtype must be uint8")
|
|
187
|
-
|
|
188
|
-
if isinstance(value, TableValue):
|
|
189
|
-
# Serialize Arrow table using IPC stream for consistency with Value serde
|
|
190
|
-
import pyarrow as pa # type: ignore
|
|
191
|
-
import pyarrow.ipc as pa_ipc # type: ignore
|
|
192
|
-
|
|
193
|
-
arrow_table = value.to_arrow()
|
|
194
|
-
sink = pa.BufferOutputStream()
|
|
195
|
-
with pa_ipc.new_stream(sink, arrow_table.schema) as writer: # type: ignore[arg-type]
|
|
196
|
-
writer.write_table(arrow_table) # type: ignore[arg-type]
|
|
197
|
-
ipc_bytes = sink.getvalue().to_pybytes()
|
|
198
|
-
return TensorValue(np.frombuffer(ipc_bytes, dtype=np.uint8))
|
|
199
|
-
|
|
200
|
-
if isinstance(value, TensorValue):
|
|
201
|
-
arr = value.to_numpy()
|
|
202
|
-
return TensorValue(np.frombuffer(arr.tobytes(order="C"), dtype=np.uint8))
|
|
203
|
-
|
|
204
|
-
raise TypeError(f"basic.pack does not support Value type {type(value).__name__}")
|
|
205
|
-
|
|
206
|
-
|
|
207
|
-
@kernel_def("basic.unpack")
|
|
208
|
-
def _unpack(pfunc: PFunction, packed: TensorValue) -> Value:
|
|
209
|
-
outs_info = pfunc.outs_info
|
|
210
|
-
if len(outs_info) != 1:
|
|
211
|
-
raise ValueError("basic.unpack expects single output type")
|
|
212
|
-
out_ty = outs_info[0]
|
|
213
|
-
|
|
214
|
-
b = packed.to_numpy().astype(np.uint8, copy=False).reshape(-1)
|
|
215
|
-
|
|
216
|
-
if isinstance(out_ty, TensorType):
|
|
217
|
-
np_dtype = out_ty.dtype.numpy_dtype()
|
|
218
|
-
shape = tuple(out_ty.shape)
|
|
219
|
-
if any(dim < 0 for dim in shape):
|
|
220
|
-
raise ValueError("basic.unpack does not support dynamic tensor shapes")
|
|
221
|
-
elem_count = int(np.prod(shape))
|
|
222
|
-
expected = elem_count * np.dtype(np_dtype).itemsize
|
|
223
|
-
if b.size != expected:
|
|
224
|
-
raise ValueError(
|
|
225
|
-
f"unpack size mismatch: got {b.size} bytes, expect {expected} for {np_dtype} {shape}"
|
|
226
|
-
)
|
|
227
|
-
arr = np.frombuffer(b.tobytes(), dtype=np_dtype)
|
|
228
|
-
return TensorValue(arr.reshape(shape))
|
|
229
|
-
|
|
230
|
-
if isinstance(out_ty, TableType):
|
|
231
|
-
# Deserialize Arrow IPC stream back to TableValue
|
|
232
|
-
import pyarrow as pa # type: ignore
|
|
233
|
-
import pyarrow.ipc as pa_ipc # type: ignore
|
|
234
|
-
|
|
235
|
-
buf = pa.py_buffer(b.tobytes())
|
|
236
|
-
reader = pa_ipc.open_stream(buf)
|
|
237
|
-
table = reader.read_all()
|
|
238
|
-
return TableValue(table)
|
|
239
|
-
|
|
240
|
-
raise TypeError("basic.unpack output type must be TensorType or TableType")
|