optimum-rbln 0.9.4a2__py3-none-any.whl → 0.9.5a4__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- optimum/rbln/__init__.py +36 -0
- optimum/rbln/__version__.py +2 -2
- optimum/rbln/configuration_utils.py +35 -16
- optimum/rbln/modeling_base.py +6 -6
- optimum/rbln/ops/__init__.py +1 -0
- optimum/rbln/ops/attn.py +10 -0
- optimum/rbln/ops/flash_attn.py +8 -0
- optimum/rbln/ops/moe.py +180 -0
- optimum/rbln/ops/sliding_window_attn.py +9 -0
- optimum/rbln/transformers/__init__.py +36 -0
- optimum/rbln/transformers/modeling_attention_utils.py +118 -222
- optimum/rbln/transformers/modeling_outputs.py +25 -0
- optimum/rbln/transformers/modeling_rope_utils.py +78 -42
- optimum/rbln/transformers/models/__init__.py +28 -0
- optimum/rbln/transformers/models/bart/bart_architecture.py +24 -24
- optimum/rbln/transformers/models/colpali/colpali_architecture.py +14 -20
- optimum/rbln/transformers/models/colpali/configuration_colpali.py +12 -17
- optimum/rbln/transformers/models/colpali/modeling_colpali.py +66 -182
- optimum/rbln/transformers/models/colqwen2/configuration_colqwen2.py +38 -21
- optimum/rbln/transformers/models/colqwen2/modeling_colqwen2.py +107 -371
- optimum/rbln/transformers/models/decoderonly/__init__.py +2 -0
- optimum/rbln/transformers/models/decoderonly/configuration_decoderonly.py +118 -16
- optimum/rbln/transformers/models/decoderonly/configuration_lora.py +1 -1
- optimum/rbln/transformers/models/decoderonly/decoderonly_architecture.py +121 -48
- optimum/rbln/transformers/models/decoderonly/decoderonly_runtime_utils.py +5 -7
- optimum/rbln/transformers/models/decoderonly/modeling_decoderonly.py +75 -107
- optimum/rbln/transformers/models/exaone/exaone_architecture.py +0 -36
- optimum/rbln/transformers/models/gemma/gemma_architecture.py +1 -1
- optimum/rbln/transformers/models/gemma2/__init__.py +16 -0
- optimum/rbln/transformers/models/gemma2/configuration_gemma2.py +45 -0
- optimum/rbln/transformers/models/gemma2/gemma2_architecture.py +83 -0
- optimum/rbln/transformers/models/gemma2/modeling_gemma2.py +101 -0
- optimum/rbln/transformers/models/gemma3/gemma3_architecture.py +16 -18
- optimum/rbln/transformers/models/gemma3/modeling_gemma3.py +1 -1
- optimum/rbln/transformers/models/gpt2/gpt2_architecture.py +8 -34
- optimum/rbln/transformers/models/gpt_oss/__init__.py +16 -0
- optimum/rbln/transformers/models/gpt_oss/configuration_gpt_oss.py +41 -0
- optimum/rbln/transformers/models/gpt_oss/gpt_oss_architecture.py +122 -0
- optimum/rbln/transformers/models/gpt_oss/modeling_gpt_oss.py +165 -0
- optimum/rbln/transformers/models/grounding_dino/configuration_grounding_dino.py +8 -5
- optimum/rbln/transformers/models/grounding_dino/grounding_dino_architecture.py +6 -4
- optimum/rbln/transformers/models/llava/modeling_llava.py +0 -1
- optimum/rbln/transformers/models/midm/midm_architecture.py +29 -22
- optimum/rbln/transformers/models/opt/opt_architecture.py +1 -44
- optimum/rbln/transformers/models/paligemma/__init__.py +16 -0
- optimum/rbln/transformers/models/paligemma/configuration_paligemma.py +129 -0
- optimum/rbln/transformers/models/paligemma/modeling_paligemma.py +564 -0
- optimum/rbln/transformers/models/pegasus/pegasus_architecture.py +24 -24
- optimum/rbln/transformers/models/phi/phi_architecture.py +13 -21
- optimum/rbln/transformers/models/qwen2_5_vl/__init__.py +6 -1
- optimum/rbln/transformers/models/qwen2_5_vl/configuration_qwen2_5_vl.py +11 -1
- optimum/rbln/transformers/models/qwen2_5_vl/modeling_qwen2_5_vl.py +271 -122
- optimum/rbln/transformers/models/qwen2_5_vl/qwen2_5_vl_architecture.py +43 -39
- optimum/rbln/transformers/models/qwen2_moe/__init__.py +16 -0
- optimum/rbln/transformers/models/qwen2_moe/configuration_qwen2_moe.py +38 -0
- optimum/rbln/transformers/models/qwen2_moe/modeling_qwen2_moe.py +68 -0
- optimum/rbln/transformers/models/qwen2_moe/qwen2_moe_architecture.py +94 -0
- optimum/rbln/transformers/models/qwen2_vl/__init__.py +6 -1
- optimum/rbln/transformers/models/qwen2_vl/configuration_qwen2_vl.py +11 -1
- optimum/rbln/transformers/models/qwen2_vl/modeling_qwen2_vl.py +263 -105
- optimum/rbln/transformers/models/qwen2_vl/qwen2_vl_architecture.py +26 -34
- optimum/rbln/transformers/models/qwen3/qwen3_architecture.py +7 -7
- optimum/rbln/transformers/models/qwen3_moe/__init__.py +16 -0
- optimum/rbln/transformers/models/qwen3_moe/configuration_qwen3_moe.py +38 -0
- optimum/rbln/transformers/models/qwen3_moe/modeling_qwen3_moe.py +68 -0
- optimum/rbln/transformers/models/qwen3_moe/qwen3_moe_architecture.py +100 -0
- optimum/rbln/transformers/models/seq2seq/seq2seq_architecture.py +14 -12
- optimum/rbln/transformers/models/siglip/modeling_siglip.py +4 -18
- optimum/rbln/transformers/models/swin/configuration_swin.py +1 -6
- optimum/rbln/transformers/models/t5/t5_architecture.py +15 -16
- optimum/rbln/transformers/models/time_series_transformer/time_series_transformers_architecture.py +0 -3
- optimum/rbln/transformers/models/whisper/whisper_architecture.py +0 -3
- optimum/rbln/transformers/utils/rbln_quantization.py +20 -12
- optimum/rbln/utils/import_utils.py +16 -1
- optimum/rbln/utils/runtime_utils.py +10 -6
- optimum/rbln/utils/submodule.py +24 -0
- {optimum_rbln-0.9.4a2.dist-info → optimum_rbln-0.9.5a4.dist-info}/METADATA +6 -6
- {optimum_rbln-0.9.4a2.dist-info → optimum_rbln-0.9.5a4.dist-info}/RECORD +81 -62
- optimum/rbln/transformers/models/colqwen2/colqwen2_architecture.py +0 -233
- {optimum_rbln-0.9.4a2.dist-info → optimum_rbln-0.9.5a4.dist-info}/WHEEL +0 -0
- {optimum_rbln-0.9.4a2.dist-info → optimum_rbln-0.9.5a4.dist-info}/entry_points.txt +0 -0
- {optimum_rbln-0.9.4a2.dist-info → optimum_rbln-0.9.5a4.dist-info}/licenses/LICENSE +0 -0
|
@@ -20,7 +20,6 @@ from transformers import PhiForCausalLM
|
|
|
20
20
|
from ..decoderonly.decoderonly_architecture import (
|
|
21
21
|
DecoderOnlyAttention,
|
|
22
22
|
DecoderOnlyLayer,
|
|
23
|
-
DecoderOnlyModel,
|
|
24
23
|
DecoderOnlyWrapper,
|
|
25
24
|
apply_rotary_pos_emb_partial,
|
|
26
25
|
)
|
|
@@ -37,9 +36,6 @@ class PhiWrapper(DecoderOnlyWrapper):
|
|
|
37
36
|
def get_rbln_layer_class(self):
|
|
38
37
|
return PhiLayer
|
|
39
38
|
|
|
40
|
-
def get_rbln_model_class(self):
|
|
41
|
-
return PhiModel
|
|
42
|
-
|
|
43
39
|
def get_model_layer(self, model: Union["PhiForCausalLM", "PhiModel"]):
|
|
44
40
|
return model.model if self.is_causal_lm else model
|
|
45
41
|
|
|
@@ -48,13 +44,15 @@ class PhiWrapper(DecoderOnlyWrapper):
|
|
|
48
44
|
|
|
49
45
|
|
|
50
46
|
class PhiAttention(DecoderOnlyAttention):
|
|
51
|
-
def __post_init__(self):
|
|
52
|
-
self.q_proj =
|
|
53
|
-
self.k_proj =
|
|
54
|
-
self.v_proj =
|
|
55
|
-
self.o_proj =
|
|
56
|
-
self.qk_layernorm =
|
|
57
|
-
self.rotary_ndims =
|
|
47
|
+
def __post_init__(self, self_attn):
|
|
48
|
+
self.q_proj = self_attn.q_proj
|
|
49
|
+
self.k_proj = self_attn.k_proj
|
|
50
|
+
self.v_proj = self_attn.v_proj
|
|
51
|
+
self.o_proj = self_attn.dense
|
|
52
|
+
self.qk_layernorm = self_attn.qk_layernorm
|
|
53
|
+
self.rotary_ndims = self_attn.rotary_ndims
|
|
54
|
+
self.q_layernorm = getattr(self_attn, "q_layernorm", None)
|
|
55
|
+
self.k_layernorm = getattr(self_attn, "k_layernorm", None)
|
|
58
56
|
|
|
59
57
|
def projection(self, hidden_states, lora_int_id) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
|
|
60
58
|
if lora_int_id is not None:
|
|
@@ -65,8 +63,8 @@ class PhiAttention(DecoderOnlyAttention):
|
|
|
65
63
|
value_states = self.v_proj(hidden_states)
|
|
66
64
|
|
|
67
65
|
if self.qk_layernorm:
|
|
68
|
-
query_states = self.
|
|
69
|
-
key_states = self.
|
|
66
|
+
query_states = self.q_layernorm(query_states)
|
|
67
|
+
key_states = self.k_layernorm(key_states)
|
|
70
68
|
|
|
71
69
|
return query_states, key_states, value_states
|
|
72
70
|
|
|
@@ -75,8 +73,7 @@ class PhiAttention(DecoderOnlyAttention):
|
|
|
75
73
|
|
|
76
74
|
|
|
77
75
|
class PhiLayer(DecoderOnlyLayer):
|
|
78
|
-
|
|
79
|
-
raise NotImplementedError
|
|
76
|
+
_POST_ATTN_LAYERNORM = None
|
|
80
77
|
|
|
81
78
|
def forward(
|
|
82
79
|
self,
|
|
@@ -103,13 +100,8 @@ class PhiLayer(DecoderOnlyLayer):
|
|
|
103
100
|
block_tables=block_tables,
|
|
104
101
|
)
|
|
105
102
|
|
|
106
|
-
feed_forward_hidden_states = self.
|
|
103
|
+
feed_forward_hidden_states = self.mlp(hidden_states)
|
|
107
104
|
|
|
108
105
|
hidden_states = attn_output + feed_forward_hidden_states + residual
|
|
109
106
|
|
|
110
107
|
return hidden_states
|
|
111
|
-
|
|
112
|
-
|
|
113
|
-
class PhiModel(DecoderOnlyModel):
|
|
114
|
-
def get_last_layernorm(self):
|
|
115
|
-
return self._original_mod.final_layernorm
|
|
@@ -15,5 +15,10 @@
|
|
|
15
15
|
from .configuration_qwen2_5_vl import (
|
|
16
16
|
RBLNQwen2_5_VisionTransformerPretrainedModelConfig,
|
|
17
17
|
RBLNQwen2_5_VLForConditionalGenerationConfig,
|
|
18
|
+
RBLNQwen2_5_VLModelConfig,
|
|
19
|
+
)
|
|
20
|
+
from .modeling_qwen2_5_vl import (
|
|
21
|
+
RBLNQwen2_5_VisionTransformerPretrainedModel,
|
|
22
|
+
RBLNQwen2_5_VLForConditionalGeneration,
|
|
23
|
+
RBLNQwen2_5_VLModel,
|
|
18
24
|
)
|
|
19
|
-
from .modeling_qwen2_5_vl import RBLNQwen2_5_VisionTransformerPretrainedModel, RBLNQwen2_5_VLForConditionalGeneration
|
|
@@ -15,7 +15,7 @@
|
|
|
15
15
|
from typing import Any, List, Optional, Union
|
|
16
16
|
|
|
17
17
|
from ....configuration_utils import RBLNModelConfig
|
|
18
|
-
from ..decoderonly.configuration_decoderonly import RBLNDecoderOnlyModelForCausalLMConfig
|
|
18
|
+
from ..decoderonly.configuration_decoderonly import RBLNDecoderOnlyModelConfig, RBLNDecoderOnlyModelForCausalLMConfig
|
|
19
19
|
|
|
20
20
|
|
|
21
21
|
class RBLNQwen2_5_VLForConditionalGenerationConfig(RBLNDecoderOnlyModelForCausalLMConfig):
|
|
@@ -56,6 +56,16 @@ class RBLNQwen2_5_VLForConditionalGenerationConfig(RBLNDecoderOnlyModelForCausal
|
|
|
56
56
|
self.visual = visual
|
|
57
57
|
|
|
58
58
|
|
|
59
|
+
class RBLNQwen2_5_VLModelConfig(RBLNDecoderOnlyModelConfig):
|
|
60
|
+
"""
|
|
61
|
+
Configuration class for RBLNQwen2_5_VLModel.
|
|
62
|
+
"""
|
|
63
|
+
|
|
64
|
+
def __init__(self, visual: Optional[RBLNModelConfig] = None, **kwargs: Any):
|
|
65
|
+
super().__init__(**kwargs)
|
|
66
|
+
self.visual = self.initialize_submodule_config(submodule_config=visual)
|
|
67
|
+
|
|
68
|
+
|
|
59
69
|
class RBLNQwen2_5_VisionTransformerPretrainedModelConfig(RBLNModelConfig):
|
|
60
70
|
"""
|
|
61
71
|
Configuration class for RBLNQwen2_5_VisionTransformerPretrainedModel.
|
|
@@ -17,7 +17,13 @@ from pathlib import Path
|
|
|
17
17
|
from typing import TYPE_CHECKING, Any, Callable, List, Optional, Tuple, Union
|
|
18
18
|
|
|
19
19
|
import torch
|
|
20
|
-
from transformers import
|
|
20
|
+
from transformers import (
|
|
21
|
+
AutoModelForVision2Seq,
|
|
22
|
+
PretrainedConfig,
|
|
23
|
+
PreTrainedModel,
|
|
24
|
+
Qwen2_5_VLConfig,
|
|
25
|
+
Qwen2_5_VLForConditionalGeneration,
|
|
26
|
+
)
|
|
21
27
|
from transformers.modeling_utils import no_init_weights
|
|
22
28
|
from transformers.models.qwen2_5_vl.modeling_qwen2_5_vl import (
|
|
23
29
|
Qwen2_5_VisionPatchEmbed,
|
|
@@ -30,8 +36,8 @@ from transformers.models.qwen2_5_vl.modeling_qwen2_5_vl import (
|
|
|
30
36
|
from ....configuration_utils import RBLNCompileConfig
|
|
31
37
|
from ....modeling import RBLNModel
|
|
32
38
|
from ....utils.logging import get_logger
|
|
33
|
-
from ...modeling_outputs import RBLNDecoderOnlyOutput
|
|
34
|
-
from ..decoderonly.modeling_decoderonly import RBLNDecoderOnlyModelForCausalLM
|
|
39
|
+
from ...modeling_outputs import RBLNDecoderOnlyOutput, _validate_output_hidden_states
|
|
40
|
+
from ..decoderonly.modeling_decoderonly import RBLNDecoderOnlyModel, RBLNDecoderOnlyModelForCausalLM
|
|
35
41
|
from .configuration_qwen2_5_vl import (
|
|
36
42
|
RBLNQwen2_5_VisionTransformerPretrainedModelConfig,
|
|
37
43
|
RBLNQwen2_5_VLForConditionalGenerationConfig,
|
|
@@ -42,7 +48,7 @@ from .qwen2_5_vl_architecture import Qwen2_5_VisionTransformerWrapper, Qwen2_5_V
|
|
|
42
48
|
logger = get_logger(__name__)
|
|
43
49
|
|
|
44
50
|
if TYPE_CHECKING:
|
|
45
|
-
from transformers import AutoFeatureExtractor, AutoProcessor, AutoTokenizer
|
|
51
|
+
from transformers import AutoFeatureExtractor, AutoProcessor, AutoTokenizer
|
|
46
52
|
|
|
47
53
|
|
|
48
54
|
class RBLNQwen2_5_VisionTransformerPretrainedModel(RBLNModel):
|
|
@@ -55,6 +61,7 @@ class RBLNQwen2_5_VisionTransformerPretrainedModel(RBLNModel):
|
|
|
55
61
|
"""
|
|
56
62
|
|
|
57
63
|
auto_model_class = None
|
|
64
|
+
_supports_non_fp32 = True
|
|
58
65
|
|
|
59
66
|
def __post_init__(self, **kwargs):
|
|
60
67
|
self.transformer = self.model[0]
|
|
@@ -91,7 +98,7 @@ class RBLNQwen2_5_VisionTransformerPretrainedModel(RBLNModel):
|
|
|
91
98
|
def _wrap_model_if_needed(
|
|
92
99
|
cls, model: "PreTrainedModel", rbln_config: RBLNQwen2_5_VisionTransformerPretrainedModelConfig
|
|
93
100
|
):
|
|
94
|
-
return Qwen2_5_VisionTransformerWrapper(model).eval()
|
|
101
|
+
return Qwen2_5_VisionTransformerWrapper(model, rbln_config).eval()
|
|
95
102
|
|
|
96
103
|
def __getattr__(self, __name: str) -> Any:
|
|
97
104
|
def redirect(func):
|
|
@@ -126,22 +133,22 @@ class RBLNQwen2_5_VisionTransformerPretrainedModel(RBLNModel):
|
|
|
126
133
|
)
|
|
127
134
|
|
|
128
135
|
input_info = [
|
|
129
|
-
("hidden_states", [max_seq_len, hidden_size],
|
|
130
|
-
("full_attn_masks", [1, 1, max_seq_len, max_seq_len],
|
|
136
|
+
("hidden_states", [max_seq_len, hidden_size], rbln_config.dtype),
|
|
137
|
+
("full_attn_masks", [1, 1, max_seq_len, max_seq_len], rbln_config.dtype),
|
|
131
138
|
(
|
|
132
139
|
"window_attn_masks",
|
|
133
140
|
[max_seq_len // window_seq_len, 1, window_seq_len, window_seq_len],
|
|
134
|
-
|
|
141
|
+
rbln_config.dtype,
|
|
135
142
|
),
|
|
136
143
|
(
|
|
137
144
|
"cos",
|
|
138
145
|
[1, 1, max_seq_len, head_dim],
|
|
139
|
-
|
|
146
|
+
rbln_config.dtype,
|
|
140
147
|
),
|
|
141
148
|
(
|
|
142
149
|
"sin",
|
|
143
150
|
[1, 1, max_seq_len, head_dim],
|
|
144
|
-
|
|
151
|
+
rbln_config.dtype,
|
|
145
152
|
),
|
|
146
153
|
]
|
|
147
154
|
input_infos.append(input_info)
|
|
@@ -203,7 +210,7 @@ class RBLNQwen2_5_VisionTransformerPretrainedModel(RBLNModel):
|
|
|
203
210
|
1,
|
|
204
211
|
window_seq_len,
|
|
205
212
|
window_seq_len,
|
|
206
|
-
dtype=
|
|
213
|
+
dtype=hidden_states.dtype,
|
|
207
214
|
)
|
|
208
215
|
for i, valid_len in enumerate(window_valid_lengths):
|
|
209
216
|
if valid_len < window_seq_len:
|
|
@@ -242,7 +249,7 @@ class RBLNQwen2_5_VisionTransformerPretrainedModel(RBLNModel):
|
|
|
242
249
|
1,
|
|
243
250
|
max_seq_len,
|
|
244
251
|
max_seq_len,
|
|
245
|
-
dtype=
|
|
252
|
+
dtype=hidden_state_padded.dtype,
|
|
246
253
|
)
|
|
247
254
|
for i, valid_len in enumerate(window_valid_lengths):
|
|
248
255
|
start = i * window_seq_len
|
|
@@ -253,7 +260,7 @@ class RBLNQwen2_5_VisionTransformerPretrainedModel(RBLNModel):
|
|
|
253
260
|
return hidden_state_full_padded, cos_full_padded, sin_full_padded, full_attn_masks
|
|
254
261
|
|
|
255
262
|
def forward(self, hidden_states: torch.Tensor, grid_thw: torch.Tensor) -> torch.Tensor:
|
|
256
|
-
hidden_states = self.patch_embed(hidden_states)
|
|
263
|
+
hidden_states = self.patch_embed(hidden_states).to(self.rbln_config.dtype)
|
|
257
264
|
rotary_pos_emb = self.rot_pos_emb(grid_thw)
|
|
258
265
|
window_index, cu_window_seqlens = self.get_window_index(grid_thw)
|
|
259
266
|
cu_window_seqlens = torch.tensor(
|
|
@@ -270,7 +277,7 @@ class RBLNQwen2_5_VisionTransformerPretrainedModel(RBLNModel):
|
|
|
270
277
|
rotary_pos_emb = rotary_pos_emb[window_index, :, :]
|
|
271
278
|
rotary_pos_emb = rotary_pos_emb.reshape(seq_len, -1)
|
|
272
279
|
emb = torch.cat((rotary_pos_emb, rotary_pos_emb), dim=-1)
|
|
273
|
-
position_embeddings = (emb.cos(), emb.sin())
|
|
280
|
+
position_embeddings = (emb.cos().to(self.rbln_config.dtype), emb.sin().to(self.rbln_config.dtype))
|
|
274
281
|
|
|
275
282
|
cu_seqlens = torch.repeat_interleave(grid_thw[:, 1] * grid_thw[:, 2], grid_thw[:, 0]).cumsum(
|
|
276
283
|
dim=0,
|
|
@@ -338,66 +345,47 @@ class RBLNQwen2_5_VisionTransformerPretrainedModel(RBLNModel):
|
|
|
338
345
|
return hidden_states
|
|
339
346
|
|
|
340
347
|
|
|
341
|
-
class
|
|
342
|
-
"""
|
|
343
|
-
RBLNQwen2_5_VLForConditionalGeneration is a multi-modal model that integrates vision and language processing capabilities,
|
|
344
|
-
optimized for RBLN NPUs. It is designed for conditional generation tasks that involve both image and text inputs.
|
|
345
|
-
|
|
346
|
-
This model inherits from [`RBLNDecoderOnlyModelForCausalLM`]. Check the superclass documentation for the generic methods the library implements for all its models.
|
|
347
|
-
|
|
348
|
-
Important Note:
|
|
349
|
-
This model includes a Large Language Model (LLM). For optimal performance, it is highly recommended to use
|
|
350
|
-
tensor parallelism for the language model. This can be achieved by using the `rbln_config` parameter in the
|
|
351
|
-
`from_pretrained` method. Refer to the `from_pretrained` documentation and the RBLNQwen2_5_VLForConditionalGenerationConfig class for details.
|
|
352
|
-
|
|
353
|
-
Examples:
|
|
354
|
-
```python
|
|
355
|
-
from optimum.rbln import RBLNQwen2_5_VLForConditionalGeneration
|
|
356
|
-
|
|
357
|
-
model = RBLNQwen2_5_VLForConditionalGeneration.from_pretrained(
|
|
358
|
-
"Qwen/Qwen2.5-VL-7B-Instruct",
|
|
359
|
-
export=True,
|
|
360
|
-
rbln_config={
|
|
361
|
-
"visual": {
|
|
362
|
-
"max_seq_lens": 6400,
|
|
363
|
-
"device": 0,
|
|
364
|
-
},
|
|
365
|
-
"tensor_parallel_size": 8,
|
|
366
|
-
"kvcache_partition_len": 16_384,
|
|
367
|
-
"max_seq_len": 114_688,
|
|
368
|
-
"device": [0, 1, 2, 3, 4, 5, 6, 7],
|
|
369
|
-
},
|
|
370
|
-
)
|
|
371
|
-
|
|
372
|
-
model.save_pretrained("compiled-qwen2.5-vl-7b-instruct")
|
|
373
|
-
```
|
|
374
|
-
"""
|
|
375
|
-
|
|
376
|
-
_supports_non_fp32 = False
|
|
377
|
-
|
|
348
|
+
class RBLNQwen2_5_VLModel(RBLNDecoderOnlyModel):
|
|
378
349
|
auto_model_class = AutoModelForVision2Seq
|
|
350
|
+
_decoder_wrapper_cls = Qwen2_5_VL_LanguageModelWrapper
|
|
351
|
+
_use_rotary_emb = False
|
|
379
352
|
_rbln_submodules = [
|
|
380
353
|
{"name": "visual"},
|
|
381
354
|
]
|
|
382
|
-
|
|
383
|
-
|
|
355
|
+
_config_class = Qwen2_5_VLConfig
|
|
356
|
+
_rotary_emb_class = Qwen2_5_VLRotaryEmbedding
|
|
357
|
+
_get_rope_index_func = Qwen2_5_VLModel.get_rope_index
|
|
384
358
|
|
|
385
359
|
def __post_init__(self, **kwargs):
|
|
360
|
+
if hasattr(self.config, "embedding_dim"):
|
|
361
|
+
self.embedding_dim = self.config.embedding_dim
|
|
362
|
+
|
|
363
|
+
if not isinstance(self.config.text_config, PretrainedConfig):
|
|
364
|
+
self.config = self._config_class(
|
|
365
|
+
text_config=self.config.text_config, vision_config=self.config.vision_config
|
|
366
|
+
)
|
|
367
|
+
|
|
386
368
|
super().__post_init__(**kwargs)
|
|
387
369
|
self.visual = self.rbln_submodules[0]
|
|
388
|
-
self.
|
|
389
|
-
|
|
390
|
-
|
|
391
|
-
|
|
392
|
-
|
|
393
|
-
|
|
370
|
+
self.rotary_emb = self._rotary_emb_class(self.config)
|
|
371
|
+
if not self.can_generate():
|
|
372
|
+
self.block_tables = torch.arange(self.rbln_config.kvcache_num_blocks, dtype=torch.int16)
|
|
373
|
+
|
|
374
|
+
@property
|
|
375
|
+
def logits_last_dim(self):
|
|
376
|
+
if self.can_generate():
|
|
377
|
+
return self.config.vocab_size
|
|
378
|
+
else:
|
|
379
|
+
return self.embedding_dim if hasattr(self, "embedding_dim") else self.config.hidden_size
|
|
394
380
|
|
|
395
|
-
|
|
396
|
-
|
|
397
|
-
|
|
398
|
-
|
|
399
|
-
|
|
400
|
-
|
|
381
|
+
def _create_embedding_layer(self):
|
|
382
|
+
with no_init_weights():
|
|
383
|
+
embed_tokens = torch.nn.Embedding(
|
|
384
|
+
self.config.text_config.vocab_size,
|
|
385
|
+
self.config.text_config.hidden_size,
|
|
386
|
+
self.config.text_config.pad_token_id,
|
|
387
|
+
)
|
|
388
|
+
return embed_tokens
|
|
401
389
|
|
|
402
390
|
@classmethod
|
|
403
391
|
def get_input_info(
|
|
@@ -414,61 +402,25 @@ class RBLNQwen2_5_VLForConditionalGeneration(RBLNDecoderOnlyModelForCausalLM):
|
|
|
414
402
|
(
|
|
415
403
|
"position_emb",
|
|
416
404
|
[2, batch_size, 1, query_length, model_config.hidden_size // model_config.num_attention_heads],
|
|
417
|
-
|
|
405
|
+
rbln_config.dtype,
|
|
418
406
|
),
|
|
419
407
|
)
|
|
420
408
|
|
|
421
409
|
return input_info
|
|
422
410
|
|
|
423
|
-
def prepare_inputs_for_generation(
|
|
424
|
-
self,
|
|
425
|
-
input_ids: torch.LongTensor,
|
|
426
|
-
generate_idx: Optional[torch.Tensor] = None,
|
|
427
|
-
attention_mask: Optional[torch.LongTensor] = None,
|
|
428
|
-
inputs_embeds: Optional[torch.Tensor] = None,
|
|
429
|
-
pixel_values=None,
|
|
430
|
-
pixel_values_videos=None,
|
|
431
|
-
image_grid_thw=None,
|
|
432
|
-
video_grid_thw=None,
|
|
433
|
-
second_per_grid_ts=None,
|
|
434
|
-
**kwargs,
|
|
435
|
-
):
|
|
436
|
-
model_inputs = {}
|
|
437
|
-
is_prefill_phase = generate_idx is None
|
|
438
|
-
|
|
439
|
-
if is_prefill_phase:
|
|
440
|
-
generate_idx = attention_mask.sum(dim=-1, keepdim=True).int()
|
|
441
|
-
cache_position = None
|
|
442
|
-
model_inputs.update({"input_ids": input_ids})
|
|
443
|
-
else:
|
|
444
|
-
if inputs_embeds is not None:
|
|
445
|
-
raise NotImplementedError("Specifying inputs_embeds in decoder phase is not supported.")
|
|
446
|
-
|
|
447
|
-
input_ids = input_ids[:, -1:]
|
|
448
|
-
cache_position = generate_idx
|
|
449
|
-
generate_idx = generate_idx + 1
|
|
450
|
-
model_inputs.update({"input_ids": input_ids})
|
|
451
|
-
|
|
452
|
-
model_inputs.update(
|
|
453
|
-
{
|
|
454
|
-
"attention_mask": attention_mask,
|
|
455
|
-
"cache_position": cache_position,
|
|
456
|
-
"generate_idx": generate_idx,
|
|
457
|
-
"pixel_values": pixel_values,
|
|
458
|
-
"pixel_values_videos": pixel_values_videos,
|
|
459
|
-
"image_grid_thw": image_grid_thw,
|
|
460
|
-
"video_grid_thw": video_grid_thw,
|
|
461
|
-
"second_per_grid_ts": second_per_grid_ts,
|
|
462
|
-
}
|
|
463
|
-
)
|
|
464
|
-
|
|
465
|
-
return model_inputs
|
|
466
|
-
|
|
467
411
|
def _get_position_embeddings(self, hidden_states, position_ids):
|
|
468
412
|
cos, sin = self.rotary_emb(hidden_states, position_ids)
|
|
469
|
-
mrope_section = self.mrope_section * 2
|
|
470
|
-
cos =
|
|
471
|
-
|
|
413
|
+
mrope_section = self.config.rope_scaling["mrope_section"] * 2
|
|
414
|
+
cos = (
|
|
415
|
+
torch.cat([m[i % 3] for i, m in enumerate(cos.split(mrope_section, dim=-1))], dim=-1)
|
|
416
|
+
.unsqueeze(1)
|
|
417
|
+
.to(self.rbln_config.dtype)
|
|
418
|
+
)
|
|
419
|
+
sin = (
|
|
420
|
+
torch.cat([m[i % 3] for i, m in enumerate(sin.split(mrope_section, dim=-1))], dim=-1)
|
|
421
|
+
.unsqueeze(1)
|
|
422
|
+
.to(self.rbln_config.dtype)
|
|
423
|
+
)
|
|
472
424
|
return torch.stack([cos, sin])
|
|
473
425
|
|
|
474
426
|
def _preprocess_prefill(
|
|
@@ -482,7 +434,7 @@ class RBLNQwen2_5_VLForConditionalGeneration(RBLNDecoderOnlyModelForCausalLM):
|
|
|
482
434
|
second_per_grid_ts: torch.Tensor = None,
|
|
483
435
|
):
|
|
484
436
|
batch_size = input_ids.shape[0]
|
|
485
|
-
inputs_embeds = self.embed_tokens(input_ids)
|
|
437
|
+
inputs_embeds = self.embed_tokens(input_ids).to(self.rbln_config.dtype)
|
|
486
438
|
|
|
487
439
|
if pixel_values is not None:
|
|
488
440
|
image_embeds = self.visual(pixel_values, grid_thw=image_grid_thw)
|
|
@@ -517,7 +469,7 @@ class RBLNQwen2_5_VLForConditionalGeneration(RBLNDecoderOnlyModelForCausalLM):
|
|
|
517
469
|
max_inputs_len = input_ids.shape[1]
|
|
518
470
|
|
|
519
471
|
head_dim = getattr(self.config, "head_dim", None) or self.config.hidden_size // self.config.num_attention_heads
|
|
520
|
-
all_position_embeds = torch.zeros(2, batch_size, 1, max_inputs_len, head_dim)
|
|
472
|
+
all_position_embeds = torch.zeros(2, batch_size, 1, max_inputs_len, head_dim, dtype=self.rbln_config.dtype)
|
|
521
473
|
all_rope_deltas = []
|
|
522
474
|
|
|
523
475
|
image_token_id = self.config.image_token_id
|
|
@@ -531,8 +483,7 @@ class RBLNQwen2_5_VLForConditionalGeneration(RBLNDecoderOnlyModelForCausalLM):
|
|
|
531
483
|
vision_tokens = input_id[0][vision_start_indices + 1]
|
|
532
484
|
image_nums = (vision_tokens == image_token_id).sum()
|
|
533
485
|
video_nums = (vision_tokens == video_token_id).sum()
|
|
534
|
-
position_ids, rope_deltas =
|
|
535
|
-
self,
|
|
486
|
+
position_ids, rope_deltas = self._get_rope_index_func(
|
|
536
487
|
input_id,
|
|
537
488
|
image_grid_thw[image_idx : image_idx + image_nums] if image_grid_thw is not None else None,
|
|
538
489
|
video_grid_thw[video_idx : video_idx + video_nums] if video_grid_thw is not None else None,
|
|
@@ -550,6 +501,180 @@ class RBLNQwen2_5_VLForConditionalGeneration(RBLNDecoderOnlyModelForCausalLM):
|
|
|
550
501
|
|
|
551
502
|
return inputs_embeds, all_position_embeds, rope_deltas
|
|
552
503
|
|
|
504
|
+
def forward(
|
|
505
|
+
self,
|
|
506
|
+
input_ids: Optional[torch.LongTensor] = None,
|
|
507
|
+
inputs_embeds: Optional[torch.FloatTensor] = None,
|
|
508
|
+
attention_mask: Optional[torch.Tensor] = None,
|
|
509
|
+
pixel_values: Optional[torch.Tensor] = None,
|
|
510
|
+
pixel_values_videos: Optional[torch.FloatTensor] = None,
|
|
511
|
+
image_grid_thw: Optional[torch.LongTensor] = None,
|
|
512
|
+
video_grid_thw: Optional[torch.LongTensor] = None,
|
|
513
|
+
cache_position: Optional[torch.LongTensor] = None,
|
|
514
|
+
second_per_grid_ts: Optional[torch.Tensor] = None,
|
|
515
|
+
output_hidden_states: Optional[bool] = None,
|
|
516
|
+
return_dict: Optional[bool] = None,
|
|
517
|
+
**kwargs,
|
|
518
|
+
) -> RBLNDecoderOnlyOutput:
|
|
519
|
+
inputs_embeds, position_embed, rope_deltas = self._preprocess_prefill(
|
|
520
|
+
input_ids,
|
|
521
|
+
attention_mask,
|
|
522
|
+
pixel_values,
|
|
523
|
+
pixel_values_videos,
|
|
524
|
+
image_grid_thw,
|
|
525
|
+
video_grid_thw,
|
|
526
|
+
second_per_grid_ts,
|
|
527
|
+
)
|
|
528
|
+
|
|
529
|
+
self.rope_deltas = rope_deltas
|
|
530
|
+
batch_size, seq_len = inputs_embeds.shape[:2]
|
|
531
|
+
|
|
532
|
+
output_hidden_states = _validate_output_hidden_states(output_hidden_states, self.rbln_config)
|
|
533
|
+
|
|
534
|
+
all_hidden_states = (
|
|
535
|
+
tuple(
|
|
536
|
+
torch.zeros(
|
|
537
|
+
batch_size,
|
|
538
|
+
seq_len,
|
|
539
|
+
self.config.hidden_size,
|
|
540
|
+
dtype=self.rbln_config.dtype,
|
|
541
|
+
)
|
|
542
|
+
for _ in range(self.config.num_hidden_layers + 1)
|
|
543
|
+
)
|
|
544
|
+
if output_hidden_states
|
|
545
|
+
else None
|
|
546
|
+
)
|
|
547
|
+
|
|
548
|
+
logits = []
|
|
549
|
+
for b_idx in range(batch_size):
|
|
550
|
+
query_length = attention_mask[b_idx].sum(dim=-1).int().item()
|
|
551
|
+
cache_position = torch.arange(query_length, dtype=torch.int32).unsqueeze(0)
|
|
552
|
+
|
|
553
|
+
output = self.prefill_decoder(
|
|
554
|
+
inputs_embeds=inputs_embeds[b_idx : b_idx + 1],
|
|
555
|
+
attention_mask=attention_mask[b_idx] if attention_mask is not None else None,
|
|
556
|
+
cache_position=cache_position,
|
|
557
|
+
batch_idx=b_idx,
|
|
558
|
+
position_embed=position_embed[:, b_idx : b_idx + 1],
|
|
559
|
+
block_tables=self.block_tables,
|
|
560
|
+
)
|
|
561
|
+
logits.append(output.logits)
|
|
562
|
+
if self.rbln_config.output_hidden_states:
|
|
563
|
+
for l_idx in range(self.config.num_hidden_layers + 1):
|
|
564
|
+
all_hidden_states[l_idx][b_idx].copy_(output.hidden_states[l_idx][0])
|
|
565
|
+
logits = torch.cat(logits, dim=0)
|
|
566
|
+
|
|
567
|
+
if not return_dict:
|
|
568
|
+
return_value = logits if not output_hidden_states else (logits, all_hidden_states)
|
|
569
|
+
return return_value
|
|
570
|
+
else:
|
|
571
|
+
return (
|
|
572
|
+
RBLNDecoderOnlyOutput(logits=logits, hidden_states=all_hidden_states)
|
|
573
|
+
if output_hidden_states
|
|
574
|
+
else RBLNDecoderOnlyOutput(logits=logits)
|
|
575
|
+
)
|
|
576
|
+
|
|
577
|
+
|
|
578
|
+
# MRO: RBLNQwen2_5_VLForConditionalGeneration -> RBLNQwen2_5_VLModel -> RBLNDecoderOnlyModelForCausalLM -> RBLNDecoderOnlyModel -> RBLNModel
|
|
579
|
+
class RBLNQwen2_5_VLForConditionalGeneration(RBLNQwen2_5_VLModel, RBLNDecoderOnlyModelForCausalLM):
|
|
580
|
+
"""
|
|
581
|
+
RBLNQwen2_5_VLForConditionalGeneration is a multi-modal model that integrates vision and language processing capabilities,
|
|
582
|
+
optimized for RBLN NPUs. It is designed for conditional generation tasks that involve both image and text inputs.
|
|
583
|
+
|
|
584
|
+
This model inherits from [`RBLNDecoderOnlyModelForCausalLM`]. Check the superclass documentation for the generic methods the library implements for all its models.
|
|
585
|
+
|
|
586
|
+
Important Note:
|
|
587
|
+
This model includes a Large Language Model (LLM). For optimal performance, it is highly recommended to use
|
|
588
|
+
tensor parallelism for the language model. This can be achieved by using the `rbln_config` parameter in the
|
|
589
|
+
`from_pretrained` method. Refer to the `from_pretrained` documentation and the RBLNQwen2_5_VLForConditionalGenerationConfig class for details.
|
|
590
|
+
|
|
591
|
+
Examples:
|
|
592
|
+
```python
|
|
593
|
+
from optimum.rbln import RBLNQwen2_5_VLForConditionalGeneration
|
|
594
|
+
|
|
595
|
+
model = RBLNQwen2_5_VLForConditionalGeneration.from_pretrained(
|
|
596
|
+
"Qwen/Qwen2.5-VL-7B-Instruct",
|
|
597
|
+
export=True,
|
|
598
|
+
rbln_config={
|
|
599
|
+
"visual": {
|
|
600
|
+
"max_seq_lens": 6400,
|
|
601
|
+
"device": 0,
|
|
602
|
+
},
|
|
603
|
+
"tensor_parallel_size": 8,
|
|
604
|
+
"kvcache_partition_len": 16_384,
|
|
605
|
+
"max_seq_len": 114_688,
|
|
606
|
+
"device": [0, 1, 2, 3, 4, 5, 6, 7],
|
|
607
|
+
},
|
|
608
|
+
)
|
|
609
|
+
|
|
610
|
+
model.save_pretrained("compiled-qwen2.5-vl-7b-instruct")
|
|
611
|
+
```
|
|
612
|
+
"""
|
|
613
|
+
|
|
614
|
+
auto_model_class = AutoModelForVision2Seq
|
|
615
|
+
_decoder_wrapper_cls = Qwen2_5_VL_LanguageModelWrapper
|
|
616
|
+
_supports_non_fp32 = True
|
|
617
|
+
_use_rotary_emb = False
|
|
618
|
+
_rbln_submodules = [
|
|
619
|
+
{"name": "visual"},
|
|
620
|
+
]
|
|
621
|
+
|
|
622
|
+
def __post_init__(self, **kwargs):
|
|
623
|
+
super().__post_init__(**kwargs)
|
|
624
|
+
self.rope_deltas = torch.zeros(self.rbln_config.batch_size)
|
|
625
|
+
|
|
626
|
+
def can_generate(self):
|
|
627
|
+
return True
|
|
628
|
+
|
|
629
|
+
@classmethod
|
|
630
|
+
def _reconstruct_model_if_needed(cls, model: "PreTrainedModel"):
|
|
631
|
+
model.model.lm_head = model.lm_head
|
|
632
|
+
return model
|
|
633
|
+
|
|
634
|
+
def prepare_inputs_for_generation(
|
|
635
|
+
self,
|
|
636
|
+
input_ids: torch.LongTensor,
|
|
637
|
+
generate_idx: Optional[torch.Tensor] = None,
|
|
638
|
+
attention_mask: Optional[torch.LongTensor] = None,
|
|
639
|
+
inputs_embeds: Optional[torch.Tensor] = None,
|
|
640
|
+
pixel_values=None,
|
|
641
|
+
pixel_values_videos=None,
|
|
642
|
+
image_grid_thw=None,
|
|
643
|
+
video_grid_thw=None,
|
|
644
|
+
second_per_grid_ts=None,
|
|
645
|
+
**kwargs,
|
|
646
|
+
):
|
|
647
|
+
model_inputs = {}
|
|
648
|
+
is_prefill_phase = generate_idx is None
|
|
649
|
+
|
|
650
|
+
if is_prefill_phase:
|
|
651
|
+
generate_idx = attention_mask.sum(dim=-1, keepdim=True).int()
|
|
652
|
+
cache_position = None
|
|
653
|
+
model_inputs.update({"input_ids": input_ids})
|
|
654
|
+
else:
|
|
655
|
+
if inputs_embeds is not None:
|
|
656
|
+
raise NotImplementedError("Specifying inputs_embeds in decoder phase is not supported.")
|
|
657
|
+
|
|
658
|
+
input_ids = input_ids[:, -1:]
|
|
659
|
+
cache_position = generate_idx
|
|
660
|
+
generate_idx = generate_idx + 1
|
|
661
|
+
model_inputs.update({"input_ids": input_ids})
|
|
662
|
+
|
|
663
|
+
model_inputs.update(
|
|
664
|
+
{
|
|
665
|
+
"attention_mask": attention_mask,
|
|
666
|
+
"cache_position": cache_position,
|
|
667
|
+
"generate_idx": generate_idx,
|
|
668
|
+
"pixel_values": pixel_values,
|
|
669
|
+
"pixel_values_videos": pixel_values_videos,
|
|
670
|
+
"image_grid_thw": image_grid_thw,
|
|
671
|
+
"video_grid_thw": video_grid_thw,
|
|
672
|
+
"second_per_grid_ts": second_per_grid_ts,
|
|
673
|
+
}
|
|
674
|
+
)
|
|
675
|
+
|
|
676
|
+
return model_inputs
|
|
677
|
+
|
|
553
678
|
def _preprocess_decoder(
|
|
554
679
|
self,
|
|
555
680
|
input_ids: torch.LongTensor = None,
|
|
@@ -560,14 +685,14 @@ class RBLNQwen2_5_VLForConditionalGeneration(RBLNDecoderOnlyModelForCausalLM):
|
|
|
560
685
|
f"Cache position size mismatch: got {cache_position.shape[0]}, expected {self.rbln_config.batch_size}."
|
|
561
686
|
)
|
|
562
687
|
|
|
563
|
-
inputs_embeds = self.embed_tokens(input_ids)
|
|
688
|
+
inputs_embeds = self.embed_tokens(input_ids).to(self.rbln_config.dtype)
|
|
564
689
|
position_embeds = []
|
|
565
690
|
for b_idx in range(self.rbln_config.batch_size):
|
|
566
691
|
delta = cache_position[b_idx] + self.rope_deltas[b_idx]
|
|
567
692
|
position_ids = torch.arange(1).view(1, -1)
|
|
568
693
|
position_ids = position_ids.add(delta)
|
|
569
694
|
position_ids = position_ids.unsqueeze(0).expand(3, -1, -1)
|
|
570
|
-
position_embed = self._get_position_embeddings(torch.zeros(1, dtype=
|
|
695
|
+
position_embed = self._get_position_embeddings(torch.zeros(1, dtype=self.rbln_config.dtype), position_ids)
|
|
571
696
|
position_embeds.append(position_embed)
|
|
572
697
|
|
|
573
698
|
position_embeds = torch.cat(position_embeds, dim=1)
|
|
@@ -587,8 +712,10 @@ class RBLNQwen2_5_VLForConditionalGeneration(RBLNDecoderOnlyModelForCausalLM):
|
|
|
587
712
|
second_per_grid_ts: Optional[torch.Tensor] = None,
|
|
588
713
|
generate_idx: Optional[torch.Tensor] = None,
|
|
589
714
|
return_dict: Optional[bool] = None,
|
|
715
|
+
output_hidden_states: Optional[bool] = None,
|
|
590
716
|
**kwargs,
|
|
591
717
|
) -> RBLNDecoderOnlyOutput:
|
|
718
|
+
output_hidden_states = _validate_output_hidden_states(output_hidden_states, self.rbln_config)
|
|
592
719
|
# Prefill
|
|
593
720
|
if cache_position is None:
|
|
594
721
|
inputs_embeds, position_embed, rope_deltas = self._preprocess_prefill(
|
|
@@ -601,8 +728,21 @@ class RBLNQwen2_5_VLForConditionalGeneration(RBLNDecoderOnlyModelForCausalLM):
|
|
|
601
728
|
second_per_grid_ts,
|
|
602
729
|
)
|
|
603
730
|
|
|
731
|
+
batch_size, seq_len = inputs_embeds.shape[:2]
|
|
732
|
+
all_hidden_states = (
|
|
733
|
+
tuple(
|
|
734
|
+
torch.zeros(
|
|
735
|
+
batch_size,
|
|
736
|
+
seq_len,
|
|
737
|
+
self.config.hidden_size,
|
|
738
|
+
dtype=self.rbln_config.dtype,
|
|
739
|
+
)
|
|
740
|
+
for _ in range(self.config.num_hidden_layers + 1)
|
|
741
|
+
)
|
|
742
|
+
if output_hidden_states
|
|
743
|
+
else None
|
|
744
|
+
)
|
|
604
745
|
self.rope_deltas = rope_deltas
|
|
605
|
-
batch_size = inputs_embeds.shape[0]
|
|
606
746
|
|
|
607
747
|
logits = []
|
|
608
748
|
for b_idx in range(batch_size):
|
|
@@ -616,8 +756,11 @@ class RBLNQwen2_5_VLForConditionalGeneration(RBLNDecoderOnlyModelForCausalLM):
|
|
|
616
756
|
position_embed=position_embed[:, b_idx : b_idx + 1],
|
|
617
757
|
)
|
|
618
758
|
logits.append(output.logits)
|
|
759
|
+
if self.rbln_config.output_hidden_states:
|
|
760
|
+
for l_idx in range(self.config.num_hidden_layers + 1):
|
|
761
|
+
all_hidden_states[l_idx][b_idx].copy_(output.hidden_states[l_idx][0])
|
|
619
762
|
logits = torch.cat(logits, dim=0)
|
|
620
|
-
|
|
763
|
+
# Decoder
|
|
621
764
|
else:
|
|
622
765
|
inputs_embeds, position_embed = self._preprocess_decoder(input_ids, cache_position)
|
|
623
766
|
output = self.decoder(
|
|
@@ -626,11 +769,17 @@ class RBLNQwen2_5_VLForConditionalGeneration(RBLNDecoderOnlyModelForCausalLM):
|
|
|
626
769
|
position_embed=position_embed,
|
|
627
770
|
)
|
|
628
771
|
logits = output.logits
|
|
772
|
+
all_hidden_states = output.hidden_states
|
|
629
773
|
|
|
630
774
|
if not return_dict:
|
|
631
|
-
|
|
775
|
+
return_value = (
|
|
776
|
+
logits,
|
|
777
|
+
generate_idx if not output_hidden_states else (logits, generate_idx, all_hidden_states),
|
|
778
|
+
)
|
|
779
|
+
return return_value
|
|
632
780
|
else:
|
|
633
781
|
return RBLNDecoderOnlyOutput(
|
|
634
782
|
logits=logits,
|
|
635
783
|
generate_idx=generate_idx,
|
|
784
|
+
hidden_states=all_hidden_states,
|
|
636
785
|
)
|