optimum-rbln 0.1.9__py3-none-any.whl → 0.1.12__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (73) hide show
  1. optimum/rbln/__init__.py +47 -9
  2. optimum/rbln/__version__.py +1 -1
  3. optimum/rbln/diffusers/models/autoencoder_kl.py +36 -31
  4. optimum/rbln/diffusers/models/controlnet.py +53 -43
  5. optimum/rbln/diffusers/models/unet_2d_condition.py +40 -31
  6. optimum/rbln/diffusers/pipelines/controlnet/multicontrolnet.py +4 -0
  7. optimum/rbln/diffusers/pipelines/controlnet/pipeline_controlnet.py +28 -23
  8. optimum/rbln/diffusers/pipelines/controlnet/pipeline_controlnet_img2img.py +28 -23
  9. optimum/rbln/diffusers/pipelines/controlnet/pipeline_controlnet_sd_xl.py +28 -37
  10. optimum/rbln/diffusers/pipelines/controlnet/pipeline_controlnet_sd_xl_img2img.py +30 -39
  11. optimum/rbln/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion.py +24 -14
  12. optimum/rbln/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_img2img.py +24 -15
  13. optimum/rbln/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl.py +26 -17
  14. optimum/rbln/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_img2img.py +26 -17
  15. optimum/rbln/modeling_alias.py +6 -11
  16. optimum/rbln/modeling_base.py +467 -261
  17. optimum/rbln/modeling_config.py +199 -73
  18. optimum/rbln/transformers/__init__.py +43 -1
  19. optimum/rbln/transformers/models/__init__.py +23 -1
  20. optimum/rbln/transformers/models/auto/__init__.py +14 -0
  21. optimum/rbln/transformers/models/auto/auto_factory.py +84 -0
  22. optimum/rbln/transformers/models/auto/modeling_auto.py +95 -0
  23. optimum/rbln/transformers/models/bart/__init__.py +1 -0
  24. optimum/rbln/transformers/models/bart/bart_architecture.py +203 -58
  25. optimum/rbln/transformers/models/bart/modeling_bart.py +125 -0
  26. optimum/rbln/transformers/models/bert/__init__.py +24 -0
  27. optimum/rbln/transformers/models/bert/modeling_bert.py +101 -0
  28. optimum/rbln/transformers/models/clip/__init__.py +1 -1
  29. optimum/rbln/transformers/models/clip/modeling_clip.py +127 -26
  30. optimum/rbln/transformers/models/decoderonly/decoderonly_architecture.py +28 -4
  31. optimum/rbln/transformers/models/decoderonly/modeling_decoderonly.py +409 -150
  32. optimum/rbln/transformers/models/dpt/modeling_dpt.py +21 -8
  33. optimum/rbln/transformers/models/exaone/__init__.py +32 -0
  34. optimum/rbln/transformers/models/exaone/exaone_architecture.py +72 -0
  35. optimum/rbln/transformers/models/exaone/hf_hub_cached/configuration_exaone.py +181 -0
  36. optimum/rbln/transformers/models/exaone/hf_hub_cached/modeling_exaone.py +1725 -0
  37. optimum/rbln/transformers/models/exaone/modeling_exaone.py +78 -0
  38. optimum/rbln/transformers/models/gemma/modeling_gemma.py +1 -1
  39. optimum/rbln/transformers/models/gpt2/gpt2_architecture.py +4 -1
  40. optimum/rbln/transformers/models/gpt2/modeling_gpt2.py +1 -1
  41. optimum/rbln/transformers/models/llama/modeling_llama.py +1 -1
  42. optimum/rbln/transformers/models/llava_next/__init__.py +24 -0
  43. optimum/rbln/transformers/models/llava_next/modeling_llava_next.py +662 -0
  44. optimum/rbln/transformers/models/midm/midm_architecture.py +5 -1
  45. optimum/rbln/transformers/models/midm/modeling_midm.py +6 -1
  46. optimum/rbln/transformers/models/mistral/modeling_mistral.py +1 -1
  47. optimum/rbln/transformers/models/phi/__init__.py +24 -0
  48. optimum/rbln/transformers/models/phi/modeling_phi.py +69 -0
  49. optimum/rbln/transformers/models/phi/phi_architecture.py +406 -0
  50. optimum/rbln/transformers/models/qwen2/__init__.py +24 -0
  51. optimum/rbln/transformers/models/qwen2/modeling_qwen2.py +67 -0
  52. optimum/rbln/transformers/models/qwen2/qwen2_architecture.py +29 -0
  53. optimum/rbln/transformers/models/seq2seq/__init__.py +24 -0
  54. optimum/rbln/{modeling_seq2seq.py → transformers/models/seq2seq/modeling_seq2seq.py} +198 -168
  55. optimum/rbln/transformers/models/t5/__init__.py +1 -0
  56. optimum/rbln/transformers/models/t5/modeling_t5.py +55 -0
  57. optimum/rbln/transformers/models/t5/t5_architecture.py +122 -47
  58. optimum/rbln/transformers/models/wav2vec2/modeling_wav2vec2.py +17 -12
  59. optimum/rbln/transformers/models/whisper/generation_whisper.py +68 -0
  60. optimum/rbln/transformers/models/whisper/modeling_whisper.py +172 -111
  61. optimum/rbln/transformers/models/whisper/whisper_architecture.py +44 -17
  62. optimum/rbln/transformers/models/xlm_roberta/modeling_xlm_roberta.py +18 -16
  63. optimum/rbln/transformers/utils/rbln_quantization.py +48 -60
  64. optimum/rbln/utils/import_utils.py +50 -1
  65. optimum/rbln/utils/logging.py +82 -0
  66. optimum/rbln/utils/runtime_utils.py +33 -0
  67. optimum/rbln/utils/timer_utils.py +43 -0
  68. {optimum_rbln-0.1.9.dist-info → optimum_rbln-0.1.12.dist-info}/METADATA +9 -7
  69. optimum_rbln-0.1.12.dist-info/RECORD +103 -0
  70. {optimum_rbln-0.1.9.dist-info → optimum_rbln-0.1.12.dist-info}/WHEEL +1 -1
  71. optimum_rbln-0.1.12.dist-info/entry_points.txt +4 -0
  72. optimum_rbln-0.1.9.dist-info/RECORD +0 -78
  73. {optimum_rbln-0.1.9.dist-info → optimum_rbln-0.1.12.dist-info}/licenses/LICENSE +0 -0
