optimum-rbln 0.9.4a2__py3-none-any.whl → 0.10.0.post1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (108) hide show
  1. optimum/rbln/__init__.py +44 -0
  2. optimum/rbln/__version__.py +2 -2
  3. optimum/rbln/configuration_utils.py +230 -67
  4. optimum/rbln/diffusers/models/controlnet.py +2 -2
  5. optimum/rbln/diffusers/models/transformers/prior_transformer.py +2 -2
  6. optimum/rbln/diffusers/models/transformers/transformer_cosmos.py +2 -2
  7. optimum/rbln/diffusers/models/transformers/transformer_sd3.py +2 -2
  8. optimum/rbln/diffusers/pipelines/auto_pipeline.py +2 -3
  9. optimum/rbln/diffusers/pipelines/controlnet/pipeline_controlnet.py +3 -12
  10. optimum/rbln/diffusers/pipelines/controlnet/pipeline_controlnet_img2img.py +2 -4
  11. optimum/rbln/diffusers/pipelines/controlnet/pipeline_controlnet_sd_xl.py +1 -3
  12. optimum/rbln/diffusers/pipelines/controlnet/pipeline_controlnet_sd_xl_img2img.py +1 -3
  13. optimum/rbln/diffusers/pipelines/cosmos/cosmos_guardrail.py +2 -2
  14. optimum/rbln/modeling_base.py +11 -10
  15. optimum/rbln/ops/__init__.py +1 -0
  16. optimum/rbln/ops/attn.py +10 -0
  17. optimum/rbln/ops/flash_attn.py +8 -0
  18. optimum/rbln/ops/moe.py +180 -0
  19. optimum/rbln/ops/sliding_window_attn.py +9 -0
  20. optimum/rbln/transformers/__init__.py +44 -0
  21. optimum/rbln/transformers/modeling_attention_utils.py +124 -222
  22. optimum/rbln/transformers/modeling_outputs.py +25 -0
  23. optimum/rbln/transformers/modeling_rope_utils.py +78 -42
  24. optimum/rbln/transformers/models/__init__.py +38 -0
  25. optimum/rbln/transformers/models/auto/auto_factory.py +3 -3
  26. optimum/rbln/transformers/models/bart/bart_architecture.py +24 -24
  27. optimum/rbln/transformers/models/blip_2/configuration_blip_2.py +7 -2
  28. optimum/rbln/transformers/models/blip_2/modeling_blip_2.py +1 -1
  29. optimum/rbln/transformers/models/colpali/colpali_architecture.py +14 -20
  30. optimum/rbln/transformers/models/colpali/configuration_colpali.py +12 -17
  31. optimum/rbln/transformers/models/colpali/modeling_colpali.py +66 -182
  32. optimum/rbln/transformers/models/colqwen2/configuration_colqwen2.py +40 -23
  33. optimum/rbln/transformers/models/colqwen2/modeling_colqwen2.py +107 -371
  34. optimum/rbln/transformers/models/decoderonly/__init__.py +2 -0
  35. optimum/rbln/transformers/models/decoderonly/configuration_decoderonly.py +144 -17
  36. optimum/rbln/transformers/models/decoderonly/configuration_lora.py +1 -1
  37. optimum/rbln/transformers/models/decoderonly/decoderonly_architecture.py +122 -48
  38. optimum/rbln/transformers/models/decoderonly/decoderonly_runtime_utils.py +5 -7
  39. optimum/rbln/transformers/models/decoderonly/modeling_decoderonly.py +120 -128
  40. optimum/rbln/transformers/models/detr/__init__.py +23 -0
  41. optimum/rbln/transformers/models/detr/configuration_detr.py +38 -0
  42. optimum/rbln/transformers/models/detr/modeling_detr.py +53 -0
  43. optimum/rbln/transformers/models/exaone/exaone_architecture.py +0 -36
  44. optimum/rbln/transformers/models/gemma/gemma_architecture.py +1 -1
  45. optimum/rbln/transformers/models/gemma2/__init__.py +16 -0
  46. optimum/rbln/transformers/models/gemma2/configuration_gemma2.py +45 -0
  47. optimum/rbln/transformers/models/gemma2/gemma2_architecture.py +83 -0
  48. optimum/rbln/transformers/models/gemma2/modeling_gemma2.py +101 -0
  49. optimum/rbln/transformers/models/gemma3/configuration_gemma3.py +2 -7
  50. optimum/rbln/transformers/models/gemma3/gemma3_architecture.py +16 -18
  51. optimum/rbln/transformers/models/gemma3/modeling_gemma3.py +5 -177
  52. optimum/rbln/transformers/models/gpt2/gpt2_architecture.py +8 -34
  53. optimum/rbln/transformers/models/gpt_oss/__init__.py +16 -0
  54. optimum/rbln/transformers/models/gpt_oss/configuration_gpt_oss.py +42 -0
  55. optimum/rbln/transformers/models/gpt_oss/gpt_oss_architecture.py +122 -0
  56. optimum/rbln/transformers/models/gpt_oss/modeling_gpt_oss.py +168 -0
  57. optimum/rbln/transformers/models/grounding_dino/configuration_grounding_dino.py +8 -5
  58. optimum/rbln/transformers/models/grounding_dino/grounding_dino_architecture.py +6 -4
  59. optimum/rbln/transformers/models/llava/modeling_llava.py +0 -1
  60. optimum/rbln/transformers/models/midm/midm_architecture.py +29 -22
  61. optimum/rbln/transformers/models/mixtral/__init__.py +16 -0
  62. optimum/rbln/transformers/models/mixtral/configuration_mixtral.py +38 -0
  63. optimum/rbln/transformers/models/mixtral/mixtral_architecture.py +76 -0
  64. optimum/rbln/transformers/models/mixtral/modeling_mixtral.py +68 -0
  65. optimum/rbln/transformers/models/opt/opt_architecture.py +1 -44
  66. optimum/rbln/transformers/models/paligemma/__init__.py +16 -0
  67. optimum/rbln/transformers/models/paligemma/configuration_paligemma.py +129 -0
  68. optimum/rbln/transformers/models/paligemma/modeling_paligemma.py +564 -0
  69. optimum/rbln/transformers/models/pegasus/pegasus_architecture.py +24 -24
  70. optimum/rbln/transformers/models/phi/phi_architecture.py +13 -21
  71. optimum/rbln/transformers/models/pixtral/modeling_pixtral.py +9 -5
  72. optimum/rbln/transformers/models/qwen2_5_vl/__init__.py +6 -1
  73. optimum/rbln/transformers/models/qwen2_5_vl/configuration_qwen2_5_vl.py +13 -1
  74. optimum/rbln/transformers/models/qwen2_5_vl/modeling_qwen2_5_vl.py +271 -122
  75. optimum/rbln/transformers/models/qwen2_5_vl/qwen2_5_vl_architecture.py +43 -39
  76. optimum/rbln/transformers/models/qwen2_moe/__init__.py +16 -0
  77. optimum/rbln/transformers/models/qwen2_moe/configuration_qwen2_moe.py +38 -0
  78. optimum/rbln/transformers/models/qwen2_moe/modeling_qwen2_moe.py +68 -0
  79. optimum/rbln/transformers/models/qwen2_moe/qwen2_moe_architecture.py +94 -0
  80. optimum/rbln/transformers/models/qwen2_vl/__init__.py +6 -1
  81. optimum/rbln/transformers/models/qwen2_vl/configuration_qwen2_vl.py +13 -1
  82. optimum/rbln/transformers/models/qwen2_vl/modeling_qwen2_vl.py +263 -105
  83. optimum/rbln/transformers/models/qwen2_vl/qwen2_vl_architecture.py +26 -34
  84. optimum/rbln/transformers/models/qwen3/qwen3_architecture.py +7 -7
  85. optimum/rbln/transformers/models/qwen3_moe/__init__.py +16 -0
  86. optimum/rbln/transformers/models/qwen3_moe/configuration_qwen3_moe.py +38 -0
  87. optimum/rbln/transformers/models/qwen3_moe/modeling_qwen3_moe.py +68 -0
  88. optimum/rbln/transformers/models/qwen3_moe/qwen3_moe_architecture.py +100 -0
  89. optimum/rbln/transformers/models/resnet/configuration_resnet.py +10 -4
  90. optimum/rbln/transformers/models/seq2seq/seq2seq_architecture.py +14 -12
  91. optimum/rbln/transformers/models/siglip/modeling_siglip.py +4 -18
  92. optimum/rbln/transformers/models/swin/configuration_swin.py +1 -6
  93. optimum/rbln/transformers/models/t5/t5_architecture.py +15 -16
  94. optimum/rbln/transformers/models/time_series_transformer/time_series_transformers_architecture.py +0 -3
  95. optimum/rbln/transformers/models/whisper/generation_whisper.py +8 -8
  96. optimum/rbln/transformers/models/whisper/whisper_architecture.py +0 -3
  97. optimum/rbln/transformers/utils/rbln_quantization.py +20 -12
  98. optimum/rbln/utils/deprecation.py +78 -1
  99. optimum/rbln/utils/hub.py +93 -2
  100. optimum/rbln/utils/import_utils.py +16 -1
  101. optimum/rbln/utils/runtime_utils.py +12 -8
  102. optimum/rbln/utils/submodule.py +24 -0
  103. {optimum_rbln-0.9.4a2.dist-info → optimum_rbln-0.10.0.post1.dist-info}/METADATA +6 -6
  104. {optimum_rbln-0.9.4a2.dist-info → optimum_rbln-0.10.0.post1.dist-info}/RECORD +107 -81
  105. optimum/rbln/transformers/models/colqwen2/colqwen2_architecture.py +0 -233
  106. {optimum_rbln-0.9.4a2.dist-info → optimum_rbln-0.10.0.post1.dist-info}/WHEEL +0 -0
  107. {optimum_rbln-0.9.4a2.dist-info → optimum_rbln-0.10.0.post1.dist-info}/entry_points.txt +0 -0
  108. {optimum_rbln-0.9.4a2.dist-info → optimum_rbln-0.10.0.post1.dist-info}/licenses/LICENSE +0 -0
