optimum-rbln 0.1.9__py3-none-any.whl → 0.1.12__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (73) hide show
  1. optimum/rbln/__init__.py +47 -9
  2. optimum/rbln/__version__.py +1 -1
  3. optimum/rbln/diffusers/models/autoencoder_kl.py +36 -31
  4. optimum/rbln/diffusers/models/controlnet.py +53 -43
  5. optimum/rbln/diffusers/models/unet_2d_condition.py +40 -31
  6. optimum/rbln/diffusers/pipelines/controlnet/multicontrolnet.py +4 -0
  7. optimum/rbln/diffusers/pipelines/controlnet/pipeline_controlnet.py +28 -23
  8. optimum/rbln/diffusers/pipelines/controlnet/pipeline_controlnet_img2img.py +28 -23
  9. optimum/rbln/diffusers/pipelines/controlnet/pipeline_controlnet_sd_xl.py +28 -37
  10. optimum/rbln/diffusers/pipelines/controlnet/pipeline_controlnet_sd_xl_img2img.py +30 -39
  11. optimum/rbln/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion.py +24 -14
  12. optimum/rbln/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_img2img.py +24 -15
  13. optimum/rbln/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl.py +26 -17
  14. optimum/rbln/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_img2img.py +26 -17
  15. optimum/rbln/modeling_alias.py +6 -11
  16. optimum/rbln/modeling_base.py +467 -261
  17. optimum/rbln/modeling_config.py +199 -73
  18. optimum/rbln/transformers/__init__.py +43 -1
  19. optimum/rbln/transformers/models/__init__.py +23 -1
  20. optimum/rbln/transformers/models/auto/__init__.py +14 -0
  21. optimum/rbln/transformers/models/auto/auto_factory.py +84 -0
  22. optimum/rbln/transformers/models/auto/modeling_auto.py +95 -0
  23. optimum/rbln/transformers/models/bart/__init__.py +1 -0
  24. optimum/rbln/transformers/models/bart/bart_architecture.py +203 -58
  25. optimum/rbln/transformers/models/bart/modeling_bart.py +125 -0
  26. optimum/rbln/transformers/models/bert/__init__.py +24 -0
  27. optimum/rbln/transformers/models/bert/modeling_bert.py +101 -0
  28. optimum/rbln/transformers/models/clip/__init__.py +1 -1
  29. optimum/rbln/transformers/models/clip/modeling_clip.py +127 -26
  30. optimum/rbln/transformers/models/decoderonly/decoderonly_architecture.py +28 -4
  31. optimum/rbln/transformers/models/decoderonly/modeling_decoderonly.py +409 -150
  32. optimum/rbln/transformers/models/dpt/modeling_dpt.py +21 -8
  33. optimum/rbln/transformers/models/exaone/__init__.py +32 -0
  34. optimum/rbln/transformers/models/exaone/exaone_architecture.py +72 -0
  35. optimum/rbln/transformers/models/exaone/hf_hub_cached/configuration_exaone.py +181 -0
  36. optimum/rbln/transformers/models/exaone/hf_hub_cached/modeling_exaone.py +1725 -0
  37. optimum/rbln/transformers/models/exaone/modeling_exaone.py +78 -0
  38. optimum/rbln/transformers/models/gemma/modeling_gemma.py +1 -1
  39. optimum/rbln/transformers/models/gpt2/gpt2_architecture.py +4 -1
  40. optimum/rbln/transformers/models/gpt2/modeling_gpt2.py +1 -1
  41. optimum/rbln/transformers/models/llama/modeling_llama.py +1 -1
  42. optimum/rbln/transformers/models/llava_next/__init__.py +24 -0
  43. optimum/rbln/transformers/models/llava_next/modeling_llava_next.py +662 -0
  44. optimum/rbln/transformers/models/midm/midm_architecture.py +5 -1
  45. optimum/rbln/transformers/models/midm/modeling_midm.py +6 -1
  46. optimum/rbln/transformers/models/mistral/modeling_mistral.py +1 -1
  47. optimum/rbln/transformers/models/phi/__init__.py +24 -0
  48. optimum/rbln/transformers/models/phi/modeling_phi.py +69 -0
  49. optimum/rbln/transformers/models/phi/phi_architecture.py +406 -0
  50. optimum/rbln/transformers/models/qwen2/__init__.py +24 -0
  51. optimum/rbln/transformers/models/qwen2/modeling_qwen2.py +67 -0
  52. optimum/rbln/transformers/models/qwen2/qwen2_architecture.py +29 -0
  53. optimum/rbln/transformers/models/seq2seq/__init__.py +24 -0
  54. optimum/rbln/{modeling_seq2seq.py → transformers/models/seq2seq/modeling_seq2seq.py} +198 -168
  55. optimum/rbln/transformers/models/t5/__init__.py +1 -0
  56. optimum/rbln/transformers/models/t5/modeling_t5.py +55 -0
  57. optimum/rbln/transformers/models/t5/t5_architecture.py +122 -47
  58. optimum/rbln/transformers/models/wav2vec2/modeling_wav2vec2.py +17 -12
  59. optimum/rbln/transformers/models/whisper/generation_whisper.py +68 -0
  60. optimum/rbln/transformers/models/whisper/modeling_whisper.py +172 -111
  61. optimum/rbln/transformers/models/whisper/whisper_architecture.py +44 -17
  62. optimum/rbln/transformers/models/xlm_roberta/modeling_xlm_roberta.py +18 -16
  63. optimum/rbln/transformers/utils/rbln_quantization.py +48 -60
  64. optimum/rbln/utils/import_utils.py +50 -1
  65. optimum/rbln/utils/logging.py +82 -0
  66. optimum/rbln/utils/runtime_utils.py +33 -0
  67. optimum/rbln/utils/timer_utils.py +43 -0
  68. {optimum_rbln-0.1.9.dist-info → optimum_rbln-0.1.12.dist-info}/METADATA +9 -7
  69. optimum_rbln-0.1.12.dist-info/RECORD +103 -0
  70. {optimum_rbln-0.1.9.dist-info → optimum_rbln-0.1.12.dist-info}/WHEEL +1 -1
  71. optimum_rbln-0.1.12.dist-info/entry_points.txt +4 -0
  72. optimum_rbln-0.1.9.dist-info/RECORD +0 -78
  73. {optimum_rbln-0.1.9.dist-info → optimum_rbln-0.1.12.dist-info}/licenses/LICENSE +0 -0
