optimum-rbln 0.1.9__py3-none-any.whl → 0.1.12__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- optimum/rbln/__init__.py +47 -9
- optimum/rbln/__version__.py +1 -1
- optimum/rbln/diffusers/models/autoencoder_kl.py +36 -31
- optimum/rbln/diffusers/models/controlnet.py +53 -43
- optimum/rbln/diffusers/models/unet_2d_condition.py +40 -31
- optimum/rbln/diffusers/pipelines/controlnet/multicontrolnet.py +4 -0
- optimum/rbln/diffusers/pipelines/controlnet/pipeline_controlnet.py +28 -23
- optimum/rbln/diffusers/pipelines/controlnet/pipeline_controlnet_img2img.py +28 -23
- optimum/rbln/diffusers/pipelines/controlnet/pipeline_controlnet_sd_xl.py +28 -37
- optimum/rbln/diffusers/pipelines/controlnet/pipeline_controlnet_sd_xl_img2img.py +30 -39
- optimum/rbln/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion.py +24 -14
- optimum/rbln/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_img2img.py +24 -15
- optimum/rbln/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl.py +26 -17
- optimum/rbln/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_img2img.py +26 -17
- optimum/rbln/modeling_alias.py +6 -11
- optimum/rbln/modeling_base.py +467 -261
- optimum/rbln/modeling_config.py +199 -73
- optimum/rbln/transformers/__init__.py +43 -1
- optimum/rbln/transformers/models/__init__.py +23 -1
- optimum/rbln/transformers/models/auto/__init__.py +14 -0
- optimum/rbln/transformers/models/auto/auto_factory.py +84 -0
- optimum/rbln/transformers/models/auto/modeling_auto.py +95 -0
- optimum/rbln/transformers/models/bart/__init__.py +1 -0
- optimum/rbln/transformers/models/bart/bart_architecture.py +203 -58
- optimum/rbln/transformers/models/bart/modeling_bart.py +125 -0
- optimum/rbln/transformers/models/bert/__init__.py +24 -0
- optimum/rbln/transformers/models/bert/modeling_bert.py +101 -0
- optimum/rbln/transformers/models/clip/__init__.py +1 -1
- optimum/rbln/transformers/models/clip/modeling_clip.py +127 -26
- optimum/rbln/transformers/models/decoderonly/decoderonly_architecture.py +28 -4
- optimum/rbln/transformers/models/decoderonly/modeling_decoderonly.py +409 -150
- optimum/rbln/transformers/models/dpt/modeling_dpt.py +21 -8
- optimum/rbln/transformers/models/exaone/__init__.py +32 -0
- optimum/rbln/transformers/models/exaone/exaone_architecture.py +72 -0
- optimum/rbln/transformers/models/exaone/hf_hub_cached/configuration_exaone.py +181 -0
- optimum/rbln/transformers/models/exaone/hf_hub_cached/modeling_exaone.py +1725 -0
- optimum/rbln/transformers/models/exaone/modeling_exaone.py +78 -0
- optimum/rbln/transformers/models/gemma/modeling_gemma.py +1 -1
- optimum/rbln/transformers/models/gpt2/gpt2_architecture.py +4 -1
- optimum/rbln/transformers/models/gpt2/modeling_gpt2.py +1 -1
- optimum/rbln/transformers/models/llama/modeling_llama.py +1 -1
- optimum/rbln/transformers/models/llava_next/__init__.py +24 -0
- optimum/rbln/transformers/models/llava_next/modeling_llava_next.py +662 -0
- optimum/rbln/transformers/models/midm/midm_architecture.py +5 -1
- optimum/rbln/transformers/models/midm/modeling_midm.py +6 -1
- optimum/rbln/transformers/models/mistral/modeling_mistral.py +1 -1
- optimum/rbln/transformers/models/phi/__init__.py +24 -0
- optimum/rbln/transformers/models/phi/modeling_phi.py +69 -0
- optimum/rbln/transformers/models/phi/phi_architecture.py +406 -0
- optimum/rbln/transformers/models/qwen2/__init__.py +24 -0
- optimum/rbln/transformers/models/qwen2/modeling_qwen2.py +67 -0
- optimum/rbln/transformers/models/qwen2/qwen2_architecture.py +29 -0
- optimum/rbln/transformers/models/seq2seq/__init__.py +24 -0
- optimum/rbln/{modeling_seq2seq.py → transformers/models/seq2seq/modeling_seq2seq.py} +198 -168
- optimum/rbln/transformers/models/t5/__init__.py +1 -0
- optimum/rbln/transformers/models/t5/modeling_t5.py +55 -0
- optimum/rbln/transformers/models/t5/t5_architecture.py +122 -47
- optimum/rbln/transformers/models/wav2vec2/modeling_wav2vec2.py +17 -12
- optimum/rbln/transformers/models/whisper/generation_whisper.py +68 -0
- optimum/rbln/transformers/models/whisper/modeling_whisper.py +172 -111
- optimum/rbln/transformers/models/whisper/whisper_architecture.py +44 -17
- optimum/rbln/transformers/models/xlm_roberta/modeling_xlm_roberta.py +18 -16
- optimum/rbln/transformers/utils/rbln_quantization.py +48 -60
- optimum/rbln/utils/import_utils.py +50 -1
- optimum/rbln/utils/logging.py +82 -0
- optimum/rbln/utils/runtime_utils.py +33 -0
- optimum/rbln/utils/timer_utils.py +43 -0
- {optimum_rbln-0.1.9.dist-info → optimum_rbln-0.1.12.dist-info}/METADATA +9 -7
- optimum_rbln-0.1.12.dist-info/RECORD +103 -0
- {optimum_rbln-0.1.9.dist-info → optimum_rbln-0.1.12.dist-info}/WHEEL +1 -1
- optimum_rbln-0.1.12.dist-info/entry_points.txt +4 -0
- optimum_rbln-0.1.9.dist-info/RECORD +0 -78
- {optimum_rbln-0.1.9.dist-info → optimum_rbln-0.1.12.dist-info}/licenses/LICENSE +0 -0
@@ -47,6 +47,12 @@ from transformers.utils import logging
|
|
47
47
|
logger = logging.get_logger(__name__)
|
48
48
|
|
49
49
|
|
50
|
+
class BartWrapper:
|
51
|
+
def __init__(self, model):
|
52
|
+
self.encoder = BartEncoderWrapper(model)
|
53
|
+
self.decoder = BartDecoderWrapper(model)
|
54
|
+
|
55
|
+
|
50
56
|
class _BartAttention(BartAttention):
|
51
57
|
def forward(
|
52
58
|
self,
|
@@ -54,6 +60,7 @@ class _BartAttention(BartAttention):
|
|
54
60
|
past_key_value: Tuple[torch.Tensor],
|
55
61
|
attention_mask: torch.Tensor,
|
56
62
|
cache_position: torch.Tensor,
|
63
|
+
batch_index: torch.Tensor,
|
57
64
|
key_value_states: Optional[torch.Tensor] = None,
|
58
65
|
) -> Tuple[torch.Tensor, Tuple[torch.Tensor]]:
|
59
66
|
bsz, tgt_len, _ = hidden_states.size()
|
@@ -72,28 +79,83 @@ class _BartAttention(BartAttention):
|
|
72
79
|
else:
|
73
80
|
key_states = self._shape(self.k_proj(hidden_states), -1, bsz)
|
74
81
|
value_states = self._shape(self.v_proj(hidden_states), -1, bsz)
|
75
|
-
key_states = past_key_value[0].slice_scatter(
|
76
|
-
key_states, dim=2, start=cache_position, end=cache_position + 1
|
77
|
-
)
|
78
|
-
value_states = past_key_value[1].slice_scatter(
|
79
|
-
value_states, dim=2, start=cache_position, end=cache_position + 1
|
80
|
-
)
|
81
82
|
|
82
|
-
|
83
|
-
|
84
|
-
|
85
|
-
|
86
|
-
|
87
|
-
|
88
|
-
|
89
|
-
|
90
|
-
|
91
|
-
|
92
|
-
|
93
|
-
|
94
|
-
|
95
|
-
|
96
|
-
|
83
|
+
if cache_position.dim() > 0:
|
84
|
+
proj_shape = (bsz, self.num_heads, -1, self.head_dim)
|
85
|
+
query_states = self._shape(query_states, tgt_len, bsz).view(*proj_shape)
|
86
|
+
key_states = key_states.reshape(*proj_shape)
|
87
|
+
value_states = value_states.reshape(*proj_shape)
|
88
|
+
|
89
|
+
all_key_states = []
|
90
|
+
all_value_states = []
|
91
|
+
all_attn_output = []
|
92
|
+
for b in range(bsz):
|
93
|
+
batch_query_states = query_states[b].unsqueeze(0).unsqueeze(2)
|
94
|
+
batch_attention_mask = attention_mask[b].unsqueeze(0).unsqueeze(2)
|
95
|
+
batch_key_states = key_states[b].unsqueeze(0).unsqueeze(2)
|
96
|
+
batch_value_states = value_states[b].unsqueeze(0).unsqueeze(2)
|
97
|
+
if not is_cross_attention:
|
98
|
+
batch_key_states = (
|
99
|
+
past_key_value[0][b]
|
100
|
+
.unsqueeze(0)
|
101
|
+
.unsqueeze(2)
|
102
|
+
.slice_scatter(
|
103
|
+
batch_key_states, dim=-2, start=cache_position[b][0], end=cache_position[b][0] + 1
|
104
|
+
)
|
105
|
+
)
|
106
|
+
batch_value_states = (
|
107
|
+
past_key_value[1][b]
|
108
|
+
.unsqueeze(0)
|
109
|
+
.unsqueeze(2)
|
110
|
+
.slice_scatter(
|
111
|
+
batch_value_states, dim=-2, start=cache_position[b][0], end=cache_position[b][0] + 1
|
112
|
+
)
|
113
|
+
)
|
114
|
+
attn_weights = torch.matmul(batch_query_states, batch_key_states.transpose(3, 4))
|
115
|
+
attn_weights = attn_weights + batch_attention_mask
|
116
|
+
attn_weights = nn.functional.softmax(attn_weights, dim=-1)
|
117
|
+
|
118
|
+
attn_output = torch.matmul(attn_weights, batch_value_states)
|
119
|
+
attn_output = attn_output.view(1, self.num_heads, tgt_len, self.head_dim)
|
120
|
+
attn_output = attn_output.transpose(1, 2)
|
121
|
+
attn_output = attn_output.reshape(1, tgt_len, self.embed_dim)
|
122
|
+
all_key_states.append(batch_key_states)
|
123
|
+
all_value_states.append(batch_value_states)
|
124
|
+
all_attn_output.append(attn_output)
|
125
|
+
key_states = torch.cat(all_key_states, dim=0).squeeze(2)
|
126
|
+
value_states = torch.cat(all_value_states, dim=0).squeeze(2)
|
127
|
+
attn_output = torch.cat(all_attn_output, dim=0)
|
128
|
+
|
129
|
+
else:
|
130
|
+
if batch_index is None or batch_index == -1:
|
131
|
+
batch_index = 0
|
132
|
+
|
133
|
+
if not is_cross_attention:
|
134
|
+
key_states = past_key_value[0].slice_scatter(
|
135
|
+
key_states, dim=2, start=cache_position, end=cache_position + 1
|
136
|
+
)
|
137
|
+
value_states = past_key_value[1].slice_scatter(
|
138
|
+
value_states, dim=2, start=cache_position, end=cache_position + 1
|
139
|
+
)
|
140
|
+
|
141
|
+
proj_shape = (bsz * self.num_heads, -1, self.head_dim)
|
142
|
+
query_states = self._shape(query_states, tgt_len, bsz).view(*proj_shape)
|
143
|
+
key_states = key_states.reshape(*proj_shape)
|
144
|
+
value_states = value_states.reshape(*proj_shape)
|
145
|
+
|
146
|
+
src_len = key_states.size(1)
|
147
|
+
attn_weights = torch.bmm(query_states, key_states.transpose(1, 2))
|
148
|
+
attn_weights = attn_weights.view(bsz, self.num_heads, tgt_len, src_len) + attention_mask
|
149
|
+
attn_weights = attn_weights.view(bsz * self.num_heads, tgt_len, src_len)
|
150
|
+
attn_weights = nn.functional.softmax(attn_weights, dim=-1)
|
151
|
+
|
152
|
+
attn_output = torch.bmm(attn_weights, value_states)
|
153
|
+
attn_output = attn_output.view(bsz, self.num_heads, tgt_len, self.head_dim)
|
154
|
+
attn_output = attn_output.transpose(1, 2)
|
155
|
+
key_states = key_states.unsqueeze(0)
|
156
|
+
value_states = value_states.unsqueeze(0)
|
157
|
+
attn_output = attn_output.reshape(bsz, tgt_len, self.embed_dim)
|
158
|
+
|
97
159
|
attn_output = self.out_proj(attn_output)
|
98
160
|
|
99
161
|
present_key_value = (key_states, value_states)
|
@@ -108,6 +170,7 @@ class _BartSdpaAttention(BartSdpaAttention):
|
|
108
170
|
past_key_value: Tuple[torch.Tensor],
|
109
171
|
attention_mask: torch.Tensor,
|
110
172
|
cache_position: torch.Tensor,
|
173
|
+
batch_index: torch.Tensor,
|
111
174
|
key_value_states: Optional[torch.Tensor] = None,
|
112
175
|
) -> Tuple[torch.Tensor, Tuple[torch.Tensor]]:
|
113
176
|
bsz, tgt_len, _ = hidden_states.size()
|
@@ -126,23 +189,71 @@ class _BartSdpaAttention(BartSdpaAttention):
|
|
126
189
|
else:
|
127
190
|
key_states = self._shape(self.k_proj(hidden_states), -1, bsz)
|
128
191
|
value_states = self._shape(self.v_proj(hidden_states), -1, bsz)
|
129
|
-
key_states = past_key_value[0].slice_scatter(
|
130
|
-
key_states, dim=2, start=cache_position, end=cache_position + 1
|
131
|
-
)
|
132
|
-
value_states = past_key_value[1].slice_scatter(
|
133
|
-
value_states, dim=2, start=cache_position, end=cache_position + 1
|
134
|
-
)
|
135
192
|
|
136
193
|
query_states = self._shape(query_states, tgt_len, bsz)
|
137
194
|
|
138
|
-
|
139
|
-
|
140
|
-
|
141
|
-
|
142
|
-
|
143
|
-
|
144
|
-
|
145
|
-
|
195
|
+
if (batch_index is None or batch_index == -1) and bsz > 1:
|
196
|
+
all_key_states = []
|
197
|
+
all_value_states = []
|
198
|
+
all_attn_output = []
|
199
|
+
|
200
|
+
for b in range(bsz):
|
201
|
+
batch_query_states = query_states[b].unsqueeze(0)
|
202
|
+
batch_attention_mask = attention_mask[b].unsqueeze(0)
|
203
|
+
batch_key_states = key_states[b].unsqueeze(0)
|
204
|
+
batch_value_states = value_states[b].unsqueeze(0)
|
205
|
+
|
206
|
+
if not is_cross_attention:
|
207
|
+
batch_key_states = (
|
208
|
+
past_key_value[0][b]
|
209
|
+
.unsqueeze(0)
|
210
|
+
.slice_scatter(
|
211
|
+
batch_key_states, dim=-2, start=cache_position[b][0], end=cache_position[b][0] + 1
|
212
|
+
)
|
213
|
+
)
|
214
|
+
batch_value_states = (
|
215
|
+
past_key_value[1][b]
|
216
|
+
.unsqueeze(0)
|
217
|
+
.slice_scatter(
|
218
|
+
batch_value_states, dim=-2, start=cache_position[b][0], end=cache_position[b][0] + 1
|
219
|
+
)
|
220
|
+
)
|
221
|
+
|
222
|
+
attn_output = torch.nn.functional.scaled_dot_product_attention(
|
223
|
+
batch_query_states, batch_key_states, batch_value_states, attn_mask=batch_attention_mask
|
224
|
+
)
|
225
|
+
attn_output = attn_output.transpose(1, 2)
|
226
|
+
attn_output = attn_output.reshape(1, tgt_len, self.embed_dim)
|
227
|
+
all_key_states.append(batch_key_states)
|
228
|
+
all_value_states.append(batch_value_states)
|
229
|
+
all_attn_output.append(attn_output)
|
230
|
+
|
231
|
+
key_states = torch.cat(all_key_states, dim=0)
|
232
|
+
value_states = torch.cat(all_value_states, dim=0)
|
233
|
+
attn_output = torch.cat(all_attn_output, dim=0)
|
234
|
+
|
235
|
+
else:
|
236
|
+
if batch_index is None or batch_index == -1:
|
237
|
+
batch_index = 0
|
238
|
+
|
239
|
+
if not is_cross_attention:
|
240
|
+
key_states = past_key_value[0].slice_scatter(
|
241
|
+
key_states, dim=2, start=cache_position, end=cache_position + 1
|
242
|
+
)
|
243
|
+
value_states = past_key_value[1].slice_scatter(
|
244
|
+
value_states, dim=2, start=cache_position, end=cache_position + 1
|
245
|
+
)
|
246
|
+
|
247
|
+
# need 4d shape (input tensors) for scaled_dot_product_attention
|
248
|
+
attn_output = torch.nn.functional.scaled_dot_product_attention(
|
249
|
+
query_states,
|
250
|
+
key_states,
|
251
|
+
value_states,
|
252
|
+
attn_mask=attention_mask,
|
253
|
+
)
|
254
|
+
attn_output = attn_output.transpose(1, 2)
|
255
|
+
attn_output = attn_output.reshape(bsz, tgt_len, self.embed_dim)
|
256
|
+
|
146
257
|
attn_output = self.out_proj(attn_output)
|
147
258
|
|
148
259
|
present_key_value = (key_states, value_states)
|
@@ -162,6 +273,7 @@ class _BartDecoderLayer(BartDecoderLayer):
|
|
162
273
|
encoder_hidden_states: torch.Tensor,
|
163
274
|
past_key_value: Tuple[torch.Tensor],
|
164
275
|
cache_position: torch.Tensor,
|
276
|
+
batch_ids: torch.Tensor,
|
165
277
|
attn_impl: str = "eager",
|
166
278
|
) -> Tuple[torch.Tensor, Tuple[torch.Tensor]]:
|
167
279
|
# Self Attention Block
|
@@ -174,6 +286,7 @@ class _BartDecoderLayer(BartDecoderLayer):
|
|
174
286
|
past_key_value=self_attn_past_key_value,
|
175
287
|
attention_mask=attention_mask,
|
176
288
|
cache_position=cache_position,
|
289
|
+
batch_index=batch_ids,
|
177
290
|
)
|
178
291
|
hidden_states = residual + hidden_states
|
179
292
|
hidden_states = self.self_attn_layer_norm(hidden_states)
|
@@ -189,6 +302,7 @@ class _BartDecoderLayer(BartDecoderLayer):
|
|
189
302
|
past_key_value=cross_attn_past_key_value,
|
190
303
|
attention_mask=encoder_attention_mask,
|
191
304
|
cache_position=cache_position,
|
305
|
+
batch_index=batch_ids,
|
192
306
|
)
|
193
307
|
hidden_states = residual + hidden_states
|
194
308
|
hidden_states = self.encoder_attn_layer_norm(hidden_states)
|
@@ -213,14 +327,32 @@ class _BartDecoder(BartDecoder):
|
|
213
327
|
encoder_hidden_states: torch.Tensor,
|
214
328
|
past_key_values: torch.Tensor,
|
215
329
|
cache_position: torch.Tensor,
|
330
|
+
batch_ids: torch.Tensor,
|
216
331
|
attn_impl: str = "eager",
|
217
332
|
):
|
218
333
|
# embedding
|
219
|
-
|
220
|
-
|
334
|
+
if hasattr(self, "embed_scale"):
|
335
|
+
inputs_embeds = self.embed_tokens(input_ids) * self.embed_scale
|
336
|
+
else:
|
337
|
+
inputs_embeds = self.embed_tokens(input_ids)
|
338
|
+
|
339
|
+
if cache_position.dim() == 0:
|
340
|
+
positions_idx = cache_position + self.embed_positions.offset
|
341
|
+
positions = self.embed_positions.weight[positions_idx]
|
342
|
+
hidden_states = inputs_embeds + positions
|
343
|
+
else:
|
344
|
+
hidden_all = []
|
345
|
+
# compiler pattern base dependency -> take + add
|
346
|
+
for i in range(input_ids.shape[0]):
|
347
|
+
# cache position [N,1]
|
348
|
+
positions_idx = cache_position[i]
|
349
|
+
# offset is set 2 in bart embedding
|
350
|
+
position_weight = self.embed_positions.weight[2:]
|
351
|
+
position = position_weight[positions_idx]
|
352
|
+
batch_hidden = position + inputs_embeds[i]
|
353
|
+
hidden_all.append(batch_hidden)
|
354
|
+
hidden_states = torch.stack(hidden_all, dim=0)
|
221
355
|
|
222
|
-
inputs_embeds = self.embed_tokens(input_ids) * self.embed_scale
|
223
|
-
hidden_states = inputs_embeds + positions
|
224
356
|
hidden_states = self.layernorm_embedding(hidden_states)
|
225
357
|
|
226
358
|
# prepare attn_mask
|
@@ -230,14 +362,14 @@ class _BartDecoder(BartDecoder):
|
|
230
362
|
attention_mask, input_shape, inputs_embeds, cache_position
|
231
363
|
)
|
232
364
|
encoder_attention_mask = _prepare_4d_attention_mask_for_sdpa(
|
233
|
-
encoder_attention_mask,
|
365
|
+
encoder_attention_mask, torch.float32, tgt_len=input_shape[-1]
|
234
366
|
)
|
235
367
|
else:
|
236
368
|
attention_mask = _prepare_4d_causal_attention_mask(
|
237
369
|
attention_mask, input_shape, inputs_embeds, cache_position
|
238
370
|
)
|
239
371
|
encoder_attention_mask = _prepare_4d_attention_mask(
|
240
|
-
encoder_attention_mask,
|
372
|
+
encoder_attention_mask, torch.float32, tgt_len=input_shape[-1]
|
241
373
|
)
|
242
374
|
|
243
375
|
# iterate decoder_layer
|
@@ -252,6 +384,7 @@ class _BartDecoder(BartDecoder):
|
|
252
384
|
encoder_attention_mask=encoder_attention_mask,
|
253
385
|
past_key_value=past_key_value,
|
254
386
|
cache_position=cache_position,
|
387
|
+
batch_ids=batch_ids,
|
255
388
|
attn_impl=attn_impl,
|
256
389
|
)
|
257
390
|
hidden_states = layer_outputs[0]
|
@@ -277,9 +410,14 @@ class BartDecoderWrapper(torch.nn.Module):
|
|
277
410
|
attention_mask: torch.Tensor,
|
278
411
|
encoder_attention_mask: torch.Tensor,
|
279
412
|
cache_position: torch.Tensor,
|
413
|
+
batch_position: torch.Tensor,
|
280
414
|
self_kv_cache: torch.Tensor,
|
281
415
|
cross_kv_cache: torch.Tensor,
|
282
416
|
) -> Tuple[torch.FloatTensor, Tuple[torch.FloatTensor]]:
|
417
|
+
if input_ids.shape[1] == 1:
|
418
|
+
rbln_batch_position = None
|
419
|
+
else:
|
420
|
+
rbln_batch_position = batch_position
|
283
421
|
# prepare past_key_values
|
284
422
|
kv_cache = ()
|
285
423
|
for i in range(0, self.num_layers * 2, 2):
|
@@ -291,7 +429,6 @@ class BartDecoderWrapper(torch.nn.Module):
|
|
291
429
|
cross_kv_cache[i + 1],
|
292
430
|
),
|
293
431
|
)
|
294
|
-
|
295
432
|
# decode
|
296
433
|
decoder_outputs = _BartDecoder.forward(
|
297
434
|
self.decoder,
|
@@ -302,6 +439,7 @@ class BartDecoderWrapper(torch.nn.Module):
|
|
302
439
|
past_key_values=kv_cache,
|
303
440
|
encoder_hidden_states=torch.tensor([1]),
|
304
441
|
attn_impl=self.config._attn_implementation,
|
442
|
+
batch_ids=rbln_batch_position,
|
305
443
|
)
|
306
444
|
sequence_output = decoder_outputs[0]
|
307
445
|
lm_logits = self.lm_head(sequence_output)
|
@@ -314,7 +452,8 @@ class BartDecoderWrapper(torch.nn.Module):
|
|
314
452
|
self_kv_cache.append(past_key_values[i][1])
|
315
453
|
self_kv_cache = torch.stack(self_kv_cache, dim=0)
|
316
454
|
|
317
|
-
return
|
455
|
+
# return batch_position to keep it as a variable within the graph
|
456
|
+
return lm_logits, self_kv_cache, batch_position
|
318
457
|
|
319
458
|
|
320
459
|
class BartEncoderWrapper(torch.nn.Module):
|
@@ -330,10 +469,13 @@ class BartEncoderWrapper(torch.nn.Module):
|
|
330
469
|
self.num_heads = self.config.decoder_attention_heads
|
331
470
|
self.d_kv = self.config.d_model // self.num_heads
|
332
471
|
|
333
|
-
def forward(
|
334
|
-
|
335
|
-
|
336
|
-
|
472
|
+
def forward(
|
473
|
+
self,
|
474
|
+
input_ids: torch.LongTensor,
|
475
|
+
attention_mask: torch.LongTensor,
|
476
|
+
cross_key_value: torch.Tensor = None,
|
477
|
+
batch_idx: torch.Tensor = None,
|
478
|
+
) -> Tuple[torch.Tensor]:
|
337
479
|
# 1. run encoder
|
338
480
|
encoder_outputs = self.encoder(input_ids=input_ids, attention_mask=attention_mask)
|
339
481
|
last_hidden_states = encoder_outputs[0]
|
@@ -341,32 +483,35 @@ class BartEncoderWrapper(torch.nn.Module):
|
|
341
483
|
# 2. run dummy decoder to get pre-calculated cross-key_values for generation
|
342
484
|
dummy_past_key_value = []
|
343
485
|
for _ in range(self.num_layers):
|
344
|
-
pkv_self_attn_key = torch.zeros(
|
345
|
-
pkv_self_attn_value = torch.zeros(
|
346
|
-
pkv_cross_attn_key = torch.zeros(
|
347
|
-
pkv_cross_attn_value = torch.zeros(
|
486
|
+
pkv_self_attn_key = torch.zeros(1, self.num_heads, self.decoder_max_length, self.d_kv)
|
487
|
+
pkv_self_attn_value = torch.zeros(1, self.num_heads, self.decoder_max_length, self.d_kv)
|
488
|
+
pkv_cross_attn_key = torch.zeros(1, self.num_heads, self.encoder_max_length, self.d_kv)
|
489
|
+
pkv_cross_attn_value = torch.zeros(1, self.num_heads, self.encoder_max_length, self.d_kv)
|
348
490
|
layer_pkv = (pkv_self_attn_key, pkv_self_attn_value, pkv_cross_attn_key, pkv_cross_attn_value)
|
349
491
|
dummy_past_key_value.append(layer_pkv)
|
350
492
|
|
351
|
-
decoder_attention_mask = torch.zeros(
|
493
|
+
decoder_attention_mask = torch.zeros(1, self.decoder_max_length, dtype=torch.float32)
|
352
494
|
decoder_attention_mask[:, :1] = 1
|
353
495
|
|
354
496
|
decoder_outputs = _BartDecoder.forward(
|
355
497
|
self.decoder,
|
356
|
-
input_ids=torch.zeros((
|
498
|
+
input_ids=torch.zeros((1, 1), dtype=torch.int64),
|
357
499
|
attention_mask=decoder_attention_mask,
|
358
500
|
encoder_attention_mask=attention_mask,
|
359
501
|
cache_position=torch.tensor(0, dtype=torch.int32),
|
360
502
|
encoder_hidden_states=last_hidden_states,
|
361
503
|
past_key_values=dummy_past_key_value,
|
504
|
+
batch_ids=torch.tensor(0, dtype=torch.int32),
|
362
505
|
attn_impl=self.config._attn_implementation,
|
363
506
|
)
|
364
507
|
first_past_kv = decoder_outputs[1]
|
365
508
|
|
366
|
-
# 3. return cross_key_values to recurrence port. fyi (enc_ir.outputs[0] -> dec_ir.inputs[5])
|
367
509
|
encoder_kv = []
|
368
|
-
for
|
369
|
-
encoder_kv.append(
|
370
|
-
|
510
|
+
for i in range(self.model.config.decoder_layers):
|
511
|
+
encoder_kv.append(first_past_kv[i][2].unsqueeze(0))
|
512
|
+
encoder_kv.append(first_past_kv[i][3].unsqueeze(0))
|
513
|
+
encoder_kv = torch.cat(encoder_kv, dim=0)
|
514
|
+
|
515
|
+
cross_key_value = cross_key_value.slice_scatter(encoder_kv, dim=1, start=batch_idx, end=batch_idx + 1)
|
371
516
|
|
372
|
-
return
|
517
|
+
return cross_key_value
|
@@ -0,0 +1,125 @@
|
|
1
|
+
# Copyright 2024 Rebellions Inc.
|
2
|
+
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4
|
+
# you may not use this file except in compliance with the License.
|
5
|
+
# You may obtain a copy of the License at:
|
6
|
+
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
8
|
+
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12
|
+
# See the License for the specific language governing permissions and
|
13
|
+
# limitations under the License.
|
14
|
+
|
15
|
+
# Portions of this software are licensed under the Apache License,
|
16
|
+
# Version 2.0. See the NOTICE file distributed with this work for
|
17
|
+
# additional information regarding copyright ownership.
|
18
|
+
|
19
|
+
# All other portions of this software, including proprietary code,
|
20
|
+
# are the intellectual property of Rebellions Inc. and may not be
|
21
|
+
# copied, modified, or distributed without prior written permission
|
22
|
+
# from Rebellions Inc.
|
23
|
+
|
24
|
+
import inspect
|
25
|
+
from typing import TYPE_CHECKING, Any, Callable, Dict, Optional, Union
|
26
|
+
|
27
|
+
from transformers import BartConfig, BartForConditionalGeneration, BartModel, PretrainedConfig
|
28
|
+
|
29
|
+
from ....modeling_base import RBLNModel
|
30
|
+
from ....modeling_config import RBLNCompileConfig, RBLNConfig
|
31
|
+
from ....utils.logging import get_logger
|
32
|
+
from ...models.seq2seq import RBLNModelForSeq2SeqLM
|
33
|
+
from .bart_architecture import BartWrapper
|
34
|
+
|
35
|
+
|
36
|
+
logger = get_logger()
|
37
|
+
|
38
|
+
|
39
|
+
if TYPE_CHECKING:
|
40
|
+
from transformers import AutoFeatureExtractor, AutoProcessor, AutoTokenizer, PreTrainedModel
|
41
|
+
|
42
|
+
|
43
|
+
class RBLNBartModel(RBLNModel):
|
44
|
+
original_model_class = BartModel
|
45
|
+
original_config_class = BartConfig
|
46
|
+
|
47
|
+
@classmethod
|
48
|
+
def _get_rbln_config(
|
49
|
+
cls,
|
50
|
+
preprocessors: Optional[Union["AutoFeatureExtractor", "AutoProcessor", "AutoTokenizer"]],
|
51
|
+
model_config: Optional["PretrainedConfig"] = None,
|
52
|
+
rbln_kwargs: Dict[str, Any] = {},
|
53
|
+
) -> RBLNConfig:
|
54
|
+
rbln_max_seq_len = rbln_kwargs.get("max_seq_len", None)
|
55
|
+
rbln_batch_size = rbln_kwargs.get("batch_size", None)
|
56
|
+
rbln_model_input_names = rbln_kwargs.get("model_input_names", None)
|
57
|
+
|
58
|
+
max_position_embeddings = getattr(model_config, "max_position_embeddings", None)
|
59
|
+
|
60
|
+
if rbln_max_seq_len is None:
|
61
|
+
rbln_max_seq_len = max_position_embeddings
|
62
|
+
if rbln_max_seq_len is None:
|
63
|
+
for tokenizer in preprocessors:
|
64
|
+
if hasattr(tokenizer, "model_max_length"):
|
65
|
+
rbln_max_seq_len = tokenizer.model_max_length
|
66
|
+
break
|
67
|
+
if rbln_max_seq_len is None:
|
68
|
+
raise ValueError("`rbln_max_seq_len` should be specified!")
|
69
|
+
|
70
|
+
if max_position_embeddings is not None and rbln_max_seq_len > max_position_embeddings:
|
71
|
+
raise ValueError("`rbln_max_seq_len` should be less or equal than max_position_embeddings!")
|
72
|
+
|
73
|
+
if rbln_model_input_names is None:
|
74
|
+
for tokenizer in preprocessors:
|
75
|
+
if hasattr(tokenizer, "model_input_names"):
|
76
|
+
rbln_model_input_names = tokenizer.model_input_names
|
77
|
+
# BartModel's forward() does not take token_type_ids as input.
|
78
|
+
# (Added because some of the tokenizers includes 'token_type_ids')
|
79
|
+
if "token_type_ids" in rbln_model_input_names:
|
80
|
+
rbln_model_input_names.remove("token_type_ids")
|
81
|
+
break
|
82
|
+
if rbln_model_input_names is None and hasattr(cls, "rbln_model_input_names"):
|
83
|
+
rbln_model_input_names = cls.rbln_model_input_names
|
84
|
+
elif rbln_model_input_names is None and hasattr(cls, "rbln_model_input_names") is False:
|
85
|
+
input_names_order = inspect.signature(cls.original_model_class.forward).parameters.keys()
|
86
|
+
raise ValueError(
|
87
|
+
"Specify the model input names obtained by the tokenizer via `rbln_model_input_names`, "
|
88
|
+
f"and be sure to make the order of the inputs same as BartModel forward() arguments like ({list(input_names_order)})"
|
89
|
+
)
|
90
|
+
|
91
|
+
if rbln_batch_size is None:
|
92
|
+
rbln_batch_size = 1
|
93
|
+
|
94
|
+
input_info = [
|
95
|
+
(model_input_name, [rbln_batch_size, rbln_max_seq_len], "int64")
|
96
|
+
for model_input_name in rbln_model_input_names
|
97
|
+
]
|
98
|
+
|
99
|
+
rbln_compile_config = RBLNCompileConfig(input_info=input_info)
|
100
|
+
|
101
|
+
rbln_config = RBLNConfig(
|
102
|
+
rbln_cls=cls.__name__,
|
103
|
+
compile_cfgs=[rbln_compile_config],
|
104
|
+
rbln_kwargs=rbln_kwargs,
|
105
|
+
)
|
106
|
+
|
107
|
+
rbln_config.model_cfg.update({"max_seq_len": rbln_max_seq_len})
|
108
|
+
return rbln_config
|
109
|
+
|
110
|
+
|
111
|
+
class RBLNBartForConditionalGeneration(RBLNModelForSeq2SeqLM):
|
112
|
+
@classmethod
|
113
|
+
def wrap_model_if_needed(self, model: "PreTrainedModel", rbln_config: "RBLNConfig"):
|
114
|
+
return BartWrapper(model)
|
115
|
+
|
116
|
+
def __getattr__(self, __name: str) -> Any:
|
117
|
+
def redirect(func):
|
118
|
+
return lambda *pargs, **kwargs: func(self, *pargs, **kwargs)
|
119
|
+
|
120
|
+
val = getattr(BartForConditionalGeneration, __name)
|
121
|
+
|
122
|
+
if isinstance(val, Callable) and "self" in set(inspect.signature(val).parameters):
|
123
|
+
return redirect(val)
|
124
|
+
|
125
|
+
return val
|
@@ -0,0 +1,24 @@
|
|
1
|
+
# Copyright 2024 Rebellions Inc.
|
2
|
+
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4
|
+
# you may not use this file except in compliance with the License.
|
5
|
+
# You may obtain a copy of the License at:
|
6
|
+
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
8
|
+
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12
|
+
# See the License for the specific language governing permissions and
|
13
|
+
# limitations under the License.
|
14
|
+
|
15
|
+
# Portions of this software are licensed under the Apache License,
|
16
|
+
# Version 2.0. See the NOTICE file distributed with this work for
|
17
|
+
# additional information regarding copyright ownership.
|
18
|
+
|
19
|
+
# All other portions of this software, including proprietary code,
|
20
|
+
# are the intellectual property of Rebellions Inc. and may not be
|
21
|
+
# copied, modified, or distributed without prior written permission
|
22
|
+
# from Rebellions Inc.
|
23
|
+
|
24
|
+
from .modeling_bert import RBLNBertModel
|
@@ -0,0 +1,101 @@
|
|
1
|
+
# Copyright 2024 Rebellions Inc.
|
2
|
+
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4
|
+
# you may not use this file except in compliance with the License.
|
5
|
+
# You may obtain a copy of the License at:
|
6
|
+
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
8
|
+
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12
|
+
# See the License for the specific language governing permissions and
|
13
|
+
# limitations under the License.
|
14
|
+
|
15
|
+
# Portions of this software are licensed under the Apache License,
|
16
|
+
# Version 2.0. See the NOTICE file distributed with this work for
|
17
|
+
# additional information regarding copyright ownership.
|
18
|
+
|
19
|
+
# All other portions of this software, including proprietary code,
|
20
|
+
# are the intellectual property of Rebellions Inc. and may not be
|
21
|
+
# copied, modified, or distributed without prior written permission
|
22
|
+
# from Rebellions Inc.
|
23
|
+
|
24
|
+
import inspect
|
25
|
+
import logging
|
26
|
+
from typing import TYPE_CHECKING, Any, Dict, Optional, Union
|
27
|
+
|
28
|
+
from transformers import BertConfig, BertModel, PretrainedConfig
|
29
|
+
|
30
|
+
from ....modeling_base import RBLNModel
|
31
|
+
from ....modeling_config import RBLNCompileConfig, RBLNConfig
|
32
|
+
|
33
|
+
|
34
|
+
logger = logging.getLogger(__name__)
|
35
|
+
|
36
|
+
if TYPE_CHECKING:
|
37
|
+
from transformers import AutoFeatureExtractor, AutoProcessor, AutoTokenizer
|
38
|
+
|
39
|
+
|
40
|
+
class RBLNBertModel(RBLNModel):
|
41
|
+
original_model_class = BertModel
|
42
|
+
original_config_class = BertConfig
|
43
|
+
|
44
|
+
@classmethod
|
45
|
+
def _get_rbln_config(
|
46
|
+
cls,
|
47
|
+
preprocessors: Optional[Union["AutoFeatureExtractor", "AutoProcessor", "AutoTokenizer"]],
|
48
|
+
model_config: Optional["PretrainedConfig"] = None,
|
49
|
+
rbln_kwargs: Dict[str, Any] = {},
|
50
|
+
) -> RBLNConfig:
|
51
|
+
rbln_max_seq_len = rbln_kwargs.get("max_seq_len", None)
|
52
|
+
rbln_batch_size = rbln_kwargs.get("batch_size", None)
|
53
|
+
rbln_model_input_names = rbln_kwargs.get("model_input_names", None)
|
54
|
+
|
55
|
+
max_position_embeddings = getattr(model_config, "max_position_embeddings", None)
|
56
|
+
|
57
|
+
if rbln_max_seq_len is None:
|
58
|
+
rbln_max_seq_len = max_position_embeddings
|
59
|
+
if rbln_max_seq_len is None:
|
60
|
+
for tokenizer in preprocessors:
|
61
|
+
if hasattr(tokenizer, "model_max_length"):
|
62
|
+
rbln_max_seq_len = tokenizer.model_max_length
|
63
|
+
break
|
64
|
+
if rbln_max_seq_len is None:
|
65
|
+
raise ValueError("`rbln_max_seq_len` should be specified!")
|
66
|
+
|
67
|
+
if max_position_embeddings is not None and rbln_max_seq_len > max_position_embeddings:
|
68
|
+
raise ValueError("`rbln_max_seq_len` should be less or equal than max_position_embeddings!")
|
69
|
+
|
70
|
+
if rbln_model_input_names is None:
|
71
|
+
for tokenizer in preprocessors:
|
72
|
+
if hasattr(tokenizer, "model_input_names"):
|
73
|
+
rbln_model_input_names = tokenizer.model_input_names
|
74
|
+
break
|
75
|
+
if rbln_model_input_names is None and hasattr(cls, "rbln_model_input_names"):
|
76
|
+
rbln_model_input_names = cls.rbln_model_input_names
|
77
|
+
elif rbln_model_input_names is None and hasattr(cls, "rbln_model_input_names") is False:
|
78
|
+
input_names_order = inspect.signature(cls.original_model_class.forward).parameters.keys()
|
79
|
+
raise ValueError(
|
80
|
+
"Specify the model input names obtained by the tokenizer via `rbln_model_input_names`, "
|
81
|
+
f"and be sure to make the order of the inputs same as BertModel forward() arguments like ({list(input_names_order)})"
|
82
|
+
)
|
83
|
+
|
84
|
+
if rbln_batch_size is None:
|
85
|
+
rbln_batch_size = 1
|
86
|
+
|
87
|
+
input_info = [
|
88
|
+
(model_input_name, [rbln_batch_size, rbln_max_seq_len], "int64")
|
89
|
+
for model_input_name in rbln_model_input_names
|
90
|
+
]
|
91
|
+
|
92
|
+
rbln_compile_config = RBLNCompileConfig(input_info=input_info)
|
93
|
+
|
94
|
+
rbln_config = RBLNConfig(
|
95
|
+
rbln_cls=cls.__name__,
|
96
|
+
compile_cfgs=[rbln_compile_config],
|
97
|
+
rbln_kwargs=rbln_kwargs,
|
98
|
+
)
|
99
|
+
|
100
|
+
rbln_config.model_cfg.update({"max_seq_len": rbln_max_seq_len})
|
101
|
+
return rbln_config
|
@@ -21,4 +21,4 @@
|
|
21
21
|
# copied, modified, or distributed without prior written permission
|
22
22
|
# from Rebellions Inc.
|
23
23
|
|
24
|
-
from .modeling_clip import RBLNCLIPTextModel, RBLNCLIPTextModelWithProjection
|
24
|
+
from .modeling_clip import RBLNCLIPTextModel, RBLNCLIPTextModelWithProjection, RBLNCLIPVisionModel
|