optimum-rbln 0.1.4__py3-none-any.whl → 0.1.7__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (31) hide show
  1. optimum/rbln/__init__.py +7 -0
  2. optimum/rbln/__version__.py +1 -1
  3. optimum/rbln/diffusers/models/autoencoder_kl.py +16 -98
  4. optimum/rbln/diffusers/models/unet_2d_condition.py +1 -1
  5. optimum/rbln/diffusers/pipelines/controlnet/multicontrolnet.py +9 -11
  6. optimum/rbln/diffusers/pipelines/controlnet/pipeline_controlnet.py +8 -0
  7. optimum/rbln/diffusers/pipelines/controlnet/pipeline_controlnet_img2img.py +9 -0
  8. optimum/rbln/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion.py +3 -0
  9. optimum/rbln/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_img2img.py +8 -0
  10. optimum/rbln/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl.py +8 -0
  11. optimum/rbln/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_img2img.py +9 -0
  12. optimum/rbln/modeling_base.py +172 -100
  13. optimum/rbln/modeling_seq2seq.py +58 -132
  14. optimum/rbln/transformers/__init__.py +2 -0
  15. optimum/rbln/transformers/models/__init__.py +1 -0
  16. optimum/rbln/transformers/models/clip/modeling_clip.py +0 -1
  17. optimum/rbln/transformers/models/dpt/__init__.py +24 -0
  18. optimum/rbln/transformers/models/dpt/modeling_dpt.py +89 -0
  19. optimum/rbln/transformers/models/gpt2/gpt2_architecture.py +24 -33
  20. optimum/rbln/transformers/models/gpt2/modeling_gpt2.py +52 -124
  21. optimum/rbln/transformers/models/llama/llama_architecture.py +13 -16
  22. optimum/rbln/transformers/models/llama/llama_architecture_cb.py +41 -36
  23. optimum/rbln/transformers/models/llama/modeling_llama.py +94 -120
  24. optimum/rbln/transformers/models/midm/modeling_midm.py +85 -121
  25. optimum/rbln/transformers/models/whisper/modeling_whisper.py +53 -123
  26. optimum/rbln/utils/__init__.py +1 -1
  27. optimum/rbln/utils/import_utils.py +46 -0
  28. {optimum_rbln-0.1.4.dist-info → optimum_rbln-0.1.7.dist-info}/METADATA +17 -51
  29. {optimum_rbln-0.1.4.dist-info → optimum_rbln-0.1.7.dist-info}/RECORD +31 -29
  30. {optimum_rbln-0.1.4.dist-info → optimum_rbln-0.1.7.dist-info}/WHEEL +1 -1
  31. {optimum_rbln-0.1.4.dist-info → optimum_rbln-0.1.7.dist-info}/licenses/LICENSE +0 -0
@@ -23,22 +23,18 @@
23
23
 
24
24
  import inspect # noqa: I001
25
25
  import logging
26
- from pathlib import Path
27
- from tempfile import TemporaryDirectory
28
26
  from typing import TYPE_CHECKING, Any, Callable, Dict, List, Optional, Tuple, Union
29
27
 
30
28
  import torch # noqa: F401
31
29
  import rebel # noqa: F401
32
30
 
33
- from optimum.exporters import TasksManager
34
- from transformers import AutoModelForCausalLM, LlamaForCausalLM, PretrainedConfig, AutoConfig
31
+ from transformers import AutoModelForCausalLM, LlamaForCausalLM, PreTrainedModel, PretrainedConfig, AutoConfig
35
32
  from transformers.modeling_outputs import CausalLMOutputWithPast
36
33
 
37
34
  from ...generation.utils import RBLNGenerationMixin
38
- from ....modeling_base import RBLNBaseModel
35
+ from ....modeling_base import RBLNModel
39
36
  from ....modeling_config import DEFAULT_COMPILED_MODEL_NAME, RBLNConfig, RBLNRuntimeConfig
40
37
  from ....utils.runtime_utils import RBLNPytorchRuntime
41
- from ....utils.save_utils import maybe_save_preprocessors
42
38
 
43
39
 
