optimum-rbln 0.8.1a6__py3-none-any.whl → 0.8.1a7__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (51) hide show
  1. optimum/rbln/__version__.py +2 -2
  2. optimum/rbln/diffusers/configurations/pipelines/configuration_controlnet.py +0 -4
  3. optimum/rbln/diffusers/configurations/pipelines/configuration_kandinsky2_2.py +0 -2
  4. optimum/rbln/diffusers/configurations/pipelines/configuration_stable_diffusion.py +0 -4
  5. optimum/rbln/diffusers/configurations/pipelines/configuration_stable_diffusion_3.py +0 -4
  6. optimum/rbln/diffusers/configurations/pipelines/configuration_stable_diffusion_xl.py +0 -4
  7. optimum/rbln/diffusers/modeling_diffusers.py +16 -18
  8. optimum/rbln/diffusers/models/autoencoders/autoencoder_kl.py +6 -1
  9. optimum/rbln/diffusers/models/autoencoders/autoencoder_kl_cosmos.py +13 -3
  10. optimum/rbln/diffusers/models/autoencoders/vq_model.py +6 -1
  11. optimum/rbln/diffusers/pipelines/cosmos/cosmos_guardrail.py +70 -14
  12. optimum/rbln/modeling.py +38 -2
  13. optimum/rbln/modeling_base.py +18 -2
  14. optimum/rbln/transformers/modeling_generic.py +3 -3
  15. optimum/rbln/transformers/models/bart/configuration_bart.py +12 -2
  16. optimum/rbln/transformers/models/bart/modeling_bart.py +16 -7
  17. optimum/rbln/transformers/models/bert/configuration_bert.py +18 -3
  18. optimum/rbln/transformers/models/bert/modeling_bert.py +24 -0
  19. optimum/rbln/transformers/models/blip_2/configuration_blip_2.py +13 -1
  20. optimum/rbln/transformers/models/blip_2/modeling_blip_2.py +15 -0
  21. optimum/rbln/transformers/models/clip/configuration_clip.py +12 -2
  22. optimum/rbln/transformers/models/clip/modeling_clip.py +27 -1
  23. optimum/rbln/transformers/models/decoderonly/modeling_decoderonly.py +22 -20
  24. optimum/rbln/transformers/models/distilbert/configuration_distilbert.py +6 -1
  25. optimum/rbln/transformers/models/distilbert/modeling_distilbert.py +8 -0
  26. optimum/rbln/transformers/models/dpt/configuration_dpt.py +6 -1
  27. optimum/rbln/transformers/models/dpt/modeling_dpt.py +6 -1
  28. optimum/rbln/transformers/models/gemma3/modeling_gemma3.py +5 -3
  29. optimum/rbln/transformers/models/llava_next/configuration_llava_next.py +8 -0
  30. optimum/rbln/transformers/models/qwen2_5_vl/configuration_qwen2_5_vl.py +16 -0
  31. optimum/rbln/transformers/models/qwen2_5_vl/modeling_qwen2_5_vl.py +8 -0
  32. optimum/rbln/transformers/models/resnet/configuration_resnet.py +6 -1
  33. optimum/rbln/transformers/models/resnet/modeling_resnet.py +5 -1
  34. optimum/rbln/transformers/models/roberta/configuration_roberta.py +12 -2
  35. optimum/rbln/transformers/models/roberta/modeling_roberta.py +16 -0
  36. optimum/rbln/transformers/models/seq2seq/modeling_seq2seq.py +6 -2
  37. optimum/rbln/transformers/models/siglip/configuration_siglip.py +7 -0
  38. optimum/rbln/transformers/models/siglip/modeling_siglip.py +7 -0
  39. optimum/rbln/transformers/models/t5/configuration_t5.py +12 -2
  40. optimum/rbln/transformers/models/time_series_transformer/configuration_time_series_transformer.py +7 -0
  41. optimum/rbln/transformers/models/time_series_transformer/modeling_time_series_transformer.py +6 -2
  42. optimum/rbln/transformers/models/vit/configuration_vit.py +6 -1
  43. optimum/rbln/transformers/models/vit/modeling_vit.py +7 -1
  44. optimum/rbln/transformers/models/wav2vec2/configuration_wav2vec2.py +7 -0
  45. optimum/rbln/transformers/models/whisper/configuration_whisper.py +7 -0
  46. optimum/rbln/transformers/models/whisper/modeling_whisper.py +6 -2
  47. optimum/rbln/utils/runtime_utils.py +46 -1
  48. {optimum_rbln-0.8.1a6.dist-info → optimum_rbln-0.8.1a7.dist-info}/METADATA +1 -1
  49. {optimum_rbln-0.8.1a6.dist-info → optimum_rbln-0.8.1a7.dist-info}/RECORD +51 -51
  50. {optimum_rbln-0.8.1a6.dist-info → optimum_rbln-0.8.1a7.dist-info}/WHEEL +0 -0
  51. {optimum_rbln-0.8.1a6.dist-info → optimum_rbln-0.8.1a7.dist-info}/licenses/LICENSE +0 -0
