optimum-rbln 0.8.0.post2__py3-none-any.whl → 0.8.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (162) hide show
  1. optimum/rbln/__init__.py +24 -0
  2. optimum/rbln/__version__.py +2 -2
  3. optimum/rbln/configuration_utils.py +45 -33
  4. optimum/rbln/diffusers/__init__.py +21 -1
  5. optimum/rbln/diffusers/configurations/__init__.py +4 -0
  6. optimum/rbln/diffusers/configurations/models/__init__.py +2 -0
  7. optimum/rbln/diffusers/configurations/models/configuration_autoencoder_kl.py +9 -2
  8. optimum/rbln/diffusers/configurations/models/configuration_autoencoder_kl_cosmos.py +84 -0
  9. optimum/rbln/diffusers/configurations/models/configuration_controlnet.py +4 -2
  10. optimum/rbln/diffusers/configurations/models/configuration_prior_transformer.py +9 -2
  11. optimum/rbln/diffusers/configurations/models/configuration_transformer_cosmos.py +70 -0
  12. optimum/rbln/diffusers/configurations/models/configuration_transformer_sd3.py +4 -2
  13. optimum/rbln/diffusers/configurations/models/configuration_unet_2d_condition.py +9 -2
  14. optimum/rbln/diffusers/configurations/models/configuration_vq_model.py +9 -2
  15. optimum/rbln/diffusers/configurations/pipelines/__init__.py +1 -0
  16. optimum/rbln/diffusers/configurations/pipelines/configuration_controlnet.py +29 -9
  17. optimum/rbln/diffusers/configurations/pipelines/configuration_cosmos.py +114 -0
  18. optimum/rbln/diffusers/configurations/pipelines/configuration_kandinsky2_2.py +28 -12
  19. optimum/rbln/diffusers/configurations/pipelines/configuration_stable_diffusion.py +18 -6
  20. optimum/rbln/diffusers/configurations/pipelines/configuration_stable_diffusion_3.py +13 -6
  21. optimum/rbln/diffusers/configurations/pipelines/configuration_stable_diffusion_xl.py +12 -6
  22. optimum/rbln/diffusers/modeling_diffusers.py +72 -65
  23. optimum/rbln/diffusers/models/__init__.py +4 -0
  24. optimum/rbln/diffusers/models/autoencoders/__init__.py +1 -0
  25. optimum/rbln/diffusers/models/autoencoders/autoencoder_kl.py +17 -1
  26. optimum/rbln/diffusers/models/autoencoders/autoencoder_kl_cosmos.py +219 -0
  27. optimum/rbln/diffusers/models/autoencoders/vae.py +45 -8
  28. optimum/rbln/diffusers/models/autoencoders/vq_model.py +17 -1
  29. optimum/rbln/diffusers/models/controlnet.py +14 -8
  30. optimum/rbln/diffusers/models/transformers/__init__.py +1 -0
  31. optimum/rbln/diffusers/models/transformers/prior_transformer.py +10 -0
  32. optimum/rbln/diffusers/models/transformers/transformer_cosmos.py +321 -0
  33. optimum/rbln/diffusers/models/transformers/transformer_sd3.py +2 -0
  34. optimum/rbln/diffusers/models/unets/unet_2d_condition.py +11 -1
  35. optimum/rbln/diffusers/pipelines/__init__.py +10 -0
  36. optimum/rbln/diffusers/pipelines/controlnet/multicontrolnet.py +1 -4
  37. optimum/rbln/diffusers/pipelines/controlnet/pipeline_controlnet.py +7 -0
  38. optimum/rbln/diffusers/pipelines/controlnet/pipeline_controlnet_img2img.py +7 -0
  39. optimum/rbln/diffusers/pipelines/controlnet/pipeline_controlnet_sd_xl.py +7 -0
  40. optimum/rbln/diffusers/pipelines/controlnet/pipeline_controlnet_sd_xl_img2img.py +7 -0
  41. optimum/rbln/diffusers/pipelines/cosmos/__init__.py +17 -0
  42. optimum/rbln/diffusers/pipelines/cosmos/configuration_cosmos_guardrail.py +102 -0
  43. optimum/rbln/diffusers/pipelines/cosmos/cosmos_guardrail.py +455 -0
  44. optimum/rbln/diffusers/pipelines/cosmos/pipeline_cosmos_text2world.py +98 -0
  45. optimum/rbln/diffusers/pipelines/cosmos/pipeline_cosmos_video2world.py +98 -0
  46. optimum/rbln/diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2.py +7 -0
  47. optimum/rbln/diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_combined.py +48 -27
  48. optimum/rbln/diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_img2img.py +7 -0
  49. optimum/rbln/diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_inpaint.py +7 -0
  50. optimum/rbln/diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_prior.py +7 -0
  51. optimum/rbln/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion.py +7 -0
  52. optimum/rbln/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_img2img.py +7 -0
  53. optimum/rbln/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_inpaint.py +7 -0
  54. optimum/rbln/diffusers/pipelines/stable_diffusion_3/pipeline_stable_diffusion_3.py +7 -0
  55. optimum/rbln/diffusers/pipelines/stable_diffusion_3/pipeline_stable_diffusion_3_img2img.py +7 -0
  56. optimum/rbln/diffusers/pipelines/stable_diffusion_3/pipeline_stable_diffusion_3_inpaint.py +7 -0
  57. optimum/rbln/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl.py +7 -0
  58. optimum/rbln/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_img2img.py +7 -0
  59. optimum/rbln/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_inpaint.py +7 -0
  60. optimum/rbln/modeling.py +71 -37
  61. optimum/rbln/modeling_base.py +63 -109
  62. optimum/rbln/transformers/__init__.py +41 -47
  63. optimum/rbln/transformers/configuration_generic.py +16 -13
  64. optimum/rbln/transformers/modeling_generic.py +21 -22
  65. optimum/rbln/transformers/modeling_rope_utils.py +5 -2
  66. optimum/rbln/transformers/models/__init__.py +54 -4
  67. optimum/rbln/transformers/models/{wav2vec2/configuration_wav2vec.py → audio_spectrogram_transformer/__init__.py} +2 -4
  68. optimum/rbln/transformers/models/audio_spectrogram_transformer/configuration_audio_spectrogram_transformer.py +21 -0
  69. optimum/rbln/transformers/models/audio_spectrogram_transformer/modeling_audio_spectrogram_transformer.py +28 -0
  70. optimum/rbln/transformers/models/auto/auto_factory.py +35 -12
  71. optimum/rbln/transformers/models/bart/bart_architecture.py +14 -1
  72. optimum/rbln/transformers/models/bart/configuration_bart.py +12 -2
  73. optimum/rbln/transformers/models/bart/modeling_bart.py +16 -7
  74. optimum/rbln/transformers/models/bert/configuration_bert.py +18 -3
  75. optimum/rbln/transformers/models/bert/modeling_bert.py +24 -0
  76. optimum/rbln/transformers/models/blip_2/configuration_blip_2.py +15 -3
  77. optimum/rbln/transformers/models/blip_2/modeling_blip_2.py +50 -4
  78. optimum/rbln/transformers/models/clip/configuration_clip.py +15 -5
  79. optimum/rbln/transformers/models/clip/modeling_clip.py +38 -13
  80. optimum/rbln/transformers/models/colpali/__init__.py +2 -0
  81. optimum/rbln/transformers/models/colpali/colpali_architecture.py +221 -0
  82. optimum/rbln/transformers/models/colpali/configuration_colpali.py +68 -0
  83. optimum/rbln/transformers/models/colpali/modeling_colpali.py +383 -0
  84. optimum/rbln/transformers/models/decoderonly/configuration_decoderonly.py +111 -14
  85. optimum/rbln/transformers/models/decoderonly/decoderonly_architecture.py +102 -35
  86. optimum/rbln/transformers/models/decoderonly/modeling_decoderonly.py +253 -195
  87. optimum/rbln/transformers/models/distilbert/__init__.py +19 -0
  88. optimum/rbln/transformers/models/distilbert/configuration_distilbert.py +24 -0
  89. optimum/rbln/transformers/models/distilbert/modeling_distilbert.py +27 -0
  90. optimum/rbln/transformers/models/dpt/configuration_dpt.py +6 -1
  91. optimum/rbln/transformers/models/dpt/modeling_dpt.py +6 -1
  92. optimum/rbln/transformers/models/exaone/configuration_exaone.py +24 -1
  93. optimum/rbln/transformers/models/exaone/exaone_architecture.py +5 -1
  94. optimum/rbln/transformers/models/exaone/modeling_exaone.py +66 -5
  95. optimum/rbln/transformers/models/gemma/configuration_gemma.py +24 -1
  96. optimum/rbln/transformers/models/gemma/gemma_architecture.py +5 -1
  97. optimum/rbln/transformers/models/gemma/modeling_gemma.py +49 -0
  98. optimum/rbln/transformers/models/gemma3/configuration_gemma3.py +3 -3
  99. optimum/rbln/transformers/models/gemma3/gemma3_architecture.py +18 -250
  100. optimum/rbln/transformers/models/gemma3/modeling_gemma3.py +89 -244
  101. optimum/rbln/transformers/models/gpt2/configuration_gpt2.py +4 -1
  102. optimum/rbln/transformers/models/gpt2/gpt2_architecture.py +6 -1
  103. optimum/rbln/transformers/models/idefics3/configuration_idefics3.py +12 -2
  104. optimum/rbln/transformers/models/idefics3/modeling_idefics3.py +41 -4
  105. optimum/rbln/transformers/models/llama/configuration_llama.py +24 -1
  106. optimum/rbln/transformers/models/llama/modeling_llama.py +49 -0
  107. optimum/rbln/transformers/models/llava_next/configuration_llava_next.py +10 -2
  108. optimum/rbln/transformers/models/llava_next/modeling_llava_next.py +32 -4
  109. optimum/rbln/transformers/models/midm/configuration_midm.py +24 -1
  110. optimum/rbln/transformers/models/midm/midm_architecture.py +6 -1
  111. optimum/rbln/transformers/models/midm/modeling_midm.py +66 -5
  112. optimum/rbln/transformers/models/mistral/configuration_mistral.py +24 -1
  113. optimum/rbln/transformers/models/mistral/modeling_mistral.py +62 -4
  114. optimum/rbln/transformers/models/opt/configuration_opt.py +4 -1
  115. optimum/rbln/transformers/models/opt/modeling_opt.py +10 -0
  116. optimum/rbln/transformers/models/opt/opt_architecture.py +7 -1
  117. optimum/rbln/transformers/models/phi/configuration_phi.py +24 -1
  118. optimum/rbln/transformers/models/phi/modeling_phi.py +49 -0
  119. optimum/rbln/transformers/models/phi/phi_architecture.py +1 -1
  120. optimum/rbln/transformers/models/qwen2/configuration_qwen2.py +24 -1
  121. optimum/rbln/transformers/models/qwen2/modeling_qwen2.py +67 -4
  122. optimum/rbln/transformers/models/qwen2_5_vl/configuration_qwen2_5_vl.py +31 -3
  123. optimum/rbln/transformers/models/qwen2_5_vl/modeling_qwen2_5_vl.py +54 -25
  124. optimum/rbln/transformers/models/qwen2_5_vl/qwen2_5_vl_architecture.py +6 -4
  125. optimum/rbln/transformers/models/resnet/__init__.py +23 -0
  126. optimum/rbln/transformers/models/resnet/configuration_resnet.py +25 -0
  127. optimum/rbln/transformers/models/resnet/modeling_resnet.py +26 -0
  128. optimum/rbln/transformers/models/roberta/__init__.py +24 -0
  129. optimum/rbln/transformers/{configuration_alias.py → models/roberta/configuration_roberta.py} +12 -28
  130. optimum/rbln/transformers/{modeling_alias.py → models/roberta/modeling_roberta.py} +14 -28
  131. optimum/rbln/transformers/models/seq2seq/__init__.py +1 -1
  132. optimum/rbln/transformers/models/seq2seq/{configuration_seq2seq2.py → configuration_seq2seq.py} +2 -2
  133. optimum/rbln/transformers/models/seq2seq/modeling_seq2seq.py +7 -3
  134. optimum/rbln/transformers/models/seq2seq/seq2seq_architecture.py +41 -3
  135. optimum/rbln/transformers/models/siglip/configuration_siglip.py +10 -0
  136. optimum/rbln/transformers/models/siglip/modeling_siglip.py +69 -21
  137. optimum/rbln/transformers/models/t5/configuration_t5.py +12 -2
  138. optimum/rbln/transformers/models/t5/modeling_t5.py +56 -8
  139. optimum/rbln/transformers/models/t5/t5_architecture.py +5 -1
  140. optimum/rbln/transformers/models/{time_series_transformers → time_series_transformer}/__init__.py +1 -1
  141. optimum/rbln/transformers/models/{time_series_transformers → time_series_transformer}/configuration_time_series_transformer.py +9 -2
  142. optimum/rbln/transformers/models/{time_series_transformers/modeling_time_series_transformers.py → time_series_transformer/modeling_time_series_transformer.py} +20 -11
  143. optimum/rbln/transformers/models/vit/__init__.py +19 -0
  144. optimum/rbln/transformers/models/vit/configuration_vit.py +24 -0
  145. optimum/rbln/transformers/models/vit/modeling_vit.py +25 -0
  146. optimum/rbln/transformers/models/wav2vec2/__init__.py +1 -1
  147. optimum/rbln/transformers/models/wav2vec2/configuration_wav2vec2.py +26 -0
  148. optimum/rbln/transformers/models/wav2vec2/modeling_wav2vec2.py +1 -1
  149. optimum/rbln/transformers/models/whisper/configuration_whisper.py +10 -1
  150. optimum/rbln/transformers/models/whisper/modeling_whisper.py +41 -17
  151. optimum/rbln/transformers/models/xlm_roberta/__init__.py +16 -2
  152. optimum/rbln/transformers/models/xlm_roberta/configuration_xlm_roberta.py +15 -2
  153. optimum/rbln/transformers/models/xlm_roberta/modeling_xlm_roberta.py +12 -3
  154. optimum/rbln/utils/model_utils.py +20 -0
  155. optimum/rbln/utils/runtime_utils.py +49 -1
  156. optimum/rbln/utils/submodule.py +6 -8
  157. {optimum_rbln-0.8.0.post2.dist-info → optimum_rbln-0.8.1.dist-info}/METADATA +6 -6
  158. optimum_rbln-0.8.1.dist-info/RECORD +211 -0
  159. optimum_rbln-0.8.0.post2.dist-info/RECORD +0 -184
  160. /optimum/rbln/transformers/models/{time_series_transformers → time_series_transformer}/time_series_transformers_architecture.py +0 -0
  161. {optimum_rbln-0.8.0.post2.dist-info → optimum_rbln-0.8.1.dist-info}/WHEEL +0 -0
  162. {optimum_rbln-0.8.0.post2.dist-info → optimum_rbln-0.8.1.dist-info}/licenses/LICENSE +0 -0
