optimum-rbln 0.8.2rc0__py3-none-any.whl → 0.8.3__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of optimum-rbln might be problematic. Click here for more details.

Files changed (105) hide show
  1. optimum/rbln/__init__.py +32 -9
  2. optimum/rbln/__version__.py +16 -3
  3. optimum/rbln/configuration_utils.py +20 -4
  4. optimum/rbln/diffusers/__init__.py +7 -0
  5. optimum/rbln/diffusers/configurations/models/configuration_autoencoder_kl.py +2 -2
  6. optimum/rbln/diffusers/configurations/models/configuration_autoencoder_kl_cosmos.py +2 -2
  7. optimum/rbln/diffusers/configurations/models/configuration_controlnet.py +2 -2
  8. optimum/rbln/diffusers/configurations/models/configuration_prior_transformer.py +2 -2
  9. optimum/rbln/diffusers/configurations/models/configuration_transformer_cosmos.py +2 -2
  10. optimum/rbln/diffusers/configurations/models/configuration_transformer_sd3.py +2 -2
  11. optimum/rbln/diffusers/configurations/models/configuration_unet_2d_condition.py +2 -2
  12. optimum/rbln/diffusers/configurations/models/configuration_vq_model.py +2 -2
  13. optimum/rbln/diffusers/configurations/pipelines/configuration_controlnet.py +3 -3
  14. optimum/rbln/diffusers/configurations/pipelines/configuration_cosmos.py +2 -2
  15. optimum/rbln/diffusers/configurations/pipelines/configuration_kandinsky2_2.py +4 -4
  16. optimum/rbln/diffusers/configurations/pipelines/configuration_stable_diffusion.py +2 -2
  17. optimum/rbln/diffusers/configurations/pipelines/configuration_stable_diffusion_3.py +2 -2
  18. optimum/rbln/diffusers/configurations/pipelines/configuration_stable_diffusion_xl.py +2 -2
  19. optimum/rbln/diffusers/modeling_diffusers.py +1 -1
  20. optimum/rbln/diffusers/models/__init__.py +3 -13
  21. optimum/rbln/diffusers/pipelines/__init__.py +11 -5
  22. optimum/rbln/diffusers/pipelines/auto_pipeline.py +237 -0
  23. optimum/rbln/diffusers/pipelines/cosmos/configuration_cosmos_guardrail.py +11 -6
  24. optimum/rbln/diffusers/pipelines/cosmos/cosmos_guardrail.py +14 -18
  25. optimum/rbln/diffusers/pipelines/cosmos/pipeline_cosmos_text2world.py +1 -1
  26. optimum/rbln/diffusers/pipelines/cosmos/pipeline_cosmos_video2world.py +1 -1
  27. optimum/rbln/diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_combined.py +1 -6
  28. optimum/rbln/modeling.py +3 -2
  29. optimum/rbln/modeling_base.py +29 -4
  30. optimum/rbln/ops/attn.py +158 -0
  31. optimum/rbln/ops/flash_attn.py +166 -0
  32. optimum/rbln/transformers/__init__.py +24 -0
  33. optimum/rbln/transformers/configuration_generic.py +6 -4
  34. optimum/rbln/transformers/modeling_generic.py +13 -8
  35. optimum/rbln/transformers/modeling_outputs.py +37 -0
  36. optimum/rbln/transformers/models/__init__.py +31 -16
  37. optimum/rbln/transformers/models/auto/__init__.py +2 -0
  38. optimum/rbln/transformers/models/auto/modeling_auto.py +14 -0
  39. optimum/rbln/transformers/models/bart/bart_architecture.py +1 -3
  40. optimum/rbln/transformers/models/bart/configuration_bart.py +2 -0
  41. optimum/rbln/transformers/models/bert/bert_architecture.py +16 -0
  42. optimum/rbln/transformers/models/bert/modeling_bert.py +8 -4
  43. optimum/rbln/transformers/models/blip_2/configuration_blip_2.py +2 -2
  44. optimum/rbln/transformers/models/blip_2/modeling_blip_2.py +7 -6
  45. optimum/rbln/transformers/models/clip/configuration_clip.py +3 -3
  46. optimum/rbln/transformers/models/colpali/colpali_architecture.py +1 -4
  47. optimum/rbln/transformers/models/colpali/configuration_colpali.py +2 -2
  48. optimum/rbln/transformers/models/colpali/modeling_colpali.py +2 -10
  49. optimum/rbln/transformers/models/decoderonly/configuration_decoderonly.py +43 -174
  50. optimum/rbln/transformers/models/decoderonly/decoderonly_architecture.py +101 -91
  51. optimum/rbln/transformers/models/decoderonly/decoderonly_runtime_utils.py +450 -0
  52. optimum/rbln/transformers/models/decoderonly/generation_decoderonly.py +88 -0
  53. optimum/rbln/transformers/models/decoderonly/modeling_decoderonly.py +296 -986
  54. optimum/rbln/transformers/models/depth_anything/__init__.py +16 -0
  55. optimum/rbln/transformers/models/depth_anything/configuration_depth_anything.py +24 -0
  56. optimum/rbln/transformers/models/depth_anything/modeling_depth_anything.py +25 -0
  57. optimum/rbln/transformers/models/gemma/gemma_architecture.py +1 -4
  58. optimum/rbln/transformers/models/gemma/modeling_gemma.py +9 -0
  59. optimum/rbln/transformers/models/gemma3/configuration_gemma3.py +3 -3
  60. optimum/rbln/transformers/models/gemma3/gemma3_runtime_utils.py +217 -0
  61. optimum/rbln/transformers/models/gemma3/modeling_gemma3.py +25 -251
  62. optimum/rbln/transformers/models/gpt2/modeling_gpt2.py +2 -0
  63. optimum/rbln/transformers/models/grounding_dino/__init__.py +10 -0
  64. optimum/rbln/transformers/models/grounding_dino/configuration_grounding_dino.py +86 -0
  65. optimum/rbln/transformers/models/grounding_dino/grounding_dino_architecture.py +507 -0
  66. optimum/rbln/transformers/models/grounding_dino/modeling_grounding_dino.py +1032 -0
  67. optimum/rbln/transformers/models/idefics3/configuration_idefics3.py +2 -2
  68. optimum/rbln/transformers/models/idefics3/modeling_idefics3.py +3 -9
  69. optimum/rbln/transformers/models/llama/modeling_llama.py +12 -3
  70. optimum/rbln/transformers/models/llava/configuration_llava.py +2 -2
  71. optimum/rbln/transformers/models/llava/modeling_llava.py +53 -14
  72. optimum/rbln/transformers/models/llava_next/configuration_llava_next.py +2 -2
  73. optimum/rbln/transformers/models/llava_next/modeling_llava_next.py +6 -16
  74. optimum/rbln/transformers/models/opt/modeling_opt.py +2 -30
  75. optimum/rbln/transformers/models/pegasus/configuration_pegasus.py +4 -0
  76. optimum/rbln/transformers/models/pegasus/modeling_pegasus.py +2 -0
  77. optimum/rbln/transformers/models/pegasus/pegasus_architecture.py +1 -3
  78. optimum/rbln/transformers/models/pixtral/configuration_pixtral.py +2 -2
  79. optimum/rbln/transformers/models/pixtral/modeling_pixtral.py +1 -4
  80. optimum/rbln/transformers/models/qwen2_5_vl/configuration_qwen2_5_vl.py +3 -3
  81. optimum/rbln/transformers/models/qwen2_5_vl/modeling_qwen2_5_vl.py +6 -15
  82. optimum/rbln/transformers/models/qwen2_5_vl/qwen2_5_vl_architecture.py +4 -7
  83. optimum/rbln/transformers/models/qwen3/modeling_qwen3.py +77 -3
  84. optimum/rbln/transformers/models/qwen3/qwen3_architecture.py +1 -4
  85. optimum/rbln/transformers/models/seq2seq/configuration_seq2seq.py +19 -2
  86. optimum/rbln/transformers/models/seq2seq/modeling_seq2seq.py +20 -1
  87. optimum/rbln/transformers/models/siglip/__init__.py +2 -6
  88. optimum/rbln/transformers/models/siglip/modeling_siglip.py +2 -2
  89. optimum/rbln/transformers/models/swin/__init__.py +16 -0
  90. optimum/rbln/transformers/models/swin/configuration_swin.py +42 -0
  91. optimum/rbln/transformers/models/swin/modeling_swin.py +341 -0
  92. optimum/rbln/transformers/models/t5/configuration_t5.py +2 -0
  93. optimum/rbln/transformers/models/t5/t5_architecture.py +8 -1
  94. optimum/rbln/transformers/models/time_series_transformer/configuration_time_series_transformer.py +2 -2
  95. optimum/rbln/transformers/models/time_series_transformer/modeling_time_series_transformer.py +4 -14
  96. optimum/rbln/transformers/models/whisper/configuration_whisper.py +10 -2
  97. optimum/rbln/transformers/models/whisper/modeling_whisper.py +20 -1
  98. optimum/rbln/transformers/models/xlm_roberta/__init__.py +2 -8
  99. optimum/rbln/transformers/utils/rbln_quantization.py +365 -65
  100. optimum/rbln/utils/runtime_utils.py +3 -3
  101. optimum/rbln/utils/submodule.py +10 -4
  102. {optimum_rbln-0.8.2rc0.dist-info → optimum_rbln-0.8.3.dist-info}/METADATA +1 -1
  103. {optimum_rbln-0.8.2rc0.dist-info → optimum_rbln-0.8.3.dist-info}/RECORD +105 -89
  104. {optimum_rbln-0.8.2rc0.dist-info → optimum_rbln-0.8.3.dist-info}/WHEEL +0 -0
  105. {optimum_rbln-0.8.2rc0.dist-info → optimum_rbln-0.8.3.dist-info}/licenses/LICENSE +0 -0