@@ -17,5 +17,5 @@ __version__: str
17
17
  __version_tuple__: VERSION_TUPLE
18
18
  version_tuple: VERSION_TUPLE
19
19
 
20
- __version__ = version = '0.8.1a6'
21
- __version_tuple__ = version_tuple = (0, 8, 1, 'a6')
20
+ __version__ = version = '0.8.1a7'
21
+ __version_tuple__ = version_tuple = (0, 8, 1, 'a7')
@@ -20,10 +20,6 @@ from ..models import RBLNAutoencoderKLConfig, RBLNControlNetModelConfig, RBLNUNe
20
20
 
21
21
 
22
22
  class RBLNStableDiffusionControlNetPipelineBaseConfig(RBLNModelConfig):
23
- """
24
- Base configuration for Stable Diffusion ControlNet pipelines.
25
- """
26
-
27
23
  submodules = ["text_encoder", "unet", "vae", "controlnet"]
28
24
  _vae_uses_encoder = False
29
25
 
@@ -21,8 +21,6 @@ from ..models.configuration_prior_transformer import RBLNPriorTransformerConfig
21
21
 
22
22
 
23
23
  class RBLNKandinskyV22PipelineBaseConfig(RBLNModelConfig):
24
- """Base configuration class for Kandinsky V2.2 decoder pipelines."""
25
-
26
24
  submodules = ["unet", "movq"]
27
25
  _movq_uses_encoder = False
28
26
 
@@ -20,10 +20,6 @@ from ..models import RBLNAutoencoderKLConfig, RBLNUNet2DConditionModelConfig
20
20
 
21
21
 
22
22
  class RBLNStableDiffusionPipelineBaseConfig(RBLNModelConfig):
23
- """
24
- Base configuration for Stable Diffusion pipelines.
25
- """
26
-
27
23
  submodules = ["text_encoder", "unet", "vae"]
28
24
  _vae_uses_encoder = False
29
25
 
@@ -20,10 +20,6 @@ from ..models import RBLNAutoencoderKLConfig, RBLNSD3Transformer2DModelConfig
20
20
 
21
21
 
22
22
  class RBLNStableDiffusion3PipelineBaseConfig(RBLNModelConfig):
23
- """
24
- Base configuration for Stable Diffusion 3 pipelines.
25
- """
26
-
27
23
  submodules = ["transformer", "text_encoder", "text_encoder_2", "text_encoder_3", "vae"]
28
24
  _vae_uses_encoder = False
29
25
 
@@ -20,10 +20,6 @@ from ..models import RBLNAutoencoderKLConfig, RBLNUNet2DConditionModelConfig
20
20
 
21
21
 
22
22
  class RBLNStableDiffusionXLPipelineBaseConfig(RBLNModelConfig):
23
- """
24
- Base configuration for Stable Diffusion XL pipelines.
25
- """
26
-
27
23
  submodules = ["text_encoder", "text_encoder_2", "unet", "vae"]
28
24
  _vae_uses_encoder = False
29
25
 
@@ -412,24 +412,22 @@ class RBLNDiffusionMixin:
412
412
  return compiled_image_size
413
413
 
414
414
  def handle_additional_kwargs(self, **kwargs):
415
- """
416
- Function to handle additional compile-time parameters during inference.
417
-
418
- If the additional variable is determined by another module, this method should be overrided.
419
-
420
- Example:
421
- ```python
422
- if hasattr(self, "movq"):
423
- compiled_image_size = self.movq.image_size
424
- kwargs["height"] = compiled_image_size[0]
425
- kwargs["width"] = compiled_image_size[1]
426
-
427
- compiled_num_frames = self.unet.rbln_config.num_frames
428
- if compiled_num_frames is not None:
429
- kwargs["num_frames"] = compiled_num_frames
430
- return kwargs
431
- ```
432
- """
415
+ # Function to handle additional compile-time parameters during inference.
416
+
417
+ # If the additional variable is determined by another module, this method should be overrided.
418
+
419
+ # Example:
420
+ # ```python
421
+ # if hasattr(self, "movq"):
422
+ # compiled_image_size = self.movq.image_size
423
+ # kwargs["height"] = compiled_image_size[0]
424
+ # kwargs["width"] = compiled_image_size[1]
425
+
426
+ # compiled_num_frames = self.unet.rbln_config.num_frames
427
+ # if compiled_num_frames is not None:
428
+ # kwargs["num_frames"] = compiled_num_frames
429
+ # return kwargs
430
+ # ```
433
431
  return kwargs
