liger-kernel-nightly 0.4.2.dev20241204180758__tar.gz → 0.4.2.dev20241207011709__tar.gz
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- {liger_kernel_nightly-0.4.2.dev20241204180758/src/liger_kernel_nightly.egg-info → liger_kernel_nightly-0.4.2.dev20241207011709}/PKG-INFO +5 -4
- {liger_kernel_nightly-0.4.2.dev20241204180758 → liger_kernel_nightly-0.4.2.dev20241207011709}/README.md +3 -3
- {liger_kernel_nightly-0.4.2.dev20241204180758 → liger_kernel_nightly-0.4.2.dev20241207011709}/pyproject.toml +2 -1
- {liger_kernel_nightly-0.4.2.dev20241204180758 → liger_kernel_nightly-0.4.2.dev20241207011709}/src/liger_kernel/chunked_loss/cpo_loss.py +3 -3
- {liger_kernel_nightly-0.4.2.dev20241204180758 → liger_kernel_nightly-0.4.2.dev20241207011709}/src/liger_kernel/chunked_loss/dpo_loss.py +4 -3
- {liger_kernel_nightly-0.4.2.dev20241204180758 → liger_kernel_nightly-0.4.2.dev20241207011709}/src/liger_kernel/chunked_loss/fused_linear_preference.py +146 -44
- {liger_kernel_nightly-0.4.2.dev20241204180758 → liger_kernel_nightly-0.4.2.dev20241207011709}/src/liger_kernel/chunked_loss/orpo_loss.py +11 -3
- {liger_kernel_nightly-0.4.2.dev20241204180758 → liger_kernel_nightly-0.4.2.dev20241207011709}/src/liger_kernel/chunked_loss/simpo_loss.py +5 -3
- {liger_kernel_nightly-0.4.2.dev20241204180758 → liger_kernel_nightly-0.4.2.dev20241207011709}/src/liger_kernel/transformers/__init__.py +1 -0
- liger_kernel_nightly-0.4.2.dev20241207011709/src/liger_kernel/transformers/orpo_trainer.py +171 -0
- {liger_kernel_nightly-0.4.2.dev20241204180758 → liger_kernel_nightly-0.4.2.dev20241207011709/src/liger_kernel_nightly.egg-info}/PKG-INFO +5 -4
- {liger_kernel_nightly-0.4.2.dev20241204180758 → liger_kernel_nightly-0.4.2.dev20241207011709}/src/liger_kernel_nightly.egg-info/SOURCES.txt +1 -0
- {liger_kernel_nightly-0.4.2.dev20241204180758 → liger_kernel_nightly-0.4.2.dev20241207011709}/src/liger_kernel_nightly.egg-info/requires.txt +1 -0
- {liger_kernel_nightly-0.4.2.dev20241204180758 → liger_kernel_nightly-0.4.2.dev20241207011709}/LICENSE +0 -0
- {liger_kernel_nightly-0.4.2.dev20241204180758 → liger_kernel_nightly-0.4.2.dev20241207011709}/NOTICE +0 -0
- {liger_kernel_nightly-0.4.2.dev20241204180758 → liger_kernel_nightly-0.4.2.dev20241207011709}/setup.cfg +0 -0
- {liger_kernel_nightly-0.4.2.dev20241204180758 → liger_kernel_nightly-0.4.2.dev20241207011709}/src/liger_kernel/__init__.py +0 -0
- {liger_kernel_nightly-0.4.2.dev20241204180758 → liger_kernel_nightly-0.4.2.dev20241207011709}/src/liger_kernel/chunked_loss/__init__.py +0 -0
- {liger_kernel_nightly-0.4.2.dev20241204180758 → liger_kernel_nightly-0.4.2.dev20241207011709}/src/liger_kernel/chunked_loss/functional.py +0 -0
- {liger_kernel_nightly-0.4.2.dev20241204180758 → liger_kernel_nightly-0.4.2.dev20241207011709}/src/liger_kernel/env_report.py +0 -0
- {liger_kernel_nightly-0.4.2.dev20241204180758 → liger_kernel_nightly-0.4.2.dev20241207011709}/src/liger_kernel/ops/__init__.py +0 -0
- {liger_kernel_nightly-0.4.2.dev20241204180758 → liger_kernel_nightly-0.4.2.dev20241207011709}/src/liger_kernel/ops/cross_entropy.py +0 -0
- {liger_kernel_nightly-0.4.2.dev20241204180758 → liger_kernel_nightly-0.4.2.dev20241207011709}/src/liger_kernel/ops/experimental/embedding.py +0 -0
- {liger_kernel_nightly-0.4.2.dev20241204180758 → liger_kernel_nightly-0.4.2.dev20241207011709}/src/liger_kernel/ops/experimental/mm_int8int2.py +0 -0
- {liger_kernel_nightly-0.4.2.dev20241204180758 → liger_kernel_nightly-0.4.2.dev20241207011709}/src/liger_kernel/ops/fused_linear_cross_entropy.py +0 -0
- {liger_kernel_nightly-0.4.2.dev20241204180758 → liger_kernel_nightly-0.4.2.dev20241207011709}/src/liger_kernel/ops/fused_linear_jsd.py +0 -0
- {liger_kernel_nightly-0.4.2.dev20241204180758 → liger_kernel_nightly-0.4.2.dev20241207011709}/src/liger_kernel/ops/geglu.py +0 -0
- {liger_kernel_nightly-0.4.2.dev20241204180758 → liger_kernel_nightly-0.4.2.dev20241207011709}/src/liger_kernel/ops/group_norm.py +0 -0
- {liger_kernel_nightly-0.4.2.dev20241204180758 → liger_kernel_nightly-0.4.2.dev20241207011709}/src/liger_kernel/ops/jsd.py +0 -0
- {liger_kernel_nightly-0.4.2.dev20241204180758 → liger_kernel_nightly-0.4.2.dev20241207011709}/src/liger_kernel/ops/kl_div.py +0 -0
- {liger_kernel_nightly-0.4.2.dev20241204180758 → liger_kernel_nightly-0.4.2.dev20241207011709}/src/liger_kernel/ops/layer_norm.py +0 -0
- {liger_kernel_nightly-0.4.2.dev20241204180758 → liger_kernel_nightly-0.4.2.dev20241207011709}/src/liger_kernel/ops/qwen2vl_mrope.py +0 -0
- {liger_kernel_nightly-0.4.2.dev20241204180758 → liger_kernel_nightly-0.4.2.dev20241207011709}/src/liger_kernel/ops/rms_norm.py +0 -0
- {liger_kernel_nightly-0.4.2.dev20241204180758 → liger_kernel_nightly-0.4.2.dev20241207011709}/src/liger_kernel/ops/rope.py +0 -0
- {liger_kernel_nightly-0.4.2.dev20241204180758 → liger_kernel_nightly-0.4.2.dev20241207011709}/src/liger_kernel/ops/swiglu.py +0 -0
- {liger_kernel_nightly-0.4.2.dev20241204180758 → liger_kernel_nightly-0.4.2.dev20241207011709}/src/liger_kernel/ops/utils.py +0 -0
- {liger_kernel_nightly-0.4.2.dev20241204180758 → liger_kernel_nightly-0.4.2.dev20241207011709}/src/liger_kernel/transformers/auto_model.py +0 -0
- {liger_kernel_nightly-0.4.2.dev20241204180758 → liger_kernel_nightly-0.4.2.dev20241207011709}/src/liger_kernel/transformers/cross_entropy.py +0 -0
- {liger_kernel_nightly-0.4.2.dev20241204180758 → liger_kernel_nightly-0.4.2.dev20241207011709}/src/liger_kernel/transformers/experimental/embedding.py +0 -0
- {liger_kernel_nightly-0.4.2.dev20241204180758 → liger_kernel_nightly-0.4.2.dev20241207011709}/src/liger_kernel/transformers/functional.py +0 -0
- {liger_kernel_nightly-0.4.2.dev20241204180758 → liger_kernel_nightly-0.4.2.dev20241207011709}/src/liger_kernel/transformers/fused_linear_cross_entropy.py +0 -0
- {liger_kernel_nightly-0.4.2.dev20241204180758 → liger_kernel_nightly-0.4.2.dev20241207011709}/src/liger_kernel/transformers/fused_linear_jsd.py +0 -0
- {liger_kernel_nightly-0.4.2.dev20241204180758 → liger_kernel_nightly-0.4.2.dev20241207011709}/src/liger_kernel/transformers/geglu.py +0 -0
- {liger_kernel_nightly-0.4.2.dev20241204180758 → liger_kernel_nightly-0.4.2.dev20241207011709}/src/liger_kernel/transformers/group_norm.py +0 -0
- {liger_kernel_nightly-0.4.2.dev20241204180758 → liger_kernel_nightly-0.4.2.dev20241207011709}/src/liger_kernel/transformers/jsd.py +0 -0
- {liger_kernel_nightly-0.4.2.dev20241204180758 → liger_kernel_nightly-0.4.2.dev20241207011709}/src/liger_kernel/transformers/kl_div.py +0 -0
- {liger_kernel_nightly-0.4.2.dev20241204180758 → liger_kernel_nightly-0.4.2.dev20241207011709}/src/liger_kernel/transformers/layer_norm.py +0 -0
- {liger_kernel_nightly-0.4.2.dev20241204180758 → liger_kernel_nightly-0.4.2.dev20241207011709}/src/liger_kernel/transformers/model/__init__.py +0 -0
- {liger_kernel_nightly-0.4.2.dev20241204180758 → liger_kernel_nightly-0.4.2.dev20241207011709}/src/liger_kernel/transformers/model/gemma.py +0 -0
- {liger_kernel_nightly-0.4.2.dev20241204180758 → liger_kernel_nightly-0.4.2.dev20241207011709}/src/liger_kernel/transformers/model/gemma2.py +0 -0
- {liger_kernel_nightly-0.4.2.dev20241204180758 → liger_kernel_nightly-0.4.2.dev20241207011709}/src/liger_kernel/transformers/model/llama.py +0 -0
- {liger_kernel_nightly-0.4.2.dev20241204180758 → liger_kernel_nightly-0.4.2.dev20241207011709}/src/liger_kernel/transformers/model/mistral.py +0 -0
- {liger_kernel_nightly-0.4.2.dev20241204180758 → liger_kernel_nightly-0.4.2.dev20241207011709}/src/liger_kernel/transformers/model/mixtral.py +0 -0
- {liger_kernel_nightly-0.4.2.dev20241204180758 → liger_kernel_nightly-0.4.2.dev20241207011709}/src/liger_kernel/transformers/model/mllama.py +0 -0
- {liger_kernel_nightly-0.4.2.dev20241204180758 → liger_kernel_nightly-0.4.2.dev20241207011709}/src/liger_kernel/transformers/model/phi3.py +0 -0
- {liger_kernel_nightly-0.4.2.dev20241204180758 → liger_kernel_nightly-0.4.2.dev20241207011709}/src/liger_kernel/transformers/model/qwen2.py +0 -0
- {liger_kernel_nightly-0.4.2.dev20241204180758 → liger_kernel_nightly-0.4.2.dev20241207011709}/src/liger_kernel/transformers/model/qwen2_vl.py +0 -0
- {liger_kernel_nightly-0.4.2.dev20241204180758 → liger_kernel_nightly-0.4.2.dev20241207011709}/src/liger_kernel/transformers/monkey_patch.py +0 -0
- {liger_kernel_nightly-0.4.2.dev20241204180758 → liger_kernel_nightly-0.4.2.dev20241207011709}/src/liger_kernel/transformers/qwen2vl_mrope.py +0 -0
- {liger_kernel_nightly-0.4.2.dev20241204180758 → liger_kernel_nightly-0.4.2.dev20241207011709}/src/liger_kernel/transformers/rms_norm.py +0 -0
- {liger_kernel_nightly-0.4.2.dev20241204180758 → liger_kernel_nightly-0.4.2.dev20241207011709}/src/liger_kernel/transformers/rope.py +0 -0
- {liger_kernel_nightly-0.4.2.dev20241204180758 → liger_kernel_nightly-0.4.2.dev20241207011709}/src/liger_kernel/transformers/swiglu.py +0 -0
- {liger_kernel_nightly-0.4.2.dev20241204180758 → liger_kernel_nightly-0.4.2.dev20241207011709}/src/liger_kernel/transformers/trainer_integration.py +0 -0
- {liger_kernel_nightly-0.4.2.dev20241204180758 → liger_kernel_nightly-0.4.2.dev20241207011709}/src/liger_kernel/triton/__init__.py +0 -0
- {liger_kernel_nightly-0.4.2.dev20241204180758 → liger_kernel_nightly-0.4.2.dev20241207011709}/src/liger_kernel/triton/monkey_patch.py +0 -0
- {liger_kernel_nightly-0.4.2.dev20241204180758 → liger_kernel_nightly-0.4.2.dev20241207011709}/src/liger_kernel/utils.py +0 -0
- {liger_kernel_nightly-0.4.2.dev20241204180758 → liger_kernel_nightly-0.4.2.dev20241207011709}/src/liger_kernel_nightly.egg-info/dependency_links.txt +0 -0
- {liger_kernel_nightly-0.4.2.dev20241204180758 → liger_kernel_nightly-0.4.2.dev20241207011709}/src/liger_kernel_nightly.egg-info/top_level.txt +0 -0
@@ -1,6 +1,6 @@
|
|
1
1
|
Metadata-Version: 2.1
|
2
2
|
Name: liger_kernel_nightly
|
3
|
-
Version: 0.4.2.
|
3
|
+
Version: 0.4.2.dev20241207011709
|
4
4
|
Summary: Efficient Triton kernels for LLM Training
|
5
5
|
License: BSD 2-CLAUSE LICENSE
|
6
6
|
Copyright 2024 LinkedIn Corporation
|
@@ -36,6 +36,7 @@ Provides-Extra: transformers
|
|
36
36
|
Requires-Dist: transformers~=4.0; extra == "transformers"
|
37
37
|
Provides-Extra: dev
|
38
38
|
Requires-Dist: transformers>=4.44.2; extra == "dev"
|
39
|
+
Requires-Dist: trl>=0.11.0; extra == "dev"
|
39
40
|
Requires-Dist: matplotlib>=3.7.2; extra == "dev"
|
40
41
|
Requires-Dist: flake8>=4.0.1.1; extra == "dev"
|
41
42
|
Requires-Dist: black>=24.4.2; extra == "dev"
|
@@ -55,7 +56,7 @@ Requires-Dist: seaborn; extra == "dev"
|
|
55
56
|
<th style="padding: 10px;" colspan="2">Stable</th>
|
56
57
|
<th style="padding: 10px;" colspan="2">Nightly</th>
|
57
58
|
<th style="padding: 10px;">Discord</th>
|
58
|
-
<th style="padding: 10px;">
|
59
|
+
<th style="padding: 10px;">Build</th>
|
59
60
|
</tr>
|
60
61
|
<tr>
|
61
62
|
<td style="padding: 10px;">
|
@@ -84,8 +85,8 @@ Requires-Dist: seaborn; extra == "dev"
|
|
84
85
|
</a>
|
85
86
|
</td>
|
86
87
|
<td style="padding: 10px;">
|
87
|
-
<a href="https://
|
88
|
-
<img src="https://
|
88
|
+
<a href="https://github.com/linkedin/Liger-Kernel/actions/workflows/ci.yml">
|
89
|
+
<img src="https://github.com/linkedin/Liger-Kernel/actions/workflows/ci.yml/badge.svg?event=schedule" alt="Build">
|
89
90
|
</a>
|
90
91
|
</td>
|
91
92
|
</tr>
|
@@ -8,7 +8,7 @@
|
|
8
8
|
<th style="padding: 10px;" colspan="2">Stable</th>
|
9
9
|
<th style="padding: 10px;" colspan="2">Nightly</th>
|
10
10
|
<th style="padding: 10px;">Discord</th>
|
11
|
-
<th style="padding: 10px;">
|
11
|
+
<th style="padding: 10px;">Build</th>
|
12
12
|
</tr>
|
13
13
|
<tr>
|
14
14
|
<td style="padding: 10px;">
|
@@ -37,8 +37,8 @@
|
|
37
37
|
</a>
|
38
38
|
</td>
|
39
39
|
<td style="padding: 10px;">
|
40
|
-
<a href="https://
|
41
|
-
<img src="https://
|
40
|
+
<a href="https://github.com/linkedin/Liger-Kernel/actions/workflows/ci.yml">
|
41
|
+
<img src="https://github.com/linkedin/Liger-Kernel/actions/workflows/ci.yml/badge.svg?event=schedule" alt="Build">
|
42
42
|
</a>
|
43
43
|
</td>
|
44
44
|
</tr>
|
@@ -4,7 +4,7 @@ build-backend = "setuptools.build_meta"
|
|
4
4
|
|
5
5
|
[project]
|
6
6
|
name = "liger_kernel_nightly"
|
7
|
-
version = "0.4.2.
|
7
|
+
version = "0.4.2.dev20241207011709"
|
8
8
|
description = "Efficient Triton kernels for LLM Training"
|
9
9
|
urls = { "Homepage" = "https://github.com/linkedin/Liger-Kernel" }
|
10
10
|
readme = { file = "README.md", content-type = "text/markdown" }
|
@@ -21,6 +21,7 @@ transformers = [
|
|
21
21
|
|
22
22
|
dev = [
|
23
23
|
"transformers>=4.44.2",
|
24
|
+
"trl>=0.11.0",
|
24
25
|
"matplotlib>=3.7.2",
|
25
26
|
"flake8>=4.0.1.1",
|
26
27
|
"black>=24.4.2",
|
@@ -9,7 +9,7 @@ from liger_kernel.chunked_loss.fused_linear_preference import (
|
|
9
9
|
class LigerFusedLinearCPOFunction(LigerFusedLinearPreferenceBase):
|
10
10
|
|
11
11
|
@staticmethod
|
12
|
-
def preference_loss_fn(chosen_logps, rejected_logps, beta=0.1):
|
12
|
+
def preference_loss_fn(chosen_logps, rejected_logps, full_target, beta=0.1):
|
13
13
|
"""
|
14
14
|
Compute odds-ratio loss.
|
15
15
|
Args:
|
@@ -18,7 +18,7 @@ class LigerFusedLinearCPOFunction(LigerFusedLinearPreferenceBase):
|
|
18
18
|
beta (float): Weight for the odds ratio loss.
|
19
19
|
"""
|
20
20
|
logits = beta * (chosen_logps - rejected_logps)
|
21
|
-
loss = F.logsigmoid(logits).
|
21
|
+
loss = F.logsigmoid(logits).sum() / (full_target.shape[0] // 2)
|
22
22
|
return loss
|
23
23
|
|
24
24
|
@staticmethod
|
@@ -55,7 +55,7 @@ class LigerFusedLinearCPOFunction(LigerFusedLinearPreferenceBase):
|
|
55
55
|
)
|
56
56
|
|
57
57
|
@staticmethod
|
58
|
-
def backward(ctx, grad_output):
|
58
|
+
def backward(ctx, *grad_output):
|
59
59
|
# Get gradients for _input, weight, bias, and target from the base class
|
60
60
|
grads = LigerFusedLinearPreferenceBase.backward(ctx, grad_output)[:4]
|
61
61
|
# Return these gradients, followed by None for the remaining inputs
|
@@ -12,6 +12,7 @@ class LigerFusedLinearDPOFunction(LigerFusedLinearPreferenceBase):
|
|
12
12
|
def preference_loss_fn(
|
13
13
|
chosen_logps,
|
14
14
|
rejected_logps,
|
15
|
+
full_target,
|
15
16
|
ref_chosen_logps=None,
|
16
17
|
ref_rejected_logps=None,
|
17
18
|
beta=0.1,
|
@@ -34,8 +35,8 @@ class LigerFusedLinearDPOFunction(LigerFusedLinearPreferenceBase):
|
|
34
35
|
rejected_logratios = rejected_logps - ref_rejected_logps
|
35
36
|
|
36
37
|
logits_diff = beta * (chosen_logratios - rejected_logratios)
|
37
|
-
|
38
|
-
return
|
38
|
+
loss = -F.logsigmoid(logits_diff).sum() / (full_target.shape[0] // 2)
|
39
|
+
return loss
|
39
40
|
|
40
41
|
@staticmethod
|
41
42
|
def forward(
|
@@ -73,7 +74,7 @@ class LigerFusedLinearDPOFunction(LigerFusedLinearPreferenceBase):
|
|
73
74
|
)
|
74
75
|
|
75
76
|
@staticmethod
|
76
|
-
def backward(ctx, grad_output):
|
77
|
+
def backward(ctx, *grad_output):
|
77
78
|
# Get gradients for _input, weight, bias, and target from the base class
|
78
79
|
grads = LigerFusedLinearPreferenceBase.backward(ctx, grad_output)[:4]
|
79
80
|
# Return these gradients, followed by None for the remaining inputs
|
@@ -52,7 +52,17 @@ class LigerFusedLinearPreferenceBase(torch.autograd.Function):
|
|
52
52
|
|
53
53
|
chosen_logps = average_log_prob[:len_chosen_chunk]
|
54
54
|
rejected_logps = average_log_prob[len_chosen_chunk:]
|
55
|
-
|
55
|
+
|
56
|
+
chosen_logits = logits_chunk[:len_chosen_chunk]
|
57
|
+
rejected_logits = logits_chunk[len_chosen_chunk:]
|
58
|
+
|
59
|
+
return (
|
60
|
+
chosen_logps,
|
61
|
+
rejected_logps,
|
62
|
+
chosen_logits,
|
63
|
+
rejected_logits,
|
64
|
+
chosen_nll_loss,
|
65
|
+
)
|
56
66
|
|
57
67
|
@staticmethod
|
58
68
|
def forward(
|
@@ -103,6 +113,12 @@ class LigerFusedLinearPreferenceBase(torch.autograd.Function):
|
|
103
113
|
grad_rejected_inputs = []
|
104
114
|
grad_bias = torch.zeros_like(bias) if bias is not None else None
|
105
115
|
loss_acc = torch.zeros((), device=_input.device)
|
116
|
+
policy_chosen_logps = []
|
117
|
+
policy_rejected_logps = []
|
118
|
+
policy_chosen_logits_mean = torch.zeros((), device=_input.device)
|
119
|
+
policy_rejected_logits_mean = torch.zeros((), device=_input.device)
|
120
|
+
policy_nll_loss = torch.zeros((), device=_input.device)
|
121
|
+
aggregated_aux_outputs = [] # aggregated aux outputs from all chunks
|
106
122
|
|
107
123
|
loss_func_to_call = partial(
|
108
124
|
LigerFusedLinearPreferenceBase._compute_loss,
|
@@ -118,32 +134,72 @@ class LigerFusedLinearPreferenceBase(torch.autograd.Function):
|
|
118
134
|
**loss_kwargs,
|
119
135
|
)
|
120
136
|
|
137
|
+
def accumulate_helper(input_chunk, target_chunk):
|
138
|
+
if bias is not None:
|
139
|
+
return torch.func.grad_and_value(
|
140
|
+
loss_func_to_call, argnums=(0, 1, 3), has_aux=True
|
141
|
+
)(input_chunk, weight, target_chunk, bias)
|
142
|
+
else:
|
143
|
+
return torch.func.grad_and_value(
|
144
|
+
loss_func_to_call, argnums=(0, 1), has_aux=True
|
145
|
+
)(input_chunk, weight, target_chunk)
|
146
|
+
|
121
147
|
def accumulate_chunk(input_chunk, target_chunk):
|
122
148
|
if bias is not None:
|
123
149
|
(chunk_grad_input, chunk_grad_weight, chunk_grad_bias), (
|
124
150
|
chunk_loss,
|
125
|
-
(
|
126
|
-
|
127
|
-
|
128
|
-
|
129
|
-
|
130
|
-
|
131
|
-
|
151
|
+
(
|
152
|
+
chunk_chosen_logps,
|
153
|
+
chunk_rejected_logps,
|
154
|
+
chunk_chosen_logits_mean,
|
155
|
+
chunk_rejected_logits_mean,
|
156
|
+
chunk_nll_loss,
|
157
|
+
*aux_outputs,
|
158
|
+
),
|
159
|
+
) = accumulate_helper(input_chunk, target_chunk)
|
160
|
+
grad_bias.add_(chunk_grad_bias) # accumulate bias gradient
|
132
161
|
else:
|
133
162
|
(chunk_grad_input, chunk_grad_weight), (
|
134
163
|
chunk_loss,
|
135
|
-
(
|
136
|
-
|
137
|
-
|
138
|
-
|
139
|
-
|
140
|
-
|
164
|
+
(
|
165
|
+
chunk_chosen_logps,
|
166
|
+
chunk_rejected_logps,
|
167
|
+
chunk_chosen_logits_mean,
|
168
|
+
chunk_rejected_logits_mean,
|
169
|
+
chunk_nll_loss,
|
170
|
+
*aux_outputs,
|
171
|
+
),
|
172
|
+
) = accumulate_helper(input_chunk, target_chunk)
|
173
|
+
|
141
174
|
grad_weight.add_(chunk_grad_weight)
|
142
175
|
loss_acc.add_(chunk_loss)
|
176
|
+
policy_chosen_logps.append(chunk_chosen_logps)
|
177
|
+
policy_rejected_logps.append(chunk_rejected_logps)
|
178
|
+
policy_chosen_logits_mean.add_(chunk_chosen_logits_mean)
|
179
|
+
policy_rejected_logits_mean.add_(chunk_rejected_logits_mean)
|
180
|
+
policy_nll_loss.add_(chunk_nll_loss)
|
181
|
+
|
182
|
+
# Initialize storage for aux_outputs
|
183
|
+
if len(aggregated_aux_outputs) == 0:
|
184
|
+
for aux in aux_outputs:
|
185
|
+
if aux.ndim == 0:
|
186
|
+
aggregated_aux_outputs.append(
|
187
|
+
torch.zeros((), device=aux.device)
|
188
|
+
)
|
189
|
+
else:
|
190
|
+
aggregated_aux_outputs.append([])
|
191
|
+
|
192
|
+
# Process each aux_output
|
193
|
+
for i, aux in enumerate(aux_outputs):
|
194
|
+
if aux.ndim == 0:
|
195
|
+
aggregated_aux_outputs[i].add_(aux)
|
196
|
+
else:
|
197
|
+
aggregated_aux_outputs[i].append(aux)
|
198
|
+
|
143
199
|
return chunk_grad_input
|
144
200
|
|
145
201
|
if compiled:
|
146
|
-
|
202
|
+
accumulate_helper = torch.compile(accumulate_helper)
|
147
203
|
|
148
204
|
len_chosen = target.shape[0] // 2
|
149
205
|
chunks = max(1, _input.shape[0] // (2 * CHUNK_SIZE))
|
@@ -168,6 +224,12 @@ class LigerFusedLinearPreferenceBase(torch.autograd.Function):
|
|
168
224
|
[chosen_target_chunk, rejected_target_chunk], dim=0
|
169
225
|
)
|
170
226
|
|
227
|
+
# mark input_chunk, target_chunk, and target dimension 1 as dynamic to prevent torch.compile recompilation
|
228
|
+
torch._dynamo.mark_dynamic(input_chunk, 1)
|
229
|
+
torch._dynamo.mark_dynamic(target_chunk, 1)
|
230
|
+
torch._dynamo.mark_dynamic(target, 1)
|
231
|
+
|
232
|
+
# accumulate loss, gradients, and metrics
|
171
233
|
grad_input = accumulate_chunk(input_chunk, target_chunk)
|
172
234
|
|
173
235
|
grad_chosen_inputs.append(grad_input[: chosen_target_chunk.shape[0]])
|
@@ -175,21 +237,37 @@ class LigerFusedLinearPreferenceBase(torch.autograd.Function):
|
|
175
237
|
|
176
238
|
# combine grad_chosen_inputs and grad_rejected_inputs
|
177
239
|
grad_inputs = grad_chosen_inputs + grad_rejected_inputs
|
240
|
+
policy_chosen_logps = torch.cat(policy_chosen_logps, dim=0)
|
241
|
+
policy_rejected_logps = torch.cat(policy_rejected_logps, dim=0)
|
242
|
+
|
243
|
+
# Aggregate aux outputs lists into tensors
|
244
|
+
for i, aux in enumerate(aggregated_aux_outputs):
|
245
|
+
if isinstance(aux, list):
|
246
|
+
aggregated_aux_outputs[i] = torch.cat(aux, dim=0)
|
178
247
|
|
179
248
|
ctx.save_for_backward(
|
180
249
|
torch.cat(grad_inputs, dim=0),
|
181
250
|
grad_weight,
|
182
251
|
grad_bias,
|
183
252
|
)
|
184
|
-
|
253
|
+
return_vars = (
|
254
|
+
policy_chosen_logps,
|
255
|
+
policy_rejected_logps,
|
256
|
+
policy_chosen_logits_mean,
|
257
|
+
policy_rejected_logits_mean,
|
258
|
+
policy_nll_loss,
|
259
|
+
)
|
260
|
+
return loss_acc, (*return_vars, *aggregated_aux_outputs)
|
185
261
|
|
186
262
|
@staticmethod
|
187
|
-
def backward(ctx, grad_output):
|
263
|
+
def backward(ctx, *grad_output):
|
188
264
|
grad_input, grad_weight, grad_bias = ctx.saved_tensors
|
189
|
-
if torch.ne(
|
190
|
-
|
191
|
-
|
192
|
-
|
265
|
+
if torch.ne(
|
266
|
+
grad_output[0][0], torch.tensor(1.0, device=grad_output[0][0].device)
|
267
|
+
):
|
268
|
+
grad_input = grad_input * grad_output[0][0]
|
269
|
+
grad_weight = grad_weight * grad_output[0][0]
|
270
|
+
grad_bias = grad_bias * grad_output[0][0] if grad_bias is not None else None
|
193
271
|
|
194
272
|
return grad_input, grad_weight, None, grad_bias, None, None, None
|
195
273
|
|
@@ -228,40 +306,64 @@ class LigerFusedLinearPreferenceBase(torch.autograd.Function):
|
|
228
306
|
ref_bias (torch.Tensor, optional): Reference bias tensor. Shape: (vocab_size,).
|
229
307
|
loss_kwargs (dict): Additional arguments for the loss function.
|
230
308
|
"""
|
231
|
-
|
232
|
-
|
233
|
-
|
234
|
-
|
235
|
-
|
236
|
-
|
237
|
-
|
238
|
-
|
239
|
-
|
309
|
+
(
|
310
|
+
chosen_logps,
|
311
|
+
rejected_logps,
|
312
|
+
chosen_logits,
|
313
|
+
rejected_logits,
|
314
|
+
chosen_nll_loss,
|
315
|
+
) = LigerFusedLinearPreferenceBase.chunk_forward(
|
316
|
+
input_chunk,
|
317
|
+
weight,
|
318
|
+
target_chunk,
|
319
|
+
bias=bias,
|
320
|
+
ignore_index=ignore_index,
|
321
|
+
compute_nll_loss=compute_nll_loss,
|
240
322
|
)
|
241
323
|
chosen_nll_loss = (
|
242
324
|
chosen_nll_loss
|
243
325
|
/ (full_target[: full_target.shape[0] // 2] != ignore_index).sum()
|
244
326
|
)
|
327
|
+
chosen_logits_mean = chosen_logits.sum() / (
|
328
|
+
full_target.shape[0] // 2 * input_chunk.shape[1] * weight.shape[0]
|
329
|
+
)
|
330
|
+
rejected_logits_mean = rejected_logits.sum() / (
|
331
|
+
full_target.shape[0] // 2 * input_chunk.shape[1] * weight.shape[0]
|
332
|
+
)
|
245
333
|
|
246
334
|
if use_ref_model:
|
247
335
|
with torch.no_grad():
|
248
|
-
|
249
|
-
|
250
|
-
|
251
|
-
|
252
|
-
|
253
|
-
|
254
|
-
|
255
|
-
|
256
|
-
|
336
|
+
(
|
337
|
+
ref_chosen_logps,
|
338
|
+
ref_rejected_logps,
|
339
|
+
ref_chosen_logits,
|
340
|
+
ref_rejected_logits,
|
341
|
+
ref_chosen_nll_loss,
|
342
|
+
) = LigerFusedLinearPreferenceBase.chunk_forward(
|
343
|
+
input_chunk,
|
344
|
+
ref_weight,
|
345
|
+
target_chunk,
|
346
|
+
ref_bias,
|
347
|
+
ignore_index=ignore_index,
|
348
|
+
compute_nll_loss=False, # We don't need NLL loss for the reference model
|
257
349
|
)
|
258
350
|
loss_kwargs["ref_chosen_logps"] = ref_chosen_logps
|
259
351
|
loss_kwargs["ref_rejected_logps"] = ref_rejected_logps
|
260
352
|
|
261
|
-
|
262
|
-
chosen_logps, rejected_logps, beta=beta, **loss_kwargs
|
353
|
+
preference_loss_outputs = preference_loss_fn(
|
354
|
+
chosen_logps, rejected_logps, full_target, beta=beta, **loss_kwargs
|
263
355
|
)
|
264
|
-
|
356
|
+
if isinstance(preference_loss_outputs, tuple):
|
357
|
+
preference_loss, *aux_outputs = preference_loss_outputs
|
358
|
+
else:
|
359
|
+
preference_loss, aux_outputs = preference_loss_outputs, []
|
265
360
|
|
266
|
-
loss = alpha * chosen_nll_loss -
|
267
|
-
|
361
|
+
loss = alpha * chosen_nll_loss - preference_loss
|
362
|
+
return_vars = (
|
363
|
+
chosen_logps,
|
364
|
+
rejected_logps,
|
365
|
+
chosen_logits_mean,
|
366
|
+
rejected_logits_mean,
|
367
|
+
chosen_nll_loss,
|
368
|
+
)
|
369
|
+
return loss, (*return_vars, *aux_outputs)
|
@@ -9,7 +9,7 @@ from liger_kernel.chunked_loss.fused_linear_preference import (
|
|
9
9
|
class LigerFusedLinearORPOFunction(LigerFusedLinearPreferenceBase):
|
10
10
|
|
11
11
|
@staticmethod
|
12
|
-
def preference_loss_fn(chosen_logps, rejected_logps, beta=0.1):
|
12
|
+
def preference_loss_fn(chosen_logps, rejected_logps, full_target, beta=0.1):
|
13
13
|
"""
|
14
14
|
Compute odds-ratio loss.
|
15
15
|
Args:
|
@@ -22,7 +22,15 @@ class LigerFusedLinearORPOFunction(LigerFusedLinearPreferenceBase):
|
|
22
22
|
- torch.log1p(-torch.exp(rejected_logps))
|
23
23
|
)
|
24
24
|
ratio = F.logsigmoid(log_odds)
|
25
|
-
|
25
|
+
loss = beta * ratio.sum() / (full_target.shape[0] // 2)
|
26
|
+
|
27
|
+
chosen_rewards = beta * chosen_logps
|
28
|
+
rejected_rewards = beta * rejected_logps
|
29
|
+
|
30
|
+
log_odds_ratio = torch.sum(ratio) / (full_target.shape[0] // 2)
|
31
|
+
log_odds_chosen = torch.sum(log_odds) / (full_target.shape[0] // 2)
|
32
|
+
|
33
|
+
return loss, chosen_rewards, rejected_rewards, log_odds_ratio, log_odds_chosen
|
26
34
|
|
27
35
|
@staticmethod
|
28
36
|
def forward(
|
@@ -56,7 +64,7 @@ class LigerFusedLinearORPOFunction(LigerFusedLinearPreferenceBase):
|
|
56
64
|
)
|
57
65
|
|
58
66
|
@staticmethod
|
59
|
-
def backward(ctx, grad_output):
|
67
|
+
def backward(ctx, *grad_output):
|
60
68
|
# Get gradients for _input, weight, bias, and target from the base class
|
61
69
|
grads = LigerFusedLinearPreferenceBase.backward(ctx, grad_output)[:4]
|
62
70
|
# Return these gradients, followed by None for the remaining inputs
|
@@ -9,7 +9,9 @@ from liger_kernel.chunked_loss.fused_linear_preference import (
|
|
9
9
|
class LigerFusedLinearSimPOFunction(LigerFusedLinearPreferenceBase):
|
10
10
|
|
11
11
|
@staticmethod
|
12
|
-
def preference_loss_fn(
|
12
|
+
def preference_loss_fn(
|
13
|
+
chosen_logps, rejected_logps, full_target, beta=0.1, gamma=0.5
|
14
|
+
):
|
13
15
|
"""
|
14
16
|
Compute odds-ratio loss.
|
15
17
|
Args:
|
@@ -19,7 +21,7 @@ class LigerFusedLinearSimPOFunction(LigerFusedLinearPreferenceBase):
|
|
19
21
|
gamma (float): The simpo gamma, margin term.
|
20
22
|
"""
|
21
23
|
logits = beta * (chosen_logps - rejected_logps) - gamma
|
22
|
-
loss = F.logsigmoid(logits).
|
24
|
+
loss = F.logsigmoid(logits).sum() / (full_target.shape[0] // 2)
|
23
25
|
return loss
|
24
26
|
|
25
27
|
@staticmethod
|
@@ -58,7 +60,7 @@ class LigerFusedLinearSimPOFunction(LigerFusedLinearPreferenceBase):
|
|
58
60
|
)
|
59
61
|
|
60
62
|
@staticmethod
|
61
|
-
def backward(ctx, grad_output):
|
63
|
+
def backward(ctx, *grad_output):
|
62
64
|
# Get gradients for _input, weight, bias, and target from the base class
|
63
65
|
grads = LigerFusedLinearPreferenceBase.backward(ctx, grad_output)[:4]
|
64
66
|
# Return these gradients, followed by None for the remaining inputs
|
@@ -22,6 +22,7 @@ from liger_kernel.transformers.monkey_patch import ( # noqa: F401
|
|
22
22
|
apply_liger_kernel_to_qwen2,
|
23
23
|
apply_liger_kernel_to_qwen2_vl,
|
24
24
|
)
|
25
|
+
from liger_kernel.transformers.orpo_trainer import LigerORPOTrainer # noqa: F401
|
25
26
|
from liger_kernel.transformers.rms_norm import LigerRMSNorm # noqa: F401
|
26
27
|
from liger_kernel.transformers.rope import liger_rotary_pos_emb # noqa: F401
|
27
28
|
from liger_kernel.transformers.swiglu import ( # noqa: F401
|
@@ -0,0 +1,171 @@
|
|
1
|
+
from typing import Any, Callable, Dict, List, Literal, Tuple, Union
|
2
|
+
|
3
|
+
import torch
|
4
|
+
import torch.nn as nn
|
5
|
+
from torch.distributed.fsdp import FullyShardedDataParallel
|
6
|
+
from trl.trainer import ORPOTrainer
|
7
|
+
|
8
|
+
from liger_kernel.chunked_loss import LigerFusedLinearORPOLoss
|
9
|
+
|
10
|
+
|
11
|
+
class _FSDPForwardRedirection:
|
12
|
+
"""
|
13
|
+
Modified based on
|
14
|
+
https://github.com/Lightning-AI/pytorch-lightning/blob/d3f9c83d6efa4f1def36aa6c199600946cdb9117/src/lightning/pytorch/strategies/strategy.py#L601-L648
|
15
|
+
Redirect a method call through FullyShardedDataParallel.forward so that the FSDP module's root pre-forward and
|
16
|
+
post-forward can be properly executed around the method call.
|
17
|
+
This is needed in cases where we call a submodule of a FSDP module. For instance, when we want to call only
|
18
|
+
the `LlamaModel` part out of a FSDP-wrapped `LlamaForCausalLM` to get the hidden states without involving
|
19
|
+
GPU-memory-heavy `lm_head` and cross entropy computation, doing this directly (i.e. `model.model.forward()`)
|
20
|
+
will not work because the first `nn.Emebedding` layer is not independently wrapped as a FSDP module (because of
|
21
|
+
the transformer-based wrapping policy), and not calling it through FSDP root module forward will not all-gather
|
22
|
+
its parameter, thus resulting in "RuntimeError: 'weight' must be 2-D" error. Similarly, if we want to call just
|
23
|
+
the `lm_head` part of a model, we need this trick too to properly get its params all-gathered.
|
24
|
+
"""
|
25
|
+
|
26
|
+
def __call__(
|
27
|
+
self,
|
28
|
+
wrapper_module: FullyShardedDataParallel,
|
29
|
+
method: Callable,
|
30
|
+
*args: Any,
|
31
|
+
**kwargs: Any,
|
32
|
+
):
|
33
|
+
"""Reroutes a method call through the `wrapper_module`'s `forward` method.
|
34
|
+
Args:
|
35
|
+
wrapper_module: The module that has `original_module` wrapped.
|
36
|
+
original_module: The module that was wrapped inside `wrapper_module`.
|
37
|
+
method_name: The name of the method that should be called on the `original_module` after inputs get
|
38
|
+
redirected through the `wrapper_module`'s `forward` method.
|
39
|
+
*args: The positional arguments to the method `method_name`. They will get passed to a patched
|
40
|
+
`forward` method instead.
|
41
|
+
**kwargs: The keyword arguments to the method `method_name`. They will get passed to a patched
|
42
|
+
`forward` method instead.
|
43
|
+
"""
|
44
|
+
assert isinstance(wrapper_module, FullyShardedDataParallel)
|
45
|
+
original_module = wrapper_module._fsdp_wrapped_module
|
46
|
+
original_forward = original_module.forward
|
47
|
+
|
48
|
+
def wrapped_forward(*_args: Any, **_kwargs: Any) -> Any:
|
49
|
+
# Unpatch ourselves immediately before calling the method `method_name`
|
50
|
+
# because itself may want to call the real `forward`
|
51
|
+
original_module.forward = original_forward # type: ignore[method-assign]
|
52
|
+
# Call the actual method e.g. `.training_step(...)`
|
53
|
+
out = method(*_args, **_kwargs)
|
54
|
+
return out
|
55
|
+
|
56
|
+
# Patch the original_module's forward so we can redirect the arguments back to the real method
|
57
|
+
original_module.forward = wrapped_forward # type: ignore[method-assign]
|
58
|
+
wrapper_output = wrapper_module(*args, **kwargs)
|
59
|
+
return wrapper_output
|
60
|
+
|
61
|
+
|
62
|
+
class LigerORPOTrainer(ORPOTrainer):
|
63
|
+
def concatenated_forward(
|
64
|
+
self, model: nn.Module, batch: Dict[str, Union[List, torch.LongTensor]]
|
65
|
+
) -> Tuple[
|
66
|
+
torch.FloatTensor, torch.FloatTensor, torch.FloatTensor, torch.FloatTensor
|
67
|
+
]:
|
68
|
+
"""
|
69
|
+
Run the given model on the given batch of inputs, concatenating the chosen and rejected inputs together.
|
70
|
+
We do this to avoid doing two forward passes, because it's faster for FSDP.
|
71
|
+
"""
|
72
|
+
concatenated_batch = self.concatenated_inputs(
|
73
|
+
batch,
|
74
|
+
is_encoder_decoder=self.is_encoder_decoder,
|
75
|
+
label_pad_token_id=self.label_pad_token_id,
|
76
|
+
padding_value=self.padding_value,
|
77
|
+
device=self.accelerator.device,
|
78
|
+
)
|
79
|
+
# if self.accelerator.is_main_process:
|
80
|
+
# import pdb; pdb.set_trace()
|
81
|
+
# torch.distributed.barrier()
|
82
|
+
model_kwargs = (
|
83
|
+
{
|
84
|
+
"decoder_input_ids": self._shift_right(
|
85
|
+
concatenated_batch["concatenated_labels"]
|
86
|
+
),
|
87
|
+
}
|
88
|
+
if self.is_encoder_decoder
|
89
|
+
else {}
|
90
|
+
)
|
91
|
+
|
92
|
+
if self.aux_loss_enabled:
|
93
|
+
model_kwargs["output_router_logits"] = True
|
94
|
+
|
95
|
+
if isinstance(model, FullyShardedDataParallel):
|
96
|
+
outputs = _FSDPForwardRedirection()(
|
97
|
+
model,
|
98
|
+
model._fsdp_wrapped_module.model,
|
99
|
+
concatenated_batch["concatenated_input_ids"],
|
100
|
+
attention_mask=concatenated_batch["concatenated_attention_mask"],
|
101
|
+
use_cache=False,
|
102
|
+
**model_kwargs,
|
103
|
+
)
|
104
|
+
else:
|
105
|
+
if isinstance(model, torch.nn.DataParallel):
|
106
|
+
model = model.module
|
107
|
+
outputs = model.model(
|
108
|
+
concatenated_batch["concatenated_input_ids"],
|
109
|
+
attention_mask=concatenated_batch["concatenated_attention_mask"],
|
110
|
+
use_cache=False,
|
111
|
+
**model_kwargs,
|
112
|
+
)
|
113
|
+
|
114
|
+
orpo_loss_fn = LigerFusedLinearORPOLoss(
|
115
|
+
ignore_index=self.label_pad_token_id, beta=self.beta
|
116
|
+
)
|
117
|
+
|
118
|
+
def orpo_partial(lm_head, last_hidden_state, concatenated_labels):
|
119
|
+
return orpo_loss_fn(
|
120
|
+
lm_head.weight, last_hidden_state, concatenated_labels, lm_head.bias
|
121
|
+
)
|
122
|
+
|
123
|
+
orpo_loss, aux_outputs = _FSDPForwardRedirection()(
|
124
|
+
model,
|
125
|
+
orpo_partial,
|
126
|
+
model.lm_head,
|
127
|
+
outputs.last_hidden_state,
|
128
|
+
concatenated_batch["concatenated_labels"],
|
129
|
+
)
|
130
|
+
return orpo_loss, aux_outputs
|
131
|
+
|
132
|
+
def get_batch_loss_metrics(
|
133
|
+
self,
|
134
|
+
model,
|
135
|
+
batch: Dict[str, Union[List, torch.LongTensor]],
|
136
|
+
train_eval: Literal["train", "eval"] = "train",
|
137
|
+
):
|
138
|
+
"""Compute the ORPO loss and other metrics for the given batch of inputs for train or test."""
|
139
|
+
metrics = {}
|
140
|
+
loss, aux_outputs = self.concatenated_forward(model, batch)
|
141
|
+
(
|
142
|
+
policy_chosen_logps,
|
143
|
+
policy_rejected_logps,
|
144
|
+
policy_chosen_logits,
|
145
|
+
policy_rejected_logits,
|
146
|
+
policy_nll_loss,
|
147
|
+
) = aux_outputs[:5]
|
148
|
+
|
149
|
+
# return loss, metrics
|
150
|
+
chosen_rewards, rejected_rewards, log_odds_ratio, log_odds_chosen = aux_outputs[
|
151
|
+
5:
|
152
|
+
]
|
153
|
+
|
154
|
+
reward_accuracies = (chosen_rewards > rejected_rewards).float()
|
155
|
+
|
156
|
+
prefix = "eval_" if train_eval == "eval" else ""
|
157
|
+
metrics[f"{prefix}rewards/chosen"] = chosen_rewards.mean()
|
158
|
+
metrics[f"{prefix}rewards/rejected"] = rejected_rewards.mean()
|
159
|
+
metrics[f"{prefix}rewards/accuracies"] = reward_accuracies.mean()
|
160
|
+
metrics[f"{prefix}rewards/margins"] = (chosen_rewards - rejected_rewards).mean()
|
161
|
+
metrics[f"{prefix}logps/rejected"] = policy_rejected_logps.detach().mean()
|
162
|
+
metrics[f"{prefix}logps/chosen"] = policy_chosen_logps.detach().mean()
|
163
|
+
metrics[f"{prefix}logits/rejected"] = policy_rejected_logits.detach().mean()
|
164
|
+
metrics[f"{prefix}logits/chosen"] = policy_chosen_logits.detach().mean()
|
165
|
+
metrics[f"{prefix}nll_loss"] = policy_nll_loss.detach().mean()
|
166
|
+
metrics[f"{prefix}log_odds_ratio"] = log_odds_ratio
|
167
|
+
metrics[f"{prefix}log_odds_chosen"] = log_odds_chosen
|
168
|
+
for k, v in metrics.items():
|
169
|
+
metrics[k] = v.item()
|
170
|
+
|
171
|
+
return loss, metrics
|
@@ -1,6 +1,6 @@
|
|
1
1
|
Metadata-Version: 2.1
|
2
2
|
Name: liger_kernel_nightly
|
3
|
-
Version: 0.4.2.
|
3
|
+
Version: 0.4.2.dev20241207011709
|
4
4
|
Summary: Efficient Triton kernels for LLM Training
|
5
5
|
License: BSD 2-CLAUSE LICENSE
|
6
6
|
Copyright 2024 LinkedIn Corporation
|
@@ -36,6 +36,7 @@ Provides-Extra: transformers
|
|
36
36
|
Requires-Dist: transformers~=4.0; extra == "transformers"
|
37
37
|
Provides-Extra: dev
|
38
38
|
Requires-Dist: transformers>=4.44.2; extra == "dev"
|
39
|
+
Requires-Dist: trl>=0.11.0; extra == "dev"
|
39
40
|
Requires-Dist: matplotlib>=3.7.2; extra == "dev"
|
40
41
|
Requires-Dist: flake8>=4.0.1.1; extra == "dev"
|
41
42
|
Requires-Dist: black>=24.4.2; extra == "dev"
|
@@ -55,7 +56,7 @@ Requires-Dist: seaborn; extra == "dev"
|
|
55
56
|
<th style="padding: 10px;" colspan="2">Stable</th>
|
56
57
|
<th style="padding: 10px;" colspan="2">Nightly</th>
|
57
58
|
<th style="padding: 10px;">Discord</th>
|
58
|
-
<th style="padding: 10px;">
|
59
|
+
<th style="padding: 10px;">Build</th>
|
59
60
|
</tr>
|
60
61
|
<tr>
|
61
62
|
<td style="padding: 10px;">
|
@@ -84,8 +85,8 @@ Requires-Dist: seaborn; extra == "dev"
|
|
84
85
|
</a>
|
85
86
|
</td>
|
86
87
|
<td style="padding: 10px;">
|
87
|
-
<a href="https://
|
88
|
-
<img src="https://
|
88
|
+
<a href="https://github.com/linkedin/Liger-Kernel/actions/workflows/ci.yml">
|
89
|
+
<img src="https://github.com/linkedin/Liger-Kernel/actions/workflows/ci.yml/badge.svg?event=schedule" alt="Build">
|
89
90
|
</a>
|
90
91
|
</td>
|
91
92
|
</tr>
|
@@ -40,6 +40,7 @@ src/liger_kernel/transformers/jsd.py
|
|
40
40
|
src/liger_kernel/transformers/kl_div.py
|
41
41
|
src/liger_kernel/transformers/layer_norm.py
|
42
42
|
src/liger_kernel/transformers/monkey_patch.py
|
43
|
+
src/liger_kernel/transformers/orpo_trainer.py
|
43
44
|
src/liger_kernel/transformers/qwen2vl_mrope.py
|
44
45
|
src/liger_kernel/transformers/rms_norm.py
|
45
46
|
src/liger_kernel/transformers/rope.py
|
File without changes
|
{liger_kernel_nightly-0.4.2.dev20241204180758 → liger_kernel_nightly-0.4.2.dev20241207011709}/NOTICE
RENAMED
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|