optimum-rbln 0.1.11__py3-none-any.whl → 0.1.13__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (72) hide show
  1. optimum/rbln/__init__.py +14 -7
  2. optimum/rbln/__version__.py +1 -1
  3. optimum/rbln/diffusers/models/autoencoder_kl.py +30 -63
  4. optimum/rbln/diffusers/models/controlnet.py +36 -62
  5. optimum/rbln/diffusers/models/unet_2d_condition.py +57 -156
  6. optimum/rbln/diffusers/pipelines/__init__.py +40 -12
  7. optimum/rbln/diffusers/pipelines/controlnet/multicontrolnet.py +11 -0
  8. optimum/rbln/diffusers/pipelines/controlnet/pipeline_controlnet.py +9 -187
  9. optimum/rbln/diffusers/pipelines/controlnet/pipeline_controlnet_img2img.py +8 -192
  10. optimum/rbln/diffusers/pipelines/controlnet/pipeline_controlnet_sd_xl.py +8 -206
  11. optimum/rbln/diffusers/pipelines/controlnet/pipeline_controlnet_sd_xl_img2img.py +8 -207
  12. optimum/rbln/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion.py +3 -111
  13. optimum/rbln/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_img2img.py +12 -117
  14. optimum/rbln/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl.py +4 -123
  15. optimum/rbln/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_img2img.py +4 -126
  16. optimum/rbln/modeling_alias.py +4 -9
  17. optimum/rbln/modeling_base.py +117 -144
  18. optimum/rbln/modeling_config.py +51 -0
  19. optimum/rbln/modeling_diffusers.py +400 -0
  20. optimum/rbln/transformers/__init__.py +10 -0
  21. optimum/rbln/transformers/cache_utils.py +5 -9
  22. optimum/rbln/transformers/modeling_rope_utils.py +283 -0
  23. optimum/rbln/transformers/models/__init__.py +80 -28
  24. optimum/rbln/transformers/models/auto/modeling_auto.py +1 -0
  25. optimum/rbln/transformers/models/bart/__init__.py +1 -1
  26. optimum/rbln/transformers/models/bart/bart_architecture.py +18 -12
  27. optimum/rbln/transformers/models/bart/modeling_bart.py +25 -6
  28. optimum/rbln/transformers/models/bert/modeling_bert.py +1 -2
  29. optimum/rbln/transformers/models/clip/modeling_clip.py +13 -23
  30. optimum/rbln/transformers/models/decoderonly/__init__.py +0 -2
  31. optimum/rbln/transformers/models/decoderonly/decoderonly_architecture.py +376 -218
  32. optimum/rbln/transformers/models/decoderonly/modeling_decoderonly.py +246 -116
  33. optimum/rbln/transformers/models/dpt/modeling_dpt.py +0 -1
  34. optimum/rbln/transformers/models/exaone/__init__.py +32 -0
  35. optimum/rbln/transformers/models/exaone/exaone_architecture.py +81 -0
  36. optimum/rbln/transformers/models/exaone/hf_hub_cached/configuration_exaone.py +181 -0
  37. optimum/rbln/transformers/models/exaone/hf_hub_cached/modeling_exaone.py +1725 -0
  38. optimum/rbln/transformers/models/exaone/modeling_exaone.py +53 -0
  39. optimum/rbln/transformers/models/gemma/gemma_architecture.py +12 -2
  40. optimum/rbln/transformers/models/gemma/modeling_gemma.py +4 -28
  41. optimum/rbln/transformers/models/gpt2/modeling_gpt2.py +4 -30
  42. optimum/rbln/transformers/models/llama/modeling_llama.py +4 -28
  43. optimum/rbln/transformers/models/llava_next/modeling_llava_next.py +166 -151
  44. optimum/rbln/transformers/models/midm/midm_architecture.py +4 -15
  45. optimum/rbln/transformers/models/midm/modeling_midm.py +8 -28
  46. optimum/rbln/transformers/models/mistral/modeling_mistral.py +4 -29
  47. optimum/rbln/transformers/models/phi/modeling_phi.py +5 -31
  48. optimum/rbln/transformers/models/phi/phi_architecture.py +75 -159
  49. optimum/rbln/transformers/models/qwen2/__init__.py +24 -0
  50. optimum/rbln/transformers/models/qwen2/modeling_qwen2.py +43 -0
  51. optimum/rbln/transformers/models/qwen2/qwen2_architecture.py +29 -0
  52. optimum/rbln/transformers/models/seq2seq/__init__.py +24 -0
  53. optimum/rbln/{modeling_seq2seq.py → transformers/models/seq2seq/modeling_seq2seq.py} +107 -166
  54. optimum/rbln/transformers/models/t5/__init__.py +1 -0
  55. optimum/rbln/transformers/models/t5/modeling_t5.py +108 -0
  56. optimum/rbln/transformers/models/t5/t5_architecture.py +46 -32
  57. optimum/rbln/transformers/models/wav2vec2/modeling_wav2vec2.py +0 -1
  58. optimum/rbln/transformers/models/whisper/modeling_whisper.py +38 -13
  59. optimum/rbln/transformers/models/xlm_roberta/modeling_xlm_roberta.py +1 -2
  60. optimum/rbln/transformers/utils/rbln_quantization.py +8 -2
  61. optimum/rbln/utils/context.py +58 -0
  62. optimum/rbln/utils/decorator_utils.py +55 -0
  63. optimum/rbln/utils/import_utils.py +21 -0
  64. optimum/rbln/utils/logging.py +1 -1
  65. optimum/rbln/utils/runtime_utils.py +4 -4
  66. optimum/rbln/utils/timer_utils.py +26 -2
  67. {optimum_rbln-0.1.11.dist-info → optimum_rbln-0.1.13.dist-info}/METADATA +11 -9
  68. optimum_rbln-0.1.13.dist-info/RECORD +107 -0
  69. {optimum_rbln-0.1.11.dist-info → optimum_rbln-0.1.13.dist-info}/WHEEL +1 -1
  70. optimum_rbln-0.1.11.dist-info/RECORD +0 -93
  71. {optimum_rbln-0.1.11.dist-info → optimum_rbln-0.1.13.dist-info}/entry_points.txt +0 -0
  72. {optimum_rbln-0.1.11.dist-info → optimum_rbln-0.1.13.dist-info}/licenses/LICENSE +0 -0