@@ -23,24 +23,17 @@
23
23
 
24
24
  import inspect
25
25
  import logging
26
- from typing import TYPE_CHECKING, Any, Callable, Dict, List, Optional, Tuple, Union
27
-
28
- import rebel
29
- import torch
30
- from transformers import (
31
- AutoModelForSeq2SeqLM,
32
- BartConfig,
33
- BartForConditionalGeneration,
34
- PretrainedConfig,
35
- T5ForConditionalGeneration,
36
- )
26
+ from abc import ABC
27
+ from typing import TYPE_CHECKING, Any, Dict, List, Optional, Tuple, Union
28
+
29
+ import rebel # noqa: F401
30
+ import torch # noqa: F401
31
+ from transformers import AutoModelForSeq2SeqLM, GenerationConfig, PretrainedConfig, PreTrainedModel
37
32
  from transformers.modeling_outputs import BaseModelOutput, Seq2SeqLMOutput
38
33
 
39
- from .modeling_base import RBLNModel
40
- from .modeling_config import DEFAULT_COMPILED_MODEL_NAME, RBLNConfig, RBLNRuntimeConfig
41
- from .transformers.models.bart import BartDecoderWrapper, BartEncoderWrapper
42
- from .transformers.models.t5 import T5DecoderWrapper, T5EncoderWrapper
43
- from .utils.runtime_utils import RBLNPytorchRuntime
34
+ from ....modeling_base import RBLNModel
35
+ from ....modeling_config import DEFAULT_COMPILED_MODEL_NAME, RBLNCompileConfig, RBLNConfig
36
+ from ....utils.runtime_utils import RBLNPytorchRuntime
44
37
 
45
38
 
46
39
  logger = logging.getLogger(__name__)
@@ -59,7 +52,6 @@ class RBLNRuntimeEncoder(RBLNPytorchRuntime):
59
52
 
60
53
  def forward(self, *args: List[torch.Tensor], **kwargs: Dict[str, torch.Tensor]):
61
54
  _ = super().forward(*args, **kwargs)
62
- # Just indicates that it is not None
63
55
  return BaseModelOutput(last_hidden_state=torch.tensor([1.0]))
64
56
 
65
57
 
@@ -71,7 +63,7 @@ class RBLNRuntimeDecoder(RBLNPytorchRuntime):
71
63
  return Seq2SeqLMOutput(logits=outputs)
72
64
 
73
65
 
