optimum-rbln 0.1.13__py3-none-any.whl → 0.2.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (103) hide show
  1. optimum/rbln/__init__.py +41 -38
  2. optimum/rbln/__version__.py +16 -1
  3. optimum/rbln/diffusers/__init__.py +26 -2
  4. optimum/rbln/{modeling_diffusers.py → diffusers/modeling_diffusers.py} +97 -126
  5. optimum/rbln/diffusers/models/__init__.py +36 -3
  6. optimum/rbln/{transformers/generation → diffusers/models/autoencoders}/__init__.py +1 -2
  7. optimum/rbln/diffusers/models/{autoencoder_kl.py → autoencoders/autoencoder_kl.py} +73 -61
  8. optimum/rbln/diffusers/models/autoencoders/vae.py +83 -0
  9. optimum/rbln/diffusers/models/controlnet.py +54 -14
  10. optimum/rbln/diffusers/models/transformers/__init__.py +24 -0
  11. optimum/rbln/diffusers/models/transformers/transformer_sd3.py +203 -0
  12. optimum/rbln/diffusers/models/unets/__init__.py +24 -0
  13. optimum/rbln/diffusers/models/{unet_2d_condition.py → unets/unet_2d_condition.py} +82 -22
  14. optimum/rbln/diffusers/pipelines/__init__.py +23 -2
  15. optimum/rbln/diffusers/pipelines/controlnet/multicontrolnet.py +13 -33
  16. optimum/rbln/diffusers/pipelines/controlnet/pipeline_controlnet.py +17 -2
  17. optimum/rbln/diffusers/pipelines/controlnet/pipeline_controlnet_img2img.py +18 -2
  18. optimum/rbln/diffusers/pipelines/controlnet/pipeline_controlnet_sd_xl.py +18 -2
  19. optimum/rbln/diffusers/pipelines/controlnet/pipeline_controlnet_sd_xl_img2img.py +18 -2
  20. optimum/rbln/diffusers/pipelines/stable_diffusion/__init__.py +1 -0
  21. optimum/rbln/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion.py +2 -2
  22. optimum/rbln/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_img2img.py +1 -13
  23. optimum/rbln/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_inpaint.py +31 -0
  24. optimum/rbln/diffusers/pipelines/stable_diffusion_3/__init__.py +26 -0
  25. optimum/rbln/diffusers/pipelines/stable_diffusion_3/pipeline_stable_diffusion_3.py +31 -0
  26. optimum/rbln/diffusers/pipelines/stable_diffusion_3/pipeline_stable_diffusion_3_img2img.py +31 -0
  27. optimum/rbln/diffusers/pipelines/stable_diffusion_3/pipeline_stable_diffusion_3_inpaint.py +31 -0
  28. optimum/rbln/diffusers/pipelines/stable_diffusion_xl/__init__.py +24 -0
  29. optimum/rbln/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl.py +15 -8
  30. optimum/rbln/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_img2img.py +15 -8
  31. optimum/rbln/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_inpaint.py +31 -0
  32. optimum/rbln/modeling.py +238 -0
  33. optimum/rbln/modeling_base.py +186 -760
  34. optimum/rbln/modeling_config.py +31 -7
  35. optimum/rbln/ops/__init__.py +26 -0
  36. optimum/rbln/ops/attn.py +221 -0
  37. optimum/rbln/ops/flash_attn.py +70 -0
  38. optimum/rbln/ops/kv_cache_update.py +69 -0
  39. optimum/rbln/transformers/__init__.py +20 -2
  40. optimum/rbln/{modeling_alias.py → transformers/modeling_alias.py} +5 -1
  41. optimum/rbln/transformers/modeling_generic.py +385 -0
  42. optimum/rbln/transformers/models/auto/__init__.py +23 -0
  43. optimum/rbln/transformers/models/auto/auto_factory.py +117 -23
  44. optimum/rbln/transformers/models/auto/modeling_auto.py +36 -12
  45. optimum/rbln/transformers/models/bart/__init__.py +0 -1
  46. optimum/rbln/transformers/models/bart/bart_architecture.py +107 -464
  47. optimum/rbln/transformers/models/bart/modeling_bart.py +10 -9
  48. optimum/rbln/transformers/models/bert/modeling_bert.py +3 -6
  49. optimum/rbln/transformers/models/clip/modeling_clip.py +8 -25
  50. optimum/rbln/transformers/models/decoderonly/__init__.py +0 -10
  51. optimum/rbln/transformers/models/decoderonly/decoderonly_architecture.py +775 -514
  52. optimum/rbln/transformers/models/decoderonly/modeling_decoderonly.py +128 -260
  53. optimum/rbln/transformers/models/dpt/modeling_dpt.py +1 -1
  54. optimum/rbln/transformers/models/exaone/exaone_architecture.py +60 -45
  55. optimum/rbln/transformers/models/exaone/modeling_exaone.py +4 -2
  56. optimum/rbln/transformers/models/gemma/gemma_architecture.py +33 -104
  57. optimum/rbln/transformers/models/gpt2/gpt2_architecture.py +50 -238
  58. optimum/rbln/transformers/models/gpt2/modeling_gpt2.py +3 -2
  59. optimum/rbln/transformers/models/llama/llama_architecture.py +0 -1
  60. optimum/rbln/transformers/models/llava_next/modeling_llava_next.py +3 -75
  61. optimum/rbln/transformers/models/midm/midm_architecture.py +84 -238
  62. optimum/rbln/transformers/models/midm/modeling_midm.py +5 -6
  63. optimum/rbln/transformers/models/mistral/mistral_architecture.py +0 -1
  64. optimum/rbln/transformers/models/phi/phi_architecture.py +60 -261
  65. optimum/rbln/transformers/models/qwen2/qwen2_architecture.py +0 -1
  66. optimum/rbln/transformers/models/seq2seq/modeling_seq2seq.py +58 -103
  67. optimum/rbln/transformers/models/seq2seq/seq2seq_architecture.py +498 -0
  68. optimum/rbln/transformers/models/t5/__init__.py +0 -1
  69. optimum/rbln/transformers/models/t5/modeling_t5.py +106 -5
  70. optimum/rbln/transformers/models/t5/t5_architecture.py +106 -448
  71. optimum/rbln/transformers/models/wav2vec2/modeling_wav2vec2.py +1 -1
  72. optimum/rbln/transformers/models/whisper/generation_whisper.py +42 -0
  73. optimum/rbln/transformers/models/whisper/modeling_whisper.py +78 -55
  74. optimum/rbln/transformers/models/whisper/whisper_architecture.py +219 -312
  75. optimum/rbln/transformers/models/xlm_roberta/modeling_xlm_roberta.py +3 -35
  76. optimum/rbln/transformers/utils/rbln_quantization.py +120 -4
  77. optimum/rbln/utils/decorator_utils.py +51 -11
  78. optimum/rbln/utils/hub.py +131 -0
  79. optimum/rbln/utils/import_utils.py +22 -1
  80. optimum/rbln/utils/logging.py +37 -0
  81. optimum/rbln/utils/model_utils.py +52 -0
  82. optimum/rbln/utils/runtime_utils.py +10 -4
  83. optimum/rbln/utils/save_utils.py +17 -0
  84. optimum/rbln/utils/submodule.py +137 -0
  85. optimum_rbln-0.2.0.dist-info/METADATA +117 -0
  86. optimum_rbln-0.2.0.dist-info/RECORD +114 -0
  87. {optimum_rbln-0.1.13.dist-info → optimum_rbln-0.2.0.dist-info}/WHEEL +1 -1
  88. optimum_rbln-0.2.0.dist-info/licenses/LICENSE +288 -0
  89. optimum/rbln/transformers/cache_utils.py +0 -107
  90. optimum/rbln/transformers/generation/streamers.py +0 -139
  91. optimum/rbln/transformers/generation/utils.py +0 -397
  92. optimum/rbln/transformers/models/exaone/hf_hub_cached/configuration_exaone.py +0 -181
  93. optimum/rbln/transformers/models/exaone/hf_hub_cached/modeling_exaone.py +0 -1725
  94. optimum/rbln/transformers/models/midm/hf_hub_cached/configuration_midm.py +0 -22
  95. optimum/rbln/transformers/models/midm/hf_hub_cached/midm_bitext_tokenization.py +0 -304
  96. optimum/rbln/transformers/models/midm/hf_hub_cached/modeling_midm.py +0 -1469
  97. optimum/rbln/transformers/models/midm/hf_hub_cached/rotary_position_embedding.py +0 -98
  98. optimum/rbln/utils/context.py +0 -58
  99. optimum/rbln/utils/timer_utils.py +0 -43
  100. optimum_rbln-0.1.13.dist-info/METADATA +0 -120
  101. optimum_rbln-0.1.13.dist-info/RECORD +0 -107
  102. optimum_rbln-0.1.13.dist-info/entry_points.txt +0 -4
  103. optimum_rbln-0.1.13.dist-info/licenses/LICENSE +0 -201
