optimum-rbln 0.1.9__py3-none-any.whl → 0.1.11__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (61) hide show
  1. optimum/rbln/__init__.py +37 -2
  2. optimum/rbln/__version__.py +1 -1
  3. optimum/rbln/diffusers/models/autoencoder_kl.py +36 -29
  4. optimum/rbln/diffusers/models/controlnet.py +56 -40
  5. optimum/rbln/diffusers/models/unet_2d_condition.py +40 -28
  6. optimum/rbln/diffusers/pipelines/controlnet/pipeline_controlnet.py +22 -15
  7. optimum/rbln/diffusers/pipelines/controlnet/pipeline_controlnet_img2img.py +22 -15
  8. optimum/rbln/diffusers/pipelines/controlnet/pipeline_controlnet_sd_xl.py +23 -17
  9. optimum/rbln/diffusers/pipelines/controlnet/pipeline_controlnet_sd_xl_img2img.py +24 -18
  10. optimum/rbln/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion.py +22 -11
  11. optimum/rbln/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_img2img.py +22 -11
  12. optimum/rbln/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl.py +24 -14
  13. optimum/rbln/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_img2img.py +24 -14
  14. optimum/rbln/modeling_alias.py +3 -3
  15. optimum/rbln/modeling_base.py +471 -231
  16. optimum/rbln/modeling_config.py +152 -77
  17. optimum/rbln/modeling_seq2seq.py +166 -77
  18. optimum/rbln/transformers/__init__.py +35 -1
  19. optimum/rbln/transformers/models/__init__.py +20 -1
  20. optimum/rbln/transformers/models/auto/__init__.py +14 -0
  21. optimum/rbln/transformers/models/auto/auto_factory.py +84 -0
  22. optimum/rbln/transformers/models/auto/modeling_auto.py +94 -0
  23. optimum/rbln/transformers/models/bart/__init__.py +1 -0
  24. optimum/rbln/transformers/models/bart/bart_architecture.py +189 -50
  25. optimum/rbln/transformers/models/bart/modeling_bart.py +106 -0
  26. optimum/rbln/transformers/models/bert/__init__.py +24 -0
  27. optimum/rbln/transformers/models/bert/modeling_bert.py +102 -0
  28. optimum/rbln/transformers/models/clip/__init__.py +1 -1
  29. optimum/rbln/transformers/models/clip/modeling_clip.py +127 -25
  30. optimum/rbln/transformers/models/decoderonly/decoderonly_architecture.py +28 -4
  31. optimum/rbln/transformers/models/decoderonly/modeling_decoderonly.py +302 -115
  32. optimum/rbln/transformers/models/dpt/modeling_dpt.py +21 -7
  33. optimum/rbln/transformers/models/gemma/modeling_gemma.py +1 -1
  34. optimum/rbln/transformers/models/gpt2/gpt2_architecture.py +4 -1
  35. optimum/rbln/transformers/models/gpt2/modeling_gpt2.py +1 -1
  36. optimum/rbln/transformers/models/llama/modeling_llama.py +1 -1
  37. optimum/rbln/transformers/models/llava_next/__init__.py +24 -0
  38. optimum/rbln/transformers/models/llava_next/modeling_llava_next.py +666 -0
  39. optimum/rbln/transformers/models/midm/midm_architecture.py +5 -1
  40. optimum/rbln/transformers/models/midm/modeling_midm.py +1 -1
  41. optimum/rbln/transformers/models/mistral/modeling_mistral.py +1 -1
  42. optimum/rbln/transformers/models/phi/__init__.py +24 -0
  43. optimum/rbln/transformers/models/phi/modeling_phi.py +69 -0
  44. optimum/rbln/transformers/models/phi/phi_architecture.py +406 -0
  45. optimum/rbln/transformers/models/t5/t5_architecture.py +92 -31
  46. optimum/rbln/transformers/models/wav2vec2/modeling_wav2vec2.py +17 -11
  47. optimum/rbln/transformers/models/whisper/generation_whisper.py +68 -0
  48. optimum/rbln/transformers/models/whisper/modeling_whisper.py +141 -105
  49. optimum/rbln/transformers/models/whisper/whisper_architecture.py +44 -17
  50. optimum/rbln/transformers/models/xlm_roberta/modeling_xlm_roberta.py +17 -14
  51. optimum/rbln/transformers/utils/rbln_quantization.py +48 -60
  52. optimum/rbln/utils/import_utils.py +36 -1
  53. optimum/rbln/utils/logging.py +82 -0
  54. optimum/rbln/utils/runtime_utils.py +33 -0
  55. optimum/rbln/utils/timer_utils.py +19 -0
  56. {optimum_rbln-0.1.9.dist-info → optimum_rbln-0.1.11.dist-info}/METADATA +8 -7
  57. optimum_rbln-0.1.11.dist-info/RECORD +93 -0
  58. {optimum_rbln-0.1.9.dist-info → optimum_rbln-0.1.11.dist-info}/WHEEL +1 -1
  59. optimum_rbln-0.1.11.dist-info/entry_points.txt +4 -0
  60. optimum_rbln-0.1.9.dist-info/RECORD +0 -78
  61. {optimum_rbln-0.1.9.dist-info → optimum_rbln-0.1.11.dist-info}/licenses/LICENSE +0 -0
