optimum-rbln 0.1.9__py3-none-any.whl → 0.1.11__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (61) hide show
  1. optimum/rbln/__init__.py +37 -2
  2. optimum/rbln/__version__.py +1 -1
  3. optimum/rbln/diffusers/models/autoencoder_kl.py +36 -29
  4. optimum/rbln/diffusers/models/controlnet.py +56 -40
  5. optimum/rbln/diffusers/models/unet_2d_condition.py +40 -28
  6. optimum/rbln/diffusers/pipelines/controlnet/pipeline_controlnet.py +22 -15
  7. optimum/rbln/diffusers/pipelines/controlnet/pipeline_controlnet_img2img.py +22 -15
  8. optimum/rbln/diffusers/pipelines/controlnet/pipeline_controlnet_sd_xl.py +23 -17
  9. optimum/rbln/diffusers/pipelines/controlnet/pipeline_controlnet_sd_xl_img2img.py +24 -18
  10. optimum/rbln/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion.py +22 -11
  11. optimum/rbln/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_img2img.py +22 -11
  12. optimum/rbln/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl.py +24 -14
  13. optimum/rbln/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_img2img.py +24 -14
  14. optimum/rbln/modeling_alias.py +3 -3
  15. optimum/rbln/modeling_base.py +471 -231
  16. optimum/rbln/modeling_config.py +152 -77
  17. optimum/rbln/modeling_seq2seq.py +166 -77
  18. optimum/rbln/transformers/__init__.py +35 -1
  19. optimum/rbln/transformers/models/__init__.py +20 -1
  20. optimum/rbln/transformers/models/auto/__init__.py +14 -0
  21. optimum/rbln/transformers/models/auto/auto_factory.py +84 -0
  22. optimum/rbln/transformers/models/auto/modeling_auto.py +94 -0
  23. optimum/rbln/transformers/models/bart/__init__.py +1 -0
  24. optimum/rbln/transformers/models/bart/bart_architecture.py +189 -50
  25. optimum/rbln/transformers/models/bart/modeling_bart.py +106 -0
  26. optimum/rbln/transformers/models/bert/__init__.py +24 -0
  27. optimum/rbln/transformers/models/bert/modeling_bert.py +102 -0
  28. optimum/rbln/transformers/models/clip/__init__.py +1 -1
  29. optimum/rbln/transformers/models/clip/modeling_clip.py +127 -25
  30. optimum/rbln/transformers/models/decoderonly/decoderonly_architecture.py +28 -4
  31. optimum/rbln/transformers/models/decoderonly/modeling_decoderonly.py +302 -115
  32. optimum/rbln/transformers/models/dpt/modeling_dpt.py +21 -7
  33. optimum/rbln/transformers/models/gemma/modeling_gemma.py +1 -1
  34. optimum/rbln/transformers/models/gpt2/gpt2_architecture.py +4 -1
  35. optimum/rbln/transformers/models/gpt2/modeling_gpt2.py +1 -1
  36. optimum/rbln/transformers/models/llama/modeling_llama.py +1 -1
  37. optimum/rbln/transformers/models/llava_next/__init__.py +24 -0
  38. optimum/rbln/transformers/models/llava_next/modeling_llava_next.py +666 -0
  39. optimum/rbln/transformers/models/midm/midm_architecture.py +5 -1
  40. optimum/rbln/transformers/models/midm/modeling_midm.py +1 -1
  41. optimum/rbln/transformers/models/mistral/modeling_mistral.py +1 -1
  42. optimum/rbln/transformers/models/phi/__init__.py +24 -0
  43. optimum/rbln/transformers/models/phi/modeling_phi.py +69 -0
  44. optimum/rbln/transformers/models/phi/phi_architecture.py +406 -0
  45. optimum/rbln/transformers/models/t5/t5_architecture.py +92 -31
  46. optimum/rbln/transformers/models/wav2vec2/modeling_wav2vec2.py +17 -11
  47. optimum/rbln/transformers/models/whisper/generation_whisper.py +68 -0
  48. optimum/rbln/transformers/models/whisper/modeling_whisper.py +141 -105
  49. optimum/rbln/transformers/models/whisper/whisper_architecture.py +44 -17
  50. optimum/rbln/transformers/models/xlm_roberta/modeling_xlm_roberta.py +17 -14
  51. optimum/rbln/transformers/utils/rbln_quantization.py +48 -60
  52. optimum/rbln/utils/import_utils.py +36 -1
  53. optimum/rbln/utils/logging.py +82 -0
  54. optimum/rbln/utils/runtime_utils.py +33 -0
  55. optimum/rbln/utils/timer_utils.py +19 -0
  56. {optimum_rbln-0.1.9.dist-info → optimum_rbln-0.1.11.dist-info}/METADATA +8 -7
  57. optimum_rbln-0.1.11.dist-info/RECORD +93 -0
  58. {optimum_rbln-0.1.9.dist-info → optimum_rbln-0.1.11.dist-info}/WHEEL +1 -1
  59. optimum_rbln-0.1.11.dist-info/entry_points.txt +4 -0
  60. optimum_rbln-0.1.9.dist-info/RECORD +0 -78
  61. {optimum_rbln-0.1.9.dist-info → optimum_rbln-0.1.11.dist-info}/licenses/LICENSE +0 -0
