optimum-rbln 0.7.4a4__py3-none-any.whl → 0.7.4a6__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- optimum/rbln/__init__.py +164 -36
- optimum/rbln/__version__.py +2 -2
- optimum/rbln/configuration_utils.py +772 -0
- optimum/rbln/diffusers/__init__.py +56 -0
- optimum/rbln/diffusers/configurations/__init__.py +30 -0
- optimum/rbln/diffusers/configurations/models/__init__.py +6 -0
- optimum/rbln/diffusers/configurations/models/configuration_autoencoder_kl.py +66 -0
- optimum/rbln/diffusers/configurations/models/configuration_controlnet.py +54 -0
- optimum/rbln/diffusers/configurations/models/configuration_prior_transformer.py +44 -0
- optimum/rbln/diffusers/configurations/models/configuration_transformer_sd3.py +48 -0
- optimum/rbln/diffusers/configurations/models/configuration_unet_2d_condition.py +66 -0
- optimum/rbln/diffusers/configurations/models/configuration_vq_model.py +67 -0
- optimum/rbln/diffusers/configurations/pipelines/__init__.py +30 -0
- optimum/rbln/diffusers/configurations/pipelines/configuration_controlnet.py +221 -0
- optimum/rbln/diffusers/configurations/pipelines/configuration_kandinsky2_2.py +285 -0
- optimum/rbln/diffusers/configurations/pipelines/configuration_stable_diffusion.py +118 -0
- optimum/rbln/diffusers/configurations/pipelines/configuration_stable_diffusion_3.py +143 -0
- optimum/rbln/diffusers/configurations/pipelines/configuration_stable_diffusion_xl.py +124 -0
- optimum/rbln/diffusers/modeling_diffusers.py +63 -122
- optimum/rbln/diffusers/models/autoencoders/autoencoder_kl.py +109 -128
- optimum/rbln/diffusers/models/autoencoders/vae.py +4 -6
- optimum/rbln/diffusers/models/autoencoders/vq_model.py +84 -85
- optimum/rbln/diffusers/models/controlnet.py +55 -70
- optimum/rbln/diffusers/models/transformers/prior_transformer.py +40 -77
- optimum/rbln/diffusers/models/transformers/transformer_sd3.py +43 -68
- optimum/rbln/diffusers/models/unets/unet_2d_condition.py +110 -113
- optimum/rbln/diffusers/pipelines/controlnet/multicontrolnet.py +3 -4
- optimum/rbln/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion.py +2 -1
- optimum/rbln/modeling.py +58 -39
- optimum/rbln/modeling_base.py +107 -78
- optimum/rbln/transformers/__init__.py +87 -8
- optimum/rbln/transformers/configuration_alias.py +49 -0
- optimum/rbln/transformers/configuration_generic.py +142 -0
- optimum/rbln/transformers/modeling_generic.py +193 -280
- optimum/rbln/transformers/models/__init__.py +108 -34
- optimum/rbln/transformers/models/auto/auto_factory.py +3 -3
- optimum/rbln/transformers/models/bart/__init__.py +1 -0
- optimum/rbln/transformers/models/bart/configuration_bart.py +24 -0
- optimum/rbln/transformers/models/bart/modeling_bart.py +10 -84
- optimum/rbln/transformers/models/bert/__init__.py +1 -0
- optimum/rbln/transformers/models/bert/configuration_bert.py +31 -0
- optimum/rbln/transformers/models/bert/modeling_bert.py +7 -80
- optimum/rbln/transformers/models/clip/__init__.py +6 -0
- optimum/rbln/transformers/models/clip/configuration_clip.py +79 -0
- optimum/rbln/transformers/models/clip/modeling_clip.py +72 -75
- optimum/rbln/transformers/models/decoderonly/__init__.py +1 -0
- optimum/rbln/transformers/models/decoderonly/configuration_decoderonly.py +90 -0
- optimum/rbln/transformers/models/decoderonly/decoderonly_architecture.py +115 -84
- optimum/rbln/transformers/models/decoderonly/modeling_decoderonly.py +282 -216
- optimum/rbln/transformers/models/dpt/__init__.py +1 -0
- optimum/rbln/transformers/models/dpt/configuration_dpt.py +19 -0
- optimum/rbln/transformers/models/dpt/modeling_dpt.py +3 -76
- optimum/rbln/transformers/models/exaone/__init__.py +1 -0
- optimum/rbln/transformers/models/exaone/configuration_exaone.py +19 -0
- optimum/rbln/transformers/models/gemma/__init__.py +1 -0
- optimum/rbln/transformers/models/gemma/configuration_gemma.py +19 -0
- optimum/rbln/transformers/models/gpt2/__init__.py +1 -0
- optimum/rbln/transformers/models/gpt2/configuration_gpt2.py +19 -0
- optimum/rbln/transformers/models/llama/__init__.py +1 -0
- optimum/rbln/transformers/models/llama/configuration_llama.py +19 -0
- optimum/rbln/transformers/models/llava_next/__init__.py +1 -0
- optimum/rbln/transformers/models/llava_next/configuration_llava_next.py +46 -0
- optimum/rbln/transformers/models/llava_next/modeling_llava_next.py +12 -23
- optimum/rbln/transformers/models/midm/__init__.py +1 -0
- optimum/rbln/transformers/models/midm/configuration_midm.py +19 -0
- optimum/rbln/transformers/models/mistral/__init__.py +1 -0
- optimum/rbln/transformers/models/mistral/configuration_mistral.py +19 -0
- optimum/rbln/transformers/models/phi/__init__.py +1 -0
- optimum/rbln/transformers/models/phi/configuration_phi.py +19 -0
- optimum/rbln/transformers/models/qwen2/__init__.py +1 -0
- optimum/rbln/transformers/models/qwen2/configuration_qwen2.py +19 -0
- optimum/rbln/transformers/models/qwen2_5_vl/__init__.py +19 -0
- optimum/rbln/transformers/models/qwen2_5_vl/configuration_qwen2_5_vl.py +68 -0
- optimum/rbln/transformers/models/qwen2_5_vl/modeling_qwen2_5_vl.py +608 -0
- optimum/rbln/transformers/models/qwen2_5_vl/qwen2_5_vl_architecture.py +214 -0
- optimum/rbln/transformers/models/seq2seq/__init__.py +1 -0
- optimum/rbln/transformers/models/seq2seq/configuration_seq2seq2.py +66 -0
- optimum/rbln/transformers/models/seq2seq/modeling_seq2seq.py +80 -97
- optimum/rbln/transformers/models/t5/__init__.py +1 -0
- optimum/rbln/transformers/models/t5/configuration_t5.py +24 -0
- optimum/rbln/transformers/models/t5/modeling_t5.py +22 -150
- optimum/rbln/transformers/models/time_series_transformers/__init__.py +1 -0
- optimum/rbln/transformers/models/time_series_transformers/configuration_time_series_transformer.py +34 -0
- optimum/rbln/transformers/models/time_series_transformers/modeling_time_series_transformers.py +52 -54
- optimum/rbln/transformers/models/wav2vec2/__init__.py +1 -0
- optimum/rbln/transformers/models/wav2vec2/configuration_wav2vec.py +19 -0
- optimum/rbln/transformers/models/wav2vec2/modeling_wav2vec2.py +9 -72
- optimum/rbln/transformers/models/whisper/__init__.py +1 -0
- optimum/rbln/transformers/models/whisper/configuration_whisper.py +64 -0
- optimum/rbln/transformers/models/whisper/modeling_whisper.py +57 -72
- optimum/rbln/transformers/models/xlm_roberta/__init__.py +1 -0
- optimum/rbln/transformers/models/xlm_roberta/configuration_xlm_roberta.py +19 -0
- optimum/rbln/transformers/models/xlm_roberta/modeling_xlm_roberta.py +3 -83
- optimum/rbln/utils/runtime_utils.py +33 -2
- optimum/rbln/utils/submodule.py +26 -43
- {optimum_rbln-0.7.4a4.dist-info → optimum_rbln-0.7.4a6.dist-info}/METADATA +1 -1
- optimum_rbln-0.7.4a6.dist-info/RECORD +166 -0
- optimum/rbln/modeling_config.py +0 -310
- optimum_rbln-0.7.4a4.dist-info/RECORD +0 -126
- {optimum_rbln-0.7.4a4.dist-info → optimum_rbln-0.7.4a6.dist-info}/WHEEL +0 -0
- {optimum_rbln-0.7.4a4.dist-info → optimum_rbln-0.7.4a6.dist-info}/licenses/LICENSE +0 -0
@@ -13,48 +13,21 @@
|
|
13
13
|
# limitations under the License.
|
14
14
|
|
15
15
|
import inspect
|
16
|
-
from typing import TYPE_CHECKING, Any, Callable
|
16
|
+
from typing import TYPE_CHECKING, Any, Callable
|
17
17
|
|
18
18
|
import torch
|
19
|
-
from transformers import
|
20
|
-
|
21
|
-
|
22
|
-
T5EncoderModel,
|
23
|
-
T5ForConditionalGeneration,
|
24
|
-
)
|
25
|
-
from transformers.modeling_outputs import BaseModelOutput
|
26
|
-
|
27
|
-
from ....diffusers.modeling_diffusers import RBLNDiffusionMixin
|
28
|
-
from ....modeling import RBLNModel
|
29
|
-
from ....modeling_config import RBLNCompileConfig, RBLNConfig
|
30
|
-
from ....utils.logging import get_logger
|
31
|
-
from ....utils.runtime_utils import RBLNPytorchRuntime
|
19
|
+
from transformers import AutoModelForTextEncoding, T5EncoderModel, T5ForConditionalGeneration
|
20
|
+
|
21
|
+
from ...modeling_generic import RBLNTransformerEncoderForFeatureExtraction
|
32
22
|
from ...models.seq2seq import RBLNModelForSeq2SeqLM
|
23
|
+
from .configuration_t5 import RBLNT5EncoderModelConfig, RBLNT5ForConditionalGenerationConfig
|
33
24
|
from .t5_architecture import T5Wrapper
|
34
25
|
|
35
26
|
|
36
|
-
logger = get_logger()
|
37
|
-
|
38
27
|
if TYPE_CHECKING:
|
39
|
-
from transformers import
|
40
|
-
|
41
|
-
|
42
|
-
class RBLNRuntimeModel(RBLNPytorchRuntime):
|
43
|
-
def forward(
|
44
|
-
self,
|
45
|
-
input_ids: torch.LongTensor,
|
46
|
-
attention_mask: torch.FloatTensor,
|
47
|
-
head_mask: torch.FloatTensor,
|
48
|
-
inputs_embeds: torch.FloatTensor,
|
49
|
-
**kwargs,
|
50
|
-
):
|
51
|
-
return super().forward(
|
52
|
-
input_ids,
|
53
|
-
attention_mask,
|
54
|
-
head_mask,
|
55
|
-
inputs_embeds,
|
56
|
-
**kwargs,
|
57
|
-
)
|
28
|
+
from transformers import PreTrainedModel
|
29
|
+
|
30
|
+
from ....diffusers.modeling_diffusers import RBLNDiffusionMixin, RBLNDiffusionMixinConfig
|
58
31
|
|
59
32
|
|
60
33
|
class T5EncoderWrapper(torch.nn.Module):
|
@@ -67,136 +40,35 @@ class T5EncoderWrapper(torch.nn.Module):
|
|
67
40
|
return self.model(*args, **kwargs, return_dict=False)
|
68
41
|
|
69
42
|
|
70
|
-
class RBLNT5EncoderModel(
|
43
|
+
class RBLNT5EncoderModel(RBLNTransformerEncoderForFeatureExtraction):
|
71
44
|
auto_model_class = AutoModelForTextEncoding
|
72
45
|
rbln_model_input_names = ["input_ids", "attention_mask"]
|
73
46
|
|
74
|
-
def __post_init__(self, **kwargs):
|
75
|
-
self.model = RBLNRuntimeModel(runtime=self.model[0])
|
76
|
-
|
77
47
|
@classmethod
|
78
|
-
def wrap_model_if_needed(self, model: "PreTrainedModel", rbln_config:
|
48
|
+
def wrap_model_if_needed(self, model: "PreTrainedModel", rbln_config: RBLNT5EncoderModelConfig):
|
79
49
|
return T5EncoderWrapper(model)
|
80
50
|
|
81
51
|
@classmethod
|
82
|
-
def update_rbln_config_using_pipe(
|
83
|
-
batch_size = rbln_config.get("batch_size", 1)
|
84
|
-
max_sequence_length = rbln_config.get("max_sequence_length", 256)
|
85
|
-
model_input_names = ["input_ids"]
|
86
|
-
|
87
|
-
rbln_config.update(
|
88
|
-
{
|
89
|
-
"batch_size": batch_size,
|
90
|
-
"max_seq_len": max_sequence_length,
|
91
|
-
"model_input_names": model_input_names,
|
92
|
-
}
|
93
|
-
)
|
94
|
-
|
95
|
-
return rbln_config
|
96
|
-
|
97
|
-
@classmethod
|
98
|
-
def _get_rbln_config(
|
52
|
+
def update_rbln_config_using_pipe(
|
99
53
|
cls,
|
100
|
-
|
101
|
-
|
102
|
-
|
103
|
-
) ->
|
104
|
-
|
105
|
-
|
106
|
-
|
107
|
-
|
108
|
-
max_position_embeddings = getattr(model_config, "n_positions", None)
|
109
|
-
|
110
|
-
if rbln_max_seq_len is None:
|
111
|
-
rbln_max_seq_len = max_position_embeddings
|
112
|
-
if rbln_max_seq_len is None:
|
113
|
-
for tokenizer in preprocessors:
|
114
|
-
if hasattr(tokenizer, "model_max_length"):
|
115
|
-
rbln_max_seq_len = tokenizer.model_max_length
|
116
|
-
break
|
117
|
-
if rbln_max_seq_len is None:
|
118
|
-
raise ValueError("`rbln_max_seq_len` should be specified!")
|
119
|
-
|
120
|
-
if max_position_embeddings is not None and rbln_max_seq_len > max_position_embeddings:
|
121
|
-
raise ValueError("`rbln_max_seq_len` should be less or equal than max_position_embeddings!")
|
122
|
-
|
123
|
-
signature_params = inspect.signature(cls.get_hf_class().forward).parameters.keys()
|
124
|
-
|
125
|
-
if rbln_model_input_names is None:
|
126
|
-
for tokenizer in preprocessors:
|
127
|
-
if hasattr(tokenizer, "model_input_names"):
|
128
|
-
rbln_model_input_names = [name for name in signature_params if name in tokenizer.model_input_names]
|
129
|
-
|
130
|
-
invalid_params = set(rbln_model_input_names) - set(signature_params)
|
131
|
-
if invalid_params:
|
132
|
-
raise ValueError(f"Invalid model input names: {invalid_params}")
|
133
|
-
break
|
134
|
-
if rbln_model_input_names is None and hasattr(cls, "rbln_model_input_names"):
|
135
|
-
rbln_model_input_names = cls.rbln_model_input_names
|
136
|
-
elif rbln_model_input_names is None and hasattr(cls, "rbln_model_input_names") is False:
|
137
|
-
raise ValueError(
|
138
|
-
"Specify the model input names obtained by the tokenizer via `rbln_model_input_names`, "
|
139
|
-
f"and be sure to make the order of the inputs same as T5EncoderModel forward() arguments like ({list(signature_params)})"
|
140
|
-
)
|
141
|
-
else:
|
142
|
-
invalid_params = set(rbln_model_input_names) - set(signature_params)
|
143
|
-
if invalid_params:
|
144
|
-
raise ValueError(f"Invalid model input names: {invalid_params}")
|
145
|
-
rbln_model_input_names = [name for name in signature_params if name in rbln_model_input_names]
|
146
|
-
|
147
|
-
if rbln_batch_size is None:
|
148
|
-
rbln_batch_size = 1
|
149
|
-
|
150
|
-
input_info = [
|
151
|
-
(model_input_name, [rbln_batch_size, rbln_max_seq_len], "int64")
|
152
|
-
for model_input_name in rbln_model_input_names
|
153
|
-
]
|
154
|
-
|
155
|
-
rbln_compile_config = RBLNCompileConfig(input_info=input_info)
|
156
|
-
|
157
|
-
rbln_config = RBLNConfig(
|
158
|
-
rbln_cls=cls.__name__,
|
159
|
-
compile_cfgs=[rbln_compile_config],
|
160
|
-
rbln_kwargs=rbln_kwargs,
|
161
|
-
)
|
162
|
-
|
163
|
-
rbln_config.model_cfg.update({"max_seq_len": rbln_max_seq_len})
|
54
|
+
pipe: "RBLNDiffusionMixin",
|
55
|
+
rbln_config: "RBLNDiffusionMixinConfig",
|
56
|
+
submodule_name: str,
|
57
|
+
) -> "RBLNDiffusionMixinConfig":
|
58
|
+
submodule_config = getattr(rbln_config, submodule_name)
|
59
|
+
submodule_config.max_seq_len = rbln_config.max_seq_len or 256
|
60
|
+
submodule_config.model_input_names = ["input_ids"]
|
164
61
|
return rbln_config
|
165
62
|
|
166
|
-
def forward(
|
167
|
-
self,
|
168
|
-
input_ids: Optional[torch.LongTensor] = None,
|
169
|
-
attention_mask: Optional[torch.FloatTensor] = None,
|
170
|
-
head_mask: Optional[torch.FloatTensor] = None,
|
171
|
-
inputs_embeds: Optional[torch.FloatTensor] = None,
|
172
|
-
output_attentions: Optional[bool] = None,
|
173
|
-
output_hidden_states: Optional[bool] = None,
|
174
|
-
return_dict: Optional[bool] = None,
|
175
|
-
) -> Union[Tuple[torch.FloatTensor], BaseModelOutput]:
|
176
|
-
encoder_outputs = self.model(
|
177
|
-
input_ids=input_ids,
|
178
|
-
attention_mask=attention_mask,
|
179
|
-
inputs_embeds=inputs_embeds,
|
180
|
-
head_mask=head_mask,
|
181
|
-
output_attentions=output_attentions,
|
182
|
-
output_hidden_states=output_hidden_states,
|
183
|
-
return_dict=return_dict,
|
184
|
-
)
|
185
|
-
if not return_dict:
|
186
|
-
return (encoder_outputs,)
|
187
|
-
else:
|
188
|
-
return BaseModelOutput(last_hidden_state=encoder_outputs)
|
189
|
-
|
190
63
|
|
191
64
|
class RBLNT5ForConditionalGeneration(RBLNModelForSeq2SeqLM):
|
192
65
|
support_causal_attn = False
|
193
66
|
|
194
67
|
@classmethod
|
195
|
-
def wrap_model_if_needed(self, model: "PreTrainedModel", rbln_config:
|
196
|
-
|
197
|
-
|
198
|
-
|
199
|
-
return T5Wrapper(model, enc_max_seq_len=enc_max_seq_len, dec_max_seq_len=dec_max_seq_len)
|
68
|
+
def wrap_model_if_needed(self, model: "PreTrainedModel", rbln_config: RBLNT5ForConditionalGenerationConfig):
|
69
|
+
return T5Wrapper(
|
70
|
+
model, enc_max_seq_len=rbln_config.enc_max_seq_len, dec_max_seq_len=rbln_config.dec_max_seq_len
|
71
|
+
)
|
200
72
|
|
201
73
|
def __getattr__(self, __name: str) -> Any:
|
202
74
|
def redirect(func):
|
@@ -22,4 +22,5 @@
|
|
22
22
|
# from Rebellions Inc.
|
23
23
|
|
24
24
|
from ....ops import paged_add_softmax_attn_decode, rbln_cache_update
|
25
|
+
from .configuration_time_series_transformer import RBLNTimeSeriesTransformerForPredictionConfig
|
25
26
|
from .modeling_time_series_transformers import RBLNTimeSeriesTransformerForPrediction
|
optimum/rbln/transformers/models/time_series_transformers/configuration_time_series_transformer.py
ADDED
@@ -0,0 +1,34 @@
|
|
1
|
+
from typing import Optional
|
2
|
+
|
3
|
+
from ....configuration_utils import RBLNModelConfig
|
4
|
+
|
5
|
+
|
6
|
+
class RBLNTimeSeriesTransformerForPredictionConfig(RBLNModelConfig):
|
7
|
+
def __init__(
|
8
|
+
self,
|
9
|
+
batch_size: Optional[int] = None,
|
10
|
+
enc_max_seq_len: Optional[int] = None,
|
11
|
+
dec_max_seq_len: Optional[int] = None,
|
12
|
+
num_parallel_samples: Optional[int] = None,
|
13
|
+
**kwargs,
|
14
|
+
):
|
15
|
+
"""
|
16
|
+
Args:
|
17
|
+
batch_size (Optional[int]): The batch size for inference. Defaults to 1.
|
18
|
+
enc_max_seq_len (Optional[int]): Maximum sequence length for the encoder.
|
19
|
+
dec_max_seq_len (Optional[int]): Maximum sequence length for the decoder.
|
20
|
+
num_parallel_samples (Optional[int]): Number of samples to generate in parallel during prediction.
|
21
|
+
**kwargs: Additional arguments passed to the parent RBLNModelConfig.
|
22
|
+
|
23
|
+
Raises:
|
24
|
+
ValueError: If batch_size is not a positive integer.
|
25
|
+
"""
|
26
|
+
super().__init__(**kwargs)
|
27
|
+
|
28
|
+
self.batch_size = batch_size or 1
|
29
|
+
if not isinstance(self.batch_size, int) or self.batch_size <= 0:
|
30
|
+
raise ValueError(f"batch_size must be a positive integer, got {self.batch_size}")
|
31
|
+
|
32
|
+
self.enc_max_seq_len = enc_max_seq_len
|
33
|
+
self.dec_max_seq_len = dec_max_seq_len
|
34
|
+
self.num_parallel_samples = num_parallel_samples
|
optimum/rbln/transformers/models/time_series_transformers/modeling_time_series_transformers.py
CHANGED
@@ -25,7 +25,7 @@ import inspect
|
|
25
25
|
import logging
|
26
26
|
from dataclasses import dataclass
|
27
27
|
from pathlib import Path
|
28
|
-
from typing import TYPE_CHECKING, Any, Callable,
|
28
|
+
from typing import TYPE_CHECKING, Any, Callable, List, Optional, Tuple, Union
|
29
29
|
|
30
30
|
import rebel
|
31
31
|
import torch
|
@@ -38,9 +38,10 @@ from transformers import (
|
|
38
38
|
from transformers.modeling_outputs import ModelOutput, SampleTSPredictionOutput, Seq2SeqTSModelOutput
|
39
39
|
from transformers.modeling_utils import no_init_weights
|
40
40
|
|
41
|
+
from ....configuration_utils import RBLNCompileConfig
|
41
42
|
from ....modeling import RBLNModel
|
42
|
-
from ....modeling_config import RBLNCompileConfig, RBLNConfig
|
43
43
|
from ....utils.runtime_utils import RBLNPytorchRuntime
|
44
|
+
from .configuration_time_series_transformer import RBLNTimeSeriesTransformerForPredictionConfig
|
44
45
|
from .time_series_transformers_architecture import TimeSeriesTransformersWrapper
|
45
46
|
|
46
47
|
|
@@ -124,9 +125,9 @@ class RBLNTimeSeriesTransformerForPrediction(RBLNModel):
|
|
124
125
|
|
125
126
|
def __post_init__(self, **kwargs):
|
126
127
|
super().__post_init__(**kwargs)
|
127
|
-
self.batch_size = self.rbln_config.
|
128
|
-
self.dec_max_seq_len = self.rbln_config.
|
129
|
-
self.num_parallel_samples = self.rbln_config.
|
128
|
+
self.batch_size = self.rbln_config.batch_size
|
129
|
+
self.dec_max_seq_len = self.rbln_config.dec_max_seq_len
|
130
|
+
self.num_parallel_samples = self.rbln_config.num_parallel_samples
|
130
131
|
|
131
132
|
with no_init_weights():
|
132
133
|
self._origin_model = TimeSeriesTransformerForPrediction._from_config(self.config)
|
@@ -156,12 +157,14 @@ class RBLNTimeSeriesTransformerForPrediction(RBLNModel):
|
|
156
157
|
return redirect(val)
|
157
158
|
|
158
159
|
@classmethod
|
159
|
-
def wrap_model_if_needed(
|
160
|
-
|
160
|
+
def wrap_model_if_needed(
|
161
|
+
self, model: "PreTrainedModel", rbln_config: RBLNTimeSeriesTransformerForPredictionConfig
|
162
|
+
):
|
163
|
+
return TimeSeriesTransformersWrapper(model, rbln_config.num_parallel_samples)
|
161
164
|
|
162
165
|
@classmethod
|
163
166
|
@torch.inference_mode()
|
164
|
-
def get_compiled_model(cls, model, rbln_config:
|
167
|
+
def get_compiled_model(cls, model, rbln_config: RBLNTimeSeriesTransformerForPredictionConfig):
|
165
168
|
wrapped_model = cls.wrap_model_if_needed(model, rbln_config)
|
166
169
|
|
167
170
|
enc_compile_config = rbln_config.compile_cfgs[0]
|
@@ -206,7 +209,7 @@ class RBLNTimeSeriesTransformerForPrediction(RBLNModel):
|
|
206
209
|
model: "PreTrainedModel",
|
207
210
|
save_dir_path: Path,
|
208
211
|
subfolder: str,
|
209
|
-
rbln_config:
|
212
|
+
rbln_config: RBLNTimeSeriesTransformerForPredictionConfig,
|
210
213
|
):
|
211
214
|
"""
|
212
215
|
If you are unavoidably running on a CPU rather than an RBLN device,
|
@@ -217,31 +220,28 @@ class RBLNTimeSeriesTransformerForPrediction(RBLNModel):
|
|
217
220
|
torch.save(save_dict, save_dir_path / subfolder / "torch_artifacts.pth")
|
218
221
|
|
219
222
|
@classmethod
|
220
|
-
def
|
223
|
+
def _update_rbln_config(
|
221
224
|
cls,
|
222
|
-
preprocessors: Optional[Union["AutoFeatureExtractor", "AutoProcessor", "AutoTokenizer"]],
|
223
|
-
|
224
|
-
|
225
|
-
|
226
|
-
|
227
|
-
|
228
|
-
|
229
|
-
|
230
|
-
if not isinstance(rbln_batch_size, int):
|
231
|
-
raise TypeError(f"Expected rbln_batch_size to be an int, but got {type(rbln_batch_size)}")
|
232
|
-
|
233
|
-
rbln_num_parallel_samples = (
|
234
|
-
model_config.num_parallel_samples if rbln_num_parallel_samples is None else rbln_num_parallel_samples
|
235
|
-
)
|
236
|
-
if rbln_dec_max_seq_len is None:
|
225
|
+
preprocessors: Optional[Union["AutoFeatureExtractor", "AutoProcessor", "AutoTokenizer"]] = None,
|
226
|
+
model: Optional["PreTrainedModel"] = None,
|
227
|
+
model_config: Optional["PretrainedConfig"] = None,
|
228
|
+
rbln_config: Optional[RBLNTimeSeriesTransformerForPredictionConfig] = None,
|
229
|
+
) -> RBLNTimeSeriesTransformerForPredictionConfig:
|
230
|
+
rbln_config.num_parallel_samples = rbln_config.num_parallel_samples or model_config.num_parallel_samples
|
231
|
+
|
232
|
+
if rbln_config.dec_max_seq_len is None:
|
237
233
|
predict_length = model_config.prediction_length
|
238
|
-
|
234
|
+
rbln_config.dec_max_seq_len = (
|
239
235
|
predict_length if predict_length % 64 == 0 else predict_length + (64 - predict_length % 64)
|
240
236
|
)
|
241
237
|
|
242
238
|
# model input info
|
243
239
|
enc_input_info = [
|
244
|
-
(
|
240
|
+
(
|
241
|
+
"inputs_embeds",
|
242
|
+
[rbln_config.batch_size, model_config.context_length, model_config.feature_size],
|
243
|
+
"float32",
|
244
|
+
),
|
245
245
|
]
|
246
246
|
enc_input_info.extend(
|
247
247
|
[
|
@@ -249,7 +249,7 @@ class RBLNTimeSeriesTransformerForPrediction(RBLNModel):
|
|
249
249
|
"cross_key_value_states",
|
250
250
|
[
|
251
251
|
model_config.decoder_layers * 2,
|
252
|
-
|
252
|
+
rbln_config.batch_size,
|
253
253
|
model_config.decoder_attention_heads,
|
254
254
|
model_config.context_length,
|
255
255
|
model_config.d_model // model_config.decoder_attention_heads,
|
@@ -260,8 +260,12 @@ class RBLNTimeSeriesTransformerForPrediction(RBLNModel):
|
|
260
260
|
)
|
261
261
|
|
262
262
|
dec_input_info = [
|
263
|
-
(
|
264
|
-
|
263
|
+
(
|
264
|
+
"inputs_embeds",
|
265
|
+
[rbln_config.batch_size * rbln_config.num_parallel_samples, 1, model_config.feature_size],
|
266
|
+
"float32",
|
267
|
+
),
|
268
|
+
("attention_mask", [1, rbln_config.dec_max_seq_len], "float32"),
|
265
269
|
("cache_position", [], "int32"),
|
266
270
|
("block_tables", [1, 1], "int16"),
|
267
271
|
]
|
@@ -271,7 +275,7 @@ class RBLNTimeSeriesTransformerForPrediction(RBLNModel):
|
|
271
275
|
"cross_key_value_states",
|
272
276
|
[
|
273
277
|
model_config.decoder_layers * 2, # 4
|
274
|
-
|
278
|
+
rbln_config.batch_size, # 64
|
275
279
|
model_config.decoder_attention_heads, # 2
|
276
280
|
model_config.context_length, # 24
|
277
281
|
model_config.d_model // model_config.decoder_attention_heads, # 13
|
@@ -286,8 +290,10 @@ class RBLNTimeSeriesTransformerForPrediction(RBLNModel):
|
|
286
290
|
f"self_key_value_states_{i}",
|
287
291
|
[
|
288
292
|
1,
|
289
|
-
model_config.decoder_attention_heads
|
290
|
-
|
293
|
+
model_config.decoder_attention_heads
|
294
|
+
* rbln_config.num_parallel_samples
|
295
|
+
* rbln_config.batch_size,
|
296
|
+
rbln_config.dec_max_seq_len,
|
291
297
|
model_config.d_model // model_config.encoder_attention_heads,
|
292
298
|
],
|
293
299
|
"float32",
|
@@ -298,38 +304,30 @@ class RBLNTimeSeriesTransformerForPrediction(RBLNModel):
|
|
298
304
|
enc_compile_config = RBLNCompileConfig(compiled_model_name="encoder", input_info=enc_input_info)
|
299
305
|
dec_compile_config = RBLNCompileConfig(compiled_model_name="decoder", input_info=dec_input_info)
|
300
306
|
|
301
|
-
rbln_config
|
302
|
-
rbln_cls=cls.__name__,
|
303
|
-
compile_cfgs=[enc_compile_config, dec_compile_config],
|
304
|
-
rbln_kwargs=rbln_kwargs,
|
305
|
-
)
|
306
|
-
|
307
|
-
rbln_config.model_cfg.update(
|
308
|
-
{
|
309
|
-
"batch_size": rbln_batch_size,
|
310
|
-
"num_parallel_samples": rbln_num_parallel_samples,
|
311
|
-
"dec_max_seq_len": rbln_dec_max_seq_len,
|
312
|
-
}
|
313
|
-
)
|
314
|
-
|
307
|
+
rbln_config.set_compile_cfgs([enc_compile_config, dec_compile_config])
|
315
308
|
return rbln_config
|
316
309
|
|
317
310
|
@classmethod
|
318
311
|
def _create_runtimes(
|
319
312
|
cls,
|
320
313
|
compiled_models: List[rebel.RBLNCompiledModel],
|
321
|
-
|
322
|
-
activate_profiler: Optional[bool] = None,
|
314
|
+
rbln_config: RBLNTimeSeriesTransformerForPredictionConfig,
|
323
315
|
) -> List[rebel.Runtime]:
|
324
|
-
if any(model_name not in
|
316
|
+
if any(model_name not in rbln_config.device_map for model_name in ["encoder", "decoder"]):
|
325
317
|
cls._raise_missing_compiled_file_error(["encoder", "decoder"])
|
326
318
|
|
327
319
|
return [
|
328
|
-
|
329
|
-
|
320
|
+
rebel.Runtime(
|
321
|
+
compiled_models[0],
|
322
|
+
tensor_type="pt",
|
323
|
+
device=rbln_config.device_map["encoder"],
|
324
|
+
activate_profiler=rbln_config.activate_profiler,
|
330
325
|
),
|
331
|
-
|
332
|
-
|
326
|
+
rebel.Runtime(
|
327
|
+
compiled_models[1],
|
328
|
+
tensor_type="pt",
|
329
|
+
device=rbln_config.device_map["decoder"],
|
330
|
+
activate_profiler=rbln_config.activate_profiler,
|
333
331
|
),
|
334
332
|
]
|
335
333
|
|
@@ -0,0 +1,19 @@
|
|
1
|
+
# Copyright 2025 Rebellions Inc. All rights reserved.
|
2
|
+
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4
|
+
# you may not use this file except in compliance with the License.
|
5
|
+
# You may obtain a copy of the License at:
|
6
|
+
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
8
|
+
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12
|
+
# See the License for the specific language governing permissions and
|
13
|
+
# limitations under the License.
|
14
|
+
|
15
|
+
from ...configuration_generic import RBLNModelForMaskedLMConfig
|
16
|
+
|
17
|
+
|
18
|
+
class RBLNWav2Vec2ForCTCConfig(RBLNModelForMaskedLMConfig):
|
19
|
+
rbln_model_input_names = ["input_values"]
|
@@ -12,26 +12,13 @@
|
|
12
12
|
# See the License for the specific language governing permissions and
|
13
13
|
# limitations under the License.
|
14
14
|
|
15
|
-
from typing import TYPE_CHECKING, Any, Dict, Union
|
16
15
|
|
17
16
|
import torch
|
18
|
-
from transformers import AutoModelForMaskedLM,
|
17
|
+
from transformers import AutoModelForMaskedLM, Wav2Vec2ForCTC
|
19
18
|
from transformers.modeling_outputs import CausalLMOutput
|
20
19
|
|
21
|
-
from
|
22
|
-
from
|
23
|
-
from ....utils.logging import get_logger
|
24
|
-
|
25
|
-
|
26
|
-
logger = get_logger(__name__)
|
27
|
-
|
28
|
-
if TYPE_CHECKING:
|
29
|
-
from transformers import (
|
30
|
-
AutoFeatureExtractor,
|
31
|
-
AutoProcessor,
|
32
|
-
AutoTokenizer,
|
33
|
-
PretrainedConfig,
|
34
|
-
)
|
20
|
+
from ...modeling_generic import RBLNModelForMaskedLM
|
21
|
+
from .configuration_wav2vec import RBLNWav2Vec2ForCTCConfig
|
35
22
|
|
36
23
|
|
37
24
|
class _Wav2Vec2(torch.nn.Module):
|
@@ -44,11 +31,11 @@ class _Wav2Vec2(torch.nn.Module):
|
|
44
31
|
return self.model.lm_head(output[0])
|
45
32
|
|
46
33
|
|
47
|
-
class RBLNWav2Vec2ForCTC(
|
34
|
+
class RBLNWav2Vec2ForCTC(RBLNModelForMaskedLM):
|
48
35
|
"""
|
49
36
|
Wav2Vec2 Model with a `language modeling` head on top for Connectionist Temporal Classification (CTC).
|
50
37
|
|
51
|
-
This model inherits from [`
|
38
|
+
This model inherits from [`RBLNModelForMaskedLM`]. Check the superclass documentation for the generic methods the
|
52
39
|
library implements for all its model.
|
53
40
|
|
54
41
|
It implements the methods to convert a pre-trained Wav2Vec2 model into a RBLN Wav2Vec2 model by:
|
@@ -58,60 +45,10 @@ class RBLNWav2Vec2ForCTC(RBLNModel):
|
|
58
45
|
|
59
46
|
main_input_name = "input_values"
|
60
47
|
auto_model_class = AutoModelForMaskedLM
|
48
|
+
rbln_dtype = "float32"
|
49
|
+
output_class = CausalLMOutput
|
50
|
+
output_key = "logits"
|
61
51
|
|
62
52
|
@classmethod
|
63
|
-
def wrap_model_if_needed(cls, model: torch.nn.Module, rbln_config:
|
53
|
+
def wrap_model_if_needed(cls, model: torch.nn.Module, rbln_config: RBLNWav2Vec2ForCTCConfig) -> torch.nn.Module:
|
64
54
|
return _Wav2Vec2(model).eval()
|
65
|
-
|
66
|
-
@classmethod
|
67
|
-
def _get_rbln_config(
|
68
|
-
cls,
|
69
|
-
preprocessors: Union["AutoFeatureExtractor", "AutoProcessor", "AutoTokenizer"],
|
70
|
-
model_config: "PretrainedConfig",
|
71
|
-
rbln_kwargs: Dict[str, Any] = {},
|
72
|
-
) -> RBLNConfig:
|
73
|
-
rbln_max_seq_len = rbln_kwargs.get("max_seq_len", None)
|
74
|
-
rbln_batch_size = rbln_kwargs.get("batch_size", None)
|
75
|
-
|
76
|
-
if rbln_max_seq_len is None:
|
77
|
-
for tokenizer in preprocessors:
|
78
|
-
if hasattr(tokenizer, "model_max_length"):
|
79
|
-
rbln_max_seq_len = tokenizer.model_max_length
|
80
|
-
break
|
81
|
-
if rbln_max_seq_len is None:
|
82
|
-
raise ValueError("`rbln_max_seq_len` should be specified!")
|
83
|
-
|
84
|
-
if rbln_batch_size is None:
|
85
|
-
rbln_batch_size = 1
|
86
|
-
|
87
|
-
input_info = [
|
88
|
-
(
|
89
|
-
"input_values",
|
90
|
-
[
|
91
|
-
rbln_batch_size,
|
92
|
-
rbln_max_seq_len,
|
93
|
-
],
|
94
|
-
"float32",
|
95
|
-
),
|
96
|
-
]
|
97
|
-
|
98
|
-
rbln_compile_config = RBLNCompileConfig(input_info=input_info)
|
99
|
-
|
100
|
-
rbln_config = RBLNConfig(
|
101
|
-
rbln_cls=cls.__name__,
|
102
|
-
compile_cfgs=[rbln_compile_config],
|
103
|
-
rbln_kwargs=rbln_kwargs,
|
104
|
-
)
|
105
|
-
|
106
|
-
rbln_config.model_cfg.update(
|
107
|
-
{
|
108
|
-
"max_seq_len": rbln_max_seq_len,
|
109
|
-
"batch_size": rbln_batch_size,
|
110
|
-
}
|
111
|
-
)
|
112
|
-
|
113
|
-
return rbln_config
|
114
|
-
|
115
|
-
def forward(self, input_values: "torch.Tensor", **kwargs):
|
116
|
-
outputs = super().forward(input_values, **kwargs)
|
117
|
-
return CausalLMOutput(logits=outputs)
|
@@ -0,0 +1,64 @@
|
|
1
|
+
# Copyright 2025 Rebellions Inc. All rights reserved.
|
2
|
+
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4
|
+
# you may not use this file except in compliance with the License.
|
5
|
+
# You may obtain a copy of the License at:
|
6
|
+
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
8
|
+
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12
|
+
# See the License for the specific language governing permissions and
|
13
|
+
# limitations under the License.
|
14
|
+
|
15
|
+
import rebel
|
16
|
+
|
17
|
+
from ....configuration_utils import RBLNModelConfig
|
18
|
+
from ....utils.logging import get_logger
|
19
|
+
|
20
|
+
|
21
|
+
logger = get_logger()
|
22
|
+
|
23
|
+
|
24
|
+
class RBLNWhisperForConditionalGenerationConfig(RBLNModelConfig):
|
25
|
+
def __init__(
|
26
|
+
self,
|
27
|
+
batch_size: int = None,
|
28
|
+
token_timestamps: bool = None,
|
29
|
+
use_attention_mask: bool = None,
|
30
|
+
enc_max_seq_len: int = None,
|
31
|
+
dec_max_seq_len: int = None,
|
32
|
+
**kwargs,
|
33
|
+
):
|
34
|
+
"""
|
35
|
+
Args:
|
36
|
+
batch_size (int, optional): The batch size for inference. Defaults to 1.
|
37
|
+
token_timestamps (bool, optional): Whether to output token timestamps during generation. Defaults to False.
|
38
|
+
use_attention_mask (bool, optional): Whether to use attention masks during inference. This is automatically
|
39
|
+
set to True for RBLN-CA02 devices.
|
40
|
+
enc_max_seq_len (int, optional): Maximum sequence length for the encoder.
|
41
|
+
dec_max_seq_len (int, optional): Maximum sequence length for the decoder.
|
42
|
+
**kwargs: Additional arguments passed to the parent RBLNModelConfig.
|
43
|
+
|
44
|
+
Raises:
|
45
|
+
ValueError: If batch_size is not a positive integer.
|
46
|
+
"""
|
47
|
+
super().__init__(**kwargs)
|
48
|
+
|
49
|
+
self.batch_size = batch_size or 1
|
50
|
+
if not isinstance(self.batch_size, int) or self.batch_size < 0:
|
51
|
+
raise ValueError(f"batch_size must be a positive integer, got {self.batch_size}")
|
52
|
+
|
53
|
+
self.token_timestamps = token_timestamps or False
|
54
|
+
self.enc_max_seq_len = enc_max_seq_len
|
55
|
+
self.dec_max_seq_len = dec_max_seq_len
|
56
|
+
|
57
|
+
self.use_attention_mask = use_attention_mask
|
58
|
+
npu = self.npu or rebel.get_npu_name()
|
59
|
+
if npu == "RBLN-CA02":
|
60
|
+
if self.use_attention_mask is False:
|
61
|
+
logger.warning("Attention mask should be used with RBLN-CA02. Setting use_attention_mask to True.")
|
62
|
+
self.use_attention_mask = True
|
63
|
+
else:
|
64
|
+
self.use_attention_mask = self.use_attention_mask or False
|