optimum-rbln 0.1.11__py3-none-any.whl → 0.1.12__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (54) hide show
  1. optimum/rbln/__init__.py +10 -7
  2. optimum/rbln/__version__.py +1 -1
  3. optimum/rbln/diffusers/models/autoencoder_kl.py +0 -2
  4. optimum/rbln/diffusers/models/controlnet.py +0 -6
  5. optimum/rbln/diffusers/models/unet_2d_condition.py +0 -3
  6. optimum/rbln/diffusers/pipelines/controlnet/multicontrolnet.py +4 -0
  7. optimum/rbln/diffusers/pipelines/controlnet/pipeline_controlnet.py +18 -20
  8. optimum/rbln/diffusers/pipelines/controlnet/pipeline_controlnet_img2img.py +18 -20
  9. optimum/rbln/diffusers/pipelines/controlnet/pipeline_controlnet_sd_xl.py +19 -34
  10. optimum/rbln/diffusers/pipelines/controlnet/pipeline_controlnet_sd_xl_img2img.py +20 -35
  11. optimum/rbln/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion.py +12 -13
  12. optimum/rbln/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_img2img.py +12 -14
  13. optimum/rbln/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl.py +13 -14
  14. optimum/rbln/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_img2img.py +13 -14
  15. optimum/rbln/modeling_alias.py +4 -9
  16. optimum/rbln/modeling_base.py +105 -139
  17. optimum/rbln/modeling_config.py +51 -0
  18. optimum/rbln/transformers/__init__.py +8 -0
  19. optimum/rbln/transformers/models/__init__.py +4 -1
  20. optimum/rbln/transformers/models/auto/modeling_auto.py +1 -0
  21. optimum/rbln/transformers/models/bart/__init__.py +1 -1
  22. optimum/rbln/transformers/models/bart/bart_architecture.py +18 -12
  23. optimum/rbln/transformers/models/bart/modeling_bart.py +25 -6
  24. optimum/rbln/transformers/models/bert/modeling_bert.py +1 -2
  25. optimum/rbln/transformers/models/clip/modeling_clip.py +0 -1
  26. optimum/rbln/transformers/models/decoderonly/modeling_decoderonly.py +172 -100
  27. optimum/rbln/transformers/models/dpt/modeling_dpt.py +0 -1
  28. optimum/rbln/transformers/models/exaone/__init__.py +32 -0
  29. optimum/rbln/transformers/models/exaone/exaone_architecture.py +72 -0
  30. optimum/rbln/transformers/models/exaone/hf_hub_cached/configuration_exaone.py +181 -0
  31. optimum/rbln/transformers/models/exaone/hf_hub_cached/modeling_exaone.py +1725 -0
  32. optimum/rbln/transformers/models/exaone/modeling_exaone.py +78 -0
  33. optimum/rbln/transformers/models/llava_next/modeling_llava_next.py +148 -152
  34. optimum/rbln/transformers/models/midm/modeling_midm.py +5 -0
  35. optimum/rbln/transformers/models/qwen2/__init__.py +24 -0
  36. optimum/rbln/transformers/models/qwen2/modeling_qwen2.py +67 -0
  37. optimum/rbln/transformers/models/qwen2/qwen2_architecture.py +29 -0
  38. optimum/rbln/transformers/models/seq2seq/__init__.py +24 -0
  39. optimum/rbln/{modeling_seq2seq.py → transformers/models/seq2seq/modeling_seq2seq.py} +107 -166
  40. optimum/rbln/transformers/models/t5/__init__.py +1 -0
  41. optimum/rbln/transformers/models/t5/modeling_t5.py +55 -0
  42. optimum/rbln/transformers/models/t5/t5_architecture.py +46 -32
  43. optimum/rbln/transformers/models/wav2vec2/modeling_wav2vec2.py +0 -1
  44. optimum/rbln/transformers/models/whisper/modeling_whisper.py +37 -12
  45. optimum/rbln/transformers/models/xlm_roberta/modeling_xlm_roberta.py +1 -2
  46. optimum/rbln/utils/import_utils.py +14 -0
  47. optimum/rbln/utils/logging.py +1 -1
  48. optimum/rbln/utils/runtime_utils.py +1 -1
  49. optimum/rbln/utils/timer_utils.py +26 -2
  50. {optimum_rbln-0.1.11.dist-info → optimum_rbln-0.1.12.dist-info}/METADATA +4 -3
  51. {optimum_rbln-0.1.11.dist-info → optimum_rbln-0.1.12.dist-info}/RECORD +54 -44
  52. {optimum_rbln-0.1.11.dist-info → optimum_rbln-0.1.12.dist-info}/WHEEL +1 -1
  53. {optimum_rbln-0.1.11.dist-info → optimum_rbln-0.1.12.dist-info}/entry_points.txt +0 -0
  54. {optimum_rbln-0.1.11.dist-info → optimum_rbln-0.1.12.dist-info}/licenses/LICENSE +0 -0
