optimum-rbln 0.8.3a4__py3-none-any.whl → 0.8.4a0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of optimum-rbln might be problematic. Click here for more details.

Files changed (31) hide show
  1. optimum/rbln/__init__.py +14 -0
  2. optimum/rbln/__version__.py +2 -2
  3. optimum/rbln/configuration_utils.py +15 -0
  4. optimum/rbln/modeling.py +2 -4
  5. optimum/rbln/modeling_base.py +44 -13
  6. optimum/rbln/transformers/__init__.py +14 -0
  7. optimum/rbln/transformers/configuration_generic.py +2 -0
  8. optimum/rbln/transformers/modeling_generic.py +12 -4
  9. optimum/rbln/transformers/models/__init__.py +18 -0
  10. optimum/rbln/transformers/models/auto/__init__.py +1 -0
  11. optimum/rbln/transformers/models/auto/modeling_auto.py +7 -0
  12. optimum/rbln/transformers/models/bert/bert_architecture.py +16 -0
  13. optimum/rbln/transformers/models/bert/modeling_bert.py +8 -4
  14. optimum/rbln/transformers/models/blip_2/modeling_blip_2.py +6 -1
  15. optimum/rbln/transformers/models/decoderonly/decoderonly_architecture.py +6 -3
  16. optimum/rbln/transformers/models/decoderonly/decoderonly_runtime_utils.py +7 -1
  17. optimum/rbln/transformers/models/decoderonly/modeling_decoderonly.py +12 -31
  18. optimum/rbln/transformers/models/depth_anything/modeling_depth_anything.py +1 -1
  19. optimum/rbln/transformers/models/gemma3/modeling_gemma3.py +7 -1
  20. optimum/rbln/transformers/models/grounding_dino/__init__.py +10 -0
  21. optimum/rbln/transformers/models/grounding_dino/configuration_grounding_dino.py +86 -0
  22. optimum/rbln/transformers/models/grounding_dino/grounding_dino_architecture.py +507 -0
  23. optimum/rbln/transformers/models/grounding_dino/modeling_grounding_dino.py +1032 -0
  24. optimum/rbln/transformers/models/qwen2_5_vl/modeling_qwen2_5_vl.py +2 -0
  25. optimum/rbln/transformers/models/swin/modeling_swin.py +32 -7
  26. optimum/rbln/transformers/utils/rbln_quantization.py +47 -31
  27. optimum/rbln/utils/submodule.py +10 -4
  28. {optimum_rbln-0.8.3a4.dist-info → optimum_rbln-0.8.4a0.dist-info}/METADATA +1 -1
  29. {optimum_rbln-0.8.3a4.dist-info → optimum_rbln-0.8.4a0.dist-info}/RECORD +31 -26
  30. {optimum_rbln-0.8.3a4.dist-info → optimum_rbln-0.8.4a0.dist-info}/WHEEL +0 -0
  31. {optimum_rbln-0.8.3a4.dist-info → optimum_rbln-0.8.4a0.dist-info}/licenses/LICENSE +0 -0