434
432
 
435
433
  @remove_compile_time_kwargs
@@ -80,7 +80,12 @@ class RBLNAutoencoderKL(RBLNModel):
80
80
 
81
81
  wrapped_model.eval()
82
82
 
83
- compiled_models[model_name] = cls.compile(wrapped_model, rbln_compile_config=rbln_config.compile_cfgs[i])
83
+ compiled_models[model_name] = cls.compile(
84
+ wrapped_model,
85
+ rbln_compile_config=rbln_config.compile_cfgs[i],
86
+ create_runtimes=rbln_config.create_runtimes,
87
+ device=rbln_config.device_map[model_name],
88
+ )
84
89
 
85
90
  return compiled_models
86
91
 
@@ -99,11 +99,21 @@ class RBLNAutoencoderKLCosmos(RBLNModel):
99
99
  compiled_models = {}
100
100
  if rbln_config.uses_encoder:
101
101
  encoder_model, decoder_model = cls.wrap_model_if_needed(model, rbln_config)
102
- enc_compiled_model = cls.compile(encoder_model, rbln_compile_config=rbln_config.compile_cfgs[0])
102
+ enc_compiled_model = cls.compile(
103
+ encoder_model,
104
+ rbln_compile_config=rbln_config.compile_cfgs[0],
105
+ create_runtimes=rbln_config.create_runtimes,
106
+ device=rbln_config.device_map["encoder"],
107
+ )
103
108
  compiled_models["encoder"] = enc_compiled_model
104
109
  else:
105
110
  decoder_model = cls.wrap_model_if_needed(model, rbln_config)
106
- dec_compiled_model = cls.compile(decoder_model, rbln_compile_config=rbln_config.compile_cfgs[-1])
111
+ dec_compiled_model = cls.compile(
112
+ decoder_model,
113
+ rbln_compile_config=rbln_config.compile_cfgs[-1],
114
+ create_runtimes=rbln_config.create_runtimes,
115
+ device=rbln_config.device_map["decoder"],
116
+ )
107
117
  compiled_models["decoder"] = dec_compiled_model
108
118
 
109
119
  finally:
@@ -115,7 +125,7 @@ class RBLNAutoencoderKLCosmos(RBLNModel):
115
125
  def update_rbln_config_using_pipe(
116
126
  cls, pipe: "RBLNDiffusionMixin", rbln_config: "RBLNDiffusionMixinConfig", submodule_name: str
117
127
  ) -> "RBLNDiffusionMixinConfig":
118
- rbln_config.vae.num_channels_latents = pipe.transformer.config.in_channels
128
+ rbln_config.vae.num_channels_latents = pipe.transformer.config.out_channels
119
129
  rbln_config.vae.vae_scale_factor_temporal = pipe.vae_scale_factor_temporal
120
130
  rbln_config.vae.vae_scale_factor_spatial = pipe.vae_scale_factor_spatial
121
131
  return rbln_config
@@ -78,7 +78,12 @@ class RBLNVQModel(RBLNModel):
78
78
 
79
79
  wrapped_model.eval()
80
80
 
81
- compiled_models[model_name] = cls.compile(wrapped_model, rbln_compile_config=rbln_config.compile_cfgs[i])
81
+ compiled_models[model_name] = cls.compile(
82
+ wrapped_model,
83
+ rbln_compile_config=rbln_config.compile_cfgs[i],
84
+ create_runtimes=rbln_config.create_runtimes,
85
+ device=rbln_config.device_map[model_name],
86
+ )
82
87
 
83
88
  return compiled_models
84
89
 
@@ -25,7 +25,7 @@ from huggingface_hub import snapshot_download
25
25
  from transformers import AutoTokenizer, SiglipProcessor
26
26
 
27
27
  from .... import RBLNAutoModelForCausalLM, RBLNSiglipVisionModel
28
- from ....utils.runtime_utils import RBLNPytorchRuntime
28
+ from ....utils.runtime_utils import RBLNPytorchRuntime, UnavailableRuntime
29
29
  from .configuration_cosmos_guardrail import RBLNCosmosSafetyCheckerConfig
30
30
 
31
31
 
