liger-kernel 0.6.4__py3-none-any.whl → 0.6.5__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (71) hide show
  1. liger_kernel/chunked_loss/cosine_similarity_loss.py +7 -1
  2. liger_kernel/chunked_loss/fused_linear_distillation.py +10 -3
  3. liger_kernel/chunked_loss/jsd_loss.py +21 -6
  4. liger_kernel/ops/__init__.py +141 -0
  5. liger_kernel/ops/backends/README.md +151 -0
  6. liger_kernel/ops/backends/__init__.py +13 -0
  7. liger_kernel/ops/backends/_ascend/__init__.py +5 -0
  8. liger_kernel/ops/backends/_ascend/ascend-ub-manager-design.md +492 -0
  9. liger_kernel/ops/backends/_ascend/ops/__init__.py +61 -0
  10. liger_kernel/ops/backends/_ascend/ops/embedding.py +214 -0
  11. liger_kernel/ops/backends/_ascend/ops/geglu.py +191 -0
  12. liger_kernel/ops/backends/_ascend/ops/llama4_rope.py +298 -0
  13. liger_kernel/ops/backends/_ascend/ops/qwen2vl_mrope.py +275 -0
  14. liger_kernel/ops/backends/_ascend/ops/rope.py +265 -0
  15. liger_kernel/ops/backends/_ascend/ops/swiglu.py +142 -0
  16. liger_kernel/ops/backends/_ascend/ops/tvd.py +223 -0
  17. liger_kernel/ops/backends/_ascend/ub_manager.py +367 -0
  18. liger_kernel/ops/backends/registry.py +61 -0
  19. liger_kernel/ops/cross_entropy.py +14 -4
  20. liger_kernel/ops/dyt.py +5 -2
  21. liger_kernel/ops/fused_add_rms_norm.py +21 -23
  22. liger_kernel/ops/fused_linear_cross_entropy.py +2 -1
  23. liger_kernel/ops/geglu.py +5 -3
  24. liger_kernel/ops/group_norm.py +12 -8
  25. liger_kernel/ops/kl_div.py +8 -11
  26. liger_kernel/ops/layer_norm.py +17 -16
  27. liger_kernel/ops/poly_norm.py +19 -21
  28. liger_kernel/ops/rms_norm.py +149 -71
  29. liger_kernel/ops/utils.py +25 -0
  30. liger_kernel/transformers/__init__.py +6 -0
  31. liger_kernel/transformers/auto_model.py +21 -0
  32. liger_kernel/transformers/cross_entropy.py +1 -1
  33. liger_kernel/transformers/dyt.py +1 -1
  34. liger_kernel/transformers/experimental/embedding.py +1 -1
  35. liger_kernel/transformers/functional.py +20 -20
  36. liger_kernel/transformers/fused_add_rms_norm.py +1 -1
  37. liger_kernel/transformers/fused_linear_cross_entropy.py +1 -1
  38. liger_kernel/transformers/fused_linear_jsd.py +1 -1
  39. liger_kernel/transformers/fused_neighborhood_attention.py +1 -1
  40. liger_kernel/transformers/geglu.py +1 -1
  41. liger_kernel/transformers/group_norm.py +1 -1
  42. liger_kernel/transformers/grpo_loss.py +1 -1
  43. liger_kernel/transformers/jsd.py +1 -1
  44. liger_kernel/transformers/kl_div.py +1 -1
  45. liger_kernel/transformers/layer_norm.py +1 -1
  46. liger_kernel/transformers/llama4_rope.py +1 -1
  47. liger_kernel/transformers/model/exaone4.py +136 -0
  48. liger_kernel/transformers/model/gemma2.py +3 -3
  49. liger_kernel/transformers/model/gemma3.py +11 -5
  50. liger_kernel/transformers/model/gpt_oss.py +211 -0
  51. liger_kernel/transformers/model/loss_utils.py +6 -0
  52. liger_kernel/transformers/model/paligemma.py +1 -0
  53. liger_kernel/transformers/monkey_patch.py +196 -39
  54. liger_kernel/transformers/multi_token_attention.py +1 -1
  55. liger_kernel/transformers/poly_norm.py +1 -1
  56. liger_kernel/transformers/qwen2vl_mrope.py +1 -1
  57. liger_kernel/transformers/rms_norm.py +8 -3
  58. liger_kernel/transformers/rope.py +28 -27
  59. liger_kernel/transformers/softmax.py +1 -1
  60. liger_kernel/transformers/sparsemax.py +1 -1
  61. liger_kernel/transformers/swiglu.py +1 -1
  62. liger_kernel/transformers/tiled_mlp.py +5 -13
  63. liger_kernel/transformers/tvd.py +1 -1
  64. liger_kernel/utils.py +54 -0
  65. {liger_kernel-0.6.4.dist-info → liger_kernel-0.6.5.dist-info}/METADATA +11 -4
  66. liger_kernel-0.6.5.dist-info/RECORD +134 -0
  67. {liger_kernel-0.6.4.dist-info → liger_kernel-0.6.5.dist-info}/WHEEL +1 -1
  68. liger_kernel-0.6.4.dist-info/RECORD +0 -118
  69. {liger_kernel-0.6.4.dist-info → liger_kernel-0.6.5.dist-info}/licenses/LICENSE +0 -0
  70. {liger_kernel-0.6.4.dist-info → liger_kernel-0.6.5.dist-info}/licenses/NOTICE +0 -0
  71. {liger_kernel-0.6.4.dist-info → liger_kernel-0.6.5.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,492 @@
1
+ # Ascend NPU UB Manager Design Document
2
+
3
+ ## Overview
4
+
5
+ The UB Manager (Unified Buffer Manager) is a core component in **Liger-Kernel** responsible for managing the Unified Buffer (UB) capacity on Ascend NPUs. By automatically detecting UB capacity and providing unified tiling strategy computation, it helps Triton kernels avoid UB overflow errors while maintaining high performance.
6
+
7
+ ## Design Goals
8
+
9
+ 1. **Automated UB Management**: Automatically detect device UB capacity without manual configuration
10
+ 2. **Unified Strategy System**: Use a single unified strategy function for all kernels, abstracting memory calculations
11
+ 3. **Flexible Parameters**: Support different memory multipliers and safety margins for different kernels
12
+ 4. **Easy to Use**: Simple interface that directly computes tiling results
13
+
14
+ ## Architecture Design
15
+
16
+ ### Core Components
17
+
18
+ ```
19
+ ┌─────────────────────────────────────────────────────────┐
20
+ │ UB Manager System │
21
+ ├─────────────────────────────────────────────────────────┤
22
+ │ │
23
+ │ ┌──────────────┐ ┌──────────────────┐ │
24
+ │ │ UBManager │ │ Default Strategy │ │
25
+ │ │ (Singleton)│────────▶│ Function │ │
26
+ │ └──────────────┘ └──────────────────┘ │
27
+ │ │ │ │
28
+ │ │ │ │
29
+ │ ▼ ▼ │
30
+ │ ┌──────────────┐ ┌──────────────────┐ │
31
+ │ │ Capacity │ │ compute_default │ │
32
+ │ │ Detection │ │ _tiling_strategy│ │
33
+ │ └──────────────┘ └──────────────────┘ │
34
+ │ │
35
+ └─────────────────────────────────────────────────────────┘
36
+ │ │
37
+ │ │
38
+ ▼ ▼
39
+ ┌──────────────┐ ┌──────────────────┐
40
+ │ GEGLU │ │ ROPE │
41
+ │ Kernel │ │ Kernel │
42
+ └──────────────┘ └──────────────────┘
43
+ ```
44
+
45
+ ### Class Diagram
46
+
47
+ ```
48
+ ┌──────────────────────────────────────┐
49
+ │ UBManager │
50
+ ├──────────────────────────────────────┤
51
+ │ - _npu_model: str │
52
+ │ - _ub_capacity_bits: int │
53
+ ├──────────────────────────────────────┤
54
+ │ + ub_capacity_bits: int │
55
+ │ + ub_capacity_bytes: int │
56
+ │ + npu_model: str │
57
+ │ - _detect_npu_model() │
58
+ │ - _detect_ub_capacity() │
59
+ │ (raises RuntimeError if fails) │
60
+ └──────────────────────────────────────┘
61
+
62
+ ┌──────────────────────────────────────┐
63
+ │ compute_default_tiling_strategy │
64
+ ├──────────────────────────────────────┤
65
+ │ + safety_margin: float │
66
+ │ + dtype_size: int │
67
+ │ + memory_multiplier: float │
68
+ │ + shapes: Tuple[Tuple[int, ...], ...]│
69
+ │ + tiling_dims: Tuple │
70
+ ├──────────────────────────────────────┤
71
+ │ Returns: Tuple[Tuple[int, ...], ...] │
72
+ │ (same structure as shapes) │
73
+ └──────────────────────────────────────┘
74
+
75
+ ┌──────────────────────────────────────┐
76
+ │ _normalize_tiling_dims │
77
+ ├──────────────────────────────────────┤
78
+ │ Helper function to normalize │
79
+ │ tiling_dim (int or tuple) to set │
80
+ └──────────────────────────────────────┘
81
+ ```
82
+
83
+ ## Core Functionality
84
+
85
+ ### 1. UB Capacity Detection
86
+
87
+ The UB Manager detects UB capacity in the following priority order:
88
+
89
+ 1. **Environment Variable**: `ASCEND_UB_CAPACITY_BITS` (in bits)
90
+ - If set, this value is used directly
91
+ - Must be a positive integer representing UB capacity in bits
92
+
93
+ 2. **get_soc_spec**: Query UB size from CANN's `get_soc_spec("UB_SIZE")`
94
+ - Returns UB size in bytes
95
+ - Automatically converted to bits (bytes * 8)
96
+ - Requires CANN environment to be sourced (e.g., `source /usr/local/Ascend/ascend-toolkit/set_env.sh`)
97
+
98
+ 3. **Error Handling**: If neither method succeeds, raises `RuntimeError` with clear instructions
99
+
100
+
101
+ ```python
102
+ # Detection flow:
103
+ # 1. Check ASCEND_UB_CAPACITY_BITS env var (bits)
104
+ # 2. Try get_soc_spec("UB_SIZE") (bytes) -> convert to bits
105
+ # 3. Raise RuntimeError if both fail
106
+ ```
107
+
108
+ ### 2. Unified Strategy System
109
+
110
+ All kernels use a single unified strategy function `_default_strategy` that abstracts memory calculations:
111
+
112
+ ```
113
+ Memory Formula: memory_multiplier * BLOCK_SIZE * unit_param * dtype_size * 8 bits
114
+ ```
115
+
116
+ Where `unit_param` is automatically calculated as the product of all fixed (non-tiling) dimensions in each shape.
117
+
118
+ The strategy function:
119
+ - Takes UB capacity, safety margin, dtype size, memory multiplier, shapes, and tiling dimension specifications
120
+ - For each shape, identifies which dimensions can be tiled (from `tiling_dims`)
121
+ - Calculates `unit_param` as the product of fixed (non-tiling) dimensions
122
+ - Calculates the maximum safe block size that fits within UB capacity
123
+ - Returns a tuple of max_safe_block_size values (one for each shape)
124
+
125
+ The `compute_default_tiling_strategy` function:
126
+ - Calls `_default_strategy` to get max_safe_block_size for each shape
127
+ - For each tiling dimension, computes desired block size using `triton.next_power_of_2(original_dim)`
128
+ - Returns the final result with same structure as input shapes: tiling dimensions replaced with computed block sizes, non-tiling dimensions padded to next power of 2
129
+
130
+ ### 3. Parameter Structure
131
+
132
+ The unified strategy uses the following parameters:
133
+
134
+ - **`safety_margin`**: Safety margin as a float (e.g., 0.80 for 80%). Default is 0.80.
135
+ - **`dtype_size`**: Size of data type in bytes (e.g., 2 for float16, 4 for float32)
136
+ - **`memory_multiplier`**: Memory multiplier for estimating peak memory usage
137
+ - For GEGLU: typically 10.0 for backward, 7.0 for forward
138
+ - For ROPE: typically 3.0
139
+ - **`shapes`**: Tuple of full shapes. Each shape is a tuple of dimension sizes.
140
+ - For ROPE: `((n_q_head, hd), (n_kv_head, hd))`
141
+ - For GEGLU: `((n_cols,),)`
142
+ - Can pass original shapes (will handle padding internally) or padded shapes
143
+ - **`tiling_dims`**: Tuple specifying which dimensions can be tiled for each shape.
144
+ - Each element can be:
145
+ - `int`: single dimension index (e.g., `0` for first dimension)
146
+ - `tuple of ints`: multiple dimensions that can be tiled together (non-empty)
147
+ - For ROPE: `(0, 0)` means first dimension of each shape can be tiled
148
+ - For GEGLU: `(0,)` means first dimension of the shape can be tiled
149
+ - Length must match `len(shapes)`
150
+ - Fixed dimensions (non-tiling) are automatically extracted from shapes and multiplied to get `unit_param`
151
+ - **Validation**: Raises `ValueError` if:
152
+ - Any `tiling_dim` is empty or invalid (e.g., empty tuple)
153
+ - Any dimension index is out of bounds (negative or >= shape length)
154
+
155
+ ### 4. Strategy Computation Flow
156
+
157
+ ```
158
+ User calls compute_default_tiling_strategy()
159
+
160
+
161
+ Get UB manager instance
162
+
163
+
164
+ Validate shapes and tiling_dims (lengths must match)
165
+
166
+
167
+ Set defaults for dtype_size (4) and memory_multiplier (10.0)
168
+
169
+
170
+ Call _default_strategy() with:
171
+ - ub_capacity_bits
172
+ - safety_margin
173
+ - dtype_size
174
+ - memory_multiplier
175
+ - shapes
176
+ - tiling_dims
177
+
178
+
179
+ For each (shape, tiling_dim) pair:
180
+ Normalize tiling_dim to set of dimension indices
181
+ Validate tiling dimensions are within shape bounds
182
+ (Raises ValueError if invalid)
183
+
184
+
185
+ Calculate unit_param:
186
+ unit_param = product of all non-tiling dimensions
187
+
188
+
189
+ Calculate max_block_size:
190
+ SAFE_UB_CAPACITY_BITS = ub_capacity_bits * safety_margin
191
+ max_block_size = SAFE_UB_CAPACITY_BITS / (memory_multiplier * unit_param * dtype_size * 8)
192
+
193
+
194
+ Find largest power of 2 <= max_block_size
195
+
196
+
197
+ Return tuple of max_safe_block_size (one per shape)
198
+
199
+
200
+ Build result with same structure as shapes:
201
+ For each (shape, tiling_dim, max_safe):
202
+ For each tiling dimension:
203
+ desired = triton.next_power_of_2(original_dim)
204
+ final = min(desired, max_safe)
205
+ final = max(1, final)
206
+ For each non-tiling dimension:
207
+ pad to triton.next_power_of_2(original_dim)
208
+
209
+
210
+ Return tuple of tiled shapes
211
+ ```
212
+
213
+ ## Usage Examples
214
+
215
+ ### Basic Usage
216
+
217
+ ```python
218
+ from liger_kernel.ops.backends._ascend.ub_manager import compute_default_tiling_strategy
219
+
220
+ # GEGLU forward
221
+ shapes = ((4096,),)
222
+ tile_shapes = compute_default_tiling_strategy(
223
+ safety_margin=0.80,
224
+ dtype_size=2, # float16
225
+ memory_multiplier=7.0,
226
+ shapes=shapes,
227
+ tiling_dims=(0,) # First dimension can be tiled
228
+ )
229
+ if tile_shapes is not None and len(tile_shapes) > 0:
230
+ block_size = tile_shapes[0][0]
231
+ # Call kernel with block_size
232
+
233
+ # ROPE forward
234
+ shapes = ((32, 128), (32, 128)) # (n_q_head, hd), (n_kv_head, hd)
235
+ tile_shapes = compute_default_tiling_strategy(
236
+ safety_margin=0.90,
237
+ dtype_size=4, # float32
238
+ memory_multiplier=3.0,
239
+ shapes=shapes,
240
+ tiling_dims=(0, 0) # First dimension of each shape can be tiled
241
+ )
242
+ if tile_shapes is not None and len(tile_shapes) == len(shapes):
243
+ q_tile_shape, k_tile_shape = tile_shapes
244
+ BLOCK_Q, _ = q_tile_shape # Tiled dimension
245
+ BLOCK_K, _ = k_tile_shape # Tiled dimension
246
+ # Call kernel with BLOCK_Q and BLOCK_K
247
+ ```
248
+
249
+ ## Strategy Function Details
250
+
251
+ ### `_normalize_tiling_dims` Helper Function
252
+
253
+ A helper function that normalizes tiling dimension specifications:
254
+
255
+ ```python
256
+ def _normalize_tiling_dims(tiling_dim: Union[int, Tuple[int, ...]]) -> set:
257
+ """
258
+ Normalize tiling dimension specification to a set of dimension indices.
259
+
260
+ Args:
261
+ tiling_dim: Either an int (single dimension) or tuple of ints (multiple dimensions).
262
+
263
+ Returns:
264
+ Set of dimension indices that can be tiled.
265
+ """
266
+ ```
267
+
268
+ This function handles the conversion of `tiling_dim` from either an `int` or `tuple` to a `set` for consistent processing.
269
+
270
+ ### `_default_strategy` Function
271
+
272
+ The core strategy function that calculates maximum safe block size:
273
+
274
+ ```python
275
+ def _default_strategy(
276
+ ub_capacity_bits: int,
277
+ safety_margin: float,
278
+ dtype_size: int,
279
+ memory_multiplier: float,
280
+ shapes: Tuple[Tuple[int, ...], ...],
281
+ tiling_dims: Tuple[Union[int, Tuple[int, ...]], ...],
282
+ ) -> Tuple[int, ...]:
283
+ """
284
+ Calculate maximum safe block size based on UB capacity.
285
+
286
+ Memory formula: memory_multiplier * BLOCK_SIZE * unit_param * dtype_size * 8 bits
287
+
288
+ For each shape, fixed dimensions (non-tiling) are multiplied together to get unit_param.
289
+
290
+ Returns:
291
+ Tuple of max_safe_block_size (power of 2), one for each shape.
292
+
293
+ Raises:
294
+ ValueError: If any tiling_dim is empty or invalid, or if any dimension
295
+ index is out of bounds for the corresponding shape.
296
+ """
297
+ ```
298
+
299
+ **Key Steps:**
300
+ 1. For each `(shape, tiling_dim)` pair:
301
+ - Normalize `tiling_dim` to a set of dimension indices using `_normalize_tiling_dims`
302
+ - Validate tiling dimensions are within shape bounds
303
+ - Raises `ValueError` if `tiling_dim` is empty or invalid
304
+ - Raises `ValueError` if any dimension index is out of bounds
305
+ - Calculate `unit_param` as the product of all non-tiling dimensions
306
+ - If all dimensions are tiling, `unit_param = 1.0`
307
+ 2. Calculate `SAFE_UB_CAPACITY_BITS = ub_capacity_bits * safety_margin`
308
+ 3. Solve for max_block_size: `SAFE_UB_CAPACITY_BITS / (memory_multiplier * unit_param * dtype_size * 8)`
309
+ 4. Find largest power of 2 <= max_block_size
310
+ 5. Return tuple with one max_safe_block_size per shape
311
+
312
+ ### `compute_default_tiling_strategy` Function
313
+
314
+ The public interface that computes final tiling results:
315
+
316
+ ```python
317
+ def compute_default_tiling_strategy(
318
+ safety_margin: float = 0.80,
319
+ dtype_size: Optional[int] = None,
320
+ memory_multiplier: Optional[float] = None,
321
+ shapes: Optional[Tuple[Tuple[int, ...], ...]] = None,
322
+ tiling_dims: Optional[Tuple[Union[int, Tuple[int, ...]], ...]] = None,
323
+ ) -> Optional[Tuple[Tuple[int, ...], ...]]:
324
+ """
325
+ Compute tiling strategy using the default strategy function.
326
+
327
+ Returns tuple of tiled shapes with same structure as input shapes.
328
+ Tiling dimensions are replaced with computed block sizes (power of 2),
329
+ while non-tiling dimensions are padded to next power of 2.
330
+
331
+ Returns:
332
+ Tuple of tiled shapes, or None if shapes/tiling_dims are empty or
333
+ lengths don't match.
334
+
335
+ Raises:
336
+ ValueError: If any tiling_dim is empty or invalid, or if any dimension
337
+ index is out of bounds for the corresponding shape.
338
+ """
339
+ ```
340
+
341
+ **Key Steps:**
342
+ 1. Get UB manager instance
343
+ 2. Validate `shapes` and `tiling_dims` (lengths must match, cannot be empty)
344
+ - Returns `None` if validation fails (empty or mismatched lengths)
345
+ 3. Set defaults for `dtype_size` (4) and `memory_multiplier` (10.0) if not provided
346
+ 4. Call `_default_strategy` to get `max_supported` (tuple of max_safe_block_size, one per shape)
347
+ - May raise `ValueError` if `tiling_dims` are invalid (see `_default_strategy` documentation)
348
+ 5. For each `(shape, tiling_dim, max_safe)`:
349
+ - Normalize `tiling_dim` to a set of dimension indices
350
+ - Validate tiling dimensions are within shape bounds
351
+ - Raises `ValueError` if `tiling_dim` is empty or invalid
352
+ - Raises `ValueError` if any dimension index is out of bounds
353
+ - For each tiling dimension:
354
+ - Compute `desired = triton.next_power_of_2(original_dim)`
355
+ - Compute `final = min(desired, max_safe)`
356
+ - Ensure `final >= 1`
357
+ - Replace dimension with `final`
358
+ - For each non-tiling dimension:
359
+ - Pad to `triton.next_power_of_2(original_dim)`
360
+ 6. Return tuple of tiled shapes (same structure as input `shapes`)
361
+
362
+ ## Memory Analysis Examples
363
+
364
+ ### GEGLU Forward
365
+
366
+ ```
367
+ Memory analysis:
368
+ - Inputs: a, b
369
+ - Intermediates: a_cubed, tanh_arg, tanh_result, geglu_a
370
+ - Output: c
371
+ - Total: ~7x * BLOCK_SIZE * dtype_size
372
+
373
+ Strategy:
374
+ - shapes: ((n_cols,),)
375
+ - tiling_dims: (0,) # First dimension can be tiled
376
+ - Fixed dimensions: none (all dimensions are tiling)
377
+ - unit_param = 1 (product of fixed dimensions)
378
+ - memory_multiplier = 7.0
379
+ - Formula: 7.0 * BLOCK_SIZE * 1 * dtype_size * 8 bits
380
+ - Returns: ((block_size,),)
381
+ ```
382
+
383
+ ### GEGLU Backward
384
+
385
+ ```
386
+ Memory analysis:
387
+ - More intermediates for gradient computation
388
+ - Total: ~10x * BLOCK_SIZE * dtype_size
389
+
390
+ Strategy:
391
+ - shapes: ((n_cols,),)
392
+ - tiling_dims: (0,) # First dimension can be tiled
393
+ - Fixed dimensions: none (all dimensions are tiling)
394
+ - unit_param = 1 (product of fixed dimensions)
395
+ - memory_multiplier = 10.0
396
+ - Formula: 10.0 * BLOCK_SIZE * 1 * dtype_size * 8 bits
397
+ - Returns: ((block_size,),)
398
+ ```
399
+
400
+ ### ROPE Forward/Backward
401
+
402
+ ```
403
+ Memory analysis (based on optimized ROPE kernel):
404
+ - cos_vals and sin_vals: pad_hd // 2 elements each (shared)
405
+ - In q heads loop (peak memory):
406
+ * q_left, q_right, new_left, new_right: 2 * BLOCK_Q * pad_hd elements
407
+ - In k heads loop (peak memory):
408
+ * k_left, k_right, new_left, new_right: 2 * BLOCK_K * pad_hd elements
409
+ - Plus shared cos/sin: pad_hd elements
410
+ - Conservative estimate: 3 * BLOCK_SIZE * pad_hd * dtype_size * 8 bits
411
+
412
+ Strategy:
413
+ - shapes: ((pad_n_q_head, pad_hd), (pad_n_kv_head, pad_hd))
414
+ - tiling_dims: (0, 0) # First dimension of each shape can be tiled
415
+ - Fixed dimensions: pad_hd (second dimension, non-tiling)
416
+ - unit_param = pad_hd (product of fixed dimensions)
417
+ - memory_multiplier = 3.0
418
+ - Formula: 3.0 * BLOCK_SIZE * pad_hd * dtype_size * 8 bits
419
+ - Returns: ((block_size_q, pad_hd), (block_size_kv, pad_hd))
420
+ ```
421
+
422
+ ## Extension Guide
423
+
424
+ ### Adding a New Kernel
425
+
426
+ To add tiling support for a new kernel:
427
+
428
+ 1. **Analyze memory usage**:
429
+ - Identify peak memory usage in the kernel
430
+ - Determine memory multiplier (e.g., 7.0, 10.0, 3.0)
431
+ - Identify which dimensions can be tiled and which are fixed
432
+ - Fixed dimensions will be automatically extracted and multiplied to get `unit_param`
433
+
434
+ 2. **Use `compute_default_tiling_strategy`** in your kernel:
435
+
436
+ ```python
437
+ def my_kernel_forward(input):
438
+ # Prepare parameters
439
+ n_cols = input.shape[-1]
440
+ dtype_size = input.element_size()
441
+
442
+ # Compute strategy
443
+ # Example 1: Simple case (all dimensions can be tiled)
444
+ shapes = ((n_cols,),)
445
+ tile_shapes = compute_default_tiling_strategy(
446
+ safety_margin=0.80,
447
+ dtype_size=dtype_size,
448
+ memory_multiplier=7.0, # Based on your memory analysis
449
+ shapes=shapes,
450
+ tiling_dims=(0,) # First dimension can be tiled
451
+ )
452
+
453
+ if tile_shapes is not None and len(tile_shapes) > 0:
454
+ block_size = tile_shapes[0][0]
455
+ else:
456
+ block_size = triton.next_power_of_2(n_cols) # Fallback
457
+
458
+ # Example 2: Multiple shapes with fixed dimensions
459
+ # shapes = ((M, K), (K, N))
460
+ # tiling_dims = (0, 1) # First shape: dim 0 can be tiled, dim 1 is fixed
461
+ # # Second shape: dim 0 is fixed, dim 1 can be tiled
462
+ # Returns: ((block_M, K), (K, block_N))
463
+
464
+ # Call kernel
465
+ kernel[(grid_size,)](
466
+ input,
467
+ BLOCK_SIZE=block_size,
468
+ )
469
+ ```
470
+
471
+ 3. **Document memory analysis** in comments:
472
+
473
+ ```python
474
+ # My kernel tiling strategy:
475
+ # - Memory analysis:
476
+ # * Input: input
477
+ # * Intermediates: intermediate1, intermediate2
478
+ # * Output: output
479
+ # * Total: ~7x * BLOCK_SIZE * dtype_size
480
+ # - shapes: ((n_cols,),)
481
+ # - tiling_dims: (0,) means first dimension can be tiled
482
+ # - Fixed dimensions: none (all dimensions are tiling)
483
+ # - unit_param = 1 (product of fixed dimensions)
484
+ # - Uses memory_multiplier=7.0 * BLOCK_SIZE * dtype_size * 8 bits for safety
485
+ # - compute_default_tiling_strategy returns: ((block_size,),)
486
+ # where block_size = min(triton.next_power_of_2(n_cols), max_safe_block_size)
487
+ ```
488
+
489
+ ## Future Improvements
490
+
491
+ 1. **Strategy Variants**: If needed, could add specialized strategy functions for specific kernels while keeping the unified interface
492
+ 2. **Multi-dimensional Tiling**: Could extend to support more complex tiling patterns if needed
@@ -0,0 +1,61 @@
1
+ """
2
+ Ascend NPU operator implementations.
3
+
4
+ This module exports Ascend NPU-optimized implementations that will automatically
5
+ replace the default implementations when running on NPU devices.
6
+
7
+ Both Function classes and kernel functions can be exported here.
8
+
9
+ To add a new operator:
10
+ 1. Create the implementation file (e.g., rms_norm.py)
11
+ 2. Import the Function class and/or kernel functions here
12
+ 3. Optionally add to __all__ for explicit control
13
+
14
+ If __all__ is not defined, all public symbols will be auto-discovered.
15
+ """
16
+
17
+ from liger_kernel.ops.backends._ascend.ops.embedding import LigerEmbeddingFunction
18
+ from liger_kernel.ops.backends._ascend.ops.embedding import embedding_backward
19
+ from liger_kernel.ops.backends._ascend.ops.embedding import embedding_forward
20
+ from liger_kernel.ops.backends._ascend.ops.geglu import LigerGELUMulFunction
21
+ from liger_kernel.ops.backends._ascend.ops.geglu import geglu_backward
22
+ from liger_kernel.ops.backends._ascend.ops.geglu import geglu_forward
23
+ from liger_kernel.ops.backends._ascend.ops.llama4_rope import LigerLlama4RopeFunction
24
+ from liger_kernel.ops.backends._ascend.ops.llama4_rope import llama4_rope_backward
25
+ from liger_kernel.ops.backends._ascend.ops.llama4_rope import llama4_rope_forward
26
+ from liger_kernel.ops.backends._ascend.ops.qwen2vl_mrope import LigerQwen2VLMRopeFunction
27
+ from liger_kernel.ops.backends._ascend.ops.qwen2vl_mrope import qwen2vl_mrope_backward
28
+ from liger_kernel.ops.backends._ascend.ops.qwen2vl_mrope import qwen2vl_mrope_forward
29
+ from liger_kernel.ops.backends._ascend.ops.rope import LigerRopeFunction
30
+ from liger_kernel.ops.backends._ascend.ops.rope import rope_backward
31
+ from liger_kernel.ops.backends._ascend.ops.rope import rope_forward
32
+ from liger_kernel.ops.backends._ascend.ops.swiglu import LigerSiLUMulFunction
33
+ from liger_kernel.ops.backends._ascend.ops.swiglu import swiglu_backward
34
+ from liger_kernel.ops.backends._ascend.ops.swiglu import swiglu_forward
35
+ from liger_kernel.ops.backends._ascend.ops.tvd import LigerTVDLossFunction
36
+ from liger_kernel.ops.backends._ascend.ops.tvd import tv_distance_forward_triton
37
+ from liger_kernel.ops.backends._ascend.ops.tvd import tvd_backward_triton
38
+
39
+ __all__ = [
40
+ "LigerEmbeddingFunction",
41
+ "embedding_forward",
42
+ "embedding_backward",
43
+ "LigerGELUMulFunction",
44
+ "geglu_forward",
45
+ "geglu_backward",
46
+ "LigerQwen2VLMRopeFunction",
47
+ "qwen2vl_mrope_forward",
48
+ "qwen2vl_mrope_backward",
49
+ "LigerRopeFunction",
50
+ "rope_forward",
51
+ "rope_backward",
52
+ "LigerSiLUMulFunction",
53
+ "swiglu_forward",
54
+ "swiglu_backward",
55
+ "LigerTVDLossFunction",
56
+ "tv_distance_forward_triton",
57
+ "tvd_backward_triton",
58
+ "LigerLlama4RopeFunction",
59
+ "llama4_rope_forward",
60
+ "llama4_rope_backward",
61
+ ]