liger-kernel-nightly 0.5.10.dev20250611191801__py3-none-any.whl → 0.6.4.dev20260112233432__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of liger-kernel-nightly might be problematic. Click here for more details.
- liger_kernel/chunked_loss/__init__.py +1 -0
- liger_kernel/chunked_loss/cosine_similarity_loss.py +142 -0
- liger_kernel/chunked_loss/dpo_loss.py +54 -3
- liger_kernel/chunked_loss/functional.py +2 -0
- liger_kernel/chunked_loss/fused_linear_distillation.py +23 -5
- liger_kernel/chunked_loss/fused_linear_ppo.py +25 -5
- liger_kernel/chunked_loss/grpo_loss.py +46 -9
- liger_kernel/chunked_loss/jsd_loss.py +44 -13
- liger_kernel/ops/__init__.py +141 -0
- liger_kernel/ops/backends/README.md +151 -0
- liger_kernel/ops/backends/__init__.py +13 -0
- liger_kernel/ops/backends/_ascend/__init__.py +5 -0
- liger_kernel/ops/backends/_ascend/ascend-ub-manager-design.md +485 -0
- liger_kernel/ops/backends/_ascend/ops/__init__.py +49 -0
- liger_kernel/ops/backends/_ascend/ops/geglu.py +266 -0
- liger_kernel/ops/backends/_ascend/ops/qwen2vl_mrope.py +285 -0
- liger_kernel/ops/backends/_ascend/ops/rope.py +290 -0
- liger_kernel/ops/backends/_ascend/ops/swiglu.py +142 -0
- liger_kernel/ops/backends/_ascend/ops/tvd.py +221 -0
- liger_kernel/ops/backends/_ascend/ub_manager.py +349 -0
- liger_kernel/ops/backends/registry.py +61 -0
- liger_kernel/ops/cross_entropy.py +130 -64
- liger_kernel/ops/dyt.py +5 -4
- liger_kernel/ops/fused_add_rms_norm.py +416 -0
- liger_kernel/ops/fused_linear_cross_entropy.py +115 -22
- liger_kernel/ops/geglu.py +6 -4
- liger_kernel/ops/group_norm.py +7 -7
- liger_kernel/ops/grpo_loss.py +3 -1
- liger_kernel/ops/kl_div.py +8 -11
- liger_kernel/ops/layer_norm.py +135 -80
- liger_kernel/ops/llama4_rope.py +225 -0
- liger_kernel/ops/poly_norm.py +390 -0
- liger_kernel/ops/rms_norm.py +148 -71
- liger_kernel/ops/rope.py +1 -1
- liger_kernel/ops/swiglu.py +1 -1
- liger_kernel/ops/tiled_mlp.py +136 -0
- liger_kernel/ops/utils.py +14 -0
- liger_kernel/transformers/__init__.py +65 -0
- liger_kernel/transformers/auto_model.py +21 -0
- liger_kernel/transformers/cross_entropy.py +9 -4
- liger_kernel/transformers/dyt.py +1 -1
- liger_kernel/transformers/experimental/__init__.py +5 -0
- liger_kernel/transformers/experimental/embedding.py +1 -1
- liger_kernel/transformers/functional.py +56 -24
- liger_kernel/transformers/fused_add_rms_norm.py +39 -0
- liger_kernel/transformers/fused_linear_cross_entropy.py +17 -5
- liger_kernel/transformers/fused_linear_jsd.py +1 -1
- liger_kernel/transformers/fused_neighborhood_attention.py +1 -1
- liger_kernel/transformers/geglu.py +1 -1
- liger_kernel/transformers/group_norm.py +1 -1
- liger_kernel/transformers/grpo_loss.py +57 -2
- liger_kernel/transformers/jsd.py +1 -1
- liger_kernel/transformers/kl_div.py +1 -1
- liger_kernel/transformers/layer_norm.py +1 -1
- liger_kernel/transformers/llama4_rope.py +93 -0
- liger_kernel/transformers/model/exaone4.py +136 -0
- liger_kernel/transformers/model/falcon_h1.py +122 -0
- liger_kernel/transformers/model/gemma.py +28 -8
- liger_kernel/transformers/model/gemma2.py +34 -11
- liger_kernel/transformers/model/gemma3.py +102 -112
- liger_kernel/transformers/model/glm4.py +18 -5
- liger_kernel/transformers/model/glm4v.py +163 -0
- liger_kernel/transformers/model/glm4v_moe.py +172 -0
- liger_kernel/transformers/model/gpt_oss.py +211 -0
- liger_kernel/transformers/model/hunyuan_v1.py +134 -0
- liger_kernel/transformers/model/internvl.py +157 -0
- liger_kernel/transformers/model/llama.py +26 -7
- liger_kernel/transformers/model/llama4.py +121 -0
- liger_kernel/transformers/model/llava.py +18 -6
- liger_kernel/transformers/model/loss_utils.py +34 -3
- liger_kernel/transformers/model/mistral.py +17 -10
- liger_kernel/transformers/model/mixtral.py +24 -9
- liger_kernel/transformers/model/mllama.py +18 -7
- liger_kernel/transformers/model/olmo2.py +18 -5
- liger_kernel/transformers/model/olmo3.py +142 -0
- liger_kernel/transformers/model/output_classes.py +147 -0
- liger_kernel/transformers/model/paligemma.py +42 -5
- liger_kernel/transformers/model/phi3.py +24 -159
- liger_kernel/transformers/model/qwen2.py +26 -4
- liger_kernel/transformers/model/qwen2_5_vl.py +21 -8
- liger_kernel/transformers/model/qwen2_vl.py +24 -7
- liger_kernel/transformers/model/qwen3.py +22 -6
- liger_kernel/transformers/model/qwen3_moe.py +27 -7
- liger_kernel/transformers/model/qwen3_next.py +146 -0
- liger_kernel/transformers/model/qwen3_vl.py +150 -0
- liger_kernel/transformers/model/qwen3_vl_moe.py +126 -0
- liger_kernel/transformers/model/smollm3.py +199 -0
- liger_kernel/transformers/model/smolvlm.py +158 -0
- liger_kernel/transformers/monkey_patch.py +1423 -100
- liger_kernel/transformers/multi_token_attention.py +2 -2
- liger_kernel/transformers/poly_norm.py +42 -0
- liger_kernel/transformers/qwen2vl_mrope.py +1 -1
- liger_kernel/transformers/rms_norm.py +15 -5
- liger_kernel/transformers/rope.py +45 -1
- liger_kernel/transformers/softmax.py +1 -1
- liger_kernel/transformers/sparsemax.py +1 -1
- liger_kernel/transformers/swiglu.py +18 -1
- liger_kernel/transformers/tiled_mlp.py +125 -0
- liger_kernel/transformers/tvd.py +1 -1
- liger_kernel/utils.py +52 -0
- {liger_kernel_nightly-0.5.10.dev20250611191801.dist-info → liger_kernel_nightly-0.6.4.dev20260112233432.dist-info}/METADATA +37 -25
- liger_kernel_nightly-0.6.4.dev20260112233432.dist-info/RECORD +132 -0
- liger_kernel_nightly-0.5.10.dev20250611191801.dist-info/RECORD +0 -95
- {liger_kernel_nightly-0.5.10.dev20250611191801.dist-info → liger_kernel_nightly-0.6.4.dev20260112233432.dist-info}/LICENSE +0 -0
- {liger_kernel_nightly-0.5.10.dev20250611191801.dist-info → liger_kernel_nightly-0.6.4.dev20260112233432.dist-info}/NOTICE +0 -0
- {liger_kernel_nightly-0.5.10.dev20250611191801.dist-info → liger_kernel_nightly-0.6.4.dev20260112233432.dist-info}/WHEEL +0 -0
- {liger_kernel_nightly-0.5.10.dev20250611191801.dist-info → liger_kernel_nightly-0.6.4.dev20260112233432.dist-info}/top_level.txt +0 -0
|
@@ -0,0 +1,485 @@
|
|
|
1
|
+
# Ascend NPU UB Manager Design Document
|
|
2
|
+
|
|
3
|
+
## Overview
|
|
4
|
+
|
|
5
|
+
The UB Manager (Unified Buffer Manager) is a core component in **Liger-Kernel** responsible for managing the Unified Buffer (UB) capacity on Ascend NPUs. By automatically detecting UB capacity and providing unified tiling strategy computation, it helps Triton kernels avoid UB overflow errors while maintaining high performance.
|
|
6
|
+
|
|
7
|
+
## Design Goals
|
|
8
|
+
|
|
9
|
+
1. **Automated UB Management**: Automatically detect device UB capacity without manual configuration
|
|
10
|
+
2. **Unified Strategy System**: Use a single unified strategy function for all kernels, abstracting memory calculations
|
|
11
|
+
3. **Flexible Parameters**: Support different memory multipliers and safety margins for different kernels
|
|
12
|
+
4. **Easy to Use**: Simple interface that directly computes tiling results
|
|
13
|
+
|
|
14
|
+
## Architecture Design
|
|
15
|
+
|
|
16
|
+
### Core Components
|
|
17
|
+
|
|
18
|
+
```
|
|
19
|
+
┌─────────────────────────────────────────────────────────┐
|
|
20
|
+
│ UB Manager System │
|
|
21
|
+
├─────────────────────────────────────────────────────────┤
|
|
22
|
+
│ │
|
|
23
|
+
│ ┌──────────────┐ ┌──────────────────┐ │
|
|
24
|
+
│ │ UBManager │ │ Default Strategy │ │
|
|
25
|
+
│ │ (Singleton)│────────▶│ Function │ │
|
|
26
|
+
│ └──────────────┘ └──────────────────┘ │
|
|
27
|
+
│ │ │ │
|
|
28
|
+
│ │ │ │
|
|
29
|
+
│ ▼ ▼ │
|
|
30
|
+
│ ┌──────────────┐ ┌──────────────────┐ │
|
|
31
|
+
│ │ Capacity │ │ compute_default │ │
|
|
32
|
+
│ │ Detection │ │ _tiling_strategy│ │
|
|
33
|
+
│ └──────────────┘ └──────────────────┘ │
|
|
34
|
+
│ │
|
|
35
|
+
└─────────────────────────────────────────────────────────┘
|
|
36
|
+
│ │
|
|
37
|
+
│ │
|
|
38
|
+
▼ ▼
|
|
39
|
+
┌──────────────┐ ┌──────────────────┐
|
|
40
|
+
│ GEGLU │ │ ROPE │
|
|
41
|
+
│ Kernel │ │ Kernel │
|
|
42
|
+
└──────────────┘ └──────────────────┘
|
|
43
|
+
```
|
|
44
|
+
|
|
45
|
+
### Class Diagram
|
|
46
|
+
|
|
47
|
+
```
|
|
48
|
+
┌──────────────────────────────────────┐
|
|
49
|
+
│ UBManager │
|
|
50
|
+
├──────────────────────────────────────┤
|
|
51
|
+
│ - _npu_model: str │
|
|
52
|
+
│ - _ub_capacity_bits: int │
|
|
53
|
+
├──────────────────────────────────────┤
|
|
54
|
+
│ + ub_capacity_bits: int │
|
|
55
|
+
│ + ub_capacity_bytes: int │
|
|
56
|
+
│ + npu_model: str │
|
|
57
|
+
│ - _detect_npu_model() │
|
|
58
|
+
│ - _detect_ub_capacity() │
|
|
59
|
+
└──────────────────────────────────────┘
|
|
60
|
+
|
|
61
|
+
┌──────────────────────────────────────┐
|
|
62
|
+
│ compute_default_tiling_strategy │
|
|
63
|
+
├──────────────────────────────────────┤
|
|
64
|
+
│ + safety_margin: float │
|
|
65
|
+
│ + dtype_size: int │
|
|
66
|
+
│ + memory_multiplier: float │
|
|
67
|
+
│ + shapes: Tuple[Tuple[int, ...], ...]│
|
|
68
|
+
│ + tiling_dims: Tuple │
|
|
69
|
+
├──────────────────────────────────────┤
|
|
70
|
+
│ Returns: Tuple[Tuple[int, ...], ...] │
|
|
71
|
+
│ (same structure as shapes) │
|
|
72
|
+
└──────────────────────────────────────┘
|
|
73
|
+
|
|
74
|
+
┌──────────────────────────────────────┐
|
|
75
|
+
│ _normalize_tiling_dims │
|
|
76
|
+
├──────────────────────────────────────┤
|
|
77
|
+
│ Helper function to normalize │
|
|
78
|
+
│ tiling_dim (int or tuple) to set │
|
|
79
|
+
└──────────────────────────────────────┘
|
|
80
|
+
```
|
|
81
|
+
|
|
82
|
+
## Core Functionality
|
|
83
|
+
|
|
84
|
+
### 1. UB Capacity Detection
|
|
85
|
+
|
|
86
|
+
The UB Manager detects UB capacity in the following priority order:
|
|
87
|
+
|
|
88
|
+
1. **Environment Variable**: `ASCEND_UB_CAPACITY_BITS`
|
|
89
|
+
2. **Device Properties**: Retrieved from `torch.npu.get_device_properties(0).ub_capacity_bits`
|
|
90
|
+
3. **Model Defaults**: Use predefined values based on the detected NPU model
|
|
91
|
+
|
|
92
|
+
```python
|
|
93
|
+
# Default UB capacity configuration
|
|
94
|
+
_DEFAULT_UB_CAPACITIES = {
|
|
95
|
+
"Ascend910B1": 2097152, # ~256 KB
|
|
96
|
+
"Ascend910B4": 1572864, # ~192 KB
|
|
97
|
+
"default": 2097152, # ~256 KB
|
|
98
|
+
}
|
|
99
|
+
```
|
|
100
|
+
|
|
101
|
+
### 2. Unified Strategy System
|
|
102
|
+
|
|
103
|
+
All kernels use a single unified strategy function `_default_strategy` that abstracts memory calculations:
|
|
104
|
+
|
|
105
|
+
```
|
|
106
|
+
Memory Formula: memory_multiplier * BLOCK_SIZE * unit_param * dtype_size * 8 bits
|
|
107
|
+
```
|
|
108
|
+
|
|
109
|
+
Where `unit_param` is automatically calculated as the product of all fixed (non-tiling) dimensions in each shape.
|
|
110
|
+
|
|
111
|
+
The strategy function:
|
|
112
|
+
- Takes UB capacity, safety margin, dtype size, memory multiplier, shapes, and tiling dimension specifications
|
|
113
|
+
- For each shape, identifies which dimensions can be tiled (from `tiling_dims`)
|
|
114
|
+
- Calculates `unit_param` as the product of fixed (non-tiling) dimensions
|
|
115
|
+
- Calculates the maximum safe block size that fits within UB capacity
|
|
116
|
+
- Returns a tuple of max_safe_block_size values (one for each shape)
|
|
117
|
+
|
|
118
|
+
The `compute_default_tiling_strategy` function:
|
|
119
|
+
- Calls `_default_strategy` to get max_safe_block_size for each shape
|
|
120
|
+
- For each tiling dimension, computes desired block size using `triton.next_power_of_2(original_dim)`
|
|
121
|
+
- Returns the final result with same structure as input shapes: tiling dimensions replaced with computed block sizes, non-tiling dimensions padded to next power of 2
|
|
122
|
+
|
|
123
|
+
### 3. Parameter Structure
|
|
124
|
+
|
|
125
|
+
The unified strategy uses the following parameters:
|
|
126
|
+
|
|
127
|
+
- **`safety_margin`**: Safety margin as a float (e.g., 0.80 for 80%). Default is 0.80.
|
|
128
|
+
- **`dtype_size`**: Size of data type in bytes (e.g., 2 for float16, 4 for float32)
|
|
129
|
+
- **`memory_multiplier`**: Memory multiplier for estimating peak memory usage
|
|
130
|
+
- For GEGLU: typically 10.0 for backward, 7.0 for forward
|
|
131
|
+
- For ROPE: typically 3.0
|
|
132
|
+
- **`shapes`**: Tuple of full shapes. Each shape is a tuple of dimension sizes.
|
|
133
|
+
- For ROPE: `((n_q_head, hd), (n_kv_head, hd))`
|
|
134
|
+
- For GEGLU: `((n_cols,),)`
|
|
135
|
+
- Can pass original shapes (will handle padding internally) or padded shapes
|
|
136
|
+
- **`tiling_dims`**: Tuple specifying which dimensions can be tiled for each shape.
|
|
137
|
+
- Each element can be:
|
|
138
|
+
- `int`: single dimension index (e.g., `0` for first dimension)
|
|
139
|
+
- `tuple of ints`: multiple dimensions that can be tiled together (non-empty)
|
|
140
|
+
- For ROPE: `(0, 0)` means first dimension of each shape can be tiled
|
|
141
|
+
- For GEGLU: `(0,)` means first dimension of the shape can be tiled
|
|
142
|
+
- Length must match `len(shapes)`
|
|
143
|
+
- Fixed dimensions (non-tiling) are automatically extracted from shapes and multiplied to get `unit_param`
|
|
144
|
+
- **Validation**: Raises `ValueError` if:
|
|
145
|
+
- Any `tiling_dim` is empty or invalid (e.g., empty tuple)
|
|
146
|
+
- Any dimension index is out of bounds (negative or >= shape length)
|
|
147
|
+
|
|
148
|
+
### 4. Strategy Computation Flow
|
|
149
|
+
|
|
150
|
+
```
|
|
151
|
+
User calls compute_default_tiling_strategy()
|
|
152
|
+
│
|
|
153
|
+
▼
|
|
154
|
+
Get UB manager instance
|
|
155
|
+
│
|
|
156
|
+
▼
|
|
157
|
+
Validate shapes and tiling_dims (lengths must match)
|
|
158
|
+
│
|
|
159
|
+
▼
|
|
160
|
+
Set defaults for dtype_size (4) and memory_multiplier (10.0)
|
|
161
|
+
│
|
|
162
|
+
▼
|
|
163
|
+
Call _default_strategy() with:
|
|
164
|
+
- ub_capacity_bits
|
|
165
|
+
- safety_margin
|
|
166
|
+
- dtype_size
|
|
167
|
+
- memory_multiplier
|
|
168
|
+
- shapes
|
|
169
|
+
- tiling_dims
|
|
170
|
+
│
|
|
171
|
+
▼
|
|
172
|
+
For each (shape, tiling_dim) pair:
|
|
173
|
+
Normalize tiling_dim to set of dimension indices
|
|
174
|
+
Validate tiling dimensions are within shape bounds
|
|
175
|
+
(Raises ValueError if invalid)
|
|
176
|
+
│
|
|
177
|
+
▼
|
|
178
|
+
Calculate unit_param:
|
|
179
|
+
unit_param = product of all non-tiling dimensions
|
|
180
|
+
│
|
|
181
|
+
▼
|
|
182
|
+
Calculate max_block_size:
|
|
183
|
+
SAFE_UB_CAPACITY_BITS = ub_capacity_bits * safety_margin
|
|
184
|
+
max_block_size = SAFE_UB_CAPACITY_BITS / (memory_multiplier * unit_param * dtype_size * 8)
|
|
185
|
+
│
|
|
186
|
+
▼
|
|
187
|
+
Find largest power of 2 <= max_block_size
|
|
188
|
+
│
|
|
189
|
+
▼
|
|
190
|
+
Return tuple of max_safe_block_size (one per shape)
|
|
191
|
+
│
|
|
192
|
+
▼
|
|
193
|
+
Build result with same structure as shapes:
|
|
194
|
+
For each (shape, tiling_dim, max_safe):
|
|
195
|
+
For each tiling dimension:
|
|
196
|
+
desired = triton.next_power_of_2(original_dim)
|
|
197
|
+
final = min(desired, max_safe)
|
|
198
|
+
final = max(1, final)
|
|
199
|
+
For each non-tiling dimension:
|
|
200
|
+
pad to triton.next_power_of_2(original_dim)
|
|
201
|
+
│
|
|
202
|
+
▼
|
|
203
|
+
Return tuple of tiled shapes
|
|
204
|
+
```
|
|
205
|
+
|
|
206
|
+
## Usage Examples
|
|
207
|
+
|
|
208
|
+
### Basic Usage
|
|
209
|
+
|
|
210
|
+
```python
|
|
211
|
+
from liger_kernel.ops.backends._ascend.ub_manager import compute_default_tiling_strategy
|
|
212
|
+
|
|
213
|
+
# GEGLU forward
|
|
214
|
+
shapes = ((4096,),)
|
|
215
|
+
tile_shapes = compute_default_tiling_strategy(
|
|
216
|
+
safety_margin=0.80,
|
|
217
|
+
dtype_size=2, # float16
|
|
218
|
+
memory_multiplier=7.0,
|
|
219
|
+
shapes=shapes,
|
|
220
|
+
tiling_dims=(0,) # First dimension can be tiled
|
|
221
|
+
)
|
|
222
|
+
if tile_shapes is not None and len(tile_shapes) > 0:
|
|
223
|
+
block_size = tile_shapes[0][0]
|
|
224
|
+
# Call kernel with block_size
|
|
225
|
+
|
|
226
|
+
# ROPE forward
|
|
227
|
+
shapes = ((32, 128), (32, 128)) # (n_q_head, hd), (n_kv_head, hd)
|
|
228
|
+
tile_shapes = compute_default_tiling_strategy(
|
|
229
|
+
safety_margin=0.90,
|
|
230
|
+
dtype_size=4, # float32
|
|
231
|
+
memory_multiplier=3.0,
|
|
232
|
+
shapes=shapes,
|
|
233
|
+
tiling_dims=(0, 0) # First dimension of each shape can be tiled
|
|
234
|
+
)
|
|
235
|
+
if tile_shapes is not None and len(tile_shapes) == len(shapes):
|
|
236
|
+
q_tile_shape, k_tile_shape = tile_shapes
|
|
237
|
+
BLOCK_Q, _ = q_tile_shape # Tiled dimension
|
|
238
|
+
BLOCK_K, _ = k_tile_shape # Tiled dimension
|
|
239
|
+
# Call kernel with BLOCK_Q and BLOCK_K
|
|
240
|
+
```
|
|
241
|
+
|
|
242
|
+
## Strategy Function Details
|
|
243
|
+
|
|
244
|
+
### `_normalize_tiling_dims` Helper Function
|
|
245
|
+
|
|
246
|
+
A helper function that normalizes tiling dimension specifications:
|
|
247
|
+
|
|
248
|
+
```python
|
|
249
|
+
def _normalize_tiling_dims(tiling_dim: Union[int, Tuple[int, ...]]) -> set:
|
|
250
|
+
"""
|
|
251
|
+
Normalize tiling dimension specification to a set of dimension indices.
|
|
252
|
+
|
|
253
|
+
Args:
|
|
254
|
+
tiling_dim: Either an int (single dimension) or tuple of ints (multiple dimensions).
|
|
255
|
+
|
|
256
|
+
Returns:
|
|
257
|
+
Set of dimension indices that can be tiled.
|
|
258
|
+
"""
|
|
259
|
+
```
|
|
260
|
+
|
|
261
|
+
This function handles the conversion of `tiling_dim` from either an `int` or `tuple` to a `set` for consistent processing.
|
|
262
|
+
|
|
263
|
+
### `_default_strategy` Function
|
|
264
|
+
|
|
265
|
+
The core strategy function that calculates maximum safe block size:
|
|
266
|
+
|
|
267
|
+
```python
|
|
268
|
+
def _default_strategy(
|
|
269
|
+
ub_capacity_bits: int,
|
|
270
|
+
safety_margin: float,
|
|
271
|
+
dtype_size: int,
|
|
272
|
+
memory_multiplier: float,
|
|
273
|
+
shapes: Tuple[Tuple[int, ...], ...],
|
|
274
|
+
tiling_dims: Tuple[Union[int, Tuple[int, ...]], ...],
|
|
275
|
+
) -> Tuple[int, ...]:
|
|
276
|
+
"""
|
|
277
|
+
Calculate maximum safe block size based on UB capacity.
|
|
278
|
+
|
|
279
|
+
Memory formula: memory_multiplier * BLOCK_SIZE * unit_param * dtype_size * 8 bits
|
|
280
|
+
|
|
281
|
+
For each shape, fixed dimensions (non-tiling) are multiplied together to get unit_param.
|
|
282
|
+
|
|
283
|
+
Returns:
|
|
284
|
+
Tuple of max_safe_block_size (power of 2), one for each shape.
|
|
285
|
+
|
|
286
|
+
Raises:
|
|
287
|
+
ValueError: If any tiling_dim is empty or invalid, or if any dimension
|
|
288
|
+
index is out of bounds for the corresponding shape.
|
|
289
|
+
"""
|
|
290
|
+
```
|
|
291
|
+
|
|
292
|
+
**Key Steps:**
|
|
293
|
+
1. For each `(shape, tiling_dim)` pair:
|
|
294
|
+
- Normalize `tiling_dim` to a set of dimension indices using `_normalize_tiling_dims`
|
|
295
|
+
- Validate tiling dimensions are within shape bounds
|
|
296
|
+
- Raises `ValueError` if `tiling_dim` is empty or invalid
|
|
297
|
+
- Raises `ValueError` if any dimension index is out of bounds
|
|
298
|
+
- Calculate `unit_param` as the product of all non-tiling dimensions
|
|
299
|
+
- If all dimensions are tiling, `unit_param = 1.0`
|
|
300
|
+
2. Calculate `SAFE_UB_CAPACITY_BITS = ub_capacity_bits * safety_margin`
|
|
301
|
+
3. Solve for max_block_size: `SAFE_UB_CAPACITY_BITS / (memory_multiplier * unit_param * dtype_size * 8)`
|
|
302
|
+
4. Find largest power of 2 <= max_block_size
|
|
303
|
+
5. Return tuple with one max_safe_block_size per shape
|
|
304
|
+
|
|
305
|
+
### `compute_default_tiling_strategy` Function
|
|
306
|
+
|
|
307
|
+
The public interface that computes final tiling results:
|
|
308
|
+
|
|
309
|
+
```python
|
|
310
|
+
def compute_default_tiling_strategy(
|
|
311
|
+
safety_margin: float = 0.80,
|
|
312
|
+
dtype_size: Optional[int] = None,
|
|
313
|
+
memory_multiplier: Optional[float] = None,
|
|
314
|
+
shapes: Optional[Tuple[Tuple[int, ...], ...]] = None,
|
|
315
|
+
tiling_dims: Optional[Tuple[Union[int, Tuple[int, ...]], ...]] = None,
|
|
316
|
+
) -> Optional[Tuple[Tuple[int, ...], ...]]:
|
|
317
|
+
"""
|
|
318
|
+
Compute tiling strategy using the default strategy function.
|
|
319
|
+
|
|
320
|
+
Returns tuple of tiled shapes with same structure as input shapes.
|
|
321
|
+
Tiling dimensions are replaced with computed block sizes (power of 2),
|
|
322
|
+
while non-tiling dimensions are padded to next power of 2.
|
|
323
|
+
|
|
324
|
+
Returns:
|
|
325
|
+
Tuple of tiled shapes, or None if shapes/tiling_dims are empty or
|
|
326
|
+
lengths don't match.
|
|
327
|
+
|
|
328
|
+
Raises:
|
|
329
|
+
ValueError: If any tiling_dim is empty or invalid, or if any dimension
|
|
330
|
+
index is out of bounds for the corresponding shape.
|
|
331
|
+
"""
|
|
332
|
+
```
|
|
333
|
+
|
|
334
|
+
**Key Steps:**
|
|
335
|
+
1. Get UB manager instance
|
|
336
|
+
2. Validate `shapes` and `tiling_dims` (lengths must match, cannot be empty)
|
|
337
|
+
- Returns `None` if validation fails (empty or mismatched lengths)
|
|
338
|
+
3. Set defaults for `dtype_size` (4) and `memory_multiplier` (10.0) if not provided
|
|
339
|
+
4. Call `_default_strategy` to get `max_supported` (tuple of max_safe_block_size, one per shape)
|
|
340
|
+
- May raise `ValueError` if `tiling_dims` are invalid (see `_default_strategy` documentation)
|
|
341
|
+
5. For each `(shape, tiling_dim, max_safe)`:
|
|
342
|
+
- Normalize `tiling_dim` to a set of dimension indices
|
|
343
|
+
- Validate tiling dimensions are within shape bounds
|
|
344
|
+
- Raises `ValueError` if `tiling_dim` is empty or invalid
|
|
345
|
+
- Raises `ValueError` if any dimension index is out of bounds
|
|
346
|
+
- For each tiling dimension:
|
|
347
|
+
- Compute `desired = triton.next_power_of_2(original_dim)`
|
|
348
|
+
- Compute `final = min(desired, max_safe)`
|
|
349
|
+
- Ensure `final >= 1`
|
|
350
|
+
- Replace dimension with `final`
|
|
351
|
+
- For each non-tiling dimension:
|
|
352
|
+
- Pad to `triton.next_power_of_2(original_dim)`
|
|
353
|
+
6. Return tuple of tiled shapes (same structure as input `shapes`)
|
|
354
|
+
|
|
355
|
+
## Memory Analysis Examples
|
|
356
|
+
|
|
357
|
+
### GEGLU Forward
|
|
358
|
+
|
|
359
|
+
```
|
|
360
|
+
Memory analysis:
|
|
361
|
+
- Inputs: a, b
|
|
362
|
+
- Intermediates: a_cubed, tanh_arg, tanh_result, geglu_a
|
|
363
|
+
- Output: c
|
|
364
|
+
- Total: ~7x * BLOCK_SIZE * dtype_size
|
|
365
|
+
|
|
366
|
+
Strategy:
|
|
367
|
+
- shapes: ((n_cols,),)
|
|
368
|
+
- tiling_dims: (0,) # First dimension can be tiled
|
|
369
|
+
- Fixed dimensions: none (all dimensions are tiling)
|
|
370
|
+
- unit_param = 1 (product of fixed dimensions)
|
|
371
|
+
- memory_multiplier = 7.0
|
|
372
|
+
- Formula: 7.0 * BLOCK_SIZE * 1 * dtype_size * 8 bits
|
|
373
|
+
- Returns: ((block_size,),)
|
|
374
|
+
```
|
|
375
|
+
|
|
376
|
+
### GEGLU Backward
|
|
377
|
+
|
|
378
|
+
```
|
|
379
|
+
Memory analysis:
|
|
380
|
+
- More intermediates for gradient computation
|
|
381
|
+
- Total: ~10x * BLOCK_SIZE * dtype_size
|
|
382
|
+
|
|
383
|
+
Strategy:
|
|
384
|
+
- shapes: ((n_cols,),)
|
|
385
|
+
- tiling_dims: (0,) # First dimension can be tiled
|
|
386
|
+
- Fixed dimensions: none (all dimensions are tiling)
|
|
387
|
+
- unit_param = 1 (product of fixed dimensions)
|
|
388
|
+
- memory_multiplier = 10.0
|
|
389
|
+
- Formula: 10.0 * BLOCK_SIZE * 1 * dtype_size * 8 bits
|
|
390
|
+
- Returns: ((block_size,),)
|
|
391
|
+
```
|
|
392
|
+
|
|
393
|
+
### ROPE Forward/Backward
|
|
394
|
+
|
|
395
|
+
```
|
|
396
|
+
Memory analysis (based on optimized ROPE kernel):
|
|
397
|
+
- cos_vals and sin_vals: pad_hd // 2 elements each (shared)
|
|
398
|
+
- In q heads loop (peak memory):
|
|
399
|
+
* q_left, q_right, new_left, new_right: 2 * BLOCK_Q * pad_hd elements
|
|
400
|
+
- In k heads loop (peak memory):
|
|
401
|
+
* k_left, k_right, new_left, new_right: 2 * BLOCK_K * pad_hd elements
|
|
402
|
+
- Plus shared cos/sin: pad_hd elements
|
|
403
|
+
- Conservative estimate: 3 * BLOCK_SIZE * pad_hd * dtype_size * 8 bits
|
|
404
|
+
|
|
405
|
+
Strategy:
|
|
406
|
+
- shapes: ((pad_n_q_head, pad_hd), (pad_n_kv_head, pad_hd))
|
|
407
|
+
- tiling_dims: (0, 0) # First dimension of each shape can be tiled
|
|
408
|
+
- Fixed dimensions: pad_hd (second dimension, non-tiling)
|
|
409
|
+
- unit_param = pad_hd (product of fixed dimensions)
|
|
410
|
+
- memory_multiplier = 3.0
|
|
411
|
+
- Formula: 3.0 * BLOCK_SIZE * pad_hd * dtype_size * 8 bits
|
|
412
|
+
- Returns: ((block_size_q, pad_hd), (block_size_kv, pad_hd))
|
|
413
|
+
```
|
|
414
|
+
|
|
415
|
+
## Extension Guide
|
|
416
|
+
|
|
417
|
+
### Adding a New Kernel
|
|
418
|
+
|
|
419
|
+
To add tiling support for a new kernel:
|
|
420
|
+
|
|
421
|
+
1. **Analyze memory usage**:
|
|
422
|
+
- Identify peak memory usage in the kernel
|
|
423
|
+
- Determine memory multiplier (e.g., 7.0, 10.0, 3.0)
|
|
424
|
+
- Identify which dimensions can be tiled and which are fixed
|
|
425
|
+
- Fixed dimensions will be automatically extracted and multiplied to get `unit_param`
|
|
426
|
+
|
|
427
|
+
2. **Use `compute_default_tiling_strategy`** in your kernel:
|
|
428
|
+
|
|
429
|
+
```python
|
|
430
|
+
def my_kernel_forward(input):
|
|
431
|
+
# Prepare parameters
|
|
432
|
+
n_cols = input.shape[-1]
|
|
433
|
+
dtype_size = input.element_size()
|
|
434
|
+
|
|
435
|
+
# Compute strategy
|
|
436
|
+
# Example 1: Simple case (all dimensions can be tiled)
|
|
437
|
+
shapes = ((n_cols,),)
|
|
438
|
+
tile_shapes = compute_default_tiling_strategy(
|
|
439
|
+
safety_margin=0.80,
|
|
440
|
+
dtype_size=dtype_size,
|
|
441
|
+
memory_multiplier=7.0, # Based on your memory analysis
|
|
442
|
+
shapes=shapes,
|
|
443
|
+
tiling_dims=(0,) # First dimension can be tiled
|
|
444
|
+
)
|
|
445
|
+
|
|
446
|
+
if tile_shapes is not None and len(tile_shapes) > 0:
|
|
447
|
+
block_size = tile_shapes[0][0]
|
|
448
|
+
else:
|
|
449
|
+
block_size = triton.next_power_of_2(n_cols) # Fallback
|
|
450
|
+
|
|
451
|
+
# Example 2: Multiple shapes with fixed dimensions
|
|
452
|
+
# shapes = ((M, K), (K, N))
|
|
453
|
+
# tiling_dims = (0, 1) # First shape: dim 0 can be tiled, dim 1 is fixed
|
|
454
|
+
# # Second shape: dim 0 is fixed, dim 1 can be tiled
|
|
455
|
+
# Returns: ((block_M, K), (K, block_N))
|
|
456
|
+
|
|
457
|
+
# Call kernel
|
|
458
|
+
kernel[(grid_size,)](
|
|
459
|
+
input,
|
|
460
|
+
BLOCK_SIZE=block_size,
|
|
461
|
+
)
|
|
462
|
+
```
|
|
463
|
+
|
|
464
|
+
3. **Document memory analysis** in comments:
|
|
465
|
+
|
|
466
|
+
```python
|
|
467
|
+
# My kernel tiling strategy:
|
|
468
|
+
# - Memory analysis:
|
|
469
|
+
# * Input: input
|
|
470
|
+
# * Intermediates: intermediate1, intermediate2
|
|
471
|
+
# * Output: output
|
|
472
|
+
# * Total: ~7x * BLOCK_SIZE * dtype_size
|
|
473
|
+
# - shapes: ((n_cols,),)
|
|
474
|
+
# - tiling_dims: (0,) means first dimension can be tiled
|
|
475
|
+
# - Fixed dimensions: none (all dimensions are tiling)
|
|
476
|
+
# - unit_param = 1 (product of fixed dimensions)
|
|
477
|
+
# - Uses memory_multiplier=7.0 * BLOCK_SIZE * dtype_size * 8 bits for safety
|
|
478
|
+
# - compute_default_tiling_strategy returns: ((block_size,),)
|
|
479
|
+
# where block_size = min(triton.next_power_of_2(n_cols), max_safe_block_size)
|
|
480
|
+
```
|
|
481
|
+
|
|
482
|
+
## Future Improvements
|
|
483
|
+
|
|
484
|
+
1. **Strategy Variants**: If needed, could add specialized strategy functions for specific kernels while keeping the unified interface
|
|
485
|
+
2. **Multi-dimensional Tiling**: Could extend to support more complex tiling patterns if needed
|
|
@@ -0,0 +1,49 @@
|
|
|
1
|
+
"""
|
|
2
|
+
Ascend NPU operator implementations.
|
|
3
|
+
|
|
4
|
+
This module exports Ascend NPU-optimized implementations that will automatically
|
|
5
|
+
replace the default implementations when running on NPU devices.
|
|
6
|
+
|
|
7
|
+
Both Function classes and kernel functions can be exported here.
|
|
8
|
+
|
|
9
|
+
To add a new operator:
|
|
10
|
+
1. Create the implementation file (e.g., rms_norm.py)
|
|
11
|
+
2. Import the Function class and/or kernel functions here
|
|
12
|
+
3. Optionally add to __all__ for explicit control
|
|
13
|
+
|
|
14
|
+
If __all__ is not defined, all public symbols will be auto-discovered.
|
|
15
|
+
"""
|
|
16
|
+
|
|
17
|
+
from liger_kernel.ops.backends._ascend.ops.geglu import LigerGELUMulFunction
|
|
18
|
+
from liger_kernel.ops.backends._ascend.ops.geglu import geglu_backward
|
|
19
|
+
from liger_kernel.ops.backends._ascend.ops.geglu import geglu_forward
|
|
20
|
+
from liger_kernel.ops.backends._ascend.ops.qwen2vl_mrope import LigerQwen2VLMRopeFunction
|
|
21
|
+
from liger_kernel.ops.backends._ascend.ops.qwen2vl_mrope import qwen2vl_mrope_backward
|
|
22
|
+
from liger_kernel.ops.backends._ascend.ops.qwen2vl_mrope import qwen2vl_mrope_forward
|
|
23
|
+
from liger_kernel.ops.backends._ascend.ops.rope import LigerRopeFunction
|
|
24
|
+
from liger_kernel.ops.backends._ascend.ops.rope import rope_backward
|
|
25
|
+
from liger_kernel.ops.backends._ascend.ops.rope import rope_forward
|
|
26
|
+
from liger_kernel.ops.backends._ascend.ops.swiglu import LigerSiLUMulFunction
|
|
27
|
+
from liger_kernel.ops.backends._ascend.ops.swiglu import swiglu_backward
|
|
28
|
+
from liger_kernel.ops.backends._ascend.ops.swiglu import swiglu_forward
|
|
29
|
+
from liger_kernel.ops.backends._ascend.ops.tvd import LigerTVDLossFunction
|
|
30
|
+
from liger_kernel.ops.backends._ascend.ops.tvd import tv_distance_forward_triton
|
|
31
|
+
from liger_kernel.ops.backends._ascend.ops.tvd import tvd_backward_triton
|
|
32
|
+
|
|
33
|
+
__all__ = [
|
|
34
|
+
"LigerGELUMulFunction",
|
|
35
|
+
"geglu_forward",
|
|
36
|
+
"geglu_backward",
|
|
37
|
+
"LigerQwen2VLMRopeFunction",
|
|
38
|
+
"qwen2vl_mrope_forward",
|
|
39
|
+
"qwen2vl_mrope_backward",
|
|
40
|
+
"LigerRopeFunction",
|
|
41
|
+
"rope_forward",
|
|
42
|
+
"rope_backward",
|
|
43
|
+
"LigerSiLUMulFunction",
|
|
44
|
+
"swiglu_forward",
|
|
45
|
+
"swiglu_backward",
|
|
46
|
+
"LigerTVDLossFunction",
|
|
47
|
+
"tv_distance_forward_triton",
|
|
48
|
+
"tvd_backward_triton",
|
|
49
|
+
]
|