@@ -129,6 +129,8 @@ class RBLNSigLIPEncoder(SigLIPEncoder):
129
129
  self.model = RBLNSiglipVisionModel.from_pretrained(
130
130
  self.checkpoint_dir,
131
131
  rbln_device=rbln_config.siglip_encoder.device,
132
+ rbln_create_runtimes=rbln_config.siglip_encoder.create_runtimes,
133
+ rbln_activate_profiler=rbln_config.aegis.activate_profiler,
132
134
  )
133
135
  else:
134
136
  super().__init__(model_name, checkpoint_id)
@@ -139,6 +141,8 @@ class RBLNSigLIPEncoder(SigLIPEncoder):
139
141
  rbln_device=rbln_config.siglip_encoder.device,
140
142
  rbln_image_size=rbln_config.siglip_encoder.image_size,
141
143
  rbln_npu=rbln_config.siglip_encoder.npu,
144
+ rbln_create_runtimes=rbln_config.siglip_encoder.create_runtimes,
145
+ rbln_activate_profiler=rbln_config.siglip_encoder.activate_profiler,
142
146
  )
143
147
  self.rbln_config = rbln_config
144
148
 
@@ -191,7 +195,29 @@ class RBLNRetinaFaceFilter(RetinaFaceFilter):
191
195
  )
192
196
 
193
197
  self.rbln_config = rbln_config
194
- runtime = rebel.Runtime(self.compiled_model, tensor_type="pt", device=self.rbln_config.face_blur_filter.device)
198
+
199
+ try:
200
+ runtime = (
201
+ rebel.Runtime(
202
+ self.compiled_model,
203
+ tensor_type="pt",
204
+ device=self.rbln_config.face_blur_filter.device,
205
+ activate_profiler=rbln_config.face_blur_filter.activate_profiler,
206
+ )
207
+ if self.rbln_config.face_blur_filter.create_runtimes
208
+ else UnavailableRuntime()
209
+ )
210
+ except rebel.core.exception.RBLNRuntimeError as e:
211
+ error_msg = (
212
+ f"\nFailed to create RBLN runtime: {str(e)}\n\n"
213
+ f"If you only need to compile the model without loading it to NPU, you can use:\n"
214
+ f" from_pretrained(..., rbln_create_runtimes=False) or\n"
215
+ f" from_pretrained(..., rbln_config={{..., 'create_runtimes': False}})\n\n"
216
+ f"To check your NPU status, run the 'rbln-stat' command in your terminal.\n"
217
+ f"Make sure your NPU is properly installed and operational."
218
+ )
219
+ raise rebel.core.exception.RBLNRuntimeError(error_msg) from e
220
+
195
221
  self.net = RBLNPytorchRuntime(runtime)
196
222
 
197
223
  def save_pretrained(self, checkpoint_id: str):
@@ -245,11 +271,28 @@ class RBLNVideoSafetyModel(VideoSafetyModel):
245
271
  npu=self.rbln_config.video_safety_model.npu,
246
272
  )
247
273
 
248
- runtime = rebel.Runtime(
249
- self.compiled_model,
250
- tensor_type="pt",
251
- device=self.rbln_config.video_safety_model.device,
252
- )
274
+ try:
275
+ runtime = (
276
+ rebel.Runtime(
277
+ self.compiled_model,
278
+ tensor_type="pt",
279
+ device=self.rbln_config.video_safety_model.device,
280
+ activate_profiler=rbln_config.video_safety_model.activate_profiler,
281
+ )
282
+ if self.rbln_config.video_safety_model.create_runtimes
283
+ else UnavailableRuntime()
284
+ )
285
+ except rebel.core.exception.RBLNRuntimeError as e:
286
+ error_msg = (
287
+ f"\nFailed to create RBLN runtime: {str(e)}\n\n"
288
+ f"If you only need to compile the model without loading it to NPU, you can use:\n"
289
+ f" from_pretrained(..., rbln_create_runtimes=False) or\n"
290
+ f" from_pretrained(..., rbln_config={{..., 'create_runtimes': False}})\n\n"
291
+ f"To check your NPU status, run the 'rbln-stat' command in your terminal.\n"
292
+ f"Make sure your NPU is properly installed and operational."
293
+ )
294
+ raise rebel.core.exception.RBLNRuntimeError(error_msg) from e
295
+
253
296
  self.network = RBLNPytorchRuntime(runtime)
254
297
 
255
298
  def save_pretrained(self, checkpoint_id: str):
@@ -291,7 +334,12 @@ class RBLNAegis(Aegis):
291
334
  torch.nn.Module.__init__(self)
