optimum-rbln 0.1.12__py3-none-any.whl → 0.1.15__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- optimum/rbln/__init__.py +27 -13
- optimum/rbln/__version__.py +16 -1
- optimum/rbln/diffusers/__init__.py +22 -2
- optimum/rbln/diffusers/models/__init__.py +34 -3
- optimum/rbln/{transformers/generation → diffusers/models/autoencoders}/__init__.py +1 -2
- optimum/rbln/diffusers/models/{autoencoder_kl.py → autoencoders/autoencoder_kl.py} +66 -111
- optimum/rbln/diffusers/models/autoencoders/vae.py +84 -0
- optimum/rbln/diffusers/models/controlnet.py +85 -65
- optimum/rbln/diffusers/models/transformers/__init__.py +24 -0
- optimum/rbln/diffusers/models/transformers/transformer_sd3.py +203 -0
- optimum/rbln/diffusers/models/unets/__init__.py +24 -0
- optimum/rbln/diffusers/models/{unet_2d_condition.py → unets/unet_2d_condition.py} +129 -163
- optimum/rbln/diffusers/pipelines/__init__.py +60 -12
- optimum/rbln/diffusers/pipelines/controlnet/multicontrolnet.py +11 -25
- optimum/rbln/diffusers/pipelines/controlnet/pipeline_controlnet.py +9 -185
- optimum/rbln/diffusers/pipelines/controlnet/pipeline_controlnet_img2img.py +9 -190
- optimum/rbln/diffusers/pipelines/controlnet/pipeline_controlnet_sd_xl.py +9 -191
- optimum/rbln/diffusers/pipelines/controlnet/pipeline_controlnet_sd_xl_img2img.py +9 -192
- optimum/rbln/diffusers/pipelines/stable_diffusion/__init__.py +1 -0
- optimum/rbln/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion.py +4 -110
- optimum/rbln/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_img2img.py +4 -118
- optimum/rbln/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_inpaint.py +32 -0
- optimum/rbln/diffusers/pipelines/stable_diffusion_3/__init__.py +26 -0
- optimum/rbln/diffusers/pipelines/stable_diffusion_3/pipeline_stable_diffusion_3.py +32 -0
- optimum/rbln/diffusers/pipelines/stable_diffusion_3/pipeline_stable_diffusion_3_img2img.py +32 -0
- optimum/rbln/diffusers/pipelines/stable_diffusion_3/pipeline_stable_diffusion_3_inpaint.py +32 -0
- optimum/rbln/diffusers/pipelines/stable_diffusion_xl/__init__.py +1 -0
- optimum/rbln/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl.py +18 -128
- optimum/rbln/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_img2img.py +18 -131
- optimum/rbln/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_inpaint.py +32 -0
- optimum/rbln/modeling.py +572 -0
- optimum/rbln/modeling_alias.py +1 -1
- optimum/rbln/modeling_base.py +176 -763
- optimum/rbln/modeling_diffusers.py +329 -0
- optimum/rbln/transformers/__init__.py +2 -2
- optimum/rbln/transformers/cache_utils.py +5 -9
- optimum/rbln/transformers/modeling_rope_utils.py +283 -0
- optimum/rbln/transformers/models/__init__.py +80 -31
- optimum/rbln/transformers/models/auto/auto_factory.py +117 -23
- optimum/rbln/transformers/models/auto/modeling_auto.py +37 -12
- optimum/rbln/transformers/models/bart/modeling_bart.py +3 -6
- optimum/rbln/transformers/models/bert/modeling_bert.py +3 -6
- optimum/rbln/transformers/models/clip/modeling_clip.py +8 -34
- optimum/rbln/transformers/models/decoderonly/__init__.py +0 -5
- optimum/rbln/transformers/models/decoderonly/decoderonly_architecture.py +779 -361
- optimum/rbln/transformers/models/decoderonly/modeling_decoderonly.py +83 -142
- optimum/rbln/transformers/models/dpt/modeling_dpt.py +1 -1
- optimum/rbln/transformers/models/exaone/exaone_architecture.py +64 -39
- optimum/rbln/transformers/models/exaone/modeling_exaone.py +6 -29
- optimum/rbln/transformers/models/gemma/gemma_architecture.py +31 -92
- optimum/rbln/transformers/models/gemma/modeling_gemma.py +4 -28
- optimum/rbln/transformers/models/gpt2/gpt2_architecture.py +50 -238
- optimum/rbln/transformers/models/gpt2/modeling_gpt2.py +6 -31
- optimum/rbln/transformers/models/llama/modeling_llama.py +4 -28
- optimum/rbln/transformers/models/llava_next/modeling_llava_next.py +29 -83
- optimum/rbln/transformers/models/midm/midm_architecture.py +88 -253
- optimum/rbln/transformers/models/midm/modeling_midm.py +8 -33
- optimum/rbln/transformers/models/mistral/modeling_mistral.py +4 -29
- optimum/rbln/transformers/models/phi/modeling_phi.py +5 -31
- optimum/rbln/transformers/models/phi/phi_architecture.py +61 -345
- optimum/rbln/transformers/models/qwen2/modeling_qwen2.py +5 -29
- optimum/rbln/transformers/models/seq2seq/modeling_seq2seq.py +1 -46
- optimum/rbln/transformers/models/t5/__init__.py +1 -1
- optimum/rbln/transformers/models/t5/modeling_t5.py +157 -6
- optimum/rbln/transformers/models/wav2vec2/modeling_wav2vec2.py +1 -1
- optimum/rbln/transformers/models/whisper/modeling_whisper.py +2 -2
- optimum/rbln/transformers/models/xlm_roberta/modeling_xlm_roberta.py +3 -35
- optimum/rbln/transformers/utils/rbln_quantization.py +128 -5
- optimum/rbln/utils/decorator_utils.py +59 -0
- optimum/rbln/utils/hub.py +131 -0
- optimum/rbln/utils/import_utils.py +21 -0
- optimum/rbln/utils/model_utils.py +53 -0
- optimum/rbln/utils/runtime_utils.py +5 -5
- optimum/rbln/utils/submodule.py +114 -0
- optimum/rbln/utils/timer_utils.py +2 -2
- optimum_rbln-0.1.15.dist-info/METADATA +106 -0
- optimum_rbln-0.1.15.dist-info/RECORD +110 -0
- {optimum_rbln-0.1.12.dist-info → optimum_rbln-0.1.15.dist-info}/WHEEL +1 -1
- optimum/rbln/transformers/generation/streamers.py +0 -139
- optimum/rbln/transformers/generation/utils.py +0 -397
- optimum/rbln/transformers/models/exaone/hf_hub_cached/configuration_exaone.py +0 -181
- optimum/rbln/transformers/models/exaone/hf_hub_cached/modeling_exaone.py +0 -1725
- optimum/rbln/transformers/models/midm/hf_hub_cached/configuration_midm.py +0 -22
- optimum/rbln/transformers/models/midm/hf_hub_cached/midm_bitext_tokenization.py +0 -304
- optimum/rbln/transformers/models/midm/hf_hub_cached/modeling_midm.py +0 -1469
- optimum/rbln/transformers/models/midm/hf_hub_cached/rotary_position_embedding.py +0 -98
- optimum_rbln-0.1.12.dist-info/METADATA +0 -119
- optimum_rbln-0.1.12.dist-info/RECORD +0 -103
- optimum_rbln-0.1.12.dist-info/entry_points.txt +0 -4
- {optimum_rbln-0.1.12.dist-info → optimum_rbln-0.1.15.dist-info}/licenses/LICENSE +0 -0
@@ -22,12 +22,23 @@
|
|
22
22
|
# from Rebellions Inc.
|
23
23
|
|
24
24
|
import inspect
|
25
|
-
from typing import TYPE_CHECKING, Any, Callable
|
26
|
-
|
27
|
-
|
28
|
-
|
29
|
-
from
|
25
|
+
from typing import TYPE_CHECKING, Any, Callable, Dict, Optional, Tuple, Union
|
26
|
+
|
27
|
+
import torch
|
28
|
+
import transformers
|
29
|
+
from transformers import (
|
30
|
+
AutoModelForTextEncoding,
|
31
|
+
PretrainedConfig,
|
32
|
+
T5EncoderModel,
|
33
|
+
T5ForConditionalGeneration,
|
34
|
+
)
|
35
|
+
from transformers.modeling_outputs import BaseModelOutput
|
36
|
+
|
37
|
+
from ....modeling import RBLNModel
|
38
|
+
from ....modeling_config import RBLNCompileConfig, RBLNConfig
|
39
|
+
from ....modeling_diffusers import RBLNDiffusionMixin
|
30
40
|
from ....utils.logging import get_logger
|
41
|
+
from ....utils.runtime_utils import RBLNPytorchRuntime
|
31
42
|
from ...models.seq2seq import RBLNModelForSeq2SeqLM
|
32
43
|
from .t5_architecture import T5Wrapper
|
33
44
|
|
@@ -35,7 +46,147 @@ from .t5_architecture import T5Wrapper
|
|
35
46
|
logger = get_logger()
|
36
47
|
|
37
48
|
if TYPE_CHECKING:
|
38
|
-
from transformers import PreTrainedModel
|
49
|
+
from transformers import AutoFeatureExtractor, AutoProcessor, AutoTokenizer, PreTrainedModel
|
50
|
+
|
51
|
+
|
52
|
+
class RBLNRuntimeModel(RBLNPytorchRuntime):
|
53
|
+
def forward(
|
54
|
+
self,
|
55
|
+
input_ids: torch.LongTensor,
|
56
|
+
attention_mask: torch.FloatTensor,
|
57
|
+
head_mask: torch.FloatTensor,
|
58
|
+
inputs_embeds: torch.FloatTensor,
|
59
|
+
**kwargs,
|
60
|
+
):
|
61
|
+
return super().forward(
|
62
|
+
input_ids,
|
63
|
+
attention_mask,
|
64
|
+
head_mask,
|
65
|
+
inputs_embeds,
|
66
|
+
**kwargs,
|
67
|
+
)
|
68
|
+
|
69
|
+
|
70
|
+
class T5EncoderWrapper(torch.nn.Module):
|
71
|
+
def __init__(self, model: "T5EncoderModel") -> None:
|
72
|
+
super().__init__()
|
73
|
+
self.model = model
|
74
|
+
|
75
|
+
def forward(self, *args, **kwargs):
|
76
|
+
kwargs.pop("return_dict", None)
|
77
|
+
return self.model(*args, **kwargs, return_dict=False)
|
78
|
+
|
79
|
+
|
80
|
+
class RBLNT5EncoderModel(RBLNModel):
|
81
|
+
auto_model_class = AutoModelForTextEncoding
|
82
|
+
rbln_model_input_names = ["input_ids", "attention_mask"]
|
83
|
+
|
84
|
+
def __post_init__(self, **kwargs):
|
85
|
+
self.model = RBLNRuntimeModel(runtime=self.model[0])
|
86
|
+
|
87
|
+
@classmethod
|
88
|
+
def wrap_model_if_needed(self, model: "PreTrainedModel", rbln_config: "RBLNConfig"):
|
89
|
+
return T5EncoderWrapper(model)
|
90
|
+
|
91
|
+
@classmethod
|
92
|
+
def update_rbln_config_using_pipe(cls, pipe: RBLNDiffusionMixin, rbln_config: Dict[str, Any]) -> Dict[str, Any]:
|
93
|
+
batch_size = rbln_config.get("batch_size", 1)
|
94
|
+
max_sequence_length = rbln_config.get("max_sequence_length", 256)
|
95
|
+
model_input_names = ["input_ids"]
|
96
|
+
|
97
|
+
rbln_config.update(
|
98
|
+
{
|
99
|
+
"batch_size": batch_size,
|
100
|
+
"max_seq_len": max_sequence_length,
|
101
|
+
"model_input_names": model_input_names,
|
102
|
+
}
|
103
|
+
)
|
104
|
+
|
105
|
+
return rbln_config
|
106
|
+
|
107
|
+
@classmethod
|
108
|
+
def _get_rbln_config(
|
109
|
+
cls,
|
110
|
+
preprocessors: Optional[Union["AutoFeatureExtractor", "AutoProcessor", "AutoTokenizer"]],
|
111
|
+
model_config: Optional["PretrainedConfig"] = None,
|
112
|
+
rbln_kwargs: Dict[str, Any] = {},
|
113
|
+
) -> RBLNConfig:
|
114
|
+
rbln_max_seq_len = rbln_kwargs.get("max_seq_len", None)
|
115
|
+
rbln_model_input_names = rbln_kwargs.get("model_input_names", None)
|
116
|
+
rbln_batch_size = rbln_kwargs.get("batch_size", None)
|
117
|
+
|
118
|
+
max_position_embeddings = getattr(model_config, "n_positions", None)
|
119
|
+
|
120
|
+
if rbln_max_seq_len is None:
|
121
|
+
rbln_max_seq_len = max_position_embeddings
|
122
|
+
if rbln_max_seq_len is None:
|
123
|
+
for tokenizer in preprocessors:
|
124
|
+
if hasattr(tokenizer, "model_max_length"):
|
125
|
+
rbln_max_seq_len = tokenizer.model_max_length
|
126
|
+
break
|
127
|
+
if rbln_max_seq_len is None:
|
128
|
+
raise ValueError("`rbln_max_seq_len` should be specified!")
|
129
|
+
|
130
|
+
if max_position_embeddings is not None and rbln_max_seq_len > max_position_embeddings:
|
131
|
+
raise ValueError("`rbln_max_seq_len` should be less or equal than max_position_embeddings!")
|
132
|
+
|
133
|
+
if rbln_model_input_names is None:
|
134
|
+
for tokenizer in preprocessors:
|
135
|
+
if hasattr(tokenizer, "model_input_names"):
|
136
|
+
rbln_model_input_names = tokenizer.model_input_names
|
137
|
+
break
|
138
|
+
if rbln_model_input_names is None and hasattr(cls, "rbln_model_input_names"):
|
139
|
+
rbln_model_input_names = cls.rbln_model_input_names
|
140
|
+
elif rbln_model_input_names is None and hasattr(cls, "rbln_model_input_names") is False:
|
141
|
+
original_model_class = getattr(transformers, model_config.architectures[0])
|
142
|
+
input_names_order = inspect.signature(original_model_class.forward).parameters.keys()
|
143
|
+
raise ValueError(
|
144
|
+
"Specify the model input names obtained by the tokenizer via `rbln_model_input_names`, "
|
145
|
+
f"and be sure to make the order of the inputs same as T5EncoderModel forward() arguments like ({list(input_names_order)})"
|
146
|
+
)
|
147
|
+
|
148
|
+
if rbln_batch_size is None:
|
149
|
+
rbln_batch_size = 1
|
150
|
+
|
151
|
+
input_info = [
|
152
|
+
(model_input_name, [rbln_batch_size, rbln_max_seq_len], "int64")
|
153
|
+
for model_input_name in rbln_model_input_names
|
154
|
+
]
|
155
|
+
|
156
|
+
rbln_compile_config = RBLNCompileConfig(input_info=input_info)
|
157
|
+
|
158
|
+
rbln_config = RBLNConfig(
|
159
|
+
rbln_cls=cls.__name__,
|
160
|
+
compile_cfgs=[rbln_compile_config],
|
161
|
+
rbln_kwargs=rbln_kwargs,
|
162
|
+
)
|
163
|
+
|
164
|
+
rbln_config.model_cfg.update({"max_seq_len": rbln_max_seq_len})
|
165
|
+
return rbln_config
|
166
|
+
|
167
|
+
def forward(
|
168
|
+
self,
|
169
|
+
input_ids: Optional[torch.LongTensor] = None,
|
170
|
+
attention_mask: Optional[torch.FloatTensor] = None,
|
171
|
+
head_mask: Optional[torch.FloatTensor] = None,
|
172
|
+
inputs_embeds: Optional[torch.FloatTensor] = None,
|
173
|
+
output_attentions: Optional[bool] = None,
|
174
|
+
output_hidden_states: Optional[bool] = None,
|
175
|
+
return_dict: Optional[bool] = None,
|
176
|
+
) -> Union[Tuple[torch.FloatTensor], BaseModelOutput]:
|
177
|
+
encoder_outputs = self.model(
|
178
|
+
input_ids=input_ids,
|
179
|
+
attention_mask=attention_mask,
|
180
|
+
inputs_embeds=inputs_embeds,
|
181
|
+
head_mask=head_mask,
|
182
|
+
output_attentions=output_attentions,
|
183
|
+
output_hidden_states=output_hidden_states,
|
184
|
+
return_dict=return_dict,
|
185
|
+
)
|
186
|
+
if not return_dict:
|
187
|
+
return (encoder_outputs,)
|
188
|
+
else:
|
189
|
+
return BaseModelOutput(last_hidden_state=encoder_outputs)
|
39
190
|
|
40
191
|
|
41
192
|
class RBLNT5ForConditionalGeneration(RBLNModelForSeq2SeqLM):
|
@@ -28,7 +28,7 @@ import torch
|
|
28
28
|
from transformers import AutoModelForMaskedLM, PretrainedConfig, Wav2Vec2ForCTC
|
29
29
|
from transformers.modeling_outputs import CausalLMOutput
|
30
30
|
|
31
|
-
from ....
|
31
|
+
from ....modeling import RBLNModel
|
32
32
|
from ....modeling_config import RBLNCompileConfig, RBLNConfig
|
33
33
|
|
34
34
|
|
@@ -36,7 +36,7 @@ from transformers import (
|
|
36
36
|
)
|
37
37
|
from transformers.modeling_outputs import BaseModelOutput, Seq2SeqLMOutput
|
38
38
|
|
39
|
-
from ....
|
39
|
+
from ....modeling import RBLNModel
|
40
40
|
from ....modeling_config import DEFAULT_COMPILED_MODEL_NAME, RBLNCompileConfig, RBLNConfig
|
41
41
|
from ....utils.runtime_utils import RBLNPytorchRuntime
|
42
42
|
from .generation_whisper import RBLNWhisperGenerationMixin
|
@@ -102,7 +102,7 @@ class RBLNRuntimeDecoder(RBLNPytorchRuntime):
|
|
102
102
|
class RBLNWhisperForConditionalGeneration(RBLNModel, RBLNWhisperGenerationMixin):
|
103
103
|
"""
|
104
104
|
The Whisper Model with a language modeling head. Can be used for automatic speech recognition.
|
105
|
-
This model inherits from [`
|
105
|
+
This model inherits from [`RBLNDecoderOnlyModelForCausalLM`]. Check the superclass documentation for the generic methods the library implements for all its models.
|
106
106
|
|
107
107
|
A class to convert and run pre-trained transformers based LlamaForCausalLM model on RBLN devices.
|
108
108
|
It implements the methods to convert a pre-trained transformers LlamaForCausalLM model into a RBLN transformer model by:
|
@@ -22,12 +22,12 @@
|
|
22
22
|
# from Rebellions Inc.
|
23
23
|
|
24
24
|
import logging
|
25
|
-
from typing import TYPE_CHECKING,
|
25
|
+
from typing import TYPE_CHECKING, Optional, Union
|
26
26
|
|
27
27
|
import torch
|
28
|
-
from transformers import PretrainedConfig
|
28
|
+
from transformers import PretrainedConfig
|
29
29
|
|
30
|
-
from ....
|
30
|
+
from ....modeling import RBLNModel
|
31
31
|
from ....modeling_config import RBLNCompileConfig, RBLNConfig
|
32
32
|
|
33
33
|
|
@@ -38,38 +38,6 @@ if TYPE_CHECKING:
|
|
38
38
|
|
39
39
|
|
40
40
|
class RBLNXLMRobertaModel(RBLNModel):
|
41
|
-
original_model_class = XLMRobertaModel
|
42
|
-
original_config_class = XLMRobertaConfig
|
43
|
-
|
44
|
-
@classmethod
|
45
|
-
def get_pytorch_model(
|
46
|
-
cls,
|
47
|
-
model_id: str,
|
48
|
-
use_auth_token: Optional[Union[bool, str]] = None,
|
49
|
-
revision: Optional[str] = None,
|
50
|
-
force_download: bool = False,
|
51
|
-
cache_dir: Optional[str] = None,
|
52
|
-
subfolder: str = "",
|
53
|
-
local_files_only: bool = False,
|
54
|
-
trust_remote_code: bool = False,
|
55
|
-
rbln_kwargs: Optional[Dict[str, Any]] = None,
|
56
|
-
**kwargs,
|
57
|
-
) -> "PreTrainedModel":
|
58
|
-
model: "PreTrainedModel" = super().get_pytorch_model(
|
59
|
-
model_id=model_id,
|
60
|
-
use_auth_token=use_auth_token,
|
61
|
-
revision=revision,
|
62
|
-
force_download=force_download,
|
63
|
-
cache_dir=cache_dir,
|
64
|
-
subfolder=subfolder,
|
65
|
-
local_files_only=local_files_only,
|
66
|
-
trust_remote_code=trust_remote_code,
|
67
|
-
rbln_kwargs=rbln_kwargs,
|
68
|
-
library_name="transformers",
|
69
|
-
)
|
70
|
-
|
71
|
-
return model
|
72
|
-
|
73
41
|
@classmethod
|
74
42
|
def _get_rbln_config(
|
75
43
|
cls,
|
@@ -22,21 +22,117 @@
|
|
22
22
|
# from Rebellions Inc.
|
23
23
|
|
24
24
|
|
25
|
-
|
25
|
+
import functools
|
26
|
+
import glob
|
27
|
+
import os
|
28
|
+
from typing import Any, Callable, Dict, Optional
|
26
29
|
|
27
30
|
import torch
|
31
|
+
from safetensors.torch import load_file
|
28
32
|
from torch.nn import Linear, Parameter
|
29
33
|
from torch.nn import functional as F
|
30
34
|
|
35
|
+
from ...utils.logging import get_logger
|
36
|
+
|
37
|
+
|
38
|
+
logger = get_logger()
|
39
|
+
|
40
|
+
SUPPORTED_QUANTIZATIONS: Dict[str, list[str]] = {
|
41
|
+
"rbln": ["w4a16"],
|
42
|
+
}
|
43
|
+
|
44
|
+
|
45
|
+
class QuantizationManager:
|
46
|
+
# The RBLN_QUANT_BITS environment variable defines the precision of each layer during the graph compilation process.
|
47
|
+
# It specifies the quantization bit depth. For instance, setting RBLN_QUANT_BITS=4 will apply 4-bit precision for quantization.
|
48
|
+
RBLN_QUANT_BITS_ENV = "RBLN_QUANT_BITS"
|
49
|
+
|
50
|
+
@staticmethod
|
51
|
+
def _raise_invalid_config_error(
|
52
|
+
key: str, value: str, valid_values: list[str], context: Optional[str] = None
|
53
|
+
) -> None:
|
54
|
+
context_info = f" for {context}" if context else ""
|
55
|
+
valid_values_str = ", ".join(valid_values)
|
56
|
+
raise ValueError(f"Invalid {key}: {value}{context_info}. " f"Supported values are: {valid_values_str}")
|
57
|
+
|
58
|
+
@staticmethod
|
59
|
+
def validate_quantization_config(quantize_config: Optional[dict]) -> Optional[dict]:
|
60
|
+
if not quantize_config:
|
61
|
+
return None
|
62
|
+
|
63
|
+
q_format = quantize_config.get("format")
|
64
|
+
q_precision = quantize_config.get("precision")
|
65
|
+
|
66
|
+
if q_format not in SUPPORTED_QUANTIZATIONS:
|
67
|
+
QuantizationManager._raise_invalid_config_error(
|
68
|
+
"quantization format", q_format, list(SUPPORTED_QUANTIZATIONS.keys())
|
69
|
+
)
|
70
|
+
|
71
|
+
if q_precision not in SUPPORTED_QUANTIZATIONS[q_format]:
|
72
|
+
QuantizationManager._raise_invalid_config_error(
|
73
|
+
"precision", q_precision, SUPPORTED_QUANTIZATIONS[q_format], q_format
|
74
|
+
)
|
75
|
+
|
76
|
+
return quantize_config
|
77
|
+
|
78
|
+
@classmethod
|
79
|
+
def _set_env_var(cls, name: str, value: str) -> None:
|
80
|
+
os.environ[name] = value
|
81
|
+
|
82
|
+
@classmethod
|
83
|
+
def _unset_env_var(cls, name: str) -> None:
|
84
|
+
os.environ.pop(name, None)
|
85
|
+
|
86
|
+
@classmethod
|
87
|
+
def set_quantization_env(cls, quantize_config: Optional[dict]) -> Optional[str]:
|
88
|
+
quantize_config = cls.validate_quantization_config(quantize_config)
|
89
|
+
if quantize_config:
|
90
|
+
q_precision: str = quantize_config["precision"]
|
91
|
+
quant_bits = q_precision.split("w")[1].split("a")[0]
|
92
|
+
cls._set_env_var(cls.RBLN_QUANT_BITS_ENV, quant_bits)
|
93
|
+
return cls.RBLN_QUANT_BITS_ENV
|
94
|
+
return None
|
95
|
+
|
96
|
+
@classmethod
|
97
|
+
def reset_quantization_env(cls, env_var_name: Optional[str]) -> None:
|
98
|
+
if env_var_name:
|
99
|
+
cls._unset_env_var(env_var_name)
|
100
|
+
|
101
|
+
@classmethod
|
102
|
+
def with_quantization_env(cls, func: Callable) -> Callable:
|
103
|
+
@functools.wraps(func)
|
104
|
+
def wrapper(*args, **kwargs):
|
105
|
+
quantize_config = kwargs.get("quantize_config")
|
106
|
+
quantize_env_var = cls.set_quantization_env(quantize_config)
|
107
|
+
try:
|
108
|
+
return func(*args, **kwargs)
|
109
|
+
finally:
|
110
|
+
cls.reset_quantization_env(quantize_env_var)
|
111
|
+
|
112
|
+
return wrapper
|
113
|
+
|
31
114
|
|
32
115
|
# Constants
|
33
116
|
QUANTIZED_WEIGHTS = {
|
34
|
-
"q_proj",
|
35
|
-
"
|
117
|
+
"q_proj",
|
118
|
+
"k_proj",
|
119
|
+
"v_proj",
|
120
|
+
"o_proj",
|
121
|
+
"gate_proj",
|
122
|
+
"up_proj",
|
123
|
+
"down_proj",
|
36
124
|
}
|
37
125
|
|
38
126
|
|
39
|
-
def
|
127
|
+
def prepare_model_for_quantization(model: torch.nn.Module, model_id: str, n_layer: Optional[int] = None) -> None:
|
128
|
+
"""
|
129
|
+
Prepare the model for quantization by updating specified linear layers to quantized (qlinear) layers.
|
130
|
+
"""
|
131
|
+
update_layers_to_quantize(model)
|
132
|
+
load_weights(model, model_id, n_layer)
|
133
|
+
|
134
|
+
|
135
|
+
def update_layers_to_quantize(module: torch.nn.Module) -> None:
|
40
136
|
"""
|
41
137
|
Updates specified linear layers to quantized (qlinear) layers in the given module.
|
42
138
|
"""
|
@@ -49,7 +145,33 @@ def update_layers_to_quantized(module: torch.nn.Module) -> None:
|
|
49
145
|
processed_layers.append(name)
|
50
146
|
|
51
147
|
if processed_layers:
|
52
|
-
|
148
|
+
logger.debug(f"Updated the following linear layers to quantized layers:\n {{{', '.join(processed_layers)}}}")
|
149
|
+
|
150
|
+
|
151
|
+
def load_weights(model, model_id, n_layer=None):
|
152
|
+
"""
|
153
|
+
Load safetensor file data directly into the model, filtering by layer if n_layer is provided.
|
154
|
+
"""
|
155
|
+
|
156
|
+
model_params = dict(model.named_parameters(recurse=True))
|
157
|
+
model_buffers = dict(model.named_buffers(recurse=True))
|
158
|
+
safetensor_files = glob.glob(f"{model_id}/*.safetensors")
|
159
|
+
|
160
|
+
target_layers = list(range(n_layer)) if n_layer is not None else None
|
161
|
+
|
162
|
+
for safetensor_file in safetensor_files:
|
163
|
+
file_data = load_file(safetensor_file)
|
164
|
+
for key, value in file_data.items():
|
165
|
+
if target_layers is not None:
|
166
|
+
parts = key.split(".")
|
167
|
+
|
168
|
+
if len(parts) > 2 and parts[2].isdigit() and (int(parts[2]) not in target_layers):
|
169
|
+
continue
|
170
|
+
|
171
|
+
if key in model_params:
|
172
|
+
model_params[key].data.copy_(value)
|
173
|
+
elif key in model_buffers:
|
174
|
+
model_buffers[key].data.copy_(value)
|
53
175
|
|
54
176
|
|
55
177
|
def is_target_for_qlinear_replacement(layer_name: str, layer: torch.nn.Module) -> bool:
|
@@ -81,6 +203,7 @@ def create_qlinear(layer: Linear) -> Linear:
|
|
81
203
|
"""
|
82
204
|
Converts a standard linear layer to a quantized linear (qlinear) layer with a custom forward pass.
|
83
205
|
"""
|
206
|
+
|
84
207
|
def qlinear_forward(self, inputs: torch.Tensor) -> torch.Tensor:
|
85
208
|
if inputs.dtype != self.scales.dtype:
|
86
209
|
raise TypeError(f"Expected input dtype {self.scales.dtype}, but got {inputs.dtype}")
|
@@ -0,0 +1,59 @@
|
|
1
|
+
from functools import wraps
|
2
|
+
|
3
|
+
from .logging import get_logger
|
4
|
+
|
5
|
+
|
6
|
+
logger = get_logger(__name__)
|
7
|
+
|
8
|
+
|
9
|
+
def remove_compile_time_kwargs(func):
|
10
|
+
"""
|
11
|
+
Decorator to handle compile-time parameters during inference.
|
12
|
+
|
13
|
+
For RBLN-optimized pipelines, several parameters must be determined during compilation
|
14
|
+
and cannot be modified during inference. This decorator:
|
15
|
+
1. Removes and warns about LoRA scale in cross_attention_kwargs
|
16
|
+
2. Removes and warns about image dimension parameters (height, width)
|
17
|
+
|
18
|
+
Args:
|
19
|
+
func: The pipeline's __call__ method to be wrapped
|
20
|
+
"""
|
21
|
+
|
22
|
+
@wraps(func)
|
23
|
+
def wrapper(self, *args, **kwargs):
|
24
|
+
height_exists = "height" in kwargs and kwargs["height"] is not None
|
25
|
+
width_exists = "width" in kwargs and kwargs["width"] is not None
|
26
|
+
compiled_image_size = self.vae.image_size
|
27
|
+
if height_exists or width_exists:
|
28
|
+
if kwargs["height"] == compiled_image_size[0] and kwargs["width"] == compiled_image_size[1]:
|
29
|
+
pass
|
30
|
+
else:
|
31
|
+
logger.warning(
|
32
|
+
"Image dimension parameters (`height`, `width`) will be ignored during inference. "
|
33
|
+
"Image dimensions must be specified during model compilation using from_pretrained()."
|
34
|
+
)
|
35
|
+
kwargs.pop("width", None)
|
36
|
+
kwargs.pop("height", None)
|
37
|
+
|
38
|
+
if "cross_attention_kwargs" in kwargs:
|
39
|
+
cross_attention_kwargs = kwargs.get("cross_attention_kwargs")
|
40
|
+
if not cross_attention_kwargs:
|
41
|
+
return func(self, *args, **kwargs)
|
42
|
+
|
43
|
+
has_scale = "scale" in cross_attention_kwargs
|
44
|
+
if has_scale:
|
45
|
+
logger.warning(
|
46
|
+
"LoRA scale in cross_attention_kwargs will be ignored during inference. "
|
47
|
+
"To adjust LoRA scale, specify it during model compilation using from_pretrained()."
|
48
|
+
)
|
49
|
+
|
50
|
+
# If scale is the only key, set to None
|
51
|
+
# Otherwise, remove scale and preserve other settings
|
52
|
+
if len(cross_attention_kwargs) == 1:
|
53
|
+
kwargs["cross_attention_kwargs"] = None
|
54
|
+
else:
|
55
|
+
kwargs["cross_attention_kwargs"].pop("scale")
|
56
|
+
|
57
|
+
return func(self, *args, **kwargs)
|
58
|
+
|
59
|
+
return wrapper
|
@@ -0,0 +1,131 @@
|
|
1
|
+
# Copyright 2024 Rebellions Inc.
|
2
|
+
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4
|
+
# you may not use this file except in compliance with the License.
|
5
|
+
# You may obtain a copy of the License at:
|
6
|
+
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
8
|
+
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12
|
+
# See the License for the specific language governing permissions and
|
13
|
+
# limitations under the License.
|
14
|
+
|
15
|
+
# Portions of this software are licensed under the Apache License,
|
16
|
+
# Version 2.0. See the NOTICE file distributed with this work for
|
17
|
+
# additional information regarding copyright ownership.
|
18
|
+
|
19
|
+
# All other portions of this software, including proprietary code,
|
20
|
+
# are the intellectual property of Rebellions Inc. and may not be
|
21
|
+
# copied, modified, or distributed without prior written permission
|
22
|
+
# from Rebellions Inc.
|
23
|
+
|
24
|
+
import os
|
25
|
+
from pathlib import Path
|
26
|
+
from typing import List, Optional, Union
|
27
|
+
|
28
|
+
from huggingface_hub import HfApi, HfFolder, hf_hub_download
|
29
|
+
|
30
|
+
|
31
|
+
class PushToHubMixin:
|
32
|
+
def push_to_hub(
|
33
|
+
self,
|
34
|
+
save_directory: str,
|
35
|
+
repository_id: str,
|
36
|
+
private: Optional[bool] = None,
|
37
|
+
use_auth_token: Union[bool, str] = True,
|
38
|
+
) -> str:
|
39
|
+
huggingface_token = _get_huggingface_token(use_auth_token)
|
40
|
+
api = HfApi()
|
41
|
+
|
42
|
+
api.create_repo(
|
43
|
+
token=huggingface_token,
|
44
|
+
repo_id=repository_id,
|
45
|
+
exist_ok=True,
|
46
|
+
private=private,
|
47
|
+
)
|
48
|
+
for path, subdirs, files in os.walk(save_directory):
|
49
|
+
for name in files:
|
50
|
+
local_file_path = os.path.join(path, name)
|
51
|
+
_, hub_file_path = os.path.split(local_file_path)
|
52
|
+
# FIXME: when huggingface_hub fixes the return of upload_file
|
53
|
+
try:
|
54
|
+
api.upload_file(
|
55
|
+
token=huggingface_token,
|
56
|
+
repo_id=f"{repository_id}",
|
57
|
+
path_or_fileobj=os.path.join(os.getcwd(), local_file_path),
|
58
|
+
path_in_repo=hub_file_path,
|
59
|
+
)
|
60
|
+
except KeyError:
|
61
|
+
pass
|
62
|
+
except NameError:
|
63
|
+
pass
|
64
|
+
|
65
|
+
|
66
|
+
def pull_compiled_model_from_hub(
|
67
|
+
model_id: Union[str, Path],
|
68
|
+
subfolder: str,
|
69
|
+
use_auth_token: Optional[Union[bool, str]],
|
70
|
+
revision: Optional[str],
|
71
|
+
cache_dir: Optional[str],
|
72
|
+
force_download: bool,
|
73
|
+
local_files_only: bool,
|
74
|
+
) -> Path:
|
75
|
+
"""Pull model files from the Hugging Face Hub."""
|
76
|
+
huggingface_token = _get_huggingface_token(use_auth_token)
|
77
|
+
repo_files = list(
|
78
|
+
map(
|
79
|
+
Path,
|
80
|
+
HfApi().list_repo_files(model_id, revision=revision, token=huggingface_token),
|
81
|
+
)
|
82
|
+
)
|
83
|
+
|
84
|
+
pattern_rbln = "*.rbln" if subfolder == "" else f"{subfolder}/*.rbln"
|
85
|
+
rbln_files = [p for p in repo_files if p.match(pattern_rbln)]
|
86
|
+
|
87
|
+
pattern_config = "rbln_config.json" if subfolder == "" else f"{subfolder}/rbln_config.json"
|
88
|
+
rbln_config_filenames = [p for p in repo_files if p.match(pattern_config)]
|
89
|
+
|
90
|
+
validate_files(rbln_files, rbln_config_filenames, f"repository {model_id}")
|
91
|
+
|
92
|
+
filenames = [str(path) for path in repo_files]
|
93
|
+
|
94
|
+
for filename in filenames:
|
95
|
+
rbln_config_cache_path = hf_hub_download(
|
96
|
+
repo_id=model_id,
|
97
|
+
filename=filename,
|
98
|
+
subfolder=subfolder,
|
99
|
+
use_auth_token=use_auth_token,
|
100
|
+
revision=revision,
|
101
|
+
cache_dir=cache_dir,
|
102
|
+
force_download=force_download,
|
103
|
+
local_files_only=local_files_only,
|
104
|
+
)
|
105
|
+
|
106
|
+
return Path(rbln_config_cache_path).parent
|
107
|
+
|
108
|
+
|
109
|
+
def validate_files(
|
110
|
+
files: List[Path],
|
111
|
+
config_files: List[Path],
|
112
|
+
location: str,
|
113
|
+
):
|
114
|
+
"""Validate the presence and count of required files."""
|
115
|
+
if len(files) == 0:
|
116
|
+
raise FileNotFoundError(f"Could not find any rbln model file in {location}")
|
117
|
+
|
118
|
+
if len(config_files) == 0:
|
119
|
+
raise FileNotFoundError(f"Could not find `rbln_config.json` file in {location}")
|
120
|
+
|
121
|
+
if len(config_files) > 1:
|
122
|
+
raise FileExistsError(f"Multiple rbln_config.json files found in {location}. This is not expected.")
|
123
|
+
|
124
|
+
|
125
|
+
def _get_huggingface_token(use_auth_token: Union[bool, str]) -> str:
|
126
|
+
if isinstance(use_auth_token, str):
|
127
|
+
return use_auth_token
|
128
|
+
elif use_auth_token:
|
129
|
+
return HfFolder.get_token()
|
130
|
+
else:
|
131
|
+
raise ValueError("`use_auth_token` must be provided to interact with the Hugging Face Hub.")
|
@@ -37,6 +37,27 @@ class VersionCompat:
|
|
37
37
|
|
38
38
|
|
39
39
|
RBLN_VERSION_COMPATS = {
|
40
|
+
"0.1.15": [
|
41
|
+
VersionCompat(
|
42
|
+
package_name="rebel-compiler",
|
43
|
+
min_version="0.6.2",
|
44
|
+
max_version="0.6.3",
|
45
|
+
),
|
46
|
+
],
|
47
|
+
"0.1.14": [
|
48
|
+
VersionCompat(
|
49
|
+
package_name="rebel-compiler",
|
50
|
+
min_version="0.6.2",
|
51
|
+
max_version="0.6.3",
|
52
|
+
),
|
53
|
+
],
|
54
|
+
"0.1.13": [
|
55
|
+
VersionCompat(
|
56
|
+
package_name="rebel-compiler",
|
57
|
+
min_version="0.6.0",
|
58
|
+
max_version="0.6.2",
|
59
|
+
),
|
60
|
+
],
|
40
61
|
"0.1.12": [
|
41
62
|
VersionCompat(
|
42
63
|
package_name="rebel-compiler",
|