@@ -176,7 +176,7 @@ class RBLNAutoPipelineBase:
176
176
  export: bool = None,
177
177
  rbln_config: Optional[Union[Dict[str, Any], RBLNModelConfig]] = None,
178
178
  **kwargs: Any,
179
- ):
179
+ ) -> RBLNBaseModel:
180
180
  """
181
181
  Load an RBLN-accelerated Diffusers pipeline from a pretrained checkpoint or a compiled RBLN artifact.
182
182
 
@@ -201,8 +201,7 @@ class RBLNAutoPipelineBase:
201
201
  - Remaining arguments are forwarded to the Diffusers loader.
202
202
 
203
203
  Returns:
204
- RBLNBaseModel: An instantiated RBLN model wrapping the Diffusers pipeline, ready for
205
- inference on RBLN NPUs.
204
+ RBLNBaseModel: An instantiated RBLN model wrapping the Diffusers pipeline, ready for inference on RBLN NPUs.
206
205
 
207
206
  """
208
207
  rbln_cls = cls.get_rbln_cls(model_id, export=export, **kwargs)
@@ -26,7 +26,7 @@
26
26
  # See the License for the specific language governing permissions and
27
27
  # limitations under the License.
28
28
 
29
- from typing import Any, Callable, Dict, List, Optional, Union
29
+ from typing import Any, Callable, Dict, List, Optional, Tuple, Union
30
30
 