@@ -23,24 +23,17 @@
23
23
 
24
24
  import inspect
25
25
  import logging
26
- from typing import TYPE_CHECKING, Any, Callable, Dict, List, Optional, Tuple, Union
27
-
28
- import rebel
29
- import torch
30
- from transformers import (
31
- AutoModelForSeq2SeqLM,
32
- BartConfig,
33
- BartForConditionalGeneration,
34
- PretrainedConfig,
35
- T5ForConditionalGeneration,
36
- )
26
+ from abc import ABC
27
+ from typing import TYPE_CHECKING, Any, Dict, List, Optional, Tuple, Union
28
+
29
+ import rebel # noqa: F401
30
+ import torch # noqa: F401
31
+ from transformers import AutoModelForSeq2SeqLM, GenerationConfig, PretrainedConfig, PreTrainedModel
37
32
  from transformers.modeling_outputs import BaseModelOutput, Seq2SeqLMOutput
38
33
 
39
- from .modeling_base import RBLNModel
40
- from .modeling_config import DEFAULT_COMPILED_MODEL_NAME, RBLNCompileConfig, RBLNConfig
41
- from .transformers.models.bart import BartDecoderWrapper, BartEncoderWrapper
42
- from .transformers.models.t5 import T5DecoderWrapper, T5EncoderWrapper
43
- from .utils.runtime_utils import RBLNPytorchRuntime
34
+ from ....modeling_base import RBLNModel
35
+ from ....modeling_config import DEFAULT_COMPILED_MODEL_NAME, RBLNCompileConfig, RBLNConfig
36
+ from ....utils.runtime_utils import RBLNPytorchRuntime
44
37
 
45
38
 
46
39
  logger = logging.getLogger(__name__)
@@ -59,7 +52,6 @@ class RBLNRuntimeEncoder(RBLNPytorchRuntime):
59
52
 
60
53
  def forward(self, *args: List[torch.Tensor], **kwargs: Dict[str, torch.Tensor]):
61
54
  _ = super().forward(*args, **kwargs)
62
- # Just indicates that it is not None
63
55
  return BaseModelOutput(last_hidden_state=torch.tensor([1.0]))
64
56
 
65
57
 
@@ -71,7 +63,7 @@ class RBLNRuntimeDecoder(RBLNPytorchRuntime):
71
63
  return Seq2SeqLMOutput(logits=outputs)
72
64
 
73
65
 
74
- class RBLNModelForSeq2SeqLM(RBLNModel):
66
+ class RBLNModelForSeq2SeqLM(RBLNModel, ABC):
75
67
  """
76
68
  This is a generic model class that will be instantiated as one of the model classes of the library (with a sequence-to-sequence language modeling head) when created with the from_pretrained() class method.
77
69
  This model inherits from [`RBLNBaseModel`]. Check the superclass documentation for the generic methods the library implements for all its models.
@@ -84,91 +76,35 @@ class RBLNModelForSeq2SeqLM(RBLNModel):
84
76
  Currently, this model class only supports the 'bart' and 't5' models from the transformers library. Future updates may include support for additional model types.
85
77
  """
86
78
 
79
+ main_input_name = "input_ids"
87
80
  auto_model_class = AutoModelForSeq2SeqLM
88
81
 
89
82
  def __post_init__(self, **kwargs):
90
- self.model_dim = self.config.d_model
91
- self.batch_size = self.rbln_config.model_cfg["batch_size"]
92
- self.enc_max_seq_len = self.rbln_config.model_cfg["enc_max_seq_len"]
93
- self.dec_max_seq_len = self.rbln_config.model_cfg["dec_max_seq_len"]
94
- self.pad_token_id = self.rbln_config.model_cfg["pad_token_id"]
95
83
  self.encoder = RBLNRuntimeEncoder(runtime=self.model[0], main_input_name="input_ids")
96
84
  self.decoder = RBLNRuntimeDecoder(runtime=self.model[1], main_input_name="input_ids")
97
- self.enc_attention_mask = torch.zeros(1, self.enc_max_seq_len, dtype=torch.float32)
98
- self.dec_enc_attention_mask = torch.zeros(self.batch_size, self.enc_max_seq_len, dtype=torch.float32)
99
-
100
- def can_generate(self):
101
- return True
102
-
103
- def get_encoder(self):
104
- return self.encoder
105
-
106
- def get_decoder(self):
107
- return self.decoder
108
-
109
- def __getattr__(self, __name: str) -> Any:
110
- def redirect(func):
111
- return lambda *pargs, **kwargs: func(self, *pargs, **kwargs)
112
-
113
- if "T5ForConditionalGeneration" == self.config.architectures:
114
- val = getattr(T5ForConditionalGeneration, __name)
115
- else:
116
- val = getattr(BartForConditionalGeneration, __name)
117
-
118
- if isinstance(val, Callable) and "self" in set(inspect.signature(val).parameters):
119
- return redirect(val)
120
- return val
121
-
122
- @classmethod
123
- def update_kwargs(cls, kwargs):
124
- kwargs.update(
125
- {
126
- "torchscript": True,
127
- "return_dict": False,
128
- "use_cache": True,
129
- }
130
- )
131
- return kwargs
132
85
 
