optimum-rbln 0.1.0__py3-none-any.whl → 0.1.4__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (41) hide show
  1. optimum/rbln/__init__.py +8 -0
  2. optimum/rbln/__version__.py +1 -1
  3. optimum/rbln/diffusers/__init__.py +7 -0
  4. optimum/rbln/diffusers/models/autoencoder_kl.py +30 -9
  5. optimum/rbln/diffusers/models/controlnet.py +93 -23
  6. optimum/rbln/diffusers/models/unet_2d_condition.py +78 -61
  7. optimum/rbln/diffusers/pipelines/__init__.py +7 -2
  8. optimum/rbln/diffusers/pipelines/controlnet/__init__.py +4 -0
  9. optimum/rbln/diffusers/pipelines/controlnet/pipeline_controlnet.py +768 -0
  10. optimum/rbln/diffusers/pipelines/{stable_diffusion → controlnet}/pipeline_controlnet_img2img.py +25 -16
  11. optimum/rbln/diffusers/pipelines/controlnet/pipeline_controlnet_sd_xl.py +942 -0
  12. optimum/rbln/diffusers/pipelines/controlnet/pipeline_controlnet_sd_xl_img2img.py +955 -0
  13. optimum/rbln/diffusers/pipelines/stable_diffusion/__init__.py +0 -1
  14. optimum/rbln/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion.py +23 -4
  15. optimum/rbln/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_img2img.py +22 -9
  16. optimum/rbln/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl.py +19 -3
  17. optimum/rbln/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_img2img.py +19 -3
  18. optimum/rbln/modeling_base.py +39 -6
  19. optimum/rbln/modeling_seq2seq.py +19 -4
  20. optimum/rbln/transformers/__init__.py +2 -0
  21. optimum/rbln/transformers/generation/__init__.py +1 -0
  22. optimum/rbln/transformers/generation/streamers.py +17 -0
  23. optimum/rbln/transformers/generation/utils.py +399 -0
  24. optimum/rbln/transformers/models/__init__.py +1 -0
  25. optimum/rbln/transformers/models/gpt2/modeling_gpt2.py +24 -333
  26. optimum/rbln/transformers/models/llama/llama_architecture.py +49 -17
  27. optimum/rbln/transformers/models/llama/llama_architecture_cb.py +759 -0
  28. optimum/rbln/transformers/models/llama/modeling_llama.py +187 -75
  29. optimum/rbln/transformers/models/midm/__init__.py +32 -0
  30. optimum/rbln/transformers/models/midm/hf_hub_cached/configuration_midm.py +22 -0
  31. optimum/rbln/transformers/models/midm/hf_hub_cached/midm_bitext_tokenization.py +303 -0
  32. optimum/rbln/transformers/models/midm/hf_hub_cached/modeling_midm.py +1473 -0
  33. optimum/rbln/transformers/models/midm/hf_hub_cached/rotary_position_embedding.py +98 -0
  34. optimum/rbln/transformers/models/midm/midm_architecture.py +506 -0
  35. optimum/rbln/transformers/models/midm/modeling_midm.py +426 -0
  36. optimum/rbln/transformers/models/whisper/modeling_whisper.py +13 -3
  37. {optimum_rbln-0.1.0.dist-info → optimum_rbln-0.1.4.dist-info}/METADATA +5 -4
  38. optimum_rbln-0.1.4.dist-info/RECORD +63 -0
  39. {optimum_rbln-0.1.0.dist-info → optimum_rbln-0.1.4.dist-info}/WHEEL +1 -1
  40. optimum_rbln-0.1.0.dist-info/RECORD +0 -51
  41. {optimum_rbln-0.1.0.dist-info → optimum_rbln-0.1.4.dist-info}/licenses/LICENSE +0 -0
@@ -21,6 +21,5 @@
21
21
  # copied, modified, or distributed without prior written permission
22
22
  # from Rebellions Inc.
23
23
 
24
- from .pipeline_controlnet_img2img import RBLNStableDiffusionControlNetImg2ImgPipeline
25
24
  from .pipeline_stable_diffusion import RBLNStableDiffusionPipeline
26
25
  from .pipeline_stable_diffusion_img2img import RBLNStableDiffusionImg2ImgPipeline
