liger-kernel-nightly 0.5.10.dev20250611191801__py3-none-any.whl → 0.6.4.dev20260112233432__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of liger-kernel-nightly might be problematic. Click here for more details.

Files changed (107) hide show
  1. liger_kernel/chunked_loss/__init__.py +1 -0
  2. liger_kernel/chunked_loss/cosine_similarity_loss.py +142 -0
  3. liger_kernel/chunked_loss/dpo_loss.py +54 -3
  4. liger_kernel/chunked_loss/functional.py +2 -0
  5. liger_kernel/chunked_loss/fused_linear_distillation.py +23 -5
  6. liger_kernel/chunked_loss/fused_linear_ppo.py +25 -5
  7. liger_kernel/chunked_loss/grpo_loss.py +46 -9
  8. liger_kernel/chunked_loss/jsd_loss.py +44 -13
  9. liger_kernel/ops/__init__.py +141 -0
  10. liger_kernel/ops/backends/README.md +151 -0
  11. liger_kernel/ops/backends/__init__.py +13 -0
  12. liger_kernel/ops/backends/_ascend/__init__.py +5 -0
  13. liger_kernel/ops/backends/_ascend/ascend-ub-manager-design.md +485 -0
  14. liger_kernel/ops/backends/_ascend/ops/__init__.py +49 -0
  15. liger_kernel/ops/backends/_ascend/ops/geglu.py +266 -0
  16. liger_kernel/ops/backends/_ascend/ops/qwen2vl_mrope.py +285 -0
  17. liger_kernel/ops/backends/_ascend/ops/rope.py +290 -0
  18. liger_kernel/ops/backends/_ascend/ops/swiglu.py +142 -0
  19. liger_kernel/ops/backends/_ascend/ops/tvd.py +221 -0
  20. liger_kernel/ops/backends/_ascend/ub_manager.py +349 -0
  21. liger_kernel/ops/backends/registry.py +61 -0
  22. liger_kernel/ops/cross_entropy.py +130 -64
  23. liger_kernel/ops/dyt.py +5 -4
  24. liger_kernel/ops/fused_add_rms_norm.py +416 -0
  25. liger_kernel/ops/fused_linear_cross_entropy.py +115 -22
  26. liger_kernel/ops/geglu.py +6 -4
  27. liger_kernel/ops/group_norm.py +7 -7
  28. liger_kernel/ops/grpo_loss.py +3 -1
  29. liger_kernel/ops/kl_div.py +8 -11
  30. liger_kernel/ops/layer_norm.py +135 -80
  31. liger_kernel/ops/llama4_rope.py +225 -0
  32. liger_kernel/ops/poly_norm.py +390 -0
  33. liger_kernel/ops/rms_norm.py +148 -71
  34. liger_kernel/ops/rope.py +1 -1
  35. liger_kernel/ops/swiglu.py +1 -1
  36. liger_kernel/ops/tiled_mlp.py +136 -0
  37. liger_kernel/ops/utils.py +14 -0
  38. liger_kernel/transformers/__init__.py +65 -0
  39. liger_kernel/transformers/auto_model.py +21 -0
  40. liger_kernel/transformers/cross_entropy.py +9 -4
  41. liger_kernel/transformers/dyt.py +1 -1
  42. liger_kernel/transformers/experimental/__init__.py +5 -0
  43. liger_kernel/transformers/experimental/embedding.py +1 -1
  44. liger_kernel/transformers/functional.py +56 -24
  45. liger_kernel/transformers/fused_add_rms_norm.py +39 -0
  46. liger_kernel/transformers/fused_linear_cross_entropy.py +17 -5
  47. liger_kernel/transformers/fused_linear_jsd.py +1 -1
  48. liger_kernel/transformers/fused_neighborhood_attention.py +1 -1
  49. liger_kernel/transformers/geglu.py +1 -1
  50. liger_kernel/transformers/group_norm.py +1 -1
  51. liger_kernel/transformers/grpo_loss.py +57 -2
  52. liger_kernel/transformers/jsd.py +1 -1
  53. liger_kernel/transformers/kl_div.py +1 -1
  54. liger_kernel/transformers/layer_norm.py +1 -1
  55. liger_kernel/transformers/llama4_rope.py +93 -0
  56. liger_kernel/transformers/model/exaone4.py +136 -0
  57. liger_kernel/transformers/model/falcon_h1.py +122 -0
  58. liger_kernel/transformers/model/gemma.py +28 -8
  59. liger_kernel/transformers/model/gemma2.py +34 -11
  60. liger_kernel/transformers/model/gemma3.py +102 -112
  61. liger_kernel/transformers/model/glm4.py +18 -5
  62. liger_kernel/transformers/model/glm4v.py +163 -0
  63. liger_kernel/transformers/model/glm4v_moe.py +172 -0
  64. liger_kernel/transformers/model/gpt_oss.py +211 -0
  65. liger_kernel/transformers/model/hunyuan_v1.py +134 -0
  66. liger_kernel/transformers/model/internvl.py +157 -0
  67. liger_kernel/transformers/model/llama.py +26 -7
  68. liger_kernel/transformers/model/llama4.py +121 -0
  69. liger_kernel/transformers/model/llava.py +18 -6
  70. liger_kernel/transformers/model/loss_utils.py +34 -3
  71. liger_kernel/transformers/model/mistral.py +17 -10
  72. liger_kernel/transformers/model/mixtral.py +24 -9
  73. liger_kernel/transformers/model/mllama.py +18 -7
  74. liger_kernel/transformers/model/olmo2.py +18 -5
  75. liger_kernel/transformers/model/olmo3.py +142 -0
  76. liger_kernel/transformers/model/output_classes.py +147 -0
  77. liger_kernel/transformers/model/paligemma.py +42 -5
  78. liger_kernel/transformers/model/phi3.py +24 -159
  79. liger_kernel/transformers/model/qwen2.py +26 -4
  80. liger_kernel/transformers/model/qwen2_5_vl.py +21 -8
  81. liger_kernel/transformers/model/qwen2_vl.py +24 -7
  82. liger_kernel/transformers/model/qwen3.py +22 -6
  83. liger_kernel/transformers/model/qwen3_moe.py +27 -7
  84. liger_kernel/transformers/model/qwen3_next.py +146 -0
  85. liger_kernel/transformers/model/qwen3_vl.py +150 -0
  86. liger_kernel/transformers/model/qwen3_vl_moe.py +126 -0
  87. liger_kernel/transformers/model/smollm3.py +199 -0
  88. liger_kernel/transformers/model/smolvlm.py +158 -0
  89. liger_kernel/transformers/monkey_patch.py +1423 -100
  90. liger_kernel/transformers/multi_token_attention.py +2 -2
  91. liger_kernel/transformers/poly_norm.py +42 -0
  92. liger_kernel/transformers/qwen2vl_mrope.py +1 -1
  93. liger_kernel/transformers/rms_norm.py +15 -5
  94. liger_kernel/transformers/rope.py +45 -1
  95. liger_kernel/transformers/softmax.py +1 -1
  96. liger_kernel/transformers/sparsemax.py +1 -1
  97. liger_kernel/transformers/swiglu.py +18 -1
  98. liger_kernel/transformers/tiled_mlp.py +125 -0
  99. liger_kernel/transformers/tvd.py +1 -1
  100. liger_kernel/utils.py +52 -0
  101. {liger_kernel_nightly-0.5.10.dev20250611191801.dist-info → liger_kernel_nightly-0.6.4.dev20260112233432.dist-info}/METADATA +37 -25
  102. liger_kernel_nightly-0.6.4.dev20260112233432.dist-info/RECORD +132 -0
  103. liger_kernel_nightly-0.5.10.dev20250611191801.dist-info/RECORD +0 -95
  104. {liger_kernel_nightly-0.5.10.dev20250611191801.dist-info → liger_kernel_nightly-0.6.4.dev20260112233432.dist-info}/LICENSE +0 -0
  105. {liger_kernel_nightly-0.5.10.dev20250611191801.dist-info → liger_kernel_nightly-0.6.4.dev20260112233432.dist-info}/NOTICE +0 -0
  106. {liger_kernel_nightly-0.5.10.dev20250611191801.dist-info → liger_kernel_nightly-0.6.4.dev20260112233432.dist-info}/WHEEL +0 -0
  107. {liger_kernel_nightly-0.5.10.dev20250611191801.dist-info → liger_kernel_nightly-0.6.4.dev20260112233432.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,485 @@
1
+ # Ascend NPU UB Manager Design Document
2
+
3
+ ## Overview
4
+
5
+ The UB Manager (Unified Buffer Manager) is a core component in **Liger-Kernel** responsible for managing the Unified Buffer (UB) capacity on Ascend NPUs. By automatically detecting UB capacity and providing unified tiling strategy computation, it helps Triton kernels avoid UB overflow errors while maintaining high performance.
6
+
7
+ ## Design Goals
8
+
9
+ 1. **Automated UB Management**: Automatically detect device UB capacity without manual configuration
10
+ 2. **Unified Strategy System**: Use a single unified strategy function for all kernels, abstracting memory calculations
11
+ 3. **Flexible Parameters**: Support different memory multipliers and safety margins for different kernels
12
+ 4. **Easy to Use**: Simple interface that directly computes tiling results
13
+
14
+ ## Architecture Design
15
+
16
+ ### Core Components
17
+
18
+ ```
19
+ ┌─────────────────────────────────────────────────────────┐
20
+ │ UB Manager System │
21
+ ├─────────────────────────────────────────────────────────┤
22
+ │ │
23
+ │ ┌──────────────┐ ┌──────────────────┐ │
24
+ │ │ UBManager │ │ Default Strategy │ │
25
+ │ │ (Singleton)│────────▶│ Function │ │
26
+ │ └──────────────┘ └──────────────────┘ │
27
+ │ │ │ │
28
+ │ │ │ │
29
+ │ ▼ ▼ │
30
+ │ ┌──────────────┐ ┌──────────────────┐ │
31
+ │ │ Capacity │ │ compute_default │ │
32
+ │ │ Detection │ │ _tiling_strategy│ │
33
+ │ └──────────────┘ └──────────────────┘ │
34
+ │ │
35
+ └─────────────────────────────────────────────────────────┘
36
+ │ │
37
+ │ │
38
+ ▼ ▼
39
+ ┌──────────────┐ ┌──────────────────┐
40
+ │ GEGLU │ │ ROPE │
41
+ │ Kernel │ │ Kernel │
42
+ └──────────────┘ └──────────────────┘
43
+ ```
44
+
45
+ ### Class Diagram
46
+
47
+ ```
48
+ ┌──────────────────────────────────────┐
49
+ │ UBManager │
50
+ ├──────────────────────────────────────┤
51
+ │ - _npu_model: str │
52
+ │ - _ub_capacity_bits: int │
53
+ ├──────────────────────────────────────┤
54
+ │ + ub_capacity_bits: int │
55
+ │ + ub_capacity_bytes: int │
56
+ │ + npu_model: str │
57
+ │ - _detect_npu_model() │
58
+ │ - _detect_ub_capacity() │
59
+ └──────────────────────────────────────┘
60
+
61
+ ┌──────────────────────────────────────┐
62
+ │ compute_default_tiling_strategy │
63
+ ├──────────────────────────────────────┤
64
+ │ + safety_margin: float │
65
+ │ + dtype_size: int │
66
+ │ + memory_multiplier: float │
67
+ │ + shapes: Tuple[Tuple[int, ...], ...]│
68
+ │ + tiling_dims: Tuple │
69
+ ├──────────────────────────────────────┤
70
+ │ Returns: Tuple[Tuple[int, ...], ...] │
71
+ │ (same structure as shapes) │
72
+ └──────────────────────────────────────┘
73
+
74
+ ┌──────────────────────────────────────┐
75
+ │ _normalize_tiling_dims │
76
+ ├──────────────────────────────────────┤
77
+ │ Helper function to normalize │
78
+ │ tiling_dim (int or tuple) to set │
79
+ └──────────────────────────────────────┘
80
+ ```
81
+
82
+ ## Core Functionality
83
+
84
+ ### 1. UB Capacity Detection
85
+
86
+ The UB Manager detects UB capacity in the following priority order:
87
+
88
+ 1. **Environment Variable**: `ASCEND_UB_CAPACITY_BITS`
89
+ 2. **Device Properties**: Retrieved from `torch.npu.get_device_properties(0).ub_capacity_bits`
90
+ 3. **Model Defaults**: Use predefined values based on the detected NPU model
91
+
92
+ ```python
93
+ # Default UB capacity configuration
94
+ _DEFAULT_UB_CAPACITIES = {
95
+ "Ascend910B1": 2097152, # ~256 KB
96
+ "Ascend910B4": 1572864, # ~192 KB
97
+ "default": 2097152, # ~256 KB
98
+ }
99
+ ```
100
+
101
+ ### 2. Unified Strategy System
102
+
103
+ All kernels use a single unified strategy function `_default_strategy` that abstracts memory calculations:
104
+
105
+ ```
106
+ Memory Formula: memory_multiplier * BLOCK_SIZE * unit_param * dtype_size * 8 bits
107
+ ```
108
+
109
+ Where `unit_param` is automatically calculated as the product of all fixed (non-tiling) dimensions in each shape.
110
+
111
+ The strategy function:
112
+ - Takes UB capacity, safety margin, dtype size, memory multiplier, shapes, and tiling dimension specifications
113
+ - For each shape, identifies which dimensions can be tiled (from `tiling_dims`)
114
+ - Calculates `unit_param` as the product of fixed (non-tiling) dimensions
115
+ - Calculates the maximum safe block size that fits within UB capacity
116
+ - Returns a tuple of max_safe_block_size values (one for each shape)
117
+
118
+ The `compute_default_tiling_strategy` function:
119
+ - Calls `_default_strategy` to get max_safe_block_size for each shape
120
+ - For each tiling dimension, computes desired block size using `triton.next_power_of_2(original_dim)`
121
+ - Returns the final result with same structure as input shapes: tiling dimensions replaced with computed block sizes, non-tiling dimensions padded to next power of 2
122
+
123
+ ### 3. Parameter Structure
124
+
125
+ The unified strategy uses the following parameters:
126
+
127
+ - **`safety_margin`**: Safety margin as a float (e.g., 0.80 for 80%). Default is 0.80.
128
+ - **`dtype_size`**: Size of data type in bytes (e.g., 2 for float16, 4 for float32)
129
+ - **`memory_multiplier`**: Memory multiplier for estimating peak memory usage
130
+ - For GEGLU: typically 10.0 for backward, 7.0 for forward
131
+ - For ROPE: typically 3.0
132
+ - **`shapes`**: Tuple of full shapes. Each shape is a tuple of dimension sizes.
133
+ - For ROPE: `((n_q_head, hd), (n_kv_head, hd))`
134
+ - For GEGLU: `((n_cols,),)`
135
+ - Can pass original shapes (will handle padding internally) or padded shapes
136
+ - **`tiling_dims`**: Tuple specifying which dimensions can be tiled for each shape.
137
+ - Each element can be:
138
+ - `int`: single dimension index (e.g., `0` for first dimension)
139
+ - `tuple of ints`: multiple dimensions that can be tiled together (non-empty)
140
+ - For ROPE: `(0, 0)` means first dimension of each shape can be tiled
141
+ - For GEGLU: `(0,)` means first dimension of the shape can be tiled
142
+ - Length must match `len(shapes)`
143
+ - Fixed dimensions (non-tiling) are automatically extracted from shapes and multiplied to get `unit_param`
144
+ - **Validation**: Raises `ValueError` if:
145
+ - Any `tiling_dim` is empty or invalid (e.g., empty tuple)
146
+ - Any dimension index is out of bounds (negative or >= shape length)
147
+
148
+ ### 4. Strategy Computation Flow
149
+
150
+ ```
151
+ User calls compute_default_tiling_strategy()
152
+
153
+
154
+ Get UB manager instance
155
+
156
+
157
+ Validate shapes and tiling_dims (lengths must match)
158
+
159
+
160
+ Set defaults for dtype_size (4) and memory_multiplier (10.0)
161
+
162
+
163
+ Call _default_strategy() with:
164
+ - ub_capacity_bits
165
+ - safety_margin
166
+ - dtype_size
167
+ - memory_multiplier
168
+ - shapes
169
+ - tiling_dims
170
+
171
+
172
+ For each (shape, tiling_dim) pair:
173
+ Normalize tiling_dim to set of dimension indices
174
+ Validate tiling dimensions are within shape bounds
175
+ (Raises ValueError if invalid)
176
+
177
+
178
+ Calculate unit_param:
179
+ unit_param = product of all non-tiling dimensions
180
+
181
+
182
+ Calculate max_block_size:
183
+ SAFE_UB_CAPACITY_BITS = ub_capacity_bits * safety_margin
184
+ max_block_size = SAFE_UB_CAPACITY_BITS / (memory_multiplier * unit_param * dtype_size * 8)
185
+
186
+
187
+ Find largest power of 2 <= max_block_size
188
+
189
+
190
+ Return tuple of max_safe_block_size (one per shape)
191
+
192
+
193
+ Build result with same structure as shapes:
194
+ For each (shape, tiling_dim, max_safe):
195
+ For each tiling dimension:
196
+ desired = triton.next_power_of_2(original_dim)
197
+ final = min(desired, max_safe)
198
+ final = max(1, final)
199
+ For each non-tiling dimension:
200
+ pad to triton.next_power_of_2(original_dim)
201
+
202
+
203
+ Return tuple of tiled shapes
204
+ ```
205
+
206
+ ## Usage Examples
207
+
208
+ ### Basic Usage
209
+
210
+ ```python
211
+ from liger_kernel.ops.backends._ascend.ub_manager import compute_default_tiling_strategy
212
+
213
+ # GEGLU forward
214
+ shapes = ((4096,),)
215
+ tile_shapes = compute_default_tiling_strategy(
216
+ safety_margin=0.80,
217
+ dtype_size=2, # float16
218
+ memory_multiplier=7.0,
219
+ shapes=shapes,
220
+ tiling_dims=(0,) # First dimension can be tiled
221
+ )
222
+ if tile_shapes is not None and len(tile_shapes) > 0:
223
+ block_size = tile_shapes[0][0]
224
+ # Call kernel with block_size
225
+
226
+ # ROPE forward
227
+ shapes = ((32, 128), (32, 128)) # (n_q_head, hd), (n_kv_head, hd)
228
+ tile_shapes = compute_default_tiling_strategy(
229
+ safety_margin=0.90,
230
+ dtype_size=4, # float32
231
+ memory_multiplier=3.0,
232
+ shapes=shapes,
233
+ tiling_dims=(0, 0) # First dimension of each shape can be tiled
234
+ )
235
+ if tile_shapes is not None and len(tile_shapes) == len(shapes):
236
+ q_tile_shape, k_tile_shape = tile_shapes
237
+ BLOCK_Q, _ = q_tile_shape # Tiled dimension
238
+ BLOCK_K, _ = k_tile_shape # Tiled dimension
239
+ # Call kernel with BLOCK_Q and BLOCK_K
240
+ ```
241
+
242
+ ## Strategy Function Details
243
+
244
+ ### `_normalize_tiling_dims` Helper Function
245
+
246
+ A helper function that normalizes tiling dimension specifications:
247
+
248
+ ```python
249
+ def _normalize_tiling_dims(tiling_dim: Union[int, Tuple[int, ...]]) -> set:
250
+ """
251
+ Normalize tiling dimension specification to a set of dimension indices.
252
+
253
+ Args:
254
+ tiling_dim: Either an int (single dimension) or tuple of ints (multiple dimensions).
255
+
256
+ Returns:
257
+ Set of dimension indices that can be tiled.
258
+ """
259
+ ```
260
+
261
+ This function handles the conversion of `tiling_dim` from either an `int` or `tuple` to a `set` for consistent processing.
262
+
263
+ ### `_default_strategy` Function
264
+
265
+ The core strategy function that calculates maximum safe block size:
266
+
267
+ ```python
268
+ def _default_strategy(
269
+ ub_capacity_bits: int,
270
+ safety_margin: float,
271
+ dtype_size: int,
272
+ memory_multiplier: float,
273
+ shapes: Tuple[Tuple[int, ...], ...],
274
+ tiling_dims: Tuple[Union[int, Tuple[int, ...]], ...],
275
+ ) -> Tuple[int, ...]:
276
+ """
277
+ Calculate maximum safe block size based on UB capacity.
278
+
279
+ Memory formula: memory_multiplier * BLOCK_SIZE * unit_param * dtype_size * 8 bits
280
+
281
+ For each shape, fixed dimensions (non-tiling) are multiplied together to get unit_param.
282
+
283
+ Returns:
284
+ Tuple of max_safe_block_size (power of 2), one for each shape.
285
+
286
+ Raises:
287
+ ValueError: If any tiling_dim is empty or invalid, or if any dimension
288
+ index is out of bounds for the corresponding shape.
289
+ """
290
+ ```
291
+
292
+ **Key Steps:**
293
+ 1. For each `(shape, tiling_dim)` pair:
294
+ - Normalize `tiling_dim` to a set of dimension indices using `_normalize_tiling_dims`
295
+ - Validate tiling dimensions are within shape bounds
296
+ - Raises `ValueError` if `tiling_dim` is empty or invalid
297
+ - Raises `ValueError` if any dimension index is out of bounds
298
+ - Calculate `unit_param` as the product of all non-tiling dimensions
299
+ - If all dimensions are tiling, `unit_param = 1.0`
300
+ 2. Calculate `SAFE_UB_CAPACITY_BITS = ub_capacity_bits * safety_margin`
301
+ 3. Solve for max_block_size: `SAFE_UB_CAPACITY_BITS / (memory_multiplier * unit_param * dtype_size * 8)`
302
+ 4. Find largest power of 2 <= max_block_size
303
+ 5. Return tuple with one max_safe_block_size per shape
304
+
305
+ ### `compute_default_tiling_strategy` Function
306
+
307
+ The public interface that computes final tiling results:
308
+
309
+ ```python
310
+ def compute_default_tiling_strategy(
311
+ safety_margin: float = 0.80,
312
+ dtype_size: Optional[int] = None,
313
+ memory_multiplier: Optional[float] = None,
314
+ shapes: Optional[Tuple[Tuple[int, ...], ...]] = None,
315
+ tiling_dims: Optional[Tuple[Union[int, Tuple[int, ...]], ...]] = None,
316
+ ) -> Optional[Tuple[Tuple[int, ...], ...]]:
317
+ """
318
+ Compute tiling strategy using the default strategy function.
319
+
320
+ Returns tuple of tiled shapes with same structure as input shapes.
321
+ Tiling dimensions are replaced with computed block sizes (power of 2),
322
+ while non-tiling dimensions are padded to next power of 2.
323
+
324
+ Returns:
325
+ Tuple of tiled shapes, or None if shapes/tiling_dims are empty or
326
+ lengths don't match.
327
+
328
+ Raises:
329
+ ValueError: If any tiling_dim is empty or invalid, or if any dimension
330
+ index is out of bounds for the corresponding shape.
331
+ """
332
+ ```
333
+
334
+ **Key Steps:**
335
+ 1. Get UB manager instance
336
+ 2. Validate `shapes` and `tiling_dims` (lengths must match, cannot be empty)
337
+ - Returns `None` if validation fails (empty or mismatched lengths)
338
+ 3. Set defaults for `dtype_size` (4) and `memory_multiplier` (10.0) if not provided
339
+ 4. Call `_default_strategy` to get `max_supported` (tuple of max_safe_block_size, one per shape)
340
+ - May raise `ValueError` if `tiling_dims` are invalid (see `_default_strategy` documentation)
341
+ 5. For each `(shape, tiling_dim, max_safe)`:
342
+ - Normalize `tiling_dim` to a set of dimension indices
343
+ - Validate tiling dimensions are within shape bounds
344
+ - Raises `ValueError` if `tiling_dim` is empty or invalid
345
+ - Raises `ValueError` if any dimension index is out of bounds
346
+ - For each tiling dimension:
347
+ - Compute `desired = triton.next_power_of_2(original_dim)`
348
+ - Compute `final = min(desired, max_safe)`
349
+ - Ensure `final >= 1`
350
+ - Replace dimension with `final`
351
+ - For each non-tiling dimension:
352
+ - Pad to `triton.next_power_of_2(original_dim)`
353
+ 6. Return tuple of tiled shapes (same structure as input `shapes`)
354
+
355
+ ## Memory Analysis Examples
356
+
357
+ ### GEGLU Forward
358
+
359
+ ```
360
+ Memory analysis:
361
+ - Inputs: a, b
362
+ - Intermediates: a_cubed, tanh_arg, tanh_result, geglu_a
363
+ - Output: c
364
+ - Total: ~7x * BLOCK_SIZE * dtype_size
365
+
366
+ Strategy:
367
+ - shapes: ((n_cols,),)
368
+ - tiling_dims: (0,) # First dimension can be tiled
369
+ - Fixed dimensions: none (all dimensions are tiling)
370
+ - unit_param = 1 (product of fixed dimensions)
371
+ - memory_multiplier = 7.0
372
+ - Formula: 7.0 * BLOCK_SIZE * 1 * dtype_size * 8 bits
373
+ - Returns: ((block_size,),)
374
+ ```
375
+
376
+ ### GEGLU Backward
377
+
378
+ ```
379
+ Memory analysis:
380
+ - More intermediates for gradient computation
381
+ - Total: ~10x * BLOCK_SIZE * dtype_size
382
+
383
+ Strategy:
384
+ - shapes: ((n_cols,),)
385
+ - tiling_dims: (0,) # First dimension can be tiled
386
+ - Fixed dimensions: none (all dimensions are tiling)
387
+ - unit_param = 1 (product of fixed dimensions)
388
+ - memory_multiplier = 10.0
389
+ - Formula: 10.0 * BLOCK_SIZE * 1 * dtype_size * 8 bits
390
+ - Returns: ((block_size,),)
391
+ ```
392
+
393
+ ### ROPE Forward/Backward
394
+
395
+ ```
396
+ Memory analysis (based on optimized ROPE kernel):
397
+ - cos_vals and sin_vals: pad_hd // 2 elements each (shared)
398
+ - In q heads loop (peak memory):
399
+ * q_left, q_right, new_left, new_right: 2 * BLOCK_Q * pad_hd elements
400
+ - In k heads loop (peak memory):
401
+ * k_left, k_right, new_left, new_right: 2 * BLOCK_K * pad_hd elements
402
+ - Plus shared cos/sin: pad_hd elements
403
+ - Conservative estimate: 3 * BLOCK_SIZE * pad_hd * dtype_size * 8 bits
404
+
405
+ Strategy:
406
+ - shapes: ((pad_n_q_head, pad_hd), (pad_n_kv_head, pad_hd))
407
+ - tiling_dims: (0, 0) # First dimension of each shape can be tiled
408
+ - Fixed dimensions: pad_hd (second dimension, non-tiling)
409
+ - unit_param = pad_hd (product of fixed dimensions)
410
+ - memory_multiplier = 3.0
411
+ - Formula: 3.0 * BLOCK_SIZE * pad_hd * dtype_size * 8 bits
412
+ - Returns: ((block_size_q, pad_hd), (block_size_kv, pad_hd))
413
+ ```
414
+
415
+ ## Extension Guide
416
+
417
+ ### Adding a New Kernel
418
+
419
+ To add tiling support for a new kernel:
420
+
421
+ 1. **Analyze memory usage**:
422
+ - Identify peak memory usage in the kernel
423
+ - Determine memory multiplier (e.g., 7.0, 10.0, 3.0)
424
+ - Identify which dimensions can be tiled and which are fixed
425
+ - Fixed dimensions will be automatically extracted and multiplied to get `unit_param`
426
+
427
+ 2. **Use `compute_default_tiling_strategy`** in your kernel:
428
+
429
+ ```python
430
+ def my_kernel_forward(input):
431
+ # Prepare parameters
432
+ n_cols = input.shape[-1]
433
+ dtype_size = input.element_size()
434
+
435
+ # Compute strategy
436
+ # Example 1: Simple case (all dimensions can be tiled)
437
+ shapes = ((n_cols,),)
438
+ tile_shapes = compute_default_tiling_strategy(
439
+ safety_margin=0.80,
440
+ dtype_size=dtype_size,
441
+ memory_multiplier=7.0, # Based on your memory analysis
442
+ shapes=shapes,
443
+ tiling_dims=(0,) # First dimension can be tiled
444
+ )
445
+
446
+ if tile_shapes is not None and len(tile_shapes) > 0:
447
+ block_size = tile_shapes[0][0]
448
+ else:
449
+ block_size = triton.next_power_of_2(n_cols) # Fallback
450
+
451
+ # Example 2: Multiple shapes with fixed dimensions
452
+ # shapes = ((M, K), (K, N))
453
+ # tiling_dims = (0, 1) # First shape: dim 0 can be tiled, dim 1 is fixed
454
+ # # Second shape: dim 0 is fixed, dim 1 can be tiled
455
+ # Returns: ((block_M, K), (K, block_N))
456
+
457
+ # Call kernel
458
+ kernel[(grid_size,)](
459
+ input,
460
+ BLOCK_SIZE=block_size,
461
+ )
462
+ ```
463
+
464
+ 3. **Document memory analysis** in comments:
465
+
466
+ ```python
467
+ # My kernel tiling strategy:
468
+ # - Memory analysis:
469
+ # * Input: input
470
+ # * Intermediates: intermediate1, intermediate2
471
+ # * Output: output
472
+ # * Total: ~7x * BLOCK_SIZE * dtype_size
473
+ # - shapes: ((n_cols,),)
474
+ # - tiling_dims: (0,) means first dimension can be tiled
475
+ # - Fixed dimensions: none (all dimensions are tiling)
476
+ # - unit_param = 1 (product of fixed dimensions)
477
+ # - Uses memory_multiplier=7.0 * BLOCK_SIZE * dtype_size * 8 bits for safety
478
+ # - compute_default_tiling_strategy returns: ((block_size,),)
479
+ # where block_size = min(triton.next_power_of_2(n_cols), max_safe_block_size)
480
+ ```
481
+
482
+ ## Future Improvements
483
+
484
+ 1. **Strategy Variants**: If needed, could add specialized strategy functions for specific kernels while keeping the unified interface
485
+ 2. **Multi-dimensional Tiling**: Could extend to support more complex tiling patterns if needed
@@ -0,0 +1,49 @@
1
+ """
2
+ Ascend NPU operator implementations.
3
+
4
+ This module exports Ascend NPU-optimized implementations that will automatically
5
+ replace the default implementations when running on NPU devices.
6
+
7
+ Both Function classes and kernel functions can be exported here.
8
+
9
+ To add a new operator:
10
+ 1. Create the implementation file (e.g., rms_norm.py)
11
+ 2. Import the Function class and/or kernel functions here
12
+ 3. Optionally add to __all__ for explicit control
13
+
14
+ If __all__ is not defined, all public symbols will be auto-discovered.
15
+ """
16
+
17
+ from liger_kernel.ops.backends._ascend.ops.geglu import LigerGELUMulFunction
18
+ from liger_kernel.ops.backends._ascend.ops.geglu import geglu_backward
19
+ from liger_kernel.ops.backends._ascend.ops.geglu import geglu_forward
20
+ from liger_kernel.ops.backends._ascend.ops.qwen2vl_mrope import LigerQwen2VLMRopeFunction
21
+ from liger_kernel.ops.backends._ascend.ops.qwen2vl_mrope import qwen2vl_mrope_backward
22
+ from liger_kernel.ops.backends._ascend.ops.qwen2vl_mrope import qwen2vl_mrope_forward
23
+ from liger_kernel.ops.backends._ascend.ops.rope import LigerRopeFunction
24
+ from liger_kernel.ops.backends._ascend.ops.rope import rope_backward
25
+ from liger_kernel.ops.backends._ascend.ops.rope import rope_forward
26
+ from liger_kernel.ops.backends._ascend.ops.swiglu import LigerSiLUMulFunction
27
+ from liger_kernel.ops.backends._ascend.ops.swiglu import swiglu_backward
28
+ from liger_kernel.ops.backends._ascend.ops.swiglu import swiglu_forward
29
+ from liger_kernel.ops.backends._ascend.ops.tvd import LigerTVDLossFunction
30
+ from liger_kernel.ops.backends._ascend.ops.tvd import tv_distance_forward_triton
31
+ from liger_kernel.ops.backends._ascend.ops.tvd import tvd_backward_triton
32
+
33
+ __all__ = [
34
+ "LigerGELUMulFunction",
35
+ "geglu_forward",
36
+ "geglu_backward",
37
+ "LigerQwen2VLMRopeFunction",
38
+ "qwen2vl_mrope_forward",
39
+ "qwen2vl_mrope_backward",
40
+ "LigerRopeFunction",
41
+ "rope_forward",
42
+ "rope_backward",
43
+ "LigerSiLUMulFunction",
44
+ "swiglu_forward",
45
+ "swiglu_backward",
46
+ "LigerTVDLossFunction",
47
+ "tv_distance_forward_triton",
48
+ "tvd_backward_triton",
49
+ ]