133
86
  @classmethod
134
- def get_compiled_model(cls, model, rbln_config: RBLNConfig):
135
- def optimized_models(model):
136
- if isinstance(model, T5ForConditionalGeneration):
137
- encoder_model = T5EncoderWrapper(model).eval()
138
- decoder_model = T5DecoderWrapper(model).eval()
139
- elif isinstance(model, BartForConditionalGeneration):
140
- encoder_model = BartEncoderWrapper(model).eval()
141
- decoder_model = BartDecoderWrapper(model).eval()
142
- else:
143
- raise ValueError(f"{model.__class__.__name__} is not supported yet.")
144
-
145
- return encoder_model, decoder_model
146
-
147
- wrapped_encoder, wrapped_decoder = optimized_models(model)
87
+ @torch.inference_mode()
88
+ def get_compiled_model(cls, model: PreTrainedModel, rbln_config: RBLNConfig):
89
+ wrapped_model = cls.wrap_model_if_needed(model, rbln_config)
148
90
 
149
- wrapped_encoder.encoder_max_length = rbln_config.model_cfg["enc_max_seq_len"]
150
- wrapped_encoder.decoder_max_length = rbln_config.model_cfg["dec_max_seq_len"]
151
- wrapped_encoder.decoder_batch_size = rbln_config.model_cfg["batch_size"]
91
+ wrapped_model.encoder.encoder_max_length = rbln_config.model_cfg["enc_max_seq_len"]
92
+ wrapped_model.encoder.decoder_max_length = rbln_config.model_cfg["dec_max_seq_len"]
152
93
 
153
- wrapped_decoder.encoder_max_length = rbln_config.model_cfg["enc_max_seq_len"]
154
- wrapped_decoder.decoder_max_length = rbln_config.model_cfg["dec_max_seq_len"]
155
- wrapped_decoder.decoder_batch_size = rbln_config.model_cfg["batch_size"]
94
+ wrapped_model.decoder.encoder_max_length = rbln_config.model_cfg["enc_max_seq_len"]
95
+ wrapped_model.decoder.decoder_max_length = rbln_config.model_cfg["dec_max_seq_len"]
156
96
 
157
97
  enc_rbln_compile_config = rbln_config.compile_cfgs[0]
158
98
  dec_rbln_compile_config = rbln_config.compile_cfgs[1]
159
99
 
160
- if isinstance(model, T5ForConditionalGeneration):
161
- enc_example_inputs = enc_rbln_compile_config.get_dummy_inputs(fill=1)
162
- dec_example_inputs = dec_rbln_compile_config.get_dummy_inputs(fill=1)
163
- else:
164
- enc_example_inputs = enc_rbln_compile_config.get_dummy_inputs(fill=0)
165
- dec_example_inputs = dec_rbln_compile_config.get_dummy_inputs(fill=0)
100
+ enc_example_inputs = enc_rbln_compile_config.get_dummy_inputs(fill=0)
101
+ dec_example_inputs = dec_rbln_compile_config.get_dummy_inputs(fill=0)
166
102
 
167
103
  enc_example_inputs[3].fill_(0)
168
104
  dec_example_inputs[4].fill_(-1)
169
105
 
170
- enc_scripted_model = torch.jit.trace(wrapped_encoder, enc_example_inputs, check_trace=False)
171
- dec_scripted_model = torch.jit.trace(wrapped_decoder, dec_example_inputs, check_trace=False)
106
+ enc_scripted_model = torch.jit.trace(wrapped_model.encoder, enc_example_inputs, check_trace=False)
107
+ dec_scripted_model = torch.jit.trace(wrapped_model.decoder, dec_example_inputs, check_trace=False)
172
108
 
173
109
  enc_ir = rebel.torchscript_to_ir(
174
110
  enc_scripted_model,
@@ -180,13 +116,12 @@ class RBLNModelForSeq2SeqLM(RBLNModel):
180
116
  input_names=[v[0] for v in dec_rbln_compile_config.input_info],
181
117
  name=dec_rbln_compile_config.mod_name,
182
118
  )
183
- dec_ir.decoder_batch_size = rbln_config.model_cfg["batch_size"]
184
119
 
185
120
  connections = [
186
121
  (enc_ir.outputs[0], enc_ir.inputs[2], dec_ir.inputs[6]),
187
- # (enc_ir.outputs[0], enc_ir.inputs[2]),
188
122
  (dec_ir.outputs[1], dec_ir.inputs[5]),
189
123
  ]
