optimum-rbln 0.8.2a0__py3-none-any.whl → 0.9.3__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- optimum/rbln/__init__.py +116 -9
- optimum/rbln/__version__.py +16 -3
- optimum/rbln/cli.py +660 -0
- optimum/rbln/configuration_utils.py +171 -43
- optimum/rbln/diffusers/__init__.py +19 -0
- optimum/rbln/diffusers/configurations/__init__.py +3 -0
- optimum/rbln/diffusers/configurations/models/__init__.py +2 -0
- optimum/rbln/diffusers/configurations/models/configuration_autoencoder_kl.py +3 -3
- optimum/rbln/diffusers/configurations/models/configuration_autoencoder_kl_cosmos.py +1 -1
- optimum/rbln/diffusers/configurations/models/configuration_autoencoder_kl_temporal_decoder.py +67 -0
- optimum/rbln/diffusers/configurations/models/configuration_controlnet.py +3 -3
- optimum/rbln/diffusers/configurations/models/configuration_prior_transformer.py +4 -4
- optimum/rbln/diffusers/configurations/models/configuration_transformer_cosmos.py +12 -4
- optimum/rbln/diffusers/configurations/models/configuration_transformer_sd3.py +9 -4
- optimum/rbln/diffusers/configurations/models/configuration_unet_2d_condition.py +3 -3
- optimum/rbln/diffusers/configurations/models/configuration_unet_spatio_temporal_condition.py +59 -0
- optimum/rbln/diffusers/configurations/models/configuration_vq_model.py +3 -3
- optimum/rbln/diffusers/configurations/pipelines/__init__.py +3 -0
- optimum/rbln/diffusers/configurations/pipelines/configuration_controlnet.py +35 -19
- optimum/rbln/diffusers/configurations/pipelines/configuration_cosmos.py +14 -11
- optimum/rbln/diffusers/configurations/pipelines/configuration_kandinsky2_2.py +30 -20
- optimum/rbln/diffusers/configurations/pipelines/configuration_stable_diffusion.py +13 -9
- optimum/rbln/diffusers/configurations/pipelines/configuration_stable_diffusion_3.py +17 -13
- optimum/rbln/diffusers/configurations/pipelines/configuration_stable_diffusion_xl.py +17 -10
- optimum/rbln/diffusers/configurations/pipelines/configuration_stable_video_diffusion.py +114 -0
- optimum/rbln/diffusers/modeling_diffusers.py +33 -18
- optimum/rbln/diffusers/models/__init__.py +4 -0
- optimum/rbln/diffusers/models/autoencoders/__init__.py +1 -0
- optimum/rbln/diffusers/models/autoencoders/autoencoder_kl.py +32 -3
- optimum/rbln/diffusers/models/autoencoders/autoencoder_kl_cosmos.py +32 -6
- optimum/rbln/diffusers/models/autoencoders/autoencoder_kl_temporal_decoder.py +275 -0
- optimum/rbln/diffusers/models/autoencoders/vae.py +27 -8
- optimum/rbln/diffusers/models/autoencoders/vq_model.py +32 -3
- optimum/rbln/diffusers/models/controlnet.py +16 -1
- optimum/rbln/diffusers/models/transformers/prior_transformer.py +17 -3
- optimum/rbln/diffusers/models/transformers/transformer_cosmos.py +26 -3
- optimum/rbln/diffusers/models/transformers/transformer_sd3.py +23 -2
- optimum/rbln/diffusers/models/unets/__init__.py +1 -0
- optimum/rbln/diffusers/models/unets/unet_2d_condition.py +23 -4
- optimum/rbln/diffusers/models/unets/unet_spatio_temporal_condition.py +201 -0
- optimum/rbln/diffusers/pipelines/__init__.py +15 -5
- optimum/rbln/diffusers/pipelines/auto_pipeline.py +307 -0
- optimum/rbln/diffusers/pipelines/controlnet/multicontrolnet.py +20 -0
- optimum/rbln/diffusers/pipelines/cosmos/configuration_cosmos_guardrail.py +23 -12
- optimum/rbln/diffusers/pipelines/cosmos/cosmos_guardrail.py +16 -46
- optimum/rbln/diffusers/pipelines/cosmos/pipeline_cosmos_text2world.py +31 -1
- optimum/rbln/diffusers/pipelines/cosmos/pipeline_cosmos_video2world.py +31 -1
- optimum/rbln/diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_combined.py +1 -6
- optimum/rbln/diffusers/pipelines/stable_video_diffusion/__init__.py +15 -0
- optimum/rbln/diffusers/pipelines/stable_video_diffusion/pipeline_stable_video_diffusion.py +46 -0
- optimum/rbln/modeling.py +50 -24
- optimum/rbln/modeling_base.py +116 -35
- optimum/rbln/ops/attn.py +158 -0
- optimum/rbln/ops/flash_attn.py +166 -0
- optimum/rbln/ops/kv_cache_update.py +5 -0
- optimum/rbln/ops/linear.py +7 -0
- optimum/rbln/transformers/__init__.py +100 -0
- optimum/rbln/transformers/configuration_generic.py +7 -32
- optimum/rbln/transformers/modeling_attention_utils.py +385 -0
- optimum/rbln/transformers/modeling_generic.py +48 -65
- optimum/rbln/transformers/modeling_outputs.py +37 -0
- optimum/rbln/transformers/models/__init__.py +93 -30
- optimum/rbln/transformers/models/audio_spectrogram_transformer/configuration_audio_spectrogram_transformer.py +28 -2
- optimum/rbln/transformers/models/audio_spectrogram_transformer/modeling_audio_spectrogram_transformer.py +68 -5
- optimum/rbln/transformers/models/auto/__init__.py +2 -0
- optimum/rbln/transformers/models/auto/auto_factory.py +92 -17
- optimum/rbln/transformers/models/auto/modeling_auto.py +45 -0
- optimum/rbln/transformers/models/bart/bart_architecture.py +2 -7
- optimum/rbln/transformers/models/bart/configuration_bart.py +2 -0
- optimum/rbln/transformers/models/bart/modeling_bart.py +23 -2
- optimum/rbln/transformers/models/bert/bert_architecture.py +16 -0
- optimum/rbln/transformers/models/bert/modeling_bert.py +93 -4
- optimum/rbln/transformers/models/blip_2/configuration_blip_2.py +42 -11
- optimum/rbln/transformers/models/blip_2/modeling_blip_2.py +135 -44
- optimum/rbln/transformers/models/clip/configuration_clip.py +21 -7
- optimum/rbln/transformers/models/clip/modeling_clip.py +183 -27
- optimum/rbln/transformers/models/colpali/colpali_architecture.py +3 -6
- optimum/rbln/transformers/models/colpali/configuration_colpali.py +37 -21
- optimum/rbln/transformers/models/colpali/modeling_colpali.py +82 -104
- optimum/rbln/transformers/models/colqwen2/__init__.py +2 -0
- optimum/rbln/transformers/models/colqwen2/colqwen2_architecture.py +233 -0
- optimum/rbln/transformers/models/colqwen2/configuration_colqwen2.py +74 -0
- optimum/rbln/transformers/models/colqwen2/modeling_colqwen2.py +446 -0
- optimum/rbln/transformers/models/decoderonly/__init__.py +3 -2
- optimum/rbln/transformers/models/decoderonly/configuration_decoderonly.py +114 -37
- optimum/rbln/transformers/models/decoderonly/configuration_lora.py +411 -0
- optimum/rbln/transformers/models/decoderonly/decoderonly_architecture.py +323 -316
- optimum/rbln/transformers/models/decoderonly/decoderonly_runtime_utils.py +508 -0
- optimum/rbln/transformers/models/decoderonly/generation_decoderonly.py +119 -0
- optimum/rbln/transformers/models/decoderonly/lora_architecture.py +204 -0
- optimum/rbln/transformers/models/decoderonly/modeling_decoderonly.py +486 -892
- optimum/rbln/transformers/models/depth_anything/__init__.py +16 -0
- optimum/rbln/transformers/models/depth_anything/configuration_depth_anything.py +24 -0
- optimum/rbln/transformers/models/depth_anything/modeling_depth_anything.py +42 -0
- optimum/rbln/transformers/models/distilbert/modeling_distilbert.py +24 -0
- optimum/rbln/transformers/models/dpt/modeling_dpt.py +17 -0
- optimum/rbln/transformers/models/exaone/modeling_exaone.py +42 -4
- optimum/rbln/transformers/models/gemma/__init__.py +2 -2
- optimum/rbln/transformers/models/gemma/configuration_gemma.py +9 -1
- optimum/rbln/transformers/models/gemma/gemma_architecture.py +1 -4
- optimum/rbln/transformers/models/gemma/modeling_gemma.py +22 -1
- optimum/rbln/transformers/models/gemma3/configuration_gemma3.py +49 -14
- optimum/rbln/transformers/models/gemma3/gemma3_architecture.py +12 -2
- optimum/rbln/transformers/models/gemma3/gemma3_runtime_utils.py +245 -0
- optimum/rbln/transformers/models/gemma3/modeling_gemma3.py +212 -504
- optimum/rbln/transformers/models/gpt2/__init__.py +2 -2
- optimum/rbln/transformers/models/gpt2/configuration_gpt2.py +31 -3
- optimum/rbln/transformers/models/gpt2/gpt2_architecture.py +10 -8
- optimum/rbln/transformers/models/gpt2/modeling_gpt2.py +18 -1
- optimum/rbln/transformers/models/grounding_dino/__init__.py +10 -0
- optimum/rbln/transformers/models/grounding_dino/configuration_grounding_dino.py +92 -0
- optimum/rbln/transformers/models/grounding_dino/grounding_dino_architecture.py +599 -0
- optimum/rbln/transformers/models/grounding_dino/modeling_grounding_dino.py +1048 -0
- optimum/rbln/transformers/models/idefics3/configuration_idefics3.py +35 -7
- optimum/rbln/transformers/models/idefics3/modeling_idefics3.py +29 -32
- optimum/rbln/transformers/models/llama/__init__.py +2 -2
- optimum/rbln/transformers/models/llama/configuration_llama.py +9 -1
- optimum/rbln/transformers/models/llama/modeling_llama.py +22 -1
- optimum/rbln/transformers/models/llava/__init__.py +16 -0
- optimum/rbln/transformers/models/llava/configuration_llava.py +72 -0
- optimum/rbln/transformers/models/llava/modeling_llava.py +490 -0
- optimum/rbln/transformers/models/llava_next/configuration_llava_next.py +21 -6
- optimum/rbln/transformers/models/llava_next/modeling_llava_next.py +234 -376
- optimum/rbln/transformers/models/midm/midm_architecture.py +4 -1
- optimum/rbln/transformers/models/midm/modeling_midm.py +42 -4
- optimum/rbln/transformers/models/mistral/__init__.py +2 -2
- optimum/rbln/transformers/models/mistral/configuration_mistral.py +9 -1
- optimum/rbln/transformers/models/mistral/mistral_architecture.py +1 -1
- optimum/rbln/transformers/models/mistral/modeling_mistral.py +26 -3
- optimum/rbln/transformers/models/opt/__init__.py +2 -2
- optimum/rbln/transformers/models/opt/configuration_opt.py +8 -1
- optimum/rbln/transformers/models/opt/modeling_opt.py +29 -17
- optimum/rbln/transformers/models/opt/opt_architecture.py +4 -4
- optimum/rbln/transformers/models/pegasus/__init__.py +17 -0
- optimum/rbln/transformers/models/pegasus/configuration_pegasus.py +38 -0
- optimum/rbln/transformers/models/pegasus/modeling_pegasus.py +71 -0
- optimum/rbln/transformers/models/pegasus/pegasus_architecture.py +161 -0
- optimum/rbln/transformers/models/phi/__init__.py +2 -2
- optimum/rbln/transformers/models/phi/configuration_phi.py +9 -1
- optimum/rbln/transformers/models/phi/modeling_phi.py +10 -1
- optimum/rbln/transformers/models/phi/phi_architecture.py +11 -7
- optimum/rbln/transformers/models/pixtral/__init__.py +16 -0
- optimum/rbln/transformers/models/pixtral/configuration_pixtral.py +43 -0
- optimum/rbln/transformers/models/pixtral/modeling_pixtral.py +322 -0
- optimum/rbln/transformers/models/pixtral/pixtral_architecture.py +73 -0
- optimum/rbln/transformers/models/qwen2/__init__.py +2 -2
- optimum/rbln/transformers/models/qwen2/configuration_qwen2.py +9 -1
- optimum/rbln/transformers/models/qwen2/modeling_qwen2.py +27 -1
- optimum/rbln/transformers/models/qwen2_5_vl/configuration_qwen2_5_vl.py +21 -6
- optimum/rbln/transformers/models/qwen2_5_vl/modeling_qwen2_5_vl.py +15 -22
- optimum/rbln/transformers/models/qwen2_5_vl/qwen2_5_vl_architecture.py +28 -7
- optimum/rbln/transformers/models/qwen2_vl/__init__.py +19 -0
- optimum/rbln/transformers/models/qwen2_vl/configuration_qwen2_vl.py +88 -0
- optimum/rbln/transformers/models/qwen2_vl/modeling_qwen2_vl.py +513 -0
- optimum/rbln/transformers/models/qwen2_vl/qwen2_vl_architecture.py +165 -0
- optimum/rbln/transformers/models/qwen3/__init__.py +16 -0
- optimum/rbln/transformers/models/qwen3/configuration_qwen3.py +71 -0
- optimum/rbln/transformers/models/qwen3/modeling_qwen3.py +133 -0
- optimum/rbln/transformers/models/qwen3/qwen3_architecture.py +31 -0
- optimum/rbln/transformers/models/resnet/configuration_resnet.py +17 -0
- optimum/rbln/transformers/models/resnet/modeling_resnet.py +73 -0
- optimum/rbln/transformers/models/roberta/modeling_roberta.py +33 -0
- optimum/rbln/transformers/models/seq2seq/configuration_seq2seq.py +21 -16
- optimum/rbln/transformers/models/seq2seq/modeling_seq2seq.py +60 -13
- optimum/rbln/transformers/models/seq2seq/seq2seq_architecture.py +2 -2
- optimum/rbln/transformers/models/siglip/__init__.py +2 -6
- optimum/rbln/transformers/models/siglip/configuration_siglip.py +1 -1
- optimum/rbln/transformers/models/siglip/modeling_siglip.py +21 -16
- optimum/rbln/transformers/models/swin/__init__.py +16 -0
- optimum/rbln/transformers/models/swin/configuration_swin.py +42 -0
- optimum/rbln/transformers/models/swin/modeling_swin.py +354 -0
- optimum/rbln/transformers/models/t5/configuration_t5.py +2 -0
- optimum/rbln/transformers/models/t5/modeling_t5.py +2 -2
- optimum/rbln/transformers/models/t5/t5_architecture.py +8 -1
- optimum/rbln/transformers/models/time_series_transformer/configuration_time_series_transformer.py +3 -3
- optimum/rbln/transformers/models/time_series_transformer/modeling_time_series_transformer.py +22 -16
- optimum/rbln/transformers/models/time_series_transformer/time_series_transformers_architecture.py +7 -1
- optimum/rbln/transformers/models/vit/modeling_vit.py +19 -0
- optimum/rbln/transformers/models/wav2vec2/configuration_wav2vec2.py +15 -3
- optimum/rbln/transformers/models/wav2vec2/modeling_wav2vec2.py +61 -8
- optimum/rbln/transformers/models/whisper/configuration_whisper.py +12 -13
- optimum/rbln/transformers/models/whisper/generation_whisper.py +62 -6
- optimum/rbln/transformers/models/whisper/modeling_whisper.py +32 -5
- optimum/rbln/transformers/models/xlm_roberta/__init__.py +2 -8
- optimum/rbln/transformers/models/xlm_roberta/modeling_xlm_roberta.py +43 -0
- optimum/rbln/transformers/utils/rbln_quantization.py +400 -75
- optimum/rbln/transformers/utils/rbln_runtime_wrapper.py +79 -0
- optimum/rbln/utils/deprecation.py +213 -0
- optimum/rbln/utils/hub.py +22 -50
- optimum/rbln/utils/runtime_utils.py +85 -17
- optimum/rbln/utils/submodule.py +31 -9
- {optimum_rbln-0.8.2a0.dist-info → optimum_rbln-0.9.3.dist-info}/METADATA +8 -7
- optimum_rbln-0.9.3.dist-info/RECORD +264 -0
- {optimum_rbln-0.8.2a0.dist-info → optimum_rbln-0.9.3.dist-info}/WHEEL +1 -1
- optimum_rbln-0.9.3.dist-info/entry_points.txt +2 -0
- optimum_rbln-0.8.2a0.dist-info/RECORD +0 -211
- {optimum_rbln-0.8.2a0.dist-info → optimum_rbln-0.9.3.dist-info}/licenses/LICENSE +0 -0
|
@@ -0,0 +1,165 @@
|
|
|
1
|
+
import math
|
|
2
|
+
from typing import Tuple
|
|
3
|
+
|
|
4
|
+
import torch
|
|
5
|
+
import torch.nn as nn
|
|
6
|
+
from transformers import PreTrainedModel
|
|
7
|
+
|
|
8
|
+
from ..decoderonly.decoderonly_architecture import (
|
|
9
|
+
DecoderOnlyWrapper,
|
|
10
|
+
apply_rotary_pos_emb,
|
|
11
|
+
)
|
|
12
|
+
|
|
13
|
+
|
|
14
|
+
class Qwen2VisionTransformerWrapper(nn.Module):
|
|
15
|
+
def __init__(self, model: torch.nn.Module):
|
|
16
|
+
super().__init__()
|
|
17
|
+
self._original_mod = model
|
|
18
|
+
self.merger = model.merger
|
|
19
|
+
self.blocks = self.wrap_vision_blocks(model.blocks)
|
|
20
|
+
|
|
21
|
+
def wrap_vision_blocks(self, blocks: torch.nn.ModuleList):
|
|
22
|
+
wrapped_blocks = []
|
|
23
|
+
for i, block in enumerate(blocks):
|
|
24
|
+
wrapped_blocks.append(Qwen2VLVisionBlock(block))
|
|
25
|
+
return nn.ModuleList(wrapped_blocks)
|
|
26
|
+
|
|
27
|
+
def forward(
|
|
28
|
+
self,
|
|
29
|
+
hidden_states: torch.Tensor,
|
|
30
|
+
full_attn_masks: torch.Tensor,
|
|
31
|
+
cos: torch.Tensor,
|
|
32
|
+
sin: torch.Tensor,
|
|
33
|
+
):
|
|
34
|
+
full_attn_masks = (1 - full_attn_masks) * torch.finfo(torch.float32).min
|
|
35
|
+
|
|
36
|
+
for block in self.blocks:
|
|
37
|
+
hidden_states = block(hidden_states, full_attn_masks, [cos, sin])
|
|
38
|
+
|
|
39
|
+
return self.merger(hidden_states)
|
|
40
|
+
|
|
41
|
+
|
|
42
|
+
class Qwen2VLVisionBlock(torch.nn.Module):
|
|
43
|
+
def __init__(self, model: torch.nn.Module):
|
|
44
|
+
super().__init__()
|
|
45
|
+
self._origin_model = model
|
|
46
|
+
self.norm1 = model.norm1
|
|
47
|
+
self.norm2 = model.norm2
|
|
48
|
+
|
|
49
|
+
self.attn = VisionAttention(model.attn)
|
|
50
|
+
self.mlp = model.mlp
|
|
51
|
+
|
|
52
|
+
def forward(
|
|
53
|
+
self,
|
|
54
|
+
hidden_states: torch.Tensor,
|
|
55
|
+
attn_masks: torch.Tensor,
|
|
56
|
+
position_embeddings: Tuple[torch.Tensor, torch.Tensor],
|
|
57
|
+
) -> torch.Tensor:
|
|
58
|
+
hidden_states = hidden_states + self.attn(
|
|
59
|
+
self.norm1(hidden_states),
|
|
60
|
+
attn_masks,
|
|
61
|
+
position_embeddings,
|
|
62
|
+
)
|
|
63
|
+
hidden_states = hidden_states + self.mlp(self.norm2(hidden_states))
|
|
64
|
+
return hidden_states
|
|
65
|
+
|
|
66
|
+
|
|
67
|
+
class VisionAttention(nn.Module):
|
|
68
|
+
def __init__(self, model: nn.Module) -> None:
|
|
69
|
+
super().__init__()
|
|
70
|
+
self._origin_model = model
|
|
71
|
+
self.num_heads = model.num_heads
|
|
72
|
+
self.head_dim = getattr(model, "head_dim", model.proj.in_features // model.num_heads)
|
|
73
|
+
self.qkv = model.qkv
|
|
74
|
+
self.proj = model.proj
|
|
75
|
+
|
|
76
|
+
def forward(
|
|
77
|
+
self,
|
|
78
|
+
hidden_states: torch.Tensor,
|
|
79
|
+
attn_masks: torch.Tensor,
|
|
80
|
+
position_embeddings: Tuple[torch.Tensor, torch.Tensor],
|
|
81
|
+
) -> torch.Tensor:
|
|
82
|
+
seq_length = hidden_states.shape[0]
|
|
83
|
+
hidden_states = hidden_states.unsqueeze(0)
|
|
84
|
+
q, k, v = (
|
|
85
|
+
self.qkv(hidden_states).reshape(1, seq_length, 3, self.num_heads, -1).permute(2, 0, 3, 1, 4).unbind(0)
|
|
86
|
+
)
|
|
87
|
+
|
|
88
|
+
cos, sin = position_embeddings
|
|
89
|
+
q, k = apply_rotary_pos_emb(q, k, cos, sin)
|
|
90
|
+
|
|
91
|
+
attn_weights = torch.matmul(q, k.transpose(2, 3)) / math.sqrt(self.head_dim)
|
|
92
|
+
attn_weights = attn_weights + attn_masks
|
|
93
|
+
attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32)
|
|
94
|
+
attn_output = torch.matmul(attn_weights, v)
|
|
95
|
+
attn_output = attn_output.transpose(1, 2)
|
|
96
|
+
attn_output = attn_output.reshape(1, seq_length, -1)
|
|
97
|
+
attn_output = self.proj(attn_output).squeeze(0)
|
|
98
|
+
|
|
99
|
+
return attn_output
|
|
100
|
+
|
|
101
|
+
|
|
102
|
+
class Qwen2VL_LanguageModelWrapper(DecoderOnlyWrapper):
|
|
103
|
+
def prepare_forward_args(self, *args):
|
|
104
|
+
args = list(args)
|
|
105
|
+
input_ids = None if self.rbln_config.use_inputs_embeds else args.pop(0)
|
|
106
|
+
inputs_embeds = args.pop(0) if self.rbln_config.use_inputs_embeds else None
|
|
107
|
+
cache_position = args.pop(0)
|
|
108
|
+
global_block_tables = args.pop(0)
|
|
109
|
+
local_block_tables = None
|
|
110
|
+
position_embeds = args.pop(0)
|
|
111
|
+
query_position = args.pop(0) if self.phase == "prefill" else None
|
|
112
|
+
position_ids = None
|
|
113
|
+
attention_mask = args.pop(0) if self.rbln_config.use_attention_mask else None
|
|
114
|
+
lora_int_id = args.pop(0) if self.rbln_config.lora_config else None
|
|
115
|
+
past_key_values = args
|
|
116
|
+
|
|
117
|
+
if len(past_key_values) != 2 * self.num_hidden_layers:
|
|
118
|
+
raise ValueError(
|
|
119
|
+
f"Different past_key_values to model's config. {len(past_key_values)} != {2 * self.num_hidden_layers}"
|
|
120
|
+
)
|
|
121
|
+
|
|
122
|
+
# [key, value] * n_layer -> ( (key, value) ) * n_layer
|
|
123
|
+
# cache shape : batch, n_heads, 1, max_seq_len, head_dim
|
|
124
|
+
_past_key_values = []
|
|
125
|
+
for i in range(self.config.num_hidden_layers):
|
|
126
|
+
key_states = past_key_values[i * 2]
|
|
127
|
+
value_states = past_key_values[i * 2 + 1]
|
|
128
|
+
past_key_value = [key_states, value_states]
|
|
129
|
+
_past_key_values.append(past_key_value)
|
|
130
|
+
past_key_values = _past_key_values
|
|
131
|
+
|
|
132
|
+
return (
|
|
133
|
+
input_ids,
|
|
134
|
+
inputs_embeds,
|
|
135
|
+
cache_position,
|
|
136
|
+
global_block_tables,
|
|
137
|
+
local_block_tables,
|
|
138
|
+
query_position,
|
|
139
|
+
attention_mask,
|
|
140
|
+
position_ids,
|
|
141
|
+
lora_int_id,
|
|
142
|
+
past_key_values,
|
|
143
|
+
position_embeds,
|
|
144
|
+
)
|
|
145
|
+
|
|
146
|
+
def convert_to_rbln_class(self, model: PreTrainedModel, max_seq_len: int):
|
|
147
|
+
new_layers = []
|
|
148
|
+
|
|
149
|
+
for layer_idx, layer in enumerate(model.model.language_model.layers):
|
|
150
|
+
is_sliding = layer_idx in self.rbln_config.sliding_window_layers
|
|
151
|
+
new_self_attn = self.get_rbln_attn_class()(
|
|
152
|
+
self.get_attn_layer(layer), self.rbln_config, is_sliding=is_sliding
|
|
153
|
+
)
|
|
154
|
+
new_layer = self.get_rbln_layer_class()(layer, new_self_attn)
|
|
155
|
+
new_layers.append(new_layer)
|
|
156
|
+
|
|
157
|
+
new_model = self.get_rbln_model_class()(
|
|
158
|
+
model.model.language_model,
|
|
159
|
+
new_layers,
|
|
160
|
+
self.rbln_config,
|
|
161
|
+
use_learned_pos_emb=self.__class__._use_learned_pos_emb,
|
|
162
|
+
)
|
|
163
|
+
|
|
164
|
+
new_model = self.get_rbln_causal_lm_class()(model.model, new_model)
|
|
165
|
+
return new_model
|
|
@@ -0,0 +1,16 @@
|
|
|
1
|
+
# Copyright 2025 Rebellions Inc. All rights reserved.
|
|
2
|
+
|
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
4
|
+
# you may not use this file except in compliance with the License.
|
|
5
|
+
# You may obtain a copy of the License at:
|
|
6
|
+
|
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
8
|
+
|
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
12
|
+
# See the License for the specific language governing permissions and
|
|
13
|
+
# limitations under the License.
|
|
14
|
+
|
|
15
|
+
from .configuration_qwen3 import RBLNQwen3ForCausalLMConfig, RBLNQwen3ModelConfig
|
|
16
|
+
from .modeling_qwen3 import RBLNQwen3ForCausalLM, RBLNQwen3Model
|
|
@@ -0,0 +1,71 @@
|
|
|
1
|
+
# Copyright 2025 Rebellions Inc. All rights reserved.
|
|
2
|
+
|
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
4
|
+
# you may not use this file except in compliance with the License.
|
|
5
|
+
# You may obtain a copy of the License at:
|
|
6
|
+
|
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
8
|
+
|
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
12
|
+
# See the License for the specific language governing permissions and
|
|
13
|
+
# limitations under the License.
|
|
14
|
+
|
|
15
|
+
from ..decoderonly.configuration_decoderonly import RBLNDecoderOnlyModelConfig, RBLNDecoderOnlyModelForCausalLMConfig
|
|
16
|
+
|
|
17
|
+
|
|
18
|
+
class RBLNQwen3ForCausalLMConfig(RBLNDecoderOnlyModelForCausalLMConfig):
|
|
19
|
+
"""
|
|
20
|
+
Configuration class for RBLN Qwen3 models.
|
|
21
|
+
|
|
22
|
+
This class is an alias of RBLNDecoderOnlyModelForCausalLMConfig.
|
|
23
|
+
|
|
24
|
+
Example usage:
|
|
25
|
+
```python
|
|
26
|
+
from optimum.rbln import RBLNQwen3ForCausalLM, RBLNQwen3ForCausalLMConfig
|
|
27
|
+
|
|
28
|
+
# Create a configuration object
|
|
29
|
+
config = RBLNQwen3ForCausalLMConfig(
|
|
30
|
+
batch_size=1,
|
|
31
|
+
max_seq_len=40960,
|
|
32
|
+
tensor_parallel_size=4,
|
|
33
|
+
kvcache_partition_len=16384
|
|
34
|
+
)
|
|
35
|
+
|
|
36
|
+
# Use the configuration with from_pretrained
|
|
37
|
+
model = RBLNQwen3ForCausalLM.from_pretrained(
|
|
38
|
+
"Qwen/Qwen3-4B",
|
|
39
|
+
export=True,
|
|
40
|
+
rbln_config=config
|
|
41
|
+
)
|
|
42
|
+
```
|
|
43
|
+
"""
|
|
44
|
+
|
|
45
|
+
|
|
46
|
+
class RBLNQwen3ModelConfig(RBLNDecoderOnlyModelConfig):
|
|
47
|
+
"""
|
|
48
|
+
Configuration class for RBLN Qwen3 models.
|
|
49
|
+
|
|
50
|
+
This class is an alias of RBLNDecoderOnlyModelForCausalLMConfig.
|
|
51
|
+
|
|
52
|
+
Example usage:
|
|
53
|
+
```python
|
|
54
|
+
from optimum.rbln import RBLNQwen3Model, RBLNQwen3ModelConfig
|
|
55
|
+
|
|
56
|
+
# Create a configuration object
|
|
57
|
+
config = RBLNQwen3ModelConfig(
|
|
58
|
+
batch_size=1,
|
|
59
|
+
max_seq_len=40960,
|
|
60
|
+
tensor_parallel_size=4,
|
|
61
|
+
kvcache_partition_len=16384
|
|
62
|
+
)
|
|
63
|
+
|
|
64
|
+
# Use the configuration with from_pretrained
|
|
65
|
+
model = RBLNQwen3Model.from_pretrained(
|
|
66
|
+
"Qwen/Qwen3-Embedding-4B",
|
|
67
|
+
export=True,
|
|
68
|
+
rbln_config=config
|
|
69
|
+
)
|
|
70
|
+
```
|
|
71
|
+
"""
|
|
@@ -0,0 +1,133 @@
|
|
|
1
|
+
# Copyright 2025 Rebellions Inc. All rights reserved.
|
|
2
|
+
|
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
4
|
+
# you may not use this file except in compliance with the License.
|
|
5
|
+
# You may obtain a copy of the License at:
|
|
6
|
+
|
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
8
|
+
|
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
12
|
+
# See the License for the specific language governing permissions and
|
|
13
|
+
# limitations under the License.
|
|
14
|
+
|
|
15
|
+
from typing import TYPE_CHECKING
|
|
16
|
+
|
|
17
|
+
from transformers import PretrainedConfig
|
|
18
|
+
|
|
19
|
+
from ....utils import logging
|
|
20
|
+
from ...models.decoderonly import (
|
|
21
|
+
RBLNDecoderOnlyModel,
|
|
22
|
+
RBLNDecoderOnlyModelForCausalLM,
|
|
23
|
+
RBLNDecoderOnlyModelForCausalLMConfig,
|
|
24
|
+
)
|
|
25
|
+
from .qwen3_architecture import Qwen3Wrapper
|
|
26
|
+
|
|
27
|
+
|
|
28
|
+
logger = logging.get_logger(__name__)
|
|
29
|
+
|
|
30
|
+
if TYPE_CHECKING:
|
|
31
|
+
from transformers import PretrainedConfig
|
|
32
|
+
|
|
33
|
+
|
|
34
|
+
class RBLNQwen3ForCausalLM(RBLNDecoderOnlyModelForCausalLM):
|
|
35
|
+
"""
|
|
36
|
+
The Qwen3 Model transformer with a language modeling head (linear layer) on top.
|
|
37
|
+
This model inherits from [`RBLNDecoderOnlyModelForCausalLM`]. Check the superclass documentation for the generic methods the library implements for all its models.
|
|
38
|
+
A class to convert and run pre-trained transformers based Qwen3ForCausalLM model on RBLN devices.
|
|
39
|
+
It implements the methods to convert a pre-trained transformers Qwen3ForCausalLM model into a RBLN transformer model by:
|
|
40
|
+
- transferring the checkpoint weights of the original into an optimized RBLN graph,
|
|
41
|
+
- compiling the resulting graph using the RBLN compiler.
|
|
42
|
+
**Configuration:**
|
|
43
|
+
This model uses [`RBLNQwen3ForCausalLMConfig`] for configuration. When calling methods like `from_pretrained` or `from_model`,
|
|
44
|
+
the `rbln_config` parameter should be an instance of [`RBLNQwen3ForCausalLMConfig`] or a dictionary conforming to its structure.
|
|
45
|
+
See the [`RBLNQwen3ForCausalLMConfig`] class for all available configuration options.
|
|
46
|
+
Examples:
|
|
47
|
+
```python
|
|
48
|
+
from optimum.rbln import RBLNQwen3ForCausalLM
|
|
49
|
+
# Simple usage using rbln_* arguments
|
|
50
|
+
# `max_seq_len` is automatically inferred from the model config
|
|
51
|
+
model = RBLNQwen3ForCausalLM.from_pretrained(
|
|
52
|
+
"Qwen/Qwen3-4B",
|
|
53
|
+
export=True,
|
|
54
|
+
rbln_batch_size=1,
|
|
55
|
+
rbln_tensor_parallel_size=4,
|
|
56
|
+
)
|
|
57
|
+
# Using a config dictionary
|
|
58
|
+
rbln_config = {
|
|
59
|
+
"batch_size": 1,
|
|
60
|
+
"max_seq_len": 40_960,
|
|
61
|
+
"tensor_parallel_size": 4,
|
|
62
|
+
"kvcache_partition_len": 8192,
|
|
63
|
+
}
|
|
64
|
+
model = RBLNQwen3ForCausalLM.from_pretrained(
|
|
65
|
+
"Qwen/Qwen3-4B",
|
|
66
|
+
export=True,
|
|
67
|
+
rbln_config=rbln_config
|
|
68
|
+
)
|
|
69
|
+
# Using a RBLNQwen3ForCausalLMConfig instance (recommended for type checking)
|
|
70
|
+
from optimum.rbln import RBLNQwen3ForCausalLMConfig
|
|
71
|
+
config = RBLNQwen3ForCausalLMConfig(
|
|
72
|
+
batch_size=1,
|
|
73
|
+
max_seq_len=40_960,
|
|
74
|
+
tensor_parallel_size=4,
|
|
75
|
+
kvcache_partition_len=8192,
|
|
76
|
+
)
|
|
77
|
+
model = RBLNQwen3ForCausalLM.from_pretrained(
|
|
78
|
+
"Qwen/Qwen3-4B",
|
|
79
|
+
export=True,
|
|
80
|
+
rbln_config=config
|
|
81
|
+
)
|
|
82
|
+
```
|
|
83
|
+
"""
|
|
84
|
+
|
|
85
|
+
_decoder_wrapper_cls = Qwen3Wrapper
|
|
86
|
+
|
|
87
|
+
@classmethod
|
|
88
|
+
def _update_sliding_window_config(
|
|
89
|
+
cls, model_config: PretrainedConfig, rbln_config: RBLNDecoderOnlyModelForCausalLMConfig
|
|
90
|
+
):
|
|
91
|
+
# https://github.com/huggingface/transformers/issues/35896
|
|
92
|
+
# There seems to be a bug in transformers(v4.52.4). Therefore, similar to when attn_implementation is eager,
|
|
93
|
+
# we set all layers to use sliding window in this version. This should be updated once the bug is fixed.
|
|
94
|
+
|
|
95
|
+
rbln_config.cache_impl = "sliding_window"
|
|
96
|
+
rbln_config.sliding_window = model_config.sliding_window
|
|
97
|
+
rbln_config.sliding_window_layers = list(range(model_config.num_hidden_layers))
|
|
98
|
+
return rbln_config
|
|
99
|
+
|
|
100
|
+
def forward(self, *args, **kwargs):
|
|
101
|
+
kwargs["return_dict"] = True
|
|
102
|
+
return super().forward(*args, **kwargs)
|
|
103
|
+
|
|
104
|
+
|
|
105
|
+
class RBLNQwen3Model(RBLNDecoderOnlyModel):
|
|
106
|
+
"""
|
|
107
|
+
The bare Qwen3 Model outputting raw hidden-states without any specific head on top.
|
|
108
|
+
This model inherits from [`RBLNDecoderOnlyModel`]. Check the superclass documentation for the generic methods the library implements for all its models.
|
|
109
|
+
A class to convert and run pre-trained transformers based Qwen3Model on RBLN devices.
|
|
110
|
+
It implements the methods to convert a pre-trained transformers Qwen3Model into a RBLN transformer model by:
|
|
111
|
+
- transferring the checkpoint weights of the original into an optimized RBLN graph,
|
|
112
|
+
- compiling the resulting graph using the RBLN compiler.
|
|
113
|
+
**Configuration:**
|
|
114
|
+
This model uses [`RBLNQwen3ModelConfig`] for configuration. When calling methods like `from_pretrained` or `from_model`,
|
|
115
|
+
the `rbln_config` parameter should be an instance of [`RBLNQwen3ModelConfig`] or a dictionary conforming to its structure.
|
|
116
|
+
See the [`RBLNQwen3ModelConfig`] class for all available configuration options.
|
|
117
|
+
Examples:
|
|
118
|
+
```python
|
|
119
|
+
from optimum.rbln import RBLNQwen3Model
|
|
120
|
+
# Simple usage using rbln_* arguments
|
|
121
|
+
# `max_seq_len` is automatically inferred from the model config
|
|
122
|
+
model = RBLNQwen3Model.from_pretrained(
|
|
123
|
+
"Qwen/Qwen3-Embedding-4B",
|
|
124
|
+
export=True,
|
|
125
|
+
rbln_batch_size=1,
|
|
126
|
+
rbln_max_seq_len=40_960,
|
|
127
|
+
rbln_tensor_parallel_size=4,
|
|
128
|
+
rbln_kvcache_partition_len=8192,
|
|
129
|
+
)
|
|
130
|
+
"""
|
|
131
|
+
|
|
132
|
+
_decoder_wrapper_cls = Qwen3Wrapper
|
|
133
|
+
_use_rotary_emb = True
|
|
@@ -0,0 +1,31 @@
|
|
|
1
|
+
# Copyright 2025 Rebellions Inc. All rights reserved.
|
|
2
|
+
|
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
4
|
+
# you may not use this file except in compliance with the License.
|
|
5
|
+
# You may obtain a copy of the License at:
|
|
6
|
+
|
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
8
|
+
|
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
12
|
+
# See the License for the specific language governing permissions and
|
|
13
|
+
# limitations under the License.
|
|
14
|
+
|
|
15
|
+
|
|
16
|
+
from ..decoderonly.decoderonly_architecture import DecoderOnlyAttention, DecoderOnlyWrapper
|
|
17
|
+
|
|
18
|
+
|
|
19
|
+
class Qwen3Wrapper(DecoderOnlyWrapper):
|
|
20
|
+
def get_rbln_attn_class(self):
|
|
21
|
+
return Qwen3Attention
|
|
22
|
+
|
|
23
|
+
|
|
24
|
+
class Qwen3Attention(DecoderOnlyAttention):
|
|
25
|
+
def __post_init__(self):
|
|
26
|
+
self.k_proj = self._original_mod.k_proj
|
|
27
|
+
self.v_proj = self._original_mod.v_proj
|
|
28
|
+
self.q_proj = self._original_mod.q_proj
|
|
29
|
+
self.o_proj = self._original_mod.o_proj
|
|
30
|
+
self.q_norm = self._original_mod.q_norm
|
|
31
|
+
self.k_norm = self._original_mod.k_norm
|
|
@@ -13,6 +13,8 @@
|
|
|
13
13
|
# limitations under the License.
|
|
14
14
|
|
|
15
15
|
|
|
16
|
+
from typing import Optional
|
|
17
|
+
|
|
16
18
|
from ...configuration_generic import RBLNModelForImageClassificationConfig
|
|
17
19
|
|
|
18
20
|
|
|
@@ -23,3 +25,18 @@ class RBLNResNetForImageClassificationConfig(RBLNModelForImageClassificationConf
|
|
|
23
25
|
This configuration class stores the configuration parameters specific to
|
|
24
26
|
RBLN-optimized ResNet models for image classification tasks.
|
|
25
27
|
"""
|
|
28
|
+
|
|
29
|
+
def __init__(self, output_hidden_states: Optional[bool] = None, **kwargs):
|
|
30
|
+
"""
|
|
31
|
+
Args:
|
|
32
|
+
image_size (Optional[Union[int, Tuple[int, int]]]): The size of input images.
|
|
33
|
+
Can be an integer for square images or a tuple (height, width).
|
|
34
|
+
batch_size (Optional[int]): The batch size for inference. Defaults to 1.
|
|
35
|
+
output_hidden_states (bool, optional) — Whether or not to return the hidden states of all layers.
|
|
36
|
+
kwargs: Additional arguments passed to the parent RBLNModelConfig.
|
|
37
|
+
|
|
38
|
+
Raises:
|
|
39
|
+
ValueError: If batch_size is not a positive integer.
|
|
40
|
+
"""
|
|
41
|
+
super().__init__(**kwargs)
|
|
42
|
+
self.output_hidden_states = output_hidden_states
|
|
@@ -13,7 +13,17 @@
|
|
|
13
13
|
# limitations under the License.
|
|
14
14
|
|
|
15
15
|
|
|
16
|
+
from typing import TYPE_CHECKING, Optional, Tuple, Union
|
|
17
|
+
|
|
18
|
+
import torch
|
|
19
|
+
from transformers.modeling_outputs import ImageClassifierOutputWithNoAttention
|
|
20
|
+
|
|
16
21
|
from ...modeling_generic import RBLNModelForImageClassification
|
|
22
|
+
from .configuration_resnet import RBLNResNetForImageClassificationConfig
|
|
23
|
+
|
|
24
|
+
|
|
25
|
+
if TYPE_CHECKING:
|
|
26
|
+
from transformers import AutoFeatureExtractor, AutoProcessor, AutoTokenizer, PretrainedConfig, PreTrainedModel
|
|
17
27
|
|
|
18
28
|
|
|
19
29
|
class RBLNResNetForImageClassification(RBLNModelForImageClassification):
|
|
@@ -24,3 +34,66 @@ class RBLNResNetForImageClassification(RBLNModelForImageClassification):
|
|
|
24
34
|
on RBLN devices, supporting image classification with convolutional neural networks
|
|
25
35
|
designed for computer vision tasks.
|
|
26
36
|
"""
|
|
37
|
+
|
|
38
|
+
@classmethod
|
|
39
|
+
def _update_rbln_config(
|
|
40
|
+
cls,
|
|
41
|
+
preprocessors: Optional[Union["AutoFeatureExtractor", "AutoProcessor", "AutoTokenizer"]] = None,
|
|
42
|
+
model: Optional["PreTrainedModel"] = None,
|
|
43
|
+
model_config: Optional["PretrainedConfig"] = None,
|
|
44
|
+
rbln_config: Optional["RBLNResNetForImageClassificationConfig"] = None,
|
|
45
|
+
) -> "RBLNResNetForImageClassificationConfig":
|
|
46
|
+
if rbln_config.output_hidden_states is None:
|
|
47
|
+
rbln_config.output_hidden_states = getattr(model_config, "output_hidden_states", False)
|
|
48
|
+
|
|
49
|
+
rbln_config = super()._update_rbln_config(
|
|
50
|
+
preprocessors=preprocessors,
|
|
51
|
+
model=model,
|
|
52
|
+
model_config=model_config,
|
|
53
|
+
rbln_config=rbln_config,
|
|
54
|
+
)
|
|
55
|
+
|
|
56
|
+
return rbln_config
|
|
57
|
+
|
|
58
|
+
@classmethod
|
|
59
|
+
def _wrap_model_if_needed(
|
|
60
|
+
cls, model: torch.nn.Module, rbln_config: "RBLNResNetForImageClassificationConfig"
|
|
61
|
+
) -> torch.nn.Module:
|
|
62
|
+
class _ResNetForImageClassification(torch.nn.Module):
|
|
63
|
+
def __init__(self, model: torch.nn.Module, output_hidden_states: bool):
|
|
64
|
+
super().__init__()
|
|
65
|
+
self.model = model
|
|
66
|
+
self.output_hidden_states = output_hidden_states
|
|
67
|
+
|
|
68
|
+
def forward(self, *args, **kwargs):
|
|
69
|
+
output = self.model(*args, output_hidden_states=self.output_hidden_states, **kwargs)
|
|
70
|
+
return output
|
|
71
|
+
|
|
72
|
+
return _ResNetForImageClassification(model, rbln_config.output_hidden_states)
|
|
73
|
+
|
|
74
|
+
def forward(
|
|
75
|
+
self, pixel_values: torch.Tensor, output_hidden_states: bool = None, return_dict: bool = None, **kwargs
|
|
76
|
+
) -> Union[Tuple, ImageClassifierOutputWithNoAttention]:
|
|
77
|
+
"""
|
|
78
|
+
Foward pass for the RBLN-optimized ResNet model for image classification.
|
|
79
|
+
|
|
80
|
+
Args:
|
|
81
|
+
pixel_values (torch.FloatTensor of shape (batch_size, channels, height, width)): The tensors corresponding to the input images.
|
|
82
|
+
output_hidden_states (bool, *optional*, defaults to False): Whether or not to return the hidden states of all layers.
|
|
83
|
+
See hidden_states under returned tensors for more details.
|
|
84
|
+
return_dict (bool, *optional*, defaults to True): Whether to return a dictionary of outputs.
|
|
85
|
+
|
|
86
|
+
Returns:
|
|
87
|
+
The model outputs. If return_dict=False is passed, returns a tuple of tensors. Otherwise, returns a ImageClassifierOutputWithNoAttention object.
|
|
88
|
+
"""
|
|
89
|
+
output_hidden_states = (
|
|
90
|
+
output_hidden_states if output_hidden_states is not None else self.rbln_config.output_hidden_states
|
|
91
|
+
)
|
|
92
|
+
|
|
93
|
+
if output_hidden_states != self.rbln_config.output_hidden_states:
|
|
94
|
+
raise ValueError(
|
|
95
|
+
f"Variable output_hidden_states {output_hidden_states} is not equal to rbln_config.output_hidden_states {self.rbln_config.output_hidden_states} "
|
|
96
|
+
f"Please compile again with the correct argument."
|
|
97
|
+
)
|
|
98
|
+
|
|
99
|
+
return super().forward(pixel_values=pixel_values, return_dict=return_dict, **kwargs)
|
|
@@ -12,6 +12,11 @@
|
|
|
12
12
|
# See the License for the specific language governing permissions and
|
|
13
13
|
# limitations under the License.
|
|
14
14
|
|
|
15
|
+
from typing import Tuple, Union
|
|
16
|
+
|
|
17
|
+
import torch
|
|
18
|
+
from transformers.modeling_outputs import MaskedLMOutput, SequenceClassifierOutput
|
|
19
|
+
|
|
15
20
|
from ...modeling_generic import RBLNModelForMaskedLM, RBLNModelForSequenceClassification
|
|
16
21
|
|
|
17
22
|
|
|
@@ -26,6 +31,19 @@ class RBLNRobertaForMaskedLM(RBLNModelForMaskedLM):
|
|
|
26
31
|
|
|
27
32
|
rbln_model_input_names = ["input_ids", "attention_mask"]
|
|
28
33
|
|
|
34
|
+
def forward(self, input_ids: torch.Tensor, attention_mask: torch.Tensor, **kwargs) -> Union[Tuple, MaskedLMOutput]:
|
|
35
|
+
"""
|
|
36
|
+
Forward pass for the RBLN-optimized RoBERTa model for masked language modeling tasks.
|
|
37
|
+
|
|
38
|
+
Args:
|
|
39
|
+
input_ids (torch.LongTensor of shape (batch_size, sequence_length), optional): Indices of input sequence tokens in the vocabulary.
|
|
40
|
+
attention_mask (torch.FloatTensor of shape (batch_size, sequence_length), optional): Mask to avoid performing attention on padding token indices.
|
|
41
|
+
|
|
42
|
+
Returns:
|
|
43
|
+
The model outputs. If return_dict=False is passed, returns a tuple of tensors. Otherwise, returns a MaskedLMOutput object.
|
|
44
|
+
"""
|
|
45
|
+
return super().forward(input_ids, attention_mask, **kwargs)
|
|
46
|
+
|
|
29
47
|
|
|
30
48
|
class RBLNRobertaForSequenceClassification(RBLNModelForSequenceClassification):
|
|
31
49
|
"""
|
|
@@ -37,3 +55,18 @@ class RBLNRobertaForSequenceClassification(RBLNModelForSequenceClassification):
|
|
|
37
55
|
"""
|
|
38
56
|
|
|
39
57
|
rbln_model_input_names = ["input_ids", "attention_mask"]
|
|
58
|
+
|
|
59
|
+
def forward(
|
|
60
|
+
self, input_ids: torch.Tensor, attention_mask: torch.Tensor, **kwargs
|
|
61
|
+
) -> Union[Tuple, SequenceClassifierOutput]:
|
|
62
|
+
"""
|
|
63
|
+
Forward pass for the RBLN-optimized RoBERTa model for sequence classification tasks.
|
|
64
|
+
|
|
65
|
+
Args:
|
|
66
|
+
input_ids (torch.LongTensor of shape (batch_size, sequence_length), optional): Indices of input sequence tokens in the vocabulary.
|
|
67
|
+
attention_mask (torch.FloatTensor of shape (batch_size, sequence_length), optional): Mask to avoid performing attention on padding token indices.
|
|
68
|
+
|
|
69
|
+
Returns:
|
|
70
|
+
The model outputs. If return_dict=False is passed, returns a tuple of tensors. Otherwise, returns a SequenceClassifierOutput object.
|
|
71
|
+
"""
|
|
72
|
+
return super().forward(input_ids, attention_mask, **kwargs)
|
|
@@ -12,11 +12,10 @@
|
|
|
12
12
|
# See the License for the specific language governing permissions and
|
|
13
13
|
# limitations under the License.
|
|
14
14
|
|
|
15
|
-
from typing import Any,
|
|
16
|
-
|
|
17
|
-
import rebel
|
|
15
|
+
from typing import Any, Optional
|
|
18
16
|
|
|
19
17
|
from ....configuration_utils import RBLNModelConfig
|
|
18
|
+
from ....utils.deprecation import deprecate_kwarg
|
|
20
19
|
from ....utils.logging import get_logger
|
|
21
20
|
|
|
22
21
|
|
|
@@ -24,14 +23,18 @@ logger = get_logger()
|
|
|
24
23
|
|
|
25
24
|
|
|
26
25
|
class RBLNModelForSeq2SeqLMConfig(RBLNModelConfig):
|
|
26
|
+
support_paged_attention = None
|
|
27
|
+
|
|
28
|
+
@deprecate_kwarg(old_name="pad_token_id", version="0.10.0")
|
|
27
29
|
def __init__(
|
|
28
30
|
self,
|
|
29
31
|
batch_size: Optional[int] = None,
|
|
30
32
|
enc_max_seq_len: Optional[int] = None,
|
|
31
33
|
dec_max_seq_len: Optional[int] = None,
|
|
32
34
|
use_attention_mask: Optional[bool] = None,
|
|
33
|
-
|
|
34
|
-
|
|
35
|
+
kvcache_num_blocks: Optional[int] = None,
|
|
36
|
+
kvcache_block_size: Optional[int] = None,
|
|
37
|
+
**kwargs: Any,
|
|
35
38
|
):
|
|
36
39
|
"""
|
|
37
40
|
Args:
|
|
@@ -39,9 +42,11 @@ class RBLNModelForSeq2SeqLMConfig(RBLNModelConfig):
|
|
|
39
42
|
enc_max_seq_len (Optional[int]): Maximum sequence length for the encoder.
|
|
40
43
|
dec_max_seq_len (Optional[int]): Maximum sequence length for the decoder.
|
|
41
44
|
use_attention_mask (Optional[bool]): Whether to use attention masks during inference.
|
|
42
|
-
|
|
43
|
-
|
|
44
|
-
|
|
45
|
+
kvcache_num_blocks (Optional[int]): The total number of blocks to allocate for the
|
|
46
|
+
PagedAttention KV cache for the SelfAttention. Defaults to batch_size.
|
|
47
|
+
kvcache_block_size (Optional[int]): Sets the size (in number of tokens) of each block
|
|
48
|
+
in the PagedAttention KV cache for the SelfAttention. Defaults to dec_max_seq_len.
|
|
49
|
+
kwargs: Additional arguments passed to the parent RBLNModelConfig.
|
|
45
50
|
|
|
46
51
|
Raises:
|
|
47
52
|
ValueError: If batch_size is not a positive integer.
|
|
@@ -55,12 +60,12 @@ class RBLNModelForSeq2SeqLMConfig(RBLNModelConfig):
|
|
|
55
60
|
self.dec_max_seq_len = dec_max_seq_len
|
|
56
61
|
|
|
57
62
|
self.use_attention_mask = use_attention_mask
|
|
58
|
-
npu = self.npu or rebel.get_npu_name()
|
|
59
|
-
if npu == "RBLN-CA02":
|
|
60
|
-
if self.use_attention_mask is False:
|
|
61
|
-
logger.warning("Attention mask should be used with RBLN-CA02. Setting use_attention_mask to True.")
|
|
62
|
-
self.use_attention_mask = True
|
|
63
|
-
else:
|
|
64
|
-
self.use_attention_mask = self.use_attention_mask or False
|
|
65
63
|
|
|
66
|
-
self.
|
|
64
|
+
if self.support_paged_attention:
|
|
65
|
+
self.kvcache_num_blocks = kvcache_num_blocks
|
|
66
|
+
self.kvcache_block_size = kvcache_block_size
|
|
67
|
+
else:
|
|
68
|
+
if kvcache_num_blocks is not None or kvcache_block_size is not None:
|
|
69
|
+
raise ValueError(
|
|
70
|
+
"You cannot set kvcache_num_blocks or kvcache_block_size as paged attention is not supported for the model."
|
|
71
|
+
)
|