@@ -43,6 +43,12 @@ if TYPE_CHECKING:
43
43
  from transformers import T5ForConditionalGeneration
44
44
 
45
45
 
46
+ class T5Wrapper:
47
+ def __init__(self, model):
48
+ self.encoder = T5EncoderWrapper(model)
49
+ self.decoder = T5DecoderWrapper(model)
50
+
51
+
46
52
  class T5Encoder(T5Stack):
47
53
  def forward(
48
54
  self,
@@ -122,19 +128,26 @@ class T5EncoderWrapper(torch.nn.Module):
122
128
  )
123
129
  self.encoder_max_length = None
124
130
  self.decoder_max_length = None
125
- self.decoder_batch_size = 1
126
131
 
127
- def forward(self, input_ids: torch.Tensor, attention_mask: torch.Tensor,
128
- cross_key_value: torch.Tensor = None, batch_idx: torch.Tensor = None,
129
- ) -> torch.Tensor:
130
- encoder_batch_size = input_ids.shape[0]
131
- decoder_batch_size = self.decoder_batch_size
132
+ def forward(
133
+ self,
134
+ input_ids: torch.Tensor,
135
+ attention_mask: torch.Tensor,
136
+ cross_key_value: torch.Tensor = None,
137
+ batch_idx: torch.Tensor = None,
138
+ ) -> torch.Tensor:
132
139
  decoder_max_length = self.decoder_max_length or self.default_max_length
133
140
  encoder_max_length = self.encoder_max_length or self.default_max_length
134
141
 
135
142
  attn_layer = self.encoder.block[0].layer[0].SelfAttention
136
143
  encoder_position_bias = T5Attention.compute_bias(attn_layer, encoder_max_length, encoder_max_length)
137
- encoder_outputs = T5Encoder.forward(self.encoder, input_ids, attention_mask, encoder_position_bias, batch_ids=torch.tensor(0, dtype=torch.int32))
144
+ encoder_outputs = T5Encoder.forward(
145
+ self.encoder,
146
+ input_ids,
147
+ attention_mask,
148
+ encoder_position_bias,
149
+ batch_ids=torch.tensor(0, dtype=torch.int32),
150
+ )
138
151
 
139
152
  attn_layer = self.decoder.block[0].layer[0].SelfAttention
140
153
  decoder_position_bias = T5Attention.compute_bias(attn_layer, decoder_max_length, decoder_max_length)
@@ -145,22 +158,14 @@ class T5EncoderWrapper(torch.nn.Module):
145
158
 
146
159
  dummy_past_key_value = []
147
160
  for i in range(self.config.num_layers):
