mplang-nightly 0.1.dev139__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- mplang/__init__.py +70 -0
- mplang/analysis/__init__.py +37 -0
- mplang/analysis/diagram.py +569 -0
- mplang/api.py +98 -0
- mplang/backend/__init__.py +20 -0
- mplang/backend/base.py +287 -0
- mplang/backend/builtin.py +212 -0
- mplang/backend/crypto.py +114 -0
- mplang/backend/phe.py +301 -0
- mplang/backend/spu.py +257 -0
- mplang/backend/sql_duckdb.py +42 -0
- mplang/backend/stablehlo.py +73 -0
- mplang/backend/tee.py +59 -0
- mplang/core/__init__.py +92 -0
- mplang/core/cluster.py +258 -0
- mplang/core/comm.py +277 -0
- mplang/core/context_mgr.py +50 -0
- mplang/core/dtype.py +293 -0
- mplang/core/expr/__init__.py +80 -0
- mplang/core/expr/ast.py +543 -0
- mplang/core/expr/evaluator.py +523 -0
- mplang/core/expr/printer.py +279 -0
- mplang/core/expr/transformer.py +141 -0
- mplang/core/expr/utils.py +78 -0
- mplang/core/expr/visitor.py +85 -0
- mplang/core/expr/walk.py +387 -0
- mplang/core/interp.py +160 -0
- mplang/core/mask.py +325 -0
- mplang/core/mpir.py +958 -0
- mplang/core/mpobject.py +117 -0
- mplang/core/mptype.py +438 -0
- mplang/core/pfunc.py +130 -0
- mplang/core/primitive.py +942 -0
- mplang/core/table.py +190 -0
- mplang/core/tensor.py +75 -0
- mplang/core/tracer.py +383 -0
- mplang/device.py +326 -0
- mplang/frontend/__init__.py +46 -0
- mplang/frontend/base.py +427 -0
- mplang/frontend/builtin.py +205 -0
- mplang/frontend/crypto.py +109 -0
- mplang/frontend/ibis_cc.py +137 -0
- mplang/frontend/jax_cc.py +150 -0
- mplang/frontend/phe.py +67 -0
- mplang/frontend/spu.py +152 -0
- mplang/frontend/sql.py +60 -0
- mplang/frontend/tee.py +42 -0
- mplang/protos/v1alpha1/mpir_pb2.py +63 -0
- mplang/protos/v1alpha1/mpir_pb2.pyi +507 -0
- mplang/protos/v1alpha1/mpir_pb2_grpc.py +3 -0
- mplang/runtime/__init__.py +32 -0
- mplang/runtime/cli.py +436 -0
- mplang/runtime/client.py +405 -0
- mplang/runtime/communicator.py +87 -0
- mplang/runtime/driver.py +311 -0
- mplang/runtime/exceptions.py +27 -0
- mplang/runtime/http_api.md +56 -0
- mplang/runtime/link_comm.py +131 -0
- mplang/runtime/resource.py +338 -0
- mplang/runtime/server.py +326 -0
- mplang/runtime/simulation.py +297 -0
- mplang/simp/__init__.py +351 -0
- mplang/simp/mpi.py +132 -0
- mplang/simp/random.py +121 -0
- mplang/simp/smpc.py +201 -0
- mplang/utils/__init__.py +13 -0
- mplang/utils/crypto.py +32 -0
- mplang/utils/func_utils.py +147 -0
- mplang/utils/spu_utils.py +130 -0
- mplang/utils/table_utils.py +73 -0
- mplang_nightly-0.1.dev139.dist-info/METADATA +290 -0
- mplang_nightly-0.1.dev139.dist-info/RECORD +75 -0
- mplang_nightly-0.1.dev139.dist-info/WHEEL +4 -0
- mplang_nightly-0.1.dev139.dist-info/entry_points.txt +2 -0
- mplang_nightly-0.1.dev139.dist-info/licenses/LICENSE +201 -0
mplang/__init__.py
ADDED
@@ -0,0 +1,70 @@
|
|
1
|
+
# Copyright 2025 Ant Group Co., Ltd.
|
2
|
+
#
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4
|
+
# you may not use this file except in compliance with the License.
|
5
|
+
# You may obtain a copy of the License at
|
6
|
+
#
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
8
|
+
#
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12
|
+
# See the License for the specific language governing permissions and
|
13
|
+
# limitations under the License.
|
14
|
+
|
15
|
+
"""Multi-Party Programming Language for Secure Computation."""
|
16
|
+
|
17
|
+
# Version is managed by hatch-vcs and available after package installation
|
18
|
+
try:
|
19
|
+
from importlib.metadata import PackageNotFoundError, version
|
20
|
+
|
21
|
+
__version__ = version("mplang")
|
22
|
+
except PackageNotFoundError:
|
23
|
+
# Fallback for development/editable installs when package is not installed
|
24
|
+
__version__ = "0.0.0-dev"
|
25
|
+
|
26
|
+
from mplang import analysis
|
27
|
+
from mplang.api import CompileOptions, compile, evaluate, fetch
|
28
|
+
from mplang.core import (
|
29
|
+
DType,
|
30
|
+
InterpContext,
|
31
|
+
Mask,
|
32
|
+
MPContext,
|
33
|
+
MPObject,
|
34
|
+
MPType,
|
35
|
+
TableType,
|
36
|
+
TensorType,
|
37
|
+
function,
|
38
|
+
)
|
39
|
+
from mplang.core.cluster import ClusterSpec, Device, Node, RuntimeInfo
|
40
|
+
from mplang.core.context_mgr import cur_ctx, set_ctx, with_ctx
|
41
|
+
from mplang.runtime.driver import Driver
|
42
|
+
from mplang.runtime.simulation import Simulator
|
43
|
+
|
44
|
+
# Public API
|
45
|
+
__all__ = [
|
46
|
+
"ClusterSpec",
|
47
|
+
"CompileOptions",
|
48
|
+
"DType",
|
49
|
+
"Device",
|
50
|
+
"Driver",
|
51
|
+
"InterpContext",
|
52
|
+
"MPContext",
|
53
|
+
"MPObject",
|
54
|
+
"MPType",
|
55
|
+
"Mask",
|
56
|
+
"Node",
|
57
|
+
"RuntimeInfo",
|
58
|
+
"Simulator",
|
59
|
+
"TableType",
|
60
|
+
"TensorType",
|
61
|
+
"__version__",
|
62
|
+
"analysis",
|
63
|
+
"compile",
|
64
|
+
"cur_ctx",
|
65
|
+
"evaluate",
|
66
|
+
"fetch",
|
67
|
+
"function",
|
68
|
+
"set_ctx",
|
69
|
+
"with_ctx",
|
70
|
+
]
|
@@ -0,0 +1,37 @@
|
|
1
|
+
# Copyright 2025 Ant Group Co., Ltd.
|
2
|
+
#
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4
|
+
# you may not use this file except in compliance with the License.
|
5
|
+
# You may obtain a copy of the License at
|
6
|
+
#
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
8
|
+
#
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12
|
+
# See the License for the specific language governing permissions and
|
13
|
+
# limitations under the License.
|
14
|
+
|
15
|
+
"""Analysis and visualization utilities for mplang.
|
16
|
+
|
17
|
+
This subpackage hosts non-core developer aids: diagram rendering, IR dumps,
|
18
|
+
profiling helpers (future), etc.
|
19
|
+
"""
|
20
|
+
|
21
|
+
from mplang.analysis.diagram import (
|
22
|
+
DumpResult,
|
23
|
+
FlowchartOptions,
|
24
|
+
SequenceDiagramOptions,
|
25
|
+
dump,
|
26
|
+
to_flowchart,
|
27
|
+
to_sequence_diagram,
|
28
|
+
)
|
29
|
+
|
30
|
+
__all__ = [
|
31
|
+
"DumpResult",
|
32
|
+
"FlowchartOptions",
|
33
|
+
"SequenceDiagramOptions",
|
34
|
+
"dump",
|
35
|
+
"to_flowchart",
|
36
|
+
"to_sequence_diagram",
|
37
|
+
]
|
@@ -0,0 +1,569 @@
|
|
1
|
+
# Copyright 2025 Ant Group Co., Ltd.
|
2
|
+
#
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4
|
+
# you may not use this file except in compliance with the License.
|
5
|
+
# You may obtain a copy of the License at
|
6
|
+
#
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
8
|
+
#
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12
|
+
# See the License for the specific language governing permissions and
|
13
|
+
# limitations under the License.
|
14
|
+
|
15
|
+
"""Diagram rendering (Mermaid) and markdown dump helpers.
|
16
|
+
|
17
|
+
Moved from mplang.utils.mermaid to dedicated analysis namespace.
|
18
|
+
"""
|
19
|
+
|
20
|
+
from __future__ import annotations
|
21
|
+
|
22
|
+
from dataclasses import dataclass
|
23
|
+
from pathlib import Path
|
24
|
+
from typing import TypedDict
|
25
|
+
|
26
|
+
from mplang.core import TracedFunction
|
27
|
+
from mplang.core.cluster import ClusterSpec
|
28
|
+
from mplang.core.mask import Mask
|
29
|
+
from mplang.core.mpir import Writer, get_graph_statistics
|
30
|
+
from mplang.protos.v1alpha1 import mpir_pb2
|
31
|
+
|
32
|
+
# ----------------------------- Core helpers (copied) -----------------------------
|
33
|
+
|
34
|
+
|
35
|
+
@dataclass
|
36
|
+
class Event:
|
37
|
+
kind: str
|
38
|
+
lines: list[str]
|
39
|
+
|
40
|
+
|
41
|
+
def _pmask_to_ranks(pmask: int) -> list[int]:
|
42
|
+
return list(Mask(pmask)) if pmask >= 0 else []
|
43
|
+
|
44
|
+
|
45
|
+
def _collect_world_size(
|
46
|
+
graph: mpir_pb2.GraphProto, explicit_world_size: int | None
|
47
|
+
) -> int:
|
48
|
+
if explicit_world_size is not None:
|
49
|
+
return explicit_world_size
|
50
|
+
max_rank = -1
|
51
|
+
for node in graph.nodes:
|
52
|
+
for out in node.outs_info:
|
53
|
+
if out.pmask >= 0:
|
54
|
+
for r in _pmask_to_ranks(out.pmask):
|
55
|
+
max_rank = max(max_rank, r)
|
56
|
+
return max_rank + 1 if max_rank >= 0 else 0
|
57
|
+
|
58
|
+
|
59
|
+
def _node_output_pmasks(node: mpir_pb2.NodeProto) -> list[int]:
|
60
|
+
return [out.pmask for out in node.outs_info]
|
61
|
+
|
62
|
+
|
63
|
+
def _value_producers(graph: mpir_pb2.GraphProto) -> dict[str, mpir_pb2.NodeProto]:
|
64
|
+
prod: dict[str, mpir_pb2.NodeProto] = {}
|
65
|
+
for n in graph.nodes:
|
66
|
+
base = n.name
|
67
|
+
arity = max(1, len(n.outs_info))
|
68
|
+
for i in range(arity):
|
69
|
+
key = f"{base}" if arity == 1 else f"{base}:{i}"
|
70
|
+
prod[key] = n
|
71
|
+
return prod
|
72
|
+
|
73
|
+
|
74
|
+
# ----------------------------- Option Types -----------------------------
|
75
|
+
|
76
|
+
|
77
|
+
class SequenceDiagramOptions(TypedDict, total=False):
|
78
|
+
collapse_local: bool
|
79
|
+
show_compute: bool
|
80
|
+
max_local_batch: int
|
81
|
+
show_meta: bool
|
82
|
+
|
83
|
+
|
84
|
+
class FlowchartOptions(TypedDict, total=False):
|
85
|
+
group_local: bool
|
86
|
+
world_size: int
|
87
|
+
show_meta: bool
|
88
|
+
direction: str
|
89
|
+
cross_edge_color: str
|
90
|
+
cluster_by_party: bool
|
91
|
+
shared_cluster_name: str
|
92
|
+
|
93
|
+
|
94
|
+
@dataclass
|
95
|
+
class DumpResult:
|
96
|
+
ir: str | None
|
97
|
+
stats: str
|
98
|
+
sequence: str | None
|
99
|
+
flow: str | None
|
100
|
+
markdown: str
|
101
|
+
|
102
|
+
|
103
|
+
# ----------------------------- Public APIs -----------------------------
|
104
|
+
|
105
|
+
|
106
|
+
def to_sequence_diagram(
|
107
|
+
graph: mpir_pb2.GraphProto,
|
108
|
+
*,
|
109
|
+
world_size: int | None = None,
|
110
|
+
collapse_local: bool = False,
|
111
|
+
show_compute: bool = False,
|
112
|
+
max_local_batch: int = 8,
|
113
|
+
show_meta: bool = False,
|
114
|
+
) -> str:
|
115
|
+
"""Render a MPIR graph as a Mermaid sequenceDiagram.
|
116
|
+
|
117
|
+
Parameters:
|
118
|
+
graph: The MPIR graph proto to visualize.
|
119
|
+
world_size: Optional explicit number of parties (participants). If None it
|
120
|
+
is inferred from the union of output pmasks in the graph.
|
121
|
+
collapse_local: If True, consecutive pure-local operations are batched into
|
122
|
+
a single note summarizing their op names (up to max_local_batch shown).
|
123
|
+
show_compute: When True and collapse_local is also True, local operations
|
124
|
+
are still listed individually (disable batching display logic).
|
125
|
+
max_local_batch: Maximum number of individual local op names to show inside
|
126
|
+
a collapsed note before summarizing with a "+N more" suffix.
|
127
|
+
show_meta: If True, include structural/meta ops (e.g. tuple, func_def).
|
128
|
+
|
129
|
+
Returns:
|
130
|
+
Mermaid sequenceDiagram text.
|
131
|
+
"""
|
132
|
+
|
133
|
+
wsize = _collect_world_size(graph, world_size)
|
134
|
+
value_producer = _value_producers(graph)
|
135
|
+
|
136
|
+
participants: list[str] = [f"participant P{i}" for i in range(wsize)]
|
137
|
+
events: list[Event] = []
|
138
|
+
|
139
|
+
def emit(kind: str, *lines: str) -> None:
|
140
|
+
events.append(Event(kind=kind, lines=list(lines)))
|
141
|
+
|
142
|
+
local_buffer: list[str] = []
|
143
|
+
|
144
|
+
def flush_local_buffer() -> None:
|
145
|
+
nonlocal local_buffer
|
146
|
+
if not local_buffer:
|
147
|
+
return
|
148
|
+
if collapse_local and not show_compute:
|
149
|
+
|
150
|
+
def _strip(s: str) -> str:
|
151
|
+
parts = s.split(": ", 1)
|
152
|
+
return parts[1] if len(parts) == 2 else s
|
153
|
+
|
154
|
+
sample = [_strip(s) for s in local_buffer[:max_local_batch]]
|
155
|
+
more = (
|
156
|
+
""
|
157
|
+
if len(local_buffer) <= max_local_batch
|
158
|
+
else f" (+{len(local_buffer) - max_local_batch} more)"
|
159
|
+
)
|
160
|
+
emit("note", f"note over P0,P{wsize - 1}: {', '.join(sample)}{more}")
|
161
|
+
else:
|
162
|
+
for line in local_buffer:
|
163
|
+
emit("local", line)
|
164
|
+
local_buffer = []
|
165
|
+
|
166
|
+
def owners_of(node: mpir_pb2.NodeProto) -> set[int]:
|
167
|
+
ranks: set[int] = set()
|
168
|
+
for pm in _node_output_pmasks(node):
|
169
|
+
if pm >= 0:
|
170
|
+
ranks.update(_pmask_to_ranks(pm))
|
171
|
+
return ranks
|
172
|
+
|
173
|
+
for node in graph.nodes:
|
174
|
+
own = owners_of(node)
|
175
|
+
input_ranks: set[int] = set()
|
176
|
+
for val in node.inputs:
|
177
|
+
prod = value_producer.get(val.split(":")[0])
|
178
|
+
if prod:
|
179
|
+
input_ranks.update(owners_of(prod))
|
180
|
+
pfunc = node.attrs.get("pfunc") if node.op_type == "eval" else None
|
181
|
+
if node.op_type == "eval" and pfunc:
|
182
|
+
fn_name = pfunc.func.name or pfunc.func.type or "eval"
|
183
|
+
label = f"{fn_name} {node.name}"
|
184
|
+
elif node.op_type == "access":
|
185
|
+
label = ""
|
186
|
+
else:
|
187
|
+
label = f"{node.op_type} {node.name}"
|
188
|
+
value_suffix = ""
|
189
|
+
if len(node.outs_info) > 1:
|
190
|
+
value_suffix = f" -> {len(node.outs_info)} outs"
|
191
|
+
if node.op_type == "access":
|
192
|
+
continue
|
193
|
+
if not show_meta and node.op_type in {"tuple", "func_def"}:
|
194
|
+
continue
|
195
|
+
if node.op_type == "shfl_s":
|
196
|
+
flush_local_buffer()
|
197
|
+
pmask_attr = node.attrs.get("pmask")
|
198
|
+
src_ranks_attr = node.attrs.get("src_ranks")
|
199
|
+
if pmask_attr and src_ranks_attr:
|
200
|
+
dst_ranks = _pmask_to_ranks(pmask_attr.i)
|
201
|
+
src_ranks = list(src_ranks_attr.ints)
|
202
|
+
for dst, src in zip(dst_ranks, src_ranks, strict=True):
|
203
|
+
emit("comm", f"P{src}->>P{dst}: {node.name}")
|
204
|
+
else:
|
205
|
+
emit("comm", f"note over P0,P{wsize - 1}: send %${node.name}")
|
206
|
+
continue
|
207
|
+
if node.op_type in {"cond", "while"}:
|
208
|
+
flush_local_buffer()
|
209
|
+
emit("note", f"note over P0,P{wsize - 1}: {node.op_type} {node.name}")
|
210
|
+
continue
|
211
|
+
cross = False
|
212
|
+
if input_ranks and own and (own - input_ranks or input_ranks - own):
|
213
|
+
cross = True
|
214
|
+
if cross and own and input_ranks:
|
215
|
+
flush_local_buffer()
|
216
|
+
for s in sorted(input_ranks):
|
217
|
+
for t in sorted(own):
|
218
|
+
if s != t:
|
219
|
+
emit("comm", f"P{s}->>P{t}: {label}{value_suffix}")
|
220
|
+
continue
|
221
|
+
local_desc = f"{label}{value_suffix}".strip()
|
222
|
+
if not local_desc:
|
223
|
+
continue
|
224
|
+
if own:
|
225
|
+
repr_rank = min(own)
|
226
|
+
local_line = f"P{repr_rank}-->>P{repr_rank}: {local_desc}"
|
227
|
+
else:
|
228
|
+
local_line = f"note over P0,P{wsize - 1}: {local_desc} (dyn)"
|
229
|
+
local_buffer.append(local_line)
|
230
|
+
|
231
|
+
flush_local_buffer()
|
232
|
+
|
233
|
+
out_lines: list[str] = ["sequenceDiagram"]
|
234
|
+
out_lines.extend(participants)
|
235
|
+
for ev in events:
|
236
|
+
out_lines.extend(ev.lines)
|
237
|
+
return "\n".join(out_lines)
|
238
|
+
|
239
|
+
|
240
|
+
def to_flowchart(
|
241
|
+
graph: mpir_pb2.GraphProto,
|
242
|
+
*,
|
243
|
+
group_local: bool = True,
|
244
|
+
world_size: int | None = None,
|
245
|
+
show_meta: bool = False,
|
246
|
+
direction: str = "LR",
|
247
|
+
cross_edge_color: str = "#ff6a00",
|
248
|
+
cluster_by_party: bool = False,
|
249
|
+
shared_cluster_name: str = "Shared",
|
250
|
+
) -> str:
|
251
|
+
"""Render a MPIR graph as a Mermaid flowchart (DAG view).
|
252
|
+
|
253
|
+
Parameters:
|
254
|
+
graph: The MPIR graph proto to visualize.
|
255
|
+
group_local: (Reserved) placeholder for future local grouping in non-cluster view.
|
256
|
+
world_size: Optional explicit party count override (inferred if None).
|
257
|
+
show_meta: Include meta/structural nodes when True.
|
258
|
+
direction: Mermaid layout direction (LR, RL, TB, BT). Accepts TD synonym for TB.
|
259
|
+
cross_edge_color: CSS color used to highlight cross-party data edges.
|
260
|
+
cluster_by_party: If True, wrap nodes in per-party subgraphs plus a shared cluster.
|
261
|
+
shared_cluster_name: Title for the shared subgraph cluster when cluster_by_party=True.
|
262
|
+
|
263
|
+
Returns:
|
264
|
+
Mermaid flowchart text.
|
265
|
+
"""
|
266
|
+
value_to_node: dict[str, mpir_pb2.NodeProto] = {}
|
267
|
+
for n in graph.nodes:
|
268
|
+
base = n.name
|
269
|
+
arity = max(1, len(n.outs_info))
|
270
|
+
for i in range(arity):
|
271
|
+
key = f"{base}" if arity == 1 else f"{base}:{i}"
|
272
|
+
value_to_node[key] = n
|
273
|
+
|
274
|
+
def owners_of(node: mpir_pb2.NodeProto) -> set[int]:
|
275
|
+
rs: set[int] = set()
|
276
|
+
for out in node.outs_info:
|
277
|
+
if out.pmask >= 0:
|
278
|
+
rs.update(_pmask_to_ranks(out.pmask))
|
279
|
+
return rs
|
280
|
+
|
281
|
+
node_labels: list[str] = []
|
282
|
+
per_party_nodes: dict[int, list[str]] = {}
|
283
|
+
shared_nodes: list[str] = []
|
284
|
+
node_id_map: dict[str, str] = {}
|
285
|
+
node_map: dict[str, mpir_pb2.NodeProto] = {n.name: n for n in graph.nodes}
|
286
|
+
id_to_owners: dict[str, set[int]] = {}
|
287
|
+
for n in graph.nodes:
|
288
|
+
if n.op_type == "access":
|
289
|
+
continue
|
290
|
+
if not show_meta and n.op_type in {"tuple", "func_def"}:
|
291
|
+
continue
|
292
|
+
node_id = f"n{n.name[1:]}"
|
293
|
+
node_id_map[n.name] = node_id
|
294
|
+
pfunc = n.attrs.get("pfunc") if n.op_type == "eval" else None
|
295
|
+
if pfunc:
|
296
|
+
op_label = pfunc.func.name or pfunc.func.type or n.op_type
|
297
|
+
else:
|
298
|
+
op_label = n.op_type
|
299
|
+
arity = len(n.outs_info)
|
300
|
+
arity_suffix = f"/{arity}" if arity > 1 else ""
|
301
|
+
owners = owners_of(n)
|
302
|
+
owners_str = (
|
303
|
+
"" if not owners else " @" + ",".join(f"P{r}" for r in sorted(owners))
|
304
|
+
)
|
305
|
+
label_line = f'{node_id}["{op_label}{arity_suffix}{owners_str}"]'
|
306
|
+
if cluster_by_party:
|
307
|
+
if len(owners) == 1:
|
308
|
+
(owner_rank,) = tuple(owners)
|
309
|
+
per_party_nodes.setdefault(owner_rank, []).append(label_line)
|
310
|
+
else:
|
311
|
+
shared_nodes.append(label_line)
|
312
|
+
else:
|
313
|
+
node_labels.append(label_line)
|
314
|
+
id_to_owners[node_id] = owners
|
315
|
+
|
316
|
+
def resolve_sources(val: str, seen: set[str] | None = None) -> set[str]:
|
317
|
+
if seen is None:
|
318
|
+
seen = set()
|
319
|
+
base = val.split(":")[0]
|
320
|
+
node = value_to_node.get(val) or value_to_node.get(base)
|
321
|
+
if not node:
|
322
|
+
return set()
|
323
|
+
if node.name in seen:
|
324
|
+
return set()
|
325
|
+
seen.add(node.name)
|
326
|
+
if node.op_type == "access":
|
327
|
+
srcs: set[str] = set()
|
328
|
+
for upstream in node.inputs:
|
329
|
+
srcs |= resolve_sources(upstream, seen)
|
330
|
+
return srcs
|
331
|
+
return {node.name}
|
332
|
+
|
333
|
+
edge_set: set[tuple[str, str]] = set()
|
334
|
+
for n in graph.nodes:
|
335
|
+
if n.op_type == "access":
|
336
|
+
continue
|
337
|
+
if not show_meta and n.op_type in {"tuple", "func_def"}:
|
338
|
+
continue
|
339
|
+
dst_id = node_id_map.get(n.name)
|
340
|
+
if not dst_id:
|
341
|
+
continue
|
342
|
+
for inp in n.inputs:
|
343
|
+
for src_name in resolve_sources(inp):
|
344
|
+
if (not show_meta) and node_map[src_name].op_type in {
|
345
|
+
"tuple",
|
346
|
+
"func_def",
|
347
|
+
}:
|
348
|
+
continue
|
349
|
+
src_id = node_id_map.get(src_name)
|
350
|
+
if not src_id or src_id == dst_id:
|
351
|
+
continue
|
352
|
+
edge_set.add((src_id, dst_id))
|
353
|
+
|
354
|
+
ordered_edges = sorted(edge_set)
|
355
|
+
edges = [f"{s} --> {t}" for s, t in ordered_edges]
|
356
|
+
cross_indices: list[int] = []
|
357
|
+
for idx, (s, t) in enumerate(ordered_edges):
|
358
|
+
so = id_to_owners.get(s, set())
|
359
|
+
to = id_to_owners.get(t, set())
|
360
|
+
if so and to and so != to:
|
361
|
+
cross_indices.append(idx)
|
362
|
+
|
363
|
+
_ = group_local
|
364
|
+
|
365
|
+
dir_norm = direction.upper()
|
366
|
+
if dir_norm == "TD":
|
367
|
+
dir_norm = "TB"
|
368
|
+
if dir_norm not in {"LR", "TB", "RL", "BT"}:
|
369
|
+
dir_norm = "LR"
|
370
|
+
result_lines = [f"graph {dir_norm};"]
|
371
|
+
result_lines.append("")
|
372
|
+
if cluster_by_party:
|
373
|
+
wsize = _collect_world_size(graph, world_size)
|
374
|
+
for r in range(wsize):
|
375
|
+
nodes = per_party_nodes.get(r)
|
376
|
+
if not nodes:
|
377
|
+
continue
|
378
|
+
result_lines.append(f" subgraph P{r}")
|
379
|
+
for ln in nodes:
|
380
|
+
result_lines.append(f" {ln}")
|
381
|
+
result_lines.append(" end")
|
382
|
+
result_lines.append("")
|
383
|
+
if shared_nodes:
|
384
|
+
result_lines.append(f" subgraph {shared_cluster_name}")
|
385
|
+
for ln in shared_nodes:
|
386
|
+
result_lines.append(f" {ln}")
|
387
|
+
result_lines.append(" end")
|
388
|
+
result_lines.append("")
|
389
|
+
else:
|
390
|
+
for lbl in node_labels:
|
391
|
+
if lbl:
|
392
|
+
result_lines.append(" " + lbl)
|
393
|
+
if node_labels:
|
394
|
+
result_lines.append("")
|
395
|
+
for edge in edges:
|
396
|
+
result_lines.append(" " + edge)
|
397
|
+
for ci in cross_indices:
|
398
|
+
result_lines.append(
|
399
|
+
f" linkStyle {ci} stroke:{cross_edge_color},stroke-width:2px;"
|
400
|
+
)
|
401
|
+
return "\n".join(result_lines)
|
402
|
+
|
403
|
+
|
404
|
+
# ----------------------------- Markdown dump -----------------------------
|
405
|
+
|
406
|
+
|
407
|
+
def dump(
|
408
|
+
traced: TracedFunction,
|
409
|
+
*,
|
410
|
+
cluster_spec: ClusterSpec | None = None,
|
411
|
+
sequence: bool = True,
|
412
|
+
flow: bool = True,
|
413
|
+
include_ir: bool = True,
|
414
|
+
report_path: str | Path | None = None,
|
415
|
+
mpir_path: str | Path | None = None,
|
416
|
+
title: str | None = None,
|
417
|
+
seq_opts: SequenceDiagramOptions | None = None,
|
418
|
+
flow_opts: FlowchartOptions | None = None,
|
419
|
+
) -> DumpResult:
|
420
|
+
"""Generate a composite analysis report (markdown + structured fields).
|
421
|
+
|
422
|
+
Sections (conditionally) included in the markdown:
|
423
|
+
- Title (if provided)
|
424
|
+
- Cluster Specification (if cluster_spec provided)
|
425
|
+
- Compiler IR (if include_ir)
|
426
|
+
- Graph Structure Analysis (always)
|
427
|
+
- Mermaid Sequence Diagram (if sequence=True)
|
428
|
+
- Mermaid Flowchart (if flow=True)
|
429
|
+
|
430
|
+
Parameters:
|
431
|
+
traced: TracedFunction object produced by the compilation pipeline.
|
432
|
+
cluster_spec: Optional cluster topology; when provided world size and a
|
433
|
+
JSON summary block are derived from it.
|
434
|
+
sequence: Whether to render a sequence diagram section.
|
435
|
+
flow: Whether to render a flowchart diagram section.
|
436
|
+
include_ir: Include textual compiler IR section when True.
|
437
|
+
report_path: If set, write the assembled markdown to this path.
|
438
|
+
mpir_path: If set, write the raw MPIR proto text to this path.
|
439
|
+
title: Optional top-level markdown title.
|
440
|
+
seq_opts: Options controlling sequence diagram rendering.
|
441
|
+
flow_opts: Options controlling flowchart rendering.
|
442
|
+
|
443
|
+
Returns:
|
444
|
+
DumpResult containing individual textual artifacts and the combined markdown.
|
445
|
+
"""
|
446
|
+
if report_path is None and mpir_path is None:
|
447
|
+
raise ValueError(
|
448
|
+
"dump() requires at least one output path: report_path for markdown or mpir_path for raw IR"
|
449
|
+
)
|
450
|
+
|
451
|
+
# Build graph once
|
452
|
+
expr = traced.make_expr()
|
453
|
+
graph_proto = Writer().dumps(expr)
|
454
|
+
|
455
|
+
# Derive world_size from cluster_spec if provided
|
456
|
+
derived_world_size: int | None = None
|
457
|
+
if cluster_spec is not None:
|
458
|
+
# world_size defined as number of physical nodes (ranks)
|
459
|
+
derived_world_size = len(cluster_spec.nodes)
|
460
|
+
|
461
|
+
parts: list[str] = []
|
462
|
+
if title:
|
463
|
+
parts.append(f"# {title}\n")
|
464
|
+
|
465
|
+
if cluster_spec is not None:
|
466
|
+
parts.append("## Cluster Specification\n")
|
467
|
+
parts.append("```json")
|
468
|
+
# Minimal JSON-ish representation (ordering may vary)
|
469
|
+
import json as _json
|
470
|
+
|
471
|
+
parts.append(
|
472
|
+
_json.dumps(
|
473
|
+
{
|
474
|
+
"nodes": [
|
475
|
+
{
|
476
|
+
"name": n.name,
|
477
|
+
"rank": n.rank,
|
478
|
+
"endpoint": n.endpoint,
|
479
|
+
}
|
480
|
+
for n in sorted(
|
481
|
+
cluster_spec.nodes.values(), key=lambda x: x.rank
|
482
|
+
)
|
483
|
+
],
|
484
|
+
"devices": {
|
485
|
+
name: {
|
486
|
+
"kind": dev.kind,
|
487
|
+
"members": [m.name for m in dev.members],
|
488
|
+
}
|
489
|
+
for name, dev in sorted(cluster_spec.devices.items())
|
490
|
+
},
|
491
|
+
},
|
492
|
+
indent=2,
|
493
|
+
)
|
494
|
+
)
|
495
|
+
parts.append("```")
|
496
|
+
|
497
|
+
ir_text: str | None = None
|
498
|
+
if include_ir:
|
499
|
+
ir_text = traced.compiler_ir()
|
500
|
+
parts.append("## Compiler IR (text)\n")
|
501
|
+
parts.append("```")
|
502
|
+
parts.append(ir_text)
|
503
|
+
parts.append("```")
|
504
|
+
|
505
|
+
stats = get_graph_statistics(graph_proto)
|
506
|
+
parts.append("## Graph Structure Analysis\n")
|
507
|
+
parts.append("```")
|
508
|
+
parts.append(stats)
|
509
|
+
parts.append("```")
|
510
|
+
|
511
|
+
seq_text: str | None = None
|
512
|
+
if sequence:
|
513
|
+
seq_opts = seq_opts or {}
|
514
|
+
seq_text = to_sequence_diagram(
|
515
|
+
graph_proto,
|
516
|
+
world_size=derived_world_size,
|
517
|
+
**seq_opts,
|
518
|
+
)
|
519
|
+
parts.append("## Mermaid Sequence Diagram")
|
520
|
+
parts.append("```mermaid")
|
521
|
+
parts.append(seq_text)
|
522
|
+
parts.append("```")
|
523
|
+
|
524
|
+
flow_text: str | None = None
|
525
|
+
if flow:
|
526
|
+
flow_opts = flow_opts or {}
|
527
|
+
effective_world_size = derived_world_size
|
528
|
+
if effective_world_size is None and "world_size" in flow_opts:
|
529
|
+
effective_world_size = flow_opts["world_size"] # type: ignore[assignment]
|
530
|
+
flow_text = to_flowchart(
|
531
|
+
graph_proto,
|
532
|
+
world_size=effective_world_size,
|
533
|
+
group_local=flow_opts.get("group_local", True),
|
534
|
+
show_meta=flow_opts.get("show_meta", False),
|
535
|
+
direction=flow_opts.get("direction", "LR"),
|
536
|
+
cross_edge_color=flow_opts.get("cross_edge_color", "#ff6a00"),
|
537
|
+
cluster_by_party=flow_opts.get("cluster_by_party", False),
|
538
|
+
shared_cluster_name=flow_opts.get("shared_cluster_name", "Shared"),
|
539
|
+
)
|
540
|
+
parts.append("## Mermaid Flowchart (DAG)")
|
541
|
+
parts.append("```mermaid")
|
542
|
+
if flow_text is not None:
|
543
|
+
parts.append(flow_text)
|
544
|
+
parts.append("```")
|
545
|
+
|
546
|
+
markdown = "\n\n".join(parts) + "\n"
|
547
|
+
|
548
|
+
if mpir_path:
|
549
|
+
Path(mpir_path).write_text(str(graph_proto), encoding="utf-8")
|
550
|
+
if report_path:
|
551
|
+
Path(report_path).write_text(markdown, encoding="utf-8")
|
552
|
+
|
553
|
+
return DumpResult(
|
554
|
+
ir=ir_text,
|
555
|
+
stats=stats,
|
556
|
+
sequence=seq_text,
|
557
|
+
flow=flow_text,
|
558
|
+
markdown=markdown,
|
559
|
+
)
|
560
|
+
|
561
|
+
|
562
|
+
__all__ = [
|
563
|
+
"DumpResult",
|
564
|
+
"FlowchartOptions",
|
565
|
+
"SequenceDiagramOptions",
|
566
|
+
"dump",
|
567
|
+
"to_flowchart",
|
568
|
+
"to_sequence_diagram",
|
569
|
+
]
|