31
31
  import torch
32
32
  import torch.nn.functional as F
@@ -260,7 +260,7 @@ class RBLNStableDiffusionControlNetPipeline(RBLNDiffusionMixin, StableDiffusionC
260
260
  callback_on_step_end: Optional[Callable[[int, int, Dict], None]] = None,
261
261
  callback_on_step_end_tensor_inputs: Optional[List[str]] = None,
262
262
  **kwargs,
263
- ):
263
+ ) -> Union[StableDiffusionPipelineOutput, Tuple]:
264
264
  r"""
265
265
  The call function to the pipeline for generation.
266
266
 
@@ -321,14 +321,7 @@ class RBLNStableDiffusionControlNetPipeline(RBLNDiffusionMixin, StableDiffusionC
321
321
  output_type (`str`, *optional*, defaults to `"pil"`):
322
322
  The output format of the generated image. Choose between `PIL.Image` or `np.array`.
323
323
  return_dict (`bool`, *optional*, defaults to `True`):
324
- Whether or not to return a [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] instead of a
325
- plain tuple.
326
- callback (`Callable`, *optional*):
327
- A function that calls every `callback_steps` steps during inference. The function is called with the
328
- following arguments: `callback(step: int, timestep: int, latents: torch.FloatTensor)`.
329
- callback_steps (`int`, *optional*, defaults to 1):
330
- The frequency at which the `callback` function is called. If not specified, the callback is called at
331
- every step.
324
+ Whether or not to return a [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] instead of a plain tuple.
332
325
  cross_attention_kwargs (`dict`, *optional*):
333
326
  A kwargs dictionary that if specified is passed along to the [`AttentionProcessor`] as defined in
334
327
  [`self.processor`](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py).
@@ -356,8 +349,6 @@ class RBLNStableDiffusionControlNetPipeline(RBLNDiffusionMixin, StableDiffusionC
356
349
  will be passed as `callback_kwargs` argument. You will only be able to include variables listed in the
357
350
  `._callback_tensor_inputs` attribute of your pipeine class.
358
351
 
359
- Examples:
360
-
361
352
  Returns:
362
353
  [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] or `tuple`:
363
354
  If `return_dict` is `True`, [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] is returned,
@@ -26,7 +26,7 @@
26
26
  # See the License for the specific language governing permissions and
27
27
  # limitations under the License.
28
28
 
29
- from typing import Any, Callable, Dict, List, Optional, Union
29
+ from typing import Any, Callable, Dict, List, Optional, Tuple, Union
30
30
 
31
31
  import torch
32
32
  import torch.nn.functional as F
@@ -253,7 +253,7 @@ class RBLNStableDiffusionControlNetImg2ImgPipeline(RBLNDiffusionMixin, StableDif
253
253
  callback_on_step_end: Optional[Callable[[int, int, Dict], None]] = None,
254
254
  callback_on_step_end_tensor_inputs: Optional[List[str]] = None,
255
255
  **kwargs,
256
- ):
256
+ ) -> Union[StableDiffusionPipelineOutput, Tuple]:
257
257
  r"""
258
258
  The call function to the pipeline for generation.
259
259
 
