liger-kernel-nightly 0.5.2.dev20241211055800__tar.gz → 0.5.2.dev20241211213024__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (70) hide show
  1. {liger_kernel_nightly-0.5.2.dev20241211055800/src/liger_kernel_nightly.egg-info → liger_kernel_nightly-0.5.2.dev20241211213024}/PKG-INFO +1 -1
  2. {liger_kernel_nightly-0.5.2.dev20241211055800 → liger_kernel_nightly-0.5.2.dev20241211213024}/pyproject.toml +1 -1
  3. {liger_kernel_nightly-0.5.2.dev20241211055800 → liger_kernel_nightly-0.5.2.dev20241211213024}/src/liger_kernel/chunked_loss/fused_linear_preference.py +6 -2
  4. {liger_kernel_nightly-0.5.2.dev20241211055800 → liger_kernel_nightly-0.5.2.dev20241211213024/src/liger_kernel_nightly.egg-info}/PKG-INFO +1 -1
  5. {liger_kernel_nightly-0.5.2.dev20241211055800 → liger_kernel_nightly-0.5.2.dev20241211213024}/LICENSE +0 -0
  6. {liger_kernel_nightly-0.5.2.dev20241211055800 → liger_kernel_nightly-0.5.2.dev20241211213024}/NOTICE +0 -0
  7. {liger_kernel_nightly-0.5.2.dev20241211055800 → liger_kernel_nightly-0.5.2.dev20241211213024}/README.md +0 -0
  8. {liger_kernel_nightly-0.5.2.dev20241211055800 → liger_kernel_nightly-0.5.2.dev20241211213024}/setup.cfg +0 -0
  9. {liger_kernel_nightly-0.5.2.dev20241211055800 → liger_kernel_nightly-0.5.2.dev20241211213024}/src/liger_kernel/__init__.py +0 -0
  10. {liger_kernel_nightly-0.5.2.dev20241211055800 → liger_kernel_nightly-0.5.2.dev20241211213024}/src/liger_kernel/chunked_loss/__init__.py +0 -0
  11. {liger_kernel_nightly-0.5.2.dev20241211055800 → liger_kernel_nightly-0.5.2.dev20241211213024}/src/liger_kernel/chunked_loss/cpo_loss.py +0 -0
  12. {liger_kernel_nightly-0.5.2.dev20241211055800 → liger_kernel_nightly-0.5.2.dev20241211213024}/src/liger_kernel/chunked_loss/dpo_loss.py +0 -0
  13. {liger_kernel_nightly-0.5.2.dev20241211055800 → liger_kernel_nightly-0.5.2.dev20241211213024}/src/liger_kernel/chunked_loss/functional.py +0 -0
  14. {liger_kernel_nightly-0.5.2.dev20241211055800 → liger_kernel_nightly-0.5.2.dev20241211213024}/src/liger_kernel/chunked_loss/fused_linear_distillation.py +0 -0
  15. {liger_kernel_nightly-0.5.2.dev20241211055800 → liger_kernel_nightly-0.5.2.dev20241211213024}/src/liger_kernel/chunked_loss/orpo_loss.py +0 -0
  16. {liger_kernel_nightly-0.5.2.dev20241211055800 → liger_kernel_nightly-0.5.2.dev20241211213024}/src/liger_kernel/chunked_loss/simpo_loss.py +0 -0
  17. {liger_kernel_nightly-0.5.2.dev20241211055800 → liger_kernel_nightly-0.5.2.dev20241211213024}/src/liger_kernel/env_report.py +0 -0
  18. {liger_kernel_nightly-0.5.2.dev20241211055800 → liger_kernel_nightly-0.5.2.dev20241211213024}/src/liger_kernel/ops/__init__.py +0 -0
  19. {liger_kernel_nightly-0.5.2.dev20241211055800 → liger_kernel_nightly-0.5.2.dev20241211213024}/src/liger_kernel/ops/cross_entropy.py +0 -0
  20. {liger_kernel_nightly-0.5.2.dev20241211055800 → liger_kernel_nightly-0.5.2.dev20241211213024}/src/liger_kernel/ops/experimental/embedding.py +0 -0
  21. {liger_kernel_nightly-0.5.2.dev20241211055800 → liger_kernel_nightly-0.5.2.dev20241211213024}/src/liger_kernel/ops/experimental/mm_int8int2.py +0 -0
  22. {liger_kernel_nightly-0.5.2.dev20241211055800 → liger_kernel_nightly-0.5.2.dev20241211213024}/src/liger_kernel/ops/fused_linear_cross_entropy.py +0 -0
  23. {liger_kernel_nightly-0.5.2.dev20241211055800 → liger_kernel_nightly-0.5.2.dev20241211213024}/src/liger_kernel/ops/fused_linear_jsd.py +0 -0
  24. {liger_kernel_nightly-0.5.2.dev20241211055800 → liger_kernel_nightly-0.5.2.dev20241211213024}/src/liger_kernel/ops/geglu.py +0 -0
  25. {liger_kernel_nightly-0.5.2.dev20241211055800 → liger_kernel_nightly-0.5.2.dev20241211213024}/src/liger_kernel/ops/group_norm.py +0 -0
  26. {liger_kernel_nightly-0.5.2.dev20241211055800 → liger_kernel_nightly-0.5.2.dev20241211213024}/src/liger_kernel/ops/jsd.py +0 -0
  27. {liger_kernel_nightly-0.5.2.dev20241211055800 → liger_kernel_nightly-0.5.2.dev20241211213024}/src/liger_kernel/ops/kl_div.py +0 -0
  28. {liger_kernel_nightly-0.5.2.dev20241211055800 → liger_kernel_nightly-0.5.2.dev20241211213024}/src/liger_kernel/ops/layer_norm.py +0 -0
  29. {liger_kernel_nightly-0.5.2.dev20241211055800 → liger_kernel_nightly-0.5.2.dev20241211213024}/src/liger_kernel/ops/qwen2vl_mrope.py +0 -0
  30. {liger_kernel_nightly-0.5.2.dev20241211055800 → liger_kernel_nightly-0.5.2.dev20241211213024}/src/liger_kernel/ops/rms_norm.py +0 -0
  31. {liger_kernel_nightly-0.5.2.dev20241211055800 → liger_kernel_nightly-0.5.2.dev20241211213024}/src/liger_kernel/ops/rope.py +0 -0
  32. {liger_kernel_nightly-0.5.2.dev20241211055800 → liger_kernel_nightly-0.5.2.dev20241211213024}/src/liger_kernel/ops/swiglu.py +0 -0
  33. {liger_kernel_nightly-0.5.2.dev20241211055800 → liger_kernel_nightly-0.5.2.dev20241211213024}/src/liger_kernel/ops/utils.py +0 -0
  34. {liger_kernel_nightly-0.5.2.dev20241211055800 → liger_kernel_nightly-0.5.2.dev20241211213024}/src/liger_kernel/transformers/__init__.py +0 -0
  35. {liger_kernel_nightly-0.5.2.dev20241211055800 → liger_kernel_nightly-0.5.2.dev20241211213024}/src/liger_kernel/transformers/auto_model.py +0 -0
  36. {liger_kernel_nightly-0.5.2.dev20241211055800 → liger_kernel_nightly-0.5.2.dev20241211213024}/src/liger_kernel/transformers/cross_entropy.py +0 -0
  37. {liger_kernel_nightly-0.5.2.dev20241211055800 → liger_kernel_nightly-0.5.2.dev20241211213024}/src/liger_kernel/transformers/experimental/embedding.py +0 -0
  38. {liger_kernel_nightly-0.5.2.dev20241211055800 → liger_kernel_nightly-0.5.2.dev20241211213024}/src/liger_kernel/transformers/functional.py +0 -0
  39. {liger_kernel_nightly-0.5.2.dev20241211055800 → liger_kernel_nightly-0.5.2.dev20241211213024}/src/liger_kernel/transformers/fused_linear_cross_entropy.py +0 -0
  40. {liger_kernel_nightly-0.5.2.dev20241211055800 → liger_kernel_nightly-0.5.2.dev20241211213024}/src/liger_kernel/transformers/fused_linear_jsd.py +0 -0
  41. {liger_kernel_nightly-0.5.2.dev20241211055800 → liger_kernel_nightly-0.5.2.dev20241211213024}/src/liger_kernel/transformers/geglu.py +0 -0
  42. {liger_kernel_nightly-0.5.2.dev20241211055800 → liger_kernel_nightly-0.5.2.dev20241211213024}/src/liger_kernel/transformers/group_norm.py +0 -0
  43. {liger_kernel_nightly-0.5.2.dev20241211055800 → liger_kernel_nightly-0.5.2.dev20241211213024}/src/liger_kernel/transformers/jsd.py +0 -0
  44. {liger_kernel_nightly-0.5.2.dev20241211055800 → liger_kernel_nightly-0.5.2.dev20241211213024}/src/liger_kernel/transformers/kl_div.py +0 -0
  45. {liger_kernel_nightly-0.5.2.dev20241211055800 → liger_kernel_nightly-0.5.2.dev20241211213024}/src/liger_kernel/transformers/layer_norm.py +0 -0
  46. {liger_kernel_nightly-0.5.2.dev20241211055800 → liger_kernel_nightly-0.5.2.dev20241211213024}/src/liger_kernel/transformers/model/__init__.py +0 -0
  47. {liger_kernel_nightly-0.5.2.dev20241211055800 → liger_kernel_nightly-0.5.2.dev20241211213024}/src/liger_kernel/transformers/model/gemma.py +0 -0
  48. {liger_kernel_nightly-0.5.2.dev20241211055800 → liger_kernel_nightly-0.5.2.dev20241211213024}/src/liger_kernel/transformers/model/gemma2.py +0 -0
  49. {liger_kernel_nightly-0.5.2.dev20241211055800 → liger_kernel_nightly-0.5.2.dev20241211213024}/src/liger_kernel/transformers/model/llama.py +0 -0
  50. {liger_kernel_nightly-0.5.2.dev20241211055800 → liger_kernel_nightly-0.5.2.dev20241211213024}/src/liger_kernel/transformers/model/mistral.py +0 -0
  51. {liger_kernel_nightly-0.5.2.dev20241211055800 → liger_kernel_nightly-0.5.2.dev20241211213024}/src/liger_kernel/transformers/model/mixtral.py +0 -0
  52. {liger_kernel_nightly-0.5.2.dev20241211055800 → liger_kernel_nightly-0.5.2.dev20241211213024}/src/liger_kernel/transformers/model/mllama.py +0 -0
  53. {liger_kernel_nightly-0.5.2.dev20241211055800 → liger_kernel_nightly-0.5.2.dev20241211213024}/src/liger_kernel/transformers/model/phi3.py +0 -0
  54. {liger_kernel_nightly-0.5.2.dev20241211055800 → liger_kernel_nightly-0.5.2.dev20241211213024}/src/liger_kernel/transformers/model/qwen2.py +0 -0
  55. {liger_kernel_nightly-0.5.2.dev20241211055800 → liger_kernel_nightly-0.5.2.dev20241211213024}/src/liger_kernel/transformers/model/qwen2_vl.py +0 -0
  56. {liger_kernel_nightly-0.5.2.dev20241211055800 → liger_kernel_nightly-0.5.2.dev20241211213024}/src/liger_kernel/transformers/monkey_patch.py +0 -0
  57. {liger_kernel_nightly-0.5.2.dev20241211055800 → liger_kernel_nightly-0.5.2.dev20241211213024}/src/liger_kernel/transformers/qwen2vl_mrope.py +0 -0
  58. {liger_kernel_nightly-0.5.2.dev20241211055800 → liger_kernel_nightly-0.5.2.dev20241211213024}/src/liger_kernel/transformers/rms_norm.py +0 -0
  59. {liger_kernel_nightly-0.5.2.dev20241211055800 → liger_kernel_nightly-0.5.2.dev20241211213024}/src/liger_kernel/transformers/rope.py +0 -0
  60. {liger_kernel_nightly-0.5.2.dev20241211055800 → liger_kernel_nightly-0.5.2.dev20241211213024}/src/liger_kernel/transformers/swiglu.py +0 -0
  61. {liger_kernel_nightly-0.5.2.dev20241211055800 → liger_kernel_nightly-0.5.2.dev20241211213024}/src/liger_kernel/transformers/trainer/__init__.py +0 -0
  62. {liger_kernel_nightly-0.5.2.dev20241211055800 → liger_kernel_nightly-0.5.2.dev20241211213024}/src/liger_kernel/transformers/trainer/orpo_trainer.py +0 -0
  63. {liger_kernel_nightly-0.5.2.dev20241211055800 → liger_kernel_nightly-0.5.2.dev20241211213024}/src/liger_kernel/transformers/trainer_integration.py +0 -0
  64. {liger_kernel_nightly-0.5.2.dev20241211055800 → liger_kernel_nightly-0.5.2.dev20241211213024}/src/liger_kernel/triton/__init__.py +0 -0
  65. {liger_kernel_nightly-0.5.2.dev20241211055800 → liger_kernel_nightly-0.5.2.dev20241211213024}/src/liger_kernel/triton/monkey_patch.py +0 -0
  66. {liger_kernel_nightly-0.5.2.dev20241211055800 → liger_kernel_nightly-0.5.2.dev20241211213024}/src/liger_kernel/utils.py +0 -0
  67. {liger_kernel_nightly-0.5.2.dev20241211055800 → liger_kernel_nightly-0.5.2.dev20241211213024}/src/liger_kernel_nightly.egg-info/SOURCES.txt +0 -0
  68. {liger_kernel_nightly-0.5.2.dev20241211055800 → liger_kernel_nightly-0.5.2.dev20241211213024}/src/liger_kernel_nightly.egg-info/dependency_links.txt +0 -0
  69. {liger_kernel_nightly-0.5.2.dev20241211055800 → liger_kernel_nightly-0.5.2.dev20241211213024}/src/liger_kernel_nightly.egg-info/requires.txt +0 -0
  70. {liger_kernel_nightly-0.5.2.dev20241211055800 → liger_kernel_nightly-0.5.2.dev20241211213024}/src/liger_kernel_nightly.egg-info/top_level.txt +0 -0
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.1
2
2
  Name: liger_kernel_nightly