@@ -47,6 +47,12 @@ from transformers.utils import logging
47
47
  logger = logging.get_logger(__name__)
48
48
 
49
49
 
50
+ class BartWrapper:
51
+ def __init__(self, model):
52
+ self.encoder = BartEncoderWrapper(model)
53
+ self.decoder = BartDecoderWrapper(model)
54
+
55
+
50
56
  class _BartAttention(BartAttention):
51
57
  def forward(
52
58
  self,
@@ -54,6 +60,7 @@ class _BartAttention(BartAttention):
54
60
  past_key_value: Tuple[torch.Tensor],
55
61
  attention_mask: torch.Tensor,
56
62
  cache_position: torch.Tensor,
63
+ batch_index: torch.Tensor,
57
64
  key_value_states: Optional[torch.Tensor] = None,
58
65
  ) -> Tuple[torch.Tensor, Tuple[torch.Tensor]]:
59
66
  bsz, tgt_len, _ = hidden_states.size()
@@ -72,28 +79,83 @@ class _BartAttention(BartAttention):
72
79
  else:
73
80
  key_states = self._shape(self.k_proj(hidden_states), -1, bsz)
74
81
  value_states = self._shape(self.v_proj(hidden_states), -1, bsz)
75
- key_states = past_key_value[0].slice_scatter(
76
- key_states, dim=2, start=cache_position, end=cache_position + 1
77
- )
78
- value_states = past_key_value[1].slice_scatter(
79
- value_states, dim=2, start=cache_position, end=cache_position + 1
80
- )
81
82
 
