optimum-rbln 0.1.7__py3-none-any.whl → 0.1.9__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- optimum/rbln/__init__.py +17 -0
- optimum/rbln/__version__.py +1 -1
- optimum/rbln/diffusers/__init__.py +0 -1
- optimum/rbln/diffusers/models/autoencoder_kl.py +3 -3
- optimum/rbln/diffusers/models/controlnet.py +7 -3
- optimum/rbln/diffusers/models/unet_2d_condition.py +5 -5
- optimum/rbln/diffusers/pipelines/controlnet/multicontrolnet.py +23 -146
- optimum/rbln/diffusers/pipelines/controlnet/pipeline_controlnet.py +107 -59
- optimum/rbln/diffusers/pipelines/controlnet/pipeline_controlnet_img2img.py +106 -54
- optimum/rbln/diffusers/pipelines/controlnet/pipeline_controlnet_sd_xl.py +130 -71
- optimum/rbln/diffusers/pipelines/controlnet/pipeline_controlnet_sd_xl_img2img.py +131 -72
- optimum/rbln/modeling_alias.py +19 -1
- optimum/rbln/modeling_base.py +162 -18
- optimum/rbln/transformers/__init__.py +8 -0
- optimum/rbln/transformers/cache_utils.py +111 -0
- optimum/rbln/transformers/generation/utils.py +0 -2
- optimum/rbln/transformers/models/__init__.py +3 -0
- optimum/rbln/transformers/models/bart/bart_architecture.py +0 -5
- optimum/rbln/transformers/models/clip/modeling_clip.py +1 -1
- optimum/rbln/transformers/models/decoderonly/__init__.py +36 -0
- optimum/rbln/transformers/models/decoderonly/decoderonly_architecture.py +516 -0
- optimum/rbln/transformers/models/decoderonly/modeling_decoderonly.py +464 -0
- optimum/rbln/transformers/models/gemma/__init__.py +24 -0
- optimum/rbln/transformers/models/gemma/gemma_architecture.py +123 -0
- optimum/rbln/transformers/models/gemma/modeling_gemma.py +67 -0
- optimum/rbln/transformers/models/gpt2/gpt2_architecture.py +201 -166
- optimum/rbln/transformers/models/gpt2/modeling_gpt2.py +10 -257
- optimum/rbln/transformers/models/llama/llama_architecture.py +3 -610
- optimum/rbln/transformers/models/llama/modeling_llama.py +12 -440
- optimum/rbln/transformers/models/midm/hf_hub_cached/midm_bitext_tokenization.py +2 -1
- optimum/rbln/transformers/models/midm/hf_hub_cached/modeling_midm.py +0 -4
- optimum/rbln/transformers/models/midm/midm_architecture.py +160 -357
- optimum/rbln/transformers/models/midm/modeling_midm.py +10 -325
- optimum/rbln/transformers/models/mistral/__init__.py +24 -0
- optimum/rbln/transformers/models/mistral/mistral_architecture.py +29 -0
- optimum/rbln/transformers/models/mistral/modeling_mistral.py +68 -0
- optimum/rbln/transformers/models/wav2vec2/modeling_wav2vec2.py +1 -1
- optimum/rbln/transformers/models/whisper/whisper_architecture.py +0 -6
- optimum/rbln/transformers/models/xlm_roberta/__init__.py +24 -0
- optimum/rbln/transformers/models/xlm_roberta/modeling_xlm_roberta.py +131 -0
- optimum/rbln/transformers/utils/__init__.py +0 -0
- optimum/rbln/transformers/utils/rbln_quantization.py +109 -0
- optimum/rbln/utils/import_utils.py +1 -4
- optimum/rbln/utils/runtime_utils.py +2 -1
- {optimum_rbln-0.1.7.dist-info → optimum_rbln-0.1.9.dist-info}/METADATA +11 -5
- {optimum_rbln-0.1.7.dist-info → optimum_rbln-0.1.9.dist-info}/RECORD +48 -35
- optimum/rbln/transformers/models/llama/llama_architecture_cb.py +0 -764
- {optimum_rbln-0.1.7.dist-info → optimum_rbln-0.1.9.dist-info}/WHEEL +0 -0
- {optimum_rbln-0.1.7.dist-info → optimum_rbln-0.1.9.dist-info}/licenses/LICENSE +0 -0
@@ -23,39 +23,21 @@
|
|
23
23
|
|
24
24
|
import inspect
|
25
25
|
import logging
|
26
|
-
from typing import TYPE_CHECKING, Any, Callable
|
26
|
+
from typing import TYPE_CHECKING, Any, Callable
|
27
27
|
|
28
|
-
import
|
29
|
-
import torch
|
30
|
-
from transformers import AutoModelForCausalLM, GPT2LMHeadModel, PretrainedConfig
|
31
|
-
from transformers.modeling_outputs import CausalLMOutputWithCrossAttentions, Seq2SeqLMOutput
|
28
|
+
from transformers import GPT2LMHeadModel
|
32
29
|
|
33
|
-
from ....
|
34
|
-
from
|
35
|
-
from ....utils.runtime_utils import RBLNPytorchRuntime
|
36
|
-
from ...generation.utils import RBLNGenerationMixin
|
30
|
+
from ....modeling_config import RBLNConfig
|
31
|
+
from ...models.decoderonly import RBLNDecoderOnlyModelForCausalLM
|
37
32
|
from .gpt2_architecture import GPT2LMHeadModelWrapper
|
38
33
|
|
39
34
|
|
40
35
|
logger = logging.getLogger(__name__)
|
41
|
-
|
42
36
|
if TYPE_CHECKING:
|
43
|
-
from transformers import
|
44
|
-
AutoFeatureExtractor,
|
45
|
-
AutoProcessor,
|
46
|
-
AutoTokenizer,
|
47
|
-
PretrainedConfig,
|
48
|
-
)
|
49
|
-
|
50
|
-
|
51
|
-
class RBLNRuntimeDecoder(RBLNPytorchRuntime):
|
52
|
-
def forward(self, *args, **kwargs) -> Union[Tuple, Seq2SeqLMOutput]:
|
53
|
-
outputs = super().forward(*args, **kwargs)
|
54
|
-
logits = outputs
|
55
|
-
return Seq2SeqLMOutput(logits=logits)
|
37
|
+
from transformers import PreTrainedModel
|
56
38
|
|
57
39
|
|
58
|
-
class RBLNGPT2LMHeadModel(
|
40
|
+
class RBLNGPT2LMHeadModel(RBLNDecoderOnlyModelForCausalLM):
|
59
41
|
"""
|
60
42
|
The GPT2 Model transformer with a language modeling head on top (linear layer with weights tied to the input
|
61
43
|
embeddings).
|
@@ -69,29 +51,10 @@ class RBLNGPT2LMHeadModel(RBLNModel, RBLNGenerationMixin):
|
|
69
51
|
|
70
52
|
"""
|
71
53
|
|
72
|
-
|
73
|
-
|
74
|
-
|
75
|
-
|
76
|
-
def __post_init__(self, **kwargs):
|
77
|
-
self.prefill_chunk_size = self.rbln_config.meta["rbln_prefill_chunk_size"]
|
78
|
-
self.max_seq_len = self.rbln_config.meta["rbln_max_seq_len"]
|
79
|
-
|
80
|
-
batch_size = self.rbln_config[DEFAULT_COMPILED_MODEL_NAME][0].input_info[0][1][0]
|
81
|
-
self.prefill_attention_mask = torch.zeros(
|
82
|
-
batch_size, 1, self.prefill_chunk_size, self.max_seq_len, dtype=torch.int64
|
83
|
-
)
|
84
|
-
self.causal_mask = 1 - torch.triu(
|
85
|
-
torch.ones(batch_size, 1, self.prefill_chunk_size, self.prefill_chunk_size), diagonal=1
|
86
|
-
)
|
87
|
-
|
88
|
-
self.prefill_decoder = RBLNRuntimeDecoder(runtime=self.model[0])
|
89
|
-
self.decoder = RBLNRuntimeDecoder(runtime=self.model[1])
|
90
|
-
self.pad_token_id = self.rbln_config.meta["rbln_pad_token_id"]
|
91
|
-
self.past_cached_length = 0
|
92
|
-
|
93
|
-
def can_generate(self):
|
94
|
-
return True
|
54
|
+
@classmethod
|
55
|
+
def wrap_model_if_needed(self, model: "PreTrainedModel", rbln_config: "RBLNConfig"):
|
56
|
+
rbln_max_seq_len = rbln_config.meta["rbln_max_seq_len"]
|
57
|
+
return GPT2LMHeadModelWrapper(model, rbln_max_seq_len).eval()
|
95
58
|
|
96
59
|
def __getattr__(self, __name: str) -> Any:
|
97
60
|
"""This is the key method to implement RBLN-GPT2.
|
@@ -107,213 +70,3 @@ class RBLNGPT2LMHeadModel(RBLNModel, RBLNGenerationMixin):
|
|
107
70
|
if isinstance(val, Callable) and "self" in set(inspect.signature(val).parameters):
|
108
71
|
return redirect(val)
|
109
72
|
return val
|
110
|
-
|
111
|
-
def _reorder_cache(self, past_key_values, beam_idx):
|
112
|
-
# TODO(jongho): implement
|
113
|
-
raise NotImplementedError
|
114
|
-
|
115
|
-
@classmethod
|
116
|
-
def update_kwargs(cls, kwargs):
|
117
|
-
kwargs.update(
|
118
|
-
{
|
119
|
-
"torchscript": True,
|
120
|
-
"return_dict": False,
|
121
|
-
"use_cache": True,
|
122
|
-
}
|
123
|
-
)
|
124
|
-
return kwargs
|
125
|
-
|
126
|
-
@classmethod
|
127
|
-
@torch.inference_mode()
|
128
|
-
def get_compiled_model(cls, model: GPT2LMHeadModel, rbln_config: RBLNConfig):
|
129
|
-
wrapped_decoder = GPT2LMHeadModelWrapper(model).eval()
|
130
|
-
|
131
|
-
prefill_rbln_runtime_config = rbln_config[DEFAULT_COMPILED_MODEL_NAME][0]
|
132
|
-
dec_rbln_runtime_config = rbln_config[DEFAULT_COMPILED_MODEL_NAME][1]
|
133
|
-
|
134
|
-
prefill_example_inputs = prefill_rbln_runtime_config.get_dummy_inputs(fill=0)
|
135
|
-
dec_example_inputs = dec_rbln_runtime_config.get_dummy_inputs(fill=0)
|
136
|
-
|
137
|
-
prefill_scripted_model = torch.jit.trace(wrapped_decoder, prefill_example_inputs, check_trace=False)
|
138
|
-
dec_scripted_model = torch.jit.trace(wrapped_decoder, dec_example_inputs, check_trace=False)
|
139
|
-
|
140
|
-
prefill_ir = rebel.torchscript_to_ir(
|
141
|
-
prefill_scripted_model,
|
142
|
-
input_names=[v[0] for v in prefill_rbln_runtime_config.input_info],
|
143
|
-
)
|
144
|
-
dec_ir = rebel.torchscript_to_ir(
|
145
|
-
dec_scripted_model,
|
146
|
-
input_names=[v[0] for v in dec_rbln_runtime_config.input_info],
|
147
|
-
)
|
148
|
-
|
149
|
-
connections = [(prefill_ir.outputs[1 + i], prefill_ir.inputs[3 + i]) for i in range(model.config.n_layer * 2)]
|
150
|
-
|
151
|
-
compiled_model = rebel.compile(
|
152
|
-
prefill_ir,
|
153
|
-
dec_ir,
|
154
|
-
connections=connections,
|
155
|
-
fusion=prefill_rbln_runtime_config.fusion,
|
156
|
-
npu=prefill_rbln_runtime_config.npu,
|
157
|
-
tensor_parallel_size=prefill_rbln_runtime_config.tensor_parallel_size,
|
158
|
-
use_weight_sharing=True,
|
159
|
-
)
|
160
|
-
|
161
|
-
return compiled_model
|
162
|
-
|
163
|
-
@classmethod
|
164
|
-
def _get_rbln_config(
|
165
|
-
cls,
|
166
|
-
preprocessors: Union["AutoFeatureExtractor", "AutoProcessor", "AutoTokenizer"],
|
167
|
-
model_config: "PretrainedConfig",
|
168
|
-
rbln_max_seq_len: Optional[int] = None,
|
169
|
-
rbln_batch_size: Optional[int] = None,
|
170
|
-
rbln_pad_token_id: Optional[int] = None,
|
171
|
-
) -> RBLNConfig:
|
172
|
-
meta = {}
|
173
|
-
|
174
|
-
default_max_length = getattr(model_config, "n_positions", None)
|
175
|
-
for tokenizer in preprocessors:
|
176
|
-
default_max_length = default_max_length or getattr(tokenizer, "max_len_single_sentence", None)
|
177
|
-
|
178
|
-
prefill_chunk_size = 128
|
179
|
-
|
180
|
-
if rbln_max_seq_len is None:
|
181
|
-
rbln_max_seq_len = default_max_length
|
182
|
-
|
183
|
-
if rbln_max_seq_len is None:
|
184
|
-
raise ValueError("`rbln_max_seq_len` should be specified!")
|
185
|
-
|
186
|
-
if rbln_pad_token_id is None:
|
187
|
-
rbln_pad_token_id = getattr(model_config, "pad_token_id", None)
|
188
|
-
if rbln_pad_token_id is None:
|
189
|
-
rbln_pad_token_id = getattr(model_config, "eos_token_id", None)
|
190
|
-
if rbln_pad_token_id is None:
|
191
|
-
rbln_pad_token_id = 50256
|
192
|
-
|
193
|
-
meta["rbln_prefill_chunk_size"] = prefill_chunk_size
|
194
|
-
meta["rbln_max_seq_len"] = rbln_max_seq_len
|
195
|
-
meta["rbln_pad_token_id"] = rbln_pad_token_id
|
196
|
-
|
197
|
-
rbln_batch_size = 1 if rbln_batch_size is None else rbln_batch_size
|
198
|
-
|
199
|
-
def get_input_info(query_length):
|
200
|
-
return [
|
201
|
-
("input_ids", [rbln_batch_size, query_length], "int64"),
|
202
|
-
("attention_mask", [rbln_batch_size, 1, query_length, rbln_max_seq_len], "int64"),
|
203
|
-
(
|
204
|
-
"cache_position",
|
205
|
-
[],
|
206
|
-
"int32",
|
207
|
-
),
|
208
|
-
] + [
|
209
|
-
(
|
210
|
-
f"past_key_values_{i}",
|
211
|
-
[
|
212
|
-
rbln_batch_size,
|
213
|
-
model_config.n_head,
|
214
|
-
rbln_max_seq_len,
|
215
|
-
model_config.hidden_size // model_config.n_head,
|
216
|
-
],
|
217
|
-
"float32",
|
218
|
-
)
|
219
|
-
for i in range(model_config.n_layer * 2)
|
220
|
-
]
|
221
|
-
|
222
|
-
# model input info
|
223
|
-
prefill_input_info = get_input_info(query_length=prefill_chunk_size)
|
224
|
-
dec_input_info = get_input_info(query_length=1)
|
225
|
-
|
226
|
-
prefill_rbln_runtime_config = RBLNRuntimeConfig(input_info=prefill_input_info)
|
227
|
-
dec_rbln_runtime_config = RBLNRuntimeConfig(input_info=dec_input_info)
|
228
|
-
|
229
|
-
rbln_config = RBLNConfig.from_rbln_runtime_configs(
|
230
|
-
[prefill_rbln_runtime_config, dec_rbln_runtime_config],
|
231
|
-
_rbln_meta=meta,
|
232
|
-
)
|
233
|
-
|
234
|
-
return rbln_config
|
235
|
-
|
236
|
-
@classmethod
|
237
|
-
def _create_runtimes(
|
238
|
-
cls, compiled_models: List[rebel.RBLNCompiledModel], rbln_device_map: Dict[str, int]
|
239
|
-
) -> List[rebel.Runtime]:
|
240
|
-
device_val = rbln_device_map[DEFAULT_COMPILED_MODEL_NAME]
|
241
|
-
return [
|
242
|
-
compiled_models[0].create_runtime(input_info_index=0, tensor_type="pt", device=device_val),
|
243
|
-
compiled_models[0].create_runtime(input_info_index=1, tensor_type="pt", device=device_val),
|
244
|
-
]
|
245
|
-
|
246
|
-
def prepare_inputs_for_generation(self, input_ids, past_key_values=0, attention_mask=None, **kwargs):
|
247
|
-
batch_size, cur_len = input_ids.shape
|
248
|
-
past_cached_length = past_key_values
|
249
|
-
|
250
|
-
# In greedy decoding
|
251
|
-
if past_cached_length == 0:
|
252
|
-
self.prompt_ids = input_ids
|
253
|
-
self.rightpad_max_len = cur_len
|
254
|
-
prompt_min_len = torch.min(torch.sum(attention_mask, dim=-1))
|
255
|
-
self.dummy_len = torch.sum(attention_mask, dim=-1) - prompt_min_len
|
256
|
-
|
257
|
-
if cur_len % self.prefill_chunk_size == 0:
|
258
|
-
pad_len = 0
|
259
|
-
else:
|
260
|
-
pad_len = self.prefill_chunk_size - cur_len % self.prefill_chunk_size
|
261
|
-
input_ids = torch.nn.functional.pad(input_ids, (0, pad_len))
|
262
|
-
attention_mask = self.prefill_attention_mask.clone()
|
263
|
-
cache_position = torch.tensor(past_cached_length, dtype=torch.int32)
|
264
|
-
|
265
|
-
query_length = prompt_min_len.item()
|
266
|
-
else:
|
267
|
-
cache_position = torch.tensor(past_cached_length, dtype=torch.int32)
|
268
|
-
attention_mask = torch.zeros(batch_size, 1, 1, self.max_seq_len, dtype=torch.int64)
|
269
|
-
attention_mask[:, :, :, : cache_position + 1] = 1
|
270
|
-
input_ids = input_ids[:, cache_position : cache_position + 1].contiguous()
|
271
|
-
query_length = 1
|
272
|
-
|
273
|
-
model_inputs = {
|
274
|
-
"input_ids": input_ids,
|
275
|
-
"past_key_values": past_key_values,
|
276
|
-
"attention_mask": attention_mask,
|
277
|
-
# below are rbln-related kwargs
|
278
|
-
"cache_position": cache_position,
|
279
|
-
"query_length": query_length,
|
280
|
-
}
|
281
|
-
|
282
|
-
return model_inputs
|
283
|
-
|
284
|
-
def forward(
|
285
|
-
self,
|
286
|
-
input_ids: Optional[torch.LongTensor] = None,
|
287
|
-
past_key_values: Optional[Tuple[Tuple[torch.Tensor]]] = None,
|
288
|
-
attention_mask: Optional[torch.FloatTensor] = None,
|
289
|
-
cache_position: Optional[torch.Tensor] = None,
|
290
|
-
query_length: Optional[torch.Tensor] = None,
|
291
|
-
**kwargs,
|
292
|
-
) -> Union[Tuple, CausalLMOutputWithCrossAttentions]:
|
293
|
-
if past_key_values is not None:
|
294
|
-
past_key_values += query_length
|
295
|
-
|
296
|
-
if cache_position == 0:
|
297
|
-
for step in range(0, query_length, self.prefill_chunk_size):
|
298
|
-
sliced_input_ids = input_ids[:, step : step + self.prefill_chunk_size]
|
299
|
-
attention_mask[:, :, :, :step] = 1
|
300
|
-
attention_mask[:, :, :, step : step + self.prefill_chunk_size] = self.causal_mask
|
301
|
-
|
302
|
-
output = self.prefill_decoder(
|
303
|
-
input_ids=sliced_input_ids.contiguous(),
|
304
|
-
attention_mask=attention_mask.contiguous(),
|
305
|
-
cache_position=cache_position + step,
|
306
|
-
)
|
307
|
-
|
308
|
-
idx = query_length % self.prefill_chunk_size - 1
|
309
|
-
output = output.logits[:, idx].unsqueeze(1)
|
310
|
-
|
311
|
-
else:
|
312
|
-
output = self.decoder(
|
313
|
-
input_ids=input_ids.contiguous(),
|
314
|
-
attention_mask=attention_mask.contiguous(),
|
315
|
-
cache_position=cache_position,
|
316
|
-
)
|
317
|
-
output = output.logits
|
318
|
-
|
319
|
-
return CausalLMOutputWithCrossAttentions(logits=output, past_key_values=past_key_values)
|