optimum-rbln 0.1.11__py3-none-any.whl → 0.1.13__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (72) hide show
  1. optimum/rbln/__init__.py +14 -7
  2. optimum/rbln/__version__.py +1 -1
  3. optimum/rbln/diffusers/models/autoencoder_kl.py +30 -63
  4. optimum/rbln/diffusers/models/controlnet.py +36 -62
  5. optimum/rbln/diffusers/models/unet_2d_condition.py +57 -156
  6. optimum/rbln/diffusers/pipelines/__init__.py +40 -12
  7. optimum/rbln/diffusers/pipelines/controlnet/multicontrolnet.py +11 -0
  8. optimum/rbln/diffusers/pipelines/controlnet/pipeline_controlnet.py +9 -187
  9. optimum/rbln/diffusers/pipelines/controlnet/pipeline_controlnet_img2img.py +8 -192
  10. optimum/rbln/diffusers/pipelines/controlnet/pipeline_controlnet_sd_xl.py +8 -206
  11. optimum/rbln/diffusers/pipelines/controlnet/pipeline_controlnet_sd_xl_img2img.py +8 -207
  12. optimum/rbln/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion.py +3 -111
  13. optimum/rbln/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_img2img.py +12 -117
  14. optimum/rbln/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl.py +4 -123
  15. optimum/rbln/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_img2img.py +4 -126
  16. optimum/rbln/modeling_alias.py +4 -9
  17. optimum/rbln/modeling_base.py +117 -144
  18. optimum/rbln/modeling_config.py +51 -0
  19. optimum/rbln/modeling_diffusers.py +400 -0
  20. optimum/rbln/transformers/__init__.py +10 -0
  21. optimum/rbln/transformers/cache_utils.py +5 -9
  22. optimum/rbln/transformers/modeling_rope_utils.py +283 -0
  23. optimum/rbln/transformers/models/__init__.py +80 -28
  24. optimum/rbln/transformers/models/auto/modeling_auto.py +1 -0
  25. optimum/rbln/transformers/models/bart/__init__.py +1 -1
  26. optimum/rbln/transformers/models/bart/bart_architecture.py +18 -12
  27. optimum/rbln/transformers/models/bart/modeling_bart.py +25 -6
  28. optimum/rbln/transformers/models/bert/modeling_bert.py +1 -2
  29. optimum/rbln/transformers/models/clip/modeling_clip.py +13 -23
  30. optimum/rbln/transformers/models/decoderonly/__init__.py +0 -2
  31. optimum/rbln/transformers/models/decoderonly/decoderonly_architecture.py +376 -218
  32. optimum/rbln/transformers/models/decoderonly/modeling_decoderonly.py +246 -116
  33. optimum/rbln/transformers/models/dpt/modeling_dpt.py +0 -1
  34. optimum/rbln/transformers/models/exaone/__init__.py +32 -0
  35. optimum/rbln/transformers/models/exaone/exaone_architecture.py +81 -0
  36. optimum/rbln/transformers/models/exaone/hf_hub_cached/configuration_exaone.py +181 -0
  37. optimum/rbln/transformers/models/exaone/hf_hub_cached/modeling_exaone.py +1725 -0
  38. optimum/rbln/transformers/models/exaone/modeling_exaone.py +53 -0
  39. optimum/rbln/transformers/models/gemma/gemma_architecture.py +12 -2
  40. optimum/rbln/transformers/models/gemma/modeling_gemma.py +4 -28
  41. optimum/rbln/transformers/models/gpt2/modeling_gpt2.py +4 -30
  42. optimum/rbln/transformers/models/llama/modeling_llama.py +4 -28
  43. optimum/rbln/transformers/models/llava_next/modeling_llava_next.py +166 -151
  44. optimum/rbln/transformers/models/midm/midm_architecture.py +4 -15
  45. optimum/rbln/transformers/models/midm/modeling_midm.py +8 -28
  46. optimum/rbln/transformers/models/mistral/modeling_mistral.py +4 -29
  47. optimum/rbln/transformers/models/phi/modeling_phi.py +5 -31
  48. optimum/rbln/transformers/models/phi/phi_architecture.py +75 -159
  49. optimum/rbln/transformers/models/qwen2/__init__.py +24 -0
  50. optimum/rbln/transformers/models/qwen2/modeling_qwen2.py +43 -0
  51. optimum/rbln/transformers/models/qwen2/qwen2_architecture.py +29 -0
  52. optimum/rbln/transformers/models/seq2seq/__init__.py +24 -0
  53. optimum/rbln/{modeling_seq2seq.py → transformers/models/seq2seq/modeling_seq2seq.py} +107 -166
  54. optimum/rbln/transformers/models/t5/__init__.py +1 -0
  55. optimum/rbln/transformers/models/t5/modeling_t5.py +108 -0
  56. optimum/rbln/transformers/models/t5/t5_architecture.py +46 -32
  57. optimum/rbln/transformers/models/wav2vec2/modeling_wav2vec2.py +0 -1
  58. optimum/rbln/transformers/models/whisper/modeling_whisper.py +38 -13
  59. optimum/rbln/transformers/models/xlm_roberta/modeling_xlm_roberta.py +1 -2
  60. optimum/rbln/transformers/utils/rbln_quantization.py +8 -2
  61. optimum/rbln/utils/context.py +58 -0
  62. optimum/rbln/utils/decorator_utils.py +55 -0
  63. optimum/rbln/utils/import_utils.py +21 -0
  64. optimum/rbln/utils/logging.py +1 -1
  65. optimum/rbln/utils/runtime_utils.py +4 -4
  66. optimum/rbln/utils/timer_utils.py +26 -2
  67. {optimum_rbln-0.1.11.dist-info → optimum_rbln-0.1.13.dist-info}/METADATA +11 -9
  68. optimum_rbln-0.1.13.dist-info/RECORD +107 -0
  69. {optimum_rbln-0.1.11.dist-info → optimum_rbln-0.1.13.dist-info}/WHEEL +1 -1
  70. optimum_rbln-0.1.11.dist-info/RECORD +0 -93
  71. {optimum_rbln-0.1.11.dist-info → optimum_rbln-0.1.13.dist-info}/entry_points.txt +0 -0
  72. {optimum_rbln-0.1.11.dist-info → optimum_rbln-0.1.13.dist-info}/licenses/LICENSE +0 -0