@@ -15,7 +15,7 @@
15
15
  import importlib
16
16
  import os
17
17
  import shutil
18
- from abc import ABC, abstractmethod
18
+ from abc import ABC
19
19
  from pathlib import Path
20
20
  from tempfile import TemporaryDirectory
21
21
  from typing import TYPE_CHECKING, Any, Dict, List, Optional, Tuple, Type, Union
@@ -24,10 +24,10 @@ import rebel
24
24
  import torch
25
25
  from transformers import AutoConfig, AutoModel, GenerationConfig, PretrainedConfig
26
26
 
27
- from .configuration_utils import RBLNAutoConfig, RBLNCompileConfig, RBLNModelConfig
27
+ from .configuration_utils import RBLNAutoConfig, RBLNCompileConfig, RBLNModelConfig, get_rbln_config_class
28
28
  from .utils.hub import PushToHubMixin, pull_compiled_model_from_hub, validate_files
29
29
  from .utils.logging import get_logger
30
- from .utils.runtime_utils import UnavailableRuntime
30
+ from .utils.runtime_utils import UnavailableRuntime, tp_and_devices_are_ok
31
31
  from .utils.save_utils import maybe_load_preprocessors
32
32
  from .utils.submodule import SubModulesMixin
33
33
 
@@ -47,40 +47,6 @@ class RBLNBaseModelConfig(RBLNModelConfig):
47
47
 