292
335
  cache_dir = pathlib.Path(checkpoint_id) / "aegis"
293
336
  self.tokenizer = AutoTokenizer.from_pretrained(cache_dir)
294
- self.model = RBLNAutoModelForCausalLM.from_pretrained(cache_dir, rbln_device=rbln_config.aegis.device)
337
+ self.model = RBLNAutoModelForCausalLM.from_pretrained(
338
+ cache_dir,
339
+ rbln_device=rbln_config.aegis.device,
340
+ rbln_create_runtimes=rbln_config.aegis.create_runtimes,
341
+ rbln_activate_profiler=rbln_config.aegis.activate_profiler,
342
+ )
295
343
 
296
344
  else:
297
345
  super().__init__(checkpoint_id, base_model_id, aegis_adapter)
@@ -302,7 +350,9 @@ class RBLNAegis(Aegis):
302
350
  model,
303
351
  rbln_tensor_parallel_size=4,
304
352
  rbln_device=rbln_config.aegis.device,
353
+ rbln_create_runtimes=rbln_config.aegis.create_runtimes,
305
354
  rbln_npu=rbln_config.aegis.npu,
355
+ rbln_activate_profiler=rbln_config.aegis.activate_profiler,
306
356
  )
307
357
 
308
358
  self.rbln_config = rbln_config
@@ -335,19 +385,25 @@ class RBLNCosmosSafetyChecker(CosmosSafetyChecker):
335
385
 
336
386
  if rbln_config is None:
337
387
  rbln_config = RBLNCosmosSafetyCheckerConfig()
388
+ elif isinstance(rbln_config, dict):
389
+ rbln_config = RBLNCosmosSafetyCheckerConfig(**rbln_config)
338
390
 
339
391
  self.text_guardrail = GuardrailRunner(
340
392
  safety_models=[
341
393
  Blocklist(COSMOS_GUARDRAIL_CHECKPOINT), # Changed since it cannot be saved
342
- RBLNAegis(checkpoint_id, aegis_model_id, aegis_adapter_id, rbln_config=rbln_config),
394
+ RBLNAegis(
395
+ checkpoint_id=checkpoint_id,
396
+ base_model_id=aegis_model_id,
397
+ aegis_adapter=aegis_adapter_id,
398
+ rbln_config=rbln_config,
399
+ ),
343
400
  ]
344
401
  )
345
402
 
346
- with patch("torch.load", partial(torch.load, weights_only=True, map_location=torch.device("cpu"))):
347
- self.video_guardrail = GuardrailRunner(
348
- safety_models=[RBLNVideoContentSafetyFilter(checkpoint_id, rbln_config=rbln_config)],
349
- postprocessors=[RBLNRetinaFaceFilter(checkpoint_id, rbln_config=rbln_config)],
350
- )
403
+ self.video_guardrail = GuardrailRunner(
404
+ safety_models=[RBLNVideoContentSafetyFilter(checkpoint_id=checkpoint_id, rbln_config=rbln_config)],
405
+ postprocessors=[RBLNRetinaFaceFilter(checkpoint_id=checkpoint_id, rbln_config=rbln_config)],
406
+ )
351
407
 
352
408
  self.rbln_config = rbln_config
353
409
 
optimum/rbln/modeling.py CHANGED
@@ -64,7 +64,12 @@ class RBLNModel(RBLNBaseModel):
64
64
  def get_compiled_model(cls, model: "PreTrainedModel", rbln_config: RBLNModelConfig):
65
65
  model = cls.wrap_model_if_needed(model, rbln_config)
66
66
  rbln_compile_config = rbln_config.compile_cfgs[0]
67
- compiled_model = cls.compile(model, rbln_compile_config=rbln_compile_config)
67
+ compiled_model = cls.compile(
68
+ model,
69
+ rbln_compile_config=rbln_compile_config,
70
+ create_runtimes=rbln_config.create_runtimes,
71
+ device=rbln_config.device,
72
+ )
68
73
  return compiled_model
69
74
 
70
75
  @classmethod
@@ -237,7 +242,38 @@ class RBLNModel(RBLNBaseModel):
237
242
  for compiled_model in compiled_models
238
243
  ]
239
244
 
