optimum-rbln 0.1.8__py3-none-any.whl → 0.1.11__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (66) hide show
  1. optimum/rbln/__init__.py +40 -2
  2. optimum/rbln/__version__.py +1 -1
  3. optimum/rbln/diffusers/models/autoencoder_kl.py +39 -32
  4. optimum/rbln/diffusers/models/controlnet.py +60 -43
  5. optimum/rbln/diffusers/models/unet_2d_condition.py +43 -31
  6. optimum/rbln/diffusers/pipelines/controlnet/multicontrolnet.py +2 -3
  7. optimum/rbln/diffusers/pipelines/controlnet/pipeline_controlnet.py +22 -15
  8. optimum/rbln/diffusers/pipelines/controlnet/pipeline_controlnet_img2img.py +22 -15
  9. optimum/rbln/diffusers/pipelines/controlnet/pipeline_controlnet_sd_xl.py +23 -17
  10. optimum/rbln/diffusers/pipelines/controlnet/pipeline_controlnet_sd_xl_img2img.py +24 -18
  11. optimum/rbln/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion.py +22 -11
  12. optimum/rbln/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_img2img.py +22 -11
  13. optimum/rbln/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl.py +24 -14
  14. optimum/rbln/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_img2img.py +24 -14
  15. optimum/rbln/modeling_alias.py +8 -4
  16. optimum/rbln/modeling_base.py +512 -238
  17. optimum/rbln/modeling_config.py +152 -77
  18. optimum/rbln/modeling_seq2seq.py +166 -77
  19. optimum/rbln/transformers/__init__.py +37 -1
  20. optimum/rbln/transformers/models/__init__.py +21 -1
  21. optimum/rbln/transformers/models/auto/__init__.py +14 -0
  22. optimum/rbln/transformers/models/auto/auto_factory.py +84 -0
  23. optimum/rbln/transformers/models/auto/modeling_auto.py +94 -0
  24. optimum/rbln/transformers/models/bart/__init__.py +1 -0
  25. optimum/rbln/transformers/models/bart/bart_architecture.py +189 -50
  26. optimum/rbln/transformers/models/bart/modeling_bart.py +106 -0
  27. optimum/rbln/transformers/models/bert/__init__.py +24 -0
  28. optimum/rbln/transformers/models/bert/modeling_bert.py +102 -0
  29. optimum/rbln/transformers/models/clip/__init__.py +1 -1
  30. optimum/rbln/transformers/models/clip/modeling_clip.py +128 -26
  31. optimum/rbln/transformers/models/decoderonly/decoderonly_architecture.py +32 -7
  32. optimum/rbln/transformers/models/decoderonly/modeling_decoderonly.py +406 -104
  33. optimum/rbln/transformers/models/dpt/modeling_dpt.py +21 -7
  34. optimum/rbln/transformers/models/gemma/gemma_architecture.py +10 -3
  35. optimum/rbln/transformers/models/gemma/modeling_gemma.py +9 -3
  36. optimum/rbln/transformers/models/gpt2/gpt2_architecture.py +4 -1
  37. optimum/rbln/transformers/models/gpt2/modeling_gpt2.py +6 -89
  38. optimum/rbln/transformers/models/llama/modeling_llama.py +9 -3
  39. optimum/rbln/transformers/models/llava_next/__init__.py +24 -0
  40. optimum/rbln/transformers/models/llava_next/modeling_llava_next.py +666 -0
  41. optimum/rbln/transformers/models/midm/midm_architecture.py +5 -1
  42. optimum/rbln/transformers/models/midm/modeling_midm.py +5 -88
  43. optimum/rbln/transformers/models/mistral/__init__.py +24 -0
  44. optimum/rbln/transformers/models/mistral/mistral_architecture.py +29 -0
  45. optimum/rbln/transformers/models/mistral/modeling_mistral.py +68 -0
  46. optimum/rbln/transformers/models/phi/__init__.py +24 -0
  47. optimum/rbln/transformers/models/phi/modeling_phi.py +69 -0
  48. optimum/rbln/transformers/models/phi/phi_architecture.py +406 -0
  49. optimum/rbln/transformers/models/t5/t5_architecture.py +92 -31
  50. optimum/rbln/transformers/models/wav2vec2/modeling_wav2vec2.py +18 -12
  51. optimum/rbln/transformers/models/whisper/generation_whisper.py +68 -0
  52. optimum/rbln/transformers/models/whisper/modeling_whisper.py +141 -105
  53. optimum/rbln/transformers/models/whisper/whisper_architecture.py +44 -17
  54. optimum/rbln/transformers/models/xlm_roberta/modeling_xlm_roberta.py +25 -16
  55. optimum/rbln/transformers/utils/__init__.py +0 -0
  56. optimum/rbln/transformers/utils/rbln_quantization.py +97 -0
  57. optimum/rbln/utils/import_utils.py +37 -5
  58. optimum/rbln/utils/logging.py +82 -0
  59. optimum/rbln/utils/runtime_utils.py +35 -1
  60. optimum/rbln/utils/timer_utils.py +19 -0
  61. {optimum_rbln-0.1.8.dist-info → optimum_rbln-0.1.11.dist-info}/METADATA +15 -7
  62. optimum_rbln-0.1.11.dist-info/RECORD +93 -0
  63. {optimum_rbln-0.1.8.dist-info → optimum_rbln-0.1.11.dist-info}/WHEEL +1 -1
  64. optimum_rbln-0.1.11.dist-info/entry_points.txt +4 -0
  65. optimum_rbln-0.1.8.dist-info/RECORD +0 -73
  66. {optimum_rbln-0.1.8.dist-info → optimum_rbln-0.1.11.dist-info}/licenses/LICENSE +0 -0