@@ -23,24 +23,38 @@
23
23
 
24
24
  import copy
25
25
  import json
26
- from collections import UserDict
27
26
  from dataclasses import asdict, dataclass
28
27
  from pathlib import Path
29
28
  from typing import Any, Dict, List, Optional, Tuple
30
29
 
30
+ import rebel
31
31
  import torch
32
32
 
33
+ from .__version__ import __version__
34
+ from .utils.runtime_utils import ContextRblnConfig
35
+
33
36
 
34
37
  DEFAULT_COMPILED_MODEL_NAME = "compiled_model"
35
38
  DEFAULT_MOD_NAME = "default"
36
39
 
37
40
 
38
41
  @dataclass
39
- class RBLNRuntimeConfig:
42
+ class RBLNCompileConfig:
43
+ """
44
+ Configuration for RBLN compilation.
45
+
46
+ Attributes:
47
+ compiled_model_name (str): Name of the compiled model.
48
+ mod_name (str): Name of the RBLN module.
49
+ input_info (List[Tuple[str, Tuple[int], Optional[str]]]): Information about input tensors.
50
+ fusion (Optional[bool]): Whether to use fusion optimization.
51
+ npu (Optional[str]): NPU configuration.
52
+ tensor_parallel_size (Optional[int]): Size for tensor parallelism.
53
+ """
54
+
40
55
  compiled_model_name: str = DEFAULT_COMPILED_MODEL_NAME
41
- rbln_mod_name: str = DEFAULT_MOD_NAME
56
+ mod_name: str = DEFAULT_MOD_NAME
42
57
  input_info: List[Tuple[str, Tuple[int], Optional[str]]] = None
43
- batch_size: Optional[int] = None
44
58
  fusion: Optional[bool] = None
45
59
  npu: Optional[str] = None
46
60
  tensor_parallel_size: Optional[int] = None
@@ -48,8 +62,14 @@ class RBLNRuntimeConfig:
48
62
  @staticmethod
49
63
  def normalize_dtype(dtype):
50
64
  """
51
- framework's dtype to string.
65
+ Convert framework-specific dtype to string representation.
52
66
  i.e. torch.float32 -> "float32"
67
+
68
+ Args:
69
+ dtype: The input dtype (can be string, torch dtype, or numpy dtype).
70
+
71
+ Returns:
72
+ str: The normalized string representation of the dtype.
53
73
  """
54
74
  if isinstance(dtype, str):
55
75
  return dtype
@@ -60,13 +80,12 @@ class RBLNRuntimeConfig:
60
80
  return dtype
61
81
 
62
82
  def __post_init__(self):
63
- self.input_info = [(i[0], i[1], RBLNRuntimeConfig.normalize_dtype(i[2]) or "float32") for i in self.input_info]
83
+ self.input_info = [(i[0], i[1], RBLNCompileConfig.normalize_dtype(i[2]) or "float32") for i in self.input_info]
64
84
 
65
- def update(self, **kwargs):
85
+ def update(self, kwargs: Dict[str, Any]):
66
86
  self.compiled_model_name = kwargs.get("compiled_model_name", self.compiled_model_name)
67
- self.rbln_mod_name = kwargs.get("rbln_mod_name", self.rbln_mod_name)
87
+ self.mod_name = kwargs.get("mod_name", self.mod_name)
68
88
  self.input_info = kwargs.get("input_info", self.input_info)
69
- self.batch_size = kwargs.get("batch_size", self.batch_size)
70
89
  self.fusion = kwargs.get("fusion", self.fusion)
71
90
  self.npu = kwargs.get("npu", self.npu)
72
91
  self.tensor_parallel_size = kwargs.get("tensor_parallel_size", self.tensor_parallel_size)
@@ -86,84 +105,140 @@ class RBLNRuntimeConfig:
86
105
  return asdict(self)
87
106
 
88
107
 
