liger-kernel-nightly 0.3.1.dev20241101201851__tar.gz → 0.3.1.dev20241102170757__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of liger-kernel-nightly might be problematic. Click here for more details.

Files changed (58) hide show
  1. {liger_kernel_nightly-0.3.1.dev20241101201851 → liger_kernel_nightly-0.3.1.dev20241102170757}/PKG-INFO +10 -2
  2. {liger_kernel_nightly-0.3.1.dev20241101201851 → liger_kernel_nightly-0.3.1.dev20241102170757}/README.md +9 -1
  3. {liger_kernel_nightly-0.3.1.dev20241101201851 → liger_kernel_nightly-0.3.1.dev20241102170757}/pyproject.toml +1 -1
  4. {liger_kernel_nightly-0.3.1.dev20241101201851 → liger_kernel_nightly-0.3.1.dev20241102170757}/src/liger_kernel/ops/cross_entropy.py +3 -3
  5. {liger_kernel_nightly-0.3.1.dev20241101201851 → liger_kernel_nightly-0.3.1.dev20241102170757}/src/liger_kernel/ops/fused_linear_cross_entropy.py +10 -5
  6. {liger_kernel_nightly-0.3.1.dev20241101201851 → liger_kernel_nightly-0.3.1.dev20241102170757}/src/liger_kernel/ops/fused_linear_jsd.py +8 -3
  7. {liger_kernel_nightly-0.3.1.dev20241101201851 → liger_kernel_nightly-0.3.1.dev20241102170757}/src/liger_kernel/ops/kl_div.py +2 -2
  8. {liger_kernel_nightly-0.3.1.dev20241101201851 → liger_kernel_nightly-0.3.1.dev20241102170757}/src/liger_kernel/ops/utils.py +5 -1
  9. {liger_kernel_nightly-0.3.1.dev20241101201851 → liger_kernel_nightly-0.3.1.dev20241102170757}/src/liger_kernel/transformers/model/llama.py +21 -19
  10. {liger_kernel_nightly-0.3.1.dev20241101201851 → liger_kernel_nightly-0.3.1.dev20241102170757}/src/liger_kernel_nightly.egg-info/PKG-INFO +10 -2
  11. {liger_kernel_nightly-0.3.1.dev20241101201851 → liger_kernel_nightly-0.3.1.dev20241102170757}/LICENSE +0 -0
  12. {liger_kernel_nightly-0.3.1.dev20241101201851 → liger_kernel_nightly-0.3.1.dev20241102170757}/LICENSE-Apache-2.0 +0 -0
  13. {liger_kernel_nightly-0.3.1.dev20241101201851 → liger_kernel_nightly-0.3.1.dev20241102170757}/LICENSE-MIT-AutoAWQ +0 -0
  14. {liger_kernel_nightly-0.3.1.dev20241101201851 → liger_kernel_nightly-0.3.1.dev20241102170757}/LICENSE-MIT-Efficient-Cross-Entropy +0 -0
  15. {liger_kernel_nightly-0.3.1.dev20241101201851 → liger_kernel_nightly-0.3.1.dev20241102170757}/LICENSE-MIT-llmc +0 -0
  16. {liger_kernel_nightly-0.3.1.dev20241101201851 → liger_kernel_nightly-0.3.1.dev20241102170757}/LICENSE-MIT-triton +0 -0
  17. {liger_kernel_nightly-0.3.1.dev20241101201851 → liger_kernel_nightly-0.3.1.dev20241102170757}/NOTICE +0 -0
  18. {liger_kernel_nightly-0.3.1.dev20241101201851 → liger_kernel_nightly-0.3.1.dev20241102170757}/setup.cfg +0 -0
  19. {liger_kernel_nightly-0.3.1.dev20241101201851 → liger_kernel_nightly-0.3.1.dev20241102170757}/src/liger_kernel/env_report.py +0 -0
  20. {liger_kernel_nightly-0.3.1.dev20241101201851 → liger_kernel_nightly-0.3.1.dev20241102170757}/src/liger_kernel/ops/__init__.py +0 -0
  21. {liger_kernel_nightly-0.3.1.dev20241101201851 → liger_kernel_nightly-0.3.1.dev20241102170757}/src/liger_kernel/ops/experimental/embedding.py +0 -0
  22. {liger_kernel_nightly-0.3.1.dev20241101201851 → liger_kernel_nightly-0.3.1.dev20241102170757}/src/liger_kernel/ops/experimental/mm_int8int2.py +0 -0
  23. {liger_kernel_nightly-0.3.1.dev20241101201851 → liger_kernel_nightly-0.3.1.dev20241102170757}/src/liger_kernel/ops/geglu.py +0 -0
  24. {liger_kernel_nightly-0.3.1.dev20241101201851 → liger_kernel_nightly-0.3.1.dev20241102170757}/src/liger_kernel/ops/jsd.py +0 -0
  25. {liger_kernel_nightly-0.3.1.dev20241101201851 → liger_kernel_nightly-0.3.1.dev20241102170757}/src/liger_kernel/ops/layer_norm.py +0 -0
  26. {liger_kernel_nightly-0.3.1.dev20241101201851 → liger_kernel_nightly-0.3.1.dev20241102170757}/src/liger_kernel/ops/rms_norm.py +0 -0
  27. {liger_kernel_nightly-0.3.1.dev20241101201851 → liger_kernel_nightly-0.3.1.dev20241102170757}/src/liger_kernel/ops/rope.py +0 -0
  28. {liger_kernel_nightly-0.3.1.dev20241101201851 → liger_kernel_nightly-0.3.1.dev20241102170757}/src/liger_kernel/ops/swiglu.py +0 -0
  29. {liger_kernel_nightly-0.3.1.dev20241101201851 → liger_kernel_nightly-0.3.1.dev20241102170757}/src/liger_kernel/transformers/__init__.py +0 -0
  30. {liger_kernel_nightly-0.3.1.dev20241101201851 → liger_kernel_nightly-0.3.1.dev20241102170757}/src/liger_kernel/transformers/auto_model.py +0 -0
  31. {liger_kernel_nightly-0.3.1.dev20241101201851 → liger_kernel_nightly-0.3.1.dev20241102170757}/src/liger_kernel/transformers/cross_entropy.py +0 -0
  32. {liger_kernel_nightly-0.3.1.dev20241101201851 → liger_kernel_nightly-0.3.1.dev20241102170757}/src/liger_kernel/transformers/experimental/embedding.py +0 -0
  33. {liger_kernel_nightly-0.3.1.dev20241101201851 → liger_kernel_nightly-0.3.1.dev20241102170757}/src/liger_kernel/transformers/functional.py +0 -0
  34. {liger_kernel_nightly-0.3.1.dev20241101201851 → liger_kernel_nightly-0.3.1.dev20241102170757}/src/liger_kernel/transformers/fused_linear_cross_entropy.py +0 -0
  35. {liger_kernel_nightly-0.3.1.dev20241101201851 → liger_kernel_nightly-0.3.1.dev20241102170757}/src/liger_kernel/transformers/fused_linear_jsd.py +0 -0
  36. {liger_kernel_nightly-0.3.1.dev20241101201851 → liger_kernel_nightly-0.3.1.dev20241102170757}/src/liger_kernel/transformers/geglu.py +0 -0
  37. {liger_kernel_nightly-0.3.1.dev20241101201851 → liger_kernel_nightly-0.3.1.dev20241102170757}/src/liger_kernel/transformers/jsd.py +0 -0
  38. {liger_kernel_nightly-0.3.1.dev20241101201851 → liger_kernel_nightly-0.3.1.dev20241102170757}/src/liger_kernel/transformers/kl_div.py +0 -0
  39. {liger_kernel_nightly-0.3.1.dev20241101201851 → liger_kernel_nightly-0.3.1.dev20241102170757}/src/liger_kernel/transformers/layer_norm.py +0 -0
  40. {liger_kernel_nightly-0.3.1.dev20241101201851 → liger_kernel_nightly-0.3.1.dev20241102170757}/src/liger_kernel/transformers/model/__init__.py +0 -0
  41. {liger_kernel_nightly-0.3.1.dev20241101201851 → liger_kernel_nightly-0.3.1.dev20241102170757}/src/liger_kernel/transformers/model/gemma.py +0 -0
  42. {liger_kernel_nightly-0.3.1.dev20241101201851 → liger_kernel_nightly-0.3.1.dev20241102170757}/src/liger_kernel/transformers/model/mistral.py +0 -0
  43. {liger_kernel_nightly-0.3.1.dev20241101201851 → liger_kernel_nightly-0.3.1.dev20241102170757}/src/liger_kernel/transformers/model/mixtral.py +0 -0
  44. {liger_kernel_nightly-0.3.1.dev20241101201851 → liger_kernel_nightly-0.3.1.dev20241102170757}/src/liger_kernel/transformers/model/mllama.py +0 -0
  45. {liger_kernel_nightly-0.3.1.dev20241101201851 → liger_kernel_nightly-0.3.1.dev20241102170757}/src/liger_kernel/transformers/model/phi3.py +0 -0
  46. {liger_kernel_nightly-0.3.1.dev20241101201851 → liger_kernel_nightly-0.3.1.dev20241102170757}/src/liger_kernel/transformers/model/qwen2.py +0 -0
  47. {liger_kernel_nightly-0.3.1.dev20241101201851 → liger_kernel_nightly-0.3.1.dev20241102170757}/src/liger_kernel/transformers/model/qwen2_vl.py +0 -0
  48. {liger_kernel_nightly-0.3.1.dev20241101201851 → liger_kernel_nightly-0.3.1.dev20241102170757}/src/liger_kernel/transformers/monkey_patch.py +0 -0
  49. {liger_kernel_nightly-0.3.1.dev20241101201851 → liger_kernel_nightly-0.3.1.dev20241102170757}/src/liger_kernel/transformers/rms_norm.py +0 -0
  50. {liger_kernel_nightly-0.3.1.dev20241101201851 → liger_kernel_nightly-0.3.1.dev20241102170757}/src/liger_kernel/transformers/rope.py +0 -0
  51. {liger_kernel_nightly-0.3.1.dev20241101201851 → liger_kernel_nightly-0.3.1.dev20241102170757}/src/liger_kernel/transformers/swiglu.py +0 -0
  52. {liger_kernel_nightly-0.3.1.dev20241101201851 → liger_kernel_nightly-0.3.1.dev20241102170757}/src/liger_kernel/transformers/trainer_integration.py +0 -0
  53. {liger_kernel_nightly-0.3.1.dev20241101201851 → liger_kernel_nightly-0.3.1.dev20241102170757}/src/liger_kernel/triton/__init__.py +0 -0
  54. {liger_kernel_nightly-0.3.1.dev20241101201851 → liger_kernel_nightly-0.3.1.dev20241102170757}/src/liger_kernel/triton/monkey_patch.py +0 -0
  55. {liger_kernel_nightly-0.3.1.dev20241101201851 → liger_kernel_nightly-0.3.1.dev20241102170757}/src/liger_kernel_nightly.egg-info/SOURCES.txt +0 -0
  56. {liger_kernel_nightly-0.3.1.dev20241101201851 → liger_kernel_nightly-0.3.1.dev20241102170757}/src/liger_kernel_nightly.egg-info/dependency_links.txt +0 -0
  57. {liger_kernel_nightly-0.3.1.dev20241101201851 → liger_kernel_nightly-0.3.1.dev20241102170757}/src/liger_kernel_nightly.egg-info/requires.txt +0 -0
  58. {liger_kernel_nightly-0.3.1.dev20241101201851 → liger_kernel_nightly-0.3.1.dev20241102170757}/src/liger_kernel_nightly.egg-info/top_level.txt +0 -0
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.1
2
2
  Name: liger_kernel_nightly