82
- proj_shape = (bsz * self.num_heads, -1, self.head_dim)
83
- query_states = self._shape(query_states, tgt_len, bsz).view(*proj_shape)
84
- key_states = key_states.reshape(*proj_shape)
85
- value_states = value_states.reshape(*proj_shape)
86
-
87
- src_len = key_states.size(1)
88
- attn_weights = torch.bmm(query_states, key_states.transpose(1, 2))
89
- attn_weights = attn_weights.view(bsz, self.num_heads, tgt_len, src_len) + attention_mask
90
- attn_weights = attn_weights.view(bsz * self.num_heads, tgt_len, src_len)
91
- attn_weights = nn.functional.softmax(attn_weights, dim=-1)
92
-
93
- attn_output = torch.bmm(attn_weights, value_states)
94
- attn_output = attn_output.view(bsz, self.num_heads, tgt_len, self.head_dim)
95
- attn_output = attn_output.transpose(1, 2)
96
- attn_output = attn_output.reshape(bsz, tgt_len, self.embed_dim)
83
+ if cache_position.dim() > 0:
84
+ proj_shape = (bsz, self.num_heads, -1, self.head_dim)
85
+ query_states = self._shape(query_states, tgt_len, bsz).view(*proj_shape)
86
+ key_states = key_states.reshape(*proj_shape)
87
+ value_states = value_states.reshape(*proj_shape)
88
+
89
+ all_key_states = []
90
+ all_value_states = []
91
+ all_attn_output = []
92
+ for b in range(bsz):
93
+ batch_query_states = query_states[b].unsqueeze(0).unsqueeze(2)
94
+ batch_attention_mask = attention_mask[b].unsqueeze(0).unsqueeze(2)
95
+ batch_key_states = key_states[b].unsqueeze(0).unsqueeze(2)
96
+ batch_value_states = value_states[b].unsqueeze(0).unsqueeze(2)
97
+ if not is_cross_attention:
98
+ batch_key_states = (
99
+ past_key_value[0][b]
100
+ .unsqueeze(0)
101
+ .unsqueeze(2)
102
+ .slice_scatter(
103
+ batch_key_states, dim=-2, start=cache_position[b][0], end=cache_position[b][0] + 1
104
+ )
105
+ )
106
+ batch_value_states = (
107
+ past_key_value[1][b]
108
+ .unsqueeze(0)
109
+ .unsqueeze(2)
110
+ .slice_scatter(
111
+ batch_value_states, dim=-2, start=cache_position[b][0], end=cache_position[b][0] + 1
112
+ )
113
+ )
114
+ attn_weights = torch.matmul(batch_query_states, batch_key_states.transpose(3, 4))
115
+ attn_weights = attn_weights + batch_attention_mask
116
+ attn_weights = nn.functional.softmax(attn_weights, dim=-1)
117
+
118
+ attn_output = torch.matmul(attn_weights, batch_value_states)
119
+ attn_output = attn_output.view(1, self.num_heads, tgt_len, self.head_dim)
120
+ attn_output = attn_output.transpose(1, 2)
121
+ attn_output = attn_output.reshape(1, tgt_len, self.embed_dim)
122
+ all_key_states.append(batch_key_states)
123
+ all_value_states.append(batch_value_states)
124
+ all_attn_output.append(attn_output)
125
+ key_states = torch.cat(all_key_states, dim=0).squeeze(2)
126
+ value_states = torch.cat(all_value_states, dim=0).squeeze(2)
127
+ attn_output = torch.cat(all_attn_output, dim=0)
128
+
129
+ else:
130
+ if batch_index is None or batch_index == -1:
131
+ batch_index = 0
132
+
133
+ if not is_cross_attention:
134
+ key_states = past_key_value[0].slice_scatter(
135
+ key_states, dim=2, start=cache_position, end=cache_position + 1
136
+ )
137
+ value_states = past_key_value[1].slice_scatter(
138
+ value_states, dim=2, start=cache_position, end=cache_position + 1
139
+ )
140
+
141
+ proj_shape = (bsz * self.num_heads, -1, self.head_dim)
142
+ query_states = self._shape(query_states, tgt_len, bsz).view(*proj_shape)
143
+ key_states = key_states.reshape(*proj_shape)
144
+ value_states = value_states.reshape(*proj_shape)
145
+
146
+ src_len = key_states.size(1)
147
+ attn_weights = torch.bmm(query_states, key_states.transpose(1, 2))
148
+ attn_weights = attn_weights.view(bsz, self.num_heads, tgt_len, src_len) + attention_mask
149
+ attn_weights = attn_weights.view(bsz * self.num_heads, tgt_len, src_len)
150
+ attn_weights = nn.functional.softmax(attn_weights, dim=-1)
151
+
152
+ attn_output = torch.bmm(attn_weights, value_states)
153
+ attn_output = attn_output.view(bsz, self.num_heads, tgt_len, self.head_dim)
154
+ attn_output = attn_output.transpose(1, 2)
155
+ key_states = key_states.unsqueeze(0)
156
+ value_states = value_states.unsqueeze(0)
157
+ attn_output = attn_output.reshape(bsz, tgt_len, self.embed_dim)
158
+
97
159
  attn_output = self.out_proj(attn_output)
98
160
 
99
161
  present_key_value = (key_states, value_states)
@@ -108,6 +170,7 @@ class _BartSdpaAttention(BartSdpaAttention):
108
170
  past_key_value: Tuple[torch.Tensor],
109
171
  attention_mask: torch.Tensor,
110
172
  cache_position: torch.Tensor,
173
+ batch_index: torch.Tensor,
111
174
  key_value_states: Optional[torch.Tensor] = None,
112
175
  ) -> Tuple[torch.Tensor, Tuple[torch.Tensor]]:
113
176
  bsz, tgt_len, _ = hidden_states.size()
@@ -126,23 +189,71 @@ class _BartSdpaAttention(BartSdpaAttention):
126
189
  else:
127
190
  key_states = self._shape(self.k_proj(hidden_states), -1, bsz)
128
191
  value_states = self._shape(self.v_proj(hidden_states), -1, bsz)
129
- key_states = past_key_value[0].slice_scatter(
130
- key_states, dim=2, start=cache_position, end=cache_position + 1
131
- )
132
- value_states = past_key_value[1].slice_scatter(
133
- value_states, dim=2, start=cache_position, end=cache_position + 1
134
- )
135
192
 
136
193
  query_states = self._shape(query_states, tgt_len, bsz)
137
194
 
