optimum-rbln 0.1.7__py3-none-any.whl → 0.1.9__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- optimum/rbln/__init__.py +17 -0
- optimum/rbln/__version__.py +1 -1
- optimum/rbln/diffusers/__init__.py +0 -1
- optimum/rbln/diffusers/models/autoencoder_kl.py +3 -3
- optimum/rbln/diffusers/models/controlnet.py +7 -3
- optimum/rbln/diffusers/models/unet_2d_condition.py +5 -5
- optimum/rbln/diffusers/pipelines/controlnet/multicontrolnet.py +23 -146
- optimum/rbln/diffusers/pipelines/controlnet/pipeline_controlnet.py +107 -59
- optimum/rbln/diffusers/pipelines/controlnet/pipeline_controlnet_img2img.py +106 -54
- optimum/rbln/diffusers/pipelines/controlnet/pipeline_controlnet_sd_xl.py +130 -71
- optimum/rbln/diffusers/pipelines/controlnet/pipeline_controlnet_sd_xl_img2img.py +131 -72
- optimum/rbln/modeling_alias.py +19 -1
- optimum/rbln/modeling_base.py +162 -18
- optimum/rbln/transformers/__init__.py +8 -0
- optimum/rbln/transformers/cache_utils.py +111 -0
- optimum/rbln/transformers/generation/utils.py +0 -2
- optimum/rbln/transformers/models/__init__.py +3 -0
- optimum/rbln/transformers/models/bart/bart_architecture.py +0 -5
- optimum/rbln/transformers/models/clip/modeling_clip.py +1 -1
- optimum/rbln/transformers/models/decoderonly/__init__.py +36 -0
- optimum/rbln/transformers/models/decoderonly/decoderonly_architecture.py +516 -0
- optimum/rbln/transformers/models/decoderonly/modeling_decoderonly.py +464 -0
- optimum/rbln/transformers/models/gemma/__init__.py +24 -0
- optimum/rbln/transformers/models/gemma/gemma_architecture.py +123 -0
- optimum/rbln/transformers/models/gemma/modeling_gemma.py +67 -0
- optimum/rbln/transformers/models/gpt2/gpt2_architecture.py +201 -166
- optimum/rbln/transformers/models/gpt2/modeling_gpt2.py +10 -257
- optimum/rbln/transformers/models/llama/llama_architecture.py +3 -610
- optimum/rbln/transformers/models/llama/modeling_llama.py +12 -440
- optimum/rbln/transformers/models/midm/hf_hub_cached/midm_bitext_tokenization.py +2 -1
- optimum/rbln/transformers/models/midm/hf_hub_cached/modeling_midm.py +0 -4
- optimum/rbln/transformers/models/midm/midm_architecture.py +160 -357
- optimum/rbln/transformers/models/midm/modeling_midm.py +10 -325
- optimum/rbln/transformers/models/mistral/__init__.py +24 -0
- optimum/rbln/transformers/models/mistral/mistral_architecture.py +29 -0
- optimum/rbln/transformers/models/mistral/modeling_mistral.py +68 -0
- optimum/rbln/transformers/models/wav2vec2/modeling_wav2vec2.py +1 -1
- optimum/rbln/transformers/models/whisper/whisper_architecture.py +0 -6
- optimum/rbln/transformers/models/xlm_roberta/__init__.py +24 -0
- optimum/rbln/transformers/models/xlm_roberta/modeling_xlm_roberta.py +131 -0
- optimum/rbln/transformers/utils/__init__.py +0 -0
- optimum/rbln/transformers/utils/rbln_quantization.py +109 -0
- optimum/rbln/utils/import_utils.py +1 -4
- optimum/rbln/utils/runtime_utils.py +2 -1
- {optimum_rbln-0.1.7.dist-info → optimum_rbln-0.1.9.dist-info}/METADATA +11 -5
- {optimum_rbln-0.1.7.dist-info → optimum_rbln-0.1.9.dist-info}/RECORD +48 -35
- optimum/rbln/transformers/models/llama/llama_architecture_cb.py +0 -764
- {optimum_rbln-0.1.7.dist-info → optimum_rbln-0.1.9.dist-info}/WHEEL +0 -0
- {optimum_rbln-0.1.7.dist-info → optimum_rbln-0.1.9.dist-info}/licenses/LICENSE +0 -0
@@ -22,17 +22,17 @@
|
|
22
22
|
# from Rebellions Inc.
|
23
23
|
"""RBLNStableDiffusionXLPipeline class for inference of diffusion models on rbln devices."""
|
24
24
|
|
25
|
-
from pathlib import Path
|
26
|
-
from tempfile import TemporaryDirectory
|
27
25
|
from typing import Any, Callable, Dict, List, Optional, Tuple, Union
|
28
26
|
|
29
27
|
import torch
|
30
28
|
import torch.nn.functional as F
|
31
|
-
from diffusers import StableDiffusionXLControlNetImg2ImgPipeline
|
29
|
+
from diffusers import AutoencoderKL, ControlNetModel, StableDiffusionXLControlNetImg2ImgPipeline
|
32
30
|
from diffusers.image_processor import PipelineImageInput
|
31
|
+
from diffusers.models.unets.unet_2d_condition import UNet2DConditionModel
|
33
32
|
from diffusers.pipelines.stable_diffusion_xl.pipeline_output import StableDiffusionXLPipelineOutput
|
34
33
|
from diffusers.utils import deprecate, logging
|
35
34
|
from diffusers.utils.torch_utils import is_compiled_module
|
35
|
+
from transformers import CLIPTextModel
|
36
36
|
|
37
37
|
from ....modeling_base import RBLNBaseModel
|
38
38
|
from ....transformers import RBLNCLIPTextModel, RBLNCLIPTextModelWithProjection
|
@@ -63,103 +63,152 @@ class RBLNStableDiffusionXLControlNetImg2ImgPipeline(StableDiffusionXLControlNet
|
|
63
63
|
- A path to a *directory* containing a model saved using [`~OptimizedModel.save_pretrained`],
|
64
64
|
"""
|
65
65
|
export = kwargs.pop("export", None)
|
66
|
-
text_encoder = kwargs.pop("text_encoder", None)
|
67
|
-
controlnets = kwargs.pop("controlnet", None)
|
68
66
|
vae = kwargs.pop("vae", None)
|
67
|
+
unet = kwargs.pop("unet", None)
|
68
|
+
text_encoder = kwargs.pop("text_encoder", None)
|
69
|
+
text_encoder_2 = kwargs.pop("text_encoder_2", None)
|
70
|
+
controlnet = kwargs.pop("controlnet", None)
|
71
|
+
model_save_dir = kwargs.pop("model_save_dir", None)
|
69
72
|
|
70
73
|
rbln_config_kwargs, rbln_constructor_kwargs = RBLNBaseModel.pop_rbln_kwargs_from_kwargs(kwargs)
|
74
|
+
|
71
75
|
kwargs_dict = {
|
72
76
|
"pretrained_model_name_or_path": model_id,
|
73
|
-
"vae": vae,
|
74
|
-
"controlnet": controlnets,
|
75
|
-
"text_encoder": text_encoder,
|
76
77
|
**kwargs,
|
77
78
|
}
|
78
79
|
|
80
|
+
kwargs_dict.update(
|
81
|
+
{
|
82
|
+
**({"vae": vae} if vae is not None and isinstance(vae, AutoencoderKL) else {}),
|
83
|
+
**({"unet": unet} if unet is not None and isinstance(unet, UNet2DConditionModel) else {}),
|
84
|
+
**(
|
85
|
+
{"text_encoder": text_encoder}
|
86
|
+
if text_encoder is not None and isinstance(text_encoder, CLIPTextModel)
|
87
|
+
else {}
|
88
|
+
),
|
89
|
+
**(
|
90
|
+
{"controlnet": controlnet}
|
91
|
+
if controlnet is not None
|
92
|
+
and (
|
93
|
+
isinstance(controlnet, ControlNetModel)
|
94
|
+
or all(isinstance(c, ControlNetModel) for c in controlnet)
|
95
|
+
)
|
96
|
+
else {}
|
97
|
+
),
|
98
|
+
}
|
99
|
+
)
|
100
|
+
|
79
101
|
model = super().from_pretrained(**{k: v for k, v in kwargs_dict.items() if v is not None})
|
80
102
|
|
81
103
|
if export is None or export is False:
|
82
104
|
return model
|
83
105
|
|
84
|
-
save_dir = TemporaryDirectory()
|
85
|
-
save_dir_path = Path(save_dir.name)
|
86
|
-
|
87
|
-
model.save_pretrained(save_directory=save_dir_path, **kwargs)
|
88
|
-
|
89
106
|
do_classifier_free_guidance = (
|
90
107
|
rbln_config_kwargs.pop("rbln_guidance_scale", 5.0) > 1.0 and model.unet.config.time_cond_proj_dim is None
|
91
108
|
)
|
92
109
|
|
93
|
-
vae
|
94
|
-
|
95
|
-
|
96
|
-
|
97
|
-
|
98
|
-
|
99
|
-
|
100
|
-
|
101
|
-
|
102
|
-
|
103
|
-
|
104
|
-
|
105
|
-
subfolder="text_encoder",
|
106
|
-
export=True,
|
107
|
-
**rbln_config_kwargs,
|
108
|
-
**rbln_constructor_kwargs,
|
109
|
-
)
|
110
|
-
text_encoder_2 = RBLNCLIPTextModelWithProjection.from_pretrained(
|
111
|
-
model_id=model_id,
|
112
|
-
subfolder="text_encoder_2",
|
113
|
-
export=True,
|
114
|
-
**rbln_config_kwargs,
|
115
|
-
**rbln_constructor_kwargs,
|
116
|
-
)
|
117
|
-
|
118
|
-
batch_size = rbln_config_kwargs.pop("rbln_batch_size", 1)
|
119
|
-
unet_batch_size = batch_size * 2 if do_classifier_free_guidance else batch_size
|
110
|
+
if not isinstance(vae, RBLNAutoencoderKL):
|
111
|
+
vae = RBLNAutoencoderKL.from_pretrained(
|
112
|
+
model_id=model_id,
|
113
|
+
subfolder="vae",
|
114
|
+
export=True,
|
115
|
+
model_save_dir=model_save_dir,
|
116
|
+
rbln_unet_sample_size=model.unet.config.sample_size,
|
117
|
+
rbln_use_encode=True,
|
118
|
+
rbln_vae_scale_factor=model.vae_scale_factor,
|
119
|
+
**rbln_config_kwargs,
|
120
|
+
**rbln_constructor_kwargs,
|
121
|
+
)
|
120
122
|
|
121
|
-
|
122
|
-
|
123
|
-
|
124
|
-
|
125
|
-
|
126
|
-
|
127
|
-
|
128
|
-
|
129
|
-
|
130
|
-
rbln_is_controlnet=True if "controlnet" in model.config.keys() else False,
|
131
|
-
**rbln_config_kwargs,
|
132
|
-
**rbln_constructor_kwargs,
|
133
|
-
)
|
123
|
+
if not isinstance(text_encoder, RBLNCLIPTextModel):
|
124
|
+
text_encoder = RBLNCLIPTextModel.from_pretrained(
|
125
|
+
model_id=model_id,
|
126
|
+
subfolder="text_encoder",
|
127
|
+
export=True,
|
128
|
+
model_save_dir=model_save_dir,
|
129
|
+
**rbln_config_kwargs,
|
130
|
+
**rbln_constructor_kwargs,
|
131
|
+
)
|
134
132
|
|
135
|
-
if isinstance(
|
136
|
-
|
137
|
-
model_id=
|
133
|
+
if not isinstance(text_encoder_2, RBLNCLIPTextModel):
|
134
|
+
text_encoder_2 = RBLNCLIPTextModelWithProjection.from_pretrained(
|
135
|
+
model_id=model_id,
|
136
|
+
subfolder="text_encoder_2",
|
138
137
|
export=True,
|
139
|
-
|
140
|
-
rbln_vae_scale_factor=model.vae_scale_factor,
|
138
|
+
model_save_dir=model_save_dir,
|
141
139
|
**rbln_config_kwargs,
|
142
140
|
**rbln_constructor_kwargs,
|
143
141
|
)
|
144
|
-
|
145
|
-
|
146
|
-
|
147
|
-
|
142
|
+
|
143
|
+
batch_size = rbln_config_kwargs.pop("rbln_batch_size", 1)
|
144
|
+
unet_batch_size = batch_size * 2 if do_classifier_free_guidance else batch_size
|
145
|
+
|
146
|
+
if not isinstance(unet, RBLNUNet2DConditionModel):
|
147
|
+
unet = RBLNUNet2DConditionModel.from_pretrained(
|
148
|
+
model_id=model_id,
|
149
|
+
subfolder="unet",
|
148
150
|
export=True,
|
149
|
-
|
151
|
+
model_save_dir=model_save_dir,
|
152
|
+
rbln_max_seq_len=model.text_encoder.config.max_position_embeddings,
|
150
153
|
rbln_text_model_hidden_size=model.text_encoder_2.config.hidden_size,
|
154
|
+
rbln_batch_size=unet_batch_size,
|
155
|
+
rbln_use_encode=True,
|
151
156
|
rbln_vae_scale_factor=model.vae_scale_factor,
|
157
|
+
rbln_is_controlnet=True if "controlnet" in model.config.keys() else False,
|
152
158
|
**rbln_config_kwargs,
|
153
159
|
**rbln_constructor_kwargs,
|
154
160
|
)
|
155
|
-
controlnet_dict = ("optimum.rbln", "RBLNControlNetModel")
|
156
161
|
|
162
|
+
if not isinstance(controlnet, (RBLNControlNetModel, RBLNMultiControlNetModel)):
|
163
|
+
if isinstance(controlnet, (list, tuple)):
|
164
|
+
multicontrolnet = []
|
165
|
+
for i, cid in enumerate(controlnet):
|
166
|
+
subfolder_name = "controlnet" if i == 0 else f"controlnet_{i}"
|
167
|
+
multicontrolnet.append(
|
168
|
+
RBLNControlNetModel.from_pretrained(
|
169
|
+
model_id=cid.config._name_or_path,
|
170
|
+
subfolder=subfolder_name,
|
171
|
+
export=True,
|
172
|
+
model_save_dir=model_save_dir,
|
173
|
+
rbln_batch_size=unet_batch_size,
|
174
|
+
rbln_text_model_hidden_size=model.text_encoder_2.config.hidden_size,
|
175
|
+
rbln_vae_scale_factor=model.vae_scale_factor,
|
176
|
+
**rbln_config_kwargs,
|
177
|
+
**rbln_constructor_kwargs,
|
178
|
+
)
|
179
|
+
)
|
180
|
+
controlnet = RBLNMultiControlNetModel(multicontrolnet, config=controlnet[0].config)
|
181
|
+
controlnet_dict = ("optimum.rbln", "RBLNMultiControlNetModel")
|
182
|
+
else:
|
183
|
+
controlnet = RBLNControlNetModel.from_pretrained(
|
184
|
+
model_id=controlnet.config._name_or_path,
|
185
|
+
subfolder="controlnet",
|
186
|
+
export=True,
|
187
|
+
model_save_dir=model_save_dir,
|
188
|
+
rbln_batch_size=unet_batch_size,
|
189
|
+
rbln_text_model_hidden_size=model.text_encoder_2.config.hidden_size,
|
190
|
+
rbln_vae_scale_factor=model.vae_scale_factor,
|
191
|
+
**rbln_config_kwargs,
|
192
|
+
**rbln_constructor_kwargs,
|
193
|
+
)
|
194
|
+
controlnet_dict = ("optimum.rbln", "RBLNControlNetModel")
|
195
|
+
|
196
|
+
if model_save_dir is not None:
|
197
|
+
# To skip saving original pytorch modules
|
198
|
+
del (model.vae, model.text_encoder, model.unet, model.controlnet)
|
199
|
+
|
200
|
+
# Direct calling of `save_pretrained` causes config.unet = (None, None).
|
201
|
+
# So config must be saved again, later.
|
202
|
+
model.save_pretrained(model_save_dir)
|
203
|
+
|
204
|
+
# replace modules
|
157
205
|
model.vae = vae
|
158
206
|
model.text_encoder = text_encoder
|
159
207
|
model.unet = unet
|
160
208
|
model.text_encoder_2 = text_encoder_2
|
161
209
|
model.controlnet = controlnet
|
162
210
|
|
211
|
+
# update config to be able to load from file
|
163
212
|
update_dict = {
|
164
213
|
"vae": ("optimum.rbln", "RBLNAutoencoderKL"),
|
165
214
|
"text_encoder": ("optimum.rbln", "RBLNCLIPTextModel"),
|
@@ -169,14 +218,24 @@ class RBLNStableDiffusionXLControlNetImg2ImgPipeline(StableDiffusionXLControlNet
|
|
169
218
|
}
|
170
219
|
model.register_to_config(**update_dict)
|
171
220
|
|
172
|
-
|
173
|
-
|
174
|
-
|
175
|
-
|
176
|
-
|
177
|
-
|
178
|
-
|
179
|
-
|
221
|
+
if model_save_dir is not None:
|
222
|
+
# overwrite to replace incorrect config
|
223
|
+
model.save_config(model_save_dir)
|
224
|
+
|
225
|
+
# use for CI to access each compiled model
|
226
|
+
if rbln_constructor_kwargs.pop("rbln_optimize_host_memory", None) is False:
|
227
|
+
model.compiled_models = [
|
228
|
+
vae.compiled_models[0],
|
229
|
+
vae.compiled_models[1],
|
230
|
+
text_encoder.compiled_models[0],
|
231
|
+
text_encoder_2.compiled_models[0],
|
232
|
+
unet.compiled_models[0],
|
233
|
+
]
|
234
|
+
if isinstance(controlnet, RBLNMultiControlNetModel):
|
235
|
+
for c_model in controlnet.nets:
|
236
|
+
model.compiled_models.append(c_model.compiled_models[0])
|
237
|
+
else:
|
238
|
+
model.compiled_models.append(controlnet.compiled_models[0])
|
180
239
|
|
181
240
|
return model
|
182
241
|
|
optimum/rbln/modeling_alias.py
CHANGED
@@ -24,7 +24,9 @@
|
|
24
24
|
from .modeling_base import (
|
25
25
|
RBLNModelForAudioClassification,
|
26
26
|
RBLNModelForImageClassification,
|
27
|
+
RBLNModelForMaskedLM,
|
27
28
|
RBLNModelForQuestionAnswering,
|
29
|
+
RBLNModelForSequenceClassification,
|
28
30
|
)
|
29
31
|
from .modeling_seq2seq import RBLNModelForSeq2SeqLM
|
30
32
|
|
@@ -34,7 +36,11 @@ class RBLNASTForAudioClassification(RBLNModelForAudioClassification):
|
|
34
36
|
|
35
37
|
|
36
38
|
class RBLNBertForQuestionAnswering(RBLNModelForQuestionAnswering):
|
37
|
-
|
39
|
+
rbln_model_input_names = ["input_ids", "attention_mask", "token_type_ids"]
|
40
|
+
|
41
|
+
|
42
|
+
class RBLNDistilBertForQuestionAnswering(RBLNModelForQuestionAnswering):
|
43
|
+
rbln_model_input_names = ["input_ids", "attention_mask"]
|
38
44
|
|
39
45
|
|
40
46
|
class RBLNResNetForImageClassification(RBLNModelForImageClassification):
|
@@ -47,3 +53,15 @@ class RBLNT5ForConditionalGeneration(RBLNModelForSeq2SeqLM):
|
|
47
53
|
|
48
54
|
class RBLNBartForConditionalGeneration(RBLNModelForSeq2SeqLM):
|
49
55
|
pass
|
56
|
+
|
57
|
+
|
58
|
+
class RBLNXLMRobertaForSequenceClassification(RBLNModelForSequenceClassification):
|
59
|
+
pass
|
60
|
+
|
61
|
+
|
62
|
+
class RBLNRobertaForSequenceClassification(RBLNModelForSequenceClassification):
|
63
|
+
pass
|
64
|
+
|
65
|
+
|
66
|
+
class RBLNRobertaForMaskedLM(RBLNModelForMaskedLM):
|
67
|
+
pass
|
optimum/rbln/modeling_base.py
CHANGED
@@ -39,7 +39,9 @@ from transformers import (
|
|
39
39
|
AutoModel,
|
40
40
|
AutoModelForAudioClassification,
|
41
41
|
AutoModelForImageClassification,
|
42
|
+
AutoModelForMaskedLM,
|
42
43
|
AutoModelForQuestionAnswering,
|
44
|
+
AutoModelForSequenceClassification,
|
43
45
|
GenerationConfig,
|
44
46
|
PretrainedConfig,
|
45
47
|
)
|
@@ -49,10 +51,15 @@ from .utils.runtime_utils import UnavailableRuntime
|
|
49
51
|
from .utils.save_utils import maybe_load_preprocessors, maybe_save_preprocessors
|
50
52
|
|
51
53
|
|
52
|
-
logger = logging.getLogger(__name__)
|
53
|
-
|
54
54
|
if TYPE_CHECKING:
|
55
|
-
from transformers import
|
55
|
+
from transformers import (
|
56
|
+
AutoFeatureExtractor,
|
57
|
+
AutoProcessor,
|
58
|
+
AutoTokenizer,
|
59
|
+
PreTrainedModel,
|
60
|
+
)
|
61
|
+
|
62
|
+
logger = logging.getLogger(__name__)
|
56
63
|
|
57
64
|
|
58
65
|
class RBLNBaseModel(OptimizedModel, ABC):
|
@@ -154,13 +161,23 @@ class RBLNBaseModel(OptimizedModel, ABC):
|
|
154
161
|
Directory where to save the model file.
|
155
162
|
"""
|
156
163
|
real_save_dir = self.model_save_dir / self.subfolder
|
164
|
+
save_directory_path = Path(save_directory)
|
157
165
|
if os.path.exists(real_save_dir) and os.path.isdir(real_save_dir):
|
166
|
+
if save_directory_path.absolute() == real_save_dir.absolute():
|
167
|
+
raise FileExistsError(
|
168
|
+
f"Cannot save model to '{save_directory}'. "
|
169
|
+
f"This directory already exists and contains the model files."
|
170
|
+
)
|
158
171
|
shutil.copytree(real_save_dir, save_directory, dirs_exist_ok=True)
|
159
172
|
self.config.save_pretrained(save_directory)
|
160
173
|
if self.generation_config is not None:
|
161
174
|
self.generation_config.save_pretrained(save_directory)
|
162
175
|
else:
|
163
|
-
raise FileNotFoundError(
|
176
|
+
raise FileNotFoundError(
|
177
|
+
f"Unable to save the model. The model directory '{real_save_dir}' does not exist or is not accessible. "
|
178
|
+
f"Cannot save to the specified destination '{save_directory}'. "
|
179
|
+
f"Please ensure the model directory exists and you have the necessary permissions to access it."
|
180
|
+
)
|
164
181
|
|
165
182
|
@classmethod
|
166
183
|
def _from_pretrained(
|
@@ -194,7 +211,12 @@ class RBLNBaseModel(OptimizedModel, ABC):
|
|
194
211
|
token = HfFolder().get_token()
|
195
212
|
else:
|
196
213
|
token = use_auth_token
|
197
|
-
repo_files = list(
|
214
|
+
repo_files = list(
|
215
|
+
map(
|
216
|
+
Path,
|
217
|
+
HfApi().list_repo_files(model_id, revision=revision, token=token),
|
218
|
+
)
|
219
|
+
)
|
198
220
|
|
199
221
|
pattern = "*.rbln" if subfolder == "" else f"{subfolder}/*.rbln"
|
200
222
|
rbln_files = [p for p in repo_files if p.match(pattern)]
|
@@ -285,7 +307,7 @@ class RBLNBaseModel(OptimizedModel, ABC):
|
|
285
307
|
preprocessors,
|
286
308
|
model_save_dir=model_save_dir,
|
287
309
|
subfolder=subfolder,
|
288
|
-
rbln_compiled_models=None if rbln_optimize_host_memory else rbln_compiled_models,
|
310
|
+
rbln_compiled_models=(None if rbln_optimize_host_memory else rbln_compiled_models),
|
289
311
|
**kwargs,
|
290
312
|
)
|
291
313
|
|
@@ -375,7 +397,7 @@ class RBLNBaseModel(OptimizedModel, ABC):
|
|
375
397
|
return self.forward(*args, **kwargs)
|
376
398
|
|
377
399
|
@classmethod
|
378
|
-
def wrap_model_if_needed(cls, model: torch.nn.Module) -> torch.nn.Module:
|
400
|
+
def wrap_model_if_needed(cls, model: torch.nn.Module, rbln_config: RBLNConfig) -> torch.nn.Module:
|
379
401
|
# Wrap the model if needed.
|
380
402
|
return model
|
381
403
|
|
@@ -398,7 +420,9 @@ class RBLNBaseModel(OptimizedModel, ABC):
|
|
398
420
|
@classmethod
|
399
421
|
@abstractmethod
|
400
422
|
def _create_runtimes(
|
401
|
-
cls,
|
423
|
+
cls,
|
424
|
+
compiled_models: List[rebel.RBLNCompiledModel],
|
425
|
+
rbln_device_map: Dict[str, int],
|
402
426
|
) -> List[rebel.Runtime]:
|
403
427
|
# compiled_models -> runtimes
|
404
428
|
pass
|
@@ -495,7 +519,7 @@ class RBLNModel(RBLNBaseModel):
|
|
495
519
|
|
496
520
|
@classmethod
|
497
521
|
def get_compiled_model(cls, model: "PreTrainedModel", rbln_config: RBLNConfig):
|
498
|
-
model = cls.wrap_model_if_needed(model)
|
522
|
+
model = cls.wrap_model_if_needed(model, rbln_config)
|
499
523
|
rbln_runtime_configs = list(rbln_config.values())
|
500
524
|
if len(rbln_runtime_configs) != 1:
|
501
525
|
raise ValueError
|
@@ -596,7 +620,9 @@ class RBLNModel(RBLNBaseModel):
|
|
596
620
|
|
597
621
|
@classmethod
|
598
622
|
def _create_runtimes(
|
599
|
-
cls,
|
623
|
+
cls,
|
624
|
+
compiled_models: List[rebel.RBLNCompiledModel],
|
625
|
+
rbln_device_map: Dict[str, int],
|
600
626
|
) -> List[rebel.Runtime]:
|
601
627
|
device = rbln_device_map[DEFAULT_COMPILED_MODEL_NAME]
|
602
628
|
return [compiled_model.create_runtime(tensor_type="pt", device=device) for compiled_model in compiled_models]
|
@@ -616,8 +642,8 @@ class RBLNModelForQuestionAnswering(RBLNModel):
|
|
616
642
|
preprocessors: Optional[Union["AutoFeatureExtractor", "AutoProcessor", "AutoTokenizer"]],
|
617
643
|
model_config: Optional["PretrainedConfig"] = None,
|
618
644
|
rbln_max_seq_len: Optional[int] = None,
|
619
|
-
rbln_model_input_names: Optional[List[str]] = None,
|
620
645
|
rbln_batch_size: Optional[int] = None,
|
646
|
+
rbln_model_input_names: Optional[List[str]] = None,
|
621
647
|
) -> RBLNConfig:
|
622
648
|
if rbln_max_seq_len is None:
|
623
649
|
for tokenizer in preprocessors:
|
@@ -627,15 +653,15 @@ class RBLNModelForQuestionAnswering(RBLNModel):
|
|
627
653
|
if rbln_max_seq_len is None:
|
628
654
|
raise ValueError("`rbln_max_seq_len` should be specified!")
|
629
655
|
|
630
|
-
if rbln_model_input_names is None:
|
631
|
-
# These are BERT's inputs
|
632
|
-
rbln_model_input_names = ["input_ids", "attention_mask", "token_type_ids"]
|
633
|
-
|
634
656
|
if rbln_batch_size is None:
|
635
657
|
rbln_batch_size = 1
|
658
|
+
|
659
|
+
if rbln_model_input_names is not None:
|
660
|
+
cls.rbln_model_input_names = rbln_model_input_names
|
661
|
+
|
636
662
|
input_info = [
|
637
663
|
(model_input_name, [rbln_batch_size, rbln_max_seq_len], "int64")
|
638
|
-
for model_input_name in rbln_model_input_names
|
664
|
+
for model_input_name in cls.rbln_model_input_names
|
639
665
|
]
|
640
666
|
|
641
667
|
rbln_runtime_config = RBLNRuntimeConfig(input_info=input_info)
|
@@ -672,7 +698,13 @@ class RBLNModelForImageClassification(RBLNModel):
|
|
672
698
|
if rbln_batch_size is None:
|
673
699
|
rbln_batch_size = 1
|
674
700
|
|
675
|
-
input_info = [
|
701
|
+
input_info = [
|
702
|
+
(
|
703
|
+
"pixel_values",
|
704
|
+
[rbln_batch_size, 3, rbln_image_size, rbln_image_size],
|
705
|
+
"float32",
|
706
|
+
)
|
707
|
+
]
|
676
708
|
|
677
709
|
rbln_runtime_config = RBLNRuntimeConfig(input_info=input_info)
|
678
710
|
rbln_runtime_config.batch_size = rbln_batch_size
|
@@ -737,7 +769,11 @@ class RBLNModelForAudioClassification(RBLNModel):
|
|
737
769
|
meta["rbln_num_mel_bins"] = rbln_num_mel_bins
|
738
770
|
|
739
771
|
model_input_info = [
|
740
|
-
(
|
772
|
+
(
|
773
|
+
"input_values",
|
774
|
+
[rbln_batch_size, rbln_max_length, rbln_num_mel_bins],
|
775
|
+
"float32",
|
776
|
+
),
|
741
777
|
]
|
742
778
|
|
743
779
|
rbln_runtime_config = RBLNRuntimeConfig(input_info=model_input_info, batch_size=rbln_batch_size)
|
@@ -748,3 +784,111 @@ class RBLNModelForAudioClassification(RBLNModel):
|
|
748
784
|
)
|
749
785
|
|
750
786
|
return rbln_config
|
787
|
+
|
788
|
+
|
789
|
+
class RBLNModelForSequenceClassification(RBLNModel):
|
790
|
+
"""
|
791
|
+
This is a generic model class that will be instantiated as one of the model classes of the library (with a sequence classification head) when created with the from_pretrained() class method
|
792
|
+
This model inherits from [`RBLNModel`]. Check the superclass documentation for the generic methods the library implements for all its models.
|
793
|
+
|
794
|
+
A class to convert and run pre-trained transformers based SequenceClassification models on RBLN devices.
|
795
|
+
It implements the methods to convert a pre-trained transformers SequenceClassification model into a RBLN transformer model by:
|
796
|
+
- transferring the checkpoint weights of the original into an optimized RBLN graph,
|
797
|
+
- compiling the resulting graph using the RBLN compiler.
|
798
|
+
|
799
|
+
Currently, this model class supports the 'XLMRoberta' and 'Roberta' model from the transformers library. Future updates may include support for additional model types.
|
800
|
+
"""
|
801
|
+
|
802
|
+
model_type = "rbln_model"
|
803
|
+
auto_model_class = AutoModelForSequenceClassification
|
804
|
+
|
805
|
+
@classmethod
|
806
|
+
def _get_rbln_config(
|
807
|
+
cls,
|
808
|
+
preprocessors: Optional[Union["AutoFeatureExtractor", "AutoProcessor", "AutoTokenizer"]],
|
809
|
+
model_config: Optional["PretrainedConfig"] = None,
|
810
|
+
rbln_max_seq_len: Optional[int] = None,
|
811
|
+
rbln_model_input_names: Optional[List[str]] = None,
|
812
|
+
rbln_batch_size: Optional[int] = None,
|
813
|
+
) -> RBLNConfig:
|
814
|
+
max_position_embeddings = getattr(model_config, "n_positions", None) or getattr(
|
815
|
+
model_config, "max_position_embeddings", None
|
816
|
+
)
|
817
|
+
|
818
|
+
if rbln_max_seq_len is None:
|
819
|
+
rbln_max_seq_len = max_position_embeddings
|
820
|
+
if rbln_max_seq_len is None:
|
821
|
+
for tokenizer in preprocessors:
|
822
|
+
if hasattr(tokenizer, "model_max_length"):
|
823
|
+
rbln_max_seq_len = tokenizer.model_max_length
|
824
|
+
break
|
825
|
+
if rbln_max_seq_len is None:
|
826
|
+
raise ValueError("`rbln_max_seq_len` should be specified!")
|
827
|
+
|
828
|
+
if max_position_embeddings is not None and rbln_max_seq_len > max_position_embeddings:
|
829
|
+
raise ValueError("`rbln_enc_max_seq_len` should be less or equal than max_position_embeddings!")
|
830
|
+
|
831
|
+
if rbln_model_input_names is None:
|
832
|
+
# These are BERT's inputs
|
833
|
+
rbln_model_input_names = ["input_ids", "attention_mask"]
|
834
|
+
|
835
|
+
if rbln_batch_size is None:
|
836
|
+
rbln_batch_size = 1
|
837
|
+
input_info = [
|
838
|
+
(model_input_name, [rbln_batch_size, rbln_max_seq_len], "int64")
|
839
|
+
for model_input_name in rbln_model_input_names
|
840
|
+
]
|
841
|
+
|
842
|
+
rbln_runtime_config = RBLNRuntimeConfig(input_info=input_info)
|
843
|
+
rbln_runtime_config.batch_size = rbln_batch_size
|
844
|
+
meta = {"rbln_max_seq_len": rbln_max_seq_len}
|
845
|
+
|
846
|
+
return RBLNConfig.from_rbln_runtime_configs([rbln_runtime_config], _rbln_meta=meta)
|
847
|
+
|
848
|
+
|
849
|
+
class RBLNModelForMaskedLM(RBLNModel):
|
850
|
+
model_type = "rbln_model"
|
851
|
+
auto_model_class = AutoModelForMaskedLM
|
852
|
+
|
853
|
+
@classmethod
|
854
|
+
def _get_rbln_config(
|
855
|
+
cls,
|
856
|
+
preprocessors: Optional[Union["AutoFeatureExtractor", "AutoProcessor", "AutoTokenizer"]],
|
857
|
+
model_config: Optional["PretrainedConfig"] = None,
|
858
|
+
rbln_max_seq_len: Optional[int] = None,
|
859
|
+
rbln_model_input_names: Optional[List[str]] = None,
|
860
|
+
rbln_batch_size: Optional[int] = None,
|
861
|
+
) -> RBLNConfig:
|
862
|
+
max_position_embeddings = getattr(model_config, "n_positions", None) or getattr(
|
863
|
+
model_config, "max_position_embeddings", None
|
864
|
+
)
|
865
|
+
|
866
|
+
if rbln_max_seq_len is None:
|
867
|
+
rbln_max_seq_len = max_position_embeddings
|
868
|
+
if rbln_max_seq_len is None:
|
869
|
+
for tokenizer in preprocessors:
|
870
|
+
if hasattr(tokenizer, "model_max_length"):
|
871
|
+
rbln_max_seq_len = tokenizer.model_max_length
|
872
|
+
break
|
873
|
+
if rbln_max_seq_len is None:
|
874
|
+
raise ValueError("`rbln_max_seq_len` should be specified!")
|
875
|
+
|
876
|
+
if max_position_embeddings is not None and rbln_max_seq_len > max_position_embeddings:
|
877
|
+
raise ValueError("`rbln_enc_max_seq_len` should be less or equal than max_position_embeddings!")
|
878
|
+
|
879
|
+
if rbln_model_input_names is None:
|
880
|
+
# These are BERT's inputs
|
881
|
+
rbln_model_input_names = ["input_ids", "attention_mask"]
|
882
|
+
|
883
|
+
if rbln_batch_size is None:
|
884
|
+
rbln_batch_size = 1
|
885
|
+
input_info = [
|
886
|
+
(model_input_name, [rbln_batch_size, rbln_max_seq_len], "int64")
|
887
|
+
for model_input_name in rbln_model_input_names
|
888
|
+
]
|
889
|
+
|
890
|
+
rbln_runtime_config = RBLNRuntimeConfig(input_info=input_info)
|
891
|
+
rbln_runtime_config.batch_size = rbln_batch_size
|
892
|
+
meta = {"rbln_max_seq_len": rbln_max_seq_len}
|
893
|
+
|
894
|
+
return RBLNConfig.from_rbln_runtime_configs([rbln_runtime_config], _rbln_meta=meta)
|
@@ -27,30 +27,38 @@ from transformers.utils import _LazyModule
|
|
27
27
|
|
28
28
|
|
29
29
|
_import_structure = {
|
30
|
+
"cache_utils": ["RebelDynamicCache"],
|
30
31
|
"generation": ["BatchTextIteratorStreamer"],
|
31
32
|
"models": [
|
32
33
|
"RBLNCLIPTextModel",
|
33
34
|
"RBLNCLIPTextModelWithProjection",
|
34
35
|
"RBLNDPTForDepthEstimation",
|
36
|
+
"RBLNGemmaForCausalLM",
|
35
37
|
"RBLNGPT2LMHeadModel",
|
36
38
|
"RBLNWav2Vec2ForCTC",
|
37
39
|
"RBLNWhisperForConditionalGeneration",
|
38
40
|
"RBLNLlamaForCausalLM",
|
39
41
|
"RBLNMidmLMHeadModel",
|
42
|
+
"RBLNMistralForCausalLM",
|
43
|
+
"RBLNXLMRobertaModel",
|
40
44
|
],
|
41
45
|
}
|
42
46
|
|
43
47
|
if TYPE_CHECKING:
|
48
|
+
from .cache_utils import RebelDynamicCache
|
44
49
|
from .generation import BatchTextIteratorStreamer
|
45
50
|
from .models import (
|
46
51
|
RBLNCLIPTextModel,
|
47
52
|
RBLNCLIPTextModelWithProjection,
|
48
53
|
RBLNDPTForDepthEstimation,
|
54
|
+
RBLNGemmaForCausalLM,
|
49
55
|
RBLNGPT2LMHeadModel,
|
50
56
|
RBLNLlamaForCausalLM,
|
51
57
|
RBLNMidmLMHeadModel,
|
58
|
+
RBLNMistralForCausalLM,
|
52
59
|
RBLNWav2Vec2ForCTC,
|
53
60
|
RBLNWhisperForConditionalGeneration,
|
61
|
+
RBLNXLMRobertaModel,
|
54
62
|
)
|
55
63
|
else:
|
56
64
|
import sys
|