240
- def forward(self, *args, return_dict: Optional[bool] = None, **kwargs):
245
+ def forward(self, *args: Any, return_dict: Optional[bool] = None, **kwargs: Dict[str, Any]) -> Any:
246
+ """
247
+ Defines the forward pass of the RBLN model, providing a drop-in replacement for HuggingFace PreTrainedModel.
248
+
249
+ This method executes the compiled RBLN model on RBLN NPU devices while maintaining full compatibility
250
+ with HuggingFace transformers and diffusers APIs. The RBLNModel can be used as a direct substitute
251
+ for any HuggingFace nn.Module/PreTrainedModel, enabling seamless integration into existing workflows.
252
+
253
+ Args:
254
+ *args: Variable length argument list containing model inputs. The format matches the original
255
+ HuggingFace model's forward method signature (e.g., input_ids, attention_mask for
256
+ transformers models, or sample, timestep for diffusers models).
257
+ return_dict:
258
+ Whether to return outputs as a dictionary-like object or as a tuple. When `None`:
259
+ - For transformers models: Uses `self.config.use_return_dict` (typically `True`)
260
+ - For diffusers models: Defaults to `True`
261
+ **kwargs: Arbitrary keyword arguments containing additional model inputs and parameters,
262
+ matching the original HuggingFace model's interface.
263
+
264
+ Returns:
265
+ Model outputs in the same format as the original HuggingFace model.
266
+
267
+ - If `return_dict=True`: Returns a dictionary-like object (e.g., BaseModelOutput,
268
+ CausalLMOutput) with named fields such as `logits`, `hidden_states`, etc.
269
+ - If `return_dict=False`: Returns a tuple containing the raw model outputs.
270
+
271
+ Note:
272
+ - This method maintains the exact same interface as the original HuggingFace model's forward method
273
+ - The compiled model runs on RBLN NPU hardware for accelerated inference
274
+ - All HuggingFace model features (generation, attention patterns, etc.) are preserved
275
+ - Can be used directly in HuggingFace pipelines, transformers.Trainer, and other workflows
276
+ """
241
277
  if self.hf_library_name == "transformers":
242
278
  return_dict = return_dict if return_dict is not None else self.config.use_return_dict
243
279
  else:
@@ -27,7 +27,7 @@ from transformers import AutoConfig, AutoModel, GenerationConfig, PretrainedConf
27
27
  from .configuration_utils import RBLNAutoConfig, RBLNCompileConfig, RBLNModelConfig, get_rbln_config_class
28
28
  from .utils.hub import PushToHubMixin, pull_compiled_model_from_hub, validate_files
29
29
  from .utils.logging import get_logger
30
- from .utils.runtime_utils import UnavailableRuntime
30
+ from .utils.runtime_utils import UnavailableRuntime, tp_and_devices_are_ok
31
31
  from .utils.save_utils import maybe_load_preprocessors
32
32
  from .utils.submodule import SubModulesMixin
33
33
 
@@ -374,7 +374,23 @@ class RBLNBaseModel(SubModulesMixin, PushToHubMixin, PreTrainedModel):
374
374
  return from_pretrained_method(model_id=model_id, **kwargs, rbln_config=rbln_config)
375
375
 
376
376
  @classmethod
