optimum-rbln 0.1.15__py3-none-any.whl → 0.2.1a0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (80) hide show
  1. optimum/rbln/__init__.py +26 -33
  2. optimum/rbln/__version__.py +2 -2
  3. optimum/rbln/diffusers/__init__.py +4 -0
  4. optimum/rbln/{modeling_diffusers.py → diffusers/modeling_diffusers.py} +66 -24
  5. optimum/rbln/diffusers/models/__init__.py +2 -0
  6. optimum/rbln/diffusers/models/autoencoders/autoencoder_kl.py +38 -12
  7. optimum/rbln/diffusers/models/autoencoders/vae.py +0 -1
  8. optimum/rbln/diffusers/models/controlnet.py +1 -1
  9. optimum/rbln/diffusers/models/transformers/transformer_sd3.py +1 -1
  10. optimum/rbln/diffusers/models/unets/unet_2d_condition.py +5 -7
  11. optimum/rbln/diffusers/pipelines/__init__.py +1 -0
  12. optimum/rbln/diffusers/pipelines/controlnet/multicontrolnet.py +8 -7
  13. optimum/rbln/diffusers/pipelines/controlnet/pipeline_controlnet.py +17 -2
  14. optimum/rbln/diffusers/pipelines/controlnet/pipeline_controlnet_img2img.py +17 -2
  15. optimum/rbln/diffusers/pipelines/controlnet/pipeline_controlnet_sd_xl.py +17 -2
  16. optimum/rbln/diffusers/pipelines/controlnet/pipeline_controlnet_sd_xl_img2img.py +17 -2
  17. optimum/rbln/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion.py +1 -2
  18. optimum/rbln/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_img2img.py +1 -2
  19. optimum/rbln/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_inpaint.py +1 -2
  20. optimum/rbln/diffusers/pipelines/stable_diffusion_3/pipeline_stable_diffusion_3.py +1 -2
  21. optimum/rbln/diffusers/pipelines/stable_diffusion_3/pipeline_stable_diffusion_3_img2img.py +1 -2
  22. optimum/rbln/diffusers/pipelines/stable_diffusion_3/pipeline_stable_diffusion_3_inpaint.py +1 -2
  23. optimum/rbln/diffusers/pipelines/stable_diffusion_xl/__init__.py +23 -0
  24. optimum/rbln/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl.py +1 -2
  25. optimum/rbln/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_img2img.py +1 -2
  26. optimum/rbln/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_inpaint.py +1 -2
  27. optimum/rbln/modeling.py +13 -347
  28. optimum/rbln/modeling_base.py +24 -4
  29. optimum/rbln/modeling_config.py +31 -7
  30. optimum/rbln/ops/__init__.py +26 -0
  31. optimum/rbln/ops/attn.py +221 -0
  32. optimum/rbln/ops/flash_attn.py +70 -0
  33. optimum/rbln/ops/kv_cache_update.py +69 -0
  34. optimum/rbln/transformers/__init__.py +20 -0
  35. optimum/rbln/{modeling_alias.py → transformers/modeling_alias.py} +5 -1
  36. optimum/rbln/transformers/modeling_generic.py +385 -0
  37. optimum/rbln/transformers/models/auto/__init__.py +23 -0
  38. optimum/rbln/transformers/models/auto/modeling_auto.py +0 -1
  39. optimum/rbln/transformers/models/bart/__init__.py +0 -1
  40. optimum/rbln/transformers/models/bart/bart_architecture.py +107 -464
  41. optimum/rbln/transformers/models/bart/modeling_bart.py +8 -4
  42. optimum/rbln/transformers/models/clip/modeling_clip.py +1 -1
  43. optimum/rbln/transformers/models/decoderonly/__init__.py +0 -7
  44. optimum/rbln/transformers/models/decoderonly/decoderonly_architecture.py +329 -328
  45. optimum/rbln/transformers/models/decoderonly/modeling_decoderonly.py +92 -107
  46. optimum/rbln/transformers/models/exaone/exaone_architecture.py +2 -3
  47. optimum/rbln/transformers/models/gemma/gemma_architecture.py +1 -1
  48. optimum/rbln/transformers/models/gpt2/gpt2_architecture.py +10 -10
  49. optimum/rbln/transformers/models/gpt2/modeling_gpt2.py +1 -1
  50. optimum/rbln/transformers/models/llama/llama_architecture.py +0 -1
  51. optimum/rbln/transformers/models/llava_next/modeling_llava_next.py +1 -0
  52. optimum/rbln/transformers/models/midm/midm_architecture.py +11 -11
  53. optimum/rbln/transformers/models/midm/modeling_midm.py +0 -1
  54. optimum/rbln/transformers/models/mistral/mistral_architecture.py +0 -1
  55. optimum/rbln/transformers/models/phi/phi_architecture.py +2 -3
  56. optimum/rbln/transformers/models/qwen2/qwen2_architecture.py +0 -1
  57. optimum/rbln/transformers/models/seq2seq/modeling_seq2seq.py +57 -57
  58. optimum/rbln/transformers/models/seq2seq/seq2seq_architecture.py +498 -0
  59. optimum/rbln/transformers/models/t5/__init__.py +0 -1
  60. optimum/rbln/transformers/models/t5/modeling_t5.py +5 -2
  61. optimum/rbln/transformers/models/t5/t5_architecture.py +106 -448
  62. optimum/rbln/transformers/models/whisper/generation_whisper.py +42 -0
  63. optimum/rbln/transformers/models/whisper/modeling_whisper.py +77 -54
  64. optimum/rbln/transformers/models/whisper/whisper_architecture.py +219 -312
  65. optimum/rbln/transformers/utils/rbln_quantization.py +1 -2
  66. optimum/rbln/utils/decorator_utils.py +51 -15
  67. optimum/rbln/utils/import_utils.py +8 -1
  68. optimum/rbln/utils/logging.py +38 -1
  69. optimum/rbln/utils/model_utils.py +0 -1
  70. optimum/rbln/utils/runtime_utils.py +9 -3
  71. optimum/rbln/utils/save_utils.py +17 -0
  72. optimum/rbln/utils/submodule.py +23 -0
  73. optimum_rbln-0.2.1a0.dist-info/METADATA +121 -0
  74. {optimum_rbln-0.1.15.dist-info → optimum_rbln-0.2.1a0.dist-info}/RECORD +76 -72
  75. optimum_rbln-0.2.1a0.dist-info/licenses/LICENSE +288 -0
  76. optimum/rbln/transformers/cache_utils.py +0 -107
  77. optimum/rbln/utils/timer_utils.py +0 -43
  78. optimum_rbln-0.1.15.dist-info/METADATA +0 -106
  79. optimum_rbln-0.1.15.dist-info/licenses/LICENSE +0 -201
  80. {optimum_rbln-0.1.15.dist-info → optimum_rbln-0.2.1a0.dist-info}/WHEEL +0 -0
