mplang-nightly 0.1.dev268__py3-none-any.whl → 0.1.dev270__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (181) hide show
  1. mplang/__init__.py +391 -17
  2. mplang/{v2/backends → backends}/__init__.py +9 -7
  3. mplang/{v2/backends → backends}/bfv_impl.py +6 -6
  4. mplang/{v2/backends → backends}/crypto_impl.py +6 -6
  5. mplang/{v2/backends → backends}/field_impl.py +5 -5
  6. mplang/{v2/backends → backends}/func_impl.py +4 -4
  7. mplang/{v2/backends → backends}/phe_impl.py +3 -3
  8. mplang/{v2/backends → backends}/simp_design.md +1 -1
  9. mplang/{v2/backends → backends}/simp_driver/__init__.py +5 -5
  10. mplang/{v2/backends → backends}/simp_driver/http.py +8 -8
  11. mplang/{v2/backends → backends}/simp_driver/mem.py +9 -9
  12. mplang/{v2/backends → backends}/simp_driver/ops.py +4 -4
  13. mplang/{v2/backends → backends}/simp_driver/state.py +2 -2
  14. mplang/{v2/backends → backends}/simp_driver/values.py +2 -2
  15. mplang/{v2/backends → backends}/simp_worker/__init__.py +3 -3
  16. mplang/{v2/backends → backends}/simp_worker/http.py +10 -10
  17. mplang/{v2/backends → backends}/simp_worker/mem.py +1 -1
  18. mplang/{v2/backends → backends}/simp_worker/ops.py +5 -5
  19. mplang/{v2/backends → backends}/simp_worker/state.py +2 -4
  20. mplang/{v2/backends → backends}/spu_impl.py +8 -8
  21. mplang/{v2/backends → backends}/spu_state.py +4 -4
  22. mplang/{v2/backends → backends}/store_impl.py +3 -3
  23. mplang/{v2/backends → backends}/table_impl.py +8 -8
  24. mplang/{v2/backends → backends}/tee_impl.py +6 -6
  25. mplang/{v2/backends → backends}/tensor_impl.py +6 -6
  26. mplang/{v2/cli.py → cli.py} +9 -9
  27. mplang/{v2/cli_guide.md → cli_guide.md} +12 -12
  28. mplang/{v2/dialects → dialects}/__init__.py +5 -5
  29. mplang/{v2/dialects → dialects}/bfv.py +6 -6
  30. mplang/{v2/dialects → dialects}/crypto.py +5 -5
  31. mplang/{v2/dialects → dialects}/dtypes.py +2 -2
  32. mplang/{v2/dialects → dialects}/field.py +3 -3
  33. mplang/{v2/dialects → dialects}/func.py +2 -2
  34. mplang/{v2/dialects → dialects}/phe.py +6 -6
  35. mplang/{v2/dialects → dialects}/simp.py +6 -6
  36. mplang/{v2/dialects → dialects}/spu.py +7 -7
  37. mplang/{v2/dialects → dialects}/store.py +2 -2
  38. mplang/{v2/dialects → dialects}/table.py +3 -3
  39. mplang/{v2/dialects → dialects}/tee.py +6 -6
  40. mplang/{v2/dialects → dialects}/tensor.py +5 -5
  41. mplang/{v2/edsl → edsl}/__init__.py +3 -3
  42. mplang/{v2/edsl → edsl}/context.py +6 -6
  43. mplang/{v2/edsl → edsl}/graph.py +5 -5
  44. mplang/{v2/edsl → edsl}/jit.py +2 -2
  45. mplang/{v2/edsl → edsl}/object.py +1 -1
  46. mplang/{v2/edsl → edsl}/primitive.py +5 -5
  47. mplang/{v2/edsl → edsl}/printer.py +1 -1
  48. mplang/{v2/edsl → edsl}/serde.py +1 -1
  49. mplang/{v2/edsl → edsl}/tracer.py +7 -7
  50. mplang/{v2/edsl → edsl}/typing.py +1 -1
  51. mplang/{v2/kernels → kernels}/ldpc.cpp +13 -13
  52. mplang/{v2/kernels → kernels}/okvs.cpp +4 -4
  53. mplang/{v2/kernels → kernels}/okvs_opt.cpp +46 -31
  54. mplang/{v2/kernels → kernels}/py_kernels.py +1 -1
  55. mplang/{v2/libs → libs}/collective.py +5 -5
  56. mplang/{v2/libs → libs}/device/__init__.py +1 -1
  57. mplang/{v2/libs → libs}/device/api.py +12 -12
  58. mplang/{v2/libs → libs}/ml/__init__.py +1 -1
  59. mplang/{v2/libs → libs}/ml/sgb.py +4 -4
  60. mplang/{v2/libs → libs}/mpc/__init__.py +3 -3
  61. mplang/{v2/libs → libs}/mpc/_utils.py +2 -2
  62. mplang/{v2/libs → libs}/mpc/analytics/aggregation.py +1 -1
  63. mplang/{v2/libs → libs}/mpc/analytics/groupby.py +2 -2
  64. mplang/{v2/libs → libs}/mpc/analytics/permutation.py +3 -3
  65. mplang/{v2/libs → libs}/mpc/ot/base.py +3 -3
  66. mplang/{v2/libs → libs}/mpc/ot/extension.py +2 -2
  67. mplang/{v2/libs → libs}/mpc/ot/silent.py +4 -4
  68. mplang/{v2/libs → libs}/mpc/psi/cuckoo.py +3 -3
  69. mplang/{v2/libs → libs}/mpc/psi/okvs.py +1 -1
  70. mplang/{v2/libs → libs}/mpc/psi/okvs_gct.py +19 -13
  71. mplang/{v2/libs → libs}/mpc/psi/oprf.py +3 -3
  72. mplang/libs/mpc/psi/rr22.py +303 -0
  73. mplang/{v2/libs → libs}/mpc/psi/unbalanced.py +4 -4
  74. mplang/{v2/libs → libs}/mpc/vole/gilboa.py +3 -3
  75. mplang/{v2/libs → libs}/mpc/vole/ldpc.py +2 -2
  76. mplang/{v2/libs → libs}/mpc/vole/silver.py +6 -6
  77. mplang/{v2/runtime → runtime}/interpreter.py +11 -11
  78. mplang/{v2/runtime → runtime}/value.py +2 -2
  79. mplang/{v1/runtime → utils}/__init__.py +18 -15
  80. mplang/{v1/utils → utils}/func_utils.py +1 -1
  81. {mplang_nightly-0.1.dev268.dist-info → mplang_nightly-0.1.dev270.dist-info}/METADATA +2 -2
  82. mplang_nightly-0.1.dev270.dist-info/RECORD +102 -0
  83. mplang/v1/__init__.py +0 -157
  84. mplang/v1/_device.py +0 -602
  85. mplang/v1/analysis/__init__.py +0 -37
  86. mplang/v1/analysis/diagram.py +0 -567
  87. mplang/v1/core/__init__.py +0 -157
  88. mplang/v1/core/cluster.py +0 -343
  89. mplang/v1/core/comm.py +0 -281
  90. mplang/v1/core/context_mgr.py +0 -50
  91. mplang/v1/core/dtypes.py +0 -335
  92. mplang/v1/core/expr/__init__.py +0 -80
  93. mplang/v1/core/expr/ast.py +0 -542
  94. mplang/v1/core/expr/evaluator.py +0 -581
  95. mplang/v1/core/expr/printer.py +0 -285
  96. mplang/v1/core/expr/transformer.py +0 -141
  97. mplang/v1/core/expr/utils.py +0 -78
  98. mplang/v1/core/expr/visitor.py +0 -85
  99. mplang/v1/core/expr/walk.py +0 -387
  100. mplang/v1/core/interp.py +0 -160
  101. mplang/v1/core/mask.py +0 -325
  102. mplang/v1/core/mpir.py +0 -965
  103. mplang/v1/core/mpobject.py +0 -117
  104. mplang/v1/core/mptype.py +0 -407
  105. mplang/v1/core/pfunc.py +0 -130
  106. mplang/v1/core/primitive.py +0 -877
  107. mplang/v1/core/table.py +0 -218
  108. mplang/v1/core/tensor.py +0 -75
  109. mplang/v1/core/tracer.py +0 -383
  110. mplang/v1/host.py +0 -130
  111. mplang/v1/kernels/__init__.py +0 -41
  112. mplang/v1/kernels/base.py +0 -125
  113. mplang/v1/kernels/basic.py +0 -240
  114. mplang/v1/kernels/context.py +0 -369
  115. mplang/v1/kernels/crypto.py +0 -122
  116. mplang/v1/kernels/fhe.py +0 -858
  117. mplang/v1/kernels/mock_tee.py +0 -72
  118. mplang/v1/kernels/phe.py +0 -1864
  119. mplang/v1/kernels/spu.py +0 -341
  120. mplang/v1/kernels/sql_duckdb.py +0 -44
  121. mplang/v1/kernels/stablehlo.py +0 -90
  122. mplang/v1/kernels/value.py +0 -626
  123. mplang/v1/ops/__init__.py +0 -35
  124. mplang/v1/ops/base.py +0 -424
  125. mplang/v1/ops/basic.py +0 -294
  126. mplang/v1/ops/crypto.py +0 -262
  127. mplang/v1/ops/fhe.py +0 -272
  128. mplang/v1/ops/jax_cc.py +0 -147
  129. mplang/v1/ops/nnx_cc.py +0 -168
  130. mplang/v1/ops/phe.py +0 -216
  131. mplang/v1/ops/spu.py +0 -151
  132. mplang/v1/ops/sql_cc.py +0 -303
  133. mplang/v1/ops/tee.py +0 -36
  134. mplang/v1/protos/v1alpha1/mpir_pb2.py +0 -63
  135. mplang/v1/protos/v1alpha1/mpir_pb2.pyi +0 -557
  136. mplang/v1/protos/v1alpha1/value_pb2.py +0 -34
  137. mplang/v1/protos/v1alpha1/value_pb2.pyi +0 -169
  138. mplang/v1/runtime/channel.py +0 -230
  139. mplang/v1/runtime/cli.py +0 -451
  140. mplang/v1/runtime/client.py +0 -456
  141. mplang/v1/runtime/communicator.py +0 -131
  142. mplang/v1/runtime/data_providers.py +0 -303
  143. mplang/v1/runtime/driver.py +0 -324
  144. mplang/v1/runtime/exceptions.py +0 -27
  145. mplang/v1/runtime/http_api.md +0 -56
  146. mplang/v1/runtime/link_comm.py +0 -196
  147. mplang/v1/runtime/server.py +0 -501
  148. mplang/v1/runtime/session.py +0 -270
  149. mplang/v1/runtime/simulation.py +0 -324
  150. mplang/v1/simp/__init__.py +0 -13
  151. mplang/v1/simp/api.py +0 -353
  152. mplang/v1/simp/mpi.py +0 -131
  153. mplang/v1/simp/party.py +0 -225
  154. mplang/v1/simp/random.py +0 -120
  155. mplang/v1/simp/smpc.py +0 -238
  156. mplang/v1/utils/__init__.py +0 -13
  157. mplang/v1/utils/crypto.py +0 -32
  158. mplang/v1/utils/spu_utils.py +0 -130
  159. mplang/v1/utils/table_utils.py +0 -185
  160. mplang/v2/__init__.py +0 -424
  161. mplang/v2/libs/mpc/psi/rr22.py +0 -344
  162. mplang_nightly-0.1.dev268.dist-info/RECORD +0 -180
  163. /mplang/{v2/backends → backends}/channel.py +0 -0
  164. /mplang/{v2/edsl → edsl}/README.md +0 -0
  165. /mplang/{v2/edsl → edsl}/registry.py +0 -0
  166. /mplang/{v2/kernels → kernels}/Makefile +0 -0
  167. /mplang/{v2/kernels → kernels}/__init__.py +0 -0
  168. /mplang/{v2/kernels → kernels}/gf128.cpp +0 -0
  169. /mplang/{v2/libs → libs}/device/cluster.py +0 -0
  170. /mplang/{v2/libs → libs}/mpc/analytics/__init__.py +0 -0
  171. /mplang/{v2/libs → libs}/mpc/analytics/groupby.md +0 -0
  172. /mplang/{v2/libs → libs}/mpc/common/constants.py +0 -0
  173. /mplang/{v2/libs → libs}/mpc/ot/__init__.py +0 -0
  174. /mplang/{v2/libs → libs}/mpc/psi/__init__.py +0 -0
  175. /mplang/{v2/libs → libs}/mpc/vole/__init__.py +0 -0
  176. /mplang/{v2/runtime → runtime}/__init__.py +0 -0
  177. /mplang/{v2/runtime → runtime}/dialect_state.py +0 -0
  178. /mplang/{v2/runtime → runtime}/object_store.py +0 -0
  179. {mplang_nightly-0.1.dev268.dist-info → mplang_nightly-0.1.dev270.dist-info}/WHEEL +0 -0
  180. {mplang_nightly-0.1.dev268.dist-info → mplang_nightly-0.1.dev270.dist-info}/entry_points.txt +0 -0
  181. {mplang_nightly-0.1.dev268.dist-info → mplang_nightly-0.1.dev270.dist-info}/licenses/LICENSE +0 -0