@@ -20,18 +20,24 @@
20
20
  # are the intellectual property of Rebellions Inc. and may not be
21
21
  # copied, modified, or distributed without prior written permission
22
22
  # from Rebellions Inc.
23
+ import glob
23
24
  import logging
24
- from abc import ABC, abstractmethod
25
- from typing import TYPE_CHECKING, Dict, List, Optional, Tuple, Union
25
+ from abc import ABC
26
+ from dataclasses import dataclass
27
+ from pathlib import Path
28
+ from typing import TYPE_CHECKING, Any, Dict, List, Optional, Tuple, Union
26
29
 
27
30
  import rebel # noqa: F401
28
31
  import torch # noqa: F401
29
- from transformers import AutoModelForCausalLM, PretrainedConfig, PreTrainedModel
30
- from transformers.modeling_outputs import CausalLMOutputWithPast
32
+ from safetensors.torch import load_file
33
+ from transformers import AutoConfig, AutoModelForCausalLM, PretrainedConfig, PreTrainedModel
34
+ from transformers.modeling_utils import no_init_weights
35
+ from transformers.utils import ModelOutput
31
36
 
32
37
  from ....modeling_base import RBLNModel
33
- from ....modeling_config import DEFAULT_COMPILED_MODEL_NAME, RBLNConfig, RBLNRuntimeConfig
38
+ from ....modeling_config import DEFAULT_COMPILED_MODEL_NAME, RBLNCompileConfig, RBLNConfig
34
39
  from ....utils.runtime_utils import RBLNPytorchRuntime
40
+ from ....utils.timer_utils import rbln_timer
35
41
 
36
42
 
37
43
  logger = logging.getLogger(__name__)
@@ -44,9 +50,54 @@ if TYPE_CHECKING:
44
50
  PretrainedConfig,
45
51
  )
46
52
 
53
+ SUPPORTED_QUANTIZATIONS = {
54
+ "rbln": [
55
+ "w4a16",
56
+ ],
57
+ }
58
+
47
59
 
48
60
  class RBLNRuntimeModel(RBLNPytorchRuntime):
49
- mandatory_members = ["main_input_name"]
61
+ mandatory_members = ["main_input_name", "embed_tokens"]
62
+
63
+ def forward(
64
+ self,
65
+ input_ids: torch.LongTensor,
66
+ inputs_embeds: torch.Tensor,
67
+ attention_mask: torch.Tensor,
68
+ cache_position: torch.Tensor,
69
+ batch_position: torch.Tensor,
70
+ query_idx: torch.Tensor,
71
+ **kwargs,
72
+ ):
73
+ if inputs_embeds is None:
74
+ inp = input_ids
75
+ if self.embed_tokens is not None:
76
+ inp = self.embed_tokens(inp)
77
+
78
+ return super().forward(
79
+ inp,
80
+ attention_mask,
81
+ cache_position,
82
+ batch_position,
83
+ query_idx,
84
+ **kwargs,
85
+ )
86
+ else:
87
+ return super().forward(
88
+ inputs_embeds,
89
+ attention_mask,
90
+ cache_position,
91
+ batch_position,
92
+ query_idx,
93
+ **kwargs,
94
+ )
95
+
96
+
97
+ @dataclass
98
+ class RBLNDecoderOnlyOutput(ModelOutput):
99
+ logits: torch.FloatTensor = None
100
+ past_cached_length: Union[int, torch.Tensor] = None
50
101
 
51
102
 
52
103
  class RBLNDecoderOnlyModelForCausalLM(RBLNModel, ABC):
@@ -64,52 +115,177 @@ class RBLNDecoderOnlyModelForCausalLM(RBLNModel, ABC):
64
115
  auto_model_class = AutoModelForCausalLM
65
116
 
66
117
  def __post_init__(self, **kwargs):
67
- self.batch_size = self.rbln_config.meta["rbln_batch_size"]
68
- self.max_seq_len = self.rbln_config.meta["rbln_max_seq_len"]
69
- self.prefill_chunk_size = self.rbln_config.meta["rbln_prefill_chunk_size"]
118
+ self.batch_size = self.rbln_config.model_cfg["batch_size"]
119
+ self.max_seq_len = self.rbln_config.model_cfg["max_seq_len"]
120
+ self.prefill_chunk_size = self.rbln_config.model_cfg["prefill_chunk_size"]
70
121
 