3
- Version: 0.3.1.dev20241101201851
3
+ Version: 0.3.1.dev20241102170757
4
4
  Summary: Efficient Triton kernels for LLM Training
5
5
  License: BSD 2-CLAUSE LICENSE
6
6
  Copyright 2024 LinkedIn Corporation
@@ -163,11 +163,18 @@ With one line of code, Liger Kernel can increase throughput by more than 20% and
163
163
 
164
164
  ## Installation
165
165
 
166
- ### Dependencies
166
+ ### Dependencies
167
+
168
+ #### CUDA
167
169
 
168
170
  - `torch >= 2.1.2`
169
171
  - `triton >= 2.3.0`
170
172
 
173
+ #### ROCm
174
+
175
+ - `torch >= 2.5.0` Install according to the instruction in Pytorch official webpage.
176
+ - `triton >= 3.0.0` Install from pypi. (e.g. `pip install triton==3.0.0`)
177
+
171
178
  ### Optional Dependencies
172
179
 
173
180
  - `transformers >= 4.x`: Required if you plan to use the transformers models patching APIs. The specific model you are working will dictate the minimum version of transformers.
@@ -197,6 +204,7 @@ pip install -e .
197
204
  pip install -e .[transformers]
198
205
  ```
199
206
 
207
+
200
208
  ## Getting Started
201
209
 
202
210
  There are a couple of ways to apply Liger kernels, depending on the level of customization required.
@@ -111,11 +111,18 @@ With one line of code, Liger Kernel can increase throughput by more than 20% and
111
111
 
112
112
  ## Installation
113
113
 
114
- ### Dependencies
114
+ ### Dependencies
115
+
116
+ #### CUDA
115
117
 
116
118
  - `torch >= 2.1.2`
117
119
  - `triton >= 2.3.0`
118
120
 
121
+ #### ROCm
122
+
123
+ - `torch >= 2.5.0` Install according to the instruction in Pytorch official webpage.
124
+ - `triton >= 3.0.0` Install from pypi. (e.g. `pip install triton==3.0.0`)
125
+
119
126
  ### Optional Dependencies
120
127
 
121
128
  - `transformers >= 4.x`: Required if you plan to use the transformers models patching APIs. The specific model you are working will dictate the minimum version of transformers.
@@ -145,6 +152,7 @@ pip install -e .
145
152
  pip install -e .[transformers]
146
153
  ```