148
- pkv_self_attn_key = torch.zeros(
149
- decoder_batch_size, self.config.num_heads, decoder_max_length, self.config.d_kv
150
- )
151
- pkv_self_attn_value = torch.zeros(
152
- decoder_batch_size, self.config.num_heads, decoder_max_length, self.config.d_kv
153
- )
154
- pkv_cross_attn_key = torch.zeros(
155
- encoder_batch_size, self.config.num_heads, encoder_max_length, self.config.d_kv
156
- )
157
- pkv_cross_attn_value = torch.zeros(
158
- encoder_batch_size, self.config.num_heads, encoder_max_length, self.config.d_kv
159
- )
161
+ pkv_self_attn_key = torch.zeros(1, self.config.num_heads, decoder_max_length, self.config.d_kv)
162
+ pkv_self_attn_value = torch.zeros(1, self.config.num_heads, decoder_max_length, self.config.d_kv)
163
+ pkv_cross_attn_key = torch.zeros(1, self.config.num_heads, encoder_max_length, self.config.d_kv)
164
+ pkv_cross_attn_value = torch.zeros(1, self.config.num_heads, encoder_max_length, self.config.d_kv)
160
165
  layer_pkv = (pkv_self_attn_key, pkv_self_attn_value, pkv_cross_attn_key, pkv_cross_attn_value)
161
166
  dummy_past_key_value.append(layer_pkv)
162
167
 
163
- decoder_attention_mask = torch.zeros(decoder_batch_size, decoder_max_length, dtype=torch.float32)
168
+ decoder_attention_mask = torch.zeros(1, decoder_max_length, dtype=torch.float32)
164
169
  decoder_attention_mask[:, :1] = 1
165
170
 
166
171
  # Since first step of decoder has different graph to further step of it,
@@ -168,7 +173,7 @@ class T5EncoderWrapper(torch.nn.Module):
168
173
  # TODO(jongho): Separate first-step-decoder.
169
174
  decoder_outputs = T5Decoder.forward(
170
175
  self.decoder,
171
- input_ids=torch.zeros(decoder_batch_size, 1, dtype=torch.int64),
176
+ input_ids=torch.zeros(1, 1, dtype=torch.int64),
172
177
  attention_mask=decoder_attention_mask,
173
178
  position_bias=decoder_position_bias,
174
179
  encoder_decoder_position_bias=encoder_decoder_position_bias,
@@ -187,7 +192,7 @@ class T5EncoderWrapper(torch.nn.Module):
187
192
  cross_kv_cache.append(past_key_values[i][3])
188
193
  cross_kv_cache = torch.stack(cross_kv_cache, dim=0)
189
194
 
190
- cross_key_value = cross_key_value.slice_scatter(cross_kv_cache, dim=1, start=batch_idx, end=batch_idx+1)
195
+ cross_key_value = cross_key_value.slice_scatter(cross_kv_cache, dim=1, start=batch_idx, end=batch_idx + 1)
191
196
 
192
197
  return cross_key_value
193
198
 
@@ -240,6 +245,7 @@ class T5DecoderWrapper(torch.nn.Module):
240
245
  attn_layer = self.model.decoder.block[0].layer[0].SelfAttention
241
246
  _decoder_position_bias = T5Attention.compute_bias(attn_layer, decoder_max_length, decoder_max_length)
242
247
 
248
+ # position_bias need to compute with batch (for cb)
243
249
  batch_decoder_position_bias = []
244
250
  for i in range(input_ids.shape[0]):
245
251
  batch_position_bias = _decoder_position_bias[:, :, cache_position[i][0]].unsqueeze(2)
@@ -259,7 +265,7 @@ class T5DecoderWrapper(torch.nn.Module):
259
265
  encoder_decoder_position_bias=encoder_decoder_position_bias,
260
266
  past_key_values=kv_cache,
261
267
  cache_position=cache_position,
262
- batch_ids=rbln_batch_position
268
+ batch_ids=rbln_batch_position,
263
269
  )
264
270
 
265
271
  past_key_values = decoder_outputs.past_key_values
@@ -312,7 +318,7 @@ class _T5Attention(T5Attention):
312
318
  value_states = shape(self.v(hidden_states), batch_size)
313
319
  else:
314
320
  # cross-attn
315
- if cache_position.dim() == 0 :
321
+ if cache_position.dim() == 0:
316
322
  key_states = shape(self.k(key_value_states), key_value_states.shape[0])
317
323
  value_states = shape(self.v(key_value_states), key_value_states.shape[0])