44
40
  # FIXME:: Merge Two architecture Codes
@@ -72,10 +68,10 @@ class RBLNRuntimeModel(RBLNPytorchRuntime):
72
68
  mandatory_members = ["main_input_name"]
73
69
 
74
70
 
75
- class RBLNLlamaForCausalLM(RBLNBaseModel, RBLNGenerationMixin):
71
+ class RBLNLlamaForCausalLM(RBLNModel, RBLNGenerationMixin):
76
72
  """
77
73
  The Llama Model transformer with a language modeling head (linear layer) on top.
78
- This model inherits from [`RBLNBaseModel`]. Check the superclass documentation for the generic methods the library implements for all its models.
74
+ This model inherits from [`RBLNMultiModel`]. Check the superclass documentation for the generic methods the library implements for all its models.
79
75
 
80
76
  A class to convert and run pre-trained transformers based LlamaForCausalLM model on RBLN devices.
81
77
  It implements the methods to convert a pre-trained transformers LlamaForCausalLM model into a RBLN transformer model by:
@@ -83,7 +79,6 @@ class RBLNLlamaForCausalLM(RBLNBaseModel, RBLNGenerationMixin):
83
79
  - compiling the resulting graph using the RBLN compiler.
84
80
  """
85
81
 
86
- model_type = "rbln_model"
87
82
  main_input_name = "input_ids"
88
83
  auto_model_class = AutoModelForCausalLM
89
84
 
@@ -102,17 +97,34 @@ class RBLNLlamaForCausalLM(RBLNBaseModel, RBLNGenerationMixin):
102
97
  )
103
98
  self.decoder_attention_mask = torch.zeros(self.batch_size, 1, 1, self.max_seq_len, dtype=torch.int64)
104
99
 
105
- self.prefill_decoder = RBLNRuntimeModel(runtime=self.runtimes[0], main_input_name="input_ids")
106
- self.decoder = RBLNRuntimeModel(runtime=self.runtimes[1], main_input_name="input_ids")
100
+ self.prefill_decoder = RBLNRuntimeModel(runtime=self.model[0], main_input_name="input_ids")
101
+ self.decoder = RBLNRuntimeModel(runtime=self.model[1], main_input_name="input_ids")
107
102
  self.past_cached_length = 0
108
103
  self.right_padding = True
109
104
 
110
105
  @classmethod