48
48
 
49
49
  class RBLNBaseModel(SubModulesMixin, PushToHubMixin, PreTrainedModel):
50
- """
51
- An abstract base class for compiling, loading, and saving neural network models from the huggingface
52
- transformers and diffusers libraries to run on RBLN NPU devices.
53
-
54
- This class supports loading and saving models using the `from_pretrained` and `save_pretrained` methods,
55
- similar to the huggingface libraries.
56
-
57
- The `from_pretrained` method loads a model corresponding to the given `model_id` from a local repository
58
- or the huggingface hub onto the NPU. If the model is a PyTorch model and `export=True` is passed as a
59
- kwarg, it compiles the PyTorch model corresponding to the given `model_id` before loading. If `model_id`
60
- is an already rbln-compiled model, it can be directly loaded onto the NPU with `export=False`.
61
-
62
- `rbln_npu` is a kwarg required for compilation, specifying the name of the NPU to be used. If this
63
- keyword is not specified, the NPU installed on the host machine is used. If no NPU is installed on the
64
- host machine, an error occurs.
65
-
66
- `rbln_device` specifies the device to be used at runtime. If not specified, device 0 is used.
67
-
68
- `rbln_create_runtimes` indicates whether to create runtime objects. If False, the runtime does not load
69
- the model onto the NPU. This option is particularly useful when you want to perform compilation only on a
70
- host machine without an NPU.
71
-
72
- `RBLNModel`, `RBLNModelFor*`, etc. are all child classes of RBLNBaseModel.
73
-
74
- Models compiled in this way can be saved to a local repository using `save_pretrained` or uploaded to
75
- the huggingface hub.
76
-
77
- It also supports generation through `generate` (for transformers models that support generation).
78
-
79
- RBLNBaseModel is a class for models consisting of an arbitrary number of `torch.nn.Module`s, and
80
- therefore is an abstract class without explicit implementations of `forward` or `export` functions.
81
- To inherit from this class, `forward`, `export`, etc. must be implemented.
82
- """
83
-
84
50
  model_type = "rbln_model"
85
51
  auto_model_class = AutoModel
86
52
  config_class = AutoConfig
@@ -156,7 +122,7 @@ class RBLNBaseModel(SubModulesMixin, PushToHubMixin, PreTrainedModel):
156
122
  subfolder: str = "",
157
123
  local_files_only: bool = False,
158
124
  ) -> str:
159
- """Load the directory containing the compiled model files."""
125
+ # Load the directory containing the compiled model files.
160
126
  model_path = Path(model_id)
161
127
 
162
128
  if model_path.is_dir():
