liger-kernel-nightly 0.6.4.dev20251208235806__py3-none-any.whl → 0.6.4.dev20251209171241__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -35,8 +35,7 @@ from liger_kernel.transformers.model.smollm3 import lce_forward as smollm3_lce_f
35
35
  from liger_kernel.transformers.qwen2vl_mrope import liger_multimodal_rotary_pos_emb
36
36
  from liger_kernel.transformers.rms_norm import LigerRMSNorm
37
37
  from liger_kernel.transformers.rope import liger_rotary_pos_emb
38
- from liger_kernel.transformers.rope import liger_rotary_pos_emb_with_cast
39
- from liger_kernel.transformers.rope import liger_rotary_pos_emb_with_cast_and_leading_batch
38
+ from liger_kernel.transformers.rope import liger_rotary_pos_emb_vision
40
39
  from liger_kernel.transformers.swiglu import LigerBlockSparseTop2MLP
41
40
  from liger_kernel.transformers.swiglu import LigerPhi3SwiGLUMLP
42
41
  from liger_kernel.transformers.swiglu import LigerSwiGLUMLP
@@ -1754,8 +1753,8 @@ def apply_liger_kernel_to_qwen3_vl(
1754
1753
  from liger_kernel.transformers.model.qwen3_vl import lce_forward as qwen3_vl_lce_forward
1755
1754
 
1756
1755
  if rope:
1757
- modeling_qwen3_vl.apply_rotary_pos_emb = liger_rotary_pos_emb_with_cast
1758
- modeling_qwen3_vl.apply_rotary_pos_emb_vision = liger_rotary_pos_emb_with_cast_and_leading_batch
1756
+ modeling_qwen3_vl.apply_rotary_pos_emb = liger_rotary_pos_emb
1757
+ modeling_qwen3_vl.apply_rotary_pos_emb_vision = liger_rotary_pos_emb_vision
1759
1758
 
1760
1759
  if rms_norm:
1761
1760
  modeling_qwen3_vl.Qwen3VLTextRMSNorm = LigerRMSNorm
@@ -1829,8 +1828,8 @@ def apply_liger_kernel_to_qwen3_vl_moe(
1829
1828
  from liger_kernel.transformers.model.qwen3_vl_moe import lce_forward as qwen3_vl_moe_lce_forward
1830
1829
 
1831
1830
  if rope:
1832
- modeling_qwen3_vl_moe.apply_rotary_pos_emb = liger_rotary_pos_emb_with_cast
1833
- modeling_qwen3_vl_moe.apply_rotary_pos_emb_vision = liger_rotary_pos_emb_with_cast_and_leading_batch
1831
+ modeling_qwen3_vl_moe.apply_rotary_pos_emb = liger_rotary_pos_emb
1832
+ modeling_qwen3_vl_moe.apply_rotary_pos_emb_vision = liger_rotary_pos_emb_vision
1834
1833
 
1835
1834
  if rms_norm:
1836
1835
  modeling_qwen3_vl_moe.Qwen3VLMoeTextRMSNorm = LigerRMSNorm
@@ -1,4 +1,3 @@
1
- from typing import Optional
2
1
  from typing import Tuple
3
2
 
4
3
  import torch
@@ -25,39 +24,41 @@ def liger_rotary_pos_emb(q, k, cos, sin, position_ids=None, unsqueeze_dim=1):
25
24
  return LigerRopeFunction.apply(q, k, cos, sin, position_ids, unsqueeze_dim)
26
25
 
27
26
 
28
- def liger_rotary_pos_emb_with_cast(
27
+ def liger_rotary_pos_emb_vision(
29
28
  q: torch.Tensor,
30
29
  k: torch.Tensor,
31
30
  cos: torch.Tensor,
32
31
  sin: torch.Tensor,
33
- position_ids: Optional[torch.Tensor] = None,
34
- unsqueeze_dim: int = 1,
35
32
  ) -> Tuple[torch.Tensor, torch.Tensor]:
33
+ """
34
+ Modified version of liger_rotary_pos_emb for qwen3_vl's apply_rotary_pos_emb_vision function.
35
+ Manually tranposed the input and output to match the expected shape for liger_rotary_pos_emb.
36
+ Reference: https://https://github.com/huggingface/transformers/blob/v5.0.0rc0/src/transformers/models/qwen3_vl/modeling_qwen3_vl.py#L116
37
+
38
+ Args:
39
+ q (torch.Tensor): The query tensor of shape (seq_length, num_heads, head_dim),
40
+ with stride (num_heads * head_dim, head_dim, 1).
41
+ k (torch.Tensor): The query tensor of shape (seq_length, num_heads, head_dim),
42
+ with stride (num_heads * head_dim, head_dim, 1). Same as q.
43
+ cos (torch.Tensor): The cosine tensor of shape (seq_length, head_dim).
44
+ sin (torch.Tensor): The sine tensor of shape (seq_length, head_dim).
45
+
46
+ Returns:
47
+ Tuple[torch.Tensor, torch.Tensor]: The query and key tensors with the same shape and stride as inputs.
48
+ """
36
49
  orig_q_dtype, orig_k_dtype = q.dtype, k.dtype
37
50
 
38
- q32 = q.to(torch.float32)
39
- k32 = k.to(torch.float32)
51
+ # tranpose to (1, num_heads, seq_length, head_dim) and cast to float32 to match liger_rotary_pos_emb input shape
52
+ # also unsqueeze for batch dim
53
+ q32 = q.to(torch.float32).unsqueeze(0).transpose(1, 2)
54
+ k32 = k.to(torch.float32).unsqueeze(0).transpose(1, 2)
40
55
  cos32 = cos.to(torch.float32)
41
56
  sin32 = sin.to(torch.float32)
42
57
 
43
- q_out, k_out = liger_rotary_pos_emb(q32, k32, cos32, sin32, position_ids=position_ids, unsqueeze_dim=unsqueeze_dim)
44
- return q_out.to(orig_q_dtype), k_out.to(orig_k_dtype)
45
-
46
-
47
- def liger_rotary_pos_emb_with_cast_and_leading_batch(
48
- q: torch.Tensor,
49
- k: torch.Tensor,
50
- cos: torch.Tensor,
51
- sin: torch.Tensor,
52
- position_ids: Optional[torch.Tensor] = None,
53
- unsqueeze_dim: int = 1,
54
- ) -> Tuple[torch.Tensor, torch.Tensor]:
55
- orig_q_dtype, orig_k_dtype = q.dtype, k.dtype
56
-
57
- q32 = q.to(torch.float32).unsqueeze(0)
58
- k32 = k.to(torch.float32).unsqueeze(0)
59
- cos32 = cos.to(torch.float32).unsqueeze(0)
60
- sin32 = sin.to(torch.float32).unsqueeze(0)
58
+ q_out, k_out = liger_rotary_pos_emb(q32, k32, cos32, sin32)
61
59
 
62
- q_out, k_out = liger_rotary_pos_emb(q32, k32, cos32, sin32, position_ids=position_ids, unsqueeze_dim=unsqueeze_dim)
63
- return q_out.to(orig_q_dtype).squeeze(0), k_out.to(orig_k_dtype).squeeze(0)
60
+ # transpose back to (seq_length, num_heads, head_dim) and cast back to original dtype
61
+ # also squeeze out batch dim
62
+ q_out = q_out.transpose(1, 2).squeeze(0).to(orig_q_dtype)
63
+ k_out = k_out.transpose(1, 2).squeeze(0).to(orig_k_dtype)
64
+ return q_out, k_out
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.1
2
2
  Name: liger_kernel_nightly
3
- Version: 0.6.4.dev20251208235806
3
+ Version: 0.6.4.dev20251209171241
4
4
  Summary: Efficient Triton kernels for LLM Training
5
5
  License: BSD 2-CLAUSE LICENSE
6
6
  Copyright 2024 LinkedIn Corporation
@@ -60,12 +60,12 @@ liger_kernel/transformers/jsd.py,sha256=DGqRnxIZxsvxo0_tbbxX3b-sDbDjC_yKufyRIHCc
60
60
  liger_kernel/transformers/kl_div.py,sha256=WLffFbh1EExD2Eb1F7lN11fo9JJC-0751WJjZAF1Fj8,409
61
61
  liger_kernel/transformers/layer_norm.py,sha256=c9pk3PEasOKYR0rhe5e5nNrnYKVCEW4VC8S6LpCq9EQ,906
62
62
  liger_kernel/transformers/llama4_rope.py,sha256=kS6PSHEwf3dS7hD7C7p8S0geugx2EMCiP0h0F7LsUoY,3639
63
- liger_kernel/transformers/monkey_patch.py,sha256=0ER5BjQXcIKwgL2e7ji3_DIm1DaJzevzo53aXSe2YJU,135862
63
+ liger_kernel/transformers/monkey_patch.py,sha256=3MtDn6_1lljiloWHFK_GOo9fCO61QwgXh6OAu9KQdAc,135705
64
64
  liger_kernel/transformers/multi_token_attention.py,sha256=K3NIY9_5TPgZ4_Rahn0xnkMXxD_fmlJHK4CWGYvGQp0,1752
65
65
  liger_kernel/transformers/poly_norm.py,sha256=g5tC75i3qy1_N26ZUP-jfpct7ivQAEdJfIfx8IXzeyE,1377
66
66
  liger_kernel/transformers/qwen2vl_mrope.py,sha256=5EwSqrMdsL9MYspeBMXBsNJKvH0MOmRrtJXAJlnnlOI,1047
67
67
  liger_kernel/transformers/rms_norm.py,sha256=HwddVqrqS58jE-M2_4NkFGARtCDBhGnkKyjBN9b3FYI,3004
68
- liger_kernel/transformers/rope.py,sha256=VMlDZI6zss9mLaLcN5XCE_ktmYRwAi_Eh4TIgO6NrIQ,2361
68
+ liger_kernel/transformers/rope.py,sha256=fmSG2h3GVtx9nqPNUOqH8TvGRKaJhsVZJMzsdVIMqPM,2872
69
69
  liger_kernel/transformers/softmax.py,sha256=yadlAgE4V2JByMwrDDa2s5SUBp8Jgd57xwnVvAWoBaI,264
70
70
  liger_kernel/transformers/sparsemax.py,sha256=0lQA0UEOs4mu8CMruZ3VLhImxQVXJWhPsAKUsYA7vj8,403
71
71
  liger_kernel/transformers/swiglu.py,sha256=dRR69wDWSWfdjtnsTECyxQqWVo5QkdXdXm9SpSQ4Jvw,4291
@@ -111,9 +111,9 @@ liger_kernel/transformers/trainer/__init__.py,sha256=p7yQfklV8-467qSz_ZMimkbDF7H
111
111
  liger_kernel/transformers/trainer/orpo_trainer.py,sha256=tX0h63aOFe3rNqTmk6JpMf75UPo981yzEa6TghnjS0Q,5370
112
112
  liger_kernel/triton/__init__.py,sha256=qCiCamzCRv6lpV8IqpAc9YMdNKC7GKurClWceQPnlis,92
113
113
  liger_kernel/triton/monkey_patch.py,sha256=Rd0hUHAzDkFfHvnX7-PBaNK5EKnZhtfM_h-fgQH9HPY,1568
114
- liger_kernel_nightly-0.6.4.dev20251208235806.dist-info/LICENSE,sha256=OhzLDHJ0to4a8sodVLELZiCFylZ1NAAYLs-HrjPy0ag,1312
115
- liger_kernel_nightly-0.6.4.dev20251208235806.dist-info/METADATA,sha256=tjEDcZKe6AONv56kfZB58iuv1zOQJ1BvIjHMVio3A1U,25468
116
- liger_kernel_nightly-0.6.4.dev20251208235806.dist-info/NOTICE,sha256=njwnoPZLh9AN8SJQzxvCGLHi-8X__AvWRze6joNXIY8,2066
117
- liger_kernel_nightly-0.6.4.dev20251208235806.dist-info/WHEEL,sha256=iAkIy5fosb7FzIOwONchHf19Qu7_1wCWyFNR5gu9nU0,91
118
- liger_kernel_nightly-0.6.4.dev20251208235806.dist-info/top_level.txt,sha256=2eghu4hA3LnkM7ElW92tQ8zegWKgSbeo-k-aGe1YnvY,13
119
- liger_kernel_nightly-0.6.4.dev20251208235806.dist-info/RECORD,,
114
+ liger_kernel_nightly-0.6.4.dev20251209171241.dist-info/LICENSE,sha256=OhzLDHJ0to4a8sodVLELZiCFylZ1NAAYLs-HrjPy0ag,1312
115
+ liger_kernel_nightly-0.6.4.dev20251209171241.dist-info/METADATA,sha256=UgQQVN7vzeOeYXRTEMJucnAgorAIj6QRpv8vBj5fin8,25468
116
+ liger_kernel_nightly-0.6.4.dev20251209171241.dist-info/NOTICE,sha256=njwnoPZLh9AN8SJQzxvCGLHi-8X__AvWRze6joNXIY8,2066
117
+ liger_kernel_nightly-0.6.4.dev20251209171241.dist-info/WHEEL,sha256=iAkIy5fosb7FzIOwONchHf19Qu7_1wCWyFNR5gu9nU0,91
118
+ liger_kernel_nightly-0.6.4.dev20251209171241.dist-info/top_level.txt,sha256=2eghu4hA3LnkM7ElW92tQ8zegWKgSbeo-k-aGe1YnvY,13
119
+ liger_kernel_nightly-0.6.4.dev20251209171241.dist-info/RECORD,,