optimum-rbln 0.1.11__py3-none-any.whl → 0.1.12__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (54) hide show
  1. optimum/rbln/__init__.py +10 -7
  2. optimum/rbln/__version__.py +1 -1
  3. optimum/rbln/diffusers/models/autoencoder_kl.py +0 -2
  4. optimum/rbln/diffusers/models/controlnet.py +0 -6
  5. optimum/rbln/diffusers/models/unet_2d_condition.py +0 -3
  6. optimum/rbln/diffusers/pipelines/controlnet/multicontrolnet.py +4 -0
  7. optimum/rbln/diffusers/pipelines/controlnet/pipeline_controlnet.py +18 -20
  8. optimum/rbln/diffusers/pipelines/controlnet/pipeline_controlnet_img2img.py +18 -20
  9. optimum/rbln/diffusers/pipelines/controlnet/pipeline_controlnet_sd_xl.py +19 -34
  10. optimum/rbln/diffusers/pipelines/controlnet/pipeline_controlnet_sd_xl_img2img.py +20 -35
  11. optimum/rbln/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion.py +12 -13
  12. optimum/rbln/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_img2img.py +12 -14
  13. optimum/rbln/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl.py +13 -14
  14. optimum/rbln/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_img2img.py +13 -14
  15. optimum/rbln/modeling_alias.py +4 -9
  16. optimum/rbln/modeling_base.py +105 -139
  17. optimum/rbln/modeling_config.py +51 -0
  18. optimum/rbln/transformers/__init__.py +8 -0
  19. optimum/rbln/transformers/models/__init__.py +4 -1
  20. optimum/rbln/transformers/models/auto/modeling_auto.py +1 -0
  21. optimum/rbln/transformers/models/bart/__init__.py +1 -1
  22. optimum/rbln/transformers/models/bart/bart_architecture.py +18 -12
  23. optimum/rbln/transformers/models/bart/modeling_bart.py +25 -6
  24. optimum/rbln/transformers/models/bert/modeling_bert.py +1 -2
  25. optimum/rbln/transformers/models/clip/modeling_clip.py +0 -1
  26. optimum/rbln/transformers/models/decoderonly/modeling_decoderonly.py +172 -100
  27. optimum/rbln/transformers/models/dpt/modeling_dpt.py +0 -1
  28. optimum/rbln/transformers/models/exaone/__init__.py +32 -0
  29. optimum/rbln/transformers/models/exaone/exaone_architecture.py +72 -0
  30. optimum/rbln/transformers/models/exaone/hf_hub_cached/configuration_exaone.py +181 -0
  31. optimum/rbln/transformers/models/exaone/hf_hub_cached/modeling_exaone.py +1725 -0
  32. optimum/rbln/transformers/models/exaone/modeling_exaone.py +78 -0
  33. optimum/rbln/transformers/models/llava_next/modeling_llava_next.py +148 -152
  34. optimum/rbln/transformers/models/midm/modeling_midm.py +5 -0
  35. optimum/rbln/transformers/models/qwen2/__init__.py +24 -0
  36. optimum/rbln/transformers/models/qwen2/modeling_qwen2.py +67 -0
  37. optimum/rbln/transformers/models/qwen2/qwen2_architecture.py +29 -0
  38. optimum/rbln/transformers/models/seq2seq/__init__.py +24 -0
  39. optimum/rbln/{modeling_seq2seq.py → transformers/models/seq2seq/modeling_seq2seq.py} +107 -166
  40. optimum/rbln/transformers/models/t5/__init__.py +1 -0
  41. optimum/rbln/transformers/models/t5/modeling_t5.py +55 -0
  42. optimum/rbln/transformers/models/t5/t5_architecture.py +46 -32
  43. optimum/rbln/transformers/models/wav2vec2/modeling_wav2vec2.py +0 -1
  44. optimum/rbln/transformers/models/whisper/modeling_whisper.py +37 -12
  45. optimum/rbln/transformers/models/xlm_roberta/modeling_xlm_roberta.py +1 -2
  46. optimum/rbln/utils/import_utils.py +14 -0
  47. optimum/rbln/utils/logging.py +1 -1
  48. optimum/rbln/utils/runtime_utils.py +1 -1
  49. optimum/rbln/utils/timer_utils.py +26 -2
  50. {optimum_rbln-0.1.11.dist-info → optimum_rbln-0.1.12.dist-info}/METADATA +4 -3
  51. {optimum_rbln-0.1.11.dist-info → optimum_rbln-0.1.12.dist-info}/RECORD +54 -44
  52. {optimum_rbln-0.1.11.dist-info → optimum_rbln-0.1.12.dist-info}/WHEEL +1 -1
  53. {optimum_rbln-0.1.11.dist-info → optimum_rbln-0.1.12.dist-info}/entry_points.txt +0 -0
  54. {optimum_rbln-0.1.11.dist-info → optimum_rbln-0.1.12.dist-info}/licenses/LICENSE +0 -0