124
+
190
125
  compiled_model = rebel.compile(
191
126
  enc_ir,
192
127
  dec_ir,
@@ -209,14 +144,13 @@ class RBLNModelForSeq2SeqLM(RBLNModel):
209
144
  rbln_batch_size = rbln_kwargs.get("batch_size", None)
210
145
  rbln_batch_size = 1 if rbln_batch_size is None else rbln_batch_size
211
146
 
212
- if isinstance(model_config, BartConfig):
213
- n_layer = model_config.decoder_layers
214
- n_head = model_config.decoder_attention_heads
215
- d_kv = model_config.d_model // model_config.encoder_attention_heads
216
- else:
217
- n_layer = model_config.num_layers
218
- n_head = model_config.num_heads
219
- d_kv = model_config.d_kv
147
+ n_layer = getattr(model_config, "decoder_layers", None) or getattr(model_config, "num_layers")
148
+ n_head = getattr(model_config, "decoder_attention_heads", None) or getattr(model_config, "num_heads")
149
+ d_kv = (
150
+ model_config.d_kv
151
+ if hasattr(model_config, "d_kv")
152
+ else model_config.d_model // model_config.encoder_attention_heads
153
+ )
220
154
 
221
155
  max_position_embeddings = getattr(model_config, "n_positions", None) or getattr(
222
156
  model_config, "max_position_embeddings", None
@@ -270,7 +204,6 @@ class RBLNModelForSeq2SeqLM(RBLNModel):
270
204
  ],
271
205
  "float32",
272
206
  ),
273
- # int16 available?
274
207
  ("batch_idx", [], "int32"),
275
208
  ]
276
209
 
@@ -281,7 +214,6 @@ class RBLNModelForSeq2SeqLM(RBLNModel):
281
214
  (
282
215
  "cache_position",
283
216
  [rbln_batch_size, 1],
284
- # [],
285
217
  "int32",
286
218
  ),
287
219
  ("batch_position", [], "int32"),
@@ -346,35 +278,32 @@ class RBLNModelForSeq2SeqLM(RBLNModel):
346
278
  compiled_models[0].create_runtime("decoder", tensor_type="pt", device=device_val),
347
279
  ]
348
280
 
281
+ def can_generate(self):
282
+ return True
283
+
284
+ def get_encoder(self):
285
+ return self.encoder
286
+
287
+ def get_decoder(self):
288
+ return self.decoder
289
+
349
290
  def prepare_inputs_for_generation(
350
291
  self,
351
292
  input_ids,
352
- past_key_values=None,
353
293
  attention_mask=None,
354
294
  decoder_attention_mask=None,
355
295
  **kwargs,
356
296
  ):
357
- past_cache_length = past_key_values
358
- if past_cache_length == 0:
359
- cache_pos = []
360
- for i in range(input_ids.shape[0]):
361
- cache_pos.append([0])
362
- cache_position = torch.tensor(cache_pos, dtype=torch.int32)
363
-
364
- max_seq_len = self.dec_max_seq_len
365
297
  cur_seq_len = input_ids.shape[-1]
298
+ cache_position = cur_seq_len - 1
299
+ max_seq_len = self.rbln_config.model_cfg["dec_max_seq_len"]
366
300
  decoder_batch_size = input_ids.shape[0]
367
301
  input_ids = input_ids[:, cur_seq_len - 1 : cur_seq_len].contiguous()
368
- # In greedy decoding
369
302
  decoder_attention_mask = torch.zeros(decoder_batch_size, max_seq_len, dtype=torch.float32)
370
303
  decoder_attention_mask[:, :cur_seq_len] = 1
