optimum-rbln 0.8.2rc0__py3-none-any.whl → 0.8.3__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of optimum-rbln might be problematic. Click here for more details.
- optimum/rbln/__init__.py +32 -9
- optimum/rbln/__version__.py +16 -3
- optimum/rbln/configuration_utils.py +20 -4
- optimum/rbln/diffusers/__init__.py +7 -0
- optimum/rbln/diffusers/configurations/models/configuration_autoencoder_kl.py +2 -2
- optimum/rbln/diffusers/configurations/models/configuration_autoencoder_kl_cosmos.py +2 -2
- optimum/rbln/diffusers/configurations/models/configuration_controlnet.py +2 -2
- optimum/rbln/diffusers/configurations/models/configuration_prior_transformer.py +2 -2
- optimum/rbln/diffusers/configurations/models/configuration_transformer_cosmos.py +2 -2
- optimum/rbln/diffusers/configurations/models/configuration_transformer_sd3.py +2 -2
- optimum/rbln/diffusers/configurations/models/configuration_unet_2d_condition.py +2 -2
- optimum/rbln/diffusers/configurations/models/configuration_vq_model.py +2 -2
- optimum/rbln/diffusers/configurations/pipelines/configuration_controlnet.py +3 -3
- optimum/rbln/diffusers/configurations/pipelines/configuration_cosmos.py +2 -2
- optimum/rbln/diffusers/configurations/pipelines/configuration_kandinsky2_2.py +4 -4
- optimum/rbln/diffusers/configurations/pipelines/configuration_stable_diffusion.py +2 -2
- optimum/rbln/diffusers/configurations/pipelines/configuration_stable_diffusion_3.py +2 -2
- optimum/rbln/diffusers/configurations/pipelines/configuration_stable_diffusion_xl.py +2 -2
- optimum/rbln/diffusers/modeling_diffusers.py +1 -1
- optimum/rbln/diffusers/models/__init__.py +3 -13
- optimum/rbln/diffusers/pipelines/__init__.py +11 -5
- optimum/rbln/diffusers/pipelines/auto_pipeline.py +237 -0
- optimum/rbln/diffusers/pipelines/cosmos/configuration_cosmos_guardrail.py +11 -6
- optimum/rbln/diffusers/pipelines/cosmos/cosmos_guardrail.py +14 -18
- optimum/rbln/diffusers/pipelines/cosmos/pipeline_cosmos_text2world.py +1 -1
- optimum/rbln/diffusers/pipelines/cosmos/pipeline_cosmos_video2world.py +1 -1
- optimum/rbln/diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_combined.py +1 -6
- optimum/rbln/modeling.py +3 -2
- optimum/rbln/modeling_base.py +29 -4
- optimum/rbln/ops/attn.py +158 -0
- optimum/rbln/ops/flash_attn.py +166 -0
- optimum/rbln/transformers/__init__.py +24 -0
- optimum/rbln/transformers/configuration_generic.py +6 -4
- optimum/rbln/transformers/modeling_generic.py +13 -8
- optimum/rbln/transformers/modeling_outputs.py +37 -0
- optimum/rbln/transformers/models/__init__.py +31 -16
- optimum/rbln/transformers/models/auto/__init__.py +2 -0
- optimum/rbln/transformers/models/auto/modeling_auto.py +14 -0
- optimum/rbln/transformers/models/bart/bart_architecture.py +1 -3
- optimum/rbln/transformers/models/bart/configuration_bart.py +2 -0
- optimum/rbln/transformers/models/bert/bert_architecture.py +16 -0
- optimum/rbln/transformers/models/bert/modeling_bert.py +8 -4
- optimum/rbln/transformers/models/blip_2/configuration_blip_2.py +2 -2
- optimum/rbln/transformers/models/blip_2/modeling_blip_2.py +7 -6
- optimum/rbln/transformers/models/clip/configuration_clip.py +3 -3
- optimum/rbln/transformers/models/colpali/colpali_architecture.py +1 -4
- optimum/rbln/transformers/models/colpali/configuration_colpali.py +2 -2
- optimum/rbln/transformers/models/colpali/modeling_colpali.py +2 -10
- optimum/rbln/transformers/models/decoderonly/configuration_decoderonly.py +43 -174
- optimum/rbln/transformers/models/decoderonly/decoderonly_architecture.py +101 -91
- optimum/rbln/transformers/models/decoderonly/decoderonly_runtime_utils.py +450 -0
- optimum/rbln/transformers/models/decoderonly/generation_decoderonly.py +88 -0
- optimum/rbln/transformers/models/decoderonly/modeling_decoderonly.py +296 -986
- optimum/rbln/transformers/models/depth_anything/__init__.py +16 -0
- optimum/rbln/transformers/models/depth_anything/configuration_depth_anything.py +24 -0
- optimum/rbln/transformers/models/depth_anything/modeling_depth_anything.py +25 -0
- optimum/rbln/transformers/models/gemma/gemma_architecture.py +1 -4
- optimum/rbln/transformers/models/gemma/modeling_gemma.py +9 -0
- optimum/rbln/transformers/models/gemma3/configuration_gemma3.py +3 -3
- optimum/rbln/transformers/models/gemma3/gemma3_runtime_utils.py +217 -0
- optimum/rbln/transformers/models/gemma3/modeling_gemma3.py +25 -251
- optimum/rbln/transformers/models/gpt2/modeling_gpt2.py +2 -0
- optimum/rbln/transformers/models/grounding_dino/__init__.py +10 -0
- optimum/rbln/transformers/models/grounding_dino/configuration_grounding_dino.py +86 -0
- optimum/rbln/transformers/models/grounding_dino/grounding_dino_architecture.py +507 -0
- optimum/rbln/transformers/models/grounding_dino/modeling_grounding_dino.py +1032 -0
- optimum/rbln/transformers/models/idefics3/configuration_idefics3.py +2 -2
- optimum/rbln/transformers/models/idefics3/modeling_idefics3.py +3 -9
- optimum/rbln/transformers/models/llama/modeling_llama.py +12 -3
- optimum/rbln/transformers/models/llava/configuration_llava.py +2 -2
- optimum/rbln/transformers/models/llava/modeling_llava.py +53 -14
- optimum/rbln/transformers/models/llava_next/configuration_llava_next.py +2 -2
- optimum/rbln/transformers/models/llava_next/modeling_llava_next.py +6 -16
- optimum/rbln/transformers/models/opt/modeling_opt.py +2 -30
- optimum/rbln/transformers/models/pegasus/configuration_pegasus.py +4 -0
- optimum/rbln/transformers/models/pegasus/modeling_pegasus.py +2 -0
- optimum/rbln/transformers/models/pegasus/pegasus_architecture.py +1 -3
- optimum/rbln/transformers/models/pixtral/configuration_pixtral.py +2 -2
- optimum/rbln/transformers/models/pixtral/modeling_pixtral.py +1 -4
- optimum/rbln/transformers/models/qwen2_5_vl/configuration_qwen2_5_vl.py +3 -3
- optimum/rbln/transformers/models/qwen2_5_vl/modeling_qwen2_5_vl.py +6 -15
- optimum/rbln/transformers/models/qwen2_5_vl/qwen2_5_vl_architecture.py +4 -7
- optimum/rbln/transformers/models/qwen3/modeling_qwen3.py +77 -3
- optimum/rbln/transformers/models/qwen3/qwen3_architecture.py +1 -4
- optimum/rbln/transformers/models/seq2seq/configuration_seq2seq.py +19 -2
- optimum/rbln/transformers/models/seq2seq/modeling_seq2seq.py +20 -1
- optimum/rbln/transformers/models/siglip/__init__.py +2 -6
- optimum/rbln/transformers/models/siglip/modeling_siglip.py +2 -2
- optimum/rbln/transformers/models/swin/__init__.py +16 -0
- optimum/rbln/transformers/models/swin/configuration_swin.py +42 -0
- optimum/rbln/transformers/models/swin/modeling_swin.py +341 -0
- optimum/rbln/transformers/models/t5/configuration_t5.py +2 -0
- optimum/rbln/transformers/models/t5/t5_architecture.py +8 -1
- optimum/rbln/transformers/models/time_series_transformer/configuration_time_series_transformer.py +2 -2
- optimum/rbln/transformers/models/time_series_transformer/modeling_time_series_transformer.py +4 -14
- optimum/rbln/transformers/models/whisper/configuration_whisper.py +10 -2
- optimum/rbln/transformers/models/whisper/modeling_whisper.py +20 -1
- optimum/rbln/transformers/models/xlm_roberta/__init__.py +2 -8
- optimum/rbln/transformers/utils/rbln_quantization.py +365 -65
- optimum/rbln/utils/runtime_utils.py +3 -3
- optimum/rbln/utils/submodule.py +10 -4
- {optimum_rbln-0.8.2rc0.dist-info → optimum_rbln-0.8.3.dist-info}/METADATA +1 -1
- {optimum_rbln-0.8.2rc0.dist-info → optimum_rbln-0.8.3.dist-info}/RECORD +105 -89
- {optimum_rbln-0.8.2rc0.dist-info → optimum_rbln-0.8.3.dist-info}/WHEEL +0 -0
- {optimum_rbln-0.8.2rc0.dist-info → optimum_rbln-0.8.3.dist-info}/licenses/LICENSE +0 -0
|
@@ -0,0 +1,237 @@
|
|
|
1
|
+
# Copyright 2025 Rebellions Inc. All rights reserved.
|
|
2
|
+
|
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
4
|
+
# you may not use this file except in compliance with the License.
|
|
5
|
+
# You may obtain a copy of the License at:
|
|
6
|
+
|
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
8
|
+
|
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
12
|
+
# See the License for the specific language governing permissions and
|
|
13
|
+
# limitations under the License.
|
|
14
|
+
|
|
15
|
+
|
|
16
|
+
import importlib
|
|
17
|
+
from typing import Type
|
|
18
|
+
|
|
19
|
+
from diffusers.models.controlnets import ControlNetUnionModel
|
|
20
|
+
from diffusers.pipelines.auto_pipeline import (
|
|
21
|
+
AUTO_IMAGE2IMAGE_PIPELINES_MAPPING,
|
|
22
|
+
AUTO_INPAINT_PIPELINES_MAPPING,
|
|
23
|
+
AUTO_TEXT2IMAGE_PIPELINES_MAPPING,
|
|
24
|
+
AutoPipelineForImage2Image,
|
|
25
|
+
AutoPipelineForInpainting,
|
|
26
|
+
AutoPipelineForText2Image,
|
|
27
|
+
_get_task_class,
|
|
28
|
+
)
|
|
29
|
+
from huggingface_hub.utils import validate_hf_hub_args
|
|
30
|
+
|
|
31
|
+
from optimum.rbln.modeling_base import RBLNBaseModel
|
|
32
|
+
from optimum.rbln.utils.model_utils import (
|
|
33
|
+
MODEL_MAPPING,
|
|
34
|
+
convert_hf_to_rbln_model_name,
|
|
35
|
+
convert_rbln_to_hf_model_name,
|
|
36
|
+
get_rbln_model_cls,
|
|
37
|
+
)
|
|
38
|
+
|
|
39
|
+
|
|
40
|
+
class RBLNAutoPipelineBase:
|
|
41
|
+
_model_mapping = None
|
|
42
|
+
_model_mapping_names = None
|
|
43
|
+
|
|
44
|
+
@classmethod
|
|
45
|
+
def get_rbln_cls(cls, pretrained_model_name_or_path, export=True, **kwargs):
|
|
46
|
+
if export:
|
|
47
|
+
hf_model_class = cls.infer_hf_model_class(pretrained_model_name_or_path, **kwargs)
|
|
48
|
+
rbln_class_name = convert_hf_to_rbln_model_name(hf_model_class.__name__)
|
|
49
|
+
else:
|
|
50
|
+
rbln_class_name = cls.get_rbln_model_cls_name(pretrained_model_name_or_path, **kwargs)
|
|
51
|
+
if convert_rbln_to_hf_model_name(rbln_class_name) not in cls._model_mapping_names.values():
|
|
52
|
+
raise ValueError(
|
|
53
|
+
f"The architecture '{rbln_class_name}' is not supported by the `{cls.__name__}.from_pretrained()` method. "
|
|
54
|
+
"Please use the `from_pretrained()` method of the appropriate class to load this model, "
|
|
55
|
+
f"or directly use '{rbln_class_name}.from_pretrained()`."
|
|
56
|
+
)
|
|
57
|
+
|
|
58
|
+
try:
|
|
59
|
+
rbln_cls = get_rbln_model_cls(rbln_class_name)
|
|
60
|
+
except AttributeError as e:
|
|
61
|
+
raise AttributeError(
|
|
62
|
+
f"Class '{rbln_class_name}' not found in 'optimum.rbln' module for model ID '{pretrained_model_name_or_path}'. "
|
|
63
|
+
"Ensure that the class name is correctly mapped and available in the 'optimum.rbln' module."
|
|
64
|
+
) from e
|
|
65
|
+
|
|
66
|
+
return rbln_cls
|
|
67
|
+
|
|
68
|
+
@classmethod
|
|
69
|
+
def get_rbln_model_cls_name(cls, pretrained_model_name_or_path, **kwargs):
|
|
70
|
+
"""
|
|
71
|
+
Retrieve the path to the compiled model directory for a given RBLN model.
|
|
72
|
+
|
|
73
|
+
Args:
|
|
74
|
+
pretrained_model_name_or_path (str): Identifier of the model.
|
|
75
|
+
|
|
76
|
+
Returns:
|
|
77
|
+
str: Path to the compiled model directory.
|
|
78
|
+
"""
|
|
79
|
+
model_index_config = cls.load_config(pretrained_model_name_or_path)
|
|
80
|
+
|
|
81
|
+
if "_class_name" not in model_index_config:
|
|
82
|
+
raise ValueError(
|
|
83
|
+
"The `_class_name` field is missing from model_index_config. This is unexpected and should be reported as an issue. "
|
|
84
|
+
"Please use the `from_pretrained()` method of the appropriate class to load this model."
|
|
85
|
+
)
|
|
86
|
+
|
|
87
|
+
return model_index_config["_class_name"]
|
|
88
|
+
|
|
89
|
+
@classmethod
|
|
90
|
+
def infer_hf_model_class(
|
|
91
|
+
cls,
|
|
92
|
+
pretrained_model_or_path,
|
|
93
|
+
cache_dir=None,
|
|
94
|
+
force_download=False,
|
|
95
|
+
proxies=None,
|
|
96
|
+
token=None,
|
|
97
|
+
local_files_only=False,
|
|
98
|
+
revision=None,
|
|
99
|
+
**kwargs,
|
|
100
|
+
):
|
|
101
|
+
config = cls.load_config(
|
|
102
|
+
pretrained_model_or_path,
|
|
103
|
+
cache_dir=cache_dir,
|
|
104
|
+
force_download=force_download,
|
|
105
|
+
proxies=proxies,
|
|
106
|
+
token=token,
|
|
107
|
+
local_files_only=local_files_only,
|
|
108
|
+
revision=revision,
|
|
109
|
+
)
|
|
110
|
+
pipeline_key_name = cls.get_pipeline_key_name(config, **kwargs)
|
|
111
|
+
|
|
112
|
+
pipeline_cls = _get_task_class(cls._model_mapping, pipeline_key_name)
|
|
113
|
+
|
|
114
|
+
return pipeline_cls
|
|
115
|
+
|
|
116
|
+
@classmethod
|
|
117
|
+
def get_pipeline_key_name(cls, config, **kwargs):
|
|
118
|
+
orig_class_name = config["_class_name"]
|
|
119
|
+
if "ControlPipeline" in orig_class_name:
|
|
120
|
+
to_replace = "ControlPipeline"
|
|
121
|
+
else:
|
|
122
|
+
to_replace = "Pipeline"
|
|
123
|
+
|
|
124
|
+
if "controlnet" in kwargs:
|
|
125
|
+
if isinstance(kwargs["controlnet"], ControlNetUnionModel):
|
|
126
|
+
orig_class_name = config["_class_name"].replace(to_replace, "ControlNetUnionPipeline")
|
|
127
|
+
else:
|
|
128
|
+
orig_class_name = config["_class_name"].replace(to_replace, "ControlNetPipeline")
|
|
129
|
+
if "enable_pag" in kwargs:
|
|
130
|
+
enable_pag = kwargs.pop("enable_pag")
|
|
131
|
+
if enable_pag:
|
|
132
|
+
orig_class_name = orig_class_name.replace(to_replace, "PAGPipeline")
|
|
133
|
+
|
|
134
|
+
return orig_class_name
|
|
135
|
+
|
|
136
|
+
@classmethod
|
|
137
|
+
@validate_hf_hub_args
|
|
138
|
+
def from_pretrained(cls, model_id, **kwargs):
|
|
139
|
+
rbln_cls = cls.get_rbln_cls(model_id, **kwargs)
|
|
140
|
+
return rbln_cls.from_pretrained(model_id, **kwargs)
|
|
141
|
+
|
|
142
|
+
@classmethod
|
|
143
|
+
def from_model(cls, model, **kwargs):
|
|
144
|
+
rbln_cls = get_rbln_model_cls(f"RBLN{model.__class__.__name__}")
|
|
145
|
+
return rbln_cls.from_model(model, **kwargs)
|
|
146
|
+
|
|
147
|
+
@staticmethod
|
|
148
|
+
def register(rbln_cls: Type[RBLNBaseModel], exist_ok=False):
|
|
149
|
+
"""
|
|
150
|
+
Register a new RBLN model class.
|
|
151
|
+
|
|
152
|
+
Args:
|
|
153
|
+
rbln_cls (Type[RBLNBaseModel]): The RBLN model class to register.
|
|
154
|
+
exist_ok (bool): Whether to allow registering an already registered model.
|
|
155
|
+
"""
|
|
156
|
+
if not issubclass(rbln_cls, RBLNBaseModel):
|
|
157
|
+
raise ValueError("`rbln_cls` must be a subclass of RBLNBaseModel.")
|
|
158
|
+
|
|
159
|
+
native_cls = getattr(importlib.import_module("optimum.rbln"), rbln_cls.__name__, None)
|
|
160
|
+
if rbln_cls.__name__ in MODEL_MAPPING or native_cls is not None:
|
|
161
|
+
if not exist_ok:
|
|
162
|
+
raise ValueError(f"Model for {rbln_cls.__name__} already registered.")
|
|
163
|
+
|
|
164
|
+
MODEL_MAPPING[rbln_cls.__name__] = rbln_cls
|
|
165
|
+
|
|
166
|
+
|
|
167
|
+
class RBLNAutoPipelineForText2Image(RBLNAutoPipelineBase, AutoPipelineForText2Image):
|
|
168
|
+
_model_mapping = AUTO_TEXT2IMAGE_PIPELINES_MAPPING
|
|
169
|
+
_model_mapping_names = {x[0]: x[1].__name__ for x in AUTO_TEXT2IMAGE_PIPELINES_MAPPING.items()}
|
|
170
|
+
|
|
171
|
+
|
|
172
|
+
class RBLNAutoPipelineForImage2Image(RBLNAutoPipelineBase, AutoPipelineForImage2Image):
|
|
173
|
+
_model_mapping = AUTO_IMAGE2IMAGE_PIPELINES_MAPPING
|
|
174
|
+
_model_mapping_names = {x[0]: x[1].__name__ for x in AUTO_IMAGE2IMAGE_PIPELINES_MAPPING.items()}
|
|
175
|
+
|
|
176
|
+
@classmethod
|
|
177
|
+
def get_pipeline_key_name(cls, config, **kwargs):
|
|
178
|
+
orig_class_name = config["_class_name"]
|
|
179
|
+
# the `orig_class_name` can be:
|
|
180
|
+
# `- *Pipeline` (for regular text-to-image checkpoint)
|
|
181
|
+
# - `*ControlPipeline` (for Flux tools specific checkpoint)
|
|
182
|
+
# `- *Img2ImgPipeline` (for refiner checkpoint)
|
|
183
|
+
if "Img2Img" in orig_class_name:
|
|
184
|
+
to_replace = "Img2ImgPipeline"
|
|
185
|
+
elif "ControlPipeline" in orig_class_name:
|
|
186
|
+
to_replace = "ControlPipeline"
|
|
187
|
+
else:
|
|
188
|
+
to_replace = "Pipeline"
|
|
189
|
+
|
|
190
|
+
if "controlnet" in kwargs:
|
|
191
|
+
if isinstance(kwargs["controlnet"], ControlNetUnionModel):
|
|
192
|
+
orig_class_name = orig_class_name.replace(to_replace, "ControlNetUnion" + to_replace)
|
|
193
|
+
else:
|
|
194
|
+
orig_class_name = orig_class_name.replace(to_replace, "ControlNet" + to_replace)
|
|
195
|
+
if "enable_pag" in kwargs:
|
|
196
|
+
enable_pag = kwargs.pop("enable_pag")
|
|
197
|
+
if enable_pag:
|
|
198
|
+
orig_class_name = orig_class_name.replace(to_replace, "PAG" + to_replace)
|
|
199
|
+
|
|
200
|
+
if to_replace == "ControlPipeline":
|
|
201
|
+
orig_class_name = orig_class_name.replace(to_replace, "ControlImg2ImgPipeline")
|
|
202
|
+
|
|
203
|
+
return orig_class_name
|
|
204
|
+
|
|
205
|
+
|
|
206
|
+
class RBLNAutoPipelineForInpainting(RBLNAutoPipelineBase, AutoPipelineForInpainting):
|
|
207
|
+
_model_mapping = AUTO_INPAINT_PIPELINES_MAPPING
|
|
208
|
+
_model_mapping_names = {x[0]: x[1].__name__ for x in AUTO_INPAINT_PIPELINES_MAPPING.items()}
|
|
209
|
+
|
|
210
|
+
@classmethod
|
|
211
|
+
def get_pipeline_key_name(cls, config, **kwargs):
|
|
212
|
+
orig_class_name = config["_class_name"]
|
|
213
|
+
|
|
214
|
+
# The `orig_class_name`` can be:
|
|
215
|
+
# `- *InpaintPipeline` (for inpaint-specific checkpoint)
|
|
216
|
+
# - `*ControlPipeline` (for Flux tools specific checkpoint)
|
|
217
|
+
# - or *Pipeline (for regular text-to-image checkpoint)
|
|
218
|
+
if "Inpaint" in orig_class_name:
|
|
219
|
+
to_replace = "InpaintPipeline"
|
|
220
|
+
elif "ControlPipeline" in orig_class_name:
|
|
221
|
+
to_replace = "ControlPipeline"
|
|
222
|
+
else:
|
|
223
|
+
to_replace = "Pipeline"
|
|
224
|
+
|
|
225
|
+
if "controlnet" in kwargs:
|
|
226
|
+
if isinstance(kwargs["controlnet"], ControlNetUnionModel):
|
|
227
|
+
orig_class_name = orig_class_name.replace(to_replace, "ControlNetUnion" + to_replace)
|
|
228
|
+
else:
|
|
229
|
+
orig_class_name = orig_class_name.replace(to_replace, "ControlNet" + to_replace)
|
|
230
|
+
if "enable_pag" in kwargs:
|
|
231
|
+
enable_pag = kwargs.pop("enable_pag")
|
|
232
|
+
if enable_pag:
|
|
233
|
+
orig_class_name = orig_class_name.replace(to_replace, "PAG" + to_replace)
|
|
234
|
+
if to_replace == "ControlPipeline":
|
|
235
|
+
orig_class_name = orig_class_name.replace(to_replace, "ControlInpaintPipeline")
|
|
236
|
+
|
|
237
|
+
return orig_class_name
|
|
@@ -12,7 +12,7 @@
|
|
|
12
12
|
# See the License for the specific language governing permissions and
|
|
13
13
|
# limitations under the License.
|
|
14
14
|
|
|
15
|
-
from typing import Any,
|
|
15
|
+
from typing import Any, Optional, Tuple
|
|
16
16
|
|
|
17
17
|
from ....configuration_utils import RBLNAutoConfig, RBLNModelConfig
|
|
18
18
|
from ....transformers import RBLNLlamaForCausalLMConfig, RBLNSiglipVisionModelConfig
|
|
@@ -56,11 +56,11 @@ class RBLNCosmosSafetyCheckerConfig(RBLNModelConfig):
|
|
|
56
56
|
Configuration class for RBLN Cosmos Safety Checker.
|
|
57
57
|
"""
|
|
58
58
|
|
|
59
|
-
submodules = ["
|
|
59
|
+
submodules = ["llamaguard3", "video_safety_model", "face_blur_filter", "siglip_encoder"]
|
|
60
60
|
|
|
61
61
|
def __init__(
|
|
62
62
|
self,
|
|
63
|
-
|
|
63
|
+
llamaguard3: Optional[RBLNModelConfig] = None,
|
|
64
64
|
video_safety_model: Optional[RBLNModelConfig] = None,
|
|
65
65
|
face_blur_filter: Optional[RBLNModelConfig] = None,
|
|
66
66
|
siglip_encoder: Optional[RBLNSiglipVisionModelConfig] = None,
|
|
@@ -69,19 +69,24 @@ class RBLNCosmosSafetyCheckerConfig(RBLNModelConfig):
|
|
|
69
69
|
image_size: Optional[Tuple[int, int]] = None,
|
|
70
70
|
height: Optional[int] = None,
|
|
71
71
|
width: Optional[int] = None,
|
|
72
|
-
|
|
72
|
+
max_seq_len: Optional[int] = None,
|
|
73
|
+
**kwargs: Any,
|
|
73
74
|
):
|
|
74
75
|
super().__init__(**kwargs)
|
|
75
76
|
if height is not None and width is not None:
|
|
76
77
|
image_size = (height, width)
|
|
77
78
|
|
|
79
|
+
if max_seq_len is None:
|
|
80
|
+
max_seq_len = 512
|
|
81
|
+
|
|
78
82
|
tensor_parallel_size = kwargs.get("tensor_parallel_size")
|
|
79
83
|
|
|
80
|
-
self.
|
|
84
|
+
self.llamaguard3 = self.init_submodule_config(
|
|
81
85
|
RBLNLlamaForCausalLMConfig,
|
|
82
|
-
|
|
86
|
+
llamaguard3,
|
|
83
87
|
batch_size=batch_size,
|
|
84
88
|
tensor_parallel_size=tensor_parallel_size,
|
|
89
|
+
max_seq_len=max_seq_len,
|
|
85
90
|
)
|
|
86
91
|
|
|
87
92
|
self.siglip_encoder = self.init_submodule_config(
|
|
@@ -33,9 +33,9 @@ if is_cosmos_guardrail_available():
|
|
|
33
33
|
from cosmos_guardrail import CosmosSafetyChecker
|
|
34
34
|
from cosmos_guardrail.cosmos_guardrail import (
|
|
35
35
|
COSMOS_GUARDRAIL_CHECKPOINT,
|
|
36
|
-
Aegis,
|
|
37
36
|
Blocklist,
|
|
38
37
|
GuardrailRunner,
|
|
38
|
+
LlamaGuard3,
|
|
39
39
|
ModelConfig,
|
|
40
40
|
RetinaFaceFilter,
|
|
41
41
|
SafetyClassifier,
|
|
@@ -55,7 +55,7 @@ else:
|
|
|
55
55
|
|
|
56
56
|
COSMOS_GUARDRAIL_CHECKPOINT = None
|
|
57
57
|
|
|
58
|
-
class
|
|
58
|
+
class LlamaGuard3(FailToImportCosmosGuardrail): ...
|
|
59
59
|
|
|
60
60
|
class Blocklist(FailToImportCosmosGuardrail): ...
|
|
61
61
|
|
|
@@ -312,33 +312,31 @@ class RBLNVideoContentSafetyFilter(VideoContentSafetyFilter):
|
|
|
312
312
|
self.encoder.save_pretrained(checkpoint_id)
|
|
313
313
|
|
|
314
314
|
|
|
315
|
-
class
|
|
315
|
+
class RBLNLlamaGuard3(LlamaGuard3):
|
|
316
316
|
def __init__(
|
|
317
317
|
self,
|
|
318
318
|
checkpoint_id: str = COSMOS_GUARDRAIL_CHECKPOINT,
|
|
319
|
-
base_model_id: str = "meta-llama/
|
|
320
|
-
aegis_adapter: str = "nvidia/Aegis-AI-Content-Safety-LlamaGuard-Defensive-1.0",
|
|
319
|
+
base_model_id: str = "meta-llama/Llama-Guard-3-8B",
|
|
321
320
|
rbln_config: Optional[RBLNCosmosSafetyCheckerConfig] = None,
|
|
322
321
|
) -> None:
|
|
323
322
|
if is_compiled_dir(checkpoint_id):
|
|
324
323
|
torch.nn.Module.__init__(self)
|
|
325
|
-
cache_dir = pathlib.Path(checkpoint_id) / "
|
|
324
|
+
cache_dir = pathlib.Path(checkpoint_id) / "llamaguard3"
|
|
326
325
|
self.tokenizer = AutoTokenizer.from_pretrained(cache_dir)
|
|
327
|
-
self.model = RBLNAutoModelForCausalLM.from_pretrained(cache_dir, rbln_config=rbln_config.
|
|
326
|
+
self.model = RBLNAutoModelForCausalLM.from_pretrained(cache_dir, rbln_config=rbln_config.llamaguard3)
|
|
328
327
|
|
|
329
328
|
else:
|
|
330
|
-
super().__init__(checkpoint_id, base_model_id
|
|
331
|
-
model = self.model
|
|
329
|
+
super().__init__(checkpoint_id, base_model_id)
|
|
330
|
+
model = self.model
|
|
332
331
|
del self.model
|
|
333
|
-
|
|
334
|
-
self.model = RBLNAutoModelForCausalLM.from_model(model, rbln_config=rbln_config.aegis)
|
|
332
|
+
self.model = RBLNAutoModelForCausalLM.from_model(model, rbln_config=rbln_config.llamaguard3)
|
|
335
333
|
|
|
336
334
|
self.rbln_config = rbln_config
|
|
337
335
|
self.dtype = torch.bfloat16
|
|
338
336
|
self.device = torch.device("cpu")
|
|
339
337
|
|
|
340
338
|
def save_pretrained(self, checkpoint_id: str):
|
|
341
|
-
cache_dir = pathlib.Path(checkpoint_id) / "
|
|
339
|
+
cache_dir = pathlib.Path(checkpoint_id) / "llamaguard3"
|
|
342
340
|
self.model.save_pretrained(cache_dir)
|
|
343
341
|
self.tokenizer.save_pretrained(cache_dir)
|
|
344
342
|
|
|
@@ -351,8 +349,7 @@ class RBLNCosmosSafetyChecker(CosmosSafetyChecker):
|
|
|
351
349
|
def __init__(
|
|
352
350
|
self,
|
|
353
351
|
checkpoint_id: str = COSMOS_GUARDRAIL_CHECKPOINT,
|
|
354
|
-
|
|
355
|
-
aegis_adapter_id: str = "nvidia/Aegis-AI-Content-Safety-LlamaGuard-Defensive-1.0",
|
|
352
|
+
llamaguard_model_id: str = "meta-llama/Llama-Guard-3-8B",
|
|
356
353
|
rbln_config: Optional[RBLNCosmosSafetyCheckerConfig] = None,
|
|
357
354
|
) -> None:
|
|
358
355
|
torch.nn.Module.__init__(self)
|
|
@@ -369,10 +366,9 @@ class RBLNCosmosSafetyChecker(CosmosSafetyChecker):
|
|
|
369
366
|
self.text_guardrail = GuardrailRunner(
|
|
370
367
|
safety_models=[
|
|
371
368
|
Blocklist(COSMOS_GUARDRAIL_CHECKPOINT), # Changed since it cannot be saved
|
|
372
|
-
|
|
369
|
+
RBLNLlamaGuard3(
|
|
373
370
|
checkpoint_id=checkpoint_id,
|
|
374
|
-
base_model_id=
|
|
375
|
-
aegis_adapter=aegis_adapter_id,
|
|
371
|
+
base_model_id=llamaguard_model_id,
|
|
376
372
|
rbln_config=rbln_config,
|
|
377
373
|
),
|
|
378
374
|
]
|
|
@@ -387,7 +383,7 @@ class RBLNCosmosSafetyChecker(CosmosSafetyChecker):
|
|
|
387
383
|
|
|
388
384
|
def save_pretrained(self, save_dir: str):
|
|
389
385
|
for text_safety_models in self.text_guardrail.safety_models:
|
|
390
|
-
if isinstance(text_safety_models,
|
|
386
|
+
if isinstance(text_safety_models, RBLNLlamaGuard3):
|
|
391
387
|
text_safety_models.save_pretrained(save_dir)
|
|
392
388
|
|
|
393
389
|
for video_safety_models in self.video_guardrail.safety_models:
|
|
@@ -87,7 +87,7 @@ class RBLNCosmosTextToWorldPipeline(RBLNDiffusionMixin, CosmosTextToWorldPipelin
|
|
|
87
87
|
export: bool = False,
|
|
88
88
|
safety_checker: Optional[RBLNCosmosSafetyChecker] = None,
|
|
89
89
|
rbln_config: Dict[str, Any] = {},
|
|
90
|
-
**kwargs:
|
|
90
|
+
**kwargs: Any,
|
|
91
91
|
):
|
|
92
92
|
rbln_config, kwargs = cls.get_rbln_config_class().initialize_from_kwargs(rbln_config, **kwargs)
|
|
93
93
|
if safety_checker is None and export:
|
|
@@ -87,7 +87,7 @@ class RBLNCosmosVideoToWorldPipeline(RBLNDiffusionMixin, CosmosVideoToWorldPipel
|
|
|
87
87
|
export: bool = False,
|
|
88
88
|
safety_checker: Optional[RBLNCosmosSafetyChecker] = None,
|
|
89
89
|
rbln_config: Dict[str, Any] = {},
|
|
90
|
-
**kwargs:
|
|
90
|
+
**kwargs: Any,
|
|
91
91
|
):
|
|
92
92
|
rbln_config, kwargs = cls.get_rbln_config_class().initialize_from_kwargs(rbln_config, **kwargs)
|
|
93
93
|
if safety_checker is None and export:
|
|
@@ -22,12 +22,7 @@ from diffusers import (
|
|
|
22
22
|
UNet2DConditionModel,
|
|
23
23
|
VQModel,
|
|
24
24
|
)
|
|
25
|
-
from transformers import
|
|
26
|
-
CLIPImageProcessor,
|
|
27
|
-
CLIPTextModelWithProjection,
|
|
28
|
-
CLIPTokenizer,
|
|
29
|
-
CLIPVisionModelWithProjection,
|
|
30
|
-
)
|
|
25
|
+
from transformers import CLIPImageProcessor, CLIPTextModelWithProjection, CLIPTokenizer, CLIPVisionModelWithProjection
|
|
31
26
|
|
|
32
27
|
from ...configurations import RBLNKandinskyV22CombinedPipelineConfig
|
|
33
28
|
from ...modeling_diffusers import RBLNDiffusionMixin
|
optimum/rbln/modeling.py
CHANGED
|
@@ -78,7 +78,7 @@ class RBLNModel(RBLNBaseModel):
|
|
|
78
78
|
rbln_config: Optional[Union[RBLNModelConfig, Dict]] = None,
|
|
79
79
|
model_save_dir: Optional[Union[str, Path, TemporaryDirectory]] = None,
|
|
80
80
|
subfolder: str = "",
|
|
81
|
-
**kwargs:
|
|
81
|
+
**kwargs: Any,
|
|
82
82
|
) -> "RBLNModel":
|
|
83
83
|
"""
|
|
84
84
|
Converts and compiles a pre-trained HuggingFace library model into a RBLN model.
|
|
@@ -147,6 +147,7 @@ class RBLNModel(RBLNBaseModel):
|
|
|
147
147
|
model=model,
|
|
148
148
|
model_save_dir=save_dir,
|
|
149
149
|
rbln_config=rbln_config,
|
|
150
|
+
preprocessors=preprocessors,
|
|
150
151
|
**kwargs,
|
|
151
152
|
)
|
|
152
153
|
else:
|
|
@@ -241,7 +242,7 @@ class RBLNModel(RBLNBaseModel):
|
|
|
241
242
|
for compiled_model in compiled_models
|
|
242
243
|
]
|
|
243
244
|
|
|
244
|
-
def forward(self, *args: Any, return_dict: Optional[bool] = None, **kwargs:
|
|
245
|
+
def forward(self, *args: Any, return_dict: Optional[bool] = None, **kwargs: Any) -> Any:
|
|
245
246
|
"""
|
|
246
247
|
Defines the forward pass of the RBLN model, providing a drop-in replacement for HuggingFace PreTrainedModel.
|
|
247
248
|
|
optimum/rbln/modeling_base.py
CHANGED
|
@@ -348,7 +348,7 @@ class RBLNBaseModel(SubModulesMixin, PushToHubMixin, PreTrainedModel):
|
|
|
348
348
|
model_id: Union[str, Path],
|
|
349
349
|
export: bool = False,
|
|
350
350
|
rbln_config: Optional[Union[Dict, RBLNModelConfig]] = None,
|
|
351
|
-
**kwargs:
|
|
351
|
+
**kwargs: Any,
|
|
352
352
|
) -> "RBLNBaseModel":
|
|
353
353
|
"""
|
|
354
354
|
The `from_pretrained()` function is utilized in its standard form as in the HuggingFace transformers library.
|
|
@@ -523,10 +523,35 @@ class RBLNBaseModel(SubModulesMixin, PushToHubMixin, PreTrainedModel):
|
|
|
523
523
|
# First copy everything to a temporary directory
|
|
524
524
|
shutil.copytree(real_save_dir, tmp_dir)
|
|
525
525
|
|
|
526
|
-
# If everything succeeded,
|
|
526
|
+
# If everything succeeded, move files to target directory
|
|
527
527
|
if os.path.exists(save_directory_path):
|
|
528
|
-
|
|
529
|
-
|
|
528
|
+
# Merge files from tmp_dir into existing directory
|
|
529
|
+
def _merge_dir(src_root: str, dst_root: str):
|
|
530
|
+
for name in os.listdir(src_root):
|
|
531
|
+
src_item = os.path.join(src_root, name)
|
|
532
|
+
dst_item = os.path.join(dst_root, name)
|
|
533
|
+
|
|
534
|
+
if os.path.islink(src_item) or os.path.isfile(src_item):
|
|
535
|
+
os.makedirs(os.path.dirname(dst_item), exist_ok=True)
|
|
536
|
+
if os.path.isdir(dst_item) and not os.path.islink(dst_item):
|
|
537
|
+
shutil.rmtree(dst_item)
|
|
538
|
+
os.replace(src_item, dst_item)
|
|
539
|
+
elif os.path.isdir(src_item):
|
|
540
|
+
if os.path.islink(dst_item) or os.path.isfile(dst_item):
|
|
541
|
+
os.remove(dst_item)
|
|
542
|
+
os.makedirs(dst_item, exist_ok=True)
|
|
543
|
+
_merge_dir(src_item, dst_item)
|
|
544
|
+
else:
|
|
545
|
+
# Fallback for special file types
|
|
546
|
+
os.replace(src_item, dst_item)
|
|
547
|
+
|
|
548
|
+
_merge_dir(tmp_dir, str(save_directory_path))
|
|
549
|
+
|
|
550
|
+
# Remove the temporary directory tree after merge
|
|
551
|
+
shutil.rmtree(tmp_dir)
|
|
552
|
+
else:
|
|
553
|
+
# If target doesn't exist, just rename tmp_dir to target
|
|
554
|
+
os.rename(tmp_dir, save_directory_path)
|
|
530
555
|
|
|
531
556
|
except Exception as e:
|
|
532
557
|
# Clean up the temporary directory if anything fails
|
optimum/rbln/ops/attn.py
CHANGED
|
@@ -53,6 +53,45 @@ def paged_attn_decode_fake(
|
|
|
53
53
|
return torch.empty_like(q)
|
|
54
54
|
|
|
55
55
|
|
|
56
|
+
@torch.library.custom_op(
|
|
57
|
+
"rbln_custom_ops::paged_attn_decode_kv_fp8",
|
|
58
|
+
mutates_args=(["kcache", "vcache"]),
|
|
59
|
+
)
|
|
60
|
+
def paged_attn_decode_kv_fp8(
|
|
61
|
+
q: Tensor,
|
|
62
|
+
k: Tensor,
|
|
63
|
+
v: Tensor,
|
|
64
|
+
mask: Tensor,
|
|
65
|
+
kcache: Tensor,
|
|
66
|
+
vcache: Tensor,
|
|
67
|
+
seq: Tensor,
|
|
68
|
+
scale: Tensor,
|
|
69
|
+
block_table: Tensor,
|
|
70
|
+
block_size: int,
|
|
71
|
+
k_scale: Tensor,
|
|
72
|
+
v_scale: Tensor,
|
|
73
|
+
) -> Tensor:
|
|
74
|
+
return torch.empty_like(q)
|
|
75
|
+
|
|
76
|
+
|
|
77
|
+
@paged_attn_decode_kv_fp8.register_fake
|
|
78
|
+
def paged_attn_decode_kv_fp8_fake(
|
|
79
|
+
q: Tensor,
|
|
80
|
+
k: Tensor,
|
|
81
|
+
v: Tensor,
|
|
82
|
+
mask: Tensor,
|
|
83
|
+
kcache: Tensor,
|
|
84
|
+
vcache: Tensor,
|
|
85
|
+
seq: Tensor,
|
|
86
|
+
scale: Tensor,
|
|
87
|
+
block_table: Tensor,
|
|
88
|
+
block_size: int,
|
|
89
|
+
k_scale: Tensor,
|
|
90
|
+
v_scale: Tensor,
|
|
91
|
+
) -> Tensor:
|
|
92
|
+
return torch.empty_like(q)
|
|
93
|
+
|
|
94
|
+
|
|
56
95
|
@torch.library.custom_op(
|
|
57
96
|
"rbln_custom_ops::paged_attn_prefill",
|
|
58
97
|
mutates_args=(["kcache", "vcache"]),
|
|
@@ -112,6 +151,45 @@ def paged_attn_prefill_fake(
|
|
|
112
151
|
return torch.empty_like(q)
|
|
113
152
|
|
|
114
153
|
|
|
154
|
+
@torch.library.custom_op(
|
|
155
|
+
"rbln_custom_ops::paged_attn_prefill_kv_fp8",
|
|
156
|
+
mutates_args=(["kcache", "vcache"]),
|
|
157
|
+
)
|
|
158
|
+
def paged_attn_prefill_kv_fp8(
|
|
159
|
+
q: Tensor,
|
|
160
|
+
k: Tensor,
|
|
161
|
+
v: Tensor,
|
|
162
|
+
mask: Tensor,
|
|
163
|
+
kcache: Tensor,
|
|
164
|
+
vcache: Tensor,
|
|
165
|
+
seq: Tensor,
|
|
166
|
+
scale: Tensor,
|
|
167
|
+
block_table: Tensor,
|
|
168
|
+
block_size: int,
|
|
169
|
+
k_scale: Tensor,
|
|
170
|
+
v_scale: Tensor,
|
|
171
|
+
) -> Tensor:
|
|
172
|
+
return torch.empty_like(q)
|
|
173
|
+
|
|
174
|
+
|
|
175
|
+
@paged_attn_prefill_kv_fp8.register_fake
|
|
176
|
+
def paged_attn_prefill_kv_fp8_fake(
|
|
177
|
+
q: Tensor,
|
|
178
|
+
k: Tensor,
|
|
179
|
+
v: Tensor,
|
|
180
|
+
mask: Tensor,
|
|
181
|
+
kcache: Tensor,
|
|
182
|
+
vcache: Tensor,
|
|
183
|
+
seq: Tensor,
|
|
184
|
+
scale: Tensor,
|
|
185
|
+
block_table: Tensor,
|
|
186
|
+
block_size: int,
|
|
187
|
+
k_scale: Tensor,
|
|
188
|
+
v_scale: Tensor,
|
|
189
|
+
) -> Tensor:
|
|
190
|
+
return torch.empty_like(q)
|
|
191
|
+
|
|
192
|
+
|
|
115
193
|
@torch.library.custom_op(
|
|
116
194
|
"rbln_custom_ops::paged_causal_attn_decode",
|
|
117
195
|
mutates_args=(["kcache", "vcache"]),
|
|
@@ -236,6 +314,86 @@ def paged_causal_attn_prefill_fake(
|
|
|
236
314
|
return torch.empty_like(q)
|
|
237
315
|
|
|
238
316
|
|
|
317
|
+
@torch.library.custom_op(
|
|
318
|
+
"rbln_custom_ops::paged_causal_attn_decode_kv_fp8",
|
|
319
|
+
mutates_args=(["kcache", "vcache"]),
|
|
320
|
+
)
|
|
321
|
+
def paged_causal_attn_decode_kv_fp8(
|
|
322
|
+
q: Tensor,
|
|
323
|
+
k: Tensor,
|
|
324
|
+
v: Tensor,
|
|
325
|
+
kcache: Tensor,
|
|
326
|
+
vcache: Tensor,
|
|
327
|
+
seq: Tensor,
|
|
328
|
+
scale: Tensor,
|
|
329
|
+
block_table: Tensor,
|
|
330
|
+
block_size: int,
|
|
331
|
+
k_scale: Tensor,
|
|
332
|
+
v_scale: Tensor,
|
|
333
|
+
mask: Optional[Tensor] = None,
|
|
334
|
+
) -> Tensor:
|
|
335
|
+
return torch.empty_like(q)
|
|
336
|
+
|
|
337
|
+
|
|
338
|
+
@paged_causal_attn_decode_kv_fp8.register_fake
|
|
339
|
+
def paged_causal_attn_decode_kv_fp8_fake(
|
|
340
|
+
q: Tensor,
|
|
341
|
+
k: Tensor,
|
|
342
|
+
v: Tensor,
|
|
343
|
+
kcache: Tensor,
|
|
344
|
+
vcache: Tensor,
|
|
345
|
+
seq: Tensor,
|
|
346
|
+
scale: Tensor,
|
|
347
|
+
block_table: Tensor,
|
|
348
|
+
block_size: int,
|
|
349
|
+
k_scale: Tensor,
|
|
350
|
+
v_scale: Tensor,
|
|
351
|
+
mask: Optional[Tensor] = None,
|
|
352
|
+
) -> Tensor:
|
|
353
|
+
return torch.empty_like(q)
|
|
354
|
+
|
|
355
|
+
|
|
356
|
+
@torch.library.custom_op(
|
|
357
|
+
"rbln_custom_ops::paged_causal_attn_prefill_kv_fp8",
|
|
358
|
+
mutates_args=(["kcache", "vcache"]),
|
|
359
|
+
)
|
|
360
|
+
def paged_causal_attn_prefill_kv_fp8(
|
|
361
|
+
q: Tensor,
|
|
362
|
+
k: Tensor,
|
|
363
|
+
v: Tensor,
|
|
364
|
+
kcache: Tensor,
|
|
365
|
+
vcache: Tensor,
|
|
366
|
+
seq: Tensor,
|
|
367
|
+
scale: Tensor,
|
|
368
|
+
block_table: Tensor,
|
|
369
|
+
block_size: int,
|
|
370
|
+
is_bidirectional: bool,
|
|
371
|
+
k_scale: Tensor,
|
|
372
|
+
v_scale: Tensor,
|
|
373
|
+
mask: Optional[Tensor] = None,
|
|
374
|
+
) -> Tensor:
|
|
375
|
+
return torch.empty_like(q)
|
|
376
|
+
|
|
377
|
+
|
|
378
|
+
@paged_causal_attn_prefill_kv_fp8.register_fake
|
|
379
|
+
def paged_causal_attn_prefill_kv_fp8_fake(
|
|
380
|
+
q: Tensor,
|
|
381
|
+
k: Tensor,
|
|
382
|
+
v: Tensor,
|
|
383
|
+
kcache: Tensor,
|
|
384
|
+
vcache: Tensor,
|
|
385
|
+
seq: Tensor,
|
|
386
|
+
scale: Tensor,
|
|
387
|
+
block_table: Tensor,
|
|
388
|
+
block_size: int,
|
|
389
|
+
is_bidirectional: bool,
|
|
390
|
+
k_scale: Tensor,
|
|
391
|
+
v_scale: Tensor,
|
|
392
|
+
mask: Optional[Tensor] = None,
|
|
393
|
+
) -> Tensor:
|
|
394
|
+
return torch.empty_like(q)
|
|
395
|
+
|
|
396
|
+
|
|
239
397
|
@torch.library.custom_op(
|
|
240
398
|
"rbln_custom_ops::paged_add_softmax_attn_decode",
|
|
241
399
|
mutates_args=(["kcache", "vcache"]),
|