optimum-rbln 0.1.7__py3-none-any.whl → 0.1.8__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- optimum/rbln/__init__.py +14 -0
- optimum/rbln/__version__.py +1 -1
- optimum/rbln/diffusers/__init__.py +0 -1
- optimum/rbln/diffusers/models/controlnet.py +3 -0
- optimum/rbln/diffusers/models/unet_2d_condition.py +2 -2
- optimum/rbln/diffusers/pipelines/controlnet/multicontrolnet.py +22 -144
- optimum/rbln/diffusers/pipelines/controlnet/pipeline_controlnet.py +107 -59
- optimum/rbln/diffusers/pipelines/controlnet/pipeline_controlnet_img2img.py +106 -54
- optimum/rbln/diffusers/pipelines/controlnet/pipeline_controlnet_sd_xl.py +130 -71
- optimum/rbln/diffusers/pipelines/controlnet/pipeline_controlnet_sd_xl_img2img.py +131 -72
- optimum/rbln/modeling_alias.py +14 -0
- optimum/rbln/modeling_base.py +110 -0
- optimum/rbln/transformers/__init__.py +6 -0
- optimum/rbln/transformers/cache_utils.py +111 -0
- optimum/rbln/transformers/generation/utils.py +0 -2
- optimum/rbln/transformers/models/__init__.py +2 -0
- optimum/rbln/transformers/models/bart/bart_architecture.py +0 -5
- optimum/rbln/transformers/models/decoderonly/__init__.py +36 -0
- optimum/rbln/transformers/models/decoderonly/decoderonly_architecture.py +515 -0
- optimum/rbln/transformers/models/decoderonly/modeling_decoderonly.py +349 -0
- optimum/rbln/transformers/models/gemma/__init__.py +24 -0
- optimum/rbln/transformers/models/gemma/gemma_architecture.py +116 -0
- optimum/rbln/transformers/models/gemma/modeling_gemma.py +61 -0
- optimum/rbln/transformers/models/gpt2/gpt2_architecture.py +201 -166
- optimum/rbln/transformers/models/gpt2/modeling_gpt2.py +56 -220
- optimum/rbln/transformers/models/llama/llama_architecture.py +3 -610
- optimum/rbln/transformers/models/llama/modeling_llama.py +8 -442
- optimum/rbln/transformers/models/midm/hf_hub_cached/midm_bitext_tokenization.py +2 -1
- optimum/rbln/transformers/models/midm/hf_hub_cached/modeling_midm.py +0 -4
- optimum/rbln/transformers/models/midm/midm_architecture.py +160 -357
- optimum/rbln/transformers/models/midm/modeling_midm.py +40 -272
- optimum/rbln/transformers/models/whisper/whisper_architecture.py +0 -6
- optimum/rbln/transformers/models/xlm_roberta/__init__.py +24 -0
- optimum/rbln/transformers/models/xlm_roberta/modeling_xlm_roberta.py +125 -0
- {optimum_rbln-0.1.7.dist-info → optimum_rbln-0.1.8.dist-info}/METADATA +2 -3
- {optimum_rbln-0.1.7.dist-info → optimum_rbln-0.1.8.dist-info}/RECORD +38 -30
- optimum/rbln/transformers/models/llama/llama_architecture_cb.py +0 -764
- {optimum_rbln-0.1.7.dist-info → optimum_rbln-0.1.8.dist-info}/WHEEL +0 -0
- {optimum_rbln-0.1.7.dist-info → optimum_rbln-0.1.8.dist-info}/licenses/LICENSE +0 -0
@@ -22,17 +22,17 @@
|
|
22
22
|
# from Rebellions Inc.
|
23
23
|
"""RBLNStableDiffusionXLPipeline class for inference of diffusion models on rbln devices."""
|
24
24
|
|
25
|
-
from pathlib import Path
|
26
|
-
from tempfile import TemporaryDirectory
|
27
25
|
from typing import Any, Callable, Dict, List, Optional, Tuple, Union
|
28
26
|
|
29
27
|
import torch
|
30
28
|
import torch.nn.functional as F
|
31
|
-
from diffusers import StableDiffusionXLControlNetImg2ImgPipeline
|
29
|
+
from diffusers import AutoencoderKL, ControlNetModel, StableDiffusionXLControlNetImg2ImgPipeline
|
32
30
|
from diffusers.image_processor import PipelineImageInput
|
31
|
+
from diffusers.models.unets.unet_2d_condition import UNet2DConditionModel
|
33
32
|
from diffusers.pipelines.stable_diffusion_xl.pipeline_output import StableDiffusionXLPipelineOutput
|
34
33
|
from diffusers.utils import deprecate, logging
|
35
34
|
from diffusers.utils.torch_utils import is_compiled_module
|
35
|
+
from transformers import CLIPTextModel
|
36
36
|
|
37
37
|
from ....modeling_base import RBLNBaseModel
|
38
38
|
from ....transformers import RBLNCLIPTextModel, RBLNCLIPTextModelWithProjection
|
@@ -63,103 +63,152 @@ class RBLNStableDiffusionXLControlNetImg2ImgPipeline(StableDiffusionXLControlNet
|
|
63
63
|
- A path to a *directory* containing a model saved using [`~OptimizedModel.save_pretrained`],
|
64
64
|
"""
|
65
65
|
export = kwargs.pop("export", None)
|
66
|
-
text_encoder = kwargs.pop("text_encoder", None)
|
67
|
-
controlnets = kwargs.pop("controlnet", None)
|
68
66
|
vae = kwargs.pop("vae", None)
|
67
|
+
unet = kwargs.pop("unet", None)
|
68
|
+
text_encoder = kwargs.pop("text_encoder", None)
|
69
|
+
text_encoder_2 = kwargs.pop("text_encoder_2", None)
|
70
|
+
controlnet = kwargs.pop("controlnet", None)
|
71
|
+
model_save_dir = kwargs.pop("model_save_dir", None)
|
69
72
|
|
70
73
|
rbln_config_kwargs, rbln_constructor_kwargs = RBLNBaseModel.pop_rbln_kwargs_from_kwargs(kwargs)
|
74
|
+
|
71
75
|
kwargs_dict = {
|
72
76
|
"pretrained_model_name_or_path": model_id,
|
73
|
-
"vae": vae,
|
74
|
-
"controlnet": controlnets,
|
75
|
-
"text_encoder": text_encoder,
|
76
77
|
**kwargs,
|
77
78
|
}
|
78
79
|
|
80
|
+
kwargs_dict.update(
|
81
|
+
{
|
82
|
+
**({"vae": vae} if vae is not None and isinstance(vae, AutoencoderKL) else {}),
|
83
|
+
**({"unet": unet} if unet is not None and isinstance(unet, UNet2DConditionModel) else {}),
|
84
|
+
**(
|
85
|
+
{"text_encoder": text_encoder}
|
86
|
+
if text_encoder is not None and isinstance(text_encoder, CLIPTextModel)
|
87
|
+
else {}
|
88
|
+
),
|
89
|
+
**(
|
90
|
+
{"controlnet": controlnet}
|
91
|
+
if controlnet is not None
|
92
|
+
and (
|
93
|
+
isinstance(controlnet, ControlNetModel)
|
94
|
+
or all(isinstance(c, ControlNetModel) for c in controlnet)
|
95
|
+
)
|
96
|
+
else {}
|
97
|
+
),
|
98
|
+
}
|
99
|
+
)
|
100
|
+
|
79
101
|
model = super().from_pretrained(**{k: v for k, v in kwargs_dict.items() if v is not None})
|
80
102
|
|
81
103
|
if export is None or export is False:
|
82
104
|
return model
|
83
105
|
|
84
|
-
save_dir = TemporaryDirectory()
|
85
|
-
save_dir_path = Path(save_dir.name)
|
86
|
-
|
87
|
-
model.save_pretrained(save_directory=save_dir_path, **kwargs)
|
88
|
-
|
89
106
|
do_classifier_free_guidance = (
|
90
107
|
rbln_config_kwargs.pop("rbln_guidance_scale", 5.0) > 1.0 and model.unet.config.time_cond_proj_dim is None
|
91
108
|
)
|
92
109
|
|
93
|
-
vae
|
94
|
-
|
95
|
-
|
96
|
-
|
97
|
-
|
98
|
-
|
99
|
-
|
100
|
-
|
101
|
-
|
102
|
-
|
103
|
-
|
104
|
-
|
105
|
-
subfolder="text_encoder",
|
106
|
-
export=True,
|
107
|
-
**rbln_config_kwargs,
|
108
|
-
**rbln_constructor_kwargs,
|
109
|
-
)
|
110
|
-
text_encoder_2 = RBLNCLIPTextModelWithProjection.from_pretrained(
|
111
|
-
model_id=model_id,
|
112
|
-
subfolder="text_encoder_2",
|
113
|
-
export=True,
|
114
|
-
**rbln_config_kwargs,
|
115
|
-
**rbln_constructor_kwargs,
|
116
|
-
)
|
117
|
-
|
118
|
-
batch_size = rbln_config_kwargs.pop("rbln_batch_size", 1)
|
119
|
-
unet_batch_size = batch_size * 2 if do_classifier_free_guidance else batch_size
|
110
|
+
if not isinstance(vae, RBLNAutoencoderKL):
|
111
|
+
vae = RBLNAutoencoderKL.from_pretrained(
|
112
|
+
model_id=model_id,
|
113
|
+
subfolder="vae",
|
114
|
+
export=True,
|
115
|
+
model_save_dir=model_save_dir,
|
116
|
+
rbln_unet_sample_size=model.unet.config.sample_size,
|
117
|
+
rbln_use_encode=True,
|
118
|
+
rbln_vae_scale_factor=model.vae_scale_factor,
|
119
|
+
**rbln_config_kwargs,
|
120
|
+
**rbln_constructor_kwargs,
|
121
|
+
)
|
120
122
|
|
121
|
-
|
122
|
-
|
123
|
-
|
124
|
-
|
125
|
-
|
126
|
-
|
127
|
-
|
128
|
-
|
129
|
-
|
130
|
-
rbln_is_controlnet=True if "controlnet" in model.config.keys() else False,
|
131
|
-
**rbln_config_kwargs,
|
132
|
-
**rbln_constructor_kwargs,
|
133
|
-
)
|
123
|
+
if not isinstance(text_encoder, RBLNCLIPTextModel):
|
124
|
+
text_encoder = RBLNCLIPTextModel.from_pretrained(
|
125
|
+
model_id=model_id,
|
126
|
+
subfolder="text_encoder",
|
127
|
+
export=True,
|
128
|
+
model_save_dir=model_save_dir,
|
129
|
+
**rbln_config_kwargs,
|
130
|
+
**rbln_constructor_kwargs,
|
131
|
+
)
|
134
132
|
|
135
|
-
if isinstance(
|
136
|
-
|
137
|
-
model_id=
|
133
|
+
if not isinstance(text_encoder_2, RBLNCLIPTextModel):
|
134
|
+
text_encoder_2 = RBLNCLIPTextModelWithProjection.from_pretrained(
|
135
|
+
model_id=model_id,
|
136
|
+
subfolder="text_encoder_2",
|
138
137
|
export=True,
|
139
|
-
|
140
|
-
rbln_vae_scale_factor=model.vae_scale_factor,
|
138
|
+
model_save_dir=model_save_dir,
|
141
139
|
**rbln_config_kwargs,
|
142
140
|
**rbln_constructor_kwargs,
|
143
141
|
)
|
144
|
-
|
145
|
-
|
146
|
-
|
147
|
-
|
142
|
+
|
143
|
+
batch_size = rbln_config_kwargs.pop("rbln_batch_size", 1)
|
144
|
+
unet_batch_size = batch_size * 2 if do_classifier_free_guidance else batch_size
|
145
|
+
|
146
|
+
if not isinstance(unet, RBLNUNet2DConditionModel):
|
147
|
+
unet = RBLNUNet2DConditionModel.from_pretrained(
|
148
|
+
model_id=model_id,
|
149
|
+
subfolder="unet",
|
148
150
|
export=True,
|
149
|
-
|
151
|
+
model_save_dir=model_save_dir,
|
152
|
+
rbln_max_seq_len=model.text_encoder.config.max_position_embeddings,
|
150
153
|
rbln_text_model_hidden_size=model.text_encoder_2.config.hidden_size,
|
154
|
+
rbln_batch_size=unet_batch_size,
|
155
|
+
rbln_use_encode=True,
|
151
156
|
rbln_vae_scale_factor=model.vae_scale_factor,
|
157
|
+
rbln_is_controlnet=True if "controlnet" in model.config.keys() else False,
|
152
158
|
**rbln_config_kwargs,
|
153
159
|
**rbln_constructor_kwargs,
|
154
160
|
)
|
155
|
-
controlnet_dict = ("optimum.rbln", "RBLNControlNetModel")
|
156
161
|
|
162
|
+
if not isinstance(controlnet, (RBLNControlNetModel, RBLNMultiControlNetModel)):
|
163
|
+
if isinstance(controlnet, (list, tuple)):
|
164
|
+
multicontrolnet = []
|
165
|
+
for i, cid in enumerate(controlnet):
|
166
|
+
subfolder_name = "controlnet" if i == 0 else f"controlnet_{i}"
|
167
|
+
multicontrolnet.append(
|
168
|
+
RBLNControlNetModel.from_pretrained(
|
169
|
+
model_id=cid.config._name_or_path,
|
170
|
+
subfolder=subfolder_name,
|
171
|
+
export=True,
|
172
|
+
model_save_dir=model_save_dir,
|
173
|
+
rbln_batch_size=unet_batch_size,
|
174
|
+
rbln_text_model_hidden_size=model.text_encoder_2.config.hidden_size,
|
175
|
+
rbln_vae_scale_factor=model.vae_scale_factor,
|
176
|
+
**rbln_config_kwargs,
|
177
|
+
**rbln_constructor_kwargs,
|
178
|
+
)
|
179
|
+
)
|
180
|
+
controlnet = RBLNMultiControlNetModel(multicontrolnet, config=controlnet[0].config)
|
181
|
+
controlnet_dict = ("optimum.rbln", "RBLNMultiControlNetModel")
|
182
|
+
else:
|
183
|
+
controlnet = RBLNControlNetModel.from_pretrained(
|
184
|
+
model_id=controlnet.config._name_or_path,
|
185
|
+
subfolder="controlnet",
|
186
|
+
export=True,
|
187
|
+
model_save_dir=model_save_dir,
|
188
|
+
rbln_batch_size=unet_batch_size,
|
189
|
+
rbln_text_model_hidden_size=model.text_encoder_2.config.hidden_size,
|
190
|
+
rbln_vae_scale_factor=model.vae_scale_factor,
|
191
|
+
**rbln_config_kwargs,
|
192
|
+
**rbln_constructor_kwargs,
|
193
|
+
)
|
194
|
+
controlnet_dict = ("optimum.rbln", "RBLNControlNetModel")
|
195
|
+
|
196
|
+
if model_save_dir is not None:
|
197
|
+
# To skip saving original pytorch modules
|
198
|
+
del (model.vae, model.text_encoder, model.unet, model.controlnet)
|
199
|
+
|
200
|
+
# Direct calling of `save_pretrained` causes config.unet = (None, None).
|
201
|
+
# So config must be saved again, later.
|
202
|
+
model.save_pretrained(model_save_dir)
|
203
|
+
|
204
|
+
# replace modules
|
157
205
|
model.vae = vae
|
158
206
|
model.text_encoder = text_encoder
|
159
207
|
model.unet = unet
|
160
208
|
model.text_encoder_2 = text_encoder_2
|
161
209
|
model.controlnet = controlnet
|
162
210
|
|
211
|
+
# update config to be able to load from file
|
163
212
|
update_dict = {
|
164
213
|
"vae": ("optimum.rbln", "RBLNAutoencoderKL"),
|
165
214
|
"text_encoder": ("optimum.rbln", "RBLNCLIPTextModel"),
|
@@ -169,14 +218,24 @@ class RBLNStableDiffusionXLControlNetImg2ImgPipeline(StableDiffusionXLControlNet
|
|
169
218
|
}
|
170
219
|
model.register_to_config(**update_dict)
|
171
220
|
|
172
|
-
|
173
|
-
|
174
|
-
|
175
|
-
|
176
|
-
|
177
|
-
|
178
|
-
|
179
|
-
|
221
|
+
if model_save_dir is not None:
|
222
|
+
# overwrite to replace incorrect config
|
223
|
+
model.save_config(model_save_dir)
|
224
|
+
|
225
|
+
# use for CI to access each compiled model
|
226
|
+
if rbln_constructor_kwargs.pop("rbln_optimize_host_memory", None) is False:
|
227
|
+
model.compiled_models = [
|
228
|
+
vae.compiled_models[0],
|
229
|
+
vae.compiled_models[1],
|
230
|
+
text_encoder.compiled_models[0],
|
231
|
+
text_encoder_2.compiled_models[0],
|
232
|
+
unet.compiled_models[0],
|
233
|
+
]
|
234
|
+
if isinstance(controlnet, RBLNMultiControlNetModel):
|
235
|
+
for c_model in controlnet.nets:
|
236
|
+
model.compiled_models.append(c_model.compiled_models[0])
|
237
|
+
else:
|
238
|
+
model.compiled_models.append(controlnet.compiled_models[0])
|
180
239
|
|
181
240
|
return model
|
182
241
|
|
optimum/rbln/modeling_alias.py
CHANGED
@@ -24,7 +24,9 @@
|
|
24
24
|
from .modeling_base import (
|
25
25
|
RBLNModelForAudioClassification,
|
26
26
|
RBLNModelForImageClassification,
|
27
|
+
RBLNModelForMaskedLM,
|
27
28
|
RBLNModelForQuestionAnswering,
|
29
|
+
RBLNModelForSequenceClassification,
|
28
30
|
)
|
29
31
|
from .modeling_seq2seq import RBLNModelForSeq2SeqLM
|
30
32
|
|
@@ -47,3 +49,15 @@ class RBLNT5ForConditionalGeneration(RBLNModelForSeq2SeqLM):
|
|
47
49
|
|
48
50
|
class RBLNBartForConditionalGeneration(RBLNModelForSeq2SeqLM):
|
49
51
|
pass
|
52
|
+
|
53
|
+
|
54
|
+
class RBLNXLMRobertaForSequenceClassification(RBLNModelForSequenceClassification):
|
55
|
+
pass
|
56
|
+
|
57
|
+
|
58
|
+
class RBLNRobertaForSequenceClassification(RBLNModelForSequenceClassification):
|
59
|
+
pass
|
60
|
+
|
61
|
+
|
62
|
+
class RBLNRobertaForMaskedLM(RBLNModelForMaskedLM):
|
63
|
+
pass
|
optimum/rbln/modeling_base.py
CHANGED
@@ -39,7 +39,9 @@ from transformers import (
|
|
39
39
|
AutoModel,
|
40
40
|
AutoModelForAudioClassification,
|
41
41
|
AutoModelForImageClassification,
|
42
|
+
AutoModelForMaskedLM,
|
42
43
|
AutoModelForQuestionAnswering,
|
44
|
+
AutoModelForSequenceClassification,
|
43
45
|
GenerationConfig,
|
44
46
|
PretrainedConfig,
|
45
47
|
)
|
@@ -748,3 +750,111 @@ class RBLNModelForAudioClassification(RBLNModel):
|
|
748
750
|
)
|
749
751
|
|
750
752
|
return rbln_config
|
753
|
+
|
754
|
+
|
755
|
+
class RBLNModelForSequenceClassification(RBLNModel):
|
756
|
+
"""
|
757
|
+
This is a generic model class that will be instantiated as one of the model classes of the library (with a sequence classification head) when created with the from_pretrained() class method
|
758
|
+
This model inherits from [`RBLNModel`]. Check the superclass documentation for the generic methods the library implements for all its models.
|
759
|
+
|
760
|
+
A class to convert and run pre-trained transformers based SequenceClassification models on RBLN devices.
|
761
|
+
It implements the methods to convert a pre-trained transformers SequenceClassification model into a RBLN transformer model by:
|
762
|
+
- transferring the checkpoint weights of the original into an optimized RBLN graph,
|
763
|
+
- compiling the resulting graph using the RBLN compiler.
|
764
|
+
|
765
|
+
Currently, this model class supports the 'XLMRoberta' and 'Roberta' model from the transformers library. Future updates may include support for additional model types.
|
766
|
+
"""
|
767
|
+
|
768
|
+
model_type = "rbln_model"
|
769
|
+
auto_model_class = AutoModelForSequenceClassification
|
770
|
+
|
771
|
+
@classmethod
|
772
|
+
def _get_rbln_config(
|
773
|
+
cls,
|
774
|
+
preprocessors: Optional[Union["AutoFeatureExtractor", "AutoProcessor", "AutoTokenizer"]],
|
775
|
+
model_config: Optional["PretrainedConfig"] = None,
|
776
|
+
rbln_max_seq_len: Optional[int] = None,
|
777
|
+
rbln_model_input_names: Optional[List[str]] = None,
|
778
|
+
rbln_batch_size: Optional[int] = None,
|
779
|
+
) -> RBLNConfig:
|
780
|
+
|
781
|
+
max_position_embeddings = getattr(model_config, "n_positions", None) or getattr(
|
782
|
+
model_config, "max_position_embeddings", None
|
783
|
+
)
|
784
|
+
|
785
|
+
if rbln_max_seq_len is None:
|
786
|
+
rbln_max_seq_len = max_position_embeddings
|
787
|
+
if rbln_max_seq_len is None:
|
788
|
+
for tokenizer in preprocessors:
|
789
|
+
if hasattr(tokenizer, "model_max_length"):
|
790
|
+
rbln_max_seq_len = tokenizer.model_max_length
|
791
|
+
break
|
792
|
+
if rbln_max_seq_len is None:
|
793
|
+
raise ValueError("`rbln_max_seq_len` should be specified!")
|
794
|
+
|
795
|
+
if max_position_embeddings is not None and rbln_max_seq_len > max_position_embeddings:
|
796
|
+
raise ValueError("`rbln_enc_max_seq_len` should be less or equal than max_position_embeddings!")
|
797
|
+
|
798
|
+
if rbln_model_input_names is None:
|
799
|
+
# These are BERT's inputs
|
800
|
+
rbln_model_input_names = ["input_ids", "attention_mask"]
|
801
|
+
|
802
|
+
if rbln_batch_size is None:
|
803
|
+
rbln_batch_size = 1
|
804
|
+
input_info = [
|
805
|
+
(model_input_name, [rbln_batch_size, rbln_max_seq_len], "int64")
|
806
|
+
for model_input_name in rbln_model_input_names
|
807
|
+
]
|
808
|
+
|
809
|
+
rbln_runtime_config = RBLNRuntimeConfig(input_info=input_info)
|
810
|
+
rbln_runtime_config.batch_size = rbln_batch_size
|
811
|
+
meta = {"rbln_max_seq_len": rbln_max_seq_len}
|
812
|
+
|
813
|
+
return RBLNConfig.from_rbln_runtime_configs([rbln_runtime_config], _rbln_meta=meta)
|
814
|
+
|
815
|
+
class RBLNModelForMaskedLM(RBLNModel):
|
816
|
+
model_type = "rbln_model"
|
817
|
+
auto_model_class = AutoModelForMaskedLM
|
818
|
+
|
819
|
+
@classmethod
|
820
|
+
def _get_rbln_config(
|
821
|
+
cls,
|
822
|
+
preprocessors: Optional[Union["AutoFeatureExtractor", "AutoProcessor", "AutoTokenizer"]],
|
823
|
+
model_config: Optional["PretrainedConfig"] = None,
|
824
|
+
rbln_max_seq_len: Optional[int] = None,
|
825
|
+
rbln_model_input_names: Optional[List[str]] = None,
|
826
|
+
rbln_batch_size: Optional[int] = None,
|
827
|
+
) -> RBLNConfig:
|
828
|
+
max_position_embeddings = getattr(model_config, "n_positions", None) or getattr(
|
829
|
+
model_config, "max_position_embeddings", None
|
830
|
+
)
|
831
|
+
|
832
|
+
if rbln_max_seq_len is None:
|
833
|
+
rbln_max_seq_len = max_position_embeddings
|
834
|
+
if rbln_max_seq_len is None:
|
835
|
+
for tokenizer in preprocessors:
|
836
|
+
if hasattr(tokenizer, "model_max_length"):
|
837
|
+
rbln_max_seq_len = tokenizer.model_max_length
|
838
|
+
break
|
839
|
+
if rbln_max_seq_len is None:
|
840
|
+
raise ValueError("`rbln_max_seq_len` should be specified!")
|
841
|
+
|
842
|
+
if max_position_embeddings is not None and rbln_max_seq_len > max_position_embeddings:
|
843
|
+
raise ValueError("`rbln_enc_max_seq_len` should be less or equal than max_position_embeddings!")
|
844
|
+
|
845
|
+
if rbln_model_input_names is None:
|
846
|
+
# These are BERT's inputs
|
847
|
+
rbln_model_input_names = ["input_ids", "attention_mask"]
|
848
|
+
|
849
|
+
if rbln_batch_size is None:
|
850
|
+
rbln_batch_size = 1
|
851
|
+
input_info = [
|
852
|
+
(model_input_name, [rbln_batch_size, rbln_max_seq_len], "int64")
|
853
|
+
for model_input_name in rbln_model_input_names
|
854
|
+
]
|
855
|
+
|
856
|
+
rbln_runtime_config = RBLNRuntimeConfig(input_info=input_info)
|
857
|
+
rbln_runtime_config.batch_size = rbln_batch_size
|
858
|
+
meta = {"rbln_max_seq_len": rbln_max_seq_len}
|
859
|
+
|
860
|
+
return RBLNConfig.from_rbln_runtime_configs([rbln_runtime_config], _rbln_meta=meta)
|
@@ -27,30 +27,36 @@ from transformers.utils import _LazyModule
|
|
27
27
|
|
28
28
|
|
29
29
|
_import_structure = {
|
30
|
+
"cache_utils": ["RebelDynamicCache"],
|
30
31
|
"generation": ["BatchTextIteratorStreamer"],
|
31
32
|
"models": [
|
32
33
|
"RBLNCLIPTextModel",
|
33
34
|
"RBLNCLIPTextModelWithProjection",
|
34
35
|
"RBLNDPTForDepthEstimation",
|
36
|
+
"RBLNGemmaForCausalLM",
|
35
37
|
"RBLNGPT2LMHeadModel",
|
36
38
|
"RBLNWav2Vec2ForCTC",
|
37
39
|
"RBLNWhisperForConditionalGeneration",
|
38
40
|
"RBLNLlamaForCausalLM",
|
39
41
|
"RBLNMidmLMHeadModel",
|
42
|
+
"RBLNXLMRobertaModel"
|
40
43
|
],
|
41
44
|
}
|
42
45
|
|
43
46
|
if TYPE_CHECKING:
|
47
|
+
from .cache_utils import RebelDynamicCache
|
44
48
|
from .generation import BatchTextIteratorStreamer
|
45
49
|
from .models import (
|
46
50
|
RBLNCLIPTextModel,
|
47
51
|
RBLNCLIPTextModelWithProjection,
|
48
52
|
RBLNDPTForDepthEstimation,
|
53
|
+
RBLNGemmaForCausalLM,
|
49
54
|
RBLNGPT2LMHeadModel,
|
50
55
|
RBLNLlamaForCausalLM,
|
51
56
|
RBLNMidmLMHeadModel,
|
52
57
|
RBLNWav2Vec2ForCTC,
|
53
58
|
RBLNWhisperForConditionalGeneration,
|
59
|
+
RBLNXLMRobertaModel,
|
54
60
|
)
|
55
61
|
else:
|
56
62
|
import sys
|
@@ -0,0 +1,111 @@
|
|
1
|
+
from typing import Optional, Tuple
|
2
|
+
|
3
|
+
import torch
|
4
|
+
from transformers.cache_utils import DynamicCache
|
5
|
+
|
6
|
+
|
7
|
+
class RebelDynamicCache(DynamicCache):
|
8
|
+
"""
|
9
|
+
A cache that grows dynamically as more tokens are generated. This is the default for generative models.
|
10
|
+
|
11
|
+
It stores the Key and Value states as a list of tensors, one for each layer. The expected shape for each tensor is
|
12
|
+
`[batch_size, num_heads, seq_len, head_dim]`.
|
13
|
+
"""
|
14
|
+
|
15
|
+
def __init__(self, current_steps) -> None:
|
16
|
+
super().__init__()
|
17
|
+
self.current_steps = current_steps
|
18
|
+
|
19
|
+
def assign(
|
20
|
+
self,
|
21
|
+
key_states: torch.Tensor,
|
22
|
+
value_states: torch.Tensor,
|
23
|
+
layer_idx: int,
|
24
|
+
) -> None:
|
25
|
+
self.key_cache[layer_idx] = key_states.squeeze(2)
|
26
|
+
self.value_cache[layer_idx] = value_states.squeeze(2)
|
27
|
+
|
28
|
+
def update(
|
29
|
+
self,
|
30
|
+
key_states: torch.Tensor,
|
31
|
+
value_states: torch.Tensor,
|
32
|
+
layer_idx: int,
|
33
|
+
batch_idx: int,
|
34
|
+
read_first_step: Optional[bool] = False,
|
35
|
+
) -> Tuple[torch.Tensor, torch.Tensor]:
|
36
|
+
"""
|
37
|
+
Updates the cache with the new `key_states` and `value_states` for the layer `layer_idx` and the batch 'batch_inx'
|
38
|
+
based on self.current_step,
|
39
|
+
"""
|
40
|
+
current_step = self.current_steps[0 if read_first_step else batch_idx]
|
41
|
+
kend = current_step + key_states.shape[-2]
|
42
|
+
vend = current_step + value_states.shape[-2]
|
43
|
+
update_key_states = (
|
44
|
+
self.key_cache[layer_idx][batch_idx]
|
45
|
+
.unsqueeze(0)
|
46
|
+
.unsqueeze(2)
|
47
|
+
.slice_scatter(key_states, dim=-2, start=current_step, end=kend)
|
48
|
+
)
|
49
|
+
update_value_states = (
|
50
|
+
self.value_cache[layer_idx][batch_idx]
|
51
|
+
.unsqueeze(0)
|
52
|
+
.unsqueeze(2)
|
53
|
+
.slice_scatter(value_states, dim=-2, start=current_step, end=vend)
|
54
|
+
)
|
55
|
+
|
56
|
+
return update_key_states, update_value_states
|
57
|
+
|
58
|
+
@classmethod
|
59
|
+
def from_input_format(cls, position_ids, num_hidden_layer, *past_key_values) -> "DynamicCache":
|
60
|
+
"""Converts a cache in the rbln cache format (list of past_kv) into an equivalent `DynamicCache`."""
|
61
|
+
|
62
|
+
batch, _ = position_ids.shape
|
63
|
+
current_steps = [position_ids[b][0] for b in range(batch)]
|
64
|
+
|
65
|
+
assert len(current_steps) == batch
|
66
|
+
cache = cls(current_steps)
|
67
|
+
|
68
|
+
for layer_idx in range(num_hidden_layer):
|
69
|
+
key_states = past_key_values[layer_idx * 2]
|
70
|
+
value_states = past_key_values[layer_idx * 2 + 1]
|
71
|
+
cache.key_cache.append(key_states)
|
72
|
+
cache.value_cache.append(value_states)
|
73
|
+
|
74
|
+
return cache
|
75
|
+
|
76
|
+
|
77
|
+
class RebelDynamicCache_4D(RebelDynamicCache):
|
78
|
+
def assign(
|
79
|
+
self,
|
80
|
+
keys: torch.Tensor,
|
81
|
+
values: torch.Tensor,
|
82
|
+
layer_idx: int,
|
83
|
+
) -> None:
|
84
|
+
self.key_cache[layer_idx] = keys
|
85
|
+
self.value_cache[layer_idx] = values
|
86
|
+
|
87
|
+
def update(
|
88
|
+
self,
|
89
|
+
keys: torch.Tensor,
|
90
|
+
values: torch.Tensor,
|
91
|
+
layer_idx: int,
|
92
|
+
batch_idx: int,
|
93
|
+
read_first_step: Optional[bool] = False,
|
94
|
+
) -> Tuple[torch.Tensor, torch.Tensor]:
|
95
|
+
"""
|
96
|
+
Updates the cache with the new `keys` and `values` for the layer `layer_idx` and the batch 'batch_inx'
|
97
|
+
based on self.current_step,
|
98
|
+
"""
|
99
|
+
current_step = self.current_steps[0 if read_first_step else batch_idx]
|
100
|
+
kend = current_step + keys.shape[-2]
|
101
|
+
vend = current_step + values.shape[-2]
|
102
|
+
update_keys = (
|
103
|
+
self.key_cache[layer_idx][batch_idx].unsqueeze(0).slice_scatter(keys, dim=-2, start=current_step, end=kend)
|
104
|
+
)
|
105
|
+
update_values = (
|
106
|
+
self.value_cache[layer_idx][batch_idx]
|
107
|
+
.unsqueeze(0)
|
108
|
+
.slice_scatter(values, dim=-2, start=current_step, end=vend)
|
109
|
+
)
|
110
|
+
|
111
|
+
return update_keys, update_values
|
@@ -32,7 +32,6 @@ class RBLNGenerationMixin:
|
|
32
32
|
generation_config: Optional[GenerationConfig] = None, # thkim change for 4.41.0
|
33
33
|
**model_kwargs,
|
34
34
|
) -> Union[SampleDecoderOnlyOutput, torch.LongTensor]:
|
35
|
-
|
36
35
|
###################### thkim change for 4.41.0 ############################
|
37
36
|
if generation_config is not None:
|
38
37
|
pad_token_id = generation_config.pad_token_id
|
@@ -216,7 +215,6 @@ class RBLNGenerationMixin:
|
|
216
215
|
do_sample: Optional[bool] = True,
|
217
216
|
**model_kwargs,
|
218
217
|
) -> Union[SampleDecoderOnlyOutput, torch.LongTensor]:
|
219
|
-
|
220
218
|
###################### thkim change for 4.41.0 ############################
|
221
219
|
if generation_config is not None:
|
222
220
|
pad_token_id = generation_config.pad_token_id
|
@@ -23,8 +23,10 @@
|
|
23
23
|
|
24
24
|
from .clip import RBLNCLIPTextModel, RBLNCLIPTextModelWithProjection
|
25
25
|
from .dpt import RBLNDPTForDepthEstimation
|
26
|
+
from .gemma import RBLNGemmaForCausalLM
|
26
27
|
from .gpt2 import RBLNGPT2LMHeadModel
|
27
28
|
from .llama import RBLNLlamaForCausalLM
|
28
29
|
from .midm import RBLNMidmLMHeadModel
|
29
30
|
from .wav2vec2 import RBLNWav2Vec2ForCTC
|
30
31
|
from .whisper import RBLNWhisperForConditionalGeneration
|
32
|
+
from .xlm_roberta import RBLNXLMRobertaModel
|
@@ -56,7 +56,6 @@ class _BartAttention(BartAttention):
|
|
56
56
|
cache_position: torch.Tensor,
|
57
57
|
key_value_states: Optional[torch.Tensor] = None,
|
58
58
|
) -> Tuple[torch.Tensor, Tuple[torch.Tensor]]:
|
59
|
-
|
60
59
|
bsz, tgt_len, _ = hidden_states.size()
|
61
60
|
is_cross_attention = key_value_states is not None
|
62
61
|
|
@@ -111,7 +110,6 @@ class _BartSdpaAttention(BartSdpaAttention):
|
|
111
110
|
cache_position: torch.Tensor,
|
112
111
|
key_value_states: Optional[torch.Tensor] = None,
|
113
112
|
) -> Tuple[torch.Tensor, Tuple[torch.Tensor]]:
|
114
|
-
|
115
113
|
bsz, tgt_len, _ = hidden_states.size()
|
116
114
|
is_cross_attention = key_value_states is not None
|
117
115
|
|
@@ -166,7 +164,6 @@ class _BartDecoderLayer(BartDecoderLayer):
|
|
166
164
|
cache_position: torch.Tensor,
|
167
165
|
attn_impl: str = "eager",
|
168
166
|
) -> Tuple[torch.Tensor, Tuple[torch.Tensor]]:
|
169
|
-
|
170
167
|
# Self Attention Block
|
171
168
|
residual = hidden_states
|
172
169
|
self_attn_past_key_value = past_key_value[:2]
|
@@ -218,7 +215,6 @@ class _BartDecoder(BartDecoder):
|
|
218
215
|
cache_position: torch.Tensor,
|
219
216
|
attn_impl: str = "eager",
|
220
217
|
):
|
221
|
-
|
222
218
|
# embedding
|
223
219
|
positions_idx = cache_position + self.embed_positions.offset
|
224
220
|
positions = self.embed_positions.weight[positions_idx]
|
@@ -284,7 +280,6 @@ class BartDecoderWrapper(torch.nn.Module):
|
|
284
280
|
self_kv_cache: torch.Tensor,
|
285
281
|
cross_kv_cache: torch.Tensor,
|
286
282
|
) -> Tuple[torch.FloatTensor, Tuple[torch.FloatTensor]]:
|
287
|
-
|
288
283
|
# prepare past_key_values
|
289
284
|
kv_cache = ()
|
290
285
|
for i in range(0, self.num_layers * 2, 2):
|
@@ -0,0 +1,36 @@
|
|
1
|
+
# Copyright 2024 Rebellions Inc.
|
2
|
+
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4
|
+
# you may not use this file except in compliance with the License.
|
5
|
+
# You may obtain a copy of the License at:
|
6
|
+
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
8
|
+
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12
|
+
# See the License for the specific language governing permissions and
|
13
|
+
# limitations under the License.
|
14
|
+
|
15
|
+
# Portions of this software are licensed under the Apache License,
|
16
|
+
# Version 2.0. See the NOTICE file distributed with this work for
|
17
|
+
# additional information regarding copyright ownership.
|
18
|
+
|
19
|
+
# All other portions of this software, including proprietary code,
|
20
|
+
# are the intellectual property of Rebellions Inc. and may not be
|
21
|
+
# copied, modified, or distributed without prior written permission
|
22
|
+
# from Rebellions Inc.
|
23
|
+
|
24
|
+
from .decoderonly_architecture import (
|
25
|
+
DecoderOnlyAttention,
|
26
|
+
DecoderOnlyDecoderLayer,
|
27
|
+
DecoderOnlyModel,
|
28
|
+
DecoderOnlyWrapper,
|
29
|
+
DynamicNTKScalingRotaryEmbedding,
|
30
|
+
LinearScalingRotaryEmbedding,
|
31
|
+
RotaryEmbedding,
|
32
|
+
apply_rotary_pos_emb,
|
33
|
+
rotate_half,
|
34
|
+
slice_and_unsqueeze_cos_sin,
|
35
|
+
)
|
36
|
+
from .modeling_decoderonly import RBLNDecoderOnlyModelForCausalLM
|