optimum-rbln 0.7.4a0__py3-none-any.whl → 0.7.4a1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
optimum/rbln/__init__.py CHANGED
@@ -73,6 +73,7 @@ _import_structure = {
73
73
  "RBLNRobertaForMaskedLM",
74
74
  "RBLNViTForImageClassification",
75
75
  "RBLNBertForMaskedLM",
76
+ "RBLNTimeSeriesTransformerForPrediction",
76
77
  ],
77
78
  "diffusers": [
78
79
  "RBLNAutoencoderKL",
@@ -184,6 +185,7 @@ if TYPE_CHECKING:
184
185
  RBLNRobertaForSequenceClassification,
185
186
  RBLNT5EncoderModel,
186
187
  RBLNT5ForConditionalGeneration,
188
+ RBLNTimeSeriesTransformerForPrediction,
187
189
  RBLNViTForImageClassification,
188
190
  RBLNWav2Vec2ForCTC,
189
191
  RBLNWhisperForConditionalGeneration,
@@ -17,5 +17,5 @@ __version__: str
17
17
  __version_tuple__: VERSION_TUPLE
18
18
  version_tuple: VERSION_TUPLE
19
19
 
20
- __version__ = version = '0.7.4a0'
20
+ __version__ = version = '0.7.4a1'
21
21
  __version_tuple__ = version_tuple = (0, 7, 4)
@@ -19,3 +19,4 @@ from .attn import (
19
19
  )
20
20
  from .flash_attn import register_rbln_custom_paged_flash_attention, register_rbln_custom_paged_flash_causal_attention
21
21
  from .kv_cache_update import register_rbln_custom_cache_update
22
+ from .linear import linear
@@ -0,0 +1,25 @@
1
+ # Copyright 2025 Rebellions Inc. All rights reserved.
2
+
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at:
6
+
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ from typing import Optional
16
+
17
+ import torch
18
+ from torch import Tensor
19
+
20
+
21
+ @torch.library.custom_op("rbln_custom_ops::linear", mutates_args=())
22
+ def linear(input: Tensor, weight: Tensor, bias: Optional[Tensor] = None) -> Tensor:
23
+ output_shape = list(input.shape[:-1])
24
+ output_shape += [weight.shape[0]]
25
+ return torch.empty(size=output_shape, dtype=input.dtype, device=input.device, requires_grad=input.requires_grad)
@@ -52,6 +52,7 @@ _import_structure = {
52
52
  "RBLNPhiForCausalLM",
53
53
  "RBLNT5EncoderModel",
54
54
  "RBLNT5ForConditionalGeneration",
55
+ "RBLNTimeSeriesTransformerForPrediction",
55
56
  "RBLNLlavaNextForConditionalGeneration",
56
57
  "RBLNMidmLMHeadModel",
57
58
  "RBLNXLMRobertaModel",
@@ -113,6 +114,7 @@ if TYPE_CHECKING:
113
114
  RBLNQwen2ForCausalLM,
114
115
  RBLNT5EncoderModel,
115
116
  RBLNT5ForConditionalGeneration,
117
+ RBLNTimeSeriesTransformerForPrediction,
116
118
  RBLNWav2Vec2ForCTC,
117
119
  RBLNWhisperForConditionalGeneration,
118
120
  RBLNXLMRobertaModel,
@@ -50,6 +50,7 @@ _import_structure = {
50
50
  "mistral": ["RBLNMistralForCausalLM"],
51
51
  "phi": ["RBLNPhiForCausalLM"],
52
52
  "qwen2": ["RBLNQwen2ForCausalLM"],
53
+ "time_series_transformers": ["RBLNTimeSeriesTransformerForPrediction"],
53
54
  "t5": ["RBLNT5EncoderModel", "RBLNT5ForConditionalGeneration"],
54
55
  "wav2vec2": ["RBLNWav2Vec2ForCTC"],
55
56
  "whisper": ["RBLNWhisperForConditionalGeneration"],
@@ -90,6 +91,7 @@ if TYPE_CHECKING:
90
91
  from .phi import RBLNPhiForCausalLM
91
92
  from .qwen2 import RBLNQwen2ForCausalLM
92
93
  from .t5 import RBLNT5EncoderModel, RBLNT5ForConditionalGeneration
94
+ from .time_series_transformers import RBLNTimeSeriesTransformerForPrediction
93
95
  from .wav2vec2 import RBLNWav2Vec2ForCTC
94
96
  from .whisper import RBLNWhisperForConditionalGeneration
95
97
  from .xlm_roberta import RBLNXLMRobertaModel
@@ -94,12 +94,11 @@ class RBLNBartModel(RBLNModel):
94
94
  for model_input_name in rbln_model_input_names
95
95
  ]
96
96
 
97
- enc_compile_config = RBLNCompileConfig(input_info=input_info, compiled_model_name="encoder")
98
- dec_compile_config = RBLNCompileConfig(input_info=input_info, compiled_model_name="decoder")
97
+ rbln_compile_config = RBLNCompileConfig(input_info=input_info)
99
98
 
100
99
  rbln_config = RBLNConfig(
101
100
  rbln_cls=cls.__name__,
102
- compile_cfgs=[enc_compile_config, dec_compile_config],
101
+ compile_cfgs=[rbln_compile_config],
103
102
  rbln_kwargs=rbln_kwargs,
104
103
  )
105
104
 
@@ -222,8 +222,6 @@ class RBLNRuntimeModel(RBLNPytorchRuntime):
222
222
 
223
223
  attention_mask = self.dec_attn_mask
224
224
 
225
- attention_mask = self.dec_attn_mask
226
-
227
225
  logits = super().forward(
228
226
  inputs,
229
227
  cache_position,
@@ -547,22 +545,27 @@ class RBLNDecoderOnlyModelForCausalLM(RBLNModel):
547
545
 
548
546
  @QuantizationManager.with_quantization_env
549
547
  def compile_model(*args, **kwargs):
550
- wrapped_model.phase = "prefill"
551
- compiled_prefill = RBLNModel.compile(
552
- wrapped_model,
553
- prefill_compile_config,
554
- example_inputs=prefill_example_inputs,
555
- compile_context=context,
556
- )
548
+ try:
549
+ original_linear = torch.nn.functional.linear
550
+ torch.nn.functional.linear = torch.ops.rbln_custom_ops.linear
551
+ wrapped_model.phase = "prefill"
552
+ compiled_prefill = RBLNModel.compile(
553
+ wrapped_model,
554
+ prefill_compile_config,
555
+ example_inputs=prefill_example_inputs,
556
+ compile_context=context,
557
+ )
557
558
 
558
- wrapped_model.phase = "decode"
559
- compiled_decoder = RBLNModel.compile(
560
- wrapped_model,
561
- dec_compile_config,
562
- example_inputs=dec_example_inputs,
563
- compile_context=context,
564
- )
565
- return {"prefill": compiled_prefill, "decoder": compiled_decoder}
559
+ wrapped_model.phase = "decode"
560
+ compiled_decoder = RBLNModel.compile(
561
+ wrapped_model,
562
+ dec_compile_config,
563
+ example_inputs=dec_example_inputs,
564
+ compile_context=context,
565
+ )
566
+ return {"prefill": compiled_prefill, "decoder": compiled_decoder}
567
+ finally:
568
+ torch.nn.functional.linear = original_linear
566
569
 
567
570
  return compile_model(quantize_config=quantize_config)
568
571
 
@@ -38,8 +38,8 @@ class RBLNRuntimeEncoder(RBLNPytorchRuntime):
38
38
  mandatory_members = ["main_input_name"]
39
39
 
40
40
  def forward(self, *args: List[torch.Tensor], **kwargs: Dict[str, torch.Tensor]):
41
- _ = super().forward(*args, **kwargs)
42
- return BaseModelOutput(last_hidden_state=torch.tensor([1.0]))
41
+ output = super().forward(*args, **kwargs)
42
+ return BaseModelOutput(last_hidden_state=output)
43
43
 
44
44
 
45
45
  class RBLNRuntimeDecoder(RBLNPytorchRuntime):
@@ -0,0 +1,24 @@
1
+ # Copyright 2024 Rebellions Inc.
2
+
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at:
6
+
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ # Portions of this software are licensed under the Apache License,
16
+ # Version 2.0. See the NOTICE file distributed with this work for
17
+ # additional information regarding copyright ownership.
18
+
19
+ # All other portions of this software, including proprietary code,
20
+ # are the intellectual property of Rebellions Inc. and may not be
21
+ # copied, modified, or distributed without prior written permission
22
+ # from Rebellions Inc.
23
+
24
+ from .modeling_time_series_transformers import RBLNTimeSeriesTransformerForPrediction
@@ -0,0 +1,422 @@
1
+ # Copyright 2024 Rebellions Inc.
2
+
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at:
6
+
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ # Portions of this software are licensed under the Apache License,
16
+ # Version 2.0. See the NOTICE file distributed with this work for
17
+ # additional information regarding copyright ownership.
18
+
19
+ # All other portions of this software, including proprietary code,
20
+ # are the intellectual property of Rebellions Inc. and may not be
21
+ # copied, modified, or distributed without prior written permission
22
+ # from Rebellions Inc.
23
+
24
+ import inspect
25
+ import logging
26
+ from dataclasses import dataclass
27
+ from pathlib import Path
28
+ from typing import TYPE_CHECKING, Any, Callable, Dict, List, Optional, Tuple, Union
29
+
30
+ import rebel
31
+ import torch
32
+ from rebel.compile_context import CompileContext
33
+ from transformers import (
34
+ PretrainedConfig,
35
+ TimeSeriesTransformerForPrediction,
36
+ TimeSeriesTransformerModel,
37
+ )
38
+ from transformers.modeling_outputs import ModelOutput, SampleTSPredictionOutput, Seq2SeqTSModelOutput
39
+ from transformers.modeling_utils import no_init_weights
40
+
41
+ from ....modeling import RBLNModel
42
+ from ....modeling_config import RBLNCompileConfig, RBLNConfig
43
+ from ....utils.runtime_utils import RBLNPytorchRuntime
44
+ from .time_series_transformers_architecture import TimeSeriesTransformersWrapper
45
+
46
+
47
+ logger = logging.getLogger(__name__)
48
+
49
+ if TYPE_CHECKING:
50
+ from transformers import AutoFeatureExtractor, AutoProcessor, AutoTokenizer, PretrainedConfig, PreTrainedModel
51
+
52
+
53
+ class RBLNRuntimeEncoder(RBLNPytorchRuntime):
54
+ mandatory_members = ["main_input_name"]
55
+
56
+ def __init__(
57
+ self,
58
+ runtime: rebel.Runtime,
59
+ model: TimeSeriesTransformerModel,
60
+ **kwargs: Any,
61
+ ) -> None:
62
+ super().__init__(runtime, **kwargs)
63
+ self._origin_model = model
64
+
65
+ def forward(
66
+ self,
67
+ past_values: torch.Tensor,
68
+ past_time_features: torch.Tensor,
69
+ static_categorical_features: Optional[torch.Tensor] = None,
70
+ static_real_features: Optional[torch.Tensor] = None,
71
+ past_observed_mask: Optional[torch.Tensor] = None,
72
+ future_values: Optional[torch.Tensor] = None,
73
+ future_time_features: Optional[torch.Tensor] = None,
74
+ ):
75
+ # preprocess
76
+ transformer_inputs, loc, scale, static_feat = self._origin_model.create_network_inputs(
77
+ past_values=past_values,
78
+ past_time_features=past_time_features,
79
+ past_observed_mask=past_observed_mask,
80
+ static_categorical_features=static_categorical_features,
81
+ static_real_features=static_real_features,
82
+ future_values=future_values,
83
+ future_time_features=future_time_features,
84
+ )
85
+ enc_input = transformer_inputs[:, : self._origin_model.config.context_length, ...]
86
+
87
+ # enc_attn_key_value_caches is updated to device dram in-place
88
+ _ = super().forward(inputs_embeds=enc_input)
89
+
90
+ return Seq2SeqTSModelOutput(
91
+ loc=loc,
92
+ scale=scale,
93
+ static_features=static_feat,
94
+ )
95
+
96
+
97
+ class RBLNRuntimeDecoder(RBLNPytorchRuntime):
98
+ mandatory_members = ["main_input_name"]
99
+
100
+ def forward(
101
+ self,
102
+ inputs_embeds: torch.Tensor = None,
103
+ attention_mask: torch.Tensor = None,
104
+ cache_position: torch.Tensor = None,
105
+ ):
106
+ block_tables = torch.zeros(1, 1, dtype=torch.int16)
107
+ outputs = super().forward(inputs_embeds, attention_mask, cache_position, block_tables)
108
+
109
+ return RBLNSeq2SeqTSDecoderOutput(
110
+ params=outputs[:-1],
111
+ last_hidden_states=outputs[-1],
112
+ )
113
+
114
+
115
+ @dataclass
116
+ class RBLNSeq2SeqTSDecoderOutput(ModelOutput):
117
+ last_hidden_states: torch.FloatTensor = None
118
+ params: Tuple[torch.FloatTensor] = None
119
+
120
+
121
+ class RBLNTimeSeriesTransformerForPrediction(RBLNModel):
122
+ auto_model_class = None
123
+ main_input_name = "inputs_embeds"
124
+
125
+ def __post_init__(self, **kwargs):
126
+ super().__post_init__(**kwargs)
127
+ self.batch_size = self.rbln_config.model_cfg["batch_size"]
128
+ self.dec_max_seq_len = self.rbln_config.model_cfg["dec_max_seq_len"]
129
+ self.num_parallel_samples = self.rbln_config.model_cfg["num_parallel_samples"]
130
+
131
+ with no_init_weights():
132
+ self._origin_model = TimeSeriesTransformerForPrediction._from_config(self.config)
133
+ artifacts = torch.load(self.model_save_dir / self.subfolder / "torch_artifacts.pth", weights_only=False)
134
+ self._origin_model.model.embedder.load_state_dict(artifacts["embedder"])
135
+ self.encoder = RBLNRuntimeEncoder(
136
+ runtime=self.model[0],
137
+ main_input_name="inputs_embeds",
138
+ model=self._origin_model.model,
139
+ )
140
+ self.decoder = RBLNRuntimeDecoder(
141
+ runtime=self.model[1],
142
+ main_input_name="inputs_embeds",
143
+ )
144
+
145
+ def __getattr__(self, __name: str) -> Any:
146
+ """This is the key method to implement RBLN-TimeSeriesTransformersForPrediction.
147
+ Returns:
148
+ Any: TimeSeriesTransformersForPrediction's corresponding method
149
+ """
150
+
151
+ def redirect(func):
152
+ return lambda *pargs, **kwargs: func(self, *pargs, **kwargs)
153
+
154
+ val = getattr(TimeSeriesTransformerForPrediction, __name)
155
+ if val is not None and isinstance(val, Callable) and "self" in set(inspect.signature(val).parameters):
156
+ return redirect(val)
157
+
158
+ @classmethod
159
+ def wrap_model_if_needed(self, model: "PreTrainedModel", rbln_config: "RBLNConfig"):
160
+ return TimeSeriesTransformersWrapper(model, rbln_config.model_cfg["num_parallel_samples"])
161
+
162
+ @classmethod
163
+ @torch.inference_mode()
164
+ def get_compiled_model(cls, model, rbln_config: RBLNConfig):
165
+ wrapped_model = cls.wrap_model_if_needed(model, rbln_config)
166
+
167
+ enc_compile_config = rbln_config.compile_cfgs[0]
168
+ dec_compile_config = rbln_config.compile_cfgs[1]
169
+
170
+ context = CompileContext(use_weight_sharing=False)
171
+
172
+ enc_example_inputs = enc_compile_config.get_dummy_inputs(fill=0)
173
+
174
+ # Mark encoder's static tensors (cross kv states)
175
+ static_tensors = {}
176
+ for (name, _, _), tensor in zip(enc_compile_config.input_info, enc_example_inputs):
177
+ if "key_value_states" in name:
178
+ static_tensors[name] = tensor
179
+ context.mark_static_address(tensor)
180
+
181
+ dec_example_inputs = dec_compile_config.get_dummy_inputs(fill=0, static_tensors=static_tensors)
182
+
183
+ # Mark decoder's static tensors (self kv states)
184
+ for (name, _, _), tensor in zip(dec_compile_config.input_info, dec_example_inputs):
185
+ if "key_value_states" in name:
186
+ context.mark_static_address(tensor)
187
+
188
+ compiled_decoder = super().compile(
189
+ wrapped_model.decoder,
190
+ dec_compile_config,
191
+ example_inputs=dec_example_inputs,
192
+ compile_context=context,
193
+ )
194
+ compiled_encoder = super().compile(
195
+ wrapped_model.encoder,
196
+ enc_compile_config,
197
+ example_inputs=enc_example_inputs,
198
+ compile_context=context,
199
+ )
200
+
201
+ return {"encoder": compiled_encoder, "decoder": compiled_decoder}
202
+
203
+ @classmethod
204
+ def save_torch_artifacts(
205
+ cls,
206
+ model: "PreTrainedModel",
207
+ save_dir_path: Path,
208
+ subfolder: str,
209
+ rbln_config: RBLNConfig,
210
+ ):
211
+ """
212
+ If you are unavoidably running on a CPU rather than an RBLN device,
213
+ store the torch tensor, weight, etc. in this function.
214
+ """
215
+ save_dict = {}
216
+ save_dict["embedder"] = model.model.embedder.state_dict()
217
+ torch.save(save_dict, save_dir_path / subfolder / "torch_artifacts.pth")
218
+
219
+ @classmethod
220
+ def _get_rbln_config(
221
+ cls,
222
+ preprocessors: Optional[Union["AutoFeatureExtractor", "AutoProcessor", "AutoTokenizer"]],
223
+ model_config: "PretrainedConfig",
224
+ rbln_kwargs: Dict[str, Any] = {},
225
+ ) -> RBLNConfig:
226
+ rbln_batch_size = rbln_kwargs.get("batch_size", 1)
227
+ rbln_dec_max_seq_len = rbln_kwargs.get("dec_max_seq_len", None)
228
+ rbln_num_parallel_samples = rbln_kwargs.get("num_parallel_samples", None)
229
+
230
+ if not isinstance(rbln_batch_size, int):
231
+ raise TypeError(f"Expected rbln_batch_size to be an int, but got {type(rbln_batch_size)}")
232
+
233
+ rbln_num_parallel_samples = (
234
+ model_config.num_parallel_samples if rbln_num_parallel_samples is None else rbln_num_parallel_samples
235
+ )
236
+ if rbln_dec_max_seq_len is None:
237
+ predict_length = model_config.prediction_length
238
+ rbln_dec_max_seq_len = (
239
+ predict_length if predict_length % 64 == 0 else predict_length + (64 - predict_length % 64)
240
+ )
241
+
242
+ # model input info
243
+ enc_input_info = [
244
+ ("inputs_embeds", [rbln_batch_size, model_config.context_length, model_config.feature_size], "float32"),
245
+ ]
246
+ enc_input_info.extend(
247
+ [
248
+ (
249
+ "cross_key_value_states",
250
+ [
251
+ model_config.decoder_layers * 2,
252
+ rbln_batch_size,
253
+ model_config.decoder_attention_heads,
254
+ model_config.context_length,
255
+ model_config.d_model // model_config.decoder_attention_heads,
256
+ ],
257
+ "float32",
258
+ )
259
+ ]
260
+ )
261
+
262
+ dec_input_info = [
263
+ ("inputs_embeds", [rbln_batch_size * rbln_num_parallel_samples, 1, model_config.feature_size], "float32"),
264
+ ("attention_mask", [1, rbln_dec_max_seq_len], "float32"),
265
+ ("cache_position", [], "int32"),
266
+ ("block_tables", [1, 1], "int16"),
267
+ ]
268
+ dec_input_info.extend(
269
+ [
270
+ (
271
+ "cross_key_value_states",
272
+ [
273
+ model_config.decoder_layers * 2, # 4
274
+ rbln_batch_size, # 64
275
+ model_config.decoder_attention_heads, # 2
276
+ model_config.context_length, # 24
277
+ model_config.d_model // model_config.decoder_attention_heads, # 13
278
+ ],
279
+ "float32",
280
+ )
281
+ ]
282
+ )
283
+ dec_input_info.extend(
284
+ [
285
+ (
286
+ f"self_key_value_states_{i}",
287
+ [
288
+ 1,
289
+ model_config.decoder_attention_heads * rbln_num_parallel_samples * rbln_batch_size,
290
+ rbln_dec_max_seq_len,
291
+ model_config.d_model // model_config.encoder_attention_heads,
292
+ ],
293
+ "float32",
294
+ )
295
+ for i in range(model_config.decoder_layers * 2)
296
+ ]
297
+ )
298
+ enc_compile_config = RBLNCompileConfig(compiled_model_name="encoder", input_info=enc_input_info)
299
+ dec_compile_config = RBLNCompileConfig(compiled_model_name="decoder", input_info=dec_input_info)
300
+
301
+ rbln_config = RBLNConfig(
302
+ rbln_cls=cls.__name__,
303
+ compile_cfgs=[enc_compile_config, dec_compile_config],
304
+ rbln_kwargs=rbln_kwargs,
305
+ )
306
+
307
+ rbln_config.model_cfg.update(
308
+ {
309
+ "batch_size": rbln_batch_size,
310
+ "num_parallel_samples": rbln_num_parallel_samples,
311
+ "dec_max_seq_len": rbln_dec_max_seq_len,
312
+ }
313
+ )
314
+
315
+ return rbln_config
316
+
317
+ @classmethod
318
+ def _create_runtimes(
319
+ cls,
320
+ compiled_models: List[rebel.RBLNCompiledModel],
321
+ rbln_device_map: Dict[str, int],
322
+ activate_profiler: Optional[bool] = None,
323
+ ) -> List[rebel.Runtime]:
324
+ if any(model_name not in rbln_device_map for model_name in ["encoder", "decoder"]):
325
+ cls._raise_missing_compiled_file_error(["encoder", "decoder"])
326
+
327
+ return [
328
+ compiled_models[0].create_runtime(
329
+ tensor_type="pt", device=rbln_device_map["encoder"], activate_profiler=activate_profiler
330
+ ),
331
+ compiled_models[1].create_runtime(
332
+ tensor_type="pt", device=rbln_device_map["decoder"], activate_profiler=activate_profiler
333
+ ),
334
+ ]
335
+
336
+ def validate_batch_size(self, **kwargs):
337
+ for k, v in kwargs.items():
338
+ if v is not None and v.shape[0] != self.batch_size:
339
+ raise RuntimeError(
340
+ f"Batch size mismatch in '{k}': Expected {self.batch_size}, but got {v.shape[0]}. \n"
341
+ f"Tensor shape: {v.shape} \n\n"
342
+ f"Note: `batch_size` is set at compile time. \n"
343
+ f"To change it, pass `export=True` along with `rbln_batch_size` when calling `from_pretrained()` to trigger recompilation."
344
+ )
345
+
346
+ @torch.no_grad()
347
+ def generate(
348
+ self,
349
+ past_values: torch.Tensor,
350
+ past_time_features: torch.Tensor,
351
+ future_time_features: torch.Tensor,
352
+ past_observed_mask: Optional[torch.Tensor] = None,
353
+ static_categorical_features: Optional[torch.Tensor] = None,
354
+ static_real_features: Optional[torch.Tensor] = None,
355
+ **kwargs,
356
+ ) -> SampleTSPredictionOutput:
357
+ self.validate_batch_size(**{k: v for k, v in locals().items() if isinstance(v, torch.Tensor)})
358
+
359
+ outputs = self.encoder(
360
+ static_categorical_features=static_categorical_features,
361
+ static_real_features=static_real_features,
362
+ past_time_features=past_time_features,
363
+ past_values=past_values,
364
+ past_observed_mask=past_observed_mask,
365
+ future_time_features=future_time_features,
366
+ )
367
+
368
+ loc = outputs.loc
369
+ scale = outputs.scale
370
+ static_feat = outputs.static_features
371
+
372
+ num_parallel_samples = self.num_parallel_samples
373
+ repeated_loc = loc.repeat_interleave(repeats=num_parallel_samples, dim=0)
374
+ repeated_scale = scale.repeat_interleave(repeats=num_parallel_samples, dim=0)
375
+
376
+ repeated_past_values = (
377
+ past_values.repeat_interleave(repeats=num_parallel_samples, dim=0) - repeated_loc
378
+ ) / repeated_scale
379
+
380
+ expanded_static_feat = static_feat.unsqueeze(1).expand(-1, future_time_features.shape[1], -1)
381
+ features = torch.cat((expanded_static_feat, future_time_features), dim=-1)
382
+ repeated_features = features.repeat_interleave(repeats=num_parallel_samples, dim=0)
383
+
384
+ # greedy decoding
385
+ future_samples = []
386
+ dec_attn_mask = torch.zeros(1, self.dec_max_seq_len)
387
+ for k in range(self.config.prediction_length):
388
+ lagged_sequence = self._origin_model.model.get_lagged_subsequences(
389
+ sequence=repeated_past_values,
390
+ subsequences_length=1 + k,
391
+ shift=1,
392
+ )
393
+
394
+ lags_shape = lagged_sequence.shape
395
+ reshaped_lagged_sequence = lagged_sequence.reshape(lags_shape[0], lags_shape[1], -1)
396
+ decoder_input = torch.cat((reshaped_lagged_sequence, repeated_features[:, : k + 1]), dim=-1)
397
+
398
+ dec_attn_mask[:, k] = 1
399
+ dec_inputs_embeds = decoder_input[:, -1:]
400
+
401
+ decoder_out = self.decoder(
402
+ inputs_embeds=dec_inputs_embeds.contiguous(),
403
+ attention_mask=dec_attn_mask,
404
+ cache_position=torch.tensor(k, dtype=torch.int32),
405
+ )
406
+ params = decoder_out.params
407
+
408
+ distr = self._origin_model.output_distribution(params, loc=repeated_loc, scale=repeated_scale)
409
+ next_sample = distr.sample()
410
+
411
+ repeated_past_values = torch.cat(
412
+ (repeated_past_values, (next_sample - repeated_loc) / repeated_scale), dim=1
413
+ )
414
+ future_samples.append(next_sample)
415
+
416
+ concat_future_samples = torch.cat(future_samples, dim=1)
417
+
418
+ return SampleTSPredictionOutput(
419
+ sequences=concat_future_samples.reshape(
420
+ (-1, num_parallel_samples, self.config.prediction_length) + self._origin_model.target_shape,
421
+ )
422
+ )