optimum-rbln 0.1.12__py3-none-any.whl → 0.1.13__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (52) hide show
  1. optimum/rbln/__init__.py +5 -1
  2. optimum/rbln/__version__.py +1 -1
  3. optimum/rbln/diffusers/models/autoencoder_kl.py +30 -61
  4. optimum/rbln/diffusers/models/controlnet.py +36 -56
  5. optimum/rbln/diffusers/models/unet_2d_condition.py +57 -153
  6. optimum/rbln/diffusers/pipelines/__init__.py +40 -12
  7. optimum/rbln/diffusers/pipelines/controlnet/multicontrolnet.py +7 -0
  8. optimum/rbln/diffusers/pipelines/controlnet/pipeline_controlnet.py +9 -185
  9. optimum/rbln/diffusers/pipelines/controlnet/pipeline_controlnet_img2img.py +8 -190
  10. optimum/rbln/diffusers/pipelines/controlnet/pipeline_controlnet_sd_xl.py +8 -191
  11. optimum/rbln/diffusers/pipelines/controlnet/pipeline_controlnet_sd_xl_img2img.py +8 -192
  12. optimum/rbln/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion.py +3 -110
  13. optimum/rbln/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_img2img.py +12 -115
  14. optimum/rbln/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl.py +4 -122
  15. optimum/rbln/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_img2img.py +4 -125
  16. optimum/rbln/modeling_base.py +12 -5
  17. optimum/rbln/modeling_diffusers.py +400 -0
  18. optimum/rbln/transformers/__init__.py +2 -0
  19. optimum/rbln/transformers/cache_utils.py +5 -9
  20. optimum/rbln/transformers/modeling_rope_utils.py +283 -0
  21. optimum/rbln/transformers/models/__init__.py +80 -31
  22. optimum/rbln/transformers/models/clip/modeling_clip.py +13 -22
  23. optimum/rbln/transformers/models/decoderonly/__init__.py +0 -2
  24. optimum/rbln/transformers/models/decoderonly/decoderonly_architecture.py +376 -218
  25. optimum/rbln/transformers/models/decoderonly/modeling_decoderonly.py +74 -16
  26. optimum/rbln/transformers/models/exaone/exaone_architecture.py +18 -9
  27. optimum/rbln/transformers/models/exaone/modeling_exaone.py +4 -29
  28. optimum/rbln/transformers/models/gemma/gemma_architecture.py +12 -2
  29. optimum/rbln/transformers/models/gemma/modeling_gemma.py +4 -28
  30. optimum/rbln/transformers/models/gpt2/modeling_gpt2.py +4 -30
  31. optimum/rbln/transformers/models/llama/modeling_llama.py +4 -28
  32. optimum/rbln/transformers/models/llava_next/modeling_llava_next.py +27 -8
  33. optimum/rbln/transformers/models/midm/midm_architecture.py +4 -15
  34. optimum/rbln/transformers/models/midm/modeling_midm.py +4 -29
  35. optimum/rbln/transformers/models/mistral/modeling_mistral.py +4 -29
  36. optimum/rbln/transformers/models/phi/modeling_phi.py +5 -31
  37. optimum/rbln/transformers/models/phi/phi_architecture.py +75 -159
  38. optimum/rbln/transformers/models/qwen2/modeling_qwen2.py +5 -29
  39. optimum/rbln/transformers/models/t5/__init__.py +1 -1
  40. optimum/rbln/transformers/models/t5/modeling_t5.py +57 -4
  41. optimum/rbln/transformers/models/whisper/modeling_whisper.py +1 -1
  42. optimum/rbln/transformers/utils/rbln_quantization.py +8 -2
  43. optimum/rbln/utils/context.py +58 -0
  44. optimum/rbln/utils/decorator_utils.py +55 -0
  45. optimum/rbln/utils/import_utils.py +7 -0
  46. optimum/rbln/utils/runtime_utils.py +4 -4
  47. optimum/rbln/utils/timer_utils.py +2 -2
  48. {optimum_rbln-0.1.12.dist-info → optimum_rbln-0.1.13.dist-info}/METADATA +8 -7
  49. {optimum_rbln-0.1.12.dist-info → optimum_rbln-0.1.13.dist-info}/RECORD +52 -48
  50. {optimum_rbln-0.1.12.dist-info → optimum_rbln-0.1.13.dist-info}/WHEEL +0 -0
  51. {optimum_rbln-0.1.12.dist-info → optimum_rbln-0.1.13.dist-info}/entry_points.txt +0 -0
  52. {optimum_rbln-0.1.12.dist-info → optimum_rbln-0.1.13.dist-info}/licenses/LICENSE +0 -0
