liger-kernel 0.6.3__py3-none-any.whl → 0.6.5__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (101) hide show
  1. liger_kernel/chunked_loss/cosine_similarity_loss.py +20 -5
  2. liger_kernel/chunked_loss/fused_linear_distillation.py +23 -5
  3. liger_kernel/chunked_loss/fused_linear_ppo.py +21 -5
  4. liger_kernel/chunked_loss/grpo_loss.py +8 -5
  5. liger_kernel/chunked_loss/jsd_loss.py +39 -11
  6. liger_kernel/ops/__init__.py +141 -0
  7. liger_kernel/ops/backends/README.md +151 -0
  8. liger_kernel/ops/backends/__init__.py +13 -0
  9. liger_kernel/ops/backends/_ascend/__init__.py +5 -0
  10. liger_kernel/ops/backends/_ascend/ascend-ub-manager-design.md +492 -0
  11. liger_kernel/ops/backends/_ascend/ops/__init__.py +61 -0
  12. liger_kernel/ops/backends/_ascend/ops/embedding.py +214 -0
  13. liger_kernel/ops/backends/_ascend/ops/geglu.py +191 -0
  14. liger_kernel/ops/backends/_ascend/ops/llama4_rope.py +298 -0
  15. liger_kernel/ops/backends/_ascend/ops/qwen2vl_mrope.py +275 -0
  16. liger_kernel/ops/backends/_ascend/ops/rope.py +265 -0
  17. liger_kernel/ops/backends/_ascend/ops/swiglu.py +142 -0
  18. liger_kernel/ops/backends/_ascend/ops/tvd.py +223 -0
  19. liger_kernel/ops/backends/_ascend/ub_manager.py +367 -0
  20. liger_kernel/ops/backends/registry.py +61 -0
  21. liger_kernel/ops/cross_entropy.py +71 -11
  22. liger_kernel/ops/dyt.py +5 -2
  23. liger_kernel/ops/fused_add_rms_norm.py +21 -23
  24. liger_kernel/ops/fused_linear_cross_entropy.py +32 -5
  25. liger_kernel/ops/geglu.py +5 -3
  26. liger_kernel/ops/group_norm.py +12 -8
  27. liger_kernel/ops/grpo_loss.py +3 -1
  28. liger_kernel/ops/kl_div.py +8 -11
  29. liger_kernel/ops/layer_norm.py +89 -69
  30. liger_kernel/ops/poly_norm.py +19 -21
  31. liger_kernel/ops/rms_norm.py +149 -71
  32. liger_kernel/ops/tiled_mlp.py +136 -0
  33. liger_kernel/ops/utils.py +25 -0
  34. liger_kernel/transformers/__init__.py +25 -0
  35. liger_kernel/transformers/auto_model.py +21 -0
  36. liger_kernel/transformers/cross_entropy.py +9 -4
  37. liger_kernel/transformers/dyt.py +1 -1
  38. liger_kernel/transformers/experimental/embedding.py +1 -1
  39. liger_kernel/transformers/functional.py +44 -26
  40. liger_kernel/transformers/fused_add_rms_norm.py +1 -1
  41. liger_kernel/transformers/fused_linear_cross_entropy.py +9 -4
  42. liger_kernel/transformers/fused_linear_jsd.py +1 -1
  43. liger_kernel/transformers/fused_neighborhood_attention.py +1 -1
  44. liger_kernel/transformers/geglu.py +1 -1
  45. liger_kernel/transformers/group_norm.py +1 -1
  46. liger_kernel/transformers/grpo_loss.py +57 -2
  47. liger_kernel/transformers/jsd.py +1 -1
  48. liger_kernel/transformers/kl_div.py +1 -1
  49. liger_kernel/transformers/layer_norm.py +1 -1
  50. liger_kernel/transformers/llama4_rope.py +1 -1
  51. liger_kernel/transformers/model/exaone4.py +136 -0
  52. liger_kernel/transformers/model/falcon_h1.py +19 -5
  53. liger_kernel/transformers/model/gemma.py +17 -6
  54. liger_kernel/transformers/model/gemma2.py +17 -8
  55. liger_kernel/transformers/model/gemma3.py +35 -16
  56. liger_kernel/transformers/model/glm4.py +16 -4
  57. liger_kernel/transformers/model/glm4v.py +16 -4
  58. liger_kernel/transformers/model/glm4v_moe.py +23 -4
  59. liger_kernel/transformers/model/gpt_oss.py +211 -0
  60. liger_kernel/transformers/model/hunyuan_v1.py +134 -0
  61. liger_kernel/transformers/model/internvl.py +12 -5
  62. liger_kernel/transformers/model/llama.py +14 -5
  63. liger_kernel/transformers/model/llama4.py +16 -4
  64. liger_kernel/transformers/model/llava.py +12 -4
  65. liger_kernel/transformers/model/loss_utils.py +37 -3
  66. liger_kernel/transformers/model/mistral.py +15 -6
  67. liger_kernel/transformers/model/mixtral.py +16 -7
  68. liger_kernel/transformers/model/mllama.py +12 -4
  69. liger_kernel/transformers/model/olmo2.py +16 -4
  70. liger_kernel/transformers/model/olmo3.py +142 -0
  71. liger_kernel/transformers/model/output_classes.py +147 -0
  72. liger_kernel/transformers/model/paligemma.py +23 -5
  73. liger_kernel/transformers/model/phi3.py +14 -7
  74. liger_kernel/transformers/model/qwen2.py +16 -3
  75. liger_kernel/transformers/model/qwen2_5_vl.py +14 -6
  76. liger_kernel/transformers/model/qwen2_vl.py +16 -4
  77. liger_kernel/transformers/model/qwen3.py +20 -5
  78. liger_kernel/transformers/model/qwen3_moe.py +19 -5
  79. liger_kernel/transformers/model/qwen3_next.py +17 -5
  80. liger_kernel/transformers/model/qwen3_vl.py +150 -0
  81. liger_kernel/transformers/model/qwen3_vl_moe.py +126 -0
  82. liger_kernel/transformers/model/smollm3.py +15 -6
  83. liger_kernel/transformers/monkey_patch.py +584 -49
  84. liger_kernel/transformers/multi_token_attention.py +1 -1
  85. liger_kernel/transformers/poly_norm.py +1 -1
  86. liger_kernel/transformers/qwen2vl_mrope.py +1 -1
  87. liger_kernel/transformers/rms_norm.py +8 -3
  88. liger_kernel/transformers/rope.py +45 -1
  89. liger_kernel/transformers/softmax.py +1 -1
  90. liger_kernel/transformers/sparsemax.py +1 -1
  91. liger_kernel/transformers/swiglu.py +18 -1
  92. liger_kernel/transformers/tiled_mlp.py +125 -0
  93. liger_kernel/transformers/tvd.py +1 -1
  94. liger_kernel/utils.py +54 -0
  95. {liger_kernel-0.6.3.dist-info → liger_kernel-0.6.5.dist-info}/METADATA +14 -4
  96. liger_kernel-0.6.5.dist-info/RECORD +134 -0
  97. {liger_kernel-0.6.3.dist-info → liger_kernel-0.6.5.dist-info}/WHEEL +1 -1
  98. liger_kernel-0.6.3.dist-info/RECORD +0 -111
  99. {liger_kernel-0.6.3.dist-info → liger_kernel-0.6.5.dist-info}/licenses/LICENSE +0 -0
  100. {liger_kernel-0.6.3.dist-info → liger_kernel-0.6.5.dist-info}/licenses/NOTICE +0 -0
  101. {liger_kernel-0.6.3.dist-info → liger_kernel-0.6.5.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,151 @@
1
+ # Adding a New Vendor Backend
2
+
3
+ This directory contains vendor-specific operator implementations that automatically replace the default (CUDA) implementations when running on the corresponding device.
4
+
5
+ ## Concepts
6
+
7
+ - **Vendor**: Chip manufacturer (e.g., `ascend`, `intel`, `nvidia`)
8
+ - **Device**: Device type (e.g., `npu`, `xpu`, `cuda`)
9
+ - **VendorInfo**: Defines the mapping between vendor and device
10
+
11
+ ## Directory Structure
12
+
13
+ ```
14
+ backends/
15
+ ├── README.md
16
+ ├── __init__.py
17
+ ├── registry.py # VendorInfo, register_vendor(), VENDOR_REGISTRY
18
+ ├── _ascend/ # Ascend (Huawei) vendor - supports NPU
19
+ │ ├── __init__.py # Registers VendorInfo for NPU
20
+ │ └── ops/
21
+ │ ├── __init__.py # Exports vendor-specific implementations
22
+ │ └── geglu.py # NPU-specific GEGLU implementation
23
+ └── _<vendor>/ # Your new vendor backend
24
+ └── ...
25
+ ```
26
+
27
+ ## How It Works
28
+
29
+ 1. When `liger_kernel.ops.backends` is imported, it imports all vendor packages (e.g., `_ascend`)
30
+ 2. Each vendor's `__init__.py` calls `register_vendor()` to register itself
31
+ 3. When `liger_kernel.ops` is imported, `_replace_with_vendor_ops()` is called
32
+ 4. It detects the current device via `infer_device()` and looks up the vendor
33
+ 5. Vendor implementations replace/add to the `liger_kernel.ops` namespace
34
+
35
+ ## Adding a New Vendor
36
+
37
+ ### Step 1: Create Directory Structure
38
+
39
+ ```bash
40
+ mkdir -p backends/_<vendor>/ops
41
+ touch backends/_<vendor>/__init__.py
42
+ touch backends/_<vendor>/ops/__init__.py
43
+ ```
44
+
45
+ ### Step 2: Register Your Vendor
46
+
47
+ In `backends/_<vendor>/__init__.py`, register your vendor:
48
+
49
+ ```python
50
+ """
51
+ <Vendor> backend for Liger-Kernel.
52
+ """
53
+
54
+ from liger_kernel.ops.backends.registry import VendorInfo, register_vendor
55
+
56
+ register_vendor(
57
+ VendorInfo(
58
+ vendor="<vendor>",
59
+ device="<device>",
60
+ )
61
+ )
62
+ ```
63
+
64
+
65
+ ### Step 3: Ensure Device Detection Works
66
+
67
+ Make sure `infer_device()` in `liger_kernel/utils.py` can detect your device:
68
+
69
+ ```python
70
+ def infer_device():
71
+ if torch.cuda.is_available():
72
+ return "cuda"
73
+ if is_npu_available():
74
+ return "npu"
75
+ # Add your device detection here
76
+ if is_<device>_available():
77
+ return "<device>"
78
+ return "cpu"
79
+ ```
80
+
81
+ ### Step 4: Implement Vendor-Specific Operators
82
+
83
+ Create operator files in `backends/_<vendor>/ops/`. For example, `geglu.py`:
84
+
85
+ ```python
86
+ import torch
87
+
88
+ class LigerGELUMulFunction(torch.autograd.Function):
89
+ """
90
+ Vendor-specific LigerGELUMulFunction implementation.
91
+ """
92
+ @staticmethod
93
+ def forward(ctx, a, b):
94
+ # Your vendor-specific forward implementation
95
+ ...
96
+
97
+ @staticmethod
98
+ def backward(ctx, dc):
99
+ # Your vendor-specific backward implementation
100
+ ...
101
+
102
+ # Optional: vendor-specific kernel functions
103
+ def geglu_forward_vendor(a, b):
104
+ ...
105
+
106
+ def geglu_backward_vendor(a, b, dc):
107
+ ...
108
+ ```
109
+
110
+ ### Step 5: Export in `ops/__init__.py`
111
+
112
+ In `backends/_<vendor>/ops/__init__.py`, export your implementations:
113
+
114
+ ```python
115
+ """
116
+ <Vendor>-specific operator implementations.
117
+ """
118
+
119
+ from .<module> import (
120
+ LigerGELUMulFunction,
121
+ geglu_forward_vendor as geglu_forward, # Rename to match default API
122
+ geglu_backward_vendor as geglu_backward,
123
+ )
124
+
125
+ # Explicitly declare what to export (recommended)
126
+ __all__ = [
127
+ "LigerGELUMulFunction",
128
+ "geglu_forward",
129
+ "geglu_backward",
130
+ ]
131
+ ```
132
+
133
+ ## Key Points
134
+
135
+ ### Incremental Override
136
+
137
+ You **don't need to implement all operators**. Only implement the ones that require vendor-specific adaptations. Unimplemented operators will automatically fall back to the default (CUDA) implementation.
138
+
139
+ ### Vendor-Specific Additions
140
+
141
+ Vendors can also **add new operators** that don't exist in the default implementation. These will be exported to `liger_kernel.ops` namespace for users to import.
142
+
143
+ ### Naming Convention
144
+
145
+ - Use the **same class/function names** as the default implementations for overrides
146
+ - This allows seamless replacement without changing user code
147
+ - Use `as` imports to rename if your internal naming differs
148
+
149
+ ## Example: Ascend NPU Backend
150
+
151
+ See `_ascend/` directory for a complete example of the Ascend NPU backend implementation.
@@ -0,0 +1,13 @@
1
+ import importlib
2
+ import pkgutil
3
+
4
+ from liger_kernel.ops.backends.registry import VENDOR_REGISTRY # noqa: F401
5
+ from liger_kernel.ops.backends.registry import VendorInfo # noqa: F401
6
+ from liger_kernel.ops.backends.registry import get_vendor_for_device # noqa: F401
7
+ from liger_kernel.ops.backends.registry import register_vendor # noqa: F401
8
+
9
+ # Auto-import all _<vendor> subpackages to trigger registration
10
+ # Each vendor's __init__.py calls register_vendor() when imported
11
+ for _, modname, ispkg in pkgutil.iter_modules(__path__):
12
+ if ispkg and modname.startswith("_"):
13
+ importlib.import_module(f"{__name__}.{modname}")
@@ -0,0 +1,5 @@
1
+ from liger_kernel.ops.backends.registry import VendorInfo
2
+ from liger_kernel.ops.backends.registry import register_vendor
3
+
4
+ # Register Ascend vendor for NPU device
5
+ register_vendor(VendorInfo(vendor="ascend", device="npu"))
@@ -0,0 +1,492 @@
1
+ # Ascend NPU UB Manager Design Document
2
+
3
+ ## Overview
4
+
5
+ The UB Manager (Unified Buffer Manager) is a core component in **Liger-Kernel** responsible for managing the Unified Buffer (UB) capacity on Ascend NPUs. By automatically detecting UB capacity and providing unified tiling strategy computation, it helps Triton kernels avoid UB overflow errors while maintaining high performance.
6
+
7
+ ## Design Goals
8
+
9
+ 1. **Automated UB Management**: Automatically detect device UB capacity without manual configuration
10
+ 2. **Unified Strategy System**: Use a single unified strategy function for all kernels, abstracting memory calculations
11
+ 3. **Flexible Parameters**: Support different memory multipliers and safety margins for different kernels
12
+ 4. **Easy to Use**: Simple interface that directly computes tiling results
13
+
14
+ ## Architecture Design
15
+
16
+ ### Core Components
17
+
18
+ ```
19
+ ┌─────────────────────────────────────────────────────────┐
20
+ │ UB Manager System │
21
+ ├─────────────────────────────────────────────────────────┤
22
+ │ │
23
+ │ ┌──────────────┐ ┌──────────────────┐ │
24
+ │ │ UBManager │ │ Default Strategy │ │
25
+ │ │ (Singleton)│────────▶│ Function │ │
26
+ │ └──────────────┘ └──────────────────┘ │
27
+ │ │ │ │
28
+ │ │ │ │
29
+ │ ▼ ▼ │
30
+ │ ┌──────────────┐ ┌──────────────────┐ │
31
+ │ │ Capacity │ │ compute_default │ │
32
+ │ │ Detection │ │ _tiling_strategy│ │
33
+ │ └──────────────┘ └──────────────────┘ │
34
+ │ │
35
+ └─────────────────────────────────────────────────────────┘
36
+ │ │
37
+ │ │
38
+ ▼ ▼
39
+ ┌──────────────┐ ┌──────────────────┐
40
+ │ GEGLU │ │ ROPE │
41
+ │ Kernel │ │ Kernel │
42
+ └──────────────┘ └──────────────────┘
43
+ ```
44
+
45
+ ### Class Diagram
46
+
47
+ ```
48
+ ┌──────────────────────────────────────┐
49
+ │ UBManager │
50
+ ├──────────────────────────────────────┤
51
+ │ - _npu_model: str │
52
+ │ - _ub_capacity_bits: int │
53
+ ├──────────────────────────────────────┤
54
+ │ + ub_capacity_bits: int │
55
+ │ + ub_capacity_bytes: int │
56
+ │ + npu_model: str │
57
+ │ - _detect_npu_model() │
58
+ │ - _detect_ub_capacity() │
59
+ │ (raises RuntimeError if fails) │
60
+ └──────────────────────────────────────┘
61
+
62
+ ┌──────────────────────────────────────┐
63
+ │ compute_default_tiling_strategy │
64
+ ├──────────────────────────────────────┤
65
+ │ + safety_margin: float │
66
+ │ + dtype_size: int │
67
+ │ + memory_multiplier: float │
68
+ │ + shapes: Tuple[Tuple[int, ...], ...]│
69
+ │ + tiling_dims: Tuple │
70
+ ├──────────────────────────────────────┤
71
+ │ Returns: Tuple[Tuple[int, ...], ...] │
72
+ │ (same structure as shapes) │
73
+ └──────────────────────────────────────┘
74
+
75
+ ┌──────────────────────────────────────┐
76
+ │ _normalize_tiling_dims │
77
+ ├──────────────────────────────────────┤
78
+ │ Helper function to normalize │
79
+ │ tiling_dim (int or tuple) to set │
80
+ └──────────────────────────────────────┘
81
+ ```
82
+
83
+ ## Core Functionality
84
+
85
+ ### 1. UB Capacity Detection
86
+
87
+ The UB Manager detects UB capacity in the following priority order:
88
+
89
+ 1. **Environment Variable**: `ASCEND_UB_CAPACITY_BITS` (in bits)
90
+ - If set, this value is used directly
91
+ - Must be a positive integer representing UB capacity in bits
92
+
93
+ 2. **get_soc_spec**: Query UB size from CANN's `get_soc_spec("UB_SIZE")`
94
+ - Returns UB size in bytes
95
+ - Automatically converted to bits (bytes * 8)
96
+ - Requires CANN environment to be sourced (e.g., `source /usr/local/Ascend/ascend-toolkit/set_env.sh`)
97
+
98
+ 3. **Error Handling**: If neither method succeeds, raises `RuntimeError` with clear instructions
99
+
100
+
101
+ ```python
102
+ # Detection flow:
103
+ # 1. Check ASCEND_UB_CAPACITY_BITS env var (bits)
104
+ # 2. Try get_soc_spec("UB_SIZE") (bytes) -> convert to bits
105
+ # 3. Raise RuntimeError if both fail
106
+ ```
107
+
108
+ ### 2. Unified Strategy System
109
+
110
+ All kernels use a single unified strategy function `_default_strategy` that abstracts memory calculations:
111
+
112
+ ```
113
+ Memory Formula: memory_multiplier * BLOCK_SIZE * unit_param * dtype_size * 8 bits
114
+ ```
115
+
116
+ Where `unit_param` is automatically calculated as the product of all fixed (non-tiling) dimensions in each shape.
117
+
118
+ The strategy function:
119
+ - Takes UB capacity, safety margin, dtype size, memory multiplier, shapes, and tiling dimension specifications
120
+ - For each shape, identifies which dimensions can be tiled (from `tiling_dims`)
121
+ - Calculates `unit_param` as the product of fixed (non-tiling) dimensions
122
+ - Calculates the maximum safe block size that fits within UB capacity
123
+ - Returns a tuple of max_safe_block_size values (one for each shape)
124
+
125
+ The `compute_default_tiling_strategy` function:
126
+ - Calls `_default_strategy` to get max_safe_block_size for each shape
127
+ - For each tiling dimension, computes desired block size using `triton.next_power_of_2(original_dim)`
128
+ - Returns the final result with same structure as input shapes: tiling dimensions replaced with computed block sizes, non-tiling dimensions padded to next power of 2
129
+
130
+ ### 3. Parameter Structure
131
+
132
+ The unified strategy uses the following parameters:
133
+
134
+ - **`safety_margin`**: Safety margin as a float (e.g., 0.80 for 80%). Default is 0.80.
135
+ - **`dtype_size`**: Size of data type in bytes (e.g., 2 for float16, 4 for float32)
136
+ - **`memory_multiplier`**: Memory multiplier for estimating peak memory usage
137
+ - For GEGLU: typically 10.0 for backward, 7.0 for forward
138
+ - For ROPE: typically 3.0
139
+ - **`shapes`**: Tuple of full shapes. Each shape is a tuple of dimension sizes.
140
+ - For ROPE: `((n_q_head, hd), (n_kv_head, hd))`
141
+ - For GEGLU: `((n_cols,),)`
142
+ - Can pass original shapes (will handle padding internally) or padded shapes
143
+ - **`tiling_dims`**: Tuple specifying which dimensions can be tiled for each shape.
144
+ - Each element can be:
145
+ - `int`: single dimension index (e.g., `0` for first dimension)
146
+ - `tuple of ints`: multiple dimensions that can be tiled together (non-empty)
147
+ - For ROPE: `(0, 0)` means first dimension of each shape can be tiled
148
+ - For GEGLU: `(0,)` means first dimension of the shape can be tiled
149
+ - Length must match `len(shapes)`
150
+ - Fixed dimensions (non-tiling) are automatically extracted from shapes and multiplied to get `unit_param`
151
+ - **Validation**: Raises `ValueError` if:
152
+ - Any `tiling_dim` is empty or invalid (e.g., empty tuple)
153
+ - Any dimension index is out of bounds (negative or >= shape length)
154
+
155
+ ### 4. Strategy Computation Flow
156
+
157
+ ```
158
+ User calls compute_default_tiling_strategy()
159
+
160
+
161
+ Get UB manager instance
162
+
163
+
164
+ Validate shapes and tiling_dims (lengths must match)
165
+
166
+
167
+ Set defaults for dtype_size (4) and memory_multiplier (10.0)
168
+
169
+
170
+ Call _default_strategy() with:
171
+ - ub_capacity_bits
172
+ - safety_margin
173
+ - dtype_size
174
+ - memory_multiplier
175
+ - shapes
176
+ - tiling_dims
177
+
178
+
179
+ For each (shape, tiling_dim) pair:
180
+ Normalize tiling_dim to set of dimension indices
181
+ Validate tiling dimensions are within shape bounds
182
+ (Raises ValueError if invalid)
183
+
184
+
185
+ Calculate unit_param:
186
+ unit_param = product of all non-tiling dimensions
187
+
188
+
189
+ Calculate max_block_size:
190
+ SAFE_UB_CAPACITY_BITS = ub_capacity_bits * safety_margin
191
+ max_block_size = SAFE_UB_CAPACITY_BITS / (memory_multiplier * unit_param * dtype_size * 8)
192
+
193
+
194
+ Find largest power of 2 <= max_block_size
195
+
196
+
197
+ Return tuple of max_safe_block_size (one per shape)
198
+
199
+
200
+ Build result with same structure as shapes:
201
+ For each (shape, tiling_dim, max_safe):
202
+ For each tiling dimension:
203
+ desired = triton.next_power_of_2(original_dim)
204
+ final = min(desired, max_safe)
205
+ final = max(1, final)
206
+ For each non-tiling dimension:
207
+ pad to triton.next_power_of_2(original_dim)
208
+
209
+
210
+ Return tuple of tiled shapes
211
+ ```
212
+
213
+ ## Usage Examples
214
+
215
+ ### Basic Usage
216
+
217
+ ```python
218
+ from liger_kernel.ops.backends._ascend.ub_manager import compute_default_tiling_strategy
219
+
220
+ # GEGLU forward
221
+ shapes = ((4096,),)
222
+ tile_shapes = compute_default_tiling_strategy(
223
+ safety_margin=0.80,
224
+ dtype_size=2, # float16
225
+ memory_multiplier=7.0,
226
+ shapes=shapes,
227
+ tiling_dims=(0,) # First dimension can be tiled
228
+ )
229
+ if tile_shapes is not None and len(tile_shapes) > 0:
230
+ block_size = tile_shapes[0][0]
231
+ # Call kernel with block_size
232
+
233
+ # ROPE forward
234
+ shapes = ((32, 128), (32, 128)) # (n_q_head, hd), (n_kv_head, hd)
235
+ tile_shapes = compute_default_tiling_strategy(
236
+ safety_margin=0.90,
237
+ dtype_size=4, # float32
238
+ memory_multiplier=3.0,
239
+ shapes=shapes,
240
+ tiling_dims=(0, 0) # First dimension of each shape can be tiled
241
+ )
242
+ if tile_shapes is not None and len(tile_shapes) == len(shapes):
243
+ q_tile_shape, k_tile_shape = tile_shapes
244
+ BLOCK_Q, _ = q_tile_shape # Tiled dimension
245
+ BLOCK_K, _ = k_tile_shape # Tiled dimension
246
+ # Call kernel with BLOCK_Q and BLOCK_K
247
+ ```
248
+
249
+ ## Strategy Function Details
250
+
251
+ ### `_normalize_tiling_dims` Helper Function
252
+
253
+ A helper function that normalizes tiling dimension specifications:
254
+
255
+ ```python
256
+ def _normalize_tiling_dims(tiling_dim: Union[int, Tuple[int, ...]]) -> set:
257
+ """
258
+ Normalize tiling dimension specification to a set of dimension indices.
259
+
260
+ Args:
261
+ tiling_dim: Either an int (single dimension) or tuple of ints (multiple dimensions).
262
+
263
+ Returns:
264
+ Set of dimension indices that can be tiled.
265
+ """
266
+ ```
267
+
268
+ This function handles the conversion of `tiling_dim` from either an `int` or `tuple` to a `set` for consistent processing.
269
+
270
+ ### `_default_strategy` Function
271
+
272
+ The core strategy function that calculates maximum safe block size:
273
+
274
+ ```python
275
+ def _default_strategy(
276
+ ub_capacity_bits: int,
277
+ safety_margin: float,
278
+ dtype_size: int,
279
+ memory_multiplier: float,
280
+ shapes: Tuple[Tuple[int, ...], ...],
281
+ tiling_dims: Tuple[Union[int, Tuple[int, ...]], ...],
282
+ ) -> Tuple[int, ...]:
283
+ """
284
+ Calculate maximum safe block size based on UB capacity.
285
+
286
+ Memory formula: memory_multiplier * BLOCK_SIZE * unit_param * dtype_size * 8 bits
287
+
288
+ For each shape, fixed dimensions (non-tiling) are multiplied together to get unit_param.
289
+
290
+ Returns:
291
+ Tuple of max_safe_block_size (power of 2), one for each shape.
292
+
293
+ Raises:
294
+ ValueError: If any tiling_dim is empty or invalid, or if any dimension
295
+ index is out of bounds for the corresponding shape.
296
+ """
297
+ ```
298
+
299
+ **Key Steps:**
300
+ 1. For each `(shape, tiling_dim)` pair:
301
+ - Normalize `tiling_dim` to a set of dimension indices using `_normalize_tiling_dims`
302
+ - Validate tiling dimensions are within shape bounds
303
+ - Raises `ValueError` if `tiling_dim` is empty or invalid
304
+ - Raises `ValueError` if any dimension index is out of bounds
305
+ - Calculate `unit_param` as the product of all non-tiling dimensions
306
+ - If all dimensions are tiling, `unit_param = 1.0`
307
+ 2. Calculate `SAFE_UB_CAPACITY_BITS = ub_capacity_bits * safety_margin`
308
+ 3. Solve for max_block_size: `SAFE_UB_CAPACITY_BITS / (memory_multiplier * unit_param * dtype_size * 8)`
309
+ 4. Find largest power of 2 <= max_block_size
310
+ 5. Return tuple with one max_safe_block_size per shape
311
+
312
+ ### `compute_default_tiling_strategy` Function
313
+
314
+ The public interface that computes final tiling results:
315
+
316
+ ```python
317
+ def compute_default_tiling_strategy(
318
+ safety_margin: float = 0.80,
319
+ dtype_size: Optional[int] = None,
320
+ memory_multiplier: Optional[float] = None,
321
+ shapes: Optional[Tuple[Tuple[int, ...], ...]] = None,
322
+ tiling_dims: Optional[Tuple[Union[int, Tuple[int, ...]], ...]] = None,
323
+ ) -> Optional[Tuple[Tuple[int, ...], ...]]:
324
+ """
325
+ Compute tiling strategy using the default strategy function.
326
+
327
+ Returns tuple of tiled shapes with same structure as input shapes.
328
+ Tiling dimensions are replaced with computed block sizes (power of 2),
329
+ while non-tiling dimensions are padded to next power of 2.
330
+
331
+ Returns:
332
+ Tuple of tiled shapes, or None if shapes/tiling_dims are empty or
333
+ lengths don't match.
334
+
335
+ Raises:
336
+ ValueError: If any tiling_dim is empty or invalid, or if any dimension
337
+ index is out of bounds for the corresponding shape.
338
+ """
339
+ ```
340
+
341
+ **Key Steps:**
342
+ 1. Get UB manager instance
343
+ 2. Validate `shapes` and `tiling_dims` (lengths must match, cannot be empty)
344
+ - Returns `None` if validation fails (empty or mismatched lengths)
345
+ 3. Set defaults for `dtype_size` (4) and `memory_multiplier` (10.0) if not provided
346
+ 4. Call `_default_strategy` to get `max_supported` (tuple of max_safe_block_size, one per shape)
347
+ - May raise `ValueError` if `tiling_dims` are invalid (see `_default_strategy` documentation)
348
+ 5. For each `(shape, tiling_dim, max_safe)`:
349
+ - Normalize `tiling_dim` to a set of dimension indices
350
+ - Validate tiling dimensions are within shape bounds
351
+ - Raises `ValueError` if `tiling_dim` is empty or invalid
352
+ - Raises `ValueError` if any dimension index is out of bounds
353
+ - For each tiling dimension:
354
+ - Compute `desired = triton.next_power_of_2(original_dim)`
355
+ - Compute `final = min(desired, max_safe)`
356
+ - Ensure `final >= 1`
357
+ - Replace dimension with `final`
358
+ - For each non-tiling dimension:
359
+ - Pad to `triton.next_power_of_2(original_dim)`
360
+ 6. Return tuple of tiled shapes (same structure as input `shapes`)
361
+
362
+ ## Memory Analysis Examples
363
+
364
+ ### GEGLU Forward
365
+
366
+ ```
367
+ Memory analysis:
368
+ - Inputs: a, b
369
+ - Intermediates: a_cubed, tanh_arg, tanh_result, geglu_a
370
+ - Output: c
371
+ - Total: ~7x * BLOCK_SIZE * dtype_size
372
+
373
+ Strategy:
374
+ - shapes: ((n_cols,),)
375
+ - tiling_dims: (0,) # First dimension can be tiled
376
+ - Fixed dimensions: none (all dimensions are tiling)
377
+ - unit_param = 1 (product of fixed dimensions)
378
+ - memory_multiplier = 7.0
379
+ - Formula: 7.0 * BLOCK_SIZE * 1 * dtype_size * 8 bits
380
+ - Returns: ((block_size,),)
381
+ ```
382
+
383
+ ### GEGLU Backward
384
+
385
+ ```
386
+ Memory analysis:
387
+ - More intermediates for gradient computation
388
+ - Total: ~10x * BLOCK_SIZE * dtype_size
389
+
390
+ Strategy:
391
+ - shapes: ((n_cols,),)
392
+ - tiling_dims: (0,) # First dimension can be tiled
393
+ - Fixed dimensions: none (all dimensions are tiling)
394
+ - unit_param = 1 (product of fixed dimensions)
395
+ - memory_multiplier = 10.0
396
+ - Formula: 10.0 * BLOCK_SIZE * 1 * dtype_size * 8 bits
397
+ - Returns: ((block_size,),)
398
+ ```
399
+
400
+ ### ROPE Forward/Backward
401
+
402
+ ```
403
+ Memory analysis (based on optimized ROPE kernel):
404
+ - cos_vals and sin_vals: pad_hd // 2 elements each (shared)
405
+ - In q heads loop (peak memory):
406
+ * q_left, q_right, new_left, new_right: 2 * BLOCK_Q * pad_hd elements
407
+ - In k heads loop (peak memory):
408
+ * k_left, k_right, new_left, new_right: 2 * BLOCK_K * pad_hd elements
409
+ - Plus shared cos/sin: pad_hd elements
410
+ - Conservative estimate: 3 * BLOCK_SIZE * pad_hd * dtype_size * 8 bits
411
+
412
+ Strategy:
413
+ - shapes: ((pad_n_q_head, pad_hd), (pad_n_kv_head, pad_hd))
414
+ - tiling_dims: (0, 0) # First dimension of each shape can be tiled
415
+ - Fixed dimensions: pad_hd (second dimension, non-tiling)
416
+ - unit_param = pad_hd (product of fixed dimensions)
417
+ - memory_multiplier = 3.0
418
+ - Formula: 3.0 * BLOCK_SIZE * pad_hd * dtype_size * 8 bits
419
+ - Returns: ((block_size_q, pad_hd), (block_size_kv, pad_hd))
420
+ ```
421
+
422
+ ## Extension Guide
423
+
424
+ ### Adding a New Kernel
425
+
426
+ To add tiling support for a new kernel:
427
+
428
+ 1. **Analyze memory usage**:
429
+ - Identify peak memory usage in the kernel
430
+ - Determine memory multiplier (e.g., 7.0, 10.0, 3.0)
431
+ - Identify which dimensions can be tiled and which are fixed
432
+ - Fixed dimensions will be automatically extracted and multiplied to get `unit_param`
433
+
434
+ 2. **Use `compute_default_tiling_strategy`** in your kernel:
435
+
436
+ ```python
437
+ def my_kernel_forward(input):
438
+ # Prepare parameters
439
+ n_cols = input.shape[-1]
440
+ dtype_size = input.element_size()
441
+
442
+ # Compute strategy
443
+ # Example 1: Simple case (all dimensions can be tiled)
444
+ shapes = ((n_cols,),)
445
+ tile_shapes = compute_default_tiling_strategy(
446
+ safety_margin=0.80,
447
+ dtype_size=dtype_size,
448
+ memory_multiplier=7.0, # Based on your memory analysis
449
+ shapes=shapes,
450
+ tiling_dims=(0,) # First dimension can be tiled
451
+ )
452
+
453
+ if tile_shapes is not None and len(tile_shapes) > 0:
454
+ block_size = tile_shapes[0][0]
455
+ else:
456
+ block_size = triton.next_power_of_2(n_cols) # Fallback
457
+
458
+ # Example 2: Multiple shapes with fixed dimensions
459
+ # shapes = ((M, K), (K, N))
460
+ # tiling_dims = (0, 1) # First shape: dim 0 can be tiled, dim 1 is fixed
461
+ # # Second shape: dim 0 is fixed, dim 1 can be tiled
462
+ # Returns: ((block_M, K), (K, block_N))
463
+
464
+ # Call kernel
465
+ kernel[(grid_size,)](
466
+ input,
467
+ BLOCK_SIZE=block_size,
468
+ )
469
+ ```
470
+
471
+ 3. **Document memory analysis** in comments:
472
+
473
+ ```python
474
+ # My kernel tiling strategy:
475
+ # - Memory analysis:
476
+ # * Input: input
477
+ # * Intermediates: intermediate1, intermediate2
478
+ # * Output: output
479
+ # * Total: ~7x * BLOCK_SIZE * dtype_size
480
+ # - shapes: ((n_cols,),)
481
+ # - tiling_dims: (0,) means first dimension can be tiled
482
+ # - Fixed dimensions: none (all dimensions are tiling)
483
+ # - unit_param = 1 (product of fixed dimensions)
484
+ # - Uses memory_multiplier=7.0 * BLOCK_SIZE * dtype_size * 8 bits for safety
485
+ # - compute_default_tiling_strategy returns: ((block_size,),)
486
+ # where block_size = min(triton.next_power_of_2(n_cols), max_safe_block_size)
487
+ ```
488
+
489
+ ## Future Improvements
490
+
491
+ 1. **Strategy Variants**: If needed, could add specialized strategy functions for specific kernels while keeping the unified interface
492
+ 2. **Multi-dimensional Tiling**: Could extend to support more complex tiling patterns if needed