147
154
 
155
+
148
156
  ## Getting Started
149
157
 
150
158
  There are a couple of ways to apply Liger kernels, depending on the level of customization required.
@@ -4,7 +4,7 @@ build-backend = "setuptools.build_meta"
4
4
 
5
5
  [project]
6
6
  name = "liger_kernel_nightly"
7
- version = "0.3.1.dev20241101201851"
7
+ version = "0.3.1.dev20241102170757"
8
8
  description = "Efficient Triton kernels for LLM Training"
9
9
  urls = { "Homepage" = "https://github.com/linkedin/Liger-Kernel" }
10
10
  readme = { file = "README.md", content-type = "text/markdown" }
@@ -2,7 +2,7 @@ import torch
2
2
  import triton
3
3
  import triton.language as tl
4
4
 
5
- from liger_kernel.ops.utils import element_mul_kernel
5
+ from liger_kernel.ops.utils import element_mul_kernel, is_hip
6
6
 
7
7
 
8
8
  @triton.jit
@@ -194,7 +194,7 @@ def cross_entropy_forward(_input, target, ignore_index, label_smoothing, reducti
194
194
  BLOCK_SIZE=BLOCK_SIZE,
195
195
  # TODO: 32 seems to give the best performance
196
196
  # Performance is quite sensitive to num_warps
197
- num_warps=32,
197
+ num_warps=32 if not is_hip() else 16,
198
198
  )