74
- class RBLNModelForSeq2SeqLM(RBLNModel):
66
+ class RBLNModelForSeq2SeqLM(RBLNModel, ABC):
75
67
  """
76
68
  This is a generic model class that will be instantiated as one of the model classes of the library (with a sequence-to-sequence language modeling head) when created with the from_pretrained() class method.
77
69
  This model inherits from [`RBLNBaseModel`]. Check the superclass documentation for the generic methods the library implements for all its models.
@@ -84,136 +76,59 @@ class RBLNModelForSeq2SeqLM(RBLNModel):
84
76
  Currently, this model class only supports the 'bart' and 't5' models from the transformers library. Future updates may include support for additional model types.
85
77
  """
86
78
 
79
+ main_input_name = "input_ids"
87
80
  auto_model_class = AutoModelForSeq2SeqLM
88
81
 
89
82
  def __post_init__(self, **kwargs):
90
- self.model_dim = self.config.d_model
91
- self.batch_size = self.rbln_config[DEFAULT_COMPILED_MODEL_NAME][0].batch_size
92
- self.enc_max_seq_len = self.rbln_config.meta["rbln_enc_max_seq_len"]
93
- self.dec_max_seq_len = self.rbln_config.meta["rbln_dec_max_seq_len"]
94
- self.pad_token_id = self.rbln_config.meta["rbln_pad_token_id"]
95
83
  self.encoder = RBLNRuntimeEncoder(runtime=self.model[0], main_input_name="input_ids")
96
84
  self.decoder = RBLNRuntimeDecoder(runtime=self.model[1], main_input_name="input_ids")
97
85
 
98
- def can_generate(self):
99
- return True
86
+ @classmethod
87
+ @torch.inference_mode()
88
+ def get_compiled_model(cls, model: PreTrainedModel, rbln_config: RBLNConfig):
89
+ wrapped_model = cls.wrap_model_if_needed(model, rbln_config)
100
90
 
101
- def get_encoder(self):
102
- return self.encoder
91
+ wrapped_model.encoder.encoder_max_length = rbln_config.model_cfg["enc_max_seq_len"]
92
+ wrapped_model.encoder.decoder_max_length = rbln_config.model_cfg["dec_max_seq_len"]
103
93
 
104
- def get_decoder(self):
105
- return self.decoder
106
-
107
- def __getattr__(self, __name: str) -> Any:
108
- def redirect(func):
109
- return lambda *pargs, **kwargs: func(self, *pargs, **kwargs)
94
+ wrapped_model.decoder.encoder_max_length = rbln_config.model_cfg["enc_max_seq_len"]
95
+ wrapped_model.decoder.decoder_max_length = rbln_config.model_cfg["dec_max_seq_len"]
110
96
 
111
- if "T5ForConditionalGeneration" == self.config.architectures:
112
- val = getattr(T5ForConditionalGeneration, __name)
113
- else:
114
- val = getattr(BartForConditionalGeneration, __name)
97
+ enc_rbln_compile_config = rbln_config.compile_cfgs[0]
98
+ dec_rbln_compile_config = rbln_config.compile_cfgs[1]
115
99
 
116
- if isinstance(val, Callable) and "self" in set(inspect.signature(val).parameters):
117
- return redirect(val)
118
- return val
100
+ enc_example_inputs = enc_rbln_compile_config.get_dummy_inputs(fill=0)
101
+ dec_example_inputs = dec_rbln_compile_config.get_dummy_inputs(fill=0)
119
102
 
120
- def prepare_inputs_for_generation(
121
- self,
122
- input_ids,
123
- past_key_values=None,
124
- attention_mask=None,
125
- decoder_attention_mask=None,
126
- **kwargs,
127
- ):
128
- max_seq_len = self.dec_max_seq_len
129
- cur_seq_len = input_ids.shape[-1]
130
- decoder_batch_size = input_ids.shape[0]
131
- input_ids = input_ids[:, cur_seq_len - 1 : cur_seq_len].contiguous()
103
+ enc_example_inputs[3].fill_(0)
104
+ dec_example_inputs[4].fill_(-1)
132
105
 
