optimum-rbln 0.1.9__py3-none-any.whl → 0.1.12__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (73) hide show
  1. optimum/rbln/__init__.py +47 -9
  2. optimum/rbln/__version__.py +1 -1
  3. optimum/rbln/diffusers/models/autoencoder_kl.py +36 -31
  4. optimum/rbln/diffusers/models/controlnet.py +53 -43
  5. optimum/rbln/diffusers/models/unet_2d_condition.py +40 -31
  6. optimum/rbln/diffusers/pipelines/controlnet/multicontrolnet.py +4 -0
  7. optimum/rbln/diffusers/pipelines/controlnet/pipeline_controlnet.py +28 -23
  8. optimum/rbln/diffusers/pipelines/controlnet/pipeline_controlnet_img2img.py +28 -23
  9. optimum/rbln/diffusers/pipelines/controlnet/pipeline_controlnet_sd_xl.py +28 -37
  10. optimum/rbln/diffusers/pipelines/controlnet/pipeline_controlnet_sd_xl_img2img.py +30 -39
  11. optimum/rbln/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion.py +24 -14
  12. optimum/rbln/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_img2img.py +24 -15
  13. optimum/rbln/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl.py +26 -17
  14. optimum/rbln/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_img2img.py +26 -17
  15. optimum/rbln/modeling_alias.py +6 -11
  16. optimum/rbln/modeling_base.py +467 -261
  17. optimum/rbln/modeling_config.py +199 -73
  18. optimum/rbln/transformers/__init__.py +43 -1
  19. optimum/rbln/transformers/models/__init__.py +23 -1
  20. optimum/rbln/transformers/models/auto/__init__.py +14 -0
  21. optimum/rbln/transformers/models/auto/auto_factory.py +84 -0
  22. optimum/rbln/transformers/models/auto/modeling_auto.py +95 -0
  23. optimum/rbln/transformers/models/bart/__init__.py +1 -0
  24. optimum/rbln/transformers/models/bart/bart_architecture.py +203 -58
  25. optimum/rbln/transformers/models/bart/modeling_bart.py +125 -0
  26. optimum/rbln/transformers/models/bert/__init__.py +24 -0
  27. optimum/rbln/transformers/models/bert/modeling_bert.py +101 -0
  28. optimum/rbln/transformers/models/clip/__init__.py +1 -1
  29. optimum/rbln/transformers/models/clip/modeling_clip.py +127 -26
  30. optimum/rbln/transformers/models/decoderonly/decoderonly_architecture.py +28 -4
  31. optimum/rbln/transformers/models/decoderonly/modeling_decoderonly.py +409 -150
  32. optimum/rbln/transformers/models/dpt/modeling_dpt.py +21 -8
  33. optimum/rbln/transformers/models/exaone/__init__.py +32 -0
  34. optimum/rbln/transformers/models/exaone/exaone_architecture.py +72 -0
  35. optimum/rbln/transformers/models/exaone/hf_hub_cached/configuration_exaone.py +181 -0
  36. optimum/rbln/transformers/models/exaone/hf_hub_cached/modeling_exaone.py +1725 -0
  37. optimum/rbln/transformers/models/exaone/modeling_exaone.py +78 -0
  38. optimum/rbln/transformers/models/gemma/modeling_gemma.py +1 -1
  39. optimum/rbln/transformers/models/gpt2/gpt2_architecture.py +4 -1
  40. optimum/rbln/transformers/models/gpt2/modeling_gpt2.py +1 -1
  41. optimum/rbln/transformers/models/llama/modeling_llama.py +1 -1
  42. optimum/rbln/transformers/models/llava_next/__init__.py +24 -0
  43. optimum/rbln/transformers/models/llava_next/modeling_llava_next.py +662 -0
  44. optimum/rbln/transformers/models/midm/midm_architecture.py +5 -1
  45. optimum/rbln/transformers/models/midm/modeling_midm.py +6 -1
  46. optimum/rbln/transformers/models/mistral/modeling_mistral.py +1 -1
  47. optimum/rbln/transformers/models/phi/__init__.py +24 -0
  48. optimum/rbln/transformers/models/phi/modeling_phi.py +69 -0
  49. optimum/rbln/transformers/models/phi/phi_architecture.py +406 -0
  50. optimum/rbln/transformers/models/qwen2/__init__.py +24 -0
  51. optimum/rbln/transformers/models/qwen2/modeling_qwen2.py +67 -0
  52. optimum/rbln/transformers/models/qwen2/qwen2_architecture.py +29 -0
  53. optimum/rbln/transformers/models/seq2seq/__init__.py +24 -0
  54. optimum/rbln/{modeling_seq2seq.py → transformers/models/seq2seq/modeling_seq2seq.py} +198 -168
  55. optimum/rbln/transformers/models/t5/__init__.py +1 -0
  56. optimum/rbln/transformers/models/t5/modeling_t5.py +55 -0
  57. optimum/rbln/transformers/models/t5/t5_architecture.py +122 -47
  58. optimum/rbln/transformers/models/wav2vec2/modeling_wav2vec2.py +17 -12
  59. optimum/rbln/transformers/models/whisper/generation_whisper.py +68 -0
  60. optimum/rbln/transformers/models/whisper/modeling_whisper.py +172 -111
  61. optimum/rbln/transformers/models/whisper/whisper_architecture.py +44 -17
  62. optimum/rbln/transformers/models/xlm_roberta/modeling_xlm_roberta.py +18 -16
  63. optimum/rbln/transformers/utils/rbln_quantization.py +48 -60
  64. optimum/rbln/utils/import_utils.py +50 -1
  65. optimum/rbln/utils/logging.py +82 -0
  66. optimum/rbln/utils/runtime_utils.py +33 -0
  67. optimum/rbln/utils/timer_utils.py +43 -0
  68. {optimum_rbln-0.1.9.dist-info → optimum_rbln-0.1.12.dist-info}/METADATA +9 -7
  69. optimum_rbln-0.1.12.dist-info/RECORD +103 -0
  70. {optimum_rbln-0.1.9.dist-info → optimum_rbln-0.1.12.dist-info}/WHEEL +1 -1
  71. optimum_rbln-0.1.12.dist-info/entry_points.txt +4 -0
  72. optimum_rbln-0.1.9.dist-info/RECORD +0 -78
  73. {optimum_rbln-0.1.9.dist-info → optimum_rbln-0.1.12.dist-info}/licenses/LICENSE +0 -0