111
- @torch.no_grad()
112
- def _export(
106
+ def update_kwargs(cls, kwargs):
107
+ """
108
+ Update user-given kwargs to get proper pytorch model.
109
+
110
+ For example, `torchscript`=True should be set because torch.jit
111
+ does not support `transformers` output instances as module output;
112
+ """
113
+ kwargs.update(
114
+ {
115
+ "torchscript": True,
116
+ "return_dict": False,
117
+ "use_cache": True,
118
+ "torch_dtype": torch.float32,
119
+ "_attn_implementation": "eager",
120
+ }
121
+ )
122
+ return kwargs
123
+
124
+ @classmethod
125
+ def get_pytorch_model(
113
126
  cls,
114
127
  model_id: str,
115
- config: "PretrainedConfig",
116
128
  use_auth_token: Optional[Union[bool, str]] = None,
117
129
  revision: Optional[str] = None,
118
130
  force_download: bool = False,
@@ -120,135 +132,94 @@ class RBLNLlamaForCausalLM(RBLNBaseModel, RBLNGenerationMixin):
120
132
  subfolder: str = "",
121
133
  local_files_only: bool = False,
122
134
  trust_remote_code: bool = False,
123
- model_save_dir: Optional[Union[str, Path, TemporaryDirectory]] = None,
135
+ rbln_config_kwargs: Optional[Dict[str, Any]] = None,
136
+ rbln_constructor_kwargs: Optional[Dict[str, Any]] = None,
124
137
  **kwargs,
125
- ) -> "RBLNLlamaForCausalLM":
126
- task = kwargs.pop("task", None)
127
- if task is None:
128
- task = TasksManager.infer_task_from_model(cls.auto_model_class)
129
-
130
- if model_save_dir is None:
131
- save_dir = TemporaryDirectory()
132
- save_dir_path = Path(save_dir.name)
133
- else:
134
- save_dir = model_save_dir
135
- if isinstance(save_dir, TemporaryDirectory):
136
- save_dir_path = Path(model_save_dir.name)
137
- else:
138
- save_dir_path = Path(model_save_dir)
139
- save_dir_path.mkdir(exist_ok=True)
140
-
141
- def update_configs(kwargs):
142
- hf_max_position_embeddings = getattr(AutoConfig.from_pretrained(model_id), "max_position_embeddings", None)
143
- max_seq_len = kwargs.get("rbln_max_seq_len", None)
144
- if max_seq_len is not None:
145
- if max_seq_len <= hf_max_position_embeddings:
146
- kwargs.update({"max_position_embeddings": max_seq_len})
147
- else:
148
- raise ValueError("`max_seq_len` should be less or equal than max_position_embeddings!")
149
-
150
- kwargs.update(
151
- {
152
- "torchscript": True,
153
- "return_dict": False,
154
- "use_cache": True,
155
- "torch_dtype": torch.float32,
156
- "_attn_implementation": "eager",
157
- }
158
- )
159
-
160
- return kwargs
161
-
162
- kwargs = update_configs(kwargs)
163
-
164
- rbln_config_kwargs, rbln_constructor_kwargs = cls.pop_rbln_kwargs_from_kwargs(kwargs)
138
+ ) -> PreTrainedModel:
139
+ if rbln_max_seq_len := rbln_config_kwargs.get("rbln_max_seq_len", None):
140
+ config = AutoConfig.from_pretrained(model_id)
141
+ if hf_position_embedding := getattr(config, "max_position_embeddings", None):
142
+ if hf_position_embedding < rbln_max_seq_len:
143
+ logger.warning(
144
+ f"`rbln_max_seq_len` is larger than original config({hf_position_embedding})."
145
+ "This may lead to incorrect inferences of the model."
146
+ )
147
+ kwargs.update({"max_position_embeddings": rbln_max_seq_len})
165
148
 
166
149
  # FIXME :: This should be moved when wrapping removed.
167
150
  use_continuous_batch = rbln_config_kwargs.get("rbln_batching", "static") == "vllm"
168
- origin_mehtods = wrap_llama_cb() if use_continuous_batch else wrap_llama()
151
+ wrap_llama_cb() if use_continuous_batch else wrap_llama()
169
152
 