138
- attn_output = torch.nn.functional.scaled_dot_product_attention(
139
- query_states,
140
- key_states,
141
- value_states,
142
- attn_mask=attention_mask,
143
- )
144
- attn_output = attn_output.transpose(1, 2)
145
- attn_output = attn_output.reshape(bsz, tgt_len, self.embed_dim)
195
+ if (batch_index is None or batch_index == -1) and bsz > 1:
196
+ all_key_states = []
197
+ all_value_states = []
198
+ all_attn_output = []
199
+
200
+ for b in range(bsz):
201
+ batch_query_states = query_states[b].unsqueeze(0)
202
+ batch_attention_mask = attention_mask[b].unsqueeze(0)
203
+ batch_key_states = key_states[b].unsqueeze(0)
204
+ batch_value_states = value_states[b].unsqueeze(0)
205
+
206
+ if not is_cross_attention:
207
+ batch_key_states = (
208
+ past_key_value[0][b]
209
+ .unsqueeze(0)
210
+ .slice_scatter(
211
+ batch_key_states, dim=-2, start=cache_position[b][0], end=cache_position[b][0] + 1
212
+ )
213
+ )
214
+ batch_value_states = (
215
+ past_key_value[1][b]
216
+ .unsqueeze(0)
217
+ .slice_scatter(
218
+ batch_value_states, dim=-2, start=cache_position[b][0], end=cache_position[b][0] + 1
219
+ )
220
+ )
221
+
222
+ attn_output = torch.nn.functional.scaled_dot_product_attention(
223
+ batch_query_states, batch_key_states, batch_value_states, attn_mask=batch_attention_mask
224
+ )
225
+ attn_output = attn_output.transpose(1, 2)
226
+ attn_output = attn_output.reshape(1, tgt_len, self.embed_dim)
227
+ all_key_states.append(batch_key_states)
228
+ all_value_states.append(batch_value_states)
229
+ all_attn_output.append(attn_output)
230
+
231
+ key_states = torch.cat(all_key_states, dim=0)
232
+ value_states = torch.cat(all_value_states, dim=0)
233
+ attn_output = torch.cat(all_attn_output, dim=0)
234
+
235
+ else:
236
+ if batch_index is None or batch_index == -1:
237
+ batch_index = 0
238
+
239
+ if not is_cross_attention:
240
+ key_states = past_key_value[0].slice_scatter(
241
+ key_states, dim=2, start=cache_position, end=cache_position + 1
242
+ )
243
+ value_states = past_key_value[1].slice_scatter(
244
+ value_states, dim=2, start=cache_position, end=cache_position + 1
245
+ )
246
+
247
+ # need 4d shape (input tensors) for scaled_dot_product_attention
248
+ attn_output = torch.nn.functional.scaled_dot_product_attention(
249
+ query_states,
250
+ key_states,
251
+ value_states,
252
+ attn_mask=attention_mask,
253
+ )
254
+ attn_output = attn_output.transpose(1, 2)
255
+ attn_output = attn_output.reshape(bsz, tgt_len, self.embed_dim)
256
+
146
257
  attn_output = self.out_proj(attn_output)
147
258
 
148
259
  present_key_value = (key_states, value_states)
@@ -162,6 +273,7 @@ class _BartDecoderLayer(BartDecoderLayer):
162
273
  encoder_hidden_states: torch.Tensor,
163
274
  past_key_value: Tuple[torch.Tensor],
164
275
  cache_position: torch.Tensor,
276
+ batch_ids: torch.Tensor,
165
277
  attn_impl: str = "eager",
166
278
  ) -> Tuple[torch.Tensor, Tuple[torch.Tensor]]:
167
279
  # Self Attention Block
@@ -174,6 +286,7 @@ class _BartDecoderLayer(BartDecoderLayer):
174
286
  past_key_value=self_attn_past_key_value,
175
287
  attention_mask=attention_mask,
176
288
  cache_position=cache_position,
289
+ batch_index=batch_ids,
177
290
  )
178
291
  hidden_states = residual + hidden_states
179
292
  hidden_states = self.self_attn_layer_norm(hidden_states)
@@ -189,6 +302,7 @@ class _BartDecoderLayer(BartDecoderLayer):
189
302
  past_key_value=cross_attn_past_key_value,
190
303
  attention_mask=encoder_attention_mask,
191
304
  cache_position=cache_position,
305
+ batch_index=batch_ids,
192
306
  )
193
307
  hidden_states = residual + hidden_states
194
308
  hidden_states = self.encoder_attn_layer_norm(hidden_states)
@@ -213,14 +327,32 @@ class _BartDecoder(BartDecoder):
213
327
  encoder_hidden_states: torch.Tensor,
214
328
  past_key_values: torch.Tensor,
215
329
  cache_position: torch.Tensor,
330
+ batch_ids: torch.Tensor,
216
331
  attn_impl: str = "eager",
217
332
  ):
218
333
  # embedding