@@ -0,0 +1,237 @@
1
+ # Copyright 2025 Rebellions Inc. All rights reserved.
2
+
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at:
6
+
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+
16
+ import importlib
17
+ from typing import Type
18
+
19
+ from diffusers.models.controlnets import ControlNetUnionModel
20
+ from diffusers.pipelines.auto_pipeline import (
21
+ AUTO_IMAGE2IMAGE_PIPELINES_MAPPING,
22
+ AUTO_INPAINT_PIPELINES_MAPPING,
23
+ AUTO_TEXT2IMAGE_PIPELINES_MAPPING,
24
+ AutoPipelineForImage2Image,
25
+ AutoPipelineForInpainting,
26
+ AutoPipelineForText2Image,
27
+ _get_task_class,
28
+ )
29
+ from huggingface_hub.utils import validate_hf_hub_args
30
+
31
+ from optimum.rbln.modeling_base import RBLNBaseModel
32
+ from optimum.rbln.utils.model_utils import (
33
+ MODEL_MAPPING,
34
+ convert_hf_to_rbln_model_name,
35
+ convert_rbln_to_hf_model_name,
36
+ get_rbln_model_cls,
37
+ )
38
+
39
+
40
+ class RBLNAutoPipelineBase:
41
+ _model_mapping = None
42
+ _model_mapping_names = None
43
+
44
+ @classmethod
45
+ def get_rbln_cls(cls, pretrained_model_name_or_path, export=True, **kwargs):
46
+ if export:
47
+ hf_model_class = cls.infer_hf_model_class(pretrained_model_name_or_path, **kwargs)
48
+ rbln_class_name = convert_hf_to_rbln_model_name(hf_model_class.__name__)
49
+ else:
50
+ rbln_class_name = cls.get_rbln_model_cls_name(pretrained_model_name_or_path, **kwargs)
51
+ if convert_rbln_to_hf_model_name(rbln_class_name) not in cls._model_mapping_names.values():
52
+ raise ValueError(
53
+ f"The architecture '{rbln_class_name}' is not supported by the `{cls.__name__}.from_pretrained()` method. "
54
+ "Please use the `from_pretrained()` method of the appropriate class to load this model, "
55
+ f"or directly use '{rbln_class_name}.from_pretrained()`."
56
+ )
57
+
58
+ try:
59
+ rbln_cls = get_rbln_model_cls(rbln_class_name)
60
+ except AttributeError as e:
61
+ raise AttributeError(
62
+ f"Class '{rbln_class_name}' not found in 'optimum.rbln' module for model ID '{pretrained_model_name_or_path}'. "
63
+ "Ensure that the class name is correctly mapped and available in the 'optimum.rbln' module."
64
+ ) from e
65
+
66
+ return rbln_cls
67
+
68
+ @classmethod
69
+ def get_rbln_model_cls_name(cls, pretrained_model_name_or_path, **kwargs):
70
+ """
71
+ Retrieve the path to the compiled model directory for a given RBLN model.
72
+
73
+ Args:
74
+ pretrained_model_name_or_path (str): Identifier of the model.
75
+
76
+ Returns:
77
+ str: Path to the compiled model directory.
78
+ """
79
+ model_index_config = cls.load_config(pretrained_model_name_or_path)
80
+
81
+ if "_class_name" not in model_index_config:
82
+ raise ValueError(
83
+ "The `_class_name` field is missing from model_index_config. This is unexpected and should be reported as an issue. "
84
+ "Please use the `from_pretrained()` method of the appropriate class to load this model."
85
+ )
86
+
87
+ return model_index_config["_class_name"]
88
+
89
+ @classmethod
90
+ def infer_hf_model_class(
91
+ cls,
92
+ pretrained_model_or_path,
93
+ cache_dir=None,
94
+ force_download=False,
95
+ proxies=None,
96
+ token=None,
97
+ local_files_only=False,
98
+ revision=None,
99
+ **kwargs,
100
+ ):
101
+ config = cls.load_config(
102
+ pretrained_model_or_path,
103
+ cache_dir=cache_dir,
104
+ force_download=force_download,
105
+ proxies=proxies,
106
+ token=token,
107
+ local_files_only=local_files_only,
108
+ revision=revision,
109
+ )
110
+ pipeline_key_name = cls.get_pipeline_key_name(config, **kwargs)
111
+
112
+ pipeline_cls = _get_task_class(cls._model_mapping, pipeline_key_name)
113
+
114
+ return pipeline_cls
115
+
116
+ @classmethod
117
+ def get_pipeline_key_name(cls, config, **kwargs):
118
+ orig_class_name = config["_class_name"]
119
+ if "ControlPipeline" in orig_class_name:
120
+ to_replace = "ControlPipeline"
121
+ else:
122
+ to_replace = "Pipeline"
123
+
124
+ if "controlnet" in kwargs:
125
+ if isinstance(kwargs["controlnet"], ControlNetUnionModel):
126
+ orig_class_name = config["_class_name"].replace(to_replace, "ControlNetUnionPipeline")
127
+ else:
128
+ orig_class_name = config["_class_name"].replace(to_replace, "ControlNetPipeline")
129
+ if "enable_pag" in kwargs:
130
+ enable_pag = kwargs.pop("enable_pag")
131
+ if enable_pag:
132
+ orig_class_name = orig_class_name.replace(to_replace, "PAGPipeline")
133
+
134
+ return orig_class_name
135
+
136
+ @classmethod
137
+ @validate_hf_hub_args
138
+ def from_pretrained(cls, model_id, **kwargs):
139
+ rbln_cls = cls.get_rbln_cls(model_id, **kwargs)
140
+ return rbln_cls.from_pretrained(model_id, **kwargs)
141
+
142
+ @classmethod
143
+ def from_model(cls, model, **kwargs):
144
+ rbln_cls = get_rbln_model_cls(f"RBLN{model.__class__.__name__}")
145
+ return rbln_cls.from_model(model, **kwargs)
146
+
147
+ @staticmethod
148
+ def register(rbln_cls: Type[RBLNBaseModel], exist_ok=False):
149
+ """
150
+ Register a new RBLN model class.
151
+
152
+ Args:
153
+ rbln_cls (Type[RBLNBaseModel]): The RBLN model class to register.
154
+ exist_ok (bool): Whether to allow registering an already registered model.
155
+ """
156
+ if not issubclass(rbln_cls, RBLNBaseModel):
157
+ raise ValueError("`rbln_cls` must be a subclass of RBLNBaseModel.")
158
+
159
+ native_cls = getattr(importlib.import_module("optimum.rbln"), rbln_cls.__name__, None)
160
+ if rbln_cls.__name__ in MODEL_MAPPING or native_cls is not None:
161
+ if not exist_ok:
162
+ raise ValueError(f"Model for {rbln_cls.__name__} already registered.")
163
+
164
+ MODEL_MAPPING[rbln_cls.__name__] = rbln_cls
165
+
166
+
167
+ class RBLNAutoPipelineForText2Image(RBLNAutoPipelineBase, AutoPipelineForText2Image):
168
+ _model_mapping = AUTO_TEXT2IMAGE_PIPELINES_MAPPING
169
+ _model_mapping_names = {x[0]: x[1].__name__ for x in AUTO_TEXT2IMAGE_PIPELINES_MAPPING.items()}
170
+
171
+
172
+ class RBLNAutoPipelineForImage2Image(RBLNAutoPipelineBase, AutoPipelineForImage2Image):
173
+ _model_mapping = AUTO_IMAGE2IMAGE_PIPELINES_MAPPING
174
+ _model_mapping_names = {x[0]: x[1].__name__ for x in AUTO_IMAGE2IMAGE_PIPELINES_MAPPING.items()}
175
+
176
+ @classmethod
177
+ def get_pipeline_key_name(cls, config, **kwargs):
178
+ orig_class_name = config["_class_name"]
179
+ # the `orig_class_name` can be:
180
+ # `- *Pipeline` (for regular text-to-image checkpoint)
181
+ # - `*ControlPipeline` (for Flux tools specific checkpoint)
182
+ # `- *Img2ImgPipeline` (for refiner checkpoint)
183
+ if "Img2Img" in orig_class_name:
184
+ to_replace = "Img2ImgPipeline"
185
+ elif "ControlPipeline" in orig_class_name:
186
+ to_replace = "ControlPipeline"
187
+ else:
188
+ to_replace = "Pipeline"
189
+
190
+ if "controlnet" in kwargs:
191
+ if isinstance(kwargs["controlnet"], ControlNetUnionModel):
192
+ orig_class_name = orig_class_name.replace(to_replace, "ControlNetUnion" + to_replace)
193
+ else:
194
+ orig_class_name = orig_class_name.replace(to_replace, "ControlNet" + to_replace)
195
+ if "enable_pag" in kwargs:
196
+ enable_pag = kwargs.pop("enable_pag")
197
+ if enable_pag:
198
+ orig_class_name = orig_class_name.replace(to_replace, "PAG" + to_replace)
199
+
200
+ if to_replace == "ControlPipeline":
201
+ orig_class_name = orig_class_name.replace(to_replace, "ControlImg2ImgPipeline")
202
+
203
+ return orig_class_name
204
+
205
+
206
+ class RBLNAutoPipelineForInpainting(RBLNAutoPipelineBase, AutoPipelineForInpainting):
207
+ _model_mapping = AUTO_INPAINT_PIPELINES_MAPPING
208
+ _model_mapping_names = {x[0]: x[1].__name__ for x in AUTO_INPAINT_PIPELINES_MAPPING.items()}
209
+
210
+ @classmethod
211
+ def get_pipeline_key_name(cls, config, **kwargs):
212
+ orig_class_name = config["_class_name"]
213
+
214
+ # The `orig_class_name`` can be:
215
+ # `- *InpaintPipeline` (for inpaint-specific checkpoint)
216
+ # - `*ControlPipeline` (for Flux tools specific checkpoint)
217
+ # - or *Pipeline (for regular text-to-image checkpoint)
218
+ if "Inpaint" in orig_class_name:
219
+ to_replace = "InpaintPipeline"
220
+ elif "ControlPipeline" in orig_class_name:
221
+ to_replace = "ControlPipeline"
222
+ else:
223
+ to_replace = "Pipeline"
224
+
225
+ if "controlnet" in kwargs:
226
+ if isinstance(kwargs["controlnet"], ControlNetUnionModel):
227
+ orig_class_name = orig_class_name.replace(to_replace, "ControlNetUnion" + to_replace)
228
+ else:
229
+ orig_class_name = orig_class_name.replace(to_replace, "ControlNet" + to_replace)
230
+ if "enable_pag" in kwargs:
231
+ enable_pag = kwargs.pop("enable_pag")
232
+ if enable_pag:
233
+ orig_class_name = orig_class_name.replace(to_replace, "PAG" + to_replace)
234
+ if to_replace == "ControlPipeline":
235
+ orig_class_name = orig_class_name.replace(to_replace, "ControlInpaintPipeline")
236
+
237
+ return orig_class_name
@@ -12,7 +12,7 @@
12
12
  # See the License for the specific language governing permissions and
