optimum-rbln 0.7.5a1__py3-none-any.whl → 0.7.5rc1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- optimum/rbln/__init__.py +10 -0
- optimum/rbln/__version__.py +2 -2
- optimum/rbln/diffusers/models/transformers/prior_transformer.py +1 -2
- optimum/rbln/diffusers/models/transformers/transformer_sd3.py +1 -2
- optimum/rbln/diffusers/models/unets/unet_2d_condition.py +1 -2
- optimum/rbln/modeling.py +53 -9
- optimum/rbln/modeling_base.py +22 -3
- optimum/rbln/transformers/__init__.py +10 -0
- optimum/rbln/transformers/modeling_generic.py +0 -19
- optimum/rbln/transformers/models/__init__.py +14 -0
- optimum/rbln/transformers/models/auto/__init__.py +1 -0
- optimum/rbln/transformers/models/auto/modeling_auto.py +7 -0
- optimum/rbln/transformers/models/decoderonly/decoderonly_architecture.py +114 -19
- optimum/rbln/transformers/models/decoderonly/modeling_decoderonly.py +29 -10
- optimum/rbln/transformers/models/exaone/exaone_architecture.py +5 -1
- optimum/rbln/transformers/models/gemma/gemma_architecture.py +5 -1
- optimum/rbln/transformers/models/gemma3/__init__.py +16 -0
- optimum/rbln/transformers/models/gemma3/configuration_gemma3.py +69 -0
- optimum/rbln/transformers/models/gemma3/gemma3_architecture.py +446 -0
- optimum/rbln/transformers/models/gemma3/modeling_gemma3.py +1057 -0
- optimum/rbln/transformers/models/gpt2/gpt2_architecture.py +4 -1
- optimum/rbln/transformers/models/midm/midm_architecture.py +4 -1
- optimum/rbln/transformers/models/opt/modeling_opt.py +2 -0
- optimum/rbln/transformers/models/opt/opt_architecture.py +4 -1
- optimum/rbln/transformers/models/phi/phi_architecture.py +4 -1
- optimum/rbln/transformers/models/qwen2_5_vl/modeling_qwen2_5_vl.py +3 -2
- optimum/rbln/transformers/models/wav2vec2/modeling_wav2vec2.py +0 -3
- {optimum_rbln-0.7.5a1.dist-info → optimum_rbln-0.7.5rc1.dist-info}/METADATA +1 -1
- {optimum_rbln-0.7.5a1.dist-info → optimum_rbln-0.7.5rc1.dist-info}/RECORD +31 -27
- {optimum_rbln-0.7.5a1.dist-info → optimum_rbln-0.7.5rc1.dist-info}/WHEEL +0 -0
- {optimum_rbln-0.7.5a1.dist-info → optimum_rbln-0.7.5rc1.dist-info}/licenses/LICENSE +0 -0
optimum/rbln/__init__.py
CHANGED
@@ -38,6 +38,7 @@ _import_structure = {
|
|
38
38
|
"RBLNAutoModelForCTC",
|
39
39
|
"RBLNAutoModelForDepthEstimation",
|
40
40
|
"RBLNAutoModelForImageClassification",
|
41
|
+
"RBLNAutoModelForImageTextToText",
|
41
42
|
"RBLNAutoModelForMaskedLM",
|
42
43
|
"RBLNAutoModelForQuestionAnswering",
|
43
44
|
"RBLNAutoModelForSeq2SeqLM",
|
@@ -78,6 +79,10 @@ _import_structure = {
|
|
78
79
|
"RBLNExaoneForCausalLMConfig",
|
79
80
|
"RBLNGemmaForCausalLM",
|
80
81
|
"RBLNGemmaForCausalLMConfig",
|
82
|
+
"RBLNGemma3ForCausalLM",
|
83
|
+
"RBLNGemma3ForCausalLMConfig",
|
84
|
+
"RBLNGemma3ForConditionalGeneration",
|
85
|
+
"RBLNGemma3ForConditionalGenerationConfig",
|
81
86
|
"RBLNGPT2LMHeadModel",
|
82
87
|
"RBLNGPT2LMHeadModelConfig",
|
83
88
|
"RBLNIdefics3VisionTransformer",
|
@@ -259,6 +264,7 @@ if TYPE_CHECKING:
|
|
259
264
|
RBLNAutoModelForCTC,
|
260
265
|
RBLNAutoModelForDepthEstimation,
|
261
266
|
RBLNAutoModelForImageClassification,
|
267
|
+
RBLNAutoModelForImageTextToText,
|
262
268
|
RBLNAutoModelForMaskedLM,
|
263
269
|
RBLNAutoModelForQuestionAnswering,
|
264
270
|
RBLNAutoModelForSeq2SeqLM,
|
@@ -297,6 +303,10 @@ if TYPE_CHECKING:
|
|
297
303
|
RBLNDPTForDepthEstimationConfig,
|
298
304
|
RBLNExaoneForCausalLM,
|
299
305
|
RBLNExaoneForCausalLMConfig,
|
306
|
+
RBLNGemma3ForCausalLM,
|
307
|
+
RBLNGemma3ForCausalLMConfig,
|
308
|
+
RBLNGemma3ForConditionalGeneration,
|
309
|
+
RBLNGemma3ForConditionalGenerationConfig,
|
300
310
|
RBLNGemmaForCausalLM,
|
301
311
|
RBLNGemmaForCausalLMConfig,
|
302
312
|
RBLNGPT2LMHeadModel,
|
optimum/rbln/__version__.py
CHANGED
@@ -17,5 +17,5 @@ __version__: str
|
|
17
17
|
__version_tuple__: VERSION_TUPLE
|
18
18
|
version_tuple: VERSION_TUPLE
|
19
19
|
|
20
|
-
__version__ = version = '0.7.
|
21
|
-
__version_tuple__ = version_tuple = (0, 7, 5, '
|
20
|
+
__version__ = version = '0.7.5rc1'
|
21
|
+
__version_tuple__ = version_tuple = (0, 7, 5, 'rc1')
|
@@ -58,8 +58,7 @@ class _PriorTransformer(torch.nn.Module):
|
|
58
58
|
class RBLNPriorTransformer(RBLNModel):
|
59
59
|
hf_library_name = "diffusers"
|
60
60
|
auto_model_class = PriorTransformer
|
61
|
-
|
62
|
-
output_key = "predicted_image_embedding"
|
61
|
+
_output_class = PriorTransformerOutput
|
63
62
|
|
64
63
|
def __post_init__(self, **kwargs):
|
65
64
|
super().__post_init__(**kwargs)
|
@@ -61,8 +61,7 @@ class SD3Transformer2DModelWrapper(torch.nn.Module):
|
|
61
61
|
class RBLNSD3Transformer2DModel(RBLNModel):
|
62
62
|
hf_library_name = "diffusers"
|
63
63
|
auto_model_class = SD3Transformer2DModel
|
64
|
-
|
65
|
-
output_key = "sample"
|
64
|
+
_output_class = Transformer2DModelOutput
|
66
65
|
|
67
66
|
def __post_init__(self, **kwargs):
|
68
67
|
super().__post_init__(**kwargs)
|
@@ -143,8 +143,7 @@ class RBLNUNet2DConditionModel(RBLNModel):
|
|
143
143
|
hf_library_name = "diffusers"
|
144
144
|
auto_model_class = UNet2DConditionModel
|
145
145
|
_rbln_config_class = RBLNUNet2DConditionModelConfig
|
146
|
-
|
147
|
-
output_key = "sample"
|
146
|
+
_output_class = UNet2DConditionOutput
|
148
147
|
|
149
148
|
def __post_init__(self, **kwargs):
|
150
149
|
super().__post_init__(**kwargs)
|
optimum/rbln/modeling.py
CHANGED
@@ -14,7 +14,7 @@
|
|
14
14
|
|
15
15
|
from pathlib import Path
|
16
16
|
from tempfile import TemporaryDirectory
|
17
|
-
from typing import TYPE_CHECKING, Dict, List, Optional, Union
|
17
|
+
from typing import TYPE_CHECKING, Dict, List, Optional, Union, get_args, get_origin, get_type_hints
|
18
18
|
|
19
19
|
import rebel
|
20
20
|
import torch
|
@@ -49,8 +49,7 @@ class RBLNModel(RBLNBaseModel):
|
|
49
49
|
```
|
50
50
|
"""
|
51
51
|
|
52
|
-
|
53
|
-
output_key = "last_hidden_state"
|
52
|
+
_output_class = None
|
54
53
|
|
55
54
|
@classmethod
|
56
55
|
def update_kwargs(cls, kwargs):
|
@@ -245,16 +244,61 @@ class RBLNModel(RBLNBaseModel):
|
|
245
244
|
# Format output according to task requirements
|
246
245
|
return self._prepare_output(output, return_dict)
|
247
246
|
|
247
|
+
@classmethod
|
248
|
+
def get_hf_output_class(cls):
|
249
|
+
"""
|
250
|
+
Dynamically gets the output class from the corresponding HuggingFace model class.
|
251
|
+
|
252
|
+
Returns:
|
253
|
+
type: The appropriate output class from transformers or diffusers
|
254
|
+
"""
|
255
|
+
if cls._output_class:
|
256
|
+
return cls._output_class
|
257
|
+
|
258
|
+
hf_class = cls.get_hf_class()
|
259
|
+
if hf_class is None:
|
260
|
+
raise ValueError(f"No HuggingFace model class found for {cls.__name__}")
|
261
|
+
|
262
|
+
hints = get_type_hints(hf_class.forward) if hasattr(hf_class, "forward") else {}
|
263
|
+
ret = hints.get("return")
|
264
|
+
|
265
|
+
if ret is not None:
|
266
|
+
candidates = get_args(ret) if get_origin(ret) is Union else (ret,)
|
267
|
+
|
268
|
+
for t in candidates:
|
269
|
+
if t is type(None): # Skip NoneType in Union
|
270
|
+
continue
|
271
|
+
mod = getattr(t, "__module__", "")
|
272
|
+
if "transformers" in mod or "diffusers" in mod:
|
273
|
+
cls._output_class = t
|
274
|
+
return t
|
275
|
+
|
276
|
+
# Fallback to BaseModelOutput
|
277
|
+
cls._output_class = BaseModelOutput
|
278
|
+
return BaseModelOutput
|
279
|
+
|
248
280
|
def _prepare_output(self, output, return_dict):
|
249
281
|
"""
|
250
282
|
Prepare model output based on return_dict flag.
|
251
283
|
This method can be overridden by subclasses to provide task-specific output handling.
|
252
284
|
"""
|
285
|
+
tuple_output = (output,) if not isinstance(output, (tuple, list)) else output
|
253
286
|
if not return_dict:
|
254
|
-
return
|
287
|
+
return tuple_output
|
255
288
|
else:
|
256
|
-
|
257
|
-
|
258
|
-
|
259
|
-
|
260
|
-
|
289
|
+
output_class = self.get_hf_output_class()
|
290
|
+
if hasattr(output_class, "loss"):
|
291
|
+
tuple_output = (None,) + tuple_output
|
292
|
+
|
293
|
+
# Truncate if we have too many outputs, otherwise use as is
|
294
|
+
if hasattr(output_class, "__annotations__"):
|
295
|
+
num_fields = len(output_class.__annotations__)
|
296
|
+
if len(tuple_output) > num_fields:
|
297
|
+
tuple_output = tuple_output[:num_fields]
|
298
|
+
logger.warning(
|
299
|
+
f"Truncating output to {num_fields} fields for {output_class.__name__}. "
|
300
|
+
f"Expected {num_fields} fields, but got {len(tuple_output)} fields."
|
301
|
+
"This is unexpected. Please report this issue to the developers."
|
302
|
+
)
|
303
|
+
|
304
|
+
return output_class(*tuple_output)
|
optimum/rbln/modeling_base.py
CHANGED
@@ -178,9 +178,27 @@ class RBLNBaseModel(SubModulesMixin, PushToHubMixin, PreTrainedModel):
|
|
178
178
|
return str(model_path)
|
179
179
|
|
180
180
|
@classmethod
|
181
|
-
def _load_compiled_models(cls, model_path: str):
|
181
|
+
def _load_compiled_models(cls, model_path: str, expected_compiled_model_names: List[str]):
|
182
182
|
compiled_models = Path(model_path).glob("*.rbln")
|
183
|
-
|
183
|
+
expected_compiled_models = [
|
184
|
+
Path(model_path) / f"{compiled_model_name}.rbln" for compiled_model_name in expected_compiled_model_names
|
185
|
+
]
|
186
|
+
unexpected_compiled_models = [cm for cm in compiled_models if cm not in expected_compiled_models]
|
187
|
+
if unexpected_compiled_models:
|
188
|
+
# TODO(jongho): fix after May release. raise error if unexpected compiled models are found
|
189
|
+
logger.warning(
|
190
|
+
f"Unexpected compiled models found: {[cm.name for cm in unexpected_compiled_models]}. "
|
191
|
+
f"Please check the model path: {model_path}"
|
192
|
+
)
|
193
|
+
|
194
|
+
rbln_compiled_models = {}
|
195
|
+
for compiled_model in expected_compiled_models:
|
196
|
+
if not compiled_model.exists():
|
197
|
+
raise FileNotFoundError(
|
198
|
+
f"Expected RBLN compiled model '{compiled_model.name}' not found at '{model_path}'. "
|
199
|
+
"Please ensure all models specified in `rbln_config` are present."
|
200
|
+
)
|
201
|
+
rbln_compiled_models[compiled_model.stem] = rebel.RBLNCompiledModel(compiled_model)
|
184
202
|
return rbln_compiled_models
|
185
203
|
|
186
204
|
@classmethod
|
@@ -271,7 +289,8 @@ class RBLNBaseModel(SubModulesMixin, PushToHubMixin, PreTrainedModel):
|
|
271
289
|
)
|
272
290
|
config = PretrainedConfig(**config)
|
273
291
|
|
274
|
-
|
292
|
+
compiled_model_names = [cfg.compiled_model_name for cfg in rbln_config.compile_cfgs]
|
293
|
+
rbln_compiled_models = cls._load_compiled_models(model_path_subfolder, compiled_model_names)
|
275
294
|
|
276
295
|
if subfolder != "":
|
277
296
|
model_save_dir = Path(model_path_subfolder).absolute().parent
|
@@ -34,6 +34,7 @@ _import_structure = {
|
|
34
34
|
"RBLNAutoModelForCTC",
|
35
35
|
"RBLNAutoModelForDepthEstimation",
|
36
36
|
"RBLNAutoModelForImageClassification",
|
37
|
+
"RBLNAutoModelForImageTextToText",
|
37
38
|
"RBLNAutoModelForMaskedLM",
|
38
39
|
"RBLNAutoModelForQuestionAnswering",
|
39
40
|
"RBLNAutoModelForSeq2SeqLM",
|
@@ -72,6 +73,10 @@ _import_structure = {
|
|
72
73
|
"RBLNExaoneForCausalLMConfig",
|
73
74
|
"RBLNGemmaForCausalLM",
|
74
75
|
"RBLNGemmaForCausalLMConfig",
|
76
|
+
"RBLNGemma3ForCausalLM",
|
77
|
+
"RBLNGemma3ForCausalLMConfig",
|
78
|
+
"RBLNGemma3ForConditionalGeneration",
|
79
|
+
"RBLNGemma3ForConditionalGenerationConfig",
|
75
80
|
"RBLNGPT2LMHeadModel",
|
76
81
|
"RBLNGPT2LMHeadModelConfig",
|
77
82
|
"RBLNIdefics3VisionTransformer",
|
@@ -148,6 +153,7 @@ if TYPE_CHECKING:
|
|
148
153
|
RBLNAutoModelForCTC,
|
149
154
|
RBLNAutoModelForDepthEstimation,
|
150
155
|
RBLNAutoModelForImageClassification,
|
156
|
+
RBLNAutoModelForImageTextToText,
|
151
157
|
RBLNAutoModelForMaskedLM,
|
152
158
|
RBLNAutoModelForQuestionAnswering,
|
153
159
|
RBLNAutoModelForSeq2SeqLM,
|
@@ -184,6 +190,10 @@ if TYPE_CHECKING:
|
|
184
190
|
RBLNDPTForDepthEstimationConfig,
|
185
191
|
RBLNExaoneForCausalLM,
|
186
192
|
RBLNExaoneForCausalLMConfig,
|
193
|
+
RBLNGemma3ForCausalLM,
|
194
|
+
RBLNGemma3ForCausalLMConfig,
|
195
|
+
RBLNGemma3ForConditionalGeneration,
|
196
|
+
RBLNGemma3ForConditionalGenerationConfig,
|
187
197
|
RBLNGemmaForCausalLM,
|
188
198
|
RBLNGemmaForCausalLMConfig,
|
189
199
|
RBLNGPT2LMHeadModel,
|
@@ -36,11 +36,7 @@ from transformers import (
|
|
36
36
|
)
|
37
37
|
from transformers.modeling_outputs import (
|
38
38
|
BaseModelOutput,
|
39
|
-
DepthEstimatorOutput,
|
40
|
-
ImageClassifierOutput,
|
41
|
-
MaskedLMOutput,
|
42
39
|
QuestionAnsweringModelOutput,
|
43
|
-
SequenceClassifierOutput,
|
44
40
|
)
|
45
41
|
|
46
42
|
from ..configuration_utils import RBLNCompileConfig
|
@@ -63,8 +59,6 @@ class _RBLNTransformerEncoder(RBLNModel):
|
|
63
59
|
auto_model_class = AutoModel
|
64
60
|
rbln_model_input_names = ["input_ids", "attention_mask", "token_type_ids"]
|
65
61
|
rbln_dtype = "int64"
|
66
|
-
output_class = BaseModelOutput
|
67
|
-
output_key = "last_hidden_state"
|
68
62
|
|
69
63
|
@classmethod
|
70
64
|
def _update_rbln_config(
|
@@ -149,7 +143,6 @@ class _RBLNImageModel(RBLNModel):
|
|
149
143
|
auto_model_class = AutoModel
|
150
144
|
main_input_name = "pixel_values"
|
151
145
|
output_class = BaseModelOutput
|
152
|
-
output_key = "last_hidden_state"
|
153
146
|
|
154
147
|
@classmethod
|
155
148
|
def _update_rbln_config(
|
@@ -223,15 +216,11 @@ class RBLNModelForQuestionAnswering(_RBLNTransformerEncoder):
|
|
223
216
|
class RBLNModelForSequenceClassification(_RBLNTransformerEncoder):
|
224
217
|
auto_model_class = AutoModelForSequenceClassification
|
225
218
|
rbln_model_input_names = ["input_ids", "attention_mask"]
|
226
|
-
output_class = SequenceClassifierOutput
|
227
|
-
output_key = "logits"
|
228
219
|
|
229
220
|
|
230
221
|
class RBLNModelForMaskedLM(_RBLNTransformerEncoder):
|
231
222
|
auto_model_class = AutoModelForMaskedLM
|
232
223
|
rbln_model_input_names = ["input_ids", "attention_mask"]
|
233
|
-
output_class = MaskedLMOutput
|
234
|
-
output_key = "logits"
|
235
224
|
|
236
225
|
|
237
226
|
class RBLNModelForTextEncoding(_RBLNTransformerEncoder):
|
@@ -243,20 +232,14 @@ class RBLNTransformerEncoderForFeatureExtraction(_RBLNTransformerEncoder):
|
|
243
232
|
# TODO: RBLNModel is also for feature extraction.
|
244
233
|
auto_model_class = AutoModel
|
245
234
|
rbln_model_input_names = ["input_ids", "attention_mask"]
|
246
|
-
output_class = BaseModelOutput
|
247
|
-
output_key = "last_hidden_state"
|
248
235
|
|
249
236
|
|
250
237
|
class RBLNModelForImageClassification(_RBLNImageModel):
|
251
238
|
auto_model_class = AutoModelForImageClassification
|
252
|
-
output_class = ImageClassifierOutput
|
253
|
-
output_key = "logits"
|
254
239
|
|
255
240
|
|
256
241
|
class RBLNModelForDepthEstimation(_RBLNImageModel):
|
257
242
|
auto_model_class = AutoModelForDepthEstimation
|
258
|
-
output_class = DepthEstimatorOutput
|
259
|
-
output_key = "predicted_depth"
|
260
243
|
|
261
244
|
|
262
245
|
class RBLNModelForAudioClassification(RBLNModel):
|
@@ -273,8 +256,6 @@ class RBLNModelForAudioClassification(RBLNModel):
|
|
273
256
|
"""
|
274
257
|
|
275
258
|
auto_model_class = AutoModelForAudioClassification
|
276
|
-
output_class = SequenceClassifierOutput
|
277
|
-
output_key = "logits"
|
278
259
|
|
279
260
|
@classmethod
|
280
261
|
def _update_rbln_config(
|
@@ -31,6 +31,7 @@ _import_structure = {
|
|
31
31
|
"RBLNAutoModelForSequenceClassification",
|
32
32
|
"RBLNAutoModelForSpeechSeq2Seq",
|
33
33
|
"RBLNAutoModelForVision2Seq",
|
34
|
+
"RBLNAutoModelForImageTextToText",
|
34
35
|
],
|
35
36
|
"bart": [
|
36
37
|
"RBLNBartForConditionalGeneration",
|
@@ -80,6 +81,12 @@ _import_structure = {
|
|
80
81
|
],
|
81
82
|
"exaone": ["RBLNExaoneForCausalLM", "RBLNExaoneForCausalLMConfig"],
|
82
83
|
"gemma": ["RBLNGemmaForCausalLM", "RBLNGemmaForCausalLMConfig"],
|
84
|
+
"gemma3": [
|
85
|
+
"RBLNGemma3ForCausalLM",
|
86
|
+
"RBLNGemma3ForCausalLMConfig",
|
87
|
+
"RBLNGemma3ForConditionalGeneration",
|
88
|
+
"RBLNGemma3ForConditionalGenerationConfig",
|
89
|
+
],
|
83
90
|
"gpt2": ["RBLNGPT2LMHeadModel", "RBLNGPT2LMHeadModelConfig"],
|
84
91
|
"idefics3": [
|
85
92
|
"RBLNIdefics3VisionTransformer",
|
@@ -121,6 +128,7 @@ if TYPE_CHECKING:
|
|
121
128
|
RBLNAutoModelForCTC,
|
122
129
|
RBLNAutoModelForDepthEstimation,
|
123
130
|
RBLNAutoModelForImageClassification,
|
131
|
+
RBLNAutoModelForImageTextToText,
|
124
132
|
RBLNAutoModelForMaskedLM,
|
125
133
|
RBLNAutoModelForQuestionAnswering,
|
126
134
|
RBLNAutoModelForSeq2SeqLM,
|
@@ -170,6 +178,12 @@ if TYPE_CHECKING:
|
|
170
178
|
)
|
171
179
|
from .exaone import RBLNExaoneForCausalLM, RBLNExaoneForCausalLMConfig
|
172
180
|
from .gemma import RBLNGemmaForCausalLM, RBLNGemmaForCausalLMConfig
|
181
|
+
from .gemma3 import (
|
182
|
+
RBLNGemma3ForCausalLM,
|
183
|
+
RBLNGemma3ForCausalLMConfig,
|
184
|
+
RBLNGemma3ForConditionalGeneration,
|
185
|
+
RBLNGemma3ForConditionalGenerationConfig,
|
186
|
+
)
|
173
187
|
from .gpt2 import RBLNGPT2LMHeadModel, RBLNGPT2LMHeadModelConfig
|
174
188
|
from .idefics3 import (
|
175
189
|
RBLNIdefics3ForConditionalGeneration,
|
@@ -23,6 +23,8 @@ from transformers.models.auto.modeling_auto import (
|
|
23
23
|
MODEL_FOR_DEPTH_ESTIMATION_MAPPING_NAMES,
|
24
24
|
MODEL_FOR_IMAGE_CLASSIFICATION_MAPPING,
|
25
25
|
MODEL_FOR_IMAGE_CLASSIFICATION_MAPPING_NAMES,
|
26
|
+
MODEL_FOR_IMAGE_TEXT_TO_TEXT_MAPPING,
|
27
|
+
MODEL_FOR_IMAGE_TEXT_TO_TEXT_MAPPING_NAMES,
|
26
28
|
MODEL_FOR_MASKED_LM_MAPPING,
|
27
29
|
MODEL_FOR_MASKED_LM_MAPPING_NAMES,
|
28
30
|
MODEL_FOR_QUESTION_ANSWERING_MAPPING,
|
@@ -90,6 +92,11 @@ class RBLNAutoModelForVision2Seq(_BaseAutoModelClass):
|
|
90
92
|
_model_mapping_names = MODEL_FOR_VISION_2_SEQ_MAPPING_NAMES
|
91
93
|
|
92
94
|
|
95
|
+
class RBLNAutoModelForImageTextToText(_BaseAutoModelClass):
|
96
|
+
_model_mapping = MODEL_FOR_IMAGE_TEXT_TO_TEXT_MAPPING
|
97
|
+
_model_mapping_names = MODEL_FOR_IMAGE_TEXT_TO_TEXT_MAPPING_NAMES
|
98
|
+
|
99
|
+
|
93
100
|
class RBLNAutoModelForMaskedLM(_BaseAutoModelClass):
|
94
101
|
_model_mapping = MODEL_FOR_MASKED_LM_MAPPING
|
95
102
|
_model_mapping_names = MODEL_FOR_MASKED_LM_MAPPING_NAMES
|
@@ -157,7 +157,11 @@ class DecoderOnlyWrapper(nn.Module):
|
|
157
157
|
self.config = causal_lm.config
|
158
158
|
|
159
159
|
if use_rotary_emb:
|
160
|
-
|
160
|
+
rotary_embs = self.get_rotary_emb(max_seq_len=max_seq_len)
|
161
|
+
if isinstance(rotary_embs, tuple):
|
162
|
+
self.rotary_emb_global, self.rotary_emb_local = rotary_embs
|
163
|
+
else:
|
164
|
+
self.rotary_emb = rotary_embs
|
161
165
|
else:
|
162
166
|
self.rotary_emb = None
|
163
167
|
|
@@ -195,7 +199,10 @@ class DecoderOnlyWrapper(nn.Module):
|
|
195
199
|
for layer in causal_lm.model.layers:
|
196
200
|
if self.attn_impl == "eager":
|
197
201
|
new_self_attn = DecoderOnlyAttention(
|
198
|
-
layer.self_attn,
|
202
|
+
layer.self_attn,
|
203
|
+
self.use_attention_mask,
|
204
|
+
self.use_position_ids,
|
205
|
+
kvcache_block_size=self.kvcache_block_size,
|
199
206
|
)
|
200
207
|
elif self.attn_impl == "flash_attn":
|
201
208
|
new_self_attn = DecoderOnlyFlashAttention(
|
@@ -203,6 +210,7 @@ class DecoderOnlyWrapper(nn.Module):
|
|
203
210
|
kvcache_partition_len=self.kvcache_partition_len,
|
204
211
|
kvcache_block_size=self.kvcache_block_size,
|
205
212
|
use_attention_mask=self.use_attention_mask,
|
213
|
+
use_position_ids=self.use_position_ids,
|
206
214
|
)
|
207
215
|
else:
|
208
216
|
raise NotImplementedError(f"Unknwon attn : {self.attn_impl}")
|
@@ -363,6 +371,13 @@ class DecoderOnlyForCausalLM(nn.Module):
|
|
363
371
|
hidden_states = hidden_states[:, query_position.to(torch.int).unsqueeze(0)]
|
364
372
|
|
365
373
|
logits = self.lm_head(hidden_states)
|
374
|
+
|
375
|
+
# Apply final logit softmaxing if configured, e.g. for Gemma2
|
376
|
+
if getattr(self.config, "final_logit_softcapping", None) is not None:
|
377
|
+
logits = logits / self.config.final_logit_softcapping
|
378
|
+
logits = torch.tanh(logits)
|
379
|
+
logits = logits * self.config.final_logit_softcapping
|
380
|
+
|
366
381
|
return logits
|
367
382
|
|
368
383
|
|
@@ -610,7 +625,7 @@ class DecoderOnlyAttention(nn.Module):
|
|
610
625
|
self_attn: Original attention module from the base model
|
611
626
|
"""
|
612
627
|
|
613
|
-
def __init__(self, self_attn, use_attention_mask, kvcache_block_size):
|
628
|
+
def __init__(self, self_attn, use_attention_mask, use_position_ids, kvcache_block_size):
|
614
629
|
super().__init__()
|
615
630
|
self._original_mod = self_attn
|
616
631
|
self.layer_idx = self_attn.layer_idx
|
@@ -629,6 +644,7 @@ class DecoderOnlyAttention(nn.Module):
|
|
629
644
|
self.num_key_value_heads = self.num_heads
|
630
645
|
|
631
646
|
self.use_attention_mask = use_attention_mask
|
647
|
+
self.use_position_ids = use_position_ids
|
632
648
|
self.attention = self.get_attention()
|
633
649
|
self.kvcache_block_size = kvcache_block_size
|
634
650
|
self.__post_init__()
|
@@ -643,7 +659,9 @@ class DecoderOnlyAttention(nn.Module):
|
|
643
659
|
self.attention.phase = phase
|
644
660
|
|
645
661
|
def get_attention(self):
|
646
|
-
return AttentionOp(
|
662
|
+
return AttentionOp(
|
663
|
+
self.num_heads, self.head_dim, self.num_key_value_heads, self.use_attention_mask, self.use_position_ids
|
664
|
+
)
|
647
665
|
|
648
666
|
def __post_init__(self):
|
649
667
|
self.q_proj = self._original_mod.q_proj
|
@@ -716,13 +734,16 @@ class DecoderOnlyAttention(nn.Module):
|
|
716
734
|
|
717
735
|
|
718
736
|
class AttentionOp(nn.Module):
|
719
|
-
def __init__(
|
737
|
+
def __init__(
|
738
|
+
self, num_heads: int, head_dim: int, num_key_value_heads: int, use_attention_mask: bool, use_position_ids: bool
|
739
|
+
):
|
720
740
|
super().__init__()
|
721
741
|
self.num_heads = num_heads
|
722
742
|
self.head_dim = head_dim
|
723
743
|
self.num_key_value_heads = num_key_value_heads
|
724
744
|
self.phase = "prefill"
|
725
745
|
self.use_attention_mask = use_attention_mask
|
746
|
+
self.use_position_ids = use_position_ids
|
726
747
|
|
727
748
|
def forward(
|
728
749
|
self,
|
@@ -755,7 +776,8 @@ class AttentionOp(nn.Module):
|
|
755
776
|
# reshape for removing repeat_kv (batch=1 , num_head, 1, q_len=1, head_dim)
|
756
777
|
key_state = key_state.unsqueeze(2) # 1, 32, 1, 128, 128
|
757
778
|
value_state = value_state.unsqueeze(2)
|
758
|
-
|
779
|
+
|
780
|
+
if self.use_attention_mask and not self.use_position_ids:
|
759
781
|
attn_mask = attn_mask.unsqueeze(2)
|
760
782
|
|
761
783
|
if self.phase == "decode":
|
@@ -772,7 +794,7 @@ class AttentionOp(nn.Module):
|
|
772
794
|
)
|
773
795
|
|
774
796
|
if self.phase == "decode":
|
775
|
-
if self.use_attention_mask:
|
797
|
+
if self.use_attention_mask and not self.use_position_ids:
|
776
798
|
attn_output = torch.ops.rbln_custom_ops.paged_attn_decode(
|
777
799
|
q=query_state,
|
778
800
|
k=key_state,
|
@@ -796,11 +818,11 @@ class AttentionOp(nn.Module):
|
|
796
818
|
scale=scale,
|
797
819
|
block_table=block_tables,
|
798
820
|
block_size=block_size,
|
799
|
-
mask=None,
|
821
|
+
mask=attn_mask if self.use_position_ids else None,
|
800
822
|
)
|
801
823
|
|
802
824
|
else:
|
803
|
-
if self.use_attention_mask:
|
825
|
+
if self.use_attention_mask and not self.use_position_ids:
|
804
826
|
attn_output = torch.ops.rbln_custom_ops.paged_attn_prefill(
|
805
827
|
q=query_state,
|
806
828
|
k=key_state,
|
@@ -824,8 +846,8 @@ class AttentionOp(nn.Module):
|
|
824
846
|
scale=scale,
|
825
847
|
block_table=block_tables,
|
826
848
|
block_size=block_size,
|
827
|
-
is_bidirectional=False,
|
828
|
-
mask=None,
|
849
|
+
is_bidirectional=True if self.phase == "image_prefill" else False, # FIXME, Hard-coded for Gemma3.
|
850
|
+
mask=attn_mask if self.use_position_ids else None,
|
829
851
|
)
|
830
852
|
|
831
853
|
attn_output = attn_output.view(batch_size, self.num_heads, -1, self.head_dim)
|
@@ -927,10 +949,13 @@ class RotaryEmbedding(nn.Module):
|
|
927
949
|
|
928
950
|
|
929
951
|
class DecoderOnlyFlashAttention(DecoderOnlyAttention):
|
930
|
-
def __init__(self, self_attn, kvcache_partition_len, kvcache_block_size, use_attention_mask):
|
952
|
+
def __init__(self, self_attn, kvcache_partition_len, kvcache_block_size, use_attention_mask, use_position_ids):
|
931
953
|
self.kvcache_partition_size = kvcache_partition_len
|
932
954
|
super().__init__(
|
933
|
-
self_attn=self_attn,
|
955
|
+
self_attn=self_attn,
|
956
|
+
use_attention_mask=use_attention_mask,
|
957
|
+
use_position_ids=use_position_ids,
|
958
|
+
kvcache_block_size=kvcache_block_size,
|
934
959
|
)
|
935
960
|
|
936
961
|
def get_attention(self):
|
@@ -940,6 +965,7 @@ class DecoderOnlyFlashAttention(DecoderOnlyAttention):
|
|
940
965
|
self.num_key_value_heads,
|
941
966
|
self.kvcache_partition_size,
|
942
967
|
self.use_attention_mask,
|
968
|
+
self.use_position_ids,
|
943
969
|
)
|
944
970
|
|
945
971
|
def forward(
|
@@ -991,12 +1017,14 @@ class FlashAttentionOp(AttentionOp):
|
|
991
1017
|
num_key_value_heads: int,
|
992
1018
|
kvcache_partition_len: int,
|
993
1019
|
use_attention_mask: bool,
|
1020
|
+
use_position_ids: bool,
|
994
1021
|
):
|
995
1022
|
super().__init__(
|
996
1023
|
num_heads=num_heads,
|
997
1024
|
head_dim=head_dim,
|
998
1025
|
num_key_value_heads=num_key_value_heads,
|
999
1026
|
use_attention_mask=use_attention_mask,
|
1027
|
+
use_position_ids=use_position_ids,
|
1000
1028
|
)
|
1001
1029
|
self.kvcache_partition_size = kvcache_partition_len
|
1002
1030
|
|
@@ -1016,7 +1044,7 @@ class FlashAttentionOp(AttentionOp):
|
|
1016
1044
|
# reshape for removing repeat_kv (batch=1 , num_head, 1, q_len=1, head_dim)
|
1017
1045
|
key_state = key_state.unsqueeze(2)
|
1018
1046
|
value_state = value_state.unsqueeze(2)
|
1019
|
-
if self.use_attention_mask:
|
1047
|
+
if self.use_attention_mask and not self.use_position_ids:
|
1020
1048
|
attn_mask = attn_mask.unsqueeze(2)
|
1021
1049
|
|
1022
1050
|
if self.phase == "decode":
|
@@ -1033,7 +1061,7 @@ class FlashAttentionOp(AttentionOp):
|
|
1033
1061
|
)
|
1034
1062
|
|
1035
1063
|
if self.phase == "decode":
|
1036
|
-
if self.use_attention_mask:
|
1064
|
+
if self.use_attention_mask and not self.use_position_ids:
|
1037
1065
|
attn_output = torch.ops.rbln_custom_ops.paged_flash_attn_decode(
|
1038
1066
|
q=query_state,
|
1039
1067
|
k=key_state,
|
@@ -1059,10 +1087,10 @@ class FlashAttentionOp(AttentionOp):
|
|
1059
1087
|
block_table=block_tables,
|
1060
1088
|
block_size=kvcache_block_size,
|
1061
1089
|
partition=self.kvcache_partition_size,
|
1062
|
-
mask=None,
|
1090
|
+
mask=attn_mask if self.use_position_ids else None,
|
1063
1091
|
)
|
1064
1092
|
else:
|
1065
|
-
if self.use_attention_mask:
|
1093
|
+
if self.use_attention_mask and not self.use_position_ids:
|
1066
1094
|
attn_output = torch.ops.rbln_custom_ops.paged_flash_attn_prefill(
|
1067
1095
|
q=query_state,
|
1068
1096
|
k=key_state,
|
@@ -1088,8 +1116,8 @@ class FlashAttentionOp(AttentionOp):
|
|
1088
1116
|
block_table=block_tables,
|
1089
1117
|
block_size=kvcache_block_size,
|
1090
1118
|
partition=self.kvcache_partition_size,
|
1091
|
-
is_bidirectional=False,
|
1092
|
-
mask=None,
|
1119
|
+
is_bidirectional=True if self.phase == "image_prefill" else False,
|
1120
|
+
mask=attn_mask if self.use_position_ids else None,
|
1093
1121
|
)
|
1094
1122
|
|
1095
1123
|
# reshape for removing repeat_kv
|
@@ -1098,3 +1126,70 @@ class FlashAttentionOp(AttentionOp):
|
|
1098
1126
|
attn_output = attn_output.reshape(batch_size, -1, self.num_heads * self.head_dim)
|
1099
1127
|
|
1100
1128
|
return attn_output
|
1129
|
+
|
1130
|
+
|
1131
|
+
class SlidingWindowAttentionOp(AttentionOp):
|
1132
|
+
def forward(
|
1133
|
+
self,
|
1134
|
+
query_state: torch.Tensor,
|
1135
|
+
key_state: torch.Tensor,
|
1136
|
+
value_state: torch.Tensor,
|
1137
|
+
attn_mask: torch.Tensor,
|
1138
|
+
past_key_state: torch.Tensor,
|
1139
|
+
past_value_state: torch.Tensor,
|
1140
|
+
seq_position: Tuple[torch.Tensor],
|
1141
|
+
scale: torch.Tensor,
|
1142
|
+
block_tables: torch.Tensor,
|
1143
|
+
block_size: int,
|
1144
|
+
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
|
1145
|
+
# reshape for removing repeat_kv (batch=1 , num_head, 1, q_len=1, head_dim)
|
1146
|
+
key_state = key_state.unsqueeze(2)
|
1147
|
+
value_state = value_state.unsqueeze(2)
|
1148
|
+
|
1149
|
+
if self.phase == "decode":
|
1150
|
+
batch_size = key_state.shape[0]
|
1151
|
+
else:
|
1152
|
+
batch_size = 1
|
1153
|
+
|
1154
|
+
query_state = query_state.view(
|
1155
|
+
batch_size,
|
1156
|
+
self.num_key_value_heads,
|
1157
|
+
self.num_heads // self.num_key_value_heads,
|
1158
|
+
-1, # seq len
|
1159
|
+
self.head_dim,
|
1160
|
+
)
|
1161
|
+
|
1162
|
+
if self.phase == "decode":
|
1163
|
+
attn_output = torch.ops.rbln_custom_ops.paged_sliding_window_attn_decode(
|
1164
|
+
q=query_state,
|
1165
|
+
k=key_state,
|
1166
|
+
v=value_state,
|
1167
|
+
kcache=past_key_state.unsqueeze(2),
|
1168
|
+
vcache=past_value_state.unsqueeze(2),
|
1169
|
+
cache_seq_len=seq_position[0],
|
1170
|
+
cache_offset=seq_position[1],
|
1171
|
+
scale=scale,
|
1172
|
+
block_table=block_tables,
|
1173
|
+
block_size=block_size,
|
1174
|
+
)
|
1175
|
+
else:
|
1176
|
+
attn_output = torch.ops.rbln_custom_ops.paged_sliding_window_attn_prefill(
|
1177
|
+
q=query_state,
|
1178
|
+
k=key_state,
|
1179
|
+
v=value_state,
|
1180
|
+
kcache=past_key_state.unsqueeze(2),
|
1181
|
+
vcache=past_value_state.unsqueeze(2),
|
1182
|
+
cache_seq_len=seq_position[0],
|
1183
|
+
cache_offset=seq_position[1],
|
1184
|
+
scale=scale,
|
1185
|
+
block_table=block_tables,
|
1186
|
+
block_size=block_size,
|
1187
|
+
is_bidirectional=True if self.phase == "image_prefill" else False,
|
1188
|
+
)
|
1189
|
+
|
1190
|
+
# reshape for removing repeat_kv
|
1191
|
+
attn_output = attn_output.view(batch_size, self.num_heads, -1, self.head_dim)
|
1192
|
+
attn_output = attn_output.transpose(1, 2).contiguous()
|
1193
|
+
attn_output = attn_output.reshape(batch_size, -1, self.num_heads * self.head_dim)
|
1194
|
+
|
1195
|
+
return attn_output
|