@@ -22,14 +22,15 @@
22
22
  # from Rebellions Inc.
23
23
  import functools
24
24
  import glob
25
+ import inspect
25
26
  import os
26
- from abc import ABC
27
27
  from dataclasses import dataclass
28
28
  from pathlib import Path
29
- from typing import TYPE_CHECKING, Any, Dict, List, Optional, Tuple, Union
29
+ from typing import TYPE_CHECKING, Any, Callable, Dict, List, Optional, Tuple, Union
30
30
 
31
- import rebel # noqa: F401
32
- import torch # noqa: F401
31
+ import rebel
32
+ import torch
33
+ import transformers
33
34
  from safetensors.torch import load_file
34
35
  from transformers import AutoConfig, AutoModelForCausalLM, PretrainedConfig, PreTrainedModel
35
36
  from transformers.modeling_utils import no_init_weights
@@ -40,6 +41,7 @@ from ....modeling_config import DEFAULT_COMPILED_MODEL_NAME, RBLNCompileConfig,
40
41
  from ....utils.logging import get_logger
41
42
  from ....utils.runtime_utils import RBLNPytorchRuntime
42
43
  from ....utils.timer_utils import rbln_timer
44
+ from .decoderonly_architecture import DecoderOnlyWrapper
43
45
 
44
46
 
45
47
  logger = get_logger()
@@ -102,19 +104,47 @@ class RBLNDecoderOnlyOutput(ModelOutput):
102
104
  generate_idx: torch.Tensor = None
103
105
 
104
106
 
105
- class RBLNDecoderOnlyModelForCausalLM(RBLNModel, ABC):
107
+ class RBLNDecoderOnlyModelForCausalLM(RBLNModel):
106
108
  """
107
- The DecoderOnly Model transformer with a language modeling head (linear layer) on top.
108
- This model inherits from [`RBLNMultiModel`]. Check the superclass documentation for the generic methods the library implements for all its models.
109
-
110
- A class to convert and run pre-trained transformers based DecoderOnlyForCausalLM model on RBLN devices.
111
- It implements the methods to convert a pre-trained transformers DecoderOnlyForCausalLM model into a RBLN transformer model by:
112
- - transferring the checkpoint weights of the original into an optimized RBLN graph,
113
- - compiling the resulting graph using the RBLN compiler.
109
+ A base class for decoder-only transformer models optimized for causal language modeling tasks on RBLN devices.
110
+ This class serves as the foundation for various decoder-only architectures like GPT, LLaMA, etc.
111
+
112
+ The class provides core functionality for:
113
+ 1. Converting pre-trained transformer models to RBLN-optimized format
114
+ 2. Handling the compilation process for RBLN devices
115
+ 3. Managing inference operations for causal language modeling
116
+
117
+ This class inherits from RBLNModel and implements specific methods required for
118
+ decoder-only architectures and causal language modeling tasks.
119
+
120
+ Note:
121
+ - This class is designed to be subclassed by specific model implementations
122
+ (e.g., RBLNLlamaForCausalLM, RBLNGPT2LMHeadModel)
123
+ - Subclasses should implement model-specific conversion logic.
124
+ - The class handles RBLN-specific optimizations automatically during compilation
114
125
  """
115
126
 
116
127
  main_input_name = "input_ids"
117
128
  auto_model_class = AutoModelForCausalLM
129
+ _decoder_wrapper_cls = DecoderOnlyWrapper
130
+ _original_cls = None
131
+
132
+ @classmethod
133
+ @property
134
+ def original_cls(cls):
135
+ """
136
+ Lazily loads and caches the corresponding Hugging Face model class.
137
+ Removes 'RBLN' prefix from the class name to get the original class name
138
+ (e.g., RBLNLlamaForCausalLM -> LlamaForCausalLM) and imports it from
139
+ the transformers module.
140
+
141
+ Returns:
142
+ type: The original Hugging Face model class
143
+ """
144
+ if cls._original_cls is None:
145
+ hf_original_cls_name = cls.__name__[4:]
146
+ cls._original_cls = getattr(transformers, hf_original_cls_name)
147
+ return cls._original_cls
118
148
 
119
149
  def __post_init__(self, **kwargs):
120
150
  self.batch_size = self.rbln_config.model_cfg["batch_size"]
@@ -233,6 +263,26 @@ class RBLNDecoderOnlyModelForCausalLM(RBLNModel, ABC):
233
263
 
234
264
  return 0
235
265
 
