optimum-rbln 0.1.13__py3-none-any.whl → 0.2.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- optimum/rbln/__init__.py +41 -38
- optimum/rbln/__version__.py +16 -1
- optimum/rbln/diffusers/__init__.py +26 -2
- optimum/rbln/{modeling_diffusers.py → diffusers/modeling_diffusers.py} +97 -126
- optimum/rbln/diffusers/models/__init__.py +36 -3
- optimum/rbln/{transformers/generation → diffusers/models/autoencoders}/__init__.py +1 -2
- optimum/rbln/diffusers/models/{autoencoder_kl.py → autoencoders/autoencoder_kl.py} +73 -61
- optimum/rbln/diffusers/models/autoencoders/vae.py +83 -0
- optimum/rbln/diffusers/models/controlnet.py +54 -14
- optimum/rbln/diffusers/models/transformers/__init__.py +24 -0
- optimum/rbln/diffusers/models/transformers/transformer_sd3.py +203 -0
- optimum/rbln/diffusers/models/unets/__init__.py +24 -0
- optimum/rbln/diffusers/models/{unet_2d_condition.py → unets/unet_2d_condition.py} +82 -22
- optimum/rbln/diffusers/pipelines/__init__.py +23 -2
- optimum/rbln/diffusers/pipelines/controlnet/multicontrolnet.py +13 -33
- optimum/rbln/diffusers/pipelines/controlnet/pipeline_controlnet.py +17 -2
- optimum/rbln/diffusers/pipelines/controlnet/pipeline_controlnet_img2img.py +18 -2
- optimum/rbln/diffusers/pipelines/controlnet/pipeline_controlnet_sd_xl.py +18 -2
- optimum/rbln/diffusers/pipelines/controlnet/pipeline_controlnet_sd_xl_img2img.py +18 -2
- optimum/rbln/diffusers/pipelines/stable_diffusion/__init__.py +1 -0
- optimum/rbln/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion.py +2 -2
- optimum/rbln/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_img2img.py +1 -13
- optimum/rbln/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_inpaint.py +31 -0
- optimum/rbln/diffusers/pipelines/stable_diffusion_3/__init__.py +26 -0
- optimum/rbln/diffusers/pipelines/stable_diffusion_3/pipeline_stable_diffusion_3.py +31 -0
- optimum/rbln/diffusers/pipelines/stable_diffusion_3/pipeline_stable_diffusion_3_img2img.py +31 -0
- optimum/rbln/diffusers/pipelines/stable_diffusion_3/pipeline_stable_diffusion_3_inpaint.py +31 -0
- optimum/rbln/diffusers/pipelines/stable_diffusion_xl/__init__.py +24 -0
- optimum/rbln/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl.py +15 -8
- optimum/rbln/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_img2img.py +15 -8
- optimum/rbln/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_inpaint.py +31 -0
- optimum/rbln/modeling.py +238 -0
- optimum/rbln/modeling_base.py +186 -760
- optimum/rbln/modeling_config.py +31 -7
- optimum/rbln/ops/__init__.py +26 -0
- optimum/rbln/ops/attn.py +221 -0
- optimum/rbln/ops/flash_attn.py +70 -0
- optimum/rbln/ops/kv_cache_update.py +69 -0
- optimum/rbln/transformers/__init__.py +20 -2
- optimum/rbln/{modeling_alias.py → transformers/modeling_alias.py} +5 -1
- optimum/rbln/transformers/modeling_generic.py +385 -0
- optimum/rbln/transformers/models/auto/__init__.py +23 -0
- optimum/rbln/transformers/models/auto/auto_factory.py +117 -23
- optimum/rbln/transformers/models/auto/modeling_auto.py +36 -12
- optimum/rbln/transformers/models/bart/__init__.py +0 -1
- optimum/rbln/transformers/models/bart/bart_architecture.py +107 -464
- optimum/rbln/transformers/models/bart/modeling_bart.py +10 -9
- optimum/rbln/transformers/models/bert/modeling_bert.py +3 -6
- optimum/rbln/transformers/models/clip/modeling_clip.py +8 -25
- optimum/rbln/transformers/models/decoderonly/__init__.py +0 -10
- optimum/rbln/transformers/models/decoderonly/decoderonly_architecture.py +775 -514
- optimum/rbln/transformers/models/decoderonly/modeling_decoderonly.py +128 -260
- optimum/rbln/transformers/models/dpt/modeling_dpt.py +1 -1
- optimum/rbln/transformers/models/exaone/exaone_architecture.py +60 -45
- optimum/rbln/transformers/models/exaone/modeling_exaone.py +4 -2
- optimum/rbln/transformers/models/gemma/gemma_architecture.py +33 -104
- optimum/rbln/transformers/models/gpt2/gpt2_architecture.py +50 -238
- optimum/rbln/transformers/models/gpt2/modeling_gpt2.py +3 -2
- optimum/rbln/transformers/models/llama/llama_architecture.py +0 -1
- optimum/rbln/transformers/models/llava_next/modeling_llava_next.py +3 -75
- optimum/rbln/transformers/models/midm/midm_architecture.py +84 -238
- optimum/rbln/transformers/models/midm/modeling_midm.py +5 -6
- optimum/rbln/transformers/models/mistral/mistral_architecture.py +0 -1
- optimum/rbln/transformers/models/phi/phi_architecture.py +60 -261
- optimum/rbln/transformers/models/qwen2/qwen2_architecture.py +0 -1
- optimum/rbln/transformers/models/seq2seq/modeling_seq2seq.py +58 -103
- optimum/rbln/transformers/models/seq2seq/seq2seq_architecture.py +498 -0
- optimum/rbln/transformers/models/t5/__init__.py +0 -1
- optimum/rbln/transformers/models/t5/modeling_t5.py +106 -5
- optimum/rbln/transformers/models/t5/t5_architecture.py +106 -448
- optimum/rbln/transformers/models/wav2vec2/modeling_wav2vec2.py +1 -1
- optimum/rbln/transformers/models/whisper/generation_whisper.py +42 -0
- optimum/rbln/transformers/models/whisper/modeling_whisper.py +78 -55
- optimum/rbln/transformers/models/whisper/whisper_architecture.py +219 -312
- optimum/rbln/transformers/models/xlm_roberta/modeling_xlm_roberta.py +3 -35
- optimum/rbln/transformers/utils/rbln_quantization.py +120 -4
- optimum/rbln/utils/decorator_utils.py +51 -11
- optimum/rbln/utils/hub.py +131 -0
- optimum/rbln/utils/import_utils.py +22 -1
- optimum/rbln/utils/logging.py +37 -0
- optimum/rbln/utils/model_utils.py +52 -0
- optimum/rbln/utils/runtime_utils.py +10 -4
- optimum/rbln/utils/save_utils.py +17 -0
- optimum/rbln/utils/submodule.py +137 -0
- optimum_rbln-0.2.0.dist-info/METADATA +117 -0
- optimum_rbln-0.2.0.dist-info/RECORD +114 -0
- {optimum_rbln-0.1.13.dist-info → optimum_rbln-0.2.0.dist-info}/WHEEL +1 -1
- optimum_rbln-0.2.0.dist-info/licenses/LICENSE +288 -0
- optimum/rbln/transformers/cache_utils.py +0 -107
- optimum/rbln/transformers/generation/streamers.py +0 -139
- optimum/rbln/transformers/generation/utils.py +0 -397
- optimum/rbln/transformers/models/exaone/hf_hub_cached/configuration_exaone.py +0 -181
- optimum/rbln/transformers/models/exaone/hf_hub_cached/modeling_exaone.py +0 -1725
- optimum/rbln/transformers/models/midm/hf_hub_cached/configuration_midm.py +0 -22
- optimum/rbln/transformers/models/midm/hf_hub_cached/midm_bitext_tokenization.py +0 -304
- optimum/rbln/transformers/models/midm/hf_hub_cached/modeling_midm.py +0 -1469
- optimum/rbln/transformers/models/midm/hf_hub_cached/rotary_position_embedding.py +0 -98
- optimum/rbln/utils/context.py +0 -58
- optimum/rbln/utils/timer_utils.py +0 -43
- optimum_rbln-0.1.13.dist-info/METADATA +0 -120
- optimum_rbln-0.1.13.dist-info/RECORD +0 -107
- optimum_rbln-0.1.13.dist-info/entry_points.txt +0 -4
- optimum_rbln-0.1.13.dist-info/licenses/LICENSE +0 -201
@@ -20,62 +20,77 @@
|
|
20
20
|
# are the intellectual property of Rebellions Inc. and may not be
|
21
21
|
# copied, modified, or distributed without prior written permission
|
22
22
|
# from Rebellions Inc.
|
23
|
-
|
23
|
+
|
24
|
+
from typing import TYPE_CHECKING
|
25
|
+
|
26
|
+
import torch.nn as nn
|
24
27
|
|
25
28
|
from ....utils import logging
|
26
|
-
from ...models.decoderonly import (
|
29
|
+
from ...models.decoderonly.decoderonly_architecture import (
|
27
30
|
DecoderOnlyAttention,
|
28
|
-
|
31
|
+
DecoderOnlyFlashAttention,
|
32
|
+
DecoderOnlyForCausalLM,
|
33
|
+
DecoderOnlyLayer,
|
29
34
|
DecoderOnlyModel,
|
30
35
|
DecoderOnlyWrapper,
|
31
|
-
RotaryEmbedding,
|
32
36
|
)
|
33
37
|
|
34
38
|
|
39
|
+
if TYPE_CHECKING:
|
40
|
+
from transformers import PreTrainedModel as ExaoneForCausalLM
|
41
|
+
|
35
42
|
logger = logging.get_logger(__name__)
|
36
43
|
|
37
44
|
|
38
45
|
class ExaoneForCausalLMWrapper(DecoderOnlyWrapper):
|
39
46
|
"""A wrapper class for the Exaone model with a language modeling head."""
|
40
47
|
|
41
|
-
def
|
42
|
-
|
43
|
-
|
44
|
-
|
45
|
-
|
46
|
-
|
47
|
-
|
48
|
-
|
49
|
-
|
50
|
-
|
51
|
-
|
52
|
-
|
53
|
-
|
54
|
-
|
55
|
-
|
56
|
-
|
57
|
-
|
58
|
-
|
59
|
-
|
60
|
-
|
61
|
-
|
62
|
-
|
63
|
-
|
64
|
-
|
65
|
-
|
66
|
-
|
67
|
-
|
68
|
-
|
69
|
-
|
70
|
-
return
|
71
|
-
|
72
|
-
def
|
73
|
-
|
74
|
-
|
75
|
-
|
76
|
-
|
77
|
-
|
78
|
-
|
79
|
-
|
80
|
-
|
81
|
-
|
48
|
+
def convert_to_rbln_causal_lm(self, causal_lm: "ExaoneForCausalLM"):
|
49
|
+
new_layers = []
|
50
|
+
for layer in causal_lm.transformer.h:
|
51
|
+
if self.attn_impl == "eager":
|
52
|
+
new_self_attn = ExaoneAttention(layer.attn.attention)
|
53
|
+
elif self.attn_impl == "flash_attn":
|
54
|
+
new_self_attn = ExaoneFlashAttention(
|
55
|
+
layer.attn.attention, kvcache_partition_len=self.kvcache_partition_len
|
56
|
+
)
|
57
|
+
else:
|
58
|
+
raise NotImplementedError(f"Unknwon attn : {self.attn_impl}")
|
59
|
+
|
60
|
+
new_layer = ExaoneLayer(layer, new_self_attn)
|
61
|
+
new_layers.append(new_layer)
|
62
|
+
new_model = ExaoneModel(causal_lm.transformer, new_layers, partition_len=self.kvcache_partition_len)
|
63
|
+
new_causal_lm = DecoderOnlyForCausalLM(causal_lm, new_model)
|
64
|
+
return new_causal_lm
|
65
|
+
|
66
|
+
|
67
|
+
class ExaoneModel(DecoderOnlyModel):
|
68
|
+
def get_embedding(self) -> nn.Embedding:
|
69
|
+
return self._original_mod.wte
|
70
|
+
|
71
|
+
def get_last_layernorm(self) -> nn.LayerNorm:
|
72
|
+
return self._original_mod.ln_f
|
73
|
+
|
74
|
+
|
75
|
+
class ExaoneLayer(DecoderOnlyLayer):
|
76
|
+
def get_pre_attention_layernorm(self) -> nn.LayerNorm:
|
77
|
+
return self._original_mod.ln_1
|
78
|
+
|
79
|
+
def get_post_attention_layernorm(self) -> nn.LayerNorm:
|
80
|
+
return self._original_mod.ln_2
|
81
|
+
|
82
|
+
|
83
|
+
class ExaoneAttention(DecoderOnlyAttention):
|
84
|
+
def __post_init__(self):
|
85
|
+
self.q_proj = self._original_mod.q_proj
|
86
|
+
self.k_proj = self._original_mod.k_proj
|
87
|
+
self.v_proj = self._original_mod.v_proj
|
88
|
+
self.o_proj = self._original_mod.out_proj
|
89
|
+
|
90
|
+
|
91
|
+
class ExaoneFlashAttention(DecoderOnlyFlashAttention):
|
92
|
+
def __post_init__(self):
|
93
|
+
self.q_proj = self._original_mod.q_proj
|
94
|
+
self.k_proj = self._original_mod.k_proj
|
95
|
+
self.v_proj = self._original_mod.v_proj
|
96
|
+
self.o_proj = self._original_mod.out_proj
|
@@ -21,10 +21,12 @@
|
|
21
21
|
# copied, modified, or distributed without prior written permission
|
22
22
|
# from Rebellions Inc.
|
23
23
|
|
24
|
+
|
25
|
+
from transformers import AutoModelForCausalLM
|
26
|
+
|
24
27
|
from ....utils import logging
|
25
28
|
from ..decoderonly import RBLNDecoderOnlyModelForCausalLM
|
26
29
|
from .exaone_architecture import ExaoneForCausalLMWrapper
|
27
|
-
from .hf_hub_cached.modeling_exaone import ExaoneForCausalLM
|
28
30
|
|
29
31
|
|
30
32
|
logger = logging.get_logger(__name__)
|
@@ -45,7 +47,7 @@ class RBLNExaoneForCausalLM(RBLNDecoderOnlyModelForCausalLM):
|
|
45
47
|
"""
|
46
48
|
|
47
49
|
_decoder_wrapper_cls = ExaoneForCausalLMWrapper
|
48
|
-
|
50
|
+
_hf_class = AutoModelForCausalLM
|
49
51
|
|
50
52
|
@classmethod
|
51
53
|
def from_pretrained(cls, *args, **kwargs):
|
@@ -21,113 +21,42 @@
|
|
21
21
|
# copied, modified, or distributed without prior written permission
|
22
22
|
# from Rebellions Inc.
|
23
23
|
|
24
|
-
from typing import
|
25
|
-
|
26
|
-
import
|
27
|
-
|
28
|
-
|
29
|
-
|
30
|
-
|
31
|
-
|
32
|
-
DecoderOnlyDecoderLayer,
|
24
|
+
from typing import TYPE_CHECKING
|
25
|
+
|
26
|
+
from ...models.decoderonly.decoderonly_architecture import (
|
27
|
+
DecoderOnlyAttention,
|
28
|
+
DecoderOnlyFlashAttention,
|
29
|
+
DecoderOnlyForCausalLM,
|
30
|
+
DecoderOnlyLayer,
|
31
|
+
DecoderOnlyModel,
|
33
32
|
DecoderOnlyWrapper,
|
34
|
-
slice_and_unsqueeze_cos_sin,
|
35
33
|
)
|
36
|
-
from ...models.decoderonly.decoderonly_architecture import DECODERONLY_ATTENTION_CLASSES
|
37
|
-
|
38
|
-
|
39
|
-
class GemmaWrapper(DecoderOnlyWrapper):
|
40
|
-
def get_forward_dict(self):
|
41
|
-
forward_dict = {}
|
42
|
-
forward_dict.update(
|
43
|
-
{
|
44
|
-
"wrapper": GemmaModel.forward,
|
45
|
-
"model": DecoderOnlyDecoderLayer.forward,
|
46
|
-
"decoder_layer": DECODERONLY_ATTENTION_CLASSES[self.attn_implementation].forward,
|
47
|
-
}
|
48
|
-
)
|
49
|
-
return forward_dict
|
50
|
-
|
51
|
-
|
52
|
-
class GemmaModel:
|
53
|
-
def forward(
|
54
|
-
self,
|
55
|
-
input_ids: torch.LongTensor = None,
|
56
|
-
attention_mask: Optional[torch.Tensor] = None,
|
57
|
-
position_ids: Optional[torch.LongTensor] = None,
|
58
|
-
past_key_values: Optional[List[torch.FloatTensor]] = None,
|
59
|
-
batch_ids: Optional[torch.LongTensor] = None,
|
60
|
-
inputs_embeds: Optional[torch.FloatTensor] = None,
|
61
|
-
use_cache: Optional[bool] = True,
|
62
|
-
output_attentions: Optional[bool] = False,
|
63
|
-
output_hidden_states: Optional[bool] = False,
|
64
|
-
cache_pos_for_partitions: Optional[torch.Tensor] = None,
|
65
|
-
kvcache_partition_size: Optional[torch.Tensor] = None,
|
66
|
-
forward_dict: Optional[Dict[str, classmethod]] = None,
|
67
|
-
rotary_pos_emb=None,
|
68
|
-
) -> Union[Tuple, BaseModelOutputWithPast]:
|
69
|
-
# retrieve input_ids and inputs_embeds
|
70
|
-
if (input_ids is None) ^ (inputs_embeds is not None):
|
71
|
-
raise ValueError(
|
72
|
-
"You cannot specify both input_ids and inputs_embeds at the same time, and must specify either one"
|
73
|
-
)
|
74
|
-
|
75
|
-
# embed positions
|
76
|
-
inputs_embeds = self.embed_tokens(input_ids)
|
77
|
-
hidden_states = inputs_embeds
|
78
34
|
|
79
|
-
##### GEMMA change from llama#####
|
80
|
-
hidden_states = hidden_states * (self.config.hidden_size**0.5)
|
81
35
|
|
82
|
-
|
36
|
+
if TYPE_CHECKING:
|
37
|
+
from transformers import GemmaForCausalLM
|
83
38
|
|
84
|
-
# get cos,sin vector
|
85
|
-
cos, sin = rotary_pos_emb(inputs_embeds, attention_mask.shape[-1])
|
86
|
-
cos, sin = slice_and_unsqueeze_cos_sin(cos, sin, position_ids)
|
87
39
|
|
88
|
-
|
89
|
-
|
90
|
-
|
91
|
-
|
92
|
-
|
93
|
-
|
94
|
-
|
95
|
-
|
96
|
-
|
97
|
-
|
98
|
-
|
99
|
-
|
100
|
-
|
101
|
-
|
102
|
-
|
103
|
-
|
104
|
-
|
105
|
-
|
106
|
-
|
107
|
-
|
108
|
-
|
109
|
-
|
110
|
-
|
111
|
-
|
112
|
-
hidden_states = layer_outputs[0]
|
113
|
-
|
114
|
-
updated_cache = layer_outputs[2 if output_attentions else 1]
|
115
|
-
|
116
|
-
if output_attentions:
|
117
|
-
all_self_attns += (layer_outputs[1],)
|
118
|
-
|
119
|
-
hidden_states = self.norm(hidden_states)
|
120
|
-
|
121
|
-
# add hidden states from the last decoder layer
|
122
|
-
if output_hidden_states:
|
123
|
-
all_hidden_states += (hidden_states,)
|
124
|
-
|
125
|
-
# convert RebelDynamicCache to legacy Tuple[Tuple[torch.Tensor]]
|
126
|
-
next_cache = updated_cache.to_legacy_cache()
|
127
|
-
|
128
|
-
return BaseModelOutputWithPast(
|
129
|
-
last_hidden_state=hidden_states,
|
130
|
-
past_key_values=next_cache,
|
131
|
-
hidden_states=all_hidden_states,
|
132
|
-
attentions=all_self_attns,
|
133
|
-
)
|
40
|
+
class GemmaWrapper(DecoderOnlyWrapper):
|
41
|
+
def convert_to_rbln_causal_lm(self, causal_lm: "GemmaForCausalLM"):
|
42
|
+
new_layers = []
|
43
|
+
for layer in causal_lm.model.layers:
|
44
|
+
if self.attn_impl == "eager":
|
45
|
+
new_self_attn = DecoderOnlyAttention(layer.self_attn)
|
46
|
+
elif self.attn_impl == "flash_attn":
|
47
|
+
new_self_attn = DecoderOnlyFlashAttention(
|
48
|
+
layer.self_attn, kvcache_partition_len=self.kvcache_partition_len
|
49
|
+
)
|
50
|
+
else:
|
51
|
+
raise NotImplementedError(f"Unknwon attn : {self.attn_impl}")
|
52
|
+
new_layer = DecoderOnlyLayer(layer, new_self_attn)
|
53
|
+
new_layers.append(new_layer)
|
54
|
+
new_model = GemmaModel(causal_lm.model, new_layers, partition_len=self.kvcache_partition_len)
|
55
|
+
new_causal_lm = DecoderOnlyForCausalLM(causal_lm, new_model)
|
56
|
+
return new_causal_lm
|
57
|
+
|
58
|
+
|
59
|
+
class GemmaModel(DecoderOnlyModel):
|
60
|
+
@property
|
61
|
+
def hidden_multiplier(self):
|
62
|
+
return self._original_mod.config.hidden_size**0.5
|
@@ -21,262 +21,74 @@
|
|
21
21
|
# copied, modified, or distributed without prior written permission
|
22
22
|
# from Rebellions Inc.
|
23
23
|
|
24
|
-
|
24
|
+
import math
|
25
|
+
from typing import TYPE_CHECKING, Tuple
|
25
26
|
|
26
27
|
import torch
|
27
28
|
import torch.nn as nn
|
28
|
-
from transformers.modeling_outputs import BaseModelOutputWithPast
|
29
29
|
|
30
|
-
from
|
30
|
+
from ..decoderonly.decoderonly_architecture import (
|
31
|
+
DecoderOnlyAttention,
|
32
|
+
DecoderOnlyForCausalLM,
|
33
|
+
DecoderOnlyLayer,
|
34
|
+
DecoderOnlyModel,
|
35
|
+
DecoderOnlyWrapper,
|
36
|
+
)
|
31
37
|
|
32
38
|
|
33
|
-
|
34
|
-
|
35
|
-
super().__init__()
|
36
|
-
self.model = model.transformer
|
37
|
-
self.lm_head = model.lm_head
|
38
|
-
self.config = model.config
|
39
|
-
self.max_seq_len = max_seq_len
|
40
|
-
self.forward_dict = self.get_forward_dict()
|
39
|
+
if TYPE_CHECKING:
|
40
|
+
from transformers import GPT2LMHeadModel
|
41
41
|
|
42
|
-
def get_forward_dict(self):
|
43
|
-
forward_dict = {
|
44
|
-
"wrapper": _GPT2Model.forward,
|
45
|
-
"model": _GPT2Block.forward,
|
46
|
-
"decoder_layer": _GPT2Attention.forward,
|
47
|
-
}
|
48
|
-
return forward_dict
|
49
42
|
|
50
|
-
|
51
|
-
|
52
|
-
|
53
|
-
|
54
|
-
|
55
|
-
|
56
|
-
|
57
|
-
|
58
|
-
|
59
|
-
|
60
|
-
|
61
|
-
|
62
|
-
rbln_batch_position = batch_position
|
43
|
+
class GPT2Wrapper(DecoderOnlyWrapper):
|
44
|
+
def convert_to_rbln_causal_lm(self, causal_lm: "GPT2LMHeadModel"):
|
45
|
+
if self.attn_impl != "eager":
|
46
|
+
raise NotImplementedError(f"flash attention ({self.attn_impl}) is not implemented for {self.__class__}")
|
47
|
+
new_layers = []
|
48
|
+
for layer in causal_lm.transformer.h:
|
49
|
+
new_self_attn = GPT2Attention(layer.attn)
|
50
|
+
new_layer = GPT2Layer(layer, new_self_attn)
|
51
|
+
new_layers.append(new_layer)
|
52
|
+
new_model = GPT2Model(causal_lm.transformer, new_layers)
|
53
|
+
new_causal_lm = DecoderOnlyForCausalLM(causal_lm, new_model)
|
54
|
+
return new_causal_lm
|
63
55
|
|
64
|
-
# Formatting list of past_kv to DynamicCache class.
|
65
|
-
past_key_value = RebelDynamicCache_4D.from_input_format(
|
66
|
-
cache_position,
|
67
|
-
self.config.n_layer,
|
68
|
-
*past_key_values,
|
69
|
-
)
|
70
56
|
|
71
|
-
|
72
|
-
|
73
|
-
|
74
|
-
attention_mask=attention_mask,
|
75
|
-
position_ids=cache_position,
|
76
|
-
past_key_value=past_key_value,
|
77
|
-
batch_ids=rbln_batch_position,
|
78
|
-
forward_dict=self.forward_dict,
|
79
|
-
# rotary_emb differenct from_llama
|
80
|
-
)
|
57
|
+
class GPT2Model(DecoderOnlyModel):
|
58
|
+
def get_last_layernorm(self) -> nn.LayerNorm:
|
59
|
+
return self._original_mod.ln_f
|
81
60
|
|
82
|
-
|
83
|
-
|
84
|
-
hidden_states = hidden_states[:, query_idx].unsqueeze(1)
|
85
|
-
logits = self.lm_head(hidden_states)
|
61
|
+
def get_embedding(self) -> nn.Embedding:
|
62
|
+
return self._original_mod.wte
|
86
63
|
|
87
|
-
|
64
|
+
def get_pos_embedding(self) -> nn.Embedding:
|
65
|
+
return self._original_mod.wpe
|
88
66
|
|
89
|
-
return output, batch_position + query_idx
|
90
67
|
|
68
|
+
class GPT2Layer(DecoderOnlyLayer):
|
69
|
+
def get_pre_attention_layernorm(self) -> nn.LayerNorm:
|
70
|
+
return self._original_mod.ln_1
|
91
71
|
|
92
|
-
|
93
|
-
|
94
|
-
self,
|
95
|
-
input_ids: torch.LongTensor = None,
|
96
|
-
attention_mask: Optional[torch.Tensor] = None,
|
97
|
-
position_ids: Optional[torch.LongTensor] = None,
|
98
|
-
past_key_value: Optional[RebelDynamicCache_4D] = None,
|
99
|
-
batch_ids: Optional[torch.LongTensor] = None,
|
100
|
-
forward_dict: Optional[Dict[str, classmethod]] = None,
|
101
|
-
) -> BaseModelOutputWithPast:
|
102
|
-
b_size, q_len = input_ids.shape
|
103
|
-
inputs_embeds = self.wte(input_ids)
|
72
|
+
def get_post_attention_layernorm(self) -> nn.LayerNorm:
|
73
|
+
return self._original_mod.ln_2
|
104
74
|
|
105
|
-
if position_ids.shape[0] > 1:
|
106
|
-
position_embeds = []
|
107
|
-
for b_idx in range(b_size):
|
108
|
-
position_embed = self.wpe(position_ids[b_idx])
|
109
|
-
# position_embed = position_embed.dtype(inputs_embeds.dtype)
|
110
|
-
position_embeds.append(position_embed)
|
111
75
|
|
112
|
-
|
113
|
-
|
114
|
-
|
76
|
+
class GPT2Attention(DecoderOnlyAttention):
|
77
|
+
def __post_init__(self):
|
78
|
+
self.c_attn = self._original_mod.c_attn
|
79
|
+
self.o_proj = self._original_mod.c_proj
|
80
|
+
self.split_size = self._original_mod.split_size
|
115
81
|
|
116
|
-
|
82
|
+
def projection(self, hidden_states) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
|
83
|
+
query_states, key_states, value_states = self.c_attn(hidden_states).split(self.split_size, dim=2)
|
84
|
+
return query_states, key_states, value_states
|
117
85
|
|
118
|
-
|
119
|
-
|
120
|
-
|
86
|
+
def get_attn_scale(self):
|
87
|
+
scale = 1.0
|
88
|
+
if self._original_mod.scale_attn_weights:
|
89
|
+
scale /= math.sqrt(self.head_dim)
|
121
90
|
|
122
|
-
|
123
|
-
|
124
|
-
block,
|
125
|
-
hidden_states,
|
126
|
-
layer_idx,
|
127
|
-
attention_mask=attention_mask,
|
128
|
-
past_key_value=past_key_value,
|
129
|
-
position_ids=position_ids,
|
130
|
-
batch_ids=batch_ids,
|
131
|
-
forward_dict=forward_dict,
|
132
|
-
)
|
91
|
+
if self._original_mod.scale_attn_by_inverse_layer_idx:
|
92
|
+
scale /= 1 + self.layer_idx
|
133
93
|
|
134
|
-
|
135
|
-
output_shape = (-1,) + (q_len,) + (hidden_states.size(-1),)
|
136
|
-
hidden_states = hidden_states.view(output_shape)
|
137
|
-
|
138
|
-
# convert RebelDynamicCache to legacy Tuple[Tuple[torch.Tensor]]
|
139
|
-
next_cache = updated_cache.to_legacy_cache()
|
140
|
-
|
141
|
-
return BaseModelOutputWithPast(
|
142
|
-
last_hidden_state=hidden_states,
|
143
|
-
past_key_values=next_cache,
|
144
|
-
)
|
145
|
-
|
146
|
-
|
147
|
-
class _GPT2Block:
|
148
|
-
def forward(
|
149
|
-
self,
|
150
|
-
hidden_states: Optional[Tuple[torch.FloatTensor]],
|
151
|
-
layer_idx: int,
|
152
|
-
attention_mask: Optional[torch.FloatTensor] = None,
|
153
|
-
position_ids: Optional[torch.LongTensor] = None,
|
154
|
-
past_key_value: Optional[RebelDynamicCache_4D] = None,
|
155
|
-
batch_ids: Optional[torch.LongTensor] = None,
|
156
|
-
forward_dict: Optional[Dict[str, classmethod]] = None,
|
157
|
-
**kwargs,
|
158
|
-
) -> Tuple[torch.Tensor, RebelDynamicCache_4D]:
|
159
|
-
residual = hidden_states
|
160
|
-
hidden_states = self.ln_1(hidden_states)
|
161
|
-
|
162
|
-
hidden_states, k, v = forward_dict["decoder_layer"](
|
163
|
-
self.attn,
|
164
|
-
hidden_states=hidden_states,
|
165
|
-
attention_mask=attention_mask,
|
166
|
-
position_ids=position_ids,
|
167
|
-
past_key_value=past_key_value,
|
168
|
-
batch_index=batch_ids,
|
169
|
-
)
|
170
|
-
past_key_value.assign(k, v, layer_idx)
|
171
|
-
|
172
|
-
# residual connection
|
173
|
-
hidden_states = residual + hidden_states
|
174
|
-
|
175
|
-
residual = hidden_states
|
176
|
-
hidden_states = self.ln_2(hidden_states)
|
177
|
-
hidden_states = self.mlp(hidden_states)
|
178
|
-
hidden_states = residual + hidden_states
|
179
|
-
|
180
|
-
return hidden_states, past_key_value
|
181
|
-
|
182
|
-
|
183
|
-
class _GPT2Attention:
|
184
|
-
def _attn(self, query, key, value, attention_mask=None, head_mask=None):
|
185
|
-
attn_weights = torch.matmul(query, key.transpose(-1, -2))
|
186
|
-
|
187
|
-
if self.scale_attn_weights:
|
188
|
-
attn_weights = attn_weights / torch.full(
|
189
|
-
[], value.size(-1) ** 0.5, dtype=attn_weights.dtype, device=attn_weights.device
|
190
|
-
)
|
191
|
-
|
192
|
-
# Layer-wise attention scaling
|
193
|
-
if self.scale_attn_by_inverse_layer_idx:
|
194
|
-
attn_weights = attn_weights / float(self.layer_idx + 1)
|
195
|
-
|
196
|
-
# -------------------
|
197
|
-
# Below are deleted since "where" op does not supported on RBLN graph.
|
198
|
-
# -------------------
|
199
|
-
# if not self.is_cross_attention:
|
200
|
-
# # if only "normal" attention layer implements causal mask
|
201
|
-
# query_length, key_length = query.size(-2), key.size(-2)
|
202
|
-
# causal_mask = self.bias[:, :, key_length - query_length : key_length, :key_length]
|
203
|
-
# mask_value = torch.finfo(attn_weights.dtype).min
|
204
|
-
# # Need to be a tensor, otherwise we get error: `RuntimeError: expected scalar type float but found double`.
|
205
|
-
# # Need to be on the same device, otherwise `RuntimeError: ..., x and y to be on the same device`
|
206
|
-
# mask_value = torch.full([], mask_value, dtype=attn_weights.dtype, device=attn_weights.device)
|
207
|
-
# attn_weights = torch.where(causal_mask, attn_weights.to(attn_weights.dtype), mask_value)
|
208
|
-
|
209
|
-
# Apply the attention mask
|
210
|
-
attn_weights.view(
|
211
|
-
-1,
|
212
|
-
)
|
213
|
-
attn_weights = attn_weights + attention_mask
|
214
|
-
attn_weights = nn.functional.softmax(attn_weights, dim=-1)
|
215
|
-
attn_output = torch.matmul(attn_weights, value)
|
216
|
-
|
217
|
-
return attn_output, attn_weights
|
218
|
-
|
219
|
-
def forward(
|
220
|
-
self,
|
221
|
-
hidden_states: Optional[Tuple[torch.FloatTensor]],
|
222
|
-
attention_mask: Optional[torch.FloatTensor] = None,
|
223
|
-
past_key_value: Optional[RebelDynamicCache_4D] = None,
|
224
|
-
batch_index: Optional[int] = None,
|
225
|
-
**kwargs,
|
226
|
-
) -> Tuple[Union[torch.Tensor, Tuple[torch.Tensor]], ...]:
|
227
|
-
bsz, q_len, _ = hidden_states.size()
|
228
|
-
query, key, value = self.c_attn(hidden_states).split(self.split_size, dim=2)
|
229
|
-
|
230
|
-
querys = self._split_heads(query, self.num_heads, self.head_dim) # (batch, head, seq_length, head_features)
|
231
|
-
keys = self._split_heads(key, self.num_heads, self.head_dim)
|
232
|
-
values = self._split_heads(value, self.num_heads, self.head_dim)
|
233
|
-
|
234
|
-
# Decoder
|
235
|
-
if (batch_index is None or batch_index == -1) and bsz > 1:
|
236
|
-
all_keys = []
|
237
|
-
all_values = []
|
238
|
-
all_attn_output = []
|
239
|
-
|
240
|
-
for b in range(bsz):
|
241
|
-
query = querys[b].unsqueeze(0)
|
242
|
-
attn_mask = attention_mask[b].unsqueeze(0)
|
243
|
-
key = keys[b].unsqueeze(0)
|
244
|
-
value = values[b].unsqueeze(0)
|
245
|
-
|
246
|
-
key, value = past_key_value.update(
|
247
|
-
key,
|
248
|
-
value,
|
249
|
-
self.layer_idx,
|
250
|
-
b,
|
251
|
-
)
|
252
|
-
|
253
|
-
attn_output, _ = _GPT2Attention._attn(self, query, key, value, attn_mask)
|
254
|
-
attn_output = self._merge_heads(attn_output, self.num_heads, self.head_dim)
|
255
|
-
|
256
|
-
all_keys.append(key)
|
257
|
-
all_values.append(value)
|
258
|
-
all_attn_output.append(attn_output)
|
259
|
-
|
260
|
-
keys = torch.cat(all_keys, dim=0)
|
261
|
-
values = torch.cat(all_values, dim=0)
|
262
|
-
attn_output = torch.cat(all_attn_output, dim=0)
|
263
|
-
|
264
|
-
# Prefill
|
265
|
-
else:
|
266
|
-
if batch_index is None or batch_index == -1:
|
267
|
-
batch_index = 0
|
268
|
-
|
269
|
-
keys, values = past_key_value.update(
|
270
|
-
keys,
|
271
|
-
values,
|
272
|
-
self.layer_idx,
|
273
|
-
batch_index,
|
274
|
-
read_first_step=True,
|
275
|
-
)
|
276
|
-
|
277
|
-
attn_output, _ = _GPT2Attention._attn(self, querys, keys, values, attention_mask)
|
278
|
-
attn_output = self._merge_heads(attn_output, self.num_heads, self.head_dim)
|
279
|
-
|
280
|
-
attn_output = self.c_proj(attn_output)
|
281
|
-
|
282
|
-
return attn_output, keys, values
|
94
|
+
return scale
|
@@ -23,7 +23,7 @@
|
|
23
23
|
|
24
24
|
from ....utils import logging
|
25
25
|
from ...models.decoderonly import RBLNDecoderOnlyModelForCausalLM
|
26
|
-
from .gpt2_architecture import
|
26
|
+
from .gpt2_architecture import GPT2Wrapper
|
27
27
|
|
28
28
|
|
29
29
|
logger = logging.get_logger(__name__)
|
@@ -43,4 +43,5 @@ class RBLNGPT2LMHeadModel(RBLNDecoderOnlyModelForCausalLM):
|
|
43
43
|
|
44
44
|
"""
|
45
45
|
|
46
|
-
_decoder_wrapper_cls =
|
46
|
+
_decoder_wrapper_cls = GPT2Wrapper
|
47
|
+
_use_rotary_emb = False
|