optimum-rbln 0.1.7__py3-none-any.whl → 0.1.8__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (39) hide show
  1. optimum/rbln/__init__.py +14 -0
  2. optimum/rbln/__version__.py +1 -1
  3. optimum/rbln/diffusers/__init__.py +0 -1
  4. optimum/rbln/diffusers/models/controlnet.py +3 -0
  5. optimum/rbln/diffusers/models/unet_2d_condition.py +2 -2
  6. optimum/rbln/diffusers/pipelines/controlnet/multicontrolnet.py +22 -144
  7. optimum/rbln/diffusers/pipelines/controlnet/pipeline_controlnet.py +107 -59
  8. optimum/rbln/diffusers/pipelines/controlnet/pipeline_controlnet_img2img.py +106 -54
  9. optimum/rbln/diffusers/pipelines/controlnet/pipeline_controlnet_sd_xl.py +130 -71
  10. optimum/rbln/diffusers/pipelines/controlnet/pipeline_controlnet_sd_xl_img2img.py +131 -72
  11. optimum/rbln/modeling_alias.py +14 -0
  12. optimum/rbln/modeling_base.py +110 -0
  13. optimum/rbln/transformers/__init__.py +6 -0
  14. optimum/rbln/transformers/cache_utils.py +111 -0
  15. optimum/rbln/transformers/generation/utils.py +0 -2
  16. optimum/rbln/transformers/models/__init__.py +2 -0
  17. optimum/rbln/transformers/models/bart/bart_architecture.py +0 -5
  18. optimum/rbln/transformers/models/decoderonly/__init__.py +36 -0
  19. optimum/rbln/transformers/models/decoderonly/decoderonly_architecture.py +515 -0
  20. optimum/rbln/transformers/models/decoderonly/modeling_decoderonly.py +349 -0
  21. optimum/rbln/transformers/models/gemma/__init__.py +24 -0
  22. optimum/rbln/transformers/models/gemma/gemma_architecture.py +116 -0
  23. optimum/rbln/transformers/models/gemma/modeling_gemma.py +61 -0
  24. optimum/rbln/transformers/models/gpt2/gpt2_architecture.py +201 -166
  25. optimum/rbln/transformers/models/gpt2/modeling_gpt2.py +56 -220
  26. optimum/rbln/transformers/models/llama/llama_architecture.py +3 -610
  27. optimum/rbln/transformers/models/llama/modeling_llama.py +8 -442
  28. optimum/rbln/transformers/models/midm/hf_hub_cached/midm_bitext_tokenization.py +2 -1
  29. optimum/rbln/transformers/models/midm/hf_hub_cached/modeling_midm.py +0 -4
  30. optimum/rbln/transformers/models/midm/midm_architecture.py +160 -357
  31. optimum/rbln/transformers/models/midm/modeling_midm.py +40 -272
  32. optimum/rbln/transformers/models/whisper/whisper_architecture.py +0 -6
  33. optimum/rbln/transformers/models/xlm_roberta/__init__.py +24 -0
  34. optimum/rbln/transformers/models/xlm_roberta/modeling_xlm_roberta.py +125 -0
  35. {optimum_rbln-0.1.7.dist-info → optimum_rbln-0.1.8.dist-info}/METADATA +2 -3
  36. {optimum_rbln-0.1.7.dist-info → optimum_rbln-0.1.8.dist-info}/RECORD +38 -30
  37. optimum/rbln/transformers/models/llama/llama_architecture_cb.py +0 -764
  38. {optimum_rbln-0.1.7.dist-info → optimum_rbln-0.1.8.dist-info}/WHEEL +0 -0
  39. {optimum_rbln-0.1.7.dist-info → optimum_rbln-0.1.8.dist-info}/licenses/LICENSE +0 -0
@@ -23,22 +23,16 @@
23
23
 
24
24
  import inspect
25
25
  import logging