71
- self.prefill_attention_mask = torch.zeros(1, 1, self.prefill_chunk_size, self.max_seq_len, dtype=torch.int64)
122
+ self.prefill_attention_mask = torch.zeros(1, 1, self.prefill_chunk_size, self.max_seq_len, dtype=torch.float32)
72
123
  self.causal_mask = 1 - torch.triu(
73
124
  torch.ones(1, 1, self.prefill_chunk_size, self.prefill_chunk_size), diagonal=1
74
125
  )
75
- self.dec_attn_mask_init = torch.zeros(1, 1, 1, self.max_seq_len, dtype=torch.int64)
76
- self.dec_attn_mask = torch.zeros(self.batch_size, 1, 1, self.max_seq_len, dtype=torch.int64)
77
- self.prefill_decoder = RBLNRuntimeModel(runtime=self.model[0], main_input_name="input_ids")
78
- self.decoder = RBLNRuntimeModel(runtime=self.model[1], main_input_name="input_ids")
126
+ self.dec_attn_mask_init = torch.zeros(1, 1, 1, self.max_seq_len, dtype=torch.float32)
127
+ self.dec_attn_mask = torch.zeros(self.batch_size, 1, 1, self.max_seq_len, dtype=torch.float32)
128
+
129
+ main_input_name = self.main_input_name
130
+ if self.rbln_config.model_cfg["use_inputs_embeds"]:
131
+ main_input_name = "inputs_embeds"
132
+ artifacts = torch.load(self.model_save_dir / self.subfolder / "torch_artifacts.pth", weights_only=False)
133
+ with no_init_weights():
134
+ self.embed_tokens = torch.nn.Embedding(
135
+ self.config.vocab_size,
136
+ self.config.hidden_size,
137
+ self.config.pad_token_id,
138
+ )
139
+ self.embed_tokens.load_state_dict(artifacts["embed_tokens"])
140
+ else:
141
+ self.embed_tokens = None
142
+
143
+ self.prefill_decoder = RBLNRuntimeModel(
144
+ runtime=self.model[0], main_input_name=main_input_name, embed_tokens=self.embed_tokens
145
+ )
146
+ self.decoder = RBLNRuntimeModel(
147
+ runtime=self.model[1], main_input_name=main_input_name, embed_tokens=self.embed_tokens
148
+ )
79
149
 
80
150
  @classmethod
81
- @abstractmethod
82
- def wrapping_torch_model(self, model: "PreTrainedModel", rbln_max_seq_len: int):
83
- pass
151
+ def save_torch_artifacts(
152
+ cls,
153
+ model: "PreTrainedModel",
154
+ save_dir_path: Path,
155
+ subfolder: str,
156
+ rbln_config: RBLNConfig,
157
+ ):
158
+ """
159
+ If you are unavoidably running on a CPU rather than an RBLN device,
160
+ store the torch tensor, weight, etc. in this function.
161
+ """
162
+ if rbln_config.model_cfg["use_inputs_embeds"]:
163
+ save_dict = {}
164
+ save_dict["embed_tokens"] = model.get_input_embeddings().state_dict()
165
+ torch.save(save_dict, save_dir_path / subfolder / "torch_artifacts.pth")
166
+
167
+ def get_input_embeddings(self):
168
+ return self.embed_tokens
169
+
170
+ @classmethod
171
+ def get_quantized_model(
172
+ cls,
173
+ model_id: str,
174
+ use_auth_token: Optional[Union[bool, str]] = None,
175
+ revision: Optional[str] = None,
176
+ force_download: bool = False,
177
+ cache_dir: Optional[str] = None,
178
+ subfolder: str = "",
179
+ local_files_only: bool = False,
180
+ trust_remote_code: bool = False,
181
+ **kwargs,
182
+ ):
183
+ from ...utils.rbln_quantization import update_layers_to_quantized
184
+
185
+ kwargs = cls.update_kwargs(kwargs)
186
+
187
+ config = AutoConfig.from_pretrained(
188
+ model_id,
189
+ use_auth_token=use_auth_token,
190
+ revision=revision,
191
+ force_download=force_download,
192
+ cache_dir=cache_dir,
193
+ trust_remote_code=trust_remote_code,
194
+ **kwargs,
195
+ )
196
+
197
+ with no_init_weights():
198
+ model = AutoModelForCausalLM.from_config(config)
199
+
200
+ update_layers_to_quantized(model)
201
+
202
+ n_layer = kwargs.get("num_hidden_layers", None)
203
+ cls._load_weights_directly_to_model(model, model_id, n_layer)
204
+
205
+ return model
206
+
207
+ def _load_weights_directly_to_model(model, model_id, n_layer=None):
208
+ """
209
+ Load safetensor file data directly into the model, filtering by layer if n_layer is provided.
210
+ """
211
+
212
+ model_params = dict(model.named_parameters(recurse=True))
213
+ model_buffers = dict(model.named_buffers(recurse=True))
214
+ safetensor_files = glob.glob(f"{model_id}/*.safetensors")
215
+
216
+ target_layers = list(range(n_layer)) if n_layer is not None else None
217
+
218
+ for safetensor_file in safetensor_files:
219
+ file_data = load_file(safetensor_file)
220
+ for key, value in file_data.items():
221
+ if target_layers is not None:
222
+ parts = key.split(".")
223
+
224
+ if len(parts) > 2 and parts[2].isdigit() and (int(parts[2]) not in target_layers):
225
+ continue
226
+
227
+ if key in model_params:
228
+ model_params[key].data.copy_(value)
229
+ elif key in model_buffers:
230
+ model_buffers[key].data.copy_(value)
231
+
232
+ return 0
233
+
234
+ @classmethod
235
+ def get_pytorch_model(cls, *args, **kwargs) -> "PreTrainedModel":
236
+ rbln_kwargs = kwargs.get("rbln_kwargs", {})
237
+ rbln_quantization = rbln_kwargs.get("quantization", None)
238
+
239
+ if rbln_quantization is not None and rbln_quantization["format"] == "rbln":
240
+ model = cls.get_quantized_model(*args, **kwargs)
241
+ else:
242
+ model = super().get_pytorch_model(*args, **kwargs)
243
+
244
+ return model
84
245
 
