optimum-rbln 0.1.11__py3-none-any.whl → 0.1.13__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (72) hide show
  1. optimum/rbln/__init__.py +14 -7
  2. optimum/rbln/__version__.py +1 -1
  3. optimum/rbln/diffusers/models/autoencoder_kl.py +30 -63
  4. optimum/rbln/diffusers/models/controlnet.py +36 -62
  5. optimum/rbln/diffusers/models/unet_2d_condition.py +57 -156
  6. optimum/rbln/diffusers/pipelines/__init__.py +40 -12
  7. optimum/rbln/diffusers/pipelines/controlnet/multicontrolnet.py +11 -0
  8. optimum/rbln/diffusers/pipelines/controlnet/pipeline_controlnet.py +9 -187
  9. optimum/rbln/diffusers/pipelines/controlnet/pipeline_controlnet_img2img.py +8 -192
  10. optimum/rbln/diffusers/pipelines/controlnet/pipeline_controlnet_sd_xl.py +8 -206
  11. optimum/rbln/diffusers/pipelines/controlnet/pipeline_controlnet_sd_xl_img2img.py +8 -207
  12. optimum/rbln/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion.py +3 -111
  13. optimum/rbln/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_img2img.py +12 -117
  14. optimum/rbln/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl.py +4 -123
  15. optimum/rbln/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_img2img.py +4 -126
  16. optimum/rbln/modeling_alias.py +4 -9
  17. optimum/rbln/modeling_base.py +117 -144
  18. optimum/rbln/modeling_config.py +51 -0
  19. optimum/rbln/modeling_diffusers.py +400 -0
  20. optimum/rbln/transformers/__init__.py +10 -0
  21. optimum/rbln/transformers/cache_utils.py +5 -9
  22. optimum/rbln/transformers/modeling_rope_utils.py +283 -0
  23. optimum/rbln/transformers/models/__init__.py +80 -28
  24. optimum/rbln/transformers/models/auto/modeling_auto.py +1 -0
  25. optimum/rbln/transformers/models/bart/__init__.py +1 -1
  26. optimum/rbln/transformers/models/bart/bart_architecture.py +18 -12
  27. optimum/rbln/transformers/models/bart/modeling_bart.py +25 -6
  28. optimum/rbln/transformers/models/bert/modeling_bert.py +1 -2
  29. optimum/rbln/transformers/models/clip/modeling_clip.py +13 -23
  30. optimum/rbln/transformers/models/decoderonly/__init__.py +0 -2
  31. optimum/rbln/transformers/models/decoderonly/decoderonly_architecture.py +376 -218
  32. optimum/rbln/transformers/models/decoderonly/modeling_decoderonly.py +246 -116
  33. optimum/rbln/transformers/models/dpt/modeling_dpt.py +0 -1
  34. optimum/rbln/transformers/models/exaone/__init__.py +32 -0
  35. optimum/rbln/transformers/models/exaone/exaone_architecture.py +81 -0
  36. optimum/rbln/transformers/models/exaone/hf_hub_cached/configuration_exaone.py +181 -0
  37. optimum/rbln/transformers/models/exaone/hf_hub_cached/modeling_exaone.py +1725 -0
  38. optimum/rbln/transformers/models/exaone/modeling_exaone.py +53 -0
  39. optimum/rbln/transformers/models/gemma/gemma_architecture.py +12 -2
  40. optimum/rbln/transformers/models/gemma/modeling_gemma.py +4 -28
  41. optimum/rbln/transformers/models/gpt2/modeling_gpt2.py +4 -30
  42. optimum/rbln/transformers/models/llama/modeling_llama.py +4 -28
  43. optimum/rbln/transformers/models/llava_next/modeling_llava_next.py +166 -151
  44. optimum/rbln/transformers/models/midm/midm_architecture.py +4 -15
  45. optimum/rbln/transformers/models/midm/modeling_midm.py +8 -28
  46. optimum/rbln/transformers/models/mistral/modeling_mistral.py +4 -29
  47. optimum/rbln/transformers/models/phi/modeling_phi.py +5 -31
  48. optimum/rbln/transformers/models/phi/phi_architecture.py +75 -159
  49. optimum/rbln/transformers/models/qwen2/__init__.py +24 -0
  50. optimum/rbln/transformers/models/qwen2/modeling_qwen2.py +43 -0
  51. optimum/rbln/transformers/models/qwen2/qwen2_architecture.py +29 -0
  52. optimum/rbln/transformers/models/seq2seq/__init__.py +24 -0
  53. optimum/rbln/{modeling_seq2seq.py → transformers/models/seq2seq/modeling_seq2seq.py} +107 -166
  54. optimum/rbln/transformers/models/t5/__init__.py +1 -0
  55. optimum/rbln/transformers/models/t5/modeling_t5.py +108 -0
  56. optimum/rbln/transformers/models/t5/t5_architecture.py +46 -32
  57. optimum/rbln/transformers/models/wav2vec2/modeling_wav2vec2.py +0 -1
  58. optimum/rbln/transformers/models/whisper/modeling_whisper.py +38 -13
  59. optimum/rbln/transformers/models/xlm_roberta/modeling_xlm_roberta.py +1 -2
  60. optimum/rbln/transformers/utils/rbln_quantization.py +8 -2
  61. optimum/rbln/utils/context.py +58 -0
  62. optimum/rbln/utils/decorator_utils.py +55 -0
  63. optimum/rbln/utils/import_utils.py +21 -0
  64. optimum/rbln/utils/logging.py +1 -1
  65. optimum/rbln/utils/runtime_utils.py +4 -4
  66. optimum/rbln/utils/timer_utils.py +26 -2
  67. {optimum_rbln-0.1.11.dist-info → optimum_rbln-0.1.13.dist-info}/METADATA +11 -9
  68. optimum_rbln-0.1.13.dist-info/RECORD +107 -0
  69. {optimum_rbln-0.1.11.dist-info → optimum_rbln-0.1.13.dist-info}/WHEEL +1 -1
  70. optimum_rbln-0.1.11.dist-info/RECORD +0 -93
  71. {optimum_rbln-0.1.11.dist-info → optimum_rbln-0.1.13.dist-info}/entry_points.txt +0 -0
  72. {optimum_rbln-0.1.11.dist-info → optimum_rbln-0.1.13.dist-info}/licenses/LICENSE +0 -0
