optimum-rbln 0.1.0__py3-none-any.whl → 0.1.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (29) hide show
  1. optimum/rbln/__init__.py +6 -0
  2. optimum/rbln/__version__.py +1 -1
  3. optimum/rbln/diffusers/__init__.py +7 -0
  4. optimum/rbln/diffusers/models/autoencoder_kl.py +30 -9
  5. optimum/rbln/diffusers/models/controlnet.py +93 -23
  6. optimum/rbln/diffusers/models/unet_2d_condition.py +78 -61
  7. optimum/rbln/diffusers/pipelines/__init__.py +7 -2
  8. optimum/rbln/diffusers/pipelines/controlnet/__init__.py +4 -0
  9. optimum/rbln/diffusers/pipelines/controlnet/pipeline_controlnet.py +768 -0
  10. optimum/rbln/diffusers/pipelines/{stable_diffusion → controlnet}/pipeline_controlnet_img2img.py +25 -16
  11. optimum/rbln/diffusers/pipelines/controlnet/pipeline_controlnet_sd_xl.py +942 -0
  12. optimum/rbln/diffusers/pipelines/controlnet/pipeline_controlnet_sd_xl_img2img.py +955 -0
  13. optimum/rbln/diffusers/pipelines/stable_diffusion/__init__.py +0 -1
  14. optimum/rbln/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion.py +23 -4
  15. optimum/rbln/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_img2img.py +22 -9
  16. optimum/rbln/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl.py +19 -3
  17. optimum/rbln/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_img2img.py +19 -3
  18. optimum/rbln/modeling_base.py +36 -3
  19. optimum/rbln/modeling_seq2seq.py +19 -4
  20. optimum/rbln/transformers/generation/__init__.py +1 -0
  21. optimum/rbln/transformers/generation/streamers.py +17 -0
  22. optimum/rbln/transformers/generation/utils.py +399 -0
  23. optimum/rbln/transformers/models/gpt2/modeling_gpt2.py +24 -333
  24. optimum/rbln/transformers/models/llama/modeling_llama.py +63 -45
  25. optimum/rbln/transformers/models/whisper/modeling_whisper.py +13 -3
  26. {optimum_rbln-0.1.0.dist-info → optimum_rbln-0.1.1.dist-info}/METADATA +1 -1
  27. {optimum_rbln-0.1.0.dist-info → optimum_rbln-0.1.1.dist-info}/RECORD +29 -25
  28. {optimum_rbln-0.1.0.dist-info → optimum_rbln-0.1.1.dist-info}/WHEEL +0 -0
  29. {optimum_rbln-0.1.0.dist-info → optimum_rbln-0.1.1.dist-info}/licenses/LICENSE +0 -0
