optimum-rbln 0.1.12__py3-none-any.whl → 0.1.15__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- optimum/rbln/__init__.py +27 -13
- optimum/rbln/__version__.py +16 -1
- optimum/rbln/diffusers/__init__.py +22 -2
- optimum/rbln/diffusers/models/__init__.py +34 -3
- optimum/rbln/{transformers/generation → diffusers/models/autoencoders}/__init__.py +1 -2
- optimum/rbln/diffusers/models/{autoencoder_kl.py → autoencoders/autoencoder_kl.py} +66 -111
- optimum/rbln/diffusers/models/autoencoders/vae.py +84 -0
- optimum/rbln/diffusers/models/controlnet.py +85 -65
- optimum/rbln/diffusers/models/transformers/__init__.py +24 -0
- optimum/rbln/diffusers/models/transformers/transformer_sd3.py +203 -0
- optimum/rbln/diffusers/models/unets/__init__.py +24 -0
- optimum/rbln/diffusers/models/{unet_2d_condition.py → unets/unet_2d_condition.py} +129 -163
- optimum/rbln/diffusers/pipelines/__init__.py +60 -12
- optimum/rbln/diffusers/pipelines/controlnet/multicontrolnet.py +11 -25
- optimum/rbln/diffusers/pipelines/controlnet/pipeline_controlnet.py +9 -185
- optimum/rbln/diffusers/pipelines/controlnet/pipeline_controlnet_img2img.py +9 -190
- optimum/rbln/diffusers/pipelines/controlnet/pipeline_controlnet_sd_xl.py +9 -191
- optimum/rbln/diffusers/pipelines/controlnet/pipeline_controlnet_sd_xl_img2img.py +9 -192
- optimum/rbln/diffusers/pipelines/stable_diffusion/__init__.py +1 -0
- optimum/rbln/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion.py +4 -110
- optimum/rbln/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_img2img.py +4 -118
- optimum/rbln/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_inpaint.py +32 -0
- optimum/rbln/diffusers/pipelines/stable_diffusion_3/__init__.py +26 -0
- optimum/rbln/diffusers/pipelines/stable_diffusion_3/pipeline_stable_diffusion_3.py +32 -0
- optimum/rbln/diffusers/pipelines/stable_diffusion_3/pipeline_stable_diffusion_3_img2img.py +32 -0
- optimum/rbln/diffusers/pipelines/stable_diffusion_3/pipeline_stable_diffusion_3_inpaint.py +32 -0
- optimum/rbln/diffusers/pipelines/stable_diffusion_xl/__init__.py +1 -0
- optimum/rbln/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl.py +18 -128
- optimum/rbln/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_img2img.py +18 -131
- optimum/rbln/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_inpaint.py +32 -0
- optimum/rbln/modeling.py +572 -0
- optimum/rbln/modeling_alias.py +1 -1
- optimum/rbln/modeling_base.py +176 -763
- optimum/rbln/modeling_diffusers.py +329 -0
- optimum/rbln/transformers/__init__.py +2 -2
- optimum/rbln/transformers/cache_utils.py +5 -9
- optimum/rbln/transformers/modeling_rope_utils.py +283 -0
- optimum/rbln/transformers/models/__init__.py +80 -31
- optimum/rbln/transformers/models/auto/auto_factory.py +117 -23
- optimum/rbln/transformers/models/auto/modeling_auto.py +37 -12
- optimum/rbln/transformers/models/bart/modeling_bart.py +3 -6
- optimum/rbln/transformers/models/bert/modeling_bert.py +3 -6
- optimum/rbln/transformers/models/clip/modeling_clip.py +8 -34
- optimum/rbln/transformers/models/decoderonly/__init__.py +0 -5
- optimum/rbln/transformers/models/decoderonly/decoderonly_architecture.py +779 -361
- optimum/rbln/transformers/models/decoderonly/modeling_decoderonly.py +83 -142
- optimum/rbln/transformers/models/dpt/modeling_dpt.py +1 -1
- optimum/rbln/transformers/models/exaone/exaone_architecture.py +64 -39
- optimum/rbln/transformers/models/exaone/modeling_exaone.py +6 -29
- optimum/rbln/transformers/models/gemma/gemma_architecture.py +31 -92
- optimum/rbln/transformers/models/gemma/modeling_gemma.py +4 -28
- optimum/rbln/transformers/models/gpt2/gpt2_architecture.py +50 -238
- optimum/rbln/transformers/models/gpt2/modeling_gpt2.py +6 -31
- optimum/rbln/transformers/models/llama/modeling_llama.py +4 -28
- optimum/rbln/transformers/models/llava_next/modeling_llava_next.py +29 -83
- optimum/rbln/transformers/models/midm/midm_architecture.py +88 -253
- optimum/rbln/transformers/models/midm/modeling_midm.py +8 -33
- optimum/rbln/transformers/models/mistral/modeling_mistral.py +4 -29
- optimum/rbln/transformers/models/phi/modeling_phi.py +5 -31
- optimum/rbln/transformers/models/phi/phi_architecture.py +61 -345
- optimum/rbln/transformers/models/qwen2/modeling_qwen2.py +5 -29
- optimum/rbln/transformers/models/seq2seq/modeling_seq2seq.py +1 -46
- optimum/rbln/transformers/models/t5/__init__.py +1 -1
- optimum/rbln/transformers/models/t5/modeling_t5.py +157 -6
- optimum/rbln/transformers/models/wav2vec2/modeling_wav2vec2.py +1 -1
- optimum/rbln/transformers/models/whisper/modeling_whisper.py +2 -2
- optimum/rbln/transformers/models/xlm_roberta/modeling_xlm_roberta.py +3 -35
- optimum/rbln/transformers/utils/rbln_quantization.py +128 -5
- optimum/rbln/utils/decorator_utils.py +59 -0
- optimum/rbln/utils/hub.py +131 -0
- optimum/rbln/utils/import_utils.py +21 -0
- optimum/rbln/utils/model_utils.py +53 -0
- optimum/rbln/utils/runtime_utils.py +5 -5
- optimum/rbln/utils/submodule.py +114 -0
- optimum/rbln/utils/timer_utils.py +2 -2
- optimum_rbln-0.1.15.dist-info/METADATA +106 -0
- optimum_rbln-0.1.15.dist-info/RECORD +110 -0
- {optimum_rbln-0.1.12.dist-info → optimum_rbln-0.1.15.dist-info}/WHEEL +1 -1
- optimum/rbln/transformers/generation/streamers.py +0 -139
- optimum/rbln/transformers/generation/utils.py +0 -397
- optimum/rbln/transformers/models/exaone/hf_hub_cached/configuration_exaone.py +0 -181
- optimum/rbln/transformers/models/exaone/hf_hub_cached/modeling_exaone.py +0 -1725
- optimum/rbln/transformers/models/midm/hf_hub_cached/configuration_midm.py +0 -22
- optimum/rbln/transformers/models/midm/hf_hub_cached/midm_bitext_tokenization.py +0 -304
- optimum/rbln/transformers/models/midm/hf_hub_cached/modeling_midm.py +0 -1469
- optimum/rbln/transformers/models/midm/hf_hub_cached/rotary_position_embedding.py +0 -98
- optimum_rbln-0.1.12.dist-info/METADATA +0 -119
- optimum_rbln-0.1.12.dist-info/RECORD +0 -103
- optimum_rbln-0.1.12.dist-info/entry_points.txt +0 -4
- {optimum_rbln-0.1.12.dist-info → optimum_rbln-0.1.15.dist-info}/licenses/LICENSE +0 -0
@@ -21,103 +21,42 @@
|
|
21
21
|
# copied, modified, or distributed without prior written permission
|
22
22
|
# from Rebellions Inc.
|
23
23
|
|
24
|
-
from typing import
|
24
|
+
from typing import TYPE_CHECKING
|
25
25
|
|
26
|
-
import
|
27
|
-
from transformers.modeling_outputs import (
|
28
|
-
BaseModelOutputWithPast,
|
29
|
-
)
|
30
|
-
|
31
|
-
from ...models.decoderonly import (
|
26
|
+
from ...models.decoderonly.decoderonly_architecture import (
|
32
27
|
DecoderOnlyAttention,
|
33
|
-
|
28
|
+
DecoderOnlyFlashAttention,
|
29
|
+
DecoderOnlyForCausalLM,
|
30
|
+
DecoderOnlyLayer,
|
31
|
+
DecoderOnlyModel,
|
34
32
|
DecoderOnlyWrapper,
|
35
|
-
slice_and_unsqueeze_cos_sin,
|
36
33
|
)
|
37
34
|
|
38
35
|
|
39
|
-
|
40
|
-
|
41
|
-
forward_dict = {}
|
42
|
-
forward_dict.update(
|
43
|
-
{
|
44
|
-
"wrapper": GemmaModel.forward,
|
45
|
-
"model": DecoderOnlyDecoderLayer.forward,
|
46
|
-
"decoder_layer": DecoderOnlyAttention.forward,
|
47
|
-
}
|
48
|
-
)
|
49
|
-
return forward_dict
|
50
|
-
|
51
|
-
|
52
|
-
class GemmaModel:
|
53
|
-
def forward(
|
54
|
-
self,
|
55
|
-
input_ids: torch.LongTensor = None,
|
56
|
-
attention_mask: Optional[torch.Tensor] = None,
|
57
|
-
position_ids: Optional[torch.LongTensor] = None,
|
58
|
-
past_key_values: Optional[List[torch.FloatTensor]] = None,
|
59
|
-
batch_ids: Optional[torch.LongTensor] = None,
|
60
|
-
inputs_embeds: Optional[torch.FloatTensor] = None,
|
61
|
-
use_cache: Optional[bool] = True,
|
62
|
-
output_attentions: Optional[bool] = False,
|
63
|
-
output_hidden_states: Optional[bool] = False,
|
64
|
-
forward_dict: Optional[Dict[str, classmethod]] = None,
|
65
|
-
rotary_pos_emb=None,
|
66
|
-
) -> Union[Tuple, BaseModelOutputWithPast]:
|
67
|
-
# embed positions
|
68
|
-
inputs_embeds = self.embed_tokens(input_ids)
|
69
|
-
hidden_states = inputs_embeds
|
70
|
-
|
71
|
-
##### GEMMA change from llama#####
|
72
|
-
hidden_states = hidden_states * (self.config.hidden_size**0.5)
|
73
|
-
|
74
|
-
attention_mask = (1 - attention_mask) * torch.finfo(torch.float16).min
|
75
|
-
|
76
|
-
# get cos,sin vector
|
77
|
-
cos, sin = rotary_pos_emb(inputs_embeds, attention_mask.shape[-1])
|
78
|
-
cos, sin = slice_and_unsqueeze_cos_sin(cos, sin, position_ids)
|
36
|
+
if TYPE_CHECKING:
|
37
|
+
from transformers import GemmaForCausalLM
|
79
38
|
|
80
|
-
# decoder layers
|
81
|
-
all_hidden_states = () if output_hidden_states else None
|
82
|
-
all_self_attns = () if output_attentions else None
|
83
39
|
|
84
|
-
|
85
|
-
|
86
|
-
|
87
|
-
|
88
|
-
|
89
|
-
|
90
|
-
|
91
|
-
|
92
|
-
|
93
|
-
|
94
|
-
|
95
|
-
|
96
|
-
|
97
|
-
|
98
|
-
|
99
|
-
|
100
|
-
|
101
|
-
|
102
|
-
|
103
|
-
|
104
|
-
|
105
|
-
|
106
|
-
|
107
|
-
all_self_attns += (layer_outputs[1],)
|
108
|
-
|
109
|
-
hidden_states = self.norm(hidden_states)
|
110
|
-
|
111
|
-
# add hidden states from the last decoder layer
|
112
|
-
if output_hidden_states:
|
113
|
-
all_hidden_states += (hidden_states,)
|
114
|
-
|
115
|
-
# convert RebelDynamicCache to legacy Tuple[Tuple[torch.Tensor]]
|
116
|
-
next_cache = updated_cache.to_legacy_cache()
|
117
|
-
|
118
|
-
return BaseModelOutputWithPast(
|
119
|
-
last_hidden_state=hidden_states,
|
120
|
-
past_key_values=next_cache,
|
121
|
-
hidden_states=all_hidden_states,
|
122
|
-
attentions=all_self_attns,
|
123
|
-
)
|
40
|
+
class GemmaWrapper(DecoderOnlyWrapper):
|
41
|
+
def convert_to_rbln_causal_lm(self, causal_lm: "GemmaForCausalLM"):
|
42
|
+
new_layers = []
|
43
|
+
for layer in causal_lm.model.layers:
|
44
|
+
if self.attn_impl == "eager":
|
45
|
+
new_self_attn = DecoderOnlyAttention(layer.self_attn)
|
46
|
+
elif self.attn_impl == "flash_attn":
|
47
|
+
new_self_attn = DecoderOnlyFlashAttention(
|
48
|
+
layer.self_attn, kvcache_partition_len=self.kvcache_partition_len
|
49
|
+
)
|
50
|
+
else:
|
51
|
+
raise NotImplementedError(f"Unknwon attn : {self.attn_impl}")
|
52
|
+
new_layer = DecoderOnlyLayer(layer, new_self_attn)
|
53
|
+
new_layers.append(new_layer)
|
54
|
+
new_model = GemmaModel(causal_lm.model, new_layers)
|
55
|
+
new_causal_lm = DecoderOnlyForCausalLM(causal_lm, new_model)
|
56
|
+
return new_causal_lm
|
57
|
+
|
58
|
+
|
59
|
+
class GemmaModel(DecoderOnlyModel):
|
60
|
+
@property
|
61
|
+
def hidden_multiplier(self):
|
62
|
+
return self._original_mod.config.hidden_size**0.5
|
@@ -21,28 +21,18 @@
|
|
21
21
|
# copied, modified, or distributed without prior written permission
|
22
22
|
# from Rebellions Inc.
|
23
23
|
|
24
|
-
import
|
25
|
-
import logging
|
26
|
-
from typing import TYPE_CHECKING, Any, Callable
|
27
|
-
|
28
|
-
from transformers import GemmaForCausalLM
|
29
|
-
|
24
|
+
from ....utils import logging
|
30
25
|
from ...models.decoderonly import RBLNDecoderOnlyModelForCausalLM
|
31
26
|
from .gemma_architecture import GemmaWrapper
|
32
27
|
|
33
28
|
|
34
|
-
|
35
|
-
from transformers import PreTrainedModel
|
36
|
-
|
37
|
-
from ....modeling_config import RBLNConfig
|
38
|
-
|
39
|
-
logger = logging.getLogger(__name__)
|
29
|
+
logger = logging.get_logger(__name__)
|
40
30
|
|
41
31
|
|
42
32
|
class RBLNGemmaForCausalLM(RBLNDecoderOnlyModelForCausalLM):
|
43
33
|
"""
|
44
34
|
The Gemma Model transformer with a language modeling head (linear layer) on top.
|
45
|
-
This model inherits from [`
|
35
|
+
This model inherits from [`RBLNDecoderOnlyModelForCausalLM`]. Check the superclass documentation for the generic methods the library implements for all its models.
|
46
36
|
|
47
37
|
A class to convert and run pre-trained transformers based GemmaForCausalLM model on RBLN devices.
|
48
38
|
It implements the methods to convert a pre-trained transformers GemmaForCausalLM model into a RBLN transformer model by:
|
@@ -50,18 +40,4 @@ class RBLNGemmaForCausalLM(RBLNDecoderOnlyModelForCausalLM):
|
|
50
40
|
- compiling the resulting graph using the RBLN compiler.
|
51
41
|
"""
|
52
42
|
|
53
|
-
|
54
|
-
def wrap_model_if_needed(self, model: "PreTrainedModel", rbln_config: "RBLNConfig"):
|
55
|
-
rbln_max_seq_len = rbln_config.model_cfg["max_seq_len"]
|
56
|
-
return GemmaWrapper(model, rbln_max_seq_len).eval()
|
57
|
-
|
58
|
-
def __getattr__(self, __name: str) -> Any:
|
59
|
-
def redirect(func):
|
60
|
-
return lambda *pargs, **kwargs: func(self, *pargs, **kwargs)
|
61
|
-
|
62
|
-
val = getattr(GemmaForCausalLM, __name)
|
63
|
-
|
64
|
-
if isinstance(val, Callable) and "self" in set(inspect.signature(val).parameters):
|
65
|
-
return redirect(val)
|
66
|
-
|
67
|
-
return val
|
43
|
+
_decoder_wrapper_cls = GemmaWrapper
|
@@ -21,262 +21,74 @@
|
|
21
21
|
# copied, modified, or distributed without prior written permission
|
22
22
|
# from Rebellions Inc.
|
23
23
|
|
24
|
-
from typing import
|
24
|
+
from typing import TYPE_CHECKING, Tuple
|
25
25
|
|
26
26
|
import torch
|
27
27
|
import torch.nn as nn
|
28
|
-
from transformers.modeling_outputs import BaseModelOutputWithPast
|
29
28
|
|
30
|
-
from
|
29
|
+
from ..decoderonly.decoderonly_architecture import (
|
30
|
+
DecoderOnlyAttention,
|
31
|
+
DecoderOnlyForCausalLM,
|
32
|
+
DecoderOnlyLayer,
|
33
|
+
DecoderOnlyModel,
|
34
|
+
DecoderOnlyWrapper,
|
35
|
+
)
|
31
36
|
|
32
37
|
|
33
|
-
|
34
|
-
|
35
|
-
super().__init__()
|
36
|
-
self.model = model.transformer
|
37
|
-
self.lm_head = model.lm_head
|
38
|
-
self.config = model.config
|
39
|
-
self.max_seq_len = max_seq_len
|
40
|
-
self.forward_dict = self.get_forward_dict()
|
38
|
+
if TYPE_CHECKING:
|
39
|
+
from transformers import GPT2LMHeadModel
|
41
40
|
|
42
|
-
def get_forward_dict(self):
|
43
|
-
forward_dict = {
|
44
|
-
"wrapper": _GPT2Model.forward,
|
45
|
-
"model": _GPT2Block.forward,
|
46
|
-
"decoder_layer": _GPT2Attention.forward,
|
47
|
-
}
|
48
|
-
return forward_dict
|
49
41
|
|
50
|
-
|
51
|
-
|
52
|
-
|
53
|
-
|
54
|
-
|
55
|
-
|
56
|
-
|
57
|
-
|
58
|
-
|
59
|
-
|
60
|
-
|
61
|
-
|
62
|
-
rbln_batch_position = batch_position
|
42
|
+
class GPT2Wrapper(DecoderOnlyWrapper):
|
43
|
+
def convert_to_rbln_causal_lm(self, causal_lm: "GPT2LMHeadModel"):
|
44
|
+
if self.attn_impl != "eager":
|
45
|
+
raise NotImplementedError(f"flash attention ({self.attn_impl}) is not implemented for {self.__class__}")
|
46
|
+
new_layers = []
|
47
|
+
for layer in causal_lm.transformer.h:
|
48
|
+
new_self_attn = GPT2Attention(layer.attn)
|
49
|
+
new_layer = GPT2Layer(layer, new_self_attn)
|
50
|
+
new_layers.append(new_layer)
|
51
|
+
new_model = GPT2Model(causal_lm.transformer, new_layers)
|
52
|
+
new_causal_lm = DecoderOnlyForCausalLM(causal_lm, new_model)
|
53
|
+
return new_causal_lm
|
63
54
|
|
64
|
-
# Formatting list of past_kv to DynamicCache class.
|
65
|
-
past_key_value = RebelDynamicCache_4D.from_input_format(
|
66
|
-
cache_position,
|
67
|
-
self.config.n_layer,
|
68
|
-
*past_key_values,
|
69
|
-
)
|
70
|
-
|
71
|
-
outputs = self.forward_dict["wrapper"](
|
72
|
-
self.model,
|
73
|
-
input_ids=input_ids,
|
74
|
-
attention_mask=attention_mask,
|
75
|
-
position_ids=cache_position,
|
76
|
-
past_key_value=past_key_value,
|
77
|
-
batch_ids=rbln_batch_position,
|
78
|
-
forward_dict=self.forward_dict,
|
79
|
-
# rotary_emb differenct from_llama
|
80
|
-
)
|
81
|
-
|
82
|
-
hidden_states = outputs[0]
|
83
|
-
if batch_position >= 0:
|
84
|
-
hidden_states = hidden_states[:, query_idx].unsqueeze(1)
|
85
|
-
logits = self.lm_head(hidden_states)
|
86
|
-
|
87
|
-
output = (logits,) + outputs[1:]
|
88
55
|
|
89
|
-
|
56
|
+
class GPT2Model(DecoderOnlyModel):
|
57
|
+
mask_fmin = torch.finfo(torch.float32).min
|
90
58
|
|
59
|
+
def get_last_layernorm(self) -> nn.LayerNorm:
|
60
|
+
return self._original_mod.ln_f
|
91
61
|
|
92
|
-
|
93
|
-
|
94
|
-
self,
|
95
|
-
input_ids: torch.LongTensor = None,
|
96
|
-
attention_mask: Optional[torch.Tensor] = None,
|
97
|
-
position_ids: Optional[torch.LongTensor] = None,
|
98
|
-
past_key_value: Optional[RebelDynamicCache_4D] = None,
|
99
|
-
batch_ids: Optional[torch.LongTensor] = None,
|
100
|
-
forward_dict: Optional[Dict[str, classmethod]] = None,
|
101
|
-
) -> BaseModelOutputWithPast:
|
102
|
-
b_size, q_len = input_ids.shape
|
103
|
-
inputs_embeds = self.wte(input_ids)
|
62
|
+
def get_embedding(self) -> nn.Embedding:
|
63
|
+
return self._original_mod.wte
|
104
64
|
|
105
|
-
|
106
|
-
|
107
|
-
for b_idx in range(b_size):
|
108
|
-
position_embed = self.wpe(position_ids[b_idx])
|
109
|
-
# position_embed = position_embed.dtype(inputs_embeds.dtype)
|
110
|
-
position_embeds.append(position_embed)
|
65
|
+
def get_pos_embedding(self) -> nn.Embedding:
|
66
|
+
return self._original_mod.wpe
|
111
67
|
|
112
|
-
position_embeds = torch.cat(position_embeds, dim=0).unsqueeze(1)
|
113
|
-
else:
|
114
|
-
position_embeds = self.wpe(position_ids)
|
115
68
|
|
116
|
-
|
69
|
+
class GPT2Layer(DecoderOnlyLayer):
|
70
|
+
def get_pre_attention_layernorm(self) -> nn.LayerNorm:
|
71
|
+
return self._original_mod.ln_1
|
117
72
|
|
118
|
-
|
119
|
-
|
120
|
-
attention_mask = (1.0 - attention_mask) * torch.finfo(self.dtype).min
|
73
|
+
def get_post_attention_layernorm(self) -> nn.LayerNorm:
|
74
|
+
return self._original_mod.ln_2
|
121
75
|
|
122
|
-
for layer_idx, block in enumerate(self.h):
|
123
|
-
hidden_states, updated_cache = forward_dict["model"](
|
124
|
-
block,
|
125
|
-
hidden_states,
|
126
|
-
layer_idx,
|
127
|
-
attention_mask=attention_mask,
|
128
|
-
past_key_value=past_key_value,
|
129
|
-
position_ids=position_ids,
|
130
|
-
batch_ids=batch_ids,
|
131
|
-
forward_dict=forward_dict,
|
132
|
-
)
|
133
|
-
|
134
|
-
hidden_states = self.ln_f(hidden_states)
|
135
|
-
output_shape = (-1,) + (q_len,) + (hidden_states.size(-1),)
|
136
|
-
hidden_states = hidden_states.view(output_shape)
|
137
|
-
|
138
|
-
# convert RebelDynamicCache to legacy Tuple[Tuple[torch.Tensor]]
|
139
|
-
next_cache = updated_cache.to_legacy_cache()
|
140
|
-
|
141
|
-
return BaseModelOutputWithPast(
|
142
|
-
last_hidden_state=hidden_states,
|
143
|
-
past_key_values=next_cache,
|
144
|
-
)
|
145
76
|
|
77
|
+
class GPT2Attention(DecoderOnlyAttention):
|
78
|
+
def __post_init__(self):
|
79
|
+
self.c_attn = self._original_mod.c_attn
|
80
|
+
self.o_proj = self._original_mod.c_proj
|
81
|
+
self.split_size = self._original_mod.split_size
|
82
|
+
self.num_key_value_heads = self._original_mod.num_heads
|
146
83
|
|
147
|
-
|
148
|
-
|
149
|
-
|
150
|
-
hidden_states: Optional[Tuple[torch.FloatTensor]],
|
151
|
-
layer_idx: int,
|
152
|
-
attention_mask: Optional[torch.FloatTensor] = None,
|
153
|
-
position_ids: Optional[torch.LongTensor] = None,
|
154
|
-
past_key_value: Optional[RebelDynamicCache_4D] = None,
|
155
|
-
batch_ids: Optional[torch.LongTensor] = None,
|
156
|
-
forward_dict: Optional[Dict[str, classmethod]] = None,
|
157
|
-
**kwargs,
|
158
|
-
) -> Tuple[torch.Tensor, RebelDynamicCache_4D]:
|
159
|
-
residual = hidden_states
|
160
|
-
hidden_states = self.ln_1(hidden_states)
|
84
|
+
def projection(self, hidden_states) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
|
85
|
+
query_states, key_states, value_states = self.c_attn(hidden_states).split(self.split_size, dim=2)
|
86
|
+
return query_states, key_states, value_states
|
161
87
|
|
162
|
-
|
163
|
-
|
164
|
-
|
165
|
-
|
166
|
-
|
167
|
-
|
168
|
-
batch_index=batch_ids,
|
88
|
+
def rbln_attention(self, *args, **kwargs):
|
89
|
+
return super().rbln_attention(
|
90
|
+
*args,
|
91
|
+
**kwargs,
|
92
|
+
layer_idx=self.layer_idx,
|
93
|
+
scale_attn_by_inverse_layer_idx=self._original_mod.scale_attn_by_inverse_layer_idx,
|
169
94
|
)
|
170
|
-
past_key_value.assign(k, v, layer_idx)
|
171
|
-
|
172
|
-
# residual connection
|
173
|
-
hidden_states = residual + hidden_states
|
174
|
-
|
175
|
-
residual = hidden_states
|
176
|
-
hidden_states = self.ln_2(hidden_states)
|
177
|
-
hidden_states = self.mlp(hidden_states)
|
178
|
-
hidden_states = residual + hidden_states
|
179
|
-
|
180
|
-
return hidden_states, past_key_value
|
181
|
-
|
182
|
-
|
183
|
-
class _GPT2Attention:
|
184
|
-
def _attn(self, query, key, value, attention_mask=None, head_mask=None):
|
185
|
-
attn_weights = torch.matmul(query, key.transpose(-1, -2))
|
186
|
-
|
187
|
-
if self.scale_attn_weights:
|
188
|
-
attn_weights = attn_weights / torch.full(
|
189
|
-
[], value.size(-1) ** 0.5, dtype=attn_weights.dtype, device=attn_weights.device
|
190
|
-
)
|
191
|
-
|
192
|
-
# Layer-wise attention scaling
|
193
|
-
if self.scale_attn_by_inverse_layer_idx:
|
194
|
-
attn_weights = attn_weights / float(self.layer_idx + 1)
|
195
|
-
|
196
|
-
# -------------------
|
197
|
-
# Below are deleted since "where" op does not supported on RBLN graph.
|
198
|
-
# -------------------
|
199
|
-
# if not self.is_cross_attention:
|
200
|
-
# # if only "normal" attention layer implements causal mask
|
201
|
-
# query_length, key_length = query.size(-2), key.size(-2)
|
202
|
-
# causal_mask = self.bias[:, :, key_length - query_length : key_length, :key_length]
|
203
|
-
# mask_value = torch.finfo(attn_weights.dtype).min
|
204
|
-
# # Need to be a tensor, otherwise we get error: `RuntimeError: expected scalar type float but found double`.
|
205
|
-
# # Need to be on the same device, otherwise `RuntimeError: ..., x and y to be on the same device`
|
206
|
-
# mask_value = torch.full([], mask_value, dtype=attn_weights.dtype, device=attn_weights.device)
|
207
|
-
# attn_weights = torch.where(causal_mask, attn_weights.to(attn_weights.dtype), mask_value)
|
208
|
-
|
209
|
-
# Apply the attention mask
|
210
|
-
attn_weights.view(
|
211
|
-
-1,
|
212
|
-
)
|
213
|
-
attn_weights = attn_weights + attention_mask
|
214
|
-
attn_weights = nn.functional.softmax(attn_weights, dim=-1)
|
215
|
-
attn_output = torch.matmul(attn_weights, value)
|
216
|
-
|
217
|
-
return attn_output, attn_weights
|
218
|
-
|
219
|
-
def forward(
|
220
|
-
self,
|
221
|
-
hidden_states: Optional[Tuple[torch.FloatTensor]],
|
222
|
-
attention_mask: Optional[torch.FloatTensor] = None,
|
223
|
-
past_key_value: Optional[RebelDynamicCache_4D] = None,
|
224
|
-
batch_index: Optional[int] = None,
|
225
|
-
**kwargs,
|
226
|
-
) -> Tuple[Union[torch.Tensor, Tuple[torch.Tensor]], ...]:
|
227
|
-
bsz, q_len, _ = hidden_states.size()
|
228
|
-
query, key, value = self.c_attn(hidden_states).split(self.split_size, dim=2)
|
229
|
-
|
230
|
-
querys = self._split_heads(query, self.num_heads, self.head_dim) # (batch, head, seq_length, head_features)
|
231
|
-
keys = self._split_heads(key, self.num_heads, self.head_dim)
|
232
|
-
values = self._split_heads(value, self.num_heads, self.head_dim)
|
233
|
-
|
234
|
-
# Decoder
|
235
|
-
if (batch_index is None or batch_index == -1) and bsz > 1:
|
236
|
-
all_keys = []
|
237
|
-
all_values = []
|
238
|
-
all_attn_output = []
|
239
|
-
|
240
|
-
for b in range(bsz):
|
241
|
-
query = querys[b].unsqueeze(0)
|
242
|
-
attn_mask = attention_mask[b].unsqueeze(0)
|
243
|
-
key = keys[b].unsqueeze(0)
|
244
|
-
value = values[b].unsqueeze(0)
|
245
|
-
|
246
|
-
key, value = past_key_value.update(
|
247
|
-
key,
|
248
|
-
value,
|
249
|
-
self.layer_idx,
|
250
|
-
b,
|
251
|
-
)
|
252
|
-
|
253
|
-
attn_output, _ = _GPT2Attention._attn(self, query, key, value, attn_mask)
|
254
|
-
attn_output = self._merge_heads(attn_output, self.num_heads, self.head_dim)
|
255
|
-
|
256
|
-
all_keys.append(key)
|
257
|
-
all_values.append(value)
|
258
|
-
all_attn_output.append(attn_output)
|
259
|
-
|
260
|
-
keys = torch.cat(all_keys, dim=0)
|
261
|
-
values = torch.cat(all_values, dim=0)
|
262
|
-
attn_output = torch.cat(all_attn_output, dim=0)
|
263
|
-
|
264
|
-
# Prefill
|
265
|
-
else:
|
266
|
-
if batch_index is None or batch_index == -1:
|
267
|
-
batch_index = 0
|
268
|
-
|
269
|
-
keys, values = past_key_value.update(
|
270
|
-
keys,
|
271
|
-
values,
|
272
|
-
self.layer_idx,
|
273
|
-
batch_index,
|
274
|
-
read_first_step=True,
|
275
|
-
)
|
276
|
-
|
277
|
-
attn_output, _ = _GPT2Attention._attn(self, querys, keys, values, attention_mask)
|
278
|
-
attn_output = self._merge_heads(attn_output, self.num_heads, self.head_dim)
|
279
|
-
|
280
|
-
attn_output = self.c_proj(attn_output)
|
281
|
-
|
282
|
-
return attn_output, keys, values
|
@@ -21,20 +21,12 @@
|
|
21
21
|
# copied, modified, or distributed without prior written permission
|
22
22
|
# from Rebellions Inc.
|
23
23
|
|
24
|
-
import
|
25
|
-
import logging
|
26
|
-
from typing import TYPE_CHECKING, Any, Callable
|
27
|
-
|
28
|
-
from transformers import GPT2LMHeadModel
|
29
|
-
|
30
|
-
from ....modeling_config import RBLNConfig
|
24
|
+
from ....utils import logging
|
31
25
|
from ...models.decoderonly import RBLNDecoderOnlyModelForCausalLM
|
32
|
-
from .gpt2_architecture import GPT2LMHeadModelWrapper
|
26
|
+
from .gpt2_architecture import GPT2Wrapper # GPT2LMHeadModelWrapper
|
33
27
|
|
34
28
|
|
35
|
-
logger = logging.
|
36
|
-
if TYPE_CHECKING:
|
37
|
-
from transformers import PreTrainedModel
|
29
|
+
logger = logging.get_logger(__name__)
|
38
30
|
|
39
31
|
|
40
32
|
class RBLNGPT2LMHeadModel(RBLNDecoderOnlyModelForCausalLM):
|
@@ -42,7 +34,7 @@ class RBLNGPT2LMHeadModel(RBLNDecoderOnlyModelForCausalLM):
|
|
42
34
|
The GPT2 Model transformer with a language modeling head on top (linear layer with weights tied to the input
|
43
35
|
embeddings).
|
44
36
|
|
45
|
-
This model inherits from [`
|
37
|
+
This model inherits from [`RBLNDecoderOnlyModelForCausalLM`]. Check the superclass documentation for the generic methods the
|
46
38
|
library implements for all its model.
|
47
39
|
|
48
40
|
It implements the methods to convert a pre-trained transformers GPT2 model into a RBLN transformer model by:
|
@@ -51,22 +43,5 @@ class RBLNGPT2LMHeadModel(RBLNDecoderOnlyModelForCausalLM):
|
|
51
43
|
|
52
44
|
"""
|
53
45
|
|
54
|
-
|
55
|
-
|
56
|
-
rbln_max_seq_len = rbln_config.model_cfg["max_seq_len"]
|
57
|
-
return GPT2LMHeadModelWrapper(model, rbln_max_seq_len).eval()
|
58
|
-
|
59
|
-
def __getattr__(self, __name: str) -> Any:
|
60
|
-
"""This is the key method to implement RBLN-GPT2.
|
61
|
-
|
62
|
-
Returns:
|
63
|
-
Any: GPT2's corresponding method
|
64
|
-
"""
|
65
|
-
|
66
|
-
def redirect(func):
|
67
|
-
return lambda *pargs, **kwargs: func(self, *pargs, **kwargs)
|
68
|
-
|
69
|
-
val = getattr(GPT2LMHeadModel, __name)
|
70
|
-
if isinstance(val, Callable) and "self" in set(inspect.signature(val).parameters):
|
71
|
-
return redirect(val)
|
72
|
-
return val
|
46
|
+
_decoder_wrapper_cls = GPT2Wrapper
|
47
|
+
_use_rotary_emb = False
|
@@ -21,28 +21,18 @@
|
|
21
21
|
# copied, modified, or distributed without prior written permission
|
22
22
|
# from Rebellions Inc.
|
23
23
|
|
24
|
-
import
|
25
|
-
import logging
|
26
|
-
from typing import TYPE_CHECKING, Any, Callable
|
27
|
-
|
28
|
-
from transformers import LlamaForCausalLM
|
29
|
-
|
24
|
+
from ....utils import logging
|
30
25
|
from ...models.decoderonly import RBLNDecoderOnlyModelForCausalLM
|
31
26
|
from .llama_architecture import LlamaWrapper
|
32
27
|
|
33
28
|
|
34
|
-
|
35
|
-
from transformers import PreTrainedModel
|
36
|
-
|
37
|
-
from ....modeling_config import RBLNConfig
|
38
|
-
|
39
|
-
logger = logging.getLogger(__name__)
|
29
|
+
logger = logging.get_logger(__name__)
|
40
30
|
|
41
31
|
|
42
32
|
class RBLNLlamaForCausalLM(RBLNDecoderOnlyModelForCausalLM):
|
43
33
|
"""
|
44
34
|
The Llama Model transformer with a language modeling head (linear layer) on top.
|
45
|
-
This model inherits from [`
|
35
|
+
This model inherits from [`RBLNDecoderOnlyModelForCausalLM`]. Check the superclass documentation for the generic methods the library implements for all its models.
|
46
36
|
|
47
37
|
A class to convert and run pre-trained transformers based LlamaForCausalLM model on RBLN devices.
|
48
38
|
It implements the methods to convert a pre-trained transformers LlamaForCausalLM model into a RBLN transformer model by:
|
@@ -50,18 +40,4 @@ class RBLNLlamaForCausalLM(RBLNDecoderOnlyModelForCausalLM):
|
|
50
40
|
- compiling the resulting graph using the RBLN compiler.
|
51
41
|
"""
|
52
42
|
|
53
|
-
|
54
|
-
def wrap_model_if_needed(self, model: "PreTrainedModel", rbln_config: "RBLNConfig"):
|
55
|
-
rbln_max_seq_len = rbln_config.model_cfg["max_seq_len"]
|
56
|
-
return LlamaWrapper(model, rbln_max_seq_len).eval()
|
57
|
-
|
58
|
-
def __getattr__(self, __name: str) -> Any:
|
59
|
-
def redirect(func):
|
60
|
-
return lambda *pargs, **kwargs: func(self, *pargs, **kwargs)
|
61
|
-
|
62
|
-
val = getattr(LlamaForCausalLM, __name)
|
63
|
-
|
64
|
-
if isinstance(val, Callable) and "self" in set(inspect.signature(val).parameters):
|
65
|
-
return redirect(val)
|
66
|
-
|
67
|
-
return val
|
43
|
+
_decoder_wrapper_cls = LlamaWrapper
|