@@ -347,8 +347,6 @@ class RBLNStableDiffusionControlNetImg2ImgPipeline(RBLNDiffusionMixin, StableDif
347
347
  will be passed as `callback_kwargs` argument. You will only be able to include variables listed in the
348
348
  `._callback_tensor_inputs` attribute of your pipeine class.
349
349
 
350
- Examples:
351
-
352
350
  Returns:
353
351
  [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] or `tuple`:
354
352
  If `return_dict` is `True`, [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] is returned,
@@ -294,7 +294,7 @@ class RBLNStableDiffusionXLControlNetPipeline(RBLNDiffusionMixin, StableDiffusio
294
294
  callback_on_step_end: Optional[Callable[[int, int, Dict], None]] = None,
295
295
  callback_on_step_end_tensor_inputs: Optional[List[str]] = None,
296
296
  **kwargs,
297
- ):
297
+ ) -> Union[StableDiffusionXLPipelineOutput, Tuple]:
298
298
  r"""
299
299
  The call function to the pipeline for generation.
300
300
 
@@ -431,8 +431,6 @@ class RBLNStableDiffusionXLControlNetPipeline(RBLNDiffusionMixin, StableDiffusio
431
431
  will be passed as `callback_kwargs` argument. You will only be able to include variables listed in the
432
432
  `._callback_tensor_inputs` attribute of your pipeine class.
433
433
 
434
- Examples:
435
-
436
434
  Returns:
437
435
  [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] or `tuple`:
438
436
  If `return_dict` is `True`, [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] is returned,
@@ -309,7 +309,7 @@ class RBLNStableDiffusionXLControlNetImg2ImgPipeline(RBLNDiffusionMixin, StableD
309
309
  callback_on_step_end: Optional[Callable[[int, int, Dict], None]] = None,
310
310
  callback_on_step_end_tensor_inputs: Optional[List[str]] = None,
311
311
  **kwargs,
312
- ):
312
+ ) -> Union[StableDiffusionXLPipelineOutput, Tuple]:
313
313
  r"""
314
314
  Function invoked when calling the pipeline for generation.
315
315
 
@@ -465,8 +465,6 @@ class RBLNStableDiffusionXLControlNetImg2ImgPipeline(RBLNDiffusionMixin, StableD
465
465
  will be passed as `callback_kwargs` argument. You will only be able to include variables listed in the
466
466
  `._callback_tensor_inputs` attribute of your pipeine class.
467
467
 
468
- Examples:
469
-
470
468
  Returns:
471
469
  [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] or `tuple`:
472
470
  [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] if `return_dict` is True, otherwise a `tuple`
@@ -203,7 +203,7 @@ class RBLNRetinaFaceFilter(RetinaFaceFilter):
203
203
  f"If you only need to compile the model without loading it to NPU, you can use:\n"
204
204
  f" from_pretrained(..., rbln_create_runtimes=False) or\n"
205
205
  f" from_pretrained(..., rbln_config={{..., 'create_runtimes': False}})\n\n"
206
- f"To check your NPU status, run the 'rbln-stat' command in your terminal.\n"
206
+ f"To check your NPU status, run the 'rbln-smi' command in your terminal.\n"
207
207
  f"Make sure your NPU is properly installed and operational."
208
208
  )
209
209
  raise rebel.core.exception.RBLNRuntimeError(error_msg) from e
@@ -278,7 +278,7 @@ class RBLNVideoSafetyModel(VideoSafetyModel):
278
278
  f"If you only need to compile the model without loading it to NPU, you can use:\n"
279
279
  f" from_pretrained(..., rbln_create_runtimes=False) or\n"
280
280
  f" from_pretrained(..., rbln_config={{..., 'create_runtimes': False}})\n\n"
281
- f"To check your NPU status, run the 'rbln-stat' command in your terminal.\n"
281
+ f"To check your NPU status, run the 'rbln-smi' command in your terminal.\n"
282
282
  f"Make sure your NPU is properly installed and operational."
283
283
  )
284
284
  raise rebel.core.exception.RBLNRuntimeError(error_msg) from e
@@ -24,7 +24,7 @@ import torch
24
24
  from transformers import AutoConfig, AutoModel, GenerationConfig, PretrainedConfig
25
25
  from transformers.utils.hub import PushToHubMixin
26
26
 
27
- from .configuration_utils import RBLNAutoConfig, RBLNCompileConfig, RBLNModelConfig, get_rbln_config_class
27
+ from .configuration_utils import RBLNCompileConfig, RBLNModelConfig, get_rbln_config_class
28
28
  from .utils.hub import pull_compiled_model_from_hub, validate_files
29
29
  from .utils.logging import get_logger
30
30
  from .utils.runtime_utils import UnavailableRuntime, tp_and_devices_are_ok
@@ -90,7 +90,7 @@ class RBLNBaseModel(SubModulesMixin, PushToHubMixin, PreTrainedModel):
90
90
 
91
91
  self.device = torch.device("cpu")
92
92
  self.training = False
93
- self.dtype = rbln_config.torch_dtype
93
+ self.dtype = rbln_config.dtype
94
94
 
95
95
  # FIXME :: model_save_dir is not used after initialized. (This can be used when save/load)
96
96
  # This attribute is needed to keep one reference on the temporary directory, since garbage collecting it
@@ -206,8 +206,9 @@ class RBLNBaseModel(SubModulesMixin, PushToHubMixin, PreTrainedModel):
206
206
  f"does not match the expected model class name ({cls.__name__})."