199
199
 
200
200
  loss = torch.sum(loss_1d)
@@ -219,7 +219,7 @@ def cross_entropy_backward(_input, grad_output):
219
219
  grad_output,
220
220
  V,
221
221
  BLOCK_SIZE=BLOCK_SIZE,
222
- num_warps=32,
222
+ num_warps=32 if not is_hip() else 16,
223
223
  )
224
224
 
225
225
  return _input
@@ -2,7 +2,12 @@ import torch
2
2
  import triton
3
3
 
4
4
  from liger_kernel.ops.cross_entropy import liger_cross_entropy_kernel
5
- from liger_kernel.ops.utils import amp_custom_bwd, amp_custom_fwd, element_mul_kernel
5
+ from liger_kernel.ops.utils import (
6
+ amp_custom_bwd,
7
+ amp_custom_fwd,
8
+ element_mul_kernel,
9
+ is_hip,
10
+ )
6
11
 
7
12
  # The hard limit of TRITON_MAX_TENSOR_NUMEL is 1048576 https://github.com/triton-lang/triton/blob/ba42a5c68fd0505f8c42f4202d53be0f8d9a5fe0/python/triton/language/core.py#L19
8
13
  # However, setting limit as 65536 as in LayerNorm tutorial is faster because of less register spilling
