wafer-core 0.1.31__py3-none-any.whl → 0.1.32__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- wafer_core/tools/dispatch_baseline/__init__.py +73 -0
- wafer_core/tools/dispatch_baseline/analyzer.py +174 -0
- wafer_core/tools/dispatch_baseline/client.py +196 -0
- wafer_core/tools/dispatch_baseline/codegen.py +246 -0
- wafer_core/tools/dispatch_baseline/dtypes.py +217 -0
- wafer_core/tools/dispatch_baseline/executor.py +360 -0
- wafer_core/tools/dispatch_baseline/roofline.py +165 -0
- wafer_core/utils/kernel_utils/defense.py +812 -10
- wafer_core/utils/kernel_utils/test_reward_hacks.py +140 -0
- {wafer_core-0.1.31.dist-info → wafer_core-0.1.32.dist-info}/METADATA +1 -1
- {wafer_core-0.1.31.dist-info → wafer_core-0.1.32.dist-info}/RECORD +12 -4
- {wafer_core-0.1.31.dist-info → wafer_core-0.1.32.dist-info}/WHEEL +0 -0
|
@@ -0,0 +1,73 @@
|
|
|
1
|
+
"""Kernel trace tool for discovering SOTA kernels.
|
|
2
|
+
|
|
3
|
+
Given a PyTorch operation, traces what kernel PyTorch dispatches to on target hardware.
|
|
4
|
+
|
|
5
|
+
Usage:
|
|
6
|
+
from wafer_core.tools.dispatch_baseline import quick_trace
|
|
7
|
+
|
|
8
|
+
result = quick_trace(
|
|
9
|
+
"torch.matmul(A, B)",
|
|
10
|
+
{"A": (4096, 4096), "B": (4096, 4096)},
|
|
11
|
+
hardware="H100",
|
|
12
|
+
)
|
|
13
|
+
print(result.summary())
|
|
14
|
+
"""
|
|
15
|
+
|
|
16
|
+
from wafer_core.tools.dispatch_baseline.client import (
|
|
17
|
+
lookup_baseline,
|
|
18
|
+
store_baseline,
|
|
19
|
+
)
|
|
20
|
+
from wafer_core.tools.dispatch_baseline.codegen import (
|
|
21
|
+
generate_trace_script,
|
|
22
|
+
parse_op_string,
|
|
23
|
+
update_dtypes,
|
|
24
|
+
update_shapes,
|
|
25
|
+
)
|
|
26
|
+
from wafer_core.tools.dispatch_baseline.dtypes import (
|
|
27
|
+
HardwareSpec,
|
|
28
|
+
KernelInfo,
|
|
29
|
+
KernelTraceConfig,
|
|
30
|
+
KernelTraceResult,
|
|
31
|
+
OpSpec,
|
|
32
|
+
RooflineAnalysis,
|
|
33
|
+
TensorSpec,
|
|
34
|
+
)
|
|
35
|
+
from wafer_core.tools.dispatch_baseline.executor import (
|
|
36
|
+
TraceExecutionResult,
|
|
37
|
+
quick_trace,
|
|
38
|
+
trace_kernel_local,
|
|
39
|
+
trace_kernel_remote,
|
|
40
|
+
)
|
|
41
|
+
from wafer_core.tools.dispatch_baseline.roofline import (
|
|
42
|
+
HARDWARE_SPECS,
|
|
43
|
+
compute_roofline,
|
|
44
|
+
get_hardware_spec,
|
|
45
|
+
)
|
|
46
|
+
|
|
47
|
+
__all__ = [
|
|
48
|
+
# Data types
|
|
49
|
+
"HardwareSpec",
|
|
50
|
+
"KernelInfo",
|
|
51
|
+
"KernelTraceConfig",
|
|
52
|
+
"KernelTraceResult",
|
|
53
|
+
"OpSpec",
|
|
54
|
+
"RooflineAnalysis",
|
|
55
|
+
"TensorSpec",
|
|
56
|
+
"TraceExecutionResult",
|
|
57
|
+
# Codegen
|
|
58
|
+
"generate_trace_script",
|
|
59
|
+
"parse_op_string",
|
|
60
|
+
"update_dtypes",
|
|
61
|
+
"update_shapes",
|
|
62
|
+
# Execution
|
|
63
|
+
"quick_trace",
|
|
64
|
+
"trace_kernel_local",
|
|
65
|
+
"trace_kernel_remote",
|
|
66
|
+
# Database
|
|
67
|
+
"lookup_baseline",
|
|
68
|
+
"store_baseline",
|
|
69
|
+
# Roofline
|
|
70
|
+
"HARDWARE_SPECS",
|
|
71
|
+
"compute_roofline",
|
|
72
|
+
"get_hardware_spec",
|
|
73
|
+
]
|
|
@@ -0,0 +1,174 @@
|
|
|
1
|
+
"""Analyzer for kernel trace output.
|
|
2
|
+
|
|
3
|
+
Parses profiler output to extract kernel information and identify the primary kernel.
|
|
4
|
+
"""
|
|
5
|
+
|
|
6
|
+
import json
|
|
7
|
+
import re
|
|
8
|
+
from dataclasses import dataclass
|
|
9
|
+
from typing import Any
|
|
10
|
+
|
|
11
|
+
from wafer_core.tools.dispatch_baseline.dtypes import KernelInfo, KernelTraceResult, OpSpec
|
|
12
|
+
|
|
13
|
+
|
|
14
|
+
@dataclass(frozen=True)
|
|
15
|
+
class ParsedTraceResult:
|
|
16
|
+
"""Parsed trace output with environment info for caching."""
|
|
17
|
+
|
|
18
|
+
result: KernelTraceResult
|
|
19
|
+
pytorch_version: str
|
|
20
|
+
runtime_version: str
|
|
21
|
+
gpu_arch: str
|
|
22
|
+
|
|
23
|
+
|
|
24
|
+
def parse_trace_output(output: str, op_spec: OpSpec, hardware: str) -> ParsedTraceResult:
|
|
25
|
+
"""Parse kernel trace script output and extract kernel information.
|
|
26
|
+
|
|
27
|
+
Args:
|
|
28
|
+
output: Raw output from the trace script (stdout)
|
|
29
|
+
op_spec: The operation that was traced
|
|
30
|
+
hardware: Target hardware name
|
|
31
|
+
|
|
32
|
+
Returns:
|
|
33
|
+
ParsedTraceResult with extracted kernel information and environment info
|
|
34
|
+
"""
|
|
35
|
+
# Look for our JSON marker in the output
|
|
36
|
+
json_match = re.search(r"KERNEL_TRACE_RESULT_JSON:(.+)$", output, re.MULTILINE)
|
|
37
|
+
|
|
38
|
+
if not json_match:
|
|
39
|
+
return ParsedTraceResult(
|
|
40
|
+
result=KernelTraceResult(
|
|
41
|
+
op_spec=op_spec,
|
|
42
|
+
hardware=hardware,
|
|
43
|
+
kernels=[],
|
|
44
|
+
primary_kernel=None,
|
|
45
|
+
raw_output=output,
|
|
46
|
+
error="Could not find KERNEL_TRACE_RESULT_JSON marker in output",
|
|
47
|
+
),
|
|
48
|
+
pytorch_version="unknown",
|
|
49
|
+
runtime_version="unknown",
|
|
50
|
+
gpu_arch="unknown",
|
|
51
|
+
)
|
|
52
|
+
|
|
53
|
+
try:
|
|
54
|
+
result_json = json.loads(json_match.group(1))
|
|
55
|
+
except json.JSONDecodeError as e:
|
|
56
|
+
return ParsedTraceResult(
|
|
57
|
+
result=KernelTraceResult(
|
|
58
|
+
op_spec=op_spec,
|
|
59
|
+
hardware=hardware,
|
|
60
|
+
kernels=[],
|
|
61
|
+
primary_kernel=None,
|
|
62
|
+
raw_output=output,
|
|
63
|
+
error=f"Failed to parse JSON: {e}",
|
|
64
|
+
),
|
|
65
|
+
pytorch_version="unknown",
|
|
66
|
+
runtime_version="unknown",
|
|
67
|
+
gpu_arch="unknown",
|
|
68
|
+
)
|
|
69
|
+
|
|
70
|
+
# Extract environment info
|
|
71
|
+
env_info = result_json.get("environment", {})
|
|
72
|
+
pytorch_version = env_info.get("pytorch_version", "unknown")
|
|
73
|
+
runtime_version = env_info.get("runtime_version", "unknown")
|
|
74
|
+
gpu_arch = env_info.get("gpu_arch", "unknown")
|
|
75
|
+
|
|
76
|
+
# Extract kernels from JSON
|
|
77
|
+
kernels = []
|
|
78
|
+
for k in result_json.get("kernels", []):
|
|
79
|
+
kernel_info = _parse_kernel_dict(k)
|
|
80
|
+
if kernel_info:
|
|
81
|
+
kernels.append(kernel_info)
|
|
82
|
+
|
|
83
|
+
# Identify primary kernel (longest duration)
|
|
84
|
+
primary_kernel = kernels[0] if kernels else None
|
|
85
|
+
|
|
86
|
+
return ParsedTraceResult(
|
|
87
|
+
result=KernelTraceResult(
|
|
88
|
+
op_spec=op_spec,
|
|
89
|
+
hardware=hardware,
|
|
90
|
+
kernels=kernels,
|
|
91
|
+
primary_kernel=primary_kernel,
|
|
92
|
+
raw_output=output,
|
|
93
|
+
),
|
|
94
|
+
pytorch_version=pytorch_version,
|
|
95
|
+
runtime_version=runtime_version,
|
|
96
|
+
gpu_arch=gpu_arch,
|
|
97
|
+
)
|
|
98
|
+
|
|
99
|
+
|
|
100
|
+
def _parse_kernel_dict(kernel_dict: dict[str, Any]) -> KernelInfo | None:
|
|
101
|
+
"""Parse a kernel dictionary into KernelInfo."""
|
|
102
|
+
name = kernel_dict.get("name", "")
|
|
103
|
+
if not name:
|
|
104
|
+
return None
|
|
105
|
+
|
|
106
|
+
return KernelInfo(
|
|
107
|
+
name=name,
|
|
108
|
+
duration_us=kernel_dict.get("duration_us", 0.0),
|
|
109
|
+
)
|
|
110
|
+
|
|
111
|
+
|
|
112
|
+
def extract_kernels_from_nsys(nsys_output: str) -> list[dict[str, Any]]:
|
|
113
|
+
"""Extract kernel information from nsys output.
|
|
114
|
+
|
|
115
|
+
This is a fallback for when torch.profiler doesn't capture all kernels.
|
|
116
|
+
|
|
117
|
+
Args:
|
|
118
|
+
nsys_output: Output from `nsys stats` or similar
|
|
119
|
+
|
|
120
|
+
Returns:
|
|
121
|
+
List of kernel dictionaries
|
|
122
|
+
"""
|
|
123
|
+
kernels = []
|
|
124
|
+
|
|
125
|
+
# Look for GPU kernel summary lines
|
|
126
|
+
# Format varies, but typically: name, duration, count, etc.
|
|
127
|
+
lines = nsys_output.split("\n")
|
|
128
|
+
|
|
129
|
+
for line in lines:
|
|
130
|
+
# Skip headers and empty lines
|
|
131
|
+
if not line.strip() or "---" in line or "Name" in line:
|
|
132
|
+
continue
|
|
133
|
+
|
|
134
|
+
# Try to parse as nsys kernel line
|
|
135
|
+
# This is a simplified parser - real nsys output varies
|
|
136
|
+
parts = line.split()
|
|
137
|
+
if len(parts) >= 2:
|
|
138
|
+
# Heuristic: last column is often the kernel name
|
|
139
|
+
name = parts[-1]
|
|
140
|
+
if any(c.isalpha() for c in name) and "_" in name:
|
|
141
|
+
kernels.append({"name": name, "duration_us": 0.0})
|
|
142
|
+
|
|
143
|
+
return kernels
|
|
144
|
+
|
|
145
|
+
|
|
146
|
+
def merge_kernel_infos(
|
|
147
|
+
profiler_kernels: list[KernelInfo], nsys_kernels: list[dict[str, Any]]
|
|
148
|
+
) -> list[KernelInfo]:
|
|
149
|
+
"""Merge kernel info from multiple sources.
|
|
150
|
+
|
|
151
|
+
Prefers profiler data when available, but adds any missing kernels from nsys.
|
|
152
|
+
|
|
153
|
+
Args:
|
|
154
|
+
profiler_kernels: Kernels from torch.profiler
|
|
155
|
+
nsys_kernels: Kernels from nsys
|
|
156
|
+
|
|
157
|
+
Returns:
|
|
158
|
+
Merged list of KernelInfo objects
|
|
159
|
+
"""
|
|
160
|
+
seen_names = {k.name for k in profiler_kernels}
|
|
161
|
+
result = list(profiler_kernels)
|
|
162
|
+
|
|
163
|
+
for nsys_kernel in nsys_kernels:
|
|
164
|
+
name = nsys_kernel.get("name", "")
|
|
165
|
+
if name and name not in seen_names:
|
|
166
|
+
result.append(
|
|
167
|
+
KernelInfo(
|
|
168
|
+
name=name,
|
|
169
|
+
duration_us=nsys_kernel.get("duration_us", 0.0),
|
|
170
|
+
)
|
|
171
|
+
)
|
|
172
|
+
seen_names.add(name)
|
|
173
|
+
|
|
174
|
+
return result
|
|
@@ -0,0 +1,196 @@
|
|
|
1
|
+
"""API client for baseline kernel trace results.
|
|
2
|
+
|
|
3
|
+
Interacts with the Supabase baselines table to store and retrieve kernel dispatch info.
|
|
4
|
+
Results are keyed by (op, shape, dtype, pytorch_version, runtime_version, gpu_arch).
|
|
5
|
+
Since kernel dispatch is deterministic for a given environment, this serves as a
|
|
6
|
+
shared lookup table across all users.
|
|
7
|
+
"""
|
|
8
|
+
|
|
9
|
+
import logging
|
|
10
|
+
|
|
11
|
+
import httpx
|
|
12
|
+
|
|
13
|
+
from wafer_core.tools.dispatch_baseline.dtypes import (
|
|
14
|
+
KernelInfo,
|
|
15
|
+
KernelTraceResult,
|
|
16
|
+
OpSpec,
|
|
17
|
+
)
|
|
18
|
+
|
|
19
|
+
logger = logging.getLogger(__name__)
|
|
20
|
+
|
|
21
|
+
# API timeout for baseline lookups
|
|
22
|
+
API_TIMEOUT = 30.0
|
|
23
|
+
|
|
24
|
+
|
|
25
|
+
def _get_api_config() -> tuple[str, dict[str, str]]:
|
|
26
|
+
"""Get API URL and auth headers.
|
|
27
|
+
|
|
28
|
+
Returns:
|
|
29
|
+
Tuple of (api_url, headers)
|
|
30
|
+
"""
|
|
31
|
+
# Import here to avoid circular imports and allow use without CLI
|
|
32
|
+
try:
|
|
33
|
+
from wafer.api_client import get_api_url
|
|
34
|
+
from wafer.auth import get_auth_headers
|
|
35
|
+
|
|
36
|
+
return get_api_url(), get_auth_headers()
|
|
37
|
+
except ImportError:
|
|
38
|
+
# If CLI not installed, return defaults (won't work without auth)
|
|
39
|
+
return "https://api.wafer.ai", {}
|
|
40
|
+
|
|
41
|
+
|
|
42
|
+
def lookup_baseline(
|
|
43
|
+
op_spec: OpSpec,
|
|
44
|
+
hardware: str,
|
|
45
|
+
pytorch_version: str,
|
|
46
|
+
runtime_version: str,
|
|
47
|
+
gpu_arch: str,
|
|
48
|
+
) -> KernelTraceResult | None:
|
|
49
|
+
"""Look up baseline result from the database.
|
|
50
|
+
|
|
51
|
+
Args:
|
|
52
|
+
op_spec: Operation specification
|
|
53
|
+
hardware: Hardware name (for display, not part of cache key)
|
|
54
|
+
pytorch_version: PyTorch version string
|
|
55
|
+
runtime_version: CUDA version string
|
|
56
|
+
gpu_arch: GPU architecture (e.g., "sm_90")
|
|
57
|
+
|
|
58
|
+
Returns:
|
|
59
|
+
KernelTraceResult if found, None otherwise
|
|
60
|
+
"""
|
|
61
|
+
api_url, headers = _get_api_config()
|
|
62
|
+
|
|
63
|
+
if not headers.get("Authorization"):
|
|
64
|
+
logger.debug("No auth headers, skipping baseline lookup")
|
|
65
|
+
return None
|
|
66
|
+
|
|
67
|
+
# Build request
|
|
68
|
+
request_data = {
|
|
69
|
+
"op": op_spec.op,
|
|
70
|
+
"inputs": [
|
|
71
|
+
{"name": t.name, "shape": list(t.shape), "dtype": t.dtype}
|
|
72
|
+
for t in op_spec.inputs
|
|
73
|
+
],
|
|
74
|
+
"kwargs": op_spec.kwargs,
|
|
75
|
+
"pytorch_version": pytorch_version,
|
|
76
|
+
"runtime_version": runtime_version,
|
|
77
|
+
"gpu_arch": gpu_arch,
|
|
78
|
+
}
|
|
79
|
+
|
|
80
|
+
try:
|
|
81
|
+
with httpx.Client(timeout=API_TIMEOUT, headers=headers) as client:
|
|
82
|
+
response = client.post(f"{api_url}/v1/baselines/lookup", json=request_data)
|
|
83
|
+
response.raise_for_status()
|
|
84
|
+
data = response.json()
|
|
85
|
+
|
|
86
|
+
if not data.get("found"):
|
|
87
|
+
return None
|
|
88
|
+
|
|
89
|
+
# Reconstruct result
|
|
90
|
+
kernels = [
|
|
91
|
+
KernelInfo(
|
|
92
|
+
name=k["name"],
|
|
93
|
+
duration_us=k["duration_us"],
|
|
94
|
+
)
|
|
95
|
+
for k in data.get("kernels", [])
|
|
96
|
+
]
|
|
97
|
+
|
|
98
|
+
primary_data = data.get("primary_kernel")
|
|
99
|
+
primary_kernel = KernelInfo(
|
|
100
|
+
name=primary_data["name"],
|
|
101
|
+
duration_us=primary_data["duration_us"],
|
|
102
|
+
) if primary_data else (kernels[0] if kernels else None)
|
|
103
|
+
|
|
104
|
+
return KernelTraceResult(
|
|
105
|
+
op_spec=op_spec,
|
|
106
|
+
hardware=hardware,
|
|
107
|
+
kernels=kernels,
|
|
108
|
+
primary_kernel=primary_kernel,
|
|
109
|
+
# Note: roofline will be recomputed by the caller
|
|
110
|
+
)
|
|
111
|
+
|
|
112
|
+
except httpx.HTTPError as e:
|
|
113
|
+
logger.warning(f"Baseline lookup failed: {e}")
|
|
114
|
+
return None
|
|
115
|
+
except Exception as e:
|
|
116
|
+
logger.warning(f"Baseline lookup error: {e}")
|
|
117
|
+
return None
|
|
118
|
+
|
|
119
|
+
|
|
120
|
+
def store_baseline(
|
|
121
|
+
result: KernelTraceResult,
|
|
122
|
+
pytorch_version: str,
|
|
123
|
+
runtime_version: str,
|
|
124
|
+
gpu_arch: str,
|
|
125
|
+
) -> bool:
|
|
126
|
+
"""Store baseline result in the database.
|
|
127
|
+
|
|
128
|
+
Args:
|
|
129
|
+
result: Trace result to cache
|
|
130
|
+
pytorch_version: PyTorch version string
|
|
131
|
+
runtime_version: CUDA version string
|
|
132
|
+
gpu_arch: GPU architecture
|
|
133
|
+
|
|
134
|
+
Returns:
|
|
135
|
+
True if stored successfully, False otherwise
|
|
136
|
+
"""
|
|
137
|
+
if result.error:
|
|
138
|
+
# Don't cache errors
|
|
139
|
+
return False
|
|
140
|
+
|
|
141
|
+
if not result.primary_kernel:
|
|
142
|
+
# Don't cache empty results
|
|
143
|
+
return False
|
|
144
|
+
|
|
145
|
+
api_url, headers = _get_api_config()
|
|
146
|
+
|
|
147
|
+
if not headers.get("Authorization"):
|
|
148
|
+
logger.debug("No auth headers, skipping baseline store")
|
|
149
|
+
return False
|
|
150
|
+
|
|
151
|
+
# Build request
|
|
152
|
+
request_data = {
|
|
153
|
+
"op": result.op_spec.op,
|
|
154
|
+
"inputs": [
|
|
155
|
+
{"name": t.name, "shape": list(t.shape), "dtype": t.dtype}
|
|
156
|
+
for t in result.op_spec.inputs
|
|
157
|
+
],
|
|
158
|
+
"kwargs": result.op_spec.kwargs,
|
|
159
|
+
"pytorch_version": pytorch_version,
|
|
160
|
+
"runtime_version": runtime_version,
|
|
161
|
+
"gpu_arch": gpu_arch,
|
|
162
|
+
"hardware_name": result.hardware,
|
|
163
|
+
"primary_kernel": {
|
|
164
|
+
"name": result.primary_kernel.name,
|
|
165
|
+
"duration_us": result.primary_kernel.duration_us,
|
|
166
|
+
},
|
|
167
|
+
"kernels": [
|
|
168
|
+
{
|
|
169
|
+
"name": k.name,
|
|
170
|
+
"duration_us": k.duration_us,
|
|
171
|
+
}
|
|
172
|
+
for k in result.kernels
|
|
173
|
+
],
|
|
174
|
+
}
|
|
175
|
+
|
|
176
|
+
try:
|
|
177
|
+
with httpx.Client(timeout=API_TIMEOUT, headers=headers) as client:
|
|
178
|
+
response = client.post(f"{api_url}/v1/baselines/store", json=request_data)
|
|
179
|
+
response.raise_for_status()
|
|
180
|
+
data = response.json()
|
|
181
|
+
|
|
182
|
+
if data.get("created"):
|
|
183
|
+
logger.info(f"Stored baseline: {result.op_spec.op} ({gpu_arch})")
|
|
184
|
+
else:
|
|
185
|
+
logger.debug(f"Baseline already exists: {result.op_spec.op} ({gpu_arch})")
|
|
186
|
+
|
|
187
|
+
return True
|
|
188
|
+
|
|
189
|
+
except httpx.HTTPError as e:
|
|
190
|
+
logger.warning(f"Baseline store failed: {e}")
|
|
191
|
+
return False
|
|
192
|
+
except Exception as e:
|
|
193
|
+
logger.warning(f"Baseline store error: {e}")
|
|
194
|
+
return False
|
|
195
|
+
|
|
196
|
+
|
|
@@ -0,0 +1,246 @@
|
|
|
1
|
+
"""Code generator for kernel trace profiling scripts.
|
|
2
|
+
|
|
3
|
+
Generates minimal Python scripts that:
|
|
4
|
+
1. Create tensors with specified shapes/dtypes
|
|
5
|
+
2. Run the PyTorch op with warmup
|
|
6
|
+
3. Profile using torch.profiler
|
|
7
|
+
4. Output structured JSON with kernel info
|
|
8
|
+
"""
|
|
9
|
+
|
|
10
|
+
from wafer_core.tools.dispatch_baseline.dtypes import KernelTraceConfig, OpSpec, TensorSpec
|
|
11
|
+
|
|
12
|
+
|
|
13
|
+
def generate_trace_script(config: KernelTraceConfig) -> str:
|
|
14
|
+
"""Generate a profiling script for the given operation.
|
|
15
|
+
|
|
16
|
+
Args:
|
|
17
|
+
config: Kernel trace configuration with op spec and settings
|
|
18
|
+
|
|
19
|
+
Returns:
|
|
20
|
+
Python script as a string
|
|
21
|
+
"""
|
|
22
|
+
op_spec = config.op_spec
|
|
23
|
+
tensor_setup = _generate_tensor_setup(op_spec.inputs)
|
|
24
|
+
op_call = _generate_op_call(op_spec)
|
|
25
|
+
|
|
26
|
+
script = f'''"""Auto-generated kernel trace script."""
|
|
27
|
+
import json
|
|
28
|
+
import sys
|
|
29
|
+
import torch
|
|
30
|
+
|
|
31
|
+
# Ensure CUDA is available
|
|
32
|
+
assert torch.cuda.is_available(), "CUDA not available"
|
|
33
|
+
|
|
34
|
+
# Setup tensors
|
|
35
|
+
{tensor_setup}
|
|
36
|
+
|
|
37
|
+
# Warmup
|
|
38
|
+
print("Warming up...", file=sys.stderr)
|
|
39
|
+
for _ in range({config.num_warmup}):
|
|
40
|
+
_ = {op_call}
|
|
41
|
+
torch.cuda.synchronize()
|
|
42
|
+
|
|
43
|
+
# Profile with torch.profiler
|
|
44
|
+
print("Profiling...", file=sys.stderr)
|
|
45
|
+
with torch.profiler.profile(
|
|
46
|
+
activities=[
|
|
47
|
+
torch.profiler.ProfilerActivity.CPU,
|
|
48
|
+
torch.profiler.ProfilerActivity.CUDA,
|
|
49
|
+
],
|
|
50
|
+
record_shapes=True,
|
|
51
|
+
with_stack=False,
|
|
52
|
+
) as prof:
|
|
53
|
+
for _ in range({config.num_runs}):
|
|
54
|
+
_ = {op_call}
|
|
55
|
+
torch.cuda.synchronize()
|
|
56
|
+
|
|
57
|
+
# Extract kernel information
|
|
58
|
+
kernels = []
|
|
59
|
+
for event in prof.key_averages():
|
|
60
|
+
if event.device_type == torch.profiler.DeviceType.CUDA:
|
|
61
|
+
# Use device_time (new) or cuda_time (old) for average time in us
|
|
62
|
+
duration = getattr(event, 'device_time', None) or getattr(event, 'cuda_time', 0)
|
|
63
|
+
kernels.append({{
|
|
64
|
+
"name": event.key,
|
|
65
|
+
"duration_us": duration,
|
|
66
|
+
"count": event.count,
|
|
67
|
+
}})
|
|
68
|
+
|
|
69
|
+
# Sort by duration (longest first)
|
|
70
|
+
kernels.sort(key=lambda k: k["duration_us"], reverse=True)
|
|
71
|
+
|
|
72
|
+
# Get environment info for caching
|
|
73
|
+
props = torch.cuda.get_device_properties(0)
|
|
74
|
+
|
|
75
|
+
# Detect runtime version (CUDA or ROCm)
|
|
76
|
+
if hasattr(torch.version, 'hip') and torch.version.hip:
|
|
77
|
+
runtime_version = torch.version.hip
|
|
78
|
+
# ROCm uses gcnArchName for architecture (e.g., "gfx942")
|
|
79
|
+
gpu_arch = getattr(props, 'gcnArchName', f"gfx{{props.major}}{{props.minor}}")
|
|
80
|
+
else:
|
|
81
|
+
runtime_version = torch.version.cuda or "unknown"
|
|
82
|
+
gpu_arch = f"sm_{{props.major}}{{props.minor}}"
|
|
83
|
+
|
|
84
|
+
env_info = {{
|
|
85
|
+
"pytorch_version": torch.__version__,
|
|
86
|
+
"runtime_version": runtime_version,
|
|
87
|
+
"gpu_arch": gpu_arch,
|
|
88
|
+
"gpu_name": props.name,
|
|
89
|
+
}}
|
|
90
|
+
|
|
91
|
+
# Build result
|
|
92
|
+
result = {{
|
|
93
|
+
"op": "{op_spec.op}",
|
|
94
|
+
"inputs": {_serialize_inputs(op_spec.inputs)},
|
|
95
|
+
"num_runs": {config.num_runs},
|
|
96
|
+
"kernels": kernels,
|
|
97
|
+
"total_cuda_time_us": sum(k["duration_us"] for k in kernels),
|
|
98
|
+
"environment": env_info,
|
|
99
|
+
}}
|
|
100
|
+
|
|
101
|
+
# Output as JSON (marker for parsing)
|
|
102
|
+
print("KERNEL_TRACE_RESULT_JSON:" + json.dumps(result))
|
|
103
|
+
'''
|
|
104
|
+
return script
|
|
105
|
+
|
|
106
|
+
|
|
107
|
+
def _generate_tensor_setup(inputs: list[TensorSpec]) -> str:
|
|
108
|
+
"""Generate tensor creation code."""
|
|
109
|
+
lines = []
|
|
110
|
+
for tensor in inputs:
|
|
111
|
+
shape_str = ", ".join(str(d) for d in tensor.shape)
|
|
112
|
+
dtype_map = {
|
|
113
|
+
"float16": "torch.float16",
|
|
114
|
+
"float32": "torch.float32",
|
|
115
|
+
"bfloat16": "torch.bfloat16",
|
|
116
|
+
"float64": "torch.float64",
|
|
117
|
+
"int8": "torch.int8",
|
|
118
|
+
"int16": "torch.int16",
|
|
119
|
+
"int32": "torch.int32",
|
|
120
|
+
"int64": "torch.int64",
|
|
121
|
+
}
|
|
122
|
+
dtype = dtype_map.get(tensor.dtype, f"torch.{tensor.dtype}")
|
|
123
|
+
|
|
124
|
+
# Use randn for float types, randint for int types
|
|
125
|
+
if "int" in tensor.dtype:
|
|
126
|
+
lines.append(
|
|
127
|
+
f'{tensor.name} = torch.randint(-128, 127, ({shape_str},), '
|
|
128
|
+
f'dtype={dtype}, device="{tensor.device}")'
|
|
129
|
+
)
|
|
130
|
+
else:
|
|
131
|
+
lines.append(
|
|
132
|
+
f'{tensor.name} = torch.randn({shape_str}, dtype={dtype}, device="{tensor.device}")'
|
|
133
|
+
)
|
|
134
|
+
|
|
135
|
+
return "\n".join(lines)
|
|
136
|
+
|
|
137
|
+
|
|
138
|
+
def _generate_op_call(op_spec: OpSpec) -> str:
|
|
139
|
+
"""Generate the operation call code."""
|
|
140
|
+
args = [t.name for t in op_spec.inputs]
|
|
141
|
+
kwargs = [f"{k}={v}" for k, v in op_spec.kwargs.items()]
|
|
142
|
+
all_args = ", ".join(args + kwargs)
|
|
143
|
+
return f"{op_spec.op}({all_args})"
|
|
144
|
+
|
|
145
|
+
|
|
146
|
+
def _serialize_inputs(inputs: list[TensorSpec]) -> str:
|
|
147
|
+
"""Serialize inputs for JSON output."""
|
|
148
|
+
items = []
|
|
149
|
+
for t in inputs:
|
|
150
|
+
items.append(f'{{"name": "{t.name}", "shape": {list(t.shape)}, "dtype": "{t.dtype}"}}')
|
|
151
|
+
return "[" + ", ".join(items) + "]"
|
|
152
|
+
|
|
153
|
+
|
|
154
|
+
def parse_op_string(op_string: str) -> OpSpec:
|
|
155
|
+
"""Parse a simple operation string into an OpSpec.
|
|
156
|
+
|
|
157
|
+
Supports formats like:
|
|
158
|
+
- "torch.matmul(A, B)"
|
|
159
|
+
- "torch.nn.functional.softmax(x, dim=-1)"
|
|
160
|
+
|
|
161
|
+
For shapes, use --shape flags separately.
|
|
162
|
+
|
|
163
|
+
Args:
|
|
164
|
+
op_string: Operation string like "torch.matmul(A, B)"
|
|
165
|
+
|
|
166
|
+
Returns:
|
|
167
|
+
OpSpec with parsed operation and placeholder inputs
|
|
168
|
+
"""
|
|
169
|
+
# Extract op name and arguments
|
|
170
|
+
if "(" not in op_string or ")" not in op_string:
|
|
171
|
+
raise ValueError(f"Invalid op format: {op_string}. Expected: op(args)")
|
|
172
|
+
|
|
173
|
+
op_name = op_string[: op_string.index("(")].strip()
|
|
174
|
+
args_str = op_string[op_string.index("(") + 1 : op_string.rindex(")")].strip()
|
|
175
|
+
|
|
176
|
+
# Parse arguments (simple comma split, doesn't handle nested calls)
|
|
177
|
+
args = [a.strip() for a in args_str.split(",") if a.strip()]
|
|
178
|
+
|
|
179
|
+
# Separate positional args (tensor names) from kwargs
|
|
180
|
+
inputs = []
|
|
181
|
+
kwargs = {}
|
|
182
|
+
|
|
183
|
+
for arg in args:
|
|
184
|
+
if "=" in arg:
|
|
185
|
+
key, value = arg.split("=", 1)
|
|
186
|
+
kwargs[key.strip()] = value.strip()
|
|
187
|
+
else:
|
|
188
|
+
# Assume it's a tensor input with default shape
|
|
189
|
+
inputs.append(
|
|
190
|
+
TensorSpec(
|
|
191
|
+
name=arg,
|
|
192
|
+
shape=(1024, 1024), # Default shape, override with --shape
|
|
193
|
+
dtype="float16",
|
|
194
|
+
)
|
|
195
|
+
)
|
|
196
|
+
|
|
197
|
+
return OpSpec(op=op_name, inputs=inputs, kwargs=kwargs)
|
|
198
|
+
|
|
199
|
+
|
|
200
|
+
def update_shapes(op_spec: OpSpec, shapes: dict[str, tuple[int, ...]]) -> OpSpec:
|
|
201
|
+
"""Update tensor shapes in an OpSpec.
|
|
202
|
+
|
|
203
|
+
Args:
|
|
204
|
+
op_spec: Original OpSpec
|
|
205
|
+
shapes: Dict mapping tensor name to new shape
|
|
206
|
+
|
|
207
|
+
Returns:
|
|
208
|
+
New OpSpec with updated shapes
|
|
209
|
+
"""
|
|
210
|
+
new_inputs = []
|
|
211
|
+
for tensor in op_spec.inputs:
|
|
212
|
+
if tensor.name in shapes:
|
|
213
|
+
new_inputs.append(
|
|
214
|
+
TensorSpec(
|
|
215
|
+
name=tensor.name,
|
|
216
|
+
shape=shapes[tensor.name],
|
|
217
|
+
dtype=tensor.dtype,
|
|
218
|
+
device=tensor.device,
|
|
219
|
+
)
|
|
220
|
+
)
|
|
221
|
+
else:
|
|
222
|
+
new_inputs.append(tensor)
|
|
223
|
+
|
|
224
|
+
return OpSpec(op=op_spec.op, inputs=new_inputs, kwargs=op_spec.kwargs)
|
|
225
|
+
|
|
226
|
+
|
|
227
|
+
def update_dtypes(op_spec: OpSpec, dtype: str) -> OpSpec:
|
|
228
|
+
"""Update all tensor dtypes in an OpSpec.
|
|
229
|
+
|
|
230
|
+
Args:
|
|
231
|
+
op_spec: Original OpSpec
|
|
232
|
+
dtype: New dtype for all tensors
|
|
233
|
+
|
|
234
|
+
Returns:
|
|
235
|
+
New OpSpec with updated dtypes
|
|
236
|
+
"""
|
|
237
|
+
new_inputs = [
|
|
238
|
+
TensorSpec(
|
|
239
|
+
name=t.name,
|
|
240
|
+
shape=t.shape,
|
|
241
|
+
dtype=dtype,
|
|
242
|
+
device=t.device,
|
|
243
|
+
)
|
|
244
|
+
for t in op_spec.inputs
|
|
245
|
+
]
|
|
246
|
+
return OpSpec(op=op_spec.op, inputs=new_inputs, kwargs=op_spec.kwargs)
|