207
207
  )
208
208
 
209
- rbln_config, kwargs = RBLNAutoConfig.load(
210
- model_path_subfolder, passed_rbln_config=rbln_config, kwargs=kwargs, return_unused_kwargs=True
209
+ config_cls = cls.get_rbln_config_class()
210
+ rbln_config, kwargs = config_cls.from_pretrained(
211
+ model_path_subfolder, rbln_config=rbln_config, return_unused_kwargs=True, **kwargs
211
212
  )
212
213
 
213
214
  if rbln_config.rbln_model_cls_name != cls.__name__:
@@ -223,8 +224,6 @@ class RBLNBaseModel(SubModulesMixin, PushToHubMixin, PreTrainedModel):
223
224
  elif rbln_submodules is None:
224
225
  rbln_submodules = []
225
226
 
226
- rbln_config.freeze()
227
-
228
227
  if config is None:
229
228
  if cls.hf_library_name == "transformers":
230
229
  config = AutoConfig.from_pretrained(
@@ -308,11 +307,13 @@ class RBLNBaseModel(SubModulesMixin, PushToHubMixin, PreTrainedModel):
308
307
  f"If you only need to compile the model without loading it to NPU, you can use:\n"
309
308
  f" from_pretrained(..., rbln_create_runtimes=False) or\n"
310
309
  f" from_pretrained(..., rbln_config={{..., 'create_runtimes': False}})\n\n"
311
- f"To check your NPU status, run the 'rbln-stat' command in your terminal.\n"
310
+ f"To check your NPU status, run the 'rbln-smi' command in your terminal.\n"
312
311
  f"Make sure your NPU is properly installed and operational."
313
312
  )
314
313
  raise rebel.core.exception.RBLNRuntimeError(error_msg) from e
315
314
 
315
+ rbln_config.freeze()
316
+
316
317
  return cls(
317
318
  models,
318
319
  config,
@@ -451,15 +452,15 @@ class RBLNBaseModel(SubModulesMixin, PushToHubMixin, PreTrainedModel):
451
452
  model_config: "PretrainedConfig",
452
453
  rbln_config: RBLNModelConfig,
453
454
  ) -> RBLNModelConfig:
454
- rbln_config.torch_dtype = model.dtype
455
- if not cls._supports_non_fp32 and rbln_config.torch_dtype != torch.float32:
455
+ rbln_config.dtype = model.dtype
456
+ if not cls._supports_non_fp32 and rbln_config.dtype != torch.float32:
456
457
  raise NotImplementedError(
457
458
  f"Currently, {cls.__name__} does not support non-fp32 dtype. Please use float32 dtype."
458
459
  )
459
460
  rbln_config = cls._update_rbln_config(
460
461
  preprocessors=preprocessors, model=model, model_config=model_config, rbln_config=rbln_config
461
462
  )
462
- rbln_config.freeze()
463
+
463
464
  if rbln_config.rbln_model_cls_name != cls.__name__:
464
465
  raise NameError(
465
466
  f"Cannot get the rbln config. {cls.__name__} is not the same as {rbln_config.rbln_model_cls_name}. "
@@ -16,4 +16,5 @@ from .attn import *
16
16
  from .flash_attn import *
17
17
  from .kv_cache_update import *
18
18
  from .linear import linear
19
+ from .moe import *
19
20
  from .sliding_window_attn import *
optimum/rbln/ops/attn.py CHANGED
@@ -205,6 +205,7 @@ def paged_causal_attn_decode(
205
205
  block_table: Tensor,
206
206
  block_size: int,
207
207
  mask: Optional[Tensor] = None,
208
+ s_aux: Optional[Tensor] = None,
208
209
  ) -> Tensor:
209
210
  """Defines the computation pattern for fused attention with KV cache updates.
210
211
 
@@ -228,6 +229,7 @@ def paged_causal_attn_decode(
228
229
  - block_table: [batch_size, max_seq_len // block_size] - Block indices for KV cache management
229
230
  - block_size: [] - Number of tokens per block
230
231
  - mask: [batch=1, max_seq_len] - attention mask when use position_ids
232
+ - s_aux: [num_attention_heads, sink_len] - auxiliary states for attention
231
233
 
232
234
  Returns:
233
235
  Tensor: attn_output: [batch=1, n_heads, n_groups, 1, head_dim] - Attention output
@@ -247,6 +249,7 @@ def paged_causal_attn_decode_fake(
247
249
  block_table: Tensor,
248
250
  block_size: int,
249
251
  mask: Optional[Tensor] = None,
252
+ s_aux: Optional[Tensor] = None,
250
253
  ) -> Tensor:
251
254
  return torch.empty_like(q)
252
255
 
@@ -267,6 +270,7 @@ def paged_causal_attn_prefill(
267
270
  block_size: int,
268
271
  is_bidirectional: bool,
269
272
  mask: Optional[Tensor] = None,
273
+ s_aux: Optional[Tensor] = None,
270
274
  ) -> Tensor:
271
275
  """Defines the computation pattern for prefill phase attention with KV cache updates.
272
276
 
@@ -290,6 +294,7 @@ def paged_causal_attn_prefill(
290
294
  - block_size: [] - Number of tokens per block
291
295
  - is_bidirectional: [] - Whether the attention is bidirectional at current sequence position
292
296
  - mask: [batch=1, max_seq_len] - attention mask when use position_ids
297
+ - s_aux: [num_attention_heads, sink_len] - auxiliary states for attention
293
298
 
294
299
  Returns:
295
300
  Tensor: attn_output: [batch=1, n_heads, n_groups, seq_len, head_dim] - Attention output
@@ -310,6 +315,7 @@ def paged_causal_attn_prefill_fake(
310
315
  block_size: int,
311
316
  is_bidirectional: bool,
312
317
  mask: Optional[Tensor] = None,
318
+ s_aux: Optional[Tensor] = None,
313
319
  ) -> Tensor:
314
320
  return torch.empty_like(q)
315
321
 
@@ -331,6 +337,7 @@ def paged_causal_attn_decode_kv_fp8(
331
337
  k_scale: Tensor,
332
338
  v_scale: Tensor,
333
339
  mask: Optional[Tensor] = None,
340
+ s_aux: Optional[Tensor] = None,
334
341
  ) -> Tensor:
335
342
  return torch.empty_like(q)
336
343
 
@@ -349,6 +356,7 @@ def paged_causal_attn_decode_kv_fp8_fake(
349
356
  k_scale: Tensor,
350
357
  v_scale: Tensor,
351
358
  mask: Optional[Tensor] = None,
359
+ s_aux: Optional[Tensor] = None,
352
360
  ) -> Tensor:
353
361
  return torch.empty_like(q)
354
362
 
@@ -371,6 +379,7 @@ def paged_causal_attn_prefill_kv_fp8(
371
379
  k_scale: Tensor,
372
380
  v_scale: Tensor,
373
381
  mask: Optional[Tensor] = None,
382
+ s_aux: Optional[Tensor] = None,
374
383
  ) -> Tensor:
375
384
  return torch.empty_like(q)
376
385
 
@@ -390,6 +399,7 @@ def paged_causal_attn_prefill_kv_fp8_fake(
390
399
  k_scale: Tensor,
391
400
  v_scale: Tensor,
392
401
  mask: Optional[Tensor] = None,
402
+ s_aux: Optional[Tensor] = None,
393
403
  ) -> Tensor:
394
404
  return torch.empty_like(q)
395
405
 
@@ -198,6 +198,7 @@ def paged_flash_causal_attn_decode(
198
198
  block_size: int,
199
199
  partition: int,
200
200
  mask: Optional[Tensor] = None,
201
+ s_aux: Optional[Tensor] = None,
201
202
  ) -> Tensor:
202
203
  """Defines the computation pattern for fused causal flash attention with KV cache for decoding.
203
204
 
@@ -219,6 +220,7 @@ def paged_flash_causal_attn_decode_fake(
219
220
  block_size: int,
220
221
  partition: int,
221
222
  mask: Optional[Tensor] = None,
223
+ s_aux: Optional[Tensor] = None,
222
224
  ) -> Tensor:
223
225
  return torch.empty_like(q)
224
226
 
@@ -241,6 +243,7 @@ def paged_flash_causal_attn_decode_kv_fp8(
241
243
  k_scale: Tensor,
242
244
  v_scale: Tensor,
243
245
  mask: Optional[Tensor] = None,
246
+ s_aux: Optional[Tensor] = None,
244
247
  ) -> Tensor:
245
248
  return torch.empty_like(q)
246
249
 
@@ -260,6 +263,7 @@ def paged_flash_causal_attn_decode_kv_fp8_fake(
260
263
  k_scale: Tensor,
261
264
  v_scale: Tensor,
262
265
  mask: Optional[Tensor] = None,
266
+ s_aux: Optional[Tensor] = None,
263
267
  ) -> Tensor:
264
268
  return torch.empty_like(q)
265
269
 
@@ -281,6 +285,7 @@ def paged_flash_causal_attn_prefill(
281
285
  partition: int,
282
286
  is_bidirectional: bool,
283
287
  mask: Optional[Tensor] = None,
288
+ s_aux: Optional[Tensor] = None,
284
289
  ) -> Tensor:
285
290
  """Defines the computation pattern for fused causal flash attention with KV cache for prefill.
286
291
 
@@ -303,6 +308,7 @@ def paged_flash_causal_attn_prefill_fake(
303
308
  partition: int,
304
309
  is_bidirectional: bool,
305
310
  mask: Optional[Tensor] = None,
311
+ s_aux: Optional[Tensor] = None,
306
312
  ) -> Tensor:
307
313
  return torch.empty_like(q)
308
314
 
@@ -326,6 +332,7 @@ def paged_flash_causal_attn_prefill_kv_fp8(
326
332
  k_scale: Tensor,
327
333
  v_scale: Tensor,
328
334
  mask: Optional[Tensor] = None,
335
+ s_aux: Optional[Tensor] = None,
329
336
  ) -> Tensor:
330
337
  return torch.empty_like(q)
331
338
 
@@ -346,5 +353,6 @@ def paged_flash_causal_attn_prefill_kv_fp8_fake(
346
353
  k_scale: Tensor,
347
354
  v_scale: Tensor,
348
355
  mask: Optional[Tensor] = None,
356
+ s_aux: Optional[Tensor] = None,
349
357
  ) -> Tensor:
350
358
  return torch.empty_like(q)
@@ -0,0 +1,180 @@
1
+ # Copyright 2025 Rebellions Inc. All rights reserved.
2
+
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at:
6
+
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ from typing import Optional
16
+
17
+ import torch
18
+ from torch import Tensor
19
+
20
+
21
+ @torch.library.custom_op(
22
+ "rbln_custom_ops::custom_moe_glu",
23
+ mutates_args=(),
24
+ )
25
+ def custom_moe_glu(
26
+ hidden_states: Tensor,
27
+ gate_proj_weight: Tensor,
28
+ up_proj_weight: Tensor,
29
+ down_proj_weight: Tensor,
30
+ router_logits: Tensor,
31
+ topk: int,
32
+ norm_topk_prob: bool,
33
+ gate_proj_bias: Optional[Tensor] = None,
34
+ up_proj_bias: Optional[Tensor] = None,
35
+ down_proj_bias: Optional[Tensor] = None,
36
+ ) -> Tensor:
37
+ """
38
+ Customized MoE GLU operation.
39
+
40
+ Expected tensor shapes:
41
+ - hidden_states: [batch*seq_len, hidden_size]
42
+ - gate_proj_weight: [num_experts, hidden_size, intermediate_size]
43
+ - up_proj_weight: [num_experts, hidden_size, intermediate_size]
44
+ - down_proj_weight: [num_experts, intermediate_size, hidden_size]
45
+ - router_logits: [batch*seq_len, num_experts]
46
+ - topk: top k experts to select
47
+ - norm_topk_prob: whether to normalize the top k routing weights with softmax
48
+ - gate_proj_bias: [num_experts, intermediate_size]
49
+ - up_proj_bias: [num_experts, intermediate_size]
50
+ - down_proj_bias: [num_experts, hidden_size]
51
+
52
+ Returns:
53
+ Tensor: [batch * seq_len, hidden_size]
54
+ """
55
+
56
+ return torch.empty_like(hidden_states)
57
+
58
+
59
+ @custom_moe_glu.register_fake
60
+ def custom_moe_glu_fake(
61
+ hidden_states: Tensor,
62
+ gate_proj_weight: Tensor,
63
+ up_proj_weight: Tensor,
64
+ down_proj_weight: Tensor,
65
+ router_logits: Tensor,
66
+ topk: int,
67
+ norm_topk_prob: bool,
68
+ gate_proj_bias: Optional[Tensor] = None,
69
+ up_proj_bias: Optional[Tensor] = None,
70
+ down_proj_bias: Optional[Tensor] = None,
71
+ ) -> Tensor:
72
+ return torch.empty_like(hidden_states)
73
+
74
+
75
+ @torch.library.custom_op(
76
+ "rbln_custom_ops::custom_moe_ff",
77
+ mutates_args=(),
78
+ )
79
+ def custom_moe_ff(
80
+ hidden_states: Tensor,
81
+ gate_proj_weight: Tensor,
82
+ down_proj_weight: Tensor,
83
+ masked_routing_weight: Tensor,
84
+ gate_proj_bias: Optional[Tensor] = None,
85
+ down_proj_bias: Optional[Tensor] = None,
86
+ ) -> Tensor:
87
+ """
88
+ Customized MoE FF operation.
89
+
90
+ Expected tensor shapes:
91
+ - hidden_states: [batch * seq_len, hidden_size]
92
+ - gate_proj_weight: [hidden_size, num_experts * intermediate_size]
93
+ - down_proj_weight: [num_experts * intermediate_size, hidden_size]
94
+ - masked_routing_weight: [batch * seq_len, num_experts]
95
+ - gate_proj_bias: [num_experts * intermediate_size]
96
+ - down_proj_bias: [hidden_size]
97
+
98
+ Returns:
99
+ Tensor: [batch * seq_len, hidden_size]
100
+ """
101
+ return torch.empty_like(hidden_states)
102
+
103
+
104
+ @custom_moe_ff.register_fake
105
+ def custom_moe_ff_fake(
106
+ hidden_states: Tensor,
107
+ gate_proj_weight: Tensor,
108
+ down_proj_weight: Tensor,
109
+ masked_routing_weight: Tensor,
110
+ gate_proj_bias: Optional[Tensor] = None,
111
+ down_proj_bias: Optional[Tensor] = None,
112
+ ) -> Tensor:
113
+ return torch.empty_like(hidden_states)
114
+
115
+
116
+ @torch.library.custom_op(
117
+ "rbln_custom_ops::custom_moe_glu_mxfp4",
118
+ mutates_args=(),
119
+ )
120
+ def custom_moe_glu_mxfp4(
121
+ hidden_states: Tensor,
122
+ gate_proj_blocks: Tensor,
123
+ gate_proj_scales: Tensor,
124
+ gate_proj_bias: Tensor,
125
+ up_proj_blocks: Tensor,
126
+ up_proj_scales: Tensor,
127
+ up_proj_bias: Tensor,
128
+ down_proj_blocks: Tensor,
129
+ down_proj_scales: Tensor,
130
+ down_proj_bias: Tensor,
131
+ router_logits: Tensor,
132
+ alpha: Tensor,
133
+ limit: Tensor,
134
+ k: int,
135
+ post_norm: bool,
136
+ ) -> Tensor:
137
+ """
138
+ Customized MoE GLU operation.
139
+
140
+ Expected tensor shapes:
141
+ - hidden_states: [batch*seq_len, hidden_size]
142
+ - gate_proj_blocks: [num_experts, intermediate_size, hidden_size // 2]
143
+ - gate_proj_scales: [num_experts, intermediate_size, hidden_size // 32]
144
+ - gate_proj_bias: [num_experts, intermediate_size]
145
+ - up_proj_blocks: [num_experts, intermediate_size, hidden_size // 2]
146
+ - up_proj_scales: [num_experts, intermediate_size, hidden_size // 32]
147
+ - up_proj_bias: [num_experts, intermediate_size]
148
+ - down_proj_blocks: [num_experts, hidden_size, intermediate_size // 2]
149
+ - down_proj_scales: [num_experts, hidden_size, intermediate_size // 32]
150
+ - masked_routing_weight: [batch * seq_len, num_experts]
151
+ - expert_select_count: [num_experts]
152
+ - alpha: []
153
+ - limit: []
154
+
155
+ Returns:
156
+ Tensor: [batch * seq_len, hidden_size]
157
+ """
158
+
159
+ return torch.empty_like(hidden_states)
160
+
161
+
162
+ @custom_moe_glu_mxfp4.register_fake
163
+ def custom_moe_glu_mxfp4_fake(
164
+ hidden_states: Tensor,
165
+ gate_proj_blocks: Tensor,
166
+ gate_proj_scales: Tensor,
167
+ gate_proj_bias: Tensor,
168
+ up_proj_blocks: Tensor,
169
+ up_proj_scales: Tensor,
170
+ up_proj_bias: Tensor,
171
+ down_proj_blocks: Tensor,
172
+ down_proj_scales: Tensor,
173
+ down_proj_bias: Tensor,
174
+ router_logits: Tensor,
175
+ alpha: Tensor,
176
+ limit: Tensor,
177
+ k: int,
178
+ post_norm: bool,
179
+ ) -> Tensor:
180
+ return torch.empty_like(hidden_states)
@@ -13,6 +13,8 @@
13
13
  # limitations under the License.
14
14
 
15
15
 
16
+ from typing import Optional
17
+
16
18
  import torch
17
19
  from torch import Tensor
18
20
 
@@ -33,6 +35,7 @@ def paged_sliding_window_attn_prefill(
33
35
  block_table: Tensor,
34
36
  block_size: int,
35
37
  is_bidirectional: bool,
38
+ s_aux: Optional[Tensor] = None,
36
39
  ) -> Tensor:
37
40
  """Defines the computation pattern for prefill phase attention with KV cache updates.
38
41
 
@@ -53,6 +56,7 @@ def paged_sliding_window_attn_prefill(
53
56
  - cache_offset: [] - The valid length in the combined sequence of the KV cache and the current projected key states.
54
57
  - scale: [] - Attention scale factor
55
58
  - is_bidirectional: [] - Whether the attention is bidirectional
59
+ - s_aux: [num_attention_heads, sink_len] - auxiliary states for attention
56
60
  Returns:
57
61
  Tensor: attn_output: [batch=1, n_heads, n_groups, seq_len, head_dim] - Attention output
58
62
  """
@@ -72,6 +76,7 @@ def paged_sliding_window_attn_prefill_fake(
72
76
  block_table: Tensor,
73
77
  block_size: int,
74
78
  is_bidirectional: bool,
79
+ s_aux: Optional[Tensor] = None,
75
80
  ) -> Tensor:
76
81
  return torch.empty_like(q)
77
82
 
@@ -91,6 +96,8 @@ def paged_sliding_window_attn_decode(
91
96
  scale: Tensor,
92
97
  block_table: Tensor,
93
98
  block_size: int,
99
+ attn_mask: Tensor,
100
+ s_aux: Optional[Tensor] = None,
94
101
  ) -> Tensor:
95
102
  return torch.empty_like(q)
96
103
 
@@ -107,5 +114,7 @@ def paged_sliding_window_attn_decode_fake(
107
114
  scale: Tensor,
108
115
  block_table: Tensor,
109
116
  block_size: int,
117
+ attn_mask: Tensor,
118
+ s_aux: Optional[Tensor] = None,
110
119
  ) -> Tensor:
111
120
  return torch.empty_like(q)