optimum-rbln 0.1.4__py3-none-any.whl → 0.1.8__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (52) hide show
  1. optimum/rbln/__init__.py +21 -0
  2. optimum/rbln/__version__.py +1 -1
  3. optimum/rbln/diffusers/__init__.py +0 -1
  4. optimum/rbln/diffusers/models/autoencoder_kl.py +16 -98
  5. optimum/rbln/diffusers/models/controlnet.py +3 -0
  6. optimum/rbln/diffusers/models/unet_2d_condition.py +3 -3
  7. optimum/rbln/diffusers/pipelines/controlnet/multicontrolnet.py +22 -146
  8. optimum/rbln/diffusers/pipelines/controlnet/pipeline_controlnet.py +109 -53
  9. optimum/rbln/diffusers/pipelines/controlnet/pipeline_controlnet_img2img.py +114 -53
  10. optimum/rbln/diffusers/pipelines/controlnet/pipeline_controlnet_sd_xl.py +130 -71
  11. optimum/rbln/diffusers/pipelines/controlnet/pipeline_controlnet_sd_xl_img2img.py +131 -72
  12. optimum/rbln/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion.py +3 -0
  13. optimum/rbln/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_img2img.py +8 -0
  14. optimum/rbln/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl.py +8 -0
  15. optimum/rbln/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_img2img.py +9 -0
  16. optimum/rbln/modeling_alias.py +14 -0
  17. optimum/rbln/modeling_base.py +282 -100
  18. optimum/rbln/modeling_seq2seq.py +58 -132
  19. optimum/rbln/transformers/__init__.py +8 -0
  20. optimum/rbln/transformers/cache_utils.py +111 -0
  21. optimum/rbln/transformers/generation/utils.py +0 -2
  22. optimum/rbln/transformers/models/__init__.py +3 -0
  23. optimum/rbln/transformers/models/bart/bart_architecture.py +0 -5
  24. optimum/rbln/transformers/models/clip/modeling_clip.py +0 -1
  25. optimum/rbln/transformers/models/decoderonly/__init__.py +36 -0
  26. optimum/rbln/transformers/models/decoderonly/decoderonly_architecture.py +515 -0
  27. optimum/rbln/transformers/models/decoderonly/modeling_decoderonly.py +349 -0
  28. optimum/rbln/transformers/models/dpt/__init__.py +24 -0
  29. optimum/rbln/transformers/models/dpt/modeling_dpt.py +89 -0
  30. optimum/rbln/transformers/models/gemma/__init__.py +24 -0
  31. optimum/rbln/transformers/models/gemma/gemma_architecture.py +116 -0
  32. optimum/rbln/transformers/models/gemma/modeling_gemma.py +61 -0
  33. optimum/rbln/transformers/models/gpt2/gpt2_architecture.py +200 -174
  34. optimum/rbln/transformers/models/gpt2/modeling_gpt2.py +57 -293
  35. optimum/rbln/transformers/models/llama/llama_architecture.py +3 -613
  36. optimum/rbln/transformers/models/llama/modeling_llama.py +9 -469
  37. optimum/rbln/transformers/models/midm/hf_hub_cached/midm_bitext_tokenization.py +2 -1
  38. optimum/rbln/transformers/models/midm/hf_hub_cached/modeling_midm.py +0 -4
  39. optimum/rbln/transformers/models/midm/midm_architecture.py +160 -357
  40. optimum/rbln/transformers/models/midm/modeling_midm.py +40 -308
  41. optimum/rbln/transformers/models/whisper/modeling_whisper.py +53 -123
  42. optimum/rbln/transformers/models/whisper/whisper_architecture.py +0 -6
  43. optimum/rbln/transformers/models/xlm_roberta/__init__.py +24 -0
  44. optimum/rbln/transformers/models/xlm_roberta/modeling_xlm_roberta.py +125 -0
  45. optimum/rbln/utils/__init__.py +1 -1
  46. optimum/rbln/utils/import_utils.py +46 -0
  47. {optimum_rbln-0.1.4.dist-info → optimum_rbln-0.1.8.dist-info}/METADATA +18 -53
  48. optimum_rbln-0.1.8.dist-info/RECORD +73 -0
  49. {optimum_rbln-0.1.4.dist-info → optimum_rbln-0.1.8.dist-info}/WHEEL +1 -1
  50. optimum/rbln/transformers/models/llama/llama_architecture_cb.py +0 -759
  51. optimum_rbln-0.1.4.dist-info/RECORD +0 -63
  52. {optimum_rbln-0.1.4.dist-info → optimum_rbln-0.1.8.dist-info}/licenses/LICENSE +0 -0
