optimum-rbln 0.1.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (51) hide show
  1. optimum/rbln/__init__.py +115 -0
  2. optimum/rbln/__version__.py +1 -0
  3. optimum/rbln/diffusers/__init__.py +64 -0
  4. optimum/rbln/diffusers/models/__init__.py +26 -0
  5. optimum/rbln/diffusers/models/autoencoder_kl.py +313 -0
  6. optimum/rbln/diffusers/models/controlnet.py +180 -0
  7. optimum/rbln/diffusers/models/unet_2d_condition.py +352 -0
  8. optimum/rbln/diffusers/pipelines/__init__.py +30 -0
  9. optimum/rbln/diffusers/pipelines/controlnet/__init__.py +24 -0
  10. optimum/rbln/diffusers/pipelines/controlnet/multicontrolnet.py +266 -0
  11. optimum/rbln/diffusers/pipelines/stable_diffusion/__init__.py +26 -0
  12. optimum/rbln/diffusers/pipelines/stable_diffusion/pipeline_controlnet_img2img.py +731 -0
  13. optimum/rbln/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion.py +106 -0
  14. optimum/rbln/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_img2img.py +116 -0
  15. optimum/rbln/diffusers/pipelines/stable_diffusion_xl/__init__.py +2 -0
  16. optimum/rbln/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl.py +109 -0
  17. optimum/rbln/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_img2img.py +111 -0
  18. optimum/rbln/modeling.py +0 -0
  19. optimum/rbln/modeling_alias.py +49 -0
  20. optimum/rbln/modeling_base.py +645 -0
  21. optimum/rbln/modeling_config.py +169 -0
  22. optimum/rbln/modeling_seq2seq.py +469 -0
  23. optimum/rbln/transformers/__init__.py +59 -0
  24. optimum/rbln/transformers/generation/__init__.py +24 -0
  25. optimum/rbln/transformers/generation/streamers.py +122 -0
  26. optimum/rbln/transformers/models/__init__.py +28 -0
  27. optimum/rbln/transformers/models/bart/__init__.py +24 -0
  28. optimum/rbln/transformers/models/bart/bart_architecture.py +377 -0
  29. optimum/rbln/transformers/models/clip/__init__.py +24 -0
  30. optimum/rbln/transformers/models/clip/modeling_clip.py +116 -0
  31. optimum/rbln/transformers/models/gpt2/__init__.py +24 -0
  32. optimum/rbln/transformers/models/gpt2/gpt2_architecture.py +253 -0
  33. optimum/rbln/transformers/models/gpt2/modeling_gpt2.py +700 -0
  34. optimum/rbln/transformers/models/llama/__init__.py +24 -0
  35. optimum/rbln/transformers/models/llama/llama_architecture.py +607 -0
  36. optimum/rbln/transformers/models/llama/modeling_llama.py +409 -0
  37. optimum/rbln/transformers/models/t5/__init__.py +24 -0
  38. optimum/rbln/transformers/models/t5/t5_architecture.py +439 -0
  39. optimum/rbln/transformers/models/wav2vec2/__init__.py +24 -0
  40. optimum/rbln/transformers/models/wav2vec2/modeling_wav2vec2.py +121 -0
  41. optimum/rbln/transformers/models/whisper/__init__.py +24 -0
  42. optimum/rbln/transformers/models/whisper/modeling_whisper.py +374 -0
  43. optimum/rbln/transformers/models/whisper/whisper_architecture.py +406 -0
  44. optimum/rbln/utils/__init__.py +25 -0
  45. optimum/rbln/utils/import_utils.py +28 -0
  46. optimum/rbln/utils/runtime_utils.py +71 -0
  47. optimum/rbln/utils/save_utils.py +92 -0
  48. optimum_rbln-0.1.0.dist-info/METADATA +144 -0
  49. optimum_rbln-0.1.0.dist-info/RECORD +51 -0
  50. optimum_rbln-0.1.0.dist-info/WHEEL +4 -0
  51. optimum_rbln-0.1.0.dist-info/licenses/LICENSE +201 -0
