liger-kernel-nightly 0.5.2.dev20241212071131__py3-none-any.whl → 0.5.2.dev20241217060137__py3-none-any.whl

Sign up to get free protection for your applications and to get access to all the features.
@@ -38,7 +38,7 @@ def lce_forward_deprecated(
38
38
  cache_position: Optional[torch.LongTensor] = None,
39
39
  ) -> Union[Tuple, MoeCausalLMOutputWithPast]:
40
40
  r"""
41
- Copy paste Mixtral's forward from transfomers v4.44.2 but replace torch cross entropy with liger fused linear cross entropy
41
+ Copy paste Mixtral's forward from transformers v4.44.2 but replace torch cross entropy with liger fused linear cross entropy
42
42
 
43
43
 
44
44
  Args:
@@ -17,7 +17,7 @@ class _FSDPForwardRedirection:
17
17
  This is needed in cases where we call a submodule of a FSDP module. For instance, when we want to call only
18
18
  the `LlamaModel` part out of a FSDP-wrapped `LlamaForCausalLM` to get the hidden states without involving
19
19
  GPU-memory-heavy `lm_head` and cross entropy computation, doing this directly (i.e. `model.model.forward()`)
20
- will not work because the first `nn.Emebedding` layer is not independently wrapped as a FSDP module (because of
20
+ will not work because the first `nn.Embedding` layer is not independently wrapped as a FSDP module (because of
21
21
  the transformer-based wrapping policy), and not calling it through FSDP root module forward will not all-gather
22
22
  its parameter, thus resulting in "RuntimeError: 'weight' must be 2-D" error. Similarly, if we want to call just
23
23
  the `lm_head` part of a model, we need this trick too to properly get its params all-gathered.
@@ -125,6 +125,10 @@ class LigerORPOTrainer(ORPOTrainer):
125
125
  outputs.last_hidden_state,
126
126
  concatenated_batch["concatenated_labels"],
127
127
  )
128
+ # if aux_loss_enabled, add the aux_loss to the orpo_loss
129
+ if self.aux_loss_enabled:
130
+ orpo_loss += self.aux_loss_coef * outputs.aux_loss
131
+
128
132
  return orpo_loss, aux_outputs
129
133
 
130
134
  def get_batch_loss_metrics(
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.1
2
2
  Name: liger_kernel_nightly
3
- Version: 0.5.2.dev20241212071131
3
+ Version: 0.5.2.dev20241217060137
4
4
  Summary: Efficient Triton kernels for LLM Training
5
5
  License: BSD 2-CLAUSE LICENSE
6
6
  Copyright 2024 LinkedIn Corporation
@@ -49,18 +49,18 @@ liger_kernel/transformers/model/gemma.py,sha256=R4huxuR48gkLrdT8KqV7As2v9dZtEmcG
49
49
  liger_kernel/transformers/model/gemma2.py,sha256=zxQsxCRqkoxCES3GJPVI7soUuF3J5HZDlvJgaBos1zM,10836
50
50
  liger_kernel/transformers/model/llama.py,sha256=RinsgC_eR-YNvZd2SHPQxZ4eyR3uViaTFCM3SvI5nks,10426
51
51
  liger_kernel/transformers/model/mistral.py,sha256=XpL1rlWg_llvW3z_Hf_d8WQs7uQaH4ds7EZ2SxjQHsU,5144
52
- liger_kernel/transformers/model/mixtral.py,sha256=nyDS1dBpsOXYC2DuW59Hgu7ZrGftrHuWPfNqjcNPIxs,11503
52
+ liger_kernel/transformers/model/mixtral.py,sha256=JlNS6DA6SJqeHDk7j2LZymPQ3wngrTIo3wUGFBqHuJs,11504
53
53
  liger_kernel/transformers/model/mllama.py,sha256=mesNCgj0Ea1O-fqRD4LVxDJ1CR2abY_zAzK_bfVzkiU,11222
54
54
  liger_kernel/transformers/model/phi3.py,sha256=xUZPlaPKwknLjHc3uUW3EPodm1h0vD3G7Qnhh51v-Io,10332
55
55
  liger_kernel/transformers/model/qwen2.py,sha256=EyhSSzQOskGjSnCsKMZpd1s5IAIlHd5PBO3q0MoCs00,9619
56
56
  liger_kernel/transformers/model/qwen2_vl.py,sha256=bIQe2bWiY--G84FhCD29Gdi64_qHP6vbcGsK6vKysQE,8547
57
57
  liger_kernel/transformers/trainer/__init__.py,sha256=c4OQVJmhNOloj0JYSEc0j_cQuBbzGWILfaowUR1hmRw,210
58
- liger_kernel/transformers/trainer/orpo_trainer.py,sha256=jko6oq_XQdBSmXubp05E-_YXOyhtB5Bj75dg5YNwOsE,7517
58
+ liger_kernel/transformers/trainer/orpo_trainer.py,sha256=O2k2vdHl-O1S-U61aEmyUFu3QrEuNAipQa2oUBb3HAA,7679
59
59
  liger_kernel/triton/__init__.py,sha256=yfRe0zMb47QnqjecZWG7LnanfCTzeku7SgWRAwNVmzU,101
60
60
  liger_kernel/triton/monkey_patch.py,sha256=5BcGKTtdqeYchypBIBopGIWPx1-cFALz7sOKoEsqXJ0,1584
61
- liger_kernel_nightly-0.5.2.dev20241212071131.dist-info/LICENSE,sha256=OhzLDHJ0to4a8sodVLELZiCFylZ1NAAYLs-HrjPy0ag,1312
62
- liger_kernel_nightly-0.5.2.dev20241212071131.dist-info/METADATA,sha256=OXd0vITMpiCdQW9JuXmKsjTwg612utOa0p9biJBsEgo,21055
63
- liger_kernel_nightly-0.5.2.dev20241212071131.dist-info/NOTICE,sha256=njwnoPZLh9AN8SJQzxvCGLHi-8X__AvWRze6joNXIY8,2066
64
- liger_kernel_nightly-0.5.2.dev20241212071131.dist-info/WHEEL,sha256=P9jw-gEje8ByB7_hXoICnHtVCrEwMQh-630tKvQWehc,91
65
- liger_kernel_nightly-0.5.2.dev20241212071131.dist-info/top_level.txt,sha256=2eghu4hA3LnkM7ElW92tQ8zegWKgSbeo-k-aGe1YnvY,13
66
- liger_kernel_nightly-0.5.2.dev20241212071131.dist-info/RECORD,,
61
+ liger_kernel_nightly-0.5.2.dev20241217060137.dist-info/LICENSE,sha256=OhzLDHJ0to4a8sodVLELZiCFylZ1NAAYLs-HrjPy0ag,1312
62
+ liger_kernel_nightly-0.5.2.dev20241217060137.dist-info/METADATA,sha256=s4F2CNLYmapm4S_h0kRqQVPItXe5hHkR81gBQL6P1L8,21055
63
+ liger_kernel_nightly-0.5.2.dev20241217060137.dist-info/NOTICE,sha256=njwnoPZLh9AN8SJQzxvCGLHi-8X__AvWRze6joNXIY8,2066
64
+ liger_kernel_nightly-0.5.2.dev20241217060137.dist-info/WHEEL,sha256=P9jw-gEje8ByB7_hXoICnHtVCrEwMQh-630tKvQWehc,91
65
+ liger_kernel_nightly-0.5.2.dev20241217060137.dist-info/top_level.txt,sha256=2eghu4hA3LnkM7ElW92tQ8zegWKgSbeo-k-aGe1YnvY,13
66
+ liger_kernel_nightly-0.5.2.dev20241217060137.dist-info/RECORD,,