@@ -0,0 +1,399 @@
1
+ import traceback
2
+ import warnings
3
+ from typing import List, Optional, Union
4
+
5
+ import torch
6
+ from transformers.generation import GenerationConfig
7
+ from transformers.generation.logits_process import LogitsProcessorList
8
+ from transformers.generation.stopping_criteria import (
9
+ StoppingCriteriaList,
10
+ validate_stopping_criteria,
11
+ )
12
+ from transformers.generation.streamers import BaseStreamer
13
+ from transformers.generation.utils import SampleDecoderOnlyOutput
14
+
15
+
16
+ class RBLNGenerationMixin:
17
+ # call 'greedy_search` directly is deprecated and removed in v4.41.
18
+ def greedy_search(self, *args, **kwargs):
19
+ return self._greedy_search(*args, **kwargs)
20
+
21
+ def _greedy_search(
22
+ self,
23
+ input_ids: torch.LongTensor,
24
+ logits_processor: Optional[LogitsProcessorList] = None,
25
+ stopping_criteria: Optional[StoppingCriteriaList] = None,
26
+ max_length: Optional[int] = None,
27
+ pad_token_id: Optional[int] = None,
28
+ eos_token_id: Optional[Union[int, List[int]]] = None,
29
+ output_logits: Optional[bool] = None,
30
+ return_dict_in_generate: Optional[bool] = None,
31
+ streamer: Optional["BaseStreamer"] = None,
32
+ generation_config: Optional[GenerationConfig] = None, # thkim change for 4.41.0
33
+ **model_kwargs,
34
+ ) -> Union[SampleDecoderOnlyOutput, torch.LongTensor]:
35
+
36
+ ###################### thkim change for 4.41.0 ############################
37
+ if generation_config is not None:
38
+ pad_token_id = generation_config.pad_token_id
39
+ output_logits = generation_config.output_logits
40
+ return_dict_in_generate = generation_config.return_dict_in_generate
41
+ ##########################################################################
42
+ # init values
43
+ logits_processor = logits_processor if logits_processor is not None else LogitsProcessorList()
44
+ stopping_criteria = stopping_criteria if stopping_criteria is not None else StoppingCriteriaList()
45
+
46
+ if max_length is not None:
47
+ warnings.warn(
48
+ "`max_length` is deprecated in this function, use"
49
+ " `stopping_criteria=StoppingCriteriaList([MaxLengthCriteria(max_length=max_length)])` instead.",
50
+ UserWarning,
51
+ )
52
+ stopping_criteria = validate_stopping_criteria(stopping_criteria, max_length)
53
+
54
+ pad_token_id = pad_token_id if pad_token_id is not None else self.generation_config.pad_token_id
55
+ eos_token_id = eos_token_id if eos_token_id is not None else self.generation_config.eos_token_id
56
+ if isinstance(eos_token_id, int):
57
+ eos_token_id = [eos_token_id]
58
+ eos_token_id_tensor = torch.tensor(eos_token_id).to(input_ids.device) if eos_token_id is not None else None
59
+
60
+ return_dict_in_generate = (
61
+ return_dict_in_generate
62
+ if return_dict_in_generate is not None
63
+ else self.generation_config.return_dict_in_generate
64
+ )
65
+
66
+ # init attention / hidden states / scores tuples
67
+ raw_logits = () if (return_dict_in_generate and output_logits) else None
68
+
69
+ # keep track of which sequences are already finished
70
+ unfinished_sequences = torch.ones(input_ids.shape[0], dtype=torch.long, device=input_ids.device)
71
+
72
+ this_peer_finished = False # used by synced_gpus only
73
+
74
+ while True:
75
+ # prepare model inputs
76
+ model_inputs = self.prepare_inputs_for_generation(input_ids, **model_kwargs)
77
+ # forward pass to get next token
78
+ try:
79
+ outputs = self(
80
+ **model_inputs,
81
+ return_dict=True,
82
+ )
83
+ next_token_logits = outputs.logits[:, -1, :]
84
+ except Exception:
85
+ traceback.print_exc()
86
+ break
87
+
88
+ # pre-process distribution
89
+ next_tokens_scores = logits_processor(input_ids, next_token_logits)
90
+
91
+ # Store scores, attentions and hidden_states when required
92
+ if return_dict_in_generate:
93
+ if output_logits:
94
+ raw_logits += (next_token_logits,)
95
+
96
+ # argmax
97
+ next_tokens = torch.argmax(next_tokens_scores, dim=-1)
98
+
99
+ # finished sentences should have their next token be a padding token
100
+ if eos_token_id is not None:
101
+ if pad_token_id is None:
102
+ raise ValueError("If `eos_token_id` is defined, make sure that `pad_token_id` is defined.")
103
+ next_tokens = next_tokens * unfinished_sequences + pad_token_id * (1 - unfinished_sequences)
104
+
105
+ ########################################################################################################
106
+ # thkim change for right-padding batch
107
+ # if min_input_len <= update_idx < max_input_len
108
+ # update validate input_ids[:,update_idx]
109
+ # TODO : raw_logits contains dummy next_token's logits
110
+ if hasattr(self, "rightpad_max_len"):
111
+ update_idx = model_inputs["cache_position"] + model_inputs["query_length"]
112
+ if update_idx < self.rightpad_max_len:
113
+ # update exist input_ids rather than concat
114
+ valid_indices = model_kwargs["attention_mask"][:, update_idx] == 0
115
+ dummy_indices = model_kwargs["attention_mask"][:, update_idx] == 1
116
+
117
+ input_ids[valid_indices, update_idx] = next_tokens[valid_indices]
118
+ model_kwargs["attention_mask"][valid_indices, update_idx] = 1
119
+ model_kwargs["past_key_values"] = outputs["past_key_values"]
120
+
121
+ # dummy next_token -> pad_token_id for streamer
122
+ # in order to skip by 'skip_special_tokens = True"
123
+ if streamer is not None:
124
+ next_tokens[dummy_indices] = pad_token_id
125
+ else:
126
+ input_ids = torch.cat([input_ids, next_tokens[:, None]], dim=-1)
127
+ model_kwargs = self._update_model_kwargs_for_generation(
128
+ outputs,
129
+ model_kwargs,
130
+ is_encoder_decoder=self.config.is_encoder_decoder,
131
+ )
132
+ else:
133
+ ############################################END#########################################################
134
+ # update generated ids, model inputs, and length for next step
135
+ input_ids = torch.cat([input_ids, next_tokens[:, None]], dim=-1)
136
+ model_kwargs = self._update_model_kwargs_for_generation(
137
+ outputs,
138
+ model_kwargs,
139
+ is_encoder_decoder=self.config.is_encoder_decoder,
140
+ )
141
+
142
+ if streamer is not None:
143
+ streamer.put(next_tokens.cpu())
144
+ if streamer.is_blocked():
145
+ this_peer_finished = True
146
+
147
+ # if eos_token was found in one sentence, set sentence to finished
148
+ if eos_token_id_tensor is not None:
149
+ ####################################################################
150
+ # thkim : to do not finish sequence of dummy_decoder of right_padding
151
+ if hasattr(self, "rightpad_max_len"):
152
+ update_idx = model_inputs["cache_position"] + model_inputs["query_length"]
153
+ if update_idx < self.rightpad_max_len:
154
+ next_tokens += (
155
+ model_kwargs["attention_mask"][:, update_idx] * self.generation_config.eos_token_id
156
+ )
157
+ ######################################################################
158
+ unfinished_sequences = unfinished_sequences.mul(
159
+ next_tokens.tile(eos_token_id_tensor.shape[0], 1).ne(eos_token_id_tensor.unsqueeze(1)).prod(dim=0)
160
+ )
161
+
162
+ # stop when each sentence is finished
163
+ if unfinished_sequences.max() == 0:
164
+ this_peer_finished = True
165
+
166
+ # stop if we exceed the maximum length
167
+ # thkim : backward compatibility bool vs torch.BoolTensor
168
+ is_stop = stopping_criteria(input_ids, None)
169
+ if isinstance(is_stop, torch.BoolTensor):
170
+ is_stop = torch.all(is_stop)
171
+ if is_stop:
172
+ this_peer_finished = True
173
+
174
+ if this_peer_finished:
175
+ break
176
+
177
+ if streamer is not None:
178
+ streamer.end()
179
+
180
+ if return_dict_in_generate:
181
+ ############## thkim : roate raw_logits when right_padding#####################
182
+ if hasattr(self, "rightpad_max_len"):
183
+ raw_logits = torch.stack(raw_logits).transpose(0, 1)
184
+ for i in range(input_ids.shape[0]):
185
+ raw_logits[i] = torch.cat((raw_logits[i][self.dummy_len[i] :], raw_logits[i][: self.dummy_len[i]]))
186
+ raw_logits = raw_logits.transpose(1, 0)
187
+ ##################################################################################
188
+ return SampleDecoderOnlyOutput(
189
+ sequences=input_ids,
190
+ logits=raw_logits,
191
+ )
192
+ else:
193
+ return input_ids
194
+
195
+ # call 'sample` directly is deprecated and removed in v4.41.
196
+ def sample(self, *args, **kwargs):
197
+ return self._sample(*args, **kwargs)
198
+
199
+ def _sample(
200
+ self,
201
+ input_ids: torch.LongTensor,
202
+ logits_processor: Optional[LogitsProcessorList] = None,
203
+ stopping_criteria: Optional[StoppingCriteriaList] = None,
204
+ logits_warper: Optional[LogitsProcessorList] = None,
205
+ max_length: Optional[int] = None,
206
+ pad_token_id: Optional[int] = None,
207
+ eos_token_id: Optional[Union[int, List[int]]] = None,
208
+ output_attentions: Optional[bool] = None,
209
+ output_hidden_states: Optional[bool] = None,
210
+ output_scores: Optional[bool] = None,
211
+ output_logits: Optional[bool] = None,
212
+ return_dict_in_generate: Optional[bool] = None,
213
+ synced_gpus: bool = False,
214
+ streamer: Optional["BaseStreamer"] = None,
215
+ generation_config: Optional[GenerationConfig] = None,
216
+ do_sample: Optional[bool] = True,
217
+ **model_kwargs,
218
+ ) -> Union[SampleDecoderOnlyOutput, torch.LongTensor]:
219
+
220
+ ###################### thkim change for 4.41.0 ############################
221
+ if generation_config is not None:
222
+ pad_token_id = generation_config.pad_token_id
223
+ output_logits = generation_config.output_logits
224
+ return_dict_in_generate = generation_config.return_dict_in_generate
225
+ do_sample = generation_config.do_sample
226
+ ###########################################################################
227
+
228
+ # init values
229
+ logits_processor = logits_processor if logits_processor is not None else LogitsProcessorList()
230
+ stopping_criteria = stopping_criteria if stopping_criteria is not None else StoppingCriteriaList()
231
+
232
+ if max_length is not None:
233
+ warnings.warn(
234
+ "`max_length` is deprecated in this function, use"
235
+ " `stopping_criteria=StoppingCriteriaList([MaxLengthCriteria(max_length=max_length)])` instead.",
236
+ UserWarning,
237
+ )
238
+ stopping_criteria = validate_stopping_criteria(stopping_criteria, max_length)
239
+
240
+ logits_warper = logits_warper if logits_warper is not None else LogitsProcessorList()
241
+ pad_token_id = pad_token_id if pad_token_id is not None else self.generation_config.pad_token_id
242
+ eos_token_id = eos_token_id if eos_token_id is not None else self.generation_config.eos_token_id
243
+
244
+ if isinstance(eos_token_id, int):
245
+ eos_token_id = [eos_token_id]
246
+ eos_token_id_tensor = torch.tensor(eos_token_id).to(input_ids.device) if eos_token_id is not None else None
247
+
248
+ output_scores = output_scores if output_scores is not None else self.generation_config.output_scores
249
+ output_logits = output_logits if output_logits is not None else False
250
+
251
+ # init attention / hidden states / scores tuples
252
+ scores = () if (return_dict_in_generate and output_scores) else None
253
+ raw_logits = () if (return_dict_in_generate and output_logits) else None
254
+
255
+ # keep track of which sequences are already finished
256
+ batch_size, cur_len = input_ids.shape
257
+ unfinished_sequences = torch.ones(batch_size, dtype=torch.long, device=input_ids.device)
258
+ this_peer_finished = False
259
+
260
+ # model_kwargs["cache_position"] = torch.arange(cur_len, device=input_ids.device)
261
+
262
+ while True:
263
+ # prepare model inputs
264
+ model_inputs = self.prepare_inputs_for_generation(input_ids, **model_kwargs)
265
+
266
+ # forward pass to get next token
267
+ try:
268
+ outputs = self(
269
+ **model_inputs,
270
+ return_dict=True,
271
+ output_attentions=output_attentions,
272
+ output_hidden_states=output_hidden_states,
273
+ )
274
+ next_token_logits = outputs.logits[:, -1, :]
275
+ except Exception:
276
+ traceback.print_exc()
277
+ break
278
+
279
+ # pre-process distribution
280
+ next_token_scores = logits_processor(input_ids, next_token_logits)
281
+
282
+ ###################### thkim change for 4.41.0 ############################
283
+ if do_sample:
284
+ next_token_scores = logits_warper(input_ids, next_token_scores)
285
+ ###########################################################################
286
+
287
+ # Store scores, attentions and hidden_states when required
288
+ if return_dict_in_generate:
289
+ if output_scores:
290
+ scores += (next_token_scores,)
291
+ if output_logits:
292
+ raw_logits += (next_token_logits,)
293
+
294
+ # sample
295
+ ###################### thkim change for 4.41.0 ############################
296
+ if do_sample:
297
+ probs = torch.nn.functional.softmax(next_token_scores, dim=-1)
298
+ next_tokens = torch.multinomial(probs, num_samples=1).squeeze(1)
299
+ else:
300
+ next_tokens = torch.argmax(next_token_scores, dim=-1)
301
+ ###########################################################################
302
+
303
+ # finished sentences should have their next token be a padding token
304
+ if eos_token_id is not None:
305
+ if pad_token_id is None:
306
+ raise ValueError("If `eos_token_id` is defined, make sure that `pad_token_id` is defined.")
307
+ next_tokens = next_tokens * unfinished_sequences + pad_token_id * (1 - unfinished_sequences)
308
+
309
+ ###############################thkim change for right-padding batch#################################
310
+ # if min_input_len <= update_idx < max_input_len
311
+ # update validate input_ids[:,update_idx]
312
+ # TODO : raw_logits contains dummy next_token's logits
313
+
314
+ if hasattr(self, "rightpad_max_len"):
315
+ update_idx = model_inputs["cache_position"] + model_inputs["query_length"]
316
+ if update_idx < self.rightpad_max_len:
317
+ # update exist input_ids rather than concat
318
+ valid_indices = model_kwargs["attention_mask"][:, update_idx] == 0
319
+ dummy_indices = model_kwargs["attention_mask"][:, update_idx] == 1
320
+
321
+ input_ids[valid_indices, update_idx] = next_tokens[valid_indices]
322
+ model_kwargs["attention_mask"][valid_indices, update_idx] = 1
323
+ model_kwargs["past_key_values"] = outputs["past_key_values"]
324
+ # dummy next_token -> pad_token_id for streamer
325
+ # in order to skip by 'skip_special_tokens = True"
326
+ if streamer is not None:
327
+ next_tokens[dummy_indices] = pad_token_id
328
+ else:
329
+ input_ids = torch.cat([input_ids, next_tokens[:, None]], dim=-1)
330
+ model_kwargs = self._update_model_kwargs_for_generation(
331
+ outputs,
332
+ model_kwargs,
333
+ is_encoder_decoder=self.config.is_encoder_decoder,
334
+ )
335
+ else:
336
+ ############################################END#########################################################
337
+ # update generated ids, model inputs, and length for next step
338
+ input_ids = torch.cat([input_ids, next_tokens[:, None]], dim=-1)
339
+
340
+ model_kwargs = self._update_model_kwargs_for_generation(
341
+ outputs,
342
+ model_kwargs,
343
+ is_encoder_decoder=self.config.is_encoder_decoder,
344
+ )
345
+
346
+ if streamer is not None:
347
+ streamer.put(next_tokens.cpu())
348
+ if streamer.is_blocked():
349
+ this_peer_finished = True
350
+
351
+ # if eos_token was found in one sentence, set sentence to finished
352
+ if eos_token_id_tensor is not None:
353
+ ####################################################################
354
+ # thkim : to do not finish sequence of dummy_decoder of right_padding
355
+ if hasattr(self, "rightpad_max_len"):
356
+ update_idx = model_inputs["cache_position"] + model_inputs["query_length"]
357
+ if update_idx < self.rightpad_max_len:
358
+ next_tokens += (
359
+ model_kwargs["attention_mask"][:, update_idx] * self.generation_config.eos_token_id
360
+ )
361
+
362
+ ######################################################################
363
+ unfinished_sequences = unfinished_sequences.mul(
364
+ next_tokens.tile(eos_token_id_tensor.shape[0], 1).ne(eos_token_id_tensor.unsqueeze(1)).prod(dim=0)
365
+ )
366
+
367
+ # stop when each sentence is finished
368
+ if unfinished_sequences.max() == 0:
369
+ this_peer_finished = True
370
+
371
+ # stop if we exceed the maximum length
372
+ # thkim : backward compatibility bool vs list[bool]
373
+ is_stop = stopping_criteria(input_ids, None)
374
+ if isinstance(is_stop, torch.BoolTensor):
375
+ is_stop = torch.all(is_stop)
376
+ if is_stop:
377
+ this_peer_finished = True
378
+
379
+ if this_peer_finished:
380
+ break
381
+
382
+ if streamer is not None:
383
+ streamer.end()
384
+
385
+ if return_dict_in_generate:
386
+ ############## thkim : roate raw_logits when right_padding#####################
387
+ if hasattr(self, "rightpad_max_len"):
388
+ raw_logits = torch.stack(raw_logits).transpose(0, 1)
389
+ for i in range(input_ids.shape[0]):
390
+ raw_logits[i] = torch.cat((raw_logits[i][self.dummy_len[i] :], raw_logits[i][: self.dummy_len[i]]))
391
+ raw_logits = raw_logits.transpose(1, 0)
392
+ ##################################################################################
393
+ return SampleDecoderOnlyOutput(
394
+ sequences=input_ids,
395
+ scores=scores,
396
+ logits=raw_logits,
397
+ )
398
+ else:
399
+ return input_ids