optimum-rbln 0.1.7__py3-none-any.whl → 0.1.9__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (49) hide show
  1. optimum/rbln/__init__.py +17 -0
  2. optimum/rbln/__version__.py +1 -1
  3. optimum/rbln/diffusers/__init__.py +0 -1
  4. optimum/rbln/diffusers/models/autoencoder_kl.py +3 -3
  5. optimum/rbln/diffusers/models/controlnet.py +7 -3
  6. optimum/rbln/diffusers/models/unet_2d_condition.py +5 -5
  7. optimum/rbln/diffusers/pipelines/controlnet/multicontrolnet.py +23 -146
  8. optimum/rbln/diffusers/pipelines/controlnet/pipeline_controlnet.py +107 -59
  9. optimum/rbln/diffusers/pipelines/controlnet/pipeline_controlnet_img2img.py +106 -54
  10. optimum/rbln/diffusers/pipelines/controlnet/pipeline_controlnet_sd_xl.py +130 -71
  11. optimum/rbln/diffusers/pipelines/controlnet/pipeline_controlnet_sd_xl_img2img.py +131 -72
  12. optimum/rbln/modeling_alias.py +19 -1
  13. optimum/rbln/modeling_base.py +162 -18
  14. optimum/rbln/transformers/__init__.py +8 -0
  15. optimum/rbln/transformers/cache_utils.py +111 -0
  16. optimum/rbln/transformers/generation/utils.py +0 -2
  17. optimum/rbln/transformers/models/__init__.py +3 -0
  18. optimum/rbln/transformers/models/bart/bart_architecture.py +0 -5
  19. optimum/rbln/transformers/models/clip/modeling_clip.py +1 -1
  20. optimum/rbln/transformers/models/decoderonly/__init__.py +36 -0
  21. optimum/rbln/transformers/models/decoderonly/decoderonly_architecture.py +516 -0
  22. optimum/rbln/transformers/models/decoderonly/modeling_decoderonly.py +464 -0
  23. optimum/rbln/transformers/models/gemma/__init__.py +24 -0
  24. optimum/rbln/transformers/models/gemma/gemma_architecture.py +123 -0
  25. optimum/rbln/transformers/models/gemma/modeling_gemma.py +67 -0
  26. optimum/rbln/transformers/models/gpt2/gpt2_architecture.py +201 -166
  27. optimum/rbln/transformers/models/gpt2/modeling_gpt2.py +10 -257
  28. optimum/rbln/transformers/models/llama/llama_architecture.py +3 -610
  29. optimum/rbln/transformers/models/llama/modeling_llama.py +12 -440
  30. optimum/rbln/transformers/models/midm/hf_hub_cached/midm_bitext_tokenization.py +2 -1
  31. optimum/rbln/transformers/models/midm/hf_hub_cached/modeling_midm.py +0 -4
  32. optimum/rbln/transformers/models/midm/midm_architecture.py +160 -357
  33. optimum/rbln/transformers/models/midm/modeling_midm.py +10 -325
  34. optimum/rbln/transformers/models/mistral/__init__.py +24 -0
  35. optimum/rbln/transformers/models/mistral/mistral_architecture.py +29 -0
  36. optimum/rbln/transformers/models/mistral/modeling_mistral.py +68 -0
  37. optimum/rbln/transformers/models/wav2vec2/modeling_wav2vec2.py +1 -1
  38. optimum/rbln/transformers/models/whisper/whisper_architecture.py +0 -6
  39. optimum/rbln/transformers/models/xlm_roberta/__init__.py +24 -0
  40. optimum/rbln/transformers/models/xlm_roberta/modeling_xlm_roberta.py +131 -0
  41. optimum/rbln/transformers/utils/__init__.py +0 -0
  42. optimum/rbln/transformers/utils/rbln_quantization.py +109 -0
  43. optimum/rbln/utils/import_utils.py +1 -4
  44. optimum/rbln/utils/runtime_utils.py +2 -1
  45. {optimum_rbln-0.1.7.dist-info → optimum_rbln-0.1.9.dist-info}/METADATA +11 -5
  46. {optimum_rbln-0.1.7.dist-info → optimum_rbln-0.1.9.dist-info}/RECORD +48 -35
  47. optimum/rbln/transformers/models/llama/llama_architecture_cb.py +0 -764
  48. {optimum_rbln-0.1.7.dist-info → optimum_rbln-0.1.9.dist-info}/WHEEL +0 -0
  49. {optimum_rbln-0.1.7.dist-info → optimum_rbln-0.1.9.dist-info}/licenses/LICENSE +0 -0