@@ -20,6 +20,7 @@
20
20
  # are the intellectual property of Rebellions Inc. and may not be
21
21
  # copied, modified, or distributed without prior written permission
22
22
  # from Rebellions Inc.
23
+
23
24
  import inspect
24
25
  from dataclasses import dataclass
25
26
  from pathlib import Path
@@ -27,28 +28,26 @@ from typing import TYPE_CHECKING, Any, Callable, Dict, List, Optional, Tuple, Un
27
28
 
28
29
  import rebel
29
30
  import torch
31
+ from rebel.compile_context import CompileContext
30
32
  from transformers import AutoConfig, AutoModelForCausalLM, PretrainedConfig, PreTrainedModel
31
33
  from transformers.modeling_utils import no_init_weights
32
34
  from transformers.utils import ModelOutput
33
35
 
34
36
  from ....modeling import RBLNModel
35
- from ....modeling_config import DEFAULT_COMPILED_MODEL_NAME, RBLNCompileConfig, RBLNConfig
37
+ from ....modeling_config import RBLNCompileConfig, RBLNConfig
36
38
  from ....utils.logging import get_logger
37
39
  from ....utils.runtime_utils import RBLNPytorchRuntime
38
- from ....utils.timer_utils import rbln_timer
39
40
  from ...utils.rbln_quantization import QuantizationManager
40
- from .decoderonly_architecture import DecoderOnlyWrapper
41
+ from .decoderonly_architecture import (
42
+ DecoderOnlyWrapper,
43
+ validate_attention_method,
44
+ )
41
45
 
42
46
 
43
47
  logger = get_logger()
44
48
 
45
49
  if TYPE_CHECKING:
46
- from transformers import (
47
- AutoFeatureExtractor,
48
- AutoProcessor,
49
- AutoTokenizer,
50
- PretrainedConfig,
51
- )
50
+ from transformers import AutoFeatureExtractor, AutoProcessor, AutoTokenizer, PretrainedConfig
52
51
 
