optimum-rbln 0.7.4a3__py3-none-any.whl → 0.7.4a5__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- optimum/rbln/__init__.py +156 -36
- optimum/rbln/__version__.py +1 -1
- optimum/rbln/configuration_utils.py +772 -0
- optimum/rbln/diffusers/__init__.py +56 -0
- optimum/rbln/diffusers/configurations/__init__.py +30 -0
- optimum/rbln/diffusers/configurations/models/__init__.py +6 -0
- optimum/rbln/diffusers/configurations/models/configuration_autoencoder_kl.py +66 -0
- optimum/rbln/diffusers/configurations/models/configuration_controlnet.py +54 -0
- optimum/rbln/diffusers/configurations/models/configuration_prior_transformer.py +44 -0
- optimum/rbln/diffusers/configurations/models/configuration_transformer_sd3.py +48 -0
- optimum/rbln/diffusers/configurations/models/configuration_unet_2d_condition.py +66 -0
- optimum/rbln/diffusers/configurations/models/configuration_vq_model.py +67 -0
- optimum/rbln/diffusers/configurations/pipelines/__init__.py +30 -0
- optimum/rbln/diffusers/configurations/pipelines/configuration_controlnet.py +221 -0
- optimum/rbln/diffusers/configurations/pipelines/configuration_kandinsky2_2.py +285 -0
- optimum/rbln/diffusers/configurations/pipelines/configuration_stable_diffusion.py +118 -0
- optimum/rbln/diffusers/configurations/pipelines/configuration_stable_diffusion_3.py +143 -0
- optimum/rbln/diffusers/configurations/pipelines/configuration_stable_diffusion_xl.py +124 -0
- optimum/rbln/diffusers/modeling_diffusers.py +63 -122
- optimum/rbln/diffusers/models/autoencoders/autoencoder_kl.py +109 -128
- optimum/rbln/diffusers/models/autoencoders/vae.py +4 -6
- optimum/rbln/diffusers/models/autoencoders/vq_model.py +84 -85
- optimum/rbln/diffusers/models/controlnet.py +55 -70
- optimum/rbln/diffusers/models/transformers/prior_transformer.py +40 -77
- optimum/rbln/diffusers/models/transformers/transformer_sd3.py +43 -68
- optimum/rbln/diffusers/models/unets/unet_2d_condition.py +110 -113
- optimum/rbln/diffusers/pipelines/controlnet/multicontrolnet.py +3 -4
- optimum/rbln/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion.py +2 -1
- optimum/rbln/modeling.py +58 -39
- optimum/rbln/modeling_base.py +85 -80
- optimum/rbln/transformers/__init__.py +79 -8
- optimum/rbln/transformers/configuration_alias.py +49 -0
- optimum/rbln/transformers/configuration_generic.py +142 -0
- optimum/rbln/transformers/modeling_generic.py +193 -280
- optimum/rbln/transformers/models/__init__.py +96 -34
- optimum/rbln/transformers/models/auto/auto_factory.py +3 -3
- optimum/rbln/transformers/models/bart/__init__.py +1 -0
- optimum/rbln/transformers/models/bart/configuration_bart.py +24 -0
- optimum/rbln/transformers/models/bart/modeling_bart.py +10 -84
- optimum/rbln/transformers/models/bert/__init__.py +1 -0
- optimum/rbln/transformers/models/bert/configuration_bert.py +31 -0
- optimum/rbln/transformers/models/bert/modeling_bert.py +7 -80
- optimum/rbln/transformers/models/clip/__init__.py +6 -0
- optimum/rbln/transformers/models/clip/configuration_clip.py +79 -0
- optimum/rbln/transformers/models/clip/modeling_clip.py +72 -75
- optimum/rbln/transformers/models/decoderonly/__init__.py +1 -0
- optimum/rbln/transformers/models/decoderonly/configuration_decoderonly.py +90 -0
- optimum/rbln/transformers/models/decoderonly/decoderonly_architecture.py +50 -43
- optimum/rbln/transformers/models/decoderonly/modeling_decoderonly.py +114 -141
- optimum/rbln/transformers/models/dpt/__init__.py +1 -0
- optimum/rbln/transformers/models/dpt/configuration_dpt.py +19 -0
- optimum/rbln/transformers/models/dpt/modeling_dpt.py +3 -76
- optimum/rbln/transformers/models/exaone/__init__.py +1 -0
- optimum/rbln/transformers/models/exaone/configuration_exaone.py +19 -0
- optimum/rbln/transformers/models/gemma/__init__.py +1 -0
- optimum/rbln/transformers/models/gemma/configuration_gemma.py +19 -0
- optimum/rbln/transformers/models/gpt2/__init__.py +1 -0
- optimum/rbln/transformers/models/gpt2/configuration_gpt2.py +19 -0
- optimum/rbln/transformers/models/llama/__init__.py +1 -0
- optimum/rbln/transformers/models/llama/configuration_llama.py +19 -0
- optimum/rbln/transformers/models/llava_next/__init__.py +1 -0
- optimum/rbln/transformers/models/llava_next/configuration_llava_next.py +46 -0
- optimum/rbln/transformers/models/llava_next/modeling_llava_next.py +12 -23
- optimum/rbln/transformers/models/midm/__init__.py +1 -0
- optimum/rbln/transformers/models/midm/configuration_midm.py +19 -0
- optimum/rbln/transformers/models/mistral/__init__.py +1 -0
- optimum/rbln/transformers/models/mistral/configuration_mistral.py +19 -0
- optimum/rbln/transformers/models/phi/__init__.py +1 -0
- optimum/rbln/transformers/models/phi/configuration_phi.py +19 -0
- optimum/rbln/transformers/models/qwen2/__init__.py +1 -0
- optimum/rbln/transformers/models/qwen2/configuration_qwen2.py +19 -0
- optimum/rbln/transformers/models/seq2seq/__init__.py +1 -0
- optimum/rbln/transformers/models/seq2seq/configuration_seq2seq2.py +66 -0
- optimum/rbln/transformers/models/seq2seq/modeling_seq2seq.py +80 -97
- optimum/rbln/transformers/models/t5/__init__.py +1 -0
- optimum/rbln/transformers/models/t5/configuration_t5.py +24 -0
- optimum/rbln/transformers/models/t5/modeling_t5.py +22 -150
- optimum/rbln/transformers/models/time_series_transformers/__init__.py +1 -0
- optimum/rbln/transformers/models/time_series_transformers/configuration_time_series_transformer.py +34 -0
- optimum/rbln/transformers/models/time_series_transformers/modeling_time_series_transformers.py +52 -54
- optimum/rbln/transformers/models/wav2vec2/__init__.py +1 -0
- optimum/rbln/transformers/models/wav2vec2/configuration_wav2vec.py +19 -0
- optimum/rbln/transformers/models/wav2vec2/modeling_wav2vec2.py +9 -72
- optimum/rbln/transformers/models/whisper/__init__.py +1 -0
- optimum/rbln/transformers/models/whisper/configuration_whisper.py +64 -0
- optimum/rbln/transformers/models/whisper/modeling_whisper.py +57 -72
- optimum/rbln/transformers/models/xlm_roberta/__init__.py +1 -0
- optimum/rbln/transformers/models/xlm_roberta/configuration_xlm_roberta.py +19 -0
- optimum/rbln/transformers/models/xlm_roberta/modeling_xlm_roberta.py +3 -83
- optimum/rbln/utils/submodule.py +26 -43
- {optimum_rbln-0.7.4a3.dist-info → optimum_rbln-0.7.4a5.dist-info}/METADATA +1 -1
- optimum_rbln-0.7.4a5.dist-info/RECORD +162 -0
- optimum/rbln/modeling_config.py +0 -310
- optimum_rbln-0.7.4a3.dist-info/RECORD +0 -126
- {optimum_rbln-0.7.4a3.dist-info → optimum_rbln-0.7.4a5.dist-info}/WHEEL +0 -0
- {optimum_rbln-0.7.4a3.dist-info → optimum_rbln-0.7.4a5.dist-info}/licenses/LICENSE +0 -0
@@ -18,19 +18,14 @@ from typing import TYPE_CHECKING, Any, Callable, Dict, List, Optional, Union
|
|
18
18
|
import rebel
|
19
19
|
import torch
|
20
20
|
from rebel.compile_context import CompileContext
|
21
|
-
from transformers import
|
22
|
-
AutoModelForSpeechSeq2Seq,
|
23
|
-
AutoProcessor,
|
24
|
-
PretrainedConfig,
|
25
|
-
WhisperForConditionalGeneration,
|
26
|
-
WhisperModel,
|
27
|
-
)
|
21
|
+
from transformers import AutoModelForSpeechSeq2Seq, WhisperForConditionalGeneration, WhisperModel
|
28
22
|
from transformers.modeling_outputs import BaseModelOutput, Seq2SeqLMOutput
|
29
23
|
|
24
|
+
from ....configuration_utils import RBLNCompileConfig
|
30
25
|
from ....modeling import RBLNModel
|
31
|
-
from ....modeling_config import RBLNCompileConfig, RBLNConfig
|
32
26
|
from ....utils.logging import get_logger
|
33
27
|
from ....utils.runtime_utils import RBLNPytorchRuntime
|
28
|
+
from .configuration_whisper import RBLNWhisperForConditionalGenerationConfig
|
34
29
|
from .generation_whisper import RBLNWhisperGenerationMixin
|
35
30
|
from .whisper_architecture import WhisperWrapper
|
36
31
|
|
@@ -38,7 +33,14 @@ from .whisper_architecture import WhisperWrapper
|
|
38
33
|
logger = get_logger(__name__)
|
39
34
|
|
40
35
|
if TYPE_CHECKING:
|
41
|
-
from transformers import
|
36
|
+
from transformers import (
|
37
|
+
AutoFeatureExtractor,
|
38
|
+
AutoProcessor,
|
39
|
+
AutoTokenizer,
|
40
|
+
GenerationConfig,
|
41
|
+
PretrainedConfig,
|
42
|
+
PreTrainedModel,
|
43
|
+
)
|
42
44
|
|
43
45
|
|
44
46
|
class RBLNRuntimeEncoder(RBLNPytorchRuntime):
|
@@ -117,10 +119,10 @@ class RBLNWhisperForConditionalGeneration(RBLNModel, RBLNWhisperGenerationMixin)
|
|
117
119
|
def __post_init__(self, **kwargs):
|
118
120
|
super().__post_init__(**kwargs)
|
119
121
|
|
120
|
-
self.batch_size = self.rbln_config.
|
121
|
-
self.dec_max_seq_len = self.rbln_config.
|
122
|
-
self.rbln_token_timestamps = self.rbln_config.
|
123
|
-
self.use_attention_mask = self.rbln_config.
|
122
|
+
self.batch_size = self.rbln_config.batch_size
|
123
|
+
self.dec_max_seq_len = self.rbln_config.dec_max_seq_len
|
124
|
+
self.rbln_token_timestamps = self.rbln_config.token_timestamps
|
125
|
+
self.use_attention_mask = self.rbln_config.use_attention_mask
|
124
126
|
|
125
127
|
self.encoder = RBLNRuntimeEncoder(runtime=self.model[0], main_input_name="input_features")
|
126
128
|
self.decoder = RBLNRuntimeDecoder(
|
@@ -169,16 +171,16 @@ class RBLNWhisperForConditionalGeneration(RBLNModel, RBLNWhisperGenerationMixin)
|
|
169
171
|
raise NotImplementedError
|
170
172
|
|
171
173
|
@classmethod
|
172
|
-
def wrap_model_if_needed(self, model: "PreTrainedModel", rbln_config:
|
173
|
-
rbln_token_timestamps = rbln_config.model_cfg["token_timestamps"]
|
174
|
-
use_attention_mask = rbln_config.model_cfg.get("use_attention_mask", False)
|
174
|
+
def wrap_model_if_needed(self, model: "PreTrainedModel", rbln_config: RBLNWhisperForConditionalGenerationConfig):
|
175
175
|
return WhisperWrapper(
|
176
|
-
model,
|
176
|
+
model,
|
177
|
+
use_attention_mask=rbln_config.use_attention_mask,
|
178
|
+
rbln_token_timestamps=rbln_config.token_timestamps,
|
177
179
|
)
|
178
180
|
|
179
181
|
@classmethod
|
180
182
|
@torch.inference_mode()
|
181
|
-
def get_compiled_model(cls, model, rbln_config:
|
183
|
+
def get_compiled_model(cls, model, rbln_config: RBLNWhisperForConditionalGenerationConfig):
|
182
184
|
wrapped_model = cls.wrap_model_if_needed(model, rbln_config)
|
183
185
|
|
184
186
|
enc_compile_config = rbln_config.compile_cfgs[0]
|
@@ -218,32 +220,21 @@ class RBLNWhisperForConditionalGeneration(RBLNModel, RBLNWhisperGenerationMixin)
|
|
218
220
|
return {"encoder": compiled_encoder, "decoder": compiled_decoder}
|
219
221
|
|
220
222
|
@classmethod
|
221
|
-
def
|
223
|
+
def _update_rbln_config(
|
222
224
|
cls,
|
223
|
-
preprocessors: Union["AutoFeatureExtractor", "AutoProcessor"],
|
224
|
-
|
225
|
-
|
226
|
-
|
227
|
-
|
228
|
-
rbln_token_timestamps = rbln_kwargs.get("token_timestamps", False)
|
229
|
-
rbln_batch_size = 1 if rbln_batch_size is None else rbln_batch_size
|
230
|
-
|
225
|
+
preprocessors: Optional[Union["AutoFeatureExtractor", "AutoProcessor", "AutoTokenizer"]] = None,
|
226
|
+
model: Optional["PreTrainedModel"] = None,
|
227
|
+
model_config: Optional["PretrainedConfig"] = None,
|
228
|
+
rbln_config: Optional[RBLNWhisperForConditionalGenerationConfig] = None,
|
229
|
+
) -> RBLNWhisperForConditionalGenerationConfig:
|
231
230
|
expected_seq_len = model_config.max_source_positions * 2
|
232
231
|
num_mel_bins = model_config.num_mel_bins
|
233
|
-
enc_max_seq_len = model_config.max_source_positions
|
232
|
+
rbln_config.enc_max_seq_len = model_config.max_source_positions
|
234
233
|
|
235
234
|
# 'whisper-large-v3-turbo' doesn't have 'max_length', but PretrainedConfig have default value for the key 'max_length'
|
236
|
-
|
237
|
-
if
|
238
|
-
|
239
|
-
|
240
|
-
# use_attention_mask conditions
|
241
|
-
rbln_use_attention_mask = rbln_kwargs.get("use_attention_mask", None)
|
242
|
-
if rbln_use_attention_mask is None:
|
243
|
-
rbln_use_attention_mask = False
|
244
|
-
rbln_npu = rbln_kwargs.get("npu", None) or rebel.get_npu_name()
|
245
|
-
if rbln_npu == "RBLN-CA02":
|
246
|
-
rbln_use_attention_mask = True
|
235
|
+
rbln_config.dec_max_seq_len = getattr(model_config, "max_target_positions", None)
|
236
|
+
if rbln_config.dec_max_seq_len is None:
|
237
|
+
rbln_config.dec_max_seq_len = model_config.max_length
|
247
238
|
|
248
239
|
enc_input_info = [
|
249
240
|
("input_features", [1, num_mel_bins, expected_seq_len], "float32"),
|
@@ -252,9 +243,9 @@ class RBLNWhisperForConditionalGeneration(RBLNModel, RBLNWhisperGenerationMixin)
|
|
252
243
|
"cross_key_value_states",
|
253
244
|
[
|
254
245
|
model_config.decoder_layers * 2,
|
255
|
-
|
246
|
+
rbln_config.batch_size,
|
256
247
|
model_config.decoder_attention_heads,
|
257
|
-
enc_max_seq_len,
|
248
|
+
rbln_config.enc_max_seq_len,
|
258
249
|
model_config.d_model // model_config.decoder_attention_heads,
|
259
250
|
],
|
260
251
|
"float32",
|
@@ -262,9 +253,9 @@ class RBLNWhisperForConditionalGeneration(RBLNModel, RBLNWhisperGenerationMixin)
|
|
262
253
|
]
|
263
254
|
|
264
255
|
dec_input_info = [
|
265
|
-
("decoder_input_ids", [
|
266
|
-
("cache_position", [
|
267
|
-
("block_tables", [
|
256
|
+
("decoder_input_ids", [rbln_config.batch_size, 1], "int64"),
|
257
|
+
("cache_position", [rbln_config.batch_size, 1], "int32"),
|
258
|
+
("block_tables", [rbln_config.batch_size, 1], "int16"),
|
268
259
|
]
|
269
260
|
dec_input_info.extend(
|
270
261
|
[
|
@@ -272,9 +263,9 @@ class RBLNWhisperForConditionalGeneration(RBLNModel, RBLNWhisperGenerationMixin)
|
|
272
263
|
"cross_key_value_states",
|
273
264
|
[
|
274
265
|
model_config.decoder_layers * 2,
|
275
|
-
|
266
|
+
rbln_config.batch_size,
|
276
267
|
model_config.decoder_attention_heads,
|
277
|
-
enc_max_seq_len,
|
268
|
+
rbln_config.enc_max_seq_len,
|
278
269
|
model_config.d_model // model_config.decoder_attention_heads,
|
279
270
|
],
|
280
271
|
"float32",
|
@@ -286,9 +277,9 @@ class RBLNWhisperForConditionalGeneration(RBLNModel, RBLNWhisperGenerationMixin)
|
|
286
277
|
(
|
287
278
|
f"self_key_value_states_{i}",
|
288
279
|
[
|
289
|
-
|
280
|
+
rbln_config.batch_size,
|
290
281
|
model_config.decoder_attention_heads,
|
291
|
-
|
282
|
+
rbln_config.dec_max_seq_len,
|
292
283
|
model_config.d_model // model_config.encoder_attention_heads,
|
293
284
|
],
|
294
285
|
"float32",
|
@@ -297,26 +288,15 @@ class RBLNWhisperForConditionalGeneration(RBLNModel, RBLNWhisperGenerationMixin)
|
|
297
288
|
]
|
298
289
|
)
|
299
290
|
|
300
|
-
if
|
301
|
-
dec_input_info.insert(
|
291
|
+
if rbln_config.use_attention_mask:
|
292
|
+
dec_input_info.insert(
|
293
|
+
1, ("decoder_attention_mask", [rbln_config.batch_size, rbln_config.dec_max_seq_len], "float32")
|
294
|
+
)
|
302
295
|
|
303
296
|
enc_compile_config = RBLNCompileConfig(compiled_model_name="encoder", input_info=enc_input_info)
|
304
297
|
dec_compile_config = RBLNCompileConfig(compiled_model_name="decoder", input_info=dec_input_info)
|
305
298
|
|
306
|
-
rbln_config
|
307
|
-
rbln_cls=cls.__name__,
|
308
|
-
compile_cfgs=[enc_compile_config, dec_compile_config],
|
309
|
-
rbln_kwargs=rbln_kwargs,
|
310
|
-
)
|
311
|
-
|
312
|
-
rbln_config.model_cfg.update(
|
313
|
-
{
|
314
|
-
"batch_size": rbln_batch_size,
|
315
|
-
"dec_max_seq_len": rbln_dec_max_seq_len,
|
316
|
-
"token_timestamps": rbln_token_timestamps,
|
317
|
-
"use_attention_mask": rbln_use_attention_mask,
|
318
|
-
}
|
319
|
-
)
|
299
|
+
rbln_config.set_compile_cfgs([enc_compile_config, dec_compile_config])
|
320
300
|
|
321
301
|
return rbln_config
|
322
302
|
|
@@ -324,18 +304,23 @@ class RBLNWhisperForConditionalGeneration(RBLNModel, RBLNWhisperGenerationMixin)
|
|
324
304
|
def _create_runtimes(
|
325
305
|
cls,
|
326
306
|
compiled_models: List[rebel.RBLNCompiledModel],
|
327
|
-
|
328
|
-
activate_profiler: Optional[bool] = None,
|
307
|
+
rbln_config: RBLNWhisperForConditionalGenerationConfig,
|
329
308
|
) -> List[rebel.Runtime]:
|
330
|
-
if any(model_name not in
|
309
|
+
if any(model_name not in rbln_config.device_map for model_name in ["encoder", "decoder"]):
|
331
310
|
cls._raise_missing_compiled_file_error(["encoder", "decoder"])
|
332
311
|
|
333
312
|
return [
|
334
|
-
|
335
|
-
|
313
|
+
rebel.Runtime(
|
314
|
+
compiled_models[0],
|
315
|
+
tensor_type="pt",
|
316
|
+
device=rbln_config.device_map["encoder"],
|
317
|
+
activate_profiler=rbln_config.activate_profiler,
|
336
318
|
),
|
337
|
-
|
338
|
-
|
319
|
+
rebel.Runtime(
|
320
|
+
compiled_models[1],
|
321
|
+
tensor_type="pt",
|
322
|
+
device=rbln_config.device_map["decoder"],
|
323
|
+
activate_profiler=rbln_config.activate_profiler,
|
339
324
|
),
|
340
325
|
]
|
341
326
|
|
@@ -432,7 +417,7 @@ class RBLNWhisperForConditionalGeneration(RBLNModel, RBLNWhisperGenerationMixin)
|
|
432
417
|
decoder_output = self.decoder(
|
433
418
|
decoder_input_ids=decoder_input_ids.contiguous(),
|
434
419
|
decoder_attention_mask=self.decoder_attention_mask,
|
435
|
-
cache_position=torch.zeros([self.rbln_config.
|
420
|
+
cache_position=torch.zeros([self.rbln_config.batch_size, 1], dtype=torch.int32),
|
436
421
|
)
|
437
422
|
lm_logits = decoder_output.logits
|
438
423
|
self.language_cross = decoder_output.cross_attentions
|
@@ -0,0 +1,19 @@
|
|
1
|
+
# Copyright 2025 Rebellions Inc. All rights reserved.
|
2
|
+
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4
|
+
# you may not use this file except in compliance with the License.
|
5
|
+
# You may obtain a copy of the License at:
|
6
|
+
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
8
|
+
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12
|
+
# See the License for the specific language governing permissions and
|
13
|
+
# limitations under the License.
|
14
|
+
|
15
|
+
from ...configuration_generic import RBLNTransformerEncoderForFeatureExtractionConfig
|
16
|
+
|
17
|
+
|
18
|
+
class RBLNXLMRobertaModelConfig(RBLNTransformerEncoderForFeatureExtractionConfig):
|
19
|
+
pass
|
@@ -12,89 +12,9 @@
|
|
12
12
|
# See the License for the specific language governing permissions and
|
13
13
|
# limitations under the License.
|
14
14
|
|
15
|
-
import inspect
|
16
|
-
from typing import TYPE_CHECKING, Optional, Union
|
17
15
|
|
18
|
-
from
|
16
|
+
from ...modeling_generic import RBLNTransformerEncoderForFeatureExtraction
|
19
17
|
|
20
|
-
from ....modeling import RBLNModel
|
21
|
-
from ....modeling_config import RBLNCompileConfig, RBLNConfig
|
22
|
-
from ....utils.logging import get_logger
|
23
18
|
|
24
|
-
|
25
|
-
|
26
|
-
|
27
|
-
if TYPE_CHECKING:
|
28
|
-
from transformers import AutoFeatureExtractor, AutoProcessor, AutoTokenizer
|
29
|
-
|
30
|
-
|
31
|
-
class RBLNXLMRobertaModel(RBLNModel):
|
32
|
-
@classmethod
|
33
|
-
def _get_rbln_config(
|
34
|
-
cls,
|
35
|
-
preprocessors: Optional[Union["AutoFeatureExtractor", "AutoProcessor", "AutoTokenizer"]],
|
36
|
-
model_config: Optional["PretrainedConfig"] = None,
|
37
|
-
rbln_kwargs={},
|
38
|
-
) -> RBLNConfig:
|
39
|
-
rbln_max_seq_len = rbln_kwargs.get("max_seq_len", None)
|
40
|
-
rbln_model_input_names = rbln_kwargs.get("model_input_names", None)
|
41
|
-
rbln_batch_size = rbln_kwargs.get("batch_size", None)
|
42
|
-
|
43
|
-
max_position_embeddings = getattr(model_config, "n_positions", None) or getattr(
|
44
|
-
model_config, "max_position_embeddings", None
|
45
|
-
)
|
46
|
-
|
47
|
-
if rbln_max_seq_len is None:
|
48
|
-
rbln_max_seq_len = max_position_embeddings
|
49
|
-
if rbln_max_seq_len is None:
|
50
|
-
for tokenizer in preprocessors:
|
51
|
-
if hasattr(tokenizer, "model_max_length"):
|
52
|
-
rbln_max_seq_len = tokenizer.model_max_length
|
53
|
-
break
|
54
|
-
if rbln_max_seq_len is None:
|
55
|
-
raise ValueError("`rbln_max_seq_len` should be specified!")
|
56
|
-
|
57
|
-
if max_position_embeddings is not None and rbln_max_seq_len > max_position_embeddings:
|
58
|
-
raise ValueError("`rbln_enc_max_seq_len` should be less or equal than max_position_embeddings!")
|
59
|
-
|
60
|
-
signature_params = inspect.signature(cls.get_hf_class().forward).parameters.keys()
|
61
|
-
|
62
|
-
if rbln_model_input_names is None:
|
63
|
-
for tokenizer in preprocessors:
|
64
|
-
if hasattr(tokenizer, "model_input_names"):
|
65
|
-
rbln_model_input_names = [name for name in signature_params if name in tokenizer.model_input_names]
|
66
|
-
|
67
|
-
invalid_params = set(rbln_model_input_names) - set(signature_params)
|
68
|
-
if invalid_params:
|
69
|
-
raise ValueError(f"Invalid model input names: {invalid_params}")
|
70
|
-
break
|
71
|
-
if rbln_model_input_names is None and hasattr(cls, "rbln_model_input_names"):
|
72
|
-
rbln_model_input_names = cls.rbln_model_input_names
|
73
|
-
elif rbln_model_input_names is None and hasattr(cls, "rbln_model_input_names") is False:
|
74
|
-
raise ValueError(
|
75
|
-
"Specify the model input names obtained by the tokenizer via `rbln_model_input_names`, "
|
76
|
-
f"and be sure to make the order of the inputs same as XLMRobertaModel forward() arguments like ({list(signature_params)})"
|
77
|
-
)
|
78
|
-
else:
|
79
|
-
invalid_params = set(rbln_model_input_names) - set(signature_params)
|
80
|
-
if invalid_params:
|
81
|
-
raise ValueError(f"Invalid model input names: {invalid_params}")
|
82
|
-
rbln_model_input_names = [name for name in signature_params if name in rbln_model_input_names]
|
83
|
-
|
84
|
-
if rbln_batch_size is None:
|
85
|
-
rbln_batch_size = 1
|
86
|
-
|
87
|
-
input_info = [
|
88
|
-
(model_input_name, [rbln_batch_size, rbln_max_seq_len], "int64")
|
89
|
-
for model_input_name in rbln_model_input_names
|
90
|
-
]
|
91
|
-
|
92
|
-
rbln_compile_config = RBLNCompileConfig(input_info=input_info)
|
93
|
-
|
94
|
-
rbln_config = RBLNConfig(
|
95
|
-
rbln_cls=cls.__name__,
|
96
|
-
compile_cfgs=[rbln_compile_config],
|
97
|
-
rbln_kwargs=rbln_kwargs,
|
98
|
-
)
|
99
|
-
rbln_config.model_cfg.update({"max_seq_len": rbln_max_seq_len})
|
100
|
-
return rbln_config
|
19
|
+
class RBLNXLMRobertaModel(RBLNTransformerEncoderForFeatureExtraction):
|
20
|
+
pass
|
optimum/rbln/utils/submodule.py
CHANGED
@@ -13,10 +13,9 @@
|
|
13
13
|
# limitations under the License.
|
14
14
|
|
15
15
|
import importlib
|
16
|
-
from
|
17
|
-
from typing import TYPE_CHECKING, Any, Dict, List
|
16
|
+
from typing import TYPE_CHECKING, Any, Dict, List, Type
|
18
17
|
|
19
|
-
from ..
|
18
|
+
from ..configuration_utils import RBLNModelConfig
|
20
19
|
|
21
20
|
|
22
21
|
if TYPE_CHECKING:
|
@@ -35,37 +34,32 @@ class SubModulesMixin:
|
|
35
34
|
|
36
35
|
_rbln_submodules: List[Dict[str, Any]] = []
|
37
36
|
|
38
|
-
def __init__(
|
39
|
-
self,
|
40
|
-
*,
|
41
|
-
rbln_submodules: List["RBLNBaseModel"] = [],
|
42
|
-
**kwargs,
|
43
|
-
) -> None:
|
37
|
+
def __init__(self, *, rbln_submodules: List["RBLNBaseModel"] = [], **kwargs) -> None:
|
44
38
|
for submodule_meta, submodule in zip(self._rbln_submodules, rbln_submodules):
|
45
39
|
setattr(self, submodule_meta["name"], submodule)
|
46
40
|
|
47
41
|
@classmethod
|
48
42
|
def _export_submodules_from_model(
|
49
|
-
cls,
|
50
|
-
model: "PreTrainedModel",
|
51
|
-
model_save_dir: str,
|
52
|
-
rbln_kwargs: Dict[str, Any],
|
53
|
-
**kwargs,
|
43
|
+
cls, model: "PreTrainedModel", model_save_dir: str, rbln_config: RBLNModelConfig, **kwargs
|
54
44
|
) -> List["RBLNBaseModel"]:
|
55
45
|
rbln_submodules = []
|
56
46
|
for submodule in cls._rbln_submodules:
|
57
47
|
submodule_name = submodule["name"]
|
58
48
|
torch_submodule: "PreTrainedModel" = getattr(model, submodule["name"])
|
59
49
|
cls_name = torch_submodule.__class__.__name__
|
60
|
-
submodule_cls: "RBLNBaseModel" = getattr(importlib.import_module("optimum.rbln"), f"RBLN{cls_name}")
|
50
|
+
submodule_cls: Type["RBLNBaseModel"] = getattr(importlib.import_module("optimum.rbln"), f"RBLN{cls_name}")
|
51
|
+
submodule_rbln_config = getattr(rbln_config, submodule_name) or {}
|
61
52
|
|
62
|
-
if
|
63
|
-
|
53
|
+
if isinstance(submodule_rbln_config, dict):
|
54
|
+
submodule_rbln_config_class = submodule_cls.get_rbln_config_class()
|
55
|
+
submodule_rbln_config = submodule_rbln_config_class(**submodule_rbln_config)
|
56
|
+
setattr(rbln_config, submodule_name, submodule_rbln_config)
|
64
57
|
|
65
58
|
rbln_submodule = submodule_cls.from_model(
|
66
59
|
model=torch_submodule,
|
67
60
|
subfolder=submodule_name,
|
68
61
|
model_save_dir=model_save_dir,
|
62
|
+
rbln_config=submodule_rbln_config,
|
69
63
|
**kwargs,
|
70
64
|
)
|
71
65
|
|
@@ -74,55 +68,44 @@ class SubModulesMixin:
|
|
74
68
|
return rbln_submodules
|
75
69
|
|
76
70
|
@classmethod
|
77
|
-
def _load_submodules_from_compiled_models(
|
78
|
-
cls,
|
79
|
-
model_save_dir: str,
|
80
|
-
rbln_kwargs: Dict[str, Any],
|
81
|
-
**kwargs,
|
82
|
-
):
|
71
|
+
def _load_submodules_from_compiled_models(cls, model_save_dir: str, rbln_config: RBLNModelConfig, **kwargs):
|
83
72
|
rbln_submodules = []
|
84
73
|
for submodule in cls._rbln_submodules:
|
85
74
|
submodule_name = submodule["name"]
|
86
75
|
|
87
|
-
if submodule_name in rbln_kwargs:
|
88
|
-
kwargs["rbln_config"] = rbln_kwargs[submodule_name]
|
89
|
-
|
90
76
|
# Get cls name for call the constructor of the rbln class
|
91
|
-
submodule_rbln_config =
|
92
|
-
|
93
|
-
|
77
|
+
submodule_rbln_config = getattr(rbln_config, submodule_name)
|
78
|
+
|
79
|
+
# RBLNModelConfig -> RBLNModel
|
80
|
+
submodule_cls: "RBLNBaseModel" = getattr(
|
81
|
+
importlib.import_module("optimum.rbln"), submodule_rbln_config.rbln_model_cls_name
|
82
|
+
)
|
94
83
|
|
95
84
|
rbln_submodule = submodule_cls._from_pretrained(
|
96
85
|
model_id=model_save_dir,
|
97
86
|
config=None,
|
98
87
|
subfolder=submodule_name,
|
88
|
+
rbln_config=submodule_rbln_config,
|
99
89
|
**kwargs,
|
100
90
|
)
|
91
|
+
|
92
|
+
# update submodule's rbln_config since it is updated in the from_pretrained method
|
93
|
+
setattr(rbln_config, submodule_name, rbln_submodule.rbln_config)
|
101
94
|
rbln_submodules.append(rbln_submodule)
|
95
|
+
|
102
96
|
return rbln_submodules
|
103
97
|
|
104
98
|
@classmethod
|
105
|
-
def _load_submodules(
|
106
|
-
cls,
|
107
|
-
model_save_dir,
|
108
|
-
rbln_kwargs,
|
109
|
-
model=None,
|
110
|
-
**kwargs,
|
111
|
-
):
|
99
|
+
def _load_submodules(cls, model_save_dir, rbln_config: RBLNModelConfig, model=None, **kwargs):
|
112
100
|
# Two ways :
|
113
101
|
# 1. Compile from pytorch object
|
114
102
|
# 2. Load from compiled file
|
115
103
|
if model is not None:
|
116
104
|
return cls._export_submodules_from_model(
|
117
|
-
model=model,
|
118
|
-
model_save_dir=model_save_dir,
|
119
|
-
rbln_kwargs=rbln_kwargs,
|
120
|
-
**kwargs,
|
105
|
+
model=model, model_save_dir=model_save_dir, rbln_config=rbln_config, **kwargs
|
121
106
|
)
|
122
107
|
|
123
108
|
else:
|
124
109
|
return cls._load_submodules_from_compiled_models(
|
125
|
-
model_save_dir=model_save_dir,
|
126
|
-
rbln_kwargs=rbln_kwargs,
|
127
|
-
**kwargs,
|
110
|
+
model_save_dir=model_save_dir, rbln_config=rbln_config, **kwargs
|
128
111
|
)
|
@@ -1,6 +1,6 @@
|
|
1
1
|
Metadata-Version: 2.4
|
2
2
|
Name: optimum-rbln
|
3
|
-
Version: 0.7.
|
3
|
+
Version: 0.7.4a5
|
4
4
|
Summary: Optimum RBLN is the interface between the Hugging Face Transformers and Diffusers libraries and RBLN accelerators. It provides a set of tools enabling easy model loading and inference on single and multiple rbln device settings for different downstream tasks.
|
5
5
|
Project-URL: Homepage, https://rebellions.ai
|
6
6
|
Project-URL: Documentation, https://docs.rbln.ai
|