13
13
  # limitations under the License.
14
14
 
15
- from typing import Any, Dict, Optional, Tuple
15
+ from typing import Any, Optional, Tuple
16
16
 
17
17
  from ....configuration_utils import RBLNAutoConfig, RBLNModelConfig
18
18
  from ....transformers import RBLNLlamaForCausalLMConfig, RBLNSiglipVisionModelConfig
@@ -56,11 +56,11 @@ class RBLNCosmosSafetyCheckerConfig(RBLNModelConfig):
56
56
  Configuration class for RBLN Cosmos Safety Checker.
57
57
  """
58
58
 
59
- submodules = ["aegis", "video_safety_model", "face_blur_filter", "siglip_encoder"]
59
+ submodules = ["llamaguard3", "video_safety_model", "face_blur_filter", "siglip_encoder"]
60
60
 
61
61
  def __init__(
62
62
  self,
63
- aegis: Optional[RBLNModelConfig] = None,
63
+ llamaguard3: Optional[RBLNModelConfig] = None,
64
64
  video_safety_model: Optional[RBLNModelConfig] = None,
65
65
  face_blur_filter: Optional[RBLNModelConfig] = None,
66
66
  siglip_encoder: Optional[RBLNSiglipVisionModelConfig] = None,
@@ -69,19 +69,24 @@ class RBLNCosmosSafetyCheckerConfig(RBLNModelConfig):
69
69
  image_size: Optional[Tuple[int, int]] = None,
70
70
  height: Optional[int] = None,
71
71
  width: Optional[int] = None,
72
- **kwargs: Dict[str, Any],
72
+ max_seq_len: Optional[int] = None,
73
+ **kwargs: Any,
73
74
  ):
74
75
  super().__init__(**kwargs)
75
76
  if height is not None and width is not None:
76
77
  image_size = (height, width)
77
78
 
79
+ if max_seq_len is None:
80
+ max_seq_len = 512
81
+
78
82
  tensor_parallel_size = kwargs.get("tensor_parallel_size")
79
83
 
80
- self.aegis = self.init_submodule_config(
84
+ self.llamaguard3 = self.init_submodule_config(
81
85
  RBLNLlamaForCausalLMConfig,
82
- aegis,
86
+ llamaguard3,
83
87
  batch_size=batch_size,
84
88
  tensor_parallel_size=tensor_parallel_size,
89
+ max_seq_len=max_seq_len,
85
90
  )
86
91
 
87
92
  self.siglip_encoder = self.init_submodule_config(
@@ -33,9 +33,9 @@ if is_cosmos_guardrail_available():
33
33
  from cosmos_guardrail import CosmosSafetyChecker
34
34
  from cosmos_guardrail.cosmos_guardrail import (
35
35
  COSMOS_GUARDRAIL_CHECKPOINT,
36
- Aegis,
37
36
  Blocklist,
38
37
  GuardrailRunner,
38
+ LlamaGuard3,
39
39
  ModelConfig,
40
40
  RetinaFaceFilter,
41
41
  SafetyClassifier,
@@ -55,7 +55,7 @@ else:
55
55
 
56
56
  COSMOS_GUARDRAIL_CHECKPOINT = None
57
57
 
58
- class Aegis(FailToImportCosmosGuardrail): ...
58
+ class LlamaGuard3(FailToImportCosmosGuardrail): ...
59
59
 
60
60
  class Blocklist(FailToImportCosmosGuardrail): ...
61
61
 
@@ -312,33 +312,31 @@ class RBLNVideoContentSafetyFilter(VideoContentSafetyFilter):
312
312
  self.encoder.save_pretrained(checkpoint_id)
313
313
 
314
314
 
315
- class RBLNAegis(Aegis):
315
+ class RBLNLlamaGuard3(LlamaGuard3):
316
316
  def __init__(
317
317
  self,
318
318
  checkpoint_id: str = COSMOS_GUARDRAIL_CHECKPOINT,
319
- base_model_id: str = "meta-llama/LlamaGuard-7b",
320
- aegis_adapter: str = "nvidia/Aegis-AI-Content-Safety-LlamaGuard-Defensive-1.0",
319
+ base_model_id: str = "meta-llama/Llama-Guard-3-8B",
321
320
  rbln_config: Optional[RBLNCosmosSafetyCheckerConfig] = None,
322
321
  ) -> None:
323
322
  if is_compiled_dir(checkpoint_id):
324
323
  torch.nn.Module.__init__(self)
325
- cache_dir = pathlib.Path(checkpoint_id) / "aegis"
324
+ cache_dir = pathlib.Path(checkpoint_id) / "llamaguard3"
326
325
  self.tokenizer = AutoTokenizer.from_pretrained(cache_dir)
327
- self.model = RBLNAutoModelForCausalLM.from_pretrained(cache_dir, rbln_config=rbln_config.aegis)
326
+ self.model = RBLNAutoModelForCausalLM.from_pretrained(cache_dir, rbln_config=rbln_config.llamaguard3)
328
327
 
329
328
  else:
330
- super().__init__(checkpoint_id, base_model_id, aegis_adapter)
331
- model = self.model.merge_and_unload() # peft merge
329
+ super().__init__(checkpoint_id, base_model_id)
330
+ model = self.model
332
331
  del self.model
333
-
334
- self.model = RBLNAutoModelForCausalLM.from_model(model, rbln_config=rbln_config.aegis)
332
+ self.model = RBLNAutoModelForCausalLM.from_model(model, rbln_config=rbln_config.llamaguard3)
335
333
 
336
334
  self.rbln_config = rbln_config
337
335
  self.dtype = torch.bfloat16
338
336
  self.device = torch.device("cpu")
339
337
 
340
338
  def save_pretrained(self, checkpoint_id: str):
341
- cache_dir = pathlib.Path(checkpoint_id) / "aegis"
339
+ cache_dir = pathlib.Path(checkpoint_id) / "llamaguard3"
342
340
  self.model.save_pretrained(cache_dir)
343
341
  self.tokenizer.save_pretrained(cache_dir)
344
342
 
@@ -351,8 +349,7 @@ class RBLNCosmosSafetyChecker(CosmosSafetyChecker):
351
349
  def __init__(
352
350
  self,
353
351
  checkpoint_id: str = COSMOS_GUARDRAIL_CHECKPOINT,
354
- aegis_model_id: str = "meta-llama/LlamaGuard-7b",
355
- aegis_adapter_id: str = "nvidia/Aegis-AI-Content-Safety-LlamaGuard-Defensive-1.0",
352
+ llamaguard_model_id: str = "meta-llama/Llama-Guard-3-8B",
356
353
  rbln_config: Optional[RBLNCosmosSafetyCheckerConfig] = None,
357
354
  ) -> None:
358
355
  torch.nn.Module.__init__(self)
@@ -369,10 +366,9 @@ class RBLNCosmosSafetyChecker(CosmosSafetyChecker):
369
366
  self.text_guardrail = GuardrailRunner(
370
367
  safety_models=[
371
368
  Blocklist(COSMOS_GUARDRAIL_CHECKPOINT), # Changed since it cannot be saved
372
- RBLNAegis(
369
+ RBLNLlamaGuard3(
373
370
  checkpoint_id=checkpoint_id,
374
- base_model_id=aegis_model_id,
375
- aegis_adapter=aegis_adapter_id,
371
+ base_model_id=llamaguard_model_id,
376
372
  rbln_config=rbln_config,
377
373
  ),
378
374
  ]
@@ -387,7 +383,7 @@ class RBLNCosmosSafetyChecker(CosmosSafetyChecker):
387
383
 
388
384
  def save_pretrained(self, save_dir: str):
389
385
  for text_safety_models in self.text_guardrail.safety_models:
390
- if isinstance(text_safety_models, RBLNAegis):
386
+ if isinstance(text_safety_models, RBLNLlamaGuard3):
391
387
  text_safety_models.save_pretrained(save_dir)
392
388
 
393
389
  for video_safety_models in self.video_guardrail.safety_models:
@@ -87,7 +87,7 @@ class RBLNCosmosTextToWorldPipeline(RBLNDiffusionMixin, CosmosTextToWorldPipelin
87
87
  export: bool = False,
88
88
  safety_checker: Optional[RBLNCosmosSafetyChecker] = None,
89
89
  rbln_config: Dict[str, Any] = {},
90
- **kwargs: Dict[str, Any],
90
+ **kwargs: Any,
91
91
  ):
92
92
  rbln_config, kwargs = cls.get_rbln_config_class().initialize_from_kwargs(rbln_config, **kwargs)
93
93
  if safety_checker is None and export:
@@ -87,7 +87,7 @@ class RBLNCosmosVideoToWorldPipeline(RBLNDiffusionMixin, CosmosVideoToWorldPipel
87
87
  export: bool = False,
88
88
  safety_checker: Optional[RBLNCosmosSafetyChecker] = None,
89
89
  rbln_config: Dict[str, Any] = {},
90
- **kwargs: Dict[str, Any],
90
+ **kwargs: Any,
91
91
  ):
92
92
  rbln_config, kwargs = cls.get_rbln_config_class().initialize_from_kwargs(rbln_config, **kwargs)
93
93
  if safety_checker is None and export:
@@ -22,12 +22,7 @@ from diffusers import (
22
22
  UNet2DConditionModel,
23
23
  VQModel,
24
24
  )
25
- from transformers import (
26
- CLIPImageProcessor,
27
- CLIPTextModelWithProjection,
28
- CLIPTokenizer,
29
- CLIPVisionModelWithProjection,
30
- )
25
+ from transformers import CLIPImageProcessor, CLIPTextModelWithProjection, CLIPTokenizer, CLIPVisionModelWithProjection
31
26
 
32
27
  from ...configurations import RBLNKandinskyV22CombinedPipelineConfig
33
28
  from ...modeling_diffusers import RBLNDiffusionMixin
optimum/rbln/modeling.py CHANGED
@@ -78,7 +78,7 @@ class RBLNModel(RBLNBaseModel):
78
78
  rbln_config: Optional[Union[RBLNModelConfig, Dict]] = None,
79
79
  model_save_dir: Optional[Union[str, Path, TemporaryDirectory]] = None,
80
80
  subfolder: str = "",
81
- **kwargs: Dict[str, Any],
81
+ **kwargs: Any,
82
82
  ) -> "RBLNModel":
83
83
  """