@@ -23,24 +23,17 @@
23
23
 
24
24
  import inspect
25
25
  import logging
26
- from typing import TYPE_CHECKING, Any, Callable, Dict, List, Optional, Tuple, Union
27
-
28
- import rebel
29
- import torch
30
- from transformers import (
31
- AutoModelForSeq2SeqLM,
32
- BartConfig,
33
- BartForConditionalGeneration,
34
- PretrainedConfig,
35
- T5ForConditionalGeneration,
36
- )
26
+ from abc import ABC
27
+ from typing import TYPE_CHECKING, Any, Dict, List, Optional, Tuple, Union
28
+
29
+ import rebel # noqa: F401
30
+ import torch # noqa: F401
31
+ from transformers import AutoModelForSeq2SeqLM, GenerationConfig, PretrainedConfig, PreTrainedModel
37
32
  from transformers.modeling_outputs import BaseModelOutput, Seq2SeqLMOutput
38
33
 
39
- from .modeling_base import RBLNModel
40
- from .modeling_config import DEFAULT_COMPILED_MODEL_NAME, RBLNCompileConfig, RBLNConfig
41
- from .transformers.models.bart import BartDecoderWrapper, BartEncoderWrapper
42
- from .transformers.models.t5 import T5DecoderWrapper, T5EncoderWrapper
43
- from .utils.runtime_utils import RBLNPytorchRuntime
34
+ from ....modeling_base import RBLNModel
35
+ from ....modeling_config import DEFAULT_COMPILED_MODEL_NAME, RBLNCompileConfig, RBLNConfig
36
+ from ....utils.runtime_utils import RBLNPytorchRuntime
44
37
 
45
38
 
46
39
  logger = logging.getLogger(__name__)
@@ -59,7 +52,6 @@ class RBLNRuntimeEncoder(RBLNPytorchRuntime):
59
52
 
60
53
  def forward(self, *args: List[torch.Tensor], **kwargs: Dict[str, torch.Tensor]):
61
54
  _ = super().forward(*args, **kwargs)