89
- class RBLNConfig(UserDict):
90
- def __init__(self, runtime_cfgs: Dict[str, List[RBLNRuntimeConfig]], _rbln_meta: Dict[str, Any] = None):
91
- """Configurations for RBLN model compilation and inference.
108
+ RUNTIME_KEYWORDS = ["create_runtimes", "optimize_host_memory", "device", "device_map"]
109
+ COMPILE_KEYWORDS = ["compiled_model_name", "mod_name", "input_info", "fusion", "npu", "tensor_parallel_size"]
92
110
 
93
- Args:
94
- _rbln_meta (Dict[str, Any], optional):
95
- Any rbln-specific configurations.
96
- (i.e. max_seq_len for language models, image_size for image models).
97
- Defaults to None.
98
- """
99
- super().__init__(runtime_cfgs)
100
- if _rbln_meta:
101
- self.meta = _rbln_meta
111
+
112
+ class RBLNConfig:
113
+ """
114
+ Configuration for single RBLN OptimizedModel, representing multiple compiled models.
115
+
116
+ Attributes:
117
+ compile_cfgs (List[RBLNCompileConfig]): Compilation configurations.
118
+ meta (dict): Metadata including version and class information.
119
+ runtime_cfg (dict): Runtime-specific configuration.
120
+ """
121
+
122
+ # It represents multiple compiled model, one of each can have multiple runtimes.
123
+ def __init__(
124
+ self,
125
+ rbln_cls,
126
+ compile_cfgs: List[RBLNCompileConfig],
127
+ rbln_kwargs=None,
128
+ meta=None,
129
+ ) -> None:
130
+ if rbln_kwargs is None:
131
+ rbln_kwargs = {}
102
132
  else:
103
- self.meta: Dict[str, Any] = {}
133
+ rbln_kwargs = copy.deepcopy(rbln_kwargs)
104
134
 
105
- @staticmethod
106
- def from_rbln_configs(rbln_configs: List["RBLNConfig"], names: Optional[List[str]] = None) -> "RBLNConfig":
107
- # assume each rbln_config has exact one rbln_runtime_config
108
- names = [None] * len(rbln_configs) if names is None else names
109
- runtime_cfgs = []
110
- for name, cfg in zip(names, rbln_configs):
111
- if len(cfg) > 1:
112
- msg = (
113
- "`from_rbln_configs` requires exact one `RBLNRuntimeConfig` for each `RBLNConfig`."
114
- f"But got {len(cfg)} `RBLNRuntimeConfig`."
115
- )
116
- raise RuntimeError(msg)
117
-
118
- runtime_cfg = cfg[list(cfg.keys())[0]][0]
119
- runtime_cfg = copy.deepcopy(runtime_cfg)
120
- if name is not None:
121
- runtime_cfg.compiled_model_name = name
122
- runtime_cfgs.append(runtime_cfg)
123
-
124
- metas = [cfg.meta for cfg in rbln_configs]
125
- merged_meta = {k: v for meta in metas for k, v in meta.items()}
126
-
127
- return RBLNConfig.from_rbln_runtime_configs(runtime_cfgs, _rbln_meta=merged_meta)
135
+ # meta : class, version and other informations.
136
+ if meta is None:
137
+ self.meta = {"version": __version__, "cls": rbln_cls}
138
+ else:
139
+ self.meta = meta
128
140
 
129
- @staticmethod
130
- def from_rbln_runtime_configs(
131
- rbln_runtime_configs: List[RBLNRuntimeConfig],
132
- _rbln_meta: Dict[str, Any] = None,
133
- ) -> "RBLNConfig":
134
- cfgs: Dict[str, List[RBLNRuntimeConfig]] = {}
135
- for rbln_runtime_config in rbln_runtime_configs:
136
- if rbln_runtime_config.compiled_model_name in cfgs:
137
- cfgs[rbln_runtime_config.compiled_model_name].append(rbln_runtime_config)
138
- else:
139
- cfgs[rbln_runtime_config.compiled_model_name] = [rbln_runtime_config]
140
- return RBLNConfig(cfgs, _rbln_meta=_rbln_meta)
141
+ # compile_cfgs : compile args for each runtime
142
+ self.compile_cfgs = compile_cfgs
143
+ for compile_cfg in self.compile_cfgs:
144
+ compile_cfg.update(rbln_kwargs)
145
+ for K in COMPILE_KEYWORDS:
146
+ rbln_kwargs.pop(K, None)
147
+
148
+ # runtime_cfg : Values that don't be saved / loaded.
149
+ self.runtime_cfg = {}
150
+ for runtime_key in RUNTIME_KEYWORDS:
151
+ if runtime_key in rbln_kwargs:
152
+ self.runtime_cfg[runtime_key] = rbln_kwargs.pop(runtime_key)
153
+
154
+ # model_cfg : All user-provided values such as "max_seq_len".
155
+ self.model_cfg: Dict[str, Any] = rbln_kwargs
141
156
 
142
157
  def save(self, dir_path: str):
143
158
  dir_path = Path(dir_path)
144
- data = self.asdict()
145
- data.update({"rbln_config_meta": self.meta})
159
+
160
+ s_json = {}
161
+ compile_cfgs = [asdict(cfg) for cfg in self.compile_cfgs]
162
+ s_json["_compile_cfgs"] = compile_cfgs
163
+ s_json["_meta"] = self.meta
164
+ s_json.update(self.model_cfg)
165
+
146
166
  with open(dir_path / "rbln_config.json", "w") as jsonf:
147
- json.dump(data, jsonf, indent=2)
167
+ json.dump(s_json, jsonf, indent=2)
148
168
 
149
- @staticmethod
150
- def load(dir_path: str) -> "RBLNConfig":
169
+ @classmethod
170
+ def load(cls, dir_path: str) -> "RBLNConfig":
151
171
  dir_path = Path(dir_path)