53
52
 
54
53
  class RBLNRuntimeModel(RBLNPytorchRuntime):
@@ -60,32 +59,21 @@ class RBLNRuntimeModel(RBLNPytorchRuntime):
60
59
  inputs_embeds: torch.Tensor,
61
60
  attention_mask: torch.Tensor,
62
61
  cache_position: torch.Tensor,
63
- batch_position: torch.Tensor,
64
- query_idx: torch.Tensor,
65
62
  **kwargs,
66
63
  ):
67
64
  if inputs_embeds is None:
68
65
  inp = input_ids
69
66
  if self.embed_tokens is not None:
70
67
  inp = self.embed_tokens(inp)
71
-
72
- return super().forward(
73
- inp,
74
- attention_mask,
75
- cache_position,
76
- batch_position,
77
- query_idx,
78
- **kwargs,
79
- )
80
68
  else:
81
- return super().forward(
82
- inputs_embeds,
83
- attention_mask,
84
- cache_position,
85
- batch_position,
86
- query_idx,
87
- **kwargs,
88
- )
69
+ inp = inputs_embeds
70
+
71
+ return super().forward(
72
+ inp,
73
+ attention_mask,
74
+ cache_position,
75
+ **kwargs,
76
+ )
89
77
 
90
78
 
91
79
  @dataclass
@@ -243,11 +231,8 @@ class RBLNDecoderOnlyModelForCausalLM(RBLNModel):
243
231
  @classmethod
244
232
  def wrap_model_if_needed(cls, model: "PreTrainedModel", rbln_config: "RBLNConfig"):
245
233
  wrapper_cfg = {"max_seq_len": rbln_config.model_cfg["max_seq_len"]}
246
-
247
- # If the model wrapper supports rbln-custom-flash-attention
248
- if "kvcache_partition_len" in inspect.signature(cls._decoder_wrapper_cls.__init__).parameters:
249
- wrapper_cfg["kvcache_partition_len"] = rbln_config.model_cfg.get("kvcache_partition_len")
250
-
234
+ wrapper_cfg["attn_impl"] = rbln_config.model_cfg.get("attn_impl")
235
+ wrapper_cfg["kvcache_partition_len"] = rbln_config.model_cfg.get("kvcache_partition_len")
251
236
  wrapper_cfg["use_rotary_emb"] = cls._use_rotary_emb
252
237
 
253
238
  return cls._decoder_wrapper_cls(model, **wrapper_cfg).eval()
@@ -258,72 +243,46 @@ class RBLNDecoderOnlyModelForCausalLM(RBLNModel):
258
243
  wrapped_model = cls.wrap_model_if_needed(model, rbln_config)
259
244
 
260
245
  rbln_compile_configs = rbln_config.compile_cfgs
261
- prefill_rbln_compile_config = rbln_compile_configs[0]
262
- dec_rbln_compile_config = rbln_compile_configs[1]
246
+ prefill_compile_config = rbln_compile_configs[0]
247
+ dec_compile_config = rbln_compile_configs[1]
263
248
 
264
- @rbln_timer("JIT trace")
265
- def get_scripted_model():
266
- # This function is nested to dealloc the example inputs before compilation.
267
- # FIXME: 3rd dummy_input(batch_idx) should be fill zero to compile flash_attn.
268
- prefill_example_inputs = prefill_rbln_compile_config.get_dummy_inputs(fill=0)
269
- dec_example_inputs = dec_rbln_compile_config.get_dummy_inputs(fill=0)
249
+ context = CompileContext(use_weight_sharing=True)
270
250
 
271
- wrapped_model.phase = "prefill"
272
- prefill_scripted_model = torch.jit.trace(
273
- wrapped_model, prefill_example_inputs, check_trace=False, _store_inputs=False
274
- )
275
- wrapped_model.phase = "decode"
276
- dec_scripted_model = torch.jit.trace(
277
- wrapped_model, dec_example_inputs, check_trace=False, _store_inputs=False
278
- )
279
- return prefill_scripted_model, dec_scripted_model
280
-
281
- prefill_scripted_model, dec_scripted_model = get_scripted_model()
251
+ # Here we use meta tensor, for the memory efficiency.
252
+ meta_tensor_names = [name for name, _, _ in prefill_compile_config.input_info if "past_key_values" in name]
253
+ prefill_example_inputs = prefill_compile_config.get_dummy_inputs(fill=0, meta_tensor_names=meta_tensor_names)
282
254
 