@@ -1,567 +0,0 @@
1
- # Copyright 2025 Ant Group Co., Ltd.
2
- #
3
- # Licensed under the Apache License, Version 2.0 (the "License");
4
- # you may not use this file except in compliance with the License.
5
- # You may obtain a copy of the License at
6
- #
7
- # http://www.apache.org/licenses/LICENSE-2.0
8
- #
9
- # Unless required by applicable law or agreed to in writing, software
10
- # distributed under the License is distributed on an "AS IS" BASIS,
11
- # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
- # See the License for the specific language governing permissions and
13
- # limitations under the License.
14
-
15
- """Diagram rendering (Mermaid) and markdown dump helpers.
16
-
17
- Moved from mplang.v1.utils.mermaid to dedicated analysis namespace.
18
- """
19
-
20
- from __future__ import annotations
21
-
22
- from dataclasses import dataclass
23
- from pathlib import Path
24
- from typing import TypedDict
25
-
26
- from mplang.v1.core import ClusterSpec, IrWriter, Mask, TracedFunction
27
- from mplang.v1.core.mpir import get_graph_statistics
28
- from mplang.v1.protos.v1alpha1 import mpir_pb2
29
-
30
- # ----------------------------- Core helpers (copied) -----------------------------
31
-
32
-
33
- @dataclass
34
- class Event:
35
- kind: str
36
- lines: list[str]
37
-
38
-
39
- def _pmask_to_ranks(pmask: int) -> list[int]:
40
- return list(Mask(pmask)) if pmask >= 0 else []
41
-
42
-
43
- def _collect_world_size(
44
- graph: mpir_pb2.GraphProto, explicit_world_size: int | None
45
- ) -> int:
46
- if explicit_world_size is not None:
47
- return explicit_world_size
48
- max_rank = -1
49
- for node in graph.nodes:
50
- for out in node.outs_info:
51
- if out.pmask >= 0:
52
- for r in _pmask_to_ranks(out.pmask):
53
- max_rank = max(max_rank, r)
54
- return max_rank + 1 if max_rank >= 0 else 0
55
-
56
-
57
- def _node_output_pmasks(node: mpir_pb2.NodeProto) -> list[int]:
58
- return [out.pmask for out in node.outs_info]
59
-
60
-
61
- def _value_producers(graph: mpir_pb2.GraphProto) -> dict[str, mpir_pb2.NodeProto]:
62
- prod: dict[str, mpir_pb2.NodeProto] = {}
63
- for n in graph.nodes:
64
- base = n.name
65
- arity = max(1, len(n.outs_info))
66
- for i in range(arity):
67
- key = f"{base}" if arity == 1 else f"{base}:{i}"
68
- prod[key] = n
69
- return prod
70
-
71
-
72
- # ----------------------------- Option Types -----------------------------
73
-
74
-
75
- class SequenceDiagramOptions(TypedDict, total=False):
76
- collapse_local: bool
77
- show_compute: bool
78
- max_local_batch: int
79
- show_meta: bool
80
-
81
-
82
- class FlowchartOptions(TypedDict, total=False):
83
- group_local: bool
84
- world_size: int
85
- show_meta: bool
86
- direction: str
87
- cross_edge_color: str
88
- cluster_by_party: bool
89
- shared_cluster_name: str
90
-
91
-
92
- @dataclass
93
- class DumpResult:
94
- ir: str | None
95
- stats: str
96
- sequence: str | None
97
- flow: str | None
98
- markdown: str
99
-
100
-
101
- # ----------------------------- Public APIs -----------------------------
102
-
103
-
104
- def to_sequence_diagram(
105
- graph: mpir_pb2.GraphProto,
106
- *,
107
- world_size: int | None = None,
108
- collapse_local: bool = False,
109
- show_compute: bool = False,
110
- max_local_batch: int = 8,
111
- show_meta: bool = False,
112
- ) -> str:
113
- """Render a MPIR graph as a Mermaid sequenceDiagram.
114
-
115
- Parameters:
116
- graph: The MPIR graph proto to visualize.
117
- world_size: Optional explicit number of parties (participants). If None it
118
- is inferred from the union of output pmasks in the graph.
119
- collapse_local: If True, consecutive pure-local operations are batched into
120
- a single note summarizing their op names (up to max_local_batch shown).
121
- show_compute: When True and collapse_local is also True, local operations
122
- are still listed individually (disable batching display logic).
123
- max_local_batch: Maximum number of individual local op names to show inside
124
- a collapsed note before summarizing with a "+N more" suffix.
125
- show_meta: If True, include structural/meta ops (e.g. tuple, func_def).
126
-
127
- Returns:
128
- Mermaid sequenceDiagram text.
129
- """
130
-
131
- wsize = _collect_world_size(graph, world_size)
132
- value_producer = _value_producers(graph)
133
-
134
- participants: list[str] = [f"participant P{i}" for i in range(wsize)]
135
- events: list[Event] = []
136
-
137
- def emit(kind: str, *lines: str) -> None:
138
- events.append(Event(kind=kind, lines=list(lines)))
139
-
140
- local_buffer: list[str] = []
141
-
142
- def flush_local_buffer() -> None:
143
- nonlocal local_buffer
144
- if not local_buffer:
145
- return
146
- if collapse_local and not show_compute:
147
-
148
- def _strip(s: str) -> str:
149
- parts = s.split(": ", 1)
150
- return parts[1] if len(parts) == 2 else s
151
-
152
- sample = [_strip(s) for s in local_buffer[:max_local_batch]]
153
- more = (
154
- ""
155
- if len(local_buffer) <= max_local_batch
156
- else f" (+{len(local_buffer) - max_local_batch} more)"
157
- )
158
- emit("note", f"note over P0,P{wsize - 1}: {', '.join(sample)}{more}")
159
- else:
160
- for line in local_buffer:
161
- emit("local", line)
162
- local_buffer = []
163
-
164
- def owners_of(node: mpir_pb2.NodeProto) -> set[int]:
165
- ranks: set[int] = set()
166
- for pm in _node_output_pmasks(node):
167
- if pm >= 0:
168
- ranks.update(_pmask_to_ranks(pm))
169
- return ranks
170
-
171
- for node in graph.nodes:
172
- own = owners_of(node)
173
- input_ranks: set[int] = set()
174
- for val in node.inputs:
175
- prod = value_producer.get(val.split(":")[0])
176
- if prod:
177
- input_ranks.update(owners_of(prod))
178
- pfunc = node.attrs.get("pfunc") if node.op_type == "eval" else None
179
- if node.op_type == "eval" and pfunc:
180
- fn_name = pfunc.func.name or pfunc.func.type or "eval"
181
- label = f"{fn_name} {node.name}"
182
- elif node.op_type == "access":
183
- label = ""
184
- else:
185
- label = f"{node.op_type} {node.name}"
186
- value_suffix = ""
187
- if len(node.outs_info) > 1:
188
- value_suffix = f" -> {len(node.outs_info)} outs"
189
- if node.op_type == "access":
190
- continue
191
- if not show_meta and node.op_type in {"tuple", "func_def"}:
192
- continue
193
- if node.op_type == "shfl_s":
194
- flush_local_buffer()
195
- pmask_attr = node.attrs.get("pmask")
196
- src_ranks_attr = node.attrs.get("src_ranks")
197
- if pmask_attr and src_ranks_attr:
198
- dst_ranks = _pmask_to_ranks(pmask_attr.i)
199
- src_ranks = list(src_ranks_attr.ints)
200
- for dst, src in zip(dst_ranks, src_ranks, strict=True):
201
- emit("comm", f"P{src}->>P{dst}: {node.name}")
202
- else:
203
- emit("comm", f"note over P0,P{wsize - 1}: send %${node.name}")
204
- continue
205
- if node.op_type in {"cond", "while"}:
206
- flush_local_buffer()
207
- emit("note", f"note over P0,P{wsize - 1}: {node.op_type} {node.name}")
208
- continue
209
- cross = False
210
- if input_ranks and own and (own - input_ranks or input_ranks - own):
211
- cross = True
212
- if cross and own and input_ranks:
213
- flush_local_buffer()
214
- for s in sorted(input_ranks):
215
- for t in sorted(own):
216
- if s != t:
217
- emit("comm", f"P{s}->>P{t}: {label}{value_suffix}")
218
- continue
219
- local_desc = f"{label}{value_suffix}".strip()
220
- if not local_desc:
221
- continue
222
- if own:
223
- repr_rank = min(own)
224
- local_line = f"P{repr_rank}-->>P{repr_rank}: {local_desc}"
225
- else:
226
- local_line = f"note over P0,P{wsize - 1}: {local_desc} (dyn)"
227
- local_buffer.append(local_line)
228
-
229
- flush_local_buffer()
230
-
231
- out_lines: list[str] = ["sequenceDiagram"]
232
- out_lines.extend(participants)
233
- for ev in events:
234
- out_lines.extend(ev.lines)
235
- return "\n".join(out_lines)
236
-
237
-
238
- def to_flowchart(
239
- graph: mpir_pb2.GraphProto,
240
- *,
241
- group_local: bool = True,
242
- world_size: int | None = None,
243
- show_meta: bool = False,
244
- direction: str = "LR",
245
- cross_edge_color: str = "#ff6a00",
246
- cluster_by_party: bool = False,
247
- shared_cluster_name: str = "Shared",
248
- ) -> str:
249
- """Render a MPIR graph as a Mermaid flowchart (DAG view).
250
-
251
- Parameters:
252
- graph: The MPIR graph proto to visualize.
253
- group_local: (Reserved) placeholder for future local grouping in non-cluster view.
254
- world_size: Optional explicit party count override (inferred if None).
255
- show_meta: Include meta/structural nodes when True.
256
- direction: Mermaid layout direction (LR, RL, TB, BT). Accepts TD synonym for TB.
257
- cross_edge_color: CSS color used to highlight cross-party data edges.
258
- cluster_by_party: If True, wrap nodes in per-party subgraphs plus a shared cluster.
259
- shared_cluster_name: Title for the shared subgraph cluster when cluster_by_party=True.
260
-
261
- Returns:
262
- Mermaid flowchart text.
263
- """
264
- value_to_node: dict[str, mpir_pb2.NodeProto] = {}
265
- for n in graph.nodes:
266
- base = n.name
267
- arity = max(1, len(n.outs_info))
268
- for i in range(arity):
269
- key = f"{base}" if arity == 1 else f"{base}:{i}"
270
- value_to_node[key] = n
271
-
272
- def owners_of(node: mpir_pb2.NodeProto) -> set[int]:
273
- rs: set[int] = set()
274
- for out in node.outs_info:
275
- if out.pmask >= 0:
276
- rs.update(_pmask_to_ranks(out.pmask))
277
- return rs
278
-
279
- node_labels: list[str] = []
280
- per_party_nodes: dict[int, list[str]] = {}
281
- shared_nodes: list[str] = []
282
- node_id_map: dict[str, str] = {}
283
- node_map: dict[str, mpir_pb2.NodeProto] = {n.name: n for n in graph.nodes}
284
- id_to_owners: dict[str, set[int]] = {}
285
- for n in graph.nodes:
286
- if n.op_type == "access":
287
- continue
288
- if not show_meta and n.op_type in {"tuple", "func_def"}:
289
- continue
290
- node_id = f"n{n.name[1:]}"
291
- node_id_map[n.name] = node_id
292
- pfunc = n.attrs.get("pfunc") if n.op_type == "eval" else None
293
- if pfunc:
294
- op_label = pfunc.func.name or pfunc.func.type or n.op_type
295
- else:
296
- op_label = n.op_type
297
- arity = len(n.outs_info)
298
- arity_suffix = f"/{arity}" if arity > 1 else ""
299
- owners = owners_of(n)
300
- owners_str = (
301
- "" if not owners else " @" + ",".join(f"P{r}" for r in sorted(owners))
302
- )
303
- label_line = f'{node_id}["{op_label}{arity_suffix}{owners_str}"]'
304
- if cluster_by_party:
305
- if len(owners) == 1:
306
- (owner_rank,) = tuple(owners)
307
- per_party_nodes.setdefault(owner_rank, []).append(label_line)
308
- else:
309
- shared_nodes.append(label_line)
310
- else:
311
- node_labels.append(label_line)
312
- id_to_owners[node_id] = owners
313
-
314
- def resolve_sources(val: str, seen: set[str] | None = None) -> set[str]:
315
- if seen is None:
316
- seen = set()
317
- base = val.split(":")[0]
318
- node = value_to_node.get(val) or value_to_node.get(base)
319
- if not node:
320
- return set()
321
- if node.name in seen:
322
- return set()
323
- seen.add(node.name)
324
- if node.op_type == "access":
325
- srcs: set[str] = set()
326
- for upstream in node.inputs:
327
- srcs |= resolve_sources(upstream, seen)
328
- return srcs
329
- return {node.name}
330
-
331
- edge_set: set[tuple[str, str]] = set()
332
- for n in graph.nodes:
333
- if n.op_type == "access":
334
- continue
335
- if not show_meta and n.op_type in {"tuple", "func_def"}:
336
- continue
337
- dst_id = node_id_map.get(n.name)
338
- if not dst_id:
339
- continue
340
- for inp in n.inputs:
341
- for src_name in resolve_sources(inp):
342
- if (not show_meta) and node_map[src_name].op_type in {
343
- "tuple",
344
- "func_def",
345
- }:
346
- continue
347
- src_id = node_id_map.get(src_name)
348
- if not src_id or src_id == dst_id:
349
- continue
350
- edge_set.add((src_id, dst_id))
351
-
352
- ordered_edges = sorted(edge_set)
353
- edges = [f"{s} --> {t}" for s, t in ordered_edges]
354
- cross_indices: list[int] = []
355
- for idx, (s, t) in enumerate(ordered_edges):
356
- so = id_to_owners.get(s, set())
357
- to = id_to_owners.get(t, set())
358
- if so and to and so != to:
359
- cross_indices.append(idx)
360
-
361
- _ = group_local
362
-
363
- dir_norm = direction.upper()
364
- if dir_norm == "TD":
365
- dir_norm = "TB"
366
- if dir_norm not in {"LR", "TB", "RL", "BT"}:
367
- dir_norm = "LR"
368
- result_lines = [f"graph {dir_norm};"]
369
- result_lines.append("")
370
- if cluster_by_party:
371
- wsize = _collect_world_size(graph, world_size)
372
- for r in range(wsize):
373
- nodes = per_party_nodes.get(r)
374
- if not nodes:
375
- continue
376
- result_lines.append(f" subgraph P{r}")
377
- for ln in nodes:
378
- result_lines.append(f" {ln}")
379
- result_lines.append(" end")
380
- result_lines.append("")
381
- if shared_nodes:
382
- result_lines.append(f" subgraph {shared_cluster_name}")
383
- for ln in shared_nodes:
384
- result_lines.append(f" {ln}")
385
- result_lines.append(" end")
386
- result_lines.append("")
387
- else:
388
- for lbl in node_labels:
389
- if lbl:
390
- result_lines.append(" " + lbl)
391
- if node_labels:
392
- result_lines.append("")
393
- for edge in edges:
394
- result_lines.append(" " + edge)
395
- for ci in cross_indices:
396
- result_lines.append(
397
- f" linkStyle {ci} stroke:{cross_edge_color},stroke-width:2px;"
398
- )
399
- return "\n".join(result_lines)
400
-
401
-
402
- # ----------------------------- Markdown dump -----------------------------
403
-
404
-
405
- def dump(
406
- traced: TracedFunction,
407
- *,
408
- cluster_spec: ClusterSpec | None = None,
409
- sequence: bool = True,
410
- flow: bool = True,
411
- include_ir: bool = True,
412
- report_path: str | Path | None = None,
413
- mpir_path: str | Path | None = None,
414
- title: str | None = None,
415
- seq_opts: SequenceDiagramOptions | None = None,
416
- flow_opts: FlowchartOptions | None = None,
417
- ) -> DumpResult:
418
- """Generate a composite analysis report (markdown + structured fields).
419
-
420
- Sections (conditionally) included in the markdown:
421
- - Title (if provided)
422
- - Cluster Specification (if cluster_spec provided)
423
- - Compiler IR (if include_ir)
424
- - Graph Structure Analysis (always)
425
- - Mermaid Sequence Diagram (if sequence=True)
426
- - Mermaid Flowchart (if flow=True)
427
-
428
- Parameters:
429
- traced: TracedFunction object produced by the compilation pipeline.
430
- cluster_spec: Optional cluster topology; when provided world size and a
431
- JSON summary block are derived from it.
432
- sequence: Whether to render a sequence diagram section.
433
- flow: Whether to render a flowchart diagram section.
434
- include_ir: Include textual compiler IR section when True.
435
- report_path: If set, write the assembled markdown to this path.
436
- mpir_path: If set, write the raw MPIR proto text to this path.
437
- title: Optional top-level markdown title.
438
- seq_opts: Options controlling sequence diagram rendering.
439
- flow_opts: Options controlling flowchart rendering.
440
-
441
- Returns:
442
- DumpResult containing individual textual artifacts and the combined markdown.
443
- """
444
- if report_path is None and mpir_path is None:
445
- raise ValueError(
446
- "dump() requires at least one output path: report_path for markdown or mpir_path for raw IR"
447
- )
448
-
449
- # Build graph once
450
- expr = traced.make_expr()
451
- graph_proto = IrWriter().dumps(expr)
452
-
453
- # Derive world_size from cluster_spec if provided
454
- derived_world_size: int | None = None
455
- if cluster_spec is not None:
456
- # world_size defined as number of physical nodes (ranks)
457
- derived_world_size = len(cluster_spec.nodes)
458
-
459
- parts: list[str] = []
460
- if title:
461
- parts.append(f"# {title}\n")
462
-
463
- if cluster_spec is not None:
464
- parts.append("## Cluster Specification\n")
465
- parts.append("```json")
466
- # Minimal JSON-ish representation (ordering may vary)
467
- import json as _json
468
-
469
- parts.append(
470
- _json.dumps(
471
- {
472
- "nodes": [
473
- {
474
- "name": n.name,
475
- "rank": n.rank,
476
- "endpoint": n.endpoint,
477
- }
478
- for n in sorted(
479
- cluster_spec.nodes.values(), key=lambda x: x.rank
480
- )
481
- ],
482
- "devices": {
483
- name: {
484
- "kind": dev.kind,
485
- "members": [m.name for m in dev.members],
486
- }
487
- for name, dev in sorted(cluster_spec.devices.items())
488
- },
489
- },
490
- indent=2,
491
- )
492
- )
493
- parts.append("```")
494
-
495
- ir_text: str | None = None
496
- if include_ir:
497
- ir_text = traced.compiler_ir()
498
- parts.append("## Compiler IR (text)\n")
499
- parts.append("```")
500
- parts.append(ir_text)
501
- parts.append("```")
502
-
503
- stats = get_graph_statistics(graph_proto)
504
- parts.append("## Graph Structure Analysis\n")
505
- parts.append("```")
506
- parts.append(stats)
507
- parts.append("```")
508
-
509
- seq_text: str | None = None
510
- if sequence:
511
- seq_opts = seq_opts or {}
512
- seq_text = to_sequence_diagram(
513
- graph_proto,
514
- world_size=derived_world_size,
515
- **seq_opts,
516
- )
517
- parts.append("## Mermaid Sequence Diagram")
518
- parts.append("```mermaid")
519
- parts.append(seq_text)
520
- parts.append("```")
521
-
522
- flow_text: str | None = None
523
- if flow:
524
- flow_opts = flow_opts or {}
525
- effective_world_size = derived_world_size
526
- if effective_world_size is None and "world_size" in flow_opts:
527
- effective_world_size = flow_opts["world_size"] # type: ignore[assignment]
528
- flow_text = to_flowchart(
529
- graph_proto,
530
- world_size=effective_world_size,
531
- group_local=flow_opts.get("group_local", True),
532
- show_meta=flow_opts.get("show_meta", False),
533
- direction=flow_opts.get("direction", "LR"),
534
- cross_edge_color=flow_opts.get("cross_edge_color", "#ff6a00"),
535
- cluster_by_party=flow_opts.get("cluster_by_party", False),
536
- shared_cluster_name=flow_opts.get("shared_cluster_name", "Shared"),
537
- )
538
- parts.append("## Mermaid Flowchart (DAG)")
539
- parts.append("```mermaid")
540
- if flow_text is not None:
541
- parts.append(flow_text)
542
- parts.append("```")
543
-
544
- markdown = "\n\n".join(parts) + "\n"
545
-
546
- if mpir_path:
547
- Path(mpir_path).write_text(str(graph_proto), encoding="utf-8")
548
- if report_path:
549
- Path(report_path).write_text(markdown, encoding="utf-8")
550
-
551
- return DumpResult(
552
- ir=ir_text,
553
- stats=stats,
554
- sequence=seq_text,
555
- flow=flow_text,
556
- markdown=markdown,
557
- )
558
-
559
-
560
- __all__ = [
561
- "DumpResult",
562
- "FlowchartOptions",
563
- "SequenceDiagramOptions",
564
- "dump",
565
- "to_flowchart",
566
- "to_sequence_diagram",
567
- ]