tilelang-devkit 0.1.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -0,0 +1,743 @@
1
+ """IR Lower Trace — zero-intrusion debug tool for visualizing tilelang compilation passes.
2
+
3
+ This module is the core of the ``tilelang-devkit`` package. It monkey-patches
4
+ ``tvm.ir.transform.Pass.__call__`` and tilelang's phase functions to
5
+ automatically capture IR before/after every compilation pass, then prints a
6
+ unified diff to the terminal and saves the raw IR to disk.
7
+
8
+ Architecture
9
+ ============
10
+ ``enable()`` installs three layers of monkey-patch hooks::
11
+
12
+ Layer 1: Pass.__call__ = _traced_pass_call
13
+ Intercept all TVM Pass calls, snapshot before/after IR → diff → record
14
+
15
+ Layer 2: phase functions = _wrap_phase(...)
16
+ Set _current_phase so Layer 1 knows which phase each pass belongs to
17
+
18
+ Layer 3: codegen FFI = _wrap_codegen_ffi(...)
19
+ Intercept TIR→C++ conversion, capture final IR and generated source
20
+
21
+ During kernel compilation::
22
+
23
+ phase1_LowerAndLegalize() ← _wrap_phase sets _current_phase
24
+ ├─ Simplify(mod) ← _traced_pass_call intercepts
25
+ ├─ LowerTileOp(mod) ← _traced_pass_call intercepts
26
+ └─ ...
27
+ phase2_OptimizeForTarget()
28
+ ├─ FlattenBuffer(mod) ← _traced_pass_call intercepts
29
+ └─ ...
30
+ codegen(mod) ← _wrap_codegen_ffi intercepts
31
+
32
+ Usage
33
+ =====
34
+ ::
35
+
36
+ from lower_trace import enable, disable
37
+
38
+ enable() # install hooks, start tracing
39
+ # ... compile your kernel ...
40
+ disable() # uninstall hooks, restore original behavior
41
+
42
+ Or use the context manager::
43
+
44
+ from lower_trace import trace
45
+
46
+ with trace():
47
+ # ... compile your kernel ...
48
+
49
+ Output
50
+ ======
51
+ Terminal diff is printed for each pass, and raw IR is saved to::
52
+
53
+ ./tmp/lower_trace/run_<timestamp>_<pid>/
54
+ ├── phase1_LowerAndLegalize/
55
+ │ ├── 00_Simplify_before.tir
56
+ │ ├── 00_Simplify_after.tir
57
+ │ └── ...
58
+ ├── phase2_OptimizeForTarget/...
59
+ └── codegen/
60
+ └── NN_codegen_before.tir
61
+ └── NN_codegen_after.cpp
62
+
63
+ Dependencies
64
+ ============
65
+ This module imports cleanly without tvm/tilelang installed (so ``pip install``
66
+ works on any environment). The actual tracing requires tvm + tilelang at
67
+ runtime: ``enable()`` raises ``RuntimeError`` if tvm is missing, and degrades
68
+ gracefully (no phase labels) if tilelang's engine module is unavailable.
69
+ """
70
+
71
+ from __future__ import annotations
72
+
73
+ import contextlib
74
+ import difflib
75
+ import dis
76
+ import functools
77
+ import inspect
78
+ import os
79
+ import re
80
+ import threading
81
+ from dataclasses import dataclass
82
+ from typing import TYPE_CHECKING
83
+
84
+ from .diff import (
85
+ _ANSI_BLUE,
86
+ _ANSI_DIM,
87
+ _ANSI_GREEN,
88
+ _ANSI_RED,
89
+ _ANSI_RESET,
90
+ print_diff,
91
+ )
92
+
93
+ if TYPE_CHECKING:
94
+ from collections.abc import Callable
95
+
96
+ _TAG = "[lower_trace]"
97
+
98
+ STATUS_COMPLETED = "completed"
99
+ STATUS_FAILED = "failed"
100
+ STATUS_SKIPPED = "skipped"
101
+ STATUS_CODEGEN = "codegen"
102
+
103
+ _CODEGEN_FFI_NAMES: list[str] = [
104
+ "target.build.tilelang_cuda",
105
+ "target.build.tilelang_cuda_without_compile",
106
+ "target.build.tilelang_cutedsl",
107
+ "target.build.tilelang_cutedsl_without_compile",
108
+ "target.build.tilelang_hip",
109
+ "target.build.tilelang_hip_without_compile",
110
+ "target.build.tilelang_metal",
111
+ "target.build.tilelang_c",
112
+ "target.build.tilelang_c_host",
113
+ "target.build.tilelang_ascend",
114
+ "target.build.tilelang_ascend_pto",
115
+ "target.build.llvm",
116
+ "target.build.webgpu",
117
+ "target.build.tilelang_cpp",
118
+ "target.build.tilelang_webgpu",
119
+ ]
120
+
121
+ # ── 全局状态 ──────────────────────────────────────────────────────────────────
122
+ # _records: 每次 pass 的记录列表,disable() 时清空
123
+ _records: list[LowerRecord] = []
124
+
125
+ # 原始函数引用:enable() 时保存,disable() 时恢复(保证 monkey-patch 可逆)
126
+ _original_pass_call: Callable | None = None
127
+ _original_codegen_ffis: dict[str, Callable] = {}
128
+ _legacy_patched: bool = False
129
+ # (target, attr_name, original_or_MISSING, is_dict) — disable() 据此恢复
130
+ _legacy_phase_originals: list[tuple[object, str, object, bool]] = []
131
+ _MISSING: object = object()
132
+
133
+ # 当前 phase 上下文:_wrap_phase 设置,_traced_pass_call 读取
134
+ _current_phase: str | None = None
135
+ _pass_index: int = 0
136
+ _run_dir: str | None = None
137
+ _lock = threading.RLock()
138
+ _run_counter: int = 0
139
+
140
+ _UNSCOPED_PHASE = "unscoped"
141
+ _DEFAULT_TRACE_DIR = os.path.join(".", "tmp", "lower_trace")
142
+ _trace_dir: str = _DEFAULT_TRACE_DIR
143
+
144
+
145
+ @dataclass
146
+ class LowerRecord:
147
+ """单次 pass 执行的记录。
148
+
149
+ Attributes
150
+ ----------
151
+ phase : str
152
+ 所属阶段(如 ``phase1_LowerAndLegalize``、``codegen``、``unscoped``)。
153
+ name : str
154
+ Pass 显示名(如 ``Simplify``、``FlattenBuffer``)。
155
+ index : int
156
+ 全局序号(跨 phase 递增)。
157
+ before_text : str
158
+ Pass 执行前的 IR 文本(``str(mod)``)。
159
+ after_text : str
160
+ Pass 执行后的 IR 文本(``str(result)``);codegen 记录为 C++ 源码。
161
+ changed : bool
162
+ before != after。
163
+ add_lines, del_lines : int
164
+ diff 增删行数。
165
+ status : str
166
+ ``STATUS_COMPLETED`` / ``STATUS_FAILED`` / ``STATUS_CODEGEN``。
167
+ error_msg : str
168
+ 失败时的异常信息。
169
+ """
170
+
171
+ phase: str
172
+ name: str
173
+ index: int
174
+ before_text: str
175
+ after_text: str
176
+ changed: bool
177
+ add_lines: int = 0
178
+ del_lines: int = 0
179
+ status: str = STATUS_COMPLETED
180
+ error_msg: str = ""
181
+
182
+
183
+ # ── 辅助函数 ──────────────────────────────────────────────────────────────────
184
+
185
+
186
+ def _get_tvm_ffi():
187
+ """返回统一的 FFI 接口(``get_global_func`` / ``register_global_func``)。
188
+
189
+ 优先用新版 ``tvm.ffi``,回退到 ``3rdparty/tvm`` 的 legacy ``tvm._ffi``
190
+ (注册函数名为 ``register_func`` 而非 ``register_global_func``)。
191
+ """
192
+ try:
193
+ import tvm.ffi as _ffi
194
+
195
+ if hasattr(_ffi, "register_global_func") and hasattr(_ffi, "get_global_func"):
196
+ return _ffi
197
+ except ImportError:
198
+ pass
199
+ import tvm._ffi as _ffi
200
+
201
+ class _LegacyFFI:
202
+ """Adapter: 在 legacy tvm._ffi 上暴露 register_global_func API。"""
203
+
204
+ get_global_func = staticmethod(_ffi.get_global_func)
205
+
206
+ @staticmethod
207
+ def register_global_func(name, func=None, override=False):
208
+ return _ffi.register_func(name, func, override=override)
209
+
210
+ return _LegacyFFI()
211
+
212
+
213
+ def _inspect_module_source(mod):
214
+ """获取 ``tvm.runtime.Module`` 的源码文本。
215
+
216
+ 优先 ``inspect_source``(新版),回退 ``get_source``(3rdparty/tvm)。
217
+ 用于 codegen 后捕获生成的 C++/CUDA 源码。
218
+ """
219
+ for _attr in ("inspect_source", "get_source"):
220
+ _fn = getattr(mod, _attr, None)
221
+ if callable(_fn):
222
+ return _fn() or ""
223
+ return ""
224
+
225
+
226
+ def _get_pass_display_name(pass_obj) -> str:
227
+ """从 ``pass_info.name`` 提取显示名,如 ``tir.Simplify`` -> ``Simplify``。"""
228
+ try:
229
+ name = str(pass_obj.info.name)
230
+ return name.split(".")[-1] if "." in name else name
231
+ except Exception:
232
+ return type(pass_obj).__name__
233
+
234
+
235
+ def _safe_filename_component(name: str) -> str:
236
+ """净化字符串使其可作为路径组件(防止路径穿越 CWE-22)。"""
237
+ return re.sub(r"[^A-Za-z0-9._-]", "_", str(name))
238
+
239
+
240
+ def _ensure_run_dir() -> str:
241
+ """返回本次 run 的输出目录(首次调用时创建)。
242
+
243
+ 格式:``<_trace_dir>/run_<timestamp>_<pid>/``
244
+ 每次 run 一个新目录,便于区分多次编译。
245
+ """
246
+ global _run_dir
247
+
248
+ if _run_dir is not None:
249
+ return _run_dir
250
+
251
+ from datetime import datetime
252
+
253
+ timestamp = datetime.now().strftime("%Y%m%d_%H%M%S_%f")
254
+ _run_dir = os.path.join(_trace_dir, f"run_{timestamp}_{os.getpid()}")
255
+ os.makedirs(_run_dir, exist_ok=True)
256
+ return _run_dir
257
+
258
+
259
+ def _save_raw_files(record: LowerRecord):
260
+ """把 before/after IR 文本落盘。
261
+
262
+ 路径:``<run_dir>/<phase>/<index>_<name>_before|after.<ext>``
263
+
264
+ codegen 记录的 after 是 C++ 源码,扩展名用 ``.cpp``;其余用 ``.tir``。
265
+ 落盘是 best-effort:失败只打 warning,不中断编译。
266
+ """
267
+ try:
268
+ trace_dir = _ensure_run_dir()
269
+ phase_dir = os.path.join(trace_dir, _safe_filename_component(record.phase))
270
+ os.makedirs(phase_dir, exist_ok=True)
271
+
272
+ prefix = f"{record.index:02d}_{_safe_filename_component(record.name)}"
273
+ before_ext = ".tir"
274
+ after_ext = ".cpp" if record.status == STATUS_CODEGEN else ".tir"
275
+ with open(
276
+ os.path.join(phase_dir, f"{prefix}_before{before_ext}"), "w", encoding="utf-8"
277
+ ) as f:
278
+ f.write(record.before_text)
279
+ with open(
280
+ os.path.join(phase_dir, f"{prefix}_after{after_ext}"), "w", encoding="utf-8"
281
+ ) as f:
282
+ f.write(record.after_text)
283
+ except Exception as exc:
284
+ print(f" {_ANSI_RED}{_TAG} WARNING: could not save raw trace files: {exc}{_ANSI_RESET}")
285
+
286
+
287
+ def _count_diff(before_text: str, after_text: str) -> tuple[int, int]:
288
+ """用 SequenceMatcher 统计增删行数,返回 (add_count, del_count)。"""
289
+ add_count = del_count = 0
290
+ sm = difflib.SequenceMatcher(None, before_text.splitlines(), after_text.splitlines())
291
+ for tag, i1, i2, j1, j2 in sm.get_opcodes():
292
+ if tag == "insert":
293
+ add_count += j2 - j1
294
+ elif tag == "delete":
295
+ del_count += i2 - i1
296
+ elif tag == "replace":
297
+ add_count += j2 - j1
298
+ del_count += i2 - i1
299
+ return add_count, del_count
300
+
301
+
302
+ # ── Layer 1: 核心 Pass 拦截 ───────────────────────────────────────────────────
303
+
304
+
305
+ def _traced_pass_call(self, mod):
306
+ """★ 核心钩子:拦截所有 ``Pass.__call__`` 调用。
307
+
308
+ 这是整个工具的心脏。``enable()`` 时 ``Pass.__call__`` 被替换为本函数,
309
+ 之后 **所有** TVM Pass 调用(无论在哪个 phase)都会经过这里。
310
+
311
+ 流程:
312
+ 1. 调用前 ``str(mod)`` 拍 before 快照
313
+ 2. 分配全局递增序号 ``idx``
314
+ 3. 调用原始 ``Pass.__call__``
315
+ 4. 调用后 ``str(result)`` 拍 after 快照
316
+ 5. 计算 diff 统计 → 构造 ``LowerRecord`` → 落盘 → 终端打印
317
+ 6. 如果有变化,打印终端 unified diff
318
+
319
+ ``_current_phase`` 由 Layer 2 (``_wrap_phase``) 设置,告诉本函数
320
+ 当前 pass 属于哪个阶段。未设置时标记为 ``unscoped``。
321
+ """
322
+ global _pass_index
323
+
324
+ phase = _current_phase or _UNSCOPED_PHASE
325
+ before_text = str(mod)
326
+
327
+ with _lock:
328
+ idx = _pass_index
329
+ _pass_index += 1
330
+
331
+ try:
332
+ result = _original_pass_call(self, mod)
333
+ except Exception as e:
334
+ with _lock:
335
+ record = LowerRecord(
336
+ phase=phase,
337
+ name=_get_pass_display_name(self),
338
+ index=idx,
339
+ before_text=before_text,
340
+ after_text="",
341
+ changed=False,
342
+ status=STATUS_FAILED,
343
+ error_msg=str(e),
344
+ )
345
+ _records.append(record)
346
+ _save_raw_files(record)
347
+ print(f" {_ANSI_RED}{_TAG} {phase}/{idx:02d}_{record.name}: FAILED ({e}){_ANSI_RESET}")
348
+ raise
349
+
350
+ after_text = str(result)
351
+ changed = before_text != after_text
352
+ pass_name = _get_pass_display_name(self)
353
+ add_count, del_count = _count_diff(before_text, after_text) if changed else (0, 0)
354
+
355
+ with _lock:
356
+ record = LowerRecord(
357
+ phase=phase,
358
+ name=pass_name,
359
+ index=idx,
360
+ before_text=before_text,
361
+ after_text=after_text,
362
+ changed=changed,
363
+ add_lines=add_count,
364
+ del_lines=del_count,
365
+ status=STATUS_COMPLETED,
366
+ )
367
+ _records.append(record)
368
+ _save_raw_files(record)
369
+ tag = "CHANGED" if changed else "NO-OP"
370
+ tag_color = _ANSI_GREEN if changed else _ANSI_DIM
371
+ print(f" {_TAG} {phase}/{idx:02d}_{pass_name}: {tag_color}{tag}{_ANSI_RESET}")
372
+
373
+ if changed:
374
+ label = f"{phase}/{pass_name}"
375
+ print_diff(before_text, after_text, f"{label} (before)", f"{label} (after)")
376
+
377
+ return result
378
+
379
+
380
+ # ── Layer 2: Phase 发现与包装 ─────────────────────────────────────────────────
381
+
382
+
383
+ def _discover_phases(lower_func) -> list:
384
+ """通过字节码扫描发现旧架构的 phase 函数。
385
+
386
+ 扫描 ``tilelang.engine.lower`` 函数的字节码,找出其引用的
387
+ ``tilelang.engine.phase`` 模块中的 phase 函数
388
+ (如 ``LowerAndLegalize``、``OptimizeForTarget``)。
389
+
390
+ 原理:``tilelang.engine.lower`` 内部调用各 phase 函数,这些调用在
391
+ 字节码中表现为 ``LOAD_GLOBAL <phase_name>``。扫描这些指令即可发现
392
+ 所有被引用的 phase 函数。
393
+
394
+ 按 source line 排序,保证 phase 顺序与定义顺序一致。
395
+ """
396
+ try:
397
+ from tilelang.engine import phase as phase_module
398
+ except ImportError:
399
+ return []
400
+
401
+ phase_funcs = []
402
+ seen_names = set()
403
+ try:
404
+ for instr in dis.get_instructions(lower_func):
405
+ if instr.opname == "LOAD_GLOBAL" and instr.argval not in seen_names:
406
+ name = instr.argval
407
+ seen_names.add(name)
408
+ func = getattr(phase_module, name, None)
409
+ if func is not None and callable(func):
410
+ phase_funcs.append(func)
411
+ except (TypeError, OSError):
412
+ pass
413
+
414
+ # 回退:字节码扫描失败时,取 phase 模块所有公共 callable
415
+ if not phase_funcs:
416
+ phase_funcs = [
417
+ getattr(phase_module, name)
418
+ for name in sorted(dir(phase_module))
419
+ if not name.startswith("_") and callable(getattr(phase_module, name, None))
420
+ ]
421
+
422
+ def _src_line(f):
423
+ try:
424
+ return inspect.getsourcelines(f)[1]
425
+ except (OSError, TypeError):
426
+ return 999999
427
+
428
+ phase_funcs.sort(key=_src_line)
429
+ return phase_funcs
430
+
431
+
432
+ def _wrap_phase(original_func, phase_index, total_phases):
433
+ """包装一个 phase 函数,在调用期间设置 ``_current_phase`` 上下文。
434
+
435
+ 这样 Layer 1 的 ``_traced_pass_call`` 就知道当前 pass 属于哪个 phase。
436
+ phase 名格式:``phase<N>_<func_name>``(多次 run 时加 ``run<M>_`` 前缀)。
437
+
438
+ ``phase_index == 1`` 时递增 ``_run_counter``;非首次 run 时重置 ``_run_dir``
439
+ 使后续 pass 落盘到新目录。
440
+ """
441
+ base_phase_name = f"phase{phase_index}_{original_func.__name__}"
442
+
443
+ @functools.wraps(original_func)
444
+ def wrapper(*args, **kwargs):
445
+ global _run_counter, _current_phase, _run_dir
446
+
447
+ with _lock:
448
+ if phase_index == 1:
449
+ _run_counter += 1
450
+ # 非首次 run 时重置 _run_dir,使新 run 落盘到新目录
451
+ if _run_counter > 1:
452
+ _run_dir = None
453
+ run_prefix = f"run{_run_counter}_" if _run_counter > 1 else ""
454
+ phase_name = f"{run_prefix}{base_phase_name}"
455
+ _current_phase = phase_name
456
+
457
+ try:
458
+ result = original_func(*args, **kwargs)
459
+ except Exception as e:
460
+ with _lock:
461
+ _current_phase = None
462
+ print(f" {_ANSI_RED}{_TAG} EXCEPTION in {phase_name}: {e}{_ANSI_RESET}")
463
+ raise
464
+
465
+ with _lock:
466
+ _current_phase = None
467
+ if phase_index == total_phases:
468
+ print(
469
+ f" {_TAG} run {_run_counter} ({phase_name}) complete: {len(_records)} total records"
470
+ )
471
+
472
+ return result
473
+
474
+ return wrapper
475
+
476
+
477
+ # ── Layer 3: Codegen 拦截(简化版,只捕获不编辑)─────────────────────────────
478
+
479
+
480
+ def _wrap_codegen_ffi(original_build, ffi_name=""):
481
+ """包装 codegen FFI,捕获 TIR-before / C++-after 作为一条 ``STATUS_CODEGEN`` 记录。
482
+
483
+ 只做 **捕获**:调用 codegen 前后各拍一个快照,记录差异。
484
+
485
+ 临时设置 ``_current_phase = 'codegen'``,使 codegen 内部的 pass
486
+ (如 ``device_codegen`` 中的 ``tir.transform.Simplify``)也被 Layer 1 捕获,
487
+ 并归属于 codegen 阶段。
488
+
489
+ Parameters
490
+ ----------
491
+ original_build : Callable
492
+ 原始 codegen FFI 函数(如 ``target.build.tilelang_ascend``)。
493
+ ffi_name : str
494
+ FFI 注册名,用于调试日志。
495
+ """
496
+
497
+ @functools.wraps(original_build)
498
+ def wrapper(*args, **kwargs):
499
+ global _pass_index, _current_phase
500
+
501
+ if _original_pass_call is None:
502
+ return original_build(*args, **kwargs)
503
+
504
+ mod = args[0] if args else kwargs.get("mod")
505
+ before_text = str(mod)
506
+
507
+ with _lock:
508
+ previous_phase = _current_phase
509
+ _current_phase = "codegen"
510
+
511
+ after_text = ""
512
+ try:
513
+ result = original_build(*args, **kwargs)
514
+ except Exception as e:
515
+ with _lock:
516
+ idx = _pass_index
517
+ _pass_index += 1
518
+ record = LowerRecord(
519
+ phase="codegen",
520
+ name=getattr(original_build, "__name__", "codegen"),
521
+ index=idx,
522
+ before_text=before_text,
523
+ after_text="",
524
+ changed=False,
525
+ status=STATUS_FAILED,
526
+ error_msg=str(e),
527
+ )
528
+ _records.append(record)
529
+ _save_raw_files(record)
530
+ _current_phase = previous_phase
531
+ print(f" {_ANSI_RED}{_TAG} codegen/{idx:02d}_codegen: FAILED ({e}){_ANSI_RESET}")
532
+ raise
533
+
534
+ try:
535
+ with _lock:
536
+ idx = _pass_index
537
+ _pass_index += 1
538
+ after_text = _inspect_module_source(result)
539
+ add_count, del_count = _count_diff(before_text, after_text)
540
+
541
+ with _lock:
542
+ record = LowerRecord(
543
+ phase="codegen",
544
+ name="codegen",
545
+ index=idx,
546
+ before_text=before_text,
547
+ after_text=after_text,
548
+ changed=True,
549
+ add_lines=add_count,
550
+ del_lines=del_count,
551
+ status=STATUS_CODEGEN,
552
+ )
553
+ _records.append(record)
554
+ _save_raw_files(record)
555
+ print(
556
+ f" {_TAG} codegen/{idx:02d}_codegen: {_ANSI_BLUE}CODEGEN{_ANSI_RESET} (+{add_count}/-{del_count})"
557
+ )
558
+ except Exception as exc:
559
+ print(f" {_ANSI_RED}{_TAG} WARNING: post-codegen tracing failed: {exc}{_ANSI_RESET}")
560
+ finally:
561
+ with _lock:
562
+ _current_phase = previous_phase
563
+
564
+ print_diff(before_text, after_text, "codegen (TIR before)", "codegen (C++ after)")
565
+ return result
566
+
567
+ return wrapper
568
+
569
+
570
+ # ── 启停控制 ──────────────────────────────────────────────────────────────────
571
+
572
+
573
+ def enable(*, trace_dir: str | None = None):
574
+ """安装三层 monkey-patch 钩子,开始跟踪编译 Pass。
575
+
576
+ 幂等:重复调用不会重复安装。
577
+
578
+ Parameters
579
+ ----------
580
+ trace_dir : str, optional
581
+ 输出根目录,默认 ``./tmp/lower_trace``。
582
+
583
+ Raises
584
+ ------
585
+ RuntimeError
586
+ 如果 tvm 未安装。tvm 是 tracing 的硬依赖(Pass.__call__ 拦截需要)。
587
+ 请先安装 tilelang(会带 tvm)。
588
+ """
589
+ global _trace_dir, _original_pass_call, _legacy_patched
590
+
591
+ if trace_dir is not None:
592
+ _trace_dir = str(trace_dir)
593
+
594
+ # tvm 是硬依赖:没有 tvm 无法拦截 Pass.__call__
595
+ try:
596
+ from tvm.ir.transform import Pass
597
+ except ImportError as exc:
598
+ raise RuntimeError(
599
+ "lower_trace requires tvm to enable tracing. "
600
+ "Install tilelang first: pip install tilelang (or use the tilelang-ascend environment)."
601
+ ) from exc
602
+
603
+ # Layer 1: 拦截所有 Pass.__call__
604
+ if _original_pass_call is None:
605
+ _original_pass_call = Pass.__call__
606
+ Pass.__call__ = _traced_pass_call
607
+
608
+ # Layer 3: 拦截 codegen FFI
609
+ if not _original_codegen_ffis:
610
+ _ffi = _get_tvm_ffi()
611
+ for ffi_name in _CODEGEN_FFI_NAMES:
612
+ try:
613
+ orig = _ffi.get_global_func(ffi_name)
614
+ if orig is not None:
615
+ wrapped = _wrap_codegen_ffi(orig, ffi_name)
616
+ _original_codegen_ffis[ffi_name] = orig
617
+ _ffi.register_global_func(ffi_name, wrapped, override=True)
618
+ except Exception as exc:
619
+ print(f"{_TAG} WARNING: could not wrap codegen FFI {ffi_name}: {exc}")
620
+
621
+ if _legacy_patched:
622
+ return
623
+
624
+ # Layer 2: 包装 phase 函数(tilelang 专用)
625
+ # tilelang 缺失时降级:Pass 跟踪仍可用(全标 unscoped),但无 phase 标签
626
+ try:
627
+ import tilelang.engine.lower as lower_mod
628
+
629
+ lower_func = lower_mod.lower
630
+ patch_mod = lower_mod
631
+ except (ImportError, AttributeError):
632
+ try:
633
+ from tilelang.engine import lower as lower_func
634
+
635
+ import tilelang.engine as patch_mod
636
+ except (ImportError, AttributeError) as e:
637
+ print(
638
+ f"{_TAG} WARNING: tilelang engine not found ({e}); phase tracing disabled (passes will be tagged 'unscoped')."
639
+ )
640
+ return
641
+
642
+ phase_funcs = _discover_phases(lower_func)
643
+ for i, phase_func in enumerate(phase_funcs):
644
+ wrapped = _wrap_phase(phase_func, i + 1, len(phase_funcs))
645
+ name = phase_func.__name__
646
+
647
+ # 在三个位置保存原始引用,disable() 时据此恢复
648
+ _legacy_phase_originals.append((patch_mod, name, getattr(patch_mod, name, _MISSING), False))
649
+ setattr(patch_mod, name, wrapped)
650
+ try:
651
+ from tilelang.engine import phase as phase_module
652
+
653
+ if hasattr(phase_module, name):
654
+ _legacy_phase_originals.append(
655
+ (phase_module, name, getattr(phase_module, name, _MISSING), False)
656
+ )
657
+ setattr(phase_module, name, wrapped)
658
+ except ImportError:
659
+ pass
660
+ glbls = getattr(lower_func, "__globals__", None)
661
+ if glbls is not None and name in glbls:
662
+ _legacy_phase_originals.append((glbls, name, glbls[name], True))
663
+ glbls[name] = wrapped
664
+
665
+ _legacy_patched = True
666
+ print(f"{_TAG} IR pass tracing enabled ({len(phase_funcs)} phases, trace_dir={_trace_dir}).")
667
+
668
+
669
+ def disable():
670
+ """卸载所有 monkey-patch 钩子,恢复原始行为。
671
+
672
+ 完全可逆:恢复 ``Pass.__call__``、codegen FFI、phase 函数,
673
+ 清空记录和状态。
674
+ """
675
+ global _original_pass_call, _legacy_patched, _legacy_phase_originals
676
+ global _run_counter, _run_dir, _trace_dir
677
+
678
+ if _original_pass_call is not None:
679
+ from tvm.ir.transform import Pass
680
+
681
+ Pass.__call__ = _original_pass_call
682
+ _original_pass_call = None
683
+
684
+ # 恢复 codegen FFI(仅在实际包装过时才需要 FFI)
685
+ if _original_codegen_ffis:
686
+ _ffi = _get_tvm_ffi()
687
+ for ffi_name, orig in _original_codegen_ffis.items():
688
+ with contextlib.suppress(Exception):
689
+ _ffi.register_global_func(ffi_name, orig, override=True)
690
+ _original_codegen_ffis.clear()
691
+
692
+ # 恢复 phase 函数(三个位置)
693
+ for target, name, original, is_dict in _legacy_phase_originals:
694
+ with contextlib.suppress(Exception):
695
+ if original is _MISSING:
696
+ if is_dict:
697
+ del target[name]
698
+ else:
699
+ delattr(target, name)
700
+ else:
701
+ if is_dict:
702
+ target[name] = original
703
+ else:
704
+ setattr(target, name, original)
705
+ _legacy_phase_originals = []
706
+
707
+ _legacy_patched = False
708
+ _run_counter = 0
709
+ _run_dir = None
710
+ _trace_dir = _DEFAULT_TRACE_DIR
711
+ reset()
712
+
713
+
714
+ def reset():
715
+ """清空已收集的记录和 pass 序号。
716
+
717
+ ``_run_dir`` 不清空:清空会导致同一次 run 的 pass 分散到多个目录
718
+ (pre-pipeline pass 先于 phase 创建 ``_run_dir``)。
719
+ 新 run 的 ``_run_dir`` 由 ``_wrap_phase`` 在 ``_run_counter > 1`` 时重置。
720
+ """
721
+ global _records, _current_phase, _pass_index
722
+ _records = []
723
+ _current_phase = None
724
+ _pass_index = 0
725
+
726
+
727
+ # ── 便捷上下文管理器 ──────────────────────────────────────────────────────────
728
+
729
+
730
+ @contextlib.contextmanager
731
+ def trace(*, trace_dir: str | None = None):
732
+ """上下文管理器:``with trace(): ...`` 自动 enable/disable。
733
+
734
+ Example::
735
+
736
+ with trace():
737
+ kernel = tilelang.jit(func)()
738
+ """
739
+ enable(trace_dir=trace_dir)
740
+ try:
741
+ yield
742
+ finally:
743
+ disable()