optimum-rbln 0.1.12__py3-none-any.whl → 0.1.15__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- optimum/rbln/__init__.py +27 -13
- optimum/rbln/__version__.py +16 -1
- optimum/rbln/diffusers/__init__.py +22 -2
- optimum/rbln/diffusers/models/__init__.py +34 -3
- optimum/rbln/{transformers/generation → diffusers/models/autoencoders}/__init__.py +1 -2
- optimum/rbln/diffusers/models/{autoencoder_kl.py → autoencoders/autoencoder_kl.py} +66 -111
- optimum/rbln/diffusers/models/autoencoders/vae.py +84 -0
- optimum/rbln/diffusers/models/controlnet.py +85 -65
- optimum/rbln/diffusers/models/transformers/__init__.py +24 -0
- optimum/rbln/diffusers/models/transformers/transformer_sd3.py +203 -0
- optimum/rbln/diffusers/models/unets/__init__.py +24 -0
- optimum/rbln/diffusers/models/{unet_2d_condition.py → unets/unet_2d_condition.py} +129 -163
- optimum/rbln/diffusers/pipelines/__init__.py +60 -12
- optimum/rbln/diffusers/pipelines/controlnet/multicontrolnet.py +11 -25
- optimum/rbln/diffusers/pipelines/controlnet/pipeline_controlnet.py +9 -185
- optimum/rbln/diffusers/pipelines/controlnet/pipeline_controlnet_img2img.py +9 -190
- optimum/rbln/diffusers/pipelines/controlnet/pipeline_controlnet_sd_xl.py +9 -191
- optimum/rbln/diffusers/pipelines/controlnet/pipeline_controlnet_sd_xl_img2img.py +9 -192
- optimum/rbln/diffusers/pipelines/stable_diffusion/__init__.py +1 -0
- optimum/rbln/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion.py +4 -110
- optimum/rbln/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_img2img.py +4 -118
- optimum/rbln/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_inpaint.py +32 -0
- optimum/rbln/diffusers/pipelines/stable_diffusion_3/__init__.py +26 -0
- optimum/rbln/diffusers/pipelines/stable_diffusion_3/pipeline_stable_diffusion_3.py +32 -0
- optimum/rbln/diffusers/pipelines/stable_diffusion_3/pipeline_stable_diffusion_3_img2img.py +32 -0
- optimum/rbln/diffusers/pipelines/stable_diffusion_3/pipeline_stable_diffusion_3_inpaint.py +32 -0
- optimum/rbln/diffusers/pipelines/stable_diffusion_xl/__init__.py +1 -0
- optimum/rbln/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl.py +18 -128
- optimum/rbln/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_img2img.py +18 -131
- optimum/rbln/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_inpaint.py +32 -0
- optimum/rbln/modeling.py +572 -0
- optimum/rbln/modeling_alias.py +1 -1
- optimum/rbln/modeling_base.py +176 -763
- optimum/rbln/modeling_diffusers.py +329 -0
- optimum/rbln/transformers/__init__.py +2 -2
- optimum/rbln/transformers/cache_utils.py +5 -9
- optimum/rbln/transformers/modeling_rope_utils.py +283 -0
- optimum/rbln/transformers/models/__init__.py +80 -31
- optimum/rbln/transformers/models/auto/auto_factory.py +117 -23
- optimum/rbln/transformers/models/auto/modeling_auto.py +37 -12
- optimum/rbln/transformers/models/bart/modeling_bart.py +3 -6
- optimum/rbln/transformers/models/bert/modeling_bert.py +3 -6
- optimum/rbln/transformers/models/clip/modeling_clip.py +8 -34
- optimum/rbln/transformers/models/decoderonly/__init__.py +0 -5
- optimum/rbln/transformers/models/decoderonly/decoderonly_architecture.py +779 -361
- optimum/rbln/transformers/models/decoderonly/modeling_decoderonly.py +83 -142
- optimum/rbln/transformers/models/dpt/modeling_dpt.py +1 -1
- optimum/rbln/transformers/models/exaone/exaone_architecture.py +64 -39
- optimum/rbln/transformers/models/exaone/modeling_exaone.py +6 -29
- optimum/rbln/transformers/models/gemma/gemma_architecture.py +31 -92
- optimum/rbln/transformers/models/gemma/modeling_gemma.py +4 -28
- optimum/rbln/transformers/models/gpt2/gpt2_architecture.py +50 -238
- optimum/rbln/transformers/models/gpt2/modeling_gpt2.py +6 -31
- optimum/rbln/transformers/models/llama/modeling_llama.py +4 -28
- optimum/rbln/transformers/models/llava_next/modeling_llava_next.py +29 -83
- optimum/rbln/transformers/models/midm/midm_architecture.py +88 -253
- optimum/rbln/transformers/models/midm/modeling_midm.py +8 -33
- optimum/rbln/transformers/models/mistral/modeling_mistral.py +4 -29
- optimum/rbln/transformers/models/phi/modeling_phi.py +5 -31
- optimum/rbln/transformers/models/phi/phi_architecture.py +61 -345
- optimum/rbln/transformers/models/qwen2/modeling_qwen2.py +5 -29
- optimum/rbln/transformers/models/seq2seq/modeling_seq2seq.py +1 -46
- optimum/rbln/transformers/models/t5/__init__.py +1 -1
- optimum/rbln/transformers/models/t5/modeling_t5.py +157 -6
- optimum/rbln/transformers/models/wav2vec2/modeling_wav2vec2.py +1 -1
- optimum/rbln/transformers/models/whisper/modeling_whisper.py +2 -2
- optimum/rbln/transformers/models/xlm_roberta/modeling_xlm_roberta.py +3 -35
- optimum/rbln/transformers/utils/rbln_quantization.py +128 -5
- optimum/rbln/utils/decorator_utils.py +59 -0
- optimum/rbln/utils/hub.py +131 -0
- optimum/rbln/utils/import_utils.py +21 -0
- optimum/rbln/utils/model_utils.py +53 -0
- optimum/rbln/utils/runtime_utils.py +5 -5
- optimum/rbln/utils/submodule.py +114 -0
- optimum/rbln/utils/timer_utils.py +2 -2
- optimum_rbln-0.1.15.dist-info/METADATA +106 -0
- optimum_rbln-0.1.15.dist-info/RECORD +110 -0
- {optimum_rbln-0.1.12.dist-info → optimum_rbln-0.1.15.dist-info}/WHEEL +1 -1
- optimum/rbln/transformers/generation/streamers.py +0 -139
- optimum/rbln/transformers/generation/utils.py +0 -397
- optimum/rbln/transformers/models/exaone/hf_hub_cached/configuration_exaone.py +0 -181
- optimum/rbln/transformers/models/exaone/hf_hub_cached/modeling_exaone.py +0 -1725
- optimum/rbln/transformers/models/midm/hf_hub_cached/configuration_midm.py +0 -22
- optimum/rbln/transformers/models/midm/hf_hub_cached/midm_bitext_tokenization.py +0 -304
- optimum/rbln/transformers/models/midm/hf_hub_cached/modeling_midm.py +0 -1469
- optimum/rbln/transformers/models/midm/hf_hub_cached/rotary_position_embedding.py +0 -98
- optimum_rbln-0.1.12.dist-info/METADATA +0 -119
- optimum_rbln-0.1.12.dist-info/RECORD +0 -103
- optimum_rbln-0.1.12.dist-info/entry_points.txt +0 -4
- {optimum_rbln-0.1.12.dist-info → optimum_rbln-0.1.15.dist-info}/licenses/LICENSE +0 -0
@@ -21,28 +21,18 @@
|
|
21
21
|
# copied, modified, or distributed without prior written permission
|
22
22
|
# from Rebellions Inc.
|
23
23
|
|
24
|
-
import
|
25
|
-
import
|
26
|
-
from typing import TYPE_CHECKING, Any, Callable
|
27
|
-
|
28
|
-
from transformers import PhiForCausalLM
|
29
|
-
|
30
|
-
from ..decoderonly import RBLNDecoderOnlyModelForCausalLM
|
24
|
+
from ....utils import logging
|
25
|
+
from ...models.decoderonly import RBLNDecoderOnlyModelForCausalLM
|
31
26
|
from .phi_architecture import PhiWrapper
|
32
27
|
|
33
28
|
|
34
|
-
|
35
|
-
from transformers import PreTrainedModel
|
36
|
-
|
37
|
-
from ....modeling_config import RBLNConfig
|
38
|
-
|
39
|
-
logger = logging.getLogger(__name__)
|
29
|
+
logger = logging.get_logger(__name__)
|
40
30
|
|
41
31
|
|
42
32
|
class RBLNPhiForCausalLM(RBLNDecoderOnlyModelForCausalLM):
|
43
33
|
"""
|
44
34
|
The Phi Model transformer with a language modeling head (linear layer) on top.
|
45
|
-
This model inherits from [`
|
35
|
+
This model inherits from [`RBLNDecoderOnlyModelForCausalLM`]. Check the superclass documentation for the generic methods the library implements for all its models.
|
46
36
|
|
47
37
|
A class to convert and run pre-trained transformers based PhiForCausalLM model on RBLN devices.
|
48
38
|
It implements the methods to convert a pre-trained transformers PhiForCausalLM model into a RBLN transformer model by:
|
@@ -50,20 +40,4 @@ class RBLNPhiForCausalLM(RBLNDecoderOnlyModelForCausalLM):
|
|
50
40
|
- compiling the resulting graph using the RBLN compiler.
|
51
41
|
"""
|
52
42
|
|
53
|
-
|
54
|
-
def wrap_model_if_needed(self, model: "PreTrainedModel", rbln_config: "RBLNConfig"):
|
55
|
-
rbln_max_seq_len = rbln_config.model_cfg["max_seq_len"]
|
56
|
-
return PhiWrapper(model, rbln_max_seq_len).eval()
|
57
|
-
|
58
|
-
def __getattr__(self, __name: str) -> Any:
|
59
|
-
def redirect(func):
|
60
|
-
return lambda *pargs, **kwargs: func(self, *pargs, **kwargs)
|
61
|
-
|
62
|
-
val = getattr(PhiForCausalLM, __name)
|
63
|
-
|
64
|
-
if isinstance(val, Callable) and "self" in set(
|
65
|
-
inspect.signature(val).parameters
|
66
|
-
):
|
67
|
-
return redirect(val)
|
68
|
-
|
69
|
-
return val
|
43
|
+
_decoder_wrapper_cls = PhiWrapper
|
@@ -21,386 +21,102 @@
|
|
21
21
|
# copied, modified, or distributed without prior written permission
|
22
22
|
# from Rebellions Inc.
|
23
23
|
|
24
|
-
import
|
25
|
-
from typing import Dict, Optional, Tuple
|
24
|
+
from typing import TYPE_CHECKING, Optional, Tuple
|
26
25
|
|
27
26
|
import torch
|
28
|
-
|
29
|
-
from transformers.modeling_outputs import (
|
30
|
-
BaseModelOutputWithPast,
|
31
|
-
)
|
27
|
+
from transformers import PhiForCausalLM
|
32
28
|
|
33
|
-
from
|
34
|
-
|
29
|
+
from ..decoderonly.decoderonly_architecture import (
|
30
|
+
DecoderOnlyAttention,
|
31
|
+
DecoderOnlyForCausalLM,
|
32
|
+
DecoderOnlyLayer,
|
33
|
+
DecoderOnlyModel,
|
35
34
|
DecoderOnlyWrapper,
|
36
|
-
|
37
|
-
LinearScalingRotaryEmbedding,
|
38
|
-
RotaryEmbedding,
|
39
|
-
apply_rotary_pos_emb,
|
40
|
-
slice_and_unsqueeze_cos_sin,
|
35
|
+
apply_rotary_pos_emb_partial,
|
41
36
|
)
|
42
37
|
|
43
38
|
|
44
|
-
|
45
|
-
|
46
|
-
if self.rope_scaling is None:
|
47
|
-
rotary_emb = RotaryEmbedding(
|
48
|
-
int(self.config.partial_rotary_factor * self.head_dim),
|
49
|
-
max_position_embeddings=self.max_position_embeddings,
|
50
|
-
base=self.config.rope_theta,
|
51
|
-
)
|
52
|
-
else:
|
53
|
-
scaling_type = self.rope_scaling["type"]
|
54
|
-
scaling_factor = self.rope_scaling["factor"]
|
55
|
-
if scaling_type == "linear":
|
56
|
-
rotary_emb = LinearScalingRotaryEmbedding(
|
57
|
-
int(self.config.partial_rotary_factor * self.head_dim),
|
58
|
-
max_position_embeddings=self.max_position_embeddings,
|
59
|
-
scaling_factor=scaling_factor,
|
60
|
-
base=self.config.rope_theta,
|
61
|
-
max_seq_len=self.max_seq_len,
|
62
|
-
)
|
63
|
-
elif scaling_type == "dynamic":
|
64
|
-
rotary_emb = DynamicNTKScalingRotaryEmbedding(
|
65
|
-
int(self.config.partial_rotary_factor * self.head_dim),
|
66
|
-
max_position_embeddings=self.max_position_embeddings,
|
67
|
-
scaling_factor=scaling_factor,
|
68
|
-
base=self.config.rope_theta,
|
69
|
-
max_seq_len=self.max_seq_len,
|
70
|
-
)
|
71
|
-
else:
|
72
|
-
raise ValueError(f"Unknown RoPE scaling type {scaling_type}")
|
73
|
-
|
74
|
-
return rotary_emb
|
75
|
-
|
76
|
-
def get_forward_dict(self):
|
77
|
-
forward_dict = {}
|
78
|
-
forward_dict.update(
|
79
|
-
{
|
80
|
-
"wrapper": PhiModel.forward,
|
81
|
-
"model": PhiDecoderLayer.forward,
|
82
|
-
"decoder_layer": PhiAttention.forward,
|
83
|
-
}
|
84
|
-
)
|
85
|
-
return forward_dict
|
86
|
-
|
39
|
+
if TYPE_CHECKING:
|
40
|
+
from transformers import PhiForCausalLM
|
87
41
|
|
88
|
-
class PhiAttention:
|
89
|
-
def forward(
|
90
|
-
self,
|
91
|
-
hidden_states: torch.Tensor,
|
92
|
-
attention_mask: Optional[torch.Tensor] = None,
|
93
|
-
past_key_value: Optional[RebelDynamicCache] = None,
|
94
|
-
batch_index: Optional[int] = None,
|
95
|
-
output_attentions: bool = False,
|
96
|
-
cos: Optional[torch.Tensor] = None,
|
97
|
-
sin: Optional[torch.Tensor] = None,
|
98
|
-
rotary_pos_emb=None,
|
99
|
-
**kwargs,
|
100
|
-
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
|
101
|
-
bsz, q_len, _ = hidden_states.size()
|
102
42
|
|
43
|
+
class PhiWrapper(DecoderOnlyWrapper):
|
44
|
+
def convert_to_rbln_causal_lm(self, causal_lm: "PhiForCausalLM"):
|
45
|
+
new_layers = []
|
46
|
+
for layer in causal_lm.model.layers:
|
47
|
+
if self.attn_impl == "eager":
|
48
|
+
new_self_attn = PhiAttention(layer.self_attn)
|
49
|
+
elif self.attn_impl == "flash_attn":
|
50
|
+
raise NotImplementedError(f"flash attn for {self.__class__} is not implemented yet.")
|
51
|
+
else:
|
52
|
+
raise NotImplementedError(f"Unknwon attn : {self.attn_impl}")
|
53
|
+
new_layer = PhiLayer(layer, new_self_attn)
|
54
|
+
new_layers.append(new_layer)
|
55
|
+
new_model = PhiModel(causal_lm.model, new_layers)
|
56
|
+
new_causal_lm = DecoderOnlyForCausalLM(causal_lm, new_model)
|
57
|
+
return new_causal_lm
|
58
|
+
|
59
|
+
|
60
|
+
class PhiAttention(DecoderOnlyAttention):
|
61
|
+
def __post_init__(self):
|
62
|
+
self.q_proj = self._original_mod.q_proj
|
63
|
+
self.k_proj = self._original_mod.k_proj
|
64
|
+
self.v_proj = self._original_mod.v_proj
|
65
|
+
self.o_proj = self._original_mod.dense
|
66
|
+
self.qk_layernorm = self._original_mod.qk_layernorm
|
67
|
+
self.rotary_ndims = self._original_mod.rotary_ndims
|
68
|
+
self.num_key_value_heads = self.num_heads
|
69
|
+
|
70
|
+
def projection(self, hidden_states) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
|
103
71
|
query_states = self.q_proj(hidden_states)
|
104
72
|
key_states = self.k_proj(hidden_states)
|
105
73
|
value_states = self.v_proj(hidden_states)
|
106
74
|
|
107
75
|
if self.qk_layernorm:
|
108
|
-
query_states = self.q_layernorm(query_states)
|
109
|
-
key_states = self.k_layernorm(key_states)
|
110
|
-
|
111
|
-
query_states = query_states.view(
|
112
|
-
bsz, q_len, self.num_heads, self.head_dim
|
113
|
-
).transpose(1, 2)
|
114
|
-
key_states = key_states.view(
|
115
|
-
bsz, q_len, self.num_key_value_heads, self.head_dim
|
116
|
-
).transpose(1, 2)
|
117
|
-
value_states = value_states.view(
|
118
|
-
bsz, q_len, self.num_key_value_heads, self.head_dim
|
119
|
-
).transpose(1, 2)
|
120
|
-
|
121
|
-
# Partial rotary embedding
|
122
|
-
query_rot, query_pass = (
|
123
|
-
query_states[..., : rotary_pos_emb.dim],
|
124
|
-
query_states[..., rotary_pos_emb.dim :],
|
125
|
-
)
|
126
|
-
key_rot, key_pass = (
|
127
|
-
key_states[..., : rotary_pos_emb.dim],
|
128
|
-
key_states[..., rotary_pos_emb.dim :],
|
129
|
-
)
|
130
|
-
|
131
|
-
# [batch_size, seq_length, num_heads, head_dim // config.partial_rotary_factor]
|
132
|
-
query_rot, key_rot = apply_rotary_pos_emb(query_rot, key_rot, cos, sin)
|
133
|
-
|
134
|
-
# [batch_size, seq_length, num_heads, head_dim]
|
135
|
-
query_states = torch.cat((query_rot, query_pass), dim=-1)
|
136
|
-
key_states = torch.cat((key_rot, key_pass), dim=-1)
|
137
|
-
|
138
|
-
# Decoder
|
139
|
-
if (batch_index is None or batch_index == -1) and bsz > 1:
|
140
|
-
all_key_states = []
|
141
|
-
all_value_states = []
|
142
|
-
all_attn_output = []
|
143
|
-
|
144
|
-
for b in range(bsz):
|
145
|
-
query_state = query_states[b].unsqueeze(0)
|
146
|
-
attn_mask = attention_mask[b].unsqueeze(0)
|
147
|
-
key_state = key_states[b].unsqueeze(0)
|
148
|
-
value_state = value_states[b].unsqueeze(0)
|
149
|
-
|
150
|
-
# reshape for removing repeat_kv (batch=1 , num_head, 1, q_len=1, head_dim)
|
151
|
-
key_state = key_state.unsqueeze(2)
|
152
|
-
value_state = value_state.unsqueeze(2)
|
153
|
-
attn_mask = attn_mask.unsqueeze(2)
|
154
|
-
|
155
|
-
query_state = query_state.view(
|
156
|
-
1,
|
157
|
-
self.num_key_value_heads,
|
158
|
-
self.num_heads // self.num_key_value_heads,
|
159
|
-
q_len,
|
160
|
-
self.head_dim,
|
161
|
-
)
|
162
|
-
|
163
|
-
key_state, value_state = past_key_value.update(
|
164
|
-
key_state,
|
165
|
-
value_state,
|
166
|
-
self.layer_idx,
|
167
|
-
b,
|
168
|
-
)
|
169
|
-
|
170
|
-
# Queries and keys upcast to fp32 is required by Phi-2 to avoid overflow
|
171
|
-
attn_weights = torch.matmul(
|
172
|
-
query_state.to(torch.float32),
|
173
|
-
key_state.to(torch.float32).transpose(3, 4),
|
174
|
-
) / math.sqrt(self.head_dim)
|
175
|
-
attn_weights = attn_weights + attn_mask
|
176
|
-
|
177
|
-
# upcast attention to fp32
|
178
|
-
attn_weights = nn.functional.softmax(
|
179
|
-
attn_weights, dim=-1, dtype=torch.float32
|
180
|
-
).to(query_states.dtype)
|
181
|
-
attn_weights = nn.functional.dropout(
|
182
|
-
attn_weights, p=self.attention_dropout, training=self.training
|
183
|
-
)
|
184
|
-
attn_output = torch.matmul(attn_weights, value_state)
|
76
|
+
query_states = self._original_mod.q_layernorm(query_states)
|
77
|
+
key_states = self._original_mod.k_layernorm(key_states)
|
185
78
|
|
186
|
-
|
187
|
-
attn_output = attn_output.view(1, self.num_heads, q_len, self.head_dim)
|
188
|
-
attn_output = attn_output.transpose(1, 2).contiguous()
|
189
|
-
attn_output = attn_output.reshape(
|
190
|
-
1, q_len, self.num_heads * self.head_dim
|
191
|
-
)
|
79
|
+
return query_states, key_states, value_states
|
192
80
|
|
193
|
-
|
194
|
-
|
195
|
-
all_attn_output.append(attn_output)
|
81
|
+
def apply_rotary_pos_embed(self, query_states, key_states, cos, sin):
|
82
|
+
return apply_rotary_pos_emb_partial(query_states, key_states, cos, sin, ndim=self.rotary_ndims)
|
196
83
|
|
197
|
-
key_states = torch.cat(all_key_states, dim=0)
|
198
|
-
value_states = torch.cat(all_value_states, dim=0)
|
199
|
-
attn_output = torch.cat(all_attn_output, dim=0)
|
200
|
-
else:
|
201
|
-
if batch_index is None or batch_index == -1:
|
202
|
-
batch_index = 0
|
203
84
|
|
204
|
-
|
205
|
-
|
206
|
-
|
207
|
-
attention_mask = attention_mask.unsqueeze(2)
|
208
|
-
query_states = query_states.view(
|
209
|
-
1,
|
210
|
-
self.num_key_value_heads,
|
211
|
-
self.num_heads // self.num_key_value_heads,
|
212
|
-
q_len,
|
213
|
-
self.head_dim,
|
214
|
-
)
|
85
|
+
class PhiLayer(DecoderOnlyLayer):
|
86
|
+
def get_post_attention_layernorm(self):
|
87
|
+
raise NotImplementedError
|
215
88
|
|
216
|
-
key_states, value_states = past_key_value.update(
|
217
|
-
key_states,
|
218
|
-
value_states,
|
219
|
-
self.layer_idx,
|
220
|
-
batch_index,
|
221
|
-
read_first_step=True,
|
222
|
-
)
|
223
|
-
|
224
|
-
# Queries and keys upcast to fp32 is required by Phi-2 to avoid overflow
|
225
|
-
attn_weights = torch.matmul(
|
226
|
-
query_states.to(torch.float32),
|
227
|
-
key_states.to(torch.float32).transpose(3, 4),
|
228
|
-
) / math.sqrt(self.head_dim)
|
229
|
-
attn_weights = attn_weights + attention_mask
|
230
|
-
|
231
|
-
# upcast attention to fp32
|
232
|
-
attn_weights = torch.nn.functional.softmax(
|
233
|
-
attn_weights, dim=-1, dtype=torch.float32
|
234
|
-
).to(value_states.dtype)
|
235
|
-
attn_weights = torch.nn.functional.dropout(
|
236
|
-
attn_weights, p=self.attention_dropout, training=self.training
|
237
|
-
)
|
238
|
-
attn_output = torch.matmul(attn_weights, value_states)
|
239
|
-
|
240
|
-
# reshape for removing repeat_kv
|
241
|
-
attn_output = attn_output.view(1, self.num_heads, q_len, self.head_dim)
|
242
|
-
attn_output = attn_output.transpose(1, 2).contiguous()
|
243
|
-
attn_output = attn_output.reshape(
|
244
|
-
bsz, q_len, self.num_heads * self.head_dim
|
245
|
-
)
|
246
|
-
|
247
|
-
attn_output = self.dense(attn_output)
|
248
|
-
|
249
|
-
if not output_attentions:
|
250
|
-
attn_weights = None
|
251
|
-
|
252
|
-
return attn_output, attn_weights, key_states, value_states
|
253
|
-
|
254
|
-
|
255
|
-
class PhiDecoderLayer:
|
256
89
|
def forward(
|
257
90
|
self,
|
258
91
|
hidden_states: torch.Tensor,
|
259
|
-
|
260
|
-
|
261
|
-
|
262
|
-
|
263
|
-
output_attentions: Optional[bool] = None,
|
264
|
-
use_cache: Optional[bool] = None,
|
265
|
-
batch_ids: Optional[torch.LongTensor] = None,
|
92
|
+
attention_mask: torch.Tensor,
|
93
|
+
current_steps: torch.LongTensor,
|
94
|
+
batch_position: torch.Tensor,
|
95
|
+
past_key_values: Tuple[Tuple[torch.Tensor]],
|
266
96
|
cos: Optional[torch.Tensor] = None,
|
267
97
|
sin: Optional[torch.Tensor] = None,
|
268
|
-
|
269
|
-
forward_dict: Optional[Dict[str, classmethod]] = None,
|
270
|
-
**kwargs,
|
271
|
-
) -> Tuple[
|
272
|
-
torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]
|
273
|
-
]:
|
274
|
-
"""
|
275
|
-
Args:
|
276
|
-
hidden_states (`torch.FloatTensor`):
|
277
|
-
input to the layer of shape `(batch, seq_len, embed_dim)`
|
278
|
-
attention_mask (`torch.FloatTensor`, *optional*): attention mask of size
|
279
|
-
`(batch, 1, tgt_len, src_len)` where padding elements are indicated by very large negative values.
|
280
|
-
position_ids (`torch.LongTensor` of shape `({0})`, *optional*):
|
281
|
-
Indices of positions of each input sequence tokens in the position embeddings. Selected in the range
|
282
|
-
`[0, config.n_positions - 1]`. [What are position IDs?](../glossary#position-ids)
|
283
|
-
output_attentions (`bool`, *optional*):
|
284
|
-
Whether or not to return the attentions tensors of all attention layers. See `attentions` under
|
285
|
-
returned tensors for more detail.
|
286
|
-
use_cache (`bool`, *optional*):
|
287
|
-
If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding
|
288
|
-
(see `past_key_values`).
|
289
|
-
past_key_value (`Tuple(torch.FloatTensor)`, *optional*): cached past key and value projection states
|
290
|
-
"""
|
291
|
-
|
98
|
+
):
|
292
99
|
residual = hidden_states
|
293
100
|
|
294
|
-
hidden_states = self.
|
101
|
+
hidden_states = self.get_pre_attention_layernorm()(hidden_states)
|
295
102
|
|
296
|
-
|
297
|
-
attn_outputs, self_attn_weights, key_states, value_states = forward_dict[
|
298
|
-
"decoder_layer"
|
299
|
-
](
|
300
|
-
self.self_attn,
|
103
|
+
attn_outputs, present_key_values = self.self_attn(
|
301
104
|
hidden_states=hidden_states,
|
302
105
|
attention_mask=attention_mask,
|
303
|
-
|
304
|
-
|
305
|
-
|
306
|
-
batch_index=batch_ids,
|
307
|
-
use_cache=use_cache,
|
106
|
+
current_steps=current_steps,
|
107
|
+
batch_position=batch_position,
|
108
|
+
past_key_values=past_key_values,
|
308
109
|
cos=cos,
|
309
110
|
sin=sin,
|
310
|
-
rotary_pos_emb=rotary_pos_emb,
|
311
|
-
**kwargs,
|
312
111
|
)
|
313
|
-
past_key_value.assign(key_states, value_states, layer_idx)
|
314
112
|
|
315
|
-
|
113
|
+
feed_forward_hidden_states = self._original_mod.mlp(hidden_states)
|
316
114
|
|
317
|
-
feed_forward_hidden_states = self.resid_dropout(self.mlp(hidden_states))
|
318
115
|
hidden_states = attn_outputs + feed_forward_hidden_states + residual
|
319
|
-
outputs = (hidden_states,)
|
320
|
-
|
321
|
-
if output_attentions:
|
322
|
-
outputs += (self_attn_weights,)
|
323
|
-
|
324
|
-
if use_cache:
|
325
|
-
outputs += (past_key_value,)
|
326
|
-
|
327
|
-
return outputs
|
328
|
-
|
329
116
|
|
330
|
-
|
331
|
-
def forward(
|
332
|
-
self,
|
333
|
-
input_ids: torch.LongTensor = None,
|
334
|
-
attention_mask: Optional[torch.Tensor] = None,
|
335
|
-
position_ids: Optional[torch.LongTensor] = None,
|
336
|
-
past_key_values: Optional[RebelDynamicCache] = None,
|
337
|
-
batch_ids: Optional[torch.LongTensor] = None,
|
338
|
-
inputs_embeds: Optional[torch.FloatTensor] = None,
|
339
|
-
use_cache: Optional[bool] = True,
|
340
|
-
output_attentions: Optional[bool] = False,
|
341
|
-
output_hidden_states: Optional[bool] = False,
|
342
|
-
forward_dict: Optional[Dict[str, classmethod]] = None,
|
343
|
-
rotary_pos_emb=None,
|
344
|
-
) -> BaseModelOutputWithPast:
|
345
|
-
# retrieve input_ids and inputs_embeds
|
346
|
-
if (input_ids is None) ^ (inputs_embeds is not None):
|
347
|
-
raise ValueError(
|
348
|
-
"You cannot specify both input_ids and inputs_embeds at the same time, and must specify either one"
|
349
|
-
)
|
350
|
-
|
351
|
-
# embed positions
|
352
|
-
if inputs_embeds is None:
|
353
|
-
inputs_embeds = self.embed_tokens(input_ids)
|
354
|
-
|
355
|
-
hidden_states = inputs_embeds
|
356
|
-
attention_mask = (1 - attention_mask) * torch.finfo(torch.float16).min
|
357
|
-
|
358
|
-
# get cos,sin vector
|
359
|
-
cos, sin = rotary_pos_emb(inputs_embeds, attention_mask.shape[-1])
|
360
|
-
cos, sin = slice_and_unsqueeze_cos_sin(cos, sin, position_ids)
|
361
|
-
|
362
|
-
# decoder layers
|
363
|
-
all_hidden_states = () if output_hidden_states else None
|
364
|
-
all_self_attns = () if output_attentions else None
|
365
|
-
|
366
|
-
for layer_idx, decoder_layer in enumerate(self.layers):
|
367
|
-
if output_hidden_states:
|
368
|
-
all_hidden_states += (hidden_states,)
|
369
|
-
layer_outputs = forward_dict["model"](
|
370
|
-
decoder_layer,
|
371
|
-
hidden_states,
|
372
|
-
layer_idx,
|
373
|
-
attention_mask=attention_mask,
|
374
|
-
position_ids=position_ids,
|
375
|
-
past_key_value=past_key_values,
|
376
|
-
output_attentions=output_attentions,
|
377
|
-
use_cache=use_cache,
|
378
|
-
batch_ids=batch_ids,
|
379
|
-
cos=cos,
|
380
|
-
sin=sin,
|
381
|
-
rotary_pos_emb=rotary_pos_emb,
|
382
|
-
forward_dict=forward_dict,
|
383
|
-
)
|
117
|
+
return hidden_states, present_key_values
|
384
118
|
|
385
|
-
hidden_states = layer_outputs[0]
|
386
119
|
|
387
|
-
|
388
|
-
|
389
|
-
|
390
|
-
all_self_attns += (layer_outputs[1],)
|
391
|
-
|
392
|
-
hidden_states = self.final_layernorm(hidden_states)
|
393
|
-
|
394
|
-
# add hidden states from the last decoder layer
|
395
|
-
if output_hidden_states:
|
396
|
-
all_hidden_states += (hidden_states,)
|
397
|
-
|
398
|
-
# convert RebelDynamicCache to legacy Tuple[Tuple[torch.Tensor]]
|
399
|
-
next_cache = updated_cache.to_legacy_cache()
|
400
|
-
|
401
|
-
return BaseModelOutputWithPast(
|
402
|
-
last_hidden_state=hidden_states,
|
403
|
-
past_key_values=next_cache,
|
404
|
-
hidden_states=all_hidden_states,
|
405
|
-
attentions=all_self_attns,
|
406
|
-
)
|
120
|
+
class PhiModel(DecoderOnlyModel):
|
121
|
+
def get_last_layernorm(self):
|
122
|
+
return self._original_mod.final_layernorm
|
@@ -21,28 +21,18 @@
|
|
21
21
|
# copied, modified, or distributed without prior written permission
|
22
22
|
# from Rebellions Inc.
|
23
23
|
|
24
|
-
import
|
25
|
-
import
|
26
|
-
from typing import TYPE_CHECKING, Any, Callable
|
27
|
-
|
28
|
-
from transformers import Qwen2ForCausalLM
|
29
|
-
|
30
|
-
from ..decoderonly import RBLNDecoderOnlyModelForCausalLM
|
24
|
+
from ....utils import logging
|
25
|
+
from ...models.decoderonly import RBLNDecoderOnlyModelForCausalLM
|
31
26
|
from .qwen2_architecture import QWEN2Wrapper
|
32
27
|
|
33
28
|
|
34
|
-
|
35
|
-
from transformers import PreTrainedModel
|
36
|
-
|
37
|
-
from ....modeling_config import RBLNConfig
|
38
|
-
|
39
|
-
logger = logging.getLogger(__name__)
|
29
|
+
logger = logging.get_logger(__name__)
|
40
30
|
|
41
31
|
|
42
32
|
class RBLNQwen2ForCausalLM(RBLNDecoderOnlyModelForCausalLM):
|
43
33
|
"""
|
44
34
|
The Llama Model transformer with a language modeling head (linear layer) on top.
|
45
|
-
This model inherits from [`
|
35
|
+
This model inherits from [`RBLNDecoderOnlyModelForCausalLM`]. Check the superclass documentation for the generic methods the library implements for all its models.
|
46
36
|
|
47
37
|
A class to convert and run pre-trained transformers based LlamaForCausalLM model on RBLN devices.
|
48
38
|
It implements the methods to convert a pre-trained transformers LlamaForCausalLM model into a RBLN transformer model by:
|
@@ -50,18 +40,4 @@ class RBLNQwen2ForCausalLM(RBLNDecoderOnlyModelForCausalLM):
|
|
50
40
|
- compiling the resulting graph using the RBLN compiler.
|
51
41
|
"""
|
52
42
|
|
53
|
-
|
54
|
-
def wrap_model_if_needed(self, model: "PreTrainedModel", rbln_config: "RBLNConfig"):
|
55
|
-
rbln_max_seq_len = rbln_config.model_cfg["max_seq_len"]
|
56
|
-
return QWEN2Wrapper(model, rbln_max_seq_len).eval()
|
57
|
-
|
58
|
-
def __getattr__(self, __name: str) -> Any:
|
59
|
-
def redirect(func):
|
60
|
-
return lambda *pargs, **kwargs: func(self, *pargs, **kwargs)
|
61
|
-
|
62
|
-
val = getattr(Qwen2ForCausalLM, __name)
|
63
|
-
|
64
|
-
if isinstance(val, Callable) and "self" in set(inspect.signature(val).parameters):
|
65
|
-
return redirect(val)
|
66
|
-
|
67
|
-
return val
|
43
|
+
_decoder_wrapper_cls = QWEN2Wrapper
|
@@ -31,7 +31,7 @@ import torch # noqa: F401
|
|
31
31
|
from transformers import AutoModelForSeq2SeqLM, GenerationConfig, PretrainedConfig, PreTrainedModel
|
32
32
|
from transformers.modeling_outputs import BaseModelOutput, Seq2SeqLMOutput
|
33
33
|
|
34
|
-
from ....
|
34
|
+
from ....modeling import RBLNModel
|
35
35
|
from ....modeling_config import DEFAULT_COMPILED_MODEL_NAME, RBLNCompileConfig, RBLNConfig
|
36
36
|
from ....utils.runtime_utils import RBLNPytorchRuntime
|
37
37
|
|
@@ -346,51 +346,6 @@ class RBLNModelForSeq2SeqLM(RBLNModel, ABC):
|
|
346
346
|
|
347
347
|
return Seq2SeqLMOutput(logits=lm_logits)
|
348
348
|
|
349
|
-
def vllm_forward(
|
350
|
-
self,
|
351
|
-
input_ids: torch.LongTensor = None,
|
352
|
-
cache_position: Union[List[torch.Tensor], torch.Tensor] = None, # vllm keyword argument
|
353
|
-
batch_idx: Optional[torch.LongTensor] = None,
|
354
|
-
enc_lengths: List[int] = None, # vllm return current attention_mask length
|
355
|
-
**kwargs,
|
356
|
-
) -> Tuple[torch.FloatTensor]:
|
357
|
-
# When using vllm, need the output of the encoder (ex. vocab_size + 100) and use that value act as start_token_id in decoder (ex. vocab_size + 99)
|
358
|
-
# encoder
|
359
|
-
if batch_idx is not None:
|
360
|
-
enc_attention_mask = torch.zeros(1, self.rbln_config.model_cfg["enc_max_seq_len"], dtype=torch.float32)
|
361
|
-
enc_attention_mask[0][: enc_lengths[batch_idx] + 1] = 1
|
362
|
-
padding_need = self.rbln_config.model_cfg["enc_max_seq_len"] - input_ids.shape[-1]
|
363
|
-
input_ids = torch.nn.functional.pad(input_ids, (0, padding_need))
|
364
|
-
_ = self.encoder(input_ids, enc_attention_mask, batch_idx=batch_idx.to(torch.int32))
|
365
|
-
logits = torch.zeros(1, 1, self.config.vocab_size + 100)
|
366
|
-
logits[0][0][-1] = 1
|
367
|
-
# decoder
|
368
|
-
else:
|
369
|
-
input_ids[input_ids == (self.config.vocab_size + 99)] = self.config.decoder_start_token_id
|
370
|
-
cache_position[cache_position != 0] = cache_position[cache_position != 0] - 2
|
371
|
-
|
372
|
-
enc_attention_mask = torch.zeros(
|
373
|
-
self.rbln_config.model_cfg["batch_size"],
|
374
|
-
self.rbln_config.model_cfg["enc_max_seq_len"],
|
375
|
-
dtype=torch.float32,
|
376
|
-
)
|
377
|
-
dec_attention_mask = torch.zeros(
|
378
|
-
self.rbln_config.model_cfg["batch_size"],
|
379
|
-
self.rbln_config.model_cfg["dec_max_seq_len"],
|
380
|
-
dtype=torch.float32,
|
381
|
-
)
|
382
|
-
for batch_idx in range(self.rbln_config.model_cfg["batch_size"]):
|
383
|
-
enc_attention_mask[batch_idx, : enc_lengths[batch_idx] + 1] = 1
|
384
|
-
|
385
|
-
logits = self._forward_decoder(
|
386
|
-
attention_mask=enc_attention_mask,
|
387
|
-
decoder_input_ids=input_ids,
|
388
|
-
decoder_attention_mask=dec_attention_mask,
|
389
|
-
cache_position=cache_position,
|
390
|
-
).logits
|
391
|
-
|
392
|
-
return Seq2SeqLMOutput(logits=logits)
|
393
|
-
|
394
349
|
def _prepare_encoder_decoder_kwargs_for_generation(
|
395
350
|
self,
|
396
351
|
inputs_tensor: torch.Tensor,
|
@@ -21,5 +21,5 @@
|
|
21
21
|
# copied, modified, or distributed without prior written permission
|
22
22
|
# from Rebellions Inc.
|
23
23
|
|
24
|
-
from .modeling_t5 import RBLNT5ForConditionalGeneration
|
24
|
+
from .modeling_t5 import RBLNT5EncoderModel, RBLNT5ForConditionalGeneration
|
25
25
|
from .t5_architecture import T5DecoderWrapper, T5EncoderWrapper
|