liger-kernel-nightly 0.4.2.dev20241121224158__py3-none-any.whl → 0.4.2.dev20241121225747__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -15,18 +15,6 @@ from liger_kernel.ops.rms_norm import LigerRMSNormFunction
15
15
  from liger_kernel.ops.rope import LigerRopeFunction
16
16
  from liger_kernel.ops.swiglu import LigerSiLUMulFunction
17
17
 
18
- liger_swiglu = LigerSiLUMulFunction.apply
19
- liger_fused_linear_cross_entropy = LigerFusedLinearCrossEntropyFunction.apply
20
- liger_geglu = LigerGELUMulFunction.apply
21
- liger_rms_norm = LigerRMSNormFunction.apply
22
- liger_rope = LigerRopeFunction.apply
23
- liger_qwen2vl_mrope = LigerQwen2VLMRopeFunction.apply
24
- liger_layer_norm = LigerLayerNormFunction.apply
25
- liger_kl_div = LigerKLDivLossFunction.apply
26
- liger_jsd = LigerJSDFunction.apply
27
- liger_fused_linear_jsd = LigerFusedLinearJSDFunction.apply
28
- liger_group_norm = LigerGroupNormFunction.apply
29
-
30
18
 
31
19
  # conform to the function signature in https://pytorch.org/docs/stable/generated/torch.nn.functional.cross_entropy.html
32
20
  # `weight` and `size_average` are placeholders and not implemented yet
