liger-kernel 0.0.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -0,0 +1,143 @@
1
+ from typing import List, Optional, Tuple, Union
2
+
3
+ import torch
4
+ import torch.nn.functional as F
5
+ from torch.nn import CrossEntropyLoss
6
+ from transformers.modeling_outputs import CausalLMOutputWithPast
7
+ from transformers.models.llama.modeling_llama import (
8
+ _CONFIG_FOR_DOC,
9
+ LLAMA_INPUTS_DOCSTRING,
10
+ )
11
+ from transformers.utils import (
12
+ add_start_docstrings_to_model_forward,
13
+ replace_return_docstrings,
14
+ )
15
+
16
+ from liger_kernel.transformers.fused_linear_cross_entropy import (
17
+ LigerFusedLinearCrossEntropyLoss,
18
+ )
19
+
20
+
21
+ @add_start_docstrings_to_model_forward(LLAMA_INPUTS_DOCSTRING)
22
+ @replace_return_docstrings(
23
+ output_type=CausalLMOutputWithPast, config_class=_CONFIG_FOR_DOC
24
+ )
25
+ def lce_forward(
26
+ self,
27
+ input_ids: torch.LongTensor = None,
28
+ attention_mask: Optional[torch.Tensor] = None,
29
+ position_ids: Optional[torch.LongTensor] = None,
30
+ past_key_values: Optional[List[torch.FloatTensor]] = None,
31
+ inputs_embeds: Optional[torch.FloatTensor] = None,
32
+ labels: Optional[torch.LongTensor] = None,
33
+ use_cache: Optional[bool] = None,
34
+ output_attentions: Optional[bool] = None,
35
+ output_hidden_states: Optional[bool] = None,
36
+ return_dict: Optional[bool] = None,
37
+ cache_position: Optional[torch.LongTensor] = None,
38
+ ) -> Union[Tuple, CausalLMOutputWithPast]:
39
+ r"""
40
+ Args:
41
+ labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
42
+ Labels for computing the masked language modeling loss. Indices should either be in `[0, ...,
43
+ config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored
44
+ (masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`.
45
+
46
+ Returns:
47
+
48
+ Example:
49
+
50
+ ```python
51
+ >>> from transformers import AutoTokenizer, LlamaForCausalLM
52
+
53
+ >>> model = LlamaForCausalLM.from_pretrained("meta-llama/Llama-2-7b-hf")
54
+ >>> tokenizer = AutoTokenizer.from_pretrained("meta-llama/Llama-2-7b-hf")
55
+
56
+ >>> prompt = "Hey, are you conscious? Can you talk to me?"
57
+ >>> inputs = tokenizer(prompt, return_tensors="pt")
58
+
59
+ >>> # Generate
60
+ >>> generate_ids = model.generate(inputs.input_ids, max_length=30)
61
+ >>> tokenizer.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0]
62
+ "Hey, are you conscious? Can you talk to me?\nI'm not conscious, but I can talk to you."
63
+ ```"""
64
+ output_attentions = (
65
+ output_attentions
66
+ if output_attentions is not None
67
+ else self.config.output_attentions
68
+ )
69
+ output_hidden_states = (
70
+ output_hidden_states
71
+ if output_hidden_states is not None
72
+ else self.config.output_hidden_states
73
+ )
74
+ return_dict = (
75
+ return_dict if return_dict is not None else self.config.use_return_dict
76
+ )
77
+
78
+ # decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn)
79
+ outputs = self.model(
80
+ input_ids=input_ids,
81
+ attention_mask=attention_mask,
82
+ position_ids=position_ids,
83
+ past_key_values=past_key_values,
84
+ inputs_embeds=inputs_embeds,
85
+ use_cache=use_cache,
86
+ output_attentions=output_attentions,
87
+ output_hidden_states=output_hidden_states,
88
+ return_dict=return_dict,
89
+ cache_position=cache_position,
90
+ )
91
+
92
+ hidden_states = outputs[0]
93
+
94
+ loss = None
95
+ logits = None
96
+
97
+ if self.training:
98
+ shift_hidden_states = hidden_states[..., :-1, :].contiguous()
99
+ shift_labels = labels[..., 1:].contiguous()
100
+
101
+ # flatten tokens
102
+ shift_hidden_states = shift_hidden_states.view(-1, self.config.hidden_size)
103
+ shift_labels = shift_labels.view(-1)
104
+
105
+ lce = LigerFusedLinearCrossEntropyLoss()
106
+ loss = lce(self.lm_head.weight, shift_hidden_states, shift_labels)
107
+
108
+ else:
109
+ if self.config.pretraining_tp > 1:
110
+ lm_head_slices = self.lm_head.weight.split(
111
+ self.vocab_size // self.config.pretraining_tp, dim=0
112
+ )
113
+ logits = [
114
+ F.linear(hidden_states, lm_head_slices[i])
115
+ for i in range(self.config.pretraining_tp)
116
+ ]
117
+ logits = torch.cat(logits, dim=-1)
118
+ else:
119
+ logits = self.lm_head(hidden_states)
120
+ logits = logits.float()
121
+ if labels is not None:
122
+ # Shift so that tokens < n predict n
123
+ shift_logits = logits[..., :-1, :].contiguous()
124
+ shift_labels = labels[..., 1:].contiguous()
125
+ # Flatten the tokens
126
+ loss_fct = CrossEntropyLoss()
127
+ shift_logits = shift_logits.view(-1, self.config.vocab_size)
128
+ shift_labels = shift_labels.view(-1)
129
+ # Enable model parallelism
130
+ shift_labels = shift_labels.to(shift_logits.device)
131
+ loss = loss_fct(shift_logits, shift_labels)
132
+
133
+ if not return_dict:
134
+ output = (logits,) + outputs[1:]
135
+ return (loss,) + output if loss is not None else output
136
+
137
+ return CausalLMOutputWithPast(
138
+ loss=loss,
139
+ logits=logits,
140
+ past_key_values=outputs.past_key_values,
141
+ hidden_states=outputs.hidden_states,
142
+ attentions=outputs.attentions,
143
+ )
@@ -0,0 +1,103 @@
1
+ from liger_kernel.transformers.cross_entropy import LigerCrossEntropyLoss
2
+ from liger_kernel.transformers.model.llama import lce_forward
3
+ from liger_kernel.transformers.rms_norm import LigerRMSNorm
4
+ from liger_kernel.transformers.rope import liger_rotary_pos_emb
5
+ from liger_kernel.transformers.swiglu import LigerBlockSparseTop2MLP, LigerSwiGLUMLP
6
+
7
+
8
+ # TODO: probably rename utils.py as hf_patcher.py to be more descriptive
9
+ def apply_liger_kernel_to_llama(
10
+ rope: bool = True,
11
+ cross_entropy: bool = True,
12
+ fused_linear_cross_entropy: bool = False,
13
+ rms_norm: bool = True,
14
+ swiglu: bool = True,
15
+ ) -> None:
16
+ """
17
+ Apply Liger kernels to replace original implementation in HuggingFace Llama models (2 and 3)
18
+ to make GPU go burrr.
19
+
20
+ Args:
21
+ rope (bool): Whether to apply Liger's rotary position embedding. Default is True.
22
+ cross_entropy (bool): Whether to apply Liger's cross entropy loss. Default is True.
23
+ fused_linear_cross_entropy (bool):
24
+ Whether to apply Liger's fused lienar cross entropy loss. Default is False.
25
+ `cross_entropy` and `fused_linear_cross_entropy` cannot both be True.
26
+ If `fused_linear_cross_entropy` is True, the logits will not be materialized but more memory efficient.
27
+ rms_norm (bool): Whether to apply Liger's RMSNorm. Default is True.
28
+ swiglu (bool): Whether to apply Liger's SwiGLU MLP. Default is True.
29
+ """
30
+
31
+ assert not (
32
+ cross_entropy and fused_linear_cross_entropy
33
+ ), "cross_entropy and fused_linear_cross_entropy cannot both be True."
34
+
35
+ from transformers.models.llama import modeling_llama
36
+
37
+ if rope:
38
+ modeling_llama.apply_rotary_pos_emb = liger_rotary_pos_emb
39
+ if rms_norm:
40
+ modeling_llama.LlamaRMSNorm = LigerRMSNorm
41
+ if swiglu:
42
+ modeling_llama.LlamaMLP = LigerSwiGLUMLP
43
+ if cross_entropy:
44
+ modeling_llama.CrossEntropyLoss = LigerCrossEntropyLoss
45
+ if fused_linear_cross_entropy:
46
+ modeling_llama.LlamaForCausalLM.forward = lce_forward
47
+
48
+
49
+ def apply_liger_kernel_to_mistral(
50
+ rope: bool = True,
51
+ cross_entropy: bool = True,
52
+ rms_norm: bool = True,
53
+ swiglu: bool = True,
54
+ ) -> None:
55
+ """
56
+ Apply Liger kernels to replace original implementation in HuggingFace Mistral models
57
+ to make GPU go burrr.
58
+
59
+ Args:
60
+ rope (bool): Whether to apply Liger's rotary position embedding. Default is True.
61
+ cross_entropy (bool): Whether to apply Liger's cross entropy loss. Default is True.
62
+ rms_norm (bool): Whether to apply Liger's RMSNorm. Default is True.
63
+ swiglu (bool): Whether to apply Liger's SwiGLU MLP. Default is True.
64
+ """
65
+
66
+ from transformers.models.mistral import modeling_mistral
67
+
68
+ if rope:
69
+ modeling_mistral.apply_rotary_pos_emb = liger_rotary_pos_emb
70
+ if rms_norm:
71
+ modeling_mistral.MistralRMSNorm = LigerRMSNorm
72
+ if cross_entropy:
73
+ modeling_mistral.CrossEntropyLoss = LigerCrossEntropyLoss
74
+ if swiglu:
75
+ modeling_mistral.MistralMLP = LigerSwiGLUMLP
76
+
77
+
78
+ def apply_liger_kernel_to_mixtral(
79
+ rope: bool = True,
80
+ cross_entropy: bool = True,
81
+ rms_norm: bool = True,
82
+ swiglu: bool = True,
83
+ ) -> None:
84
+ """
85
+ Apply Liger kernels to replace original implementation in HuggingFace Mixtral models
86
+ to make GPU go burrr.
87
+
88
+ Args:
89
+ rope (bool): Whether to apply Liger's rotary position embedding. Default is True.
90
+ cross_entropy (bool): Whether to apply Liger's cross entropy loss. Default is True.
91
+ rms_norm (bool): Whether to apply Liger's RMSNorm. Default is True.
92
+ """
93
+
94
+ from transformers.models.mixtral import modeling_mixtral
95
+
96
+ if rope:
97
+ modeling_mixtral.apply_rotary_pos_emb = liger_rotary_pos_emb
98
+ if rms_norm:
99
+ modeling_mixtral.MistralRMSNorm = LigerRMSNorm
100
+ if cross_entropy:
101
+ modeling_mixtral.CrossEntropyLoss = LigerCrossEntropyLoss
102
+ if swiglu:
103
+ modeling_mixtral.MixtralBlockSparseTop2MLP = LigerBlockSparseTop2MLP
@@ -0,0 +1,16 @@
1
+ import torch
2
+ import torch.nn as nn
3
+
4
+ from liger_kernel.ops.rms_norm import LigerRMSNormFunction
5
+
6
+
7
+ class LigerRMSNorm(nn.Module):
8
+ def __init__(self, hidden_size, eps=1e-6):
9
+ super().__init__()
10
+ self.weight = nn.Parameter(torch.ones(hidden_size))
11
+ self.variance_epsilon = eps
12
+
13
+ def forward(self, hidden_states):
14
+ return LigerRMSNormFunction.apply(
15
+ hidden_states, self.weight, self.variance_epsilon
16
+ )
@@ -0,0 +1,20 @@
1
+ from liger_kernel.ops.rope import LigerRopeFunction
2
+
3
+
4
+ def liger_rotary_pos_emb(q, k, cos, sin, position_ids=None, unsqueeze_dim=1):
5
+ """
6
+ Applies Rotary Positional Embedding (RoPE) operation to query and key states.
7
+
8
+ Args:
9
+ q (torch.Tensor): The query tensor of shape (bsz, n_q_head, seq_len, head_dim).
10
+ k (torch.Tensor): The key tensor of shape (bsz, n_kv_head, seq_len, head_dim).
11
+ cos (torch.Tensor): The cosine tensor of shape (1, seq_len, head_dim).
12
+ sin (torch.Tensor): The sine tensor of shape (1, seq_len, head_dim).
13
+ position_ids (torch.Tensor, optional): The position ids tensor. Defaults to None.
14
+ unsqueeze_dim (int, optional): The dimension to unsqueeze. Defaults to 1.
15
+
16
+ Returns:
17
+ Tuple[torch.Tensor, torch.Tensor]: The query and key tensors after applying the RoPE operation.
18
+ """
19
+
20
+ return LigerRopeFunction.apply(q, k, cos, sin, position_ids, unsqueeze_dim)
@@ -0,0 +1,40 @@
1
+ import torch.nn as nn
2
+
3
+ from liger_kernel.ops.swiglu import LigerSiLUMulFunction
4
+
5
+
6
+ class LigerSwiGLUMLP(nn.Module):
7
+ def __init__(self, config):
8
+ super().__init__()
9
+ self.config = config
10
+ self.hidden_size = config.hidden_size
11
+ self.intermediate_size = config.intermediate_size
12
+ self.gate_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=False)
13
+ self.up_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=False)
14
+ self.down_proj = nn.Linear(self.intermediate_size, self.hidden_size, bias=False)
15
+ if config.hidden_act not in ["silu", "swish"]:
16
+ raise ValueError(f"Activation function {config.hidden_act} not supported.")
17
+
18
+ def forward(self, x):
19
+
20
+ return self.down_proj(
21
+ LigerSiLUMulFunction.apply(self.gate_proj(x), self.up_proj(x))
22
+ )
23
+
24
+
25
+ class LigerBlockSparseTop2MLP(nn.Module):
26
+ def __init__(self, config):
27
+ super().__init__()
28
+ self.ffn_dim = config.intermediate_size
29
+ self.hidden_dim = config.hidden_size
30
+
31
+ self.w1 = nn.Linear(self.hidden_dim, self.ffn_dim, bias=False)
32
+ self.w2 = nn.Linear(self.ffn_dim, self.hidden_dim, bias=False)
33
+ self.w3 = nn.Linear(self.hidden_dim, self.ffn_dim, bias=False)
34
+
35
+ if config.hidden_act not in ["silu", "swish"]:
36
+ raise ValueError(f"Activation function {config.hidden_act} not supported.")
37
+
38
+ def forward(self, x):
39
+
40
+ return self.w2(LigerSiLUMulFunction.apply(self.w1(x), self.w3(x)))
@@ -0,0 +1,3 @@
1
+ from liger_kernel.triton.monkey_patch import ( # noqa: F401
2
+ apply_liger_triton_cache_manager,
3
+ )
@@ -0,0 +1,44 @@
1
+ import os
2
+ import random
3
+
4
+ from overrides import override
5
+ from triton.runtime.cache import FileCacheManager
6
+
7
+
8
+ class LigerTritonFileCacheManager(FileCacheManager):
9
+ @override
10
+ def put(self, data, filename, binary=True) -> str:
11
+ if not self.cache_dir:
12
+ raise RuntimeError("Could not create or locate cache dir")
13
+ binary = isinstance(data, bytes)
14
+ if not binary:
15
+ data = str(data)
16
+ assert self.lock_path is not None
17
+ filepath = self._make_path(filename)
18
+ # Random ID to avoid any collisions
19
+ rnd_id = random.randint(0, 1000000)
20
+ # we use the PID incase a bunch of these around so we can see what PID made it
21
+ pid = os.getpid()
22
+ # use temp dir to be robust against program interruptions
23
+ temp_dir = os.path.join(self.cache_dir, f"tmp.pid_{pid}_{rnd_id}")
24
+ os.makedirs(temp_dir, exist_ok=True)
25
+ temp_path = os.path.join(temp_dir, filename)
26
+
27
+ mode = "wb" if binary else "w"
28
+ with open(temp_path, mode) as f:
29
+ f.write(data)
30
+ # Replace is guaranteed to be atomic on POSIX systems if it succeeds
31
+ # so filepath cannot see a partial write
32
+ os.replace(temp_path, filepath)
33
+ os.removedirs(temp_dir)
34
+ return filepath
35
+
36
+
37
+ def apply_liger_triton_cache_manager():
38
+ """
39
+ Experimental feature to get around transient FileNotFoundError in triton compilation.
40
+ For more details please see https://github.com/triton-lang/triton/pull/4295
41
+ """
42
+ os.environ["TRITON_CACHE_MANAGER"] = (
43
+ "liger_kernel.triton.monkey_patch:LigerTritonFileCacheManager"
44
+ )
@@ -0,0 +1,14 @@
1
+ Metadata-Version: 2.1
2
+ Name: liger-kernel
3
+ Version: 0.0.0
4
+ Requires-Dist: torch>=2.1.2
5
+ Requires-Dist: triton>=2.3.0
6
+ Requires-Dist: transformers>=4.40.1
7
+ Provides-Extra: dev
8
+ Requires-Dist: matplotlib>=3.7.2; extra == "dev"
9
+ Requires-Dist: flake8>=4.0.1.1; extra == "dev"
10
+ Requires-Dist: black>=24.4.2; extra == "dev"
11
+ Requires-Dist: isort>=5.13.2; extra == "dev"
12
+ Requires-Dist: pre-commit>=3.7.1; extra == "dev"
13
+ Requires-Dist: torch-tb-profiler>=0.4.1; extra == "dev"
14
+
@@ -0,0 +1,24 @@
1
+ liger_kernel/ops/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
2
+ liger_kernel/ops/cross_entropy.py,sha256=XRnLWW2Jo1sVllDbyTuM8ir_6WZR791fFgqoaIVzPrM,10665
3
+ liger_kernel/ops/fused_linear_cross_entropy.py,sha256=-z5NDZ1a7htumYPJbmw1QRqgp9N_QKZZiuaZSPCb9Y0,7311
4
+ liger_kernel/ops/geglu.py,sha256=DiJSy4I8kouPFyNpKUuthfibZWRioPJMGR-4MJgebhg,3660
5
+ liger_kernel/ops/rms_norm.py,sha256=iQd8ZDzNM-3b05eLzjh1Jfj2C8QKAtg59h-b-XuIo5s,4299
6
+ liger_kernel/ops/rope.py,sha256=fYBct8gDQfKPZdMWlzkZZ8kBzh6nQ7DIpDsc7lZwM8c,8584
7
+ liger_kernel/ops/swiglu.py,sha256=__QsfYxKyZHtRScm31zL3sAOVEblQFqKj2ll8I4Odqg,2835
8
+ liger_kernel/ops/utils.py,sha256=cC7rvhiEBW-8x4qQRTUYWW790k3TA-S7pKbJmdRj-Xc,1080
9
+ liger_kernel/transformers/__init__.py,sha256=7rOw9yZ8kNXO483Colx-EUq8GcTCvCZxrxF-S7pmkkU,172
10
+ liger_kernel/transformers/cross_entropy.py,sha256=G-L4EaUYVc25NKZ2jrlaG-d5YUvDqJdUlawPN7K1d1g,389
11
+ liger_kernel/transformers/fused_linear_cross_entropy.py,sha256=h0AW9ubFGfz4DBwgh2CLW8rpKo9PvxYpB6AUzjx-1b0,501
12
+ liger_kernel/transformers/geglu.py,sha256=FrLBHZRdI68jw9RR6MSTE59-xCzueOwSRp9jL8y-j98,896
13
+ liger_kernel/transformers/monkey_patch.py,sha256=5h436874AENVnTjQAk4-Srp_GIr50CXAl2xeNTbqzJg,3988
14
+ liger_kernel/transformers/rms_norm.py,sha256=2LHfEctSpzuNRaoZ9uUECSFK8fZeIxIsHm9QbEHZvDQ,452
15
+ liger_kernel/transformers/rope.py,sha256=m-ah8vZBYW8tfplTXCiAPMHJWlB1tdp_JPXJeWE-Boo,943
16
+ liger_kernel/transformers/swiglu.py,sha256=8kt4MffEZT5vx3k0WA-GO-WPLv5kGdnu_nAwlJyMI2U,1516
17
+ liger_kernel/transformers/model/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
18
+ liger_kernel/transformers/model/llama.py,sha256=DJOjLT5-TGMLKaPqLqyW03rLae8lJTb3nwnfg2mVNXQ,5197
19
+ liger_kernel/triton/__init__.py,sha256=yfRe0zMb47QnqjecZWG7LnanfCTzeku7SgWRAwNVmzU,101
20
+ liger_kernel/triton/monkey_patch.py,sha256=yRNaGdyG5PrwX5ed_MQdqtqvvpVvQ7ZD2FQ_9W1q9u8,1629
21
+ liger_kernel-0.0.0.dist-info/METADATA,sha256=SBK5dFzMYYtFyorscmi__7u83TitHFpKMsRE9pUKXGI,461
22
+ liger_kernel-0.0.0.dist-info/WHEEL,sha256=eOLhNAGa2EW3wWl_TU484h7q1UNgy0JXjjoqKoxAAQc,92
23
+ liger_kernel-0.0.0.dist-info/top_level.txt,sha256=2eghu4hA3LnkM7ElW92tQ8zegWKgSbeo-k-aGe1YnvY,13
24
+ liger_kernel-0.0.0.dist-info/RECORD,,
@@ -0,0 +1,5 @@
1
+ Wheel-Version: 1.0
2
+ Generator: bdist_wheel (0.44.0)
3
+ Root-Is-Purelib: true
4
+ Tag: py3-none-any
5
+
@@ -0,0 +1 @@
1
+ liger_kernel