mplang-nightly 0.1.dev192__py3-none-any.whl → 0.1.dev268__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (188) hide show
  1. mplang/__init__.py +21 -130
  2. mplang/py.typed +13 -0
  3. mplang/v1/__init__.py +157 -0
  4. mplang/v1/_device.py +602 -0
  5. mplang/{analysis → v1/analysis}/__init__.py +1 -1
  6. mplang/{analysis → v1/analysis}/diagram.py +4 -4
  7. mplang/{core → v1/core}/__init__.py +20 -14
  8. mplang/{core → v1/core}/cluster.py +6 -1
  9. mplang/{core → v1/core}/comm.py +1 -1
  10. mplang/{core → v1/core}/context_mgr.py +1 -1
  11. mplang/{core → v1/core}/dtypes.py +38 -0
  12. mplang/{core → v1/core}/expr/__init__.py +7 -7
  13. mplang/{core → v1/core}/expr/ast.py +11 -13
  14. mplang/{core → v1/core}/expr/evaluator.py +8 -8
  15. mplang/{core → v1/core}/expr/printer.py +6 -6
  16. mplang/{core → v1/core}/expr/transformer.py +2 -2
  17. mplang/{core → v1/core}/expr/utils.py +2 -2
  18. mplang/{core → v1/core}/expr/visitor.py +1 -1
  19. mplang/{core → v1/core}/expr/walk.py +1 -1
  20. mplang/{core → v1/core}/interp.py +6 -6
  21. mplang/{core → v1/core}/mpir.py +13 -11
  22. mplang/{core → v1/core}/mpobject.py +6 -6
  23. mplang/{core → v1/core}/mptype.py +13 -10
  24. mplang/{core → v1/core}/pfunc.py +2 -2
  25. mplang/{core → v1/core}/primitive.py +12 -12
  26. mplang/{core → v1/core}/table.py +36 -8
  27. mplang/{core → v1/core}/tensor.py +1 -1
  28. mplang/{core → v1/core}/tracer.py +9 -9
  29. mplang/{host.py → v1/host.py} +5 -5
  30. mplang/{kernels → v1/kernels}/__init__.py +1 -1
  31. mplang/{kernels → v1/kernels}/base.py +1 -1
  32. mplang/{kernels → v1/kernels}/basic.py +15 -15
  33. mplang/{kernels → v1/kernels}/context.py +19 -16
  34. mplang/{kernels → v1/kernels}/crypto.py +8 -10
  35. mplang/{kernels → v1/kernels}/fhe.py +9 -7
  36. mplang/{kernels → v1/kernels}/mock_tee.py +3 -3
  37. mplang/{kernels → v1/kernels}/phe.py +26 -18
  38. mplang/{kernels → v1/kernels}/spu.py +5 -5
  39. mplang/{kernels → v1/kernels}/sql_duckdb.py +5 -3
  40. mplang/{kernels → v1/kernels}/stablehlo.py +18 -17
  41. mplang/{kernels → v1/kernels}/value.py +2 -2
  42. mplang/{ops → v1/ops}/__init__.py +3 -3
  43. mplang/{ops → v1/ops}/base.py +1 -1
  44. mplang/{ops → v1/ops}/basic.py +6 -5
  45. mplang/v1/ops/crypto.py +262 -0
  46. mplang/{ops → v1/ops}/fhe.py +2 -2
  47. mplang/{ops → v1/ops}/jax_cc.py +26 -59
  48. mplang/v1/ops/nnx_cc.py +168 -0
  49. mplang/{ops → v1/ops}/phe.py +16 -3
  50. mplang/{ops → v1/ops}/spu.py +3 -3
  51. mplang/v1/ops/sql_cc.py +303 -0
  52. mplang/{ops → v1/ops}/tee.py +2 -2
  53. mplang/{runtime → v1/runtime}/__init__.py +2 -2
  54. mplang/v1/runtime/channel.py +230 -0
  55. mplang/{runtime → v1/runtime}/cli.py +3 -3
  56. mplang/{runtime → v1/runtime}/client.py +1 -1
  57. mplang/{runtime → v1/runtime}/communicator.py +39 -15
  58. mplang/{runtime → v1/runtime}/data_providers.py +80 -19
  59. mplang/{runtime → v1/runtime}/driver.py +4 -4
  60. mplang/v1/runtime/link_comm.py +196 -0
  61. mplang/{runtime → v1/runtime}/server.py +22 -9
  62. mplang/{runtime → v1/runtime}/session.py +24 -51
  63. mplang/{runtime → v1/runtime}/simulation.py +36 -14
  64. mplang/{simp → v1/simp}/api.py +72 -14
  65. mplang/{simp → v1/simp}/mpi.py +1 -1
  66. mplang/{simp → v1/simp}/party.py +5 -5
  67. mplang/{simp → v1/simp}/random.py +2 -2
  68. mplang/v1/simp/smpc.py +238 -0
  69. mplang/v1/utils/table_utils.py +185 -0
  70. mplang/v2/__init__.py +424 -0
  71. mplang/v2/backends/__init__.py +57 -0
  72. mplang/v2/backends/bfv_impl.py +705 -0
  73. mplang/v2/backends/channel.py +217 -0
  74. mplang/v2/backends/crypto_impl.py +723 -0
  75. mplang/v2/backends/field_impl.py +454 -0
  76. mplang/v2/backends/func_impl.py +107 -0
  77. mplang/v2/backends/phe_impl.py +148 -0
  78. mplang/v2/backends/simp_design.md +136 -0
  79. mplang/v2/backends/simp_driver/__init__.py +41 -0
  80. mplang/v2/backends/simp_driver/http.py +168 -0
  81. mplang/v2/backends/simp_driver/mem.py +280 -0
  82. mplang/v2/backends/simp_driver/ops.py +135 -0
  83. mplang/v2/backends/simp_driver/state.py +60 -0
  84. mplang/v2/backends/simp_driver/values.py +52 -0
  85. mplang/v2/backends/simp_worker/__init__.py +29 -0
  86. mplang/v2/backends/simp_worker/http.py +354 -0
  87. mplang/v2/backends/simp_worker/mem.py +102 -0
  88. mplang/v2/backends/simp_worker/ops.py +167 -0
  89. mplang/v2/backends/simp_worker/state.py +49 -0
  90. mplang/v2/backends/spu_impl.py +275 -0
  91. mplang/v2/backends/spu_state.py +187 -0
  92. mplang/v2/backends/store_impl.py +62 -0
  93. mplang/v2/backends/table_impl.py +838 -0
  94. mplang/v2/backends/tee_impl.py +215 -0
  95. mplang/v2/backends/tensor_impl.py +519 -0
  96. mplang/v2/cli.py +603 -0
  97. mplang/v2/cli_guide.md +122 -0
  98. mplang/v2/dialects/__init__.py +36 -0
  99. mplang/v2/dialects/bfv.py +665 -0
  100. mplang/v2/dialects/crypto.py +689 -0
  101. mplang/v2/dialects/dtypes.py +378 -0
  102. mplang/v2/dialects/field.py +210 -0
  103. mplang/v2/dialects/func.py +135 -0
  104. mplang/v2/dialects/phe.py +723 -0
  105. mplang/v2/dialects/simp.py +944 -0
  106. mplang/v2/dialects/spu.py +349 -0
  107. mplang/v2/dialects/store.py +63 -0
  108. mplang/v2/dialects/table.py +407 -0
  109. mplang/v2/dialects/tee.py +346 -0
  110. mplang/v2/dialects/tensor.py +1175 -0
  111. mplang/v2/edsl/README.md +279 -0
  112. mplang/v2/edsl/__init__.py +99 -0
  113. mplang/v2/edsl/context.py +311 -0
  114. mplang/v2/edsl/graph.py +463 -0
  115. mplang/v2/edsl/jit.py +62 -0
  116. mplang/v2/edsl/object.py +53 -0
  117. mplang/v2/edsl/primitive.py +284 -0
  118. mplang/v2/edsl/printer.py +119 -0
  119. mplang/v2/edsl/registry.py +207 -0
  120. mplang/v2/edsl/serde.py +375 -0
  121. mplang/v2/edsl/tracer.py +614 -0
  122. mplang/v2/edsl/typing.py +816 -0
  123. mplang/v2/kernels/Makefile +30 -0
  124. mplang/v2/kernels/__init__.py +23 -0
  125. mplang/v2/kernels/gf128.cpp +148 -0
  126. mplang/v2/kernels/ldpc.cpp +82 -0
  127. mplang/v2/kernels/okvs.cpp +283 -0
  128. mplang/v2/kernels/okvs_opt.cpp +291 -0
  129. mplang/v2/kernels/py_kernels.py +398 -0
  130. mplang/v2/libs/collective.py +330 -0
  131. mplang/v2/libs/device/__init__.py +51 -0
  132. mplang/v2/libs/device/api.py +813 -0
  133. mplang/v2/libs/device/cluster.py +352 -0
  134. mplang/v2/libs/ml/__init__.py +23 -0
  135. mplang/v2/libs/ml/sgb.py +1861 -0
  136. mplang/v2/libs/mpc/__init__.py +41 -0
  137. mplang/v2/libs/mpc/_utils.py +99 -0
  138. mplang/v2/libs/mpc/analytics/__init__.py +35 -0
  139. mplang/v2/libs/mpc/analytics/aggregation.py +372 -0
  140. mplang/v2/libs/mpc/analytics/groupby.md +99 -0
  141. mplang/v2/libs/mpc/analytics/groupby.py +331 -0
  142. mplang/v2/libs/mpc/analytics/permutation.py +386 -0
  143. mplang/v2/libs/mpc/common/constants.py +39 -0
  144. mplang/v2/libs/mpc/ot/__init__.py +32 -0
  145. mplang/v2/libs/mpc/ot/base.py +222 -0
  146. mplang/v2/libs/mpc/ot/extension.py +477 -0
  147. mplang/v2/libs/mpc/ot/silent.py +217 -0
  148. mplang/v2/libs/mpc/psi/__init__.py +40 -0
  149. mplang/v2/libs/mpc/psi/cuckoo.py +228 -0
  150. mplang/v2/libs/mpc/psi/okvs.py +49 -0
  151. mplang/v2/libs/mpc/psi/okvs_gct.py +79 -0
  152. mplang/v2/libs/mpc/psi/oprf.py +310 -0
  153. mplang/v2/libs/mpc/psi/rr22.py +344 -0
  154. mplang/v2/libs/mpc/psi/unbalanced.py +200 -0
  155. mplang/v2/libs/mpc/vole/__init__.py +31 -0
  156. mplang/v2/libs/mpc/vole/gilboa.py +327 -0
  157. mplang/v2/libs/mpc/vole/ldpc.py +383 -0
  158. mplang/v2/libs/mpc/vole/silver.py +336 -0
  159. mplang/v2/runtime/__init__.py +15 -0
  160. mplang/v2/runtime/dialect_state.py +41 -0
  161. mplang/v2/runtime/interpreter.py +871 -0
  162. mplang/v2/runtime/object_store.py +194 -0
  163. mplang/v2/runtime/value.py +141 -0
  164. {mplang_nightly-0.1.dev192.dist-info → mplang_nightly-0.1.dev268.dist-info}/METADATA +22 -16
  165. mplang_nightly-0.1.dev268.dist-info/RECORD +180 -0
  166. {mplang_nightly-0.1.dev192.dist-info → mplang_nightly-0.1.dev268.dist-info}/WHEEL +1 -1
  167. mplang/device.py +0 -327
  168. mplang/ops/crypto.py +0 -108
  169. mplang/ops/ibis_cc.py +0 -136
  170. mplang/ops/sql_cc.py +0 -62
  171. mplang/runtime/link_comm.py +0 -78
  172. mplang/simp/smpc.py +0 -201
  173. mplang/utils/table_utils.py +0 -85
  174. mplang_nightly-0.1.dev192.dist-info/RECORD +0 -83
  175. /mplang/{core → v1/core}/mask.py +0 -0
  176. /mplang/{protos → v1/protos}/v1alpha1/mpir_pb2.py +0 -0
  177. /mplang/{protos → v1/protos}/v1alpha1/mpir_pb2.pyi +0 -0
  178. /mplang/{protos → v1/protos}/v1alpha1/value_pb2.py +0 -0
  179. /mplang/{protos → v1/protos}/v1alpha1/value_pb2.pyi +0 -0
  180. /mplang/{runtime → v1/runtime}/exceptions.py +0 -0
  181. /mplang/{runtime → v1/runtime}/http_api.md +0 -0
  182. /mplang/{simp → v1/simp}/__init__.py +0 -0
  183. /mplang/{utils → v1/utils}/__init__.py +0 -0
  184. /mplang/{utils → v1/utils}/crypto.py +0 -0
  185. /mplang/{utils → v1/utils}/func_utils.py +0 -0
  186. /mplang/{utils → v1/utils}/spu_utils.py +0 -0
  187. {mplang_nightly-0.1.dev192.dist-info → mplang_nightly-0.1.dev268.dist-info}/entry_points.txt +0 -0
  188. {mplang_nightly-0.1.dev192.dist-info → mplang_nightly-0.1.dev268.dist-info}/licenses/LICENSE +0 -0