85
246
  @classmethod
86
247
  @torch.inference_mode()
87
248
  def get_compiled_model(cls, model: "PreTrainedModel", rbln_config: RBLNConfig):
88
- wrapped_model = cls.wrapping_torch_model(model, rbln_config.meta["rbln_max_seq_len"])
249
+ wrapped_model = cls.wrap_model_if_needed(model, rbln_config)
89
250
 
90
- prefill_rbln_runtime_config = rbln_config[DEFAULT_COMPILED_MODEL_NAME][0]
91
- dec_rbln_runtime_config = rbln_config[DEFAULT_COMPILED_MODEL_NAME][1]
251
+ rbln_compile_configs = rbln_config.compile_cfgs
252
+ prefill_rbln_compile_config = rbln_compile_configs[0]
253
+ dec_rbln_compile_config = rbln_compile_configs[1]
92
254
 
93
- prefill_example_inputs = prefill_rbln_runtime_config.get_dummy_inputs(fill=0)
94
- dec_example_inputs = dec_rbln_runtime_config.get_dummy_inputs(fill=4)
255
+ @rbln_timer("Jit Trace")
256
+ def get_scripted_model():
257
+ # This function is nested to dealloc the example inputs before compilation.
258
+ prefill_example_inputs = prefill_rbln_compile_config.get_dummy_inputs(fill=0)
259
+ dec_example_inputs = dec_rbln_compile_config.get_dummy_inputs(fill=4)
95
260
 
96
- batch_index = 3
97
- dec_example_inputs[batch_index].fill_(-1) # fill batch_position -1 to indicate it is decoder.
261
+ batch_index = 3
262
+ dec_example_inputs[batch_index].fill_(-1) # fill batch_position -1 to indicate it is decoder.
263
+
264
+ prefill_scripted_model = torch.jit.trace(
265
+ wrapped_model, prefill_example_inputs, check_trace=False, _store_inputs=False
266
+ )
267
+ dec_scripted_model = torch.jit.trace(
268
+ wrapped_model, dec_example_inputs, check_trace=False, _store_inputs=False
269
+ )
270
+ return prefill_scripted_model, dec_scripted_model
98
271
 
99
- prefill_scripted_model = torch.jit.trace(wrapped_model, prefill_example_inputs, check_trace=False)
100
- dec_scripted_model = torch.jit.trace(wrapped_model, dec_example_inputs, check_trace=False)
272
+ prefill_scripted_model, dec_scripted_model = get_scripted_model()
101
273
 
102
- prefill_ir = rebel.torchscript_to_ir(
103
- prefill_scripted_model,
104
- input_names=[v[0] for v in prefill_rbln_runtime_config.input_info],
105
- )
106
- dec_ir = rebel.torchscript_to_ir(
107
- dec_scripted_model,
108
- input_names=[v[0] for v in dec_rbln_runtime_config.input_info],
109
- )
274
+ @rbln_timer("TorchScript to IR")
275
+ def scripted_model_to_ir():
276
+ prefill_ir = rebel.torchscript_to_ir(
277
+ prefill_scripted_model,
278
+ input_names=[v[0] for v in prefill_rbln_compile_config.input_info],
279
+ )
280
+ dec_ir = rebel.torchscript_to_ir(
281
+ dec_scripted_model,
282
+ input_names=[v[0] for v in dec_rbln_compile_config.input_info],
283
+ )
284
+ return prefill_ir, dec_ir
110
285
 
286
+ prefill_ir, dec_ir = scripted_model_to_ir()
111
287
  # Caching prefill_decoder/decoder I/O
112
- cache_index_offset = 4
288
+ cache_index_offset = 5
113
289
  connections = [
114
290
  (prefill_ir.outputs[1 + i], prefill_ir.inputs[cache_index_offset + i])
115
291
  for i in range(model.config.num_hidden_layers * 2)
@@ -119,9 +295,9 @@ class RBLNDecoderOnlyModelForCausalLM(RBLNModel, ABC):
119
295
  prefill_ir,
120
296
  dec_ir,
121
297
  connections=connections,
122
- fusion=prefill_rbln_runtime_config.fusion,
123
- npu=prefill_rbln_runtime_config.npu,
124
- tensor_parallel_size=prefill_rbln_runtime_config.tensor_parallel_size,
298
+ fusion=prefill_rbln_compile_config.fusion,
299
+ npu=prefill_rbln_compile_config.npu,
300
+ tensor_parallel_size=prefill_rbln_compile_config.tensor_parallel_size,
125
301
  use_weight_sharing=True,
126
302
  )