@@ -88,7 +93,7 @@ def fused_linear_cross_entropy_forward(
88
93
  label_smoothing=label_smoothing,
89
94
  reduction=reduction,
90
95
  BLOCK_SIZE=BLOCK_SIZE,
91
- num_warps=32,
96
+ num_warps=32 if not is_hip() else 16,
92
97
  )
93
98
 
94
99
  # gradient of logits_chunk is computed in-place by the above triton kernel.
@@ -153,7 +158,7 @@ def fused_linear_cross_entropy_backward(
153
158
  grad_output,
154
159
  H,
155
160
  BLOCK_SIZE=BLOCK_SIZE,
156
- num_warps=32,
161
+ num_warps=32 if not is_hip() else 16,
157
162
  )
158
163
 
159
164
  # handle grad_weight
@@ -167,7 +172,7 @@ def fused_linear_cross_entropy_backward(
167
172
  grad_output,
168
173
  H,
169
174
  BLOCK_SIZE=BLOCK_SIZE,
170
- num_warps=32,
175
+ num_warps=32 if not is_hip() else 16,
171
176
  )
172
177
 
173
178
  if grad_bias is not None:
@@ -180,7 +185,7 @@ def fused_linear_cross_entropy_backward(
180
185
  grad_output,
181
186
  1,
182
187
  BLOCK_SIZE=BLOCK_SIZE,
183
- num_warps=32,
188
+ num_warps=32 if not is_hip() else 16,
184
189
  )
185
190
  return grad_input, grad_weight, grad_bias
186
191
 
@@ -4,7 +4,12 @@ import torch
4
4
  import triton
5
5
 
6
6
  from liger_kernel.ops.jsd import _jsd_kernel
7
- from liger_kernel.ops.utils import amp_custom_bwd, amp_custom_fwd, element_mul_kernel
7
+ from liger_kernel.ops.utils import (
8
+ amp_custom_bwd,
9
+ amp_custom_fwd,
10
+ element_mul_kernel,
11
+ is_hip,
12
+ )
8
13
 
9
14
  # The hard limit of TRITON_MAX_TENSOR_NUMEL is 1048576 https://github.com/triton-lang/triton/blob/ba42a5c68fd0505f8c42f4202d53be0f8d9a5fe0/python/triton/language/core.py#L19
10
15
  # However, setting limit as 65536 as in LayerNorm tutorial is faster because of less register spilling
@@ -147,7 +152,7 @@ def fused_linear_jsd_backward(grad_output, grad_input, grad_weight):
147
152
  grad_output,
148
153
  H,
149
154
  BLOCK_SIZE=BLOCK_SIZE,