@@ -0,0 +1,31 @@
1
+ # Copyright 2024 Rebellions Inc.
2
+
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at:
6
+
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ # Portions of this software are licensed under the Apache License,
16
+ # Version 2.0. See the NOTICE file distributed with this work for
17
+ # additional information regarding copyright ownership.
18
+
19
+ # All other portions of this software, including proprietary code,
20
+ # are the intellectual property of Rebellions Inc. and may not be
21
+ # copied, modified, or distributed without prior written permission
22
+ # from Rebellions Inc.
23
+
24
+ from diffusers import StableDiffusion3Img2ImgPipeline
25
+
26
+ from ...modeling_diffusers import RBLNDiffusionMixin
27
+
28
+
29
+ class RBLNStableDiffusion3Img2ImgPipeline(RBLNDiffusionMixin, StableDiffusion3Img2ImgPipeline):
30
+ original_class = StableDiffusion3Img2ImgPipeline
31
+ _submodules = ["transformer", "text_encoder_3", "text_encoder", "text_encoder_2", "vae"]
@@ -0,0 +1,31 @@
1
+ # Copyright 2024 Rebellions Inc.
2
+
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at:
6
+
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ # Portions of this software are licensed under the Apache License,
16
+ # Version 2.0. See the NOTICE file distributed with this work for
17
+ # additional information regarding copyright ownership.
18
+
19
+ # All other portions of this software, including proprietary code,
20
+ # are the intellectual property of Rebellions Inc. and may not be
21
+ # copied, modified, or distributed without prior written permission
22
+ # from Rebellions Inc.
23
+
24
+ from diffusers import StableDiffusion3InpaintPipeline
25
+
26
+ from ...modeling_diffusers import RBLNDiffusionMixin
27
+
28
+
29
+ class RBLNStableDiffusion3InpaintPipeline(RBLNDiffusionMixin, StableDiffusion3InpaintPipeline):
30
+ original_class = StableDiffusion3InpaintPipeline
31
+ _submodules = ["transformer", "text_encoder_3", "text_encoder", "text_encoder_2", "vae"]
@@ -1,2 +1,26 @@
1
+ # Copyright 2024 Rebellions Inc.
2
+
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at:
6
+
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ # Portions of this software are licensed under the Apache License,
16
+ # Version 2.0. See the NOTICE file distributed with this work for
17
+ # additional information regarding copyright ownership.
18
+
19
+ # All other portions of this software, including proprietary code,
20
+ # are the intellectual property of Rebellions Inc. and may not be
21
+ # copied, modified, or distributed without prior written permission
22
+ # from Rebellions Inc.
23
+
1
24
  from .pipeline_stable_diffusion_xl import RBLNStableDiffusionXLPipeline