3
- Version: 0.5.2.dev20241211055800
3
+ Version: 0.5.2.dev20241211213024
4
4
  Summary: Efficient Triton kernels for LLM Training
5
5
  License: BSD 2-CLAUSE LICENSE
6
6
  Copyright 2024 LinkedIn Corporation
@@ -4,7 +4,7 @@ build-backend = "setuptools.build_meta"
4
4
 
5
5
  [project]
6
6
  name = "liger_kernel_nightly"
7
- version = "0.5.2.dev20241211055800"
7
+ version = "0.5.2.dev20241211213024"
8
8
  description = "Efficient Triton kernels for LLM Training"
9
9
  urls = { "Homepage" = "https://github.com/linkedin/Liger-Kernel" }
10
10
  readme = { file = "README.md", content-type = "text/markdown" }
@@ -29,7 +29,7 @@ class LigerFusedLinearPreferenceBase(torch.autograd.Function):
29
29
  compute_nll_loss=True,
30
30
  compiled=True,
31
31
  use_ref_model=False,
32
- # TODO: ref input
32
+ ref_input=None,
33
33
  ref_weight=None,
34
34
  ref_bias=None,
35
35
  **loss_kwargs,
@@ -59,6 +59,7 @@ class LigerFusedLinearPreferenceBase(torch.autograd.Function):
59
59
  compute_nll_loss (bool): Whether to compute NLL loss.