283
- @rbln_timer("Model conversion")
284
- def scripted_model_to_ir():
285
- prefill_ir = rebel.torchscript_to_ir(
286
- prefill_scripted_model,
287
- input_names=[v[0] for v in prefill_rbln_compile_config.input_info],
288
- )
289
- dec_ir = rebel.torchscript_to_ir(
290
- dec_scripted_model,
291
- input_names=[v[0] for v in dec_rbln_compile_config.input_info],
292
- )
293
- return prefill_ir, dec_ir
255
+ # Mark static tensors (self kv states)
256
+ static_tensors = {}
257
+ for (name, _, _), tensor in zip(prefill_compile_config.input_info, prefill_example_inputs):
258
+ if "past_key_values" in name:
259
+ static_tensors[name] = tensor
260
+ context.mark_static_address(tensor)
294
261
 
295
- prefill_ir, dec_ir = scripted_model_to_ir()
296
- # Caching prefill_decoder/decoder I/O
297
- cache_index_offset = 5
262
+ dec_example_inputs = dec_compile_config.get_dummy_inputs(fill=0, static_tensors=static_tensors)
298
263
 
299
- connections = [
300
- (prefill_ir.outputs[1 + i], prefill_ir.inputs[cache_index_offset + i])
301
- for i in range(model.config.num_hidden_layers * 2)
302
- ]
303
-
304
- # Extract quantize_config from rbln_config
305
264
  quantize_config = rbln_config.model_cfg.get("quantization", None)
306
265
 
307
266
  @QuantizationManager.with_quantization_env
308
267
  def compile_model(*args, **kwargs):
309
- # Remove quantize_config from kwargs
310
- kwargs.pop("quantize_config", None)
311
-
312
- # Call rebel.compile with the updated kwargs
313
- return rebel.compile(*args, **kwargs)
314
-
315
- compiled_model = compile_model(
316
- prefill_ir,
317
- dec_ir,
318
- connections=connections,
319
- fusion=prefill_rbln_compile_config.fusion,
320
- npu=prefill_rbln_compile_config.npu,
321
- tensor_parallel_size=prefill_rbln_compile_config.tensor_parallel_size,
322
- use_weight_sharing=True,
323
- quantize_config=quantize_config,
324
- )
268
+ wrapped_model.phase = "prefill"
269
+ compiled_prefill = RBLNModel.compile(
270
+ wrapped_model,
271
+ prefill_compile_config,
272
+ example_inputs=prefill_example_inputs,
273
+ compile_context=context,
274
+ )
275
+
276
+ wrapped_model.phase = "decode"
277
+ compiled_decoder = RBLNModel.compile(
278
+ wrapped_model,
279
+ dec_compile_config,
280
+ example_inputs=dec_example_inputs,
281
+ compile_context=context,
282
+ )
283
+ return {"prefill": compiled_prefill, "decoder": compiled_decoder}
325
284
 
326
- return compiled_model
285
+ return compile_model(quantize_config=quantize_config)
327
286
 
328
287
  @classmethod
329
288
  def _get_rbln_config(
@@ -335,6 +294,8 @@ class RBLNDecoderOnlyModelForCausalLM(RBLNModel):
335
294
  rbln_max_seq_len = rbln_kwargs.get("max_seq_len", None)
336
295
  rbln_batch_size = rbln_kwargs.get("batch_size", None)
337
296
  rbln_use_inputs_embeds = rbln_kwargs.get("use_inputs_embeds", None)
297
+ rbln_attn_impl = rbln_kwargs.get("attn_impl", None)
298
+ rbln_kvcache_partition_len = rbln_kwargs.get("kvcache_partition_len", None)
338
299
  rbln_quantization = QuantizationManager.validate_quantization_config(rbln_kwargs.get("quantization", None))
339
300
 
340
301
  prefill_chunk_size = 128
@@ -344,9 +305,16 @@ class RBLNDecoderOnlyModelForCausalLM(RBLNModel):
344
305
  )
345
306
  if rbln_max_seq_len is None:
346
307
  raise ValueError("`rbln_max_seq_len` should be specified.")
308
+
347
309
  rbln_batch_size = 1 if rbln_batch_size is None else rbln_batch_size
