optimum-rbln 0.8.1a5__py3-none-any.whl → 0.8.1a7__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (70) hide show
  1. optimum/rbln/__init__.py +18 -0
  2. optimum/rbln/__version__.py +2 -2
  3. optimum/rbln/diffusers/__init__.py +21 -1
  4. optimum/rbln/diffusers/configurations/__init__.py +4 -0
  5. optimum/rbln/diffusers/configurations/models/__init__.py +2 -0
  6. optimum/rbln/diffusers/configurations/models/configuration_autoencoder_kl_cosmos.py +82 -0
  7. optimum/rbln/diffusers/configurations/models/configuration_cosmos_transformer.py +68 -0
  8. optimum/rbln/diffusers/configurations/pipelines/__init__.py +1 -0
  9. optimum/rbln/diffusers/configurations/pipelines/configuration_controlnet.py +0 -4
  10. optimum/rbln/diffusers/configurations/pipelines/configuration_cosmos.py +110 -0
  11. optimum/rbln/diffusers/configurations/pipelines/configuration_kandinsky2_2.py +0 -2
  12. optimum/rbln/diffusers/configurations/pipelines/configuration_stable_diffusion.py +0 -4
  13. optimum/rbln/diffusers/configurations/pipelines/configuration_stable_diffusion_3.py +1 -4
  14. optimum/rbln/diffusers/configurations/pipelines/configuration_stable_diffusion_xl.py +0 -4
  15. optimum/rbln/diffusers/modeling_diffusers.py +57 -40
  16. optimum/rbln/diffusers/models/__init__.py +4 -0
  17. optimum/rbln/diffusers/models/autoencoders/__init__.py +1 -0
  18. optimum/rbln/diffusers/models/autoencoders/autoencoder_kl.py +6 -1
  19. optimum/rbln/diffusers/models/autoencoders/autoencoder_kl_cosmos.py +219 -0
  20. optimum/rbln/diffusers/models/autoencoders/vae.py +49 -5
  21. optimum/rbln/diffusers/models/autoencoders/vq_model.py +6 -1
  22. optimum/rbln/diffusers/models/transformers/__init__.py +1 -0
  23. optimum/rbln/diffusers/models/transformers/transformer_cosmos.py +321 -0
  24. optimum/rbln/diffusers/pipelines/__init__.py +10 -0
  25. optimum/rbln/diffusers/pipelines/cosmos/__init__.py +17 -0
  26. optimum/rbln/diffusers/pipelines/cosmos/configuration_cosmos_guardrail.py +102 -0
  27. optimum/rbln/diffusers/pipelines/cosmos/cosmos_guardrail.py +451 -0
  28. optimum/rbln/diffusers/pipelines/cosmos/pipeline_cosmos_text2world.py +98 -0
  29. optimum/rbln/diffusers/pipelines/cosmos/pipeline_cosmos_video2world.py +98 -0
  30. optimum/rbln/modeling.py +38 -2
  31. optimum/rbln/modeling_base.py +18 -2
  32. optimum/rbln/transformers/modeling_generic.py +3 -3
  33. optimum/rbln/transformers/models/bart/configuration_bart.py +12 -2
  34. optimum/rbln/transformers/models/bart/modeling_bart.py +16 -7
  35. optimum/rbln/transformers/models/bert/configuration_bert.py +18 -3
  36. optimum/rbln/transformers/models/bert/modeling_bert.py +24 -0
  37. optimum/rbln/transformers/models/blip_2/configuration_blip_2.py +13 -1
  38. optimum/rbln/transformers/models/blip_2/modeling_blip_2.py +15 -0
  39. optimum/rbln/transformers/models/clip/configuration_clip.py +12 -2
  40. optimum/rbln/transformers/models/clip/modeling_clip.py +27 -1
  41. optimum/rbln/transformers/models/decoderonly/modeling_decoderonly.py +22 -20
  42. optimum/rbln/transformers/models/distilbert/configuration_distilbert.py +6 -1
  43. optimum/rbln/transformers/models/distilbert/modeling_distilbert.py +8 -0
  44. optimum/rbln/transformers/models/dpt/configuration_dpt.py +6 -1
  45. optimum/rbln/transformers/models/dpt/modeling_dpt.py +6 -1
  46. optimum/rbln/transformers/models/gemma3/modeling_gemma3.py +5 -3
  47. optimum/rbln/transformers/models/llava_next/configuration_llava_next.py +8 -0
  48. optimum/rbln/transformers/models/qwen2_5_vl/configuration_qwen2_5_vl.py +16 -0
  49. optimum/rbln/transformers/models/qwen2_5_vl/modeling_qwen2_5_vl.py +8 -0
  50. optimum/rbln/transformers/models/resnet/configuration_resnet.py +6 -1
  51. optimum/rbln/transformers/models/resnet/modeling_resnet.py +5 -1
  52. optimum/rbln/transformers/models/roberta/configuration_roberta.py +12 -2
  53. optimum/rbln/transformers/models/roberta/modeling_roberta.py +16 -0
  54. optimum/rbln/transformers/models/seq2seq/modeling_seq2seq.py +6 -2
  55. optimum/rbln/transformers/models/siglip/configuration_siglip.py +7 -0
  56. optimum/rbln/transformers/models/siglip/modeling_siglip.py +7 -0
  57. optimum/rbln/transformers/models/t5/configuration_t5.py +12 -2
  58. optimum/rbln/transformers/models/t5/modeling_t5.py +10 -4
  59. optimum/rbln/transformers/models/time_series_transformer/configuration_time_series_transformer.py +7 -0
  60. optimum/rbln/transformers/models/time_series_transformer/modeling_time_series_transformer.py +6 -2
  61. optimum/rbln/transformers/models/vit/configuration_vit.py +6 -1
  62. optimum/rbln/transformers/models/vit/modeling_vit.py +7 -1
  63. optimum/rbln/transformers/models/wav2vec2/configuration_wav2vec2.py +7 -0
  64. optimum/rbln/transformers/models/whisper/configuration_whisper.py +7 -0
  65. optimum/rbln/transformers/models/whisper/modeling_whisper.py +6 -2
  66. optimum/rbln/utils/runtime_utils.py +49 -1
  67. {optimum_rbln-0.8.1a5.dist-info → optimum_rbln-0.8.1a7.dist-info}/METADATA +1 -1
  68. {optimum_rbln-0.8.1a5.dist-info → optimum_rbln-0.8.1a7.dist-info}/RECORD +70 -60
  69. {optimum_rbln-0.8.1a5.dist-info → optimum_rbln-0.8.1a7.dist-info}/WHEEL +0 -0
  70. {optimum_rbln-0.8.1a5.dist-info → optimum_rbln-0.8.1a7.dist-info}/licenses/LICENSE +0 -0