@@ -22,7 +22,6 @@
22
22
  # from Rebellions Inc.
23
23
  """RBLNStableDiffusionPipeline class for inference of diffusion models on rbln devices."""
24
24
 
25
-
26
25
  from diffusers import StableDiffusionPipeline
27
26
 
28
27
  from ....modeling_base import RBLNBaseModel
@@ -50,17 +49,22 @@ class RBLNStableDiffusionPipeline(StableDiffusionPipeline):
50
49
  - A path to a *directory* containing a model saved using [`~OptimizedModel.save_pretrained`],
51
50
  """
52
51
  export = kwargs.pop("export", None)
52
+ model_save_dir = kwargs.pop("model_save_dir", None)
53
+ rbln_config_kwargs, rbln_constructor_kwargs = RBLNBaseModel.pop_rbln_kwargs_from_kwargs(kwargs)
53
54
  model = super().from_pretrained(pretrained_model_name_or_path=model_id, **kwargs)
55
+
54
56
  if export is None or export is False:
55
57
  return model
56
58
 
57
- rbln_config_kwargs, rbln_constructor_kwargs = RBLNBaseModel.pop_rbln_kwargs_from_kwargs(kwargs)
59
+ do_classifier_free_guidance = (
60
+ rbln_config_kwargs.pop("rbln_guidance_scale", 5.0) > 1.0 and model.unet.config.time_cond_proj_dim is None
61
+ )
58
62
 
59
- # compile model, create runtime
60
63
  vae = RBLNAutoencoderKL.from_pretrained(
61
64
  model_id=model_id,
62
65
  subfolder="vae",
63
66
  export=True,
67
+ model_save_dir=model_save_dir,
64
68
  rbln_unet_sample_size=model.unet.config.sample_size,
65
69
  rbln_use_encode=False,
66
70
  **rbln_config_kwargs,
@@ -70,16 +74,19 @@ class RBLNStableDiffusionPipeline(StableDiffusionPipeline):
70
74
  model_id=model_id,
71
75
  subfolder="text_encoder",
72
76
  export=True,
77
+ model_save_dir=model_save_dir,
73
78
  **rbln_config_kwargs,
74
79
  **rbln_constructor_kwargs,
75
80
  )
76
81
 
77
82
  batch_size = rbln_config_kwargs.pop("rbln_batch_size", 1)
78
- unet_batch_size = batch_size * 2
83
+ unet_batch_size = batch_size * 2 if do_classifier_free_guidance else batch_size
84
+
79
85
  unet = RBLNUNet2DConditionModel.from_pretrained(
80
86
  model_id=model_id,
81
87
  subfolder="unet",
82
88
  export=True,
89
+ model_save_dir=model_save_dir,
83
90
  rbln_max_seq_len=text_encoder.config.max_position_embeddings,
84
91
  rbln_batch_size=unet_batch_size,
85
92
  rbln_use_encode=False,
@@ -88,6 +95,14 @@ class RBLNStableDiffusionPipeline(StableDiffusionPipeline):
88
95
  **rbln_constructor_kwargs,
89
96
  )
90
97
 
98
+ if model_save_dir is not None:
99
+ # To skip saving original pytorch modules
100
+ del (model.vae, model.text_encoder, model.unet)
101
+
102
+ # Direct calling of `save_pretrained` causes config.unet = (None, None).
103
+ # So config must be saved again, later.
104
+ model.save_pretrained(model_save_dir)
105
+
91
106
  # replace modules
92
107
  model.vae = vae
93
108
  model.text_encoder = text_encoder
@@ -101,6 +116,10 @@ class RBLNStableDiffusionPipeline(StableDiffusionPipeline):
101
116
  }
102
117
  model.register_to_config(**update_dict)
103
118
 
119
+ if model_save_dir is not None:
120
+ # overwrite to replace incorrect config
121
+ model.save_config(model_save_dir)
122
+
104
123
  model.models = [vae.model[0], text_encoder.model[0], unet.model[0]]
105
124
 
106
125
  return model
@@ -22,9 +22,6 @@
22
22
  # from Rebellions Inc.
23
23
  """RBLNStableDiffusionPipeline class for inference of diffusion models on rbln devices."""
24
24
 
25
- from pathlib import Path
26
- from tempfile import TemporaryDirectory
27
-
28
25
  from diffusers import StableDiffusionImg2ImgPipeline
29
26
 
30
27
  from ....modeling_base import RBLNBaseModel
@@ -52,21 +49,23 @@ class RBLNStableDiffusionImg2ImgPipeline(StableDiffusionImg2ImgPipeline):
52
49
  - A path to a *directory* containing a model saved using [`~OptimizedModel.save_pretrained`],
53
50
  """