@@ -0,0 +1,463 @@
1
+ # Copyright 2025 Ant Group Co., Ltd.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ """
16
+ Graph IR: Operation List + SSA Values.
17
+
18
+ This module implements a modern, flat IR representation inspired by torch.fx
19
+ and JAX jaxpr, replacing the tree-based Expr system.
20
+
21
+ Key Design Principles:
22
+ ----------------------
23
+ 1. **Flat Structure**: Operations in a list, not a tree
24
+ 2. **SSA Form**: Each value defined once, use-def chains explicit
25
+ 3. **Easy Traversal**: No visitor pattern needed
26
+ 4. **Optimization-Friendly**: Dead code elimination, fusion, etc.
27
+
28
+ Example:
29
+ --------
30
+ from mplang.v2.edsl.graph import Graph, Operation, Value
31
+ from mplang.v2.edsl.typing import Tensor, f32
32
+
33
+ graph = Graph()
34
+
35
+ # Create values
36
+ x = graph.add_input("x", Tensor[f32, (10,)])
37
+ y = graph.add_input("y", Tensor[f32, (10,)])
38
+
39
+ # Add operations
40
+ z, = graph.add_op("add", [x, y])
41
+ scale, = graph.add_op("tensor.constant", [], output_types=[f32], attrs={"data": 2.0})
42
+ result, = graph.add_op("mul", [z, scale])
43
+
44
+ # Mark outputs
45
+ graph.add_output(result)
46
+
47
+ # Print IR
48
+ print(graph.to_string())
49
+ # Output:
50
+ # %0 = input "x" : Tensor[f32, (10,)]
51
+ # %1 = input "y" : Tensor[f32, (10,)]
52
+ # %2 = tensor.constant {data=2.0} : f32
53
+ # %3 = add %0, %1 : Tensor[f32, (10,)]
54
+ # %4 = mul %3, %2 : Tensor[f32, (10,)]
55
+ # return %4
56
+ """
57
+
58
+ from __future__ import annotations
59
+
60
+ from collections.abc import Sequence
61
+ from dataclasses import dataclass, field
62
+ from typing import Any, ClassVar
63
+
64
+ from mplang.v2.edsl import serde
65
+ from mplang.v2.edsl.typing import BaseType
66
+
67
+
68
+ @dataclass
69
+ class Value:
70
+ """SSA value in the IR.
71
+
72
+ Each value is defined exactly once by an operation (or is an input).
73
+ Values track their uses and defining operation for def-use chain analysis.
74
+
75
+ Attributes:
76
+ name: Unique SSA name (e.g., "%0", "%1", ...)
77
+ type: Type of this value (from mplang.v2.edsl.typing)
78
+ defining_op: Operation that produces this value (None for inputs)
79
+ uses: List of operations that consume this value
80
+ """
81
+
82
+ name: str
83
+ type: BaseType
84
+ defining_op: Operation | None = None
85
+ uses: dict[Operation, None] = field(default_factory=dict)
86
+
87
+ def __repr__(self) -> str:
88
+ return f"Value({self.name}: {self.type})"
89
+
90
+ def __str__(self) -> str:
91
+ return self.name
92
+
93
+ def __hash__(self) -> int:
94
+ return id(self)
95
+
96
+ def __eq__(self, other: object) -> bool:
97
+ return self is other
98
+
99
+ def add_use(self, op: Operation) -> None:
100
+ """Register an operation that uses this value."""
101
+ self.uses[op] = None
102
+
103
+ def remove_use(self, op: Operation) -> None:
104
+ """Unregister an operation that uses this value."""
105
+ if op in self.uses:
106
+ del self.uses[op]
107
+
108
+ @property
109
+ def num_uses(self) -> int:
110
+ """Number of operations using this value."""
111
+ return len(self.uses)
112
+
113
+ @property
114
+ def is_dead(self) -> bool:
115
+ """True if this value is never used (dead code)."""
116
+ return self.num_uses == 0 and self.defining_op is not None
117
+
118
+ @property
119
+ def is_bound(self) -> bool:
120
+ """True if this value is bound (defined by an operation)."""
121
+ return self.defining_op is not None
122
+
123
+ @property
124
+ def is_free(self) -> bool:
125
+ """True if this value is free (graph input, not defined by operation)."""
126
+ return self.defining_op is None
127
+
128
+
129
+ @dataclass
130
+ class Operation:
131
+ """Single operation in the IR.
132
+
133
+ Operations represent computations in the graph. They consume input values
134
+ and produce output values.
135
+
136
+ Attributes:
137
+ opcode: Operation name (e.g., "add", "mul", "cond")
138
+ inputs: Input values consumed by this operation
139
+ outputs: Output values produced by this operation
140
+ attrs: Additional attributes (e.g., shape, dtype, backend-specific)
141
+ regions: Nested graphs (for control flow: cond, while)
142
+ """
143
+
144
+ opcode: str
145
+ inputs: list[Value]
146
+ outputs: list[Value]
147
+ attrs: dict[str, Any] = field(default_factory=dict)
148
+ regions: list[Graph] = field(default_factory=list)
149
+ name: str = field(default="")
150
+
151
+ def __eq__(self, other: object) -> bool:
152
+ return self is other
153
+
154
+ def __hash__(self) -> int:
155
+ return id(self)
156
+
157
+ def __post_init__(self) -> None:
158
+ """Register this operation as the definer and user of values."""
159
+ # Register as defining op for outputs
160
+ for output in self.outputs:
161
+ output.defining_op = self
162
+
163
+ # Register as user for inputs
164
+ for input_val in self.inputs:
165
+ input_val.add_use(self)
166
+
167
+ def __repr__(self) -> str:
168
+ inputs_str = ", ".join(str(v) for v in self.inputs)
169
+ outputs_str = ", ".join(str(v) for v in self.outputs)
170
+ return f"Operation({self.opcode}: {inputs_str} -> {outputs_str})"
171
+
172
+ def replace_input(self, old: Value, new: Value) -> None:
173
+ """Replace an input value (updates use-def chains)."""
174
+ for i, inp in enumerate(self.inputs):
175
+ if inp is old:
176
+ self.inputs[i] = new
177
+ old.remove_use(self)
178
+ new.add_use(self)
179
+
180
+ def erase(self) -> None:
181
+ """Remove this operation (updates use-def chains)."""
182
+ for inp in self.inputs:
183
+ inp.remove_use(self)
184
+ for out in self.outputs:
185
+ out.defining_op = None
186
+
187
+
188
+ @serde.register_class
189
+ class Graph:
190
+ """Computation graph as a flat list of operations.
191
+
192
+ A graph contains:
193
+ - Inputs: Named input values
194
+ - Operations: Flat list of computations
195
+ - Outputs: Values returned from the graph
196
+ - Values: All SSA values in the graph
197
+
198
+ Example:
199
+ graph = Graph()
200
+ x = graph.add_input("x", Tensor[f32, (10,)])
201
+ y, = graph.add_op("tensor.constant", [], output_types=[f32], attrs={"data": 1.0})
202
+ z, = graph.add_op("add", [x, y])
203
+ graph.add_output(z)
204
+ """
205
+
206
+ _serde_kind: ClassVar[str] = "mplang.Graph"
207
+
208
+ def __init__(self) -> None:
209
+ self.operations: list[Operation] = []
210
+ self.values: dict[str, Value] = {}
211
+ self.inputs: list[Value] = []
212
+ self.outputs: list[Value] = []
213
+ self._value_counter = 0
214
+ self._op_counter = 0
215
+
216
+ def _gen_value_name(self) -> str:
217
+ """Generate a unique SSA value name."""
218
+ name = f"%{self._value_counter}"
219
+ self._value_counter += 1
220
+ return name
221
+
222
+ def add_value(self, type: BaseType, name: str | None = None) -> Value:
223
+ """Create a new SSA value.
224
+
225
+ Args:
226
+ type: Type of the value
227
+ name: Optional custom name (auto-generated if None)
228
+
229
+ Returns:
230
+ New Value instance
231
+ """
232
+ if name is None:
233
+ name = self._gen_value_name()
234
+
235
+ if name in self.values:
236
+ raise ValueError(f"Value {name} already exists")
237
+
238
+ value = Value(name, type)
239
+ self.values[name] = value
240
+ return value
241
+
242
+ def add_input(self, name: str, type: BaseType) -> Value:
243
+ """Add a graph input.
244
+
245
+ Args:
246
+ name: Input parameter name
247
+ type: Type of the input
248
+
249
+ Returns:
250
+ Input value
251
+ """
252
+ value = self.add_value(type, name=name)
253
+ self.inputs.append(value)
254
+ return value
255
+
256
+ def add_op(
257
+ self,
258
+ opcode: str,
259
+ inputs: list[Value],
260
+ output_types: Sequence[BaseType] | None = None,
261
+ attrs: dict[str, Any] | None = None,
262
+ regions: list[Graph] | None = None,
263
+ ) -> list[Value]:
264
+ """Add an operation to the graph.
265
+
266
+ Args:
267
+ opcode: Operation name
268
+ inputs: Input values
269
+ output_types: Types of outputs (inferred if None)
270
+ attrs: Additional attributes
271
+ regions: Nested graphs (for control flow)
272
+
273
+ Returns:
274
+ List of output values (one entry per output)
275
+ """
276
+ # Type inference (placeholder - should be backend-specific)
277
+ if output_types is None:
278
+ # Simple rule: inherit from first input
279
+ if inputs:
280
+ output_types = [inputs[0].type]
281
+ else:
282
+ raise ValueError(f"Cannot infer type for {opcode} with no inputs")
283
+
284
+ # Create output values
285
+ outputs = [self.add_value(t) for t in output_types]
286
+
287
+ # Create operation
288
+ op_name = f"op{self._op_counter}"
289
+ self._op_counter += 1
290
+ op = Operation(
291
+ opcode=opcode,
292
+ inputs=inputs,
293
+ outputs=outputs,
294
+ attrs=attrs or {},
295
+ regions=regions or [],
296
+ name=op_name,
297
+ )
298
+ self.operations.append(op)
299
+
300
+ return outputs
301
+
302
+ def add_output(self, value: Value) -> None:
303
+ """Mark a value as a graph output.
304
+
305
+ Args:
306
+ value: Value to be returned from the graph
307
+ """
308
+ if value not in self.values.values():
309
+ raise ValueError(f"Value {value} not in graph")
310
+ self.outputs.append(value)
311
+
312
+ def to_string(self, verbose: bool = False) -> str:
313
+ """Generate human-readable IR representation.
314
+
315
+ Args:
316
+ verbose: Include type annotations
317
+
318
+ Returns:
319
+ String representation of the graph
320
+ """
321
+ lines = []
322
+
323
+ # Print inputs
324
+ for inp in self.inputs:
325
+ type_str = f" : {inp.type}" if verbose else ""
326
+ lines.append(f"{inp.name} = input{type_str}")
327
+
328
+ # Print operations
329
+ for op in self.operations:
330
+ if op.opcode == "constant":
331
+ value_str = op.attrs.get("value", "?")
332
+ type_str = f" : {op.outputs[0].type}" if verbose else ""
333
+ lines.append(f"{op.outputs[0].name} = constant {value_str}{type_str}")
334
+ else:
335
+ inputs_str = ", ".join(str(v) for v in op.inputs)
336
+ outputs_str = ", ".join(str(v) for v in op.outputs)
337
+
338
+ # Handle single vs multiple outputs
339
+ if len(op.outputs) == 1:
340
+ lhs = str(op.outputs[0])
341
+ else:
342
+ lhs = f"[{outputs_str}]"
343
+
344
+ type_str = f" : {op.outputs[0].type}" if verbose and op.outputs else ""
345
+
346
+ if op.attrs:
347
+ attrs_str = ", ".join(f"{k}={v}" for k, v in op.attrs.items())
348
+ lines.append(
349
+ f"{lhs} = {op.opcode}({inputs_str}) {{{attrs_str}}}{type_str}"
350
+ )
351
+ else:
352
+ lines.append(f"{lhs} = {op.opcode}({inputs_str}){type_str}")
353
+
354
+ # Print outputs
355
+ if self.outputs:
356
+ outputs_str = ", ".join(str(v) for v in self.outputs)
357
+ lines.append(f"return {outputs_str}")
358
+
359
+ return "\n".join(lines)
360
+
361
+ def __repr__(self) -> str:
362
+ return f"Graph({len(self.operations)} ops, {len(self.values)} values)"
363
+
364
+ def __str__(self) -> str:
365
+ return self.to_string()
366
+
367
+ # =========================================================================
368
+ # Serialization
369
+ # =========================================================================
370
+
371
+ def to_json(self) -> dict:
372
+ """Serialize graph to JSON-compatible dict."""
373
+
374
+ def _type_to_json(t: BaseType) -> dict:
375
+ return serde.to_json(t)
376
+
377
+ def _attr_to_json(value: Any) -> dict:
378
+ return serde.to_json(value)
379
+
380
+ def _attrs_to_json(attrs: dict[str, Any]) -> dict[str, Any]:
381
+ return {k: _attr_to_json(v) for k, v in attrs.items()}
382
+
383
+ return {
384
+ "inputs": [
385
+ {"name": v.name, "type": _type_to_json(v.type)} for v in self.inputs
386
+ ],
387
+ "operations": [
388
+ {
389
+ "opcode": op.opcode,
390
+ "inputs": [v.name for v in op.inputs],
391
+ "outputs": [
392
+ {"name": v.name, "type": _type_to_json(v.type)}
393
+ for v in op.outputs
394
+ ],
395
+ "attrs": _attrs_to_json(op.attrs),
396
+ "regions": [serde.to_json(r) for r in op.regions],
397
+ "name": op.name,
398
+ }
399
+ for op in self.operations
400
+ ],
401
+ "outputs": [v.name for v in self.outputs],
402
+ }
403
+
404
+ @classmethod
405
+ def from_json(cls, data: dict) -> Graph:
406
+ """Deserialize graph from JSON-compatible dict."""
407
+
408
+ def _type_from_json(d: dict) -> BaseType:
409
+ result = serde.from_json(d)
410
+ if not isinstance(result, BaseType):
411
+ raise TypeError(f"Expected BaseType, got {type(result)}")
412
+ return result
413
+
414
+ def _attr_from_json(value: dict) -> Any:
415
+ return serde.from_json(value)
416
+
417
+ def _attrs_from_json(attrs: dict[str, Any]) -> dict[str, Any]:
418
+ return {k: _attr_from_json(v) for k, v in attrs.items()}
419
+
420
+ graph = cls()
421
+
422
+ # Reconstruct inputs
423
+ for inp_data in data["inputs"]:
424
+ graph.add_input(inp_data["name"], _type_from_json(inp_data["type"]))
425
+
426
+ # Reconstruct operations
427
+ for op_data in data["operations"]:
428
+ # Resolve input values by name
429
+ inputs = [graph.values[name] for name in op_data["inputs"]]
430
+
431
+ # Get output types
432
+ output_types = [_type_from_json(out["type"]) for out in op_data["outputs"]]
433
+
434
+ # Deserialize nested graphs (regions)
435
+ regions = [serde.from_json(r) for r in op_data.get("regions", [])]
436
+
437
+ # Add operation
438
+ outputs = graph.add_op(
439
+ op_data["opcode"],
440
+ inputs,
441
+ output_types=output_types,
442
+ attrs=_attrs_from_json(op_data.get("attrs", {})),
443
+ regions=regions,
444
+ )
445
+
446
+ # Rename outputs to match original names
447
+ for out_val, out_data in zip(outputs, op_data["outputs"], strict=False):
448
+ original_name = out_data["name"]
449
+ if out_val.name != original_name:
450
+ # Update the values dict with the original name
451
+ del graph.values[out_val.name]
452
+ out_val.name = original_name
453
+ graph.values[original_name] = out_val
454
+
455
+ # Set operation name if provided
456
+ if op_data.get("name"):
457
+ graph.operations[-1].name = op_data["name"]
458
+
459
+ # Reconstruct outputs
460
+ for name in data["outputs"]:
461
+ graph.add_output(graph.values[name])
462
+
463
+ return graph
mplang/v2/edsl/jit.py ADDED
@@ -0,0 +1,62 @@
1
+ # Copyright 2025 Ant Group Co., Ltd.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ """JIT Decorator: Compile and cache Graph IR."""
16
+
17
+ from collections.abc import Callable
18
+ from typing import Any
19
+
20
+ from jax.tree_util import tree_map
21
+
22
+ from mplang.v2.edsl.context import (
23
+ AbstractInterpreter,
24
+ get_current_context,
25
+ get_default_context,
26
+ )
27
+ from mplang.v2.edsl.tracer import Tracer
28
+
29
+
30
+ def jit(fn: Callable) -> Callable:
31
+ """JIT compilation decorator.
32
+
33
+ Traces the function to Graph IR on first call, then executes the cached
34
+ Graph on subsequent calls.
35
+
36
+ Example:
37
+ >>> @jit
38
+ ... def compute(x, y):
39
+ ... return x + y
40
+ >>> result = compute(x_interp, y_interp) # First call: trace
41
+ >>> result = compute(x_interp, y_interp) # Subsequent: execute cached graph
42
+ """
43
+
44
+ def wrapper(*args: Any, **kwargs: Any) -> Any:
45
+ # If we are already inside a Tracer (e.g. pcall_static), just inline
46
+ # the function to trace it into the current graph.
47
+ cur_ctx = get_current_context()
48
+ if isinstance(cur_ctx, Tracer):
49
+ return fn(*args, **kwargs)
50
+
51
+ # otherwise trace for JIT compilation
52
+ with Tracer():
53
+ result = fn(*args, **kwargs)
54
+
55
+ # Use current context if available (e.g., SimpSimulator), otherwise use default
56
+ cur_ctx = cur_ctx or get_default_context()
57
+ assert isinstance(cur_ctx, AbstractInterpreter), (
58
+ "JIT execution requires Interpreter context"
59
+ )
60
+ return tree_map(cur_ctx.lift, result)
61
+
62
+ return wrapper
@@ -0,0 +1,53 @@
1
+ # Copyright 2025 Ant Group Co., Ltd.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ """Object: Base class for runtime objects.
16
+
17
+ Base abstraction for distinguishing trace-time and interp-time execution.
18
+
19
+ - TraceObject: Defined in mplang.edsl.tracer
20
+ - InterpObject: Defined in mplang.edsl.interpreter
21
+ """
22
+
23
+ from __future__ import annotations
24
+
25
+ from abc import ABC, abstractmethod
26
+ from typing import Generic, TypeVar
27
+
28
+ from mplang.v2.edsl.typing import BaseType
29
+
30
+ T = TypeVar("T", bound=BaseType)
31
+
32
+
33
+ class Object(ABC, Generic[T]):
34
+ """Base class for MPLang runtime objects.
35
+
36
+ This is a Driver-side abstraction used for:
37
+ 1. Distinguishing between trace-time and interp-time objects
38
+ 2. Providing uniform operation interfaces (arithmetic, attribute access, etc.)
39
+ 3. Enabling polymorphic handling by the Tracer
40
+
41
+ Subclasses:
42
+ - TraceObject: Trace-time object (holds a Value in Graph IR) - in mplang.edsl.tracer
43
+ - InterpObject: Interp-time object (holds backend-specific runtime data) - in mplang.edsl.interpreter
44
+ """
45
+
46
+ @property
47
+ @abstractmethod
48
+ def type(self) -> T:
49
+ """Type of the object (available in both trace and interp modes)."""
50
+
51
+ # Note: Arithmetic operators (__add__, __mul__, etc.) are NOT defined here.
52
+ # They should be provided by dialect-specific dispatch mechanisms since
53
+ # different types (Tensor, Vector, SS) require different implementations.