mplang-nightly 0.1.dev192__py3-none-any.whl → 0.1.dev268__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- mplang/__init__.py +21 -130
- mplang/py.typed +13 -0
- mplang/v1/__init__.py +157 -0
- mplang/v1/_device.py +602 -0
- mplang/{analysis → v1/analysis}/__init__.py +1 -1
- mplang/{analysis → v1/analysis}/diagram.py +4 -4
- mplang/{core → v1/core}/__init__.py +20 -14
- mplang/{core → v1/core}/cluster.py +6 -1
- mplang/{core → v1/core}/comm.py +1 -1
- mplang/{core → v1/core}/context_mgr.py +1 -1
- mplang/{core → v1/core}/dtypes.py +38 -0
- mplang/{core → v1/core}/expr/__init__.py +7 -7
- mplang/{core → v1/core}/expr/ast.py +11 -13
- mplang/{core → v1/core}/expr/evaluator.py +8 -8
- mplang/{core → v1/core}/expr/printer.py +6 -6
- mplang/{core → v1/core}/expr/transformer.py +2 -2
- mplang/{core → v1/core}/expr/utils.py +2 -2
- mplang/{core → v1/core}/expr/visitor.py +1 -1
- mplang/{core → v1/core}/expr/walk.py +1 -1
- mplang/{core → v1/core}/interp.py +6 -6
- mplang/{core → v1/core}/mpir.py +13 -11
- mplang/{core → v1/core}/mpobject.py +6 -6
- mplang/{core → v1/core}/mptype.py +13 -10
- mplang/{core → v1/core}/pfunc.py +2 -2
- mplang/{core → v1/core}/primitive.py +12 -12
- mplang/{core → v1/core}/table.py +36 -8
- mplang/{core → v1/core}/tensor.py +1 -1
- mplang/{core → v1/core}/tracer.py +9 -9
- mplang/{host.py → v1/host.py} +5 -5
- mplang/{kernels → v1/kernels}/__init__.py +1 -1
- mplang/{kernels → v1/kernels}/base.py +1 -1
- mplang/{kernels → v1/kernels}/basic.py +15 -15
- mplang/{kernels → v1/kernels}/context.py +19 -16
- mplang/{kernels → v1/kernels}/crypto.py +8 -10
- mplang/{kernels → v1/kernels}/fhe.py +9 -7
- mplang/{kernels → v1/kernels}/mock_tee.py +3 -3
- mplang/{kernels → v1/kernels}/phe.py +26 -18
- mplang/{kernels → v1/kernels}/spu.py +5 -5
- mplang/{kernels → v1/kernels}/sql_duckdb.py +5 -3
- mplang/{kernels → v1/kernels}/stablehlo.py +18 -17
- mplang/{kernels → v1/kernels}/value.py +2 -2
- mplang/{ops → v1/ops}/__init__.py +3 -3
- mplang/{ops → v1/ops}/base.py +1 -1
- mplang/{ops → v1/ops}/basic.py +6 -5
- mplang/v1/ops/crypto.py +262 -0
- mplang/{ops → v1/ops}/fhe.py +2 -2
- mplang/{ops → v1/ops}/jax_cc.py +26 -59
- mplang/v1/ops/nnx_cc.py +168 -0
- mplang/{ops → v1/ops}/phe.py +16 -3
- mplang/{ops → v1/ops}/spu.py +3 -3
- mplang/v1/ops/sql_cc.py +303 -0
- mplang/{ops → v1/ops}/tee.py +2 -2
- mplang/{runtime → v1/runtime}/__init__.py +2 -2
- mplang/v1/runtime/channel.py +230 -0
- mplang/{runtime → v1/runtime}/cli.py +3 -3
- mplang/{runtime → v1/runtime}/client.py +1 -1
- mplang/{runtime → v1/runtime}/communicator.py +39 -15
- mplang/{runtime → v1/runtime}/data_providers.py +80 -19
- mplang/{runtime → v1/runtime}/driver.py +4 -4
- mplang/v1/runtime/link_comm.py +196 -0
- mplang/{runtime → v1/runtime}/server.py +22 -9
- mplang/{runtime → v1/runtime}/session.py +24 -51
- mplang/{runtime → v1/runtime}/simulation.py +36 -14
- mplang/{simp → v1/simp}/api.py +72 -14
- mplang/{simp → v1/simp}/mpi.py +1 -1
- mplang/{simp → v1/simp}/party.py +5 -5
- mplang/{simp → v1/simp}/random.py +2 -2
- mplang/v1/simp/smpc.py +238 -0
- mplang/v1/utils/table_utils.py +185 -0
- mplang/v2/__init__.py +424 -0
- mplang/v2/backends/__init__.py +57 -0
- mplang/v2/backends/bfv_impl.py +705 -0
- mplang/v2/backends/channel.py +217 -0
- mplang/v2/backends/crypto_impl.py +723 -0
- mplang/v2/backends/field_impl.py +454 -0
- mplang/v2/backends/func_impl.py +107 -0
- mplang/v2/backends/phe_impl.py +148 -0
- mplang/v2/backends/simp_design.md +136 -0
- mplang/v2/backends/simp_driver/__init__.py +41 -0
- mplang/v2/backends/simp_driver/http.py +168 -0
- mplang/v2/backends/simp_driver/mem.py +280 -0
- mplang/v2/backends/simp_driver/ops.py +135 -0
- mplang/v2/backends/simp_driver/state.py +60 -0
- mplang/v2/backends/simp_driver/values.py +52 -0
- mplang/v2/backends/simp_worker/__init__.py +29 -0
- mplang/v2/backends/simp_worker/http.py +354 -0
- mplang/v2/backends/simp_worker/mem.py +102 -0
- mplang/v2/backends/simp_worker/ops.py +167 -0
- mplang/v2/backends/simp_worker/state.py +49 -0
- mplang/v2/backends/spu_impl.py +275 -0
- mplang/v2/backends/spu_state.py +187 -0
- mplang/v2/backends/store_impl.py +62 -0
- mplang/v2/backends/table_impl.py +838 -0
- mplang/v2/backends/tee_impl.py +215 -0
- mplang/v2/backends/tensor_impl.py +519 -0
- mplang/v2/cli.py +603 -0
- mplang/v2/cli_guide.md +122 -0
- mplang/v2/dialects/__init__.py +36 -0
- mplang/v2/dialects/bfv.py +665 -0
- mplang/v2/dialects/crypto.py +689 -0
- mplang/v2/dialects/dtypes.py +378 -0
- mplang/v2/dialects/field.py +210 -0
- mplang/v2/dialects/func.py +135 -0
- mplang/v2/dialects/phe.py +723 -0
- mplang/v2/dialects/simp.py +944 -0
- mplang/v2/dialects/spu.py +349 -0
- mplang/v2/dialects/store.py +63 -0
- mplang/v2/dialects/table.py +407 -0
- mplang/v2/dialects/tee.py +346 -0
- mplang/v2/dialects/tensor.py +1175 -0
- mplang/v2/edsl/README.md +279 -0
- mplang/v2/edsl/__init__.py +99 -0
- mplang/v2/edsl/context.py +311 -0
- mplang/v2/edsl/graph.py +463 -0
- mplang/v2/edsl/jit.py +62 -0
- mplang/v2/edsl/object.py +53 -0
- mplang/v2/edsl/primitive.py +284 -0
- mplang/v2/edsl/printer.py +119 -0
- mplang/v2/edsl/registry.py +207 -0
- mplang/v2/edsl/serde.py +375 -0
- mplang/v2/edsl/tracer.py +614 -0
- mplang/v2/edsl/typing.py +816 -0
- mplang/v2/kernels/Makefile +30 -0
- mplang/v2/kernels/__init__.py +23 -0
- mplang/v2/kernels/gf128.cpp +148 -0
- mplang/v2/kernels/ldpc.cpp +82 -0
- mplang/v2/kernels/okvs.cpp +283 -0
- mplang/v2/kernels/okvs_opt.cpp +291 -0
- mplang/v2/kernels/py_kernels.py +398 -0
- mplang/v2/libs/collective.py +330 -0
- mplang/v2/libs/device/__init__.py +51 -0
- mplang/v2/libs/device/api.py +813 -0
- mplang/v2/libs/device/cluster.py +352 -0
- mplang/v2/libs/ml/__init__.py +23 -0
- mplang/v2/libs/ml/sgb.py +1861 -0
- mplang/v2/libs/mpc/__init__.py +41 -0
- mplang/v2/libs/mpc/_utils.py +99 -0
- mplang/v2/libs/mpc/analytics/__init__.py +35 -0
- mplang/v2/libs/mpc/analytics/aggregation.py +372 -0
- mplang/v2/libs/mpc/analytics/groupby.md +99 -0
- mplang/v2/libs/mpc/analytics/groupby.py +331 -0
- mplang/v2/libs/mpc/analytics/permutation.py +386 -0
- mplang/v2/libs/mpc/common/constants.py +39 -0
- mplang/v2/libs/mpc/ot/__init__.py +32 -0
- mplang/v2/libs/mpc/ot/base.py +222 -0
- mplang/v2/libs/mpc/ot/extension.py +477 -0
- mplang/v2/libs/mpc/ot/silent.py +217 -0
- mplang/v2/libs/mpc/psi/__init__.py +40 -0
- mplang/v2/libs/mpc/psi/cuckoo.py +228 -0
- mplang/v2/libs/mpc/psi/okvs.py +49 -0
- mplang/v2/libs/mpc/psi/okvs_gct.py +79 -0
- mplang/v2/libs/mpc/psi/oprf.py +310 -0
- mplang/v2/libs/mpc/psi/rr22.py +344 -0
- mplang/v2/libs/mpc/psi/unbalanced.py +200 -0
- mplang/v2/libs/mpc/vole/__init__.py +31 -0
- mplang/v2/libs/mpc/vole/gilboa.py +327 -0
- mplang/v2/libs/mpc/vole/ldpc.py +383 -0
- mplang/v2/libs/mpc/vole/silver.py +336 -0
- mplang/v2/runtime/__init__.py +15 -0
- mplang/v2/runtime/dialect_state.py +41 -0
- mplang/v2/runtime/interpreter.py +871 -0
- mplang/v2/runtime/object_store.py +194 -0
- mplang/v2/runtime/value.py +141 -0
- {mplang_nightly-0.1.dev192.dist-info → mplang_nightly-0.1.dev268.dist-info}/METADATA +22 -16
- mplang_nightly-0.1.dev268.dist-info/RECORD +180 -0
- {mplang_nightly-0.1.dev192.dist-info → mplang_nightly-0.1.dev268.dist-info}/WHEEL +1 -1
- mplang/device.py +0 -327
- mplang/ops/crypto.py +0 -108
- mplang/ops/ibis_cc.py +0 -136
- mplang/ops/sql_cc.py +0 -62
- mplang/runtime/link_comm.py +0 -78
- mplang/simp/smpc.py +0 -201
- mplang/utils/table_utils.py +0 -85
- mplang_nightly-0.1.dev192.dist-info/RECORD +0 -83
- /mplang/{core → v1/core}/mask.py +0 -0
- /mplang/{protos → v1/protos}/v1alpha1/mpir_pb2.py +0 -0
- /mplang/{protos → v1/protos}/v1alpha1/mpir_pb2.pyi +0 -0
- /mplang/{protos → v1/protos}/v1alpha1/value_pb2.py +0 -0
- /mplang/{protos → v1/protos}/v1alpha1/value_pb2.pyi +0 -0
- /mplang/{runtime → v1/runtime}/exceptions.py +0 -0
- /mplang/{runtime → v1/runtime}/http_api.md +0 -0
- /mplang/{simp → v1/simp}/__init__.py +0 -0
- /mplang/{utils → v1/utils}/__init__.py +0 -0
- /mplang/{utils → v1/utils}/crypto.py +0 -0
- /mplang/{utils → v1/utils}/func_utils.py +0 -0
- /mplang/{utils → v1/utils}/spu_utils.py +0 -0
- {mplang_nightly-0.1.dev192.dist-info → mplang_nightly-0.1.dev268.dist-info}/entry_points.txt +0 -0
- {mplang_nightly-0.1.dev192.dist-info → mplang_nightly-0.1.dev268.dist-info}/licenses/LICENSE +0 -0
|
@@ -0,0 +1,135 @@
|
|
|
1
|
+
# Copyright 2025 Ant Group Co., Ltd.
|
|
2
|
+
#
|
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
4
|
+
# you may not use this file except in compliance with the License.
|
|
5
|
+
# You may obtain a copy of the License at
|
|
6
|
+
#
|
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
8
|
+
#
|
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
12
|
+
# See the License for the specific language governing permissions and
|
|
13
|
+
# limitations under the License.
|
|
14
|
+
|
|
15
|
+
"""Function dialect: generic region-based call + definition primitives.
|
|
16
|
+
|
|
17
|
+
Design: Function as Value
|
|
18
|
+
- func.func: Defines a function and returns a function handle (TraceObject).
|
|
19
|
+
- func.call: Invokes a function using the handle.
|
|
20
|
+
"""
|
|
21
|
+
|
|
22
|
+
from __future__ import annotations
|
|
23
|
+
|
|
24
|
+
from collections.abc import Callable
|
|
25
|
+
from typing import Any
|
|
26
|
+
|
|
27
|
+
import mplang.v2.edsl as el
|
|
28
|
+
import mplang.v2.edsl.typing as elt
|
|
29
|
+
|
|
30
|
+
func_def_p = el.Primitive[el.TraceObject]("func.func")
|
|
31
|
+
call_p = el.Primitive[Any]("func.call")
|
|
32
|
+
FuncType = elt.CustomType("function")
|
|
33
|
+
|
|
34
|
+
|
|
35
|
+
def _current_tracer() -> el.Tracer:
|
|
36
|
+
ctx = el.get_current_context()
|
|
37
|
+
if not isinstance(ctx, el.Tracer):
|
|
38
|
+
raise TypeError(f"Expected Tracer context, got {type(ctx)}")
|
|
39
|
+
return ctx
|
|
40
|
+
|
|
41
|
+
|
|
42
|
+
@func_def_p.def_trace
|
|
43
|
+
def _func_trace(fn: Callable[..., Any], *args: Any, **kwargs: Any) -> el.TraceObject:
|
|
44
|
+
"""Trace func.func: Define a function and return a handle.
|
|
45
|
+
|
|
46
|
+
Returns a TraceObject representing the function, which can be passed to func.call.
|
|
47
|
+
"""
|
|
48
|
+
tracer = _current_tracer()
|
|
49
|
+
traced = el.trace(fn, *args, **kwargs)
|
|
50
|
+
|
|
51
|
+
attrs = {
|
|
52
|
+
"sym_name": traced.name,
|
|
53
|
+
"in_var_pos": traced.in_var_pos,
|
|
54
|
+
"in_imms": traced.in_imms,
|
|
55
|
+
"in_tree": traced.in_tree,
|
|
56
|
+
"out_var_pos": traced.out_var_pos,
|
|
57
|
+
"out_imms": traced.out_imms,
|
|
58
|
+
"out_tree": traced.out_tree,
|
|
59
|
+
"output_types": [val.type for val in traced.graph.outputs],
|
|
60
|
+
}
|
|
61
|
+
|
|
62
|
+
# func.func returns a function handle (single output of FuncType)
|
|
63
|
+
result_values = tracer.graph.add_op(
|
|
64
|
+
opcode="func.func",
|
|
65
|
+
inputs=[],
|
|
66
|
+
output_types=[FuncType],
|
|
67
|
+
attrs=attrs,
|
|
68
|
+
regions=[traced.graph],
|
|
69
|
+
)
|
|
70
|
+
|
|
71
|
+
return el.TraceObject(result_values[0], tracer)
|
|
72
|
+
|
|
73
|
+
|
|
74
|
+
@call_p.def_trace
|
|
75
|
+
def _call_trace(fn_handle: el.TraceObject, *args: Any) -> Any:
|
|
76
|
+
"""Trace func.call: Invoke a function using its handle.
|
|
77
|
+
|
|
78
|
+
The function handle carries output type and PyTree information.
|
|
79
|
+
"""
|
|
80
|
+
tracer = _current_tracer()
|
|
81
|
+
|
|
82
|
+
if not isinstance(fn_handle, el.TraceObject):
|
|
83
|
+
raise TypeError(
|
|
84
|
+
f"func.call expects TraceObject as function handle, got {type(fn_handle)}"
|
|
85
|
+
)
|
|
86
|
+
if not all(isinstance(arg, el.TraceObject) for arg in args):
|
|
87
|
+
raise TypeError("func.call arguments must be TraceObjects")
|
|
88
|
+
|
|
89
|
+
# Get output types and PyTree from the func.func operation that produced fn_handle
|
|
90
|
+
fn_op = fn_handle._graph_value.defining_op
|
|
91
|
+
if fn_op is None:
|
|
92
|
+
raise ValueError("Function handle has no defining operation")
|
|
93
|
+
output_types = fn_op.attrs.get("output_types", [elt.TensorType(elt.i64, ())])
|
|
94
|
+
out_tree = fn_op.attrs.get("out_tree")
|
|
95
|
+
|
|
96
|
+
result_values = tracer.graph.add_op(
|
|
97
|
+
opcode="func.call",
|
|
98
|
+
inputs=[fn_handle._graph_value] + [arg._graph_value for arg in args],
|
|
99
|
+
output_types=output_types,
|
|
100
|
+
attrs={},
|
|
101
|
+
regions=[],
|
|
102
|
+
)
|
|
103
|
+
|
|
104
|
+
traced_results = [el.TraceObject(v, tracer) for v in result_values]
|
|
105
|
+
|
|
106
|
+
# Restructure outputs using PyTree if available
|
|
107
|
+
if out_tree is not None:
|
|
108
|
+
try:
|
|
109
|
+
return out_tree.unflatten(traced_results)
|
|
110
|
+
except ValueError as e:
|
|
111
|
+
import warnings
|
|
112
|
+
|
|
113
|
+
warnings.warn(
|
|
114
|
+
f"Failed to unflatten PyTree for func.call: {e}", stacklevel=2
|
|
115
|
+
)
|
|
116
|
+
|
|
117
|
+
# Single result: return directly
|
|
118
|
+
if len(result_values) == 1:
|
|
119
|
+
return traced_results[0]
|
|
120
|
+
return traced_results
|
|
121
|
+
|
|
122
|
+
|
|
123
|
+
def func(
|
|
124
|
+
fn: Callable[..., Any], *args: el.TraceObject, **kwargs: Any
|
|
125
|
+
) -> el.TraceObject:
|
|
126
|
+
"""Define a function and return its handle."""
|
|
127
|
+
return func_def_p.bind(fn, *args, **kwargs)
|
|
128
|
+
|
|
129
|
+
|
|
130
|
+
def call(fn_handle: el.TraceObject, *args: el.TraceObject) -> Any:
|
|
131
|
+
"""Call a function using its handle."""
|
|
132
|
+
return call_p.bind(fn_handle, *args)
|
|
133
|
+
|
|
134
|
+
|
|
135
|
+
__all__ = ["call", "call_p", "func", "func_def_p"]
|