optimum-rbln 0.8.1a5__py3-none-any.whl → 0.8.1a7__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (70) hide show
  1. optimum/rbln/__init__.py +18 -0
  2. optimum/rbln/__version__.py +2 -2
  3. optimum/rbln/diffusers/__init__.py +21 -1
  4. optimum/rbln/diffusers/configurations/__init__.py +4 -0
  5. optimum/rbln/diffusers/configurations/models/__init__.py +2 -0
  6. optimum/rbln/diffusers/configurations/models/configuration_autoencoder_kl_cosmos.py +82 -0
  7. optimum/rbln/diffusers/configurations/models/configuration_cosmos_transformer.py +68 -0
  8. optimum/rbln/diffusers/configurations/pipelines/__init__.py +1 -0
  9. optimum/rbln/diffusers/configurations/pipelines/configuration_controlnet.py +0 -4
  10. optimum/rbln/diffusers/configurations/pipelines/configuration_cosmos.py +110 -0
  11. optimum/rbln/diffusers/configurations/pipelines/configuration_kandinsky2_2.py +0 -2
  12. optimum/rbln/diffusers/configurations/pipelines/configuration_stable_diffusion.py +0 -4
  13. optimum/rbln/diffusers/configurations/pipelines/configuration_stable_diffusion_3.py +1 -4
  14. optimum/rbln/diffusers/configurations/pipelines/configuration_stable_diffusion_xl.py +0 -4
  15. optimum/rbln/diffusers/modeling_diffusers.py +57 -40
  16. optimum/rbln/diffusers/models/__init__.py +4 -0
  17. optimum/rbln/diffusers/models/autoencoders/__init__.py +1 -0
  18. optimum/rbln/diffusers/models/autoencoders/autoencoder_kl.py +6 -1
  19. optimum/rbln/diffusers/models/autoencoders/autoencoder_kl_cosmos.py +219 -0
  20. optimum/rbln/diffusers/models/autoencoders/vae.py +49 -5
  21. optimum/rbln/diffusers/models/autoencoders/vq_model.py +6 -1
  22. optimum/rbln/diffusers/models/transformers/__init__.py +1 -0
  23. optimum/rbln/diffusers/models/transformers/transformer_cosmos.py +321 -0
  24. optimum/rbln/diffusers/pipelines/__init__.py +10 -0
  25. optimum/rbln/diffusers/pipelines/cosmos/__init__.py +17 -0
  26. optimum/rbln/diffusers/pipelines/cosmos/configuration_cosmos_guardrail.py +102 -0
  27. optimum/rbln/diffusers/pipelines/cosmos/cosmos_guardrail.py +451 -0
  28. optimum/rbln/diffusers/pipelines/cosmos/pipeline_cosmos_text2world.py +98 -0
  29. optimum/rbln/diffusers/pipelines/cosmos/pipeline_cosmos_video2world.py +98 -0
  30. optimum/rbln/modeling.py +38 -2
  31. optimum/rbln/modeling_base.py +18 -2
  32. optimum/rbln/transformers/modeling_generic.py +3 -3
  33. optimum/rbln/transformers/models/bart/configuration_bart.py +12 -2
  34. optimum/rbln/transformers/models/bart/modeling_bart.py +16 -7
  35. optimum/rbln/transformers/models/bert/configuration_bert.py +18 -3
  36. optimum/rbln/transformers/models/bert/modeling_bert.py +24 -0
  37. optimum/rbln/transformers/models/blip_2/configuration_blip_2.py +13 -1
  38. optimum/rbln/transformers/models/blip_2/modeling_blip_2.py +15 -0
  39. optimum/rbln/transformers/models/clip/configuration_clip.py +12 -2
  40. optimum/rbln/transformers/models/clip/modeling_clip.py +27 -1
  41. optimum/rbln/transformers/models/decoderonly/modeling_decoderonly.py +22 -20
  42. optimum/rbln/transformers/models/distilbert/configuration_distilbert.py +6 -1
  43. optimum/rbln/transformers/models/distilbert/modeling_distilbert.py +8 -0
  44. optimum/rbln/transformers/models/dpt/configuration_dpt.py +6 -1
  45. optimum/rbln/transformers/models/dpt/modeling_dpt.py +6 -1
  46. optimum/rbln/transformers/models/gemma3/modeling_gemma3.py +5 -3
  47. optimum/rbln/transformers/models/llava_next/configuration_llava_next.py +8 -0
  48. optimum/rbln/transformers/models/qwen2_5_vl/configuration_qwen2_5_vl.py +16 -0
  49. optimum/rbln/transformers/models/qwen2_5_vl/modeling_qwen2_5_vl.py +8 -0
  50. optimum/rbln/transformers/models/resnet/configuration_resnet.py +6 -1
  51. optimum/rbln/transformers/models/resnet/modeling_resnet.py +5 -1
  52. optimum/rbln/transformers/models/roberta/configuration_roberta.py +12 -2
  53. optimum/rbln/transformers/models/roberta/modeling_roberta.py +16 -0
  54. optimum/rbln/transformers/models/seq2seq/modeling_seq2seq.py +6 -2
  55. optimum/rbln/transformers/models/siglip/configuration_siglip.py +7 -0
  56. optimum/rbln/transformers/models/siglip/modeling_siglip.py +7 -0
  57. optimum/rbln/transformers/models/t5/configuration_t5.py +12 -2
  58. optimum/rbln/transformers/models/t5/modeling_t5.py +10 -4
  59. optimum/rbln/transformers/models/time_series_transformer/configuration_time_series_transformer.py +7 -0
  60. optimum/rbln/transformers/models/time_series_transformer/modeling_time_series_transformer.py +6 -2
  61. optimum/rbln/transformers/models/vit/configuration_vit.py +6 -1
  62. optimum/rbln/transformers/models/vit/modeling_vit.py +7 -1
  63. optimum/rbln/transformers/models/wav2vec2/configuration_wav2vec2.py +7 -0
  64. optimum/rbln/transformers/models/whisper/configuration_whisper.py +7 -0
  65. optimum/rbln/transformers/models/whisper/modeling_whisper.py +6 -2
  66. optimum/rbln/utils/runtime_utils.py +49 -1
  67. {optimum_rbln-0.8.1a5.dist-info → optimum_rbln-0.8.1a7.dist-info}/METADATA +1 -1
  68. {optimum_rbln-0.8.1a5.dist-info → optimum_rbln-0.8.1a7.dist-info}/RECORD +70 -60
  69. {optimum_rbln-0.8.1a5.dist-info → optimum_rbln-0.8.1a7.dist-info}/WHEEL +0 -0
  70. {optimum_rbln-0.8.1a5.dist-info → optimum_rbln-0.8.1a7.dist-info}/licenses/LICENSE +0 -0