133
- # In greedy decoding
134
- decoder_attention_mask = torch.zeros(decoder_batch_size, max_seq_len, dtype=torch.int64)
135
- decoder_attention_mask[:, :cur_seq_len] = 1
136
- cache_position = torch.tensor(cur_seq_len - 1, dtype=torch.int32)
137
-
138
- return {
139
- "decoder_input_ids": input_ids,
140
- "past_key_values": past_key_values,
141
- "attention_mask": attention_mask,
142
- "decoder_attention_mask": decoder_attention_mask,
143
- "cache_position": cache_position,
144
- }
145
-
146
- @classmethod
147
- def update_kwargs(cls, kwargs):
148
- kwargs.update(
149
- {
150
- "torchscript": True,
151
- "return_dict": False,
152
- "use_cache": True,
153
- }
154
- )
155
- return kwargs
156
-
157
- @classmethod
158
- def get_compiled_model(cls, model, rbln_config: RBLNConfig):
159
- def optimized_models(model):
160
- if isinstance(model, T5ForConditionalGeneration):
161
- encoder_model = T5EncoderWrapper(model).eval()
162
- decoder_model = T5DecoderWrapper(model).eval()
163
- elif isinstance(model, BartForConditionalGeneration):
164
- encoder_model = BartEncoderWrapper(model).eval()
165
- decoder_model = BartDecoderWrapper(model).eval()
166
- else:
167
- raise ValueError(f"{model.__class__.__name__} is not supported yet.")
168
-
169
- return encoder_model, decoder_model
170
-
171
- wrapped_encoder, wrapped_decoder = optimized_models(model)
172
-
173
- wrapped_encoder.encoder_max_length = rbln_config.meta["rbln_enc_max_seq_len"]
174
- wrapped_encoder.decoder_max_length = rbln_config.meta["rbln_dec_max_seq_len"]
175
- wrapped_encoder.decoder_batch_size = rbln_config.meta["rbln_batch_size"]
176
-
177
- wrapped_decoder.encoder_max_length = rbln_config.meta["rbln_enc_max_seq_len"]
178
- wrapped_decoder.decoder_max_length = rbln_config.meta["rbln_dec_max_seq_len"]
179
- wrapped_decoder.decoder_batch_size = rbln_config.meta["rbln_batch_size"]
180
-
181
- enc_rbln_runtime_config = rbln_config[DEFAULT_COMPILED_MODEL_NAME][0]
182
- dec_rbln_runtime_config = rbln_config[DEFAULT_COMPILED_MODEL_NAME][1]
183
-
184
- if isinstance(model, T5ForConditionalGeneration):
185
- enc_example_inputs = enc_rbln_runtime_config.get_dummy_inputs(fill=1)
186
- dec_example_inputs = dec_rbln_runtime_config.get_dummy_inputs(fill=1)
187
- else:
188
- enc_example_inputs = enc_rbln_runtime_config.get_dummy_inputs(fill=0)
189
- dec_example_inputs = dec_rbln_runtime_config.get_dummy_inputs(fill=0)
190
-
191
- enc_scripted_model = torch.jit.trace(wrapped_encoder, enc_example_inputs, check_trace=False)
192
- dec_scripted_model = torch.jit.trace(wrapped_decoder, dec_example_inputs, check_trace=False)
106
+ enc_scripted_model = torch.jit.trace(wrapped_model.encoder, enc_example_inputs, check_trace=False)
107
+ dec_scripted_model = torch.jit.trace(wrapped_model.decoder, dec_example_inputs, check_trace=False)
193
108
 
194
109
  enc_ir = rebel.torchscript_to_ir(
195
110
  enc_scripted_model,
196
- input_names=[v[0] for v in enc_rbln_runtime_config.input_info],
197
- name=enc_rbln_runtime_config.rbln_mod_name,
111
+ input_names=[v[0] for v in enc_rbln_compile_config.input_info],
112
+ name=enc_rbln_compile_config.mod_name,
198
113
  )
199
114
  dec_ir = rebel.torchscript_to_ir(
200
115
  dec_scripted_model,
201
- input_names=[v[0] for v in dec_rbln_runtime_config.input_info],
202
- name=dec_rbln_runtime_config.rbln_mod_name,
116
+ input_names=[v[0] for v in dec_rbln_compile_config.input_info],
117
+ name=dec_rbln_compile_config.mod_name,
203
118
  )
204
- dec_ir.decoder_batch_size = rbln_config.meta["rbln_batch_size"]
205
119
 
206
120
  connections = [
207
- (enc_ir.outputs[0], dec_ir.inputs[5]),
208
- (dec_ir.outputs[1], dec_ir.inputs[4]),
121
+ (enc_ir.outputs[0], enc_ir.inputs[2], dec_ir.inputs[6]),
122
+ (dec_ir.outputs[1], dec_ir.inputs[5]),
209
123
  ]
