optimum-rbln 0.1.12__py3-none-any.whl → 0.1.13__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- optimum/rbln/__init__.py +5 -1
- optimum/rbln/__version__.py +1 -1
- optimum/rbln/diffusers/models/autoencoder_kl.py +30 -61
- optimum/rbln/diffusers/models/controlnet.py +36 -56
- optimum/rbln/diffusers/models/unet_2d_condition.py +57 -153
- optimum/rbln/diffusers/pipelines/__init__.py +40 -12
- optimum/rbln/diffusers/pipelines/controlnet/multicontrolnet.py +7 -0
- optimum/rbln/diffusers/pipelines/controlnet/pipeline_controlnet.py +9 -185
- optimum/rbln/diffusers/pipelines/controlnet/pipeline_controlnet_img2img.py +8 -190
- optimum/rbln/diffusers/pipelines/controlnet/pipeline_controlnet_sd_xl.py +8 -191
- optimum/rbln/diffusers/pipelines/controlnet/pipeline_controlnet_sd_xl_img2img.py +8 -192
- optimum/rbln/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion.py +3 -110
- optimum/rbln/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_img2img.py +12 -115
- optimum/rbln/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl.py +4 -122
- optimum/rbln/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_img2img.py +4 -125
- optimum/rbln/modeling_base.py +12 -5
- optimum/rbln/modeling_diffusers.py +400 -0
- optimum/rbln/transformers/__init__.py +2 -0
- optimum/rbln/transformers/cache_utils.py +5 -9
- optimum/rbln/transformers/modeling_rope_utils.py +283 -0
- optimum/rbln/transformers/models/__init__.py +80 -31
- optimum/rbln/transformers/models/clip/modeling_clip.py +13 -22
- optimum/rbln/transformers/models/decoderonly/__init__.py +0 -2
- optimum/rbln/transformers/models/decoderonly/decoderonly_architecture.py +376 -218
- optimum/rbln/transformers/models/decoderonly/modeling_decoderonly.py +74 -16
- optimum/rbln/transformers/models/exaone/exaone_architecture.py +18 -9
- optimum/rbln/transformers/models/exaone/modeling_exaone.py +4 -29
- optimum/rbln/transformers/models/gemma/gemma_architecture.py +12 -2
- optimum/rbln/transformers/models/gemma/modeling_gemma.py +4 -28
- optimum/rbln/transformers/models/gpt2/modeling_gpt2.py +4 -30
- optimum/rbln/transformers/models/llama/modeling_llama.py +4 -28
- optimum/rbln/transformers/models/llava_next/modeling_llava_next.py +27 -8
- optimum/rbln/transformers/models/midm/midm_architecture.py +4 -15
- optimum/rbln/transformers/models/midm/modeling_midm.py +4 -29
- optimum/rbln/transformers/models/mistral/modeling_mistral.py +4 -29
- optimum/rbln/transformers/models/phi/modeling_phi.py +5 -31
- optimum/rbln/transformers/models/phi/phi_architecture.py +75 -159
- optimum/rbln/transformers/models/qwen2/modeling_qwen2.py +5 -29
- optimum/rbln/transformers/models/t5/__init__.py +1 -1
- optimum/rbln/transformers/models/t5/modeling_t5.py +57 -4
- optimum/rbln/transformers/models/whisper/modeling_whisper.py +1 -1
- optimum/rbln/transformers/utils/rbln_quantization.py +8 -2
- optimum/rbln/utils/context.py +58 -0
- optimum/rbln/utils/decorator_utils.py +55 -0
- optimum/rbln/utils/import_utils.py +7 -0
- optimum/rbln/utils/runtime_utils.py +4 -4
- optimum/rbln/utils/timer_utils.py +2 -2
- {optimum_rbln-0.1.12.dist-info → optimum_rbln-0.1.13.dist-info}/METADATA +8 -7
- {optimum_rbln-0.1.12.dist-info → optimum_rbln-0.1.13.dist-info}/RECORD +52 -48
- {optimum_rbln-0.1.12.dist-info → optimum_rbln-0.1.13.dist-info}/WHEEL +0 -0
- {optimum_rbln-0.1.12.dist-info → optimum_rbln-0.1.13.dist-info}/entry_points.txt +0 -0
- {optimum_rbln-0.1.12.dist-info → optimum_rbln-0.1.13.dist-info}/licenses/LICENSE +0 -0
@@ -22,14 +22,15 @@
|
|
22
22
|
# from Rebellions Inc.
|
23
23
|
import functools
|
24
24
|
import glob
|
25
|
+
import inspect
|
25
26
|
import os
|
26
|
-
from abc import ABC
|
27
27
|
from dataclasses import dataclass
|
28
28
|
from pathlib import Path
|
29
|
-
from typing import TYPE_CHECKING, Any, Dict, List, Optional, Tuple, Union
|
29
|
+
from typing import TYPE_CHECKING, Any, Callable, Dict, List, Optional, Tuple, Union
|
30
30
|
|
31
|
-
import rebel
|
32
|
-
import torch
|
31
|
+
import rebel
|
32
|
+
import torch
|
33
|
+
import transformers
|
33
34
|
from safetensors.torch import load_file
|
34
35
|
from transformers import AutoConfig, AutoModelForCausalLM, PretrainedConfig, PreTrainedModel
|
35
36
|
from transformers.modeling_utils import no_init_weights
|
@@ -40,6 +41,7 @@ from ....modeling_config import DEFAULT_COMPILED_MODEL_NAME, RBLNCompileConfig,
|
|
40
41
|
from ....utils.logging import get_logger
|
41
42
|
from ....utils.runtime_utils import RBLNPytorchRuntime
|
42
43
|
from ....utils.timer_utils import rbln_timer
|
44
|
+
from .decoderonly_architecture import DecoderOnlyWrapper
|
43
45
|
|
44
46
|
|
45
47
|
logger = get_logger()
|
@@ -102,19 +104,47 @@ class RBLNDecoderOnlyOutput(ModelOutput):
|
|
102
104
|
generate_idx: torch.Tensor = None
|
103
105
|
|
104
106
|
|
105
|
-
class RBLNDecoderOnlyModelForCausalLM(RBLNModel
|
107
|
+
class RBLNDecoderOnlyModelForCausalLM(RBLNModel):
|
106
108
|
"""
|
107
|
-
|
108
|
-
This
|
109
|
-
|
110
|
-
|
111
|
-
|
112
|
-
|
113
|
-
|
109
|
+
A base class for decoder-only transformer models optimized for causal language modeling tasks on RBLN devices.
|
110
|
+
This class serves as the foundation for various decoder-only architectures like GPT, LLaMA, etc.
|
111
|
+
|
112
|
+
The class provides core functionality for:
|
113
|
+
1. Converting pre-trained transformer models to RBLN-optimized format
|
114
|
+
2. Handling the compilation process for RBLN devices
|
115
|
+
3. Managing inference operations for causal language modeling
|
116
|
+
|
117
|
+
This class inherits from RBLNModel and implements specific methods required for
|
118
|
+
decoder-only architectures and causal language modeling tasks.
|
119
|
+
|
120
|
+
Note:
|
121
|
+
- This class is designed to be subclassed by specific model implementations
|
122
|
+
(e.g., RBLNLlamaForCausalLM, RBLNGPT2LMHeadModel)
|
123
|
+
- Subclasses should implement model-specific conversion logic.
|
124
|
+
- The class handles RBLN-specific optimizations automatically during compilation
|
114
125
|
"""
|
115
126
|
|
116
127
|
main_input_name = "input_ids"
|
117
128
|
auto_model_class = AutoModelForCausalLM
|
129
|
+
_decoder_wrapper_cls = DecoderOnlyWrapper
|
130
|
+
_original_cls = None
|
131
|
+
|
132
|
+
@classmethod
|
133
|
+
@property
|
134
|
+
def original_cls(cls):
|
135
|
+
"""
|
136
|
+
Lazily loads and caches the corresponding Hugging Face model class.
|
137
|
+
Removes 'RBLN' prefix from the class name to get the original class name
|
138
|
+
(e.g., RBLNLlamaForCausalLM -> LlamaForCausalLM) and imports it from
|
139
|
+
the transformers module.
|
140
|
+
|
141
|
+
Returns:
|
142
|
+
type: The original Hugging Face model class
|
143
|
+
"""
|
144
|
+
if cls._original_cls is None:
|
145
|
+
hf_original_cls_name = cls.__name__[4:]
|
146
|
+
cls._original_cls = getattr(transformers, hf_original_cls_name)
|
147
|
+
return cls._original_cls
|
118
148
|
|
119
149
|
def __post_init__(self, **kwargs):
|
120
150
|
self.batch_size = self.rbln_config.model_cfg["batch_size"]
|
@@ -233,6 +263,26 @@ class RBLNDecoderOnlyModelForCausalLM(RBLNModel, ABC):
|
|
233
263
|
|
234
264
|
return 0
|
235
265
|
|
266
|
+
def __getattr__(self, __name: str) -> Any:
|
267
|
+
"""
|
268
|
+
Special method to delegate attribute access to the original Huggingface LM class.
|
269
|
+
This method is called when an attribute is not found in the current instance's dictionary.
|
270
|
+
It enables transparent access to the original model's attributes and methods while maintaining
|
271
|
+
proper method binding.
|
272
|
+
|
273
|
+
The method implements a delegation pattern that:
|
274
|
+
1. For methods: Creates a wrapper that properly binds 'self' to method calls
|
275
|
+
2. For other attributes: Returns them directly from the original class
|
276
|
+
"""
|
277
|
+
|
278
|
+
def redirect(func):
|
279
|
+
return lambda *pargs, **kwargs: func(self, *pargs, **kwargs)
|
280
|
+
|
281
|
+
val = getattr(self.original_cls, __name)
|
282
|
+
if isinstance(val, Callable) and "self" in set(inspect.signature(val).parameters):
|
283
|
+
return redirect(val)
|
284
|
+
return val
|
285
|
+
|
236
286
|
@classmethod
|
237
287
|
def get_pytorch_model(cls, *args, **kwargs) -> "PreTrainedModel":
|
238
288
|
rbln_kwargs = kwargs.get("rbln_kwargs", {})
|
@@ -293,6 +343,16 @@ class RBLNDecoderOnlyModelForCausalLM(RBLNModel, ABC):
|
|
293
343
|
|
294
344
|
return wrapper
|
295
345
|
|
346
|
+
@classmethod
|
347
|
+
def wrap_model_if_needed(cls, model: "PreTrainedModel", rbln_config: "RBLNConfig"):
|
348
|
+
wrapper_cfg = {"max_seq_len": rbln_config.model_cfg["max_seq_len"]}
|
349
|
+
|
350
|
+
# If the model wrapper supports rbln-custom-flash-attention
|
351
|
+
if "kvcache_partition_len" in inspect.signature(cls._decoder_wrapper_cls.__init__).parameters:
|
352
|
+
wrapper_cfg["kvcache_partition_len"] = rbln_config.model_cfg.get("kvcache_partition_len")
|
353
|
+
|
354
|
+
return cls._decoder_wrapper_cls(model, **wrapper_cfg).eval()
|
355
|
+
|
296
356
|
@classmethod
|
297
357
|
@torch.inference_mode()
|
298
358
|
def get_compiled_model(cls, model: "PreTrainedModel", rbln_config: RBLNConfig):
|
@@ -305,11 +365,9 @@ class RBLNDecoderOnlyModelForCausalLM(RBLNModel, ABC):
|
|
305
365
|
@rbln_timer("JIT trace")
|
306
366
|
def get_scripted_model():
|
307
367
|
# This function is nested to dealloc the example inputs before compilation.
|
368
|
+
# FIXME: 3rd dummy_input(batch_idx) should be fill zero to compile flash_attn.
|
308
369
|
prefill_example_inputs = prefill_rbln_compile_config.get_dummy_inputs(fill=0)
|
309
|
-
dec_example_inputs = dec_rbln_compile_config.get_dummy_inputs(fill=
|
310
|
-
|
311
|
-
batch_index = 3
|
312
|
-
dec_example_inputs[batch_index].fill_(-1) # fill batch_position -1 to indicate it is decoder.
|
370
|
+
dec_example_inputs = dec_rbln_compile_config.get_dummy_inputs(fill=0)
|
313
371
|
|
314
372
|
prefill_scripted_model = torch.jit.trace(
|
315
373
|
wrapped_model, prefill_example_inputs, check_trace=False, _store_inputs=False
|
@@ -20,31 +20,40 @@
|
|
20
20
|
# are the intellectual property of Rebellions Inc. and may not be
|
21
21
|
# copied, modified, or distributed without prior written permission
|
22
22
|
# from Rebellions Inc.
|
23
|
+
import torch
|
23
24
|
|
24
|
-
|
25
|
+
from ....utils import logging
|
25
26
|
from ...models.decoderonly import (
|
26
27
|
DecoderOnlyAttention,
|
27
28
|
DecoderOnlyDecoderLayer,
|
28
29
|
DecoderOnlyModel,
|
29
30
|
DecoderOnlyWrapper,
|
31
|
+
RotaryEmbedding,
|
30
32
|
)
|
31
33
|
|
32
34
|
|
35
|
+
logger = logging.get_logger(__name__)
|
36
|
+
|
37
|
+
|
33
38
|
class ExaoneForCausalLMWrapper(DecoderOnlyWrapper):
|
34
39
|
"""A wrapper class for the Exaone model with a language modeling head."""
|
35
40
|
|
36
|
-
def __init__(self, model, max_seq_len):
|
41
|
+
def __init__(self, model, max_seq_len, kvcache_partition_len=None):
|
37
42
|
super(DecoderOnlyWrapper, self).__init__()
|
38
43
|
self.config = model.config
|
39
44
|
self.model = self.convert_attribute_name(model.transformer)
|
40
45
|
self.lm_head = model.lm_head
|
41
|
-
self.
|
42
|
-
|
43
|
-
|
44
|
-
|
45
|
-
|
46
|
-
|
47
|
-
|
46
|
+
self.rotary_emb = RotaryEmbedding(config=self.config, max_seq_len_cached=max_seq_len)
|
47
|
+
|
48
|
+
if kvcache_partition_len is not None:
|
49
|
+
# WORKAROUND : for passing partition length as a value to the rbln compiler.
|
50
|
+
# What is actually used is the shape of this tensor.
|
51
|
+
self.kvcache_partition_size = torch.zeros(kvcache_partition_len, dtype=torch.int32)
|
52
|
+
self.attn_implementation = "flash_attn_rbln"
|
53
|
+
logger.info(f"Using rbln-flash-attention. (partition length : {kvcache_partition_len})")
|
54
|
+
else:
|
55
|
+
self.kvcache_partition_size = None
|
56
|
+
self.attn_implementation = "eager"
|
48
57
|
|
49
58
|
@staticmethod
|
50
59
|
def convert_attribute_name(model):
|
@@ -21,21 +21,13 @@
|
|
21
21
|
# copied, modified, or distributed without prior written permission
|
22
22
|
# from Rebellions Inc.
|
23
23
|
|
24
|
-
import
|
25
|
-
import logging
|
26
|
-
from typing import TYPE_CHECKING, Any, Callable
|
27
|
-
|
28
|
-
from ....modeling_config import RBLNConfig
|
24
|
+
from ....utils import logging
|
29
25
|
from ..decoderonly import RBLNDecoderOnlyModelForCausalLM
|
30
26
|
from .exaone_architecture import ExaoneForCausalLMWrapper
|
31
27
|
from .hf_hub_cached.modeling_exaone import ExaoneForCausalLM
|
32
28
|
|
33
29
|
|
34
|
-
logger = logging.
|
35
|
-
if TYPE_CHECKING:
|
36
|
-
from transformers import (
|
37
|
-
PreTrainedModel,
|
38
|
-
)
|
30
|
+
logger = logging.get_logger(__name__)
|
39
31
|
|
40
32
|
|
41
33
|
class RBLNExaoneForCausalLM(RBLNDecoderOnlyModelForCausalLM):
|
@@ -52,25 +44,8 @@ class RBLNExaoneForCausalLM(RBLNDecoderOnlyModelForCausalLM):
|
|
52
44
|
|
53
45
|
"""
|
54
46
|
|
55
|
-
|
56
|
-
|
57
|
-
rbln_max_seq_len = rbln_config.model_cfg["max_seq_len"]
|
58
|
-
return ExaoneForCausalLMWrapper(model, rbln_max_seq_len).eval()
|
59
|
-
|
60
|
-
def __getattr__(self, __name: str) -> Any:
|
61
|
-
"""This is the key method to implement RBLN-Exaone.
|
62
|
-
|
63
|
-
Returns:
|
64
|
-
Any: Exaone's corresponding method
|
65
|
-
"""
|
66
|
-
|
67
|
-
def redirect(func):
|
68
|
-
return lambda *pargs, **kwargs: func(self, *pargs, **kwargs)
|
69
|
-
|
70
|
-
val = getattr(ExaoneForCausalLM, __name)
|
71
|
-
if isinstance(val, Callable) and "self" in set(inspect.signature(val).parameters):
|
72
|
-
return redirect(val)
|
73
|
-
return val
|
47
|
+
_decoder_wrapper_cls = ExaoneForCausalLMWrapper
|
48
|
+
_original_cls = ExaoneForCausalLM
|
74
49
|
|
75
50
|
@classmethod
|
76
51
|
def from_pretrained(cls, *args, **kwargs):
|
@@ -29,11 +29,11 @@ from transformers.modeling_outputs import (
|
|
29
29
|
)
|
30
30
|
|
31
31
|
from ...models.decoderonly import (
|
32
|
-
DecoderOnlyAttention,
|
33
32
|
DecoderOnlyDecoderLayer,
|
34
33
|
DecoderOnlyWrapper,
|
35
34
|
slice_and_unsqueeze_cos_sin,
|
36
35
|
)
|
36
|
+
from ...models.decoderonly.decoderonly_architecture import DECODERONLY_ATTENTION_CLASSES
|
37
37
|
|
38
38
|
|
39
39
|
class GemmaWrapper(DecoderOnlyWrapper):
|
@@ -43,7 +43,7 @@ class GemmaWrapper(DecoderOnlyWrapper):
|
|
43
43
|
{
|
44
44
|
"wrapper": GemmaModel.forward,
|
45
45
|
"model": DecoderOnlyDecoderLayer.forward,
|
46
|
-
"decoder_layer":
|
46
|
+
"decoder_layer": DECODERONLY_ATTENTION_CLASSES[self.attn_implementation].forward,
|
47
47
|
}
|
48
48
|
)
|
49
49
|
return forward_dict
|
@@ -61,9 +61,17 @@ class GemmaModel:
|
|
61
61
|
use_cache: Optional[bool] = True,
|
62
62
|
output_attentions: Optional[bool] = False,
|
63
63
|
output_hidden_states: Optional[bool] = False,
|
64
|
+
cache_pos_for_partitions: Optional[torch.Tensor] = None,
|
65
|
+
kvcache_partition_size: Optional[torch.Tensor] = None,
|
64
66
|
forward_dict: Optional[Dict[str, classmethod]] = None,
|
65
67
|
rotary_pos_emb=None,
|
66
68
|
) -> Union[Tuple, BaseModelOutputWithPast]:
|
69
|
+
# retrieve input_ids and inputs_embeds
|
70
|
+
if (input_ids is None) ^ (inputs_embeds is not None):
|
71
|
+
raise ValueError(
|
72
|
+
"You cannot specify both input_ids and inputs_embeds at the same time, and must specify either one"
|
73
|
+
)
|
74
|
+
|
67
75
|
# embed positions
|
68
76
|
inputs_embeds = self.embed_tokens(input_ids)
|
69
77
|
hidden_states = inputs_embeds
|
@@ -96,6 +104,8 @@ class GemmaModel:
|
|
96
104
|
batch_ids=batch_ids,
|
97
105
|
cos=cos,
|
98
106
|
sin=sin,
|
107
|
+
cache_pos_for_partitions=cache_pos_for_partitions,
|
108
|
+
kvcache_partition_size=kvcache_partition_size,
|
99
109
|
forward_dict=forward_dict,
|
100
110
|
)
|
101
111
|
|
@@ -21,28 +21,18 @@
|
|
21
21
|
# copied, modified, or distributed without prior written permission
|
22
22
|
# from Rebellions Inc.
|
23
23
|
|
24
|
-
import
|
25
|
-
import logging
|
26
|
-
from typing import TYPE_CHECKING, Any, Callable
|
27
|
-
|
28
|
-
from transformers import GemmaForCausalLM
|
29
|
-
|
24
|
+
from ....utils import logging
|
30
25
|
from ...models.decoderonly import RBLNDecoderOnlyModelForCausalLM
|
31
26
|
from .gemma_architecture import GemmaWrapper
|
32
27
|
|
33
28
|
|
34
|
-
|
35
|
-
from transformers import PreTrainedModel
|
36
|
-
|
37
|
-
from ....modeling_config import RBLNConfig
|
38
|
-
|
39
|
-
logger = logging.getLogger(__name__)
|
29
|
+
logger = logging.get_logger(__name__)
|
40
30
|
|
41
31
|
|
42
32
|
class RBLNGemmaForCausalLM(RBLNDecoderOnlyModelForCausalLM):
|
43
33
|
"""
|
44
34
|
The Gemma Model transformer with a language modeling head (linear layer) on top.
|
45
|
-
This model inherits from [`
|
35
|
+
This model inherits from [`RBLNDecoderOnlyModelForCausalLM`]. Check the superclass documentation for the generic methods the library implements for all its models.
|
46
36
|
|
47
37
|
A class to convert and run pre-trained transformers based GemmaForCausalLM model on RBLN devices.
|
48
38
|
It implements the methods to convert a pre-trained transformers GemmaForCausalLM model into a RBLN transformer model by:
|
@@ -50,18 +40,4 @@ class RBLNGemmaForCausalLM(RBLNDecoderOnlyModelForCausalLM):
|
|
50
40
|
- compiling the resulting graph using the RBLN compiler.
|
51
41
|
"""
|
52
42
|
|
53
|
-
|
54
|
-
def wrap_model_if_needed(self, model: "PreTrainedModel", rbln_config: "RBLNConfig"):
|
55
|
-
rbln_max_seq_len = rbln_config.model_cfg["max_seq_len"]
|
56
|
-
return GemmaWrapper(model, rbln_max_seq_len).eval()
|
57
|
-
|
58
|
-
def __getattr__(self, __name: str) -> Any:
|
59
|
-
def redirect(func):
|
60
|
-
return lambda *pargs, **kwargs: func(self, *pargs, **kwargs)
|
61
|
-
|
62
|
-
val = getattr(GemmaForCausalLM, __name)
|
63
|
-
|
64
|
-
if isinstance(val, Callable) and "self" in set(inspect.signature(val).parameters):
|
65
|
-
return redirect(val)
|
66
|
-
|
67
|
-
return val
|
43
|
+
_decoder_wrapper_cls = GemmaWrapper
|
@@ -21,20 +21,12 @@
|
|
21
21
|
# copied, modified, or distributed without prior written permission
|
22
22
|
# from Rebellions Inc.
|
23
23
|
|
24
|
-
import
|
25
|
-
import logging
|
26
|
-
from typing import TYPE_CHECKING, Any, Callable
|
27
|
-
|
28
|
-
from transformers import GPT2LMHeadModel
|
29
|
-
|
30
|
-
from ....modeling_config import RBLNConfig
|
24
|
+
from ....utils import logging
|
31
25
|
from ...models.decoderonly import RBLNDecoderOnlyModelForCausalLM
|
32
26
|
from .gpt2_architecture import GPT2LMHeadModelWrapper
|
33
27
|
|
34
28
|
|
35
|
-
logger = logging.
|
36
|
-
if TYPE_CHECKING:
|
37
|
-
from transformers import PreTrainedModel
|
29
|
+
logger = logging.get_logger(__name__)
|
38
30
|
|
39
31
|
|
40
32
|
class RBLNGPT2LMHeadModel(RBLNDecoderOnlyModelForCausalLM):
|
@@ -42,7 +34,7 @@ class RBLNGPT2LMHeadModel(RBLNDecoderOnlyModelForCausalLM):
|
|
42
34
|
The GPT2 Model transformer with a language modeling head on top (linear layer with weights tied to the input
|
43
35
|
embeddings).
|
44
36
|
|
45
|
-
This model inherits from [`
|
37
|
+
This model inherits from [`RBLNDecoderOnlyModelForCausalLM`]. Check the superclass documentation for the generic methods the
|
46
38
|
library implements for all its model.
|
47
39
|
|
48
40
|
It implements the methods to convert a pre-trained transformers GPT2 model into a RBLN transformer model by:
|
@@ -51,22 +43,4 @@ class RBLNGPT2LMHeadModel(RBLNDecoderOnlyModelForCausalLM):
|
|
51
43
|
|
52
44
|
"""
|
53
45
|
|
54
|
-
|
55
|
-
def wrap_model_if_needed(self, model: "PreTrainedModel", rbln_config: "RBLNConfig"):
|
56
|
-
rbln_max_seq_len = rbln_config.model_cfg["max_seq_len"]
|
57
|
-
return GPT2LMHeadModelWrapper(model, rbln_max_seq_len).eval()
|
58
|
-
|
59
|
-
def __getattr__(self, __name: str) -> Any:
|
60
|
-
"""This is the key method to implement RBLN-GPT2.
|
61
|
-
|
62
|
-
Returns:
|
63
|
-
Any: GPT2's corresponding method
|
64
|
-
"""
|
65
|
-
|
66
|
-
def redirect(func):
|
67
|
-
return lambda *pargs, **kwargs: func(self, *pargs, **kwargs)
|
68
|
-
|
69
|
-
val = getattr(GPT2LMHeadModel, __name)
|
70
|
-
if isinstance(val, Callable) and "self" in set(inspect.signature(val).parameters):
|
71
|
-
return redirect(val)
|
72
|
-
return val
|
46
|
+
_decoder_wrapper_cls = GPT2LMHeadModelWrapper
|
@@ -21,28 +21,18 @@
|
|
21
21
|
# copied, modified, or distributed without prior written permission
|
22
22
|
# from Rebellions Inc.
|
23
23
|
|
24
|
-
import
|
25
|
-
import logging
|
26
|
-
from typing import TYPE_CHECKING, Any, Callable
|
27
|
-
|
28
|
-
from transformers import LlamaForCausalLM
|
29
|
-
|
24
|
+
from ....utils import logging
|
30
25
|
from ...models.decoderonly import RBLNDecoderOnlyModelForCausalLM
|
31
26
|
from .llama_architecture import LlamaWrapper
|
32
27
|
|
33
28
|
|
34
|
-
|
35
|
-
from transformers import PreTrainedModel
|
36
|
-
|
37
|
-
from ....modeling_config import RBLNConfig
|
38
|
-
|
39
|
-
logger = logging.getLogger(__name__)
|
29
|
+
logger = logging.get_logger(__name__)
|
40
30
|
|
41
31
|
|
42
32
|
class RBLNLlamaForCausalLM(RBLNDecoderOnlyModelForCausalLM):
|
43
33
|
"""
|
44
34
|
The Llama Model transformer with a language modeling head (linear layer) on top.
|
45
|
-
This model inherits from [`
|
35
|
+
This model inherits from [`RBLNDecoderOnlyModelForCausalLM`]. Check the superclass documentation for the generic methods the library implements for all its models.
|
46
36
|
|
47
37
|
A class to convert and run pre-trained transformers based LlamaForCausalLM model on RBLN devices.
|
48
38
|
It implements the methods to convert a pre-trained transformers LlamaForCausalLM model into a RBLN transformer model by:
|
@@ -50,18 +40,4 @@ class RBLNLlamaForCausalLM(RBLNDecoderOnlyModelForCausalLM):
|
|
50
40
|
- compiling the resulting graph using the RBLN compiler.
|
51
41
|
"""
|
52
42
|
|
53
|
-
|
54
|
-
def wrap_model_if_needed(self, model: "PreTrainedModel", rbln_config: "RBLNConfig"):
|
55
|
-
rbln_max_seq_len = rbln_config.model_cfg["max_seq_len"]
|
56
|
-
return LlamaWrapper(model, rbln_max_seq_len).eval()
|
57
|
-
|
58
|
-
def __getattr__(self, __name: str) -> Any:
|
59
|
-
def redirect(func):
|
60
|
-
return lambda *pargs, **kwargs: func(self, *pargs, **kwargs)
|
61
|
-
|
62
|
-
val = getattr(LlamaForCausalLM, __name)
|
63
|
-
|
64
|
-
if isinstance(val, Callable) and "self" in set(inspect.signature(val).parameters):
|
65
|
-
return redirect(val)
|
66
|
-
|
67
|
-
return val
|
43
|
+
_decoder_wrapper_cls = LlamaWrapper
|
@@ -350,9 +350,22 @@ class RBLNLlavaNextForConditionalGeneration(RBLNModel):
|
|
350
350
|
is_prefill_phase = not generate_idx.bool().all()
|
351
351
|
|
352
352
|
if is_prefill_phase:
|
353
|
+
# if the number of image tokens is more than image embeddings seq length, then prob we expanded it in processing
|
354
|
+
# not very reliable, but we don't expect one to actually pass 500+ images for one prompt
|
355
|
+
# In case we're in decoding stage, legacy behavior is checked by presence of pixel values even if use_cache=True
|
356
|
+
legacy_processing = (
|
357
|
+
(input_ids == self.config.image_token_index).sum(1).max() < self.config.image_seq_length
|
358
|
+
) or (input_ids.shape[-1] == 1 and pixel_values is not None)
|
359
|
+
|
353
360
|
# Get the number of images in the prompt
|
354
361
|
special_image_token_masks = [input_id == self.config.image_token_index for input_id in input_ids]
|
355
|
-
|
362
|
+
if legacy_processing:
|
363
|
+
num_special_image_tokens = [torch.sum(mask, dim=-1) for mask in special_image_token_masks]
|
364
|
+
else:
|
365
|
+
image_tokens_masks_diff = [
|
366
|
+
torch.diff(mask, prepend=torch.tensor([0])) for mask in special_image_token_masks
|
367
|
+
]
|
368
|
+
num_special_image_tokens = [int(torch.sum((diff == 1).int())) for diff in image_tokens_masks_diff]
|
356
369
|
|
357
370
|
# Split images for each prompt
|
358
371
|
if pixel_values is not None and pixel_values.size(0) > 0:
|
@@ -370,13 +383,19 @@ class RBLNLlavaNextForConditionalGeneration(RBLNModel):
|
|
370
383
|
image_features, feature_lens = self.image_embedding(
|
371
384
|
image_sizes[b_idx], pixel_values[b_idx], vision_feature_layer, vision_feature_select_strategy
|
372
385
|
)
|
373
|
-
|
374
|
-
|
375
|
-
|
376
|
-
|
377
|
-
|
378
|
-
|
379
|
-
|
386
|
+
if legacy_processing:
|
387
|
+
inputs_embed, _, _, _, _ = self._merge_input_ids_with_image_features(
|
388
|
+
image_features,
|
389
|
+
feature_lens,
|
390
|
+
inputs_embed.to(image_features.dtype),
|
391
|
+
input_id,
|
392
|
+
torch.ones_like(input_id, dtype=torch.long),
|
393
|
+
)
|
394
|
+
else:
|
395
|
+
special_image_mask = (
|
396
|
+
(input_id == self.config.image_token_index).unsqueeze(-1).expand_as(inputs_embed)
|
397
|
+
)
|
398
|
+
inputs_embed = inputs_embed.masked_scatter(special_image_mask, image_features)
|
380
399
|
|
381
400
|
# Update generate_idx according to inputs_embed
|
382
401
|
generate_idx[b_idx] = inputs_embed.shape[1]
|
@@ -58,23 +58,12 @@ class MidmLMHeadModelWrapper(torch.nn.Module):
|
|
58
58
|
self.model = model.transformer
|
59
59
|
self.lm_head = model.lm_head
|
60
60
|
self.config = model.config
|
61
|
-
self.head_dim = self.config.n_embd // self.config.n_head
|
62
|
-
self.max_position_embeddings = (
|
63
|
-
self.config.max_position_embeddings if max_seq_len > self.config.max_position_embeddings else max_seq_len
|
64
|
-
)
|
65
61
|
self.max_seq_len = max_seq_len
|
66
|
-
self.rotary_dim = int(
|
67
|
-
model.config.hidden_size // model.config.num_attention_heads * model.config.rotary_percentage
|
68
|
-
)
|
69
|
-
self.rotary_emb = self._init_rope()
|
70
62
|
|
71
|
-
|
72
|
-
|
73
|
-
|
74
|
-
|
75
|
-
max_position_embeddings=self.max_position_embeddings,
|
76
|
-
)
|
77
|
-
return rotary_emb
|
63
|
+
self.config.partial_rotary_factor = model.config.rotary_percentage
|
64
|
+
self.config.head_dim = self.config.n_embd // self.config.n_head
|
65
|
+
self.config.rope_theta = 10000
|
66
|
+
self.rotary_emb = RotaryEmbedding(config=self.config, max_seq_len_cached=max_seq_len)
|
78
67
|
|
79
68
|
def forward(
|
80
69
|
self,
|
@@ -21,11 +21,7 @@
|
|
21
21
|
# copied, modified, or distributed without prior written permission
|
22
22
|
# from Rebellions Inc.
|
23
23
|
|
24
|
-
import
|
25
|
-
import logging
|
26
|
-
from typing import TYPE_CHECKING, Any, Callable
|
27
|
-
|
28
|
-
from ....modeling_config import RBLNConfig
|
24
|
+
from ....utils import logging
|
29
25
|
from ...models.decoderonly import RBLNDecoderOnlyModelForCausalLM
|
30
26
|
from .hf_hub_cached.modeling_midm import MidmLMHeadModel
|
31
27
|
from .midm_architecture import (
|
@@ -33,11 +29,7 @@ from .midm_architecture import (
|
|
33
29
|
)
|
34
30
|
|
35
31
|
|
36
|
-
logger = logging.
|
37
|
-
if TYPE_CHECKING:
|
38
|
-
from transformers import (
|
39
|
-
PreTrainedModel,
|
40
|
-
)
|
32
|
+
logger = logging.get_logger(__name__)
|
41
33
|
|
42
34
|
|
43
35
|
class RBLNMidmLMHeadModel(RBLNDecoderOnlyModelForCausalLM):
|
@@ -54,25 +46,8 @@ class RBLNMidmLMHeadModel(RBLNDecoderOnlyModelForCausalLM):
|
|
54
46
|
|
55
47
|
"""
|
56
48
|
|
57
|
-
|
58
|
-
|
59
|
-
rbln_max_seq_len = rbln_config.model_cfg["max_seq_len"]
|
60
|
-
return MidmLMHeadModelWrapper(model, rbln_max_seq_len).eval()
|
61
|
-
|
62
|
-
def __getattr__(self, __name: str) -> Any:
|
63
|
-
"""This is the key method to implement RBLN-Midm.
|
64
|
-
|
65
|
-
Returns:
|
66
|
-
Any: Midm's corresponding method
|
67
|
-
"""
|
68
|
-
|
69
|
-
def redirect(func):
|
70
|
-
return lambda *pargs, **kwargs: func(self, *pargs, **kwargs)
|
71
|
-
|
72
|
-
val = getattr(MidmLMHeadModel, __name)
|
73
|
-
if isinstance(val, Callable) and "self" in set(inspect.signature(val).parameters):
|
74
|
-
return redirect(val)
|
75
|
-
return val
|
49
|
+
_decoder_wrapper_cls = MidmLMHeadModelWrapper
|
50
|
+
_original_cls = MidmLMHeadModel
|
76
51
|
|
77
52
|
@classmethod
|
78
53
|
def from_pretrained(cls, *args, **kwargs):
|
@@ -21,29 +21,18 @@
|
|
21
21
|
# copied, modified, or distributed without prior written permission
|
22
22
|
# from Rebellions Inc.
|
23
23
|
|
24
|
-
import
|
25
|
-
import logging
|
26
|
-
from typing import TYPE_CHECKING, Any, Callable
|
27
|
-
|
28
|
-
from transformers import MistralForCausalLM
|
29
|
-
|
24
|
+
from ....utils import logging
|
30
25
|
from ...models.decoderonly import RBLNDecoderOnlyModelForCausalLM
|
31
26
|
from .mistral_architecture import MistralForCausalLMWrapper
|
32
27
|
|
33
28
|
|
34
|
-
|
35
|
-
from transformers import PreTrainedModel
|
36
|
-
|
37
|
-
from ....modeling_config import RBLNConfig
|
38
|
-
|
39
|
-
|
40
|
-
logger = logging.getLogger(__name__)
|
29
|
+
logger = logging.get_logger(__name__)
|
41
30
|
|
42
31
|
|
43
32
|
class RBLNMistralForCausalLM(RBLNDecoderOnlyModelForCausalLM):
|
44
33
|
"""
|
45
34
|
The Llama Model transformer with a language modeling head (linear layer) on top.
|
46
|
-
This model inherits from [`
|
35
|
+
This model inherits from [`RBLNDecoderOnlyModelForCausalLM`]. Check the superclass documentation for the generic methods the library implements for all its models.
|
47
36
|
|
48
37
|
A class to convert and run pre-trained transformers based LlamaForCausalLM model on RBLN devices.
|
49
38
|
It implements the methods to convert a pre-trained transformers LlamaForCausalLM model into a RBLN transformer model by:
|
@@ -51,18 +40,4 @@ class RBLNMistralForCausalLM(RBLNDecoderOnlyModelForCausalLM):
|
|
51
40
|
- compiling the resulting graph using the RBLN compiler.
|
52
41
|
"""
|
53
42
|
|
54
|
-
|
55
|
-
def wrap_model_if_needed(self, model: "PreTrainedModel", rbln_config: "RBLNConfig"):
|
56
|
-
rbln_max_seq_len = rbln_config.model_cfg["max_seq_len"]
|
57
|
-
return MistralForCausalLMWrapper(model, rbln_max_seq_len).eval()
|
58
|
-
|
59
|
-
def __getattr__(self, __name: str) -> Any:
|
60
|
-
def redirect(func):
|
61
|
-
return lambda *pargs, **kwargs: func(self, *pargs, **kwargs)
|
62
|
-
|
63
|
-
val = getattr(MistralForCausalLM, __name)
|
64
|
-
|
65
|
-
if isinstance(val, Callable) and "self" in set(inspect.signature(val).parameters):
|
66
|
-
return redirect(val)
|
67
|
-
|
68
|
-
return val
|
43
|
+
_decoder_wrapper_cls = MistralForCausalLMWrapper
|