54
51
  export = kwargs.pop("export", None)
52
+ model_save_dir = kwargs.pop("model_save_dir", None)
53
+ rbln_config_kwargs, rbln_constructor_kwargs = RBLNBaseModel.pop_rbln_kwargs_from_kwargs(kwargs)
55
54
  model = super().from_pretrained(pretrained_model_name_or_path=model_id, **kwargs)
55
+
56
56
  if export is None or export is False:
57
57
  return model
58
58
 
59
- rbln_config_kwargs, rbln_constructor_kwargs = RBLNBaseModel.pop_rbln_kwargs_from_kwargs(kwargs)
60
-
61
- save_dir = TemporaryDirectory()
62
- save_dir_path = Path(save_dir.name)
63
- model.save_pretrained(save_directory=save_dir_path, **kwargs)
59
+ do_classifier_free_guidance = (
60
+ rbln_config_kwargs.pop("rbln_guidance_scale", 5.0) > 1.0 and model.unet.config.time_cond_proj_dim is None
61
+ )
64
62
 
65
63
  # compile model, create runtime
66
64
  vae = RBLNAutoencoderKL.from_pretrained(
67
65
  model_id=model_id,
68
66
  subfolder="vae",
69
67
  export=True,
68
+ model_save_dir=model_save_dir,
70
69
  rbln_unet_sample_size=model.unet.config.sample_size,
71
70
  rbln_use_encode=True,
72
71
  rbln_vae_scale_factor=model.vae_scale_factor,
@@ -77,17 +76,19 @@ class RBLNStableDiffusionImg2ImgPipeline(StableDiffusionImg2ImgPipeline):
77
76
  model_id=model_id,
78
77
  subfolder="text_encoder",
79
78
  export=True,
79
+ model_save_dir=model_save_dir,
80
80
  **rbln_config_kwargs,
81
81
  **rbln_constructor_kwargs,
82
82
  )
83
83
 
84
84
  batch_size = rbln_config_kwargs.pop("rbln_batch_size", 1)
85
- unet_batch_size = batch_size * 2
85
+ unet_batch_size = batch_size * 2 if do_classifier_free_guidance else batch_size
86
86
 
87
87
  unet = RBLNUNet2DConditionModel.from_pretrained(
88
88
  model_id=model_id,
89
89
  subfolder="unet",
90
90
  export=True,
91
+ model_save_dir=model_save_dir,
91
92
  rbln_max_seq_len=text_encoder.config.max_position_embeddings,
92
93
  rbln_batch_size=unet_batch_size,
93
94
  rbln_use_encode=True,
@@ -97,6 +98,14 @@ class RBLNStableDiffusionImg2ImgPipeline(StableDiffusionImg2ImgPipeline):
97
98
  **rbln_constructor_kwargs,
98
99
  )
99
100
 
101
+ if model_save_dir is not None:
102
+ # To skip saving original pytorch modules
103
+ del (model.vae, model.text_encoder, model.unet)
104
+
105
+ # Direct calling of `save_pretrained` causes config.unet = (None, None).
106
+ # So config must be saved again, later.
107
+ model.save_pretrained(model_save_dir)
108
+
100
109
  # replace modules
101
110
  model.vae = vae
102
111
  model.text_encoder = text_encoder
@@ -110,6 +119,10 @@ class RBLNStableDiffusionImg2ImgPipeline(StableDiffusionImg2ImgPipeline):
110
119
  }
111
120
  model.register_to_config(**update_dict)
112
121
 
122
+ if model_save_dir is not None:
123
+ # overwrite to replace incorrect config
124
+ model.save_config(model_save_dir)
125
+
113
126
  # vae encoder, vae decoder, text_encoder, unet
114
127
  model.models = [vae.model[0], vae.model[1], text_encoder.model[0], unet.model[0]]
115
128
 
@@ -14,7 +14,6 @@
14
14
  # limitations under the License.
15
15
  """RBLNStableDiffusionXLPipeline class for inference of diffusion models on rbln devices."""
16
16
 
17
-
18
17
  from diffusers import StableDiffusionXLPipeline
19
18
 
20
19
  from ....modeling_base import RBLNBaseModel
@@ -42,12 +41,13 @@ class RBLNStableDiffusionXLPipeline(StableDiffusionXLPipeline):
42
41
  - A path to a *directory* containing a model saved using [`~OptimizedModel.save_pretrained`],
43
42
  """
44
43
  export = kwargs.pop("export", None)
44
+ model_save_dir = kwargs.pop("model_save_dir", None)
45
+ rbln_config_kwargs, rbln_constructor_kwargs = RBLNBaseModel.pop_rbln_kwargs_from_kwargs(kwargs)
45
46
  model = super().from_pretrained(pretrained_model_name_or_path=model_id, **kwargs)
47
+
46
48
  if export is None or export is False:
47
49
  return model
48
50
 
49
- rbln_config_kwargs, rbln_constructor_kwargs = RBLNBaseModel.pop_rbln_kwargs_from_kwargs(kwargs)
50
-
51
51
  do_classifier_free_guidance = (
52
52
  rbln_config_kwargs.pop("rbln_guidance_scale", 5.0) > 1.0 and model.unet.config.time_cond_proj_dim is None
53
53
  )
@@ -56,6 +56,7 @@ class RBLNStableDiffusionXLPipeline(StableDiffusionXLPipeline):
56
56
  model_id=model_id,
57
57
  subfolder="vae",
58
58
  export=True,
59
+ model_save_dir=model_save_dir,
59
60
  rbln_unet_sample_size=model.unet.config.sample_size,
60
61
  rbln_use_encode=False,
61
62
  **rbln_config_kwargs,
@@ -65,6 +66,7 @@ class RBLNStableDiffusionXLPipeline(StableDiffusionXLPipeline):
65
66
  model_id=model_id,
66
67
  subfolder="text_encoder",
67
68
  export=True,
69
+ model_save_dir=model_save_dir,
68
70
  **rbln_config_kwargs,
69
71
  **rbln_constructor_kwargs,
70
72
  )
@@ -72,6 +74,7 @@ class RBLNStableDiffusionXLPipeline(StableDiffusionXLPipeline):
72
74
  model_id=model_id,
73
75
  subfolder="text_encoder_2",
74
76
  export=True,
77
+ model_save_dir=model_save_dir,
75
78
  **rbln_config_kwargs,
76
79
  **rbln_constructor_kwargs,
77
80
  )
@@ -83,6 +86,7 @@ class RBLNStableDiffusionXLPipeline(StableDiffusionXLPipeline):
83
86
  model_id=model_id,
84
87
  subfolder="unet",
85
88
  export=True,
89
+ model_save_dir=model_save_dir,
86
90
  rbln_max_seq_len=model.text_encoder.config.max_position_embeddings,
87
91
  rbln_text_model_hidden_size=model.text_encoder_2.config.hidden_size,
88
92
  rbln_batch_size=unet_batch_size,
@@ -92,6 +96,14 @@ class RBLNStableDiffusionXLPipeline(StableDiffusionXLPipeline):
92
96
  **rbln_constructor_kwargs,
93
97
  )
94
98
 
99
+ if model_save_dir is not None:
100
+ # To skip saving original pytorch modules
101
+ del (model.vae, model.text_encoder, model.unet, model.text_encoder_2)
102
+
103
+ # Direct calling of `save_pretrained` causes config.unet = (None, None).
104
+ # So config must be saved again, later.
105
+ model.save_pretrained(model_save_dir)
106
+
95
107
  model.vae = vae
96
108
  model.text_encoder = text_encoder
97
109
  model.unet = unet
@@ -104,6 +116,10 @@ class RBLNStableDiffusionXLPipeline(StableDiffusionXLPipeline):
104
116
  }
105
117
  model.register_to_config(**update_dict)
106
118
 
119
+ if model_save_dir is not None:
120
+ # overwrite to replace incorrect config
121
+ model.save_config(model_save_dir)
122
+
107
123
  model.models = [vae.model[0], unet.model[0], text_encoder.model[0], text_encoder_2.model[0]]
108
124
 
109
125
  return model
@@ -14,7 +14,6 @@
14
14
  # limitations under the License.
15
15
  """RBLNStableDiffusionXLPipeline class for inference of diffusion models on rbln devices."""
16
16
 
17
-
18
17
  from diffusers import StableDiffusionXLImg2ImgPipeline
19
18
 
20
19
  from ....modeling_base import RBLNBaseModel
@@ -42,12 +41,13 @@ class RBLNStableDiffusionXLImg2ImgPipeline(StableDiffusionXLImg2ImgPipeline):
42
41
  - A path to a *directory* containing a model saved using [`~OptimizedModel.save_pretrained`],
43
42
  """