2
25
  from .pipeline_stable_diffusion_xl_img2img import RBLNStableDiffusionXLImg2ImgPipeline
26
+ from .pipeline_stable_diffusion_xl_inpaint import RBLNStableDiffusionXLInpaintPipeline
@@ -1,22 +1,29 @@
1
- # coding=utf-8
2
- # Copyright 2023 The HuggingFace Inc. team. All rights reserved.
3
- #
1
+ # Copyright 2024 Rebellions Inc.
2
+
4
3
  # Licensed under the Apache License, Version 2.0 (the "License");
5
4
  # you may not use this file except in compliance with the License.
6
- # You may obtain a copy of the License at
7
- #
5
+ # You may obtain a copy of the License at:
6
+
8
7
  # http://www.apache.org/licenses/LICENSE-2.0
9
- #
8
+
10
9
  # Unless required by applicable law or agreed to in writing, software
11
10
  # distributed under the License is distributed on an "AS IS" BASIS,
12
11
  # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
12
  # See the License for the specific language governing permissions and
14
13
  # limitations under the License.
15
- """RBLNStableDiffusionXLPipeline class for inference of diffusion models on rbln devices."""
14
+
15
+ # Portions of this software are licensed under the Apache License,
16
+ # Version 2.0. See the NOTICE file distributed with this work for
17
+ # additional information regarding copyright ownership.
18
+
19
+ # All other portions of this software, including proprietary code,
20
+ # are the intellectual property of Rebellions Inc. and may not be
21
+ # copied, modified, or distributed without prior written permission
22
+ # from Rebellions Inc.
16
23
 
17
24
  from diffusers import StableDiffusionXLPipeline
18
25
 
19
- from ....modeling_diffusers import RBLNDiffusionMixin
26
+ from ...modeling_diffusers import RBLNDiffusionMixin
20
27
 
21
28
 
22
29
  class RBLNStableDiffusionXLPipeline(RBLNDiffusionMixin, StableDiffusionXLPipeline):
@@ -1,22 +1,29 @@
1
- # coding=utf-8
2
- # Copyright 2023 The HuggingFace Inc. team. All rights reserved.
3
- #
1
+ # Copyright 2024 Rebellions Inc.
2
+
4
3
  # Licensed under the Apache License, Version 2.0 (the "License");
5
4
  # you may not use this file except in compliance with the License.
6
- # You may obtain a copy of the License at
7
- #
5
+ # You may obtain a copy of the License at:
6
+
8
7
  # http://www.apache.org/licenses/LICENSE-2.0
9
- #
8
+
10
9
  # Unless required by applicable law or agreed to in writing, software
11
10
  # distributed under the License is distributed on an "AS IS" BASIS,
12
11
  # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
12
  # See the License for the specific language governing permissions and
14
13
  # limitations under the License.
15
- """RBLNStableDiffusionXLPipeline class for inference of diffusion models on rbln devices."""
14
+
15
+ # Portions of this software are licensed under the Apache License,
16
+ # Version 2.0. See the NOTICE file distributed with this work for
17
+ # additional information regarding copyright ownership.
18
+
19
+ # All other portions of this software, including proprietary code,
20
+ # are the intellectual property of Rebellions Inc. and may not be
21
+ # copied, modified, or distributed without prior written permission
22
+ # from Rebellions Inc.
16
23
 
17
24
  from diffusers import StableDiffusionXLImg2ImgPipeline
18
25
 
19
- from ....modeling_diffusers import RBLNDiffusionMixin
26
+ from ...modeling_diffusers import RBLNDiffusionMixin
20
27
 
21
28
 
22
29
  class RBLNStableDiffusionXLImg2ImgPipeline(RBLNDiffusionMixin, StableDiffusionXLImg2ImgPipeline):
