optimum-rbln 0.8.2a1__py3-none-any.whl → 0.8.2a3__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of optimum-rbln might be problematic. Click here for more details.

Files changed (34) hide show
  1. optimum/rbln/__init__.py +8 -0
  2. optimum/rbln/__version__.py +2 -2
  3. optimum/rbln/configuration_utils.py +16 -1
  4. optimum/rbln/diffusers/configurations/models/configuration_transformer_cosmos.py +3 -0
  5. optimum/rbln/diffusers/modeling_diffusers.py +3 -4
  6. optimum/rbln/diffusers/models/autoencoders/autoencoder_kl.py +1 -0
  7. optimum/rbln/diffusers/models/autoencoders/autoencoder_kl_cosmos.py +1 -0
  8. optimum/rbln/diffusers/models/autoencoders/vq_model.py +1 -0
  9. optimum/rbln/diffusers/models/transformers/transformer_cosmos.py +1 -1
  10. optimum/rbln/diffusers/pipelines/cosmos/configuration_cosmos_guardrail.py +10 -2
  11. optimum/rbln/diffusers/pipelines/cosmos/cosmos_guardrail.py +4 -30
  12. optimum/rbln/modeling.py +2 -3
  13. optimum/rbln/modeling_base.py +17 -13
  14. optimum/rbln/transformers/__init__.py +8 -0
  15. optimum/rbln/transformers/models/__init__.py +2 -0
  16. optimum/rbln/transformers/models/clip/configuration_clip.py +12 -1
  17. optimum/rbln/transformers/models/clip/modeling_clip.py +123 -28
  18. optimum/rbln/transformers/models/decoderonly/modeling_decoderonly.py +13 -1
  19. optimum/rbln/transformers/models/gemma3/configuration_gemma3.py +2 -3
  20. optimum/rbln/transformers/models/gemma3/modeling_gemma3.py +107 -249
  21. optimum/rbln/transformers/models/llava_next/configuration_llava_next.py +18 -1
  22. optimum/rbln/transformers/models/qwen3/__init__.py +16 -0
  23. optimum/rbln/transformers/models/qwen3/configuration_qwen3.py +71 -0
  24. optimum/rbln/transformers/models/qwen3/modeling_qwen3.py +377 -0
  25. optimum/rbln/transformers/models/qwen3/qwen3_architecture.py +275 -0
  26. optimum/rbln/transformers/models/seq2seq/modeling_seq2seq.py +2 -0
  27. optimum/rbln/transformers/models/time_series_transformer/modeling_time_series_transformer.py +2 -0
  28. optimum/rbln/transformers/models/whisper/modeling_whisper.py +2 -0
  29. optimum/rbln/utils/hub.py +8 -47
  30. optimum/rbln/utils/runtime_utils.py +28 -2
  31. {optimum_rbln-0.8.2a1.dist-info → optimum_rbln-0.8.2a3.dist-info}/METADATA +1 -1
  32. {optimum_rbln-0.8.2a1.dist-info → optimum_rbln-0.8.2a3.dist-info}/RECORD +34 -30
  33. {optimum_rbln-0.8.2a1.dist-info → optimum_rbln-0.8.2a3.dist-info}/WHEEL +0 -0
  34. {optimum_rbln-0.8.2a1.dist-info → optimum_rbln-0.8.2a3.dist-info}/licenses/LICENSE +0 -0