44
43
  export = kwargs.pop("export", None)
44
+ model_save_dir = kwargs.pop("model_save_dir", None)
45
+ rbln_config_kwargs, rbln_constructor_kwargs = RBLNBaseModel.pop_rbln_kwargs_from_kwargs(kwargs)
45
46
  model = super().from_pretrained(pretrained_model_name_or_path=model_id, **kwargs)
47
+
46
48
  if export is None or export is False:
47
49
  return model
48
50
 
49
- rbln_config_kwargs, rbln_constructor_kwargs = RBLNBaseModel.pop_rbln_kwargs_from_kwargs(kwargs)
50
-
51
51
  do_classifier_free_guidance = (
52
52
  rbln_config_kwargs.pop("rbln_guidance_scale", 5.0) > 1.0 and model.unet.config.time_cond_proj_dim is None
53
53
  )
@@ -56,6 +56,7 @@ class RBLNStableDiffusionXLImg2ImgPipeline(StableDiffusionXLImg2ImgPipeline):
56
56
  model_id=model_id,
57
57
  subfolder="vae",
58
58
  export=True,
59
+ model_save_dir=model_save_dir,
59
60
  rbln_unet_sample_size=model.unet.config.sample_size,
60
61
  rbln_use_encode=True,
61
62
  rbln_vae_scale_factor=model.vae_scale_factor,
@@ -66,6 +67,7 @@ class RBLNStableDiffusionXLImg2ImgPipeline(StableDiffusionXLImg2ImgPipeline):
66
67
  model_id=model_id,
67
68
  subfolder="text_encoder",
68
69
  export=True,
70
+ model_save_dir=model_save_dir,
69
71
  **rbln_config_kwargs,
70
72
  **rbln_constructor_kwargs,
71
73
  )
@@ -73,6 +75,7 @@ class RBLNStableDiffusionXLImg2ImgPipeline(StableDiffusionXLImg2ImgPipeline):
73
75
  model_id=model_id,
74
76
  subfolder="text_encoder_2",
75
77
  export=True,
78
+ model_save_dir=model_save_dir,
76
79
  **rbln_config_kwargs,
77
80
  **rbln_constructor_kwargs,
78
81
  )
@@ -84,6 +87,7 @@ class RBLNStableDiffusionXLImg2ImgPipeline(StableDiffusionXLImg2ImgPipeline):
84
87
  model_id=model_id,
85
88
  subfolder="unet",
86
89
  export=True,
90
+ model_save_dir=model_save_dir,
87
91
  rbln_max_seq_len=model.text_encoder.config.max_position_embeddings,
88
92
  rbln_text_model_hidden_size=model.text_encoder_2.config.hidden_size,
89
93
  rbln_batch_size=unet_batch_size,
@@ -94,6 +98,14 @@ class RBLNStableDiffusionXLImg2ImgPipeline(StableDiffusionXLImg2ImgPipeline):
94
98
  **rbln_constructor_kwargs,
95
99
  )
96
100
 
101
+ if model_save_dir is not None:
102
+ # To skip saving original pytorch modules
103
+ del (model.vae, model.text_encoder, model.unet, model.text_encoder_2)
104
+
105
+ # Direct calling of `save_pretrained` causes config.unet = (None, None).
106
+ # So config must be saved again, later.
107
+ model.save_pretrained(model_save_dir)
108
+
97
109
  model.vae = vae
98
110
  model.text_encoder = text_encoder
99
111
  model.unet = unet
@@ -106,6 +118,10 @@ class RBLNStableDiffusionXLImg2ImgPipeline(StableDiffusionXLImg2ImgPipeline):
106
118
  }
107
119
  model.register_to_config(**update_dict)
108
120
 
121
+ if model_save_dir is not None:
122
+ # overwrite to replace incorrect config
123
+ model.save_config(model_save_dir)
124
+
109
125
  model.models = [vae.model[0], vae.model[1], unet.model[0], text_encoder.model[0], text_encoder_2.model[0]]
110
126
 
111
127
  return model
@@ -99,7 +99,7 @@ class RBLNBaseModel(OptimizedModel, ABC):
99
99
 
100
100
  model_type = "rbln_model"
101
101
  auto_model_class = AutoModel # feature extraction
102
- config_name = "model_index.json"
102
+ config_name = "config.json"
103
103
 
104
104
  def __init__(
105
105
  self,
@@ -109,7 +109,8 @@ class RBLNBaseModel(OptimizedModel, ABC):
109
109
  rbln_config: Optional[RBLNConfig],
110
110
  rbln_device: Optional[List[int]] = None,
111
111
  rbln_device_map: Optional[Dict[str, int]] = None,
112
- rbln_create_runtimes: Optional[bool] = True,
112
+ rbln_create_runtimes: Optional[bool] = None,
113
+ model_save_dir: Optional[Union[str, Path, TemporaryDirectory]] = None,
113
114
  **kwargs,
114
115
  ):
115
116
  super().__init__(models, config)
@@ -145,9 +146,24 @@ class RBLNBaseModel(OptimizedModel, ABC):
145
146
 
146
147
  self.device = torch.device("cpu")
147
148
 
149
+ if rbln_create_runtimes is None:
150
+ rbln_create_runtimes = rebel.npu_is_available()
151
+
148
152
  # create runtimes only if `rbln_create_runtimes` is enabled
149
153
  self.runtimes = self._create_runtimes(self.rbln_device_map) if rbln_create_runtimes else UnavailableRuntime()
150
154
 
155
+ # FIXME :: model_save_dir is not used after initialized. (This can be used when save/load)
156
+ # This attribute is needed to keep one reference on the temporary directory, since garbage collecting it
157
+ # would end-up removing the directory containing the underlying ONNX model.
158
+ self._model_save_dir_tempdirectory_instance = None
159
+ if isinstance(model_save_dir, TemporaryDirectory):
160
+ self._model_save_dir_tempdirectory_instance = model_save_dir
161
+ self.model_save_dir = Path(model_save_dir.name)
162
+ elif isinstance(model_save_dir, str):
163
+ self.model_save_dir = Path(model_save_dir)
164
+ else:
165
+ self.model_save_dir = model_save_dir
166
+
151
167
  self.__post_init__(**kwargs)
152
168
 
153
169
  def __post_init__(self, **kwargs):
@@ -179,6 +195,7 @@ class RBLNBaseModel(OptimizedModel, ABC):
179
195
  cache_dir: Optional[str] = None,
180
196
  subfolder: str = "",
181
197
  local_files_only: bool = False,
198
+ model_save_dir: Optional[Union[str, Path, TemporaryDirectory]] = None,
182
199
  **kwargs,
183
200
  ) -> "RBLNBaseModel":
184
201
  model_path = Path(model_id)
@@ -216,6 +233,7 @@ class RBLNBaseModel(OptimizedModel, ABC):
216
233
  rebel.RBLNCompiledModel(model_path / f"{compiled_model_name}.rbln")
217
234
  for compiled_model_name in rbln_config
218
235
  ]
236
+ new_model_save_dir = model_path
219
237
 
220
238
  else:
221
239
  rbln_config_filename = rbln_config_filenames[0]
@@ -243,14 +261,19 @@ class RBLNBaseModel(OptimizedModel, ABC):
243
261
  local_files_only=local_files_only,
244
262
  )
245
263
  models.append(rebel.RBLNCompiledModel(model_cache_path))
264
+ new_model_save_dir = Path(rbln_config_cache_path).parent
246
265
 
247
266
  preprocessors = maybe_load_preprocessors(model_id, subfolder=subfolder)
248
267
 