@@ -0,0 +1,98 @@
1
+ # Copyright 2025 Rebellions Inc. All rights reserved.
2
+
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at:
6
+
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+
16
+ from typing import Any, Dict, Optional
17
+
18
+ from diffusers import CosmosVideoToWorldPipeline
19
+ from diffusers.schedulers import EDMEulerScheduler
20
+ from transformers import T5TokenizerFast
21
+
22
+ from ....transformers.models.t5.modeling_t5 import RBLNT5EncoderModel
23
+ from ....utils.logging import get_logger
24
+ from ...modeling_diffusers import RBLNDiffusionMixin
25
+ from ...models.autoencoders.autoencoder_kl_cosmos import RBLNAutoencoderKLCosmos
26
+ from ...models.transformers.transformer_cosmos import RBLNCosmosTransformer3DModel
27
+ from .cosmos_guardrail import RBLNCosmosSafetyChecker
28
+
29
+
30
+ logger = get_logger(__name__)
31
+
32
+
33
+ class RBLNCosmosVideoToWorldPipeline(RBLNDiffusionMixin, CosmosVideoToWorldPipeline):
34
+ """
35
+ RBLN-accelerated implementation of Cosmos Video to World pipeline for video-to-video generation.
36
+
37
+ This pipeline compiles Cosmos Video to World models to run efficiently on RBLN NPUs, enabling high-performance
38
+ inference for generating images with distinctive artistic style and enhanced visual quality.
39
+ """
40
+
41
+ original_class = CosmosVideoToWorldPipeline
42
+ _submodules = ["text_encoder", "transformer", "vae"]
43
+ _optional_components = ["safety_checker"]
44
+
45
+ def __init__(
46
+ self,
47
+ text_encoder: RBLNT5EncoderModel,
48
+ tokenizer: T5TokenizerFast,
49
+ transformer: RBLNCosmosTransformer3DModel,
50
+ vae: RBLNAutoencoderKLCosmos,
51
+ scheduler: EDMEulerScheduler,
52
+ safety_checker: RBLNCosmosSafetyChecker = None,
53
+ ):
54
+ if safety_checker is None:
55
+ safety_checker = RBLNCosmosSafetyChecker()
56
+
57
+ super().__init__(
58
+ text_encoder=text_encoder,
59
+ tokenizer=tokenizer,
60
+ transformer=transformer,
61
+ vae=vae,
62
+ scheduler=scheduler,
63
+ safety_checker=safety_checker,
64
+ )
65
+
66
+ def handle_additional_kwargs(self, **kwargs):
67
+ if "num_frames" in kwargs and kwargs["num_frames"] != self.transformer.rbln_config.num_frames:
68
+ logger.warning(
69
+ f"The transformer in this pipeline is compiled with 'num_frames={self.transformer.rbln_config.num_frames}'. 'num_frames' set by the user will be ignored"
70
+ )
71
+ kwargs.pop("num_frames")
72
+ if (
73
+ "max_sequence_length" in kwargs
74
+ and kwargs["max_sequence_length"] != self.transformer.rbln_config.max_seq_len
75
+ ):
76
+ logger.warning(
77
+ f"The transformer in this pipeline is compiled with 'max_seq_len={self.transformer.rbln_config.max_seq_len}'. 'max_sequence_length' set by the user will be ignored"
78
+ )
79
+ kwargs.pop("max_sequence_length")
80
+ return kwargs
81
+
82
+ @classmethod
83
+ def from_pretrained(
84
+ cls,
85
+ model_id: str,
86
+ *,
87
+ export: bool = False,
88
+ safety_checker: Optional[RBLNCosmosSafetyChecker] = None,
89
+ rbln_config: Dict[str, Any] = {},
90
+ **kwargs: Dict[str, Any],
91
+ ):
92
+ rbln_config, kwargs = cls.get_rbln_config_class().initialize_from_kwargs(rbln_config, **kwargs)
93
+ if safety_checker is None and export:
94
+ safety_checker = RBLNCosmosSafetyChecker(rbln_config=rbln_config.safety_checker)
95
+
96
+ return super().from_pretrained(
97
+ model_id, export=export, safety_checker=safety_checker, rbln_config=rbln_config, **kwargs
98
+ )
optimum/rbln/modeling.py CHANGED
@@ -64,7 +64,12 @@ class RBLNModel(RBLNBaseModel):
64
64
  def get_compiled_model(cls, model: "PreTrainedModel", rbln_config: RBLNModelConfig):
