optimum-rbln 0.1.12__py3-none-any.whl → 0.1.15__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (90) hide show
  1. optimum/rbln/__init__.py +27 -13
  2. optimum/rbln/__version__.py +16 -1
  3. optimum/rbln/diffusers/__init__.py +22 -2
  4. optimum/rbln/diffusers/models/__init__.py +34 -3
  5. optimum/rbln/{transformers/generation → diffusers/models/autoencoders}/__init__.py +1 -2
  6. optimum/rbln/diffusers/models/{autoencoder_kl.py → autoencoders/autoencoder_kl.py} +66 -111
  7. optimum/rbln/diffusers/models/autoencoders/vae.py +84 -0
  8. optimum/rbln/diffusers/models/controlnet.py +85 -65
  9. optimum/rbln/diffusers/models/transformers/__init__.py +24 -0
  10. optimum/rbln/diffusers/models/transformers/transformer_sd3.py +203 -0
  11. optimum/rbln/diffusers/models/unets/__init__.py +24 -0
  12. optimum/rbln/diffusers/models/{unet_2d_condition.py → unets/unet_2d_condition.py} +129 -163
  13. optimum/rbln/diffusers/pipelines/__init__.py +60 -12
  14. optimum/rbln/diffusers/pipelines/controlnet/multicontrolnet.py +11 -25
  15. optimum/rbln/diffusers/pipelines/controlnet/pipeline_controlnet.py +9 -185
  16. optimum/rbln/diffusers/pipelines/controlnet/pipeline_controlnet_img2img.py +9 -190
  17. optimum/rbln/diffusers/pipelines/controlnet/pipeline_controlnet_sd_xl.py +9 -191
  18. optimum/rbln/diffusers/pipelines/controlnet/pipeline_controlnet_sd_xl_img2img.py +9 -192
  19. optimum/rbln/diffusers/pipelines/stable_diffusion/__init__.py +1 -0
  20. optimum/rbln/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion.py +4 -110
  21. optimum/rbln/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_img2img.py +4 -118
  22. optimum/rbln/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_inpaint.py +32 -0
  23. optimum/rbln/diffusers/pipelines/stable_diffusion_3/__init__.py +26 -0
  24. optimum/rbln/diffusers/pipelines/stable_diffusion_3/pipeline_stable_diffusion_3.py +32 -0
  25. optimum/rbln/diffusers/pipelines/stable_diffusion_3/pipeline_stable_diffusion_3_img2img.py +32 -0
  26. optimum/rbln/diffusers/pipelines/stable_diffusion_3/pipeline_stable_diffusion_3_inpaint.py +32 -0
  27. optimum/rbln/diffusers/pipelines/stable_diffusion_xl/__init__.py +1 -0
  28. optimum/rbln/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl.py +18 -128
  29. optimum/rbln/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_img2img.py +18 -131
  30. optimum/rbln/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_inpaint.py +32 -0
  31. optimum/rbln/modeling.py +572 -0
  32. optimum/rbln/modeling_alias.py +1 -1
  33. optimum/rbln/modeling_base.py +176 -763
  34. optimum/rbln/modeling_diffusers.py +329 -0
  35. optimum/rbln/transformers/__init__.py +2 -2
  36. optimum/rbln/transformers/cache_utils.py +5 -9
  37. optimum/rbln/transformers/modeling_rope_utils.py +283 -0
  38. optimum/rbln/transformers/models/__init__.py +80 -31
  39. optimum/rbln/transformers/models/auto/auto_factory.py +117 -23
  40. optimum/rbln/transformers/models/auto/modeling_auto.py +37 -12
  41. optimum/rbln/transformers/models/bart/modeling_bart.py +3 -6
  42. optimum/rbln/transformers/models/bert/modeling_bert.py +3 -6
  43. optimum/rbln/transformers/models/clip/modeling_clip.py +8 -34
  44. optimum/rbln/transformers/models/decoderonly/__init__.py +0 -5
  45. optimum/rbln/transformers/models/decoderonly/decoderonly_architecture.py +779 -361
  46. optimum/rbln/transformers/models/decoderonly/modeling_decoderonly.py +83 -142
  47. optimum/rbln/transformers/models/dpt/modeling_dpt.py +1 -1
  48. optimum/rbln/transformers/models/exaone/exaone_architecture.py +64 -39
  49. optimum/rbln/transformers/models/exaone/modeling_exaone.py +6 -29
  50. optimum/rbln/transformers/models/gemma/gemma_architecture.py +31 -92
  51. optimum/rbln/transformers/models/gemma/modeling_gemma.py +4 -28
  52. optimum/rbln/transformers/models/gpt2/gpt2_architecture.py +50 -238
  53. optimum/rbln/transformers/models/gpt2/modeling_gpt2.py +6 -31
  54. optimum/rbln/transformers/models/llama/modeling_llama.py +4 -28
  55. optimum/rbln/transformers/models/llava_next/modeling_llava_next.py +29 -83
  56. optimum/rbln/transformers/models/midm/midm_architecture.py +88 -253
  57. optimum/rbln/transformers/models/midm/modeling_midm.py +8 -33
  58. optimum/rbln/transformers/models/mistral/modeling_mistral.py +4 -29
  59. optimum/rbln/transformers/models/phi/modeling_phi.py +5 -31
  60. optimum/rbln/transformers/models/phi/phi_architecture.py +61 -345
  61. optimum/rbln/transformers/models/qwen2/modeling_qwen2.py +5 -29
  62. optimum/rbln/transformers/models/seq2seq/modeling_seq2seq.py +1 -46
  63. optimum/rbln/transformers/models/t5/__init__.py +1 -1
  64. optimum/rbln/transformers/models/t5/modeling_t5.py +157 -6
  65. optimum/rbln/transformers/models/wav2vec2/modeling_wav2vec2.py +1 -1
  66. optimum/rbln/transformers/models/whisper/modeling_whisper.py +2 -2
  67. optimum/rbln/transformers/models/xlm_roberta/modeling_xlm_roberta.py +3 -35
  68. optimum/rbln/transformers/utils/rbln_quantization.py +128 -5
  69. optimum/rbln/utils/decorator_utils.py +59 -0
  70. optimum/rbln/utils/hub.py +131 -0
  71. optimum/rbln/utils/import_utils.py +21 -0
  72. optimum/rbln/utils/model_utils.py +53 -0
  73. optimum/rbln/utils/runtime_utils.py +5 -5
  74. optimum/rbln/utils/submodule.py +114 -0
  75. optimum/rbln/utils/timer_utils.py +2 -2
  76. optimum_rbln-0.1.15.dist-info/METADATA +106 -0
  77. optimum_rbln-0.1.15.dist-info/RECORD +110 -0
  78. {optimum_rbln-0.1.12.dist-info → optimum_rbln-0.1.15.dist-info}/WHEEL +1 -1
  79. optimum/rbln/transformers/generation/streamers.py +0 -139
  80. optimum/rbln/transformers/generation/utils.py +0 -397
  81. optimum/rbln/transformers/models/exaone/hf_hub_cached/configuration_exaone.py +0 -181
  82. optimum/rbln/transformers/models/exaone/hf_hub_cached/modeling_exaone.py +0 -1725
  83. optimum/rbln/transformers/models/midm/hf_hub_cached/configuration_midm.py +0 -22
  84. optimum/rbln/transformers/models/midm/hf_hub_cached/midm_bitext_tokenization.py +0 -304
  85. optimum/rbln/transformers/models/midm/hf_hub_cached/modeling_midm.py +0 -1469
  86. optimum/rbln/transformers/models/midm/hf_hub_cached/rotary_position_embedding.py +0 -98
  87. optimum_rbln-0.1.12.dist-info/METADATA +0 -119
  88. optimum_rbln-0.1.12.dist-info/RECORD +0 -103
  89. optimum_rbln-0.1.12.dist-info/entry_points.txt +0 -4
  90. {optimum_rbln-0.1.12.dist-info → optimum_rbln-0.1.15.dist-info}/licenses/LICENSE +0 -0