optimum/rbln/__init__.py CHANGED
@@ -47,6 +47,7 @@ _import_structure = {
47
47
  "RBLNAutoModelForSpeechSeq2Seq",
48
48
  "RBLNAutoModelForVision2Seq",
49
49
  "RBLNAutoModelForTextEncoding",
50
+ "RBLNAutoModelForZeroShotObjectDetection",
50
51
  "RBLNBartForConditionalGeneration",
51
52
  "RBLNBartForConditionalGenerationConfig",
52
53
  "RBLNBartModel",
@@ -97,6 +98,12 @@ _import_structure = {
97
98
  "RBLNGPT2ModelConfig",
98
99
  "RBLNGPT2LMHeadModel",
99
100
  "RBLNGPT2LMHeadModelConfig",
101
+ "RBLNGroundingDinoDecoder",
102
+ "RBLNGroundingDinoDecoderConfig",
103
+ "RBLNGroundingDinoForObjectDetection",
104
+ "RBLNGroundingDinoForObjectDetectionConfig",
105
+ "RBLNGroundingDinoEncoder",
106
+ "RBLNGroundingDinoEncoderConfig",
100
107
  "RBLNIdefics3VisionTransformer",
101
108
  "RBLNIdefics3ForConditionalGeneration",
102
109
  "RBLNIdefics3ForConditionalGenerationConfig",
@@ -326,6 +333,7 @@ if TYPE_CHECKING:
326
333
  RBLNAutoModelForSpeechSeq2Seq,
327
334
  RBLNAutoModelForTextEncoding,
328
335
  RBLNAutoModelForVision2Seq,
336
+ RBLNAutoModelForZeroShotObjectDetection,
329
337
  RBLNBartForConditionalGeneration,
330
338
  RBLNBartForConditionalGenerationConfig,
331
339
  RBLNBartModel,
@@ -376,6 +384,12 @@ if TYPE_CHECKING:
376
384
  RBLNGPT2LMHeadModelConfig,
377
385
  RBLNGPT2Model,
378
386
  RBLNGPT2ModelConfig,
387
+ RBLNGroundingDinoDecoder,
388
+ RBLNGroundingDinoDecoderConfig,
389
+ RBLNGroundingDinoEncoder,
390
+ RBLNGroundingDinoEncoderConfig,
391
+ RBLNGroundingDinoForObjectDetection,
392
+ RBLNGroundingDinoForObjectDetectionConfig,
379
393
  RBLNIdefics3ForConditionalGeneration,
380
394
  RBLNIdefics3ForConditionalGenerationConfig,
381
395
  RBLNIdefics3VisionTransformer,
@@ -28,7 +28,7 @@ version_tuple: VERSION_TUPLE
28
28
  commit_id: COMMIT_ID
29
29
  __commit_id__: COMMIT_ID
30
30
 
31
- __version__ = version = '0.8.3a4'
32
- __version_tuple__ = version_tuple = (0, 8, 3, 'a4')
31
+ __version__ = version = '0.8.4a0'
32
+ __version_tuple__ = version_tuple = (0, 8, 4, 'a0')
33
33
 
34
34
  __commit_id__ = commit_id = None
@@ -476,6 +476,7 @@ class RBLNModelConfig(RBLNSerializableConfigProtocol):
476
476
  non_save_attributes = [
477
477
  "_frozen",
478
478
  "_runtime_options",
479
+ "torch_dtype",
479
480
  "npu",
480
481
  "tensor_parallel_size",
481
482
  "create_runtimes",
@@ -566,6 +567,7 @@ class RBLNModelConfig(RBLNSerializableConfigProtocol):
566
567
  tensor_parallel_size: Optional[int] = None,
567
568
  timeout: Optional[int] = None,
568
569
  optimum_rbln_version: Optional[str] = None,
570
+ _torch_dtype: Optional[str] = None,
569
571
  _compile_cfgs: List[RBLNCompileConfig] = [],
570
572
  **kwargs: Any,
571
573
  ):
@@ -583,6 +585,7 @@ class RBLNModelConfig(RBLNSerializableConfigProtocol):
583
585
  tensor_parallel_size (Optional[int]): Size for tensor parallelism to distribute the model across devices.
584
586
  timeout (Optional[int]): The timeout for the runtime in seconds. If it isn't provided, it will be set to 60 by default.
585
587
  optimum_rbln_version (Optional[str]): The optimum-rbln version used for this configuration.
588
+ _torch_dtype (Optional[str]): The data type to use for the model.
586
589
  _compile_cfgs (List[RBLNCompileConfig]): List of compilation configurations for the model.
587
590
  **kwargs: Additional keyword arguments.
588
591
 
@@ -610,6 +613,7 @@ class RBLNModelConfig(RBLNSerializableConfigProtocol):
610
613
  self.npu = npu
611
614
  self.tensor_parallel_size = tensor_parallel_size
612
615
 
616
+ self._torch_dtype = _torch_dtype or "float32"
613
617
  self.optimum_rbln_version = optimum_rbln_version
614
618
  if self.optimum_rbln_version is None:
615
619
  self.optimum_rbln_version = __version__
@@ -639,6 +643,17 @@ class RBLNModelConfig(RBLNSerializableConfigProtocol):
639
643
 
640
644
  raise ValueError(f"Unexpected arguments: {kwargs.keys()}")
641
645
 
646
+ @property
647
+ def torch_dtype(self):
648
+ return getattr(torch, self._torch_dtype)
649
+
650
+ @torch_dtype.setter
651
+ def torch_dtype(self, torch_dtype: Union[str, torch.dtype]):
652
+ if isinstance(torch_dtype, torch.dtype):
653
+ torch_dtype = RBLNCompileConfig.normalize_dtype(torch_dtype)
654
+
655
+ self._torch_dtype = torch_dtype
656
+
642
657
  @property
643
658
  def rbln_model_cls_name(self) -> str:
644
659
  return self.__class__.__name__[:-6]
optimum/rbln/modeling.py CHANGED
@@ -19,7 +19,7 @@ from typing import TYPE_CHECKING, Any, Dict, List, Optional, Union, get_args, ge
19
19
  import rebel
20
20
  import torch
21
21
  from huggingface_hub.constants import HUGGINGFACE_HUB_CACHE
22
- from transformers import AutoConfig, PretrainedConfig
22
+ from transformers import PretrainedConfig
23
23
  from transformers.modeling_outputs import BaseModelOutput
24
24
 
25
25
  from .configuration_utils import DEFAULT_COMPILED_MODEL_NAME, RBLNModelConfig
@@ -119,9 +119,6 @@ class RBLNModel(RBLNBaseModel):
119
119
  # Save configs
120
120
  if config is None:
121
121
  config = model.config
122
- # remote_config
123
- if hasattr(config, "auto_map") and "AutoConfig" in config.auto_map:
124
- config = AutoConfig.from_pretrained(config._name_or_path, **kwargs)
125
122
 
126
123
  if hasattr(model, "can_generate") and model.can_generate():
127
124
  import json
@@ -147,6 +144,7 @@ class RBLNModel(RBLNBaseModel):
147
144
  model=model,
148
145
  model_save_dir=save_dir,
149
146
  rbln_config=rbln_config,
147
+ preprocessors=preprocessors,
150
148
  **kwargs,
151
149
  )
152
150
  else:
@@ -34,7 +34,7 @@ from .utils.submodule import SubModulesMixin
34
34
 
35
35
 
36
36
  if TYPE_CHECKING:
37
- from transformers import PreTrainedModel
37
+ from transformers import AutoFeatureExtractor, AutoProcessor, AutoTokenizer, PreTrainedModel
38
38
 
39
39
  logger = get_logger(__name__)
40
40
 
@@ -53,6 +53,7 @@ class RBLNBaseModel(SubModulesMixin, PushToHubMixin, PreTrainedModel):
53
53
  config_class = AutoConfig
54
54
  config_name = "config.json"
55
55
  hf_library_name = "transformers"
56
+ _supports_non_fp32 = False
56
57
 
57
58
  def __init__(
58
59
  self,
@@ -91,7 +92,7 @@ class RBLNBaseModel(SubModulesMixin, PushToHubMixin, PreTrainedModel):
91
92
 
92
93
  self.device = torch.device("cpu")
93
94
  self.training = False
94
- self.dtype = torch.float32
95
+ self.dtype = rbln_config.torch_dtype
95
96
 
96
97
  # FIXME :: model_save_dir is not used after initialized. (This can be used when save/load)
97
98
  # This attribute is needed to keep one reference on the temporary directory, since garbage collecting it
@@ -400,8 +401,21 @@ class RBLNBaseModel(SubModulesMixin, PushToHubMixin, PreTrainedModel):
400
401
  return compiled_model
401
402
 
402
403
  @classmethod
403
- def update_rbln_config(cls, **others) -> RBLNModelConfig:
404
- rbln_config = cls._update_rbln_config(**others)
404
+ def update_rbln_config(
405
+ cls,
406
+ preprocessors: Optional[Union["AutoFeatureExtractor", "AutoProcessor", "AutoTokenizer"]],
407
+ model: "PreTrainedModel",
408
+ model_config: "PretrainedConfig",
409
+ rbln_config: RBLNModelConfig,
410
+ ) -> RBLNModelConfig:
411
+ rbln_config.torch_dtype = model.dtype
412
+ if not cls._supports_non_fp32 and rbln_config.torch_dtype != torch.float32:
413
+ raise NotImplementedError(
414
+ f"Currently, {cls.__name__} does not support non-fp32 dtype. Please use float32 dtype."
415
+ )
416
+ rbln_config = cls._update_rbln_config(
417
+ preprocessors=preprocessors, model=model, model_config=model_config, rbln_config=rbln_config
418
+ )
405
419
  rbln_config.freeze()
406
420
  if rbln_config.rbln_model_cls_name != cls.__name__:
407
421
  raise NameError(
@@ -444,12 +458,12 @@ class RBLNBaseModel(SubModulesMixin, PushToHubMixin, PreTrainedModel):
444
458
 
445
459
  # This method mimics the interface of torch.nn.Module.parameters()
446
460
  # specifically for code that uses `next(model.parameters())` to infer
447
- # the device or dtype. It yields a single dummy tensor on CPU with float32 dtype.
461
+ # the device or dtype. It yields a single dummy tensor on CPU with model dtype.
448
462
 
449
463
  # Warning:
450
464
  # This does NOT yield the actual model parameters used by the RBLN runtime.
451
465
  # Code relying on iterating through all model parameters will not work as expected.
452
- yield torch.tensor([1.0], dtype=torch.float32, device=torch.device("cpu"))
466
+ yield torch.tensor([1.0], dtype=self.dtype, device=torch.device("cpu"))
453
467
 
454
468
  def __call__(self, *args, **kwargs):
455
469
  return self.forward(*args, **kwargs)
@@ -525,13 +539,30 @@ class RBLNBaseModel(SubModulesMixin, PushToHubMixin, PreTrainedModel):
525
539
 
526
540
  # If everything succeeded, move files to target directory
527
541
  if os.path.exists(save_directory_path):
528
- # Move files from tmp_dir to existing directory (overwrite existing files)
529
- for item in os.listdir(tmp_dir):
530
- src_path = os.path.join(tmp_dir, item)
531
- dst_path = os.path.join(save_directory_path, item)
532
- shutil.move(src_path, dst_path)
533
- # Clean up empty tmp_dir
534
- os.rmdir(tmp_dir)
542
+ # Merge files from tmp_dir into existing directory
543
+ def _merge_dir(src_root: str, dst_root: str):
544
+ for name in os.listdir(src_root):
545
+ src_item = os.path.join(src_root, name)
546
+ dst_item = os.path.join(dst_root, name)
547
+
548
+ if os.path.islink(src_item) or os.path.isfile(src_item):
549
+ os.makedirs(os.path.dirname(dst_item), exist_ok=True)
550
+ if os.path.isdir(dst_item) and not os.path.islink(dst_item):
551
+ shutil.rmtree(dst_item)
552
+ os.replace(src_item, dst_item)
553
+ elif os.path.isdir(src_item):
554
+ if os.path.islink(dst_item) or os.path.isfile(dst_item):
555
+ os.remove(dst_item)
556
+ os.makedirs(dst_item, exist_ok=True)
557
+ _merge_dir(src_item, dst_item)
558
+ else:
559
+ # Fallback for special file types
560
+ os.replace(src_item, dst_item)
561
+
562
+ _merge_dir(tmp_dir, str(save_directory_path))
563
+
564
+ # Remove the temporary directory tree after merge
565
+ shutil.rmtree(tmp_dir)
535
566
  else:
536
567
  # If target doesn't exist, just rename tmp_dir to target
537
568
  os.rename(tmp_dir, save_directory_path)
@@ -35,6 +35,7 @@ _import_structure = {
35
35
  "RBLNAutoModelForSpeechSeq2Seq",
36
36
  "RBLNAutoModelForVision2Seq",
37
37
  "RBLNAutoModelForTextEncoding",
38
+ "RBLNAutoModelForZeroShotObjectDetection",
38
39
  "RBLNBartForConditionalGeneration",
39
40
  "RBLNBartForConditionalGenerationConfig",
40
41
  "RBLNBartModel",
@@ -85,6 +86,12 @@ _import_structure = {
85
86
  "RBLNGPT2LMHeadModelConfig",
86
87
  "RBLNGPT2Model",
87
88
  "RBLNGPT2ModelConfig",
89
+ "RBLNGroundingDinoDecoder",
90
+ "RBLNGroundingDinoDecoderConfig",
91
+ "RBLNGroundingDinoForObjectDetection",
92
+ "RBLNGroundingDinoForObjectDetectionConfig",
93
+ "RBLNGroundingDinoEncoder",
94
+ "RBLNGroundingDinoEncoderConfig",
88
95
  "RBLNIdefics3ForConditionalGeneration",
89
96
  "RBLNIdefics3ForConditionalGenerationConfig",
90
97
  "RBLNIdefics3VisionTransformer",
@@ -178,6 +185,7 @@ if TYPE_CHECKING:
178
185
  RBLNAutoModelForSpeechSeq2Seq,
179
186
  RBLNAutoModelForTextEncoding,
180
187
  RBLNAutoModelForVision2Seq,
188
+ RBLNAutoModelForZeroShotObjectDetection,
181
189
  RBLNBartForConditionalGeneration,
182
190
  RBLNBartForConditionalGenerationConfig,
183
191
  RBLNBartModel,
@@ -228,6 +236,12 @@ if TYPE_CHECKING:
228
236
  RBLNGPT2LMHeadModelConfig,
229
237
  RBLNGPT2Model,
230
238
  RBLNGPT2ModelConfig,
239
+ RBLNGroundingDinoDecoder,
240
+ RBLNGroundingDinoDecoderConfig,
241
+ RBLNGroundingDinoEncoder,
242
+ RBLNGroundingDinoEncoderConfig,
243
+ RBLNGroundingDinoForObjectDetection,
244
+ RBLNGroundingDinoForObjectDetectionConfig,
231
245
  RBLNIdefics3ForConditionalGeneration,
232
246
  RBLNIdefics3ForConditionalGenerationConfig,
233
247
  RBLNIdefics3VisionTransformer,
@@ -25,6 +25,7 @@ class RBLNTransformerEncoderConfig(RBLNModelConfig):
25
25
  max_seq_len: Optional[int] = None,
26
26
  batch_size: Optional[int] = None,
27
27
  model_input_names: Optional[List[str]] = None,
28
+ model_input_shapes: Optional[List[Tuple[int, int]]] = None,
28
29
  **kwargs: Any,
29
30
  ):
30
31
  """
@@ -45,6 +46,7 @@ class RBLNTransformerEncoderConfig(RBLNModelConfig):
45
46
  raise ValueError(f"batch_size must be a positive integer, got {self.batch_size}")
46
47
 
47
48
  self.model_input_names = model_input_names or self.rbln_model_input_names
49
+ self.model_input_shapes = model_input_shapes
48
50
 
49
51
 
50
52
  class RBLNImageModelConfig(RBLNModelConfig):
@@ -127,10 +127,18 @@ class RBLNTransformerEncoder(RBLNModel):
127
127
  "This is an internal error. Please report it to the developers."
128
128
  )
129
129
 
130
- input_info = [
131
- (model_input_name, [rbln_config.batch_size, rbln_config.max_seq_len], cls.rbln_dtype)
132
- for model_input_name in rbln_config.model_input_names
133
- ]
130
+ if rbln_config.model_input_shapes is None:
131
+ input_info = [
132
+ (model_input_name, [rbln_config.batch_size, rbln_config.max_seq_len], cls.rbln_dtype)
133
+ for model_input_name in rbln_config.model_input_names
134
+ ]
135
+ else:
136
+ input_info = [
137
+ (model_input_name, model_input_shape, cls.rbln_dtype)
138
+ for model_input_name, model_input_shape in zip(
139
+ rbln_config.model_input_names, rbln_config.model_input_shapes
140
+ )
141
+ ]
134
142
 
135
143
  rbln_config.set_compile_cfgs([RBLNCompileConfig(input_info=input_info)])
136
144
  return rbln_config
@@ -37,6 +37,7 @@ _import_structure = {
37
37
  "RBLNAutoModelForVision2Seq",
38
38
  "RBLNAutoModelForImageTextToText",
39
39
  "RBLNAutoModelForTextEncoding",
40
+ "RBLNAutoModelForZeroShotObjectDetection",
40
41
  ],
41
42
  "bart": [
42
43
  "RBLNBartForConditionalGeneration",
@@ -165,6 +166,14 @@ _import_structure = {
165
166
  "RBLNXLMRobertaForSequenceClassification",
166
167
  "RBLNXLMRobertaForSequenceClassificationConfig",
167
168
  ],
169
+ "grounding_dino": [
170
+ "RBLNGroundingDinoForObjectDetection",
171
+ "RBLNGroundingDinoForObjectDetectionConfig",
172
+ "RBLNGroundingDinoEncoder",
173
+ "RBLNGroundingDinoEncoderConfig",
174
+ "RBLNGroundingDinoDecoder",
175
+ "RBLNGroundingDinoDecoderConfig",
176
+ ],
168
177
  }
169
178
 
170
179
  if TYPE_CHECKING:
@@ -184,6 +193,7 @@ if TYPE_CHECKING:
184
193
  RBLNAutoModelForSpeechSeq2Seq,
185
194
  RBLNAutoModelForTextEncoding,
186
195
  RBLNAutoModelForVision2Seq,
196
+ RBLNAutoModelForZeroShotObjectDetection,
187
197
  )
188
198
  from .bart import (
189
199
  RBLNBartForConditionalGeneration,
@@ -236,6 +246,14 @@ if TYPE_CHECKING:
236
246
  RBLNGemma3ForConditionalGenerationConfig,
237
247
  )
238
248
  from .gpt2 import RBLNGPT2LMHeadModel, RBLNGPT2LMHeadModelConfig, RBLNGPT2Model, RBLNGPT2ModelConfig
249
+ from .grounding_dino import (
250
+ RBLNGroundingDinoDecoder,
251
+ RBLNGroundingDinoDecoderConfig,
252
+ RBLNGroundingDinoEncoder,
253
+ RBLNGroundingDinoEncoderConfig,
254
+ RBLNGroundingDinoForObjectDetection,
255
+ RBLNGroundingDinoForObjectDetectionConfig,
256
+ )
239
257
  from .idefics3 import (
240
258
  RBLNIdefics3ForConditionalGeneration,
241
259
  RBLNIdefics3ForConditionalGenerationConfig,
@@ -27,4 +27,5 @@ from .modeling_auto import (
27
27
  RBLNAutoModelForSpeechSeq2Seq,
28
28
  RBLNAutoModelForTextEncoding,
29
29
  RBLNAutoModelForVision2Seq,
30
+ RBLNAutoModelForZeroShotObjectDetection,
30
31
  )
@@ -39,6 +39,8 @@ from transformers.models.auto.modeling_auto import (
39
39
  MODEL_FOR_TEXT_ENCODING_MAPPING_NAMES,
40
40
  MODEL_FOR_VISION_2_SEQ_MAPPING,
41
41
  MODEL_FOR_VISION_2_SEQ_MAPPING_NAMES,
42
+ MODEL_FOR_ZERO_SHOT_OBJECT_DETECTION_MAPPING,
43
+ MODEL_FOR_ZERO_SHOT_OBJECT_DETECTION_MAPPING_NAMES,
42
44
  MODEL_MAPPING,
43
45
  MODEL_MAPPING_NAMES,
44
46
  )
@@ -122,3 +124,8 @@ class RBLNAutoModelForQuestionAnswering(_BaseAutoModelClass):
122
124
  class RBLNAutoModelForTextEncoding(_BaseAutoModelClass):
123
125
  _model_mapping = MODEL_FOR_TEXT_ENCODING_MAPPING
124
126
  _model_mapping_names = MODEL_FOR_TEXT_ENCODING_MAPPING_NAMES
127
+
128
+
129
+ class RBLNAutoModelForZeroShotObjectDetection(_BaseAutoModelClass):
130
+ _model_mapping = MODEL_FOR_ZERO_SHOT_OBJECT_DETECTION_MAPPING
131
+ _model_mapping_names = MODEL_FOR_ZERO_SHOT_OBJECT_DETECTION_MAPPING_NAMES
@@ -0,0 +1,16 @@
1
+ import torch
2
+
3
+
4
+ class BertModelWrapper(torch.nn.Module):
5
+ def __init__(self, model, rbln_config):
6
+ super().__init__()
7
+ self.model = model
8
+ self.rbln_config = rbln_config
9
+
10
+ def forward(self, *args, **kwargs):
11
+ output = self.model(*args, **kwargs)
12
+ if isinstance(output, torch.Tensor):
13
+ return output
14
+ elif isinstance(output, tuple):
15
+ return tuple(x for x in output if x is not None)
16
+ return output
@@ -12,15 +12,15 @@
12
12
  # See the License for the specific language governing permissions and
13
13
  # limitations under the License.
14
14
 
15
- from ....utils.logging import get_logger
15
+ import torch
16
+
16
17
  from ...modeling_generic import (
17
18
  RBLNModelForMaskedLM,
18
19
  RBLNModelForQuestionAnswering,
19
20
  RBLNTransformerEncoderForFeatureExtraction,
20
21
  )
21
-
22
-
23
- logger = get_logger(__name__)
22
+ from .bert_architecture import BertModelWrapper
23
+ from .configuration_bert import RBLNBertModelConfig
24
24
 
25
25
 
26
26
  class RBLNBertModel(RBLNTransformerEncoderForFeatureExtraction):
@@ -34,6 +34,10 @@ class RBLNBertModel(RBLNTransformerEncoderForFeatureExtraction):
34
34
 
35
35
  rbln_model_input_names = ["input_ids", "attention_mask"]
36
36
 
37
+ @classmethod
38
+ def wrap_model_if_needed(cls, model: torch.nn.Module, rbln_config: RBLNBertModelConfig) -> torch.nn.Module:
39
+ return BertModelWrapper(model, rbln_config)
40
+
37
41
 
38
42
  class RBLNBertForMaskedLM(RBLNModelForMaskedLM):
39
43
  """
@@ -174,7 +174,12 @@ class RBLNBlip2QFormerModel(RBLNModel):
174
174
  return Blip2QFormerModelWrapper(model).eval()
175
175
 
176
176
  @classmethod
177
- def _update_submodule_config(cls, model: "PreTrainedModel", rbln_config: "RBLNModelConfig") -> "RBLNModelConfig":
177
+ def _update_submodule_config(
178
+ cls,
179
+ model: "PreTrainedModel",
180
+ rbln_config: RBLNModelConfig,
181
+ preprocessors: Optional[Union["AutoFeatureExtractor", "AutoProcessor", "AutoTokenizer"]],
182
+ ):
178
183
  if rbln_config.num_query_tokens is None:
179
184
  rbln_config.num_query_tokens = model.config.num_query_tokens
180
185
 
@@ -1066,7 +1066,7 @@ class RotaryEmbedding(nn.Module):
1066
1066
  rope_type = "default"
1067
1067
 
1068
1068
  inv_freq, attention_scaling = ROPE_INIT_FUNCTIONS[rope_type](config, max_seq_len_cached)
1069
- cache_position = torch.arange(0, max_seq_len_cached, dtype=torch.float32)
1069
+ cache_position = torch.arange(0, max_seq_len_cached)
1070
1070
  cache_position_expanded = cache_position[:, None]
1071
1071
 
1072
1072
  if rope_type == "dynamic":
@@ -1085,8 +1085,8 @@ class RotaryEmbedding(nn.Module):
1085
1085
 
1086
1086
  def forward(self, x, seq_len):
1087
1087
  return (
1088
- self._cos_cached[:seq_len].to(dtype=x.dtype),
1089
- self._sin_cached[:seq_len].to(dtype=x.dtype),
1088
+ self._cos_cached[:seq_len].to(dtype=torch.float32),
1089
+ self._sin_cached[:seq_len].to(dtype=torch.float32),
1090
1090
  )
1091
1091
 
1092
1092
 
@@ -1116,8 +1116,11 @@ def rotate_half(x):
1116
1116
 
1117
1117
  def apply_rotary_pos_emb(q, k, cos, sin):
1118
1118
  """Applies Rotary Position Embedding to the query and key tensors."""
1119
+ dtype = q.dtype
1119
1120
  q_embed = (q * cos) + (rotate_half(q) * sin)
1120
1121
  k_embed = (k * cos) + (rotate_half(k) * sin)
1122
+ q_embed = q_embed.to(dtype)
1123
+ k_embed = k_embed.to(dtype)
1121
1124
  return q_embed, k_embed
1122
1125
 
1123
1126
 
@@ -317,7 +317,13 @@ class RBLNRuntimeModel(RBLNPytorchRuntime):
317
317
 
318
318
  # Initialize attention mask for chunked processing
319
319
  chunked_attention_mask = (
320
- torch.zeros(1, 1, self.rbln_config.prefill_chunk_size, self.rbln_config.max_seq_len, dtype=torch.float32)
320
+ torch.zeros(
321
+ 1,
322
+ 1,
323
+ self.rbln_config.prefill_chunk_size,
324
+ self.rbln_config.max_seq_len,
325
+ dtype=self.rbln_config.torch_dtype,
326
+ )
321
327
  if self.rbln_config.use_attention_mask
322
328
  else None
323
329
  )
@@ -19,7 +19,7 @@ from typing import TYPE_CHECKING, Any, Callable, List, Optional, Tuple, Union
19
19
  import rebel
20
20
  import torch
21
21
  from rebel.compile_context import CompileContext
22
- from transformers import AutoConfig, AutoModel, AutoModelForCausalLM, PretrainedConfig, PreTrainedModel
22
+ from transformers import AutoModel, AutoModelForCausalLM, PretrainedConfig, PreTrainedModel
23
23
  from transformers.modeling_outputs import BaseModelOutputWithPast
24
24
  from transformers.modeling_utils import no_init_weights
25
25
 
@@ -33,7 +33,7 @@ from ...modeling_attention_utils import (
33
33
  validate_sliding_window,
34
34
  )
35
35
  from ...modeling_outputs import RBLNDecoderOnlyOutput
36
- from ...utils.rbln_quantization import prepare_model_for_quantization
36
+ from ...utils.rbln_quantization import get_quantized_model
37
37
  from .configuration_decoderonly import RBLNDecoderOnlyModelConfig, RBLNDecoderOnlyModelForCausalLMConfig
38
38
  from .decoderonly_architecture import DecoderOnlyWrapper
39
39
  from .decoderonly_runtime_utils import RBLNPageTableManager, RBLNRuntimeModel
@@ -72,6 +72,7 @@ class RBLNDecoderOnlyModel(RBLNModel, RBLNDecoderOnlyFlashAttentionMixin):
72
72
  auto_model_class = AutoModel
73
73
  _decoder_wrapper_cls = DecoderOnlyWrapper
74
74
  _use_rotary_emb = True
75
+ _supports_non_fp32 = True
75
76
 
76
77
  def __post_init__(self, **kwargs):
77
78
  if self.rbln_config.use_inputs_embeds:
@@ -86,10 +87,8 @@ class RBLNDecoderOnlyModel(RBLNModel, RBLNDecoderOnlyFlashAttentionMixin):
86
87
  def setup_runtime(self):
87
88
  # Initialize resources to be used across Runtime instances (prefill and decode phases)
88
89
  page_table_manager = RBLNPageTableManager(self.rbln_config)
89
- dec_attn_mask = torch.zeros(
90
- self.rbln_config.batch_size, 1, 1, self.rbln_config.max_seq_len, dtype=torch.float32
91
- )
92
- out_buffers = [torch.empty(self.prefill_output_size, dtype=torch.float32, device="cpu")]
90
+ dec_attn_mask = torch.zeros(self.rbln_config.batch_size, 1, 1, self.rbln_config.max_seq_len, dtype=self.dtype)
91
+ out_buffers = [torch.empty(self.prefill_output_size, dtype=self.dtype)]
93
92
 
94
93
  common_kwargs = {
95
94
  "main_input_name": "inputs_embeds" if self.rbln_config.use_inputs_embeds else "input_ids",
@@ -143,35 +142,17 @@ class RBLNDecoderOnlyModel(RBLNModel, RBLNDecoderOnlyFlashAttentionMixin):
143
142
  ):
144
143
  kwargs = cls.update_kwargs(kwargs)
145
144
 
146
- if config is None:
147
- config = AutoConfig.from_pretrained(
148
- model_id,
149
- use_auth_token=use_auth_token,
150
- revision=revision,
151
- force_download=force_download,
152
- cache_dir=cache_dir,
153
- trust_remote_code=trust_remote_code,
154
- **kwargs,
155
- )
156
- if config.torch_dtype == torch.bfloat16:
157
- # FIXME: bfloat16 is not supported by rebel-compiler
158
- config.torch_dtype = torch.float32
159
-
160
- with no_init_weights():
161
- model = cls.auto_model_class.from_config(config)
162
-
163
- model = prepare_model_for_quantization(
164
- model,
145
+ return get_quantized_model(
146
+ cls.auto_model_class,
165
147
  model_id,
166
- kwargs.get("num_hidden_layers"),
167
148
  use_auth_token=use_auth_token,
168
149
  revision=revision,
169
150
  cache_dir=cache_dir,
170
151
  force_download=force_download,
171
152
  local_files_only=local_files_only,
172
153
  rbln_quantization=rbln_config.quantization,
154
+ **kwargs,
173
155
  )
174
- return model
175
156
 
176
157
  def __getattr__(self, __name: str) -> Any:
177
158
  # Special method to delegate attribute access to the original Huggingface LM class.
@@ -365,7 +346,7 @@ class RBLNDecoderOnlyModel(RBLNModel, RBLNDecoderOnlyFlashAttentionMixin):
365
346
 
366
347
  input_info = []
367
348
  if rbln_config.use_inputs_embeds:
368
- input_info.append(("inputs_embeds", [batch_size, query_length, hidden_size], "float32"))
349
+ input_info.append(("inputs_embeds", [batch_size, query_length, hidden_size], rbln_config.torch_dtype))
369
350
  else:
370
351
  input_info.append(("input_ids", [batch_size, query_length], "int64"))
371
352
 
@@ -384,16 +365,16 @@ class RBLNDecoderOnlyModel(RBLNModel, RBLNDecoderOnlyFlashAttentionMixin):
384
365
 
385
366
  if rbln_config.use_attention_mask:
386
367
  if rbln_config.use_position_ids:
387
- input_info.append(("attention_mask", [batch_size, rbln_config.max_seq_len], "float32"))
368
+ input_info.append(("attention_mask", [batch_size, rbln_config.max_seq_len], rbln_config.torch_dtype))
388
369
  else:
389
370
  input_info.append(
390
- ("attention_mask", [batch_size, 1, query_length, rbln_config.max_seq_len], "float32")
371
+ ("attention_mask", [batch_size, 1, query_length, rbln_config.max_seq_len], rbln_config.torch_dtype)
391
372
  )
392
373
 
393
374
  if rbln_config.use_position_ids:
394
375
  input_info.append(("position_ids", [batch_size, query_length], "int32"))
395
376
 
396
- kvcache_dtype = "float32"
377
+ kvcache_dtype = rbln_config.torch_dtype
397
378
  if rbln_config.quantization and rbln_config.quantization.kv_caches == "fp8":
398
379
  kvcache_dtype = "float8_e4m3fn"
399
380
 
@@ -20,6 +20,6 @@ class RBLNDepthAnythingForDepthEstimation(RBLNModelForDepthEstimation):
20
20
  """
21
21
  RBLN optimized DepthAnythingForDepthEstimation model for depth estimation tasks.
22
22
 
23
- This class provides hardware-accelerated inference for Depth Anything V2 Small
23
+ This class provides hardware-accelerated inference for Depth Anything V2
24
24
  models on RBLN devices, providing the most capable monocular depth estimation (MDE) model.
25
25
  """
@@ -345,6 +345,7 @@ class RBLNGemma3ForCausalLM(RBLNDecoderOnlyModelForCausalLM):
345
345
  """
346
346
 
347
347
  _decoder_wrapper_cls = Gemma3ForCausalLMWrapper
348
+ _supports_non_fp32 = False
348
349
 
349
350
  def setup_runtime(self):
350
351
  # Initialize shared resources to be used across Runtime instances (prefill and decode phases)
@@ -403,7 +404,12 @@ class RBLNGemma3ForCausalLM(RBLNDecoderOnlyModelForCausalLM):
403
404
  return rbln_config
404
405
 
405
406
  @classmethod
406
- def _update_submodule_config(cls, model: "PreTrainedModel", rbln_config: RBLNModelConfig):
407
+ def _update_submodule_config(
408
+ cls,
409
+ model: "PreTrainedModel",
410
+ rbln_config: RBLNModelConfig,
411
+ preprocessors: Optional[Union["AutoFeatureExtractor", "AutoProcessor", "AutoTokenizer"]],
412
+ ):
407
413
  if rbln_config.image_prefill_chunk_size is None:
408
414
  rbln_config.image_prefill_chunk_size = model.config.mm_tokens_per_image
409
415
 
@@ -0,0 +1,10 @@
1
+ from .configuration_grounding_dino import (
2
+ RBLNGroundingDinoDecoderConfig,
3
+ RBLNGroundingDinoEncoderConfig,
4
+ RBLNGroundingDinoForObjectDetectionConfig,
5
+ )
6
+ from .modeling_grounding_dino import (
7
+ RBLNGroundingDinoDecoder,
8
+ RBLNGroundingDinoEncoder,
9
+ RBLNGroundingDinoForObjectDetection,
10
+ )