optimum-rbln 0.7.3.post2__py3-none-any.whl → 0.7.4__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (133) hide show
  1. optimum/rbln/__init__.py +173 -35
  2. optimum/rbln/__version__.py +2 -2
  3. optimum/rbln/configuration_utils.py +816 -0
  4. optimum/rbln/diffusers/__init__.py +56 -0
  5. optimum/rbln/diffusers/configurations/__init__.py +30 -0
  6. optimum/rbln/diffusers/configurations/models/__init__.py +6 -0
  7. optimum/rbln/diffusers/configurations/models/configuration_autoencoder_kl.py +66 -0
  8. optimum/rbln/diffusers/configurations/models/configuration_controlnet.py +62 -0
  9. optimum/rbln/diffusers/configurations/models/configuration_prior_transformer.py +52 -0
  10. optimum/rbln/diffusers/configurations/models/configuration_transformer_sd3.py +56 -0
  11. optimum/rbln/diffusers/configurations/models/configuration_unet_2d_condition.py +74 -0
  12. optimum/rbln/diffusers/configurations/models/configuration_vq_model.py +67 -0
  13. optimum/rbln/diffusers/configurations/pipelines/__init__.py +30 -0
  14. optimum/rbln/diffusers/configurations/pipelines/configuration_controlnet.py +236 -0
  15. optimum/rbln/diffusers/configurations/pipelines/configuration_kandinsky2_2.py +289 -0
  16. optimum/rbln/diffusers/configurations/pipelines/configuration_stable_diffusion.py +118 -0
  17. optimum/rbln/diffusers/configurations/pipelines/configuration_stable_diffusion_3.py +143 -0
  18. optimum/rbln/diffusers/configurations/pipelines/configuration_stable_diffusion_xl.py +124 -0
  19. optimum/rbln/diffusers/modeling_diffusers.py +111 -137
  20. optimum/rbln/diffusers/models/autoencoders/autoencoder_kl.py +109 -128
  21. optimum/rbln/diffusers/models/autoencoders/vae.py +4 -6
  22. optimum/rbln/diffusers/models/autoencoders/vq_model.py +84 -85
  23. optimum/rbln/diffusers/models/controlnet.py +56 -71
  24. optimum/rbln/diffusers/models/transformers/prior_transformer.py +40 -77
  25. optimum/rbln/diffusers/models/transformers/transformer_sd3.py +44 -69
  26. optimum/rbln/diffusers/models/unets/unet_2d_condition.py +111 -114
  27. optimum/rbln/diffusers/pipelines/controlnet/multicontrolnet.py +3 -4
  28. optimum/rbln/diffusers/pipelines/controlnet/pipeline_controlnet.py +2 -0
  29. optimum/rbln/diffusers/pipelines/controlnet/pipeline_controlnet_img2img.py +2 -0
  30. optimum/rbln/diffusers/pipelines/controlnet/pipeline_controlnet_sd_xl.py +2 -0
  31. optimum/rbln/diffusers/pipelines/controlnet/pipeline_controlnet_sd_xl_img2img.py +2 -0
  32. optimum/rbln/diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2.py +2 -0
  33. optimum/rbln/diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_combined.py +2 -0
  34. optimum/rbln/diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_img2img.py +2 -0
  35. optimum/rbln/diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_inpaint.py +2 -0
  36. optimum/rbln/diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_prior.py +2 -0
  37. optimum/rbln/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion.py +4 -1
  38. optimum/rbln/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_img2img.py +2 -0
  39. optimum/rbln/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_inpaint.py +2 -0
  40. optimum/rbln/diffusers/pipelines/stable_diffusion_3/pipeline_stable_diffusion_3.py +2 -0
  41. optimum/rbln/diffusers/pipelines/stable_diffusion_3/pipeline_stable_diffusion_3_img2img.py +2 -0
  42. optimum/rbln/diffusers/pipelines/stable_diffusion_3/pipeline_stable_diffusion_3_inpaint.py +2 -0
  43. optimum/rbln/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl.py +2 -0
  44. optimum/rbln/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_img2img.py +2 -0
  45. optimum/rbln/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_inpaint.py +2 -0
  46. optimum/rbln/modeling.py +66 -40
  47. optimum/rbln/modeling_base.py +111 -86
  48. optimum/rbln/ops/__init__.py +4 -7
  49. optimum/rbln/ops/attn.py +271 -205
  50. optimum/rbln/ops/flash_attn.py +161 -67
  51. optimum/rbln/ops/kv_cache_update.py +4 -40
  52. optimum/rbln/ops/linear.py +25 -0
  53. optimum/rbln/transformers/__init__.py +97 -8
  54. optimum/rbln/transformers/configuration_alias.py +49 -0
  55. optimum/rbln/transformers/configuration_generic.py +142 -0
  56. optimum/rbln/transformers/modeling_generic.py +193 -280
  57. optimum/rbln/transformers/models/__init__.py +120 -32
  58. optimum/rbln/transformers/models/auto/auto_factory.py +6 -6
  59. optimum/rbln/transformers/models/bart/__init__.py +2 -0
  60. optimum/rbln/transformers/models/bart/configuration_bart.py +24 -0
  61. optimum/rbln/transformers/models/bart/modeling_bart.py +12 -85
  62. optimum/rbln/transformers/models/bert/__init__.py +1 -0
  63. optimum/rbln/transformers/models/bert/configuration_bert.py +31 -0
  64. optimum/rbln/transformers/models/bert/modeling_bert.py +7 -80
  65. optimum/rbln/transformers/models/clip/__init__.py +6 -0
  66. optimum/rbln/transformers/models/clip/configuration_clip.py +79 -0
  67. optimum/rbln/transformers/models/clip/modeling_clip.py +72 -75
  68. optimum/rbln/transformers/models/decoderonly/__init__.py +11 -0
  69. optimum/rbln/transformers/models/decoderonly/configuration_decoderonly.py +90 -0
  70. optimum/rbln/transformers/models/decoderonly/decoderonly_architecture.py +197 -178
  71. optimum/rbln/transformers/models/decoderonly/modeling_decoderonly.py +343 -249
  72. optimum/rbln/transformers/models/dpt/__init__.py +1 -0
  73. optimum/rbln/transformers/models/dpt/configuration_dpt.py +19 -0
  74. optimum/rbln/transformers/models/dpt/modeling_dpt.py +3 -76
  75. optimum/rbln/transformers/models/exaone/__init__.py +1 -0
  76. optimum/rbln/transformers/models/exaone/configuration_exaone.py +19 -0
  77. optimum/rbln/transformers/models/gemma/__init__.py +1 -0
  78. optimum/rbln/transformers/models/gemma/configuration_gemma.py +19 -0
  79. optimum/rbln/transformers/models/gpt2/__init__.py +1 -0
  80. optimum/rbln/transformers/models/gpt2/configuration_gpt2.py +19 -0
  81. optimum/rbln/transformers/models/idefics3/__init__.py +16 -0
  82. optimum/rbln/transformers/models/idefics3/configuration_idefics3.py +51 -0
  83. optimum/rbln/transformers/models/idefics3/modeling_idefics3.py +459 -0
  84. optimum/rbln/transformers/models/llama/__init__.py +1 -0
  85. optimum/rbln/transformers/models/llama/configuration_llama.py +19 -0
  86. optimum/rbln/transformers/models/llava_next/__init__.py +1 -0
  87. optimum/rbln/transformers/models/llava_next/configuration_llava_next.py +46 -0
  88. optimum/rbln/transformers/models/llava_next/modeling_llava_next.py +18 -23
  89. optimum/rbln/transformers/models/midm/__init__.py +1 -0
  90. optimum/rbln/transformers/models/midm/configuration_midm.py +19 -0
  91. optimum/rbln/transformers/models/mistral/__init__.py +1 -0
  92. optimum/rbln/transformers/models/mistral/configuration_mistral.py +19 -0
  93. optimum/rbln/transformers/models/phi/__init__.py +1 -0
  94. optimum/rbln/transformers/models/phi/configuration_phi.py +19 -0
  95. optimum/rbln/transformers/models/qwen2/__init__.py +1 -0
  96. optimum/rbln/transformers/models/qwen2/configuration_qwen2.py +19 -0
  97. optimum/rbln/transformers/models/qwen2_5_vl/__init__.py +19 -0
  98. optimum/rbln/transformers/models/qwen2_5_vl/configuration_qwen2_5_vl.py +68 -0
  99. optimum/rbln/transformers/models/qwen2_5_vl/modeling_qwen2_5_vl.py +608 -0
  100. optimum/rbln/transformers/models/qwen2_5_vl/qwen2_5_vl_architecture.py +214 -0
  101. optimum/rbln/transformers/models/seq2seq/__init__.py +1 -0
  102. optimum/rbln/transformers/models/seq2seq/configuration_seq2seq2.py +66 -0
  103. optimum/rbln/transformers/models/seq2seq/modeling_seq2seq.py +99 -112
  104. optimum/rbln/transformers/models/seq2seq/seq2seq_architecture.py +12 -21
  105. optimum/rbln/transformers/models/t5/__init__.py +2 -0
  106. optimum/rbln/transformers/models/t5/configuration_t5.py +24 -0
  107. optimum/rbln/transformers/models/t5/modeling_t5.py +21 -356
  108. optimum/rbln/transformers/models/t5/t5_architecture.py +10 -5
  109. optimum/rbln/transformers/models/time_series_transformers/__init__.py +26 -0
  110. optimum/rbln/transformers/models/time_series_transformers/configuration_time_series_transformer.py +34 -0
  111. optimum/rbln/transformers/models/time_series_transformers/modeling_time_series_transformers.py +420 -0
  112. optimum/rbln/transformers/models/time_series_transformers/time_series_transformers_architecture.py +331 -0
  113. optimum/rbln/transformers/models/wav2vec2/__init__.py +1 -0
  114. optimum/rbln/transformers/models/wav2vec2/configuration_wav2vec.py +19 -0
  115. optimum/rbln/transformers/models/wav2vec2/modeling_wav2vec2.py +9 -72
  116. optimum/rbln/transformers/models/whisper/__init__.py +2 -0
  117. optimum/rbln/transformers/models/whisper/configuration_whisper.py +64 -0
  118. optimum/rbln/transformers/models/whisper/modeling_whisper.py +135 -100
  119. optimum/rbln/transformers/models/whisper/whisper_architecture.py +73 -40
  120. optimum/rbln/transformers/models/xlm_roberta/__init__.py +1 -0
  121. optimum/rbln/transformers/models/xlm_roberta/configuration_xlm_roberta.py +19 -0
  122. optimum/rbln/transformers/models/xlm_roberta/modeling_xlm_roberta.py +3 -83
  123. optimum/rbln/utils/hub.py +2 -2
  124. optimum/rbln/utils/import_utils.py +23 -6
  125. optimum/rbln/utils/model_utils.py +4 -4
  126. optimum/rbln/utils/runtime_utils.py +33 -2
  127. optimum/rbln/utils/submodule.py +36 -44
  128. {optimum_rbln-0.7.3.post2.dist-info → optimum_rbln-0.7.4.dist-info}/METADATA +6 -6
  129. optimum_rbln-0.7.4.dist-info/RECORD +169 -0
  130. optimum/rbln/modeling_config.py +0 -310
  131. optimum_rbln-0.7.3.post2.dist-info/RECORD +0 -122
  132. {optimum_rbln-0.7.3.post2.dist-info → optimum_rbln-0.7.4.dist-info}/WHEEL +0 -0
  133. {optimum_rbln-0.7.3.post2.dist-info → optimum_rbln-0.7.4.dist-info}/licenses/LICENSE +0 -0
