optimum-rbln 0.1.12__py3-none-any.whl → 0.1.15__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- optimum/rbln/__init__.py +27 -13
- optimum/rbln/__version__.py +16 -1
- optimum/rbln/diffusers/__init__.py +22 -2
- optimum/rbln/diffusers/models/__init__.py +34 -3
- optimum/rbln/{transformers/generation → diffusers/models/autoencoders}/__init__.py +1 -2
- optimum/rbln/diffusers/models/{autoencoder_kl.py → autoencoders/autoencoder_kl.py} +66 -111
- optimum/rbln/diffusers/models/autoencoders/vae.py +84 -0
- optimum/rbln/diffusers/models/controlnet.py +85 -65
- optimum/rbln/diffusers/models/transformers/__init__.py +24 -0
- optimum/rbln/diffusers/models/transformers/transformer_sd3.py +203 -0
- optimum/rbln/diffusers/models/unets/__init__.py +24 -0
- optimum/rbln/diffusers/models/{unet_2d_condition.py → unets/unet_2d_condition.py} +129 -163
- optimum/rbln/diffusers/pipelines/__init__.py +60 -12
- optimum/rbln/diffusers/pipelines/controlnet/multicontrolnet.py +11 -25
- optimum/rbln/diffusers/pipelines/controlnet/pipeline_controlnet.py +9 -185
- optimum/rbln/diffusers/pipelines/controlnet/pipeline_controlnet_img2img.py +9 -190
- optimum/rbln/diffusers/pipelines/controlnet/pipeline_controlnet_sd_xl.py +9 -191
- optimum/rbln/diffusers/pipelines/controlnet/pipeline_controlnet_sd_xl_img2img.py +9 -192
- optimum/rbln/diffusers/pipelines/stable_diffusion/__init__.py +1 -0
- optimum/rbln/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion.py +4 -110
- optimum/rbln/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_img2img.py +4 -118
- optimum/rbln/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_inpaint.py +32 -0
- optimum/rbln/diffusers/pipelines/stable_diffusion_3/__init__.py +26 -0
- optimum/rbln/diffusers/pipelines/stable_diffusion_3/pipeline_stable_diffusion_3.py +32 -0
- optimum/rbln/diffusers/pipelines/stable_diffusion_3/pipeline_stable_diffusion_3_img2img.py +32 -0
- optimum/rbln/diffusers/pipelines/stable_diffusion_3/pipeline_stable_diffusion_3_inpaint.py +32 -0
- optimum/rbln/diffusers/pipelines/stable_diffusion_xl/__init__.py +1 -0
- optimum/rbln/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl.py +18 -128
- optimum/rbln/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_img2img.py +18 -131
- optimum/rbln/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_inpaint.py +32 -0
- optimum/rbln/modeling.py +572 -0
- optimum/rbln/modeling_alias.py +1 -1
- optimum/rbln/modeling_base.py +176 -763
- optimum/rbln/modeling_diffusers.py +329 -0
- optimum/rbln/transformers/__init__.py +2 -2
- optimum/rbln/transformers/cache_utils.py +5 -9
- optimum/rbln/transformers/modeling_rope_utils.py +283 -0
- optimum/rbln/transformers/models/__init__.py +80 -31
- optimum/rbln/transformers/models/auto/auto_factory.py +117 -23
- optimum/rbln/transformers/models/auto/modeling_auto.py +37 -12
- optimum/rbln/transformers/models/bart/modeling_bart.py +3 -6
- optimum/rbln/transformers/models/bert/modeling_bert.py +3 -6
- optimum/rbln/transformers/models/clip/modeling_clip.py +8 -34
- optimum/rbln/transformers/models/decoderonly/__init__.py +0 -5
- optimum/rbln/transformers/models/decoderonly/decoderonly_architecture.py +779 -361
- optimum/rbln/transformers/models/decoderonly/modeling_decoderonly.py +83 -142
- optimum/rbln/transformers/models/dpt/modeling_dpt.py +1 -1
- optimum/rbln/transformers/models/exaone/exaone_architecture.py +64 -39
- optimum/rbln/transformers/models/exaone/modeling_exaone.py +6 -29
- optimum/rbln/transformers/models/gemma/gemma_architecture.py +31 -92
- optimum/rbln/transformers/models/gemma/modeling_gemma.py +4 -28
- optimum/rbln/transformers/models/gpt2/gpt2_architecture.py +50 -238
- optimum/rbln/transformers/models/gpt2/modeling_gpt2.py +6 -31
- optimum/rbln/transformers/models/llama/modeling_llama.py +4 -28
- optimum/rbln/transformers/models/llava_next/modeling_llava_next.py +29 -83
- optimum/rbln/transformers/models/midm/midm_architecture.py +88 -253
- optimum/rbln/transformers/models/midm/modeling_midm.py +8 -33
- optimum/rbln/transformers/models/mistral/modeling_mistral.py +4 -29
- optimum/rbln/transformers/models/phi/modeling_phi.py +5 -31
- optimum/rbln/transformers/models/phi/phi_architecture.py +61 -345
- optimum/rbln/transformers/models/qwen2/modeling_qwen2.py +5 -29
- optimum/rbln/transformers/models/seq2seq/modeling_seq2seq.py +1 -46
- optimum/rbln/transformers/models/t5/__init__.py +1 -1
- optimum/rbln/transformers/models/t5/modeling_t5.py +157 -6
- optimum/rbln/transformers/models/wav2vec2/modeling_wav2vec2.py +1 -1
- optimum/rbln/transformers/models/whisper/modeling_whisper.py +2 -2
- optimum/rbln/transformers/models/xlm_roberta/modeling_xlm_roberta.py +3 -35
- optimum/rbln/transformers/utils/rbln_quantization.py +128 -5
- optimum/rbln/utils/decorator_utils.py +59 -0
- optimum/rbln/utils/hub.py +131 -0
- optimum/rbln/utils/import_utils.py +21 -0
- optimum/rbln/utils/model_utils.py +53 -0
- optimum/rbln/utils/runtime_utils.py +5 -5
- optimum/rbln/utils/submodule.py +114 -0
- optimum/rbln/utils/timer_utils.py +2 -2
- optimum_rbln-0.1.15.dist-info/METADATA +106 -0
- optimum_rbln-0.1.15.dist-info/RECORD +110 -0
- {optimum_rbln-0.1.12.dist-info → optimum_rbln-0.1.15.dist-info}/WHEEL +1 -1
- optimum/rbln/transformers/generation/streamers.py +0 -139
- optimum/rbln/transformers/generation/utils.py +0 -397
- optimum/rbln/transformers/models/exaone/hf_hub_cached/configuration_exaone.py +0 -181
- optimum/rbln/transformers/models/exaone/hf_hub_cached/modeling_exaone.py +0 -1725
- optimum/rbln/transformers/models/midm/hf_hub_cached/configuration_midm.py +0 -22
- optimum/rbln/transformers/models/midm/hf_hub_cached/midm_bitext_tokenization.py +0 -304
- optimum/rbln/transformers/models/midm/hf_hub_cached/modeling_midm.py +0 -1469
- optimum/rbln/transformers/models/midm/hf_hub_cached/rotary_position_embedding.py +0 -98
- optimum_rbln-0.1.12.dist-info/METADATA +0 -119
- optimum_rbln-0.1.12.dist-info/RECORD +0 -103
- optimum_rbln-0.1.12.dist-info/entry_points.txt +0 -4
- {optimum_rbln-0.1.12.dist-info → optimum_rbln-0.1.15.dist-info}/licenses/LICENSE +0 -0
@@ -23,7 +23,7 @@
|
|
23
23
|
import inspect
|
24
24
|
import logging
|
25
25
|
from pathlib import Path
|
26
|
-
from typing import TYPE_CHECKING, Any, Callable, Dict,
|
26
|
+
from typing import TYPE_CHECKING, Any, Callable, Dict, Optional, Tuple, Union
|
27
27
|
|
28
28
|
import numpy as np
|
29
29
|
import torch
|
@@ -36,7 +36,7 @@ from transformers import (
|
|
36
36
|
from transformers.modeling_outputs import BaseModelOutputWithPooling
|
37
37
|
from transformers.models.llava_next.modeling_llava_next import LlavaNextCausalLMOutputWithPast
|
38
38
|
|
39
|
-
from ....
|
39
|
+
from ....modeling import RBLNModel
|
40
40
|
from ....modeling_config import RBLNCompileConfig, RBLNConfig
|
41
41
|
from ..decoderonly.modeling_decoderonly import RBLNDecoderOnlyOutput
|
42
42
|
|
@@ -166,19 +166,6 @@ class RBLNLlavaNextForConditionalGeneration(RBLNModel):
|
|
166
166
|
self._padding_side = "left" # set it to left by default, user can use setter to change padding_sides
|
167
167
|
return super().__post_init__(**kwargs)
|
168
168
|
|
169
|
-
@classmethod
|
170
|
-
def get_pytorch_model(
|
171
|
-
cls,
|
172
|
-
model_id: str,
|
173
|
-
*args,
|
174
|
-
rbln_kwargs: Optional[Dict[str, Any]] = None,
|
175
|
-
**kwargs,
|
176
|
-
) -> "PreTrainedModel":
|
177
|
-
# Optimum's TasksManager does not handle Llava.
|
178
|
-
kwargs = cls.update_kwargs(kwargs)
|
179
|
-
model = LlavaNextForConditionalGeneration.from_pretrained(model_id, *args, **kwargs)
|
180
|
-
return model
|
181
|
-
|
182
169
|
def get_input_embeddings(self):
|
183
170
|
return self.language_model.get_input_embeddings()
|
184
171
|
|
@@ -350,9 +337,22 @@ class RBLNLlavaNextForConditionalGeneration(RBLNModel):
|
|
350
337
|
is_prefill_phase = not generate_idx.bool().all()
|
351
338
|
|
352
339
|
if is_prefill_phase:
|
340
|
+
# if the number of image tokens is more than image embeddings seq length, then prob we expanded it in processing
|
341
|
+
# not very reliable, but we don't expect one to actually pass 500+ images for one prompt
|
342
|
+
# In case we're in decoding stage, legacy behavior is checked by presence of pixel values even if use_cache=True
|
343
|
+
legacy_processing = (
|
344
|
+
(input_ids == self.config.image_token_index).sum(1).max() < self.config.image_seq_length
|
345
|
+
) or (input_ids.shape[-1] == 1 and pixel_values is not None)
|
346
|
+
|
353
347
|
# Get the number of images in the prompt
|
354
348
|
special_image_token_masks = [input_id == self.config.image_token_index for input_id in input_ids]
|
355
|
-
|
349
|
+
if legacy_processing:
|
350
|
+
num_special_image_tokens = [torch.sum(mask, dim=-1) for mask in special_image_token_masks]
|
351
|
+
else:
|
352
|
+
image_tokens_masks_diff = [
|
353
|
+
torch.diff(mask, prepend=torch.tensor([0])) for mask in special_image_token_masks
|
354
|
+
]
|
355
|
+
num_special_image_tokens = [int(torch.sum((diff == 1).int())) for diff in image_tokens_masks_diff]
|
356
356
|
|
357
357
|
# Split images for each prompt
|
358
358
|
if pixel_values is not None and pixel_values.size(0) > 0:
|
@@ -370,13 +370,19 @@ class RBLNLlavaNextForConditionalGeneration(RBLNModel):
|
|
370
370
|
image_features, feature_lens = self.image_embedding(
|
371
371
|
image_sizes[b_idx], pixel_values[b_idx], vision_feature_layer, vision_feature_select_strategy
|
372
372
|
)
|
373
|
-
|
374
|
-
|
375
|
-
|
376
|
-
|
377
|
-
|
378
|
-
|
379
|
-
|
373
|
+
if legacy_processing:
|
374
|
+
inputs_embed, _, _, _, _ = self._merge_input_ids_with_image_features(
|
375
|
+
image_features,
|
376
|
+
feature_lens,
|
377
|
+
inputs_embed.to(image_features.dtype),
|
378
|
+
input_id,
|
379
|
+
torch.ones_like(input_id, dtype=torch.long),
|
380
|
+
)
|
381
|
+
else:
|
382
|
+
special_image_mask = (
|
383
|
+
(input_id == self.config.image_token_index).unsqueeze(-1).expand_as(inputs_embed)
|
384
|
+
)
|
385
|
+
inputs_embed = inputs_embed.masked_scatter(special_image_mask, image_features)
|
380
386
|
|
381
387
|
# Update generate_idx according to inputs_embed
|
382
388
|
generate_idx[b_idx] = inputs_embed.shape[1]
|
@@ -403,66 +409,6 @@ class RBLNLlavaNextForConditionalGeneration(RBLNModel):
|
|
403
409
|
|
404
410
|
return outputs
|
405
411
|
|
406
|
-
def vllm_forward(
|
407
|
-
self,
|
408
|
-
input_ids: torch.LongTensor = None,
|
409
|
-
pixel_values: torch.FloatTensor = None,
|
410
|
-
image_sizes: Optional[torch.LongTensor] = None,
|
411
|
-
inputs_embeds: Optional[torch.FloatTensor] = None,
|
412
|
-
vision_feature_layer: Optional[int] = None,
|
413
|
-
vision_feature_select_strategy: Optional[str] = None,
|
414
|
-
cache_position: Union[List[torch.Tensor], torch.Tensor] = None, # vllm keyword argument
|
415
|
-
batch_idx: Optional[int] = None,
|
416
|
-
**kwargs,
|
417
|
-
) -> Union[Tuple, LlavaNextCausalLMOutputWithPast]:
|
418
|
-
is_prefill = cache_position.shape[-1] > 1
|
419
|
-
|
420
|
-
if inputs_embeds is not None:
|
421
|
-
raise NotImplementedError("Specifying inputs_embeds is not supported.")
|
422
|
-
|
423
|
-
if is_prefill:
|
424
|
-
# Get text_embeds
|
425
|
-
inputs_embeds = self.text_embedding(input_ids)
|
426
|
-
|
427
|
-
# If any images in the prompt, get image_embeds and merge with text
|
428
|
-
if pixel_values is not None and input_ids.shape[1] != 1 and pixel_values.size(0) > 0:
|
429
|
-
image_features, _ = self.image_embedding(
|
430
|
-
image_sizes, pixel_values, vision_feature_layer, vision_feature_select_strategy
|
431
|
-
)
|
432
|
-
|
433
|
-
def merge_vllm_multimodal_embeddings(
|
434
|
-
input_ids: torch.Tensor,
|
435
|
-
inputs_embeds: torch.Tensor,
|
436
|
-
multimodal_embeddings: torch.Tensor,
|
437
|
-
placeholder_token_id: int,
|
438
|
-
) -> torch.Tensor:
|
439
|
-
mask = input_ids == placeholder_token_id
|
440
|
-
num_expected_tokens = mask.sum().item()
|
441
|
-
|
442
|
-
if multimodal_embeddings.shape[0] != num_expected_tokens:
|
443
|
-
raise ValueError(
|
444
|
-
f"Attempted to assign {inputs_embeds[mask].shape} = {multimodal_embeddings.shape} "
|
445
|
-
f"multimodal tokens to {num_expected_tokens} placeholders"
|
446
|
-
)
|
447
|
-
|
448
|
-
inputs_embeds[mask] = multimodal_embeddings
|
449
|
-
return inputs_embeds
|
450
|
-
|
451
|
-
inputs_embeds = merge_vllm_multimodal_embeddings(
|
452
|
-
input_ids, inputs_embeds, image_features, self.config.image_token_index
|
453
|
-
)
|
454
|
-
|
455
|
-
else:
|
456
|
-
inputs_embeds = self.text_embedding(input_ids=input_ids)
|
457
|
-
|
458
|
-
outputs: RBLNDecoderOnlyOutput = self.language_model.vllm_forward(
|
459
|
-
inputs_embeds=inputs_embeds,
|
460
|
-
batch_idx=batch_idx,
|
461
|
-
cache_position=cache_position,
|
462
|
-
)
|
463
|
-
|
464
|
-
return outputs
|
465
|
-
|
466
412
|
# Almost copied from : https://github.com/huggingface/transformers/blob/5af7d41e49bbfc8319f462eb45253dcb3863dfb7/src/transformers/models/llava_next/modeling_llava_next.py
|
467
413
|
def pack_image_features(self, image_features, image_sizes, image_newline=None):
|
468
414
|
"""
|
@@ -21,18 +21,24 @@
|
|
21
21
|
# copied, modified, or distributed without prior written permission
|
22
22
|
# from Rebellions Inc.
|
23
23
|
|
24
|
-
from typing import
|
24
|
+
from typing import TYPE_CHECKING, Tuple
|
25
25
|
|
26
26
|
import torch
|
27
27
|
import torch.nn as nn
|
28
|
-
from transformers.modeling_outputs import BaseModelOutputWithPast
|
29
28
|
|
30
|
-
from ....transformers.models.decoderonly.decoderonly_architecture import
|
31
|
-
|
32
|
-
|
33
|
-
|
29
|
+
from ....transformers.models.decoderonly.decoderonly_architecture import rotate_half
|
30
|
+
from ..decoderonly.decoderonly_architecture import (
|
31
|
+
DecoderOnlyAttention,
|
32
|
+
DecoderOnlyForCausalLM,
|
33
|
+
DecoderOnlyLayer,
|
34
|
+
DecoderOnlyModel,
|
35
|
+
DecoderOnlyWrapper,
|
36
|
+
apply_rotary_pos_emb_partial,
|
34
37
|
)
|
35
|
-
|
38
|
+
|
39
|
+
|
40
|
+
if TYPE_CHECKING:
|
41
|
+
from transformers import PreTrainedModel as MidmLMHeadModel
|
36
42
|
|
37
43
|
|
38
44
|
def apply_rotary_to_tensor(tensor, cos, sin, rot_dim):
|
@@ -50,264 +56,93 @@ def apply_rotary_pos_emb(q, k, cos, sin):
|
|
50
56
|
return q_embed, k_embed
|
51
57
|
|
52
58
|
|
53
|
-
class MidmLMHeadModelWrapper(
|
54
|
-
|
59
|
+
class MidmLMHeadModelWrapper(DecoderOnlyWrapper):
|
60
|
+
def get_rotary_emb(self, max_seq_len):
|
61
|
+
self.config.rope_theta = 10000
|
62
|
+
self.config.head_dim = self.config.n_embd // self.config.n_head
|
63
|
+
self.config.partial_rotary_factor = self.config.rotary_percentage
|
64
|
+
return super().get_rotary_emb(max_seq_len=max_seq_len)
|
65
|
+
|
66
|
+
def convert_to_rbln_causal_lm(self, causal_lm: "MidmLMHeadModel"):
|
67
|
+
if self.attn_impl != "eager":
|
68
|
+
raise NotImplementedError(f"flash attention ({self.attn_impl}) is not implemented for {self.__class__}")
|
69
|
+
new_layers = []
|
70
|
+
for layer in causal_lm.transformer.h:
|
71
|
+
new_self_attn = MidmAttention(layer.attn)
|
72
|
+
new_layer = MidmLayer(layer, new_self_attn)
|
73
|
+
new_layers.append(new_layer)
|
74
|
+
new_model = MidmModel(causal_lm.transformer, new_layers)
|
75
|
+
new_causal_lm = DecoderOnlyForCausalLM(causal_lm, new_model)
|
76
|
+
return new_causal_lm
|
77
|
+
|
78
|
+
|
79
|
+
class MidmModel(DecoderOnlyModel):
|
80
|
+
mask_fmin = -10000.0
|
81
|
+
|
82
|
+
def get_layernorm1p(self, module: nn.LayerNorm):
|
83
|
+
def layernorm1p(input: torch.Tensor):
|
84
|
+
"""Applies Layer Normalization with a slight modification on the weights."""
|
85
|
+
return torch.nn.functional.layer_norm(
|
86
|
+
input, module.normalized_shape, module.weight + 1, module.bias, module.eps
|
87
|
+
)
|
55
88
|
|
56
|
-
|
57
|
-
super().__init__()
|
58
|
-
self.model = model.transformer
|
59
|
-
self.lm_head = model.lm_head
|
60
|
-
self.config = model.config
|
61
|
-
self.head_dim = self.config.n_embd // self.config.n_head
|
62
|
-
self.max_position_embeddings = (
|
63
|
-
self.config.max_position_embeddings if max_seq_len > self.config.max_position_embeddings else max_seq_len
|
64
|
-
)
|
65
|
-
self.max_seq_len = max_seq_len
|
66
|
-
self.rotary_dim = int(
|
67
|
-
model.config.hidden_size // model.config.num_attention_heads * model.config.rotary_percentage
|
68
|
-
)
|
69
|
-
self.rotary_emb = self._init_rope()
|
89
|
+
return layernorm1p
|
70
90
|
|
71
|
-
def
|
72
|
-
|
73
|
-
|
74
|
-
self.rotary_dim,
|
75
|
-
max_position_embeddings=self.max_position_embeddings,
|
76
|
-
)
|
77
|
-
return rotary_emb
|
78
|
-
|
79
|
-
def forward(
|
80
|
-
self,
|
81
|
-
input_ids: torch.Tensor,
|
82
|
-
attention_mask: torch.Tensor,
|
83
|
-
cache_position: torch.LongTensor,
|
84
|
-
batch_position: int,
|
85
|
-
query_idx: int,
|
86
|
-
*past_key_values,
|
87
|
-
):
|
88
|
-
"""Defines the forward pass for the wrapper model."""
|
89
|
-
if input_ids.shape[1] == 1:
|
90
|
-
rbln_batch_position = None
|
91
|
+
def get_last_layernorm(self) -> nn.LayerNorm:
|
92
|
+
if self._original_mod.use_layernorm1p:
|
93
|
+
return self.get_layernorm1p(self._original_mod.ln_f)
|
91
94
|
else:
|
92
|
-
|
93
|
-
|
94
|
-
past_key_values = RebelDynamicCache_4D.from_input_format(
|
95
|
-
cache_position,
|
96
|
-
self.config.num_hidden_layers,
|
97
|
-
*past_key_values,
|
98
|
-
)
|
99
|
-
|
100
|
-
outputs = _MidmModel.forward(
|
101
|
-
self.model,
|
102
|
-
input_ids=input_ids,
|
103
|
-
past_key_values=past_key_values,
|
104
|
-
attention_mask=attention_mask,
|
105
|
-
position_ids=cache_position,
|
106
|
-
rotary_pos_emb=self.rotary_emb,
|
107
|
-
batch_ids=rbln_batch_position,
|
108
|
-
)
|
95
|
+
return self._original_mod.ln_f
|
109
96
|
|
110
|
-
|
111
|
-
|
112
|
-
hidden_states = hidden_states[:, query_idx].unsqueeze(1)
|
97
|
+
def get_embedding(self) -> nn.Embedding:
|
98
|
+
return self._original_mod.wte
|
113
99
|
|
114
|
-
|
115
|
-
|
100
|
+
def get_pos_embedding(self) -> nn.Embedding:
|
101
|
+
return self._original_mod.wpe
|
116
102
|
|
117
|
-
return output, batch_position + query_idx
|
118
103
|
|
119
|
-
|
120
|
-
def
|
121
|
-
|
122
|
-
|
123
|
-
|
124
|
-
|
125
|
-
class _MidmAttention:
|
126
|
-
"""Custom implementation of the MidmAttention class with specific modifications."""
|
127
|
-
|
128
|
-
def _attn(self, query, key, value, attention_mask=None, head_mask=None):
|
129
|
-
"""Computes the attention weights and output."""
|
130
|
-
attn_weights = torch.matmul(query, key.transpose(-1, -2))
|
131
|
-
|
132
|
-
if self.scale_attn_weights:
|
133
|
-
attn_weights = attn_weights / torch.full(
|
134
|
-
[], value.size(-1) ** 0.5, dtype=attn_weights.dtype, device=attn_weights.device
|
104
|
+
class MidmLayer(DecoderOnlyLayer):
|
105
|
+
def get_layernorm1p(self, module: nn.LayerNorm):
|
106
|
+
def layernorm1p(input: torch.Tensor):
|
107
|
+
"""Applies Layer Normalization with a slight modification on the weights."""
|
108
|
+
return torch.nn.functional.layer_norm(
|
109
|
+
input, module.normalized_shape, module.weight + 1, module.bias, module.eps
|
135
110
|
)
|
136
111
|
|
137
|
-
|
138
|
-
attn_weights = attn_weights / float(self.layer_idx + 1)
|
139
|
-
|
140
|
-
if attention_mask is not None:
|
141
|
-
attn_weights = attn_weights + attention_mask
|
142
|
-
|
143
|
-
if self.scale_qk_by_inverse_layer_idx:
|
144
|
-
attn_weights = attn_weights * float(self.layer_idx + 1)
|
145
|
-
|
146
|
-
attn_weights = nn.functional.softmax(attn_weights, dim=-1)
|
147
|
-
attn_weights = attn_weights.type(value.dtype)
|
148
|
-
|
149
|
-
if head_mask is not None:
|
150
|
-
attn_weights = attn_weights * head_mask
|
151
|
-
|
152
|
-
attn_output = torch.matmul(attn_weights, value)
|
153
|
-
return attn_output, attn_weights
|
154
|
-
|
155
|
-
def forward(
|
156
|
-
self,
|
157
|
-
hidden_states: Optional[Tuple[torch.FloatTensor]],
|
158
|
-
attention_mask: Optional[torch.FloatTensor] = None,
|
159
|
-
past_key_value: Optional[RebelDynamicCache_4D] = None,
|
160
|
-
batch_index: Optional[int] = None,
|
161
|
-
cos: Optional[torch.Tensor] = None,
|
162
|
-
sin: Optional[torch.Tensor] = None,
|
163
|
-
) -> Tuple[Union[torch.Tensor, Tuple[torch.Tensor]], ...]:
|
164
|
-
"""Defines the forward pass for the attention mechanism."""
|
165
|
-
bsz, q_len, _ = hidden_states.size()
|
166
|
-
|
167
|
-
querys, keys, values = self.c_attn(hidden_states).split(self.split_size, dim=2)
|
168
|
-
|
169
|
-
querys = self._split_heads(querys, self.num_heads, self.head_dim).contiguous()
|
170
|
-
keys = self._split_heads(keys, self.num_heads, self.head_dim).contiguous()
|
171
|
-
values = self._split_heads(values, self.num_heads, self.head_dim).contiguous()
|
172
|
-
|
173
|
-
querys, keys = apply_rotary_pos_emb(querys, keys, cos, sin)
|
174
|
-
|
175
|
-
# Decoder
|
176
|
-
if (batch_index is None or batch_index == -1) and bsz > 1:
|
177
|
-
all_key_states = []
|
178
|
-
all_value_states = []
|
179
|
-
all_attn_output = []
|
180
|
-
|
181
|
-
for b in range(bsz):
|
182
|
-
query = querys[b].unsqueeze(0)
|
183
|
-
attn_mask = attention_mask[b].unsqueeze(0)
|
184
|
-
key = keys[b].unsqueeze(0)
|
185
|
-
value = values[b].unsqueeze(0)
|
186
|
-
|
187
|
-
key, value = past_key_value.update(
|
188
|
-
key,
|
189
|
-
value,
|
190
|
-
self.layer_idx,
|
191
|
-
b,
|
192
|
-
)
|
193
|
-
|
194
|
-
attn_output, _ = _MidmAttention._attn(self, query, key, value, attn_mask)
|
195
|
-
attn_output = self._merge_heads(attn_output, self.num_heads, self.head_dim)
|
196
|
-
|
197
|
-
all_key_states.append(key)
|
198
|
-
all_value_states.append(value)
|
199
|
-
all_attn_output.append(attn_output)
|
200
|
-
|
201
|
-
keys = torch.cat(all_key_states, dim=0)
|
202
|
-
values = torch.cat(all_value_states, dim=0)
|
203
|
-
attn_output = torch.cat(all_attn_output, dim=0)
|
112
|
+
return layernorm1p
|
204
113
|
|
114
|
+
def get_pre_attention_layernorm(self) -> nn.LayerNorm:
|
115
|
+
if self._original_mod.use_layernorm1p:
|
116
|
+
return self.get_layernorm1p(self._original_mod.ln_1)
|
205
117
|
else:
|
206
|
-
|
207
|
-
batch_index = 0
|
208
|
-
|
209
|
-
keys, values = past_key_value.update(
|
210
|
-
keys,
|
211
|
-
values,
|
212
|
-
self.layer_idx,
|
213
|
-
batch_index,
|
214
|
-
read_first_step=True,
|
215
|
-
)
|
118
|
+
return self._original_mod.ln_1
|
216
119
|
|
217
|
-
|
218
|
-
|
219
|
-
|
220
|
-
attn_output = self.c_proj(attn_output)
|
221
|
-
return attn_output, keys, values
|
222
|
-
|
223
|
-
|
224
|
-
class _MidmBlock:
|
225
|
-
"""Custom implementation of the MidmBlock class with specific modifications."""
|
226
|
-
|
227
|
-
def forward(
|
228
|
-
self,
|
229
|
-
hidden_states: Optional[Tuple[torch.FloatTensor]],
|
230
|
-
layer_idx: int,
|
231
|
-
attention_mask: Optional[torch.FloatTensor] = None,
|
232
|
-
past_key_value: Optional[RebelDynamicCache_4D] = None,
|
233
|
-
batch_ids: Optional[torch.LongTensor] = None,
|
234
|
-
cos: Optional[torch.Tensor] = None,
|
235
|
-
sin: Optional[torch.Tensor] = None,
|
236
|
-
) -> Union[Tuple[torch.Tensor], Optional[Tuple[torch.Tensor, Tuple[torch.FloatTensor, ...]]]]:
|
237
|
-
"""Defines the forward pass for the block."""
|
238
|
-
residual = hidden_states
|
239
|
-
if self.use_layernorm1p:
|
240
|
-
hidden_states = layernorm1p(self.ln_1, hidden_states)
|
120
|
+
def get_post_attention_layernorm(self) -> nn.LayerNorm:
|
121
|
+
if self._original_mod.use_layernorm1p:
|
122
|
+
return self.get_layernorm1p(self._original_mod.ln_2)
|
241
123
|
else:
|
242
|
-
|
243
|
-
|
244
|
-
|
245
|
-
|
246
|
-
|
247
|
-
|
248
|
-
|
249
|
-
|
250
|
-
|
251
|
-
|
124
|
+
return self._original_mod.ln_2
|
125
|
+
|
126
|
+
|
127
|
+
class MidmAttention(DecoderOnlyAttention):
|
128
|
+
def __post_init__(self):
|
129
|
+
self.c_attn = self._original_mod.c_attn
|
130
|
+
self.o_proj = self._original_mod.c_proj
|
131
|
+
self.split_size = self._original_mod.split_size
|
132
|
+
self.num_key_value_heads = self._original_mod.num_heads
|
133
|
+
|
134
|
+
def projection(self, hidden_states) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
|
135
|
+
query_states, key_states, value_states = self.c_attn(hidden_states).split(self.split_size, dim=2)
|
136
|
+
return query_states, key_states, value_states
|
137
|
+
|
138
|
+
def rbln_attention(self, *args, **kwargs):
|
139
|
+
return super().rbln_attention(
|
140
|
+
*args,
|
141
|
+
**kwargs,
|
142
|
+
layer_idx=self.layer_idx,
|
143
|
+
scale_attn_weights=self._original_mod.scale_attn_weights,
|
144
|
+
scale_attn_by_inverse_layer_idx=self._original_mod.scale_attn_by_inverse_layer_idx,
|
252
145
|
)
|
253
|
-
past_key_value.assign(k, v, layer_idx)
|
254
|
-
|
255
|
-
hidden_states = hidden_states + residual
|
256
146
|
|
257
|
-
|
258
|
-
|
259
|
-
hidden_states = layernorm1p(self.ln_2, hidden_states)
|
260
|
-
else:
|
261
|
-
hidden_states = self.ln_2(hidden_states)
|
262
|
-
|
263
|
-
feed_forward_hidden_states = self.mlp(hidden_states)
|
264
|
-
hidden_states = residual + feed_forward_hidden_states
|
265
|
-
|
266
|
-
return hidden_states, past_key_value
|
267
|
-
|
268
|
-
|
269
|
-
class _MidmModel:
|
270
|
-
"""Custom implementation of the MidmModel class with specific modifications."""
|
271
|
-
|
272
|
-
def forward(
|
273
|
-
self,
|
274
|
-
input_ids: Optional[torch.LongTensor] = None,
|
275
|
-
past_key_values: Optional[RebelDynamicCache_4D] = None,
|
276
|
-
attention_mask: Optional[torch.FloatTensor] = None,
|
277
|
-
position_ids: Optional[torch.LongTensor] = None,
|
278
|
-
rotary_pos_emb=None,
|
279
|
-
batch_ids: Optional[torch.LongTensor] = None,
|
280
|
-
) -> Union[Tuple, BaseModelOutputWithPast]:
|
281
|
-
"""Defines the forward pass for the model."""
|
282
|
-
input_shape = input_ids.size()
|
283
|
-
|
284
|
-
attention_mask = (1.0 - attention_mask) * -10000.0
|
285
|
-
|
286
|
-
inputs_embeds = self.wte(input_ids)
|
287
|
-
|
288
|
-
cos, sin = rotary_pos_emb(inputs_embeds, attention_mask.shape[-1])
|
289
|
-
cos, sin = slice_and_unsqueeze_cos_sin(cos, sin, position_ids)
|
290
|
-
hidden_states = inputs_embeds
|
291
|
-
|
292
|
-
for layer_idx, (block, _) in enumerate(zip(self.h, past_key_values)):
|
293
|
-
hidden_states, updated_cache = _MidmBlock.forward(
|
294
|
-
block,
|
295
|
-
hidden_states,
|
296
|
-
layer_idx,
|
297
|
-
attention_mask=attention_mask,
|
298
|
-
past_key_value=past_key_values,
|
299
|
-
batch_ids=batch_ids,
|
300
|
-
cos=cos,
|
301
|
-
sin=sin,
|
302
|
-
)
|
303
|
-
|
304
|
-
hidden_states = layernorm1p(self.ln_f, hidden_states)
|
305
|
-
output_shape = (-1,) + input_shape[1:] + (hidden_states.size(-1),)
|
306
|
-
hidden_states = hidden_states.view(output_shape)
|
307
|
-
|
308
|
-
next_cache = updated_cache.to_legacy_cache()
|
309
|
-
|
310
|
-
return BaseModelOutputWithPast(
|
311
|
-
last_hidden_state=hidden_states,
|
312
|
-
past_key_values=next_cache,
|
313
|
-
)
|
147
|
+
def apply_rotary_pos_embed(self, query_states, key_states, cos, sin):
|
148
|
+
return apply_rotary_pos_emb_partial(query_states, key_states, cos, sin, ndim=cos.shape[-1])
|
@@ -21,23 +21,15 @@
|
|
21
21
|
# copied, modified, or distributed without prior written permission
|
22
22
|
# from Rebellions Inc.
|
23
23
|
|
24
|
-
import inspect
|
25
|
-
import logging
|
26
|
-
from typing import TYPE_CHECKING, Any, Callable
|
27
24
|
|
28
|
-
from
|
29
|
-
from ...models.decoderonly import RBLNDecoderOnlyModelForCausalLM
|
30
|
-
from .hf_hub_cached.modeling_midm import MidmLMHeadModel
|
31
|
-
from .midm_architecture import (
|
32
|
-
MidmLMHeadModelWrapper,
|
33
|
-
)
|
25
|
+
from transformers import AutoModelForCausalLM
|
34
26
|
|
27
|
+
from ....utils import logging
|
28
|
+
from ..decoderonly import RBLNDecoderOnlyModelForCausalLM
|
29
|
+
from .midm_architecture import MidmLMHeadModelWrapper
|
35
30
|
|
36
|
-
|
37
|
-
|
38
|
-
from transformers import (
|
39
|
-
PreTrainedModel,
|
40
|
-
)
|
31
|
+
|
32
|
+
logger = logging.get_logger(__name__)
|
41
33
|
|
42
34
|
|
43
35
|
class RBLNMidmLMHeadModel(RBLNDecoderOnlyModelForCausalLM):
|
@@ -54,25 +46,8 @@ class RBLNMidmLMHeadModel(RBLNDecoderOnlyModelForCausalLM):
|
|
54
46
|
|
55
47
|
"""
|
56
48
|
|
57
|
-
|
58
|
-
|
59
|
-
rbln_max_seq_len = rbln_config.model_cfg["max_seq_len"]
|
60
|
-
return MidmLMHeadModelWrapper(model, rbln_max_seq_len).eval()
|
61
|
-
|
62
|
-
def __getattr__(self, __name: str) -> Any:
|
63
|
-
"""This is the key method to implement RBLN-Midm.
|
64
|
-
|
65
|
-
Returns:
|
66
|
-
Any: Midm's corresponding method
|
67
|
-
"""
|
68
|
-
|
69
|
-
def redirect(func):
|
70
|
-
return lambda *pargs, **kwargs: func(self, *pargs, **kwargs)
|
71
|
-
|
72
|
-
val = getattr(MidmLMHeadModel, __name)
|
73
|
-
if isinstance(val, Callable) and "self" in set(inspect.signature(val).parameters):
|
74
|
-
return redirect(val)
|
75
|
-
return val
|
49
|
+
_decoder_wrapper_cls = MidmLMHeadModelWrapper
|
50
|
+
_hf_class = AutoModelForCausalLM
|
76
51
|
|
77
52
|
@classmethod
|
78
53
|
def from_pretrained(cls, *args, **kwargs):
|
@@ -21,29 +21,18 @@
|
|
21
21
|
# copied, modified, or distributed without prior written permission
|
22
22
|
# from Rebellions Inc.
|
23
23
|
|
24
|
-
import
|
25
|
-
import logging
|
26
|
-
from typing import TYPE_CHECKING, Any, Callable
|
27
|
-
|
28
|
-
from transformers import MistralForCausalLM
|
29
|
-
|
24
|
+
from ....utils import logging
|
30
25
|
from ...models.decoderonly import RBLNDecoderOnlyModelForCausalLM
|
31
26
|
from .mistral_architecture import MistralForCausalLMWrapper
|
32
27
|
|
33
28
|
|
34
|
-
|
35
|
-
from transformers import PreTrainedModel
|
36
|
-
|
37
|
-
from ....modeling_config import RBLNConfig
|
38
|
-
|
39
|
-
|
40
|
-
logger = logging.getLogger(__name__)
|
29
|
+
logger = logging.get_logger(__name__)
|
41
30
|
|
42
31
|
|
43
32
|
class RBLNMistralForCausalLM(RBLNDecoderOnlyModelForCausalLM):
|
44
33
|
"""
|
45
34
|
The Llama Model transformer with a language modeling head (linear layer) on top.
|
46
|
-
This model inherits from [`
|
35
|
+
This model inherits from [`RBLNDecoderOnlyModelForCausalLM`]. Check the superclass documentation for the generic methods the library implements for all its models.
|
47
36
|
|
48
37
|
A class to convert and run pre-trained transformers based LlamaForCausalLM model on RBLN devices.
|
49
38
|
It implements the methods to convert a pre-trained transformers LlamaForCausalLM model into a RBLN transformer model by:
|
@@ -51,18 +40,4 @@ class RBLNMistralForCausalLM(RBLNDecoderOnlyModelForCausalLM):
|
|
51
40
|
- compiling the resulting graph using the RBLN compiler.
|
52
41
|
"""
|
53
42
|
|
54
|
-
|
55
|
-
def wrap_model_if_needed(self, model: "PreTrainedModel", rbln_config: "RBLNConfig"):
|
56
|
-
rbln_max_seq_len = rbln_config.model_cfg["max_seq_len"]
|
57
|
-
return MistralForCausalLMWrapper(model, rbln_max_seq_len).eval()
|
58
|
-
|
59
|
-
def __getattr__(self, __name: str) -> Any:
|
60
|
-
def redirect(func):
|
61
|
-
return lambda *pargs, **kwargs: func(self, *pargs, **kwargs)
|
62
|
-
|
63
|
-
val = getattr(MistralForCausalLM, __name)
|
64
|
-
|
65
|
-
if isinstance(val, Callable) and "self" in set(inspect.signature(val).parameters):
|
66
|
-
return redirect(val)
|
67
|
-
|
68
|
-
return val
|
43
|
+
_decoder_wrapper_cls = MistralForCausalLMWrapper
|