optimum-rbln 0.1.13__py3-none-any.whl → 0.2.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (103) hide show
  1. optimum/rbln/__init__.py +41 -38
  2. optimum/rbln/__version__.py +16 -1
  3. optimum/rbln/diffusers/__init__.py +26 -2
  4. optimum/rbln/{modeling_diffusers.py → diffusers/modeling_diffusers.py} +97 -126
  5. optimum/rbln/diffusers/models/__init__.py +36 -3
  6. optimum/rbln/{transformers/generation → diffusers/models/autoencoders}/__init__.py +1 -2
  7. optimum/rbln/diffusers/models/{autoencoder_kl.py → autoencoders/autoencoder_kl.py} +73 -61
  8. optimum/rbln/diffusers/models/autoencoders/vae.py +83 -0
  9. optimum/rbln/diffusers/models/controlnet.py +54 -14
  10. optimum/rbln/diffusers/models/transformers/__init__.py +24 -0
  11. optimum/rbln/diffusers/models/transformers/transformer_sd3.py +203 -0
  12. optimum/rbln/diffusers/models/unets/__init__.py +24 -0
  13. optimum/rbln/diffusers/models/{unet_2d_condition.py → unets/unet_2d_condition.py} +82 -22
  14. optimum/rbln/diffusers/pipelines/__init__.py +23 -2
  15. optimum/rbln/diffusers/pipelines/controlnet/multicontrolnet.py +13 -33
  16. optimum/rbln/diffusers/pipelines/controlnet/pipeline_controlnet.py +17 -2
  17. optimum/rbln/diffusers/pipelines/controlnet/pipeline_controlnet_img2img.py +18 -2
  18. optimum/rbln/diffusers/pipelines/controlnet/pipeline_controlnet_sd_xl.py +18 -2
  19. optimum/rbln/diffusers/pipelines/controlnet/pipeline_controlnet_sd_xl_img2img.py +18 -2
  20. optimum/rbln/diffusers/pipelines/stable_diffusion/__init__.py +1 -0
  21. optimum/rbln/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion.py +2 -2
  22. optimum/rbln/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_img2img.py +1 -13
  23. optimum/rbln/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_inpaint.py +31 -0
  24. optimum/rbln/diffusers/pipelines/stable_diffusion_3/__init__.py +26 -0
  25. optimum/rbln/diffusers/pipelines/stable_diffusion_3/pipeline_stable_diffusion_3.py +31 -0
  26. optimum/rbln/diffusers/pipelines/stable_diffusion_3/pipeline_stable_diffusion_3_img2img.py +31 -0
  27. optimum/rbln/diffusers/pipelines/stable_diffusion_3/pipeline_stable_diffusion_3_inpaint.py +31 -0
  28. optimum/rbln/diffusers/pipelines/stable_diffusion_xl/__init__.py +24 -0
  29. optimum/rbln/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl.py +15 -8
  30. optimum/rbln/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_img2img.py +15 -8
  31. optimum/rbln/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_inpaint.py +31 -0
  32. optimum/rbln/modeling.py +238 -0
  33. optimum/rbln/modeling_base.py +186 -760
  34. optimum/rbln/modeling_config.py +31 -7
  35. optimum/rbln/ops/__init__.py +26 -0
  36. optimum/rbln/ops/attn.py +221 -0
  37. optimum/rbln/ops/flash_attn.py +70 -0
  38. optimum/rbln/ops/kv_cache_update.py +69 -0
  39. optimum/rbln/transformers/__init__.py +20 -2
  40. optimum/rbln/{modeling_alias.py → transformers/modeling_alias.py} +5 -1
  41. optimum/rbln/transformers/modeling_generic.py +385 -0
  42. optimum/rbln/transformers/models/auto/__init__.py +23 -0
  43. optimum/rbln/transformers/models/auto/auto_factory.py +117 -23
  44. optimum/rbln/transformers/models/auto/modeling_auto.py +36 -12
  45. optimum/rbln/transformers/models/bart/__init__.py +0 -1
  46. optimum/rbln/transformers/models/bart/bart_architecture.py +107 -464
  47. optimum/rbln/transformers/models/bart/modeling_bart.py +10 -9
  48. optimum/rbln/transformers/models/bert/modeling_bert.py +3 -6
  49. optimum/rbln/transformers/models/clip/modeling_clip.py +8 -25
  50. optimum/rbln/transformers/models/decoderonly/__init__.py +0 -10
  51. optimum/rbln/transformers/models/decoderonly/decoderonly_architecture.py +775 -514
  52. optimum/rbln/transformers/models/decoderonly/modeling_decoderonly.py +128 -260
  53. optimum/rbln/transformers/models/dpt/modeling_dpt.py +1 -1
  54. optimum/rbln/transformers/models/exaone/exaone_architecture.py +60 -45
  55. optimum/rbln/transformers/models/exaone/modeling_exaone.py +4 -2
  56. optimum/rbln/transformers/models/gemma/gemma_architecture.py +33 -104
  57. optimum/rbln/transformers/models/gpt2/gpt2_architecture.py +50 -238
  58. optimum/rbln/transformers/models/gpt2/modeling_gpt2.py +3 -2
  59. optimum/rbln/transformers/models/llama/llama_architecture.py +0 -1
  60. optimum/rbln/transformers/models/llava_next/modeling_llava_next.py +3 -75
  61. optimum/rbln/transformers/models/midm/midm_architecture.py +84 -238
  62. optimum/rbln/transformers/models/midm/modeling_midm.py +5 -6
  63. optimum/rbln/transformers/models/mistral/mistral_architecture.py +0 -1
  64. optimum/rbln/transformers/models/phi/phi_architecture.py +60 -261
  65. optimum/rbln/transformers/models/qwen2/qwen2_architecture.py +0 -1
  66. optimum/rbln/transformers/models/seq2seq/modeling_seq2seq.py +58 -103
  67. optimum/rbln/transformers/models/seq2seq/seq2seq_architecture.py +498 -0
  68. optimum/rbln/transformers/models/t5/__init__.py +0 -1
  69. optimum/rbln/transformers/models/t5/modeling_t5.py +106 -5
  70. optimum/rbln/transformers/models/t5/t5_architecture.py +106 -448
  71. optimum/rbln/transformers/models/wav2vec2/modeling_wav2vec2.py +1 -1
  72. optimum/rbln/transformers/models/whisper/generation_whisper.py +42 -0
  73. optimum/rbln/transformers/models/whisper/modeling_whisper.py +78 -55
  74. optimum/rbln/transformers/models/whisper/whisper_architecture.py +219 -312
  75. optimum/rbln/transformers/models/xlm_roberta/modeling_xlm_roberta.py +3 -35
  76. optimum/rbln/transformers/utils/rbln_quantization.py +120 -4
  77. optimum/rbln/utils/decorator_utils.py +51 -11
  78. optimum/rbln/utils/hub.py +131 -0
  79. optimum/rbln/utils/import_utils.py +22 -1
  80. optimum/rbln/utils/logging.py +37 -0
  81. optimum/rbln/utils/model_utils.py +52 -0
  82. optimum/rbln/utils/runtime_utils.py +10 -4
  83. optimum/rbln/utils/save_utils.py +17 -0
  84. optimum/rbln/utils/submodule.py +137 -0
  85. optimum_rbln-0.2.0.dist-info/METADATA +117 -0
  86. optimum_rbln-0.2.0.dist-info/RECORD +114 -0
  87. {optimum_rbln-0.1.13.dist-info → optimum_rbln-0.2.0.dist-info}/WHEEL +1 -1
  88. optimum_rbln-0.2.0.dist-info/licenses/LICENSE +288 -0
  89. optimum/rbln/transformers/cache_utils.py +0 -107
  90. optimum/rbln/transformers/generation/streamers.py +0 -139
  91. optimum/rbln/transformers/generation/utils.py +0 -397
  92. optimum/rbln/transformers/models/exaone/hf_hub_cached/configuration_exaone.py +0 -181
  93. optimum/rbln/transformers/models/exaone/hf_hub_cached/modeling_exaone.py +0 -1725
  94. optimum/rbln/transformers/models/midm/hf_hub_cached/configuration_midm.py +0 -22
  95. optimum/rbln/transformers/models/midm/hf_hub_cached/midm_bitext_tokenization.py +0 -304
  96. optimum/rbln/transformers/models/midm/hf_hub_cached/modeling_midm.py +0 -1469
  97. optimum/rbln/transformers/models/midm/hf_hub_cached/rotary_position_embedding.py +0 -98
  98. optimum/rbln/utils/context.py +0 -58
  99. optimum/rbln/utils/timer_utils.py +0 -43
  100. optimum_rbln-0.1.13.dist-info/METADATA +0 -120
  101. optimum_rbln-0.1.13.dist-info/RECORD +0 -107
  102. optimum_rbln-0.1.13.dist-info/entry_points.txt +0 -4
  103. optimum_rbln-0.1.13.dist-info/licenses/LICENSE +0 -201