348
310
  rbln_use_inputs_embeds = False if rbln_use_inputs_embeds is None else rbln_use_inputs_embeds
349
311
 
312
+ rbln_attn_impl, rbln_kvcache_partition_len = validate_attention_method(
313
+ rbln_attn_impl=rbln_attn_impl,
314
+ rbln_kvcache_partition_len=rbln_kvcache_partition_len,
315
+ rbln_max_seq_len=rbln_max_seq_len,
316
+ )
317
+
350
318
  num_attention_heads = getattr(model_config, "n_head", None) or getattr(model_config, "num_attention_heads")
351
319
  num_key_value_heads = getattr(model_config, "num_key_value_heads", None) or num_attention_heads
352
320
  num_hidden_layers = getattr(model_config, "n_layer", None) or getattr(model_config, "num_hidden_layers")
@@ -372,9 +340,14 @@ class RBLNDecoderOnlyModelForCausalLM(RBLNModel):
372
340
  [batch_size, query_length],
373
341
  "int32",
374
342
  ),
375
- ("batch_position", [], "int16"),
376
- ("query_idx", [], "int16"),
377
343
  ]
344
+ if query_length > 1:
345
+ input_info.extend(
346
+ [
347
+ ("batch_position", [], "int16"),
348
+ ("query_position", [], "int16"),
349
+ ]
350
+ )
378
351
 
379
352
  input_info.extend(
380
353
  [
@@ -407,12 +380,12 @@ class RBLNDecoderOnlyModelForCausalLM(RBLNModel):
407
380
  hidden_size=hidden_size,
408
381
  )
409
382
 
410
- prefill_rbln_compile_config = RBLNCompileConfig(input_info=prefill_input_info)
411
- dec_rbln_compile_config = RBLNCompileConfig(input_info=dec_input_info)
383
+ prefill_compile_config = RBLNCompileConfig(compiled_model_name="prefill", input_info=prefill_input_info)
384
+ dec_compile_config = RBLNCompileConfig(compiled_model_name="decoder", input_info=dec_input_info)
412
385
 
413
386
  rbln_config = RBLNConfig(
414
387
  rbln_cls=cls.__name__,
415
- compile_cfgs=[prefill_rbln_compile_config, dec_rbln_compile_config],
388
+ compile_cfgs=[prefill_compile_config, dec_compile_config],
416
389
  rbln_kwargs=rbln_kwargs,
417
390
  )
418
391
 
@@ -422,6 +395,8 @@ class RBLNDecoderOnlyModelForCausalLM(RBLNModel):
422
395
  "batch_size": rbln_batch_size,
423
396
  "prefill_chunk_size": prefill_chunk_size,
424
397
  "use_inputs_embeds": rbln_use_inputs_embeds,
398
+ "kvcache_partition_len": rbln_kvcache_partition_len,
399
+ "attn_impl": rbln_attn_impl,
425
400
  }
426
401
  )
427
402
 
@@ -432,12 +407,21 @@ class RBLNDecoderOnlyModelForCausalLM(RBLNModel):
432
407
 
433
408
  @classmethod
434
409
  def _create_runtimes(
435
- cls, compiled_models: List[rebel.RBLNCompiledModel], rbln_device_map: Dict[str, int]
410
+ cls,
411
+ compiled_models: List[rebel.RBLNCompiledModel],
412
+ rbln_device_map: Dict[str, int],
413
+ activate_profiler: Optional[bool] = None,
436
414
  ) -> List[rebel.Runtime]:
437
- device_val = rbln_device_map[DEFAULT_COMPILED_MODEL_NAME]
415
+ if any(model_name not in rbln_device_map for model_name in ["prefill", "decoder"]):
416
+ cls._raise_missing_compiled_file_error(["prefill", "decoder"])
417
+
438
418
  return [
439
- compiled_models[0].create_runtime(input_info_index=0, tensor_type="pt", device=device_val),
440
- compiled_models[0].create_runtime(input_info_index=1, tensor_type="pt", device=device_val),
419
+ compiled_models[0].create_runtime(
420
+ tensor_type="pt", device=rbln_device_map["prefill"], activate_profiler=activate_profiler
421
+ ),
422
+ compiled_models[1].create_runtime(
423
+ tensor_type="pt", device=rbln_device_map["decoder"], activate_profiler=activate_profiler
424
+ ),
441
425
  ]
442
426
 
443
427
  def get_decoder(self):
