mplang-nightly 0.1.dev269__py3-none-any.whl → 0.1.dev270__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- mplang/__init__.py +391 -17
- mplang/{v2/backends → backends}/__init__.py +9 -7
- mplang/{v2/backends → backends}/bfv_impl.py +6 -6
- mplang/{v2/backends → backends}/crypto_impl.py +6 -6
- mplang/{v2/backends → backends}/field_impl.py +5 -5
- mplang/{v2/backends → backends}/func_impl.py +4 -4
- mplang/{v2/backends → backends}/phe_impl.py +3 -3
- mplang/{v2/backends → backends}/simp_design.md +1 -1
- mplang/{v2/backends → backends}/simp_driver/__init__.py +5 -5
- mplang/{v2/backends → backends}/simp_driver/http.py +8 -8
- mplang/{v2/backends → backends}/simp_driver/mem.py +9 -9
- mplang/{v2/backends → backends}/simp_driver/ops.py +4 -4
- mplang/{v2/backends → backends}/simp_driver/state.py +2 -2
- mplang/{v2/backends → backends}/simp_driver/values.py +2 -2
- mplang/{v2/backends → backends}/simp_worker/__init__.py +3 -3
- mplang/{v2/backends → backends}/simp_worker/http.py +10 -10
- mplang/{v2/backends → backends}/simp_worker/mem.py +1 -1
- mplang/{v2/backends → backends}/simp_worker/ops.py +5 -5
- mplang/{v2/backends → backends}/simp_worker/state.py +2 -4
- mplang/{v2/backends → backends}/spu_impl.py +8 -8
- mplang/{v2/backends → backends}/spu_state.py +4 -4
- mplang/{v2/backends → backends}/store_impl.py +3 -3
- mplang/{v2/backends → backends}/table_impl.py +8 -8
- mplang/{v2/backends → backends}/tee_impl.py +6 -6
- mplang/{v2/backends → backends}/tensor_impl.py +6 -6
- mplang/{v2/cli.py → cli.py} +9 -9
- mplang/{v2/cli_guide.md → cli_guide.md} +12 -12
- mplang/{v2/dialects → dialects}/__init__.py +5 -5
- mplang/{v2/dialects → dialects}/bfv.py +6 -6
- mplang/{v2/dialects → dialects}/crypto.py +5 -5
- mplang/{v2/dialects → dialects}/dtypes.py +2 -2
- mplang/{v2/dialects → dialects}/field.py +3 -3
- mplang/{v2/dialects → dialects}/func.py +2 -2
- mplang/{v2/dialects → dialects}/phe.py +6 -6
- mplang/{v2/dialects → dialects}/simp.py +6 -6
- mplang/{v2/dialects → dialects}/spu.py +7 -7
- mplang/{v2/dialects → dialects}/store.py +2 -2
- mplang/{v2/dialects → dialects}/table.py +3 -3
- mplang/{v2/dialects → dialects}/tee.py +6 -6
- mplang/{v2/dialects → dialects}/tensor.py +5 -5
- mplang/{v2/edsl → edsl}/__init__.py +3 -3
- mplang/{v2/edsl → edsl}/context.py +6 -6
- mplang/{v2/edsl → edsl}/graph.py +5 -5
- mplang/{v2/edsl → edsl}/jit.py +2 -2
- mplang/{v2/edsl → edsl}/object.py +1 -1
- mplang/{v2/edsl → edsl}/primitive.py +5 -5
- mplang/{v2/edsl → edsl}/printer.py +1 -1
- mplang/{v2/edsl → edsl}/serde.py +1 -1
- mplang/{v2/edsl → edsl}/tracer.py +7 -7
- mplang/{v2/edsl → edsl}/typing.py +1 -1
- mplang/{v2/kernels → kernels}/ldpc.cpp +13 -13
- mplang/{v2/kernels → kernels}/okvs.cpp +4 -4
- mplang/{v2/kernels → kernels}/okvs_opt.cpp +31 -31
- mplang/{v2/kernels → kernels}/py_kernels.py +1 -1
- mplang/{v2/libs → libs}/collective.py +5 -5
- mplang/{v2/libs → libs}/device/__init__.py +1 -1
- mplang/{v2/libs → libs}/device/api.py +12 -12
- mplang/{v2/libs → libs}/ml/__init__.py +1 -1
- mplang/{v2/libs → libs}/ml/sgb.py +4 -4
- mplang/{v2/libs → libs}/mpc/__init__.py +3 -3
- mplang/{v2/libs → libs}/mpc/_utils.py +2 -2
- mplang/{v2/libs → libs}/mpc/analytics/aggregation.py +1 -1
- mplang/{v2/libs → libs}/mpc/analytics/groupby.py +2 -2
- mplang/{v2/libs → libs}/mpc/analytics/permutation.py +3 -3
- mplang/{v2/libs → libs}/mpc/ot/base.py +3 -3
- mplang/{v2/libs → libs}/mpc/ot/extension.py +2 -2
- mplang/{v2/libs → libs}/mpc/ot/silent.py +4 -4
- mplang/{v2/libs → libs}/mpc/psi/cuckoo.py +3 -3
- mplang/{v2/libs → libs}/mpc/psi/okvs.py +1 -1
- mplang/{v2/libs → libs}/mpc/psi/okvs_gct.py +3 -3
- mplang/{v2/libs → libs}/mpc/psi/oprf.py +3 -3
- mplang/{v2/libs → libs}/mpc/psi/rr22.py +7 -7
- mplang/{v2/libs → libs}/mpc/psi/unbalanced.py +4 -4
- mplang/{v2/libs → libs}/mpc/vole/gilboa.py +3 -3
- mplang/{v2/libs → libs}/mpc/vole/ldpc.py +2 -2
- mplang/{v2/libs → libs}/mpc/vole/silver.py +6 -6
- mplang/{v2/runtime → runtime}/interpreter.py +11 -11
- mplang/{v2/runtime → runtime}/value.py +2 -2
- mplang/{v1/runtime → utils}/__init__.py +18 -15
- mplang/{v1/utils → utils}/func_utils.py +1 -1
- {mplang_nightly-0.1.dev269.dist-info → mplang_nightly-0.1.dev270.dist-info}/METADATA +2 -2
- mplang_nightly-0.1.dev270.dist-info/RECORD +102 -0
- mplang/v1/__init__.py +0 -157
- mplang/v1/_device.py +0 -602
- mplang/v1/analysis/__init__.py +0 -37
- mplang/v1/analysis/diagram.py +0 -567
- mplang/v1/core/__init__.py +0 -157
- mplang/v1/core/cluster.py +0 -343
- mplang/v1/core/comm.py +0 -281
- mplang/v1/core/context_mgr.py +0 -50
- mplang/v1/core/dtypes.py +0 -335
- mplang/v1/core/expr/__init__.py +0 -80
- mplang/v1/core/expr/ast.py +0 -542
- mplang/v1/core/expr/evaluator.py +0 -581
- mplang/v1/core/expr/printer.py +0 -285
- mplang/v1/core/expr/transformer.py +0 -141
- mplang/v1/core/expr/utils.py +0 -78
- mplang/v1/core/expr/visitor.py +0 -85
- mplang/v1/core/expr/walk.py +0 -387
- mplang/v1/core/interp.py +0 -160
- mplang/v1/core/mask.py +0 -325
- mplang/v1/core/mpir.py +0 -965
- mplang/v1/core/mpobject.py +0 -117
- mplang/v1/core/mptype.py +0 -407
- mplang/v1/core/pfunc.py +0 -130
- mplang/v1/core/primitive.py +0 -877
- mplang/v1/core/table.py +0 -218
- mplang/v1/core/tensor.py +0 -75
- mplang/v1/core/tracer.py +0 -383
- mplang/v1/host.py +0 -130
- mplang/v1/kernels/__init__.py +0 -41
- mplang/v1/kernels/base.py +0 -125
- mplang/v1/kernels/basic.py +0 -240
- mplang/v1/kernels/context.py +0 -369
- mplang/v1/kernels/crypto.py +0 -122
- mplang/v1/kernels/fhe.py +0 -858
- mplang/v1/kernels/mock_tee.py +0 -72
- mplang/v1/kernels/phe.py +0 -1864
- mplang/v1/kernels/spu.py +0 -341
- mplang/v1/kernels/sql_duckdb.py +0 -44
- mplang/v1/kernels/stablehlo.py +0 -90
- mplang/v1/kernels/value.py +0 -626
- mplang/v1/ops/__init__.py +0 -35
- mplang/v1/ops/base.py +0 -424
- mplang/v1/ops/basic.py +0 -294
- mplang/v1/ops/crypto.py +0 -262
- mplang/v1/ops/fhe.py +0 -272
- mplang/v1/ops/jax_cc.py +0 -147
- mplang/v1/ops/nnx_cc.py +0 -168
- mplang/v1/ops/phe.py +0 -216
- mplang/v1/ops/spu.py +0 -151
- mplang/v1/ops/sql_cc.py +0 -303
- mplang/v1/ops/tee.py +0 -36
- mplang/v1/protos/v1alpha1/mpir_pb2.py +0 -63
- mplang/v1/protos/v1alpha1/mpir_pb2.pyi +0 -557
- mplang/v1/protos/v1alpha1/value_pb2.py +0 -34
- mplang/v1/protos/v1alpha1/value_pb2.pyi +0 -169
- mplang/v1/runtime/channel.py +0 -230
- mplang/v1/runtime/cli.py +0 -451
- mplang/v1/runtime/client.py +0 -456
- mplang/v1/runtime/communicator.py +0 -131
- mplang/v1/runtime/data_providers.py +0 -303
- mplang/v1/runtime/driver.py +0 -324
- mplang/v1/runtime/exceptions.py +0 -27
- mplang/v1/runtime/http_api.md +0 -56
- mplang/v1/runtime/link_comm.py +0 -196
- mplang/v1/runtime/server.py +0 -501
- mplang/v1/runtime/session.py +0 -270
- mplang/v1/runtime/simulation.py +0 -324
- mplang/v1/simp/__init__.py +0 -13
- mplang/v1/simp/api.py +0 -353
- mplang/v1/simp/mpi.py +0 -131
- mplang/v1/simp/party.py +0 -225
- mplang/v1/simp/random.py +0 -120
- mplang/v1/simp/smpc.py +0 -238
- mplang/v1/utils/__init__.py +0 -13
- mplang/v1/utils/crypto.py +0 -32
- mplang/v1/utils/spu_utils.py +0 -130
- mplang/v1/utils/table_utils.py +0 -185
- mplang/v2/__init__.py +0 -424
- mplang_nightly-0.1.dev269.dist-info/RECORD +0 -180
- /mplang/{v2/backends → backends}/channel.py +0 -0
- /mplang/{v2/edsl → edsl}/README.md +0 -0
- /mplang/{v2/edsl → edsl}/registry.py +0 -0
- /mplang/{v2/kernels → kernels}/Makefile +0 -0
- /mplang/{v2/kernels → kernels}/__init__.py +0 -0
- /mplang/{v2/kernels → kernels}/gf128.cpp +0 -0
- /mplang/{v2/libs → libs}/device/cluster.py +0 -0
- /mplang/{v2/libs → libs}/mpc/analytics/__init__.py +0 -0
- /mplang/{v2/libs → libs}/mpc/analytics/groupby.md +0 -0
- /mplang/{v2/libs → libs}/mpc/common/constants.py +0 -0
- /mplang/{v2/libs → libs}/mpc/ot/__init__.py +0 -0
- /mplang/{v2/libs → libs}/mpc/psi/__init__.py +0 -0
- /mplang/{v2/libs → libs}/mpc/vole/__init__.py +0 -0
- /mplang/{v2/runtime → runtime}/__init__.py +0 -0
- /mplang/{v2/runtime → runtime}/dialect_state.py +0 -0
- /mplang/{v2/runtime → runtime}/object_store.py +0 -0
- {mplang_nightly-0.1.dev269.dist-info → mplang_nightly-0.1.dev270.dist-info}/WHEEL +0 -0
- {mplang_nightly-0.1.dev269.dist-info → mplang_nightly-0.1.dev270.dist-info}/entry_points.txt +0 -0
- {mplang_nightly-0.1.dev269.dist-info → mplang_nightly-0.1.dev270.dist-info}/licenses/LICENSE +0 -0
mplang/v1/core/mpir.py
DELETED
|
@@ -1,965 +0,0 @@
|
|
|
1
|
-
# Copyright 2025 Ant Group Co., Ltd.
|
|
2
|
-
#
|
|
3
|
-
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
4
|
-
# you may not use this file except in compliance with the License.
|
|
5
|
-
# You may obtain a copy of the License at
|
|
6
|
-
#
|
|
7
|
-
# http://www.apache.org/licenses/LICENSE-2.0
|
|
8
|
-
#
|
|
9
|
-
# Unless required by applicable law or agreed to in writing, software
|
|
10
|
-
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
11
|
-
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
12
|
-
# See the License for the specific language governing permissions and
|
|
13
|
-
# limitations under the License.
|
|
14
|
-
|
|
15
|
-
"""MPIR (Multi-Party Intermediate Representation) serialization module.
|
|
16
|
-
|
|
17
|
-
This module provides functionality for serializing and deserializing
|
|
18
|
-
expression-based computation graphs to and from protobuf representations.
|
|
19
|
-
It serves as the bridge between in-memory expression trees and their
|
|
20
|
-
serialized form for storage or transmission.
|
|
21
|
-
|
|
22
|
-
Key components:
|
|
23
|
-
- Writer: Serializes Expr objects to GraphProto
|
|
24
|
-
- Reader: Deserializes GraphProto back to Expr objects
|
|
25
|
-
- Conversion functions: Handle mapping between Python types and protobuf types
|
|
26
|
-
"""
|
|
27
|
-
|
|
28
|
-
from __future__ import annotations
|
|
29
|
-
|
|
30
|
-
from typing import Any
|
|
31
|
-
|
|
32
|
-
import numpy as np
|
|
33
|
-
import spu.libspu as spu_api
|
|
34
|
-
|
|
35
|
-
from mplang.v1.core.dtypes import DATE, JSON, STRING, TIME, TIMESTAMP, DType
|
|
36
|
-
from mplang.v1.core.expr import Expr, FuncDefExpr
|
|
37
|
-
from mplang.v1.core.expr.ast import (
|
|
38
|
-
AccessExpr,
|
|
39
|
-
CallExpr,
|
|
40
|
-
CondExpr,
|
|
41
|
-
ConvExpr,
|
|
42
|
-
EvalExpr,
|
|
43
|
-
ShflExpr,
|
|
44
|
-
ShflSExpr,
|
|
45
|
-
TupleExpr,
|
|
46
|
-
VariableExpr,
|
|
47
|
-
WhileExpr,
|
|
48
|
-
)
|
|
49
|
-
from mplang.v1.core.expr.walk import walk
|
|
50
|
-
from mplang.v1.core.mask import Mask
|
|
51
|
-
from mplang.v1.core.mptype import MPType
|
|
52
|
-
from mplang.v1.core.pfunc import PFunction
|
|
53
|
-
from mplang.v1.core.table import TableType
|
|
54
|
-
from mplang.v1.core.tensor import TensorType
|
|
55
|
-
from mplang.v1.protos.v1alpha1 import mpir_pb2
|
|
56
|
-
|
|
57
|
-
# Single mapping table for dtype conversion
|
|
58
|
-
DTYPE_MAPPING = {
|
|
59
|
-
np.float32: mpir_pb2.DataType.F32,
|
|
60
|
-
np.uint8: mpir_pb2.DataType.U8,
|
|
61
|
-
np.int8: mpir_pb2.DataType.I8,
|
|
62
|
-
np.uint16: mpir_pb2.DataType.U16,
|
|
63
|
-
np.int16: mpir_pb2.DataType.I16,
|
|
64
|
-
np.int32: mpir_pb2.DataType.I32,
|
|
65
|
-
np.int64: mpir_pb2.DataType.I64,
|
|
66
|
-
np.str_: mpir_pb2.DataType.STRING,
|
|
67
|
-
np.bool_: mpir_pb2.DataType.BOOL,
|
|
68
|
-
np.float16: mpir_pb2.DataType.F16,
|
|
69
|
-
np.float64: mpir_pb2.DataType.F64,
|
|
70
|
-
np.uint32: mpir_pb2.DataType.U32,
|
|
71
|
-
np.uint64: mpir_pb2.DataType.U64,
|
|
72
|
-
np.complex64: mpir_pb2.DataType.COMPLEX64,
|
|
73
|
-
np.complex128: mpir_pb2.DataType.COMPLEX128,
|
|
74
|
-
}
|
|
75
|
-
|
|
76
|
-
# Additional mapping for table-only DType constants
|
|
77
|
-
DTYPE_TO_PROTO_MAPPING = {
|
|
78
|
-
# Map DType constants to protobuf enums
|
|
79
|
-
STRING: mpir_pb2.DataType.STRING,
|
|
80
|
-
DATE: mpir_pb2.DataType.DATE,
|
|
81
|
-
TIME: mpir_pb2.DataType.TIME,
|
|
82
|
-
TIMESTAMP: mpir_pb2.DataType.TIMESTAMP,
|
|
83
|
-
JSON: mpir_pb2.DataType.JSON,
|
|
84
|
-
}
|
|
85
|
-
|
|
86
|
-
|
|
87
|
-
def dtype_to_proto(dtype_like: Any) -> Any:
|
|
88
|
-
"""Convert dtype (DType, NumPy dtype, or type) to protobuf DataType.
|
|
89
|
-
|
|
90
|
-
Args:
|
|
91
|
-
dtype_like: A DType, NumPy dtype, or Python type to convert.
|
|
92
|
-
|
|
93
|
-
Returns:
|
|
94
|
-
The corresponding protobuf DataType enum value.
|
|
95
|
-
|
|
96
|
-
Raises:
|
|
97
|
-
ValueError: If the dtype is not supported for conversion.
|
|
98
|
-
"""
|
|
99
|
-
# If it's already a DType, check for direct mapping first
|
|
100
|
-
if isinstance(dtype_like, DType):
|
|
101
|
-
# Check for table-only types first
|
|
102
|
-
if dtype_like in DTYPE_TO_PROTO_MAPPING:
|
|
103
|
-
return DTYPE_TO_PROTO_MAPPING[dtype_like]
|
|
104
|
-
|
|
105
|
-
# For regular types, convert to numpy for protobuf mapping
|
|
106
|
-
try:
|
|
107
|
-
numpy_dtype = dtype_like.to_numpy()
|
|
108
|
-
key_type = numpy_dtype.type
|
|
109
|
-
except ValueError as e:
|
|
110
|
-
# Handle table-only types that can't be converted to numpy
|
|
111
|
-
raise ValueError(
|
|
112
|
-
f"Unsupported dtype for proto conversion: {dtype_like}. This is likely a table-only type that cannot be converted to a numpy dtype. Please ensure the dtype is supported for proto conversion."
|
|
113
|
-
) from e
|
|
114
|
-
else:
|
|
115
|
-
# Handle NumPy dtypes and other types
|
|
116
|
-
try:
|
|
117
|
-
key_type = np.dtype(dtype_like).type
|
|
118
|
-
except TypeError:
|
|
119
|
-
# Handle cases where dtype_like might already be a type object
|
|
120
|
-
# that np.dtype() can't process but is a valid key.
|
|
121
|
-
if isinstance(dtype_like, type) and issubclass(dtype_like, np.generic):
|
|
122
|
-
key_type = dtype_like
|
|
123
|
-
else:
|
|
124
|
-
raise ValueError(f"Invalid dtype: {dtype_like}") from None
|
|
125
|
-
|
|
126
|
-
if key_type in DTYPE_MAPPING:
|
|
127
|
-
return DTYPE_MAPPING[key_type]
|
|
128
|
-
else:
|
|
129
|
-
raise ValueError(f"Unsupported dtype: {dtype_like}")
|
|
130
|
-
|
|
131
|
-
|
|
132
|
-
def proto_to_dtype(dtype_enum: int) -> DType:
|
|
133
|
-
"""Convert protobuf DataType enum to DType.
|
|
134
|
-
|
|
135
|
-
Args:
|
|
136
|
-
dtype_enum: The protobuf DataType enum value to convert.
|
|
137
|
-
|
|
138
|
-
Returns:
|
|
139
|
-
The corresponding DType object.
|
|
140
|
-
|
|
141
|
-
Raises:
|
|
142
|
-
ValueError: If the enum value is not supported.
|
|
143
|
-
"""
|
|
144
|
-
# Check for table-only types first
|
|
145
|
-
for dtype_obj, proto_enum in DTYPE_TO_PROTO_MAPPING.items():
|
|
146
|
-
if proto_enum == dtype_enum:
|
|
147
|
-
return dtype_obj
|
|
148
|
-
|
|
149
|
-
# Find the numpy type for the given enum by searching the mapping
|
|
150
|
-
for numpy_type, proto_enum in DTYPE_MAPPING.items():
|
|
151
|
-
if proto_enum == dtype_enum:
|
|
152
|
-
# Convert numpy type to dtype
|
|
153
|
-
try:
|
|
154
|
-
np_dtype = np.dtype(numpy_type)
|
|
155
|
-
except TypeError as e:
|
|
156
|
-
raise ValueError(f"Cannot create numpy dtype from {numpy_type}") from e
|
|
157
|
-
|
|
158
|
-
# Special handling for string types since DType.from_numpy doesn't support them
|
|
159
|
-
if np_dtype.kind == "U": # Unicode string
|
|
160
|
-
# Return the STRING constant for table-only string types
|
|
161
|
-
return STRING
|
|
162
|
-
else:
|
|
163
|
-
try:
|
|
164
|
-
return DType.from_numpy(np_dtype)
|
|
165
|
-
except ValueError as e:
|
|
166
|
-
raise ValueError(
|
|
167
|
-
f"Cannot convert numpy dtype {np_dtype} to DType"
|
|
168
|
-
) from e
|
|
169
|
-
|
|
170
|
-
# If we get here, the enum was not found
|
|
171
|
-
raise ValueError(f"Unsupported dtype enum: {dtype_enum}")
|
|
172
|
-
|
|
173
|
-
|
|
174
|
-
def attr_to_proto(py_value: Any) -> mpir_pb2.AttrProto:
|
|
175
|
-
"""Convert a Python attribute value to an AttrProto."""
|
|
176
|
-
attr_proto = mpir_pb2.AttrProto()
|
|
177
|
-
if isinstance(py_value, int):
|
|
178
|
-
attr_proto.type = mpir_pb2.AttrProto.INT
|
|
179
|
-
attr_proto.i = py_value
|
|
180
|
-
elif isinstance(py_value, float):
|
|
181
|
-
attr_proto.type = mpir_pb2.AttrProto.FLOAT
|
|
182
|
-
attr_proto.f = py_value
|
|
183
|
-
elif isinstance(py_value, str):
|
|
184
|
-
attr_proto.type = mpir_pb2.AttrProto.STRING
|
|
185
|
-
attr_proto.s = py_value
|
|
186
|
-
elif isinstance(py_value, bytes):
|
|
187
|
-
attr_proto.type = mpir_pb2.AttrProto.BYTES
|
|
188
|
-
attr_proto.raw_bytes = py_value
|
|
189
|
-
elif isinstance(py_value, tuple | list):
|
|
190
|
-
if all(isinstance(item, int) for item in py_value):
|
|
191
|
-
attr_proto.type = mpir_pb2.AttrProto.INTS
|
|
192
|
-
attr_proto.ints.extend(list(py_value))
|
|
193
|
-
elif all(isinstance(item, float) for item in py_value):
|
|
194
|
-
attr_proto.type = mpir_pb2.AttrProto.FLOATS
|
|
195
|
-
attr_proto.floats.extend(list(py_value))
|
|
196
|
-
elif all(isinstance(item, str) for item in py_value):
|
|
197
|
-
attr_proto.type = mpir_pb2.AttrProto.STRINGS
|
|
198
|
-
attr_proto.strs.extend(list(py_value))
|
|
199
|
-
elif all(isinstance(item, spu_api.Visibility) for item in py_value):
|
|
200
|
-
# Handle list of enum types (like [Visibility.VIS_SECRET, Visibility.VIS_SECRET])
|
|
201
|
-
attr_proto.type = mpir_pb2.AttrProto.INTS
|
|
202
|
-
attr_proto.ints.extend([int(item) for item in py_value])
|
|
203
|
-
else:
|
|
204
|
-
raise TypeError(f"Unsupported tuple/list type: {type(py_value)}")
|
|
205
|
-
elif isinstance(py_value, FuncDefExpr):
|
|
206
|
-
# Convert FuncDefExpr to GraphProto
|
|
207
|
-
graph = IrWriter().dumps(py_value)
|
|
208
|
-
attr_proto.type = mpir_pb2.AttrProto.GRAPH
|
|
209
|
-
attr_proto.graph.CopyFrom(graph)
|
|
210
|
-
elif isinstance(py_value, PFunction):
|
|
211
|
-
attr_proto.type = mpir_pb2.AttrProto.FUNCTION
|
|
212
|
-
attr_proto.func.type = py_value.fn_type
|
|
213
|
-
attr_proto.func.name = py_value.fn_name or ""
|
|
214
|
-
if py_value.fn_text is not None:
|
|
215
|
-
attr_proto.func.body = str(py_value.fn_text)
|
|
216
|
-
|
|
217
|
-
# Serialize attrs dictionary
|
|
218
|
-
if py_value.attrs:
|
|
219
|
-
for attr_name, attr_value in py_value.attrs.items():
|
|
220
|
-
# Skip None-valued attributes to align with top-level attr handling
|
|
221
|
-
if attr_value is not None:
|
|
222
|
-
attr_proto.func.attrs[attr_name].CopyFrom(attr_to_proto(attr_value))
|
|
223
|
-
|
|
224
|
-
# Note: We don't serialize ins_info and outs_info since they can be
|
|
225
|
-
# inferred from the input expressions during deserialization
|
|
226
|
-
elif isinstance(py_value, spu_api.Visibility):
|
|
227
|
-
# Handle enum types (like spu.libspu.Visibility) by storing as int
|
|
228
|
-
attr_proto.type = mpir_pb2.AttrProto.INT
|
|
229
|
-
attr_proto.i = int(py_value)
|
|
230
|
-
elif isinstance(py_value, Mask):
|
|
231
|
-
# Handle Mask objects by storing as int
|
|
232
|
-
attr_proto.type = mpir_pb2.AttrProto.INT
|
|
233
|
-
attr_proto.i = int(py_value)
|
|
234
|
-
else:
|
|
235
|
-
raise TypeError(f"Unsupported attribute type: {type(py_value)}")
|
|
236
|
-
return attr_proto
|
|
237
|
-
|
|
238
|
-
|
|
239
|
-
class IrWriter:
|
|
240
|
-
"""Writer for serializing Expr-based expressions to GraphProto.
|
|
241
|
-
|
|
242
|
-
This class traverses an expression tree and converts it into a serialized
|
|
243
|
-
GraphProto representation. It handles various expression types and ensures
|
|
244
|
-
that all dependencies are properly serialized before the expressions that
|
|
245
|
-
depend on them.
|
|
246
|
-
"""
|
|
247
|
-
|
|
248
|
-
def __init__(self, var_name_mapping: dict[str, str] | None = None):
|
|
249
|
-
"""Initialize the Writer.
|
|
250
|
-
|
|
251
|
-
Args:
|
|
252
|
-
var_name_mapping: Optional mapping of variable names to replace during serialization.
|
|
253
|
-
"""
|
|
254
|
-
self._counter = 0
|
|
255
|
-
self._expr_ids: dict[int, str] = {} # Use expr id instead of Node
|
|
256
|
-
self._nodes: list[mpir_pb2.NodeProto] = []
|
|
257
|
-
self._var_name_mapping = var_name_mapping or {}
|
|
258
|
-
|
|
259
|
-
def expr_name(self, expr: Expr) -> str:
|
|
260
|
-
"""Get or create a name for an expression.
|
|
261
|
-
|
|
262
|
-
Args:
|
|
263
|
-
expr: The expression to name.
|
|
264
|
-
|
|
265
|
-
Returns:
|
|
266
|
-
A unique name for the expression.
|
|
267
|
-
"""
|
|
268
|
-
expr_id = id(expr)
|
|
269
|
-
if expr_id not in self._expr_ids:
|
|
270
|
-
self._expr_ids[expr_id] = f"%{self._counter}"
|
|
271
|
-
self._counter += 1
|
|
272
|
-
return self._expr_ids[expr_id]
|
|
273
|
-
|
|
274
|
-
def value_name(self, expr: Expr, out_idx: int = 0) -> str:
|
|
275
|
-
"""Get value name for expression output.
|
|
276
|
-
|
|
277
|
-
Args:
|
|
278
|
-
expr: The expression.
|
|
279
|
-
out_idx: The output index for multi-output expressions.
|
|
280
|
-
|
|
281
|
-
Returns:
|
|
282
|
-
A name for the specific output of the expression.
|
|
283
|
-
"""
|
|
284
|
-
if len(expr.mptypes) == 1:
|
|
285
|
-
return self.expr_name(expr)
|
|
286
|
-
else:
|
|
287
|
-
return f"{self.expr_name(expr)}:{out_idx}"
|
|
288
|
-
|
|
289
|
-
# ------------------------- traversal and deps helpers -------------------------
|
|
290
|
-
@staticmethod
|
|
291
|
-
def _writer_deps(node: Expr) -> list[Expr]:
|
|
292
|
-
"""Dependencies for serialization order.
|
|
293
|
-
|
|
294
|
-
Similar to dataflow deps, but with two important differences:
|
|
295
|
-
- CallExpr: include the function value (fn) so we emit a func_def node
|
|
296
|
-
in the outer graph before the call node.
|
|
297
|
-
- FuncDefExpr: include body so we emit body producers before func_def.
|
|
298
|
-
"""
|
|
299
|
-
if isinstance(node, EvalExpr):
|
|
300
|
-
return list(node.args)
|
|
301
|
-
if isinstance(node, TupleExpr):
|
|
302
|
-
return list(node.args)
|
|
303
|
-
if isinstance(node, CondExpr):
|
|
304
|
-
# pred and actual args only; functions are serialized via attrs (nested graphs)
|
|
305
|
-
return [node.pred, *node.args]
|
|
306
|
-
if isinstance(node, WhileExpr):
|
|
307
|
-
# initial state args only; functions are serialized via attrs (nested graphs)
|
|
308
|
-
return list(node.args)
|
|
309
|
-
if isinstance(node, ConvExpr):
|
|
310
|
-
return list(node.vars)
|
|
311
|
-
if isinstance(node, ShflSExpr):
|
|
312
|
-
return [node.src_val]
|
|
313
|
-
if isinstance(node, ShflExpr):
|
|
314
|
-
return [node.src, node.index]
|
|
315
|
-
if isinstance(node, AccessExpr):
|
|
316
|
-
return [node.src]
|
|
317
|
-
if isinstance(node, VariableExpr):
|
|
318
|
-
return []
|
|
319
|
-
if isinstance(node, FuncDefExpr):
|
|
320
|
-
# ensure body producers are serialized first
|
|
321
|
-
return [node.body]
|
|
322
|
-
if isinstance(node, CallExpr):
|
|
323
|
-
# include fn and args as deps so func_def appears before call
|
|
324
|
-
return [node.fn, *node.args]
|
|
325
|
-
return []
|
|
326
|
-
|
|
327
|
-
def reset(self) -> None:
|
|
328
|
-
"""Reset writer state.
|
|
329
|
-
|
|
330
|
-
Clears all internal state, allowing the writer to be reused for
|
|
331
|
-
serializing a new expression tree.
|
|
332
|
-
"""
|
|
333
|
-
self._counter = 0
|
|
334
|
-
self._expr_ids.clear()
|
|
335
|
-
self._nodes.clear()
|
|
336
|
-
|
|
337
|
-
def _create_node_proto(self, expr: Expr, op_type: str) -> mpir_pb2.NodeProto:
|
|
338
|
-
"""Helper: Create a basic NodeProto with common fields set.
|
|
339
|
-
|
|
340
|
-
Args:
|
|
341
|
-
expr: The expression this node represents.
|
|
342
|
-
op_type: The operation type for this node.
|
|
343
|
-
|
|
344
|
-
Returns:
|
|
345
|
-
A new NodeProto with basic fields set.
|
|
346
|
-
"""
|
|
347
|
-
op = mpir_pb2.NodeProto()
|
|
348
|
-
op.op_type = op_type
|
|
349
|
-
op.name = self.expr_name(expr)
|
|
350
|
-
return op
|
|
351
|
-
|
|
352
|
-
def _add_output_info(self, op: mpir_pb2.NodeProto, expr: Expr) -> None:
|
|
353
|
-
"""Helper: Add output type information to a NodeProto.
|
|
354
|
-
|
|
355
|
-
This method populates the output type information for a node based
|
|
356
|
-
on the expression's mptypes.
|
|
357
|
-
|
|
358
|
-
Args:
|
|
359
|
-
op: The NodeProto to populate.
|
|
360
|
-
expr: The expression providing the type information.
|
|
361
|
-
"""
|
|
362
|
-
for out_info in expr.mptypes:
|
|
363
|
-
out_proto = op.outs_info.add()
|
|
364
|
-
|
|
365
|
-
if out_info.is_tensor:
|
|
366
|
-
# Handle tensor type
|
|
367
|
-
tensor_type = out_proto.tensor_type
|
|
368
|
-
tensor_type.dtype = dtype_to_proto(out_info.dtype)
|
|
369
|
-
tensor_type.shape_dims.extend(list(out_info.shape))
|
|
370
|
-
elif out_info.is_table:
|
|
371
|
-
# Handle table type
|
|
372
|
-
table_type = out_proto.table_type
|
|
373
|
-
for col_name, col_dtype in out_info.schema.columns:
|
|
374
|
-
column = table_type.columns.add()
|
|
375
|
-
column.name = col_name
|
|
376
|
-
column.dtype = dtype_to_proto(col_dtype)
|
|
377
|
-
|
|
378
|
-
# Set pmask (now int64, -1 for dynamic mask)
|
|
379
|
-
if out_info.pmask is not None:
|
|
380
|
-
out_proto.pmask = int(out_info.pmask)
|
|
381
|
-
else:
|
|
382
|
-
out_proto.pmask = -1 # Dynamic mask
|
|
383
|
-
|
|
384
|
-
def _add_expr_inputs(self, op: mpir_pb2.NodeProto, *exprs: Expr) -> None:
|
|
385
|
-
"""Helper: Add expression inputs to NodeProto.
|
|
386
|
-
|
|
387
|
-
For multi-output expressions, this adds all outputs as inputs.
|
|
388
|
-
|
|
389
|
-
Args:
|
|
390
|
-
op: The NodeProto to add inputs to.
|
|
391
|
-
exprs: The expressions to add as inputs.
|
|
392
|
-
"""
|
|
393
|
-
for expr in exprs:
|
|
394
|
-
op.inputs.extend([
|
|
395
|
-
self.value_name(expr, i) for i in range(len(expr.mptypes))
|
|
396
|
-
])
|
|
397
|
-
|
|
398
|
-
def _add_single_expr_inputs(self, op: mpir_pb2.NodeProto, *exprs: Expr) -> None:
|
|
399
|
-
"""Helper: Add single-output expression inputs to NodeProto.
|
|
400
|
-
|
|
401
|
-
For expressions, this adds only the first (primary) output as input.
|
|
402
|
-
|
|
403
|
-
Args:
|
|
404
|
-
op: The NodeProto to add inputs to.
|
|
405
|
-
exprs: The expressions to add as inputs.
|
|
406
|
-
"""
|
|
407
|
-
for expr in exprs:
|
|
408
|
-
op.inputs.append(self.value_name(expr, 0))
|
|
409
|
-
|
|
410
|
-
def _add_attrs(self, op: mpir_pb2.NodeProto, **attrs: Any) -> None:
|
|
411
|
-
"""Helper: Add attributes to NodeProto.
|
|
412
|
-
|
|
413
|
-
Args:
|
|
414
|
-
op: The NodeProto to add attributes to.
|
|
415
|
-
**attrs: The attributes to add (key-value pairs).
|
|
416
|
-
"""
|
|
417
|
-
for key, value in attrs.items():
|
|
418
|
-
if value is not None: # Skip None values
|
|
419
|
-
op.attrs[key].CopyFrom(attr_to_proto(value))
|
|
420
|
-
|
|
421
|
-
def _finalize_node(self, op: mpir_pb2.NodeProto, expr: Expr) -> str:
|
|
422
|
-
"""Helper: Add output info, append to nodes, and return expr name.
|
|
423
|
-
|
|
424
|
-
This method completes the node creation process by adding output
|
|
425
|
-
information, appending the node to the list of nodes, and returning
|
|
426
|
-
the expression name.
|
|
427
|
-
|
|
428
|
-
Args:
|
|
429
|
-
op: The completed NodeProto.
|
|
430
|
-
expr: The expression the node represents.
|
|
431
|
-
|
|
432
|
-
Returns:
|
|
433
|
-
The name of the expression.
|
|
434
|
-
"""
|
|
435
|
-
self._add_output_info(op, expr)
|
|
436
|
-
self._nodes.append(op)
|
|
437
|
-
return self.expr_name(expr)
|
|
438
|
-
|
|
439
|
-
def dumps(self, expr: Expr) -> mpir_pb2.GraphProto:
|
|
440
|
-
"""Dump an expression to GraphProto using iterative walk traversal."""
|
|
441
|
-
self.reset()
|
|
442
|
-
|
|
443
|
-
# Walk in post-order so deps are serialized before users
|
|
444
|
-
for node in walk(expr, get_deps=self._writer_deps, traversal="dfs_post_iter"):
|
|
445
|
-
# Avoid double-emit if the same Expr object appears multiple times
|
|
446
|
-
node_id = id(node)
|
|
447
|
-
if node_id in self._expr_ids:
|
|
448
|
-
continue
|
|
449
|
-
# Emit node
|
|
450
|
-
self._serialize_node(node)
|
|
451
|
-
|
|
452
|
-
# Create graph metadata
|
|
453
|
-
graph_attrs = {}
|
|
454
|
-
if isinstance(expr, FuncDefExpr):
|
|
455
|
-
graph_attrs["name"] = attr_to_proto(f"function_{id(expr)}")
|
|
456
|
-
# For function definitions, the outputs should be the FuncDefExpr itself
|
|
457
|
-
outputs = [self.value_name(expr, i) for i in range(len(expr.mptypes))]
|
|
458
|
-
else:
|
|
459
|
-
# For regular expressions, outputs are the expression outputs
|
|
460
|
-
outputs = [self.value_name(expr, i) for i in range(len(expr.mptypes))]
|
|
461
|
-
|
|
462
|
-
return mpir_pb2.GraphProto(
|
|
463
|
-
version=mpir_pb2.VersionInfo(major=1, minor=0, patch=0),
|
|
464
|
-
nodes=self._nodes,
|
|
465
|
-
outputs=outputs,
|
|
466
|
-
attrs=graph_attrs,
|
|
467
|
-
)
|
|
468
|
-
|
|
469
|
-
# ------------------------------- emitters --------------------------------
|
|
470
|
-
def _serialize_node(self, expr: Expr) -> None:
|
|
471
|
-
"""Create and append a NodeProto for the given expr."""
|
|
472
|
-
if isinstance(expr, EvalExpr):
|
|
473
|
-
op = self._create_node_proto(expr, "eval")
|
|
474
|
-
self._add_expr_inputs(op, *expr.args)
|
|
475
|
-
self._add_attrs(op, pfunc=expr.pfunc, rmask=expr.rmask)
|
|
476
|
-
self._finalize_node(op, expr)
|
|
477
|
-
elif isinstance(expr, VariableExpr):
|
|
478
|
-
op = self._create_node_proto(expr, "variable")
|
|
479
|
-
mapped_name = self._var_name_mapping.get(expr.name, expr.name)
|
|
480
|
-
self._add_attrs(op, name=mapped_name)
|
|
481
|
-
self._finalize_node(op, expr)
|
|
482
|
-
elif isinstance(expr, TupleExpr):
|
|
483
|
-
op = self._create_node_proto(expr, "tuple")
|
|
484
|
-
self._add_single_expr_inputs(op, *expr.args)
|
|
485
|
-
self._finalize_node(op, expr)
|
|
486
|
-
elif isinstance(expr, CondExpr):
|
|
487
|
-
op = self._create_node_proto(expr, "cond")
|
|
488
|
-
self._add_single_expr_inputs(op, expr.pred)
|
|
489
|
-
self._add_expr_inputs(op, *expr.args)
|
|
490
|
-
self._add_attrs(op, then_fn=expr.then_fn, else_fn=expr.else_fn)
|
|
491
|
-
self._finalize_node(op, expr)
|
|
492
|
-
elif isinstance(expr, CallExpr):
|
|
493
|
-
op = self._create_node_proto(expr, "call")
|
|
494
|
-
self._add_single_expr_inputs(op, expr.fn)
|
|
495
|
-
self._add_expr_inputs(op, *expr.args)
|
|
496
|
-
self._add_attrs(op, name=expr.name)
|
|
497
|
-
self._finalize_node(op, expr)
|
|
498
|
-
elif isinstance(expr, WhileExpr):
|
|
499
|
-
op = self._create_node_proto(expr, "while")
|
|
500
|
-
self._add_expr_inputs(op, *expr.args)
|
|
501
|
-
self._add_attrs(op, cond_fn=expr.cond_fn, body_fn=expr.body_fn)
|
|
502
|
-
self._finalize_node(op, expr)
|
|
503
|
-
elif isinstance(expr, ConvExpr):
|
|
504
|
-
op = self._create_node_proto(expr, "conv")
|
|
505
|
-
self._add_expr_inputs(op, *expr.vars)
|
|
506
|
-
self._finalize_node(op, expr)
|
|
507
|
-
elif isinstance(expr, ShflSExpr):
|
|
508
|
-
op = self._create_node_proto(expr, "shfl_s")
|
|
509
|
-
self._add_single_expr_inputs(op, expr.src_val)
|
|
510
|
-
self._add_attrs(op, pmask=expr.pmask, src_ranks=expr.src_ranks)
|
|
511
|
-
self._finalize_node(op, expr)
|
|
512
|
-
elif isinstance(expr, ShflExpr):
|
|
513
|
-
op = self._create_node_proto(expr, "shfl")
|
|
514
|
-
self._add_single_expr_inputs(op, expr.src, expr.index)
|
|
515
|
-
self._finalize_node(op, expr)
|
|
516
|
-
elif isinstance(expr, AccessExpr):
|
|
517
|
-
op = self._create_node_proto(expr, "access")
|
|
518
|
-
op.inputs.append(self.value_name(expr.src, expr.index))
|
|
519
|
-
self._add_attrs(op, index=expr.index)
|
|
520
|
-
self._finalize_node(op, expr)
|
|
521
|
-
elif isinstance(expr, FuncDefExpr):
|
|
522
|
-
op = self._create_node_proto(expr, "func_def")
|
|
523
|
-
self._add_expr_inputs(op, expr.body)
|
|
524
|
-
self._add_attrs(op, params=expr.params)
|
|
525
|
-
self._finalize_node(op, expr)
|
|
526
|
-
else:
|
|
527
|
-
raise TypeError(f"Unsupported expr type for serialization: {type(expr)}")
|
|
528
|
-
|
|
529
|
-
|
|
530
|
-
class IrReader:
|
|
531
|
-
"""Reader for deserializing GraphProto back to Expr-based expressions.
|
|
532
|
-
|
|
533
|
-
This class is responsible for converting serialized GraphProto representations
|
|
534
|
-
back into executable expression trees. It handles the deserialization of
|
|
535
|
-
various node types and manages dependencies between nodes to ensure proper
|
|
536
|
-
reconstruction of the expression graph.
|
|
537
|
-
"""
|
|
538
|
-
|
|
539
|
-
def __init__(self) -> None:
|
|
540
|
-
self._value_cache: dict[str, Expr] = {}
|
|
541
|
-
|
|
542
|
-
def loads(self, graph_proto: mpir_pb2.GraphProto) -> Expr | None:
|
|
543
|
-
"""Load an expression from a GraphProto.
|
|
544
|
-
|
|
545
|
-
Args:
|
|
546
|
-
graph_proto: The protobuf graph to deserialize
|
|
547
|
-
|
|
548
|
-
Returns:
|
|
549
|
-
The deserialized expression or None if empty
|
|
550
|
-
"""
|
|
551
|
-
self._value_cache.clear()
|
|
552
|
-
|
|
553
|
-
# Create a mapping for faster node lookup, checking for duplicate node names
|
|
554
|
-
node_map = {}
|
|
555
|
-
for node in graph_proto.nodes:
|
|
556
|
-
if node.name in node_map:
|
|
557
|
-
raise ValueError(
|
|
558
|
-
f"Duplicate node name detected in graph: '{node.name}'"
|
|
559
|
-
)
|
|
560
|
-
node_map[node.name] = node
|
|
561
|
-
|
|
562
|
-
# Process nodes in topological order
|
|
563
|
-
processed_nodes = set()
|
|
564
|
-
|
|
565
|
-
def process_node(node_proto: mpir_pb2.NodeProto) -> None:
|
|
566
|
-
"""Process a single node and its dependencies."""
|
|
567
|
-
if node_proto.name in processed_nodes:
|
|
568
|
-
return
|
|
569
|
-
|
|
570
|
-
# First process all dependencies
|
|
571
|
-
for input_name in node_proto.inputs:
|
|
572
|
-
dep_node_name = input_name.split(":")[0]
|
|
573
|
-
if dep_node_name in node_map and dep_node_name not in processed_nodes:
|
|
574
|
-
process_node(node_map[dep_node_name])
|
|
575
|
-
|
|
576
|
-
# Now process this node
|
|
577
|
-
try:
|
|
578
|
-
expr = self._create_expr_from_proto(node_proto)
|
|
579
|
-
processed_nodes.add(node_proto.name)
|
|
580
|
-
# Cache the expression
|
|
581
|
-
self._value_cache[node_proto.name] = expr
|
|
582
|
-
except Exception as e:
|
|
583
|
-
raise ValueError(
|
|
584
|
-
f"Error processing node '{node_proto.name}' "
|
|
585
|
-
f"of type '{node_proto.op_type}': {e!s}"
|
|
586
|
-
) from e
|
|
587
|
-
|
|
588
|
-
# Process all nodes
|
|
589
|
-
for node_proto in graph_proto.nodes:
|
|
590
|
-
process_node(node_proto)
|
|
591
|
-
|
|
592
|
-
# Extract outputs - for now, just return the first output expression
|
|
593
|
-
if graph_proto.outputs:
|
|
594
|
-
output_name = graph_proto.outputs[0].split(":")[0]
|
|
595
|
-
if output_name in self._value_cache:
|
|
596
|
-
return self._value_cache[output_name]
|
|
597
|
-
else:
|
|
598
|
-
raise ValueError(f"Output {output_name} not found in processed nodes")
|
|
599
|
-
|
|
600
|
-
return None
|
|
601
|
-
|
|
602
|
-
def _create_expr_from_proto(self, node_proto: mpir_pb2.NodeProto) -> Expr:
|
|
603
|
-
"""Create an Expression from a NodeProto.
|
|
604
|
-
|
|
605
|
-
This method delegates to specific creation methods based on the node type.
|
|
606
|
-
"""
|
|
607
|
-
# Dispatch to appropriate creation method based on op_type
|
|
608
|
-
creation_methods = {
|
|
609
|
-
"eval": self._create_eval_expr,
|
|
610
|
-
"variable": self._create_variable_expr,
|
|
611
|
-
"tuple": self._create_tuple_expr,
|
|
612
|
-
"cond": self._create_cond_expr,
|
|
613
|
-
"while": self._create_while_expr,
|
|
614
|
-
"access": self._create_access_expr,
|
|
615
|
-
"func_def": self._create_func_def_expr,
|
|
616
|
-
"shfl_s": self._create_shfl_s_expr,
|
|
617
|
-
"shfl": self._create_shfl_expr,
|
|
618
|
-
"conv": self._create_conv_expr,
|
|
619
|
-
"call": self._create_call_expr,
|
|
620
|
-
}
|
|
621
|
-
|
|
622
|
-
if node_proto.op_type in creation_methods:
|
|
623
|
-
return creation_methods[node_proto.op_type](node_proto)
|
|
624
|
-
else:
|
|
625
|
-
raise ValueError(f"Unsupported node type: {node_proto.op_type}")
|
|
626
|
-
|
|
627
|
-
def _create_eval_expr(self, node_proto: mpir_pb2.NodeProto) -> EvalExpr:
|
|
628
|
-
"""Create an EvalExpr from a NodeProto."""
|
|
629
|
-
# Parse inputs
|
|
630
|
-
input_exprs = []
|
|
631
|
-
for input_name in node_proto.inputs:
|
|
632
|
-
dep_name = input_name.split(":")[0]
|
|
633
|
-
if dep_name in self._value_cache:
|
|
634
|
-
input_exprs.append(self._value_cache[dep_name])
|
|
635
|
-
else:
|
|
636
|
-
raise ValueError(f"Input {input_name} not found for eval node")
|
|
637
|
-
|
|
638
|
-
# Parse function
|
|
639
|
-
pfunc = self._proto_to_attr(node_proto.attrs["pfunc"])
|
|
640
|
-
rmask = None
|
|
641
|
-
if "rmask" in node_proto.attrs:
|
|
642
|
-
rmask = self._proto_to_attr(node_proto.attrs["rmask"])
|
|
643
|
-
|
|
644
|
-
# Fill in ins_info and outs_info for PFunction
|
|
645
|
-
# ins_info from input expressions (use mptype for single type per value)
|
|
646
|
-
ins_info: list[TensorType | TableType] = []
|
|
647
|
-
for input_expr in input_exprs:
|
|
648
|
-
# Use mptype directly for single MPType
|
|
649
|
-
mptype = input_expr.mptype
|
|
650
|
-
if mptype.is_tensor:
|
|
651
|
-
ins_info.append(TensorType(mptype.dtype, mptype.shape))
|
|
652
|
-
elif mptype.is_table:
|
|
653
|
-
ins_info.append(mptype.schema)
|
|
654
|
-
else:
|
|
655
|
-
raise ValueError(f"unsupported type: {mptype}")
|
|
656
|
-
|
|
657
|
-
# outs_info from NodeProto.outs_info
|
|
658
|
-
outs_info: list[TensorType | TableType] = []
|
|
659
|
-
for out_proto in node_proto.outs_info:
|
|
660
|
-
if out_proto.HasField("tensor_type"):
|
|
661
|
-
tensor_type_proto = out_proto.tensor_type
|
|
662
|
-
dtype = proto_to_dtype(tensor_type_proto.dtype)
|
|
663
|
-
shape = tuple(tensor_type_proto.shape_dims)
|
|
664
|
-
outs_info.append(TensorType(dtype, shape))
|
|
665
|
-
elif out_proto.HasField("table_type"):
|
|
666
|
-
columns = [
|
|
667
|
-
(col.name, proto_to_dtype(col.dtype))
|
|
668
|
-
for col in out_proto.table_type.columns
|
|
669
|
-
]
|
|
670
|
-
outs_info.append(TableType.from_pairs(columns))
|
|
671
|
-
else:
|
|
672
|
-
raise ValueError("Eval node currently only supports tensor types")
|
|
673
|
-
|
|
674
|
-
# Create a complete PFunction with proper type information
|
|
675
|
-
complete_pfunc = PFunction(
|
|
676
|
-
fn_type=pfunc.fn_type,
|
|
677
|
-
ins_info=ins_info,
|
|
678
|
-
outs_info=outs_info,
|
|
679
|
-
fn_name=pfunc.fn_name,
|
|
680
|
-
fn_text=pfunc.fn_text,
|
|
681
|
-
**pfunc.attrs, # Restore attributes
|
|
682
|
-
)
|
|
683
|
-
|
|
684
|
-
return EvalExpr(complete_pfunc, input_exprs, rmask)
|
|
685
|
-
|
|
686
|
-
def _create_variable_expr(self, node_proto: mpir_pb2.NodeProto) -> VariableExpr:
|
|
687
|
-
"""Create a VariableExpr from a NodeProto."""
|
|
688
|
-
# Parse variable name
|
|
689
|
-
name = self._proto_to_attr(node_proto.attrs["name"])
|
|
690
|
-
|
|
691
|
-
# Parse type info from output info (VariableExpr needs a single MPType)
|
|
692
|
-
if not node_proto.outs_info:
|
|
693
|
-
raise ValueError("Variable node missing output info")
|
|
694
|
-
|
|
695
|
-
mptype = self._proto_to_mptype(node_proto.outs_info[0])
|
|
696
|
-
return VariableExpr(name, mptype)
|
|
697
|
-
|
|
698
|
-
def _create_tuple_expr(self, node_proto: mpir_pb2.NodeProto) -> TupleExpr:
|
|
699
|
-
"""Create a TupleExpr from a NodeProto."""
|
|
700
|
-
# Parse inputs
|
|
701
|
-
input_exprs = []
|
|
702
|
-
for input_name in node_proto.inputs:
|
|
703
|
-
dep_name = input_name.split(":")[0]
|
|
704
|
-
if dep_name in self._value_cache:
|
|
705
|
-
input_exprs.append(self._value_cache[dep_name])
|
|
706
|
-
else:
|
|
707
|
-
raise ValueError(f"Input {input_name} not found for tuple node")
|
|
708
|
-
|
|
709
|
-
return TupleExpr(input_exprs)
|
|
710
|
-
|
|
711
|
-
def _create_cond_expr(self, node_proto: mpir_pb2.NodeProto) -> CondExpr:
|
|
712
|
-
"""Create a CondExpr from a NodeProto."""
|
|
713
|
-
# Parse predicate and arguments
|
|
714
|
-
pred_name = node_proto.inputs[0].split(":")[0]
|
|
715
|
-
pred_expr = self._value_cache[pred_name]
|
|
716
|
-
|
|
717
|
-
arg_exprs = []
|
|
718
|
-
for input_name in node_proto.inputs[1:]:
|
|
719
|
-
dep_name = input_name.split(":")[0]
|
|
720
|
-
if dep_name in self._value_cache:
|
|
721
|
-
arg_exprs.append(self._value_cache[dep_name])
|
|
722
|
-
else:
|
|
723
|
-
raise ValueError(f"Input {input_name} not found for cond node")
|
|
724
|
-
|
|
725
|
-
# Parse functions
|
|
726
|
-
then_fn = self._proto_to_attr(node_proto.attrs["then_fn"])
|
|
727
|
-
else_fn = self._proto_to_attr(node_proto.attrs["else_fn"])
|
|
728
|
-
|
|
729
|
-
return CondExpr(pred_expr, then_fn, else_fn, arg_exprs)
|
|
730
|
-
|
|
731
|
-
def _create_while_expr(self, node_proto: mpir_pb2.NodeProto) -> WhileExpr:
|
|
732
|
-
"""Create a WhileExpr from a NodeProto."""
|
|
733
|
-
# Parse arguments
|
|
734
|
-
arg_exprs = []
|
|
735
|
-
for input_name in node_proto.inputs:
|
|
736
|
-
dep_name = input_name.split(":")[0]
|
|
737
|
-
if dep_name in self._value_cache:
|
|
738
|
-
arg_exprs.append(self._value_cache[dep_name])
|
|
739
|
-
else:
|
|
740
|
-
raise ValueError(f"Input {input_name} not found for while node")
|
|
741
|
-
|
|
742
|
-
# Parse functions
|
|
743
|
-
cond_fn = self._proto_to_attr(node_proto.attrs["cond_fn"])
|
|
744
|
-
body_fn = self._proto_to_attr(node_proto.attrs["body_fn"])
|
|
745
|
-
|
|
746
|
-
return WhileExpr(cond_fn, body_fn, arg_exprs)
|
|
747
|
-
|
|
748
|
-
def _create_access_expr(self, node_proto: mpir_pb2.NodeProto) -> AccessExpr:
|
|
749
|
-
"""Create an AccessExpr from a NodeProto."""
|
|
750
|
-
# Parse source expression
|
|
751
|
-
input_name = node_proto.inputs[0]
|
|
752
|
-
dep_name = input_name.split(":")[0]
|
|
753
|
-
src_expr = self._value_cache[dep_name]
|
|
754
|
-
|
|
755
|
-
# Parse index
|
|
756
|
-
index = self._proto_to_attr(node_proto.attrs["index"])
|
|
757
|
-
|
|
758
|
-
return AccessExpr(src_expr, index)
|
|
759
|
-
|
|
760
|
-
def _create_func_def_expr(self, node_proto: mpir_pb2.NodeProto) -> FuncDefExpr:
|
|
761
|
-
"""Create a FuncDefExpr from a NodeProto."""
|
|
762
|
-
# Parse body expression
|
|
763
|
-
input_names = node_proto.inputs
|
|
764
|
-
if not input_names:
|
|
765
|
-
raise ValueError("FuncDef node missing body input")
|
|
766
|
-
|
|
767
|
-
body_name = input_names[0].split(":")[0]
|
|
768
|
-
body_expr = self._value_cache[body_name]
|
|
769
|
-
|
|
770
|
-
# Parse parameters
|
|
771
|
-
params = self._proto_to_attr(node_proto.attrs["params"])
|
|
772
|
-
|
|
773
|
-
return FuncDefExpr(params, body_expr)
|
|
774
|
-
|
|
775
|
-
def _create_shfl_s_expr(self, node_proto: mpir_pb2.NodeProto) -> ShflSExpr:
|
|
776
|
-
"""Create a ShflSExpr from a NodeProto."""
|
|
777
|
-
# Parse source expression
|
|
778
|
-
input_name = node_proto.inputs[0]
|
|
779
|
-
dep_name = input_name.split(":")[0]
|
|
780
|
-
src_val = self._value_cache[dep_name]
|
|
781
|
-
|
|
782
|
-
# Parse attributes
|
|
783
|
-
pmask = self._proto_to_attr(node_proto.attrs["pmask"])
|
|
784
|
-
src_ranks = self._proto_to_attr(node_proto.attrs["src_ranks"])
|
|
785
|
-
|
|
786
|
-
return ShflSExpr(src_val, pmask, src_ranks)
|
|
787
|
-
|
|
788
|
-
def _create_shfl_expr(self, node_proto: mpir_pb2.NodeProto) -> ShflExpr:
|
|
789
|
-
"""Create a ShflExpr from a NodeProto."""
|
|
790
|
-
# Parse source and index expressions
|
|
791
|
-
src_name = node_proto.inputs[0].split(":")[0]
|
|
792
|
-
index_name = node_proto.inputs[1].split(":")[0]
|
|
793
|
-
src_expr = self._value_cache[src_name]
|
|
794
|
-
index_expr = self._value_cache[index_name]
|
|
795
|
-
|
|
796
|
-
return ShflExpr(src_expr, index_expr)
|
|
797
|
-
|
|
798
|
-
def _create_conv_expr(self, node_proto: mpir_pb2.NodeProto) -> ConvExpr:
|
|
799
|
-
"""Create a ConvExpr from a NodeProto."""
|
|
800
|
-
# Parse variable expressions
|
|
801
|
-
var_exprs = []
|
|
802
|
-
for input_name in node_proto.inputs:
|
|
803
|
-
dep_name = input_name.split(":")[0]
|
|
804
|
-
if dep_name in self._value_cache:
|
|
805
|
-
var_exprs.append(self._value_cache[dep_name])
|
|
806
|
-
else:
|
|
807
|
-
raise ValueError(f"Input {input_name} not found for conv node")
|
|
808
|
-
|
|
809
|
-
return ConvExpr(var_exprs)
|
|
810
|
-
|
|
811
|
-
def _create_call_expr(self, node_proto: mpir_pb2.NodeProto) -> CallExpr:
|
|
812
|
-
"""Create a CallExpr from a NodeProto."""
|
|
813
|
-
# Parse function and arguments
|
|
814
|
-
fn_name = node_proto.inputs[0].split(":")[0]
|
|
815
|
-
fn_expr = self._value_cache[fn_name]
|
|
816
|
-
|
|
817
|
-
# Ensure function is FuncDefExpr
|
|
818
|
-
if not isinstance(fn_expr, FuncDefExpr):
|
|
819
|
-
raise ValueError(f"Call function must be FuncDefExpr, got {type(fn_expr)}")
|
|
820
|
-
|
|
821
|
-
arg_exprs = []
|
|
822
|
-
for input_name in node_proto.inputs[1:]:
|
|
823
|
-
dep_name = input_name.split(":")[0]
|
|
824
|
-
if dep_name in self._value_cache:
|
|
825
|
-
arg_exprs.append(self._value_cache[dep_name])
|
|
826
|
-
else:
|
|
827
|
-
raise ValueError(f"Input {input_name} not found for call node")
|
|
828
|
-
# Optional call-site name attribute
|
|
829
|
-
call_name = None
|
|
830
|
-
if "name" in node_proto.attrs:
|
|
831
|
-
call_name = self._proto_to_attr(node_proto.attrs["name"]) # type: ignore[assignment]
|
|
832
|
-
|
|
833
|
-
return CallExpr(call_name or "", fn_expr, arg_exprs)
|
|
834
|
-
|
|
835
|
-
def _proto_to_mptype(self, type_proto: mpir_pb2.MPTypeProto) -> MPType:
|
|
836
|
-
"""Convert MPTypeProto to MPType."""
|
|
837
|
-
# Convert pmask (now int64, -1 means dynamic mask (None))
|
|
838
|
-
pmask_int = type_proto.pmask
|
|
839
|
-
pmask = None if pmask_int == -1 else Mask(pmask_int)
|
|
840
|
-
|
|
841
|
-
# Convert attributes
|
|
842
|
-
attrs = {}
|
|
843
|
-
for attr_name, attr_proto in type_proto.attrs.items():
|
|
844
|
-
attrs[attr_name] = self._proto_to_attr(attr_proto)
|
|
845
|
-
|
|
846
|
-
# Handle tensor type
|
|
847
|
-
if type_proto.HasField("tensor_type"):
|
|
848
|
-
tensor_type_proto = type_proto.tensor_type
|
|
849
|
-
dtype = proto_to_dtype(tensor_type_proto.dtype)
|
|
850
|
-
shape = tuple(tensor_type_proto.shape_dims)
|
|
851
|
-
tensor_type = TensorType(dtype, shape)
|
|
852
|
-
return MPType(tensor_type, pmask, attrs)
|
|
853
|
-
|
|
854
|
-
# Handle table type
|
|
855
|
-
elif type_proto.HasField("table_type"):
|
|
856
|
-
table_type_proto = type_proto.table_type
|
|
857
|
-
columns = []
|
|
858
|
-
for column_proto in table_type_proto.columns:
|
|
859
|
-
col_name = column_proto.name
|
|
860
|
-
col_dtype = proto_to_dtype(column_proto.dtype)
|
|
861
|
-
columns.append((col_name, col_dtype))
|
|
862
|
-
|
|
863
|
-
table_type = TableType(tuple(columns))
|
|
864
|
-
return MPType(table_type, pmask, attrs)
|
|
865
|
-
|
|
866
|
-
else:
|
|
867
|
-
raise ValueError(
|
|
868
|
-
"MPTypeProto must specify either tensor_type or table_type"
|
|
869
|
-
)
|
|
870
|
-
|
|
871
|
-
def _proto_to_attr(self, attr_proto: mpir_pb2.AttrProto) -> Any:
|
|
872
|
-
"""Convert AttrProto to Python value."""
|
|
873
|
-
if attr_proto.type == mpir_pb2.AttrProto.INT:
|
|
874
|
-
return attr_proto.i
|
|
875
|
-
elif attr_proto.type == mpir_pb2.AttrProto.FLOAT:
|
|
876
|
-
return attr_proto.f
|
|
877
|
-
elif attr_proto.type == mpir_pb2.AttrProto.STRING:
|
|
878
|
-
return attr_proto.s
|
|
879
|
-
elif attr_proto.type == mpir_pb2.AttrProto.BYTES:
|
|
880
|
-
return attr_proto.raw_bytes
|
|
881
|
-
elif attr_proto.type == mpir_pb2.AttrProto.INTS:
|
|
882
|
-
return list(attr_proto.ints)
|
|
883
|
-
elif attr_proto.type == mpir_pb2.AttrProto.FLOATS:
|
|
884
|
-
return list(attr_proto.floats)
|
|
885
|
-
elif attr_proto.type == mpir_pb2.AttrProto.STRINGS:
|
|
886
|
-
return list(attr_proto.strs)
|
|
887
|
-
elif attr_proto.type == mpir_pb2.AttrProto.FUNCTION:
|
|
888
|
-
# Reconstruct PFunction - since Expr already contains MPType information,
|
|
889
|
-
# we don't need to reconstruct ins_info and outs_info from serialized data.
|
|
890
|
-
# The type information will be inferred from the actual input expressions.
|
|
891
|
-
|
|
892
|
-
# Deserialize attrs dictionary
|
|
893
|
-
attrs = {}
|
|
894
|
-
for attr_name, attr_value_proto in attr_proto.func.attrs.items():
|
|
895
|
-
attrs[attr_name] = self._proto_to_attr(attr_value_proto)
|
|
896
|
-
|
|
897
|
-
return PFunction(
|
|
898
|
-
fn_type=attr_proto.func.type,
|
|
899
|
-
ins_info=[], # Will be inferred from input expressions
|
|
900
|
-
outs_info=[], # Will be inferred from context
|
|
901
|
-
fn_name=attr_proto.func.name or None,
|
|
902
|
-
fn_text=attr_proto.func.body if attr_proto.func.body else None,
|
|
903
|
-
**attrs, # Restore serialized attributes
|
|
904
|
-
)
|
|
905
|
-
elif attr_proto.type == mpir_pb2.AttrProto.GRAPH:
|
|
906
|
-
# Handle nested expressions (for control flow)
|
|
907
|
-
reader = IrReader()
|
|
908
|
-
return reader.loads(attr_proto.graph)
|
|
909
|
-
else:
|
|
910
|
-
raise TypeError(f"Unsupported attribute type: {attr_proto.type}")
|
|
911
|
-
|
|
912
|
-
|
|
913
|
-
def get_graph_statistics(graph_proto: mpir_pb2.GraphProto) -> str:
|
|
914
|
-
"""Get statistics about a GraphProto structure.
|
|
915
|
-
|
|
916
|
-
Args:
|
|
917
|
-
graph_proto: The protobuf GraphProto to analyze
|
|
918
|
-
|
|
919
|
-
Returns:
|
|
920
|
-
A formatted string with:
|
|
921
|
-
- Graph version information
|
|
922
|
-
- Node count and breakdown by operation type
|
|
923
|
-
- Output variable information
|
|
924
|
-
- Graph attributes count
|
|
925
|
-
"""
|
|
926
|
-
# Build statistics string
|
|
927
|
-
lines = []
|
|
928
|
-
lines.append("GraphProto structure analysis:")
|
|
929
|
-
|
|
930
|
-
# Version information with compatibility check
|
|
931
|
-
try:
|
|
932
|
-
version = graph_proto.version
|
|
933
|
-
version_str = f"{version.major}.{version.minor}.{version.patch}"
|
|
934
|
-
lines.append(f"- Version: {version_str}")
|
|
935
|
-
|
|
936
|
-
# Version compatibility check
|
|
937
|
-
if version.major != 1:
|
|
938
|
-
lines.append(f" WARNING: Expected major version 1, got {version.major}")
|
|
939
|
-
except AttributeError:
|
|
940
|
-
lines.append("- Version: Unknown (missing version info)")
|
|
941
|
-
version_str = "unknown"
|
|
942
|
-
|
|
943
|
-
# Node and output counts
|
|
944
|
-
lines.append(f"- Number of nodes: {len(graph_proto.nodes)}")
|
|
945
|
-
lines.append(f"- Number of outputs: {len(graph_proto.outputs)}")
|
|
946
|
-
lines.append(f"- Graph attributes: {len(graph_proto.attrs)}")
|
|
947
|
-
lines.append("")
|
|
948
|
-
|
|
949
|
-
# Node breakdown by operation type
|
|
950
|
-
lines.append("Node breakdown by operation type:")
|
|
951
|
-
op_counts: dict[str, int] = {}
|
|
952
|
-
for node in graph_proto.nodes:
|
|
953
|
-
op_type = node.op_type
|
|
954
|
-
op_counts[op_type] = op_counts.get(op_type, 0) + 1
|
|
955
|
-
|
|
956
|
-
for op_type, count in sorted(op_counts.items()):
|
|
957
|
-
lines.append(f"- {op_type}: {count} nodes")
|
|
958
|
-
lines.append("")
|
|
959
|
-
|
|
960
|
-
# Output variables
|
|
961
|
-
lines.append("Output variables:")
|
|
962
|
-
for i, output in enumerate(graph_proto.outputs):
|
|
963
|
-
lines.append(f"- Output {i}: {output}")
|
|
964
|
-
|
|
965
|
-
return "\n".join(lines)
|