@@ -23,19 +23,21 @@
23
23
  import glob
24
24
  import logging
25
25
  from abc import ABC
26
+ from dataclasses import dataclass
27
+ from pathlib import Path
26
28
  from typing import TYPE_CHECKING, Any, Dict, List, Optional, Tuple, Union
27
29
 
28
30
  import rebel # noqa: F401
29
31
  import torch # noqa: F401
30
32
  from safetensors.torch import load_file
31
33
  from transformers import AutoConfig, AutoModelForCausalLM, PretrainedConfig, PreTrainedModel
32
- from transformers.modeling_outputs import CausalLMOutputWithPast
33
34
  from transformers.modeling_utils import no_init_weights
35
+ from transformers.utils import ModelOutput
34
36
 
35
37
  from ....modeling_base import RBLNModel
36
- from ....modeling_config import DEFAULT_COMPILED_MODEL_NAME, RBLNConfig, RBLNRuntimeConfig
38
+ from ....modeling_config import DEFAULT_COMPILED_MODEL_NAME, RBLNCompileConfig, RBLNConfig
37
39
  from ....utils.runtime_utils import RBLNPytorchRuntime
38
- from ...utils.rbln_quantization import replace_quantized_linear_layers
40
+ from ....utils.timer_utils import rbln_timer
39
41
 
40
42
 
41
43
  logger = logging.getLogger(__name__)
