liger-kernel-nightly 0.4.2.dev20241122052539__tar.gz → 0.4.2.dev20241122175637__tar.gz
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- {liger_kernel_nightly-0.4.2.dev20241122052539/src/liger_kernel_nightly.egg-info → liger_kernel_nightly-0.4.2.dev20241122175637}/PKG-INFO +1 -1
- {liger_kernel_nightly-0.4.2.dev20241122052539 → liger_kernel_nightly-0.4.2.dev20241122175637}/pyproject.toml +1 -1
- {liger_kernel_nightly-0.4.2.dev20241122052539 → liger_kernel_nightly-0.4.2.dev20241122175637}/src/liger_kernel/ops/cross_entropy.py +12 -6
- {liger_kernel_nightly-0.4.2.dev20241122052539 → liger_kernel_nightly-0.4.2.dev20241122175637}/src/liger_kernel/ops/fused_linear_cross_entropy.py +0 -11
- {liger_kernel_nightly-0.4.2.dev20241122052539 → liger_kernel_nightly-0.4.2.dev20241122175637/src/liger_kernel_nightly.egg-info}/PKG-INFO +1 -1
- {liger_kernel_nightly-0.4.2.dev20241122052539 → liger_kernel_nightly-0.4.2.dev20241122175637}/LICENSE +0 -0
- {liger_kernel_nightly-0.4.2.dev20241122052539 → liger_kernel_nightly-0.4.2.dev20241122175637}/NOTICE +0 -0
- {liger_kernel_nightly-0.4.2.dev20241122052539 → liger_kernel_nightly-0.4.2.dev20241122175637}/README.md +0 -0
- {liger_kernel_nightly-0.4.2.dev20241122052539 → liger_kernel_nightly-0.4.2.dev20241122175637}/setup.cfg +0 -0
- {liger_kernel_nightly-0.4.2.dev20241122052539 → liger_kernel_nightly-0.4.2.dev20241122175637}/src/liger_kernel/chunked_loss/__init__.py +0 -0
- {liger_kernel_nightly-0.4.2.dev20241122052539 → liger_kernel_nightly-0.4.2.dev20241122175637}/src/liger_kernel/chunked_loss/cpo_loss.py +0 -0
- {liger_kernel_nightly-0.4.2.dev20241122052539 → liger_kernel_nightly-0.4.2.dev20241122175637}/src/liger_kernel/chunked_loss/dpo_loss.py +0 -0
- {liger_kernel_nightly-0.4.2.dev20241122052539 → liger_kernel_nightly-0.4.2.dev20241122175637}/src/liger_kernel/chunked_loss/functional.py +0 -0
- {liger_kernel_nightly-0.4.2.dev20241122052539 → liger_kernel_nightly-0.4.2.dev20241122175637}/src/liger_kernel/chunked_loss/fused_linear_preference.py +0 -0
- {liger_kernel_nightly-0.4.2.dev20241122052539 → liger_kernel_nightly-0.4.2.dev20241122175637}/src/liger_kernel/chunked_loss/orpo_loss.py +0 -0
- {liger_kernel_nightly-0.4.2.dev20241122052539 → liger_kernel_nightly-0.4.2.dev20241122175637}/src/liger_kernel/chunked_loss/simpo_loss.py +0 -0
- {liger_kernel_nightly-0.4.2.dev20241122052539 → liger_kernel_nightly-0.4.2.dev20241122175637}/src/liger_kernel/env_report.py +0 -0
- {liger_kernel_nightly-0.4.2.dev20241122052539 → liger_kernel_nightly-0.4.2.dev20241122175637}/src/liger_kernel/ops/__init__.py +0 -0
- {liger_kernel_nightly-0.4.2.dev20241122052539 → liger_kernel_nightly-0.4.2.dev20241122175637}/src/liger_kernel/ops/experimental/embedding.py +0 -0
- {liger_kernel_nightly-0.4.2.dev20241122052539 → liger_kernel_nightly-0.4.2.dev20241122175637}/src/liger_kernel/ops/experimental/mm_int8int2.py +0 -0
- {liger_kernel_nightly-0.4.2.dev20241122052539 → liger_kernel_nightly-0.4.2.dev20241122175637}/src/liger_kernel/ops/fused_linear_jsd.py +0 -0
- {liger_kernel_nightly-0.4.2.dev20241122052539 → liger_kernel_nightly-0.4.2.dev20241122175637}/src/liger_kernel/ops/geglu.py +0 -0
- {liger_kernel_nightly-0.4.2.dev20241122052539 → liger_kernel_nightly-0.4.2.dev20241122175637}/src/liger_kernel/ops/group_norm.py +0 -0
- {liger_kernel_nightly-0.4.2.dev20241122052539 → liger_kernel_nightly-0.4.2.dev20241122175637}/src/liger_kernel/ops/jsd.py +0 -0
- {liger_kernel_nightly-0.4.2.dev20241122052539 → liger_kernel_nightly-0.4.2.dev20241122175637}/src/liger_kernel/ops/kl_div.py +0 -0
- {liger_kernel_nightly-0.4.2.dev20241122052539 → liger_kernel_nightly-0.4.2.dev20241122175637}/src/liger_kernel/ops/layer_norm.py +0 -0
- {liger_kernel_nightly-0.4.2.dev20241122052539 → liger_kernel_nightly-0.4.2.dev20241122175637}/src/liger_kernel/ops/qwen2vl_mrope.py +0 -0
- {liger_kernel_nightly-0.4.2.dev20241122052539 → liger_kernel_nightly-0.4.2.dev20241122175637}/src/liger_kernel/ops/rms_norm.py +0 -0
- {liger_kernel_nightly-0.4.2.dev20241122052539 → liger_kernel_nightly-0.4.2.dev20241122175637}/src/liger_kernel/ops/rope.py +0 -0
- {liger_kernel_nightly-0.4.2.dev20241122052539 → liger_kernel_nightly-0.4.2.dev20241122175637}/src/liger_kernel/ops/swiglu.py +0 -0
- {liger_kernel_nightly-0.4.2.dev20241122052539 → liger_kernel_nightly-0.4.2.dev20241122175637}/src/liger_kernel/ops/utils.py +0 -0
- {liger_kernel_nightly-0.4.2.dev20241122052539 → liger_kernel_nightly-0.4.2.dev20241122175637}/src/liger_kernel/transformers/__init__.py +0 -0
- {liger_kernel_nightly-0.4.2.dev20241122052539 → liger_kernel_nightly-0.4.2.dev20241122175637}/src/liger_kernel/transformers/auto_model.py +0 -0
- {liger_kernel_nightly-0.4.2.dev20241122052539 → liger_kernel_nightly-0.4.2.dev20241122175637}/src/liger_kernel/transformers/cross_entropy.py +0 -0
- {liger_kernel_nightly-0.4.2.dev20241122052539 → liger_kernel_nightly-0.4.2.dev20241122175637}/src/liger_kernel/transformers/experimental/embedding.py +0 -0
- {liger_kernel_nightly-0.4.2.dev20241122052539 → liger_kernel_nightly-0.4.2.dev20241122175637}/src/liger_kernel/transformers/functional.py +0 -0
- {liger_kernel_nightly-0.4.2.dev20241122052539 → liger_kernel_nightly-0.4.2.dev20241122175637}/src/liger_kernel/transformers/fused_linear_cross_entropy.py +0 -0
- {liger_kernel_nightly-0.4.2.dev20241122052539 → liger_kernel_nightly-0.4.2.dev20241122175637}/src/liger_kernel/transformers/fused_linear_jsd.py +0 -0
- {liger_kernel_nightly-0.4.2.dev20241122052539 → liger_kernel_nightly-0.4.2.dev20241122175637}/src/liger_kernel/transformers/geglu.py +0 -0
- {liger_kernel_nightly-0.4.2.dev20241122052539 → liger_kernel_nightly-0.4.2.dev20241122175637}/src/liger_kernel/transformers/group_norm.py +0 -0
- {liger_kernel_nightly-0.4.2.dev20241122052539 → liger_kernel_nightly-0.4.2.dev20241122175637}/src/liger_kernel/transformers/jsd.py +0 -0
- {liger_kernel_nightly-0.4.2.dev20241122052539 → liger_kernel_nightly-0.4.2.dev20241122175637}/src/liger_kernel/transformers/kl_div.py +0 -0
- {liger_kernel_nightly-0.4.2.dev20241122052539 → liger_kernel_nightly-0.4.2.dev20241122175637}/src/liger_kernel/transformers/layer_norm.py +0 -0
- {liger_kernel_nightly-0.4.2.dev20241122052539 → liger_kernel_nightly-0.4.2.dev20241122175637}/src/liger_kernel/transformers/model/__init__.py +0 -0
- {liger_kernel_nightly-0.4.2.dev20241122052539 → liger_kernel_nightly-0.4.2.dev20241122175637}/src/liger_kernel/transformers/model/gemma.py +0 -0
- {liger_kernel_nightly-0.4.2.dev20241122052539 → liger_kernel_nightly-0.4.2.dev20241122175637}/src/liger_kernel/transformers/model/gemma2.py +0 -0
- {liger_kernel_nightly-0.4.2.dev20241122052539 → liger_kernel_nightly-0.4.2.dev20241122175637}/src/liger_kernel/transformers/model/llama.py +0 -0
- {liger_kernel_nightly-0.4.2.dev20241122052539 → liger_kernel_nightly-0.4.2.dev20241122175637}/src/liger_kernel/transformers/model/mistral.py +0 -0
- {liger_kernel_nightly-0.4.2.dev20241122052539 → liger_kernel_nightly-0.4.2.dev20241122175637}/src/liger_kernel/transformers/model/mixtral.py +0 -0
- {liger_kernel_nightly-0.4.2.dev20241122052539 → liger_kernel_nightly-0.4.2.dev20241122175637}/src/liger_kernel/transformers/model/mllama.py +0 -0
- {liger_kernel_nightly-0.4.2.dev20241122052539 → liger_kernel_nightly-0.4.2.dev20241122175637}/src/liger_kernel/transformers/model/phi3.py +0 -0
- {liger_kernel_nightly-0.4.2.dev20241122052539 → liger_kernel_nightly-0.4.2.dev20241122175637}/src/liger_kernel/transformers/model/qwen2.py +0 -0
- {liger_kernel_nightly-0.4.2.dev20241122052539 → liger_kernel_nightly-0.4.2.dev20241122175637}/src/liger_kernel/transformers/model/qwen2_vl.py +0 -0
- {liger_kernel_nightly-0.4.2.dev20241122052539 → liger_kernel_nightly-0.4.2.dev20241122175637}/src/liger_kernel/transformers/monkey_patch.py +0 -0
- {liger_kernel_nightly-0.4.2.dev20241122052539 → liger_kernel_nightly-0.4.2.dev20241122175637}/src/liger_kernel/transformers/qwen2vl_mrope.py +0 -0
- {liger_kernel_nightly-0.4.2.dev20241122052539 → liger_kernel_nightly-0.4.2.dev20241122175637}/src/liger_kernel/transformers/rms_norm.py +0 -0
- {liger_kernel_nightly-0.4.2.dev20241122052539 → liger_kernel_nightly-0.4.2.dev20241122175637}/src/liger_kernel/transformers/rope.py +0 -0
- {liger_kernel_nightly-0.4.2.dev20241122052539 → liger_kernel_nightly-0.4.2.dev20241122175637}/src/liger_kernel/transformers/swiglu.py +0 -0
- {liger_kernel_nightly-0.4.2.dev20241122052539 → liger_kernel_nightly-0.4.2.dev20241122175637}/src/liger_kernel/transformers/trainer_integration.py +0 -0
- {liger_kernel_nightly-0.4.2.dev20241122052539 → liger_kernel_nightly-0.4.2.dev20241122175637}/src/liger_kernel/triton/__init__.py +0 -0
- {liger_kernel_nightly-0.4.2.dev20241122052539 → liger_kernel_nightly-0.4.2.dev20241122175637}/src/liger_kernel/triton/monkey_patch.py +0 -0
- {liger_kernel_nightly-0.4.2.dev20241122052539 → liger_kernel_nightly-0.4.2.dev20241122175637}/src/liger_kernel_nightly.egg-info/SOURCES.txt +0 -0
- {liger_kernel_nightly-0.4.2.dev20241122052539 → liger_kernel_nightly-0.4.2.dev20241122175637}/src/liger_kernel_nightly.egg-info/dependency_links.txt +0 -0
- {liger_kernel_nightly-0.4.2.dev20241122052539 → liger_kernel_nightly-0.4.2.dev20241122175637}/src/liger_kernel_nightly.egg-info/requires.txt +0 -0
- {liger_kernel_nightly-0.4.2.dev20241122052539 → liger_kernel_nightly-0.4.2.dev20241122175637}/src/liger_kernel_nightly.egg-info/top_level.txt +0 -0
@@ -4,7 +4,7 @@ build-backend = "setuptools.build_meta"
|
|
4
4
|
|
5
5
|
[project]
|
6
6
|
name = "liger_kernel_nightly"
|
7
|
-
version = "0.4.2.
|
7
|
+
version = "0.4.2.dev20241122175637"
|
8
8
|
description = "Efficient Triton kernels for LLM Training"
|
9
9
|
urls = { "Homepage" = "https://github.com/linkedin/Liger-Kernel" }
|
10
10
|
readme = { file = "README.md", content-type = "text/markdown" }
|
@@ -92,8 +92,8 @@ def liger_cross_entropy_kernel(
|
|
92
92
|
# 3. [Online softmax] first pass: find max + sum
|
93
93
|
m = float("-inf") # m is the max value. use the notation from the paper
|
94
94
|
d = 0.0 # d is the sum. use the notation from the paper
|
95
|
-
ori_X_y = tl.load(
|
96
|
-
|
95
|
+
ori_X_y = tl.load(X_ptr + y).cast(
|
96
|
+
tl.float32
|
97
97
|
) # we need to store the original value of X_y for the loss calculation
|
98
98
|
if HAS_SOFTCAPPING:
|
99
99
|
ori_X_y = softcap * tanh(ori_X_y / softcap)
|
@@ -106,8 +106,11 @@ def liger_cross_entropy_kernel(
|
|
106
106
|
for i in range(0, n_cols, BLOCK_SIZE):
|
107
107
|
X_offsets = i + tl.arange(0, BLOCK_SIZE)
|
108
108
|
X_block = tl.load(
|
109
|
-
X_ptr + X_offsets,
|
110
|
-
|
109
|
+
X_ptr + X_offsets,
|
110
|
+
mask=X_offsets < n_cols,
|
111
|
+
other=float("-inf"),
|
112
|
+
# Ensure float32 precision for softmax calculation
|
113
|
+
).cast(tl.float32)
|
111
114
|
if HAS_SOFTCAPPING:
|
112
115
|
X_block = softcap * tanh(X_block / softcap)
|
113
116
|
block_max = tl.max(X_block)
|
@@ -141,8 +144,11 @@ def liger_cross_entropy_kernel(
|
|
141
144
|
for i in range(0, n_cols, BLOCK_SIZE):
|
142
145
|
X_offsets = i + tl.arange(0, BLOCK_SIZE)
|
143
146
|
X_block = tl.load(
|
144
|
-
X_ptr + X_offsets,
|
145
|
-
|
147
|
+
X_ptr + X_offsets,
|
148
|
+
mask=X_offsets < n_cols,
|
149
|
+
other=float("-inf"),
|
150
|
+
# Ensure float32 precision for softmax calculation
|
151
|
+
).cast(tl.float32)
|
146
152
|
if HAS_SOFTCAPPING:
|
147
153
|
intermediate = tanh(X_block / softcap)
|
148
154
|
X_block = softcap * intermediate
|
@@ -26,7 +26,6 @@ def fused_linear_cross_entropy_forward(
|
|
26
26
|
reduction="mean",
|
27
27
|
softcap=None,
|
28
28
|
):
|
29
|
-
dtype = _input.dtype
|
30
29
|
device = _input.device
|
31
30
|
|
32
31
|
# inputs have shape: BT x H
|
@@ -74,9 +73,6 @@ def fused_linear_cross_entropy_forward(
|
|
74
73
|
loss_1d_slice = loss_1d[start_idx:end_idx] # chunk_size,
|
75
74
|
n_non_ignore = (target_chunk != ignore_index).sum().item()
|
76
75
|
|
77
|
-
# when doing CE, use the upcasted precision
|
78
|
-
logits_chunk = logits_chunk.float()
|
79
|
-
|
80
76
|
# ensure _input and target are contiguous
|
81
77
|
logits_chunk = logits_chunk.contiguous()
|
82
78
|
target_chunk = target_chunk.contiguous()
|
@@ -103,13 +99,6 @@ def fused_linear_cross_entropy_forward(
|
|
103
99
|
num_warps=32 if not is_hip() else 16,
|
104
100
|
)
|
105
101
|
|
106
|
-
# gradient of logits_chunk is computed in-place by the above triton kernel.
|
107
|
-
# Following HuggingFace model source code, we do the forward and backward
|
108
|
-
# w.r.t. logits in fp32 for numerical stability especially as the num classes (vocab size) is huge.
|
109
|
-
# (reference: https://github.com/huggingface/transformers/blob/v4.42.4/src/transformers/models/llama/modeling_llama.py#L1194)
|
110
|
-
# Propagating to lm_head's backward, we'll switch back to the original dtype.
|
111
|
-
logits_chunk = logits_chunk.to(dtype)
|
112
|
-
|
113
102
|
# gradient of logits_chunk is computed in-place by the above triton kernel and is of shape: chunk_size x V
|
114
103
|
# thus grad_input[start_idx: end_idx] should be of shape: chunk_size x H
|
115
104
|
# additionally, since we are chunking the inputs, observe that the loss and gradients are calculated only
|
File without changes
|
{liger_kernel_nightly-0.4.2.dev20241122052539 → liger_kernel_nightly-0.4.2.dev20241122175637}/NOTICE
RENAMED
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|