152
172
  with open(dir_path / "rbln_config.json", "r") as jsonf:
153
173
  config_file = json.load(jsonf)
154
- return RBLNConfig.fromdict(config_file)
155
-
156
- def asdict(self):
157
- dic = {k: [asdict(cfg) for cfg in cfgs] for k, cfgs in self.data.items()}
158
- return dic
159
174
 
160
- @staticmethod
161
- def fromdict(dic: dict):
162
- runtime_cfgs = {
163
- k: [RBLNRuntimeConfig(**cfg) for cfg in cfgs] for k, cfgs in dic.items() if k != "rbln_config_meta"
164
- }
165
- if "rbln_config_meta" in dic:
166
- meta = dic["rbln_config_meta"]
167
- else:
168
- meta = None
169
- return RBLNConfig(runtime_cfgs, _rbln_meta=meta)
175
+ return cls.fromdict(config_file)
176
+
177
+ @classmethod
178
+ def fromdict(cls, dic: dict):
179
+ compile_cfgs = dic.pop("_compile_cfgs")
180
+ compile_cfgs = [RBLNCompileConfig(**cfg) for cfg in compile_cfgs]
181
+
182
+ meta = dic.pop("_meta")
183
+ rbln_cls = meta["cls"]
184
+
185
+ rbln_kwargs = dic
186
+ return cls(rbln_cls=rbln_cls, compile_cfgs=compile_cfgs, rbln_kwargs=rbln_kwargs, meta=meta)
187
+
188
+ def update_runtime_cfg(self, rbln_kwargs: Dict[str, Any]):
189
+ keys = list(rbln_kwargs.keys())
190
+ for key in keys:
191
+ if key in RUNTIME_KEYWORDS:
192
+ self.runtime_cfg[key] = rbln_kwargs[key]
193
+
194
+ def __repr__(self):
195
+ compile_cfgs_repr = [f"\n {cfg!r}" for cfg in self.compile_cfgs]
196
+ return (
197
+ f"RBLNConfig(\n"
198
+ f" rbln_cls={self.meta['cls']},\n"
199
+ f" version='{self.meta['version']}',\n"
200
+ f" compile_cfgs=[{''.join(compile_cfgs_repr)}\n ],\n"
201
+ f" model_cfg={self.model_cfg},\n"
202
+ f" runtime_cfg={self.runtime_cfg}\n"
203
+ f")"
204
+ )
205
+
206
+ @property
207
+ def create_runtimes(self):
208
+ context = ContextRblnConfig.get_current_context()["create_runtimes"]
209
+ if context is not None:
210
+ return context
211
+ elif self.runtime_cfg.get("create_runtimes", None) is None:
212
+ return rebel.npu_is_available()
213
+ return self.runtime_cfg["create_runtimes"]
214
+
215
+ @property
216
+ def optimize_host_memory(self):
217
+ context = ContextRblnConfig.get_current_context()["optimize_host_memory"]
218
+ if context is not None:
219
+ return context
220
+ elif self.runtime_cfg.get("optimize_host_memory", None) is None:
221
+ return True
222
+ return self.runtime_cfg["optimize_host_memory"]
223
+
224
+ @property
225
+ def device(self):
226
+ context = ContextRblnConfig.get_current_context()["device"]
227
+ if context:
228
+ return context
229
+ elif self.runtime_cfg.get("device", None) is None:
230
+ return 0
231
+ return self.runtime_cfg["device"]
232
+
233
+ @property
234
+ def device_map(self):
235
+ context = ContextRblnConfig.get_current_context()["device_map"]
236
+ if context:
237
+ return context
238
+ elif self.runtime_cfg.get("device_map", None) is None:
239
+ rbln_device_map = {}
240
+ device_val = self.device
241
+ for cfg in self.compile_cfgs:
242
+ rbln_device_map[cfg.compiled_model_name] = device_val
243
+ return rbln_device_map
244
+ return self.runtime_cfg["device_map"]
@@ -37,7 +37,7 @@ from transformers import (
37
37
  from transformers.modeling_outputs import BaseModelOutput, Seq2SeqLMOutput
38
38
 
39
39
  from .modeling_base import RBLNModel
40
- from .modeling_config import DEFAULT_COMPILED_MODEL_NAME, RBLNConfig, RBLNRuntimeConfig
40
+ from .modeling_config import DEFAULT_COMPILED_MODEL_NAME, RBLNCompileConfig, RBLNConfig
41
41
  from .transformers.models.bart import BartDecoderWrapper, BartEncoderWrapper
42
42
  from .transformers.models.t5 import T5DecoderWrapper, T5EncoderWrapper
43
43
  from .utils.runtime_utils import RBLNPytorchRuntime
@@ -88,12 +88,14 @@ class RBLNModelForSeq2SeqLM(RBLNModel):
88
88
 
89
89
  def __post_init__(self, **kwargs):
90
90
  self.model_dim = self.config.d_model
91
- self.batch_size = self.rbln_config[DEFAULT_COMPILED_MODEL_NAME][0].batch_size
92
- self.enc_max_seq_len = self.rbln_config.meta["rbln_enc_max_seq_len"]
93
- self.dec_max_seq_len = self.rbln_config.meta["rbln_dec_max_seq_len"]
94
- self.pad_token_id = self.rbln_config.meta["rbln_pad_token_id"]
91
+ self.batch_size = self.rbln_config.model_cfg["batch_size"]
92
+ self.enc_max_seq_len = self.rbln_config.model_cfg["enc_max_seq_len"]
93
+ self.dec_max_seq_len = self.rbln_config.model_cfg["dec_max_seq_len"]
94
+ self.pad_token_id = self.rbln_config.model_cfg["pad_token_id"]
95
95
  self.encoder = RBLNRuntimeEncoder(runtime=self.model[0], main_input_name="input_ids")
96
96
  self.decoder = RBLNRuntimeDecoder(runtime=self.model[1], main_input_name="input_ids")
97
+ self.enc_attention_mask = torch.zeros(1, self.enc_max_seq_len, dtype=torch.float32)
98
+ self.dec_enc_attention_mask = torch.zeros(self.batch_size, self.enc_max_seq_len, dtype=torch.float32)
97
99
 
98
100
  def can_generate(self):
99
101
  return True
@@ -117,32 +119,6 @@ class RBLNModelForSeq2SeqLM(RBLNModel):
117
119
  return redirect(val)
118
120
  return val
119
121
 
120
- def prepare_inputs_for_generation(
121
- self,
122
- input_ids,
123
- past_key_values=None,
124
- attention_mask=None,
125
- decoder_attention_mask=None,
126
- **kwargs,
127
- ):
128
- max_seq_len = self.dec_max_seq_len
129
- cur_seq_len = input_ids.shape[-1]
130
- decoder_batch_size = input_ids.shape[0]
131
- input_ids = input_ids[:, cur_seq_len - 1 : cur_seq_len].contiguous()
132
-
133
- # In greedy decoding
134
- decoder_attention_mask = torch.zeros(decoder_batch_size, max_seq_len, dtype=torch.int64)
135
- decoder_attention_mask[:, :cur_seq_len] = 1
136
- cache_position = torch.tensor(cur_seq_len - 1, dtype=torch.int32)
137
-
138
- return {
139
- "decoder_input_ids": input_ids,
140
- "past_key_values": past_key_values,
141
- "attention_mask": attention_mask,
142
- "decoder_attention_mask": decoder_attention_mask,
143
- "cache_position": cache_position,
144
- }
145
-
146
122
  @classmethod
147
123
  def update_kwargs(cls, kwargs):
148
124
  kwargs.update(
@@ -170,50 +146,54 @@ class RBLNModelForSeq2SeqLM(RBLNModel):
170
146
 
171
147
  wrapped_encoder, wrapped_decoder = optimized_models(model)
172
148
 
173
- wrapped_encoder.encoder_max_length = rbln_config.meta["rbln_enc_max_seq_len"]
174
- wrapped_encoder.decoder_max_length = rbln_config.meta["rbln_dec_max_seq_len"]
175
- wrapped_encoder.decoder_batch_size = rbln_config.meta["rbln_batch_size"]
149
+ wrapped_encoder.encoder_max_length = rbln_config.model_cfg["enc_max_seq_len"]
150
+ wrapped_encoder.decoder_max_length = rbln_config.model_cfg["dec_max_seq_len"]
151
+ wrapped_encoder.decoder_batch_size = rbln_config.model_cfg["batch_size"]
176
152
 
177
- wrapped_decoder.encoder_max_length = rbln_config.meta["rbln_enc_max_seq_len"]
178
- wrapped_decoder.decoder_max_length = rbln_config.meta["rbln_dec_max_seq_len"]
179
- wrapped_decoder.decoder_batch_size = rbln_config.meta["rbln_batch_size"]
153
+ wrapped_decoder.encoder_max_length = rbln_config.model_cfg["enc_max_seq_len"]
154
+ wrapped_decoder.decoder_max_length = rbln_config.model_cfg["dec_max_seq_len"]
155
+ wrapped_decoder.decoder_batch_size = rbln_config.model_cfg["batch_size"]
180
156
 
181
- enc_rbln_runtime_config = rbln_config[DEFAULT_COMPILED_MODEL_NAME][0]
182
- dec_rbln_runtime_config = rbln_config[DEFAULT_COMPILED_MODEL_NAME][1]
157
+ enc_rbln_compile_config = rbln_config.compile_cfgs[0]
158
+ dec_rbln_compile_config = rbln_config.compile_cfgs[1]
183
159
 
184
160
  if isinstance(model, T5ForConditionalGeneration):
185
- enc_example_inputs = enc_rbln_runtime_config.get_dummy_inputs(fill=1)
186
- dec_example_inputs = dec_rbln_runtime_config.get_dummy_inputs(fill=1)
161
+ enc_example_inputs = enc_rbln_compile_config.get_dummy_inputs(fill=1)
162
+ dec_example_inputs = dec_rbln_compile_config.get_dummy_inputs(fill=1)
187
163
  else:
188
- enc_example_inputs = enc_rbln_runtime_config.get_dummy_inputs(fill=0)
189
- dec_example_inputs = dec_rbln_runtime_config.get_dummy_inputs(fill=0)
164
+ enc_example_inputs = enc_rbln_compile_config.get_dummy_inputs(fill=0)
165
+ dec_example_inputs = dec_rbln_compile_config.get_dummy_inputs(fill=0)
166
+
167
+ enc_example_inputs[3].fill_(0)
168
+ dec_example_inputs[4].fill_(-1)
190
169
 
191
170
  enc_scripted_model = torch.jit.trace(wrapped_encoder, enc_example_inputs, check_trace=False)
192
171
  dec_scripted_model = torch.jit.trace(wrapped_decoder, dec_example_inputs, check_trace=False)
193
172
 
194
173
  enc_ir = rebel.torchscript_to_ir(
195
174
  enc_scripted_model,
196
- input_names=[v[0] for v in enc_rbln_runtime_config.input_info],
197
- name=enc_rbln_runtime_config.rbln_mod_name,
175
+ input_names=[v[0] for v in enc_rbln_compile_config.input_info],
176
+ name=enc_rbln_compile_config.mod_name,
198
177
  )
199
178
  dec_ir = rebel.torchscript_to_ir(
200
179
  dec_scripted_model,
201
- input_names=[v[0] for v in dec_rbln_runtime_config.input_info],
202
- name=dec_rbln_runtime_config.rbln_mod_name,
180
+ input_names=[v[0] for v in dec_rbln_compile_config.input_info],
181
+ name=dec_rbln_compile_config.mod_name,
203
182
  )
204
- dec_ir.decoder_batch_size = rbln_config.meta["rbln_batch_size"]
183
+ dec_ir.decoder_batch_size = rbln_config.model_cfg["batch_size"]
205
184
 
206
185
  connections = [
207
- (enc_ir.outputs[0], dec_ir.inputs[5]),
208
- (dec_ir.outputs[1], dec_ir.inputs[4]),
186
+ (enc_ir.outputs[0], enc_ir.inputs[2], dec_ir.inputs[6]),
187
+ # (enc_ir.outputs[0], enc_ir.inputs[2]),
188
+ (dec_ir.outputs[1], dec_ir.inputs[5]),
209
189
  ]
210
190
  compiled_model = rebel.compile(
211
191
  enc_ir,
212
192
  dec_ir,
213
193
  connections=connections,
214
- fusion=enc_rbln_runtime_config.fusion,
215
- npu=enc_rbln_runtime_config.npu,
216
- tensor_parallel_size=enc_rbln_runtime_config.tensor_parallel_size,
194
+ fusion=enc_rbln_compile_config.fusion,
195
+ npu=enc_rbln_compile_config.npu,
196
+ tensor_parallel_size=enc_rbln_compile_config.tensor_parallel_size,
217
197
  )
218
198
  return compiled_model
219
199
 
@@ -222,11 +202,12 @@ class RBLNModelForSeq2SeqLM(RBLNModel):
222
202
  cls,
223
203
  preprocessors: Union["AutoFeatureExtractor", "AutoProcessor", "AutoTokenizer"],
224
204
  model_config: "PretrainedConfig",
225
- rbln_enc_max_seq_len: Optional[int] = None,
226
- rbln_dec_max_seq_len: Optional[int] = None,
227
- rbln_batch_size: Optional[int] = 1,
205
+ rbln_kwargs: Dict[str, Any] = {},
228
206
  ) -> RBLNConfig:
229
- meta = {}
207
+ rbln_enc_max_seq_len = rbln_kwargs.get("enc_max_seq_len", None)
208
+ rbln_dec_max_seq_len = rbln_kwargs.get("dec_max_seq_len", None)
209
+ rbln_batch_size = rbln_kwargs.get("batch_size", None)
210
+ rbln_batch_size = 1 if rbln_batch_size is None else rbln_batch_size
230
211
 
231
212
  if isinstance(model_config, BartConfig):
232
213
  n_layer = model_config.decoder_layers
@@ -274,28 +255,36 @@ class RBLNModelForSeq2SeqLM(RBLNModel):
274
255
  if max_position_embeddings is not None and rbln_dec_max_seq_len > max_position_embeddings:
275
256
  raise ValueError("`rbln_dec_max_seq_len` should be less or equal than max_position_embeddings!")
276
257
 
277
- rbln_batch_size = 1 if rbln_batch_size is None else rbln_batch_size
278
-
279
- meta["rbln_enc_max_seq_len"] = rbln_enc_max_seq_len
280
- meta["rbln_dec_max_seq_len"] = rbln_dec_max_seq_len
281
- meta["rbln_batch_size"] = rbln_batch_size
282
- meta["rbln_pad_token_id"] = rbln_pad_token_id
283
-
284
258
  # model input info
285
259
  enc_input_info = [
286
- ("input_ids", [rbln_batch_size, rbln_enc_max_seq_len], "int64"),
287
- ("attention_mask", [rbln_batch_size, rbln_enc_max_seq_len], "int64"),
260
+ ("input_ids", [1, rbln_enc_max_seq_len], "int64"),
261
+ ("attention_mask", [1, rbln_enc_max_seq_len], "float32"),
262
+ (
263
+ "cross_key_value_states",
264
+ [
265
+ n_layer * 2,
266
+ rbln_batch_size,
267
+ n_head,
268
+ rbln_enc_max_seq_len,
269
+ d_kv,
270
+ ],
271
+ "float32",
272
+ ),
273
+ # int16 available?
274
+ ("batch_idx", [], "int32"),
288
275
  ]
289
276
 
290
277
  dec_input_info = [
291
278
  ("input_ids", [rbln_batch_size, 1], "int64"),
292
- ("attention_mask", [rbln_batch_size, rbln_dec_max_seq_len], "int64"),
293
- ("encoder_attention_mask", [rbln_batch_size, rbln_enc_max_seq_len], "int64"),
279
+ ("attention_mask", [rbln_batch_size, rbln_dec_max_seq_len], "float32"),
280
+ ("encoder_attention_mask", [rbln_batch_size, rbln_enc_max_seq_len], "float32"),
294
281
  (
295
282
  "cache_position",
296
- [],
283
+ [rbln_batch_size, 1],
284
+ # [],
297
285
  "int32",
298
286
  ),
287
+ ("batch_position", [], "int32"),
299
288
  ]
300
289
  dec_input_info.extend(
301
290
  [
@@ -327,12 +316,22 @@ class RBLNModelForSeq2SeqLM(RBLNModel):
327
316
  )
328
317
  ]
329
318
  )