@@ -56,7 +58,46 @@ SUPPORTED_QUANTIZATIONS = {
56
58
 
57
59
 
58
60
  class RBLNRuntimeModel(RBLNPytorchRuntime):
59
- mandatory_members = ["main_input_name"]
61
+ mandatory_members = ["main_input_name", "embed_tokens"]
62
+
63
+ def forward(
64
+ self,
65
+ input_ids: torch.LongTensor,
66
+ inputs_embeds: torch.Tensor,
67
+ attention_mask: torch.Tensor,
68
+ cache_position: torch.Tensor,
69
+ batch_position: torch.Tensor,
70
+ query_idx: torch.Tensor,
71
+ **kwargs,
72
+ ):
73
+ if inputs_embeds is None:
74
+ inp = input_ids
75
+ if self.embed_tokens is not None:
76
+ inp = self.embed_tokens(inp)
77
+
78
+ return super().forward(
79
+ inp,
80
+ attention_mask,
81
+ cache_position,
82
+ batch_position,
83
+ query_idx,
84
+ **kwargs,
85
+ )
86
+ else:
87
+ return super().forward(
88
+ inputs_embeds,
89
+ attention_mask,
90
+ cache_position,
91
+ batch_position,
92
+ query_idx,
93
+ **kwargs,
94
+ )
95
+
96
+
97
+ @dataclass
98
+ class RBLNDecoderOnlyOutput(ModelOutput):
99
+ logits: torch.FloatTensor = None
100
+ past_cached_length: Union[int, torch.Tensor] = None
60
101
 
61
102
 
62
103
  class RBLNDecoderOnlyModelForCausalLM(RBLNModel, ABC):
@@ -74,18 +115,57 @@ class RBLNDecoderOnlyModelForCausalLM(RBLNModel, ABC):
74
115
  auto_model_class = AutoModelForCausalLM
75
116
 
76
117
  def __post_init__(self, **kwargs):
77
- self.batch_size = self.rbln_config.meta["rbln_batch_size"]
78
- self.max_seq_len = self.rbln_config.meta["rbln_max_seq_len"]
79
- self.prefill_chunk_size = self.rbln_config.meta["rbln_prefill_chunk_size"]
118
+ self.batch_size = self.rbln_config.model_cfg["batch_size"]
119
+ self.max_seq_len = self.rbln_config.model_cfg["max_seq_len"]
120
+ self.prefill_chunk_size = self.rbln_config.model_cfg["prefill_chunk_size"]
80
121
 
81
- self.prefill_attention_mask = torch.zeros(1, 1, self.prefill_chunk_size, self.max_seq_len, dtype=torch.int64)
122
+ self.prefill_attention_mask = torch.zeros(1, 1, self.prefill_chunk_size, self.max_seq_len, dtype=torch.float32)
82
123
  self.causal_mask = 1 - torch.triu(
83
124
  torch.ones(1, 1, self.prefill_chunk_size, self.prefill_chunk_size), diagonal=1
84
125
  )
85
- self.dec_attn_mask_init = torch.zeros(1, 1, 1, self.max_seq_len, dtype=torch.int64)
86
- self.dec_attn_mask = torch.zeros(self.batch_size, 1, 1, self.max_seq_len, dtype=torch.int64)
87
- self.prefill_decoder = RBLNRuntimeModel(runtime=self.model[0], main_input_name="input_ids")
88
- self.decoder = RBLNRuntimeModel(runtime=self.model[1], main_input_name="input_ids")
126
+ self.dec_attn_mask_init = torch.zeros(1, 1, 1, self.max_seq_len, dtype=torch.float32)
127
+ self.dec_attn_mask = torch.zeros(self.batch_size, 1, 1, self.max_seq_len, dtype=torch.float32)
128
+
129
+ main_input_name = self.main_input_name
130
+ if self.rbln_config.model_cfg["use_inputs_embeds"]:
131
+ main_input_name = "inputs_embeds"
132
+ artifacts = torch.load(self.model_save_dir / self.subfolder / "torch_artifacts.pth", weights_only=False)
133
+ with no_init_weights():
134
+ self.embed_tokens = torch.nn.Embedding(
135
+ self.config.vocab_size,
136
+ self.config.hidden_size,
137
+ self.config.pad_token_id,
138
+ )
139
+ self.embed_tokens.load_state_dict(artifacts["embed_tokens"])
140
+ else:
141
+ self.embed_tokens = None
142
+
143
+ self.prefill_decoder = RBLNRuntimeModel(
144
+ runtime=self.model[0], main_input_name=main_input_name, embed_tokens=self.embed_tokens
145
+ )
146
+ self.decoder = RBLNRuntimeModel(
147
+ runtime=self.model[1], main_input_name=main_input_name, embed_tokens=self.embed_tokens
148
+ )
149
+
150
+ @classmethod
151
+ def save_torch_artifacts(
152
+ cls,
153
+ model: "PreTrainedModel",
154
+ save_dir_path: Path,
155
+ subfolder: str,
156
+ rbln_config: RBLNConfig,
157
+ ):
158
+ """
159
+ If you are unavoidably running on a CPU rather than an RBLN device,
160
+ store the torch tensor, weight, etc. in this function.
161
+ """
162
+ if rbln_config.model_cfg["use_inputs_embeds"]:
163
+ save_dict = {}
164
+ save_dict["embed_tokens"] = model.get_input_embeddings().state_dict()
165
+ torch.save(save_dict, save_dir_path / subfolder / "torch_artifacts.pth")
166
+
167
+ def get_input_embeddings(self):
168
+ return self.embed_tokens
89
169
 
90
170
  @classmethod
91
171
  def get_quantized_model(
@@ -98,10 +178,10 @@ class RBLNDecoderOnlyModelForCausalLM(RBLNModel, ABC):
98
178
  subfolder: str = "",
99
179
  local_files_only: bool = False,
100
180
  trust_remote_code: bool = False,
101
- rbln_config_kwargs: Optional[Dict[str, Any]] = None,
102
- rbln_constructor_kwargs: Optional[Dict[str, Any]] = None,
103
181
  **kwargs,
104
182
  ):
183
+ from ...utils.rbln_quantization import update_layers_to_quantized
184
+
105
185
  kwargs = cls.update_kwargs(kwargs)
106
186
 
107
187
  config = AutoConfig.from_pretrained(
@@ -116,37 +196,45 @@ class RBLNDecoderOnlyModelForCausalLM(RBLNModel, ABC):
116
196
 
117
197
  with no_init_weights():
118
198
  model = AutoModelForCausalLM.from_config(config)
119
- replace_quantized_linear_layers(model)
120
199
 
121
- state_dict = {}
122
- for safetensor_file in glob.glob(f"{model_id}/*.safetensors"):
123
- partial_state_dict = load_file(safetensor_file)
124
- state_dict.update(partial_state_dict)
200
+ update_layers_to_quantized(model)
125
201
 
126
202
  n_layer = kwargs.get("num_hidden_layers", None)
127
- if n_layer is not None:
128
- keys_to_delete = []
129
- for key in state_dict.keys():
130
- parts = key.split(".")
131
- if len(parts) > 2 and parts[2].isdigit():
132
- layer_num = int(parts[2])
133
- if layer_num >= n_layer:
134
- keys_to_delete.append(key)
135
-
136
- for key in keys_to_delete:
137
- del state_dict[key]
138
-
139
- model.load_state_dict(state_dict)
203
+ cls._load_weights_directly_to_model(model, model_id, n_layer)
204
+
140
205
  return model
141
206
 
207
+ def _load_weights_directly_to_model(model, model_id, n_layer=None):
208
+ """
209
+ Load safetensor file data directly into the model, filtering by layer if n_layer is provided.
210
+ """
211
+
212
+ model_params = dict(model.named_parameters(recurse=True))
213
+ model_buffers = dict(model.named_buffers(recurse=True))
214
+ safetensor_files = glob.glob(f"{model_id}/*.safetensors")
215
+
216
+ target_layers = list(range(n_layer)) if n_layer is not None else None
217
+
218
+ for safetensor_file in safetensor_files:
219
+ file_data = load_file(safetensor_file)
220
+ for key, value in file_data.items():
221
+ if target_layers is not None:
222
+ parts = key.split(".")
223
+
224
+ if len(parts) > 2 and parts[2].isdigit() and (int(parts[2]) not in target_layers):
225
+ continue
226
+
227
+ if key in model_params:
228
+ model_params[key].data.copy_(value)
229
+ elif key in model_buffers:
230
+ model_buffers[key].data.copy_(value)
231
+
232
+ return 0
233
+
142
234
  @classmethod
143
- def get_pytorch_model(
144
- cls,
145
- *args,
146
- **kwargs,
147
- ) -> "PreTrainedModel":
148
- rbln_config_kwargs = kwargs.get("rbln_config_kwargs", {})
149
- rbln_quantization = rbln_config_kwargs.get("rbln_quantization", None)
235
+ def get_pytorch_model(cls, *args, **kwargs) -> "PreTrainedModel":
236
+ rbln_kwargs = kwargs.get("rbln_kwargs", {})
237
+ rbln_quantization = rbln_kwargs.get("quantization", None)
150
238
 
151
239
  if rbln_quantization is not None and rbln_quantization["format"] == "rbln":
152
240
  model = cls.get_quantized_model(*args, **kwargs)
@@ -160,13 +248,15 @@ class RBLNDecoderOnlyModelForCausalLM(RBLNModel, ABC):
160
248
  def get_compiled_model(cls, model: "PreTrainedModel", rbln_config: RBLNConfig):
161
249
  wrapped_model = cls.wrap_model_if_needed(model, rbln_config)
162
250
 
163
- prefill_rbln_runtime_config = rbln_config[DEFAULT_COMPILED_MODEL_NAME][0]
164
- dec_rbln_runtime_config = rbln_config[DEFAULT_COMPILED_MODEL_NAME][1]
251
+ rbln_compile_configs = rbln_config.compile_cfgs
252
+ prefill_rbln_compile_config = rbln_compile_configs[0]
253
+ dec_rbln_compile_config = rbln_compile_configs[1]
165
254
 
255
+ @rbln_timer("Jit Trace")
166
256
  def get_scripted_model():
167
257
  # This function is nested to dealloc the example inputs before compilation.
168
- prefill_example_inputs = prefill_rbln_runtime_config.get_dummy_inputs(fill=0)
169
- dec_example_inputs = dec_rbln_runtime_config.get_dummy_inputs(fill=4)
258
+ prefill_example_inputs = prefill_rbln_compile_config.get_dummy_inputs(fill=0)
259
+ dec_example_inputs = dec_rbln_compile_config.get_dummy_inputs(fill=4)
170
260
 
171
261
  batch_index = 3
172
262
  dec_example_inputs[batch_index].fill_(-1) # fill batch_position -1 to indicate it is decoder.
@@ -181,17 +271,21 @@ class RBLNDecoderOnlyModelForCausalLM(RBLNModel, ABC):
181
271
 
182
272
  prefill_scripted_model, dec_scripted_model = get_scripted_model()
183
273
 
184
- prefill_ir = rebel.torchscript_to_ir(
185
- prefill_scripted_model,
186
- input_names=[v[0] for v in prefill_rbln_runtime_config.input_info],
187
- )
188
- dec_ir = rebel.torchscript_to_ir(
189
- dec_scripted_model,
190
- input_names=[v[0] for v in dec_rbln_runtime_config.input_info],
191
- )
274
+ @rbln_timer("TorchScript to IR")
275
+ def scripted_model_to_ir():
276
+ prefill_ir = rebel.torchscript_to_ir(
277
+ prefill_scripted_model,
278
+ input_names=[v[0] for v in prefill_rbln_compile_config.input_info],
279
+ )
280
+ dec_ir = rebel.torchscript_to_ir(
281
+ dec_scripted_model,
282
+ input_names=[v[0] for v in dec_rbln_compile_config.input_info],
283
+ )
284
+ return prefill_ir, dec_ir
192
285
 
286
+ prefill_ir, dec_ir = scripted_model_to_ir()
193
287
  # Caching prefill_decoder/decoder I/O
194
- cache_index_offset = 4
288
+ cache_index_offset = 5
195
289
  connections = [
196
290
  (prefill_ir.outputs[1 + i], prefill_ir.inputs[cache_index_offset + i])
197
291
  for i in range(model.config.num_hidden_layers * 2)
@@ -201,9 +295,9 @@ class RBLNDecoderOnlyModelForCausalLM(RBLNModel, ABC):
201
295
  prefill_ir,
202
296
  dec_ir,
203
297
  connections=connections,
204
- fusion=prefill_rbln_runtime_config.fusion,
205
- npu=prefill_rbln_runtime_config.npu,
206
- tensor_parallel_size=prefill_rbln_runtime_config.tensor_parallel_size,
298
+ fusion=prefill_rbln_compile_config.fusion,
299
+ npu=prefill_rbln_compile_config.npu,
300
+ tensor_parallel_size=prefill_rbln_compile_config.tensor_parallel_size,
207
301
  use_weight_sharing=True,
208
302
  )
209
303
  return compiled_model
@@ -213,12 +307,12 @@ class RBLNDecoderOnlyModelForCausalLM(RBLNModel, ABC):
213
307
  cls,
214
308
  preprocessors: Union["AutoFeatureExtractor", "AutoProcessor", "AutoTokenizer"],
215
309
  model_config: "PretrainedConfig",
216
- rbln_max_seq_len: Optional[int] = None,
217
- rbln_batch_size: Optional[int] = None,
218
- rbln_quantization: Optional[Dict[str, str]] = None,
219
- **kwargs,
310
+ rbln_kwargs: Dict[str, Any] = {},
220
311
  ) -> RBLNConfig:
221
- meta = {}
312
+ rbln_max_seq_len = rbln_kwargs.get("max_seq_len", None)
313
+ rbln_batch_size = rbln_kwargs.get("batch_size", None)
314
+ rbln_quantization = rbln_kwargs.get("quantization", None)
315
+ rbln_use_inputs_embeds = rbln_kwargs.get("use_inputs_embeds", None)
222
316
 
223
317
  prefill_chunk_size = 128
224
318
  if rbln_max_seq_len is None:
@@ -228,15 +322,13 @@ class RBLNDecoderOnlyModelForCausalLM(RBLNModel, ABC):
228
322
  if rbln_max_seq_len is None:
229
323
  raise ValueError("`rbln_max_seq_len` should be specified.")
230
324
  rbln_batch_size = 1 if rbln_batch_size is None else rbln_batch_size
231
-
232
- meta["rbln_max_seq_len"] = rbln_max_seq_len
233
- meta["rbln_batch_size"] = rbln_batch_size
234
- meta["rbln_prefill_chunk_size"] = prefill_chunk_size
325
+ rbln_use_inputs_embeds = False if rbln_use_inputs_embeds is None else rbln_use_inputs_embeds
235
326
 
236
327
  num_attention_heads = getattr(model_config, "n_head", None) or getattr(model_config, "num_attention_heads")
237
328
  num_key_value_heads = getattr(model_config, "num_key_value_heads", None) or num_attention_heads
238
329
  num_hidden_layers = getattr(model_config, "n_layer", None) or getattr(model_config, "num_hidden_layers")
239
330
  head_dim = getattr(model_config, "head_dim", None) or model_config.hidden_size // num_attention_heads
331
+ hidden_size = getattr(model_config, "n_embd", None) or getattr(model_config, "hidden_size")
240
332
 
241
333
  if rbln_quantization is not None:
242
334
  q_format = rbln_quantization.get("format", None)
@@ -247,21 +339,28 @@ class RBLNDecoderOnlyModelForCausalLM(RBLNModel, ABC):
247
339
  f'rbln_quantization="{rbln_quantization}" is not a supported quantization format or precesion, '
248
340
  f"Possible: {SUPPORTED_QUANTIZATIONS}"
249
341
  )