127
303
  return compiled_model
@@ -131,39 +307,60 @@ class RBLNDecoderOnlyModelForCausalLM(RBLNModel, ABC):
131
307
  cls,
132
308
  preprocessors: Union["AutoFeatureExtractor", "AutoProcessor", "AutoTokenizer"],
133
309
  model_config: "PretrainedConfig",
134
- rbln_max_seq_len: Optional[int] = None,
135
- rbln_batch_size: Optional[int] = None,
136
- **kwargs,
310
+ rbln_kwargs: Dict[str, Any] = {},
137
311
  ) -> RBLNConfig:
138
- meta = {}
312
+ rbln_max_seq_len = rbln_kwargs.get("max_seq_len", None)
313
+ rbln_batch_size = rbln_kwargs.get("batch_size", None)
314
+ rbln_quantization = rbln_kwargs.get("quantization", None)
315
+ rbln_use_inputs_embeds = rbln_kwargs.get("use_inputs_embeds", None)
139
316
 
140
317
  prefill_chunk_size = 128
141
318
  if rbln_max_seq_len is None:
142
- rbln_max_seq_len = getattr(model_config, "max_position_embeddings", None)
319
+ rbln_max_seq_len = getattr(model_config, "max_position_embeddings", None) or getattr(
320
+ model_config, "n_positions", None
321
+ )
322
+ if rbln_max_seq_len is None:
323
+ raise ValueError("`rbln_max_seq_len` should be specified.")
143
324
  rbln_batch_size = 1 if rbln_batch_size is None else rbln_batch_size
144
-
145
- meta["rbln_max_seq_len"] = rbln_max_seq_len
146
- meta["rbln_batch_size"] = rbln_batch_size
147
- meta["rbln_prefill_chunk_size"] = prefill_chunk_size
325
+ rbln_use_inputs_embeds = False if rbln_use_inputs_embeds is None else rbln_use_inputs_embeds
326
+
327
+ num_attention_heads = getattr(model_config, "n_head", None) or getattr(model_config, "num_attention_heads")
328
+ num_key_value_heads = getattr(model_config, "num_key_value_heads", None) or num_attention_heads
329
+ num_hidden_layers = getattr(model_config, "n_layer", None) or getattr(model_config, "num_hidden_layers")
330
+ head_dim = getattr(model_config, "head_dim", None) or model_config.hidden_size // num_attention_heads
331
+ hidden_size = getattr(model_config, "n_embd", None) or getattr(model_config, "hidden_size")
332
+
333
+ if rbln_quantization is not None:
334
+ q_format = rbln_quantization.get("format", None)
335
+ q_precision = rbln_quantization.get("precision", None)
336
+
337
+ if q_format not in SUPPORTED_QUANTIZATIONS.keys() or q_precision not in SUPPORTED_QUANTIZATIONS[q_format]:
338
+ raise ValueError(
339
+ f'rbln_quantization="{rbln_quantization}" is not a supported quantization format or precesion, '
340
+ f"Possible: {SUPPORTED_QUANTIZATIONS}"
341
+ )
148
342
 
149
343
  def get_input_info(
150
344
  batch_size,
151
345
  query_length,
346
+ use_inputs_embeds,
347
+ hidden_size,
152
348
  ):
153
- head_dim = (
154
- model_config.head_dim
155
- if hasattr(model_config, "head_dim")
156
- else model_config.hidden_size // model_config.num_attention_heads
157
- )
349
+ if use_inputs_embeds:
350
+ main_input = ("inputs_embeds", [batch_size, query_length, hidden_size], "float32")
351
+ else:
352
+ main_input = ("input_ids", [batch_size, query_length], "int64")
353
+
158
354
  input_info = [
159
- ("input_ids", [batch_size, query_length], "int64"),
160
- ("attention_mask", [batch_size, 1, query_length, rbln_max_seq_len], "int64"),
355
+ main_input,
356
+ ("attention_mask", [batch_size, 1, query_length, rbln_max_seq_len], "float32"),
161
357
  (
162
358
  "cache_position",
163
359
  [batch_size, query_length],
164
360
  "int32",
165
361
  ),
166
362
  ("batch_position", [], "int16"),
363
+ ("query_idx", [], "int16"),
167
364
  ]
168
365
 
169
366
  input_info.extend(
@@ -172,13 +369,13 @@ class RBLNDecoderOnlyModelForCausalLM(RBLNModel, ABC):
172
369
  f"past_key_values_{i}",
173
370
  [
174
371
  rbln_batch_size,
175
- model_config.num_key_value_heads,
372
+ num_key_value_heads,
176
373
  rbln_max_seq_len,
177
374
  head_dim,
178
375
  ],
179
376
  "float32",
180
377
  )
181
- for i in range(model_config.num_hidden_layers * 2)
378
+ for i in range(num_hidden_layers * 2)
182
379
  ]