170
- model: LlamaForCausalLM = TasksManager.get_model_from_task(
171
- task=task,
172
- model_name_or_path=model_id,
173
- subfolder=subfolder,
153
+ model = super().get_pytorch_model(
154
+ model_id=model_id,
155
+ use_auth_token=use_auth_token,
174
156
  revision=revision,
175
- framework="pt",
157
+ force_download=force_download,
176
158
  cache_dir=cache_dir,
177
- use_auth_token=use_auth_token,
159
+ subfolder=subfolder,
178
160
  local_files_only=local_files_only,
179
- force_download=force_download,
180
161
  trust_remote_code=trust_remote_code,
162
+ rbln_config_kwargs=rbln_config_kwargs,
163
+ rbln_constructor_kwargs=rbln_constructor_kwargs,
181
164
  **kwargs,
182
165
  )
183
166
 
184
- if config is None:
185
- config = model.config
167
+ unwrap_llama()
186
168
 
187
- config.save_pretrained(save_dir_path)
188
- preprocessors = maybe_save_preprocessors(model_id, save_dir_path, src_subfolder=subfolder)
169
+ return model
189
170
 
190
- # Get compilation arguments
191
- if rbln_config_kwargs.get("rbln_config", None) is None:
192
- rbln_config = cls.get_rbln_config(
193
- preprocessors=preprocessors, model_config=model.config, **rbln_config_kwargs
194
- )
171
+ @classmethod
172
+ @torch.inference_mode()
173
+ def get_compiled_model(cls, model: "PreTrainedModel", rbln_config: RBLNConfig):
174
+ use_continuous_batch = rbln_config.meta["rbln_batching"] == "vllm"
195
175
 
196
- def compile_llama(use_continuous_batch, wrapper_cls):
197
- wrapped_model = wrapper_cls(model).eval()
176
+ wrapper_cls = LlamaWrapper_cb if use_continuous_batch else LlamaWrapper
198
177
 
199
- prefill_rbln_runtime_config = rbln_config[DEFAULT_COMPILED_MODEL_NAME][0]
200
- dec_rbln_runtime_config = rbln_config[DEFAULT_COMPILED_MODEL_NAME][1]
178
+ wrapped_model = wrapper_cls(model).eval()
201
179
 
202
- prefill_example_inputs = prefill_rbln_runtime_config.get_dummy_inputs(fill=0)
203
- dec_example_inputs = dec_rbln_runtime_config.get_dummy_inputs(fill=4)
180
+ prefill_rbln_runtime_config = rbln_config[DEFAULT_COMPILED_MODEL_NAME][0]
181
+ dec_rbln_runtime_config = rbln_config[DEFAULT_COMPILED_MODEL_NAME][1]
204
182
 
205
- if use_continuous_batch:
206
- batch_index_index = 3
207
- dec_example_inputs[batch_index_index].fill_(-1) # fill batch_position -1 to indicate it is decoder.
183
+ prefill_example_inputs = prefill_rbln_runtime_config.get_dummy_inputs(fill=0)
184
+ dec_example_inputs = dec_rbln_runtime_config.get_dummy_inputs(fill=4)
208
185
 
209
- prefill_scripted_model = torch.jit.trace(wrapped_model, prefill_example_inputs)
210
- dec_scripted_model = torch.jit.trace(wrapped_model, dec_example_inputs)
186
+ if use_continuous_batch:
187
+ batch_index_index = 3
188
+ dec_example_inputs[batch_index_index].fill_(-1) # fill batch_position -1 to indicate it is decoder.
211
189
 
212
- prefill_ir = rebel.torchscript_to_ir(
213
- prefill_scripted_model,
214
- input_names=[v[0] for v in prefill_rbln_runtime_config.input_info],
215
- )
216
- dec_ir = rebel.torchscript_to_ir(
217
- dec_scripted_model,
218
- input_names=[v[0] for v in dec_rbln_runtime_config.input_info],
219
- )
190
+ wrap_llama_cb() if use_continuous_batch else wrap_llama()
220
191
 
221
- # Caching prefill_decoder/decoder I/O
222
- cache_index_offset = 4 if use_continuous_batch else 3
223
- connections = [
224
- (prefill_ir.outputs[1 + i], prefill_ir.inputs[cache_index_offset + i])
225
- for i in range(model.config.num_hidden_layers * 2)
226
- ]
192
+ prefill_scripted_model = torch.jit.trace(wrapped_model, prefill_example_inputs, check_trace=False)
193
+ dec_scripted_model = torch.jit.trace(wrapped_model, dec_example_inputs, check_trace=False)
227
194
 
228
- compiled_model = rebel.compile(
229
- prefill_ir,
230
- dec_ir,
231
- connections=connections,
232
- fusion=prefill_rbln_runtime_config.fusion,
233
- npu=prefill_rbln_runtime_config.npu,
234
- tensor_parallel_size=prefill_rbln_runtime_config.tensor_parallel_size,
235
- use_weight_sharing=True,
236
- )
237
- compiled_model.save(save_dir_path / f"{DEFAULT_COMPILED_MODEL_NAME}.rbln")
195
+ unwrap_llama()
238
196
 
239
- wrapper_cls = LlamaWrapper_cb if use_continuous_batch else LlamaWrapper
240
- compile_llama(use_continuous_batch=use_continuous_batch, wrapper_cls=wrapper_cls)
241
- unwrap_llama(origin_mehtods)
197
+ prefill_ir = rebel.torchscript_to_ir(
198
+ prefill_scripted_model,
199
+ input_names=[v[0] for v in prefill_rbln_runtime_config.input_info],
200
+ )
201
+ dec_ir = rebel.torchscript_to_ir(
202
+ dec_scripted_model,
203
+ input_names=[v[0] for v in dec_rbln_runtime_config.input_info],
204
+ )
242
205
 
243
- rbln_config.save(save_dir_path)
206
+ # Caching prefill_decoder/decoder I/O
207
+ cache_index_offset = 4 if use_continuous_batch else 3
208
+ connections = [
209
+ (prefill_ir.outputs[1 + i], prefill_ir.inputs[cache_index_offset + i])
210
+ for i in range(model.config.num_hidden_layers * 2)
211
+ ]
244
212
 
245
- return cls._from_pretrained(
246
- model_id=save_dir_path,
247
- config=config,
248
- model_save_dir=save_dir,
249
- **rbln_constructor_kwargs,
250
- **kwargs,
213
+ compiled_model = rebel.compile(
214
+ prefill_ir,
215
+ dec_ir,
216
+ connections=connections,
217
+ fusion=prefill_rbln_runtime_config.fusion,
218
+ npu=prefill_rbln_runtime_config.npu,
219
+ tensor_parallel_size=prefill_rbln_runtime_config.tensor_parallel_size,
220
+ use_weight_sharing=True,
251
221
  )
222
+ return compiled_model
252
223
 
253
224
  @classmethod
254
225
  def _get_rbln_config(
@@ -338,11 +309,14 @@ class RBLNLlamaForCausalLM(RBLNBaseModel, RBLNGenerationMixin):
338
309
 
339
310
  return rbln_config
340
311
 
341
- def _create_runtimes(self, rbln_device_map: Dict[str, int]) -> List[rebel.Runtime]:
312
+ @classmethod
313
+ def _create_runtimes(
314
+ cls, compiled_models: List[rebel.RBLNCompiledModel], rbln_device_map: Dict[str, int]
315
+ ) -> List[rebel.Runtime]:
342
316
  device_val = rbln_device_map[DEFAULT_COMPILED_MODEL_NAME]
343
317
  return [
344
- self.compiled_models[0].create_runtime(input_info_index=0, tensor_type="pt", device=device_val),
345
- self.compiled_models[0].create_runtime(input_info_index=1, tensor_type="pt", device=device_val),
318
+ compiled_models[0].create_runtime(input_info_index=0, tensor_type="pt", device=device_val),
319
+ compiled_models[0].create_runtime(input_info_index=1, tensor_type="pt", device=device_val),
346
320
  ]
347
321
 
348
322
  def get_decoder(self):
@@ -23,20 +23,16 @@
23
23
 
24
24
  import inspect
25
25
  import logging
26
- from pathlib import Path
27
- from tempfile import TemporaryDirectory
28
26
  from typing import TYPE_CHECKING, Any, Callable, Dict, List, Optional, Tuple, Union
29
27
 
30
28
  import rebel
31
29
  import torch
32
- from optimum.exporters import TasksManager
33
- from transformers import AutoModelForCausalLM, PretrainedConfig, PreTrainedModel
30
+ from transformers import AutoConfig, AutoModelForCausalLM, PretrainedConfig, PreTrainedModel
34
31
  from transformers.modeling_outputs import CausalLMOutputWithCrossAttentions
35
32
 
36
- from ....modeling_base import RBLNBaseModel
33
+ from ....modeling_base import RBLNModel
37
34
  from ....modeling_config import DEFAULT_COMPILED_MODEL_NAME, RBLNConfig, RBLNRuntimeConfig
38
35
  from ....utils.runtime_utils import RBLNPytorchRuntime
39
- from ....utils.save_utils import maybe_save_preprocessors
40
36
  from ...generation.utils import RBLNGenerationMixin
41
37
  from .hf_hub_cached.modeling_midm import MidmLMHeadModel
42
38
  from .midm_architecture import (
@@ -74,7 +70,7 @@ class RBLNRuntimeDecoder(RBLNPytorchRuntime):
74
70
  return logits
75
71
 
76
72
 
77
- class RBLNMidmLMHeadModel(RBLNBaseModel, RBLNGenerationMixin):
73
+ class RBLNMidmLMHeadModel(RBLNModel, RBLNGenerationMixin):
78
74
  """
79
75
  The Midm Model transformer with a language modeling head on top (linear layer with weights tied to the input
80
76
  embeddings).
@@ -122,8 +118,8 @@ class RBLNMidmLMHeadModel(RBLNBaseModel, RBLNGenerationMixin):
122
118
  torch.ones(self.batch_size, 1, self.prefill_chunk_size, self.prefill_chunk_size), diagonal=1
123
119
  )
124
120
 
125
- self.prefill_decoder = RBLNRuntimeDecoder(runtime=self.runtimes[0], main_input_name="input_ids")
126
- self.decoder = RBLNRuntimeDecoder(runtime=self.runtimes[1], main_input_name="input_ids")
121
+ self.prefill_decoder = RBLNRuntimeDecoder(runtime=self.model[0], main_input_name="input_ids")
122
+ self.decoder = RBLNRuntimeDecoder(runtime=self.model[1], main_input_name="input_ids")
127
123
  self.past_cached_length = 0
128
124
 
129
125
  def can_generate(self):
@@ -149,10 +145,63 @@ class RBLNMidmLMHeadModel(RBLNBaseModel, RBLNGenerationMixin):
149
145
  raise NotImplementedError
150
146
 
151
147
  @classmethod
152
- def _export(
148
+ @torch.inference_mode()
149
+ def get_compiled_model(cls, model: "PreTrainedModel", rbln_config: RBLNConfig):
150
+ wrapped_decoder = MidmLMHeadModelWrapper(model).eval()
151
+ prefill_rbln_runtime_config = rbln_config[DEFAULT_COMPILED_MODEL_NAME][0]
152
+ dec_rbln_runtime_config = rbln_config[DEFAULT_COMPILED_MODEL_NAME][1]
153
+
154
+ prefill_example_inputs = prefill_rbln_runtime_config.get_dummy_inputs(fill=0)
155
+ dec_example_inputs = dec_rbln_runtime_config.get_dummy_inputs(fill=0)
156
+
157
+ prefill_scripted_model = torch.jit.trace(wrapped_decoder, prefill_example_inputs, check_trace=False)
158
+ dec_scripted_model = torch.jit.trace(wrapped_decoder, dec_example_inputs, check_trace=False)
159
+
160
+ prefill_ir = rebel.torchscript_to_ir(
161
+ prefill_scripted_model,
162
+ input_names=[v[0] for v in prefill_rbln_runtime_config.input_info],
163
+ )
164
+ dec_ir = rebel.torchscript_to_ir(
165
+ dec_scripted_model,
166
+ input_names=[v[0] for v in dec_rbln_runtime_config.input_info],
167
+ )
168
+
169
+ connections = [(prefill_ir.outputs[1 + i], prefill_ir.inputs[3 + i]) for i in range(model.config.n_layer * 2)]
170
+
171
+ compiled_model = rebel.compile(
172
+ prefill_ir,
173
+ dec_ir,
174
+ connections=connections,
175
+ fusion=prefill_rbln_runtime_config.fusion,
176
+ npu=prefill_rbln_runtime_config.npu,
177
+ tensor_parallel_size=prefill_rbln_runtime_config.tensor_parallel_size,
178
+ use_weight_sharing=True,
179
+ )
180
+ return compiled_model
181
+
182
+ @classmethod
183
+ def update_kwargs(cls, kwargs):
184
+ """
185
+ Update user-given kwargs to get proper pytorch model.
186
+
187
+ For example, `torchscript`=True should be set because torch.jit
188
+ does not support `transformers` output instances as module output;
189
+ """
190
+ kwargs.update(
191
+ {
192
+ "torchscript": True,
193
+ "return_dict": False,
194
+ "use_cache": True,
195
+ "torch_dtype": torch.float32,
196
+ "_attn_implementation": "eager",
197
+ }
198
+ )
199
+ return kwargs
200
+
201
+ @classmethod
202
+ def get_pytorch_model(
153
203
  cls,
154
204
  model_id: str,
155
- config: "PretrainedConfig",
156
205
  use_auth_token: Optional[Union[bool, str]] = None,
157
206
  revision: Optional[str] = None,
158
207
  force_download: bool = False,
@@ -160,120 +209,35 @@ class RBLNMidmLMHeadModel(RBLNBaseModel, RBLNGenerationMixin):
160
209
  subfolder: str = "",
161
210
  local_files_only: bool = False,
162
211
  trust_remote_code: bool = False,
163
- model_save_dir: Optional[Union[str, Path, TemporaryDirectory]] = None,
212
+ rbln_config_kwargs: Optional[Dict[str, Any]] = None,
213
+ rbln_constructor_kwargs: Optional[Dict[str, Any]] = None,
164
214
  **kwargs,
165
- ) -> "RBLNMidmLMHeadModel":
166
-
167
- task = kwargs.pop("task", None)
168
- if task is None:
169
- task = TasksManager.infer_task_from_model(cls.auto_model_class)
170
-
171
- if model_save_dir is None:
172
- save_dir = TemporaryDirectory()
173
- save_dir_path = Path(save_dir.name)
174
- else:
175
- save_dir = model_save_dir
176
- if isinstance(save_dir, TemporaryDirectory):
177
- save_dir_path = Path(model_save_dir.name)
178
- else:
179
- save_dir_path = Path(model_save_dir)
180
- save_dir_path.mkdir(exist_ok=True)
181
-
182
- def update_configs(kwargs):
183
- max_seq_len = kwargs.get("rbln_max_seq_len", None)
184
- if max_seq_len is not None:
185
- kwargs.update({"max_position_embeddings": max_seq_len})
186
-
187
- kwargs.update(
188
- {
189
- "torchscript": True,
190
- "return_dict": False,
191
- "use_cache": True,
192
- "torch_dtype": torch.float32,
193
- "_attn_implementation": "eager",
194
- }
195
- )
196
-
197
- return kwargs
198
-
199
- kwargs = update_configs(kwargs)
200
-
201
- rbln_config_kwargs, rbln_constructor_kwargs = cls.pop_rbln_kwargs_from_kwargs(kwargs)
215
+ ) -> PreTrainedModel:
216
+ if rbln_max_seq_len := rbln_config_kwargs.get("rbln_max_seq_len", None):
217
+ config = AutoConfig.from_pretrained(model_id, trust_remote_code=trust_remote_code)
218
+ if hf_position_embedding := getattr(config, "max_position_embeddings", None):
219
+ if hf_position_embedding < rbln_max_seq_len:
220
+ logger.warning(
221
+ f"`rbln_max_seq_len` is larger than original config({hf_position_embedding})."
222
+ "This may lead to incorrect inferences of the model."
223
+ )
224
+ kwargs.update({"max_position_embeddings": rbln_max_seq_len})
202
225
 
203
- model: MidmLMHeadModel = TasksManager.get_model_from_task(
204
- task=task,
205
- model_name_or_path=model_id,
206
- subfolder=subfolder,
226
+ return super().get_pytorch_model(
227
+ model_id=model_id,
228
+ use_auth_token=use_auth_token,
207
229
  revision=revision,
208
- framework="pt",
230
+ force_download=force_download,
209
231
  cache_dir=cache_dir,
210
- use_auth_token=use_auth_token,
232
+ subfolder=subfolder,
211
233
  local_files_only=local_files_only,
212
- force_download=force_download,
213
234
  trust_remote_code=trust_remote_code,
235
+ rbln_config_kwargs=rbln_config_kwargs,
236
+ rbln_constructor_kwargs=rbln_constructor_kwargs,
214
237
  ignore_mismatched_sizes=True,
215
238
  **kwargs,
216
239
  )
217
240
 
218
- if config is None:
219
- config = model.config
220
-
221
- config.save_pretrained(save_dir_path)
222
- preprocessors = maybe_save_preprocessors(model_id, save_dir_path, src_subfolder=subfolder)
223
-
224
- # Get compilation arguments
225
- if rbln_config_kwargs.get("rbln_config", None) is None:
226
- rbln_config = cls.get_rbln_config(
227
- preprocessors=preprocessors, model_config=model.config, **rbln_config_kwargs
228
- )
229
-
230
- def compile_midm():
231
- wrapped_decoder = MidmLMHeadModelWrapper(model).eval()
232
- prefill_rbln_runtime_config = rbln_config[DEFAULT_COMPILED_MODEL_NAME][0]
233
- dec_rbln_runtime_config = rbln_config[DEFAULT_COMPILED_MODEL_NAME][1]
234
-
235
- prefill_example_inputs = prefill_rbln_runtime_config.get_dummy_inputs(fill=0)
236
- dec_example_inputs = dec_rbln_runtime_config.get_dummy_inputs(fill=0)
237
-
238
- prefill_scripted_model = torch.jit.trace(wrapped_decoder, prefill_example_inputs)
239
- dec_scripted_model = torch.jit.trace(wrapped_decoder, dec_example_inputs)
240
-
241
- prefill_ir = rebel.torchscript_to_ir(
242
- prefill_scripted_model,
243
- input_names=[v[0] for v in prefill_rbln_runtime_config.input_info],
244
- )
245
- dec_ir = rebel.torchscript_to_ir(
246
- dec_scripted_model,
247
- input_names=[v[0] for v in dec_rbln_runtime_config.input_info],
248
- )
249
-
250
- connections = [
251
- (prefill_ir.outputs[1 + i], prefill_ir.inputs[3 + i]) for i in range(model.config.n_layer * 2)
252
- ]
253
-
254
- compiled_model = rebel.compile(
255
- prefill_ir,
256
- dec_ir,
257
- connections=connections,
258
- fusion=prefill_rbln_runtime_config.fusion,
259
- npu=prefill_rbln_runtime_config.npu,
260
- tensor_parallel_size=prefill_rbln_runtime_config.tensor_parallel_size,
261
- use_weight_sharing=True,
262
- )
263
- compiled_model.save(save_dir_path / f"{DEFAULT_COMPILED_MODEL_NAME}.rbln")
264
-
265
- compile_midm()
266
-
267
- rbln_config.save(save_dir_path)
268
-
269
- return cls._from_pretrained(
270
- model_id=save_dir_path,
271
- config=config,
272
- model_save_dir=save_dir,
273
- **rbln_constructor_kwargs,
274
- **kwargs,
275
- )
276
-
277
241
  @classmethod
278
242
  def _get_rbln_config(
279
243
  cls,
@@ -345,11 +309,14 @@ class RBLNMidmLMHeadModel(RBLNBaseModel, RBLNGenerationMixin):
345
309
 
346
310
  return rbln_config
347
311
 
348
- def _create_runtimes(self, rbln_device_map: Dict[str, int]) -> List[rebel.Runtime]:
312
+ @classmethod
313
+ def _create_runtimes(
314
+ cls, compiled_models: List[rebel.RBLNCompiledModel], rbln_device_map: Dict[str, int]
315
+ ) -> List[rebel.Runtime]:
349
316
  device_val = rbln_device_map[DEFAULT_COMPILED_MODEL_NAME]
350
317
  return [
351
- self.compiled_models[0].create_runtime(input_info_index=0, tensor_type="pt", device=device_val),
352
- self.compiled_models[0].create_runtime(input_info_index=1, tensor_type="pt", device=device_val),
318
+ compiled_models[0].create_runtime(input_info_index=0, tensor_type="pt", device=device_val),
319
+ compiled_models[0].create_runtime(input_info_index=1, tensor_type="pt", device=device_val),
353
320
  ]
354
321
 
355
322
  def prepare_inputs_for_generation(self, input_ids, past_key_values=0, attention_mask=None, **kwargs):
@@ -421,6 +388,3 @@ class RBLNMidmLMHeadModel(RBLNBaseModel, RBLNGenerationMixin):
421
388
  cache_position=cache_position,
422
389
  )
423
390
  return CausalLMOutputWithCrossAttentions(logits=output, past_key_values=past_cached_length)
424
-
425
- def __repr__(self):
426
- return repr(self.runtimes[0]) + "\n" + repr(self.runtimes[1])