62
- # Just indicates that it is not None
63
55
  return BaseModelOutput(last_hidden_state=torch.tensor([1.0]))
64
56
 
65
57
 
@@ -71,7 +63,7 @@ class RBLNRuntimeDecoder(RBLNPytorchRuntime):
71
63
  return Seq2SeqLMOutput(logits=outputs)
72
64
 
73
65
 
74
- class RBLNModelForSeq2SeqLM(RBLNModel):
66
+ class RBLNModelForSeq2SeqLM(RBLNModel, ABC):
75
67
  """
76
68
  This is a generic model class that will be instantiated as one of the model classes of the library (with a sequence-to-sequence language modeling head) when created with the from_pretrained() class method.
77
69
  This model inherits from [`RBLNBaseModel`]. Check the superclass documentation for the generic methods the library implements for all its models.
@@ -84,91 +76,35 @@ class RBLNModelForSeq2SeqLM(RBLNModel):
84
76
  Currently, this model class only supports the 'bart' and 't5' models from the transformers library. Future updates may include support for additional model types.
85
77
  """
86
78
 
79
+ main_input_name = "input_ids"
87
80
  auto_model_class = AutoModelForSeq2SeqLM
88
81
 
89
82
  def __post_init__(self, **kwargs):
90
- self.model_dim = self.config.d_model
91
- self.batch_size = self.rbln_config.model_cfg["batch_size"]
92
- self.enc_max_seq_len = self.rbln_config.model_cfg["enc_max_seq_len"]
93
- self.dec_max_seq_len = self.rbln_config.model_cfg["dec_max_seq_len"]
94
- self.pad_token_id = self.rbln_config.model_cfg["pad_token_id"]
95
83
  self.encoder = RBLNRuntimeEncoder(runtime=self.model[0], main_input_name="input_ids")
96
84
  self.decoder = RBLNRuntimeDecoder(runtime=self.model[1], main_input_name="input_ids")
97
- self.enc_attention_mask = torch.zeros(1, self.enc_max_seq_len, dtype=torch.float32)
98
- self.dec_enc_attention_mask = torch.zeros(self.batch_size, self.enc_max_seq_len, dtype=torch.float32)
99
-
100
- def can_generate(self):
101
- return True
102
-
103
- def get_encoder(self):
104
- return self.encoder
105
-
106
- def get_decoder(self):
107
- return self.decoder
108
-
109
- def __getattr__(self, __name: str) -> Any:
110
- def redirect(func):
111
- return lambda *pargs, **kwargs: func(self, *pargs, **kwargs)
112
-
113
- if "T5ForConditionalGeneration" == self.config.architectures:
114
- val = getattr(T5ForConditionalGeneration, __name)
115
- else:
116
- val = getattr(BartForConditionalGeneration, __name)
117
-
118
- if isinstance(val, Callable) and "self" in set(inspect.signature(val).parameters):
119
- return redirect(val)
120
- return val
121
-
122
- @classmethod
123
- def update_kwargs(cls, kwargs):
124
- kwargs.update(
125
- {
126
- "torchscript": True,
127
- "return_dict": False,
128
- "use_cache": True,
129
- }
130
- )
131
- return kwargs
132
85
 
133
86
  @classmethod
134
- def get_compiled_model(cls, model, rbln_config: RBLNConfig):
135
- def optimized_models(model):
136
- if isinstance(model, T5ForConditionalGeneration):
137
- encoder_model = T5EncoderWrapper(model).eval()
138
- decoder_model = T5DecoderWrapper(model).eval()
139
- elif isinstance(model, BartForConditionalGeneration):
140
- encoder_model = BartEncoderWrapper(model).eval()
141
- decoder_model = BartDecoderWrapper(model).eval()
142
- else:
143
- raise ValueError(f"{model.__class__.__name__} is not supported yet.")
144
-
145
- return encoder_model, decoder_model
146
-
147
- wrapped_encoder, wrapped_decoder = optimized_models(model)
87
+ @torch.inference_mode()
88
+ def get_compiled_model(cls, model: PreTrainedModel, rbln_config: RBLNConfig):
89
+ wrapped_model = cls.wrap_model_if_needed(model, rbln_config)
148
90
 
