optimum-rbln 0.8.2a1__py3-none-any.whl → 0.8.2a3__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of optimum-rbln might be problematic. Click here for more details.
- optimum/rbln/__init__.py +8 -0
- optimum/rbln/__version__.py +2 -2
- optimum/rbln/configuration_utils.py +16 -1
- optimum/rbln/diffusers/configurations/models/configuration_transformer_cosmos.py +3 -0
- optimum/rbln/diffusers/modeling_diffusers.py +3 -4
- optimum/rbln/diffusers/models/autoencoders/autoencoder_kl.py +1 -0
- optimum/rbln/diffusers/models/autoencoders/autoencoder_kl_cosmos.py +1 -0
- optimum/rbln/diffusers/models/autoencoders/vq_model.py +1 -0
- optimum/rbln/diffusers/models/transformers/transformer_cosmos.py +1 -1
- optimum/rbln/diffusers/pipelines/cosmos/configuration_cosmos_guardrail.py +10 -2
- optimum/rbln/diffusers/pipelines/cosmos/cosmos_guardrail.py +4 -30
- optimum/rbln/modeling.py +2 -3
- optimum/rbln/modeling_base.py +17 -13
- optimum/rbln/transformers/__init__.py +8 -0
- optimum/rbln/transformers/models/__init__.py +2 -0
- optimum/rbln/transformers/models/clip/configuration_clip.py +12 -1
- optimum/rbln/transformers/models/clip/modeling_clip.py +123 -28
- optimum/rbln/transformers/models/decoderonly/modeling_decoderonly.py +13 -1
- optimum/rbln/transformers/models/gemma3/configuration_gemma3.py +2 -3
- optimum/rbln/transformers/models/gemma3/modeling_gemma3.py +107 -249
- optimum/rbln/transformers/models/llava_next/configuration_llava_next.py +18 -1
- optimum/rbln/transformers/models/qwen3/__init__.py +16 -0
- optimum/rbln/transformers/models/qwen3/configuration_qwen3.py +71 -0
- optimum/rbln/transformers/models/qwen3/modeling_qwen3.py +377 -0
- optimum/rbln/transformers/models/qwen3/qwen3_architecture.py +275 -0
- optimum/rbln/transformers/models/seq2seq/modeling_seq2seq.py +2 -0
- optimum/rbln/transformers/models/time_series_transformer/modeling_time_series_transformer.py +2 -0
- optimum/rbln/transformers/models/whisper/modeling_whisper.py +2 -0
- optimum/rbln/utils/hub.py +8 -47
- optimum/rbln/utils/runtime_utils.py +28 -2
- {optimum_rbln-0.8.2a1.dist-info → optimum_rbln-0.8.2a3.dist-info}/METADATA +1 -1
- {optimum_rbln-0.8.2a1.dist-info → optimum_rbln-0.8.2a3.dist-info}/RECORD +34 -30
- {optimum_rbln-0.8.2a1.dist-info → optimum_rbln-0.8.2a3.dist-info}/WHEEL +0 -0
- {optimum_rbln-0.8.2a1.dist-info → optimum_rbln-0.8.2a3.dist-info}/licenses/LICENSE +0 -0
|
@@ -31,15 +31,11 @@ from transformers.models.gemma3.modeling_gemma3 import Gemma3TextScaledWordEmbed
|
|
|
31
31
|
|
|
32
32
|
from ....configuration_utils import RBLNCompileConfig, RBLNModelConfig
|
|
33
33
|
from ....modeling import RBLNModel
|
|
34
|
-
from ....utils.logging import get_logger
|
|
35
34
|
from ..decoderonly.modeling_decoderonly import RBLNDecoderOnlyModelForCausalLM, RBLNDecoderOnlyOutput, RBLNRuntimeModel
|
|
36
35
|
from .configuration_gemma3 import RBLNGemma3ForCausalLMConfig
|
|
37
36
|
from .gemma3_architecture import Gemma3ForCausalLMWrapper
|
|
38
37
|
|
|
39
38
|
|
|
40
|
-
logger = get_logger()
|
|
41
|
-
|
|
42
|
-
|
|
43
39
|
if TYPE_CHECKING:
|
|
44
40
|
from transformers import AutoFeatureExtractor, AutoProcessor, AutoTokenizer, Gemma3ForConditionalGeneration
|
|
45
41
|
|
|
@@ -320,194 +316,28 @@ class RBLNGemma3RuntimeModel(RBLNRuntimeModel):
|
|
|
320
316
|
self.prefill = self.runtime if self.phase == "prefill" else None # FIXME
|
|
321
317
|
self.decode = self.runtime if self.phase == "decode" else None
|
|
322
318
|
|
|
323
|
-
def
|
|
324
|
-
|
|
325
|
-
|
|
326
|
-
|
|
327
|
-
|
|
328
|
-
|
|
329
|
-
|
|
330
|
-
|
|
331
|
-
|
|
332
|
-
|
|
333
|
-
|
|
334
|
-
|
|
335
|
-
inputs: (1, seq_len, hidden_size) tensor.
|
|
336
|
-
attention_mask: (1, seq_len) tensor, 1 for valid, 0 for masked.
|
|
337
|
-
position_ids: (1, seq_len) tensor for RoPE.
|
|
338
|
-
token_type_ids: (1, seq_len) tensor, 0 for text, 1 for image.
|
|
339
|
-
|
|
340
|
-
Returns:
|
|
341
|
-
(inputs_padded, attention_mask_padded, position_ids_padded, padded_len, token_type_ids_padded).
|
|
342
|
-
"""
|
|
343
|
-
|
|
344
|
-
if token_type_ids is None:
|
|
345
|
-
return inputs, attention_mask, position_ids, 0, torch.zeros(inputs.shape[:2], dtype=torch.long)
|
|
346
|
-
|
|
347
|
-
seq_len = inputs.shape[1]
|
|
348
|
-
|
|
349
|
-
# Find image start positions
|
|
350
|
-
image_starts = [
|
|
351
|
-
s
|
|
352
|
-
for s in range(seq_len - self.rbln_config.prefill_chunk_size + 1)
|
|
353
|
-
if torch.all(token_type_ids[:, s : s + self.rbln_config.prefill_chunk_size] == 1)
|
|
354
|
-
]
|
|
355
|
-
|
|
356
|
-
# Initialize padded tensors
|
|
357
|
-
padded_input_len = seq_len
|
|
358
|
-
for image_start in image_starts:
|
|
359
|
-
pad_needed = (
|
|
360
|
-
self.rbln_config.prefill_chunk_size
|
|
361
|
-
- (image_start + padded_input_len - seq_len) % self.rbln_config.prefill_chunk_size
|
|
362
|
-
) % self.rbln_config.prefill_chunk_size
|
|
363
|
-
padded_input_len += pad_needed
|
|
364
|
-
total_padding = padded_input_len - seq_len
|
|
365
|
-
|
|
366
|
-
if inputs.dim() == 3:
|
|
367
|
-
inputs_padded = torch.zeros(1, padded_input_len, inputs.shape[2], dtype=inputs.dtype)
|
|
368
|
-
else:
|
|
369
|
-
inputs_padded = torch.zeros(1, padded_input_len, dtype=inputs.dtype)
|
|
370
|
-
attention_mask_padded = torch.zeros(1, padded_input_len, dtype=attention_mask.dtype)
|
|
371
|
-
position_ids_padded = torch.zeros(1, padded_input_len, dtype=position_ids.dtype)
|
|
372
|
-
token_type_ids_padded = torch.zeros(1, padded_input_len, dtype=token_type_ids.dtype)
|
|
373
|
-
|
|
374
|
-
# Fill padded tensors
|
|
375
|
-
dest_pos = 0
|
|
376
|
-
src_pos = 0
|
|
377
|
-
last_pos_id = -1
|
|
378
|
-
for image_start in image_starts + [seq_len]:
|
|
379
|
-
# Text segment
|
|
380
|
-
if src_pos < image_start:
|
|
381
|
-
length = image_start - src_pos
|
|
382
|
-
inputs_padded[:, dest_pos : dest_pos + length] = inputs[:, src_pos:image_start]
|
|
383
|
-
attention_mask_padded[:, dest_pos : dest_pos + length] = attention_mask[:, src_pos:image_start]
|
|
384
|
-
position_ids_padded[:, dest_pos : dest_pos + length] = position_ids[:, src_pos:image_start]
|
|
385
|
-
token_type_ids_padded[:, dest_pos : dest_pos + length] = token_type_ids[:, src_pos:image_start]
|
|
386
|
-
dest_pos += length
|
|
387
|
-
last_pos_id = position_ids[0, image_start - 1].item()
|
|
388
|
-
src_pos = image_start
|
|
389
|
-
|
|
390
|
-
# Padding
|
|
391
|
-
pad_needed = (
|
|
392
|
-
self.rbln_config.prefill_chunk_size - dest_pos % self.rbln_config.prefill_chunk_size
|
|
393
|
-
) % self.rbln_config.prefill_chunk_size
|
|
394
|
-
if pad_needed and dest_pos < padded_input_len:
|
|
395
|
-
position_ids_padded[:, dest_pos : dest_pos + pad_needed] = torch.arange(
|
|
396
|
-
last_pos_id + 1, last_pos_id + pad_needed + 1, dtype=position_ids.dtype
|
|
397
|
-
).unsqueeze(0)
|
|
398
|
-
dest_pos += pad_needed
|
|
399
|
-
|
|
400
|
-
# Image segment
|
|
401
|
-
if src_pos < seq_len and src_pos == image_start:
|
|
402
|
-
inputs_padded[:, dest_pos : dest_pos + self.rbln_config.prefill_chunk_size] = inputs[
|
|
403
|
-
:, src_pos : src_pos + self.rbln_config.prefill_chunk_size
|
|
404
|
-
]
|
|
405
|
-
attention_mask_padded[:, dest_pos : dest_pos + self.rbln_config.prefill_chunk_size] = attention_mask[
|
|
406
|
-
:, src_pos : src_pos + self.rbln_config.prefill_chunk_size
|
|
407
|
-
]
|
|
408
|
-
position_ids_padded[:, dest_pos : dest_pos + self.rbln_config.prefill_chunk_size] = position_ids[
|
|
409
|
-
:, src_pos : src_pos + self.rbln_config.prefill_chunk_size
|
|
410
|
-
]
|
|
411
|
-
token_type_ids_padded[:, dest_pos : dest_pos + self.rbln_config.prefill_chunk_size] = token_type_ids[
|
|
412
|
-
:, src_pos : src_pos + self.rbln_config.prefill_chunk_size
|
|
413
|
-
]
|
|
414
|
-
dest_pos += self.rbln_config.prefill_chunk_size
|
|
415
|
-
src_pos += self.rbln_config.prefill_chunk_size
|
|
416
|
-
last_pos_id = position_ids[0, image_start + self.rbln_config.prefill_chunk_size - 1].item()
|
|
417
|
-
|
|
418
|
-
return inputs_padded, attention_mask_padded, position_ids_padded, total_padding, token_type_ids_padded
|
|
419
|
-
|
|
420
|
-
def _prepare_prefill_inputs(
|
|
421
|
-
self,
|
|
422
|
-
inputs: torch.Tensor,
|
|
423
|
-
cache_position: torch.Tensor,
|
|
424
|
-
attention_mask: Optional[torch.Tensor] = None,
|
|
425
|
-
position_embed: Optional[torch.Tensor] = None,
|
|
426
|
-
token_type_ids: Optional[torch.Tensor] = None,
|
|
427
|
-
):
|
|
428
|
-
"""
|
|
429
|
-
Prepare inputs for prefill phase.
|
|
430
|
-
"""
|
|
431
|
-
# Handle continuous batching in a compiled graph by extracting valid inputs
|
|
432
|
-
# If an attention mask is provided, select only the valid (non-masked) inputs
|
|
433
|
-
inputs = inputs[:, attention_mask.bool()] if attention_mask is not None else inputs
|
|
434
|
-
token_type_ids = (
|
|
435
|
-
token_type_ids[:, attention_mask.bool()]
|
|
436
|
-
if attention_mask is not None and token_type_ids is not None
|
|
437
|
-
else token_type_ids
|
|
438
|
-
)
|
|
439
|
-
|
|
440
|
-
if position_embed is not None:
|
|
441
|
-
position_embed = (
|
|
442
|
-
position_embed[:, :, :, attention_mask.bool(), :] if attention_mask is not None else position_embed
|
|
443
|
-
)
|
|
444
|
-
|
|
445
|
-
seq_len = inputs.shape[1]
|
|
446
|
-
# Initialize attention mask for chunked processing
|
|
447
|
-
if self.rbln_config.use_attention_mask:
|
|
448
|
-
chunked_attention_mask = (
|
|
449
|
-
torch.ones(1, seq_len, dtype=torch.float32)
|
|
450
|
-
if self.rbln_config.use_position_ids
|
|
451
|
-
else torch.zeros(
|
|
452
|
-
1, 1, self.rbln_config.prefill_chunk_size, self.rbln_config.max_seq_len, dtype=torch.float32
|
|
453
|
-
)
|
|
454
|
-
)
|
|
455
|
-
else:
|
|
456
|
-
chunked_attention_mask = None
|
|
457
|
-
|
|
458
|
-
# Buffer for storing output logits
|
|
459
|
-
out_buffers = [
|
|
460
|
-
torch.empty(
|
|
461
|
-
size=self.output_size,
|
|
462
|
-
dtype=torch.float32,
|
|
463
|
-
device="cpu",
|
|
464
|
-
)
|
|
465
|
-
]
|
|
466
|
-
|
|
467
|
-
inputs, chunked_attention_mask, position_ids, padded_cache_lengths, token_type_ids_padded = (
|
|
468
|
-
self.pad_for_chunked_images(inputs, chunked_attention_mask, cache_position, token_type_ids)
|
|
469
|
-
)
|
|
470
|
-
|
|
471
|
-
query_length = inputs.shape[1]
|
|
472
|
-
if query_length > self.rbln_config.max_seq_len:
|
|
473
|
-
raise ValueError(
|
|
474
|
-
f"Input length ({query_length}) exceeds the maximum allowed sequence length ({self.rbln_config.max_seq_len})."
|
|
475
|
-
)
|
|
476
|
-
|
|
477
|
-
# Align attention_mask to compiled shape
|
|
478
|
-
if self.rbln_config.use_position_ids:
|
|
479
|
-
chunked_attention_mask = torch.nn.functional.pad(
|
|
480
|
-
chunked_attention_mask, (0, self.rbln_config.max_seq_len - query_length)
|
|
481
|
-
)
|
|
482
|
-
|
|
483
|
-
# Pad input and cache_position if the last chunk is smaller than `prefill_chunk_size`
|
|
484
|
-
padding_size = 0
|
|
485
|
-
if query_length % self.rbln_config.prefill_chunk_size != 0:
|
|
486
|
-
padding_size = (self.rbln_config.prefill_chunk_size - query_length) % self.rbln_config.prefill_chunk_size
|
|
487
|
-
# inputs_embeds
|
|
488
|
-
if inputs.dim() == 3:
|
|
489
|
-
inputs = torch.nn.functional.pad(inputs, (0, 0, 0, padding_size))
|
|
490
|
-
# inputs_ids
|
|
491
|
-
else:
|
|
492
|
-
inputs = torch.nn.functional.pad(inputs, (0, padding_size))
|
|
493
|
-
|
|
494
|
-
position_ids = torch.cat(
|
|
495
|
-
[
|
|
496
|
-
position_ids,
|
|
497
|
-
torch.arange(
|
|
498
|
-
query_length,
|
|
499
|
-
query_length + padding_size,
|
|
500
|
-
dtype=torch.int32,
|
|
501
|
-
).unsqueeze(0),
|
|
502
|
-
],
|
|
503
|
-
dim=-1,
|
|
504
|
-
)
|
|
505
|
-
token_type_ids_padded = torch.nn.functional.pad(token_type_ids_padded, (0, padding_size))
|
|
319
|
+
def _prepare_prefill_inputs(self, *args, **kwargs):
|
|
320
|
+
(
|
|
321
|
+
inputs,
|
|
322
|
+
cache_position,
|
|
323
|
+
chunked_attention_mask,
|
|
324
|
+
out_buffers,
|
|
325
|
+
position_ids,
|
|
326
|
+
position_embed,
|
|
327
|
+
padded_cache_lengths,
|
|
328
|
+
query_length,
|
|
329
|
+
token_type_ids,
|
|
330
|
+
) = super()._prepare_prefill_inputs(*args, **kwargs)
|
|
506
331
|
|
|
507
|
-
|
|
508
|
-
|
|
332
|
+
# chunked_attention_mask shape
|
|
333
|
+
chunked_attention_mask = torch.zeros(1, chunked_attention_mask.shape[-1], dtype=torch.float32)
|
|
509
334
|
|
|
510
|
-
|
|
335
|
+
# as gemma3 has different prefill chunk size for image and text, we need to pad the inputs to the max of the two.
|
|
336
|
+
padding_size = max(self.rbln_config.prefill_chunk_size, self.rbln_config.image_prefill_chunk_size)
|
|
337
|
+
inputs = torch.nn.functional.pad(inputs, (0, 0, 0, padding_size))
|
|
338
|
+
cache_position = torch.nn.functional.pad(cache_position, (0, padding_size))
|
|
339
|
+
position_ids = torch.nn.functional.pad(position_ids, (0, padding_size))
|
|
340
|
+
token_type_ids = torch.nn.functional.pad(token_type_ids, (0, padding_size), value=-1)
|
|
511
341
|
|
|
512
342
|
return (
|
|
513
343
|
inputs,
|
|
@@ -518,7 +348,7 @@ class RBLNGemma3RuntimeModel(RBLNRuntimeModel):
|
|
|
518
348
|
position_embed,
|
|
519
349
|
padded_cache_lengths,
|
|
520
350
|
query_length,
|
|
521
|
-
|
|
351
|
+
token_type_ids,
|
|
522
352
|
)
|
|
523
353
|
|
|
524
354
|
def prefill_forward(
|
|
@@ -541,65 +371,73 @@ class RBLNGemma3RuntimeModel(RBLNRuntimeModel):
|
|
|
541
371
|
(
|
|
542
372
|
inputs,
|
|
543
373
|
cache_position,
|
|
544
|
-
|
|
374
|
+
chunked_attention_mask,
|
|
545
375
|
out_buffers,
|
|
546
376
|
position_ids,
|
|
547
377
|
position_embed,
|
|
548
378
|
padded_cache_lengths,
|
|
549
379
|
query_length,
|
|
550
|
-
|
|
380
|
+
token_type_ids,
|
|
551
381
|
) = self._prepare_prefill_inputs(
|
|
552
382
|
inputs, cache_position, attention_mask, position_embed, token_type_ids=token_type_ids
|
|
553
383
|
)
|
|
554
|
-
|
|
555
|
-
|
|
556
|
-
|
|
557
|
-
|
|
558
|
-
|
|
559
|
-
|
|
560
|
-
|
|
561
|
-
|
|
562
|
-
|
|
563
|
-
# Extract the current chunk of inputs and cache positions
|
|
564
|
-
input_chunk = inputs[:, step : step + self.rbln_config.prefill_chunk_size]
|
|
565
|
-
cache_pos_chunk = cache_position[:, step : step + self.rbln_config.prefill_chunk_size]
|
|
566
|
-
position_ids_chunk = (
|
|
567
|
-
position_ids[:, step : step + self.rbln_config.prefill_chunk_size]
|
|
568
|
-
if position_ids is not None
|
|
569
|
-
else None
|
|
384
|
+
|
|
385
|
+
step = 0
|
|
386
|
+
while step < query_length:
|
|
387
|
+
# Check if the prefill chunk is an image prefill
|
|
388
|
+
is_image_prefill = torch.all(
|
|
389
|
+
token_type_ids[:, step : step + self.rbln_config.image_prefill_chunk_size] == 1
|
|
390
|
+
)
|
|
391
|
+
prefill_chunk_size = (
|
|
392
|
+
self.rbln_config.image_prefill_chunk_size if is_image_prefill else self.rbln_config.prefill_chunk_size
|
|
570
393
|
)
|
|
571
394
|
|
|
572
|
-
if
|
|
573
|
-
|
|
574
|
-
|
|
575
|
-
padded_attention_mask[0, step : step + self.rbln_config.prefill_chunk_size]
|
|
576
|
-
)
|
|
577
|
-
|
|
578
|
-
# Define query position
|
|
579
|
-
query_position = (
|
|
580
|
-
torch.sum(
|
|
581
|
-
chunked_attention_mask[0][step : step + self.rbln_config.prefill_chunk_size],
|
|
582
|
-
dim=-1,
|
|
583
|
-
dtype=torch.int16,
|
|
584
|
-
).squeeze(0)
|
|
585
|
-
- 1
|
|
395
|
+
# Check if the prefill chunk is a text prefill which have image_tokens in it.
|
|
396
|
+
is_text_prefill_with_image_tokens = not is_image_prefill and torch.any(
|
|
397
|
+
token_type_ids[:, step : step + prefill_chunk_size] == 1
|
|
586
398
|
)
|
|
587
|
-
|
|
588
|
-
|
|
589
|
-
|
|
590
|
-
|
|
591
|
-
|
|
592
|
-
|
|
593
|
-
|
|
594
|
-
|
|
595
|
-
|
|
596
|
-
|
|
597
|
-
|
|
598
|
-
|
|
599
|
-
|
|
600
|
-
|
|
399
|
+
|
|
400
|
+
# Check if the prefill chunk crosses a block boundary, requiring padding to align with block boundaries
|
|
401
|
+
is_cross_block_boundary = (
|
|
402
|
+
step // self.rbln_config.kvcache_block_size
|
|
403
|
+
!= (step + prefill_chunk_size) // self.rbln_config.kvcache_block_size
|
|
404
|
+
)
|
|
405
|
+
|
|
406
|
+
# Check if the prefill chunk is the last chunk
|
|
407
|
+
is_last_chunk = step + prefill_chunk_size >= query_length
|
|
408
|
+
|
|
409
|
+
if is_cross_block_boundary:
|
|
410
|
+
padding_size = prefill_chunk_size - (step + prefill_chunk_size) % self.rbln_config.kvcache_block_size
|
|
411
|
+
padded_cache_lengths += padding_size
|
|
412
|
+
|
|
413
|
+
# if text_prefill end with image_tokens, we only treat the text part.
|
|
414
|
+
num_processed_tokens = prefill_chunk_size
|
|
415
|
+
if is_text_prefill_with_image_tokens:
|
|
416
|
+
first_image_token_idx = torch.where(token_type_ids[:, step : step + prefill_chunk_size] == 1)[1][0]
|
|
417
|
+
num_processed_tokens = first_image_token_idx
|
|
418
|
+
if is_last_chunk:
|
|
419
|
+
num_processed_tokens = query_length - step
|
|
420
|
+
|
|
421
|
+
input_chunk = inputs[:, step : step + prefill_chunk_size]
|
|
422
|
+
cache_pos_chunk = cache_position[:, step : step + prefill_chunk_size].clone() + padded_cache_lengths
|
|
423
|
+
position_ids_chunk = position_ids[:, step : step + prefill_chunk_size].clone()
|
|
424
|
+
chunked_attention_mask[
|
|
425
|
+
:, step + padded_cache_lengths : step + num_processed_tokens + padded_cache_lengths
|
|
426
|
+
] = 1
|
|
427
|
+
query_position = torch.tensor(num_processed_tokens - 1, dtype=torch.int16)
|
|
428
|
+
|
|
429
|
+
if is_image_prefill:
|
|
430
|
+
logits = self.image_prefill(
|
|
431
|
+
input_chunk,
|
|
432
|
+
cache_pos_chunk,
|
|
433
|
+
block_tables,
|
|
434
|
+
local_block_tables,
|
|
435
|
+
query_position,
|
|
436
|
+
chunked_attention_mask,
|
|
437
|
+
position_ids_chunk,
|
|
438
|
+
out=out_buffers,
|
|
439
|
+
)
|
|
601
440
|
else:
|
|
602
|
-
# Forward pass for the current chunk
|
|
603
441
|
logits = self.prefill(
|
|
604
442
|
input_chunk,
|
|
605
443
|
cache_pos_chunk,
|
|
@@ -611,6 +449,11 @@ class RBLNGemma3RuntimeModel(RBLNRuntimeModel):
|
|
|
611
449
|
out=out_buffers,
|
|
612
450
|
)
|
|
613
451
|
|
|
452
|
+
step += num_processed_tokens
|
|
453
|
+
|
|
454
|
+
if not is_external_block_tables:
|
|
455
|
+
self.dec_attn_mask[batch_idx : batch_idx + 1] = chunked_attention_mask
|
|
456
|
+
|
|
614
457
|
return RBLNGemma3ForCausalLMOutput(
|
|
615
458
|
logits=logits, padded_cache_lengths=padded_cache_lengths, attention_mask=chunked_attention_mask
|
|
616
459
|
)
|
|
@@ -757,13 +600,14 @@ class RBLNGemma3ForCausalLM(RBLNDecoderOnlyModelForCausalLM):
|
|
|
757
600
|
|
|
758
601
|
@classmethod
|
|
759
602
|
def _update_submodule_config(cls, model: "PreTrainedModel", rbln_config: RBLNModelConfig):
|
|
760
|
-
if rbln_config.
|
|
761
|
-
rbln_config.
|
|
603
|
+
if rbln_config.image_prefill_chunk_size is None:
|
|
604
|
+
rbln_config.image_prefill_chunk_size = model.config.mm_tokens_per_image
|
|
762
605
|
|
|
763
|
-
if rbln_config.
|
|
764
|
-
|
|
765
|
-
f"
|
|
606
|
+
if rbln_config.image_prefill_chunk_size != model.config.mm_tokens_per_image:
|
|
607
|
+
raise ValueError(
|
|
608
|
+
f"Image prefill chunk size is different from mm_tokens_per_image: {rbln_config.image_prefill_chunk_size} != {model.config.mm_tokens_per_image}"
|
|
766
609
|
)
|
|
610
|
+
|
|
767
611
|
return rbln_config
|
|
768
612
|
|
|
769
613
|
@classmethod
|
|
@@ -777,14 +621,22 @@ class RBLNGemma3ForCausalLM(RBLNDecoderOnlyModelForCausalLM):
|
|
|
777
621
|
# Update rbln_config with super class
|
|
778
622
|
rbln_config = super()._update_rbln_config(preprocessors, model, model_config, rbln_config)
|
|
779
623
|
|
|
780
|
-
|
|
781
|
-
|
|
624
|
+
if not (rbln_config.use_attention_mask and rbln_config.use_position_ids):
|
|
625
|
+
raise ValueError("use_attention_mask and use_position_ids must be True for RBLNGemma3ForCausalLM")
|
|
626
|
+
|
|
627
|
+
# Update image prefill compile config
|
|
628
|
+
img_prefill_input_info = cls.get_input_info(
|
|
629
|
+
batch_size=1,
|
|
630
|
+
query_length=rbln_config.image_prefill_chunk_size,
|
|
631
|
+
rbln_config=rbln_config,
|
|
632
|
+
model_config=model_config,
|
|
633
|
+
)
|
|
782
634
|
image_prefill_compile_config = RBLNCompileConfig(
|
|
783
|
-
compiled_model_name="image_prefill", input_info=
|
|
635
|
+
compiled_model_name="image_prefill", input_info=img_prefill_input_info
|
|
784
636
|
)
|
|
785
637
|
# Insert image_prefill compile config at index 1
|
|
786
|
-
|
|
787
|
-
compile_cfgs.insert(
|
|
638
|
+
compile_cfgs = rbln_config.compile_cfgs
|
|
639
|
+
compile_cfgs.insert(1, image_prefill_compile_config)
|
|
788
640
|
rbln_config.set_compile_cfgs(compile_cfgs)
|
|
789
641
|
|
|
790
642
|
return rbln_config
|
|
@@ -840,11 +692,14 @@ class RBLNGemma3ForCausalLM(RBLNDecoderOnlyModelForCausalLM):
|
|
|
840
692
|
)
|
|
841
693
|
|
|
842
694
|
image_prefill_compile_config = rbln_compile_configs[1]
|
|
695
|
+
image_prefill_example_inputs = image_prefill_compile_config.get_dummy_inputs(
|
|
696
|
+
fill=0, static_tensors=static_tensors
|
|
697
|
+
)
|
|
843
698
|
wrapped_model.phase = "image_prefill"
|
|
844
699
|
compiled_image_prefill = compile_model(
|
|
845
700
|
wrapped_model,
|
|
846
701
|
image_prefill_compile_config,
|
|
847
|
-
|
|
702
|
+
image_prefill_example_inputs,
|
|
848
703
|
context,
|
|
849
704
|
rbln_config.quantization,
|
|
850
705
|
)
|
|
@@ -884,12 +739,14 @@ class RBLNGemma3ForCausalLM(RBLNDecoderOnlyModelForCausalLM):
|
|
|
884
739
|
tensor_type="pt",
|
|
885
740
|
device=rbln_config.device_map["prefill"],
|
|
886
741
|
activate_profiler=rbln_config.activate_profiler,
|
|
742
|
+
timeout=rbln_config.timeout,
|
|
887
743
|
),
|
|
888
744
|
rebel.Runtime(
|
|
889
745
|
compiled_models[1],
|
|
890
746
|
tensor_type="pt",
|
|
891
747
|
device=rbln_config.device_map["image_prefill"],
|
|
892
748
|
activate_profiler=rbln_config.activate_profiler,
|
|
749
|
+
timeout=rbln_config.timeout,
|
|
893
750
|
),
|
|
894
751
|
*[
|
|
895
752
|
rebel.Runtime(
|
|
@@ -897,6 +754,7 @@ class RBLNGemma3ForCausalLM(RBLNDecoderOnlyModelForCausalLM):
|
|
|
897
754
|
tensor_type="pt",
|
|
898
755
|
device=rbln_config.device_map[f"decoder_batch_{batch_size}"],
|
|
899
756
|
activate_profiler=rbln_config.activate_profiler,
|
|
757
|
+
timeout=rbln_config.timeout,
|
|
900
758
|
)
|
|
901
759
|
for i, batch_size in enumerate(rbln_config.decoder_batch_sizes)
|
|
902
760
|
],
|
|
@@ -15,6 +15,11 @@
|
|
|
15
15
|
from typing import Any, Dict, Optional
|
|
16
16
|
|
|
17
17
|
from ....configuration_utils import RBLNModelConfig
|
|
18
|
+
from ....utils.logging import get_logger
|
|
19
|
+
from ...models.clip import RBLNCLIPVisionModelConfig
|
|
20
|
+
|
|
21
|
+
|
|
22
|
+
logger = get_logger(__name__)
|
|
18
23
|
|
|
19
24
|
|
|
20
25
|
class RBLNLlavaNextForConditionalGenerationConfig(RBLNModelConfig):
|
|
@@ -50,5 +55,17 @@ class RBLNLlavaNextForConditionalGenerationConfig(RBLNModelConfig):
|
|
|
50
55
|
if not isinstance(self.batch_size, int) or self.batch_size < 0:
|
|
51
56
|
raise ValueError(f"batch_size must be a positive integer, got {self.batch_size}")
|
|
52
57
|
|
|
53
|
-
self.vision_tower =
|
|
58
|
+
self.vision_tower = self.init_submodule_config(
|
|
59
|
+
RBLNCLIPVisionModelConfig,
|
|
60
|
+
vision_tower,
|
|
61
|
+
)
|
|
62
|
+
|
|
63
|
+
if self.vision_tower.output_hidden_states is False:
|
|
64
|
+
raise ValueError(
|
|
65
|
+
f"LlavaNext requires output_hidden_states to be True, but found output_hidden_states={self.vision_tower.output_hidden_states}. "
|
|
66
|
+
f"Please compile again with the correct argument."
|
|
67
|
+
)
|
|
68
|
+
else:
|
|
69
|
+
self.vision_tower.output_hidden_states = True
|
|
70
|
+
|
|
54
71
|
self.language_model = language_model
|
|
@@ -0,0 +1,16 @@
|
|
|
1
|
+
# Copyright 2025 Rebellions Inc. All rights reserved.
|
|
2
|
+
|
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
4
|
+
# you may not use this file except in compliance with the License.
|
|
5
|
+
# You may obtain a copy of the License at:
|
|
6
|
+
|
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
8
|
+
|
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
12
|
+
# See the License for the specific language governing permissions and
|
|
13
|
+
# limitations under the License.
|
|
14
|
+
|
|
15
|
+
from .configuration_qwen3 import RBLNQwen3ForCausalLMConfig, RBLNQwen3ModelConfig
|
|
16
|
+
from .modeling_qwen3 import RBLNQwen3ForCausalLM, RBLNQwen3Model
|
|
@@ -0,0 +1,71 @@
|
|
|
1
|
+
# Copyright 2025 Rebellions Inc. All rights reserved.
|
|
2
|
+
|
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
4
|
+
# you may not use this file except in compliance with the License.
|
|
5
|
+
# You may obtain a copy of the License at:
|
|
6
|
+
|
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
8
|
+
|
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
12
|
+
# See the License for the specific language governing permissions and
|
|
13
|
+
# limitations under the License.
|
|
14
|
+
|
|
15
|
+
from ..decoderonly.configuration_decoderonly import RBLNDecoderOnlyModelForCausalLMConfig
|
|
16
|
+
|
|
17
|
+
|
|
18
|
+
class RBLNQwen3ForCausalLMConfig(RBLNDecoderOnlyModelForCausalLMConfig):
|
|
19
|
+
"""
|
|
20
|
+
Configuration class for RBLN Qwen3 models.
|
|
21
|
+
|
|
22
|
+
This class is an alias of RBLNDecoderOnlyModelForCausalLMConfig.
|
|
23
|
+
|
|
24
|
+
Example usage:
|
|
25
|
+
```python
|
|
26
|
+
from optimum.rbln import RBLNQwen3ForCausalLM, RBLNQwen3ForCausalLMConfig
|
|
27
|
+
|
|
28
|
+
# Create a configuration object
|
|
29
|
+
config = RBLNQwen3ForCausalLMConfig(
|
|
30
|
+
batch_size=1,
|
|
31
|
+
max_seq_len=40960,
|
|
32
|
+
tensor_parallel_size=4,
|
|
33
|
+
kvcache_partition_len=16384
|
|
34
|
+
)
|
|
35
|
+
|
|
36
|
+
# Use the configuration with from_pretrained
|
|
37
|
+
model = RBLNQwen3ForCausalLM.from_pretrained(
|
|
38
|
+
"Qwen/Qwen3-4B",
|
|
39
|
+
export=True,
|
|
40
|
+
rbln_config=config
|
|
41
|
+
)
|
|
42
|
+
```
|
|
43
|
+
"""
|
|
44
|
+
|
|
45
|
+
|
|
46
|
+
class RBLNQwen3ModelConfig(RBLNDecoderOnlyModelForCausalLMConfig):
|
|
47
|
+
"""
|
|
48
|
+
Configuration class for RBLN Qwen3 models.
|
|
49
|
+
|
|
50
|
+
This class is an alias of RBLNDecoderOnlyModelForCausalLMConfig.
|
|
51
|
+
|
|
52
|
+
Example usage:
|
|
53
|
+
```python
|
|
54
|
+
from optimum.rbln import RBLNQwen3Model, RBLNQwen3ModelConfig
|
|
55
|
+
|
|
56
|
+
# Create a configuration object
|
|
57
|
+
config = RBLNQwen3ModelConfig(
|
|
58
|
+
batch_size=1,
|
|
59
|
+
max_seq_len=40960,
|
|
60
|
+
tensor_parallel_size=4,
|
|
61
|
+
kvcache_partition_len=16384
|
|
62
|
+
)
|
|
63
|
+
|
|
64
|
+
# Use the configuration with from_pretrained
|
|
65
|
+
model = RBLNQwen3Model.from_pretrained(
|
|
66
|
+
"Qwen/Qwen3-Embedding-4B",
|
|
67
|
+
export=True,
|
|
68
|
+
rbln_config=config
|
|
69
|
+
)
|
|
70
|
+
```
|
|
71
|
+
"""
|