@@ -23,26 +23,16 @@
23
23
 
24
24
  import inspect
25
25
  import logging
26
- from pathlib import Path
27
- from tempfile import TemporaryDirectory
28
- from typing import TYPE_CHECKING, Any, Callable, Dict, List, Optional, Tuple, Union
26
+ from typing import TYPE_CHECKING, Any, Callable, Optional, Union
29
27
 
30
- import rebel
31
- import torch
32
- from optimum.exporters import TasksManager
33
- from transformers import AutoModelForCausalLM, GPT2LMHeadModel, PretrainedConfig
34
- from transformers.modeling_outputs import CausalLMOutputWithCrossAttentions, Seq2SeqLMOutput
28
+ from transformers import GPT2LMHeadModel, PretrainedConfig, PreTrainedModel
35
29
 
36
- from ....modeling_base import RBLNBaseModel
37
- from ....modeling_config import DEFAULT_COMPILED_MODEL_NAME, RBLNConfig, RBLNRuntimeConfig
38
- from ....utils.runtime_utils import RBLNPytorchRuntime
39
- from ....utils.save_utils import maybe_save_preprocessors
40
- from ...generation.utils import RBLNGenerationMixin
30
+ from ....modeling_config import RBLNConfig, RBLNRuntimeConfig
31
+ from ...models.decoderonly import RBLNDecoderOnlyModelForCausalLM
41
32
  from .gpt2_architecture import GPT2LMHeadModelWrapper
42
33
 
43
34
 
44
35
  logger = logging.getLogger(__name__)
45
-
46
36
  if TYPE_CHECKING:
47
37
  from transformers import (
48
38
  AutoFeatureExtractor,
@@ -52,19 +42,12 @@ if TYPE_CHECKING:
52
42
  )
53
43
 
54
44
 
55
- class RBLNRuntimeDecoder(RBLNPytorchRuntime):
56
- def forward(self, *args, **kwargs) -> Union[Tuple, Seq2SeqLMOutput]:
57
- outputs = super().forward(*args, **kwargs)
58
- logits = outputs
59
- return Seq2SeqLMOutput(logits=logits)
60
-
61
-
62
- class RBLNGPT2LMHeadModel(RBLNBaseModel, RBLNGenerationMixin):
45
+ class RBLNGPT2LMHeadModel(RBLNDecoderOnlyModelForCausalLM):
63
46
  """
64
47
  The GPT2 Model transformer with a language modeling head on top (linear layer with weights tied to the input
65
48
  embeddings).
66
49
 
67
- This model inherits from [`RBLNBaseModel`]. Check the superclass documentation for the generic methods the
50
+ This model inherits from [`RBLNMultiModel`]. Check the superclass documentation for the generic methods the
68
51
  library implements for all its model.
69
52
 
70
53
  It implements the methods to convert a pre-trained transformers GPT2 model into a RBLN transformer model by:
@@ -73,29 +56,9 @@ class RBLNGPT2LMHeadModel(RBLNBaseModel, RBLNGenerationMixin):
73
56
 
74
57
  """
75
58
 
