optimum-rbln 0.8.2a4__py3-none-any.whl → 0.8.2a5__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of optimum-rbln might be problematic. Click here for more details.

Files changed (57) hide show
  1. optimum/rbln/__init__.py +36 -0
  2. optimum/rbln/__version__.py +2 -2
  3. optimum/rbln/configuration_utils.py +4 -0
  4. optimum/rbln/ops/kv_cache_update.py +5 -0
  5. optimum/rbln/ops/linear.py +7 -0
  6. optimum/rbln/transformers/__init__.py +40 -0
  7. optimum/rbln/transformers/modeling_attention_utils.py +252 -0
  8. optimum/rbln/transformers/models/__init__.py +31 -14
  9. optimum/rbln/transformers/models/decoderonly/__init__.py +2 -2
  10. optimum/rbln/transformers/models/decoderonly/configuration_decoderonly.py +204 -44
  11. optimum/rbln/transformers/models/decoderonly/decoderonly_architecture.py +124 -208
  12. optimum/rbln/transformers/models/decoderonly/modeling_decoderonly.py +565 -366
  13. optimum/rbln/transformers/models/gemma/__init__.py +2 -2
  14. optimum/rbln/transformers/models/gemma/configuration_gemma.py +9 -1
  15. optimum/rbln/transformers/models/gemma/modeling_gemma.py +13 -1
  16. optimum/rbln/transformers/models/gemma3/configuration_gemma3.py +0 -6
  17. optimum/rbln/transformers/models/gemma3/modeling_gemma3.py +10 -6
  18. optimum/rbln/transformers/models/gpt2/__init__.py +2 -2
  19. optimum/rbln/transformers/models/gpt2/configuration_gpt2.py +31 -3
  20. optimum/rbln/transformers/models/gpt2/gpt2_architecture.py +6 -7
  21. optimum/rbln/transformers/models/gpt2/modeling_gpt2.py +16 -1
  22. optimum/rbln/transformers/models/idefics3/modeling_idefics3.py +2 -2
  23. optimum/rbln/transformers/models/llama/__init__.py +2 -2
  24. optimum/rbln/transformers/models/llama/configuration_llama.py +9 -1
  25. optimum/rbln/transformers/models/llama/modeling_llama.py +13 -1
  26. optimum/rbln/transformers/models/llava_next/modeling_llava_next.py +4 -4
  27. optimum/rbln/transformers/models/mistral/__init__.py +2 -2
  28. optimum/rbln/transformers/models/mistral/configuration_mistral.py +9 -1
  29. optimum/rbln/transformers/models/mistral/mistral_architecture.py +1 -1
  30. optimum/rbln/transformers/models/mistral/modeling_mistral.py +26 -3
  31. optimum/rbln/transformers/models/opt/__init__.py +2 -2
  32. optimum/rbln/transformers/models/opt/configuration_opt.py +8 -1
  33. optimum/rbln/transformers/models/opt/modeling_opt.py +41 -1
  34. optimum/rbln/transformers/models/opt/opt_architecture.py +4 -4
  35. optimum/rbln/transformers/models/pegasus/__init__.py +17 -0
  36. optimum/rbln/transformers/models/pegasus/configuration_pegasus.py +34 -0
  37. optimum/rbln/transformers/models/pegasus/modeling_pegasus.py +69 -0
  38. optimum/rbln/transformers/models/pegasus/pegasus_architecture.py +163 -0
  39. optimum/rbln/transformers/models/phi/__init__.py +2 -2
  40. optimum/rbln/transformers/models/phi/configuration_phi.py +9 -1
  41. optimum/rbln/transformers/models/phi/modeling_phi.py +10 -1
  42. optimum/rbln/transformers/models/phi/phi_architecture.py +6 -6
  43. optimum/rbln/transformers/models/qwen2/__init__.py +2 -2
  44. optimum/rbln/transformers/models/qwen2/configuration_qwen2.py +9 -1
  45. optimum/rbln/transformers/models/qwen2/modeling_qwen2.py +27 -1
  46. optimum/rbln/transformers/models/qwen2_5_vl/modeling_qwen2_5_vl.py +3 -3
  47. optimum/rbln/transformers/models/qwen3/configuration_qwen3.py +2 -2
  48. optimum/rbln/transformers/models/qwen3/modeling_qwen3.py +10 -328
  49. optimum/rbln/transformers/models/qwen3/qwen3_architecture.py +0 -241
  50. optimum/rbln/transformers/models/seq2seq/configuration_seq2seq.py +0 -10
  51. optimum/rbln/transformers/models/whisper/configuration_whisper.py +1 -10
  52. optimum/rbln/transformers/models/whisper/modeling_whisper.py +5 -1
  53. optimum/rbln/utils/depreacate_utils.py +16 -0
  54. {optimum_rbln-0.8.2a4.dist-info → optimum_rbln-0.8.2a5.dist-info}/METADATA +1 -1
  55. {optimum_rbln-0.8.2a4.dist-info → optimum_rbln-0.8.2a5.dist-info}/RECORD +57 -51
  56. {optimum_rbln-0.8.2a4.dist-info → optimum_rbln-0.8.2a5.dist-info}/WHEEL +0 -0
  57. {optimum_rbln-0.8.2a4.dist-info → optimum_rbln-0.8.2a5.dist-info}/licenses/LICENSE +0 -0
@@ -12,32 +12,23 @@
12
12
  # See the License for the specific language governing permissions and
13
13
  # limitations under the License.
14
14
 
15
- from pathlib import Path
16
- from typing import TYPE_CHECKING, List, Optional, Union
15
+ from typing import TYPE_CHECKING
17
16
 
18
- import rebel
19
- import torch
20
- from rebel.compile_context import CompileContext
21
- from transformers import PretrainedConfig, PreTrainedModel
22
- from transformers.modeling_outputs import BaseModelOutputWithPast
23
- from transformers.modeling_utils import no_init_weights
17
+ from transformers import PretrainedConfig
24
18
 
25
- from ....configuration_utils import RBLNCompileConfig
26
- from ....modeling import RBLNModel
27
19
  from ....utils import logging
28
- from ...models.decoderonly import RBLNDecoderOnlyModelForCausalLM, RBLNDecoderOnlyModelForCausalLMConfig
29
- from ..decoderonly.modeling_decoderonly import set_default_values, validate_attention_method
30
- from .configuration_qwen3 import RBLNQwen3ModelConfig
31
- from .qwen3_architecture import Qwen3ModelWrapper, Qwen3Wrapper
20
+ from ...models.decoderonly import (
21
+ RBLNDecoderOnlyModel,
22
+ RBLNDecoderOnlyModelForCausalLM,
23
+ RBLNDecoderOnlyModelForCausalLMConfig,
24
+ )
25
+ from .qwen3_architecture import Qwen3Wrapper
32
26
 
33
27
 
34
28
  logger = logging.get_logger(__name__)
35
29
 
36
30
  if TYPE_CHECKING:
37
31
  from transformers import (
38
- AutoFeatureExtractor,
39
- AutoProcessor,
40
- AutoTokenizer,
41
32
  PretrainedConfig,
42
33
  )
43
34
 
@@ -63,315 +54,6 @@ class RBLNQwen3ForCausalLM(RBLNDecoderOnlyModelForCausalLM):
63
54
  return super().forward(*args, **kwargs)
64
55
 
65
56
 
66
- class RBLNQwen3Model(RBLNModel):
67
- _decoder_wrapper_cls = Qwen3ModelWrapper
57
+ class RBLNQwen3Model(RBLNDecoderOnlyModel):
58
+ _decoder_wrapper_cls = Qwen3Wrapper
68
59
  _use_rotary_emb = True
69
-
70
- def __post_init__(self, **kwargs):
71
- artifacts = torch.load(self.model_save_dir / self.subfolder / "torch_artifacts.pth", weights_only=False)
72
- self.embed_tokens = self._create_embedding_layer()
73
- self.embed_tokens.load_state_dict(artifacts["embed_tokens"])
74
- self.block_tables = torch.arange(
75
- self.rbln_config.max_seq_len / self.rbln_config.kvcache_block_size, dtype=torch.int16
76
- )
77
- self.causal_mask = 1 - torch.triu(
78
- torch.ones(1, 1, self.rbln_config.prefill_chunk_size, self.rbln_config.prefill_chunk_size), diagonal=1
79
- )
80
-
81
- @classmethod
82
- def save_torch_artifacts(
83
- cls,
84
- model: PreTrainedModel,
85
- save_dir_path: Path,
86
- subfolder: str,
87
- rbln_config: RBLNQwen3ModelConfig,
88
- ):
89
- save_dict = {}
90
- save_dict["embed_tokens"] = model.get_input_embeddings().state_dict()
91
- torch.save(save_dict, save_dir_path / subfolder / "torch_artifacts.pth")
92
-
93
- def _create_embedding_layer(self):
94
- with no_init_weights():
95
- embed_tokens = torch.nn.Embedding(
96
- self.config.vocab_size,
97
- self.config.hidden_size,
98
- self.config.pad_token_id,
99
- )
100
- return embed_tokens
101
-
102
- def get_input_embeddings(self):
103
- return self.embed_tokens
104
-
105
- @classmethod
106
- def wrap_model_if_needed(cls, model: PreTrainedModel, rbln_config: "RBLNQwen3ModelConfig"):
107
- wrapper_cfg = {
108
- "max_seq_len": rbln_config.max_seq_len,
109
- "attn_impl": rbln_config.attn_impl,
110
- "kvcache_partition_len": rbln_config.kvcache_partition_len,
111
- "kvcache_block_size": rbln_config.kvcache_block_size,
112
- "use_rotary_emb": cls._use_rotary_emb,
113
- "use_attention_mask": rbln_config.use_attention_mask,
114
- "cache_impl": rbln_config.cache_impl,
115
- "sliding_window": rbln_config.sliding_window,
116
- "sliding_window_layers": rbln_config.sliding_window_layers,
117
- }
118
- return cls._decoder_wrapper_cls(model, **wrapper_cfg).eval()
119
-
120
- @classmethod
121
- @torch.inference_mode()
122
- def get_compiled_model(cls, model: "PreTrainedModel", rbln_config: RBLNQwen3ModelConfig):
123
- wrapped_model = cls.wrap_model_if_needed(model, rbln_config)
124
-
125
- rbln_compile_configs = rbln_config.compile_cfgs
126
- prefill_compile_config = rbln_compile_configs[0]
127
-
128
- context = CompileContext(use_weight_sharing=False)
129
-
130
- meta_tensor_names = [name for name, _, _ in prefill_compile_config.input_info if "past_key_values" in name]
131
- prefill_example_inputs = prefill_compile_config.get_dummy_inputs(fill=0, meta_tensor_names=meta_tensor_names)
132
-
133
- static_tensors = {}
134
- for (name, _, _), tensor in zip(prefill_compile_config.input_info, prefill_example_inputs):
135
- if "past_key_values" in name:
136
- static_tensors[name] = tensor
137
- context.mark_static_address(tensor)
138
-
139
- def compile_model(wrapped_model, compile_config, example_inputs, compile_context):
140
- try:
141
- original_linear = torch.nn.functional.linear
142
- torch.nn.functional.linear = torch.ops.rbln_custom_ops.linear
143
- compiled_model = RBLNModel.compile(
144
- wrapped_model,
145
- compile_config,
146
- example_inputs=example_inputs,
147
- compile_context=compile_context,
148
- create_runtimes=rbln_config.create_runtimes,
149
- device=rbln_config.device,
150
- )
151
- return compiled_model
152
- finally:
153
- torch.nn.functional.linear = original_linear
154
-
155
- wrapped_model.phase = "prefill"
156
- compiled_prefill = compile_model(wrapped_model, prefill_compile_config, prefill_example_inputs, context)
157
-
158
- compiled_models = {"prefill": compiled_prefill}
159
- return compiled_models
160
-
161
- @classmethod
162
- def get_input_info(
163
- cls,
164
- batch_size: int,
165
- query_length: int,
166
- rbln_config: RBLNQwen3ModelConfig,
167
- model_config: PretrainedConfig,
168
- ):
169
- input_info = RBLNDecoderOnlyModelForCausalLM.get_input_info(
170
- batch_size,
171
- query_length,
172
- rbln_config=rbln_config,
173
- model_config=model_config,
174
- )
175
-
176
- if rbln_config.sliding_window is None:
177
- # remove query position
178
- input_info.pop(3)
179
-
180
- return input_info
181
-
182
- @classmethod
183
- def _update_rbln_config(
184
- cls,
185
- preprocessors: Optional[Union["AutoFeatureExtractor", "AutoProcessor", "AutoTokenizer"]] = None,
186
- model: Optional["PreTrainedModel"] = None,
187
- model_config: Optional["PretrainedConfig"] = None,
188
- rbln_config: Optional[RBLNQwen3ModelConfig] = None,
189
- ) -> RBLNQwen3ModelConfig:
190
- if rbln_config.max_seq_len is None:
191
- rbln_config.max_seq_len = getattr(model_config, "max_position_embeddings", None) or getattr(
192
- model_config, "n_positions", None
193
- )
194
- if rbln_config.max_seq_len is None:
195
- raise ValueError("`max_seq_len` should be specified.")
196
-
197
- rbln_config.attn_impl, rbln_config.kvcache_partition_len, rbln_config.kvcache_block_size = set_default_values(
198
- attn_impl=rbln_config.attn_impl,
199
- kvcache_partition_len=rbln_config.kvcache_partition_len,
200
- kvcache_block_size=rbln_config.kvcache_block_size,
201
- max_seq_len=rbln_config.max_seq_len,
202
- )
203
-
204
- validate_attention_method(
205
- attn_impl=rbln_config.attn_impl,
206
- kvcache_partition_len=rbln_config.kvcache_partition_len,
207
- kvcache_block_size=rbln_config.kvcache_block_size,
208
- max_seq_len=rbln_config.max_seq_len,
209
- )
210
-
211
- # only compile prefill cb -> always batch_size 1
212
- required_num_blocks = rbln_config.max_seq_len // rbln_config.kvcache_block_size
213
- max_num_blocks = required_num_blocks
214
-
215
- if rbln_config.attn_impl == "flash_attn":
216
- estimated_max_num_blocks = RBLNDecoderOnlyModelForCausalLM.get_maximum_num_blocks(
217
- config=model_config,
218
- tensor_parallel_size=rbln_config.tensor_parallel_size or 1,
219
- kvcache_block_size=rbln_config.kvcache_block_size,
220
- nbits_per_param=16 if not rbln_config.quantization else 4,
221
- n_model_params=sum(p.numel() for p in model.parameters()),
222
- num_runtimes=1 + len(rbln_config.decoder_batch_sizes),
223
- )
224
-
225
- max_num_blocks = min(max_num_blocks, estimated_max_num_blocks)
226
-
227
- flash_min_blocks = rbln_config.max_seq_len // rbln_config.kvcache_block_size + 1
228
- if max_num_blocks < flash_min_blocks:
229
- max_num_blocks = flash_min_blocks
230
-
231
- if rbln_config.kvcache_num_blocks is None:
232
- rbln_config.kvcache_num_blocks = max_num_blocks
233
-
234
- prefill_input_info = cls.get_input_info(
235
- batch_size=1,
236
- query_length=rbln_config.prefill_chunk_size,
237
- rbln_config=rbln_config,
238
- model_config=model_config,
239
- )
240
-
241
- prefill_compile_config = RBLNCompileConfig(compiled_model_name="prefill", input_info=prefill_input_info)
242
- rbln_config.set_compile_cfgs([prefill_compile_config])
243
-
244
- return rbln_config
245
-
246
- @classmethod
247
- def _create_runtimes(
248
- cls,
249
- compiled_models: List[rebel.RBLNCompiledModel],
250
- rbln_config: RBLNQwen3ModelConfig,
251
- ) -> List[rebel.Runtime]:
252
- expected_model_names = ["prefill"]
253
- if any(model_name not in rbln_config.device_map for model_name in expected_model_names):
254
- cls._raise_missing_compiled_file_error(expected_model_names)
255
-
256
- return [
257
- rebel.Runtime(
258
- compiled_models[0],
259
- tensor_type="pt",
260
- device=rbln_config.device_map["prefill"],
261
- activate_profiler=rbln_config.activate_profiler,
262
- ),
263
- ]
264
-
265
- def _preprocess_chunked_prefill(
266
- self,
267
- inputs: torch.Tensor,
268
- attention_mask: Optional[torch.Tensor] = None,
269
- position_embed: Optional[torch.Tensor] = None,
270
- ):
271
- # valid sequence length of inputs_embeds
272
- query_length = inputs.shape[1] if attention_mask is None else torch.sum(attention_mask.view(-1)).item()
273
-
274
- # extract valid inputs
275
- inputs = inputs[:, attention_mask.bool()] if attention_mask is not None else inputs
276
- if position_embed is not None:
277
- position_embed = (
278
- position_embed[:, :, :, attention_mask.bool(), :] if attention_mask is not None else position_embed
279
- )
280
-
281
- if self.rbln_config.use_attention_mask:
282
- chunked_attention_mask = (
283
- torch.zeros(
284
- 1, 1, self.rbln_config.prefill_chunk_size, self.rbln_config.max_seq_len, dtype=torch.float32
285
- )
286
- if self.rbln_config.use_attention_mask
287
- else None
288
- )
289
- else:
290
- chunked_attention_mask = None
291
-
292
- # padding for chunked prefill
293
- padding_size = (
294
- self.rbln_config.prefill_chunk_size - (query_length % self.rbln_config.prefill_chunk_size)
295
- ) % self.rbln_config.prefill_chunk_size
296
- padded_len = query_length + padding_size
297
-
298
- inputs = torch.nn.functional.pad(inputs, (0, padding_size))
299
- position_embed = (
300
- None if position_embed is None else torch.nn.functional.pad(position_embed, (0, 0, 0, padding_size))
301
- )
302
- cache_position = torch.arange(padded_len, dtype=torch.int32).unsqueeze(0)
303
-
304
- return inputs, chunked_attention_mask, position_embed, cache_position, query_length
305
-
306
- def _chunked_prefill_forward(
307
- self,
308
- inputs_embeds: torch.Tensor,
309
- attention_mask: Optional[torch.Tensor] = None,
310
- position_embed: Optional[torch.Tensor] = None,
311
- ):
312
- padded_input, chunked_attention_mask, padded_position_embed, cache_position, query_length = (
313
- self._preprocess_chunked_prefill(inputs_embeds, attention_mask, position_embed)
314
- )
315
-
316
- # chunked prefill
317
- last_hidden_states = []
318
- for step in range(0, query_length, self.rbln_config.prefill_chunk_size):
319
- # Extract the current chunk of inputs and cache positions
320
- input_chunk = padded_input[:, step : step + self.rbln_config.prefill_chunk_size]
321
- cache_pos_chunk = cache_position[:, step : step + self.rbln_config.prefill_chunk_size]
322
-
323
- model_args = {
324
- "input_ids": input_chunk,
325
- "cache_position": cache_pos_chunk,
326
- "block_tables": self.block_tables,
327
- }
328
-
329
- if chunked_attention_mask is not None:
330
- if step >= self.rbln_config.prefill_chunk_size:
331
- chunked_attention_mask[:, :, :, step - self.rbln_config.prefill_chunk_size : step] = 1
332
- chunked_attention_mask[:, :, :, step : step + self.rbln_config.prefill_chunk_size] = self.causal_mask
333
- model_args["attention_mask"] = chunked_attention_mask
334
-
335
- last_hidden_states_chunk = self.model[0](**model_args)
336
- last_hidden_states.append(last_hidden_states_chunk)
337
-
338
- last_hidden_states = torch.concat(last_hidden_states, dim=-2)[:, :query_length]
339
-
340
- return self._postprocess_chunked_prefill(last_hidden_states, attention_mask)
341
-
342
- def _postprocess_chunked_prefill(
343
- self, last_hidden_states: torch.Tensor, attention_mask: Optional[torch.Tensor] = None
344
- ):
345
- # index copy for attention mask
346
- if attention_mask is not None:
347
- new_last_hidden_states = torch.full(
348
- (1, attention_mask.shape[-1], last_hidden_states.shape[-1]),
349
- fill_value=1e-10,
350
- dtype=last_hidden_states.dtype,
351
- )
352
- mask_indices = torch.nonzero(attention_mask, as_tuple=True)[0]
353
- new_last_hidden_states.index_copy_(dim=-2, index=mask_indices, source=last_hidden_states)
354
- else:
355
- new_last_hidden_states = last_hidden_states
356
- return new_last_hidden_states
357
-
358
- def forward(
359
- self,
360
- input_ids: Optional[torch.LongTensor] = None,
361
- inputs_embeds: Optional[torch.Tensor] = None,
362
- attention_mask: Optional[torch.LongTensor] = None,
363
- position_embed: Optional[torch.Tensor] = None,
364
- **kwargs,
365
- ):
366
- inputs = inputs_embeds if inputs_embeds is not None else input_ids
367
- batch_size = inputs.shape[0]
368
- all_last_hidden_states = []
369
- for b_idx in range(batch_size):
370
- last_hidden_states = self._chunked_prefill_forward(
371
- inputs[b_idx : b_idx + 1],
372
- attention_mask[b_idx] if attention_mask is not None else None,
373
- position_embed[b_idx : b_idx + 1] if position_embed is not None else None,
374
- )
375
- all_last_hidden_states.append(last_hidden_states)
376
-
377
- return BaseModelOutputWithPast(last_hidden_state=torch.concat(all_last_hidden_states, dim=0))
@@ -12,15 +12,10 @@
12
12
  # See the License for the specific language governing permissions and
13
13
  # limitations under the License.
14
14
 
15
- import torch
16
- import torch.nn as nn
17
- from transformers import PreTrainedModel
18
15
 
19
16
  from ..decoderonly.decoderonly_architecture import (
20
17
  DecoderOnlyAttention,
21
- DecoderOnlyLayer,
22
18
  DecoderOnlyWrapper,
23
- RotaryEmbedding,
24
19
  )
25
20
 
26
21
 
@@ -37,239 +32,3 @@ class Qwen3Attention(DecoderOnlyAttention):
37
32
  self.o_proj = self._original_mod.o_proj
38
33
  self.q_norm = self._original_mod.q_norm
39
34
  self.k_norm = self._original_mod.k_norm
40
-
41
-
42
- class Qwen3ModelWrapper(nn.Module):
43
- def __init__(
44
- self,
45
- model,
46
- attn_impl=None,
47
- use_inputs_embeds=None,
48
- use_attention_mask=None,
49
- use_rotary_emb=None,
50
- cache_impl=None,
51
- kvcache_partition_len=None,
52
- max_seq_len=None,
53
- kvcache_block_size=None,
54
- sliding_window=None,
55
- sliding_window_layers=None,
56
- ):
57
- super().__init__()
58
- self.config = model.config
59
-
60
- if use_rotary_emb:
61
- rotary_embs = self.get_rotary_emb(max_seq_len=max_seq_len)
62
- if isinstance(rotary_embs, tuple):
63
- self.rotary_emb_global, self.rotary_emb_local = rotary_embs
64
- else:
65
- self.rotary_emb = rotary_embs
66
- else:
67
- self.rotary_emb = None
68
-
69
- self._original_mod = model
70
- self.use_inputs_embeds = use_inputs_embeds
71
- self.attn_impl = attn_impl
72
- self.cache_impl = cache_impl
73
- self.use_attention_mask = use_attention_mask
74
- self.kvcache_partition_len = kvcache_partition_len
75
- self.kvcache_block_size = kvcache_block_size
76
- self.max_seq_len = max_seq_len
77
- self.sliding_window = sliding_window
78
- self.sliding_window_layers = sliding_window_layers
79
- self.model = self.convert_to_rbln_model(model)
80
-
81
- def get_rotary_emb(self, max_seq_len):
82
- return RotaryEmbedding(config=self.config, max_seq_len_cached=max_seq_len)
83
-
84
- def convert_to_rbln_model(self, base_model: PreTrainedModel):
85
- for layer_idx, layer in enumerate(base_model.layers):
86
- is_sliding = layer_idx in self.sliding_window_layers
87
- new_self_attn = Qwen3Attention(
88
- layer.self_attn,
89
- self.use_attention_mask if not is_sliding else True,
90
- use_position_ids=None,
91
- kvcache_block_size=self.sliding_window
92
- if layer_idx in self.sliding_window_layers
93
- else self.kvcache_block_size,
94
- is_sliding=is_sliding,
95
- attn_impl=self.attn_impl if not is_sliding else "eager",
96
- kvcache_partition_len=self.kvcache_partition_len,
97
- )
98
- base_model.layers[layer_idx] = DecoderOnlyLayer(layer, new_self_attn)
99
-
100
- return base_model
101
-
102
- @property
103
- def hidden_multiplier(self):
104
- return 1
105
-
106
- def get_last_layernorm(self) -> nn.LayerNorm:
107
- return self._original_mod.norm
108
-
109
- def get_embedding(self) -> nn.Embedding:
110
- return self._original_mod.embed_tokens
111
-
112
- def get_pos_embedding(self) -> nn.Embedding:
113
- raise NotImplementedError(
114
- "The 'get_pos_embedding' method is not implemented. Please define this method in a subclass."
115
- )
116
-
117
- def convert_sequence_positions_for_flash_attn(self, seq_positions, max_seq_len):
118
- if self.attn_impl not in ["flash_attn"]:
119
- raise NotImplementedError(f"Unknown attn_impl ({self.attn_impl}).")
120
- partition_len = self.kvcache_partition_len
121
- num_partition = max_seq_len // partition_len
122
-
123
- cs = seq_positions.repeat(num_partition, 1).transpose(0, 1)
124
- pidx = torch.arange(num_partition)
125
- cache_pos_for_partitions = torch.clamp(cs - pidx * partition_len, 0, partition_len)
126
- return cache_pos_for_partitions
127
-
128
- def get_local_cache_positions(self, position_ids, query_position):
129
- max_cache_len = self.model.config.sliding_window
130
- valid_input_len = 1 if query_position is None else query_position + 1
131
- cache_seq_len = torch.clamp(position_ids, max=max_cache_len)[:, :1] # past seen tokens
132
- cache_offset = (
133
- torch.clamp(position_ids, max=max_cache_len)[:, :1] + valid_input_len
134
- ) # cache offset for next steps
135
-
136
- return cache_seq_len, cache_offset
137
-
138
- def prepare_forward_args(self, *args):
139
- args = list(args)
140
- input_ids = None if self.use_inputs_embeds else args.pop(0)
141
- inputs_embeds = args.pop(0) if self.use_inputs_embeds else None
142
- cache_position = args.pop(0)
143
- global_block_tables = args.pop(0) if self.cache_impl in ["hybrid", "static"] else None
144
- local_block_tables = args.pop(0) if self.cache_impl in ["hybrid", "sliding_window"] else None
145
- query_position = args.pop(0) if self.sliding_window else None
146
- attention_mask = args.pop(0) if self.use_attention_mask else None
147
- position_ids = None
148
- past_key_values = args
149
-
150
- if len(past_key_values) != 2 * self.config.num_hidden_layers:
151
- raise ValueError(
152
- f"Different past_key_values to model's config. {len(past_key_values)} != {2 * self.config.num_hidden_layers}"
153
- )
154
-
155
- # [key, value] * n_layer -> ( (key, value) ) * n_layer
156
- # cache shape : batch, n_heads, 1, max_seq_len, head_dim
157
- _past_key_values = []
158
- for i in range(self.config.num_hidden_layers):
159
- key_states = past_key_values[i * 2]
160
- value_states = past_key_values[i * 2 + 1]
161
- past_key_value = [key_states, value_states]
162
- _past_key_values.append(past_key_value)
163
- past_key_values = _past_key_values
164
-
165
- if hasattr(self, "rotary_emb_global") and hasattr(self, "rotary_emb_local"):
166
- rotary_emb = (self.rotary_emb_global, self.rotary_emb_local)
167
- else:
168
- rotary_emb = self.rotary_emb
169
-
170
- return (
171
- input_ids,
172
- inputs_embeds,
173
- cache_position,
174
- global_block_tables,
175
- local_block_tables,
176
- attention_mask,
177
- position_ids,
178
- query_position,
179
- past_key_values,
180
- rotary_emb,
181
- )
182
-
183
- def forward(self, *args):
184
- (
185
- input_ids,
186
- inputs_embeds,
187
- cache_position,
188
- global_block_tables,
189
- local_block_tables,
190
- attention_mask,
191
- position_ids,
192
- query_position,
193
- past_key_values,
194
- rotary_emb,
195
- ) = self.prepare_forward_args(*args)
196
-
197
- # retrieve input_ids and inputs_embeds
198
- if (input_ids is None) ^ (inputs_embeds is not None):
199
- raise ValueError(
200
- "You cannot specify both input_ids and inputs_embeds at the same time, and must specify either one"
201
- )
202
-
203
- # embed positions
204
- if inputs_embeds is None:
205
- inputs_embeds = self.get_embedding()(input_ids)
206
-
207
- hidden_states = inputs_embeds * self.hidden_multiplier
208
-
209
- # get cos,sin vector if needed
210
- position_ids = position_ids if position_ids is not None else cache_position
211
- if rotary_emb is not None:
212
- if isinstance(rotary_emb, torch.Tensor):
213
- cos = rotary_emb[0]
214
- sin = rotary_emb[1]
215
- else:
216
- cos, sin = rotary_emb(hidden_states, self.max_seq_len) # dtype carrier, max_seq_len
217
- cos, sin = slice_and_unsqueeze_cos_sin(cos, sin, position_ids)
218
- else:
219
- batch_size = inputs_embeds.shape[0]
220
- if position_ids.shape[0] > 1:
221
- position_embeds = []
222
- for b_idx in range(batch_size):
223
- position_embed = self.get_pos_embedding()(position_ids[b_idx])
224
- position_embeds.append(position_embed)
225
-
226
- position_embeds = torch.cat(position_embeds, dim=0).unsqueeze(1)
227
- else:
228
- position_embeds = self.get_pos_embedding()(position_ids)
229
- hidden_states = hidden_states + position_embeds
230
- cos, sin = None, None
231
-
232
- # Get sequence positions for flash attention
233
- if self.attn_impl == "flash_attn":
234
- seq_positions = cache_position[:, 0]
235
- seq_positions = self.convert_sequence_positions_for_flash_attn(
236
- seq_positions=seq_positions, max_seq_len=self.max_seq_len
237
- )
238
- else:
239
- seq_positions = cache_position[:, :1]
240
-
241
- # Get local cache positions for sliding window layers
242
- if len(self.sliding_window_layers) > 0:
243
- sliding_cache_pos = self.get_local_cache_positions(position_ids, query_position)
244
-
245
- for layer_idx, layer in enumerate(self.model.layers):
246
- is_sliding = True if layer_idx in self.sliding_window_layers else False
247
- hidden_states = layer(
248
- hidden_states=hidden_states,
249
- attention_mask=attention_mask,
250
- seq_positions=sliding_cache_pos if is_sliding else seq_positions,
251
- past_key_values=past_key_values,
252
- cos=cos,
253
- sin=sin,
254
- block_tables=local_block_tables if is_sliding else global_block_tables,
255
- )
256
-
257
- hidden_states = self.get_last_layernorm()(hidden_states)
258
- return hidden_states
259
-
260
-
261
- def slice_and_unsqueeze_cos_sin(cos, sin, cache_position, unsqueeze_dim=1):
262
- """Slice cos[cache_position], sin[cache_position] vector for the query."""
263
- if cache_position.shape[0] > 1:
264
- cos_all = []
265
- sin_all = []
266
- for i in range(cache_position.shape[0]):
267
- cos_all.append(cos[cache_position[i : i + 1]].unsqueeze(unsqueeze_dim))
268
- sin_all.append(sin[cache_position[i : i + 1]].unsqueeze(unsqueeze_dim))
269
- cos = torch.cat(cos_all, dim=0)
270
- sin = torch.cat(sin_all, dim=0)
271
- else:
272
- cos = cos[cache_position].unsqueeze(unsqueeze_dim)
273
- sin = sin[cache_position].unsqueeze(unsqueeze_dim)
274
-
275
- return cos, sin
@@ -14,8 +14,6 @@
14
14
 
15
15
  from typing import Any, Dict, Optional
16
16
 
17
- import rebel
18
-
19
17
  from ....configuration_utils import RBLNModelConfig
20
18
  from ....utils.logging import get_logger
21
19
 
@@ -39,7 +37,6 @@ class RBLNModelForSeq2SeqLMConfig(RBLNModelConfig):
39
37
  enc_max_seq_len (Optional[int]): Maximum sequence length for the encoder.
40
38
  dec_max_seq_len (Optional[int]): Maximum sequence length for the decoder.
41
39
  use_attention_mask (Optional[bool]): Whether to use attention masks during inference.
42
- This is automatically set to True for RBLN-CA02 devices.
43
40
  pad_token_id (Optional[int]): The ID of the padding token in the vocabulary.
44
41
  **kwargs: Additional arguments passed to the parent RBLNModelConfig.
45
42
 
@@ -55,12 +52,5 @@ class RBLNModelForSeq2SeqLMConfig(RBLNModelConfig):
55
52
  self.dec_max_seq_len = dec_max_seq_len
56
53
 
57
54
  self.use_attention_mask = use_attention_mask
58
- npu = self.npu or rebel.get_npu_name()
59
- if npu == "RBLN-CA02":
60
- if self.use_attention_mask is False:
61
- logger.warning("Attention mask should be used with RBLN-CA02. Setting use_attention_mask to True.")
62
- self.use_attention_mask = True
63
- else:
64
- self.use_attention_mask = self.use_attention_mask or False
65
55
 
66
56
  self.pad_token_id = pad_token_id
@@ -14,8 +14,6 @@
14
14
 
15
15
  from typing import Any, Dict
16
16
 
17
- import rebel
18
-
19
17
  from ....configuration_utils import RBLNModelConfig
20
18
  from ....utils.logging import get_logger
21
19
 
@@ -45,7 +43,6 @@ class RBLNWhisperForConditionalGenerationConfig(RBLNModelConfig):
45
43
  batch_size (int, optional): The batch size for inference. Defaults to 1.
46
44
  token_timestamps (bool, optional): Whether to output token timestamps during generation. Defaults to False.
47
45
  use_attention_mask (bool, optional): Whether to use attention masks during inference. This is automatically
48
- set to True for RBLN-CA02 devices.
49
46
  enc_max_seq_len (int, optional): Maximum sequence length for the encoder.
50
47
  dec_max_seq_len (int, optional): Maximum sequence length for the decoder.
51
48
  **kwargs: Additional arguments passed to the parent RBLNModelConfig.
@@ -64,10 +61,4 @@ class RBLNWhisperForConditionalGenerationConfig(RBLNModelConfig):
64
61
  self.dec_max_seq_len = dec_max_seq_len
65
62
 
66
63
  self.use_attention_mask = use_attention_mask
67
- npu = self.npu or rebel.get_npu_name()
68
- if npu == "RBLN-CA02":
69
- if self.use_attention_mask is False:
70
- logger.warning("Attention mask should be used with RBLN-CA02. Setting use_attention_mask to True.")
71
- self.use_attention_mask = True
72
- else:
73
- self.use_attention_mask = self.use_attention_mask or False
64
+ self.use_attention_mask = self.use_attention_mask or False
@@ -73,6 +73,7 @@ class RBLNRuntimeDecoder(RBLNPytorchRuntime):
73
73
  decoder_input_ids: torch.Tensor = None,
74
74
  decoder_attention_mask: torch.Tensor = None,
75
75
  cache_position: torch.Tensor = None,
76
+ block_tables: torch.Tensor = None,
76
77
  ):
77
78
  inputs_bsz = decoder_input_ids.shape[0]
78
79
  padded_bsz = self.batch_size - inputs_bsz
@@ -89,11 +90,14 @@ class RBLNRuntimeDecoder(RBLNPytorchRuntime):
89
90
  )
90
91
  decoder_attention_mask[b_idx, : decoding_step + 1] = 1
91
92
 
93
+ if block_tables is None:
94
+ block_tables = self.default_block_tables
95
+
92
96
  outputs = super().forward(
93
97
  decoder_input_ids,
94
98
  decoder_attention_mask if self.use_attention_mask else None,
95
99
  cache_position,
96
- block_tables=self.default_block_tables,
100
+ block_tables=block_tables,
97
101
  )
98
102
 
99
103
  if isinstance(outputs, torch.Tensor):