optimum-rbln 0.1.11__py3-none-any.whl → 0.1.13__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- optimum/rbln/__init__.py +14 -7
- optimum/rbln/__version__.py +1 -1
- optimum/rbln/diffusers/models/autoencoder_kl.py +30 -63
- optimum/rbln/diffusers/models/controlnet.py +36 -62
- optimum/rbln/diffusers/models/unet_2d_condition.py +57 -156
- optimum/rbln/diffusers/pipelines/__init__.py +40 -12
- optimum/rbln/diffusers/pipelines/controlnet/multicontrolnet.py +11 -0
- optimum/rbln/diffusers/pipelines/controlnet/pipeline_controlnet.py +9 -187
- optimum/rbln/diffusers/pipelines/controlnet/pipeline_controlnet_img2img.py +8 -192
- optimum/rbln/diffusers/pipelines/controlnet/pipeline_controlnet_sd_xl.py +8 -206
- optimum/rbln/diffusers/pipelines/controlnet/pipeline_controlnet_sd_xl_img2img.py +8 -207
- optimum/rbln/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion.py +3 -111
- optimum/rbln/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_img2img.py +12 -117
- optimum/rbln/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl.py +4 -123
- optimum/rbln/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_img2img.py +4 -126
- optimum/rbln/modeling_alias.py +4 -9
- optimum/rbln/modeling_base.py +117 -144
- optimum/rbln/modeling_config.py +51 -0
- optimum/rbln/modeling_diffusers.py +400 -0
- optimum/rbln/transformers/__init__.py +10 -0
- optimum/rbln/transformers/cache_utils.py +5 -9
- optimum/rbln/transformers/modeling_rope_utils.py +283 -0
- optimum/rbln/transformers/models/__init__.py +80 -28
- optimum/rbln/transformers/models/auto/modeling_auto.py +1 -0
- optimum/rbln/transformers/models/bart/__init__.py +1 -1
- optimum/rbln/transformers/models/bart/bart_architecture.py +18 -12
- optimum/rbln/transformers/models/bart/modeling_bart.py +25 -6
- optimum/rbln/transformers/models/bert/modeling_bert.py +1 -2
- optimum/rbln/transformers/models/clip/modeling_clip.py +13 -23
- optimum/rbln/transformers/models/decoderonly/__init__.py +0 -2
- optimum/rbln/transformers/models/decoderonly/decoderonly_architecture.py +376 -218
- optimum/rbln/transformers/models/decoderonly/modeling_decoderonly.py +246 -116
- optimum/rbln/transformers/models/dpt/modeling_dpt.py +0 -1
- optimum/rbln/transformers/models/exaone/__init__.py +32 -0
- optimum/rbln/transformers/models/exaone/exaone_architecture.py +81 -0
- optimum/rbln/transformers/models/exaone/hf_hub_cached/configuration_exaone.py +181 -0
- optimum/rbln/transformers/models/exaone/hf_hub_cached/modeling_exaone.py +1725 -0
- optimum/rbln/transformers/models/exaone/modeling_exaone.py +53 -0
- optimum/rbln/transformers/models/gemma/gemma_architecture.py +12 -2
- optimum/rbln/transformers/models/gemma/modeling_gemma.py +4 -28
- optimum/rbln/transformers/models/gpt2/modeling_gpt2.py +4 -30
- optimum/rbln/transformers/models/llama/modeling_llama.py +4 -28
- optimum/rbln/transformers/models/llava_next/modeling_llava_next.py +166 -151
- optimum/rbln/transformers/models/midm/midm_architecture.py +4 -15
- optimum/rbln/transformers/models/midm/modeling_midm.py +8 -28
- optimum/rbln/transformers/models/mistral/modeling_mistral.py +4 -29
- optimum/rbln/transformers/models/phi/modeling_phi.py +5 -31
- optimum/rbln/transformers/models/phi/phi_architecture.py +75 -159
- optimum/rbln/transformers/models/qwen2/__init__.py +24 -0
- optimum/rbln/transformers/models/qwen2/modeling_qwen2.py +43 -0
- optimum/rbln/transformers/models/qwen2/qwen2_architecture.py +29 -0
- optimum/rbln/transformers/models/seq2seq/__init__.py +24 -0
- optimum/rbln/{modeling_seq2seq.py → transformers/models/seq2seq/modeling_seq2seq.py} +107 -166
- optimum/rbln/transformers/models/t5/__init__.py +1 -0
- optimum/rbln/transformers/models/t5/modeling_t5.py +108 -0
- optimum/rbln/transformers/models/t5/t5_architecture.py +46 -32
- optimum/rbln/transformers/models/wav2vec2/modeling_wav2vec2.py +0 -1
- optimum/rbln/transformers/models/whisper/modeling_whisper.py +38 -13
- optimum/rbln/transformers/models/xlm_roberta/modeling_xlm_roberta.py +1 -2
- optimum/rbln/transformers/utils/rbln_quantization.py +8 -2
- optimum/rbln/utils/context.py +58 -0
- optimum/rbln/utils/decorator_utils.py +55 -0
- optimum/rbln/utils/import_utils.py +21 -0
- optimum/rbln/utils/logging.py +1 -1
- optimum/rbln/utils/runtime_utils.py +4 -4
- optimum/rbln/utils/timer_utils.py +26 -2
- {optimum_rbln-0.1.11.dist-info → optimum_rbln-0.1.13.dist-info}/METADATA +11 -9
- optimum_rbln-0.1.13.dist-info/RECORD +107 -0
- {optimum_rbln-0.1.11.dist-info → optimum_rbln-0.1.13.dist-info}/WHEEL +1 -1
- optimum_rbln-0.1.11.dist-info/RECORD +0 -93
- {optimum_rbln-0.1.11.dist-info → optimum_rbln-0.1.13.dist-info}/entry_points.txt +0 -0
- {optimum_rbln-0.1.11.dist-info → optimum_rbln-0.1.13.dist-info}/licenses/LICENSE +0 -0
@@ -0,0 +1,53 @@
|
|
1
|
+
# Copyright 2024 Rebellions Inc.
|
2
|
+
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4
|
+
# you may not use this file except in compliance with the License.
|
5
|
+
# You may obtain a copy of the License at:
|
6
|
+
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
8
|
+
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12
|
+
# See the License for the specific language governing permissions and
|
13
|
+
# limitations under the License.
|
14
|
+
|
15
|
+
# Portions of this software are licensed under the Apache License,
|
16
|
+
# Version 2.0. See the NOTICE file distributed with this work for
|
17
|
+
# additional information regarding copyright ownership.
|
18
|
+
|
19
|
+
# All other portions of this software, including proprietary code,
|
20
|
+
# are the intellectual property of Rebellions Inc. and may not be
|
21
|
+
# copied, modified, or distributed without prior written permission
|
22
|
+
# from Rebellions Inc.
|
23
|
+
|
24
|
+
from ....utils import logging
|
25
|
+
from ..decoderonly import RBLNDecoderOnlyModelForCausalLM
|
26
|
+
from .exaone_architecture import ExaoneForCausalLMWrapper
|
27
|
+
from .hf_hub_cached.modeling_exaone import ExaoneForCausalLM
|
28
|
+
|
29
|
+
|
30
|
+
logger = logging.get_logger(__name__)
|
31
|
+
|
32
|
+
|
33
|
+
class RBLNExaoneForCausalLM(RBLNDecoderOnlyModelForCausalLM):
|
34
|
+
"""
|
35
|
+
The Exaone Model transformer with a language modeling head on top (linear layer with weights tied to the input
|
36
|
+
embeddings).
|
37
|
+
|
38
|
+
This model inherits from [`RBLNDecoderOnlyModelForCausalLM`]. Check the superclass documentation for the generic methods the
|
39
|
+
library implements for all its model.
|
40
|
+
|
41
|
+
It implements the methods to convert a pre-trained transformers Exaone model into a RBLN transformer model by:
|
42
|
+
- transferring the checkpoint weights of the original into an optimized RBLN graph,
|
43
|
+
- compiling the resulting graph using the RBLN compiler.
|
44
|
+
|
45
|
+
"""
|
46
|
+
|
47
|
+
_decoder_wrapper_cls = ExaoneForCausalLMWrapper
|
48
|
+
_original_cls = ExaoneForCausalLM
|
49
|
+
|
50
|
+
@classmethod
|
51
|
+
def from_pretrained(cls, *args, **kwargs):
|
52
|
+
kwargs.setdefault("trust_remote_code", True)
|
53
|
+
return super().from_pretrained(*args, **kwargs)
|
@@ -29,11 +29,11 @@ from transformers.modeling_outputs import (
|
|
29
29
|
)
|
30
30
|
|
31
31
|
from ...models.decoderonly import (
|
32
|
-
DecoderOnlyAttention,
|
33
32
|
DecoderOnlyDecoderLayer,
|
34
33
|
DecoderOnlyWrapper,
|
35
34
|
slice_and_unsqueeze_cos_sin,
|
36
35
|
)
|
36
|
+
from ...models.decoderonly.decoderonly_architecture import DECODERONLY_ATTENTION_CLASSES
|
37
37
|
|
38
38
|
|
39
39
|
class GemmaWrapper(DecoderOnlyWrapper):
|
@@ -43,7 +43,7 @@ class GemmaWrapper(DecoderOnlyWrapper):
|
|
43
43
|
{
|
44
44
|
"wrapper": GemmaModel.forward,
|
45
45
|
"model": DecoderOnlyDecoderLayer.forward,
|
46
|
-
"decoder_layer":
|
46
|
+
"decoder_layer": DECODERONLY_ATTENTION_CLASSES[self.attn_implementation].forward,
|
47
47
|
}
|
48
48
|
)
|
49
49
|
return forward_dict
|
@@ -61,9 +61,17 @@ class GemmaModel:
|
|
61
61
|
use_cache: Optional[bool] = True,
|
62
62
|
output_attentions: Optional[bool] = False,
|
63
63
|
output_hidden_states: Optional[bool] = False,
|
64
|
+
cache_pos_for_partitions: Optional[torch.Tensor] = None,
|
65
|
+
kvcache_partition_size: Optional[torch.Tensor] = None,
|
64
66
|
forward_dict: Optional[Dict[str, classmethod]] = None,
|
65
67
|
rotary_pos_emb=None,
|
66
68
|
) -> Union[Tuple, BaseModelOutputWithPast]:
|
69
|
+
# retrieve input_ids and inputs_embeds
|
70
|
+
if (input_ids is None) ^ (inputs_embeds is not None):
|
71
|
+
raise ValueError(
|
72
|
+
"You cannot specify both input_ids and inputs_embeds at the same time, and must specify either one"
|
73
|
+
)
|
74
|
+
|
67
75
|
# embed positions
|
68
76
|
inputs_embeds = self.embed_tokens(input_ids)
|
69
77
|
hidden_states = inputs_embeds
|
@@ -96,6 +104,8 @@ class GemmaModel:
|
|
96
104
|
batch_ids=batch_ids,
|
97
105
|
cos=cos,
|
98
106
|
sin=sin,
|
107
|
+
cache_pos_for_partitions=cache_pos_for_partitions,
|
108
|
+
kvcache_partition_size=kvcache_partition_size,
|
99
109
|
forward_dict=forward_dict,
|
100
110
|
)
|
101
111
|
|
@@ -21,28 +21,18 @@
|
|
21
21
|
# copied, modified, or distributed without prior written permission
|
22
22
|
# from Rebellions Inc.
|
23
23
|
|
24
|
-
import
|
25
|
-
import logging
|
26
|
-
from typing import TYPE_CHECKING, Any, Callable
|
27
|
-
|
28
|
-
from transformers import GemmaForCausalLM
|
29
|
-
|
24
|
+
from ....utils import logging
|
30
25
|
from ...models.decoderonly import RBLNDecoderOnlyModelForCausalLM
|
31
26
|
from .gemma_architecture import GemmaWrapper
|
32
27
|
|
33
28
|
|
34
|
-
|
35
|
-
from transformers import PreTrainedModel
|
36
|
-
|
37
|
-
from ....modeling_config import RBLNConfig
|
38
|
-
|
39
|
-
logger = logging.getLogger(__name__)
|
29
|
+
logger = logging.get_logger(__name__)
|
40
30
|
|
41
31
|
|
42
32
|
class RBLNGemmaForCausalLM(RBLNDecoderOnlyModelForCausalLM):
|
43
33
|
"""
|
44
34
|
The Gemma Model transformer with a language modeling head (linear layer) on top.
|
45
|
-
This model inherits from [`
|
35
|
+
This model inherits from [`RBLNDecoderOnlyModelForCausalLM`]. Check the superclass documentation for the generic methods the library implements for all its models.
|
46
36
|
|
47
37
|
A class to convert and run pre-trained transformers based GemmaForCausalLM model on RBLN devices.
|
48
38
|
It implements the methods to convert a pre-trained transformers GemmaForCausalLM model into a RBLN transformer model by:
|
@@ -50,18 +40,4 @@ class RBLNGemmaForCausalLM(RBLNDecoderOnlyModelForCausalLM):
|
|
50
40
|
- compiling the resulting graph using the RBLN compiler.
|
51
41
|
"""
|
52
42
|
|
53
|
-
|
54
|
-
def wrap_model_if_needed(self, model: "PreTrainedModel", rbln_config: "RBLNConfig"):
|
55
|
-
rbln_max_seq_len = rbln_config.model_cfg["max_seq_len"]
|
56
|
-
return GemmaWrapper(model, rbln_max_seq_len).eval()
|
57
|
-
|
58
|
-
def __getattr__(self, __name: str) -> Any:
|
59
|
-
def redirect(func):
|
60
|
-
return lambda *pargs, **kwargs: func(self, *pargs, **kwargs)
|
61
|
-
|
62
|
-
val = getattr(GemmaForCausalLM, __name)
|
63
|
-
|
64
|
-
if isinstance(val, Callable) and "self" in set(inspect.signature(val).parameters):
|
65
|
-
return redirect(val)
|
66
|
-
|
67
|
-
return val
|
43
|
+
_decoder_wrapper_cls = GemmaWrapper
|
@@ -21,20 +21,12 @@
|
|
21
21
|
# copied, modified, or distributed without prior written permission
|
22
22
|
# from Rebellions Inc.
|
23
23
|
|
24
|
-
import
|
25
|
-
import logging
|
26
|
-
from typing import TYPE_CHECKING, Any, Callable
|
27
|
-
|
28
|
-
from transformers import GPT2LMHeadModel
|
29
|
-
|
30
|
-
from ....modeling_config import RBLNConfig
|
24
|
+
from ....utils import logging
|
31
25
|
from ...models.decoderonly import RBLNDecoderOnlyModelForCausalLM
|
32
26
|
from .gpt2_architecture import GPT2LMHeadModelWrapper
|
33
27
|
|
34
28
|
|
35
|
-
logger = logging.
|
36
|
-
if TYPE_CHECKING:
|
37
|
-
from transformers import PreTrainedModel
|
29
|
+
logger = logging.get_logger(__name__)
|
38
30
|
|
39
31
|
|
40
32
|
class RBLNGPT2LMHeadModel(RBLNDecoderOnlyModelForCausalLM):
|
@@ -42,7 +34,7 @@ class RBLNGPT2LMHeadModel(RBLNDecoderOnlyModelForCausalLM):
|
|
42
34
|
The GPT2 Model transformer with a language modeling head on top (linear layer with weights tied to the input
|
43
35
|
embeddings).
|
44
36
|
|
45
|
-
This model inherits from [`
|
37
|
+
This model inherits from [`RBLNDecoderOnlyModelForCausalLM`]. Check the superclass documentation for the generic methods the
|
46
38
|
library implements for all its model.
|
47
39
|
|
48
40
|
It implements the methods to convert a pre-trained transformers GPT2 model into a RBLN transformer model by:
|
@@ -51,22 +43,4 @@ class RBLNGPT2LMHeadModel(RBLNDecoderOnlyModelForCausalLM):
|
|
51
43
|
|
52
44
|
"""
|
53
45
|
|
54
|
-
|
55
|
-
def wrap_model_if_needed(self, model: "PreTrainedModel", rbln_config: "RBLNConfig"):
|
56
|
-
rbln_max_seq_len = rbln_config.model_cfg["max_seq_len"]
|
57
|
-
return GPT2LMHeadModelWrapper(model, rbln_max_seq_len).eval()
|
58
|
-
|
59
|
-
def __getattr__(self, __name: str) -> Any:
|
60
|
-
"""This is the key method to implement RBLN-GPT2.
|
61
|
-
|
62
|
-
Returns:
|
63
|
-
Any: GPT2's corresponding method
|
64
|
-
"""
|
65
|
-
|
66
|
-
def redirect(func):
|
67
|
-
return lambda *pargs, **kwargs: func(self, *pargs, **kwargs)
|
68
|
-
|
69
|
-
val = getattr(GPT2LMHeadModel, __name)
|
70
|
-
if isinstance(val, Callable) and "self" in set(inspect.signature(val).parameters):
|
71
|
-
return redirect(val)
|
72
|
-
return val
|
46
|
+
_decoder_wrapper_cls = GPT2LMHeadModelWrapper
|
@@ -21,28 +21,18 @@
|
|
21
21
|
# copied, modified, or distributed without prior written permission
|
22
22
|
# from Rebellions Inc.
|
23
23
|
|
24
|
-
import
|
25
|
-
import logging
|
26
|
-
from typing import TYPE_CHECKING, Any, Callable
|
27
|
-
|
28
|
-
from transformers import LlamaForCausalLM
|
29
|
-
|
24
|
+
from ....utils import logging
|
30
25
|
from ...models.decoderonly import RBLNDecoderOnlyModelForCausalLM
|
31
26
|
from .llama_architecture import LlamaWrapper
|
32
27
|
|
33
28
|
|
34
|
-
|
35
|
-
from transformers import PreTrainedModel
|
36
|
-
|
37
|
-
from ....modeling_config import RBLNConfig
|
38
|
-
|
39
|
-
logger = logging.getLogger(__name__)
|
29
|
+
logger = logging.get_logger(__name__)
|
40
30
|
|
41
31
|
|
42
32
|
class RBLNLlamaForCausalLM(RBLNDecoderOnlyModelForCausalLM):
|
43
33
|
"""
|
44
34
|
The Llama Model transformer with a language modeling head (linear layer) on top.
|
45
|
-
This model inherits from [`
|
35
|
+
This model inherits from [`RBLNDecoderOnlyModelForCausalLM`]. Check the superclass documentation for the generic methods the library implements for all its models.
|
46
36
|
|
47
37
|
A class to convert and run pre-trained transformers based LlamaForCausalLM model on RBLN devices.
|
48
38
|
It implements the methods to convert a pre-trained transformers LlamaForCausalLM model into a RBLN transformer model by:
|
@@ -50,18 +40,4 @@ class RBLNLlamaForCausalLM(RBLNDecoderOnlyModelForCausalLM):
|
|
50
40
|
- compiling the resulting graph using the RBLN compiler.
|
51
41
|
"""
|
52
42
|
|
53
|
-
|
54
|
-
def wrap_model_if_needed(self, model: "PreTrainedModel", rbln_config: "RBLNConfig"):
|
55
|
-
rbln_max_seq_len = rbln_config.model_cfg["max_seq_len"]
|
56
|
-
return LlamaWrapper(model, rbln_max_seq_len).eval()
|
57
|
-
|
58
|
-
def __getattr__(self, __name: str) -> Any:
|
59
|
-
def redirect(func):
|
60
|
-
return lambda *pargs, **kwargs: func(self, *pargs, **kwargs)
|
61
|
-
|
62
|
-
val = getattr(LlamaForCausalLM, __name)
|
63
|
-
|
64
|
-
if isinstance(val, Callable) and "self" in set(inspect.signature(val).parameters):
|
65
|
-
return redirect(val)
|
66
|
-
|
67
|
-
return val
|
43
|
+
_decoder_wrapper_cls = LlamaWrapper
|
@@ -114,7 +114,7 @@ class LoopProjector:
|
|
114
114
|
return self.forward(*args, **kwds)
|
115
115
|
|
116
116
|
def __repr__(self) -> str:
|
117
|
-
return repr(self.
|
117
|
+
return repr(self.multi_modal_projector)
|
118
118
|
|
119
119
|
|
120
120
|
class RBLNLlavaNextForConditionalGeneration(RBLNModel):
|
@@ -228,29 +228,26 @@ class RBLNLlavaNextForConditionalGeneration(RBLNModel):
|
|
228
228
|
pixel_values=None,
|
229
229
|
image_sizes=None,
|
230
230
|
attention_mask=None,
|
231
|
-
|
231
|
+
generate_idx=None,
|
232
232
|
**kwargs,
|
233
233
|
):
|
234
234
|
# Prepare HF generation
|
235
|
-
is_prefill_phase =
|
235
|
+
is_prefill_phase = generate_idx is None
|
236
236
|
batch_size = input_ids.shape[0]
|
237
237
|
|
238
238
|
model_inputs = self.language_model.prepare_inputs_for_generation(
|
239
239
|
input_ids=input_ids,
|
240
240
|
inputs_embeds=inputs_embeds,
|
241
|
-
|
241
|
+
generate_idx=generate_idx, # Not affect
|
242
242
|
attention_mask=attention_mask,
|
243
243
|
**kwargs,
|
244
244
|
)
|
245
245
|
|
246
246
|
if is_prefill_phase:
|
247
|
-
model_inputs["
|
248
|
-
else:
|
249
|
-
model_inputs["past_cached_length"] = past_cached_length + 1
|
247
|
+
model_inputs["generate_idx"] = torch.zeros((batch_size, 1), dtype=torch.int32)
|
250
248
|
|
251
249
|
model_inputs.update(
|
252
250
|
{
|
253
|
-
# "position_ids": position_ids or cache_positions,
|
254
251
|
"pixel_values": pixel_values,
|
255
252
|
"image_sizes": image_sizes,
|
256
253
|
"attention_mask": attention_mask,
|
@@ -264,43 +261,28 @@ class RBLNLlavaNextForConditionalGeneration(RBLNModel):
|
|
264
261
|
model_kwargs: Dict[str, Any],
|
265
262
|
**kwargs,
|
266
263
|
) -> Dict[str, Any]:
|
267
|
-
# update
|
268
|
-
model_kwargs["
|
264
|
+
# update generate_idx
|
265
|
+
model_kwargs["generate_idx"] = outputs.generate_idx
|
269
266
|
|
270
267
|
return model_kwargs
|
271
268
|
|
272
|
-
def
|
269
|
+
def text_embedding(
|
273
270
|
self,
|
274
|
-
input_ids: torch.
|
275
|
-
inputs_embeds: torch.Tensor,
|
276
|
-
multimodal_embeddings: torch.Tensor,
|
277
|
-
placeholder_token_id: int,
|
271
|
+
input_ids: torch.LongTensor,
|
278
272
|
) -> torch.Tensor:
|
279
|
-
|
280
|
-
|
281
|
-
|
282
|
-
|
283
|
-
if multimodal_embeddings.shape[0] != num_expected_tokens:
|
284
|
-
raise ValueError(
|
285
|
-
f"Attempted to assign {inputs_embeds[mask].shape} = {multimodal_embeddings.shape} "
|
286
|
-
f"multimodal tokens to {num_expected_tokens} placeholders"
|
287
|
-
)
|
273
|
+
for_inputs_embeds_ids = input_ids.clone()
|
274
|
+
for_inputs_embeds_ids[(input_ids == self.config.image_token_index)] = 0
|
275
|
+
inputs_embeds = self.get_input_embeddings()(for_inputs_embeds_ids)
|
288
276
|
|
289
|
-
inputs_embeds[mask] = multimodal_embeddings
|
290
277
|
return inputs_embeds
|
291
278
|
|
292
|
-
def
|
279
|
+
def image_embedding(
|
293
280
|
self,
|
294
|
-
input_ids: torch.LongTensor,
|
295
281
|
image_sizes: torch.LongTensor,
|
296
|
-
attention_mask: torch.Tensor,
|
297
282
|
pixel_values: torch.FloatTensor,
|
298
283
|
vision_feature_layer: int,
|
299
284
|
vision_feature_select_strategy: str,
|
300
|
-
|
301
|
-
past_cached_length: torch.Tensor,
|
302
|
-
from_vllm_prefill: bool = False,
|
303
|
-
) -> List[torch.Tensor]:
|
285
|
+
) -> torch.Tensor:
|
304
286
|
vision_feature_layer = (
|
305
287
|
vision_feature_layer if vision_feature_layer is not None else self.config.vision_feature_layer
|
306
288
|
)
|
@@ -310,84 +292,137 @@ class RBLNLlavaNextForConditionalGeneration(RBLNModel):
|
|
310
292
|
else self.config.vision_feature_select_strategy
|
311
293
|
)
|
312
294
|
|
313
|
-
#
|
314
|
-
|
315
|
-
|
316
|
-
|
295
|
+
# ! infer image_num_patches from image_sizes
|
296
|
+
image_num_patches = [
|
297
|
+
image_size_to_num_patches(
|
298
|
+
image_size=imsize,
|
299
|
+
grid_pinpoints=self.config.image_grid_pinpoints,
|
300
|
+
patch_size=self.config.vision_config.image_size,
|
301
|
+
)
|
302
|
+
for imsize in image_sizes
|
303
|
+
]
|
317
304
|
|
318
|
-
|
305
|
+
# figure out if pixel_values is concatenated or stacked
|
306
|
+
if pixel_values.dim() == 5:
|
307
|
+
# stacking when input is (batch_size, num_patches, num_channels, height, width)
|
308
|
+
_pixel_values_list = [pix_val[:num_patch] for pix_val, num_patch in zip(pixel_values, image_num_patches)]
|
309
|
+
pixel_values = torch.cat(_pixel_values_list, dim=0)
|
310
|
+
elif pixel_values.dim() != 4:
|
311
|
+
# otherwise has to be stacked from list of (num_patches, num_channels, height, width)
|
312
|
+
raise ValueError(f"pixel_values of shape {pixel_values.shape}, expect to be of 4 or 5 dimensions")
|
319
313
|
|
320
|
-
|
321
|
-
|
322
|
-
# ! infer image_num_patches from image_sizes
|
323
|
-
image_num_patches = [
|
324
|
-
image_size_to_num_patches(
|
325
|
-
image_size=imsize,
|
326
|
-
grid_pinpoints=self.config.image_grid_pinpoints,
|
327
|
-
patch_size=self.config.vision_config.image_size,
|
328
|
-
)
|
329
|
-
for imsize in image_sizes
|
330
|
-
]
|
331
|
-
# figure out if pixel_values is concatenated or stacked
|
332
|
-
if pixel_values.dim() == 5:
|
333
|
-
# stacking when input is (batch_size, num_patches, num_channels, height, width)
|
334
|
-
_pixel_values_list = [
|
335
|
-
pix_val[:num_patch] for pix_val, num_patch in zip(pixel_values, image_num_patches)
|
336
|
-
]
|
337
|
-
pixel_values = torch.cat(_pixel_values_list, dim=0)
|
338
|
-
elif pixel_values.dim() != 4:
|
339
|
-
# otherwise has to be stacked from list of (num_patches, num_channels, height, width)
|
340
|
-
raise ValueError(f"pixel_values of shape {pixel_values.shape}, expect to be of 4 or 5 dimensions")
|
341
|
-
|
342
|
-
image_features = self.vision_tower(pixel_values, output_hidden_states=True)
|
343
|
-
selected_image_feature = image_features.hidden_states[vision_feature_layer]
|
344
|
-
|
345
|
-
if vision_feature_select_strategy == "default":
|
346
|
-
selected_image_feature = selected_image_feature[:, 1:]
|
347
|
-
elif vision_feature_select_strategy == "full":
|
348
|
-
selected_image_feature = selected_image_feature
|
349
|
-
|
350
|
-
image_features = self.multi_modal_projector(selected_image_feature)
|
351
|
-
image_features = torch.split(image_features, image_num_patches, dim=0)
|
352
|
-
|
353
|
-
# NOTE we only support multimodal_patch_merge_type == "spatial_unpad"
|
354
|
-
image_features, feature_lens = self.pack_image_features(
|
355
|
-
image_features,
|
356
|
-
image_sizes,
|
357
|
-
image_newline=self.image_newline,
|
358
|
-
)
|
314
|
+
image_features = self.vision_tower(pixel_values, output_hidden_states=True)
|
315
|
+
selected_image_feature = image_features.hidden_states[vision_feature_layer]
|
359
316
|
|
360
|
-
|
317
|
+
if vision_feature_select_strategy == "default":
|
318
|
+
selected_image_feature = selected_image_feature[:, 1:]
|
319
|
+
elif vision_feature_select_strategy == "full":
|
320
|
+
selected_image_feature = selected_image_feature
|
321
|
+
|
322
|
+
image_features = self.multi_modal_projector(selected_image_feature)
|
323
|
+
image_features = torch.split(image_features, image_num_patches, dim=0)
|
324
|
+
|
325
|
+
# NOTE we only support multimodal_patch_merge_type == "spatial_unpad"
|
326
|
+
image_features, feature_lens = self.pack_image_features(
|
327
|
+
image_features,
|
328
|
+
image_sizes,
|
329
|
+
image_newline=self.image_newline,
|
330
|
+
)
|
361
331
|
|
362
|
-
|
363
|
-
|
364
|
-
|
365
|
-
|
332
|
+
return image_features, feature_lens
|
333
|
+
|
334
|
+
def forward(
|
335
|
+
self,
|
336
|
+
input_ids: torch.LongTensor = None,
|
337
|
+
attention_mask: torch.LongTensor = None,
|
338
|
+
pixel_values: torch.FloatTensor = None,
|
339
|
+
image_sizes: Optional[torch.LongTensor] = None,
|
340
|
+
inputs_embeds: Optional[torch.FloatTensor] = None,
|
341
|
+
vision_feature_layer: Optional[int] = None,
|
342
|
+
vision_feature_select_strategy: Optional[str] = None,
|
343
|
+
cache_position: torch.Tensor = None,
|
344
|
+
generate_idx: Optional[torch.Tensor] = None,
|
345
|
+
**kwargs,
|
346
|
+
) -> Union[Tuple, LlavaNextCausalLMOutputWithPast]:
|
347
|
+
if inputs_embeds is not None:
|
348
|
+
raise NotImplementedError("Specifying inputs_embeds is not supported.")
|
349
|
+
|
350
|
+
is_prefill_phase = not generate_idx.bool().all()
|
351
|
+
|
352
|
+
if is_prefill_phase:
|
353
|
+
# if the number of image tokens is more than image embeddings seq length, then prob we expanded it in processing
|
354
|
+
# not very reliable, but we don't expect one to actually pass 500+ images for one prompt
|
355
|
+
# In case we're in decoding stage, legacy behavior is checked by presence of pixel values even if use_cache=True
|
356
|
+
legacy_processing = (
|
357
|
+
(input_ids == self.config.image_token_index).sum(1).max() < self.config.image_seq_length
|
358
|
+
) or (input_ids.shape[-1] == 1 and pixel_values is not None)
|
359
|
+
|
360
|
+
# Get the number of images in the prompt
|
361
|
+
special_image_token_masks = [input_id == self.config.image_token_index for input_id in input_ids]
|
362
|
+
if legacy_processing:
|
363
|
+
num_special_image_tokens = [torch.sum(mask, dim=-1) for mask in special_image_token_masks]
|
366
364
|
else:
|
367
|
-
|
368
|
-
|
369
|
-
|
370
|
-
|
371
|
-
|
372
|
-
|
365
|
+
image_tokens_masks_diff = [
|
366
|
+
torch.diff(mask, prepend=torch.tensor([0])) for mask in special_image_token_masks
|
367
|
+
]
|
368
|
+
num_special_image_tokens = [int(torch.sum((diff == 1).int())) for diff in image_tokens_masks_diff]
|
369
|
+
|
370
|
+
# Split images for each prompt
|
371
|
+
if pixel_values is not None and pixel_values.size(0) > 0:
|
372
|
+
pixel_values = pixel_values.split(num_special_image_tokens, dim=0)
|
373
|
+
image_sizes = image_sizes.split(num_special_image_tokens, dim=0)
|
374
|
+
|
375
|
+
logits = []
|
376
|
+
for b_idx in range(input_ids.shape[0]):
|
377
|
+
# Get text_embeds from input_id
|
378
|
+
input_id = input_ids[b_idx : b_idx + 1, attention_mask[b_idx].bool()]
|
379
|
+
inputs_embed = self.text_embedding(input_id)
|
380
|
+
|
381
|
+
# If any images in the prompt, get image_embeds and merge with text
|
382
|
+
if num_special_image_tokens[b_idx] > 0:
|
383
|
+
image_features, feature_lens = self.image_embedding(
|
384
|
+
image_sizes[b_idx], pixel_values[b_idx], vision_feature_layer, vision_feature_select_strategy
|
385
|
+
)
|
386
|
+
if legacy_processing:
|
387
|
+
inputs_embed, _, _, _, _ = self._merge_input_ids_with_image_features(
|
388
|
+
image_features,
|
389
|
+
feature_lens,
|
390
|
+
inputs_embed.to(image_features.dtype),
|
391
|
+
input_id,
|
392
|
+
torch.ones_like(input_id, dtype=torch.long),
|
393
|
+
)
|
394
|
+
else:
|
395
|
+
special_image_mask = (
|
396
|
+
(input_id == self.config.image_token_index).unsqueeze(-1).expand_as(inputs_embed)
|
397
|
+
)
|
398
|
+
inputs_embed = inputs_embed.masked_scatter(special_image_mask, image_features)
|
399
|
+
|
400
|
+
# Update generate_idx according to inputs_embed
|
401
|
+
generate_idx[b_idx] = inputs_embed.shape[1]
|
402
|
+
|
403
|
+
logit = self.language_model._forward_prefill(
|
404
|
+
inputs_embeds=inputs_embed,
|
405
|
+
batch_idx=b_idx,
|
406
|
+
cache_position=torch.arange(0, generate_idx[b_idx].item(), dtype=torch.int32).unsqueeze(0),
|
373
407
|
)
|
374
408
|
|
375
|
-
|
409
|
+
logits.append(logit)
|
376
410
|
|
377
|
-
|
378
|
-
|
379
|
-
|
380
|
-
|
381
|
-
|
411
|
+
logits = torch.cat(logits, dim=0)
|
412
|
+
outputs = RBLNDecoderOnlyOutput(logits=logits, generate_idx=generate_idx)
|
413
|
+
|
414
|
+
else:
|
415
|
+
inputs_embeds = self.text_embedding(input_ids)
|
382
416
|
|
383
|
-
|
384
|
-
|
385
|
-
|
386
|
-
|
417
|
+
outputs: RBLNDecoderOnlyOutput = self.language_model(
|
418
|
+
inputs_embeds=inputs_embeds,
|
419
|
+
cache_position=cache_position,
|
420
|
+
generate_idx=generate_idx,
|
421
|
+
)
|
387
422
|
|
388
|
-
return
|
423
|
+
return outputs
|
389
424
|
|
390
|
-
def
|
425
|
+
def vllm_forward(
|
391
426
|
self,
|
392
427
|
input_ids: torch.LongTensor = None,
|
393
428
|
pixel_values: torch.FloatTensor = None,
|
@@ -397,72 +432,52 @@ class RBLNLlavaNextForConditionalGeneration(RBLNModel):
|
|
397
432
|
vision_feature_select_strategy: Optional[str] = None,
|
398
433
|
cache_position: Union[List[torch.Tensor], torch.Tensor] = None, # vllm keyword argument
|
399
434
|
batch_idx: Optional[int] = None,
|
400
|
-
past_cached_length: Optional[torch.Tensor] = None,
|
401
435
|
**kwargs,
|
402
436
|
) -> Union[Tuple, LlavaNextCausalLMOutputWithPast]:
|
403
|
-
|
404
|
-
from_hf_generate_prefill = isinstance(input_ids, list)
|
437
|
+
is_prefill = cache_position.shape[-1] > 1
|
405
438
|
|
406
439
|
if inputs_embeds is not None:
|
407
440
|
raise NotImplementedError("Specifying inputs_embeds is not supported.")
|
408
441
|
|
409
|
-
if
|
410
|
-
|
411
|
-
|
442
|
+
if is_prefill:
|
443
|
+
# Get text_embeds
|
444
|
+
inputs_embeds = self.text_embedding(input_ids)
|
412
445
|
|
413
|
-
#
|
414
|
-
|
415
|
-
|
446
|
+
# If any images in the prompt, get image_embeds and merge with text
|
447
|
+
if pixel_values is not None and input_ids.shape[1] != 1 and pixel_values.size(0) > 0:
|
448
|
+
image_features, _ = self.image_embedding(
|
449
|
+
image_sizes, pixel_values, vision_feature_layer, vision_feature_select_strategy
|
450
|
+
)
|
416
451
|
|
417
|
-
|
418
|
-
|
419
|
-
|
420
|
-
|
421
|
-
|
422
|
-
|
423
|
-
|
424
|
-
|
425
|
-
|
426
|
-
|
427
|
-
|
428
|
-
|
429
|
-
|
430
|
-
|
452
|
+
def merge_vllm_multimodal_embeddings(
|
453
|
+
input_ids: torch.Tensor,
|
454
|
+
inputs_embeds: torch.Tensor,
|
455
|
+
multimodal_embeddings: torch.Tensor,
|
456
|
+
placeholder_token_id: int,
|
457
|
+
) -> torch.Tensor:
|
458
|
+
mask = input_ids == placeholder_token_id
|
459
|
+
num_expected_tokens = mask.sum().item()
|
460
|
+
|
461
|
+
if multimodal_embeddings.shape[0] != num_expected_tokens:
|
462
|
+
raise ValueError(
|
463
|
+
f"Attempted to assign {inputs_embeds[mask].shape} = {multimodal_embeddings.shape} "
|
464
|
+
f"multimodal tokens to {num_expected_tokens} placeholders"
|
465
|
+
)
|
466
|
+
|
467
|
+
inputs_embeds[mask] = multimodal_embeddings
|
468
|
+
return inputs_embeds
|
469
|
+
|
470
|
+
inputs_embeds = merge_vllm_multimodal_embeddings(
|
471
|
+
input_ids, inputs_embeds, image_features, self.config.image_token_index
|
431
472
|
)
|
432
|
-
|
433
|
-
cache_position[b_idx] = cache_pos
|
434
|
-
past_cached_length[b_idx] += embed.shape[1]
|
435
|
-
|
436
|
-
elif from_vllm_prefill:
|
437
|
-
inputs_embeds, cache_position = self._embed(
|
438
|
-
input_ids=input_ids,
|
439
|
-
image_sizes=image_sizes,
|
440
|
-
attention_mask=torch.ones_like(input_ids),
|
441
|
-
pixel_values=pixel_values,
|
442
|
-
vision_feature_layer=vision_feature_layer,
|
443
|
-
vision_feature_select_strategy=vision_feature_select_strategy,
|
444
|
-
cache_position=cache_position,
|
445
|
-
past_cached_length=past_cached_length,
|
446
|
-
from_vllm_prefill=from_vllm_prefill,
|
447
|
-
)
|
473
|
+
|
448
474
|
else:
|
449
|
-
|
450
|
-
inputs_embeds, cache_position = self._embed(
|
451
|
-
input_ids=input_ids,
|
452
|
-
image_sizes=image_sizes,
|
453
|
-
attention_mask=torch.ones_like(input_ids),
|
454
|
-
pixel_values=pixel_values,
|
455
|
-
vision_feature_layer=vision_feature_layer,
|
456
|
-
vision_feature_select_strategy=vision_feature_select_strategy,
|
457
|
-
cache_position=cache_position,
|
458
|
-
past_cached_length=past_cached_length,
|
459
|
-
)
|
475
|
+
inputs_embeds = self.text_embedding(input_ids=input_ids)
|
460
476
|
|
461
|
-
outputs: RBLNDecoderOnlyOutput = self.language_model(
|
477
|
+
outputs: RBLNDecoderOnlyOutput = self.language_model.vllm_forward(
|
462
478
|
inputs_embeds=inputs_embeds,
|
463
479
|
batch_idx=batch_idx,
|
464
480
|
cache_position=cache_position,
|
465
|
-
past_cached_length=past_cached_length,
|
466
481
|
)
|
467
482
|
|
468
483
|
return outputs
|