84
84
  Converts and compiles a pre-trained HuggingFace library model into a RBLN model.
@@ -147,6 +147,7 @@ class RBLNModel(RBLNBaseModel):
147
147
  model=model,
148
148
  model_save_dir=save_dir,
149
149
  rbln_config=rbln_config,
150
+ preprocessors=preprocessors,
150
151
  **kwargs,
151
152
  )
152
153
  else:
@@ -241,7 +242,7 @@ class RBLNModel(RBLNBaseModel):
241
242
  for compiled_model in compiled_models
242
243
  ]
243
244
 
244
- def forward(self, *args: Any, return_dict: Optional[bool] = None, **kwargs: Dict[str, Any]) -> Any:
245
+ def forward(self, *args: Any, return_dict: Optional[bool] = None, **kwargs: Any) -> Any:
245
246
  """
246
247
  Defines the forward pass of the RBLN model, providing a drop-in replacement for HuggingFace PreTrainedModel.
247
248
 
@@ -348,7 +348,7 @@ class RBLNBaseModel(SubModulesMixin, PushToHubMixin, PreTrainedModel):
348
348
  model_id: Union[str, Path],
349
349
  export: bool = False,
350
350
  rbln_config: Optional[Union[Dict, RBLNModelConfig]] = None,
351
- **kwargs: Dict[str, Any],
351
+ **kwargs: Any,
352
352
  ) -> "RBLNBaseModel":
353
353
  """
