optimum-rbln 0.1.7__py3-none-any.whl → 0.1.8__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- optimum/rbln/__init__.py +14 -0
- optimum/rbln/__version__.py +1 -1
- optimum/rbln/diffusers/__init__.py +0 -1
- optimum/rbln/diffusers/models/controlnet.py +3 -0
- optimum/rbln/diffusers/models/unet_2d_condition.py +2 -2
- optimum/rbln/diffusers/pipelines/controlnet/multicontrolnet.py +22 -144
- optimum/rbln/diffusers/pipelines/controlnet/pipeline_controlnet.py +107 -59
- optimum/rbln/diffusers/pipelines/controlnet/pipeline_controlnet_img2img.py +106 -54
- optimum/rbln/diffusers/pipelines/controlnet/pipeline_controlnet_sd_xl.py +130 -71
- optimum/rbln/diffusers/pipelines/controlnet/pipeline_controlnet_sd_xl_img2img.py +131 -72
- optimum/rbln/modeling_alias.py +14 -0
- optimum/rbln/modeling_base.py +110 -0
- optimum/rbln/transformers/__init__.py +6 -0
- optimum/rbln/transformers/cache_utils.py +111 -0
- optimum/rbln/transformers/generation/utils.py +0 -2
- optimum/rbln/transformers/models/__init__.py +2 -0
- optimum/rbln/transformers/models/bart/bart_architecture.py +0 -5
- optimum/rbln/transformers/models/decoderonly/__init__.py +36 -0
- optimum/rbln/transformers/models/decoderonly/decoderonly_architecture.py +515 -0
- optimum/rbln/transformers/models/decoderonly/modeling_decoderonly.py +349 -0
- optimum/rbln/transformers/models/gemma/__init__.py +24 -0
- optimum/rbln/transformers/models/gemma/gemma_architecture.py +116 -0
- optimum/rbln/transformers/models/gemma/modeling_gemma.py +61 -0
- optimum/rbln/transformers/models/gpt2/gpt2_architecture.py +201 -166
- optimum/rbln/transformers/models/gpt2/modeling_gpt2.py +56 -220
- optimum/rbln/transformers/models/llama/llama_architecture.py +3 -610
- optimum/rbln/transformers/models/llama/modeling_llama.py +8 -442
- optimum/rbln/transformers/models/midm/hf_hub_cached/midm_bitext_tokenization.py +2 -1
- optimum/rbln/transformers/models/midm/hf_hub_cached/modeling_midm.py +0 -4
- optimum/rbln/transformers/models/midm/midm_architecture.py +160 -357
- optimum/rbln/transformers/models/midm/modeling_midm.py +40 -272
- optimum/rbln/transformers/models/whisper/whisper_architecture.py +0 -6
- optimum/rbln/transformers/models/xlm_roberta/__init__.py +24 -0
- optimum/rbln/transformers/models/xlm_roberta/modeling_xlm_roberta.py +125 -0
- {optimum_rbln-0.1.7.dist-info → optimum_rbln-0.1.8.dist-info}/METADATA +2 -3
- {optimum_rbln-0.1.7.dist-info → optimum_rbln-0.1.8.dist-info}/RECORD +38 -30
- optimum/rbln/transformers/models/llama/llama_architecture_cb.py +0 -764
- {optimum_rbln-0.1.7.dist-info → optimum_rbln-0.1.8.dist-info}/WHEEL +0 -0
- {optimum_rbln-0.1.7.dist-info → optimum_rbln-0.1.8.dist-info}/licenses/LICENSE +0 -0
@@ -23,17 +23,12 @@
|
|
23
23
|
|
24
24
|
import inspect
|
25
25
|
import logging
|
26
|
-
from typing import TYPE_CHECKING, Any, Callable,
|
26
|
+
from typing import TYPE_CHECKING, Any, Callable, Optional, Union
|
27
27
|
|
28
|
-
import
|
29
|
-
import torch
|
30
|
-
from transformers import AutoConfig, AutoModelForCausalLM, PretrainedConfig, PreTrainedModel
|
31
|
-
from transformers.modeling_outputs import CausalLMOutputWithCrossAttentions
|
28
|
+
from transformers import PretrainedConfig, PreTrainedModel
|
32
29
|
|
33
|
-
from ....
|
34
|
-
from
|
35
|
-
from ....utils.runtime_utils import RBLNPytorchRuntime
|
36
|
-
from ...generation.utils import RBLNGenerationMixin
|
30
|
+
from ....modeling_config import RBLNConfig, RBLNRuntimeConfig
|
31
|
+
from ...models.decoderonly import RBLNDecoderOnlyModelForCausalLM
|
37
32
|
from .hf_hub_cached.modeling_midm import MidmLMHeadModel
|
38
33
|
from .midm_architecture import (
|
39
34
|
MidmLMHeadModelWrapper,
|
@@ -41,7 +36,6 @@ from .midm_architecture import (
|
|
41
36
|
|
42
37
|
|
43
38
|
logger = logging.getLogger(__name__)
|
44
|
-
|
45
39
|
if TYPE_CHECKING:
|
46
40
|
from transformers import (
|
47
41
|
AutoFeatureExtractor,
|
@@ -51,31 +45,12 @@ if TYPE_CHECKING:
|
|
51
45
|
)
|
52
46
|
|
53
47
|
|
54
|
-
class
|
55
|
-
mandatory_members = ["main_input_name"]
|
56
|
-
|
57
|
-
# RBLN_Runtimemodule
|
58
|
-
def forward(
|
59
|
-
self,
|
60
|
-
input_ids: torch.LongTensor = None,
|
61
|
-
attention_mask: torch.LongTensor = None,
|
62
|
-
cache_position: torch.Tensor = None,
|
63
|
-
**kwargs: Dict[str, Any],
|
64
|
-
):
|
65
|
-
logits = super().forward(
|
66
|
-
input_ids=input_ids,
|
67
|
-
attention_mask=attention_mask,
|
68
|
-
cache_position=cache_position,
|
69
|
-
)
|
70
|
-
return logits
|
71
|
-
|
72
|
-
|
73
|
-
class RBLNMidmLMHeadModel(RBLNModel, RBLNGenerationMixin):
|
48
|
+
class RBLNMidmLMHeadModel(RBLNDecoderOnlyModelForCausalLM):
|
74
49
|
"""
|
75
50
|
The Midm Model transformer with a language modeling head on top (linear layer with weights tied to the input
|
76
51
|
embeddings).
|
77
52
|
|
78
|
-
This model inherits from [`
|
53
|
+
This model inherits from [`RBLNDecoderOnlyModelForCausalLM`]. Check the superclass documentation for the generic methods the
|
79
54
|
library implements for all its model.
|
80
55
|
|
81
56
|
It implements the methods to convert a pre-trained transformers Midm model into a RBLN transformer model by:
|
@@ -84,46 +59,9 @@ class RBLNMidmLMHeadModel(RBLNModel, RBLNGenerationMixin):
|
|
84
59
|
|
85
60
|
"""
|
86
61
|
|
87
|
-
|
88
|
-
|
89
|
-
|
90
|
-
|
91
|
-
def __init__(
|
92
|
-
self,
|
93
|
-
models: List[Union[PreTrainedModel, rebel.RBLNCompiledModel]],
|
94
|
-
config: PretrainedConfig = None,
|
95
|
-
preprocessors: Optional[List] = None,
|
96
|
-
rbln_config: Optional[RBLNConfig] = None,
|
97
|
-
rbln_device: Optional[List[int]] = None,
|
98
|
-
rbln_device_map: Optional[Dict[str, int]] = None,
|
99
|
-
**kwargs,
|
100
|
-
):
|
101
|
-
super().__init__(
|
102
|
-
models,
|
103
|
-
config,
|
104
|
-
preprocessors,
|
105
|
-
rbln_config,
|
106
|
-
rbln_device=rbln_device,
|
107
|
-
rbln_device_map=rbln_device_map,
|
108
|
-
**kwargs,
|
109
|
-
)
|
110
|
-
self.batch_size = self.rbln_config.meta["rbln_batch_size"]
|
111
|
-
self.prefill_chunk_size = self.rbln_config.meta["rbln_prefill_chunk_size"]
|
112
|
-
self.max_seq_len = self.rbln_config.meta["rbln_max_seq_len"]
|
113
|
-
|
114
|
-
self.prefill_attention_mask = torch.zeros(
|
115
|
-
self.batch_size, 1, self.prefill_chunk_size, self.max_seq_len, dtype=torch.int64
|
116
|
-
)
|
117
|
-
self.causal_mask = 1 - torch.triu(
|
118
|
-
torch.ones(self.batch_size, 1, self.prefill_chunk_size, self.prefill_chunk_size), diagonal=1
|
119
|
-
)
|
120
|
-
|
121
|
-
self.prefill_decoder = RBLNRuntimeDecoder(runtime=self.model[0], main_input_name="input_ids")
|
122
|
-
self.decoder = RBLNRuntimeDecoder(runtime=self.model[1], main_input_name="input_ids")
|
123
|
-
self.past_cached_length = 0
|
124
|
-
|
125
|
-
def can_generate(self):
|
126
|
-
return True
|
62
|
+
@classmethod
|
63
|
+
def wrapping_torch_model(self, model: "PreTrainedModel", rbln_max_seq_len: int):
|
64
|
+
return MidmLMHeadModelWrapper(model, rbln_max_seq_len).eval()
|
127
65
|
|
128
66
|
def __getattr__(self, __name: str) -> Any:
|
129
67
|
"""This is the key method to implement RBLN-Midm.
|
@@ -140,142 +78,46 @@ class RBLNMidmLMHeadModel(RBLNModel, RBLNGenerationMixin):
|
|
140
78
|
return redirect(val)
|
141
79
|
return val
|
142
80
|
|
143
|
-
def _reorder_cache(self, past_key_values, beam_idx):
|
144
|
-
# TODO(jongho): implement
|
145
|
-
raise NotImplementedError
|
146
|
-
|
147
|
-
@classmethod
|
148
|
-
@torch.inference_mode()
|
149
|
-
def get_compiled_model(cls, model: "PreTrainedModel", rbln_config: RBLNConfig):
|
150
|
-
wrapped_decoder = MidmLMHeadModelWrapper(model).eval()
|
151
|
-
prefill_rbln_runtime_config = rbln_config[DEFAULT_COMPILED_MODEL_NAME][0]
|
152
|
-
dec_rbln_runtime_config = rbln_config[DEFAULT_COMPILED_MODEL_NAME][1]
|
153
|
-
|
154
|
-
prefill_example_inputs = prefill_rbln_runtime_config.get_dummy_inputs(fill=0)
|
155
|
-
dec_example_inputs = dec_rbln_runtime_config.get_dummy_inputs(fill=0)
|
156
|
-
|
157
|
-
prefill_scripted_model = torch.jit.trace(wrapped_decoder, prefill_example_inputs, check_trace=False)
|
158
|
-
dec_scripted_model = torch.jit.trace(wrapped_decoder, dec_example_inputs, check_trace=False)
|
159
|
-
|
160
|
-
prefill_ir = rebel.torchscript_to_ir(
|
161
|
-
prefill_scripted_model,
|
162
|
-
input_names=[v[0] for v in prefill_rbln_runtime_config.input_info],
|
163
|
-
)
|
164
|
-
dec_ir = rebel.torchscript_to_ir(
|
165
|
-
dec_scripted_model,
|
166
|
-
input_names=[v[0] for v in dec_rbln_runtime_config.input_info],
|
167
|
-
)
|
168
|
-
|
169
|
-
connections = [(prefill_ir.outputs[1 + i], prefill_ir.inputs[3 + i]) for i in range(model.config.n_layer * 2)]
|
170
|
-
|
171
|
-
compiled_model = rebel.compile(
|
172
|
-
prefill_ir,
|
173
|
-
dec_ir,
|
174
|
-
connections=connections,
|
175
|
-
fusion=prefill_rbln_runtime_config.fusion,
|
176
|
-
npu=prefill_rbln_runtime_config.npu,
|
177
|
-
tensor_parallel_size=prefill_rbln_runtime_config.tensor_parallel_size,
|
178
|
-
use_weight_sharing=True,
|
179
|
-
)
|
180
|
-
return compiled_model
|
181
|
-
|
182
|
-
@classmethod
|
183
|
-
def update_kwargs(cls, kwargs):
|
184
|
-
"""
|
185
|
-
Update user-given kwargs to get proper pytorch model.
|
186
|
-
|
187
|
-
For example, `torchscript`=True should be set because torch.jit
|
188
|
-
does not support `transformers` output instances as module output;
|
189
|
-
"""
|
190
|
-
kwargs.update(
|
191
|
-
{
|
192
|
-
"torchscript": True,
|
193
|
-
"return_dict": False,
|
194
|
-
"use_cache": True,
|
195
|
-
"torch_dtype": torch.float32,
|
196
|
-
"_attn_implementation": "eager",
|
197
|
-
}
|
198
|
-
)
|
199
|
-
return kwargs
|
200
|
-
|
201
|
-
@classmethod
|
202
|
-
def get_pytorch_model(
|
203
|
-
cls,
|
204
|
-
model_id: str,
|
205
|
-
use_auth_token: Optional[Union[bool, str]] = None,
|
206
|
-
revision: Optional[str] = None,
|
207
|
-
force_download: bool = False,
|
208
|
-
cache_dir: Optional[str] = None,
|
209
|
-
subfolder: str = "",
|
210
|
-
local_files_only: bool = False,
|
211
|
-
trust_remote_code: bool = False,
|
212
|
-
rbln_config_kwargs: Optional[Dict[str, Any]] = None,
|
213
|
-
rbln_constructor_kwargs: Optional[Dict[str, Any]] = None,
|
214
|
-
**kwargs,
|
215
|
-
) -> PreTrainedModel:
|
216
|
-
if rbln_max_seq_len := rbln_config_kwargs.get("rbln_max_seq_len", None):
|
217
|
-
config = AutoConfig.from_pretrained(model_id, trust_remote_code=trust_remote_code)
|
218
|
-
if hf_position_embedding := getattr(config, "max_position_embeddings", None):
|
219
|
-
if hf_position_embedding < rbln_max_seq_len:
|
220
|
-
logger.warning(
|
221
|
-
f"`rbln_max_seq_len` is larger than original config({hf_position_embedding})."
|
222
|
-
"This may lead to incorrect inferences of the model."
|
223
|
-
)
|
224
|
-
kwargs.update({"max_position_embeddings": rbln_max_seq_len})
|
225
|
-
|
226
|
-
return super().get_pytorch_model(
|
227
|
-
model_id=model_id,
|
228
|
-
use_auth_token=use_auth_token,
|
229
|
-
revision=revision,
|
230
|
-
force_download=force_download,
|
231
|
-
cache_dir=cache_dir,
|
232
|
-
subfolder=subfolder,
|
233
|
-
local_files_only=local_files_only,
|
234
|
-
trust_remote_code=trust_remote_code,
|
235
|
-
rbln_config_kwargs=rbln_config_kwargs,
|
236
|
-
rbln_constructor_kwargs=rbln_constructor_kwargs,
|
237
|
-
ignore_mismatched_sizes=True,
|
238
|
-
**kwargs,
|
239
|
-
)
|
240
|
-
|
241
81
|
@classmethod
|
242
82
|
def _get_rbln_config(
|
243
83
|
cls,
|
244
84
|
preprocessors: Union["AutoFeatureExtractor", "AutoProcessor", "AutoTokenizer"],
|
245
85
|
model_config: "PretrainedConfig",
|
246
|
-
rbln_prefill_chunk_size: Optional[int] = 128,
|
247
86
|
rbln_max_seq_len: Optional[int] = None,
|
248
87
|
rbln_batch_size: Optional[int] = None,
|
88
|
+
**kwargs,
|
249
89
|
) -> RBLNConfig:
|
250
90
|
meta = {}
|
251
|
-
if rbln_max_seq_len is None:
|
252
|
-
rbln_max_seq_len = getattr(model_config, "max_position_embeddings", None)
|
253
91
|
|
92
|
+
prefill_chunk_size = 128
|
254
93
|
if rbln_max_seq_len is None:
|
255
|
-
|
256
|
-
|
257
|
-
rbln_max_seq_len = tokenizer.model_max_length
|
258
|
-
break
|
259
|
-
if rbln_max_seq_len is None:
|
260
|
-
raise ValueError("`rbln_max_seq_len` should be specified!")
|
94
|
+
rbln_max_seq_len = getattr(model_config, "n_positions", None)
|
95
|
+
rbln_batch_size = 1 if rbln_batch_size is None else rbln_batch_size
|
261
96
|
|
262
|
-
if rbln_batch_size is None:
|
263
|
-
rbln_batch_size = 1
|
264
|
-
|
265
|
-
meta["rbln_prefill_chunk_size"] = rbln_prefill_chunk_size
|
266
97
|
meta["rbln_max_seq_len"] = rbln_max_seq_len
|
267
|
-
meta["rbln_batch_size"] = rbln_batch_size
|
268
|
-
|
269
|
-
|
98
|
+
meta["rbln_batch_size"] = rbln_batch_size
|
99
|
+
meta["rbln_prefill_chunk_size"] = prefill_chunk_size
|
100
|
+
|
101
|
+
def get_input_info(
|
102
|
+
batch_size,
|
103
|
+
query_length,
|
104
|
+
):
|
105
|
+
head_dim = (
|
106
|
+
model_config.head_dim
|
107
|
+
if hasattr(model_config, "head_dim")
|
108
|
+
else model_config.hidden_size // model_config.n_head
|
109
|
+
)
|
270
110
|
input_info = [
|
271
|
-
("input_ids", [
|
272
|
-
("attention_mask", [
|
111
|
+
("input_ids", [batch_size, query_length], "int64"),
|
112
|
+
("attention_mask", [batch_size, 1, query_length, rbln_max_seq_len], "int64"),
|
273
113
|
(
|
274
114
|
"cache_position",
|
275
|
-
[],
|
115
|
+
[batch_size, query_length],
|
276
116
|
"int32",
|
277
117
|
),
|
118
|
+
("batch_position", [], "int16"),
|
278
119
|
]
|
120
|
+
|
279
121
|
input_info.extend(
|
280
122
|
[
|
281
123
|
(
|
@@ -284,18 +126,24 @@ class RBLNMidmLMHeadModel(RBLNModel, RBLNGenerationMixin):
|
|
284
126
|
rbln_batch_size,
|
285
127
|
model_config.n_head,
|
286
128
|
rbln_max_seq_len,
|
287
|
-
|
129
|
+
head_dim,
|
288
130
|
],
|
289
131
|
"float32",
|
290
132
|
)
|
291
133
|
for i in range(model_config.n_layer * 2)
|
292
134
|
]
|
293
135
|
)
|
136
|
+
|
294
137
|
return input_info
|
295
138
|
|
296
|
-
|
297
|
-
|
298
|
-
|
139
|
+
prefill_input_info = get_input_info(
|
140
|
+
batch_size=1,
|
141
|
+
query_length=prefill_chunk_size,
|
142
|
+
)
|
143
|
+
dec_input_info = get_input_info(
|
144
|
+
batch_size=rbln_batch_size,
|
145
|
+
query_length=1,
|
146
|
+
)
|
299
147
|
|
300
148
|
prefill_rbln_runtime_config = RBLNRuntimeConfig(input_info=prefill_input_info)
|
301
149
|
dec_rbln_runtime_config = RBLNRuntimeConfig(input_info=dec_input_info)
|
@@ -308,83 +156,3 @@ class RBLNMidmLMHeadModel(RBLNModel, RBLNGenerationMixin):
|
|
308
156
|
)
|
309
157
|
|
310
158
|
return rbln_config
|
311
|
-
|
312
|
-
@classmethod
|
313
|
-
def _create_runtimes(
|
314
|
-
cls, compiled_models: List[rebel.RBLNCompiledModel], rbln_device_map: Dict[str, int]
|
315
|
-
) -> List[rebel.Runtime]:
|
316
|
-
device_val = rbln_device_map[DEFAULT_COMPILED_MODEL_NAME]
|
317
|
-
return [
|
318
|
-
compiled_models[0].create_runtime(input_info_index=0, tensor_type="pt", device=device_val),
|
319
|
-
compiled_models[0].create_runtime(input_info_index=1, tensor_type="pt", device=device_val),
|
320
|
-
]
|
321
|
-
|
322
|
-
def prepare_inputs_for_generation(self, input_ids, past_key_values=0, attention_mask=None, **kwargs):
|
323
|
-
batch_size, cur_len = input_ids.shape
|
324
|
-
past_cached_length = past_key_values
|
325
|
-
|
326
|
-
if past_cached_length == 0:
|
327
|
-
mod_len = cur_len % self.prefill_chunk_size
|
328
|
-
self.pad_len = self.prefill_chunk_size - mod_len if mod_len > 0 else 0
|
329
|
-
|
330
|
-
prompt_attn_mask = torch.nn.functional.pad(attention_mask, (self.pad_len, 0), value=0)
|
331
|
-
self.prompt_attn_mask = prompt_attn_mask.reshape(batch_size, 1, 1, -1).contiguous()
|
332
|
-
|
333
|
-
input_ids = torch.nn.functional.pad(input_ids, (self.pad_len, 0), value=0)
|
334
|
-
attention_mask = self.prefill_attention_mask.clone()
|
335
|
-
cache_position = torch.tensor(past_cached_length, dtype=torch.int32)
|
336
|
-
|
337
|
-
query_length = cur_len + self.pad_len
|
338
|
-
else:
|
339
|
-
attention_mask = torch.nn.functional.pad(
|
340
|
-
attention_mask, (self.pad_len, self.max_seq_len - cur_len - self.pad_len)
|
341
|
-
)
|
342
|
-
attention_mask = attention_mask.reshape(batch_size, 1, 1, -1).contiguous()
|
343
|
-
cache_position = torch.tensor(past_cached_length, dtype=torch.int32)
|
344
|
-
input_ids = input_ids[:, -1:].contiguous()
|
345
|
-
query_length = 1
|
346
|
-
|
347
|
-
model_inputs = {
|
348
|
-
"input_ids": input_ids,
|
349
|
-
"past_key_values": past_cached_length,
|
350
|
-
"attention_mask": attention_mask,
|
351
|
-
"cache_position": cache_position,
|
352
|
-
"query_length": query_length,
|
353
|
-
}
|
354
|
-
|
355
|
-
return model_inputs
|
356
|
-
|
357
|
-
def forward(
|
358
|
-
self,
|
359
|
-
input_ids: Optional[torch.LongTensor] = None,
|
360
|
-
past_key_values: int = None,
|
361
|
-
attention_mask: Optional[torch.FloatTensor] = None,
|
362
|
-
cache_position: Optional[torch.Tensor] = None,
|
363
|
-
query_length: Optional[torch.Tensor] = None,
|
364
|
-
**kwargs,
|
365
|
-
) -> Union[Tuple, CausalLMOutputWithCrossAttentions]:
|
366
|
-
past_cached_length = past_key_values
|
367
|
-
|
368
|
-
if past_cached_length is not None:
|
369
|
-
past_cached_length += query_length
|
370
|
-
|
371
|
-
if cache_position == 0:
|
372
|
-
for step in range(0, query_length, self.prefill_chunk_size):
|
373
|
-
sliced_input_ids = input_ids[:, step : step + self.prefill_chunk_size]
|
374
|
-
attention_mask[:, :, :, :step] = 1
|
375
|
-
attention_mask[:, :, :, step : step + self.prefill_chunk_size] = self.causal_mask
|
376
|
-
attention_mask[:, :, :, :query_length] *= self.prompt_attn_mask
|
377
|
-
|
378
|
-
output = self.prefill_decoder(
|
379
|
-
input_ids=sliced_input_ids.contiguous(),
|
380
|
-
attention_mask=attention_mask,
|
381
|
-
cache_position=cache_position + step,
|
382
|
-
)
|
383
|
-
cache_position += self.prefill_chunk_size
|
384
|
-
else:
|
385
|
-
output = self.decoder(
|
386
|
-
input_ids=input_ids.contiguous(),
|
387
|
-
attention_mask=attention_mask,
|
388
|
-
cache_position=cache_position,
|
389
|
-
)
|
390
|
-
return CausalLMOutputWithCrossAttentions(logits=output, past_key_values=past_cached_length)
|
@@ -57,7 +57,6 @@ class _WhisperAttention(WhisperAttention):
|
|
57
57
|
cache_position: Optional[torch.Tensor] = None,
|
58
58
|
**kwargs,
|
59
59
|
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
|
60
|
-
|
61
60
|
bsz, tgt_len, _ = hidden_states.size()
|
62
61
|
is_cross_attention = key_value_states is not None
|
63
62
|
|
@@ -123,7 +122,6 @@ class _WhisperSdpaAttention(WhisperSdpaAttention):
|
|
123
122
|
cache_position: Optional[torch.Tensor] = None,
|
124
123
|
**kwargs,
|
125
124
|
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
|
126
|
-
|
127
125
|
bsz, tgt_len, _ = hidden_states.size()
|
128
126
|
|
129
127
|
is_cross_attention = key_value_states is not None
|
@@ -189,7 +187,6 @@ class _WhisperDecoderLayer(WhisperDecoderLayer):
|
|
189
187
|
cache_position: Optional[torch.Tensor] = None,
|
190
188
|
attn_impl: str = "eager",
|
191
189
|
) -> torch.Tensor:
|
192
|
-
|
193
190
|
# Self Attention Block
|
194
191
|
residual = hidden_states
|
195
192
|
hidden_states = self.self_attn_layer_norm(hidden_states)
|
@@ -248,7 +245,6 @@ class _WhisperDecoder(WhisperDecoder):
|
|
248
245
|
attn_impl: str = "eager",
|
249
246
|
**kwargs,
|
250
247
|
):
|
251
|
-
|
252
248
|
input_shape = input_ids.size()
|
253
249
|
input_ids = input_ids.view(-1, input_shape[-1])
|
254
250
|
|
@@ -312,7 +308,6 @@ class _WhisperDecoderWrapper(torch.nn.Module):
|
|
312
308
|
self_kv_cache: torch.Tensor,
|
313
309
|
cross_kv_cache: torch.Tensor,
|
314
310
|
) -> Union[Tuple[torch.FloatTensor], Seq2SeqLMOutput]:
|
315
|
-
|
316
311
|
# prepare past_key_values
|
317
312
|
kv_cache = ()
|
318
313
|
for i in range(0, self.num_layers * 2, 2):
|
@@ -367,7 +362,6 @@ class _WhisperEncoderWrapper(torch.nn.Module):
|
|
367
362
|
self,
|
368
363
|
input_features: Optional[torch.LongTensor] = None,
|
369
364
|
) -> Union[Tuple[torch.FloatTensor], BaseModelOutput]:
|
370
|
-
|
371
365
|
encoder_outputs = self.encoder(input_features=input_features)
|
372
366
|
last_hidden_states = encoder_outputs[0]
|
373
367
|
|
@@ -0,0 +1,24 @@
|
|
1
|
+
# Copyright 2024 Rebellions Inc.
|
2
|
+
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4
|
+
# you may not use this file except in compliance with the License.
|
5
|
+
# You may obtain a copy of the License at:
|
6
|
+
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
8
|
+
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12
|
+
# See the License for the specific language governing permissions and
|
13
|
+
# limitations under the License.
|
14
|
+
|
15
|
+
# Portions of this software are licensed under the Apache License,
|
16
|
+
# Version 2.0. See the NOTICE file distributed with this work for
|
17
|
+
# additional information regarding copyright ownership.
|
18
|
+
|
19
|
+
# All other portions of this software, including proprietary code,
|
20
|
+
# are the intellectual property of Rebellions Inc. and may not be
|
21
|
+
# copied, modified, or distributed without prior written permission
|
22
|
+
# from Rebellions Inc.
|
23
|
+
|
24
|
+
from .modeling_xlm_roberta import RBLNXLMRobertaModel
|
@@ -0,0 +1,125 @@
|
|
1
|
+
# Copyright 2024 Rebellions Inc.
|
2
|
+
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4
|
+
# you may not use this file except in compliance with the License.
|
5
|
+
# You may obtain a copy of the License at:
|
6
|
+
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
8
|
+
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12
|
+
# See the License for the specific language governing permissions and
|
13
|
+
# limitations under the License.
|
14
|
+
|
15
|
+
# Portions of this software are licensed under the Apache License,
|
16
|
+
# Version 2.0. See the NOTICE file distributed with this work for
|
17
|
+
# additional information regarding copyright ownership.
|
18
|
+
|
19
|
+
# All other portions of this software, including proprietary code,
|
20
|
+
# are the intellectual property of Rebellions Inc. and may not be
|
21
|
+
# copied, modified, or distributed without prior written permission
|
22
|
+
# from Rebellions Inc.
|
23
|
+
|
24
|
+
import logging
|
25
|
+
from typing import TYPE_CHECKING, Any, Dict, List, Optional, Union
|
26
|
+
|
27
|
+
import torch
|
28
|
+
from transformers import AutoModel, PretrainedConfig, PreTrainedModel, XLMRobertaConfig, XLMRobertaModel
|
29
|
+
|
30
|
+
from ....modeling_base import RBLNModel
|
31
|
+
from ....modeling_config import RBLNConfig, RBLNRuntimeConfig
|
32
|
+
|
33
|
+
|
34
|
+
logger = logging.getLogger(__name__)
|
35
|
+
|
36
|
+
if TYPE_CHECKING:
|
37
|
+
from transformers import AutoFeatureExtractor, AutoProcessor, AutoTokenizer
|
38
|
+
|
39
|
+
class RBLNXLMRobertaModel(RBLNModel):
|
40
|
+
auto_model_class = AutoModel # feature extraction
|
41
|
+
original_model_class = XLMRobertaModel
|
42
|
+
original_config_class = XLMRobertaConfig
|
43
|
+
|
44
|
+
@classmethod
|
45
|
+
def get_pytorch_model(
|
46
|
+
cls,
|
47
|
+
model_id: str,
|
48
|
+
use_auth_token: Optional[Union[bool, str]] = None,
|
49
|
+
revision: Optional[str] = None,
|
50
|
+
force_download: bool = False,
|
51
|
+
cache_dir: Optional[str] = None,
|
52
|
+
subfolder: str = "",
|
53
|
+
local_files_only: bool = False,
|
54
|
+
trust_remote_code: bool = False,
|
55
|
+
rbln_config_kwargs: Optional[Dict[str, Any]] = None,
|
56
|
+
rbln_constructor_kwargs: Optional[Dict[str, Any]] = None,
|
57
|
+
**kwargs,
|
58
|
+
) -> "PreTrainedModel":
|
59
|
+
model: "PreTrainedModel" = super().get_pytorch_model(
|
60
|
+
model_id=model_id,
|
61
|
+
use_auth_token=use_auth_token,
|
62
|
+
revision=revision,
|
63
|
+
force_download=force_download,
|
64
|
+
cache_dir=cache_dir,
|
65
|
+
subfolder=subfolder,
|
66
|
+
local_files_only=local_files_only,
|
67
|
+
trust_remote_code=trust_remote_code,
|
68
|
+
rbln_config_kwargs=rbln_config_kwargs,
|
69
|
+
rbln_constructor_kwargs=rbln_constructor_kwargs,
|
70
|
+
library_name="transformers",
|
71
|
+
)
|
72
|
+
|
73
|
+
return model
|
74
|
+
|
75
|
+
@classmethod
|
76
|
+
def _get_rbln_config(
|
77
|
+
cls,
|
78
|
+
preprocessors: Optional[Union["AutoFeatureExtractor", "AutoProcessor", "AutoTokenizer"]],
|
79
|
+
model_config: Optional["PretrainedConfig"] = None,
|
80
|
+
rbln_max_seq_len: Optional[int] = None,
|
81
|
+
rbln_model_input_names: Optional[List[str]] = None,
|
82
|
+
rbln_batch_size: Optional[int] = None,
|
83
|
+
) -> RBLNConfig:
|
84
|
+
|
85
|
+
max_position_embeddings = getattr(model_config, "n_positions", None) or getattr(
|
86
|
+
model_config, "max_position_embeddings", None
|
87
|
+
)
|
88
|
+
|
89
|
+
if rbln_max_seq_len is None:
|
90
|
+
rbln_max_seq_len = max_position_embeddings
|
91
|
+
if rbln_max_seq_len is None:
|
92
|
+
for tokenizer in preprocessors:
|
93
|
+
if hasattr(tokenizer, "model_max_length"):
|
94
|
+
rbln_max_seq_len = tokenizer.model_max_length
|
95
|
+
break
|
96
|
+
if rbln_max_seq_len is None:
|
97
|
+
raise ValueError("`rbln_max_seq_len` should be specified!")
|
98
|
+
|
99
|
+
if max_position_embeddings is not None and rbln_max_seq_len > max_position_embeddings:
|
100
|
+
raise ValueError("`rbln_enc_max_seq_len` should be less or equal than max_position_embeddings!")
|
101
|
+
|
102
|
+
if rbln_model_input_names is None:
|
103
|
+
# These are BERT's inputs
|
104
|
+
rbln_model_input_names = ["input_ids", "attention_mask", "token_type_ids"]
|
105
|
+
|
106
|
+
if rbln_batch_size is None:
|
107
|
+
rbln_batch_size = 1
|
108
|
+
|
109
|
+
input_info = [
|
110
|
+
(model_input_name, [rbln_batch_size, rbln_max_seq_len], "int64")
|
111
|
+
for model_input_name in rbln_model_input_names
|
112
|
+
]
|
113
|
+
|
114
|
+
rbln_runtime_config = RBLNRuntimeConfig(input_info=input_info)
|
115
|
+
rbln_runtime_config.batch_size = rbln_batch_size
|
116
|
+
|
117
|
+
meta = {"rbln_max_seq_len": rbln_max_seq_len}
|
118
|
+
|
119
|
+
return RBLNConfig.from_rbln_runtime_configs([rbln_runtime_config], _rbln_meta=meta)
|
120
|
+
|
121
|
+
def forward(self, input_ids: "torch.Tensor", attention_mask: "torch.Tensor", token_type_ids: "torch.Tensor" = None, **kwargs):
|
122
|
+
if token_type_ids is None:
|
123
|
+
token_type_ids = torch.zeros_like(input=input_ids, dtype=torch.int64)
|
124
|
+
output = super().forward(input_ids, attention_mask, token_type_ids)
|
125
|
+
return output
|
@@ -1,6 +1,6 @@
|
|
1
1
|
Metadata-Version: 2.1
|
2
2
|
Name: optimum-rbln
|
3
|
-
Version: 0.1.
|
3
|
+
Version: 0.1.8
|
4
4
|
Summary: Optimum RBLN is the interface between the Hugging Face Transformers and Diffusers libraries and RBLN accelerators.
|
5
5
|
It provides a set of tools enabling easy model loading and inference on single and multiple rbln device settings for different downstream tasks.
|
6
6
|
Keywords: transformers,diffusers,inference,rbln,atom,rebel
|
@@ -21,7 +21,7 @@ Project-URL: Homepage, https://rebellions.ai
|
|
21
21
|
Project-URL: Documentation, https://docs.rbln.ai
|
22
22
|
Requires-Python: <3.11,>=3.8
|
23
23
|
Requires-Dist: torch<=2.2.1
|
24
|
-
Requires-Dist: optimum
|
24
|
+
Requires-Dist: optimum<=1.20.0
|
25
25
|
Requires-Dist: accelerate>=0.28.0
|
26
26
|
Requires-Dist: transformers<=4.40.2
|
27
27
|
Requires-Dist: diffusers<=0.29.2
|
@@ -35,7 +35,6 @@ Requires-Dist: sentencepiece>=0.2.0; extra == "tests"
|
|
35
35
|
Requires-Dist: datasets>=2.18.0; extra == "tests"
|
36
36
|
Requires-Dist: sacremoses>=0.1.1; extra == "tests"
|
37
37
|
Requires-Dist: safetensors>=0.4.2; extra == "tests"
|
38
|
-
Requires-Dist: black>=24.3.0; extra == "quality"
|
39
38
|
Requires-Dist: ruff>=0.3.3; extra == "quality"
|
40
39
|
Requires-Dist: isort>=5.13.2; extra == "quality"
|
41
40
|
Requires-Dist: hf-doc-builder>=0.5.0; extra == "quality"
|