weco 0.2.5__tar.gz → 0.2.6__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (31) hide show
  1. {weco-0.2.5 → weco-0.2.6}/.github/workflows/release.yml +2 -2
  2. {weco-0.2.5 → weco-0.2.6}/PKG-INFO +47 -10
  3. {weco-0.2.5 → weco-0.2.6}/README.md +45 -8
  4. weco-0.2.6/examples/cuda/evaluate.py +157 -0
  5. weco-0.2.6/examples/cuda/guide.md +113 -0
  6. weco-0.2.6/examples/cuda/optimize.py +44 -0
  7. {weco-0.2.5/examples/simple-torch → weco-0.2.6/examples/hello-kernel-world}/evaluate.py +32 -17
  8. {weco-0.2.5/examples/simple-mlx → weco-0.2.6/examples/metal}/evaluate.py +28 -20
  9. weco-0.2.5/examples/simple-mlx/metal-examples.rst → weco-0.2.6/examples/metal/examples.rst +2 -1
  10. weco-0.2.6/examples/metal/optimize.py +28 -0
  11. weco-0.2.6/examples/triton/evaluate.py +153 -0
  12. weco-0.2.6/examples/triton/optimize.py +44 -0
  13. {weco-0.2.5 → weco-0.2.6}/pyproject.toml +2 -2
  14. {weco-0.2.5 → weco-0.2.6}/weco/__init__.py +1 -1
  15. {weco-0.2.5 → weco-0.2.6}/weco/cli.py +6 -1
  16. {weco-0.2.5 → weco-0.2.6}/weco/panels.py +12 -6
  17. {weco-0.2.5 → weco-0.2.6}/weco.egg-info/PKG-INFO +47 -10
  18. weco-0.2.6/weco.egg-info/SOURCES.txt +27 -0
  19. weco-0.2.5/examples/simple-mlx/optimize.py +0 -26
  20. weco-0.2.5/weco.egg-info/SOURCES.txt +0 -22
  21. {weco-0.2.5 → weco-0.2.6}/.github/workflows/lint.yml +0 -0
  22. {weco-0.2.5 → weco-0.2.6}/.gitignore +0 -0
  23. {weco-0.2.5 → weco-0.2.6}/LICENSE +0 -0
  24. {weco-0.2.5/examples/simple-torch → weco-0.2.6/examples/hello-kernel-world}/optimize.py +0 -0
  25. {weco-0.2.5 → weco-0.2.6}/setup.cfg +0 -0
  26. {weco-0.2.5 → weco-0.2.6}/weco/api.py +0 -0
  27. {weco-0.2.5 → weco-0.2.6}/weco/utils.py +0 -0
  28. {weco-0.2.5 → weco-0.2.6}/weco.egg-info/dependency_links.txt +0 -0
  29. {weco-0.2.5 → weco-0.2.6}/weco.egg-info/entry_points.txt +0 -0
  30. {weco-0.2.5 → weco-0.2.6}/weco.egg-info/requires.txt +0 -0
  31. {weco-0.2.5 → weco-0.2.6}/weco.egg-info/top_level.txt +0 -0
@@ -90,7 +90,7 @@ jobs:
90
90
  GITHUB_TOKEN: ${{ github.token }}
91
91
  run: >-
92
92
  gh release create
93
- 'v0.2.5'
93
+ 'v0.2.6'
94
94
  --repo '${{ github.repository }}'
95
95
  --notes ""
96
96
 
@@ -102,5 +102,5 @@ jobs:
102
102
  # sigstore-produced signatures and certificates.
103
103
  run: >-
104
104
  gh release upload