149
- wrapped_encoder.encoder_max_length = rbln_config.model_cfg["enc_max_seq_len"]
150
- wrapped_encoder.decoder_max_length = rbln_config.model_cfg["dec_max_seq_len"]
151
- wrapped_encoder.decoder_batch_size = rbln_config.model_cfg["batch_size"]
91
+ wrapped_model.encoder.encoder_max_length = rbln_config.model_cfg["enc_max_seq_len"]
92
+ wrapped_model.encoder.decoder_max_length = rbln_config.model_cfg["dec_max_seq_len"]
152
93
 
153
- wrapped_decoder.encoder_max_length = rbln_config.model_cfg["enc_max_seq_len"]
154
- wrapped_decoder.decoder_max_length = rbln_config.model_cfg["dec_max_seq_len"]
155
- wrapped_decoder.decoder_batch_size = rbln_config.model_cfg["batch_size"]
94
+ wrapped_model.decoder.encoder_max_length = rbln_config.model_cfg["enc_max_seq_len"]
95
+ wrapped_model.decoder.decoder_max_length = rbln_config.model_cfg["dec_max_seq_len"]
156
96
 
157
97
  enc_rbln_compile_config = rbln_config.compile_cfgs[0]
158
98
  dec_rbln_compile_config = rbln_config.compile_cfgs[1]
159
99
 
160
- if isinstance(model, T5ForConditionalGeneration):
161
- enc_example_inputs = enc_rbln_compile_config.get_dummy_inputs(fill=1)
162
- dec_example_inputs = dec_rbln_compile_config.get_dummy_inputs(fill=1)
163
- else:
164
- enc_example_inputs = enc_rbln_compile_config.get_dummy_inputs(fill=0)
165
- dec_example_inputs = dec_rbln_compile_config.get_dummy_inputs(fill=0)
100
+ enc_example_inputs = enc_rbln_compile_config.get_dummy_inputs(fill=0)
101
+ dec_example_inputs = dec_rbln_compile_config.get_dummy_inputs(fill=0)
166
102
 
167
103
  enc_example_inputs[3].fill_(0)
168
104
  dec_example_inputs[4].fill_(-1)
169
105
 
170
- enc_scripted_model = torch.jit.trace(wrapped_encoder, enc_example_inputs, check_trace=False)
171
- dec_scripted_model = torch.jit.trace(wrapped_decoder, dec_example_inputs, check_trace=False)
106
+ enc_scripted_model = torch.jit.trace(wrapped_model.encoder, enc_example_inputs, check_trace=False)
107
+ dec_scripted_model = torch.jit.trace(wrapped_model.decoder, dec_example_inputs, check_trace=False)
172
108
 
173
109
  enc_ir = rebel.torchscript_to_ir(
174
110
  enc_scripted_model,
@@ -180,13 +116,12 @@ class RBLNModelForSeq2SeqLM(RBLNModel):
180
116
  input_names=[v[0] for v in dec_rbln_compile_config.input_info],
181
117
  name=dec_rbln_compile_config.mod_name,
182
118
  )
183
- dec_ir.decoder_batch_size = rbln_config.model_cfg["batch_size"]
184
119
 
185
120
  connections = [
186
121
  (enc_ir.outputs[0], enc_ir.inputs[2], dec_ir.inputs[6]),
187
- # (enc_ir.outputs[0], enc_ir.inputs[2]),
188
122
  (dec_ir.outputs[1], dec_ir.inputs[5]),
189
123
  ]
124
+
190
125
  compiled_model = rebel.compile(
191
126
  enc_ir,
192
127
  dec_ir,
@@ -209,14 +144,13 @@ class RBLNModelForSeq2SeqLM(RBLNModel):
209
144
  rbln_batch_size = rbln_kwargs.get("batch_size", None)
210
145
  rbln_batch_size = 1 if rbln_batch_size is None else rbln_batch_size
211
146
 
212
- if isinstance(model_config, BartConfig):
213
- n_layer = model_config.decoder_layers
214
- n_head = model_config.decoder_attention_heads
215
- d_kv = model_config.d_model // model_config.encoder_attention_heads
216
- else:
217
- n_layer = model_config.num_layers
218
- n_head = model_config.num_heads
219
- d_kv = model_config.d_kv
147
+ n_layer = getattr(model_config, "decoder_layers", None) or getattr(model_config, "num_layers")
148
+ n_head = getattr(model_config, "decoder_attention_heads", None) or getattr(model_config, "num_heads")
149
+ d_kv = (
150
+ model_config.d_kv
151
+ if hasattr(model_config, "d_kv")
152
+ else model_config.d_model // model_config.encoder_attention_heads
153
+ )
220
154
 
221
155
  max_position_embeddings = getattr(model_config, "n_positions", None) or getattr(
222
156
  model_config, "max_position_embeddings", None
@@ -270,7 +204,6 @@ class RBLNModelForSeq2SeqLM(RBLNModel):
270
204
  ],
271
205
  "float32",
272
206
  ),
