mplang-nightly 0.1.dev269__py3-none-any.whl → 0.1.dev270__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (180) hide show
  1. mplang/__init__.py +391 -17
  2. mplang/{v2/backends → backends}/__init__.py +9 -7
  3. mplang/{v2/backends → backends}/bfv_impl.py +6 -6
  4. mplang/{v2/backends → backends}/crypto_impl.py +6 -6
  5. mplang/{v2/backends → backends}/field_impl.py +5 -5
  6. mplang/{v2/backends → backends}/func_impl.py +4 -4
  7. mplang/{v2/backends → backends}/phe_impl.py +3 -3
  8. mplang/{v2/backends → backends}/simp_design.md +1 -1
  9. mplang/{v2/backends → backends}/simp_driver/__init__.py +5 -5
  10. mplang/{v2/backends → backends}/simp_driver/http.py +8 -8
  11. mplang/{v2/backends → backends}/simp_driver/mem.py +9 -9
  12. mplang/{v2/backends → backends}/simp_driver/ops.py +4 -4
  13. mplang/{v2/backends → backends}/simp_driver/state.py +2 -2
  14. mplang/{v2/backends → backends}/simp_driver/values.py +2 -2
  15. mplang/{v2/backends → backends}/simp_worker/__init__.py +3 -3
  16. mplang/{v2/backends → backends}/simp_worker/http.py +10 -10
  17. mplang/{v2/backends → backends}/simp_worker/mem.py +1 -1
  18. mplang/{v2/backends → backends}/simp_worker/ops.py +5 -5
  19. mplang/{v2/backends → backends}/simp_worker/state.py +2 -4
  20. mplang/{v2/backends → backends}/spu_impl.py +8 -8
  21. mplang/{v2/backends → backends}/spu_state.py +4 -4
  22. mplang/{v2/backends → backends}/store_impl.py +3 -3
  23. mplang/{v2/backends → backends}/table_impl.py +8 -8
  24. mplang/{v2/backends → backends}/tee_impl.py +6 -6
  25. mplang/{v2/backends → backends}/tensor_impl.py +6 -6
  26. mplang/{v2/cli.py → cli.py} +9 -9
  27. mplang/{v2/cli_guide.md → cli_guide.md} +12 -12
  28. mplang/{v2/dialects → dialects}/__init__.py +5 -5
  29. mplang/{v2/dialects → dialects}/bfv.py +6 -6
  30. mplang/{v2/dialects → dialects}/crypto.py +5 -5
  31. mplang/{v2/dialects → dialects}/dtypes.py +2 -2
  32. mplang/{v2/dialects → dialects}/field.py +3 -3
  33. mplang/{v2/dialects → dialects}/func.py +2 -2
  34. mplang/{v2/dialects → dialects}/phe.py +6 -6
  35. mplang/{v2/dialects → dialects}/simp.py +6 -6
  36. mplang/{v2/dialects → dialects}/spu.py +7 -7
  37. mplang/{v2/dialects → dialects}/store.py +2 -2
  38. mplang/{v2/dialects → dialects}/table.py +3 -3
  39. mplang/{v2/dialects → dialects}/tee.py +6 -6
  40. mplang/{v2/dialects → dialects}/tensor.py +5 -5
  41. mplang/{v2/edsl → edsl}/__init__.py +3 -3
  42. mplang/{v2/edsl → edsl}/context.py +6 -6
  43. mplang/{v2/edsl → edsl}/graph.py +5 -5
  44. mplang/{v2/edsl → edsl}/jit.py +2 -2
  45. mplang/{v2/edsl → edsl}/object.py +1 -1
  46. mplang/{v2/edsl → edsl}/primitive.py +5 -5
  47. mplang/{v2/edsl → edsl}/printer.py +1 -1
  48. mplang/{v2/edsl → edsl}/serde.py +1 -1
  49. mplang/{v2/edsl → edsl}/tracer.py +7 -7
  50. mplang/{v2/edsl → edsl}/typing.py +1 -1
  51. mplang/{v2/kernels → kernels}/ldpc.cpp +13 -13
  52. mplang/{v2/kernels → kernels}/okvs.cpp +4 -4
  53. mplang/{v2/kernels → kernels}/okvs_opt.cpp +31 -31
  54. mplang/{v2/kernels → kernels}/py_kernels.py +1 -1
  55. mplang/{v2/libs → libs}/collective.py +5 -5
  56. mplang/{v2/libs → libs}/device/__init__.py +1 -1
  57. mplang/{v2/libs → libs}/device/api.py +12 -12
  58. mplang/{v2/libs → libs}/ml/__init__.py +1 -1
  59. mplang/{v2/libs → libs}/ml/sgb.py +4 -4
  60. mplang/{v2/libs → libs}/mpc/__init__.py +3 -3
  61. mplang/{v2/libs → libs}/mpc/_utils.py +2 -2
  62. mplang/{v2/libs → libs}/mpc/analytics/aggregation.py +1 -1
  63. mplang/{v2/libs → libs}/mpc/analytics/groupby.py +2 -2
  64. mplang/{v2/libs → libs}/mpc/analytics/permutation.py +3 -3
  65. mplang/{v2/libs → libs}/mpc/ot/base.py +3 -3
  66. mplang/{v2/libs → libs}/mpc/ot/extension.py +2 -2
  67. mplang/{v2/libs → libs}/mpc/ot/silent.py +4 -4
  68. mplang/{v2/libs → libs}/mpc/psi/cuckoo.py +3 -3
  69. mplang/{v2/libs → libs}/mpc/psi/okvs.py +1 -1
  70. mplang/{v2/libs → libs}/mpc/psi/okvs_gct.py +3 -3
  71. mplang/{v2/libs → libs}/mpc/psi/oprf.py +3 -3
  72. mplang/{v2/libs → libs}/mpc/psi/rr22.py +7 -7
  73. mplang/{v2/libs → libs}/mpc/psi/unbalanced.py +4 -4
  74. mplang/{v2/libs → libs}/mpc/vole/gilboa.py +3 -3
  75. mplang/{v2/libs → libs}/mpc/vole/ldpc.py +2 -2
  76. mplang/{v2/libs → libs}/mpc/vole/silver.py +6 -6
  77. mplang/{v2/runtime → runtime}/interpreter.py +11 -11
  78. mplang/{v2/runtime → runtime}/value.py +2 -2
  79. mplang/{v1/runtime → utils}/__init__.py +18 -15
  80. mplang/{v1/utils → utils}/func_utils.py +1 -1
  81. {mplang_nightly-0.1.dev269.dist-info → mplang_nightly-0.1.dev270.dist-info}/METADATA +2 -2
  82. mplang_nightly-0.1.dev270.dist-info/RECORD +102 -0
  83. mplang/v1/__init__.py +0 -157
  84. mplang/v1/_device.py +0 -602
  85. mplang/v1/analysis/__init__.py +0 -37
  86. mplang/v1/analysis/diagram.py +0 -567
  87. mplang/v1/core/__init__.py +0 -157
  88. mplang/v1/core/cluster.py +0 -343
  89. mplang/v1/core/comm.py +0 -281
  90. mplang/v1/core/context_mgr.py +0 -50
  91. mplang/v1/core/dtypes.py +0 -335
  92. mplang/v1/core/expr/__init__.py +0 -80
  93. mplang/v1/core/expr/ast.py +0 -542
  94. mplang/v1/core/expr/evaluator.py +0 -581
  95. mplang/v1/core/expr/printer.py +0 -285
  96. mplang/v1/core/expr/transformer.py +0 -141
  97. mplang/v1/core/expr/utils.py +0 -78
  98. mplang/v1/core/expr/visitor.py +0 -85
  99. mplang/v1/core/expr/walk.py +0 -387
  100. mplang/v1/core/interp.py +0 -160
  101. mplang/v1/core/mask.py +0 -325
  102. mplang/v1/core/mpir.py +0 -965
  103. mplang/v1/core/mpobject.py +0 -117
  104. mplang/v1/core/mptype.py +0 -407
  105. mplang/v1/core/pfunc.py +0 -130
  106. mplang/v1/core/primitive.py +0 -877
  107. mplang/v1/core/table.py +0 -218
  108. mplang/v1/core/tensor.py +0 -75
  109. mplang/v1/core/tracer.py +0 -383
  110. mplang/v1/host.py +0 -130
  111. mplang/v1/kernels/__init__.py +0 -41
  112. mplang/v1/kernels/base.py +0 -125
  113. mplang/v1/kernels/basic.py +0 -240
  114. mplang/v1/kernels/context.py +0 -369
  115. mplang/v1/kernels/crypto.py +0 -122
  116. mplang/v1/kernels/fhe.py +0 -858
  117. mplang/v1/kernels/mock_tee.py +0 -72
  118. mplang/v1/kernels/phe.py +0 -1864
  119. mplang/v1/kernels/spu.py +0 -341
  120. mplang/v1/kernels/sql_duckdb.py +0 -44
  121. mplang/v1/kernels/stablehlo.py +0 -90
  122. mplang/v1/kernels/value.py +0 -626
  123. mplang/v1/ops/__init__.py +0 -35
  124. mplang/v1/ops/base.py +0 -424
  125. mplang/v1/ops/basic.py +0 -294
  126. mplang/v1/ops/crypto.py +0 -262
  127. mplang/v1/ops/fhe.py +0 -272
  128. mplang/v1/ops/jax_cc.py +0 -147
  129. mplang/v1/ops/nnx_cc.py +0 -168
  130. mplang/v1/ops/phe.py +0 -216
  131. mplang/v1/ops/spu.py +0 -151
  132. mplang/v1/ops/sql_cc.py +0 -303
  133. mplang/v1/ops/tee.py +0 -36
  134. mplang/v1/protos/v1alpha1/mpir_pb2.py +0 -63
  135. mplang/v1/protos/v1alpha1/mpir_pb2.pyi +0 -557
  136. mplang/v1/protos/v1alpha1/value_pb2.py +0 -34
  137. mplang/v1/protos/v1alpha1/value_pb2.pyi +0 -169
  138. mplang/v1/runtime/channel.py +0 -230
  139. mplang/v1/runtime/cli.py +0 -451
  140. mplang/v1/runtime/client.py +0 -456
  141. mplang/v1/runtime/communicator.py +0 -131
  142. mplang/v1/runtime/data_providers.py +0 -303
  143. mplang/v1/runtime/driver.py +0 -324
  144. mplang/v1/runtime/exceptions.py +0 -27
  145. mplang/v1/runtime/http_api.md +0 -56
  146. mplang/v1/runtime/link_comm.py +0 -196
  147. mplang/v1/runtime/server.py +0 -501
  148. mplang/v1/runtime/session.py +0 -270
  149. mplang/v1/runtime/simulation.py +0 -324
  150. mplang/v1/simp/__init__.py +0 -13
  151. mplang/v1/simp/api.py +0 -353
  152. mplang/v1/simp/mpi.py +0 -131
  153. mplang/v1/simp/party.py +0 -225
  154. mplang/v1/simp/random.py +0 -120
  155. mplang/v1/simp/smpc.py +0 -238
  156. mplang/v1/utils/__init__.py +0 -13
  157. mplang/v1/utils/crypto.py +0 -32
  158. mplang/v1/utils/spu_utils.py +0 -130
  159. mplang/v1/utils/table_utils.py +0 -185
  160. mplang/v2/__init__.py +0 -424
  161. mplang_nightly-0.1.dev269.dist-info/RECORD +0 -180
  162. /mplang/{v2/backends → backends}/channel.py +0 -0
  163. /mplang/{v2/edsl → edsl}/README.md +0 -0
  164. /mplang/{v2/edsl → edsl}/registry.py +0 -0
  165. /mplang/{v2/kernels → kernels}/Makefile +0 -0
  166. /mplang/{v2/kernels → kernels}/__init__.py +0 -0
  167. /mplang/{v2/kernels → kernels}/gf128.cpp +0 -0
  168. /mplang/{v2/libs → libs}/device/cluster.py +0 -0
  169. /mplang/{v2/libs → libs}/mpc/analytics/__init__.py +0 -0
  170. /mplang/{v2/libs → libs}/mpc/analytics/groupby.md +0 -0
  171. /mplang/{v2/libs → libs}/mpc/common/constants.py +0 -0
  172. /mplang/{v2/libs → libs}/mpc/ot/__init__.py +0 -0
  173. /mplang/{v2/libs → libs}/mpc/psi/__init__.py +0 -0
  174. /mplang/{v2/libs → libs}/mpc/vole/__init__.py +0 -0
  175. /mplang/{v2/runtime → runtime}/__init__.py +0 -0
  176. /mplang/{v2/runtime → runtime}/dialect_state.py +0 -0
  177. /mplang/{v2/runtime → runtime}/object_store.py +0 -0
  178. {mplang_nightly-0.1.dev269.dist-info → mplang_nightly-0.1.dev270.dist-info}/WHEEL +0 -0
  179. {mplang_nightly-0.1.dev269.dist-info → mplang_nightly-0.1.dev270.dist-info}/entry_points.txt +0 -0
  180. {mplang_nightly-0.1.dev269.dist-info → mplang_nightly-0.1.dev270.dist-info}/licenses/LICENSE +0 -0