@@ -569,12 +553,16 @@ class RBLNDecoderOnlyModelForCausalLM(RBLNModel):
569
553
  ],
570
554
  dtype=torch.float32,
571
555
  device="cpu",
572
- ),
573
- torch.empty(size=[], dtype=torch.int16, device="cpu"),
556
+ )
574
557
  ]
575
558
 
576
559
  input_tensors = inputs_embeds if inputs_embeds is not None else input_ids
577
560
  query_length = input_tensors.shape[1]
561
+ if query_length > self.max_seq_len:
562
+ raise ValueError(
563
+ f"Input length ({query_length}) exceeds the maximum allowed sequence length ({self.max_seq_len})."
564
+ )
565
+
578
566
  _attention_mask = self.prefill_attention_mask.clone()
579
567
 
580
568
  for step in range(0, query_length, self.prefill_chunk_size):
@@ -607,15 +595,15 @@ class RBLNDecoderOnlyModelForCausalLM(RBLNModel):
607
595
  _attention_mask[:, :, :, step - self.prefill_chunk_size : step] = 1
608
596
  _attention_mask[:, :, :, step : step + self.prefill_chunk_size] = self.causal_mask
609
597
 
610
- query_idx = (query_length - 1) % self.prefill_chunk_size
598
+ query_position = (query_length - 1) % self.prefill_chunk_size
611
599
 
612
- logits, _ = self.prefill_decoder(
600
+ logits = self.prefill_decoder(
613
601
  input_ids=_input_tensors.contiguous() if inputs_embeds is None else None,
614
602
  inputs_embeds=_input_tensors.contiguous() if inputs_embeds is not None else None,
615
603
  attention_mask=_attention_mask.contiguous(),
616
604
  cache_position=_cache_position.contiguous(),
617
605
  batch_position=torch.tensor(batch_idx, dtype=torch.int16),
618
- query_idx=torch.tensor(query_idx, dtype=torch.int16),
606
+ query_position=torch.tensor(query_position, dtype=torch.int16),
619
607
  out=out_buffers,
620
608
  )
621
609
 
@@ -651,14 +639,11 @@ class RBLNDecoderOnlyModelForCausalLM(RBLNModel):
651
639
  f"Decoding step {decoding_step} out of bounds for attention mask with shape {self.dec_attn_mask.shape}."
652
640
  )
653
641
  self.dec_attn_mask[b_idx, :, :, decoding_step] = 1