105
- 'v0.2.5' dist/**
105
+ 'v0.2.6' dist/**
106
106
  --repo '${{ github.repository }}'
@@ -1,8 +1,8 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: weco
3
- Version: 0.2.5
3
+ Version: 0.2.6
4
4
  Summary: Documentation for `weco`, a CLI for using Weco AI's code optimizer.
5
- Author-email: Weco AI Team <dhruv@weco.ai>
5
+ Author-email: Weco AI Team <contact@weco.ai>
6
6
  License: MIT
7
7
  Project-URL: Homepage, https://github.com/WecoAI/weco-cli
8
8
  Keywords: AI,Code Optimization,Code Generation
@@ -99,32 +99,69 @@ Here's how `weco` can be applied to common ML engineering tasks:
99
99
 
100
100
  ### Examples
101
101
 
102
- **Example 1: Optimizing PyTorch operations**
102
+ **Example 1: Optimizing PyTorch simple operations**
103
103
 
104
104
  ```bash
105
- weco --source examples/simple-torch/optimize.py \
106
- --eval-command "python examples/simple-torch/evaluate.py --solution-path examples/simple-torch/optimize.py --device mps" \
105
+ cd examples/hello-kernel-world
106
+ pip install torch
107
+ weco --source optimize.py \
108
+ --eval-command "python evaluate.py --solution-path optimize.py --device cpu" \
107
109
  --metric speedup \
108
110
  --maximize true \
109
111
  --steps 15 \
110
- --model o3-mini \
112
+ --model claude-3-7-sonnet-20250219 \
111
113
  --additional-instructions "Fuse operations in the forward method while ensuring the max float deviation remains small. Maintain the same format of the code."
112
114
  ```
113
115
 
116
+ Note that if you have an NVIDIA gpu, change the device to `cuda`. If you are running this on Apple Silicon, set it to `mps`.
117
+
114
118
  **Example 2: Optimizing MLX operations with instructions from a file**
115
119
 
116
- Sometimes, additional context or instructions are too complex for a single command-line string. You can provide a path to a file containing these instructions.
120
+ Lets optimize a 2D convolution operation in [`mlx`](https://github.com/ml-explore/mlx) using [Metal](https://developer.apple.com/documentation/metal/). Sometimes, additional context or instructions are too complex for a single command-line string. You can provide a path to a file containing these instructions.
117
121
 
118
122
  ```bash
119
- weco --source examples/simple-mlx/optimize.py \
120
- --eval-command "python examples/simple-mlx/evaluate.py --solution-path examples/simple-mlx/optimize.py" \
123
+ cd examples/metal
124
+ pip install mlx
125
+ weco --source optimize.py \
126
+ --eval-command "python evaluate.py --solution-path optimize.py" \
121
127
  --metric speedup \
122
128
  --maximize true \
123
129
  --steps 30 \
124
130
  --model o3-mini \
125
- --additional-instructions examples/simple-mlx/metal-examples.rst
131
+ --additional-instructions examples.rst
126
132
  ```
127
133
 
134
+ **Example 3: Level Agnostic Optimization: Causal Self Attention with Triton & CUDA**
135
+
136
+ Given how useful causal multihead self attention is to transformers, we've seen its wide adoption across ML engineering and AI research. Its great to keep things at a high-level (in PyTorch) when doing research, but when moving to production you often need to write highly customized low-level kernels to make things run as fast as they can. The `weco` CLI can optimize kernels across a variety of different abstraction levels and frameworks. Example 2 uses Metal but lets explore two more frameworks:
137
+
138
+ 1. [Triton](https://github.com/triton-lang/triton)
139
+ ```bash
140
+ cd examples/triton
141
+ pip install torch triton
142
+ weco --source optimize.py \
143
+ --eval-command "python evaluate.py --solution-path optimize.py" \
144
+ --metric speedup \
145
+ --maximize true \
146
+ --steps 30 \
147
+ --model gemini-2.5-pro-preview-03-25 \
148
+ --additional-instructions "Use triton to optimize the code while ensuring a small max float diff. Maintain the same code format."
149
+ ```
150
+
151
+ 2. [CUDA](https://docs.nvidia.com/cuda/cuda-c-programming-guide/index.html)
152
+ ```bash
153
+ cd examples/cuda
154
+ pip install torch
155
+ weco --source optimize.py \
156
+ --eval-command "python evaluate.py --solution-path optimize.py" \
157
+ --metric speedup \
158
+ --maximize true \
159
+ --steps 30 \
160
+ --model gemini-2.5-pro-preview-03-25 \
161
+ --additional-instructions guide.md
162
+ ```
163
+
164
+
128
165
  ---
129
166
 
130
167
  ### Command Line Arguments
@@ -77,32 +77,69 @@ Here's how `weco` can be applied to common ML engineering tasks:
77
77
 
78
78
  ### Examples
79
79
 
80
- **Example 1: Optimizing PyTorch operations**
80
+ **Example 1: Optimizing PyTorch simple operations**
81
81
 
82
82
  ```bash
83
- weco --source examples/simple-torch/optimize.py \
84
- --eval-command "python examples/simple-torch/evaluate.py --solution-path examples/simple-torch/optimize.py --device mps" \
83
+ cd examples/hello-kernel-world
84
+ pip install torch
85
+ weco --source optimize.py \
86
+ --eval-command "python evaluate.py --solution-path optimize.py --device cpu" \
85
87
  --metric speedup \
86
88
  --maximize true \
87
89
  --steps 15 \
88
- --model o3-mini \
90
+ --model claude-3-7-sonnet-20250219 \
89
91
  --additional-instructions "Fuse operations in the forward method while ensuring the max float deviation remains small. Maintain the same format of the code."
90
92
  ```
91
93
 
94
+ Note that if you have an NVIDIA gpu, change the device to `cuda`. If you are running this on Apple Silicon, set it to `mps`.
95
+
92
96
  **Example 2: Optimizing MLX operations with instructions from a file**
93
97
 
94
- Sometimes, additional context or instructions are too complex for a single command-line string. You can provide a path to a file containing these instructions.
98
+ Lets optimize a 2D convolution operation in [`mlx`](https://github.com/ml-explore/mlx) using [Metal](https://developer.apple.com/documentation/metal/). Sometimes, additional context or instructions are too complex for a single command-line string. You can provide a path to a file containing these instructions.
95
99
 
96
100
  ```bash
97
- weco --source examples/simple-mlx/optimize.py \
98
- --eval-command "python examples/simple-mlx/evaluate.py --solution-path examples/simple-mlx/optimize.py" \
101
+ cd examples/metal
102
+ pip install mlx
103
+ weco --source optimize.py \
104
+ --eval-command "python evaluate.py --solution-path optimize.py" \
99
105
  --metric speedup \
100
106
  --maximize true \
101
107
  --steps 30 \
102
108
  --model o3-mini \
103
- --additional-instructions examples/simple-mlx/metal-examples.rst
109
+ --additional-instructions examples.rst
104
110
  ```
105
111
 
112
+ **Example 3: Level Agnostic Optimization: Causal Self Attention with Triton & CUDA**
113
+
114
+ Given how useful causal multihead self attention is to transformers, we've seen its wide adoption across ML engineering and AI research. Its great to keep things at a high-level (in PyTorch) when doing research, but when moving to production you often need to write highly customized low-level kernels to make things run as fast as they can. The `weco` CLI can optimize kernels across a variety of different abstraction levels and frameworks. Example 2 uses Metal but lets explore two more frameworks:
115
+
116
+ 1. [Triton](https://github.com/triton-lang/triton)
117
+ ```bash
118
+ cd examples/triton
119
+ pip install torch triton
120
+ weco --source optimize.py \
121
+ --eval-command "python evaluate.py --solution-path optimize.py" \
122
+ --metric speedup \
123
+ --maximize true \
124
+ --steps 30 \
125
+ --model gemini-2.5-pro-preview-03-25 \
126
+ --additional-instructions "Use triton to optimize the code while ensuring a small max float diff. Maintain the same code format."
127
+ ```
128
+
129
+ 2. [CUDA](https://docs.nvidia.com/cuda/cuda-c-programming-guide/index.html)
130
+ ```bash
131
+ cd examples/cuda
132
+ pip install torch
133
+ weco --source optimize.py \
134
+ --eval-command "python evaluate.py --solution-path optimize.py" \
135
+ --metric speedup \
136
+ --maximize true \
137
+ --steps 30 \
138
+ --model gemini-2.5-pro-preview-03-25 \
139
+ --additional-instructions guide.md
140
+ ```
141
+
142
+
106
143
  ---
107
144
 
108
145
  ### Command Line Arguments
@@ -0,0 +1,157 @@
1
+ import time
2
+ import sys
3
+ import os
4
+ import pathlib
5
+ import importlib
6
+ import traceback
7
+ import torch
8
+ import torch.nn as nn
9
+ import torch.nn.functional as F
10
+ import math
11
+
12
+
13
+ ########################################################
14
+ # Baseline
15
+ ########################################################
16
+ class Model(nn.Module):
17
+ """
18
+ A vanilla multi-head masked self-attention layer with a projection at the end.
19
+ """
20
+
21
+ def __init__(self, n_embd, n_head, attn_pdrop, resid_pdrop, max_seqlen):
22
+ super().__init__()
23
+ assert n_embd % n_head == 0
24
+ # key, query, value projections for all heads, but in a batch
25
+ self.c_attn = nn.Linear(n_embd, 3 * n_embd)
26
+ # output projection
27
+ self.c_proj = nn.Linear(n_embd, n_embd)
28
+ # regularization
29
+ self.attn_dropout = nn.Dropout(attn_pdrop)
30
+ self.resid_dropout = nn.Dropout(resid_pdrop)
31
+ # causal mask to ensure that attention is only applied to the left in the input sequence
32
+ self.register_buffer("bias", torch.tril(torch.ones(max_seqlen, max_seqlen)).view(1, 1, max_seqlen, max_seqlen))
33
+ self.n_head = n_head
34
+ self.n_embd = n_embd
35
+
36
+ def forward(self, x):
37
+ B, T, C = x.size() # batch size, sequence length, embedding dimensionality (n_embd)
38
+ # calculate query, key, values for all heads in batch and move head forward to be the batch dim
39
+ q, k, v = self.c_attn(x).split(self.n_embd, dim=2)
40
+ k = k.view(B, T, self.n_head, C // self.n_head).transpose(1, 2) # (B, nh, T, hs)
41
+ q = q.view(B, T, self.n_head, C // self.n_head).transpose(1, 2) # (B, nh, T, hs)
42
+ v = v.view(B, T, self.n_head, C // self.n_head).transpose(1, 2) # (B, nh, T, hs)
43
+
44
+ # causal self-attention; Self-attend: (B, nh, T, hs) x (B, nh, hs, T) -> (B, nh, T, T)
45
+ att = (q @ k.transpose(-2, -1)) * (1.0 / math.sqrt(k.size(-1)))
46
+ att = att.masked_fill(self.bias[:, :, :T, :T] == 0, float("-inf"))
47
+ att = F.softmax(att, dim=-1)
48
+ att = self.attn_dropout(att)
49
+ y = att @ v # (B, nh, T, T) x (B, nh, T, hs) -> (B, nh, T, hs)
50
+ y = y.transpose(1, 2).contiguous().view(B, T, C) # re-assemble all head outputs side by side
51
+ # output projection
52
+ y = self.resid_dropout(self.c_proj(y))
53
+ return y
54
+
55
+
56
+ ########################################################
57
+ # Weco Solution
58
+ ########################################################
59
+ def load_module_from_path(module_path: str, add_to_sys_modules: bool = False):
60
+ # Clean out all old compiled extensions to prevent namespace collisions during build
61
+ module_path = pathlib.Path(module_path)
62
+ name = module_path.stem
63
+ spec = importlib.util.spec_from_file_location(name, module_path)
64
+ mod = importlib.util.module_from_spec(spec) # type: ignore
65
+ if add_to_sys_modules:
66
+ sys.modules[name] = mod
67
+ spec.loader.exec_module(mod) # type: ignore
68
+ return mod
69
+
70
+
71
+ ########################################################
72
+ # Benchmark
73
+ ########################################################
74
+ os.environ["MAX_JOBS"] = "1" # number of workers for building with ninja
75
+
76
+
77
+ def get_inputs(batch_size, seq_len, n_embd, device):
78
+ return torch.randn(batch_size, seq_len, n_embd, device=device, dtype=torch.float32)
79
+
80
+
81
+ def bench(f, inputs, n_warmup, n_rep):
82
+ with torch.no_grad():
83
+ # warmup
84
+ for _ in range(n_warmup):
85
+ f(inputs) # noqa
86
+
87
+ # benchmark
88
+ t_avg = 0.0
89
+ for _ in range(n_rep):
90
+ torch.cuda.empty_cache() # Clear cache before timing
91
+ start_time = time.time()
92
+ f(inputs)
93
+ torch.cuda.synchronize() # Wait for all computations to complete
94
+ t_avg += time.time() - start_time
95
+ t_avg /= n_rep * 1e-3
96
+ return t_avg
97
+
98
+
99
+ if __name__ == "__main__":
100
+ import argparse
101
+
102
+ parser = argparse.ArgumentParser()
103
+ parser.add_argument("--solution-path", type=str, required=True)
104
+ args = parser.parse_args()
105
+
106
+ # benchmarking parameters
107
+ n_correctness_trials = 10
108
+ n_warmup = 1000
109
+ n_rep = 5000
110
+
111
+ # init parameters
112
+ max_seqlen = 512
113
+ seq_len = 256
114
+ n_embd = 768
115
+ n_head = 8
116
+ # turn off dropout to measure correctness well
117
+ attn_pdrop = 0.0
118
+ resid_pdrop = 0.0
119
+
120
+ # input parameters
121
+ batch_size = 32
122
+
123
+ # load solution module
124
+ try:
125
+ torch.manual_seed(0)
126
+ solution_module = load_module_from_path(args.solution_path, add_to_sys_modules=False)
127
+ solution_model = solution_module.Model(
128
+ n_embd=n_embd, n_head=n_head, attn_pdrop=attn_pdrop, resid_pdrop=resid_pdrop, max_seqlen=max_seqlen
129
+ ).to("cuda")
130
+ assert isinstance(solution_model, nn.Module)
131
+ except Exception:
132
+ print(f"Candidate module initialization failed: {traceback.format_exc()}")
133
+ exit(1)
134
+
135
+ torch.manual_seed(0)
136
+ baseline_model = Model(
137
+ n_embd=n_embd, n_head=n_head, attn_pdrop=attn_pdrop, resid_pdrop=resid_pdrop, max_seqlen=max_seqlen
138
+ ).to("cuda")
139
+
140
+ # measure correctness
141
+ max_diff_avg = 0
142
+ for _ in range(n_correctness_trials):
143
+ inputs = get_inputs(batch_size=batch_size, seq_len=seq_len, n_embd=n_embd, device="cuda")
144
+ with torch.no_grad():
145
+ baseline_output = baseline_model(inputs)
146
+ optimized_output = solution_model(inputs)
147
+ max_diff_avg += torch.max(torch.abs(optimized_output - baseline_output))
148
+ max_diff_avg /= n_correctness_trials
149
+ print(f"max float diff between values of baseline and optimized model: {max_diff_avg}")
150
+
151
+ # measure performance
152
+ inputs = get_inputs(batch_size=batch_size, seq_len=seq_len, n_embd=n_embd, device="cuda")
153
+ t_avg_baseline = bench(baseline_model, inputs, n_warmup, n_rep)
154
+ print(f"baseline time: {t_avg_baseline:.2f}ms")
155
+ t_avg_optimized = bench(solution_model, inputs, n_warmup, n_rep)
156
+ print(f"optimized time: {t_avg_optimized:.2f}ms")
157
+ print(f"speedup: {t_avg_baseline / t_avg_optimized:.2f}x")
@@ -0,0 +1,113 @@
1
+ # Writing In-line CUDA Kernels: 101
2
+
3
+ This document outlines the strategy to improve speedup by writing fused and optimized CUDA kernels using a single-file implementation.
4
+
5
+ ## Requirements
6
+
7
+ - **Single-File Implementation:** Develop fused CUDA kernels within one file.
8
+ - **No Fallback Implementation:** Do not include any alternative or fallback code.
9
+ - **Simplicity & Readability:** Write simple, easy-to-understand code and include clear comments.
10
+ - **Avoid Templates:** Use plain fused kernel functions without templates.
11
+ - **Multiple Kernels Allowed:** You can define more than one kernel in the file if needed.
12
+ - **Model Class Requirement:** The solution must include a class `Model` (an instance of `nn.Module`), with the main computation in its `forward` method.
13
+ - **Preserve Initialization:** Do not change the initialization of the `Model` class.
14
+ - **Focus on Efficiency:** Concentrate solely on efficient PyTorch and CUDA coding without capturing logs.
15
+ - **Error Handling:** Any terminal output or errors will be reviewed by an LLM for feedback.
16
+
17
+ ## GPU Hardware Specifications
18
+
19
+ Here are some details on the hardware you have access to.
20
+
21
+ ```json
22
+ {
23
+ "GPU Architecture": "Ampere",
24
+ "GPU Memory": "40GB",
25
+ "Memory Bandwidth": "1935 GB/s",
26
+ "FP64 TFLOPS": "9.7",
27
+ "FP64 Tensor Core TFLOPS": "19.5",
28
+ "FP32 TFLOPS": "19.5",
29
+ "TF32 Tensor Core TFLOPS": "156 (312 with sparsity)",
30
+ "BFLOAT16 Tensore Core TFLOPS": "312 (624 with sparsity)",
31
+ "FP16 Tensor Core TFLOPS": "312 (624 with sparsity)",
32
+ "INT8 Tensor Core TOPS": "624 (1248 with sparsity)",
33
+ "Register File Size": "64K 32-bit registers per SM",
34
+ "Maximum number of registers per thread": "255",
35
+ "Maximum number of thread blocks per SM": "32",
36
+ "Shared memory capacity per SM": "164 KB",
37
+ "Maximum shared memory per thread block": "163 KB"
38
+ }
39
+ ```
40
+
41
+ ## Baseline Code
42
+
43
+ The baseline implementation of the `Model` class simply performs an element-wise addition.
44
+
45
+ ```python
46
+ import torch
47
+ import torch.nn as nn
48
+ import torch.nn.functional as F
49
+
50
+ class Model(nn.Module):
51
+ def __init__(self) -> None:
52
+ super().__init__()
53
+
54
+ def forward(self, a, b):
55
+ return a + b
56
+ ```
57
+
58
+ ## Optimized Code
59
+
60
+ The optimized version employs a custom CUDA kernel for fused element-wise addition. The kernel is defined and compiled inline using PyTorch's `load_inline`.
61
+
62
+ ```python
63
+ import torch
64
+ import torch.nn as nn
65
+ import torch.nn.functional as F
66
+ from torch.utils.cpp_extension import load_inline
67
+
68
+ # Define the custom CUDA kernel for element-wise addition
69
+ elementwise_add_source = '''
70
+ #include <torch/extension.h>
71
+ #include <cuda_runtime.h>
72
+
73
+ // CUDA kernel for element-wise addition
74
+ __global__ void elementwise_add_kernel(const float* a, const float* b, float* out, int size) {
75
+ int idx = blockIdx.x * blockDim.x + threadIdx.x;
76
+ if (idx < size) {
77
+ out[idx] = a[idx] + b[idx];
78
+ }
79
+ }
80
+
81
+ // Launch function for the CUDA kernel
82
+ torch::Tensor elementwise_add_cuda(torch::Tensor a, torch::Tensor b) {
83
+ auto size = a.numel();
84
+ auto out = torch::zeros_like(a);
85
+ const int block_size = 256;
86
+ const int num_blocks = (size + block_size - 1) / block_size;
87
+ elementwise_add_kernel<<<num_blocks, block_size>>>(a.data_ptr<float>(), b.data_ptr<float>(), out.data_ptr<float>(), size);
88
+ return out;
89
+ }
90
+ '''
91
+
92
+ # C++ function prototype declaration
93
+ elementwise_add_cpp_source = "torch::Tensor elementwise_add_cuda(torch::Tensor a, torch::Tensor b);"
94
+
95
+ # Compile the inline CUDA code for element-wise addition
96
+ elementwise_add = load_inline(
97
+ name="elementwise_add",
98
+ cpp_sources=elementwise_add_cpp_source,
99
+ cuda_sources=elementwise_add_source,
100
+ functions=["elementwise_add_cuda"],
101
+ verbose=True,
102
+ extra_cflags=[""],
103
+ extra_ldflags=[""],
104
+ )
105
+
106
+ class Model(nn.Module):
107
+ def __init__(self) -> None:
108
+ super().__init__()
109
+ self.elementwise_add = elementwise_add
110
+
111
+ def forward(self, a, b):
112
+ return self.elementwise_add.elementwise_add_cuda(a, b)
113
+ ```
@@ -0,0 +1,44 @@
1
+ import torch
2
+ import torch.nn as nn
3
+ import torch.nn.functional as F
4
+ import math
5
+
6
+
7
+ class Model(nn.Module):
8
+ """
9
+ A vanilla multi-head masked self-attention layer with a projection at the end.
10
+ """
11
+
12
+ def __init__(self, n_embd, n_head, attn_pdrop, resid_pdrop, max_seqlen):
13
+ super().__init__()
14
+ assert n_embd % n_head == 0
15
+ # key, query, value projections for all heads, but in a batch
16
+ self.c_attn = nn.Linear(n_embd, 3 * n_embd)
17
+ # output projection
18
+ self.c_proj = nn.Linear(n_embd, n_embd)
19
+ # regularization
20
+ self.attn_dropout = nn.Dropout(attn_pdrop)
21
+ self.resid_dropout = nn.Dropout(resid_pdrop)
22
+ # causal mask to ensure that attention is only applied to the left in the input sequence
23
+ self.register_buffer("bias", torch.tril(torch.ones(max_seqlen, max_seqlen)).view(1, 1, max_seqlen, max_seqlen))
24
+ self.n_head = n_head
25
+ self.n_embd = n_embd
26
+
27
+ def forward(self, x):
28
+ B, T, C = x.size() # batch size, sequence length, embedding dimensionality (n_embd)
29
+ # calculate query, key, values for all heads in batch and move head forward to be the batch dim
30
+ q, k, v = self.c_attn(x).split(self.n_embd, dim=2)
31
+ k = k.view(B, T, self.n_head, C // self.n_head).transpose(1, 2) # (B, nh, T, hs)
32
+ q = q.view(B, T, self.n_head, C // self.n_head).transpose(1, 2) # (B, nh, T, hs)
33
+ v = v.view(B, T, self.n_head, C // self.n_head).transpose(1, 2) # (B, nh, T, hs)
34
+
35
+ # causal self-attention; Self-attend: (B, nh, T, hs) x (B, nh, hs, T) -> (B, nh, T, T)
36
+ att = (q @ k.transpose(-2, -1)) * (1.0 / math.sqrt(k.size(-1)))
37
+ att = att.masked_fill(self.bias[:, :, :T, :T] == 0, float("-inf"))
38
+ att = F.softmax(att, dim=-1)
39
+ att = self.attn_dropout(att)
40
+ y = att @ v # (B, nh, T, T) x (B, nh, T, hs) -> (B, nh, T, hs)
41
+ y = y.transpose(1, 2).contiguous().view(B, T, C) # re-assemble all head outputs side by side
42
+ # output projection
43
+ y = self.resid_dropout(self.c_proj(y))
44
+ return y
@@ -28,10 +28,10 @@ class Model(nn.Module):
28
28
  Returns:
29
29
  torch.Tensor: Output tensor of shape (batch_size, hidden_size).
30
30
  """
31
- x = torch.matmul(x, self.weight.T) # Gemm
32
- x = x / 2 # Divide
33
- x = torch.sum(x, dim=1, keepdim=True) # Sum
34
- x = x * self.scaling_factor # Scaling
31
+ x = torch.matmul(x, self.weight.T)
32
+ x = x / 2
33
+ x = torch.sum(x, dim=1, keepdim=True)
34
+ x = x * self.scaling_factor
35
35
  return x
36
36
 
37
37
 
@@ -60,15 +60,33 @@ def get_inputs(B, N, device):
60
60
  return torch.randn(B, N, device=device, dtype=torch.float32)
61
61
 
62
62
 
63
+ @torch.no_grad()
63
64
  def bench(f, inputs, n_warmup, n_rep):
65
+ # Warm up
64
66
  for _ in range(n_warmup):
65
67
  f(inputs) # noqa
66
68
 
69
+ # Benchmark
70
+ device_type = inputs.device.type
67
71
  t_avg = 0.0
68
72
  for _ in range(n_rep):
73
+ # Clear cache before timing
74
+ if device_type == "cuda":
75
+ torch.cuda.empty_cache()
76
+ elif device_type == "mps":
77
+ torch.mps.empty_cache()
78
+
79
+ # time forward pass
69
80
  start_time = time.time()
70
81
  f(inputs)
71
82
  t_avg += time.time() - start_time
83
+
84
+ # Synchronize after each iteration
85
+ if device_type == "cuda":
86
+ torch.cuda.synchronize()
87
+ elif device_type == "mps":
88
+ torch.mps.synchronize()
89
+
72
90
  t_avg /= n_rep * 1e-3
73
91
  return t_avg
74
92
 
@@ -81,14 +99,19 @@ if __name__ == "__main__":
81
99
  parser.add_argument("--device", default="cpu", type=str)
82
100
  args = parser.parse_args()
83
101
 
102
+ # benchmark parameters
103
+ n_correctness_trials = 10
104
+ n_warmup = 1000
105
+ n_rep = 5000
106
+
84
107
  # init and input parameters
85
- B, N, H, S = 128, 10, 20, 1.5
108
+ batch_size, input_size, hidden_size, scaling_factor = 128, 10, 20, 1.5
86
109
 
87
110
  # load solution module
88
111
  try:
89
112
  torch.manual_seed(0)
90
113
  solution_module = load_module_from_path(args.solution_path, add_to_sys_modules=False)
91
- solution_model = solution_module.Model(N, H, S).to(args.device)
114
+ solution_model = solution_module.Model(input_size, hidden_size, scaling_factor).to(args.device)
92
115
  assert isinstance(solution_model, nn.Module)
93
116
  assert hasattr(solution_model, "forward")
94
117
  except Exception:
@@ -96,13 +119,12 @@ if __name__ == "__main__":
96
119
  exit(1)
97
120
 
98
121
  torch.manual_seed(0)
99
- baseline_model = Model(N, H, S).to(args.device)
122
+ baseline_model = Model(input_size, hidden_size, scaling_factor).to(args.device)
100
123
 
101
124
  # measure correctness
102
- n_correctness_trials = 10
103
125
  max_diff_avg = 0
104
126
  for _ in range(n_correctness_trials):
105
- inputs = get_inputs(B, N, args.device)
127
+ inputs = get_inputs(batch_size, input_size, args.device)
106
128
  baseline_output = baseline_model(inputs)
107
129
  optimized_output = solution_model(inputs)
108
130
  max_diff_avg += torch.max(torch.abs(optimized_output - baseline_output))
@@ -110,16 +132,9 @@ if __name__ == "__main__":
110
132
  print(f"max float diff between values of baseline and optimized model: {max_diff_avg}")
111
133
 
112
134
  # measure performance
113
- inputs = get_inputs(B, N, args.device)
114
- n_warmup = 100
115
- n_rep = 500
116
-
117
- # baseline
135
+ inputs = get_inputs(batch_size, input_size, args.device)
118
136
  t_avg_baseline = bench(baseline_model, inputs, n_warmup, n_rep)
119
137
  print(f"baseline time: {t_avg_baseline:.2f}ms")
120
-
121
- # optimized
122
138
  t_avg_optimized = bench(solution_model, inputs, n_warmup, n_rep)
123
139
  print(f"optimized time: {t_avg_optimized:.2f}ms")
124
-
125
140
  print(f"speedup: {t_avg_baseline / t_avg_optimized:.2f}x")
@@ -5,6 +5,7 @@ import importlib
5
5
  import traceback
6
6
  import mlx.core as mx
7
7
  import mlx.nn as nn
8
+ from typing import Union
8
9
 
9
10
 
10
11
  ########################################################
@@ -12,26 +13,27 @@ import mlx.nn as nn
12
13
  ########################################################
13
14
  class Model(nn.Module):
14
15
  """
15
- Model that performs a matrix multiplication, division, summation, and scaling.
16
+ Model that performs a 2D convolution.
17
+
18
+ Args:
19
+ in_channels (int): Number of input channels.
20
+ out_channels (int): Number of output channels.
21
+ kernel_size (Union[int, tuple]): Size of the convolution kernel.
22
+ stride (Union[int, tuple]): Stride of the convolution. Default is 1.
16
23
  """
17
24
 
18
- def __init__(self, input_size, hidden_size, scaling_factor):
25
+ def __init__(self, in_channels: int, out_channels: int, kernel_size: Union[int, tuple], stride: Union[int, tuple] = 1):
19
26
  super(Model, self).__init__()
20
- self.weight = mx.random.normal(shape=(hidden_size, input_size))
21
- self.scaling_factor = scaling_factor
27
+ self.conv = nn.Conv2d(in_channels, out_channels, kernel_size, stride)
22
28
 
23
29
  def __call__(self, x):
24
30
  """
25
31
  Args:
26
- x (mx.array): Input tensor of shape (batch_size, input_size).
32
+ x (mx.array): Input tensor of shape (batch_size, height, width, in_channels).
27
33
  Returns:
28
- mx.array: Output tensor of shape (batch_size, hidden_size).
34
+ mx.array: Output tensor of shape (batch_size, height, width, out_channels).
29
35
  """
30
- x = mx.matmul(x, mx.transpose(self.weight)) # Gemm
31
- x = x / 2 # Divide
32
- x = mx.sum(x, axis=1, keepdims=True) # Sum
33
- x = x * self.scaling_factor # Scaling
34
- return x
36
+ return self.conv(x)
35
37
 
36
38
 
37
39
  ########################################################
@@ -52,9 +54,9 @@ def load_module_from_path(module_path: str, add_to_sys_modules: bool = False):
52
54
  ########################################################
53
55
  # Benchmark
54
56
  ########################################################
55
- def get_inputs(B, N):
57
+ def get_inputs(batch_size, img_height, img_width, img_channels):
56
58
  # MLX doesn't use device parameter like PyTorch, as it automatically uses Metal
57
- return mx.random.normal(shape=(B, N), dtype=mx.float32)
59
+ return mx.random.normal(shape=(batch_size, img_height, img_width, img_channels), dtype=mx.float32)
58
60
 
59
61
 
60
62
  def bench(f, inputs, n_warmup, n_rep):
@@ -86,7 +88,13 @@ if __name__ == "__main__":
86
88
  args = parser.parse_args()
87
89
 
88
90
  # init and input parameters
89
- B, N, H, S = 128, 10, 20, 1.5
91
+ batch_size = 4
92
+ img_height = 224
93
+ img_width = 224
94
+ img_channels = 3
95
+ out_channels = 64
96
+ kernel_size = 3
97
+ stride = 1
90
98
 
91
99
  # Set the default device to 0
92
100
  mx.set_default_device(mx.gpu)
@@ -95,20 +103,20 @@ if __name__ == "__main__":
95
103
  try:
96
104
  mx.random.seed(0)
97
105
  solution_module = load_module_from_path(args.solution_path, add_to_sys_modules=False)
98
- solution_model = solution_module.Model(N, H, S)
106
+ solution_model = solution_module.Model(img_channels, out_channels, kernel_size, stride)
99
107
  assert hasattr(solution_model, "__call__")
100
108
  except Exception:
101
109
  print(f"Candidate module initialization failed: {traceback.format_exc()}")
102
110
  exit(1)
103
111
 
104
112
  mx.random.seed(0)
105
- baseline_model = Model(N, H, S)
113
+ baseline_model = Model(img_channels, out_channels, kernel_size, stride)
106
114
 
107
115
  # measure correctness
108
116
  n_correctness_trials = 10
109
117
  max_diff_avg = 0
110
118
  for _ in range(n_correctness_trials):
111
- inputs = get_inputs(B, N)
119
+ inputs = get_inputs(batch_size, img_height, img_width, img_channels)
112
120
  baseline_output = baseline_model(inputs)
113
121
  optimized_output = solution_model(inputs)
114
122
  max_diff = mx.max(mx.abs(optimized_output - baseline_output))
@@ -118,9 +126,9 @@ if __name__ == "__main__":
118
126
  print(f"max float diff between values of baseline and optimized model: {max_diff_avg}")
119
127
 
120
128
  # measure performance
121
- inputs = get_inputs(B, N)
122
- n_warmup = 100
123
- n_rep = 500
129
+ inputs = get_inputs(batch_size, img_height, img_width, img_channels)
130
+ n_warmup = 1000
131
+ n_rep = 5000
124
132
 
125
133
  # baseline
126
134
  t_avg_baseline = bench(baseline_model, inputs, n_warmup, n_rep)
@@ -3,7 +3,8 @@
3
3
  Custom Metal Kernels
4
4
  ====================
5
5
 
6
- MLX supports writing custom Metal kernels through the Python and C++ APIs.
6
+ MLX supports writing custom Metal kernels through the Python and C++ APIs. Use Metal kernels allows us to optimize the performance of our models by writing low-level parallelization and vectorization schemes with fuse operations effciently.
7
+ One important thing to keep in mind is correctness. The optimized code should produce outputs that are identical or atleast numerically close to the reference implementation i.e., max float difference less than 1e-5.
7
8
 
8
9
  When designing a custom kernel, ensure that you maintain the format of having a 'Model' class with a '__call__' method as this is what will be called to evaluate the solution.
9
10
 
@@ -0,0 +1,28 @@
1
+ import mlx.core as mx # noqa
2
+ import mlx.nn as nn
3
+ from typing import Union
4
+
5
+
6
+ class Model(nn.Module):
7
+ """
8
+ Model that performs a 2D convolution.
9
+
10
+ Args:
11
+ in_channels (int): Number of input channels.
12
+ out_channels (int): Number of output channels.
13
+ kernel_size (Union[int, tuple]): Size of the convolution kernel.
14
+ stride (Union[int, tuple]): Stride of the convolution. Default is 1.
15
+ """
16
+
17
+ def __init__(self, in_channels: int, out_channels: int, kernel_size: Union[int, tuple], stride: Union[int, tuple] = 1):
18
+ super(Model, self).__init__()
19
+ self.conv = nn.Conv2d(in_channels, out_channels, kernel_size, stride)
20
+
21
+ def __call__(self, x):
22
+ """
23
+ Args:
24
+ x (mx.array): Input tensor of shape (batch_size, height, width, in_channels).
25
+ Returns:
26
+ mx.array: Output tensor of shape (batch_size, height, width, out_channels).
27
+ """
28
+ return self.conv(x)
@@ -0,0 +1,153 @@
1
+ import time
2
+ import sys
3
+ import pathlib
4
+ import importlib
5
+ import traceback
6
+ import torch
7
+ import torch.nn as nn
8
+ import torch.nn.functional as F
9
+ import math
10
+
11
+
12
+ ########################################################
13
+ # Baseline
14
+ ########################################################
15
+ class Model(nn.Module):
16
+ """
17
+ A vanilla multi-head masked self-attention layer with a projection at the end.
18
+ """
19
+
20
+ def __init__(self, n_embd, n_head, attn_pdrop, resid_pdrop, max_seqlen):
21
+ super().__init__()
22
+ assert n_embd % n_head == 0
23
+ # key, query, value projections for all heads, but in a batch
24
+ self.c_attn = nn.Linear(n_embd, 3 * n_embd)
25
+ # output projection
26
+ self.c_proj = nn.Linear(n_embd, n_embd)
27
+ # regularization
28
+ self.attn_dropout = nn.Dropout(attn_pdrop)
29
+ self.resid_dropout = nn.Dropout(resid_pdrop)
30
+ # causal mask to ensure that attention is only applied to the left in the input sequence
31
+ self.register_buffer("bias", torch.tril(torch.ones(max_seqlen, max_seqlen)).view(1, 1, max_seqlen, max_seqlen))
32
+ self.n_head = n_head
33
+ self.n_embd = n_embd
34
+
35
+ def forward(self, x):
36
+ B, T, C = x.size() # batch size, sequence length, embedding dimensionality (n_embd)
37
+ # calculate query, key, values for all heads in batch and move head forward to be the batch dim
38
+ q, k, v = self.c_attn(x).split(self.n_embd, dim=2)
39
+ k = k.view(B, T, self.n_head, C // self.n_head).transpose(1, 2) # (B, nh, T, hs)
40
+ q = q.view(B, T, self.n_head, C // self.n_head).transpose(1, 2) # (B, nh, T, hs)
41
+ v = v.view(B, T, self.n_head, C // self.n_head).transpose(1, 2) # (B, nh, T, hs)
42
+
43
+ # causal self-attention; Self-attend: (B, nh, T, hs) x (B, nh, hs, T) -> (B, nh, T, T)
44
+ att = (q @ k.transpose(-2, -1)) * (1.0 / math.sqrt(k.size(-1)))
45
+ att = att.masked_fill(self.bias[:, :, :T, :T] == 0, float("-inf"))
46
+ att = F.softmax(att, dim=-1)
47
+ att = self.attn_dropout(att)
48
+ y = att @ v # (B, nh, T, T) x (B, nh, T, hs) -> (B, nh, T, hs)
49
+ y = y.transpose(1, 2).contiguous().view(B, T, C) # re-assemble all head outputs side by side
50
+ # output projection
51
+ y = self.resid_dropout(self.c_proj(y))
52
+ return y
53
+
54
+
55
+ ########################################################
56
+ # Weco Solution
57
+ ########################################################
58
+ def load_module_from_path(module_path: str, add_to_sys_modules: bool = False):
59
+ # Clean out all old compiled extensions to prevent namespace collisions during build
60
+ module_path = pathlib.Path(module_path)
61
+ name = module_path.stem
62
+ spec = importlib.util.spec_from_file_location(name, module_path)
63
+ mod = importlib.util.module_from_spec(spec) # type: ignore
64
+ if add_to_sys_modules:
65
+ sys.modules[name] = mod
66
+ spec.loader.exec_module(mod) # type: ignore
67
+ return mod
68
+
69
+
70
+ ########################################################
71
+ # Benchmark
72
+ ########################################################
73
+ def get_inputs(batch_size, seq_len, n_embd, device):
74
+ return torch.randn(batch_size, seq_len, n_embd, device=device, dtype=torch.float32)
75
+
76
+
77
+ @torch.no_grad()
78
+ def bench(f, inputs, n_warmup, n_rep):
79
+ # warmup
80
+ for _ in range(n_warmup):
81
+ f(inputs) # noqa
82
+
83
+ # benchmark
84
+ t_avg = 0.0
85
+ for _ in range(n_rep):
86
+ torch.cuda.empty_cache() # Clear cache before timing
87
+ start_time = time.time()
88
+ f(inputs)
89
+ torch.cuda.synchronize() # Wait for all computations to complete
90
+ t_avg += time.time() - start_time
91
+ t_avg /= n_rep * 1e-3
92
+ return t_avg
93
+
94
+
95
+ if __name__ == "__main__":
96
+ import argparse
97
+
98
+ parser = argparse.ArgumentParser()
99
+ parser.add_argument("--solution-path", type=str, required=True)
100
+ args = parser.parse_args()
101
+
102
+ # benchmarking parameters
103
+ n_correctness_trials = 10
104
+ n_warmup = 1000
105
+ n_rep = 5000
106
+
107
+ # init parameters
108
+ max_seqlen = 512
109
+ seq_len = 256
110
+ n_embd = 768
111
+ n_head = 8
112
+ # turn off dropout to measure correctness well
113
+ attn_pdrop = 0.0
114
+ resid_pdrop = 0.0
115
+
116
+ # input parameters
117
+ batch_size = 32
118
+
119
+ # load solution module
120
+ try:
121
+ torch.manual_seed(0)
122
+ solution_module = load_module_from_path(args.solution_path, add_to_sys_modules=False)
123
+ solution_model = solution_module.Model(
124
+ n_embd=n_embd, n_head=n_head, attn_pdrop=attn_pdrop, resid_pdrop=resid_pdrop, max_seqlen=max_seqlen
125
+ ).to("cuda")
126
+ assert isinstance(solution_model, nn.Module)
127
+ except Exception:
128
+ print(f"Candidate module initialization failed: {traceback.format_exc()}")
129
+ exit(1)
130
+
131
+ torch.manual_seed(0)
132
+ baseline_model = Model(
133
+ n_embd=n_embd, n_head=n_head, attn_pdrop=attn_pdrop, resid_pdrop=resid_pdrop, max_seqlen=max_seqlen
134
+ ).to("cuda")
135
+
136
+ # measure correctness
137
+ max_diff_avg = 0
138
+ for _ in range(n_correctness_trials):
139
+ inputs = get_inputs(batch_size=batch_size, seq_len=seq_len, n_embd=n_embd, device="cuda")
140
+ with torch.no_grad():
141
+ baseline_output = baseline_model(inputs)
142
+ optimized_output = solution_model(inputs)
143
+ max_diff_avg += torch.max(torch.abs(optimized_output - baseline_output))
144
+ max_diff_avg /= n_correctness_trials
145
+ print(f"max float diff between values of baseline and optimized model: {max_diff_avg}")
146
+
147
+ # measure performance
148
+ inputs = get_inputs(batch_size=batch_size, seq_len=seq_len, n_embd=n_embd, device="cuda")
149
+ t_avg_baseline = bench(baseline_model, inputs, n_warmup, n_rep)
150
+ print(f"baseline time: {t_avg_baseline:.2f}ms")
151
+ t_avg_optimized = bench(solution_model, inputs, n_warmup, n_rep)
152
+ print(f"optimized time: {t_avg_optimized:.2f}ms")
153
+ print(f"speedup: {t_avg_baseline / t_avg_optimized:.2f}x")
@@ -0,0 +1,44 @@
1
+ import torch
2
+ import torch.nn as nn
3
+ import torch.nn.functional as F
4
+ import math
5
+
6
+
7
+ class Model(nn.Module):
8
+ """
9
+ A vanilla multi-head masked self-attention layer with a projection at the end.
10
+ """
11
+
12
+ def __init__(self, n_embd, n_head, attn_pdrop, resid_pdrop, max_seqlen):
13
+ super().__init__()
14
+ assert n_embd % n_head == 0
15
+ # key, query, value projections for all heads, but in a batch
16
+ self.c_attn = nn.Linear(n_embd, 3 * n_embd)
17
+ # output projection
18
+ self.c_proj = nn.Linear(n_embd, n_embd)
19
+ # regularization
20
+ self.attn_dropout = nn.Dropout(attn_pdrop)
21
+ self.resid_dropout = nn.Dropout(resid_pdrop)
22
+ # causal mask to ensure that attention is only applied to the left in the input sequence
23
+ self.register_buffer("bias", torch.tril(torch.ones(max_seqlen, max_seqlen)).view(1, 1, max_seqlen, max_seqlen))
24
+ self.n_head = n_head
25
+ self.n_embd = n_embd
26
+
27
+ def forward(self, x):
28
+ B, T, C = x.size() # batch size, sequence length, embedding dimensionality (n_embd)
29
+ # calculate query, key, values for all heads in batch and move head forward to be the batch dim
30
+ q, k, v = self.c_attn(x).split(self.n_embd, dim=2)
31
+ k = k.view(B, T, self.n_head, C // self.n_head).transpose(1, 2) # (B, nh, T, hs)
32
+ q = q.view(B, T, self.n_head, C // self.n_head).transpose(1, 2) # (B, nh, T, hs)
33
+ v = v.view(B, T, self.n_head, C // self.n_head).transpose(1, 2) # (B, nh, T, hs)
34
+
35
+ # causal self-attention; Self-attend: (B, nh, T, hs) x (B, nh, hs, T) -> (B, nh, T, T)
36
+ att = (q @ k.transpose(-2, -1)) * (1.0 / math.sqrt(k.size(-1)))
37
+ att = att.masked_fill(self.bias[:, :, :T, :T] == 0, float("-inf"))
38
+ att = F.softmax(att, dim=-1)
39
+ att = self.attn_dropout(att)
40
+ y = att @ v # (B, nh, T, T) x (B, nh, T, hs) -> (B, nh, T, hs)
41
+ y = y.transpose(1, 2).contiguous().view(B, T, C) # re-assemble all head outputs side by side
42
+ # output projection
43
+ y = self.resid_dropout(self.c_proj(y))
44
+ return y
@@ -6,11 +6,11 @@ build-backend = "setuptools.build_meta"
6
6
  [project]
7
7
  name = "weco"
8
8
  authors = [
9
- {name = "Weco AI Team", email = "dhruv@weco.ai"},
9
+ {name = "Weco AI Team", email = "contact@weco.ai"},
10
10
  ]
11
11
  description = "Documentation for `weco`, a CLI for using Weco AI's code optimizer."
12
12
  readme = "README.md"
13
- version = "0.2.5"
13
+ version = "0.2.6"
14
14
  license = {text = "MIT"}
15
15
  requires-python = ">=3.12"
16
16
  dependencies = ["requests", "rich"]
@@ -1,4 +1,4 @@
1
1
  # DO NOT EDIT
2
- __pkg_version__ = "0.2.5"
2
+ __pkg_version__ = "0.2.6"
3
3
  __api_version__ = "v1"
4
4
  __base_url__ = f"https://api.aide.weco.ai/{__api_version__}"
@@ -321,7 +321,12 @@ def main() -> None:
321
321
  _, best_solution_panel = solution_panels.get_display(current_step=steps)
322
322
 
323
323
  # Update the end optimization layout
324
- end_optimization_layout["summary"].update(summary_panel.get_display())
324
+ final_message = (
325
+ f"{summary_panel.metric_name.capitalize()} {'maximized' if summary_panel.maximize else 'minimized'}! Best solution {summary_panel.metric_name.lower()} = [green]{status_response['best_result']['metric_value']}[/] 🏆"
326
+ if best_solution_node is not None
327
+ else "[red] No solution found.[/]"
328
+ )
329
+ end_optimization_layout["summary"].update(summary_panel.get_display(final_message=final_message))
325
330
  end_optimization_layout["tree"].update(tree_panel.get_display())
326
331
  end_optimization_layout["best_solution"].update(best_solution_panel)
327
332
 
@@ -12,7 +12,9 @@ class SummaryPanel:
12
12
  """Holds a summary of the optimization session."""
13
13
 
14
14
  def __init__(self, maximize: bool, metric_name: str, total_steps: int, model: str, session_id: str = None):
15
- self.goal = ("Maximizing" if maximize else "Minimizing") + f" {metric_name}..."
15
+ self.maximize = maximize
16
+ self.metric_name = metric_name
17
+ self.goal = ("Maximizing" if self.maximize else "Minimizing") + f" {self.metric_name}..."
16
18
  self.total_input_tokens = 0
17
19
  self.total_output_tokens = 0
18
20
  self.total_steps = total_steps
@@ -39,24 +41,28 @@ class SummaryPanel:
39
41
  self.total_input_tokens += usage["input_tokens"]
40
42
  self.total_output_tokens += usage["output_tokens"]
41
43
 
42
- def get_display(self) -> Panel:
44
+ def get_display(self, final_message: Optional[str] = None) -> Panel:
43
45
  """Create a summary panel with the relevant information."""
44
46
  layout = Layout(name="summary")
45
47
  summary_table = Table(show_header=False, box=None, padding=(0, 1))
46
48
  # Goal
47
- summary_table.add_row(f"[bold cyan]Goal:[/] {self.goal}")
49
+ if final_message is not None:
50
+ summary_table.add_row(f"[bold cyan]Result:[/] {final_message}")
51
+ else:
52
+ summary_table.add_row(f"[bold cyan]Goal:[/] {self.goal}")
53
+ summary_table.add_row("")
54
+ # Model used
55
+ summary_table.add_row(f"[bold cyan]Model:[/] {self.model}")
48
56
  summary_table.add_row("")
49
57
  # Log directory
50
58
  runs_dir = f".runs/{self.session_id}"
51
59
  summary_table.add_row(f"[bold cyan]Logs:[/] [blue underline]{runs_dir}[/]")
52
60
  summary_table.add_row("")
53
- # Model used
54
- summary_table.add_row(f"[bold cyan]Model:[/] [yellow]{self.model}[/]")
55
- summary_table.add_row("")
56
61
  # Token counts
57
62
  summary_table.add_row(
58
63
  f"[bold cyan]Tokens:[/] ↑[yellow]{format_number(self.total_input_tokens)}[/] ↓[yellow]{format_number(self.total_output_tokens)}[/] = [green]{format_number(self.total_input_tokens + self.total_output_tokens)}[/]"
59
64
  )
65
+ summary_table.add_row("")
60
66
  # Progress bar
61
67
  summary_table.add_row(self.progress)
62
68
 
@@ -1,8 +1,8 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: weco
3
- Version: 0.2.5
3
+ Version: 0.2.6
4
4
  Summary: Documentation for `weco`, a CLI for using Weco AI's code optimizer.
5
- Author-email: Weco AI Team <dhruv@weco.ai>
5
+ Author-email: Weco AI Team <contact@weco.ai>
6
6
  License: MIT
7
7
  Project-URL: Homepage, https://github.com/WecoAI/weco-cli
8
8
  Keywords: AI,Code Optimization,Code Generation
@@ -99,32 +99,69 @@ Here's how `weco` can be applied to common ML engineering tasks:
99
99
 
100
100
  ### Examples
101
101
 
102
- **Example 1: Optimizing PyTorch operations**
102
+ **Example 1: Optimizing PyTorch simple operations**
103
103
 
104
104
  ```bash
105
- weco --source examples/simple-torch/optimize.py \
106
- --eval-command "python examples/simple-torch/evaluate.py --solution-path examples/simple-torch/optimize.py --device mps" \
105
+ cd examples/hello-kernel-world
106
+ pip install torch
107
+ weco --source optimize.py \
108
+ --eval-command "python evaluate.py --solution-path optimize.py --device cpu" \
107
109
  --metric speedup \
108
110
  --maximize true \
109
111
  --steps 15 \
110
- --model o3-mini \
112
+ --model claude-3-7-sonnet-20250219 \
111
113
  --additional-instructions "Fuse operations in the forward method while ensuring the max float deviation remains small. Maintain the same format of the code."
112
114
  ```
113
115
 
116
+ Note that if you have an NVIDIA gpu, change the device to `cuda`. If you are running this on Apple Silicon, set it to `mps`.
117
+
114
118
  **Example 2: Optimizing MLX operations with instructions from a file**
115
119
 
116
- Sometimes, additional context or instructions are too complex for a single command-line string. You can provide a path to a file containing these instructions.
120
+ Lets optimize a 2D convolution operation in [`mlx`](https://github.com/ml-explore/mlx) using [Metal](https://developer.apple.com/documentation/metal/). Sometimes, additional context or instructions are too complex for a single command-line string. You can provide a path to a file containing these instructions.
117
121
 
118
122
  ```bash
119
- weco --source examples/simple-mlx/optimize.py \
120
- --eval-command "python examples/simple-mlx/evaluate.py --solution-path examples/simple-mlx/optimize.py" \
123
+ cd examples/metal
124
+ pip install mlx
125
+ weco --source optimize.py \
126
+ --eval-command "python evaluate.py --solution-path optimize.py" \
121
127
  --metric speedup \
122
128
  --maximize true \
123
129
  --steps 30 \
124
130
  --model o3-mini \
125
- --additional-instructions examples/simple-mlx/metal-examples.rst
131
+ --additional-instructions examples.rst
126
132
  ```
127
133
 
134
+ **Example 3: Level Agnostic Optimization: Causal Self Attention with Triton & CUDA**
135
+
136
+ Given how useful causal multihead self attention is to transformers, we've seen its wide adoption across ML engineering and AI research. Its great to keep things at a high-level (in PyTorch) when doing research, but when moving to production you often need to write highly customized low-level kernels to make things run as fast as they can. The `weco` CLI can optimize kernels across a variety of different abstraction levels and frameworks. Example 2 uses Metal but lets explore two more frameworks:
137
+
138
+ 1. [Triton](https://github.com/triton-lang/triton)
139
+ ```bash
140
+ cd examples/triton
141
+ pip install torch triton
142
+ weco --source optimize.py \
143
+ --eval-command "python evaluate.py --solution-path optimize.py" \
144
+ --metric speedup \
145
+ --maximize true \
146
+ --steps 30 \
147
+ --model gemini-2.5-pro-preview-03-25 \
148
+ --additional-instructions "Use triton to optimize the code while ensuring a small max float diff. Maintain the same code format."
149
+ ```
150
+
151
+ 2. [CUDA](https://docs.nvidia.com/cuda/cuda-c-programming-guide/index.html)
152
+ ```bash
153
+ cd examples/cuda
154
+ pip install torch
155
+ weco --source optimize.py \
156
+ --eval-command "python evaluate.py --solution-path optimize.py" \
157
+ --metric speedup \
158
+ --maximize true \
159
+ --steps 30 \
160
+ --model gemini-2.5-pro-preview-03-25 \
161
+ --additional-instructions guide.md
162
+ ```
163
+
164
+
128
165
  ---
129
166
 
130
167
  ### Command Line Arguments
@@ -0,0 +1,27 @@
1
+ .gitignore
2
+ LICENSE
3
+ README.md
4
+ pyproject.toml
5
+ .github/workflows/lint.yml
6
+ .github/workflows/release.yml
7
+ examples/cuda/evaluate.py
8
+ examples/cuda/guide.md
9
+ examples/cuda/optimize.py
10
+ examples/hello-kernel-world/evaluate.py
11
+ examples/hello-kernel-world/optimize.py
12
+ examples/metal/evaluate.py
13
+ examples/metal/examples.rst
14
+ examples/metal/optimize.py
15
+ examples/triton/evaluate.py
16
+ examples/triton/optimize.py
17
+ weco/__init__.py
18
+ weco/api.py
19
+ weco/cli.py
20
+ weco/panels.py
21
+ weco/utils.py
22
+ weco.egg-info/PKG-INFO
23
+ weco.egg-info/SOURCES.txt
24
+ weco.egg-info/dependency_links.txt
25
+ weco.egg-info/entry_points.txt
26
+ weco.egg-info/requires.txt
27
+ weco.egg-info/top_level.txt
@@ -1,26 +0,0 @@
1
- import mlx.core as mx
2
- import mlx.nn as nn
3
-
4
-
5
- class Model(nn.Module):
6
- """
7
- Model that performs a matrix multiplication, division, summation, and scaling.
8
- """
9
-
10
- def __init__(self, input_size, hidden_size, scaling_factor):
11
- super(Model, self).__init__()
12
- self.weight = mx.random.normal(shape=(hidden_size, input_size))
13
- self.scaling_factor = scaling_factor
14
-
15
- def __call__(self, x):
16
- """
17
- Args:
18
- x (mx.array): Input tensor of shape (batch_size, input_size).
19
- Returns:
20
- mx.array: Output tensor of shape (batch_size, hidden_size).
21
- """
22
- x = mx.matmul(x, mx.transpose(self.weight))
23
- x = x / 2
24
- x = mx.sum(x, axis=1, keepdims=True)
25
- x = x * self.scaling_factor
26
- return x
@@ -1,22 +0,0 @@
1
- .gitignore
2
- LICENSE
3
- README.md
4
- pyproject.toml
5
- .github/workflows/lint.yml
6
- .github/workflows/release.yml
7
- examples/simple-mlx/evaluate.py
8
- examples/simple-mlx/metal-examples.rst
9
- examples/simple-mlx/optimize.py
10
- examples/simple-torch/evaluate.py
11
- examples/simple-torch/optimize.py
12
- weco/__init__.py
13
- weco/api.py
14
- weco/cli.py
15
- weco/panels.py
16
- weco/utils.py
17
- weco.egg-info/PKG-INFO
18
- weco.egg-info/SOURCES.txt
19
- weco.egg-info/dependency_links.txt
20
- weco.egg-info/entry_points.txt
21
- weco.egg-info/requires.txt
22
- weco.egg-info/top_level.txt
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes