optimum-rbln 0.7.3.post1__py3-none-any.whl → 0.7.3.post2__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -17,5 +17,5 @@ __version__: str
17
17
  __version_tuple__: VERSION_TUPLE
18
18
  version_tuple: VERSION_TUPLE
19
19
 
20
- __version__ = version = '0.7.3.post1'
20
+ __version__ = version = '0.7.3.post2'
21
21
  __version_tuple__ = version_tuple = (0, 7, 3)
@@ -108,8 +108,6 @@ class RBLNBartModel(RBLNModel):
108
108
 
109
109
 
110
110
  class RBLNBartForConditionalGeneration(RBLNModelForSeq2SeqLM):
111
- support_paged_causal_attn = True
112
-
113
111
  @classmethod
114
112
  def wrap_model_if_needed(self, model: "PreTrainedModel", rbln_config: "RBLNConfig"):
115
113
  enc_max_seq_len = (
@@ -50,7 +50,6 @@ class RBLNRuntimeDecoder(RBLNPytorchRuntime):
50
50
  runtime: rebel.Runtime,
51
51
  batch_size: int,
52
52
  dec_max_seq_len: int,
53
- support_paged_causal_attn: Optional[bool] = None,
54
53
  use_attention_mask: Optional[bool] = None,
55
54
  **kwargs: Any,
56
55
  ) -> None:
@@ -58,10 +57,7 @@ class RBLNRuntimeDecoder(RBLNPytorchRuntime):
58
57
  self.batch_size = batch_size
59
58
  self.dec_max_seq_len = dec_max_seq_len
60
59
  self.use_attention_mask = use_attention_mask
61
- if support_paged_causal_attn:
62
- self.default_block_tables = torch.arange(0, self.batch_size, dtype=torch.int16).view(self.batch_size, 1)
63
- else:
64
- self.default_block_tables = None
60
+ self.default_block_tables = torch.arange(0, self.batch_size, dtype=torch.int16).view(self.batch_size, 1)
65
61
 
66
62
  def forward(
67
63
  self,
@@ -98,7 +94,7 @@ class RBLNRuntimeDecoder(RBLNPytorchRuntime):
98
94
  decoder_attention_mask if self.use_attention_mask else None,
99
95
  attention_mask,
100
96
  cache_position,
101
- block_tables=block_tables,
97
+ block_tables,
102
98
  )
103
99
 
104
100
  return Seq2SeqLMOutput(logits=lm_logits)
@@ -119,7 +115,6 @@ class RBLNModelForSeq2SeqLM(RBLNModel, ABC):
119
115
 
120
116
  main_input_name = "input_ids"
121
117
  auto_model_class = AutoModelForSeq2SeqLM
122
- support_paged_causal_attn = None
123
118
 
124
119
  def __post_init__(self, **kwargs):
125
120
  batch_size = self.rbln_config.model_cfg["batch_size"]
@@ -135,7 +130,6 @@ class RBLNModelForSeq2SeqLM(RBLNModel, ABC):
135
130
  main_input_name="input_ids",
136
131
  batch_size=batch_size,
137
132
  dec_max_seq_len=dec_max_seq_len,
138
- support_paged_causal_attn=self.support_paged_causal_attn,
139
133
  use_attention_mask=self.use_attention_mask,
140
134
  )
141
135
 
@@ -192,16 +186,13 @@ class RBLNModelForSeq2SeqLM(RBLNModel, ABC):
192
186
  rbln_dec_max_seq_len = rbln_kwargs.get("dec_max_seq_len", None)
193
187
  rbln_batch_size = rbln_kwargs.get("batch_size", None)
194
188
  rbln_batch_size = 1 if rbln_batch_size is None else rbln_batch_size
189
+ rbln_use_attention_mask = rbln_kwargs.get("use_attention_mask", None)
195
190
 
196
- if cls.support_paged_causal_attn:
197
- rbln_use_attention_mask = rbln_kwargs.get("use_attention_mask", None)
198
- if rbln_use_attention_mask is None:
199
- rbln_use_attention_mask = False
200
- rbln_npu = rbln_kwargs.get("npu", None) or rebel.get_npu_name()
201
- if rbln_npu == "RBLN-CA02":
202
- rbln_use_attention_mask = True
203
- else:
204
- rbln_use_attention_mask = True
191
+ if rbln_use_attention_mask is None:
192
+ rbln_use_attention_mask = False
193
+ rbln_npu = rbln_kwargs.get("npu", None) or rebel.get_npu_name()
194
+ if rbln_npu == "RBLN-CA02":
195
+ rbln_use_attention_mask = True
205
196
 
206
197
  n_layer = getattr(model_config, "decoder_layers", None) or getattr(model_config, "num_layers")
207
198
  n_head = getattr(model_config, "decoder_attention_heads", None) or getattr(model_config, "num_heads")
@@ -274,6 +265,11 @@ class RBLNModelForSeq2SeqLM(RBLNModel, ABC):
274
265
  [rbln_batch_size, 1],
275
266
  "int32",
276
267
  ),
268
+ (
269
+ "block_tables",
270
+ [rbln_batch_size, 1],
271
+ "int16",
272
+ ),
277
273
  ]
278
274
  dec_input_info.extend(
279
275
  [
@@ -306,8 +302,6 @@ class RBLNModelForSeq2SeqLM(RBLNModel, ABC):
306
302
  ]
307
303
  )
308
304
 
309
- if cls.support_paged_causal_attn:
310
- dec_input_info.insert(3, ("block_tables", [rbln_batch_size, 1], "int16"))
311
305
  if rbln_use_attention_mask:
312
306
  dec_input_info.insert(1, ("attention_mask", [rbln_batch_size, rbln_dec_max_seq_len], "float32"))
313
307
 
@@ -13,8 +13,9 @@
13
13
  # limitations under the License.
14
14
 
15
15
  import inspect
16
- from typing import TYPE_CHECKING, Any, Callable, Dict, Optional, Tuple, Union
16
+ from typing import TYPE_CHECKING, Any, Callable, Dict, List, Optional, Tuple, Union
17
17
 
18
+ import rebel
18
19
  import torch
19
20
  from transformers import (
20
21
  AutoModelForTextEncoding,
@@ -22,7 +23,7 @@ from transformers import (
22
23
  T5EncoderModel,
23
24
  T5ForConditionalGeneration,
24
25
  )
25
- from transformers.modeling_outputs import BaseModelOutput
26
+ from transformers.modeling_outputs import BaseModelOutput, Seq2SeqLMOutput
26
27
 
27
28
  from ....diffusers.modeling_diffusers import RBLNDiffusionMixin
28
29
  from ....modeling import RBLNModel
@@ -57,6 +58,63 @@ class RBLNRuntimeModel(RBLNPytorchRuntime):
57
58
  )
58
59
 
59
60
 
61
+ class RBLNRuntimeEncoder(RBLNPytorchRuntime):
62
+ mandatory_members = ["main_input_name"]
63
+
64
+ def forward(self, *args: List[torch.Tensor], **kwargs: Dict[str, torch.Tensor]):
65
+ _ = super().forward(*args, **kwargs)
66
+ return BaseModelOutput(last_hidden_state=torch.tensor([1.0]))
67
+
68
+
69
+ class RBLNRuntimeDecoder(RBLNPytorchRuntime):
70
+ mandatory_members = ["main_input_name"]
71
+
72
+ def __init__(
73
+ self,
74
+ runtime: rebel.Runtime,
75
+ batch_size: int,
76
+ dec_max_seq_len: int,
77
+ **kwargs: Any,
78
+ ) -> None:
79
+ super().__init__(runtime, **kwargs)
80
+ self.batch_size = batch_size
81
+ self.dec_max_seq_len = dec_max_seq_len
82
+
83
+ def forward(
84
+ self,
85
+ decoder_input_ids: Optional[torch.LongTensor] = None,
86
+ attention_mask: Optional[torch.FloatTensor] = None,
87
+ decoder_attention_mask: Optional[torch.BoolTensor] = None,
88
+ cache_position: Optional[torch.Tensor] = None,
89
+ **kwargs,
90
+ ) -> Tuple[torch.FloatTensor]:
91
+ batch_size = decoder_input_ids.shape[0]
92
+ if batch_size != self.batch_size:
93
+ raise RuntimeError(
94
+ f"Batch size mismatch: got {batch_size}, expected {self.batch_size} (compiled batch size)."
95
+ )
96
+
97
+ if batch_size != cache_position.shape[0]:
98
+ raise RuntimeError(f"Cache position size mismatch: got {cache_position.shape[0]}, expected {batch_size}.")
99
+
100
+ for b_idx in range(self.batch_size):
101
+ decoding_step = cache_position[b_idx].item()
102
+ if not (0 <= decoding_step < self.dec_max_seq_len):
103
+ raise ValueError(
104
+ f"Decoding step {decoding_step} out of bounds for attention mask with shape {self.dec_attn_mask.shape}."
105
+ )
106
+ decoder_attention_mask[b_idx, : decoding_step + 1] = 1
107
+
108
+ lm_logits = super().forward(
109
+ decoder_input_ids,
110
+ decoder_attention_mask,
111
+ attention_mask,
112
+ cache_position,
113
+ )
114
+
115
+ return Seq2SeqLMOutput(logits=lm_logits)
116
+
117
+
60
118
  class T5EncoderWrapper(torch.nn.Module):
61
119
  def __init__(self, model: "T5EncoderModel") -> None:
62
120
  super().__init__()
@@ -189,7 +247,20 @@ class RBLNT5EncoderModel(RBLNModel):
189
247
 
190
248
 
191
249
  class RBLNT5ForConditionalGeneration(RBLNModelForSeq2SeqLM):
192
- support_causal_paged_attn = False
250
+ def __post_init__(self, **kwargs):
251
+ batch_size = self.rbln_config.model_cfg["batch_size"]
252
+ dec_max_seq_len = self.rbln_config.model_cfg["dec_max_seq_len"]
253
+
254
+ self.encoder = RBLNRuntimeEncoder(
255
+ runtime=self.model[0],
256
+ main_input_name="input_ids",
257
+ )
258
+ self.decoder = RBLNRuntimeDecoder(
259
+ runtime=self.model[1],
260
+ main_input_name="input_ids",
261
+ batch_size=batch_size,
262
+ dec_max_seq_len=dec_max_seq_len,
263
+ )
193
264
 
194
265
  @classmethod
195
266
  def wrap_model_if_needed(self, model: "PreTrainedModel", rbln_config: "RBLNConfig"):
@@ -208,3 +279,139 @@ class RBLNT5ForConditionalGeneration(RBLNModelForSeq2SeqLM):
208
279
  return redirect(val)
209
280
 
210
281
  return val