266
+ def __getattr__(self, __name: str) -> Any:
267
+ """
268
+ Special method to delegate attribute access to the original Huggingface LM class.
269
+ This method is called when an attribute is not found in the current instance's dictionary.
270
+ It enables transparent access to the original model's attributes and methods while maintaining
271
+ proper method binding.
272
+
273
+ The method implements a delegation pattern that:
274
+ 1. For methods: Creates a wrapper that properly binds 'self' to method calls
275
+ 2. For other attributes: Returns them directly from the original class
276
+ """
277
+
278
+ def redirect(func):
279
+ return lambda *pargs, **kwargs: func(self, *pargs, **kwargs)
280
+
281
+ val = getattr(self.original_cls, __name)
282
+ if isinstance(val, Callable) and "self" in set(inspect.signature(val).parameters):
283
+ return redirect(val)
284
+ return val
285
+
236
286
  @classmethod
237
287
  def get_pytorch_model(cls, *args, **kwargs) -> "PreTrainedModel":
238
288
  rbln_kwargs = kwargs.get("rbln_kwargs", {})
@@ -293,6 +343,16 @@ class RBLNDecoderOnlyModelForCausalLM(RBLNModel, ABC):
293
343
 
294
344
  return wrapper
295
345
 
346
+ @classmethod
347
+ def wrap_model_if_needed(cls, model: "PreTrainedModel", rbln_config: "RBLNConfig"):
348
+ wrapper_cfg = {"max_seq_len": rbln_config.model_cfg["max_seq_len"]}
349
+
350
+ # If the model wrapper supports rbln-custom-flash-attention
351
+ if "kvcache_partition_len" in inspect.signature(cls._decoder_wrapper_cls.__init__).parameters:
352
+ wrapper_cfg["kvcache_partition_len"] = rbln_config.model_cfg.get("kvcache_partition_len")
353
+
354
+ return cls._decoder_wrapper_cls(model, **wrapper_cfg).eval()
355
+
296
356
  @classmethod
297
357
  @torch.inference_mode()
298
358
  def get_compiled_model(cls, model: "PreTrainedModel", rbln_config: RBLNConfig):
@@ -305,11 +365,9 @@ class RBLNDecoderOnlyModelForCausalLM(RBLNModel, ABC):
305
365
  @rbln_timer("JIT trace")
306
366
  def get_scripted_model():
307
367
  # This function is nested to dealloc the example inputs before compilation.
368
+ # FIXME: 3rd dummy_input(batch_idx) should be fill zero to compile flash_attn.
308
369
  prefill_example_inputs = prefill_rbln_compile_config.get_dummy_inputs(fill=0)
309
- dec_example_inputs = dec_rbln_compile_config.get_dummy_inputs(fill=4)
310
-
311
- batch_index = 3
312
- dec_example_inputs[batch_index].fill_(-1) # fill batch_position -1 to indicate it is decoder.
370
+ dec_example_inputs = dec_rbln_compile_config.get_dummy_inputs(fill=0)
313
371
 
314
372
  prefill_scripted_model = torch.jit.trace(
315
373
  wrapped_model, prefill_example_inputs, check_trace=False, _store_inputs=False
@@ -20,31 +20,40 @@
20
20
  # are the intellectual property of Rebellions Inc. and may not be
21
21
  # copied, modified, or distributed without prior written permission
22
22
  # from Rebellions Inc.
23
+ import torch
23
24
 
24
-
25
+ from ....utils import logging
25
26
  from ...models.decoderonly import (
26
27
  DecoderOnlyAttention,
27
28
  DecoderOnlyDecoderLayer,
28
29
  DecoderOnlyModel,
29
30
  DecoderOnlyWrapper,
31
+ RotaryEmbedding,
30
32
  )
31
33
 
32
34
 
35
+ logger = logging.get_logger(__name__)
36
+
37
+
33
38
  class ExaoneForCausalLMWrapper(DecoderOnlyWrapper):
34
39
  """A wrapper class for the Exaone model with a language modeling head."""
35
40
 
36
- def __init__(self, model, max_seq_len):
41
+ def __init__(self, model, max_seq_len, kvcache_partition_len=None):
37
42
  super(DecoderOnlyWrapper, self).__init__()
38
43
  self.config = model.config
39
44
  self.model = self.convert_attribute_name(model.transformer)
40
45
  self.lm_head = model.lm_head
41
- self.head_dim = self.config.hidden_size // self.config.num_attention_heads
42
- self.max_position_embeddings = (
43
- self.config.max_position_embeddings if max_seq_len > self.config.max_position_embeddings else max_seq_len
44
- )
45
- self.max_seq_len = max_seq_len
46
- self.rope_scaling = getattr(self.config, "rope_scaling", None)
47
- self.rotary_emb = self._init_rope()
46
+ self.rotary_emb = RotaryEmbedding(config=self.config, max_seq_len_cached=max_seq_len)
47
+
48
+ if kvcache_partition_len is not None:
49
+ # WORKAROUND : for passing partition length as a value to the rbln compiler.
50
+ # What is actually used is the shape of this tensor.
51
+ self.kvcache_partition_size = torch.zeros(kvcache_partition_len, dtype=torch.int32)
52
+ self.attn_implementation = "flash_attn_rbln"
53
+ logger.info(f"Using rbln-flash-attention. (partition length : {kvcache_partition_len})")
54
+ else:
55
+ self.kvcache_partition_size = None
56
+ self.attn_implementation = "eager"
48
57
 
49
58
  @staticmethod
50
59
  def convert_attribute_name(model):
@@ -21,21 +21,13 @@
21
21
  # copied, modified, or distributed without prior written permission
22
22
  # from Rebellions Inc.
23
23
 
24
- import inspect
25
- import logging
26
- from typing import TYPE_CHECKING, Any, Callable
27
-
28
- from ....modeling_config import RBLNConfig
24
+ from ....utils import logging
29
25
  from ..decoderonly import RBLNDecoderOnlyModelForCausalLM
30
26
  from .exaone_architecture import ExaoneForCausalLMWrapper
31
27
  from .hf_hub_cached.modeling_exaone import ExaoneForCausalLM
32
28
 
33
29
 
34
- logger = logging.getLogger(__name__)
35
- if TYPE_CHECKING:
36
- from transformers import (
37
- PreTrainedModel,
38
- )
30
+ logger = logging.get_logger(__name__)
39
31
 
40
32
 
41
33
  class RBLNExaoneForCausalLM(RBLNDecoderOnlyModelForCausalLM):
@@ -52,25 +44,8 @@ class RBLNExaoneForCausalLM(RBLNDecoderOnlyModelForCausalLM):
52
44
 
53
45
  """
54
46
 
55
- @classmethod
56
- def wrap_model_if_needed(self, model: "PreTrainedModel", rbln_config: "RBLNConfig"):
57
- rbln_max_seq_len = rbln_config.model_cfg["max_seq_len"]
58
- return ExaoneForCausalLMWrapper(model, rbln_max_seq_len).eval()
59
-
60
- def __getattr__(self, __name: str) -> Any:
61
- """This is the key method to implement RBLN-Exaone.
62
-
63
- Returns:
64
- Any: Exaone's corresponding method
65
- """
66
-
67
- def redirect(func):
68
- return lambda *pargs, **kwargs: func(self, *pargs, **kwargs)
69
-
70
- val = getattr(ExaoneForCausalLM, __name)
71
- if isinstance(val, Callable) and "self" in set(inspect.signature(val).parameters):
72
- return redirect(val)
73
- return val
47
+ _decoder_wrapper_cls = ExaoneForCausalLMWrapper
48
+ _original_cls = ExaoneForCausalLM
74
49
 
75
50
  @classmethod
76
51
  def from_pretrained(cls, *args, **kwargs):
@@ -29,11 +29,11 @@ from transformers.modeling_outputs import (
29
29
  )
30
30
 
31
31
  from ...models.decoderonly import (
32
- DecoderOnlyAttention,
33
32
  DecoderOnlyDecoderLayer,
34
33
  DecoderOnlyWrapper,
35
34
  slice_and_unsqueeze_cos_sin,
36
35
  )
36
+ from ...models.decoderonly.decoderonly_architecture import DECODERONLY_ATTENTION_CLASSES
37
37
 
38
38
 
39
39
  class GemmaWrapper(DecoderOnlyWrapper):
@@ -43,7 +43,7 @@ class GemmaWrapper(DecoderOnlyWrapper):
43
43
  {
44
44
  "wrapper": GemmaModel.forward,
45
45
  "model": DecoderOnlyDecoderLayer.forward,
46
- "decoder_layer": DecoderOnlyAttention.forward,
46
+ "decoder_layer": DECODERONLY_ATTENTION_CLASSES[self.attn_implementation].forward,
47
47
  }
48
48
  )
49
49
  return forward_dict
@@ -61,9 +61,17 @@ class GemmaModel:
61
61
  use_cache: Optional[bool] = True,
62
62
  output_attentions: Optional[bool] = False,
63
63
  output_hidden_states: Optional[bool] = False,
64
+ cache_pos_for_partitions: Optional[torch.Tensor] = None,
65
+ kvcache_partition_size: Optional[torch.Tensor] = None,
64
66
  forward_dict: Optional[Dict[str, classmethod]] = None,
65
67
  rotary_pos_emb=None,
66
68
  ) -> Union[Tuple, BaseModelOutputWithPast]:
69
+ # retrieve input_ids and inputs_embeds
70
+ if (input_ids is None) ^ (inputs_embeds is not None):
71
+ raise ValueError(
72
+ "You cannot specify both input_ids and inputs_embeds at the same time, and must specify either one"
73
+ )
74
+
67
75
  # embed positions
68
76
  inputs_embeds = self.embed_tokens(input_ids)
69
77
  hidden_states = inputs_embeds
@@ -96,6 +104,8 @@ class GemmaModel:
96
104
  batch_ids=batch_ids,
97
105
  cos=cos,
98
106
  sin=sin,
107
+ cache_pos_for_partitions=cache_pos_for_partitions,
108
+ kvcache_partition_size=kvcache_partition_size,
99
109
  forward_dict=forward_dict,
100
110
  )
101
111
 
@@ -21,28 +21,18 @@
21
21
  # copied, modified, or distributed without prior written permission
22
22
  # from Rebellions Inc.
23
23
 
24
- import inspect
25
- import logging
26
- from typing import TYPE_CHECKING, Any, Callable
27
-
28
- from transformers import GemmaForCausalLM
29
-
24
+ from ....utils import logging
30
25
  from ...models.decoderonly import RBLNDecoderOnlyModelForCausalLM
31
26
  from .gemma_architecture import GemmaWrapper
32
27
 
33
28
 
34
- if TYPE_CHECKING:
35
- from transformers import PreTrainedModel
36
-
37
- from ....modeling_config import RBLNConfig
38
-
39
- logger = logging.getLogger(__name__)
29
+ logger = logging.get_logger(__name__)
40
30
 
41
31
 
42
32
  class RBLNGemmaForCausalLM(RBLNDecoderOnlyModelForCausalLM):
43
33
  """
44
34
  The Gemma Model transformer with a language modeling head (linear layer) on top.
45
- This model inherits from [`RBLNMultiModel`]. Check the superclass documentation for the generic methods the library implements for all its models.
35
+ This model inherits from [`RBLNDecoderOnlyModelForCausalLM`]. Check the superclass documentation for the generic methods the library implements for all its models.
46
36
 
47
37
  A class to convert and run pre-trained transformers based GemmaForCausalLM model on RBLN devices.
48
38
  It implements the methods to convert a pre-trained transformers GemmaForCausalLM model into a RBLN transformer model by:
@@ -50,18 +40,4 @@ class RBLNGemmaForCausalLM(RBLNDecoderOnlyModelForCausalLM):
50
40
  - compiling the resulting graph using the RBLN compiler.
51
41
  """
52
42
 
53
- @classmethod
54
- def wrap_model_if_needed(self, model: "PreTrainedModel", rbln_config: "RBLNConfig"):
55
- rbln_max_seq_len = rbln_config.model_cfg["max_seq_len"]
56
- return GemmaWrapper(model, rbln_max_seq_len).eval()
57
-
58
- def __getattr__(self, __name: str) -> Any:
59
- def redirect(func):
60
- return lambda *pargs, **kwargs: func(self, *pargs, **kwargs)
61
-
62
- val = getattr(GemmaForCausalLM, __name)
63
-
64
- if isinstance(val, Callable) and "self" in set(inspect.signature(val).parameters):
65
- return redirect(val)
66
-
67
- return val
43
+ _decoder_wrapper_cls = GemmaWrapper
@@ -21,20 +21,12 @@
21
21
  # copied, modified, or distributed without prior written permission
22
22
  # from Rebellions Inc.
23
23
 
24
- import inspect
25
- import logging
26
- from typing import TYPE_CHECKING, Any, Callable
27
-
28
- from transformers import GPT2LMHeadModel
29
-
30
- from ....modeling_config import RBLNConfig
24
+ from ....utils import logging
31
25
  from ...models.decoderonly import RBLNDecoderOnlyModelForCausalLM
32
26
  from .gpt2_architecture import GPT2LMHeadModelWrapper
33
27
 
34
28
 
35
- logger = logging.getLogger(__name__)
36
- if TYPE_CHECKING:
37
- from transformers import PreTrainedModel
29
+ logger = logging.get_logger(__name__)
38
30
 
39
31
 
40
32
  class RBLNGPT2LMHeadModel(RBLNDecoderOnlyModelForCausalLM):
@@ -42,7 +34,7 @@ class RBLNGPT2LMHeadModel(RBLNDecoderOnlyModelForCausalLM):
42
34
  The GPT2 Model transformer with a language modeling head on top (linear layer with weights tied to the input
43
35
  embeddings).
44
36
 
45
- This model inherits from [`RBLNMultiModel`]. Check the superclass documentation for the generic methods the
37
+ This model inherits from [`RBLNDecoderOnlyModelForCausalLM`]. Check the superclass documentation for the generic methods the
46
38
  library implements for all its model.
47
39
 
48
40
  It implements the methods to convert a pre-trained transformers GPT2 model into a RBLN transformer model by:
@@ -51,22 +43,4 @@ class RBLNGPT2LMHeadModel(RBLNDecoderOnlyModelForCausalLM):
51
43
 
52
44
  """
53
45
 
54
- @classmethod
55
- def wrap_model_if_needed(self, model: "PreTrainedModel", rbln_config: "RBLNConfig"):
56
- rbln_max_seq_len = rbln_config.model_cfg["max_seq_len"]
57
- return GPT2LMHeadModelWrapper(model, rbln_max_seq_len).eval()
58
-
59
- def __getattr__(self, __name: str) -> Any:
60
- """This is the key method to implement RBLN-GPT2.
61
-
62
- Returns:
63
- Any: GPT2's corresponding method
64
- """
65
-
66
- def redirect(func):
67
- return lambda *pargs, **kwargs: func(self, *pargs, **kwargs)
68
-
69
- val = getattr(GPT2LMHeadModel, __name)
70
- if isinstance(val, Callable) and "self" in set(inspect.signature(val).parameters):
71
- return redirect(val)
72
- return val
46
+ _decoder_wrapper_cls = GPT2LMHeadModelWrapper
@@ -21,28 +21,18 @@
21
21
  # copied, modified, or distributed without prior written permission
22
22
  # from Rebellions Inc.
23
23
 
24
- import inspect
25
- import logging
26
- from typing import TYPE_CHECKING, Any, Callable
27
-
28
- from transformers import LlamaForCausalLM
29
-
24
+ from ....utils import logging
30
25
  from ...models.decoderonly import RBLNDecoderOnlyModelForCausalLM
31
26
  from .llama_architecture import LlamaWrapper
32
27
 
33
28
 
34
- if TYPE_CHECKING:
35
- from transformers import PreTrainedModel
36
-
37
- from ....modeling_config import RBLNConfig
38
-
39
- logger = logging.getLogger(__name__)
29
+ logger = logging.get_logger(__name__)
40
30
 
41
31
 
42
32
  class RBLNLlamaForCausalLM(RBLNDecoderOnlyModelForCausalLM):
43
33
  """
44
34
  The Llama Model transformer with a language modeling head (linear layer) on top.
45
- This model inherits from [`RBLNMultiModel`]. Check the superclass documentation for the generic methods the library implements for all its models.
35
+ This model inherits from [`RBLNDecoderOnlyModelForCausalLM`]. Check the superclass documentation for the generic methods the library implements for all its models.
46
36
 
47
37
  A class to convert and run pre-trained transformers based LlamaForCausalLM model on RBLN devices.
48
38
  It implements the methods to convert a pre-trained transformers LlamaForCausalLM model into a RBLN transformer model by:
@@ -50,18 +40,4 @@ class RBLNLlamaForCausalLM(RBLNDecoderOnlyModelForCausalLM):
50
40
  - compiling the resulting graph using the RBLN compiler.
51
41
  """
52
42
 
53
- @classmethod
54
- def wrap_model_if_needed(self, model: "PreTrainedModel", rbln_config: "RBLNConfig"):
55
- rbln_max_seq_len = rbln_config.model_cfg["max_seq_len"]
56
- return LlamaWrapper(model, rbln_max_seq_len).eval()
57
-
58
- def __getattr__(self, __name: str) -> Any:
59
- def redirect(func):
60
- return lambda *pargs, **kwargs: func(self, *pargs, **kwargs)
61
-
62
- val = getattr(LlamaForCausalLM, __name)
63
-
64
- if isinstance(val, Callable) and "self" in set(inspect.signature(val).parameters):
65
- return redirect(val)
66
-
67
- return val
43
+ _decoder_wrapper_cls = LlamaWrapper
@@ -350,9 +350,22 @@ class RBLNLlavaNextForConditionalGeneration(RBLNModel):
350
350
  is_prefill_phase = not generate_idx.bool().all()
351
351
 
352
352
  if is_prefill_phase:
353
+ # if the number of image tokens is more than image embeddings seq length, then prob we expanded it in processing
354
+ # not very reliable, but we don't expect one to actually pass 500+ images for one prompt
355
+ # In case we're in decoding stage, legacy behavior is checked by presence of pixel values even if use_cache=True
356
+ legacy_processing = (
357
+ (input_ids == self.config.image_token_index).sum(1).max() < self.config.image_seq_length
358
+ ) or (input_ids.shape[-1] == 1 and pixel_values is not None)
359
+
353
360
  # Get the number of images in the prompt
354
361
  special_image_token_masks = [input_id == self.config.image_token_index for input_id in input_ids]
355
- num_special_image_tokens = [torch.sum(mask, dim=-1) for mask in special_image_token_masks]
362
+ if legacy_processing:
363
+ num_special_image_tokens = [torch.sum(mask, dim=-1) for mask in special_image_token_masks]
364
+ else:
365
+ image_tokens_masks_diff = [
366
+ torch.diff(mask, prepend=torch.tensor([0])) for mask in special_image_token_masks
367
+ ]
368
+ num_special_image_tokens = [int(torch.sum((diff == 1).int())) for diff in image_tokens_masks_diff]
356
369
 
357
370
  # Split images for each prompt
358
371
  if pixel_values is not None and pixel_values.size(0) > 0:
@@ -370,13 +383,19 @@ class RBLNLlavaNextForConditionalGeneration(RBLNModel):
370
383
  image_features, feature_lens = self.image_embedding(
371
384
  image_sizes[b_idx], pixel_values[b_idx], vision_feature_layer, vision_feature_select_strategy
372
385
  )
373
- inputs_embed, _, _, _, _ = self._merge_input_ids_with_image_features(
374
- image_features,
375
- feature_lens,
376
- inputs_embed.to(image_features.dtype),
377
- input_id,
378
- torch.ones_like(input_id, dtype=torch.long),
379
- )
386
+ if legacy_processing:
387
+ inputs_embed, _, _, _, _ = self._merge_input_ids_with_image_features(
388
+ image_features,
389
+ feature_lens,
390
+ inputs_embed.to(image_features.dtype),
391
+ input_id,
392
+ torch.ones_like(input_id, dtype=torch.long),
393
+ )
394
+ else:
395
+ special_image_mask = (
396
+ (input_id == self.config.image_token_index).unsqueeze(-1).expand_as(inputs_embed)
397
+ )
398
+ inputs_embed = inputs_embed.masked_scatter(special_image_mask, image_features)
380
399
 
381
400
  # Update generate_idx according to inputs_embed
382
401
  generate_idx[b_idx] = inputs_embed.shape[1]
@@ -58,23 +58,12 @@ class MidmLMHeadModelWrapper(torch.nn.Module):
58
58
  self.model = model.transformer
59
59
  self.lm_head = model.lm_head
60
60
  self.config = model.config
61
- self.head_dim = self.config.n_embd // self.config.n_head
62
- self.max_position_embeddings = (
63
- self.config.max_position_embeddings if max_seq_len > self.config.max_position_embeddings else max_seq_len
64
- )
65
61
  self.max_seq_len = max_seq_len
66
- self.rotary_dim = int(
67
- model.config.hidden_size // model.config.num_attention_heads * model.config.rotary_percentage
68
- )
69
- self.rotary_emb = self._init_rope()
70
62
 
71
- def _init_rope(self):
72
- """Initializes the Rotary Position Embeddings."""
73
- rotary_emb = RotaryEmbedding(
74
- self.rotary_dim,
75
- max_position_embeddings=self.max_position_embeddings,
76
- )
77
- return rotary_emb
63
+ self.config.partial_rotary_factor = model.config.rotary_percentage
64
+ self.config.head_dim = self.config.n_embd // self.config.n_head
65
+ self.config.rope_theta = 10000
66
+ self.rotary_emb = RotaryEmbedding(config=self.config, max_seq_len_cached=max_seq_len)
78
67
 
79
68
  def forward(
80
69
  self,
@@ -21,11 +21,7 @@
21
21
  # copied, modified, or distributed without prior written permission
22
22
  # from Rebellions Inc.
23
23
 
24
- import inspect
25
- import logging
26
- from typing import TYPE_CHECKING, Any, Callable
27
-
28
- from ....modeling_config import RBLNConfig
24
+ from ....utils import logging
29
25
  from ...models.decoderonly import RBLNDecoderOnlyModelForCausalLM
30
26
  from .hf_hub_cached.modeling_midm import MidmLMHeadModel
31
27
  from .midm_architecture import (
@@ -33,11 +29,7 @@ from .midm_architecture import (
33
29
  )
34
30
 
35
31
 
36
- logger = logging.getLogger(__name__)
37
- if TYPE_CHECKING:
38
- from transformers import (
39
- PreTrainedModel,
40
- )
32
+ logger = logging.get_logger(__name__)
41
33
 
42
34
 
43
35
  class RBLNMidmLMHeadModel(RBLNDecoderOnlyModelForCausalLM):
@@ -54,25 +46,8 @@ class RBLNMidmLMHeadModel(RBLNDecoderOnlyModelForCausalLM):
54
46
 
55
47
  """
56
48
 
57
- @classmethod
58
- def wrap_model_if_needed(self, model: "PreTrainedModel", rbln_config: "RBLNConfig"):
59
- rbln_max_seq_len = rbln_config.model_cfg["max_seq_len"]
60
- return MidmLMHeadModelWrapper(model, rbln_max_seq_len).eval()
61
-
62
- def __getattr__(self, __name: str) -> Any:
63
- """This is the key method to implement RBLN-Midm.
64
-
65
- Returns:
66
- Any: Midm's corresponding method
67
- """
68
-
69
- def redirect(func):
70
- return lambda *pargs, **kwargs: func(self, *pargs, **kwargs)
71
-
72
- val = getattr(MidmLMHeadModel, __name)
73
- if isinstance(val, Callable) and "self" in set(inspect.signature(val).parameters):
74
- return redirect(val)
75
- return val
49
+ _decoder_wrapper_cls = MidmLMHeadModelWrapper
50
+ _original_cls = MidmLMHeadModel
76
51
 
77
52
  @classmethod
78
53
  def from_pretrained(cls, *args, **kwargs):
@@ -21,29 +21,18 @@
21
21
  # copied, modified, or distributed without prior written permission
22
22
  # from Rebellions Inc.
23
23
 
24
- import inspect
25
- import logging
26
- from typing import TYPE_CHECKING, Any, Callable
27
-
28
- from transformers import MistralForCausalLM
29
-
24
+ from ....utils import logging
30
25
  from ...models.decoderonly import RBLNDecoderOnlyModelForCausalLM
31
26
  from .mistral_architecture import MistralForCausalLMWrapper
32
27
 
33
28
 
34
- if TYPE_CHECKING:
35
- from transformers import PreTrainedModel
36
-
37
- from ....modeling_config import RBLNConfig
38
-
39
-
40
- logger = logging.getLogger(__name__)
29
+ logger = logging.get_logger(__name__)
41
30
 
42
31
 
43
32
  class RBLNMistralForCausalLM(RBLNDecoderOnlyModelForCausalLM):
44
33
  """
45
34
  The Llama Model transformer with a language modeling head (linear layer) on top.
46
- This model inherits from [`RBLNMultiModel`]. Check the superclass documentation for the generic methods the library implements for all its models.
35
+ This model inherits from [`RBLNDecoderOnlyModelForCausalLM`]. Check the superclass documentation for the generic methods the library implements for all its models.
47
36
 
48
37
  A class to convert and run pre-trained transformers based LlamaForCausalLM model on RBLN devices.
49
38
  It implements the methods to convert a pre-trained transformers LlamaForCausalLM model into a RBLN transformer model by:
@@ -51,18 +40,4 @@ class RBLNMistralForCausalLM(RBLNDecoderOnlyModelForCausalLM):
51
40
  - compiling the resulting graph using the RBLN compiler.
52
41
  """
53
42
 
54
- @classmethod
55
- def wrap_model_if_needed(self, model: "PreTrainedModel", rbln_config: "RBLNConfig"):
56
- rbln_max_seq_len = rbln_config.model_cfg["max_seq_len"]
57
- return MistralForCausalLMWrapper(model, rbln_max_seq_len).eval()
58
-
59
- def __getattr__(self, __name: str) -> Any:
60
- def redirect(func):
61
- return lambda *pargs, **kwargs: func(self, *pargs, **kwargs)
62
-
63
- val = getattr(MistralForCausalLM, __name)
64
-
65
- if isinstance(val, Callable) and "self" in set(inspect.signature(val).parameters):
66
- return redirect(val)
67
-
68
- return val
43
+ _decoder_wrapper_cls = MistralForCausalLMWrapper