weco 0.2.5__tar.gz → 0.2.6__tar.gz
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- {weco-0.2.5 → weco-0.2.6}/.github/workflows/release.yml +2 -2
- {weco-0.2.5 → weco-0.2.6}/PKG-INFO +47 -10
- {weco-0.2.5 → weco-0.2.6}/README.md +45 -8
- weco-0.2.6/examples/cuda/evaluate.py +157 -0
- weco-0.2.6/examples/cuda/guide.md +113 -0
- weco-0.2.6/examples/cuda/optimize.py +44 -0
- {weco-0.2.5/examples/simple-torch → weco-0.2.6/examples/hello-kernel-world}/evaluate.py +32 -17
- {weco-0.2.5/examples/simple-mlx → weco-0.2.6/examples/metal}/evaluate.py +28 -20
- weco-0.2.5/examples/simple-mlx/metal-examples.rst → weco-0.2.6/examples/metal/examples.rst +2 -1
- weco-0.2.6/examples/metal/optimize.py +28 -0
- weco-0.2.6/examples/triton/evaluate.py +153 -0
- weco-0.2.6/examples/triton/optimize.py +44 -0
- {weco-0.2.5 → weco-0.2.6}/pyproject.toml +2 -2
- {weco-0.2.5 → weco-0.2.6}/weco/__init__.py +1 -1
- {weco-0.2.5 → weco-0.2.6}/weco/cli.py +6 -1
- {weco-0.2.5 → weco-0.2.6}/weco/panels.py +12 -6
- {weco-0.2.5 → weco-0.2.6}/weco.egg-info/PKG-INFO +47 -10
- weco-0.2.6/weco.egg-info/SOURCES.txt +27 -0
- weco-0.2.5/examples/simple-mlx/optimize.py +0 -26
- weco-0.2.5/weco.egg-info/SOURCES.txt +0 -22
- {weco-0.2.5 → weco-0.2.6}/.github/workflows/lint.yml +0 -0
- {weco-0.2.5 → weco-0.2.6}/.gitignore +0 -0
- {weco-0.2.5 → weco-0.2.6}/LICENSE +0 -0
- {weco-0.2.5/examples/simple-torch → weco-0.2.6/examples/hello-kernel-world}/optimize.py +0 -0
- {weco-0.2.5 → weco-0.2.6}/setup.cfg +0 -0
- {weco-0.2.5 → weco-0.2.6}/weco/api.py +0 -0
- {weco-0.2.5 → weco-0.2.6}/weco/utils.py +0 -0
- {weco-0.2.5 → weco-0.2.6}/weco.egg-info/dependency_links.txt +0 -0
- {weco-0.2.5 → weco-0.2.6}/weco.egg-info/entry_points.txt +0 -0
- {weco-0.2.5 → weco-0.2.6}/weco.egg-info/requires.txt +0 -0
- {weco-0.2.5 → weco-0.2.6}/weco.egg-info/top_level.txt +0 -0
|
@@ -90,7 +90,7 @@ jobs:
|
|
|
90
90
|
GITHUB_TOKEN: ${{ github.token }}
|
|
91
91
|
run: >-
|
|
92
92
|
gh release create
|
|
93
|
-
'v0.2.
|
|
93
|
+
'v0.2.6'
|
|
94
94
|
--repo '${{ github.repository }}'
|
|
95
95
|
--notes ""
|
|
96
96
|
|
|
@@ -102,5 +102,5 @@ jobs:
|
|
|
102
102
|
# sigstore-produced signatures and certificates.
|
|
103
103
|
run: >-
|
|
104
104
|
gh release upload
|
|
105
|
-
'v0.2.
|
|
105
|
+
'v0.2.6' dist/**
|
|
106
106
|
--repo '${{ github.repository }}'
|
|
@@ -1,8 +1,8 @@
|
|
|
1
1
|
Metadata-Version: 2.4
|
|
2
2
|
Name: weco
|
|
3
|
-
Version: 0.2.
|
|
3
|
+
Version: 0.2.6
|
|
4
4
|
Summary: Documentation for `weco`, a CLI for using Weco AI's code optimizer.
|
|
5
|
-
Author-email: Weco AI Team <
|
|
5
|
+
Author-email: Weco AI Team <contact@weco.ai>
|
|
6
6
|
License: MIT
|
|
7
7
|
Project-URL: Homepage, https://github.com/WecoAI/weco-cli
|
|
8
8
|
Keywords: AI,Code Optimization,Code Generation
|
|
@@ -99,32 +99,69 @@ Here's how `weco` can be applied to common ML engineering tasks:
|
|
|
99
99
|
|
|
100
100
|
### Examples
|
|
101
101
|
|
|
102
|
-
**Example 1: Optimizing PyTorch operations**
|
|
102
|
+
**Example 1: Optimizing PyTorch simple operations**
|
|
103
103
|
|
|
104
104
|
```bash
|
|
105
|
-
|
|
106
|
-
|
|
105
|
+
cd examples/hello-kernel-world
|
|
106
|
+
pip install torch
|
|
107
|
+
weco --source optimize.py \
|
|
108
|
+
--eval-command "python evaluate.py --solution-path optimize.py --device cpu" \
|
|
107
109
|
--metric speedup \
|
|
108
110
|
--maximize true \
|
|
109
111
|
--steps 15 \
|
|
110
|
-
--model
|
|
112
|
+
--model claude-3-7-sonnet-20250219 \
|
|
111
113
|
--additional-instructions "Fuse operations in the forward method while ensuring the max float deviation remains small. Maintain the same format of the code."
|
|
112
114
|
```
|
|
113
115
|
|
|
116
|
+
Note that if you have an NVIDIA gpu, change the device to `cuda`. If you are running this on Apple Silicon, set it to `mps`.
|
|
117
|
+
|
|
114
118
|
**Example 2: Optimizing MLX operations with instructions from a file**
|
|
115
119
|
|
|
116
|
-
Sometimes, additional context or instructions are too complex for a single command-line string. You can provide a path to a file containing these instructions.
|
|
120
|
+
Lets optimize a 2D convolution operation in [`mlx`](https://github.com/ml-explore/mlx) using [Metal](https://developer.apple.com/documentation/metal/). Sometimes, additional context or instructions are too complex for a single command-line string. You can provide a path to a file containing these instructions.
|
|
117
121
|
|
|
118
122
|
```bash
|
|
119
|
-
|
|
120
|
-
|
|
123
|
+
cd examples/metal
|
|
124
|
+
pip install mlx
|
|
125
|
+
weco --source optimize.py \
|
|
126
|
+
--eval-command "python evaluate.py --solution-path optimize.py" \
|
|
121
127
|
--metric speedup \
|
|
122
128
|
--maximize true \
|
|
123
129
|
--steps 30 \
|
|
124
130
|
--model o3-mini \
|
|
125
|
-
--additional-instructions examples
|
|
131
|
+
--additional-instructions examples.rst
|
|
126
132
|
```
|
|
127
133
|
|
|
134
|
+
**Example 3: Level Agnostic Optimization: Causal Self Attention with Triton & CUDA**
|
|
135
|
+
|
|
136
|
+
Given how useful causal multihead self attention is to transformers, we've seen its wide adoption across ML engineering and AI research. Its great to keep things at a high-level (in PyTorch) when doing research, but when moving to production you often need to write highly customized low-level kernels to make things run as fast as they can. The `weco` CLI can optimize kernels across a variety of different abstraction levels and frameworks. Example 2 uses Metal but lets explore two more frameworks:
|
|
137
|
+
|
|
138
|
+
1. [Triton](https://github.com/triton-lang/triton)
|
|
139
|
+
```bash
|
|
140
|
+
cd examples/triton
|
|
141
|
+
pip install torch triton
|
|
142
|
+
weco --source optimize.py \
|
|
143
|
+
--eval-command "python evaluate.py --solution-path optimize.py" \
|
|
144
|
+
--metric speedup \
|
|
145
|
+
--maximize true \
|
|
146
|
+
--steps 30 \
|
|
147
|
+
--model gemini-2.5-pro-preview-03-25 \
|
|
148
|
+
--additional-instructions "Use triton to optimize the code while ensuring a small max float diff. Maintain the same code format."
|
|
149
|
+
```
|
|
150
|
+
|
|
151
|
+
2. [CUDA](https://docs.nvidia.com/cuda/cuda-c-programming-guide/index.html)
|
|
152
|
+
```bash
|
|
153
|
+
cd examples/cuda
|
|
154
|
+
pip install torch
|
|
155
|
+
weco --source optimize.py \
|
|
156
|
+
--eval-command "python evaluate.py --solution-path optimize.py" \
|
|
157
|
+
--metric speedup \
|
|
158
|
+
--maximize true \
|
|
159
|
+
--steps 30 \
|
|
160
|
+
--model gemini-2.5-pro-preview-03-25 \
|
|
161
|
+
--additional-instructions guide.md
|
|
162
|
+
```
|
|
163
|
+
|
|
164
|
+
|
|
128
165
|
---
|
|
129
166
|
|
|
130
167
|
### Command Line Arguments
|
|
@@ -77,32 +77,69 @@ Here's how `weco` can be applied to common ML engineering tasks:
|
|
|
77
77
|
|
|
78
78
|
### Examples
|
|
79
79
|
|
|
80
|
-
**Example 1: Optimizing PyTorch operations**
|
|
80
|
+
**Example 1: Optimizing PyTorch simple operations**
|
|
81
81
|
|
|
82
82
|
```bash
|
|
83
|
-
|
|
84
|
-
|
|
83
|
+
cd examples/hello-kernel-world
|
|
84
|
+
pip install torch
|
|
85
|
+
weco --source optimize.py \
|
|
86
|
+
--eval-command "python evaluate.py --solution-path optimize.py --device cpu" \
|
|
85
87
|
--metric speedup \
|
|
86
88
|
--maximize true \
|
|
87
89
|
--steps 15 \
|
|
88
|
-
--model
|
|
90
|
+
--model claude-3-7-sonnet-20250219 \
|
|
89
91
|
--additional-instructions "Fuse operations in the forward method while ensuring the max float deviation remains small. Maintain the same format of the code."
|
|
90
92
|
```
|
|
91
93
|
|
|
94
|
+
Note that if you have an NVIDIA gpu, change the device to `cuda`. If you are running this on Apple Silicon, set it to `mps`.
|
|
95
|
+
|
|
92
96
|
**Example 2: Optimizing MLX operations with instructions from a file**
|
|
93
97
|
|
|
94
|
-
Sometimes, additional context or instructions are too complex for a single command-line string. You can provide a path to a file containing these instructions.
|
|
98
|
+
Lets optimize a 2D convolution operation in [`mlx`](https://github.com/ml-explore/mlx) using [Metal](https://developer.apple.com/documentation/metal/). Sometimes, additional context or instructions are too complex for a single command-line string. You can provide a path to a file containing these instructions.
|
|
95
99
|
|
|
96
100
|
```bash
|
|
97
|
-
|
|
98
|
-
|
|
101
|
+
cd examples/metal
|
|
102
|
+
pip install mlx
|
|
103
|
+
weco --source optimize.py \
|
|
104
|
+
--eval-command "python evaluate.py --solution-path optimize.py" \
|
|
99
105
|
--metric speedup \
|
|
100
106
|
--maximize true \
|
|
101
107
|
--steps 30 \
|
|
102
108
|
--model o3-mini \
|
|
103
|
-
--additional-instructions examples
|
|
109
|
+
--additional-instructions examples.rst
|
|
104
110
|
```
|
|
105
111
|
|
|
112
|
+
**Example 3: Level Agnostic Optimization: Causal Self Attention with Triton & CUDA**
|
|
113
|
+
|
|
114
|
+
Given how useful causal multihead self attention is to transformers, we've seen its wide adoption across ML engineering and AI research. Its great to keep things at a high-level (in PyTorch) when doing research, but when moving to production you often need to write highly customized low-level kernels to make things run as fast as they can. The `weco` CLI can optimize kernels across a variety of different abstraction levels and frameworks. Example 2 uses Metal but lets explore two more frameworks:
|
|
115
|
+
|
|
116
|
+
1. [Triton](https://github.com/triton-lang/triton)
|
|
117
|
+
```bash
|
|
118
|
+
cd examples/triton
|
|
119
|
+
pip install torch triton
|
|
120
|
+
weco --source optimize.py \
|
|
121
|
+
--eval-command "python evaluate.py --solution-path optimize.py" \
|
|
122
|
+
--metric speedup \
|
|
123
|
+
--maximize true \
|
|
124
|
+
--steps 30 \
|
|
125
|
+
--model gemini-2.5-pro-preview-03-25 \
|
|
126
|
+
--additional-instructions "Use triton to optimize the code while ensuring a small max float diff. Maintain the same code format."
|
|
127
|
+
```
|
|
128
|
+
|
|
129
|
+
2. [CUDA](https://docs.nvidia.com/cuda/cuda-c-programming-guide/index.html)
|
|
130
|
+
```bash
|
|
131
|
+
cd examples/cuda
|
|
132
|
+
pip install torch
|
|
133
|
+
weco --source optimize.py \
|
|
134
|
+
--eval-command "python evaluate.py --solution-path optimize.py" \
|
|
135
|
+
--metric speedup \
|
|
136
|
+
--maximize true \
|
|
137
|
+
--steps 30 \
|
|
138
|
+
--model gemini-2.5-pro-preview-03-25 \
|
|
139
|
+
--additional-instructions guide.md
|
|
140
|
+
```
|
|
141
|
+
|
|
142
|
+
|
|
106
143
|
---
|
|
107
144
|
|
|
108
145
|
### Command Line Arguments
|
|
@@ -0,0 +1,157 @@
|
|
|
1
|
+
import time
|
|
2
|
+
import sys
|
|
3
|
+
import os
|
|
4
|
+
import pathlib
|
|
5
|
+
import importlib
|
|
6
|
+
import traceback
|
|
7
|
+
import torch
|
|
8
|
+
import torch.nn as nn
|
|
9
|
+
import torch.nn.functional as F
|
|
10
|
+
import math
|
|
11
|
+
|
|
12
|
+
|
|
13
|
+
########################################################
|
|
14
|
+
# Baseline
|
|
15
|
+
########################################################
|
|
16
|
+
class Model(nn.Module):
|
|
17
|
+
"""
|
|
18
|
+
A vanilla multi-head masked self-attention layer with a projection at the end.
|
|
19
|
+
"""
|
|
20
|
+
|
|
21
|
+
def __init__(self, n_embd, n_head, attn_pdrop, resid_pdrop, max_seqlen):
|
|
22
|
+
super().__init__()
|
|
23
|
+
assert n_embd % n_head == 0
|
|
24
|
+
# key, query, value projections for all heads, but in a batch
|
|
25
|
+
self.c_attn = nn.Linear(n_embd, 3 * n_embd)
|
|
26
|
+
# output projection
|
|
27
|
+
self.c_proj = nn.Linear(n_embd, n_embd)
|
|
28
|
+
# regularization
|
|
29
|
+
self.attn_dropout = nn.Dropout(attn_pdrop)
|
|
30
|
+
self.resid_dropout = nn.Dropout(resid_pdrop)
|
|
31
|
+
# causal mask to ensure that attention is only applied to the left in the input sequence
|
|
32
|
+
self.register_buffer("bias", torch.tril(torch.ones(max_seqlen, max_seqlen)).view(1, 1, max_seqlen, max_seqlen))
|
|
33
|
+
self.n_head = n_head
|
|
34
|
+
self.n_embd = n_embd
|
|
35
|
+
|
|
36
|
+
def forward(self, x):
|
|
37
|
+
B, T, C = x.size() # batch size, sequence length, embedding dimensionality (n_embd)
|
|
38
|
+
# calculate query, key, values for all heads in batch and move head forward to be the batch dim
|
|
39
|
+
q, k, v = self.c_attn(x).split(self.n_embd, dim=2)
|
|
40
|
+
k = k.view(B, T, self.n_head, C // self.n_head).transpose(1, 2) # (B, nh, T, hs)
|
|
41
|
+
q = q.view(B, T, self.n_head, C // self.n_head).transpose(1, 2) # (B, nh, T, hs)
|
|
42
|
+
v = v.view(B, T, self.n_head, C // self.n_head).transpose(1, 2) # (B, nh, T, hs)
|
|
43
|
+
|
|
44
|
+
# causal self-attention; Self-attend: (B, nh, T, hs) x (B, nh, hs, T) -> (B, nh, T, T)
|
|
45
|
+
att = (q @ k.transpose(-2, -1)) * (1.0 / math.sqrt(k.size(-1)))
|
|
46
|
+
att = att.masked_fill(self.bias[:, :, :T, :T] == 0, float("-inf"))
|
|
47
|
+
att = F.softmax(att, dim=-1)
|
|
48
|
+
att = self.attn_dropout(att)
|
|
49
|
+
y = att @ v # (B, nh, T, T) x (B, nh, T, hs) -> (B, nh, T, hs)
|
|
50
|
+
y = y.transpose(1, 2).contiguous().view(B, T, C) # re-assemble all head outputs side by side
|
|
51
|
+
# output projection
|
|
52
|
+
y = self.resid_dropout(self.c_proj(y))
|
|
53
|
+
return y
|
|
54
|
+
|
|
55
|
+
|
|
56
|
+
########################################################
|
|
57
|
+
# Weco Solution
|
|
58
|
+
########################################################
|
|
59
|
+
def load_module_from_path(module_path: str, add_to_sys_modules: bool = False):
|
|
60
|
+
# Clean out all old compiled extensions to prevent namespace collisions during build
|
|
61
|
+
module_path = pathlib.Path(module_path)
|
|
62
|
+
name = module_path.stem
|
|
63
|
+
spec = importlib.util.spec_from_file_location(name, module_path)
|
|
64
|
+
mod = importlib.util.module_from_spec(spec) # type: ignore
|
|
65
|
+
if add_to_sys_modules:
|
|
66
|
+
sys.modules[name] = mod
|
|
67
|
+
spec.loader.exec_module(mod) # type: ignore
|
|
68
|
+
return mod
|
|
69
|
+
|
|
70
|
+
|
|
71
|
+
########################################################
|
|
72
|
+
# Benchmark
|
|
73
|
+
########################################################
|
|
74
|
+
os.environ["MAX_JOBS"] = "1" # number of workers for building with ninja
|
|
75
|
+
|
|
76
|
+
|
|
77
|
+
def get_inputs(batch_size, seq_len, n_embd, device):
|
|
78
|
+
return torch.randn(batch_size, seq_len, n_embd, device=device, dtype=torch.float32)
|
|
79
|
+
|
|
80
|
+
|
|
81
|
+
def bench(f, inputs, n_warmup, n_rep):
|
|
82
|
+
with torch.no_grad():
|
|
83
|
+
# warmup
|
|
84
|
+
for _ in range(n_warmup):
|
|
85
|
+
f(inputs) # noqa
|
|
86
|
+
|
|
87
|
+
# benchmark
|
|
88
|
+
t_avg = 0.0
|
|
89
|
+
for _ in range(n_rep):
|
|
90
|
+
torch.cuda.empty_cache() # Clear cache before timing
|
|
91
|
+
start_time = time.time()
|
|
92
|
+
f(inputs)
|
|
93
|
+
torch.cuda.synchronize() # Wait for all computations to complete
|
|
94
|
+
t_avg += time.time() - start_time
|
|
95
|
+
t_avg /= n_rep * 1e-3
|
|
96
|
+
return t_avg
|
|
97
|
+
|
|
98
|
+
|
|
99
|
+
if __name__ == "__main__":
|
|
100
|
+
import argparse
|
|
101
|
+
|
|
102
|
+
parser = argparse.ArgumentParser()
|
|
103
|
+
parser.add_argument("--solution-path", type=str, required=True)
|
|
104
|
+
args = parser.parse_args()
|
|
105
|
+
|
|
106
|
+
# benchmarking parameters
|
|
107
|
+
n_correctness_trials = 10
|
|
108
|
+
n_warmup = 1000
|
|
109
|
+
n_rep = 5000
|
|
110
|
+
|
|
111
|
+
# init parameters
|
|
112
|
+
max_seqlen = 512
|
|
113
|
+
seq_len = 256
|
|
114
|
+
n_embd = 768
|
|
115
|
+
n_head = 8
|
|
116
|
+
# turn off dropout to measure correctness well
|
|
117
|
+
attn_pdrop = 0.0
|
|
118
|
+
resid_pdrop = 0.0
|
|
119
|
+
|
|
120
|
+
# input parameters
|
|
121
|
+
batch_size = 32
|
|
122
|
+
|
|
123
|
+
# load solution module
|
|
124
|
+
try:
|
|
125
|
+
torch.manual_seed(0)
|
|
126
|
+
solution_module = load_module_from_path(args.solution_path, add_to_sys_modules=False)
|
|
127
|
+
solution_model = solution_module.Model(
|
|
128
|
+
n_embd=n_embd, n_head=n_head, attn_pdrop=attn_pdrop, resid_pdrop=resid_pdrop, max_seqlen=max_seqlen
|
|
129
|
+
).to("cuda")
|
|
130
|
+
assert isinstance(solution_model, nn.Module)
|
|
131
|
+
except Exception:
|
|
132
|
+
print(f"Candidate module initialization failed: {traceback.format_exc()}")
|
|
133
|
+
exit(1)
|
|
134
|
+
|
|
135
|
+
torch.manual_seed(0)
|
|
136
|
+
baseline_model = Model(
|
|
137
|
+
n_embd=n_embd, n_head=n_head, attn_pdrop=attn_pdrop, resid_pdrop=resid_pdrop, max_seqlen=max_seqlen
|
|
138
|
+
).to("cuda")
|
|
139
|
+
|
|
140
|
+
# measure correctness
|
|
141
|
+
max_diff_avg = 0
|
|
142
|
+
for _ in range(n_correctness_trials):
|
|
143
|
+
inputs = get_inputs(batch_size=batch_size, seq_len=seq_len, n_embd=n_embd, device="cuda")
|
|
144
|
+
with torch.no_grad():
|
|
145
|
+
baseline_output = baseline_model(inputs)
|
|
146
|
+
optimized_output = solution_model(inputs)
|
|
147
|
+
max_diff_avg += torch.max(torch.abs(optimized_output - baseline_output))
|
|
148
|
+
max_diff_avg /= n_correctness_trials
|
|
149
|
+
print(f"max float diff between values of baseline and optimized model: {max_diff_avg}")
|
|
150
|
+
|
|
151
|
+
# measure performance
|
|
152
|
+
inputs = get_inputs(batch_size=batch_size, seq_len=seq_len, n_embd=n_embd, device="cuda")
|
|
153
|
+
t_avg_baseline = bench(baseline_model, inputs, n_warmup, n_rep)
|
|
154
|
+
print(f"baseline time: {t_avg_baseline:.2f}ms")
|
|
155
|
+
t_avg_optimized = bench(solution_model, inputs, n_warmup, n_rep)
|
|
156
|
+
print(f"optimized time: {t_avg_optimized:.2f}ms")
|
|
157
|
+
print(f"speedup: {t_avg_baseline / t_avg_optimized:.2f}x")
|
|
@@ -0,0 +1,113 @@
|
|
|
1
|
+
# Writing In-line CUDA Kernels: 101
|
|
2
|
+
|
|
3
|
+
This document outlines the strategy to improve speedup by writing fused and optimized CUDA kernels using a single-file implementation.
|
|
4
|
+
|
|
5
|
+
## Requirements
|
|
6
|
+
|
|
7
|
+
- **Single-File Implementation:** Develop fused CUDA kernels within one file.
|
|
8
|
+
- **No Fallback Implementation:** Do not include any alternative or fallback code.
|
|
9
|
+
- **Simplicity & Readability:** Write simple, easy-to-understand code and include clear comments.
|
|
10
|
+
- **Avoid Templates:** Use plain fused kernel functions without templates.
|
|
11
|
+
- **Multiple Kernels Allowed:** You can define more than one kernel in the file if needed.
|
|
12
|
+
- **Model Class Requirement:** The solution must include a class `Model` (an instance of `nn.Module`), with the main computation in its `forward` method.
|
|
13
|
+
- **Preserve Initialization:** Do not change the initialization of the `Model` class.
|
|
14
|
+
- **Focus on Efficiency:** Concentrate solely on efficient PyTorch and CUDA coding without capturing logs.
|
|
15
|
+
- **Error Handling:** Any terminal output or errors will be reviewed by an LLM for feedback.
|
|
16
|
+
|
|
17
|
+
## GPU Hardware Specifications
|
|
18
|
+
|
|
19
|
+
Here are some details on the hardware you have access to.
|
|
20
|
+
|
|
21
|
+
```json
|
|
22
|
+
{
|
|
23
|
+
"GPU Architecture": "Ampere",
|
|
24
|
+
"GPU Memory": "40GB",
|
|
25
|
+
"Memory Bandwidth": "1935 GB/s",
|
|
26
|
+
"FP64 TFLOPS": "9.7",
|
|
27
|
+
"FP64 Tensor Core TFLOPS": "19.5",
|
|
28
|
+
"FP32 TFLOPS": "19.5",
|
|
29
|
+
"TF32 Tensor Core TFLOPS": "156 (312 with sparsity)",
|
|
30
|
+
"BFLOAT16 Tensore Core TFLOPS": "312 (624 with sparsity)",
|
|
31
|
+
"FP16 Tensor Core TFLOPS": "312 (624 with sparsity)",
|
|
32
|
+
"INT8 Tensor Core TOPS": "624 (1248 with sparsity)",
|
|
33
|
+
"Register File Size": "64K 32-bit registers per SM",
|
|
34
|
+
"Maximum number of registers per thread": "255",
|
|
35
|
+
"Maximum number of thread blocks per SM": "32",
|
|
36
|
+
"Shared memory capacity per SM": "164 KB",
|
|
37
|
+
"Maximum shared memory per thread block": "163 KB"
|
|
38
|
+
}
|
|
39
|
+
```
|
|
40
|
+
|
|
41
|
+
## Baseline Code
|
|
42
|
+
|
|
43
|
+
The baseline implementation of the `Model` class simply performs an element-wise addition.
|
|
44
|
+
|
|
45
|
+
```python
|
|
46
|
+
import torch
|
|
47
|
+
import torch.nn as nn
|
|
48
|
+
import torch.nn.functional as F
|
|
49
|
+
|
|
50
|
+
class Model(nn.Module):
|
|
51
|
+
def __init__(self) -> None:
|
|
52
|
+
super().__init__()
|
|
53
|
+
|
|
54
|
+
def forward(self, a, b):
|
|
55
|
+
return a + b
|
|
56
|
+
```
|
|
57
|
+
|
|
58
|
+
## Optimized Code
|
|
59
|
+
|
|
60
|
+
The optimized version employs a custom CUDA kernel for fused element-wise addition. The kernel is defined and compiled inline using PyTorch's `load_inline`.
|
|
61
|
+
|
|
62
|
+
```python
|
|
63
|
+
import torch
|
|
64
|
+
import torch.nn as nn
|
|
65
|
+
import torch.nn.functional as F
|
|
66
|
+
from torch.utils.cpp_extension import load_inline
|
|
67
|
+
|
|
68
|
+
# Define the custom CUDA kernel for element-wise addition
|
|
69
|
+
elementwise_add_source = '''
|
|
70
|
+
#include <torch/extension.h>
|
|
71
|
+
#include <cuda_runtime.h>
|
|
72
|
+
|
|
73
|
+
// CUDA kernel for element-wise addition
|
|
74
|
+
__global__ void elementwise_add_kernel(const float* a, const float* b, float* out, int size) {
|
|
75
|
+
int idx = blockIdx.x * blockDim.x + threadIdx.x;
|
|
76
|
+
if (idx < size) {
|
|
77
|
+
out[idx] = a[idx] + b[idx];
|
|
78
|
+
}
|
|
79
|
+
}
|
|
80
|
+
|
|
81
|
+
// Launch function for the CUDA kernel
|
|
82
|
+
torch::Tensor elementwise_add_cuda(torch::Tensor a, torch::Tensor b) {
|
|
83
|
+
auto size = a.numel();
|
|
84
|
+
auto out = torch::zeros_like(a);
|
|
85
|
+
const int block_size = 256;
|
|
86
|
+
const int num_blocks = (size + block_size - 1) / block_size;
|
|
87
|
+
elementwise_add_kernel<<<num_blocks, block_size>>>(a.data_ptr<float>(), b.data_ptr<float>(), out.data_ptr<float>(), size);
|
|
88
|
+
return out;
|
|
89
|
+
}
|
|
90
|
+
'''
|
|
91
|
+
|
|
92
|
+
# C++ function prototype declaration
|
|
93
|
+
elementwise_add_cpp_source = "torch::Tensor elementwise_add_cuda(torch::Tensor a, torch::Tensor b);"
|
|
94
|
+
|
|
95
|
+
# Compile the inline CUDA code for element-wise addition
|
|
96
|
+
elementwise_add = load_inline(
|
|
97
|
+
name="elementwise_add",
|
|
98
|
+
cpp_sources=elementwise_add_cpp_source,
|
|
99
|
+
cuda_sources=elementwise_add_source,
|
|
100
|
+
functions=["elementwise_add_cuda"],
|
|
101
|
+
verbose=True,
|
|
102
|
+
extra_cflags=[""],
|
|
103
|
+
extra_ldflags=[""],
|
|
104
|
+
)
|
|
105
|
+
|
|
106
|
+
class Model(nn.Module):
|
|
107
|
+
def __init__(self) -> None:
|
|
108
|
+
super().__init__()
|
|
109
|
+
self.elementwise_add = elementwise_add
|
|
110
|
+
|
|
111
|
+
def forward(self, a, b):
|
|
112
|
+
return self.elementwise_add.elementwise_add_cuda(a, b)
|
|
113
|
+
```
|
|
@@ -0,0 +1,44 @@
|
|
|
1
|
+
import torch
|
|
2
|
+
import torch.nn as nn
|
|
3
|
+
import torch.nn.functional as F
|
|
4
|
+
import math
|
|
5
|
+
|
|
6
|
+
|
|
7
|
+
class Model(nn.Module):
|
|
8
|
+
"""
|
|
9
|
+
A vanilla multi-head masked self-attention layer with a projection at the end.
|
|
10
|
+
"""
|
|
11
|
+
|
|
12
|
+
def __init__(self, n_embd, n_head, attn_pdrop, resid_pdrop, max_seqlen):
|
|
13
|
+
super().__init__()
|
|
14
|
+
assert n_embd % n_head == 0
|
|
15
|
+
# key, query, value projections for all heads, but in a batch
|
|
16
|
+
self.c_attn = nn.Linear(n_embd, 3 * n_embd)
|
|
17
|
+
# output projection
|
|
18
|
+
self.c_proj = nn.Linear(n_embd, n_embd)
|
|
19
|
+
# regularization
|
|
20
|
+
self.attn_dropout = nn.Dropout(attn_pdrop)
|
|
21
|
+
self.resid_dropout = nn.Dropout(resid_pdrop)
|
|
22
|
+
# causal mask to ensure that attention is only applied to the left in the input sequence
|
|
23
|
+
self.register_buffer("bias", torch.tril(torch.ones(max_seqlen, max_seqlen)).view(1, 1, max_seqlen, max_seqlen))
|
|
24
|
+
self.n_head = n_head
|
|
25
|
+
self.n_embd = n_embd
|
|
26
|
+
|
|
27
|
+
def forward(self, x):
|
|
28
|
+
B, T, C = x.size() # batch size, sequence length, embedding dimensionality (n_embd)
|
|
29
|
+
# calculate query, key, values for all heads in batch and move head forward to be the batch dim
|
|
30
|
+
q, k, v = self.c_attn(x).split(self.n_embd, dim=2)
|
|
31
|
+
k = k.view(B, T, self.n_head, C // self.n_head).transpose(1, 2) # (B, nh, T, hs)
|
|
32
|
+
q = q.view(B, T, self.n_head, C // self.n_head).transpose(1, 2) # (B, nh, T, hs)
|
|
33
|
+
v = v.view(B, T, self.n_head, C // self.n_head).transpose(1, 2) # (B, nh, T, hs)
|
|
34
|
+
|
|
35
|
+
# causal self-attention; Self-attend: (B, nh, T, hs) x (B, nh, hs, T) -> (B, nh, T, T)
|
|
36
|
+
att = (q @ k.transpose(-2, -1)) * (1.0 / math.sqrt(k.size(-1)))
|
|
37
|
+
att = att.masked_fill(self.bias[:, :, :T, :T] == 0, float("-inf"))
|
|
38
|
+
att = F.softmax(att, dim=-1)
|
|
39
|
+
att = self.attn_dropout(att)
|
|
40
|
+
y = att @ v # (B, nh, T, T) x (B, nh, T, hs) -> (B, nh, T, hs)
|
|
41
|
+
y = y.transpose(1, 2).contiguous().view(B, T, C) # re-assemble all head outputs side by side
|
|
42
|
+
# output projection
|
|
43
|
+
y = self.resid_dropout(self.c_proj(y))
|
|
44
|
+
return y
|
|
@@ -28,10 +28,10 @@ class Model(nn.Module):
|
|
|
28
28
|
Returns:
|
|
29
29
|
torch.Tensor: Output tensor of shape (batch_size, hidden_size).
|
|
30
30
|
"""
|
|
31
|
-
x = torch.matmul(x, self.weight.T)
|
|
32
|
-
x = x / 2
|
|
33
|
-
x = torch.sum(x, dim=1, keepdim=True)
|
|
34
|
-
x = x * self.scaling_factor
|
|
31
|
+
x = torch.matmul(x, self.weight.T)
|
|
32
|
+
x = x / 2
|
|
33
|
+
x = torch.sum(x, dim=1, keepdim=True)
|
|
34
|
+
x = x * self.scaling_factor
|
|
35
35
|
return x
|
|
36
36
|
|
|
37
37
|
|
|
@@ -60,15 +60,33 @@ def get_inputs(B, N, device):
|
|
|
60
60
|
return torch.randn(B, N, device=device, dtype=torch.float32)
|
|
61
61
|
|
|
62
62
|
|
|
63
|
+
@torch.no_grad()
|
|
63
64
|
def bench(f, inputs, n_warmup, n_rep):
|
|
65
|
+
# Warm up
|
|
64
66
|
for _ in range(n_warmup):
|
|
65
67
|
f(inputs) # noqa
|
|
66
68
|
|
|
69
|
+
# Benchmark
|
|
70
|
+
device_type = inputs.device.type
|
|
67
71
|
t_avg = 0.0
|
|
68
72
|
for _ in range(n_rep):
|
|
73
|
+
# Clear cache before timing
|
|
74
|
+
if device_type == "cuda":
|
|
75
|
+
torch.cuda.empty_cache()
|
|
76
|
+
elif device_type == "mps":
|
|
77
|
+
torch.mps.empty_cache()
|
|
78
|
+
|
|
79
|
+
# time forward pass
|
|
69
80
|
start_time = time.time()
|
|
70
81
|
f(inputs)
|
|
71
82
|
t_avg += time.time() - start_time
|
|
83
|
+
|
|
84
|
+
# Synchronize after each iteration
|
|
85
|
+
if device_type == "cuda":
|
|
86
|
+
torch.cuda.synchronize()
|
|
87
|
+
elif device_type == "mps":
|
|
88
|
+
torch.mps.synchronize()
|
|
89
|
+
|
|
72
90
|
t_avg /= n_rep * 1e-3
|
|
73
91
|
return t_avg
|
|
74
92
|
|
|
@@ -81,14 +99,19 @@ if __name__ == "__main__":
|
|
|
81
99
|
parser.add_argument("--device", default="cpu", type=str)
|
|
82
100
|
args = parser.parse_args()
|
|
83
101
|
|
|
102
|
+
# benchmark parameters
|
|
103
|
+
n_correctness_trials = 10
|
|
104
|
+
n_warmup = 1000
|
|
105
|
+
n_rep = 5000
|
|
106
|
+
|
|
84
107
|
# init and input parameters
|
|
85
|
-
|
|
108
|
+
batch_size, input_size, hidden_size, scaling_factor = 128, 10, 20, 1.5
|
|
86
109
|
|
|
87
110
|
# load solution module
|
|
88
111
|
try:
|
|
89
112
|
torch.manual_seed(0)
|
|
90
113
|
solution_module = load_module_from_path(args.solution_path, add_to_sys_modules=False)
|
|
91
|
-
solution_model = solution_module.Model(
|
|
114
|
+
solution_model = solution_module.Model(input_size, hidden_size, scaling_factor).to(args.device)
|
|
92
115
|
assert isinstance(solution_model, nn.Module)
|
|
93
116
|
assert hasattr(solution_model, "forward")
|
|
94
117
|
except Exception:
|
|
@@ -96,13 +119,12 @@ if __name__ == "__main__":
|
|
|
96
119
|
exit(1)
|
|
97
120
|
|
|
98
121
|
torch.manual_seed(0)
|
|
99
|
-
baseline_model = Model(
|
|
122
|
+
baseline_model = Model(input_size, hidden_size, scaling_factor).to(args.device)
|
|
100
123
|
|
|
101
124
|
# measure correctness
|
|
102
|
-
n_correctness_trials = 10
|
|
103
125
|
max_diff_avg = 0
|
|
104
126
|
for _ in range(n_correctness_trials):
|
|
105
|
-
inputs = get_inputs(
|
|
127
|
+
inputs = get_inputs(batch_size, input_size, args.device)
|
|
106
128
|
baseline_output = baseline_model(inputs)
|
|
107
129
|
optimized_output = solution_model(inputs)
|
|
108
130
|
max_diff_avg += torch.max(torch.abs(optimized_output - baseline_output))
|
|
@@ -110,16 +132,9 @@ if __name__ == "__main__":
|
|
|
110
132
|
print(f"max float diff between values of baseline and optimized model: {max_diff_avg}")
|
|
111
133
|
|
|
112
134
|
# measure performance
|
|
113
|
-
inputs = get_inputs(
|
|
114
|
-
n_warmup = 100
|
|
115
|
-
n_rep = 500
|
|
116
|
-
|
|
117
|
-
# baseline
|
|
135
|
+
inputs = get_inputs(batch_size, input_size, args.device)
|
|
118
136
|
t_avg_baseline = bench(baseline_model, inputs, n_warmup, n_rep)
|
|
119
137
|
print(f"baseline time: {t_avg_baseline:.2f}ms")
|
|
120
|
-
|
|
121
|
-
# optimized
|
|
122
138
|
t_avg_optimized = bench(solution_model, inputs, n_warmup, n_rep)
|
|
123
139
|
print(f"optimized time: {t_avg_optimized:.2f}ms")
|
|
124
|
-
|
|
125
140
|
print(f"speedup: {t_avg_baseline / t_avg_optimized:.2f}x")
|
|
@@ -5,6 +5,7 @@ import importlib
|
|
|
5
5
|
import traceback
|
|
6
6
|
import mlx.core as mx
|
|
7
7
|
import mlx.nn as nn
|
|
8
|
+
from typing import Union
|
|
8
9
|
|
|
9
10
|
|
|
10
11
|
########################################################
|
|
@@ -12,26 +13,27 @@ import mlx.nn as nn
|
|
|
12
13
|
########################################################
|
|
13
14
|
class Model(nn.Module):
|
|
14
15
|
"""
|
|
15
|
-
Model that performs a
|
|
16
|
+
Model that performs a 2D convolution.
|
|
17
|
+
|
|
18
|
+
Args:
|
|
19
|
+
in_channels (int): Number of input channels.
|
|
20
|
+
out_channels (int): Number of output channels.
|
|
21
|
+
kernel_size (Union[int, tuple]): Size of the convolution kernel.
|
|
22
|
+
stride (Union[int, tuple]): Stride of the convolution. Default is 1.
|
|
16
23
|
"""
|
|
17
24
|
|
|
18
|
-
def __init__(self,
|
|
25
|
+
def __init__(self, in_channels: int, out_channels: int, kernel_size: Union[int, tuple], stride: Union[int, tuple] = 1):
|
|
19
26
|
super(Model, self).__init__()
|
|
20
|
-
self.
|
|
21
|
-
self.scaling_factor = scaling_factor
|
|
27
|
+
self.conv = nn.Conv2d(in_channels, out_channels, kernel_size, stride)
|
|
22
28
|
|
|
23
29
|
def __call__(self, x):
|
|
24
30
|
"""
|
|
25
31
|
Args:
|
|
26
|
-
x (mx.array): Input tensor of shape (batch_size,
|
|
32
|
+
x (mx.array): Input tensor of shape (batch_size, height, width, in_channels).
|
|
27
33
|
Returns:
|
|
28
|
-
mx.array: Output tensor of shape (batch_size,
|
|
34
|
+
mx.array: Output tensor of shape (batch_size, height, width, out_channels).
|
|
29
35
|
"""
|
|
30
|
-
|
|
31
|
-
x = x / 2 # Divide
|
|
32
|
-
x = mx.sum(x, axis=1, keepdims=True) # Sum
|
|
33
|
-
x = x * self.scaling_factor # Scaling
|
|
34
|
-
return x
|
|
36
|
+
return self.conv(x)
|
|
35
37
|
|
|
36
38
|
|
|
37
39
|
########################################################
|
|
@@ -52,9 +54,9 @@ def load_module_from_path(module_path: str, add_to_sys_modules: bool = False):
|
|
|
52
54
|
########################################################
|
|
53
55
|
# Benchmark
|
|
54
56
|
########################################################
|
|
55
|
-
def get_inputs(
|
|
57
|
+
def get_inputs(batch_size, img_height, img_width, img_channels):
|
|
56
58
|
# MLX doesn't use device parameter like PyTorch, as it automatically uses Metal
|
|
57
|
-
return mx.random.normal(shape=(
|
|
59
|
+
return mx.random.normal(shape=(batch_size, img_height, img_width, img_channels), dtype=mx.float32)
|
|
58
60
|
|
|
59
61
|
|
|
60
62
|
def bench(f, inputs, n_warmup, n_rep):
|
|
@@ -86,7 +88,13 @@ if __name__ == "__main__":
|
|
|
86
88
|
args = parser.parse_args()
|
|
87
89
|
|
|
88
90
|
# init and input parameters
|
|
89
|
-
|
|
91
|
+
batch_size = 4
|
|
92
|
+
img_height = 224
|
|
93
|
+
img_width = 224
|
|
94
|
+
img_channels = 3
|
|
95
|
+
out_channels = 64
|
|
96
|
+
kernel_size = 3
|
|
97
|
+
stride = 1
|
|
90
98
|
|
|
91
99
|
# Set the default device to 0
|
|
92
100
|
mx.set_default_device(mx.gpu)
|
|
@@ -95,20 +103,20 @@ if __name__ == "__main__":
|
|
|
95
103
|
try:
|
|
96
104
|
mx.random.seed(0)
|
|
97
105
|
solution_module = load_module_from_path(args.solution_path, add_to_sys_modules=False)
|
|
98
|
-
solution_model = solution_module.Model(
|
|
106
|
+
solution_model = solution_module.Model(img_channels, out_channels, kernel_size, stride)
|
|
99
107
|
assert hasattr(solution_model, "__call__")
|
|
100
108
|
except Exception:
|
|
101
109
|
print(f"Candidate module initialization failed: {traceback.format_exc()}")
|
|
102
110
|
exit(1)
|
|
103
111
|
|
|
104
112
|
mx.random.seed(0)
|
|
105
|
-
baseline_model = Model(
|
|
113
|
+
baseline_model = Model(img_channels, out_channels, kernel_size, stride)
|
|
106
114
|
|
|
107
115
|
# measure correctness
|
|
108
116
|
n_correctness_trials = 10
|
|
109
117
|
max_diff_avg = 0
|
|
110
118
|
for _ in range(n_correctness_trials):
|
|
111
|
-
inputs = get_inputs(
|
|
119
|
+
inputs = get_inputs(batch_size, img_height, img_width, img_channels)
|
|
112
120
|
baseline_output = baseline_model(inputs)
|
|
113
121
|
optimized_output = solution_model(inputs)
|
|
114
122
|
max_diff = mx.max(mx.abs(optimized_output - baseline_output))
|
|
@@ -118,9 +126,9 @@ if __name__ == "__main__":
|
|
|
118
126
|
print(f"max float diff between values of baseline and optimized model: {max_diff_avg}")
|
|
119
127
|
|
|
120
128
|
# measure performance
|
|
121
|
-
inputs = get_inputs(
|
|
122
|
-
n_warmup =
|
|
123
|
-
n_rep =
|
|
129
|
+
inputs = get_inputs(batch_size, img_height, img_width, img_channels)
|
|
130
|
+
n_warmup = 1000
|
|
131
|
+
n_rep = 5000
|
|
124
132
|
|
|
125
133
|
# baseline
|
|
126
134
|
t_avg_baseline = bench(baseline_model, inputs, n_warmup, n_rep)
|
|
@@ -3,7 +3,8 @@
|
|
|
3
3
|
Custom Metal Kernels
|
|
4
4
|
====================
|
|
5
5
|
|
|
6
|
-
MLX supports writing custom Metal kernels through the Python and C++ APIs.
|
|
6
|
+
MLX supports writing custom Metal kernels through the Python and C++ APIs. Use Metal kernels allows us to optimize the performance of our models by writing low-level parallelization and vectorization schemes with fuse operations effciently.
|
|
7
|
+
One important thing to keep in mind is correctness. The optimized code should produce outputs that are identical or atleast numerically close to the reference implementation i.e., max float difference less than 1e-5.
|
|
7
8
|
|
|
8
9
|
When designing a custom kernel, ensure that you maintain the format of having a 'Model' class with a '__call__' method as this is what will be called to evaluate the solution.
|
|
9
10
|
|
|
@@ -0,0 +1,28 @@
|
|
|
1
|
+
import mlx.core as mx # noqa
|
|
2
|
+
import mlx.nn as nn
|
|
3
|
+
from typing import Union
|
|
4
|
+
|
|
5
|
+
|
|
6
|
+
class Model(nn.Module):
|
|
7
|
+
"""
|
|
8
|
+
Model that performs a 2D convolution.
|
|
9
|
+
|
|
10
|
+
Args:
|
|
11
|
+
in_channels (int): Number of input channels.
|
|
12
|
+
out_channels (int): Number of output channels.
|
|
13
|
+
kernel_size (Union[int, tuple]): Size of the convolution kernel.
|
|
14
|
+
stride (Union[int, tuple]): Stride of the convolution. Default is 1.
|
|
15
|
+
"""
|
|
16
|
+
|
|
17
|
+
def __init__(self, in_channels: int, out_channels: int, kernel_size: Union[int, tuple], stride: Union[int, tuple] = 1):
|
|
18
|
+
super(Model, self).__init__()
|
|
19
|
+
self.conv = nn.Conv2d(in_channels, out_channels, kernel_size, stride)
|
|
20
|
+
|
|
21
|
+
def __call__(self, x):
|
|
22
|
+
"""
|
|
23
|
+
Args:
|
|
24
|
+
x (mx.array): Input tensor of shape (batch_size, height, width, in_channels).
|
|
25
|
+
Returns:
|
|
26
|
+
mx.array: Output tensor of shape (batch_size, height, width, out_channels).
|
|
27
|
+
"""
|
|
28
|
+
return self.conv(x)
|
|
@@ -0,0 +1,153 @@
|
|
|
1
|
+
import time
|
|
2
|
+
import sys
|
|
3
|
+
import pathlib
|
|
4
|
+
import importlib
|
|
5
|
+
import traceback
|
|
6
|
+
import torch
|
|
7
|
+
import torch.nn as nn
|
|
8
|
+
import torch.nn.functional as F
|
|
9
|
+
import math
|
|
10
|
+
|
|
11
|
+
|
|
12
|
+
########################################################
|
|
13
|
+
# Baseline
|
|
14
|
+
########################################################
|
|
15
|
+
class Model(nn.Module):
|
|
16
|
+
"""
|
|
17
|
+
A vanilla multi-head masked self-attention layer with a projection at the end.
|
|
18
|
+
"""
|
|
19
|
+
|
|
20
|
+
def __init__(self, n_embd, n_head, attn_pdrop, resid_pdrop, max_seqlen):
|
|
21
|
+
super().__init__()
|
|
22
|
+
assert n_embd % n_head == 0
|
|
23
|
+
# key, query, value projections for all heads, but in a batch
|
|
24
|
+
self.c_attn = nn.Linear(n_embd, 3 * n_embd)
|
|
25
|
+
# output projection
|
|
26
|
+
self.c_proj = nn.Linear(n_embd, n_embd)
|
|
27
|
+
# regularization
|
|
28
|
+
self.attn_dropout = nn.Dropout(attn_pdrop)
|
|
29
|
+
self.resid_dropout = nn.Dropout(resid_pdrop)
|
|
30
|
+
# causal mask to ensure that attention is only applied to the left in the input sequence
|
|
31
|
+
self.register_buffer("bias", torch.tril(torch.ones(max_seqlen, max_seqlen)).view(1, 1, max_seqlen, max_seqlen))
|
|
32
|
+
self.n_head = n_head
|
|
33
|
+
self.n_embd = n_embd
|
|
34
|
+
|
|
35
|
+
def forward(self, x):
|
|
36
|
+
B, T, C = x.size() # batch size, sequence length, embedding dimensionality (n_embd)
|
|
37
|
+
# calculate query, key, values for all heads in batch and move head forward to be the batch dim
|
|
38
|
+
q, k, v = self.c_attn(x).split(self.n_embd, dim=2)
|
|
39
|
+
k = k.view(B, T, self.n_head, C // self.n_head).transpose(1, 2) # (B, nh, T, hs)
|
|
40
|
+
q = q.view(B, T, self.n_head, C // self.n_head).transpose(1, 2) # (B, nh, T, hs)
|
|
41
|
+
v = v.view(B, T, self.n_head, C // self.n_head).transpose(1, 2) # (B, nh, T, hs)
|
|
42
|
+
|
|
43
|
+
# causal self-attention; Self-attend: (B, nh, T, hs) x (B, nh, hs, T) -> (B, nh, T, T)
|
|
44
|
+
att = (q @ k.transpose(-2, -1)) * (1.0 / math.sqrt(k.size(-1)))
|
|
45
|
+
att = att.masked_fill(self.bias[:, :, :T, :T] == 0, float("-inf"))
|
|
46
|
+
att = F.softmax(att, dim=-1)
|
|
47
|
+
att = self.attn_dropout(att)
|
|
48
|
+
y = att @ v # (B, nh, T, T) x (B, nh, T, hs) -> (B, nh, T, hs)
|
|
49
|
+
y = y.transpose(1, 2).contiguous().view(B, T, C) # re-assemble all head outputs side by side
|
|
50
|
+
# output projection
|
|
51
|
+
y = self.resid_dropout(self.c_proj(y))
|
|
52
|
+
return y
|
|
53
|
+
|
|
54
|
+
|
|
55
|
+
########################################################
|
|
56
|
+
# Weco Solution
|
|
57
|
+
########################################################
|
|
58
|
+
def load_module_from_path(module_path: str, add_to_sys_modules: bool = False):
|
|
59
|
+
# Clean out all old compiled extensions to prevent namespace collisions during build
|
|
60
|
+
module_path = pathlib.Path(module_path)
|
|
61
|
+
name = module_path.stem
|
|
62
|
+
spec = importlib.util.spec_from_file_location(name, module_path)
|
|
63
|
+
mod = importlib.util.module_from_spec(spec) # type: ignore
|
|
64
|
+
if add_to_sys_modules:
|
|
65
|
+
sys.modules[name] = mod
|
|
66
|
+
spec.loader.exec_module(mod) # type: ignore
|
|
67
|
+
return mod
|
|
68
|
+
|
|
69
|
+
|
|
70
|
+
########################################################
|
|
71
|
+
# Benchmark
|
|
72
|
+
########################################################
|
|
73
|
+
def get_inputs(batch_size, seq_len, n_embd, device):
|
|
74
|
+
return torch.randn(batch_size, seq_len, n_embd, device=device, dtype=torch.float32)
|
|
75
|
+
|
|
76
|
+
|
|
77
|
+
@torch.no_grad()
|
|
78
|
+
def bench(f, inputs, n_warmup, n_rep):
|
|
79
|
+
# warmup
|
|
80
|
+
for _ in range(n_warmup):
|
|
81
|
+
f(inputs) # noqa
|
|
82
|
+
|
|
83
|
+
# benchmark
|
|
84
|
+
t_avg = 0.0
|
|
85
|
+
for _ in range(n_rep):
|
|
86
|
+
torch.cuda.empty_cache() # Clear cache before timing
|
|
87
|
+
start_time = time.time()
|
|
88
|
+
f(inputs)
|
|
89
|
+
torch.cuda.synchronize() # Wait for all computations to complete
|
|
90
|
+
t_avg += time.time() - start_time
|
|
91
|
+
t_avg /= n_rep * 1e-3
|
|
92
|
+
return t_avg
|
|
93
|
+
|
|
94
|
+
|
|
95
|
+
if __name__ == "__main__":
|
|
96
|
+
import argparse
|
|
97
|
+
|
|
98
|
+
parser = argparse.ArgumentParser()
|
|
99
|
+
parser.add_argument("--solution-path", type=str, required=True)
|
|
100
|
+
args = parser.parse_args()
|
|
101
|
+
|
|
102
|
+
# benchmarking parameters
|
|
103
|
+
n_correctness_trials = 10
|
|
104
|
+
n_warmup = 1000
|
|
105
|
+
n_rep = 5000
|
|
106
|
+
|
|
107
|
+
# init parameters
|
|
108
|
+
max_seqlen = 512
|
|
109
|
+
seq_len = 256
|
|
110
|
+
n_embd = 768
|
|
111
|
+
n_head = 8
|
|
112
|
+
# turn off dropout to measure correctness well
|
|
113
|
+
attn_pdrop = 0.0
|
|
114
|
+
resid_pdrop = 0.0
|
|
115
|
+
|
|
116
|
+
# input parameters
|
|
117
|
+
batch_size = 32
|
|
118
|
+
|
|
119
|
+
# load solution module
|
|
120
|
+
try:
|
|
121
|
+
torch.manual_seed(0)
|
|
122
|
+
solution_module = load_module_from_path(args.solution_path, add_to_sys_modules=False)
|
|
123
|
+
solution_model = solution_module.Model(
|
|
124
|
+
n_embd=n_embd, n_head=n_head, attn_pdrop=attn_pdrop, resid_pdrop=resid_pdrop, max_seqlen=max_seqlen
|
|
125
|
+
).to("cuda")
|
|
126
|
+
assert isinstance(solution_model, nn.Module)
|
|
127
|
+
except Exception:
|
|
128
|
+
print(f"Candidate module initialization failed: {traceback.format_exc()}")
|
|
129
|
+
exit(1)
|
|
130
|
+
|
|
131
|
+
torch.manual_seed(0)
|
|
132
|
+
baseline_model = Model(
|
|
133
|
+
n_embd=n_embd, n_head=n_head, attn_pdrop=attn_pdrop, resid_pdrop=resid_pdrop, max_seqlen=max_seqlen
|
|
134
|
+
).to("cuda")
|
|
135
|
+
|
|
136
|
+
# measure correctness
|
|
137
|
+
max_diff_avg = 0
|
|
138
|
+
for _ in range(n_correctness_trials):
|
|
139
|
+
inputs = get_inputs(batch_size=batch_size, seq_len=seq_len, n_embd=n_embd, device="cuda")
|
|
140
|
+
with torch.no_grad():
|
|
141
|
+
baseline_output = baseline_model(inputs)
|
|
142
|
+
optimized_output = solution_model(inputs)
|
|
143
|
+
max_diff_avg += torch.max(torch.abs(optimized_output - baseline_output))
|
|
144
|
+
max_diff_avg /= n_correctness_trials
|
|
145
|
+
print(f"max float diff between values of baseline and optimized model: {max_diff_avg}")
|
|
146
|
+
|
|
147
|
+
# measure performance
|
|
148
|
+
inputs = get_inputs(batch_size=batch_size, seq_len=seq_len, n_embd=n_embd, device="cuda")
|
|
149
|
+
t_avg_baseline = bench(baseline_model, inputs, n_warmup, n_rep)
|
|
150
|
+
print(f"baseline time: {t_avg_baseline:.2f}ms")
|
|
151
|
+
t_avg_optimized = bench(solution_model, inputs, n_warmup, n_rep)
|
|
152
|
+
print(f"optimized time: {t_avg_optimized:.2f}ms")
|
|
153
|
+
print(f"speedup: {t_avg_baseline / t_avg_optimized:.2f}x")
|
|
@@ -0,0 +1,44 @@
|
|
|
1
|
+
import torch
|
|
2
|
+
import torch.nn as nn
|
|
3
|
+
import torch.nn.functional as F
|
|
4
|
+
import math
|
|
5
|
+
|
|
6
|
+
|
|
7
|
+
class Model(nn.Module):
|
|
8
|
+
"""
|
|
9
|
+
A vanilla multi-head masked self-attention layer with a projection at the end.
|
|
10
|
+
"""
|
|
11
|
+
|
|
12
|
+
def __init__(self, n_embd, n_head, attn_pdrop, resid_pdrop, max_seqlen):
|
|
13
|
+
super().__init__()
|
|
14
|
+
assert n_embd % n_head == 0
|
|
15
|
+
# key, query, value projections for all heads, but in a batch
|
|
16
|
+
self.c_attn = nn.Linear(n_embd, 3 * n_embd)
|
|
17
|
+
# output projection
|
|
18
|
+
self.c_proj = nn.Linear(n_embd, n_embd)
|
|
19
|
+
# regularization
|
|
20
|
+
self.attn_dropout = nn.Dropout(attn_pdrop)
|
|
21
|
+
self.resid_dropout = nn.Dropout(resid_pdrop)
|
|
22
|
+
# causal mask to ensure that attention is only applied to the left in the input sequence
|
|
23
|
+
self.register_buffer("bias", torch.tril(torch.ones(max_seqlen, max_seqlen)).view(1, 1, max_seqlen, max_seqlen))
|
|
24
|
+
self.n_head = n_head
|
|
25
|
+
self.n_embd = n_embd
|
|
26
|
+
|
|
27
|
+
def forward(self, x):
|
|
28
|
+
B, T, C = x.size() # batch size, sequence length, embedding dimensionality (n_embd)
|
|
29
|
+
# calculate query, key, values for all heads in batch and move head forward to be the batch dim
|
|
30
|
+
q, k, v = self.c_attn(x).split(self.n_embd, dim=2)
|
|
31
|
+
k = k.view(B, T, self.n_head, C // self.n_head).transpose(1, 2) # (B, nh, T, hs)
|
|
32
|
+
q = q.view(B, T, self.n_head, C // self.n_head).transpose(1, 2) # (B, nh, T, hs)
|
|
33
|
+
v = v.view(B, T, self.n_head, C // self.n_head).transpose(1, 2) # (B, nh, T, hs)
|
|
34
|
+
|
|
35
|
+
# causal self-attention; Self-attend: (B, nh, T, hs) x (B, nh, hs, T) -> (B, nh, T, T)
|
|
36
|
+
att = (q @ k.transpose(-2, -1)) * (1.0 / math.sqrt(k.size(-1)))
|
|
37
|
+
att = att.masked_fill(self.bias[:, :, :T, :T] == 0, float("-inf"))
|
|
38
|
+
att = F.softmax(att, dim=-1)
|
|
39
|
+
att = self.attn_dropout(att)
|
|
40
|
+
y = att @ v # (B, nh, T, T) x (B, nh, T, hs) -> (B, nh, T, hs)
|
|
41
|
+
y = y.transpose(1, 2).contiguous().view(B, T, C) # re-assemble all head outputs side by side
|
|
42
|
+
# output projection
|
|
43
|
+
y = self.resid_dropout(self.c_proj(y))
|
|
44
|
+
return y
|
|
@@ -6,11 +6,11 @@ build-backend = "setuptools.build_meta"
|
|
|
6
6
|
[project]
|
|
7
7
|
name = "weco"
|
|
8
8
|
authors = [
|
|
9
|
-
{name = "Weco AI Team", email = "
|
|
9
|
+
{name = "Weco AI Team", email = "contact@weco.ai"},
|
|
10
10
|
]
|
|
11
11
|
description = "Documentation for `weco`, a CLI for using Weco AI's code optimizer."
|
|
12
12
|
readme = "README.md"
|
|
13
|
-
version = "0.2.
|
|
13
|
+
version = "0.2.6"
|
|
14
14
|
license = {text = "MIT"}
|
|
15
15
|
requires-python = ">=3.12"
|
|
16
16
|
dependencies = ["requests", "rich"]
|
|
@@ -321,7 +321,12 @@ def main() -> None:
|
|
|
321
321
|
_, best_solution_panel = solution_panels.get_display(current_step=steps)
|
|
322
322
|
|
|
323
323
|
# Update the end optimization layout
|
|
324
|
-
|
|
324
|
+
final_message = (
|
|
325
|
+
f"{summary_panel.metric_name.capitalize()} {'maximized' if summary_panel.maximize else 'minimized'}! Best solution {summary_panel.metric_name.lower()} = [green]{status_response['best_result']['metric_value']}[/] 🏆"
|
|
326
|
+
if best_solution_node is not None
|
|
327
|
+
else "[red] No solution found.[/]"
|
|
328
|
+
)
|
|
329
|
+
end_optimization_layout["summary"].update(summary_panel.get_display(final_message=final_message))
|
|
325
330
|
end_optimization_layout["tree"].update(tree_panel.get_display())
|
|
326
331
|
end_optimization_layout["best_solution"].update(best_solution_panel)
|
|
327
332
|
|
|
@@ -12,7 +12,9 @@ class SummaryPanel:
|
|
|
12
12
|
"""Holds a summary of the optimization session."""
|
|
13
13
|
|
|
14
14
|
def __init__(self, maximize: bool, metric_name: str, total_steps: int, model: str, session_id: str = None):
|
|
15
|
-
self.
|
|
15
|
+
self.maximize = maximize
|
|
16
|
+
self.metric_name = metric_name
|
|
17
|
+
self.goal = ("Maximizing" if self.maximize else "Minimizing") + f" {self.metric_name}..."
|
|
16
18
|
self.total_input_tokens = 0
|
|
17
19
|
self.total_output_tokens = 0
|
|
18
20
|
self.total_steps = total_steps
|
|
@@ -39,24 +41,28 @@ class SummaryPanel:
|
|
|
39
41
|
self.total_input_tokens += usage["input_tokens"]
|
|
40
42
|
self.total_output_tokens += usage["output_tokens"]
|
|
41
43
|
|
|
42
|
-
def get_display(self) -> Panel:
|
|
44
|
+
def get_display(self, final_message: Optional[str] = None) -> Panel:
|
|
43
45
|
"""Create a summary panel with the relevant information."""
|
|
44
46
|
layout = Layout(name="summary")
|
|
45
47
|
summary_table = Table(show_header=False, box=None, padding=(0, 1))
|
|
46
48
|
# Goal
|
|
47
|
-
|
|
49
|
+
if final_message is not None:
|
|
50
|
+
summary_table.add_row(f"[bold cyan]Result:[/] {final_message}")
|
|
51
|
+
else:
|
|
52
|
+
summary_table.add_row(f"[bold cyan]Goal:[/] {self.goal}")
|
|
53
|
+
summary_table.add_row("")
|
|
54
|
+
# Model used
|
|
55
|
+
summary_table.add_row(f"[bold cyan]Model:[/] {self.model}")
|
|
48
56
|
summary_table.add_row("")
|
|
49
57
|
# Log directory
|
|
50
58
|
runs_dir = f".runs/{self.session_id}"
|
|
51
59
|
summary_table.add_row(f"[bold cyan]Logs:[/] [blue underline]{runs_dir}[/]")
|
|
52
60
|
summary_table.add_row("")
|
|
53
|
-
# Model used
|
|
54
|
-
summary_table.add_row(f"[bold cyan]Model:[/] [yellow]{self.model}[/]")
|
|
55
|
-
summary_table.add_row("")
|
|
56
61
|
# Token counts
|
|
57
62
|
summary_table.add_row(
|
|
58
63
|
f"[bold cyan]Tokens:[/] ↑[yellow]{format_number(self.total_input_tokens)}[/] ↓[yellow]{format_number(self.total_output_tokens)}[/] = [green]{format_number(self.total_input_tokens + self.total_output_tokens)}[/]"
|
|
59
64
|
)
|
|
65
|
+
summary_table.add_row("")
|
|
60
66
|
# Progress bar
|
|
61
67
|
summary_table.add_row(self.progress)
|
|
62
68
|
|
|
@@ -1,8 +1,8 @@
|
|
|
1
1
|
Metadata-Version: 2.4
|
|
2
2
|
Name: weco
|
|
3
|
-
Version: 0.2.
|
|
3
|
+
Version: 0.2.6
|
|
4
4
|
Summary: Documentation for `weco`, a CLI for using Weco AI's code optimizer.
|
|
5
|
-
Author-email: Weco AI Team <
|
|
5
|
+
Author-email: Weco AI Team <contact@weco.ai>
|
|
6
6
|
License: MIT
|
|
7
7
|
Project-URL: Homepage, https://github.com/WecoAI/weco-cli
|
|
8
8
|
Keywords: AI,Code Optimization,Code Generation
|
|
@@ -99,32 +99,69 @@ Here's how `weco` can be applied to common ML engineering tasks:
|
|
|
99
99
|
|
|
100
100
|
### Examples
|
|
101
101
|
|
|
102
|
-
**Example 1: Optimizing PyTorch operations**
|
|
102
|
+
**Example 1: Optimizing PyTorch simple operations**
|
|
103
103
|
|
|
104
104
|
```bash
|
|
105
|
-
|
|
106
|
-
|
|
105
|
+
cd examples/hello-kernel-world
|
|
106
|
+
pip install torch
|
|
107
|
+
weco --source optimize.py \
|
|
108
|
+
--eval-command "python evaluate.py --solution-path optimize.py --device cpu" \
|
|
107
109
|
--metric speedup \
|
|
108
110
|
--maximize true \
|
|
109
111
|
--steps 15 \
|
|
110
|
-
--model
|
|
112
|
+
--model claude-3-7-sonnet-20250219 \
|
|
111
113
|
--additional-instructions "Fuse operations in the forward method while ensuring the max float deviation remains small. Maintain the same format of the code."
|
|
112
114
|
```
|
|
113
115
|
|
|
116
|
+
Note that if you have an NVIDIA gpu, change the device to `cuda`. If you are running this on Apple Silicon, set it to `mps`.
|
|
117
|
+
|
|
114
118
|
**Example 2: Optimizing MLX operations with instructions from a file**
|
|
115
119
|
|
|
116
|
-
Sometimes, additional context or instructions are too complex for a single command-line string. You can provide a path to a file containing these instructions.
|
|
120
|
+
Lets optimize a 2D convolution operation in [`mlx`](https://github.com/ml-explore/mlx) using [Metal](https://developer.apple.com/documentation/metal/). Sometimes, additional context or instructions are too complex for a single command-line string. You can provide a path to a file containing these instructions.
|
|
117
121
|
|
|
118
122
|
```bash
|
|
119
|
-
|
|
120
|
-
|
|
123
|
+
cd examples/metal
|
|
124
|
+
pip install mlx
|
|
125
|
+
weco --source optimize.py \
|
|
126
|
+
--eval-command "python evaluate.py --solution-path optimize.py" \
|
|
121
127
|
--metric speedup \
|
|
122
128
|
--maximize true \
|
|
123
129
|
--steps 30 \
|
|
124
130
|
--model o3-mini \
|
|
125
|
-
--additional-instructions examples
|
|
131
|
+
--additional-instructions examples.rst
|
|
126
132
|
```
|
|
127
133
|
|
|
134
|
+
**Example 3: Level Agnostic Optimization: Causal Self Attention with Triton & CUDA**
|
|
135
|
+
|
|
136
|
+
Given how useful causal multihead self attention is to transformers, we've seen its wide adoption across ML engineering and AI research. Its great to keep things at a high-level (in PyTorch) when doing research, but when moving to production you often need to write highly customized low-level kernels to make things run as fast as they can. The `weco` CLI can optimize kernels across a variety of different abstraction levels and frameworks. Example 2 uses Metal but lets explore two more frameworks:
|
|
137
|
+
|
|
138
|
+
1. [Triton](https://github.com/triton-lang/triton)
|
|
139
|
+
```bash
|
|
140
|
+
cd examples/triton
|
|
141
|
+
pip install torch triton
|
|
142
|
+
weco --source optimize.py \
|
|
143
|
+
--eval-command "python evaluate.py --solution-path optimize.py" \
|
|
144
|
+
--metric speedup \
|
|
145
|
+
--maximize true \
|
|
146
|
+
--steps 30 \
|
|
147
|
+
--model gemini-2.5-pro-preview-03-25 \
|
|
148
|
+
--additional-instructions "Use triton to optimize the code while ensuring a small max float diff. Maintain the same code format."
|
|
149
|
+
```
|
|
150
|
+
|
|
151
|
+
2. [CUDA](https://docs.nvidia.com/cuda/cuda-c-programming-guide/index.html)
|
|
152
|
+
```bash
|
|
153
|
+
cd examples/cuda
|
|
154
|
+
pip install torch
|
|
155
|
+
weco --source optimize.py \
|
|
156
|
+
--eval-command "python evaluate.py --solution-path optimize.py" \
|
|
157
|
+
--metric speedup \
|
|
158
|
+
--maximize true \
|
|
159
|
+
--steps 30 \
|
|
160
|
+
--model gemini-2.5-pro-preview-03-25 \
|
|
161
|
+
--additional-instructions guide.md
|
|
162
|
+
```
|
|
163
|
+
|
|
164
|
+
|
|
128
165
|
---
|
|
129
166
|
|
|
130
167
|
### Command Line Arguments
|
|
@@ -0,0 +1,27 @@
|
|
|
1
|
+
.gitignore
|
|
2
|
+
LICENSE
|
|
3
|
+
README.md
|
|
4
|
+
pyproject.toml
|
|
5
|
+
.github/workflows/lint.yml
|
|
6
|
+
.github/workflows/release.yml
|
|
7
|
+
examples/cuda/evaluate.py
|
|
8
|
+
examples/cuda/guide.md
|
|
9
|
+
examples/cuda/optimize.py
|
|
10
|
+
examples/hello-kernel-world/evaluate.py
|
|
11
|
+
examples/hello-kernel-world/optimize.py
|
|
12
|
+
examples/metal/evaluate.py
|
|
13
|
+
examples/metal/examples.rst
|
|
14
|
+
examples/metal/optimize.py
|
|
15
|
+
examples/triton/evaluate.py
|
|
16
|
+
examples/triton/optimize.py
|
|
17
|
+
weco/__init__.py
|
|
18
|
+
weco/api.py
|
|
19
|
+
weco/cli.py
|
|
20
|
+
weco/panels.py
|
|
21
|
+
weco/utils.py
|
|
22
|
+
weco.egg-info/PKG-INFO
|
|
23
|
+
weco.egg-info/SOURCES.txt
|
|
24
|
+
weco.egg-info/dependency_links.txt
|
|
25
|
+
weco.egg-info/entry_points.txt
|
|
26
|
+
weco.egg-info/requires.txt
|
|
27
|
+
weco.egg-info/top_level.txt
|
|
@@ -1,26 +0,0 @@
|
|
|
1
|
-
import mlx.core as mx
|
|
2
|
-
import mlx.nn as nn
|
|
3
|
-
|
|
4
|
-
|
|
5
|
-
class Model(nn.Module):
|
|
6
|
-
"""
|
|
7
|
-
Model that performs a matrix multiplication, division, summation, and scaling.
|
|
8
|
-
"""
|
|
9
|
-
|
|
10
|
-
def __init__(self, input_size, hidden_size, scaling_factor):
|
|
11
|
-
super(Model, self).__init__()
|
|
12
|
-
self.weight = mx.random.normal(shape=(hidden_size, input_size))
|
|
13
|
-
self.scaling_factor = scaling_factor
|
|
14
|
-
|
|
15
|
-
def __call__(self, x):
|
|
16
|
-
"""
|
|
17
|
-
Args:
|
|
18
|
-
x (mx.array): Input tensor of shape (batch_size, input_size).
|
|
19
|
-
Returns:
|
|
20
|
-
mx.array: Output tensor of shape (batch_size, hidden_size).
|
|
21
|
-
"""
|
|
22
|
-
x = mx.matmul(x, mx.transpose(self.weight))
|
|
23
|
-
x = x / 2
|
|
24
|
-
x = mx.sum(x, axis=1, keepdims=True)
|
|
25
|
-
x = x * self.scaling_factor
|
|
26
|
-
return x
|
|
@@ -1,22 +0,0 @@
|
|
|
1
|
-
.gitignore
|
|
2
|
-
LICENSE
|
|
3
|
-
README.md
|
|
4
|
-
pyproject.toml
|
|
5
|
-
.github/workflows/lint.yml
|
|
6
|
-
.github/workflows/release.yml
|
|
7
|
-
examples/simple-mlx/evaluate.py
|
|
8
|
-
examples/simple-mlx/metal-examples.rst
|
|
9
|
-
examples/simple-mlx/optimize.py
|
|
10
|
-
examples/simple-torch/evaluate.py
|
|
11
|
-
examples/simple-torch/optimize.py
|
|
12
|
-
weco/__init__.py
|
|
13
|
-
weco/api.py
|
|
14
|
-
weco/cli.py
|
|
15
|
-
weco/panels.py
|
|
16
|
-
weco/utils.py
|
|
17
|
-
weco.egg-info/PKG-INFO
|
|
18
|
-
weco.egg-info/SOURCES.txt
|
|
19
|
-
weco.egg-info/dependency_links.txt
|
|
20
|
-
weco.egg-info/entry_points.txt
|
|
21
|
-
weco.egg-info/requires.txt
|
|
22
|
-
weco.egg-info/top_level.txt
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|