mplang/v1/core/mpir.py DELETED
@@ -1,965 +0,0 @@
1
- # Copyright 2025 Ant Group Co., Ltd.
2
- #
3
- # Licensed under the Apache License, Version 2.0 (the "License");
4
- # you may not use this file except in compliance with the License.
5
- # You may obtain a copy of the License at
6
- #
7
- # http://www.apache.org/licenses/LICENSE-2.0
8
- #
9
- # Unless required by applicable law or agreed to in writing, software
10
- # distributed under the License is distributed on an "AS IS" BASIS,
11
- # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
- # See the License for the specific language governing permissions and
13
- # limitations under the License.
14
-
15
- """MPIR (Multi-Party Intermediate Representation) serialization module.
16
-
17
- This module provides functionality for serializing and deserializing
18
- expression-based computation graphs to and from protobuf representations.
19
- It serves as the bridge between in-memory expression trees and their
20
- serialized form for storage or transmission.
21
-
22
- Key components:
23
- - Writer: Serializes Expr objects to GraphProto
24
- - Reader: Deserializes GraphProto back to Expr objects
25
- - Conversion functions: Handle mapping between Python types and protobuf types
26
- """
27
-
28
- from __future__ import annotations
29
-
30
- from typing import Any
31
-
32
- import numpy as np
33
- import spu.libspu as spu_api
34
-
35
- from mplang.v1.core.dtypes import DATE, JSON, STRING, TIME, TIMESTAMP, DType
36
- from mplang.v1.core.expr import Expr, FuncDefExpr
37
- from mplang.v1.core.expr.ast import (
38
- AccessExpr,
39
- CallExpr,
40
- CondExpr,
41
- ConvExpr,
42
- EvalExpr,
43
- ShflExpr,
44
- ShflSExpr,
45
- TupleExpr,
46
- VariableExpr,
47
- WhileExpr,
48
- )
49
- from mplang.v1.core.expr.walk import walk
50
- from mplang.v1.core.mask import Mask
51
- from mplang.v1.core.mptype import MPType
52
- from mplang.v1.core.pfunc import PFunction
53
- from mplang.v1.core.table import TableType
54
- from mplang.v1.core.tensor import TensorType
55
- from mplang.v1.protos.v1alpha1 import mpir_pb2
56
-
57
- # Single mapping table for dtype conversion
58
- DTYPE_MAPPING = {
59
- np.float32: mpir_pb2.DataType.F32,
60
- np.uint8: mpir_pb2.DataType.U8,
61
- np.int8: mpir_pb2.DataType.I8,
62
- np.uint16: mpir_pb2.DataType.U16,
63
- np.int16: mpir_pb2.DataType.I16,
64
- np.int32: mpir_pb2.DataType.I32,
65
- np.int64: mpir_pb2.DataType.I64,
66
- np.str_: mpir_pb2.DataType.STRING,
67
- np.bool_: mpir_pb2.DataType.BOOL,
68
- np.float16: mpir_pb2.DataType.F16,
69
- np.float64: mpir_pb2.DataType.F64,
70
- np.uint32: mpir_pb2.DataType.U32,
71
- np.uint64: mpir_pb2.DataType.U64,
72
- np.complex64: mpir_pb2.DataType.COMPLEX64,
73
- np.complex128: mpir_pb2.DataType.COMPLEX128,
74
- }
75
-
76
- # Additional mapping for table-only DType constants
77
- DTYPE_TO_PROTO_MAPPING = {
78
- # Map DType constants to protobuf enums
79
- STRING: mpir_pb2.DataType.STRING,
80
- DATE: mpir_pb2.DataType.DATE,
81
- TIME: mpir_pb2.DataType.TIME,
82
- TIMESTAMP: mpir_pb2.DataType.TIMESTAMP,
83
- JSON: mpir_pb2.DataType.JSON,
84
- }
85
-
86
-
87
- def dtype_to_proto(dtype_like: Any) -> Any:
88
- """Convert dtype (DType, NumPy dtype, or type) to protobuf DataType.
89
-
90
- Args:
91
- dtype_like: A DType, NumPy dtype, or Python type to convert.
92
-
93
- Returns:
94
- The corresponding protobuf DataType enum value.
95
-
96
- Raises:
97
- ValueError: If the dtype is not supported for conversion.
98
- """
99
- # If it's already a DType, check for direct mapping first
100
- if isinstance(dtype_like, DType):
101
- # Check for table-only types first
102
- if dtype_like in DTYPE_TO_PROTO_MAPPING:
103
- return DTYPE_TO_PROTO_MAPPING[dtype_like]
104
-
105
- # For regular types, convert to numpy for protobuf mapping
106
- try:
107
- numpy_dtype = dtype_like.to_numpy()
108
- key_type = numpy_dtype.type
109
- except ValueError as e:
110
- # Handle table-only types that can't be converted to numpy
111
- raise ValueError(
112
- f"Unsupported dtype for proto conversion: {dtype_like}. This is likely a table-only type that cannot be converted to a numpy dtype. Please ensure the dtype is supported for proto conversion."
113
- ) from e
114
- else:
115
- # Handle NumPy dtypes and other types
116
- try:
117
- key_type = np.dtype(dtype_like).type
118
- except TypeError:
119
- # Handle cases where dtype_like might already be a type object
120
- # that np.dtype() can't process but is a valid key.
121
- if isinstance(dtype_like, type) and issubclass(dtype_like, np.generic):
122
- key_type = dtype_like
123
- else:
124
- raise ValueError(f"Invalid dtype: {dtype_like}") from None
125
-
126
- if key_type in DTYPE_MAPPING:
127
- return DTYPE_MAPPING[key_type]
128
- else:
129
- raise ValueError(f"Unsupported dtype: {dtype_like}")
130
-
131
-
132
- def proto_to_dtype(dtype_enum: int) -> DType:
133
- """Convert protobuf DataType enum to DType.
134
-
135
- Args:
136
- dtype_enum: The protobuf DataType enum value to convert.
137
-
138
- Returns:
139
- The corresponding DType object.
140
-
141
- Raises:
142
- ValueError: If the enum value is not supported.
143
- """
144
- # Check for table-only types first
145
- for dtype_obj, proto_enum in DTYPE_TO_PROTO_MAPPING.items():
146
- if proto_enum == dtype_enum:
147
- return dtype_obj
148
-
149
- # Find the numpy type for the given enum by searching the mapping
150
- for numpy_type, proto_enum in DTYPE_MAPPING.items():
151
- if proto_enum == dtype_enum:
152
- # Convert numpy type to dtype
153
- try:
154
- np_dtype = np.dtype(numpy_type)
155
- except TypeError as e:
156
- raise ValueError(f"Cannot create numpy dtype from {numpy_type}") from e
157
-
158
- # Special handling for string types since DType.from_numpy doesn't support them
159
- if np_dtype.kind == "U": # Unicode string
160
- # Return the STRING constant for table-only string types
161
- return STRING
162
- else:
163
- try:
164
- return DType.from_numpy(np_dtype)
165
- except ValueError as e:
166
- raise ValueError(
167
- f"Cannot convert numpy dtype {np_dtype} to DType"
168
- ) from e
169
-
170
- # If we get here, the enum was not found
171
- raise ValueError(f"Unsupported dtype enum: {dtype_enum}")
172
-
173
-
174
- def attr_to_proto(py_value: Any) -> mpir_pb2.AttrProto:
175
- """Convert a Python attribute value to an AttrProto."""
176
- attr_proto = mpir_pb2.AttrProto()
177
- if isinstance(py_value, int):
178
- attr_proto.type = mpir_pb2.AttrProto.INT
179
- attr_proto.i = py_value
180
- elif isinstance(py_value, float):
181
- attr_proto.type = mpir_pb2.AttrProto.FLOAT
182
- attr_proto.f = py_value
183
- elif isinstance(py_value, str):
184
- attr_proto.type = mpir_pb2.AttrProto.STRING
185
- attr_proto.s = py_value
186
- elif isinstance(py_value, bytes):
187
- attr_proto.type = mpir_pb2.AttrProto.BYTES
188
- attr_proto.raw_bytes = py_value
189
- elif isinstance(py_value, tuple | list):
190
- if all(isinstance(item, int) for item in py_value):
191
- attr_proto.type = mpir_pb2.AttrProto.INTS
192
- attr_proto.ints.extend(list(py_value))
193
- elif all(isinstance(item, float) for item in py_value):
194
- attr_proto.type = mpir_pb2.AttrProto.FLOATS
195
- attr_proto.floats.extend(list(py_value))
196
- elif all(isinstance(item, str) for item in py_value):
197
- attr_proto.type = mpir_pb2.AttrProto.STRINGS
198
- attr_proto.strs.extend(list(py_value))
199
- elif all(isinstance(item, spu_api.Visibility) for item in py_value):
200
- # Handle list of enum types (like [Visibility.VIS_SECRET, Visibility.VIS_SECRET])
201
- attr_proto.type = mpir_pb2.AttrProto.INTS
202
- attr_proto.ints.extend([int(item) for item in py_value])
203
- else:
204
- raise TypeError(f"Unsupported tuple/list type: {type(py_value)}")
205
- elif isinstance(py_value, FuncDefExpr):
206
- # Convert FuncDefExpr to GraphProto
207
- graph = IrWriter().dumps(py_value)
208
- attr_proto.type = mpir_pb2.AttrProto.GRAPH
209
- attr_proto.graph.CopyFrom(graph)
210
- elif isinstance(py_value, PFunction):
211
- attr_proto.type = mpir_pb2.AttrProto.FUNCTION
212
- attr_proto.func.type = py_value.fn_type
213
- attr_proto.func.name = py_value.fn_name or ""
214
- if py_value.fn_text is not None:
215
- attr_proto.func.body = str(py_value.fn_text)
216
-
217
- # Serialize attrs dictionary
218
- if py_value.attrs:
219
- for attr_name, attr_value in py_value.attrs.items():
220
- # Skip None-valued attributes to align with top-level attr handling
221
- if attr_value is not None:
222
- attr_proto.func.attrs[attr_name].CopyFrom(attr_to_proto(attr_value))
223
-
224
- # Note: We don't serialize ins_info and outs_info since they can be
225
- # inferred from the input expressions during deserialization
226
- elif isinstance(py_value, spu_api.Visibility):
227
- # Handle enum types (like spu.libspu.Visibility) by storing as int
228
- attr_proto.type = mpir_pb2.AttrProto.INT
229
- attr_proto.i = int(py_value)
230
- elif isinstance(py_value, Mask):
231
- # Handle Mask objects by storing as int
232
- attr_proto.type = mpir_pb2.AttrProto.INT
233
- attr_proto.i = int(py_value)
234
- else:
235
- raise TypeError(f"Unsupported attribute type: {type(py_value)}")
236
- return attr_proto
237
-
238
-
239
- class IrWriter:
240
- """Writer for serializing Expr-based expressions to GraphProto.
241
-
242
- This class traverses an expression tree and converts it into a serialized
243
- GraphProto representation. It handles various expression types and ensures
244
- that all dependencies are properly serialized before the expressions that
245
- depend on them.
246
- """
247
-
248
- def __init__(self, var_name_mapping: dict[str, str] | None = None):
249
- """Initialize the Writer.
250
-
251
- Args:
252
- var_name_mapping: Optional mapping of variable names to replace during serialization.
253
- """
254
- self._counter = 0
255
- self._expr_ids: dict[int, str] = {} # Use expr id instead of Node
256
- self._nodes: list[mpir_pb2.NodeProto] = []
257
- self._var_name_mapping = var_name_mapping or {}
258
-
259
- def expr_name(self, expr: Expr) -> str:
260
- """Get or create a name for an expression.
261
-
262
- Args:
263
- expr: The expression to name.
264
-
265
- Returns:
266
- A unique name for the expression.
267
- """
268
- expr_id = id(expr)
269
- if expr_id not in self._expr_ids:
270
- self._expr_ids[expr_id] = f"%{self._counter}"
271
- self._counter += 1
272
- return self._expr_ids[expr_id]
273
-
274
- def value_name(self, expr: Expr, out_idx: int = 0) -> str:
275
- """Get value name for expression output.
276
-
277
- Args:
278
- expr: The expression.
279
- out_idx: The output index for multi-output expressions.
280
-
281
- Returns:
282
- A name for the specific output of the expression.
283
- """
284
- if len(expr.mptypes) == 1:
285
- return self.expr_name(expr)
286
- else:
287
- return f"{self.expr_name(expr)}:{out_idx}"
288
-
289
- # ------------------------- traversal and deps helpers -------------------------
290
- @staticmethod
291
- def _writer_deps(node: Expr) -> list[Expr]:
292
- """Dependencies for serialization order.
293
-
294
- Similar to dataflow deps, but with two important differences:
295
- - CallExpr: include the function value (fn) so we emit a func_def node
296
- in the outer graph before the call node.
297
- - FuncDefExpr: include body so we emit body producers before func_def.
298
- """
299
- if isinstance(node, EvalExpr):
300
- return list(node.args)
301
- if isinstance(node, TupleExpr):
302
- return list(node.args)
303
- if isinstance(node, CondExpr):
304
- # pred and actual args only; functions are serialized via attrs (nested graphs)
305
- return [node.pred, *node.args]
306
- if isinstance(node, WhileExpr):
307
- # initial state args only; functions are serialized via attrs (nested graphs)
308
- return list(node.args)
309
- if isinstance(node, ConvExpr):
310
- return list(node.vars)
311
- if isinstance(node, ShflSExpr):
312
- return [node.src_val]
313
- if isinstance(node, ShflExpr):
314
- return [node.src, node.index]
315
- if isinstance(node, AccessExpr):
316
- return [node.src]
317
- if isinstance(node, VariableExpr):
318
- return []
319
- if isinstance(node, FuncDefExpr):
320
- # ensure body producers are serialized first
321
- return [node.body]
322
- if isinstance(node, CallExpr):
323
- # include fn and args as deps so func_def appears before call
324
- return [node.fn, *node.args]
325
- return []
326
-
327
- def reset(self) -> None:
328
- """Reset writer state.
329
-
330
- Clears all internal state, allowing the writer to be reused for
331
- serializing a new expression tree.
332
- """
333
- self._counter = 0
334
- self._expr_ids.clear()
335
- self._nodes.clear()
336
-
337
- def _create_node_proto(self, expr: Expr, op_type: str) -> mpir_pb2.NodeProto:
338
- """Helper: Create a basic NodeProto with common fields set.
339
-
340
- Args:
341
- expr: The expression this node represents.
342
- op_type: The operation type for this node.
343
-
344
- Returns:
345
- A new NodeProto with basic fields set.
346
- """
347
- op = mpir_pb2.NodeProto()
348
- op.op_type = op_type
349
- op.name = self.expr_name(expr)
350
- return op
351
-
352
- def _add_output_info(self, op: mpir_pb2.NodeProto, expr: Expr) -> None:
353
- """Helper: Add output type information to a NodeProto.
354
-
355
- This method populates the output type information for a node based
356
- on the expression's mptypes.
357
-
358
- Args:
359
- op: The NodeProto to populate.
360
- expr: The expression providing the type information.
361
- """
362
- for out_info in expr.mptypes:
363
- out_proto = op.outs_info.add()
364
-
365
- if out_info.is_tensor:
366
- # Handle tensor type
367
- tensor_type = out_proto.tensor_type
368
- tensor_type.dtype = dtype_to_proto(out_info.dtype)
369
- tensor_type.shape_dims.extend(list(out_info.shape))
370
- elif out_info.is_table:
371
- # Handle table type
372
- table_type = out_proto.table_type
373
- for col_name, col_dtype in out_info.schema.columns:
374
- column = table_type.columns.add()
375
- column.name = col_name
376
- column.dtype = dtype_to_proto(col_dtype)
377
-
378
- # Set pmask (now int64, -1 for dynamic mask)
379
- if out_info.pmask is not None:
380
- out_proto.pmask = int(out_info.pmask)
381
- else:
382
- out_proto.pmask = -1 # Dynamic mask
383
-
384
- def _add_expr_inputs(self, op: mpir_pb2.NodeProto, *exprs: Expr) -> None:
385
- """Helper: Add expression inputs to NodeProto.
386
-
387
- For multi-output expressions, this adds all outputs as inputs.
388
-
389
- Args:
390
- op: The NodeProto to add inputs to.
391
- exprs: The expressions to add as inputs.
392
- """
393
- for expr in exprs:
394
- op.inputs.extend([
395
- self.value_name(expr, i) for i in range(len(expr.mptypes))
396
- ])
397
-
398
- def _add_single_expr_inputs(self, op: mpir_pb2.NodeProto, *exprs: Expr) -> None:
399
- """Helper: Add single-output expression inputs to NodeProto.
400
-
401
- For expressions, this adds only the first (primary) output as input.
402
-
403
- Args:
404
- op: The NodeProto to add inputs to.
405
- exprs: The expressions to add as inputs.
406
- """
407
- for expr in exprs:
408
- op.inputs.append(self.value_name(expr, 0))
409
-
410
- def _add_attrs(self, op: mpir_pb2.NodeProto, **attrs: Any) -> None:
411
- """Helper: Add attributes to NodeProto.
412
-
413
- Args:
414
- op: The NodeProto to add attributes to.
415
- **attrs: The attributes to add (key-value pairs).
416
- """
417
- for key, value in attrs.items():
418
- if value is not None: # Skip None values
419
- op.attrs[key].CopyFrom(attr_to_proto(value))
420
-
421
- def _finalize_node(self, op: mpir_pb2.NodeProto, expr: Expr) -> str:
422
- """Helper: Add output info, append to nodes, and return expr name.
423
-
424
- This method completes the node creation process by adding output
425
- information, appending the node to the list of nodes, and returning
426
- the expression name.
427
-
428
- Args:
429
- op: The completed NodeProto.
430
- expr: The expression the node represents.
431
-
432
- Returns:
433
- The name of the expression.
434
- """
435
- self._add_output_info(op, expr)
436
- self._nodes.append(op)
437
- return self.expr_name(expr)
438
-
439
- def dumps(self, expr: Expr) -> mpir_pb2.GraphProto:
440
- """Dump an expression to GraphProto using iterative walk traversal."""
441
- self.reset()
442
-
443
- # Walk in post-order so deps are serialized before users
444
- for node in walk(expr, get_deps=self._writer_deps, traversal="dfs_post_iter"):
445
- # Avoid double-emit if the same Expr object appears multiple times
446
- node_id = id(node)
447
- if node_id in self._expr_ids:
448
- continue
449
- # Emit node
450
- self._serialize_node(node)
451
-
452
- # Create graph metadata
453
- graph_attrs = {}
454
- if isinstance(expr, FuncDefExpr):
455
- graph_attrs["name"] = attr_to_proto(f"function_{id(expr)}")
456
- # For function definitions, the outputs should be the FuncDefExpr itself
457
- outputs = [self.value_name(expr, i) for i in range(len(expr.mptypes))]
458
- else:
459
- # For regular expressions, outputs are the expression outputs
460
- outputs = [self.value_name(expr, i) for i in range(len(expr.mptypes))]
461
-
462
- return mpir_pb2.GraphProto(
463
- version=mpir_pb2.VersionInfo(major=1, minor=0, patch=0),
464
- nodes=self._nodes,
465
- outputs=outputs,
466
- attrs=graph_attrs,
467
- )
468
-
469
- # ------------------------------- emitters --------------------------------
470
- def _serialize_node(self, expr: Expr) -> None:
471
- """Create and append a NodeProto for the given expr."""
472
- if isinstance(expr, EvalExpr):
473
- op = self._create_node_proto(expr, "eval")
474
- self._add_expr_inputs(op, *expr.args)
475
- self._add_attrs(op, pfunc=expr.pfunc, rmask=expr.rmask)
476
- self._finalize_node(op, expr)
477
- elif isinstance(expr, VariableExpr):
478
- op = self._create_node_proto(expr, "variable")
479
- mapped_name = self._var_name_mapping.get(expr.name, expr.name)
480
- self._add_attrs(op, name=mapped_name)
481
- self._finalize_node(op, expr)
482
- elif isinstance(expr, TupleExpr):
483
- op = self._create_node_proto(expr, "tuple")
484
- self._add_single_expr_inputs(op, *expr.args)
485
- self._finalize_node(op, expr)
486
- elif isinstance(expr, CondExpr):
487
- op = self._create_node_proto(expr, "cond")
488
- self._add_single_expr_inputs(op, expr.pred)
489
- self._add_expr_inputs(op, *expr.args)
490
- self._add_attrs(op, then_fn=expr.then_fn, else_fn=expr.else_fn)
491
- self._finalize_node(op, expr)
492
- elif isinstance(expr, CallExpr):
493
- op = self._create_node_proto(expr, "call")
494
- self._add_single_expr_inputs(op, expr.fn)
495
- self._add_expr_inputs(op, *expr.args)
496
- self._add_attrs(op, name=expr.name)
497
- self._finalize_node(op, expr)
498
- elif isinstance(expr, WhileExpr):
499
- op = self._create_node_proto(expr, "while")
500
- self._add_expr_inputs(op, *expr.args)
501
- self._add_attrs(op, cond_fn=expr.cond_fn, body_fn=expr.body_fn)
502
- self._finalize_node(op, expr)
503
- elif isinstance(expr, ConvExpr):
504
- op = self._create_node_proto(expr, "conv")
505
- self._add_expr_inputs(op, *expr.vars)
506
- self._finalize_node(op, expr)
507
- elif isinstance(expr, ShflSExpr):
508
- op = self._create_node_proto(expr, "shfl_s")
509
- self._add_single_expr_inputs(op, expr.src_val)
510
- self._add_attrs(op, pmask=expr.pmask, src_ranks=expr.src_ranks)
511
- self._finalize_node(op, expr)
512
- elif isinstance(expr, ShflExpr):
513
- op = self._create_node_proto(expr, "shfl")
514
- self._add_single_expr_inputs(op, expr.src, expr.index)
515
- self._finalize_node(op, expr)
516
- elif isinstance(expr, AccessExpr):
517
- op = self._create_node_proto(expr, "access")
518
- op.inputs.append(self.value_name(expr.src, expr.index))
519
- self._add_attrs(op, index=expr.index)
520
- self._finalize_node(op, expr)
521
- elif isinstance(expr, FuncDefExpr):
522
- op = self._create_node_proto(expr, "func_def")
523
- self._add_expr_inputs(op, expr.body)
524
- self._add_attrs(op, params=expr.params)
525
- self._finalize_node(op, expr)
526
- else:
527
- raise TypeError(f"Unsupported expr type for serialization: {type(expr)}")
528
-
529
-
530
- class IrReader:
531
- """Reader for deserializing GraphProto back to Expr-based expressions.
532
-
533
- This class is responsible for converting serialized GraphProto representations
534
- back into executable expression trees. It handles the deserialization of
535
- various node types and manages dependencies between nodes to ensure proper
536
- reconstruction of the expression graph.
537
- """
538
-
539
- def __init__(self) -> None:
540
- self._value_cache: dict[str, Expr] = {}
541
-
542
- def loads(self, graph_proto: mpir_pb2.GraphProto) -> Expr | None:
543
- """Load an expression from a GraphProto.
544
-
545
- Args:
546
- graph_proto: The protobuf graph to deserialize
547
-
548
- Returns:
549
- The deserialized expression or None if empty
550
- """
551
- self._value_cache.clear()
552
-
553
- # Create a mapping for faster node lookup, checking for duplicate node names
554
- node_map = {}
555
- for node in graph_proto.nodes:
556
- if node.name in node_map:
557
- raise ValueError(
558
- f"Duplicate node name detected in graph: '{node.name}'"
559
- )
560
- node_map[node.name] = node
561
-
562
- # Process nodes in topological order
563
- processed_nodes = set()
564
-
565
- def process_node(node_proto: mpir_pb2.NodeProto) -> None:
566
- """Process a single node and its dependencies."""
567
- if node_proto.name in processed_nodes:
568
- return
569
-
570
- # First process all dependencies
571
- for input_name in node_proto.inputs:
572
- dep_node_name = input_name.split(":")[0]
573
- if dep_node_name in node_map and dep_node_name not in processed_nodes:
574
- process_node(node_map[dep_node_name])
575
-
576
- # Now process this node
577
- try:
578
- expr = self._create_expr_from_proto(node_proto)
579
- processed_nodes.add(node_proto.name)
580
- # Cache the expression
581
- self._value_cache[node_proto.name] = expr
582
- except Exception as e:
583
- raise ValueError(
584
- f"Error processing node '{node_proto.name}' "
585
- f"of type '{node_proto.op_type}': {e!s}"
586
- ) from e
587
-
588
- # Process all nodes
589
- for node_proto in graph_proto.nodes:
590
- process_node(node_proto)
591
-
592
- # Extract outputs - for now, just return the first output expression
593
- if graph_proto.outputs:
594
- output_name = graph_proto.outputs[0].split(":")[0]
595
- if output_name in self._value_cache:
596
- return self._value_cache[output_name]
597
- else:
598
- raise ValueError(f"Output {output_name} not found in processed nodes")
599
-
600
- return None
601
-
602
- def _create_expr_from_proto(self, node_proto: mpir_pb2.NodeProto) -> Expr:
603
- """Create an Expression from a NodeProto.
604
-
605
- This method delegates to specific creation methods based on the node type.
606
- """
607
- # Dispatch to appropriate creation method based on op_type
608
- creation_methods = {
609
- "eval": self._create_eval_expr,
610
- "variable": self._create_variable_expr,
611
- "tuple": self._create_tuple_expr,
612
- "cond": self._create_cond_expr,
613
- "while": self._create_while_expr,
614
- "access": self._create_access_expr,
615
- "func_def": self._create_func_def_expr,
616
- "shfl_s": self._create_shfl_s_expr,
617
- "shfl": self._create_shfl_expr,
618
- "conv": self._create_conv_expr,
619
- "call": self._create_call_expr,
620
- }
621
-
622
- if node_proto.op_type in creation_methods:
623
- return creation_methods[node_proto.op_type](node_proto)
624
- else:
625
- raise ValueError(f"Unsupported node type: {node_proto.op_type}")
626
-
627
- def _create_eval_expr(self, node_proto: mpir_pb2.NodeProto) -> EvalExpr:
628
- """Create an EvalExpr from a NodeProto."""
629
- # Parse inputs
630
- input_exprs = []
631
- for input_name in node_proto.inputs:
632
- dep_name = input_name.split(":")[0]
633
- if dep_name in self._value_cache:
634
- input_exprs.append(self._value_cache[dep_name])
635
- else:
636
- raise ValueError(f"Input {input_name} not found for eval node")
637
-
638
- # Parse function
639
- pfunc = self._proto_to_attr(node_proto.attrs["pfunc"])
640
- rmask = None
641
- if "rmask" in node_proto.attrs:
642
- rmask = self._proto_to_attr(node_proto.attrs["rmask"])
643
-
644
- # Fill in ins_info and outs_info for PFunction
645
- # ins_info from input expressions (use mptype for single type per value)
646
- ins_info: list[TensorType | TableType] = []
647
- for input_expr in input_exprs:
648
- # Use mptype directly for single MPType
649
- mptype = input_expr.mptype
650
- if mptype.is_tensor:
651
- ins_info.append(TensorType(mptype.dtype, mptype.shape))
652
- elif mptype.is_table:
653
- ins_info.append(mptype.schema)
654
- else:
655
- raise ValueError(f"unsupported type: {mptype}")
656
-
657
- # outs_info from NodeProto.outs_info
658
- outs_info: list[TensorType | TableType] = []
659
- for out_proto in node_proto.outs_info:
660
- if out_proto.HasField("tensor_type"):
661
- tensor_type_proto = out_proto.tensor_type
662
- dtype = proto_to_dtype(tensor_type_proto.dtype)
663
- shape = tuple(tensor_type_proto.shape_dims)
664
- outs_info.append(TensorType(dtype, shape))
665
- elif out_proto.HasField("table_type"):
666
- columns = [
667
- (col.name, proto_to_dtype(col.dtype))
668
- for col in out_proto.table_type.columns
669
- ]
670
- outs_info.append(TableType.from_pairs(columns))
671
- else:
672
- raise ValueError("Eval node currently only supports tensor types")
673
-
674
- # Create a complete PFunction with proper type information
675
- complete_pfunc = PFunction(
676
- fn_type=pfunc.fn_type,
677
- ins_info=ins_info,
678
- outs_info=outs_info,
679
- fn_name=pfunc.fn_name,
680
- fn_text=pfunc.fn_text,
681
- **pfunc.attrs, # Restore attributes
682
- )
683
-
684
- return EvalExpr(complete_pfunc, input_exprs, rmask)
685
-
686
- def _create_variable_expr(self, node_proto: mpir_pb2.NodeProto) -> VariableExpr:
687
- """Create a VariableExpr from a NodeProto."""
688
- # Parse variable name
689
- name = self._proto_to_attr(node_proto.attrs["name"])
690
-
691
- # Parse type info from output info (VariableExpr needs a single MPType)
692
- if not node_proto.outs_info:
693
- raise ValueError("Variable node missing output info")
694
-
695
- mptype = self._proto_to_mptype(node_proto.outs_info[0])
696
- return VariableExpr(name, mptype)
697
-
698
- def _create_tuple_expr(self, node_proto: mpir_pb2.NodeProto) -> TupleExpr:
699
- """Create a TupleExpr from a NodeProto."""
700
- # Parse inputs
701
- input_exprs = []
702
- for input_name in node_proto.inputs:
703
- dep_name = input_name.split(":")[0]
704
- if dep_name in self._value_cache:
705
- input_exprs.append(self._value_cache[dep_name])
706
- else:
707
- raise ValueError(f"Input {input_name} not found for tuple node")
708
-
709
- return TupleExpr(input_exprs)
710
-
711
- def _create_cond_expr(self, node_proto: mpir_pb2.NodeProto) -> CondExpr:
712
- """Create a CondExpr from a NodeProto."""
713
- # Parse predicate and arguments
714
- pred_name = node_proto.inputs[0].split(":")[0]
715
- pred_expr = self._value_cache[pred_name]
716
-
717
- arg_exprs = []
718
- for input_name in node_proto.inputs[1:]:
719
- dep_name = input_name.split(":")[0]
720
- if dep_name in self._value_cache:
721
- arg_exprs.append(self._value_cache[dep_name])
722
- else:
723
- raise ValueError(f"Input {input_name} not found for cond node")
724
-
725
- # Parse functions
726
- then_fn = self._proto_to_attr(node_proto.attrs["then_fn"])
727
- else_fn = self._proto_to_attr(node_proto.attrs["else_fn"])
728
-
729
- return CondExpr(pred_expr, then_fn, else_fn, arg_exprs)
730
-
731
- def _create_while_expr(self, node_proto: mpir_pb2.NodeProto) -> WhileExpr:
732
- """Create a WhileExpr from a NodeProto."""
733
- # Parse arguments
734
- arg_exprs = []
735
- for input_name in node_proto.inputs:
736
- dep_name = input_name.split(":")[0]
737
- if dep_name in self._value_cache:
738
- arg_exprs.append(self._value_cache[dep_name])
739
- else:
740
- raise ValueError(f"Input {input_name} not found for while node")
741
-
742
- # Parse functions
743
- cond_fn = self._proto_to_attr(node_proto.attrs["cond_fn"])
744
- body_fn = self._proto_to_attr(node_proto.attrs["body_fn"])
745
-
746
- return WhileExpr(cond_fn, body_fn, arg_exprs)
747
-
748
- def _create_access_expr(self, node_proto: mpir_pb2.NodeProto) -> AccessExpr:
749
- """Create an AccessExpr from a NodeProto."""
750
- # Parse source expression
751
- input_name = node_proto.inputs[0]
752
- dep_name = input_name.split(":")[0]
753
- src_expr = self._value_cache[dep_name]
754
-
755
- # Parse index
756
- index = self._proto_to_attr(node_proto.attrs["index"])
757
-
758
- return AccessExpr(src_expr, index)
759
-
760
- def _create_func_def_expr(self, node_proto: mpir_pb2.NodeProto) -> FuncDefExpr:
761
- """Create a FuncDefExpr from a NodeProto."""
762
- # Parse body expression
763
- input_names = node_proto.inputs
764
- if not input_names:
765
- raise ValueError("FuncDef node missing body input")
766
-
767
- body_name = input_names[0].split(":")[0]
768
- body_expr = self._value_cache[body_name]
769
-
770
- # Parse parameters
771
- params = self._proto_to_attr(node_proto.attrs["params"])
772
-
773
- return FuncDefExpr(params, body_expr)
774
-
775
- def _create_shfl_s_expr(self, node_proto: mpir_pb2.NodeProto) -> ShflSExpr:
776
- """Create a ShflSExpr from a NodeProto."""
777
- # Parse source expression
778
- input_name = node_proto.inputs[0]
779
- dep_name = input_name.split(":")[0]
780
- src_val = self._value_cache[dep_name]
781
-
782
- # Parse attributes
783
- pmask = self._proto_to_attr(node_proto.attrs["pmask"])
784
- src_ranks = self._proto_to_attr(node_proto.attrs["src_ranks"])
785
-
786
- return ShflSExpr(src_val, pmask, src_ranks)
787
-
788
- def _create_shfl_expr(self, node_proto: mpir_pb2.NodeProto) -> ShflExpr:
789
- """Create a ShflExpr from a NodeProto."""
790
- # Parse source and index expressions
791
- src_name = node_proto.inputs[0].split(":")[0]
792
- index_name = node_proto.inputs[1].split(":")[0]
793
- src_expr = self._value_cache[src_name]
794
- index_expr = self._value_cache[index_name]
795
-
796
- return ShflExpr(src_expr, index_expr)
797
-
798
- def _create_conv_expr(self, node_proto: mpir_pb2.NodeProto) -> ConvExpr:
799
- """Create a ConvExpr from a NodeProto."""
800
- # Parse variable expressions
801
- var_exprs = []
802
- for input_name in node_proto.inputs:
803
- dep_name = input_name.split(":")[0]
804
- if dep_name in self._value_cache:
805
- var_exprs.append(self._value_cache[dep_name])
806
- else:
807
- raise ValueError(f"Input {input_name} not found for conv node")
808
-
809
- return ConvExpr(var_exprs)
810
-
811
- def _create_call_expr(self, node_proto: mpir_pb2.NodeProto) -> CallExpr:
812
- """Create a CallExpr from a NodeProto."""
813
- # Parse function and arguments
814
- fn_name = node_proto.inputs[0].split(":")[0]
815
- fn_expr = self._value_cache[fn_name]
816
-
817
- # Ensure function is FuncDefExpr
818
- if not isinstance(fn_expr, FuncDefExpr):
819
- raise ValueError(f"Call function must be FuncDefExpr, got {type(fn_expr)}")
820
-
821
- arg_exprs = []
822
- for input_name in node_proto.inputs[1:]:
823
- dep_name = input_name.split(":")[0]
824
- if dep_name in self._value_cache:
825
- arg_exprs.append(self._value_cache[dep_name])
826
- else:
827
- raise ValueError(f"Input {input_name} not found for call node")
828
- # Optional call-site name attribute
829
- call_name = None
830
- if "name" in node_proto.attrs:
831
- call_name = self._proto_to_attr(node_proto.attrs["name"]) # type: ignore[assignment]
832
-
833
- return CallExpr(call_name or "", fn_expr, arg_exprs)
834
-
835
- def _proto_to_mptype(self, type_proto: mpir_pb2.MPTypeProto) -> MPType:
836
- """Convert MPTypeProto to MPType."""
837
- # Convert pmask (now int64, -1 means dynamic mask (None))
838
- pmask_int = type_proto.pmask
839
- pmask = None if pmask_int == -1 else Mask(pmask_int)
840
-
841
- # Convert attributes
842
- attrs = {}
843
- for attr_name, attr_proto in type_proto.attrs.items():
844
- attrs[attr_name] = self._proto_to_attr(attr_proto)
845
-
846
- # Handle tensor type
847
- if type_proto.HasField("tensor_type"):
848
- tensor_type_proto = type_proto.tensor_type
849
- dtype = proto_to_dtype(tensor_type_proto.dtype)
850
- shape = tuple(tensor_type_proto.shape_dims)
851
- tensor_type = TensorType(dtype, shape)
852
- return MPType(tensor_type, pmask, attrs)
853
-
854
- # Handle table type
855
- elif type_proto.HasField("table_type"):
856
- table_type_proto = type_proto.table_type
857
- columns = []
858
- for column_proto in table_type_proto.columns:
859
- col_name = column_proto.name
860
- col_dtype = proto_to_dtype(column_proto.dtype)
861
- columns.append((col_name, col_dtype))
862
-
863
- table_type = TableType(tuple(columns))
864
- return MPType(table_type, pmask, attrs)
865
-
866
- else:
867
- raise ValueError(
868
- "MPTypeProto must specify either tensor_type or table_type"
869
- )
870
-
871
- def _proto_to_attr(self, attr_proto: mpir_pb2.AttrProto) -> Any:
872
- """Convert AttrProto to Python value."""
873
- if attr_proto.type == mpir_pb2.AttrProto.INT:
874
- return attr_proto.i
875
- elif attr_proto.type == mpir_pb2.AttrProto.FLOAT:
876
- return attr_proto.f
877
- elif attr_proto.type == mpir_pb2.AttrProto.STRING:
878
- return attr_proto.s
879
- elif attr_proto.type == mpir_pb2.AttrProto.BYTES:
880
- return attr_proto.raw_bytes
881
- elif attr_proto.type == mpir_pb2.AttrProto.INTS:
882
- return list(attr_proto.ints)
883
- elif attr_proto.type == mpir_pb2.AttrProto.FLOATS:
884
- return list(attr_proto.floats)
885
- elif attr_proto.type == mpir_pb2.AttrProto.STRINGS:
886
- return list(attr_proto.strs)
887
- elif attr_proto.type == mpir_pb2.AttrProto.FUNCTION:
888
- # Reconstruct PFunction - since Expr already contains MPType information,
889
- # we don't need to reconstruct ins_info and outs_info from serialized data.
890
- # The type information will be inferred from the actual input expressions.
891
-
892
- # Deserialize attrs dictionary
893
- attrs = {}
894
- for attr_name, attr_value_proto in attr_proto.func.attrs.items():
895
- attrs[attr_name] = self._proto_to_attr(attr_value_proto)
896
-
897
- return PFunction(
898
- fn_type=attr_proto.func.type,
899
- ins_info=[], # Will be inferred from input expressions
900
- outs_info=[], # Will be inferred from context
901
- fn_name=attr_proto.func.name or None,
902
- fn_text=attr_proto.func.body if attr_proto.func.body else None,
903
- **attrs, # Restore serialized attributes
904
- )
905
- elif attr_proto.type == mpir_pb2.AttrProto.GRAPH:
906
- # Handle nested expressions (for control flow)
907
- reader = IrReader()
908
- return reader.loads(attr_proto.graph)
909
- else:
910
- raise TypeError(f"Unsupported attribute type: {attr_proto.type}")
911
-
912
-
913
- def get_graph_statistics(graph_proto: mpir_pb2.GraphProto) -> str:
914
- """Get statistics about a GraphProto structure.
915
-
916
- Args:
917
- graph_proto: The protobuf GraphProto to analyze
918
-
919
- Returns:
920
- A formatted string with:
921
- - Graph version information
922
- - Node count and breakdown by operation type
923
- - Output variable information
924
- - Graph attributes count
925
- """
926
- # Build statistics string
927
- lines = []
928
- lines.append("GraphProto structure analysis:")
929
-
930
- # Version information with compatibility check
931
- try:
932
- version = graph_proto.version
933
- version_str = f"{version.major}.{version.minor}.{version.patch}"
934
- lines.append(f"- Version: {version_str}")
935
-
936
- # Version compatibility check
937
- if version.major != 1:
938
- lines.append(f" WARNING: Expected major version 1, got {version.major}")
939
- except AttributeError:
940
- lines.append("- Version: Unknown (missing version info)")
941
- version_str = "unknown"
942
-
943
- # Node and output counts
944
- lines.append(f"- Number of nodes: {len(graph_proto.nodes)}")
945
- lines.append(f"- Number of outputs: {len(graph_proto.outputs)}")
946
- lines.append(f"- Graph attributes: {len(graph_proto.attrs)}")
947
- lines.append("")
948
-
949
- # Node breakdown by operation type
950
- lines.append("Node breakdown by operation type:")
951
- op_counts: dict[str, int] = {}
952
- for node in graph_proto.nodes:
953
- op_type = node.op_type
954
- op_counts[op_type] = op_counts.get(op_type, 0) + 1
955
-
956
- for op_type, count in sorted(op_counts.items()):
957
- lines.append(f"- {op_type}: {count} nodes")
958
- lines.append("")
959
-
960
- # Output variables
961
- lines.append("Output variables:")
962
- for i, output in enumerate(graph_proto.outputs):
963
- lines.append(f"- Output {i}: {output}")
964
-
965
- return "\n".join(lines)