mplang-nightly 0.1.dev139__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (75) hide show
  1. mplang/__init__.py +70 -0
  2. mplang/analysis/__init__.py +37 -0
  3. mplang/analysis/diagram.py +569 -0
  4. mplang/api.py +98 -0
  5. mplang/backend/__init__.py +20 -0
  6. mplang/backend/base.py +287 -0
  7. mplang/backend/builtin.py +212 -0
  8. mplang/backend/crypto.py +114 -0
  9. mplang/backend/phe.py +301 -0
  10. mplang/backend/spu.py +257 -0
  11. mplang/backend/sql_duckdb.py +42 -0
  12. mplang/backend/stablehlo.py +73 -0
  13. mplang/backend/tee.py +59 -0
  14. mplang/core/__init__.py +92 -0
  15. mplang/core/cluster.py +258 -0
  16. mplang/core/comm.py +277 -0
  17. mplang/core/context_mgr.py +50 -0
  18. mplang/core/dtype.py +293 -0
  19. mplang/core/expr/__init__.py +80 -0
  20. mplang/core/expr/ast.py +543 -0
  21. mplang/core/expr/evaluator.py +523 -0
  22. mplang/core/expr/printer.py +279 -0
  23. mplang/core/expr/transformer.py +141 -0
  24. mplang/core/expr/utils.py +78 -0
  25. mplang/core/expr/visitor.py +85 -0
  26. mplang/core/expr/walk.py +387 -0
  27. mplang/core/interp.py +160 -0
  28. mplang/core/mask.py +325 -0
  29. mplang/core/mpir.py +958 -0
  30. mplang/core/mpobject.py +117 -0
  31. mplang/core/mptype.py +438 -0
  32. mplang/core/pfunc.py +130 -0
  33. mplang/core/primitive.py +942 -0
  34. mplang/core/table.py +190 -0
  35. mplang/core/tensor.py +75 -0
  36. mplang/core/tracer.py +383 -0
  37. mplang/device.py +326 -0
  38. mplang/frontend/__init__.py +46 -0
  39. mplang/frontend/base.py +427 -0
  40. mplang/frontend/builtin.py +205 -0
  41. mplang/frontend/crypto.py +109 -0
  42. mplang/frontend/ibis_cc.py +137 -0
  43. mplang/frontend/jax_cc.py +150 -0
  44. mplang/frontend/phe.py +67 -0
  45. mplang/frontend/spu.py +152 -0
  46. mplang/frontend/sql.py +60 -0
  47. mplang/frontend/tee.py +42 -0
  48. mplang/protos/v1alpha1/mpir_pb2.py +63 -0
  49. mplang/protos/v1alpha1/mpir_pb2.pyi +507 -0
  50. mplang/protos/v1alpha1/mpir_pb2_grpc.py +3 -0
  51. mplang/runtime/__init__.py +32 -0
  52. mplang/runtime/cli.py +436 -0
  53. mplang/runtime/client.py +405 -0
  54. mplang/runtime/communicator.py +87 -0
  55. mplang/runtime/driver.py +311 -0
  56. mplang/runtime/exceptions.py +27 -0
  57. mplang/runtime/http_api.md +56 -0
  58. mplang/runtime/link_comm.py +131 -0
  59. mplang/runtime/resource.py +338 -0
  60. mplang/runtime/server.py +326 -0
  61. mplang/runtime/simulation.py +297 -0
  62. mplang/simp/__init__.py +351 -0
  63. mplang/simp/mpi.py +132 -0
  64. mplang/simp/random.py +121 -0
  65. mplang/simp/smpc.py +201 -0
  66. mplang/utils/__init__.py +13 -0
  67. mplang/utils/crypto.py +32 -0
  68. mplang/utils/func_utils.py +147 -0
  69. mplang/utils/spu_utils.py +130 -0
  70. mplang/utils/table_utils.py +73 -0
  71. mplang_nightly-0.1.dev139.dist-info/METADATA +290 -0
  72. mplang_nightly-0.1.dev139.dist-info/RECORD +75 -0
  73. mplang_nightly-0.1.dev139.dist-info/WHEEL +4 -0
  74. mplang_nightly-0.1.dev139.dist-info/entry_points.txt +2 -0
  75. mplang_nightly-0.1.dev139.dist-info/licenses/LICENSE +201 -0
