optimum-rbln 0.8.1rc0__py3-none-any.whl → 0.8.2__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of optimum-rbln might be problematic. Click here for more details.
- optimum/rbln/__init__.py +58 -9
- optimum/rbln/__version__.py +2 -2
- optimum/rbln/configuration_utils.py +24 -5
- optimum/rbln/diffusers/configurations/models/__init__.py +1 -1
- optimum/rbln/diffusers/configurations/models/configuration_autoencoder_kl.py +2 -2
- optimum/rbln/diffusers/configurations/models/configuration_autoencoder_kl_cosmos.py +5 -3
- optimum/rbln/diffusers/configurations/models/configuration_controlnet.py +2 -2
- optimum/rbln/diffusers/configurations/models/configuration_prior_transformer.py +2 -2
- optimum/rbln/diffusers/configurations/models/{configuration_cosmos_transformer.py → configuration_transformer_cosmos.py} +7 -2
- optimum/rbln/diffusers/configurations/models/configuration_transformer_sd3.py +2 -2
- optimum/rbln/diffusers/configurations/models/configuration_unet_2d_condition.py +2 -2
- optimum/rbln/diffusers/configurations/models/configuration_vq_model.py +2 -2
- optimum/rbln/diffusers/configurations/pipelines/configuration_controlnet.py +3 -3
- optimum/rbln/diffusers/configurations/pipelines/configuration_cosmos.py +10 -6
- optimum/rbln/diffusers/configurations/pipelines/configuration_kandinsky2_2.py +4 -4
- optimum/rbln/diffusers/configurations/pipelines/configuration_stable_diffusion.py +2 -2
- optimum/rbln/diffusers/configurations/pipelines/configuration_stable_diffusion_3.py +2 -2
- optimum/rbln/diffusers/configurations/pipelines/configuration_stable_diffusion_xl.py +2 -2
- optimum/rbln/diffusers/modeling_diffusers.py +4 -5
- optimum/rbln/diffusers/models/__init__.py +3 -13
- optimum/rbln/diffusers/models/autoencoders/autoencoder_kl.py +1 -0
- optimum/rbln/diffusers/models/autoencoders/autoencoder_kl_cosmos.py +1 -0
- optimum/rbln/diffusers/models/autoencoders/vq_model.py +1 -0
- optimum/rbln/diffusers/models/transformers/transformer_cosmos.py +1 -1
- optimum/rbln/diffusers/pipelines/__init__.py +1 -5
- optimum/rbln/diffusers/pipelines/cosmos/configuration_cosmos_guardrail.py +12 -4
- optimum/rbln/diffusers/pipelines/cosmos/cosmos_guardrail.py +4 -26
- optimum/rbln/diffusers/pipelines/cosmos/pipeline_cosmos_text2world.py +2 -2
- optimum/rbln/diffusers/pipelines/cosmos/pipeline_cosmos_video2world.py +2 -2
- optimum/rbln/diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_combined.py +1 -6
- optimum/rbln/modeling.py +4 -5
- optimum/rbln/modeling_base.py +18 -14
- optimum/rbln/ops/kv_cache_update.py +5 -0
- optimum/rbln/ops/linear.py +7 -0
- optimum/rbln/transformers/__init__.py +60 -0
- optimum/rbln/transformers/configuration_generic.py +4 -4
- optimum/rbln/transformers/modeling_attention_utils.py +252 -0
- optimum/rbln/transformers/modeling_generic.py +1 -4
- optimum/rbln/transformers/models/__init__.py +45 -30
- optimum/rbln/transformers/models/bart/bart_architecture.py +2 -7
- optimum/rbln/transformers/models/blip_2/configuration_blip_2.py +2 -2
- optimum/rbln/transformers/models/blip_2/modeling_blip_2.py +1 -5
- optimum/rbln/transformers/models/clip/configuration_clip.py +14 -3
- optimum/rbln/transformers/models/clip/modeling_clip.py +123 -28
- optimum/rbln/transformers/models/colpali/colpali_architecture.py +1 -4
- optimum/rbln/transformers/models/colpali/configuration_colpali.py +2 -2
- optimum/rbln/transformers/models/colpali/modeling_colpali.py +2 -10
- optimum/rbln/transformers/models/decoderonly/__init__.py +2 -2
- optimum/rbln/transformers/models/decoderonly/configuration_decoderonly.py +214 -45
- optimum/rbln/transformers/models/decoderonly/decoderonly_architecture.py +323 -454
- optimum/rbln/transformers/models/decoderonly/modeling_decoderonly.py +579 -362
- optimum/rbln/transformers/models/exaone/exaone_architecture.py +17 -42
- optimum/rbln/transformers/models/gemma/__init__.py +2 -2
- optimum/rbln/transformers/models/gemma/configuration_gemma.py +9 -1
- optimum/rbln/transformers/models/gemma/gemma_architecture.py +3 -44
- optimum/rbln/transformers/models/gemma/modeling_gemma.py +22 -1
- optimum/rbln/transformers/models/gemma3/configuration_gemma3.py +21 -9
- optimum/rbln/transformers/models/gemma3/gemma3_architecture.py +9 -63
- optimum/rbln/transformers/models/gemma3/modeling_gemma3.py +200 -292
- optimum/rbln/transformers/models/gpt2/__init__.py +2 -2
- optimum/rbln/transformers/models/gpt2/configuration_gpt2.py +31 -3
- optimum/rbln/transformers/models/gpt2/gpt2_architecture.py +19 -24
- optimum/rbln/transformers/models/gpt2/modeling_gpt2.py +18 -1
- optimum/rbln/transformers/models/idefics3/configuration_idefics3.py +2 -2
- optimum/rbln/transformers/models/idefics3/modeling_idefics3.py +3 -9
- optimum/rbln/transformers/models/llama/__init__.py +2 -2
- optimum/rbln/transformers/models/llama/configuration_llama.py +9 -1
- optimum/rbln/transformers/models/llama/modeling_llama.py +22 -1
- optimum/rbln/transformers/models/llava/__init__.py +16 -0
- optimum/rbln/transformers/models/llava/configuration_llava.py +54 -0
- optimum/rbln/transformers/models/llava/modeling_llava.py +419 -0
- optimum/rbln/transformers/models/llava_next/configuration_llava_next.py +20 -3
- optimum/rbln/transformers/models/llava_next/modeling_llava_next.py +6 -16
- optimum/rbln/transformers/models/midm/midm_architecture.py +14 -22
- optimum/rbln/transformers/models/mistral/__init__.py +2 -2
- optimum/rbln/transformers/models/mistral/configuration_mistral.py +9 -1
- optimum/rbln/transformers/models/mistral/mistral_architecture.py +1 -1
- optimum/rbln/transformers/models/mistral/modeling_mistral.py +26 -3
- optimum/rbln/transformers/models/opt/__init__.py +2 -2
- optimum/rbln/transformers/models/opt/configuration_opt.py +8 -1
- optimum/rbln/transformers/models/opt/modeling_opt.py +41 -1
- optimum/rbln/transformers/models/opt/opt_architecture.py +16 -25
- optimum/rbln/transformers/models/pegasus/__init__.py +17 -0
- optimum/rbln/transformers/models/pegasus/configuration_pegasus.py +34 -0
- optimum/rbln/transformers/models/pegasus/modeling_pegasus.py +69 -0
- optimum/rbln/transformers/models/pegasus/pegasus_architecture.py +161 -0
- optimum/rbln/transformers/models/phi/__init__.py +2 -2
- optimum/rbln/transformers/models/phi/configuration_phi.py +9 -1
- optimum/rbln/transformers/models/phi/modeling_phi.py +10 -1
- optimum/rbln/transformers/models/phi/phi_architecture.py +16 -22
- optimum/rbln/transformers/models/pixtral/__init__.py +16 -0
- optimum/rbln/transformers/models/pixtral/configuration_pixtral.py +43 -0
- optimum/rbln/transformers/models/pixtral/modeling_pixtral.py +315 -0
- optimum/rbln/transformers/models/pixtral/pixtral_architecture.py +73 -0
- optimum/rbln/transformers/models/qwen2/__init__.py +2 -2
- optimum/rbln/transformers/models/qwen2/configuration_qwen2.py +9 -1
- optimum/rbln/transformers/models/qwen2/modeling_qwen2.py +27 -1
- optimum/rbln/transformers/models/qwen2_5_vl/configuration_qwen2_5_vl.py +3 -3
- optimum/rbln/transformers/models/qwen2_5_vl/modeling_qwen2_5_vl.py +5 -15
- optimum/rbln/transformers/models/qwen2_5_vl/qwen2_5_vl_architecture.py +1 -4
- optimum/rbln/transformers/models/qwen3/__init__.py +16 -0
- optimum/rbln/transformers/models/qwen3/configuration_qwen3.py +71 -0
- optimum/rbln/transformers/models/qwen3/modeling_qwen3.py +133 -0
- optimum/rbln/transformers/models/qwen3/qwen3_architecture.py +31 -0
- optimum/rbln/transformers/models/seq2seq/configuration_seq2seq.py +2 -12
- optimum/rbln/transformers/models/seq2seq/modeling_seq2seq.py +3 -1
- optimum/rbln/transformers/models/siglip/__init__.py +2 -6
- optimum/rbln/transformers/models/siglip/modeling_siglip.py +2 -2
- optimum/rbln/transformers/models/time_series_transformer/configuration_time_series_transformer.py +2 -2
- optimum/rbln/transformers/models/time_series_transformer/modeling_time_series_transformer.py +3 -5
- optimum/rbln/transformers/models/whisper/configuration_whisper.py +3 -12
- optimum/rbln/transformers/models/whisper/modeling_whisper.py +8 -2
- optimum/rbln/transformers/models/xlm_roberta/__init__.py +2 -8
- optimum/rbln/utils/depreacate_utils.py +16 -0
- optimum/rbln/utils/hub.py +8 -47
- optimum/rbln/utils/runtime_utils.py +31 -5
- {optimum_rbln-0.8.1rc0.dist-info → optimum_rbln-0.8.2.dist-info}/METADATA +1 -1
- {optimum_rbln-0.8.1rc0.dist-info → optimum_rbln-0.8.2.dist-info}/RECORD +120 -103
- {optimum_rbln-0.8.1rc0.dist-info → optimum_rbln-0.8.2.dist-info}/WHEEL +0 -0
- {optimum_rbln-0.8.1rc0.dist-info → optimum_rbln-0.8.2.dist-info}/licenses/LICENSE +0 -0
|
@@ -19,33 +19,28 @@ from typing import TYPE_CHECKING, Any, Callable, Dict, List, Optional, Tuple, Un
|
|
|
19
19
|
import rebel
|
|
20
20
|
import torch
|
|
21
21
|
from rebel.compile_context import CompileContext
|
|
22
|
-
from transformers import
|
|
23
|
-
AutoModelForImageTextToText,
|
|
24
|
-
Gemma3ForConditionalGeneration,
|
|
25
|
-
PretrainedConfig,
|
|
26
|
-
PreTrainedModel,
|
|
27
|
-
)
|
|
22
|
+
from transformers import AutoModelForImageTextToText, Gemma3ForConditionalGeneration, PretrainedConfig, PreTrainedModel
|
|
28
23
|
from transformers.modeling_outputs import BaseModelOutputWithPooling
|
|
29
24
|
from transformers.modeling_utils import no_init_weights
|
|
30
25
|
from transformers.models.gemma3.modeling_gemma3 import Gemma3TextScaledWordEmbedding
|
|
31
26
|
|
|
32
27
|
from ....configuration_utils import RBLNCompileConfig, RBLNModelConfig
|
|
33
28
|
from ....modeling import RBLNModel
|
|
34
|
-
from
|
|
35
|
-
|
|
29
|
+
from ..decoderonly.modeling_decoderonly import (
|
|
30
|
+
RBLNDecoderOnlyForCausalLMOutput,
|
|
31
|
+
RBLNDecoderOnlyModelForCausalLM,
|
|
32
|
+
RBLNRuntimeModel,
|
|
33
|
+
)
|
|
36
34
|
from .configuration_gemma3 import RBLNGemma3ForCausalLMConfig
|
|
37
35
|
from .gemma3_architecture import Gemma3ForCausalLMWrapper
|
|
38
36
|
|
|
39
37
|
|
|
40
|
-
logger = get_logger()
|
|
41
|
-
|
|
42
|
-
|
|
43
38
|
if TYPE_CHECKING:
|
|
44
39
|
from transformers import AutoFeatureExtractor, AutoProcessor, AutoTokenizer, Gemma3ForConditionalGeneration
|
|
45
40
|
|
|
46
41
|
|
|
47
42
|
@dataclass
|
|
48
|
-
class RBLNGemma3ForCausalLMOutput(
|
|
43
|
+
class RBLNGemma3ForCausalLMOutput(RBLNDecoderOnlyForCausalLMOutput):
|
|
49
44
|
attention_mask: Optional[torch.Tensor] = None
|
|
50
45
|
|
|
51
46
|
|
|
@@ -201,7 +196,7 @@ class RBLNGemma3ForConditionalGeneration(RBLNModel):
|
|
|
201
196
|
|
|
202
197
|
def _update_model_kwargs_for_generation(
|
|
203
198
|
self,
|
|
204
|
-
outputs:
|
|
199
|
+
outputs: RBLNDecoderOnlyForCausalLMOutput,
|
|
205
200
|
model_kwargs: Dict[str, Any],
|
|
206
201
|
**kwargs,
|
|
207
202
|
) -> Dict[str, Any]:
|
|
@@ -258,19 +253,47 @@ class RBLNGemma3ForConditionalGeneration(RBLNModel):
|
|
|
258
253
|
|
|
259
254
|
return inputs_embeds
|
|
260
255
|
|
|
256
|
+
def get_padded_cache_position(
|
|
257
|
+
self,
|
|
258
|
+
cache_position: torch.Tensor, # shape: [1, seq_len]
|
|
259
|
+
token_type_ids: torch.Tensor, # shape: [1, seq_len]
|
|
260
|
+
) -> torch.Tensor:
|
|
261
|
+
seq_len = cache_position[0][-1].item() + 1
|
|
262
|
+
|
|
263
|
+
# Find image start positions
|
|
264
|
+
image_starts = [
|
|
265
|
+
s
|
|
266
|
+
for s in torch.where(token_type_ids == 1)[1]
|
|
267
|
+
if torch.all(token_type_ids[:, s : s + self.rbln_config.image_prefill_chunk_size] == 1)
|
|
268
|
+
]
|
|
269
|
+
|
|
270
|
+
# Initialize padded tensors
|
|
271
|
+
padded_input_len = seq_len
|
|
272
|
+
for image_start in image_starts:
|
|
273
|
+
pad_needed = (
|
|
274
|
+
self.rbln_config.image_prefill_chunk_size
|
|
275
|
+
- (image_start + padded_input_len - seq_len) % self.rbln_config.image_prefill_chunk_size
|
|
276
|
+
) % self.rbln_config.image_prefill_chunk_size
|
|
277
|
+
padded_input_len += pad_needed
|
|
278
|
+
|
|
279
|
+
return torch.cat(
|
|
280
|
+
[cache_position, torch.arange(seq_len, padded_input_len, dtype=torch.int32).unsqueeze(0)],
|
|
281
|
+
dim=1,
|
|
282
|
+
)
|
|
283
|
+
|
|
261
284
|
def forward(
|
|
262
285
|
self,
|
|
263
286
|
input_ids: torch.LongTensor = None,
|
|
287
|
+
attention_mask: torch.Tensor = None,
|
|
288
|
+
token_type_ids: torch.Tensor = None,
|
|
264
289
|
pixel_values: torch.FloatTensor = None,
|
|
265
|
-
attention_mask: Optional[torch.Tensor] = None,
|
|
266
290
|
cache_position: Optional[torch.LongTensor] = None,
|
|
267
291
|
inputs_embeds: Optional[torch.FloatTensor] = None,
|
|
268
292
|
generate_idx: Optional[torch.Tensor] = None,
|
|
269
293
|
padded_cache_lengths: Optional[torch.Tensor] = None,
|
|
270
294
|
position_ids: Optional[torch.Tensor] = None,
|
|
271
|
-
token_type_ids: Optional[torch.Tensor] = None,
|
|
272
295
|
**lm_kwargs: Dict[str, Any],
|
|
273
|
-
) -> Union[Tuple,
|
|
296
|
+
) -> Union[Tuple, RBLNDecoderOnlyForCausalLMOutput]:
|
|
274
297
|
# prefill
|
|
275
298
|
if cache_position is None:
|
|
276
299
|
logits = []
|
|
@@ -279,12 +302,15 @@ class RBLNGemma3ForConditionalGeneration(RBLNModel):
|
|
|
279
302
|
|
|
280
303
|
for b_idx in range(batch_size):
|
|
281
304
|
cache_position = torch.arange(0, generate_idx[b_idx].item(), dtype=torch.int32).unsqueeze(0)
|
|
305
|
+
token_type_id = token_type_ids[b_idx : b_idx + 1, attention_mask[b_idx].bool()]
|
|
306
|
+
cache_position = self.get_padded_cache_position(cache_position, token_type_id)
|
|
307
|
+
|
|
282
308
|
output = self.language_model.prefill_decoder(
|
|
283
309
|
inputs_embeds=inputs_embeds[b_idx : b_idx + 1],
|
|
284
310
|
attention_mask=attention_mask[b_idx],
|
|
285
311
|
cache_position=cache_position,
|
|
286
312
|
batch_idx=b_idx,
|
|
287
|
-
token_type_ids=token_type_ids[b_idx : b_idx + 1]
|
|
313
|
+
token_type_ids=token_type_ids[b_idx : b_idx + 1], # do not pass token_type_id
|
|
288
314
|
)
|
|
289
315
|
padded_cache_lengths[b_idx] += output.padded_cache_lengths
|
|
290
316
|
logits.append(output.logits)
|
|
@@ -308,7 +334,7 @@ class RBLNGemma3ForConditionalGeneration(RBLNModel):
|
|
|
308
334
|
position_ids=position_ids if self.rbln_config.language_model.use_position_ids else None,
|
|
309
335
|
).logits
|
|
310
336
|
|
|
311
|
-
return
|
|
337
|
+
return RBLNDecoderOnlyForCausalLMOutput(
|
|
312
338
|
logits=logits, generate_idx=generate_idx, padded_cache_lengths=padded_cache_lengths
|
|
313
339
|
)
|
|
314
340
|
|
|
@@ -320,194 +346,30 @@ class RBLNGemma3RuntimeModel(RBLNRuntimeModel):
|
|
|
320
346
|
self.prefill = self.runtime if self.phase == "prefill" else None # FIXME
|
|
321
347
|
self.decode = self.runtime if self.phase == "decode" else None
|
|
322
348
|
|
|
323
|
-
def
|
|
324
|
-
|
|
325
|
-
|
|
326
|
-
|
|
327
|
-
|
|
328
|
-
|
|
329
|
-
|
|
330
|
-
|
|
331
|
-
|
|
332
|
-
|
|
333
|
-
|
|
334
|
-
|
|
335
|
-
inputs: (1, seq_len, hidden_size) tensor.
|
|
336
|
-
attention_mask: (1, seq_len) tensor, 1 for valid, 0 for masked.
|
|
337
|
-
position_ids: (1, seq_len) tensor for RoPE.
|
|
338
|
-
token_type_ids: (1, seq_len) tensor, 0 for text, 1 for image.
|
|
339
|
-
|
|
340
|
-
Returns:
|
|
341
|
-
(inputs_padded, attention_mask_padded, position_ids_padded, padded_len, token_type_ids_padded).
|
|
342
|
-
"""
|
|
343
|
-
|
|
344
|
-
if token_type_ids is None:
|
|
345
|
-
return inputs, attention_mask, position_ids, 0, torch.zeros(inputs.shape[:2], dtype=torch.long)
|
|
346
|
-
|
|
347
|
-
seq_len = inputs.shape[1]
|
|
348
|
-
|
|
349
|
-
# Find image start positions
|
|
350
|
-
image_starts = [
|
|
351
|
-
s
|
|
352
|
-
for s in range(seq_len - self.rbln_config.prefill_chunk_size + 1)
|
|
353
|
-
if torch.all(token_type_ids[:, s : s + self.rbln_config.prefill_chunk_size] == 1)
|
|
354
|
-
]
|
|
355
|
-
|
|
356
|
-
# Initialize padded tensors
|
|
357
|
-
padded_input_len = seq_len
|
|
358
|
-
for image_start in image_starts:
|
|
359
|
-
pad_needed = (
|
|
360
|
-
self.rbln_config.prefill_chunk_size
|
|
361
|
-
- (image_start + padded_input_len - seq_len) % self.rbln_config.prefill_chunk_size
|
|
362
|
-
) % self.rbln_config.prefill_chunk_size
|
|
363
|
-
padded_input_len += pad_needed
|
|
364
|
-
total_padding = padded_input_len - seq_len
|
|
365
|
-
|
|
366
|
-
if inputs.dim() == 3:
|
|
367
|
-
inputs_padded = torch.zeros(1, padded_input_len, inputs.shape[2], dtype=inputs.dtype)
|
|
368
|
-
else:
|
|
369
|
-
inputs_padded = torch.zeros(1, padded_input_len, dtype=inputs.dtype)
|
|
370
|
-
attention_mask_padded = torch.zeros(1, padded_input_len, dtype=attention_mask.dtype)
|
|
371
|
-
position_ids_padded = torch.zeros(1, padded_input_len, dtype=position_ids.dtype)
|
|
372
|
-
token_type_ids_padded = torch.zeros(1, padded_input_len, dtype=token_type_ids.dtype)
|
|
373
|
-
|
|
374
|
-
# Fill padded tensors
|
|
375
|
-
dest_pos = 0
|
|
376
|
-
src_pos = 0
|
|
377
|
-
last_pos_id = -1
|
|
378
|
-
for image_start in image_starts + [seq_len]:
|
|
379
|
-
# Text segment
|
|
380
|
-
if src_pos < image_start:
|
|
381
|
-
length = image_start - src_pos
|
|
382
|
-
inputs_padded[:, dest_pos : dest_pos + length] = inputs[:, src_pos:image_start]
|
|
383
|
-
attention_mask_padded[:, dest_pos : dest_pos + length] = attention_mask[:, src_pos:image_start]
|
|
384
|
-
position_ids_padded[:, dest_pos : dest_pos + length] = position_ids[:, src_pos:image_start]
|
|
385
|
-
token_type_ids_padded[:, dest_pos : dest_pos + length] = token_type_ids[:, src_pos:image_start]
|
|
386
|
-
dest_pos += length
|
|
387
|
-
last_pos_id = position_ids[0, image_start - 1].item()
|
|
388
|
-
src_pos = image_start
|
|
389
|
-
|
|
390
|
-
# Padding
|
|
391
|
-
pad_needed = (
|
|
392
|
-
self.rbln_config.prefill_chunk_size - dest_pos % self.rbln_config.prefill_chunk_size
|
|
393
|
-
) % self.rbln_config.prefill_chunk_size
|
|
394
|
-
if pad_needed and dest_pos < padded_input_len:
|
|
395
|
-
position_ids_padded[:, dest_pos : dest_pos + pad_needed] = torch.arange(
|
|
396
|
-
last_pos_id + 1, last_pos_id + pad_needed + 1, dtype=position_ids.dtype
|
|
397
|
-
).unsqueeze(0)
|
|
398
|
-
dest_pos += pad_needed
|
|
399
|
-
|
|
400
|
-
# Image segment
|
|
401
|
-
if src_pos < seq_len and src_pos == image_start:
|
|
402
|
-
inputs_padded[:, dest_pos : dest_pos + self.rbln_config.prefill_chunk_size] = inputs[
|
|
403
|
-
:, src_pos : src_pos + self.rbln_config.prefill_chunk_size
|
|
404
|
-
]
|
|
405
|
-
attention_mask_padded[:, dest_pos : dest_pos + self.rbln_config.prefill_chunk_size] = attention_mask[
|
|
406
|
-
:, src_pos : src_pos + self.rbln_config.prefill_chunk_size
|
|
407
|
-
]
|
|
408
|
-
position_ids_padded[:, dest_pos : dest_pos + self.rbln_config.prefill_chunk_size] = position_ids[
|
|
409
|
-
:, src_pos : src_pos + self.rbln_config.prefill_chunk_size
|
|
410
|
-
]
|
|
411
|
-
token_type_ids_padded[:, dest_pos : dest_pos + self.rbln_config.prefill_chunk_size] = token_type_ids[
|
|
412
|
-
:, src_pos : src_pos + self.rbln_config.prefill_chunk_size
|
|
413
|
-
]
|
|
414
|
-
dest_pos += self.rbln_config.prefill_chunk_size
|
|
415
|
-
src_pos += self.rbln_config.prefill_chunk_size
|
|
416
|
-
last_pos_id = position_ids[0, image_start + self.rbln_config.prefill_chunk_size - 1].item()
|
|
417
|
-
|
|
418
|
-
return inputs_padded, attention_mask_padded, position_ids_padded, total_padding, token_type_ids_padded
|
|
419
|
-
|
|
420
|
-
def _prepare_prefill_inputs(
|
|
421
|
-
self,
|
|
422
|
-
inputs: torch.Tensor,
|
|
423
|
-
cache_position: torch.Tensor,
|
|
424
|
-
attention_mask: Optional[torch.Tensor] = None,
|
|
425
|
-
position_embed: Optional[torch.Tensor] = None,
|
|
426
|
-
token_type_ids: Optional[torch.Tensor] = None,
|
|
427
|
-
):
|
|
428
|
-
"""
|
|
429
|
-
Prepare inputs for prefill phase.
|
|
430
|
-
"""
|
|
431
|
-
# Handle continuous batching in a compiled graph by extracting valid inputs
|
|
432
|
-
# If an attention mask is provided, select only the valid (non-masked) inputs
|
|
433
|
-
inputs = inputs[:, attention_mask.bool()] if attention_mask is not None else inputs
|
|
434
|
-
token_type_ids = (
|
|
435
|
-
token_type_ids[:, attention_mask.bool()]
|
|
436
|
-
if attention_mask is not None and token_type_ids is not None
|
|
437
|
-
else token_type_ids
|
|
438
|
-
)
|
|
439
|
-
|
|
440
|
-
if position_embed is not None:
|
|
441
|
-
position_embed = (
|
|
442
|
-
position_embed[:, :, :, attention_mask.bool(), :] if attention_mask is not None else position_embed
|
|
443
|
-
)
|
|
444
|
-
|
|
445
|
-
seq_len = inputs.shape[1]
|
|
446
|
-
# Initialize attention mask for chunked processing
|
|
447
|
-
if self.rbln_config.use_attention_mask:
|
|
448
|
-
chunked_attention_mask = (
|
|
449
|
-
torch.ones(1, seq_len, dtype=torch.float32)
|
|
450
|
-
if self.rbln_config.use_position_ids
|
|
451
|
-
else torch.zeros(
|
|
452
|
-
1, 1, self.rbln_config.prefill_chunk_size, self.rbln_config.max_seq_len, dtype=torch.float32
|
|
453
|
-
)
|
|
454
|
-
)
|
|
455
|
-
else:
|
|
456
|
-
chunked_attention_mask = None
|
|
457
|
-
|
|
458
|
-
# Buffer for storing output logits
|
|
459
|
-
out_buffers = [
|
|
460
|
-
torch.empty(
|
|
461
|
-
size=self.output_size,
|
|
462
|
-
dtype=torch.float32,
|
|
463
|
-
device="cpu",
|
|
464
|
-
)
|
|
465
|
-
]
|
|
466
|
-
|
|
467
|
-
inputs, chunked_attention_mask, position_ids, padded_cache_lengths, token_type_ids_padded = (
|
|
468
|
-
self.pad_for_chunked_images(inputs, chunked_attention_mask, cache_position, token_type_ids)
|
|
469
|
-
)
|
|
470
|
-
|
|
471
|
-
query_length = inputs.shape[1]
|
|
472
|
-
if query_length > self.rbln_config.max_seq_len:
|
|
473
|
-
raise ValueError(
|
|
474
|
-
f"Input length ({query_length}) exceeds the maximum allowed sequence length ({self.rbln_config.max_seq_len})."
|
|
475
|
-
)
|
|
476
|
-
|
|
477
|
-
# Align attention_mask to compiled shape
|
|
478
|
-
if self.rbln_config.use_position_ids:
|
|
479
|
-
chunked_attention_mask = torch.nn.functional.pad(
|
|
480
|
-
chunked_attention_mask, (0, self.rbln_config.max_seq_len - query_length)
|
|
481
|
-
)
|
|
482
|
-
|
|
483
|
-
# Pad input and cache_position if the last chunk is smaller than `prefill_chunk_size`
|
|
484
|
-
padding_size = 0
|
|
485
|
-
if query_length % self.rbln_config.prefill_chunk_size != 0:
|
|
486
|
-
padding_size = (self.rbln_config.prefill_chunk_size - query_length) % self.rbln_config.prefill_chunk_size
|
|
487
|
-
# inputs_embeds
|
|
488
|
-
if inputs.dim() == 3:
|
|
489
|
-
inputs = torch.nn.functional.pad(inputs, (0, 0, 0, padding_size))
|
|
490
|
-
# inputs_ids
|
|
491
|
-
else:
|
|
492
|
-
inputs = torch.nn.functional.pad(inputs, (0, padding_size))
|
|
493
|
-
|
|
494
|
-
position_ids = torch.cat(
|
|
495
|
-
[
|
|
496
|
-
position_ids,
|
|
497
|
-
torch.arange(
|
|
498
|
-
query_length,
|
|
499
|
-
query_length + padding_size,
|
|
500
|
-
dtype=torch.int32,
|
|
501
|
-
).unsqueeze(0),
|
|
502
|
-
],
|
|
503
|
-
dim=-1,
|
|
504
|
-
)
|
|
505
|
-
token_type_ids_padded = torch.nn.functional.pad(token_type_ids_padded, (0, padding_size))
|
|
349
|
+
def _prepare_prefill_inputs(self, *args, **kwargs):
|
|
350
|
+
(
|
|
351
|
+
inputs,
|
|
352
|
+
cache_position,
|
|
353
|
+
chunked_attention_mask,
|
|
354
|
+
out_buffers,
|
|
355
|
+
position_ids,
|
|
356
|
+
position_embed,
|
|
357
|
+
padded_cache_lengths,
|
|
358
|
+
query_length,
|
|
359
|
+
token_type_ids,
|
|
360
|
+
) = super()._prepare_prefill_inputs(*args, **kwargs)
|
|
506
361
|
|
|
507
|
-
|
|
508
|
-
|
|
362
|
+
# chunked_attention_mask shape
|
|
363
|
+
chunked_attention_mask = torch.zeros(1, chunked_attention_mask.shape[-1], dtype=torch.float32)
|
|
509
364
|
|
|
510
|
-
|
|
365
|
+
# In case of Gemma3ForConditionalGeneration, the loop counter may not be a prefill_chunk_size,
|
|
366
|
+
# so we cannot guarantee that the last chunk starts at a position that is a multiple of prefill_chunk_size.
|
|
367
|
+
if self.rbln_config.use_image_prefill:
|
|
368
|
+
padding_size = self.rbln_config.image_prefill_chunk_size
|
|
369
|
+
inputs = torch.nn.functional.pad(inputs, (0, 0, 0, padding_size))
|
|
370
|
+
cache_position = torch.nn.functional.pad(cache_position, (0, padding_size))
|
|
371
|
+
position_ids = torch.nn.functional.pad(position_ids, (0, padding_size))
|
|
372
|
+
token_type_ids = torch.nn.functional.pad(token_type_ids, (0, padding_size), value=-1)
|
|
511
373
|
|
|
512
374
|
return (
|
|
513
375
|
inputs,
|
|
@@ -518,7 +380,7 @@ class RBLNGemma3RuntimeModel(RBLNRuntimeModel):
|
|
|
518
380
|
position_embed,
|
|
519
381
|
padded_cache_lengths,
|
|
520
382
|
query_length,
|
|
521
|
-
|
|
383
|
+
token_type_ids,
|
|
522
384
|
)
|
|
523
385
|
|
|
524
386
|
def prefill_forward(
|
|
@@ -541,65 +403,69 @@ class RBLNGemma3RuntimeModel(RBLNRuntimeModel):
|
|
|
541
403
|
(
|
|
542
404
|
inputs,
|
|
543
405
|
cache_position,
|
|
544
|
-
|
|
406
|
+
chunked_attention_mask,
|
|
545
407
|
out_buffers,
|
|
546
408
|
position_ids,
|
|
547
409
|
position_embed,
|
|
548
410
|
padded_cache_lengths,
|
|
549
411
|
query_length,
|
|
550
|
-
|
|
412
|
+
token_type_ids,
|
|
551
413
|
) = self._prepare_prefill_inputs(
|
|
552
414
|
inputs, cache_position, attention_mask, position_embed, token_type_ids=token_type_ids
|
|
553
415
|
)
|
|
554
|
-
if not is_external_block_tables:
|
|
555
|
-
local_block_tables = torch.tensor([batch_idx], dtype=torch.int16)
|
|
556
|
-
self.dec_attn_mask[batch_idx : batch_idx + 1] = padded_attention_mask[:1]
|
|
557
416
|
|
|
558
|
-
|
|
559
|
-
|
|
417
|
+
step = 0
|
|
418
|
+
while step < query_length:
|
|
419
|
+
if self.rbln_config.use_image_prefill:
|
|
420
|
+
# Check if the prefill chunk is an image prefill
|
|
421
|
+
is_image_prefill = torch.all(
|
|
422
|
+
token_type_ids[:, step : step + self.rbln_config.image_prefill_chunk_size] == 1
|
|
423
|
+
)
|
|
424
|
+
# Check if the prefill chunk is a text prefill which have image_tokens in it.
|
|
425
|
+
is_text_prefill_with_image_tokens = not is_image_prefill and torch.any(
|
|
426
|
+
token_type_ids[:, step : step + self.rbln_config.prefill_chunk_size] == 1
|
|
427
|
+
)
|
|
428
|
+
else:
|
|
429
|
+
is_image_prefill, is_text_prefill_with_image_tokens = False, False
|
|
430
|
+
|
|
431
|
+
# Check if the prefill chunk is the last chunk
|
|
432
|
+
is_last_chunk = step + self.rbln_config.prefill_chunk_size >= query_length
|
|
560
433
|
|
|
561
|
-
# Process input in chunks of size `prefill_chunk_size`
|
|
562
|
-
for step in range(0, query_length, self.rbln_config.prefill_chunk_size):
|
|
563
|
-
# Extract the current chunk of inputs and cache positions
|
|
564
434
|
input_chunk = inputs[:, step : step + self.rbln_config.prefill_chunk_size]
|
|
565
|
-
cache_pos_chunk =
|
|
566
|
-
|
|
567
|
-
position_ids[:, step : step + self.rbln_config.prefill_chunk_size]
|
|
568
|
-
if position_ids is not None
|
|
569
|
-
else None
|
|
570
|
-
)
|
|
571
|
-
|
|
572
|
-
if self.rbln_config.use_attention_mask:
|
|
573
|
-
if self.rbln_config.use_position_ids:
|
|
574
|
-
chunked_attention_mask[0, step : step + self.rbln_config.prefill_chunk_size] = (
|
|
575
|
-
padded_attention_mask[0, step : step + self.rbln_config.prefill_chunk_size]
|
|
576
|
-
)
|
|
577
|
-
|
|
578
|
-
# Define query position
|
|
579
|
-
query_position = (
|
|
580
|
-
torch.sum(
|
|
581
|
-
chunked_attention_mask[0][step : step + self.rbln_config.prefill_chunk_size],
|
|
582
|
-
dim=-1,
|
|
583
|
-
dtype=torch.int16,
|
|
584
|
-
).squeeze(0)
|
|
585
|
-
- 1
|
|
435
|
+
cache_pos_chunk = (
|
|
436
|
+
cache_position[:, step : step + self.rbln_config.prefill_chunk_size] + padded_cache_lengths
|
|
586
437
|
)
|
|
587
|
-
|
|
588
|
-
|
|
589
|
-
|
|
590
|
-
|
|
591
|
-
|
|
592
|
-
|
|
593
|
-
|
|
594
|
-
|
|
595
|
-
|
|
596
|
-
|
|
597
|
-
|
|
598
|
-
|
|
599
|
-
|
|
600
|
-
|
|
438
|
+
position_ids_chunk = position_ids[:, step : step + self.rbln_config.prefill_chunk_size]
|
|
439
|
+
|
|
440
|
+
# if text_prefill end with image_tokens, we only treat the text part.
|
|
441
|
+
num_processed_tokens = self.rbln_config.prefill_chunk_size
|
|
442
|
+
current_padded_cache_lengths = 0
|
|
443
|
+
if is_text_prefill_with_image_tokens:
|
|
444
|
+
first_image_token_idx = torch.where(
|
|
445
|
+
token_type_ids[:, step : step + self.rbln_config.prefill_chunk_size] == 1
|
|
446
|
+
)[1][0]
|
|
447
|
+
num_processed_tokens = first_image_token_idx.item()
|
|
448
|
+
current_padded_cache_lengths = self.rbln_config.prefill_chunk_size - num_processed_tokens
|
|
449
|
+
if is_last_chunk:
|
|
450
|
+
num_processed_tokens = query_length - step
|
|
451
|
+
|
|
452
|
+
chunked_attention_mask[
|
|
453
|
+
:, step + padded_cache_lengths : step + num_processed_tokens + padded_cache_lengths
|
|
454
|
+
] = 1
|
|
455
|
+
query_position = torch.tensor(num_processed_tokens - 1, dtype=torch.int16)
|
|
456
|
+
|
|
457
|
+
if is_image_prefill:
|
|
458
|
+
logits = self.image_prefill(
|
|
459
|
+
input_chunk,
|
|
460
|
+
cache_pos_chunk,
|
|
461
|
+
block_tables,
|
|
462
|
+
local_block_tables,
|
|
463
|
+
query_position,
|
|
464
|
+
chunked_attention_mask,
|
|
465
|
+
position_ids_chunk,
|
|
466
|
+
out=out_buffers,
|
|
467
|
+
)
|
|
601
468
|
else:
|
|
602
|
-
# Forward pass for the current chunk
|
|
603
469
|
logits = self.prefill(
|
|
604
470
|
input_chunk,
|
|
605
471
|
cache_pos_chunk,
|
|
@@ -611,6 +477,12 @@ class RBLNGemma3RuntimeModel(RBLNRuntimeModel):
|
|
|
611
477
|
out=out_buffers,
|
|
612
478
|
)
|
|
613
479
|
|
|
480
|
+
padded_cache_lengths += current_padded_cache_lengths
|
|
481
|
+
step += num_processed_tokens
|
|
482
|
+
|
|
483
|
+
if not is_external_block_tables:
|
|
484
|
+
self.dec_attn_mask[batch_idx : batch_idx + 1] = chunked_attention_mask
|
|
485
|
+
|
|
614
486
|
return RBLNGemma3ForCausalLMOutput(
|
|
615
487
|
logits=logits, padded_cache_lengths=padded_cache_lengths, attention_mask=chunked_attention_mask
|
|
616
488
|
)
|
|
@@ -666,7 +538,7 @@ class RBLNGemma3RuntimeModel(RBLNRuntimeModel):
|
|
|
666
538
|
|
|
667
539
|
logits = self.decode(inputs, cache_position, block_tables, local_block_tables, attention_mask, position_ids)
|
|
668
540
|
|
|
669
|
-
return
|
|
541
|
+
return RBLNDecoderOnlyForCausalLMOutput(logits=logits)
|
|
670
542
|
|
|
671
543
|
|
|
672
544
|
class RBLNGemma3ForCausalLM(RBLNDecoderOnlyModelForCausalLM):
|
|
@@ -701,9 +573,10 @@ class RBLNGemma3ForCausalLM(RBLNDecoderOnlyModelForCausalLM):
|
|
|
701
573
|
dtype=torch.int16,
|
|
702
574
|
).fill_(-1)
|
|
703
575
|
free_block_pool = deque(x for x in range(self.rbln_config.kvcache_num_blocks))
|
|
576
|
+
|
|
704
577
|
self.prefill_decoder = RBLNGemma3RuntimeModel(
|
|
705
578
|
runtime=self.model[0],
|
|
706
|
-
image_prefill=self.model[1],
|
|
579
|
+
image_prefill=self.model[1] if self.rbln_config.use_image_prefill else None,
|
|
707
580
|
main_input_name=main_input_name,
|
|
708
581
|
embed_tokens=self.embed_tokens,
|
|
709
582
|
phase="prefill",
|
|
@@ -718,7 +591,7 @@ class RBLNGemma3ForCausalLM(RBLNDecoderOnlyModelForCausalLM):
|
|
|
718
591
|
self.decoders = {}
|
|
719
592
|
for i, batch_size in enumerate(self.rbln_config.decoder_batch_sizes):
|
|
720
593
|
self.decoders[batch_size] = RBLNGemma3RuntimeModel(
|
|
721
|
-
runtime=self.model[i +
|
|
594
|
+
runtime=self.model[i + self.rbln_config.decoder_runtime_idx],
|
|
722
595
|
main_input_name=main_input_name,
|
|
723
596
|
embed_tokens=self.embed_tokens,
|
|
724
597
|
phase="decode",
|
|
@@ -757,13 +630,14 @@ class RBLNGemma3ForCausalLM(RBLNDecoderOnlyModelForCausalLM):
|
|
|
757
630
|
|
|
758
631
|
@classmethod
|
|
759
632
|
def _update_submodule_config(cls, model: "PreTrainedModel", rbln_config: RBLNModelConfig):
|
|
760
|
-
if rbln_config.
|
|
761
|
-
rbln_config.
|
|
633
|
+
if rbln_config.image_prefill_chunk_size is None:
|
|
634
|
+
rbln_config.image_prefill_chunk_size = model.config.mm_tokens_per_image
|
|
762
635
|
|
|
763
|
-
if rbln_config.
|
|
764
|
-
|
|
765
|
-
f"
|
|
636
|
+
if rbln_config.image_prefill_chunk_size != model.config.mm_tokens_per_image:
|
|
637
|
+
raise ValueError(
|
|
638
|
+
f"Image prefill chunk size is different from mm_tokens_per_image: {rbln_config.image_prefill_chunk_size} != {model.config.mm_tokens_per_image}"
|
|
766
639
|
)
|
|
640
|
+
|
|
767
641
|
return rbln_config
|
|
768
642
|
|
|
769
643
|
@classmethod
|
|
@@ -777,15 +651,29 @@ class RBLNGemma3ForCausalLM(RBLNDecoderOnlyModelForCausalLM):
|
|
|
777
651
|
# Update rbln_config with super class
|
|
778
652
|
rbln_config = super()._update_rbln_config(preprocessors, model, model_config, rbln_config)
|
|
779
653
|
|
|
780
|
-
|
|
781
|
-
|
|
782
|
-
|
|
783
|
-
|
|
784
|
-
|
|
785
|
-
|
|
786
|
-
|
|
787
|
-
|
|
788
|
-
|
|
654
|
+
if not (rbln_config.use_attention_mask and rbln_config.use_position_ids):
|
|
655
|
+
raise ValueError("use_attention_mask and use_position_ids must be True for RBLNGemma3ForCausalLM")
|
|
656
|
+
|
|
657
|
+
if rbln_config.use_image_prefill:
|
|
658
|
+
if rbln_config.prefill_chunk_size != rbln_config.image_prefill_chunk_size:
|
|
659
|
+
raise NotImplementedError(
|
|
660
|
+
"Not implemented for different prefill chunk sizes between text and image prefill."
|
|
661
|
+
)
|
|
662
|
+
|
|
663
|
+
# Update image prefill compile config
|
|
664
|
+
img_prefill_input_info = cls.get_input_info(
|
|
665
|
+
batch_size=1,
|
|
666
|
+
query_length=rbln_config.image_prefill_chunk_size,
|
|
667
|
+
rbln_config=rbln_config,
|
|
668
|
+
model_config=model_config,
|
|
669
|
+
)
|
|
670
|
+
image_prefill_compile_config = RBLNCompileConfig(
|
|
671
|
+
compiled_model_name="image_prefill", input_info=img_prefill_input_info
|
|
672
|
+
)
|
|
673
|
+
# Insert image_prefill compile config at index 1
|
|
674
|
+
compile_cfgs = rbln_config.compile_cfgs
|
|
675
|
+
compile_cfgs.insert(1, image_prefill_compile_config)
|
|
676
|
+
rbln_config.set_compile_cfgs(compile_cfgs)
|
|
789
677
|
|
|
790
678
|
return rbln_config
|
|
791
679
|
|
|
@@ -838,20 +726,27 @@ class RBLNGemma3ForCausalLM(RBLNDecoderOnlyModelForCausalLM):
|
|
|
838
726
|
context,
|
|
839
727
|
rbln_config.quantization,
|
|
840
728
|
)
|
|
729
|
+
compiled_models = {"prefill": compiled_prefill}
|
|
841
730
|
|
|
842
|
-
|
|
843
|
-
|
|
844
|
-
|
|
845
|
-
|
|
846
|
-
|
|
847
|
-
|
|
848
|
-
|
|
849
|
-
|
|
850
|
-
|
|
731
|
+
if rbln_config.use_image_prefill:
|
|
732
|
+
image_prefill_compile_config = rbln_compile_configs[1]
|
|
733
|
+
image_prefill_example_inputs = image_prefill_compile_config.get_dummy_inputs(
|
|
734
|
+
fill=0, static_tensors=static_tensors
|
|
735
|
+
)
|
|
736
|
+
wrapped_model.phase = "image_prefill"
|
|
737
|
+
compiled_image_prefill = compile_model(
|
|
738
|
+
wrapped_model,
|
|
739
|
+
image_prefill_compile_config,
|
|
740
|
+
image_prefill_example_inputs,
|
|
741
|
+
context,
|
|
742
|
+
rbln_config.quantization,
|
|
743
|
+
)
|
|
744
|
+
compiled_models["image_prefill"] = compiled_image_prefill
|
|
851
745
|
|
|
852
|
-
compiled_models = {"prefill": compiled_prefill, "image_prefill": compiled_image_prefill}
|
|
853
746
|
wrapped_model.phase = "decode"
|
|
854
|
-
for batch_size, dec_compile_config in zip(
|
|
747
|
+
for batch_size, dec_compile_config in zip(
|
|
748
|
+
rbln_config.decoder_batch_sizes, rbln_compile_configs[rbln_config.decoder_runtime_idx :]
|
|
749
|
+
):
|
|
855
750
|
dec_example_inputs = dec_compile_config.get_dummy_inputs(fill=0, static_tensors=static_tensors)
|
|
856
751
|
compiled_decoder = compile_model(
|
|
857
752
|
wrapped_model,
|
|
@@ -872,32 +767,45 @@ class RBLNGemma3ForCausalLM(RBLNDecoderOnlyModelForCausalLM):
|
|
|
872
767
|
) -> List[rebel.Runtime]:
|
|
873
768
|
expected_model_names = [
|
|
874
769
|
"prefill",
|
|
875
|
-
"image_prefill",
|
|
876
770
|
*[f"decoder_batch_{batch_size}" for batch_size in rbln_config.decoder_batch_sizes],
|
|
877
771
|
]
|
|
772
|
+
if rbln_config.use_image_prefill:
|
|
773
|
+
expected_model_names.insert(1, "image_prefill")
|
|
774
|
+
|
|
878
775
|
if any(model_name not in rbln_config.device_map for model_name in expected_model_names):
|
|
879
776
|
cls._raise_missing_compiled_file_error(expected_model_names)
|
|
880
777
|
|
|
881
|
-
|
|
778
|
+
ret_val = [
|
|
882
779
|
rebel.Runtime(
|
|
883
780
|
compiled_models[0],
|
|
884
781
|
tensor_type="pt",
|
|
885
782
|
device=rbln_config.device_map["prefill"],
|
|
886
783
|
activate_profiler=rbln_config.activate_profiler,
|
|
887
|
-
|
|
888
|
-
|
|
889
|
-
|
|
890
|
-
|
|
891
|
-
|
|
892
|
-
activate_profiler=rbln_config.activate_profiler,
|
|
893
|
-
),
|
|
894
|
-
*[
|
|
784
|
+
timeout=rbln_config.timeout,
|
|
785
|
+
)
|
|
786
|
+
]
|
|
787
|
+
if rbln_config.use_image_prefill:
|
|
788
|
+
ret_val.append(
|
|
895
789
|
rebel.Runtime(
|
|
896
|
-
compiled_models[
|
|
790
|
+
compiled_models[1],
|
|
791
|
+
tensor_type="pt",
|
|
792
|
+
device=rbln_config.device_map["image_prefill"],
|
|
793
|
+
activate_profiler=rbln_config.activate_profiler,
|
|
794
|
+
timeout=rbln_config.timeout,
|
|
795
|
+
),
|
|
796
|
+
)
|
|
797
|
+
|
|
798
|
+
ret_val.extend(
|
|
799
|
+
[
|
|
800
|
+
rebel.Runtime(
|
|
801
|
+
compiled_models[i + rbln_config.decoder_runtime_idx],
|
|
897
802
|
tensor_type="pt",
|
|
898
803
|
device=rbln_config.device_map[f"decoder_batch_{batch_size}"],
|
|
899
804
|
activate_profiler=rbln_config.activate_profiler,
|
|
805
|
+
timeout=rbln_config.timeout,
|
|
900
806
|
)
|
|
901
807
|
for i, batch_size in enumerate(rbln_config.decoder_batch_sizes)
|
|
902
|
-
]
|
|
903
|
-
|
|
808
|
+
]
|
|
809
|
+
)
|
|
810
|
+
|
|
811
|
+
return ret_val
|
|
@@ -12,5 +12,5 @@
|
|
|
12
12
|
# See the License for the specific language governing permissions and
|
|
13
13
|
# limitations under the License.
|
|
14
14
|
|
|
15
|
-
from .configuration_gpt2 import RBLNGPT2LMHeadModelConfig
|
|
16
|
-
from .modeling_gpt2 import RBLNGPT2LMHeadModel
|
|
15
|
+
from .configuration_gpt2 import RBLNGPT2LMHeadModelConfig, RBLNGPT2ModelConfig
|
|
16
|
+
from .modeling_gpt2 import RBLNGPT2LMHeadModel, RBLNGPT2Model
|