330
- enc_rbln_runtime_config = RBLNRuntimeConfig(rbln_mod_name="encoder", input_info=enc_input_info)
331
- dec_rbln_runtime_config = RBLNRuntimeConfig(rbln_mod_name="decoder", input_info=dec_input_info)
319
+ enc_rbln_compile_config = RBLNCompileConfig(mod_name="encoder", input_info=enc_input_info)
320
+ dec_rbln_compile_config = RBLNCompileConfig(mod_name="decoder", input_info=dec_input_info)
332
321
 
333
- rbln_config = RBLNConfig.from_rbln_runtime_configs(
334
- [enc_rbln_runtime_config, dec_rbln_runtime_config],
335
- _rbln_meta=meta,
322
+ rbln_config = RBLNConfig(
323
+ rbln_cls=cls.__name__,
324
+ compile_cfgs=[enc_rbln_compile_config, dec_rbln_compile_config],
325
+ rbln_kwargs=rbln_kwargs,
326
+ )
327
+
328
+ rbln_config.model_cfg.update(
329
+ {
330
+ "enc_max_seq_len": rbln_enc_max_seq_len,
331
+ "dec_max_seq_len": rbln_dec_max_seq_len,
332
+ "batch_size": rbln_batch_size,
333
+ "pad_token_id": rbln_pad_token_id,
334
+ }
336
335
  )
337
336
 
338
337
  return rbln_config
@@ -347,7 +346,84 @@ class RBLNModelForSeq2SeqLM(RBLNModel):
347
346
  compiled_models[0].create_runtime("decoder", tensor_type="pt", device=device_val),
348
347
  ]
