optimum-rbln 0.1.4__py3-none-any.whl → 0.1.8__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- optimum/rbln/__init__.py +21 -0
- optimum/rbln/__version__.py +1 -1
- optimum/rbln/diffusers/__init__.py +0 -1
- optimum/rbln/diffusers/models/autoencoder_kl.py +16 -98
- optimum/rbln/diffusers/models/controlnet.py +3 -0
- optimum/rbln/diffusers/models/unet_2d_condition.py +3 -3
- optimum/rbln/diffusers/pipelines/controlnet/multicontrolnet.py +22 -146
- optimum/rbln/diffusers/pipelines/controlnet/pipeline_controlnet.py +109 -53
- optimum/rbln/diffusers/pipelines/controlnet/pipeline_controlnet_img2img.py +114 -53
- optimum/rbln/diffusers/pipelines/controlnet/pipeline_controlnet_sd_xl.py +130 -71
- optimum/rbln/diffusers/pipelines/controlnet/pipeline_controlnet_sd_xl_img2img.py +131 -72
- optimum/rbln/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion.py +3 -0
- optimum/rbln/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_img2img.py +8 -0
- optimum/rbln/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl.py +8 -0
- optimum/rbln/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_img2img.py +9 -0
- optimum/rbln/modeling_alias.py +14 -0
- optimum/rbln/modeling_base.py +282 -100
- optimum/rbln/modeling_seq2seq.py +58 -132
- optimum/rbln/transformers/__init__.py +8 -0
- optimum/rbln/transformers/cache_utils.py +111 -0
- optimum/rbln/transformers/generation/utils.py +0 -2
- optimum/rbln/transformers/models/__init__.py +3 -0
- optimum/rbln/transformers/models/bart/bart_architecture.py +0 -5
- optimum/rbln/transformers/models/clip/modeling_clip.py +0 -1
- optimum/rbln/transformers/models/decoderonly/__init__.py +36 -0
- optimum/rbln/transformers/models/decoderonly/decoderonly_architecture.py +515 -0
- optimum/rbln/transformers/models/decoderonly/modeling_decoderonly.py +349 -0
- optimum/rbln/transformers/models/dpt/__init__.py +24 -0
- optimum/rbln/transformers/models/dpt/modeling_dpt.py +89 -0
- optimum/rbln/transformers/models/gemma/__init__.py +24 -0
- optimum/rbln/transformers/models/gemma/gemma_architecture.py +116 -0
- optimum/rbln/transformers/models/gemma/modeling_gemma.py +61 -0
- optimum/rbln/transformers/models/gpt2/gpt2_architecture.py +200 -174
- optimum/rbln/transformers/models/gpt2/modeling_gpt2.py +57 -293
- optimum/rbln/transformers/models/llama/llama_architecture.py +3 -613
- optimum/rbln/transformers/models/llama/modeling_llama.py +9 -469
- optimum/rbln/transformers/models/midm/hf_hub_cached/midm_bitext_tokenization.py +2 -1
- optimum/rbln/transformers/models/midm/hf_hub_cached/modeling_midm.py +0 -4
- optimum/rbln/transformers/models/midm/midm_architecture.py +160 -357
- optimum/rbln/transformers/models/midm/modeling_midm.py +40 -308
- optimum/rbln/transformers/models/whisper/modeling_whisper.py +53 -123
- optimum/rbln/transformers/models/whisper/whisper_architecture.py +0 -6
- optimum/rbln/transformers/models/xlm_roberta/__init__.py +24 -0
- optimum/rbln/transformers/models/xlm_roberta/modeling_xlm_roberta.py +125 -0
- optimum/rbln/utils/__init__.py +1 -1
- optimum/rbln/utils/import_utils.py +46 -0
- {optimum_rbln-0.1.4.dist-info → optimum_rbln-0.1.8.dist-info}/METADATA +18 -53
- optimum_rbln-0.1.8.dist-info/RECORD +73 -0
- {optimum_rbln-0.1.4.dist-info → optimum_rbln-0.1.8.dist-info}/WHEEL +1 -1
- optimum/rbln/transformers/models/llama/llama_architecture_cb.py +0 -759
- optimum_rbln-0.1.4.dist-info/RECORD +0 -63
- {optimum_rbln-0.1.4.dist-info → optimum_rbln-0.1.8.dist-info}/licenses/LICENSE +0 -0
@@ -21,61 +21,23 @@
|
|
21
21
|
# copied, modified, or distributed without prior written permission
|
22
22
|
# from Rebellions Inc.
|
23
23
|
|
24
|
-
import inspect
|
24
|
+
import inspect
|
25
25
|
import logging
|
26
|
-
from
|
27
|
-
from tempfile import TemporaryDirectory
|
28
|
-
from typing import TYPE_CHECKING, Any, Callable, Dict, List, Optional, Tuple, Union
|
26
|
+
from typing import Any, Callable
|
29
27
|
|
30
|
-
import
|
31
|
-
import rebel # noqa: F401
|
28
|
+
from transformers import LlamaForCausalLM, PreTrainedModel
|
32
29
|
|
33
|
-
from
|
34
|
-
from
|
35
|
-
from transformers.modeling_outputs import CausalLMOutputWithPast
|
36
|
-
|
37
|
-
from ...generation.utils import RBLNGenerationMixin
|
38
|
-
from ....modeling_base import RBLNBaseModel
|
39
|
-
from ....modeling_config import DEFAULT_COMPILED_MODEL_NAME, RBLNConfig, RBLNRuntimeConfig
|
40
|
-
from ....utils.runtime_utils import RBLNPytorchRuntime
|
41
|
-
from ....utils.save_utils import maybe_save_preprocessors
|
42
|
-
|
43
|
-
|
44
|
-
# FIXME:: Merge Two architecture Codes
|
45
|
-
from .llama_architecture import (
|
46
|
-
LlamaWrapper,
|
47
|
-
wrap_llama,
|
48
|
-
unwrap_llama,
|
49
|
-
)
|
50
|
-
|
51
|
-
from .llama_architecture_cb import (
|
52
|
-
LlamaDynamicBatchWrapper as LlamaWrapper_cb,
|
53
|
-
wrap_llama as wrap_llama_cb,
|
54
|
-
)
|
30
|
+
from ...models.decoderonly import RBLNDecoderOnlyModelForCausalLM
|
31
|
+
from .llama_architecture import LlamaWrapper
|
55
32
|
|
56
33
|
|
57
34
|
logger = logging.getLogger(__name__)
|
58
35
|
|
59
|
-
if TYPE_CHECKING:
|
60
|
-
from transformers import (
|
61
|
-
AutoFeatureExtractor,
|
62
|
-
AutoProcessor,
|
63
|
-
AutoTokenizer,
|
64
|
-
PretrainedConfig,
|
65
|
-
)
|
66
|
-
|
67
36
|
|
68
|
-
|
69
|
-
|
70
|
-
|
71
|
-
class RBLNRuntimeModel(RBLNPytorchRuntime):
|
72
|
-
mandatory_members = ["main_input_name"]
|
73
|
-
|
74
|
-
|
75
|
-
class RBLNLlamaForCausalLM(RBLNBaseModel, RBLNGenerationMixin):
|
37
|
+
class RBLNLlamaForCausalLM(RBLNDecoderOnlyModelForCausalLM):
|
76
38
|
"""
|
77
39
|
The Llama Model transformer with a language modeling head (linear layer) on top.
|
78
|
-
This model inherits from [`
|
40
|
+
This model inherits from [`RBLNMultiModel`]. Check the superclass documentation for the generic methods the library implements for all its models.
|
79
41
|
|
80
42
|
A class to convert and run pre-trained transformers based LlamaForCausalLM model on RBLN devices.
|
81
43
|
It implements the methods to convert a pre-trained transformers LlamaForCausalLM model into a RBLN transformer model by:
|
@@ -83,273 +45,9 @@ class RBLNLlamaForCausalLM(RBLNBaseModel, RBLNGenerationMixin):
|
|
83
45
|
- compiling the resulting graph using the RBLN compiler.
|
84
46
|
"""
|
85
47
|
|
86
|
-
model_type = "rbln_model"
|
87
|
-
main_input_name = "input_ids"
|
88
|
-
auto_model_class = AutoModelForCausalLM
|
89
|
-
|
90
|
-
def __post_init__(self, **kwargs):
|
91
|
-
self.batch_size = self.rbln_config.meta["rbln_batch_size"]
|
92
|
-
self.max_seq_len = self.rbln_config.meta["rbln_max_seq_len"]
|
93
|
-
self.prefill_chunk_size = self.rbln_config.meta["rbln_prefill_chunk_size"]
|
94
|
-
self.use_continuous_batch = self.rbln_config.meta["rbln_batching"] == "vllm"
|
95
|
-
|
96
|
-
prefill_batch_size = self.batch_size if not self.use_continuous_batch else 1
|
97
|
-
self.prefill_attention_mask = torch.zeros(
|
98
|
-
prefill_batch_size, 1, self.prefill_chunk_size, self.max_seq_len, dtype=torch.int64
|
99
|
-
)
|
100
|
-
self.causal_mask = 1 - torch.triu(
|
101
|
-
torch.ones(prefill_batch_size, 1, self.prefill_chunk_size, self.prefill_chunk_size), diagonal=1
|
102
|
-
)
|
103
|
-
self.decoder_attention_mask = torch.zeros(self.batch_size, 1, 1, self.max_seq_len, dtype=torch.int64)
|
104
|
-
|
105
|
-
self.prefill_decoder = RBLNRuntimeModel(runtime=self.runtimes[0], main_input_name="input_ids")
|
106
|
-
self.decoder = RBLNRuntimeModel(runtime=self.runtimes[1], main_input_name="input_ids")
|
107
|
-
self.past_cached_length = 0
|
108
|
-
self.right_padding = True
|
109
|
-
|
110
|
-
@classmethod
|
111
|
-
@torch.no_grad()
|
112
|
-
def _export(
|
113
|
-
cls,
|
114
|
-
model_id: str,
|
115
|
-
config: "PretrainedConfig",
|
116
|
-
use_auth_token: Optional[Union[bool, str]] = None,
|
117
|
-
revision: Optional[str] = None,
|
118
|
-
force_download: bool = False,
|
119
|
-
cache_dir: Optional[str] = None,
|
120
|
-
subfolder: str = "",
|
121
|
-
local_files_only: bool = False,
|
122
|
-
trust_remote_code: bool = False,
|
123
|
-
model_save_dir: Optional[Union[str, Path, TemporaryDirectory]] = None,
|
124
|
-
**kwargs,
|
125
|
-
) -> "RBLNLlamaForCausalLM":
|
126
|
-
task = kwargs.pop("task", None)
|
127
|
-
if task is None:
|
128
|
-
task = TasksManager.infer_task_from_model(cls.auto_model_class)
|
129
|
-
|
130
|
-
if model_save_dir is None:
|
131
|
-
save_dir = TemporaryDirectory()
|
132
|
-
save_dir_path = Path(save_dir.name)
|
133
|
-
else:
|
134
|
-
save_dir = model_save_dir
|
135
|
-
if isinstance(save_dir, TemporaryDirectory):
|
136
|
-
save_dir_path = Path(model_save_dir.name)
|
137
|
-
else:
|
138
|
-
save_dir_path = Path(model_save_dir)
|
139
|
-
save_dir_path.mkdir(exist_ok=True)
|
140
|
-
|
141
|
-
def update_configs(kwargs):
|
142
|
-
hf_max_position_embeddings = getattr(AutoConfig.from_pretrained(model_id), "max_position_embeddings", None)
|
143
|
-
max_seq_len = kwargs.get("rbln_max_seq_len", None)
|
144
|
-
if max_seq_len is not None:
|
145
|
-
if max_seq_len <= hf_max_position_embeddings:
|
146
|
-
kwargs.update({"max_position_embeddings": max_seq_len})
|
147
|
-
else:
|
148
|
-
raise ValueError("`max_seq_len` should be less or equal than max_position_embeddings!")
|
149
|
-
|
150
|
-
kwargs.update(
|
151
|
-
{
|
152
|
-
"torchscript": True,
|
153
|
-
"return_dict": False,
|
154
|
-
"use_cache": True,
|
155
|
-
"torch_dtype": torch.float32,
|
156
|
-
"_attn_implementation": "eager",
|
157
|
-
}
|
158
|
-
)
|
159
|
-
|
160
|
-
return kwargs
|
161
|
-
|
162
|
-
kwargs = update_configs(kwargs)
|
163
|
-
|
164
|
-
rbln_config_kwargs, rbln_constructor_kwargs = cls.pop_rbln_kwargs_from_kwargs(kwargs)
|
165
|
-
|
166
|
-
# FIXME :: This should be moved when wrapping removed.
|
167
|
-
use_continuous_batch = rbln_config_kwargs.get("rbln_batching", "static") == "vllm"
|
168
|
-
origin_mehtods = wrap_llama_cb() if use_continuous_batch else wrap_llama()
|
169
|
-
|
170
|
-
model: LlamaForCausalLM = TasksManager.get_model_from_task(
|
171
|
-
task=task,
|
172
|
-
model_name_or_path=model_id,
|
173
|
-
subfolder=subfolder,
|
174
|
-
revision=revision,
|
175
|
-
framework="pt",
|
176
|
-
cache_dir=cache_dir,
|
177
|
-
use_auth_token=use_auth_token,
|
178
|
-
local_files_only=local_files_only,
|
179
|
-
force_download=force_download,
|
180
|
-
trust_remote_code=trust_remote_code,
|
181
|
-
**kwargs,
|
182
|
-
)
|
183
|
-
|
184
|
-
if config is None:
|
185
|
-
config = model.config
|
186
|
-
|
187
|
-
config.save_pretrained(save_dir_path)
|
188
|
-
preprocessors = maybe_save_preprocessors(model_id, save_dir_path, src_subfolder=subfolder)
|
189
|
-
|
190
|
-
# Get compilation arguments
|
191
|
-
if rbln_config_kwargs.get("rbln_config", None) is None:
|
192
|
-
rbln_config = cls.get_rbln_config(
|
193
|
-
preprocessors=preprocessors, model_config=model.config, **rbln_config_kwargs
|
194
|
-
)
|
195
|
-
|
196
|
-
def compile_llama(use_continuous_batch, wrapper_cls):
|
197
|
-
wrapped_model = wrapper_cls(model).eval()
|
198
|
-
|
199
|
-
prefill_rbln_runtime_config = rbln_config[DEFAULT_COMPILED_MODEL_NAME][0]
|
200
|
-
dec_rbln_runtime_config = rbln_config[DEFAULT_COMPILED_MODEL_NAME][1]
|
201
|
-
|
202
|
-
prefill_example_inputs = prefill_rbln_runtime_config.get_dummy_inputs(fill=0)
|
203
|
-
dec_example_inputs = dec_rbln_runtime_config.get_dummy_inputs(fill=4)
|
204
|
-
|
205
|
-
if use_continuous_batch:
|
206
|
-
batch_index_index = 3
|
207
|
-
dec_example_inputs[batch_index_index].fill_(-1) # fill batch_position -1 to indicate it is decoder.
|
208
|
-
|
209
|
-
prefill_scripted_model = torch.jit.trace(wrapped_model, prefill_example_inputs)
|
210
|
-
dec_scripted_model = torch.jit.trace(wrapped_model, dec_example_inputs)
|
211
|
-
|
212
|
-
prefill_ir = rebel.torchscript_to_ir(
|
213
|
-
prefill_scripted_model,
|
214
|
-
input_names=[v[0] for v in prefill_rbln_runtime_config.input_info],
|
215
|
-
)
|
216
|
-
dec_ir = rebel.torchscript_to_ir(
|
217
|
-
dec_scripted_model,
|
218
|
-
input_names=[v[0] for v in dec_rbln_runtime_config.input_info],
|
219
|
-
)
|
220
|
-
|
221
|
-
# Caching prefill_decoder/decoder I/O
|
222
|
-
cache_index_offset = 4 if use_continuous_batch else 3
|
223
|
-
connections = [
|
224
|
-
(prefill_ir.outputs[1 + i], prefill_ir.inputs[cache_index_offset + i])
|
225
|
-
for i in range(model.config.num_hidden_layers * 2)
|
226
|
-
]
|
227
|
-
|
228
|
-
compiled_model = rebel.compile(
|
229
|
-
prefill_ir,
|
230
|
-
dec_ir,
|
231
|
-
connections=connections,
|
232
|
-
fusion=prefill_rbln_runtime_config.fusion,
|
233
|
-
npu=prefill_rbln_runtime_config.npu,
|
234
|
-
tensor_parallel_size=prefill_rbln_runtime_config.tensor_parallel_size,
|
235
|
-
use_weight_sharing=True,
|
236
|
-
)
|
237
|
-
compiled_model.save(save_dir_path / f"{DEFAULT_COMPILED_MODEL_NAME}.rbln")
|
238
|
-
|
239
|
-
wrapper_cls = LlamaWrapper_cb if use_continuous_batch else LlamaWrapper
|
240
|
-
compile_llama(use_continuous_batch=use_continuous_batch, wrapper_cls=wrapper_cls)
|
241
|
-
unwrap_llama(origin_mehtods)
|
242
|
-
|
243
|
-
rbln_config.save(save_dir_path)
|
244
|
-
|
245
|
-
return cls._from_pretrained(
|
246
|
-
model_id=save_dir_path,
|
247
|
-
config=config,
|
248
|
-
model_save_dir=save_dir,
|
249
|
-
**rbln_constructor_kwargs,
|
250
|
-
**kwargs,
|
251
|
-
)
|
252
|
-
|
253
48
|
@classmethod
|
254
|
-
def
|
255
|
-
|
256
|
-
preprocessors: Union["AutoFeatureExtractor", "AutoProcessor", "AutoTokenizer"],
|
257
|
-
model_config: "PretrainedConfig",
|
258
|
-
rbln_max_seq_len: Optional[int] = None,
|
259
|
-
rbln_batch_size: Optional[int] = None,
|
260
|
-
rbln_batching: Optional[str] = None,
|
261
|
-
) -> RBLNConfig:
|
262
|
-
meta = {}
|
263
|
-
|
264
|
-
prefill_chunk_size = 128
|
265
|
-
if rbln_max_seq_len is None:
|
266
|
-
rbln_max_seq_len = getattr(model_config, "max_position_embeddings", None)
|
267
|
-
rbln_batch_size = 1 if rbln_batch_size is None else rbln_batch_size
|
268
|
-
rbln_batching = "static" if rbln_batching is None else rbln_batching
|
269
|
-
|
270
|
-
meta["rbln_max_seq_len"] = rbln_max_seq_len
|
271
|
-
meta["rbln_batch_size"] = rbln_batch_size
|
272
|
-
meta["rbln_prefill_chunk_size"] = prefill_chunk_size
|
273
|
-
meta["rbln_batching"] = rbln_batching
|
274
|
-
use_continuous_batching = meta["rbln_batching"] == "vllm"
|
275
|
-
|
276
|
-
if rbln_batching not in SUPPORTED_BATCHING_MODES:
|
277
|
-
raise ValueError(
|
278
|
-
f'rbln_batching="{rbln_batching}" is not a supported batch mode, '
|
279
|
-
f"Possible: {SUPPORTED_BATCHING_MODES}"
|
280
|
-
)
|
281
|
-
|
282
|
-
def get_input_info(
|
283
|
-
batch_size, # should be 1 if continous batch prefill
|
284
|
-
query_length,
|
285
|
-
continuous_batch=False, # determines the shape of `cache position`
|
286
|
-
):
|
287
|
-
input_info = [
|
288
|
-
("input_ids", [batch_size, query_length], "int64"),
|
289
|
-
("attention_mask", [batch_size, 1, query_length, rbln_max_seq_len], "int64"),
|
290
|
-
(
|
291
|
-
"cache_position",
|
292
|
-
[batch_size, query_length] if continuous_batch else [],
|
293
|
-
"int32",
|
294
|
-
),
|
295
|
-
]
|
296
|
-
|
297
|
-
if continuous_batch:
|
298
|
-
input_info.append(("batch_position", [], "int16"))
|
299
|
-
|
300
|
-
input_info.extend(
|
301
|
-
[
|
302
|
-
(
|
303
|
-
f"past_key_values_{i}",
|
304
|
-
[
|
305
|
-
rbln_batch_size,
|
306
|
-
model_config.num_key_value_heads,
|
307
|
-
rbln_max_seq_len,
|
308
|
-
model_config.hidden_size // model_config.num_attention_heads,
|
309
|
-
],
|
310
|
-
"float32",
|
311
|
-
)
|
312
|
-
for i in range(model_config.num_hidden_layers * 2)
|
313
|
-
]
|
314
|
-
)
|
315
|
-
|
316
|
-
return input_info
|
317
|
-
|
318
|
-
prefill_input_info = get_input_info(
|
319
|
-
batch_size=1 if use_continuous_batching else rbln_batch_size,
|
320
|
-
query_length=prefill_chunk_size,
|
321
|
-
continuous_batch=use_continuous_batching,
|
322
|
-
)
|
323
|
-
dec_input_info = get_input_info(
|
324
|
-
batch_size=rbln_batch_size,
|
325
|
-
query_length=1,
|
326
|
-
continuous_batch=use_continuous_batching,
|
327
|
-
)
|
328
|
-
|
329
|
-
prefill_rbln_runtime_config = RBLNRuntimeConfig(input_info=prefill_input_info)
|
330
|
-
dec_rbln_runtime_config = RBLNRuntimeConfig(input_info=dec_input_info)
|
331
|
-
|
332
|
-
dec_rbln_runtime_config.batch_size = rbln_batch_size
|
333
|
-
|
334
|
-
rbln_config = RBLNConfig.from_rbln_runtime_configs(
|
335
|
-
[prefill_rbln_runtime_config, dec_rbln_runtime_config],
|
336
|
-
_rbln_meta=meta,
|
337
|
-
)
|
338
|
-
|
339
|
-
return rbln_config
|
340
|
-
|
341
|
-
def _create_runtimes(self, rbln_device_map: Dict[str, int]) -> List[rebel.Runtime]:
|
342
|
-
device_val = rbln_device_map[DEFAULT_COMPILED_MODEL_NAME]
|
343
|
-
return [
|
344
|
-
self.compiled_models[0].create_runtime(input_info_index=0, tensor_type="pt", device=device_val),
|
345
|
-
self.compiled_models[0].create_runtime(input_info_index=1, tensor_type="pt", device=device_val),
|
346
|
-
]
|
347
|
-
|
348
|
-
def get_decoder(self):
|
349
|
-
return self.decoder
|
350
|
-
|
351
|
-
def can_generate(self):
|
352
|
-
return True
|
49
|
+
def wrapping_torch_model(self, model: "PreTrainedModel", rbln_max_seq_len: int):
|
50
|
+
return LlamaWrapper(model, rbln_max_seq_len).eval()
|
353
51
|
|
354
52
|
def __getattr__(self, __name: str) -> Any:
|
355
53
|
def redirect(func):
|
@@ -361,161 +59,3 @@ class RBLNLlamaForCausalLM(RBLNBaseModel, RBLNGenerationMixin):
|
|
361
59
|
return redirect(val)
|
362
60
|
|
363
61
|
return val
|
364
|
-
|
365
|
-
def _reorder_cache(self, past_key_values, beam_idx):
|
366
|
-
raise NotImplementedError
|
367
|
-
|
368
|
-
# args input_ids, past_key_values and attention_mask are updated by _update_model_kwargs_for_generation() in _greedy_search() in GenerationMixin
|
369
|
-
def prepare_inputs_for_generation(self, input_ids, past_key_values=0, attention_mask=None, **kwargs):
|
370
|
-
batch_size, cur_len = input_ids.shape
|
371
|
-
past_cached_length = past_key_values
|
372
|
-
|
373
|
-
# In greedy decoding
|
374
|
-
if past_cached_length == 0:
|
375
|
-
# padding with prefill_chunk_size
|
376
|
-
# TODO left padding + left padding has issue on stoppingcriteria(max_len)
|
377
|
-
if cur_len % self.prefill_chunk_size != 0:
|
378
|
-
pad_len = self.prefill_chunk_size - cur_len % self.prefill_chunk_size
|
379
|
-
input_ids = torch.nn.functional.pad(input_ids, (0, pad_len))
|
380
|
-
|
381
|
-
# padding_side
|
382
|
-
if batch_size > 1 and torch.all(attention_mask[..., -1] == 1):
|
383
|
-
self.right_padding = False
|
384
|
-
|
385
|
-
if self.right_padding:
|
386
|
-
self.rightpad_max_len = cur_len
|
387
|
-
prompt_min_len = torch.min(torch.sum(attention_mask, dim=-1))
|
388
|
-
self.dummy_len = torch.sum(attention_mask, dim=-1) - prompt_min_len # dummy_decoder generation length
|
389
|
-
query_length = prompt_min_len.item()
|
390
|
-
else:
|
391
|
-
query_length = cur_len - past_cached_length
|
392
|
-
self.prompt_length = query_length
|
393
|
-
self.prompt_attn_mask = attention_mask.unsqueeze(1).unsqueeze(1).contiguous()
|
394
|
-
|
395
|
-
attention_mask = self.prefill_attention_mask.clone()
|
396
|
-
cache_position = torch.tensor(0, dtype=torch.int32)
|
397
|
-
|
398
|
-
else:
|
399
|
-
if self.right_padding:
|
400
|
-
attention_mask = torch.zeros(batch_size, 1, 1, self.max_seq_len, dtype=torch.int64)
|
401
|
-
attention_mask[:, :, :, : past_cached_length + 1] = 1
|
402
|
-
input_ids = input_ids[:, past_cached_length : past_cached_length + 1].contiguous()
|
403
|
-
else:
|
404
|
-
attention_mask = torch.nn.functional.pad(attention_mask, (0, self.max_seq_len - cur_len))
|
405
|
-
attention_mask = attention_mask.reshape(batch_size, 1, 1, -1).contiguous()
|
406
|
-
input_ids = input_ids[:, -1:]
|
407
|
-
|
408
|
-
cache_position = torch.tensor(past_cached_length, dtype=torch.int32)
|
409
|
-
query_length = 1
|
410
|
-
|
411
|
-
model_inputs = {
|
412
|
-
"input_ids": input_ids,
|
413
|
-
"past_key_values": past_key_values,
|
414
|
-
"attention_mask": attention_mask,
|
415
|
-
"cache_position": cache_position,
|
416
|
-
"query_length": query_length,
|
417
|
-
}
|
418
|
-
|
419
|
-
return model_inputs
|
420
|
-
|
421
|
-
def forward(self, *args, **kwargs):
|
422
|
-
if self.use_continuous_batch:
|
423
|
-
return self.forward_cb(*args, **kwargs)
|
424
|
-
else:
|
425
|
-
return self.forward_static(*args, **kwargs)
|
426
|
-
|
427
|
-
def forward_static(
|
428
|
-
self,
|
429
|
-
input_ids: torch.LongTensor = None,
|
430
|
-
attention_mask: Optional[torch.Tensor] = None,
|
431
|
-
past_key_values: int = None,
|
432
|
-
cache_position: Optional[torch.Tensor] = None,
|
433
|
-
query_length: Optional[torch.Tensor] = None,
|
434
|
-
**kwargs,
|
435
|
-
) -> Tuple[torch.FloatTensor]:
|
436
|
-
if past_key_values is not None:
|
437
|
-
past_key_values += query_length
|
438
|
-
|
439
|
-
# prefill_decoder
|
440
|
-
if cache_position == 0:
|
441
|
-
for step in range(0, query_length, self.prefill_chunk_size):
|
442
|
-
sliced_input_ids = input_ids[:, step : step + self.prefill_chunk_size]
|
443
|
-
attention_mask[:, :, :, :step] = 1
|
444
|
-
attention_mask[:, :, :, step : step + self.prefill_chunk_size] = self.causal_mask
|
445
|
-
if not self.right_padding:
|
446
|
-
attention_mask[:, :, :, : self.prompt_length] &= self.prompt_attn_mask[:, :, :, :]
|
447
|
-
|
448
|
-
outputs = self.prefill_decoder(
|
449
|
-
input_ids=sliced_input_ids.contiguous(),
|
450
|
-
attention_mask=attention_mask.contiguous(),
|
451
|
-
cache_position=cache_position + step,
|
452
|
-
)
|
453
|
-
outputs = outputs[:, query_length % self.prefill_chunk_size - 1].unsqueeze(1)
|
454
|
-
|
455
|
-
# decoder
|
456
|
-
else:
|
457
|
-
outputs = self.decoder(
|
458
|
-
input_ids.contiguous(),
|
459
|
-
attention_mask.contiguous(),
|
460
|
-
cache_position=cache_position,
|
461
|
-
)
|
462
|
-
|
463
|
-
return CausalLMOutputWithPast(
|
464
|
-
logits=outputs,
|
465
|
-
past_key_values=past_key_values,
|
466
|
-
)
|
467
|
-
|
468
|
-
def forward_cb(
|
469
|
-
self,
|
470
|
-
input_ids: torch.LongTensor = None,
|
471
|
-
cache_position: Optional[torch.Tensor] = None, # torch.tensor(,dtype=int32) (1,64) // (4,1)
|
472
|
-
batch_idx: int = None,
|
473
|
-
**kwargs,
|
474
|
-
) -> Tuple[torch.FloatTensor]:
|
475
|
-
# prefill_decoder
|
476
|
-
if cache_position.shape[1] > 1:
|
477
|
-
query_length = input_ids.shape[1]
|
478
|
-
attention_mask = self.prefill_attention_mask.clone()
|
479
|
-
for step in range(0, query_length, self.prefill_chunk_size):
|
480
|
-
if step + self.prefill_chunk_size > query_length:
|
481
|
-
input_ids = torch.nn.functional.pad(input_ids, (0, step + self.prefill_chunk_size - query_length))
|
482
|
-
cache_position = torch.cat(
|
483
|
-
[
|
484
|
-
cache_position,
|
485
|
-
torch.arange(
|
486
|
-
query_length,
|
487
|
-
step + self.prefill_chunk_size,
|
488
|
-
dtype=torch.int32,
|
489
|
-
).unsqueeze(0),
|
490
|
-
],
|
491
|
-
dim=-1,
|
492
|
-
)
|
493
|
-
|
494
|
-
sliced_input_ids = input_ids[:, step : step + self.prefill_chunk_size]
|
495
|
-
sliced_cache_positions = cache_position[:, step : step + self.prefill_chunk_size]
|
496
|
-
attention_mask[:, :, :, :step] = 1
|
497
|
-
attention_mask[:, :, :, step : step + self.prefill_chunk_size] = self.causal_mask
|
498
|
-
|
499
|
-
outputs, _ = self.prefill_decoder(
|
500
|
-
sliced_input_ids.contiguous(),
|
501
|
-
attention_mask.contiguous(),
|
502
|
-
sliced_cache_positions.contiguous(),
|
503
|
-
torch.tensor(batch_idx, dtype=torch.int16),
|
504
|
-
)
|
505
|
-
outputs = outputs[:, query_length % self.prefill_chunk_size - 1].unsqueeze(1)
|
506
|
-
# decoder
|
507
|
-
else:
|
508
|
-
attention_mask = self.decoder_attention_mask.clone()
|
509
|
-
for b_idx in range(self.batch_size):
|
510
|
-
attention_mask[b_idx, :, :, : cache_position[b_idx].item() + 1] = 1
|
511
|
-
|
512
|
-
outputs = self.decoder(
|
513
|
-
input_ids.contiguous(),
|
514
|
-
attention_mask.contiguous(),
|
515
|
-
cache_position.contiguous(),
|
516
|
-
torch.tensor(0, dtype=torch.int16),
|
517
|
-
)[0]
|
518
|
-
|
519
|
-
return CausalLMOutputWithPast(
|
520
|
-
logits=outputs,
|
521
|
-
)
|
@@ -10,7 +10,8 @@
|
|
10
10
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
11
11
|
# See the License for the specific language governing permissions and
|
12
12
|
# limitations under the License.
|
13
|
-
"""
|
13
|
+
"""Tokenization class for model Midm_bitext_tonkenizer."""
|
14
|
+
|
14
15
|
import os
|
15
16
|
import re
|
16
17
|
import warnings
|
@@ -817,7 +817,6 @@ class MidmModel(MidmPreTrainedModel):
|
|
817
817
|
all_cross_attentions = () if output_attentions and self.config.add_cross_attention else None
|
818
818
|
all_hidden_states = () if output_hidden_states else None
|
819
819
|
for i, (block, layer_past) in enumerate(zip(self.h, past_key_values)):
|
820
|
-
|
821
820
|
# Model parallel
|
822
821
|
if self.model_parallel:
|
823
822
|
torch.cuda.set_device(hidden_states.device)
|
@@ -833,7 +832,6 @@ class MidmModel(MidmPreTrainedModel):
|
|
833
832
|
all_hidden_states = all_hidden_states + (hidden_states,)
|
834
833
|
|
835
834
|
if self.gradient_checkpointing and self.training:
|
836
|
-
|
837
835
|
if use_cache:
|
838
836
|
logger.warning(
|
839
837
|
"`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..."
|
@@ -1174,7 +1172,6 @@ class MidmDoubleHeadsModel(MidmPreTrainedModel):
|
|
1174
1172
|
return_dict=None,
|
1175
1173
|
**kwargs,
|
1176
1174
|
):
|
1177
|
-
|
1178
1175
|
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
|
1179
1176
|
|
1180
1177
|
transformer_outputs = self.transformer(
|
@@ -1445,7 +1442,6 @@ def get_submodule(module, target: str): # -> "Module":
|
|
1445
1442
|
mod: torch.nn.Module = module
|
1446
1443
|
|
1447
1444
|
for item in atoms:
|
1448
|
-
|
1449
1445
|
if not hasattr(mod, item):
|
1450
1446
|
raise AttributeError(mod._get_name() + " has no " "attribute `" + item + "`")
|
1451
1447
|
|