250
- meta["rbln_quantization"] = rbln_quantization
251
342
 
252
343
  def get_input_info(
253
344
  batch_size,
254
345
  query_length,
346
+ use_inputs_embeds,
347
+ hidden_size,
255
348
  ):
349
+ if use_inputs_embeds:
350
+ main_input = ("inputs_embeds", [batch_size, query_length, hidden_size], "float32")
351
+ else:
352
+ main_input = ("input_ids", [batch_size, query_length], "int64")
353
+
256
354
  input_info = [
257
- ("input_ids", [batch_size, query_length], "int64"),
258
- ("attention_mask", [batch_size, 1, query_length, rbln_max_seq_len], "int64"),
355
+ main_input,
356
+ ("attention_mask", [batch_size, 1, query_length, rbln_max_seq_len], "float32"),
259
357
  (
260
358
  "cache_position",
261
359
  [batch_size, query_length],
262
360
  "int32",
263
361
  ),
264
362
  ("batch_position", [], "int16"),
363
+ ("query_idx", [], "int16"),
265
364
  ]
266
365
 
267
366
  input_info.extend(
@@ -285,22 +384,37 @@ class RBLNDecoderOnlyModelForCausalLM(RBLNModel, ABC):
285
384
  prefill_input_info = get_input_info(
286
385
  batch_size=1,
287
386
  query_length=prefill_chunk_size,
387
+ use_inputs_embeds=rbln_use_inputs_embeds,
388
+ hidden_size=hidden_size,
288
389
  )
289
390
  dec_input_info = get_input_info(
290
391
  batch_size=rbln_batch_size,
291
392
  query_length=1,
393
+ use_inputs_embeds=rbln_use_inputs_embeds,
394
+ hidden_size=hidden_size,
292
395
  )
293
396
 
294
- prefill_rbln_runtime_config = RBLNRuntimeConfig(input_info=prefill_input_info)
295
- dec_rbln_runtime_config = RBLNRuntimeConfig(input_info=dec_input_info)
397
+ prefill_rbln_compile_config = RBLNCompileConfig(input_info=prefill_input_info)
398
+ dec_rbln_compile_config = RBLNCompileConfig(input_info=dec_input_info)
296
399
 
297
- dec_rbln_runtime_config.batch_size = rbln_batch_size
400
+ rbln_config = RBLNConfig(
401
+ rbln_cls=cls.__name__,
402
+ compile_cfgs=[prefill_rbln_compile_config, dec_rbln_compile_config],
403
+ rbln_kwargs=rbln_kwargs,
404
+ )
298
405
 
299
- rbln_config = RBLNConfig.from_rbln_runtime_configs(
300
- [prefill_rbln_runtime_config, dec_rbln_runtime_config],
301
- _rbln_meta=meta,
406
+ rbln_config.model_cfg.update(
407
+ {
408
+ "max_seq_len": rbln_max_seq_len,
409
+ "batch_size": rbln_batch_size,
410
+ "prefill_chunk_size": prefill_chunk_size,
411
+ "use_inputs_embeds": rbln_use_inputs_embeds,
412
+ }
302
413
  )
303
414
 
415
+ if rbln_quantization is not None:
416
+ rbln_config.model_cfg.update({"quantization": rbln_quantization})
417
+
304
418
  return rbln_config
305
419
 
306
420
  @classmethod
@@ -322,71 +436,117 @@ class RBLNDecoderOnlyModelForCausalLM(RBLNModel, ABC):
322
436
  def _reorder_cache(self, past_key_values, beam_idx):
323
437
  raise NotImplementedError
324
438
 
325
- # args input_ids, past_key_values and attention_mask are updated by _update_model_kwargs_for_generation() in _greedy_search() in GenerationMixin
326
- def prepare_inputs_for_generation(self, input_ids, past_key_values=None, attention_mask=None, **kwargs):
327
- batch_size = input_ids.shape[0]
328
-
329
- # FIXME past_key_values is just carriier variable for past_cached_length
330
- # torch.tensor((4,1),dtype=torch.int32) which refers a past_cached_length of each batch
331
- past_cached_length = past_key_values
439
+ def prepare_inputs_for_generation(
440
+ self,
441
+ input_ids: torch.LongTensor,
442
+ past_cached_length: Optional[torch.Tensor] = None,
443
+ attention_mask: Optional[torch.LongTensor] = None,
444
+ inputs_embeds: Optional[torch.Tensor] = None,
445
+ **kwargs,
446
+ ):
447
+ model_inputs = {}
448
+ # prefill phase
332
449
  if past_cached_length is None:
333
- l_input_ids = []
450
+ # huggingface make dummy_input_ids if model_input_name is "input_embeds"
451
+ # https://github.com/huggingface/transformers/blob/174890280b340b89c5bfa092f6b4fb0e2dc2d7fc/src/transformers/generation/utils.py#L469
452
+ if self.rbln_config.model_cfg["use_inputs_embeds"] and inputs_embeds is not None:
453
+ input_tensors = inputs_embeds
454
+ else:
455
+ input_tensors = input_ids
456
+
457
+ batch_size = input_tensors.shape[0]
458
+ l_input_tensors = []
334
459
  cache_positions = []
335
460
  past_cached_length = torch.zeros((batch_size, 1), dtype=torch.int32)
336
461
  for i in range(batch_size):
337
- input_id = input_ids[i]
338
- input_id = input_id[attention_mask[i] == 1]
339
- valid_len = input_id.shape[-1]
462
+ input_tensor = input_tensors[i]
463
+ input_tensor = input_tensor[attention_mask[i] == 1]
464
+ valid_len = input_tensor.shape[0]
340
465
  cache_position = torch.arange(0, valid_len, dtype=torch.int32)
341
466
  past_cached_length[i] = valid_len
342
- l_input_ids.append(input_id.unsqueeze(0))
467
+ l_input_tensors.append(input_tensor.unsqueeze(0))
343
468
  cache_positions.append(cache_position.unsqueeze(0))
344
469
 
345
- input_ids = l_input_ids
470
+ input_tensors = l_input_tensors
471
+ if self.rbln_config.model_cfg["use_inputs_embeds"] and inputs_embeds is not None:
472
+ model_inputs.update({"inputs_embeds": input_tensors, "input_ids": input_ids})
473
+ else:
474
+ model_inputs.update({"input_ids": input_tensors, "inputs_embeds": inputs_embeds})
475
+ # decoder phase
346
476
  else:
347
477
  input_ids = input_ids[:, -1:]
348
478
  cache_positions = past_cached_length
349
479
  past_cached_length = past_cached_length + 1
480
+ model_inputs.update({"input_ids": input_ids})
350
481
 
351
- model_inputs = {
352
- "input_ids": input_ids,
353
- "cache_position": cache_positions,
354
- "past_cached_length": past_cached_length,
355
- }
482
+ model_inputs.update(
483
+ {
484
+ "cache_position": cache_positions,
485
+ "past_cached_length": past_cached_length,
486
+ }
487
+ )
356
488
 
357
489
  return model_inputs
358
490
 
491
+ def _update_model_kwargs_for_generation(
492
+ self,
493
+ outputs: RBLNDecoderOnlyOutput,
494
+ model_kwargs: Dict[str, Any],
495
+ **kwargs,
496
+ ) -> Dict[str, Any]:
497
+ # update past_cached_length
498
+ model_kwargs["past_cached_length"] = outputs.past_cached_length
499
+
500
+ return model_kwargs
501
+
359
502
  def forward(
360
503
  self,
361
- input_ids: torch.LongTensor = None,
504
+ input_ids: Optional[Union[List[torch.LongTensor], torch.LongTensor]] = None,
505
+ inputs_embeds: Optional[Union[List[torch.Tensor], torch.Tensor]] = None,
362
506
  cache_position: Union[List[torch.Tensor], torch.Tensor] = None, # vllm keyword argument
363
507
  batch_idx: Optional[int] = None,
364
- past_cached_length: Optional[torch.Tensor] = None, # past_cached_length
508
+ past_cached_length: Optional[torch.Tensor] = None,
365
509
  **kwargs,
366
510
  ) -> Tuple[torch.FloatTensor]:
367
511
  # prefll & hf generate
368
512
  if isinstance(cache_position, list):
369
513
  logits = []
370
- for batch_idx, (input_id, cache_pos) in enumerate(zip(input_ids, cache_position)):
371
- logit = self._forward_prefill(input_ids=input_id, cache_position=cache_pos, batch_idx=batch_idx)
514
+ input_tensors = input_ids if inputs_embeds is None else inputs_embeds
515
+ for batch_idx, (input_tensor, cache_pos) in enumerate(zip(input_tensors, cache_position)):
516
+ logit = self._forward_prefill(
517
+ input_ids=input_tensor if inputs_embeds is None else None,
518
+ inputs_embeds=input_tensor if inputs_embeds is not None else None,
519
+ cache_position=cache_pos,
520
+ batch_idx=batch_idx,
521
+ )
372
522
  logits.append(logit)
373
523
  logits = torch.cat(logits, dim=0)
374
524
  # prefill & vllm step
375
525
  elif cache_position.shape[-1] > 1:
376
- logits = self._forward_prefill(input_ids=input_ids, cache_position=cache_position, batch_idx=batch_idx)
526
+ logits = self._forward_prefill(
527
+ input_ids=input_ids,
528
+ inputs_embeds=inputs_embeds,
529
+ cache_position=cache_position,
530
+ batch_idx=batch_idx,
531
+ )
377
532
  # common decoder
378
533
  else:
379
- logits = self._forward_decoder(input_ids=input_ids, cache_position=cache_position)
534
+ logits = self._forward_decoder(
535
+ input_ids=input_ids,
536
+ inputs_embeds=inputs_embeds,
537
+ cache_position=cache_position,
538
+ )
380
539
 
381
- return CausalLMOutputWithPast(
540
+ return RBLNDecoderOnlyOutput(
382
541
  logits=logits,
383
- past_key_values=past_cached_length, # past_cached_length
542
+ past_cached_length=past_cached_length,
384
543
  )
385
544
 
386
545
  def _forward_prefill(
387
546
  self,
388
547
  input_ids: torch.LongTensor = None,
389
- cache_position: torch.Tensor = None, # torch.tensor(,dtype=int32) (1,64) // (4,1)
548
+ inputs_embeds: torch.Tensor = None,
549
+ cache_position: torch.Tensor = None,
390
550
  batch_idx: int = None,
391
551
  ) -> torch.FloatTensor:
392
552
  if batch_idx is None or batch_idx >= self.batch_size:
@@ -398,7 +558,7 @@ class RBLNDecoderOnlyModelForCausalLM(RBLNModel, ABC):
398
558
  torch.empty(
399
559
  size=[
400
560
  1,
401
- self.prefill_chunk_size,
561
+ 1,
402
562
  self.config.vocab_size,
403
563
  ],
404
564
  dtype=torch.float32,
@@ -407,11 +567,24 @@ class RBLNDecoderOnlyModelForCausalLM(RBLNModel, ABC):
407
567
  torch.empty(size=[], dtype=torch.int16, device="cpu"),
408
568
  ]
409
569
 
410
- query_length = input_ids.shape[1]
570
+ if self.rbln_config.model_cfg["use_inputs_embeds"] and inputs_embeds is not None:
571
+ model_input_name = "inputs_embeds"
572
+ else:
573
+ model_input_name = "input_ids"
574
+
575
+ input_tensors = input_ids if model_input_name == "input_ids" else inputs_embeds
576
+
577
+ query_length = input_tensors.shape[1]
411
578
  attention_mask = self.prefill_attention_mask.clone()
412
579
  for step in range(0, query_length, self.prefill_chunk_size):
413
580
  if step + self.prefill_chunk_size > query_length:
414
- input_ids = torch.nn.functional.pad(input_ids, (0, step + self.prefill_chunk_size - query_length))
581
+ # input_tensors = torch.nn.functional.pad(input_tensors, (0, step + self.prefill_chunk_size - query_length))
582
+ padding_needed = step + self.prefill_chunk_size - query_length
583
+ if model_input_name == "input_ids":
584
+ input_tensors = torch.nn.functional.pad(input_tensors, (0, padding_needed))
585
+ else:
586
+ input_tensors = torch.nn.functional.pad(input_tensors, (0, 0, 0, padding_needed))
587
+
415
588
  cache_position = torch.cat(
416
589
  [
417
590
  cache_position,
@@ -424,21 +597,24 @@ class RBLNDecoderOnlyModelForCausalLM(RBLNModel, ABC):
424
597
  dim=-1,
425
598
  )
426
599
 
427
- sliced_input_ids = input_ids[:, step : step + self.prefill_chunk_size]
600
+ sliced_input_tensors = input_tensors[:, step : step + self.prefill_chunk_size]
428
601
  sliced_cache_positions = cache_position[:, step : step + self.prefill_chunk_size]
429
602
 
430
603
  if step >= self.prefill_chunk_size:
431
604
  attention_mask[:, :, :, step - self.prefill_chunk_size : step] = 1
432
605
  attention_mask[:, :, :, step : step + self.prefill_chunk_size] = self.causal_mask
433
606
 
607
+ query_idx = query_length % self.prefill_chunk_size - 1
608
+
434
609
  logits, _ = self.prefill_decoder(
435
- sliced_input_ids.contiguous(),
436
- attention_mask.contiguous(),
437
- sliced_cache_positions.contiguous(),
438
- torch.tensor(batch_idx, dtype=torch.int16),
610
+ input_ids=sliced_input_tensors.contiguous() if model_input_name == "input_ids" else None,
611
+ inputs_embeds=sliced_input_tensors.contiguous() if model_input_name == "inputs_embeds" else None,
612
+ attention_mask=attention_mask.contiguous(),
613
+ cache_position=sliced_cache_positions.contiguous(),
614
+ batch_position=torch.tensor(batch_idx, dtype=torch.int16),
615
+ query_idx=torch.tensor(query_idx, dtype=torch.int16),
439
616
  out=out_buffers,
440
617
  )
441
- logits = logits[:, query_length % self.prefill_chunk_size - 1].unsqueeze(1)
442
618
 
443
619
  self.dec_attn_mask[batch_idx] = self.dec_attn_mask_init.clone()
444
620
  self.dec_attn_mask[batch_idx, :, :, :query_length] = 1
@@ -446,19 +622,30 @@ class RBLNDecoderOnlyModelForCausalLM(RBLNModel, ABC):
446
622
  return logits
447
623
 
448
624
  def _forward_decoder(
449
- self, input_ids: torch.LongTensor = None, cache_position: torch.Tensor = None
625
+ self,
626
+ input_ids: torch.LongTensor = None,
627
+ inputs_embeds: torch.Tensor = None,
628
+ cache_position: torch.Tensor = None,
450
629
  ) -> torch.FloatTensor:
451
- batch_size = input_ids.shape[0]
630
+ if self.rbln_config.model_cfg["use_inputs_embeds"] and inputs_embeds is not None:
631
+ model_input_name = "inputs_embeds"
632
+ else:
633
+ model_input_name = "input_ids"
634
+ input_tensors = input_ids if model_input_name == "input_ids" else inputs_embeds
635
+
636
+ batch_size = input_tensors.shape[0]
452
637
 
453
638
  for b_idx in range(batch_size):
454
639
  decoding_step = cache_position[b_idx].item()
455
640
  self.dec_attn_mask[b_idx, :, :, decoding_step] = 1
456
641
 
457
642
  logits, _ = self.decoder(
458
- input_ids.contiguous(),
459
- self.dec_attn_mask.contiguous(),
460
- cache_position.contiguous(),
461
- torch.tensor(0, dtype=torch.int16),
643
+ input_ids=input_tensors.contiguous() if model_input_name == "input_ids" else None,
644
+ inputs_embeds=input_tensors.contiguous() if model_input_name == "inputs_embeds" else None,
645
+ attention_mask=self.dec_attn_mask.contiguous(),
646
+ cache_position=cache_position.contiguous(),
647
+ batch_position=torch.tensor(0, dtype=torch.int16),
648
+ query_idx=torch.tensor(0, dtype=torch.int16),
462
649
  )
463
650
 
464
651
  return logits