354
354
  The `from_pretrained()` function is utilized in its standard form as in the HuggingFace transformers library.
@@ -523,10 +523,35 @@ class RBLNBaseModel(SubModulesMixin, PushToHubMixin, PreTrainedModel):
523
523
  # First copy everything to a temporary directory
524
524
  shutil.copytree(real_save_dir, tmp_dir)
525
525
 
526
- # If everything succeeded, atomically replace the target directory
526
+ # If everything succeeded, move files to target directory
527
527
  if os.path.exists(save_directory_path):
528
- shutil.rmtree(save_directory_path)
529
- os.rename(tmp_dir, save_directory_path)
528
+ # Merge files from tmp_dir into existing directory
529
+ def _merge_dir(src_root: str, dst_root: str):
530
+ for name in os.listdir(src_root):
531
+ src_item = os.path.join(src_root, name)
532
+ dst_item = os.path.join(dst_root, name)
533
+
534
+ if os.path.islink(src_item) or os.path.isfile(src_item):
535
+ os.makedirs(os.path.dirname(dst_item), exist_ok=True)
536
+ if os.path.isdir(dst_item) and not os.path.islink(dst_item):
537
+ shutil.rmtree(dst_item)
538
+ os.replace(src_item, dst_item)
539
+ elif os.path.isdir(src_item):
540
+ if os.path.islink(dst_item) or os.path.isfile(dst_item):
541
+ os.remove(dst_item)
542
+ os.makedirs(dst_item, exist_ok=True)
543
+ _merge_dir(src_item, dst_item)
544
+ else:
545
+ # Fallback for special file types
546
+ os.replace(src_item, dst_item)
547
+
548
+ _merge_dir(tmp_dir, str(save_directory_path))
549
+
550
+ # Remove the temporary directory tree after merge
551
+ shutil.rmtree(tmp_dir)
552
+ else:
553
+ # If target doesn't exist, just rename tmp_dir to target
554
+ os.rename(tmp_dir, save_directory_path)
530
555
 