76
- model_type = "rbln_model"
77
- auto_model_class = AutoModelForCausalLM
78
- main_input_name = "input_ids"
79
-
80
- def __post_init__(self, **kwargs):
81
- self.prefill_chunk_size = self.rbln_config.meta["rbln_prefill_chunk_size"]
82
- self.max_seq_len = self.rbln_config.meta["rbln_max_seq_len"]
83
-
84
- batch_size = self.rbln_config[DEFAULT_COMPILED_MODEL_NAME][0].input_info[0][1][0]
85
- self.prefill_attention_mask = torch.zeros(
86
- batch_size, 1, self.prefill_chunk_size, self.max_seq_len, dtype=torch.int64
87
- )
88
- self.causal_mask = 1 - torch.triu(
89
- torch.ones(batch_size, 1, self.prefill_chunk_size, self.prefill_chunk_size), diagonal=1
90
- )
91
-
92
- self.prefill_decoder = RBLNRuntimeDecoder(runtime=self.runtimes[0])
93
- self.decoder = RBLNRuntimeDecoder(runtime=self.runtimes[1])
94
- self.pad_token_id = self.rbln_config.meta["rbln_pad_token_id"]
95
- self.past_cached_length = 0
96
-
97
- def can_generate(self):
98
- return True
59
+ @classmethod
60
+ def wrapping_torch_model(self, model: "PreTrainedModel", rbln_max_seq_len: int):
61
+ return GPT2LMHeadModelWrapper(model, rbln_max_seq_len).eval()
99
62
 
100
63
  def __getattr__(self, __name: str) -> Any:
101
64
  """This is the key method to implement RBLN-GPT2.
@@ -112,126 +75,6 @@ class RBLNGPT2LMHeadModel(RBLNBaseModel, RBLNGenerationMixin):
112
75
  return redirect(val)
113
76
  return val
114
77
 
115
- def _reorder_cache(self, past_key_values, beam_idx):
116
- # TODO(jongho): implement
117
- raise NotImplementedError
118
-
119
- @classmethod
120
- def _export(
121
- cls,
122
- model_id: str,
123
- config: "PretrainedConfig",
124
- use_auth_token: Optional[Union[bool, str]] = None,
125
- revision: Optional[str] = None,
126
- force_download: bool = False,
127
- cache_dir: Optional[str] = None,
128
- subfolder: str = "",
129
- local_files_only: bool = False,
130
- trust_remote_code: bool = False,
131
- model_save_dir: Optional[Union[str, Path, TemporaryDirectory]] = None,
132
- **kwargs,
133
- ) -> "RBLNGPT2LMHeadModel":
134
- """
135
- Exports a vanilla Transformers model into a rbln-compiled Module.
136
- """
137
- task = kwargs.pop("task", None)
138
- if task is None:
139
- task = TasksManager.infer_task_from_model(cls.auto_model_class)
140
-
141
- if model_save_dir is None:
142
- save_dir = TemporaryDirectory()
143
- save_dir_path = Path(save_dir.name)
144
- else:
145
- save_dir = model_save_dir
146
- if isinstance(save_dir, TemporaryDirectory):
147
- save_dir_path = Path(model_save_dir.name)
148
- else:
149
- save_dir_path = Path(model_save_dir)
150
- save_dir_path.mkdir(exist_ok=True)
151
-
152
- kwargs.update(
153
- {
154
- "torchscript": True,
155
- "return_dict": False,
156
- "use_cache": True,
157
- }
158
- )
159
-
160
- rbln_config_kwargs, rbln_constructor_kwargs = cls.pop_rbln_kwargs_from_kwargs(kwargs)
161
-
162
- model: GPT2LMHeadModel = TasksManager.get_model_from_task(
163
- task=task,
164
- model_name_or_path=model_id,
165
- subfolder=subfolder,
166
- revision=revision,
167
- framework="pt",
168
- cache_dir=cache_dir,
169
- use_auth_token=use_auth_token,
170
- local_files_only=local_files_only,
171
- force_download=force_download,
172
- trust_remote_code=trust_remote_code,
173
- **kwargs,
174
- )
175
-
176
- if config is None:
177
- config = model.config
178
-
179
- config.save_pretrained(save_dir_path)
180
- preprocessors = maybe_save_preprocessors(model_id, save_dir_path, src_subfolder=subfolder)
181
-
182
- # Get compilation arguments
183
- if rbln_config_kwargs.get("rbln_config", None) is None:
184
- rbln_config = cls.get_rbln_config(
185
- preprocessors=preprocessors, model_config=model.config, **rbln_config_kwargs
186
- )
187
-
188
- def compile_gpt2():
189
- wrapped_decoder = GPT2LMHeadModelWrapper(model).eval()
190
-
191
- prefill_rbln_runtime_config = rbln_config[DEFAULT_COMPILED_MODEL_NAME][0]
192
- dec_rbln_runtime_config = rbln_config[DEFAULT_COMPILED_MODEL_NAME][1]
193
-
194
- prefill_example_inputs = prefill_rbln_runtime_config.get_dummy_inputs(fill=0)
195
- dec_example_inputs = dec_rbln_runtime_config.get_dummy_inputs(fill=0)
196
-
197
- prefill_scripted_model = torch.jit.trace(wrapped_decoder, prefill_example_inputs)
198
- dec_scripted_model = torch.jit.trace(wrapped_decoder, dec_example_inputs)
199
-
200
- prefill_ir = rebel.torchscript_to_ir(
201
- prefill_scripted_model,
202
- input_names=[v[0] for v in prefill_rbln_runtime_config.input_info],
203
- )
204
- dec_ir = rebel.torchscript_to_ir(
205
- dec_scripted_model,
206
- input_names=[v[0] for v in dec_rbln_runtime_config.input_info],
207
- )
208
-
209
- connections = [
210
- (prefill_ir.outputs[1], prefill_ir.inputs[1]),
211
- ]
212
-
213
- compiled_model = rebel.compile(
214
- prefill_ir,
215
- dec_ir,
216
- connections=connections,
217
- fusion=prefill_rbln_runtime_config.fusion,
218
- npu=prefill_rbln_runtime_config.npu,
219
- tensor_parallel_size=prefill_rbln_runtime_config.tensor_parallel_size,
220
- use_weight_sharing=True,
221
- )
222
- compiled_model.save(save_dir_path / f"{DEFAULT_COMPILED_MODEL_NAME}.rbln")
223
-
224
- compile_gpt2()
225
- rbln_config.save(save_dir_path)
226
-
227
- return cls._from_pretrained(
228
- model_id=save_dir_path,
229
- config=config,
230
- model_save_dir=save_dir,
231
- **rbln_constructor_kwargs,
232
- **kwargs,
233
- )
234
-
235
78
  @classmethod
236
79
  def _get_rbln_config(
237
80
  cls,
@@ -239,153 +82,74 @@ class RBLNGPT2LMHeadModel(RBLNBaseModel, RBLNGenerationMixin):
239
82
  model_config: "PretrainedConfig",
240
83
  rbln_max_seq_len: Optional[int] = None,
241
84
  rbln_batch_size: Optional[int] = None,
242
- rbln_pad_token_id: Optional[int] = None,
85
+ **kwargs,
243
86
  ) -> RBLNConfig:
244
87
  meta = {}
245
88
 
246
- default_max_length = getattr(model_config, "n_positions", None)
247
- for tokenizer in preprocessors:
248
- default_max_length = default_max_length or getattr(tokenizer, "max_len_single_sentence", None)
249
-
250
89
  prefill_chunk_size = 128
90
+ if rbln_max_seq_len is None: # differenct from llama
91
+ rbln_max_seq_len = getattr(model_config, "n_positions", None)
92
+ rbln_batch_size = 1 if rbln_batch_size is None else rbln_batch_size
251
93
 
252
- if rbln_max_seq_len is None:
253
- rbln_max_seq_len = default_max_length
254
-
255
- if rbln_max_seq_len is None:
256
- raise ValueError("`rbln_max_seq_len` should be specified!")
257
-
258
- if rbln_pad_token_id is None:
259
- rbln_pad_token_id = getattr(model_config, "pad_token_id", None)
260
- if rbln_pad_token_id is None:
261
- rbln_pad_token_id = getattr(model_config, "eos_token_id", None)
262
- if rbln_pad_token_id is None:
263
- rbln_pad_token_id = 50256
264
-
265
- meta["rbln_prefill_chunk_size"] = prefill_chunk_size
266
94
  meta["rbln_max_seq_len"] = rbln_max_seq_len
267
- meta["rbln_pad_token_id"] = rbln_pad_token_id
268
-
269
- rbln_batch_size = 1 if rbln_batch_size is None else rbln_batch_size
95
+ meta["rbln_batch_size"] = rbln_batch_size
96
+ meta["rbln_prefill_chunk_size"] = prefill_chunk_size
270
97
 
271
- def get_input_info(query_length):
272
- return [
273
- ("input_ids", [rbln_batch_size, query_length], "int64"),
274
- (
275
- "past_key_values",
276
- [
277
- model_config.n_layer,
278
- 2,
279
- rbln_batch_size,
280
- model_config.n_head,
281
- rbln_max_seq_len,
282
- model_config.hidden_size // model_config.n_head,
283
- ],
284
- "float32",
285
- ),
286
- ("attention_mask", [rbln_batch_size, 1, query_length, rbln_max_seq_len], "int64"),
98
+ def get_input_info(
99
+ batch_size,
100
+ query_length,
101
+ ):
102
+ head_dim = (
103
+ model_config.head_dim
104
+ if hasattr(model_config, "head_dim")
105
+ else model_config.hidden_size // model_config.n_head
106
+ )
107
+ input_info = [
108
+ ("input_ids", [batch_size, query_length], "int64"),
109
+ ("attention_mask", [batch_size, 1, query_length, rbln_max_seq_len], "int64"),
287
110
  (
288
111
  "cache_position",
289
- [],
112
+ [batch_size, query_length],
290
113
  "int32",
291
114
  ),
115
+ ("batch_position", [], "int16"),
292
116
  ]
293
117
 
294
- # model input info
295
- prefill_input_info = get_input_info(query_length=prefill_chunk_size)
296
- dec_input_info = get_input_info(query_length=1)
118
+ input_info.extend(
119
+ [
120
+ (
121
+ f"past_key_values_{i}",
122
+ [
123
+ rbln_batch_size,
124
+ model_config.n_head, # differenct from llama
125
+ rbln_max_seq_len,
126
+ head_dim,
127
+ ],
128
+ "float32",
129
+ )
130
+ for i in range(model_config.n_layer * 2) # differenct from llama
131
+ ]
132
+ )
133
+
134
+ return input_info
135
+
136
+ prefill_input_info = get_input_info(
137
+ batch_size=1,
138
+ query_length=prefill_chunk_size,
139
+ )
140
+ dec_input_info = get_input_info(
141
+ batch_size=rbln_batch_size,
142
+ query_length=1,
143
+ )
297
144
 
298
145
  prefill_rbln_runtime_config = RBLNRuntimeConfig(input_info=prefill_input_info)
299
146
  dec_rbln_runtime_config = RBLNRuntimeConfig(input_info=dec_input_info)
300
147
 
148
+ dec_rbln_runtime_config.batch_size = rbln_batch_size
149
+
301
150
  rbln_config = RBLNConfig.from_rbln_runtime_configs(
302
151
  [prefill_rbln_runtime_config, dec_rbln_runtime_config],
303
152
  _rbln_meta=meta,
304
153
  )
305
154
 
306
155
  return rbln_config
307
-
308
- def _create_runtimes(self, rbln_device_map: Dict[str, int]) -> List[rebel.Runtime]:
309
- device_val = rbln_device_map[DEFAULT_COMPILED_MODEL_NAME]
310
- return [
311
- self.compiled_models[0].create_runtime(input_info_index=0, tensor_type="pt", device=device_val),
312
- self.compiled_models[0].create_runtime(input_info_index=1, tensor_type="pt", device=device_val),
313
- ]
314
-
315
- def prepare_inputs_for_generation(self, input_ids, past_key_values=0, attention_mask=None, **kwargs):
316
- batch_size, cur_len = input_ids.shape
317
- past_cached_length = past_key_values
318
-
319
- # In greedy decoding
320
- if past_cached_length == 0:
321
- self.prompt_ids = input_ids
322
- self.rightpad_max_len = cur_len
323
- prompt_min_len = torch.min(torch.sum(attention_mask, dim=-1))
324
- self.dummy_len = torch.sum(attention_mask, dim=-1) - prompt_min_len
325
-
326
- if cur_len % self.prefill_chunk_size == 0:
327
- pad_len = 0
328
- else:
329
- pad_len = self.prefill_chunk_size - cur_len % self.prefill_chunk_size
330
- input_ids = torch.nn.functional.pad(input_ids, (0, pad_len))
331
- attention_mask = self.prefill_attention_mask.clone()
332
- cache_position = torch.tensor(past_cached_length, dtype=torch.int32)
333
-
334
- query_length = prompt_min_len.item()
335
- else:
336
- cache_position = torch.tensor(past_cached_length, dtype=torch.int32)
337
- attention_mask = torch.zeros(batch_size, 1, 1, self.max_seq_len, dtype=torch.int64)
338
- attention_mask[:, :, :, : cache_position + 1] = 1
339
- input_ids = input_ids[:, cache_position : cache_position + 1].contiguous()
340
- query_length = 1
341
-
342
- model_inputs = {
343
- "input_ids": input_ids,
344
- "past_key_values": past_key_values,
345
- "attention_mask": attention_mask,
346
- # below are rbln-related kwargs
347
- "cache_position": cache_position,
348
- "query_length": query_length,
349
- }
350
-
351
- return model_inputs
352
-
353
- def forward(
354
- self,
355
- input_ids: Optional[torch.LongTensor] = None,
356
- past_key_values: Optional[Tuple[Tuple[torch.Tensor]]] = None,
357
- attention_mask: Optional[torch.FloatTensor] = None,
358
- cache_position: Optional[torch.Tensor] = None,
359
- query_length: Optional[torch.Tensor] = None,
360
- **kwargs,
361
- ) -> Union[Tuple, CausalLMOutputWithCrossAttentions]:
362
- if past_key_values is not None:
363
- past_key_values += query_length
364
-
365
- if cache_position == 0:
366
- for step in range(0, query_length, self.prefill_chunk_size):
367
- sliced_input_ids = input_ids[:, step : step + self.prefill_chunk_size]
368
- attention_mask[:, :, :, :step] = 1
369
- attention_mask[:, :, :, step : step + self.prefill_chunk_size] = self.causal_mask
370
-
371
- output = self.prefill_decoder(
372
- input_ids=sliced_input_ids.contiguous(),
373
- attention_mask=attention_mask.contiguous(),
374
- cache_position=cache_position + step,
375
- )
376
-
377
- idx = query_length % self.prefill_chunk_size - 1
378
- output = output.logits[:, idx].unsqueeze(1)
379
-
380
- else:
381
- output = self.decoder(
382
- input_ids=input_ids.contiguous(),
383
- attention_mask=attention_mask.contiguous(),
384
- cache_position=cache_position,
385
- )
386
- output = output.logits
387
-
388
- return CausalLMOutputWithCrossAttentions(logits=output, past_key_values=past_key_values)
389
-
390
- def __repr__(self):
391
- return repr(self.runtimes[0]) + "\n" + repr(self.runtimes[1])