optimum-rbln 0.1.4__py3-none-any.whl → 0.1.7__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- optimum/rbln/__init__.py +7 -0
- optimum/rbln/__version__.py +1 -1
- optimum/rbln/diffusers/models/autoencoder_kl.py +16 -98
- optimum/rbln/diffusers/models/unet_2d_condition.py +1 -1
- optimum/rbln/diffusers/pipelines/controlnet/multicontrolnet.py +9 -11
- optimum/rbln/diffusers/pipelines/controlnet/pipeline_controlnet.py +8 -0
- optimum/rbln/diffusers/pipelines/controlnet/pipeline_controlnet_img2img.py +9 -0
- optimum/rbln/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion.py +3 -0
- optimum/rbln/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_img2img.py +8 -0
- optimum/rbln/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl.py +8 -0
- optimum/rbln/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_img2img.py +9 -0
- optimum/rbln/modeling_base.py +172 -100
- optimum/rbln/modeling_seq2seq.py +58 -132
- optimum/rbln/transformers/__init__.py +2 -0
- optimum/rbln/transformers/models/__init__.py +1 -0
- optimum/rbln/transformers/models/clip/modeling_clip.py +0 -1
- optimum/rbln/transformers/models/dpt/__init__.py +24 -0
- optimum/rbln/transformers/models/dpt/modeling_dpt.py +89 -0
- optimum/rbln/transformers/models/gpt2/gpt2_architecture.py +24 -33
- optimum/rbln/transformers/models/gpt2/modeling_gpt2.py +52 -124
- optimum/rbln/transformers/models/llama/llama_architecture.py +13 -16
- optimum/rbln/transformers/models/llama/llama_architecture_cb.py +41 -36
- optimum/rbln/transformers/models/llama/modeling_llama.py +94 -120
- optimum/rbln/transformers/models/midm/modeling_midm.py +85 -121
- optimum/rbln/transformers/models/whisper/modeling_whisper.py +53 -123
- optimum/rbln/utils/__init__.py +1 -1
- optimum/rbln/utils/import_utils.py +46 -0
- {optimum_rbln-0.1.4.dist-info → optimum_rbln-0.1.7.dist-info}/METADATA +17 -51
- {optimum_rbln-0.1.4.dist-info → optimum_rbln-0.1.7.dist-info}/RECORD +31 -29
- {optimum_rbln-0.1.4.dist-info → optimum_rbln-0.1.7.dist-info}/WHEEL +1 -1
- {optimum_rbln-0.1.4.dist-info → optimum_rbln-0.1.7.dist-info}/licenses/LICENSE +0 -0
@@ -23,22 +23,18 @@
|
|
23
23
|
|
24
24
|
import inspect # noqa: I001
|
25
25
|
import logging
|
26
|
-
from pathlib import Path
|
27
|
-
from tempfile import TemporaryDirectory
|
28
26
|
from typing import TYPE_CHECKING, Any, Callable, Dict, List, Optional, Tuple, Union
|
29
27
|
|
30
28
|
import torch # noqa: F401
|
31
29
|
import rebel # noqa: F401
|
32
30
|
|
33
|
-
from
|
34
|
-
from transformers import AutoModelForCausalLM, LlamaForCausalLM, PretrainedConfig, AutoConfig
|
31
|
+
from transformers import AutoModelForCausalLM, LlamaForCausalLM, PreTrainedModel, PretrainedConfig, AutoConfig
|
35
32
|
from transformers.modeling_outputs import CausalLMOutputWithPast
|
36
33
|
|
37
34
|
from ...generation.utils import RBLNGenerationMixin
|
38
|
-
from ....modeling_base import
|
35
|
+
from ....modeling_base import RBLNModel
|
39
36
|
from ....modeling_config import DEFAULT_COMPILED_MODEL_NAME, RBLNConfig, RBLNRuntimeConfig
|
40
37
|
from ....utils.runtime_utils import RBLNPytorchRuntime
|
41
|
-
from ....utils.save_utils import maybe_save_preprocessors
|
42
38
|
|
43
39
|
|
44
40
|
# FIXME:: Merge Two architecture Codes
|
@@ -72,10 +68,10 @@ class RBLNRuntimeModel(RBLNPytorchRuntime):
|
|
72
68
|
mandatory_members = ["main_input_name"]
|
73
69
|
|
74
70
|
|
75
|
-
class RBLNLlamaForCausalLM(
|
71
|
+
class RBLNLlamaForCausalLM(RBLNModel, RBLNGenerationMixin):
|
76
72
|
"""
|
77
73
|
The Llama Model transformer with a language modeling head (linear layer) on top.
|
78
|
-
This model inherits from [`
|
74
|
+
This model inherits from [`RBLNMultiModel`]. Check the superclass documentation for the generic methods the library implements for all its models.
|
79
75
|
|
80
76
|
A class to convert and run pre-trained transformers based LlamaForCausalLM model on RBLN devices.
|
81
77
|
It implements the methods to convert a pre-trained transformers LlamaForCausalLM model into a RBLN transformer model by:
|
@@ -83,7 +79,6 @@ class RBLNLlamaForCausalLM(RBLNBaseModel, RBLNGenerationMixin):
|
|
83
79
|
- compiling the resulting graph using the RBLN compiler.
|
84
80
|
"""
|
85
81
|
|
86
|
-
model_type = "rbln_model"
|
87
82
|
main_input_name = "input_ids"
|
88
83
|
auto_model_class = AutoModelForCausalLM
|
89
84
|
|
@@ -102,17 +97,34 @@ class RBLNLlamaForCausalLM(RBLNBaseModel, RBLNGenerationMixin):
|
|
102
97
|
)
|
103
98
|
self.decoder_attention_mask = torch.zeros(self.batch_size, 1, 1, self.max_seq_len, dtype=torch.int64)
|
104
99
|
|
105
|
-
self.prefill_decoder = RBLNRuntimeModel(runtime=self.
|
106
|
-
self.decoder = RBLNRuntimeModel(runtime=self.
|
100
|
+
self.prefill_decoder = RBLNRuntimeModel(runtime=self.model[0], main_input_name="input_ids")
|
101
|
+
self.decoder = RBLNRuntimeModel(runtime=self.model[1], main_input_name="input_ids")
|
107
102
|
self.past_cached_length = 0
|
108
103
|
self.right_padding = True
|
109
104
|
|
110
105
|
@classmethod
|
111
|
-
|
112
|
-
|
106
|
+
def update_kwargs(cls, kwargs):
|
107
|
+
"""
|
108
|
+
Update user-given kwargs to get proper pytorch model.
|
109
|
+
|
110
|
+
For example, `torchscript`=True should be set because torch.jit
|
111
|
+
does not support `transformers` output instances as module output;
|
112
|
+
"""
|
113
|
+
kwargs.update(
|
114
|
+
{
|
115
|
+
"torchscript": True,
|
116
|
+
"return_dict": False,
|
117
|
+
"use_cache": True,
|
118
|
+
"torch_dtype": torch.float32,
|
119
|
+
"_attn_implementation": "eager",
|
120
|
+
}
|
121
|
+
)
|
122
|
+
return kwargs
|
123
|
+
|
124
|
+
@classmethod
|
125
|
+
def get_pytorch_model(
|
113
126
|
cls,
|
114
127
|
model_id: str,
|
115
|
-
config: "PretrainedConfig",
|
116
128
|
use_auth_token: Optional[Union[bool, str]] = None,
|
117
129
|
revision: Optional[str] = None,
|
118
130
|
force_download: bool = False,
|
@@ -120,135 +132,94 @@ class RBLNLlamaForCausalLM(RBLNBaseModel, RBLNGenerationMixin):
|
|
120
132
|
subfolder: str = "",
|
121
133
|
local_files_only: bool = False,
|
122
134
|
trust_remote_code: bool = False,
|
123
|
-
|
135
|
+
rbln_config_kwargs: Optional[Dict[str, Any]] = None,
|
136
|
+
rbln_constructor_kwargs: Optional[Dict[str, Any]] = None,
|
124
137
|
**kwargs,
|
125
|
-
) ->
|
126
|
-
|
127
|
-
|
128
|
-
|
129
|
-
|
130
|
-
|
131
|
-
|
132
|
-
|
133
|
-
|
134
|
-
|
135
|
-
if isinstance(save_dir, TemporaryDirectory):
|
136
|
-
save_dir_path = Path(model_save_dir.name)
|
137
|
-
else:
|
138
|
-
save_dir_path = Path(model_save_dir)
|
139
|
-
save_dir_path.mkdir(exist_ok=True)
|
140
|
-
|
141
|
-
def update_configs(kwargs):
|
142
|
-
hf_max_position_embeddings = getattr(AutoConfig.from_pretrained(model_id), "max_position_embeddings", None)
|
143
|
-
max_seq_len = kwargs.get("rbln_max_seq_len", None)
|
144
|
-
if max_seq_len is not None:
|
145
|
-
if max_seq_len <= hf_max_position_embeddings:
|
146
|
-
kwargs.update({"max_position_embeddings": max_seq_len})
|
147
|
-
else:
|
148
|
-
raise ValueError("`max_seq_len` should be less or equal than max_position_embeddings!")
|
149
|
-
|
150
|
-
kwargs.update(
|
151
|
-
{
|
152
|
-
"torchscript": True,
|
153
|
-
"return_dict": False,
|
154
|
-
"use_cache": True,
|
155
|
-
"torch_dtype": torch.float32,
|
156
|
-
"_attn_implementation": "eager",
|
157
|
-
}
|
158
|
-
)
|
159
|
-
|
160
|
-
return kwargs
|
161
|
-
|
162
|
-
kwargs = update_configs(kwargs)
|
163
|
-
|
164
|
-
rbln_config_kwargs, rbln_constructor_kwargs = cls.pop_rbln_kwargs_from_kwargs(kwargs)
|
138
|
+
) -> PreTrainedModel:
|
139
|
+
if rbln_max_seq_len := rbln_config_kwargs.get("rbln_max_seq_len", None):
|
140
|
+
config = AutoConfig.from_pretrained(model_id)
|
141
|
+
if hf_position_embedding := getattr(config, "max_position_embeddings", None):
|
142
|
+
if hf_position_embedding < rbln_max_seq_len:
|
143
|
+
logger.warning(
|
144
|
+
f"`rbln_max_seq_len` is larger than original config({hf_position_embedding})."
|
145
|
+
"This may lead to incorrect inferences of the model."
|
146
|
+
)
|
147
|
+
kwargs.update({"max_position_embeddings": rbln_max_seq_len})
|
165
148
|
|
166
149
|
# FIXME :: This should be moved when wrapping removed.
|
167
150
|
use_continuous_batch = rbln_config_kwargs.get("rbln_batching", "static") == "vllm"
|
168
|
-
|
151
|
+
wrap_llama_cb() if use_continuous_batch else wrap_llama()
|
169
152
|
|
170
|
-
model
|
171
|
-
|
172
|
-
|
173
|
-
subfolder=subfolder,
|
153
|
+
model = super().get_pytorch_model(
|
154
|
+
model_id=model_id,
|
155
|
+
use_auth_token=use_auth_token,
|
174
156
|
revision=revision,
|
175
|
-
|
157
|
+
force_download=force_download,
|
176
158
|
cache_dir=cache_dir,
|
177
|
-
|
159
|
+
subfolder=subfolder,
|
178
160
|
local_files_only=local_files_only,
|
179
|
-
force_download=force_download,
|
180
161
|
trust_remote_code=trust_remote_code,
|
162
|
+
rbln_config_kwargs=rbln_config_kwargs,
|
163
|
+
rbln_constructor_kwargs=rbln_constructor_kwargs,
|
181
164
|
**kwargs,
|
182
165
|
)
|
183
166
|
|
184
|
-
|
185
|
-
config = model.config
|
167
|
+
unwrap_llama()
|
186
168
|
|
187
|
-
|
188
|
-
preprocessors = maybe_save_preprocessors(model_id, save_dir_path, src_subfolder=subfolder)
|
169
|
+
return model
|
189
170
|
|
190
|
-
|
191
|
-
|
192
|
-
|
193
|
-
|
194
|
-
)
|
171
|
+
@classmethod
|
172
|
+
@torch.inference_mode()
|
173
|
+
def get_compiled_model(cls, model: "PreTrainedModel", rbln_config: RBLNConfig):
|
174
|
+
use_continuous_batch = rbln_config.meta["rbln_batching"] == "vllm"
|
195
175
|
|
196
|
-
|
197
|
-
wrapped_model = wrapper_cls(model).eval()
|
176
|
+
wrapper_cls = LlamaWrapper_cb if use_continuous_batch else LlamaWrapper
|
198
177
|
|
199
|
-
|
200
|
-
dec_rbln_runtime_config = rbln_config[DEFAULT_COMPILED_MODEL_NAME][1]
|
178
|
+
wrapped_model = wrapper_cls(model).eval()
|
201
179
|
|
202
|
-
|
203
|
-
|
180
|
+
prefill_rbln_runtime_config = rbln_config[DEFAULT_COMPILED_MODEL_NAME][0]
|
181
|
+
dec_rbln_runtime_config = rbln_config[DEFAULT_COMPILED_MODEL_NAME][1]
|
204
182
|
|
205
|
-
|
206
|
-
|
207
|
-
dec_example_inputs[batch_index_index].fill_(-1) # fill batch_position -1 to indicate it is decoder.
|
183
|
+
prefill_example_inputs = prefill_rbln_runtime_config.get_dummy_inputs(fill=0)
|
184
|
+
dec_example_inputs = dec_rbln_runtime_config.get_dummy_inputs(fill=4)
|
208
185
|
|
209
|
-
|
210
|
-
|
186
|
+
if use_continuous_batch:
|
187
|
+
batch_index_index = 3
|
188
|
+
dec_example_inputs[batch_index_index].fill_(-1) # fill batch_position -1 to indicate it is decoder.
|
211
189
|
|
212
|
-
|
213
|
-
prefill_scripted_model,
|
214
|
-
input_names=[v[0] for v in prefill_rbln_runtime_config.input_info],
|
215
|
-
)
|
216
|
-
dec_ir = rebel.torchscript_to_ir(
|
217
|
-
dec_scripted_model,
|
218
|
-
input_names=[v[0] for v in dec_rbln_runtime_config.input_info],
|
219
|
-
)
|
190
|
+
wrap_llama_cb() if use_continuous_batch else wrap_llama()
|
220
191
|
|
221
|
-
|
222
|
-
|
223
|
-
connections = [
|
224
|
-
(prefill_ir.outputs[1 + i], prefill_ir.inputs[cache_index_offset + i])
|
225
|
-
for i in range(model.config.num_hidden_layers * 2)
|
226
|
-
]
|
192
|
+
prefill_scripted_model = torch.jit.trace(wrapped_model, prefill_example_inputs, check_trace=False)
|
193
|
+
dec_scripted_model = torch.jit.trace(wrapped_model, dec_example_inputs, check_trace=False)
|
227
194
|
|
228
|
-
|
229
|
-
prefill_ir,
|
230
|
-
dec_ir,
|
231
|
-
connections=connections,
|
232
|
-
fusion=prefill_rbln_runtime_config.fusion,
|
233
|
-
npu=prefill_rbln_runtime_config.npu,
|
234
|
-
tensor_parallel_size=prefill_rbln_runtime_config.tensor_parallel_size,
|
235
|
-
use_weight_sharing=True,
|
236
|
-
)
|
237
|
-
compiled_model.save(save_dir_path / f"{DEFAULT_COMPILED_MODEL_NAME}.rbln")
|
195
|
+
unwrap_llama()
|
238
196
|
|
239
|
-
|
240
|
-
|
241
|
-
|
197
|
+
prefill_ir = rebel.torchscript_to_ir(
|
198
|
+
prefill_scripted_model,
|
199
|
+
input_names=[v[0] for v in prefill_rbln_runtime_config.input_info],
|
200
|
+
)
|
201
|
+
dec_ir = rebel.torchscript_to_ir(
|
202
|
+
dec_scripted_model,
|
203
|
+
input_names=[v[0] for v in dec_rbln_runtime_config.input_info],
|
204
|
+
)
|
242
205
|
|
243
|
-
|
206
|
+
# Caching prefill_decoder/decoder I/O
|
207
|
+
cache_index_offset = 4 if use_continuous_batch else 3
|
208
|
+
connections = [
|
209
|
+
(prefill_ir.outputs[1 + i], prefill_ir.inputs[cache_index_offset + i])
|
210
|
+
for i in range(model.config.num_hidden_layers * 2)
|
211
|
+
]
|
244
212
|
|
245
|
-
|
246
|
-
|
247
|
-
|
248
|
-
|
249
|
-
|
250
|
-
|
213
|
+
compiled_model = rebel.compile(
|
214
|
+
prefill_ir,
|
215
|
+
dec_ir,
|
216
|
+
connections=connections,
|
217
|
+
fusion=prefill_rbln_runtime_config.fusion,
|
218
|
+
npu=prefill_rbln_runtime_config.npu,
|
219
|
+
tensor_parallel_size=prefill_rbln_runtime_config.tensor_parallel_size,
|
220
|
+
use_weight_sharing=True,
|
251
221
|
)
|
222
|
+
return compiled_model
|
252
223
|
|
253
224
|
@classmethod
|
254
225
|
def _get_rbln_config(
|
@@ -338,11 +309,14 @@ class RBLNLlamaForCausalLM(RBLNBaseModel, RBLNGenerationMixin):
|
|
338
309
|
|
339
310
|
return rbln_config
|
340
311
|
|
341
|
-
|
312
|
+
@classmethod
|
313
|
+
def _create_runtimes(
|
314
|
+
cls, compiled_models: List[rebel.RBLNCompiledModel], rbln_device_map: Dict[str, int]
|
315
|
+
) -> List[rebel.Runtime]:
|
342
316
|
device_val = rbln_device_map[DEFAULT_COMPILED_MODEL_NAME]
|
343
317
|
return [
|
344
|
-
|
345
|
-
|
318
|
+
compiled_models[0].create_runtime(input_info_index=0, tensor_type="pt", device=device_val),
|
319
|
+
compiled_models[0].create_runtime(input_info_index=1, tensor_type="pt", device=device_val),
|
346
320
|
]
|
347
321
|
|
348
322
|
def get_decoder(self):
|
@@ -23,20 +23,16 @@
|
|
23
23
|
|
24
24
|
import inspect
|
25
25
|
import logging
|
26
|
-
from pathlib import Path
|
27
|
-
from tempfile import TemporaryDirectory
|
28
26
|
from typing import TYPE_CHECKING, Any, Callable, Dict, List, Optional, Tuple, Union
|
29
27
|
|
30
28
|
import rebel
|
31
29
|
import torch
|
32
|
-
from
|
33
|
-
from transformers import AutoModelForCausalLM, PretrainedConfig, PreTrainedModel
|
30
|
+
from transformers import AutoConfig, AutoModelForCausalLM, PretrainedConfig, PreTrainedModel
|
34
31
|
from transformers.modeling_outputs import CausalLMOutputWithCrossAttentions
|
35
32
|
|
36
|
-
from ....modeling_base import
|
33
|
+
from ....modeling_base import RBLNModel
|
37
34
|
from ....modeling_config import DEFAULT_COMPILED_MODEL_NAME, RBLNConfig, RBLNRuntimeConfig
|
38
35
|
from ....utils.runtime_utils import RBLNPytorchRuntime
|
39
|
-
from ....utils.save_utils import maybe_save_preprocessors
|
40
36
|
from ...generation.utils import RBLNGenerationMixin
|
41
37
|
from .hf_hub_cached.modeling_midm import MidmLMHeadModel
|
42
38
|
from .midm_architecture import (
|
@@ -74,7 +70,7 @@ class RBLNRuntimeDecoder(RBLNPytorchRuntime):
|
|
74
70
|
return logits
|
75
71
|
|
76
72
|
|
77
|
-
class RBLNMidmLMHeadModel(
|
73
|
+
class RBLNMidmLMHeadModel(RBLNModel, RBLNGenerationMixin):
|
78
74
|
"""
|
79
75
|
The Midm Model transformer with a language modeling head on top (linear layer with weights tied to the input
|
80
76
|
embeddings).
|
@@ -122,8 +118,8 @@ class RBLNMidmLMHeadModel(RBLNBaseModel, RBLNGenerationMixin):
|
|
122
118
|
torch.ones(self.batch_size, 1, self.prefill_chunk_size, self.prefill_chunk_size), diagonal=1
|
123
119
|
)
|
124
120
|
|
125
|
-
self.prefill_decoder = RBLNRuntimeDecoder(runtime=self.
|
126
|
-
self.decoder = RBLNRuntimeDecoder(runtime=self.
|
121
|
+
self.prefill_decoder = RBLNRuntimeDecoder(runtime=self.model[0], main_input_name="input_ids")
|
122
|
+
self.decoder = RBLNRuntimeDecoder(runtime=self.model[1], main_input_name="input_ids")
|
127
123
|
self.past_cached_length = 0
|
128
124
|
|
129
125
|
def can_generate(self):
|
@@ -149,10 +145,63 @@ class RBLNMidmLMHeadModel(RBLNBaseModel, RBLNGenerationMixin):
|
|
149
145
|
raise NotImplementedError
|
150
146
|
|
151
147
|
@classmethod
|
152
|
-
|
148
|
+
@torch.inference_mode()
|
149
|
+
def get_compiled_model(cls, model: "PreTrainedModel", rbln_config: RBLNConfig):
|
150
|
+
wrapped_decoder = MidmLMHeadModelWrapper(model).eval()
|
151
|
+
prefill_rbln_runtime_config = rbln_config[DEFAULT_COMPILED_MODEL_NAME][0]
|
152
|
+
dec_rbln_runtime_config = rbln_config[DEFAULT_COMPILED_MODEL_NAME][1]
|
153
|
+
|
154
|
+
prefill_example_inputs = prefill_rbln_runtime_config.get_dummy_inputs(fill=0)
|
155
|
+
dec_example_inputs = dec_rbln_runtime_config.get_dummy_inputs(fill=0)
|
156
|
+
|
157
|
+
prefill_scripted_model = torch.jit.trace(wrapped_decoder, prefill_example_inputs, check_trace=False)
|
158
|
+
dec_scripted_model = torch.jit.trace(wrapped_decoder, dec_example_inputs, check_trace=False)
|
159
|
+
|
160
|
+
prefill_ir = rebel.torchscript_to_ir(
|
161
|
+
prefill_scripted_model,
|
162
|
+
input_names=[v[0] for v in prefill_rbln_runtime_config.input_info],
|
163
|
+
)
|
164
|
+
dec_ir = rebel.torchscript_to_ir(
|
165
|
+
dec_scripted_model,
|
166
|
+
input_names=[v[0] for v in dec_rbln_runtime_config.input_info],
|
167
|
+
)
|
168
|
+
|
169
|
+
connections = [(prefill_ir.outputs[1 + i], prefill_ir.inputs[3 + i]) for i in range(model.config.n_layer * 2)]
|
170
|
+
|
171
|
+
compiled_model = rebel.compile(
|
172
|
+
prefill_ir,
|
173
|
+
dec_ir,
|
174
|
+
connections=connections,
|
175
|
+
fusion=prefill_rbln_runtime_config.fusion,
|
176
|
+
npu=prefill_rbln_runtime_config.npu,
|
177
|
+
tensor_parallel_size=prefill_rbln_runtime_config.tensor_parallel_size,
|
178
|
+
use_weight_sharing=True,
|
179
|
+
)
|
180
|
+
return compiled_model
|
181
|
+
|
182
|
+
@classmethod
|
183
|
+
def update_kwargs(cls, kwargs):
|
184
|
+
"""
|
185
|
+
Update user-given kwargs to get proper pytorch model.
|
186
|
+
|
187
|
+
For example, `torchscript`=True should be set because torch.jit
|
188
|
+
does not support `transformers` output instances as module output;
|
189
|
+
"""
|
190
|
+
kwargs.update(
|
191
|
+
{
|
192
|
+
"torchscript": True,
|
193
|
+
"return_dict": False,
|
194
|
+
"use_cache": True,
|
195
|
+
"torch_dtype": torch.float32,
|
196
|
+
"_attn_implementation": "eager",
|
197
|
+
}
|
198
|
+
)
|
199
|
+
return kwargs
|
200
|
+
|
201
|
+
@classmethod
|
202
|
+
def get_pytorch_model(
|
153
203
|
cls,
|
154
204
|
model_id: str,
|
155
|
-
config: "PretrainedConfig",
|
156
205
|
use_auth_token: Optional[Union[bool, str]] = None,
|
157
206
|
revision: Optional[str] = None,
|
158
207
|
force_download: bool = False,
|
@@ -160,120 +209,35 @@ class RBLNMidmLMHeadModel(RBLNBaseModel, RBLNGenerationMixin):
|
|
160
209
|
subfolder: str = "",
|
161
210
|
local_files_only: bool = False,
|
162
211
|
trust_remote_code: bool = False,
|
163
|
-
|
212
|
+
rbln_config_kwargs: Optional[Dict[str, Any]] = None,
|
213
|
+
rbln_constructor_kwargs: Optional[Dict[str, Any]] = None,
|
164
214
|
**kwargs,
|
165
|
-
) ->
|
166
|
-
|
167
|
-
|
168
|
-
|
169
|
-
|
170
|
-
|
171
|
-
|
172
|
-
|
173
|
-
|
174
|
-
|
175
|
-
save_dir = model_save_dir
|
176
|
-
if isinstance(save_dir, TemporaryDirectory):
|
177
|
-
save_dir_path = Path(model_save_dir.name)
|
178
|
-
else:
|
179
|
-
save_dir_path = Path(model_save_dir)
|
180
|
-
save_dir_path.mkdir(exist_ok=True)
|
181
|
-
|
182
|
-
def update_configs(kwargs):
|
183
|
-
max_seq_len = kwargs.get("rbln_max_seq_len", None)
|
184
|
-
if max_seq_len is not None:
|
185
|
-
kwargs.update({"max_position_embeddings": max_seq_len})
|
186
|
-
|
187
|
-
kwargs.update(
|
188
|
-
{
|
189
|
-
"torchscript": True,
|
190
|
-
"return_dict": False,
|
191
|
-
"use_cache": True,
|
192
|
-
"torch_dtype": torch.float32,
|
193
|
-
"_attn_implementation": "eager",
|
194
|
-
}
|
195
|
-
)
|
196
|
-
|
197
|
-
return kwargs
|
198
|
-
|
199
|
-
kwargs = update_configs(kwargs)
|
200
|
-
|
201
|
-
rbln_config_kwargs, rbln_constructor_kwargs = cls.pop_rbln_kwargs_from_kwargs(kwargs)
|
215
|
+
) -> PreTrainedModel:
|
216
|
+
if rbln_max_seq_len := rbln_config_kwargs.get("rbln_max_seq_len", None):
|
217
|
+
config = AutoConfig.from_pretrained(model_id, trust_remote_code=trust_remote_code)
|
218
|
+
if hf_position_embedding := getattr(config, "max_position_embeddings", None):
|
219
|
+
if hf_position_embedding < rbln_max_seq_len:
|
220
|
+
logger.warning(
|
221
|
+
f"`rbln_max_seq_len` is larger than original config({hf_position_embedding})."
|
222
|
+
"This may lead to incorrect inferences of the model."
|
223
|
+
)
|
224
|
+
kwargs.update({"max_position_embeddings": rbln_max_seq_len})
|
202
225
|
|
203
|
-
|
204
|
-
|
205
|
-
|
206
|
-
subfolder=subfolder,
|
226
|
+
return super().get_pytorch_model(
|
227
|
+
model_id=model_id,
|
228
|
+
use_auth_token=use_auth_token,
|
207
229
|
revision=revision,
|
208
|
-
|
230
|
+
force_download=force_download,
|
209
231
|
cache_dir=cache_dir,
|
210
|
-
|
232
|
+
subfolder=subfolder,
|
211
233
|
local_files_only=local_files_only,
|
212
|
-
force_download=force_download,
|
213
234
|
trust_remote_code=trust_remote_code,
|
235
|
+
rbln_config_kwargs=rbln_config_kwargs,
|
236
|
+
rbln_constructor_kwargs=rbln_constructor_kwargs,
|
214
237
|
ignore_mismatched_sizes=True,
|
215
238
|
**kwargs,
|
216
239
|
)
|
217
240
|
|
218
|
-
if config is None:
|
219
|
-
config = model.config
|
220
|
-
|
221
|
-
config.save_pretrained(save_dir_path)
|
222
|
-
preprocessors = maybe_save_preprocessors(model_id, save_dir_path, src_subfolder=subfolder)
|
223
|
-
|
224
|
-
# Get compilation arguments
|
225
|
-
if rbln_config_kwargs.get("rbln_config", None) is None:
|
226
|
-
rbln_config = cls.get_rbln_config(
|
227
|
-
preprocessors=preprocessors, model_config=model.config, **rbln_config_kwargs
|
228
|
-
)
|
229
|
-
|
230
|
-
def compile_midm():
|
231
|
-
wrapped_decoder = MidmLMHeadModelWrapper(model).eval()
|
232
|
-
prefill_rbln_runtime_config = rbln_config[DEFAULT_COMPILED_MODEL_NAME][0]
|
233
|
-
dec_rbln_runtime_config = rbln_config[DEFAULT_COMPILED_MODEL_NAME][1]
|
234
|
-
|
235
|
-
prefill_example_inputs = prefill_rbln_runtime_config.get_dummy_inputs(fill=0)
|
236
|
-
dec_example_inputs = dec_rbln_runtime_config.get_dummy_inputs(fill=0)
|
237
|
-
|
238
|
-
prefill_scripted_model = torch.jit.trace(wrapped_decoder, prefill_example_inputs)
|
239
|
-
dec_scripted_model = torch.jit.trace(wrapped_decoder, dec_example_inputs)
|
240
|
-
|
241
|
-
prefill_ir = rebel.torchscript_to_ir(
|
242
|
-
prefill_scripted_model,
|
243
|
-
input_names=[v[0] for v in prefill_rbln_runtime_config.input_info],
|
244
|
-
)
|
245
|
-
dec_ir = rebel.torchscript_to_ir(
|
246
|
-
dec_scripted_model,
|
247
|
-
input_names=[v[0] for v in dec_rbln_runtime_config.input_info],
|
248
|
-
)
|
249
|
-
|
250
|
-
connections = [
|
251
|
-
(prefill_ir.outputs[1 + i], prefill_ir.inputs[3 + i]) for i in range(model.config.n_layer * 2)
|
252
|
-
]
|
253
|
-
|
254
|
-
compiled_model = rebel.compile(
|
255
|
-
prefill_ir,
|
256
|
-
dec_ir,
|
257
|
-
connections=connections,
|
258
|
-
fusion=prefill_rbln_runtime_config.fusion,
|
259
|
-
npu=prefill_rbln_runtime_config.npu,
|
260
|
-
tensor_parallel_size=prefill_rbln_runtime_config.tensor_parallel_size,
|
261
|
-
use_weight_sharing=True,
|
262
|
-
)
|
263
|
-
compiled_model.save(save_dir_path / f"{DEFAULT_COMPILED_MODEL_NAME}.rbln")
|
264
|
-
|
265
|
-
compile_midm()
|
266
|
-
|
267
|
-
rbln_config.save(save_dir_path)
|
268
|
-
|
269
|
-
return cls._from_pretrained(
|
270
|
-
model_id=save_dir_path,
|
271
|
-
config=config,
|
272
|
-
model_save_dir=save_dir,
|
273
|
-
**rbln_constructor_kwargs,
|
274
|
-
**kwargs,
|
275
|
-
)
|
276
|
-
|
277
241
|
@classmethod
|
278
242
|
def _get_rbln_config(
|
279
243
|
cls,
|
@@ -345,11 +309,14 @@ class RBLNMidmLMHeadModel(RBLNBaseModel, RBLNGenerationMixin):
|
|
345
309
|
|
346
310
|
return rbln_config
|
347
311
|
|
348
|
-
|
312
|
+
@classmethod
|
313
|
+
def _create_runtimes(
|
314
|
+
cls, compiled_models: List[rebel.RBLNCompiledModel], rbln_device_map: Dict[str, int]
|
315
|
+
) -> List[rebel.Runtime]:
|
349
316
|
device_val = rbln_device_map[DEFAULT_COMPILED_MODEL_NAME]
|
350
317
|
return [
|
351
|
-
|
352
|
-
|
318
|
+
compiled_models[0].create_runtime(input_info_index=0, tensor_type="pt", device=device_val),
|
319
|
+
compiled_models[0].create_runtime(input_info_index=1, tensor_type="pt", device=device_val),
|
353
320
|
]
|
354
321
|
|
355
322
|
def prepare_inputs_for_generation(self, input_ids, past_key_values=0, attention_mask=None, **kwargs):
|
@@ -421,6 +388,3 @@ class RBLNMidmLMHeadModel(RBLNBaseModel, RBLNGenerationMixin):
|
|
421
388
|
cache_position=cache_position,
|
422
389
|
)
|
423
390
|
return CausalLMOutputWithCrossAttentions(logits=output, past_key_values=past_cached_length)
|
424
|
-
|
425
|
-
def __repr__(self):
|
426
|
-
return repr(self.runtimes[0]) + "\n" + repr(self.runtimes[1])
|