@@ -372,22 +338,59 @@ class RBLNBaseModel(SubModulesMixin, PushToHubMixin, PreTrainedModel):
372
338
  def prepare_rbln_config(
373
339
  cls, rbln_config: Optional[Union[Dict[str, Any], RBLNModelConfig]] = None, **kwargs
374
340
  ) -> Tuple[RBLNModelConfig, Dict[str, Any]]:
375
- """
376
- Extract rbln-config from kwargs and convert it to RBLNModelConfig.
377
- """
341
+ # Extract rbln-config from kwargs and convert it to RBLNModelConfig.
342
+
378
343
  config_cls = cls.get_rbln_config_class()
379
344
  rbln_config, kwargs = config_cls.initialize_from_kwargs(rbln_config, **kwargs)
380
345
  return rbln_config, kwargs
381
346
 
382
347
  @classmethod
383
- def from_pretrained(cls, model_id: Union[str, Path], export: bool = False, **kwargs) -> "RBLNBaseModel":
348
+ def from_pretrained(
349
+ cls: Type["RBLNBaseModel"],
350
+ model_id: Union[str, Path],
351
+ export: bool = False,
352
+ rbln_config: Optional[Union[Dict, RBLNModelConfig]] = None,
353
+ **kwargs: Dict[str, Any],
354
+ ) -> "RBLNBaseModel":
355
+ """
356
+ The `from_pretrained()` function is utilized in its standard form as in the HuggingFace transformers library.
357
+ User can use this function to load a pre-trained model from the HuggingFace library and convert it to a RBLN model to be run on RBLN NPUs.
358
+
359
+ Args:
360
+ model_id: The model id of the pre-trained model to be loaded. It can be downloaded from the HuggingFace model hub or a local path, or a model id of a compiled model using the RBLN Compiler.
361
+ export: A boolean flag to indicate whether the model should be compiled.
362
+ rbln_config: Configuration for RBLN model compilation and runtime. This can be provided as a dictionary or an instance of the model's configuration class (e.g., `RBLNLlamaForCausalLMConfig` for Llama models).
363
+ For detailed configuration options, see the specific model's configuration class documentation.
364
+
365
+ kwargs: Additional keyword arguments. Arguments with the prefix 'rbln_' are passed to rbln_config, while the remaining arguments are passed to the HuggingFace library.
366
+
367
+ Returns:
368
+ A RBLN model instance ready for inference on RBLN NPU devices.
369
+ """
370
+
384
371
  if isinstance(model_id, Path):
385
372
  model_id = model_id.as_posix()
386
373
  from_pretrained_method = cls._export if export else cls._from_pretrained
387
- return from_pretrained_method(model_id=model_id, **kwargs)
374
+ return from_pretrained_method(model_id=model_id, **kwargs, rbln_config=rbln_config)
388
375
 
389
376
  @classmethod
390
- def compile(cls, model, rbln_compile_config: Optional[RBLNCompileConfig] = None, **kwargs):
377
+ def compile(
378
+ cls,
379
+ model,
380
+ rbln_compile_config: RBLNCompileConfig,
381
+ create_runtimes: bool,
382
+ device: Union[int, List[int]],
383
+ **kwargs,
384
+ ):
385
+ if create_runtimes:
386
+ runtime_cannot_be_created = tp_and_devices_are_ok(
387
+ tensor_parallel_size=rbln_compile_config.tensor_parallel_size,
388
+ device=device,
389
+ npu=rbln_compile_config.npu,
390
+ )
391
+ if runtime_cannot_be_created:
392
+ raise ValueError(runtime_cannot_be_created)
393
+
391
394
  compiled_model = rebel.compile_from_torch(
392
395
  model,
393
396
  input_info=rbln_compile_config.input_info,
@@ -411,15 +414,13 @@ class RBLNBaseModel(SubModulesMixin, PushToHubMixin, PreTrainedModel):
411
414
 
412
415
  @classmethod
413
416
  def get_hf_class(cls):
414
- """
415
- Lazily loads and caches the corresponding HuggingFace model class.
416
- Removes 'RBLN' prefix from the class name to get the original class name
417
- (e.g., RBLNLlamaForCausalLM -> LlamaForCausalLM) and imports it from
418
- the transformers/diffusers module.
417
+ # Lazily loads and caches the corresponding HuggingFace model class.
418
+ # Removes 'RBLN' prefix from the class name to get the original class name
419
+ # (e.g., RBLNLlamaForCausalLM -> LlamaForCausalLM) and imports it from
420
+ # the transformers/diffusers module.
419
421
 
420
- Returns:
421
- type: The original HuggingFace model class
422
- """
422
+ # Returns:
423
+ # type: The original HuggingFace model class
423
424
  if cls._hf_class is None:
424
425
  hf_cls_name = cls.__name__[4:]
425
426
  library = importlib.import_module(cls.hf_library_name)
@@ -428,18 +429,10 @@ class RBLNBaseModel(SubModulesMixin, PushToHubMixin, PreTrainedModel):
428
429
 
429
430
  @classmethod
430
431
  def get_rbln_config_class(cls) -> Type[RBLNModelConfig]:
431
- """
432
- Lazily loads and caches the corresponding RBLN model config class.
433
- """
432
+ # Lazily loads and caches the corresponding RBLN model config class.
434
433
  if cls._rbln_config_class is None:
435
434
  rbln_config_class_name = cls.__name__ + "Config"
436
- library = importlib.import_module("optimum.rbln")
437
- cls._rbln_config_class = getattr(library, rbln_config_class_name, None)
438
- if cls._rbln_config_class is None:
439
- raise ValueError(
440
- f"RBLN config class {rbln_config_class_name} not found. This is an internal error. "
441
- "Please report it to the developers."
442
- )
435
+ cls._rbln_config_class = get_rbln_config_class(rbln_config_class_name)
443
436
  return cls._rbln_config_class
444
437
 
445
438
  def can_generate(self):
@@ -449,17 +442,15 @@ class RBLNBaseModel(SubModulesMixin, PushToHubMixin, PreTrainedModel):
449
442
  return self
450
443
 
451
444
  def parameters(self):
452
- """
453
- Provides a dummy parameter generator for compatibility.
445
+ # A dummy parameter generator for compatibility.
454
446
 
455
- This method mimics the interface of torch.nn.Module.parameters()
456
- specifically for code that uses `next(model.parameters())` to infer
457
- the device or dtype. It yields a single dummy tensor on CPU with float32 dtype.
447
+ # This method mimics the interface of torch.nn.Module.parameters()
448
+ # specifically for code that uses `next(model.parameters())` to infer
449
+ # the device or dtype. It yields a single dummy tensor on CPU with float32 dtype.
458
450
 
459
- Warning:
460
- This does NOT yield the actual model parameters used by the RBLN runtime.
461
- Code relying on iterating through all model parameters will not work as expected.
462
- """
451
+ # Warning:
452
+ # This does NOT yield the actual model parameters used by the RBLN runtime.
453
+ # Code relying on iterating through all model parameters will not work as expected.
463
454
  yield torch.tensor([1.0], dtype=torch.float32, device=torch.device("cpu"))
464
455
 
465
456
  def __call__(self, *args, **kwargs):
@@ -547,7 +538,7 @@ class RBLNBaseModel(SubModulesMixin, PushToHubMixin, PreTrainedModel):
547
538
 
548
539
  @staticmethod
549
540
  def _raise_missing_compiled_file_error(missing_files: List[str]):
550
- """Raises a KeyError with a message indicating missing compiled model files."""
541
+ # Raises a KeyError with a message indicating missing compiled model files.
551
542
 
552
543
  if len(missing_files) == 1:
553
544
  message = f"The rbln model folder is missing the required '{missing_files[0]}.rbln' file. "
@@ -563,40 +554,3 @@ class RBLNBaseModel(SubModulesMixin, PushToHubMixin, PreTrainedModel):
563
554
  "and ensure the compilation completes successfully."
564
555
  )