60
60
  compiled (bool): Whether to use torch compile for chunk accumulation.
61
61
  use_ref_model (bool): Whether to use a reference model for the alignment loss.
62
+ ref_input (torch.Tensor): Reference input tensor. Shape: (batch_size, seq_len, hidden_size).
62
63
  ref_weight (torch.Tensor): Reference weight tensor. Shape: (vocab_size, hidden_size).
63
64
  ref_bias (torch.Tensor, optional): Reference bias tensor. Shape: (vocab_size,).
64
65
  loss_kwargs (dict): Other possible arguments that a loss function might need
@@ -92,6 +93,7 @@ class LigerFusedLinearPreferenceBase(torch.autograd.Function):
92
93
  compute_nll_loss=compute_nll_loss,
93
94
  full_target=target,
94
95
  use_ref_model=use_ref_model,
96
+ ref_input=ref_input,
95
97
  ref_weight=ref_weight,
96
98
  ref_bias=ref_bias,
97
99
  **loss_kwargs,
@@ -301,6 +303,7 @@ class LigerFusedLinearPreferenceBase(torch.autograd.Function):
301
303
  beta=0.1,
302
304
  compute_nll_loss=True,
303
305
  use_ref_model=False,
306
+ ref_input=None,
304
307
  ref_weight=None,