@@ -56,3 +44,130 @@ def liger_cross_entropy(
56
44
  if not return_z_loss:
57
45
  return loss
58
46
  return loss, z_loss
47
+
48
+
49
+ def liger_fused_linear_cross_entropy(
50
+ input,
51
+ weight,
52
+ target,
53
+ bias=None,
54
+ ignore_index: int = -100,
55
+ lse_square_scale: float = 0.0,
56
+ label_smoothing: float = 0.0,
57
+ reduction: str = "mean",
58
+ softcap: Optional[float] = None,
59
+ ):
60
+ return LigerFusedLinearCrossEntropyFunction.apply(
61
+ input,
62
+ weight,
63
+ target,
64
+ bias,
65
+ ignore_index,
66
+ lse_square_scale,
67
+ label_smoothing,
68
+ reduction,
69
+ softcap,
70
+ )
71
+
72
+
73
+ def liger_fused_linear_jsd(
74
+ student_input,
75
+ student_weight,
76
+ teacher_input,
77
+ teacher_weight,
78
+ shift_labels=None,
79
+ jsd_beta: float = 0.5,
80
+ ignore_index: int = -100,
81
+ temperature: float = 1.0,
82
+ ):
83
+ return LigerFusedLinearJSDFunction.apply(
84
+ student_input,
85
+ student_weight,
86
+ teacher_input,
87
+ teacher_weight,
88
+ shift_labels,
89
+ jsd_beta,
90
+ ignore_index,
91
+ temperature,
92
+ )
93
+
94
+
95
+ def liger_geglu(a, b):
96
+ return LigerGELUMulFunction.apply(a, b)
97
+
98
+
99
+ def liger_group_norm(
100
+ X,
101
+ affine_scaling_weight,
102
+ affine_shifting_bias,
103
+ num_channels,
104
+ num_groups,
105
+ eps,
106
+ ):
107
+ return LigerGroupNormFunction.apply(
108
+ X,
109
+ affine_scaling_weight,
110
+ affine_shifting_bias,
111
+ num_channels,
112
+ num_groups,
113
+ eps,
114
+ )
115
+
116
+
117
+ def liger_jsd(
118
+ input,
119
+ target,
120
+ shift_labels=None,
121
+ beta: float = 0.5,
122
+ ignore_index: int = -100,
123
+ ):
124
+ return LigerJSDFunction.apply(
125
+ input,
126
+ target,
127
+ shift_labels,
128
+ beta,
129
+ ignore_index,
130
+ )
131
+
132
+
133
+ # conform to the function signature in https://pytorch.org/docs/stable/generated/torch.nn.functional.kl_div.html#torch.nn.functional.kl_div
134
+ # `size_average` and `mean` are being deprecated in torch API and are placeholders here
135
+ def liger_kl_div(
136
+ input,
137
+ target,
138
+ size_average: bool = True,
139
+ reduce: bool = True,
140
+ reduction: str = "mean",
141
+ log_target: bool = False,
142
+ eps: float = 1e-10,
143
+ ):
144
+ # Note: the default reduction in torch is `mean`, but being `batchmean` in Liger
145
+ return LigerKLDivLossFunction.apply(
146
+ input,
147
+ target,
148
+ reduction,
149
+ log_target,
150
+ eps,
151
+ )
152
+
153
+
154
+ def liger_layer_norm(X, W, B, eps):
155
+ return LigerLayerNormFunction.apply(X, W, B, eps)
156
+
157
+
158
+ def liger_qwen2vl_mrope(q, k, cos, sin, mrope_section, unsqueeze_dim=1):
159
+ return LigerQwen2VLMRopeFunction.apply(q, k, cos, sin, mrope_section, unsqueeze_dim)
160
+
161
+
162
+ def liger_rms_norm(
163
+ X, W, eps, offset: float = 0.0, casting_mode: str = "llama", in_place: bool = True
164
+ ):
165
+ return LigerRMSNormFunction.apply(X, W, eps, offset, casting_mode, in_place)
166
+
167
+
168
+ def liger_rope(q, k, cos, sin, position_ids=None, unsqueeze_dim=1):
169
+ return LigerRopeFunction.apply(q, k, cos, sin, position_ids, unsqueeze_dim)
170
+
171
+
172
+ def liger_swiglu(a, b):
173
+ return LigerSiLUMulFunction.apply(a, b)
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.1
2
2
  Name: liger_kernel_nightly
3
- Version: 0.4.2.dev20241121224158
3
+ Version: 0.4.2.dev20241121225747
4
4
  Summary: Efficient Triton kernels for LLM Training
5
5
  License: BSD 2-CLAUSE LICENSE
6
6
  Copyright 2024 LinkedIn Corporation
@@ -25,7 +25,7 @@ liger_kernel/ops/experimental/mm_int8int2.py,sha256=JpGVZCgRC6T8XMUJ_QbZRS2XU1bh
25
25
  liger_kernel/transformers/__init__.py,sha256=gia-eBxr7TLxU0GdDf8AfCY4WgDlFLqIGSt7EoQGsBA,1336
26
26
  liger_kernel/transformers/auto_model.py,sha256=RMIwQHSiXoksXFTIqFZ4PLBgoqkxJJAT3q1Qh47bGN8,1552
27
27
  liger_kernel/transformers/cross_entropy.py,sha256=yEm_YQ7oa3_BzT3hdW6KrAslduhSqWcJQVNZZDcWCg4,1758
28
- liger_kernel/transformers/functional.py,sha256=jwTHmyjOVC1_I-6ztY1EbbRqPIfFHojcHrP2c4P6U4I,2123
28
+ liger_kernel/transformers/functional.py,sha256=sUBoU8Vb4pLpr9G6IdkRsToYgh-rCXL4OLYat7Tv_GU,4450
29
29
  liger_kernel/transformers/fused_linear_cross_entropy.py,sha256=_i0PXSp5iZ9pKXdEeZ4lvHCENJYjV4y74yz3ZRG5XQg,1484
30
30
  liger_kernel/transformers/fused_linear_jsd.py,sha256=bZ4otCvWBuOnA5XdQL-FzZVItJlDt-ht9e_pG7PG93E,3999
31
31
  liger_kernel/transformers/geglu.py,sha256=QcrME_8ooIn0xa59LaC0aoOdRrBIFd11Y0bAyF0NfCw,1130
@@ -52,9 +52,9 @@ liger_kernel/transformers/model/qwen2.py,sha256=EyhSSzQOskGjSnCsKMZpd1s5IAIlHd5P
52
52
  liger_kernel/transformers/model/qwen2_vl.py,sha256=bIQe2bWiY--G84FhCD29Gdi64_qHP6vbcGsK6vKysQE,8547
53
53
  liger_kernel/triton/__init__.py,sha256=yfRe0zMb47QnqjecZWG7LnanfCTzeku7SgWRAwNVmzU,101
54
54
  liger_kernel/triton/monkey_patch.py,sha256=5BcGKTtdqeYchypBIBopGIWPx1-cFALz7sOKoEsqXJ0,1584
55
- liger_kernel_nightly-0.4.2.dev20241121224158.dist-info/LICENSE,sha256=OhzLDHJ0to4a8sodVLELZiCFylZ1NAAYLs-HrjPy0ag,1312
56
- liger_kernel_nightly-0.4.2.dev20241121224158.dist-info/METADATA,sha256=3HyUur6qJmSMTQaxiLaiDaGUrvU3_ILHlvWdobywuso,21891
57
- liger_kernel_nightly-0.4.2.dev20241121224158.dist-info/NOTICE,sha256=njwnoPZLh9AN8SJQzxvCGLHi-8X__AvWRze6joNXIY8,2066
58
- liger_kernel_nightly-0.4.2.dev20241121224158.dist-info/WHEEL,sha256=P9jw-gEje8ByB7_hXoICnHtVCrEwMQh-630tKvQWehc,91
59
- liger_kernel_nightly-0.4.2.dev20241121224158.dist-info/top_level.txt,sha256=2eghu4hA3LnkM7ElW92tQ8zegWKgSbeo-k-aGe1YnvY,13
60
- liger_kernel_nightly-0.4.2.dev20241121224158.dist-info/RECORD,,
55
+ liger_kernel_nightly-0.4.2.dev20241121225747.dist-info/LICENSE,sha256=OhzLDHJ0to4a8sodVLELZiCFylZ1NAAYLs-HrjPy0ag,1312
56
+ liger_kernel_nightly-0.4.2.dev20241121225747.dist-info/METADATA,sha256=_x0oNBWnmCOeIMGuUSKIpU4v2_j8cG-2u-gfzz071hM,21891
57
+ liger_kernel_nightly-0.4.2.dev20241121225747.dist-info/NOTICE,sha256=njwnoPZLh9AN8SJQzxvCGLHi-8X__AvWRze6joNXIY8,2066
58
+ liger_kernel_nightly-0.4.2.dev20241121225747.dist-info/WHEEL,sha256=P9jw-gEje8ByB7_hXoICnHtVCrEwMQh-630tKvQWehc,91
59
+ liger_kernel_nightly-0.4.2.dev20241121225747.dist-info/top_level.txt,sha256=2eghu4hA3LnkM7ElW92tQ8zegWKgSbeo-k-aGe1YnvY,13
60
+ liger_kernel_nightly-0.4.2.dev20241121225747.dist-info/RECORD,,