optimum-rbln 0.1.15__py3-none-any.whl → 0.2.1a0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- optimum/rbln/__init__.py +26 -33
- optimum/rbln/__version__.py +2 -2
- optimum/rbln/diffusers/__init__.py +4 -0
- optimum/rbln/{modeling_diffusers.py → diffusers/modeling_diffusers.py} +66 -24
- optimum/rbln/diffusers/models/__init__.py +2 -0
- optimum/rbln/diffusers/models/autoencoders/autoencoder_kl.py +38 -12
- optimum/rbln/diffusers/models/autoencoders/vae.py +0 -1
- optimum/rbln/diffusers/models/controlnet.py +1 -1
- optimum/rbln/diffusers/models/transformers/transformer_sd3.py +1 -1
- optimum/rbln/diffusers/models/unets/unet_2d_condition.py +5 -7
- optimum/rbln/diffusers/pipelines/__init__.py +1 -0
- optimum/rbln/diffusers/pipelines/controlnet/multicontrolnet.py +8 -7
- optimum/rbln/diffusers/pipelines/controlnet/pipeline_controlnet.py +17 -2
- optimum/rbln/diffusers/pipelines/controlnet/pipeline_controlnet_img2img.py +17 -2
- optimum/rbln/diffusers/pipelines/controlnet/pipeline_controlnet_sd_xl.py +17 -2
- optimum/rbln/diffusers/pipelines/controlnet/pipeline_controlnet_sd_xl_img2img.py +17 -2
- optimum/rbln/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion.py +1 -2
- optimum/rbln/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_img2img.py +1 -2
- optimum/rbln/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_inpaint.py +1 -2
- optimum/rbln/diffusers/pipelines/stable_diffusion_3/pipeline_stable_diffusion_3.py +1 -2
- optimum/rbln/diffusers/pipelines/stable_diffusion_3/pipeline_stable_diffusion_3_img2img.py +1 -2
- optimum/rbln/diffusers/pipelines/stable_diffusion_3/pipeline_stable_diffusion_3_inpaint.py +1 -2
- optimum/rbln/diffusers/pipelines/stable_diffusion_xl/__init__.py +23 -0
- optimum/rbln/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl.py +1 -2
- optimum/rbln/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_img2img.py +1 -2
- optimum/rbln/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_inpaint.py +1 -2
- optimum/rbln/modeling.py +13 -347
- optimum/rbln/modeling_base.py +24 -4
- optimum/rbln/modeling_config.py +31 -7
- optimum/rbln/ops/__init__.py +26 -0
- optimum/rbln/ops/attn.py +221 -0
- optimum/rbln/ops/flash_attn.py +70 -0
- optimum/rbln/ops/kv_cache_update.py +69 -0
- optimum/rbln/transformers/__init__.py +20 -0
- optimum/rbln/{modeling_alias.py → transformers/modeling_alias.py} +5 -1
- optimum/rbln/transformers/modeling_generic.py +385 -0
- optimum/rbln/transformers/models/auto/__init__.py +23 -0
- optimum/rbln/transformers/models/auto/modeling_auto.py +0 -1
- optimum/rbln/transformers/models/bart/__init__.py +0 -1
- optimum/rbln/transformers/models/bart/bart_architecture.py +107 -464
- optimum/rbln/transformers/models/bart/modeling_bart.py +8 -4
- optimum/rbln/transformers/models/clip/modeling_clip.py +1 -1
- optimum/rbln/transformers/models/decoderonly/__init__.py +0 -7
- optimum/rbln/transformers/models/decoderonly/decoderonly_architecture.py +329 -328
- optimum/rbln/transformers/models/decoderonly/modeling_decoderonly.py +92 -107
- optimum/rbln/transformers/models/exaone/exaone_architecture.py +2 -3
- optimum/rbln/transformers/models/gemma/gemma_architecture.py +1 -1
- optimum/rbln/transformers/models/gpt2/gpt2_architecture.py +10 -10
- optimum/rbln/transformers/models/gpt2/modeling_gpt2.py +1 -1
- optimum/rbln/transformers/models/llama/llama_architecture.py +0 -1
- optimum/rbln/transformers/models/llava_next/modeling_llava_next.py +1 -0
- optimum/rbln/transformers/models/midm/midm_architecture.py +11 -11
- optimum/rbln/transformers/models/midm/modeling_midm.py +0 -1
- optimum/rbln/transformers/models/mistral/mistral_architecture.py +0 -1
- optimum/rbln/transformers/models/phi/phi_architecture.py +2 -3
- optimum/rbln/transformers/models/qwen2/qwen2_architecture.py +0 -1
- optimum/rbln/transformers/models/seq2seq/modeling_seq2seq.py +57 -57
- optimum/rbln/transformers/models/seq2seq/seq2seq_architecture.py +498 -0
- optimum/rbln/transformers/models/t5/__init__.py +0 -1
- optimum/rbln/transformers/models/t5/modeling_t5.py +5 -2
- optimum/rbln/transformers/models/t5/t5_architecture.py +106 -448
- optimum/rbln/transformers/models/whisper/generation_whisper.py +42 -0
- optimum/rbln/transformers/models/whisper/modeling_whisper.py +77 -54
- optimum/rbln/transformers/models/whisper/whisper_architecture.py +219 -312
- optimum/rbln/transformers/utils/rbln_quantization.py +1 -2
- optimum/rbln/utils/decorator_utils.py +51 -15
- optimum/rbln/utils/import_utils.py +8 -1
- optimum/rbln/utils/logging.py +38 -1
- optimum/rbln/utils/model_utils.py +0 -1
- optimum/rbln/utils/runtime_utils.py +9 -3
- optimum/rbln/utils/save_utils.py +17 -0
- optimum/rbln/utils/submodule.py +23 -0
- optimum_rbln-0.2.1a0.dist-info/METADATA +121 -0
- {optimum_rbln-0.1.15.dist-info → optimum_rbln-0.2.1a0.dist-info}/RECORD +76 -72
- optimum_rbln-0.2.1a0.dist-info/licenses/LICENSE +288 -0
- optimum/rbln/transformers/cache_utils.py +0 -107
- optimum/rbln/utils/timer_utils.py +0 -43
- optimum_rbln-0.1.15.dist-info/METADATA +0 -106
- optimum_rbln-0.1.15.dist-info/licenses/LICENSE +0 -201
- {optimum_rbln-0.1.15.dist-info → optimum_rbln-0.2.1a0.dist-info}/WHEEL +0 -0
@@ -27,6 +27,7 @@ from typing import TYPE_CHECKING, Any, Callable, Dict, List, Optional, Union
|
|
27
27
|
|
28
28
|
import rebel
|
29
29
|
import torch
|
30
|
+
from rebel.compile_context import CompileContext
|
30
31
|
from transformers import (
|
31
32
|
AutoModelForSpeechSeq2Seq,
|
32
33
|
AutoProcessor,
|
@@ -37,23 +38,16 @@ from transformers import (
|
|
37
38
|
from transformers.modeling_outputs import BaseModelOutput, Seq2SeqLMOutput
|
38
39
|
|
39
40
|
from ....modeling import RBLNModel
|
40
|
-
from ....modeling_config import
|
41
|
+
from ....modeling_config import RBLNCompileConfig, RBLNConfig
|
41
42
|
from ....utils.runtime_utils import RBLNPytorchRuntime
|
42
43
|
from .generation_whisper import RBLNWhisperGenerationMixin
|
43
|
-
from .whisper_architecture import
|
44
|
-
_WhisperDecoderWrapper,
|
45
|
-
_WhisperEncoderWrapper,
|
46
|
-
)
|
44
|
+
from .whisper_architecture import WhisperWrapper
|
47
45
|
|
48
46
|
|
49
47
|
logger = logging.getLogger(__name__)
|
50
48
|
|
51
49
|
if TYPE_CHECKING:
|
52
|
-
from transformers import
|
53
|
-
AutoFeatureExtractor,
|
54
|
-
AutoProcessor,
|
55
|
-
PretrainedConfig,
|
56
|
-
)
|
50
|
+
from transformers import AutoFeatureExtractor, AutoProcessor, PretrainedConfig, PreTrainedModel
|
57
51
|
|
58
52
|
|
59
53
|
class RBLNRuntimeEncoder(RBLNPytorchRuntime):
|
@@ -164,47 +158,51 @@ class RBLNWhisperForConditionalGeneration(RBLNModel, RBLNWhisperGenerationMixin)
|
|
164
158
|
# TODO(jongho): implement
|
165
159
|
raise NotImplementedError
|
166
160
|
|
161
|
+
@classmethod
|
162
|
+
def wrap_model_if_needed(self, model: "PreTrainedModel", rbln_config: "RBLNConfig"):
|
163
|
+
rbln_token_timestamps = rbln_config.model_cfg["token_timestamps"]
|
164
|
+
return WhisperWrapper(model, rbln_token_timestamps)
|
165
|
+
|
167
166
|
@classmethod
|
168
167
|
@torch.inference_mode()
|
169
168
|
def get_compiled_model(cls, model, rbln_config: RBLNConfig):
|
170
|
-
|
171
|
-
wrapped_encoder = _WhisperEncoderWrapper(model).eval()
|
172
|
-
wrapped_decoder = _WhisperDecoderWrapper(model, output_attentions=rbln_token_timestamps).eval()
|
169
|
+
wrapped_model = cls.wrap_model_if_needed(model, rbln_config)
|
173
170
|
|
174
|
-
|
175
|
-
|
171
|
+
enc_compile_config = rbln_config.compile_cfgs[0]
|
172
|
+
dec_compile_config = rbln_config.compile_cfgs[1]
|
176
173
|
|
177
|
-
|
178
|
-
dec_example_inputs = dec_rbln_compile_config.get_dummy_inputs(fill=1)
|
174
|
+
context = CompileContext(use_weight_sharing=False)
|
179
175
|
|
180
|
-
|
181
|
-
dec_scripted_model = torch.jit.trace(wrapped_decoder, dec_example_inputs, check_trace=False)
|
176
|
+
enc_example_inputs = enc_compile_config.get_dummy_inputs(fill=0)
|
182
177
|
|
183
|
-
|
184
|
-
|
185
|
-
|
186
|
-
name
|
178
|
+
# Mark encoder's static tensors (cross kv states)
|
179
|
+
static_tensors = {}
|
180
|
+
for (name, _, _), tensor in zip(enc_compile_config.input_info, enc_example_inputs):
|
181
|
+
if "key_value_states" in name:
|
182
|
+
static_tensors[name] = tensor
|
183
|
+
context.mark_static_address(tensor)
|
184
|
+
|
185
|
+
dec_example_inputs = dec_compile_config.get_dummy_inputs(fill=0, static_tensors=static_tensors)
|
186
|
+
|
187
|
+
# Mark decoder's static tensors (self kv states)
|
188
|
+
for (name, _, _), tensor in zip(dec_compile_config.input_info, dec_example_inputs):
|
189
|
+
if "key_value_states" in name:
|
190
|
+
context.mark_static_address(tensor)
|
191
|
+
|
192
|
+
compiled_encoder = super().compile(
|
193
|
+
wrapped_model.encoder,
|
194
|
+
enc_compile_config,
|
195
|
+
example_inputs=enc_example_inputs,
|
196
|
+
compile_context=context,
|
187
197
|
)
|
188
|
-
|
189
|
-
|
190
|
-
|
191
|
-
|
198
|
+
compiled_decoder = super().compile(
|
199
|
+
wrapped_model.decoder,
|
200
|
+
dec_compile_config,
|
201
|
+
example_inputs=dec_example_inputs,
|
202
|
+
compile_context=context,
|
192
203
|
)
|
193
204
|
|
194
|
-
|
195
|
-
connections = [
|
196
|
-
(enc_ir.outputs[0], dec_ir.inputs[4]),
|
197
|
-
(dec_ir.outputs[1], dec_ir.inputs[3]),
|
198
|
-
]
|
199
|
-
compiled_model = rebel.compile(
|
200
|
-
enc_ir,
|
201
|
-
dec_ir,
|
202
|
-
connections=connections,
|
203
|
-
fusion=enc_rbln_compile_config.fusion,
|
204
|
-
npu=enc_rbln_compile_config.npu,
|
205
|
-
tensor_parallel_size=enc_rbln_compile_config.tensor_parallel_size,
|
206
|
-
)
|
207
|
-
return compiled_model
|
205
|
+
return {"encoder": compiled_encoder, "decoder": compiled_decoder}
|
208
206
|
|
209
207
|
@classmethod
|
210
208
|
def _get_rbln_config(
|
@@ -228,6 +226,22 @@ class RBLNWhisperForConditionalGeneration(RBLNModel, RBLNWhisperGenerationMixin)
|
|
228
226
|
|
229
227
|
# model input info
|
230
228
|
enc_input_info = [("input_features", [rbln_batch_size, num_mel_bins, expected_seq_len], "float32")]
|
229
|
+
enc_input_info.extend(
|
230
|
+
[
|
231
|
+
(
|
232
|
+
"cross_key_value_states",
|
233
|
+
[
|
234
|
+
model_config.decoder_layers * 2,
|
235
|
+
rbln_batch_size,
|
236
|
+
model_config.decoder_attention_heads,
|
237
|
+
enc_max_seq_len,
|
238
|
+
model_config.d_model // model_config.decoder_attention_heads,
|
239
|
+
],
|
240
|
+
"float32",
|
241
|
+
)
|
242
|
+
]
|
243
|
+
)
|
244
|
+
|
231
245
|
dec_input_info = [
|
232
246
|
("decoder_input_ids", [rbln_batch_size, 1], "int64"),
|
233
247
|
("decoder_attention_mask", [rbln_batch_size, rbln_dec_max_seq_len], "int64"),
|
@@ -236,13 +250,13 @@ class RBLNWhisperForConditionalGeneration(RBLNModel, RBLNWhisperGenerationMixin)
|
|
236
250
|
dec_input_info.extend(
|
237
251
|
[
|
238
252
|
(
|
239
|
-
"
|
253
|
+
"cross_key_value_states",
|
240
254
|
[
|
241
255
|
model_config.decoder_layers * 2,
|
242
256
|
rbln_batch_size,
|
243
257
|
model_config.decoder_attention_heads,
|
244
|
-
|
245
|
-
model_config.d_model // model_config.
|
258
|
+
enc_max_seq_len,
|
259
|
+
model_config.d_model // model_config.decoder_attention_heads,
|
246
260
|
],
|
247
261
|
"float32",
|
248
262
|
)
|
@@ -251,25 +265,25 @@ class RBLNWhisperForConditionalGeneration(RBLNModel, RBLNWhisperGenerationMixin)
|
|
251
265
|
dec_input_info.extend(
|
252
266
|
[
|
253
267
|
(
|
254
|
-
"
|
268
|
+
f"self_key_value_states_{i}",
|
255
269
|
[
|
256
|
-
model_config.decoder_layers * 2,
|
257
270
|
rbln_batch_size,
|
258
271
|
model_config.decoder_attention_heads,
|
259
|
-
|
272
|
+
rbln_dec_max_seq_len,
|
260
273
|
model_config.d_model // model_config.encoder_attention_heads,
|
261
274
|
],
|
262
275
|
"float32",
|
263
276
|
)
|
277
|
+
for i in range(model_config.decoder_layers * 2)
|
264
278
|
]
|
265
279
|
)
|
266
280
|
|
267
|
-
|
268
|
-
|
281
|
+
enc_compile_config = RBLNCompileConfig(compiled_model_name="encoder", input_info=enc_input_info)
|
282
|
+
dec_compile_config = RBLNCompileConfig(compiled_model_name="decoder", input_info=dec_input_info)
|
269
283
|
|
270
284
|
rbln_config = RBLNConfig(
|
271
285
|
rbln_cls=cls.__name__,
|
272
|
-
compile_cfgs=[
|
286
|
+
compile_cfgs=[enc_compile_config, dec_compile_config],
|
273
287
|
rbln_kwargs=rbln_kwargs,
|
274
288
|
)
|
275
289
|
|
@@ -285,12 +299,21 @@ class RBLNWhisperForConditionalGeneration(RBLNModel, RBLNWhisperGenerationMixin)
|
|
285
299
|
|
286
300
|
@classmethod
|
287
301
|
def _create_runtimes(
|
288
|
-
cls,
|
302
|
+
cls,
|
303
|
+
compiled_models: List[rebel.RBLNCompiledModel],
|
304
|
+
rbln_device_map: Dict[str, int],
|
305
|
+
activate_profiler: Optional[bool] = None,
|
289
306
|
) -> List[rebel.Runtime]:
|
290
|
-
|
307
|
+
if any(model_name not in rbln_device_map for model_name in ["encoder", "decoder"]):
|
308
|
+
cls._raise_missing_compiled_file_error(["encoder", "decoder"])
|
309
|
+
|
291
310
|
return [
|
292
|
-
compiled_models[0].create_runtime(
|
293
|
-
|
311
|
+
compiled_models[0].create_runtime(
|
312
|
+
tensor_type="pt", device=rbln_device_map["encoder"], activate_profiler=activate_profiler
|
313
|
+
),
|
314
|
+
compiled_models[1].create_runtime(
|
315
|
+
tensor_type="pt", device=rbln_device_map["decoder"], activate_profiler=activate_profiler
|
316
|
+
),
|
294
317
|
]
|
295
318
|
|
296
319
|
def prepare_inputs_for_generation(
|