183
380
  )
184
381
 
@@ -187,22 +384,37 @@ class RBLNDecoderOnlyModelForCausalLM(RBLNModel, ABC):
187
384
  prefill_input_info = get_input_info(
188
385
  batch_size=1,
189
386
  query_length=prefill_chunk_size,
387
+ use_inputs_embeds=rbln_use_inputs_embeds,
388
+ hidden_size=hidden_size,
190
389
  )
191
390
  dec_input_info = get_input_info(
192
391
  batch_size=rbln_batch_size,
193
392
  query_length=1,
393
+ use_inputs_embeds=rbln_use_inputs_embeds,
394
+ hidden_size=hidden_size,
194
395
  )
195
396
 
196
- prefill_rbln_runtime_config = RBLNRuntimeConfig(input_info=prefill_input_info)
197
- dec_rbln_runtime_config = RBLNRuntimeConfig(input_info=dec_input_info)
397
+ prefill_rbln_compile_config = RBLNCompileConfig(input_info=prefill_input_info)
398
+ dec_rbln_compile_config = RBLNCompileConfig(input_info=dec_input_info)
198
399
 
199
- dec_rbln_runtime_config.batch_size = rbln_batch_size
400
+ rbln_config = RBLNConfig(
401
+ rbln_cls=cls.__name__,
402
+ compile_cfgs=[prefill_rbln_compile_config, dec_rbln_compile_config],
403
+ rbln_kwargs=rbln_kwargs,
404
+ )
200
405
 