@@ -23,7 +23,7 @@
23
23
  import inspect
24
24
  import logging
25
25
  from pathlib import Path
26
- from typing import TYPE_CHECKING, Any, Callable, Dict, List, Optional, Tuple, Union
26
+ from typing import TYPE_CHECKING, Any, Callable, Dict, Optional, Tuple, Union
27
27
 
28
28
  import numpy as np
29
29
  import torch
@@ -36,7 +36,7 @@ from transformers import (
36
36
  from transformers.modeling_outputs import BaseModelOutputWithPooling
37
37
  from transformers.models.llava_next.modeling_llava_next import LlavaNextCausalLMOutputWithPast
38
38
 
39
- from ....modeling_base import RBLNModel
39
+ from ....modeling import RBLNModel
40
40
  from ....modeling_config import RBLNCompileConfig, RBLNConfig
41
41
  from ..decoderonly.modeling_decoderonly import RBLNDecoderOnlyOutput
42
42
 
@@ -166,19 +166,6 @@ class RBLNLlavaNextForConditionalGeneration(RBLNModel):
166
166
  self._padding_side = "left" # set it to left by default, user can use setter to change padding_sides
167
167
  return super().__post_init__(**kwargs)
168
168
 
169
- @classmethod
170
- def get_pytorch_model(
171
- cls,
172
- model_id: str,
173
- *args,
174
- rbln_kwargs: Optional[Dict[str, Any]] = None,
175
- **kwargs,
176
- ) -> "PreTrainedModel":
177
- # Optimum's TasksManager does not handle Llava.
178
- kwargs = cls.update_kwargs(kwargs)
179
- model = LlavaNextForConditionalGeneration.from_pretrained(model_id, *args, **kwargs)
180
- return model
181
-
182
169
  def get_input_embeddings(self):
183
170
  return self.language_model.get_input_embeddings()
184
171
 
@@ -350,9 +337,22 @@ class RBLNLlavaNextForConditionalGeneration(RBLNModel):
350
337
  is_prefill_phase = not generate_idx.bool().all()
351
338
 
352
339
  if is_prefill_phase:
340
+ # if the number of image tokens is more than image embeddings seq length, then prob we expanded it in processing
341
+ # not very reliable, but we don't expect one to actually pass 500+ images for one prompt
342
+ # In case we're in decoding stage, legacy behavior is checked by presence of pixel values even if use_cache=True
343
+ legacy_processing = (
344
+ (input_ids == self.config.image_token_index).sum(1).max() < self.config.image_seq_length
345
+ ) or (input_ids.shape[-1] == 1 and pixel_values is not None)
346
+
353
347
  # Get the number of images in the prompt
354
348
  special_image_token_masks = [input_id == self.config.image_token_index for input_id in input_ids]
355
- num_special_image_tokens = [torch.sum(mask, dim=-1) for mask in special_image_token_masks]
349
+ if legacy_processing:
350
+ num_special_image_tokens = [torch.sum(mask, dim=-1) for mask in special_image_token_masks]
351
+ else:
352
+ image_tokens_masks_diff = [
353
+ torch.diff(mask, prepend=torch.tensor([0])) for mask in special_image_token_masks
354
+ ]
355
+ num_special_image_tokens = [int(torch.sum((diff == 1).int())) for diff in image_tokens_masks_diff]
356
356
 
357
357
  # Split images for each prompt
358
358
  if pixel_values is not None and pixel_values.size(0) > 0:
@@ -370,13 +370,19 @@ class RBLNLlavaNextForConditionalGeneration(RBLNModel):
370
370
  image_features, feature_lens = self.image_embedding(
371
371
  image_sizes[b_idx], pixel_values[b_idx], vision_feature_layer, vision_feature_select_strategy
372
372
  )
373
- inputs_embed, _, _, _, _ = self._merge_input_ids_with_image_features(
374
- image_features,
375
- feature_lens,
376
- inputs_embed.to(image_features.dtype),
377
- input_id,
378
- torch.ones_like(input_id, dtype=torch.long),
379
- )
373
+ if legacy_processing:
374
+ inputs_embed, _, _, _, _ = self._merge_input_ids_with_image_features(
375
+ image_features,
376
+ feature_lens,
377
+ inputs_embed.to(image_features.dtype),
378
+ input_id,
379
+ torch.ones_like(input_id, dtype=torch.long),
380
+ )
381
+ else:
382
+ special_image_mask = (
383
+ (input_id == self.config.image_token_index).unsqueeze(-1).expand_as(inputs_embed)
384
+ )
385
+ inputs_embed = inputs_embed.masked_scatter(special_image_mask, image_features)
380
386
 
381
387
  # Update generate_idx according to inputs_embed
382
388
  generate_idx[b_idx] = inputs_embed.shape[1]
@@ -403,66 +409,6 @@ class RBLNLlavaNextForConditionalGeneration(RBLNModel):
403
409
 
404
410
  return outputs
405
411
 
406
- def vllm_forward(
407
- self,
408
- input_ids: torch.LongTensor = None,
409
- pixel_values: torch.FloatTensor = None,
410
- image_sizes: Optional[torch.LongTensor] = None,
411
- inputs_embeds: Optional[torch.FloatTensor] = None,
412
- vision_feature_layer: Optional[int] = None,
413
- vision_feature_select_strategy: Optional[str] = None,
414
- cache_position: Union[List[torch.Tensor], torch.Tensor] = None, # vllm keyword argument
415
- batch_idx: Optional[int] = None,
416
- **kwargs,
417
- ) -> Union[Tuple, LlavaNextCausalLMOutputWithPast]:
418
- is_prefill = cache_position.shape[-1] > 1
419
-
420
- if inputs_embeds is not None:
421
- raise NotImplementedError("Specifying inputs_embeds is not supported.")
422
-
423
- if is_prefill:
424
- # Get text_embeds
425
- inputs_embeds = self.text_embedding(input_ids)
426
-
427
- # If any images in the prompt, get image_embeds and merge with text
428
- if pixel_values is not None and input_ids.shape[1] != 1 and pixel_values.size(0) > 0:
429
- image_features, _ = self.image_embedding(
430
- image_sizes, pixel_values, vision_feature_layer, vision_feature_select_strategy
431
- )
432
-
433
- def merge_vllm_multimodal_embeddings(
434
- input_ids: torch.Tensor,
435
- inputs_embeds: torch.Tensor,
436
- multimodal_embeddings: torch.Tensor,
437
- placeholder_token_id: int,
438
- ) -> torch.Tensor:
439
- mask = input_ids == placeholder_token_id
440
- num_expected_tokens = mask.sum().item()
441
-
442
- if multimodal_embeddings.shape[0] != num_expected_tokens:
443
- raise ValueError(
444
- f"Attempted to assign {inputs_embeds[mask].shape} = {multimodal_embeddings.shape} "
445
- f"multimodal tokens to {num_expected_tokens} placeholders"
446
- )
447
-
448
- inputs_embeds[mask] = multimodal_embeddings
449
- return inputs_embeds
450
-
451
- inputs_embeds = merge_vllm_multimodal_embeddings(
452
- input_ids, inputs_embeds, image_features, self.config.image_token_index
453
- )
454
-
455
- else:
456
- inputs_embeds = self.text_embedding(input_ids=input_ids)
457
-
458
- outputs: RBLNDecoderOnlyOutput = self.language_model.vllm_forward(
459
- inputs_embeds=inputs_embeds,
460
- batch_idx=batch_idx,
461
- cache_position=cache_position,
462
- )
463
-
464
- return outputs
465
-
466
412
  # Almost copied from : https://github.com/huggingface/transformers/blob/5af7d41e49bbfc8319f462eb45253dcb3863dfb7/src/transformers/models/llava_next/modeling_llava_next.py
467
413
  def pack_image_features(self, image_features, image_sizes, image_newline=None):
468
414
  """
@@ -21,18 +21,24 @@
21
21
  # copied, modified, or distributed without prior written permission
22
22
  # from Rebellions Inc.
23
23
 
24
- from typing import Optional, Tuple, Union
24
+ from typing import TYPE_CHECKING, Tuple
25
25
 
26
26
  import torch
27
27
  import torch.nn as nn
28
- from transformers.modeling_outputs import BaseModelOutputWithPast
29
28
 
30
- from ....transformers.models.decoderonly.decoderonly_architecture import (
31
- RotaryEmbedding,
32
- rotate_half,
33
- slice_and_unsqueeze_cos_sin,
29
+ from ....transformers.models.decoderonly.decoderonly_architecture import rotate_half
30
+ from ..decoderonly.decoderonly_architecture import (
31
+ DecoderOnlyAttention,
32
+ DecoderOnlyForCausalLM,
33
+ DecoderOnlyLayer,
34
+ DecoderOnlyModel,
35
+ DecoderOnlyWrapper,
36
+ apply_rotary_pos_emb_partial,
34
37
  )
35
- from ...cache_utils import RebelDynamicCache_4D
38
+
39
+
40
+ if TYPE_CHECKING:
41
+ from transformers import PreTrainedModel as MidmLMHeadModel
36
42
 
37
43
 
38
44
  def apply_rotary_to_tensor(tensor, cos, sin, rot_dim):
@@ -50,264 +56,93 @@ def apply_rotary_pos_emb(q, k, cos, sin):
50
56
  return q_embed, k_embed
51
57
 
52
58
 
53
- class MidmLMHeadModelWrapper(torch.nn.Module):
54
- """A wrapper class for the Midm model with a language modeling head."""
59
+ class MidmLMHeadModelWrapper(DecoderOnlyWrapper):
60
+ def get_rotary_emb(self, max_seq_len):
61
+ self.config.rope_theta = 10000
62
+ self.config.head_dim = self.config.n_embd // self.config.n_head
63
+ self.config.partial_rotary_factor = self.config.rotary_percentage
64
+ return super().get_rotary_emb(max_seq_len=max_seq_len)
65
+
66
+ def convert_to_rbln_causal_lm(self, causal_lm: "MidmLMHeadModel"):
67
+ if self.attn_impl != "eager":
68
+ raise NotImplementedError(f"flash attention ({self.attn_impl}) is not implemented for {self.__class__}")
69
+ new_layers = []
70
+ for layer in causal_lm.transformer.h:
71
+ new_self_attn = MidmAttention(layer.attn)
72
+ new_layer = MidmLayer(layer, new_self_attn)
73
+ new_layers.append(new_layer)
74
+ new_model = MidmModel(causal_lm.transformer, new_layers)
75
+ new_causal_lm = DecoderOnlyForCausalLM(causal_lm, new_model)
76
+ return new_causal_lm
77
+
78
+
79
+ class MidmModel(DecoderOnlyModel):
80
+ mask_fmin = -10000.0
81
+
82
+ def get_layernorm1p(self, module: nn.LayerNorm):
83
+ def layernorm1p(input: torch.Tensor):
84
+ """Applies Layer Normalization with a slight modification on the weights."""
85
+ return torch.nn.functional.layer_norm(
86
+ input, module.normalized_shape, module.weight + 1, module.bias, module.eps
87
+ )
55
88
 
56
- def __init__(self, model, max_seq_len):
57
- super().__init__()
58
- self.model = model.transformer
59
- self.lm_head = model.lm_head
60
- self.config = model.config
61
- self.head_dim = self.config.n_embd // self.config.n_head
62
- self.max_position_embeddings = (
63
- self.config.max_position_embeddings if max_seq_len > self.config.max_position_embeddings else max_seq_len
64
- )
65
- self.max_seq_len = max_seq_len
66
- self.rotary_dim = int(
67
- model.config.hidden_size // model.config.num_attention_heads * model.config.rotary_percentage
68
- )
69
- self.rotary_emb = self._init_rope()
89
+ return layernorm1p
70
90
 
71
- def _init_rope(self):
72
- """Initializes the Rotary Position Embeddings."""
73
- rotary_emb = RotaryEmbedding(
74
- self.rotary_dim,
75
- max_position_embeddings=self.max_position_embeddings,
76
- )
77
- return rotary_emb
78
-
79
- def forward(
80
- self,
81
- input_ids: torch.Tensor,
82
- attention_mask: torch.Tensor,
83
- cache_position: torch.LongTensor,
84
- batch_position: int,
85
- query_idx: int,
86
- *past_key_values,
87
- ):
88
- """Defines the forward pass for the wrapper model."""
89
- if input_ids.shape[1] == 1:
90
- rbln_batch_position = None
91
+ def get_last_layernorm(self) -> nn.LayerNorm:
92
+ if self._original_mod.use_layernorm1p:
93
+ return self.get_layernorm1p(self._original_mod.ln_f)
91
94
  else:
92
- rbln_batch_position = batch_position
93
-
94
- past_key_values = RebelDynamicCache_4D.from_input_format(
95
- cache_position,
96
- self.config.num_hidden_layers,
97
- *past_key_values,
98
- )
99
-
100
- outputs = _MidmModel.forward(
101
- self.model,
102
- input_ids=input_ids,
103
- past_key_values=past_key_values,
104
- attention_mask=attention_mask,
105
- position_ids=cache_position,
106
- rotary_pos_emb=self.rotary_emb,
107
- batch_ids=rbln_batch_position,
108
- )
95
+ return self._original_mod.ln_f
109
96
 
110
- hidden_states = outputs[0]
111
- if batch_position >= 0:
112
- hidden_states = hidden_states[:, query_idx].unsqueeze(1)
97
+ def get_embedding(self) -> nn.Embedding:
98
+ return self._original_mod.wte
113
99
 
114
- logits = self.lm_head(hidden_states)
115
- output = (logits,) + outputs[1:]
100
+ def get_pos_embedding(self) -> nn.Embedding:
101
+ return self._original_mod.wpe
116
102
 
117
- return output, batch_position + query_idx
118
103
 
119
-
120
- def layernorm1p(module, input):
121
- """Applies Layer Normalization with a slight modification on the weights."""
122
- return torch.nn.functional.layer_norm(input, module.normalized_shape, module.weight + 1, module.bias, module.eps)
123
-
124
-
125
- class _MidmAttention:
126
- """Custom implementation of the MidmAttention class with specific modifications."""
127
-
128
- def _attn(self, query, key, value, attention_mask=None, head_mask=None):
129
- """Computes the attention weights and output."""
130
- attn_weights = torch.matmul(query, key.transpose(-1, -2))
131
-
132
- if self.scale_attn_weights:
133
- attn_weights = attn_weights / torch.full(
134
- [], value.size(-1) ** 0.5, dtype=attn_weights.dtype, device=attn_weights.device
104
+ class MidmLayer(DecoderOnlyLayer):
105
+ def get_layernorm1p(self, module: nn.LayerNorm):
106
+ def layernorm1p(input: torch.Tensor):
107
+ """Applies Layer Normalization with a slight modification on the weights."""
108
+ return torch.nn.functional.layer_norm(
109
+ input, module.normalized_shape, module.weight + 1, module.bias, module.eps
135
110
  )
136
111
 
137
- if self.scale_attn_by_inverse_layer_idx or self.scale_qk_by_inverse_layer_idx:
138
- attn_weights = attn_weights / float(self.layer_idx + 1)
139
-
140
- if attention_mask is not None:
141
- attn_weights = attn_weights + attention_mask
142
-
143
- if self.scale_qk_by_inverse_layer_idx:
144
- attn_weights = attn_weights * float(self.layer_idx + 1)
145
-
146
- attn_weights = nn.functional.softmax(attn_weights, dim=-1)
147
- attn_weights = attn_weights.type(value.dtype)
148
-
149
- if head_mask is not None:
150
- attn_weights = attn_weights * head_mask
151
-
152
- attn_output = torch.matmul(attn_weights, value)
153
- return attn_output, attn_weights
154
-
155
- def forward(
156
- self,
157
- hidden_states: Optional[Tuple[torch.FloatTensor]],
158
- attention_mask: Optional[torch.FloatTensor] = None,
159
- past_key_value: Optional[RebelDynamicCache_4D] = None,
160
- batch_index: Optional[int] = None,
161
- cos: Optional[torch.Tensor] = None,
162
- sin: Optional[torch.Tensor] = None,
163
- ) -> Tuple[Union[torch.Tensor, Tuple[torch.Tensor]], ...]:
164
- """Defines the forward pass for the attention mechanism."""
165
- bsz, q_len, _ = hidden_states.size()
166
-
167
- querys, keys, values = self.c_attn(hidden_states).split(self.split_size, dim=2)
168
-
169
- querys = self._split_heads(querys, self.num_heads, self.head_dim).contiguous()
170
- keys = self._split_heads(keys, self.num_heads, self.head_dim).contiguous()
171
- values = self._split_heads(values, self.num_heads, self.head_dim).contiguous()
172
-
173
- querys, keys = apply_rotary_pos_emb(querys, keys, cos, sin)
174
-
175
- # Decoder
176
- if (batch_index is None or batch_index == -1) and bsz > 1:
177
- all_key_states = []
178
- all_value_states = []
179
- all_attn_output = []
180
-
181
- for b in range(bsz):
182
- query = querys[b].unsqueeze(0)
183
- attn_mask = attention_mask[b].unsqueeze(0)
184
- key = keys[b].unsqueeze(0)
185
- value = values[b].unsqueeze(0)
186
-
187
- key, value = past_key_value.update(
188
- key,
189
- value,
190
- self.layer_idx,
191
- b,
192
- )
193
-
194
- attn_output, _ = _MidmAttention._attn(self, query, key, value, attn_mask)
195
- attn_output = self._merge_heads(attn_output, self.num_heads, self.head_dim)
196
-
197
- all_key_states.append(key)
198
- all_value_states.append(value)
199
- all_attn_output.append(attn_output)
200
-
201
- keys = torch.cat(all_key_states, dim=0)
202
- values = torch.cat(all_value_states, dim=0)
203
- attn_output = torch.cat(all_attn_output, dim=0)
112
+ return layernorm1p
204
113
 
114
+ def get_pre_attention_layernorm(self) -> nn.LayerNorm:
115
+ if self._original_mod.use_layernorm1p:
116
+ return self.get_layernorm1p(self._original_mod.ln_1)
205
117
  else:
206
- if batch_index is None or batch_index == -1:
207
- batch_index = 0
208
-
209
- keys, values = past_key_value.update(
210
- keys,
211
- values,
212
- self.layer_idx,
213
- batch_index,
214
- read_first_step=True,
215
- )
118
+ return self._original_mod.ln_1
216
119
 
217
- attn_output, _ = _MidmAttention._attn(self, querys, keys, values, attention_mask)
218
- attn_output = self._merge_heads(attn_output, self.num_heads, self.head_dim)
219
-
220
- attn_output = self.c_proj(attn_output)
221
- return attn_output, keys, values
222
-
223
-
224
- class _MidmBlock:
225
- """Custom implementation of the MidmBlock class with specific modifications."""
226
-
227
- def forward(
228
- self,
229
- hidden_states: Optional[Tuple[torch.FloatTensor]],
230
- layer_idx: int,
231
- attention_mask: Optional[torch.FloatTensor] = None,
232
- past_key_value: Optional[RebelDynamicCache_4D] = None,
233
- batch_ids: Optional[torch.LongTensor] = None,
234
- cos: Optional[torch.Tensor] = None,
235
- sin: Optional[torch.Tensor] = None,
236
- ) -> Union[Tuple[torch.Tensor], Optional[Tuple[torch.Tensor, Tuple[torch.FloatTensor, ...]]]]:
237
- """Defines the forward pass for the block."""
238
- residual = hidden_states
239
- if self.use_layernorm1p:
240
- hidden_states = layernorm1p(self.ln_1, hidden_states)
120
+ def get_post_attention_layernorm(self) -> nn.LayerNorm:
121
+ if self._original_mod.use_layernorm1p:
122
+ return self.get_layernorm1p(self._original_mod.ln_2)
241
123
  else:
242
- hidden_states = self.ln_1(hidden_states)
243
-
244
- hidden_states, k, v = _MidmAttention.forward(
245
- self.attn,
246
- hidden_states,
247
- attention_mask=attention_mask,
248
- past_key_value=past_key_value,
249
- cos=cos,
250
- sin=sin,
251
- batch_index=batch_ids,
124
+ return self._original_mod.ln_2
125
+
126
+
127
+ class MidmAttention(DecoderOnlyAttention):
128
+ def __post_init__(self):
129
+ self.c_attn = self._original_mod.c_attn
130
+ self.o_proj = self._original_mod.c_proj
131
+ self.split_size = self._original_mod.split_size
132
+ self.num_key_value_heads = self._original_mod.num_heads
133
+
134
+ def projection(self, hidden_states) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
135
+ query_states, key_states, value_states = self.c_attn(hidden_states).split(self.split_size, dim=2)
136
+ return query_states, key_states, value_states
137
+
138
+ def rbln_attention(self, *args, **kwargs):
139
+ return super().rbln_attention(
140
+ *args,
141
+ **kwargs,
142
+ layer_idx=self.layer_idx,
143
+ scale_attn_weights=self._original_mod.scale_attn_weights,
144
+ scale_attn_by_inverse_layer_idx=self._original_mod.scale_attn_by_inverse_layer_idx,
252
145
  )
253
- past_key_value.assign(k, v, layer_idx)
254
-
255
- hidden_states = hidden_states + residual
256
146
 
257
- residual = hidden_states
258
- if self.use_layernorm1p:
259
- hidden_states = layernorm1p(self.ln_2, hidden_states)
260
- else:
261
- hidden_states = self.ln_2(hidden_states)
262
-
263
- feed_forward_hidden_states = self.mlp(hidden_states)
264
- hidden_states = residual + feed_forward_hidden_states
265
-
266
- return hidden_states, past_key_value
267
-
268
-
269
- class _MidmModel:
270
- """Custom implementation of the MidmModel class with specific modifications."""
271
-
272
- def forward(
273
- self,
274
- input_ids: Optional[torch.LongTensor] = None,
275
- past_key_values: Optional[RebelDynamicCache_4D] = None,
276
- attention_mask: Optional[torch.FloatTensor] = None,
277
- position_ids: Optional[torch.LongTensor] = None,
278
- rotary_pos_emb=None,
279
- batch_ids: Optional[torch.LongTensor] = None,
280
- ) -> Union[Tuple, BaseModelOutputWithPast]:
281
- """Defines the forward pass for the model."""
282
- input_shape = input_ids.size()
283
-
284
- attention_mask = (1.0 - attention_mask) * -10000.0
285
-
286
- inputs_embeds = self.wte(input_ids)
287
-
288
- cos, sin = rotary_pos_emb(inputs_embeds, attention_mask.shape[-1])
289
- cos, sin = slice_and_unsqueeze_cos_sin(cos, sin, position_ids)
290
- hidden_states = inputs_embeds
291
-
292
- for layer_idx, (block, _) in enumerate(zip(self.h, past_key_values)):
293
- hidden_states, updated_cache = _MidmBlock.forward(
294
- block,
295
- hidden_states,
296
- layer_idx,
297
- attention_mask=attention_mask,
298
- past_key_value=past_key_values,
299
- batch_ids=batch_ids,
300
- cos=cos,
301
- sin=sin,
302
- )
303
-
304
- hidden_states = layernorm1p(self.ln_f, hidden_states)
305
- output_shape = (-1,) + input_shape[1:] + (hidden_states.size(-1),)
306
- hidden_states = hidden_states.view(output_shape)
307
-
308
- next_cache = updated_cache.to_legacy_cache()
309
-
310
- return BaseModelOutputWithPast(
311
- last_hidden_state=hidden_states,
312
- past_key_values=next_cache,
313
- )
147
+ def apply_rotary_pos_embed(self, query_states, key_states, cos, sin):
148
+ return apply_rotary_pos_emb_partial(query_states, key_states, cos, sin, ndim=cos.shape[-1])
@@ -21,23 +21,15 @@
21
21
  # copied, modified, or distributed without prior written permission
22
22
  # from Rebellions Inc.
23
23
 
24
- import inspect
25
- import logging
26
- from typing import TYPE_CHECKING, Any, Callable
27
24
 
28
- from ....modeling_config import RBLNConfig
29
- from ...models.decoderonly import RBLNDecoderOnlyModelForCausalLM
30
- from .hf_hub_cached.modeling_midm import MidmLMHeadModel
31
- from .midm_architecture import (
32
- MidmLMHeadModelWrapper,
33
- )
25
+ from transformers import AutoModelForCausalLM
34
26
 
27
+ from ....utils import logging
28
+ from ..decoderonly import RBLNDecoderOnlyModelForCausalLM
29
+ from .midm_architecture import MidmLMHeadModelWrapper
35
30
 
36
- logger = logging.getLogger(__name__)
37
- if TYPE_CHECKING:
38
- from transformers import (
39
- PreTrainedModel,
40
- )
31
+
32
+ logger = logging.get_logger(__name__)
41
33
 
42
34
 
43
35
  class RBLNMidmLMHeadModel(RBLNDecoderOnlyModelForCausalLM):
@@ -54,25 +46,8 @@ class RBLNMidmLMHeadModel(RBLNDecoderOnlyModelForCausalLM):
54
46
 
55
47
  """
56
48
 
57
- @classmethod
58
- def wrap_model_if_needed(self, model: "PreTrainedModel", rbln_config: "RBLNConfig"):
59
- rbln_max_seq_len = rbln_config.model_cfg["max_seq_len"]
60
- return MidmLMHeadModelWrapper(model, rbln_max_seq_len).eval()
61
-
62
- def __getattr__(self, __name: str) -> Any:
63
- """This is the key method to implement RBLN-Midm.
64
-
65
- Returns:
66
- Any: Midm's corresponding method
67
- """
68
-
69
- def redirect(func):
70
- return lambda *pargs, **kwargs: func(self, *pargs, **kwargs)
71
-
72
- val = getattr(MidmLMHeadModel, __name)
73
- if isinstance(val, Callable) and "self" in set(inspect.signature(val).parameters):
74
- return redirect(val)
75
- return val
49
+ _decoder_wrapper_cls = MidmLMHeadModelWrapper
50
+ _hf_class = AutoModelForCausalLM
76
51
 
77
52
  @classmethod
78
53
  def from_pretrained(cls, *args, **kwargs):
@@ -21,29 +21,18 @@
21
21
  # copied, modified, or distributed without prior written permission
22
22
  # from Rebellions Inc.
23
23
 
24
- import inspect
25
- import logging
26
- from typing import TYPE_CHECKING, Any, Callable
27
-
28
- from transformers import MistralForCausalLM
29
-
24
+ from ....utils import logging
30
25
  from ...models.decoderonly import RBLNDecoderOnlyModelForCausalLM
31
26
  from .mistral_architecture import MistralForCausalLMWrapper
32
27
 
33
28
 
34
- if TYPE_CHECKING:
35
- from transformers import PreTrainedModel
36
-
37
- from ....modeling_config import RBLNConfig
38
-
39
-
40
- logger = logging.getLogger(__name__)
29
+ logger = logging.get_logger(__name__)
41
30
 
42
31
 
43
32
  class RBLNMistralForCausalLM(RBLNDecoderOnlyModelForCausalLM):
44
33
  """
45
34
  The Llama Model transformer with a language modeling head (linear layer) on top.
46
- This model inherits from [`RBLNMultiModel`]. Check the superclass documentation for the generic methods the library implements for all its models.
35
+ This model inherits from [`RBLNDecoderOnlyModelForCausalLM`]. Check the superclass documentation for the generic methods the library implements for all its models.
47
36
 
48
37
  A class to convert and run pre-trained transformers based LlamaForCausalLM model on RBLN devices.
49
38
  It implements the methods to convert a pre-trained transformers LlamaForCausalLM model into a RBLN transformer model by:
@@ -51,18 +40,4 @@ class RBLNMistralForCausalLM(RBLNDecoderOnlyModelForCausalLM):
51
40
  - compiling the resulting graph using the RBLN compiler.
52
41
  """
53
42
 
54
- @classmethod
55
- def wrap_model_if_needed(self, model: "PreTrainedModel", rbln_config: "RBLNConfig"):
56
- rbln_max_seq_len = rbln_config.model_cfg["max_seq_len"]
57
- return MistralForCausalLMWrapper(model, rbln_max_seq_len).eval()
58
-
59
- def __getattr__(self, __name: str) -> Any:
60
- def redirect(func):
61
- return lambda *pargs, **kwargs: func(self, *pargs, **kwargs)
62
-
63
- val = getattr(MistralForCausalLM, __name)
64
-
65
- if isinstance(val, Callable) and "self" in set(inspect.signature(val).parameters):
66
- return redirect(val)
67
-
68
- return val
43
+ _decoder_wrapper_cls = MistralForCausalLMWrapper