optimum-rbln 0.9.4a2__py3-none-any.whl → 0.9.5a4__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- optimum/rbln/__init__.py +36 -0
- optimum/rbln/__version__.py +2 -2
- optimum/rbln/configuration_utils.py +35 -16
- optimum/rbln/modeling_base.py +6 -6
- optimum/rbln/ops/__init__.py +1 -0
- optimum/rbln/ops/attn.py +10 -0
- optimum/rbln/ops/flash_attn.py +8 -0
- optimum/rbln/ops/moe.py +180 -0
- optimum/rbln/ops/sliding_window_attn.py +9 -0
- optimum/rbln/transformers/__init__.py +36 -0
- optimum/rbln/transformers/modeling_attention_utils.py +118 -222
- optimum/rbln/transformers/modeling_outputs.py +25 -0
- optimum/rbln/transformers/modeling_rope_utils.py +78 -42
- optimum/rbln/transformers/models/__init__.py +28 -0
- optimum/rbln/transformers/models/bart/bart_architecture.py +24 -24
- optimum/rbln/transformers/models/colpali/colpali_architecture.py +14 -20
- optimum/rbln/transformers/models/colpali/configuration_colpali.py +12 -17
- optimum/rbln/transformers/models/colpali/modeling_colpali.py +66 -182
- optimum/rbln/transformers/models/colqwen2/configuration_colqwen2.py +38 -21
- optimum/rbln/transformers/models/colqwen2/modeling_colqwen2.py +107 -371
- optimum/rbln/transformers/models/decoderonly/__init__.py +2 -0
- optimum/rbln/transformers/models/decoderonly/configuration_decoderonly.py +118 -16
- optimum/rbln/transformers/models/decoderonly/configuration_lora.py +1 -1
- optimum/rbln/transformers/models/decoderonly/decoderonly_architecture.py +121 -48
- optimum/rbln/transformers/models/decoderonly/decoderonly_runtime_utils.py +5 -7
- optimum/rbln/transformers/models/decoderonly/modeling_decoderonly.py +75 -107
- optimum/rbln/transformers/models/exaone/exaone_architecture.py +0 -36
- optimum/rbln/transformers/models/gemma/gemma_architecture.py +1 -1
- optimum/rbln/transformers/models/gemma2/__init__.py +16 -0
- optimum/rbln/transformers/models/gemma2/configuration_gemma2.py +45 -0
- optimum/rbln/transformers/models/gemma2/gemma2_architecture.py +83 -0
- optimum/rbln/transformers/models/gemma2/modeling_gemma2.py +101 -0
- optimum/rbln/transformers/models/gemma3/gemma3_architecture.py +16 -18
- optimum/rbln/transformers/models/gemma3/modeling_gemma3.py +1 -1
- optimum/rbln/transformers/models/gpt2/gpt2_architecture.py +8 -34
- optimum/rbln/transformers/models/gpt_oss/__init__.py +16 -0
- optimum/rbln/transformers/models/gpt_oss/configuration_gpt_oss.py +41 -0
- optimum/rbln/transformers/models/gpt_oss/gpt_oss_architecture.py +122 -0
- optimum/rbln/transformers/models/gpt_oss/modeling_gpt_oss.py +165 -0
- optimum/rbln/transformers/models/grounding_dino/configuration_grounding_dino.py +8 -5
- optimum/rbln/transformers/models/grounding_dino/grounding_dino_architecture.py +6 -4
- optimum/rbln/transformers/models/llava/modeling_llava.py +0 -1
- optimum/rbln/transformers/models/midm/midm_architecture.py +29 -22
- optimum/rbln/transformers/models/opt/opt_architecture.py +1 -44
- optimum/rbln/transformers/models/paligemma/__init__.py +16 -0
- optimum/rbln/transformers/models/paligemma/configuration_paligemma.py +129 -0
- optimum/rbln/transformers/models/paligemma/modeling_paligemma.py +564 -0
- optimum/rbln/transformers/models/pegasus/pegasus_architecture.py +24 -24
- optimum/rbln/transformers/models/phi/phi_architecture.py +13 -21
- optimum/rbln/transformers/models/qwen2_5_vl/__init__.py +6 -1
- optimum/rbln/transformers/models/qwen2_5_vl/configuration_qwen2_5_vl.py +11 -1
- optimum/rbln/transformers/models/qwen2_5_vl/modeling_qwen2_5_vl.py +271 -122
- optimum/rbln/transformers/models/qwen2_5_vl/qwen2_5_vl_architecture.py +43 -39
- optimum/rbln/transformers/models/qwen2_moe/__init__.py +16 -0
- optimum/rbln/transformers/models/qwen2_moe/configuration_qwen2_moe.py +38 -0
- optimum/rbln/transformers/models/qwen2_moe/modeling_qwen2_moe.py +68 -0
- optimum/rbln/transformers/models/qwen2_moe/qwen2_moe_architecture.py +94 -0
- optimum/rbln/transformers/models/qwen2_vl/__init__.py +6 -1
- optimum/rbln/transformers/models/qwen2_vl/configuration_qwen2_vl.py +11 -1
- optimum/rbln/transformers/models/qwen2_vl/modeling_qwen2_vl.py +263 -105
- optimum/rbln/transformers/models/qwen2_vl/qwen2_vl_architecture.py +26 -34
- optimum/rbln/transformers/models/qwen3/qwen3_architecture.py +7 -7
- optimum/rbln/transformers/models/qwen3_moe/__init__.py +16 -0
- optimum/rbln/transformers/models/qwen3_moe/configuration_qwen3_moe.py +38 -0
- optimum/rbln/transformers/models/qwen3_moe/modeling_qwen3_moe.py +68 -0
- optimum/rbln/transformers/models/qwen3_moe/qwen3_moe_architecture.py +100 -0
- optimum/rbln/transformers/models/seq2seq/seq2seq_architecture.py +14 -12
- optimum/rbln/transformers/models/siglip/modeling_siglip.py +4 -18
- optimum/rbln/transformers/models/swin/configuration_swin.py +1 -6
- optimum/rbln/transformers/models/t5/t5_architecture.py +15 -16
- optimum/rbln/transformers/models/time_series_transformer/time_series_transformers_architecture.py +0 -3
- optimum/rbln/transformers/models/whisper/whisper_architecture.py +0 -3
- optimum/rbln/transformers/utils/rbln_quantization.py +20 -12
- optimum/rbln/utils/import_utils.py +16 -1
- optimum/rbln/utils/runtime_utils.py +10 -6
- optimum/rbln/utils/submodule.py +24 -0
- {optimum_rbln-0.9.4a2.dist-info → optimum_rbln-0.9.5a4.dist-info}/METADATA +6 -6
- {optimum_rbln-0.9.4a2.dist-info → optimum_rbln-0.9.5a4.dist-info}/RECORD +81 -62
- optimum/rbln/transformers/models/colqwen2/colqwen2_architecture.py +0 -233
- {optimum_rbln-0.9.4a2.dist-info → optimum_rbln-0.9.5a4.dist-info}/WHEEL +0 -0
- {optimum_rbln-0.9.4a2.dist-info → optimum_rbln-0.9.5a4.dist-info}/entry_points.txt +0 -0
- {optimum_rbln-0.9.4a2.dist-info → optimum_rbln-0.9.5a4.dist-info}/licenses/LICENSE +0 -0
|
@@ -26,15 +26,16 @@ from transformers.modeling_utils import no_init_weights
|
|
|
26
26
|
from ....configuration_utils import RBLNCompileConfig
|
|
27
27
|
from ....modeling import RBLNModel
|
|
28
28
|
from ....utils.logging import get_logger
|
|
29
|
+
from ....utils.runtime_utils import is_compiler_supports_buffer_resize
|
|
29
30
|
from ...modeling_attention_utils import (
|
|
30
31
|
RBLNDecoderOnlyFlashAttentionMixin,
|
|
31
32
|
set_default_values,
|
|
32
33
|
validate_attention_method,
|
|
33
34
|
validate_sliding_window,
|
|
34
35
|
)
|
|
35
|
-
from ...modeling_outputs import RBLNDecoderOnlyOutput
|
|
36
|
+
from ...modeling_outputs import RBLNDecoderOnlyOutput, _validate_output_hidden_states
|
|
36
37
|
from ...utils.rbln_quantization import get_quantized_model
|
|
37
|
-
from .configuration_decoderonly import RBLNDecoderOnlyModelConfig, RBLNDecoderOnlyModelForCausalLMConfig
|
|
38
|
+
from .configuration_decoderonly import KVCacheMeta, RBLNDecoderOnlyModelConfig, RBLNDecoderOnlyModelForCausalLMConfig
|
|
38
39
|
from .decoderonly_architecture import DecoderOnlyWrapper
|
|
39
40
|
from .decoderonly_runtime_utils import RBLNPageTableManager, RBLNRuntimeModel
|
|
40
41
|
from .generation_decoderonly import RBLNDecoderOnlyGenerationMixin
|
|
@@ -230,7 +231,7 @@ class RBLNDecoderOnlyModel(RBLNModel, RBLNDecoderOnlyFlashAttentionMixin):
|
|
|
230
231
|
rbln_config: RBLNDecoderOnlyModelForCausalLMConfig,
|
|
231
232
|
quantization=None,
|
|
232
233
|
phase: str = "prefill",
|
|
233
|
-
):
|
|
234
|
+
) -> rebel.RBLNCompiledModel:
|
|
234
235
|
try:
|
|
235
236
|
wrapped_model.phase = phase
|
|
236
237
|
if quantization:
|
|
@@ -252,21 +253,15 @@ class RBLNDecoderOnlyModel(RBLNModel, RBLNDecoderOnlyFlashAttentionMixin):
|
|
|
252
253
|
quantization.maybe_reset_quantization_env()
|
|
253
254
|
|
|
254
255
|
@classmethod
|
|
255
|
-
def _get_compile_context(
|
|
256
|
-
cls,
|
|
257
|
-
compile_config: RBLNCompileConfig,
|
|
258
|
-
example_inputs: List[torch.Tensor],
|
|
259
|
-
):
|
|
256
|
+
def _get_compile_context(cls, compile_config: RBLNCompileConfig, example_inputs: List[torch.Tensor]):
|
|
260
257
|
context = CompileContext(use_weight_sharing=True)
|
|
261
258
|
|
|
262
259
|
# Mark static tensors (self kv states)
|
|
263
260
|
static_tensors = {}
|
|
264
|
-
idx = 0
|
|
265
261
|
for (name, _, _), tensor in zip(compile_config.input_info, example_inputs):
|
|
266
262
|
if "past_key_values" in name:
|
|
267
263
|
static_tensors[name] = tensor
|
|
268
|
-
context.mark_static_address(tensor,
|
|
269
|
-
idx += 1
|
|
264
|
+
context.mark_static_address(tensor, name)
|
|
270
265
|
|
|
271
266
|
return context, static_tensors
|
|
272
267
|
|
|
@@ -281,7 +276,7 @@ class RBLNDecoderOnlyModel(RBLNModel, RBLNDecoderOnlyFlashAttentionMixin):
|
|
|
281
276
|
prefill_example_inputs = prefill_compile_config.get_dummy_inputs(fill=0, meta_tensor_names=meta_tensor_names)
|
|
282
277
|
context, static_tensors = cls._get_compile_context(prefill_compile_config, prefill_example_inputs)
|
|
283
278
|
|
|
284
|
-
compiled_models = {}
|
|
279
|
+
compiled_models: dict[str, rebel.RBLNCompiledModel] = {}
|
|
285
280
|
compiled_models["prefill"] = cls._compile_model(
|
|
286
281
|
wrapped_model,
|
|
287
282
|
prefill_compile_config,
|
|
@@ -307,14 +302,10 @@ class RBLNDecoderOnlyModel(RBLNModel, RBLNDecoderOnlyFlashAttentionMixin):
|
|
|
307
302
|
)
|
|
308
303
|
compiled_models[f"decoder_batch_{batch_size}"] = compiled_decoder
|
|
309
304
|
|
|
310
|
-
|
|
311
|
-
|
|
312
|
-
|
|
313
|
-
|
|
314
|
-
compiled_models=compiled_models,
|
|
315
|
-
model_config=model.config,
|
|
316
|
-
rbln_config=rbln_config,
|
|
317
|
-
)
|
|
305
|
+
if rbln_config.is_auto_num_blocks:
|
|
306
|
+
if not is_compiler_supports_buffer_resize():
|
|
307
|
+
raise RuntimeError("`kvcache_num_blocks` must be set.")
|
|
308
|
+
cls.set_kvcache_num_blocks_after_compilation(compiled_models, rbln_config)
|
|
318
309
|
|
|
319
310
|
return compiled_models
|
|
320
311
|
|
|
@@ -330,8 +321,8 @@ class RBLNDecoderOnlyModel(RBLNModel, RBLNDecoderOnlyFlashAttentionMixin):
|
|
|
330
321
|
return model
|
|
331
322
|
|
|
332
323
|
@classmethod
|
|
333
|
-
def use_query_position(cls, use_local_attention: bool, is_prefill: bool = True):
|
|
334
|
-
return use_local_attention
|
|
324
|
+
def use_query_position(cls, use_local_attention: bool, is_prefill: bool = True, logits_to_keep: int = None):
|
|
325
|
+
return is_prefill and (use_local_attention or logits_to_keep == 1)
|
|
335
326
|
|
|
336
327
|
@classmethod
|
|
337
328
|
def get_input_info(
|
|
@@ -350,7 +341,7 @@ class RBLNDecoderOnlyModel(RBLNModel, RBLNDecoderOnlyFlashAttentionMixin):
|
|
|
350
341
|
|
|
351
342
|
input_info = []
|
|
352
343
|
if rbln_config.use_inputs_embeds:
|
|
353
|
-
input_info.append(("inputs_embeds", [batch_size, query_length, hidden_size], rbln_config.
|
|
344
|
+
input_info.append(("inputs_embeds", [batch_size, query_length, hidden_size], rbln_config.dtype))
|
|
354
345
|
else:
|
|
355
346
|
input_info.append(("input_ids", [batch_size, query_length], "int64"))
|
|
356
347
|
|
|
@@ -364,15 +355,15 @@ class RBLNDecoderOnlyModel(RBLNModel, RBLNDecoderOnlyFlashAttentionMixin):
|
|
|
364
355
|
if rbln_config.use_local_attention:
|
|
365
356
|
input_info.append(("local_block_tables", [1] if is_prefill else [batch_size, 1], "int16"))
|
|
366
357
|
|
|
367
|
-
if cls.use_query_position(rbln_config.use_local_attention, is_prefill):
|
|
358
|
+
if cls.use_query_position(rbln_config.use_local_attention, is_prefill, rbln_config.logits_to_keep):
|
|
368
359
|
input_info.append(("query_position", [], "int16"))
|
|
369
360
|
|
|
370
361
|
if rbln_config.use_attention_mask:
|
|
371
362
|
if rbln_config.use_position_ids:
|
|
372
|
-
input_info.append(("attention_mask", [batch_size, rbln_config.max_seq_len], rbln_config.
|
|
363
|
+
input_info.append(("attention_mask", [batch_size, rbln_config.max_seq_len], rbln_config.dtype))
|
|
373
364
|
else:
|
|
374
365
|
input_info.append(
|
|
375
|
-
("attention_mask", [batch_size, 1, query_length, rbln_config.max_seq_len], rbln_config.
|
|
366
|
+
("attention_mask", [batch_size, 1, query_length, rbln_config.max_seq_len], rbln_config.dtype)
|
|
376
367
|
)
|
|
377
368
|
|
|
378
369
|
if rbln_config.use_position_ids:
|
|
@@ -381,29 +372,36 @@ class RBLNDecoderOnlyModel(RBLNModel, RBLNDecoderOnlyFlashAttentionMixin):
|
|
|
381
372
|
if rbln_config.use_lora:
|
|
382
373
|
input_info.append(("lora_int_ids", [batch_size], "int32"))
|
|
383
374
|
|
|
384
|
-
|
|
385
|
-
|
|
386
|
-
|
|
375
|
+
if len(rbln_config.kvcache_metas) > 0:
|
|
376
|
+
# Meta is already set, use it
|
|
377
|
+
input_info.extend(
|
|
378
|
+
[
|
|
379
|
+
(kvcache_meta.name, kvcache_meta.compile_shape, kvcache_meta.dtype)
|
|
380
|
+
for kvcache_meta in rbln_config.kvcache_metas
|
|
381
|
+
]
|
|
382
|
+
)
|
|
387
383
|
|
|
388
|
-
|
|
389
|
-
rbln_config.
|
|
390
|
-
|
|
391
|
-
|
|
392
|
-
|
|
393
|
-
|
|
394
|
-
|
|
395
|
-
|
|
396
|
-
|
|
397
|
-
(
|
|
398
|
-
|
|
399
|
-
|
|
400
|
-
|
|
401
|
-
|
|
402
|
-
kvcache_dtype,
|
|
384
|
+
else:
|
|
385
|
+
kvcache_dtype = rbln_config.dtype
|
|
386
|
+
if rbln_config.quantization and rbln_config.quantization.kv_caches == "fp8":
|
|
387
|
+
kvcache_dtype = "float8_e4m3fn"
|
|
388
|
+
|
|
389
|
+
kvcache_metas = []
|
|
390
|
+
for i in range(num_hidden_layers * 2):
|
|
391
|
+
layer_idx = i // 2
|
|
392
|
+
name = f"past_key_values_{i}"
|
|
393
|
+
kvcache_meta = KVCacheMeta.make(
|
|
394
|
+
name,
|
|
395
|
+
layer_idx,
|
|
396
|
+
num_key_value_heads,
|
|
397
|
+
head_dim,
|
|
398
|
+
RBLNCompileConfig.normalize_dtype(kvcache_dtype),
|
|
399
|
+
rbln_config,
|
|
403
400
|
)
|
|
404
|
-
|
|
405
|
-
|
|
406
|
-
|
|
401
|
+
kvcache_metas.append(kvcache_meta)
|
|
402
|
+
input_info.append((name, kvcache_meta.compile_shape, kvcache_meta.dtype))
|
|
403
|
+
|
|
404
|
+
rbln_config.kvcache_metas.extend(kvcache_metas)
|
|
407
405
|
|
|
408
406
|
return input_info
|
|
409
407
|
|
|
@@ -475,51 +473,39 @@ class RBLNDecoderOnlyModel(RBLNModel, RBLNDecoderOnlyFlashAttentionMixin):
|
|
|
475
473
|
max_seq_len=rbln_config.max_seq_len,
|
|
476
474
|
)
|
|
477
475
|
|
|
478
|
-
|
|
479
|
-
|
|
480
|
-
#
|
|
476
|
+
# Validate kvcache_num_blocks based on the number of full blocks required.
|
|
477
|
+
# Eager mode restriction:
|
|
478
|
+
# - num_blocks must be at least equal to the batch size
|
|
479
|
+
# Flash attention restriction:
|
|
480
|
+
# - num_blocks must be at least equal to (max_seq_len // kvcache_block_size) + 1
|
|
481
|
+
# - num_blocks must be no greater than the number of full blocks.
|
|
481
482
|
if rbln_config.attn_impl == "flash_attn":
|
|
482
|
-
|
|
483
|
-
|
|
484
|
-
|
|
483
|
+
if rbln_config.is_auto_num_blocks:
|
|
484
|
+
# Do nothing
|
|
485
|
+
pass
|
|
485
486
|
|
|
486
|
-
|
|
487
|
-
if
|
|
488
|
-
|
|
489
|
-
|
|
490
|
-
|
|
487
|
+
else:
|
|
488
|
+
if rbln_config.kvcache_num_blocks > rbln_config.num_full_blocks:
|
|
489
|
+
logger.warning(
|
|
490
|
+
f"The set `kvcache_num_blocks` ({rbln_config.kvcache_num_blocks}) is greater"
|
|
491
|
+
f" than the required number of blocks ({rbln_config.num_full_blocks})."
|
|
492
|
+
"This can cause a failure during model compilation."
|
|
493
|
+
)
|
|
494
|
+
elif rbln_config.kvcache_num_blocks < rbln_config.num_min_blocks:
|
|
495
|
+
raise ValueError(
|
|
496
|
+
f"The set `kvcache_num_blocks` ({rbln_config.kvcache_num_blocks}) is less"
|
|
497
|
+
f" than the minimum number of blocks ({rbln_config.num_min_blocks})."
|
|
491
498
|
)
|
|
492
|
-
if min_blocks_for_flash > estimated_max_num_blocks:
|
|
493
|
-
# NOTE: Just try to compile with lower bound of blocks for flash attention.
|
|
494
|
-
# Even if it's larger than the estimated maximum number of blocks.
|
|
495
|
-
rbln_config.kvcache_num_blocks = min_blocks_for_flash
|
|
496
|
-
else:
|
|
497
|
-
logger.info(f"[KVCache] Compiling with num_blocks: {rbln_config.kvcache_num_blocks}")
|
|
498
|
-
rbln_config.kvcache_num_blocks = estimated_max_num_blocks
|
|
499
|
-
|
|
500
|
-
if rbln_config.kvcache_num_blocks < rbln_config.batch_size:
|
|
501
|
-
raise RuntimeError(
|
|
502
|
-
f"Batch size ({rbln_config.batch_size}) exceeds num_blocks ({rbln_config.kvcache_num_blocks}). "
|
|
503
|
-
"Ensure the number of blocks is at least equal to the batch size."
|
|
504
|
-
)
|
|
505
|
-
else:
|
|
506
|
-
rbln_config.kvcache_num_blocks = num_full_blocks
|
|
507
|
-
elif rbln_config.kvcache_num_blocks > estimated_max_num_blocks:
|
|
508
|
-
logger.warning(
|
|
509
|
-
f"The set `kvcache_num_blocks` ({rbln_config.kvcache_num_blocks}) is greater"
|
|
510
|
-
f" than the estimated maximum number of blocks ({estimated_max_num_blocks})."
|
|
511
|
-
"This can cause a failure during model compilation."
|
|
512
|
-
)
|
|
513
499
|
else:
|
|
514
|
-
if rbln_config.
|
|
515
|
-
|
|
516
|
-
|
|
500
|
+
if rbln_config.is_auto_num_blocks:
|
|
501
|
+
# Eager attention should use fixed number of blocks.
|
|
502
|
+
rbln_config.kvcache_num_blocks = rbln_config.num_full_blocks
|
|
503
|
+
elif rbln_config.kvcache_num_blocks > rbln_config.num_full_blocks:
|
|
517
504
|
logger.warning(
|
|
518
505
|
f"The set `kvcache_num_blocks` ({rbln_config.kvcache_num_blocks}) is greater"
|
|
519
|
-
f" than the required number of blocks ({num_full_blocks})."
|
|
506
|
+
f" than the required number of blocks ({rbln_config.num_full_blocks})."
|
|
520
507
|
"This can cause a failure during model compilation."
|
|
521
508
|
)
|
|
522
|
-
logger.info(f"[KVCache] Compiling with num_blocks: {rbln_config.kvcache_num_blocks}")
|
|
523
509
|
|
|
524
510
|
return rbln_config
|
|
525
511
|
|
|
@@ -643,15 +629,7 @@ class RBLNDecoderOnlyModel(RBLNModel, RBLNDecoderOnlyFlashAttentionMixin):
|
|
|
643
629
|
raise ValueError(
|
|
644
630
|
f"Batch size ({batch_size}) must be equal to the batch size of the model ({self.rbln_config.batch_size})."
|
|
645
631
|
)
|
|
646
|
-
|
|
647
|
-
output_hidden_states = (
|
|
648
|
-
output_hidden_states if output_hidden_states is not None else self.rbln_config.output_hidden_states
|
|
649
|
-
)
|
|
650
|
-
if output_hidden_states != self.rbln_config.output_hidden_states:
|
|
651
|
-
raise ValueError(
|
|
652
|
-
f"Variable output_hidden_states {output_hidden_states} is not equal to rbln_config.output_hidden_states {self.rbln_config.output_hidden_states} "
|
|
653
|
-
f"Please compile again with the correct argument."
|
|
654
|
-
)
|
|
632
|
+
output_hidden_states = _validate_output_hidden_states(output_hidden_states, self.rbln_config)
|
|
655
633
|
|
|
656
634
|
all_last_hidden_states = []
|
|
657
635
|
all_hidden_states = (
|
|
@@ -660,7 +638,7 @@ class RBLNDecoderOnlyModel(RBLNModel, RBLNDecoderOnlyFlashAttentionMixin):
|
|
|
660
638
|
self.rbln_config.batch_size,
|
|
661
639
|
inputs.shape[1],
|
|
662
640
|
self.config.hidden_size,
|
|
663
|
-
dtype=self.rbln_config.
|
|
641
|
+
dtype=self.rbln_config.dtype,
|
|
664
642
|
)
|
|
665
643
|
for _ in range(self.config.num_hidden_layers + 1)
|
|
666
644
|
)
|
|
@@ -700,6 +678,7 @@ class RBLNDecoderOnlyModelForCausalLM(RBLNDecoderOnlyModel, RBLNDecoderOnlyGener
|
|
|
700
678
|
1. Converting pre-trained transformer models to RBLN-optimized format
|
|
701
679
|
2. Handling the compilation process for RBLN devices
|
|
702
680
|
3. Managing inference operations for causal language modeling
|
|
681
|
+
|
|
703
682
|
This class inherits from RBLNModel and implements specific methods required for
|
|
704
683
|
decoder-only architectures and causal language modeling tasks.
|
|
705
684
|
|
|
@@ -716,10 +695,6 @@ class RBLNDecoderOnlyModelForCausalLM(RBLNDecoderOnlyModel, RBLNDecoderOnlyGener
|
|
|
716
695
|
def logits_last_dim(self):
|
|
717
696
|
return self.config.vocab_size
|
|
718
697
|
|
|
719
|
-
@classmethod
|
|
720
|
-
def use_query_position(cls, use_local_attention: bool, is_prefill: bool = True):
|
|
721
|
-
return is_prefill
|
|
722
|
-
|
|
723
698
|
def set_lora_int_ids(self, lora_int_ids: Optional[torch.Tensor]):
|
|
724
699
|
if isinstance(lora_int_ids, int):
|
|
725
700
|
lora_int_ids = torch.tensor([lora_int_ids], dtype=torch.int32)
|
|
@@ -803,14 +778,7 @@ class RBLNDecoderOnlyModelForCausalLM(RBLNDecoderOnlyModel, RBLNDecoderOnlyGener
|
|
|
803
778
|
)
|
|
804
779
|
padded_cache_lengths = torch.zeros_like(generate_idx)
|
|
805
780
|
|
|
806
|
-
output_hidden_states = (
|
|
807
|
-
output_hidden_states if output_hidden_states is not None else self.rbln_config.output_hidden_states
|
|
808
|
-
)
|
|
809
|
-
if output_hidden_states != self.rbln_config.output_hidden_states:
|
|
810
|
-
raise ValueError(
|
|
811
|
-
f"Variable output_hidden_states {output_hidden_states} is not equal to rbln_config.output_hidden_states {self.rbln_config.output_hidden_states} "
|
|
812
|
-
f"Please compile again with the correct argument."
|
|
813
|
-
)
|
|
781
|
+
output_hidden_states = _validate_output_hidden_states(output_hidden_states, self.rbln_config)
|
|
814
782
|
|
|
815
783
|
# Prefill
|
|
816
784
|
if cache_position is None:
|
|
@@ -829,7 +797,7 @@ class RBLNDecoderOnlyModelForCausalLM(RBLNDecoderOnlyModel, RBLNDecoderOnlyGener
|
|
|
829
797
|
|
|
830
798
|
all_hidden_states = (
|
|
831
799
|
tuple(
|
|
832
|
-
torch.zeros(batch_size, input_len, self.config.hidden_size, dtype=self.rbln_config.
|
|
800
|
+
torch.zeros(batch_size, input_len, self.config.hidden_size, dtype=self.rbln_config.dtype)
|
|
833
801
|
for _ in range(self.config.num_hidden_layers + 1)
|
|
834
802
|
)
|
|
835
803
|
if self.rbln_config.output_hidden_states
|
|
@@ -18,9 +18,6 @@ import torch.nn as nn
|
|
|
18
18
|
|
|
19
19
|
from ....utils import logging
|
|
20
20
|
from ...models.decoderonly.decoderonly_architecture import (
|
|
21
|
-
DecoderOnlyAttention,
|
|
22
|
-
DecoderOnlyLayer,
|
|
23
|
-
DecoderOnlyModel,
|
|
24
21
|
DecoderOnlyWrapper,
|
|
25
22
|
)
|
|
26
23
|
|
|
@@ -42,36 +39,3 @@ class ExaoneForCausalLMWrapper(DecoderOnlyWrapper):
|
|
|
42
39
|
|
|
43
40
|
def get_model_layer(self, causal_lm: "ExaoneForCausalLM"):
|
|
44
41
|
return causal_lm.transformer
|
|
45
|
-
|
|
46
|
-
def get_rbln_attn_class(self):
|
|
47
|
-
return ExaoneAttention
|
|
48
|
-
|
|
49
|
-
def get_rbln_layer_class(self):
|
|
50
|
-
return ExaoneLayer
|
|
51
|
-
|
|
52
|
-
def get_rbln_model_class(self):
|
|
53
|
-
return ExaoneModel
|
|
54
|
-
|
|
55
|
-
|
|
56
|
-
class ExaoneModel(DecoderOnlyModel):
|
|
57
|
-
def get_embedding(self) -> nn.Embedding:
|
|
58
|
-
return self._original_mod.wte
|
|
59
|
-
|
|
60
|
-
def get_last_layernorm(self) -> nn.LayerNorm:
|
|
61
|
-
return self._original_mod.ln_f
|
|
62
|
-
|
|
63
|
-
|
|
64
|
-
class ExaoneLayer(DecoderOnlyLayer):
|
|
65
|
-
def get_pre_attention_layernorm(self) -> nn.LayerNorm:
|
|
66
|
-
return self._original_mod.ln_1
|
|
67
|
-
|
|
68
|
-
def get_post_attention_layernorm(self) -> nn.LayerNorm:
|
|
69
|
-
return self._original_mod.ln_2
|
|
70
|
-
|
|
71
|
-
|
|
72
|
-
class ExaoneAttention(DecoderOnlyAttention):
|
|
73
|
-
def __post_init__(self):
|
|
74
|
-
self.q_proj = self._original_mod.q_proj
|
|
75
|
-
self.k_proj = self._original_mod.k_proj
|
|
76
|
-
self.v_proj = self._original_mod.v_proj
|
|
77
|
-
self.o_proj = self._original_mod.out_proj
|
|
@@ -0,0 +1,16 @@
|
|
|
1
|
+
# Copyright 2025 Rebellions Inc. All rights reserved.
|
|
2
|
+
|
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
4
|
+
# you may not use this file except in compliance with the License.
|
|
5
|
+
# You may obtain a copy of the License at:
|
|
6
|
+
|
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
8
|
+
|
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
12
|
+
# See the License for the specific language governing permissions and
|
|
13
|
+
# limitations under the License.
|
|
14
|
+
|
|
15
|
+
from .configuration_gemma2 import RBLNGemma2ForCausalLMConfig, RBLNGemma2ModelConfig
|
|
16
|
+
from .modeling_gemma2 import RBLNGemma2ForCausalLM, RBLNGemma2Model
|
|
@@ -0,0 +1,45 @@
|
|
|
1
|
+
# Copyright 2025 Rebellions Inc. All rights reserved.
|
|
2
|
+
|
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
4
|
+
# you may not use this file except in compliance with the License.
|
|
5
|
+
# You may obtain a copy of the License at:
|
|
6
|
+
|
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
8
|
+
|
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
12
|
+
# See the License for the specific language governing permissions and
|
|
13
|
+
# limitations under the License.
|
|
14
|
+
|
|
15
|
+
from ..decoderonly.configuration_decoderonly import RBLNDecoderOnlyModelConfig, RBLNDecoderOnlyModelForCausalLMConfig
|
|
16
|
+
|
|
17
|
+
|
|
18
|
+
class RBLNGemma2ForCausalLMConfig(RBLNDecoderOnlyModelForCausalLMConfig):
|
|
19
|
+
"""
|
|
20
|
+
Configuration class for RBLN Gemma2 models.
|
|
21
|
+
This class is an alias of RBLNDecoderOnlyModelForCausalLMConfig.
|
|
22
|
+
Example usage:
|
|
23
|
+
```python
|
|
24
|
+
from optimum.rbln import RBLNGemma2ForCausalLM, RBLNGemma2ForCausalLMConfig
|
|
25
|
+
# Create a configuration object
|
|
26
|
+
config = RBLNGemma2ForCausalLMConfig(
|
|
27
|
+
batch_size=1,
|
|
28
|
+
max_seq_len=8192,
|
|
29
|
+
tensor_parallel_size=4
|
|
30
|
+
)
|
|
31
|
+
# Use the configuration with from_pretrained
|
|
32
|
+
model = RBLNGemma2ForCausalLM.from_pretrained(
|
|
33
|
+
"google/gemma-2-9b",
|
|
34
|
+
export=True,
|
|
35
|
+
rbln_config=config
|
|
36
|
+
)
|
|
37
|
+
```
|
|
38
|
+
"""
|
|
39
|
+
|
|
40
|
+
|
|
41
|
+
class RBLNGemma2ModelConfig(RBLNDecoderOnlyModelConfig):
|
|
42
|
+
"""
|
|
43
|
+
Configuration class for RBLN Gemma2 models.
|
|
44
|
+
This class is an alias of RBLNDecoderOnlyModelConfig.
|
|
45
|
+
"""
|
|
@@ -0,0 +1,83 @@
|
|
|
1
|
+
# Copyright 2025 Rebellions Inc. All rights reserved.
|
|
2
|
+
|
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
4
|
+
# you may not use this file except in compliance with the License.
|
|
5
|
+
# You may obtain a copy of the License at:
|
|
6
|
+
|
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
8
|
+
|
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
12
|
+
# See the License for the specific language governing permissions and
|
|
13
|
+
# limitations under the License.
|
|
14
|
+
|
|
15
|
+
from typing import Optional, Tuple, Union
|
|
16
|
+
|
|
17
|
+
import torch
|
|
18
|
+
|
|
19
|
+
from ...models.decoderonly.decoderonly_architecture import DecoderOnlyAttention, DecoderOnlyLayer, DecoderOnlyModel
|
|
20
|
+
from ..decoderonly.decoderonly_architecture import DecoderOnlyWrapper
|
|
21
|
+
|
|
22
|
+
|
|
23
|
+
class Gemma2Wrapper(DecoderOnlyWrapper):
|
|
24
|
+
def get_rbln_layer_class(self):
|
|
25
|
+
return Gemma2DecoderLayer
|
|
26
|
+
|
|
27
|
+
def get_rbln_attn_class(self):
|
|
28
|
+
return Gemma2Attention
|
|
29
|
+
|
|
30
|
+
def get_rbln_model_class(self):
|
|
31
|
+
return Gemma2Model
|
|
32
|
+
|
|
33
|
+
|
|
34
|
+
class Gemma2DecoderLayer(DecoderOnlyLayer):
|
|
35
|
+
_PRE_FF_LAYERNORM_ATTRS = ["pre_feedforward_layernorm"]
|
|
36
|
+
_POST_FF_LAYERNORM_ATTRS = ["post_feedforward_layernorm"]
|
|
37
|
+
|
|
38
|
+
def forward(
|
|
39
|
+
self,
|
|
40
|
+
hidden_states: torch.Tensor,
|
|
41
|
+
attention_mask: torch.Tensor,
|
|
42
|
+
seq_positions: Union[torch.LongTensor, Tuple[torch.LongTensor]],
|
|
43
|
+
past_key_values: Tuple[Tuple[torch.Tensor]],
|
|
44
|
+
cos: Optional[torch.Tensor] = None,
|
|
45
|
+
sin: Optional[torch.Tensor] = None,
|
|
46
|
+
block_tables: Optional[torch.Tensor] = None,
|
|
47
|
+
lora_int_id: Optional[torch.Tensor] = None,
|
|
48
|
+
):
|
|
49
|
+
residual = hidden_states
|
|
50
|
+
hidden_states = self.get_pre_attention_layernorm()(hidden_states)
|
|
51
|
+
|
|
52
|
+
hidden_states = self.self_attn(
|
|
53
|
+
hidden_states=hidden_states,
|
|
54
|
+
attention_mask=attention_mask,
|
|
55
|
+
seq_positions=seq_positions,
|
|
56
|
+
past_key_values=past_key_values,
|
|
57
|
+
cos=cos,
|
|
58
|
+
sin=sin,
|
|
59
|
+
block_tables=block_tables,
|
|
60
|
+
lora_int_id=lora_int_id,
|
|
61
|
+
)
|
|
62
|
+
hidden_states = self.get_post_attention_layernorm()(hidden_states)
|
|
63
|
+
hidden_states = residual + hidden_states
|
|
64
|
+
|
|
65
|
+
# Fully Connected
|
|
66
|
+
residual = hidden_states
|
|
67
|
+
hidden_states = self.get_pre_feedforward_layernorm()(hidden_states)
|
|
68
|
+
hidden_states = self.forward_mlp(hidden_states, lora_int_id)
|
|
69
|
+
hidden_states = self.get_post_feedforward_layernorm()(hidden_states)
|
|
70
|
+
hidden_states = residual + hidden_states
|
|
71
|
+
|
|
72
|
+
return hidden_states
|
|
73
|
+
|
|
74
|
+
|
|
75
|
+
class Gemma2Attention(DecoderOnlyAttention):
|
|
76
|
+
def get_attn_scale(self, self_attn):
|
|
77
|
+
return self_attn.config.query_pre_attn_scalar**-0.5
|
|
78
|
+
|
|
79
|
+
|
|
80
|
+
class Gemma2Model(DecoderOnlyModel):
|
|
81
|
+
@property
|
|
82
|
+
def hidden_multiplier(self):
|
|
83
|
+
return self.config.hidden_size**0.5
|
|
@@ -0,0 +1,101 @@
|
|
|
1
|
+
# Copyright 2025 Rebellions Inc. All rights reserved.
|
|
2
|
+
|
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
4
|
+
# you may not use this file except in compliance with the License.
|
|
5
|
+
# You may obtain a copy of the License at:
|
|
6
|
+
|
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
8
|
+
|
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
12
|
+
# See the License for the specific language governing permissions and
|
|
13
|
+
# limitations under the License.
|
|
14
|
+
|
|
15
|
+
|
|
16
|
+
from ....utils import logging
|
|
17
|
+
from ...models.decoderonly import (
|
|
18
|
+
RBLNDecoderOnlyModel,
|
|
19
|
+
RBLNDecoderOnlyModelForCausalLM,
|
|
20
|
+
)
|
|
21
|
+
from .gemma2_architecture import Gemma2Wrapper
|
|
22
|
+
|
|
23
|
+
|
|
24
|
+
logger = logging.get_logger(__name__)
|
|
25
|
+
|
|
26
|
+
|
|
27
|
+
class RBLNGemma2ForCausalLM(RBLNDecoderOnlyModelForCausalLM):
|
|
28
|
+
"""
|
|
29
|
+
The Gemma2 Model transformer with a language modeling head (linear layer) on top.
|
|
30
|
+
This model inherits from [`RBLNDecoderOnlyModelForCausalLM`]. Check the superclass documentation for the generic methods the library implements for all its models.
|
|
31
|
+
|
|
32
|
+
A class to convert and run pre-trained transformers based Gemma2ForCausalLM model on RBLN devices.
|
|
33
|
+
It implements the methods to convert a pre-trained transformers Gemma2ForCausalLM model into a RBLN transformer model by:
|
|
34
|
+
|
|
35
|
+
- transferring the checkpoint weights of the original into an optimized RBLN graph,
|
|
36
|
+
- compiling the resulting graph using the RBLN compiler.
|
|
37
|
+
|
|
38
|
+
**Configuration:**
|
|
39
|
+
This model uses [`RBLNGemma2ForCausalLMConfig`] for configuration. When calling methods like `from_pretrained` or `from_model`,
|
|
40
|
+
the `rbln_config` parameter should be an instance of [`RBLNGemma2ForCausalLMConfig`] or a dictionary conforming to its structure.
|
|
41
|
+
|
|
42
|
+
See the [`RBLNGemma2ForCausalLMConfig`] class for all available configuration options.
|
|
43
|
+
Examples:
|
|
44
|
+
```python
|
|
45
|
+
from optimum.rbln import RBLNGemma2ForCausalLM
|
|
46
|
+
# Simple usage using rbln_* arguments
|
|
47
|
+
# `max_seq_len` is automatically inferred from the model config
|
|
48
|
+
model = RBLNGemma2ForCausalLM.from_pretrained(
|
|
49
|
+
"google/gemma-2-9b",
|
|
50
|
+
export=True,
|
|
51
|
+
rbln_batch_size=1,
|
|
52
|
+
rbln_tensor_parallel_size=4,
|
|
53
|
+
)
|
|
54
|
+
# Using a config dictionary
|
|
55
|
+
rbln_config = {
|
|
56
|
+
"batch_size": 1,
|
|
57
|
+
"max_seq_len": 8192,
|
|
58
|
+
"tensor_parallel_size": 4,
|
|
59
|
+
}
|
|
60
|
+
model = RBLNGemma2ForCausalLM.from_pretrained(
|
|
61
|
+
"google/gemma-2-9b",
|
|
62
|
+
export=True,
|
|
63
|
+
rbln_config=rbln_config
|
|
64
|
+
)
|
|
65
|
+
# Using a RBLNMistralForCausalLMConfig instance (recommended for type checking)
|
|
66
|
+
from optimum.rbln import RBLNGemma2ForCausalLMConfig
|
|
67
|
+
config = RBLNGemma2ForCausalLMConfig(
|
|
68
|
+
batch_size=1,
|
|
69
|
+
max_seq_len=8192,
|
|
70
|
+
tensor_parallel_size=4
|
|
71
|
+
)
|
|
72
|
+
model = RBLNGemma2ForCausalLM.from_pretrained(
|
|
73
|
+
"google/gemma-2-9b",
|
|
74
|
+
export=True,
|
|
75
|
+
rbln_config=config
|
|
76
|
+
)
|
|
77
|
+
```
|
|
78
|
+
"""
|
|
79
|
+
|
|
80
|
+
_decoder_wrapper_cls = Gemma2Wrapper
|
|
81
|
+
|
|
82
|
+
|
|
83
|
+
class RBLNGemma2Model(RBLNDecoderOnlyModel):
|
|
84
|
+
"""
|
|
85
|
+
The Gemma2 Model transformer without a language modeling head.
|
|
86
|
+
This model inherits from [`RBLNDecoderOnlyModel`]. Check the superclass documentation for the generic methods the library implements for all its models.
|
|
87
|
+
|
|
88
|
+
A class to convert and run pre-trained transformers based Gemma2Model model on RBLN devices.
|
|
89
|
+
It implements the methods to convert a pre-trained transformers Gemma2Model model into a RBLN transformer model by:
|
|
90
|
+
|
|
91
|
+
- transferring the checkpoint weights of the original into an optimized RBLN graph,
|
|
92
|
+
- compiling the resulting graph using the RBLN compiler.
|
|
93
|
+
|
|
94
|
+
**Configuration:**
|
|
95
|
+
This model uses [`RBLNGemma2ModelConfig`] for configuration. When calling methods like `from_pretrained` or `from_model`,
|
|
96
|
+
the `rbln_config` parameter should be an instance of [`RBLNGemma2ModelConfig`] or a dictionary conforming to its structure.
|
|
97
|
+
|
|
98
|
+
See the [`RBLNGemma2ModelConfig`] class for all available configuration options.
|
|
99
|
+
"""
|
|
100
|
+
|
|
101
|
+
_decoder_wrapper_cls = Gemma2Wrapper
|