@@ -47,6 +47,12 @@ from transformers.utils import logging
47
47
  logger = logging.get_logger(__name__)
48
48
 
49
49
 
50
+ class BartWrapper:
51
+ def __init__(self, model):
52
+ self.encoder = BartEncoderWrapper(model)
53
+ self.decoder = BartDecoderWrapper(model)
54
+
55
+
50
56
  class _BartAttention(BartAttention):
51
57
  def forward(
52
58
  self,
@@ -238,6 +244,7 @@ class _BartSdpaAttention(BartSdpaAttention):
238
244
  value_states, dim=2, start=cache_position, end=cache_position + 1
239
245
  )
240
246
 
247
+ # need 4d shape (input tensors) for scaled_dot_product_attention
241
248
  attn_output = torch.nn.functional.scaled_dot_product_attention(
242
249
  query_states,
243
250
  key_states,
@@ -324,7 +331,6 @@ class _BartDecoder(BartDecoder):
324
331
  attn_impl: str = "eager",
325
332
  ):
326
333
  # embedding
327
- # thkim fix : transformers == 4.44.2 compile
328
334
  if hasattr(self, "embed_scale"):
329
335
  inputs_embeds = self.embed_tokens(input_ids) * self.embed_scale
330
336
  else:
@@ -336,13 +342,15 @@ class _BartDecoder(BartDecoder):
336
342
  hidden_states = inputs_embeds + positions
337
343
  else:
338
344
  hidden_all = []
345
+ # compiler pattern base dependency -> take + add
339
346
  for i in range(input_ids.shape[0]):
340
347
  # cache position [N,1]
341
348
  positions_idx = cache_position[i]
349
+ # offset is set 2 in bart embedding
342
350
  position_weight = self.embed_positions.weight[2:]
343
351
  position = position_weight[positions_idx]
344
- tmp_hidden = position + inputs_embeds[i]
345
- hidden_all.append(tmp_hidden)
352
+ batch_hidden = position + inputs_embeds[i]
353
+ hidden_all.append(batch_hidden)
346
354
  hidden_states = torch.stack(hidden_all, dim=0)
347
355
 
348
356
  hidden_states = self.layernorm_embedding(hidden_states)
@@ -444,6 +452,7 @@ class BartDecoderWrapper(torch.nn.Module):
444
452
  self_kv_cache.append(past_key_values[i][1])
445
453
  self_kv_cache = torch.stack(self_kv_cache, dim=0)
446
454
 
455
+ # return batch_position to keep it as a variable within the graph
447
456
  return lm_logits, self_kv_cache, batch_position
448
457
 
449
458
 
@@ -467,9 +476,6 @@ class BartEncoderWrapper(torch.nn.Module):
467
476
  cross_key_value: torch.Tensor = None,
468
477
  batch_idx: torch.Tensor = None,
469
478
  ) -> Tuple[torch.Tensor]:
470
- encoder_batch_size = input_ids.shape[0]
471
- decoder_batch_size = encoder_batch_size # TODO(taehoon) fix to enable beam-search
472
-
473
479
  # 1. run encoder
474
480
  encoder_outputs = self.encoder(input_ids=input_ids, attention_mask=attention_mask)
475
481
  last_hidden_states = encoder_outputs[0]
@@ -477,19 +483,19 @@ class BartEncoderWrapper(torch.nn.Module):
477
483
  # 2. run dummy decoder to get pre-calculated cross-key_values for generation
478
484
  dummy_past_key_value = []
479
485
  for _ in range(self.num_layers):
480
- pkv_self_attn_key = torch.zeros(decoder_batch_size, self.num_heads, self.decoder_max_length, self.d_kv)
481
- pkv_self_attn_value = torch.zeros(decoder_batch_size, self.num_heads, self.decoder_max_length, self.d_kv)
482
- pkv_cross_attn_key = torch.zeros(encoder_batch_size, self.num_heads, self.encoder_max_length, self.d_kv)
483
- pkv_cross_attn_value = torch.zeros(encoder_batch_size, self.num_heads, self.encoder_max_length, self.d_kv)
486
+ pkv_self_attn_key = torch.zeros(1, self.num_heads, self.decoder_max_length, self.d_kv)
487
+ pkv_self_attn_value = torch.zeros(1, self.num_heads, self.decoder_max_length, self.d_kv)
488
+ pkv_cross_attn_key = torch.zeros(1, self.num_heads, self.encoder_max_length, self.d_kv)
489
+ pkv_cross_attn_value = torch.zeros(1, self.num_heads, self.encoder_max_length, self.d_kv)
484
490
  layer_pkv = (pkv_self_attn_key, pkv_self_attn_value, pkv_cross_attn_key, pkv_cross_attn_value)