@@ -0,0 +1,169 @@
1
+ # Copyright 2024 Rebellions Inc.
2
+
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at:
6
+
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ # Portions of this software are licensed under the Apache License,
16
+ # Version 2.0. See the NOTICE file distributed with this work for
17
+ # additional information regarding copyright ownership.
18
+
19
+ # All other portions of this software, including proprietary code,
20
+ # are the intellectual property of Rebellions Inc. and may not be
21
+ # copied, modified, or distributed without prior written permission
22
+ # from Rebellions Inc.
23
+
24
+ import copy
25
+ import json
26
+ from collections import UserDict
27
+ from dataclasses import asdict, dataclass
28
+ from pathlib import Path
29
+ from typing import Any, Dict, List, Optional, Tuple
30
+
31
+ import torch
32
+
33
+
34
+ DEFAULT_COMPILED_MODEL_NAME = "compiled_model"
35
+ DEFAULT_MOD_NAME = "default"
36
+
37
+
38
+ @dataclass
39
+ class RBLNRuntimeConfig:
40
+ compiled_model_name: str = DEFAULT_COMPILED_MODEL_NAME
41
+ rbln_mod_name: str = DEFAULT_MOD_NAME
42
+ input_info: List[Tuple[str, Tuple[int], Optional[str]]] = None
43
+ batch_size: Optional[int] = None
44
+ fusion: Optional[bool] = None
45
+ npu: Optional[str] = None
46
+ tensor_parallel_size: Optional[int] = None
47
+
48
+ @staticmethod
49
+ def normalize_dtype(dtype):
50
+ """
51
+ framework's dtype to string.
52
+ i.e. torch.float32 -> "float32"
53
+ """
54
+ if isinstance(dtype, str):
55
+ return dtype
56
+ else:
57
+ dtype: str = repr(dtype).split(".")[-1]
58
+ if dtype.endswith("'>"): # numpy
59
+ dtype = dtype[:-2]
60
+ return dtype
61
+
62
+ def __post_init__(self):
63
+ self.input_info = [(i[0], i[1], RBLNRuntimeConfig.normalize_dtype(i[2]) or "float32") for i in self.input_info]
64
+
65
+ def update(self, **kwargs):
66
+ self.compiled_model_name = kwargs.get("compiled_model_name", self.compiled_model_name)
67
+ self.rbln_mod_name = kwargs.get("rbln_mod_name", self.rbln_mod_name)
68
+ self.input_info = kwargs.get("input_info", self.input_info)
69
+ self.batch_size = kwargs.get("batch_size", self.batch_size)
70
+ self.fusion = kwargs.get("fusion", self.fusion)
71
+ self.npu = kwargs.get("npu", self.npu)
72
+ self.tensor_parallel_size = kwargs.get("tensor_parallel_size", self.tensor_parallel_size)
73
+ return self
74
+
75
+ def get_dummy_inputs(self, fill=0):
76
+ dummy = []
77
+ for name, shape, dtype in self.input_info:
78
+ dummy.append(
79
+ torch.fill(torch.zeros(*shape, dtype=getattr(torch, dtype)), fill)
80
+ if len(shape) > 0
81
+ else torch.tensor(fill, dtype=getattr(torch, dtype))
82
+ )
83
+ return tuple(dummy)
84
+
85
+ def asdict(self):
86
+ return asdict(self)
87
+
88
+
89
+ class RBLNConfig(UserDict):
90
+ def __init__(self, runtime_cfgs: Dict[str, List[RBLNRuntimeConfig]], _rbln_meta: Dict[str, Any] = None):
91
+ """Configurations for RBLN model compilation and inference.
92
+
93
+ Args:
94
+ _rbln_meta (Dict[str, Any], optional):
95
+ Any rbln-specific configurations.
96
+ (i.e. max_seq_len for language models, image_size for image models).
97
+ Defaults to None.
98
+ """
99
+ super().__init__(runtime_cfgs)
100
+ if _rbln_meta:
101
+ self.meta = _rbln_meta
102
+ else:
103
+ self.meta: Dict[str, Any] = {}
104
+
105
+ @staticmethod
106
+ def from_rbln_configs(rbln_configs: List["RBLNConfig"], names: Optional[List[str]] = None) -> "RBLNConfig":
107
+ # assume each rbln_config has exact one rbln_runtime_config
108
+ names = [None] * len(rbln_configs) if names is None else names
109
+ runtime_cfgs = []
110
+ for name, cfg in zip(names, rbln_configs):
111
+ if len(cfg) > 1:
112
+ msg = (
113
+ "`from_rbln_configs` requires exact one `RBLNRuntimeConfig` for each `RBLNConfig`."
114
+ f"But got {len(cfg)} `RBLNRuntimeConfig`."
115
+ )
116
+ raise RuntimeError(msg)
117
+
118
+ runtime_cfg = cfg[list(cfg.keys())[0]][0]
119
+ runtime_cfg = copy.deepcopy(runtime_cfg)
120
+ if name is not None:
121
+ runtime_cfg.compiled_model_name = name
122
+ runtime_cfgs.append(runtime_cfg)
123
+
124
+ metas = [cfg.meta for cfg in rbln_configs]
125
+ merged_meta = {k: v for meta in metas for k, v in meta.items()}
126
+
127
+ return RBLNConfig.from_rbln_runtime_configs(runtime_cfgs, _rbln_meta=merged_meta)
128
+
129
+ @staticmethod
130
+ def from_rbln_runtime_configs(
131
+ rbln_runtime_configs: List[RBLNRuntimeConfig],
132
+ _rbln_meta: Dict[str, Any] = None,
133
+ ) -> "RBLNConfig":
134
+ cfgs: Dict[str, List[RBLNRuntimeConfig]] = {}
135
+ for rbln_runtime_config in rbln_runtime_configs:
136
+ if rbln_runtime_config.compiled_model_name in cfgs:
137
+ cfgs[rbln_runtime_config.compiled_model_name].append(rbln_runtime_config)
138
+ else:
139
+ cfgs[rbln_runtime_config.compiled_model_name] = [rbln_runtime_config]
140
+ return RBLNConfig(cfgs, _rbln_meta=_rbln_meta)
141
+
142
+ def save(self, dir_path: str):
143
+ dir_path = Path(dir_path)
144
+ data = self.asdict()
145
+ data.update({"rbln_config_meta": self.meta})
146
+ with open(dir_path / "rbln_config.json", "w") as jsonf:
147
+ json.dump(data, jsonf, indent=2)
148
+
149
+ @staticmethod
150
+ def load(dir_path: str) -> "RBLNConfig":
151
+ dir_path = Path(dir_path)
152
+ with open(dir_path / "rbln_config.json", "r") as jsonf:
153
+ config_file = json.load(jsonf)
154
+ return RBLNConfig.fromdict(config_file)
155
+
156
+ def asdict(self):
157
+ dic = {k: [asdict(cfg) for cfg in cfgs] for k, cfgs in self.data.items()}
158
+ return dic
159
+
160
+ @staticmethod
161
+ def fromdict(dic: dict):
162
+ runtime_cfgs = {
163
+ k: [RBLNRuntimeConfig(**cfg) for cfg in cfgs] for k, cfgs in dic.items() if k != "rbln_config_meta"
164
+ }
165
+ if "rbln_config_meta" in dic:
166
+ meta = dic["rbln_config_meta"]
167
+ else:
168
+ meta = None
169
+ return RBLNConfig(runtime_cfgs, _rbln_meta=meta)
@@ -0,0 +1,469 @@
1
+ # Copyright 2024 Rebellions Inc.
2
+
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at:
6
+
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ # Portions of this software are licensed under the Apache License,
16
+ # Version 2.0. See the NOTICE file distributed with this work for
17
+ # additional information regarding copyright ownership.
18
+
19
+ # All other portions of this software, including proprietary code,
20
+ # are the intellectual property of Rebellions Inc. and may not be
21
+ # copied, modified, or distributed without prior written permission
22
+ # from Rebellions Inc.
23
+
24
+ import inspect
25
+ import logging
26
+ from pathlib import Path
27
+ from tempfile import TemporaryDirectory
28
+ from typing import TYPE_CHECKING, Any, Callable, Dict, List, Optional, Tuple, Union
29
+
30
+ import rebel
31
+ import torch
32
+ from optimum.exporters import TasksManager
33
+ from transformers import (
34
+ AutoModelForSeq2SeqLM,
35
+ BartConfig,
36
+ BartForConditionalGeneration,
37
+ PretrainedConfig,
38
+ T5ForConditionalGeneration,
39
+ )
40
+ from transformers.modeling_outputs import BaseModelOutput, Seq2SeqLMOutput
41
+
42
+ from .modeling_base import RBLNBaseModel
43
+ from .modeling_config import DEFAULT_COMPILED_MODEL_NAME, RBLNConfig, RBLNRuntimeConfig
44
+ from .transformers.models.bart import BartDecoderWrapper, BartEncoderWrapper
45
+ from .transformers.models.t5 import T5DecoderWrapper, T5EncoderWrapper
46
+ from .utils.runtime_utils import RBLNPytorchRuntime
47
+ from .utils.save_utils import maybe_save_preprocessors
48
+
49
+
50
+ logger = logging.getLogger(__name__)
51
+
52
+ if TYPE_CHECKING:
53
+ from transformers import (
54
+ AutoFeatureExtractor,
55
+ AutoProcessor,
56
+ AutoTokenizer,
57
+ PretrainedConfig,
58
+ )
59
+
60
+
61
+ class RBLNRuntimeEncoder(RBLNPytorchRuntime):
62
+ mandatory_members = ["main_input_name"]
63
+
64
+ def forward(self, *args: List[torch.Tensor], **kwargs: Dict[str, torch.Tensor]):
65
+ _ = super().forward(*args, **kwargs)
66
+ # Just indicates that it is not None
67
+ return BaseModelOutput(last_hidden_state=torch.tensor([1.0]))
68
+
69
+
70
+ class RBLNRuntimeDecoder(RBLNPytorchRuntime):
71
+ mandatory_members = ["main_input_name"]
72
+
73
+ def forward(self, *args: List[torch.Tensor], **kwargs: Dict[str, torch.Tensor]):
74
+ outputs = super().forward(*args, **kwargs)
75
+ return Seq2SeqLMOutput(logits=outputs)
76
+
77
+
78
+ class RBLNModelForSeq2SeqLM(RBLNBaseModel):
79
+ """
80
+ This is a generic model class that will be instantiated as one of the model classes of the library (with a sequence-to-sequence language modeling head) when created with the from_pretrained() class method.
81
+ This model inherits from [`RBLNBaseModel`]. Check the superclass documentation for the generic methods the library implements for all its models.
82
+
83
+ A class to convert and run pre-trained transformers based Seq2SeqLM models on RBLN devices.
84
+ It implements the methods to convert a pre-trained transformers Seq2SeqLM model into a RBLN transformer model by:
85
+ - transferring the checkpoint weights of the original into an optimized RBLN graph,
86
+ - compiling the resulting graph using the RBLN compiler.
87
+
88
+ Currently, this model class only supports the 'bart' and 't5' models from the transformers library. Future updates may include support for additional model types.
89
+ """
90
+
91
+ model_type = "rbln_model"
92
+ auto_model_class = AutoModelForSeq2SeqLM
93
+
94
+ def __post_init__(self, **kwargs):
95
+ self.model_dim = self.config.d_model
96
+ self.batch_size = self.rbln_config[DEFAULT_COMPILED_MODEL_NAME][0].batch_size
97
+ self.enc_max_seq_len = self.rbln_config.meta["rbln_enc_max_seq_len"]
98
+ self.dec_max_seq_len = self.rbln_config.meta["rbln_dec_max_seq_len"]
99
+ self.pad_token_id = self.rbln_config.meta["rbln_pad_token_id"]
100
+ self.encoder = RBLNRuntimeEncoder(runtime=self.runtimes[0], main_input_name="input_ids")
101
+ self.decoder = RBLNRuntimeDecoder(runtime=self.runtimes[1], main_input_name="input_ids")
102
+
103
+ def can_generate(self):
104
+ return True
105
+
106
+ def get_encoder(self):
107
+ return self.encoder
108
+
109
+ def get_decoder(self):
110
+ return self.decoder
111
+
112
+ def __getattr__(self, __name: str) -> Any:
113
+ def redirect(func):
114
+ return lambda *pargs, **kwargs: func(self, *pargs, **kwargs)
115
+
116
+ if "T5ForConditionalGeneration" == self.config.architectures:
117
+ val = getattr(T5ForConditionalGeneration, __name)
118
+ else:
119
+ val = getattr(BartForConditionalGeneration, __name)
120
+
121
+ if isinstance(val, Callable) and "self" in set(inspect.signature(val).parameters):
122
+ return redirect(val)
123
+ return val
124
+
125
+ def prepare_inputs_for_generation(
126
+ self,
127
+ input_ids,
128
+ past_key_values=None,
129
+ attention_mask=None,
130
+ decoder_attention_mask=None,
131
+ **kwargs,
132
+ ):
133
+ max_seq_len = self.dec_max_seq_len
134
+ cur_seq_len = input_ids.shape[-1]
135
+ decoder_batch_size = input_ids.shape[0]
136
+ input_ids = input_ids[:, cur_seq_len - 1 : cur_seq_len].contiguous()
137
+
138
+ # In greedy decoding
139
+ decoder_attention_mask = torch.zeros(decoder_batch_size, max_seq_len, dtype=torch.int64)
140
+ decoder_attention_mask[:, :cur_seq_len] = 1
141
+ cache_position = torch.tensor(cur_seq_len - 1, dtype=torch.int32)
142
+
143
+ return {
144
+ "decoder_input_ids": input_ids,
145
+ "past_key_values": past_key_values,
146
+ "attention_mask": attention_mask,
147
+ "decoder_attention_mask": decoder_attention_mask,
148
+ "cache_position": cache_position,
149
+ }
150
+
151
+ @classmethod
152
+ def _export(
153
+ cls,
154
+ model_id: str,
155
+ config: "PretrainedConfig",
156
+ use_auth_token: Optional[Union[bool, str]] = None,
157
+ revision: Optional[str] = None,
158
+ force_download: bool = False,
159
+ cache_dir: Optional[str] = None,
160
+ subfolder: str = "",
161
+ local_files_only: bool = False,
162
+ trust_remote_code: bool = False,
163
+ **kwargs,
164
+ ) -> "AutoModelForSeq2SeqLM":
165
+ """
166
+ Exports a vanilla Transformers model into a rbln-compiled Module.
167
+ """
168
+ task = kwargs.pop("task", None)
169
+ if task is None:
170
+ task = TasksManager.infer_task_from_model(cls.auto_model_class)
171
+
172
+ save_dir = TemporaryDirectory()
173
+ save_dir_path = Path(save_dir.name)
174
+
175
+ kwargs.update(
176
+ {
177
+ "torchscript": True,
178
+ "return_dict": False,
179
+ "use_cache": False,
180
+ }
181
+ )
182
+
183
+ rbln_config_kwargs, rbln_constructor_kwargs = cls.pop_rbln_kwargs_from_kwargs(kwargs)
184
+
185
+ model: AutoModelForSeq2SeqLM = TasksManager.get_model_from_task(
186
+ task=task,
187
+ model_name_or_path=model_id,
188
+ subfolder=subfolder,
189
+ revision=revision,
190
+ framework="pt",
191
+ cache_dir=cache_dir,
192
+ use_auth_token=use_auth_token,
193
+ local_files_only=local_files_only,
194
+ force_download=force_download,
195
+ trust_remote_code=trust_remote_code,
196
+ **kwargs,
197
+ )
198
+
199
+ if config is None:
200
+ config = model.config
201
+
202
+ config.save_pretrained(save_dir_path)
203
+ preprocessors = maybe_save_preprocessors(model_id, save_dir_path, src_subfolder=subfolder)
204
+
205
+ # Get compilation arguments
206
+ if rbln_config_kwargs.get("rbln_config", None) is None:
207
+ rbln_config = cls.get_rbln_config(
208
+ preprocessors=preprocessors, model_config=model.config, **rbln_config_kwargs
209
+ )
210
+
211
+ def optimized_models(model):
212
+ if isinstance(model, T5ForConditionalGeneration):
213
+ encoder_model = T5EncoderWrapper(model).eval()
214
+ decoder_model = T5DecoderWrapper(model).eval()
215
+ elif isinstance(model, BartForConditionalGeneration):
216
+ encoder_model = BartEncoderWrapper(model).eval()
217
+ decoder_model = BartDecoderWrapper(model).eval()
218
+ else:
219
+ raise ValueError(f"{model.__class__.__name__} is not supported yet.")
220
+
221
+ return encoder_model, decoder_model
222
+
223
+ def compile():
224
+ wrapped_encoder, wrapped_decoder = optimized_models(model)
225
+
226
+ wrapped_encoder.encoder_max_length = rbln_config.meta["rbln_enc_max_seq_len"]
227
+ wrapped_encoder.decoder_max_length = rbln_config.meta["rbln_dec_max_seq_len"]
228
+ wrapped_encoder.decoder_batch_size = rbln_config.meta["rbln_batch_size"]
229
+
230
+ wrapped_decoder.encoder_max_length = rbln_config.meta["rbln_enc_max_seq_len"]
231
+ wrapped_decoder.decoder_max_length = rbln_config.meta["rbln_dec_max_seq_len"]
232
+ wrapped_decoder.decoder_batch_size = rbln_config.meta["rbln_batch_size"]
233
+
234
+ enc_rbln_runtime_config = rbln_config[DEFAULT_COMPILED_MODEL_NAME][0]
235
+ dec_rbln_runtime_config = rbln_config[DEFAULT_COMPILED_MODEL_NAME][1]
236
+
237
+ if isinstance(model, T5ForConditionalGeneration):
238
+ enc_example_inputs = enc_rbln_runtime_config.get_dummy_inputs(fill=1)
239
+ dec_example_inputs = dec_rbln_runtime_config.get_dummy_inputs(fill=1)
240
+ else:
241
+ enc_example_inputs = enc_rbln_runtime_config.get_dummy_inputs(fill=0)
242
+ dec_example_inputs = dec_rbln_runtime_config.get_dummy_inputs(fill=0)
243
+
244
+ enc_scripted_model = torch.jit.trace(wrapped_encoder, enc_example_inputs)
245
+ dec_scripted_model = torch.jit.trace(wrapped_decoder, dec_example_inputs)
246
+
247
+ enc_ir = rebel.torchscript_to_ir(
248
+ enc_scripted_model,
249
+ input_names=[v[0] for v in enc_rbln_runtime_config.input_info],
250
+ name=enc_rbln_runtime_config.rbln_mod_name,
251
+ )
252
+ dec_ir = rebel.torchscript_to_ir(
253
+ dec_scripted_model,
254
+ input_names=[v[0] for v in dec_rbln_runtime_config.input_info],
255
+ name=dec_rbln_runtime_config.rbln_mod_name,
256
+ )
257
+ dec_ir.decoder_batch_size = rbln_config.meta["rbln_batch_size"]
258
+
259
+ connections = [
260
+ (enc_ir.outputs[0], dec_ir.inputs[5]),
261
+ (dec_ir.outputs[1], dec_ir.inputs[4]),
262
+ ]
263
+ compiled_model = rebel.compile(
264
+ enc_ir,
265
+ dec_ir,
266
+ connections=connections,
267
+ fusion=enc_rbln_runtime_config.fusion,
268
+ npu=enc_rbln_runtime_config.npu,
269
+ tensor_parallel_size=enc_rbln_runtime_config.tensor_parallel_size,
270
+ )
271
+ compiled_model.save(save_dir_path / f"{DEFAULT_COMPILED_MODEL_NAME}.rbln")
272
+
273
+ compile()
274
+
275
+ rbln_config.save(save_dir_path)
276
+
277
+ return cls._from_pretrained(
278
+ model_id=save_dir_path,
279
+ config=config,
280
+ model_save_dir=save_dir,
281
+ **rbln_constructor_kwargs,
282
+ **kwargs,
283
+ )
284
+
285
+ @classmethod
286
+ def _get_rbln_config(
287
+ cls,
288
+ preprocessors: Union["AutoFeatureExtractor", "AutoProcessor", "AutoTokenizer"],
289
+ model_config: "PretrainedConfig",
290
+ rbln_enc_max_seq_len: Optional[int] = None,
291
+ rbln_dec_max_seq_len: Optional[int] = None,
292
+ rbln_batch_size: Optional[int] = 1,
293
+ ) -> RBLNConfig:
294
+ meta = {}
295
+
296
+ if isinstance(model_config, BartConfig):
297
+ n_layer = model_config.decoder_layers
298
+ n_head = model_config.decoder_attention_heads
299
+ d_kv = model_config.d_model // model_config.encoder_attention_heads
300
+ else:
301
+ n_layer = model_config.num_layers
302
+ n_head = model_config.num_heads
303
+ d_kv = model_config.d_kv
304
+
305
+ max_position_embeddings = getattr(model_config, "n_positions", None) or getattr(
306
+ model_config, "max_position_embeddings", None
307
+ )
308
+
309
+ rbln_pad_token_id = getattr(model_config, "pad_token_id", None)
310
+ if rbln_pad_token_id is None:
311
+ rbln_pad_token_id = getattr(model_config, "bos_token_id", None)
312
+ if rbln_pad_token_id is None:
313
+ rbln_pad_token_id = getattr(model_config, "eos_token_id", None)
314
+ if rbln_pad_token_id is None:
315
+ rbln_pad_token_id = -1
316
+
317
+ if rbln_enc_max_seq_len is None:
318
+ rbln_enc_max_seq_len = max_position_embeddings
319
+ if rbln_enc_max_seq_len is None:
320
+ for tokenizer in preprocessors:
321
+ if hasattr(tokenizer, "model_max_length"):
322
+ rbln_enc_max_seq_len = tokenizer.model_max_length
323
+ break
324
+ if rbln_enc_max_seq_len is None:
325
+ raise ValueError("`rbln_enc_max_seq_len` should be specified!")
326
+ if max_position_embeddings is not None and rbln_enc_max_seq_len > max_position_embeddings:
327
+ raise ValueError("`rbln_enc_max_seq_len` should be less or equal than max_position_embeddings!")
328
+
329
+ if rbln_dec_max_seq_len is None:
330
+ rbln_dec_max_seq_len = max_position_embeddings
331
+ if rbln_dec_max_seq_len is None:
332
+ for tokenizer in preprocessors:
333
+ if hasattr(tokenizer, "model_max_length"):
334
+ rbln_dec_max_seq_len = tokenizer.model_max_length
335
+ break
336
+ if rbln_dec_max_seq_len is None:
337
+ raise ValueError("`rbln_dec_max_seq_len` should be specified!")
338
+
339
+ if max_position_embeddings is not None and rbln_dec_max_seq_len > max_position_embeddings:
340
+ raise ValueError("`rbln_dec_max_seq_len` should be less or equal than max_position_embeddings!")
341
+
342
+ meta["rbln_enc_max_seq_len"] = rbln_enc_max_seq_len
343
+ meta["rbln_dec_max_seq_len"] = rbln_dec_max_seq_len
344
+ meta["rbln_batch_size"] = rbln_batch_size
345
+ meta["rbln_pad_token_id"] = rbln_pad_token_id
346
+
347
+ # model input info
348
+ enc_input_info = [
349
+ ("input_ids", [rbln_batch_size, rbln_enc_max_seq_len], "int64"),
350
+ ("attention_mask", [rbln_batch_size, rbln_enc_max_seq_len], "int64"),
351
+ ]
352
+
353
+ dec_input_info = [
354
+ ("input_ids", [rbln_batch_size, 1], "int64"),
355
+ ("attention_mask", [rbln_batch_size, rbln_dec_max_seq_len], "int64"),
356
+ ("encoder_attention_mask", [rbln_batch_size, rbln_enc_max_seq_len], "int64"),
357
+ (
358
+ "cache_position",
359
+ [],
360
+ "int32",
361
+ ),
362
+ ]
363
+ dec_input_info.extend(
364
+ [
365
+ (
366
+ "self_key_value_states",
367
+ [
368
+ n_layer * 2,
369
+ rbln_batch_size,
370
+ n_head,
371
+ rbln_dec_max_seq_len,
372
+ d_kv,
373
+ ],
374
+ "float32",
375
+ )
376
+ ]
377
+ )
378
+ dec_input_info.extend(
379
+ [
380
+ (
381
+ "cross_key_value_states",
382
+ [
383
+ n_layer * 2,
384
+ rbln_batch_size,
385
+ n_head,
386
+ rbln_enc_max_seq_len,
387
+ d_kv,
388
+ ],
389
+ "float32",
390
+ )
391
+ ]
392
+ )
393
+ enc_rbln_runtime_config = RBLNRuntimeConfig(rbln_mod_name="encoder", input_info=enc_input_info)
394
+ dec_rbln_runtime_config = RBLNRuntimeConfig(rbln_mod_name="decoder", input_info=dec_input_info)
395
+
396
+ rbln_config = RBLNConfig.from_rbln_runtime_configs(
397
+ [enc_rbln_runtime_config, dec_rbln_runtime_config],
398
+ _rbln_meta=meta,
399
+ )
400
+
401
+ return rbln_config
402
+
403
+ def _create_runtimes(self, rbln_device_map: Dict[str, int]) -> List[rebel.Runtime]:
404
+ device_val = rbln_device_map[DEFAULT_COMPILED_MODEL_NAME]
405
+ return [
406
+ self.compiled_models[0].create_runtime("encoder", tensor_type="pt", device=device_val),
407
+ self.compiled_models[0].create_runtime("decoder", tensor_type="pt", device=device_val),
408
+ ]
409
+
410
+ def forward(
411
+ self,
412
+ attention_mask: Optional[torch.FloatTensor] = None,
413
+ decoder_input_ids: Optional[torch.LongTensor] = None,
414
+ decoder_attention_mask: Optional[torch.BoolTensor] = None,
415
+ cache_position: Optional[torch.Tensor] = None,
416
+ **kwargs,
417
+ ) -> Tuple[torch.FloatTensor]:
418
+ decoder_output = self.decoder(
419
+ input_ids=decoder_input_ids,
420
+ attention_mask=decoder_attention_mask,
421
+ encoder_attention_mask=attention_mask,
422
+ cache_position=cache_position,
423
+ )
424
+ lm_logits = decoder_output.logits
425
+
426
+ return Seq2SeqLMOutput(logits=lm_logits)
427
+
428
+ def __repr__(self):
429
+ return repr(self.runtimes[0]) + "\n" + repr(self.runtimes[1])
430
+
431
+ def _prepare_encoder_decoder_kwargs_for_generation(
432
+ self, inputs_tensor: torch.Tensor, model_kwargs, model_input_name: Optional[str] = None
433
+ ) -> Dict[str, Any]:
434
+
435
+ ########## thkim change start ###################
436
+ # padding input_ids & attention_mask regardless of user's tokenizer usage
437
+ batch_size, input_len = inputs_tensor.shape
438
+ inputs_tensor = torch.nn.functional.pad(
439
+ inputs_tensor, (0, self.enc_max_seq_len - input_len), value=self.pad_token_id
440
+ )
441
+ model_kwargs["attention_mask"] = torch.nn.functional.pad(
442
+ model_kwargs["attention_mask"], (0, self.enc_max_seq_len - input_len), value=0
443
+ )
444
+ ########## thkim change end ###################
445
+
446
+ # 1. get encoder
447
+ encoder = self.get_encoder()
448
+
449
+ # 2. Prepare encoder args and encoder kwargs from model kwargs.
450
+ irrelevant_prefix = ["decoder_", "cross_attn", "use_cache"]
451
+ encoder_kwargs = {
452
+ argument: value
453
+ for argument, value in model_kwargs.items()
454
+ if not any(argument.startswith(p) for p in irrelevant_prefix)
455
+ }
456
+ encoder_signature = set(inspect.signature(encoder.forward).parameters)
457
+ encoder_accepts_wildcard = "kwargs" in encoder_signature or "model_kwargs" in encoder_signature
458
+ if not encoder_accepts_wildcard:
459
+ encoder_kwargs = {
460
+ argument: value for argument, value in encoder_kwargs.items() if argument in encoder_signature
461
+ }
462
+
463
+ # 3. make sure that encoder returns `ModelOutput`
464
+ model_input_name = model_input_name if model_input_name is not None else self.main_input_name
465
+ encoder_kwargs["return_dict"] = True
466
+ encoder_kwargs[model_input_name] = inputs_tensor
467
+ model_kwargs["encoder_outputs"] = encoder(**encoder_kwargs)
468
+
469
+ return model_kwargs
@@ -0,0 +1,59 @@
1
+ # Copyright 2024 Rebellions Inc.
2
+
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at:
6
+
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ # Portions of this software are licensed under the Apache License,
16
+ # Version 2.0. See the NOTICE file distributed with this work for
17
+ # additional information regarding copyright ownership.
18
+
19
+ # All other portions of this software, including proprietary code,
20
+ # are the intellectual property of Rebellions Inc. and may not be
21
+ # copied, modified, or distributed without prior written permission
22
+ # from Rebellions Inc.
23
+
24
+ from typing import TYPE_CHECKING
25
+
26
+ from transformers.utils import _LazyModule
27
+
28
+
29
+ _import_structure = {
30
+ "generation": ["BatchTextIteratorStreamer"],
31
+ "models": [
32
+ "RBLNCLIPTextModel",
33
+ "RBLNCLIPTextModelWithProjection",
34
+ "RBLNGPT2LMHeadModel",
35
+ "RBLNWav2Vec2ForCTC",
36
+ "RBLNWhisperForConditionalGeneration",
37
+ "RBLNLlamaForCausalLM",
38
+ ],
39
+ }
40
+
41
+ if TYPE_CHECKING:
42
+ from .generation import BatchTextIteratorStreamer
43
+ from .models import (
44
+ RBLNCLIPTextModel,
45
+ RBLNCLIPTextModelWithProjection,
46
+ RBLNGPT2LMHeadModel,
47
+ RBLNLlamaForCausalLM,
48
+ RBLNWav2Vec2ForCTC,
49
+ RBLNWhisperForConditionalGeneration,
50
+ )
51
+ else:
52
+ import sys
53
+
54
+ sys.modules[__name__] = _LazyModule(
55
+ __name__,
56
+ globals()["__file__"],
57
+ _import_structure,
58
+ module_spec=__spec__,
59
+ )