349
348
 
349
+ def prepare_inputs_for_generation(
350
+ self,
351
+ input_ids,
352
+ past_key_values=None,
353
+ attention_mask=None,
354
+ decoder_attention_mask=None,
355
+ **kwargs,
356
+ ):
357
+ past_cache_length = past_key_values
358
+ if past_cache_length == 0:
359
+ cache_pos = []
360
+ for i in range(input_ids.shape[0]):
361
+ cache_pos.append([0])
362
+ cache_position = torch.tensor(cache_pos, dtype=torch.int32)
363
+
364
+ max_seq_len = self.dec_max_seq_len
365
+ cur_seq_len = input_ids.shape[-1]
366
+ decoder_batch_size = input_ids.shape[0]
367
+ input_ids = input_ids[:, cur_seq_len - 1 : cur_seq_len].contiguous()
368
+ # In greedy decoding
369
+ decoder_attention_mask = torch.zeros(decoder_batch_size, max_seq_len, dtype=torch.float32)
370
+ decoder_attention_mask[:, :cur_seq_len] = 1
371
+ cache_pos = []
372
+ for i in range(input_ids.shape[0]):
373
+ cache_pos.append([cur_seq_len - 1])
374
+ cache_position = torch.tensor(cache_pos, dtype=torch.int32)
375
+ return {
376
+ "decoder_input_ids": input_ids,
377
+ "past_key_values": past_key_values,
378
+ "attention_mask": attention_mask.to(torch.float32),
379
+ "decoder_attention_mask": decoder_attention_mask,
380
+ "cache_position": cache_position,
381
+ }
382
+
350
383
  def forward(
384
+ self,
385
+ input_ids: torch.LongTensor = None,
386
+ cache_position: Union[List[torch.Tensor], torch.Tensor] = None, # vllm keyword argument
387
+ batch_idx: Optional[torch.LongTensor] = None,
388
+ enc_lengths: List[int] = None,
389
+ **kwargs,
390
+ ) -> Tuple[torch.FloatTensor]:
391
+ # common decoder
392
+ if enc_lengths is None:
393
+ output = self._forward_decoder(input_ids=input_ids, cache_position=cache_position, **kwargs)
394
+ return output
395
+
396
+ # vllm & encoder
397
+ if batch_idx is not None:
398
+ enc_attention_mask = self.enc_attention_mask.clone()
399
+ enc_attention_mask[0][: enc_lengths[batch_idx] + 1] = 1
400
+ padding_need = self.enc_max_seq_len - input_ids.shape[-1]
401
+ input_ids = torch.nn.functional.pad(input_ids, (0, padding_need))
402
+ _ = self.encoder(input_ids, enc_attention_mask, batch_idx=batch_idx.to(torch.int32))
403
+ logits = torch.zeros(1, 1, self.config.vocab_size + 100)
404
+ logits[0][0][-1] = 1
405
+ # vllm & decoder
406
+ else:
407
+ input_ids[input_ids == (self.config.vocab_size + 99)] = self.config.decoder_start_token_id
408
+ cache_position[cache_position != 0] = cache_position[cache_position != 0] - 2
409
+
410
+ enc_attention_mask = self.dec_enc_attention_mask.clone()
411
+ dec_attention_mask = torch.zeros(self.batch_size, self.dec_max_seq_len, dtype=torch.float32)
412
+ for batch_idx in range(self.batch_size):
413
+ enc_attention_mask[batch_idx, : enc_lengths[batch_idx] + 1] = 1
414
+
415
+ logits = self._forward_decoder(
416
+ attention_mask=enc_attention_mask,
417
+ decoder_input_ids=input_ids,
418
+ decoder_attention_mask=dec_attention_mask,
419
+ cache_position=cache_position,
420
+ ).logits
421
+
422
+ return Seq2SeqLMOutput(
423
+ logits=logits,
424
+ )
425
+
426
+ def _forward_decoder(
351
427
  self,
352
428
  attention_mask: Optional[torch.FloatTensor] = None,
353
429
  decoder_input_ids: Optional[torch.LongTensor] = None,
@@ -355,13 +431,18 @@ class RBLNModelForSeq2SeqLM(RBLNModel):
355
431
  cache_position: Optional[torch.Tensor] = None,
356
432
  **kwargs,
357
433
  ) -> Tuple[torch.FloatTensor]:
434
+ dec_attention_mask = decoder_attention_mask.clone()
435
+ for b_idx in range(self.rbln_config.model_cfg["batch_size"]):
436
+ dec_attention_mask[b_idx, : cache_position[b_idx] + 1] = 1
437
+
358
438
  decoder_output = self.decoder(
359
439
  input_ids=decoder_input_ids,
360
- attention_mask=decoder_attention_mask,
440
+ attention_mask=dec_attention_mask,
361
441
  encoder_attention_mask=attention_mask,
362
442
  cache_position=cache_position,
443
+ batch_position=torch.tensor(0, dtype=torch.int32),
363
444
  )
364
- lm_logits = decoder_output.logits
445
+ lm_logits = decoder_output.logits[0]
365
446
 
366
447
  return Seq2SeqLMOutput(logits=lm_logits)
367
448
 
@@ -405,6 +486,14 @@ class RBLNModelForSeq2SeqLM(RBLNModel):
405
486
  model_input_name = model_input_name if model_input_name is not None else self.main_input_name
406
487
  encoder_kwargs["return_dict"] = True
407
488
  encoder_kwargs[model_input_name] = inputs_tensor
408
- model_kwargs["encoder_outputs"] = encoder(**encoder_kwargs)
489
+ for b in range(batch_size):
490
+ batch_idx = torch.tensor(b, dtype=torch.int32)
491
+ cb_inputs = {}
492
+ cb_inputs["return_dict"] = True
493
+ cb_inputs["output_hidden_states"] = False
494
+ cb_inputs["output_attentions"] = False
495
+ cb_inputs["input_ids"] = encoder_kwargs["input_ids"][b].unsqueeze(0)
496
+ cb_inputs["attention_mask"] = encoder_kwargs["attention_mask"][b].unsqueeze(0).to(torch.float32)
497
+ model_kwargs["encoder_outputs"] = encoder(**cb_inputs, batch_idx=batch_idx)
409
498
 
410
499
  return model_kwargs