124
+
210
125
  compiled_model = rebel.compile(
211
126
  enc_ir,
212
127
  dec_ir,
213
128
  connections=connections,
214
- fusion=enc_rbln_runtime_config.fusion,
215
- npu=enc_rbln_runtime_config.npu,
216
- tensor_parallel_size=enc_rbln_runtime_config.tensor_parallel_size,
129
+ fusion=enc_rbln_compile_config.fusion,
130
+ npu=enc_rbln_compile_config.npu,
131
+ tensor_parallel_size=enc_rbln_compile_config.tensor_parallel_size,
217
132
  )
218
133
  return compiled_model
219
134
 
@@ -222,20 +137,20 @@ class RBLNModelForSeq2SeqLM(RBLNModel):
222
137
  cls,
223
138
  preprocessors: Union["AutoFeatureExtractor", "AutoProcessor", "AutoTokenizer"],
224
139
  model_config: "PretrainedConfig",
225
- rbln_enc_max_seq_len: Optional[int] = None,
226
- rbln_dec_max_seq_len: Optional[int] = None,
227
- rbln_batch_size: Optional[int] = 1,
140
+ rbln_kwargs: Dict[str, Any] = {},
228
141
  ) -> RBLNConfig:
229
- meta = {}
142
+ rbln_enc_max_seq_len = rbln_kwargs.get("enc_max_seq_len", None)
143
+ rbln_dec_max_seq_len = rbln_kwargs.get("dec_max_seq_len", None)
144
+ rbln_batch_size = rbln_kwargs.get("batch_size", None)
145
+ rbln_batch_size = 1 if rbln_batch_size is None else rbln_batch_size
230
146
 
231
- if isinstance(model_config, BartConfig):
232
- n_layer = model_config.decoder_layers
233
- n_head = model_config.decoder_attention_heads
234
- d_kv = model_config.d_model // model_config.encoder_attention_heads
235
- else:
236
- n_layer = model_config.num_layers
237
- n_head = model_config.num_heads
238
- d_kv = model_config.d_kv
147
+ n_layer = getattr(model_config, "decoder_layers", None) or getattr(model_config, "num_layers")
148
+ n_head = getattr(model_config, "decoder_attention_heads", None) or getattr(model_config, "num_heads")
149
+ d_kv = (
150
+ model_config.d_kv
151
+ if hasattr(model_config, "d_kv")
152
+ else model_config.d_model // model_config.encoder_attention_heads
153
+ )
239
154
 
240
155
  max_position_embeddings = getattr(model_config, "n_positions", None) or getattr(
241
156
  model_config, "max_position_embeddings", None
@@ -274,28 +189,34 @@ class RBLNModelForSeq2SeqLM(RBLNModel):
274
189
  if max_position_embeddings is not None and rbln_dec_max_seq_len > max_position_embeddings:
275
190
  raise ValueError("`rbln_dec_max_seq_len` should be less or equal than max_position_embeddings!")
276
191
 
277
- rbln_batch_size = 1 if rbln_batch_size is None else rbln_batch_size
278
-
279
- meta["rbln_enc_max_seq_len"] = rbln_enc_max_seq_len
280
- meta["rbln_dec_max_seq_len"] = rbln_dec_max_seq_len
281
- meta["rbln_batch_size"] = rbln_batch_size
282
- meta["rbln_pad_token_id"] = rbln_pad_token_id
283
-
284
192
  # model input info
285
193
  enc_input_info = [
286
- ("input_ids", [rbln_batch_size, rbln_enc_max_seq_len], "int64"),
287
- ("attention_mask", [rbln_batch_size, rbln_enc_max_seq_len], "int64"),
194
+ ("input_ids", [1, rbln_enc_max_seq_len], "int64"),
195
+ ("attention_mask", [1, rbln_enc_max_seq_len], "float32"),
196
+ (
197
+ "cross_key_value_states",
198
+ [
199
+ n_layer * 2,
200
+ rbln_batch_size,
201
+ n_head,
202
+ rbln_enc_max_seq_len,
203
+ d_kv,
204
+ ],
205
+ "float32",
206
+ ),
207
+ ("batch_idx", [], "int32"),
288
208
  ]
289
209
 
290
210
  dec_input_info = [
291
211
  ("input_ids", [rbln_batch_size, 1], "int64"),
292
- ("attention_mask", [rbln_batch_size, rbln_dec_max_seq_len], "int64"),
293
- ("encoder_attention_mask", [rbln_batch_size, rbln_enc_max_seq_len], "int64"),
212
+ ("attention_mask", [rbln_batch_size, rbln_dec_max_seq_len], "float32"),
213
+ ("encoder_attention_mask", [rbln_batch_size, rbln_enc_max_seq_len], "float32"),
294
214
  (
295
215
  "cache_position",
296
- [],
216
+ [rbln_batch_size, 1],
297
217
  "int32",
298
218
  ),
219
+ ("batch_position", [], "int32"),
299
220
  ]
300
221
  dec_input_info.extend(
301
222
  [
@@ -327,12 +248,22 @@ class RBLNModelForSeq2SeqLM(RBLNModel):
327
248
  )
328
249
  ]
329
250
  )