565
556
  raise KeyError(message)
566
-
567
- @classmethod
568
- @abstractmethod
569
- def _update_rbln_config(cls, **rbln_config_kwargs) -> RBLNModelConfig:
570
- pass
571
-
572
- @classmethod
573
- @abstractmethod
574
- def _create_runtimes(
575
- cls,
576
- compiled_models: List[rebel.RBLNCompiledModel],
577
- rbln_config: RBLNModelConfig,
578
- ) -> List[rebel.Runtime]:
579
- # compiled_models -> runtimes
580
- pass
581
-
582
- @classmethod
583
- @abstractmethod
584
- def get_pytorch_model(cls, *args, **kwargs):
585
- pass
586
-
587
- @classmethod
588
- @abstractmethod
589
- def from_model(
590
- cls,
591
- model: "PreTrainedModel",
592
- config: Optional[PretrainedConfig] = None,
593
- rbln_config: Optional[RBLNModelConfig] = None,
594
- model_save_dir: Optional[Union[str, Path, TemporaryDirectory]] = None,
595
- subfolder: str = "",
596
- **kwargs,
597
- ):
598
- pass
599
-
600
- @abstractmethod
601
- def forward(self, *args: List[torch.Tensor], **kwargs: Dict[str, torch.Tensor]):
602
- pass
@@ -18,16 +18,9 @@ from transformers.utils import _LazyModule
18
18
 
19
19
 