282
+
283
+ @classmethod
284
+ def _get_rbln_config(
285
+ cls,
286
+ preprocessors: Union["AutoFeatureExtractor", "AutoProcessor", "AutoTokenizer"],
287
+ model_config: "PretrainedConfig",
288
+ rbln_kwargs: Dict[str, Any] = {},
289
+ ) -> RBLNConfig:
290
+ rbln_enc_max_seq_len = rbln_kwargs.get("enc_max_seq_len", None)
291
+ rbln_dec_max_seq_len = rbln_kwargs.get("dec_max_seq_len", None)
292
+ rbln_batch_size = rbln_kwargs.get("batch_size", None)
293
+ rbln_batch_size = 1 if rbln_batch_size is None else rbln_batch_size
294
+
295
+ n_layer = getattr(model_config, "decoder_layers", None) or getattr(model_config, "num_layers")
296
+ n_head = getattr(model_config, "decoder_attention_heads", None) or getattr(model_config, "num_heads")
297
+ d_kv = (
298
+ model_config.d_kv
299
+ if hasattr(model_config, "d_kv")
300
+ else model_config.d_model // model_config.encoder_attention_heads
301
+ )
302
+
303
+ max_position_embeddings = getattr(model_config, "n_positions", None) or getattr(
304
+ model_config, "max_position_embeddings", None
305
+ )
306
+
307
+ rbln_pad_token_id = getattr(model_config, "pad_token_id", None)
308
+ if rbln_pad_token_id is None:
309
+ rbln_pad_token_id = getattr(model_config, "bos_token_id", None)
310
+ if rbln_pad_token_id is None:
311
+ rbln_pad_token_id = getattr(model_config, "eos_token_id", None)
312
+ if rbln_pad_token_id is None:
313
+ rbln_pad_token_id = -1
314
+
315
+ if rbln_enc_max_seq_len is None:
316
+ rbln_enc_max_seq_len = max_position_embeddings
317
+ if rbln_enc_max_seq_len is None:
318
+ for tokenizer in preprocessors:
319
+ if hasattr(tokenizer, "model_max_length"):
320
+ rbln_enc_max_seq_len = tokenizer.model_max_length
321
+ break
322
+ if rbln_enc_max_seq_len is None:
323
+ raise ValueError("`rbln_enc_max_seq_len` should be specified!")
324
+ if max_position_embeddings is not None and rbln_enc_max_seq_len > max_position_embeddings:
325
+ raise ValueError("`rbln_enc_max_seq_len` should be less or equal than max_position_embeddings!")
326
+
327
+ if rbln_dec_max_seq_len is None:
328
+ rbln_dec_max_seq_len = max_position_embeddings
329
+ if rbln_dec_max_seq_len is None:
330
+ for tokenizer in preprocessors:
331
+ if hasattr(tokenizer, "model_max_length"):
332
+ rbln_dec_max_seq_len = tokenizer.model_max_length
333
+ break
334
+ if rbln_dec_max_seq_len is None:
335
+ raise ValueError("`rbln_dec_max_seq_len` should be specified!")
336
+
337
+ if max_position_embeddings is not None and rbln_dec_max_seq_len > max_position_embeddings:
338
+ raise ValueError("`rbln_dec_max_seq_len` should be less or equal than max_position_embeddings!")
339
+
340
+ # model input info
341
+ enc_input_info = [
342
+ ("input_ids", [1, rbln_enc_max_seq_len], "int64"),
343
+ ("attention_mask", [1, rbln_enc_max_seq_len], "float32"),
344
+ (
345
+ "cross_key_value_states",
346
+ [
347
+ n_layer * 2,
348
+ rbln_batch_size,
349
+ n_head,
350
+ rbln_enc_max_seq_len,
351
+ d_kv,
352
+ ],
353
+ "float32",
354
+ ),
355
+ ("block_tables", [1], "int16"),
356
+ ]
357
+
358
+ dec_input_info = [
359
+ ("input_ids", [rbln_batch_size, 1], "int64"),
360
+ ("attention_mask", [rbln_batch_size, rbln_dec_max_seq_len], "float32"),
361
+ ("encoder_attention_mask", [rbln_batch_size, rbln_enc_max_seq_len], "float32"),
362
+ (
363
+ "cache_position",
364
+ [rbln_batch_size, 1],
365
+ "int32",
366
+ ),
367
+ ]
368
+ dec_input_info.extend(
369
+ [
370
+ (
371
+ "cross_key_value_states",
372
+ [
373
+ n_layer * 2,
374
+ rbln_batch_size,
375
+ n_head,
376
+ rbln_enc_max_seq_len,
377
+ d_kv,
378
+ ],
379
+ "float32",
380
+ )
381
+ ]
382
+ )
383
+ dec_input_info.extend(
384
+ [
385
+ (
386
+ f"self_key_value_states_{i}",
387
+ [
388
+ rbln_batch_size,
389
+ n_head,
390
+ rbln_dec_max_seq_len,
391
+ d_kv,
392
+ ],
393
+ "float32",
394
+ )
395
+ for i in range(n_layer * 2)
396
+ ]
397
+ )
398
+
399
+ enc_compile_config = RBLNCompileConfig(compiled_model_name="encoder", input_info=enc_input_info)
400
+ dec_compile_config = RBLNCompileConfig(compiled_model_name="decoder", input_info=dec_input_info)
401
+
402
+ rbln_config = RBLNConfig(
403
+ rbln_cls=cls.__name__,
404
+ compile_cfgs=[enc_compile_config, dec_compile_config],
405
+ rbln_kwargs=rbln_kwargs,
406
+ )
407
+
408
+ rbln_config.model_cfg.update(
409
+ {
410
+ "enc_max_seq_len": rbln_enc_max_seq_len,
411
+ "dec_max_seq_len": rbln_dec_max_seq_len,
412
+ "batch_size": rbln_batch_size,
413
+ "pad_token_id": rbln_pad_token_id,
414
+ }
415
+ )
416
+
417
+ return rbln_config
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: optimum-rbln
3
- Version: 0.7.3.post1
3
+ Version: 0.7.3.post2
4
4
  Summary: Optimum RBLN is the interface between the Hugging Face Transformers and Diffusers libraries and RBLN accelerators. It provides a set of tools enabling easy model loading and inference on single and multiple rbln device settings for different downstream tasks.