268
+ if model_save_dir is None:
269
+ model_save_dir = new_model_save_dir
270
+
249
271
  return cls(
250
272
  models,
251
273
  config,
252
274
  preprocessors,
253
275
  rbln_config=rbln_config,
276
+ model_save_dir=model_save_dir,
254
277
  **kwargs,
255
278
  )
256
279
 
@@ -370,6 +393,7 @@ class RBLNBaseModel(OptimizedModel, ABC):
370
393
  subfolder: str = "",
371
394
  local_files_only: bool = False,
372
395
  trust_remote_code: bool = False,
396
+ model_save_dir: Optional[Union[str, Path, TemporaryDirectory]] = None,
373
397
  **kwargs,
374
398
  ):
375
399
  """
@@ -408,6 +432,7 @@ class RBLNModel(RBLNBaseModel):
408
432
  subfolder: str = "",
409
433
  local_files_only: bool = False,
410
434
  trust_remote_code: bool = False,
435
+ model_save_dir: Optional[Union[str, Path, TemporaryDirectory]] = None,
411
436
  **kwargs,
412
437
  ) -> "RBLNModel":
413
438
  """
@@ -417,8 +442,16 @@ class RBLNModel(RBLNBaseModel):
417
442
  if task is None:
418
443
  task = TasksManager.infer_task_from_model(cls.auto_model_class)
419
444
 
420
- save_dir = TemporaryDirectory()
421
- save_dir_path = Path(save_dir.name)
445
+ if model_save_dir is None:
446
+ save_dir = TemporaryDirectory()
447
+ save_dir_path = Path(save_dir.name)
448
+ else:
449
+ save_dir = model_save_dir
450
+ if isinstance(save_dir, TemporaryDirectory):
451
+ save_dir_path = Path(model_save_dir.name)
452
+ else:
453
+ save_dir_path = Path(model_save_dir)
454
+ save_dir_path.mkdir(exist_ok=True)
422
455
 