273
- # int16 available?
274
207
  ("batch_idx", [], "int32"),
275
208
  ]
276
209
 
@@ -281,7 +214,6 @@ class RBLNModelForSeq2SeqLM(RBLNModel):
281
214
  (
282
215
  "cache_position",
283
216
  [rbln_batch_size, 1],
284
- # [],
285
217
  "int32",
286
218
  ),
287
219
  ("batch_position", [], "int32"),
@@ -346,35 +278,32 @@ class RBLNModelForSeq2SeqLM(RBLNModel):
346
278
  compiled_models[0].create_runtime("decoder", tensor_type="pt", device=device_val),
347
279
  ]
348
280
 
281
+ def can_generate(self):
282
+ return True
283
+
284
+ def get_encoder(self):
285
+ return self.encoder
286
+
287
+ def get_decoder(self):
288
+ return self.decoder
289
+
349
290
  def prepare_inputs_for_generation(
350
291
  self,
351
292
  input_ids,
352
- past_key_values=None,
353
293
  attention_mask=None,
354
294
  decoder_attention_mask=None,
355
295
  **kwargs,
356
296
  ):
357
- past_cache_length = past_key_values
358
- if past_cache_length == 0:
359
- cache_pos = []
360
- for i in range(input_ids.shape[0]):
361
- cache_pos.append([0])
362
- cache_position = torch.tensor(cache_pos, dtype=torch.int32)
363
-
364
- max_seq_len = self.dec_max_seq_len
365
297
  cur_seq_len = input_ids.shape[-1]
298
+ cache_position = cur_seq_len - 1
299
+ max_seq_len = self.rbln_config.model_cfg["dec_max_seq_len"]
366
300
  decoder_batch_size = input_ids.shape[0]
367
301
  input_ids = input_ids[:, cur_seq_len - 1 : cur_seq_len].contiguous()
368
- # In greedy decoding
369
302
  decoder_attention_mask = torch.zeros(decoder_batch_size, max_seq_len, dtype=torch.float32)
370
303
  decoder_attention_mask[:, :cur_seq_len] = 1