531
556
  except Exception as e:
532
557
  # Clean up the temporary directory if anything fails
optimum/rbln/ops/attn.py CHANGED
@@ -53,6 +53,45 @@ def paged_attn_decode_fake(
53
53
  return torch.empty_like(q)
54
54
 
55
55
 
56
+ @torch.library.custom_op(
57
+ "rbln_custom_ops::paged_attn_decode_kv_fp8",
58
+ mutates_args=(["kcache", "vcache"]),
59
+ )
60
+ def paged_attn_decode_kv_fp8(
61
+ q: Tensor,
62
+ k: Tensor,
63
+ v: Tensor,
64
+ mask: Tensor,
65
+ kcache: Tensor,
66
+ vcache: Tensor,
67
+ seq: Tensor,
68
+ scale: Tensor,
69
+ block_table: Tensor,
70
+ block_size: int,
71
+ k_scale: Tensor,
72
+ v_scale: Tensor,
73
+ ) -> Tensor:
74
+ return torch.empty_like(q)
75
+
76
+
77
+ @paged_attn_decode_kv_fp8.register_fake
78
+ def paged_attn_decode_kv_fp8_fake(
79
+ q: Tensor,
80
+ k: Tensor,
81
+ v: Tensor,
82
+ mask: Tensor,
83
+ kcache: Tensor,
84
+ vcache: Tensor,
85
+ seq: Tensor,
86
+ scale: Tensor,
87
+ block_table: Tensor,
88
+ block_size: int,
89
+ k_scale: Tensor,
90
+ v_scale: Tensor,
91
+ ) -> Tensor:
92
+ return torch.empty_like(q)
93
+
94
+
56
95
  @torch.library.custom_op(
57
96
  "rbln_custom_ops::paged_attn_prefill",
58
97
  mutates_args=(["kcache", "vcache"]),
@@ -112,6 +151,45 @@ def paged_attn_prefill_fake(
112
151
  return torch.empty_like(q)
113
152
 
114
153
 
154
+ @torch.library.custom_op(
155
+ "rbln_custom_ops::paged_attn_prefill_kv_fp8",
156
+ mutates_args=(["kcache", "vcache"]),
157
+ )
158
+ def paged_attn_prefill_kv_fp8(
159
+ q: Tensor,
160
+ k: Tensor,
161
+ v: Tensor,
162
+ mask: Tensor,
163
+ kcache: Tensor,
164
+ vcache: Tensor,
165
+ seq: Tensor,
166
+ scale: Tensor,
167
+ block_table: Tensor,
168
+ block_size: int,
169
+ k_scale: Tensor,
170
+ v_scale: Tensor,
171
+ ) -> Tensor:
172
+ return torch.empty_like(q)
173
+
174
+
175
+ @paged_attn_prefill_kv_fp8.register_fake
176
+ def paged_attn_prefill_kv_fp8_fake(
177
+ q: Tensor,
178
+ k: Tensor,
179
+ v: Tensor,
180
+ mask: Tensor,
181
+ kcache: Tensor,
182
+ vcache: Tensor,
183
+ seq: Tensor,
184
+ scale: Tensor,
185
+ block_table: Tensor,
186
+ block_size: int,
187
+ k_scale: Tensor,
188
+ v_scale: Tensor,
189
+ ) -> Tensor:
190
+ return torch.empty_like(q)
191
+
192
+
115
193
  @torch.library.custom_op(
116
194
  "rbln_custom_ops::paged_causal_attn_decode",
117
195
  mutates_args=(["kcache", "vcache"]),
@@ -236,6 +314,86 @@ def paged_causal_attn_prefill_fake(
236
314
  return torch.empty_like(q)
237
315
 
238
316
 
317
+ @torch.library.custom_op(
318
+ "rbln_custom_ops::paged_causal_attn_decode_kv_fp8",
319
+ mutates_args=(["kcache", "vcache"]),
320
+ )
321
+ def paged_causal_attn_decode_kv_fp8(
322
+ q: Tensor,
323
+ k: Tensor,
324
+ v: Tensor,
325
+ kcache: Tensor,
326
+ vcache: Tensor,
327
+ seq: Tensor,
328
+ scale: Tensor,
329
+ block_table: Tensor,
330
+ block_size: int,
331
+ k_scale: Tensor,
332
+ v_scale: Tensor,
333
+ mask: Optional[Tensor] = None,
334
+ ) -> Tensor:
335
+ return torch.empty_like(q)
336
+
337
+
338
+ @paged_causal_attn_decode_kv_fp8.register_fake
339
+ def paged_causal_attn_decode_kv_fp8_fake(
340
+ q: Tensor,
341
+ k: Tensor,
342
+ v: Tensor,
343
+ kcache: Tensor,
344
+ vcache: Tensor,
345
+ seq: Tensor,
346
+ scale: Tensor,
347
+ block_table: Tensor,
348
+ block_size: int,
349
+ k_scale: Tensor,
350
+ v_scale: Tensor,
351
+ mask: Optional[Tensor] = None,
352
+ ) -> Tensor:
353
+ return torch.empty_like(q)
354
+
355
+
356
+ @torch.library.custom_op(
357
+ "rbln_custom_ops::paged_causal_attn_prefill_kv_fp8",
358
+ mutates_args=(["kcache", "vcache"]),
359
+ )
360
+ def paged_causal_attn_prefill_kv_fp8(
361
+ q: Tensor,
362
+ k: Tensor,
363
+ v: Tensor,
364
+ kcache: Tensor,
365
+ vcache: Tensor,
366
+ seq: Tensor,
367
+ scale: Tensor,
368
+ block_table: Tensor,
369
+ block_size: int,
370
+ is_bidirectional: bool,
371
+ k_scale: Tensor,
372
+ v_scale: Tensor,
373
+ mask: Optional[Tensor] = None,
374
+ ) -> Tensor:
375
+ return torch.empty_like(q)
376
+
377
+
378
+ @paged_causal_attn_prefill_kv_fp8.register_fake
379
+ def paged_causal_attn_prefill_kv_fp8_fake(
380
+ q: Tensor,
381
+ k: Tensor,
382
+ v: Tensor,
383
+ kcache: Tensor,
384
+ vcache: Tensor,
385
+ seq: Tensor,
386
+ scale: Tensor,
387
+ block_table: Tensor,
388
+ block_size: int,
389
+ is_bidirectional: bool,
390
+ k_scale: Tensor,
391
+ v_scale: Tensor,
392
+ mask: Optional[Tensor] = None,
393
+ ) -> Tensor:
394
+ return torch.empty_like(q)
395
+
396
+
239
397
  @torch.library.custom_op(
240
398
  "rbln_custom_ops::paged_add_softmax_attn_decode",
241
399
  mutates_args=(["kcache", "vcache"]),