@@ -326,7 +326,7 @@ class RBLNGemma3RuntimeModel(RBLNRuntimeModel):
326
326
  attention_mask: torch.Tensor,
327
327
  position_ids: torch.Tensor,
328
328
  token_type_ids: Optional[torch.Tensor] = None,
329
- ):
329
+ ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, int, torch.Tensor]:
330
330
  """
331
331
  Pads inputs, attention_mask, and position_ids so image token groups (256 tokens with token_type_ids == 1)
332
332
  start at multiples of prefill_chunk_size (256). Returns padded tensors and total padded length.
@@ -338,7 +338,7 @@ class RBLNGemma3RuntimeModel(RBLNRuntimeModel):
338
338
  token_type_ids: (1, seq_len) tensor, 0 for text, 1 for image.
339
339
 
340
340
  Returns:
341
- Tuple: (inputs_padded, attention_mask_padded, position_ids_padded, padded_len, token_type_ids_padded).
341
+ (inputs_padded, attention_mask_padded, position_ids_padded, padded_len, token_type_ids_padded).
342
342
  """
343
343
 
344
344
  if token_type_ids is None:
@@ -816,9 +816,11 @@ class RBLNGemma3ForCausalLM(RBLNDecoderOnlyModelForCausalLM):
816
816
  quantization.maybe_set_quantization_env()