485
491
  dummy_past_key_value.append(layer_pkv)
486
492
 
487
- decoder_attention_mask = torch.zeros(decoder_batch_size, self.decoder_max_length, dtype=torch.float32)
493
+ decoder_attention_mask = torch.zeros(1, self.decoder_max_length, dtype=torch.float32)
488
494
  decoder_attention_mask[:, :1] = 1
489
495
 
490
496
  decoder_outputs = _BartDecoder.forward(
491
497
  self.decoder,
492
- input_ids=torch.zeros((decoder_batch_size, 1), dtype=torch.int64),
498
+ input_ids=torch.zeros((1, 1), dtype=torch.int64),
493
499
  attention_mask=decoder_attention_mask,
494
500
  encoder_attention_mask=attention_mask,
495
501
  cache_position=torch.tensor(0, dtype=torch.int32),
@@ -22,23 +22,25 @@
22
22
  # from Rebellions Inc.
23
23
 
24
24
  import inspect
25
- import logging
26
- from typing import TYPE_CHECKING, Any, Dict, Optional, Union
25
+ from typing import TYPE_CHECKING, Any, Callable, Dict, Optional, Union
27
26
 
28
- from transformers import AutoModel, BartConfig, BartModel, PretrainedConfig
27
+ from transformers import BartConfig, BartForConditionalGeneration, BartModel, PretrainedConfig
29
28
 
30
29
  from ....modeling_base import RBLNModel
31
30
  from ....modeling_config import RBLNCompileConfig, RBLNConfig
31
+ from ....utils.logging import get_logger
32
+ from ...models.seq2seq import RBLNModelForSeq2SeqLM
33
+ from .bart_architecture import BartWrapper
32
34
 
33
35
 
34
- logger = logging.getLogger(__name__)
36
+ logger = get_logger()
37
+
35
38
 
36
39
  if TYPE_CHECKING:
37
- from transformers import AutoFeatureExtractor, AutoProcessor, AutoTokenizer
40
+ from transformers import AutoFeatureExtractor, AutoProcessor, AutoTokenizer, PreTrainedModel
38
41
 
39
42
 
40
43
  class RBLNBartModel(RBLNModel):
41
- auto_model_class = AutoModel # feature extraction
42
44
  original_model_class = BartModel
43
45
  original_config_class = BartConfig
44
46
 
@@ -104,3 +106,20 @@ class RBLNBartModel(RBLNModel):
104
106
 
105
107
  rbln_config.model_cfg.update({"max_seq_len": rbln_max_seq_len})
106
108
  return rbln_config
109
+
110
+
111
+ class RBLNBartForConditionalGeneration(RBLNModelForSeq2SeqLM):
112
+ @classmethod
113
+ def wrap_model_if_needed(self, model: "PreTrainedModel", rbln_config: "RBLNConfig"):
114
+ return BartWrapper(model)
115
+
116
+ def __getattr__(self, __name: str) -> Any:
117
+ def redirect(func):
118
+ return lambda *pargs, **kwargs: func(self, *pargs, **kwargs)
119
+
120
+ val = getattr(BartForConditionalGeneration, __name)
121
+
122
+ if isinstance(val, Callable) and "self" in set(inspect.signature(val).parameters):
123
+ return redirect(val)
124
+
125
+ return val
@@ -25,7 +25,7 @@ import inspect
25
25
  import logging
26
26
  from typing import TYPE_CHECKING, Any, Dict, Optional, Union
27
27
 
28
- from transformers import AutoModel, BertConfig, BertModel, PretrainedConfig
28
+ from transformers import BertConfig, BertModel, PretrainedConfig
29
29
 
30
30
  from ....modeling_base import RBLNModel
31
31
  from ....modeling_config import RBLNCompileConfig, RBLNConfig
@@ -38,7 +38,6 @@ if TYPE_CHECKING:
38
38
 
39
39
 
40
40
  class RBLNBertModel(RBLNModel):
41
- auto_model_class = AutoModel # feature extraction
42
41
  original_model_class = BertModel
43
42
  original_config_class = BertConfig
44
43
 
@@ -58,7 +58,6 @@ class _TextEncoder(torch.nn.Module):
58
58
 
59
59
 
60
60
  class RBLNCLIPTextModel(RBLNModel):
61
- auto_model_class = AutoModel # feature extraction
62
61
  original_model_class = CLIPTextModel
63
62
  original_config_class = CLIPTextConfig
64
63
 
@@ -20,8 +20,9 @@
20
20
  # are the intellectual property of Rebellions Inc. and may not be
21
21
  # copied, modified, or distributed without prior written permission
22
22
  # from Rebellions Inc.
23
+ import functools
23
24
  import glob
24
- import logging
25
+ import os
25
26
  from abc import ABC
26
27
  from dataclasses import dataclass
27
28
  from pathlib import Path
@@ -36,11 +37,12 @@ from transformers.utils import ModelOutput
36
37
 
37
38
  from ....modeling_base import RBLNModel
38
39
  from ....modeling_config import DEFAULT_COMPILED_MODEL_NAME, RBLNCompileConfig, RBLNConfig
40
+ from ....utils.logging import get_logger
39
41
  from ....utils.runtime_utils import RBLNPytorchRuntime
40
42
  from ....utils.timer_utils import rbln_timer
41
43
 
42
44
 
43
- logger = logging.getLogger(__name__)
45
+ logger = get_logger()
44
46
 
45
47
  if TYPE_CHECKING:
46
48
  from transformers import (
@@ -97,7 +99,7 @@ class RBLNRuntimeModel(RBLNPytorchRuntime):
97
99
  @dataclass
98
100
  class RBLNDecoderOnlyOutput(ModelOutput):
99
101
  logits: torch.FloatTensor = None
100
- past_cached_length: Union[int, torch.Tensor] = None
102
+ generate_idx: torch.Tensor = None
101
103
 
102
104
 
103
105
  class RBLNDecoderOnlyModelForCausalLM(RBLNModel, ABC):
@@ -243,6 +245,54 @@ class RBLNDecoderOnlyModelForCausalLM(RBLNModel, ABC):
243
245
 
244
246
  return model
245
247
 
248
+ def validate_quantization_config(quantize_config):
249
+ if quantize_config is not None:
250
+ q_format = quantize_config.get("format")
251
+ q_precision = quantize_config.get("precision")
252
+
253
+ if q_format not in SUPPORTED_QUANTIZATIONS:
254
+ raise ValueError(
255
+ f"Invalid quantization format: {q_format}. "
256
+ f"Supported formats are: {list(SUPPORTED_QUANTIZATIONS.keys())}"
257
+ )
258
+
259
+ if q_precision not in SUPPORTED_QUANTIZATIONS[q_format]:
260
+ raise ValueError(
261
+ f"Invalid precision: {q_precision} for format: {q_format}. "
262
+ f"Supported precisions are: {SUPPORTED_QUANTIZATIONS[q_format]}"
263
+ )
264
+
265
+ return quantize_config
266
+
267
+ @classmethod
268
+ def set_quantize_env(cls, quantize_config):
269
+ RBLN_QUANT_BITS_ENV = "RBLN_QUANT_BITS"
270
+ quantize_config = cls.validate_quantization_config(quantize_config)
271
+ if quantize_config is not None:
272
+ q_precision = quantize_config.get("precision")
273
+ quant_bits = q_precision.split("w")[1].split("a")[0]
274
+ os.environ[RBLN_QUANT_BITS_ENV] = quant_bits
275
+ return RBLN_QUANT_BITS_ENV
276
+ return None
277
+
278
+ @classmethod
279
+ def reset_quantize_env(cls, env_var_name):
280
+ if env_var_name is not None and env_var_name in os.environ:
281
+ del os.environ[env_var_name]
282
+
283
+ @classmethod
284
+ def manage_quantize_env(cls, func):
285
+ @functools.wraps(func)
286
+ def wrapper(*args, **kwargs):
287
+ quantize_config = kwargs.get("quantize_config")
288
+ quantize_env_var = cls.set_quantize_env(quantize_config)
289
+ try:
290
+ return func(*args, **kwargs)
291
+ finally:
292
+ cls.reset_quantize_env(quantize_env_var)
293
+
294
+ return wrapper
295
+
246
296
  @classmethod
247
297
  @torch.inference_mode()
248
298
  def get_compiled_model(cls, model: "PreTrainedModel", rbln_config: RBLNConfig):
@@ -252,7 +302,7 @@ class RBLNDecoderOnlyModelForCausalLM(RBLNModel, ABC):
252
302
  prefill_rbln_compile_config = rbln_compile_configs[0]
253
303
  dec_rbln_compile_config = rbln_compile_configs[1]
254
304
 
255
- @rbln_timer("Jit Trace")
305
+ @rbln_timer("JIT trace")
256
306
  def get_scripted_model():
257
307
  # This function is nested to dealloc the example inputs before compilation.
258
308
  prefill_example_inputs = prefill_rbln_compile_config.get_dummy_inputs(fill=0)
@@ -271,7 +321,7 @@ class RBLNDecoderOnlyModelForCausalLM(RBLNModel, ABC):
271
321
 
272
322
  prefill_scripted_model, dec_scripted_model = get_scripted_model()
273
323
 
274
- @rbln_timer("TorchScript to IR")
324
+ @rbln_timer("Model conversion")
275
325
  def scripted_model_to_ir():
276
326
  prefill_ir = rebel.torchscript_to_ir(
277
327
  prefill_scripted_model,
@@ -291,7 +341,18 @@ class RBLNDecoderOnlyModelForCausalLM(RBLNModel, ABC):
291
341
  for i in range(model.config.num_hidden_layers * 2)
292
342
  ]
293
343
 
294
- compiled_model = rebel.compile(
344
+ # Extract quantize_config from rbln_config
345
+ quantize_config = rbln_config.model_cfg.get("quantization", None)
346
+
347
+ @cls.manage_quantize_env
348
+ def compile_model(*args, **kwargs):
349
+ # Remove quantize_config from kwargs
350
+ kwargs.pop("quantize_config", None)
351
+
352
+ # Call rebel.compile with the updated kwargs
353
+ return rebel.compile(*args, **kwargs)
354
+
355
+ compiled_model = compile_model(
295
356
  prefill_ir,
296
357
  dec_ir,
297
358
  connections=connections,
@@ -299,7 +360,9 @@ class RBLNDecoderOnlyModelForCausalLM(RBLNModel, ABC):
299
360
  npu=prefill_rbln_compile_config.npu,
300
361
  tensor_parallel_size=prefill_rbln_compile_config.tensor_parallel_size,
301
362
  use_weight_sharing=True,
363
+ quantize_config=quantize_config,
302
364
  )
365
+
303
366
  return compiled_model
304
367
 
305
368
  @classmethod
@@ -314,6 +377,8 @@ class RBLNDecoderOnlyModelForCausalLM(RBLNModel, ABC):
314
377
  rbln_quantization = rbln_kwargs.get("quantization", None)
315
378
  rbln_use_inputs_embeds = rbln_kwargs.get("use_inputs_embeds", None)
316
379
 
380
+ rbln_quantization = cls.validate_quantization_config(rbln_quantization)
381
+
317
382
  prefill_chunk_size = 128
318
383
  if rbln_max_seq_len is None:
319
384
  rbln_max_seq_len = getattr(model_config, "max_position_embeddings", None) or getattr(
@@ -330,16 +395,6 @@ class RBLNDecoderOnlyModelForCausalLM(RBLNModel, ABC):
330
395
  head_dim = getattr(model_config, "head_dim", None) or model_config.hidden_size // num_attention_heads
331
396
  hidden_size = getattr(model_config, "n_embd", None) or getattr(model_config, "hidden_size")
332
397
 
333
- if rbln_quantization is not None:
334
- q_format = rbln_quantization.get("format", None)
335
- q_precision = rbln_quantization.get("precision", None)
336
-
337
- if q_format not in SUPPORTED_QUANTIZATIONS.keys() or q_precision not in SUPPORTED_QUANTIZATIONS[q_format]:
338
- raise ValueError(
339
- f'rbln_quantization="{rbln_quantization}" is not a supported quantization format or precesion, '
340
- f"Possible: {SUPPORTED_QUANTIZATIONS}"
341
- )
342
-
343
398
  def get_input_info(
344
399
  batch_size,
345
400
  query_length,
@@ -439,50 +494,41 @@ class RBLNDecoderOnlyModelForCausalLM(RBLNModel, ABC):
439
494
  def prepare_inputs_for_generation(
440
495
  self,
441
496
  input_ids: torch.LongTensor,
442
- past_cached_length: Optional[torch.Tensor] = None,
497
+ generate_idx: Optional[torch.Tensor] = None,
443
498
  attention_mask: Optional[torch.LongTensor] = None,
444
499
  inputs_embeds: Optional[torch.Tensor] = None,
445
500
  **kwargs,
446
501
  ):
447
502
  model_inputs = {}
448
- # prefill phase
449
- if past_cached_length is None:
450
- # huggingface make dummy_input_ids if model_input_name is "input_embeds"
451
- # https://github.com/huggingface/transformers/blob/174890280b340b89c5bfa092f6b4fb0e2dc2d7fc/src/transformers/generation/utils.py#L469
452
- if self.rbln_config.model_cfg["use_inputs_embeds"] and inputs_embeds is not None:
453
- input_tensors = inputs_embeds
454
- else:
455
- input_tensors = input_ids
503
+ is_prefill_phase = generate_idx is None
456
504
 
457
- batch_size = input_tensors.shape[0]
458
- l_input_tensors = []
459
- cache_positions = []
460
- past_cached_length = torch.zeros((batch_size, 1), dtype=torch.int32)
461
- for i in range(batch_size):
462
- input_tensor = input_tensors[i]
463
- input_tensor = input_tensor[attention_mask[i] == 1]
464
- valid_len = input_tensor.shape[0]
465
- cache_position = torch.arange(0, valid_len, dtype=torch.int32)
466
- past_cached_length[i] = valid_len
467
- l_input_tensors.append(input_tensor.unsqueeze(0))
468
- cache_positions.append(cache_position.unsqueeze(0))
469
-
470
- input_tensors = l_input_tensors
471
- if self.rbln_config.model_cfg["use_inputs_embeds"] and inputs_embeds is not None:
472
- model_inputs.update({"inputs_embeds": input_tensors, "input_ids": input_ids})
473
- else:
474
- model_inputs.update({"input_ids": input_tensors, "inputs_embeds": inputs_embeds})
475
- # decoder phase
505
+ if is_prefill_phase:
506
+ generate_idx = attention_mask.sum(dim=-1, keepdim=True).int()
507
+ cache_position = None
476
508
  else:
509
+ if inputs_embeds is not None:
510
+ raise NotImplementedError("Specifying inputs_embeds in decoder phase is not supported.")
511
+
477
512
  input_ids = input_ids[:, -1:]
478
- cache_positions = past_cached_length
479
- past_cached_length = past_cached_length + 1
513
+ cache_position = generate_idx
514
+ generate_idx = generate_idx + 1
515
+ model_inputs.update({"input_ids": input_ids})
516
+
517
+ if inputs_embeds is not None:
518
+ if self.rbln_config.model_cfg["use_inputs_embeds"]:
519
+ model_inputs.update({"inputs_embeds": inputs_embeds})
520
+ else:
521
+ raise ValueError(
522
+ "The specifying inputs_embedst is only supported when using a compiled RBLN model with 'rbln_use_inputs_embeds' set to True."
523
+ )
524
+ else:
480
525
  model_inputs.update({"input_ids": input_ids})
481
526
 
482
527
  model_inputs.update(
483
528
  {
484
- "cache_position": cache_positions,
485
- "past_cached_length": past_cached_length,
529
+ "attention_mask": attention_mask,
530
+ "cache_position": cache_position,
531
+ "generate_idx": generate_idx,
486
532
  }
487
533
  )
488
534
 
@@ -494,42 +540,46 @@ class RBLNDecoderOnlyModelForCausalLM(RBLNModel, ABC):
494
540
  model_kwargs: Dict[str, Any],
495
541
  **kwargs,
496
542
  ) -> Dict[str, Any]:
497
- # update past_cached_length
498
- model_kwargs["past_cached_length"] = outputs.past_cached_length
543
+ # update generate_idx
544
+ model_kwargs["generate_idx"] = outputs.generate_idx
499
545
 
500
546
  return model_kwargs
501
547
 
502
548
  def forward(
503
549
  self,
504
- input_ids: Optional[Union[List[torch.LongTensor], torch.LongTensor]] = None,
505
- inputs_embeds: Optional[Union[List[torch.Tensor], torch.Tensor]] = None,
506
- cache_position: Union[List[torch.Tensor], torch.Tensor] = None, # vllm keyword argument
550
+ input_ids: Optional[torch.LongTensor] = None,
551
+ inputs_embeds: Optional[torch.Tensor] = None,
552
+ cache_position: Optional[torch.Tensor] = None,
553
+ attention_mask: Optional[torch.LongTensor] = None,
554
+ generate_idx: Optional[torch.Tensor] = None,
555
+ # from llava_next forward args
507
556
  batch_idx: Optional[int] = None,
508
- past_cached_length: Optional[torch.Tensor] = None,
509
557
  **kwargs,
510
558
  ) -> Tuple[torch.FloatTensor]:
511
- # prefll & hf generate
512
- if isinstance(cache_position, list):
559
+ # prefll
560
+ if cache_position is None:
513
561
  logits = []
514
- input_tensors = input_ids if inputs_embeds is None else inputs_embeds
515
- for batch_idx, (input_tensor, cache_pos) in enumerate(zip(input_tensors, cache_position)):
562
+ input_tensors = inputs_embeds if inputs_embeds is not None else input_ids
563
+ batch_size = input_tensors.shape[0]
564
+
565
+ for b_idx in range(batch_size):
566
+ # Transform inputs as vllm format
567
+ if attention_mask is not None:
568
+ input_tensor = input_tensors[b_idx : b_idx + 1, attention_mask[b_idx].bool()]
569
+ else:
570
+ input_tensor = input_tensors[b_idx : b_idx + 1]
571
+
572
+ cache_position = torch.arange(0, generate_idx[b_idx].item(), dtype=torch.int32).unsqueeze(0)
573
+
516
574
  logit = self._forward_prefill(
517
575
  input_ids=input_tensor if inputs_embeds is None else None,
518
576
  inputs_embeds=input_tensor if inputs_embeds is not None else None,
519
- cache_position=cache_pos,
520
- batch_idx=batch_idx,
577
+ cache_position=cache_position,
578
+ batch_idx=b_idx if batch_idx is None else batch_idx, # Llava-next prefill
521
579
  )
522
580
  logits.append(logit)
523
581
  logits = torch.cat(logits, dim=0)
524
- # prefill & vllm step
525
- elif cache_position.shape[-1] > 1:
526
- logits = self._forward_prefill(
527
- input_ids=input_ids,
528
- inputs_embeds=inputs_embeds,
529
- cache_position=cache_position,
530
- batch_idx=batch_idx,
531
- )
532
- # common decoder
582
+ # decoder
533
583
  else:
534
584
  logits = self._forward_decoder(
535
585
  input_ids=input_ids,
@@ -539,7 +589,7 @@ class RBLNDecoderOnlyModelForCausalLM(RBLNModel, ABC):
539
589
 
540
590
  return RBLNDecoderOnlyOutput(
541
591
  logits=logits,
542
- past_cached_length=past_cached_length,
592
+ generate_idx=generate_idx,
543
593
  )
544
594
 
545
595
  def _forward_prefill(
@@ -567,23 +617,18 @@ class RBLNDecoderOnlyModelForCausalLM(RBLNModel, ABC):
567
617
  torch.empty(size=[], dtype=torch.int16, device="cpu"),
568
618
  ]
569
619
 
570
- if self.rbln_config.model_cfg["use_inputs_embeds"] and inputs_embeds is not None:
571
- model_input_name = "inputs_embeds"
572
- else:
573
- model_input_name = "input_ids"
574
-
575
- input_tensors = input_ids if model_input_name == "input_ids" else inputs_embeds
576
-
620
+ input_tensors = inputs_embeds if inputs_embeds is not None else input_ids
577
621
  query_length = input_tensors.shape[1]
578
- attention_mask = self.prefill_attention_mask.clone()
622
+ _attention_mask = self.prefill_attention_mask.clone()
623
+
579
624
  for step in range(0, query_length, self.prefill_chunk_size):
580
- if step + self.prefill_chunk_size > query_length:
581
- # input_tensors = torch.nn.functional.pad(input_tensors, (0, step + self.prefill_chunk_size - query_length))
582
- padding_needed = step + self.prefill_chunk_size - query_length
583
- if model_input_name == "input_ids":
584
- input_tensors = torch.nn.functional.pad(input_tensors, (0, padding_needed))
625
+ # pad input_tensors & cache_position for prefill_chunk
626
+ if (step + self.prefill_chunk_size) > query_length:
627
+ pad_to_chunk = step + self.prefill_chunk_size - query_length
628
+ if inputs_embeds is not None:
629
+ input_tensors = torch.nn.functional.pad(input_tensors, (0, 0, 0, pad_to_chunk))
585
630
  else:
586
- input_tensors = torch.nn.functional.pad(input_tensors, (0, 0, 0, padding_needed))
631
+ input_tensors = torch.nn.functional.pad(input_tensors, (0, pad_to_chunk))
587
632
 
588
633
  cache_position = torch.cat(
589
634
  [
@@ -597,25 +642,28 @@ class RBLNDecoderOnlyModelForCausalLM(RBLNModel, ABC):
597
642
  dim=-1,
598
643
  )
599
644
 
600
- sliced_input_tensors = input_tensors[:, step : step + self.prefill_chunk_size]
601
- sliced_cache_positions = cache_position[:, step : step + self.prefill_chunk_size]
645
+ # slice input_tensor & cache_position with prefill_chunk_size
646
+ _input_tensors = input_tensors[:, step : step + self.prefill_chunk_size]
647
+ _cache_position = cache_position[:, step : step + self.prefill_chunk_size]
602
648
 
649
+ # update attention_mask
603
650
  if step >= self.prefill_chunk_size:
604
- attention_mask[:, :, :, step - self.prefill_chunk_size : step] = 1
605
- attention_mask[:, :, :, step : step + self.prefill_chunk_size] = self.causal_mask
651
+ _attention_mask[:, :, :, step - self.prefill_chunk_size : step] = 1
652
+ _attention_mask[:, :, :, step : step + self.prefill_chunk_size] = self.causal_mask
606
653
 
607
- query_idx = query_length % self.prefill_chunk_size - 1
654
+ query_idx = (query_length - 1) % self.prefill_chunk_size
608
655
 
609
656
  logits, _ = self.prefill_decoder(
610
- input_ids=sliced_input_tensors.contiguous() if model_input_name == "input_ids" else None,
611
- inputs_embeds=sliced_input_tensors.contiguous() if model_input_name == "inputs_embeds" else None,
612
- attention_mask=attention_mask.contiguous(),
613
- cache_position=sliced_cache_positions.contiguous(),
657
+ input_ids=_input_tensors.contiguous() if inputs_embeds is None else None,
658
+ inputs_embeds=_input_tensors.contiguous() if inputs_embeds is not None else None,
659
+ attention_mask=_attention_mask.contiguous(),
660
+ cache_position=_cache_position.contiguous(),
614
661
  batch_position=torch.tensor(batch_idx, dtype=torch.int16),
615
662
  query_idx=torch.tensor(query_idx, dtype=torch.int16),
616
663
  out=out_buffers,
617
664
  )
618
665
 
666
+ # update decoder_attn_mask with preprocessed kv-cache length in prefill phase
619
667
  self.dec_attn_mask[batch_idx] = self.dec_attn_mask_init.clone()
620
668
  self.dec_attn_mask[batch_idx, :, :, :query_length] = 1
621
669
 
@@ -627,11 +675,7 @@ class RBLNDecoderOnlyModelForCausalLM(RBLNModel, ABC):
627
675
  inputs_embeds: torch.Tensor = None,
628
676
  cache_position: torch.Tensor = None,
629
677
  ) -> torch.FloatTensor:
630
- if self.rbln_config.model_cfg["use_inputs_embeds"] and inputs_embeds is not None:
631
- model_input_name = "inputs_embeds"
632
- else:
633
- model_input_name = "input_ids"
634
- input_tensors = input_ids if model_input_name == "input_ids" else inputs_embeds
678
+ input_tensors = inputs_embeds if inputs_embeds is not None else input_ids
635
679
 
636
680
  batch_size = input_tensors.shape[0]
637
681
 
@@ -640,8 +684,8 @@ class RBLNDecoderOnlyModelForCausalLM(RBLNModel, ABC):
640
684
  self.dec_attn_mask[b_idx, :, :, decoding_step] = 1
641
685
 
642
686
  logits, _ = self.decoder(
643
- input_ids=input_tensors.contiguous() if model_input_name == "input_ids" else None,
644
- inputs_embeds=input_tensors.contiguous() if model_input_name == "inputs_embeds" else None,
687
+ input_ids=input_tensors.contiguous() if inputs_embeds is None else None,
688
+ inputs_embeds=input_tensors.contiguous() if inputs_embeds is not None else None,
645
689
  attention_mask=self.dec_attn_mask.contiguous(),
646
690
  cache_position=cache_position.contiguous(),
647
691
  batch_position=torch.tensor(0, dtype=torch.int16),
@@ -649,3 +693,31 @@ class RBLNDecoderOnlyModelForCausalLM(RBLNModel, ABC):
649
693
  )
650
694
 
651
695
  return logits
696
+
697
+ def vllm_forward(
698
+ self,
699
+ input_ids: torch.LongTensor = None,
700
+ inputs_embeds: torch.Tensor = None,
701
+ cache_position: torch.Tensor = None,
702
+ batch_idx: Optional[int] = None,
703
+ **kwargs,
704
+ ) -> Tuple[torch.FloatTensor]:
705
+ # prefll
706
+ if cache_position.shape[-1] > 1:
707
+ logits = self._forward_prefill(
708
+ input_ids=input_ids,
709
+ inputs_embeds=inputs_embeds,
710
+ cache_position=cache_position,
711
+ batch_idx=batch_idx,
712
+ )
713
+ # decoder
714
+ else:
715
+ logits = self._forward_decoder(
716
+ input_ids=input_ids,
717
+ inputs_embeds=inputs_embeds,
718
+ cache_position=cache_position,
719
+ )
720
+
721
+ return RBLNDecoderOnlyOutput(
722
+ logits=logits,
723
+ )
@@ -38,7 +38,6 @@ if TYPE_CHECKING:
38
38
 
39
39
 
40
40
  class RBLNDPTForDepthEstimation(RBLNModel):
41
- model_type = "rbln_model"
42
41
  auto_model_class = AutoModelForDepthEstimation
43
42
  main_input_name = "pixel_values"
44
43
 
@@ -0,0 +1,32 @@
1
+ # Copyright 2024 Rebellions Inc.
2
+
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at:
6
+
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ # Portions of this software are licensed under the Apache License,
16
+ # Version 2.0. See the NOTICE file distributed with this work for
17
+ # additional information regarding copyright ownership.
18
+
19
+ # All other portions of this software, including proprietary code,
20
+ # are the intellectual property of Rebellions Inc. and may not be
21
+ # copied, modified, or distributed without prior written permission
22
+ # from Rebellions Inc.
23
+
24
+ import os
25
+ from os import environ
26
+
27
+
28
+ this_path = os.path.abspath(__file__)
29
+ local_dir = "/" + os.path.join(*this_path.split("/")[:-1]) + "/hf_hub_cached"
30
+ environ["LOCAL_CACHE_ROOT_CUSTOM_CODE_MIDM"] = local_dir
31
+
32
+ from .modeling_exaone import RBLNExaoneForCausalLM