@@ -20,25 +20,29 @@
20
20
  # are the intellectual property of Rebellions Inc. and may not be
21
21
  # copied, modified, or distributed without prior written permission
22
22
  # from Rebellions Inc.
23
+ import functools
23
24
  import glob
24
- import logging
25
+ import os
25
26
  from abc import ABC
27
+ from dataclasses import dataclass
28
+ from pathlib import Path
26
29
  from typing import TYPE_CHECKING, Any, Dict, List, Optional, Tuple, Union
27
30
 
28
31
  import rebel # noqa: F401
29
32
  import torch # noqa: F401
30
33
  from safetensors.torch import load_file
31
34
  from transformers import AutoConfig, AutoModelForCausalLM, PretrainedConfig, PreTrainedModel
32
- from transformers.modeling_outputs import CausalLMOutputWithPast
33
35
  from transformers.modeling_utils import no_init_weights
36
+ from transformers.utils import ModelOutput
34
37
 
35
38
  from ....modeling_base import RBLNModel
36
- from ....modeling_config import DEFAULT_COMPILED_MODEL_NAME, RBLNConfig, RBLNRuntimeConfig
39
+ from ....modeling_config import DEFAULT_COMPILED_MODEL_NAME, RBLNCompileConfig, RBLNConfig
40
+ from ....utils.logging import get_logger
37
41
  from ....utils.runtime_utils import RBLNPytorchRuntime
38
- from ...utils.rbln_quantization import replace_quantized_linear_layers
42
+ from ....utils.timer_utils import rbln_timer
39
43
 
40
44
 
41
- logger = logging.getLogger(__name__)
45
+ logger = get_logger()
42
46
 
43
47
  if TYPE_CHECKING:
