mplang-nightly 0.1.dev158__py3-none-any.whl → 0.1.dev268__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- mplang/__init__.py +21 -45
- mplang/py.typed +13 -0
- mplang/v1/__init__.py +157 -0
- mplang/v1/_device.py +602 -0
- mplang/{analysis → v1/analysis}/__init__.py +1 -1
- mplang/{analysis → v1/analysis}/diagram.py +5 -7
- mplang/v1/core/__init__.py +157 -0
- mplang/{core → v1/core}/cluster.py +30 -14
- mplang/{core → v1/core}/comm.py +5 -1
- mplang/{core → v1/core}/context_mgr.py +1 -1
- mplang/{core/dtype.py → v1/core/dtypes.py} +44 -2
- mplang/{core → v1/core}/expr/__init__.py +7 -7
- mplang/{core → v1/core}/expr/ast.py +13 -14
- mplang/{core → v1/core}/expr/evaluator.py +65 -24
- mplang/{core → v1/core}/expr/printer.py +24 -18
- mplang/{core → v1/core}/expr/transformer.py +3 -3
- mplang/{core → v1/core}/expr/utils.py +2 -2
- mplang/{core → v1/core}/expr/visitor.py +1 -1
- mplang/{core → v1/core}/expr/walk.py +1 -1
- mplang/{core → v1/core}/interp.py +6 -6
- mplang/{core → v1/core}/mpir.py +23 -16
- mplang/{core → v1/core}/mpobject.py +6 -6
- mplang/{core → v1/core}/mptype.py +13 -10
- mplang/{core → v1/core}/pfunc.py +4 -4
- mplang/{core → v1/core}/primitive.py +106 -201
- mplang/{core → v1/core}/table.py +36 -8
- mplang/{core → v1/core}/tensor.py +1 -1
- mplang/{core → v1/core}/tracer.py +9 -9
- mplang/{api.py → v1/host.py} +38 -6
- mplang/v1/kernels/__init__.py +41 -0
- mplang/{kernels → v1/kernels}/base.py +1 -1
- mplang/v1/kernels/basic.py +240 -0
- mplang/{kernels → v1/kernels}/context.py +42 -27
- mplang/{kernels → v1/kernels}/crypto.py +44 -37
- mplang/v1/kernels/fhe.py +858 -0
- mplang/{kernels → v1/kernels}/mock_tee.py +12 -13
- mplang/{kernels → v1/kernels}/phe.py +263 -57
- mplang/{kernels → v1/kernels}/spu.py +137 -48
- mplang/{kernels → v1/kernels}/sql_duckdb.py +12 -15
- mplang/{kernels → v1/kernels}/stablehlo.py +30 -23
- mplang/v1/kernels/value.py +626 -0
- mplang/{ops → v1/ops}/__init__.py +5 -16
- mplang/{ops → v1/ops}/base.py +2 -5
- mplang/{ops/builtin.py → v1/ops/basic.py} +34 -26
- mplang/v1/ops/crypto.py +262 -0
- mplang/v1/ops/fhe.py +272 -0
- mplang/{ops → v1/ops}/jax_cc.py +33 -68
- mplang/v1/ops/nnx_cc.py +168 -0
- mplang/{ops → v1/ops}/phe.py +16 -4
- mplang/{ops → v1/ops}/spu.py +3 -5
- mplang/v1/ops/sql_cc.py +303 -0
- mplang/{ops → v1/ops}/tee.py +9 -24
- mplang/{protos → v1/protos}/v1alpha1/mpir_pb2.pyi +71 -21
- mplang/v1/protos/v1alpha1/value_pb2.py +34 -0
- mplang/v1/protos/v1alpha1/value_pb2.pyi +169 -0
- mplang/{runtime → v1/runtime}/__init__.py +2 -2
- mplang/v1/runtime/channel.py +230 -0
- mplang/{runtime → v1/runtime}/cli.py +35 -20
- mplang/{runtime → v1/runtime}/client.py +19 -8
- mplang/{runtime → v1/runtime}/communicator.py +59 -15
- mplang/{runtime → v1/runtime}/data_providers.py +80 -19
- mplang/{runtime → v1/runtime}/driver.py +30 -12
- mplang/v1/runtime/link_comm.py +196 -0
- mplang/{runtime → v1/runtime}/server.py +58 -42
- mplang/{runtime → v1/runtime}/session.py +57 -71
- mplang/{runtime → v1/runtime}/simulation.py +55 -28
- mplang/v1/simp/api.py +353 -0
- mplang/{simp → v1/simp}/mpi.py +8 -9
- mplang/{simp/__init__.py → v1/simp/party.py} +19 -145
- mplang/{simp → v1/simp}/random.py +21 -22
- mplang/v1/simp/smpc.py +238 -0
- mplang/v1/utils/table_utils.py +185 -0
- mplang/v2/__init__.py +424 -0
- mplang/v2/backends/__init__.py +57 -0
- mplang/v2/backends/bfv_impl.py +705 -0
- mplang/v2/backends/channel.py +217 -0
- mplang/v2/backends/crypto_impl.py +723 -0
- mplang/v2/backends/field_impl.py +454 -0
- mplang/v2/backends/func_impl.py +107 -0
- mplang/v2/backends/phe_impl.py +148 -0
- mplang/v2/backends/simp_design.md +136 -0
- mplang/v2/backends/simp_driver/__init__.py +41 -0
- mplang/v2/backends/simp_driver/http.py +168 -0
- mplang/v2/backends/simp_driver/mem.py +280 -0
- mplang/v2/backends/simp_driver/ops.py +135 -0
- mplang/v2/backends/simp_driver/state.py +60 -0
- mplang/v2/backends/simp_driver/values.py +52 -0
- mplang/v2/backends/simp_worker/__init__.py +29 -0
- mplang/v2/backends/simp_worker/http.py +354 -0
- mplang/v2/backends/simp_worker/mem.py +102 -0
- mplang/v2/backends/simp_worker/ops.py +167 -0
- mplang/v2/backends/simp_worker/state.py +49 -0
- mplang/v2/backends/spu_impl.py +275 -0
- mplang/v2/backends/spu_state.py +187 -0
- mplang/v2/backends/store_impl.py +62 -0
- mplang/v2/backends/table_impl.py +838 -0
- mplang/v2/backends/tee_impl.py +215 -0
- mplang/v2/backends/tensor_impl.py +519 -0
- mplang/v2/cli.py +603 -0
- mplang/v2/cli_guide.md +122 -0
- mplang/v2/dialects/__init__.py +36 -0
- mplang/v2/dialects/bfv.py +665 -0
- mplang/v2/dialects/crypto.py +689 -0
- mplang/v2/dialects/dtypes.py +378 -0
- mplang/v2/dialects/field.py +210 -0
- mplang/v2/dialects/func.py +135 -0
- mplang/v2/dialects/phe.py +723 -0
- mplang/v2/dialects/simp.py +944 -0
- mplang/v2/dialects/spu.py +349 -0
- mplang/v2/dialects/store.py +63 -0
- mplang/v2/dialects/table.py +407 -0
- mplang/v2/dialects/tee.py +346 -0
- mplang/v2/dialects/tensor.py +1175 -0
- mplang/v2/edsl/README.md +279 -0
- mplang/v2/edsl/__init__.py +99 -0
- mplang/v2/edsl/context.py +311 -0
- mplang/v2/edsl/graph.py +463 -0
- mplang/v2/edsl/jit.py +62 -0
- mplang/v2/edsl/object.py +53 -0
- mplang/v2/edsl/primitive.py +284 -0
- mplang/v2/edsl/printer.py +119 -0
- mplang/v2/edsl/registry.py +207 -0
- mplang/v2/edsl/serde.py +375 -0
- mplang/v2/edsl/tracer.py +614 -0
- mplang/v2/edsl/typing.py +816 -0
- mplang/v2/kernels/Makefile +30 -0
- mplang/v2/kernels/__init__.py +23 -0
- mplang/v2/kernels/gf128.cpp +148 -0
- mplang/v2/kernels/ldpc.cpp +82 -0
- mplang/v2/kernels/okvs.cpp +283 -0
- mplang/v2/kernels/okvs_opt.cpp +291 -0
- mplang/v2/kernels/py_kernels.py +398 -0
- mplang/v2/libs/collective.py +330 -0
- mplang/v2/libs/device/__init__.py +51 -0
- mplang/v2/libs/device/api.py +813 -0
- mplang/v2/libs/device/cluster.py +352 -0
- mplang/v2/libs/ml/__init__.py +23 -0
- mplang/v2/libs/ml/sgb.py +1861 -0
- mplang/v2/libs/mpc/__init__.py +41 -0
- mplang/v2/libs/mpc/_utils.py +99 -0
- mplang/v2/libs/mpc/analytics/__init__.py +35 -0
- mplang/v2/libs/mpc/analytics/aggregation.py +372 -0
- mplang/v2/libs/mpc/analytics/groupby.md +99 -0
- mplang/v2/libs/mpc/analytics/groupby.py +331 -0
- mplang/v2/libs/mpc/analytics/permutation.py +386 -0
- mplang/v2/libs/mpc/common/constants.py +39 -0
- mplang/v2/libs/mpc/ot/__init__.py +32 -0
- mplang/v2/libs/mpc/ot/base.py +222 -0
- mplang/v2/libs/mpc/ot/extension.py +477 -0
- mplang/v2/libs/mpc/ot/silent.py +217 -0
- mplang/v2/libs/mpc/psi/__init__.py +40 -0
- mplang/v2/libs/mpc/psi/cuckoo.py +228 -0
- mplang/v2/libs/mpc/psi/okvs.py +49 -0
- mplang/v2/libs/mpc/psi/okvs_gct.py +79 -0
- mplang/v2/libs/mpc/psi/oprf.py +310 -0
- mplang/v2/libs/mpc/psi/rr22.py +344 -0
- mplang/v2/libs/mpc/psi/unbalanced.py +200 -0
- mplang/v2/libs/mpc/vole/__init__.py +31 -0
- mplang/v2/libs/mpc/vole/gilboa.py +327 -0
- mplang/v2/libs/mpc/vole/ldpc.py +383 -0
- mplang/v2/libs/mpc/vole/silver.py +336 -0
- mplang/v2/runtime/__init__.py +15 -0
- mplang/v2/runtime/dialect_state.py +41 -0
- mplang/v2/runtime/interpreter.py +871 -0
- mplang/v2/runtime/object_store.py +194 -0
- mplang/v2/runtime/value.py +141 -0
- {mplang_nightly-0.1.dev158.dist-info → mplang_nightly-0.1.dev268.dist-info}/METADATA +24 -17
- mplang_nightly-0.1.dev268.dist-info/RECORD +180 -0
- {mplang_nightly-0.1.dev158.dist-info → mplang_nightly-0.1.dev268.dist-info}/WHEEL +1 -1
- mplang/core/__init__.py +0 -92
- mplang/device.py +0 -340
- mplang/kernels/builtin.py +0 -207
- mplang/ops/crypto.py +0 -109
- mplang/ops/ibis_cc.py +0 -139
- mplang/ops/sql.py +0 -61
- mplang/protos/v1alpha1/mpir_pb2_grpc.py +0 -3
- mplang/runtime/link_comm.py +0 -131
- mplang/simp/smpc.py +0 -201
- mplang/utils/table_utils.py +0 -73
- mplang_nightly-0.1.dev158.dist-info/RECORD +0 -77
- /mplang/{core → v1/core}/mask.py +0 -0
- /mplang/{protos → v1/protos}/v1alpha1/mpir_pb2.py +0 -0
- /mplang/{runtime → v1/runtime}/exceptions.py +0 -0
- /mplang/{runtime → v1/runtime}/http_api.md +0 -0
- /mplang/{kernels → v1/simp}/__init__.py +0 -0
- /mplang/{utils → v1/utils}/__init__.py +0 -0
- /mplang/{utils → v1/utils}/crypto.py +0 -0
- /mplang/{utils → v1/utils}/func_utils.py +0 -0
- /mplang/{utils → v1/utils}/spu_utils.py +0 -0
- {mplang_nightly-0.1.dev158.dist-info → mplang_nightly-0.1.dev268.dist-info}/entry_points.txt +0 -0
- {mplang_nightly-0.1.dev158.dist-info → mplang_nightly-0.1.dev268.dist-info}/licenses/LICENSE +0 -0
mplang/v2/edsl/graph.py
ADDED
|
@@ -0,0 +1,463 @@
|
|
|
1
|
+
# Copyright 2025 Ant Group Co., Ltd.
|
|
2
|
+
#
|
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
4
|
+
# you may not use this file except in compliance with the License.
|
|
5
|
+
# You may obtain a copy of the License at
|
|
6
|
+
#
|
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
8
|
+
#
|
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
12
|
+
# See the License for the specific language governing permissions and
|
|
13
|
+
# limitations under the License.
|
|
14
|
+
|
|
15
|
+
"""
|
|
16
|
+
Graph IR: Operation List + SSA Values.
|
|
17
|
+
|
|
18
|
+
This module implements a modern, flat IR representation inspired by torch.fx
|
|
19
|
+
and JAX jaxpr, replacing the tree-based Expr system.
|
|
20
|
+
|
|
21
|
+
Key Design Principles:
|
|
22
|
+
----------------------
|
|
23
|
+
1. **Flat Structure**: Operations in a list, not a tree
|
|
24
|
+
2. **SSA Form**: Each value defined once, use-def chains explicit
|
|
25
|
+
3. **Easy Traversal**: No visitor pattern needed
|
|
26
|
+
4. **Optimization-Friendly**: Dead code elimination, fusion, etc.
|
|
27
|
+
|
|
28
|
+
Example:
|
|
29
|
+
--------
|
|
30
|
+
from mplang.v2.edsl.graph import Graph, Operation, Value
|
|
31
|
+
from mplang.v2.edsl.typing import Tensor, f32
|
|
32
|
+
|
|
33
|
+
graph = Graph()
|
|
34
|
+
|
|
35
|
+
# Create values
|
|
36
|
+
x = graph.add_input("x", Tensor[f32, (10,)])
|
|
37
|
+
y = graph.add_input("y", Tensor[f32, (10,)])
|
|
38
|
+
|
|
39
|
+
# Add operations
|
|
40
|
+
z, = graph.add_op("add", [x, y])
|
|
41
|
+
scale, = graph.add_op("tensor.constant", [], output_types=[f32], attrs={"data": 2.0})
|
|
42
|
+
result, = graph.add_op("mul", [z, scale])
|
|
43
|
+
|
|
44
|
+
# Mark outputs
|
|
45
|
+
graph.add_output(result)
|
|
46
|
+
|
|
47
|
+
# Print IR
|
|
48
|
+
print(graph.to_string())
|
|
49
|
+
# Output:
|
|
50
|
+
# %0 = input "x" : Tensor[f32, (10,)]
|
|
51
|
+
# %1 = input "y" : Tensor[f32, (10,)]
|
|
52
|
+
# %2 = tensor.constant {data=2.0} : f32
|
|
53
|
+
# %3 = add %0, %1 : Tensor[f32, (10,)]
|
|
54
|
+
# %4 = mul %3, %2 : Tensor[f32, (10,)]
|
|
55
|
+
# return %4
|
|
56
|
+
"""
|
|
57
|
+
|
|
58
|
+
from __future__ import annotations
|
|
59
|
+
|
|
60
|
+
from collections.abc import Sequence
|
|
61
|
+
from dataclasses import dataclass, field
|
|
62
|
+
from typing import Any, ClassVar
|
|
63
|
+
|
|
64
|
+
from mplang.v2.edsl import serde
|
|
65
|
+
from mplang.v2.edsl.typing import BaseType
|
|
66
|
+
|
|
67
|
+
|
|
68
|
+
@dataclass
|
|
69
|
+
class Value:
|
|
70
|
+
"""SSA value in the IR.
|
|
71
|
+
|
|
72
|
+
Each value is defined exactly once by an operation (or is an input).
|
|
73
|
+
Values track their uses and defining operation for def-use chain analysis.
|
|
74
|
+
|
|
75
|
+
Attributes:
|
|
76
|
+
name: Unique SSA name (e.g., "%0", "%1", ...)
|
|
77
|
+
type: Type of this value (from mplang.v2.edsl.typing)
|
|
78
|
+
defining_op: Operation that produces this value (None for inputs)
|
|
79
|
+
uses: List of operations that consume this value
|
|
80
|
+
"""
|
|
81
|
+
|
|
82
|
+
name: str
|
|
83
|
+
type: BaseType
|
|
84
|
+
defining_op: Operation | None = None
|
|
85
|
+
uses: dict[Operation, None] = field(default_factory=dict)
|
|
86
|
+
|
|
87
|
+
def __repr__(self) -> str:
|
|
88
|
+
return f"Value({self.name}: {self.type})"
|
|
89
|
+
|
|
90
|
+
def __str__(self) -> str:
|
|
91
|
+
return self.name
|
|
92
|
+
|
|
93
|
+
def __hash__(self) -> int:
|
|
94
|
+
return id(self)
|
|
95
|
+
|
|
96
|
+
def __eq__(self, other: object) -> bool:
|
|
97
|
+
return self is other
|
|
98
|
+
|
|
99
|
+
def add_use(self, op: Operation) -> None:
|
|
100
|
+
"""Register an operation that uses this value."""
|
|
101
|
+
self.uses[op] = None
|
|
102
|
+
|
|
103
|
+
def remove_use(self, op: Operation) -> None:
|
|
104
|
+
"""Unregister an operation that uses this value."""
|
|
105
|
+
if op in self.uses:
|
|
106
|
+
del self.uses[op]
|
|
107
|
+
|
|
108
|
+
@property
|
|
109
|
+
def num_uses(self) -> int:
|
|
110
|
+
"""Number of operations using this value."""
|
|
111
|
+
return len(self.uses)
|
|
112
|
+
|
|
113
|
+
@property
|
|
114
|
+
def is_dead(self) -> bool:
|
|
115
|
+
"""True if this value is never used (dead code)."""
|
|
116
|
+
return self.num_uses == 0 and self.defining_op is not None
|
|
117
|
+
|
|
118
|
+
@property
|
|
119
|
+
def is_bound(self) -> bool:
|
|
120
|
+
"""True if this value is bound (defined by an operation)."""
|
|
121
|
+
return self.defining_op is not None
|
|
122
|
+
|
|
123
|
+
@property
|
|
124
|
+
def is_free(self) -> bool:
|
|
125
|
+
"""True if this value is free (graph input, not defined by operation)."""
|
|
126
|
+
return self.defining_op is None
|
|
127
|
+
|
|
128
|
+
|
|
129
|
+
@dataclass
|
|
130
|
+
class Operation:
|
|
131
|
+
"""Single operation in the IR.
|
|
132
|
+
|
|
133
|
+
Operations represent computations in the graph. They consume input values
|
|
134
|
+
and produce output values.
|
|
135
|
+
|
|
136
|
+
Attributes:
|
|
137
|
+
opcode: Operation name (e.g., "add", "mul", "cond")
|
|
138
|
+
inputs: Input values consumed by this operation
|
|
139
|
+
outputs: Output values produced by this operation
|
|
140
|
+
attrs: Additional attributes (e.g., shape, dtype, backend-specific)
|
|
141
|
+
regions: Nested graphs (for control flow: cond, while)
|
|
142
|
+
"""
|
|
143
|
+
|
|
144
|
+
opcode: str
|
|
145
|
+
inputs: list[Value]
|
|
146
|
+
outputs: list[Value]
|
|
147
|
+
attrs: dict[str, Any] = field(default_factory=dict)
|
|
148
|
+
regions: list[Graph] = field(default_factory=list)
|
|
149
|
+
name: str = field(default="")
|
|
150
|
+
|
|
151
|
+
def __eq__(self, other: object) -> bool:
|
|
152
|
+
return self is other
|
|
153
|
+
|
|
154
|
+
def __hash__(self) -> int:
|
|
155
|
+
return id(self)
|
|
156
|
+
|
|
157
|
+
def __post_init__(self) -> None:
|
|
158
|
+
"""Register this operation as the definer and user of values."""
|
|
159
|
+
# Register as defining op for outputs
|
|
160
|
+
for output in self.outputs:
|
|
161
|
+
output.defining_op = self
|
|
162
|
+
|
|
163
|
+
# Register as user for inputs
|
|
164
|
+
for input_val in self.inputs:
|
|
165
|
+
input_val.add_use(self)
|
|
166
|
+
|
|
167
|
+
def __repr__(self) -> str:
|
|
168
|
+
inputs_str = ", ".join(str(v) for v in self.inputs)
|
|
169
|
+
outputs_str = ", ".join(str(v) for v in self.outputs)
|
|
170
|
+
return f"Operation({self.opcode}: {inputs_str} -> {outputs_str})"
|
|
171
|
+
|
|
172
|
+
def replace_input(self, old: Value, new: Value) -> None:
|
|
173
|
+
"""Replace an input value (updates use-def chains)."""
|
|
174
|
+
for i, inp in enumerate(self.inputs):
|
|
175
|
+
if inp is old:
|
|
176
|
+
self.inputs[i] = new
|
|
177
|
+
old.remove_use(self)
|
|
178
|
+
new.add_use(self)
|
|
179
|
+
|
|
180
|
+
def erase(self) -> None:
|
|
181
|
+
"""Remove this operation (updates use-def chains)."""
|
|
182
|
+
for inp in self.inputs:
|
|
183
|
+
inp.remove_use(self)
|
|
184
|
+
for out in self.outputs:
|
|
185
|
+
out.defining_op = None
|
|
186
|
+
|
|
187
|
+
|
|
188
|
+
@serde.register_class
|
|
189
|
+
class Graph:
|
|
190
|
+
"""Computation graph as a flat list of operations.
|
|
191
|
+
|
|
192
|
+
A graph contains:
|
|
193
|
+
- Inputs: Named input values
|
|
194
|
+
- Operations: Flat list of computations
|
|
195
|
+
- Outputs: Values returned from the graph
|
|
196
|
+
- Values: All SSA values in the graph
|
|
197
|
+
|
|
198
|
+
Example:
|
|
199
|
+
graph = Graph()
|
|
200
|
+
x = graph.add_input("x", Tensor[f32, (10,)])
|
|
201
|
+
y, = graph.add_op("tensor.constant", [], output_types=[f32], attrs={"data": 1.0})
|
|
202
|
+
z, = graph.add_op("add", [x, y])
|
|
203
|
+
graph.add_output(z)
|
|
204
|
+
"""
|
|
205
|
+
|
|
206
|
+
_serde_kind: ClassVar[str] = "mplang.Graph"
|
|
207
|
+
|
|
208
|
+
def __init__(self) -> None:
|
|
209
|
+
self.operations: list[Operation] = []
|
|
210
|
+
self.values: dict[str, Value] = {}
|
|
211
|
+
self.inputs: list[Value] = []
|
|
212
|
+
self.outputs: list[Value] = []
|
|
213
|
+
self._value_counter = 0
|
|
214
|
+
self._op_counter = 0
|
|
215
|
+
|
|
216
|
+
def _gen_value_name(self) -> str:
|
|
217
|
+
"""Generate a unique SSA value name."""
|
|
218
|
+
name = f"%{self._value_counter}"
|
|
219
|
+
self._value_counter += 1
|
|
220
|
+
return name
|
|
221
|
+
|
|
222
|
+
def add_value(self, type: BaseType, name: str | None = None) -> Value:
|
|
223
|
+
"""Create a new SSA value.
|
|
224
|
+
|
|
225
|
+
Args:
|
|
226
|
+
type: Type of the value
|
|
227
|
+
name: Optional custom name (auto-generated if None)
|
|
228
|
+
|
|
229
|
+
Returns:
|
|
230
|
+
New Value instance
|
|
231
|
+
"""
|
|
232
|
+
if name is None:
|
|
233
|
+
name = self._gen_value_name()
|
|
234
|
+
|
|
235
|
+
if name in self.values:
|
|
236
|
+
raise ValueError(f"Value {name} already exists")
|
|
237
|
+
|
|
238
|
+
value = Value(name, type)
|
|
239
|
+
self.values[name] = value
|
|
240
|
+
return value
|
|
241
|
+
|
|
242
|
+
def add_input(self, name: str, type: BaseType) -> Value:
|
|
243
|
+
"""Add a graph input.
|
|
244
|
+
|
|
245
|
+
Args:
|
|
246
|
+
name: Input parameter name
|
|
247
|
+
type: Type of the input
|
|
248
|
+
|
|
249
|
+
Returns:
|
|
250
|
+
Input value
|
|
251
|
+
"""
|
|
252
|
+
value = self.add_value(type, name=name)
|
|
253
|
+
self.inputs.append(value)
|
|
254
|
+
return value
|
|
255
|
+
|
|
256
|
+
def add_op(
|
|
257
|
+
self,
|
|
258
|
+
opcode: str,
|
|
259
|
+
inputs: list[Value],
|
|
260
|
+
output_types: Sequence[BaseType] | None = None,
|
|
261
|
+
attrs: dict[str, Any] | None = None,
|
|
262
|
+
regions: list[Graph] | None = None,
|
|
263
|
+
) -> list[Value]:
|
|
264
|
+
"""Add an operation to the graph.
|
|
265
|
+
|
|
266
|
+
Args:
|
|
267
|
+
opcode: Operation name
|
|
268
|
+
inputs: Input values
|
|
269
|
+
output_types: Types of outputs (inferred if None)
|
|
270
|
+
attrs: Additional attributes
|
|
271
|
+
regions: Nested graphs (for control flow)
|
|
272
|
+
|
|
273
|
+
Returns:
|
|
274
|
+
List of output values (one entry per output)
|
|
275
|
+
"""
|
|
276
|
+
# Type inference (placeholder - should be backend-specific)
|
|
277
|
+
if output_types is None:
|
|
278
|
+
# Simple rule: inherit from first input
|
|
279
|
+
if inputs:
|
|
280
|
+
output_types = [inputs[0].type]
|
|
281
|
+
else:
|
|
282
|
+
raise ValueError(f"Cannot infer type for {opcode} with no inputs")
|
|
283
|
+
|
|
284
|
+
# Create output values
|
|
285
|
+
outputs = [self.add_value(t) for t in output_types]
|
|
286
|
+
|
|
287
|
+
# Create operation
|
|
288
|
+
op_name = f"op{self._op_counter}"
|
|
289
|
+
self._op_counter += 1
|
|
290
|
+
op = Operation(
|
|
291
|
+
opcode=opcode,
|
|
292
|
+
inputs=inputs,
|
|
293
|
+
outputs=outputs,
|
|
294
|
+
attrs=attrs or {},
|
|
295
|
+
regions=regions or [],
|
|
296
|
+
name=op_name,
|
|
297
|
+
)
|
|
298
|
+
self.operations.append(op)
|
|
299
|
+
|
|
300
|
+
return outputs
|
|
301
|
+
|
|
302
|
+
def add_output(self, value: Value) -> None:
|
|
303
|
+
"""Mark a value as a graph output.
|
|
304
|
+
|
|
305
|
+
Args:
|
|
306
|
+
value: Value to be returned from the graph
|
|
307
|
+
"""
|
|
308
|
+
if value not in self.values.values():
|
|
309
|
+
raise ValueError(f"Value {value} not in graph")
|
|
310
|
+
self.outputs.append(value)
|
|
311
|
+
|
|
312
|
+
def to_string(self, verbose: bool = False) -> str:
|
|
313
|
+
"""Generate human-readable IR representation.
|
|
314
|
+
|
|
315
|
+
Args:
|
|
316
|
+
verbose: Include type annotations
|
|
317
|
+
|
|
318
|
+
Returns:
|
|
319
|
+
String representation of the graph
|
|
320
|
+
"""
|
|
321
|
+
lines = []
|
|
322
|
+
|
|
323
|
+
# Print inputs
|
|
324
|
+
for inp in self.inputs:
|
|
325
|
+
type_str = f" : {inp.type}" if verbose else ""
|
|
326
|
+
lines.append(f"{inp.name} = input{type_str}")
|
|
327
|
+
|
|
328
|
+
# Print operations
|
|
329
|
+
for op in self.operations:
|
|
330
|
+
if op.opcode == "constant":
|
|
331
|
+
value_str = op.attrs.get("value", "?")
|
|
332
|
+
type_str = f" : {op.outputs[0].type}" if verbose else ""
|
|
333
|
+
lines.append(f"{op.outputs[0].name} = constant {value_str}{type_str}")
|
|
334
|
+
else:
|
|
335
|
+
inputs_str = ", ".join(str(v) for v in op.inputs)
|
|
336
|
+
outputs_str = ", ".join(str(v) for v in op.outputs)
|
|
337
|
+
|
|
338
|
+
# Handle single vs multiple outputs
|
|
339
|
+
if len(op.outputs) == 1:
|
|
340
|
+
lhs = str(op.outputs[0])
|
|
341
|
+
else:
|
|
342
|
+
lhs = f"[{outputs_str}]"
|
|
343
|
+
|
|
344
|
+
type_str = f" : {op.outputs[0].type}" if verbose and op.outputs else ""
|
|
345
|
+
|
|
346
|
+
if op.attrs:
|
|
347
|
+
attrs_str = ", ".join(f"{k}={v}" for k, v in op.attrs.items())
|
|
348
|
+
lines.append(
|
|
349
|
+
f"{lhs} = {op.opcode}({inputs_str}) {{{attrs_str}}}{type_str}"
|
|
350
|
+
)
|
|
351
|
+
else:
|
|
352
|
+
lines.append(f"{lhs} = {op.opcode}({inputs_str}){type_str}")
|
|
353
|
+
|
|
354
|
+
# Print outputs
|
|
355
|
+
if self.outputs:
|
|
356
|
+
outputs_str = ", ".join(str(v) for v in self.outputs)
|
|
357
|
+
lines.append(f"return {outputs_str}")
|
|
358
|
+
|
|
359
|
+
return "\n".join(lines)
|
|
360
|
+
|
|
361
|
+
def __repr__(self) -> str:
|
|
362
|
+
return f"Graph({len(self.operations)} ops, {len(self.values)} values)"
|
|
363
|
+
|
|
364
|
+
def __str__(self) -> str:
|
|
365
|
+
return self.to_string()
|
|
366
|
+
|
|
367
|
+
# =========================================================================
|
|
368
|
+
# Serialization
|
|
369
|
+
# =========================================================================
|
|
370
|
+
|
|
371
|
+
def to_json(self) -> dict:
|
|
372
|
+
"""Serialize graph to JSON-compatible dict."""
|
|
373
|
+
|
|
374
|
+
def _type_to_json(t: BaseType) -> dict:
|
|
375
|
+
return serde.to_json(t)
|
|
376
|
+
|
|
377
|
+
def _attr_to_json(value: Any) -> dict:
|
|
378
|
+
return serde.to_json(value)
|
|
379
|
+
|
|
380
|
+
def _attrs_to_json(attrs: dict[str, Any]) -> dict[str, Any]:
|
|
381
|
+
return {k: _attr_to_json(v) for k, v in attrs.items()}
|
|
382
|
+
|
|
383
|
+
return {
|
|
384
|
+
"inputs": [
|
|
385
|
+
{"name": v.name, "type": _type_to_json(v.type)} for v in self.inputs
|
|
386
|
+
],
|
|
387
|
+
"operations": [
|
|
388
|
+
{
|
|
389
|
+
"opcode": op.opcode,
|
|
390
|
+
"inputs": [v.name for v in op.inputs],
|
|
391
|
+
"outputs": [
|
|
392
|
+
{"name": v.name, "type": _type_to_json(v.type)}
|
|
393
|
+
for v in op.outputs
|
|
394
|
+
],
|
|
395
|
+
"attrs": _attrs_to_json(op.attrs),
|
|
396
|
+
"regions": [serde.to_json(r) for r in op.regions],
|
|
397
|
+
"name": op.name,
|
|
398
|
+
}
|
|
399
|
+
for op in self.operations
|
|
400
|
+
],
|
|
401
|
+
"outputs": [v.name for v in self.outputs],
|
|
402
|
+
}
|
|
403
|
+
|
|
404
|
+
@classmethod
|
|
405
|
+
def from_json(cls, data: dict) -> Graph:
|
|
406
|
+
"""Deserialize graph from JSON-compatible dict."""
|
|
407
|
+
|
|
408
|
+
def _type_from_json(d: dict) -> BaseType:
|
|
409
|
+
result = serde.from_json(d)
|
|
410
|
+
if not isinstance(result, BaseType):
|
|
411
|
+
raise TypeError(f"Expected BaseType, got {type(result)}")
|
|
412
|
+
return result
|
|
413
|
+
|
|
414
|
+
def _attr_from_json(value: dict) -> Any:
|
|
415
|
+
return serde.from_json(value)
|
|
416
|
+
|
|
417
|
+
def _attrs_from_json(attrs: dict[str, Any]) -> dict[str, Any]:
|
|
418
|
+
return {k: _attr_from_json(v) for k, v in attrs.items()}
|
|
419
|
+
|
|
420
|
+
graph = cls()
|
|
421
|
+
|
|
422
|
+
# Reconstruct inputs
|
|
423
|
+
for inp_data in data["inputs"]:
|
|
424
|
+
graph.add_input(inp_data["name"], _type_from_json(inp_data["type"]))
|
|
425
|
+
|
|
426
|
+
# Reconstruct operations
|
|
427
|
+
for op_data in data["operations"]:
|
|
428
|
+
# Resolve input values by name
|
|
429
|
+
inputs = [graph.values[name] for name in op_data["inputs"]]
|
|
430
|
+
|
|
431
|
+
# Get output types
|
|
432
|
+
output_types = [_type_from_json(out["type"]) for out in op_data["outputs"]]
|
|
433
|
+
|
|
434
|
+
# Deserialize nested graphs (regions)
|
|
435
|
+
regions = [serde.from_json(r) for r in op_data.get("regions", [])]
|
|
436
|
+
|
|
437
|
+
# Add operation
|
|
438
|
+
outputs = graph.add_op(
|
|
439
|
+
op_data["opcode"],
|
|
440
|
+
inputs,
|
|
441
|
+
output_types=output_types,
|
|
442
|
+
attrs=_attrs_from_json(op_data.get("attrs", {})),
|
|
443
|
+
regions=regions,
|
|
444
|
+
)
|
|
445
|
+
|
|
446
|
+
# Rename outputs to match original names
|
|
447
|
+
for out_val, out_data in zip(outputs, op_data["outputs"], strict=False):
|
|
448
|
+
original_name = out_data["name"]
|
|
449
|
+
if out_val.name != original_name:
|
|
450
|
+
# Update the values dict with the original name
|
|
451
|
+
del graph.values[out_val.name]
|
|
452
|
+
out_val.name = original_name
|
|
453
|
+
graph.values[original_name] = out_val
|
|
454
|
+
|
|
455
|
+
# Set operation name if provided
|
|
456
|
+
if op_data.get("name"):
|
|
457
|
+
graph.operations[-1].name = op_data["name"]
|
|
458
|
+
|
|
459
|
+
# Reconstruct outputs
|
|
460
|
+
for name in data["outputs"]:
|
|
461
|
+
graph.add_output(graph.values[name])
|
|
462
|
+
|
|
463
|
+
return graph
|
mplang/v2/edsl/jit.py
ADDED
|
@@ -0,0 +1,62 @@
|
|
|
1
|
+
# Copyright 2025 Ant Group Co., Ltd.
|
|
2
|
+
#
|
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
4
|
+
# you may not use this file except in compliance with the License.
|
|
5
|
+
# You may obtain a copy of the License at
|
|
6
|
+
#
|
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
8
|
+
#
|
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
12
|
+
# See the License for the specific language governing permissions and
|
|
13
|
+
# limitations under the License.
|
|
14
|
+
|
|
15
|
+
"""JIT Decorator: Compile and cache Graph IR."""
|
|
16
|
+
|
|
17
|
+
from collections.abc import Callable
|
|
18
|
+
from typing import Any
|
|
19
|
+
|
|
20
|
+
from jax.tree_util import tree_map
|
|
21
|
+
|
|
22
|
+
from mplang.v2.edsl.context import (
|
|
23
|
+
AbstractInterpreter,
|
|
24
|
+
get_current_context,
|
|
25
|
+
get_default_context,
|
|
26
|
+
)
|
|
27
|
+
from mplang.v2.edsl.tracer import Tracer
|
|
28
|
+
|
|
29
|
+
|
|
30
|
+
def jit(fn: Callable) -> Callable:
|
|
31
|
+
"""JIT compilation decorator.
|
|
32
|
+
|
|
33
|
+
Traces the function to Graph IR on first call, then executes the cached
|
|
34
|
+
Graph on subsequent calls.
|
|
35
|
+
|
|
36
|
+
Example:
|
|
37
|
+
>>> @jit
|
|
38
|
+
... def compute(x, y):
|
|
39
|
+
... return x + y
|
|
40
|
+
>>> result = compute(x_interp, y_interp) # First call: trace
|
|
41
|
+
>>> result = compute(x_interp, y_interp) # Subsequent: execute cached graph
|
|
42
|
+
"""
|
|
43
|
+
|
|
44
|
+
def wrapper(*args: Any, **kwargs: Any) -> Any:
|
|
45
|
+
# If we are already inside a Tracer (e.g. pcall_static), just inline
|
|
46
|
+
# the function to trace it into the current graph.
|
|
47
|
+
cur_ctx = get_current_context()
|
|
48
|
+
if isinstance(cur_ctx, Tracer):
|
|
49
|
+
return fn(*args, **kwargs)
|
|
50
|
+
|
|
51
|
+
# otherwise trace for JIT compilation
|
|
52
|
+
with Tracer():
|
|
53
|
+
result = fn(*args, **kwargs)
|
|
54
|
+
|
|
55
|
+
# Use current context if available (e.g., SimpSimulator), otherwise use default
|
|
56
|
+
cur_ctx = cur_ctx or get_default_context()
|
|
57
|
+
assert isinstance(cur_ctx, AbstractInterpreter), (
|
|
58
|
+
"JIT execution requires Interpreter context"
|
|
59
|
+
)
|
|
60
|
+
return tree_map(cur_ctx.lift, result)
|
|
61
|
+
|
|
62
|
+
return wrapper
|
mplang/v2/edsl/object.py
ADDED
|
@@ -0,0 +1,53 @@
|
|
|
1
|
+
# Copyright 2025 Ant Group Co., Ltd.
|
|
2
|
+
#
|
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
4
|
+
# you may not use this file except in compliance with the License.
|
|
5
|
+
# You may obtain a copy of the License at
|
|
6
|
+
#
|
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
8
|
+
#
|
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
12
|
+
# See the License for the specific language governing permissions and
|
|
13
|
+
# limitations under the License.
|
|
14
|
+
|
|
15
|
+
"""Object: Base class for runtime objects.
|
|
16
|
+
|
|
17
|
+
Base abstraction for distinguishing trace-time and interp-time execution.
|
|
18
|
+
|
|
19
|
+
- TraceObject: Defined in mplang.edsl.tracer
|
|
20
|
+
- InterpObject: Defined in mplang.edsl.interpreter
|
|
21
|
+
"""
|
|
22
|
+
|
|
23
|
+
from __future__ import annotations
|
|
24
|
+
|
|
25
|
+
from abc import ABC, abstractmethod
|
|
26
|
+
from typing import Generic, TypeVar
|
|
27
|
+
|
|
28
|
+
from mplang.v2.edsl.typing import BaseType
|
|
29
|
+
|
|
30
|
+
T = TypeVar("T", bound=BaseType)
|
|
31
|
+
|
|
32
|
+
|
|
33
|
+
class Object(ABC, Generic[T]):
|
|
34
|
+
"""Base class for MPLang runtime objects.
|
|
35
|
+
|
|
36
|
+
This is a Driver-side abstraction used for:
|
|
37
|
+
1. Distinguishing between trace-time and interp-time objects
|
|
38
|
+
2. Providing uniform operation interfaces (arithmetic, attribute access, etc.)
|
|
39
|
+
3. Enabling polymorphic handling by the Tracer
|
|
40
|
+
|
|
41
|
+
Subclasses:
|
|
42
|
+
- TraceObject: Trace-time object (holds a Value in Graph IR) - in mplang.edsl.tracer
|
|
43
|
+
- InterpObject: Interp-time object (holds backend-specific runtime data) - in mplang.edsl.interpreter
|
|
44
|
+
"""
|
|
45
|
+
|
|
46
|
+
@property
|
|
47
|
+
@abstractmethod
|
|
48
|
+
def type(self) -> T:
|
|
49
|
+
"""Type of the object (available in both trace and interp modes)."""
|
|
50
|
+
|
|
51
|
+
# Note: Arithmetic operators (__add__, __mul__, etc.) are NOT defined here.
|
|
52
|
+
# They should be provided by dialect-specific dispatch mechanisms since
|
|
53
|
+
# different types (Tensor, Vector, SS) require different implementations.
|