@@ -0,0 +1,31 @@
1
+ # Copyright 2024 Rebellions Inc.
2
+
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at:
6
+
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ # Portions of this software are licensed under the Apache License,
16
+ # Version 2.0. See the NOTICE file distributed with this work for
17
+ # additional information regarding copyright ownership.
18
+
19
+ # All other portions of this software, including proprietary code,
20
+ # are the intellectual property of Rebellions Inc. and may not be
21
+ # copied, modified, or distributed without prior written permission
22
+ # from Rebellions Inc.
23
+
24
+ from diffusers import StableDiffusionXLInpaintPipeline
25
+
26
+ from ...modeling_diffusers import RBLNDiffusionMixin
27
+
28
+
29
+ class RBLNStableDiffusionXLInpaintPipeline(RBLNDiffusionMixin, StableDiffusionXLInpaintPipeline):
30
+ original_class = StableDiffusionXLInpaintPipeline
31
+ _submodules = ["text_encoder", "text_encoder_2", "unet", "vae"]
optimum/rbln/modeling.py CHANGED
@@ -0,0 +1,238 @@
1
+ # Copyright 2024 Rebellions Inc.
2
+
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at:
6
+
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ # Portions of this software are licensed under the Apache License,
16
+ # Version 2.0. See the NOTICE file distributed with this work for
17
+ # additional information regarding copyright ownership.
18
+
19
+ # All other portions of this software, including proprietary code,
20
+ # are the intellectual property of Rebellions Inc. and may not be
21
+ # copied, modified, or distributed without prior written permission
22
+ # from Rebellions Inc.
23
+
24
+ import logging
25
+ from pathlib import Path
26
+ from tempfile import TemporaryDirectory
27
+ from typing import TYPE_CHECKING, Any, Dict, List, Optional, Union
28
+
29
+ import rebel
30
+ import torch
31
+ from huggingface_hub.constants import HUGGINGFACE_HUB_CACHE
32
+ from transformers import AutoConfig, PretrainedConfig
33
+
34
+ from .modeling_base import RBLNBaseModel
35
+ from .modeling_config import DEFAULT_COMPILED_MODEL_NAME, RBLNConfig, use_rbln_config
36
+
37
+
38
+ if TYPE_CHECKING:
39
+ from transformers import PreTrainedModel
40
+
41
+
42
+ logger = logging.getLogger(__name__)
43
+
44
+
45
+ class RBLNModel(RBLNBaseModel):
46
+ """
47
+ A class that inherits from RBLNBaseModel for models consisting of a single `torch.nn.Module`.
48
+
49
+ This class supports all the functionality of RBLNBaseModel, including loading and saving models using
50
+ the `from_pretrained` and `save_pretrained` methods, compiling PyTorch models for execution on RBLN NPU
51
+ devices.
52
+
53
+ Example:
54
+ ```python
55
+ model = RBLNModel.from_pretrained("model_id", export=True, rbln_npu="npu_name")
56
+ outputs = model(**inputs)
57
+ ```
58
+ """
59
+
60
+ @classmethod
61
+ def update_kwargs(cls, kwargs):
62
+ """
63
+ Update user-given kwargs to get proper pytorch model.
64
+
65
+ For example, `torchscript`=True should be set because torch.jit
66
+ does not support `transformers` output instances as module output;
67
+ """
68
+ kwargs.update(
69
+ {
70
+ "torchscript": True,
71
+ "return_dict": False,
72
+ }
73
+ )
74
+ return kwargs
75
+
76
+ @classmethod
77
+ def save_torch_artifacts(
78
+ cls,
79
+ model: "PreTrainedModel",
80
+ save_dir_path: Path,
81
+ subfolder: str,
82
+ rbln_config: RBLNConfig,
83
+ ):
84
+ """
85
+ If you are unavoidably running on a CPU rather than an RBLN device,
86
+ store the torch tensor, weight, etc. in this function.
87
+ """
88
+
89
+ @classmethod
90
+ def wrap_model_if_needed(cls, model: torch.nn.Module, rbln_config: RBLNConfig) -> torch.nn.Module:
91
+ # Wrap the model if needed.
92
+ return model
93
+
94
+ @classmethod
95
+ def get_compiled_model(cls, model: "PreTrainedModel", rbln_config: RBLNConfig):
96
+ model = cls.wrap_model_if_needed(model, rbln_config)
97
+ rbln_compile_config = rbln_config.compile_cfgs[0]
98
+ compiled_model = cls.compile(model, rbln_compile_config=rbln_compile_config)
99
+ return compiled_model
100
+
101
+ @classmethod
102
+ @use_rbln_config
103
+ def from_model(
104
+ cls,
105
+ model: "PreTrainedModel",
106
+ config: Optional[PretrainedConfig] = None,
107
+ rbln_config: Dict[str, Any] = {},
108
+ model_save_dir: Optional[Union[str, Path, TemporaryDirectory]] = None,
109
+ subfolder: str = "",
110
+ **kwargs,
111
+ ):
112
+ preprocessors = kwargs.pop("preprocessors", [])
113
+ rbln_kwargs = rbln_config
114
+
115
+ # Directory to save compile artifacts(.rbln) and original configs
116
+ if model_save_dir is None:
117
+ save_dir = TemporaryDirectory()
118
+ save_dir_path = Path(save_dir.name)
119
+ else:
120
+ save_dir = model_save_dir
121
+ if isinstance(save_dir, TemporaryDirectory):
122
+ save_dir_path = Path(model_save_dir.name)
123
+ else:
124
+ save_dir_path = Path(model_save_dir)
125
+ save_dir_path.mkdir(exist_ok=True)
126
+
127
+ # Save configs
128
+ if config is None:
129
+ config = model.config
130
+ # remote_config
131
+ if hasattr(config, "auto_map") and "AutoConfig" in config.auto_map:
132
+ config = AutoConfig.from_pretrained(config._name_or_path, **kwargs)
133
+
134
+ if hasattr(model, "can_generate") and model.can_generate():
135
+ generation_config = model.generation_config
136
+ generation_config.save_pretrained(save_dir_path / subfolder)
137
+
138
+ if not isinstance(config, PretrainedConfig): # diffusers config
139
+ config = PretrainedConfig(**config)
140
+ config.save_pretrained(save_dir_path / subfolder)
141
+
142
+ # Save preprocessor
143
+ for preprocessor in preprocessors:
144
+ preprocessor.save_pretrained(save_dir_path / subfolder)
145
+
146
+ # Get compilation arguments (e.g. input_info)
147
+ rbln_config: RBLNConfig = cls.get_rbln_config(
148
+ preprocessors=preprocessors, model_config=config, rbln_kwargs=rbln_kwargs
149
+ )
150
+ # rbln_config.update_runtime_cfg(rbln_kwargs) # This is done in get_rbln_config
151
+
152
+ compiled_model: Union[rebel.RBLNCompiledModel, Dict[str, rebel.RBLNCompiledModel]] = cls.get_compiled_model(
153
+ model, rbln_config=rbln_config
154
+ )
155
+
156
+ # Save compiled models (.rbln)
157
+ (save_dir_path / subfolder).mkdir(exist_ok=True)
158
+ if not isinstance(compiled_model, dict):
159
+ compiled_models = {DEFAULT_COMPILED_MODEL_NAME: compiled_model}
160
+ else:
161
+ compiled_models = compiled_model
162
+ for compiled_model_name, cm in compiled_models.items():
163
+ cm.save(save_dir_path / subfolder / f"{compiled_model_name}.rbln")
164
+ rbln_config.save(save_dir_path / subfolder)
165
+
166
+ # Save torch artifacts (e.g. embedding matrix if needed.)
167
+ cls.save_torch_artifacts(model, save_dir_path=save_dir_path, subfolder=subfolder, rbln_config=rbln_config)
168
+
169
+ # Load submodules
170
+ if len(cls._rbln_submodules) > 0:
171
+ rbln_submodules = cls._load_submodules(
172
+ model=model,
173
+ model_save_dir=save_dir,
174
+ rbln_kwargs=rbln_kwargs,
175
+ **kwargs,
176
+ )
177
+ else:
178
+ rbln_submodules = []
179
+
180
+ # Instantiate
181
+ return cls._from_pretrained(
182
+ model_id=save_dir_path,
183
+ config=config,
184
+ model_save_dir=save_dir,
185
+ subfolder=subfolder,
186
+ rbln_config=rbln_config,
187
+ rbln_compiled_models=compiled_models,
188
+ rbln_submodules=rbln_submodules,
189
+ **kwargs,
190
+ )
191
+
192
+ @classmethod
193
+ def get_pytorch_model(
194
+ cls,
195
+ model_id: str,
196
+ use_auth_token: Optional[Union[bool, str]] = None,
197
+ revision: Optional[str] = None,
198
+ force_download: bool = False,
199
+ cache_dir: Optional[str] = HUGGINGFACE_HUB_CACHE,
200
+ subfolder: str = "",
201
+ local_files_only: bool = False,
202
+ trust_remote_code: bool = False,
203
+ # Some rbln-kwargs should be applied before loading torch module (i.e. quantized llm)
204
+ rbln_kwargs: Optional[Dict[str, Any]] = None,
205
+ **kwargs,
206
+ ) -> "PreTrainedModel":
207
+ kwargs = cls.update_kwargs(kwargs)
208
+ return cls.hf_class.from_pretrained(
209
+ model_id,
210
+ subfolder=subfolder,
211
+ revision=revision,
212
+ cache_dir=cache_dir,
213
+ use_auth_token=use_auth_token,
214
+ local_files_only=local_files_only,
215
+ force_download=force_download,
216
+ trust_remote_code=trust_remote_code,
217
+ **kwargs,
218
+ )
219
+
220
+ @classmethod
221
+ def _create_runtimes(
222
+ cls,
223
+ compiled_models: List[rebel.RBLNCompiledModel],
224
+ rbln_device_map: Dict[str, int],
225
+ activate_profiler: Optional[bool] = None,
226
+ ) -> List[rebel.Runtime]:
227
+ if DEFAULT_COMPILED_MODEL_NAME not in rbln_device_map:
228
+ cls._raise_missing_compiled_file_error([DEFAULT_COMPILED_MODEL_NAME])
229
+
230
+ device = rbln_device_map[DEFAULT_COMPILED_MODEL_NAME]
231
+ return [
232
+ compiled_model.create_runtime(tensor_type="pt", device=device, activate_profiler=activate_profiler)
233
+ for compiled_model in compiled_models
234
+ ]
235
+
236
+ def forward(self, *args: List[torch.Tensor], **kwargs: Dict[str, torch.Tensor]):
237
+ output = self.model[0](*args, **kwargs)
238
+ return output