423
456
  kwargs.update(
424
457
  {
@@ -457,7 +490,7 @@ class RBLNModel(RBLNBaseModel):
457
490
  preprocessors = maybe_save_preprocessors(model_id, save_dir_path, src_subfolder=subfolder)
458
491
 
459
492
  # Get compilation arguments
460
- if rbln_config_kwargs.get("rbln_config", None) is None:
493
+ if (rbln_config := rbln_config_kwargs.pop("rbln_config", None)) is None:
461
494
  rbln_config = cls.get_rbln_config(preprocessors=preprocessors, model_config=config, **rbln_config_kwargs)
462
495
 
463
496
  rbln_runtime_configs = list(rbln_config.values())
@@ -562,7 +595,7 @@ class RBLNModelForImageClassification(RBLNModel):
562
595
  rbln_image_size = processor.size["shortest_edge"]
563
596
  break
564
597
  if rbln_image_size is None:
565
- raise ValueError("`rbln_rbln_image_size` should be specified!")
598
+ raise ValueError("`rbln_image_size` should be specified!")
566
599
 
567
600
  if rbln_batch_size is None:
568
601
  rbln_batch_size = 1
@@ -160,6 +160,7 @@ class RBLNModelForSeq2SeqLM(RBLNBaseModel):
160
160
  subfolder: str = "",
161
161
  local_files_only: bool = False,
162
162
  trust_remote_code: bool = False,
163
+ model_save_dir: Optional[Union[str, Path, TemporaryDirectory]] = None,
163
164
  **kwargs,
164
165
  ) -> "AutoModelForSeq2SeqLM":
165
166
  """
@@ -169,8 +170,16 @@ class RBLNModelForSeq2SeqLM(RBLNBaseModel):
169
170
  if task is None:
170
171
  task = TasksManager.infer_task_from_model(cls.auto_model_class)
171
172
 
172
- save_dir = TemporaryDirectory()
173
- save_dir_path = Path(save_dir.name)
173
+ if model_save_dir is None:
174
+ save_dir = TemporaryDirectory()
175
+ save_dir_path = Path(save_dir.name)
176
+ else:
177
+ save_dir = model_save_dir
178
+ if isinstance(save_dir, TemporaryDirectory):
179
+ save_dir_path = Path(model_save_dir.name)
180
+ else:
181
+ save_dir_path = Path(model_save_dir)
182
+ save_dir_path.mkdir(exist_ok=True)
174
183
 
175
184
  kwargs.update(
176
185
  {
@@ -339,6 +348,8 @@ class RBLNModelForSeq2SeqLM(RBLNBaseModel):
339
348
  if max_position_embeddings is not None and rbln_dec_max_seq_len > max_position_embeddings:
340
349
  raise ValueError("`rbln_dec_max_seq_len` should be less or equal than max_position_embeddings!")
341
350
 
351
+ rbln_batch_size = 1 if rbln_batch_size is None else rbln_batch_size
352
+
342
353
  meta["rbln_enc_max_seq_len"] = rbln_enc_max_seq_len
343
354
  meta["rbln_dec_max_seq_len"] = rbln_dec_max_seq_len
344
355
  meta["rbln_batch_size"] = rbln_batch_size
@@ -429,9 +440,13 @@ class RBLNModelForSeq2SeqLM(RBLNBaseModel):
429
440
  return repr(self.runtimes[0]) + "\n" + repr(self.runtimes[1])
430
441
 
431
442
  def _prepare_encoder_decoder_kwargs_for_generation(
432
- self, inputs_tensor: torch.Tensor, model_kwargs, model_input_name: Optional[str] = None
443
+ self,
444
+ inputs_tensor: torch.Tensor,
445
+ model_kwargs,
446
+ model_input_name: Optional[str] = None,
447
+ *args,
448
+ **kwargs,
433
449
  ) -> Dict[str, Any]:
434
-
435
450
  ########## thkim change start ###################
436
451
  # padding input_ids & attention_mask regardless of user's tokenizer usage
437
452
  batch_size, input_len = inputs_tensor.shape
@@ -35,6 +35,7 @@ _import_structure = {
35
35
  "RBLNWav2Vec2ForCTC",
36
36
  "RBLNWhisperForConditionalGeneration",
37
37
  "RBLNLlamaForCausalLM",
38
+ "RBLNMidmLMHeadModel",
38
39
  ],
39
40
  }
40
41
 
@@ -45,6 +46,7 @@ if TYPE_CHECKING:
45
46
  RBLNCLIPTextModelWithProjection,
46
47
  RBLNGPT2LMHeadModel,
47
48
  RBLNLlamaForCausalLM,
49
+ RBLNMidmLMHeadModel,
48
50
  RBLNWav2Vec2ForCTC,
49
51
  RBLNWhisperForConditionalGeneration,
50
52
  )
@@ -22,3 +22,4 @@
22
22
  # from Rebellions Inc.
23
23
 
24
24
  from .streamers import BatchTextIteratorStreamer
25
+ from .utils import RBLNGenerationMixin
@@ -64,11 +64,13 @@ class BatchTextIteratorStreamer(TextIteratorStreamer):
64
64
  self.batch_size: int = batch_size
65
65
  self.token_cache: List[List[int]] = [[] for _ in range(batch_size)]
66
66
  self.print_len = [0] * batch_size
67
+ self.blocked = False
67
68
 
68
69
  def put(self, value):
69
70
  """
70
71
  Receives tokens, decodes them, and prints them to buffer as soon as they form entire words.
71
72
  """
73
+
72
74
  if len(value.shape) < 2:
73
75
  value = torch.reshape(value, (self.batch_size, value.shape[0] // self.batch_size))
74
76
 
@@ -115,8 +117,23 @@ class BatchTextIteratorStreamer(TextIteratorStreamer):
115
117
 
116
118
  self.next_tokens_are_prompt = True
117
119
  self.on_finalized_text(batch_printable_text, stream_end=True)
120
+ self.blocked = False
118
121
 
119
122
  def on_finalized_text(self, texts: List[str], stream_end: bool = False):
120
123
  self.text_queue.put(texts, timeout=self.timeout)
121
124
  if stream_end:
122
125
  self.text_queue.put(self.stop_signal, timeout=self.timeout)
126
+
127
+ # thkim change for demo
128
+ def __next__(self):
129
+ value = self.text_queue.get(timeout=self.timeout)
130
+ if value == self.stop_signal:
131
+ raise StopIteration()
132
+ else:
133
+ return value
134
+
135
+ def block(self):
136
+ self.blocked = True
137
+
138
+ def is_blocked(self):
139
+ return self.blocked