654
-
655
- logits, _ = self.decoder(
642
+ logits = self.decoder(
656
643
  input_ids=input_tensors.contiguous() if inputs_embeds is None else None,
657
644
  inputs_embeds=input_tensors.contiguous() if inputs_embeds is not None else None,
658
645
  attention_mask=self.dec_attn_mask.contiguous(),
659
646
  cache_position=cache_position.contiguous(),
660
- batch_position=torch.tensor(0, dtype=torch.int16),
661
- query_idx=torch.tensor(0, dtype=torch.int16),
662
647
  )
663
648
 
664
649
  return logits
@@ -20,6 +20,7 @@
20
20
  # are the intellectual property of Rebellions Inc. and may not be
21
21
  # copied, modified, or distributed without prior written permission
22
22
  # from Rebellions Inc.
23
+
23
24
  from typing import TYPE_CHECKING
24
25
 
25
26
  import torch.nn as nn
@@ -58,7 +59,7 @@ class ExaoneForCausalLMWrapper(DecoderOnlyWrapper):
58
59
 
59
60
  new_layer = ExaoneLayer(layer, new_self_attn)
60
61
  new_layers.append(new_layer)
61
- new_model = ExaoneModel(causal_lm.transformer, new_layers)
62
+ new_model = ExaoneModel(causal_lm.transformer, new_layers, partition_len=self.kvcache_partition_len)
62
63
  new_causal_lm = DecoderOnlyForCausalLM(causal_lm, new_model)
63
64
  return new_causal_lm
64
65
 
@@ -85,7 +86,6 @@ class ExaoneAttention(DecoderOnlyAttention):
85
86
  self.k_proj = self._original_mod.k_proj
86
87
  self.v_proj = self._original_mod.v_proj
87
88
  self.o_proj = self._original_mod.out_proj
88
- self.num_key_value_heads = self._original_mod.num_key_value_heads
89
89
 
90
90
 
91
91
  class ExaoneFlashAttention(DecoderOnlyFlashAttention):
@@ -94,4 +94,3 @@ class ExaoneFlashAttention(DecoderOnlyFlashAttention):
94
94
  self.k_proj = self._original_mod.k_proj
95
95
  self.v_proj = self._original_mod.v_proj
96
96
  self.o_proj = self._original_mod.out_proj
97
- self.num_key_value_heads = self._original_mod.num_key_value_heads
@@ -51,7 +51,7 @@ class GemmaWrapper(DecoderOnlyWrapper):
51
51
  raise NotImplementedError(f"Unknwon attn : {self.attn_impl}")
52
52
  new_layer = DecoderOnlyLayer(layer, new_self_attn)
53
53
  new_layers.append(new_layer)
54
- new_model = GemmaModel(causal_lm.model, new_layers)
54
+ new_model = GemmaModel(causal_lm.model, new_layers, partition_len=self.kvcache_partition_len)
55
55
  new_causal_lm = DecoderOnlyForCausalLM(causal_lm, new_model)
56
56
  return new_causal_lm
57
57
 
@@ -21,6 +21,7 @@
21
21
  # copied, modified, or distributed without prior written permission
22
22
  # from Rebellions Inc.
23
23
 
24
+ import math
24
25
  from typing import TYPE_CHECKING, Tuple
25
26
 
26
27
  import torch
@@ -54,8 +55,6 @@ class GPT2Wrapper(DecoderOnlyWrapper):
54
55
 
55
56
 
56
57
  class GPT2Model(DecoderOnlyModel):
57
- mask_fmin = torch.finfo(torch.float32).min
58
-
59
58
  def get_last_layernorm(self) -> nn.LayerNorm:
60
59
  return self._original_mod.ln_f
61
60
 
@@ -79,16 +78,17 @@ class GPT2Attention(DecoderOnlyAttention):
79
78
  self.c_attn = self._original_mod.c_attn
80
79
  self.o_proj = self._original_mod.c_proj
81
80
  self.split_size = self._original_mod.split_size
82
- self.num_key_value_heads = self._original_mod.num_heads
83
81
 
84
82
  def projection(self, hidden_states) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
85
83
  query_states, key_states, value_states = self.c_attn(hidden_states).split(self.split_size, dim=2)
86
84
  return query_states, key_states, value_states
87
85
 
88
- def rbln_attention(self, *args, **kwargs):
89
- return super().rbln_attention(
90
- *args,
91
- **kwargs,
92
- layer_idx=self.layer_idx,
93
- scale_attn_by_inverse_layer_idx=self._original_mod.scale_attn_by_inverse_layer_idx,
94
- )
86
+ def get_attn_scale(self):
87
+ scale = 1.0
88
+ if self._original_mod.scale_attn_weights:
89
+ scale /= math.sqrt(self.head_dim)
90
+
91
+ if self._original_mod.scale_attn_by_inverse_layer_idx:
92
+ scale /= 1 + self.layer_idx
93
+
94
+ return scale
@@ -23,7 +23,7 @@
23
23
 
24
24
  from ....utils import logging
25
25
  from ...models.decoderonly import RBLNDecoderOnlyModelForCausalLM
26
- from .gpt2_architecture import GPT2Wrapper # GPT2LMHeadModelWrapper
26
+ from .gpt2_architecture import GPT2Wrapper
27
27
 
28
28
 
29
29
  logger = logging.get_logger(__name__)
@@ -21,7 +21,6 @@
21
21
  # copied, modified, or distributed without prior written permission
22
22
  # from Rebellions Inc.
23
23
 
24
-
25
24
  from ...models.decoderonly.decoderonly_architecture import DecoderOnlyWrapper
26
25
 
27
26
 
@@ -20,6 +20,7 @@
20
20
  # are the intellectual property of Rebellions Inc. and may not be
21
21
  # copied, modified, or distributed without prior written permission
22
22
  # from Rebellions Inc.
23
+
23
24
  import inspect
24
25
  import logging
25
26
  from pathlib import Path
@@ -21,12 +21,12 @@
21
21
  # copied, modified, or distributed without prior written permission
22
22
  # from Rebellions Inc.
23
23
 
24
+ import math
24
25
  from typing import TYPE_CHECKING, Tuple
25
26
 
26
27
  import torch
27
28
  import torch.nn as nn
28
29
 
29
- from ....transformers.models.decoderonly.decoderonly_architecture import rotate_half
30
30
  from ..decoderonly.decoderonly_architecture import (
31
31
  DecoderOnlyAttention,
32
32
  DecoderOnlyForCausalLM,
@@ -34,6 +34,7 @@ from ..decoderonly.decoderonly_architecture import (
34
34
  DecoderOnlyModel,
35
35
  DecoderOnlyWrapper,
36
36
  apply_rotary_pos_emb_partial,
37
+ rotate_half,
37
38
  )
38
39
 
39
40
 
@@ -77,8 +78,6 @@ class MidmLMHeadModelWrapper(DecoderOnlyWrapper):
77
78
 
78
79
 
79
80
  class MidmModel(DecoderOnlyModel):
80
- mask_fmin = -10000.0
81
-
82
81
  def get_layernorm1p(self, module: nn.LayerNorm):
83
82
  def layernorm1p(input: torch.Tensor):
84
83
  """Applies Layer Normalization with a slight modification on the weights."""
@@ -135,14 +134,15 @@ class MidmAttention(DecoderOnlyAttention):
135
134
  query_states, key_states, value_states = self.c_attn(hidden_states).split(self.split_size, dim=2)
136
135
  return query_states, key_states, value_states
137
136
 
138
- def rbln_attention(self, *args, **kwargs):
139
- return super().rbln_attention(
140
- *args,
141
- **kwargs,
142
- layer_idx=self.layer_idx,
143
- scale_attn_weights=self._original_mod.scale_attn_weights,
144
- scale_attn_by_inverse_layer_idx=self._original_mod.scale_attn_by_inverse_layer_idx,
145
- )
137
+ def get_attn_scale(self):
138
+ scale = 1.0
139
+ if self._original_mod.scale_attn_weights:
140
+ scale /= math.sqrt(self.head_dim)
141
+
142
+ if self._original_mod.scale_attn_by_inverse_layer_idx and not self._original_mod.scale_qk_by_inverse_layer_idx:
143
+ scale /= 1 + self.layer_idx
144
+
145
+ return scale
146
146
 
147
147
  def apply_rotary_pos_embed(self, query_states, key_states, cos, sin):
148
148
  return apply_rotary_pos_emb_partial(query_states, key_states, cos, sin, ndim=cos.shape[-1])
@@ -21,7 +21,6 @@
21
21
  # copied, modified, or distributed without prior written permission
22
22
  # from Rebellions Inc.
23
23
 
24
-
25
24
  from transformers import AutoModelForCausalLM
26
25
 
27
26
  from ....utils import logging
@@ -21,7 +21,6 @@
21
21
  # copied, modified, or distributed without prior written permission
22
22
  # from Rebellions Inc.
23
23
 
24
-
25
24
  from ..decoderonly.decoderonly_architecture import DecoderOnlyWrapper
26
25
 
27
26
 
@@ -65,7 +65,6 @@ class PhiAttention(DecoderOnlyAttention):
65
65
  self.o_proj = self._original_mod.dense
66
66
  self.qk_layernorm = self._original_mod.qk_layernorm
67
67
  self.rotary_ndims = self._original_mod.rotary_ndims
68
- self.num_key_value_heads = self.num_heads
69
68
 
70
69
  def projection(self, hidden_states) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
71
70
  query_states = self.q_proj(hidden_states)
@@ -90,7 +89,7 @@ class PhiLayer(DecoderOnlyLayer):
90
89
  self,
91
90
  hidden_states: torch.Tensor,
92
91
  attention_mask: torch.Tensor,
93
- current_steps: torch.LongTensor,
92
+ seq_positions: torch.LongTensor,
94
93
  batch_position: torch.Tensor,
95
94
  past_key_values: Tuple[Tuple[torch.Tensor]],
96
95
  cos: Optional[torch.Tensor] = None,
@@ -103,7 +102,7 @@ class PhiLayer(DecoderOnlyLayer):
103
102
  attn_outputs, present_key_values = self.self_attn(
104
103
  hidden_states=hidden_states,
105
104
  attention_mask=attention_mask,
106
- current_steps=current_steps,
105
+ seq_positions=seq_positions,
107
106
  batch_position=batch_position,
108
107
  past_key_values=past_key_values,
109
108
  cos=cos,
@@ -21,7 +21,6 @@
21
21
  # copied, modified, or distributed without prior written permission
22
22
  # from Rebellions Inc.
23
23
 
24
-
25
24
  from ..decoderonly.decoderonly_architecture import DecoderOnlyWrapper
26
25
 
27
26