65
65
  model = cls.wrap_model_if_needed(model, rbln_config)
66
66
  rbln_compile_config = rbln_config.compile_cfgs[0]
67
- compiled_model = cls.compile(model, rbln_compile_config=rbln_compile_config)
67
+ compiled_model = cls.compile(
68
+ model,
69
+ rbln_compile_config=rbln_compile_config,
70
+ create_runtimes=rbln_config.create_runtimes,
71
+ device=rbln_config.device,
72
+ )
68
73
  return compiled_model
69
74
 
70
75
  @classmethod
@@ -237,7 +242,38 @@ class RBLNModel(RBLNBaseModel):
237
242
  for compiled_model in compiled_models
238
243
  ]
239
244
 
240
- def forward(self, *args, return_dict: Optional[bool] = None, **kwargs):
245
+ def forward(self, *args: Any, return_dict: Optional[bool] = None, **kwargs: Dict[str, Any]) -> Any:
246
+ """
247
+ Defines the forward pass of the RBLN model, providing a drop-in replacement for HuggingFace PreTrainedModel.
248
+
249
+ This method executes the compiled RBLN model on RBLN NPU devices while maintaining full compatibility
250
+ with HuggingFace transformers and diffusers APIs. The RBLNModel can be used as a direct substitute
251
+ for any HuggingFace nn.Module/PreTrainedModel, enabling seamless integration into existing workflows.
252
+
253
+ Args:
254
+ *args: Variable length argument list containing model inputs. The format matches the original
255
+ HuggingFace model's forward method signature (e.g., input_ids, attention_mask for
256
+ transformers models, or sample, timestep for diffusers models).
257
+ return_dict:
258
+ Whether to return outputs as a dictionary-like object or as a tuple. When `None`:
259
+ - For transformers models: Uses `self.config.use_return_dict` (typically `True`)
260
+ - For diffusers models: Defaults to `True`
261
+ **kwargs: Arbitrary keyword arguments containing additional model inputs and parameters,
262
+ matching the original HuggingFace model's interface.
263
+
264
+ Returns:
265
+ Model outputs in the same format as the original HuggingFace model.
266
+
267
+ - If `return_dict=True`: Returns a dictionary-like object (e.g., BaseModelOutput,
268
+ CausalLMOutput) with named fields such as `logits`, `hidden_states`, etc.
269
+ - If `return_dict=False`: Returns a tuple containing the raw model outputs.
270
+
271
+ Note:
272
+ - This method maintains the exact same interface as the original HuggingFace model's forward method
273
+ - The compiled model runs on RBLN NPU hardware for accelerated inference
274
+ - All HuggingFace model features (generation, attention patterns, etc.) are preserved
275
+ - Can be used directly in HuggingFace pipelines, transformers.Trainer, and other workflows
276
+ """
241
277
  if self.hf_library_name == "transformers":
242
278
  return_dict = return_dict if return_dict is not None else self.config.use_return_dict
243
279
  else:
@@ -27,7 +27,7 @@ from transformers import AutoConfig, AutoModel, GenerationConfig, PretrainedConf
27
27
  from .configuration_utils import RBLNAutoConfig, RBLNCompileConfig, RBLNModelConfig, get_rbln_config_class
28
28
  from .utils.hub import PushToHubMixin, pull_compiled_model_from_hub, validate_files
29
29
  from .utils.logging import get_logger
30
- from .utils.runtime_utils import UnavailableRuntime
30
+ from .utils.runtime_utils import UnavailableRuntime, tp_and_devices_are_ok
31
31
  from .utils.save_utils import maybe_load_preprocessors
32
32
  from .utils.submodule import SubModulesMixin
33
33
 
@@ -374,7 +374,23 @@ class RBLNBaseModel(SubModulesMixin, PushToHubMixin, PreTrainedModel):
374
374
  return from_pretrained_method(model_id=model_id, **kwargs, rbln_config=rbln_config)
375
375
 
376
376
  @classmethod
377
- def compile(cls, model, rbln_compile_config: Optional[RBLNCompileConfig] = None, **kwargs):
377
+ def compile(
378
+ cls,
379
+ model,
380
+ rbln_compile_config: RBLNCompileConfig,
381
+ create_runtimes: bool,
382
+ device: Union[int, List[int]],
383
+ **kwargs,
384
+ ):
385
+ if create_runtimes:
386
+ runtime_cannot_be_created = tp_and_devices_are_ok(
387
+ tensor_parallel_size=rbln_compile_config.tensor_parallel_size,
388
+ device=device,
389
+ npu=rbln_compile_config.npu,
390
+ )
391
+ if runtime_cannot_be_created:
392
+ raise ValueError(runtime_cannot_be_created)
393
+
378
394
  compiled_model = rebel.compile_from_torch(
379
395
  model,
380
396
  input_info=rbln_compile_config.input_info,
@@ -139,7 +139,7 @@ class RBLNTransformerEncoder(RBLNModel):
139
139
  return rbln_config
140
140
 
141
141
 
142
- class _RBLNImageModel(RBLNModel):
142
+ class RBLNImageModel(RBLNModel):
143
143
  auto_model_class = AutoModel
144
144
  main_input_name = "pixel_values"
145
145
  output_class = BaseModelOutput
@@ -233,11 +233,11 @@ class RBLNTransformerEncoderForFeatureExtraction(RBLNTransformerEncoder):
233
233
  rbln_model_input_names = ["input_ids", "attention_mask"]
234
234
 
235
235
 
236
- class RBLNModelForImageClassification(_RBLNImageModel):
236
+ class RBLNModelForImageClassification(RBLNImageModel):
237
237
  auto_model_class = AutoModelForImageClassification
238
238
 
239
239
 
240
- class RBLNModelForDepthEstimation(_RBLNImageModel):
240
+ class RBLNModelForDepthEstimation(RBLNImageModel):
241
241
  auto_model_class = AutoModelForDepthEstimation
242
242
 
243
243
 
@@ -17,8 +17,18 @@ from ..seq2seq import RBLNModelForSeq2SeqLMConfig
17
17
 
18
18
 
19
19
  class RBLNBartModelConfig(RBLNTransformerEncoderForFeatureExtractionConfig):
20
- pass
20
+ """
21
+ Configuration class for RBLNBartModel.
22
+
23
+ This configuration class stores the configuration parameters specific to
24
+ RBLN-optimized BART models for feature extraction tasks.
25
+ """
21
26
 
22
27
 
23
28
  class RBLNBartForConditionalGenerationConfig(RBLNModelForSeq2SeqLMConfig):
24
- pass
29
+ """
30
+ Configuration class for RBLNBartForConditionalGeneration.
31
+
32
+ This configuration class stores the configuration parameters specific to
33
+ RBLN-optimized BART models for conditional text generation tasks.
34
+ """
@@ -13,7 +13,7 @@
13
13
  # limitations under the License.
14
14
 
15
15
  import inspect
16
- from typing import TYPE_CHECKING, Any, Callable
16
+ from typing import Any, Callable
17
17
 
18
18
  from transformers import BartForConditionalGeneration, PreTrainedModel
19
19
 
@@ -27,19 +27,28 @@ from .configuration_bart import RBLNBartForConditionalGenerationConfig
27
27
  logger = get_logger()
28
28
 
29
29
 
30
- if TYPE_CHECKING:
31
- from transformers import PreTrainedModel
32
-
33
-
34
30
  class RBLNBartModel(RBLNTransformerEncoderForFeatureExtraction):
35
- pass
31
+ """
32
+ RBLN optimized BART model for feature extraction tasks.
33
+
34
+ This class provides hardware-accelerated inference for BART encoder models
35
+ on RBLN devices, optimized for feature extraction use cases.
36
+ """
36
37
 
37
38
 
38
39
  class RBLNBartForConditionalGeneration(RBLNModelForSeq2SeqLM):
40
+ """
41
+ RBLN optimized BART model for conditional text generation tasks.
42
+
43
+ This class provides hardware-accelerated inference for BART models
44
+ on RBLN devices, supporting sequence-to-sequence generation tasks
45
+ such as summarization, translation, and text generation.
46
+ """
47
+
39
48
  support_causal_attn = True
40
49
 
41
50
  @classmethod
42
- def wrap_model_if_needed(self, model: "PreTrainedModel", rbln_config: RBLNBartForConditionalGenerationConfig):
51
+ def wrap_model_if_needed(self, model: PreTrainedModel, rbln_config: RBLNBartForConditionalGenerationConfig):
43
52
  return BartWrapper(
44
53
  model, enc_max_seq_len=rbln_config.enc_max_seq_len, use_attention_mask=rbln_config.use_attention_mask
45
54
  )
@@ -20,12 +20,27 @@ from ...configuration_generic import (
20
20
 
21
21
 
22
22
  class RBLNBertModelConfig(RBLNTransformerEncoderForFeatureExtractionConfig):
23
- pass
23
+ """
24
+ Configuration class for RBLNBertModel.
25
+
26
+ This configuration class stores the configuration parameters specific to
27
+ RBLN-optimized BERT models for feature extraction tasks.
28
+ """
24
29
 
25
30
 
26
31
  class RBLNBertForMaskedLMConfig(RBLNModelForMaskedLMConfig):
27
- pass
32
+ """
33
+ Configuration class for RBLNBertForMaskedLM.
34
+
35
+ This configuration class stores the configuration parameters specific to
36
+ RBLN-optimized BERT models for masked language modeling tasks.
37
+ """
28
38
 
29
39
 
30
40
  class RBLNBertForQuestionAnsweringConfig(RBLNModelForQuestionAnsweringConfig):
31
- pass
41
+ """
42
+ Configuration class for RBLNBertForQuestionAnswering.
43
+
44
+ This configuration class stores the configuration parameters specific to
45
+ RBLN-optimized BERT models for question answering tasks.
46
+ """
@@ -24,12 +24,36 @@ logger = get_logger(__name__)
24
24
 
25
25
 
26
26
  class RBLNBertModel(RBLNTransformerEncoderForFeatureExtraction):
27
+ """
28
+ RBLN optimized BERT model for feature extraction tasks.
29
+
30
+ This class provides hardware-accelerated inference for BERT models
31
+ on RBLN devices, optimized for extracting contextualized embeddings
32
+ and features from text sequences.
33
+ """
34
+
27
35
  rbln_model_input_names = ["input_ids", "attention_mask"]
28
36
 
29
37
 
30
38
  class RBLNBertForMaskedLM(RBLNModelForMaskedLM):
39
+ """
40
+ RBLN optimized BERT model for masked language modeling tasks.
41
+
42
+ This class provides hardware-accelerated inference for BERT models
43
+ on RBLN devices, supporting masked language modeling tasks such as
44
+ token prediction and text completion.
45
+ """
46
+
31
47
  rbln_model_input_names = ["input_ids", "attention_mask", "token_type_ids"]
32
48
 
33
49
 
34
50
  class RBLNBertForQuestionAnswering(RBLNModelForQuestionAnswering):
51
+ """
52
+ RBLN optimized BERT model for question answering tasks.
53
+
54
+ This class provides hardware-accelerated inference for BERT models
55
+ on RBLN devices, supporting extractive question answering tasks where
56
+ the model predicts start and end positions of answers in text.
57
+ """
58
+
35
59
  rbln_model_input_names = ["input_ids", "attention_mask", "token_type_ids"]
@@ -18,10 +18,22 @@ from ....configuration_utils import RBLNModelConfig
18
18
 
19
19
 
20
20
  class RBLNBlip2VisionModelConfig(RBLNModelConfig):
21
- pass
21
+ """
22
+ Configuration class for RBLNBlip2VisionModel.
23
+
24
+ This configuration class stores the configuration parameters specific to
25
+ RBLN-optimized BLIP-2 vision encoder models for multimodal tasks.
26
+ """
22
27
 
23
28
 
24
29
  class RBLNBlip2QFormerModelConfig(RBLNModelConfig):
30
+ """
31
+ Configuration class for RBLNBlip2QFormerModel.
32
+
33
+ This configuration class stores the configuration parameters specific to
34
+ RBLN-optimized BLIP-2 Q-Former models that bridge vision and language modalities.
35
+ """
36
+
25
37
  def __init__(
26
38
  self,
27
39
  num_query_tokens: Optional[int] = None,
@@ -65,6 +65,13 @@ class LoopProjector:
65
65
 
66
66
 
67
67
  class RBLNBlip2VisionModel(RBLNModel):
68
+ """
69
+ RBLN optimized BLIP-2 vision encoder model.
70
+
71
+ This class provides hardware-accelerated inference for BLIP-2 vision encoders
72
+ on RBLN devices, supporting image encoding for multimodal vision-language tasks.
73
+ """
74
+
68
75
  def get_input_embeddings(self):
69
76
  return self.embeddings
70
77
 
@@ -136,6 +143,14 @@ class RBLNBlip2VisionModel(RBLNModel):
136
143
 
137
144
 
138
145
  class RBLNBlip2QFormerModel(RBLNModel):
146
+ """
147
+ RBLN optimized BLIP-2 Q-Former model.
148
+
149
+ This class provides hardware-accelerated inference for BLIP-2 Q-Former models
150
+ on RBLN devices, which bridge vision and language modalities through cross-attention
151
+ mechanisms for multimodal understanding tasks.
152
+ """
153
+
139
154
  def get_input_embeddings(self):
140
155
  return self.embeddings.word_embeddings
141
156
 
@@ -34,7 +34,12 @@ class RBLNCLIPTextModelConfig(RBLNModelConfig):
34
34
 
35
35
 
36
36
  class RBLNCLIPTextModelWithProjectionConfig(RBLNCLIPTextModelConfig):
37
- pass
37
+ """
38
+ Configuration class for RBLNCLIPTextModelWithProjection.
39
+
40
+ This configuration inherits from RBLNCLIPTextModelConfig and stores
41
+ configuration parameters for CLIP text models with projection layers.
42
+ """
38
43
 
39
44
 
40
45
  class RBLNCLIPVisionModelConfig(RBLNModelConfig):
@@ -76,4 +81,9 @@ class RBLNCLIPVisionModelConfig(RBLNModelConfig):
76
81
 
77
82
 
78
83
  class RBLNCLIPVisionModelWithProjectionConfig(RBLNCLIPVisionModelConfig):
79
- pass
84
+ """
85
+ Configuration class for RBLNCLIPVisionModelWithProjection.
86
+
87
+ This configuration inherits from RBLNCLIPVisionModelConfig and stores
88
+ configuration parameters for CLIP vision models with projection layers.
89
+ """
@@ -43,6 +43,13 @@ class _TextEncoder(torch.nn.Module):
43
43
 
44
44
 
45
45
  class RBLNCLIPTextModel(RBLNModel):
46
+ """
47
+ RBLN optimized CLIP text encoder model.
48
+
49
+ This class provides hardware-accelerated inference for CLIP text encoders
50
+ on RBLN devices, supporting text encoding for multimodal tasks.
51
+ """
52
+
46
53
  @classmethod
47
54
  def wrap_model_if_needed(cls, model: torch.nn.Module, rbln_config: RBLNCLIPTextModelConfig) -> torch.nn.Module:
48
55
  return _TextEncoder(model).eval()
@@ -95,7 +102,12 @@ class RBLNCLIPTextModel(RBLNModel):
95
102
 
96
103
 
97
104
  class RBLNCLIPTextModelWithProjection(RBLNCLIPTextModel):
98
- pass
105
+ """
106
+ RBLN optimized CLIP text encoder model with projection layer.
107
+
108
+ This class extends RBLNCLIPTextModel with a projection layer for
109
+ multimodal embedding alignment tasks.
110
+ """
99
111
 
100
112
 
101
113
  class _VisionEncoder(torch.nn.Module):
@@ -109,6 +121,13 @@ class _VisionEncoder(torch.nn.Module):
109
121
 
110
122
 
111
123
  class RBLNCLIPVisionModel(RBLNModel):
124
+ """
125
+ RBLN optimized CLIP vision encoder model.
126
+
127
+ This class provides hardware-accelerated inference for CLIP vision encoders
128
+ on RBLN devices, supporting image encoding for multimodal tasks.
129
+ """
130
+
112
131
  @classmethod
113
132
  def wrap_model_if_needed(cls, model: torch.nn.Module, rbln_config: RBLNCLIPVisionModelConfig) -> torch.nn.Module:
114
133
  return _VisionEncoder(model).eval()
@@ -182,6 +201,13 @@ class RBLNCLIPVisionModel(RBLNModel):
182
201
 
183
202
 
184
203
  class RBLNCLIPVisionModelWithProjection(RBLNCLIPVisionModel):
204
+ """
205
+ RBLN optimized CLIP vision encoder model with projection layer.
206
+
207
+ This class extends RBLNCLIPVisionModel with a projection layer for
208
+ multimodal embedding alignment tasks.
209
+ """
210
+
185
211
  def forward(
186
212
  self,
187
213
  pixel_values: Optional[torch.FloatTensor] = None,
@@ -78,7 +78,7 @@ class RBLNRuntimeModel(RBLNPytorchRuntime):
78
78
  torch.ones(1, 1, self.rbln_config.prefill_chunk_size, self.rbln_config.prefill_chunk_size), diagonal=1
79
79
  )
80
80
 
81
- def get_block_tables(self, cache_position: torch.Tensor, batch_idx: int = None):
81
+ def get_block_tables(self, cache_position: torch.Tensor, batch_idx: int = None) -> torch.Tensor:
82
82
  """
83
83
  Manages and returns the KV cache block tables.
84
84
  Updates the block tables based on the given cache_position, allocating new blocks or reusing existing ones as needed.
@@ -88,7 +88,7 @@ class RBLNRuntimeModel(RBLNPytorchRuntime):
88
88
  batch_idx (int, optional): Specific batch index, used when phase is 'prefill'.
89
89
 
90
90
  Returns:
91
- torch.Tensor: Updated block tables.
91
+ Updated block tables.
92
92
  """
93
93
 
94
94
  NO_BLOCKS_ERROR = (
@@ -458,6 +458,7 @@ class RBLNDecoderOnlyModelForCausalLM(RBLNModel):
458
458
  This class serves as the foundation for various decoder-only architectures like GPT, LLaMA, etc.
459
459
 
460
460
  The class provides core functionality for:
461
+
461
462
  1. Converting pre-trained transformer models to RBLN-optimized format
462
463
  2. Handling the compilation process for RBLN devices
463
464
  3. Managing inference operations for causal language modeling
@@ -532,7 +533,7 @@ class RBLNDecoderOnlyModelForCausalLM(RBLNModel):
532
533
  @classmethod
533
534
  def save_torch_artifacts(
534
535
  cls,
535
- model: "PreTrainedModel",
536
+ model: PreTrainedModel,
536
537
  save_dir_path: Path,
537
538
  subfolder: str,
538
539
  rbln_config: RBLNDecoderOnlyModelForCausalLMConfig,
@@ -566,7 +567,7 @@ class RBLNDecoderOnlyModelForCausalLM(RBLNModel):
566
567
  def get_quantized_model(
567
568
  cls,
568
569
  model_id: str,
569
- config: Optional["PretrainedConfig"] = None,
570
+ config: Optional[PretrainedConfig] = None,
570
571
  use_auth_token: Optional[Union[bool, str]] = None,
571
572
  revision: Optional[str] = None,
572
573
  force_download: bool = False,
@@ -605,16 +606,15 @@ class RBLNDecoderOnlyModelForCausalLM(RBLNModel):
605
606
  return model
606
607
 
607
608
  def __getattr__(self, __name: str) -> Any:
608
- """
609
- Special method to delegate attribute access to the original Huggingface LM class.
610
- This method is called when an attribute is not found in the current instance's dictionary.
611
- It enables transparent access to the original model's attributes and methods while maintaining
612
- proper method binding.
613
-
614
- The method implements a delegation pattern that:
615
- 1. For methods: Creates a wrapper that properly binds 'self' to method calls
616
- 2. For other attributes: Returns them directly from the original class
617
- """
609
+ # Special method to delegate attribute access to the original Huggingface LM class.
610
+ # This method is called when an attribute is not found in the current instance's dictionary.
611
+ # It enables transparent access to the original model's attributes and methods while maintaining
612
+ # proper method binding.
613
+
614
+ # The method implements a delegation pattern that:
615
+
616
+ # 1. For methods: Creates a wrapper that properly binds 'self' to method calls
617
+ # 2. For other attributes: Returns them directly from the original class
618
618
 
619
619
  def redirect(func):
620
620
  return lambda *pargs, **kwargs: func(self, *pargs, **kwargs)
@@ -627,7 +627,7 @@ class RBLNDecoderOnlyModelForCausalLM(RBLNModel):
627
627
  @classmethod
628
628
  def get_pytorch_model(
629
629
  cls, *args, rbln_config: Optional[RBLNDecoderOnlyModelForCausalLMConfig] = None, **kwargs
630
- ) -> "PreTrainedModel":
630
+ ) -> PreTrainedModel:
631
631
  if rbln_config and rbln_config.quantization:
632
632
  model = cls.get_quantized_model(*args, **kwargs)
633
633
  else:
@@ -636,7 +636,7 @@ class RBLNDecoderOnlyModelForCausalLM(RBLNModel):
636
636
  return model
637
637
 
638
638
  @classmethod
639
- def wrap_model_if_needed(cls, model: "PreTrainedModel", rbln_config: "RBLNDecoderOnlyModelForCausalLMConfig"):
639
+ def wrap_model_if_needed(cls, model: PreTrainedModel, rbln_config: "RBLNDecoderOnlyModelForCausalLMConfig"):
640
640
  wrapper_cfg = {
641
641
  "max_seq_len": rbln_config.max_seq_len,
642
642
  "attn_impl": rbln_config.attn_impl,
@@ -654,7 +654,7 @@ class RBLNDecoderOnlyModelForCausalLM(RBLNModel):
654
654
 
655
655
  @classmethod
656
656
  @torch.inference_mode()
657
- def get_compiled_model(cls, model: "PreTrainedModel", rbln_config: RBLNDecoderOnlyModelForCausalLMConfig):
657
+ def get_compiled_model(cls, model: PreTrainedModel, rbln_config: RBLNDecoderOnlyModelForCausalLMConfig):
658
658
  wrapped_model = cls.wrap_model_if_needed(model, rbln_config)
659
659
 
660
660
  rbln_compile_configs = rbln_config.compile_cfgs
@@ -679,9 +679,11 @@ class RBLNDecoderOnlyModelForCausalLM(RBLNModel):
679
679
  quantization.maybe_set_quantization_env()
680
680
  original_linear = torch.nn.functional.linear
681
681
  torch.nn.functional.linear = torch.ops.rbln_custom_ops.linear
682
- compiled_model = RBLNModel.compile(
682
+ compiled_model = cls.compile(
683
683
  wrapped_model,
684
684
  compile_config,
685
+ create_runtimes=rbln_config.create_runtimes,
686
+ device=rbln_config.device,
685
687
  example_inputs=example_inputs,
686
688
  compile_context=compile_context,
687
689
  )
@@ -973,8 +975,8 @@ class RBLNDecoderOnlyModelForCausalLM(RBLNModel):
973
975
  def _update_rbln_config(
974
976
  cls,
975
977
  preprocessors: Optional[Union["AutoFeatureExtractor", "AutoProcessor", "AutoTokenizer"]] = None,
976
- model: Optional["PreTrainedModel"] = None,
977
- model_config: Optional["PretrainedConfig"] = None,
978
+ model: Optional[PreTrainedModel] = None,
979
+ model_config: Optional[PretrainedConfig] = None,
978
980
  rbln_config: Optional[RBLNDecoderOnlyModelForCausalLMConfig] = None,
979
981
  ) -> RBLNDecoderOnlyModelForCausalLMConfig:
980
982
  if rbln_config.max_seq_len is None:
@@ -16,4 +16,9 @@ from ...configuration_generic import RBLNModelForQuestionAnsweringConfig
16
16
 
17
17
 
18
18
  class RBLNDistilBertForQuestionAnsweringConfig(RBLNModelForQuestionAnsweringConfig):
19
- ""
19
+ """
20
+ Configuration class for RBLNDistilBertForQuestionAnswering.
21
+
22
+ This configuration class stores the configuration parameters specific to
23
+ RBLN-optimized DistilBERT models for question answering tasks.
24
+ """
@@ -16,4 +16,12 @@ from ...modeling_generic import RBLNModelForQuestionAnswering
16
16
 
17
17
 
18
18
  class RBLNDistilBertForQuestionAnswering(RBLNModelForQuestionAnswering):
19
+ """
20
+ RBLN optimized DistilBERT model for question answering tasks.
21
+
22
+ This class provides hardware-accelerated inference for DistilBERT models
23
+ on RBLN devices, supporting extractive question answering tasks where
24
+ the model predicts start and end positions of answers in text.
25
+ """
26
+
19
27
  rbln_model_input_names = ["input_ids", "attention_mask"]
@@ -16,4 +16,9 @@ from ...configuration_generic import RBLNModelForDepthEstimationConfig
16
16
 
17
17
 
18
18
  class RBLNDPTForDepthEstimationConfig(RBLNModelForDepthEstimationConfig):
19
- pass
19
+ """
20
+ Configuration class for RBLNDPTForDepthEstimation.
21
+
22
+ This configuration class stores the configuration parameters specific to
23
+ RBLN-optimized DPT (Dense Prediction Transformer) models for depth estimation tasks.
24
+ """
@@ -17,4 +17,9 @@ from ...modeling_generic import RBLNModelForDepthEstimation
17
17
 
18
18
 
19
19
  class RBLNDPTForDepthEstimation(RBLNModelForDepthEstimation):
20
- pass
20
+ """
21
+ RBLN optimized DPT model for depth estimation tasks.
22
+
23
+ This class provides hardware-accelerated inference for DPT (Dense Prediction Transformer)
24
+ models on RBLN devices, supporting monocular depth estimation from single images.
25
+ """