@@ -0,0 +1,377 @@
1
+ # Copyright 2025 Rebellions Inc. All rights reserved.
2
+
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at:
6
+
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ from pathlib import Path
16
+ from typing import TYPE_CHECKING, List, Optional, Union
17
+
18
+ import rebel
19
+ import torch
20
+ from rebel.compile_context import CompileContext
21
+ from transformers import PretrainedConfig, PreTrainedModel
22
+ from transformers.modeling_outputs import BaseModelOutputWithPast
23
+ from transformers.modeling_utils import no_init_weights
24
+
25
+ from ....configuration_utils import RBLNCompileConfig
26
+ from ....modeling import RBLNModel
27
+ from ....utils import logging
28
+ from ...models.decoderonly import RBLNDecoderOnlyModelForCausalLM, RBLNDecoderOnlyModelForCausalLMConfig
29
+ from ..decoderonly.modeling_decoderonly import set_default_values, validate_attention_method
30
+ from .configuration_qwen3 import RBLNQwen3ModelConfig
31
+ from .qwen3_architecture import Qwen3ModelWrapper, Qwen3Wrapper
32
+
33
+
34
+ logger = logging.get_logger(__name__)
35
+
36
+ if TYPE_CHECKING:
37
+ from transformers import (
38
+ AutoFeatureExtractor,
39
+ AutoProcessor,
40
+ AutoTokenizer,
41
+ PretrainedConfig,
42
+ )
43
+
44
+
45
+ class RBLNQwen3ForCausalLM(RBLNDecoderOnlyModelForCausalLM):
46
+ _decoder_wrapper_cls = Qwen3Wrapper
47
+
48
+ @classmethod
49
+ def _update_sliding_window_config(
50
+ cls, model_config: PretrainedConfig, rbln_config: RBLNDecoderOnlyModelForCausalLMConfig
51
+ ):
52
+ # https://github.com/huggingface/transformers/issues/35896
53
+ # There seems to be a bug in transformers(v4.52.4). Therefore, similar to when attn_implementation is eager,
54
+ # we set all layers to use sliding window in this version. This should be updated once the bug is fixed.
55
+
56
+ rbln_config.cache_impl = "sliding_window"
57
+ rbln_config.sliding_window = model_config.sliding_window
58
+ rbln_config.sliding_window_layers = list(range(model_config.num_hidden_layers))
59
+ return rbln_config
60
+
61
+ def forward(self, *args, **kwargs):
62
+ kwargs["return_dict"] = True
63
+ return super().forward(*args, **kwargs)
64
+
65
+
66
+ class RBLNQwen3Model(RBLNModel):
67
+ _decoder_wrapper_cls = Qwen3ModelWrapper
68
+ _use_rotary_emb = True
69
+
70
+ def __post_init__(self, **kwargs):
71
+ artifacts = torch.load(self.model_save_dir / self.subfolder / "torch_artifacts.pth", weights_only=False)
72
+ self.embed_tokens = self._create_embedding_layer()
73
+ self.embed_tokens.load_state_dict(artifacts["embed_tokens"])
74
+ self.block_tables = torch.arange(
75
+ self.rbln_config.max_seq_len / self.rbln_config.kvcache_block_size, dtype=torch.int16
76
+ )
77
+ self.causal_mask = 1 - torch.triu(
78
+ torch.ones(1, 1, self.rbln_config.prefill_chunk_size, self.rbln_config.prefill_chunk_size), diagonal=1
79
+ )
80
+
81
+ @classmethod
82
+ def save_torch_artifacts(
83
+ cls,
84
+ model: PreTrainedModel,
85
+ save_dir_path: Path,
86
+ subfolder: str,
87
+ rbln_config: RBLNQwen3ModelConfig,
88
+ ):
89
+ save_dict = {}
90
+ save_dict["embed_tokens"] = model.get_input_embeddings().state_dict()
91
+ torch.save(save_dict, save_dir_path / subfolder / "torch_artifacts.pth")
92
+
93
+ def _create_embedding_layer(self):
94
+ with no_init_weights():
95
+ embed_tokens = torch.nn.Embedding(
96
+ self.config.vocab_size,
97
+ self.config.hidden_size,
98
+ self.config.pad_token_id,
99
+ )
100
+ return embed_tokens
101
+
102
+ def get_input_embeddings(self):
103
+ return self.embed_tokens
104
+
105
+ @classmethod
106
+ def wrap_model_if_needed(cls, model: PreTrainedModel, rbln_config: "RBLNQwen3ModelConfig"):
107
+ wrapper_cfg = {
108
+ "max_seq_len": rbln_config.max_seq_len,
109
+ "attn_impl": rbln_config.attn_impl,
110
+ "kvcache_partition_len": rbln_config.kvcache_partition_len,
111
+ "kvcache_block_size": rbln_config.kvcache_block_size,
112
+ "use_rotary_emb": cls._use_rotary_emb,
113
+ "use_attention_mask": rbln_config.use_attention_mask,
114
+ "cache_impl": rbln_config.cache_impl,
115
+ "sliding_window": rbln_config.sliding_window,
116
+ "sliding_window_layers": rbln_config.sliding_window_layers,
117
+ }
118
+ return cls._decoder_wrapper_cls(model, **wrapper_cfg).eval()
119
+
120
+ @classmethod
121
+ @torch.inference_mode()
122
+ def get_compiled_model(cls, model: "PreTrainedModel", rbln_config: RBLNQwen3ModelConfig):
123
+ wrapped_model = cls.wrap_model_if_needed(model, rbln_config)
124
+
125
+ rbln_compile_configs = rbln_config.compile_cfgs
126
+ prefill_compile_config = rbln_compile_configs[0]
127
+
128
+ context = CompileContext(use_weight_sharing=False)
129
+
130
+ meta_tensor_names = [name for name, _, _ in prefill_compile_config.input_info if "past_key_values" in name]
131
+ prefill_example_inputs = prefill_compile_config.get_dummy_inputs(fill=0, meta_tensor_names=meta_tensor_names)
132
+
133
+ static_tensors = {}
134
+ for (name, _, _), tensor in zip(prefill_compile_config.input_info, prefill_example_inputs):
135
+ if "past_key_values" in name:
136
+ static_tensors[name] = tensor
137
+ context.mark_static_address(tensor)
138
+
139
+ def compile_model(wrapped_model, compile_config, example_inputs, compile_context):
140
+ try:
141
+ original_linear = torch.nn.functional.linear
142
+ torch.nn.functional.linear = torch.ops.rbln_custom_ops.linear
143
+ compiled_model = RBLNModel.compile(
144
+ wrapped_model,
145
+ compile_config,
146
+ example_inputs=example_inputs,
147
+ compile_context=compile_context,
148
+ create_runtimes=rbln_config.create_runtimes,
149
+ device=rbln_config.device,
150
+ )
151
+ return compiled_model
152
+ finally:
153
+ torch.nn.functional.linear = original_linear
154
+
155
+ wrapped_model.phase = "prefill"
156
+ compiled_prefill = compile_model(wrapped_model, prefill_compile_config, prefill_example_inputs, context)
157
+
158
+ compiled_models = {"prefill": compiled_prefill}
159
+ return compiled_models
160
+
161
+ @classmethod
162
+ def get_input_info(
163
+ cls,
164
+ batch_size: int,
165
+ query_length: int,
166
+ rbln_config: RBLNQwen3ModelConfig,
167
+ model_config: PretrainedConfig,
168
+ ):
169
+ input_info = RBLNDecoderOnlyModelForCausalLM.get_input_info(
170
+ batch_size,
171
+ query_length,
172
+ rbln_config=rbln_config,
173
+ model_config=model_config,
174
+ )
175
+
176
+ if rbln_config.sliding_window is None:
177
+ # remove query position
178
+ input_info.pop(3)
179
+
180
+ return input_info
181
+
182
+ @classmethod
183
+ def _update_rbln_config(
184
+ cls,
185
+ preprocessors: Optional[Union["AutoFeatureExtractor", "AutoProcessor", "AutoTokenizer"]] = None,
186
+ model: Optional["PreTrainedModel"] = None,
187
+ model_config: Optional["PretrainedConfig"] = None,
188
+ rbln_config: Optional[RBLNQwen3ModelConfig] = None,
189
+ ) -> RBLNQwen3ModelConfig:
190
+ if rbln_config.max_seq_len is None:
191
+ rbln_config.max_seq_len = getattr(model_config, "max_position_embeddings", None) or getattr(
192
+ model_config, "n_positions", None
193
+ )
194
+ if rbln_config.max_seq_len is None:
195
+ raise ValueError("`max_seq_len` should be specified.")
196
+
197
+ rbln_config.attn_impl, rbln_config.kvcache_partition_len, rbln_config.kvcache_block_size = set_default_values(
198
+ attn_impl=rbln_config.attn_impl,
199
+ kvcache_partition_len=rbln_config.kvcache_partition_len,
200
+ kvcache_block_size=rbln_config.kvcache_block_size,
201
+ max_seq_len=rbln_config.max_seq_len,
202
+ )
203
+
204
+ validate_attention_method(
205
+ attn_impl=rbln_config.attn_impl,
206
+ kvcache_partition_len=rbln_config.kvcache_partition_len,
207
+ kvcache_block_size=rbln_config.kvcache_block_size,
208
+ max_seq_len=rbln_config.max_seq_len,
209
+ )
210
+
211
+ # only compile prefill cb -> always batch_size 1
212
+ required_num_blocks = rbln_config.max_seq_len // rbln_config.kvcache_block_size
213
+ max_num_blocks = required_num_blocks
214
+
215
+ if rbln_config.attn_impl == "flash_attn":
216
+ estimated_max_num_blocks = RBLNDecoderOnlyModelForCausalLM.get_maximum_num_blocks(
217
+ config=model_config,
218
+ tensor_parallel_size=rbln_config.tensor_parallel_size or 1,
219
+ kvcache_block_size=rbln_config.kvcache_block_size,
220
+ nbits_per_param=16 if not rbln_config.quantization else 4,
221
+ n_model_params=sum(p.numel() for p in model.parameters()),
222
+ num_runtimes=1 + len(rbln_config.decoder_batch_sizes),
223
+ )
224
+
225
+ max_num_blocks = min(max_num_blocks, estimated_max_num_blocks)
226
+
227
+ flash_min_blocks = rbln_config.max_seq_len // rbln_config.kvcache_block_size + 1
228
+ if max_num_blocks < flash_min_blocks:
229
+ max_num_blocks = flash_min_blocks
230
+
231
+ if rbln_config.kvcache_num_blocks is None:
232
+ rbln_config.kvcache_num_blocks = max_num_blocks
233
+
234
+ prefill_input_info = cls.get_input_info(
235
+ batch_size=1,
236
+ query_length=rbln_config.prefill_chunk_size,
237
+ rbln_config=rbln_config,
238
+ model_config=model_config,
239
+ )
240
+
241
+ prefill_compile_config = RBLNCompileConfig(compiled_model_name="prefill", input_info=prefill_input_info)
242
+ rbln_config.set_compile_cfgs([prefill_compile_config])
243
+
244
+ return rbln_config
245
+
246
+ @classmethod
247
+ def _create_runtimes(
248
+ cls,
249
+ compiled_models: List[rebel.RBLNCompiledModel],
250
+ rbln_config: RBLNQwen3ModelConfig,
251
+ ) -> List[rebel.Runtime]:
252
+ expected_model_names = ["prefill"]
253
+ if any(model_name not in rbln_config.device_map for model_name in expected_model_names):
254
+ cls._raise_missing_compiled_file_error(expected_model_names)
255
+
256
+ return [
257
+ rebel.Runtime(
258
+ compiled_models[0],
259
+ tensor_type="pt",
260
+ device=rbln_config.device_map["prefill"],
261
+ activate_profiler=rbln_config.activate_profiler,
262
+ ),
263
+ ]
264
+
265
+ def _preprocess_chunked_prefill(
266
+ self,
267
+ inputs: torch.Tensor,
268
+ attention_mask: Optional[torch.Tensor] = None,
269
+ position_embed: Optional[torch.Tensor] = None,
270
+ ):
271
+ # valid sequence length of inputs_embeds
272
+ query_length = inputs.shape[1] if attention_mask is None else torch.sum(attention_mask.view(-1)).item()
273
+
274
+ # extract valid inputs
275
+ inputs = inputs[:, attention_mask.bool()] if attention_mask is not None else inputs
276
+ if position_embed is not None:
277
+ position_embed = (
278
+ position_embed[:, :, :, attention_mask.bool(), :] if attention_mask is not None else position_embed
279
+ )
280
+
281
+ if self.rbln_config.use_attention_mask:
282
+ chunked_attention_mask = (
283
+ torch.zeros(
284
+ 1, 1, self.rbln_config.prefill_chunk_size, self.rbln_config.max_seq_len, dtype=torch.float32
285
+ )
286
+ if self.rbln_config.use_attention_mask
287
+ else None
288
+ )
289
+ else:
290
+ chunked_attention_mask = None
291
+
292
+ # padding for chunked prefill
293
+ padding_size = (
294
+ self.rbln_config.prefill_chunk_size - (query_length % self.rbln_config.prefill_chunk_size)
295
+ ) % self.rbln_config.prefill_chunk_size
296
+ padded_len = query_length + padding_size
297
+
298
+ inputs = torch.nn.functional.pad(inputs, (0, padding_size))
299
+ position_embed = (
300
+ None if position_embed is None else torch.nn.functional.pad(position_embed, (0, 0, 0, padding_size))
301
+ )
302
+ cache_position = torch.arange(padded_len, dtype=torch.int32).unsqueeze(0)
303
+
304
+ return inputs, chunked_attention_mask, position_embed, cache_position, query_length
305
+
306
+ def _chunked_prefill_forward(
307
+ self,
308
+ inputs_embeds: torch.Tensor,
309
+ attention_mask: Optional[torch.Tensor] = None,
310
+ position_embed: Optional[torch.Tensor] = None,
311
+ ):
312
+ padded_input, chunked_attention_mask, padded_position_embed, cache_position, query_length = (
313
+ self._preprocess_chunked_prefill(inputs_embeds, attention_mask, position_embed)
314
+ )
315
+
316
+ # chunked prefill
317
+ last_hidden_states = []
318
+ for step in range(0, query_length, self.rbln_config.prefill_chunk_size):
319
+ # Extract the current chunk of inputs and cache positions
320
+ input_chunk = padded_input[:, step : step + self.rbln_config.prefill_chunk_size]
321
+ cache_pos_chunk = cache_position[:, step : step + self.rbln_config.prefill_chunk_size]
322
+
323
+ model_args = {
324
+ "input_ids": input_chunk,
325
+ "cache_position": cache_pos_chunk,
326
+ "block_tables": self.block_tables,
327
+ }
328
+
329
+ if chunked_attention_mask is not None:
330
+ if step >= self.rbln_config.prefill_chunk_size:
331
+ chunked_attention_mask[:, :, :, step - self.rbln_config.prefill_chunk_size : step] = 1
332
+ chunked_attention_mask[:, :, :, step : step + self.rbln_config.prefill_chunk_size] = self.causal_mask
333
+ model_args["attention_mask"] = chunked_attention_mask
334
+
335
+ last_hidden_states_chunk = self.model[0](**model_args)
336
+ last_hidden_states.append(last_hidden_states_chunk)
337
+
338
+ last_hidden_states = torch.concat(last_hidden_states, dim=-2)[:, :query_length]
339
+
340
+ return self._postprocess_chunked_prefill(last_hidden_states, attention_mask)
341
+
342
+ def _postprocess_chunked_prefill(
343
+ self, last_hidden_states: torch.Tensor, attention_mask: Optional[torch.Tensor] = None
344
+ ):
345
+ # index copy for attention mask
346
+ if attention_mask is not None:
347
+ new_last_hidden_states = torch.full(
348
+ (1, attention_mask.shape[-1], last_hidden_states.shape[-1]),
349
+ fill_value=1e-10,
350
+ dtype=last_hidden_states.dtype,
351
+ )
352
+ mask_indices = torch.nonzero(attention_mask, as_tuple=True)[0]
353
+ new_last_hidden_states.index_copy_(dim=-2, index=mask_indices, source=last_hidden_states)
354
+ else:
355
+ new_last_hidden_states = last_hidden_states
356
+ return new_last_hidden_states
357
+
358
+ def forward(
359
+ self,
360
+ input_ids: Optional[torch.LongTensor] = None,
361
+ inputs_embeds: Optional[torch.Tensor] = None,
362
+ attention_mask: Optional[torch.LongTensor] = None,
363
+ position_embed: Optional[torch.Tensor] = None,
364
+ **kwargs,
365
+ ):
366
+ inputs = inputs_embeds if inputs_embeds is not None else input_ids
367
+ batch_size = inputs.shape[0]
368
+ all_last_hidden_states = []
369
+ for b_idx in range(batch_size):
370
+ last_hidden_states = self._chunked_prefill_forward(
371
+ inputs[b_idx : b_idx + 1],
372
+ attention_mask[b_idx] if attention_mask is not None else None,
373
+ position_embed[b_idx : b_idx + 1] if position_embed is not None else None,
374
+ )
375
+ all_last_hidden_states.append(last_hidden_states)
376
+
377
+ return BaseModelOutputWithPast(last_hidden_state=torch.concat(all_last_hidden_states, dim=0))
@@ -0,0 +1,275 @@
1
+ # Copyright 2025 Rebellions Inc. All rights reserved.
2
+
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at:
6
+
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ import torch
16
+ import torch.nn as nn
17
+ from transformers import PreTrainedModel
18
+
19
+ from ..decoderonly.decoderonly_architecture import (
20
+ DecoderOnlyAttention,
21
+ DecoderOnlyLayer,
22
+ DecoderOnlyWrapper,
23
+ RotaryEmbedding,
24
+ )
25
+
26
+
27
+ class Qwen3Wrapper(DecoderOnlyWrapper):
28
+ def get_rbln_attn_class(self):
29
+ return Qwen3Attention
30
+
31
+
32
+ class Qwen3Attention(DecoderOnlyAttention):
33
+ def __post_init__(self):
34
+ self.k_proj = self._original_mod.k_proj
35
+ self.v_proj = self._original_mod.v_proj
36
+ self.q_proj = self._original_mod.q_proj
37
+ self.o_proj = self._original_mod.o_proj
38
+ self.q_norm = self._original_mod.q_norm
39
+ self.k_norm = self._original_mod.k_norm
40
+
41
+
42
+ class Qwen3ModelWrapper(nn.Module):
43
+ def __init__(
44
+ self,
45
+ model,
46
+ attn_impl=None,
47
+ use_inputs_embeds=None,
48
+ use_attention_mask=None,
49
+ use_rotary_emb=None,
50
+ cache_impl=None,
51
+ kvcache_partition_len=None,
52
+ max_seq_len=None,
53
+ kvcache_block_size=None,
54
+ sliding_window=None,
55
+ sliding_window_layers=None,
56
+ ):
57
+ super().__init__()
58
+ self.config = model.config
59
+
60
+ if use_rotary_emb:
61
+ rotary_embs = self.get_rotary_emb(max_seq_len=max_seq_len)
62
+ if isinstance(rotary_embs, tuple):
63
+ self.rotary_emb_global, self.rotary_emb_local = rotary_embs
64
+ else:
65
+ self.rotary_emb = rotary_embs
66
+ else:
67
+ self.rotary_emb = None
68
+
69
+ self._original_mod = model
70
+ self.use_inputs_embeds = use_inputs_embeds
71
+ self.attn_impl = attn_impl
72
+ self.cache_impl = cache_impl
73
+ self.use_attention_mask = use_attention_mask
74
+ self.kvcache_partition_len = kvcache_partition_len
75
+ self.kvcache_block_size = kvcache_block_size
76
+ self.max_seq_len = max_seq_len
77
+ self.sliding_window = sliding_window
78
+ self.sliding_window_layers = sliding_window_layers
79
+ self.model = self.convert_to_rbln_model(model)
80
+
81
+ def get_rotary_emb(self, max_seq_len):
82
+ return RotaryEmbedding(config=self.config, max_seq_len_cached=max_seq_len)
83
+
84
+ def convert_to_rbln_model(self, base_model: PreTrainedModel):
85
+ for layer_idx, layer in enumerate(base_model.layers):
86
+ is_sliding = layer_idx in self.sliding_window_layers
87
+ new_self_attn = Qwen3Attention(
88
+ layer.self_attn,
89
+ self.use_attention_mask if not is_sliding else True,
90
+ use_position_ids=None,
91
+ kvcache_block_size=self.sliding_window
92
+ if layer_idx in self.sliding_window_layers
93
+ else self.kvcache_block_size,
94
+ is_sliding=is_sliding,
95
+ attn_impl=self.attn_impl if not is_sliding else "eager",
96
+ kvcache_partition_len=self.kvcache_partition_len,
97
+ )
98
+ base_model.layers[layer_idx] = DecoderOnlyLayer(layer, new_self_attn)
99
+
100
+ return base_model
101
+
102
+ @property
103
+ def hidden_multiplier(self):
104
+ return 1
105
+
106
+ def get_last_layernorm(self) -> nn.LayerNorm:
107
+ return self._original_mod.norm
108
+
109
+ def get_embedding(self) -> nn.Embedding:
110
+ return self._original_mod.embed_tokens
111
+
112
+ def get_pos_embedding(self) -> nn.Embedding:
113
+ raise NotImplementedError(
114
+ "The 'get_pos_embedding' method is not implemented. Please define this method in a subclass."
115
+ )
116
+
117
+ def convert_sequence_positions_for_flash_attn(self, seq_positions, max_seq_len):
118
+ if self.attn_impl not in ["flash_attn"]:
119
+ raise NotImplementedError(f"Unknown attn_impl ({self.attn_impl}).")
120
+ partition_len = self.kvcache_partition_len
121
+ num_partition = max_seq_len // partition_len
122
+
123
+ cs = seq_positions.repeat(num_partition, 1).transpose(0, 1)
124
+ pidx = torch.arange(num_partition)
125
+ cache_pos_for_partitions = torch.clamp(cs - pidx * partition_len, 0, partition_len)
126
+ return cache_pos_for_partitions
127
+
128
+ def get_local_cache_positions(self, position_ids, query_position):
129
+ max_cache_len = self.model.config.sliding_window
130
+ valid_input_len = 1 if query_position is None else query_position + 1
131
+ cache_seq_len = torch.clamp(position_ids, max=max_cache_len)[:, :1] # past seen tokens
132
+ cache_offset = (
133
+ torch.clamp(position_ids, max=max_cache_len)[:, :1] + valid_input_len
134
+ ) # cache offset for next steps
135
+
136
+ return cache_seq_len, cache_offset
137
+
138
+ def prepare_forward_args(self, *args):
139
+ args = list(args)
140
+ input_ids = None if self.use_inputs_embeds else args.pop(0)
141
+ inputs_embeds = args.pop(0) if self.use_inputs_embeds else None
142
+ cache_position = args.pop(0)
143
+ global_block_tables = args.pop(0) if self.cache_impl in ["hybrid", "static"] else None
144
+ local_block_tables = args.pop(0) if self.cache_impl in ["hybrid", "sliding_window"] else None
145
+ query_position = args.pop(0) if self.sliding_window else None
146
+ attention_mask = args.pop(0) if self.use_attention_mask else None
147
+ position_ids = None
148
+ past_key_values = args
149
+
150
+ if len(past_key_values) != 2 * self.config.num_hidden_layers:
151
+ raise ValueError(
152
+ f"Different past_key_values to model's config. {len(past_key_values)} != {2 * self.config.num_hidden_layers}"
153
+ )
154
+
155
+ # [key, value] * n_layer -> ( (key, value) ) * n_layer
156
+ # cache shape : batch, n_heads, 1, max_seq_len, head_dim
157
+ _past_key_values = []
158
+ for i in range(self.config.num_hidden_layers):
159
+ key_states = past_key_values[i * 2]
160
+ value_states = past_key_values[i * 2 + 1]
161
+ past_key_value = [key_states, value_states]
162
+ _past_key_values.append(past_key_value)
163
+ past_key_values = _past_key_values
164
+
165
+ if hasattr(self, "rotary_emb_global") and hasattr(self, "rotary_emb_local"):
166
+ rotary_emb = (self.rotary_emb_global, self.rotary_emb_local)
167
+ else:
168
+ rotary_emb = self.rotary_emb
169
+
170
+ return (
171
+ input_ids,
172
+ inputs_embeds,
173
+ cache_position,
174
+ global_block_tables,
175
+ local_block_tables,
176
+ attention_mask,
177
+ position_ids,
178
+ query_position,
179
+ past_key_values,
180
+ rotary_emb,
181
+ )
182
+
183
+ def forward(self, *args):
184
+ (
185
+ input_ids,
186
+ inputs_embeds,
187
+ cache_position,
188
+ global_block_tables,
189
+ local_block_tables,
190
+ attention_mask,
191
+ position_ids,
192
+ query_position,
193
+ past_key_values,
194
+ rotary_emb,
195
+ ) = self.prepare_forward_args(*args)
196
+
197
+ # retrieve input_ids and inputs_embeds
198
+ if (input_ids is None) ^ (inputs_embeds is not None):
199
+ raise ValueError(
200
+ "You cannot specify both input_ids and inputs_embeds at the same time, and must specify either one"
201
+ )
202
+
203
+ # embed positions
204
+ if inputs_embeds is None:
205
+ inputs_embeds = self.get_embedding()(input_ids)
206
+
207
+ hidden_states = inputs_embeds * self.hidden_multiplier
208
+
209
+ # get cos,sin vector if needed
210
+ position_ids = position_ids if position_ids is not None else cache_position
211
+ if rotary_emb is not None:
212
+ if isinstance(rotary_emb, torch.Tensor):
213
+ cos = rotary_emb[0]
214
+ sin = rotary_emb[1]
215
+ else:
216
+ cos, sin = rotary_emb(hidden_states, self.max_seq_len) # dtype carrier, max_seq_len
217
+ cos, sin = slice_and_unsqueeze_cos_sin(cos, sin, position_ids)
218
+ else:
219
+ batch_size = inputs_embeds.shape[0]
220
+ if position_ids.shape[0] > 1:
221
+ position_embeds = []
222
+ for b_idx in range(batch_size):
223
+ position_embed = self.get_pos_embedding()(position_ids[b_idx])
224
+ position_embeds.append(position_embed)
225
+
226
+ position_embeds = torch.cat(position_embeds, dim=0).unsqueeze(1)
227
+ else:
228
+ position_embeds = self.get_pos_embedding()(position_ids)
229
+ hidden_states = hidden_states + position_embeds
230
+ cos, sin = None, None
231
+
232
+ # Get sequence positions for flash attention
233
+ if self.attn_impl == "flash_attn":
234
+ seq_positions = cache_position[:, 0]
235
+ seq_positions = self.convert_sequence_positions_for_flash_attn(
236
+ seq_positions=seq_positions, max_seq_len=self.max_seq_len
237
+ )
238
+ else:
239
+ seq_positions = cache_position[:, :1]
240
+
241
+ # Get local cache positions for sliding window layers
242
+ if len(self.sliding_window_layers) > 0:
243
+ sliding_cache_pos = self.get_local_cache_positions(position_ids, query_position)
244
+
245
+ for layer_idx, layer in enumerate(self.model.layers):
246
+ is_sliding = True if layer_idx in self.sliding_window_layers else False
247
+ hidden_states = layer(
248
+ hidden_states=hidden_states,
249
+ attention_mask=attention_mask,
250
+ seq_positions=sliding_cache_pos if is_sliding else seq_positions,
251
+ past_key_values=past_key_values,
252
+ cos=cos,
253
+ sin=sin,
254
+ block_tables=local_block_tables if is_sliding else global_block_tables,
255
+ )
256
+
257
+ hidden_states = self.get_last_layernorm()(hidden_states)
258
+ return hidden_states
259
+
260
+
261
+ def slice_and_unsqueeze_cos_sin(cos, sin, cache_position, unsqueeze_dim=1):
262
+ """Slice cos[cache_position], sin[cache_position] vector for the query."""
263
+ if cache_position.shape[0] > 1:
264
+ cos_all = []
265
+ sin_all = []
266
+ for i in range(cache_position.shape[0]):
267
+ cos_all.append(cos[cache_position[i : i + 1]].unsqueeze(unsqueeze_dim))
268
+ sin_all.append(sin[cache_position[i : i + 1]].unsqueeze(unsqueeze_dim))
269
+ cos = torch.cat(cos_all, dim=0)
270
+ sin = torch.cat(sin_all, dim=0)
271
+ else:
272
+ cos = cos[cache_position].unsqueeze(unsqueeze_dim)
273
+ sin = sin[cache_position].unsqueeze(unsqueeze_dim)
274
+
275
+ return cos, sin
@@ -327,12 +327,14 @@ class RBLNModelForSeq2SeqLM(RBLNModel, ABC):
327
327
  tensor_type="pt",
328
328
  device=rbln_config.device_map["encoder"],
329
329
  activate_profiler=rbln_config.activate_profiler,
330
+ timeout=rbln_config.timeout,
330
331
  ),
331
332
  rebel.Runtime(
332
333
  compiled_models[1],
333
334
  tensor_type="pt",
334
335
  device=rbln_config.device_map["decoder"],
335
336
  activate_profiler=rbln_config.activate_profiler,
337
+ timeout=rbln_config.timeout,
336
338
  ),
337
339
  ]
338
340
 
@@ -331,12 +331,14 @@ class RBLNTimeSeriesTransformerForPrediction(RBLNModel):
331
331
  tensor_type="pt",
332
332
  device=rbln_config.device_map["encoder"],
333
333
  activate_profiler=rbln_config.activate_profiler,
334
+ timeout=rbln_config.timeout,
334
335
  ),
335
336
  rebel.Runtime(
336
337
  compiled_models[1],
337
338
  tensor_type="pt",
338
339
  device=rbln_config.device_map["decoder"],
339
340
  activate_profiler=rbln_config.activate_profiler,
341
+ timeout=rbln_config.timeout,
340
342
  ),
341
343
  ]
342
344