20
20
  _import_structure = {
21
- "configuration_alias": [
22
- "RBLNASTForAudioClassificationConfig",
23
- "RBLNDistilBertForQuestionAnsweringConfig",
24
- "RBLNResNetForImageClassificationConfig",
25
- "RBLNXLMRobertaForSequenceClassificationConfig",
26
- "RBLNRobertaForSequenceClassificationConfig",
27
- "RBLNRobertaForMaskedLMConfig",
28
- "RBLNViTForImageClassificationConfig",
29
- ],
30
21
  "models": [
22
+ "RBLNASTForAudioClassification",
23
+ "RBLNASTForAudioClassificationConfig",
31
24
  "RBLNAutoModel",
32
25
  "RBLNAutoModelForAudioClassification",
33
26
  "RBLNAutoModelForCausalLM",
@@ -51,12 +44,14 @@ _import_structure = {
51
44
  "RBLNBertForQuestionAnsweringConfig",
52
45
  "RBLNBertModel",
53
46
  "RBLNBertModelConfig",
54
- "RBLNBlip2VisionModelConfig",
55
- "RBLNBlip2VisionModel",
56
- "RBLNBlip2QFormerModel",
57
- "RBLNBlip2QFormerModelConfig",
58
47
  "RBLNBlip2ForConditionalGeneration",
59
48
  "RBLNBlip2ForConditionalGenerationConfig",
49
+ "RBLNBlip2QFormerModel",
50
+ "RBLNBlip2QFormerModelConfig",
51
+ "RBLNBlip2VisionModel",
52
+ "RBLNBlip2VisionModelConfig",
53
+ "RBLNColPaliForRetrieval",
54
+ "RBLNColPaliForRetrievalConfig",
60
55
  "RBLNCLIPTextModel",
61
56
  "RBLNCLIPTextModelConfig",
62
57
  "RBLNCLIPTextModelWithProjection",
@@ -67,40 +62,48 @@ _import_structure = {
67
62
  "RBLNCLIPVisionModelWithProjectionConfig",
68
63
  "RBLNDecoderOnlyModelForCausalLM",
69
64
  "RBLNDecoderOnlyModelForCausalLMConfig",
65
+ "RBLNDistilBertForQuestionAnswering",
66
+ "RBLNDistilBertForQuestionAnsweringConfig",
70
67
  "RBLNDPTForDepthEstimation",
71
68
  "RBLNDPTForDepthEstimationConfig",
72
69
  "RBLNExaoneForCausalLM",
73
70
  "RBLNExaoneForCausalLMConfig",
74
- "RBLNGemmaForCausalLM",
75
- "RBLNGemmaForCausalLMConfig",
76
71
  "RBLNGemma3ForCausalLM",
77
72
  "RBLNGemma3ForCausalLMConfig",
78
73
  "RBLNGemma3ForConditionalGeneration",
79
74
  "RBLNGemma3ForConditionalGenerationConfig",
75
+ "RBLNGemmaForCausalLM",
76
+ "RBLNGemmaForCausalLMConfig",
80
77
  "RBLNGPT2LMHeadModel",
81
78
  "RBLNGPT2LMHeadModelConfig",
82
- "RBLNIdefics3VisionTransformer",
83
79
  "RBLNIdefics3ForConditionalGeneration",
84
80
  "RBLNIdefics3ForConditionalGenerationConfig",
81
+ "RBLNIdefics3VisionTransformer",
85
82
  "RBLNIdefics3VisionTransformerConfig",
86
83
  "RBLNLlamaForCausalLM",
87
84
  "RBLNLlamaForCausalLMConfig",
88
- "RBLNOPTForCausalLM",
89
- "RBLNOPTForCausalLMConfig",
90
85
  "RBLNLlavaNextForConditionalGeneration",
91
86
  "RBLNLlavaNextForConditionalGenerationConfig",
92
87
  "RBLNMidmLMHeadModel",
93
88
  "RBLNMidmLMHeadModelConfig",
94
89
  "RBLNMistralForCausalLM",
95
90
  "RBLNMistralForCausalLMConfig",
91
+ "RBLNOPTForCausalLM",
92
+ "RBLNOPTForCausalLMConfig",
96
93
  "RBLNPhiForCausalLM",
97
94
  "RBLNPhiForCausalLMConfig",
98
- "RBLNQwen2ForCausalLM",
99
- "RBLNQwen2ForCausalLMConfig",
100
95
  "RBLNQwen2_5_VisionTransformerPretrainedModel",
101
96
  "RBLNQwen2_5_VisionTransformerPretrainedModelConfig",
102
97
  "RBLNQwen2_5_VLForConditionalGeneration",
103
98
  "RBLNQwen2_5_VLForConditionalGenerationConfig",
99
+ "RBLNQwen2ForCausalLM",
100
+ "RBLNQwen2ForCausalLMConfig",
101
+ "RBLNResNetForImageClassification",
102
+ "RBLNResNetForImageClassificationConfig",
103
+ "RBLNRobertaForMaskedLM",
104
+ "RBLNRobertaForMaskedLMConfig",
105
+ "RBLNRobertaForSequenceClassification",
106
+ "RBLNRobertaForSequenceClassificationConfig",
104
107
  "RBLNSiglipVisionModel",
105
108
  "RBLNSiglipVisionModelConfig",
106
109
  "RBLNT5EncoderModel",
@@ -109,44 +112,23 @@ _import_structure = {
109
112
  "RBLNT5ForConditionalGenerationConfig",
110
113
  "RBLNTimeSeriesTransformerForPrediction",
111
114
  "RBLNTimeSeriesTransformerForPredictionConfig",
115
+ "RBLNViTForImageClassification",
116
+ "RBLNViTForImageClassificationConfig",
112
117
  "RBLNWav2Vec2ForCTC",
113
118
  "RBLNWav2Vec2ForCTCConfig",
114
119
  "RBLNWhisperForConditionalGeneration",
115
120
  "RBLNWhisperForConditionalGenerationConfig",
121
+ "RBLNXLMRobertaForSequenceClassification",
122
+ "RBLNXLMRobertaForSequenceClassificationConfig",
116
123
  "RBLNXLMRobertaModel",
117
124
  "RBLNXLMRobertaModelConfig",
118
125
  ],
119
- "modeling_alias": [
120
- "RBLNASTForAudioClassification",
121
- "RBLNDistilBertForQuestionAnswering",
122
- "RBLNResNetForImageClassification",
123
- "RBLNXLMRobertaForSequenceClassification",
124
- "RBLNRobertaForSequenceClassification",
125
- "RBLNRobertaForMaskedLM",
126
- "RBLNViTForImageClassification",
127
- ],
128
126
  }
129
127
 
130
128
  if TYPE_CHECKING:
131
- from .configuration_alias import (
132
- RBLNASTForAudioClassificationConfig,
133
- RBLNDistilBertForQuestionAnsweringConfig,
134
- RBLNResNetForImageClassificationConfig,
135
- RBLNRobertaForMaskedLMConfig,
136
- RBLNRobertaForSequenceClassificationConfig,
137
- RBLNViTForImageClassificationConfig,
138
- RBLNXLMRobertaForSequenceClassificationConfig,
139
- )
140
- from .modeling_alias import (
141
- RBLNASTForAudioClassification,
142
- RBLNDistilBertForQuestionAnswering,
143
- RBLNResNetForImageClassification,
144
- RBLNRobertaForMaskedLM,
145
- RBLNRobertaForSequenceClassification,
146
- RBLNViTForImageClassification,
147
- RBLNXLMRobertaForSequenceClassification,
148
- )
149
129
  from .models import (
130
+ RBLNASTForAudioClassification,
131
+ RBLNASTForAudioClassificationConfig,
150
132
  RBLNAutoModel,
151
133
  RBLNAutoModelForAudioClassification,
152
134
  RBLNAutoModelForCausalLM,
@@ -186,6 +168,8 @@ if TYPE_CHECKING:
186
168
  RBLNCLIPVisionModelWithProjectionConfig,
187
169
  RBLNDecoderOnlyModelForCausalLM,
188
170
  RBLNDecoderOnlyModelForCausalLMConfig,
171
+ RBLNDistilBertForQuestionAnswering,
172
+ RBLNDistilBertForQuestionAnsweringConfig,
189
173
  RBLNDPTForDepthEstimation,
190
174
  RBLNDPTForDepthEstimationConfig,
191
175
  RBLNExaoneForCausalLM,
@@ -220,6 +204,12 @@ if TYPE_CHECKING:
220
204
  RBLNQwen2_5_VLForConditionalGenerationConfig,
221
205
  RBLNQwen2ForCausalLM,
222
206
  RBLNQwen2ForCausalLMConfig,
207
+ RBLNResNetForImageClassification,
208
+ RBLNResNetForImageClassificationConfig,
209
+ RBLNRobertaForMaskedLM,
210
+ RBLNRobertaForMaskedLMConfig,
211
+ RBLNRobertaForSequenceClassification,
212
+ RBLNRobertaForSequenceClassificationConfig,
223
213
  RBLNSiglipVisionModel,
224
214
  RBLNSiglipVisionModelConfig,
225
215
  RBLNT5EncoderModel,
@@ -228,10 +218,14 @@ if TYPE_CHECKING:
228
218
  RBLNT5ForConditionalGenerationConfig,
229
219
  RBLNTimeSeriesTransformerForPrediction,
230
220
  RBLNTimeSeriesTransformerForPredictionConfig,
221
+ RBLNViTForImageClassification,
222
+ RBLNViTForImageClassificationConfig,
231
223
  RBLNWav2Vec2ForCTC,
232
224
  RBLNWav2Vec2ForCTCConfig,
233
225
  RBLNWhisperForConditionalGeneration,
234
226
  RBLNWhisperForConditionalGenerationConfig,
227
+ RBLNXLMRobertaForSequenceClassification,
228
+ RBLNXLMRobertaForSequenceClassificationConfig,
235
229
  RBLNXLMRobertaModel,
236
230
  RBLNXLMRobertaModelConfig,
237
231
  )
@@ -12,12 +12,12 @@
12
12
  # See the License for the specific language governing permissions and
13
13
  # limitations under the License.
14
14
 
15
- from typing import List, Optional, Tuple, Union
15
+ from typing import Any, Dict, List, Optional, Tuple, Union
16
16
 
17
17
  from ..configuration_utils import RBLNModelConfig
18
18
 
19
19
 
20
- class _RBLNTransformerEncoderConfig(RBLNModelConfig):
20
+ class RBLNTransformerEncoderConfig(RBLNModelConfig):
21
21
  rbln_model_input_names: Optional[List[str]] = None
22
22
 
23
23
  def __init__(
@@ -25,7 +25,7 @@ class _RBLNTransformerEncoderConfig(RBLNModelConfig):
25
25
  max_seq_len: Optional[int] = None,
26
26
  batch_size: Optional[int] = None,
27
27
  model_input_names: Optional[List[str]] = None,
28
- **kwargs,
28
+ **kwargs: Dict[str, Any],
29
29
  ):
30
30
  """
31
31
  Args:
@@ -47,9 +47,12 @@ class _RBLNTransformerEncoderConfig(RBLNModelConfig):
47
47
  self.model_input_names = model_input_names or self.rbln_model_input_names
48
48
 
49
49
 
50
- class _RBLNImageModelConfig(RBLNModelConfig):
50
+ class RBLNImageModelConfig(RBLNModelConfig):
51
51
  def __init__(
52
- self, image_size: Optional[Union[int, Tuple[int, int]]] = None, batch_size: Optional[int] = None, **kwargs
52
+ self,
53
+ image_size: Optional[Union[int, Tuple[int, int]]] = None,
54
+ batch_size: Optional[int] = None,
55
+ **kwargs: Dict[str, Any],
53
56
  ):
54
57
  """
55
58
  Args:
@@ -86,32 +89,32 @@ class _RBLNImageModelConfig(RBLNModelConfig):
86
89
  return self.image_size["height"]
87
90
 
88
91
 
89
- class RBLNModelForQuestionAnsweringConfig(_RBLNTransformerEncoderConfig):
92
+ class RBLNModelForQuestionAnsweringConfig(RBLNTransformerEncoderConfig):
90
93
  pass
91
94
 
92
95
 
93
- class RBLNModelForSequenceClassificationConfig(_RBLNTransformerEncoderConfig):
96
+ class RBLNModelForSequenceClassificationConfig(RBLNTransformerEncoderConfig):
94
97
  pass
95
98
 
96
99
 
97
- class RBLNModelForMaskedLMConfig(_RBLNTransformerEncoderConfig):
100
+ class RBLNModelForMaskedLMConfig(RBLNTransformerEncoderConfig):
98
101
  pass
99
102
 
100
103
 
101
- class RBLNModelForTextEncodingConfig(_RBLNTransformerEncoderConfig):
104
+ class RBLNModelForTextEncodingConfig(RBLNTransformerEncoderConfig):
102
105
  pass
103
106
 
104
107
 
105
108
  # FIXME : Appropriate name ?
106
- class RBLNTransformerEncoderForFeatureExtractionConfig(_RBLNTransformerEncoderConfig):
109
+ class RBLNTransformerEncoderForFeatureExtractionConfig(RBLNTransformerEncoderConfig):
107
110
  pass
108
111
 
109
112
 
110
- class RBLNModelForImageClassificationConfig(_RBLNImageModelConfig):
113
+ class RBLNModelForImageClassificationConfig(RBLNImageModelConfig):
111
114
  pass
112
115
 
113
116
 
114
- class RBLNModelForDepthEstimationConfig(_RBLNImageModelConfig):
117
+ class RBLNModelForDepthEstimationConfig(RBLNImageModelConfig):
115
118
  pass
116
119
 
117
120
 
@@ -121,7 +124,7 @@ class RBLNModelForAudioClassificationConfig(RBLNModelConfig):
121
124
  batch_size: Optional[int] = None,
122
125
  max_length: Optional[int] = None,
123
126
  num_mel_bins: Optional[int] = None,
124
- **kwargs,
127
+ **kwargs: Dict[str, Any],
125
128
  ):
126
129
  """
127
130
  Args:
@@ -43,9 +43,9 @@ from ..configuration_utils import RBLNCompileConfig
43
43
  from ..modeling import RBLNModel
44
44
  from ..utils.logging import get_logger
45
45
  from .configuration_generic import (
46
+ RBLNImageModelConfig,
46
47
  RBLNModelForAudioClassificationConfig,
47
- _RBLNImageModelConfig,
48
- _RBLNTransformerEncoderConfig,
48
+ RBLNTransformerEncoderConfig,
49
49
  )
50
50
 
51
51
 
@@ -55,7 +55,7 @@ if TYPE_CHECKING:
55
55
  logger = get_logger()
56
56
 
57
57
 
58
- class _RBLNTransformerEncoder(RBLNModel):
58
+ class RBLNTransformerEncoder(RBLNModel):
59
59
  auto_model_class = AutoModel
60
60
  rbln_model_input_names = ["input_ids", "attention_mask", "token_type_ids"]
61
61
  rbln_dtype = "int64"
@@ -66,8 +66,8 @@ class _RBLNTransformerEncoder(RBLNModel):
66
66
  preprocessors: Optional[Union["AutoFeatureExtractor", "AutoProcessor", "AutoTokenizer"]] = None,
67
67
  model: Optional["PreTrainedModel"] = None,
68
68
  model_config: Optional["PretrainedConfig"] = None,
69
- rbln_config: Optional[_RBLNTransformerEncoderConfig] = None,
70
- ) -> _RBLNTransformerEncoderConfig:
69
+ rbln_config: Optional[RBLNTransformerEncoderConfig] = None,
70
+ ) -> RBLNTransformerEncoderConfig:
71
71
  return cls.update_rbln_config_for_transformers_encoder(
72
72
  preprocessors=preprocessors,
73
73
  model=model,
@@ -81,8 +81,8 @@ class _RBLNTransformerEncoder(RBLNModel):
81
81
  preprocessors: Optional[Union["AutoFeatureExtractor", "AutoProcessor", "AutoTokenizer"]] = None,
82
82
  model: Optional["PreTrainedModel"] = None,
83
83
  model_config: Optional["PretrainedConfig"] = None,
84
- rbln_config: Optional[_RBLNTransformerEncoderConfig] = None,
85
- ) -> _RBLNTransformerEncoderConfig:
84
+ rbln_config: Optional[RBLNTransformerEncoderConfig] = None,
85
+ ) -> RBLNTransformerEncoderConfig:
86
86
  max_position_embeddings = getattr(model_config, "n_positions", None) or getattr(
87
87
  model_config, "max_position_embeddings", None
88
88
  )
@@ -139,7 +139,7 @@ class _RBLNTransformerEncoder(RBLNModel):
139
139
  return rbln_config
140
140
 
141
141
 
142
- class _RBLNImageModel(RBLNModel):
142
+ class RBLNImageModel(RBLNModel):
143
143
  auto_model_class = AutoModel
144
144
  main_input_name = "pixel_values"
145
145
  output_class = BaseModelOutput
@@ -150,8 +150,8 @@ class _RBLNImageModel(RBLNModel):
150
150
  preprocessors: Optional[Union["AutoFeatureExtractor", "AutoProcessor", "AutoTokenizer"]] = None,
151
151
  model: Optional["PreTrainedModel"] = None,
152
152
  model_config: Optional["PretrainedConfig"] = None,
153
- rbln_config: Optional[_RBLNImageModelConfig] = None,
154
- ) -> _RBLNImageModelConfig:
153
+ rbln_config: Optional[RBLNImageModelConfig] = None,
154
+ ) -> RBLNImageModelConfig:
155
155
  return cls.update_rbln_config_for_image_model(
156
156
  preprocessors=preprocessors,
157
157
  model=model,
@@ -165,8 +165,8 @@ class _RBLNImageModel(RBLNModel):
165
165
  preprocessors: Optional[Union["AutoFeatureExtractor", "AutoProcessor", "AutoTokenizer"]] = None,
166
166
  model: Optional["PreTrainedModel"] = None,
167
167
  model_config: Optional["PretrainedConfig"] = None,
168
- rbln_config: Optional[_RBLNImageModelConfig] = None,
169
- ) -> _RBLNImageModelConfig:
168
+ rbln_config: Optional[RBLNImageModelConfig] = None,
169
+ ) -> RBLNImageModelConfig:
170
170
  if rbln_config.image_size is None:
171
171
  for processor in preprocessors:
172
172
  if hasattr(processor, "size"):
@@ -196,15 +196,14 @@ class _RBLNImageModel(RBLNModel):
196
196
  return rbln_config
197
197
 
198
198
 
199
- class RBLNModelForQuestionAnswering(_RBLNTransformerEncoder):
199
+ class RBLNModelForQuestionAnswering(RBLNTransformerEncoder):
200
200
  auto_model_class = AutoModelForQuestionAnswering
201
201
  rbln_model_input_names = ["input_ids", "attention_mask", "token_type_ids"]
202
202
  output_class = QuestionAnsweringModelOutput
203
203
 
204
204
  def _prepare_output(self, output, return_dict):
205
- """
206
- Prepare QuestionAnswering specific output format.
207
- """
205
+ # Prepare QuestionAnswering specific output format.
206
+
208
207
  start_logits, end_logits = output
209
208
 
210
209
  if not return_dict:
@@ -213,32 +212,32 @@ class RBLNModelForQuestionAnswering(_RBLNTransformerEncoder):
213
212
  return QuestionAnsweringModelOutput(start_logits=start_logits, end_logits=end_logits)
214
213
 
215
214
 
216
- class RBLNModelForSequenceClassification(_RBLNTransformerEncoder):
215
+ class RBLNModelForSequenceClassification(RBLNTransformerEncoder):
217
216
  auto_model_class = AutoModelForSequenceClassification
218
217
  rbln_model_input_names = ["input_ids", "attention_mask"]
219
218
 
220
219
 
221
- class RBLNModelForMaskedLM(_RBLNTransformerEncoder):
220
+ class RBLNModelForMaskedLM(RBLNTransformerEncoder):
222
221
  auto_model_class = AutoModelForMaskedLM
223
222
  rbln_model_input_names = ["input_ids", "attention_mask"]
224
223
 
225
224
 
226
- class RBLNModelForTextEncoding(_RBLNTransformerEncoder):
225
+ class RBLNModelForTextEncoding(RBLNTransformerEncoder):
227
226
  auto_model_class = AutoModelForTextEncoding
228
227
  rbln_model_input_names = ["input_ids", "attention_mask"]
229
228
 
230
229
 
231
- class RBLNTransformerEncoderForFeatureExtraction(_RBLNTransformerEncoder):
230
+ class RBLNTransformerEncoderForFeatureExtraction(RBLNTransformerEncoder):
232
231
  # TODO: RBLNModel is also for feature extraction.
233
232
  auto_model_class = AutoModel
234
233
  rbln_model_input_names = ["input_ids", "attention_mask"]
235
234
 
236
235
 
237
- class RBLNModelForImageClassification(_RBLNImageModel):
236
+ class RBLNModelForImageClassification(RBLNImageModel):
238
237
  auto_model_class = AutoModelForImageClassification
239
238
 
240
239
 
241
- class RBLNModelForDepthEstimation(_RBLNImageModel):
240
+ class RBLNModelForDepthEstimation(RBLNImageModel):
242
241
  auto_model_class = AutoModelForDepthEstimation
243
242
 
244
243
 
@@ -48,10 +48,13 @@ def _compute_default_rope_parameters(
48
48
  Tuple of (`torch.Tensor`, `float`), containing the inverse frequencies for the RoPE embeddings and the
49
49
  post-processing scaling factor applied to the computed cos/sin (unused in this type of RoPE).
50
50
  """
51
-
52
51
  base = config.rope_theta
53
52
  partial_rotary_factor = config.partial_rotary_factor if hasattr(config, "partial_rotary_factor") else 1.0
54
- head_dim = getattr(config, "head_dim", config.hidden_size // config.num_attention_heads)
53
+ head_dim = (
54
+ config.head_dim
55
+ if hasattr(config, "head_dim") and config.head_dim is not None
56
+ else config.hidden_size // config.num_attention_heads
57
+ )
55
58
  dim = int(head_dim * partial_rotary_factor)
56
59
 
57
60
  attention_factor = 1.0 # Unused in this type of RoPE