optimum-rbln 0.1.13__py3-none-any.whl → 0.2.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- optimum/rbln/__init__.py +41 -38
- optimum/rbln/__version__.py +16 -1
- optimum/rbln/diffusers/__init__.py +26 -2
- optimum/rbln/{modeling_diffusers.py → diffusers/modeling_diffusers.py} +97 -126
- optimum/rbln/diffusers/models/__init__.py +36 -3
- optimum/rbln/{transformers/generation → diffusers/models/autoencoders}/__init__.py +1 -2
- optimum/rbln/diffusers/models/{autoencoder_kl.py → autoencoders/autoencoder_kl.py} +73 -61
- optimum/rbln/diffusers/models/autoencoders/vae.py +83 -0
- optimum/rbln/diffusers/models/controlnet.py +54 -14
- optimum/rbln/diffusers/models/transformers/__init__.py +24 -0
- optimum/rbln/diffusers/models/transformers/transformer_sd3.py +203 -0
- optimum/rbln/diffusers/models/unets/__init__.py +24 -0
- optimum/rbln/diffusers/models/{unet_2d_condition.py → unets/unet_2d_condition.py} +82 -22
- optimum/rbln/diffusers/pipelines/__init__.py +23 -2
- optimum/rbln/diffusers/pipelines/controlnet/multicontrolnet.py +13 -33
- optimum/rbln/diffusers/pipelines/controlnet/pipeline_controlnet.py +17 -2
- optimum/rbln/diffusers/pipelines/controlnet/pipeline_controlnet_img2img.py +18 -2
- optimum/rbln/diffusers/pipelines/controlnet/pipeline_controlnet_sd_xl.py +18 -2
- optimum/rbln/diffusers/pipelines/controlnet/pipeline_controlnet_sd_xl_img2img.py +18 -2
- optimum/rbln/diffusers/pipelines/stable_diffusion/__init__.py +1 -0
- optimum/rbln/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion.py +2 -2
- optimum/rbln/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_img2img.py +1 -13
- optimum/rbln/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_inpaint.py +31 -0
- optimum/rbln/diffusers/pipelines/stable_diffusion_3/__init__.py +26 -0
- optimum/rbln/diffusers/pipelines/stable_diffusion_3/pipeline_stable_diffusion_3.py +31 -0
- optimum/rbln/diffusers/pipelines/stable_diffusion_3/pipeline_stable_diffusion_3_img2img.py +31 -0
- optimum/rbln/diffusers/pipelines/stable_diffusion_3/pipeline_stable_diffusion_3_inpaint.py +31 -0
- optimum/rbln/diffusers/pipelines/stable_diffusion_xl/__init__.py +24 -0
- optimum/rbln/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl.py +15 -8
- optimum/rbln/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_img2img.py +15 -8
- optimum/rbln/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_inpaint.py +31 -0
- optimum/rbln/modeling.py +238 -0
- optimum/rbln/modeling_base.py +186 -760
- optimum/rbln/modeling_config.py +31 -7
- optimum/rbln/ops/__init__.py +26 -0
- optimum/rbln/ops/attn.py +221 -0
- optimum/rbln/ops/flash_attn.py +70 -0
- optimum/rbln/ops/kv_cache_update.py +69 -0
- optimum/rbln/transformers/__init__.py +20 -2
- optimum/rbln/{modeling_alias.py → transformers/modeling_alias.py} +5 -1
- optimum/rbln/transformers/modeling_generic.py +385 -0
- optimum/rbln/transformers/models/auto/__init__.py +23 -0
- optimum/rbln/transformers/models/auto/auto_factory.py +117 -23
- optimum/rbln/transformers/models/auto/modeling_auto.py +36 -12
- optimum/rbln/transformers/models/bart/__init__.py +0 -1
- optimum/rbln/transformers/models/bart/bart_architecture.py +107 -464
- optimum/rbln/transformers/models/bart/modeling_bart.py +10 -9
- optimum/rbln/transformers/models/bert/modeling_bert.py +3 -6
- optimum/rbln/transformers/models/clip/modeling_clip.py +8 -25
- optimum/rbln/transformers/models/decoderonly/__init__.py +0 -10
- optimum/rbln/transformers/models/decoderonly/decoderonly_architecture.py +775 -514
- optimum/rbln/transformers/models/decoderonly/modeling_decoderonly.py +128 -260
- optimum/rbln/transformers/models/dpt/modeling_dpt.py +1 -1
- optimum/rbln/transformers/models/exaone/exaone_architecture.py +60 -45
- optimum/rbln/transformers/models/exaone/modeling_exaone.py +4 -2
- optimum/rbln/transformers/models/gemma/gemma_architecture.py +33 -104
- optimum/rbln/transformers/models/gpt2/gpt2_architecture.py +50 -238
- optimum/rbln/transformers/models/gpt2/modeling_gpt2.py +3 -2
- optimum/rbln/transformers/models/llama/llama_architecture.py +0 -1
- optimum/rbln/transformers/models/llava_next/modeling_llava_next.py +3 -75
- optimum/rbln/transformers/models/midm/midm_architecture.py +84 -238
- optimum/rbln/transformers/models/midm/modeling_midm.py +5 -6
- optimum/rbln/transformers/models/mistral/mistral_architecture.py +0 -1
- optimum/rbln/transformers/models/phi/phi_architecture.py +60 -261
- optimum/rbln/transformers/models/qwen2/qwen2_architecture.py +0 -1
- optimum/rbln/transformers/models/seq2seq/modeling_seq2seq.py +58 -103
- optimum/rbln/transformers/models/seq2seq/seq2seq_architecture.py +498 -0
- optimum/rbln/transformers/models/t5/__init__.py +0 -1
- optimum/rbln/transformers/models/t5/modeling_t5.py +106 -5
- optimum/rbln/transformers/models/t5/t5_architecture.py +106 -448
- optimum/rbln/transformers/models/wav2vec2/modeling_wav2vec2.py +1 -1
- optimum/rbln/transformers/models/whisper/generation_whisper.py +42 -0
- optimum/rbln/transformers/models/whisper/modeling_whisper.py +78 -55
- optimum/rbln/transformers/models/whisper/whisper_architecture.py +219 -312
- optimum/rbln/transformers/models/xlm_roberta/modeling_xlm_roberta.py +3 -35
- optimum/rbln/transformers/utils/rbln_quantization.py +120 -4
- optimum/rbln/utils/decorator_utils.py +51 -11
- optimum/rbln/utils/hub.py +131 -0
- optimum/rbln/utils/import_utils.py +22 -1
- optimum/rbln/utils/logging.py +37 -0
- optimum/rbln/utils/model_utils.py +52 -0
- optimum/rbln/utils/runtime_utils.py +10 -4
- optimum/rbln/utils/save_utils.py +17 -0
- optimum/rbln/utils/submodule.py +137 -0
- optimum_rbln-0.2.0.dist-info/METADATA +117 -0
- optimum_rbln-0.2.0.dist-info/RECORD +114 -0
- {optimum_rbln-0.1.13.dist-info → optimum_rbln-0.2.0.dist-info}/WHEEL +1 -1
- optimum_rbln-0.2.0.dist-info/licenses/LICENSE +288 -0
- optimum/rbln/transformers/cache_utils.py +0 -107
- optimum/rbln/transformers/generation/streamers.py +0 -139
- optimum/rbln/transformers/generation/utils.py +0 -397
- optimum/rbln/transformers/models/exaone/hf_hub_cached/configuration_exaone.py +0 -181
- optimum/rbln/transformers/models/exaone/hf_hub_cached/modeling_exaone.py +0 -1725
- optimum/rbln/transformers/models/midm/hf_hub_cached/configuration_midm.py +0 -22
- optimum/rbln/transformers/models/midm/hf_hub_cached/midm_bitext_tokenization.py +0 -304
- optimum/rbln/transformers/models/midm/hf_hub_cached/modeling_midm.py +0 -1469
- optimum/rbln/transformers/models/midm/hf_hub_cached/rotary_position_embedding.py +0 -98
- optimum/rbln/utils/context.py +0 -58
- optimum/rbln/utils/timer_utils.py +0 -43
- optimum_rbln-0.1.13.dist-info/METADATA +0 -120
- optimum_rbln-0.1.13.dist-info/RECORD +0 -107
- optimum_rbln-0.1.13.dist-info/entry_points.txt +0 -4
- optimum_rbln-0.1.13.dist-info/licenses/LICENSE +0 -201
@@ -20,10 +20,11 @@
|
|
20
20
|
# are the intellectual property of Rebellions Inc. and may not be
|
21
21
|
# copied, modified, or distributed without prior written permission
|
22
22
|
# from Rebellions Inc.
|
23
|
+
|
23
24
|
import inspect
|
24
25
|
import logging
|
25
26
|
from pathlib import Path
|
26
|
-
from typing import TYPE_CHECKING, Any, Callable, Dict,
|
27
|
+
from typing import TYPE_CHECKING, Any, Callable, Dict, Optional, Tuple, Union
|
27
28
|
|
28
29
|
import numpy as np
|
29
30
|
import torch
|
@@ -36,7 +37,7 @@ from transformers import (
|
|
36
37
|
from transformers.modeling_outputs import BaseModelOutputWithPooling
|
37
38
|
from transformers.models.llava_next.modeling_llava_next import LlavaNextCausalLMOutputWithPast
|
38
39
|
|
39
|
-
from ....
|
40
|
+
from ....modeling import RBLNModel
|
40
41
|
from ....modeling_config import RBLNCompileConfig, RBLNConfig
|
41
42
|
from ..decoderonly.modeling_decoderonly import RBLNDecoderOnlyOutput
|
42
43
|
|
@@ -166,19 +167,6 @@ class RBLNLlavaNextForConditionalGeneration(RBLNModel):
|
|
166
167
|
self._padding_side = "left" # set it to left by default, user can use setter to change padding_sides
|
167
168
|
return super().__post_init__(**kwargs)
|
168
169
|
|
169
|
-
@classmethod
|
170
|
-
def get_pytorch_model(
|
171
|
-
cls,
|
172
|
-
model_id: str,
|
173
|
-
*args,
|
174
|
-
rbln_kwargs: Optional[Dict[str, Any]] = None,
|
175
|
-
**kwargs,
|
176
|
-
) -> "PreTrainedModel":
|
177
|
-
# Optimum's TasksManager does not handle Llava.
|
178
|
-
kwargs = cls.update_kwargs(kwargs)
|
179
|
-
model = LlavaNextForConditionalGeneration.from_pretrained(model_id, *args, **kwargs)
|
180
|
-
return model
|
181
|
-
|
182
170
|
def get_input_embeddings(self):
|
183
171
|
return self.language_model.get_input_embeddings()
|
184
172
|
|
@@ -422,66 +410,6 @@ class RBLNLlavaNextForConditionalGeneration(RBLNModel):
|
|
422
410
|
|
423
411
|
return outputs
|
424
412
|
|
425
|
-
def vllm_forward(
|
426
|
-
self,
|
427
|
-
input_ids: torch.LongTensor = None,
|
428
|
-
pixel_values: torch.FloatTensor = None,
|
429
|
-
image_sizes: Optional[torch.LongTensor] = None,
|
430
|
-
inputs_embeds: Optional[torch.FloatTensor] = None,
|
431
|
-
vision_feature_layer: Optional[int] = None,
|
432
|
-
vision_feature_select_strategy: Optional[str] = None,
|
433
|
-
cache_position: Union[List[torch.Tensor], torch.Tensor] = None, # vllm keyword argument
|
434
|
-
batch_idx: Optional[int] = None,
|
435
|
-
**kwargs,
|
436
|
-
) -> Union[Tuple, LlavaNextCausalLMOutputWithPast]:
|
437
|
-
is_prefill = cache_position.shape[-1] > 1
|
438
|
-
|
439
|
-
if inputs_embeds is not None:
|
440
|
-
raise NotImplementedError("Specifying inputs_embeds is not supported.")
|
441
|
-
|
442
|
-
if is_prefill:
|
443
|
-
# Get text_embeds
|
444
|
-
inputs_embeds = self.text_embedding(input_ids)
|
445
|
-
|
446
|
-
# If any images in the prompt, get image_embeds and merge with text
|
447
|
-
if pixel_values is not None and input_ids.shape[1] != 1 and pixel_values.size(0) > 0:
|
448
|
-
image_features, _ = self.image_embedding(
|
449
|
-
image_sizes, pixel_values, vision_feature_layer, vision_feature_select_strategy
|
450
|
-
)
|
451
|
-
|
452
|
-
def merge_vllm_multimodal_embeddings(
|
453
|
-
input_ids: torch.Tensor,
|
454
|
-
inputs_embeds: torch.Tensor,
|
455
|
-
multimodal_embeddings: torch.Tensor,
|
456
|
-
placeholder_token_id: int,
|
457
|
-
) -> torch.Tensor:
|
458
|
-
mask = input_ids == placeholder_token_id
|
459
|
-
num_expected_tokens = mask.sum().item()
|
460
|
-
|
461
|
-
if multimodal_embeddings.shape[0] != num_expected_tokens:
|
462
|
-
raise ValueError(
|
463
|
-
f"Attempted to assign {inputs_embeds[mask].shape} = {multimodal_embeddings.shape} "
|
464
|
-
f"multimodal tokens to {num_expected_tokens} placeholders"
|
465
|
-
)
|
466
|
-
|
467
|
-
inputs_embeds[mask] = multimodal_embeddings
|
468
|
-
return inputs_embeds
|
469
|
-
|
470
|
-
inputs_embeds = merge_vllm_multimodal_embeddings(
|
471
|
-
input_ids, inputs_embeds, image_features, self.config.image_token_index
|
472
|
-
)
|
473
|
-
|
474
|
-
else:
|
475
|
-
inputs_embeds = self.text_embedding(input_ids=input_ids)
|
476
|
-
|
477
|
-
outputs: RBLNDecoderOnlyOutput = self.language_model.vllm_forward(
|
478
|
-
inputs_embeds=inputs_embeds,
|
479
|
-
batch_idx=batch_idx,
|
480
|
-
cache_position=cache_position,
|
481
|
-
)
|
482
|
-
|
483
|
-
return outputs
|
484
|
-
|
485
413
|
# Almost copied from : https://github.com/huggingface/transformers/blob/5af7d41e49bbfc8319f462eb45253dcb3863dfb7/src/transformers/models/llava_next/modeling_llava_next.py
|
486
414
|
def pack_image_features(self, image_features, image_sizes, image_newline=None):
|
487
415
|
"""
|
@@ -21,18 +21,25 @@
|
|
21
21
|
# copied, modified, or distributed without prior written permission
|
22
22
|
# from Rebellions Inc.
|
23
23
|
|
24
|
-
|
24
|
+
import math
|
25
|
+
from typing import TYPE_CHECKING, Tuple
|
25
26
|
|
26
27
|
import torch
|
27
28
|
import torch.nn as nn
|
28
|
-
from transformers.modeling_outputs import BaseModelOutputWithPast
|
29
29
|
|
30
|
-
from
|
31
|
-
|
30
|
+
from ..decoderonly.decoderonly_architecture import (
|
31
|
+
DecoderOnlyAttention,
|
32
|
+
DecoderOnlyForCausalLM,
|
33
|
+
DecoderOnlyLayer,
|
34
|
+
DecoderOnlyModel,
|
35
|
+
DecoderOnlyWrapper,
|
36
|
+
apply_rotary_pos_emb_partial,
|
32
37
|
rotate_half,
|
33
|
-
slice_and_unsqueeze_cos_sin,
|
34
38
|
)
|
35
|
-
|
39
|
+
|
40
|
+
|
41
|
+
if TYPE_CHECKING:
|
42
|
+
from transformers import PreTrainedModel as MidmLMHeadModel
|
36
43
|
|
37
44
|
|
38
45
|
def apply_rotary_to_tensor(tensor, cos, sin, rot_dim):
|
@@ -50,253 +57,92 @@ def apply_rotary_pos_emb(q, k, cos, sin):
|
|
50
57
|
return q_embed, k_embed
|
51
58
|
|
52
59
|
|
53
|
-
class MidmLMHeadModelWrapper(
|
54
|
-
|
55
|
-
|
56
|
-
def __init__(self, model, max_seq_len):
|
57
|
-
super().__init__()
|
58
|
-
self.model = model.transformer
|
59
|
-
self.lm_head = model.lm_head
|
60
|
-
self.config = model.config
|
61
|
-
self.max_seq_len = max_seq_len
|
62
|
-
|
63
|
-
self.config.partial_rotary_factor = model.config.rotary_percentage
|
64
|
-
self.config.head_dim = self.config.n_embd // self.config.n_head
|
60
|
+
class MidmLMHeadModelWrapper(DecoderOnlyWrapper):
|
61
|
+
def get_rotary_emb(self, max_seq_len):
|
65
62
|
self.config.rope_theta = 10000
|
66
|
-
self.
|
67
|
-
|
68
|
-
|
69
|
-
|
70
|
-
|
71
|
-
|
72
|
-
|
73
|
-
|
74
|
-
|
75
|
-
|
76
|
-
|
77
|
-
|
78
|
-
|
79
|
-
|
80
|
-
|
81
|
-
|
82
|
-
|
83
|
-
|
84
|
-
|
85
|
-
|
86
|
-
|
87
|
-
|
88
|
-
|
89
|
-
outputs = _MidmModel.forward(
|
90
|
-
self.model,
|
91
|
-
input_ids=input_ids,
|
92
|
-
past_key_values=past_key_values,
|
93
|
-
attention_mask=attention_mask,
|
94
|
-
position_ids=cache_position,
|
95
|
-
rotary_pos_emb=self.rotary_emb,
|
96
|
-
batch_ids=rbln_batch_position,
|
97
|
-
)
|
98
|
-
|
99
|
-
hidden_states = outputs[0]
|
100
|
-
if batch_position >= 0:
|
101
|
-
hidden_states = hidden_states[:, query_idx].unsqueeze(1)
|
102
|
-
|
103
|
-
logits = self.lm_head(hidden_states)
|
104
|
-
output = (logits,) + outputs[1:]
|
105
|
-
|
106
|
-
return output, batch_position + query_idx
|
107
|
-
|
108
|
-
|
109
|
-
def layernorm1p(module, input):
|
110
|
-
"""Applies Layer Normalization with a slight modification on the weights."""
|
111
|
-
return torch.nn.functional.layer_norm(input, module.normalized_shape, module.weight + 1, module.bias, module.eps)
|
112
|
-
|
113
|
-
|
114
|
-
class _MidmAttention:
|
115
|
-
"""Custom implementation of the MidmAttention class with specific modifications."""
|
116
|
-
|
117
|
-
def _attn(self, query, key, value, attention_mask=None, head_mask=None):
|
118
|
-
"""Computes the attention weights and output."""
|
119
|
-
attn_weights = torch.matmul(query, key.transpose(-1, -2))
|
120
|
-
|
121
|
-
if self.scale_attn_weights:
|
122
|
-
attn_weights = attn_weights / torch.full(
|
123
|
-
[], value.size(-1) ** 0.5, dtype=attn_weights.dtype, device=attn_weights.device
|
63
|
+
self.config.head_dim = self.config.n_embd // self.config.n_head
|
64
|
+
self.config.partial_rotary_factor = self.config.rotary_percentage
|
65
|
+
return super().get_rotary_emb(max_seq_len=max_seq_len)
|
66
|
+
|
67
|
+
def convert_to_rbln_causal_lm(self, causal_lm: "MidmLMHeadModel"):
|
68
|
+
if self.attn_impl != "eager":
|
69
|
+
raise NotImplementedError(f"flash attention ({self.attn_impl}) is not implemented for {self.__class__}")
|
70
|
+
new_layers = []
|
71
|
+
for layer in causal_lm.transformer.h:
|
72
|
+
new_self_attn = MidmAttention(layer.attn)
|
73
|
+
new_layer = MidmLayer(layer, new_self_attn)
|
74
|
+
new_layers.append(new_layer)
|
75
|
+
new_model = MidmModel(causal_lm.transformer, new_layers)
|
76
|
+
new_causal_lm = DecoderOnlyForCausalLM(causal_lm, new_model)
|
77
|
+
return new_causal_lm
|
78
|
+
|
79
|
+
|
80
|
+
class MidmModel(DecoderOnlyModel):
|
81
|
+
def get_layernorm1p(self, module: nn.LayerNorm):
|
82
|
+
def layernorm1p(input: torch.Tensor):
|
83
|
+
"""Applies Layer Normalization with a slight modification on the weights."""
|
84
|
+
return torch.nn.functional.layer_norm(
|
85
|
+
input, module.normalized_shape, module.weight + 1, module.bias, module.eps
|
124
86
|
)
|
125
87
|
|
126
|
-
|
127
|
-
attn_weights = attn_weights / float(self.layer_idx + 1)
|
128
|
-
|
129
|
-
if attention_mask is not None:
|
130
|
-
attn_weights = attn_weights + attention_mask
|
131
|
-
|
132
|
-
if self.scale_qk_by_inverse_layer_idx:
|
133
|
-
attn_weights = attn_weights * float(self.layer_idx + 1)
|
88
|
+
return layernorm1p
|
134
89
|
|
135
|
-
|
136
|
-
|
137
|
-
|
138
|
-
|
139
|
-
|
140
|
-
|
141
|
-
attn_output = torch.matmul(attn_weights, value)
|
142
|
-
return attn_output, attn_weights
|
143
|
-
|
144
|
-
def forward(
|
145
|
-
self,
|
146
|
-
hidden_states: Optional[Tuple[torch.FloatTensor]],
|
147
|
-
attention_mask: Optional[torch.FloatTensor] = None,
|
148
|
-
past_key_value: Optional[RebelDynamicCache_4D] = None,
|
149
|
-
batch_index: Optional[int] = None,
|
150
|
-
cos: Optional[torch.Tensor] = None,
|
151
|
-
sin: Optional[torch.Tensor] = None,
|
152
|
-
) -> Tuple[Union[torch.Tensor, Tuple[torch.Tensor]], ...]:
|
153
|
-
"""Defines the forward pass for the attention mechanism."""
|
154
|
-
bsz, q_len, _ = hidden_states.size()
|
90
|
+
def get_last_layernorm(self) -> nn.LayerNorm:
|
91
|
+
if self._original_mod.use_layernorm1p:
|
92
|
+
return self.get_layernorm1p(self._original_mod.ln_f)
|
93
|
+
else:
|
94
|
+
return self._original_mod.ln_f
|
155
95
|
|
156
|
-
|
96
|
+
def get_embedding(self) -> nn.Embedding:
|
97
|
+
return self._original_mod.wte
|
157
98
|
|
158
|
-
|
159
|
-
|
160
|
-
values = self._split_heads(values, self.num_heads, self.head_dim).contiguous()
|
99
|
+
def get_pos_embedding(self) -> nn.Embedding:
|
100
|
+
return self._original_mod.wpe
|
161
101
|
|
162
|
-
querys, keys = apply_rotary_pos_emb(querys, keys, cos, sin)
|
163
102
|
|
164
|
-
|
165
|
-
|
166
|
-
|
167
|
-
|
168
|
-
|
103
|
+
class MidmLayer(DecoderOnlyLayer):
|
104
|
+
def get_layernorm1p(self, module: nn.LayerNorm):
|
105
|
+
def layernorm1p(input: torch.Tensor):
|
106
|
+
"""Applies Layer Normalization with a slight modification on the weights."""
|
107
|
+
return torch.nn.functional.layer_norm(
|
108
|
+
input, module.normalized_shape, module.weight + 1, module.bias, module.eps
|
109
|
+
)
|
169
110
|
|
170
|
-
|
171
|
-
query = querys[b].unsqueeze(0)
|
172
|
-
attn_mask = attention_mask[b].unsqueeze(0)
|
173
|
-
key = keys[b].unsqueeze(0)
|
174
|
-
value = values[b].unsqueeze(0)
|
111
|
+
return layernorm1p
|
175
112
|
|
176
|
-
|
177
|
-
|
178
|
-
|
179
|
-
|
180
|
-
|
181
|
-
)
|
113
|
+
def get_pre_attention_layernorm(self) -> nn.LayerNorm:
|
114
|
+
if self._original_mod.use_layernorm1p:
|
115
|
+
return self.get_layernorm1p(self._original_mod.ln_1)
|
116
|
+
else:
|
117
|
+
return self._original_mod.ln_1
|
182
118
|
|
183
|
-
|
184
|
-
|
119
|
+
def get_post_attention_layernorm(self) -> nn.LayerNorm:
|
120
|
+
if self._original_mod.use_layernorm1p:
|
121
|
+
return self.get_layernorm1p(self._original_mod.ln_2)
|
122
|
+
else:
|
123
|
+
return self._original_mod.ln_2
|
185
124
|
|
186
|
-
all_key_states.append(key)
|
187
|
-
all_value_states.append(value)
|
188
|
-
all_attn_output.append(attn_output)
|
189
125
|
|
190
|
-
|
191
|
-
|
192
|
-
|
126
|
+
class MidmAttention(DecoderOnlyAttention):
|
127
|
+
def __post_init__(self):
|
128
|
+
self.c_attn = self._original_mod.c_attn
|
129
|
+
self.o_proj = self._original_mod.c_proj
|
130
|
+
self.split_size = self._original_mod.split_size
|
131
|
+
self.num_key_value_heads = self._original_mod.num_heads
|
193
132
|
|
194
|
-
|
195
|
-
|
196
|
-
|
197
|
-
|
198
|
-
keys, values = past_key_value.update(
|
199
|
-
keys,
|
200
|
-
values,
|
201
|
-
self.layer_idx,
|
202
|
-
batch_index,
|
203
|
-
read_first_step=True,
|
204
|
-
)
|
133
|
+
def projection(self, hidden_states) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
|
134
|
+
query_states, key_states, value_states = self.c_attn(hidden_states).split(self.split_size, dim=2)
|
135
|
+
return query_states, key_states, value_states
|
205
136
|
|
206
|
-
|
207
|
-
|
208
|
-
|
209
|
-
|
210
|
-
return attn_output, keys, values
|
211
|
-
|
212
|
-
|
213
|
-
class _MidmBlock:
|
214
|
-
"""Custom implementation of the MidmBlock class with specific modifications."""
|
215
|
-
|
216
|
-
def forward(
|
217
|
-
self,
|
218
|
-
hidden_states: Optional[Tuple[torch.FloatTensor]],
|
219
|
-
layer_idx: int,
|
220
|
-
attention_mask: Optional[torch.FloatTensor] = None,
|
221
|
-
past_key_value: Optional[RebelDynamicCache_4D] = None,
|
222
|
-
batch_ids: Optional[torch.LongTensor] = None,
|
223
|
-
cos: Optional[torch.Tensor] = None,
|
224
|
-
sin: Optional[torch.Tensor] = None,
|
225
|
-
) -> Union[Tuple[torch.Tensor], Optional[Tuple[torch.Tensor, Tuple[torch.FloatTensor, ...]]]]:
|
226
|
-
"""Defines the forward pass for the block."""
|
227
|
-
residual = hidden_states
|
228
|
-
if self.use_layernorm1p:
|
229
|
-
hidden_states = layernorm1p(self.ln_1, hidden_states)
|
230
|
-
else:
|
231
|
-
hidden_states = self.ln_1(hidden_states)
|
232
|
-
|
233
|
-
hidden_states, k, v = _MidmAttention.forward(
|
234
|
-
self.attn,
|
235
|
-
hidden_states,
|
236
|
-
attention_mask=attention_mask,
|
237
|
-
past_key_value=past_key_value,
|
238
|
-
cos=cos,
|
239
|
-
sin=sin,
|
240
|
-
batch_index=batch_ids,
|
241
|
-
)
|
242
|
-
past_key_value.assign(k, v, layer_idx)
|
243
|
-
|
244
|
-
hidden_states = hidden_states + residual
|
245
|
-
|
246
|
-
residual = hidden_states
|
247
|
-
if self.use_layernorm1p:
|
248
|
-
hidden_states = layernorm1p(self.ln_2, hidden_states)
|
249
|
-
else:
|
250
|
-
hidden_states = self.ln_2(hidden_states)
|
251
|
-
|
252
|
-
feed_forward_hidden_states = self.mlp(hidden_states)
|
253
|
-
hidden_states = residual + feed_forward_hidden_states
|
254
|
-
|
255
|
-
return hidden_states, past_key_value
|
256
|
-
|
257
|
-
|
258
|
-
class _MidmModel:
|
259
|
-
"""Custom implementation of the MidmModel class with specific modifications."""
|
260
|
-
|
261
|
-
def forward(
|
262
|
-
self,
|
263
|
-
input_ids: Optional[torch.LongTensor] = None,
|
264
|
-
past_key_values: Optional[RebelDynamicCache_4D] = None,
|
265
|
-
attention_mask: Optional[torch.FloatTensor] = None,
|
266
|
-
position_ids: Optional[torch.LongTensor] = None,
|
267
|
-
rotary_pos_emb=None,
|
268
|
-
batch_ids: Optional[torch.LongTensor] = None,
|
269
|
-
) -> Union[Tuple, BaseModelOutputWithPast]:
|
270
|
-
"""Defines the forward pass for the model."""
|
271
|
-
input_shape = input_ids.size()
|
272
|
-
|
273
|
-
attention_mask = (1.0 - attention_mask) * -10000.0
|
274
|
-
|
275
|
-
inputs_embeds = self.wte(input_ids)
|
276
|
-
|
277
|
-
cos, sin = rotary_pos_emb(inputs_embeds, attention_mask.shape[-1])
|
278
|
-
cos, sin = slice_and_unsqueeze_cos_sin(cos, sin, position_ids)
|
279
|
-
hidden_states = inputs_embeds
|
280
|
-
|
281
|
-
for layer_idx, (block, _) in enumerate(zip(self.h, past_key_values)):
|
282
|
-
hidden_states, updated_cache = _MidmBlock.forward(
|
283
|
-
block,
|
284
|
-
hidden_states,
|
285
|
-
layer_idx,
|
286
|
-
attention_mask=attention_mask,
|
287
|
-
past_key_value=past_key_values,
|
288
|
-
batch_ids=batch_ids,
|
289
|
-
cos=cos,
|
290
|
-
sin=sin,
|
291
|
-
)
|
137
|
+
def get_attn_scale(self):
|
138
|
+
scale = 1.0
|
139
|
+
if self._original_mod.scale_attn_weights:
|
140
|
+
scale /= math.sqrt(self.head_dim)
|
292
141
|
|
293
|
-
|
294
|
-
|
295
|
-
hidden_states = hidden_states.view(output_shape)
|
142
|
+
if self._original_mod.scale_attn_by_inverse_layer_idx and not self._original_mod.scale_qk_by_inverse_layer_idx:
|
143
|
+
scale /= 1 + self.layer_idx
|
296
144
|
|
297
|
-
|
145
|
+
return scale
|
298
146
|
|
299
|
-
|
300
|
-
|
301
|
-
past_key_values=next_cache,
|
302
|
-
)
|
147
|
+
def apply_rotary_pos_embed(self, query_states, key_states, cos, sin):
|
148
|
+
return apply_rotary_pos_emb_partial(query_states, key_states, cos, sin, ndim=cos.shape[-1])
|
@@ -21,12 +21,11 @@
|
|
21
21
|
# copied, modified, or distributed without prior written permission
|
22
22
|
# from Rebellions Inc.
|
23
23
|
|
24
|
+
from transformers import AutoModelForCausalLM
|
25
|
+
|
24
26
|
from ....utils import logging
|
25
|
-
from
|
26
|
-
from .
|
27
|
-
from .midm_architecture import (
|
28
|
-
MidmLMHeadModelWrapper,
|
29
|
-
)
|
27
|
+
from ..decoderonly import RBLNDecoderOnlyModelForCausalLM
|
28
|
+
from .midm_architecture import MidmLMHeadModelWrapper
|
30
29
|
|
31
30
|
|
32
31
|
logger = logging.get_logger(__name__)
|
@@ -47,7 +46,7 @@ class RBLNMidmLMHeadModel(RBLNDecoderOnlyModelForCausalLM):
|
|
47
46
|
"""
|
48
47
|
|
49
48
|
_decoder_wrapper_cls = MidmLMHeadModelWrapper
|
50
|
-
|
49
|
+
_hf_class = AutoModelForCausalLM
|
51
50
|
|
52
51
|
@classmethod
|
53
52
|
def from_pretrained(cls, *args, **kwargs):
|