219
- positions_idx = cache_position + self.embed_positions.offset
220
- positions = self.embed_positions.weight[positions_idx]
334
+ if hasattr(self, "embed_scale"):
335
+ inputs_embeds = self.embed_tokens(input_ids) * self.embed_scale
336
+ else:
337
+ inputs_embeds = self.embed_tokens(input_ids)
338
+
339
+ if cache_position.dim() == 0:
340
+ positions_idx = cache_position + self.embed_positions.offset
341
+ positions = self.embed_positions.weight[positions_idx]
342
+ hidden_states = inputs_embeds + positions
343
+ else:
344
+ hidden_all = []
345
+ # compiler pattern base dependency -> take + add
346
+ for i in range(input_ids.shape[0]):
347
+ # cache position [N,1]
348
+ positions_idx = cache_position[i]
349
+ # offset is set 2 in bart embedding
350
+ position_weight = self.embed_positions.weight[2:]
351
+ position = position_weight[positions_idx]
352
+ batch_hidden = position + inputs_embeds[i]
353
+ hidden_all.append(batch_hidden)
354
+ hidden_states = torch.stack(hidden_all, dim=0)
221
355
 
222
- inputs_embeds = self.embed_tokens(input_ids) * self.embed_scale
223
- hidden_states = inputs_embeds + positions
224
356
  hidden_states = self.layernorm_embedding(hidden_states)
225
357
 
226
358
  # prepare attn_mask
@@ -230,14 +362,14 @@ class _BartDecoder(BartDecoder):
230
362
  attention_mask, input_shape, inputs_embeds, cache_position
231
363
  )
232
364
  encoder_attention_mask = _prepare_4d_attention_mask_for_sdpa(
233
- encoder_attention_mask, inputs_embeds.dtype, tgt_len=input_shape[-1]
365
+ encoder_attention_mask, torch.float32, tgt_len=input_shape[-1]
234
366
  )
235
367
  else:
236
368
  attention_mask = _prepare_4d_causal_attention_mask(
237
369
  attention_mask, input_shape, inputs_embeds, cache_position
238
370
  )
239
371
  encoder_attention_mask = _prepare_4d_attention_mask(
240
- encoder_attention_mask, inputs_embeds.dtype, tgt_len=input_shape[-1]
372
+ encoder_attention_mask, torch.float32, tgt_len=input_shape[-1]
241
373
  )
242
374
 
243
375
  # iterate decoder_layer
@@ -252,6 +384,7 @@ class _BartDecoder(BartDecoder):
252
384
  encoder_attention_mask=encoder_attention_mask,
253
385
  past_key_value=past_key_value,
254
386
  cache_position=cache_position,
387
+ batch_ids=batch_ids,
255
388
  attn_impl=attn_impl,
256
389
  )
257
390
  hidden_states = layer_outputs[0]
@@ -277,9 +410,14 @@ class BartDecoderWrapper(torch.nn.Module):
277
410
  attention_mask: torch.Tensor,
278
411
  encoder_attention_mask: torch.Tensor,
279
412
  cache_position: torch.Tensor,
413
+ batch_position: torch.Tensor,
280
414
  self_kv_cache: torch.Tensor,
281
415
  cross_kv_cache: torch.Tensor,
282
416
  ) -> Tuple[torch.FloatTensor, Tuple[torch.FloatTensor]]:
417
+ if input_ids.shape[1] == 1:
418
+ rbln_batch_position = None
419
+ else:
420
+ rbln_batch_position = batch_position
283
421
  # prepare past_key_values
284
422
  kv_cache = ()
285
423
  for i in range(0, self.num_layers * 2, 2):
@@ -291,7 +429,6 @@ class BartDecoderWrapper(torch.nn.Module):
291
429
  cross_kv_cache[i + 1],
292
430
  ),
293
431
  )
294
-
295
432
  # decode
296
433
  decoder_outputs = _BartDecoder.forward(
297
434
  self.decoder,
@@ -302,6 +439,7 @@ class BartDecoderWrapper(torch.nn.Module):
302
439
  past_key_values=kv_cache,
303
440
  encoder_hidden_states=torch.tensor([1]),
304
441
  attn_impl=self.config._attn_implementation,
442
+ batch_ids=rbln_batch_position,
305
443
  )
306
444
  sequence_output = decoder_outputs[0]
307
445
  lm_logits = self.lm_head(sequence_output)
@@ -314,7 +452,8 @@ class BartDecoderWrapper(torch.nn.Module):
314
452
  self_kv_cache.append(past_key_values[i][1])
315
453
  self_kv_cache = torch.stack(self_kv_cache, dim=0)
316
454
 
317
- return lm_logits, self_kv_cache
455
+ # return batch_position to keep it as a variable within the graph
456
+ return lm_logits, self_kv_cache, batch_position
318
457
 
319
458
 
320
459
  class BartEncoderWrapper(torch.nn.Module):
@@ -330,10 +469,13 @@ class BartEncoderWrapper(torch.nn.Module):
330
469
  self.num_heads = self.config.decoder_attention_heads
331
470
  self.d_kv = self.config.d_model // self.num_heads
332
471
 