5
5
  Project-URL: Homepage, https://rebellions.ai
6
6
  Project-URL: Documentation, https://docs.rbln.ai
@@ -1,5 +1,5 @@
1
1
  optimum/rbln/__init__.py,sha256=ZDzXcl-oAcYJhKjJMpotjbTih9awo7HzUb6T3MUEP6Q,6894
2
- optimum/rbln/__version__.py,sha256=aegWGVZeZJ9bIegWWNAgPL2y9SAs5kPTsXCQi0EZ9go,517
2
+ optimum/rbln/__version__.py,sha256=OJRzB6Y7xaNgH7EkerbumPEoY0Nlzs1_xYhBJvOXTzQ,517
3
3
  optimum/rbln/modeling.py,sha256=nJsAs5zs--VVOYGFjYNpqfxYIemJIK4Lr0WEzlDLdP0,8390
4
4
  optimum/rbln/modeling_base.py,sha256=dNCL-BhrWCpuOVkZaj8-MW567Tf4lLo3p3Z3ldjWJfU,21779
5
5
  optimum/rbln/modeling_config.py,sha256=7104bxmrvKW4Q6XTruQayiIGl8GHDFmPkJ3cknMIInE,11335
@@ -55,7 +55,7 @@ optimum/rbln/transformers/models/auto/auto_factory.py,sha256=IK9jFrJ3EEzYQa9_aKp
55
55
  optimum/rbln/transformers/models/auto/modeling_auto.py,sha256=Un9qoqdy3dO8JBza_bTJF_6_fRVNM9QisihSgTRFI-o,3933
56
56
  optimum/rbln/transformers/models/bart/__init__.py,sha256=32HPe0_GIO0hp9U464Iv6Jd7M-1nop9g8hA1UZMHhyw,674
57
57
  optimum/rbln/transformers/models/bart/bart_architecture.py,sha256=Oo-Cdne7igKEex8wwP-gztKJHgs5GLHQjK1oc3IZIDE,5801
58
- optimum/rbln/transformers/models/bart/modeling_bart.py,sha256=6IpWXlBCd02v66KF77oEWfrv8-FnPBYjjjL_8KZL3Ow,5835
58
+ optimum/rbln/transformers/models/bart/modeling_bart.py,sha256=iI3ubPOVvHmhLt0wEz_vkOfMyNTHVNjmnkLtbpOX760,5797
59
59
  optimum/rbln/transformers/models/bert/__init__.py,sha256=YVV7k_laU6yJBawZrgjIWjRmIF-Y4oQQHqyf8lsraQs,691
60
60
  optimum/rbln/transformers/models/bert/modeling_bert.py,sha256=p3utRqf3dv9_RkHwaMCa1EfXttNJkqCJUIZo3CeZ9YY,4674
61
61
  optimum/rbln/transformers/models/clip/__init__.py,sha256=H9vuBwrmFO0-CqZhXUrKF-uQL6igCqMlqrT1X_ELaAI,754
@@ -92,10 +92,10 @@ optimum/rbln/transformers/models/qwen2/__init__.py,sha256=RAMWc21W_2I6DH9xBjeNxP
92
92
  optimum/rbln/transformers/models/qwen2/modeling_qwen2.py,sha256=9-aFDvjMzPNUyGOz0qo33RE18bUFGYZ3Wt_68zb5uJY,1530
93
93
  optimum/rbln/transformers/models/qwen2/qwen2_architecture.py,sha256=XlNAMYAcDLohnSAhIFGKOPuCB5XLgzYs5ABWdeQSaZs,720