817
817
  original_linear = torch.nn.functional.linear
818
818
  torch.nn.functional.linear = torch.ops.rbln_custom_ops.linear
819
- compiled_model = RBLNModel.compile(
819
+ compiled_model = cls.compile(
820
820
  wrapped_model,
821
821
  compile_config,
822
+ create_runtimes=rbln_config.create_runtimes,
823
+ device=rbln_config.device,
822
824
  example_inputs=example_inputs,
823
825
  compile_context=compile_context,
824
826
  )
@@ -18,6 +18,14 @@ from ....configuration_utils import RBLNModelConfig
18
18
 
19
19
 
20
20
  class RBLNLlavaNextForConditionalGenerationConfig(RBLNModelConfig):
21
+ """
22
+ Configuration class for RBLNLlavaNextForConditionalGeneration.
23
+
24
+ This configuration class stores the configuration parameters specific to
25
+ RBLN-optimized LLaVA-Next models for multimodal conditional generation tasks
26
+ that combine vision and language processing capabilities.
27
+ """
28
+
21
29
  submodules = ["vision_tower", "language_model"]
22
30
 
23
31
  def __init__(
@@ -19,6 +19,14 @@ from ..decoderonly.configuration_decoderonly import RBLNDecoderOnlyModelForCausa
19
19
 
20
20
 
21
21
  class RBLNQwen2_5_VLForConditionalGenerationConfig(RBLNDecoderOnlyModelForCausalLMConfig):
22
+ """
23
+ Configuration class for RBLNQwen2_5_VLForConditionalGeneration.
24
+
25
+ This configuration class stores the configuration parameters specific to
26
+ RBLN-optimized Qwen2.5-VL models for multimodal conditional generation tasks
27
+ that combine vision and language processing capabilities.
28
+ """
29
+
22
30
  submodules = ["visual"]
23
31
 
24
32
  def __init__(
@@ -37,6 +45,14 @@ class RBLNQwen2_5_VLForConditionalGenerationConfig(RBLNDecoderOnlyModelForCausal
37
45
 
38
46
 
39
47
  class RBLNQwen2_5_VisionTransformerPretrainedModelConfig(RBLNModelConfig):
48
+ """
49
+ Configuration class for RBLNQwen2_5_VisionTransformerPretrainedModel.
50
+
51
+ This configuration class stores the configuration parameters specific to
52
+ RBLN-optimized Qwen2.5-VL vision transformer models with window-based attention
53
+ mechanisms for processing images and videos.
54
+ """
55
+
40
56
  def __init__(self, max_seq_lens: Union[int, List[int]] = None, **kwargs: Dict[str, Any]):
41
57
  """
42
58
  Args:
@@ -54,6 +54,14 @@ if TYPE_CHECKING:
54
54
 
55
55
 
56
56
  class RBLNQwen2_5_VisionTransformerPretrainedModel(RBLNModel):
57
+ """
58
+ RBLN optimized Qwen2.5-VL vision transformer model.
59
+
60
+ This class provides hardware-accelerated inference for Qwen2.5-VL vision transformers
61
+ on RBLN devices, supporting image and video encoding for multimodal vision-language tasks
62
+ with window-based attention mechanisms.
63
+ """
64
+
57
65
  auto_model_class = None
58
66
 
59
67
  def __post_init__(self, **kwargs):
@@ -17,4 +17,9 @@ from ...configuration_generic import RBLNModelForImageClassificationConfig
17
17
 
18
18
 
19
19
  class RBLNResNetForImageClassificationConfig(RBLNModelForImageClassificationConfig):
20
- ""
20
+ """
21
+ Configuration class for RBLNResNetForImageClassification.
22
+
23
+ This configuration class stores the configuration parameters specific to
24
+ RBLN-optimized ResNet models for image classification tasks.
25
+ """
@@ -18,5 +18,9 @@ from ...modeling_generic import RBLNModelForImageClassification
18
18
 
19
19
  class RBLNResNetForImageClassification(RBLNModelForImageClassification):
20
20
  """
21
- ResNet model for image classification tasks on RBLN NPU.
21
+ RBLN optimized ResNet model for image classification tasks.
22
+
23
+ This class provides hardware-accelerated inference for ResNet models
24
+ on RBLN devices, supporting image classification with convolutional neural networks
25
+ designed for computer vision tasks.
22
26
  """
@@ -16,8 +16,18 @@ from ...configuration_generic import RBLNModelForMaskedLMConfig, RBLNModelForSeq
16
16
 
17
17
 
18
18
  class RBLNRobertaForMaskedLMConfig(RBLNModelForMaskedLMConfig):
19
- ""
19
+ """
20
+ Configuration class for RBLNRobertaForMaskedLM.
21
+
22
+ This configuration class stores the configuration parameters specific to
23
+ RBLN-optimized RoBERTa models for masked language modeling tasks.
24
+ """
20
25
 
21
26
 
22
27
  class RBLNRobertaForSequenceClassificationConfig(RBLNModelForSequenceClassificationConfig):
23
- ""
28
+ """
29
+ Configuration class for RBLNRobertaForSequenceClassification.
30
+
31
+ This configuration class stores the configuration parameters specific to
32
+ RBLN-optimized RoBERTa models for sequence classification tasks.
33
+ """
@@ -16,8 +16,24 @@ from ...modeling_generic import RBLNModelForMaskedLM, RBLNModelForSequenceClassi
16
16
 
17
17
 
18
18
  class RBLNRobertaForMaskedLM(RBLNModelForMaskedLM):
19
+ """
20
+ RBLN optimized RoBERTa model for masked language modeling tasks.
21
+
22
+ This class provides hardware-accelerated inference for RoBERTa models
23
+ on RBLN devices, supporting masked language modeling tasks such as
24
+ token prediction and text completion.
25
+ """
26
+
19
27
  rbln_model_input_names = ["input_ids", "attention_mask"]
20
28
 
21
29
 
22
30
  class RBLNRobertaForSequenceClassification(RBLNModelForSequenceClassification):
31
+ """
32
+ RBLN optimized RoBERTa model for sequence classification tasks.
33
+
34
+ This class provides hardware-accelerated inference for RoBERTa models
35
+ on RBLN devices, supporting text classification tasks such as sentiment analysis,
36
+ topic classification, and other sequence-level prediction tasks.
37
+ """
38
+
23
39
  rbln_model_input_names = ["input_ids", "attention_mask"]
@@ -161,16 +161,20 @@ class RBLNModelForSeq2SeqLM(RBLNModel, ABC):
161
161
  if "key_value_states" in name:
162
162
  context.mark_static_address(tensor)
163
163
 
164
- compiled_encoder = super().compile(
164
+ compiled_encoder = cls.compile(
165
165
  wrapped_model.encoder,
166
166
  enc_compile_config,
167
+ create_runtimes=rbln_config.create_runtimes,
168
+ device=rbln_config.device,
167
169
  example_inputs=enc_example_inputs,
168
170
  compile_context=context,
169
171
  )
170
172
 
171
- compiled_decoder = super().compile(
173
+ compiled_decoder = cls.compile(
172
174
  wrapped_model.decoder,
173
175
  dec_compile_config,
176
+ create_runtimes=rbln_config.create_runtimes,
177
+ device=rbln_config.device,
174
178
  example_inputs=dec_example_inputs,
175
179
  compile_context=context,
176
180
  )
@@ -18,6 +18,13 @@ from ....configuration_utils import RBLNModelConfig
18
18
 
19
19
 
20
20
  class RBLNSiglipVisionModelConfig(RBLNModelConfig):
21
+ """
22
+ Configuration class for RBLNSiglipVisionModel.
23
+
24
+ This configuration class stores the configuration parameters specific to
25
+ RBLN-optimized SigLIP vision models for image encoding in multimodal tasks.
26
+ """
27
+
21
28
  def __init__(
22
29
  self,
23
30
  batch_size: Optional[int] = None,
@@ -58,6 +58,13 @@ class _SiglipVisionModel(torch.nn.Module):
58
58
 
59
59
 
60
60
  class RBLNSiglipVisionModel(RBLNModel):
61
+ """
62
+ RBLN optimized SigLIP vision model.
63
+
64
+ This class provides hardware-accelerated inference for SigLIP vision models
65
+ on RBLN devices, supporting image encoding for multimodal vision-language tasks.
66
+ """
67
+
61
68
  @classmethod
62
69
  def wrap_model_if_needed(cls, model: torch.nn.Module, rbln_config: RBLNSiglipVisionModelConfig) -> torch.nn.Module:
63
70
  wrapper_cfg = {
@@ -17,8 +17,18 @@ from ..seq2seq import RBLNModelForSeq2SeqLMConfig
17
17
 
18
18
 
19
19
  class RBLNT5EncoderModelConfig(RBLNTransformerEncoderForFeatureExtractionConfig):
20
- pass
20
+ """
21
+ Configuration class for RBLNT5EncoderModel.
22
+
23
+ This configuration class stores the configuration parameters specific to
24
+ RBLN-optimized T5 encoder models for feature extraction tasks.
25
+ """
21
26
 
22
27
 
23
28
  class RBLNT5ForConditionalGenerationConfig(RBLNModelForSeq2SeqLMConfig):
24
- pass
29
+ """
30
+ Configuration class for RBLNT5ForConditionalGeneration.
31
+
32
+ This configuration class stores the configuration parameters specific to
33
+ RBLN-optimized T5 models for conditional text generation tasks.
34
+ """
@@ -17,6 +17,7 @@ from typing import TYPE_CHECKING, Any, Callable
17
17
 
18
18
  import torch
19
19
  from transformers import AutoModelForTextEncoding, T5EncoderModel, T5ForConditionalGeneration
20
+ from transformers.modeling_outputs import BaseModelOutputWithPastAndCrossAttentions
20
21
 
21
22
  from ...modeling_generic import RBLNTransformerEncoderForFeatureExtraction
22
23
  from ...models.seq2seq import RBLNModelForSeq2SeqLM
@@ -64,7 +65,7 @@ class RBLNT5EncoderModel(RBLNTransformerEncoderForFeatureExtraction):
64
65
  """
65
66
 
66
67
  auto_model_class = AutoModelForTextEncoding
67
- rbln_model_input_names = ["input_ids", "attention_mask"]
68
+ output_class = BaseModelOutputWithPastAndCrossAttentions
68
69
 
69
70
  @classmethod
70
71
  def wrap_model_if_needed(self, model: "PreTrainedModel", rbln_config: RBLNT5EncoderModelConfig):
@@ -74,11 +75,16 @@ class RBLNT5EncoderModel(RBLNTransformerEncoderForFeatureExtraction):
74
75
  def update_rbln_config_using_pipe(
75
76
  cls, pipe: "RBLNDiffusionMixin", rbln_config: "RBLNDiffusionMixinConfig", submodule_name: str
76
77
  ) -> "RBLNDiffusionMixinConfig":
77
- submodule_config = getattr(rbln_config, submodule_name)
78
- submodule_config.max_seq_len = rbln_config.max_seq_len or 256
79
- submodule_config.model_input_names = ["input_ids"]
80
78
  return rbln_config
81
79
 
80
+ def forward(self, input_ids=None, attention_mask=None, **kwargs):
81
+ input_dict = {"input_ids": input_ids.long()}
82
+ if attention_mask is not None:
83
+ input_dict["attention_mask"] = attention_mask.long()
84
+
85
+ output = super().forward(**input_dict, **kwargs)
86
+ return output
87
+
82
88
 
83
89
  class RBLNT5ForConditionalGeneration(RBLNModelForSeq2SeqLM):
84
90
  """
@@ -4,6 +4,13 @@ from ....configuration_utils import RBLNModelConfig
4
4
 
5
5
 
6
6
  class RBLNTimeSeriesTransformerForPredictionConfig(RBLNModelConfig):
7
+ """
8
+ Configuration class for RBLNTimeSeriesTransformerForPrediction.
9
+
10
+ This configuration class stores the configuration parameters specific to
11
+ RBLN-optimized Time Series Transformer models for time series forecasting tasks.
12
+ """
13
+
7
14
  def __init__(
8
15
  self,
9
16
  batch_size: Optional[int] = None,
@@ -194,15 +194,19 @@ class RBLNTimeSeriesTransformerForPrediction(RBLNModel):
194
194
  if "key_value_states" in name:
195
195
  context.mark_static_address(tensor)
196
196
 
197
- compiled_decoder = super().compile(
197
+ compiled_decoder = cls.compile(
198
198
  wrapped_model.decoder,
199
199
  dec_compile_config,
200
+ create_runtimes=rbln_config.create_runtimes,
201
+ device=rbln_config.device,
200
202
  example_inputs=dec_example_inputs,
201
203
  compile_context=context,
202
204
  )
203
- compiled_encoder = super().compile(
205
+ compiled_encoder = cls.compile(
204
206
  wrapped_model.encoder,
205
207
  enc_compile_config,
208
+ create_runtimes=rbln_config.create_runtimes,
209
+ device=rbln_config.device,
206
210
  example_inputs=enc_example_inputs,
207
211
  compile_context=context,
208
212
  )
@@ -16,4 +16,9 @@ from ...configuration_generic import RBLNModelForImageClassificationConfig
16
16
 
17
17
 
18
18
  class RBLNViTForImageClassificationConfig(RBLNModelForImageClassificationConfig):
19
- ""
19
+ """
20
+ Configuration class for RBLNViTForImageClassification.
21
+
22
+ This configuration class stores the configuration parameters specific to
23
+ RBLN-optimized Vision Transformer (ViT) models for image classification tasks.
24
+ """
@@ -16,4 +16,10 @@ from ...modeling_generic import RBLNModelForImageClassification
16
16
 
17
17
 
18
18
  class RBLNViTForImageClassification(RBLNModelForImageClassification):
19
- ""
19
+ """
20
+ RBLN optimized Vision Transformer (ViT) model for image classification tasks.
21
+
22
+ This class provides hardware-accelerated inference for Vision Transformer models
23
+ on RBLN devices, supporting image classification with transformer-based architectures
24
+ that process images as sequences of patches.
25
+ """
@@ -16,4 +16,11 @@ from ...configuration_generic import RBLNModelForMaskedLMConfig
16
16
 
17
17
 
18
18
  class RBLNWav2Vec2ForCTCConfig(RBLNModelForMaskedLMConfig):
19
+ """
20
+ Configuration class for RBLNWav2Vec2ForCTC.
21
+
22
+ This configuration class stores the configuration parameters specific to
23
+ RBLN-optimized Wav2Vec2 models for Connectionist Temporal Classification (CTC) tasks.
24
+ """
25
+
19
26
  rbln_model_input_names = ["input_values"]
@@ -24,6 +24,13 @@ logger = get_logger()
24
24
 
25
25
 
26
26
  class RBLNWhisperForConditionalGenerationConfig(RBLNModelConfig):
27
+ """
28
+ Configuration class for RBLNWhisperForConditionalGeneration.
29
+
30
+ This configuration class stores the configuration parameters specific to
31
+ RBLN-optimized Whisper models for speech recognition and transcription tasks.
32
+ """
33
+
27
34
  def __init__(
28
35
  self,
29
36
  batch_size: int = None,
@@ -230,15 +230,19 @@ class RBLNWhisperForConditionalGeneration(RBLNModel, RBLNWhisperGenerationMixin)
230
230
  if "key_value_states" in name:
231
231
  context.mark_static_address(tensor)
232
232
 
233
- compiled_encoder = super().compile(
233
+ compiled_encoder = cls.compile(
234
234
  wrapped_model.encoder,
235
235
  enc_compile_config,
236
+ create_runtimes=rbln_config.create_runtimes,
237
+ device=rbln_config.device,
236
238
  example_inputs=enc_example_inputs,
237
239
  compile_context=context,
238
240
  )
239
- compiled_decoder = super().compile(
241
+ compiled_decoder = cls.compile(
240
242
  wrapped_model.decoder,
241
243
  dec_compile_config,
244
+ create_runtimes=rbln_config.create_runtimes,
245
+ device=rbln_config.device,
242
246
  example_inputs=dec_example_inputs,
243
247
  compile_context=context,
244
248
  )
@@ -13,12 +13,57 @@
13
13
  # limitations under the License.
14
14
 
15
15
  import threading
16
- from typing import Any, Dict, List
16
+ from typing import Any, Dict, List, Optional, Union
17
17
 
18
18
  import rebel
19
19
  import torch
20
20
 
21
21
 
22
+ def tp_and_devices_are_ok(
23
+ tensor_parallel_size: Optional[int] = None,
24
+ device: Optional[Union[int, List[int]]] = None,
25
+ npu: Optional[str] = None,
26
+ ) -> Optional[str]:
27
+ if tensor_parallel_size is None:
28
+ tensor_parallel_size = 1
29
+
30
+ if rebel.device_count() < tensor_parallel_size:
31
+ return (
32
+ f"Tensor parallel size {tensor_parallel_size} is greater than "
33
+ f"the number of available devices {rebel.device_count()}."
34
+ )
35
+
36
+ if device is None:
37
+ device = list(range(tensor_parallel_size))
38
+ elif isinstance(device, int):
39
+ device = [device]
40
+ elif isinstance(device, list):
41
+ if any(not isinstance(d, int) for d in device):
42
+ return "Device must be a(n) (list of) integer(s)."
43
+ if len(device) != tensor_parallel_size:
44
+ return (
45
+ f"The number of devices ({len(device)}) does not match tensor parallel size ({tensor_parallel_size})."
46
+ )
47
+ else:
48
+ return f"Invalid device: {device}"
49
+
50
+ for device_id in device:
51
+ if device_id < 0: # if any device is dummy device, skip it
52
+ return None
53
+ if rebel.get_npu_name(device_id) is None:
54
+ return (
55
+ f"Device {device_id} is not a valid NPU device. Please check your NPU status with 'rbln-stat' command."
56
+ )
57
+
58
+ if npu is not None:
59
+ for device_id in device:
60
+ npu_name = rebel.get_npu_name(device_id)
61
+ if npu_name != npu:
62
+ return f"Device {device_id} ({npu_name}) is not on the same NPU as {npu}."
63
+
64
+ return None
65
+
66
+
22
67
  class RBLNPytorchRuntime:
23
68
  mandatory_members = []
24
69
 
@@ -43,6 +88,9 @@ class RBLNPytorchRuntime:
43
88
  def __repr__(self) -> str:
44
89
  return repr(self.runtime)
45
90
 
91
+ def parameters(self):
92
+ yield torch.tensor([1.0], dtype=torch.float32, device=torch.device("cpu"))
93
+
46
94
 
47
95
  class UnavailableRuntime:
48
96
  """
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: optimum-rbln
3
- Version: 0.8.1a5
3
+ Version: 0.8.1a7
4
4
  Summary: Optimum RBLN is the interface between the HuggingFace Transformers and Diffusers libraries and RBLN accelerators. It provides a set of tools enabling easy model loading and inference on single and multiple rbln device settings for different downstream tasks.
5
5
  Project-URL: Homepage, https://rebellions.ai
6
6
  Project-URL: Documentation, https://docs.rbln.ai