371
- cache_pos = []
372
- for i in range(input_ids.shape[0]):
373
- cache_pos.append([cur_seq_len - 1])
374
- cache_position = torch.tensor(cache_pos, dtype=torch.int32)
304
+
375
305
  return {
376
306
  "decoder_input_ids": input_ids,
377
- "past_key_values": past_key_values,
378
307
  "attention_mask": attention_mask.to(torch.float32),
379
308
  "decoder_attention_mask": decoder_attention_mask,
380
309
  "cache_position": cache_position,
@@ -383,41 +312,12 @@ class RBLNModelForSeq2SeqLM(RBLNModel):
383
312
  def forward(
384
313
  self,
385
314
  input_ids: torch.LongTensor = None,
386
- cache_position: Union[List[torch.Tensor], torch.Tensor] = None, # vllm keyword argument
387
- batch_idx: Optional[torch.LongTensor] = None,
388
- enc_lengths: List[int] = None,
315
+ cache_position: Union[List[torch.Tensor], torch.Tensor] = None,
389
316
  **kwargs,
390
317
  ) -> Tuple[torch.FloatTensor]:
391
318
  # common decoder
392
- if enc_lengths is None:
393
- output = self._forward_decoder(input_ids=input_ids, cache_position=cache_position, **kwargs)
394
- return output
395
-
396
- # vllm & encoder
397
- if batch_idx is not None:
398
- enc_attention_mask = self.enc_attention_mask.clone()
399
- enc_attention_mask[0][: enc_lengths[batch_idx] + 1] = 1
400
- padding_need = self.enc_max_seq_len - input_ids.shape[-1]
401
- input_ids = torch.nn.functional.pad(input_ids, (0, padding_need))
402
- _ = self.encoder(input_ids, enc_attention_mask, batch_idx=batch_idx.to(torch.int32))
403
- logits = torch.zeros(1, 1, self.config.vocab_size + 100)
404
- logits[0][0][-1] = 1
405
- # vllm & decoder
406
- else:
407
- input_ids[input_ids == (self.config.vocab_size + 99)] = self.config.decoder_start_token_id
408
- cache_position[cache_position != 0] = cache_position[cache_position != 0] - 2
409
-
410
- enc_attention_mask = self.dec_enc_attention_mask.clone()
411
- dec_attention_mask = torch.zeros(self.batch_size, self.dec_max_seq_len, dtype=torch.float32)
412
- for batch_idx in range(self.batch_size):
413
- enc_attention_mask[batch_idx, : enc_lengths[batch_idx] + 1] = 1
414
-
415
- logits = self._forward_decoder(
416
- attention_mask=enc_attention_mask,
417
- decoder_input_ids=input_ids,
418
- decoder_attention_mask=dec_attention_mask,
419
- cache_position=cache_position,
420
- ).logits
319
+ cache_position = torch.full((self.rbln_config.model_cfg["batch_size"], 1), cache_position, dtype=torch.int32)
320
+ logits = self._forward_decoder(input_ids=input_ids, cache_position=cache_position, **kwargs).logits
421
321
 
422
322
  return Seq2SeqLMOutput(
423
323
  logits=logits,
@@ -446,25 +346,58 @@ class RBLNModelForSeq2SeqLM(RBLNModel):
446
346
 
447
347
  return Seq2SeqLMOutput(logits=lm_logits)
448
348
 
349
+ def vllm_forward(
350
+ self,
351
+ input_ids: torch.LongTensor = None,
352
+ cache_position: Union[List[torch.Tensor], torch.Tensor] = None, # vllm keyword argument
353
+ batch_idx: Optional[torch.LongTensor] = None,
354
+ enc_lengths: List[int] = None, # vllm return current attention_mask length
355
+ **kwargs,
356
+ ) -> Tuple[torch.FloatTensor]:
357
+ # When using vllm, need the output of the encoder (ex. vocab_size + 100) and use that value act as start_token_id in decoder (ex. vocab_size + 99)
358
+ # encoder
359
+ if batch_idx is not None:
360
+ enc_attention_mask = torch.zeros(1, self.rbln_config.model_cfg["enc_max_seq_len"], dtype=torch.float32)
361
+ enc_attention_mask[0][: enc_lengths[batch_idx] + 1] = 1
362
+ padding_need = self.rbln_config.model_cfg["enc_max_seq_len"] - input_ids.shape[-1]
363
+ input_ids = torch.nn.functional.pad(input_ids, (0, padding_need))
364
+ _ = self.encoder(input_ids, enc_attention_mask, batch_idx=batch_idx.to(torch.int32))
365
+ logits = torch.zeros(1, 1, self.config.vocab_size + 100)
366
+ logits[0][0][-1] = 1
367
+ # decoder
368
+ else:
369
+ input_ids[input_ids == (self.config.vocab_size + 99)] = self.config.decoder_start_token_id
370
+ cache_position[cache_position != 0] = cache_position[cache_position != 0] - 2
371
+
372
+ enc_attention_mask = torch.zeros(
373
+ self.rbln_config.model_cfg["batch_size"],
374
+ self.rbln_config.model_cfg["enc_max_seq_len"],
375
+ dtype=torch.float32,
376
+ )
377
+ dec_attention_mask = torch.zeros(
378
+ self.rbln_config.model_cfg["batch_size"],
379
+ self.rbln_config.model_cfg["dec_max_seq_len"],
380
+ dtype=torch.float32,
381
+ )
382
+ for batch_idx in range(self.rbln_config.model_cfg["batch_size"]):
383
+ enc_attention_mask[batch_idx, : enc_lengths[batch_idx] + 1] = 1
384
+
385
+ logits = self._forward_decoder(
386
+ attention_mask=enc_attention_mask,
387
+ decoder_input_ids=input_ids,
388
+ decoder_attention_mask=dec_attention_mask,
389
+ cache_position=cache_position,
390
+ ).logits
391
+
392
+ return Seq2SeqLMOutput(logits=logits)
393
+
449
394
  def _prepare_encoder_decoder_kwargs_for_generation(
450
395
  self,
451
396
  inputs_tensor: torch.Tensor,
452
397
  model_kwargs,
453
398
  model_input_name: Optional[str] = None,
454
- *args,
455
- **kwargs,
399
+ generation_config: Optional[GenerationConfig] = None,
456
400
  ) -> Dict[str, Any]:
457
- ########## thkim change start ###################
458
- # padding input_ids & attention_mask regardless of user's tokenizer usage
459
- batch_size, input_len = inputs_tensor.shape
460
- inputs_tensor = torch.nn.functional.pad(
461
- inputs_tensor, (0, self.enc_max_seq_len - input_len), value=self.pad_token_id
462
- )
463
- model_kwargs["attention_mask"] = torch.nn.functional.pad(
464
- model_kwargs["attention_mask"], (0, self.enc_max_seq_len - input_len), value=0
465
- )
466
- ########## thkim change end ###################
467
-
468
401
  # 1. get encoder
469
402
  encoder = self.get_encoder()
470
403
 
@@ -482,18 +415,26 @@ class RBLNModelForSeq2SeqLM(RBLNModel):
482
415
  argument: value for argument, value in encoder_kwargs.items() if argument in encoder_signature
483
416
  }
484
417
 
418
+ batch_size, input_len = inputs_tensor.shape
419
+ inputs_tensor = torch.nn.functional.pad(
420
+ inputs_tensor,
421
+ (0, self.rbln_config.model_cfg["enc_max_seq_len"] - input_len),
422
+ value=self.rbln_config.model_cfg["pad_token_id"],
423
+ )
424
+ model_kwargs["attention_mask"] = torch.nn.functional.pad(
425
+ model_kwargs["attention_mask"], (0, self.rbln_config.model_cfg["enc_max_seq_len"] - input_len)
426
+ )
427
+
485
428
  # 3. make sure that encoder returns `ModelOutput`
486
429
  model_input_name = model_input_name if model_input_name is not None else self.main_input_name
487
430
  encoder_kwargs["return_dict"] = True
488
- encoder_kwargs[model_input_name] = inputs_tensor
431
+ encoder_kwargs["output_hidden_states"] = False
432
+ encoder_kwargs["output_attentions"] = False
433
+
489
434
  for b in range(batch_size):
490
435
  batch_idx = torch.tensor(b, dtype=torch.int32)
491
- cb_inputs = {}
492
- cb_inputs["return_dict"] = True
493
- cb_inputs["output_hidden_states"] = False
494
- cb_inputs["output_attentions"] = False
495
- cb_inputs["input_ids"] = encoder_kwargs["input_ids"][b].unsqueeze(0)
496
- cb_inputs["attention_mask"] = encoder_kwargs["attention_mask"][b].unsqueeze(0).to(torch.float32)
497
- model_kwargs["encoder_outputs"] = encoder(**cb_inputs, batch_idx=batch_idx)
436
+ encoder_kwargs["input_ids"] = inputs_tensor[b].unsqueeze(0)
437
+ encoder_kwargs["attention_mask"] = model_kwargs["attention_mask"][b].unsqueeze(0).to(torch.float32)
438
+ model_kwargs["encoder_outputs"] = encoder(**encoder_kwargs, batch_idx=batch_idx)
498
439
 
499
440
  return model_kwargs
@@ -21,4 +21,5 @@
21
21
  # copied, modified, or distributed without prior written permission
22
22
  # from Rebellions Inc.
23
23
 
24
+ from .modeling_t5 import RBLNT5EncoderModel, RBLNT5ForConditionalGeneration
24
25
  from .t5_architecture import T5DecoderWrapper, T5EncoderWrapper
@@ -0,0 +1,108 @@
1
+ # Copyright 2024 Rebellions Inc.
2
+
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at:
6
+
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ # Portions of this software are licensed under the Apache License,
16
+ # Version 2.0. See the NOTICE file distributed with this work for
17
+ # additional information regarding copyright ownership.
18
+
19
+ # All other portions of this software, including proprietary code,
20
+ # are the intellectual property of Rebellions Inc. and may not be
21
+ # copied, modified, or distributed without prior written permission
22
+ # from Rebellions Inc.
23
+
24
+ import inspect
25
+ from typing import TYPE_CHECKING, Any, Callable, Dict, Optional, Union
26
+
27
+ from transformers import (
28
+ AutoModelForTextEncoding,
29
+ PretrainedConfig,
30
+ T5ForConditionalGeneration,
31
+ )
32
+
33
+ from ....modeling_base import RBLNModel
34
+ from ....modeling_config import RBLNCompileConfig, RBLNConfig
35
+ from ....utils.logging import get_logger
36
+ from ...models.seq2seq import RBLNModelForSeq2SeqLM
37
+ from .t5_architecture import T5Wrapper
38
+
39
+
40
+ logger = get_logger()
41
+
42
+ if TYPE_CHECKING:
43
+ from transformers import AutoFeatureExtractor, AutoProcessor, AutoTokenizer, PreTrainedModel
44
+
45
+
46
+ class RBLNT5EncoderModel(RBLNModel):
47
+ auto_model_class = AutoModelForTextEncoding
48
+
49
+ @classmethod
50
+ def _get_rbln_config(
51
+ cls,
52
+ preprocessors: Optional[Union["AutoFeatureExtractor", "AutoProcessor", "AutoTokenizer"]],
53
+ model_config: Optional["PretrainedConfig"] = None,
54
+ rbln_kwargs: Dict[str, Any] = {},
55
+ ) -> RBLNConfig:
56
+ rbln_max_seq_len = rbln_kwargs.get("max_seq_len", None)
57
+ rbln_batch_size = rbln_kwargs.get("batch_size", None)
58
+
59
+ max_position_embeddings = getattr(model_config, "n_positions", None)
60
+
61
+ if rbln_max_seq_len is None:
62
+ rbln_max_seq_len = max_position_embeddings
63
+ if rbln_max_seq_len is None:
64
+ for tokenizer in preprocessors:
65
+ if hasattr(tokenizer, "model_max_length"):
66
+ rbln_max_seq_len = tokenizer.model_max_length
67
+ break
68
+ if rbln_max_seq_len is None:
69
+ raise ValueError("`rbln_max_seq_len` should be specified!")
70
+
71
+ if max_position_embeddings is not None and rbln_max_seq_len > max_position_embeddings:
72
+ raise ValueError("`rbln_max_seq_len` should be less or equal than max_position_embeddings!")
73
+
74
+ if rbln_batch_size is None:
75
+ rbln_batch_size = 1
76
+
77
+ input_info = [
78
+ ("input_ids", [rbln_batch_size, rbln_max_seq_len], "int64"),
79
+ ("attention_mask", [rbln_batch_size, rbln_max_seq_len], "int64"),
80
+ ]
81
+
82
+ rbln_compile_config = RBLNCompileConfig(input_info=input_info)
83
+
84
+ rbln_config = RBLNConfig(
85
+ rbln_cls=cls.__name__,
86
+ compile_cfgs=[rbln_compile_config],
87
+ rbln_kwargs=rbln_kwargs,
88
+ )
89
+
90
+ rbln_config.model_cfg.update({"max_seq_len": rbln_max_seq_len})
91
+ return rbln_config
92
+
93
+
94
+ class RBLNT5ForConditionalGeneration(RBLNModelForSeq2SeqLM):
95
+ @classmethod
96
+ def wrap_model_if_needed(self, model: "PreTrainedModel", rbln_config: "RBLNConfig"):
97
+ return T5Wrapper(model)
98
+
99
+ def __getattr__(self, __name: str) -> Any:
100
+ def redirect(func):
101
+ return lambda *pargs, **kwargs: func(self, *pargs, **kwargs)
102
+
103
+ val = getattr(T5ForConditionalGeneration, __name)
104
+
105
+ if isinstance(val, Callable) and "self" in set(inspect.signature(val).parameters):
106
+ return redirect(val)
107
+
108
+ return val