optimum-rbln 0.9.4a2__py3-none-any.whl → 0.10.0.post1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- optimum/rbln/__init__.py +44 -0
- optimum/rbln/__version__.py +2 -2
- optimum/rbln/configuration_utils.py +230 -67
- optimum/rbln/diffusers/models/controlnet.py +2 -2
- optimum/rbln/diffusers/models/transformers/prior_transformer.py +2 -2
- optimum/rbln/diffusers/models/transformers/transformer_cosmos.py +2 -2
- optimum/rbln/diffusers/models/transformers/transformer_sd3.py +2 -2
- optimum/rbln/diffusers/pipelines/auto_pipeline.py +2 -3
- optimum/rbln/diffusers/pipelines/controlnet/pipeline_controlnet.py +3 -12
- optimum/rbln/diffusers/pipelines/controlnet/pipeline_controlnet_img2img.py +2 -4
- optimum/rbln/diffusers/pipelines/controlnet/pipeline_controlnet_sd_xl.py +1 -3
- optimum/rbln/diffusers/pipelines/controlnet/pipeline_controlnet_sd_xl_img2img.py +1 -3
- optimum/rbln/diffusers/pipelines/cosmos/cosmos_guardrail.py +2 -2
- optimum/rbln/modeling_base.py +11 -10
- optimum/rbln/ops/__init__.py +1 -0
- optimum/rbln/ops/attn.py +10 -0
- optimum/rbln/ops/flash_attn.py +8 -0
- optimum/rbln/ops/moe.py +180 -0
- optimum/rbln/ops/sliding_window_attn.py +9 -0
- optimum/rbln/transformers/__init__.py +44 -0
- optimum/rbln/transformers/modeling_attention_utils.py +124 -222
- optimum/rbln/transformers/modeling_outputs.py +25 -0
- optimum/rbln/transformers/modeling_rope_utils.py +78 -42
- optimum/rbln/transformers/models/__init__.py +38 -0
- optimum/rbln/transformers/models/auto/auto_factory.py +3 -3
- optimum/rbln/transformers/models/bart/bart_architecture.py +24 -24
- optimum/rbln/transformers/models/blip_2/configuration_blip_2.py +7 -2
- optimum/rbln/transformers/models/blip_2/modeling_blip_2.py +1 -1
- optimum/rbln/transformers/models/colpali/colpali_architecture.py +14 -20
- optimum/rbln/transformers/models/colpali/configuration_colpali.py +12 -17
- optimum/rbln/transformers/models/colpali/modeling_colpali.py +66 -182
- optimum/rbln/transformers/models/colqwen2/configuration_colqwen2.py +40 -23
- optimum/rbln/transformers/models/colqwen2/modeling_colqwen2.py +107 -371
- optimum/rbln/transformers/models/decoderonly/__init__.py +2 -0
- optimum/rbln/transformers/models/decoderonly/configuration_decoderonly.py +144 -17
- optimum/rbln/transformers/models/decoderonly/configuration_lora.py +1 -1
- optimum/rbln/transformers/models/decoderonly/decoderonly_architecture.py +122 -48
- optimum/rbln/transformers/models/decoderonly/decoderonly_runtime_utils.py +5 -7
- optimum/rbln/transformers/models/decoderonly/modeling_decoderonly.py +120 -128
- optimum/rbln/transformers/models/detr/__init__.py +23 -0
- optimum/rbln/transformers/models/detr/configuration_detr.py +38 -0
- optimum/rbln/transformers/models/detr/modeling_detr.py +53 -0
- optimum/rbln/transformers/models/exaone/exaone_architecture.py +0 -36
- optimum/rbln/transformers/models/gemma/gemma_architecture.py +1 -1
- optimum/rbln/transformers/models/gemma2/__init__.py +16 -0
- optimum/rbln/transformers/models/gemma2/configuration_gemma2.py +45 -0
- optimum/rbln/transformers/models/gemma2/gemma2_architecture.py +83 -0
- optimum/rbln/transformers/models/gemma2/modeling_gemma2.py +101 -0
- optimum/rbln/transformers/models/gemma3/configuration_gemma3.py +2 -7
- optimum/rbln/transformers/models/gemma3/gemma3_architecture.py +16 -18
- optimum/rbln/transformers/models/gemma3/modeling_gemma3.py +5 -177
- optimum/rbln/transformers/models/gpt2/gpt2_architecture.py +8 -34
- optimum/rbln/transformers/models/gpt_oss/__init__.py +16 -0
- optimum/rbln/transformers/models/gpt_oss/configuration_gpt_oss.py +42 -0
- optimum/rbln/transformers/models/gpt_oss/gpt_oss_architecture.py +122 -0
- optimum/rbln/transformers/models/gpt_oss/modeling_gpt_oss.py +168 -0
- optimum/rbln/transformers/models/grounding_dino/configuration_grounding_dino.py +8 -5
- optimum/rbln/transformers/models/grounding_dino/grounding_dino_architecture.py +6 -4
- optimum/rbln/transformers/models/llava/modeling_llava.py +0 -1
- optimum/rbln/transformers/models/midm/midm_architecture.py +29 -22
- optimum/rbln/transformers/models/mixtral/__init__.py +16 -0
- optimum/rbln/transformers/models/mixtral/configuration_mixtral.py +38 -0
- optimum/rbln/transformers/models/mixtral/mixtral_architecture.py +76 -0
- optimum/rbln/transformers/models/mixtral/modeling_mixtral.py +68 -0
- optimum/rbln/transformers/models/opt/opt_architecture.py +1 -44
- optimum/rbln/transformers/models/paligemma/__init__.py +16 -0
- optimum/rbln/transformers/models/paligemma/configuration_paligemma.py +129 -0
- optimum/rbln/transformers/models/paligemma/modeling_paligemma.py +564 -0
- optimum/rbln/transformers/models/pegasus/pegasus_architecture.py +24 -24
- optimum/rbln/transformers/models/phi/phi_architecture.py +13 -21
- optimum/rbln/transformers/models/pixtral/modeling_pixtral.py +9 -5
- optimum/rbln/transformers/models/qwen2_5_vl/__init__.py +6 -1
- optimum/rbln/transformers/models/qwen2_5_vl/configuration_qwen2_5_vl.py +13 -1
- optimum/rbln/transformers/models/qwen2_5_vl/modeling_qwen2_5_vl.py +271 -122
- optimum/rbln/transformers/models/qwen2_5_vl/qwen2_5_vl_architecture.py +43 -39
- optimum/rbln/transformers/models/qwen2_moe/__init__.py +16 -0
- optimum/rbln/transformers/models/qwen2_moe/configuration_qwen2_moe.py +38 -0
- optimum/rbln/transformers/models/qwen2_moe/modeling_qwen2_moe.py +68 -0
- optimum/rbln/transformers/models/qwen2_moe/qwen2_moe_architecture.py +94 -0
- optimum/rbln/transformers/models/qwen2_vl/__init__.py +6 -1
- optimum/rbln/transformers/models/qwen2_vl/configuration_qwen2_vl.py +13 -1
- optimum/rbln/transformers/models/qwen2_vl/modeling_qwen2_vl.py +263 -105
- optimum/rbln/transformers/models/qwen2_vl/qwen2_vl_architecture.py +26 -34
- optimum/rbln/transformers/models/qwen3/qwen3_architecture.py +7 -7
- optimum/rbln/transformers/models/qwen3_moe/__init__.py +16 -0
- optimum/rbln/transformers/models/qwen3_moe/configuration_qwen3_moe.py +38 -0
- optimum/rbln/transformers/models/qwen3_moe/modeling_qwen3_moe.py +68 -0
- optimum/rbln/transformers/models/qwen3_moe/qwen3_moe_architecture.py +100 -0
- optimum/rbln/transformers/models/resnet/configuration_resnet.py +10 -4
- optimum/rbln/transformers/models/seq2seq/seq2seq_architecture.py +14 -12
- optimum/rbln/transformers/models/siglip/modeling_siglip.py +4 -18
- optimum/rbln/transformers/models/swin/configuration_swin.py +1 -6
- optimum/rbln/transformers/models/t5/t5_architecture.py +15 -16
- optimum/rbln/transformers/models/time_series_transformer/time_series_transformers_architecture.py +0 -3
- optimum/rbln/transformers/models/whisper/generation_whisper.py +8 -8
- optimum/rbln/transformers/models/whisper/whisper_architecture.py +0 -3
- optimum/rbln/transformers/utils/rbln_quantization.py +20 -12
- optimum/rbln/utils/deprecation.py +78 -1
- optimum/rbln/utils/hub.py +93 -2
- optimum/rbln/utils/import_utils.py +16 -1
- optimum/rbln/utils/runtime_utils.py +12 -8
- optimum/rbln/utils/submodule.py +24 -0
- {optimum_rbln-0.9.4a2.dist-info → optimum_rbln-0.10.0.post1.dist-info}/METADATA +6 -6
- {optimum_rbln-0.9.4a2.dist-info → optimum_rbln-0.10.0.post1.dist-info}/RECORD +107 -81
- optimum/rbln/transformers/models/colqwen2/colqwen2_architecture.py +0 -233
- {optimum_rbln-0.9.4a2.dist-info → optimum_rbln-0.10.0.post1.dist-info}/WHEEL +0 -0
- {optimum_rbln-0.9.4a2.dist-info → optimum_rbln-0.10.0.post1.dist-info}/entry_points.txt +0 -0
- {optimum_rbln-0.9.4a2.dist-info → optimum_rbln-0.10.0.post1.dist-info}/licenses/LICENSE +0 -0
|
@@ -176,7 +176,7 @@ class RBLNAutoPipelineBase:
|
|
|
176
176
|
export: bool = None,
|
|
177
177
|
rbln_config: Optional[Union[Dict[str, Any], RBLNModelConfig]] = None,
|
|
178
178
|
**kwargs: Any,
|
|
179
|
-
):
|
|
179
|
+
) -> RBLNBaseModel:
|
|
180
180
|
"""
|
|
181
181
|
Load an RBLN-accelerated Diffusers pipeline from a pretrained checkpoint or a compiled RBLN artifact.
|
|
182
182
|
|
|
@@ -201,8 +201,7 @@ class RBLNAutoPipelineBase:
|
|
|
201
201
|
- Remaining arguments are forwarded to the Diffusers loader.
|
|
202
202
|
|
|
203
203
|
Returns:
|
|
204
|
-
RBLNBaseModel: An instantiated RBLN model wrapping the Diffusers pipeline, ready for
|
|
205
|
-
inference on RBLN NPUs.
|
|
204
|
+
RBLNBaseModel: An instantiated RBLN model wrapping the Diffusers pipeline, ready for inference on RBLN NPUs.
|
|
206
205
|
|
|
207
206
|
"""
|
|
208
207
|
rbln_cls = cls.get_rbln_cls(model_id, export=export, **kwargs)
|
|
@@ -26,7 +26,7 @@
|
|
|
26
26
|
# See the License for the specific language governing permissions and
|
|
27
27
|
# limitations under the License.
|
|
28
28
|
|
|
29
|
-
from typing import Any, Callable, Dict, List, Optional, Union
|
|
29
|
+
from typing import Any, Callable, Dict, List, Optional, Tuple, Union
|
|
30
30
|
|
|
31
31
|
import torch
|
|
32
32
|
import torch.nn.functional as F
|
|
@@ -260,7 +260,7 @@ class RBLNStableDiffusionControlNetPipeline(RBLNDiffusionMixin, StableDiffusionC
|
|
|
260
260
|
callback_on_step_end: Optional[Callable[[int, int, Dict], None]] = None,
|
|
261
261
|
callback_on_step_end_tensor_inputs: Optional[List[str]] = None,
|
|
262
262
|
**kwargs,
|
|
263
|
-
):
|
|
263
|
+
) -> Union[StableDiffusionPipelineOutput, Tuple]:
|
|
264
264
|
r"""
|
|
265
265
|
The call function to the pipeline for generation.
|
|
266
266
|
|
|
@@ -321,14 +321,7 @@ class RBLNStableDiffusionControlNetPipeline(RBLNDiffusionMixin, StableDiffusionC
|
|
|
321
321
|
output_type (`str`, *optional*, defaults to `"pil"`):
|
|
322
322
|
The output format of the generated image. Choose between `PIL.Image` or `np.array`.
|
|
323
323
|
return_dict (`bool`, *optional*, defaults to `True`):
|
|
324
|
-
Whether or not to return a [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] instead of a
|
|
325
|
-
plain tuple.
|
|
326
|
-
callback (`Callable`, *optional*):
|
|
327
|
-
A function that calls every `callback_steps` steps during inference. The function is called with the
|
|
328
|
-
following arguments: `callback(step: int, timestep: int, latents: torch.FloatTensor)`.
|
|
329
|
-
callback_steps (`int`, *optional*, defaults to 1):
|
|
330
|
-
The frequency at which the `callback` function is called. If not specified, the callback is called at
|
|
331
|
-
every step.
|
|
324
|
+
Whether or not to return a [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] instead of a plain tuple.
|
|
332
325
|
cross_attention_kwargs (`dict`, *optional*):
|
|
333
326
|
A kwargs dictionary that if specified is passed along to the [`AttentionProcessor`] as defined in
|
|
334
327
|
[`self.processor`](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py).
|
|
@@ -356,8 +349,6 @@ class RBLNStableDiffusionControlNetPipeline(RBLNDiffusionMixin, StableDiffusionC
|
|
|
356
349
|
will be passed as `callback_kwargs` argument. You will only be able to include variables listed in the
|
|
357
350
|
`._callback_tensor_inputs` attribute of your pipeine class.
|
|
358
351
|
|
|
359
|
-
Examples:
|
|
360
|
-
|
|
361
352
|
Returns:
|
|
362
353
|
[`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] or `tuple`:
|
|
363
354
|
If `return_dict` is `True`, [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] is returned,
|
|
@@ -26,7 +26,7 @@
|
|
|
26
26
|
# See the License for the specific language governing permissions and
|
|
27
27
|
# limitations under the License.
|
|
28
28
|
|
|
29
|
-
from typing import Any, Callable, Dict, List, Optional, Union
|
|
29
|
+
from typing import Any, Callable, Dict, List, Optional, Tuple, Union
|
|
30
30
|
|
|
31
31
|
import torch
|
|
32
32
|
import torch.nn.functional as F
|
|
@@ -253,7 +253,7 @@ class RBLNStableDiffusionControlNetImg2ImgPipeline(RBLNDiffusionMixin, StableDif
|
|
|
253
253
|
callback_on_step_end: Optional[Callable[[int, int, Dict], None]] = None,
|
|
254
254
|
callback_on_step_end_tensor_inputs: Optional[List[str]] = None,
|
|
255
255
|
**kwargs,
|
|
256
|
-
):
|
|
256
|
+
) -> Union[StableDiffusionPipelineOutput, Tuple]:
|
|
257
257
|
r"""
|
|
258
258
|
The call function to the pipeline for generation.
|
|
259
259
|
|
|
@@ -347,8 +347,6 @@ class RBLNStableDiffusionControlNetImg2ImgPipeline(RBLNDiffusionMixin, StableDif
|
|
|
347
347
|
will be passed as `callback_kwargs` argument. You will only be able to include variables listed in the
|
|
348
348
|
`._callback_tensor_inputs` attribute of your pipeine class.
|
|
349
349
|
|
|
350
|
-
Examples:
|
|
351
|
-
|
|
352
350
|
Returns:
|
|
353
351
|
[`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] or `tuple`:
|
|
354
352
|
If `return_dict` is `True`, [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] is returned,
|
|
@@ -294,7 +294,7 @@ class RBLNStableDiffusionXLControlNetPipeline(RBLNDiffusionMixin, StableDiffusio
|
|
|
294
294
|
callback_on_step_end: Optional[Callable[[int, int, Dict], None]] = None,
|
|
295
295
|
callback_on_step_end_tensor_inputs: Optional[List[str]] = None,
|
|
296
296
|
**kwargs,
|
|
297
|
-
):
|
|
297
|
+
) -> Union[StableDiffusionXLPipelineOutput, Tuple]:
|
|
298
298
|
r"""
|
|
299
299
|
The call function to the pipeline for generation.
|
|
300
300
|
|
|
@@ -431,8 +431,6 @@ class RBLNStableDiffusionXLControlNetPipeline(RBLNDiffusionMixin, StableDiffusio
|
|
|
431
431
|
will be passed as `callback_kwargs` argument. You will only be able to include variables listed in the
|
|
432
432
|
`._callback_tensor_inputs` attribute of your pipeine class.
|
|
433
433
|
|
|
434
|
-
Examples:
|
|
435
|
-
|
|
436
434
|
Returns:
|
|
437
435
|
[`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] or `tuple`:
|
|
438
436
|
If `return_dict` is `True`, [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] is returned,
|
|
@@ -309,7 +309,7 @@ class RBLNStableDiffusionXLControlNetImg2ImgPipeline(RBLNDiffusionMixin, StableD
|
|
|
309
309
|
callback_on_step_end: Optional[Callable[[int, int, Dict], None]] = None,
|
|
310
310
|
callback_on_step_end_tensor_inputs: Optional[List[str]] = None,
|
|
311
311
|
**kwargs,
|
|
312
|
-
):
|
|
312
|
+
) -> Union[StableDiffusionXLPipelineOutput, Tuple]:
|
|
313
313
|
r"""
|
|
314
314
|
Function invoked when calling the pipeline for generation.
|
|
315
315
|
|
|
@@ -465,8 +465,6 @@ class RBLNStableDiffusionXLControlNetImg2ImgPipeline(RBLNDiffusionMixin, StableD
|
|
|
465
465
|
will be passed as `callback_kwargs` argument. You will only be able to include variables listed in the
|
|
466
466
|
`._callback_tensor_inputs` attribute of your pipeine class.
|
|
467
467
|
|
|
468
|
-
Examples:
|
|
469
|
-
|
|
470
468
|
Returns:
|
|
471
469
|
[`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] or `tuple`:
|
|
472
470
|
[`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] if `return_dict` is True, otherwise a `tuple`
|
|
@@ -203,7 +203,7 @@ class RBLNRetinaFaceFilter(RetinaFaceFilter):
|
|
|
203
203
|
f"If you only need to compile the model without loading it to NPU, you can use:\n"
|
|
204
204
|
f" from_pretrained(..., rbln_create_runtimes=False) or\n"
|
|
205
205
|
f" from_pretrained(..., rbln_config={{..., 'create_runtimes': False}})\n\n"
|
|
206
|
-
f"To check your NPU status, run the 'rbln-
|
|
206
|
+
f"To check your NPU status, run the 'rbln-smi' command in your terminal.\n"
|
|
207
207
|
f"Make sure your NPU is properly installed and operational."
|
|
208
208
|
)
|
|
209
209
|
raise rebel.core.exception.RBLNRuntimeError(error_msg) from e
|
|
@@ -278,7 +278,7 @@ class RBLNVideoSafetyModel(VideoSafetyModel):
|
|
|
278
278
|
f"If you only need to compile the model without loading it to NPU, you can use:\n"
|
|
279
279
|
f" from_pretrained(..., rbln_create_runtimes=False) or\n"
|
|
280
280
|
f" from_pretrained(..., rbln_config={{..., 'create_runtimes': False}})\n\n"
|
|
281
|
-
f"To check your NPU status, run the 'rbln-
|
|
281
|
+
f"To check your NPU status, run the 'rbln-smi' command in your terminal.\n"
|
|
282
282
|
f"Make sure your NPU is properly installed and operational."
|
|
283
283
|
)
|
|
284
284
|
raise rebel.core.exception.RBLNRuntimeError(error_msg) from e
|
optimum/rbln/modeling_base.py
CHANGED
|
@@ -24,7 +24,7 @@ import torch
|
|
|
24
24
|
from transformers import AutoConfig, AutoModel, GenerationConfig, PretrainedConfig
|
|
25
25
|
from transformers.utils.hub import PushToHubMixin
|
|
26
26
|
|
|
27
|
-
from .configuration_utils import
|
|
27
|
+
from .configuration_utils import RBLNCompileConfig, RBLNModelConfig, get_rbln_config_class
|
|
28
28
|
from .utils.hub import pull_compiled_model_from_hub, validate_files
|
|
29
29
|
from .utils.logging import get_logger
|
|
30
30
|
from .utils.runtime_utils import UnavailableRuntime, tp_and_devices_are_ok
|
|
@@ -90,7 +90,7 @@ class RBLNBaseModel(SubModulesMixin, PushToHubMixin, PreTrainedModel):
|
|
|
90
90
|
|
|
91
91
|
self.device = torch.device("cpu")
|
|
92
92
|
self.training = False
|
|
93
|
-
self.dtype = rbln_config.
|
|
93
|
+
self.dtype = rbln_config.dtype
|
|
94
94
|
|
|
95
95
|
# FIXME :: model_save_dir is not used after initialized. (This can be used when save/load)
|
|
96
96
|
# This attribute is needed to keep one reference on the temporary directory, since garbage collecting it
|
|
@@ -206,8 +206,9 @@ class RBLNBaseModel(SubModulesMixin, PushToHubMixin, PreTrainedModel):
|
|
|
206
206
|
f"does not match the expected model class name ({cls.__name__})."
|
|
207
207
|
)
|
|
208
208
|
|
|
209
|
-
|
|
210
|
-
|
|
209
|
+
config_cls = cls.get_rbln_config_class()
|
|
210
|
+
rbln_config, kwargs = config_cls.from_pretrained(
|
|
211
|
+
model_path_subfolder, rbln_config=rbln_config, return_unused_kwargs=True, **kwargs
|
|
211
212
|
)
|
|
212
213
|
|
|
213
214
|
if rbln_config.rbln_model_cls_name != cls.__name__:
|
|
@@ -223,8 +224,6 @@ class RBLNBaseModel(SubModulesMixin, PushToHubMixin, PreTrainedModel):
|
|
|
223
224
|
elif rbln_submodules is None:
|
|
224
225
|
rbln_submodules = []
|
|
225
226
|
|
|
226
|
-
rbln_config.freeze()
|
|
227
|
-
|
|
228
227
|
if config is None:
|
|
229
228
|
if cls.hf_library_name == "transformers":
|
|
230
229
|
config = AutoConfig.from_pretrained(
|
|
@@ -308,11 +307,13 @@ class RBLNBaseModel(SubModulesMixin, PushToHubMixin, PreTrainedModel):
|
|
|
308
307
|
f"If you only need to compile the model without loading it to NPU, you can use:\n"
|
|
309
308
|
f" from_pretrained(..., rbln_create_runtimes=False) or\n"
|
|
310
309
|
f" from_pretrained(..., rbln_config={{..., 'create_runtimes': False}})\n\n"
|
|
311
|
-
f"To check your NPU status, run the 'rbln-
|
|
310
|
+
f"To check your NPU status, run the 'rbln-smi' command in your terminal.\n"
|
|
312
311
|
f"Make sure your NPU is properly installed and operational."
|
|
313
312
|
)
|
|
314
313
|
raise rebel.core.exception.RBLNRuntimeError(error_msg) from e
|
|
315
314
|
|
|
315
|
+
rbln_config.freeze()
|
|
316
|
+
|
|
316
317
|
return cls(
|
|
317
318
|
models,
|
|
318
319
|
config,
|
|
@@ -451,15 +452,15 @@ class RBLNBaseModel(SubModulesMixin, PushToHubMixin, PreTrainedModel):
|
|
|
451
452
|
model_config: "PretrainedConfig",
|
|
452
453
|
rbln_config: RBLNModelConfig,
|
|
453
454
|
) -> RBLNModelConfig:
|
|
454
|
-
rbln_config.
|
|
455
|
-
if not cls._supports_non_fp32 and rbln_config.
|
|
455
|
+
rbln_config.dtype = model.dtype
|
|
456
|
+
if not cls._supports_non_fp32 and rbln_config.dtype != torch.float32:
|
|
456
457
|
raise NotImplementedError(
|
|
457
458
|
f"Currently, {cls.__name__} does not support non-fp32 dtype. Please use float32 dtype."
|
|
458
459
|
)
|
|
459
460
|
rbln_config = cls._update_rbln_config(
|
|
460
461
|
preprocessors=preprocessors, model=model, model_config=model_config, rbln_config=rbln_config
|
|
461
462
|
)
|
|
462
|
-
|
|
463
|
+
|
|
463
464
|
if rbln_config.rbln_model_cls_name != cls.__name__:
|
|
464
465
|
raise NameError(
|
|
465
466
|
f"Cannot get the rbln config. {cls.__name__} is not the same as {rbln_config.rbln_model_cls_name}. "
|
optimum/rbln/ops/__init__.py
CHANGED
optimum/rbln/ops/attn.py
CHANGED
|
@@ -205,6 +205,7 @@ def paged_causal_attn_decode(
|
|
|
205
205
|
block_table: Tensor,
|
|
206
206
|
block_size: int,
|
|
207
207
|
mask: Optional[Tensor] = None,
|
|
208
|
+
s_aux: Optional[Tensor] = None,
|
|
208
209
|
) -> Tensor:
|
|
209
210
|
"""Defines the computation pattern for fused attention with KV cache updates.
|
|
210
211
|
|
|
@@ -228,6 +229,7 @@ def paged_causal_attn_decode(
|
|
|
228
229
|
- block_table: [batch_size, max_seq_len // block_size] - Block indices for KV cache management
|
|
229
230
|
- block_size: [] - Number of tokens per block
|
|
230
231
|
- mask: [batch=1, max_seq_len] - attention mask when use position_ids
|
|
232
|
+
- s_aux: [num_attention_heads, sink_len] - auxiliary states for attention
|
|
231
233
|
|
|
232
234
|
Returns:
|
|
233
235
|
Tensor: attn_output: [batch=1, n_heads, n_groups, 1, head_dim] - Attention output
|
|
@@ -247,6 +249,7 @@ def paged_causal_attn_decode_fake(
|
|
|
247
249
|
block_table: Tensor,
|
|
248
250
|
block_size: int,
|
|
249
251
|
mask: Optional[Tensor] = None,
|
|
252
|
+
s_aux: Optional[Tensor] = None,
|
|
250
253
|
) -> Tensor:
|
|
251
254
|
return torch.empty_like(q)
|
|
252
255
|
|
|
@@ -267,6 +270,7 @@ def paged_causal_attn_prefill(
|
|
|
267
270
|
block_size: int,
|
|
268
271
|
is_bidirectional: bool,
|
|
269
272
|
mask: Optional[Tensor] = None,
|
|
273
|
+
s_aux: Optional[Tensor] = None,
|
|
270
274
|
) -> Tensor:
|
|
271
275
|
"""Defines the computation pattern for prefill phase attention with KV cache updates.
|
|
272
276
|
|
|
@@ -290,6 +294,7 @@ def paged_causal_attn_prefill(
|
|
|
290
294
|
- block_size: [] - Number of tokens per block
|
|
291
295
|
- is_bidirectional: [] - Whether the attention is bidirectional at current sequence position
|
|
292
296
|
- mask: [batch=1, max_seq_len] - attention mask when use position_ids
|
|
297
|
+
- s_aux: [num_attention_heads, sink_len] - auxiliary states for attention
|
|
293
298
|
|
|
294
299
|
Returns:
|
|
295
300
|
Tensor: attn_output: [batch=1, n_heads, n_groups, seq_len, head_dim] - Attention output
|
|
@@ -310,6 +315,7 @@ def paged_causal_attn_prefill_fake(
|
|
|
310
315
|
block_size: int,
|
|
311
316
|
is_bidirectional: bool,
|
|
312
317
|
mask: Optional[Tensor] = None,
|
|
318
|
+
s_aux: Optional[Tensor] = None,
|
|
313
319
|
) -> Tensor:
|
|
314
320
|
return torch.empty_like(q)
|
|
315
321
|
|
|
@@ -331,6 +337,7 @@ def paged_causal_attn_decode_kv_fp8(
|
|
|
331
337
|
k_scale: Tensor,
|
|
332
338
|
v_scale: Tensor,
|
|
333
339
|
mask: Optional[Tensor] = None,
|
|
340
|
+
s_aux: Optional[Tensor] = None,
|
|
334
341
|
) -> Tensor:
|
|
335
342
|
return torch.empty_like(q)
|
|
336
343
|
|
|
@@ -349,6 +356,7 @@ def paged_causal_attn_decode_kv_fp8_fake(
|
|
|
349
356
|
k_scale: Tensor,
|
|
350
357
|
v_scale: Tensor,
|
|
351
358
|
mask: Optional[Tensor] = None,
|
|
359
|
+
s_aux: Optional[Tensor] = None,
|
|
352
360
|
) -> Tensor:
|
|
353
361
|
return torch.empty_like(q)
|
|
354
362
|
|
|
@@ -371,6 +379,7 @@ def paged_causal_attn_prefill_kv_fp8(
|
|
|
371
379
|
k_scale: Tensor,
|
|
372
380
|
v_scale: Tensor,
|
|
373
381
|
mask: Optional[Tensor] = None,
|
|
382
|
+
s_aux: Optional[Tensor] = None,
|
|
374
383
|
) -> Tensor:
|
|
375
384
|
return torch.empty_like(q)
|
|
376
385
|
|
|
@@ -390,6 +399,7 @@ def paged_causal_attn_prefill_kv_fp8_fake(
|
|
|
390
399
|
k_scale: Tensor,
|
|
391
400
|
v_scale: Tensor,
|
|
392
401
|
mask: Optional[Tensor] = None,
|
|
402
|
+
s_aux: Optional[Tensor] = None,
|
|
393
403
|
) -> Tensor:
|
|
394
404
|
return torch.empty_like(q)
|
|
395
405
|
|
optimum/rbln/ops/flash_attn.py
CHANGED
|
@@ -198,6 +198,7 @@ def paged_flash_causal_attn_decode(
|
|
|
198
198
|
block_size: int,
|
|
199
199
|
partition: int,
|
|
200
200
|
mask: Optional[Tensor] = None,
|
|
201
|
+
s_aux: Optional[Tensor] = None,
|
|
201
202
|
) -> Tensor:
|
|
202
203
|
"""Defines the computation pattern for fused causal flash attention with KV cache for decoding.
|
|
203
204
|
|
|
@@ -219,6 +220,7 @@ def paged_flash_causal_attn_decode_fake(
|
|
|
219
220
|
block_size: int,
|
|
220
221
|
partition: int,
|
|
221
222
|
mask: Optional[Tensor] = None,
|
|
223
|
+
s_aux: Optional[Tensor] = None,
|
|
222
224
|
) -> Tensor:
|
|
223
225
|
return torch.empty_like(q)
|
|
224
226
|
|
|
@@ -241,6 +243,7 @@ def paged_flash_causal_attn_decode_kv_fp8(
|
|
|
241
243
|
k_scale: Tensor,
|
|
242
244
|
v_scale: Tensor,
|
|
243
245
|
mask: Optional[Tensor] = None,
|
|
246
|
+
s_aux: Optional[Tensor] = None,
|
|
244
247
|
) -> Tensor:
|
|
245
248
|
return torch.empty_like(q)
|
|
246
249
|
|
|
@@ -260,6 +263,7 @@ def paged_flash_causal_attn_decode_kv_fp8_fake(
|
|
|
260
263
|
k_scale: Tensor,
|
|
261
264
|
v_scale: Tensor,
|
|
262
265
|
mask: Optional[Tensor] = None,
|
|
266
|
+
s_aux: Optional[Tensor] = None,
|
|
263
267
|
) -> Tensor:
|
|
264
268
|
return torch.empty_like(q)
|
|
265
269
|
|
|
@@ -281,6 +285,7 @@ def paged_flash_causal_attn_prefill(
|
|
|
281
285
|
partition: int,
|
|
282
286
|
is_bidirectional: bool,
|
|
283
287
|
mask: Optional[Tensor] = None,
|
|
288
|
+
s_aux: Optional[Tensor] = None,
|
|
284
289
|
) -> Tensor:
|
|
285
290
|
"""Defines the computation pattern for fused causal flash attention with KV cache for prefill.
|
|
286
291
|
|
|
@@ -303,6 +308,7 @@ def paged_flash_causal_attn_prefill_fake(
|
|
|
303
308
|
partition: int,
|
|
304
309
|
is_bidirectional: bool,
|
|
305
310
|
mask: Optional[Tensor] = None,
|
|
311
|
+
s_aux: Optional[Tensor] = None,
|
|
306
312
|
) -> Tensor:
|
|
307
313
|
return torch.empty_like(q)
|
|
308
314
|
|
|
@@ -326,6 +332,7 @@ def paged_flash_causal_attn_prefill_kv_fp8(
|
|
|
326
332
|
k_scale: Tensor,
|
|
327
333
|
v_scale: Tensor,
|
|
328
334
|
mask: Optional[Tensor] = None,
|
|
335
|
+
s_aux: Optional[Tensor] = None,
|
|
329
336
|
) -> Tensor:
|
|
330
337
|
return torch.empty_like(q)
|
|
331
338
|
|
|
@@ -346,5 +353,6 @@ def paged_flash_causal_attn_prefill_kv_fp8_fake(
|
|
|
346
353
|
k_scale: Tensor,
|
|
347
354
|
v_scale: Tensor,
|
|
348
355
|
mask: Optional[Tensor] = None,
|
|
356
|
+
s_aux: Optional[Tensor] = None,
|
|
349
357
|
) -> Tensor:
|
|
350
358
|
return torch.empty_like(q)
|
optimum/rbln/ops/moe.py
ADDED
|
@@ -0,0 +1,180 @@
|
|
|
1
|
+
# Copyright 2025 Rebellions Inc. All rights reserved.
|
|
2
|
+
|
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
4
|
+
# you may not use this file except in compliance with the License.
|
|
5
|
+
# You may obtain a copy of the License at:
|
|
6
|
+
|
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
8
|
+
|
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
12
|
+
# See the License for the specific language governing permissions and
|
|
13
|
+
# limitations under the License.
|
|
14
|
+
|
|
15
|
+
from typing import Optional
|
|
16
|
+
|
|
17
|
+
import torch
|
|
18
|
+
from torch import Tensor
|
|
19
|
+
|
|
20
|
+
|
|
21
|
+
@torch.library.custom_op(
|
|
22
|
+
"rbln_custom_ops::custom_moe_glu",
|
|
23
|
+
mutates_args=(),
|
|
24
|
+
)
|
|
25
|
+
def custom_moe_glu(
|
|
26
|
+
hidden_states: Tensor,
|
|
27
|
+
gate_proj_weight: Tensor,
|
|
28
|
+
up_proj_weight: Tensor,
|
|
29
|
+
down_proj_weight: Tensor,
|
|
30
|
+
router_logits: Tensor,
|
|
31
|
+
topk: int,
|
|
32
|
+
norm_topk_prob: bool,
|
|
33
|
+
gate_proj_bias: Optional[Tensor] = None,
|
|
34
|
+
up_proj_bias: Optional[Tensor] = None,
|
|
35
|
+
down_proj_bias: Optional[Tensor] = None,
|
|
36
|
+
) -> Tensor:
|
|
37
|
+
"""
|
|
38
|
+
Customized MoE GLU operation.
|
|
39
|
+
|
|
40
|
+
Expected tensor shapes:
|
|
41
|
+
- hidden_states: [batch*seq_len, hidden_size]
|
|
42
|
+
- gate_proj_weight: [num_experts, hidden_size, intermediate_size]
|
|
43
|
+
- up_proj_weight: [num_experts, hidden_size, intermediate_size]
|
|
44
|
+
- down_proj_weight: [num_experts, intermediate_size, hidden_size]
|
|
45
|
+
- router_logits: [batch*seq_len, num_experts]
|
|
46
|
+
- topk: top k experts to select
|
|
47
|
+
- norm_topk_prob: whether to normalize the top k routing weights with softmax
|
|
48
|
+
- gate_proj_bias: [num_experts, intermediate_size]
|
|
49
|
+
- up_proj_bias: [num_experts, intermediate_size]
|
|
50
|
+
- down_proj_bias: [num_experts, hidden_size]
|
|
51
|
+
|
|
52
|
+
Returns:
|
|
53
|
+
Tensor: [batch * seq_len, hidden_size]
|
|
54
|
+
"""
|
|
55
|
+
|
|
56
|
+
return torch.empty_like(hidden_states)
|
|
57
|
+
|
|
58
|
+
|
|
59
|
+
@custom_moe_glu.register_fake
|
|
60
|
+
def custom_moe_glu_fake(
|
|
61
|
+
hidden_states: Tensor,
|
|
62
|
+
gate_proj_weight: Tensor,
|
|
63
|
+
up_proj_weight: Tensor,
|
|
64
|
+
down_proj_weight: Tensor,
|
|
65
|
+
router_logits: Tensor,
|
|
66
|
+
topk: int,
|
|
67
|
+
norm_topk_prob: bool,
|
|
68
|
+
gate_proj_bias: Optional[Tensor] = None,
|
|
69
|
+
up_proj_bias: Optional[Tensor] = None,
|
|
70
|
+
down_proj_bias: Optional[Tensor] = None,
|
|
71
|
+
) -> Tensor:
|
|
72
|
+
return torch.empty_like(hidden_states)
|
|
73
|
+
|
|
74
|
+
|
|
75
|
+
@torch.library.custom_op(
|
|
76
|
+
"rbln_custom_ops::custom_moe_ff",
|
|
77
|
+
mutates_args=(),
|
|
78
|
+
)
|
|
79
|
+
def custom_moe_ff(
|
|
80
|
+
hidden_states: Tensor,
|
|
81
|
+
gate_proj_weight: Tensor,
|
|
82
|
+
down_proj_weight: Tensor,
|
|
83
|
+
masked_routing_weight: Tensor,
|
|
84
|
+
gate_proj_bias: Optional[Tensor] = None,
|
|
85
|
+
down_proj_bias: Optional[Tensor] = None,
|
|
86
|
+
) -> Tensor:
|
|
87
|
+
"""
|
|
88
|
+
Customized MoE FF operation.
|
|
89
|
+
|
|
90
|
+
Expected tensor shapes:
|
|
91
|
+
- hidden_states: [batch * seq_len, hidden_size]
|
|
92
|
+
- gate_proj_weight: [hidden_size, num_experts * intermediate_size]
|
|
93
|
+
- down_proj_weight: [num_experts * intermediate_size, hidden_size]
|
|
94
|
+
- masked_routing_weight: [batch * seq_len, num_experts]
|
|
95
|
+
- gate_proj_bias: [num_experts * intermediate_size]
|
|
96
|
+
- down_proj_bias: [hidden_size]
|
|
97
|
+
|
|
98
|
+
Returns:
|
|
99
|
+
Tensor: [batch * seq_len, hidden_size]
|
|
100
|
+
"""
|
|
101
|
+
return torch.empty_like(hidden_states)
|
|
102
|
+
|
|
103
|
+
|
|
104
|
+
@custom_moe_ff.register_fake
|
|
105
|
+
def custom_moe_ff_fake(
|
|
106
|
+
hidden_states: Tensor,
|
|
107
|
+
gate_proj_weight: Tensor,
|
|
108
|
+
down_proj_weight: Tensor,
|
|
109
|
+
masked_routing_weight: Tensor,
|
|
110
|
+
gate_proj_bias: Optional[Tensor] = None,
|
|
111
|
+
down_proj_bias: Optional[Tensor] = None,
|
|
112
|
+
) -> Tensor:
|
|
113
|
+
return torch.empty_like(hidden_states)
|
|
114
|
+
|
|
115
|
+
|
|
116
|
+
@torch.library.custom_op(
|
|
117
|
+
"rbln_custom_ops::custom_moe_glu_mxfp4",
|
|
118
|
+
mutates_args=(),
|
|
119
|
+
)
|
|
120
|
+
def custom_moe_glu_mxfp4(
|
|
121
|
+
hidden_states: Tensor,
|
|
122
|
+
gate_proj_blocks: Tensor,
|
|
123
|
+
gate_proj_scales: Tensor,
|
|
124
|
+
gate_proj_bias: Tensor,
|
|
125
|
+
up_proj_blocks: Tensor,
|
|
126
|
+
up_proj_scales: Tensor,
|
|
127
|
+
up_proj_bias: Tensor,
|
|
128
|
+
down_proj_blocks: Tensor,
|
|
129
|
+
down_proj_scales: Tensor,
|
|
130
|
+
down_proj_bias: Tensor,
|
|
131
|
+
router_logits: Tensor,
|
|
132
|
+
alpha: Tensor,
|
|
133
|
+
limit: Tensor,
|
|
134
|
+
k: int,
|
|
135
|
+
post_norm: bool,
|
|
136
|
+
) -> Tensor:
|
|
137
|
+
"""
|
|
138
|
+
Customized MoE GLU operation.
|
|
139
|
+
|
|
140
|
+
Expected tensor shapes:
|
|
141
|
+
- hidden_states: [batch*seq_len, hidden_size]
|
|
142
|
+
- gate_proj_blocks: [num_experts, intermediate_size, hidden_size // 2]
|
|
143
|
+
- gate_proj_scales: [num_experts, intermediate_size, hidden_size // 32]
|
|
144
|
+
- gate_proj_bias: [num_experts, intermediate_size]
|
|
145
|
+
- up_proj_blocks: [num_experts, intermediate_size, hidden_size // 2]
|
|
146
|
+
- up_proj_scales: [num_experts, intermediate_size, hidden_size // 32]
|
|
147
|
+
- up_proj_bias: [num_experts, intermediate_size]
|
|
148
|
+
- down_proj_blocks: [num_experts, hidden_size, intermediate_size // 2]
|
|
149
|
+
- down_proj_scales: [num_experts, hidden_size, intermediate_size // 32]
|
|
150
|
+
- masked_routing_weight: [batch * seq_len, num_experts]
|
|
151
|
+
- expert_select_count: [num_experts]
|
|
152
|
+
- alpha: []
|
|
153
|
+
- limit: []
|
|
154
|
+
|
|
155
|
+
Returns:
|
|
156
|
+
Tensor: [batch * seq_len, hidden_size]
|
|
157
|
+
"""
|
|
158
|
+
|
|
159
|
+
return torch.empty_like(hidden_states)
|
|
160
|
+
|
|
161
|
+
|
|
162
|
+
@custom_moe_glu_mxfp4.register_fake
|
|
163
|
+
def custom_moe_glu_mxfp4_fake(
|
|
164
|
+
hidden_states: Tensor,
|
|
165
|
+
gate_proj_blocks: Tensor,
|
|
166
|
+
gate_proj_scales: Tensor,
|
|
167
|
+
gate_proj_bias: Tensor,
|
|
168
|
+
up_proj_blocks: Tensor,
|
|
169
|
+
up_proj_scales: Tensor,
|
|
170
|
+
up_proj_bias: Tensor,
|
|
171
|
+
down_proj_blocks: Tensor,
|
|
172
|
+
down_proj_scales: Tensor,
|
|
173
|
+
down_proj_bias: Tensor,
|
|
174
|
+
router_logits: Tensor,
|
|
175
|
+
alpha: Tensor,
|
|
176
|
+
limit: Tensor,
|
|
177
|
+
k: int,
|
|
178
|
+
post_norm: bool,
|
|
179
|
+
) -> Tensor:
|
|
180
|
+
return torch.empty_like(hidden_states)
|
|
@@ -13,6 +13,8 @@
|
|
|
13
13
|
# limitations under the License.
|
|
14
14
|
|
|
15
15
|
|
|
16
|
+
from typing import Optional
|
|
17
|
+
|
|
16
18
|
import torch
|
|
17
19
|
from torch import Tensor
|
|
18
20
|
|
|
@@ -33,6 +35,7 @@ def paged_sliding_window_attn_prefill(
|
|
|
33
35
|
block_table: Tensor,
|
|
34
36
|
block_size: int,
|
|
35
37
|
is_bidirectional: bool,
|
|
38
|
+
s_aux: Optional[Tensor] = None,
|
|
36
39
|
) -> Tensor:
|
|
37
40
|
"""Defines the computation pattern for prefill phase attention with KV cache updates.
|
|
38
41
|
|
|
@@ -53,6 +56,7 @@ def paged_sliding_window_attn_prefill(
|
|
|
53
56
|
- cache_offset: [] - The valid length in the combined sequence of the KV cache and the current projected key states.
|
|
54
57
|
- scale: [] - Attention scale factor
|
|
55
58
|
- is_bidirectional: [] - Whether the attention is bidirectional
|
|
59
|
+
- s_aux: [num_attention_heads, sink_len] - auxiliary states for attention
|
|
56
60
|
Returns:
|
|
57
61
|
Tensor: attn_output: [batch=1, n_heads, n_groups, seq_len, head_dim] - Attention output
|
|
58
62
|
"""
|
|
@@ -72,6 +76,7 @@ def paged_sliding_window_attn_prefill_fake(
|
|
|
72
76
|
block_table: Tensor,
|
|
73
77
|
block_size: int,
|
|
74
78
|
is_bidirectional: bool,
|
|
79
|
+
s_aux: Optional[Tensor] = None,
|
|
75
80
|
) -> Tensor:
|
|
76
81
|
return torch.empty_like(q)
|
|
77
82
|
|
|
@@ -91,6 +96,8 @@ def paged_sliding_window_attn_decode(
|
|
|
91
96
|
scale: Tensor,
|
|
92
97
|
block_table: Tensor,
|
|
93
98
|
block_size: int,
|
|
99
|
+
attn_mask: Tensor,
|
|
100
|
+
s_aux: Optional[Tensor] = None,
|
|
94
101
|
) -> Tensor:
|
|
95
102
|
return torch.empty_like(q)
|
|
96
103
|
|
|
@@ -107,5 +114,7 @@ def paged_sliding_window_attn_decode_fake(
|
|
|
107
114
|
scale: Tensor,
|
|
108
115
|
block_table: Tensor,
|
|
109
116
|
block_size: int,
|
|
117
|
+
attn_mask: Tensor,
|
|
118
|
+
s_aux: Optional[Tensor] = None,
|
|
110
119
|
) -> Tensor:
|
|
111
120
|
return torch.empty_like(q)
|