tilelang-devkit 0.1.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- tilelang_devkit/__init__.py +47 -0
- tilelang_devkit/_version.py +6 -0
- tilelang_devkit/lower_trace/__init__.py +37 -0
- tilelang_devkit/lower_trace/core.py +743 -0
- tilelang_devkit/lower_trace/diff.py +101 -0
- tilelang_devkit/py.typed +0 -0
- tilelang_devkit-0.1.0.dist-info/METADATA +192 -0
- tilelang_devkit-0.1.0.dist-info/RECORD +13 -0
- tilelang_devkit-0.1.0.dist-info/WHEEL +5 -0
- tilelang_devkit-0.1.0.dist-info/licenses/LICENSE +201 -0
- tilelang_devkit-0.1.0.dist-info/scm_file_list.json +15 -0
- tilelang_devkit-0.1.0.dist-info/scm_version.json +8 -0
- tilelang_devkit-0.1.0.dist-info/top_level.txt +1 -0
|
@@ -0,0 +1,743 @@
|
|
|
1
|
+
"""IR Lower Trace — zero-intrusion debug tool for visualizing tilelang compilation passes.
|
|
2
|
+
|
|
3
|
+
This module is the core of the ``tilelang-devkit`` package. It monkey-patches
|
|
4
|
+
``tvm.ir.transform.Pass.__call__`` and tilelang's phase functions to
|
|
5
|
+
automatically capture IR before/after every compilation pass, then prints a
|
|
6
|
+
unified diff to the terminal and saves the raw IR to disk.
|
|
7
|
+
|
|
8
|
+
Architecture
|
|
9
|
+
============
|
|
10
|
+
``enable()`` installs three layers of monkey-patch hooks::
|
|
11
|
+
|
|
12
|
+
Layer 1: Pass.__call__ = _traced_pass_call
|
|
13
|
+
Intercept all TVM Pass calls, snapshot before/after IR → diff → record
|
|
14
|
+
|
|
15
|
+
Layer 2: phase functions = _wrap_phase(...)
|
|
16
|
+
Set _current_phase so Layer 1 knows which phase each pass belongs to
|
|
17
|
+
|
|
18
|
+
Layer 3: codegen FFI = _wrap_codegen_ffi(...)
|
|
19
|
+
Intercept TIR→C++ conversion, capture final IR and generated source
|
|
20
|
+
|
|
21
|
+
During kernel compilation::
|
|
22
|
+
|
|
23
|
+
phase1_LowerAndLegalize() ← _wrap_phase sets _current_phase
|
|
24
|
+
├─ Simplify(mod) ← _traced_pass_call intercepts
|
|
25
|
+
├─ LowerTileOp(mod) ← _traced_pass_call intercepts
|
|
26
|
+
└─ ...
|
|
27
|
+
phase2_OptimizeForTarget()
|
|
28
|
+
├─ FlattenBuffer(mod) ← _traced_pass_call intercepts
|
|
29
|
+
└─ ...
|
|
30
|
+
codegen(mod) ← _wrap_codegen_ffi intercepts
|
|
31
|
+
|
|
32
|
+
Usage
|
|
33
|
+
=====
|
|
34
|
+
::
|
|
35
|
+
|
|
36
|
+
from lower_trace import enable, disable
|
|
37
|
+
|
|
38
|
+
enable() # install hooks, start tracing
|
|
39
|
+
# ... compile your kernel ...
|
|
40
|
+
disable() # uninstall hooks, restore original behavior
|
|
41
|
+
|
|
42
|
+
Or use the context manager::
|
|
43
|
+
|
|
44
|
+
from lower_trace import trace
|
|
45
|
+
|
|
46
|
+
with trace():
|
|
47
|
+
# ... compile your kernel ...
|
|
48
|
+
|
|
49
|
+
Output
|
|
50
|
+
======
|
|
51
|
+
Terminal diff is printed for each pass, and raw IR is saved to::
|
|
52
|
+
|
|
53
|
+
./tmp/lower_trace/run_<timestamp>_<pid>/
|
|
54
|
+
├── phase1_LowerAndLegalize/
|
|
55
|
+
│ ├── 00_Simplify_before.tir
|
|
56
|
+
│ ├── 00_Simplify_after.tir
|
|
57
|
+
│ └── ...
|
|
58
|
+
├── phase2_OptimizeForTarget/...
|
|
59
|
+
└── codegen/
|
|
60
|
+
└── NN_codegen_before.tir
|
|
61
|
+
└── NN_codegen_after.cpp
|
|
62
|
+
|
|
63
|
+
Dependencies
|
|
64
|
+
============
|
|
65
|
+
This module imports cleanly without tvm/tilelang installed (so ``pip install``
|
|
66
|
+
works on any environment). The actual tracing requires tvm + tilelang at
|
|
67
|
+
runtime: ``enable()`` raises ``RuntimeError`` if tvm is missing, and degrades
|
|
68
|
+
gracefully (no phase labels) if tilelang's engine module is unavailable.
|
|
69
|
+
"""
|
|
70
|
+
|
|
71
|
+
from __future__ import annotations
|
|
72
|
+
|
|
73
|
+
import contextlib
|
|
74
|
+
import difflib
|
|
75
|
+
import dis
|
|
76
|
+
import functools
|
|
77
|
+
import inspect
|
|
78
|
+
import os
|
|
79
|
+
import re
|
|
80
|
+
import threading
|
|
81
|
+
from dataclasses import dataclass
|
|
82
|
+
from typing import TYPE_CHECKING
|
|
83
|
+
|
|
84
|
+
from .diff import (
|
|
85
|
+
_ANSI_BLUE,
|
|
86
|
+
_ANSI_DIM,
|
|
87
|
+
_ANSI_GREEN,
|
|
88
|
+
_ANSI_RED,
|
|
89
|
+
_ANSI_RESET,
|
|
90
|
+
print_diff,
|
|
91
|
+
)
|
|
92
|
+
|
|
93
|
+
if TYPE_CHECKING:
|
|
94
|
+
from collections.abc import Callable
|
|
95
|
+
|
|
96
|
+
_TAG = "[lower_trace]"
|
|
97
|
+
|
|
98
|
+
STATUS_COMPLETED = "completed"
|
|
99
|
+
STATUS_FAILED = "failed"
|
|
100
|
+
STATUS_SKIPPED = "skipped"
|
|
101
|
+
STATUS_CODEGEN = "codegen"
|
|
102
|
+
|
|
103
|
+
_CODEGEN_FFI_NAMES: list[str] = [
|
|
104
|
+
"target.build.tilelang_cuda",
|
|
105
|
+
"target.build.tilelang_cuda_without_compile",
|
|
106
|
+
"target.build.tilelang_cutedsl",
|
|
107
|
+
"target.build.tilelang_cutedsl_without_compile",
|
|
108
|
+
"target.build.tilelang_hip",
|
|
109
|
+
"target.build.tilelang_hip_without_compile",
|
|
110
|
+
"target.build.tilelang_metal",
|
|
111
|
+
"target.build.tilelang_c",
|
|
112
|
+
"target.build.tilelang_c_host",
|
|
113
|
+
"target.build.tilelang_ascend",
|
|
114
|
+
"target.build.tilelang_ascend_pto",
|
|
115
|
+
"target.build.llvm",
|
|
116
|
+
"target.build.webgpu",
|
|
117
|
+
"target.build.tilelang_cpp",
|
|
118
|
+
"target.build.tilelang_webgpu",
|
|
119
|
+
]
|
|
120
|
+
|
|
121
|
+
# ── 全局状态 ──────────────────────────────────────────────────────────────────
|
|
122
|
+
# _records: 每次 pass 的记录列表,disable() 时清空
|
|
123
|
+
_records: list[LowerRecord] = []
|
|
124
|
+
|
|
125
|
+
# 原始函数引用:enable() 时保存,disable() 时恢复(保证 monkey-patch 可逆)
|
|
126
|
+
_original_pass_call: Callable | None = None
|
|
127
|
+
_original_codegen_ffis: dict[str, Callable] = {}
|
|
128
|
+
_legacy_patched: bool = False
|
|
129
|
+
# (target, attr_name, original_or_MISSING, is_dict) — disable() 据此恢复
|
|
130
|
+
_legacy_phase_originals: list[tuple[object, str, object, bool]] = []
|
|
131
|
+
_MISSING: object = object()
|
|
132
|
+
|
|
133
|
+
# 当前 phase 上下文:_wrap_phase 设置,_traced_pass_call 读取
|
|
134
|
+
_current_phase: str | None = None
|
|
135
|
+
_pass_index: int = 0
|
|
136
|
+
_run_dir: str | None = None
|
|
137
|
+
_lock = threading.RLock()
|
|
138
|
+
_run_counter: int = 0
|
|
139
|
+
|
|
140
|
+
_UNSCOPED_PHASE = "unscoped"
|
|
141
|
+
_DEFAULT_TRACE_DIR = os.path.join(".", "tmp", "lower_trace")
|
|
142
|
+
_trace_dir: str = _DEFAULT_TRACE_DIR
|
|
143
|
+
|
|
144
|
+
|
|
145
|
+
@dataclass
|
|
146
|
+
class LowerRecord:
|
|
147
|
+
"""单次 pass 执行的记录。
|
|
148
|
+
|
|
149
|
+
Attributes
|
|
150
|
+
----------
|
|
151
|
+
phase : str
|
|
152
|
+
所属阶段(如 ``phase1_LowerAndLegalize``、``codegen``、``unscoped``)。
|
|
153
|
+
name : str
|
|
154
|
+
Pass 显示名(如 ``Simplify``、``FlattenBuffer``)。
|
|
155
|
+
index : int
|
|
156
|
+
全局序号(跨 phase 递增)。
|
|
157
|
+
before_text : str
|
|
158
|
+
Pass 执行前的 IR 文本(``str(mod)``)。
|
|
159
|
+
after_text : str
|
|
160
|
+
Pass 执行后的 IR 文本(``str(result)``);codegen 记录为 C++ 源码。
|
|
161
|
+
changed : bool
|
|
162
|
+
before != after。
|
|
163
|
+
add_lines, del_lines : int
|
|
164
|
+
diff 增删行数。
|
|
165
|
+
status : str
|
|
166
|
+
``STATUS_COMPLETED`` / ``STATUS_FAILED`` / ``STATUS_CODEGEN``。
|
|
167
|
+
error_msg : str
|
|
168
|
+
失败时的异常信息。
|
|
169
|
+
"""
|
|
170
|
+
|
|
171
|
+
phase: str
|
|
172
|
+
name: str
|
|
173
|
+
index: int
|
|
174
|
+
before_text: str
|
|
175
|
+
after_text: str
|
|
176
|
+
changed: bool
|
|
177
|
+
add_lines: int = 0
|
|
178
|
+
del_lines: int = 0
|
|
179
|
+
status: str = STATUS_COMPLETED
|
|
180
|
+
error_msg: str = ""
|
|
181
|
+
|
|
182
|
+
|
|
183
|
+
# ── 辅助函数 ──────────────────────────────────────────────────────────────────
|
|
184
|
+
|
|
185
|
+
|
|
186
|
+
def _get_tvm_ffi():
|
|
187
|
+
"""返回统一的 FFI 接口(``get_global_func`` / ``register_global_func``)。
|
|
188
|
+
|
|
189
|
+
优先用新版 ``tvm.ffi``,回退到 ``3rdparty/tvm`` 的 legacy ``tvm._ffi``
|
|
190
|
+
(注册函数名为 ``register_func`` 而非 ``register_global_func``)。
|
|
191
|
+
"""
|
|
192
|
+
try:
|
|
193
|
+
import tvm.ffi as _ffi
|
|
194
|
+
|
|
195
|
+
if hasattr(_ffi, "register_global_func") and hasattr(_ffi, "get_global_func"):
|
|
196
|
+
return _ffi
|
|
197
|
+
except ImportError:
|
|
198
|
+
pass
|
|
199
|
+
import tvm._ffi as _ffi
|
|
200
|
+
|
|
201
|
+
class _LegacyFFI:
|
|
202
|
+
"""Adapter: 在 legacy tvm._ffi 上暴露 register_global_func API。"""
|
|
203
|
+
|
|
204
|
+
get_global_func = staticmethod(_ffi.get_global_func)
|
|
205
|
+
|
|
206
|
+
@staticmethod
|
|
207
|
+
def register_global_func(name, func=None, override=False):
|
|
208
|
+
return _ffi.register_func(name, func, override=override)
|
|
209
|
+
|
|
210
|
+
return _LegacyFFI()
|
|
211
|
+
|
|
212
|
+
|
|
213
|
+
def _inspect_module_source(mod):
|
|
214
|
+
"""获取 ``tvm.runtime.Module`` 的源码文本。
|
|
215
|
+
|
|
216
|
+
优先 ``inspect_source``(新版),回退 ``get_source``(3rdparty/tvm)。
|
|
217
|
+
用于 codegen 后捕获生成的 C++/CUDA 源码。
|
|
218
|
+
"""
|
|
219
|
+
for _attr in ("inspect_source", "get_source"):
|
|
220
|
+
_fn = getattr(mod, _attr, None)
|
|
221
|
+
if callable(_fn):
|
|
222
|
+
return _fn() or ""
|
|
223
|
+
return ""
|
|
224
|
+
|
|
225
|
+
|
|
226
|
+
def _get_pass_display_name(pass_obj) -> str:
|
|
227
|
+
"""从 ``pass_info.name`` 提取显示名,如 ``tir.Simplify`` -> ``Simplify``。"""
|
|
228
|
+
try:
|
|
229
|
+
name = str(pass_obj.info.name)
|
|
230
|
+
return name.split(".")[-1] if "." in name else name
|
|
231
|
+
except Exception:
|
|
232
|
+
return type(pass_obj).__name__
|
|
233
|
+
|
|
234
|
+
|
|
235
|
+
def _safe_filename_component(name: str) -> str:
|
|
236
|
+
"""净化字符串使其可作为路径组件(防止路径穿越 CWE-22)。"""
|
|
237
|
+
return re.sub(r"[^A-Za-z0-9._-]", "_", str(name))
|
|
238
|
+
|
|
239
|
+
|
|
240
|
+
def _ensure_run_dir() -> str:
|
|
241
|
+
"""返回本次 run 的输出目录(首次调用时创建)。
|
|
242
|
+
|
|
243
|
+
格式:``<_trace_dir>/run_<timestamp>_<pid>/``
|
|
244
|
+
每次 run 一个新目录,便于区分多次编译。
|
|
245
|
+
"""
|
|
246
|
+
global _run_dir
|
|
247
|
+
|
|
248
|
+
if _run_dir is not None:
|
|
249
|
+
return _run_dir
|
|
250
|
+
|
|
251
|
+
from datetime import datetime
|
|
252
|
+
|
|
253
|
+
timestamp = datetime.now().strftime("%Y%m%d_%H%M%S_%f")
|
|
254
|
+
_run_dir = os.path.join(_trace_dir, f"run_{timestamp}_{os.getpid()}")
|
|
255
|
+
os.makedirs(_run_dir, exist_ok=True)
|
|
256
|
+
return _run_dir
|
|
257
|
+
|
|
258
|
+
|
|
259
|
+
def _save_raw_files(record: LowerRecord):
|
|
260
|
+
"""把 before/after IR 文本落盘。
|
|
261
|
+
|
|
262
|
+
路径:``<run_dir>/<phase>/<index>_<name>_before|after.<ext>``
|
|
263
|
+
|
|
264
|
+
codegen 记录的 after 是 C++ 源码,扩展名用 ``.cpp``;其余用 ``.tir``。
|
|
265
|
+
落盘是 best-effort:失败只打 warning,不中断编译。
|
|
266
|
+
"""
|
|
267
|
+
try:
|
|
268
|
+
trace_dir = _ensure_run_dir()
|
|
269
|
+
phase_dir = os.path.join(trace_dir, _safe_filename_component(record.phase))
|
|
270
|
+
os.makedirs(phase_dir, exist_ok=True)
|
|
271
|
+
|
|
272
|
+
prefix = f"{record.index:02d}_{_safe_filename_component(record.name)}"
|
|
273
|
+
before_ext = ".tir"
|
|
274
|
+
after_ext = ".cpp" if record.status == STATUS_CODEGEN else ".tir"
|
|
275
|
+
with open(
|
|
276
|
+
os.path.join(phase_dir, f"{prefix}_before{before_ext}"), "w", encoding="utf-8"
|
|
277
|
+
) as f:
|
|
278
|
+
f.write(record.before_text)
|
|
279
|
+
with open(
|
|
280
|
+
os.path.join(phase_dir, f"{prefix}_after{after_ext}"), "w", encoding="utf-8"
|
|
281
|
+
) as f:
|
|
282
|
+
f.write(record.after_text)
|
|
283
|
+
except Exception as exc:
|
|
284
|
+
print(f" {_ANSI_RED}{_TAG} WARNING: could not save raw trace files: {exc}{_ANSI_RESET}")
|
|
285
|
+
|
|
286
|
+
|
|
287
|
+
def _count_diff(before_text: str, after_text: str) -> tuple[int, int]:
|
|
288
|
+
"""用 SequenceMatcher 统计增删行数,返回 (add_count, del_count)。"""
|
|
289
|
+
add_count = del_count = 0
|
|
290
|
+
sm = difflib.SequenceMatcher(None, before_text.splitlines(), after_text.splitlines())
|
|
291
|
+
for tag, i1, i2, j1, j2 in sm.get_opcodes():
|
|
292
|
+
if tag == "insert":
|
|
293
|
+
add_count += j2 - j1
|
|
294
|
+
elif tag == "delete":
|
|
295
|
+
del_count += i2 - i1
|
|
296
|
+
elif tag == "replace":
|
|
297
|
+
add_count += j2 - j1
|
|
298
|
+
del_count += i2 - i1
|
|
299
|
+
return add_count, del_count
|
|
300
|
+
|
|
301
|
+
|
|
302
|
+
# ── Layer 1: 核心 Pass 拦截 ───────────────────────────────────────────────────
|
|
303
|
+
|
|
304
|
+
|
|
305
|
+
def _traced_pass_call(self, mod):
|
|
306
|
+
"""★ 核心钩子:拦截所有 ``Pass.__call__`` 调用。
|
|
307
|
+
|
|
308
|
+
这是整个工具的心脏。``enable()`` 时 ``Pass.__call__`` 被替换为本函数,
|
|
309
|
+
之后 **所有** TVM Pass 调用(无论在哪个 phase)都会经过这里。
|
|
310
|
+
|
|
311
|
+
流程:
|
|
312
|
+
1. 调用前 ``str(mod)`` 拍 before 快照
|
|
313
|
+
2. 分配全局递增序号 ``idx``
|
|
314
|
+
3. 调用原始 ``Pass.__call__``
|
|
315
|
+
4. 调用后 ``str(result)`` 拍 after 快照
|
|
316
|
+
5. 计算 diff 统计 → 构造 ``LowerRecord`` → 落盘 → 终端打印
|
|
317
|
+
6. 如果有变化,打印终端 unified diff
|
|
318
|
+
|
|
319
|
+
``_current_phase`` 由 Layer 2 (``_wrap_phase``) 设置,告诉本函数
|
|
320
|
+
当前 pass 属于哪个阶段。未设置时标记为 ``unscoped``。
|
|
321
|
+
"""
|
|
322
|
+
global _pass_index
|
|
323
|
+
|
|
324
|
+
phase = _current_phase or _UNSCOPED_PHASE
|
|
325
|
+
before_text = str(mod)
|
|
326
|
+
|
|
327
|
+
with _lock:
|
|
328
|
+
idx = _pass_index
|
|
329
|
+
_pass_index += 1
|
|
330
|
+
|
|
331
|
+
try:
|
|
332
|
+
result = _original_pass_call(self, mod)
|
|
333
|
+
except Exception as e:
|
|
334
|
+
with _lock:
|
|
335
|
+
record = LowerRecord(
|
|
336
|
+
phase=phase,
|
|
337
|
+
name=_get_pass_display_name(self),
|
|
338
|
+
index=idx,
|
|
339
|
+
before_text=before_text,
|
|
340
|
+
after_text="",
|
|
341
|
+
changed=False,
|
|
342
|
+
status=STATUS_FAILED,
|
|
343
|
+
error_msg=str(e),
|
|
344
|
+
)
|
|
345
|
+
_records.append(record)
|
|
346
|
+
_save_raw_files(record)
|
|
347
|
+
print(f" {_ANSI_RED}{_TAG} {phase}/{idx:02d}_{record.name}: FAILED ({e}){_ANSI_RESET}")
|
|
348
|
+
raise
|
|
349
|
+
|
|
350
|
+
after_text = str(result)
|
|
351
|
+
changed = before_text != after_text
|
|
352
|
+
pass_name = _get_pass_display_name(self)
|
|
353
|
+
add_count, del_count = _count_diff(before_text, after_text) if changed else (0, 0)
|
|
354
|
+
|
|
355
|
+
with _lock:
|
|
356
|
+
record = LowerRecord(
|
|
357
|
+
phase=phase,
|
|
358
|
+
name=pass_name,
|
|
359
|
+
index=idx,
|
|
360
|
+
before_text=before_text,
|
|
361
|
+
after_text=after_text,
|
|
362
|
+
changed=changed,
|
|
363
|
+
add_lines=add_count,
|
|
364
|
+
del_lines=del_count,
|
|
365
|
+
status=STATUS_COMPLETED,
|
|
366
|
+
)
|
|
367
|
+
_records.append(record)
|
|
368
|
+
_save_raw_files(record)
|
|
369
|
+
tag = "CHANGED" if changed else "NO-OP"
|
|
370
|
+
tag_color = _ANSI_GREEN if changed else _ANSI_DIM
|
|
371
|
+
print(f" {_TAG} {phase}/{idx:02d}_{pass_name}: {tag_color}{tag}{_ANSI_RESET}")
|
|
372
|
+
|
|
373
|
+
if changed:
|
|
374
|
+
label = f"{phase}/{pass_name}"
|
|
375
|
+
print_diff(before_text, after_text, f"{label} (before)", f"{label} (after)")
|
|
376
|
+
|
|
377
|
+
return result
|
|
378
|
+
|
|
379
|
+
|
|
380
|
+
# ── Layer 2: Phase 发现与包装 ─────────────────────────────────────────────────
|
|
381
|
+
|
|
382
|
+
|
|
383
|
+
def _discover_phases(lower_func) -> list:
|
|
384
|
+
"""通过字节码扫描发现旧架构的 phase 函数。
|
|
385
|
+
|
|
386
|
+
扫描 ``tilelang.engine.lower`` 函数的字节码,找出其引用的
|
|
387
|
+
``tilelang.engine.phase`` 模块中的 phase 函数
|
|
388
|
+
(如 ``LowerAndLegalize``、``OptimizeForTarget``)。
|
|
389
|
+
|
|
390
|
+
原理:``tilelang.engine.lower`` 内部调用各 phase 函数,这些调用在
|
|
391
|
+
字节码中表现为 ``LOAD_GLOBAL <phase_name>``。扫描这些指令即可发现
|
|
392
|
+
所有被引用的 phase 函数。
|
|
393
|
+
|
|
394
|
+
按 source line 排序,保证 phase 顺序与定义顺序一致。
|
|
395
|
+
"""
|
|
396
|
+
try:
|
|
397
|
+
from tilelang.engine import phase as phase_module
|
|
398
|
+
except ImportError:
|
|
399
|
+
return []
|
|
400
|
+
|
|
401
|
+
phase_funcs = []
|
|
402
|
+
seen_names = set()
|
|
403
|
+
try:
|
|
404
|
+
for instr in dis.get_instructions(lower_func):
|
|
405
|
+
if instr.opname == "LOAD_GLOBAL" and instr.argval not in seen_names:
|
|
406
|
+
name = instr.argval
|
|
407
|
+
seen_names.add(name)
|
|
408
|
+
func = getattr(phase_module, name, None)
|
|
409
|
+
if func is not None and callable(func):
|
|
410
|
+
phase_funcs.append(func)
|
|
411
|
+
except (TypeError, OSError):
|
|
412
|
+
pass
|
|
413
|
+
|
|
414
|
+
# 回退:字节码扫描失败时,取 phase 模块所有公共 callable
|
|
415
|
+
if not phase_funcs:
|
|
416
|
+
phase_funcs = [
|
|
417
|
+
getattr(phase_module, name)
|
|
418
|
+
for name in sorted(dir(phase_module))
|
|
419
|
+
if not name.startswith("_") and callable(getattr(phase_module, name, None))
|
|
420
|
+
]
|
|
421
|
+
|
|
422
|
+
def _src_line(f):
|
|
423
|
+
try:
|
|
424
|
+
return inspect.getsourcelines(f)[1]
|
|
425
|
+
except (OSError, TypeError):
|
|
426
|
+
return 999999
|
|
427
|
+
|
|
428
|
+
phase_funcs.sort(key=_src_line)
|
|
429
|
+
return phase_funcs
|
|
430
|
+
|
|
431
|
+
|
|
432
|
+
def _wrap_phase(original_func, phase_index, total_phases):
|
|
433
|
+
"""包装一个 phase 函数,在调用期间设置 ``_current_phase`` 上下文。
|
|
434
|
+
|
|
435
|
+
这样 Layer 1 的 ``_traced_pass_call`` 就知道当前 pass 属于哪个 phase。
|
|
436
|
+
phase 名格式:``phase<N>_<func_name>``(多次 run 时加 ``run<M>_`` 前缀)。
|
|
437
|
+
|
|
438
|
+
``phase_index == 1`` 时递增 ``_run_counter``;非首次 run 时重置 ``_run_dir``
|
|
439
|
+
使后续 pass 落盘到新目录。
|
|
440
|
+
"""
|
|
441
|
+
base_phase_name = f"phase{phase_index}_{original_func.__name__}"
|
|
442
|
+
|
|
443
|
+
@functools.wraps(original_func)
|
|
444
|
+
def wrapper(*args, **kwargs):
|
|
445
|
+
global _run_counter, _current_phase, _run_dir
|
|
446
|
+
|
|
447
|
+
with _lock:
|
|
448
|
+
if phase_index == 1:
|
|
449
|
+
_run_counter += 1
|
|
450
|
+
# 非首次 run 时重置 _run_dir,使新 run 落盘到新目录
|
|
451
|
+
if _run_counter > 1:
|
|
452
|
+
_run_dir = None
|
|
453
|
+
run_prefix = f"run{_run_counter}_" if _run_counter > 1 else ""
|
|
454
|
+
phase_name = f"{run_prefix}{base_phase_name}"
|
|
455
|
+
_current_phase = phase_name
|
|
456
|
+
|
|
457
|
+
try:
|
|
458
|
+
result = original_func(*args, **kwargs)
|
|
459
|
+
except Exception as e:
|
|
460
|
+
with _lock:
|
|
461
|
+
_current_phase = None
|
|
462
|
+
print(f" {_ANSI_RED}{_TAG} EXCEPTION in {phase_name}: {e}{_ANSI_RESET}")
|
|
463
|
+
raise
|
|
464
|
+
|
|
465
|
+
with _lock:
|
|
466
|
+
_current_phase = None
|
|
467
|
+
if phase_index == total_phases:
|
|
468
|
+
print(
|
|
469
|
+
f" {_TAG} run {_run_counter} ({phase_name}) complete: {len(_records)} total records"
|
|
470
|
+
)
|
|
471
|
+
|
|
472
|
+
return result
|
|
473
|
+
|
|
474
|
+
return wrapper
|
|
475
|
+
|
|
476
|
+
|
|
477
|
+
# ── Layer 3: Codegen 拦截(简化版,只捕获不编辑)─────────────────────────────
|
|
478
|
+
|
|
479
|
+
|
|
480
|
+
def _wrap_codegen_ffi(original_build, ffi_name=""):
|
|
481
|
+
"""包装 codegen FFI,捕获 TIR-before / C++-after 作为一条 ``STATUS_CODEGEN`` 记录。
|
|
482
|
+
|
|
483
|
+
只做 **捕获**:调用 codegen 前后各拍一个快照,记录差异。
|
|
484
|
+
|
|
485
|
+
临时设置 ``_current_phase = 'codegen'``,使 codegen 内部的 pass
|
|
486
|
+
(如 ``device_codegen`` 中的 ``tir.transform.Simplify``)也被 Layer 1 捕获,
|
|
487
|
+
并归属于 codegen 阶段。
|
|
488
|
+
|
|
489
|
+
Parameters
|
|
490
|
+
----------
|
|
491
|
+
original_build : Callable
|
|
492
|
+
原始 codegen FFI 函数(如 ``target.build.tilelang_ascend``)。
|
|
493
|
+
ffi_name : str
|
|
494
|
+
FFI 注册名,用于调试日志。
|
|
495
|
+
"""
|
|
496
|
+
|
|
497
|
+
@functools.wraps(original_build)
|
|
498
|
+
def wrapper(*args, **kwargs):
|
|
499
|
+
global _pass_index, _current_phase
|
|
500
|
+
|
|
501
|
+
if _original_pass_call is None:
|
|
502
|
+
return original_build(*args, **kwargs)
|
|
503
|
+
|
|
504
|
+
mod = args[0] if args else kwargs.get("mod")
|
|
505
|
+
before_text = str(mod)
|
|
506
|
+
|
|
507
|
+
with _lock:
|
|
508
|
+
previous_phase = _current_phase
|
|
509
|
+
_current_phase = "codegen"
|
|
510
|
+
|
|
511
|
+
after_text = ""
|
|
512
|
+
try:
|
|
513
|
+
result = original_build(*args, **kwargs)
|
|
514
|
+
except Exception as e:
|
|
515
|
+
with _lock:
|
|
516
|
+
idx = _pass_index
|
|
517
|
+
_pass_index += 1
|
|
518
|
+
record = LowerRecord(
|
|
519
|
+
phase="codegen",
|
|
520
|
+
name=getattr(original_build, "__name__", "codegen"),
|
|
521
|
+
index=idx,
|
|
522
|
+
before_text=before_text,
|
|
523
|
+
after_text="",
|
|
524
|
+
changed=False,
|
|
525
|
+
status=STATUS_FAILED,
|
|
526
|
+
error_msg=str(e),
|
|
527
|
+
)
|
|
528
|
+
_records.append(record)
|
|
529
|
+
_save_raw_files(record)
|
|
530
|
+
_current_phase = previous_phase
|
|
531
|
+
print(f" {_ANSI_RED}{_TAG} codegen/{idx:02d}_codegen: FAILED ({e}){_ANSI_RESET}")
|
|
532
|
+
raise
|
|
533
|
+
|
|
534
|
+
try:
|
|
535
|
+
with _lock:
|
|
536
|
+
idx = _pass_index
|
|
537
|
+
_pass_index += 1
|
|
538
|
+
after_text = _inspect_module_source(result)
|
|
539
|
+
add_count, del_count = _count_diff(before_text, after_text)
|
|
540
|
+
|
|
541
|
+
with _lock:
|
|
542
|
+
record = LowerRecord(
|
|
543
|
+
phase="codegen",
|
|
544
|
+
name="codegen",
|
|
545
|
+
index=idx,
|
|
546
|
+
before_text=before_text,
|
|
547
|
+
after_text=after_text,
|
|
548
|
+
changed=True,
|
|
549
|
+
add_lines=add_count,
|
|
550
|
+
del_lines=del_count,
|
|
551
|
+
status=STATUS_CODEGEN,
|
|
552
|
+
)
|
|
553
|
+
_records.append(record)
|
|
554
|
+
_save_raw_files(record)
|
|
555
|
+
print(
|
|
556
|
+
f" {_TAG} codegen/{idx:02d}_codegen: {_ANSI_BLUE}CODEGEN{_ANSI_RESET} (+{add_count}/-{del_count})"
|
|
557
|
+
)
|
|
558
|
+
except Exception as exc:
|
|
559
|
+
print(f" {_ANSI_RED}{_TAG} WARNING: post-codegen tracing failed: {exc}{_ANSI_RESET}")
|
|
560
|
+
finally:
|
|
561
|
+
with _lock:
|
|
562
|
+
_current_phase = previous_phase
|
|
563
|
+
|
|
564
|
+
print_diff(before_text, after_text, "codegen (TIR before)", "codegen (C++ after)")
|
|
565
|
+
return result
|
|
566
|
+
|
|
567
|
+
return wrapper
|
|
568
|
+
|
|
569
|
+
|
|
570
|
+
# ── 启停控制 ──────────────────────────────────────────────────────────────────
|
|
571
|
+
|
|
572
|
+
|
|
573
|
+
def enable(*, trace_dir: str | None = None):
|
|
574
|
+
"""安装三层 monkey-patch 钩子,开始跟踪编译 Pass。
|
|
575
|
+
|
|
576
|
+
幂等:重复调用不会重复安装。
|
|
577
|
+
|
|
578
|
+
Parameters
|
|
579
|
+
----------
|
|
580
|
+
trace_dir : str, optional
|
|
581
|
+
输出根目录,默认 ``./tmp/lower_trace``。
|
|
582
|
+
|
|
583
|
+
Raises
|
|
584
|
+
------
|
|
585
|
+
RuntimeError
|
|
586
|
+
如果 tvm 未安装。tvm 是 tracing 的硬依赖(Pass.__call__ 拦截需要)。
|
|
587
|
+
请先安装 tilelang(会带 tvm)。
|
|
588
|
+
"""
|
|
589
|
+
global _trace_dir, _original_pass_call, _legacy_patched
|
|
590
|
+
|
|
591
|
+
if trace_dir is not None:
|
|
592
|
+
_trace_dir = str(trace_dir)
|
|
593
|
+
|
|
594
|
+
# tvm 是硬依赖:没有 tvm 无法拦截 Pass.__call__
|
|
595
|
+
try:
|
|
596
|
+
from tvm.ir.transform import Pass
|
|
597
|
+
except ImportError as exc:
|
|
598
|
+
raise RuntimeError(
|
|
599
|
+
"lower_trace requires tvm to enable tracing. "
|
|
600
|
+
"Install tilelang first: pip install tilelang (or use the tilelang-ascend environment)."
|
|
601
|
+
) from exc
|
|
602
|
+
|
|
603
|
+
# Layer 1: 拦截所有 Pass.__call__
|
|
604
|
+
if _original_pass_call is None:
|
|
605
|
+
_original_pass_call = Pass.__call__
|
|
606
|
+
Pass.__call__ = _traced_pass_call
|
|
607
|
+
|
|
608
|
+
# Layer 3: 拦截 codegen FFI
|
|
609
|
+
if not _original_codegen_ffis:
|
|
610
|
+
_ffi = _get_tvm_ffi()
|
|
611
|
+
for ffi_name in _CODEGEN_FFI_NAMES:
|
|
612
|
+
try:
|
|
613
|
+
orig = _ffi.get_global_func(ffi_name)
|
|
614
|
+
if orig is not None:
|
|
615
|
+
wrapped = _wrap_codegen_ffi(orig, ffi_name)
|
|
616
|
+
_original_codegen_ffis[ffi_name] = orig
|
|
617
|
+
_ffi.register_global_func(ffi_name, wrapped, override=True)
|
|
618
|
+
except Exception as exc:
|
|
619
|
+
print(f"{_TAG} WARNING: could not wrap codegen FFI {ffi_name}: {exc}")
|
|
620
|
+
|
|
621
|
+
if _legacy_patched:
|
|
622
|
+
return
|
|
623
|
+
|
|
624
|
+
# Layer 2: 包装 phase 函数(tilelang 专用)
|
|
625
|
+
# tilelang 缺失时降级:Pass 跟踪仍可用(全标 unscoped),但无 phase 标签
|
|
626
|
+
try:
|
|
627
|
+
import tilelang.engine.lower as lower_mod
|
|
628
|
+
|
|
629
|
+
lower_func = lower_mod.lower
|
|
630
|
+
patch_mod = lower_mod
|
|
631
|
+
except (ImportError, AttributeError):
|
|
632
|
+
try:
|
|
633
|
+
from tilelang.engine import lower as lower_func
|
|
634
|
+
|
|
635
|
+
import tilelang.engine as patch_mod
|
|
636
|
+
except (ImportError, AttributeError) as e:
|
|
637
|
+
print(
|
|
638
|
+
f"{_TAG} WARNING: tilelang engine not found ({e}); phase tracing disabled (passes will be tagged 'unscoped')."
|
|
639
|
+
)
|
|
640
|
+
return
|
|
641
|
+
|
|
642
|
+
phase_funcs = _discover_phases(lower_func)
|
|
643
|
+
for i, phase_func in enumerate(phase_funcs):
|
|
644
|
+
wrapped = _wrap_phase(phase_func, i + 1, len(phase_funcs))
|
|
645
|
+
name = phase_func.__name__
|
|
646
|
+
|
|
647
|
+
# 在三个位置保存原始引用,disable() 时据此恢复
|
|
648
|
+
_legacy_phase_originals.append((patch_mod, name, getattr(patch_mod, name, _MISSING), False))
|
|
649
|
+
setattr(patch_mod, name, wrapped)
|
|
650
|
+
try:
|
|
651
|
+
from tilelang.engine import phase as phase_module
|
|
652
|
+
|
|
653
|
+
if hasattr(phase_module, name):
|
|
654
|
+
_legacy_phase_originals.append(
|
|
655
|
+
(phase_module, name, getattr(phase_module, name, _MISSING), False)
|
|
656
|
+
)
|
|
657
|
+
setattr(phase_module, name, wrapped)
|
|
658
|
+
except ImportError:
|
|
659
|
+
pass
|
|
660
|
+
glbls = getattr(lower_func, "__globals__", None)
|
|
661
|
+
if glbls is not None and name in glbls:
|
|
662
|
+
_legacy_phase_originals.append((glbls, name, glbls[name], True))
|
|
663
|
+
glbls[name] = wrapped
|
|
664
|
+
|
|
665
|
+
_legacy_patched = True
|
|
666
|
+
print(f"{_TAG} IR pass tracing enabled ({len(phase_funcs)} phases, trace_dir={_trace_dir}).")
|
|
667
|
+
|
|
668
|
+
|
|
669
|
+
def disable():
|
|
670
|
+
"""卸载所有 monkey-patch 钩子,恢复原始行为。
|
|
671
|
+
|
|
672
|
+
完全可逆:恢复 ``Pass.__call__``、codegen FFI、phase 函数,
|
|
673
|
+
清空记录和状态。
|
|
674
|
+
"""
|
|
675
|
+
global _original_pass_call, _legacy_patched, _legacy_phase_originals
|
|
676
|
+
global _run_counter, _run_dir, _trace_dir
|
|
677
|
+
|
|
678
|
+
if _original_pass_call is not None:
|
|
679
|
+
from tvm.ir.transform import Pass
|
|
680
|
+
|
|
681
|
+
Pass.__call__ = _original_pass_call
|
|
682
|
+
_original_pass_call = None
|
|
683
|
+
|
|
684
|
+
# 恢复 codegen FFI(仅在实际包装过时才需要 FFI)
|
|
685
|
+
if _original_codegen_ffis:
|
|
686
|
+
_ffi = _get_tvm_ffi()
|
|
687
|
+
for ffi_name, orig in _original_codegen_ffis.items():
|
|
688
|
+
with contextlib.suppress(Exception):
|
|
689
|
+
_ffi.register_global_func(ffi_name, orig, override=True)
|
|
690
|
+
_original_codegen_ffis.clear()
|
|
691
|
+
|
|
692
|
+
# 恢复 phase 函数(三个位置)
|
|
693
|
+
for target, name, original, is_dict in _legacy_phase_originals:
|
|
694
|
+
with contextlib.suppress(Exception):
|
|
695
|
+
if original is _MISSING:
|
|
696
|
+
if is_dict:
|
|
697
|
+
del target[name]
|
|
698
|
+
else:
|
|
699
|
+
delattr(target, name)
|
|
700
|
+
else:
|
|
701
|
+
if is_dict:
|
|
702
|
+
target[name] = original
|
|
703
|
+
else:
|
|
704
|
+
setattr(target, name, original)
|
|
705
|
+
_legacy_phase_originals = []
|
|
706
|
+
|
|
707
|
+
_legacy_patched = False
|
|
708
|
+
_run_counter = 0
|
|
709
|
+
_run_dir = None
|
|
710
|
+
_trace_dir = _DEFAULT_TRACE_DIR
|
|
711
|
+
reset()
|
|
712
|
+
|
|
713
|
+
|
|
714
|
+
def reset():
|
|
715
|
+
"""清空已收集的记录和 pass 序号。
|
|
716
|
+
|
|
717
|
+
``_run_dir`` 不清空:清空会导致同一次 run 的 pass 分散到多个目录
|
|
718
|
+
(pre-pipeline pass 先于 phase 创建 ``_run_dir``)。
|
|
719
|
+
新 run 的 ``_run_dir`` 由 ``_wrap_phase`` 在 ``_run_counter > 1`` 时重置。
|
|
720
|
+
"""
|
|
721
|
+
global _records, _current_phase, _pass_index
|
|
722
|
+
_records = []
|
|
723
|
+
_current_phase = None
|
|
724
|
+
_pass_index = 0
|
|
725
|
+
|
|
726
|
+
|
|
727
|
+
# ── 便捷上下文管理器 ──────────────────────────────────────────────────────────
|
|
728
|
+
|
|
729
|
+
|
|
730
|
+
@contextlib.contextmanager
|
|
731
|
+
def trace(*, trace_dir: str | None = None):
|
|
732
|
+
"""上下文管理器:``with trace(): ...`` 自动 enable/disable。
|
|
733
|
+
|
|
734
|
+
Example::
|
|
735
|
+
|
|
736
|
+
with trace():
|
|
737
|
+
kernel = tilelang.jit(func)()
|
|
738
|
+
"""
|
|
739
|
+
enable(trace_dir=trace_dir)
|
|
740
|
+
try:
|
|
741
|
+
yield
|
|
742
|
+
finally:
|
|
743
|
+
disable()
|