201
- rbln_config = RBLNConfig.from_rbln_runtime_configs(
202
- [prefill_rbln_runtime_config, dec_rbln_runtime_config],
203
- _rbln_meta=meta,
406
+ rbln_config.model_cfg.update(
407
+ {
408
+ "max_seq_len": rbln_max_seq_len,
409
+ "batch_size": rbln_batch_size,
410
+ "prefill_chunk_size": prefill_chunk_size,
411
+ "use_inputs_embeds": rbln_use_inputs_embeds,
412
+ }
204
413
  )
205
414
 
415
+ if rbln_quantization is not None:
416
+ rbln_config.model_cfg.update({"quantization": rbln_quantization})
417
+
206
418
  return rbln_config
207
419
 
208
420
  @classmethod
@@ -224,82 +436,155 @@ class RBLNDecoderOnlyModelForCausalLM(RBLNModel, ABC):
224
436
  def _reorder_cache(self, past_key_values, beam_idx):
225
437
  raise NotImplementedError
226
438
 
227
- # args input_ids, past_key_values and attention_mask are updated by _update_model_kwargs_for_generation() in _greedy_search() in GenerationMixin
228
- def prepare_inputs_for_generation(self, input_ids, past_key_values=None, attention_mask=None, **kwargs):
229
- batch_size = input_ids.shape[0]
230
-
231
- # FIXME past_key_values is just carriier variable for past_cached_length
232
- # torch.tensor((4,1),dtype=torch.int32) which refers a past_cached_length of each batch
233
- past_cached_length = past_key_values
439
+ def prepare_inputs_for_generation(
440
+ self,
441
+ input_ids: torch.LongTensor,
442
+ past_cached_length: Optional[torch.Tensor] = None,
443
+ attention_mask: Optional[torch.LongTensor] = None,
444
+ inputs_embeds: Optional[torch.Tensor] = None,
445
+ **kwargs,
446
+ ):
447
+ model_inputs = {}
448
+ # prefill phase
234
449
  if past_cached_length is None:
235
- l_input_ids = []
450
+ # huggingface make dummy_input_ids if model_input_name is "input_embeds"
451
+ # https://github.com/huggingface/transformers/blob/174890280b340b89c5bfa092f6b4fb0e2dc2d7fc/src/transformers/generation/utils.py#L469
452
+ if self.rbln_config.model_cfg["use_inputs_embeds"] and inputs_embeds is not None:
453
+ input_tensors = inputs_embeds
454
+ else:
455
+ input_tensors = input_ids
456
+
457
+ batch_size = input_tensors.shape[0]
458
+ l_input_tensors = []
236
459
  cache_positions = []
237
460
  past_cached_length = torch.zeros((batch_size, 1), dtype=torch.int32)
238
461
  for i in range(batch_size):
239
- input_id = input_ids[i]
240
- input_id = input_id[attention_mask[i] == 1]
241
- valid_len = input_id.shape[-1]
462
+ input_tensor = input_tensors[i]
463
+ input_tensor = input_tensor[attention_mask[i] == 1]
464
+ valid_len = input_tensor.shape[0]
242
465
  cache_position = torch.arange(0, valid_len, dtype=torch.int32)
243
466
  past_cached_length[i] = valid_len
244
- l_input_ids.append(input_id.unsqueeze(0))
467
+ l_input_tensors.append(input_tensor.unsqueeze(0))
245
468
  cache_positions.append(cache_position.unsqueeze(0))
246
469
 
247
- input_ids = l_input_ids
470
+ input_tensors = l_input_tensors
471
+ if self.rbln_config.model_cfg["use_inputs_embeds"] and inputs_embeds is not None:
472
+ model_inputs.update({"inputs_embeds": input_tensors, "input_ids": input_ids})
473
+ else:
474
+ model_inputs.update({"input_ids": input_tensors, "inputs_embeds": inputs_embeds})
475
+ # decoder phase
248
476
  else:
249
477
  input_ids = input_ids[:, -1:]
250
478
  cache_positions = past_cached_length
251
479
  past_cached_length = past_cached_length + 1
480
+ model_inputs.update({"input_ids": input_ids})
252
481
 
253
- model_inputs = {
254
- "input_ids": input_ids,
255
- "cache_position": cache_positions,
256
- "past_cached_length": past_cached_length,
257
- }
482
+ model_inputs.update(
483
+ {
484
+ "cache_position": cache_positions,
485
+ "past_cached_length": past_cached_length,
486
+ }
487
+ )
258
488
 
259
489
  return model_inputs
260
490
 
491
+ def _update_model_kwargs_for_generation(
492
+ self,
493
+ outputs: RBLNDecoderOnlyOutput,
494
+ model_kwargs: Dict[str, Any],
495
+ **kwargs,
496
+ ) -> Dict[str, Any]:
497
+ # update past_cached_length
498
+ model_kwargs["past_cached_length"] = outputs.past_cached_length
499
+
500
+ return model_kwargs
501
+
261
502
  def forward(
262
503
  self,
263
- input_ids: torch.LongTensor = None,
504
+ input_ids: Optional[Union[List[torch.LongTensor], torch.LongTensor]] = None,
505
+ inputs_embeds: Optional[Union[List[torch.Tensor], torch.Tensor]] = None,
264
506
  cache_position: Union[List[torch.Tensor], torch.Tensor] = None, # vllm keyword argument
265
507
  batch_idx: Optional[int] = None,
266
- past_cached_length: Optional[torch.Tensor] = None, # past_cached_length
508
+ past_cached_length: Optional[torch.Tensor] = None,
267
509
  **kwargs,
268
510
  ) -> Tuple[torch.FloatTensor]:
269
511
  # prefll & hf generate
270
512
  if isinstance(cache_position, list):
271
513
  logits = []
272
- for batch_idx, (input_id, cache_pos) in enumerate(zip(input_ids, cache_position)):
273
- logit = self._forward_prefill(input_ids=input_id, cache_position=cache_pos, batch_idx=batch_idx)
514
+ input_tensors = input_ids if inputs_embeds is None else inputs_embeds
515
+ for batch_idx, (input_tensor, cache_pos) in enumerate(zip(input_tensors, cache_position)):
516
+ logit = self._forward_prefill(
517
+ input_ids=input_tensor if inputs_embeds is None else None,
518
+ inputs_embeds=input_tensor if inputs_embeds is not None else None,
519
+ cache_position=cache_pos,
520
+ batch_idx=batch_idx,
521
+ )
274
522
  logits.append(logit)
275
523
  logits = torch.cat(logits, dim=0)
276
524
  # prefill & vllm step
277
525
  elif cache_position.shape[-1] > 1:
278
- logits = self._forward_prefill(input_ids=input_ids, cache_position=cache_position, batch_idx=batch_idx)
526
+ logits = self._forward_prefill(
527
+ input_ids=input_ids,
528
+ inputs_embeds=inputs_embeds,
529
+ cache_position=cache_position,
530
+ batch_idx=batch_idx,
531
+ )
279
532
  # common decoder
280
533
  else:
281
- logits = self._forward_decoder(input_ids=input_ids, cache_position=cache_position)
534
+ logits = self._forward_decoder(
535
+ input_ids=input_ids,
536
+ inputs_embeds=inputs_embeds,
537
+ cache_position=cache_position,
538
+ )
282
539
 
283
- return CausalLMOutputWithPast(
540
+ return RBLNDecoderOnlyOutput(
284
541
  logits=logits,
285
- past_key_values=past_cached_length, # past_cached_length
542
+ past_cached_length=past_cached_length,
286
543
  )
287
544
 
288
545
  def _forward_prefill(
289
546
  self,
290
547
  input_ids: torch.LongTensor = None,
291
- cache_position: torch.Tensor = None, # torch.tensor(,dtype=int32) (1,64) // (4,1)
548
+ inputs_embeds: torch.Tensor = None,
549
+ cache_position: torch.Tensor = None,
292
550
  batch_idx: int = None,
293
551
  ) -> torch.FloatTensor:
294
552
  if batch_idx is None or batch_idx >= self.batch_size:
295
553
  raise RuntimeError(
296
554
  f"Invalid batch_idx ({batch_idx}). It must be a non-null value less than the batch size ({self.batch_size})."
297
555
  )
298
- query_length = input_ids.shape[1]
556
+
557
+ out_buffers = [
558
+ torch.empty(
559
+ size=[
560
+ 1,
561
+ 1,
562
+ self.config.vocab_size,
563
+ ],
564
+ dtype=torch.float32,
565
+ device="cpu",
566
+ ),
567
+ torch.empty(size=[], dtype=torch.int16, device="cpu"),
568
+ ]
569
+
570
+ if self.rbln_config.model_cfg["use_inputs_embeds"] and inputs_embeds is not None:
571
+ model_input_name = "inputs_embeds"
572
+ else:
573
+ model_input_name = "input_ids"
574
+
575
+ input_tensors = input_ids if model_input_name == "input_ids" else inputs_embeds
576
+
577
+ query_length = input_tensors.shape[1]
299
578
  attention_mask = self.prefill_attention_mask.clone()
300
579
  for step in range(0, query_length, self.prefill_chunk_size):
301
580
  if step + self.prefill_chunk_size > query_length:
302
- input_ids = torch.nn.functional.pad(input_ids, (0, step + self.prefill_chunk_size - query_length))
581
+ # input_tensors = torch.nn.functional.pad(input_tensors, (0, step + self.prefill_chunk_size - query_length))
582
+ padding_needed = step + self.prefill_chunk_size - query_length
583
+ if model_input_name == "input_ids":
584
+ input_tensors = torch.nn.functional.pad(input_tensors, (0, padding_needed))
585
+ else:
586
+ input_tensors = torch.nn.functional.pad(input_tensors, (0, 0, 0, padding_needed))
587
+
303
588
  cache_position = torch.cat(
304
589
  [
305
590
  cache_position,
@@ -312,18 +597,24 @@ class RBLNDecoderOnlyModelForCausalLM(RBLNModel, ABC):
312
597
  dim=-1,
313
598
  )
314
599
 
315
- sliced_input_ids = input_ids[:, step : step + self.prefill_chunk_size]
600
+ sliced_input_tensors = input_tensors[:, step : step + self.prefill_chunk_size]
316
601
  sliced_cache_positions = cache_position[:, step : step + self.prefill_chunk_size]
317
- attention_mask[:, :, :, :step] = 1
602
+
603
+ if step >= self.prefill_chunk_size:
604
+ attention_mask[:, :, :, step - self.prefill_chunk_size : step] = 1
318
605
  attention_mask[:, :, :, step : step + self.prefill_chunk_size] = self.causal_mask
319
606
 
607
+ query_idx = query_length % self.prefill_chunk_size - 1
608
+
320
609
  logits, _ = self.prefill_decoder(
321
- sliced_input_ids.contiguous(),
322
- attention_mask.contiguous(),
323
- sliced_cache_positions.contiguous(),
324
- torch.tensor(batch_idx, dtype=torch.int16),
610
+ input_ids=sliced_input_tensors.contiguous() if model_input_name == "input_ids" else None,
611
+ inputs_embeds=sliced_input_tensors.contiguous() if model_input_name == "inputs_embeds" else None,
612
+ attention_mask=attention_mask.contiguous(),
613
+ cache_position=sliced_cache_positions.contiguous(),
614
+ batch_position=torch.tensor(batch_idx, dtype=torch.int16),
615
+ query_idx=torch.tensor(query_idx, dtype=torch.int16),
616
+ out=out_buffers,
325
617
  )
326
- logits = logits[:, query_length % self.prefill_chunk_size - 1].unsqueeze(1)
327
618
 
328
619
  self.dec_attn_mask[batch_idx] = self.dec_attn_mask_init.clone()
329
620
  self.dec_attn_mask[batch_idx, :, :, :query_length] = 1
@@ -331,19 +622,30 @@ class RBLNDecoderOnlyModelForCausalLM(RBLNModel, ABC):
331
622
  return logits
332
623
 
333
624
  def _forward_decoder(
334
- self, input_ids: torch.LongTensor = None, cache_position: torch.Tensor = None
625
+ self,
626
+ input_ids: torch.LongTensor = None,
627
+ inputs_embeds: torch.Tensor = None,
628
+ cache_position: torch.Tensor = None,
335
629
  ) -> torch.FloatTensor:
336
- batch_size = input_ids.shape[0]
630
+ if self.rbln_config.model_cfg["use_inputs_embeds"] and inputs_embeds is not None:
631
+ model_input_name = "inputs_embeds"
632
+ else:
633
+ model_input_name = "input_ids"
634
+ input_tensors = input_ids if model_input_name == "input_ids" else inputs_embeds
635
+
636
+ batch_size = input_tensors.shape[0]
337
637
 
338
638
  for b_idx in range(batch_size):
339
639
  decoding_step = cache_position[b_idx].item()
340
640
  self.dec_attn_mask[b_idx, :, :, decoding_step] = 1
341
641
 
342
642
  logits, _ = self.decoder(
343
- input_ids.contiguous(),
344
- self.dec_attn_mask.contiguous(),
345
- cache_position.contiguous(),
346
- torch.tensor(0, dtype=torch.int16),
643
+ input_ids=input_tensors.contiguous() if model_input_name == "input_ids" else None,
644
+ inputs_embeds=input_tensors.contiguous() if model_input_name == "inputs_embeds" else None,
645
+ attention_mask=self.dec_attn_mask.contiguous(),
646
+ cache_position=cache_position.contiguous(),
647
+ batch_position=torch.tensor(0, dtype=torch.int16),
648
+ query_idx=torch.tensor(0, dtype=torch.int16),
347
649
  )
348
650
 
349
651
  return logits