333
- def forward(self, input_ids: torch.LongTensor, attention_mask: torch.LongTensor) -> Tuple[torch.Tensor]:
334
- encoder_batch_size = input_ids.shape[0]
335
- decoder_batch_size = encoder_batch_size # TODO(taehoon) fix to enable beam-search
336
-
472
+ def forward(
473
+ self,
474
+ input_ids: torch.LongTensor,
475
+ attention_mask: torch.LongTensor,
476
+ cross_key_value: torch.Tensor = None,
477
+ batch_idx: torch.Tensor = None,
478
+ ) -> Tuple[torch.Tensor]:
337
479
  # 1. run encoder
338
480
  encoder_outputs = self.encoder(input_ids=input_ids, attention_mask=attention_mask)
339
481
  last_hidden_states = encoder_outputs[0]
@@ -341,32 +483,35 @@ class BartEncoderWrapper(torch.nn.Module):
341
483
  # 2. run dummy decoder to get pre-calculated cross-key_values for generation
342
484
  dummy_past_key_value = []
343
485
  for _ in range(self.num_layers):
344
- pkv_self_attn_key = torch.zeros(decoder_batch_size, self.num_heads, self.decoder_max_length, self.d_kv)
345
- pkv_self_attn_value = torch.zeros(decoder_batch_size, self.num_heads, self.decoder_max_length, self.d_kv)
346
- pkv_cross_attn_key = torch.zeros(encoder_batch_size, self.num_heads, self.encoder_max_length, self.d_kv)
347
- pkv_cross_attn_value = torch.zeros(encoder_batch_size, self.num_heads, self.encoder_max_length, self.d_kv)
486
+ pkv_self_attn_key = torch.zeros(1, self.num_heads, self.decoder_max_length, self.d_kv)
487
+ pkv_self_attn_value = torch.zeros(1, self.num_heads, self.decoder_max_length, self.d_kv)
488
+ pkv_cross_attn_key = torch.zeros(1, self.num_heads, self.encoder_max_length, self.d_kv)
489
+ pkv_cross_attn_value = torch.zeros(1, self.num_heads, self.encoder_max_length, self.d_kv)
348
490
  layer_pkv = (pkv_self_attn_key, pkv_self_attn_value, pkv_cross_attn_key, pkv_cross_attn_value)
349
491
  dummy_past_key_value.append(layer_pkv)
350
492
 
351
- decoder_attention_mask = torch.zeros(decoder_batch_size, self.decoder_max_length, dtype=torch.int64)
493
+ decoder_attention_mask = torch.zeros(1, self.decoder_max_length, dtype=torch.float32)
352
494
  decoder_attention_mask[:, :1] = 1
353
495
 
354
496
  decoder_outputs = _BartDecoder.forward(
355
497
  self.decoder,
356
- input_ids=torch.zeros((decoder_batch_size, 1), dtype=torch.int64),
498
+ input_ids=torch.zeros((1, 1), dtype=torch.int64),
357
499
  attention_mask=decoder_attention_mask,
358
500
  encoder_attention_mask=attention_mask,
359
501
  cache_position=torch.tensor(0, dtype=torch.int32),
360
502
  encoder_hidden_states=last_hidden_states,
361
503
  past_key_values=dummy_past_key_value,
504
+ batch_ids=torch.tensor(0, dtype=torch.int32),
362
505
  attn_impl=self.config._attn_implementation,
363
506
  )
364
507
  first_past_kv = decoder_outputs[1]
365
508
 
366
- # 3. return cross_key_values to recurrence port. fyi (enc_ir.outputs[0] -> dec_ir.inputs[5])
367
509
  encoder_kv = []
368
- for layer_out in first_past_kv: # for layer
369
- encoder_kv.append(torch.stack(layer_out[2:], dim=0))
370
- encoder_kv = torch.stack(encoder_kv, dim=0)
510
+ for i in range(self.model.config.decoder_layers):
511
+ encoder_kv.append(first_past_kv[i][2].unsqueeze(0))
512
+ encoder_kv.append(first_past_kv[i][3].unsqueeze(0))
513
+ encoder_kv = torch.cat(encoder_kv, dim=0)
514
+
515
+ cross_key_value = cross_key_value.slice_scatter(encoder_kv, dim=1, start=batch_idx, end=batch_idx + 1)
371
516
 
