optimum-rbln 0.1.11__py3-none-any.whl → 0.1.13__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- optimum/rbln/__init__.py +14 -7
- optimum/rbln/__version__.py +1 -1
- optimum/rbln/diffusers/models/autoencoder_kl.py +30 -63
- optimum/rbln/diffusers/models/controlnet.py +36 -62
- optimum/rbln/diffusers/models/unet_2d_condition.py +57 -156
- optimum/rbln/diffusers/pipelines/__init__.py +40 -12
- optimum/rbln/diffusers/pipelines/controlnet/multicontrolnet.py +11 -0
- optimum/rbln/diffusers/pipelines/controlnet/pipeline_controlnet.py +9 -187
- optimum/rbln/diffusers/pipelines/controlnet/pipeline_controlnet_img2img.py +8 -192
- optimum/rbln/diffusers/pipelines/controlnet/pipeline_controlnet_sd_xl.py +8 -206
- optimum/rbln/diffusers/pipelines/controlnet/pipeline_controlnet_sd_xl_img2img.py +8 -207
- optimum/rbln/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion.py +3 -111
- optimum/rbln/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_img2img.py +12 -117
- optimum/rbln/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl.py +4 -123
- optimum/rbln/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_img2img.py +4 -126
- optimum/rbln/modeling_alias.py +4 -9
- optimum/rbln/modeling_base.py +117 -144
- optimum/rbln/modeling_config.py +51 -0
- optimum/rbln/modeling_diffusers.py +400 -0
- optimum/rbln/transformers/__init__.py +10 -0
- optimum/rbln/transformers/cache_utils.py +5 -9
- optimum/rbln/transformers/modeling_rope_utils.py +283 -0
- optimum/rbln/transformers/models/__init__.py +80 -28
- optimum/rbln/transformers/models/auto/modeling_auto.py +1 -0
- optimum/rbln/transformers/models/bart/__init__.py +1 -1
- optimum/rbln/transformers/models/bart/bart_architecture.py +18 -12
- optimum/rbln/transformers/models/bart/modeling_bart.py +25 -6
- optimum/rbln/transformers/models/bert/modeling_bert.py +1 -2
- optimum/rbln/transformers/models/clip/modeling_clip.py +13 -23
- optimum/rbln/transformers/models/decoderonly/__init__.py +0 -2
- optimum/rbln/transformers/models/decoderonly/decoderonly_architecture.py +376 -218
- optimum/rbln/transformers/models/decoderonly/modeling_decoderonly.py +246 -116
- optimum/rbln/transformers/models/dpt/modeling_dpt.py +0 -1
- optimum/rbln/transformers/models/exaone/__init__.py +32 -0
- optimum/rbln/transformers/models/exaone/exaone_architecture.py +81 -0
- optimum/rbln/transformers/models/exaone/hf_hub_cached/configuration_exaone.py +181 -0
- optimum/rbln/transformers/models/exaone/hf_hub_cached/modeling_exaone.py +1725 -0
- optimum/rbln/transformers/models/exaone/modeling_exaone.py +53 -0
- optimum/rbln/transformers/models/gemma/gemma_architecture.py +12 -2
- optimum/rbln/transformers/models/gemma/modeling_gemma.py +4 -28
- optimum/rbln/transformers/models/gpt2/modeling_gpt2.py +4 -30
- optimum/rbln/transformers/models/llama/modeling_llama.py +4 -28
- optimum/rbln/transformers/models/llava_next/modeling_llava_next.py +166 -151
- optimum/rbln/transformers/models/midm/midm_architecture.py +4 -15
- optimum/rbln/transformers/models/midm/modeling_midm.py +8 -28
- optimum/rbln/transformers/models/mistral/modeling_mistral.py +4 -29
- optimum/rbln/transformers/models/phi/modeling_phi.py +5 -31
- optimum/rbln/transformers/models/phi/phi_architecture.py +75 -159
- optimum/rbln/transformers/models/qwen2/__init__.py +24 -0
- optimum/rbln/transformers/models/qwen2/modeling_qwen2.py +43 -0
- optimum/rbln/transformers/models/qwen2/qwen2_architecture.py +29 -0
- optimum/rbln/transformers/models/seq2seq/__init__.py +24 -0
- optimum/rbln/{modeling_seq2seq.py → transformers/models/seq2seq/modeling_seq2seq.py} +107 -166
- optimum/rbln/transformers/models/t5/__init__.py +1 -0
- optimum/rbln/transformers/models/t5/modeling_t5.py +108 -0
- optimum/rbln/transformers/models/t5/t5_architecture.py +46 -32
- optimum/rbln/transformers/models/wav2vec2/modeling_wav2vec2.py +0 -1
- optimum/rbln/transformers/models/whisper/modeling_whisper.py +38 -13
- optimum/rbln/transformers/models/xlm_roberta/modeling_xlm_roberta.py +1 -2
- optimum/rbln/transformers/utils/rbln_quantization.py +8 -2
- optimum/rbln/utils/context.py +58 -0
- optimum/rbln/utils/decorator_utils.py +55 -0
- optimum/rbln/utils/import_utils.py +21 -0
- optimum/rbln/utils/logging.py +1 -1
- optimum/rbln/utils/runtime_utils.py +4 -4
- optimum/rbln/utils/timer_utils.py +26 -2
- {optimum_rbln-0.1.11.dist-info → optimum_rbln-0.1.13.dist-info}/METADATA +11 -9
- optimum_rbln-0.1.13.dist-info/RECORD +107 -0
- {optimum_rbln-0.1.11.dist-info → optimum_rbln-0.1.13.dist-info}/WHEEL +1 -1
- optimum_rbln-0.1.11.dist-info/RECORD +0 -93
- {optimum_rbln-0.1.11.dist-info → optimum_rbln-0.1.13.dist-info}/entry_points.txt +0 -0
- {optimum_rbln-0.1.11.dist-info → optimum_rbln-0.1.13.dist-info}/licenses/LICENSE +0 -0
@@ -58,23 +58,12 @@ class MidmLMHeadModelWrapper(torch.nn.Module):
|
|
58
58
|
self.model = model.transformer
|
59
59
|
self.lm_head = model.lm_head
|
60
60
|
self.config = model.config
|
61
|
-
self.head_dim = self.config.n_embd // self.config.n_head
|
62
|
-
self.max_position_embeddings = (
|
63
|
-
self.config.max_position_embeddings if max_seq_len > self.config.max_position_embeddings else max_seq_len
|
64
|
-
)
|
65
61
|
self.max_seq_len = max_seq_len
|
66
|
-
self.rotary_dim = int(
|
67
|
-
model.config.hidden_size // model.config.num_attention_heads * model.config.rotary_percentage
|
68
|
-
)
|
69
|
-
self.rotary_emb = self._init_rope()
|
70
62
|
|
71
|
-
|
72
|
-
|
73
|
-
|
74
|
-
|
75
|
-
max_position_embeddings=self.max_position_embeddings,
|
76
|
-
)
|
77
|
-
return rotary_emb
|
63
|
+
self.config.partial_rotary_factor = model.config.rotary_percentage
|
64
|
+
self.config.head_dim = self.config.n_embd // self.config.n_head
|
65
|
+
self.config.rope_theta = 10000
|
66
|
+
self.rotary_emb = RotaryEmbedding(config=self.config, max_seq_len_cached=max_seq_len)
|
78
67
|
|
79
68
|
def forward(
|
80
69
|
self,
|
@@ -21,11 +21,7 @@
|
|
21
21
|
# copied, modified, or distributed without prior written permission
|
22
22
|
# from Rebellions Inc.
|
23
23
|
|
24
|
-
import
|
25
|
-
import logging
|
26
|
-
from typing import TYPE_CHECKING, Any, Callable
|
27
|
-
|
28
|
-
from ....modeling_config import RBLNConfig
|
24
|
+
from ....utils import logging
|
29
25
|
from ...models.decoderonly import RBLNDecoderOnlyModelForCausalLM
|
30
26
|
from .hf_hub_cached.modeling_midm import MidmLMHeadModel
|
31
27
|
from .midm_architecture import (
|
@@ -33,11 +29,7 @@ from .midm_architecture import (
|
|
33
29
|
)
|
34
30
|
|
35
31
|
|
36
|
-
logger = logging.
|
37
|
-
if TYPE_CHECKING:
|
38
|
-
from transformers import (
|
39
|
-
PreTrainedModel,
|
40
|
-
)
|
32
|
+
logger = logging.get_logger(__name__)
|
41
33
|
|
42
34
|
|
43
35
|
class RBLNMidmLMHeadModel(RBLNDecoderOnlyModelForCausalLM):
|
@@ -54,22 +46,10 @@ class RBLNMidmLMHeadModel(RBLNDecoderOnlyModelForCausalLM):
|
|
54
46
|
|
55
47
|
"""
|
56
48
|
|
57
|
-
|
58
|
-
|
59
|
-
rbln_max_seq_len = rbln_config.model_cfg["max_seq_len"]
|
60
|
-
return MidmLMHeadModelWrapper(model, rbln_max_seq_len).eval()
|
61
|
-
|
62
|
-
def __getattr__(self, __name: str) -> Any:
|
63
|
-
"""This is the key method to implement RBLN-Midm.
|
49
|
+
_decoder_wrapper_cls = MidmLMHeadModelWrapper
|
50
|
+
_original_cls = MidmLMHeadModel
|
64
51
|
|
65
|
-
|
66
|
-
|
67
|
-
""
|
68
|
-
|
69
|
-
def redirect(func):
|
70
|
-
return lambda *pargs, **kwargs: func(self, *pargs, **kwargs)
|
71
|
-
|
72
|
-
val = getattr(MidmLMHeadModel, __name)
|
73
|
-
if isinstance(val, Callable) and "self" in set(inspect.signature(val).parameters):
|
74
|
-
return redirect(val)
|
75
|
-
return val
|
52
|
+
@classmethod
|
53
|
+
def from_pretrained(cls, *args, **kwargs):
|
54
|
+
kwargs.setdefault("trust_remote_code", True)
|
55
|
+
return super().from_pretrained(*args, **kwargs)
|
@@ -21,29 +21,18 @@
|
|
21
21
|
# copied, modified, or distributed without prior written permission
|
22
22
|
# from Rebellions Inc.
|
23
23
|
|
24
|
-
import
|
25
|
-
import logging
|
26
|
-
from typing import TYPE_CHECKING, Any, Callable
|
27
|
-
|
28
|
-
from transformers import MistralForCausalLM
|
29
|
-
|
24
|
+
from ....utils import logging
|
30
25
|
from ...models.decoderonly import RBLNDecoderOnlyModelForCausalLM
|
31
26
|
from .mistral_architecture import MistralForCausalLMWrapper
|
32
27
|
|
33
28
|
|
34
|
-
|
35
|
-
from transformers import PreTrainedModel
|
36
|
-
|
37
|
-
from ....modeling_config import RBLNConfig
|
38
|
-
|
39
|
-
|
40
|
-
logger = logging.getLogger(__name__)
|
29
|
+
logger = logging.get_logger(__name__)
|
41
30
|
|
42
31
|
|
43
32
|
class RBLNMistralForCausalLM(RBLNDecoderOnlyModelForCausalLM):
|
44
33
|
"""
|
45
34
|
The Llama Model transformer with a language modeling head (linear layer) on top.
|
46
|
-
This model inherits from [`
|
35
|
+
This model inherits from [`RBLNDecoderOnlyModelForCausalLM`]. Check the superclass documentation for the generic methods the library implements for all its models.
|
47
36
|
|
48
37
|
A class to convert and run pre-trained transformers based LlamaForCausalLM model on RBLN devices.
|
49
38
|
It implements the methods to convert a pre-trained transformers LlamaForCausalLM model into a RBLN transformer model by:
|
@@ -51,18 +40,4 @@ class RBLNMistralForCausalLM(RBLNDecoderOnlyModelForCausalLM):
|
|
51
40
|
- compiling the resulting graph using the RBLN compiler.
|
52
41
|
"""
|
53
42
|
|
54
|
-
|
55
|
-
def wrap_model_if_needed(self, model: "PreTrainedModel", rbln_config: "RBLNConfig"):
|
56
|
-
rbln_max_seq_len = rbln_config.model_cfg["max_seq_len"]
|
57
|
-
return MistralForCausalLMWrapper(model, rbln_max_seq_len).eval()
|
58
|
-
|
59
|
-
def __getattr__(self, __name: str) -> Any:
|
60
|
-
def redirect(func):
|
61
|
-
return lambda *pargs, **kwargs: func(self, *pargs, **kwargs)
|
62
|
-
|
63
|
-
val = getattr(MistralForCausalLM, __name)
|
64
|
-
|
65
|
-
if isinstance(val, Callable) and "self" in set(inspect.signature(val).parameters):
|
66
|
-
return redirect(val)
|
67
|
-
|
68
|
-
return val
|
43
|
+
_decoder_wrapper_cls = MistralForCausalLMWrapper
|
@@ -21,28 +21,18 @@
|
|
21
21
|
# copied, modified, or distributed without prior written permission
|
22
22
|
# from Rebellions Inc.
|
23
23
|
|
24
|
-
import
|
25
|
-
import
|
26
|
-
from typing import TYPE_CHECKING, Any, Callable
|
27
|
-
|
28
|
-
from transformers import PhiForCausalLM
|
29
|
-
|
30
|
-
from ..decoderonly import RBLNDecoderOnlyModelForCausalLM
|
24
|
+
from ....utils import logging
|
25
|
+
from ...models.decoderonly import RBLNDecoderOnlyModelForCausalLM
|
31
26
|
from .phi_architecture import PhiWrapper
|
32
27
|
|
33
28
|
|
34
|
-
|
35
|
-
from transformers import PreTrainedModel
|
36
|
-
|
37
|
-
from ....modeling_config import RBLNConfig
|
38
|
-
|
39
|
-
logger = logging.getLogger(__name__)
|
29
|
+
logger = logging.get_logger(__name__)
|
40
30
|
|
41
31
|
|
42
32
|
class RBLNPhiForCausalLM(RBLNDecoderOnlyModelForCausalLM):
|
43
33
|
"""
|
44
34
|
The Phi Model transformer with a language modeling head (linear layer) on top.
|
45
|
-
This model inherits from [`
|
35
|
+
This model inherits from [`RBLNDecoderOnlyModelForCausalLM`]. Check the superclass documentation for the generic methods the library implements for all its models.
|
46
36
|
|
47
37
|
A class to convert and run pre-trained transformers based PhiForCausalLM model on RBLN devices.
|
48
38
|
It implements the methods to convert a pre-trained transformers PhiForCausalLM model into a RBLN transformer model by:
|
@@ -50,20 +40,4 @@ class RBLNPhiForCausalLM(RBLNDecoderOnlyModelForCausalLM):
|
|
50
40
|
- compiling the resulting graph using the RBLN compiler.
|
51
41
|
"""
|
52
42
|
|
53
|
-
|
54
|
-
def wrap_model_if_needed(self, model: "PreTrainedModel", rbln_config: "RBLNConfig"):
|
55
|
-
rbln_max_seq_len = rbln_config.model_cfg["max_seq_len"]
|
56
|
-
return PhiWrapper(model, rbln_max_seq_len).eval()
|
57
|
-
|
58
|
-
def __getattr__(self, __name: str) -> Any:
|
59
|
-
def redirect(func):
|
60
|
-
return lambda *pargs, **kwargs: func(self, *pargs, **kwargs)
|
61
|
-
|
62
|
-
val = getattr(PhiForCausalLM, __name)
|
63
|
-
|
64
|
-
if isinstance(val, Callable) and "self" in set(
|
65
|
-
inspect.signature(val).parameters
|
66
|
-
):
|
67
|
-
return redirect(val)
|
68
|
-
|
69
|
-
return val
|
43
|
+
_decoder_wrapper_cls = PhiWrapper
|
@@ -33,46 +33,12 @@ from transformers.modeling_outputs import (
|
|
33
33
|
from ...cache_utils import RebelDynamicCache
|
34
34
|
from ..decoderonly import (
|
35
35
|
DecoderOnlyWrapper,
|
36
|
-
DynamicNTKScalingRotaryEmbedding,
|
37
|
-
LinearScalingRotaryEmbedding,
|
38
|
-
RotaryEmbedding,
|
39
36
|
apply_rotary_pos_emb,
|
40
37
|
slice_and_unsqueeze_cos_sin,
|
41
38
|
)
|
42
39
|
|
43
40
|
|
44
41
|
class PhiWrapper(DecoderOnlyWrapper):
|
45
|
-
def _init_rope(self):
|
46
|
-
if self.rope_scaling is None:
|
47
|
-
rotary_emb = RotaryEmbedding(
|
48
|
-
int(self.config.partial_rotary_factor * self.head_dim),
|
49
|
-
max_position_embeddings=self.max_position_embeddings,
|
50
|
-
base=self.config.rope_theta,
|
51
|
-
)
|
52
|
-
else:
|
53
|
-
scaling_type = self.rope_scaling["type"]
|
54
|
-
scaling_factor = self.rope_scaling["factor"]
|
55
|
-
if scaling_type == "linear":
|
56
|
-
rotary_emb = LinearScalingRotaryEmbedding(
|
57
|
-
int(self.config.partial_rotary_factor * self.head_dim),
|
58
|
-
max_position_embeddings=self.max_position_embeddings,
|
59
|
-
scaling_factor=scaling_factor,
|
60
|
-
base=self.config.rope_theta,
|
61
|
-
max_seq_len=self.max_seq_len,
|
62
|
-
)
|
63
|
-
elif scaling_type == "dynamic":
|
64
|
-
rotary_emb = DynamicNTKScalingRotaryEmbedding(
|
65
|
-
int(self.config.partial_rotary_factor * self.head_dim),
|
66
|
-
max_position_embeddings=self.max_position_embeddings,
|
67
|
-
scaling_factor=scaling_factor,
|
68
|
-
base=self.config.rope_theta,
|
69
|
-
max_seq_len=self.max_seq_len,
|
70
|
-
)
|
71
|
-
else:
|
72
|
-
raise ValueError(f"Unknown RoPE scaling type {scaling_type}")
|
73
|
-
|
74
|
-
return rotary_emb
|
75
|
-
|
76
42
|
def get_forward_dict(self):
|
77
43
|
forward_dict = {}
|
78
44
|
forward_dict.update(
|
@@ -86,6 +52,41 @@ class PhiWrapper(DecoderOnlyWrapper):
|
|
86
52
|
|
87
53
|
|
88
54
|
class PhiAttention:
|
55
|
+
def _attn(self, query_state, key_state, value_state, attn_mask, past_key_value, batch_idx=0, is_prefill=False):
|
56
|
+
# reshape for removing repeat_kv (batch=1 , num_head, 1, q_len=1, head_dim)
|
57
|
+
key_state = key_state.unsqueeze(2)
|
58
|
+
value_state = value_state.unsqueeze(2)
|
59
|
+
attn_mask = attn_mask.unsqueeze(2)
|
60
|
+
|
61
|
+
query_state = query_state.view(
|
62
|
+
1,
|
63
|
+
self.num_key_value_heads,
|
64
|
+
self.num_heads // self.num_key_value_heads,
|
65
|
+
-1,
|
66
|
+
self.head_dim,
|
67
|
+
)
|
68
|
+
|
69
|
+
key_state, value_state = past_key_value.update(key_state, value_state, self.layer_idx, batch_idx, is_prefill)
|
70
|
+
|
71
|
+
# Queries and keys upcast to fp32 is required by Phi-2 to avoid overflow
|
72
|
+
attn_weights = torch.matmul(
|
73
|
+
query_state.to(torch.float32),
|
74
|
+
key_state.to(torch.float32).transpose(3, 4),
|
75
|
+
) / math.sqrt(self.head_dim)
|
76
|
+
attn_weights = attn_weights + attn_mask
|
77
|
+
|
78
|
+
# upcast attention to fp32
|
79
|
+
attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query_state.dtype)
|
80
|
+
attn_weights = nn.functional.dropout(attn_weights, p=self.attention_dropout, training=self.training)
|
81
|
+
attn_output = torch.matmul(attn_weights, value_state)
|
82
|
+
|
83
|
+
# reshape for removing repeat_kv
|
84
|
+
attn_output = attn_output.view(1, self.num_heads, -1, self.head_dim)
|
85
|
+
attn_output = attn_output.transpose(1, 2).contiguous()
|
86
|
+
attn_output = attn_output.reshape(1, -1, self.num_heads * self.head_dim)
|
87
|
+
|
88
|
+
return attn_output, key_state, value_state
|
89
|
+
|
89
90
|
def forward(
|
90
91
|
self,
|
91
92
|
hidden_states: torch.Tensor,
|
@@ -95,7 +96,6 @@ class PhiAttention:
|
|
95
96
|
output_attentions: bool = False,
|
96
97
|
cos: Optional[torch.Tensor] = None,
|
97
98
|
sin: Optional[torch.Tensor] = None,
|
98
|
-
rotary_pos_emb=None,
|
99
99
|
**kwargs,
|
100
100
|
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
|
101
101
|
bsz, q_len, _ = hidden_states.size()
|
@@ -108,24 +108,18 @@ class PhiAttention:
|
|
108
108
|
query_states = self.q_layernorm(query_states)
|
109
109
|
key_states = self.k_layernorm(key_states)
|
110
110
|
|
111
|
-
query_states = query_states.view(
|
112
|
-
|
113
|
-
).transpose(1, 2)
|
114
|
-
key_states = key_states.view(
|
115
|
-
bsz, q_len, self.num_key_value_heads, self.head_dim
|
116
|
-
).transpose(1, 2)
|
117
|
-
value_states = value_states.view(
|
118
|
-
bsz, q_len, self.num_key_value_heads, self.head_dim
|
119
|
-
).transpose(1, 2)
|
111
|
+
query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2)
|
112
|
+
key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
|
113
|
+
value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
|
120
114
|
|
121
115
|
# Partial rotary embedding
|
122
116
|
query_rot, query_pass = (
|
123
|
-
query_states[..., :
|
124
|
-
query_states[...,
|
117
|
+
query_states[..., : self.rotary_ndims],
|
118
|
+
query_states[..., self.rotary_ndims :],
|
125
119
|
)
|
126
120
|
key_rot, key_pass = (
|
127
|
-
key_states[..., :
|
128
|
-
key_states[...,
|
121
|
+
key_states[..., : self.rotary_ndims],
|
122
|
+
key_states[..., self.rotary_ndims :],
|
129
123
|
)
|
130
124
|
|
131
125
|
# [batch_size, seq_length, num_heads, head_dim // config.partial_rotary_factor]
|
@@ -135,113 +129,38 @@ class PhiAttention:
|
|
135
129
|
query_states = torch.cat((query_rot, query_pass), dim=-1)
|
136
130
|
key_states = torch.cat((key_rot, key_pass), dim=-1)
|
137
131
|
|
138
|
-
# Decoder
|
139
|
-
if
|
140
|
-
|
141
|
-
all_value_states = []
|
142
|
-
all_attn_output = []
|
143
|
-
|
132
|
+
# Decoder (bsz > 1)
|
133
|
+
if bsz > 1:
|
134
|
+
iterate_results = {"key_states": [], "value_states": [], "attn_output": []}
|
144
135
|
for b in range(bsz):
|
145
|
-
|
146
|
-
|
147
|
-
|
148
|
-
|
149
|
-
|
150
|
-
|
151
|
-
|
152
|
-
|
153
|
-
|
154
|
-
|
155
|
-
query_state = query_state.view(
|
156
|
-
1,
|
157
|
-
self.num_key_value_heads,
|
158
|
-
self.num_heads // self.num_key_value_heads,
|
159
|
-
q_len,
|
160
|
-
self.head_dim,
|
136
|
+
attn_output, key_state, value_state = PhiAttention._attn(
|
137
|
+
self,
|
138
|
+
query_states[b].unsqueeze(0),
|
139
|
+
key_states[b].unsqueeze(0),
|
140
|
+
value_states[b].unsqueeze(0),
|
141
|
+
attention_mask[b].unsqueeze(0),
|
142
|
+
past_key_value,
|
143
|
+
batch_idx=b,
|
144
|
+
is_prefill=False,
|
161
145
|
)
|
162
|
-
|
163
|
-
|
164
|
-
|
165
|
-
|
166
|
-
|
167
|
-
|
168
|
-
|
169
|
-
|
170
|
-
# Queries and keys upcast to fp32 is required by Phi-2 to avoid overflow
|
171
|
-
attn_weights = torch.matmul(
|
172
|
-
query_state.to(torch.float32),
|
173
|
-
key_state.to(torch.float32).transpose(3, 4),
|
174
|
-
) / math.sqrt(self.head_dim)
|
175
|
-
attn_weights = attn_weights + attn_mask
|
176
|
-
|
177
|
-
# upcast attention to fp32
|
178
|
-
attn_weights = nn.functional.softmax(
|
179
|
-
attn_weights, dim=-1, dtype=torch.float32
|
180
|
-
).to(query_states.dtype)
|
181
|
-
attn_weights = nn.functional.dropout(
|
182
|
-
attn_weights, p=self.attention_dropout, training=self.training
|
183
|
-
)
|
184
|
-
attn_output = torch.matmul(attn_weights, value_state)
|
185
|
-
|
186
|
-
# reshape for removing repeat_kv
|
187
|
-
attn_output = attn_output.view(1, self.num_heads, q_len, self.head_dim)
|
188
|
-
attn_output = attn_output.transpose(1, 2).contiguous()
|
189
|
-
attn_output = attn_output.reshape(
|
190
|
-
1, q_len, self.num_heads * self.head_dim
|
191
|
-
)
|
192
|
-
|
193
|
-
all_key_states.append(key_state)
|
194
|
-
all_value_states.append(value_state)
|
195
|
-
all_attn_output.append(attn_output)
|
196
|
-
|
197
|
-
key_states = torch.cat(all_key_states, dim=0)
|
198
|
-
value_states = torch.cat(all_value_states, dim=0)
|
199
|
-
attn_output = torch.cat(all_attn_output, dim=0)
|
146
|
+
iterate_results["key_states"].append(key_state)
|
147
|
+
iterate_results["value_states"].append(value_state)
|
148
|
+
iterate_results["attn_output"].append(attn_output)
|
149
|
+
|
150
|
+
key_states = torch.cat(iterate_results["key_states"], dim=0)
|
151
|
+
value_states = torch.cat(iterate_results["value_states"], dim=0)
|
152
|
+
attn_output = torch.cat(iterate_results["attn_output"], dim=0)
|
153
|
+
# Prefill & Decoder (bsz == 1)
|
200
154
|
else:
|
201
|
-
|
202
|
-
|
203
|
-
|
204
|
-
# reshape for removing repeat_kv
|
205
|
-
key_states = key_states.unsqueeze(2)
|
206
|
-
value_states = value_states.unsqueeze(2)
|
207
|
-
attention_mask = attention_mask.unsqueeze(2)
|
208
|
-
query_states = query_states.view(
|
209
|
-
1,
|
210
|
-
self.num_key_value_heads,
|
211
|
-
self.num_heads // self.num_key_value_heads,
|
212
|
-
q_len,
|
213
|
-
self.head_dim,
|
214
|
-
)
|
215
|
-
|
216
|
-
key_states, value_states = past_key_value.update(
|
155
|
+
attn_output, key_states, value_states = PhiAttention._attn(
|
156
|
+
self,
|
157
|
+
query_states,
|
217
158
|
key_states,
|
218
159
|
value_states,
|
219
|
-
|
220
|
-
|
221
|
-
|
222
|
-
|
223
|
-
|
224
|
-
# Queries and keys upcast to fp32 is required by Phi-2 to avoid overflow
|
225
|
-
attn_weights = torch.matmul(
|
226
|
-
query_states.to(torch.float32),
|
227
|
-
key_states.to(torch.float32).transpose(3, 4),
|
228
|
-
) / math.sqrt(self.head_dim)
|
229
|
-
attn_weights = attn_weights + attention_mask
|
230
|
-
|
231
|
-
# upcast attention to fp32
|
232
|
-
attn_weights = torch.nn.functional.softmax(
|
233
|
-
attn_weights, dim=-1, dtype=torch.float32
|
234
|
-
).to(value_states.dtype)
|
235
|
-
attn_weights = torch.nn.functional.dropout(
|
236
|
-
attn_weights, p=self.attention_dropout, training=self.training
|
237
|
-
)
|
238
|
-
attn_output = torch.matmul(attn_weights, value_states)
|
239
|
-
|
240
|
-
# reshape for removing repeat_kv
|
241
|
-
attn_output = attn_output.view(1, self.num_heads, q_len, self.head_dim)
|
242
|
-
attn_output = attn_output.transpose(1, 2).contiguous()
|
243
|
-
attn_output = attn_output.reshape(
|
244
|
-
bsz, q_len, self.num_heads * self.head_dim
|
160
|
+
attention_mask,
|
161
|
+
past_key_value,
|
162
|
+
batch_idx=batch_index,
|
163
|
+
is_prefill=True,
|
245
164
|
)
|
246
165
|
|
247
166
|
attn_output = self.dense(attn_output)
|
@@ -265,12 +184,9 @@ class PhiDecoderLayer:
|
|
265
184
|
batch_ids: Optional[torch.LongTensor] = None,
|
266
185
|
cos: Optional[torch.Tensor] = None,
|
267
186
|
sin: Optional[torch.Tensor] = None,
|
268
|
-
rotary_pos_emb=None,
|
269
187
|
forward_dict: Optional[Dict[str, classmethod]] = None,
|
270
188
|
**kwargs,
|
271
|
-
) -> Tuple[
|
272
|
-
torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]
|
273
|
-
]:
|
189
|
+
) -> Tuple[torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]]:
|
274
190
|
"""
|
275
191
|
Args:
|
276
192
|
hidden_states (`torch.FloatTensor`):
|
@@ -294,9 +210,7 @@ class PhiDecoderLayer:
|
|
294
210
|
hidden_states = self.input_layernorm(hidden_states)
|
295
211
|
|
296
212
|
# Self Attention
|
297
|
-
attn_outputs, self_attn_weights, key_states, value_states = forward_dict[
|
298
|
-
"decoder_layer"
|
299
|
-
](
|
213
|
+
attn_outputs, self_attn_weights, key_states, value_states = forward_dict["decoder_layer"](
|
300
214
|
self.self_attn,
|
301
215
|
hidden_states=hidden_states,
|
302
216
|
attention_mask=attention_mask,
|
@@ -307,7 +221,6 @@ class PhiDecoderLayer:
|
|
307
221
|
use_cache=use_cache,
|
308
222
|
cos=cos,
|
309
223
|
sin=sin,
|
310
|
-
rotary_pos_emb=rotary_pos_emb,
|
311
224
|
**kwargs,
|
312
225
|
)
|
313
226
|
past_key_value.assign(key_states, value_states, layer_idx)
|
@@ -339,6 +252,8 @@ class PhiModel:
|
|
339
252
|
use_cache: Optional[bool] = True,
|
340
253
|
output_attentions: Optional[bool] = False,
|
341
254
|
output_hidden_states: Optional[bool] = False,
|
255
|
+
cache_pos_for_partitions: Optional[torch.Tensor] = None,
|
256
|
+
kvcache_partition_size: Optional[torch.Tensor] = None,
|
342
257
|
forward_dict: Optional[Dict[str, classmethod]] = None,
|
343
258
|
rotary_pos_emb=None,
|
344
259
|
) -> BaseModelOutputWithPast:
|
@@ -378,7 +293,8 @@ class PhiModel:
|
|
378
293
|
batch_ids=batch_ids,
|
379
294
|
cos=cos,
|
380
295
|
sin=sin,
|
381
|
-
|
296
|
+
cache_pos_for_partitions=cache_pos_for_partitions,
|
297
|
+
kvcache_partition_size=kvcache_partition_size,
|
382
298
|
forward_dict=forward_dict,
|
383
299
|
)
|
384
300
|
|
@@ -0,0 +1,24 @@
|
|
1
|
+
# Copyright 2024 Rebellions Inc.
|
2
|
+
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4
|
+
# you may not use this file except in compliance with the License.
|
5
|
+
# You may obtain a copy of the License at:
|
6
|
+
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
8
|
+
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12
|
+
# See the License for the specific language governing permissions and
|
13
|
+
# limitations under the License.
|
14
|
+
|
15
|
+
# Portions of this software are licensed under the Apache License,
|
16
|
+
# Version 2.0. See the NOTICE file distributed with this work for
|
17
|
+
# additional information regarding copyright ownership.
|
18
|
+
|
19
|
+
# All other portions of this software, including proprietary code,
|
20
|
+
# are the intellectual property of Rebellions Inc. and may not be
|
21
|
+
# copied, modified, or distributed without prior written permission
|
22
|
+
# from Rebellions Inc.
|
23
|
+
|
24
|
+
from .modeling_qwen2 import RBLNQwen2ForCausalLM
|
@@ -0,0 +1,43 @@
|
|
1
|
+
# Copyright 2024 Rebellions Inc.
|
2
|
+
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4
|
+
# you may not use this file except in compliance with the License.
|
5
|
+
# You may obtain a copy of the License at:
|
6
|
+
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
8
|
+
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12
|
+
# See the License for the specific language governing permissions and
|
13
|
+
# limitations under the License.
|
14
|
+
|
15
|
+
# Portions of this software are licensed under the Apache License,
|
16
|
+
# Version 2.0. See the NOTICE file distributed with this work for
|
17
|
+
# additional information regarding copyright ownership.
|
18
|
+
|
19
|
+
# All other portions of this software, including proprietary code,
|
20
|
+
# are the intellectual property of Rebellions Inc. and may not be
|
21
|
+
# copied, modified, or distributed without prior written permission
|
22
|
+
# from Rebellions Inc.
|
23
|
+
|
24
|
+
from ....utils import logging
|
25
|
+
from ...models.decoderonly import RBLNDecoderOnlyModelForCausalLM
|
26
|
+
from .qwen2_architecture import QWEN2Wrapper
|
27
|
+
|
28
|
+
|
29
|
+
logger = logging.get_logger(__name__)
|
30
|
+
|
31
|
+
|
32
|
+
class RBLNQwen2ForCausalLM(RBLNDecoderOnlyModelForCausalLM):
|
33
|
+
"""
|
34
|
+
The Llama Model transformer with a language modeling head (linear layer) on top.
|
35
|
+
This model inherits from [`RBLNDecoderOnlyModelForCausalLM`]. Check the superclass documentation for the generic methods the library implements for all its models.
|
36
|
+
|
37
|
+
A class to convert and run pre-trained transformers based LlamaForCausalLM model on RBLN devices.
|
38
|
+
It implements the methods to convert a pre-trained transformers LlamaForCausalLM model into a RBLN transformer model by:
|
39
|
+
- transferring the checkpoint weights of the original into an optimized RBLN graph,
|
40
|
+
- compiling the resulting graph using the RBLN compiler.
|
41
|
+
"""
|
42
|
+
|
43
|
+
_decoder_wrapper_cls = QWEN2Wrapper
|
@@ -0,0 +1,29 @@
|
|
1
|
+
# Copyright 2024 Rebellions Inc.
|
2
|
+
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4
|
+
# you may not use this file except in compliance with the License.
|
5
|
+
# You may obtain a copy of the License at:
|
6
|
+
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
8
|
+
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12
|
+
# See the License for the specific language governing permissions and
|
13
|
+
# limitations under the License.
|
14
|
+
|
15
|
+
# Portions of this software are licensed under the Apache License,
|
16
|
+
# Version 2.0. See the NOTICE file distributed with this work for
|
17
|
+
# additional information regarding copyright ownership.
|
18
|
+
|
19
|
+
# All other portions of this software, including proprietary code,
|
20
|
+
# are the intellectual property of Rebellions Inc. and may not be
|
21
|
+
# copied, modified, or distributed without prior written permission
|
22
|
+
# from Rebellions Inc.
|
23
|
+
|
24
|
+
|
25
|
+
from ..decoderonly.decoderonly_architecture import DecoderOnlyWrapper
|
26
|
+
|
27
|
+
|
28
|
+
class QWEN2Wrapper(DecoderOnlyWrapper):
|
29
|
+
pass
|
@@ -0,0 +1,24 @@
|
|
1
|
+
# Copyright 2024 Rebellions Inc.
|
2
|
+
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4
|
+
# you may not use this file except in compliance with the License.
|
5
|
+
# You may obtain a copy of the License at:
|
6
|
+
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
8
|
+
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12
|
+
# See the License for the specific language governing permissions and
|
13
|
+
# limitations under the License.
|
14
|
+
|
15
|
+
# Portions of this software are licensed under the Apache License,
|
16
|
+
# Version 2.0. See the NOTICE file distributed with this work for
|
17
|
+
# additional information regarding copyright ownership.
|
18
|
+
|
19
|
+
# All other portions of this software, including proprietary code,
|
20
|
+
# are the intellectual property of Rebellions Inc. and may not be
|
21
|
+
# copied, modified, or distributed without prior written permission
|
22
|
+
# from Rebellions Inc.
|
23
|
+
|
24
|
+
from .modeling_seq2seq import RBLNModelForSeq2SeqLM
|