optimum-rbln 0.1.13__py3-none-any.whl → 0.2.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (103) hide show
  1. optimum/rbln/__init__.py +41 -38
  2. optimum/rbln/__version__.py +16 -1
  3. optimum/rbln/diffusers/__init__.py +26 -2
  4. optimum/rbln/{modeling_diffusers.py → diffusers/modeling_diffusers.py} +97 -126
  5. optimum/rbln/diffusers/models/__init__.py +36 -3
  6. optimum/rbln/{transformers/generation → diffusers/models/autoencoders}/__init__.py +1 -2
  7. optimum/rbln/diffusers/models/{autoencoder_kl.py → autoencoders/autoencoder_kl.py} +73 -61
  8. optimum/rbln/diffusers/models/autoencoders/vae.py +83 -0
  9. optimum/rbln/diffusers/models/controlnet.py +54 -14
  10. optimum/rbln/diffusers/models/transformers/__init__.py +24 -0
  11. optimum/rbln/diffusers/models/transformers/transformer_sd3.py +203 -0
  12. optimum/rbln/diffusers/models/unets/__init__.py +24 -0
  13. optimum/rbln/diffusers/models/{unet_2d_condition.py → unets/unet_2d_condition.py} +82 -22
  14. optimum/rbln/diffusers/pipelines/__init__.py +23 -2
  15. optimum/rbln/diffusers/pipelines/controlnet/multicontrolnet.py +13 -33
  16. optimum/rbln/diffusers/pipelines/controlnet/pipeline_controlnet.py +17 -2
  17. optimum/rbln/diffusers/pipelines/controlnet/pipeline_controlnet_img2img.py +18 -2
  18. optimum/rbln/diffusers/pipelines/controlnet/pipeline_controlnet_sd_xl.py +18 -2
  19. optimum/rbln/diffusers/pipelines/controlnet/pipeline_controlnet_sd_xl_img2img.py +18 -2
  20. optimum/rbln/diffusers/pipelines/stable_diffusion/__init__.py +1 -0
  21. optimum/rbln/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion.py +2 -2
  22. optimum/rbln/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_img2img.py +1 -13
  23. optimum/rbln/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_inpaint.py +31 -0
  24. optimum/rbln/diffusers/pipelines/stable_diffusion_3/__init__.py +26 -0
  25. optimum/rbln/diffusers/pipelines/stable_diffusion_3/pipeline_stable_diffusion_3.py +31 -0
  26. optimum/rbln/diffusers/pipelines/stable_diffusion_3/pipeline_stable_diffusion_3_img2img.py +31 -0
  27. optimum/rbln/diffusers/pipelines/stable_diffusion_3/pipeline_stable_diffusion_3_inpaint.py +31 -0
  28. optimum/rbln/diffusers/pipelines/stable_diffusion_xl/__init__.py +24 -0
  29. optimum/rbln/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl.py +15 -8
  30. optimum/rbln/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_img2img.py +15 -8
  31. optimum/rbln/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_inpaint.py +31 -0
  32. optimum/rbln/modeling.py +238 -0
  33. optimum/rbln/modeling_base.py +186 -760
  34. optimum/rbln/modeling_config.py +31 -7
  35. optimum/rbln/ops/__init__.py +26 -0
  36. optimum/rbln/ops/attn.py +221 -0
  37. optimum/rbln/ops/flash_attn.py +70 -0
  38. optimum/rbln/ops/kv_cache_update.py +69 -0
  39. optimum/rbln/transformers/__init__.py +20 -2
  40. optimum/rbln/{modeling_alias.py → transformers/modeling_alias.py} +5 -1
  41. optimum/rbln/transformers/modeling_generic.py +385 -0
  42. optimum/rbln/transformers/models/auto/__init__.py +23 -0
  43. optimum/rbln/transformers/models/auto/auto_factory.py +117 -23
  44. optimum/rbln/transformers/models/auto/modeling_auto.py +36 -12
  45. optimum/rbln/transformers/models/bart/__init__.py +0 -1
  46. optimum/rbln/transformers/models/bart/bart_architecture.py +107 -464
  47. optimum/rbln/transformers/models/bart/modeling_bart.py +10 -9
  48. optimum/rbln/transformers/models/bert/modeling_bert.py +3 -6
  49. optimum/rbln/transformers/models/clip/modeling_clip.py +8 -25
  50. optimum/rbln/transformers/models/decoderonly/__init__.py +0 -10
  51. optimum/rbln/transformers/models/decoderonly/decoderonly_architecture.py +775 -514
  52. optimum/rbln/transformers/models/decoderonly/modeling_decoderonly.py +128 -260
  53. optimum/rbln/transformers/models/dpt/modeling_dpt.py +1 -1
  54. optimum/rbln/transformers/models/exaone/exaone_architecture.py +60 -45
  55. optimum/rbln/transformers/models/exaone/modeling_exaone.py +4 -2
  56. optimum/rbln/transformers/models/gemma/gemma_architecture.py +33 -104
  57. optimum/rbln/transformers/models/gpt2/gpt2_architecture.py +50 -238
  58. optimum/rbln/transformers/models/gpt2/modeling_gpt2.py +3 -2
  59. optimum/rbln/transformers/models/llama/llama_architecture.py +0 -1
  60. optimum/rbln/transformers/models/llava_next/modeling_llava_next.py +3 -75
  61. optimum/rbln/transformers/models/midm/midm_architecture.py +84 -238
  62. optimum/rbln/transformers/models/midm/modeling_midm.py +5 -6
  63. optimum/rbln/transformers/models/mistral/mistral_architecture.py +0 -1
  64. optimum/rbln/transformers/models/phi/phi_architecture.py +60 -261
  65. optimum/rbln/transformers/models/qwen2/qwen2_architecture.py +0 -1
  66. optimum/rbln/transformers/models/seq2seq/modeling_seq2seq.py +58 -103
  67. optimum/rbln/transformers/models/seq2seq/seq2seq_architecture.py +498 -0
  68. optimum/rbln/transformers/models/t5/__init__.py +0 -1
  69. optimum/rbln/transformers/models/t5/modeling_t5.py +106 -5
  70. optimum/rbln/transformers/models/t5/t5_architecture.py +106 -448
  71. optimum/rbln/transformers/models/wav2vec2/modeling_wav2vec2.py +1 -1
  72. optimum/rbln/transformers/models/whisper/generation_whisper.py +42 -0
  73. optimum/rbln/transformers/models/whisper/modeling_whisper.py +78 -55
  74. optimum/rbln/transformers/models/whisper/whisper_architecture.py +219 -312
  75. optimum/rbln/transformers/models/xlm_roberta/modeling_xlm_roberta.py +3 -35
  76. optimum/rbln/transformers/utils/rbln_quantization.py +120 -4
  77. optimum/rbln/utils/decorator_utils.py +51 -11
  78. optimum/rbln/utils/hub.py +131 -0
  79. optimum/rbln/utils/import_utils.py +22 -1
  80. optimum/rbln/utils/logging.py +37 -0
  81. optimum/rbln/utils/model_utils.py +52 -0
  82. optimum/rbln/utils/runtime_utils.py +10 -4
  83. optimum/rbln/utils/save_utils.py +17 -0
  84. optimum/rbln/utils/submodule.py +137 -0
  85. optimum_rbln-0.2.0.dist-info/METADATA +117 -0
  86. optimum_rbln-0.2.0.dist-info/RECORD +114 -0
  87. {optimum_rbln-0.1.13.dist-info → optimum_rbln-0.2.0.dist-info}/WHEEL +1 -1
  88. optimum_rbln-0.2.0.dist-info/licenses/LICENSE +288 -0
  89. optimum/rbln/transformers/cache_utils.py +0 -107
  90. optimum/rbln/transformers/generation/streamers.py +0 -139
  91. optimum/rbln/transformers/generation/utils.py +0 -397
  92. optimum/rbln/transformers/models/exaone/hf_hub_cached/configuration_exaone.py +0 -181
  93. optimum/rbln/transformers/models/exaone/hf_hub_cached/modeling_exaone.py +0 -1725
  94. optimum/rbln/transformers/models/midm/hf_hub_cached/configuration_midm.py +0 -22
  95. optimum/rbln/transformers/models/midm/hf_hub_cached/midm_bitext_tokenization.py +0 -304
  96. optimum/rbln/transformers/models/midm/hf_hub_cached/modeling_midm.py +0 -1469
  97. optimum/rbln/transformers/models/midm/hf_hub_cached/rotary_position_embedding.py +0 -98
  98. optimum/rbln/utils/context.py +0 -58
  99. optimum/rbln/utils/timer_utils.py +0 -43
  100. optimum_rbln-0.1.13.dist-info/METADATA +0 -120
  101. optimum_rbln-0.1.13.dist-info/RECORD +0 -107
  102. optimum_rbln-0.1.13.dist-info/entry_points.txt +0 -4
  103. optimum_rbln-0.1.13.dist-info/licenses/LICENSE +0 -201