377
- def compile(cls, model, rbln_compile_config: Optional[RBLNCompileConfig] = None, **kwargs):
377
+ def compile(
378
+ cls,
379
+ model,
380
+ rbln_compile_config: RBLNCompileConfig,
381
+ create_runtimes: bool,
382
+ device: Union[int, List[int]],
383
+ **kwargs,
384
+ ):
385
+ if create_runtimes:
386
+ runtime_cannot_be_created = tp_and_devices_are_ok(
387
+ tensor_parallel_size=rbln_compile_config.tensor_parallel_size,
388
+ device=device,
389
+ npu=rbln_compile_config.npu,
390
+ )
391
+ if runtime_cannot_be_created:
392
+ raise ValueError(runtime_cannot_be_created)
393
+
378
394
  compiled_model = rebel.compile_from_torch(
379
395
  model,
380
396
  input_info=rbln_compile_config.input_info,
@@ -139,7 +139,7 @@ class RBLNTransformerEncoder(RBLNModel):
139
139
  return rbln_config
140
140
 
141
141
 
142
- class _RBLNImageModel(RBLNModel):
142
+ class RBLNImageModel(RBLNModel):
143
143
  auto_model_class = AutoModel
144
144
  main_input_name = "pixel_values"
145
145
  output_class = BaseModelOutput
@@ -233,11 +233,11 @@ class RBLNTransformerEncoderForFeatureExtraction(RBLNTransformerEncoder):
233
233
  rbln_model_input_names = ["input_ids", "attention_mask"]
234
234
 
235
235
 
236
- class RBLNModelForImageClassification(_RBLNImageModel):
236
+ class RBLNModelForImageClassification(RBLNImageModel):
237
237
  auto_model_class = AutoModelForImageClassification
238
238
 
239
239
 
240
- class RBLNModelForDepthEstimation(_RBLNImageModel):
240
+ class RBLNModelForDepthEstimation(RBLNImageModel):
241
241
  auto_model_class = AutoModelForDepthEstimation
242
242
 
243
243
 
@@ -17,8 +17,18 @@ from ..seq2seq import RBLNModelForSeq2SeqLMConfig
17
17
 
18
18
 
19
19
  class RBLNBartModelConfig(RBLNTransformerEncoderForFeatureExtractionConfig):
20
- pass
20
+ """
21
+ Configuration class for RBLNBartModel.
22
+
23
+ This configuration class stores the configuration parameters specific to
24
+ RBLN-optimized BART models for feature extraction tasks.
25
+ """
21
26
 
22
27
 
23
28
  class RBLNBartForConditionalGenerationConfig(RBLNModelForSeq2SeqLMConfig):
24
- pass
29
+ """
30
+ Configuration class for RBLNBartForConditionalGeneration.
31
+
32
+ This configuration class stores the configuration parameters specific to
33
+ RBLN-optimized BART models for conditional text generation tasks.
34
+ """
@@ -13,7 +13,7 @@
13
13
  # limitations under the License.
14
14
 
15
15
  import inspect
16
- from typing import TYPE_CHECKING, Any, Callable
16
+ from typing import Any, Callable
17
17
 
18
18
  from transformers import BartForConditionalGeneration, PreTrainedModel
19
19
 
@@ -27,19 +27,28 @@ from .configuration_bart import RBLNBartForConditionalGenerationConfig
27
27
  logger = get_logger()
28
28
 
29
29
 
30
- if TYPE_CHECKING:
31
- from transformers import PreTrainedModel
32
-
33
-
34
30
  class RBLNBartModel(RBLNTransformerEncoderForFeatureExtraction):
35
- pass
31
+ """
32
+ RBLN optimized BART model for feature extraction tasks.
33
+
34
+ This class provides hardware-accelerated inference for BART encoder models
35
+ on RBLN devices, optimized for feature extraction use cases.
36
+ """
36
37
 
37
38
 
38
39
  class RBLNBartForConditionalGeneration(RBLNModelForSeq2SeqLM):
40
+ """
41
+ RBLN optimized BART model for conditional text generation tasks.
42
+
43
+ This class provides hardware-accelerated inference for BART models
44
+ on RBLN devices, supporting sequence-to-sequence generation tasks
45
+ such as summarization, translation, and text generation.
46
+ """
47
+
39
48
  support_causal_attn = True
40
49
 
41
50
  @classmethod
42
- def wrap_model_if_needed(self, model: "PreTrainedModel", rbln_config: RBLNBartForConditionalGenerationConfig):
51
+ def wrap_model_if_needed(self, model: PreTrainedModel, rbln_config: RBLNBartForConditionalGenerationConfig):
43
52
  return BartWrapper(
44
53
  model, enc_max_seq_len=rbln_config.enc_max_seq_len, use_attention_mask=rbln_config.use_attention_mask
45
54
  )
@@ -20,12 +20,27 @@ from ...configuration_generic import (
20
20
 
21
21
 
22
22
  class RBLNBertModelConfig(RBLNTransformerEncoderForFeatureExtractionConfig):
23
- pass
23
+ """
24
+ Configuration class for RBLNBertModel.
25
+
26
+ This configuration class stores the configuration parameters specific to
27
+ RBLN-optimized BERT models for feature extraction tasks.
28
+ """
24
29
 
25
30
 
26
31
  class RBLNBertForMaskedLMConfig(RBLNModelForMaskedLMConfig):
27
- pass
32
+ """
33
+ Configuration class for RBLNBertForMaskedLM.
34
+
35
+ This configuration class stores the configuration parameters specific to
36
+ RBLN-optimized BERT models for masked language modeling tasks.
37
+ """
28
38
 
29
39
 
30
40
  class RBLNBertForQuestionAnsweringConfig(RBLNModelForQuestionAnsweringConfig):
31
- pass
41
+ """
42
+ Configuration class for RBLNBertForQuestionAnswering.
43
+
44
+ This configuration class stores the configuration parameters specific to
45
+ RBLN-optimized BERT models for question answering tasks.
46
+ """
@@ -24,12 +24,36 @@ logger = get_logger(__name__)
24
24
 
25
25
 
26
26
  class RBLNBertModel(RBLNTransformerEncoderForFeatureExtraction):
27
+ """
28
+ RBLN optimized BERT model for feature extraction tasks.
29
+
30
+ This class provides hardware-accelerated inference for BERT models
31
+ on RBLN devices, optimized for extracting contextualized embeddings
32
+ and features from text sequences.
33
+ """
34
+
27
35
  rbln_model_input_names = ["input_ids", "attention_mask"]
28
36
 
29
37
 
30
38
  class RBLNBertForMaskedLM(RBLNModelForMaskedLM):
39
+ """
40
+ RBLN optimized BERT model for masked language modeling tasks.
41
+
42
+ This class provides hardware-accelerated inference for BERT models
43
+ on RBLN devices, supporting masked language modeling tasks such as
44
+ token prediction and text completion.
45
+ """
46
+
31
47
  rbln_model_input_names = ["input_ids", "attention_mask", "token_type_ids"]
32
48
 
33
49
 
34
50
  class RBLNBertForQuestionAnswering(RBLNModelForQuestionAnswering):
51
+ """
52
+ RBLN optimized BERT model for question answering tasks.
53
+
54
+ This class provides hardware-accelerated inference for BERT models
55
+ on RBLN devices, supporting extractive question answering tasks where
56
+ the model predicts start and end positions of answers in text.
57
+ """
58
+
35
59
  rbln_model_input_names = ["input_ids", "attention_mask", "token_type_ids"]
@@ -18,10 +18,22 @@ from ....configuration_utils import RBLNModelConfig
18
18
 
19
19
 
20
20
  class RBLNBlip2VisionModelConfig(RBLNModelConfig):
21
- pass
21
+ """
22
+ Configuration class for RBLNBlip2VisionModel.
23
+
24
+ This configuration class stores the configuration parameters specific to
25
+ RBLN-optimized BLIP-2 vision encoder models for multimodal tasks.
26
+ """
22
27
 
23
28
 
24
29
  class RBLNBlip2QFormerModelConfig(RBLNModelConfig):
30
+ """
31
+ Configuration class for RBLNBlip2QFormerModel.
32
+
33
+ This configuration class stores the configuration parameters specific to
34
+ RBLN-optimized BLIP-2 Q-Former models that bridge vision and language modalities.
35
+ """
36
+
25
37
  def __init__(
26
38
  self,
27
39
  num_query_tokens: Optional[int] = None,
@@ -65,6 +65,13 @@ class LoopProjector:
65
65
 
66
66
 
67
67
  class RBLNBlip2VisionModel(RBLNModel):
68
+ """
69
+ RBLN optimized BLIP-2 vision encoder model.
70
+
71
+ This class provides hardware-accelerated inference for BLIP-2 vision encoders
72
+ on RBLN devices, supporting image encoding for multimodal vision-language tasks.
73
+ """
74
+
68
75
  def get_input_embeddings(self):
69
76
  return self.embeddings
70
77
 
@@ -136,6 +143,14 @@ class RBLNBlip2VisionModel(RBLNModel):
136
143
 
137
144
 
138
145
  class RBLNBlip2QFormerModel(RBLNModel):
146
+ """
147
+ RBLN optimized BLIP-2 Q-Former model.
148
+
149
+ This class provides hardware-accelerated inference for BLIP-2 Q-Former models
150
+ on RBLN devices, which bridge vision and language modalities through cross-attention
151
+ mechanisms for multimodal understanding tasks.
152
+ """
153
+
139
154
  def get_input_embeddings(self):
140
155
  return self.embeddings.word_embeddings
141
156
 
@@ -34,7 +34,12 @@ class RBLNCLIPTextModelConfig(RBLNModelConfig):
34
34
 
35
35
 
36
36
  class RBLNCLIPTextModelWithProjectionConfig(RBLNCLIPTextModelConfig):
37
- pass
37
+ """
38
+ Configuration class for RBLNCLIPTextModelWithProjection.
39
+
40
+ This configuration inherits from RBLNCLIPTextModelConfig and stores
41
+ configuration parameters for CLIP text models with projection layers.
42
+ """
38
43
 
39
44
 
40
45
  class RBLNCLIPVisionModelConfig(RBLNModelConfig):
@@ -76,4 +81,9 @@ class RBLNCLIPVisionModelConfig(RBLNModelConfig):
76
81
 
77
82
 
78
83
  class RBLNCLIPVisionModelWithProjectionConfig(RBLNCLIPVisionModelConfig):
79
- pass
84
+ """
85
+ Configuration class for RBLNCLIPVisionModelWithProjection.
86
+
87
+ This configuration inherits from RBLNCLIPVisionModelConfig and stores
88
+ configuration parameters for CLIP vision models with projection layers.
89
+ """