371
- cache_pos = []
372
- for i in range(input_ids.shape[0]):
373
- cache_pos.append([cur_seq_len - 1])
374
- cache_position = torch.tensor(cache_pos, dtype=torch.int32)
304
+
375
305
  return {
376
306
  "decoder_input_ids": input_ids,
377
- "past_key_values": past_key_values,
378
307
  "attention_mask": attention_mask.to(torch.float32),
379
308
  "decoder_attention_mask": decoder_attention_mask,
380
309
  "cache_position": cache_position,
@@ -383,41 +312,12 @@ class RBLNModelForSeq2SeqLM(RBLNModel):
383
312
  def forward(
384
313
  self,
385
314
  input_ids: torch.LongTensor = None,
386
- cache_position: Union[List[torch.Tensor], torch.Tensor] = None, # vllm keyword argument
387
- batch_idx: Optional[torch.LongTensor] = None,
388
- enc_lengths: List[int] = None,
315
+ cache_position: Union[List[torch.Tensor], torch.Tensor] = None,
389
316
  **kwargs,
390
317
  ) -> Tuple[torch.FloatTensor]:
391
318
  # common decoder
392
- if enc_lengths is None:
393
- output = self._forward_decoder(input_ids=input_ids, cache_position=cache_position, **kwargs)
394
- return output
395
-
396
- # vllm & encoder
397
- if batch_idx is not None:
398
- enc_attention_mask = self.enc_attention_mask.clone()
399
- enc_attention_mask[0][: enc_lengths[batch_idx] + 1] = 1
400
- padding_need = self.enc_max_seq_len - input_ids.shape[-1]
401
- input_ids = torch.nn.functional.pad(input_ids, (0, padding_need))
402
- _ = self.encoder(input_ids, enc_attention_mask, batch_idx=batch_idx.to(torch.int32))
403
- logits = torch.zeros(1, 1, self.config.vocab_size + 100)
404
- logits[0][0][-1] = 1
405
- # vllm & decoder
406
- else:
407
- input_ids[input_ids == (self.config.vocab_size + 99)] = self.config.decoder_start_token_id
408
- cache_position[cache_position != 0] = cache_position[cache_position != 0] - 2
409
-
410
- enc_attention_mask = self.dec_enc_attention_mask.clone()
411
- dec_attention_mask = torch.zeros(self.batch_size, self.dec_max_seq_len, dtype=torch.float32)
412
- for batch_idx in range(self.batch_size):
413
- enc_attention_mask[batch_idx, : enc_lengths[batch_idx] + 1] = 1
414
-
415
- logits = self._forward_decoder(
416
- attention_mask=enc_attention_mask,
417
- decoder_input_ids=input_ids,
418
- decoder_attention_mask=dec_attention_mask,
419
- cache_position=cache_position,
420
- ).logits
319
+ cache_position = torch.full((self.rbln_config.model_cfg["batch_size"], 1), cache_position, dtype=torch.int32)
320
+ logits = self._forward_decoder(input_ids=input_ids, cache_position=cache_position, **kwargs).logits
421
321
 
422
322
  return Seq2SeqLMOutput(
423
323
  logits=logits,
@@ -446,25 +346,58 @@ class RBLNModelForSeq2SeqLM(RBLNModel):
446
346
 
447
347
  return Seq2SeqLMOutput(logits=lm_logits)
448
348
 
349
+ def vllm_forward(
350
+ self,
351
+ input_ids: torch.LongTensor = None,
352
+ cache_position: Union[List[torch.Tensor], torch.Tensor] = None, # vllm keyword argument
353
+ batch_idx: Optional[torch.LongTensor] = None,
354
+ enc_lengths: List[int] = None, # vllm return current attention_mask length
355
+ **kwargs,
356
+ ) -> Tuple[torch.FloatTensor]:
357
+ # When using vllm, need the output of the encoder (ex. vocab_size + 100) and use that value act as start_token_id in decoder (ex. vocab_size + 99)
358
+ # encoder
359
+ if batch_idx is not None:
360
+ enc_attention_mask = torch.zeros(1, self.rbln_config.model_cfg["enc_max_seq_len"], dtype=torch.float32)
361
+ enc_attention_mask[0][: enc_lengths[batch_idx] + 1] = 1
362
+ padding_need = self.rbln_config.model_cfg["enc_max_seq_len"] - input_ids.shape[-1]
363
+ input_ids = torch.nn.functional.pad(input_ids, (0, padding_need))
364
+ _ = self.encoder(input_ids, enc_attention_mask, batch_idx=batch_idx.to(torch.int32))
365
+ logits = torch.zeros(1, 1, self.config.vocab_size + 100)
366
+ logits[0][0][-1] = 1
367
+ # decoder
368
+ else:
369
+ input_ids[input_ids == (self.config.vocab_size + 99)] = self.config.decoder_start_token_id
370
+ cache_position[cache_position != 0] = cache_position[cache_position != 0] - 2
371
+
372
+ enc_attention_mask = torch.zeros(
373
+ self.rbln_config.model_cfg["batch_size"],
374
+ self.rbln_config.model_cfg["enc_max_seq_len"],
375
+ dtype=torch.float32,
376
+ )
377
+ dec_attention_mask = torch.zeros(
378
+ self.rbln_config.model_cfg["batch_size"],
379
+ self.rbln_config.model_cfg["dec_max_seq_len"],
380
+ dtype=torch.float32,
381
+ )
382
+ for batch_idx in range(self.rbln_config.model_cfg["batch_size"]):
383
+ enc_attention_mask[batch_idx, : enc_lengths[batch_idx] + 1] = 1
384
+
385
+ logits = self._forward_decoder(
386
+ attention_mask=enc_attention_mask,
387
+ decoder_input_ids=input_ids,
388
+ decoder_attention_mask=dec_attention_mask,
389
+ cache_position=cache_position,
390
+ ).logits
391
+
392
+ return Seq2SeqLMOutput(logits=logits)
393
+
449
394
  def _prepare_encoder_decoder_kwargs_for_generation(
450
395
  self,
451
396
  inputs_tensor: torch.Tensor,
452
397
  model_kwargs,
453
398
  model_input_name: Optional[str] = None,
454
- *args,
455
- **kwargs,
399
+ generation_config: Optional[GenerationConfig] = None,
456
400
  ) -> Dict[str, Any]:
457
- ########## thkim change start ###################
458
- # padding input_ids & attention_mask regardless of user's tokenizer usage
459
- batch_size, input_len = inputs_tensor.shape
460
- inputs_tensor = torch.nn.functional.pad(
461
- inputs_tensor, (0, self.enc_max_seq_len - input_len), value=self.pad_token_id
462
- )
463
- model_kwargs["attention_mask"] = torch.nn.functional.pad(
464
- model_kwargs["attention_mask"], (0, self.enc_max_seq_len - input_len), value=0
465
- )
466
- ########## thkim change end ###################
467
-
468
401
  # 1. get encoder
469
402
  encoder = self.get_encoder()
470
403
 
@@ -482,18 +415,26 @@ class RBLNModelForSeq2SeqLM(RBLNModel):
482
415
  argument: value for argument, value in encoder_kwargs.items() if argument in encoder_signature
483
416
  }
484
417
 
418
+ batch_size, input_len = inputs_tensor.shape
419
+ inputs_tensor = torch.nn.functional.pad(
420
+ inputs_tensor,
421
+ (0, self.rbln_config.model_cfg["enc_max_seq_len"] - input_len),
422
+ value=self.rbln_config.model_cfg["pad_token_id"],
423
+ )
424
+ model_kwargs["attention_mask"] = torch.nn.functional.pad(
425
+ model_kwargs["attention_mask"], (0, self.rbln_config.model_cfg["enc_max_seq_len"] - input_len)
426
+ )
427
+
485
428
  # 3. make sure that encoder returns `ModelOutput`
486
429
  model_input_name = model_input_name if model_input_name is not None else self.main_input_name
487
430
  encoder_kwargs["return_dict"] = True
488
- encoder_kwargs[model_input_name] = inputs_tensor
431
+ encoder_kwargs["output_hidden_states"] = False
432
+ encoder_kwargs["output_attentions"] = False
433
+
489
434
  for b in range(batch_size):
490
435
  batch_idx = torch.tensor(b, dtype=torch.int32)
491
- cb_inputs = {}
492
- cb_inputs["return_dict"] = True
493
- cb_inputs["output_hidden_states"] = False
494
- cb_inputs["output_attentions"] = False
495
- cb_inputs["input_ids"] = encoder_kwargs["input_ids"][b].unsqueeze(0)
496
- cb_inputs["attention_mask"] = encoder_kwargs["attention_mask"][b].unsqueeze(0).to(torch.float32)
497
- model_kwargs["encoder_outputs"] = encoder(**cb_inputs, batch_idx=batch_idx)
436
+ encoder_kwargs["input_ids"] = inputs_tensor[b].unsqueeze(0)
437
+ encoder_kwargs["attention_mask"] = model_kwargs["attention_mask"][b].unsqueeze(0).to(torch.float32)
438
+ model_kwargs["encoder_outputs"] = encoder(**encoder_kwargs, batch_idx=batch_idx)
498
439
 
499
440
  return model_kwargs
@@ -21,4 +21,5 @@
21
21
  # copied, modified, or distributed without prior written permission
22
22
  # from Rebellions Inc.
23
23
 
24
+ from .modeling_t5 import RBLNT5ForConditionalGeneration
24
25
  from .t5_architecture import T5DecoderWrapper, T5EncoderWrapper
@@ -0,0 +1,55 @@
1
+ # Copyright 2024 Rebellions Inc.
2
+
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at:
6
+
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ # Portions of this software are licensed under the Apache License,
16
+ # Version 2.0. See the NOTICE file distributed with this work for
17
+ # additional information regarding copyright ownership.
18
+
19
+ # All other portions of this software, including proprietary code,
20
+ # are the intellectual property of Rebellions Inc. and may not be
21
+ # copied, modified, or distributed without prior written permission
22
+ # from Rebellions Inc.
23
+
24
+ import inspect
25
+ from typing import TYPE_CHECKING, Any, Callable
26
+
27
+ from transformers import T5ForConditionalGeneration
28
+
29
+ from ....modeling_config import RBLNConfig
30
+ from ....utils.logging import get_logger
31
+ from ...models.seq2seq import RBLNModelForSeq2SeqLM
32
+ from .t5_architecture import T5Wrapper
33
+
34
+
35
+ logger = get_logger()
36
+
37
+ if TYPE_CHECKING:
38
+ from transformers import PreTrainedModel
39
+
40
+
41
+ class RBLNT5ForConditionalGeneration(RBLNModelForSeq2SeqLM):
42
+ @classmethod
43
+ def wrap_model_if_needed(self, model: "PreTrainedModel", rbln_config: "RBLNConfig"):
44
+ return T5Wrapper(model)
45
+
46
+ def __getattr__(self, __name: str) -> Any:
47
+ def redirect(func):
48
+ return lambda *pargs, **kwargs: func(self, *pargs, **kwargs)
49
+
50
+ val = getattr(T5ForConditionalGeneration, __name)
51
+
52
+ if isinstance(val, Callable) and "self" in set(inspect.signature(val).parameters):
53
+ return redirect(val)
54
+
55
+ return val
@@ -43,6 +43,12 @@ if TYPE_CHECKING:
43
43
  from transformers import T5ForConditionalGeneration
44
44
 
45
45
 
46
+ class T5Wrapper:
47
+ def __init__(self, model):
48
+ self.encoder = T5EncoderWrapper(model)
49
+ self.decoder = T5DecoderWrapper(model)
50
+
51
+
46
52
  class T5Encoder(T5Stack):
47
53
  def forward(
48
54
  self,
@@ -122,19 +128,26 @@ class T5EncoderWrapper(torch.nn.Module):
122
128
  )
123
129
  self.encoder_max_length = None
124
130
  self.decoder_max_length = None
125
- self.decoder_batch_size = 1
126
131
 
127
- def forward(self, input_ids: torch.Tensor, attention_mask: torch.Tensor,
128
- cross_key_value: torch.Tensor = None, batch_idx: torch.Tensor = None,
129
- ) -> torch.Tensor:
130
- encoder_batch_size = input_ids.shape[0]
131
- decoder_batch_size = self.decoder_batch_size
132
+ def forward(
133
+ self,
134
+ input_ids: torch.Tensor,
135
+ attention_mask: torch.Tensor,
136
+ cross_key_value: torch.Tensor = None,
137
+ batch_idx: torch.Tensor = None,
138
+ ) -> torch.Tensor:
132
139
  decoder_max_length = self.decoder_max_length or self.default_max_length
133
140
  encoder_max_length = self.encoder_max_length or self.default_max_length
134
141
 
135
142
  attn_layer = self.encoder.block[0].layer[0].SelfAttention
136
143
  encoder_position_bias = T5Attention.compute_bias(attn_layer, encoder_max_length, encoder_max_length)
137
- encoder_outputs = T5Encoder.forward(self.encoder, input_ids, attention_mask, encoder_position_bias, batch_ids=torch.tensor(0, dtype=torch.int32))
144
+ encoder_outputs = T5Encoder.forward(
145
+ self.encoder,
146
+ input_ids,
147
+ attention_mask,
148
+ encoder_position_bias,
149
+ batch_ids=torch.tensor(0, dtype=torch.int32),
150
+ )
138
151
 
139
152
  attn_layer = self.decoder.block[0].layer[0].SelfAttention
140
153
  decoder_position_bias = T5Attention.compute_bias(attn_layer, decoder_max_length, decoder_max_length)
@@ -145,22 +158,14 @@ class T5EncoderWrapper(torch.nn.Module):
145
158
 
146
159
  dummy_past_key_value = []
147
160
  for i in range(self.config.num_layers):
148
- pkv_self_attn_key = torch.zeros(
149
- decoder_batch_size, self.config.num_heads, decoder_max_length, self.config.d_kv
150
- )
151
- pkv_self_attn_value = torch.zeros(
152
- decoder_batch_size, self.config.num_heads, decoder_max_length, self.config.d_kv
153
- )
154
- pkv_cross_attn_key = torch.zeros(
155
- encoder_batch_size, self.config.num_heads, encoder_max_length, self.config.d_kv
156
- )
157
- pkv_cross_attn_value = torch.zeros(
158
- encoder_batch_size, self.config.num_heads, encoder_max_length, self.config.d_kv
159
- )
161
+ pkv_self_attn_key = torch.zeros(1, self.config.num_heads, decoder_max_length, self.config.d_kv)
162
+ pkv_self_attn_value = torch.zeros(1, self.config.num_heads, decoder_max_length, self.config.d_kv)
163
+ pkv_cross_attn_key = torch.zeros(1, self.config.num_heads, encoder_max_length, self.config.d_kv)
164
+ pkv_cross_attn_value = torch.zeros(1, self.config.num_heads, encoder_max_length, self.config.d_kv)
160
165
  layer_pkv = (pkv_self_attn_key, pkv_self_attn_value, pkv_cross_attn_key, pkv_cross_attn_value)
161
166
  dummy_past_key_value.append(layer_pkv)
162
167
 
163
- decoder_attention_mask = torch.zeros(decoder_batch_size, decoder_max_length, dtype=torch.float32)
168
+ decoder_attention_mask = torch.zeros(1, decoder_max_length, dtype=torch.float32)
164
169
  decoder_attention_mask[:, :1] = 1
165
170
 
166
171
  # Since first step of decoder has different graph to further step of it,
@@ -168,7 +173,7 @@ class T5EncoderWrapper(torch.nn.Module):
168
173
  # TODO(jongho): Separate first-step-decoder.
169
174
  decoder_outputs = T5Decoder.forward(
170
175
  self.decoder,
171
- input_ids=torch.zeros(decoder_batch_size, 1, dtype=torch.int64),
176
+ input_ids=torch.zeros(1, 1, dtype=torch.int64),
172
177
  attention_mask=decoder_attention_mask,
173
178
  position_bias=decoder_position_bias,
174
179
  encoder_decoder_position_bias=encoder_decoder_position_bias,
@@ -187,7 +192,7 @@ class T5EncoderWrapper(torch.nn.Module):
187
192
  cross_kv_cache.append(past_key_values[i][3])
188
193
  cross_kv_cache = torch.stack(cross_kv_cache, dim=0)
189
194
 
190
- cross_key_value = cross_key_value.slice_scatter(cross_kv_cache, dim=1, start=batch_idx, end=batch_idx+1)
195
+ cross_key_value = cross_key_value.slice_scatter(cross_kv_cache, dim=1, start=batch_idx, end=batch_idx + 1)
191
196
 
192
197
  return cross_key_value
193
198
 
@@ -240,6 +245,7 @@ class T5DecoderWrapper(torch.nn.Module):
240
245
  attn_layer = self.model.decoder.block[0].layer[0].SelfAttention
241
246
  _decoder_position_bias = T5Attention.compute_bias(attn_layer, decoder_max_length, decoder_max_length)
242
247
 
248
+ # position_bias need to compute with batch (for cb)
243
249
  batch_decoder_position_bias = []
244
250
  for i in range(input_ids.shape[0]):
245
251
  batch_position_bias = _decoder_position_bias[:, :, cache_position[i][0]].unsqueeze(2)
@@ -259,7 +265,7 @@ class T5DecoderWrapper(torch.nn.Module):
259
265
  encoder_decoder_position_bias=encoder_decoder_position_bias,
260
266
  past_key_values=kv_cache,
261
267
  cache_position=cache_position,
262
- batch_ids=rbln_batch_position
268
+ batch_ids=rbln_batch_position,
263
269
  )
264
270
 
265
271
  past_key_values = decoder_outputs.past_key_values
@@ -312,7 +318,7 @@ class _T5Attention(T5Attention):
312
318
  value_states = shape(self.v(hidden_states), batch_size)
313
319
  else:
314
320
  # cross-attn
315
- if cache_position.dim() == 0 :
321
+ if cache_position.dim() == 0:
316
322
  key_states = shape(self.k(key_value_states), key_value_states.shape[0])
317
323
  value_states = shape(self.v(key_value_states), key_value_states.shape[0])
318
324
  past_key_value = key_states, value_states
@@ -331,18 +337,24 @@ class _T5Attention(T5Attention):
331
337
  batch_value_states = value_states[b].unsqueeze(0)
332
338
 
333
339
  if is_self_attn and past_key_value is not None:
334
- batch_key_states = past_key_value[0][b].unsqueeze(0).slice_scatter(
335
- batch_key_states, dim=-2, start=cache_position[b][0], end=cache_position[b][0] + 1
340
+ batch_key_states = (
341
+ past_key_value[0][b]
342
+ .unsqueeze(0)
343
+ .slice_scatter(
344
+ batch_key_states, dim=-2, start=cache_position[b][0], end=cache_position[b][0] + 1
345
+ )
336
346
  )
337
- batch_value_states = past_key_value[1][b].unsqueeze(0).slice_scatter(
338
- batch_value_states, dim=-2, start=cache_position[b][0], end=cache_position[b][0] + 1
347
+ batch_value_states = (
348
+ past_key_value[1][b]
349
+ .unsqueeze(0)
350
+ .slice_scatter(
351
+ batch_value_states, dim=-2, start=cache_position[b][0], end=cache_position[b][0] + 1
352
+ )
339
353
  )
340
354
 
341
355
  scores = torch.matmul(batch_query_states, batch_key_states.transpose(3, 2))
342
356
  scores += position_bias[b]
343
- attn_weights = nn.functional.softmax(scores.float(), dim=-1).type_as(
344
- scores
345
- )
357
+ attn_weights = nn.functional.softmax(scores.float(), dim=-1).type_as(scores)
346
358
  attn_output = unshape(torch.matmul(attn_weights, batch_value_states), 1)
347
359
  all_key_states.append(batch_key_states)
348
360
  all_value_states.append(batch_value_states)
@@ -371,7 +383,9 @@ class _T5Attention(T5Attention):
371
383
  scores
372
384
  ) # (batch_size, n_heads, seq_length, key_length)
373
385
 
374
- attn_output = unshape(torch.matmul(attn_weights, value_states), batch_size) # (batch_size, seq_length, dim)
386
+ attn_output = unshape(
387
+ torch.matmul(attn_weights, value_states), batch_size
388
+ ) # (batch_size, seq_length, dim)
375
389
 
376
390
  attn_output = self.o(attn_output)
377
391
  present_key_value = (key_states, value_states)
@@ -65,7 +65,6 @@ class RBLNWav2Vec2ForCTC(RBLNModel):
65
65
  - compiling the resulting graph using the RBLN compiler.
66
66
  """
67
67
 
68
- model_type = "rbln_model"
69
68
  main_input_name = "input_values"
70
69
  auto_model_class = AutoModelForMaskedLM
71
70