@@ -20,15 +20,17 @@
20
20
  # are the intellectual property of Rebellions Inc. and may not be
21
21
  # copied, modified, or distributed without prior written permission
22
22
  # from Rebellions Inc.
23
+ import functools
23
24
  import glob
24
- import logging
25
- from abc import ABC
25
+ import inspect
26
+ import os
26
27
  from dataclasses import dataclass
27
28
  from pathlib import Path
28
- from typing import TYPE_CHECKING, Any, Dict, List, Optional, Tuple, Union
29
+ from typing import TYPE_CHECKING, Any, Callable, Dict, List, Optional, Tuple, Union
29
30
 
30
- import rebel # noqa: F401
31
- import torch # noqa: F401
31
+ import rebel
32
+ import torch
33
+ import transformers
32
34
  from safetensors.torch import load_file
33
35
  from transformers import AutoConfig, AutoModelForCausalLM, PretrainedConfig, PreTrainedModel
34
36
  from transformers.modeling_utils import no_init_weights
@@ -36,11 +38,13 @@ from transformers.utils import ModelOutput
36
38
 
37
39
  from ....modeling_base import RBLNModel
38
40
  from ....modeling_config import DEFAULT_COMPILED_MODEL_NAME, RBLNCompileConfig, RBLNConfig
41
+ from ....utils.logging import get_logger
39
42
  from ....utils.runtime_utils import RBLNPytorchRuntime
40
43
  from ....utils.timer_utils import rbln_timer
44
+ from .decoderonly_architecture import DecoderOnlyWrapper
41
45
 
42
46
 
43
- logger = logging.getLogger(__name__)
47
+ logger = get_logger()
44
48
 
45
49
  if TYPE_CHECKING:
46
50
  from transformers import (
@@ -97,22 +101,50 @@ class RBLNRuntimeModel(RBLNPytorchRuntime):
97
101
  @dataclass
98
102
  class RBLNDecoderOnlyOutput(ModelOutput):
99
103
  logits: torch.FloatTensor = None
100
- past_cached_length: Union[int, torch.Tensor] = None
104
+ generate_idx: torch.Tensor = None
101
105
 
102
106
 
103
- class RBLNDecoderOnlyModelForCausalLM(RBLNModel, ABC):
107
+ class RBLNDecoderOnlyModelForCausalLM(RBLNModel):
104
108
  """
105
- The DecoderOnly Model transformer with a language modeling head (linear layer) on top.
106
- This model inherits from [`RBLNMultiModel`]. Check the superclass documentation for the generic methods the library implements for all its models.
107
-
108
- A class to convert and run pre-trained transformers based DecoderOnlyForCausalLM model on RBLN devices.
109
- It implements the methods to convert a pre-trained transformers DecoderOnlyForCausalLM model into a RBLN transformer model by:
110
- - transferring the checkpoint weights of the original into an optimized RBLN graph,
111
- - compiling the resulting graph using the RBLN compiler.
109
+ A base class for decoder-only transformer models optimized for causal language modeling tasks on RBLN devices.
110
+ This class serves as the foundation for various decoder-only architectures like GPT, LLaMA, etc.
111
+
112
+ The class provides core functionality for:
113
+ 1. Converting pre-trained transformer models to RBLN-optimized format
114
+ 2. Handling the compilation process for RBLN devices
115
+ 3. Managing inference operations for causal language modeling
116
+
117
+ This class inherits from RBLNModel and implements specific methods required for
118
+ decoder-only architectures and causal language modeling tasks.
119
+
120
+ Note:
121
+ - This class is designed to be subclassed by specific model implementations
122
+ (e.g., RBLNLlamaForCausalLM, RBLNGPT2LMHeadModel)
123
+ - Subclasses should implement model-specific conversion logic.
124
+ - The class handles RBLN-specific optimizations automatically during compilation
112
125
  """
113
126
 
114
127
  main_input_name = "input_ids"
115
128
  auto_model_class = AutoModelForCausalLM
129
+ _decoder_wrapper_cls = DecoderOnlyWrapper
130
+ _original_cls = None
131
+
132
+ @classmethod
133
+ @property
134
+ def original_cls(cls):
135
+ """
136
+ Lazily loads and caches the corresponding Hugging Face model class.
137
+ Removes 'RBLN' prefix from the class name to get the original class name
138
+ (e.g., RBLNLlamaForCausalLM -> LlamaForCausalLM) and imports it from
139
+ the transformers module.
140
+
141
+ Returns:
142
+ type: The original Hugging Face model class
143
+ """
144
+ if cls._original_cls is None:
145
+ hf_original_cls_name = cls.__name__[4:]
146
+ cls._original_cls = getattr(transformers, hf_original_cls_name)
147
+ return cls._original_cls
116
148
 
117
149
  def __post_init__(self, **kwargs):
118
150
  self.batch_size = self.rbln_config.model_cfg["batch_size"]
@@ -231,6 +263,26 @@ class RBLNDecoderOnlyModelForCausalLM(RBLNModel, ABC):
231
263
 
232
264
  return 0
233
265
 
266
+ def __getattr__(self, __name: str) -> Any:
267
+ """
268
+ Special method to delegate attribute access to the original Huggingface LM class.
269
+ This method is called when an attribute is not found in the current instance's dictionary.
270
+ It enables transparent access to the original model's attributes and methods while maintaining
271
+ proper method binding.
272
+
273
+ The method implements a delegation pattern that:
274
+ 1. For methods: Creates a wrapper that properly binds 'self' to method calls
275
+ 2. For other attributes: Returns them directly from the original class
276
+ """
277
+
278
+ def redirect(func):
279
+ return lambda *pargs, **kwargs: func(self, *pargs, **kwargs)
280
+
281
+ val = getattr(self.original_cls, __name)
282
+ if isinstance(val, Callable) and "self" in set(inspect.signature(val).parameters):
283
+ return redirect(val)
284
+ return val
285
+
234
286
  @classmethod
235
287
  def get_pytorch_model(cls, *args, **kwargs) -> "PreTrainedModel":
236
288
  rbln_kwargs = kwargs.get("rbln_kwargs", {})
@@ -243,6 +295,64 @@ class RBLNDecoderOnlyModelForCausalLM(RBLNModel, ABC):
243
295
 
244
296
  return model
245
297
 
298
+ def validate_quantization_config(quantize_config):
299
+ if quantize_config is not None:
300
+ q_format = quantize_config.get("format")
301
+ q_precision = quantize_config.get("precision")
302
+
303
+ if q_format not in SUPPORTED_QUANTIZATIONS:
304
+ raise ValueError(
305
+ f"Invalid quantization format: {q_format}. "
306
+ f"Supported formats are: {list(SUPPORTED_QUANTIZATIONS.keys())}"
307
+ )
308
+
309
+ if q_precision not in SUPPORTED_QUANTIZATIONS[q_format]:
310
+ raise ValueError(
311
+ f"Invalid precision: {q_precision} for format: {q_format}. "
312
+ f"Supported precisions are: {SUPPORTED_QUANTIZATIONS[q_format]}"
313
+ )
314
+
315
+ return quantize_config
316
+
317
+ @classmethod
318
+ def set_quantize_env(cls, quantize_config):
319
+ RBLN_QUANT_BITS_ENV = "RBLN_QUANT_BITS"
320
+ quantize_config = cls.validate_quantization_config(quantize_config)
321
+ if quantize_config is not None:
322
+ q_precision = quantize_config.get("precision")
323
+ quant_bits = q_precision.split("w")[1].split("a")[0]
324
+ os.environ[RBLN_QUANT_BITS_ENV] = quant_bits
325
+ return RBLN_QUANT_BITS_ENV
326
+ return None
327
+
328
+ @classmethod
329
+ def reset_quantize_env(cls, env_var_name):
330
+ if env_var_name is not None and env_var_name in os.environ:
331
+ del os.environ[env_var_name]
332
+
333
+ @classmethod
334
+ def manage_quantize_env(cls, func):
335
+ @functools.wraps(func)
336
+ def wrapper(*args, **kwargs):
337
+ quantize_config = kwargs.get("quantize_config")
338
+ quantize_env_var = cls.set_quantize_env(quantize_config)
339
+ try:
340
+ return func(*args, **kwargs)
341
+ finally:
342
+ cls.reset_quantize_env(quantize_env_var)
343
+
344
+ return wrapper
345
+
346
+ @classmethod
347
+ def wrap_model_if_needed(cls, model: "PreTrainedModel", rbln_config: "RBLNConfig"):
348
+ wrapper_cfg = {"max_seq_len": rbln_config.model_cfg["max_seq_len"]}
349
+
350
+ # If the model wrapper supports rbln-custom-flash-attention
351
+ if "kvcache_partition_len" in inspect.signature(cls._decoder_wrapper_cls.__init__).parameters:
352
+ wrapper_cfg["kvcache_partition_len"] = rbln_config.model_cfg.get("kvcache_partition_len")
353
+
354
+ return cls._decoder_wrapper_cls(model, **wrapper_cfg).eval()
355
+
246
356
  @classmethod
247
357
  @torch.inference_mode()
248
358
  def get_compiled_model(cls, model: "PreTrainedModel", rbln_config: RBLNConfig):
@@ -252,14 +362,12 @@ class RBLNDecoderOnlyModelForCausalLM(RBLNModel, ABC):
252
362
  prefill_rbln_compile_config = rbln_compile_configs[0]
253
363
  dec_rbln_compile_config = rbln_compile_configs[1]
254
364
 
255
- @rbln_timer("Jit Trace")
365
+ @rbln_timer("JIT trace")
256
366
  def get_scripted_model():
257
367
  # This function is nested to dealloc the example inputs before compilation.
368
+ # FIXME: 3rd dummy_input(batch_idx) should be fill zero to compile flash_attn.
258
369
  prefill_example_inputs = prefill_rbln_compile_config.get_dummy_inputs(fill=0)
259
- dec_example_inputs = dec_rbln_compile_config.get_dummy_inputs(fill=4)
260
-
261
- batch_index = 3
262
- dec_example_inputs[batch_index].fill_(-1) # fill batch_position -1 to indicate it is decoder.
370
+ dec_example_inputs = dec_rbln_compile_config.get_dummy_inputs(fill=0)
263
371
 
264
372
  prefill_scripted_model = torch.jit.trace(
265
373
  wrapped_model, prefill_example_inputs, check_trace=False, _store_inputs=False
@@ -271,7 +379,7 @@ class RBLNDecoderOnlyModelForCausalLM(RBLNModel, ABC):
271
379
 
272
380
  prefill_scripted_model, dec_scripted_model = get_scripted_model()
273
381
 
274
- @rbln_timer("TorchScript to IR")
382
+ @rbln_timer("Model conversion")
275
383
  def scripted_model_to_ir():
276
384
  prefill_ir = rebel.torchscript_to_ir(
277
385
  prefill_scripted_model,
@@ -291,7 +399,18 @@ class RBLNDecoderOnlyModelForCausalLM(RBLNModel, ABC):
291
399
  for i in range(model.config.num_hidden_layers * 2)
292
400
  ]
293
401
 
294
- compiled_model = rebel.compile(
402
+ # Extract quantize_config from rbln_config
403
+ quantize_config = rbln_config.model_cfg.get("quantization", None)
404
+
405
+ @cls.manage_quantize_env
406
+ def compile_model(*args, **kwargs):
407
+ # Remove quantize_config from kwargs
408
+ kwargs.pop("quantize_config", None)
409
+
410
+ # Call rebel.compile with the updated kwargs
411
+ return rebel.compile(*args, **kwargs)
412
+
413
+ compiled_model = compile_model(
295
414
  prefill_ir,
296
415
  dec_ir,
297
416
  connections=connections,
@@ -299,7 +418,9 @@ class RBLNDecoderOnlyModelForCausalLM(RBLNModel, ABC):
299
418
  npu=prefill_rbln_compile_config.npu,
300
419
  tensor_parallel_size=prefill_rbln_compile_config.tensor_parallel_size,
301
420
  use_weight_sharing=True,
421
+ quantize_config=quantize_config,
302
422
  )
423
+
303
424
  return compiled_model
304
425
 
305
426
  @classmethod
@@ -314,6 +435,8 @@ class RBLNDecoderOnlyModelForCausalLM(RBLNModel, ABC):
314
435
  rbln_quantization = rbln_kwargs.get("quantization", None)
315
436
  rbln_use_inputs_embeds = rbln_kwargs.get("use_inputs_embeds", None)
316
437
 
438
+ rbln_quantization = cls.validate_quantization_config(rbln_quantization)
439
+
317
440
  prefill_chunk_size = 128
318
441
  if rbln_max_seq_len is None:
319
442
  rbln_max_seq_len = getattr(model_config, "max_position_embeddings", None) or getattr(
@@ -330,16 +453,6 @@ class RBLNDecoderOnlyModelForCausalLM(RBLNModel, ABC):
330
453
  head_dim = getattr(model_config, "head_dim", None) or model_config.hidden_size // num_attention_heads
331
454
  hidden_size = getattr(model_config, "n_embd", None) or getattr(model_config, "hidden_size")
332
455
 
333
- if rbln_quantization is not None:
334
- q_format = rbln_quantization.get("format", None)
335
- q_precision = rbln_quantization.get("precision", None)
336
-
337
- if q_format not in SUPPORTED_QUANTIZATIONS.keys() or q_precision not in SUPPORTED_QUANTIZATIONS[q_format]:
338
- raise ValueError(
339
- f'rbln_quantization="{rbln_quantization}" is not a supported quantization format or precesion, '
340
- f"Possible: {SUPPORTED_QUANTIZATIONS}"
341
- )
342
-
343
456
  def get_input_info(
344
457
  batch_size,
345
458
  query_length,
@@ -439,50 +552,41 @@ class RBLNDecoderOnlyModelForCausalLM(RBLNModel, ABC):
439
552
  def prepare_inputs_for_generation(
440
553
  self,
441
554
  input_ids: torch.LongTensor,
442
- past_cached_length: Optional[torch.Tensor] = None,
555
+ generate_idx: Optional[torch.Tensor] = None,
443
556
  attention_mask: Optional[torch.LongTensor] = None,
444
557
  inputs_embeds: Optional[torch.Tensor] = None,
445
558
  **kwargs,
446
559
  ):
447
560
  model_inputs = {}
448
- # prefill phase
449
- if past_cached_length is None:
450
- # huggingface make dummy_input_ids if model_input_name is "input_embeds"
451
- # https://github.com/huggingface/transformers/blob/174890280b340b89c5bfa092f6b4fb0e2dc2d7fc/src/transformers/generation/utils.py#L469
452
- if self.rbln_config.model_cfg["use_inputs_embeds"] and inputs_embeds is not None:
453
- input_tensors = inputs_embeds
454
- else:
455
- input_tensors = input_ids
561
+ is_prefill_phase = generate_idx is None
456
562
 
457
- batch_size = input_tensors.shape[0]
458
- l_input_tensors = []
459
- cache_positions = []
460
- past_cached_length = torch.zeros((batch_size, 1), dtype=torch.int32)
461
- for i in range(batch_size):
462
- input_tensor = input_tensors[i]
463
- input_tensor = input_tensor[attention_mask[i] == 1]
464
- valid_len = input_tensor.shape[0]
465
- cache_position = torch.arange(0, valid_len, dtype=torch.int32)
466
- past_cached_length[i] = valid_len
467
- l_input_tensors.append(input_tensor.unsqueeze(0))
468
- cache_positions.append(cache_position.unsqueeze(0))
469
-
470
- input_tensors = l_input_tensors
471
- if self.rbln_config.model_cfg["use_inputs_embeds"] and inputs_embeds is not None:
472
- model_inputs.update({"inputs_embeds": input_tensors, "input_ids": input_ids})
473
- else:
474
- model_inputs.update({"input_ids": input_tensors, "inputs_embeds": inputs_embeds})
475
- # decoder phase
563
+ if is_prefill_phase:
564
+ generate_idx = attention_mask.sum(dim=-1, keepdim=True).int()
565
+ cache_position = None
476
566
  else:
567
+ if inputs_embeds is not None:
568
+ raise NotImplementedError("Specifying inputs_embeds in decoder phase is not supported.")
569
+
477
570
  input_ids = input_ids[:, -1:]
478
- cache_positions = past_cached_length
479
- past_cached_length = past_cached_length + 1
571
+ cache_position = generate_idx
572
+ generate_idx = generate_idx + 1
573
+ model_inputs.update({"input_ids": input_ids})
574
+
575
+ if inputs_embeds is not None:
576
+ if self.rbln_config.model_cfg["use_inputs_embeds"]:
577
+ model_inputs.update({"inputs_embeds": inputs_embeds})
578
+ else:
579
+ raise ValueError(
580
+ "The specifying inputs_embedst is only supported when using a compiled RBLN model with 'rbln_use_inputs_embeds' set to True."
581
+ )
582
+ else:
480
583
  model_inputs.update({"input_ids": input_ids})
481
584
 
482
585
  model_inputs.update(
483
586
  {
484
- "cache_position": cache_positions,
485
- "past_cached_length": past_cached_length,
587
+ "attention_mask": attention_mask,
588
+ "cache_position": cache_position,
589
+ "generate_idx": generate_idx,
486
590
  }
487
591
  )
488
592
 
@@ -494,42 +598,46 @@ class RBLNDecoderOnlyModelForCausalLM(RBLNModel, ABC):
494
598
  model_kwargs: Dict[str, Any],
495
599
  **kwargs,
496
600
  ) -> Dict[str, Any]:
497
- # update past_cached_length
498
- model_kwargs["past_cached_length"] = outputs.past_cached_length
601
+ # update generate_idx
602
+ model_kwargs["generate_idx"] = outputs.generate_idx
499
603
 
500
604
  return model_kwargs
501
605
 
502
606
  def forward(
503
607
  self,
504
- input_ids: Optional[Union[List[torch.LongTensor], torch.LongTensor]] = None,
505
- inputs_embeds: Optional[Union[List[torch.Tensor], torch.Tensor]] = None,
506
- cache_position: Union[List[torch.Tensor], torch.Tensor] = None, # vllm keyword argument
608
+ input_ids: Optional[torch.LongTensor] = None,
609
+ inputs_embeds: Optional[torch.Tensor] = None,
610
+ cache_position: Optional[torch.Tensor] = None,
611
+ attention_mask: Optional[torch.LongTensor] = None,
612
+ generate_idx: Optional[torch.Tensor] = None,
613
+ # from llava_next forward args
507
614
  batch_idx: Optional[int] = None,
508
- past_cached_length: Optional[torch.Tensor] = None,
509
615
  **kwargs,
510
616
  ) -> Tuple[torch.FloatTensor]:
511
- # prefll & hf generate
512
- if isinstance(cache_position, list):
617
+ # prefll
618
+ if cache_position is None:
513
619
  logits = []
514
- input_tensors = input_ids if inputs_embeds is None else inputs_embeds
515
- for batch_idx, (input_tensor, cache_pos) in enumerate(zip(input_tensors, cache_position)):
620
+ input_tensors = inputs_embeds if inputs_embeds is not None else input_ids
621
+ batch_size = input_tensors.shape[0]
622
+
623
+ for b_idx in range(batch_size):
624
+ # Transform inputs as vllm format
625
+ if attention_mask is not None:
626
+ input_tensor = input_tensors[b_idx : b_idx + 1, attention_mask[b_idx].bool()]
627
+ else:
628
+ input_tensor = input_tensors[b_idx : b_idx + 1]
629
+
630
+ cache_position = torch.arange(0, generate_idx[b_idx].item(), dtype=torch.int32).unsqueeze(0)
631
+
516
632
  logit = self._forward_prefill(
517
633
  input_ids=input_tensor if inputs_embeds is None else None,
518
634
  inputs_embeds=input_tensor if inputs_embeds is not None else None,
519
- cache_position=cache_pos,
520
- batch_idx=batch_idx,
635
+ cache_position=cache_position,
636
+ batch_idx=b_idx if batch_idx is None else batch_idx, # Llava-next prefill
521
637
  )
522
638
  logits.append(logit)
523
639
  logits = torch.cat(logits, dim=0)
524
- # prefill & vllm step
525
- elif cache_position.shape[-1] > 1:
526
- logits = self._forward_prefill(
527
- input_ids=input_ids,
528
- inputs_embeds=inputs_embeds,
529
- cache_position=cache_position,
530
- batch_idx=batch_idx,
531
- )
532
- # common decoder
640
+ # decoder
533
641
  else:
534
642
  logits = self._forward_decoder(
535
643
  input_ids=input_ids,
@@ -539,7 +647,7 @@ class RBLNDecoderOnlyModelForCausalLM(RBLNModel, ABC):
539
647
 
540
648
  return RBLNDecoderOnlyOutput(
541
649
  logits=logits,
542
- past_cached_length=past_cached_length,
650
+ generate_idx=generate_idx,
543
651
  )
544
652
 
545
653
  def _forward_prefill(
@@ -567,23 +675,18 @@ class RBLNDecoderOnlyModelForCausalLM(RBLNModel, ABC):
567
675
  torch.empty(size=[], dtype=torch.int16, device="cpu"),
568
676
  ]
569
677
 
570
- if self.rbln_config.model_cfg["use_inputs_embeds"] and inputs_embeds is not None:
571
- model_input_name = "inputs_embeds"
572
- else:
573
- model_input_name = "input_ids"
574
-
575
- input_tensors = input_ids if model_input_name == "input_ids" else inputs_embeds
576
-
678
+ input_tensors = inputs_embeds if inputs_embeds is not None else input_ids
577
679
  query_length = input_tensors.shape[1]
578
- attention_mask = self.prefill_attention_mask.clone()
680
+ _attention_mask = self.prefill_attention_mask.clone()
681
+
579
682
  for step in range(0, query_length, self.prefill_chunk_size):
580
- if step + self.prefill_chunk_size > query_length:
581
- # input_tensors = torch.nn.functional.pad(input_tensors, (0, step + self.prefill_chunk_size - query_length))
582
- padding_needed = step + self.prefill_chunk_size - query_length
583
- if model_input_name == "input_ids":
584
- input_tensors = torch.nn.functional.pad(input_tensors, (0, padding_needed))
683
+ # pad input_tensors & cache_position for prefill_chunk
684
+ if (step + self.prefill_chunk_size) > query_length:
685
+ pad_to_chunk = step + self.prefill_chunk_size - query_length
686
+ if inputs_embeds is not None:
687
+ input_tensors = torch.nn.functional.pad(input_tensors, (0, 0, 0, pad_to_chunk))
585
688
  else:
586
- input_tensors = torch.nn.functional.pad(input_tensors, (0, 0, 0, padding_needed))
689
+ input_tensors = torch.nn.functional.pad(input_tensors, (0, pad_to_chunk))
587
690
 
588
691
  cache_position = torch.cat(
589
692
  [
@@ -597,25 +700,28 @@ class RBLNDecoderOnlyModelForCausalLM(RBLNModel, ABC):
597
700
  dim=-1,
598
701
  )
599
702
 
600
- sliced_input_tensors = input_tensors[:, step : step + self.prefill_chunk_size]
601
- sliced_cache_positions = cache_position[:, step : step + self.prefill_chunk_size]
703
+ # slice input_tensor & cache_position with prefill_chunk_size
704
+ _input_tensors = input_tensors[:, step : step + self.prefill_chunk_size]
705
+ _cache_position = cache_position[:, step : step + self.prefill_chunk_size]
602
706
 
707
+ # update attention_mask
603
708
  if step >= self.prefill_chunk_size:
604
- attention_mask[:, :, :, step - self.prefill_chunk_size : step] = 1
605
- attention_mask[:, :, :, step : step + self.prefill_chunk_size] = self.causal_mask
709
+ _attention_mask[:, :, :, step - self.prefill_chunk_size : step] = 1
710
+ _attention_mask[:, :, :, step : step + self.prefill_chunk_size] = self.causal_mask
606
711
 
607
- query_idx = query_length % self.prefill_chunk_size - 1
712
+ query_idx = (query_length - 1) % self.prefill_chunk_size
608
713
 
609
714
  logits, _ = self.prefill_decoder(
610
- input_ids=sliced_input_tensors.contiguous() if model_input_name == "input_ids" else None,
611
- inputs_embeds=sliced_input_tensors.contiguous() if model_input_name == "inputs_embeds" else None,
612
- attention_mask=attention_mask.contiguous(),
613
- cache_position=sliced_cache_positions.contiguous(),
715
+ input_ids=_input_tensors.contiguous() if inputs_embeds is None else None,
716
+ inputs_embeds=_input_tensors.contiguous() if inputs_embeds is not None else None,
717
+ attention_mask=_attention_mask.contiguous(),
718
+ cache_position=_cache_position.contiguous(),
614
719
  batch_position=torch.tensor(batch_idx, dtype=torch.int16),
615
720
  query_idx=torch.tensor(query_idx, dtype=torch.int16),
616
721
  out=out_buffers,
617
722
  )
618
723
 
724
+ # update decoder_attn_mask with preprocessed kv-cache length in prefill phase
619
725
  self.dec_attn_mask[batch_idx] = self.dec_attn_mask_init.clone()
620
726
  self.dec_attn_mask[batch_idx, :, :, :query_length] = 1
621
727
 
@@ -627,11 +733,7 @@ class RBLNDecoderOnlyModelForCausalLM(RBLNModel, ABC):
627
733
  inputs_embeds: torch.Tensor = None,
628
734
  cache_position: torch.Tensor = None,
629
735
  ) -> torch.FloatTensor:
630
- if self.rbln_config.model_cfg["use_inputs_embeds"] and inputs_embeds is not None:
631
- model_input_name = "inputs_embeds"
632
- else:
633
- model_input_name = "input_ids"
634
- input_tensors = input_ids if model_input_name == "input_ids" else inputs_embeds
736
+ input_tensors = inputs_embeds if inputs_embeds is not None else input_ids
635
737
 
636
738
  batch_size = input_tensors.shape[0]
637
739
 
@@ -640,8 +742,8 @@ class RBLNDecoderOnlyModelForCausalLM(RBLNModel, ABC):
640
742
  self.dec_attn_mask[b_idx, :, :, decoding_step] = 1
641
743
 
642
744
  logits, _ = self.decoder(
643
- input_ids=input_tensors.contiguous() if model_input_name == "input_ids" else None,
644
- inputs_embeds=input_tensors.contiguous() if model_input_name == "inputs_embeds" else None,
745
+ input_ids=input_tensors.contiguous() if inputs_embeds is None else None,
746
+ inputs_embeds=input_tensors.contiguous() if inputs_embeds is not None else None,
645
747
  attention_mask=self.dec_attn_mask.contiguous(),
646
748
  cache_position=cache_position.contiguous(),
647
749
  batch_position=torch.tensor(0, dtype=torch.int16),
@@ -649,3 +751,31 @@ class RBLNDecoderOnlyModelForCausalLM(RBLNModel, ABC):
649
751
  )
650
752
 
651
753
  return logits
754
+
755
+ def vllm_forward(
756
+ self,
757
+ input_ids: torch.LongTensor = None,
758
+ inputs_embeds: torch.Tensor = None,
759
+ cache_position: torch.Tensor = None,
760
+ batch_idx: Optional[int] = None,
761
+ **kwargs,
762
+ ) -> Tuple[torch.FloatTensor]:
763
+ # prefll
764
+ if cache_position.shape[-1] > 1:
765
+ logits = self._forward_prefill(
766
+ input_ids=input_ids,
767
+ inputs_embeds=inputs_embeds,
768
+ cache_position=cache_position,
769
+ batch_idx=batch_idx,
770
+ )
771
+ # decoder
772
+ else:
773
+ logits = self._forward_decoder(
774
+ input_ids=input_ids,
775
+ inputs_embeds=inputs_embeds,
776
+ cache_position=cache_position,
777
+ )
778
+
779
+ return RBLNDecoderOnlyOutput(
780
+ logits=logits,
781
+ )
@@ -38,7 +38,6 @@ if TYPE_CHECKING:
38
38
 
39
39
 
40
40
  class RBLNDPTForDepthEstimation(RBLNModel):
41
- model_type = "rbln_model"
42
41
  auto_model_class = AutoModelForDepthEstimation
43
42
  main_input_name = "pixel_values"
44
43
 
@@ -0,0 +1,32 @@
1
+ # Copyright 2024 Rebellions Inc.
2
+
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at:
6
+
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ # Portions of this software are licensed under the Apache License,
16
+ # Version 2.0. See the NOTICE file distributed with this work for
17
+ # additional information regarding copyright ownership.
18
+
19
+ # All other portions of this software, including proprietary code,
20
+ # are the intellectual property of Rebellions Inc. and may not be
21
+ # copied, modified, or distributed without prior written permission
22
+ # from Rebellions Inc.
23
+
24
+ import os
25
+ from os import environ
26
+
27
+
28
+ this_path = os.path.abspath(__file__)
29
+ local_dir = "/" + os.path.join(*this_path.split("/")[:-1]) + "/hf_hub_cached"
30
+ environ["LOCAL_CACHE_ROOT_CUSTOM_CODE_MIDM"] = local_dir
31
+
32
+ from .modeling_exaone import RBLNExaoneForCausalLM
@@ -0,0 +1,81 @@
1
+ # Copyright 2024 Rebellions Inc.
2
+
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at:
6
+
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ # Portions of this software are licensed under the Apache License,
16
+ # Version 2.0. See the NOTICE file distributed with this work for
17
+ # additional information regarding copyright ownership.
18
+
19
+ # All other portions of this software, including proprietary code,
20
+ # are the intellectual property of Rebellions Inc. and may not be
21
+ # copied, modified, or distributed without prior written permission
22
+ # from Rebellions Inc.
23
+ import torch
24
+
25
+ from ....utils import logging
26
+ from ...models.decoderonly import (
27
+ DecoderOnlyAttention,
28
+ DecoderOnlyDecoderLayer,
29
+ DecoderOnlyModel,
30
+ DecoderOnlyWrapper,
31
+ RotaryEmbedding,
32
+ )
33
+
34
+
35
+ logger = logging.get_logger(__name__)
36
+
37
+
38
+ class ExaoneForCausalLMWrapper(DecoderOnlyWrapper):
39
+ """A wrapper class for the Exaone model with a language modeling head."""
40
+
41
+ def __init__(self, model, max_seq_len, kvcache_partition_len=None):
42
+ super(DecoderOnlyWrapper, self).__init__()
43
+ self.config = model.config
44
+ self.model = self.convert_attribute_name(model.transformer)
45
+ self.lm_head = model.lm_head
46
+ self.rotary_emb = RotaryEmbedding(config=self.config, max_seq_len_cached=max_seq_len)
47
+
48
+ if kvcache_partition_len is not None:
49
+ # WORKAROUND : for passing partition length as a value to the rbln compiler.
50
+ # What is actually used is the shape of this tensor.
51
+ self.kvcache_partition_size = torch.zeros(kvcache_partition_len, dtype=torch.int32)
52
+ self.attn_implementation = "flash_attn_rbln"
53
+ logger.info(f"Using rbln-flash-attention. (partition length : {kvcache_partition_len})")
54
+ else:
55
+ self.kvcache_partition_size = None
56
+ self.attn_implementation = "eager"
57
+
58
+ @staticmethod
59
+ def convert_attribute_name(model):
60
+ model.embed_tokens = model.wte
61
+ model.norm = model.ln_f
62
+ model.layers = model.h
63
+
64
+ for layer in model.layers:
65
+ layer.input_layernorm = layer.ln_1
66
+ layer.self_attn = layer.attn.attention
67
+ layer.post_attention_layernorm = layer.ln_2
68
+ layer.self_attn.o_proj = layer.self_attn.out_proj
69
+
70
+ return model
71
+
72
+ def get_forward_dict(self):
73
+ forward_dict = {}
74
+ forward_dict.update(
75
+ {
76
+ "wrapper": DecoderOnlyModel.forward,
77
+ "model": DecoderOnlyDecoderLayer.forward,
78
+ "decoder_layer": DecoderOnlyAttention.forward,
79
+ }
80
+ )
81
+ return forward_dict