mplang/__init__.py ADDED
@@ -0,0 +1,70 @@
1
+ # Copyright 2025 Ant Group Co., Ltd.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ """Multi-Party Programming Language for Secure Computation."""
16
+
17
+ # Version is managed by hatch-vcs and available after package installation
18
+ try:
19
+ from importlib.metadata import PackageNotFoundError, version
20
+
21
+ __version__ = version("mplang")
22
+ except PackageNotFoundError:
23
+ # Fallback for development/editable installs when package is not installed
24
+ __version__ = "0.0.0-dev"
25
+
26
+ from mplang import analysis
27
+ from mplang.api import CompileOptions, compile, evaluate, fetch
28
+ from mplang.core import (
29
+ DType,
30
+ InterpContext,
31
+ Mask,
32
+ MPContext,
33
+ MPObject,
34
+ MPType,
35
+ TableType,
36
+ TensorType,
37
+ function,
38
+ )
39
+ from mplang.core.cluster import ClusterSpec, Device, Node, RuntimeInfo
40
+ from mplang.core.context_mgr import cur_ctx, set_ctx, with_ctx
41
+ from mplang.runtime.driver import Driver
42
+ from mplang.runtime.simulation import Simulator
43
+
44
+ # Public API
45
+ __all__ = [
46
+ "ClusterSpec",
47
+ "CompileOptions",
48
+ "DType",
49
+ "Device",
50
+ "Driver",
51
+ "InterpContext",
52
+ "MPContext",
53
+ "MPObject",
54
+ "MPType",
55
+ "Mask",
56
+ "Node",
57
+ "RuntimeInfo",
58
+ "Simulator",
59
+ "TableType",
60
+ "TensorType",
61
+ "__version__",
62
+ "analysis",
63
+ "compile",
64
+ "cur_ctx",
65
+ "evaluate",
66
+ "fetch",
67
+ "function",
68
+ "set_ctx",
69
+ "with_ctx",
70
+ ]
@@ -0,0 +1,37 @@
1
+ # Copyright 2025 Ant Group Co., Ltd.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ """Analysis and visualization utilities for mplang.
16
+
17
+ This subpackage hosts non-core developer aids: diagram rendering, IR dumps,
18
+ profiling helpers (future), etc.
19
+ """
20
+
21
+ from mplang.analysis.diagram import (
22
+ DumpResult,
23
+ FlowchartOptions,
24
+ SequenceDiagramOptions,
25
+ dump,
26
+ to_flowchart,
27
+ to_sequence_diagram,
28
+ )
29
+
30
+ __all__ = [
31
+ "DumpResult",
32
+ "FlowchartOptions",
33
+ "SequenceDiagramOptions",
34
+ "dump",
35
+ "to_flowchart",
36
+ "to_sequence_diagram",
37
+ ]
@@ -0,0 +1,569 @@
1
+ # Copyright 2025 Ant Group Co., Ltd.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ """Diagram rendering (Mermaid) and markdown dump helpers.
16
+
17
+ Moved from mplang.utils.mermaid to dedicated analysis namespace.
18
+ """
19
+
20
+ from __future__ import annotations
21
+
22
+ from dataclasses import dataclass
23
+ from pathlib import Path
24
+ from typing import TypedDict
25
+
26
+ from mplang.core import TracedFunction
27
+ from mplang.core.cluster import ClusterSpec
28
+ from mplang.core.mask import Mask
29
+ from mplang.core.mpir import Writer, get_graph_statistics
30
+ from mplang.protos.v1alpha1 import mpir_pb2
31
+
32
+ # ----------------------------- Core helpers (copied) -----------------------------
33
+
34
+
35
+ @dataclass
36
+ class Event:
37
+ kind: str
38
+ lines: list[str]
39
+
40
+
41
+ def _pmask_to_ranks(pmask: int) -> list[int]:
42
+ return list(Mask(pmask)) if pmask >= 0 else []
43
+
44
+
45
+ def _collect_world_size(
46
+ graph: mpir_pb2.GraphProto, explicit_world_size: int | None
47
+ ) -> int:
48
+ if explicit_world_size is not None:
49
+ return explicit_world_size
50
+ max_rank = -1
51
+ for node in graph.nodes:
52
+ for out in node.outs_info:
53
+ if out.pmask >= 0:
54
+ for r in _pmask_to_ranks(out.pmask):
55
+ max_rank = max(max_rank, r)
56
+ return max_rank + 1 if max_rank >= 0 else 0
57
+
58
+
59
+ def _node_output_pmasks(node: mpir_pb2.NodeProto) -> list[int]:
60
+ return [out.pmask for out in node.outs_info]
61
+
62
+
63
+ def _value_producers(graph: mpir_pb2.GraphProto) -> dict[str, mpir_pb2.NodeProto]:
64
+ prod: dict[str, mpir_pb2.NodeProto] = {}
65
+ for n in graph.nodes:
66
+ base = n.name
67
+ arity = max(1, len(n.outs_info))
68
+ for i in range(arity):
69
+ key = f"{base}" if arity == 1 else f"{base}:{i}"
70
+ prod[key] = n
71
+ return prod
72
+
73
+
74
+ # ----------------------------- Option Types -----------------------------
75
+
76
+
77
+ class SequenceDiagramOptions(TypedDict, total=False):
78
+ collapse_local: bool
79
+ show_compute: bool
80
+ max_local_batch: int
81
+ show_meta: bool
82
+
83
+
84
+ class FlowchartOptions(TypedDict, total=False):
85
+ group_local: bool
86
+ world_size: int
87
+ show_meta: bool
88
+ direction: str
89
+ cross_edge_color: str
90
+ cluster_by_party: bool
91
+ shared_cluster_name: str
92
+
93
+
94
+ @dataclass
95
+ class DumpResult:
96
+ ir: str | None
97
+ stats: str
98
+ sequence: str | None
99
+ flow: str | None
100
+ markdown: str
101
+
102
+
103
+ # ----------------------------- Public APIs -----------------------------
104
+
105
+
106
+ def to_sequence_diagram(
107
+ graph: mpir_pb2.GraphProto,
108
+ *,
109
+ world_size: int | None = None,
110
+ collapse_local: bool = False,
111
+ show_compute: bool = False,
112
+ max_local_batch: int = 8,
113
+ show_meta: bool = False,
114
+ ) -> str:
115
+ """Render a MPIR graph as a Mermaid sequenceDiagram.
116
+
117
+ Parameters:
118
+ graph: The MPIR graph proto to visualize.
119
+ world_size: Optional explicit number of parties (participants). If None it
120
+ is inferred from the union of output pmasks in the graph.
121
+ collapse_local: If True, consecutive pure-local operations are batched into
122
+ a single note summarizing their op names (up to max_local_batch shown).
123
+ show_compute: When True and collapse_local is also True, local operations
124
+ are still listed individually (disable batching display logic).
125
+ max_local_batch: Maximum number of individual local op names to show inside
126
+ a collapsed note before summarizing with a "+N more" suffix.
127
+ show_meta: If True, include structural/meta ops (e.g. tuple, func_def).
128
+
129
+ Returns:
130
+ Mermaid sequenceDiagram text.
131
+ """
132
+
133
+ wsize = _collect_world_size(graph, world_size)
134
+ value_producer = _value_producers(graph)
135
+
136
+ participants: list[str] = [f"participant P{i}" for i in range(wsize)]
137
+ events: list[Event] = []
138
+
139
+ def emit(kind: str, *lines: str) -> None:
140
+ events.append(Event(kind=kind, lines=list(lines)))
141
+
142
+ local_buffer: list[str] = []
143
+
144
+ def flush_local_buffer() -> None:
145
+ nonlocal local_buffer
146
+ if not local_buffer:
147
+ return
148
+ if collapse_local and not show_compute:
149
+
150
+ def _strip(s: str) -> str:
151
+ parts = s.split(": ", 1)
152
+ return parts[1] if len(parts) == 2 else s
153
+
154
+ sample = [_strip(s) for s in local_buffer[:max_local_batch]]
155
+ more = (
156
+ ""
157
+ if len(local_buffer) <= max_local_batch
158
+ else f" (+{len(local_buffer) - max_local_batch} more)"
159
+ )
160
+ emit("note", f"note over P0,P{wsize - 1}: {', '.join(sample)}{more}")
161
+ else:
162
+ for line in local_buffer:
163
+ emit("local", line)
164
+ local_buffer = []
165
+
166
+ def owners_of(node: mpir_pb2.NodeProto) -> set[int]:
167
+ ranks: set[int] = set()
168
+ for pm in _node_output_pmasks(node):
169
+ if pm >= 0:
170
+ ranks.update(_pmask_to_ranks(pm))
171
+ return ranks
172
+
173
+ for node in graph.nodes:
174
+ own = owners_of(node)
175
+ input_ranks: set[int] = set()
176
+ for val in node.inputs:
177
+ prod = value_producer.get(val.split(":")[0])
178
+ if prod:
179
+ input_ranks.update(owners_of(prod))
180
+ pfunc = node.attrs.get("pfunc") if node.op_type == "eval" else None
181
+ if node.op_type == "eval" and pfunc:
182
+ fn_name = pfunc.func.name or pfunc.func.type or "eval"
183
+ label = f"{fn_name} {node.name}"
184
+ elif node.op_type == "access":
185
+ label = ""
186
+ else:
187
+ label = f"{node.op_type} {node.name}"
188
+ value_suffix = ""
189
+ if len(node.outs_info) > 1:
190
+ value_suffix = f" -> {len(node.outs_info)} outs"
191
+ if node.op_type == "access":
192
+ continue
193
+ if not show_meta and node.op_type in {"tuple", "func_def"}:
194
+ continue
195
+ if node.op_type == "shfl_s":
196
+ flush_local_buffer()
197
+ pmask_attr = node.attrs.get("pmask")
198
+ src_ranks_attr = node.attrs.get("src_ranks")
199
+ if pmask_attr and src_ranks_attr:
200
+ dst_ranks = _pmask_to_ranks(pmask_attr.i)
201
+ src_ranks = list(src_ranks_attr.ints)
202
+ for dst, src in zip(dst_ranks, src_ranks, strict=True):
203
+ emit("comm", f"P{src}->>P{dst}: {node.name}")
204
+ else:
205
+ emit("comm", f"note over P0,P{wsize - 1}: send %${node.name}")
206
+ continue
207
+ if node.op_type in {"cond", "while"}:
208
+ flush_local_buffer()
209
+ emit("note", f"note over P0,P{wsize - 1}: {node.op_type} {node.name}")
210
+ continue
211
+ cross = False
212
+ if input_ranks and own and (own - input_ranks or input_ranks - own):
213
+ cross = True
214
+ if cross and own and input_ranks:
215
+ flush_local_buffer()
216
+ for s in sorted(input_ranks):
217
+ for t in sorted(own):
218
+ if s != t:
219
+ emit("comm", f"P{s}->>P{t}: {label}{value_suffix}")
220
+ continue
221
+ local_desc = f"{label}{value_suffix}".strip()
222
+ if not local_desc:
223
+ continue
224
+ if own:
225
+ repr_rank = min(own)
226
+ local_line = f"P{repr_rank}-->>P{repr_rank}: {local_desc}"
227
+ else:
228
+ local_line = f"note over P0,P{wsize - 1}: {local_desc} (dyn)"
229
+ local_buffer.append(local_line)
230
+
231
+ flush_local_buffer()
232
+
233
+ out_lines: list[str] = ["sequenceDiagram"]
234
+ out_lines.extend(participants)
235
+ for ev in events:
236
+ out_lines.extend(ev.lines)
237
+ return "\n".join(out_lines)
238
+
239
+
240
+ def to_flowchart(
241
+ graph: mpir_pb2.GraphProto,
242
+ *,
243
+ group_local: bool = True,
244
+ world_size: int | None = None,
245
+ show_meta: bool = False,
246
+ direction: str = "LR",
247
+ cross_edge_color: str = "#ff6a00",
248
+ cluster_by_party: bool = False,
249
+ shared_cluster_name: str = "Shared",
250
+ ) -> str:
251
+ """Render a MPIR graph as a Mermaid flowchart (DAG view).
252
+
253
+ Parameters:
254
+ graph: The MPIR graph proto to visualize.
255
+ group_local: (Reserved) placeholder for future local grouping in non-cluster view.
256
+ world_size: Optional explicit party count override (inferred if None).
257
+ show_meta: Include meta/structural nodes when True.
258
+ direction: Mermaid layout direction (LR, RL, TB, BT). Accepts TD synonym for TB.
259
+ cross_edge_color: CSS color used to highlight cross-party data edges.
260
+ cluster_by_party: If True, wrap nodes in per-party subgraphs plus a shared cluster.
261
+ shared_cluster_name: Title for the shared subgraph cluster when cluster_by_party=True.
262
+
263
+ Returns:
264
+ Mermaid flowchart text.
265
+ """
266
+ value_to_node: dict[str, mpir_pb2.NodeProto] = {}
267
+ for n in graph.nodes:
268
+ base = n.name
269
+ arity = max(1, len(n.outs_info))
270
+ for i in range(arity):
271
+ key = f"{base}" if arity == 1 else f"{base}:{i}"
272
+ value_to_node[key] = n
273
+
274
+ def owners_of(node: mpir_pb2.NodeProto) -> set[int]:
275
+ rs: set[int] = set()
276
+ for out in node.outs_info:
277
+ if out.pmask >= 0:
278
+ rs.update(_pmask_to_ranks(out.pmask))
279
+ return rs
280
+
281
+ node_labels: list[str] = []
282
+ per_party_nodes: dict[int, list[str]] = {}
283
+ shared_nodes: list[str] = []
284
+ node_id_map: dict[str, str] = {}
285
+ node_map: dict[str, mpir_pb2.NodeProto] = {n.name: n for n in graph.nodes}
286
+ id_to_owners: dict[str, set[int]] = {}
287
+ for n in graph.nodes:
288
+ if n.op_type == "access":
289
+ continue
290
+ if not show_meta and n.op_type in {"tuple", "func_def"}:
291
+ continue
292
+ node_id = f"n{n.name[1:]}"
293
+ node_id_map[n.name] = node_id
294
+ pfunc = n.attrs.get("pfunc") if n.op_type == "eval" else None
295
+ if pfunc:
296
+ op_label = pfunc.func.name or pfunc.func.type or n.op_type
297
+ else:
298
+ op_label = n.op_type
299
+ arity = len(n.outs_info)
300
+ arity_suffix = f"/{arity}" if arity > 1 else ""
301
+ owners = owners_of(n)
302
+ owners_str = (
303
+ "" if not owners else " @" + ",".join(f"P{r}" for r in sorted(owners))
304
+ )
305
+ label_line = f'{node_id}["{op_label}{arity_suffix}{owners_str}"]'
306
+ if cluster_by_party:
307
+ if len(owners) == 1:
308
+ (owner_rank,) = tuple(owners)
309
+ per_party_nodes.setdefault(owner_rank, []).append(label_line)
310
+ else:
311
+ shared_nodes.append(label_line)
312
+ else:
313
+ node_labels.append(label_line)
314
+ id_to_owners[node_id] = owners
315
+
316
+ def resolve_sources(val: str, seen: set[str] | None = None) -> set[str]:
317
+ if seen is None:
318
+ seen = set()
319
+ base = val.split(":")[0]
320
+ node = value_to_node.get(val) or value_to_node.get(base)
321
+ if not node:
322
+ return set()
323
+ if node.name in seen:
324
+ return set()
325
+ seen.add(node.name)
326
+ if node.op_type == "access":
327
+ srcs: set[str] = set()
328
+ for upstream in node.inputs:
329
+ srcs |= resolve_sources(upstream, seen)
330
+ return srcs
331
+ return {node.name}
332
+
333
+ edge_set: set[tuple[str, str]] = set()
334
+ for n in graph.nodes:
335
+ if n.op_type == "access":
336
+ continue
337
+ if not show_meta and n.op_type in {"tuple", "func_def"}:
338
+ continue
339
+ dst_id = node_id_map.get(n.name)
340
+ if not dst_id:
341
+ continue
342
+ for inp in n.inputs:
343
+ for src_name in resolve_sources(inp):
344
+ if (not show_meta) and node_map[src_name].op_type in {
345
+ "tuple",
346
+ "func_def",
347
+ }:
348
+ continue
349
+ src_id = node_id_map.get(src_name)
350
+ if not src_id or src_id == dst_id:
351
+ continue
352
+ edge_set.add((src_id, dst_id))
353
+
354
+ ordered_edges = sorted(edge_set)
355
+ edges = [f"{s} --> {t}" for s, t in ordered_edges]
356
+ cross_indices: list[int] = []
357
+ for idx, (s, t) in enumerate(ordered_edges):
358
+ so = id_to_owners.get(s, set())
359
+ to = id_to_owners.get(t, set())
360
+ if so and to and so != to:
361
+ cross_indices.append(idx)
362
+
363
+ _ = group_local
364
+
365
+ dir_norm = direction.upper()
366
+ if dir_norm == "TD":
367
+ dir_norm = "TB"
368
+ if dir_norm not in {"LR", "TB", "RL", "BT"}:
369
+ dir_norm = "LR"
370
+ result_lines = [f"graph {dir_norm};"]
371
+ result_lines.append("")
372
+ if cluster_by_party:
373
+ wsize = _collect_world_size(graph, world_size)
374
+ for r in range(wsize):
375
+ nodes = per_party_nodes.get(r)
376
+ if not nodes:
377
+ continue
378
+ result_lines.append(f" subgraph P{r}")
379
+ for ln in nodes:
380
+ result_lines.append(f" {ln}")
381
+ result_lines.append(" end")
382
+ result_lines.append("")
383
+ if shared_nodes:
384
+ result_lines.append(f" subgraph {shared_cluster_name}")
385
+ for ln in shared_nodes:
386
+ result_lines.append(f" {ln}")
387
+ result_lines.append(" end")
388
+ result_lines.append("")
389
+ else:
390
+ for lbl in node_labels:
391
+ if lbl:
392
+ result_lines.append(" " + lbl)
393
+ if node_labels:
394
+ result_lines.append("")
395
+ for edge in edges:
396
+ result_lines.append(" " + edge)
397
+ for ci in cross_indices:
398
+ result_lines.append(
399
+ f" linkStyle {ci} stroke:{cross_edge_color},stroke-width:2px;"
400
+ )
401
+ return "\n".join(result_lines)
402
+
403
+
404
+ # ----------------------------- Markdown dump -----------------------------
405
+
406
+
407
+ def dump(
408
+ traced: TracedFunction,
409
+ *,
410
+ cluster_spec: ClusterSpec | None = None,
411
+ sequence: bool = True,
412
+ flow: bool = True,
413
+ include_ir: bool = True,
414
+ report_path: str | Path | None = None,
415
+ mpir_path: str | Path | None = None,
416
+ title: str | None = None,
417
+ seq_opts: SequenceDiagramOptions | None = None,
418
+ flow_opts: FlowchartOptions | None = None,
419
+ ) -> DumpResult:
420
+ """Generate a composite analysis report (markdown + structured fields).
421
+
422
+ Sections (conditionally) included in the markdown:
423
+ - Title (if provided)
424
+ - Cluster Specification (if cluster_spec provided)
425
+ - Compiler IR (if include_ir)
426
+ - Graph Structure Analysis (always)
427
+ - Mermaid Sequence Diagram (if sequence=True)
428
+ - Mermaid Flowchart (if flow=True)
429
+
430
+ Parameters:
431
+ traced: TracedFunction object produced by the compilation pipeline.
432
+ cluster_spec: Optional cluster topology; when provided world size and a
433
+ JSON summary block are derived from it.
434
+ sequence: Whether to render a sequence diagram section.
435
+ flow: Whether to render a flowchart diagram section.
436
+ include_ir: Include textual compiler IR section when True.
437
+ report_path: If set, write the assembled markdown to this path.
438
+ mpir_path: If set, write the raw MPIR proto text to this path.
439
+ title: Optional top-level markdown title.
440
+ seq_opts: Options controlling sequence diagram rendering.
441
+ flow_opts: Options controlling flowchart rendering.
442
+
443
+ Returns:
444
+ DumpResult containing individual textual artifacts and the combined markdown.
445
+ """
446
+ if report_path is None and mpir_path is None:
447
+ raise ValueError(
448
+ "dump() requires at least one output path: report_path for markdown or mpir_path for raw IR"
449
+ )
450
+
451
+ # Build graph once
452
+ expr = traced.make_expr()
453
+ graph_proto = Writer().dumps(expr)
454
+
455
+ # Derive world_size from cluster_spec if provided
456
+ derived_world_size: int | None = None
457
+ if cluster_spec is not None:
458
+ # world_size defined as number of physical nodes (ranks)
459
+ derived_world_size = len(cluster_spec.nodes)
460
+
461
+ parts: list[str] = []
462
+ if title:
463
+ parts.append(f"# {title}\n")
464
+
465
+ if cluster_spec is not None:
466
+ parts.append("## Cluster Specification\n")
467
+ parts.append("```json")
468
+ # Minimal JSON-ish representation (ordering may vary)
469
+ import json as _json
470
+
471
+ parts.append(
472
+ _json.dumps(
473
+ {
474
+ "nodes": [
475
+ {
476
+ "name": n.name,
477
+ "rank": n.rank,
478
+ "endpoint": n.endpoint,
479
+ }
480
+ for n in sorted(
481
+ cluster_spec.nodes.values(), key=lambda x: x.rank
482
+ )
483
+ ],
484
+ "devices": {
485
+ name: {
486
+ "kind": dev.kind,
487
+ "members": [m.name for m in dev.members],
488
+ }
489
+ for name, dev in sorted(cluster_spec.devices.items())
490
+ },
491
+ },
492
+ indent=2,
493
+ )
494
+ )
495
+ parts.append("```")
496
+
497
+ ir_text: str | None = None
498
+ if include_ir:
499
+ ir_text = traced.compiler_ir()
500
+ parts.append("## Compiler IR (text)\n")
501
+ parts.append("```")
502
+ parts.append(ir_text)
503
+ parts.append("```")
504
+
505
+ stats = get_graph_statistics(graph_proto)
506
+ parts.append("## Graph Structure Analysis\n")
507
+ parts.append("```")
508
+ parts.append(stats)
509
+ parts.append("```")
510
+
511
+ seq_text: str | None = None
512
+ if sequence:
513
+ seq_opts = seq_opts or {}
514
+ seq_text = to_sequence_diagram(
515
+ graph_proto,
516
+ world_size=derived_world_size,
517
+ **seq_opts,
518
+ )
519
+ parts.append("## Mermaid Sequence Diagram")
520
+ parts.append("```mermaid")
521
+ parts.append(seq_text)
522
+ parts.append("```")
523
+
524
+ flow_text: str | None = None
525
+ if flow:
526
+ flow_opts = flow_opts or {}
527
+ effective_world_size = derived_world_size
528
+ if effective_world_size is None and "world_size" in flow_opts:
529
+ effective_world_size = flow_opts["world_size"] # type: ignore[assignment]
530
+ flow_text = to_flowchart(
531
+ graph_proto,
532
+ world_size=effective_world_size,
533
+ group_local=flow_opts.get("group_local", True),
534
+ show_meta=flow_opts.get("show_meta", False),
535
+ direction=flow_opts.get("direction", "LR"),
536
+ cross_edge_color=flow_opts.get("cross_edge_color", "#ff6a00"),
537
+ cluster_by_party=flow_opts.get("cluster_by_party", False),
538
+ shared_cluster_name=flow_opts.get("shared_cluster_name", "Shared"),
539
+ )
540
+ parts.append("## Mermaid Flowchart (DAG)")
541
+ parts.append("```mermaid")
542
+ if flow_text is not None:
543
+ parts.append(flow_text)
544
+ parts.append("```")
545
+
546
+ markdown = "\n\n".join(parts) + "\n"
547
+
548
+ if mpir_path:
549
+ Path(mpir_path).write_text(str(graph_proto), encoding="utf-8")
550
+ if report_path:
551
+ Path(report_path).write_text(markdown, encoding="utf-8")
552
+
553
+ return DumpResult(
554
+ ir=ir_text,
555
+ stats=stats,
556
+ sequence=seq_text,
557
+ flow=flow_text,
558
+ markdown=markdown,
559
+ )
560
+
561
+
562
+ __all__ = [
563
+ "DumpResult",
564
+ "FlowchartOptions",
565
+ "SequenceDiagramOptions",
566
+ "dump",
567
+ "to_flowchart",
568
+ "to_sequence_diagram",
569
+ ]