330
- enc_rbln_runtime_config = RBLNRuntimeConfig(rbln_mod_name="encoder", input_info=enc_input_info)
331
- dec_rbln_runtime_config = RBLNRuntimeConfig(rbln_mod_name="decoder", input_info=dec_input_info)
251
+ enc_rbln_compile_config = RBLNCompileConfig(mod_name="encoder", input_info=enc_input_info)
252
+ dec_rbln_compile_config = RBLNCompileConfig(mod_name="decoder", input_info=dec_input_info)
332
253
 
333
- rbln_config = RBLNConfig.from_rbln_runtime_configs(
334
- [enc_rbln_runtime_config, dec_rbln_runtime_config],
335
- _rbln_meta=meta,
254
+ rbln_config = RBLNConfig(
255
+ rbln_cls=cls.__name__,
256
+ compile_cfgs=[enc_rbln_compile_config, dec_rbln_compile_config],
257
+ rbln_kwargs=rbln_kwargs,
258
+ )
259
+
260
+ rbln_config.model_cfg.update(
261
+ {
262
+ "enc_max_seq_len": rbln_enc_max_seq_len,
263
+ "dec_max_seq_len": rbln_dec_max_seq_len,
264
+ "batch_size": rbln_batch_size,
265
+ "pad_token_id": rbln_pad_token_id,
266
+ }
336
267
  )
337
268
 
338
269
  return rbln_config
@@ -347,7 +278,52 @@ class RBLNModelForSeq2SeqLM(RBLNModel):
347
278
  compiled_models[0].create_runtime("decoder", tensor_type="pt", device=device_val),
348
279
  ]
349
280
 
281
+ def can_generate(self):
282
+ return True
283
+
284
+ def get_encoder(self):
285
+ return self.encoder
286
+
287
+ def get_decoder(self):
288
+ return self.decoder
289
+
290
+ def prepare_inputs_for_generation(
291
+ self,
292
+ input_ids,
293
+ attention_mask=None,
294
+ decoder_attention_mask=None,
295
+ **kwargs,
296
+ ):
297
+ cur_seq_len = input_ids.shape[-1]
298
+ cache_position = cur_seq_len - 1
299
+ max_seq_len = self.rbln_config.model_cfg["dec_max_seq_len"]
300
+ decoder_batch_size = input_ids.shape[0]
301
+ input_ids = input_ids[:, cur_seq_len - 1 : cur_seq_len].contiguous()
302
+ decoder_attention_mask = torch.zeros(decoder_batch_size, max_seq_len, dtype=torch.float32)
303
+ decoder_attention_mask[:, :cur_seq_len] = 1
304
+
305
+ return {
306
+ "decoder_input_ids": input_ids,
307
+ "attention_mask": attention_mask.to(torch.float32),
308
+ "decoder_attention_mask": decoder_attention_mask,
309
+ "cache_position": cache_position,
310
+ }
311
+
350
312
  def forward(
313
+ self,
314
+ input_ids: torch.LongTensor = None,
315
+ cache_position: Union[List[torch.Tensor], torch.Tensor] = None,
316
+ **kwargs,
317
+ ) -> Tuple[torch.FloatTensor]:
318
+ # common decoder
319
+ cache_position = torch.full((self.rbln_config.model_cfg["batch_size"], 1), cache_position, dtype=torch.int32)
320
+ logits = self._forward_decoder(input_ids=input_ids, cache_position=cache_position, **kwargs).logits
321
+
322
+ return Seq2SeqLMOutput(
323
+ logits=logits,
324
+ )
325
+
326
+ def _forward_decoder(
351
327
  self,
352
328
  attention_mask: Optional[torch.FloatTensor] = None,
353
329
  decoder_input_ids: Optional[torch.LongTensor] = None,
@@ -355,35 +331,73 @@ class RBLNModelForSeq2SeqLM(RBLNModel):
355
331
  cache_position: Optional[torch.Tensor] = None,
356
332
  **kwargs,
357
333
  ) -> Tuple[torch.FloatTensor]:
334
+ dec_attention_mask = decoder_attention_mask.clone()
335
+ for b_idx in range(self.rbln_config.model_cfg["batch_size"]):
336
+ dec_attention_mask[b_idx, : cache_position[b_idx] + 1] = 1
337
+
358
338
  decoder_output = self.decoder(
359
339
  input_ids=decoder_input_ids,
360
- attention_mask=decoder_attention_mask,
340
+ attention_mask=dec_attention_mask,
361
341
  encoder_attention_mask=attention_mask,
362
342
  cache_position=cache_position,
343
+ batch_position=torch.tensor(0, dtype=torch.int32),
363
344
  )
364
- lm_logits = decoder_output.logits
345
+ lm_logits = decoder_output.logits[0]
365
346
 
366
347
  return Seq2SeqLMOutput(logits=lm_logits)
367
348
 
349
+ def vllm_forward(
350
+ self,
351
+ input_ids: torch.LongTensor = None,
352
+ cache_position: Union[List[torch.Tensor], torch.Tensor] = None, # vllm keyword argument
353
+ batch_idx: Optional[torch.LongTensor] = None,
354
+ enc_lengths: List[int] = None, # vllm return current attention_mask length
355
+ **kwargs,
356
+ ) -> Tuple[torch.FloatTensor]:
357
+ # When using vllm, need the output of the encoder (ex. vocab_size + 100) and use that value act as start_token_id in decoder (ex. vocab_size + 99)
358
+ # encoder
359
+ if batch_idx is not None:
360
+ enc_attention_mask = torch.zeros(1, self.rbln_config.model_cfg["enc_max_seq_len"], dtype=torch.float32)
361
+ enc_attention_mask[0][: enc_lengths[batch_idx] + 1] = 1
362
+ padding_need = self.rbln_config.model_cfg["enc_max_seq_len"] - input_ids.shape[-1]
363
+ input_ids = torch.nn.functional.pad(input_ids, (0, padding_need))
364
+ _ = self.encoder(input_ids, enc_attention_mask, batch_idx=batch_idx.to(torch.int32))
365
+ logits = torch.zeros(1, 1, self.config.vocab_size + 100)
366
+ logits[0][0][-1] = 1
367
+ # decoder
368
+ else:
369
+ input_ids[input_ids == (self.config.vocab_size + 99)] = self.config.decoder_start_token_id
370
+ cache_position[cache_position != 0] = cache_position[cache_position != 0] - 2
371
+
372
+ enc_attention_mask = torch.zeros(
373
+ self.rbln_config.model_cfg["batch_size"],
374
+ self.rbln_config.model_cfg["enc_max_seq_len"],
375
+ dtype=torch.float32,
376
+ )
377
+ dec_attention_mask = torch.zeros(
378
+ self.rbln_config.model_cfg["batch_size"],
379
+ self.rbln_config.model_cfg["dec_max_seq_len"],
380
+ dtype=torch.float32,
381
+ )
382
+ for batch_idx in range(self.rbln_config.model_cfg["batch_size"]):
383
+ enc_attention_mask[batch_idx, : enc_lengths[batch_idx] + 1] = 1
384
+
385
+ logits = self._forward_decoder(
386
+ attention_mask=enc_attention_mask,
387
+ decoder_input_ids=input_ids,
388
+ decoder_attention_mask=dec_attention_mask,
389
+ cache_position=cache_position,
390
+ ).logits
391
+
392
+ return Seq2SeqLMOutput(logits=logits)
393
+
368
394
  def _prepare_encoder_decoder_kwargs_for_generation(
369
395
  self,
370
396
  inputs_tensor: torch.Tensor,
371
397
  model_kwargs,
372
398
  model_input_name: Optional[str] = None,
373
- *args,
374
- **kwargs,
399
+ generation_config: Optional[GenerationConfig] = None,
375
400
  ) -> Dict[str, Any]:
376
- ########## thkim change start ###################
377
- # padding input_ids & attention_mask regardless of user's tokenizer usage
378
- batch_size, input_len = inputs_tensor.shape
379
- inputs_tensor = torch.nn.functional.pad(
380
- inputs_tensor, (0, self.enc_max_seq_len - input_len), value=self.pad_token_id
381
- )
382
- model_kwargs["attention_mask"] = torch.nn.functional.pad(
383
- model_kwargs["attention_mask"], (0, self.enc_max_seq_len - input_len), value=0
384
- )
385
- ########## thkim change end ###################
386
-
387
401
  # 1. get encoder
388
402
  encoder = self.get_encoder()
389
403
 
@@ -401,10 +415,26 @@ class RBLNModelForSeq2SeqLM(RBLNModel):
401
415
  argument: value for argument, value in encoder_kwargs.items() if argument in encoder_signature
402
416
  }
403
417
 
418
+ batch_size, input_len = inputs_tensor.shape
419
+ inputs_tensor = torch.nn.functional.pad(
420
+ inputs_tensor,
421
+ (0, self.rbln_config.model_cfg["enc_max_seq_len"] - input_len),
422
+ value=self.rbln_config.model_cfg["pad_token_id"],
423
+ )
424
+ model_kwargs["attention_mask"] = torch.nn.functional.pad(
425
+ model_kwargs["attention_mask"], (0, self.rbln_config.model_cfg["enc_max_seq_len"] - input_len)
426
+ )
427
+
404
428
  # 3. make sure that encoder returns `ModelOutput`
405
429
  model_input_name = model_input_name if model_input_name is not None else self.main_input_name
406
430
  encoder_kwargs["return_dict"] = True
407
- encoder_kwargs[model_input_name] = inputs_tensor
408
- model_kwargs["encoder_outputs"] = encoder(**encoder_kwargs)
431
+ encoder_kwargs["output_hidden_states"] = False
432
+ encoder_kwargs["output_attentions"] = False
433
+
434
+ for b in range(batch_size):
435
+ batch_idx = torch.tensor(b, dtype=torch.int32)
436
+ encoder_kwargs["input_ids"] = inputs_tensor[b].unsqueeze(0)
437
+ encoder_kwargs["attention_mask"] = model_kwargs["attention_mask"][b].unsqueeze(0).to(torch.float32)
438
+ model_kwargs["encoder_outputs"] = encoder(**encoder_kwargs, batch_idx=batch_idx)
409
439
 
410
440
  return model_kwargs
@@ -21,4 +21,5 @@
21
21
  # copied, modified, or distributed without prior written permission
22
22
  # from Rebellions Inc.
23
23
 
24
+ from .modeling_t5 import RBLNT5ForConditionalGeneration
24
25
  from .t5_architecture import T5DecoderWrapper, T5EncoderWrapper
@@ -0,0 +1,55 @@
1
+ # Copyright 2024 Rebellions Inc.
2
+
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at:
6
+
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ # Portions of this software are licensed under the Apache License,
16
+ # Version 2.0. See the NOTICE file distributed with this work for
17
+ # additional information regarding copyright ownership.
18
+
19
+ # All other portions of this software, including proprietary code,
20
+ # are the intellectual property of Rebellions Inc. and may not be
21
+ # copied, modified, or distributed without prior written permission
22
+ # from Rebellions Inc.
23
+
24
+ import inspect
25
+ from typing import TYPE_CHECKING, Any, Callable
26
+
27
+ from transformers import T5ForConditionalGeneration
28
+
29
+ from ....modeling_config import RBLNConfig
30
+ from ....utils.logging import get_logger
31
+ from ...models.seq2seq import RBLNModelForSeq2SeqLM
32
+ from .t5_architecture import T5Wrapper
33
+
34
+
35
+ logger = get_logger()
36
+
37
+ if TYPE_CHECKING:
38
+ from transformers import PreTrainedModel
39
+
40
+
41
+ class RBLNT5ForConditionalGeneration(RBLNModelForSeq2SeqLM):
42
+ @classmethod
43
+ def wrap_model_if_needed(self, model: "PreTrainedModel", rbln_config: "RBLNConfig"):
44
+ return T5Wrapper(model)
45
+
46
+ def __getattr__(self, __name: str) -> Any:
47
+ def redirect(func):
48
+ return lambda *pargs, **kwargs: func(self, *pargs, **kwargs)
49
+
50
+ val = getattr(T5ForConditionalGeneration, __name)
51
+
52
+ if isinstance(val, Callable) and "self" in set(inspect.signature(val).parameters):
53
+ return redirect(val)
54
+
55
+ return val