optimum-rbln 0.7.5a1__py3-none-any.whl → 0.7.5rc1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- optimum/rbln/__init__.py +10 -0
- optimum/rbln/__version__.py +2 -2
- optimum/rbln/diffusers/models/transformers/prior_transformer.py +1 -2
- optimum/rbln/diffusers/models/transformers/transformer_sd3.py +1 -2
- optimum/rbln/diffusers/models/unets/unet_2d_condition.py +1 -2
- optimum/rbln/modeling.py +53 -9
- optimum/rbln/modeling_base.py +22 -3
- optimum/rbln/transformers/__init__.py +10 -0
- optimum/rbln/transformers/modeling_generic.py +0 -19
- optimum/rbln/transformers/models/__init__.py +14 -0
- optimum/rbln/transformers/models/auto/__init__.py +1 -0
- optimum/rbln/transformers/models/auto/modeling_auto.py +7 -0
- optimum/rbln/transformers/models/decoderonly/decoderonly_architecture.py +114 -19
- optimum/rbln/transformers/models/decoderonly/modeling_decoderonly.py +29 -10
- optimum/rbln/transformers/models/exaone/exaone_architecture.py +5 -1
- optimum/rbln/transformers/models/gemma/gemma_architecture.py +5 -1
- optimum/rbln/transformers/models/gemma3/__init__.py +16 -0
- optimum/rbln/transformers/models/gemma3/configuration_gemma3.py +69 -0
- optimum/rbln/transformers/models/gemma3/gemma3_architecture.py +446 -0
- optimum/rbln/transformers/models/gemma3/modeling_gemma3.py +1057 -0
- optimum/rbln/transformers/models/gpt2/gpt2_architecture.py +4 -1
- optimum/rbln/transformers/models/midm/midm_architecture.py +4 -1
- optimum/rbln/transformers/models/opt/modeling_opt.py +2 -0
- optimum/rbln/transformers/models/opt/opt_architecture.py +4 -1
- optimum/rbln/transformers/models/phi/phi_architecture.py +4 -1
- optimum/rbln/transformers/models/qwen2_5_vl/modeling_qwen2_5_vl.py +3 -2
- optimum/rbln/transformers/models/wav2vec2/modeling_wav2vec2.py +0 -3
- {optimum_rbln-0.7.5a1.dist-info → optimum_rbln-0.7.5rc1.dist-info}/METADATA +1 -1
- {optimum_rbln-0.7.5a1.dist-info → optimum_rbln-0.7.5rc1.dist-info}/RECORD +31 -27
- {optimum_rbln-0.7.5a1.dist-info → optimum_rbln-0.7.5rc1.dist-info}/WHEEL +0 -0
- {optimum_rbln-0.7.5a1.dist-info → optimum_rbln-0.7.5rc1.dist-info}/licenses/LICENSE +0 -0
@@ -167,6 +167,8 @@ class RBLNRuntimeModel(RBLNPytorchRuntime):
|
|
167
167
|
block_tables: Optional[torch.Tensor] = None,
|
168
168
|
position_embed: Optional[torch.Tensor] = None,
|
169
169
|
position_ids: Optional[torch.Tensor] = None,
|
170
|
+
token_type_ids: Optional[torch.Tensor] = None,
|
171
|
+
local_block_tables: Optional[torch.Tensor] = None,
|
170
172
|
):
|
171
173
|
if input_ids is None and inputs_embeds is None:
|
172
174
|
raise ValueError("Either `input_ids` or `inputs_embeds` must be provided.")
|
@@ -193,6 +195,7 @@ class RBLNRuntimeModel(RBLNPytorchRuntime):
|
|
193
195
|
attention_mask=attention_mask,
|
194
196
|
position_embed=position_embed,
|
195
197
|
position_ids=position_ids,
|
198
|
+
local_block_tables=local_block_tables,
|
196
199
|
)
|
197
200
|
else:
|
198
201
|
return self.prefill_forward(
|
@@ -202,6 +205,8 @@ class RBLNRuntimeModel(RBLNPytorchRuntime):
|
|
202
205
|
batch_idx,
|
203
206
|
block_tables,
|
204
207
|
position_embed=position_embed,
|
208
|
+
token_type_ids=token_type_ids,
|
209
|
+
local_block_tables=local_block_tables,
|
205
210
|
)
|
206
211
|
|
207
212
|
def decode_forward(
|
@@ -213,6 +218,7 @@ class RBLNRuntimeModel(RBLNPytorchRuntime):
|
|
213
218
|
attention_mask: Optional[torch.Tensor] = None,
|
214
219
|
position_embed: Optional[torch.Tensor] = None,
|
215
220
|
position_ids: Optional[torch.Tensor] = None,
|
221
|
+
local_block_tables: Optional[torch.Tensor] = None,
|
216
222
|
) -> torch.FloatTensor:
|
217
223
|
batch_size = inputs.shape[0]
|
218
224
|
if batch_size != self.batch_size:
|
@@ -262,6 +268,8 @@ class RBLNRuntimeModel(RBLNPytorchRuntime):
|
|
262
268
|
cache_position: torch.Tensor,
|
263
269
|
attention_mask: Optional[torch.Tensor] = None,
|
264
270
|
position_embed: Optional[torch.Tensor] = None,
|
271
|
+
local_block_tables: Optional[torch.Tensor] = None,
|
272
|
+
token_type_ids: Optional[torch.Tensor] = None,
|
265
273
|
):
|
266
274
|
"""
|
267
275
|
Prepare inputs for prefill phase.
|
@@ -345,6 +353,8 @@ class RBLNRuntimeModel(RBLNPytorchRuntime):
|
|
345
353
|
block_tables: torch.Tensor = None,
|
346
354
|
is_external_block_tables: bool = None,
|
347
355
|
position_embed: Optional[torch.Tensor] = None,
|
356
|
+
local_block_tables: Optional[torch.Tensor] = None,
|
357
|
+
token_type_ids: Optional[torch.Tensor] = None,
|
348
358
|
) -> torch.FloatTensor:
|
349
359
|
"""
|
350
360
|
Performs chunked prefill for efficient KV-cache updates and memory optimization.
|
@@ -360,7 +370,9 @@ class RBLNRuntimeModel(RBLNPytorchRuntime):
|
|
360
370
|
position_embed,
|
361
371
|
padded_cache_lengths,
|
362
372
|
query_length,
|
363
|
-
) = self._prepare_prefill_inputs(
|
373
|
+
) = self._prepare_prefill_inputs(
|
374
|
+
inputs, cache_position, attention_mask, position_embed, token_type_ids=token_type_ids
|
375
|
+
)
|
364
376
|
|
365
377
|
# Process input in chunks of size `prefill_chunk_size`
|
366
378
|
for step in range(0, query_length, self.prefill_chunk_size):
|
@@ -373,7 +385,7 @@ class RBLNRuntimeModel(RBLNPytorchRuntime):
|
|
373
385
|
if position_embed is not None:
|
374
386
|
position_embed_chunk = position_embed[:, :, :, step : step + self.prefill_chunk_size, :]
|
375
387
|
|
376
|
-
if self.use_attention_mask:
|
388
|
+
if self.use_attention_mask and not self.use_position_ids:
|
377
389
|
# Update attention mask to ensure proper causal behavior
|
378
390
|
if step >= self.prefill_chunk_size:
|
379
391
|
chunked_attention_mask[:, :, :, step - self.prefill_chunk_size : step] = 1
|
@@ -387,10 +399,10 @@ class RBLNRuntimeModel(RBLNPytorchRuntime):
|
|
387
399
|
input_chunk,
|
388
400
|
cache_pos_chunk,
|
389
401
|
block_tables,
|
402
|
+
position_embed_chunk if position_embed is not None else None,
|
390
403
|
query_position,
|
391
404
|
chunked_attention_mask if self.use_attention_mask else None,
|
392
|
-
position_ids_chunk if
|
393
|
-
position_embed_chunk if position_embed is not None else None,
|
405
|
+
position_ids_chunk if self.use_position_ids else None,
|
394
406
|
out=out_buffers,
|
395
407
|
)
|
396
408
|
|
@@ -440,12 +452,7 @@ class RBLNDecoderOnlyModelForCausalLM(RBLNModel):
|
|
440
452
|
if self.rbln_config.use_inputs_embeds:
|
441
453
|
main_input_name = "inputs_embeds"
|
442
454
|
artifacts = torch.load(self.model_save_dir / self.subfolder / "torch_artifacts.pth", weights_only=False)
|
443
|
-
|
444
|
-
self.embed_tokens = torch.nn.Embedding(
|
445
|
-
self.config.vocab_size,
|
446
|
-
self.config.hidden_size,
|
447
|
-
self.config.pad_token_id,
|
448
|
-
)
|
455
|
+
self.embed_tokens = self._create_embedding_layer()
|
449
456
|
self.embed_tokens.load_state_dict(artifacts["embed_tokens"])
|
450
457
|
else:
|
451
458
|
self.embed_tokens = None
|
@@ -478,6 +485,7 @@ class RBLNDecoderOnlyModelForCausalLM(RBLNModel):
|
|
478
485
|
attn_impl=self.rbln_config.attn_impl,
|
479
486
|
use_position_ids=self.rbln_config.use_position_ids,
|
480
487
|
)
|
488
|
+
|
481
489
|
self.decoders = {}
|
482
490
|
for i, batch_size in enumerate(self.rbln_config.decoder_batch_sizes):
|
483
491
|
self.decoders[batch_size] = RBLNRuntimeModel(
|
@@ -515,6 +523,15 @@ class RBLNDecoderOnlyModelForCausalLM(RBLNModel):
|
|
515
523
|
save_dict["embed_tokens"] = model.get_input_embeddings().state_dict()
|
516
524
|
torch.save(save_dict, save_dir_path / subfolder / "torch_artifacts.pth")
|
517
525
|
|
526
|
+
def _create_embedding_layer(self):
|
527
|
+
with no_init_weights():
|
528
|
+
embed_tokens = torch.nn.Embedding(
|
529
|
+
self.config.vocab_size,
|
530
|
+
self.config.hidden_size,
|
531
|
+
self.config.pad_token_id,
|
532
|
+
)
|
533
|
+
return embed_tokens
|
534
|
+
|
518
535
|
def get_input_embeddings(self):
|
519
536
|
return self.embed_tokens
|
520
537
|
|
@@ -1101,6 +1118,7 @@ class RBLNDecoderOnlyModelForCausalLM(RBLNModel):
|
|
1101
1118
|
generate_idx: Optional[torch.Tensor] = None,
|
1102
1119
|
padded_cache_lengths: Optional[torch.Tensor] = None,
|
1103
1120
|
position_ids: Optional[torch.Tensor] = None,
|
1121
|
+
token_type_ids: Optional[torch.Tensor] = None,
|
1104
1122
|
return_dict: Optional[torch.Tensor] = None,
|
1105
1123
|
**kwargs,
|
1106
1124
|
) -> Tuple[torch.FloatTensor]:
|
@@ -1123,6 +1141,7 @@ class RBLNDecoderOnlyModelForCausalLM(RBLNModel):
|
|
1123
1141
|
attention_mask=attention_mask[b_idx] if attention_mask is not None else None,
|
1124
1142
|
cache_position=cache_position,
|
1125
1143
|
batch_idx=b_idx,
|
1144
|
+
token_type_ids=token_type_ids[b_idx : b_idx + 1] if token_type_ids is not None else None,
|
1126
1145
|
)
|
1127
1146
|
padded_cache_lengths[b_idx] += output.padded_cache_lengths
|
1128
1147
|
logits.append(output.logits)
|
@@ -41,7 +41,10 @@ class ExaoneForCausalLMWrapper(DecoderOnlyWrapper):
|
|
41
41
|
for layer in causal_lm.transformer.h:
|
42
42
|
if self.attn_impl == "eager":
|
43
43
|
new_self_attn = ExaoneAttention(
|
44
|
-
layer.attn.attention,
|
44
|
+
layer.attn.attention,
|
45
|
+
self.use_attention_mask,
|
46
|
+
kvcache_block_size=self.kvcache_block_size,
|
47
|
+
use_position_ids=self.use_position_ids,
|
45
48
|
)
|
46
49
|
elif self.attn_impl == "flash_attn":
|
47
50
|
new_self_attn = ExaoneFlashAttention(
|
@@ -49,6 +52,7 @@ class ExaoneForCausalLMWrapper(DecoderOnlyWrapper):
|
|
49
52
|
kvcache_partition_len=self.kvcache_partition_len,
|
50
53
|
use_attention_mask=self.use_attention_mask,
|
51
54
|
kvcache_block_size=self.kvcache_block_size,
|
55
|
+
use_position_ids=self.use_position_ids,
|
52
56
|
)
|
53
57
|
else:
|
54
58
|
raise NotImplementedError(f"Unknwon attn : {self.attn_impl}")
|
@@ -34,7 +34,10 @@ class GemmaWrapper(DecoderOnlyWrapper):
|
|
34
34
|
for layer in causal_lm.model.layers:
|
35
35
|
if self.attn_impl == "eager":
|
36
36
|
new_self_attn = DecoderOnlyAttention(
|
37
|
-
layer.self_attn,
|
37
|
+
layer.self_attn,
|
38
|
+
self.use_attention_mask,
|
39
|
+
kvcache_block_size=self.kvcache_block_size,
|
40
|
+
use_position_ids=self.use_position_ids,
|
38
41
|
)
|
39
42
|
elif self.attn_impl == "flash_attn":
|
40
43
|
new_self_attn = DecoderOnlyFlashAttention(
|
@@ -42,6 +45,7 @@ class GemmaWrapper(DecoderOnlyWrapper):
|
|
42
45
|
kvcache_partition_len=self.kvcache_partition_len,
|
43
46
|
use_attention_mask=self.use_attention_mask,
|
44
47
|
kvcache_block_size=self.kvcache_block_size,
|
48
|
+
use_position_ids=self.use_position_ids,
|
45
49
|
)
|
46
50
|
else:
|
47
51
|
raise NotImplementedError(f"Unknwon attn : {self.attn_impl}")
|
@@ -0,0 +1,16 @@
|
|
1
|
+
# Copyright 2025 Rebellions Inc. All rights reserved.
|
2
|
+
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4
|
+
# you may not use this file except in compliance with the License.
|
5
|
+
# You may obtain a copy of the License at:
|
6
|
+
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
8
|
+
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12
|
+
# See the License for the specific language governing permissions and
|
13
|
+
# limitations under the License.
|
14
|
+
|
15
|
+
from .configuration_gemma3 import RBLNGemma3ForCausalLMConfig, RBLNGemma3ForConditionalGenerationConfig
|
16
|
+
from .modeling_gemma3 import RBLNGemma3ForCausalLM, RBLNGemma3ForConditionalGeneration
|
@@ -0,0 +1,69 @@
|
|
1
|
+
# Copyright 2025 Rebellions Inc. All rights reserved.
|
2
|
+
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4
|
+
# you may not use this file except in compliance with the License.
|
5
|
+
# You may obtain a copy of the License at:
|
6
|
+
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
8
|
+
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12
|
+
# See the License for the specific language governing permissions and
|
13
|
+
# limitations under the License.
|
14
|
+
|
15
|
+
from typing import Optional
|
16
|
+
|
17
|
+
from ....configuration_utils import RBLNModelConfig
|
18
|
+
from ..decoderonly.configuration_decoderonly import RBLNDecoderOnlyModelForCausalLMConfig
|
19
|
+
from ..siglip.configuration_siglip import RBLNSiglipVisionModelConfig
|
20
|
+
|
21
|
+
|
22
|
+
class RBLNGemma3ForCausalLMConfig(RBLNDecoderOnlyModelForCausalLMConfig):
|
23
|
+
def __init__(
|
24
|
+
self,
|
25
|
+
prefill_chunk_size: Optional[int] = None,
|
26
|
+
use_position_ids: Optional[bool] = None,
|
27
|
+
use_attention_mask: Optional[bool] = None,
|
28
|
+
**kwargs,
|
29
|
+
):
|
30
|
+
# use_attention_mask and use_position_ids are always True for Gemma3
|
31
|
+
use_attention_mask = use_attention_mask or True
|
32
|
+
use_position_ids = use_position_ids or True
|
33
|
+
prefill_chunk_size = prefill_chunk_size or 256
|
34
|
+
|
35
|
+
super().__init__(
|
36
|
+
prefill_chunk_size=prefill_chunk_size,
|
37
|
+
use_attention_mask=use_attention_mask,
|
38
|
+
use_position_ids=use_position_ids,
|
39
|
+
**kwargs,
|
40
|
+
)
|
41
|
+
|
42
|
+
|
43
|
+
class RBLNGemma3ForConditionalGenerationConfig(RBLNModelConfig):
|
44
|
+
submodules = ["vision_tower", "language_model"]
|
45
|
+
|
46
|
+
def __init__(
|
47
|
+
self,
|
48
|
+
batch_size: Optional[int] = None,
|
49
|
+
vision_tower: Optional[RBLNModelConfig] = None,
|
50
|
+
language_model: Optional[RBLNModelConfig] = None,
|
51
|
+
**kwargs,
|
52
|
+
):
|
53
|
+
"""
|
54
|
+
Args:
|
55
|
+
batch_size (Optional[int]): The batch size for inference. Defaults to 1.
|
56
|
+
vision_tower (Optional[RBLNModelConfig]): Configuration for the vision encoder component.
|
57
|
+
language_model (Optional[RBLNModelConfig]): Configuration for the language model component.
|
58
|
+
**kwargs: Additional arguments passed to the parent RBLNModelConfig.
|
59
|
+
|
60
|
+
Raises:
|
61
|
+
ValueError: If batch_size is not a positive integer.
|
62
|
+
"""
|
63
|
+
super().__init__(**kwargs)
|
64
|
+
self.batch_size = batch_size or 1
|
65
|
+
if not isinstance(self.batch_size, int) or self.batch_size < 0:
|
66
|
+
raise ValueError(f"batch_size must be a positive integer, got {self.batch_size}")
|
67
|
+
|
68
|
+
self.vision_tower = self.init_submodule_config(RBLNSiglipVisionModelConfig, vision_tower)
|
69
|
+
self.language_model = self.init_submodule_config(RBLNGemma3ForCausalLMConfig, language_model)
|