@@ -0,0 +1,420 @@
1
+ # Copyright 2024 Rebellions Inc.
2
+
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at:
6
+
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ # Portions of this software are licensed under the Apache License,
16
+ # Version 2.0. See the NOTICE file distributed with this work for
17
+ # additional information regarding copyright ownership.
18
+
19
+ # All other portions of this software, including proprietary code,
20
+ # are the intellectual property of Rebellions Inc. and may not be
21
+ # copied, modified, or distributed without prior written permission
22
+ # from Rebellions Inc.
23
+
24
+ import inspect
25
+ import logging
26
+ from dataclasses import dataclass
27
+ from pathlib import Path
28
+ from typing import TYPE_CHECKING, Any, Callable, List, Optional, Tuple, Union
29
+
30
+ import rebel
31
+ import torch
32
+ from rebel.compile_context import CompileContext
33
+ from transformers import (
34
+ PretrainedConfig,
35
+ TimeSeriesTransformerForPrediction,
36
+ TimeSeriesTransformerModel,
37
+ )
38
+ from transformers.modeling_outputs import ModelOutput, SampleTSPredictionOutput, Seq2SeqTSModelOutput
39
+ from transformers.modeling_utils import no_init_weights
40
+
41
+ from ....configuration_utils import RBLNCompileConfig
42
+ from ....modeling import RBLNModel
43
+ from ....utils.runtime_utils import RBLNPytorchRuntime
44
+ from .configuration_time_series_transformer import RBLNTimeSeriesTransformerForPredictionConfig
45
+ from .time_series_transformers_architecture import TimeSeriesTransformersWrapper
46
+
47
+
48
+ logger = logging.getLogger(__name__)
49
+
50
+ if TYPE_CHECKING:
51
+ from transformers import AutoFeatureExtractor, AutoProcessor, AutoTokenizer, PretrainedConfig, PreTrainedModel
52
+
53
+
54
+ class RBLNRuntimeEncoder(RBLNPytorchRuntime):
55
+ mandatory_members = ["main_input_name"]
56
+
57
+ def __init__(
58
+ self,
59
+ runtime: rebel.Runtime,
60
+ model: TimeSeriesTransformerModel,
61
+ **kwargs: Any,
62
+ ) -> None:
63
+ super().__init__(runtime, **kwargs)
64
+ self._origin_model = model
65
+
66
+ def forward(
67
+ self,
68
+ past_values: torch.Tensor,
69
+ past_time_features: torch.Tensor,
70
+ static_categorical_features: Optional[torch.Tensor] = None,
71
+ static_real_features: Optional[torch.Tensor] = None,
72
+ past_observed_mask: Optional[torch.Tensor] = None,
73
+ future_values: Optional[torch.Tensor] = None,
74
+ future_time_features: Optional[torch.Tensor] = None,
75
+ ):
76
+ # preprocess
77
+ transformer_inputs, loc, scale, static_feat = self._origin_model.create_network_inputs(
78
+ past_values=past_values,
79
+ past_time_features=past_time_features,
80
+ past_observed_mask=past_observed_mask,
81
+ static_categorical_features=static_categorical_features,
82
+ static_real_features=static_real_features,
83
+ future_values=future_values,
84
+ future_time_features=future_time_features,
85
+ )
86
+ enc_input = transformer_inputs[:, : self._origin_model.config.context_length, ...]
87
+
88
+ # enc_attn_key_value_caches is updated to device dram in-place
89
+ _ = super().forward(inputs_embeds=enc_input)
90
+
91
+ return Seq2SeqTSModelOutput(
92
+ loc=loc,
93
+ scale=scale,
94
+ static_features=static_feat,
95
+ )
96
+
97
+
98
+ class RBLNRuntimeDecoder(RBLNPytorchRuntime):
99
+ mandatory_members = ["main_input_name"]
100
+
101
+ def forward(
102
+ self,
103
+ inputs_embeds: torch.Tensor = None,
104
+ attention_mask: torch.Tensor = None,
105
+ cache_position: torch.Tensor = None,
106
+ ):
107
+ block_tables = torch.zeros(1, 1, dtype=torch.int16)
108
+ outputs = super().forward(inputs_embeds, attention_mask, cache_position, block_tables)
109
+
110
+ return RBLNSeq2SeqTSDecoderOutput(
111
+ params=outputs[:-1],
112
+ last_hidden_states=outputs[-1],
113
+ )
114
+
115
+
116
+ @dataclass
117
+ class RBLNSeq2SeqTSDecoderOutput(ModelOutput):
118
+ last_hidden_states: torch.FloatTensor = None
119
+ params: Tuple[torch.FloatTensor] = None
120
+
121
+
122
+ class RBLNTimeSeriesTransformerForPrediction(RBLNModel):
123
+ auto_model_class = None
124
+ main_input_name = "inputs_embeds"
125
+
126
+ def __post_init__(self, **kwargs):
127
+ super().__post_init__(**kwargs)
128
+ self.batch_size = self.rbln_config.batch_size
129
+ self.dec_max_seq_len = self.rbln_config.dec_max_seq_len
130
+ self.num_parallel_samples = self.rbln_config.num_parallel_samples
131
+
132
+ with no_init_weights():
133
+ self._origin_model = TimeSeriesTransformerForPrediction._from_config(self.config)
134
+ artifacts = torch.load(self.model_save_dir / self.subfolder / "torch_artifacts.pth", weights_only=False)
135
+ self._origin_model.model.embedder.load_state_dict(artifacts["embedder"])
136
+ self.encoder = RBLNRuntimeEncoder(
137
+ runtime=self.model[0],
138
+ main_input_name="inputs_embeds",
139
+ model=self._origin_model.model,
140
+ )
141
+ self.decoder = RBLNRuntimeDecoder(
142
+ runtime=self.model[1],
143
+ main_input_name="inputs_embeds",
144
+ )
145
+
146
+ def __getattr__(self, __name: str) -> Any:
147
+ """This is the key method to implement RBLN-TimeSeriesTransformersForPrediction.
148
+ Returns:
149
+ Any: TimeSeriesTransformersForPrediction's corresponding method
150
+ """
151
+
152
+ def redirect(func):
153
+ return lambda *pargs, **kwargs: func(self, *pargs, **kwargs)
154
+
155
+ val = getattr(TimeSeriesTransformerForPrediction, __name)
156
+ if val is not None and isinstance(val, Callable) and "self" in set(inspect.signature(val).parameters):
157
+ return redirect(val)
158
+
159
+ @classmethod
160
+ def wrap_model_if_needed(
161
+ self, model: "PreTrainedModel", rbln_config: RBLNTimeSeriesTransformerForPredictionConfig
162
+ ):
163
+ return TimeSeriesTransformersWrapper(model, rbln_config.num_parallel_samples)
164
+
165
+ @classmethod
166
+ @torch.inference_mode()
167
+ def get_compiled_model(cls, model, rbln_config: RBLNTimeSeriesTransformerForPredictionConfig):
168
+ wrapped_model = cls.wrap_model_if_needed(model, rbln_config)
169
+
170
+ enc_compile_config = rbln_config.compile_cfgs[0]
171
+ dec_compile_config = rbln_config.compile_cfgs[1]
172
+
173
+ context = CompileContext(use_weight_sharing=False)
174
+
175
+ enc_example_inputs = enc_compile_config.get_dummy_inputs(fill=0)
176
+
177
+ # Mark encoder's static tensors (cross kv states)
178
+ static_tensors = {}
179
+ for (name, _, _), tensor in zip(enc_compile_config.input_info, enc_example_inputs):
180
+ if "key_value_states" in name:
181
+ static_tensors[name] = tensor
182
+ context.mark_static_address(tensor)
183
+
184
+ dec_example_inputs = dec_compile_config.get_dummy_inputs(fill=0, static_tensors=static_tensors)
185
+
186
+ # Mark decoder's static tensors (self kv states)
187
+ for (name, _, _), tensor in zip(dec_compile_config.input_info, dec_example_inputs):
188
+ if "key_value_states" in name:
189
+ context.mark_static_address(tensor)
190
+
191
+ compiled_decoder = super().compile(
192
+ wrapped_model.decoder,
193
+ dec_compile_config,
194
+ example_inputs=dec_example_inputs,
195
+ compile_context=context,
196
+ )
197
+ compiled_encoder = super().compile(
198
+ wrapped_model.encoder,
199
+ enc_compile_config,
200
+ example_inputs=enc_example_inputs,
201
+ compile_context=context,
202
+ )
203
+
204
+ return {"encoder": compiled_encoder, "decoder": compiled_decoder}
205
+
206
+ @classmethod
207
+ def save_torch_artifacts(
208
+ cls,
209
+ model: "PreTrainedModel",
210
+ save_dir_path: Path,
211
+ subfolder: str,
212
+ rbln_config: RBLNTimeSeriesTransformerForPredictionConfig,
213
+ ):
214
+ """
215
+ If you are unavoidably running on a CPU rather than an RBLN device,
216
+ store the torch tensor, weight, etc. in this function.
217
+ """
218
+ save_dict = {}
219
+ save_dict["embedder"] = model.model.embedder.state_dict()
220
+ torch.save(save_dict, save_dir_path / subfolder / "torch_artifacts.pth")
221
+
222
+ @classmethod
223
+ def _update_rbln_config(
224
+ cls,
225
+ preprocessors: Optional[Union["AutoFeatureExtractor", "AutoProcessor", "AutoTokenizer"]] = None,
226
+ model: Optional["PreTrainedModel"] = None,
227
+ model_config: Optional["PretrainedConfig"] = None,
228
+ rbln_config: Optional[RBLNTimeSeriesTransformerForPredictionConfig] = None,
229
+ ) -> RBLNTimeSeriesTransformerForPredictionConfig:
230
+ rbln_config.num_parallel_samples = rbln_config.num_parallel_samples or model_config.num_parallel_samples
231
+
232
+ if rbln_config.dec_max_seq_len is None:
233
+ predict_length = model_config.prediction_length
234
+ rbln_config.dec_max_seq_len = (
235
+ predict_length if predict_length % 64 == 0 else predict_length + (64 - predict_length % 64)
236
+ )
237
+
238
+ # model input info
239
+ enc_input_info = [
240
+ (
241
+ "inputs_embeds",
242
+ [rbln_config.batch_size, model_config.context_length, model_config.feature_size],
243
+ "float32",
244
+ ),
245
+ ]
246
+ enc_input_info.extend(
247
+ [
248
+ (
249
+ "cross_key_value_states",
250
+ [
251
+ model_config.decoder_layers * 2,
252
+ rbln_config.batch_size,
253
+ model_config.decoder_attention_heads,
254
+ model_config.context_length,
255
+ model_config.d_model // model_config.decoder_attention_heads,
256
+ ],
257
+ "float32",
258
+ )
259
+ ]
260
+ )
261
+
262
+ dec_input_info = [
263
+ (
264
+ "inputs_embeds",
265
+ [rbln_config.batch_size * rbln_config.num_parallel_samples, 1, model_config.feature_size],
266
+ "float32",
267
+ ),
268
+ ("attention_mask", [1, rbln_config.dec_max_seq_len], "float32"),
269
+ ("cache_position", [], "int32"),
270
+ ("block_tables", [1, 1], "int16"),
271
+ ]
272
+ dec_input_info.extend(
273
+ [
274
+ (
275
+ "cross_key_value_states",
276
+ [
277
+ model_config.decoder_layers * 2, # 4
278
+ rbln_config.batch_size, # 64
279
+ model_config.decoder_attention_heads, # 2
280
+ model_config.context_length, # 24
281
+ model_config.d_model // model_config.decoder_attention_heads, # 13
282
+ ],
283
+ "float32",
284
+ )
285
+ ]
286
+ )
287
+ dec_input_info.extend(
288
+ [
289
+ (
290
+ f"self_key_value_states_{i}",
291
+ [
292
+ 1,
293
+ model_config.decoder_attention_heads
294
+ * rbln_config.num_parallel_samples
295
+ * rbln_config.batch_size,
296
+ rbln_config.dec_max_seq_len,
297
+ model_config.d_model // model_config.encoder_attention_heads,
298
+ ],
299
+ "float32",
300
+ )
301
+ for i in range(model_config.decoder_layers * 2)
302
+ ]
303
+ )
304
+ enc_compile_config = RBLNCompileConfig(compiled_model_name="encoder", input_info=enc_input_info)
305
+ dec_compile_config = RBLNCompileConfig(compiled_model_name="decoder", input_info=dec_input_info)
306
+
307
+ rbln_config.set_compile_cfgs([enc_compile_config, dec_compile_config])
308
+ return rbln_config
309
+
310
+ @classmethod
311
+ def _create_runtimes(
312
+ cls,
313
+ compiled_models: List[rebel.RBLNCompiledModel],
314
+ rbln_config: RBLNTimeSeriesTransformerForPredictionConfig,
315
+ ) -> List[rebel.Runtime]:
316
+ if any(model_name not in rbln_config.device_map for model_name in ["encoder", "decoder"]):
317
+ cls._raise_missing_compiled_file_error(["encoder", "decoder"])
318
+
319
+ return [
320
+ rebel.Runtime(
321
+ compiled_models[0],
322
+ tensor_type="pt",
323
+ device=rbln_config.device_map["encoder"],
324
+ activate_profiler=rbln_config.activate_profiler,
325
+ ),
326
+ rebel.Runtime(
327
+ compiled_models[1],
328
+ tensor_type="pt",
329
+ device=rbln_config.device_map["decoder"],
330
+ activate_profiler=rbln_config.activate_profiler,
331
+ ),
332
+ ]
333
+
334
+ def validate_batch_size(self, **kwargs):
335
+ for k, v in kwargs.items():
336
+ if v is not None and v.shape[0] != self.batch_size:
337
+ raise RuntimeError(
338
+ f"Batch size mismatch in '{k}': Expected {self.batch_size}, but got {v.shape[0]}. \n"
339
+ f"Tensor shape: {v.shape} \n\n"
340
+ f"Note: `batch_size` is set at compile time. \n"
341
+ f"To change it, pass `export=True` along with `rbln_batch_size` when calling `from_pretrained()` to trigger recompilation."
342
+ )
343
+
344
+ @torch.no_grad()
345
+ def generate(
346
+ self,
347
+ past_values: torch.Tensor,
348
+ past_time_features: torch.Tensor,
349
+ future_time_features: torch.Tensor,
350
+ past_observed_mask: Optional[torch.Tensor] = None,
351
+ static_categorical_features: Optional[torch.Tensor] = None,
352
+ static_real_features: Optional[torch.Tensor] = None,
353
+ **kwargs,
354
+ ) -> SampleTSPredictionOutput:
355
+ self.validate_batch_size(**{k: v for k, v in locals().items() if isinstance(v, torch.Tensor)})
356
+
357
+ outputs = self.encoder(
358
+ static_categorical_features=static_categorical_features,
359
+ static_real_features=static_real_features,
360
+ past_time_features=past_time_features,
361
+ past_values=past_values,
362
+ past_observed_mask=past_observed_mask,
363
+ future_time_features=future_time_features,
364
+ )
365
+
366
+ loc = outputs.loc
367
+ scale = outputs.scale
368
+ static_feat = outputs.static_features
369
+
370
+ num_parallel_samples = self.num_parallel_samples
371
+ repeated_loc = loc.repeat_interleave(repeats=num_parallel_samples, dim=0)
372
+ repeated_scale = scale.repeat_interleave(repeats=num_parallel_samples, dim=0)
373
+
374
+ repeated_past_values = (
375
+ past_values.repeat_interleave(repeats=num_parallel_samples, dim=0) - repeated_loc
376
+ ) / repeated_scale
377
+
378
+ expanded_static_feat = static_feat.unsqueeze(1).expand(-1, future_time_features.shape[1], -1)
379
+ features = torch.cat((expanded_static_feat, future_time_features), dim=-1)
380
+ repeated_features = features.repeat_interleave(repeats=num_parallel_samples, dim=0)
381
+
382
+ # greedy decoding
383
+ future_samples = []
384
+ dec_attn_mask = torch.zeros(1, self.dec_max_seq_len)
385
+ for k in range(self.config.prediction_length):
386
+ lagged_sequence = self._origin_model.model.get_lagged_subsequences(
387
+ sequence=repeated_past_values,
388
+ subsequences_length=1 + k,
389
+ shift=1,
390
+ )
391
+
392
+ lags_shape = lagged_sequence.shape
393
+ reshaped_lagged_sequence = lagged_sequence.reshape(lags_shape[0], lags_shape[1], -1)
394
+ decoder_input = torch.cat((reshaped_lagged_sequence, repeated_features[:, : k + 1]), dim=-1)
395
+
396
+ dec_attn_mask[:, k] = 1
397
+ dec_inputs_embeds = decoder_input[:, -1:]
398
+
399
+ decoder_out = self.decoder(
400
+ inputs_embeds=dec_inputs_embeds.contiguous(),
401
+ attention_mask=dec_attn_mask,
402
+ cache_position=torch.tensor(k, dtype=torch.int32),
403
+ )
404
+ params = decoder_out.params
405
+
406
+ distr = self._origin_model.output_distribution(params, loc=repeated_loc, scale=repeated_scale)
407
+ next_sample = distr.sample()
408
+
409
+ repeated_past_values = torch.cat(
410
+ (repeated_past_values, (next_sample - repeated_loc) / repeated_scale), dim=1
411
+ )
412
+ future_samples.append(next_sample)
413
+
414
+ concat_future_samples = torch.cat(future_samples, dim=1)
415
+
416
+ return SampleTSPredictionOutput(
417
+ sequences=concat_future_samples.reshape(
418
+ (-1, num_parallel_samples, self.config.prediction_length) + self._origin_model.target_shape,
419
+ )
420
+ )