150
- num_warps=32,
155
+ num_warps=32 if not is_hip() else 16,
151
156
  )
152
157
 
153
158
  # handle grad_weight
@@ -161,7 +166,7 @@ def fused_linear_jsd_backward(grad_output, grad_input, grad_weight):
161
166
  grad_output,
162
167
  H,
163
168
  BLOCK_SIZE=BLOCK_SIZE,
164
- num_warps=32,
169
+ num_warps=32 if not is_hip() else 16,
165
170
  )
166
171
 
167
172
  return grad_input, grad_weight
@@ -4,13 +4,13 @@ import torch
4
4
  import triton
5
5
  import triton.language as tl
6
6
 
7
- from liger_kernel.ops.utils import ensure_contiguous
7
+ from liger_kernel.ops.utils import ensure_contiguous, is_hip
8
8
 
9
9
 
10
10
  def get_num_warps(BLOCK_SIZE):
11
11
  num_warps = 4
12
12
  if BLOCK_SIZE >= 32768:
13
- num_warps = 32
13
+ num_warps = 32 if not is_hip() else 16
14
14
  elif BLOCK_SIZE >= 8192:
15
15
  num_warps = 16
16
16
  elif BLOCK_SIZE >= 2048:
@@ -21,6 +21,10 @@ import triton.language as tl
21
21
  from packaging.version import Version
22
22
 
23
23
 
24
+ def is_hip() -> bool:
25
+ return torch.version.hip is not None
26
+
27
+
24
28
  def ensure_contiguous(fn):
25
29
  @functools.wraps(fn)
26
30
  def wrapper(ctx, *args, **kwargs):
@@ -47,7 +51,7 @@ def calculate_settings(n):
47
51
 
48
52
  num_warps = 4
49
53
  if BLOCK_SIZE >= 32768:
50
- num_warps = 32
54
+ num_warps = 32 if not is_hip() else 16
51
55
  elif BLOCK_SIZE >= 8192:
52
56
  num_warps = 16
53
57
  elif BLOCK_SIZE >= 2048:
@@ -1,4 +1,4 @@
1
- from typing import List, Optional, Tuple, Union
1
+ from typing import TYPE_CHECKING, List, Optional, Tuple, Union
2
2
 
3
3
  import torch
4
4
  import torch.nn.functional as F
@@ -17,6 +17,9 @@ from liger_kernel.transformers.fused_linear_cross_entropy import (
17
17
  LigerFusedLinearCrossEntropyLoss,
18
18
  )
19
19
 
20
+ if TYPE_CHECKING:
21
+ from transformers.cache_utils import Cache
22
+
20
23
 
21
24
  @add_start_docstrings_to_model_forward(LLAMA_INPUTS_DOCSTRING)