318
324
  past_key_value = key_states, value_states
@@ -331,18 +337,24 @@ class _T5Attention(T5Attention):
331
337
  batch_value_states = value_states[b].unsqueeze(0)
332
338
 
333
339
  if is_self_attn and past_key_value is not None:
334
- batch_key_states = past_key_value[0][b].unsqueeze(0).slice_scatter(
335
- batch_key_states, dim=-2, start=cache_position[b][0], end=cache_position[b][0] + 1
340
+ batch_key_states = (
341
+ past_key_value[0][b]
342
+ .unsqueeze(0)
343
+ .slice_scatter(
344
+ batch_key_states, dim=-2, start=cache_position[b][0], end=cache_position[b][0] + 1
345
+ )
336
346
  )
337
- batch_value_states = past_key_value[1][b].unsqueeze(0).slice_scatter(
338
- batch_value_states, dim=-2, start=cache_position[b][0], end=cache_position[b][0] + 1
347
+ batch_value_states = (
348
+ past_key_value[1][b]
349
+ .unsqueeze(0)
350
+ .slice_scatter(
351
+ batch_value_states, dim=-2, start=cache_position[b][0], end=cache_position[b][0] + 1
352
+ )
339
353
  )
340
354
 
341
355
  scores = torch.matmul(batch_query_states, batch_key_states.transpose(3, 2))
342
356
  scores += position_bias[b]
343
- attn_weights = nn.functional.softmax(scores.float(), dim=-1).type_as(
344
- scores
345
- )
357
+ attn_weights = nn.functional.softmax(scores.float(), dim=-1).type_as(scores)
346
358
  attn_output = unshape(torch.matmul(attn_weights, batch_value_states), 1)
347
359
  all_key_states.append(batch_key_states)
348
360
  all_value_states.append(batch_value_states)
@@ -371,7 +383,9 @@ class _T5Attention(T5Attention):
371
383
  scores
372
384
  ) # (batch_size, n_heads, seq_length, key_length)
373
385
 
374
- attn_output = unshape(torch.matmul(attn_weights, value_states), batch_size) # (batch_size, seq_length, dim)
386
+ attn_output = unshape(
387
+ torch.matmul(attn_weights, value_states), batch_size
388
+ ) # (batch_size, seq_length, dim)
375
389
 
376
390
  attn_output = self.o(attn_output)
377
391
  present_key_value = (key_states, value_states)
@@ -65,7 +65,6 @@ class RBLNWav2Vec2ForCTC(RBLNModel):
65
65
  - compiling the resulting graph using the RBLN compiler.
66
66
  """
67
67
 
68
- model_type = "rbln_model"
69
68
  main_input_name = "input_values"
70
69
  auto_model_class = AutoModelForMaskedLM
71
70
 
@@ -59,14 +59,16 @@ if TYPE_CHECKING:
59
59
  class RBLNRuntimeEncoder(RBLNPytorchRuntime):
60
60
  mandatory_members = ["main_input_name"]
61
61
 
62
- def forward(self, *args: List[torch.Tensor], **kwargs: Dict[str, torch.Tensor]):
62
+ def forward(self, input_features: torch.Tensor = None):
63
63
  # backward compatibility transformers==4.40.2
64
64
  # https://github.com/huggingface/transformers/blob/4fdf58afb72b0754da30037fc800b6044e7d9c99/src/transformers/pipelines/automatic_speech_recognition.py#L494
65
- input_features = kwargs.get("input_features", None)
66
- if input_features is None:
67
- input_features = args[0]
65
+
66
+ n_pad_to_batch = self.batch_size - input_features.shape[0]
67
+ if n_pad_to_batch > 0:
68
+ input_features = torch.nn.functional.pad(input_features, (0, 0, 0, 0, 0, n_pad_to_batch))
68
69
 
69
70
  _ = super().forward(input_features=input_features)
71
+
70
72
  # dummy output for generation
71
73
  return BaseModelOutput(last_hidden_state=torch.tensor([[-1.0]]))
72
74
 
@@ -74,18 +76,33 @@ class RBLNRuntimeEncoder(RBLNPytorchRuntime):
74
76
  class RBLNRuntimeDecoder(RBLNPytorchRuntime):
75
77
  mandatory_members = ["main_input_name"]
76
78
 
77
- def forward(self, *args: List[torch.Tensor], **kwargs: Dict[str, torch.Tensor]):
78
- outputs = super().forward(*args, **kwargs)
79
+ def forward(
80
+ self,
81
+ decoder_input_ids: torch.Tensor = None,
82
+ decoder_attention_mask: torch.Tensor = None,
83
+ cache_position: torch.Tensor = None,
84
+ ):
85
+ inputs_bsz = decoder_input_ids.shape[0]
86
+ padded_bsz = self.batch_size - inputs_bsz
87
+ if padded_bsz > 0:
88
+ decoder_input_ids = torch.nn.functional.pad(decoder_input_ids, (0, 0, 0, padded_bsz))
89
+
90
+ outputs = super().forward(
91
+ decoder_input_ids=decoder_input_ids,
92
+ decoder_attention_mask=decoder_attention_mask,
93
+ cache_position=cache_position,
94
+ )
95
+
79
96
  if isinstance(outputs, torch.Tensor):
80
- return Seq2SeqLMOutput(logits=outputs, cross_attentions=None)
97
+ return Seq2SeqLMOutput(logits=outputs[:inputs_bsz], cross_attentions=None)
81
98
  else:
82
- return Seq2SeqLMOutput(logits=outputs[0], cross_attentions=outputs[1])
99
+ return Seq2SeqLMOutput(logits=outputs[0][:inputs_bsz], cross_attentions=outputs[1][:, :inputs_bsz])
83
100
 
84
101
 
85
102
  class RBLNWhisperForConditionalGeneration(RBLNModel, RBLNWhisperGenerationMixin):
86
103
  """
87
104
  The Whisper Model with a language modeling head. Can be used for automatic speech recognition.
88
- This model inherits from [`RBLNMultiModel`]. Check the superclass documentation for the generic methods the library implements for all its models.
105
+ This model inherits from [`RBLNDecoderOnlyModelForCausalLM`]. Check the superclass documentation for the generic methods the library implements for all its models.
89
106
 
90
107
  A class to convert and run pre-trained transformers based LlamaForCausalLM model on RBLN devices.
91
108
  It implements the methods to convert a pre-trained transformers LlamaForCausalLM model into a RBLN transformer model by:
@@ -93,7 +110,6 @@ class RBLNWhisperForConditionalGeneration(RBLNModel, RBLNWhisperGenerationMixin)
93
110
  - compiling the resulting graph using the RBLN compiler.
94
111
  """
95
112
 
96
- model_type = "rbln_model"
97
113
  auto_model_class = AutoModelForSpeechSeq2Seq
98
114
  main_input_name = "input_ids"
99
115
 
@@ -104,8 +120,12 @@ class RBLNWhisperForConditionalGeneration(RBLNModel, RBLNWhisperGenerationMixin)
104
120
  self.dec_max_seq_len = self.rbln_config.model_cfg["dec_max_seq_len"]
105
121
  self.rbln_token_timestamps = self.rbln_config.model_cfg["token_timestamps"]
106
122
 
107
- self.encoder = RBLNRuntimeEncoder(runtime=self.model[0], main_input_name="input_features")
108
- self.decoder = RBLNRuntimeDecoder(runtime=self.model[1], main_input_name="input_ids")
123
+ self.encoder = RBLNRuntimeEncoder(
124
+ runtime=self.model[0], main_input_name="input_features", batch_size=self.batch_size
125
+ )
126
+ self.decoder = RBLNRuntimeDecoder(
127
+ runtime=self.model[1], main_input_name="input_ids", batch_size=self.batch_size
128
+ )
109
129
 
110
130
  # skip encoder & first decoder when language detected
111
131
  self.is_language_detected = False
@@ -200,7 +220,11 @@ class RBLNWhisperForConditionalGeneration(RBLNModel, RBLNWhisperGenerationMixin)
200
220
  expected_seq_len = model_config.max_source_positions * 2
201
221
  num_mel_bins = model_config.num_mel_bins
202
222
  enc_max_seq_len = model_config.max_source_positions
203
- rbln_dec_max_seq_len = model_config.max_length
223
+
224
+ # 'whisper-large-v3-turbo' doesn't have 'max_length', but PretrainedConfig have default value for the key 'max_length'
225
+ rbln_dec_max_seq_len = getattr(model_config, "max_target_positions", None)
226
+ if rbln_dec_max_seq_len is None:
227
+ rbln_dec_max_seq_len = model_config.max_length
204
228
 
205
229
  # model input info
206
230
  enc_input_info = [("input_features", [rbln_batch_size, num_mel_bins, expected_seq_len], "float32")]
@@ -273,6 +297,7 @@ class RBLNWhisperForConditionalGeneration(RBLNModel, RBLNWhisperGenerationMixin)
273
297
  self,
274
298
  input_ids,
275
299
  cache_position: Optional[torch.Tensor] = None,
300
+ attention_mask: Optional[torch.Tensor] = None, # need for support transformers>=4.45.0
276
301
  **kwargs,
277
302
  ):
278
303
  """
@@ -25,7 +25,7 @@ import logging
25
25
  from typing import TYPE_CHECKING, Any, Dict, Optional, Union
26
26
 
27
27
  import torch
28
- from transformers import AutoModel, PretrainedConfig, PreTrainedModel, XLMRobertaConfig, XLMRobertaModel
28
+ from transformers import PretrainedConfig, PreTrainedModel, XLMRobertaConfig, XLMRobertaModel
29
29
 
30
30
  from ....modeling_base import RBLNModel
31
31
  from ....modeling_config import RBLNCompileConfig, RBLNConfig
@@ -38,7 +38,6 @@ if TYPE_CHECKING:
38
38
 
39
39
 
40
40
  class RBLNXLMRobertaModel(RBLNModel):
41
- auto_model_class = AutoModel # feature extraction
42
41
  original_model_class = XLMRobertaModel
43
42
  original_config_class = XLMRobertaConfig
44
43
 
@@ -31,8 +31,13 @@ from torch.nn import functional as F
31
31
 
32
32
  # Constants
33
33
  QUANTIZED_WEIGHTS = {
34
- "q_proj", "k_proj", "v_proj", "o_proj",
35
- "gate_proj", "up_proj", "down_proj",
34
+ "q_proj",
35
+ "k_proj",
36
+ "v_proj",
37
+ "o_proj",
38
+ "gate_proj",
39
+ "up_proj",
40
+ "down_proj",
36
41
  }
37
42
 
38
43
 
@@ -81,6 +86,7 @@ def create_qlinear(layer: Linear) -> Linear:
81
86
  """
82
87
  Converts a standard linear layer to a quantized linear (qlinear) layer with a custom forward pass.
83
88
  """
89
+
84
90
  def qlinear_forward(self, inputs: torch.Tensor) -> torch.Tensor:
85
91
  if inputs.dtype != self.scales.dtype:
86
92
  raise TypeError(f"Expected input dtype {self.scales.dtype}, but got {inputs.dtype}")
@@ -0,0 +1,58 @@
1
+ # Copyright 2024 Rebellions Inc.
2
+
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at:
6
+
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ # Portions of this software are licensed under the Apache License,
16
+ # Version 2.0. See the NOTICE file distributed with this work for
17
+ # additional information regarding copyright ownership.
18
+
19
+ # All other portions of this software, including proprietary code,
20
+ # are the intellectual property of Rebellions Inc. and may not be
21
+ # copied, modified, or distributed without prior written permission
22
+ # from Rebellions Inc.
23
+
24
+ from contextlib import contextmanager
25
+ from pathlib import Path
26
+ from typing import Union
27
+
28
+ from optimum.exporters import TasksManager
29
+ from transformers import AutoConfig, AutoModel
30
+
31
+
32
+ @contextmanager
33
+ def override_auto_classes(config_func=None, model_func=None, skip_taskmanager=True):
34
+ """Temporarily override Auto classes with original model classes"""
35
+ original_config = AutoConfig.from_pretrained
36
+ original_model = AutoModel.from_pretrained
37
+ original_get_model_from_task = TasksManager.get_model_from_task
38
+
39
+ def get_model_from_task(
40
+ task: str,
41
+ model_name_or_path: Union[str, Path],
42
+ **kwargs,
43
+ ):
44
+ return model_func(model_name_or_path, **kwargs)
45
+
46
+ def none_func(*args, **kwargs):
47
+ return None
48
+
49
+ try:
50
+ AutoConfig.from_pretrained = config_func or none_func
51
+ AutoModel.from_pretrained = model_func or none_func
52
+ if skip_taskmanager:
53
+ TasksManager.get_model_from_task = none_func if model_func is None else get_model_from_task
54
+ yield
55
+ finally:
56
+ AutoConfig.from_pretrained = original_config
57
+ AutoModel.from_pretrained = original_model
58
+ TasksManager.get_model_from_task = original_get_model_from_task
@@ -0,0 +1,55 @@
1
+ from functools import wraps
2
+
3
+ from .logging import get_logger
4
+
5
+
6
+ logger = get_logger(__name__)
7
+
8
+
9
+ def remove_compile_time_kwargs(func):
10
+ """
11
+ Decorator to handle compile-time parameters during inference.
12
+
13
+ For RBLN-optimized pipelines, several parameters must be determined during compilation
14
+ and cannot be modified during inference. This decorator:
15
+ 1. Removes and warns about LoRA scale in cross_attention_kwargs
16
+ 2. Removes and warns about image dimension parameters (height, width)
17
+
18
+ Args:
19
+ func: The pipeline's __call__ method to be wrapped
20
+ """
21
+
22
+ @wraps(func)
23
+ def wrapper(self, *args, **kwargs):
24
+ height_exists = "height" in kwargs and kwargs["height"] is not None
25
+ width_exists = "width" in kwargs and kwargs["width"] is not None
26
+ if height_exists or width_exists:
27
+ logger.warning(
28
+ "Image dimension parameters (`height`, `width`) will be ignored during inference. "
29
+ "Image dimensions must be specified during model compilation using from_pretrained()."
30
+ )
31
+ kwargs.pop("width", None)
32
+ kwargs.pop("height", None)
33
+
34
+ if "cross_attention_kwargs" in kwargs:
35
+ cross_attention_kwargs = kwargs.get("cross_attention_kwargs")
36
+ if not cross_attention_kwargs:
37
+ return func(self, *args, **kwargs)
38
+
39
+ has_scale = "scale" in cross_attention_kwargs
40
+ if has_scale:
41
+ logger.warning(
42
+ "LoRA scale in cross_attention_kwargs will be ignored during inference. "
43
+ "To adjust LoRA scale, specify it during model compilation using from_pretrained()."
44
+ )
45
+
46
+ # If scale is the only key, set to None
47
+ # Otherwise, remove scale and preserve other settings
48
+ if len(cross_attention_kwargs) == 1:
49
+ kwargs["cross_attention_kwargs"] = None
50
+ else:
51
+ kwargs["cross_attention_kwargs"].pop("scale")
52
+
53
+ return func(self, *args, **kwargs)
54
+
55
+ return wrapper
@@ -37,6 +37,27 @@ class VersionCompat:
37
37
 
38
38
 
39
39
  RBLN_VERSION_COMPATS = {
40
+ "0.1.13": [
41
+ VersionCompat(
42
+ package_name="rebel-compiler",
43
+ min_version="0.6.0",
44
+ max_version="0.6.1",
45
+ ),
46
+ ],
47
+ "0.1.12": [
48
+ VersionCompat(
49
+ package_name="rebel-compiler",
50
+ min_version="0.5.12",
51
+ max_version="0.5.13",
52
+ ),
53
+ ],
54
+ "0.1.11": [
55
+ VersionCompat(
56
+ package_name="rebel-compiler",
57
+ min_version="0.5.10",
58
+ max_version="0.5.11",
59
+ ),
60
+ ],
40
61
  "0.1.10": [
41
62
  VersionCompat(
42
63
  package_name="rebel-compiler",
@@ -22,7 +22,7 @@ log_levels = {
22
22
  "critical": logging.CRITICAL,
23
23
  }
24
24
 
25
- _default_log_level = logging.WARNING
25
+ _default_log_level = logging.INFO
26
26
 
27
27
 
28
28
  def _get_default_logging_level():
@@ -67,7 +67,7 @@ class UnavailableRuntime:
67
67
  return iter([self])
68
68
 
69
69
  def forward(self, *args: List["torch.Tensor"], **kwargs: Dict[str, "torch.Tensor"]):
70
- raise RuntimeError("RBLN-Runtime is not created, So it is not available.")
70
+ raise RuntimeError("The model can't run because the runtime hasn't been created.")
71
71
 
72
72
  def __repr__(self) -> str:
73
73
  return "UnavailableRuntime"
@@ -76,17 +76,17 @@ class UnavailableRuntime:
76
76
  class ContextRblnConfig:
77
77
  _local = threading.local()
78
78
 
79
- def __init__(self, device, device_map, create_runtimes, optimze_host_mem):
79
+ def __init__(self, device=None, device_map=None, create_runtimes=None, optimize_host_mem=None):
80
80
  self.device = device
81
81
  self.device_map = device_map
82
82
  self.create_runtimes = create_runtimes
83
- self.optimze_host_mem = optimze_host_mem
83
+ self.optimize_host_mem = optimize_host_mem
84
84
 
85
85
  def __enter__(self):
86
86
  self._local.device = self.device
87
87
  self._local.device_map = self.device_map
88
88
  self._local.create_runtimes = self.create_runtimes
89
- self._local.optimize_host_memory = self.optimze_host_mem
89
+ self._local.optimize_host_memory = self.optimize_host_mem
90
90
  return self
91
91
 
92
92
  def __exit__(self, exc_type, exc_val, exc_tb):
@@ -1,5 +1,8 @@
1
+ import os
1
2
  from datetime import datetime
2
3
 
4
+ from halo import Halo
5
+
3
6
  from .logging import get_logger
4
7
 
5
8
 
@@ -9,11 +12,32 @@ logger = get_logger()
9
12
  def rbln_timer(print_name):
10
13
  def decorator(function):
11
14
  def wrapper(*args, **kwargs):
15
+ disable = os.getenv("OPTIMUM_RBLN_DISABLE_SPIN", "False").lower() in ("true", "1", "t")
16
+ if disable:
17
+ logger.info(f"{print_name} ...")
18
+
19
+ spinner = Halo(text=f"{print_name} ...", spinner="dots", color="green", enabled=(not disable))
20
+ spinner.start()
21
+
22
+ # Start timer
12
23
  tick = datetime.now()
13
- result = function(*args, **kwargs)
14
- logger.debug(f"{print_name}. Elasped time: {str(datetime.now() - tick)[:7]}")
24
+ try:
25
+ result = function(*args, **kwargs)
26
+ except Exception as e:
27
+ spinner.fail(f"{print_name} failed.")
28
+ raise e
29
+
30
+ # Print elapsed time.
31
+ if disable:
32
+ logger.info(f"{print_name} done. Elasped time: {format_elapsed_time(tick)}")
33
+
34
+ spinner.stop()
35
+ spinner.succeed(text=f"{print_name} done. Elasped time: {format_elapsed_time(tick)}")
15
36
  return result
16
37
 
17
38
  return wrapper
18
39
 
40
+ def format_elapsed_time(start_time: datetime) -> str:
41
+ return str(datetime.now() - start_time)[:7]
42
+
19
43
  return decorator
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.1
2
2
  Name: optimum-rbln
3
- Version: 0.1.11
3
+ Version: 0.1.13
4
4
  Summary: Optimum RBLN is the interface between the Hugging Face Transformers and Diffusers libraries and RBLN accelerators.
5
5
  It provides a set of tools enabling easy model loading and inference on single and multiple rbln device settings for different downstream tasks.
6
6
 
@@ -14,22 +14,24 @@ Classifier: Intended Audience :: Education
14
14
  Classifier: Intended Audience :: Science/Research
15
15
  Classifier: Operating System :: POSIX :: Linux
16
16
  Classifier: Programming Language :: Python :: 3 :: Only
17
- Classifier: Programming Language :: Python :: 3.8
18
17
  Classifier: Programming Language :: Python :: 3.9
19
18
  Classifier: Programming Language :: Python :: 3.10
19
+ Classifier: Programming Language :: Python :: 3.11
20
+ Classifier: Programming Language :: Python :: 3.12
20
21
  Classifier: Topic :: Scientific/Engineering :: Artificial Intelligence
21
22
  Project-URL: Homepage, https://rebellions.ai
22
23
  Project-URL: Documentation, https://docs.rbln.ai
23
- Requires-Python: <3.11,>=3.8
24
- Requires-Dist: torch<=2.2.1
25
- Requires-Dist: torchvision<=0.17.1
26
- Requires-Dist: torchaudio<=2.2.1
27
- Requires-Dist: optimum<=1.22.0
24
+ Requires-Python: <3.13,>=3.9
25
+ Requires-Dist: torch<=2.5.1
26
+ Requires-Dist: torchvision<=0.20.1
27
+ Requires-Dist: torchaudio<=2.5.1
28
+ Requires-Dist: optimum==1.23.1
28
29
  Requires-Dist: accelerate>=0.28.0
29
- Requires-Dist: transformers<=4.44.2,>=4.43.2
30
- Requires-Dist: diffusers<=0.30.3
30
+ Requires-Dist: transformers==4.45.2
31
+ Requires-Dist: diffusers<=0.31.0
31
32
  Requires-Dist: einops>=0.8.0
32
33
  Requires-Dist: packaging>=24.1
34
+ Requires-Dist: halo
33
35
  Provides-Extra: tests
34
36
  Requires-Dist: pytest>=8.1.1; extra == "tests"
35
37
  Requires-Dist: psutil>=5.9.8; extra == "tests"