optimum-rbln 0.1.0__py3-none-any.whl → 0.1.4__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- optimum/rbln/__init__.py +8 -0
- optimum/rbln/__version__.py +1 -1
- optimum/rbln/diffusers/__init__.py +7 -0
- optimum/rbln/diffusers/models/autoencoder_kl.py +30 -9
- optimum/rbln/diffusers/models/controlnet.py +93 -23
- optimum/rbln/diffusers/models/unet_2d_condition.py +78 -61
- optimum/rbln/diffusers/pipelines/__init__.py +7 -2
- optimum/rbln/diffusers/pipelines/controlnet/__init__.py +4 -0
- optimum/rbln/diffusers/pipelines/controlnet/pipeline_controlnet.py +768 -0
- optimum/rbln/diffusers/pipelines/{stable_diffusion → controlnet}/pipeline_controlnet_img2img.py +25 -16
- optimum/rbln/diffusers/pipelines/controlnet/pipeline_controlnet_sd_xl.py +942 -0
- optimum/rbln/diffusers/pipelines/controlnet/pipeline_controlnet_sd_xl_img2img.py +955 -0
- optimum/rbln/diffusers/pipelines/stable_diffusion/__init__.py +0 -1
- optimum/rbln/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion.py +23 -4
- optimum/rbln/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_img2img.py +22 -9
- optimum/rbln/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl.py +19 -3
- optimum/rbln/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_img2img.py +19 -3
- optimum/rbln/modeling_base.py +39 -6
- optimum/rbln/modeling_seq2seq.py +19 -4
- optimum/rbln/transformers/__init__.py +2 -0
- optimum/rbln/transformers/generation/__init__.py +1 -0
- optimum/rbln/transformers/generation/streamers.py +17 -0
- optimum/rbln/transformers/generation/utils.py +399 -0
- optimum/rbln/transformers/models/__init__.py +1 -0
- optimum/rbln/transformers/models/gpt2/modeling_gpt2.py +24 -333
- optimum/rbln/transformers/models/llama/llama_architecture.py +49 -17
- optimum/rbln/transformers/models/llama/llama_architecture_cb.py +759 -0
- optimum/rbln/transformers/models/llama/modeling_llama.py +187 -75
- optimum/rbln/transformers/models/midm/__init__.py +32 -0
- optimum/rbln/transformers/models/midm/hf_hub_cached/configuration_midm.py +22 -0
- optimum/rbln/transformers/models/midm/hf_hub_cached/midm_bitext_tokenization.py +303 -0
- optimum/rbln/transformers/models/midm/hf_hub_cached/modeling_midm.py +1473 -0
- optimum/rbln/transformers/models/midm/hf_hub_cached/rotary_position_embedding.py +98 -0
- optimum/rbln/transformers/models/midm/midm_architecture.py +506 -0
- optimum/rbln/transformers/models/midm/modeling_midm.py +426 -0
- optimum/rbln/transformers/models/whisper/modeling_whisper.py +13 -3
- {optimum_rbln-0.1.0.dist-info → optimum_rbln-0.1.4.dist-info}/METADATA +5 -4
- optimum_rbln-0.1.4.dist-info/RECORD +63 -0
- {optimum_rbln-0.1.0.dist-info → optimum_rbln-0.1.4.dist-info}/WHEEL +1 -1
- optimum_rbln-0.1.0.dist-info/RECORD +0 -51
- {optimum_rbln-0.1.0.dist-info → optimum_rbln-0.1.4.dist-info}/licenses/LICENSE +0 -0
@@ -0,0 +1,399 @@
|
|
1
|
+
import traceback
|
2
|
+
import warnings
|
3
|
+
from typing import List, Optional, Union
|
4
|
+
|
5
|
+
import torch
|
6
|
+
from transformers.generation import GenerationConfig
|
7
|
+
from transformers.generation.logits_process import LogitsProcessorList
|
8
|
+
from transformers.generation.stopping_criteria import (
|
9
|
+
StoppingCriteriaList,
|
10
|
+
validate_stopping_criteria,
|
11
|
+
)
|
12
|
+
from transformers.generation.streamers import BaseStreamer
|
13
|
+
from transformers.generation.utils import SampleDecoderOnlyOutput
|
14
|
+
|
15
|
+
|
16
|
+
class RBLNGenerationMixin:
|
17
|
+
# call 'greedy_search` directly is deprecated and removed in v4.41.
|
18
|
+
def greedy_search(self, *args, **kwargs):
|
19
|
+
return self._greedy_search(*args, **kwargs)
|
20
|
+
|
21
|
+
def _greedy_search(
|
22
|
+
self,
|
23
|
+
input_ids: torch.LongTensor,
|
24
|
+
logits_processor: Optional[LogitsProcessorList] = None,
|
25
|
+
stopping_criteria: Optional[StoppingCriteriaList] = None,
|
26
|
+
max_length: Optional[int] = None,
|
27
|
+
pad_token_id: Optional[int] = None,
|
28
|
+
eos_token_id: Optional[Union[int, List[int]]] = None,
|
29
|
+
output_logits: Optional[bool] = None,
|
30
|
+
return_dict_in_generate: Optional[bool] = None,
|
31
|
+
streamer: Optional["BaseStreamer"] = None,
|
32
|
+
generation_config: Optional[GenerationConfig] = None, # thkim change for 4.41.0
|
33
|
+
**model_kwargs,
|
34
|
+
) -> Union[SampleDecoderOnlyOutput, torch.LongTensor]:
|
35
|
+
|
36
|
+
###################### thkim change for 4.41.0 ############################
|
37
|
+
if generation_config is not None:
|
38
|
+
pad_token_id = generation_config.pad_token_id
|
39
|
+
output_logits = generation_config.output_logits
|
40
|
+
return_dict_in_generate = generation_config.return_dict_in_generate
|
41
|
+
##########################################################################
|
42
|
+
# init values
|
43
|
+
logits_processor = logits_processor if logits_processor is not None else LogitsProcessorList()
|
44
|
+
stopping_criteria = stopping_criteria if stopping_criteria is not None else StoppingCriteriaList()
|
45
|
+
|
46
|
+
if max_length is not None:
|
47
|
+
warnings.warn(
|
48
|
+
"`max_length` is deprecated in this function, use"
|
49
|
+
" `stopping_criteria=StoppingCriteriaList([MaxLengthCriteria(max_length=max_length)])` instead.",
|
50
|
+
UserWarning,
|
51
|
+
)
|
52
|
+
stopping_criteria = validate_stopping_criteria(stopping_criteria, max_length)
|
53
|
+
|
54
|
+
pad_token_id = pad_token_id if pad_token_id is not None else self.generation_config.pad_token_id
|
55
|
+
eos_token_id = eos_token_id if eos_token_id is not None else self.generation_config.eos_token_id
|
56
|
+
if isinstance(eos_token_id, int):
|
57
|
+
eos_token_id = [eos_token_id]
|
58
|
+
eos_token_id_tensor = torch.tensor(eos_token_id).to(input_ids.device) if eos_token_id is not None else None
|
59
|
+
|
60
|
+
return_dict_in_generate = (
|
61
|
+
return_dict_in_generate
|
62
|
+
if return_dict_in_generate is not None
|
63
|
+
else self.generation_config.return_dict_in_generate
|
64
|
+
)
|
65
|
+
|
66
|
+
# init attention / hidden states / scores tuples
|
67
|
+
raw_logits = () if (return_dict_in_generate and output_logits) else None
|
68
|
+
|
69
|
+
# keep track of which sequences are already finished
|
70
|
+
unfinished_sequences = torch.ones(input_ids.shape[0], dtype=torch.long, device=input_ids.device)
|
71
|
+
|
72
|
+
this_peer_finished = False # used by synced_gpus only
|
73
|
+
|
74
|
+
while True:
|
75
|
+
# prepare model inputs
|
76
|
+
model_inputs = self.prepare_inputs_for_generation(input_ids, **model_kwargs)
|
77
|
+
# forward pass to get next token
|
78
|
+
try:
|
79
|
+
outputs = self(
|
80
|
+
**model_inputs,
|
81
|
+
return_dict=True,
|
82
|
+
)
|
83
|
+
next_token_logits = outputs.logits[:, -1, :]
|
84
|
+
except Exception:
|
85
|
+
traceback.print_exc()
|
86
|
+
break
|
87
|
+
|
88
|
+
# pre-process distribution
|
89
|
+
next_tokens_scores = logits_processor(input_ids, next_token_logits)
|
90
|
+
|
91
|
+
# Store scores, attentions and hidden_states when required
|
92
|
+
if return_dict_in_generate:
|
93
|
+
if output_logits:
|
94
|
+
raw_logits += (next_token_logits,)
|
95
|
+
|
96
|
+
# argmax
|
97
|
+
next_tokens = torch.argmax(next_tokens_scores, dim=-1)
|
98
|
+
|
99
|
+
# finished sentences should have their next token be a padding token
|
100
|
+
if eos_token_id is not None:
|
101
|
+
if pad_token_id is None:
|
102
|
+
raise ValueError("If `eos_token_id` is defined, make sure that `pad_token_id` is defined.")
|
103
|
+
next_tokens = next_tokens * unfinished_sequences + pad_token_id * (1 - unfinished_sequences)
|
104
|
+
|
105
|
+
########################################################################################################
|
106
|
+
# thkim change for right-padding batch
|
107
|
+
# if min_input_len <= update_idx < max_input_len
|
108
|
+
# update validate input_ids[:,update_idx]
|
109
|
+
# TODO : raw_logits contains dummy next_token's logits
|
110
|
+
if hasattr(self, "rightpad_max_len"):
|
111
|
+
update_idx = model_inputs["cache_position"] + model_inputs["query_length"]
|
112
|
+
if update_idx < self.rightpad_max_len:
|
113
|
+
# update exist input_ids rather than concat
|
114
|
+
valid_indices = model_kwargs["attention_mask"][:, update_idx] == 0
|
115
|
+
dummy_indices = model_kwargs["attention_mask"][:, update_idx] == 1
|
116
|
+
|
117
|
+
input_ids[valid_indices, update_idx] = next_tokens[valid_indices]
|
118
|
+
model_kwargs["attention_mask"][valid_indices, update_idx] = 1
|
119
|
+
model_kwargs["past_key_values"] = outputs["past_key_values"]
|
120
|
+
|
121
|
+
# dummy next_token -> pad_token_id for streamer
|
122
|
+
# in order to skip by 'skip_special_tokens = True"
|
123
|
+
if streamer is not None:
|
124
|
+
next_tokens[dummy_indices] = pad_token_id
|
125
|
+
else:
|
126
|
+
input_ids = torch.cat([input_ids, next_tokens[:, None]], dim=-1)
|
127
|
+
model_kwargs = self._update_model_kwargs_for_generation(
|
128
|
+
outputs,
|
129
|
+
model_kwargs,
|
130
|
+
is_encoder_decoder=self.config.is_encoder_decoder,
|
131
|
+
)
|
132
|
+
else:
|
133
|
+
############################################END#########################################################
|
134
|
+
# update generated ids, model inputs, and length for next step
|
135
|
+
input_ids = torch.cat([input_ids, next_tokens[:, None]], dim=-1)
|
136
|
+
model_kwargs = self._update_model_kwargs_for_generation(
|
137
|
+
outputs,
|
138
|
+
model_kwargs,
|
139
|
+
is_encoder_decoder=self.config.is_encoder_decoder,
|
140
|
+
)
|
141
|
+
|
142
|
+
if streamer is not None:
|
143
|
+
streamer.put(next_tokens.cpu())
|
144
|
+
if streamer.is_blocked():
|
145
|
+
this_peer_finished = True
|
146
|
+
|
147
|
+
# if eos_token was found in one sentence, set sentence to finished
|
148
|
+
if eos_token_id_tensor is not None:
|
149
|
+
####################################################################
|
150
|
+
# thkim : to do not finish sequence of dummy_decoder of right_padding
|
151
|
+
if hasattr(self, "rightpad_max_len"):
|
152
|
+
update_idx = model_inputs["cache_position"] + model_inputs["query_length"]
|
153
|
+
if update_idx < self.rightpad_max_len:
|
154
|
+
next_tokens += (
|
155
|
+
model_kwargs["attention_mask"][:, update_idx] * self.generation_config.eos_token_id
|
156
|
+
)
|
157
|
+
######################################################################
|
158
|
+
unfinished_sequences = unfinished_sequences.mul(
|
159
|
+
next_tokens.tile(eos_token_id_tensor.shape[0], 1).ne(eos_token_id_tensor.unsqueeze(1)).prod(dim=0)
|
160
|
+
)
|
161
|
+
|
162
|
+
# stop when each sentence is finished
|
163
|
+
if unfinished_sequences.max() == 0:
|
164
|
+
this_peer_finished = True
|
165
|
+
|
166
|
+
# stop if we exceed the maximum length
|
167
|
+
# thkim : backward compatibility bool vs torch.BoolTensor
|
168
|
+
is_stop = stopping_criteria(input_ids, None)
|
169
|
+
if isinstance(is_stop, torch.BoolTensor):
|
170
|
+
is_stop = torch.all(is_stop)
|
171
|
+
if is_stop:
|
172
|
+
this_peer_finished = True
|
173
|
+
|
174
|
+
if this_peer_finished:
|
175
|
+
break
|
176
|
+
|
177
|
+
if streamer is not None:
|
178
|
+
streamer.end()
|
179
|
+
|
180
|
+
if return_dict_in_generate:
|
181
|
+
############## thkim : roate raw_logits when right_padding#####################
|
182
|
+
if hasattr(self, "rightpad_max_len"):
|
183
|
+
raw_logits = torch.stack(raw_logits).transpose(0, 1)
|
184
|
+
for i in range(input_ids.shape[0]):
|
185
|
+
raw_logits[i] = torch.cat((raw_logits[i][self.dummy_len[i] :], raw_logits[i][: self.dummy_len[i]]))
|
186
|
+
raw_logits = raw_logits.transpose(1, 0)
|
187
|
+
##################################################################################
|
188
|
+
return SampleDecoderOnlyOutput(
|
189
|
+
sequences=input_ids,
|
190
|
+
logits=raw_logits,
|
191
|
+
)
|
192
|
+
else:
|
193
|
+
return input_ids
|
194
|
+
|
195
|
+
# call 'sample` directly is deprecated and removed in v4.41.
|
196
|
+
def sample(self, *args, **kwargs):
|
197
|
+
return self._sample(*args, **kwargs)
|
198
|
+
|
199
|
+
def _sample(
|
200
|
+
self,
|
201
|
+
input_ids: torch.LongTensor,
|
202
|
+
logits_processor: Optional[LogitsProcessorList] = None,
|
203
|
+
stopping_criteria: Optional[StoppingCriteriaList] = None,
|
204
|
+
logits_warper: Optional[LogitsProcessorList] = None,
|
205
|
+
max_length: Optional[int] = None,
|
206
|
+
pad_token_id: Optional[int] = None,
|
207
|
+
eos_token_id: Optional[Union[int, List[int]]] = None,
|
208
|
+
output_attentions: Optional[bool] = None,
|
209
|
+
output_hidden_states: Optional[bool] = None,
|
210
|
+
output_scores: Optional[bool] = None,
|
211
|
+
output_logits: Optional[bool] = None,
|
212
|
+
return_dict_in_generate: Optional[bool] = None,
|
213
|
+
synced_gpus: bool = False,
|
214
|
+
streamer: Optional["BaseStreamer"] = None,
|
215
|
+
generation_config: Optional[GenerationConfig] = None,
|
216
|
+
do_sample: Optional[bool] = True,
|
217
|
+
**model_kwargs,
|
218
|
+
) -> Union[SampleDecoderOnlyOutput, torch.LongTensor]:
|
219
|
+
|
220
|
+
###################### thkim change for 4.41.0 ############################
|
221
|
+
if generation_config is not None:
|
222
|
+
pad_token_id = generation_config.pad_token_id
|
223
|
+
output_logits = generation_config.output_logits
|
224
|
+
return_dict_in_generate = generation_config.return_dict_in_generate
|
225
|
+
do_sample = generation_config.do_sample
|
226
|
+
###########################################################################
|
227
|
+
|
228
|
+
# init values
|
229
|
+
logits_processor = logits_processor if logits_processor is not None else LogitsProcessorList()
|
230
|
+
stopping_criteria = stopping_criteria if stopping_criteria is not None else StoppingCriteriaList()
|
231
|
+
|
232
|
+
if max_length is not None:
|
233
|
+
warnings.warn(
|
234
|
+
"`max_length` is deprecated in this function, use"
|
235
|
+
" `stopping_criteria=StoppingCriteriaList([MaxLengthCriteria(max_length=max_length)])` instead.",
|
236
|
+
UserWarning,
|
237
|
+
)
|
238
|
+
stopping_criteria = validate_stopping_criteria(stopping_criteria, max_length)
|
239
|
+
|
240
|
+
logits_warper = logits_warper if logits_warper is not None else LogitsProcessorList()
|
241
|
+
pad_token_id = pad_token_id if pad_token_id is not None else self.generation_config.pad_token_id
|
242
|
+
eos_token_id = eos_token_id if eos_token_id is not None else self.generation_config.eos_token_id
|
243
|
+
|
244
|
+
if isinstance(eos_token_id, int):
|
245
|
+
eos_token_id = [eos_token_id]
|
246
|
+
eos_token_id_tensor = torch.tensor(eos_token_id).to(input_ids.device) if eos_token_id is not None else None
|
247
|
+
|
248
|
+
output_scores = output_scores if output_scores is not None else self.generation_config.output_scores
|
249
|
+
output_logits = output_logits if output_logits is not None else False
|
250
|
+
|
251
|
+
# init attention / hidden states / scores tuples
|
252
|
+
scores = () if (return_dict_in_generate and output_scores) else None
|
253
|
+
raw_logits = () if (return_dict_in_generate and output_logits) else None
|
254
|
+
|
255
|
+
# keep track of which sequences are already finished
|
256
|
+
batch_size, cur_len = input_ids.shape
|
257
|
+
unfinished_sequences = torch.ones(batch_size, dtype=torch.long, device=input_ids.device)
|
258
|
+
this_peer_finished = False
|
259
|
+
|
260
|
+
# model_kwargs["cache_position"] = torch.arange(cur_len, device=input_ids.device)
|
261
|
+
|
262
|
+
while True:
|
263
|
+
# prepare model inputs
|
264
|
+
model_inputs = self.prepare_inputs_for_generation(input_ids, **model_kwargs)
|
265
|
+
|
266
|
+
# forward pass to get next token
|
267
|
+
try:
|
268
|
+
outputs = self(
|
269
|
+
**model_inputs,
|
270
|
+
return_dict=True,
|
271
|
+
output_attentions=output_attentions,
|
272
|
+
output_hidden_states=output_hidden_states,
|
273
|
+
)
|
274
|
+
next_token_logits = outputs.logits[:, -1, :]
|
275
|
+
except Exception:
|
276
|
+
traceback.print_exc()
|
277
|
+
break
|
278
|
+
|
279
|
+
# pre-process distribution
|
280
|
+
next_token_scores = logits_processor(input_ids, next_token_logits)
|
281
|
+
|
282
|
+
###################### thkim change for 4.41.0 ############################
|
283
|
+
if do_sample:
|
284
|
+
next_token_scores = logits_warper(input_ids, next_token_scores)
|
285
|
+
###########################################################################
|
286
|
+
|
287
|
+
# Store scores, attentions and hidden_states when required
|
288
|
+
if return_dict_in_generate:
|
289
|
+
if output_scores:
|
290
|
+
scores += (next_token_scores,)
|
291
|
+
if output_logits:
|
292
|
+
raw_logits += (next_token_logits,)
|
293
|
+
|
294
|
+
# sample
|
295
|
+
###################### thkim change for 4.41.0 ############################
|
296
|
+
if do_sample:
|
297
|
+
probs = torch.nn.functional.softmax(next_token_scores, dim=-1)
|
298
|
+
next_tokens = torch.multinomial(probs, num_samples=1).squeeze(1)
|
299
|
+
else:
|
300
|
+
next_tokens = torch.argmax(next_token_scores, dim=-1)
|
301
|
+
###########################################################################
|
302
|
+
|
303
|
+
# finished sentences should have their next token be a padding token
|
304
|
+
if eos_token_id is not None:
|
305
|
+
if pad_token_id is None:
|
306
|
+
raise ValueError("If `eos_token_id` is defined, make sure that `pad_token_id` is defined.")
|
307
|
+
next_tokens = next_tokens * unfinished_sequences + pad_token_id * (1 - unfinished_sequences)
|
308
|
+
|
309
|
+
###############################thkim change for right-padding batch#################################
|
310
|
+
# if min_input_len <= update_idx < max_input_len
|
311
|
+
# update validate input_ids[:,update_idx]
|
312
|
+
# TODO : raw_logits contains dummy next_token's logits
|
313
|
+
|
314
|
+
if hasattr(self, "rightpad_max_len"):
|
315
|
+
update_idx = model_inputs["cache_position"] + model_inputs["query_length"]
|
316
|
+
if update_idx < self.rightpad_max_len:
|
317
|
+
# update exist input_ids rather than concat
|
318
|
+
valid_indices = model_kwargs["attention_mask"][:, update_idx] == 0
|
319
|
+
dummy_indices = model_kwargs["attention_mask"][:, update_idx] == 1
|
320
|
+
|
321
|
+
input_ids[valid_indices, update_idx] = next_tokens[valid_indices]
|
322
|
+
model_kwargs["attention_mask"][valid_indices, update_idx] = 1
|
323
|
+
model_kwargs["past_key_values"] = outputs["past_key_values"]
|
324
|
+
# dummy next_token -> pad_token_id for streamer
|
325
|
+
# in order to skip by 'skip_special_tokens = True"
|
326
|
+
if streamer is not None:
|
327
|
+
next_tokens[dummy_indices] = pad_token_id
|
328
|
+
else:
|
329
|
+
input_ids = torch.cat([input_ids, next_tokens[:, None]], dim=-1)
|
330
|
+
model_kwargs = self._update_model_kwargs_for_generation(
|
331
|
+
outputs,
|
332
|
+
model_kwargs,
|
333
|
+
is_encoder_decoder=self.config.is_encoder_decoder,
|
334
|
+
)
|
335
|
+
else:
|
336
|
+
############################################END#########################################################
|
337
|
+
# update generated ids, model inputs, and length for next step
|
338
|
+
input_ids = torch.cat([input_ids, next_tokens[:, None]], dim=-1)
|
339
|
+
|
340
|
+
model_kwargs = self._update_model_kwargs_for_generation(
|
341
|
+
outputs,
|
342
|
+
model_kwargs,
|
343
|
+
is_encoder_decoder=self.config.is_encoder_decoder,
|
344
|
+
)
|
345
|
+
|
346
|
+
if streamer is not None:
|
347
|
+
streamer.put(next_tokens.cpu())
|
348
|
+
if streamer.is_blocked():
|
349
|
+
this_peer_finished = True
|
350
|
+
|
351
|
+
# if eos_token was found in one sentence, set sentence to finished
|
352
|
+
if eos_token_id_tensor is not None:
|
353
|
+
####################################################################
|
354
|
+
# thkim : to do not finish sequence of dummy_decoder of right_padding
|
355
|
+
if hasattr(self, "rightpad_max_len"):
|
356
|
+
update_idx = model_inputs["cache_position"] + model_inputs["query_length"]
|
357
|
+
if update_idx < self.rightpad_max_len:
|
358
|
+
next_tokens += (
|
359
|
+
model_kwargs["attention_mask"][:, update_idx] * self.generation_config.eos_token_id
|
360
|
+
)
|
361
|
+
|
362
|
+
######################################################################
|
363
|
+
unfinished_sequences = unfinished_sequences.mul(
|
364
|
+
next_tokens.tile(eos_token_id_tensor.shape[0], 1).ne(eos_token_id_tensor.unsqueeze(1)).prod(dim=0)
|
365
|
+
)
|
366
|
+
|
367
|
+
# stop when each sentence is finished
|
368
|
+
if unfinished_sequences.max() == 0:
|
369
|
+
this_peer_finished = True
|
370
|
+
|
371
|
+
# stop if we exceed the maximum length
|
372
|
+
# thkim : backward compatibility bool vs list[bool]
|
373
|
+
is_stop = stopping_criteria(input_ids, None)
|
374
|
+
if isinstance(is_stop, torch.BoolTensor):
|
375
|
+
is_stop = torch.all(is_stop)
|
376
|
+
if is_stop:
|
377
|
+
this_peer_finished = True
|
378
|
+
|
379
|
+
if this_peer_finished:
|
380
|
+
break
|
381
|
+
|
382
|
+
if streamer is not None:
|
383
|
+
streamer.end()
|
384
|
+
|
385
|
+
if return_dict_in_generate:
|
386
|
+
############## thkim : roate raw_logits when right_padding#####################
|
387
|
+
if hasattr(self, "rightpad_max_len"):
|
388
|
+
raw_logits = torch.stack(raw_logits).transpose(0, 1)
|
389
|
+
for i in range(input_ids.shape[0]):
|
390
|
+
raw_logits[i] = torch.cat((raw_logits[i][self.dummy_len[i] :], raw_logits[i][: self.dummy_len[i]]))
|
391
|
+
raw_logits = raw_logits.transpose(1, 0)
|
392
|
+
##################################################################################
|
393
|
+
return SampleDecoderOnlyOutput(
|
394
|
+
sequences=input_ids,
|
395
|
+
scores=scores,
|
396
|
+
logits=raw_logits,
|
397
|
+
)
|
398
|
+
else:
|
399
|
+
return input_ids
|
@@ -24,5 +24,6 @@
|
|
24
24
|
from .clip import RBLNCLIPTextModel, RBLNCLIPTextModelWithProjection
|
25
25
|
from .gpt2 import RBLNGPT2LMHeadModel
|
26
26
|
from .llama import RBLNLlamaForCausalLM
|
27
|
+
from .midm import RBLNMidmLMHeadModel
|
27
28
|
from .wav2vec2 import RBLNWav2Vec2ForCTC
|
28
29
|
from .whisper import RBLNWhisperForConditionalGeneration
|