optimum-rbln 0.1.12__py3-none-any.whl → 0.1.13__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- optimum/rbln/__init__.py +5 -1
- optimum/rbln/__version__.py +1 -1
- optimum/rbln/diffusers/models/autoencoder_kl.py +30 -61
- optimum/rbln/diffusers/models/controlnet.py +36 -56
- optimum/rbln/diffusers/models/unet_2d_condition.py +57 -153
- optimum/rbln/diffusers/pipelines/__init__.py +40 -12
- optimum/rbln/diffusers/pipelines/controlnet/multicontrolnet.py +7 -0
- optimum/rbln/diffusers/pipelines/controlnet/pipeline_controlnet.py +9 -185
- optimum/rbln/diffusers/pipelines/controlnet/pipeline_controlnet_img2img.py +8 -190
- optimum/rbln/diffusers/pipelines/controlnet/pipeline_controlnet_sd_xl.py +8 -191
- optimum/rbln/diffusers/pipelines/controlnet/pipeline_controlnet_sd_xl_img2img.py +8 -192
- optimum/rbln/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion.py +3 -110
- optimum/rbln/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_img2img.py +12 -115
- optimum/rbln/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl.py +4 -122
- optimum/rbln/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_img2img.py +4 -125
- optimum/rbln/modeling_base.py +12 -5
- optimum/rbln/modeling_diffusers.py +400 -0
- optimum/rbln/transformers/__init__.py +2 -0
- optimum/rbln/transformers/cache_utils.py +5 -9
- optimum/rbln/transformers/modeling_rope_utils.py +283 -0
- optimum/rbln/transformers/models/__init__.py +80 -31
- optimum/rbln/transformers/models/clip/modeling_clip.py +13 -22
- optimum/rbln/transformers/models/decoderonly/__init__.py +0 -2
- optimum/rbln/transformers/models/decoderonly/decoderonly_architecture.py +376 -218
- optimum/rbln/transformers/models/decoderonly/modeling_decoderonly.py +74 -16
- optimum/rbln/transformers/models/exaone/exaone_architecture.py +18 -9
- optimum/rbln/transformers/models/exaone/modeling_exaone.py +4 -29
- optimum/rbln/transformers/models/gemma/gemma_architecture.py +12 -2
- optimum/rbln/transformers/models/gemma/modeling_gemma.py +4 -28
- optimum/rbln/transformers/models/gpt2/modeling_gpt2.py +4 -30
- optimum/rbln/transformers/models/llama/modeling_llama.py +4 -28
- optimum/rbln/transformers/models/llava_next/modeling_llava_next.py +27 -8
- optimum/rbln/transformers/models/midm/midm_architecture.py +4 -15
- optimum/rbln/transformers/models/midm/modeling_midm.py +4 -29
- optimum/rbln/transformers/models/mistral/modeling_mistral.py +4 -29
- optimum/rbln/transformers/models/phi/modeling_phi.py +5 -31
- optimum/rbln/transformers/models/phi/phi_architecture.py +75 -159
- optimum/rbln/transformers/models/qwen2/modeling_qwen2.py +5 -29
- optimum/rbln/transformers/models/t5/__init__.py +1 -1
- optimum/rbln/transformers/models/t5/modeling_t5.py +57 -4
- optimum/rbln/transformers/models/whisper/modeling_whisper.py +1 -1
- optimum/rbln/transformers/utils/rbln_quantization.py +8 -2
- optimum/rbln/utils/context.py +58 -0
- optimum/rbln/utils/decorator_utils.py +55 -0
- optimum/rbln/utils/import_utils.py +7 -0
- optimum/rbln/utils/runtime_utils.py +4 -4
- optimum/rbln/utils/timer_utils.py +2 -2
- {optimum_rbln-0.1.12.dist-info → optimum_rbln-0.1.13.dist-info}/METADATA +8 -7
- {optimum_rbln-0.1.12.dist-info → optimum_rbln-0.1.13.dist-info}/RECORD +52 -48
- {optimum_rbln-0.1.12.dist-info → optimum_rbln-0.1.13.dist-info}/WHEEL +0 -0
- {optimum_rbln-0.1.12.dist-info → optimum_rbln-0.1.13.dist-info}/entry_points.txt +0 -0
- {optimum_rbln-0.1.12.dist-info → optimum_rbln-0.1.13.dist-info}/licenses/LICENSE +0 -0
@@ -21,28 +21,18 @@
|
|
21
21
|
# copied, modified, or distributed without prior written permission
|
22
22
|
# from Rebellions Inc.
|
23
23
|
|
24
|
-
import
|
25
|
-
import
|
26
|
-
from typing import TYPE_CHECKING, Any, Callable
|
27
|
-
|
28
|
-
from transformers import PhiForCausalLM
|
29
|
-
|
30
|
-
from ..decoderonly import RBLNDecoderOnlyModelForCausalLM
|
24
|
+
from ....utils import logging
|
25
|
+
from ...models.decoderonly import RBLNDecoderOnlyModelForCausalLM
|
31
26
|
from .phi_architecture import PhiWrapper
|
32
27
|
|
33
28
|
|
34
|
-
|
35
|
-
from transformers import PreTrainedModel
|
36
|
-
|
37
|
-
from ....modeling_config import RBLNConfig
|
38
|
-
|
39
|
-
logger = logging.getLogger(__name__)
|
29
|
+
logger = logging.get_logger(__name__)
|
40
30
|
|
41
31
|
|
42
32
|
class RBLNPhiForCausalLM(RBLNDecoderOnlyModelForCausalLM):
|
43
33
|
"""
|
44
34
|
The Phi Model transformer with a language modeling head (linear layer) on top.
|
45
|
-
This model inherits from [`
|
35
|
+
This model inherits from [`RBLNDecoderOnlyModelForCausalLM`]. Check the superclass documentation for the generic methods the library implements for all its models.
|
46
36
|
|
47
37
|
A class to convert and run pre-trained transformers based PhiForCausalLM model on RBLN devices.
|
48
38
|
It implements the methods to convert a pre-trained transformers PhiForCausalLM model into a RBLN transformer model by:
|
@@ -50,20 +40,4 @@ class RBLNPhiForCausalLM(RBLNDecoderOnlyModelForCausalLM):
|
|
50
40
|
- compiling the resulting graph using the RBLN compiler.
|
51
41
|
"""
|
52
42
|
|
53
|
-
|
54
|
-
def wrap_model_if_needed(self, model: "PreTrainedModel", rbln_config: "RBLNConfig"):
|
55
|
-
rbln_max_seq_len = rbln_config.model_cfg["max_seq_len"]
|
56
|
-
return PhiWrapper(model, rbln_max_seq_len).eval()
|
57
|
-
|
58
|
-
def __getattr__(self, __name: str) -> Any:
|
59
|
-
def redirect(func):
|
60
|
-
return lambda *pargs, **kwargs: func(self, *pargs, **kwargs)
|
61
|
-
|
62
|
-
val = getattr(PhiForCausalLM, __name)
|
63
|
-
|
64
|
-
if isinstance(val, Callable) and "self" in set(
|
65
|
-
inspect.signature(val).parameters
|
66
|
-
):
|
67
|
-
return redirect(val)
|
68
|
-
|
69
|
-
return val
|
43
|
+
_decoder_wrapper_cls = PhiWrapper
|
@@ -33,46 +33,12 @@ from transformers.modeling_outputs import (
|
|
33
33
|
from ...cache_utils import RebelDynamicCache
|
34
34
|
from ..decoderonly import (
|
35
35
|
DecoderOnlyWrapper,
|
36
|
-
DynamicNTKScalingRotaryEmbedding,
|
37
|
-
LinearScalingRotaryEmbedding,
|
38
|
-
RotaryEmbedding,
|
39
36
|
apply_rotary_pos_emb,
|
40
37
|
slice_and_unsqueeze_cos_sin,
|
41
38
|
)
|
42
39
|
|
43
40
|
|
44
41
|
class PhiWrapper(DecoderOnlyWrapper):
|
45
|
-
def _init_rope(self):
|
46
|
-
if self.rope_scaling is None:
|
47
|
-
rotary_emb = RotaryEmbedding(
|
48
|
-
int(self.config.partial_rotary_factor * self.head_dim),
|
49
|
-
max_position_embeddings=self.max_position_embeddings,
|
50
|
-
base=self.config.rope_theta,
|
51
|
-
)
|
52
|
-
else:
|
53
|
-
scaling_type = self.rope_scaling["type"]
|
54
|
-
scaling_factor = self.rope_scaling["factor"]
|
55
|
-
if scaling_type == "linear":
|
56
|
-
rotary_emb = LinearScalingRotaryEmbedding(
|
57
|
-
int(self.config.partial_rotary_factor * self.head_dim),
|
58
|
-
max_position_embeddings=self.max_position_embeddings,
|
59
|
-
scaling_factor=scaling_factor,
|
60
|
-
base=self.config.rope_theta,
|
61
|
-
max_seq_len=self.max_seq_len,
|
62
|
-
)
|
63
|
-
elif scaling_type == "dynamic":
|
64
|
-
rotary_emb = DynamicNTKScalingRotaryEmbedding(
|
65
|
-
int(self.config.partial_rotary_factor * self.head_dim),
|
66
|
-
max_position_embeddings=self.max_position_embeddings,
|
67
|
-
scaling_factor=scaling_factor,
|
68
|
-
base=self.config.rope_theta,
|
69
|
-
max_seq_len=self.max_seq_len,
|
70
|
-
)
|
71
|
-
else:
|
72
|
-
raise ValueError(f"Unknown RoPE scaling type {scaling_type}")
|
73
|
-
|
74
|
-
return rotary_emb
|
75
|
-
|
76
42
|
def get_forward_dict(self):
|
77
43
|
forward_dict = {}
|
78
44
|
forward_dict.update(
|
@@ -86,6 +52,41 @@ class PhiWrapper(DecoderOnlyWrapper):
|
|
86
52
|
|
87
53
|
|
88
54
|
class PhiAttention:
|
55
|
+
def _attn(self, query_state, key_state, value_state, attn_mask, past_key_value, batch_idx=0, is_prefill=False):
|
56
|
+
# reshape for removing repeat_kv (batch=1 , num_head, 1, q_len=1, head_dim)
|
57
|
+
key_state = key_state.unsqueeze(2)
|
58
|
+
value_state = value_state.unsqueeze(2)
|
59
|
+
attn_mask = attn_mask.unsqueeze(2)
|
60
|
+
|
61
|
+
query_state = query_state.view(
|
62
|
+
1,
|
63
|
+
self.num_key_value_heads,
|
64
|
+
self.num_heads // self.num_key_value_heads,
|
65
|
+
-1,
|
66
|
+
self.head_dim,
|
67
|
+
)
|
68
|
+
|
69
|
+
key_state, value_state = past_key_value.update(key_state, value_state, self.layer_idx, batch_idx, is_prefill)
|
70
|
+
|
71
|
+
# Queries and keys upcast to fp32 is required by Phi-2 to avoid overflow
|
72
|
+
attn_weights = torch.matmul(
|
73
|
+
query_state.to(torch.float32),
|
74
|
+
key_state.to(torch.float32).transpose(3, 4),
|
75
|
+
) / math.sqrt(self.head_dim)
|
76
|
+
attn_weights = attn_weights + attn_mask
|
77
|
+
|
78
|
+
# upcast attention to fp32
|
79
|
+
attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query_state.dtype)
|
80
|
+
attn_weights = nn.functional.dropout(attn_weights, p=self.attention_dropout, training=self.training)
|
81
|
+
attn_output = torch.matmul(attn_weights, value_state)
|
82
|
+
|
83
|
+
# reshape for removing repeat_kv
|
84
|
+
attn_output = attn_output.view(1, self.num_heads, -1, self.head_dim)
|
85
|
+
attn_output = attn_output.transpose(1, 2).contiguous()
|
86
|
+
attn_output = attn_output.reshape(1, -1, self.num_heads * self.head_dim)
|
87
|
+
|
88
|
+
return attn_output, key_state, value_state
|
89
|
+
|
89
90
|
def forward(
|
90
91
|
self,
|
91
92
|
hidden_states: torch.Tensor,
|
@@ -95,7 +96,6 @@ class PhiAttention:
|
|
95
96
|
output_attentions: bool = False,
|
96
97
|
cos: Optional[torch.Tensor] = None,
|
97
98
|
sin: Optional[torch.Tensor] = None,
|
98
|
-
rotary_pos_emb=None,
|
99
99
|
**kwargs,
|
100
100
|
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
|
101
101
|
bsz, q_len, _ = hidden_states.size()
|
@@ -108,24 +108,18 @@ class PhiAttention:
|
|
108
108
|
query_states = self.q_layernorm(query_states)
|
109
109
|
key_states = self.k_layernorm(key_states)
|
110
110
|
|
111
|
-
query_states = query_states.view(
|
112
|
-
|
113
|
-
).transpose(1, 2)
|
114
|
-
key_states = key_states.view(
|
115
|
-
bsz, q_len, self.num_key_value_heads, self.head_dim
|
116
|
-
).transpose(1, 2)
|
117
|
-
value_states = value_states.view(
|
118
|
-
bsz, q_len, self.num_key_value_heads, self.head_dim
|
119
|
-
).transpose(1, 2)
|
111
|
+
query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2)
|
112
|
+
key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
|
113
|
+
value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
|
120
114
|
|
121
115
|
# Partial rotary embedding
|
122
116
|
query_rot, query_pass = (
|
123
|
-
query_states[..., :
|
124
|
-
query_states[...,
|
117
|
+
query_states[..., : self.rotary_ndims],
|
118
|
+
query_states[..., self.rotary_ndims :],
|
125
119
|
)
|
126
120
|
key_rot, key_pass = (
|
127
|
-
key_states[..., :
|
128
|
-
key_states[...,
|
121
|
+
key_states[..., : self.rotary_ndims],
|
122
|
+
key_states[..., self.rotary_ndims :],
|
129
123
|
)
|
130
124
|
|
131
125
|
# [batch_size, seq_length, num_heads, head_dim // config.partial_rotary_factor]
|
@@ -135,113 +129,38 @@ class PhiAttention:
|
|
135
129
|
query_states = torch.cat((query_rot, query_pass), dim=-1)
|
136
130
|
key_states = torch.cat((key_rot, key_pass), dim=-1)
|
137
131
|
|
138
|
-
# Decoder
|
139
|
-
if
|
140
|
-
|
141
|
-
all_value_states = []
|
142
|
-
all_attn_output = []
|
143
|
-
|
132
|
+
# Decoder (bsz > 1)
|
133
|
+
if bsz > 1:
|
134
|
+
iterate_results = {"key_states": [], "value_states": [], "attn_output": []}
|
144
135
|
for b in range(bsz):
|
145
|
-
|
146
|
-
|
147
|
-
|
148
|
-
|
149
|
-
|
150
|
-
|
151
|
-
|
152
|
-
|
153
|
-
|
154
|
-
|
155
|
-
query_state = query_state.view(
|
156
|
-
1,
|
157
|
-
self.num_key_value_heads,
|
158
|
-
self.num_heads // self.num_key_value_heads,
|
159
|
-
q_len,
|
160
|
-
self.head_dim,
|
136
|
+
attn_output, key_state, value_state = PhiAttention._attn(
|
137
|
+
self,
|
138
|
+
query_states[b].unsqueeze(0),
|
139
|
+
key_states[b].unsqueeze(0),
|
140
|
+
value_states[b].unsqueeze(0),
|
141
|
+
attention_mask[b].unsqueeze(0),
|
142
|
+
past_key_value,
|
143
|
+
batch_idx=b,
|
144
|
+
is_prefill=False,
|
161
145
|
)
|
162
|
-
|
163
|
-
|
164
|
-
|
165
|
-
|
166
|
-
|
167
|
-
|
168
|
-
|
169
|
-
|
170
|
-
# Queries and keys upcast to fp32 is required by Phi-2 to avoid overflow
|
171
|
-
attn_weights = torch.matmul(
|
172
|
-
query_state.to(torch.float32),
|
173
|
-
key_state.to(torch.float32).transpose(3, 4),
|
174
|
-
) / math.sqrt(self.head_dim)
|
175
|
-
attn_weights = attn_weights + attn_mask
|
176
|
-
|
177
|
-
# upcast attention to fp32
|
178
|
-
attn_weights = nn.functional.softmax(
|
179
|
-
attn_weights, dim=-1, dtype=torch.float32
|
180
|
-
).to(query_states.dtype)
|
181
|
-
attn_weights = nn.functional.dropout(
|
182
|
-
attn_weights, p=self.attention_dropout, training=self.training
|
183
|
-
)
|
184
|
-
attn_output = torch.matmul(attn_weights, value_state)
|
185
|
-
|
186
|
-
# reshape for removing repeat_kv
|
187
|
-
attn_output = attn_output.view(1, self.num_heads, q_len, self.head_dim)
|
188
|
-
attn_output = attn_output.transpose(1, 2).contiguous()
|
189
|
-
attn_output = attn_output.reshape(
|
190
|
-
1, q_len, self.num_heads * self.head_dim
|
191
|
-
)
|
192
|
-
|
193
|
-
all_key_states.append(key_state)
|
194
|
-
all_value_states.append(value_state)
|
195
|
-
all_attn_output.append(attn_output)
|
196
|
-
|
197
|
-
key_states = torch.cat(all_key_states, dim=0)
|
198
|
-
value_states = torch.cat(all_value_states, dim=0)
|
199
|
-
attn_output = torch.cat(all_attn_output, dim=0)
|
146
|
+
iterate_results["key_states"].append(key_state)
|
147
|
+
iterate_results["value_states"].append(value_state)
|
148
|
+
iterate_results["attn_output"].append(attn_output)
|
149
|
+
|
150
|
+
key_states = torch.cat(iterate_results["key_states"], dim=0)
|
151
|
+
value_states = torch.cat(iterate_results["value_states"], dim=0)
|
152
|
+
attn_output = torch.cat(iterate_results["attn_output"], dim=0)
|
153
|
+
# Prefill & Decoder (bsz == 1)
|
200
154
|
else:
|
201
|
-
|
202
|
-
|
203
|
-
|
204
|
-
# reshape for removing repeat_kv
|
205
|
-
key_states = key_states.unsqueeze(2)
|
206
|
-
value_states = value_states.unsqueeze(2)
|
207
|
-
attention_mask = attention_mask.unsqueeze(2)
|
208
|
-
query_states = query_states.view(
|
209
|
-
1,
|
210
|
-
self.num_key_value_heads,
|
211
|
-
self.num_heads // self.num_key_value_heads,
|
212
|
-
q_len,
|
213
|
-
self.head_dim,
|
214
|
-
)
|
215
|
-
|
216
|
-
key_states, value_states = past_key_value.update(
|
155
|
+
attn_output, key_states, value_states = PhiAttention._attn(
|
156
|
+
self,
|
157
|
+
query_states,
|
217
158
|
key_states,
|
218
159
|
value_states,
|
219
|
-
|
220
|
-
|
221
|
-
|
222
|
-
|
223
|
-
|
224
|
-
# Queries and keys upcast to fp32 is required by Phi-2 to avoid overflow
|
225
|
-
attn_weights = torch.matmul(
|
226
|
-
query_states.to(torch.float32),
|
227
|
-
key_states.to(torch.float32).transpose(3, 4),
|
228
|
-
) / math.sqrt(self.head_dim)
|
229
|
-
attn_weights = attn_weights + attention_mask
|
230
|
-
|
231
|
-
# upcast attention to fp32
|
232
|
-
attn_weights = torch.nn.functional.softmax(
|
233
|
-
attn_weights, dim=-1, dtype=torch.float32
|
234
|
-
).to(value_states.dtype)
|
235
|
-
attn_weights = torch.nn.functional.dropout(
|
236
|
-
attn_weights, p=self.attention_dropout, training=self.training
|
237
|
-
)
|
238
|
-
attn_output = torch.matmul(attn_weights, value_states)
|
239
|
-
|
240
|
-
# reshape for removing repeat_kv
|
241
|
-
attn_output = attn_output.view(1, self.num_heads, q_len, self.head_dim)
|
242
|
-
attn_output = attn_output.transpose(1, 2).contiguous()
|
243
|
-
attn_output = attn_output.reshape(
|
244
|
-
bsz, q_len, self.num_heads * self.head_dim
|
160
|
+
attention_mask,
|
161
|
+
past_key_value,
|
162
|
+
batch_idx=batch_index,
|
163
|
+
is_prefill=True,
|
245
164
|
)
|
246
165
|
|
247
166
|
attn_output = self.dense(attn_output)
|
@@ -265,12 +184,9 @@ class PhiDecoderLayer:
|
|
265
184
|
batch_ids: Optional[torch.LongTensor] = None,
|
266
185
|
cos: Optional[torch.Tensor] = None,
|
267
186
|
sin: Optional[torch.Tensor] = None,
|
268
|
-
rotary_pos_emb=None,
|
269
187
|
forward_dict: Optional[Dict[str, classmethod]] = None,
|
270
188
|
**kwargs,
|
271
|
-
) -> Tuple[
|
272
|
-
torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]
|
273
|
-
]:
|
189
|
+
) -> Tuple[torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]]:
|
274
190
|
"""
|
275
191
|
Args:
|
276
192
|
hidden_states (`torch.FloatTensor`):
|
@@ -294,9 +210,7 @@ class PhiDecoderLayer:
|
|
294
210
|
hidden_states = self.input_layernorm(hidden_states)
|
295
211
|
|
296
212
|
# Self Attention
|
297
|
-
attn_outputs, self_attn_weights, key_states, value_states = forward_dict[
|
298
|
-
"decoder_layer"
|
299
|
-
](
|
213
|
+
attn_outputs, self_attn_weights, key_states, value_states = forward_dict["decoder_layer"](
|
300
214
|
self.self_attn,
|
301
215
|
hidden_states=hidden_states,
|
302
216
|
attention_mask=attention_mask,
|
@@ -307,7 +221,6 @@ class PhiDecoderLayer:
|
|
307
221
|
use_cache=use_cache,
|
308
222
|
cos=cos,
|
309
223
|
sin=sin,
|
310
|
-
rotary_pos_emb=rotary_pos_emb,
|
311
224
|
**kwargs,
|
312
225
|
)
|
313
226
|
past_key_value.assign(key_states, value_states, layer_idx)
|
@@ -339,6 +252,8 @@ class PhiModel:
|
|
339
252
|
use_cache: Optional[bool] = True,
|
340
253
|
output_attentions: Optional[bool] = False,
|
341
254
|
output_hidden_states: Optional[bool] = False,
|
255
|
+
cache_pos_for_partitions: Optional[torch.Tensor] = None,
|
256
|
+
kvcache_partition_size: Optional[torch.Tensor] = None,
|
342
257
|
forward_dict: Optional[Dict[str, classmethod]] = None,
|
343
258
|
rotary_pos_emb=None,
|
344
259
|
) -> BaseModelOutputWithPast:
|
@@ -378,7 +293,8 @@ class PhiModel:
|
|
378
293
|
batch_ids=batch_ids,
|
379
294
|
cos=cos,
|
380
295
|
sin=sin,
|
381
|
-
|
296
|
+
cache_pos_for_partitions=cache_pos_for_partitions,
|
297
|
+
kvcache_partition_size=kvcache_partition_size,
|
382
298
|
forward_dict=forward_dict,
|
383
299
|
)
|
384
300
|
|
@@ -21,28 +21,18 @@
|
|
21
21
|
# copied, modified, or distributed without prior written permission
|
22
22
|
# from Rebellions Inc.
|
23
23
|
|
24
|
-
import
|
25
|
-
import
|
26
|
-
from typing import TYPE_CHECKING, Any, Callable
|
27
|
-
|
28
|
-
from transformers import Qwen2ForCausalLM
|
29
|
-
|
30
|
-
from ..decoderonly import RBLNDecoderOnlyModelForCausalLM
|
24
|
+
from ....utils import logging
|
25
|
+
from ...models.decoderonly import RBLNDecoderOnlyModelForCausalLM
|
31
26
|
from .qwen2_architecture import QWEN2Wrapper
|
32
27
|
|
33
28
|
|
34
|
-
|
35
|
-
from transformers import PreTrainedModel
|
36
|
-
|
37
|
-
from ....modeling_config import RBLNConfig
|
38
|
-
|
39
|
-
logger = logging.getLogger(__name__)
|
29
|
+
logger = logging.get_logger(__name__)
|
40
30
|
|
41
31
|
|
42
32
|
class RBLNQwen2ForCausalLM(RBLNDecoderOnlyModelForCausalLM):
|
43
33
|
"""
|
44
34
|
The Llama Model transformer with a language modeling head (linear layer) on top.
|
45
|
-
This model inherits from [`
|
35
|
+
This model inherits from [`RBLNDecoderOnlyModelForCausalLM`]. Check the superclass documentation for the generic methods the library implements for all its models.
|
46
36
|
|
47
37
|
A class to convert and run pre-trained transformers based LlamaForCausalLM model on RBLN devices.
|
48
38
|
It implements the methods to convert a pre-trained transformers LlamaForCausalLM model into a RBLN transformer model by:
|
@@ -50,18 +40,4 @@ class RBLNQwen2ForCausalLM(RBLNDecoderOnlyModelForCausalLM):
|
|
50
40
|
- compiling the resulting graph using the RBLN compiler.
|
51
41
|
"""
|
52
42
|
|
53
|
-
|
54
|
-
def wrap_model_if_needed(self, model: "PreTrainedModel", rbln_config: "RBLNConfig"):
|
55
|
-
rbln_max_seq_len = rbln_config.model_cfg["max_seq_len"]
|
56
|
-
return QWEN2Wrapper(model, rbln_max_seq_len).eval()
|
57
|
-
|
58
|
-
def __getattr__(self, __name: str) -> Any:
|
59
|
-
def redirect(func):
|
60
|
-
return lambda *pargs, **kwargs: func(self, *pargs, **kwargs)
|
61
|
-
|
62
|
-
val = getattr(Qwen2ForCausalLM, __name)
|
63
|
-
|
64
|
-
if isinstance(val, Callable) and "self" in set(inspect.signature(val).parameters):
|
65
|
-
return redirect(val)
|
66
|
-
|
67
|
-
return val
|
43
|
+
_decoder_wrapper_cls = QWEN2Wrapper
|
@@ -21,5 +21,5 @@
|
|
21
21
|
# copied, modified, or distributed without prior written permission
|
22
22
|
# from Rebellions Inc.
|
23
23
|
|
24
|
-
from .modeling_t5 import RBLNT5ForConditionalGeneration
|
24
|
+
from .modeling_t5 import RBLNT5EncoderModel, RBLNT5ForConditionalGeneration
|
25
25
|
from .t5_architecture import T5DecoderWrapper, T5EncoderWrapper
|
@@ -22,11 +22,16 @@
|
|
22
22
|
# from Rebellions Inc.
|
23
23
|
|
24
24
|
import inspect
|
25
|
-
from typing import TYPE_CHECKING, Any, Callable
|
25
|
+
from typing import TYPE_CHECKING, Any, Callable, Dict, Optional, Union
|
26
26
|
|
27
|
-
from transformers import
|
27
|
+
from transformers import (
|
28
|
+
AutoModelForTextEncoding,
|
29
|
+
PretrainedConfig,
|
30
|
+
T5ForConditionalGeneration,
|
31
|
+
)
|
28
32
|
|
29
|
-
from ....
|
33
|
+
from ....modeling_base import RBLNModel
|
34
|
+
from ....modeling_config import RBLNCompileConfig, RBLNConfig
|
30
35
|
from ....utils.logging import get_logger
|
31
36
|
from ...models.seq2seq import RBLNModelForSeq2SeqLM
|
32
37
|
from .t5_architecture import T5Wrapper
|
@@ -35,7 +40,55 @@ from .t5_architecture import T5Wrapper
|
|
35
40
|
logger = get_logger()
|
36
41
|
|
37
42
|
if TYPE_CHECKING:
|
38
|
-
from transformers import PreTrainedModel
|
43
|
+
from transformers import AutoFeatureExtractor, AutoProcessor, AutoTokenizer, PreTrainedModel
|
44
|
+
|
45
|
+
|
46
|
+
class RBLNT5EncoderModel(RBLNModel):
|
47
|
+
auto_model_class = AutoModelForTextEncoding
|
48
|
+
|
49
|
+
@classmethod
|
50
|
+
def _get_rbln_config(
|
51
|
+
cls,
|
52
|
+
preprocessors: Optional[Union["AutoFeatureExtractor", "AutoProcessor", "AutoTokenizer"]],
|
53
|
+
model_config: Optional["PretrainedConfig"] = None,
|
54
|
+
rbln_kwargs: Dict[str, Any] = {},
|
55
|
+
) -> RBLNConfig:
|
56
|
+
rbln_max_seq_len = rbln_kwargs.get("max_seq_len", None)
|
57
|
+
rbln_batch_size = rbln_kwargs.get("batch_size", None)
|
58
|
+
|
59
|
+
max_position_embeddings = getattr(model_config, "n_positions", None)
|
60
|
+
|
61
|
+
if rbln_max_seq_len is None:
|
62
|
+
rbln_max_seq_len = max_position_embeddings
|
63
|
+
if rbln_max_seq_len is None:
|
64
|
+
for tokenizer in preprocessors:
|
65
|
+
if hasattr(tokenizer, "model_max_length"):
|
66
|
+
rbln_max_seq_len = tokenizer.model_max_length
|
67
|
+
break
|
68
|
+
if rbln_max_seq_len is None:
|
69
|
+
raise ValueError("`rbln_max_seq_len` should be specified!")
|
70
|
+
|
71
|
+
if max_position_embeddings is not None and rbln_max_seq_len > max_position_embeddings:
|
72
|
+
raise ValueError("`rbln_max_seq_len` should be less or equal than max_position_embeddings!")
|
73
|
+
|
74
|
+
if rbln_batch_size is None:
|
75
|
+
rbln_batch_size = 1
|
76
|
+
|
77
|
+
input_info = [
|
78
|
+
("input_ids", [rbln_batch_size, rbln_max_seq_len], "int64"),
|
79
|
+
("attention_mask", [rbln_batch_size, rbln_max_seq_len], "int64"),
|
80
|
+
]
|
81
|
+
|
82
|
+
rbln_compile_config = RBLNCompileConfig(input_info=input_info)
|
83
|
+
|
84
|
+
rbln_config = RBLNConfig(
|
85
|
+
rbln_cls=cls.__name__,
|
86
|
+
compile_cfgs=[rbln_compile_config],
|
87
|
+
rbln_kwargs=rbln_kwargs,
|
88
|
+
)
|
89
|
+
|
90
|
+
rbln_config.model_cfg.update({"max_seq_len": rbln_max_seq_len})
|
91
|
+
return rbln_config
|
39
92
|
|
40
93
|
|
41
94
|
class RBLNT5ForConditionalGeneration(RBLNModelForSeq2SeqLM):
|
@@ -102,7 +102,7 @@ class RBLNRuntimeDecoder(RBLNPytorchRuntime):
|
|
102
102
|
class RBLNWhisperForConditionalGeneration(RBLNModel, RBLNWhisperGenerationMixin):
|
103
103
|
"""
|
104
104
|
The Whisper Model with a language modeling head. Can be used for automatic speech recognition.
|
105
|
-
This model inherits from [`
|
105
|
+
This model inherits from [`RBLNDecoderOnlyModelForCausalLM`]. Check the superclass documentation for the generic methods the library implements for all its models.
|
106
106
|
|
107
107
|
A class to convert and run pre-trained transformers based LlamaForCausalLM model on RBLN devices.
|
108
108
|
It implements the methods to convert a pre-trained transformers LlamaForCausalLM model into a RBLN transformer model by:
|
@@ -31,8 +31,13 @@ from torch.nn import functional as F
|
|
31
31
|
|
32
32
|
# Constants
|
33
33
|
QUANTIZED_WEIGHTS = {
|
34
|
-
"q_proj",
|
35
|
-
"
|
34
|
+
"q_proj",
|
35
|
+
"k_proj",
|
36
|
+
"v_proj",
|
37
|
+
"o_proj",
|
38
|
+
"gate_proj",
|
39
|
+
"up_proj",
|
40
|
+
"down_proj",
|
36
41
|
}
|
37
42
|
|
38
43
|
|
@@ -81,6 +86,7 @@ def create_qlinear(layer: Linear) -> Linear:
|
|
81
86
|
"""
|
82
87
|
Converts a standard linear layer to a quantized linear (qlinear) layer with a custom forward pass.
|
83
88
|
"""
|
89
|
+
|
84
90
|
def qlinear_forward(self, inputs: torch.Tensor) -> torch.Tensor:
|
85
91
|
if inputs.dtype != self.scales.dtype:
|
86
92
|
raise TypeError(f"Expected input dtype {self.scales.dtype}, but got {inputs.dtype}")
|
@@ -0,0 +1,58 @@
|
|
1
|
+
# Copyright 2024 Rebellions Inc.
|
2
|
+
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4
|
+
# you may not use this file except in compliance with the License.
|
5
|
+
# You may obtain a copy of the License at:
|
6
|
+
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
8
|
+
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12
|
+
# See the License for the specific language governing permissions and
|
13
|
+
# limitations under the License.
|
14
|
+
|
15
|
+
# Portions of this software are licensed under the Apache License,
|
16
|
+
# Version 2.0. See the NOTICE file distributed with this work for
|
17
|
+
# additional information regarding copyright ownership.
|
18
|
+
|
19
|
+
# All other portions of this software, including proprietary code,
|
20
|
+
# are the intellectual property of Rebellions Inc. and may not be
|
21
|
+
# copied, modified, or distributed without prior written permission
|
22
|
+
# from Rebellions Inc.
|
23
|
+
|
24
|
+
from contextlib import contextmanager
|
25
|
+
from pathlib import Path
|
26
|
+
from typing import Union
|
27
|
+
|
28
|
+
from optimum.exporters import TasksManager
|
29
|
+
from transformers import AutoConfig, AutoModel
|
30
|
+
|
31
|
+
|
32
|
+
@contextmanager
|
33
|
+
def override_auto_classes(config_func=None, model_func=None, skip_taskmanager=True):
|
34
|
+
"""Temporarily override Auto classes with original model classes"""
|
35
|
+
original_config = AutoConfig.from_pretrained
|
36
|
+
original_model = AutoModel.from_pretrained
|
37
|
+
original_get_model_from_task = TasksManager.get_model_from_task
|
38
|
+
|
39
|
+
def get_model_from_task(
|
40
|
+
task: str,
|
41
|
+
model_name_or_path: Union[str, Path],
|
42
|
+
**kwargs,
|
43
|
+
):
|
44
|
+
return model_func(model_name_or_path, **kwargs)
|
45
|
+
|
46
|
+
def none_func(*args, **kwargs):
|
47
|
+
return None
|
48
|
+
|
49
|
+
try:
|
50
|
+
AutoConfig.from_pretrained = config_func or none_func
|
51
|
+
AutoModel.from_pretrained = model_func or none_func
|
52
|
+
if skip_taskmanager:
|
53
|
+
TasksManager.get_model_from_task = none_func if model_func is None else get_model_from_task
|
54
|
+
yield
|
55
|
+
finally:
|
56
|
+
AutoConfig.from_pretrained = original_config
|
57
|
+
AutoModel.from_pretrained = original_model
|
58
|
+
TasksManager.get_model_from_task = original_get_model_from_task
|
@@ -0,0 +1,55 @@
|
|
1
|
+
from functools import wraps
|
2
|
+
|
3
|
+
from .logging import get_logger
|
4
|
+
|
5
|
+
|
6
|
+
logger = get_logger(__name__)
|
7
|
+
|
8
|
+
|
9
|
+
def remove_compile_time_kwargs(func):
|
10
|
+
"""
|
11
|
+
Decorator to handle compile-time parameters during inference.
|
12
|
+
|
13
|
+
For RBLN-optimized pipelines, several parameters must be determined during compilation
|
14
|
+
and cannot be modified during inference. This decorator:
|
15
|
+
1. Removes and warns about LoRA scale in cross_attention_kwargs
|
16
|
+
2. Removes and warns about image dimension parameters (height, width)
|
17
|
+
|
18
|
+
Args:
|
19
|
+
func: The pipeline's __call__ method to be wrapped
|
20
|
+
"""
|
21
|
+
|
22
|
+
@wraps(func)
|
23
|
+
def wrapper(self, *args, **kwargs):
|
24
|
+
height_exists = "height" in kwargs and kwargs["height"] is not None
|
25
|
+
width_exists = "width" in kwargs and kwargs["width"] is not None
|
26
|
+
if height_exists or width_exists:
|
27
|
+
logger.warning(
|
28
|
+
"Image dimension parameters (`height`, `width`) will be ignored during inference. "
|
29
|
+
"Image dimensions must be specified during model compilation using from_pretrained()."
|
30
|
+
)
|
31
|
+
kwargs.pop("width", None)
|
32
|
+
kwargs.pop("height", None)
|
33
|
+
|
34
|
+
if "cross_attention_kwargs" in kwargs:
|
35
|
+
cross_attention_kwargs = kwargs.get("cross_attention_kwargs")
|
36
|
+
if not cross_attention_kwargs:
|
37
|
+
return func(self, *args, **kwargs)
|
38
|
+
|
39
|
+
has_scale = "scale" in cross_attention_kwargs
|
40
|
+
if has_scale:
|
41
|
+
logger.warning(
|
42
|
+
"LoRA scale in cross_attention_kwargs will be ignored during inference. "
|
43
|
+
"To adjust LoRA scale, specify it during model compilation using from_pretrained()."
|
44
|
+
)
|
45
|
+
|
46
|
+
# If scale is the only key, set to None
|
47
|
+
# Otherwise, remove scale and preserve other settings
|
48
|
+
if len(cross_attention_kwargs) == 1:
|
49
|
+
kwargs["cross_attention_kwargs"] = None
|
50
|
+
else:
|
51
|
+
kwargs["cross_attention_kwargs"].pop("scale")
|
52
|
+
|
53
|
+
return func(self, *args, **kwargs)
|
54
|
+
|
55
|
+
return wrapper
|