optimum-rbln 0.1.7__py3-none-any.whl → 0.1.8__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- optimum/rbln/__init__.py +14 -0
- optimum/rbln/__version__.py +1 -1
- optimum/rbln/diffusers/__init__.py +0 -1
- optimum/rbln/diffusers/models/controlnet.py +3 -0
- optimum/rbln/diffusers/models/unet_2d_condition.py +2 -2
- optimum/rbln/diffusers/pipelines/controlnet/multicontrolnet.py +22 -144
- optimum/rbln/diffusers/pipelines/controlnet/pipeline_controlnet.py +107 -59
- optimum/rbln/diffusers/pipelines/controlnet/pipeline_controlnet_img2img.py +106 -54
- optimum/rbln/diffusers/pipelines/controlnet/pipeline_controlnet_sd_xl.py +130 -71
- optimum/rbln/diffusers/pipelines/controlnet/pipeline_controlnet_sd_xl_img2img.py +131 -72
- optimum/rbln/modeling_alias.py +14 -0
- optimum/rbln/modeling_base.py +110 -0
- optimum/rbln/transformers/__init__.py +6 -0
- optimum/rbln/transformers/cache_utils.py +111 -0
- optimum/rbln/transformers/generation/utils.py +0 -2
- optimum/rbln/transformers/models/__init__.py +2 -0
- optimum/rbln/transformers/models/bart/bart_architecture.py +0 -5
- optimum/rbln/transformers/models/decoderonly/__init__.py +36 -0
- optimum/rbln/transformers/models/decoderonly/decoderonly_architecture.py +515 -0
- optimum/rbln/transformers/models/decoderonly/modeling_decoderonly.py +349 -0
- optimum/rbln/transformers/models/gemma/__init__.py +24 -0
- optimum/rbln/transformers/models/gemma/gemma_architecture.py +116 -0
- optimum/rbln/transformers/models/gemma/modeling_gemma.py +61 -0
- optimum/rbln/transformers/models/gpt2/gpt2_architecture.py +201 -166
- optimum/rbln/transformers/models/gpt2/modeling_gpt2.py +56 -220
- optimum/rbln/transformers/models/llama/llama_architecture.py +3 -610
- optimum/rbln/transformers/models/llama/modeling_llama.py +8 -442
- optimum/rbln/transformers/models/midm/hf_hub_cached/midm_bitext_tokenization.py +2 -1
- optimum/rbln/transformers/models/midm/hf_hub_cached/modeling_midm.py +0 -4
- optimum/rbln/transformers/models/midm/midm_architecture.py +160 -357
- optimum/rbln/transformers/models/midm/modeling_midm.py +40 -272
- optimum/rbln/transformers/models/whisper/whisper_architecture.py +0 -6
- optimum/rbln/transformers/models/xlm_roberta/__init__.py +24 -0
- optimum/rbln/transformers/models/xlm_roberta/modeling_xlm_roberta.py +125 -0
- {optimum_rbln-0.1.7.dist-info → optimum_rbln-0.1.8.dist-info}/METADATA +2 -3
- {optimum_rbln-0.1.7.dist-info → optimum_rbln-0.1.8.dist-info}/RECORD +38 -30
- optimum/rbln/transformers/models/llama/llama_architecture_cb.py +0 -764
- {optimum_rbln-0.1.7.dist-info → optimum_rbln-0.1.8.dist-info}/WHEEL +0 -0
- {optimum_rbln-0.1.7.dist-info → optimum_rbln-0.1.8.dist-info}/licenses/LICENSE +0 -0
@@ -23,22 +23,16 @@
|
|
23
23
|
|
24
24
|
import inspect
|
25
25
|
import logging
|
26
|
-
from typing import TYPE_CHECKING, Any, Callable,
|
26
|
+
from typing import TYPE_CHECKING, Any, Callable, Optional, Union
|
27
27
|
|
28
|
-
import
|
29
|
-
import torch
|
30
|
-
from transformers import AutoModelForCausalLM, GPT2LMHeadModel, PretrainedConfig
|
31
|
-
from transformers.modeling_outputs import CausalLMOutputWithCrossAttentions, Seq2SeqLMOutput
|
28
|
+
from transformers import GPT2LMHeadModel, PretrainedConfig, PreTrainedModel
|
32
29
|
|
33
|
-
from ....
|
34
|
-
from
|
35
|
-
from ....utils.runtime_utils import RBLNPytorchRuntime
|
36
|
-
from ...generation.utils import RBLNGenerationMixin
|
30
|
+
from ....modeling_config import RBLNConfig, RBLNRuntimeConfig
|
31
|
+
from ...models.decoderonly import RBLNDecoderOnlyModelForCausalLM
|
37
32
|
from .gpt2_architecture import GPT2LMHeadModelWrapper
|
38
33
|
|
39
34
|
|
40
35
|
logger = logging.getLogger(__name__)
|
41
|
-
|
42
36
|
if TYPE_CHECKING:
|
43
37
|
from transformers import (
|
44
38
|
AutoFeatureExtractor,
|
@@ -48,14 +42,7 @@ if TYPE_CHECKING:
|
|
48
42
|
)
|
49
43
|
|
50
44
|
|
51
|
-
class
|
52
|
-
def forward(self, *args, **kwargs) -> Union[Tuple, Seq2SeqLMOutput]:
|
53
|
-
outputs = super().forward(*args, **kwargs)
|
54
|
-
logits = outputs
|
55
|
-
return Seq2SeqLMOutput(logits=logits)
|
56
|
-
|
57
|
-
|
58
|
-
class RBLNGPT2LMHeadModel(RBLNModel, RBLNGenerationMixin):
|
45
|
+
class RBLNGPT2LMHeadModel(RBLNDecoderOnlyModelForCausalLM):
|
59
46
|
"""
|
60
47
|
The GPT2 Model transformer with a language modeling head on top (linear layer with weights tied to the input
|
61
48
|
embeddings).
|
@@ -69,29 +56,9 @@ class RBLNGPT2LMHeadModel(RBLNModel, RBLNGenerationMixin):
|
|
69
56
|
|
70
57
|
"""
|
71
58
|
|
72
|
-
|
73
|
-
|
74
|
-
|
75
|
-
|
76
|
-
def __post_init__(self, **kwargs):
|
77
|
-
self.prefill_chunk_size = self.rbln_config.meta["rbln_prefill_chunk_size"]
|
78
|
-
self.max_seq_len = self.rbln_config.meta["rbln_max_seq_len"]
|
79
|
-
|
80
|
-
batch_size = self.rbln_config[DEFAULT_COMPILED_MODEL_NAME][0].input_info[0][1][0]
|
81
|
-
self.prefill_attention_mask = torch.zeros(
|
82
|
-
batch_size, 1, self.prefill_chunk_size, self.max_seq_len, dtype=torch.int64
|
83
|
-
)
|
84
|
-
self.causal_mask = 1 - torch.triu(
|
85
|
-
torch.ones(batch_size, 1, self.prefill_chunk_size, self.prefill_chunk_size), diagonal=1
|
86
|
-
)
|
87
|
-
|
88
|
-
self.prefill_decoder = RBLNRuntimeDecoder(runtime=self.model[0])
|
89
|
-
self.decoder = RBLNRuntimeDecoder(runtime=self.model[1])
|
90
|
-
self.pad_token_id = self.rbln_config.meta["rbln_pad_token_id"]
|
91
|
-
self.past_cached_length = 0
|
92
|
-
|
93
|
-
def can_generate(self):
|
94
|
-
return True
|
59
|
+
@classmethod
|
60
|
+
def wrapping_torch_model(self, model: "PreTrainedModel", rbln_max_seq_len: int):
|
61
|
+
return GPT2LMHeadModelWrapper(model, rbln_max_seq_len).eval()
|
95
62
|
|
96
63
|
def __getattr__(self, __name: str) -> Any:
|
97
64
|
"""This is the key method to implement RBLN-GPT2.
|
@@ -108,58 +75,6 @@ class RBLNGPT2LMHeadModel(RBLNModel, RBLNGenerationMixin):
|
|
108
75
|
return redirect(val)
|
109
76
|
return val
|
110
77
|
|
111
|
-
def _reorder_cache(self, past_key_values, beam_idx):
|
112
|
-
# TODO(jongho): implement
|
113
|
-
raise NotImplementedError
|
114
|
-
|
115
|
-
@classmethod
|
116
|
-
def update_kwargs(cls, kwargs):
|
117
|
-
kwargs.update(
|
118
|
-
{
|
119
|
-
"torchscript": True,
|
120
|
-
"return_dict": False,
|
121
|
-
"use_cache": True,
|
122
|
-
}
|
123
|
-
)
|
124
|
-
return kwargs
|
125
|
-
|
126
|
-
@classmethod
|
127
|
-
@torch.inference_mode()
|
128
|
-
def get_compiled_model(cls, model: GPT2LMHeadModel, rbln_config: RBLNConfig):
|
129
|
-
wrapped_decoder = GPT2LMHeadModelWrapper(model).eval()
|
130
|
-
|
131
|
-
prefill_rbln_runtime_config = rbln_config[DEFAULT_COMPILED_MODEL_NAME][0]
|
132
|
-
dec_rbln_runtime_config = rbln_config[DEFAULT_COMPILED_MODEL_NAME][1]
|
133
|
-
|
134
|
-
prefill_example_inputs = prefill_rbln_runtime_config.get_dummy_inputs(fill=0)
|
135
|
-
dec_example_inputs = dec_rbln_runtime_config.get_dummy_inputs(fill=0)
|
136
|
-
|
137
|
-
prefill_scripted_model = torch.jit.trace(wrapped_decoder, prefill_example_inputs, check_trace=False)
|
138
|
-
dec_scripted_model = torch.jit.trace(wrapped_decoder, dec_example_inputs, check_trace=False)
|
139
|
-
|
140
|
-
prefill_ir = rebel.torchscript_to_ir(
|
141
|
-
prefill_scripted_model,
|
142
|
-
input_names=[v[0] for v in prefill_rbln_runtime_config.input_info],
|
143
|
-
)
|
144
|
-
dec_ir = rebel.torchscript_to_ir(
|
145
|
-
dec_scripted_model,
|
146
|
-
input_names=[v[0] for v in dec_rbln_runtime_config.input_info],
|
147
|
-
)
|
148
|
-
|
149
|
-
connections = [(prefill_ir.outputs[1 + i], prefill_ir.inputs[3 + i]) for i in range(model.config.n_layer * 2)]
|
150
|
-
|
151
|
-
compiled_model = rebel.compile(
|
152
|
-
prefill_ir,
|
153
|
-
dec_ir,
|
154
|
-
connections=connections,
|
155
|
-
fusion=prefill_rbln_runtime_config.fusion,
|
156
|
-
npu=prefill_rbln_runtime_config.npu,
|
157
|
-
tensor_parallel_size=prefill_rbln_runtime_config.tensor_parallel_size,
|
158
|
-
use_weight_sharing=True,
|
159
|
-
)
|
160
|
-
|
161
|
-
return compiled_model
|
162
|
-
|
163
78
|
@classmethod
|
164
79
|
def _get_rbln_config(
|
165
80
|
cls,
|
@@ -167,153 +82,74 @@ class RBLNGPT2LMHeadModel(RBLNModel, RBLNGenerationMixin):
|
|
167
82
|
model_config: "PretrainedConfig",
|
168
83
|
rbln_max_seq_len: Optional[int] = None,
|
169
84
|
rbln_batch_size: Optional[int] = None,
|
170
|
-
|
85
|
+
**kwargs,
|
171
86
|
) -> RBLNConfig:
|
172
87
|
meta = {}
|
173
88
|
|
174
|
-
default_max_length = getattr(model_config, "n_positions", None)
|
175
|
-
for tokenizer in preprocessors:
|
176
|
-
default_max_length = default_max_length or getattr(tokenizer, "max_len_single_sentence", None)
|
177
|
-
|
178
89
|
prefill_chunk_size = 128
|
90
|
+
if rbln_max_seq_len is None: # differenct from llama
|
91
|
+
rbln_max_seq_len = getattr(model_config, "n_positions", None)
|
92
|
+
rbln_batch_size = 1 if rbln_batch_size is None else rbln_batch_size
|
179
93
|
|
180
|
-
if rbln_max_seq_len is None:
|
181
|
-
rbln_max_seq_len = default_max_length
|
182
|
-
|
183
|
-
if rbln_max_seq_len is None:
|
184
|
-
raise ValueError("`rbln_max_seq_len` should be specified!")
|
185
|
-
|
186
|
-
if rbln_pad_token_id is None:
|
187
|
-
rbln_pad_token_id = getattr(model_config, "pad_token_id", None)
|
188
|
-
if rbln_pad_token_id is None:
|
189
|
-
rbln_pad_token_id = getattr(model_config, "eos_token_id", None)
|
190
|
-
if rbln_pad_token_id is None:
|
191
|
-
rbln_pad_token_id = 50256
|
192
|
-
|
193
|
-
meta["rbln_prefill_chunk_size"] = prefill_chunk_size
|
194
94
|
meta["rbln_max_seq_len"] = rbln_max_seq_len
|
195
|
-
meta["
|
196
|
-
|
197
|
-
rbln_batch_size = 1 if rbln_batch_size is None else rbln_batch_size
|
95
|
+
meta["rbln_batch_size"] = rbln_batch_size
|
96
|
+
meta["rbln_prefill_chunk_size"] = prefill_chunk_size
|
198
97
|
|
199
|
-
def get_input_info(
|
200
|
-
|
201
|
-
|
202
|
-
|
98
|
+
def get_input_info(
|
99
|
+
batch_size,
|
100
|
+
query_length,
|
101
|
+
):
|
102
|
+
head_dim = (
|
103
|
+
model_config.head_dim
|
104
|
+
if hasattr(model_config, "head_dim")
|
105
|
+
else model_config.hidden_size // model_config.n_head
|
106
|
+
)
|
107
|
+
input_info = [
|
108
|
+
("input_ids", [batch_size, query_length], "int64"),
|
109
|
+
("attention_mask", [batch_size, 1, query_length, rbln_max_seq_len], "int64"),
|
203
110
|
(
|
204
111
|
"cache_position",
|
205
|
-
[],
|
112
|
+
[batch_size, query_length],
|
206
113
|
"int32",
|
207
114
|
),
|
208
|
-
|
209
|
-
(
|
210
|
-
f"past_key_values_{i}",
|
211
|
-
[
|
212
|
-
rbln_batch_size,
|
213
|
-
model_config.n_head,
|
214
|
-
rbln_max_seq_len,
|
215
|
-
model_config.hidden_size // model_config.n_head,
|
216
|
-
],
|
217
|
-
"float32",
|
218
|
-
)
|
219
|
-
for i in range(model_config.n_layer * 2)
|
115
|
+
("batch_position", [], "int16"),
|
220
116
|
]
|
221
117
|
|
222
|
-
|
223
|
-
|
224
|
-
|
118
|
+
input_info.extend(
|
119
|
+
[
|
120
|
+
(
|
121
|
+
f"past_key_values_{i}",
|
122
|
+
[
|
123
|
+
rbln_batch_size,
|
124
|
+
model_config.n_head, # differenct from llama
|
125
|
+
rbln_max_seq_len,
|
126
|
+
head_dim,
|
127
|
+
],
|
128
|
+
"float32",
|
129
|
+
)
|
130
|
+
for i in range(model_config.n_layer * 2) # differenct from llama
|
131
|
+
]
|
132
|
+
)
|
133
|
+
|
134
|
+
return input_info
|
135
|
+
|
136
|
+
prefill_input_info = get_input_info(
|
137
|
+
batch_size=1,
|
138
|
+
query_length=prefill_chunk_size,
|
139
|
+
)
|
140
|
+
dec_input_info = get_input_info(
|
141
|
+
batch_size=rbln_batch_size,
|
142
|
+
query_length=1,
|
143
|
+
)
|
225
144
|
|
226
145
|
prefill_rbln_runtime_config = RBLNRuntimeConfig(input_info=prefill_input_info)
|
227
146
|
dec_rbln_runtime_config = RBLNRuntimeConfig(input_info=dec_input_info)
|
228
147
|
|
148
|
+
dec_rbln_runtime_config.batch_size = rbln_batch_size
|
149
|
+
|
229
150
|
rbln_config = RBLNConfig.from_rbln_runtime_configs(
|
230
151
|
[prefill_rbln_runtime_config, dec_rbln_runtime_config],
|
231
152
|
_rbln_meta=meta,
|
232
153
|
)
|
233
154
|
|
234
155
|
return rbln_config
|
235
|
-
|
236
|
-
@classmethod
|
237
|
-
def _create_runtimes(
|
238
|
-
cls, compiled_models: List[rebel.RBLNCompiledModel], rbln_device_map: Dict[str, int]
|
239
|
-
) -> List[rebel.Runtime]:
|
240
|
-
device_val = rbln_device_map[DEFAULT_COMPILED_MODEL_NAME]
|
241
|
-
return [
|
242
|
-
compiled_models[0].create_runtime(input_info_index=0, tensor_type="pt", device=device_val),
|
243
|
-
compiled_models[0].create_runtime(input_info_index=1, tensor_type="pt", device=device_val),
|
244
|
-
]
|
245
|
-
|
246
|
-
def prepare_inputs_for_generation(self, input_ids, past_key_values=0, attention_mask=None, **kwargs):
|
247
|
-
batch_size, cur_len = input_ids.shape
|
248
|
-
past_cached_length = past_key_values
|
249
|
-
|
250
|
-
# In greedy decoding
|
251
|
-
if past_cached_length == 0:
|
252
|
-
self.prompt_ids = input_ids
|
253
|
-
self.rightpad_max_len = cur_len
|
254
|
-
prompt_min_len = torch.min(torch.sum(attention_mask, dim=-1))
|
255
|
-
self.dummy_len = torch.sum(attention_mask, dim=-1) - prompt_min_len
|
256
|
-
|
257
|
-
if cur_len % self.prefill_chunk_size == 0:
|
258
|
-
pad_len = 0
|
259
|
-
else:
|
260
|
-
pad_len = self.prefill_chunk_size - cur_len % self.prefill_chunk_size
|
261
|
-
input_ids = torch.nn.functional.pad(input_ids, (0, pad_len))
|
262
|
-
attention_mask = self.prefill_attention_mask.clone()
|
263
|
-
cache_position = torch.tensor(past_cached_length, dtype=torch.int32)
|
264
|
-
|
265
|
-
query_length = prompt_min_len.item()
|
266
|
-
else:
|
267
|
-
cache_position = torch.tensor(past_cached_length, dtype=torch.int32)
|
268
|
-
attention_mask = torch.zeros(batch_size, 1, 1, self.max_seq_len, dtype=torch.int64)
|
269
|
-
attention_mask[:, :, :, : cache_position + 1] = 1
|
270
|
-
input_ids = input_ids[:, cache_position : cache_position + 1].contiguous()
|
271
|
-
query_length = 1
|
272
|
-
|
273
|
-
model_inputs = {
|
274
|
-
"input_ids": input_ids,
|
275
|
-
"past_key_values": past_key_values,
|
276
|
-
"attention_mask": attention_mask,
|
277
|
-
# below are rbln-related kwargs
|
278
|
-
"cache_position": cache_position,
|
279
|
-
"query_length": query_length,
|
280
|
-
}
|
281
|
-
|
282
|
-
return model_inputs
|
283
|
-
|
284
|
-
def forward(
|
285
|
-
self,
|
286
|
-
input_ids: Optional[torch.LongTensor] = None,
|
287
|
-
past_key_values: Optional[Tuple[Tuple[torch.Tensor]]] = None,
|
288
|
-
attention_mask: Optional[torch.FloatTensor] = None,
|
289
|
-
cache_position: Optional[torch.Tensor] = None,
|
290
|
-
query_length: Optional[torch.Tensor] = None,
|
291
|
-
**kwargs,
|
292
|
-
) -> Union[Tuple, CausalLMOutputWithCrossAttentions]:
|
293
|
-
if past_key_values is not None:
|
294
|
-
past_key_values += query_length
|
295
|
-
|
296
|
-
if cache_position == 0:
|
297
|
-
for step in range(0, query_length, self.prefill_chunk_size):
|
298
|
-
sliced_input_ids = input_ids[:, step : step + self.prefill_chunk_size]
|
299
|
-
attention_mask[:, :, :, :step] = 1
|
300
|
-
attention_mask[:, :, :, step : step + self.prefill_chunk_size] = self.causal_mask
|
301
|
-
|
302
|
-
output = self.prefill_decoder(
|
303
|
-
input_ids=sliced_input_ids.contiguous(),
|
304
|
-
attention_mask=attention_mask.contiguous(),
|
305
|
-
cache_position=cache_position + step,
|
306
|
-
)
|
307
|
-
|
308
|
-
idx = query_length % self.prefill_chunk_size - 1
|
309
|
-
output = output.logits[:, idx].unsqueeze(1)
|
310
|
-
|
311
|
-
else:
|
312
|
-
output = self.decoder(
|
313
|
-
input_ids=input_ids.contiguous(),
|
314
|
-
attention_mask=attention_mask.contiguous(),
|
315
|
-
cache_position=cache_position,
|
316
|
-
)
|
317
|
-
output = output.logits
|
318
|
-
|
319
|
-
return CausalLMOutputWithCrossAttentions(logits=output, past_key_values=past_key_values)
|