optimum-rbln 0.7.5a0__py3-none-any.whl → 0.7.5a1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- optimum/rbln/__init__.py +20 -0
- optimum/rbln/__version__.py +2 -2
- optimum/rbln/configuration_utils.py +9 -4
- optimum/rbln/modeling.py +7 -5
- optimum/rbln/ops/__init__.py +1 -0
- optimum/rbln/ops/attn.py +10 -0
- optimum/rbln/ops/flash_attn.py +8 -0
- optimum/rbln/ops/sliding_window_attn.py +111 -0
- optimum/rbln/transformers/__init__.py +22 -3
- optimum/rbln/transformers/models/__init__.py +23 -0
- optimum/rbln/transformers/models/blip_2/__init__.py +20 -0
- optimum/rbln/transformers/models/blip_2/configuration_blip_2.py +93 -0
- optimum/rbln/transformers/models/blip_2/modeling_blip_2.py +298 -0
- optimum/rbln/transformers/models/decoderonly/configuration_decoderonly.py +12 -6
- optimum/rbln/transformers/models/decoderonly/decoderonly_architecture.py +81 -77
- optimum/rbln/transformers/models/decoderonly/modeling_decoderonly.py +160 -88
- optimum/rbln/transformers/models/idefics3/modeling_idefics3.py +11 -7
- optimum/rbln/transformers/models/llava_next/modeling_llava_next.py +4 -4
- optimum/rbln/transformers/models/opt/__init__.py +16 -0
- optimum/rbln/transformers/models/opt/configuration_opt.py +19 -0
- optimum/rbln/transformers/models/opt/modeling_opt.py +78 -0
- optimum/rbln/transformers/models/opt/opt_architecture.py +74 -0
- optimum/rbln/transformers/models/qwen2_5_vl/modeling_qwen2_5_vl.py +16 -10
- optimum/rbln/transformers/models/qwen2_5_vl/qwen2_5_vl_architecture.py +35 -52
- optimum/rbln/transformers/models/seq2seq/seq2seq_architecture.py +2 -0
- optimum/rbln/transformers/models/siglip/__init__.py +20 -0
- optimum/rbln/transformers/models/siglip/configuration_siglip.py +66 -0
- optimum/rbln/transformers/models/siglip/modeling_siglip.py +146 -0
- optimum/rbln/transformers/models/whisper/whisper_architecture.py +1 -0
- optimum/rbln/transformers/utils/rbln_quantization.py +121 -72
- optimum/rbln/utils/submodule.py +13 -1
- {optimum_rbln-0.7.5a0.dist-info → optimum_rbln-0.7.5a1.dist-info}/METADATA +1 -1
- {optimum_rbln-0.7.5a0.dist-info → optimum_rbln-0.7.5a1.dist-info}/RECORD +35 -24
- {optimum_rbln-0.7.5a0.dist-info → optimum_rbln-0.7.5a1.dist-info}/WHEEL +0 -0
- {optimum_rbln-0.7.5a0.dist-info → optimum_rbln-0.7.5a1.dist-info}/licenses/LICENSE +0 -0
@@ -30,7 +30,7 @@ from ....configuration_utils import RBLNCompileConfig
|
|
30
30
|
from ....modeling import RBLNModel
|
31
31
|
from ....utils.logging import get_logger
|
32
32
|
from ....utils.runtime_utils import RBLNPytorchRuntime
|
33
|
-
from ...utils.rbln_quantization import
|
33
|
+
from ...utils.rbln_quantization import prepare_model_for_quantization
|
34
34
|
from .configuration_decoderonly import RBLNDecoderOnlyModelForCausalLMConfig
|
35
35
|
from .decoderonly_architecture import (
|
36
36
|
DecoderOnlyWrapper,
|
@@ -59,6 +59,7 @@ class RBLNRuntimeModel(RBLNPytorchRuntime):
|
|
59
59
|
kvcache_block_size: int,
|
60
60
|
use_attention_mask: bool,
|
61
61
|
attn_impl: str,
|
62
|
+
use_position_ids: bool,
|
62
63
|
**kwargs: Any,
|
63
64
|
) -> None:
|
64
65
|
super().__init__(runtime, **kwargs)
|
@@ -72,6 +73,7 @@ class RBLNRuntimeModel(RBLNPytorchRuntime):
|
|
72
73
|
self.dec_attn_mask = dec_attn_mask
|
73
74
|
self.block_tables = block_tables
|
74
75
|
self.free_block_pool = free_block_pool
|
76
|
+
self.use_position_ids = use_position_ids
|
75
77
|
|
76
78
|
self.kvcache_block_size = kvcache_block_size
|
77
79
|
self.empty_block = -1
|
@@ -164,6 +166,7 @@ class RBLNRuntimeModel(RBLNPytorchRuntime):
|
|
164
166
|
batch_idx: Optional[int] = None,
|
165
167
|
block_tables: Optional[torch.Tensor] = None,
|
166
168
|
position_embed: Optional[torch.Tensor] = None,
|
169
|
+
position_ids: Optional[torch.Tensor] = None,
|
167
170
|
):
|
168
171
|
if input_ids is None and inputs_embeds is None:
|
169
172
|
raise ValueError("Either `input_ids` or `inputs_embeds` must be provided.")
|
@@ -189,10 +192,16 @@ class RBLNRuntimeModel(RBLNPytorchRuntime):
|
|
189
192
|
is_external_block_tables,
|
190
193
|
attention_mask=attention_mask,
|
191
194
|
position_embed=position_embed,
|
195
|
+
position_ids=position_ids,
|
192
196
|
)
|
193
197
|
else:
|
194
198
|
return self.prefill_forward(
|
195
|
-
inputs,
|
199
|
+
inputs,
|
200
|
+
cache_position,
|
201
|
+
attention_mask,
|
202
|
+
batch_idx,
|
203
|
+
block_tables,
|
204
|
+
position_embed=position_embed,
|
196
205
|
)
|
197
206
|
|
198
207
|
def decode_forward(
|
@@ -203,6 +212,7 @@ class RBLNRuntimeModel(RBLNPytorchRuntime):
|
|
203
212
|
is_external_block_tables: bool = None,
|
204
213
|
attention_mask: Optional[torch.Tensor] = None,
|
205
214
|
position_embed: Optional[torch.Tensor] = None,
|
215
|
+
position_ids: Optional[torch.Tensor] = None,
|
206
216
|
) -> torch.FloatTensor:
|
207
217
|
batch_size = inputs.shape[0]
|
208
218
|
if batch_size != self.batch_size:
|
@@ -232,35 +242,30 @@ class RBLNRuntimeModel(RBLNPytorchRuntime):
|
|
232
242
|
if self.batch_size < block_tables.shape[0]:
|
233
243
|
block_tables = block_tables[: self.batch_size]
|
234
244
|
|
235
|
-
if self.batch_size < attention_mask.shape[0]:
|
245
|
+
if attention_mask is not None and self.batch_size < attention_mask.shape[0]:
|
236
246
|
attention_mask = attention_mask[: self.batch_size]
|
237
247
|
|
238
248
|
logits = super().forward(
|
239
249
|
inputs,
|
240
250
|
cache_position,
|
241
|
-
attention_mask if self.use_attention_mask else None,
|
242
251
|
block_tables,
|
243
252
|
position_embed,
|
253
|
+
attention_mask if self.use_attention_mask else None,
|
254
|
+
position_ids if self.use_position_ids else None,
|
244
255
|
)
|
245
256
|
|
246
|
-
return logits
|
257
|
+
return RBLNDecoderOnlyOutput(logits=logits)
|
247
258
|
|
248
|
-
def
|
259
|
+
def _prepare_prefill_inputs(
|
249
260
|
self,
|
250
261
|
inputs: torch.Tensor,
|
251
|
-
cache_position: torch.Tensor
|
262
|
+
cache_position: torch.Tensor,
|
252
263
|
attention_mask: Optional[torch.Tensor] = None,
|
253
|
-
batch_idx: int = None,
|
254
|
-
block_tables: torch.Tensor = None,
|
255
|
-
is_external_block_tables: bool = None,
|
256
264
|
position_embed: Optional[torch.Tensor] = None,
|
257
|
-
)
|
265
|
+
):
|
258
266
|
"""
|
259
|
-
|
260
|
-
Instead of processing the entire sequence at once, the input is divided into chunks of size `prefill_chunk_size`,
|
261
|
-
and each chunk is processed sequentially. This allows for better memory utilization and compatibility with continuous batching.
|
267
|
+
Prepare inputs for prefill phase.
|
262
268
|
"""
|
263
|
-
|
264
269
|
# Handle continuous batching in a compiled graph by extracting valid inputs
|
265
270
|
# If an attention mask is provided, select only the valid (non-masked) inputs
|
266
271
|
inputs = inputs[:, attention_mask.bool()] if attention_mask is not None else inputs
|
@@ -276,8 +281,11 @@ class RBLNRuntimeModel(RBLNPytorchRuntime):
|
|
276
281
|
)
|
277
282
|
|
278
283
|
# Initialize attention mask for chunked processing
|
279
|
-
|
280
|
-
|
284
|
+
chunked_attention_mask = (
|
285
|
+
torch.zeros(1, 1, self.prefill_chunk_size, self.max_seq_len, dtype=torch.float32)
|
286
|
+
if self.use_attention_mask
|
287
|
+
else None
|
288
|
+
)
|
281
289
|
|
282
290
|
# Buffer for storing output logits
|
283
291
|
out_buffers = [
|
@@ -288,36 +296,80 @@ class RBLNRuntimeModel(RBLNPytorchRuntime):
|
|
288
296
|
)
|
289
297
|
]
|
290
298
|
|
291
|
-
#
|
292
|
-
|
293
|
-
|
294
|
-
|
295
|
-
|
296
|
-
|
297
|
-
|
298
|
-
|
299
|
-
|
300
|
-
else:
|
301
|
-
inputs = torch.nn.functional.pad(inputs, (0, padding_size))
|
299
|
+
# Pad input and cache_position if the last chunk is smaller than `prefill_chunk_size`
|
300
|
+
if query_length % self.prefill_chunk_size != 0:
|
301
|
+
padding_size = self.prefill_chunk_size - query_length % self.prefill_chunk_size
|
302
|
+
# inputs_embeds
|
303
|
+
if inputs.dim() == 3:
|
304
|
+
inputs = torch.nn.functional.pad(inputs, (0, 0, 0, padding_size))
|
305
|
+
# inputs_ids
|
306
|
+
else:
|
307
|
+
inputs = torch.nn.functional.pad(inputs, (0, padding_size))
|
302
308
|
|
303
|
-
|
304
|
-
|
305
|
-
|
306
|
-
|
307
|
-
|
308
|
-
|
309
|
-
|
310
|
-
|
311
|
-
|
312
|
-
|
313
|
-
|
309
|
+
cache_position = torch.cat(
|
310
|
+
[
|
311
|
+
cache_position,
|
312
|
+
torch.arange(
|
313
|
+
query_length,
|
314
|
+
query_length + padding_size,
|
315
|
+
dtype=torch.int32,
|
316
|
+
).unsqueeze(0),
|
317
|
+
],
|
318
|
+
dim=-1,
|
319
|
+
)
|
320
|
+
|
321
|
+
if position_embed is not None:
|
322
|
+
position_embed = torch.nn.functional.pad(position_embed, (0, 0, 0, padding_size))
|
323
|
+
|
324
|
+
# Overwrite position_ids and padded_cache_lengths
|
325
|
+
position_ids = None
|
326
|
+
padded_cache_lengths = 0
|
327
|
+
|
328
|
+
return (
|
329
|
+
inputs,
|
330
|
+
cache_position,
|
331
|
+
chunked_attention_mask,
|
332
|
+
out_buffers,
|
333
|
+
position_ids,
|
334
|
+
position_embed,
|
335
|
+
padded_cache_lengths,
|
336
|
+
query_length,
|
337
|
+
)
|
314
338
|
|
315
|
-
|
316
|
-
|
339
|
+
def prefill_forward(
|
340
|
+
self,
|
341
|
+
inputs: torch.Tensor,
|
342
|
+
cache_position: torch.Tensor = None,
|
343
|
+
attention_mask: Optional[torch.Tensor] = None,
|
344
|
+
batch_idx: int = None,
|
345
|
+
block_tables: torch.Tensor = None,
|
346
|
+
is_external_block_tables: bool = None,
|
347
|
+
position_embed: Optional[torch.Tensor] = None,
|
348
|
+
) -> torch.FloatTensor:
|
349
|
+
"""
|
350
|
+
Performs chunked prefill for efficient KV-cache updates and memory optimization.
|
351
|
+
Instead of processing the entire sequence at once, the input is divided into chunks of size `prefill_chunk_size`,
|
352
|
+
and each chunk is processed sequentially. This allows for better memory utilization and compatibility with continuous batching.
|
353
|
+
"""
|
354
|
+
(
|
355
|
+
inputs,
|
356
|
+
cache_position,
|
357
|
+
chunked_attention_mask,
|
358
|
+
out_buffers,
|
359
|
+
position_ids,
|
360
|
+
position_embed,
|
361
|
+
padded_cache_lengths,
|
362
|
+
query_length,
|
363
|
+
) = self._prepare_prefill_inputs(inputs, cache_position, attention_mask, position_embed)
|
317
364
|
|
365
|
+
# Process input in chunks of size `prefill_chunk_size`
|
366
|
+
for step in range(0, query_length, self.prefill_chunk_size):
|
318
367
|
# Extract the current chunk of inputs and cache positions
|
319
368
|
input_chunk = inputs[:, step : step + self.prefill_chunk_size]
|
320
369
|
cache_pos_chunk = cache_position[:, step : step + self.prefill_chunk_size]
|
370
|
+
position_ids_chunk = (
|
371
|
+
position_ids[:, step : step + self.prefill_chunk_size] if position_ids is not None else None
|
372
|
+
)
|
321
373
|
if position_embed is not None:
|
322
374
|
position_embed_chunk = position_embed[:, :, :, step : step + self.prefill_chunk_size, :]
|
323
375
|
|
@@ -334,9 +386,10 @@ class RBLNRuntimeModel(RBLNPytorchRuntime):
|
|
334
386
|
logits = super().forward(
|
335
387
|
input_chunk,
|
336
388
|
cache_pos_chunk,
|
337
|
-
chunked_attention_mask if self.use_attention_mask else None,
|
338
|
-
query_position,
|
339
389
|
block_tables,
|
390
|
+
query_position,
|
391
|
+
chunked_attention_mask if self.use_attention_mask else None,
|
392
|
+
position_ids_chunk if position_ids is not None else None,
|
340
393
|
position_embed_chunk if position_embed is not None else None,
|
341
394
|
out=out_buffers,
|
342
395
|
)
|
@@ -346,13 +399,14 @@ class RBLNRuntimeModel(RBLNPytorchRuntime):
|
|
346
399
|
self.dec_attn_mask[batch_idx].fill_(0)
|
347
400
|
self.dec_attn_mask[batch_idx, :, :, :query_length] = 1
|
348
401
|
|
349
|
-
return logits
|
402
|
+
return RBLNDecoderOnlyOutput(logits=logits, padded_cache_lengths=padded_cache_lengths)
|
350
403
|
|
351
404
|
|
352
405
|
@dataclass
|
353
406
|
class RBLNDecoderOnlyOutput(ModelOutput):
|
354
407
|
logits: torch.FloatTensor = None
|
355
408
|
generate_idx: torch.Tensor = None
|
409
|
+
padded_cache_lengths: int = None
|
356
410
|
|
357
411
|
|
358
412
|
class RBLNDecoderOnlyModelForCausalLM(RBLNModel):
|
@@ -422,6 +476,7 @@ class RBLNDecoderOnlyModelForCausalLM(RBLNModel):
|
|
422
476
|
max_seq_len=self.rbln_config.max_seq_len,
|
423
477
|
use_attention_mask=self.rbln_config.use_attention_mask,
|
424
478
|
attn_impl=self.rbln_config.attn_impl,
|
479
|
+
use_position_ids=self.rbln_config.use_position_ids,
|
425
480
|
)
|
426
481
|
self.decoders = {}
|
427
482
|
for i, batch_size in enumerate(self.rbln_config.decoder_batch_sizes):
|
@@ -437,6 +492,7 @@ class RBLNDecoderOnlyModelForCausalLM(RBLNModel):
|
|
437
492
|
kvcache_block_size=self.rbln_config.kvcache_block_size,
|
438
493
|
use_attention_mask=self.rbln_config.use_attention_mask,
|
439
494
|
attn_impl=self.rbln_config.attn_impl,
|
495
|
+
use_position_ids=self.rbln_config.use_position_ids,
|
440
496
|
)
|
441
497
|
|
442
498
|
# NOTE(eunji): Use a decoder whose batch size matches the model's main batch size for compatibility.
|
@@ -482,8 +538,6 @@ class RBLNDecoderOnlyModelForCausalLM(RBLNModel):
|
|
482
538
|
trust_remote_code: bool = False,
|
483
539
|
**kwargs,
|
484
540
|
):
|
485
|
-
from ...utils.rbln_quantization import prepare_model_for_quantization
|
486
|
-
|
487
541
|
kwargs = cls.update_kwargs(kwargs)
|
488
542
|
|
489
543
|
if config is None:
|
@@ -500,8 +554,16 @@ class RBLNDecoderOnlyModelForCausalLM(RBLNModel):
|
|
500
554
|
with no_init_weights():
|
501
555
|
model = AutoModelForCausalLM.from_config(config)
|
502
556
|
|
503
|
-
|
504
|
-
|
557
|
+
model = prepare_model_for_quantization(
|
558
|
+
model,
|
559
|
+
model_id,
|
560
|
+
kwargs.get("num_hidden_layers"),
|
561
|
+
use_auth_token=use_auth_token,
|
562
|
+
revision=revision,
|
563
|
+
cache_dir=cache_dir,
|
564
|
+
force_download=force_download,
|
565
|
+
local_files_only=local_files_only,
|
566
|
+
)
|
505
567
|
return model
|
506
568
|
|
507
569
|
def __getattr__(self, __name: str) -> Any:
|
@@ -528,11 +590,7 @@ class RBLNDecoderOnlyModelForCausalLM(RBLNModel):
|
|
528
590
|
def get_pytorch_model(
|
529
591
|
cls, *args, rbln_config: Optional[RBLNDecoderOnlyModelForCausalLMConfig] = None, **kwargs
|
530
592
|
) -> "PreTrainedModel":
|
531
|
-
if
|
532
|
-
rbln_config is not None
|
533
|
-
and "format" in rbln_config.quantization
|
534
|
-
and rbln_config.quantization["format"] == "rbln"
|
535
|
-
):
|
593
|
+
if rbln_config and rbln_config.quantization:
|
536
594
|
model = cls.get_quantized_model(*args, **kwargs)
|
537
595
|
else:
|
538
596
|
model = super().get_pytorch_model(*args, **kwargs)
|
@@ -548,6 +606,8 @@ class RBLNDecoderOnlyModelForCausalLM(RBLNModel):
|
|
548
606
|
"kvcache_block_size": rbln_config.kvcache_block_size,
|
549
607
|
"use_rotary_emb": cls._use_rotary_emb,
|
550
608
|
"use_attention_mask": rbln_config.use_attention_mask,
|
609
|
+
"use_position_ids": rbln_config.use_position_ids,
|
610
|
+
"use_inputs_embeds": rbln_config.use_inputs_embeds,
|
551
611
|
}
|
552
612
|
return cls._decoder_wrapper_cls(model, **wrapper_cfg).eval()
|
553
613
|
|
@@ -572,9 +632,10 @@ class RBLNDecoderOnlyModelForCausalLM(RBLNModel):
|
|
572
632
|
static_tensors[name] = tensor
|
573
633
|
context.mark_static_address(tensor)
|
574
634
|
|
575
|
-
|
576
|
-
def compile_model(wrapped_model, compile_config, example_inputs, compile_context, **kwargs):
|
635
|
+
def compile_model(wrapped_model, compile_config, example_inputs, compile_context, quantization):
|
577
636
|
try:
|
637
|
+
if quantization:
|
638
|
+
quantization.maybe_set_quantization_env()
|
578
639
|
original_linear = torch.nn.functional.linear
|
579
640
|
torch.nn.functional.linear = torch.ops.rbln_custom_ops.linear
|
580
641
|
compiled_model = RBLNModel.compile(
|
@@ -586,14 +647,12 @@ class RBLNDecoderOnlyModelForCausalLM(RBLNModel):
|
|
586
647
|
return compiled_model
|
587
648
|
finally:
|
588
649
|
torch.nn.functional.linear = original_linear
|
650
|
+
if quantization:
|
651
|
+
quantization.maybe_reset_quantization_env()
|
589
652
|
|
590
653
|
wrapped_model.phase = "prefill"
|
591
654
|
compiled_prefill = compile_model(
|
592
|
-
wrapped_model,
|
593
|
-
prefill_compile_config,
|
594
|
-
prefill_example_inputs,
|
595
|
-
context,
|
596
|
-
quantize_config=rbln_config.quantization,
|
655
|
+
wrapped_model, prefill_compile_config, prefill_example_inputs, context, rbln_config.quantization
|
597
656
|
)
|
598
657
|
|
599
658
|
wrapped_model.phase = "decode"
|
@@ -601,11 +660,7 @@ class RBLNDecoderOnlyModelForCausalLM(RBLNModel):
|
|
601
660
|
for batch_size, dec_compile_config in zip(rbln_config.decoder_batch_sizes, rbln_compile_configs[1:]):
|
602
661
|
dec_example_inputs = dec_compile_config.get_dummy_inputs(fill=0, static_tensors=static_tensors)
|
603
662
|
compiled_decoder = compile_model(
|
604
|
-
wrapped_model,
|
605
|
-
dec_compile_config,
|
606
|
-
dec_example_inputs,
|
607
|
-
context,
|
608
|
-
quantize_config=rbln_config.quantization,
|
663
|
+
wrapped_model, dec_compile_config, dec_example_inputs, context, rbln_config.quantization
|
609
664
|
)
|
610
665
|
compiled_models[f"decoder_batch_{batch_size}"] = compiled_decoder
|
611
666
|
|
@@ -763,6 +818,7 @@ class RBLNDecoderOnlyModelForCausalLM(RBLNModel):
|
|
763
818
|
query_length: int,
|
764
819
|
use_inputs_embeds: bool,
|
765
820
|
use_attention_mask: bool,
|
821
|
+
use_position_ids: bool,
|
766
822
|
max_seq_len: int,
|
767
823
|
kvcache_block_size: int,
|
768
824
|
kvcache_num_blocks: int,
|
@@ -785,26 +841,27 @@ class RBLNDecoderOnlyModelForCausalLM(RBLNModel):
|
|
785
841
|
),
|
786
842
|
]
|
787
843
|
|
788
|
-
|
844
|
+
max_block_cnt = max_seq_len // kvcache_block_size
|
845
|
+
|
846
|
+
if query_length > 1:
|
847
|
+
input_info.extend([("block_tables", [max_block_cnt], "int16")])
|
848
|
+
else:
|
849
|
+
input_info.extend([("block_tables", [batch_size, max_block_cnt], "int16")])
|
850
|
+
|
851
|
+
if query_length > 1:
|
789
852
|
input_info.extend(
|
790
853
|
[
|
791
|
-
("
|
854
|
+
("query_position", [], "int16"),
|
792
855
|
]
|
793
856
|
)
|
794
|
-
|
795
|
-
if query_length > 1:
|
857
|
+
if use_attention_mask:
|
796
858
|
input_info.extend(
|
797
859
|
[
|
798
|
-
("
|
860
|
+
("attention_mask", [batch_size, 1, query_length, max_seq_len], "float32"),
|
799
861
|
]
|
800
862
|
)
|
801
|
-
|
802
|
-
|
803
|
-
|
804
|
-
if query_length > 1:
|
805
|
-
input_info.extend([("block_tables", [max_block_cnt], "int16")])
|
806
|
-
else:
|
807
|
-
input_info.extend([("block_tables", [batch_size, max_block_cnt], "int16")])
|
863
|
+
if use_position_ids:
|
864
|
+
input_info.append(("position_ids", [batch_size, query_length], "int32"))
|
808
865
|
|
809
866
|
input_info.extend(
|
810
867
|
[
|
@@ -898,6 +955,7 @@ class RBLNDecoderOnlyModelForCausalLM(RBLNModel):
|
|
898
955
|
query_length=rbln_config.prefill_chunk_size,
|
899
956
|
use_inputs_embeds=rbln_config.use_inputs_embeds,
|
900
957
|
use_attention_mask=rbln_config.use_attention_mask,
|
958
|
+
use_position_ids=rbln_config.use_position_ids,
|
901
959
|
max_seq_len=rbln_config.max_seq_len,
|
902
960
|
kvcache_block_size=rbln_config.kvcache_block_size,
|
903
961
|
kvcache_num_blocks=rbln_config.kvcache_num_blocks,
|
@@ -916,6 +974,7 @@ class RBLNDecoderOnlyModelForCausalLM(RBLNModel):
|
|
916
974
|
query_length=1,
|
917
975
|
use_inputs_embeds=rbln_config.use_inputs_embeds,
|
918
976
|
use_attention_mask=rbln_config.use_attention_mask,
|
977
|
+
use_position_ids=rbln_config.use_position_ids,
|
919
978
|
max_seq_len=rbln_config.max_seq_len,
|
920
979
|
kvcache_block_size=rbln_config.kvcache_block_size,
|
921
980
|
kvcache_num_blocks=rbln_config.kvcache_num_blocks,
|
@@ -977,6 +1036,7 @@ class RBLNDecoderOnlyModelForCausalLM(RBLNModel):
|
|
977
1036
|
generate_idx: Optional[torch.Tensor] = None,
|
978
1037
|
attention_mask: Optional[torch.LongTensor] = None,
|
979
1038
|
inputs_embeds: Optional[torch.Tensor] = None,
|
1039
|
+
padded_cache_lengths: Optional[torch.Tensor] = None,
|
980
1040
|
**kwargs,
|
981
1041
|
):
|
982
1042
|
model_inputs = {}
|
@@ -984,13 +1044,17 @@ class RBLNDecoderOnlyModelForCausalLM(RBLNModel):
|
|
984
1044
|
|
985
1045
|
if is_prefill_phase:
|
986
1046
|
generate_idx = attention_mask.sum(dim=-1, keepdim=True).int()
|
1047
|
+
padded_cache_lengths = torch.zeros_like(generate_idx)
|
987
1048
|
cache_position = None
|
1049
|
+
position_ids = None
|
988
1050
|
else:
|
989
1051
|
if inputs_embeds is not None:
|
990
|
-
|
1052
|
+
# if `inputs_embeds` are passed, only use them in the 1st generation step for every prompt.
|
1053
|
+
inputs_embeds = None
|
991
1054
|
|
992
1055
|
input_ids = input_ids[:, -1:]
|
993
|
-
|
1056
|
+
position_ids = generate_idx
|
1057
|
+
cache_position = generate_idx + padded_cache_lengths if padded_cache_lengths is not None else generate_idx
|
994
1058
|
generate_idx = generate_idx + 1
|
995
1059
|
model_inputs.update({"input_ids": input_ids})
|
996
1060
|
|
@@ -1009,6 +1073,8 @@ class RBLNDecoderOnlyModelForCausalLM(RBLNModel):
|
|
1009
1073
|
"attention_mask": attention_mask,
|
1010
1074
|
"cache_position": cache_position,
|
1011
1075
|
"generate_idx": generate_idx,
|
1076
|
+
"position_ids": position_ids,
|
1077
|
+
"padded_cache_lengths": padded_cache_lengths,
|
1012
1078
|
}
|
1013
1079
|
)
|
1014
1080
|
|
@@ -1022,6 +1088,7 @@ class RBLNDecoderOnlyModelForCausalLM(RBLNModel):
|
|
1022
1088
|
) -> Dict[str, Any]:
|
1023
1089
|
# update generate_idx
|
1024
1090
|
model_kwargs["generate_idx"] = outputs.generate_idx
|
1091
|
+
model_kwargs["padded_cache_lengths"] = outputs.padded_cache_lengths
|
1025
1092
|
|
1026
1093
|
return model_kwargs
|
1027
1094
|
|
@@ -1032,6 +1099,9 @@ class RBLNDecoderOnlyModelForCausalLM(RBLNModel):
|
|
1032
1099
|
cache_position: Optional[torch.Tensor] = None,
|
1033
1100
|
attention_mask: Optional[torch.LongTensor] = None,
|
1034
1101
|
generate_idx: Optional[torch.Tensor] = None,
|
1102
|
+
padded_cache_lengths: Optional[torch.Tensor] = None,
|
1103
|
+
position_ids: Optional[torch.Tensor] = None,
|
1104
|
+
return_dict: Optional[torch.Tensor] = None,
|
1035
1105
|
**kwargs,
|
1036
1106
|
) -> Tuple[torch.FloatTensor]:
|
1037
1107
|
"""
|
@@ -1045,18 +1115,17 @@ class RBLNDecoderOnlyModelForCausalLM(RBLNModel):
|
|
1045
1115
|
logits = []
|
1046
1116
|
inputs = inputs_embeds if inputs_embeds is not None else input_ids
|
1047
1117
|
batch_size = inputs.shape[0]
|
1048
|
-
|
1049
1118
|
for b_idx in range(batch_size):
|
1050
1119
|
cache_position = torch.arange(0, generate_idx[b_idx].item(), dtype=torch.int32).unsqueeze(0)
|
1051
|
-
|
1120
|
+
output = self.prefill_decoder(
|
1052
1121
|
input_ids=inputs[b_idx : b_idx + 1] if inputs_embeds is None else None,
|
1053
1122
|
inputs_embeds=inputs[b_idx : b_idx + 1] if inputs_embeds is not None else None,
|
1054
1123
|
attention_mask=attention_mask[b_idx] if attention_mask is not None else None,
|
1055
1124
|
cache_position=cache_position,
|
1056
1125
|
batch_idx=b_idx,
|
1057
1126
|
)
|
1058
|
-
|
1059
|
-
|
1127
|
+
padded_cache_lengths[b_idx] += output.padded_cache_lengths
|
1128
|
+
logits.append(output.logits)
|
1060
1129
|
logits = torch.cat(logits, dim=0)
|
1061
1130
|
# Decoder
|
1062
1131
|
else:
|
@@ -1072,9 +1141,12 @@ class RBLNDecoderOnlyModelForCausalLM(RBLNModel):
|
|
1072
1141
|
input_ids=input_ids,
|
1073
1142
|
inputs_embeds=inputs_embeds,
|
1074
1143
|
cache_position=cache_position,
|
1075
|
-
|
1144
|
+
position_ids=position_ids if self.rbln_config.use_position_ids else None,
|
1145
|
+
).logits
|
1076
1146
|
|
1077
|
-
|
1078
|
-
logits
|
1079
|
-
|
1080
|
-
|
1147
|
+
if not return_dict:
|
1148
|
+
return logits, generate_idx, padded_cache_lengths
|
1149
|
+
else:
|
1150
|
+
return RBLNDecoderOnlyOutput(
|
1151
|
+
logits=logits, generate_idx=generate_idx, padded_cache_lengths=padded_cache_lengths
|
1152
|
+
)
|
@@ -421,6 +421,7 @@ class RBLNIdefics3ForConditionalGeneration(RBLNModel):
|
|
421
421
|
image_hidden_states: Optional[torch.FloatTensor] = None,
|
422
422
|
cache_position: torch.Tensor = None,
|
423
423
|
generate_idx: Optional[torch.Tensor] = None,
|
424
|
+
return_dict: Optional[bool] = None,
|
424
425
|
**kwargs,
|
425
426
|
) -> Union[Tuple, Idefics3CausalLMOutputWithPast]:
|
426
427
|
# Prefill
|
@@ -434,14 +435,14 @@ class RBLNIdefics3ForConditionalGeneration(RBLNModel):
|
|
434
435
|
|
435
436
|
for b_idx in range(batch_size):
|
436
437
|
cache_position = torch.arange(0, generate_idx[b_idx].item(), dtype=torch.int32).unsqueeze(0)
|
437
|
-
|
438
|
+
output = self.text_model.prefill_decoder(
|
438
439
|
input_ids=inputs[b_idx : b_idx + 1] if inputs_embeds is None else None,
|
439
440
|
inputs_embeds=inputs[b_idx : b_idx + 1] if inputs_embeds is not None else None,
|
440
441
|
attention_mask=attention_mask[b_idx] if attention_mask is not None else None,
|
441
442
|
cache_position=cache_position,
|
442
443
|
batch_idx=b_idx,
|
443
444
|
)
|
444
|
-
logits.append(
|
445
|
+
logits.append(output.logits)
|
445
446
|
|
446
447
|
logits = torch.cat(logits, dim=0)
|
447
448
|
|
@@ -451,9 +452,12 @@ class RBLNIdefics3ForConditionalGeneration(RBLNModel):
|
|
451
452
|
input_ids=input_ids,
|
452
453
|
inputs_embeds=inputs_embeds,
|
453
454
|
cache_position=cache_position,
|
454
|
-
)
|
455
|
+
).logits
|
455
456
|
|
456
|
-
|
457
|
-
logits
|
458
|
-
|
459
|
-
|
457
|
+
if not return_dict:
|
458
|
+
return logits, generate_idx
|
459
|
+
else:
|
460
|
+
return RBLNDecoderOnlyOutput(
|
461
|
+
logits=logits,
|
462
|
+
generate_idx=generate_idx,
|
463
|
+
)
|
@@ -372,7 +372,7 @@ class RBLNLlavaNextForConditionalGeneration(RBLNModel):
|
|
372
372
|
inputs_embeds = [inputs_embeds[i : i + 1, attention_mask[i].bool()] for i in range(batch_size)]
|
373
373
|
for batch_idx in range(batch_size):
|
374
374
|
generate_idx[batch_idx] = inputs_embeds[batch_idx].shape[-2]
|
375
|
-
|
375
|
+
output = self.language_model.prefill_decoder(
|
376
376
|
inputs_embeds=inputs_embeds[batch_idx],
|
377
377
|
batch_idx=batch_idx,
|
378
378
|
cache_position=torch.arange(
|
@@ -382,14 +382,14 @@ class RBLNLlavaNextForConditionalGeneration(RBLNModel):
|
|
382
382
|
).unsqueeze(0),
|
383
383
|
)
|
384
384
|
|
385
|
-
logits.append(
|
385
|
+
logits.append(output.logits)
|
386
386
|
logits = torch.cat(logits, dim=0)
|
387
387
|
else:
|
388
|
-
|
388
|
+
output = self.language_model.decoder(
|
389
389
|
inputs_embeds=inputs_embeds,
|
390
390
|
cache_position=cache_position,
|
391
391
|
)
|
392
|
-
|
392
|
+
logits = output.logits
|
393
393
|
return RBLNDecoderOnlyOutput(logits=logits, generate_idx=generate_idx)
|
394
394
|
|
395
395
|
# Almost copied from : https://github.com/huggingface/transformers/blob/6b550462139655d488d4c663086a63e98713c6b9/src/transformers/models/llava_next/modeling_llava_next.py
|
@@ -0,0 +1,16 @@
|
|
1
|
+
# Copyright 2025 Rebellions Inc. All rights reserved.
|
2
|
+
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4
|
+
# you may not use this file except in compliance with the License.
|
5
|
+
# You may obtain a copy of the License at:
|
6
|
+
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
8
|
+
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12
|
+
# See the License for the specific language governing permissions and
|
13
|
+
# limitations under the License.
|
14
|
+
|
15
|
+
from .configuration_opt import RBLNOPTForCausalLMConfig
|
16
|
+
from .modeling_opt import RBLNOPTForCausalLM
|
@@ -0,0 +1,19 @@
|
|
1
|
+
# Copyright 2025 Rebellions Inc. All rights reserved.
|
2
|
+
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4
|
+
# you may not use this file except in compliance with the License.
|
5
|
+
# You may obtain a copy of the License at:
|
6
|
+
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
8
|
+
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12
|
+
# See the License for the specific language governing permissions and
|
13
|
+
# limitations under the License.
|
14
|
+
|
15
|
+
from ..decoderonly.configuration_decoderonly import RBLNDecoderOnlyModelForCausalLMConfig
|
16
|
+
|
17
|
+
|
18
|
+
class RBLNOPTForCausalLMConfig(RBLNDecoderOnlyModelForCausalLMConfig):
|
19
|
+
pass
|