94
94
  optimum/rbln/transformers/models/seq2seq/__init__.py,sha256=EmEMV4rOYqKyruX85d0fR73-b8N6BSD6CPcbpYdBuVk,651
95
- optimum/rbln/transformers/models/seq2seq/modeling_seq2seq.py,sha256=9Pf9Y86ABDfhwIenlZqYfgqjbyFmtKBiPnbCD7zxw4M,18017
95
+ optimum/rbln/transformers/models/seq2seq/modeling_seq2seq.py,sha256=NPfJf9Uk_bYOae7hXGHwteGiWH0va63Z-D93RmAMENg,17611
96
96
  optimum/rbln/transformers/models/seq2seq/seq2seq_architecture.py,sha256=tvzacIZam1sIr_1BvvZ_fDr8u5dXAiYiynFdX9tArtY,18877
97
97
  optimum/rbln/transformers/models/t5/__init__.py,sha256=1skR1RmnG62WTAP3-F5P1x-V_ReFhMyirH3u56vWwvc,675
98
- optimum/rbln/transformers/models/t5/modeling_t5.py,sha256=8PAhPlYT1dmpcWM7hUMmZV9lPd4d75CuMuFen1pzr3Q,8088
98
+ optimum/rbln/transformers/models/t5/modeling_t5.py,sha256=nKRR3eH1EAu1YkKvhlqGyTrJXDRd-IWB5LOeG9jrcb4,16021
99
99
  optimum/rbln/transformers/models/t5/t5_architecture.py,sha256=AArCQhZRETVM583wlIRzMFOSYq7t2nzxaAeyhZxyxKk,9508
100
100
  optimum/rbln/transformers/models/wav2vec2/__init__.py,sha256=YpgA0K-vyg9veh0eL_jxauosbRpb_kpGKHvvQLBspKM,649
101
101
  optimum/rbln/transformers/models/wav2vec2/modeling_wav2vec2.py,sha256=JYJmV52j6cBwim4RanVJryfKnV80V96ol0A-oR6o7cg,3856
@@ -116,7 +116,7 @@ optimum/rbln/utils/model_utils.py,sha256=DfD_Z2qvZHqcddXqnzTM1AN8khanj3-DXK2lJvV
116
116
  optimum/rbln/utils/runtime_utils.py,sha256=5-DYniyP59nx-mrrbi7AqA77L85b4Cm5oLpaxidSyss,3699
117
117
  optimum/rbln/utils/save_utils.py,sha256=hG5uOtYmecSXZuGTvCXsTM-SiyZpr5q3InUGCCq_jzQ,3619
118
118
  optimum/rbln/utils/submodule.py,sha256=oZoGrItB8WqY4i-K9WJPlLlcLohc1YGB9OHB8_XZw3A,4071
119
- optimum_rbln-0.7.3.post1.dist-info/METADATA,sha256=dKER74SsqGQwVQgTXVM854y97xzhfRl5LKaGedd4IIw,5304
120
- optimum_rbln-0.7.3.post1.dist-info/WHEEL,sha256=qtCwoSJWgHk21S1Kb4ihdzI2rlJ1ZKaIurTj_ngOhyQ,87
121
- optimum_rbln-0.7.3.post1.dist-info/licenses/LICENSE,sha256=QwcOLU5TJoTeUhuIXzhdCEEDDvorGiC6-3YTOl4TecE,11356
122
- optimum_rbln-0.7.3.post1.dist-info/RECORD,,
119
+ optimum_rbln-0.7.3.post2.dist-info/METADATA,sha256=YgOp5SEpJ_VfYEohAoBhSQ20TaX1usvkRAzV7s7mS5I,5304
120
+ optimum_rbln-0.7.3.post2.dist-info/WHEEL,sha256=qtCwoSJWgHk21S1Kb4ihdzI2rlJ1ZKaIurTj_ngOhyQ,87
121
+ optimum_rbln-0.7.3.post2.dist-info/licenses/LICENSE,sha256=QwcOLU5TJoTeUhuIXzhdCEEDDvorGiC6-3YTOl4TecE,11356
122
+ optimum_rbln-0.7.3.post2.dist-info/RECORD,,