@@ -20,10 +20,11 @@
20
20
  # are the intellectual property of Rebellions Inc. and may not be
21
21
  # copied, modified, or distributed without prior written permission
22
22
  # from Rebellions Inc.
23
+
23
24
  import inspect
24
25
  import logging
25
26
  from pathlib import Path
26
- from typing import TYPE_CHECKING, Any, Callable, Dict, List, Optional, Tuple, Union
27
+ from typing import TYPE_CHECKING, Any, Callable, Dict, Optional, Tuple, Union
27
28
 
28
29
  import numpy as np
29
30
  import torch
@@ -36,7 +37,7 @@ from transformers import (
36
37
  from transformers.modeling_outputs import BaseModelOutputWithPooling
37
38
  from transformers.models.llava_next.modeling_llava_next import LlavaNextCausalLMOutputWithPast
38
39
 
39
- from ....modeling_base import RBLNModel
40
+ from ....modeling import RBLNModel
40
41
  from ....modeling_config import RBLNCompileConfig, RBLNConfig
41
42
  from ..decoderonly.modeling_decoderonly import RBLNDecoderOnlyOutput
42
43
 
@@ -166,19 +167,6 @@ class RBLNLlavaNextForConditionalGeneration(RBLNModel):
166
167
  self._padding_side = "left" # set it to left by default, user can use setter to change padding_sides
167
168
  return super().__post_init__(**kwargs)
168
169
 
169
- @classmethod
170
- def get_pytorch_model(
171
- cls,
172
- model_id: str,
173
- *args,
174
- rbln_kwargs: Optional[Dict[str, Any]] = None,
175
- **kwargs,
176
- ) -> "PreTrainedModel":
177
- # Optimum's TasksManager does not handle Llava.
178
- kwargs = cls.update_kwargs(kwargs)
179
- model = LlavaNextForConditionalGeneration.from_pretrained(model_id, *args, **kwargs)
180
- return model
181
-
182
170
  def get_input_embeddings(self):
183
171
  return self.language_model.get_input_embeddings()
184
172
 
@@ -422,66 +410,6 @@ class RBLNLlavaNextForConditionalGeneration(RBLNModel):
422
410
 
423
411
  return outputs
424
412
 
425
- def vllm_forward(
426
- self,
427
- input_ids: torch.LongTensor = None,
428
- pixel_values: torch.FloatTensor = None,
429
- image_sizes: Optional[torch.LongTensor] = None,
430
- inputs_embeds: Optional[torch.FloatTensor] = None,
431
- vision_feature_layer: Optional[int] = None,
432
- vision_feature_select_strategy: Optional[str] = None,
433
- cache_position: Union[List[torch.Tensor], torch.Tensor] = None, # vllm keyword argument
434
- batch_idx: Optional[int] = None,
435
- **kwargs,
436
- ) -> Union[Tuple, LlavaNextCausalLMOutputWithPast]:
437
- is_prefill = cache_position.shape[-1] > 1
438
-
439
- if inputs_embeds is not None:
440
- raise NotImplementedError("Specifying inputs_embeds is not supported.")
441
-
442
- if is_prefill:
443
- # Get text_embeds
444
- inputs_embeds = self.text_embedding(input_ids)
445
-
446
- # If any images in the prompt, get image_embeds and merge with text
447
- if pixel_values is not None and input_ids.shape[1] != 1 and pixel_values.size(0) > 0:
448
- image_features, _ = self.image_embedding(
449
- image_sizes, pixel_values, vision_feature_layer, vision_feature_select_strategy
450
- )
451
-
452
- def merge_vllm_multimodal_embeddings(
453
- input_ids: torch.Tensor,
454
- inputs_embeds: torch.Tensor,
455
- multimodal_embeddings: torch.Tensor,
456
- placeholder_token_id: int,
457
- ) -> torch.Tensor:
458
- mask = input_ids == placeholder_token_id
459
- num_expected_tokens = mask.sum().item()
460
-
461
- if multimodal_embeddings.shape[0] != num_expected_tokens:
462
- raise ValueError(
463
- f"Attempted to assign {inputs_embeds[mask].shape} = {multimodal_embeddings.shape} "
464
- f"multimodal tokens to {num_expected_tokens} placeholders"
465
- )
466
-
467
- inputs_embeds[mask] = multimodal_embeddings
468
- return inputs_embeds
469
-
470
- inputs_embeds = merge_vllm_multimodal_embeddings(
471
- input_ids, inputs_embeds, image_features, self.config.image_token_index
472
- )
473
-
474
- else:
475
- inputs_embeds = self.text_embedding(input_ids=input_ids)
476
-
477
- outputs: RBLNDecoderOnlyOutput = self.language_model.vllm_forward(
478
- inputs_embeds=inputs_embeds,
479
- batch_idx=batch_idx,
480
- cache_position=cache_position,
481
- )
482
-
483
- return outputs
484
-
485
413
  # Almost copied from : https://github.com/huggingface/transformers/blob/5af7d41e49bbfc8319f462eb45253dcb3863dfb7/src/transformers/models/llava_next/modeling_llava_next.py
486
414
  def pack_image_features(self, image_features, image_sizes, image_newline=None):
487
415
  """
@@ -21,18 +21,25 @@
21
21
  # copied, modified, or distributed without prior written permission
22
22
  # from Rebellions Inc.
23
23
 
24
- from typing import Optional, Tuple, Union
24
+ import math
25
+ from typing import TYPE_CHECKING, Tuple
25
26
 
26
27
  import torch
27
28
  import torch.nn as nn
28
- from transformers.modeling_outputs import BaseModelOutputWithPast
29
29
 
30
- from ....transformers.models.decoderonly.decoderonly_architecture import (
31
- RotaryEmbedding,
30
+ from ..decoderonly.decoderonly_architecture import (
31
+ DecoderOnlyAttention,
32
+ DecoderOnlyForCausalLM,
33
+ DecoderOnlyLayer,
34
+ DecoderOnlyModel,
35
+ DecoderOnlyWrapper,
36
+ apply_rotary_pos_emb_partial,
32
37
  rotate_half,
33
- slice_and_unsqueeze_cos_sin,
34
38
  )
35
- from ...cache_utils import RebelDynamicCache_4D
39
+
40
+
41
+ if TYPE_CHECKING:
42
+ from transformers import PreTrainedModel as MidmLMHeadModel
36
43
 
37
44
 
38
45
  def apply_rotary_to_tensor(tensor, cos, sin, rot_dim):
@@ -50,253 +57,92 @@ def apply_rotary_pos_emb(q, k, cos, sin):
50
57
  return q_embed, k_embed
51
58
 
52
59
 
53
- class MidmLMHeadModelWrapper(torch.nn.Module):
54
- """A wrapper class for the Midm model with a language modeling head."""
55
-
56
- def __init__(self, model, max_seq_len):
57
- super().__init__()
58
- self.model = model.transformer
59
- self.lm_head = model.lm_head
60
- self.config = model.config
61
- self.max_seq_len = max_seq_len
62
-
63
- self.config.partial_rotary_factor = model.config.rotary_percentage
64
- self.config.head_dim = self.config.n_embd // self.config.n_head
60
+ class MidmLMHeadModelWrapper(DecoderOnlyWrapper):
61
+ def get_rotary_emb(self, max_seq_len):
65
62
  self.config.rope_theta = 10000
66
- self.rotary_emb = RotaryEmbedding(config=self.config, max_seq_len_cached=max_seq_len)
67
-
68
- def forward(
69
- self,
70
- input_ids: torch.Tensor,
71
- attention_mask: torch.Tensor,
72
- cache_position: torch.LongTensor,
73
- batch_position: int,
74
- query_idx: int,
75
- *past_key_values,
76
- ):
77
- """Defines the forward pass for the wrapper model."""
78
- if input_ids.shape[1] == 1:
79
- rbln_batch_position = None
80
- else:
81
- rbln_batch_position = batch_position
82
-
83
- past_key_values = RebelDynamicCache_4D.from_input_format(
84
- cache_position,
85
- self.config.num_hidden_layers,
86
- *past_key_values,
87
- )
88
-
89
- outputs = _MidmModel.forward(
90
- self.model,
91
- input_ids=input_ids,
92
- past_key_values=past_key_values,
93
- attention_mask=attention_mask,
94
- position_ids=cache_position,
95
- rotary_pos_emb=self.rotary_emb,
96
- batch_ids=rbln_batch_position,
97
- )
98
-
99
- hidden_states = outputs[0]
100
- if batch_position >= 0:
101
- hidden_states = hidden_states[:, query_idx].unsqueeze(1)
102
-
103
- logits = self.lm_head(hidden_states)
104
- output = (logits,) + outputs[1:]
105
-
106
- return output, batch_position + query_idx
107
-
108
-
109
- def layernorm1p(module, input):
110
- """Applies Layer Normalization with a slight modification on the weights."""
111
- return torch.nn.functional.layer_norm(input, module.normalized_shape, module.weight + 1, module.bias, module.eps)
112
-
113
-
114
- class _MidmAttention:
115
- """Custom implementation of the MidmAttention class with specific modifications."""
116
-
117
- def _attn(self, query, key, value, attention_mask=None, head_mask=None):
118
- """Computes the attention weights and output."""
119
- attn_weights = torch.matmul(query, key.transpose(-1, -2))
120
-
121
- if self.scale_attn_weights:
122
- attn_weights = attn_weights / torch.full(
123
- [], value.size(-1) ** 0.5, dtype=attn_weights.dtype, device=attn_weights.device
63
+ self.config.head_dim = self.config.n_embd // self.config.n_head
64
+ self.config.partial_rotary_factor = self.config.rotary_percentage
65
+ return super().get_rotary_emb(max_seq_len=max_seq_len)
66
+
67
+ def convert_to_rbln_causal_lm(self, causal_lm: "MidmLMHeadModel"):
68
+ if self.attn_impl != "eager":
69
+ raise NotImplementedError(f"flash attention ({self.attn_impl}) is not implemented for {self.__class__}")
70
+ new_layers = []
71
+ for layer in causal_lm.transformer.h:
72
+ new_self_attn = MidmAttention(layer.attn)
73
+ new_layer = MidmLayer(layer, new_self_attn)
74
+ new_layers.append(new_layer)
75
+ new_model = MidmModel(causal_lm.transformer, new_layers)
76
+ new_causal_lm = DecoderOnlyForCausalLM(causal_lm, new_model)
77
+ return new_causal_lm
78
+
79
+
80
+ class MidmModel(DecoderOnlyModel):
81
+ def get_layernorm1p(self, module: nn.LayerNorm):
82
+ def layernorm1p(input: torch.Tensor):
83
+ """Applies Layer Normalization with a slight modification on the weights."""
84
+ return torch.nn.functional.layer_norm(
85
+ input, module.normalized_shape, module.weight + 1, module.bias, module.eps
124
86
  )
125
87
 
126
- if self.scale_attn_by_inverse_layer_idx or self.scale_qk_by_inverse_layer_idx:
127
- attn_weights = attn_weights / float(self.layer_idx + 1)
128
-
129
- if attention_mask is not None:
130
- attn_weights = attn_weights + attention_mask
131
-
132
- if self.scale_qk_by_inverse_layer_idx:
133
- attn_weights = attn_weights * float(self.layer_idx + 1)
88
+ return layernorm1p
134
89
 
135
- attn_weights = nn.functional.softmax(attn_weights, dim=-1)
136
- attn_weights = attn_weights.type(value.dtype)
137
-
138
- if head_mask is not None:
139
- attn_weights = attn_weights * head_mask
140
-
141
- attn_output = torch.matmul(attn_weights, value)
142
- return attn_output, attn_weights
143
-
144
- def forward(
145
- self,
146
- hidden_states: Optional[Tuple[torch.FloatTensor]],
147
- attention_mask: Optional[torch.FloatTensor] = None,
148
- past_key_value: Optional[RebelDynamicCache_4D] = None,
149
- batch_index: Optional[int] = None,
150
- cos: Optional[torch.Tensor] = None,
151
- sin: Optional[torch.Tensor] = None,
152
- ) -> Tuple[Union[torch.Tensor, Tuple[torch.Tensor]], ...]:
153
- """Defines the forward pass for the attention mechanism."""
154
- bsz, q_len, _ = hidden_states.size()
90
+ def get_last_layernorm(self) -> nn.LayerNorm:
91
+ if self._original_mod.use_layernorm1p:
92
+ return self.get_layernorm1p(self._original_mod.ln_f)
93
+ else:
94
+ return self._original_mod.ln_f
155
95
 
156
- querys, keys, values = self.c_attn(hidden_states).split(self.split_size, dim=2)
96
+ def get_embedding(self) -> nn.Embedding:
97
+ return self._original_mod.wte
157
98
 
158
- querys = self._split_heads(querys, self.num_heads, self.head_dim).contiguous()
159
- keys = self._split_heads(keys, self.num_heads, self.head_dim).contiguous()
160
- values = self._split_heads(values, self.num_heads, self.head_dim).contiguous()
99
+ def get_pos_embedding(self) -> nn.Embedding:
100
+ return self._original_mod.wpe
161
101
 
162
- querys, keys = apply_rotary_pos_emb(querys, keys, cos, sin)
163
102
 
164
- # Decoder
165
- if (batch_index is None or batch_index == -1) and bsz > 1:
166
- all_key_states = []
167
- all_value_states = []
168
- all_attn_output = []
103
+ class MidmLayer(DecoderOnlyLayer):
104
+ def get_layernorm1p(self, module: nn.LayerNorm):
105
+ def layernorm1p(input: torch.Tensor):
106
+ """Applies Layer Normalization with a slight modification on the weights."""
107
+ return torch.nn.functional.layer_norm(
108
+ input, module.normalized_shape, module.weight + 1, module.bias, module.eps
109
+ )
169
110
 
170
- for b in range(bsz):
171
- query = querys[b].unsqueeze(0)
172
- attn_mask = attention_mask[b].unsqueeze(0)
173
- key = keys[b].unsqueeze(0)
174
- value = values[b].unsqueeze(0)
111
+ return layernorm1p
175
112
 
176
- key, value = past_key_value.update(
177
- key,
178
- value,
179
- self.layer_idx,
180
- b,
181
- )
113
+ def get_pre_attention_layernorm(self) -> nn.LayerNorm:
114
+ if self._original_mod.use_layernorm1p:
115
+ return self.get_layernorm1p(self._original_mod.ln_1)
116
+ else:
117
+ return self._original_mod.ln_1
182
118
 
183
- attn_output, _ = _MidmAttention._attn(self, query, key, value, attn_mask)
184
- attn_output = self._merge_heads(attn_output, self.num_heads, self.head_dim)
119
+ def get_post_attention_layernorm(self) -> nn.LayerNorm:
120
+ if self._original_mod.use_layernorm1p:
121
+ return self.get_layernorm1p(self._original_mod.ln_2)
122
+ else:
123
+ return self._original_mod.ln_2
185
124
 
186
- all_key_states.append(key)
187
- all_value_states.append(value)
188
- all_attn_output.append(attn_output)
189
125
 
190
- keys = torch.cat(all_key_states, dim=0)
191
- values = torch.cat(all_value_states, dim=0)
192
- attn_output = torch.cat(all_attn_output, dim=0)
126
+ class MidmAttention(DecoderOnlyAttention):
127
+ def __post_init__(self):
128
+ self.c_attn = self._original_mod.c_attn
129
+ self.o_proj = self._original_mod.c_proj
130
+ self.split_size = self._original_mod.split_size
131
+ self.num_key_value_heads = self._original_mod.num_heads
193
132
 
194
- else:
195
- if batch_index is None or batch_index == -1:
196
- batch_index = 0
197
-
198
- keys, values = past_key_value.update(
199
- keys,
200
- values,
201
- self.layer_idx,
202
- batch_index,
203
- read_first_step=True,
204
- )
133
+ def projection(self, hidden_states) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
134
+ query_states, key_states, value_states = self.c_attn(hidden_states).split(self.split_size, dim=2)
135
+ return query_states, key_states, value_states
205
136
 
206
- attn_output, _ = _MidmAttention._attn(self, querys, keys, values, attention_mask)
207
- attn_output = self._merge_heads(attn_output, self.num_heads, self.head_dim)
208
-
209
- attn_output = self.c_proj(attn_output)
210
- return attn_output, keys, values
211
-
212
-
213
- class _MidmBlock:
214
- """Custom implementation of the MidmBlock class with specific modifications."""
215
-
216
- def forward(
217
- self,
218
- hidden_states: Optional[Tuple[torch.FloatTensor]],
219
- layer_idx: int,
220
- attention_mask: Optional[torch.FloatTensor] = None,
221
- past_key_value: Optional[RebelDynamicCache_4D] = None,
222
- batch_ids: Optional[torch.LongTensor] = None,
223
- cos: Optional[torch.Tensor] = None,
224
- sin: Optional[torch.Tensor] = None,
225
- ) -> Union[Tuple[torch.Tensor], Optional[Tuple[torch.Tensor, Tuple[torch.FloatTensor, ...]]]]:
226
- """Defines the forward pass for the block."""
227
- residual = hidden_states
228
- if self.use_layernorm1p:
229
- hidden_states = layernorm1p(self.ln_1, hidden_states)
230
- else:
231
- hidden_states = self.ln_1(hidden_states)
232
-
233
- hidden_states, k, v = _MidmAttention.forward(
234
- self.attn,
235
- hidden_states,
236
- attention_mask=attention_mask,
237
- past_key_value=past_key_value,
238
- cos=cos,
239
- sin=sin,
240
- batch_index=batch_ids,
241
- )
242
- past_key_value.assign(k, v, layer_idx)
243
-
244
- hidden_states = hidden_states + residual
245
-
246
- residual = hidden_states
247
- if self.use_layernorm1p:
248
- hidden_states = layernorm1p(self.ln_2, hidden_states)
249
- else:
250
- hidden_states = self.ln_2(hidden_states)
251
-
252
- feed_forward_hidden_states = self.mlp(hidden_states)
253
- hidden_states = residual + feed_forward_hidden_states
254
-
255
- return hidden_states, past_key_value
256
-
257
-
258
- class _MidmModel:
259
- """Custom implementation of the MidmModel class with specific modifications."""
260
-
261
- def forward(
262
- self,
263
- input_ids: Optional[torch.LongTensor] = None,
264
- past_key_values: Optional[RebelDynamicCache_4D] = None,
265
- attention_mask: Optional[torch.FloatTensor] = None,
266
- position_ids: Optional[torch.LongTensor] = None,
267
- rotary_pos_emb=None,
268
- batch_ids: Optional[torch.LongTensor] = None,
269
- ) -> Union[Tuple, BaseModelOutputWithPast]:
270
- """Defines the forward pass for the model."""
271
- input_shape = input_ids.size()
272
-
273
- attention_mask = (1.0 - attention_mask) * -10000.0
274
-
275
- inputs_embeds = self.wte(input_ids)
276
-
277
- cos, sin = rotary_pos_emb(inputs_embeds, attention_mask.shape[-1])
278
- cos, sin = slice_and_unsqueeze_cos_sin(cos, sin, position_ids)
279
- hidden_states = inputs_embeds
280
-
281
- for layer_idx, (block, _) in enumerate(zip(self.h, past_key_values)):
282
- hidden_states, updated_cache = _MidmBlock.forward(
283
- block,
284
- hidden_states,
285
- layer_idx,
286
- attention_mask=attention_mask,
287
- past_key_value=past_key_values,
288
- batch_ids=batch_ids,
289
- cos=cos,
290
- sin=sin,
291
- )
137
+ def get_attn_scale(self):
138
+ scale = 1.0
139
+ if self._original_mod.scale_attn_weights:
140
+ scale /= math.sqrt(self.head_dim)
292
141
 
293
- hidden_states = layernorm1p(self.ln_f, hidden_states)
294
- output_shape = (-1,) + input_shape[1:] + (hidden_states.size(-1),)
295
- hidden_states = hidden_states.view(output_shape)
142
+ if self._original_mod.scale_attn_by_inverse_layer_idx and not self._original_mod.scale_qk_by_inverse_layer_idx:
143
+ scale /= 1 + self.layer_idx
296
144
 
297
- next_cache = updated_cache.to_legacy_cache()
145
+ return scale
298
146
 
299
- return BaseModelOutputWithPast(
300
- last_hidden_state=hidden_states,
301
- past_key_values=next_cache,
302
- )
147
+ def apply_rotary_pos_embed(self, query_states, key_states, cos, sin):
148
+ return apply_rotary_pos_emb_partial(query_states, key_states, cos, sin, ndim=cos.shape[-1])
@@ -21,12 +21,11 @@
21
21
  # copied, modified, or distributed without prior written permission
22
22
  # from Rebellions Inc.
23
23
 
24
+ from transformers import AutoModelForCausalLM
25
+
24
26
  from ....utils import logging
25
- from ...models.decoderonly import RBLNDecoderOnlyModelForCausalLM
26
- from .hf_hub_cached.modeling_midm import MidmLMHeadModel
27
- from .midm_architecture import (
28
- MidmLMHeadModelWrapper,
29
- )
27
+ from ..decoderonly import RBLNDecoderOnlyModelForCausalLM
28
+ from .midm_architecture import MidmLMHeadModelWrapper
30
29
 
31
30
 
32
31
  logger = logging.get_logger(__name__)
@@ -47,7 +46,7 @@ class RBLNMidmLMHeadModel(RBLNDecoderOnlyModelForCausalLM):
47
46
  """
48
47
 
49
48
  _decoder_wrapper_cls = MidmLMHeadModelWrapper
50
- _original_cls = MidmLMHeadModel
49
+ _hf_class = AutoModelForCausalLM
51
50
 
52
51
  @classmethod
53
52
  def from_pretrained(cls, *args, **kwargs):
@@ -21,7 +21,6 @@
21
21
  # copied, modified, or distributed without prior written permission
22
22
  # from Rebellions Inc.
23
23
 
24
-
25
24
  from ..decoderonly.decoderonly_architecture import DecoderOnlyWrapper
26
25
 
27
26