26
- from typing import TYPE_CHECKING, Any, Callable, Dict, List, Optional, Tuple, Union
26
+ from typing import TYPE_CHECKING, Any, Callable, Optional, Union
27
27
 
28
- import rebel
29
- import torch
30
- from transformers import AutoModelForCausalLM, GPT2LMHeadModel, PretrainedConfig
31
- from transformers.modeling_outputs import CausalLMOutputWithCrossAttentions, Seq2SeqLMOutput
28
+ from transformers import GPT2LMHeadModel, PretrainedConfig, PreTrainedModel
32
29
 
33
- from ....modeling_base import RBLNModel
34
- from ....modeling_config import DEFAULT_COMPILED_MODEL_NAME, RBLNConfig, RBLNRuntimeConfig
35
- from ....utils.runtime_utils import RBLNPytorchRuntime
36
- from ...generation.utils import RBLNGenerationMixin
30
+ from ....modeling_config import RBLNConfig, RBLNRuntimeConfig
31
+ from ...models.decoderonly import RBLNDecoderOnlyModelForCausalLM
37
32
  from .gpt2_architecture import GPT2LMHeadModelWrapper
38
33
 
39
34
 
40
35
  logger = logging.getLogger(__name__)
41
-
42
36
  if TYPE_CHECKING:
43
37
  from transformers import (
44
38
  AutoFeatureExtractor,
@@ -48,14 +42,7 @@ if TYPE_CHECKING:
48
42
  )
49
43
 
50
44
 
51
- class RBLNRuntimeDecoder(RBLNPytorchRuntime):
52
- def forward(self, *args, **kwargs) -> Union[Tuple, Seq2SeqLMOutput]:
53
- outputs = super().forward(*args, **kwargs)
54
- logits = outputs
55
- return Seq2SeqLMOutput(logits=logits)
56
-
57
-
58
- class RBLNGPT2LMHeadModel(RBLNModel, RBLNGenerationMixin):
45
+ class RBLNGPT2LMHeadModel(RBLNDecoderOnlyModelForCausalLM):
59
46
  """
60
47
  The GPT2 Model transformer with a language modeling head on top (linear layer with weights tied to the input
61
48
  embeddings).
@@ -69,29 +56,9 @@ class RBLNGPT2LMHeadModel(RBLNModel, RBLNGenerationMixin):
69
56
 
70
57
  """
71
58
 
72
- model_type = "rbln_model"
73
- auto_model_class = AutoModelForCausalLM
74
- main_input_name = "input_ids"
75
-
76
- def __post_init__(self, **kwargs):
77
- self.prefill_chunk_size = self.rbln_config.meta["rbln_prefill_chunk_size"]
78
- self.max_seq_len = self.rbln_config.meta["rbln_max_seq_len"]
79
-
80
- batch_size = self.rbln_config[DEFAULT_COMPILED_MODEL_NAME][0].input_info[0][1][0]
81
- self.prefill_attention_mask = torch.zeros(
82
- batch_size, 1, self.prefill_chunk_size, self.max_seq_len, dtype=torch.int64
83
- )
84
- self.causal_mask = 1 - torch.triu(
85
- torch.ones(batch_size, 1, self.prefill_chunk_size, self.prefill_chunk_size), diagonal=1
86
- )
87
-
88
- self.prefill_decoder = RBLNRuntimeDecoder(runtime=self.model[0])
89
- self.decoder = RBLNRuntimeDecoder(runtime=self.model[1])
90
- self.pad_token_id = self.rbln_config.meta["rbln_pad_token_id"]
91
- self.past_cached_length = 0
92
-
93
- def can_generate(self):
94
- return True
59
+ @classmethod
60
+ def wrapping_torch_model(self, model: "PreTrainedModel", rbln_max_seq_len: int):
61
+ return GPT2LMHeadModelWrapper(model, rbln_max_seq_len).eval()
95
62
 
96
63
  def __getattr__(self, __name: str) -> Any:
97
64
  """This is the key method to implement RBLN-GPT2.
@@ -108,58 +75,6 @@ class RBLNGPT2LMHeadModel(RBLNModel, RBLNGenerationMixin):
108
75
  return redirect(val)
109
76
  return val
110
77
 
111
- def _reorder_cache(self, past_key_values, beam_idx):
112
- # TODO(jongho): implement
113
- raise NotImplementedError
114
-
115
- @classmethod
116
- def update_kwargs(cls, kwargs):
117
- kwargs.update(
118
- {
119
- "torchscript": True,
120
- "return_dict": False,
121
- "use_cache": True,
122
- }
123
- )
124
- return kwargs
125
-
126
- @classmethod
127
- @torch.inference_mode()
128
- def get_compiled_model(cls, model: GPT2LMHeadModel, rbln_config: RBLNConfig):
129
- wrapped_decoder = GPT2LMHeadModelWrapper(model).eval()
130
-
131
- prefill_rbln_runtime_config = rbln_config[DEFAULT_COMPILED_MODEL_NAME][0]
132
- dec_rbln_runtime_config = rbln_config[DEFAULT_COMPILED_MODEL_NAME][1]
133
-
134
- prefill_example_inputs = prefill_rbln_runtime_config.get_dummy_inputs(fill=0)
135
- dec_example_inputs = dec_rbln_runtime_config.get_dummy_inputs(fill=0)
136
-
137
- prefill_scripted_model = torch.jit.trace(wrapped_decoder, prefill_example_inputs, check_trace=False)
138
- dec_scripted_model = torch.jit.trace(wrapped_decoder, dec_example_inputs, check_trace=False)
139
-
140
- prefill_ir = rebel.torchscript_to_ir(
141
- prefill_scripted_model,
142
- input_names=[v[0] for v in prefill_rbln_runtime_config.input_info],
143
- )
144
- dec_ir = rebel.torchscript_to_ir(
145
- dec_scripted_model,
146
- input_names=[v[0] for v in dec_rbln_runtime_config.input_info],
147
- )
148
-
149
- connections = [(prefill_ir.outputs[1 + i], prefill_ir.inputs[3 + i]) for i in range(model.config.n_layer * 2)]
150
-
151
- compiled_model = rebel.compile(
152
- prefill_ir,
153
- dec_ir,
154
- connections=connections,
155
- fusion=prefill_rbln_runtime_config.fusion,
156
- npu=prefill_rbln_runtime_config.npu,
157
- tensor_parallel_size=prefill_rbln_runtime_config.tensor_parallel_size,
158
- use_weight_sharing=True,
159
- )
160
-
161
- return compiled_model
162
-
163
78
  @classmethod
164
79
  def _get_rbln_config(
165
80
  cls,
@@ -167,153 +82,74 @@ class RBLNGPT2LMHeadModel(RBLNModel, RBLNGenerationMixin):
167
82
  model_config: "PretrainedConfig",
168
83
  rbln_max_seq_len: Optional[int] = None,
169
84
  rbln_batch_size: Optional[int] = None,
170
- rbln_pad_token_id: Optional[int] = None,
85
+ **kwargs,
171
86
  ) -> RBLNConfig:
172
87
  meta = {}
173
88
 
174
- default_max_length = getattr(model_config, "n_positions", None)
175
- for tokenizer in preprocessors:
176
- default_max_length = default_max_length or getattr(tokenizer, "max_len_single_sentence", None)
177
-
178
89
  prefill_chunk_size = 128
90
+ if rbln_max_seq_len is None: # differenct from llama
91
+ rbln_max_seq_len = getattr(model_config, "n_positions", None)
92
+ rbln_batch_size = 1 if rbln_batch_size is None else rbln_batch_size
179
93
 
180
- if rbln_max_seq_len is None:
181
- rbln_max_seq_len = default_max_length
182
-
183
- if rbln_max_seq_len is None:
184
- raise ValueError("`rbln_max_seq_len` should be specified!")
185
-
186
- if rbln_pad_token_id is None:
187
- rbln_pad_token_id = getattr(model_config, "pad_token_id", None)
188
- if rbln_pad_token_id is None:
189
- rbln_pad_token_id = getattr(model_config, "eos_token_id", None)
190
- if rbln_pad_token_id is None:
191
- rbln_pad_token_id = 50256
192
-
193
- meta["rbln_prefill_chunk_size"] = prefill_chunk_size
194
94
  meta["rbln_max_seq_len"] = rbln_max_seq_len
195
- meta["rbln_pad_token_id"] = rbln_pad_token_id
196
-
197
- rbln_batch_size = 1 if rbln_batch_size is None else rbln_batch_size
95
+ meta["rbln_batch_size"] = rbln_batch_size
96
+ meta["rbln_prefill_chunk_size"] = prefill_chunk_size
198
97
 
199
- def get_input_info(query_length):
200
- return [
201
- ("input_ids", [rbln_batch_size, query_length], "int64"),
202
- ("attention_mask", [rbln_batch_size, 1, query_length, rbln_max_seq_len], "int64"),
98
+ def get_input_info(
99
+ batch_size,
100
+ query_length,
101
+ ):
102
+ head_dim = (
103
+ model_config.head_dim
104
+ if hasattr(model_config, "head_dim")
105
+ else model_config.hidden_size // model_config.n_head
106
+ )
107
+ input_info = [
108
+ ("input_ids", [batch_size, query_length], "int64"),
109
+ ("attention_mask", [batch_size, 1, query_length, rbln_max_seq_len], "int64"),
203
110
  (
204
111
  "cache_position",
205
- [],
112
+ [batch_size, query_length],
206
113
  "int32",
207
114
  ),
208
- ] + [
209
- (
210
- f"past_key_values_{i}",
211
- [
212
- rbln_batch_size,
213
- model_config.n_head,
214
- rbln_max_seq_len,
215
- model_config.hidden_size // model_config.n_head,
216
- ],
217
- "float32",
218
- )
219
- for i in range(model_config.n_layer * 2)
115
+ ("batch_position", [], "int16"),
220
116
  ]
221
117
 
222
- # model input info
223
- prefill_input_info = get_input_info(query_length=prefill_chunk_size)
224
- dec_input_info = get_input_info(query_length=1)
118
+ input_info.extend(
119
+ [
120
+ (
121
+ f"past_key_values_{i}",
122
+ [
123
+ rbln_batch_size,
124
+ model_config.n_head, # differenct from llama
125
+ rbln_max_seq_len,
126
+ head_dim,
127
+ ],
128
+ "float32",
129
+ )
130
+ for i in range(model_config.n_layer * 2) # differenct from llama
131
+ ]
132
+ )
133
+
134
+ return input_info
135
+
136
+ prefill_input_info = get_input_info(
137
+ batch_size=1,
138
+ query_length=prefill_chunk_size,
139
+ )
140
+ dec_input_info = get_input_info(
141
+ batch_size=rbln_batch_size,
142
+ query_length=1,
143
+ )
225
144
 
226
145
  prefill_rbln_runtime_config = RBLNRuntimeConfig(input_info=prefill_input_info)
227
146
  dec_rbln_runtime_config = RBLNRuntimeConfig(input_info=dec_input_info)
228
147
 
148
+ dec_rbln_runtime_config.batch_size = rbln_batch_size
149
+
229
150
  rbln_config = RBLNConfig.from_rbln_runtime_configs(
230
151
  [prefill_rbln_runtime_config, dec_rbln_runtime_config],
231
152
  _rbln_meta=meta,
232
153
  )
233
154
 
234
155
  return rbln_config
235
-
236
- @classmethod
237
- def _create_runtimes(
238
- cls, compiled_models: List[rebel.RBLNCompiledModel], rbln_device_map: Dict[str, int]
239
- ) -> List[rebel.Runtime]:
240
- device_val = rbln_device_map[DEFAULT_COMPILED_MODEL_NAME]
241
- return [
242
- compiled_models[0].create_runtime(input_info_index=0, tensor_type="pt", device=device_val),
243
- compiled_models[0].create_runtime(input_info_index=1, tensor_type="pt", device=device_val),
244
- ]
245
-
246
- def prepare_inputs_for_generation(self, input_ids, past_key_values=0, attention_mask=None, **kwargs):
247
- batch_size, cur_len = input_ids.shape
248
- past_cached_length = past_key_values
249
-
250
- # In greedy decoding
251
- if past_cached_length == 0:
252
- self.prompt_ids = input_ids
253
- self.rightpad_max_len = cur_len
254
- prompt_min_len = torch.min(torch.sum(attention_mask, dim=-1))
255
- self.dummy_len = torch.sum(attention_mask, dim=-1) - prompt_min_len
256
-
257
- if cur_len % self.prefill_chunk_size == 0:
258
- pad_len = 0
259
- else:
260
- pad_len = self.prefill_chunk_size - cur_len % self.prefill_chunk_size
261
- input_ids = torch.nn.functional.pad(input_ids, (0, pad_len))
262
- attention_mask = self.prefill_attention_mask.clone()
263
- cache_position = torch.tensor(past_cached_length, dtype=torch.int32)
264
-
265
- query_length = prompt_min_len.item()
266
- else:
267
- cache_position = torch.tensor(past_cached_length, dtype=torch.int32)
268
- attention_mask = torch.zeros(batch_size, 1, 1, self.max_seq_len, dtype=torch.int64)
269
- attention_mask[:, :, :, : cache_position + 1] = 1
270
- input_ids = input_ids[:, cache_position : cache_position + 1].contiguous()
271
- query_length = 1
272
-
273
- model_inputs = {
274
- "input_ids": input_ids,
275
- "past_key_values": past_key_values,
276
- "attention_mask": attention_mask,
277
- # below are rbln-related kwargs
278
- "cache_position": cache_position,
279
- "query_length": query_length,
280
- }
281
-
282
- return model_inputs
283
-
284
- def forward(
285
- self,
286
- input_ids: Optional[torch.LongTensor] = None,
287
- past_key_values: Optional[Tuple[Tuple[torch.Tensor]]] = None,
288
- attention_mask: Optional[torch.FloatTensor] = None,
289
- cache_position: Optional[torch.Tensor] = None,
290
- query_length: Optional[torch.Tensor] = None,
291
- **kwargs,
292
- ) -> Union[Tuple, CausalLMOutputWithCrossAttentions]:
293
- if past_key_values is not None:
294
- past_key_values += query_length
295
-
296
- if cache_position == 0:
297
- for step in range(0, query_length, self.prefill_chunk_size):
298
- sliced_input_ids = input_ids[:, step : step + self.prefill_chunk_size]
299
- attention_mask[:, :, :, :step] = 1
300
- attention_mask[:, :, :, step : step + self.prefill_chunk_size] = self.causal_mask
301
-
302
- output = self.prefill_decoder(
303
- input_ids=sliced_input_ids.contiguous(),
304
- attention_mask=attention_mask.contiguous(),
305
- cache_position=cache_position + step,
306
- )
307
-
308
- idx = query_length % self.prefill_chunk_size - 1
309
- output = output.logits[:, idx].unsqueeze(1)
310
-
311
- else:
312
- output = self.decoder(
313
- input_ids=input_ids.contiguous(),
314
- attention_mask=attention_mask.contiguous(),
315
- cache_position=cache_position,
316
- )
317
- output = output.logits
318
-
319
- return CausalLMOutputWithCrossAttentions(logits=output, past_key_values=past_key_values)