22
25
  @replace_return_docstrings(
@@ -27,7 +30,7 @@ def lce_forward_deprecated(
27
30
  input_ids: torch.LongTensor = None,
28
31
  attention_mask: Optional[torch.Tensor] = None,
29
32
  position_ids: Optional[torch.LongTensor] = None,
30
- past_key_values: Optional[List[torch.FloatTensor]] = None,
33
+ past_key_values: Optional[Union["Cache", List[torch.FloatTensor]]] = None,
31
34
  inputs_embeds: Optional[torch.FloatTensor] = None,
32
35
  labels: Optional[torch.LongTensor] = None,
33
36
  use_cache: Optional[bool] = None,
@@ -153,19 +156,19 @@ def lce_forward_deprecated(
153
156
  )
154
157
  def lce_forward(
155
158
  self,
156
- input_ids=None,
157
- attention_mask=None,
158
- position_ids=None,
159
- past_key_values=None,
160
- inputs_embeds=None,
161
- labels=None,
162
- use_cache=None,
163
- output_attentions=None,
164
- output_hidden_states=None,
165
- return_dict=None,
166
- cache_position=None,
167
- num_logits_to_keep=0,
168
- **kwargs,
159
+ input_ids: torch.LongTensor = None,
160
+ attention_mask: Optional[torch.Tensor] = None,
161
+ position_ids: Optional[torch.LongTensor] = None,
162
+ past_key_values: Optional[Union["Cache", List[torch.FloatTensor]]] = None,
163
+ inputs_embeds: Optional[torch.FloatTensor] = None,
164
+ labels: Optional[torch.LongTensor] = None,
165
+ use_cache: Optional[bool] = None,
166
+ output_attentions: Optional[bool] = None,
167
+ output_hidden_states: Optional[bool] = None,
168
+ return_dict: Optional[bool] = None,
169
+ cache_position: Optional[torch.LongTensor] = None,
170
+ num_logits_to_keep: int = 0,
171
+ **loss_kwargs,
169
172
  ) -> Union[Tuple, CausalLMOutputWithPast]:
170
173
  r"""
171
174
  Args:
@@ -224,7 +227,6 @@ def lce_forward(
224
227
  output_hidden_states=output_hidden_states,
225
228
  return_dict=return_dict,
226
229
  cache_position=cache_position,
227
- **kwargs,
228
230
  )
229
231
 
230
232
  hidden_states = outputs[0]
@@ -245,12 +247,12 @@ def lce_forward(
245
247
  shift_hidden_states = shift_hidden_states.view(-1, self.config.hidden_size)
246
248
  shift_labels = shift_labels.view(-1)
247
249
 
248
- reduction = "sum" if "num_items_in_batch" in kwargs else "mean"
250
+ reduction = "sum" if "num_items_in_batch" in loss_kwargs else "mean"
249
251
  lce = LigerFusedLinearCrossEntropyLoss(reduction=reduction)
250
252
 
251
253
  loss = lce(self.lm_head.weight, shift_hidden_states, shift_labels)
252
254
  if reduction == "sum":
253
- loss /= kwargs["num_items_in_batch"]
255
+ loss /= loss_kwargs["num_items_in_batch"]
254
256
 
255
257
  else: # if in inference mode materialize logits
256
258
  logits = self.lm_head(hidden_states[:, -num_logits_to_keep:, :])
@@ -259,7 +261,7 @@ def lce_forward(
259
261
  logits=logits,
260
262
  labels=labels,
261
263
  vocab_size=self.config.vocab_size,
262
- **kwargs,
264
+ **loss_kwargs,
263
265
  )
264
266
 
265
267
  if not return_dict:
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.1
2
2
  Name: liger_kernel_nightly
3
- Version: 0.3.1.dev20241101201851
3
+ Version: 0.3.1.dev20241102170757
4
4
  Summary: Efficient Triton kernels for LLM Training
5
5
  License: BSD 2-CLAUSE LICENSE
6
6
  Copyright 2024 LinkedIn Corporation
@@ -163,11 +163,18 @@ With one line of code, Liger Kernel can increase throughput by more than 20% and
163
163
 
164
164
  ## Installation
165
165
 
166
- ### Dependencies
166
+ ### Dependencies
167
+
168
+ #### CUDA
167
169
 
168
170
  - `torch >= 2.1.2`
169
171
  - `triton >= 2.3.0`
170
172
 
173
+ #### ROCm
174
+
175
+ - `torch >= 2.5.0` Install according to the instruction in Pytorch official webpage.
176
+ - `triton >= 3.0.0` Install from pypi. (e.g. `pip install triton==3.0.0`)
177
+
171
178
  ### Optional Dependencies
172
179
 
173
180
  - `transformers >= 4.x`: Required if you plan to use the transformers models patching APIs. The specific model you are working will dictate the minimum version of transformers.
@@ -197,6 +204,7 @@ pip install -e .
197
204
  pip install -e .[transformers]
198
205
  ```
199
206
 
207
+
200
208
  ## Getting Started
201
209
 
202
210
  There are a couple of ways to apply Liger kernels, depending on the level of customization required.