optimum/rbln/__init__.py CHANGED
@@ -30,27 +30,15 @@ from .utils import check_version_compats
30
30
 
31
31
 
32
32
  _import_structure = {
33
- "modeling_alias": [
34
- "RBLNASTForAudioClassification",
35
- "RBLNBertForQuestionAnswering",
36
- "RBLNDistilBertForQuestionAnswering",
37
- "RBLNResNetForImageClassification",
38
- "RBLNXLMRobertaForSequenceClassification",
39
- "RBLNRobertaForSequenceClassification",
40
- "RBLNRobertaForMaskedLM",
41
- "RBLNViTForImageClassification",
42
- ],
43
- "modeling_base": [
33
+ "modeling": [
44
34
  "RBLNBaseModel",
45
35
  "RBLNModel",
46
- "RBLNModelForQuestionAnswering",
47
- "RBLNModelForAudioClassification",
48
- "RBLNModelForImageClassification",
49
- "RBLNModelForSequenceClassification",
50
- "RBLNModelForMaskedLM",
36
+ ],
37
+ "modeling_config": [
38
+ "RBLNCompileConfig",
39
+ "RBLNConfig",
51
40
  ],
52
41
  "transformers": [
53
- "BatchTextIteratorStreamer",
54
42
  "RBLNAutoModel",
55
43
  "RBLNAutoModelForAudioClassification",
56
44
  "RBLNAutoModelForCausalLM",
@@ -84,6 +72,14 @@ _import_structure = {
84
72
  "RBLNMistralForCausalLM",
85
73
  "RBLNWhisperForConditionalGeneration",
86
74
  "RBLNXLMRobertaModel",
75
+ "RBLNASTForAudioClassification",
76
+ "RBLNBertForQuestionAnswering",
77
+ "RBLNDistilBertForQuestionAnswering",
78
+ "RBLNResNetForImageClassification",
79
+ "RBLNXLMRobertaForSequenceClassification",
80
+ "RBLNRobertaForSequenceClassification",
81
+ "RBLNRobertaForMaskedLM",
82
+ "RBLNViTForImageClassification",
87
83
  ],
88
84
  "diffusers": [
89
85
  "RBLNStableDiffusionPipeline",
@@ -92,55 +88,54 @@ _import_structure = {
92
88
  "RBLNUNet2DConditionModel",
93
89
  "RBLNControlNetModel",
94
90
  "RBLNStableDiffusionImg2ImgPipeline",
91
+ "RBLNStableDiffusionInpaintPipeline",
95
92
  "RBLNStableDiffusionControlNetImg2ImgPipeline",
96
93
  "RBLNMultiControlNetModel",
97
94
  "RBLNStableDiffusionXLImg2ImgPipeline",
95
+ "RBLNStableDiffusionXLInpaintPipeline",
98
96
  "RBLNStableDiffusionControlNetPipeline",
99
97
  "RBLNStableDiffusionXLControlNetPipeline",
100
98
  "RBLNStableDiffusionXLControlNetImg2ImgPipeline",
99
+ "RBLNSD3Transformer2DModel",
100
+ "RBLNStableDiffusion3Img2ImgPipeline",
101
+ "RBLNStableDiffusion3InpaintPipeline",
102
+ "RBLNStableDiffusion3Pipeline",
103
+ "RBLNDiffusionMixin",
101
104
  ],
102
- "modeling_config": ["RBLNCompileConfig", "RBLNConfig"],
103
- "modeling_diffusers": ["RBLNDiffusionMixin"],
104
105
  }
105
106
 
106
107
  if TYPE_CHECKING:
107
108
  from .diffusers import (
108
109
  RBLNAutoencoderKL,
109
110
  RBLNControlNetModel,
111
+ RBLNDiffusionMixin,
110
112
  RBLNMultiControlNetModel,
113
+ RBLNSD3Transformer2DModel,
114
+ RBLNStableDiffusion3Img2ImgPipeline,
115
+ RBLNStableDiffusion3InpaintPipeline,
116
+ RBLNStableDiffusion3Pipeline,
111
117
  RBLNStableDiffusionControlNetImg2ImgPipeline,
112
118
  RBLNStableDiffusionControlNetPipeline,
113
119
  RBLNStableDiffusionImg2ImgPipeline,
120
+ RBLNStableDiffusionInpaintPipeline,
114
121
  RBLNStableDiffusionPipeline,
115
122
  RBLNStableDiffusionXLControlNetImg2ImgPipeline,
116
123
  RBLNStableDiffusionXLControlNetPipeline,
117
124
  RBLNStableDiffusionXLImg2ImgPipeline,
125
+ RBLNStableDiffusionXLInpaintPipeline,
118
126
  RBLNStableDiffusionXLPipeline,
119
127
  RBLNUNet2DConditionModel,
120
128
  )
121
- from .modeling_alias import (
122
- RBLNASTForAudioClassification,
123
- RBLNBertForQuestionAnswering,
124
- RBLNResNetForImageClassification,
125
- RBLNRobertaForMaskedLM,
126
- RBLNRobertaForSequenceClassification,
127
- RBLNT5ForConditionalGeneration,
128
- RBLNViTForImageClassification,
129
- RBLNXLMRobertaForSequenceClassification,
130
- )
131
- from .modeling_base import (
129
+ from .modeling import (
132
130
  RBLNBaseModel,
133
131
  RBLNModel,
134
- RBLNModelForAudioClassification,
135
- RBLNModelForImageClassification,
136
- RBLNModelForMaskedLM,
137
- RBLNModelForQuestionAnswering,
138
- RBLNModelForSequenceClassification,
139
132
  )
140
- from .modeling_config import RBLNCompileConfig, RBLNConfig
141
- from .modeling_diffusers import RBLNDiffusionMixin
133
+ from .modeling_config import (
134
+ RBLNCompileConfig,
135
+ RBLNConfig,
136
+ )
142
137
  from .transformers import (
143
- BatchTextIteratorStreamer,
138
+ RBLNASTForAudioClassification,
144
139
  RBLNAutoModel,
145
140
  RBLNAutoModelForAudioClassification,
146
141
  RBLNAutoModelForCausalLM,
@@ -155,10 +150,12 @@ if TYPE_CHECKING:
155
150
  RBLNAutoModelForVision2Seq,
156
151
  RBLNBartForConditionalGeneration,
157
152
  RBLNBartModel,
153
+ RBLNBertForQuestionAnswering,
158
154
  RBLNBertModel,
159
155
  RBLNCLIPTextModel,
160
156
  RBLNCLIPTextModelWithProjection,
161
157
  RBLNCLIPVisionModel,
158
+ RBLNDistilBertForQuestionAnswering,
162
159
  RBLNDPTForDepthEstimation,
163
160
  RBLNExaoneForCausalLM,
164
161
  RBLNGemmaForCausalLM,
@@ -169,12 +166,18 @@ if TYPE_CHECKING:
169
166
  RBLNMistralForCausalLM,
170
167
  RBLNPhiForCausalLM,
171
168
  RBLNQwen2ForCausalLM,
169
+ RBLNResNetForImageClassification,
170
+ RBLNRobertaForMaskedLM,
171
+ RBLNRobertaForSequenceClassification,
172
172
  RBLNT5EncoderModel,
173
173
  RBLNT5ForConditionalGeneration,
174
+ RBLNViTForImageClassification,
174
175
  RBLNWav2Vec2ForCTC,
175
176
  RBLNWhisperForConditionalGeneration,
177
+ RBLNXLMRobertaForSequenceClassification,
176
178
  RBLNXLMRobertaModel,
177
179
  )
180
+
178
181
  else:
179
182
  import sys
180
183
 
@@ -1 +1,16 @@
1
- __version__ = '0.1.13'
1
+ # file generated by setuptools_scm
2
+ # don't change, don't track in version control
3
+ TYPE_CHECKING = False
4
+ if TYPE_CHECKING:
5
+ from typing import Tuple, Union
6
+ VERSION_TUPLE = Tuple[Union[int, str], ...]
7
+ else:
8
+ VERSION_TUPLE = object
9
+
10
+ version: str
11
+ __version__: str
12
+ __version_tuple__: VERSION_TUPLE
13
+ version_tuple: VERSION_TUPLE
14
+
15
+ __version__ = version = '0.2.0'
16
+ __version_tuple__ = version_tuple = (0, 2, 0)
@@ -36,27 +36,51 @@ _import_structure = {
36
36
  "RBLNStableDiffusionPipeline",
37
37
  "RBLNStableDiffusionXLPipeline",
38
38
  "RBLNStableDiffusionImg2ImgPipeline",
39
+ "RBLNStableDiffusionInpaintPipeline",
39
40
  "RBLNStableDiffusionControlNetImg2ImgPipeline",
40
41
  "RBLNMultiControlNetModel",
41
42
  "RBLNStableDiffusionXLImg2ImgPipeline",
43
+ "RBLNStableDiffusionXLInpaintPipeline",
42
44
  "RBLNStableDiffusionControlNetPipeline",
43
45
  "RBLNStableDiffusionXLControlNetPipeline",
44
46
  "RBLNStableDiffusionXLControlNetImg2ImgPipeline",
47
+ "RBLNStableDiffusion3Pipeline",
48
+ "RBLNStableDiffusion3Img2ImgPipeline",
49
+ "RBLNStableDiffusion3InpaintPipeline",
50
+ ],
51
+ "models": [
52
+ "RBLNAutoencoderKL",
53
+ "RBLNUNet2DConditionModel",
54
+ "RBLNControlNetModel",
55
+ "RBLNSD3Transformer2DModel",
56
+ ],
57
+ "modeling_diffusers": [
58
+ "RBLNDiffusionMixin",
45
59
  ],
46
- "models": ["RBLNAutoencoderKL", "RBLNUNet2DConditionModel", "RBLNControlNetModel"],
47
60
  }
48
61
 
49
62
  if TYPE_CHECKING:
50
- from .models import RBLNAutoencoderKL, RBLNControlNetModel, RBLNUNet2DConditionModel
63
+ from .modeling_diffusers import RBLNDiffusionMixin
64
+ from .models import (
65
+ RBLNAutoencoderKL,
66
+ RBLNControlNetModel,
67
+ RBLNSD3Transformer2DModel,
68
+ RBLNUNet2DConditionModel,
69
+ )
51
70
  from .pipelines import (
52
71
  RBLNMultiControlNetModel,
72
+ RBLNStableDiffusion3Img2ImgPipeline,
73
+ RBLNStableDiffusion3InpaintPipeline,
74
+ RBLNStableDiffusion3Pipeline,
53
75
  RBLNStableDiffusionControlNetImg2ImgPipeline,
54
76
  RBLNStableDiffusionControlNetPipeline,
55
77
  RBLNStableDiffusionImg2ImgPipeline,
78
+ RBLNStableDiffusionInpaintPipeline,
56
79
  RBLNStableDiffusionPipeline,
57
80
  RBLNStableDiffusionXLControlNetImg2ImgPipeline,
58
81
  RBLNStableDiffusionXLControlNetPipeline,
59
82
  RBLNStableDiffusionXLImg2ImgPipeline,
83
+ RBLNStableDiffusionXLInpaintPipeline,
60
84
  RBLNStableDiffusionXLPipeline,
61
85
  )
62
86
  else:
@@ -20,16 +20,21 @@
20
20
  # are the intellectual property of Rebellions Inc. and may not be
21
21
  # copied, modified, or distributed without prior written permission
22
22
  # from Rebellions Inc.
23
+
24
+ import copy
23
25
  import importlib
24
26
  from os import PathLike
25
- from typing import TYPE_CHECKING, Any, Dict, List, Optional, Tuple, Union
27
+ from typing import TYPE_CHECKING, Any, Dict, List, Optional, Union
26
28
 
27
29
  import torch
28
30
 
29
- from .modeling_base import RBLNModel
30
- from .modeling_config import ContextRblnConfig, use_rbln_config
31
- from .utils.decorator_utils import remove_compile_time_kwargs
31
+ from ..modeling import RBLNModel
32
+ from ..modeling_config import RUNTIME_KEYWORDS, ContextRblnConfig, use_rbln_config
33
+ from ..utils.decorator_utils import remove_compile_time_kwargs
34
+ from ..utils.logging import get_logger
35
+
32
36
 
37
+ logger = get_logger(__name__)
33
38
 
34
39
  if TYPE_CHECKING:
35
40
  from diffusers.pipelines.controlnet.multicontrolnet import MultiControlNetModel
@@ -74,127 +79,40 @@ class RBLNDiffusionMixin:
74
79
 
75
80
  @classmethod
76
81
  @property
77
- def use_encode(cls):
82
+ def img2img_pipeline(cls):
78
83
  return "Img2Img" in cls.__name__
79
84
 
80
85
  @classmethod
81
- def _get_unet_batch_size(cls, model: torch.nn.Module, rbln_config: Dict[str, Any]) -> int:
82
- # Calculates the batch size based on guidance scale
83
- batch_size = rbln_config.get("batch_size", 1)
84
- do_guidance = rbln_config.get("guidance_scale", 5.0) > 1.0 and model.unet.config.time_cond_proj_dim is None
85
- return batch_size * 2 if do_guidance else batch_size
86
-
87
- @classmethod
88
- def _get_vae_sample_size(cls, model: torch.nn.Module, rbln_config: Dict[str, Any]) -> Union[int, Tuple[int, int]]:
89
- image_size = (rbln_config.get("img_height"), rbln_config.get("img_width"))
90
- if (image_size[0] is None) != (image_size[1] is None):
91
- raise ValueError("Both image height and image width must be given or not given")
92
- elif image_size[0] is None and image_size[1] is None:
93
- if cls.use_encode:
94
- sample_size = model.vae.config.sample_size
95
- else:
96
- # In case of text2img, sample size of vae decoder is determined by unet.
97
- unet_sample_size = model.unet.config.sample_size
98
- if isinstance(unet_sample_size, int):
99
- sample_size = unet_sample_size * model.vae_scale_factor
100
- else:
101
- sample_size = (
102
- unet_sample_size[0] * model.vae_scale_factor,
103
- unet_sample_size[1] * model.vae_scale_factor,
104
- )
105
-
106
- else:
107
- sample_size = (image_size[0], image_size[1])
108
- return sample_size
109
-
110
- @classmethod
111
- def _get_unet_sample_size(cls, model: torch.nn.Module, rbln_config: Dict[str, Any]) -> Union[int, Tuple[int, int]]:
112
- image_size = (rbln_config.get("img_height"), rbln_config.get("img_width"))
113
- if (image_size[0] is None) != (image_size[1] is None):
114
- raise ValueError("Both image height and image width must be given or not given")
115
- elif image_size[0] is None and image_size[1] is None:
116
- if cls.use_encode:
117
- # In case of img2img, sample size of unet is determined by vae encoder.
118
- vae_sample_size = model.vae.config.sample_size
119
- if isinstance(vae_sample_size, int):
120
- sample_size = vae_sample_size // model.vae_scale_factor
121
- else:
122
- sample_size = (
123
- vae_sample_size[0] // model.vae_scale_factor,
124
- vae_sample_size[1] // model.vae_scale_factor,
125
- )
126
- else:
127
- sample_size = model.unet.config.sample_size
128
- else:
129
- sample_size = (image_size[0] // model.vae_scale_factor, image_size[1] // model.vae_scale_factor)
130
- return sample_size
131
-
132
- @classmethod
133
- def _get_default_config(cls, model: torch.nn.Module, rbln_config: Dict[str, Any]) -> Dict[str, Any]:
134
- # default configurations for each submodules
135
- return {"img2img_pipeline": cls.use_encode}
86
+ @property
87
+ def inpaint_pipeline(cls):
88
+ return "Inpaint" in cls.__name__
136
89
 
137
90
  @classmethod
138
- def get_default_rbln_config_text_encoder(
139
- cls, model: torch.nn.Module, rbln_config: Dict[str, Any]
91
+ def get_submodule_rbln_config(
92
+ cls, model: torch.nn.Module, submodule_name: str, rbln_config: Dict[str, Any]
140
93
  ) -> Dict[str, Any]:
141
- batch_size = rbln_config.get("batch_size", 1)
142
- return {"batch_size": batch_size}
94
+ submodule = getattr(model, submodule_name)
95
+ submodule_class_name = submodule.__class__.__name__
143
96
 
144
- @classmethod
145
- def get_default_rbln_config_text_encoder_2(
146
- cls, model: torch.nn.Module, rbln_config: Dict[str, Any]
147
- ) -> Dict[str, Any]:
148
- batch_size = rbln_config.get("batch_size", 1)
149
- return {"batch_size": batch_size}
97
+ if submodule_class_name == "MultiControlNetModel":
98
+ submodule_class_name = "ControlNetModel"
150
99
 
151
- @classmethod
152
- def get_default_rbln_config_unet(cls, model: torch.nn.Module, rbln_config: Dict[str, Any]) -> Dict[str, Any]:
153
- # configuration for unet
154
- unet_batch_size = cls._get_unet_batch_size(model, rbln_config)
155
- text_model_hidden_size = model.text_encoder_2.config.hidden_size if hasattr(model, "text_encoder_2") else None
156
- return {
157
- **cls._get_default_config(model, rbln_config),
158
- "max_seq_len": model.text_encoder.config.max_position_embeddings,
159
- "text_model_hidden_size": text_model_hidden_size,
160
- "batch_size": unet_batch_size,
161
- "sample_size": cls._get_unet_sample_size(model, rbln_config),
162
- "is_controlnet": "controlnet" in model.config.keys(),
163
- }
100
+ submodule_cls: RBLNModel = getattr(importlib.import_module("optimum.rbln"), f"RBLN{submodule_class_name}")
164
101
 
165
- @classmethod
166
- def get_default_rbln_config_vae(cls, model: torch.nn.Module, rbln_config: Dict[str, Any]) -> Dict[str, Any]:
167
- # configuration for vae
168
- batch_size = rbln_config.get("batch_size", 1)
169
- return {
170
- **cls._get_default_config(model, rbln_config),
171
- "sample_size": cls._get_vae_sample_size(model, rbln_config),
172
- "batch_size": batch_size,
173
- }
102
+ submodule_config = rbln_config.get(submodule_name, {})
103
+ submodule_config = copy.deepcopy(submodule_config)
174
104
 
175
- @classmethod
176
- def get_default_rbln_config_controlnet(cls, model: torch.nn.Module, rbln_config: Dict[str, Any]) -> Dict[str, Any]:
177
- # configuration for controlnet
178
- unet_batch_size = cls._get_unet_batch_size(model, rbln_config)
179
- text_model_hidden_size = model.text_encoder_2.config.hidden_size if hasattr(model, "text_encoder_2") else None
180
- return {
181
- **cls._get_default_config(model, rbln_config),
182
- "max_seq_len": model.text_encoder.config.max_position_embeddings,
183
- "vae_sample_size": cls._get_vae_sample_size(model, rbln_config),
184
- "unet_sample_size": cls._get_unet_sample_size(model, rbln_config),
185
- "batch_size": unet_batch_size,
186
- "text_model_hidden_size": text_model_hidden_size,
187
- }
105
+ pipe_global_config = {k: v for k, v in rbln_config.items() if k not in cls._submodules}
188
106
 
189
- @classmethod
190
- def get_default_rbln_config(
191
- cls, model: torch.nn.Module, submodule_name: str, rbln_config: Dict[str, Any]
192
- ) -> Dict[str, Any]:
193
- # Returns the default configuration based on submodule name
194
- config_method = f"get_default_rbln_config_{submodule_name}"
195
- if hasattr(cls, config_method):
196
- return getattr(cls, config_method)(model, rbln_config)
197
- raise ValueError(f"Unknown submodule: {submodule_name}")
107
+ submodule_config.update({k: v for k, v in pipe_global_config.items() if k not in submodule_config})
108
+ submodule_config.update(
109
+ {
110
+ "img2img_pipeline": cls.img2img_pipeline,
111
+ "inpaint_pipeline": cls.inpaint_pipeline,
112
+ }
113
+ )
114
+ submodule_config = submodule_cls.update_rbln_config_using_pipe(model, submodule_config)
115
+ return submodule_config
198
116
 
199
117
  @staticmethod
200
118
  def _maybe_apply_and_fuse_lora(
@@ -256,17 +174,46 @@ class RBLNDiffusionMixin:
256
174
 
257
175
  else:
258
176
  # raise error if any of submodules are torch module.
259
- for name in cls._submodules:
260
- if isinstance(kwargs.get(name), torch.nn.Module):
177
+ model_index_config = None
178
+ for submodule_name in cls._submodules:
179
+ if isinstance(kwargs.get(submodule_name), torch.nn.Module):
261
180
  raise AssertionError(
262
- f"{name} is not compiled torch module. If you want to compile, set `export=True`."
181
+ f"{submodule_name} is not compiled torch module. If you want to compile, set `export=True`."
263
182
  )
264
183
 
184
+ submodule_config = rbln_config.get(submodule_name, {})
185
+
186
+ for key, value in rbln_config.items():
187
+ if key in RUNTIME_KEYWORDS and key not in submodule_config:
188
+ submodule_config[key] = value
189
+
190
+ if not any(kwd in submodule_config for kwd in RUNTIME_KEYWORDS):
191
+ continue
192
+
193
+ if model_index_config is None:
194
+ model_index_config = cls.load_config(pretrained_model_name_or_path=model_id)
195
+
196
+ module_name, class_name = model_index_config[submodule_name]
197
+ if module_name != "optimum.rbln":
198
+ raise ValueError(
199
+ f"Invalid module_name '{module_name}' found in model_index.json for "
200
+ f"submodule '{submodule_name}'. "
201
+ "Expected 'optimum.rbln'. Please check the model_index.json configuration."
202
+ )
203
+
204
+ submodule_cls: RBLNModel = getattr(importlib.import_module("optimum.rbln"), class_name)
205
+
206
+ submodule = submodule_cls.from_pretrained(
207
+ model_id, export=False, subfolder=submodule_name, rbln_config=submodule_config
208
+ )
209
+ kwargs[submodule_name] = submodule
210
+
265
211
  with ContextRblnConfig(
266
212
  device=rbln_config.get("device"),
267
213
  device_map=rbln_config.get("device_map"),
268
214
  create_runtimes=rbln_config.get("create_runtimes"),
269
215
  optimize_host_mem=rbln_config.get("optimize_host_memory"),
216
+ activate_profiler=rbln_config.get("activate_profiler"),
270
217
  ):
271
218
  model = super().from_pretrained(pretrained_model_name_or_path=model_id, **kwargs)
272
219
 
@@ -291,16 +238,11 @@ class RBLNDiffusionMixin:
291
238
  model_save_dir: Optional[PathLike],
292
239
  rbln_config: Dict[str, Any],
293
240
  ) -> Dict[str, RBLNModel]:
294
- # Compile submodules based on rbln_config
295
241
  compiled_submodules = {}
296
242
 
297
- # FIXME : Currently, optimum-rbln for transformer does not use base rbln config.
298
- base_rbln_config = {k: v for k, v in rbln_config.items() if k not in cls._submodules}
299
243
  for submodule_name in cls._submodules:
300
244
  submodule = passed_submodules.get(submodule_name) or getattr(model, submodule_name, None)
301
- submodule_rbln_config = cls.get_default_rbln_config(model, submodule_name, rbln_config)
302
- submodule_rbln_config.update(base_rbln_config)
303
- submodule_rbln_config.update(rbln_config.get(submodule_name, {}))
245
+ submodule_rbln_config = cls.get_submodule_rbln_config(model, submodule_name, rbln_config)
304
246
 
305
247
  if submodule is None:
306
248
  raise ValueError(f"submodule ({submodule_name}) cannot be accessed since it is not provided.")
@@ -337,8 +279,8 @@ class RBLNDiffusionMixin:
337
279
  controlnet_rbln_config: Dict[str, Any],
338
280
  ):
339
281
  # Compile multiple ControlNet models for a MultiControlNet setup
340
- from .diffusers.models.controlnet import RBLNControlNetModel
341
- from .diffusers.pipelines.controlnet import RBLNMultiControlNetModel
282
+ from .models.controlnet import RBLNControlNetModel
283
+ from .pipelines.controlnet import RBLNMultiControlNetModel
342
284
 
343
285
  compiled_controlnets = [
344
286
  RBLNControlNetModel.from_model(
@@ -349,7 +291,7 @@ class RBLNDiffusionMixin:
349
291
  )
350
292
  for i, controlnet in enumerate(controlnets.nets)
351
293
  ]
352
- return RBLNMultiControlNetModel(compiled_controlnets, config=controlnets.nets[0].config)
294
+ return RBLNMultiControlNetModel(compiled_controlnets)
353
295
 
354
296
  @classmethod
355
297
  def _construct_pipe(cls, model, submodules, model_save_dir, rbln_config):
@@ -395,6 +337,35 @@ class RBLNDiffusionMixin:
395
337
 
396
338
  return model
397
339
 
340
+ def get_compiled_image_size(self):
341
+ if hasattr(self, "vae"):
342
+ compiled_image_size = self.vae.image_size
343
+ else:
344
+ compiled_image_size = None
345
+ return compiled_image_size
346
+
347
+ def handle_additional_kwargs(self, **kwargs):
348
+ """
349
+ Function to handle additional compile-time parameters during inference.
350
+
351
+ If the additional variable is determined by another module, this method should be overrided.
352
+
353
+ Example:
354
+ ```python
355
+ if hasattr(self, "movq"):
356
+ compiled_image_size = self.movq.image_size
357
+ kwargs["height"] = compiled_image_size[0]
358
+ kwargs["width"] = compiled_image_size[1]
359
+
360
+ compiled_num_frames = self.unet.rbln_config.model_cfg.get("num_frames", None)
361
+ if compiled_num_frames is not None:
362
+ kwargs["num_frames"] = self.unet.rbln_config.model_cfg.get("num_frames")
363
+ return kwargs
364
+ ```
365
+ """
366
+ return kwargs
367
+
398
368
  @remove_compile_time_kwargs
399
369
  def __call__(self, *args, **kwargs):
370
+ kwargs = self.handle_additional_kwargs(**kwargs)
400
371
  return super().__call__(*args, **kwargs)
@@ -21,6 +21,39 @@
21
21
  # copied, modified, or distributed without prior written permission
22
22
  # from Rebellions Inc.
23
23
 
24
- from .autoencoder_kl import RBLNAutoencoderKL
25
- from .controlnet import RBLNControlNetModel
26
- from .unet_2d_condition import RBLNUNet2DConditionModel
24
+ from typing import TYPE_CHECKING
25
+
26
+ from transformers.utils import _LazyModule
27
+
28
+
29
+ _import_structure = {
30
+ "autoencoders": [
31
+ "RBLNAutoencoderKL",
32
+ ],
33
+ "unets": [
34
+ "RBLNUNet2DConditionModel",
35
+ ],
36
+ "controlnet": ["RBLNControlNetModel"],
37
+ "transformers": ["RBLNSD3Transformer2DModel"],
38
+ }
39
+
40
+ if TYPE_CHECKING:
41
+ from .autoencoders import (
42
+ RBLNAutoencoderKL,
43
+ )
44
+ from .controlnet import RBLNControlNetModel
45
+ from .transformers import (
46
+ RBLNSD3Transformer2DModel,
47
+ )
48
+ from .unets import (
49
+ RBLNUNet2DConditionModel,
50
+ )
51
+ else:
52
+ import sys
53
+
54
+ sys.modules[__name__] = _LazyModule(
55
+ __name__,
56
+ globals()["__file__"],
57
+ _import_structure,
58
+ module_spec=__spec__,
59
+ )
@@ -21,5 +21,4 @@
21
21
  # copied, modified, or distributed without prior written permission
22
22
  # from Rebellions Inc.
23
23
 
24
- from .streamers import BatchTextIteratorStreamer
25
- from .utils import RBLNGenerationMixin
24
+ from .autoencoder_kl import RBLNAutoencoderKL