@@ -23,39 +23,21 @@
23
23
 
24
24
  import inspect
25
25
  import logging
26
- from typing import TYPE_CHECKING, Any, Callable, Dict, List, Optional, Tuple, Union
26
+ from typing import TYPE_CHECKING, Any, Callable
27
27
 
28
- import rebel
29
- import torch
30
- from transformers import AutoModelForCausalLM, GPT2LMHeadModel, PretrainedConfig
31
- from transformers.modeling_outputs import CausalLMOutputWithCrossAttentions, Seq2SeqLMOutput
28
+ from transformers import GPT2LMHeadModel
32
29
 
33
- from ....modeling_base import RBLNModel
34
- from ....modeling_config import DEFAULT_COMPILED_MODEL_NAME, RBLNConfig, RBLNRuntimeConfig
35
- from ....utils.runtime_utils import RBLNPytorchRuntime
36
- from ...generation.utils import RBLNGenerationMixin
30
+ from ....modeling_config import RBLNConfig
31
+ from ...models.decoderonly import RBLNDecoderOnlyModelForCausalLM
37
32
  from .gpt2_architecture import GPT2LMHeadModelWrapper
38
33
 
39
34
 
40
35
  logger = logging.getLogger(__name__)
41
-
42
36
  if TYPE_CHECKING:
43
- from transformers import (
44
- AutoFeatureExtractor,
45
- AutoProcessor,
46
- AutoTokenizer,
47
- PretrainedConfig,
48
- )
49
-
50
-
51
- class RBLNRuntimeDecoder(RBLNPytorchRuntime):
52
- def forward(self, *args, **kwargs) -> Union[Tuple, Seq2SeqLMOutput]:
53
- outputs = super().forward(*args, **kwargs)
54
- logits = outputs
55
- return Seq2SeqLMOutput(logits=logits)
37
+ from transformers import PreTrainedModel
56
38
 
57
39
 
58
- class RBLNGPT2LMHeadModel(RBLNModel, RBLNGenerationMixin):
40
+ class RBLNGPT2LMHeadModel(RBLNDecoderOnlyModelForCausalLM):
59
41
  """
60
42
  The GPT2 Model transformer with a language modeling head on top (linear layer with weights tied to the input
61
43
  embeddings).
@@ -69,29 +51,10 @@ class RBLNGPT2LMHeadModel(RBLNModel, RBLNGenerationMixin):
69
51
 
70
52
  """
71
53
 
72
- model_type = "rbln_model"
73
- auto_model_class = AutoModelForCausalLM
74
- main_input_name = "input_ids"
75
-
76
- def __post_init__(self, **kwargs):
77
- self.prefill_chunk_size = self.rbln_config.meta["rbln_prefill_chunk_size"]
78
- self.max_seq_len = self.rbln_config.meta["rbln_max_seq_len"]
79
-
80
- batch_size = self.rbln_config[DEFAULT_COMPILED_MODEL_NAME][0].input_info[0][1][0]
81
- self.prefill_attention_mask = torch.zeros(
82
- batch_size, 1, self.prefill_chunk_size, self.max_seq_len, dtype=torch.int64
83
- )
84
- self.causal_mask = 1 - torch.triu(
85
- torch.ones(batch_size, 1, self.prefill_chunk_size, self.prefill_chunk_size), diagonal=1
86
- )
87
-
88
- self.prefill_decoder = RBLNRuntimeDecoder(runtime=self.model[0])
89
- self.decoder = RBLNRuntimeDecoder(runtime=self.model[1])
90
- self.pad_token_id = self.rbln_config.meta["rbln_pad_token_id"]
91
- self.past_cached_length = 0
92
-
93
- def can_generate(self):
94
- return True
54
+ @classmethod
55
+ def wrap_model_if_needed(self, model: "PreTrainedModel", rbln_config: "RBLNConfig"):
56
+ rbln_max_seq_len = rbln_config.meta["rbln_max_seq_len"]
57
+ return GPT2LMHeadModelWrapper(model, rbln_max_seq_len).eval()
95
58
 
96
59
  def __getattr__(self, __name: str) -> Any:
97
60
  """This is the key method to implement RBLN-GPT2.
@@ -107,213 +70,3 @@ class RBLNGPT2LMHeadModel(RBLNModel, RBLNGenerationMixin):
107
70
  if isinstance(val, Callable) and "self" in set(inspect.signature(val).parameters):
108
71
  return redirect(val)
109
72
  return val
110
-
111
- def _reorder_cache(self, past_key_values, beam_idx):
112
- # TODO(jongho): implement
113
- raise NotImplementedError
114
-
115
- @classmethod
116
- def update_kwargs(cls, kwargs):
117
- kwargs.update(
118
- {
119
- "torchscript": True,
120
- "return_dict": False,
121
- "use_cache": True,
122
- }
123
- )
124
- return kwargs
125
-
126
- @classmethod
127
- @torch.inference_mode()
128
- def get_compiled_model(cls, model: GPT2LMHeadModel, rbln_config: RBLNConfig):
129
- wrapped_decoder = GPT2LMHeadModelWrapper(model).eval()
130
-
131
- prefill_rbln_runtime_config = rbln_config[DEFAULT_COMPILED_MODEL_NAME][0]
132
- dec_rbln_runtime_config = rbln_config[DEFAULT_COMPILED_MODEL_NAME][1]
133
-
134
- prefill_example_inputs = prefill_rbln_runtime_config.get_dummy_inputs(fill=0)
135
- dec_example_inputs = dec_rbln_runtime_config.get_dummy_inputs(fill=0)
136
-
137
- prefill_scripted_model = torch.jit.trace(wrapped_decoder, prefill_example_inputs, check_trace=False)
138
- dec_scripted_model = torch.jit.trace(wrapped_decoder, dec_example_inputs, check_trace=False)
139
-
140
- prefill_ir = rebel.torchscript_to_ir(
141
- prefill_scripted_model,
142
- input_names=[v[0] for v in prefill_rbln_runtime_config.input_info],
143
- )
144
- dec_ir = rebel.torchscript_to_ir(
145
- dec_scripted_model,
146
- input_names=[v[0] for v in dec_rbln_runtime_config.input_info],
147
- )
148
-
149
- connections = [(prefill_ir.outputs[1 + i], prefill_ir.inputs[3 + i]) for i in range(model.config.n_layer * 2)]
150
-
151
- compiled_model = rebel.compile(
152
- prefill_ir,
153
- dec_ir,
154
- connections=connections,
155
- fusion=prefill_rbln_runtime_config.fusion,
156
- npu=prefill_rbln_runtime_config.npu,
157
- tensor_parallel_size=prefill_rbln_runtime_config.tensor_parallel_size,
158
- use_weight_sharing=True,
159
- )
160
-
161
- return compiled_model
162
-
163
- @classmethod
164
- def _get_rbln_config(
165
- cls,
166
- preprocessors: Union["AutoFeatureExtractor", "AutoProcessor", "AutoTokenizer"],
167
- model_config: "PretrainedConfig",
168
- rbln_max_seq_len: Optional[int] = None,
169
- rbln_batch_size: Optional[int] = None,
170
- rbln_pad_token_id: Optional[int] = None,
171
- ) -> RBLNConfig:
172
- meta = {}
173
-
174
- default_max_length = getattr(model_config, "n_positions", None)
175
- for tokenizer in preprocessors:
176
- default_max_length = default_max_length or getattr(tokenizer, "max_len_single_sentence", None)
177
-
178
- prefill_chunk_size = 128
179
-
180
- if rbln_max_seq_len is None:
181
- rbln_max_seq_len = default_max_length
182
-
183
- if rbln_max_seq_len is None:
184
- raise ValueError("`rbln_max_seq_len` should be specified!")
185
-
186
- if rbln_pad_token_id is None:
187
- rbln_pad_token_id = getattr(model_config, "pad_token_id", None)
188
- if rbln_pad_token_id is None:
189
- rbln_pad_token_id = getattr(model_config, "eos_token_id", None)
190
- if rbln_pad_token_id is None:
191
- rbln_pad_token_id = 50256
192
-
193
- meta["rbln_prefill_chunk_size"] = prefill_chunk_size
194
- meta["rbln_max_seq_len"] = rbln_max_seq_len
195
- meta["rbln_pad_token_id"] = rbln_pad_token_id
196
-
197
- rbln_batch_size = 1 if rbln_batch_size is None else rbln_batch_size
198
-
199
- def get_input_info(query_length):
200
- return [
201
- ("input_ids", [rbln_batch_size, query_length], "int64"),
202
- ("attention_mask", [rbln_batch_size, 1, query_length, rbln_max_seq_len], "int64"),
203
- (
204
- "cache_position",
205
- [],
206
- "int32",
207
- ),
208
- ] + [
209
- (
210
- f"past_key_values_{i}",
211
- [
212
- rbln_batch_size,
213
- model_config.n_head,
214
- rbln_max_seq_len,
215
- model_config.hidden_size // model_config.n_head,
216
- ],
217
- "float32",
218
- )
219
- for i in range(model_config.n_layer * 2)
220
- ]
221
-
222
- # model input info
223
- prefill_input_info = get_input_info(query_length=prefill_chunk_size)
224
- dec_input_info = get_input_info(query_length=1)
225
-
226
- prefill_rbln_runtime_config = RBLNRuntimeConfig(input_info=prefill_input_info)
227
- dec_rbln_runtime_config = RBLNRuntimeConfig(input_info=dec_input_info)
228
-
229
- rbln_config = RBLNConfig.from_rbln_runtime_configs(
230
- [prefill_rbln_runtime_config, dec_rbln_runtime_config],
231
- _rbln_meta=meta,
232
- )
233
-
234
- return rbln_config
235
-
236
- @classmethod
237
- def _create_runtimes(
238
- cls, compiled_models: List[rebel.RBLNCompiledModel], rbln_device_map: Dict[str, int]
239
- ) -> List[rebel.Runtime]:
240
- device_val = rbln_device_map[DEFAULT_COMPILED_MODEL_NAME]
241
- return [
242
- compiled_models[0].create_runtime(input_info_index=0, tensor_type="pt", device=device_val),
243
- compiled_models[0].create_runtime(input_info_index=1, tensor_type="pt", device=device_val),
244
- ]
245
-
246
- def prepare_inputs_for_generation(self, input_ids, past_key_values=0, attention_mask=None, **kwargs):
247
- batch_size, cur_len = input_ids.shape
248
- past_cached_length = past_key_values
249
-
250
- # In greedy decoding
251
- if past_cached_length == 0:
252
- self.prompt_ids = input_ids
253
- self.rightpad_max_len = cur_len
254
- prompt_min_len = torch.min(torch.sum(attention_mask, dim=-1))
255
- self.dummy_len = torch.sum(attention_mask, dim=-1) - prompt_min_len
256
-
257
- if cur_len % self.prefill_chunk_size == 0:
258
- pad_len = 0
259
- else:
260
- pad_len = self.prefill_chunk_size - cur_len % self.prefill_chunk_size
261
- input_ids = torch.nn.functional.pad(input_ids, (0, pad_len))
262
- attention_mask = self.prefill_attention_mask.clone()
263
- cache_position = torch.tensor(past_cached_length, dtype=torch.int32)
264
-
265
- query_length = prompt_min_len.item()
266
- else:
267
- cache_position = torch.tensor(past_cached_length, dtype=torch.int32)
268
- attention_mask = torch.zeros(batch_size, 1, 1, self.max_seq_len, dtype=torch.int64)
269
- attention_mask[:, :, :, : cache_position + 1] = 1
270
- input_ids = input_ids[:, cache_position : cache_position + 1].contiguous()
271
- query_length = 1
272
-
273
- model_inputs = {
274
- "input_ids": input_ids,
275
- "past_key_values": past_key_values,
276
- "attention_mask": attention_mask,
277
- # below are rbln-related kwargs
278
- "cache_position": cache_position,
279
- "query_length": query_length,
280
- }
281
-
282
- return model_inputs
283
-
284
- def forward(
285
- self,
286
- input_ids: Optional[torch.LongTensor] = None,
287
- past_key_values: Optional[Tuple[Tuple[torch.Tensor]]] = None,
288
- attention_mask: Optional[torch.FloatTensor] = None,
289
- cache_position: Optional[torch.Tensor] = None,
290
- query_length: Optional[torch.Tensor] = None,
291
- **kwargs,
292
- ) -> Union[Tuple, CausalLMOutputWithCrossAttentions]:
293
- if past_key_values is not None:
294
- past_key_values += query_length
295
-
296
- if cache_position == 0:
297
- for step in range(0, query_length, self.prefill_chunk_size):
298
- sliced_input_ids = input_ids[:, step : step + self.prefill_chunk_size]
299
- attention_mask[:, :, :, :step] = 1
300
- attention_mask[:, :, :, step : step + self.prefill_chunk_size] = self.causal_mask
301
-
302
- output = self.prefill_decoder(
303
- input_ids=sliced_input_ids.contiguous(),
304
- attention_mask=attention_mask.contiguous(),
305
- cache_position=cache_position + step,
306
- )
307
-
308
- idx = query_length % self.prefill_chunk_size - 1
309
- output = output.logits[:, idx].unsqueeze(1)
310
-
311
- else:
312
- output = self.decoder(
313
- input_ids=input_ids.contiguous(),
314
- attention_mask=attention_mask.contiguous(),
315
- cache_position=cache_position,
316
- )
317
- output = output.logits
318
-
319
- return CausalLMOutputWithCrossAttentions(logits=output, past_key_values=past_key_values)