305
308
  ref_bias=None,
306
309
  **loss_kwargs,
@@ -319,6 +322,7 @@ class LigerFusedLinearPreferenceBase(torch.autograd.Function):
319
322
  beta (float): Weight for the preference loss.
320
323
  compute_nll_loss (bool): Whether to compute NLL loss.
321
324
  use_ref_model (bool): Whether to use a reference model for the alignment loss.
325
+ ref_input (torch.Tensor): Reference input tensor. Shape: (2 * chunk_size, sequence_length, hidden_size).
322
326
  ref_weight (torch.Tensor): Reference weight tensor. Shape: (vocab_size, hidden_size).
323
327
  ref_bias (torch.Tensor, optional): Reference bias tensor. Shape: (vocab_size,).
324
328
  loss_kwargs (dict): Additional arguments for the loss function.
@@ -357,7 +361,7 @@ class LigerFusedLinearPreferenceBase(torch.autograd.Function):
357
361
  ref_rejected_logits,
358
362
  ref_chosen_nll_loss,
359
363
  ) = LigerFusedLinearPreferenceBase.chunk_forward(
360
- input_chunk,
364
+ ref_input,
361
365
  ref_weight,
362
366
  target_chunk,
363
367
  ref_bias,
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.1
2
2
  Name: liger_kernel_nightly
3
- Version: 0.5.2.dev20241211055800
3
+ Version: 0.5.2.dev20241211213024
4
4
  Summary: Efficient Triton kernels for LLM Training
5
5
  License: BSD 2-CLAUSE LICENSE
6
6
  Copyright 2024 LinkedIn Corporation