optimum-rbln 0.7.5a0__py3-none-any.whl → 0.7.5rc0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- optimum/rbln/__init__.py +30 -0
- optimum/rbln/__version__.py +2 -2
- optimum/rbln/configuration_utils.py +9 -4
- optimum/rbln/modeling.py +7 -5
- optimum/rbln/ops/__init__.py +1 -0
- optimum/rbln/ops/attn.py +10 -0
- optimum/rbln/ops/flash_attn.py +8 -0
- optimum/rbln/ops/sliding_window_attn.py +111 -0
- optimum/rbln/transformers/__init__.py +32 -3
- optimum/rbln/transformers/models/__init__.py +37 -0
- optimum/rbln/transformers/models/auto/__init__.py +1 -0
- optimum/rbln/transformers/models/auto/modeling_auto.py +7 -0
- optimum/rbln/transformers/models/blip_2/__init__.py +20 -0
- optimum/rbln/transformers/models/blip_2/configuration_blip_2.py +93 -0
- optimum/rbln/transformers/models/blip_2/modeling_blip_2.py +298 -0
- optimum/rbln/transformers/models/decoderonly/configuration_decoderonly.py +12 -6
- optimum/rbln/transformers/models/decoderonly/decoderonly_architecture.py +189 -90
- optimum/rbln/transformers/models/decoderonly/modeling_decoderonly.py +186 -95
- optimum/rbln/transformers/models/exaone/exaone_architecture.py +5 -1
- optimum/rbln/transformers/models/gemma/gemma_architecture.py +5 -1
- optimum/rbln/transformers/models/gemma3/__init__.py +16 -0
- optimum/rbln/transformers/models/gemma3/configuration_gemma3.py +69 -0
- optimum/rbln/transformers/models/gemma3/gemma3_architecture.py +446 -0
- optimum/rbln/transformers/models/gemma3/modeling_gemma3.py +1057 -0
- optimum/rbln/transformers/models/gpt2/gpt2_architecture.py +4 -1
- optimum/rbln/transformers/models/idefics3/modeling_idefics3.py +11 -7
- optimum/rbln/transformers/models/llava_next/modeling_llava_next.py +4 -4
- optimum/rbln/transformers/models/midm/midm_architecture.py +4 -1
- optimum/rbln/transformers/models/opt/__init__.py +16 -0
- optimum/rbln/transformers/models/opt/configuration_opt.py +19 -0
- optimum/rbln/transformers/models/opt/modeling_opt.py +80 -0
- optimum/rbln/transformers/models/opt/opt_architecture.py +77 -0
- optimum/rbln/transformers/models/phi/phi_architecture.py +4 -1
- optimum/rbln/transformers/models/qwen2_5_vl/modeling_qwen2_5_vl.py +18 -11
- optimum/rbln/transformers/models/qwen2_5_vl/qwen2_5_vl_architecture.py +35 -52
- optimum/rbln/transformers/models/seq2seq/seq2seq_architecture.py +2 -0
- optimum/rbln/transformers/models/siglip/__init__.py +20 -0
- optimum/rbln/transformers/models/siglip/configuration_siglip.py +66 -0
- optimum/rbln/transformers/models/siglip/modeling_siglip.py +146 -0
- optimum/rbln/transformers/models/whisper/whisper_architecture.py +1 -0
- optimum/rbln/transformers/utils/rbln_quantization.py +121 -72
- optimum/rbln/utils/submodule.py +13 -1
- {optimum_rbln-0.7.5a0.dist-info → optimum_rbln-0.7.5rc0.dist-info}/METADATA +1 -1
- {optimum_rbln-0.7.5a0.dist-info → optimum_rbln-0.7.5rc0.dist-info}/RECORD +46 -31
- {optimum_rbln-0.7.5a0.dist-info → optimum_rbln-0.7.5rc0.dist-info}/WHEEL +0 -0
- {optimum_rbln-0.7.5a0.dist-info → optimum_rbln-0.7.5rc0.dist-info}/licenses/LICENSE +0 -0
@@ -30,7 +30,7 @@ from ....configuration_utils import RBLNCompileConfig
|
|
30
30
|
from ....modeling import RBLNModel
|
31
31
|
from ....utils.logging import get_logger
|
32
32
|
from ....utils.runtime_utils import RBLNPytorchRuntime
|
33
|
-
from ...utils.rbln_quantization import
|
33
|
+
from ...utils.rbln_quantization import prepare_model_for_quantization
|
34
34
|
from .configuration_decoderonly import RBLNDecoderOnlyModelForCausalLMConfig
|
35
35
|
from .decoderonly_architecture import (
|
36
36
|
DecoderOnlyWrapper,
|
@@ -59,6 +59,7 @@ class RBLNRuntimeModel(RBLNPytorchRuntime):
|
|
59
59
|
kvcache_block_size: int,
|
60
60
|
use_attention_mask: bool,
|
61
61
|
attn_impl: str,
|
62
|
+
use_position_ids: bool,
|
62
63
|
**kwargs: Any,
|
63
64
|
) -> None:
|
64
65
|
super().__init__(runtime, **kwargs)
|
@@ -72,6 +73,7 @@ class RBLNRuntimeModel(RBLNPytorchRuntime):
|
|
72
73
|
self.dec_attn_mask = dec_attn_mask
|
73
74
|
self.block_tables = block_tables
|
74
75
|
self.free_block_pool = free_block_pool
|
76
|
+
self.use_position_ids = use_position_ids
|
75
77
|
|
76
78
|
self.kvcache_block_size = kvcache_block_size
|
77
79
|
self.empty_block = -1
|
@@ -164,6 +166,9 @@ class RBLNRuntimeModel(RBLNPytorchRuntime):
|
|
164
166
|
batch_idx: Optional[int] = None,
|
165
167
|
block_tables: Optional[torch.Tensor] = None,
|
166
168
|
position_embed: Optional[torch.Tensor] = None,
|
169
|
+
position_ids: Optional[torch.Tensor] = None,
|
170
|
+
token_type_ids: Optional[torch.Tensor] = None,
|
171
|
+
local_block_tables: Optional[torch.Tensor] = None,
|
167
172
|
):
|
168
173
|
if input_ids is None and inputs_embeds is None:
|
169
174
|
raise ValueError("Either `input_ids` or `inputs_embeds` must be provided.")
|
@@ -189,10 +194,19 @@ class RBLNRuntimeModel(RBLNPytorchRuntime):
|
|
189
194
|
is_external_block_tables,
|
190
195
|
attention_mask=attention_mask,
|
191
196
|
position_embed=position_embed,
|
197
|
+
position_ids=position_ids,
|
198
|
+
local_block_tables=local_block_tables,
|
192
199
|
)
|
193
200
|
else:
|
194
201
|
return self.prefill_forward(
|
195
|
-
inputs,
|
202
|
+
inputs,
|
203
|
+
cache_position,
|
204
|
+
attention_mask,
|
205
|
+
batch_idx,
|
206
|
+
block_tables,
|
207
|
+
position_embed=position_embed,
|
208
|
+
token_type_ids=token_type_ids,
|
209
|
+
local_block_tables=local_block_tables,
|
196
210
|
)
|
197
211
|
|
198
212
|
def decode_forward(
|
@@ -203,6 +217,8 @@ class RBLNRuntimeModel(RBLNPytorchRuntime):
|
|
203
217
|
is_external_block_tables: bool = None,
|
204
218
|
attention_mask: Optional[torch.Tensor] = None,
|
205
219
|
position_embed: Optional[torch.Tensor] = None,
|
220
|
+
position_ids: Optional[torch.Tensor] = None,
|
221
|
+
local_block_tables: Optional[torch.Tensor] = None,
|
206
222
|
) -> torch.FloatTensor:
|
207
223
|
batch_size = inputs.shape[0]
|
208
224
|
if batch_size != self.batch_size:
|
@@ -232,35 +248,32 @@ class RBLNRuntimeModel(RBLNPytorchRuntime):
|
|
232
248
|
if self.batch_size < block_tables.shape[0]:
|
233
249
|
block_tables = block_tables[: self.batch_size]
|
234
250
|
|
235
|
-
if self.batch_size < attention_mask.shape[0]:
|
251
|
+
if attention_mask is not None and self.batch_size < attention_mask.shape[0]:
|
236
252
|
attention_mask = attention_mask[: self.batch_size]
|
237
253
|
|
238
254
|
logits = super().forward(
|
239
255
|
inputs,
|
240
256
|
cache_position,
|
241
|
-
attention_mask if self.use_attention_mask else None,
|
242
257
|
block_tables,
|
243
258
|
position_embed,
|
259
|
+
attention_mask if self.use_attention_mask else None,
|
260
|
+
position_ids if self.use_position_ids else None,
|
244
261
|
)
|
245
262
|
|
246
|
-
return logits
|
263
|
+
return RBLNDecoderOnlyOutput(logits=logits)
|
247
264
|
|
248
|
-
def
|
265
|
+
def _prepare_prefill_inputs(
|
249
266
|
self,
|
250
267
|
inputs: torch.Tensor,
|
251
|
-
cache_position: torch.Tensor
|
268
|
+
cache_position: torch.Tensor,
|
252
269
|
attention_mask: Optional[torch.Tensor] = None,
|
253
|
-
batch_idx: int = None,
|
254
|
-
block_tables: torch.Tensor = None,
|
255
|
-
is_external_block_tables: bool = None,
|
256
270
|
position_embed: Optional[torch.Tensor] = None,
|
257
|
-
|
271
|
+
local_block_tables: Optional[torch.Tensor] = None,
|
272
|
+
token_type_ids: Optional[torch.Tensor] = None,
|
273
|
+
):
|
258
274
|
"""
|
259
|
-
|
260
|
-
Instead of processing the entire sequence at once, the input is divided into chunks of size `prefill_chunk_size`,
|
261
|
-
and each chunk is processed sequentially. This allows for better memory utilization and compatibility with continuous batching.
|
275
|
+
Prepare inputs for prefill phase.
|
262
276
|
"""
|
263
|
-
|
264
277
|
# Handle continuous batching in a compiled graph by extracting valid inputs
|
265
278
|
# If an attention mask is provided, select only the valid (non-masked) inputs
|
266
279
|
inputs = inputs[:, attention_mask.bool()] if attention_mask is not None else inputs
|
@@ -276,8 +289,11 @@ class RBLNRuntimeModel(RBLNPytorchRuntime):
|
|
276
289
|
)
|
277
290
|
|
278
291
|
# Initialize attention mask for chunked processing
|
279
|
-
|
280
|
-
|
292
|
+
chunked_attention_mask = (
|
293
|
+
torch.zeros(1, 1, self.prefill_chunk_size, self.max_seq_len, dtype=torch.float32)
|
294
|
+
if self.use_attention_mask
|
295
|
+
else None
|
296
|
+
)
|
281
297
|
|
282
298
|
# Buffer for storing output logits
|
283
299
|
out_buffers = [
|
@@ -288,40 +304,88 @@ class RBLNRuntimeModel(RBLNPytorchRuntime):
|
|
288
304
|
)
|
289
305
|
]
|
290
306
|
|
291
|
-
#
|
292
|
-
|
293
|
-
|
294
|
-
|
295
|
-
|
296
|
-
|
297
|
-
|
298
|
-
|
299
|
-
|
300
|
-
else:
|
301
|
-
inputs = torch.nn.functional.pad(inputs, (0, padding_size))
|
307
|
+
# Pad input and cache_position if the last chunk is smaller than `prefill_chunk_size`
|
308
|
+
if query_length % self.prefill_chunk_size != 0:
|
309
|
+
padding_size = self.prefill_chunk_size - query_length % self.prefill_chunk_size
|
310
|
+
# inputs_embeds
|
311
|
+
if inputs.dim() == 3:
|
312
|
+
inputs = torch.nn.functional.pad(inputs, (0, 0, 0, padding_size))
|
313
|
+
# inputs_ids
|
314
|
+
else:
|
315
|
+
inputs = torch.nn.functional.pad(inputs, (0, padding_size))
|
302
316
|
|
303
|
-
|
304
|
-
|
305
|
-
|
306
|
-
|
307
|
-
|
308
|
-
|
309
|
-
|
310
|
-
|
311
|
-
|
312
|
-
|
313
|
-
|
317
|
+
cache_position = torch.cat(
|
318
|
+
[
|
319
|
+
cache_position,
|
320
|
+
torch.arange(
|
321
|
+
query_length,
|
322
|
+
query_length + padding_size,
|
323
|
+
dtype=torch.int32,
|
324
|
+
).unsqueeze(0),
|
325
|
+
],
|
326
|
+
dim=-1,
|
327
|
+
)
|
328
|
+
|
329
|
+
if position_embed is not None:
|
330
|
+
position_embed = torch.nn.functional.pad(position_embed, (0, 0, 0, padding_size))
|
331
|
+
|
332
|
+
# Overwrite position_ids and padded_cache_lengths
|
333
|
+
position_ids = None
|
334
|
+
padded_cache_lengths = 0
|
335
|
+
|
336
|
+
return (
|
337
|
+
inputs,
|
338
|
+
cache_position,
|
339
|
+
chunked_attention_mask,
|
340
|
+
out_buffers,
|
341
|
+
position_ids,
|
342
|
+
position_embed,
|
343
|
+
padded_cache_lengths,
|
344
|
+
query_length,
|
345
|
+
)
|
314
346
|
|
315
|
-
|
316
|
-
|
347
|
+
def prefill_forward(
|
348
|
+
self,
|
349
|
+
inputs: torch.Tensor,
|
350
|
+
cache_position: torch.Tensor = None,
|
351
|
+
attention_mask: Optional[torch.Tensor] = None,
|
352
|
+
batch_idx: int = None,
|
353
|
+
block_tables: torch.Tensor = None,
|
354
|
+
is_external_block_tables: bool = None,
|
355
|
+
position_embed: Optional[torch.Tensor] = None,
|
356
|
+
local_block_tables: Optional[torch.Tensor] = None,
|
357
|
+
token_type_ids: Optional[torch.Tensor] = None,
|
358
|
+
) -> torch.FloatTensor:
|
359
|
+
"""
|
360
|
+
Performs chunked prefill for efficient KV-cache updates and memory optimization.
|
361
|
+
Instead of processing the entire sequence at once, the input is divided into chunks of size `prefill_chunk_size`,
|
362
|
+
and each chunk is processed sequentially. This allows for better memory utilization and compatibility with continuous batching.
|
363
|
+
"""
|
364
|
+
(
|
365
|
+
inputs,
|
366
|
+
cache_position,
|
367
|
+
chunked_attention_mask,
|
368
|
+
out_buffers,
|
369
|
+
position_ids,
|
370
|
+
position_embed,
|
371
|
+
padded_cache_lengths,
|
372
|
+
query_length,
|
373
|
+
) = self._prepare_prefill_inputs(
|
374
|
+
inputs, cache_position, attention_mask, position_embed, token_type_ids=token_type_ids
|
375
|
+
)
|
317
376
|
|
377
|
+
# Process input in chunks of size `prefill_chunk_size`
|
378
|
+
for step in range(0, query_length, self.prefill_chunk_size):
|
318
379
|
# Extract the current chunk of inputs and cache positions
|
319
380
|
input_chunk = inputs[:, step : step + self.prefill_chunk_size]
|
320
381
|
cache_pos_chunk = cache_position[:, step : step + self.prefill_chunk_size]
|
382
|
+
position_ids_chunk = (
|
383
|
+
position_ids[:, step : step + self.prefill_chunk_size] if position_ids is not None else None
|
384
|
+
)
|
321
385
|
if position_embed is not None:
|
322
386
|
position_embed_chunk = position_embed[:, :, :, step : step + self.prefill_chunk_size, :]
|
323
387
|
|
324
|
-
if self.use_attention_mask:
|
388
|
+
if self.use_attention_mask and not self.use_position_ids:
|
325
389
|
# Update attention mask to ensure proper causal behavior
|
326
390
|
if step >= self.prefill_chunk_size:
|
327
391
|
chunked_attention_mask[:, :, :, step - self.prefill_chunk_size : step] = 1
|
@@ -334,10 +398,11 @@ class RBLNRuntimeModel(RBLNPytorchRuntime):
|
|
334
398
|
logits = super().forward(
|
335
399
|
input_chunk,
|
336
400
|
cache_pos_chunk,
|
337
|
-
chunked_attention_mask if self.use_attention_mask else None,
|
338
|
-
query_position,
|
339
401
|
block_tables,
|
340
402
|
position_embed_chunk if position_embed is not None else None,
|
403
|
+
query_position,
|
404
|
+
chunked_attention_mask if self.use_attention_mask else None,
|
405
|
+
position_ids_chunk if self.use_position_ids else None,
|
341
406
|
out=out_buffers,
|
342
407
|
)
|
343
408
|
|
@@ -346,13 +411,14 @@ class RBLNRuntimeModel(RBLNPytorchRuntime):
|
|
346
411
|
self.dec_attn_mask[batch_idx].fill_(0)
|
347
412
|
self.dec_attn_mask[batch_idx, :, :, :query_length] = 1
|
348
413
|
|
349
|
-
return logits
|
414
|
+
return RBLNDecoderOnlyOutput(logits=logits, padded_cache_lengths=padded_cache_lengths)
|
350
415
|
|
351
416
|
|
352
417
|
@dataclass
|
353
418
|
class RBLNDecoderOnlyOutput(ModelOutput):
|
354
419
|
logits: torch.FloatTensor = None
|
355
420
|
generate_idx: torch.Tensor = None
|
421
|
+
padded_cache_lengths: int = None
|
356
422
|
|
357
423
|
|
358
424
|
class RBLNDecoderOnlyModelForCausalLM(RBLNModel):
|
@@ -386,12 +452,7 @@ class RBLNDecoderOnlyModelForCausalLM(RBLNModel):
|
|
386
452
|
if self.rbln_config.use_inputs_embeds:
|
387
453
|
main_input_name = "inputs_embeds"
|
388
454
|
artifacts = torch.load(self.model_save_dir / self.subfolder / "torch_artifacts.pth", weights_only=False)
|
389
|
-
|
390
|
-
self.embed_tokens = torch.nn.Embedding(
|
391
|
-
self.config.vocab_size,
|
392
|
-
self.config.hidden_size,
|
393
|
-
self.config.pad_token_id,
|
394
|
-
)
|
455
|
+
self.embed_tokens = self._create_embedding_layer()
|
395
456
|
self.embed_tokens.load_state_dict(artifacts["embed_tokens"])
|
396
457
|
else:
|
397
458
|
self.embed_tokens = None
|
@@ -422,7 +483,9 @@ class RBLNDecoderOnlyModelForCausalLM(RBLNModel):
|
|
422
483
|
max_seq_len=self.rbln_config.max_seq_len,
|
423
484
|
use_attention_mask=self.rbln_config.use_attention_mask,
|
424
485
|
attn_impl=self.rbln_config.attn_impl,
|
486
|
+
use_position_ids=self.rbln_config.use_position_ids,
|
425
487
|
)
|
488
|
+
|
426
489
|
self.decoders = {}
|
427
490
|
for i, batch_size in enumerate(self.rbln_config.decoder_batch_sizes):
|
428
491
|
self.decoders[batch_size] = RBLNRuntimeModel(
|
@@ -437,6 +500,7 @@ class RBLNDecoderOnlyModelForCausalLM(RBLNModel):
|
|
437
500
|
kvcache_block_size=self.rbln_config.kvcache_block_size,
|
438
501
|
use_attention_mask=self.rbln_config.use_attention_mask,
|
439
502
|
attn_impl=self.rbln_config.attn_impl,
|
503
|
+
use_position_ids=self.rbln_config.use_position_ids,
|
440
504
|
)
|
441
505
|
|
442
506
|
# NOTE(eunji): Use a decoder whose batch size matches the model's main batch size for compatibility.
|
@@ -459,6 +523,15 @@ class RBLNDecoderOnlyModelForCausalLM(RBLNModel):
|
|
459
523
|
save_dict["embed_tokens"] = model.get_input_embeddings().state_dict()
|
460
524
|
torch.save(save_dict, save_dir_path / subfolder / "torch_artifacts.pth")
|
461
525
|
|
526
|
+
def _create_embedding_layer(self):
|
527
|
+
with no_init_weights():
|
528
|
+
embed_tokens = torch.nn.Embedding(
|
529
|
+
self.config.vocab_size,
|
530
|
+
self.config.hidden_size,
|
531
|
+
self.config.pad_token_id,
|
532
|
+
)
|
533
|
+
return embed_tokens
|
534
|
+
|
462
535
|
def get_input_embeddings(self):
|
463
536
|
return self.embed_tokens
|
464
537
|
|
@@ -482,8 +555,6 @@ class RBLNDecoderOnlyModelForCausalLM(RBLNModel):
|
|
482
555
|
trust_remote_code: bool = False,
|
483
556
|
**kwargs,
|
484
557
|
):
|
485
|
-
from ...utils.rbln_quantization import prepare_model_for_quantization
|
486
|
-
|
487
558
|
kwargs = cls.update_kwargs(kwargs)
|
488
559
|
|
489
560
|
if config is None:
|
@@ -500,8 +571,16 @@ class RBLNDecoderOnlyModelForCausalLM(RBLNModel):
|
|
500
571
|
with no_init_weights():
|
501
572
|
model = AutoModelForCausalLM.from_config(config)
|
502
573
|
|
503
|
-
|
504
|
-
|
574
|
+
model = prepare_model_for_quantization(
|
575
|
+
model,
|
576
|
+
model_id,
|
577
|
+
kwargs.get("num_hidden_layers"),
|
578
|
+
use_auth_token=use_auth_token,
|
579
|
+
revision=revision,
|
580
|
+
cache_dir=cache_dir,
|
581
|
+
force_download=force_download,
|
582
|
+
local_files_only=local_files_only,
|
583
|
+
)
|
505
584
|
return model
|
506
585
|
|
507
586
|
def __getattr__(self, __name: str) -> Any:
|
@@ -528,11 +607,7 @@ class RBLNDecoderOnlyModelForCausalLM(RBLNModel):
|
|
528
607
|
def get_pytorch_model(
|
529
608
|
cls, *args, rbln_config: Optional[RBLNDecoderOnlyModelForCausalLMConfig] = None, **kwargs
|
530
609
|
) -> "PreTrainedModel":
|
531
|
-
if
|
532
|
-
rbln_config is not None
|
533
|
-
and "format" in rbln_config.quantization
|
534
|
-
and rbln_config.quantization["format"] == "rbln"
|
535
|
-
):
|
610
|
+
if rbln_config and rbln_config.quantization:
|
536
611
|
model = cls.get_quantized_model(*args, **kwargs)
|
537
612
|
else:
|
538
613
|
model = super().get_pytorch_model(*args, **kwargs)
|
@@ -548,6 +623,8 @@ class RBLNDecoderOnlyModelForCausalLM(RBLNModel):
|
|
548
623
|
"kvcache_block_size": rbln_config.kvcache_block_size,
|
549
624
|
"use_rotary_emb": cls._use_rotary_emb,
|
550
625
|
"use_attention_mask": rbln_config.use_attention_mask,
|
626
|
+
"use_position_ids": rbln_config.use_position_ids,
|
627
|
+
"use_inputs_embeds": rbln_config.use_inputs_embeds,
|
551
628
|
}
|
552
629
|
return cls._decoder_wrapper_cls(model, **wrapper_cfg).eval()
|
553
630
|
|
@@ -572,9 +649,10 @@ class RBLNDecoderOnlyModelForCausalLM(RBLNModel):
|
|
572
649
|
static_tensors[name] = tensor
|
573
650
|
context.mark_static_address(tensor)
|
574
651
|
|
575
|
-
|
576
|
-
def compile_model(wrapped_model, compile_config, example_inputs, compile_context, **kwargs):
|
652
|
+
def compile_model(wrapped_model, compile_config, example_inputs, compile_context, quantization):
|
577
653
|
try:
|
654
|
+
if quantization:
|
655
|
+
quantization.maybe_set_quantization_env()
|
578
656
|
original_linear = torch.nn.functional.linear
|
579
657
|
torch.nn.functional.linear = torch.ops.rbln_custom_ops.linear
|
580
658
|
compiled_model = RBLNModel.compile(
|
@@ -586,14 +664,12 @@ class RBLNDecoderOnlyModelForCausalLM(RBLNModel):
|
|
586
664
|
return compiled_model
|
587
665
|
finally:
|
588
666
|
torch.nn.functional.linear = original_linear
|
667
|
+
if quantization:
|
668
|
+
quantization.maybe_reset_quantization_env()
|
589
669
|
|
590
670
|
wrapped_model.phase = "prefill"
|
591
671
|
compiled_prefill = compile_model(
|
592
|
-
wrapped_model,
|
593
|
-
prefill_compile_config,
|
594
|
-
prefill_example_inputs,
|
595
|
-
context,
|
596
|
-
quantize_config=rbln_config.quantization,
|
672
|
+
wrapped_model, prefill_compile_config, prefill_example_inputs, context, rbln_config.quantization
|
597
673
|
)
|
598
674
|
|
599
675
|
wrapped_model.phase = "decode"
|
@@ -601,11 +677,7 @@ class RBLNDecoderOnlyModelForCausalLM(RBLNModel):
|
|
601
677
|
for batch_size, dec_compile_config in zip(rbln_config.decoder_batch_sizes, rbln_compile_configs[1:]):
|
602
678
|
dec_example_inputs = dec_compile_config.get_dummy_inputs(fill=0, static_tensors=static_tensors)
|
603
679
|
compiled_decoder = compile_model(
|
604
|
-
wrapped_model,
|
605
|
-
dec_compile_config,
|
606
|
-
dec_example_inputs,
|
607
|
-
context,
|
608
|
-
quantize_config=rbln_config.quantization,
|
680
|
+
wrapped_model, dec_compile_config, dec_example_inputs, context, rbln_config.quantization
|
609
681
|
)
|
610
682
|
compiled_models[f"decoder_batch_{batch_size}"] = compiled_decoder
|
611
683
|
|
@@ -763,6 +835,7 @@ class RBLNDecoderOnlyModelForCausalLM(RBLNModel):
|
|
763
835
|
query_length: int,
|
764
836
|
use_inputs_embeds: bool,
|
765
837
|
use_attention_mask: bool,
|
838
|
+
use_position_ids: bool,
|
766
839
|
max_seq_len: int,
|
767
840
|
kvcache_block_size: int,
|
768
841
|
kvcache_num_blocks: int,
|
@@ -785,26 +858,27 @@ class RBLNDecoderOnlyModelForCausalLM(RBLNModel):
|
|
785
858
|
),
|
786
859
|
]
|
787
860
|
|
788
|
-
|
861
|
+
max_block_cnt = max_seq_len // kvcache_block_size
|
862
|
+
|
863
|
+
if query_length > 1:
|
864
|
+
input_info.extend([("block_tables", [max_block_cnt], "int16")])
|
865
|
+
else:
|
866
|
+
input_info.extend([("block_tables", [batch_size, max_block_cnt], "int16")])
|
867
|
+
|
868
|
+
if query_length > 1:
|
789
869
|
input_info.extend(
|
790
870
|
[
|
791
|
-
("
|
871
|
+
("query_position", [], "int16"),
|
792
872
|
]
|
793
873
|
)
|
794
|
-
|
795
|
-
if query_length > 1:
|
874
|
+
if use_attention_mask:
|
796
875
|
input_info.extend(
|
797
876
|
[
|
798
|
-
("
|
877
|
+
("attention_mask", [batch_size, 1, query_length, max_seq_len], "float32"),
|
799
878
|
]
|
800
879
|
)
|
801
|
-
|
802
|
-
|
803
|
-
|
804
|
-
if query_length > 1:
|
805
|
-
input_info.extend([("block_tables", [max_block_cnt], "int16")])
|
806
|
-
else:
|
807
|
-
input_info.extend([("block_tables", [batch_size, max_block_cnt], "int16")])
|
880
|
+
if use_position_ids:
|
881
|
+
input_info.append(("position_ids", [batch_size, query_length], "int32"))
|
808
882
|
|
809
883
|
input_info.extend(
|
810
884
|
[
|
@@ -898,6 +972,7 @@ class RBLNDecoderOnlyModelForCausalLM(RBLNModel):
|
|
898
972
|
query_length=rbln_config.prefill_chunk_size,
|
899
973
|
use_inputs_embeds=rbln_config.use_inputs_embeds,
|
900
974
|
use_attention_mask=rbln_config.use_attention_mask,
|
975
|
+
use_position_ids=rbln_config.use_position_ids,
|
901
976
|
max_seq_len=rbln_config.max_seq_len,
|
902
977
|
kvcache_block_size=rbln_config.kvcache_block_size,
|
903
978
|
kvcache_num_blocks=rbln_config.kvcache_num_blocks,
|
@@ -916,6 +991,7 @@ class RBLNDecoderOnlyModelForCausalLM(RBLNModel):
|
|
916
991
|
query_length=1,
|
917
992
|
use_inputs_embeds=rbln_config.use_inputs_embeds,
|
918
993
|
use_attention_mask=rbln_config.use_attention_mask,
|
994
|
+
use_position_ids=rbln_config.use_position_ids,
|
919
995
|
max_seq_len=rbln_config.max_seq_len,
|
920
996
|
kvcache_block_size=rbln_config.kvcache_block_size,
|
921
997
|
kvcache_num_blocks=rbln_config.kvcache_num_blocks,
|
@@ -977,6 +1053,7 @@ class RBLNDecoderOnlyModelForCausalLM(RBLNModel):
|
|
977
1053
|
generate_idx: Optional[torch.Tensor] = None,
|
978
1054
|
attention_mask: Optional[torch.LongTensor] = None,
|
979
1055
|
inputs_embeds: Optional[torch.Tensor] = None,
|
1056
|
+
padded_cache_lengths: Optional[torch.Tensor] = None,
|
980
1057
|
**kwargs,
|
981
1058
|
):
|
982
1059
|
model_inputs = {}
|
@@ -984,13 +1061,17 @@ class RBLNDecoderOnlyModelForCausalLM(RBLNModel):
|
|
984
1061
|
|
985
1062
|
if is_prefill_phase:
|
986
1063
|
generate_idx = attention_mask.sum(dim=-1, keepdim=True).int()
|
1064
|
+
padded_cache_lengths = torch.zeros_like(generate_idx)
|
987
1065
|
cache_position = None
|
1066
|
+
position_ids = None
|
988
1067
|
else:
|
989
1068
|
if inputs_embeds is not None:
|
990
|
-
|
1069
|
+
# if `inputs_embeds` are passed, only use them in the 1st generation step for every prompt.
|
1070
|
+
inputs_embeds = None
|
991
1071
|
|
992
1072
|
input_ids = input_ids[:, -1:]
|
993
|
-
|
1073
|
+
position_ids = generate_idx
|
1074
|
+
cache_position = generate_idx + padded_cache_lengths if padded_cache_lengths is not None else generate_idx
|
994
1075
|
generate_idx = generate_idx + 1
|
995
1076
|
model_inputs.update({"input_ids": input_ids})
|
996
1077
|
|
@@ -1009,6 +1090,8 @@ class RBLNDecoderOnlyModelForCausalLM(RBLNModel):
|
|
1009
1090
|
"attention_mask": attention_mask,
|
1010
1091
|
"cache_position": cache_position,
|
1011
1092
|
"generate_idx": generate_idx,
|
1093
|
+
"position_ids": position_ids,
|
1094
|
+
"padded_cache_lengths": padded_cache_lengths,
|
1012
1095
|
}
|
1013
1096
|
)
|
1014
1097
|
|
@@ -1022,6 +1105,7 @@ class RBLNDecoderOnlyModelForCausalLM(RBLNModel):
|
|
1022
1105
|
) -> Dict[str, Any]:
|
1023
1106
|
# update generate_idx
|
1024
1107
|
model_kwargs["generate_idx"] = outputs.generate_idx
|
1108
|
+
model_kwargs["padded_cache_lengths"] = outputs.padded_cache_lengths
|
1025
1109
|
|
1026
1110
|
return model_kwargs
|
1027
1111
|
|
@@ -1032,6 +1116,10 @@ class RBLNDecoderOnlyModelForCausalLM(RBLNModel):
|
|
1032
1116
|
cache_position: Optional[torch.Tensor] = None,
|
1033
1117
|
attention_mask: Optional[torch.LongTensor] = None,
|
1034
1118
|
generate_idx: Optional[torch.Tensor] = None,
|
1119
|
+
padded_cache_lengths: Optional[torch.Tensor] = None,
|
1120
|
+
position_ids: Optional[torch.Tensor] = None,
|
1121
|
+
token_type_ids: Optional[torch.Tensor] = None,
|
1122
|
+
return_dict: Optional[torch.Tensor] = None,
|
1035
1123
|
**kwargs,
|
1036
1124
|
) -> Tuple[torch.FloatTensor]:
|
1037
1125
|
"""
|
@@ -1045,18 +1133,18 @@ class RBLNDecoderOnlyModelForCausalLM(RBLNModel):
|
|
1045
1133
|
logits = []
|
1046
1134
|
inputs = inputs_embeds if inputs_embeds is not None else input_ids
|
1047
1135
|
batch_size = inputs.shape[0]
|
1048
|
-
|
1049
1136
|
for b_idx in range(batch_size):
|
1050
1137
|
cache_position = torch.arange(0, generate_idx[b_idx].item(), dtype=torch.int32).unsqueeze(0)
|
1051
|
-
|
1138
|
+
output = self.prefill_decoder(
|
1052
1139
|
input_ids=inputs[b_idx : b_idx + 1] if inputs_embeds is None else None,
|
1053
1140
|
inputs_embeds=inputs[b_idx : b_idx + 1] if inputs_embeds is not None else None,
|
1054
1141
|
attention_mask=attention_mask[b_idx] if attention_mask is not None else None,
|
1055
1142
|
cache_position=cache_position,
|
1056
1143
|
batch_idx=b_idx,
|
1144
|
+
token_type_ids=token_type_ids[b_idx : b_idx + 1] if token_type_ids is not None else None,
|
1057
1145
|
)
|
1058
|
-
|
1059
|
-
|
1146
|
+
padded_cache_lengths[b_idx] += output.padded_cache_lengths
|
1147
|
+
logits.append(output.logits)
|
1060
1148
|
logits = torch.cat(logits, dim=0)
|
1061
1149
|
# Decoder
|
1062
1150
|
else:
|
@@ -1072,9 +1160,12 @@ class RBLNDecoderOnlyModelForCausalLM(RBLNModel):
|
|
1072
1160
|
input_ids=input_ids,
|
1073
1161
|
inputs_embeds=inputs_embeds,
|
1074
1162
|
cache_position=cache_position,
|
1075
|
-
|
1163
|
+
position_ids=position_ids if self.rbln_config.use_position_ids else None,
|
1164
|
+
).logits
|
1076
1165
|
|
1077
|
-
|
1078
|
-
logits
|
1079
|
-
|
1080
|
-
|
1166
|
+
if not return_dict:
|
1167
|
+
return logits, generate_idx, padded_cache_lengths
|
1168
|
+
else:
|
1169
|
+
return RBLNDecoderOnlyOutput(
|
1170
|
+
logits=logits, generate_idx=generate_idx, padded_cache_lengths=padded_cache_lengths
|
1171
|
+
)
|
@@ -41,7 +41,10 @@ class ExaoneForCausalLMWrapper(DecoderOnlyWrapper):
|
|
41
41
|
for layer in causal_lm.transformer.h:
|
42
42
|
if self.attn_impl == "eager":
|
43
43
|
new_self_attn = ExaoneAttention(
|
44
|
-
layer.attn.attention,
|
44
|
+
layer.attn.attention,
|
45
|
+
self.use_attention_mask,
|
46
|
+
kvcache_block_size=self.kvcache_block_size,
|
47
|
+
use_position_ids=self.use_position_ids,
|
45
48
|
)
|
46
49
|
elif self.attn_impl == "flash_attn":
|
47
50
|
new_self_attn = ExaoneFlashAttention(
|
@@ -49,6 +52,7 @@ class ExaoneForCausalLMWrapper(DecoderOnlyWrapper):
|
|
49
52
|
kvcache_partition_len=self.kvcache_partition_len,
|
50
53
|
use_attention_mask=self.use_attention_mask,
|
51
54
|
kvcache_block_size=self.kvcache_block_size,
|
55
|
+
use_position_ids=self.use_position_ids,
|
52
56
|
)
|
53
57
|
else:
|
54
58
|
raise NotImplementedError(f"Unknwon attn : {self.attn_impl}")
|
@@ -34,7 +34,10 @@ class GemmaWrapper(DecoderOnlyWrapper):
|
|
34
34
|
for layer in causal_lm.model.layers:
|
35
35
|
if self.attn_impl == "eager":
|
36
36
|
new_self_attn = DecoderOnlyAttention(
|
37
|
-
layer.self_attn,
|
37
|
+
layer.self_attn,
|
38
|
+
self.use_attention_mask,
|
39
|
+
kvcache_block_size=self.kvcache_block_size,
|
40
|
+
use_position_ids=self.use_position_ids,
|
38
41
|
)
|
39
42
|
elif self.attn_impl == "flash_attn":
|
40
43
|
new_self_attn = DecoderOnlyFlashAttention(
|
@@ -42,6 +45,7 @@ class GemmaWrapper(DecoderOnlyWrapper):
|
|
42
45
|
kvcache_partition_len=self.kvcache_partition_len,
|
43
46
|
use_attention_mask=self.use_attention_mask,
|
44
47
|
kvcache_block_size=self.kvcache_block_size,
|
48
|
+
use_position_ids=self.use_position_ids,
|
45
49
|
)
|
46
50
|
else:
|
47
51
|
raise NotImplementedError(f"Unknwon attn : {self.attn_impl}")
|
@@ -0,0 +1,16 @@
|
|
1
|
+
# Copyright 2025 Rebellions Inc. All rights reserved.
|
2
|
+
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4
|
+
# you may not use this file except in compliance with the License.
|
5
|
+
# You may obtain a copy of the License at:
|
6
|
+
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
8
|
+
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12
|
+
# See the License for the specific language governing permissions and
|
13
|
+
# limitations under the License.
|
14
|
+
|
15
|
+
from .configuration_gemma3 import RBLNGemma3ForCausalLMConfig, RBLNGemma3ForConditionalGenerationConfig
|
16
|
+
from .modeling_gemma3 import RBLNGemma3ForCausalLM, RBLNGemma3ForConditionalGeneration
|