372
- return encoder_kv
517
+ return cross_key_value
@@ -0,0 +1,125 @@
1
+ # Copyright 2024 Rebellions Inc.
2
+
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at:
6
+
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ # Portions of this software are licensed under the Apache License,
16
+ # Version 2.0. See the NOTICE file distributed with this work for
17
+ # additional information regarding copyright ownership.
18
+
19
+ # All other portions of this software, including proprietary code,
20
+ # are the intellectual property of Rebellions Inc. and may not be
21
+ # copied, modified, or distributed without prior written permission
22
+ # from Rebellions Inc.
23
+
24
+ import inspect
25
+ from typing import TYPE_CHECKING, Any, Callable, Dict, Optional, Union
26
+
27
+ from transformers import BartConfig, BartForConditionalGeneration, BartModel, PretrainedConfig
28
+
29
+ from ....modeling_base import RBLNModel
30
+ from ....modeling_config import RBLNCompileConfig, RBLNConfig
31
+ from ....utils.logging import get_logger
32
+ from ...models.seq2seq import RBLNModelForSeq2SeqLM
33
+ from .bart_architecture import BartWrapper
34
+
35
+
36
+ logger = get_logger()
37
+
38
+
39
+ if TYPE_CHECKING:
40
+ from transformers import AutoFeatureExtractor, AutoProcessor, AutoTokenizer, PreTrainedModel
41
+
42
+
43
+ class RBLNBartModel(RBLNModel):
44
+ original_model_class = BartModel
45
+ original_config_class = BartConfig
46
+
47
+ @classmethod
48
+ def _get_rbln_config(
49
+ cls,
50
+ preprocessors: Optional[Union["AutoFeatureExtractor", "AutoProcessor", "AutoTokenizer"]],
51
+ model_config: Optional["PretrainedConfig"] = None,
52
+ rbln_kwargs: Dict[str, Any] = {},
53
+ ) -> RBLNConfig:
54
+ rbln_max_seq_len = rbln_kwargs.get("max_seq_len", None)
55
+ rbln_batch_size = rbln_kwargs.get("batch_size", None)
56
+ rbln_model_input_names = rbln_kwargs.get("model_input_names", None)
57
+
58
+ max_position_embeddings = getattr(model_config, "max_position_embeddings", None)
59
+
60
+ if rbln_max_seq_len is None:
61
+ rbln_max_seq_len = max_position_embeddings
62
+ if rbln_max_seq_len is None:
63
+ for tokenizer in preprocessors:
64
+ if hasattr(tokenizer, "model_max_length"):
65
+ rbln_max_seq_len = tokenizer.model_max_length
66
+ break
67
+ if rbln_max_seq_len is None:
68
+ raise ValueError("`rbln_max_seq_len` should be specified!")
69
+
70
+ if max_position_embeddings is not None and rbln_max_seq_len > max_position_embeddings:
71
+ raise ValueError("`rbln_max_seq_len` should be less or equal than max_position_embeddings!")
72
+
73
+ if rbln_model_input_names is None:
74
+ for tokenizer in preprocessors:
75
+ if hasattr(tokenizer, "model_input_names"):
76
+ rbln_model_input_names = tokenizer.model_input_names
77
+ # BartModel's forward() does not take token_type_ids as input.
78
+ # (Added because some of the tokenizers includes 'token_type_ids')
79
+ if "token_type_ids" in rbln_model_input_names:
80
+ rbln_model_input_names.remove("token_type_ids")
81
+ break
82
+ if rbln_model_input_names is None and hasattr(cls, "rbln_model_input_names"):
83
+ rbln_model_input_names = cls.rbln_model_input_names
84
+ elif rbln_model_input_names is None and hasattr(cls, "rbln_model_input_names") is False:
85
+ input_names_order = inspect.signature(cls.original_model_class.forward).parameters.keys()
86
+ raise ValueError(
87
+ "Specify the model input names obtained by the tokenizer via `rbln_model_input_names`, "
88
+ f"and be sure to make the order of the inputs same as BartModel forward() arguments like ({list(input_names_order)})"
89
+ )
90
+
91
+ if rbln_batch_size is None:
92
+ rbln_batch_size = 1
93
+
94
+ input_info = [
95
+ (model_input_name, [rbln_batch_size, rbln_max_seq_len], "int64")
96
+ for model_input_name in rbln_model_input_names
97
+ ]
98
+
99
+ rbln_compile_config = RBLNCompileConfig(input_info=input_info)
100
+
101
+ rbln_config = RBLNConfig(
102
+ rbln_cls=cls.__name__,
103
+ compile_cfgs=[rbln_compile_config],
104
+ rbln_kwargs=rbln_kwargs,
105
+ )
106
+
107
+ rbln_config.model_cfg.update({"max_seq_len": rbln_max_seq_len})
108
+ return rbln_config
109
+
110
+
111
+ class RBLNBartForConditionalGeneration(RBLNModelForSeq2SeqLM):
112
+ @classmethod
113
+ def wrap_model_if_needed(self, model: "PreTrainedModel", rbln_config: "RBLNConfig"):
114
+ return BartWrapper(model)
115
+
116
+ def __getattr__(self, __name: str) -> Any:
117
+ def redirect(func):
118
+ return lambda *pargs, **kwargs: func(self, *pargs, **kwargs)
119
+
120
+ val = getattr(BartForConditionalGeneration, __name)
121
+
122
+ if isinstance(val, Callable) and "self" in set(inspect.signature(val).parameters):
123
+ return redirect(val)
124
+
125
+ return val
@@ -0,0 +1,24 @@
1
+ # Copyright 2024 Rebellions Inc.
2
+
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at:
6
+
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ # Portions of this software are licensed under the Apache License,
16
+ # Version 2.0. See the NOTICE file distributed with this work for
17
+ # additional information regarding copyright ownership.
18
+
19
+ # All other portions of this software, including proprietary code,
20
+ # are the intellectual property of Rebellions Inc. and may not be
21
+ # copied, modified, or distributed without prior written permission
22
+ # from Rebellions Inc.
23
+
24
+ from .modeling_bert import RBLNBertModel
@@ -0,0 +1,101 @@
1
+ # Copyright 2024 Rebellions Inc.
2
+
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at:
6
+
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ # Portions of this software are licensed under the Apache License,
16
+ # Version 2.0. See the NOTICE file distributed with this work for
17
+ # additional information regarding copyright ownership.
18
+
19
+ # All other portions of this software, including proprietary code,
20
+ # are the intellectual property of Rebellions Inc. and may not be
21
+ # copied, modified, or distributed without prior written permission
22
+ # from Rebellions Inc.
23
+
24
+ import inspect
25
+ import logging
26
+ from typing import TYPE_CHECKING, Any, Dict, Optional, Union
27
+
28
+ from transformers import BertConfig, BertModel, PretrainedConfig
29
+
30
+ from ....modeling_base import RBLNModel
31
+ from ....modeling_config import RBLNCompileConfig, RBLNConfig
32
+
33
+
34
+ logger = logging.getLogger(__name__)
35
+
36
+ if TYPE_CHECKING:
37
+ from transformers import AutoFeatureExtractor, AutoProcessor, AutoTokenizer
38
+
39
+
40
+ class RBLNBertModel(RBLNModel):
41
+ original_model_class = BertModel
42
+ original_config_class = BertConfig
43
+
44
+ @classmethod
45
+ def _get_rbln_config(
46
+ cls,
47
+ preprocessors: Optional[Union["AutoFeatureExtractor", "AutoProcessor", "AutoTokenizer"]],
48
+ model_config: Optional["PretrainedConfig"] = None,
49
+ rbln_kwargs: Dict[str, Any] = {},
50
+ ) -> RBLNConfig:
51
+ rbln_max_seq_len = rbln_kwargs.get("max_seq_len", None)
52
+ rbln_batch_size = rbln_kwargs.get("batch_size", None)
53
+ rbln_model_input_names = rbln_kwargs.get("model_input_names", None)
54
+
55
+ max_position_embeddings = getattr(model_config, "max_position_embeddings", None)
56
+
57
+ if rbln_max_seq_len is None:
58
+ rbln_max_seq_len = max_position_embeddings
59
+ if rbln_max_seq_len is None:
60
+ for tokenizer in preprocessors:
61
+ if hasattr(tokenizer, "model_max_length"):
62
+ rbln_max_seq_len = tokenizer.model_max_length
63
+ break
64
+ if rbln_max_seq_len is None:
65
+ raise ValueError("`rbln_max_seq_len` should be specified!")
66
+
67
+ if max_position_embeddings is not None and rbln_max_seq_len > max_position_embeddings:
68
+ raise ValueError("`rbln_max_seq_len` should be less or equal than max_position_embeddings!")
69
+
70
+ if rbln_model_input_names is None:
71
+ for tokenizer in preprocessors:
72
+ if hasattr(tokenizer, "model_input_names"):
73
+ rbln_model_input_names = tokenizer.model_input_names
74
+ break
75
+ if rbln_model_input_names is None and hasattr(cls, "rbln_model_input_names"):
76
+ rbln_model_input_names = cls.rbln_model_input_names
77
+ elif rbln_model_input_names is None and hasattr(cls, "rbln_model_input_names") is False:
78
+ input_names_order = inspect.signature(cls.original_model_class.forward).parameters.keys()
79
+ raise ValueError(
80
+ "Specify the model input names obtained by the tokenizer via `rbln_model_input_names`, "
81
+ f"and be sure to make the order of the inputs same as BertModel forward() arguments like ({list(input_names_order)})"
82
+ )
83
+
84
+ if rbln_batch_size is None:
85
+ rbln_batch_size = 1
86
+
87
+ input_info = [
88
+ (model_input_name, [rbln_batch_size, rbln_max_seq_len], "int64")
89
+ for model_input_name in rbln_model_input_names
90
+ ]
91
+
92
+ rbln_compile_config = RBLNCompileConfig(input_info=input_info)
93
+
94
+ rbln_config = RBLNConfig(
95
+ rbln_cls=cls.__name__,
96
+ compile_cfgs=[rbln_compile_config],
97
+ rbln_kwargs=rbln_kwargs,
98
+ )
99
+
100
+ rbln_config.model_cfg.update({"max_seq_len": rbln_max_seq_len})
101
+ return rbln_config
@@ -21,4 +21,4 @@
21
21
  # copied, modified, or distributed without prior written permission
22
22
  # from Rebellions Inc.
23
23
 
24
- from .modeling_clip import RBLNCLIPTextModel, RBLNCLIPTextModelWithProjection
24
+ from .modeling_clip import RBLNCLIPTextModel, RBLNCLIPTextModelWithProjection, RBLNCLIPVisionModel