44
48
  from transformers import (
@@ -56,7 +60,46 @@ SUPPORTED_QUANTIZATIONS = {
56
60
 
57
61
 
58
62
  class RBLNRuntimeModel(RBLNPytorchRuntime):
59
- mandatory_members = ["main_input_name"]
63
+ mandatory_members = ["main_input_name", "embed_tokens"]
64
+
65
+ def forward(
66
+ self,
67
+ input_ids: torch.LongTensor,
68
+ inputs_embeds: torch.Tensor,
69
+ attention_mask: torch.Tensor,
70
+ cache_position: torch.Tensor,
71
+ batch_position: torch.Tensor,
72
+ query_idx: torch.Tensor,
73
+ **kwargs,
74
+ ):
75
+ if inputs_embeds is None:
76
+ inp = input_ids
77
+ if self.embed_tokens is not None:
78
+ inp = self.embed_tokens(inp)
79
+
80
+ return super().forward(
81
+ inp,
82
+ attention_mask,
83
+ cache_position,
84
+ batch_position,
85
+ query_idx,
86
+ **kwargs,
87
+ )
88
+ else:
89
+ return super().forward(
90
+ inputs_embeds,
91
+ attention_mask,
92
+ cache_position,
93
+ batch_position,
94
+ query_idx,
95
+ **kwargs,
96
+ )
97
+
98
+
99
+ @dataclass
100
+ class RBLNDecoderOnlyOutput(ModelOutput):
101
+ logits: torch.FloatTensor = None
102
+ generate_idx: torch.Tensor = None
60
103
 
61
104
 
62
105
  class RBLNDecoderOnlyModelForCausalLM(RBLNModel, ABC):
@@ -74,18 +117,57 @@ class RBLNDecoderOnlyModelForCausalLM(RBLNModel, ABC):
74
117
  auto_model_class = AutoModelForCausalLM
75
118
 
76
119
  def __post_init__(self, **kwargs):
77
- self.batch_size = self.rbln_config.meta["rbln_batch_size"]
78
- self.max_seq_len = self.rbln_config.meta["rbln_max_seq_len"]
79
- self.prefill_chunk_size = self.rbln_config.meta["rbln_prefill_chunk_size"]
120
+ self.batch_size = self.rbln_config.model_cfg["batch_size"]
121
+ self.max_seq_len = self.rbln_config.model_cfg["max_seq_len"]
122
+ self.prefill_chunk_size = self.rbln_config.model_cfg["prefill_chunk_size"]
80
123
 
81
- self.prefill_attention_mask = torch.zeros(1, 1, self.prefill_chunk_size, self.max_seq_len, dtype=torch.int64)
124
+ self.prefill_attention_mask = torch.zeros(1, 1, self.prefill_chunk_size, self.max_seq_len, dtype=torch.float32)
82
125
  self.causal_mask = 1 - torch.triu(
83
126
  torch.ones(1, 1, self.prefill_chunk_size, self.prefill_chunk_size), diagonal=1
84
127
  )
85
- self.dec_attn_mask_init = torch.zeros(1, 1, 1, self.max_seq_len, dtype=torch.int64)
86
- self.dec_attn_mask = torch.zeros(self.batch_size, 1, 1, self.max_seq_len, dtype=torch.int64)
87
- self.prefill_decoder = RBLNRuntimeModel(runtime=self.model[0], main_input_name="input_ids")
88
- self.decoder = RBLNRuntimeModel(runtime=self.model[1], main_input_name="input_ids")
128
+ self.dec_attn_mask_init = torch.zeros(1, 1, 1, self.max_seq_len, dtype=torch.float32)
129
+ self.dec_attn_mask = torch.zeros(self.batch_size, 1, 1, self.max_seq_len, dtype=torch.float32)
130
+
131
+ main_input_name = self.main_input_name
132
+ if self.rbln_config.model_cfg["use_inputs_embeds"]:
133
+ main_input_name = "inputs_embeds"
134
+ artifacts = torch.load(self.model_save_dir / self.subfolder / "torch_artifacts.pth", weights_only=False)
135
+ with no_init_weights():
136
+ self.embed_tokens = torch.nn.Embedding(
137
+ self.config.vocab_size,
138
+ self.config.hidden_size,
139
+ self.config.pad_token_id,
140
+ )
141
+ self.embed_tokens.load_state_dict(artifacts["embed_tokens"])
142
+ else:
143
+ self.embed_tokens = None
144
+
145
+ self.prefill_decoder = RBLNRuntimeModel(
146
+ runtime=self.model[0], main_input_name=main_input_name, embed_tokens=self.embed_tokens
147
+ )
148
+ self.decoder = RBLNRuntimeModel(
149
+ runtime=self.model[1], main_input_name=main_input_name, embed_tokens=self.embed_tokens
150
+ )
151
+
152
+ @classmethod
153
+ def save_torch_artifacts(
154
+ cls,
155
+ model: "PreTrainedModel",
156
+ save_dir_path: Path,
157
+ subfolder: str,
158
+ rbln_config: RBLNConfig,
159
+ ):
160
+ """
161
+ If you are unavoidably running on a CPU rather than an RBLN device,
162
+ store the torch tensor, weight, etc. in this function.
163
+ """
164
+ if rbln_config.model_cfg["use_inputs_embeds"]:
165
+ save_dict = {}
166
+ save_dict["embed_tokens"] = model.get_input_embeddings().state_dict()
167
+ torch.save(save_dict, save_dir_path / subfolder / "torch_artifacts.pth")
168
+
169
+ def get_input_embeddings(self):
170
+ return self.embed_tokens
89
171
 
90
172
  @classmethod
91
173
  def get_quantized_model(
@@ -98,10 +180,10 @@ class RBLNDecoderOnlyModelForCausalLM(RBLNModel, ABC):
98
180
  subfolder: str = "",
99
181
  local_files_only: bool = False,
100
182
  trust_remote_code: bool = False,
101
- rbln_config_kwargs: Optional[Dict[str, Any]] = None,
102
- rbln_constructor_kwargs: Optional[Dict[str, Any]] = None,
103
183
  **kwargs,
104
184
  ):
185
+ from ...utils.rbln_quantization import update_layers_to_quantized
186
+
105
187
  kwargs = cls.update_kwargs(kwargs)
106
188
 
107
189
  config = AutoConfig.from_pretrained(
@@ -116,37 +198,45 @@ class RBLNDecoderOnlyModelForCausalLM(RBLNModel, ABC):
116
198
 
117
199
  with no_init_weights():
118
200
  model = AutoModelForCausalLM.from_config(config)
119
- replace_quantized_linear_layers(model)
120
201
 
121
- state_dict = {}
122
- for safetensor_file in glob.glob(f"{model_id}/*.safetensors"):
123
- partial_state_dict = load_file(safetensor_file)
124
- state_dict.update(partial_state_dict)
202
+ update_layers_to_quantized(model)
125
203
 
126
204
  n_layer = kwargs.get("num_hidden_layers", None)
127
- if n_layer is not None:
128
- keys_to_delete = []
129
- for key in state_dict.keys():
130
- parts = key.split(".")
131
- if len(parts) > 2 and parts[2].isdigit():
132
- layer_num = int(parts[2])
133
- if layer_num >= n_layer:
134
- keys_to_delete.append(key)
135
-
136
- for key in keys_to_delete:
137
- del state_dict[key]
138
-
139
- model.load_state_dict(state_dict)
205
+ cls._load_weights_directly_to_model(model, model_id, n_layer)
206
+
140
207
  return model
141
208
 
209
+ def _load_weights_directly_to_model(model, model_id, n_layer=None):
210
+ """
211
+ Load safetensor file data directly into the model, filtering by layer if n_layer is provided.
212
+ """
213
+
214
+ model_params = dict(model.named_parameters(recurse=True))
215
+ model_buffers = dict(model.named_buffers(recurse=True))
216
+ safetensor_files = glob.glob(f"{model_id}/*.safetensors")
217
+
218
+ target_layers = list(range(n_layer)) if n_layer is not None else None
219
+
220
+ for safetensor_file in safetensor_files:
221
+ file_data = load_file(safetensor_file)
222
+ for key, value in file_data.items():
223
+ if target_layers is not None:
224
+ parts = key.split(".")
225
+
226
+ if len(parts) > 2 and parts[2].isdigit() and (int(parts[2]) not in target_layers):
227
+ continue
228
+
229
+ if key in model_params:
230
+ model_params[key].data.copy_(value)
231
+ elif key in model_buffers:
232
+ model_buffers[key].data.copy_(value)
233
+
234
+ return 0
235
+
142
236
  @classmethod
143
- def get_pytorch_model(
144
- cls,
145
- *args,
146
- **kwargs,
147
- ) -> "PreTrainedModel":
148
- rbln_config_kwargs = kwargs.get("rbln_config_kwargs", {})
149
- rbln_quantization = rbln_config_kwargs.get("rbln_quantization", None)
237
+ def get_pytorch_model(cls, *args, **kwargs) -> "PreTrainedModel":
238
+ rbln_kwargs = kwargs.get("rbln_kwargs", {})
239
+ rbln_quantization = rbln_kwargs.get("quantization", None)
150
240
 
151
241
  if rbln_quantization is not None and rbln_quantization["format"] == "rbln":
152
242
  model = cls.get_quantized_model(*args, **kwargs)
@@ -155,18 +245,68 @@ class RBLNDecoderOnlyModelForCausalLM(RBLNModel, ABC):
155
245
 
156
246
  return model
157
247
 
248
+ def validate_quantization_config(quantize_config):
249
+ if quantize_config is not None:
250
+ q_format = quantize_config.get("format")
251
+ q_precision = quantize_config.get("precision")
252
+
253
+ if q_format not in SUPPORTED_QUANTIZATIONS:
254
+ raise ValueError(
255
+ f"Invalid quantization format: {q_format}. "
256
+ f"Supported formats are: {list(SUPPORTED_QUANTIZATIONS.keys())}"
257
+ )
258
+
259
+ if q_precision not in SUPPORTED_QUANTIZATIONS[q_format]:
260
+ raise ValueError(
261
+ f"Invalid precision: {q_precision} for format: {q_format}. "
262
+ f"Supported precisions are: {SUPPORTED_QUANTIZATIONS[q_format]}"
263
+ )
264
+
265
+ return quantize_config
266
+
267
+ @classmethod
268
+ def set_quantize_env(cls, quantize_config):
269
+ RBLN_QUANT_BITS_ENV = "RBLN_QUANT_BITS"
270
+ quantize_config = cls.validate_quantization_config(quantize_config)
271
+ if quantize_config is not None:
272
+ q_precision = quantize_config.get("precision")
273
+ quant_bits = q_precision.split("w")[1].split("a")[0]
274
+ os.environ[RBLN_QUANT_BITS_ENV] = quant_bits
275
+ return RBLN_QUANT_BITS_ENV
276
+ return None
277
+
278
+ @classmethod
279
+ def reset_quantize_env(cls, env_var_name):
280
+ if env_var_name is not None and env_var_name in os.environ:
281
+ del os.environ[env_var_name]
282
+
283
+ @classmethod
284
+ def manage_quantize_env(cls, func):
285
+ @functools.wraps(func)
286
+ def wrapper(*args, **kwargs):
287
+ quantize_config = kwargs.get("quantize_config")
288
+ quantize_env_var = cls.set_quantize_env(quantize_config)
289
+ try:
290
+ return func(*args, **kwargs)
291
+ finally:
292
+ cls.reset_quantize_env(quantize_env_var)
293
+
294
+ return wrapper
295
+
158
296
  @classmethod
159
297
  @torch.inference_mode()
160
298
  def get_compiled_model(cls, model: "PreTrainedModel", rbln_config: RBLNConfig):
161
299
  wrapped_model = cls.wrap_model_if_needed(model, rbln_config)
162
300
 
163
- prefill_rbln_runtime_config = rbln_config[DEFAULT_COMPILED_MODEL_NAME][0]
164
- dec_rbln_runtime_config = rbln_config[DEFAULT_COMPILED_MODEL_NAME][1]
301
+ rbln_compile_configs = rbln_config.compile_cfgs
302
+ prefill_rbln_compile_config = rbln_compile_configs[0]
303
+ dec_rbln_compile_config = rbln_compile_configs[1]
165
304
 
305
+ @rbln_timer("JIT trace")
166
306
  def get_scripted_model():
167
307
  # This function is nested to dealloc the example inputs before compilation.
168
- prefill_example_inputs = prefill_rbln_runtime_config.get_dummy_inputs(fill=0)
169
- dec_example_inputs = dec_rbln_runtime_config.get_dummy_inputs(fill=4)
308
+ prefill_example_inputs = prefill_rbln_compile_config.get_dummy_inputs(fill=0)
309
+ dec_example_inputs = dec_rbln_compile_config.get_dummy_inputs(fill=4)
170
310
 
171
311
  batch_index = 3
172
312
  dec_example_inputs[batch_index].fill_(-1) # fill batch_position -1 to indicate it is decoder.
@@ -181,31 +321,48 @@ class RBLNDecoderOnlyModelForCausalLM(RBLNModel, ABC):
181
321
 
182
322
  prefill_scripted_model, dec_scripted_model = get_scripted_model()
183
323
 
184
- prefill_ir = rebel.torchscript_to_ir(
185
- prefill_scripted_model,
186
- input_names=[v[0] for v in prefill_rbln_runtime_config.input_info],
187
- )
188
- dec_ir = rebel.torchscript_to_ir(
189
- dec_scripted_model,
190
- input_names=[v[0] for v in dec_rbln_runtime_config.input_info],
191
- )
324
+ @rbln_timer("Model conversion")
325
+ def scripted_model_to_ir():
326
+ prefill_ir = rebel.torchscript_to_ir(
327
+ prefill_scripted_model,
328
+ input_names=[v[0] for v in prefill_rbln_compile_config.input_info],
329
+ )
330
+ dec_ir = rebel.torchscript_to_ir(
331
+ dec_scripted_model,
332
+ input_names=[v[0] for v in dec_rbln_compile_config.input_info],
333
+ )
334
+ return prefill_ir, dec_ir
192
335
 
336
+ prefill_ir, dec_ir = scripted_model_to_ir()
193
337
  # Caching prefill_decoder/decoder I/O
194
- cache_index_offset = 4
338
+ cache_index_offset = 5
195
339
  connections = [
196
340
  (prefill_ir.outputs[1 + i], prefill_ir.inputs[cache_index_offset + i])
197
341
  for i in range(model.config.num_hidden_layers * 2)
198
342
  ]
199
343
 
200
- compiled_model = rebel.compile(
344
+ # Extract quantize_config from rbln_config
345
+ quantize_config = rbln_config.model_cfg.get("quantization", None)
346
+
347
+ @cls.manage_quantize_env
348
+ def compile_model(*args, **kwargs):
349
+ # Remove quantize_config from kwargs
350
+ kwargs.pop("quantize_config", None)
351
+
352
+ # Call rebel.compile with the updated kwargs
353
+ return rebel.compile(*args, **kwargs)
354
+
355
+ compiled_model = compile_model(
201
356
  prefill_ir,
202
357
  dec_ir,
203
358
  connections=connections,
204
- fusion=prefill_rbln_runtime_config.fusion,
205
- npu=prefill_rbln_runtime_config.npu,
206
- tensor_parallel_size=prefill_rbln_runtime_config.tensor_parallel_size,
359
+ fusion=prefill_rbln_compile_config.fusion,
360
+ npu=prefill_rbln_compile_config.npu,
361
+ tensor_parallel_size=prefill_rbln_compile_config.tensor_parallel_size,
207
362
  use_weight_sharing=True,
363
+ quantize_config=quantize_config,
208
364
  )
365
+
209
366
  return compiled_model
210
367
 
211
368
  @classmethod
@@ -213,12 +370,14 @@ class RBLNDecoderOnlyModelForCausalLM(RBLNModel, ABC):
213
370
  cls,
214
371
  preprocessors: Union["AutoFeatureExtractor", "AutoProcessor", "AutoTokenizer"],
215
372
  model_config: "PretrainedConfig",
216
- rbln_max_seq_len: Optional[int] = None,
217
- rbln_batch_size: Optional[int] = None,
218
- rbln_quantization: Optional[Dict[str, str]] = None,
219
- **kwargs,
373
+ rbln_kwargs: Dict[str, Any] = {},
220
374
  ) -> RBLNConfig:
221
- meta = {}
375
+ rbln_max_seq_len = rbln_kwargs.get("max_seq_len", None)
376
+ rbln_batch_size = rbln_kwargs.get("batch_size", None)
377
+ rbln_quantization = rbln_kwargs.get("quantization", None)
378
+ rbln_use_inputs_embeds = rbln_kwargs.get("use_inputs_embeds", None)
379
+
380
+ rbln_quantization = cls.validate_quantization_config(rbln_quantization)
222
381
 
223
382
  prefill_chunk_size = 128
224
383
  if rbln_max_seq_len is None:
@@ -228,40 +387,35 @@ class RBLNDecoderOnlyModelForCausalLM(RBLNModel, ABC):
228
387
  if rbln_max_seq_len is None:
229
388
  raise ValueError("`rbln_max_seq_len` should be specified.")
230
389
  rbln_batch_size = 1 if rbln_batch_size is None else rbln_batch_size
231
-
232
- meta["rbln_max_seq_len"] = rbln_max_seq_len
233
- meta["rbln_batch_size"] = rbln_batch_size
234
- meta["rbln_prefill_chunk_size"] = prefill_chunk_size
390
+ rbln_use_inputs_embeds = False if rbln_use_inputs_embeds is None else rbln_use_inputs_embeds
235
391
 
236
392
  num_attention_heads = getattr(model_config, "n_head", None) or getattr(model_config, "num_attention_heads")
237
393
  num_key_value_heads = getattr(model_config, "num_key_value_heads", None) or num_attention_heads
238
394
  num_hidden_layers = getattr(model_config, "n_layer", None) or getattr(model_config, "num_hidden_layers")
239
395
  head_dim = getattr(model_config, "head_dim", None) or model_config.hidden_size // num_attention_heads
240
-
241
- if rbln_quantization is not None:
242
- q_format = rbln_quantization.get("format", None)
243
- q_precision = rbln_quantization.get("precision", None)
244
-
245
- if q_format not in SUPPORTED_QUANTIZATIONS.keys() or q_precision not in SUPPORTED_QUANTIZATIONS[q_format]:
246
- raise ValueError(
247
- f'rbln_quantization="{rbln_quantization}" is not a supported quantization format or precesion, '
248
- f"Possible: {SUPPORTED_QUANTIZATIONS}"
249
- )
250
- meta["rbln_quantization"] = rbln_quantization
396
+ hidden_size = getattr(model_config, "n_embd", None) or getattr(model_config, "hidden_size")
251
397
 
252
398
  def get_input_info(
253
399
  batch_size,
254
400
  query_length,
401
+ use_inputs_embeds,
402
+ hidden_size,
255
403
  ):
404
+ if use_inputs_embeds:
405
+ main_input = ("inputs_embeds", [batch_size, query_length, hidden_size], "float32")
406
+ else:
407
+ main_input = ("input_ids", [batch_size, query_length], "int64")
408
+
256
409
  input_info = [
257
- ("input_ids", [batch_size, query_length], "int64"),
258
- ("attention_mask", [batch_size, 1, query_length, rbln_max_seq_len], "int64"),
410
+ main_input,
411
+ ("attention_mask", [batch_size, 1, query_length, rbln_max_seq_len], "float32"),
259
412
  (
260
413
  "cache_position",
261
414
  [batch_size, query_length],
262
415
  "int32",
263
416
  ),
264
417
  ("batch_position", [], "int16"),
418
+ ("query_idx", [], "int16"),
265
419
  ]
266
420
 
267
421
  input_info.extend(
@@ -285,22 +439,37 @@ class RBLNDecoderOnlyModelForCausalLM(RBLNModel, ABC):
285
439
  prefill_input_info = get_input_info(
286
440
  batch_size=1,
287
441
  query_length=prefill_chunk_size,
442
+ use_inputs_embeds=rbln_use_inputs_embeds,
443
+ hidden_size=hidden_size,
288
444
  )
289
445
  dec_input_info = get_input_info(
290
446
  batch_size=rbln_batch_size,
291
447
  query_length=1,
448
+ use_inputs_embeds=rbln_use_inputs_embeds,
449
+ hidden_size=hidden_size,
292
450
  )
293
451
 
294
- prefill_rbln_runtime_config = RBLNRuntimeConfig(input_info=prefill_input_info)
295
- dec_rbln_runtime_config = RBLNRuntimeConfig(input_info=dec_input_info)
452
+ prefill_rbln_compile_config = RBLNCompileConfig(input_info=prefill_input_info)
453
+ dec_rbln_compile_config = RBLNCompileConfig(input_info=dec_input_info)
296
454
 
297
- dec_rbln_runtime_config.batch_size = rbln_batch_size
455
+ rbln_config = RBLNConfig(
456
+ rbln_cls=cls.__name__,
457
+ compile_cfgs=[prefill_rbln_compile_config, dec_rbln_compile_config],
458
+ rbln_kwargs=rbln_kwargs,
459
+ )
298
460
 
299
- rbln_config = RBLNConfig.from_rbln_runtime_configs(
300
- [prefill_rbln_runtime_config, dec_rbln_runtime_config],
301
- _rbln_meta=meta,
461
+ rbln_config.model_cfg.update(
462
+ {
463
+ "max_seq_len": rbln_max_seq_len,
464
+ "batch_size": rbln_batch_size,
465
+ "prefill_chunk_size": prefill_chunk_size,
466
+ "use_inputs_embeds": rbln_use_inputs_embeds,
467
+ }
302
468
  )
303
469
 
470
+ if rbln_quantization is not None:
471
+ rbln_config.model_cfg.update({"quantization": rbln_quantization})
472
+
304
473
  return rbln_config
305
474
 
306
475
  @classmethod
@@ -322,71 +491,112 @@ class RBLNDecoderOnlyModelForCausalLM(RBLNModel, ABC):
322
491
  def _reorder_cache(self, past_key_values, beam_idx):
323
492
  raise NotImplementedError
324
493
 
325
- # args input_ids, past_key_values and attention_mask are updated by _update_model_kwargs_for_generation() in _greedy_search() in GenerationMixin
326
- def prepare_inputs_for_generation(self, input_ids, past_key_values=None, attention_mask=None, **kwargs):
327
- batch_size = input_ids.shape[0]
328
-
329
- # FIXME past_key_values is just carriier variable for past_cached_length
330
- # torch.tensor((4,1),dtype=torch.int32) which refers a past_cached_length of each batch
331
- past_cached_length = past_key_values
332
- if past_cached_length is None:
333
- l_input_ids = []
334
- cache_positions = []
335
- past_cached_length = torch.zeros((batch_size, 1), dtype=torch.int32)
336
- for i in range(batch_size):
337
- input_id = input_ids[i]
338
- input_id = input_id[attention_mask[i] == 1]
339
- valid_len = input_id.shape[-1]
340
- cache_position = torch.arange(0, valid_len, dtype=torch.int32)
341
- past_cached_length[i] = valid_len
342
- l_input_ids.append(input_id.unsqueeze(0))
343
- cache_positions.append(cache_position.unsqueeze(0))
344
-
345
- input_ids = l_input_ids
494
+ def prepare_inputs_for_generation(
495
+ self,
496
+ input_ids: torch.LongTensor,
497
+ generate_idx: Optional[torch.Tensor] = None,
498
+ attention_mask: Optional[torch.LongTensor] = None,
499
+ inputs_embeds: Optional[torch.Tensor] = None,
500
+ **kwargs,
501
+ ):
502
+ model_inputs = {}
503
+ is_prefill_phase = generate_idx is None
504
+
505
+ if is_prefill_phase:
506
+ generate_idx = attention_mask.sum(dim=-1, keepdim=True).int()
507
+ cache_position = None
346
508
  else:
347
- input_ids = input_ids[:, -1:]
348
- cache_positions = past_cached_length
349
- past_cached_length = past_cached_length + 1
509
+ if inputs_embeds is not None:
510
+ raise NotImplementedError("Specifying inputs_embeds in decoder phase is not supported.")
350
511
 
351
- model_inputs = {
352
- "input_ids": input_ids,
353
- "cache_position": cache_positions,
354
- "past_cached_length": past_cached_length,
355
- }
512
+ input_ids = input_ids[:, -1:]
513
+ cache_position = generate_idx
514
+ generate_idx = generate_idx + 1
515
+ model_inputs.update({"input_ids": input_ids})
516
+
517
+ if inputs_embeds is not None:
518
+ if self.rbln_config.model_cfg["use_inputs_embeds"]:
519
+ model_inputs.update({"inputs_embeds": inputs_embeds})
520
+ else:
521
+ raise ValueError(
522
+ "The specifying inputs_embedst is only supported when using a compiled RBLN model with 'rbln_use_inputs_embeds' set to True."
523
+ )
524
+ else:
525
+ model_inputs.update({"input_ids": input_ids})
526
+
527
+ model_inputs.update(
528
+ {
529
+ "attention_mask": attention_mask,
530
+ "cache_position": cache_position,
531
+ "generate_idx": generate_idx,
532
+ }
533
+ )
356
534
 
357
535
  return model_inputs
358
536
 
537
+ def _update_model_kwargs_for_generation(
538
+ self,
539
+ outputs: RBLNDecoderOnlyOutput,
540
+ model_kwargs: Dict[str, Any],
541
+ **kwargs,
542
+ ) -> Dict[str, Any]:
543
+ # update generate_idx
544
+ model_kwargs["generate_idx"] = outputs.generate_idx
545
+
546
+ return model_kwargs
547
+
359
548
  def forward(
360
549
  self,
361
- input_ids: torch.LongTensor = None,
362
- cache_position: Union[List[torch.Tensor], torch.Tensor] = None, # vllm keyword argument
550
+ input_ids: Optional[torch.LongTensor] = None,
551
+ inputs_embeds: Optional[torch.Tensor] = None,
552
+ cache_position: Optional[torch.Tensor] = None,
553
+ attention_mask: Optional[torch.LongTensor] = None,
554
+ generate_idx: Optional[torch.Tensor] = None,
555
+ # from llava_next forward args
363
556
  batch_idx: Optional[int] = None,
364
- past_cached_length: Optional[torch.Tensor] = None, # past_cached_length
365
557
  **kwargs,
366
558
  ) -> Tuple[torch.FloatTensor]:
367
- # prefll & hf generate
368
- if isinstance(cache_position, list):
559
+ # prefll
560
+ if cache_position is None:
369
561
  logits = []
370
- for batch_idx, (input_id, cache_pos) in enumerate(zip(input_ids, cache_position)):
371
- logit = self._forward_prefill(input_ids=input_id, cache_position=cache_pos, batch_idx=batch_idx)
562
+ input_tensors = inputs_embeds if inputs_embeds is not None else input_ids
563
+ batch_size = input_tensors.shape[0]
564
+
565
+ for b_idx in range(batch_size):
566
+ # Transform inputs as vllm format
567
+ if attention_mask is not None:
568
+ input_tensor = input_tensors[b_idx : b_idx + 1, attention_mask[b_idx].bool()]
569
+ else:
570
+ input_tensor = input_tensors[b_idx : b_idx + 1]
571
+
572
+ cache_position = torch.arange(0, generate_idx[b_idx].item(), dtype=torch.int32).unsqueeze(0)
573
+
574
+ logit = self._forward_prefill(
575
+ input_ids=input_tensor if inputs_embeds is None else None,
576
+ inputs_embeds=input_tensor if inputs_embeds is not None else None,
577
+ cache_position=cache_position,
578
+ batch_idx=b_idx if batch_idx is None else batch_idx, # Llava-next prefill
579
+ )
372
580
  logits.append(logit)
373
581
  logits = torch.cat(logits, dim=0)
374
- # prefill & vllm step
375
- elif cache_position.shape[-1] > 1:
376
- logits = self._forward_prefill(input_ids=input_ids, cache_position=cache_position, batch_idx=batch_idx)
377
- # common decoder
582
+ # decoder
378
583
  else:
379
- logits = self._forward_decoder(input_ids=input_ids, cache_position=cache_position)
584
+ logits = self._forward_decoder(
585
+ input_ids=input_ids,
586
+ inputs_embeds=inputs_embeds,
587
+ cache_position=cache_position,
588
+ )
380
589
 
381
- return CausalLMOutputWithPast(
590
+ return RBLNDecoderOnlyOutput(
382
591
  logits=logits,
383
- past_key_values=past_cached_length, # past_cached_length
592
+ generate_idx=generate_idx,
384
593
  )
385
594
 
386
595
  def _forward_prefill(
387
596
  self,
388
597
  input_ids: torch.LongTensor = None,
389
- cache_position: torch.Tensor = None, # torch.tensor(,dtype=int32) (1,64) // (4,1)
598
+ inputs_embeds: torch.Tensor = None,
599
+ cache_position: torch.Tensor = None,
390
600
  batch_idx: int = None,
391
601
  ) -> torch.FloatTensor:
392
602
  if batch_idx is None or batch_idx >= self.batch_size:
@@ -398,7 +608,7 @@ class RBLNDecoderOnlyModelForCausalLM(RBLNModel, ABC):
398
608
  torch.empty(
399
609
  size=[
400
610
  1,
401
- self.prefill_chunk_size,
611
+ 1,
402
612
  self.config.vocab_size,
403
613
  ],
404
614
  dtype=torch.float32,
@@ -407,11 +617,19 @@ class RBLNDecoderOnlyModelForCausalLM(RBLNModel, ABC):
407
617
  torch.empty(size=[], dtype=torch.int16, device="cpu"),
408
618
  ]
409
619
 
410
- query_length = input_ids.shape[1]
411
- attention_mask = self.prefill_attention_mask.clone()
620
+ input_tensors = inputs_embeds if inputs_embeds is not None else input_ids
621
+ query_length = input_tensors.shape[1]
622
+ _attention_mask = self.prefill_attention_mask.clone()
623
+
412
624
  for step in range(0, query_length, self.prefill_chunk_size):
413
- if step + self.prefill_chunk_size > query_length:
414
- input_ids = torch.nn.functional.pad(input_ids, (0, step + self.prefill_chunk_size - query_length))
625
+ # pad input_tensors & cache_position for prefill_chunk
626
+ if (step + self.prefill_chunk_size) > query_length:
627
+ pad_to_chunk = step + self.prefill_chunk_size - query_length
628
+ if inputs_embeds is not None:
629
+ input_tensors = torch.nn.functional.pad(input_tensors, (0, 0, 0, pad_to_chunk))
630
+ else:
631
+ input_tensors = torch.nn.functional.pad(input_tensors, (0, pad_to_chunk))
632
+
415
633
  cache_position = torch.cat(
416
634
  [
417
635
  cache_position,
@@ -424,41 +642,82 @@ class RBLNDecoderOnlyModelForCausalLM(RBLNModel, ABC):
424
642
  dim=-1,
425
643
  )
426
644
 
427
- sliced_input_ids = input_ids[:, step : step + self.prefill_chunk_size]
428
- sliced_cache_positions = cache_position[:, step : step + self.prefill_chunk_size]
645
+ # slice input_tensor & cache_position with prefill_chunk_size
646
+ _input_tensors = input_tensors[:, step : step + self.prefill_chunk_size]
647
+ _cache_position = cache_position[:, step : step + self.prefill_chunk_size]
429
648
 
649
+ # update attention_mask
430
650
  if step >= self.prefill_chunk_size:
431
- attention_mask[:, :, :, step - self.prefill_chunk_size : step] = 1
432
- attention_mask[:, :, :, step : step + self.prefill_chunk_size] = self.causal_mask
651
+ _attention_mask[:, :, :, step - self.prefill_chunk_size : step] = 1
652
+ _attention_mask[:, :, :, step : step + self.prefill_chunk_size] = self.causal_mask
653
+
654
+ query_idx = (query_length - 1) % self.prefill_chunk_size
433
655
 
434
656
  logits, _ = self.prefill_decoder(
435
- sliced_input_ids.contiguous(),
436
- attention_mask.contiguous(),
437
- sliced_cache_positions.contiguous(),
438
- torch.tensor(batch_idx, dtype=torch.int16),
657
+ input_ids=_input_tensors.contiguous() if inputs_embeds is None else None,
658
+ inputs_embeds=_input_tensors.contiguous() if inputs_embeds is not None else None,
659
+ attention_mask=_attention_mask.contiguous(),
660
+ cache_position=_cache_position.contiguous(),
661
+ batch_position=torch.tensor(batch_idx, dtype=torch.int16),
662
+ query_idx=torch.tensor(query_idx, dtype=torch.int16),
439
663
  out=out_buffers,
440
664
  )
441
- logits = logits[:, query_length % self.prefill_chunk_size - 1].unsqueeze(1)
442
665
 
666
+ # update decoder_attn_mask with preprocessed kv-cache length in prefill phase
443
667
  self.dec_attn_mask[batch_idx] = self.dec_attn_mask_init.clone()
444
668
  self.dec_attn_mask[batch_idx, :, :, :query_length] = 1
445
669
 
446
670
  return logits
447
671
 
448
672
  def _forward_decoder(
449
- self, input_ids: torch.LongTensor = None, cache_position: torch.Tensor = None
673
+ self,
674
+ input_ids: torch.LongTensor = None,
675
+ inputs_embeds: torch.Tensor = None,
676
+ cache_position: torch.Tensor = None,
450
677
  ) -> torch.FloatTensor:
451
- batch_size = input_ids.shape[0]
678
+ input_tensors = inputs_embeds if inputs_embeds is not None else input_ids
679
+
680
+ batch_size = input_tensors.shape[0]
452
681
 
453
682
  for b_idx in range(batch_size):
454
683
  decoding_step = cache_position[b_idx].item()
455
684
  self.dec_attn_mask[b_idx, :, :, decoding_step] = 1
456
685
 
457
686
  logits, _ = self.decoder(
458
- input_ids.contiguous(),
459
- self.dec_attn_mask.contiguous(),
460
- cache_position.contiguous(),
461
- torch.tensor(0, dtype=torch.int16),
687
+ input_ids=input_tensors.contiguous() if inputs_embeds is None else None,
688
+ inputs_embeds=input_tensors.contiguous() if inputs_embeds is not None else None,
689
+ attention_mask=self.dec_attn_mask.contiguous(),
690
+ cache_position=cache_position.contiguous(),
691
+ batch_position=torch.tensor(0, dtype=torch.int16),
692
+ query_idx=torch.tensor(0, dtype=torch.int16),
462
693
  )
463
694
 
464
695
  return logits
696
+
697
+ def vllm_forward(
698
+ self,
699
+ input_ids: torch.LongTensor = None,
700
+ inputs_embeds: torch.Tensor = None,
701
+ cache_position: torch.Tensor = None,
702
+ batch_idx: Optional[int] = None,
703
+ **kwargs,
704
+ ) -> Tuple[torch.FloatTensor]:
705
+ # prefll
706
+ if cache_position.shape[-1] > 1:
707
+ logits = self._forward_prefill(
708
+ input_ids=input_ids,
709
+ inputs_embeds=inputs_embeds,
710
+ cache_position=cache_position,
711
+ batch_idx=batch_idx,
712
+ )
713
+ # decoder
714
+ else:
715
+ logits = self._forward_decoder(
716
+ input_ids=input_ids,
717
+ inputs_embeds=inputs_embeds,
718
+ cache_position=cache_position,
719
+ )
720
+
721
+ return RBLNDecoderOnlyOutput(
722
+ logits=logits,
723
+ )